├── README.md ├── config └── underwater.json ├── core ├── logger.py ├── metrics.py └── wandb_logger.py ├── data ├── LRHR_dataset.py ├── __init__.py ├── prepare_data.py └── util.py ├── dataset └── water_val_16_256 │ ├── hr_256 │ ├── 00001.png │ ├── 00002.png │ ├── 00003.png │ ├── 00004.png │ ├── 00005.png │ ├── 00006.png │ ├── 00007.png │ ├── 00008.png │ ├── 00009.png │ ├── 00010.png │ ├── 00011.png │ ├── 00012.png │ ├── 00013.png │ ├── 00014.png │ ├── 00015.png │ ├── 00016.png │ ├── 00017.png │ ├── 00018.png │ ├── 00019.png │ └── 00020.png │ └── sr_16_256 │ ├── 00001.png │ ├── 00002.png │ ├── 00003.png │ ├── 00004.png │ ├── 00005.png │ ├── 00006.png │ ├── 00007.png │ ├── 00008.png │ ├── 00009.png │ ├── 00010.png │ ├── 00011.png │ ├── 00012.png │ ├── 00013.png │ ├── 00014.png │ ├── 00015.png │ ├── 00016.png │ ├── 00017.png │ ├── 00018.png │ ├── 00019.png │ └── 00020.png ├── experiments_supervised └── model ├── infer.py ├── model ├── __init__.py ├── base_model.py ├── ddpm_trans_modules │ ├── diffusion.py │ ├── style_transfer.py │ ├── trans_block_eca.py │ └── unet.py ├── model.py ├── networks.py └── utils.py ├── requirement.txt ├── search_diffusion.py ├── tester_water.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # DM_underwater 2 | This is the code of the paper "Underwater Image Enhancement by Transformer-based Diffusion Model with Non-uniform Sampling for Skip Strategy" 3 | 4 | usage steps: 5 | 6 | - Install necessary Python packages from requirement.txt. 7 | - Putting your data into the dataset folder. (There is initial data in this folder now). 8 | - Download the pre-trained model, the link is https://drive.google.com/file/d/1As3Pd8W6XmQBU__83iYtBT5vssoZHSqn/view?usp=sharing. Then, put the model in the experiments_supervised folder. 9 | - Execute infer.py to get the inference results in a new folder called experiments. 10 | - Users can also comment and uncomment the line 13 and 14 in the config/underwater.json to change for the training process. And execute train.py for training. 11 | - search_diffussion.py is used to search the sequence of time steps with the evolutionary algorithm. Users can use it in the inference process. 12 | 13 | P.S. The author is so lazy that he doesn't want to write down more instructions. 14 | 15 | 16 | -------------------------------------------------------------------------------- /config/underwater.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sr_ffhq", 3 | "phase": "train", // train or val 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | "path": { //set the path 8 | "stage": "train", //change train or val folder to save 9 | "log": "logs", 10 | "tb_logger": "tb_logger", 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | //"resume_state": null 14 | "resume_state": "experiments_supervised/I950000_E3369" //pretrain model or training state 15 | }, 16 | "datasets": { 17 | "train": { 18 | "name": "Water", 19 | "mode": "HR", // whether need LR img 20 | "dataroot": "dataset/water_train_16_128", 21 | "datatype": "img", //lmdb or img, path of img files 22 | "l_resolution": 16, // low resolution need to super_resolution 23 | "r_resolution": 128, // high resolution 24 | "batch_size": 10, 25 | "num_workers": 8, 26 | "use_shuffle": true, 27 | "data_len": -1 // -1 represents all data used in train 28 | }, 29 | "val": { 30 | "name": "Water", 31 | "mode": "HR", 32 | "dataroot": "dataset/water_val_16_256", 33 | "datatype": "img", //lmdb or img, path of img files 34 | "l_resolution": 16, 35 | "r_resolution": 256, 36 | "data_len": -1 // data length in validation 37 | } 38 | }, 39 | "model": { 40 | "which_model_G": "trans", // use the ddpm or sr3 network structure 41 | "finetune_norm": false, 42 | "unet": { 43 | "in_channel": 6, 44 | "out_channel": 3, 45 | "inner_channel": 48, 46 | "norm_groups": 24, 47 | "channel_multiplier": [ 48 | 1, 49 | 2, 50 | 4, 51 | 8, 52 | 8 53 | ], 54 | "attn_res": [ 55 | 16 56 | ], 57 | "res_blocks": 2, 58 | "dropout": 0.2 59 | }, 60 | "beta_schedule": { // use munual beta_schedule for acceleration 61 | "train": { 62 | "schedule": "linear", 63 | "n_timestep": 2000, 64 | "linear_start": 1e-6, 65 | "linear_end": 1e-2 66 | }, 67 | "val": { 68 | "schedule": "linear", 69 | "n_timestep": 2000, 70 | "linear_start": 1e-6, 71 | "linear_end": 1e-2 72 | } 73 | }, 74 | "diffusion": { 75 | "image_size": 256, 76 | "channels": 3, //sample channel 77 | "conditional": true // unconditional generation or unconditional generation(super_resolution) 78 | } 79 | }, 80 | "train": { 81 | "n_iter": 1000000, 82 | "val_freq": 50000, 83 | "save_checkpoint_freq": 50000, 84 | "print_freq": 200, 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 1e-4 88 | }, 89 | "ema_scheduler": { // not used now 90 | "step_start_ema": 5000, 91 | "update_ema_every": 1, 92 | "ema_decay": 0.9999 93 | } 94 | }, 95 | "wandb": { 96 | "project": "sr_ffhq" 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | from collections import OrderedDict 5 | import json 6 | from datetime import datetime 7 | 8 | 9 | def mkdirs(paths): 10 | if isinstance(paths, str): 11 | os.makedirs(paths, exist_ok=True) 12 | else: 13 | for path in paths: 14 | os.makedirs(path, exist_ok=True) 15 | 16 | 17 | def get_timestamp(): 18 | return datetime.now().strftime('%y%m%d_%H%M%S') 19 | 20 | 21 | def parse(args): 22 | phase = args.phase 23 | opt_path = args.config 24 | gpu_ids = args.gpu_ids 25 | enable_wandb = args.enable_wandb 26 | # remove comments starting with '//' 27 | json_str = '' 28 | with open(opt_path, 'r') as f: 29 | for line in f: 30 | line = line.split('//')[0] + '\n' 31 | json_str += line 32 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 33 | 34 | # set log directory 35 | if args.debug: 36 | opt['name'] = 'debug_{}'.format(opt['name']) 37 | experiments_root = os.path.join( 38 | 'experiments_{}'.format(opt['path']['stage']), '{}_{}'.format(opt['name'], get_timestamp())) 39 | opt['path']['experiments_root'] = experiments_root 40 | for key, path in opt['path'].items(): 41 | if 'resume' not in key and 'experiments' not in key: 42 | opt['path'][key] = os.path.join(experiments_root, path) 43 | mkdirs(opt['path'][key]) 44 | 45 | # change dataset length limit 46 | opt['phase'] = phase 47 | 48 | # export CUDA_VISIBLE_DEVICES 49 | if gpu_ids is not None: 50 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')] 51 | gpu_list = gpu_ids 52 | else: 53 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 54 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 55 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 56 | if len(gpu_list) > 1: 57 | opt['distributed'] = True 58 | else: 59 | opt['distributed'] = False 60 | 61 | # debug 62 | if 'debug' in opt['name']: 63 | opt['train']['val_freq'] = 2 64 | opt['train']['print_freq'] = 2 65 | opt['train']['save_checkpoint_freq'] = 3 66 | opt['datasets']['train']['batch_size'] = 2 67 | opt['model']['beta_schedule']['train']['n_timestep'] = 10 68 | opt['model']['beta_schedule']['val']['n_timestep'] = 10 69 | opt['datasets']['train']['data_len'] = 6 70 | opt['datasets']['val']['data_len'] = 3 71 | 72 | # validation in train phase 73 | if phase == 'train': 74 | opt['datasets']['val']['data_len'] = 3 75 | 76 | # W&B Logging 77 | try: 78 | log_wandb_ckpt = args.log_wandb_ckpt 79 | opt['log_wandb_ckpt'] = log_wandb_ckpt 80 | except: 81 | pass 82 | try: 83 | log_eval = args.log_eval 84 | opt['log_eval'] = log_eval 85 | except: 86 | pass 87 | try: 88 | log_infer = args.log_infer 89 | opt['log_infer'] = log_infer 90 | except: 91 | pass 92 | opt['enable_wandb'] = enable_wandb 93 | 94 | return opt 95 | 96 | 97 | class NoneDict(dict): 98 | def __missing__(self, key): 99 | return None 100 | 101 | 102 | # convert to NoneDict, which return None for missing key. 103 | def dict_to_nonedict(opt): 104 | if isinstance(opt, dict): 105 | new_opt = dict() 106 | for key, sub_opt in opt.items(): 107 | new_opt[key] = dict_to_nonedict(sub_opt) 108 | return NoneDict(**new_opt) 109 | elif isinstance(opt, list): 110 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 111 | else: 112 | return opt 113 | 114 | 115 | def dict2str(opt, indent_l=1): 116 | '''dict to string for logger''' 117 | msg = '' 118 | for k, v in opt.items(): 119 | if isinstance(v, dict): 120 | msg += ' ' * (indent_l * 2) + k + ':[\n' 121 | msg += dict2str(v, indent_l + 1) 122 | msg += ' ' * (indent_l * 2) + ']\n' 123 | else: 124 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 125 | return msg 126 | 127 | 128 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 129 | '''set up logger''' 130 | l = logging.getLogger(logger_name) 131 | formatter = logging.Formatter( 132 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 133 | log_file = os.path.join(root, '{}.log'.format(phase)) 134 | fh = logging.FileHandler(log_file, mode='w') 135 | fh.setFormatter(formatter) 136 | l.setLevel(level) 137 | l.addHandler(fh) 138 | if screen: 139 | sh = logging.StreamHandler() 140 | sh.setFormatter(formatter) 141 | l.addHandler(sh) 142 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | from torchvision.utils import make_grid 6 | 7 | 8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | ''' 14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 15 | tensor = (tensor - min_max[0]) / \ 16 | (min_max[1] - min_max[0]) # to range [0,1] 17 | n_dim = tensor.dim() 18 | if n_dim == 4: 19 | n_img = len(tensor) 20 | img_np = make_grid(tensor, nrow=int( 21 | math.sqrt(n_img)), normalize=False).numpy() 22 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 23 | elif n_dim == 3: 24 | img_np = tensor.numpy() 25 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 26 | elif n_dim == 2: 27 | img_np = tensor.numpy() 28 | else: 29 | raise TypeError( 30 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 31 | if out_type == np.uint8: 32 | img_np = (img_np * 255.0).round() 33 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 34 | return img_np.astype(out_type) 35 | 36 | 37 | def save_img(img, img_path, mode='RGB'): 38 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 39 | # cv2.imwrite(img_path, img) 40 | 41 | 42 | def calculate_psnr(img1, img2): 43 | # img1 and img2 have range [0, 255] 44 | img1 = img1.astype(np.float64) 45 | img2 = img2.astype(np.float64) 46 | mse = np.mean((img1 - img2)**2) 47 | if mse == 0: 48 | return float('inf') 49 | return 20 * math.log10(255.0 / math.sqrt(mse)) 50 | 51 | 52 | def ssim(img1, img2): 53 | C1 = (0.01 * 255)**2 54 | C2 = (0.03 * 255)**2 55 | 56 | img1 = img1.astype(np.float64) 57 | img2 = img2.astype(np.float64) 58 | kernel = cv2.getGaussianKernel(11, 1.5) 59 | window = np.outer(kernel, kernel.transpose()) 60 | 61 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 62 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 63 | mu1_sq = mu1**2 64 | mu2_sq = mu2**2 65 | mu1_mu2 = mu1 * mu2 66 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 67 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 68 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 69 | 70 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 71 | (sigma1_sq + sigma2_sq + C2)) 72 | return ssim_map.mean() 73 | 74 | 75 | def calculate_ssim(img1, img2): 76 | '''calculate SSIM 77 | the same outputs as MATLAB's 78 | img1, img2: [0, 255] 79 | ''' 80 | if not img1.shape == img2.shape: 81 | raise ValueError('Input images must have the same dimensions.') 82 | if img1.ndim == 2: 83 | return ssim(img1, img2) 84 | elif img1.ndim == 3: 85 | if img1.shape[2] == 3: 86 | ssims = [] 87 | for i in range(3): 88 | ssims.append(ssim(img1, img2)) 89 | return np.array(ssims).mean() 90 | elif img1.shape[2] == 1: 91 | return ssim(np.squeeze(img1), np.squeeze(img2)) 92 | else: 93 | raise ValueError('Wrong input image dimensions.') 94 | -------------------------------------------------------------------------------- /core/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class WandbLogger: 4 | """ 5 | Log using `Weights and Biases`. 6 | """ 7 | def __init__(self, opt): 8 | try: 9 | import wandb 10 | except ImportError: 11 | raise ImportError( 12 | "To use the Weights and Biases Logger please install wandb." 13 | "Run `pip install wandb` to install it." 14 | ) 15 | 16 | self._wandb = wandb 17 | 18 | # Initialize a W&B run 19 | if self._wandb.run is None: 20 | self._wandb.init( 21 | project=opt['wandb']['project'], 22 | config=opt, 23 | dir='./experiments' 24 | ) 25 | 26 | self.config = self._wandb.config 27 | 28 | if self.config.get('log_eval', None): 29 | self.eval_table = self._wandb.Table(columns=['fake_image', 30 | 'sr_image', 31 | 'hr_image', 32 | 'psnr', 33 | 'ssim']) 34 | else: 35 | self.eval_table = None 36 | 37 | if self.config.get('log_infer', None): 38 | self.infer_table = self._wandb.Table(columns=['fake_image', 39 | 'sr_image', 40 | 'hr_image']) 41 | else: 42 | self.infer_table = None 43 | 44 | def log_metrics(self, metrics, commit=True): 45 | """ 46 | Log train/validation metrics onto W&B. 47 | 48 | metrics: dictionary of metrics to be logged 49 | """ 50 | self._wandb.log(metrics, commit=commit) 51 | 52 | def log_image(self, key_name, image_array): 53 | """ 54 | Log image array onto W&B. 55 | 56 | key_name: name of the key 57 | image_array: numpy array of image. 58 | """ 59 | self._wandb.log({key_name: self._wandb.Image(image_array)}) 60 | 61 | def log_images(self, key_name, list_images): 62 | """ 63 | Log list of image array onto W&B 64 | 65 | key_name: name of the key 66 | list_images: list of numpy image arrays 67 | """ 68 | self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]}) 69 | 70 | def log_checkpoint(self, current_epoch, current_step): 71 | """ 72 | Log the model checkpoint as W&B artifacts 73 | 74 | current_epoch: the current epoch 75 | current_step: the current batch step 76 | """ 77 | model_artifact = self._wandb.Artifact( 78 | self._wandb.run.id + "_model", type="model" 79 | ) 80 | 81 | gen_path = os.path.join( 82 | self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch)) 83 | opt_path = os.path.join( 84 | self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch)) 85 | 86 | model_artifact.add_file(gen_path) 87 | model_artifact.add_file(opt_path) 88 | self._wandb.log_artifact(model_artifact, aliases=["latest"]) 89 | 90 | def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None): 91 | """ 92 | Add data row-wise to the initialized table. 93 | """ 94 | if psnr is not None and ssim is not None: 95 | self.eval_table.add_data( 96 | self._wandb.Image(fake_img), 97 | self._wandb.Image(sr_img), 98 | self._wandb.Image(hr_img), 99 | psnr, 100 | ssim 101 | ) 102 | else: 103 | self.infer_table.add_data( 104 | self._wandb.Image(fake_img), 105 | self._wandb.Image(sr_img), 106 | self._wandb.Image(hr_img) 107 | ) 108 | 109 | def log_eval_table(self, commit=False): 110 | """ 111 | Log the table 112 | """ 113 | if self.eval_table: 114 | self._wandb.log({'eval_data': self.eval_table}, commit=commit) 115 | elif self.infer_table: 116 | self._wandb.log({'infer_data': self.infer_table}, commit=commit) 117 | -------------------------------------------------------------------------------- /data/LRHR_dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import lmdb 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | import random 6 | import data.util as Util 7 | import numpy as np 8 | 9 | class LRHRDataset(Dataset): 10 | def __init__(self, dataroot, datatype, l_resolution=16, r_resolution=128, split='train', data_len=-1, need_LR=False): 11 | self.datatype = datatype 12 | self.l_res = l_resolution 13 | self.r_res = r_resolution 14 | self.data_len = data_len 15 | self.need_LR = need_LR 16 | self.split = split 17 | 18 | if datatype == 'lmdb': 19 | self.env = lmdb.open(dataroot, readonly=True, lock=False, 20 | readahead=False, meminit=False) 21 | # init the datalen 22 | with self.env.begin(write=False) as txn: 23 | self.dataset_len = int(txn.get("length".encode("utf-8"))) 24 | if self.data_len <= 0: 25 | self.data_len = self.dataset_len 26 | else: 27 | self.data_len = min(self.data_len, self.dataset_len) 28 | elif datatype == 'img': 29 | self.sr_path = Util.get_paths_from_images( 30 | '{}/sr_{}_{}'.format(dataroot, l_resolution, r_resolution)) 31 | self.hr_path = Util.get_paths_from_images( 32 | '{}/hr_{}'.format(dataroot, r_resolution)) 33 | if self.need_LR: 34 | self.lr_path = Util.get_paths_from_images( 35 | '{}/lr_{}'.format(dataroot, l_resolution)) 36 | self.dataset_len = len(self.hr_path) 37 | if self.data_len <= 0: 38 | self.data_len = self.dataset_len 39 | else: 40 | self.data_len = min(self.data_len, self.dataset_len) 41 | else: 42 | raise NotImplementedError( 43 | 'data_type [{:s}] is not recognized.'.format(datatype)) 44 | 45 | def __len__(self): 46 | return self.data_len 47 | 48 | def __getitem__(self, index): 49 | img_HR = None 50 | img_LR = None 51 | 52 | if self.datatype == 'lmdb': 53 | with self.env.begin(write=False) as txn: 54 | hr_img_bytes = txn.get( 55 | 'hr_{}_{}'.format( 56 | self.r_res, str(index).zfill(5)).encode('utf-8') 57 | ) 58 | sr_img_bytes = txn.get( 59 | 'sr_{}_{}_{}'.format( 60 | self.l_res, self.r_res, str(index).zfill(5)).encode('utf-8') 61 | ) 62 | if self.need_LR: 63 | lr_img_bytes = txn.get( 64 | 'lr_{}_{}'.format( 65 | self.l_res, str(index).zfill(5)).encode('utf-8') 66 | ) 67 | # skip the invalid index 68 | while (hr_img_bytes is None) or (sr_img_bytes is None): 69 | new_index = random.randint(0, self.data_len-1) 70 | hr_img_bytes = txn.get( 71 | 'hr_{}_{}'.format( 72 | self.r_res, str(new_index).zfill(5)).encode('utf-8') 73 | ) 74 | sr_img_bytes = txn.get( 75 | 'sr_{}_{}_{}'.format( 76 | self.l_res, self.r_res, str(new_index).zfill(5)).encode('utf-8') 77 | ) 78 | if self.need_LR: 79 | lr_img_bytes = txn.get( 80 | 'lr_{}_{}'.format( 81 | self.l_res, str(new_index).zfill(5)).encode('utf-8') 82 | ) 83 | img_HR = Image.open(BytesIO(hr_img_bytes)).convert("RGB") 84 | img_SR = Image.open(BytesIO(sr_img_bytes)).convert("RGB") 85 | if self.need_LR: 86 | img_LR = Image.open(BytesIO(lr_img_bytes)).convert("RGB") 87 | else: 88 | img_HR = Image.open(self.hr_path[index]).convert("RGB") 89 | img_SR = Image.open(self.sr_path[index]).convert("RGB") 90 | if self.need_LR: 91 | img_LR = Image.open(self.lr_path[index]).convert("RGB") 92 | if self.need_LR: 93 | [img_LR, img_SR, img_HR] = Util.transform_augment( 94 | [img_LR, img_SR, img_HR], split=self.split, min_max=(-1, 1)) 95 | return {'LR': img_LR, 'HR': img_HR, 'SR': img_SR, 'Index': index} 96 | else: 97 | [img_SR, img_HR] = Util.transform_augment( 98 | [img_SR, img_HR], split=self.split, min_max=(-1, 1)) 99 | return {'HR': img_HR, 'SR': img_SR, 'Index': index} 100 | 101 | class LRHRDataset2(Dataset): 102 | def __init__(self, dataroot, datatype, l_resolution=16, r_resolution=128, split='train', data_len=-1, need_LR=False): 103 | self.datatype = datatype 104 | self.l_res = l_resolution 105 | self.r_res = r_resolution 106 | self.data_len = data_len 107 | self.need_LR = need_LR 108 | self.split = split 109 | 110 | if datatype == 'img': 111 | self.sr_path = Util.get_paths_from_images( 112 | '{}/sr_{}_{}'.format(dataroot, l_resolution, r_resolution)) 113 | self.hr_path = Util.get_paths_from_images( 114 | '{}/hr_{}'.format(dataroot, r_resolution)) 115 | self.style_path = Util.get_paths_from_images( 116 | '{}/style_{}'.format(dataroot, r_resolution)) 117 | if self.need_LR: 118 | self.lr_path = Util.get_paths_from_images( 119 | '{}/lr_{}'.format(dataroot, l_resolution)) 120 | self.dataset_len = len(self.hr_path) 121 | if self.data_len <= 0: 122 | self.data_len = self.dataset_len 123 | else: 124 | self.data_len = min(self.data_len, self.dataset_len) 125 | else: 126 | raise NotImplementedError( 127 | 'data_type [{:s}] is not recognized.'.format(datatype)) 128 | 129 | def __len__(self): 130 | return self.data_len 131 | 132 | def __getitem__(self, index): 133 | img_HR = None 134 | img_LR = None 135 | # index_style = np.random.randint(0, self.data_len) 136 | 137 | # img_HR = Image.open(self.hr_path[index]).convert("RGB").resize((128, 128)) 138 | # img_SR = Image.open(self.sr_path[index]).convert("RGB").resize((128, 128)) 139 | # img_style = Image.open(self.style_path[index]).convert("RGB").resize((64, 64)) 140 | 141 | img_HR = Image.open(self.hr_path[index]).convert("RGB") 142 | img_SR = Image.open(self.sr_path[index]).convert("RGB") 143 | # img_style = Image.open(self.style_path[index]).convert("RGB") 144 | img_style = Image.open(self.sr_path[index]).convert("RGB") 145 | if self.need_LR: 146 | img_LR = Image.open(self.lr_path[index]).convert("RGB") 147 | if self.need_LR: 148 | [img_LR, img_SR, img_HR, img_style] = Util.transform_augment( 149 | [img_LR, img_SR, img_HR, img_style], split=self.split, min_max=(-1, 1)) 150 | return {'LR': img_LR, 'HR': img_HR, 'SR': img_SR, 'style': img_style, 'Index': index} 151 | else: 152 | [img_SR, img_HR, img_style] = Util.transform_augment( 153 | [img_SR, img_HR, img_style], split=self.split, min_max=(-1, 1)) 154 | return {'HR': img_HR, 'SR': img_SR, 'style': img_style, 'Index': index} 155 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | from re import split 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, phase): 8 | '''create dataloader ''' 9 | if phase == 'train': 10 | return torch.utils.data.DataLoader( 11 | dataset, 12 | batch_size=dataset_opt['batch_size'], 13 | shuffle=dataset_opt['use_shuffle'], 14 | num_workers=dataset_opt['num_workers'], 15 | pin_memory=True) 16 | elif phase == 'val': 17 | return torch.utils.data.DataLoader( 18 | dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 19 | else: 20 | raise NotImplementedError( 21 | 'Dataloader [{:s}] is not found.'.format(phase)) 22 | 23 | 24 | def create_dataset(dataset_opt, phase): 25 | '''create dataset''' 26 | mode = dataset_opt['mode'] 27 | from data.LRHR_dataset import LRHRDataset2 as D 28 | dataset = D(dataroot=dataset_opt['dataroot'], 29 | datatype=dataset_opt['datatype'], 30 | l_resolution=dataset_opt['l_resolution'], 31 | r_resolution=dataset_opt['r_resolution'], 32 | split=phase, 33 | data_len=dataset_opt['data_len'], 34 | need_LR=(mode == 'LRHR') 35 | ) 36 | logger = logging.getLogger('base') 37 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 38 | dataset_opt['name'])) 39 | return dataset 40 | -------------------------------------------------------------------------------- /data/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from multiprocessing import Lock, Process, RawValue 5 | from functools import partial 6 | from multiprocessing.sharedctypes import RawValue 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from torchvision.transforms import functional as trans_fn 10 | import os 11 | from pathlib import Path 12 | import lmdb 13 | import numpy as np 14 | import time 15 | 16 | 17 | def resize_and_convert(img, size, resample): 18 | if(img.size[0] != size): 19 | img = trans_fn.resize(img, size, resample) 20 | img = trans_fn.center_crop(img, size) 21 | return img 22 | 23 | 24 | def image_convert_bytes(img): 25 | buffer = BytesIO() 26 | img.save(buffer, format='png') 27 | return buffer.getvalue() 28 | 29 | 30 | def resize_multiple(img, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=False): 31 | lr_img = resize_and_convert(img, sizes[0], resample) 32 | hr_img = resize_and_convert(img, sizes[1], resample) 33 | sr_img = resize_and_convert(lr_img, sizes[1], resample) 34 | 35 | if lmdb_save: 36 | lr_img = image_convert_bytes(lr_img) 37 | hr_img = image_convert_bytes(hr_img) 38 | sr_img = image_convert_bytes(sr_img) 39 | 40 | return [lr_img, hr_img, sr_img] 41 | 42 | def resize_worker(img_file, sizes, resample, lmdb_save=False): 43 | img = Image.open(img_file) 44 | img = img.convert('RGB') 45 | out = resize_multiple( 46 | img, sizes=sizes, resample=resample, lmdb_save=lmdb_save) 47 | 48 | return img_file.name.split('.')[0], out 49 | 50 | class WorkingContext(): 51 | def __init__(self, resize_fn, lmdb_save, out_path, env, sizes): 52 | self.resize_fn = resize_fn 53 | self.lmdb_save = lmdb_save 54 | self.out_path = out_path 55 | self.env = env 56 | self.sizes = sizes 57 | 58 | self.counter = RawValue('i', 0) 59 | self.counter_lock = Lock() 60 | 61 | def inc_get(self): 62 | with self.counter_lock: 63 | self.counter.value += 1 64 | return self.counter.value 65 | 66 | def value(self): 67 | with self.counter_lock: 68 | return self.counter.value 69 | 70 | def prepare_process_worker(wctx, file_subset): 71 | for file in file_subset: 72 | i, imgs = wctx.resize_fn(file) 73 | lr_img, hr_img, sr_img = imgs 74 | if not wctx.lmdb_save: 75 | lr_img.save( 76 | '{}/lr_{}/{}.png'.format(wctx.out_path, wctx.sizes[0], i.zfill(5))) 77 | hr_img.save( 78 | '{}/hr_{}/{}.png'.format(wctx.out_path, wctx.sizes[1], i.zfill(5))) 79 | sr_img.save( 80 | '{}/sr_{}_{}/{}.png'.format(wctx.out_path, wctx.sizes[0], wctx.sizes[1], i.zfill(5))) 81 | else: 82 | with wctx.env.begin(write=True) as txn: 83 | txn.put('lr_{}_{}'.format( 84 | wctx.sizes[0], i.zfill(5)).encode('utf-8'), lr_img) 85 | txn.put('hr_{}_{}'.format( 86 | wctx.sizes[1], i.zfill(5)).encode('utf-8'), hr_img) 87 | txn.put('sr_{}_{}_{}'.format( 88 | wctx.sizes[0], wctx.sizes[1], i.zfill(5)).encode('utf-8'), sr_img) 89 | curr_total = wctx.inc_get() 90 | if wctx.lmdb_save: 91 | with wctx.env.begin(write=True) as txn: 92 | txn.put('length'.encode('utf-8'), str(curr_total).encode('utf-8')) 93 | 94 | def all_threads_inactive(worker_threads): 95 | for thread in worker_threads: 96 | if thread.is_alive(): 97 | return False 98 | return True 99 | 100 | def prepare(img_path, out_path, n_worker, sizes=(16, 128), resample=Image.BICUBIC, lmdb_save=False): 101 | resize_fn = partial(resize_worker, sizes=sizes, 102 | resample=resample, lmdb_save=lmdb_save) 103 | files = [p for p in Path( 104 | '{}'.format(img_path)).glob(f'**/*')] 105 | 106 | if not lmdb_save: 107 | os.makedirs(out_path, exist_ok=True) 108 | os.makedirs('{}/lr_{}'.format(out_path, sizes[0]), exist_ok=True) 109 | os.makedirs('{}/hr_{}'.format(out_path, sizes[1]), exist_ok=True) 110 | os.makedirs('{}/sr_{}_{}'.format(out_path, 111 | sizes[0], sizes[1]), exist_ok=True) 112 | else: 113 | env = lmdb.open(out_path, map_size=1024 ** 4, readahead=False) 114 | 115 | if n_worker > 1: 116 | # prepare data subsets 117 | multi_env = None 118 | if lmdb_save: 119 | multi_env = env 120 | 121 | file_subsets = np.array_split(files, n_worker) 122 | worker_threads = [] 123 | wctx = WorkingContext(resize_fn, lmdb_save, out_path, multi_env, sizes) 124 | 125 | # start worker processes, monitor results 126 | for i in range(n_worker): 127 | proc = Process(target=prepare_process_worker, args=(wctx, file_subsets[i])) 128 | proc.start() 129 | worker_threads.append(proc) 130 | 131 | total_count = str(len(files)) 132 | while not all_threads_inactive(worker_threads): 133 | print("\r{}/{} images processed".format(wctx.value(), total_count), end=" ") 134 | time.sleep(0.1) 135 | 136 | else: 137 | total = 0 138 | for file in tqdm(files): 139 | i, imgs = resize_fn(file) 140 | lr_img, hr_img, sr_img = imgs 141 | if not lmdb_save: 142 | lr_img.save( 143 | '{}/lr_{}/{}.png'.format(out_path, sizes[0], i.zfill(5))) 144 | hr_img.save( 145 | '{}/hr_{}/{}.png'.format(out_path, sizes[1], i.zfill(5))) 146 | sr_img.save( 147 | '{}/sr_{}_{}/{}.png'.format(out_path, sizes[0], sizes[1], i.zfill(5))) 148 | else: 149 | with env.begin(write=True) as txn: 150 | txn.put('lr_{}_{}'.format( 151 | sizes[0], i.zfill(5)).encode('utf-8'), lr_img) 152 | txn.put('hr_{}_{}'.format( 153 | sizes[1], i.zfill(5)).encode('utf-8'), hr_img) 154 | txn.put('sr_{}_{}_{}'.format( 155 | sizes[0], sizes[1], i.zfill(5)).encode('utf-8'), sr_img) 156 | total += 1 157 | if lmdb_save: 158 | with env.begin(write=True) as txn: 159 | txn.put('length'.encode('utf-8'), str(total).encode('utf-8')) 160 | 161 | if __name__ == '__main__': 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument('--path', '-p', type=str, 164 | default='{}/Dataset/celebahq_256'.format(Path.home())) 165 | parser.add_argument('--out', '-o', type=str, 166 | default='./dataset/celebahq') 167 | 168 | parser.add_argument('--size', type=str, default='64,512') 169 | parser.add_argument('--n_worker', type=int, default=3) 170 | parser.add_argument('--resample', type=str, default='bicubic') 171 | # default save in png format 172 | parser.add_argument('--lmdb', '-l', action='store_true') 173 | 174 | args = parser.parse_args() 175 | 176 | resample_map = {'bilinear': Image.BILINEAR, 'bicubic': Image.BICUBIC} 177 | resample = resample_map[args.resample] 178 | sizes = [int(s.strip()) for s in args.size.split(',')] 179 | 180 | args.out = '{}_{}_{}'.format(args.out, sizes[0], sizes[1]) 181 | prepare(args.path, args.out, args.n_worker, 182 | sizes=sizes, resample=resample, lmdb_save=args.lmdb) 183 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import random 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | 15 | def get_paths_from_images(path): 16 | # assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 17 | images = [] 18 | for dirpath, _, fnames in sorted(os.walk(path)): 19 | for fname in sorted(fnames): 20 | if is_image_file(fname): 21 | img_path = os.path.join(dirpath, fname) 22 | images.append(img_path) 23 | # assert images, '{:s} has no valid image file'.format(path) 24 | return sorted(images) 25 | 26 | 27 | def augment(img_list, hflip=True, rot=True, split='val'): 28 | # horizontal flip OR rotate 29 | hflip = hflip and (split == 'train' and random.random() < 0.5) 30 | vflip = rot and (split == 'train' and random.random() < 0.5) 31 | rot90 = rot and (split == 'train' and random.random() < 0.5) 32 | 33 | def _augment(img): 34 | if hflip: 35 | img = img[:, ::-1, :] 36 | if vflip: 37 | img = img[::-1, :, :] 38 | if rot90: 39 | img = img.transpose(1, 0, 2) 40 | return img 41 | 42 | return [_augment(img) for img in img_list] 43 | 44 | 45 | def transform2numpy(img): 46 | img = np.array(img) 47 | img = img.astype(np.float32) / 255. 48 | if img.ndim == 2: 49 | img = np.expand_dims(img, axis=2) 50 | # some images have 4 channels 51 | if img.shape[2] > 3: 52 | img = img[:, :, :3] 53 | return img 54 | 55 | 56 | def transform2tensor(img, min_max=(0, 1)): 57 | # HWC to CHW 58 | img = torch.from_numpy(np.ascontiguousarray( 59 | np.transpose(img, (2, 0, 1)))).float() 60 | # to range min_max 61 | img = img*(min_max[1] - min_max[0]) + min_max[0] 62 | return img 63 | 64 | 65 | # implementation by numpy and torch 66 | # def transform_augment(img_list, split='val', min_max=(0, 1)): 67 | # imgs = [transform2numpy(img) for img in img_list] 68 | # imgs = augment(imgs, split=split) 69 | # ret_img = [transform2tensor(img, min_max) for img in imgs] 70 | # return ret_img 71 | 72 | 73 | # implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14 74 | totensor = torchvision.transforms.ToTensor() 75 | hflip = torchvision.transforms.RandomHorizontalFlip() 76 | def transform_augment(img_list, split='val', min_max=(0, 1)): 77 | imgs = [totensor(img) for img in img_list] 78 | if split == 'train': 79 | imgs = torch.stack(imgs, 0) 80 | imgs = hflip(imgs) 81 | imgs = torch.unbind(imgs, dim=0) 82 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs] 83 | return ret_img 84 | -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00001.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00002.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00003.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00004.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00005.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00006.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00007.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00008.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00009.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00010.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00011.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00012.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00013.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00014.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00015.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00016.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00017.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00018.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00019.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/hr_256/00020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/hr_256/00020.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00001.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00002.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00003.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00004.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00005.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00006.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00007.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00008.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00009.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00010.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00011.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00012.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00013.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00014.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00015.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00016.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00017.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00018.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00019.png -------------------------------------------------------------------------------- /dataset/water_val_16_256/sr_16_256/00020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/piggy2009/DM_underwater/cb7a4a5734b575279c16e52ffd1d3e21dd673fcf/dataset/water_val_16_256/sr_16_256/00020.png -------------------------------------------------------------------------------- /experiments_supervised/model: -------------------------------------------------------------------------------- 1 | put the model parameters in the folder 2 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data as Data 3 | import model as Model 4 | import argparse 5 | import logging 6 | import core.logger as Logger 7 | import core.metrics as Metrics 8 | from core.wandb_logger import WandbLogger 9 | from tensorboardX import SummaryWriter 10 | import os 11 | import time 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-c', '--config', type=str, default='config/underwater.json', 17 | help='JSON file for configuration') 18 | parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val') 19 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 20 | parser.add_argument('-debug', '-d', action='store_true') 21 | parser.add_argument('-enable_wandb', action='store_true') 22 | parser.add_argument('-log_infer', action='store_true') 23 | 24 | # parse configs 25 | args = parser.parse_args() 26 | opt = Logger.parse(args) 27 | # Convert to NoneDict, which return None for missing key. 28 | opt = Logger.dict_to_nonedict(opt) 29 | 30 | # logging 31 | torch.backends.cudnn.enabled = True 32 | torch.backends.cudnn.benchmark = True 33 | 34 | Logger.setup_logger(None, opt['path']['log'], 35 | 'train', level=logging.INFO, screen=True) 36 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) 37 | logger = logging.getLogger('base') 38 | logger.info(Logger.dict2str(opt)) 39 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) 40 | 41 | # Initialize WandbLogger 42 | if opt['enable_wandb']: 43 | wandb_logger = WandbLogger(opt) 44 | else: 45 | wandb_logger = None 46 | 47 | # dataset 48 | for phase, dataset_opt in opt['datasets'].items(): 49 | if phase == 'val': 50 | val_set = Data.create_dataset(dataset_opt, phase) 51 | val_loader = Data.create_dataloader( 52 | val_set, dataset_opt, phase) 53 | logger.info('Initial Dataset Finished') 54 | 55 | # model 56 | diffusion = Model.create_model(opt) 57 | logger.info('Initial Model Finished') 58 | 59 | diffusion.set_new_noise_schedule( 60 | opt['model']['beta_schedule']['val'], schedule_phase='val') 61 | 62 | logger.info('Begin Model Inference.') 63 | current_step = 0 64 | current_epoch = 0 65 | idx = 0 66 | 67 | result_path = '{}'.format(opt['path']['results']) 68 | os.makedirs(result_path, exist_ok=True) 69 | for _, val_data in enumerate(val_loader): 70 | idx += 1 71 | diffusion.feed_data(val_data) 72 | start = time.time() 73 | diffusion.test(continous=True) 74 | end = time.time() 75 | print('Execution time:', (end - start), 'seconds') 76 | visuals = diffusion.get_current_visuals(need_LR=False) 77 | 78 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8 79 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8 80 | 81 | sr_img_mode = 'grid' 82 | if sr_img_mode == 'single': 83 | # single img series 84 | sr_img = visuals['SR'] # uint8 85 | sample_num = sr_img.shape[0] 86 | for iter in range(0, sample_num): 87 | Metrics.save_img( 88 | Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter)) 89 | else: 90 | # grid img 91 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8 92 | Metrics.save_img( 93 | sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx)) 94 | Metrics.save_img( 95 | Metrics.tensor2img(visuals['SR'][-1]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx)) 96 | # for i in range(len(visuals['SR'])): 97 | # Metrics.save_img( 98 | # Metrics.tensor2img(visuals['SR'][i]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, str(i))) 99 | 100 | Metrics.save_img( 101 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) 102 | Metrics.save_img( 103 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) 104 | 105 | if wandb_logger and opt['log_infer']: 106 | wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img) 107 | 108 | if wandb_logger and opt['log_infer']: 109 | wandb_logger.log_eval_table(commit=True) 110 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | from .model import DDPM as M 7 | m = M(opt) 8 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 9 | return m 10 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(): 7 | def __init__(self, opt): 8 | self.opt = opt 9 | self.device = torch.device( 10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu') 11 | self.begin_step = 0 12 | self.begin_epoch = 0 13 | 14 | def feed_data(self, data): 15 | pass 16 | 17 | def optimize_parameters(self): 18 | pass 19 | 20 | def get_current_visuals(self): 21 | pass 22 | 23 | def get_current_losses(self): 24 | pass 25 | 26 | def print_network(self): 27 | pass 28 | 29 | def set_device(self, x): 30 | if isinstance(x, dict): 31 | for key, item in x.items(): 32 | if item is not None: 33 | x[key] = item.to(self.device) 34 | elif isinstance(x, list): 35 | for item in x: 36 | if item is not None: 37 | item = item.to(self.device) 38 | else: 39 | x = x.to(self.device) 40 | return x 41 | 42 | def get_network_description(self, network): 43 | '''Get the string and total parameters of the network''' 44 | if isinstance(network, nn.DataParallel): 45 | network = network.module 46 | s = str(network) 47 | n = sum(map(lambda x: x.numel(), network.parameters())) 48 | return s, n 49 | -------------------------------------------------------------------------------- /model/ddpm_trans_modules/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from functools import partial 7 | import numpy as np 8 | from tqdm import tqdm 9 | from model.ddpm_trans_modules.style_transfer import VGGPerceptualLoss 10 | 11 | 12 | 13 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 14 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 15 | warmup_time = int(n_timestep * warmup_frac) 16 | betas[:warmup_time] = np.linspace( 17 | linear_start, linear_end, warmup_time, dtype=np.float64) 18 | return betas 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == 'quad': 23 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 24 | n_timestep, dtype=np.float64) ** 2 25 | elif schedule == 'linear': 26 | betas = np.linspace(linear_start, linear_end, 27 | n_timestep, dtype=np.float64) 28 | elif schedule == 'warmup10': 29 | betas = _warmup_beta(linear_start, linear_end, 30 | n_timestep, 0.1) 31 | elif schedule == 'warmup50': 32 | betas = _warmup_beta(linear_start, linear_end, 33 | n_timestep, 0.5) 34 | elif schedule == 'const': 35 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 36 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 37 | betas = 1. / np.linspace(n_timestep, 38 | 1, n_timestep, dtype=np.float64) 39 | elif schedule == "cosine": 40 | timesteps = ( 41 | torch.arange(n_timestep + 1, dtype=torch.float64) / 42 | n_timestep + cosine_s 43 | ) 44 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 45 | alphas = torch.cos(alphas).pow(2) 46 | alphas = alphas / alphas[0] 47 | betas = 1 - alphas[1:] / alphas[:-1] 48 | betas = betas.clamp(max=0.999) 49 | else: 50 | raise NotImplementedError(schedule) 51 | return betas 52 | 53 | 54 | # gaussian diffusion trainer class 55 | 56 | def exists(x): 57 | return x is not None 58 | 59 | 60 | def default(val, d): 61 | if exists(val): 62 | return val 63 | return d() if isfunction(d) else d 64 | 65 | 66 | # def extract(a, t, x_shape): 67 | # b, *_ = t.shape 68 | # out = a.gather(-1, t) 69 | # return out.reshape(b, *((1,) * (len(x_shape) - 1))) 70 | 71 | def extract(a, t, x_shape): 72 | """Extract coefficients from a based on t and reshape to make it 73 | broadcastable with x_shape.""" 74 | bs, = t.shape 75 | assert x_shape[0] == bs 76 | out = torch.gather(torch.tensor(a, dtype=torch.float, device=t.device), 0, t.long()) 77 | assert out.shape == (bs,) 78 | out = out.reshape((bs,) + (1,) * (len(x_shape) - 1)) 79 | return out 80 | 81 | 82 | def noise_like(shape, device, repeat=False): 83 | def repeat_noise(): return torch.randn( 84 | (1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 85 | 86 | def noise(): return torch.randn(shape, device=device) 87 | return repeat_noise() if repeat else noise() 88 | 89 | 90 | class GaussianDiffusion(nn.Module): 91 | def __init__( 92 | self, 93 | denoise_fn, 94 | image_size, 95 | channels=3, 96 | loss_type='l1', 97 | conditional=True, 98 | schedule_opt=None 99 | ): 100 | super().__init__() 101 | self.channels = channels 102 | self.image_size = image_size 103 | self.denoise_fn = denoise_fn 104 | self.conditional = conditional 105 | self.loss_type = loss_type 106 | if schedule_opt is not None: 107 | pass 108 | # self.set_new_noise_schedule(schedule_opt) 109 | self.eta = 0 110 | self.sample_proc = 'ddim' 111 | def set_loss(self, device): 112 | if self.loss_type == 'l1': 113 | self.loss_func = nn.L1Loss().to(device) 114 | self.style_loss = VGGPerceptualLoss().to(device) 115 | elif self.loss_type == 'l2': 116 | self.loss_func = nn.MSELoss().to(device) 117 | else: 118 | raise NotImplementedError() 119 | 120 | def set_new_noise_schedule(self, schedule_opt, device): 121 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 122 | betas = make_beta_schedule( 123 | schedule=schedule_opt['schedule'], 124 | n_timestep=schedule_opt['n_timestep'], 125 | linear_start=schedule_opt['linear_start'], 126 | linear_end=schedule_opt['linear_end']) 127 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas 128 | alphas = 1. - betas 129 | alphas_cumprod = np.cumprod(alphas, axis=0) 130 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 131 | 132 | ddim_sigma = (self.eta * ((1 - alphas_cumprod_prev) / (1 - alphas_cumprod) * (1 - alphas_cumprod / alphas_cumprod_prev)) ** 0.5) 133 | self.ddim_sigma = to_torch(ddim_sigma) 134 | timesteps, = betas.shape 135 | self.num_timesteps = int(timesteps) 136 | self.register_buffer('betas', to_torch(betas)) 137 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 138 | self.register_buffer('alphas_cumprod_prev', 139 | to_torch(alphas_cumprod_prev)) 140 | 141 | # self.register_buffer('ddim_sigma', 142 | # to_torch(ddim_sigma)) 143 | 144 | # calculations for diffusion q(x_t | x_{t-1}) and others 145 | self.register_buffer('sqrt_alphas_cumprod', 146 | to_torch(np.sqrt(alphas_cumprod))) 147 | self.register_buffer('sqrt_one_minus_alphas_cumprod', 148 | to_torch(np.sqrt(1. - alphas_cumprod))) 149 | self.register_buffer('log_one_minus_alphas_cumprod', 150 | to_torch(np.log(1. - alphas_cumprod))) 151 | self.register_buffer('sqrt_recip_alphas_cumprod', 152 | to_torch(np.sqrt(1. / alphas_cumprod))) 153 | self.register_buffer('sqrt_recipm1_alphas_cumprod', 154 | to_torch(np.sqrt(1. / alphas_cumprod - 1))) 155 | 156 | # calculations for posterior q(x_{t-1} | x_t, x_0) 157 | posterior_variance = betas * \ 158 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 159 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 160 | self.register_buffer('posterior_variance', 161 | to_torch(posterior_variance)) 162 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 163 | self.register_buffer('posterior_log_variance_clipped', to_torch( 164 | np.log(np.maximum(posterior_variance, 1e-20)))) 165 | self.register_buffer('posterior_mean_coef1', to_torch( 166 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 167 | self.register_buffer('posterior_mean_coef2', to_torch( 168 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 169 | 170 | def q_mean_variance(self, x_start, t): 171 | mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 172 | variance = extract(1. - self.alphas_cumprod, t, x_start.shape) 173 | log_variance = extract( 174 | self.log_one_minus_alphas_cumprod, t, x_start.shape) 175 | return mean, variance, log_variance 176 | 177 | def predict_start_from_noise(self, x_t, t, noise): 178 | return ( 179 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 180 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 181 | ) 182 | 183 | def q_posterior(self, x_start, x_t, t): 184 | posterior_mean = ( 185 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 186 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 187 | ) 188 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 189 | posterior_log_variance_clipped = extract( 190 | self.posterior_log_variance_clipped, t, x_t.shape) 191 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 192 | 193 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None, style=None): 194 | if condition_x is not None: 195 | x_recon = self.predict_start_from_noise( 196 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), t)) 197 | else: 198 | x_recon = self.predict_start_from_noise( 199 | x, t=t, noise=self.denoise_fn(x, t)) 200 | 201 | if clip_denoised: 202 | x_recon.clamp_(-1., 1.) 203 | 204 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior( 205 | x_start=x_recon, x_t=x, t=t) 206 | return model_mean, posterior_variance, posterior_log_variance 207 | 208 | def p_mean_variance_ddim(self, x, t, clip_denoised: bool, condition_x=None, style=None): 209 | if condition_x is not None: 210 | # x_recon = self.predict_start_from_noise( 211 | # x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), t, style)) 212 | x_recon = self.denoise_fn(torch.cat([condition_x, x], dim=1), t, style) 213 | else: 214 | # x_recon = self.predict_start_from_noise( 215 | # x, t=t, noise=self.denoise_fn(x, t)) 216 | x_recon = self.denoise_fn(x, t) 217 | # if clip_denoised: 218 | # x_recon.clamp_(-1., 1.) 219 | 220 | alpha = extract(self.alphas_cumprod, t, x_recon.shape) 221 | alpha_prev = extract(self.alphas_cumprod_prev, t, x_recon.shape) 222 | sigma = extract(self.ddim_sigma, t, x_recon.shape) 223 | sqrt_one_minus_alphas = extract(self.sqrt_one_minus_alphas_cumprod, t, x_recon.shape) 224 | pred_x0 = (x - sqrt_one_minus_alphas * x_recon) / (alpha ** 0.5) 225 | 226 | dir_xt = torch.sqrt(1. - alpha_prev - sigma ** 2) * x_recon 227 | noise = torch.randn(x.shape, device=x.device) 228 | x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise 229 | return x_prev, pred_x0 230 | 231 | 232 | @torch.no_grad() 233 | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None, style=None): 234 | b, *_, device = *x.shape, x.device 235 | model_mean, _, model_log_variance = self.p_mean_variance( 236 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x, style=style) 237 | noise = noise_like(x.shape, device, repeat_noise) 238 | # no noise when t == 0 239 | nonzero_mask = (1 - (t == 0).float()).reshape(b, 240 | *((1,) * (len(x.shape) - 1))) 241 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 242 | 243 | @torch.no_grad() 244 | def p_sample2(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None, style=None): 245 | bt = extract(self.betas, t, x.shape) 246 | at = extract((1.0 - self.betas).cumprod(dim=0), t, x.shape) 247 | logvar = extract( 248 | self.posterior_log_variance_clipped, t, x.shape) 249 | weight = bt / torch.sqrt(1 - at) 250 | et = self.denoise_fn(torch.cat([condition_x, x], dim=1), t, style) 251 | # if clip_denoised: 252 | # et.clamp_(-1., 1.) 253 | mean = 1 / torch.sqrt(1.0 - bt) * (x - weight * et) 254 | noise = torch.randn_like(x) 255 | mask = 1 - (t == 0).float() 256 | mask = mask.reshape((x.shape[0],) + (1,) * (len(x.shape) - 1)) 257 | xt_next = mean + mask * torch.exp(0.5 * logvar) * noise 258 | xt_next = xt_next.float() 259 | return xt_next 260 | 261 | @torch.no_grad() 262 | def p_sample_ddim(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None, style=None): 263 | b, *_, device = *x.shape, x.device 264 | x_prev, pred_x0 = self.p_mean_variance_ddim( 265 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x, style=style) 266 | # noise = noise_like(x.shape, device, repeat_noise) 267 | # no noise when t == 0 268 | 269 | return x_prev 270 | 271 | def p_sample_ddim2(self, x, t, t_next, clip_denoised=True, repeat_noise=False, condition_x=None, style=None): 272 | b, *_, device = *x.shape, x.device 273 | bt = extract(self.betas, t, x.shape) 274 | at = extract((1.0 - self.betas).cumprod(dim=0), t, x.shape) 275 | 276 | if condition_x is not None: 277 | et = self.denoise_fn(torch.cat([condition_x, x], dim=1), t) 278 | else: 279 | et = self.denoise_fn(x, t) 280 | 281 | 282 | x0_t = (x - et * (1 - at).sqrt()) / at.sqrt() 283 | # x0_air_t = (x_air - et_air * (1 - at).sqrt()) / at.sqrt() 284 | if t_next == None: 285 | at_next = torch.ones_like(at) 286 | else: 287 | at_next = extract((1.0 - self.betas).cumprod(dim=0), t_next, x.shape) 288 | if self.eta == 0: 289 | xt_next = at_next.sqrt() * x0_t + (1 - at_next).sqrt() * et 290 | # xt_air_next = at_next.sqrt() * x0_air_t + (1 - at_next).sqrt() * et_air 291 | elif at > (at_next): 292 | print('Inversion process is only possible with eta = 0') 293 | raise ValueError 294 | else: 295 | c1 = self.eta * ((1 - at / (at_next)) * (1 - at_next) / (1 - at)).sqrt() 296 | c2 = ((1 - at_next) - c1 ** 2).sqrt() 297 | xt_next = at_next.sqrt() * x0_t + c2 * et + c1 * torch.randn_like(x0_t) 298 | # xt_air_next = at_next.sqrt() * x0_air_t + c2 * et_air + c1 * torch.randn_like(x0_t) 299 | 300 | # noise = noise_like(x.shape, device, repeat_noise) 301 | # no noise when t == 0 302 | 303 | return xt_next 304 | 305 | @torch.no_grad() 306 | def p_sample_loop(self, x_in, continous=False, cand=None): 307 | device = self.betas.device 308 | sample_inter = 10 309 | g_gpu = torch.Generator(device=device).manual_seed(44444) 310 | if not self.conditional: 311 | x = x_in['SR'] 312 | shape = x.shape 313 | b = shape[0] 314 | img = torch.randn(shape, device=device, generator=g_gpu) 315 | ret_img = img 316 | if cand is not None: 317 | time_steps = np.array(cand) 318 | else: 319 | num_timesteps_ddim = np.array([0, 245, 521, 1052, 1143, 1286, 1475, 1587, 1765, 1859]) # searching 320 | time_steps = np.flip(num_timesteps_ddim) 321 | for j, i in enumerate(tqdm(time_steps, desc='sampling loop time step', total=len(time_steps))): 322 | # print('i = ', i) 323 | t = torch.full((b,), i, device=device, dtype=torch.long) 324 | if j == len(time_steps) - 1: 325 | t_next = None 326 | else: 327 | t_next = torch.full((b,), time_steps[j + 1], device=device, dtype=torch.long) 328 | img = self.p_sample_ddim2(img, t, t_next, style=x_in['style']) 329 | if i % sample_inter == 0: 330 | ret_img = torch.cat([ret_img, img], dim=0) 331 | return img 332 | else: 333 | x = x_in['SR'] 334 | shape = x.shape 335 | b = shape[0] 336 | img = torch.randn(shape, device=device, generator=g_gpu) 337 | ret_img = x 338 | 339 | if self.sample_proc == 'ddpm': 340 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 341 | # print('i = ', i) 342 | img = self.p_sample(img, torch.full( 343 | (b,), i, device=device, dtype=torch.long), condition_x=x) 344 | if i % sample_inter == 0: 345 | ret_img = torch.cat([ret_img, img], dim=0) 346 | else: 347 | if cand is not None: 348 | time_steps = np.array(cand) 349 | # print(time_steps) 350 | else: 351 | time_steps = np.array([1898, 1640, 1539, 1491, 1370, 1136, 972, 858, 680, 340]) 352 | # time_steps = np.asarray(list(range(0, 1000, int(1000/4))) + list(range(1000, 2000, int(1000/6)))) 353 | # time_steps = np.flip(time_steps[:-1]) 354 | for j, i in enumerate(time_steps): 355 | # print('i = ', i) 356 | t = torch.full((b,), i, device=device, dtype=torch.long) 357 | if j == len(time_steps) - 1: 358 | t_next = None 359 | else: 360 | t_next = torch.full((b,), time_steps[j + 1], device=device, dtype=torch.long) 361 | img = self.p_sample_ddim2(img, t, t_next, condition_x=x, style=x_in['style']) 362 | if i % sample_inter == 0: 363 | ret_img = torch.cat([ret_img, img], dim=0) 364 | if continous: 365 | return ret_img 366 | else: 367 | return ret_img[-1] 368 | 369 | @torch.no_grad() 370 | def sample(self, batch_size=1, continous=False): 371 | image_size = self.image_size 372 | channels = self.channels 373 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous) 374 | 375 | @torch.no_grad() 376 | def super_resolution(self, x_in, continous=False, cand=None): 377 | return self.p_sample_loop(x_in, continous, cand=cand) 378 | 379 | 380 | @torch.no_grad() 381 | def interpolate(self, x1, x2, t=None, lam=0.5): 382 | b, *_, device = *x1.shape, x1.device 383 | t = default(t, self.num_timesteps - 1) 384 | 385 | assert x1.shape == x2.shape 386 | 387 | t_batched = torch.stack([torch.tensor(t, device=device)] * b) 388 | xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) 389 | 390 | img = (1 - lam) * xt1 + lam * xt2 391 | for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): 392 | img = self.p_sample(img, torch.full( 393 | (b,), i, device=device, dtype=torch.long)) 394 | 395 | return img 396 | 397 | def q_sample_recover(self, x_noisy, t, predict_noise=None): 398 | # noise = default(noise, lambda: torch.randn_like(x_start)) 399 | return (x_noisy - extract(self.sqrt_one_minus_alphas_cumprod, 400 | t, x_noisy.shape) * predict_noise) / extract(self.sqrt_alphas_cumprod, t, x_noisy.shape) 401 | 402 | 403 | def q_sample(self, x_start, t, noise=None): 404 | noise = default(noise, lambda: torch.randn_like(x_start)) 405 | 406 | # fix gama 407 | return ( 408 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 409 | extract(self.sqrt_one_minus_alphas_cumprod, 410 | t, x_start.shape) * noise 411 | ) 412 | 413 | 414 | def p_losses(self, x_in, noise=None): 415 | x_start = x_in['HR'] 416 | condition_x = x_in['SR'] 417 | [b, c, h, w] = x_start.shape 418 | t = torch.randint(0, self.num_timesteps, (b,), 419 | device=x_start.device).long() 420 | 421 | noise = default(noise, lambda: torch.randn_like(x_start)) 422 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 423 | 424 | if not self.conditional: 425 | x_recon = self.denoise_fn(x_noisy, t) 426 | else: 427 | x_recon = self.denoise_fn( 428 | torch.cat([condition_x, x_noisy], dim=1), t) 429 | 430 | 431 | loss = self.loss_func(noise, x_recon) 432 | 433 | return loss 434 | 435 | def forward(self, x, flag, *args, **kwargs): 436 | return self.p_losses(x, *args, **kwargs) -------------------------------------------------------------------------------- /model/ddpm_trans_modules/style_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | def gram_matrix(input): 5 | b, c, h, w = input.size() 6 | features = input.view(b * c, h * w) 7 | G = torch.mm(features, features.t()) 8 | return G.div(b * c * h * w) 9 | 10 | class VGGPerceptualLoss(torch.nn.Module): 11 | def __init__(self, resize=False): 12 | super(VGGPerceptualLoss, self).__init__() 13 | blocks = [] 14 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 15 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 16 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 17 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 18 | for bl in blocks: 19 | for p in bl.parameters(): 20 | p.requires_grad = False 21 | 22 | self.blocks = torch.nn.ModuleList(blocks) 23 | self.transform = torch.nn.functional.interpolate 24 | self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406], device='cuda').view(1, 3, 1, 1)) 25 | self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225], device='cuda').view(1, 3, 1, 1)) 26 | self.resize = resize 27 | 28 | def forward(self, input, support, support2, style): 29 | # print(input.shape, '--', target.shape, '--', style.shape) 30 | if input.shape[1] != 3: 31 | input = input.repeat(1, 3, 1, 1) 32 | support = support.repeat(1, 3, 1, 1) 33 | support2 = support2.repeat(1, 3, 1, 1) 34 | # target = target.repeat(1, 3, 1, 1) 35 | style = style.repeat(1, 3, 1, 1) 36 | 37 | # input = (input + 1) / 2 38 | # target = (target + 1) / 2 39 | # style = (style + 1) / 2 40 | # input = (input - self.mean) / self.std 41 | # target = (target - self.mean) / self.std 42 | # style = (style - self.mean) / self.std 43 | if self.resize: 44 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 45 | # target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 46 | style = self.transform(style, mode='bilinear', size=(224, 224), align_corners=False) 47 | content_loss = 0.0 48 | style_loss = 0.0 49 | style_loss2 = 0.0 50 | x = input 51 | y = support 52 | y2 = support2 53 | # y = target 54 | s = style 55 | for block in self.blocks: 56 | x = block(x) 57 | y = block(y) 58 | y2 = block(y2) 59 | s = block(s) 60 | # content_loss += torch.nn.functional.mse_loss(x, y) 61 | style_loss += torch.nn.functional.mse_loss(gram_matrix(x), gram_matrix(s).detach()) 62 | style_loss2 += torch.nn.functional.mse_loss(gram_matrix(y), gram_matrix(s).detach()) \ 63 | + torch.nn.functional.mse_loss(gram_matrix(y2), gram_matrix(s).detach()) 64 | 65 | # return content_loss, style_loss 66 | return style_loss + style_loss2 -------------------------------------------------------------------------------- /model/ddpm_trans_modules/trans_block_eca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pdb import set_trace as stx 5 | import numbers 6 | 7 | from einops import rearrange 8 | 9 | 10 | ########################################################################## 11 | ## Layer Norm 12 | 13 | def to_3d(x): 14 | return rearrange(x, 'b c h w -> b (h w) c') 15 | 16 | 17 | def to_4d(x, h, w): 18 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 19 | 20 | 21 | class BiasFree_LayerNorm(nn.Module): 22 | def __init__(self, normalized_shape): 23 | super(BiasFree_LayerNorm, self).__init__() 24 | if isinstance(normalized_shape, numbers.Integral): 25 | normalized_shape = (normalized_shape,) 26 | normalized_shape = torch.Size(normalized_shape) 27 | 28 | assert len(normalized_shape) == 1 29 | 30 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 31 | self.normalized_shape = normalized_shape 32 | 33 | def forward(self, x): 34 | sigma = x.var(-1, keepdim=True, unbiased=False) 35 | return x / torch.sqrt(sigma + 1e-5) * self.weight 36 | 37 | 38 | class WithBias_LayerNorm(nn.Module): 39 | def __init__(self, normalized_shape): 40 | super(WithBias_LayerNorm, self).__init__() 41 | if isinstance(normalized_shape, numbers.Integral): 42 | normalized_shape = (normalized_shape,) 43 | normalized_shape = torch.Size(normalized_shape) 44 | 45 | assert len(normalized_shape) == 1 46 | 47 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 48 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 49 | self.normalized_shape = normalized_shape 50 | 51 | def forward(self, x): 52 | mu = x.mean(-1, keepdim=True) 53 | sigma = x.var(-1, keepdim=True, unbiased=False) 54 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 55 | 56 | 57 | class LayerNorm(nn.Module): 58 | def __init__(self, dim, LayerNorm_type): 59 | super(LayerNorm, self).__init__() 60 | if LayerNorm_type == 'BiasFree': 61 | self.body = BiasFree_LayerNorm(dim) 62 | else: 63 | self.body = WithBias_LayerNorm(dim) 64 | 65 | def forward(self, x): 66 | h, w = x.shape[-2:] 67 | return to_4d(self.body(to_3d(x)), h, w) 68 | 69 | 70 | ########################################################################## 71 | ## Gated-Dconv Feed-Forward Network (GDFN) 72 | class FeedForward(nn.Module): 73 | def __init__(self, dim, ffn_expansion_factor, bias): 74 | super(FeedForward, self).__init__() 75 | 76 | hidden_features = int(dim * ffn_expansion_factor) 77 | 78 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 79 | 80 | self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, 81 | groups=hidden_features * 2, bias=bias) 82 | 83 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 84 | 85 | def forward(self, x): 86 | x = self.project_in(x) 87 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 88 | x = F.gelu(x1) * x2 89 | x = self.project_out(x) 90 | return x 91 | 92 | # ECA attention module 93 | class Attention_eca(nn.Module): 94 | def __init__(self, num_heads, k_size, bias): 95 | super(Attention_eca, self).__init__() 96 | self.num_heads = num_heads 97 | 98 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 99 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=bias) 100 | self.sigmoid = nn.Sigmoid() 101 | 102 | def forward(self, x): 103 | heads = x.chunk(self.num_heads, dim=1) 104 | outputs = [] 105 | for head in heads: 106 | y = self.avg_pool(head) 107 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 108 | y = self.sigmoid(y) 109 | out = head * y.expand_as(head) 110 | outputs.append(out) 111 | # Two different branches of ECA module 112 | output = torch.cat(outputs, dim=1) 113 | 114 | return output 115 | 116 | ########################################################################## 117 | class TransformerBlock_eca(nn.Module): 118 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 119 | super(TransformerBlock_eca, self).__init__() 120 | 121 | self.norm1 = LayerNorm(dim, LayerNorm_type) 122 | self.attn = Attention_eca(num_heads, 3, bias) 123 | self.norm2 = LayerNorm(dim, LayerNorm_type) 124 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 125 | 126 | def forward(self, x): 127 | x = x + self.attn(self.norm1(x)) 128 | x = x + self.ffn(self.norm2(x)) 129 | 130 | return x 131 | 132 | if __name__ == '__main__': 133 | input = torch.zeros([2, 48, 128, 128]) 134 | # model = Restormer() 135 | # output = model(input) 136 | model2 = nn.Sequential(*[ 137 | TransformerBlock_eca(dim=int(48), num_heads=2, ffn_expansion_factor=2.66, 138 | bias=False, LayerNorm_type='WithBias') for i in range(1)]) 139 | # model3 = Attention_sa(1, 16, 48) 140 | output2 = model2(input) 141 | print(output2.shape) -------------------------------------------------------------------------------- /model/ddpm_trans_modules/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | 7 | from model.ddpm_trans_modules.trans_block_eca import TransformerBlock_eca 8 | 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | def gram_matrix(input): 20 | a, b, c, d = input.size() # a=batch size(=1) 21 | # b=number of feature maps 22 | # (c,d)=dimensions of a f. map (N=c*d) 23 | # input = F.normalize(input, p=2, dim=1, eps=1e-12) 24 | # input = (input - torch.min(input)) / (torch.max(input) - torch.min(input)) 25 | 26 | features = input.view(a * b, c * d) # resise F_XL into \hat F_XL 27 | 28 | G = torch.mm(features, features.t()) # compute the gram product 29 | 30 | # we 'normalize' the values of the gram matrix 31 | # by dividing by the number of element in each feature maps. 32 | return G.div(a * b) 33 | 34 | # model 35 | 36 | 37 | 38 | class TimeEmbedding(nn.Module): 39 | def __init__(self, dim): 40 | super().__init__() 41 | self.dim = dim 42 | inv_freq = torch.exp( 43 | torch.arange(0, dim, 2, dtype=torch.float32) * 44 | (-math.log(10000) / dim) 45 | ) 46 | self.register_buffer("inv_freq", inv_freq) 47 | 48 | def forward(self, input): 49 | shape = input.shape 50 | sinusoid_in = torch.ger(input.view(-1).float(), self.inv_freq) 51 | pos_emb = torch.cat([sinusoid_in.sin(), sinusoid_in.cos()], dim=-1) 52 | pos_emb = pos_emb.view(*shape, self.dim) 53 | return pos_emb 54 | 55 | 56 | class Swish(nn.Module): 57 | def forward(self, x): 58 | return x * torch.sigmoid(x) 59 | 60 | 61 | class Upsample(nn.Module): 62 | def __init__(self, dim): 63 | super().__init__() 64 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 65 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 66 | 67 | def forward(self, x): 68 | return self.conv(self.up(x)) 69 | 70 | 71 | class Downsample(nn.Module): 72 | def __init__(self, dim): 73 | super().__init__() 74 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 75 | 76 | def forward(self, x): 77 | return self.conv(x) 78 | 79 | 80 | # building block modules 81 | 82 | 83 | class Block(nn.Module): 84 | def __init__(self, dim, dim_out, groups=32, dropout=0): 85 | super().__init__() 86 | self.block = nn.Sequential( 87 | nn.GroupNorm(groups, dim), 88 | Swish(), 89 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 90 | nn.Conv2d(dim, dim_out, 3, padding=1) 91 | ) 92 | 93 | def forward(self, x): 94 | return self.block(x) 95 | 96 | 97 | class ResnetBlock(nn.Module): 98 | def __init__(self, dim, dim_out, time_emb_dim=None, dropout=0, norm_groups=32): 99 | super().__init__() 100 | self.mlp = nn.Sequential( 101 | Swish(), 102 | nn.Linear(time_emb_dim, dim_out) 103 | ) if exists(time_emb_dim) else None 104 | 105 | self.block1 = Block(dim, dim_out, groups=norm_groups) 106 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 107 | self.res_conv = nn.Conv2d( 108 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 109 | 110 | def forward(self, x, time_emb): 111 | h = self.block1(x) 112 | if exists(self.mlp): 113 | h += self.mlp(time_emb)[:, :, None, None] 114 | h = self.block2(h) 115 | return h + self.res_conv(x) 116 | 117 | class ResnetBloc_eca(nn.Module): 118 | def __init__(self, dim, dim_out, *, time_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 119 | super().__init__() 120 | self.with_attn = with_attn 121 | self.res_block = ResnetBlock( 122 | dim, dim_out, time_emb_dim, norm_groups=norm_groups, dropout=dropout) 123 | if with_attn: 124 | self.attn = nn.Sequential(*[TransformerBlock_eca(dim=int(dim), num_heads=2, ffn_expansion_factor=2.66, 125 | bias=False, LayerNorm_type='WithBias') for i in range(1)]) 126 | 127 | def forward(self, x, time_emb): 128 | x = self.res_block(x, time_emb) 129 | if(self.with_attn): 130 | x = self.attn(x) 131 | return x 132 | 133 | class Encoder(nn.Module): 134 | def __init__( 135 | self, 136 | in_channel=6, 137 | inner_channel=32, 138 | norm_groups=32, 139 | ): 140 | super().__init__() 141 | 142 | dim = inner_channel 143 | time_dim = inner_channel 144 | 145 | self.conv1 = nn.Sequential(nn.Conv2d(in_channel, dim, kernel_size=3, stride=1, padding=1, bias=False)) 146 | self.conv2 = nn.Sequential(nn.Conv2d(dim, dim // 2, kernel_size=3, stride=1, padding=1), 147 | nn.PixelUnshuffle(2)) 148 | self.conv3 = nn.Sequential( 149 | nn.Conv2d(int(dim * 2 ** 1), int(dim * 2 ** 1) // 2, kernel_size=3, stride=1, padding=1), 150 | nn.PixelUnshuffle(2)) 151 | 152 | self.conv4 = nn.Sequential( 153 | nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 2) // 2, kernel_size=3, stride=1, padding=1), 154 | nn.PixelUnshuffle(2)) 155 | 156 | self.block1 = ResnetBloc_eca(dim=dim, dim_out=dim, time_emb_dim=time_dim, norm_groups=norm_groups, 157 | with_attn=True) 158 | self.block2 = ResnetBloc_eca(dim=dim * 2 ** 1, dim_out=dim * 2 ** 1, time_emb_dim=time_dim, 159 | norm_groups=norm_groups, with_attn=True) 160 | self.block3 = ResnetBloc_eca(dim=dim * 2 ** 2, dim_out=dim * 2 ** 2, time_emb_dim=time_dim, 161 | norm_groups=norm_groups, with_attn=True) 162 | self.block4 = ResnetBloc_eca(dim=dim * 2 ** 3, dim_out=dim * 2 ** 3, time_emb_dim=time_dim, 163 | norm_groups=norm_groups, with_attn=True) 164 | 165 | self.conv_up3 = nn.Sequential( 166 | nn.Conv2d((dim * 2 ** 3), (dim * 2 ** 3) * 2, kernel_size=3, stride=1, padding=1, bias=False), 167 | nn.PixelShuffle(2)) 168 | 169 | self.conv_up2 = nn.Sequential( 170 | nn.Conv2d((dim * 2 ** 2), (dim * 2 ** 2) * 2, kernel_size=3, stride=1, padding=1, bias=False), 171 | nn.PixelShuffle(2)) 172 | self.conv_up1 = nn.Sequential( 173 | nn.Conv2d((dim * 2 ** 1), (dim * 2 ** 1) * 2, kernel_size=3, stride=1, padding=1, bias=False), 174 | nn.PixelShuffle(2)) 175 | 176 | self.conv_cat3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=False) 177 | self.conv_cat2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=False) 178 | 179 | self.decoder_block3 = ResnetBloc_eca(dim=dim * 2 ** 2, dim_out=dim * 2 ** 2, time_emb_dim=time_dim, 180 | norm_groups=norm_groups, with_attn=True) 181 | self.decoder_block2 = ResnetBloc_eca(dim=dim * 2 ** 1, dim_out=dim * 2 ** 1, time_emb_dim=time_dim, 182 | norm_groups=norm_groups, with_attn=True) 183 | self.decoder_block1 = ResnetBloc_eca(dim=dim * 2 ** 1, dim_out=dim * 2 ** 1, time_emb_dim=time_dim, 184 | norm_groups=norm_groups, with_attn=True) 185 | 186 | def forward(self, x, t): 187 | x = self.conv1(x) 188 | x1 = self.block1(x, t) 189 | 190 | x2 = self.conv2(x1) 191 | x2 = self.block2(x2, t) 192 | 193 | x3 = self.conv3(x2) 194 | x3 = self.block3(x3, t) 195 | 196 | x4 = self.conv4(x3) 197 | x4 = self.block4(x4, t) 198 | 199 | de_level3 = self.conv_up3(x4) 200 | de_level3 = torch.cat([de_level3, x3], 1) 201 | de_level3 = self.conv_cat3(de_level3) 202 | de_level3 = self.decoder_block3(de_level3, t) 203 | 204 | de_level2 = self.conv_up2(de_level3) 205 | de_level2 = torch.cat([de_level2, x2], 1) 206 | de_level2 = self.conv_cat2(de_level2) 207 | de_level2 = self.decoder_block2(de_level2, t) 208 | 209 | de_level1 = self.conv_up1(de_level2) 210 | de_level1 = torch.cat([de_level1, x1], 1) 211 | mid_feat = self.decoder_block1(de_level1, t) 212 | 213 | return mid_feat, de_level2 214 | 215 | class AdaptiveInstanceNorm2d(nn.Module): 216 | def __init__(self): 217 | super(AdaptiveInstanceNorm2d, self).__init__() 218 | self.eps = 1e-5 219 | 220 | def forward(self, x, y): 221 | mean_x, mean_y = torch.mean(x, dim=(2, 3), keepdim=True), torch.mean(y, dim=(2, 3), keepdim=True) 222 | std_x, std_y = torch.std(x, dim=(2, 3), keepdim=True) + self.eps, torch.std(y, dim=(2, 3), keepdim=True) + self.eps 223 | return std_y * (x - mean_x) / std_x + mean_y 224 | 225 | class UNet(nn.Module): 226 | def __init__( 227 | self, 228 | in_channel=6, 229 | out_channel=3, 230 | inner_channel=32, 231 | norm_groups=32, 232 | channel_mults=(1, 2, 4, 8, 8), 233 | attn_res=(8), 234 | res_blocks=3, 235 | dropout=0, 236 | with_time_emb=True, 237 | image_size=128 238 | ): 239 | super().__init__() 240 | 241 | if with_time_emb: 242 | time_dim = inner_channel 243 | self.time_mlp = nn.Sequential( 244 | TimeEmbedding(inner_channel), 245 | nn.Linear(inner_channel, inner_channel * 4), 246 | Swish(), 247 | nn.Linear(inner_channel * 4, inner_channel) 248 | ) 249 | else: 250 | time_dim = None 251 | self.time_mlp = None 252 | 253 | dim = inner_channel 254 | 255 | self.encoder_water = Encoder(in_channel=in_channel, inner_channel=inner_channel, norm_groups=norm_groups) 256 | 257 | self.refine = ResnetBloc_eca(dim=dim*2**1, dim_out=dim*2**1, time_emb_dim=time_dim, norm_groups=norm_groups, with_attn=True) 258 | self.de_predict = nn.Sequential(nn.Conv2d(dim * 2 ** 1, out_channel, kernel_size=1, stride=1)) 259 | 260 | 261 | def forward(self, x, time): 262 | # print(time.shape) 263 | t = self.time_mlp(time) if exists(self.time_mlp) else None 264 | 265 | mid_feat, x1 = self.encoder_water(x, t) 266 | # mid_feat_air, x1_air = self.encoder_air(x_air, t) 267 | 268 | mid_feat2 = self.refine(mid_feat, t) 269 | 270 | return self.de_predict(mid_feat2) 271 | 272 | if __name__ == '__main__': 273 | 274 | img = torch.zeros(2, 3, 128, 128) 275 | time = torch.tensor([1, 2]) 276 | model = UNet(inner_channel=48, norm_groups=24, in_channel=3) 277 | output, a = model(img, img, time) 278 | # output = model2(img) 279 | print(output.shape) 280 | print(a.shape) 281 | # print(b) 282 | # total = sum([param.nelement() for param in model.parameters()]) 283 | # print('parameter: %.2fM' % (total / 1e6)) -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | import torch.nn.functional as F 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import model.networks as networks 8 | from .base_model import BaseModel 9 | # from .ddpm_trans_modules.style_loss import LossNetwork 10 | logger = logging.getLogger('base') 11 | 12 | 13 | 14 | class DDPM(BaseModel): 15 | def __init__(self, opt): 16 | super(DDPM, self).__init__(opt) 17 | # define network and load pretrained models 18 | self.netG = self.set_device(networks.define_G(opt)) 19 | # self.netG_air = self.set_device(networks.define_G(opt)) 20 | # self.dis_water = self.set_device(D2()) 21 | # self.dis_air = self.set_device(D2()) 22 | self.schedule_phase = None 23 | 24 | # set loss and load resume state 25 | self.set_loss() 26 | self.loss_func = nn.MSELoss(reduction='sum').to(self.device) 27 | # self.loss_style = LossNetwork().to(self.device) 28 | # self.style_loss = VGGPerceptualLoss().to(self.device) 29 | self.set_new_noise_schedule( 30 | opt['model']['beta_schedule']['train'], schedule_phase='train') 31 | if self.opt['phase'] == 'train': 32 | self.netG.train() 33 | # self.netG_air.train() 34 | # self.dis_water.train() 35 | # self.dis_air.train() 36 | # find the parameters to optimize 37 | if opt['model']['finetune_norm']: 38 | optim_params = [] 39 | for k, v in self.netG.named_parameters(): 40 | v.requires_grad = False 41 | if k.find('transformer') >= 0: 42 | v.requires_grad = True 43 | v.data.zero_() 44 | optim_params.append(v) 45 | logger.info( 46 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k)) 47 | else: 48 | optim_params = list(self.netG.parameters()) 49 | # optim_params = list(self.netG.parameters()) + list(self.netG_air.parameters()) 50 | # optim_params_dis = list(self.dis_air.parameters()) + list(self.dis_water.parameters()) 51 | 52 | self.optG = torch.optim.Adam( 53 | optim_params, lr=opt['train']["optimizer"]["lr"]) 54 | # self.optD = torch.optim.Adam( 55 | # optim_params_dis, lr=opt['train']["optimizer"]["lr"]) 56 | self.log_dict = OrderedDict() 57 | self.load_network() 58 | self.print_network() 59 | 60 | 61 | def set_requires_grad(self, nets, requires_grad=False): 62 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 63 | Parameters: 64 | nets (network list) -- a list of networks 65 | requires_grad (bool) -- whether the networks require gradients or not 66 | """ 67 | if not isinstance(nets, list): 68 | nets = [nets] 69 | for net in nets: 70 | if net is not None: 71 | for param in net.parameters(): 72 | param.requires_grad = requires_grad 73 | 74 | def feed_data(self, data): 75 | self.data = self.set_device(data) 76 | 77 | def optimize_parameters(self, flag=None): 78 | # need to average in multi-gpu 79 | if flag is None: 80 | 81 | self.optG.zero_grad() 82 | l_pix = self.netG(self.data, flag=None) 83 | 84 | l_pix.backward() 85 | self.optG.step() 86 | # print('single mse:', l_pix.item()) 87 | # set log 88 | self.log_dict['l_pix'] = l_pix.item() 89 | 90 | def optimize_parameters2(self): 91 | # need to average in multi-gpu 92 | self.optG.zero_grad() 93 | l_pix = self.netG(self.data) 94 | # need to average in multi-gpu 95 | b, c, h, w = self.data['HR'].shape 96 | l_pix = l_pix.sum() / int(b * c * h * w) 97 | l_pix.backward() 98 | self.optG.step() 99 | 100 | # set log 101 | self.log_dict['l_pix'] = l_pix.item() 102 | 103 | 104 | def test(self, cand=None, continous=False): 105 | self.netG.eval() 106 | with torch.no_grad(): 107 | if isinstance(self.netG, nn.DataParallel): 108 | self.SR = self.netG.module.super_resolution( 109 | self.data, continous) 110 | else: 111 | # n = None 112 | # self.temp, miu, var = self.netG_air.super_resolution( 113 | # self.data, continous, flag='style', n=n) 114 | self.SR = self.netG.super_resolution( 115 | self.data, continous, cand=cand) 116 | 117 | self.netG.train() 118 | 119 | def sample(self, batch_size=1, continous=False): 120 | self.netG.eval() 121 | with torch.no_grad(): 122 | if isinstance(self.netG, nn.DataParallel): 123 | self.SR = self.netG.module.sample(batch_size, continous) 124 | else: 125 | self.SR = self.netG.sample(batch_size, continous) 126 | self.netG.train() 127 | 128 | def set_loss(self): 129 | if isinstance(self.netG, nn.DataParallel): 130 | self.netG.module.set_loss(self.device) 131 | else: 132 | self.netG.set_loss(self.device) 133 | 134 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'): 135 | if self.schedule_phase is None or self.schedule_phase != schedule_phase: 136 | self.schedule_phase = schedule_phase 137 | if isinstance(self.netG, nn.DataParallel): 138 | self.netG.module.set_new_noise_schedule( 139 | schedule_opt, self.device) 140 | else: 141 | self.netG.set_new_noise_schedule(schedule_opt, self.device) 142 | 143 | 144 | def get_current_log(self): 145 | return self.log_dict 146 | 147 | def get_current_visuals(self, need_LR=True, sample=False): 148 | out_dict = OrderedDict() 149 | if sample: 150 | out_dict['SAM'] = self.SR.detach().float().cpu() 151 | else: 152 | out_dict['SR'] = self.SR.detach().float().cpu() 153 | out_dict['INF'] = self.data['SR'].detach().float().cpu() 154 | out_dict['HR'] = self.data['HR'].detach().float().cpu() 155 | if need_LR and 'LR' in self.data: 156 | out_dict['LR'] = self.data['LR'].detach().float().cpu() 157 | else: 158 | out_dict['LR'] = out_dict['INF'] 159 | return out_dict 160 | 161 | def print_network(self): 162 | s, n = self.get_network_description(self.netG) 163 | if isinstance(self.netG, nn.DataParallel): 164 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 165 | self.netG.module.__class__.__name__) 166 | else: 167 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 168 | 169 | logger.info( 170 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 171 | logger.info(s) 172 | 173 | def save_network(self, epoch, iter_step): 174 | gen_path = os.path.join( 175 | self.opt['path']['checkpoint'], 'I{}_E{}_gen.pth'.format(iter_step, epoch)) 176 | opt_path = os.path.join( 177 | self.opt['path']['checkpoint'], 'I{}_E{}_opt.pth'.format(iter_step, epoch)) 178 | # gen 179 | network = self.netG 180 | if isinstance(self.netG, nn.DataParallel): 181 | network = network.module 182 | state_dict = network.state_dict() 183 | for key, param in state_dict.items(): 184 | state_dict[key] = param.cpu() 185 | torch.save(state_dict, gen_path) 186 | # opt 187 | opt_state = {'epoch': epoch, 'iter': iter_step, 188 | 'scheduler': None, 'optimizer': None} 189 | opt_state['optimizer'] = self.optG.state_dict() 190 | torch.save(opt_state, opt_path) 191 | 192 | logger.info( 193 | 'Saved model in [{:s}] ...'.format(gen_path)) 194 | 195 | def load_network(self): 196 | load_path = self.opt['path']['resume_state'] 197 | if load_path is not None: 198 | logger.info( 199 | 'Loading pretrained model for G [{:s}] ...'.format(load_path)) 200 | gen_path = '{}_gen.pth'.format(load_path) 201 | opt_path = '{}_opt.pth'.format(load_path) 202 | # gen 203 | network = self.netG 204 | if isinstance(self.netG, nn.DataParallel): 205 | network = network.module 206 | network.load_state_dict(torch.load( 207 | gen_path), strict=(not self.opt['model']['finetune_norm'])) 208 | 209 | # load_part_of_model(network, gen_path, s=(not self.opt['model']['finetune_norm'])) 210 | if self.opt['phase'] == 'train': 211 | # optimizer 212 | opt = torch.load(opt_path) 213 | self.optG.load_state_dict(opt['optimizer']) 214 | self.begin_step = opt['iter'] 215 | self.begin_epoch = opt['epoch'] 216 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | logger = logging.getLogger('base') 8 | #################### 9 | # initialize 10 | #################### 11 | 12 | 13 | def weights_init_normal(m, std=0.02): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.normal_(m.weight.data, 0.0, std) 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | elif classname.find('Linear') != -1: 20 | init.normal_(m.weight.data, 0.0, std) 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def weights_init_kaiming(m, scale=1): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv2d') != -1: 31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 32 | m.weight.data *= scale 33 | if m.bias is not None: 34 | m.bias.data.zero_() 35 | elif classname.find('Linear') != -1: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | m.weight.data *= scale 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | elif classname.find('BatchNorm2d') != -1: 41 | init.constant_(m.weight.data, 1.0) 42 | init.constant_(m.bias.data, 0.0) 43 | 44 | 45 | def weights_init_orthogonal(m): 46 | classname = m.__class__.__name__ 47 | if classname.find('Conv') != -1: 48 | init.orthogonal_(m.weight.data, gain=1) 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | elif classname.find('Linear') != -1: 52 | init.orthogonal_(m.weight.data, gain=1) 53 | if m.bias is not None: 54 | m.bias.data.zero_() 55 | elif classname.find('BatchNorm2d') != -1: 56 | init.constant_(m.weight.data, 1.0) 57 | init.constant_(m.bias.data, 0.0) 58 | 59 | 60 | def init_weights(net, init_type='kaiming', scale=1, std=0.02): 61 | # scale for 'kaiming', std for 'normal'. 62 | logger.info('Initialization method [{:s}]'.format(init_type)) 63 | if init_type == 'normal': 64 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 65 | net.apply(weights_init_normal_) 66 | elif init_type == 'kaiming': 67 | weights_init_kaiming_ = functools.partial( 68 | weights_init_kaiming, scale=scale) 69 | net.apply(weights_init_kaiming_) 70 | elif init_type == 'orthogonal': 71 | net.apply(weights_init_orthogonal) 72 | else: 73 | raise NotImplementedError( 74 | 'initialization method [{:s}] not implemented'.format(init_type)) 75 | 76 | 77 | #################### 78 | # define network 79 | #################### 80 | 81 | 82 | # Generator 83 | def define_G(opt): 84 | model_opt = opt['model'] 85 | if model_opt['which_model_G'] == 'ddpm': 86 | from .ddpm_modules import diffusion, unet 87 | elif model_opt['which_model_G'] == 'sr3': 88 | from .sr3_modules import diffusion, unet 89 | elif model_opt['which_model_G'] == 'trans': 90 | from .ddpm_trans_modules import diffusion, unet 91 | elif model_opt['which_model_G'] == 'trans_div': 92 | from .ddpm_modules import diffusion, unet_backup 93 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None: 94 | model_opt['unet']['norm_groups']=32 95 | if model_opt['which_model_G'] == 'trans_div': 96 | model = unet_backup.DiT(depth=12, in_channels=6, hidden_size=384, patch_size=4, num_heads=6, input_size=128) 97 | else: 98 | model = unet.UNet( 99 | in_channel=model_opt['unet']['in_channel'], 100 | out_channel=model_opt['unet']['out_channel'], 101 | norm_groups=model_opt['unet']['norm_groups'], 102 | inner_channel=model_opt['unet']['inner_channel'], 103 | channel_mults=model_opt['unet']['channel_multiplier'], 104 | attn_res=model_opt['unet']['attn_res'], 105 | res_blocks=model_opt['unet']['res_blocks'], 106 | dropout=model_opt['unet']['dropout'], 107 | image_size=model_opt['diffusion']['image_size'] 108 | ) 109 | netG = diffusion.GaussianDiffusion( 110 | model, 111 | image_size=model_opt['diffusion']['image_size'], 112 | channels=model_opt['diffusion']['channels'], 113 | loss_type='l1', # L1 or L2 114 | conditional=model_opt['diffusion']['conditional'], 115 | schedule_opt=model_opt['beta_schedule']['train'] 116 | ) 117 | # if opt['phase'] == 'train': 118 | # # init_weights(netG, init_type='kaiming', scale=0.1) 119 | # init_weights(netG, init_type='orthogonal') 120 | if opt['gpu_ids'] and opt['distributed']: 121 | assert torch.cuda.is_available() 122 | netG = nn.DataParallel(netG) 123 | return netG 124 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def load_part_of_model(new_model, src_model_path, s): 4 | src_model = torch.load(src_model_path) 5 | m_dict = new_model.state_dict() 6 | for k in src_model.keys(): 7 | if k in m_dict.keys(): 8 | param = src_model.get(k) 9 | if param.shape == m_dict[k].data.shape: 10 | m_dict[k].data = param 11 | print('loading:', k) 12 | else: 13 | print('shape is different, not loading:', k) 14 | else: 15 | print('not loading:', k) 16 | 17 | new_model.load_state_dict(m_dict, strict=s) 18 | return new_model 19 | 20 | def load_part_of_model2(new_model, src_model_path): 21 | src_model = torch.load(src_model_path) 22 | m_dict = new_model.state_dict() 23 | for k in src_model.keys(): 24 | k2 = k.replace('denoise_fn.', '') 25 | if k2 in m_dict.keys(): 26 | # print(k) 27 | param = src_model.get(k) 28 | if param.shape == m_dict[k2].data.shape: 29 | m_dict[k2].data = param 30 | print('loading:', k) 31 | # else: 32 | # print('shape is different, not loading:', k) 33 | else: 34 | print('not loading:', k) 35 | 36 | new_model.load_state_dict(m_dict) 37 | return new_model -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6 2 | torchvision 3 | numpy 4 | pandas 5 | tqdm 6 | lmdb 7 | opencv-python 8 | pillow 9 | tensorboardx 10 | wandb 11 | 12 | -------------------------------------------------------------------------------- /search_diffusion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import numpy as np 6 | import pickle 7 | import torch 8 | import logging 9 | import argparse 10 | import torch 11 | import random 12 | 13 | import data as Data 14 | import model as Model 15 | import core.logger as Logger 16 | from core.wandb_logger import WandbLogger 17 | from tensorboardX import SummaryWriter 18 | 19 | from tester_water import get_cand_err2 20 | import sys 21 | sys.setrecursionlimit(10000) 22 | import argparse 23 | 24 | import functools 25 | print = functools.partial(print, flush=True) 26 | 27 | choice = lambda x: x[np.random.randint(len(x))] if isinstance( 28 | x, tuple) else choice(tuple(x)) 29 | 30 | # device_id = 0 31 | # torch.cuda.set_device(device_id) 32 | 33 | args = { 34 | 'max_num': 2000, 35 | 'choice': 8, 36 | 'layers': 10, 37 | 'en_channels': [64, 128, 256], 38 | 'dim': 48, 39 | 'log_dir': 'log', 40 | 'max_epochs': 100, 41 | 'select_num': 10, 42 | 'population_num': 40, 43 | 'top_k': 20, 44 | 'm_prob': 0.1, 45 | 'crossover_num': 50, 46 | 'mutation_num': 50, 47 | 'flops_limit': 330 * 1e6, 48 | } 49 | 50 | 51 | class EvolutionSearcher(object): 52 | 53 | def __init__(self): 54 | self.args = args 55 | # print(args['flops-limit']) 56 | 57 | self.max_epochs = args['max_epochs'] 58 | self.select_num = args['select_num'] 59 | self.top_k = args['top_k'] 60 | self.population_num = args['population_num'] 61 | self.m_prob = args['m_prob'] 62 | self.crossover_num = args['crossover_num'] 63 | self.mutation_num = args['mutation_num'] 64 | self.flops_limit = args['flops_limit'] 65 | 66 | # diffusion model init 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('-c', '--config', type=str, default='config/underwater.json', 69 | help='JSON file for configuration') 70 | parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val') 71 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 72 | parser.add_argument('-debug', '-d', action='store_true') 73 | parser.add_argument('-enable_wandb', action='store_true') 74 | parser.add_argument('-log_infer', action='store_true') 75 | 76 | # parse configs 77 | args2 = parser.parse_args() 78 | opt = Logger.parse(args2) 79 | # Convert to NoneDict, which return None for missing key. 80 | opt = Logger.dict_to_nonedict(opt) 81 | 82 | # logging 83 | torch.backends.cudnn.enabled = True 84 | torch.backends.cudnn.benchmark = True 85 | 86 | # dataset 87 | for phase, dataset_opt in opt['datasets'].items(): 88 | if phase == 'val': 89 | val_set = Data.create_dataset(dataset_opt, phase) 90 | val_loader = Data.create_dataloader( 91 | val_set, dataset_opt, phase) 92 | 93 | # model 94 | diffusion = Model.create_model(opt) 95 | 96 | diffusion.set_new_noise_schedule( 97 | opt['model']['beta_schedule']['val'], schedule_phase='val') 98 | 99 | self.model = diffusion 100 | self.val_loader = val_loader 101 | 102 | 103 | self.log_dir = args['log_dir'] 104 | self.checkpoint_name = os.path.join(self.log_dir, 'checkpoint.pth.tar') 105 | 106 | self.memory = [] 107 | self.vis_dict = {} 108 | self.keep_top_k = {self.select_num: [], self.top_k: []} 109 | self.epoch = 0 110 | self.candidates = [] 111 | 112 | self.nr_layer = args['layers'] 113 | self.nr_state = args['choice'] 114 | self.max_num = args['max_num'] 115 | 116 | def save_checkpoint(self): 117 | if not os.path.exists(self.log_dir): 118 | os.makedirs(self.log_dir) 119 | info = {} 120 | info['memory'] = self.memory 121 | info['candidates'] = self.candidates 122 | info['vis_dict'] = self.vis_dict 123 | info['keep_top_k'] = self.keep_top_k 124 | info['epoch'] = self.epoch 125 | torch.save(info, self.checkpoint_name) 126 | print('save checkpoint to', self.checkpoint_name) 127 | 128 | def load_checkpoint(self): 129 | if not os.path.exists(self.checkpoint_name): 130 | return False 131 | info = torch.load(self.checkpoint_name) 132 | self.memory = info['memory'] 133 | self.candidates = info['candidates'] 134 | self.vis_dict = info['vis_dict'] 135 | self.keep_top_k = info['keep_top_k'] 136 | self.epoch = info['epoch'] 137 | 138 | print('load checkpoint from', self.checkpoint_name) 139 | print('infor message:', info) 140 | return True 141 | 142 | def is_legal(self, cand): 143 | assert isinstance(cand, tuple) and len(cand) == self.nr_layer 144 | if cand not in self.vis_dict: 145 | self.vis_dict[cand] = {} 146 | info = self.vis_dict[cand] 147 | if 'visited' in info: 148 | return False 149 | 150 | # if 'flops' not in info: 151 | # info['flops'] = get_cand_flops(cand) 152 | 153 | # if info['flops'] > self.flops_limit: 154 | # print('flops limit exceed') 155 | # return False 156 | 157 | info['err'] = get_cand_err2(self.model, cand, self.val_loader, self.args) 158 | print(cand, '--- psnr:', info['err']) 159 | info['visited'] = True 160 | 161 | return True 162 | 163 | def update_top_k(self, candidates, *, k, key, reverse=False): 164 | assert k in self.keep_top_k 165 | print('select ......') 166 | t = self.keep_top_k[k] 167 | t += candidates 168 | t.sort(key=key, reverse=reverse) 169 | self.keep_top_k[k] = t[:k] 170 | 171 | def stack_random_cand(self, random_func, *, batchsize=10): 172 | while True: 173 | cands = [random_func() for _ in range(batchsize)] 174 | for cand in cands: 175 | if cand not in self.vis_dict: 176 | self.vis_dict[cand] = {} 177 | info = self.vis_dict[cand] 178 | for cand in cands: 179 | yield cand 180 | 181 | def get_random(self, num): 182 | print('random select ........') 183 | 184 | def random_func(): 185 | no_dup = False 186 | random_des_seq = None 187 | while (no_dup == False): 188 | random_des_seq = [np.random.randint(self.max_num) for i in range(self.nr_layer)] 189 | dup = [x for x in random_des_seq if random_des_seq.count(x) > 1] 190 | if len(dup) == 0: 191 | no_dup = True 192 | random_des_seq.sort(reverse=True) 193 | return tuple(random_des_seq) 194 | 195 | cand_iter = self.stack_random_cand(random_func) 196 | while len(self.candidates) < num: 197 | cand = next(cand_iter) 198 | if not self.is_legal(cand): 199 | continue 200 | self.candidates.append(cand) 201 | print('random {}/{}'.format(len(self.candidates), num)) 202 | print('random_num = {}'.format(len(self.candidates))) 203 | 204 | def get_mutation(self, k, mutation_num, m_prob): 205 | assert k in self.keep_top_k 206 | print('mutation ......') 207 | res = [] 208 | iter = 0 209 | max_iters = mutation_num * 10 210 | 211 | def random_func(): 212 | cand = list(choice(self.keep_top_k[k])) 213 | for i in range(self.nr_layer): 214 | if np.random.random_sample() < m_prob: 215 | if i == 0: 216 | cand[i] = np.random.randint(cand[i + 1] + 1, self.max_num) 217 | elif i == self.nr_layer - 1: 218 | cand[i] = np.random.randint(1, cand[i - 1]) 219 | else: 220 | cand[i] = np.random.randint(cand[i + 1] + 1, cand[i - 1]) 221 | 222 | return tuple(cand) 223 | 224 | cand_iter = self.stack_random_cand(random_func) 225 | while len(res) < mutation_num and max_iters > 0: 226 | max_iters -= 1 227 | cand = next(cand_iter) 228 | if not self.is_legal(cand): 229 | continue 230 | res.append(cand) 231 | print('mutation {}/{}'.format(len(res), mutation_num)) 232 | 233 | print('mutation_num = {}'.format(len(res))) 234 | return res 235 | 236 | def get_crossover(self, k, crossover_num): 237 | assert k in self.keep_top_k 238 | print('crossover ......') 239 | res = [] 240 | iter = 0 241 | max_iters = 10 * crossover_num 242 | 243 | def random_func(): 244 | p1 = choice(self.keep_top_k[k]) 245 | p2 = choice(self.keep_top_k[k]) 246 | no_dup = False 247 | cand = None 248 | while (no_dup == False): 249 | cand = [choice([i, j]) for i, j in zip(p1, p2)] 250 | dup = [x for x in cand if cand.count(x) > 1] 251 | if len(dup) == 0: 252 | no_dup = True 253 | cand.sort(reverse=True) 254 | return tuple(cand) 255 | cand_iter = self.stack_random_cand(random_func) 256 | while len(res) < crossover_num and max_iters > 0: 257 | max_iters -= 1 258 | cand = next(cand_iter) 259 | if not self.is_legal(cand): 260 | continue 261 | res.append(cand) 262 | print('crossover {}/{}'.format(len(res), crossover_num)) 263 | 264 | print('crossover_num = {}'.format(len(res))) 265 | return res 266 | 267 | def search(self): 268 | print('population_num = {} select_num = {} mutation_num = {} crossover_num = {} random_num = {} max_epochs = {}'.format( 269 | self.population_num, self.select_num, self.mutation_num, self.crossover_num, self.population_num - self.mutation_num - self.crossover_num, self.max_epochs)) 270 | 271 | self.load_checkpoint() 272 | 273 | self.get_random(self.population_num) 274 | 275 | while self.epoch < self.max_epochs: 276 | print('epoch = {}'.format(self.epoch)) 277 | 278 | self.memory.append([]) 279 | for cand in self.candidates: 280 | self.memory[-1].append(cand) 281 | 282 | self.update_top_k( 283 | self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]['err'], reverse=True) 284 | self.update_top_k( 285 | self.candidates, k=self.top_k, key=lambda x: self.vis_dict[x]['err'], reverse=True) 286 | 287 | print('epoch = {} : top {} result'.format( 288 | self.epoch, len(self.keep_top_k[self.top_k]))) 289 | for i, cand in enumerate(self.keep_top_k[self.top_k]): 290 | print('No.{} {} Top-1 err = {}'.format( 291 | i + 1, cand, self.vis_dict[cand]['err'])) 292 | ops = [i for i in cand] 293 | print('ops:', ops) 294 | 295 | mutation = self.get_mutation( 296 | self.select_num, self.mutation_num, self.m_prob) 297 | crossover = self.get_crossover(self.select_num, self.crossover_num) 298 | 299 | self.candidates = mutation + crossover 300 | 301 | self.get_random(self.population_num) 302 | 303 | self.epoch += 1 304 | 305 | self.save_checkpoint() 306 | 307 | 308 | def main(): 309 | # print(args['max-epochs']) 310 | t = time.time() 311 | 312 | searcher = EvolutionSearcher() 313 | 314 | searcher.search() 315 | 316 | print('total searching time = {:.2f} hours'.format( 317 | (time.time() - t) / 3600)) 318 | 319 | if __name__ == '__main__': 320 | main() 321 | -------------------------------------------------------------------------------- /tester_water.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import core.metrics as Metrics 3 | 4 | def no_grad_wrapper(func): 5 | def new_func(*args, **kwargs): 6 | with torch.no_grad(): 7 | return func(*args, **kwargs) 8 | return new_func 9 | 10 | def get_cand_err2(model, cand, data, args): 11 | avg_psnr = 0.0 12 | idx = 0 13 | for _, val_data in enumerate(data): 14 | idx += 1 15 | model.feed_data(val_data) 16 | model.test(cand=cand, continous=True) 17 | 18 | visuals = model.get_current_visuals(need_LR=False) 19 | 20 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8 21 | sr_img = Metrics.tensor2img(visuals['SR'][-1]) 22 | psnr = Metrics.calculate_psnr(sr_img, hr_img) 23 | avg_psnr += psnr 24 | avg_psnr = avg_psnr / idx 25 | return avg_psnr 26 | 27 | 28 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data as Data 3 | import model as Model 4 | import argparse 5 | import logging 6 | import core.logger as Logger 7 | import core.metrics as Metrics 8 | from core.wandb_logger import WandbLogger 9 | from tensorboardX import SummaryWriter 10 | import os 11 | import numpy as np 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-c', '--config', type=str, default='config/underwater.json', 16 | help='JSON file for configuration') 17 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'], 18 | help='Run either train(training) or val(generation)', default='train') 19 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 20 | parser.add_argument('-debug', '-d', action='store_true') 21 | parser.add_argument('-enable_wandb', action='store_true') 22 | parser.add_argument('-log_wandb_ckpt', action='store_true') 23 | parser.add_argument('-log_eval', action='store_true') 24 | 25 | # parse configs 26 | args = parser.parse_args() 27 | opt = Logger.parse(args) 28 | # Convert to NoneDict, which return None for missing key. 29 | opt = Logger.dict_to_nonedict(opt) 30 | 31 | # logging 32 | torch.backends.cudnn.enabled = True 33 | torch.backends.cudnn.benchmark = True 34 | 35 | Logger.setup_logger(None, opt['path']['log'], 36 | 'train', level=logging.INFO, screen=True) 37 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) 38 | logger = logging.getLogger('base') 39 | logger.info(Logger.dict2str(opt)) 40 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) 41 | 42 | # Initialize WandbLogger 43 | if opt['enable_wandb']: 44 | import wandb 45 | wandb_logger = WandbLogger(opt) 46 | wandb.define_metric('validation/val_step') 47 | wandb.define_metric('epoch') 48 | wandb.define_metric("validation/*", step_metric="val_step") 49 | val_step = 0 50 | else: 51 | wandb_logger = None 52 | 53 | # dataset 54 | for phase, dataset_opt in opt['datasets'].items(): 55 | if phase == 'train' and args.phase != 'val': 56 | train_set = Data.create_dataset(dataset_opt, phase) 57 | train_loader = Data.create_dataloader( 58 | train_set, dataset_opt, phase) 59 | elif phase == 'val': 60 | val_set = Data.create_dataset(dataset_opt, phase) 61 | val_loader = Data.create_dataloader( 62 | val_set, dataset_opt, phase) 63 | logger.info('Initial Dataset Finished') 64 | 65 | # model 66 | diffusion = Model.create_model(opt) 67 | logger.info('Initial Model Finished') 68 | 69 | # Train 70 | current_step = diffusion.begin_step 71 | current_epoch = diffusion.begin_epoch 72 | n_iter = opt['train']['n_iter'] 73 | 74 | if opt['path']['resume_state']: 75 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 76 | current_epoch, current_step)) 77 | 78 | diffusion.set_new_noise_schedule( 79 | opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase']) 80 | if opt['phase'] == 'train': 81 | while current_step < n_iter: 82 | current_epoch += 1 83 | for _, train_data in enumerate(train_loader): 84 | current_step += 1 85 | if current_step > n_iter: 86 | break 87 | diffusion.feed_data(train_data) 88 | diffusion.optimize_parameters() 89 | # diffusion.finetune_parameters() 90 | # log 91 | if current_step % opt['train']['print_freq'] == 0: 92 | logs = diffusion.get_current_log() 93 | message = ' '.format( 94 | current_epoch, current_step) 95 | for k, v in logs.items(): 96 | message += '{:s}: {:.4e} '.format(k, v) 97 | tb_logger.add_scalar(k, v, current_step) 98 | logger.info(message) 99 | 100 | if wandb_logger: 101 | wandb_logger.log_metrics(logs) 102 | 103 | # validation 104 | if current_step % opt['train']['val_freq'] == 0: 105 | avg_psnr = 0.0 106 | idx = 0 107 | result_path = '{}/{}'.format(opt['path'] 108 | ['results'], current_epoch) 109 | os.makedirs(result_path, exist_ok=True) 110 | 111 | diffusion.set_new_noise_schedule( 112 | opt['model']['beta_schedule']['val'], schedule_phase='val') 113 | for _, val_data in enumerate(val_loader): 114 | idx += 1 115 | diffusion.feed_data(val_data) 116 | diffusion.test(continous=False) 117 | visuals = diffusion.get_current_visuals() 118 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8 119 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8 120 | lr_img = Metrics.tensor2img(visuals['LR']) # uint8 121 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8 122 | 123 | # generation 124 | Metrics.save_img( 125 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) 126 | Metrics.save_img( 127 | sr_img, '{}/{}_{}_sr.png'.format(result_path, current_step, idx)) 128 | Metrics.save_img( 129 | lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx)) 130 | Metrics.save_img( 131 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) 132 | tb_logger.add_image( 133 | 'Iter_{}'.format(current_step), 134 | np.transpose(np.concatenate( 135 | (fake_img, sr_img, hr_img), axis=1), [2, 0, 1]), 136 | idx) 137 | avg_psnr += Metrics.calculate_psnr( 138 | sr_img, hr_img) 139 | 140 | if wandb_logger: 141 | wandb_logger.log_image( 142 | f'validation_{idx}', 143 | np.concatenate((fake_img, sr_img, hr_img), axis=1) 144 | ) 145 | 146 | avg_psnr = avg_psnr / idx 147 | diffusion.set_new_noise_schedule( 148 | opt['model']['beta_schedule']['train'], schedule_phase='train') 149 | # log 150 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) 151 | logger_val = logging.getLogger('val') # validation logger 152 | logger_val.info(' psnr: {:.4e}'.format( 153 | current_epoch, current_step, avg_psnr)) 154 | # tensorboard logger 155 | tb_logger.add_scalar('psnr', avg_psnr, current_step) 156 | 157 | if wandb_logger: 158 | wandb_logger.log_metrics({ 159 | 'validation/val_psnr': avg_psnr, 160 | 'validation/val_step': val_step 161 | }) 162 | val_step += 1 163 | 164 | if current_step % opt['train']['save_checkpoint_freq'] == 0: 165 | logger.info('Saving models and training states.') 166 | diffusion.save_network(current_epoch, current_step) 167 | 168 | if wandb_logger and opt['log_wandb_ckpt']: 169 | wandb_logger.log_checkpoint(current_epoch, current_step) 170 | 171 | if wandb_logger: 172 | wandb_logger.log_metrics({'epoch': current_epoch-1}) 173 | 174 | # save model 175 | logger.info('End of training.') 176 | else: 177 | logger.info('Begin Model Evaluation.') 178 | avg_psnr = 0.0 179 | avg_ssim = 0.0 180 | idx = 0 181 | result_path = '{}'.format(opt['path']['results']) 182 | os.makedirs(result_path, exist_ok=True) 183 | for _, val_data in enumerate(val_loader): 184 | idx += 1 185 | diffusion.feed_data(val_data) 186 | diffusion.test(continous=True) 187 | visuals = diffusion.get_current_visuals() 188 | 189 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8 190 | lr_img = Metrics.tensor2img(visuals['LR']) # uint8 191 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8 192 | 193 | sr_img_mode = 'grid' 194 | if sr_img_mode == 'single': 195 | # single img series 196 | sr_img = visuals['SR'] # uint8 197 | sample_num = sr_img.shape[0] 198 | for iter in range(0, sample_num): 199 | Metrics.save_img( 200 | Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter)) 201 | else: 202 | # grid img 203 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8 204 | Metrics.save_img( 205 | sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx)) 206 | Metrics.save_img( 207 | Metrics.tensor2img(visuals['SR'][-1]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx)) 208 | 209 | Metrics.save_img( 210 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) 211 | Metrics.save_img( 212 | lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx)) 213 | Metrics.save_img( 214 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) 215 | 216 | # generation 217 | eval_psnr = Metrics.calculate_psnr(Metrics.tensor2img(visuals['SR'][-1]), hr_img) 218 | eval_ssim = Metrics.calculate_ssim(Metrics.tensor2img(visuals['SR'][-1]), hr_img) 219 | 220 | avg_psnr += eval_psnr 221 | avg_ssim += eval_ssim 222 | 223 | if wandb_logger and opt['log_eval']: 224 | wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img, eval_psnr, eval_ssim) 225 | 226 | avg_psnr = avg_psnr / idx 227 | avg_ssim = avg_ssim / idx 228 | 229 | # log 230 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) 231 | logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim)) 232 | logger_val = logging.getLogger('val') # validation logger 233 | logger_val.info(' psnr: {:.4e}, ssim:{:.4e}'.format( 234 | current_epoch, current_step, avg_psnr, avg_ssim)) 235 | 236 | if wandb_logger: 237 | if opt['log_eval']: 238 | wandb_logger.log_eval_table() 239 | wandb_logger.log_metrics({ 240 | 'PSNR': float(avg_psnr), 241 | 'SSIM': float(avg_ssim) 242 | }) 243 | 244 | 245 | --------------------------------------------------------------------------------