├── .idea ├── .gitignore ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── pet.iml └── vcs.xml ├── config ├── sample_ddpm_128.json ├── sample_sr3_128.json ├── sr_ddpm_16_128.json ├── sr_sr3_16_128.json └── sr_sr3_64_512.json ├── core ├── logger.py ├── metrics.py └── wandb_logger.py ├── data ├── LRHR_dataset.py ├── __init__.py ├── dataloader.py ├── prepare_data.py └── util.py ├── easy_train.py ├── inference.py ├── model ├── __init__.py ├── base_model.py ├── ddpm_modules │ ├── diffusion.py │ └── unet.py ├── model.py ├── networks.py └── sr3_modules │ ├── diffusion.py │ └── unet.py ├── requirement.txt └── train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /config/sample_ddpm_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "generation_ffhq", 3 | "phase": "train", 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | "path": { 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": null 13 | // "resume_state": "experiments/generation_ffhq_210811_140902/checkpoint/I30_E1" 14 | }, 15 | "datasets": { 16 | "train": { 17 | "name": "FFHQ", 18 | "mode": "HR", 19 | "dataroot": "dataset/ffhq_16_128", 20 | "datatype": "lmdb", //lmdb or img, path of img files 21 | "l_resolution": 16, 22 | "r_resolution": 128, 23 | "batch_size": 12, 24 | "num_workers": 8, 25 | "use_shuffle": true, 26 | "data_len": -1 27 | }, 28 | "val": { 29 | "name": "CelebaHQ", 30 | "mode": "HR", 31 | "dataroot": "dataset/celebahq_16_128", 32 | "datatype": "lmdb", //lmdb or img, path of img files 33 | "l_resolution": 16, 34 | "r_resolution": 128, 35 | "data_len": 10 36 | } 37 | }, 38 | "model": { 39 | "which_model_G": "ddpm", //ddpm, sr3 40 | "finetune_norm": false, 41 | "unet": { 42 | "in_channel": 3, 43 | "out_channel": 3, 44 | "inner_channel": 64, 45 | "channel_multiplier": [ 46 | 1, 47 | 1, 48 | 2, 49 | 2, 50 | 4, 51 | 4 52 | ], 53 | "attn_res": [ 54 | 16 55 | ], 56 | "res_blocks": 2, 57 | "dropout": 0.2 58 | }, 59 | "beta_schedule": { 60 | "train": { 61 | "schedule": "linear", 62 | "n_timestep": 2000, 63 | "linear_start": 1e-4, 64 | "linear_end": 2e-2 65 | }, 66 | "val": { 67 | "schedule": "linear", 68 | "n_timestep": 2000, 69 | "linear_start": 1e-4, 70 | "linear_end": 2e-2 71 | } 72 | }, 73 | "diffusion": { 74 | "image_size": 128, 75 | "channels": 3, //sample channel 76 | "conditional": false 77 | } 78 | }, 79 | "train": { 80 | "n_iter": 1000000, 81 | "val_freq": 1e4, 82 | "save_checkpoint_freq": 1e4, 83 | "print_freq": 200, 84 | "optimizer": { 85 | "type": "adam", 86 | "lr": 1e-4 87 | }, 88 | "ema_scheduler": { 89 | "step_start_ema": 5000, 90 | "update_ema_every": 1, 91 | "ema_decay": 0.9999 92 | } 93 | }, 94 | "wandb": { 95 | "project": "generation_ffhq_ddpm" 96 | } 97 | } -------------------------------------------------------------------------------- /config/sample_sr3_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "generation_ffhq", 3 | "phase": "train", // train or val 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | "path": { //set the path 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": null 13 | // "resume_state": "experiments/generation_ffhq_210811_140902/checkpoint/I1560000_E91" //pretrain model or training state 14 | }, 15 | "datasets": { 16 | "train": { 17 | "name": "FFHQ", 18 | "mode": "HR", // whether need LR img 19 | "dataroot": "dataset/ffhq_16_128", 20 | "datatype": "lmdb", //lmdb or img, path of img files 21 | "l_resolution": 16, // low resolution need to super_resolution 22 | "r_resolution": 128, // high resolution 23 | "batch_size": 4, 24 | "num_workers": 8, 25 | "use_shuffle": true, 26 | "data_len": -1 // -1 represents all data used in train 27 | }, 28 | "val": { 29 | "name": "CelebaHQ", 30 | "mode": "HR", 31 | "dataroot": "dataset/celebahq_16_128", 32 | "datatype": "lmdb", //lmdb or img, path of img files 33 | "l_resolution": 16, 34 | "r_resolution": 128, 35 | "data_len": 50 36 | } 37 | }, 38 | "model": { 39 | "which_model_G": "sr3", // use the ddpm or sr3 network structure 40 | "finetune_norm": false, 41 | "unet": { 42 | "in_channel": 3, 43 | "out_channel": 3, 44 | "inner_channel": 64, 45 | "channel_multiplier": [ 46 | 1, 47 | 2, 48 | 4, 49 | 8, 50 | 8 51 | ], 52 | "attn_res": [ 53 | 16 54 | ], 55 | "res_blocks": 2, 56 | "dropout": 0.2 57 | }, 58 | "beta_schedule": { // use munual beta_schedule for acceleration 59 | "train": { 60 | "schedule": "linear", 61 | "n_timestep": 2000, 62 | "linear_start": 1e-6, 63 | "linear_end": 1e-2 64 | }, 65 | "val": { 66 | "schedule": "linear", 67 | "n_timestep": 2000, 68 | "linear_start": 1e-6, 69 | "linear_end": 1e-2 70 | } 71 | }, 72 | "diffusion": { 73 | "image_size": 128, 74 | "channels": 3, //sample channel 75 | "conditional": false // unconditional generation or unconditional generation(super_resolution) 76 | } 77 | }, 78 | "train": { 79 | "n_iter": 10000000, 80 | "val_freq": 1e4, 81 | "save_checkpoint_freq": 1e4, 82 | "print_freq": 200, 83 | "optimizer": { 84 | "type": "adam", 85 | "lr": 1e-4 86 | }, 87 | "ema_scheduler": { // not used now 88 | "step_start_ema": 5000, 89 | "update_ema_every": 1, 90 | "ema_decay": 0.9999 91 | } 92 | }, 93 | "wandb": { 94 | "project": "generation_ffhq_sr3" 95 | } 96 | } -------------------------------------------------------------------------------- /config/sr_ddpm_16_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sr_ffhq", 3 | "phase": "train", 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | "path": { 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": null 13 | // "resume_state": "experiments/sr_ffhq_210806_204158/checkpoint/I640000_E37" //pretrain model or training state 14 | }, 15 | "datasets": { 16 | "train": { 17 | "name": "FFHQ", 18 | "mode": "HR", 19 | "dataroot": "dataset/processed", 20 | "datatype": "lmdb", //lmdb or img, path of img files 21 | "l_resolution": 64, 22 | "r_resolution": 64, 23 | "batch_size": 2, 24 | "num_workers": 0, 25 | "use_shuffle": true, 26 | "data_len": -1 27 | }, 28 | "val": { 29 | "name": "CelebaHQ", 30 | "mode": "LRHR", 31 | "dataroot": "dataset/processed", 32 | "datatype": "lmdb", //lmdb or img, path of img files 33 | "l_resolution": 64, 34 | "r_resolution": 64, 35 | "data_len": 3 36 | } 37 | }, 38 | "model": { 39 | "which_model_G": "ddpm", //ddpm, sr3 40 | "finetune_norm": false, 41 | "unet": { 42 | "in_channel": 2, 43 | "out_channel": 1, 44 | "inner_channel": 32, 45 | "channel_multiplier": [ 46 | 1, 47 | 1, 48 | 2, 49 | 2, 50 | 4, 51 | 4 52 | ], 53 | "attn_res": [ 54 | 16 55 | ], 56 | "res_blocks": 2, 57 | "dropout": 0.2 58 | }, 59 | "beta_schedule": { 60 | "train": { 61 | "schedule": "linear", 62 | "n_timestep": 2000, 63 | "linear_start": 1e-4, 64 | "linear_end": 2e-2 65 | }, 66 | "val": { 67 | "schedule": "linear", 68 | "n_timestep": 2000, 69 | "linear_start": 1e-4, 70 | "linear_end": 2e-2 71 | } 72 | }, 73 | "diffusion": { 74 | "image_size": 64, 75 | "channels": 1, //sample channel 76 | "conditional": true 77 | } 78 | }, 79 | "train": { 80 | "n_iter": 1000000, 81 | "val_freq": 1e4, 82 | "save_checkpoint_freq": 1e4, 83 | "print_freq": 200, 84 | "optimizer": { 85 | "type": "adam", 86 | "lr": 1e-4 87 | }, 88 | "ema_scheduler": { 89 | "step_start_ema": 5000, 90 | "update_ema_every": 1, 91 | "ema_decay": 0.9999 92 | } 93 | }, 94 | "wandb": { 95 | "project": "sr_ffhq" 96 | } 97 | } -------------------------------------------------------------------------------- /config/sr_sr3_16_128.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sr_ffhq", 3 | "phase": "train", // train or val 4 | "gpu_ids": [ 5 | 3 6 | ], 7 | "path": { //set the path 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": null //resume_state": "experiments/sr_ffhq_210806_204158/checkpoint/I640000_E37" //pretrain model or training state 13 | }, 14 | "datasets": { 15 | "train": { 16 | "name": "FFHQ", 17 | "mode": "HR", // whether need LR img 18 | "dataroot": "train_mat", 19 | "datatype": "lmdb", //lmdb or img, path of img files 20 | "l_resolution": 64, // low resolution need to super_resolution 21 | "r_resolution": 64, // high resolution 22 | "batch_size": 4, 23 | "num_workers": 0, 24 | "use_shuffle": true, 25 | "data_len": -1 // -1 represents all data used in train 26 | }, 27 | "val": { 28 | "name": "CelebaHQ", 29 | "mode": "LRHR", 30 | "dataroot": "dataset", 31 | "datatype": "lmdb", //lmdb or img, path of img files 32 | "l_resolution": 64, 33 | "r_resolution": 64, 34 | "data_len": -1 // data length in validation 35 | } 36 | }, 37 | "model": { 38 | "which_model_G": "sr3", // use the ddpm or sr3 network structure 39 | "finetune_norm": false, 40 | "unet": { 41 | "PreNet": { 42 | "in_channel": 1, 43 | "out_channel": 1, 44 | "inner_channel": 64, 45 | "channel_multiplier": [ 46 | 1, 47 | 2, 48 | 3, 49 | 4 50 | ], 51 | "attn_res": [ 52 | 32 53 | ], 54 | "res_blocks": 3, 55 | "dropout": 0.1 56 | }, 57 | "DenoiseNet": { 58 | "in_channel": 2, 59 | "out_channel": 1, 60 | "inner_channel": 32, 61 | "channel_multiplier": [ 62 | 1, 63 | 2, 64 | 3, 65 | 4 66 | ], 67 | "attn_res": [ 68 | 32 69 | ], 70 | "res_blocks": 3, 71 | "dropout": 0.1 72 | } 73 | 74 | }, 75 | "beta_schedule": { // use munual beta_schedule for acceleration 76 | "train": { 77 | "schedule": "linear", 78 | "n_timestep": 2000, 79 | "linear_start": 1e-6, 80 | "linear_end": 1e-2 81 | }, 82 | "val": { 83 | "schedule": "linear", 84 | "n_timestep": 2000, 85 | "linear_start": 1e-6, 86 | "linear_end": 1e-2 87 | } 88 | }, 89 | "diffusion": { 90 | "image_size": 128, 91 | "channels": 1, //sample channel 92 | "conditional": true // unconditional generation or unconditional generation(super_resolution) 93 | } 94 | }, 95 | "train": { 96 | "n_iter": 1000000, 97 | "val_freq": 1e4, 98 | "save_checkpoint_freq": 2e4, 99 | "print_freq": 200, 100 | "optimizer": { 101 | "type": "adam", 102 | "lr": 1e-4 103 | }, 104 | "ema_scheduler": { // not used now 105 | "step_start_ema": 5000, 106 | "update_ema_every": 1, 107 | "ema_decay": 0.9999 108 | } 109 | }, 110 | "wandb": { 111 | "project": "sr_ffhq" 112 | } 113 | } -------------------------------------------------------------------------------- /config/sr_sr3_64_512.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "distributed_high_sr_ffhq", 3 | "phase": "train", // train or val 4 | "gpu_ids": [ 5 | 0,1 6 | ], 7 | "path": { //set the path 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": null 13 | // "resume_state": "experiments/distributed_high_sr_ffhq_210901_121212/checkpoint/I830000_E32" //pretrain model or training state 14 | }, 15 | "datasets": { 16 | "train": { 17 | "name": "FFHQ", 18 | "mode": "HR", // whether need LR img 19 | "dataroot": "dataset/ffhq_64_512", 20 | "datatype": "img", //lmdb or img, path of img files 21 | "l_resolution": 64, // low resolution need to super_resolution 22 | "r_resolution": 512, // high resolution 23 | "batch_size": 2, 24 | "num_workers": 8, 25 | "use_shuffle": true, 26 | "data_len": -1 // -1 represents all data used in train 27 | }, 28 | "val": { 29 | "name": "CelebaHQ", 30 | "mode": "LRHR", 31 | "dataroot": "dataset/celebahq_64_512", 32 | "datatype": "img", //lmdb or img, path of img files 33 | "l_resolution": 64, 34 | "r_resolution": 512, 35 | "data_len": 50 36 | } 37 | }, 38 | "model": { 39 | "which_model_G": "sr3", // use the ddpm or sr3 network structure 40 | "finetune_norm": false, 41 | "unet": { 42 | "in_channel": 6, 43 | "out_channel": 3, 44 | "inner_channel": 64, 45 | "norm_groups": 16, 46 | "channel_multiplier": [ 47 | 1, 48 | 2, 49 | 4, 50 | 8, 51 | // 8, 52 | // 16, 53 | 16 54 | ], 55 | "attn_res": [ 56 | // 16 57 | ], 58 | "res_blocks": 1, 59 | "dropout": 0 60 | }, 61 | "beta_schedule": { // use munual beta_schedule for acceleration 62 | "train": { 63 | "schedule": "linear", 64 | "n_timestep": 2000, 65 | "linear_start": 1e-6, 66 | "linear_end": 1e-2 67 | }, 68 | "val": { 69 | "schedule": "linear", 70 | "n_timestep": 2000, 71 | "linear_start": 1e-6, 72 | "linear_end": 1e-2 73 | } 74 | }, 75 | "diffusion": { 76 | "image_size": 512, 77 | "channels": 3, //sample channel 78 | "conditional": true // unconditional generation or unconditional generation(super_resolution) 79 | } 80 | }, 81 | "train": { 82 | "n_iter": 1000000, 83 | "val_freq": 1e4, 84 | "save_checkpoint_freq": 1e4, 85 | "print_freq": 50, 86 | "optimizer": { 87 | "type": "adam", 88 | "lr": 3e-6 89 | }, 90 | "ema_scheduler": { // not used now 91 | "step_start_ema": 5000, 92 | "update_ema_every": 1, 93 | "ema_decay": 0.9999 94 | } 95 | }, 96 | "wandb": { 97 | "project": "distributed_high_sr_ffhq" 98 | } 99 | } -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | from collections import OrderedDict 5 | import json 6 | from datetime import datetime 7 | 8 | 9 | def mkdirs(paths): 10 | if isinstance(paths, str): 11 | os.makedirs(paths, exist_ok=True) 12 | else: 13 | for path in paths: 14 | os.makedirs(path, exist_ok=True) 15 | 16 | 17 | def get_timestamp(): 18 | return datetime.now().strftime('%y%m%d_%H%M%S') 19 | 20 | 21 | def parse(args): 22 | phase = args.phase 23 | opt_path = args.config 24 | gpu_ids = args.gpu_ids 25 | enable_wandb = args.enable_wandb 26 | # remove comments starting with '//' 27 | json_str = '' 28 | with open(opt_path, 'r') as f: 29 | for line in f: 30 | line = line.split('//')[0] + '\n' 31 | json_str += line 32 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 33 | 34 | # set log directory 35 | if args.debug: 36 | opt['name'] = 'debug_{}'.format(opt['name']) 37 | experiments_root = os.path.join( 38 | 'experiments', '{}_{}'.format(opt['name'], get_timestamp())) 39 | opt['path']['experiments_root'] = experiments_root 40 | for key, path in opt['path'].items(): 41 | if 'resume' not in key and 'experiments' not in key: 42 | opt['path'][key] = os.path.join(experiments_root, path) 43 | mkdirs(opt['path'][key]) 44 | 45 | # change dataset length limit 46 | opt['phase'] = phase 47 | 48 | # export CUDA_VISIBLE_DEVICES 49 | if gpu_ids is not None: 50 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')] 51 | gpu_list = gpu_ids 52 | else: 53 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 54 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 55 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 56 | if len(gpu_list) > 1: 57 | opt['distributed'] = True 58 | else: 59 | opt['distributed'] = False 60 | 61 | # debug 62 | if 'debug' in opt['name']: 63 | opt['train']['val_freq'] = 2 64 | opt['train']['print_freq'] = 2 65 | opt['train']['save_checkpoint_freq'] = 3 66 | opt['datasets']['train']['batch_size'] = 2 67 | opt['model']['beta_schedule']['train']['n_timestep'] = 10 68 | opt['model']['beta_schedule']['val']['n_timestep'] = 10 69 | opt['datasets']['train']['data_len'] = 6 70 | opt['datasets']['val']['data_len'] = 3 71 | 72 | # validation in train phase 73 | if phase == 'train': 74 | opt['datasets']['val']['data_len'] = 3 75 | 76 | # W&B Logging 77 | try: 78 | log_wandb_ckpt = args.log_wandb_ckpt 79 | opt['log_wandb_ckpt'] = log_wandb_ckpt 80 | except: 81 | pass 82 | try: 83 | log_eval = args.log_eval 84 | opt['log_eval'] = log_eval 85 | except: 86 | pass 87 | try: 88 | log_infer = args.log_infer 89 | opt['log_infer'] = log_infer 90 | except: 91 | pass 92 | opt['enable_wandb'] = enable_wandb 93 | 94 | return opt 95 | 96 | 97 | class NoneDict(dict): 98 | def __missing__(self, key): 99 | return None 100 | 101 | 102 | # convert to NoneDict, which return None for missing key. 103 | def dict_to_nonedict(opt): 104 | if isinstance(opt, dict): 105 | new_opt = dict() 106 | for key, sub_opt in opt.items(): 107 | new_opt[key] = dict_to_nonedict(sub_opt) 108 | return NoneDict(**new_opt) 109 | elif isinstance(opt, list): 110 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 111 | else: 112 | return opt 113 | 114 | 115 | def dict2str(opt, indent_l=1): 116 | '''dict to string for logger''' 117 | msg = '' 118 | for k, v in opt.items(): 119 | if isinstance(v, dict): 120 | msg += ' ' * (indent_l * 2) + k + ':[\n' 121 | msg += dict2str(v, indent_l + 1) 122 | msg += ' ' * (indent_l * 2) + ']\n' 123 | else: 124 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 125 | return msg 126 | 127 | 128 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 129 | '''set up logger''' 130 | l = logging.getLogger(logger_name) 131 | formatter = logging.Formatter( 132 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 133 | log_file = os.path.join(root, '{}.log'.format(phase)) 134 | fh = logging.FileHandler(log_file, mode='w') 135 | fh.setFormatter(formatter) 136 | l.setLevel(level) 137 | l.addHandler(fh) 138 | if screen: 139 | sh = logging.StreamHandler() 140 | sh.setFormatter(formatter) 141 | l.addHandler(sh) 142 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | from torchvision.utils import make_grid 6 | 7 | 8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | ''' 14 | tensor = tensor.squeeze().float().cpu() 15 | img_np = tensor.numpy() 16 | return img_np 17 | 18 | 19 | def save_img(img, img_path, mode='RGB'): 20 | savImg = sitk.GetImageFromArray(img[:, :, :]) 21 | sitk.WriteImage(savImg, img_path) 22 | # cv2.imwrite(img_path, img) 23 | 24 | 25 | 26 | def calculate_psnr(img1, img2): 27 | # img1 and img2 have range [0, 255] 28 | img1 = img1.astype(np.float64) 29 | img2 = img2.astype(np.float64) 30 | mse = np.mean((img1 - img2)**2) 31 | if mse == 0: 32 | return float('inf') 33 | return 20 * math.log10(255.0 / math.sqrt(mse)) 34 | 35 | 36 | def ssim(img1, img2): 37 | C1 = (0.01 * 255)**2 38 | C2 = (0.03 * 255)**2 39 | 40 | img1 = img1.astype(np.float64) 41 | img2 = img2.astype(np.float64) 42 | kernel = cv2.getGaussianKernel(11, 1.5) 43 | window = np.outer(kernel, kernel.transpose()) 44 | 45 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 46 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 47 | mu1_sq = mu1**2 48 | mu2_sq = mu2**2 49 | mu1_mu2 = mu1 * mu2 50 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 51 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 52 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 53 | 54 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 55 | (sigma1_sq + sigma2_sq + C2)) 56 | return ssim_map.mean() 57 | 58 | 59 | def calculate_ssim(img1, img2): 60 | '''calculate SSIM 61 | the same outputs as MATLAB's 62 | img1, img2: [0, 255] 63 | ''' 64 | if not img1.shape == img2.shape: 65 | raise ValueError('Input images must have the same dimensions.') 66 | if img1.ndim == 2: 67 | return ssim(img1, img2) 68 | elif img1.ndim == 3: 69 | if img1.shape[2] == 3: 70 | ssims = [] 71 | for i in range(3): 72 | ssims.append(ssim(img1, img2)) 73 | return np.array(ssims).mean() 74 | elif img1.shape[2] == 1: 75 | return ssim(np.squeeze(img1), np.squeeze(img2)) 76 | else: 77 | raise ValueError('Wrong input image dimensions.') 78 | -------------------------------------------------------------------------------- /core/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class WandbLogger: 4 | """ 5 | Log using `Weights and Biases`. 6 | """ 7 | def __init__(self, opt): 8 | try: 9 | import wandb 10 | except ImportError: 11 | raise ImportError( 12 | "To use the Weights and Biases Logger please install wandb." 13 | "Run `pip install wandb` to install it." 14 | ) 15 | 16 | self._wandb = wandb 17 | 18 | # Initialize a W&B run 19 | if self._wandb.run is None: 20 | self._wandb.init( 21 | project=opt['wandb']['project'], 22 | config=opt, 23 | dir='./experiments' 24 | ) 25 | 26 | self.config = self._wandb.config 27 | 28 | if self.config.get('log_eval', None): 29 | self.eval_table = self._wandb.Table(columns=[ 30 | 'sr_image', 31 | 'hr_image', 32 | 'psnr', 33 | 'ssim']) 34 | else: 35 | self.eval_table = None 36 | 37 | if self.config.get('log_infer', None): 38 | self.infer_table = self._wandb.Table(columns=[ 39 | 'sr_image', 40 | 'hr_image']) 41 | else: 42 | self.infer_table = None 43 | 44 | def log_metrics(self, metrics, commit=True): 45 | """ 46 | Log train/validation metrics onto W&B. 47 | 48 | metrics: dictionary of metrics to be logged 49 | """ 50 | self._wandb.log(metrics, commit=commit) 51 | 52 | def log_image(self, key_name, image_array): 53 | """ 54 | Log image array onto W&B. 55 | 56 | key_name: name of the key 57 | image_array: numpy array of image. 58 | """ 59 | self._wandb.log({key_name: self._wandb.Image(image_array)}) 60 | 61 | def log_images(self, key_name, list_images): 62 | """ 63 | Log list of image array onto W&B 64 | 65 | key_name: name of the key 66 | list_images: list of numpy image arrays 67 | """ 68 | self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]}) 69 | 70 | def log_checkpoint(self, current_epoch, current_step): 71 | """ 72 | Log the model checkpoint as W&B artifacts 73 | 74 | current_epoch: the current epoch 75 | current_step: the current batch step 76 | """ 77 | model_artifact = self._wandb.Artifact( 78 | self._wandb.run.id + "_model", type="model" 79 | ) 80 | 81 | gen_path = os.path.join( 82 | self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch)) 83 | opt_path = os.path.join( 84 | self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch)) 85 | 86 | model_artifact.add_file(gen_path) 87 | model_artifact.add_file(opt_path) 88 | self._wandb.log_artifact(model_artifact, aliases=["latest"]) 89 | 90 | def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None): 91 | """ 92 | Add data row-wise to the initialized table. 93 | """ 94 | if psnr is not None and ssim is not None: 95 | self.eval_table.add_data( 96 | self._wandb.Image(fake_img), 97 | self._wandb.Image(sr_img), 98 | self._wandb.Image(hr_img), 99 | psnr, 100 | ssim 101 | ) 102 | else: 103 | self.infer_table.add_data( 104 | self._wandb.Image(fake_img), 105 | self._wandb.Image(sr_img), 106 | self._wandb.Image(hr_img) 107 | ) 108 | 109 | def log_eval_table(self, commit=False): 110 | """ 111 | Log the table 112 | """ 113 | if self.eval_table: 114 | self._wandb.log({'eval_data': self.eval_table}, commit=commit) 115 | elif self.infer_table: 116 | self._wandb.log({'infer_data': self.infer_table}, commit=commit) 117 | -------------------------------------------------------------------------------- /data/LRHR_dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms, datasets 5 | from torchvision.utils import save_image 6 | import torchvision.transforms 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | import random 10 | import data.util as Util 11 | import os 12 | from medpy.io import load 13 | import numpy as np 14 | import scipy.io as io 15 | class LRHRDataset(Dataset): 16 | def __init__(self, dataroot, datatype, l_resolution=64, r_resolution=64, split='train', data_len=-1, need_LR=False): 17 | self.datatype = datatype 18 | self.data_len = data_len 19 | self.need_LR = need_LR 20 | self.split = split 21 | self.path = Util.get_paths_from_images( 22 | '{}'.format(dataroot)) 23 | self.dataset_len = len(self.path) 24 | if self.data_len <= 0: 25 | self.data_len = self.dataset_len 26 | else: 27 | self.data_len = min(self.data_len, self.dataset_len) 28 | def __len__(self): 29 | return self.data_len 30 | 31 | def __getitem__(self, index): 32 | image_path = os.path.join(self.path[index]) 33 | image= io.loadmat(image_path)['img'] 34 | image_h = image[:,128:256,:] 35 | img_hpet = torch.Tensor(image_h) 36 | image_s = image[:,0:128,:] 37 | img_spet = torch.Tensor(image_s) 38 | if self.need_LR: 39 | image_l = image[:,0:128,:] 40 | img_lpet = torch.Tensor(image_l) 41 | if self.need_LR: 42 | return {'LR': img_lpet, 'HR': img_hpet, 'SR': img_spet, 'Index': index} 43 | else: 44 | return {'HR': img_hpet, 'SR': img_spet, 'Index': index} -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | from re import split 4 | import torch.utils.data 5 | import os 6 | 7 | def create_dataloader(dataset, dataset_opt, phase): 8 | '''create dataloader ''' 9 | if phase == 'train': 10 | return torch.utils.data.DataLoader( 11 | dataset, 12 | batch_size=dataset_opt['batch_size'], 13 | shuffle=dataset_opt['use_shuffle'], 14 | num_workers=dataset_opt['num_workers'], 15 | pin_memory=True) 16 | elif phase == 'val': 17 | return torch.utils.data.DataLoader( 18 | dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) 19 | else: 20 | raise NotImplementedError( 21 | 'Dataloader [{:s}] is not found.'.format(phase)) 22 | 23 | 24 | def create_dataset(dataset_opt, phase): 25 | '''create dataset''' 26 | mode = dataset_opt['mode'] 27 | from data.LRHR_dataset import LRHRDataset as D 28 | dataset = D(dataroot=dataset_opt['dataroot'], 29 | datatype=dataset_opt['datatype'], 30 | l_resolution=dataset_opt['l_resolution'], 31 | r_resolution=dataset_opt['r_resolution'], 32 | split=phase, 33 | data_len=dataset_opt['data_len'], 34 | need_LR=(mode == 'LRHR') 35 | ) 36 | logger = logging.getLogger('base') 37 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 38 | dataset_opt['name'])) 39 | return dataset 40 | if __name__ == "__main__": 41 | from data.LRHR_dataset import LRHRDataset as D 42 | dataset = D( 43 | dataroot='C:\\Users\Administrator\Desktop\PET-Reconstruction-with-Diffusion\dataset\processed', 44 | datatype='jpg', 45 | l_resolution=64, 46 | r_resolution=64, 47 | split='train', 48 | data_len=-1, 49 | need_LR=False 50 | ) 51 | train_set = dataset 52 | train_loader=torch.utils.data.DataLoader( 53 | dataset, 54 | batch_size=2, 55 | shuffle="true", 56 | num_workers=0, 57 | pin_memory=True) 58 | for _, train_data in enumerate(train_loader): 59 | print(train_data['HR'].shape) 60 | # print(torch.zeros(train_data['HR'].shape[0:2], dtype=torch.float)) 61 | # path = 'dataset/processed' 62 | # # print(os.path.join(path.split(path.split('\\')[-1])[0]),'heihei',path.split('\\')[-1]) 63 | # print(os.path.join( 64 | # path.split(path.split('/')[-1])[0],'PreNet', 'I{}_E{}_gen.pth')) -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import medpy 2 | from medpy.io import load 3 | import os 4 | import pickle 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | from torchvision import transforms, datasets 8 | from torchvision.utils import save_image 9 | import numpy as np 10 | class MyDataset(Dataset): 11 | def __init__(self, root_l, subfolder_l,root_s,subfolder_s,prefixs, transform=None): 12 | super(MyDataset, self).__init__() 13 | self.prefixs=prefixs 14 | self.l_path = os.path.join(root_l, subfolder_l) 15 | self.s_path=os.path.join(root_s, subfolder_s) 16 | self.templ = [x for x in os.listdir(self.l_path) if os.path.splitext(x)[1] == ".img"] 17 | self.temps = [x for x in os.listdir(self.s_path) if os.path.splitext(x)[1] == ".img"] 18 | self.image_list_l=[] 19 | self.image_list_s = [] 20 | #找指定前缀的数据 21 | for file in self.templ: 22 | for pre in prefixs: 23 | if pre in file: 24 | self.image_list_l.append(file) 25 | #找指定前缀的数据 26 | for file in self.temps: 27 | for pre in prefixs: 28 | if pre in file: 29 | self.image_list_s.append(file) 30 | # print(self.image_list_l) 31 | # print(self.image_list_s) 32 | self.transform = transform 33 | 34 | def __len__(self): 35 | return len(self.image_list_l) 36 | 37 | def __getitem__(self, item): 38 | #读图片(低剂量PET) 39 | image_path_l = os.path.join(self.l_path, self.image_list_l[item]) 40 | #image = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR -> RGB 41 | image_l,h=load(image_path_l) 42 | image=np.array(image_l) 43 | #print(image.shape) 44 | if self.transform is not None: 45 | image = self.transform(image_l) 46 | #读标签(高质量PET) 47 | image_path_s = os.path.join(self.s_path, self.image_list_s[item]) 48 | image_s,h2=load(image_path_s) 49 | image_s=np.array(image_s) 50 | #print(image_l.shape)0 51 | # print(image_path_l,image_path_s) 52 | #添加通道维度 53 | image_l=image_l[np.newaxis,:] 54 | image_s=image_s[np.newaxis,:] 55 | image_l=torch.Tensor(image_l) 56 | image_s=torch.Tensor(image_s) 57 | #print(image.shape) 58 | if self.transform is not None: 59 | image = self.transform(image_s) 60 | #返回:影像,标签 61 | return image_l, image_s 62 | ### 63 | class MyMultiDataset(Dataset): 64 | def __init__(self, root_l, subfolder_l,root_s,subfolder_s,root_mri,subfolder_mri,prefixs, transform=None): 65 | super(MyMultiDataset, self).__init__() 66 | self.prefixs=prefixs 67 | self.l_path = os.path.join(root_l, subfolder_l) 68 | self.s_path=os.path.join(root_s, subfolder_s) 69 | self.templ = [x for x in os.listdir(self.l_path) if os.path.splitext(x)[1] == ".img"] 70 | self.temps = [x for x in os.listdir(self.s_path) if os.path.splitext(x)[1] == ".img"] 71 | self.image_list_l=[] 72 | self.image_list_s = [] 73 | self.image_list_mri = [] 74 | #找指定前缀的数据 75 | for file in self.templ: 76 | for pre in prefixs: 77 | if pre in file: 78 | self.image_list_l.append(file) 79 | #找指定前缀的数据 80 | for file in self.temps: 81 | for pre in prefixs: 82 | if pre in file: 83 | self.image_list_s.append(file) 84 | #找指定前缀的数据 85 | for file in self.temp_mri: 86 | for pre in prefixs: 87 | if pre in file: 88 | self.image_list_mri.append(file) 89 | # print(self.image_list_l) 90 | # print(self.image_list_s) 91 | self.transform = transform 92 | 93 | def __len__(self): 94 | return len(self.image_list_l) 95 | 96 | def __getitem__(self, item): 97 | #读图片(低剂量PET) 98 | image_path_l = os.path.join(self.l_path, self.image_list_l[item]) 99 | #image = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR -> RGB 100 | image_l,h=load(image_path_l) 101 | image=np.array(image_l) 102 | #print(image.shape) 103 | if self.transform is not None: 104 | image = self.transform(image_l) 105 | #读标签(高质量PET) 106 | image_path_s = os.path.join(self.s_path, self.image_list_s[item]) 107 | image_s,h2=load(image_path_s) 108 | image_s=np.array(image_s) 109 | #print(image_l.shape)0 110 | # print(image_path_l,image_path_s) 111 | #添加通道维度 112 | image_l=image_l[np.newaxis,:] 113 | image_s=image_s[np.newaxis,:] 114 | image_l=torch.Tensor(image_l) 115 | image_s=torch.Tensor(image_s) 116 | #print(image.shape) 117 | if self.transform is not None: 118 | image = self.transform(image_s) 119 | #返回:影像,标签 120 | return image_l, image_s 121 | # 122 | #data 123 | def loadData(root1, subfolder1,root2,subfolder2,prefixs, batch_size, shuffle=True): 124 | 125 | transform = None 126 | #测试已修改 127 | dataset = MyDataset(root1, subfolder1,root2,subfolder2,prefixs,transform=transform) 128 | #dataset = MyDataset(root, subfolder,transform=None) 129 | 130 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) 131 | #multi data 132 | def loadMultiData(root1, subfolder1,root2,subfolder2,root3,subfolder3,prefixs, batch_size, shuffle=True): 133 | 134 | transform = None 135 | #测试已修改 136 | dataset = MyMultiDataset(root1, subfolder1,root2,subfolder2,root3,subfolder3,prefixs,transform=transform) 137 | #dataset = MyDataset(root, subfolder,transform=None) 138 | 139 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) 140 | # 141 | #x=MyDataset('./data/l_cut','') 142 | def readTxtLineAsList(txt_path): 143 | fi = open(txt_path, 'r') 144 | txt = fi.readlines() 145 | res_list = [] 146 | for w in txt: 147 | w = w.replace('\n', '') 148 | res_list.append(w) 149 | return res_list 150 | 151 | if __name__ == '__main__': 152 | train_txt_path = r"E:\Projects\PyCharm Projects\dataset\split\Ex2\train.txt" 153 | val_txt_path = r"E:\Projects\PyCharm Projects\dataset\split\Ex2\val.txt" 154 | train_imgs = readTxtLineAsList(train_txt_path) 155 | print(train_imgs) 156 | val_imgs = readTxtLineAsList(val_txt_path) 157 | print(val_imgs) 158 | trainloader=loadData('E:\Projects\PyCharm Projects\dataset\clinical/train_l_cut','','E:\Projects\PyCharm Projects\dataset\clinical/train_s_cut','',prefixs=train_imgs,batch_size=1) 159 | valloader = loadData('E:\Projects\PyCharm Projects\dataset\clinical/train_l_cut', '', 160 | 'E:\Projects\PyCharm Projects\dataset\clinical/train_s_cut', '', prefixs=val_imgs, 161 | batch_size=1) -------------------------------------------------------------------------------- /data/prepare_data.py: -------------------------------------------------------------------------------- 1 | import medpy 2 | from medpy.io import load 3 | from medpy.io import save 4 | import numpy as np 5 | import os 6 | import SimpleITK as sitk 7 | 8 | 9 | # 用于数据切片 10 | def Datamake(root_l, root_s): 11 | all_l_names = [] 12 | all_s_names = [] 13 | for root, dirs, files in os.walk(root_l): 14 | all_l_names = (files) 15 | for root, dirs, files in os.walk(root_s): 16 | all_s_names = (files) 17 | # 18 | all_l_name = [] 19 | all_s_name = [] 20 | for i in all_l_names: 21 | if os.path.splitext(i)[1] == ".img": 22 | # print(i) 23 | all_l_name.append(i) 24 | for i in all_s_names: 25 | if os.path.splitext(i)[1] == ".img": 26 | all_s_name.append(i) 27 | # 28 | print(all_l_name) 29 | # 30 | for file in all_l_name: 31 | image_path_l = os.path.join(root_l, file) 32 | image_l, h = load(image_path_l) 33 | image_l = np.array(image_l) 34 | # print(image_l.shape) 35 | cut_cnt = 0 36 | # print(cut_cnt) 37 | for i in range(0, 8): 38 | for j in range(0, 8): 39 | for k in range(0, 8): 40 | image_cut = image_l[9 * i:64 + 9 * i, 9 * j:64 + 9 * j, 9 * k:64 + 9 * k] 41 | savImg = sitk.GetImageFromArray(image_cut.transpose(2, 1, 0)) 42 | sitk.WriteImage(savImg, 43 | 'C:\\Users\Administrator\Desktop\PET-Reconstruction-with-Diffusion\dataset\processed\LPET_cut' + '/' + file + '_cut' + str(cut_cnt) + '.img') 44 | cut_cnt += 1 45 | 46 | for file in all_s_name: 47 | image_path_s = os.path.join(root_s, file) 48 | image_s, h = load(image_path_s) 49 | image_s = np.array(image_s) 50 | # print(image_l.shape) 51 | cut_cnt = 0 52 | for i in range(0, 8): 53 | for j in range(0, 8): 54 | for k in range(0, 8): 55 | image_cut = image_s[9 * i:64 + 9 * i, 9 * j:64 + 9 * j, 9 * k:64 + 9 * k] 56 | savImg = sitk.GetImageFromArray(image_cut.transpose(2, 1, 0)) 57 | sitk.WriteImage(savImg, 58 | 'C:\\Users\Administrator\Desktop\PET-Reconstruction-with-Diffusion\dataset\processed\HPET_cut' + '/' + file + '_cut' + str(cut_cnt) + '.img') 59 | cut_cnt += 1 60 | if __name__ == '__main__': 61 | Datamake('D:\zpx\CVT3D\dataset\processed\LPET','D:\zpx\CVT3D\dataset\processed\SPET') 62 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import random 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | 15 | def get_paths_from_images(path): 16 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 17 | images = [] 18 | for dirpath, _, fnames in sorted(os.walk(path)): 19 | for fname in sorted(fnames): 20 | if fname.endswith('.mat'): 21 | img_path = os.path.join(dirpath, fname) 22 | images.append(img_path) 23 | assert images, '{:s} has no valid image file'.format(path) 24 | return sorted(images) 25 | 26 | 27 | 28 | def transform2numpy(img): 29 | img = np.array(img) 30 | img = img.astype(np.float32) / 255. 31 | if img.ndim == 2: 32 | img = np.expand_dims(img, axis=2) 33 | # some images have 4 channels 34 | if img.shape[2] > 3: 35 | img = img[:, :, :3] 36 | return img 37 | 38 | 39 | def transform2tensor(img, min_max=(0, 1)): 40 | # HWC to CHW 41 | img = torch.from_numpy(np.ascontiguousarray( 42 | np.transpose(img, (2, 0, 1)))).float() 43 | # to range min_max 44 | img = img*(min_max[1] - min_max[0]) + min_max[0] 45 | return img 46 | 47 | totensor = torchvision.transforms.ToTensor() 48 | def transform_augment(img_list, split='val', min_max=(0, 1)): 49 | imgs = [totensor(img) for img in img_list] 50 | return imgs 51 | # implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14 52 | -------------------------------------------------------------------------------- /easy_train.py: -------------------------------------------------------------------------------- 1 | import torch, torchvision 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | import torchvision.transforms as transforms 7 | 8 | from einops import rearrange, repeat 9 | from tqdm.notebook import tqdm 10 | from functools import partial 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import math, os, copy 15 | 16 | """ 17 | Define U-net Architecture: 18 | Approximate reverse diffusion process by using U-net 19 | U-net of SR3 : U-net backbone + Positional Encoding of time + Multihead Self-Attention 20 | """ 21 | 22 | # U-net Encoding 23 | class PositionalEncoding(nn.Module): 24 | def __init__(self, dim): 25 | super().__init__() 26 | self.dim = dim 27 | 28 | def forward(self, noise_level): 29 | # Input : tensor of value of coefficient alpha at specific step of diffusion process e.g. torch.Tensor([0.03]) 30 | # Transform level of noise into representation of given desired dimension 31 | count = self.dim // 2 32 | step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count 33 | encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 34 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 35 | return encoding 36 | 37 | # 38 | class FeatureWiseAffine(nn.Module): 39 | def __init__(self, in_channels, out_channels, use_affine_level=False): 40 | super(FeatureWiseAffine, self).__init__() 41 | self.use_affine_level = use_affine_level 42 | self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level))) 43 | 44 | def forward(self, x, noise_embed): 45 | noise = self.noise_func(noise_embed).view(x.shape[0], -1, 1, 1) 46 | if self.use_affine_level: 47 | gamma, beta = noise.chunk(2, dim=1) 48 | x = (1 + gamma) * x + beta 49 | else: 50 | x = x + noise 51 | return x 52 | 53 | # swish activation function 54 | class Swish(nn.Module): 55 | def forward(self, x): 56 | return x * torch.sigmoid(x) 57 | 58 | 59 | class Upsample(nn.Module): 60 | def __init__(self, dim): 61 | super().__init__() 62 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 63 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 64 | 65 | def forward(self, x): 66 | return self.conv(self.up(x)) 67 | 68 | 69 | class Downsample(nn.Module): 70 | def __init__(self, dim): 71 | super().__init__() 72 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 73 | 74 | def forward(self, x): 75 | return self.conv(x) 76 | 77 | 78 | class Block(nn.Module): 79 | def __init__(self, dim, dim_out, groups=32, dropout=0): 80 | super().__init__() 81 | self.block = nn.Sequential( 82 | nn.GroupNorm(groups, dim), 83 | Swish(), 84 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 85 | nn.Conv2d(dim, dim_out, 3, padding=1) 86 | ) 87 | 88 | def forward(self, x): 89 | return self.block(x) 90 | 91 | 92 | # Linear Multi-head Self-attention 93 | class SelfAtt(nn.Module): 94 | def __init__(self, channel_dim, num_heads, norm_groups=32): 95 | super(SelfAtt, self).__init__() 96 | self.groupnorm = nn.GroupNorm(norm_groups, channel_dim) 97 | self.num_heads = num_heads 98 | self.qkv = nn.Conv2d(channel_dim, channel_dim * 3, 1, bias=False) 99 | self.proj = nn.Conv2d(channel_dim, channel_dim, 1) 100 | 101 | def forward(self, x): 102 | b, c, h, w = x.size() 103 | x = self.groupnorm(x) 104 | qkv = rearrange(self.qkv(x), "b (qkv heads c) h w -> (qkv) b heads c (h w)", heads=self.num_heads, qkv=3) 105 | queries, keys, values = qkv[0], qkv[1], qkv[2] 106 | 107 | keys = F.softmax(keys, dim=-1) 108 | att = torch.einsum('bhdn,bhen->bhde', keys, values) 109 | out = torch.einsum('bhde,bhdn->bhen', att, queries) 110 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.num_heads, h=h, w=w) 111 | 112 | return self.proj(out) 113 | 114 | 115 | class ResBlock(nn.Module): 116 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, 117 | num_heads=1, use_affine_level=False, norm_groups=32, att=True): 118 | super().__init__() 119 | self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level) 120 | self.block1 = Block(dim, dim_out, groups=norm_groups) 121 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 122 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 123 | self.att = att 124 | self.attn = SelfAtt(dim_out, num_heads=num_heads, norm_groups=norm_groups) 125 | 126 | def forward(self, x, time_emb): 127 | y = self.block1(x) 128 | y = self.noise_func(y, time_emb) 129 | y = self.block2(y) 130 | x = y + self.res_conv(x) 131 | if self.att: 132 | x = self.attn(x) 133 | return x 134 | 135 | 136 | class UNet(nn.Module): 137 | def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32, 138 | channel_mults=[1, 2, 4, 8, 8], res_blocks=3, dropout=0, img_size=128): 139 | super().__init__() 140 | 141 | noise_level_channel = inner_channel 142 | self.noise_level_mlp = nn.Sequential( 143 | PositionalEncoding(inner_channel), 144 | nn.Linear(inner_channel, inner_channel * 4), 145 | Swish(), 146 | nn.Linear(inner_channel * 4, inner_channel) 147 | ) 148 | 149 | num_mults = len(channel_mults) 150 | pre_channel = inner_channel 151 | feat_channels = [pre_channel] 152 | now_res = img_size 153 | 154 | # Downsampling stage of U-net 155 | downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)] 156 | for ind in range(num_mults): 157 | is_last = (ind == num_mults - 1) 158 | channel_mult = inner_channel * channel_mults[ind] 159 | for _ in range(0, res_blocks): 160 | downs.append(ResBlock( 161 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, 162 | norm_groups=norm_groups, dropout=dropout)) 163 | feat_channels.append(channel_mult) 164 | pre_channel = channel_mult 165 | if not is_last: 166 | downs.append(Downsample(pre_channel)) 167 | feat_channels.append(pre_channel) 168 | now_res = now_res // 2 169 | self.downs = nn.ModuleList(downs) 170 | 171 | self.mid = nn.ModuleList([ 172 | ResBlock(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 173 | norm_groups=norm_groups, dropout=dropout), 174 | ResBlock(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, 175 | norm_groups=norm_groups, dropout=dropout, att=False) 176 | ]) 177 | 178 | # Upsampling stage of U-net 179 | ups = [] 180 | for ind in reversed(range(num_mults)): 181 | is_last = (ind < 1) 182 | channel_mult = inner_channel * channel_mults[ind] 183 | for _ in range(0, res_blocks + 1): 184 | ups.append(ResBlock( 185 | pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, 186 | norm_groups=norm_groups, dropout=dropout)) 187 | pre_channel = channel_mult 188 | if not is_last: 189 | ups.append(Upsample(pre_channel)) 190 | now_res = now_res * 2 191 | 192 | self.ups = nn.ModuleList(ups) 193 | 194 | self.final_conv = Block(pre_channel, out_channel, groups=norm_groups) 195 | 196 | def forward(self, x, noise_level): 197 | # Embedding of time step with noise coefficient alpha 198 | t = self.noise_level_mlp(noise_level) 199 | print(t.shape) 200 | feats = [] 201 | for layer in self.downs: 202 | if isinstance(layer, ResBlock): 203 | x = layer(x, t) 204 | else: 205 | x = layer(x) 206 | feats.append(x) 207 | 208 | for layer in self.mid: 209 | x = layer(x, t) 210 | 211 | for layer in self.ups: 212 | if isinstance(layer, ResBlock): 213 | x = layer(torch.cat((x, feats.pop()), dim=1), t) 214 | else: 215 | x = layer(x) 216 | 217 | return self.final_conv(x) 218 | 219 | 220 | """ 221 | Define Diffusion process framework to train desired model: 222 | Forward Diffusion process: 223 | Given original image x_0, apply Gaussian noise ε_t for each time step t 224 | After proper length of time step, image x_T reachs to pure Gaussian noise 225 | Objective of model f : 226 | model f is trained to predict actual added noise ε_t for each time step t 227 | """ 228 | 229 | 230 | class Diffusion(nn.Module): 231 | def __init__(self, model, device, img_size, LR_size, channels=3): 232 | super().__init__() 233 | self.channels = channels 234 | self.model = model.to(device) 235 | self.img_size = img_size 236 | self.LR_size = LR_size 237 | self.device = device 238 | 239 | def set_loss(self, loss_type): 240 | if loss_type == 'l1': 241 | self.loss_func = nn.L1Loss(reduction='sum') 242 | elif loss_type == 'l2': 243 | self.loss_func = nn.MSELoss(reduction='sum') 244 | else: 245 | raise NotImplementedError() 246 | 247 | def make_beta_schedule(self, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2): 248 | if schedule == 'linear': 249 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) 250 | elif schedule == 'warmup': 251 | warmup_frac = 0.1 252 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 253 | warmup_time = int(n_timestep * warmup_frac) 254 | betas[:warmup_time] = np.linspace(linear_start, linear_end, warmup_time, dtype=np.float64) 255 | elif schedule == "cosine": 256 | cosine_s = 8e-3 257 | timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 258 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 259 | alphas = torch.cos(alphas).pow(2) 260 | alphas = alphas / alphas[0] 261 | betas = 1 - alphas[1:] / alphas[:-1] 262 | betas = betas.clamp(max=0.999) 263 | else: 264 | raise NotImplementedError(schedule) 265 | return betas 266 | 267 | def set_new_noise_schedule(self, schedule_opt): 268 | to_torch = partial(torch.tensor, dtype=torch.float32, device=self.device) 269 | 270 | betas = self.make_beta_schedule( 271 | schedule=schedule_opt['schedule'], 272 | n_timestep=schedule_opt['n_timestep'], 273 | linear_start=schedule_opt['linear_start'], 274 | linear_end=schedule_opt['linear_end']) 275 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas 276 | alphas = 1. - betas 277 | alphas_cumprod = np.cumprod(alphas, axis=0) 278 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 279 | self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod)) 280 | self.num_timesteps = int(len(betas)) 281 | # Coefficient for forward diffusion q(x_t | x_{t-1}) and others 282 | self.register_buffer('betas', to_torch(betas)) 283 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 284 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) 285 | self.register_buffer('pred_coef1', to_torch(np.sqrt(1. / alphas_cumprod))) 286 | self.register_buffer('pred_coef2', to_torch(np.sqrt(1. / alphas_cumprod - 1))) 287 | 288 | # Coefficient for reverse diffusion posterior q(x_{t-1} | x_t, x_0) 289 | variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 290 | self.register_buffer('variance', to_torch(variance)) 291 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 292 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(variance, 1e-20)))) 293 | self.register_buffer('posterior_mean_coef1', 294 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 295 | self.register_buffer('posterior_mean_coef2', 296 | to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 297 | 298 | # Predict desired image x_0 from x_t with noise z_t -> Output is predicted x_0 299 | def predict_start(self, x_t, t, noise): 300 | return self.pred_coef1[t] * x_t - self.pred_coef2[t] * noise 301 | 302 | # Compute mean and log variance of posterior(reverse diffusion process) distribution 303 | def q_posterior(self, x_start, x_t, t): 304 | posterior_mean = self.posterior_mean_coef1[t] * x_start + self.posterior_mean_coef2[t] * x_t 305 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 306 | return posterior_mean, posterior_log_variance_clipped 307 | 308 | # Note that posterior q for reverse diffusion process is conditioned Gaussian distribution q(x_{t-1}|x_t, x_0) 309 | # Thus to compute desired posterior q, we need original image x_0 in ideal, 310 | # but it's impossible for actual training procedure -> Thus we reconstruct desired x_0 and use this for posterior 311 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None): 312 | batch_size = x.shape[0] 313 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t + 1]]).repeat(batch_size, 1).to(x.device) 314 | x_recon = self.predict_start(x, t, noise=self.model(torch.cat([condition_x, x], dim=1), noise_level)) 315 | 316 | if clip_denoised: 317 | x_recon.clamp_(-1., 1.) 318 | 319 | mean, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 320 | return mean, posterior_log_variance 321 | 322 | # Progress single step of reverse diffusion process 323 | # Given mean and log variance of posterior, sample reverse diffusion result from the posterior 324 | @torch.no_grad() 325 | def p_sample(self, x, t, clip_denoised=True, condition_x=None): 326 | mean, log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x) 327 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 328 | return mean + noise * (0.5 * log_variance).exp() 329 | 330 | # Progress whole reverse diffusion process 331 | @torch.no_grad() 332 | def super_resolution(self, x_in): 333 | img = torch.rand_like(x_in, device=x_in.device) 334 | for i in reversed(range(0, self.num_timesteps)): 335 | img = self.p_sample(img, i, condition_x=x_in) 336 | return img 337 | 338 | # Compute loss to train the model 339 | def p_losses(self, x_in): 340 | x_start = x_in 341 | lr_imgs = transforms.Resize(self.img_size)(transforms.Resize(self.LR_size)(x_in)) 342 | b, c, h, w = x_start.shape 343 | t = np.random.randint(1, self.num_timesteps + 1) 344 | sqrt_alpha = torch.FloatTensor( 345 | np.random.uniform(self.sqrt_alphas_cumprod_prev[t - 1], self.sqrt_alphas_cumprod_prev[t], size=b) 346 | ).to(x_start.device) 347 | sqrt_alpha = sqrt_alpha.view(-1, 1, 1, 1) 348 | 349 | noise = torch.randn_like(x_start).to(x_start.device) 350 | # Perturbed image obtained by forward diffusion process at random time step t 351 | x_noisy = sqrt_alpha * x_start + (1 - sqrt_alpha ** 2).sqrt() * noise 352 | # The model predict actual noise added at time step t 353 | pred_noise = self.model(torch.cat([lr_imgs, x_noisy], dim=1), noise_level=sqrt_alpha) 354 | 355 | return self.loss_func(noise, pred_noise) 356 | 357 | def forward(self, x, *args, **kwargs): 358 | return self.p_losses(x, *args, **kwargs) 359 | 360 | 361 | # Class to train & test desired model 362 | class SR3(): 363 | def __init__(self, device, img_size, LR_size, loss_type, dataloader, testloader, 364 | schedule_opt, save_path, load_path=None, load=False, 365 | in_channel=6, out_channel=3, inner_channel=32, norm_groups=8, 366 | channel_mults=(1, 2, 4, 8, 8), res_blocks=3, dropout=0, lr=1e-5, distributed=False): 367 | super(SR3, self).__init__() 368 | self.dataloader = dataloader 369 | self.testloader = testloader 370 | self.device = device 371 | self.save_path = save_path 372 | self.img_size = img_size 373 | self.LR_size = LR_size 374 | 375 | model = UNet(in_channel, out_channel, inner_channel, norm_groups, channel_mults, res_blocks, dropout, img_size) 376 | self.sr3 = Diffusion(model, device, img_size, LR_size, out_channel) 377 | 378 | # Apply weight initialization & set loss & set noise schedule 379 | self.sr3.apply(self.weights_init_orthogonal) 380 | self.sr3.set_loss(loss_type) 381 | self.sr3.set_new_noise_schedule(schedule_opt) 382 | 383 | if distributed: 384 | assert torch.cuda.is_available() 385 | self.sr3 = nn.DataParallel(self.sr3) 386 | 387 | self.optimizer = torch.optim.Adam(self.sr3.parameters(), lr=lr) 388 | 389 | params = sum(p.numel() for p in self.sr3.parameters()) 390 | print(f"Number of model parameters : {params}") 391 | 392 | if load: 393 | self.load(load_path) 394 | 395 | def weights_init_orthogonal(self, m): 396 | classname = m.__class__.__name__ 397 | if classname.find('Conv') != -1: 398 | init.orthogonal_(m.weight.data, gain=1) 399 | if m.bias is not None: 400 | m.bias.data.zero_() 401 | elif classname.find('Linear') != -1: 402 | init.orthogonal_(m.weight.data, gain=1) 403 | if m.bias is not None: 404 | m.bias.data.zero_() 405 | elif classname.find('BatchNorm2d') != -1: 406 | init.constant_(m.weight.data, 1.0) 407 | init.constant_(m.bias.data, 0.0) 408 | 409 | def train(self, epoch, verbose): 410 | fixed_imgs = copy.deepcopy(next(iter(self.testloader))) 411 | fixed_imgs = fixed_imgs[0].to(self.device) 412 | # Transform to low-resolution images 413 | fixed_imgs = transforms.Resize(self.img_size)(transforms.Resize(self.LR_size)(fixed_imgs)) 414 | 415 | for i in tqdm(range(epoch)): 416 | train_loss = 0 417 | for _, imgs in enumerate(self.dataloader): 418 | # Initial imgs are high-resolution 419 | imgs = imgs[0].to(self.device) 420 | b, c, h, w = imgs.shape 421 | 422 | self.optimizer.zero_grad() 423 | loss = self.sr3(imgs) 424 | loss = loss.sum() / int(b * c * h * w) 425 | loss.backward() 426 | self.optimizer.step() 427 | train_loss += loss.item() * b 428 | 429 | if (i + 1) % verbose == 0: 430 | self.sr3.eval() 431 | test_imgs = next(iter(self.testloader)) 432 | test_imgs = test_imgs[0].to(self.device) 433 | b, c, h, w = test_imgs.shape 434 | 435 | with torch.no_grad(): 436 | val_loss = self.sr3(test_imgs) 437 | val_loss = val_loss.sum() / int(b * c * h * w) 438 | self.sr3.train() 439 | 440 | train_loss = train_loss / len(self.dataloader) 441 | print(f'Epoch: {i + 1} / loss:{train_loss:.3f} / val_loss:{val_loss.item():.3f}') 442 | 443 | # Save example of test images to check training 444 | plt.figure(figsize=(15, 10)) 445 | plt.subplot(1, 2, 1) 446 | plt.axis("off") 447 | plt.title("Low-Resolution Inputs") 448 | plt.imshow(np.transpose(torchvision.utils.make_grid(fixed_imgs, 449 | nrow=2, padding=1, normalize=True).cpu(), 450 | (1, 2, 0))) 451 | 452 | plt.subplot(1, 2, 2) 453 | plt.axis("off") 454 | plt.title("Super-Resolution Results") 455 | plt.imshow(np.transpose(torchvision.utils.make_grid(self.test(fixed_imgs).detach().cpu(), 456 | nrow=2, padding=1, normalize=True), (1, 2, 0))) 457 | plt.savefig('SuperResolution_Result.jpg') 458 | plt.close() 459 | 460 | # Save model weight 461 | self.save(self.save_path) 462 | 463 | def test(self, imgs): 464 | imgs_lr = transforms.Resize(self.img_size)(transforms.Resize(self.LR_size)(imgs)) 465 | self.sr3.eval() 466 | with torch.no_grad(): 467 | if isinstance(self.sr3, nn.DataParallel): 468 | result_SR = self.sr3.module.super_resolution(imgs_lr) 469 | else: 470 | result_SR = self.sr3.super_resolution(imgs_lr) 471 | self.sr3.train() 472 | return result_SR 473 | 474 | def save(self, save_path): 475 | network = self.sr3 476 | if isinstance(self.sr3, nn.DataParallel): 477 | network = network.module 478 | state_dict = network.state_dict() 479 | for key, param in state_dict.items(): 480 | state_dict[key] = param.cpu() 481 | torch.save(state_dict, save_path) 482 | 483 | def load(self, load_path): 484 | network = self.sr3 485 | if isinstance(self.sr3, nn.DataParallel): 486 | network = network.module 487 | network.load_state_dict(torch.load(load_path)) 488 | print("Model loaded successfully") 489 | 490 | 491 | if __name__ == "__main__": 492 | batch_size = 16 493 | LR_size = 32 494 | img_size = 128 495 | root = './data/ffhq_thumb' 496 | testroot = './data/celeba_hq' 497 | 498 | transforms_ = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(), 499 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 500 | dataloader = DataLoader(torchvision.datasets.ImageFolder(root, transform=transforms_), 501 | batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) 502 | testloader = DataLoader(torchvision.datasets.ImageFolder(testroot, transform=transforms_), 503 | batch_size=4, shuffle=True, num_workers=8, pin_memory=True) 504 | 505 | cuda = torch.cuda.is_available() 506 | device = torch.device("cuda:2" if cuda else "cpu") 507 | schedule_opt = {'schedule': 'linear', 'n_timestep': 2000, 'linear_start': 1e-4, 'linear_end': 0.05} 508 | 509 | sr3 = SR3(device, img_size=img_size, LR_size=LR_size, loss_type='l1', 510 | dataloader=dataloader, testloader=testloader, schedule_opt=schedule_opt, 511 | save_path='./SR3.pt', load_path='./SR3.pt', load=True, inner_channel=96, 512 | norm_groups=16, channel_mults=(1, 2, 2, 2), dropout=0.2, res_blocks=2, lr=1e-5, distributed=False) 513 | sr3.train(epoch=250, verbose=25) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data as Data 3 | import model as Model 4 | import argparse 5 | import logging 6 | import core.logger as Logger 7 | import core.metrics as Metrics 8 | from core.wandb_logger import WandbLogger 9 | from tensorboardX import SummaryWriter 10 | import os 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-c', '--config', type=str, default='config/sr_sr3_64_512.json', 15 | help='JSON file for configuration') 16 | parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val') 17 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 18 | parser.add_argument('-debug', '-d', action='store_true') 19 | parser.add_argument('-enable_wandb', action='store_true') 20 | parser.add_argument('-log_infer', action='store_true') 21 | 22 | # parse configs 23 | args = parser.parse_args() 24 | opt = Logger.parse(args) 25 | # Convert to NoneDict, which return None for missing key. 26 | opt = Logger.dict_to_nonedict(opt) 27 | 28 | # logging 29 | torch.backends.cudnn.enabled = True 30 | torch.backends.cudnn.benchmark = True 31 | 32 | Logger.setup_logger(None, opt['path']['log'], 33 | 'train', level=logging.INFO, screen=True) 34 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) 35 | logger = logging.getLogger('base') 36 | logger.info(Logger.dict2str(opt)) 37 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) 38 | 39 | # Initialize WandbLogger 40 | if opt['enable_wandb']: 41 | wandb_logger = WandbLogger(opt) 42 | else: 43 | wandb_logger = None 44 | 45 | # dataset 46 | for phase, dataset_opt in opt['datasets'].items(): 47 | if phase == 'val': 48 | val_set = Data.create_dataset(dataset_opt, phase) 49 | val_loader = Data.create_dataloader( 50 | val_set, dataset_opt, phase) 51 | logger.info('Initial Dataset Finished') 52 | 53 | # model 54 | diffusion = Model.create_model(opt) 55 | logger.info('Initial Model Finished') 56 | 57 | diffusion.set_new_noise_schedule( 58 | opt['model']['beta_schedule']['val'], schedule_phase='val') 59 | 60 | logger.info('Begin Model Inference.') 61 | current_step = 0 62 | current_epoch = 0 63 | idx = 0 64 | 65 | result_path = '{}'.format(opt['path']['results']) 66 | os.makedirs(result_path, exist_ok=True) 67 | for _, val_data in enumerate(val_loader): 68 | idx += 1 69 | diffusion.feed_data(val_data) 70 | diffusion.test(continous=True) 71 | visuals = diffusion.get_current_visuals(need_LR=False) 72 | 73 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8 74 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8 75 | 76 | sr_img_mode = 'grid' 77 | if sr_img_mode == 'single': 78 | # single img series 79 | sr_img = visuals['SR'] # uint8 80 | sample_num = sr_img.shape[0] 81 | for iter in range(0, sample_num): 82 | Metrics.save_img( 83 | Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter)) 84 | else: 85 | # grid img 86 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8 87 | Metrics.save_img( 88 | sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx)) 89 | Metrics.save_img( 90 | Metrics.tensor2img(visuals['SR'][-1]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx)) 91 | 92 | Metrics.save_img( 93 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) 94 | Metrics.save_img( 95 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) 96 | 97 | if wandb_logger and opt['log_infer']: 98 | wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img) 99 | 100 | if wandb_logger and opt['log_infer']: 101 | wandb_logger.log_eval_table(commit=True) 102 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | from .model import DDPM as M 7 | m = M(opt) 8 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 9 | return m 10 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(): 7 | def __init__(self, opt): 8 | self.opt = opt 9 | self.device = torch.device( 10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu') 11 | self.begin_step = 0 12 | self.begin_epoch = 0 13 | 14 | def feed_data(self, data): 15 | pass 16 | 17 | def optimize_parameters(self): 18 | pass 19 | 20 | def get_current_visuals(self): 21 | pass 22 | 23 | def get_current_losses(self): 24 | pass 25 | 26 | def print_network(self): 27 | pass 28 | 29 | def set_device(self, x): 30 | if isinstance(x, dict): 31 | for key, item in x.items(): 32 | if item is not None: 33 | x[key] = item.to(self.device) 34 | elif isinstance(x, list): 35 | for item in x: 36 | if item is not None: 37 | item = item.to(self.device) 38 | else: 39 | x = x.to(self.device) 40 | return x 41 | 42 | def get_network_description(self, network): 43 | '''Get the string and total parameters of the network''' 44 | if isinstance(network, nn.DataParallel): 45 | network = network.module 46 | s = str(network) 47 | n = sum(map(lambda x: x.numel(), network.parameters())) 48 | return s, n 49 | -------------------------------------------------------------------------------- /model/ddpm_modules/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn, einsum 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from functools import partial 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | 11 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 12 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 13 | warmup_time = int(n_timestep * warmup_frac) 14 | betas[:warmup_time] = np.linspace( 15 | linear_start, linear_end, warmup_time, dtype=np.float64) 16 | return betas 17 | 18 | 19 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 20 | if schedule == 'quad': 21 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 22 | n_timestep, dtype=np.float64) ** 2 23 | elif schedule == 'linear': 24 | betas = np.linspace(linear_start, linear_end, 25 | n_timestep, dtype=np.float64) 26 | elif schedule == 'warmup10': 27 | betas = _warmup_beta(linear_start, linear_end, 28 | n_timestep, 0.1) 29 | elif schedule == 'warmup50': 30 | betas = _warmup_beta(linear_start, linear_end, 31 | n_timestep, 0.5) 32 | elif schedule == 'const': 33 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 34 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 35 | betas = 1. / np.linspace(n_timestep, 36 | 1, n_timestep, dtype=np.float64) 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / 40 | n_timestep + cosine_s 41 | ) 42 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 43 | alphas = torch.cos(alphas).pow(2) 44 | alphas = alphas / alphas[0] 45 | betas = 1 - alphas[1:] / alphas[:-1] 46 | betas = betas.clamp(max=0.999) 47 | else: 48 | raise NotImplementedError(schedule) 49 | return betas 50 | 51 | 52 | # gaussian diffusion trainer class 53 | 54 | def exists(x): 55 | return x is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | def extract(a, t, x_shape): 65 | b, *_ = t.shape 66 | out = a.gather(-1, t) 67 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 68 | 69 | 70 | def noise_like(shape, device, repeat=False): 71 | def repeat_noise(): return torch.randn( 72 | (1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 73 | 74 | def noise(): return torch.randn(shape, device=device) 75 | return repeat_noise() if repeat else noise() 76 | 77 | 78 | class GaussianDiffusion(nn.Module): 79 | def __init__( 80 | self, 81 | denoise_fn, 82 | image_size, 83 | channels=3, 84 | loss_type='l1', 85 | conditional=True, 86 | schedule_opt=None 87 | ): 88 | super().__init__() 89 | self.channels = channels 90 | self.image_size = image_size 91 | self.denoise_fn = denoise_fn 92 | self.conditional = conditional 93 | self.loss_type = loss_type 94 | if schedule_opt is not None: 95 | pass 96 | # self.set_new_noise_schedule(schedule_opt) 97 | 98 | def set_loss(self, device): 99 | if self.loss_type == 'l1': 100 | self.loss_func = nn.L1Loss(reduction='sum').to(device) 101 | elif self.loss_type == 'l2': 102 | self.loss_func = nn.MSELoss(reduction='sum').to(device) 103 | else: 104 | raise NotImplementedError() 105 | 106 | def set_new_noise_schedule(self, schedule_opt, device): 107 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 108 | betas = make_beta_schedule( 109 | schedule=schedule_opt['schedule'], 110 | n_timestep=schedule_opt['n_timestep'], 111 | linear_start=schedule_opt['linear_start'], 112 | linear_end=schedule_opt['linear_end']) 113 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas 114 | alphas = 1. - betas 115 | alphas_cumprod = np.cumprod(alphas, axis=0) 116 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 117 | 118 | timesteps, = betas.shape 119 | self.num_timesteps = int(timesteps) 120 | self.register_buffer('betas', to_torch(betas)) 121 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 122 | self.register_buffer('alphas_cumprod_prev', 123 | to_torch(alphas_cumprod_prev)) 124 | 125 | # calculations for diffusion q(x_t | x_{t-1}) and others 126 | self.register_buffer('sqrt_alphas_cumprod', 127 | to_torch(np.sqrt(alphas_cumprod))) 128 | self.register_buffer('sqrt_one_minus_alphas_cumprod', 129 | to_torch(np.sqrt(1. - alphas_cumprod))) 130 | self.register_buffer('log_one_minus_alphas_cumprod', 131 | to_torch(np.log(1. - alphas_cumprod))) 132 | self.register_buffer('sqrt_recip_alphas_cumprod', 133 | to_torch(np.sqrt(1. / alphas_cumprod))) 134 | self.register_buffer('sqrt_recipm1_alphas_cumprod', 135 | to_torch(np.sqrt(1. / alphas_cumprod - 1))) 136 | 137 | # calculations for posterior q(x_{t-1} | x_t, x_0) 138 | posterior_variance = betas * \ 139 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 140 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 141 | self.register_buffer('posterior_variance', 142 | to_torch(posterior_variance)) 143 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 144 | self.register_buffer('posterior_log_variance_clipped', to_torch( 145 | np.log(np.maximum(posterior_variance, 1e-20)))) 146 | self.register_buffer('posterior_mean_coef1', to_torch( 147 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 148 | self.register_buffer('posterior_mean_coef2', to_torch( 149 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 150 | 151 | def q_mean_variance(self, x_start, t): 152 | mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 153 | variance = extract(1. - self.alphas_cumprod, t, x_start.shape) 154 | log_variance = extract( 155 | self.log_one_minus_alphas_cumprod, t, x_start.shape) 156 | return mean, variance, log_variance 157 | 158 | def predict_start_from_noise(self, x_t, t, noise): 159 | return ( 160 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - 161 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise 162 | ) 163 | 164 | def q_posterior(self, x_start, x_t, t): 165 | posterior_mean = ( 166 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + 167 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t 168 | ) 169 | posterior_variance = extract(self.posterior_variance, t, x_t.shape) 170 | posterior_log_variance_clipped = extract( 171 | self.posterior_log_variance_clipped, t, x_t.shape) 172 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 173 | 174 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None): 175 | if condition_x is not None: 176 | x_recon = self.predict_start_from_noise( 177 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), t)) 178 | else: 179 | x_recon = self.predict_start_from_noise( 180 | x, t=t, noise=self.denoise_fn(x, t)) 181 | 182 | if clip_denoised: 183 | x_recon.clamp_(-1., 1.) 184 | 185 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior( 186 | x_start=x_recon, x_t=x, t=t) 187 | return model_mean, posterior_variance, posterior_log_variance 188 | 189 | @torch.no_grad() 190 | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None): 191 | b, *_, device = *x.shape, x.device 192 | model_mean, _, model_log_variance = self.p_mean_variance( 193 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x) 194 | noise = noise_like(x.shape, device, repeat_noise) 195 | # no noise when t == 0 196 | nonzero_mask = (1 - (t == 0).float()).reshape(b, 197 | *((1,) * (len(x.shape) - 1))) 198 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise 199 | 200 | @torch.no_grad() 201 | def p_sample_loop(self, x_in, continous=False): 202 | device = self.betas.device 203 | sample_inter = (1 | (self.num_timesteps//10)) 204 | 205 | if not self.conditional: 206 | shape = x_in 207 | b = shape[0] 208 | img = torch.randn(shape, device=device) 209 | ret_img = img 210 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 211 | img = self.p_sample(img, torch.full( 212 | (b,), i, device=device, dtype=torch.long)) 213 | if i % sample_inter == 0: 214 | ret_img = torch.cat([ret_img, img], dim=0) 215 | return img 216 | else: 217 | x = x_in 218 | shape = x.shape 219 | b = shape[0] 220 | img = torch.randn(shape, device=device) 221 | ret_img = x 222 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 223 | img = self.p_sample(img, torch.full( 224 | (b,), i, device=device, dtype=torch.long), condition_x=x) 225 | if i % sample_inter == 0: 226 | ret_img = torch.cat([ret_img, img], dim=0) 227 | if continous: 228 | return ret_img 229 | else: 230 | return ret_img[-1] 231 | 232 | @torch.no_grad() 233 | def sample(self, batch_size=1, continous=False): 234 | image_size = self.image_size 235 | channels = self.channels 236 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous) 237 | 238 | @torch.no_grad() 239 | def super_resolution(self, x_in, continous=False): 240 | return self.p_sample_loop(x_in, continous) 241 | 242 | @torch.no_grad() 243 | def interpolate(self, x1, x2, t=None, lam=0.5): 244 | b, *_, device = *x1.shape, x1.device 245 | t = default(t, self.num_timesteps - 1) 246 | 247 | assert x1.shape == x2.shape 248 | 249 | t_batched = torch.stack([torch.tensor(t, device=device)] * b) 250 | xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) 251 | 252 | img = (1 - lam) * xt1 + lam * xt2 253 | for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t): 254 | img = self.p_sample(img, torch.full( 255 | (b,), i, device=device, dtype=torch.long)) 256 | 257 | return img 258 | 259 | def q_sample(self, x_start, t, noise=None): 260 | noise = default(noise, lambda: torch.randn_like(x_start)) 261 | 262 | # fix gama 263 | return ( 264 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 265 | extract(self.sqrt_one_minus_alphas_cumprod, 266 | t, x_start.shape) * noise 267 | ) 268 | 269 | def p_losses(self, x_in, noise=None): 270 | x_start = x_in['HR'] 271 | [b, c, h, w , l] = x_in['HR'].shape 272 | t = torch.randint(0, self.num_timesteps, (b,), 273 | device=x_start.device).long() 274 | 275 | noise = default(noise, lambda: torch.randn_like(x_start)) 276 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 277 | 278 | if not self.conditional: 279 | x_recon = self.denoise_fn(x_noisy, t) 280 | else: 281 | x_recon = self.denoise_fn( 282 | torch.cat([x_in['SR'], x_noisy], dim=1), t) 283 | loss = self.loss_func(noise, x_recon) 284 | 285 | return loss 286 | 287 | def forward(self, x, *args, **kwargs): 288 | return self.p_losses(x, *args, **kwargs) 289 | -------------------------------------------------------------------------------- /model/ddpm_modules/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from inspect import isfunction 8 | 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 20 | class PositionalEncoding(nn.Module): 21 | def __init__(self, dim): 22 | super().__init__() 23 | self.dim = dim 24 | 25 | def forward(self, noise_level): 26 | count = self.dim // 2 27 | step = torch.arange(count, dtype=noise_level.dtype, 28 | device=noise_level.device) / count 29 | encoding = noise_level.unsqueeze( 30 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 31 | encoding = torch.cat( 32 | [torch.sin(encoding), torch.cos(encoding)], dim=-1) 33 | return encoding 34 | 35 | 36 | class FeatureWiseAffine(nn.Module): 37 | def __init__(self, in_channels, out_channels, use_affine_level=False): 38 | super(FeatureWiseAffine, self).__init__() 39 | 40 | self.use_affine_level = use_affine_level 41 | self.noise_func = nn.Sequential( 42 | nn.Linear(in_channels,out_channels*(1+self.use_affine_level)) 43 | ) 44 | 45 | def forward(self, x, noise_embed): 46 | batch = x.shape[0] 47 | if noise_embed is None: 48 | return x 49 | elif self.use_affine_level: 50 | gamma, beta = self.noise_func(noise_embed).view( 51 | batch, -1, 1, 1, 1).chunk(2, dim=1) 52 | x = (1 + gamma) * x + beta 53 | else: 54 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1, 1) 55 | return x 56 | 57 | 58 | class Swish(nn.Module): 59 | def forward(self, x): 60 | return x * torch.sigmoid(x) 61 | 62 | 63 | class Upsample(nn.Module): 64 | def __init__(self, dim): 65 | super().__init__() 66 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 67 | self.conv = nn.Conv3d(dim, dim, 3, padding=1) 68 | 69 | def forward(self, x): 70 | return self.conv(self.up(x)) 71 | 72 | 73 | class Downsample(nn.Module): 74 | def __init__(self, dim): 75 | super().__init__() 76 | self.conv = nn.Conv3d(dim, dim, 3, 2, 1) 77 | 78 | def forward(self, x): 79 | return self.conv(x) 80 | 81 | 82 | # building block modules 83 | 84 | 85 | class Block(nn.Module): 86 | def __init__(self, dim, dim_out, groups=16, dropout=0): 87 | super().__init__() 88 | self.block = nn.Sequential( 89 | nn.GroupNorm(groups, dim), 90 | Swish(), 91 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 92 | nn.Conv3d(dim, dim_out, 3, padding=1) 93 | ) 94 | 95 | def forward(self, x): 96 | return self.block(x) 97 | 98 | 99 | class ResnetBlock(nn.Module): 100 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=16): 101 | super().__init__() 102 | if noise_level_emb_dim is not None: 103 | self.noise_func = FeatureWiseAffine( 104 | noise_level_emb_dim, dim_out, use_affine_level) 105 | 106 | self.block1 = Block(dim, dim_out, groups=norm_groups) 107 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 108 | self.res_conv = nn.Conv3d( 109 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 110 | 111 | def forward(self, x, time_emb): 112 | b, c, h, w, d = x.shape 113 | h = self.block1(x) 114 | if time_emb is not None: 115 | h = self.noise_func(h, time_emb) 116 | h = self.block2(h) 117 | return h + self.res_conv(x) 118 | 119 | 120 | class SelfAttention(nn.Module): 121 | def __init__(self, in_channel, n_head=1, norm_groups=32): 122 | super().__init__() 123 | 124 | self.n_head = n_head 125 | 126 | self.norm = nn.GroupNorm(norm_groups, in_channel) 127 | self.qkv = nn.Conv3d(in_channel, in_channel * 3, 1, bias=False) 128 | self.out = nn.Conv3d(in_channel, in_channel, 1) 129 | 130 | def forward(self, input): 131 | batch, channel, height, width, depth = input.shape 132 | n_head = self.n_head 133 | head_dim = channel // n_head 134 | 135 | norm = self.norm(input) 136 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width, depth) 137 | query, key, value = qkv.chunk(3, dim=2) 138 | 139 | attn = torch.einsum( 140 | "bnchwd, bncyxz -> bnhwdyxz", query, key 141 | ).contiguous() / math.sqrt(channel) 142 | attn = attn.view(batch, n_head, height, width, depth, -1) 143 | attn = torch.softmax(attn, -1) 144 | attn = attn.view(batch, n_head, height, width, depth, height, width, depth) 145 | 146 | out = torch.einsum("bnhwdyxz, bncyxz -> bnchwd", attn, value).contiguous() 147 | out = self.out(out.view(batch, channel, height, width, depth)) 148 | 149 | return out + input 150 | 151 | 152 | class ResnetBlocWithAttn(nn.Module): 153 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 154 | super().__init__() 155 | self.with_attn = with_attn 156 | self.res_block = ResnetBlock( 157 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 158 | if with_attn: 159 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 160 | 161 | def forward(self, x, time_emb): 162 | x = self.res_block(x, time_emb) 163 | if(self.with_attn): 164 | x = self.attn(x) 165 | return x 166 | 167 | 168 | class UNet(nn.Module): 169 | def __init__( 170 | self, 171 | in_channel=1, 172 | out_channel=1, 173 | inner_channel=32, 174 | norm_groups=16, 175 | channel_mults=(1, 2, 4, 8, 8), 176 | attn_res=(8,), 177 | res_blocks=3, 178 | dropout=0, 179 | with_noise_level_emb=False, 180 | image_size=64 181 | ): 182 | super().__init__() 183 | 184 | if with_noise_level_emb: 185 | noise_level_channel = inner_channel 186 | self.noise_level_mlp = nn.Sequential( 187 | PositionalEncoding(inner_channel), 188 | nn.Linear(inner_channel, inner_channel * 4), 189 | Swish(), 190 | nn.Linear(inner_channel * 4, inner_channel) 191 | ) 192 | else: 193 | noise_level_channel = None 194 | self.noise_level_mlp = None 195 | 196 | num_mults = len(channel_mults) 197 | pre_channel = inner_channel 198 | feat_channels = [pre_channel] 199 | now_res = image_size 200 | downs = [nn.Conv3d(in_channel, inner_channel, 201 | kernel_size=3, padding=1)] 202 | for ind in range(num_mults): 203 | is_last = (ind == num_mults - 1) 204 | use_attn = (now_res in attn_res) 205 | channel_mult = inner_channel * channel_mults[ind] 206 | for _ in range(0, res_blocks): 207 | downs.append(ResnetBlocWithAttn( 208 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 209 | feat_channels.append(channel_mult) 210 | pre_channel = channel_mult 211 | if not is_last: 212 | downs.append(Downsample(pre_channel)) 213 | feat_channels.append(pre_channel) 214 | now_res = now_res//2 215 | self.downs = nn.ModuleList(downs) 216 | 217 | self.mid = nn.ModuleList([ 218 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 219 | dropout=dropout, with_attn=True), 220 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 221 | dropout=dropout, with_attn=False) 222 | ]) 223 | 224 | ups = [] 225 | for ind in reversed(range(num_mults)): 226 | is_last = (ind < 1) 227 | use_attn = (now_res in attn_res) 228 | channel_mult = inner_channel * channel_mults[ind] 229 | for _ in range(0, res_blocks+1): 230 | ups.append(ResnetBlocWithAttn( 231 | pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 232 | dropout=dropout, with_attn=use_attn)) 233 | pre_channel = channel_mult 234 | if not is_last: 235 | ups.append(Upsample(pre_channel)) 236 | now_res = now_res*2 237 | 238 | self.ups = nn.ModuleList(ups) 239 | 240 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 241 | 242 | def forward(self, x, time): 243 | t = self.noise_level_mlp(time) if exists( 244 | self.noise_level_mlp) else None 245 | 246 | feats = [] 247 | for layer in self.downs: 248 | if isinstance(layer, ResnetBlocWithAttn): 249 | x = layer(x, t) 250 | else: 251 | x = layer(x) 252 | feats.append(x) 253 | 254 | for layer in self.mid: 255 | if isinstance(layer, ResnetBlocWithAttn): 256 | 257 | x = layer(x, t) 258 | else: 259 | 260 | x = layer(x) 261 | 262 | for layer in self.ups: 263 | if isinstance(layer, ResnetBlocWithAttn): 264 | x = layer(torch.cat((x, feats.pop()), dim=1), t) 265 | else: 266 | 267 | x = layer(x) 268 | 269 | return self.final_conv(x) 270 | 271 | 272 | if __name__ == "__main__": 273 | model = UNet().to("cuda") 274 | noise_level = torch.FloatTensor( 275 | [0.5]).repeat(1, 1).to("cuda") 276 | x = torch.randn(1, 1, 64, 64, 64).to("cuda") 277 | y = model(x, noise_level) 278 | print(y.shape) -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import model.networks as networks 8 | from .base_model import BaseModel 9 | logger = logging.getLogger('base') 10 | 11 | 12 | class DDPM(BaseModel): 13 | def __init__(self, opt): 14 | super(DDPM, self).__init__(opt) 15 | # define network and load pretrained models 16 | self.netP = self.set_device(networks.define_P(opt)) 17 | self.netG = self.set_device(networks.define_G(opt)) 18 | self.schedule_phase = None 19 | # set loss and load resume state 20 | self.loss_func = nn.L1Loss(reduction='sum').to(self.device) 21 | self.lr = opt['train']["optimizer"]["lr"] 22 | self.old_lr = self.lr 23 | self.set_loss() 24 | self.set_new_noise_schedule( 25 | opt['model']['beta_schedule']['train'], schedule_phase='train') 26 | if self.opt['phase'] == 'train': 27 | self.netG.train() 28 | self.netP.train() 29 | # find the parameters to optimize 30 | if opt['model']['finetune_norm']: 31 | optim_params = [] 32 | optim_params_P = [] 33 | for k, v in self.netG.named_parameters(): 34 | v.requires_grad = False 35 | if k.find('transformer') >= 0: 36 | v.requires_grad = True 37 | v.data.zero_() 38 | optim_params.append(v) 39 | logger.info( 40 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k)) 41 | for k, v in self.netP.named_parameters(): 42 | v.requires_grad = False 43 | if k.find('transformer') >= 0: 44 | v.requires_grad = True 45 | v.data.zero_() 46 | optim_params.append(v) 47 | logger.info( 48 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k)) 49 | else: 50 | optim_params = list(self.netG.parameters()) 51 | optim_params_P = list(self.netP.parameters()) 52 | self.optG = torch.optim.Adam( 53 | optim_params, lr=opt['train']["optimizer"]["lr"]) 54 | self.optP = torch.optim.Adam( 55 | optim_params_P, lr=opt['train']["optimizer"]["lr"]) 56 | self.log_dict = OrderedDict() 57 | self.load_network() 58 | self.print_network() 59 | 60 | def feed_data(self, data): 61 | self.data = self.set_device(data) 62 | 63 | def optimize_parameters(self): 64 | self.optG.zero_grad() 65 | self.optP.zero_grad() 66 | # 采样得到Prenet结果 67 | self.initial_predict() 68 | # 计算残差并作为loss的x_start 69 | self.data['IP'] = self.IP 70 | self.data['RS'] = self.data['HR'] - self.IP 71 | l_pix = self.netG(self.data) 72 | # need to average in multi-gpu 73 | b, c, h, w = self.data['HR'].shape 74 | l_pix = (l_pix.sum())/int(b*c*h*w) 75 | l_pix.backward() 76 | # 更新两个网络 77 | self.optG.step() 78 | self.optP.step() 79 | # set log 80 | self.log_dict['l_pix'] = l_pix.item() 81 | # self.log_dict['loss_pix'] = l_loss.item() 82 | def initial_predict(self): 83 | self.IP = self.netP(self.data['SR'],time = None) 84 | 85 | def test(self, continous=False): 86 | self.netG.eval() 87 | self.netP.eval() 88 | with torch.no_grad(): 89 | 90 | if isinstance(self.netG, nn.DataParallel): 91 | self.SR = self.netG.module.super_resolution( 92 | self.data['SR'], continous) 93 | else: 94 | self.SR = self.netG.super_resolution( 95 | self.data['SR'], continous) 96 | self.netG.train() 97 | self.netP.train() 98 | 99 | def sample(self, batch_size=1, continous=False): 100 | self.netG.eval() 101 | with torch.no_grad(): 102 | if isinstance(self.netG, nn.DataParallel): 103 | self.SR = self.netG.module.sample(batch_size, continous) 104 | else: 105 | self.SR = self.netG.sample(batch_size, continous) 106 | self.netG.train() 107 | 108 | def set_loss(self): 109 | if isinstance(self.netG, nn.DataParallel): 110 | self.netG.module.set_loss(self.device) 111 | else: 112 | self.netG.set_loss(self.device) 113 | 114 | 115 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'): 116 | if self.schedule_phase is None or self.schedule_phase != schedule_phase: 117 | self.schedule_phase = schedule_phase 118 | if isinstance(self.netG, nn.DataParallel): 119 | self.netG.module.set_new_noise_schedule( 120 | schedule_opt, self.device) 121 | else: 122 | self.netG.set_new_noise_schedule(schedule_opt, self.device) 123 | 124 | 125 | def get_current_log(self): 126 | return self.log_dict 127 | 128 | def get_current_visuals(self, need_LR=True, sample=False): 129 | out_dict = OrderedDict() 130 | if sample: 131 | out_dict['SAM'] = self.SR.detach().float().cpu() 132 | else: 133 | out_dict['SR'] = self.SR.detach().float().cpu() 134 | out_dict['INF'] = self.data['SR'].detach().float().cpu() 135 | out_dict['HR'] = self.data['HR'].detach().float().cpu() 136 | if need_LR and 'LR' in self.data: 137 | out_dict['LR'] = self.data['LR'].detach().float().cpu() 138 | else: 139 | out_dict['LR'] = out_dict['INF'] 140 | return out_dict 141 | 142 | def print_network(self): 143 | s, n = self.get_network_description(self.netG) 144 | if isinstance(self.netG, nn.DataParallel): 145 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 146 | self.netG.module.__class__.__name__) 147 | else: 148 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 149 | 150 | logger.info( 151 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 152 | logger.info(s) 153 | 154 | def save_network(self, epoch, iter_step): 155 | # Prenet保存 156 | gen_path = os.path.join( 157 | self.opt['path']['checkpoint'], 'I{}_E{}_PreNet_gen.pth'.format(iter_step, epoch)) 158 | opt_path = os.path.join( 159 | self.opt['path']['checkpoint'], 'I{}_E{}_PreNet_opt.pth'.format(iter_step, epoch)) 160 | # gen 161 | network = self.netP 162 | if isinstance(self.netP, nn.DataParallel): 163 | network = network.module 164 | state_dict = network.state_dict() 165 | for key, param in state_dict.items(): 166 | state_dict[key] = param.cpu() 167 | torch.save(state_dict, gen_path) 168 | # opt 169 | opt_state = {'epoch': epoch, 'iter': iter_step, 170 | 'scheduler': None, 'optimizer': None} 171 | opt_state['optimizer'] = self.optP.state_dict() 172 | torch.save(opt_state, opt_path) 173 | 174 | # DenoiseNet 保存 175 | gen_path = os.path.join( 176 | self.opt['path']['checkpoint'], 'I{}_E{}_DenoiseNet_gen.pth'.format(iter_step, epoch)) 177 | opt_path = os.path.join( 178 | self.opt['path']['checkpoint'], 'I{}_E{}_DenoiseNet_opt.pth'.format(iter_step, epoch)) 179 | # gen 180 | network = self.netG 181 | if isinstance(self.netG, nn.DataParallel): 182 | network = network.module 183 | state_dict = network.state_dict() 184 | for key, param in state_dict.items(): 185 | state_dict[key] = param.cpu() 186 | torch.save(state_dict, gen_path) 187 | # opt 188 | opt_state = {'epoch': epoch, 'iter': iter_step, 189 | 'scheduler': None, 'optimizer': None} 190 | opt_state['optimizer'] = self.optG.state_dict() 191 | torch.save(opt_state, opt_path) 192 | 193 | logger.info( 194 | 'Saved model in [{:s}] ...'.format(gen_path)) 195 | 196 | def load_network(self): 197 | # Prenet加载 198 | if self.opt['path']['resume_state'] is not None: 199 | load_path = self.opt['path']['resume_state'] 200 | logger.info( 201 | 'Loading pretrained model for G [{:s}] ...'.format(load_path)) 202 | gen_path = '{}_PreNet_gen.pth'.format(load_path) 203 | opt_path = '{}_PreNet_opt.pth'.format(load_path) 204 | # gen 205 | network = self.netP 206 | if isinstance(self.netP, nn.DataParallel): 207 | network = network.module 208 | network.load_state_dict(torch.load( 209 | gen_path), strict=(not self.opt['model']['finetune_norm'])) 210 | # network.load_state_dict(torch.load( 211 | # gen_path), strict=False) 212 | if self.opt['phase'] == 'train': 213 | # optimizer 214 | opt = torch.load(opt_path) 215 | self.optP.load_state_dict(opt['optimizer']) 216 | self.begin_step = opt['iter'] 217 | self.begin_epoch = opt['epoch'] 218 | 219 | # DenoiseNet加载 220 | if self.opt['path']['resume_state'] is not None: 221 | load_path = self.opt['path']['resume_state'] 222 | logger.info( 223 | 'Loading pretrained model for G [{:s}] ...'.format(load_path)) 224 | gen_path = '{}_DenoiseNet_gen.pth'.format(load_path) 225 | opt_path = '{}_DenoiseNet_opt.pth'.format(load_path) 226 | # gen 227 | network = self.netG 228 | if isinstance(self.netG, nn.DataParallel): 229 | network = network.module 230 | network.load_state_dict(torch.load( 231 | gen_path), strict=(not self.opt['model']['finetune_norm'])) 232 | # network.load_state_dict(torch.load( 233 | # gen_path), strict=False) 234 | if self.opt['phase'] == 'train': 235 | # optimizer 236 | opt = torch.load(opt_path) 237 | self.optG.load_state_dict(opt['optimizer']) 238 | self.begin_step = opt['iter'] 239 | self.begin_epoch = opt['epoch'] 240 | def update_learning_rate(self): 241 | self.niter_decay = 1000000 242 | if self.old_lr > 0.000001: 243 | lrd = 200 * self.lr / self.niter_decay 244 | lr = self.old_lr - lrd 245 | else: 246 | lr = self.old_lr 247 | for param_group in self.optP.param_groups: 248 | param_group['lr'] = lr 249 | for param_group in self.optG.param_groups: 250 | param_group['lr'] = lr 251 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 252 | self.old_lr = lr -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn import modules 7 | logger = logging.getLogger('base') 8 | #################### 9 | # initialize 10 | #################### 11 | 12 | 13 | def weights_init_normal(m, std=0.02): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.normal_(m.weight.data, 0.0, std) 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | elif classname.find('Linear') != -1: 20 | init.normal_(m.weight.data, 0.0, std) 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def weights_init_kaiming(m, scale=1): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv2d') != -1: 31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 32 | m.weight.data *= scale 33 | if m.bias is not None: 34 | m.bias.data.zero_() 35 | elif classname.find('Linear') != -1: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | m.weight.data *= scale 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | elif classname.find('BatchNorm2d') != -1: 41 | init.constant_(m.weight.data, 1.0) 42 | init.constant_(m.bias.data, 0.0) 43 | 44 | 45 | def weights_init_orthogonal(m): 46 | classname = m.__class__.__name__ 47 | if classname.find('Conv') != -1: 48 | init.orthogonal_(m.weight.data, gain=1) 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | elif classname.find('Linear') != -1: 52 | init.orthogonal_(m.weight.data, gain=1) 53 | if m.bias is not None: 54 | m.bias.data.zero_() 55 | elif classname.find('BatchNorm2d') != -1: 56 | init.constant_(m.weight.data, 1.0) 57 | init.constant_(m.bias.data, 0.0) 58 | 59 | 60 | def init_weights(net, init_type='kaiming', scale=1, std=0.02): 61 | # scale for 'kaiming', std for 'normal'. 62 | logger.info('Initialization method [{:s}]'.format(init_type)) 63 | if init_type == 'normal': 64 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 65 | net.apply(weights_init_normal_) 66 | elif init_type == 'kaiming': 67 | weights_init_kaiming_ = functools.partial( 68 | weights_init_kaiming, scale=scale) 69 | net.apply(weights_init_kaiming_) 70 | elif init_type == 'orthogonal': 71 | net.apply(weights_init_orthogonal) 72 | else: 73 | raise NotImplementedError( 74 | 'initialization method [{:s}] not implemented'.format(init_type)) 75 | 76 | 77 | #################### 78 | # define network 79 | #################### 80 | def define_P(opt): 81 | model_opt = opt['model'] 82 | if model_opt['which_model_G'] == 'ddpm': 83 | from .ddpm_modules import diffusion, unet 84 | elif model_opt['which_model_G'] == 'sr3': 85 | from .sr3_modules import diffusion, unet 86 | if ('norm_groups' not in model_opt['unet']['PreNet']) or model_opt['unet']['PreNet']['norm_groups'] is None: 87 | model_opt['unet']['PreNet']['norm_groups']=32 88 | model = unet.UNet( 89 | in_channel=model_opt['unet']['PreNet']['in_channel'], 90 | out_channel=model_opt['unet']['PreNet']['out_channel'], 91 | norm_groups=model_opt['unet']['PreNet']['norm_groups'], 92 | inner_channel=model_opt['unet']['PreNet']['inner_channel'], 93 | channel_mults=model_opt['unet']['PreNet']['channel_multiplier'], 94 | attn_res=model_opt['unet']['PreNet']['attn_res'], 95 | res_blocks=model_opt['unet']['PreNet']['res_blocks'], 96 | dropout=model_opt['unet']['PreNet']['dropout'], 97 | with_noise_level_emb= False, 98 | image_size=model_opt['diffusion']['image_size'] 99 | ) 100 | 101 | if opt['phase'] == 'train': 102 | # init_weights(netG, init_type='kaiming', scale=0.1) 103 | init_weights(model, init_type='orthogonal') 104 | if opt['gpu_ids'] and opt['distributed']: 105 | assert torch.cuda.is_available() 106 | model = nn.DataParallel(model) 107 | return model 108 | 109 | # Generator 110 | def define_G(opt): 111 | model_opt = opt['model'] 112 | if model_opt['which_model_G'] == 'ddpm': 113 | from .ddpm_modules import diffusion, unet 114 | elif model_opt['which_model_G'] == 'sr3': 115 | from .sr3_modules import diffusion, unet 116 | if ('norm_groups' not in model_opt['unet']['DenoiseNet']) or model_opt['unet']['DenoiseNet']['norm_groups'] is None: 117 | model_opt['unet']['DenoiseNet']['norm_groups']=32 118 | model = unet.UNet( 119 | in_channel=model_opt['unet']['DenoiseNet']['in_channel'], 120 | out_channel=model_opt['unet']['DenoiseNet']['out_channel'], 121 | norm_groups=model_opt['unet']['DenoiseNet']['norm_groups'], 122 | inner_channel=model_opt['unet']['DenoiseNet']['inner_channel'], 123 | channel_mults=model_opt['unet']['DenoiseNet']['channel_multiplier'], 124 | attn_res=model_opt['unet']['DenoiseNet']['attn_res'], 125 | res_blocks=model_opt['unet']['DenoiseNet']['res_blocks'], 126 | dropout=model_opt['unet']['DenoiseNet']['dropout'], 127 | image_size=model_opt['diffusion']['image_size'] 128 | ) 129 | netG = diffusion.GaussianDiffusion( 130 | model, 131 | image_size=model_opt['diffusion']['image_size'], 132 | channels=model_opt['diffusion']['channels'], 133 | loss_type='l1', # L1 or L2 134 | conditional=model_opt['diffusion']['conditional'], 135 | schedule_opt=model_opt['beta_schedule']['train'] 136 | ) 137 | if opt['phase'] == 'train': 138 | # init_weights(netG, init_type='kaiming', scale=0.1) 139 | init_weights(netG, init_type='orthogonal') 140 | if opt['gpu_ids'] and opt['distributed']: 141 | assert torch.cuda.is_available() 142 | netG = nn.DataParallel(netG) 143 | return netG 144 | 145 | -------------------------------------------------------------------------------- /model/sr3_modules/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import device, nn, einsum 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from functools import partial 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | 11 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 12 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 13 | warmup_time = int(n_timestep * warmup_frac) 14 | betas[:warmup_time] = np.linspace( 15 | linear_start, linear_end, warmup_time, dtype=np.float64) 16 | return betas 17 | 18 | 19 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 20 | if schedule == 'quad': 21 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 22 | n_timestep, dtype=np.float64) ** 2 23 | elif schedule == 'linear': 24 | betas = np.linspace(linear_start, linear_end, 25 | n_timestep, dtype=np.float64) 26 | elif schedule == 'warmup10': 27 | betas = _warmup_beta(linear_start, linear_end, 28 | n_timestep, 0.1) 29 | elif schedule == 'warmup50': 30 | betas = _warmup_beta(linear_start, linear_end, 31 | n_timestep, 0.5) 32 | elif schedule == 'const': 33 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 34 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 35 | betas = 1. / np.linspace(n_timestep, 36 | 1, n_timestep, dtype=np.float64) 37 | elif schedule == "cosine": 38 | timesteps = ( 39 | torch.arange(n_timestep + 1, dtype=torch.float64) / 40 | n_timestep + cosine_s 41 | ) 42 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 43 | alphas = torch.cos(alphas).pow(2) 44 | alphas = alphas / alphas[0] 45 | betas = 1 - alphas[1:] / alphas[:-1] 46 | betas = betas.clamp(max=0.999) 47 | else: 48 | raise NotImplementedError(schedule) 49 | return betas 50 | 51 | 52 | # gaussian diffusion trainer class 53 | 54 | def exists(x): 55 | return x is not None 56 | 57 | 58 | def default(val, d): 59 | if exists(val): 60 | return val 61 | return d() if isfunction(d) else d 62 | 63 | 64 | class GaussianDiffusion(nn.Module): 65 | def __init__( 66 | self, 67 | denoise_fn, 68 | image_size, 69 | channels=1, 70 | loss_type='l1', 71 | conditional=True, 72 | schedule_opt=None 73 | ): 74 | super().__init__() 75 | self.channels = channels 76 | self.image_size = image_size 77 | self.denoise_fn = denoise_fn 78 | self.loss_type = loss_type 79 | self.conditional = conditional 80 | if schedule_opt is not None: 81 | pass 82 | # self.set_new_noise_schedule(schedule_opt) 83 | 84 | def set_loss(self, device): 85 | if self.loss_type == 'l1': 86 | self.loss_func = nn.L1Loss(reduction='sum').to(device) 87 | elif self.loss_type == 'l2': 88 | self.loss_func = nn.MSELoss(reduction='sum').to(device) 89 | else: 90 | raise NotImplementedError() 91 | 92 | def set_new_noise_schedule(self, schedule_opt, device): 93 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 94 | 95 | betas = make_beta_schedule( 96 | schedule=schedule_opt['schedule'], 97 | n_timestep=schedule_opt['n_timestep'], 98 | linear_start=schedule_opt['linear_start'], 99 | linear_end=schedule_opt['linear_end']) 100 | betas = betas.detach().cpu().numpy() if isinstance( 101 | betas, torch.Tensor) else betas 102 | alphas = 1. - betas 103 | alphas_cumprod = np.cumprod(alphas, axis=0) 104 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 105 | self.sqrt_alphas_cumprod_prev = np.sqrt( 106 | np.append(1., alphas_cumprod)) 107 | 108 | timesteps, = betas.shape 109 | self.num_timesteps = int(timesteps) 110 | self.register_buffer('betas', to_torch(betas)) 111 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 112 | self.register_buffer('alphas_cumprod_prev', 113 | to_torch(alphas_cumprod_prev)) 114 | 115 | # calculations for diffusion q(x_t | x_{t-1}) and others 116 | self.register_buffer('sqrt_alphas_cumprod', 117 | to_torch(np.sqrt(alphas_cumprod))) 118 | self.register_buffer('sqrt_one_minus_alphas_cumprod', 119 | to_torch(np.sqrt(1. - alphas_cumprod))) 120 | self.register_buffer('log_one_minus_alphas_cumprod', 121 | to_torch(np.log(1. - alphas_cumprod))) 122 | self.register_buffer('sqrt_recip_alphas_cumprod', 123 | to_torch(np.sqrt(1. / alphas_cumprod))) 124 | self.register_buffer('sqrt_recipm1_alphas_cumprod', 125 | to_torch(np.sqrt(1. / alphas_cumprod - 1))) 126 | 127 | # calculations for posterior q(x_{t-1} | x_t, x_0) 128 | posterior_variance = betas * \ 129 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 130 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 131 | self.register_buffer('posterior_variance', 132 | to_torch(posterior_variance)) 133 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 134 | self.register_buffer('posterior_log_variance_clipped', to_torch( 135 | np.log(np.maximum(posterior_variance, 1e-20)))) 136 | self.register_buffer('posterior_mean_coef1', to_torch( 137 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 138 | self.register_buffer('posterior_mean_coef2', to_torch( 139 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 140 | 141 | def predict_start_from_noise(self, x_t, t, noise): 142 | return self.sqrt_recip_alphas_cumprod[t] * x_t - \ 143 | self.sqrt_recipm1_alphas_cumprod[t] * noise 144 | 145 | def q_posterior(self, x_start, x_t, t): 146 | posterior_mean = self.posterior_mean_coef1[t] * \ 147 | x_start + self.posterior_mean_coef2[t] * x_t 148 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 149 | return posterior_mean, posterior_log_variance_clipped 150 | 151 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None): 152 | batch_size = x.shape[0] 153 | noise_level = torch.FloatTensor( 154 | [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device) 155 | if condition_x is not None: 156 | x_recon = self.predict_start_from_noise( 157 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)) 158 | else: 159 | x_recon = self.predict_start_from_noise( 160 | x, t=t, noise=self.denoise_fn(x, noise_level)) 161 | 162 | if clip_denoised: 163 | x_recon.clamp_(-1., 1.) 164 | 165 | model_mean, posterior_log_variance = self.q_posterior( 166 | x_start=x_recon, x_t=x, t=t) 167 | return model_mean, posterior_log_variance 168 | 169 | @torch.no_grad() 170 | def p_sample(self, x, t, clip_denoised=True, condition_x=None): 171 | model_mean, model_log_variance = self.p_mean_variance( 172 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x) 173 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 174 | return model_mean + noise * (0.5 * model_log_variance).exp() 175 | 176 | @torch.no_grad() 177 | def p_sample_loop(self, x_in, continous=False): 178 | device = self.betas.device 179 | sample_inter = (1 | (self.num_timesteps//10)) 180 | if not self.conditional: 181 | shape = x_in 182 | img = torch.randn(shape, device=device) 183 | ret_img = img 184 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 185 | img = self.p_sample(img, i) 186 | if i % sample_inter == 0: 187 | ret_img = torch.cat([ret_img, img], dim=0) 188 | else: 189 | x = x_in 190 | shape = x.shape 191 | img = torch.randn(shape, device=device) 192 | ret_img = x 193 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 194 | img = self.p_sample(img, i, condition_x=x) 195 | if i % sample_inter == 0: 196 | ret_img = torch.cat([ret_img, img], dim=0) 197 | if continous: 198 | return ret_img 199 | else: 200 | return ret_img[-1] 201 | 202 | @torch.no_grad() 203 | def sample(self, batch_size=1, continous=False): 204 | image_size = self.image_size 205 | channels = self.channels 206 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous) 207 | 208 | @torch.no_grad() 209 | def super_resolution(self, x_in, continous=False): 210 | return self.p_sample_loop(x_in, continous) 211 | 212 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None): 213 | noise = default(noise, lambda: torch.randn_like(x_start)) 214 | 215 | # random gama 216 | return ( 217 | continuous_sqrt_alpha_cumprod * x_start + 218 | (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise 219 | ) 220 | 221 | def p_losses(self, x_in, noise=None): 222 | # x_start = x_in['IP'] 223 | x_start = x_in['RS'] 224 | [b, c, h, w] = x_start.shape 225 | t = np.random.randint(1, self.num_timesteps + 1) 226 | continuous_sqrt_alpha_cumprod = torch.FloatTensor( 227 | np.random.uniform( 228 | self.sqrt_alphas_cumprod_prev[t-1], 229 | self.sqrt_alphas_cumprod_prev[t], 230 | size=b 231 | ) 232 | ).to(x_start.device) 233 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view( 234 | b, -1) 235 | 236 | noise = default(noise, lambda: torch.randn_like(x_start)) 237 | x_noisy = self.q_sample( 238 | x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise) 239 | 240 | if not self.conditional: 241 | x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod) 242 | else: 243 | x_recon = self.denoise_fn(torch.cat([x_in['SR'],x_noisy], dim=1), continuous_sqrt_alpha_cumprod) 244 | loss = self.loss_func(noise, x_recon) 245 | # loss = self.loss_func(x_start,x_in['HR']) 246 | return loss 247 | # 248 | def forward(self, x, *args, **kwargs): 249 | return self.p_losses(x, *args, **kwargs) 250 | 251 | -------------------------------------------------------------------------------- /model/sr3_modules/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | 7 | 8 | def exists(x): 9 | return x is not None 10 | 11 | 12 | def default(val, d): 13 | if exists(val): 14 | return val 15 | return d() if isfunction(d) else d 16 | 17 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 18 | class PositionalEncoding(nn.Module): 19 | def __init__(self, dim): 20 | super().__init__() 21 | self.dim = dim 22 | 23 | def forward(self, noise_level): 24 | count = self.dim // 2 25 | step = torch.arange(count, dtype=noise_level.dtype, 26 | device=noise_level.device) / count 27 | encoding = noise_level.unsqueeze( 28 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 29 | encoding = torch.cat( 30 | [torch.sin(encoding), torch.cos(encoding)], dim=-1) 31 | return encoding 32 | 33 | 34 | class FeatureWiseAffine(nn.Module): 35 | def __init__(self, in_channels, out_channels, use_affine_level=False): 36 | super(FeatureWiseAffine, self).__init__() 37 | self.use_affine_level = use_affine_level 38 | self.noise_func = nn.Sequential( 39 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level)) 40 | ) 41 | 42 | def forward(self, x, noise_embed): 43 | batch = x.shape[0] 44 | if self.use_affine_level: 45 | gamma, beta = self.noise_func(noise_embed).view( 46 | batch, -1, 1, 1).chunk(2, dim=1) 47 | x = (1 + gamma) * x + beta 48 | else: 49 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) 50 | return x 51 | 52 | 53 | class Swish(nn.Module): 54 | def forward(self, x): 55 | return x * torch.sigmoid(x) 56 | 57 | 58 | class Upsample(nn.Module): 59 | def __init__(self, dim): 60 | super().__init__() 61 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 62 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 63 | 64 | def forward(self, x): 65 | return self.conv(self.up(x)) 66 | 67 | 68 | class Downsample(nn.Module): 69 | def __init__(self, dim): 70 | super().__init__() 71 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 72 | 73 | def forward(self, x): 74 | return self.conv(x) 75 | 76 | 77 | # building block modules 78 | 79 | 80 | class Block(nn.Module): 81 | def __init__(self, dim, dim_out, groups=32, dropout=0): 82 | super().__init__() 83 | self.block = nn.Sequential( 84 | nn.GroupNorm(groups, dim), 85 | Swish(), 86 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 87 | nn.Conv2d(dim, dim_out, 3, padding=1) 88 | ) 89 | 90 | def forward(self, x): 91 | return self.block(x) 92 | 93 | 94 | class ResnetBlock(nn.Module): 95 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): 96 | super().__init__() 97 | if noise_level_emb_dim is not None: 98 | self.noise_func = FeatureWiseAffine( 99 | noise_level_emb_dim, dim_out, use_affine_level) 100 | 101 | self.block1 = Block(dim, dim_out, groups=norm_groups) 102 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 103 | self.res_conv = nn.Conv2d( 104 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 105 | 106 | def forward(self, x, time_emb): 107 | b, c, h, w = x.shape 108 | h = self.block1(x) 109 | if time_emb is not None: 110 | h = self.noise_func(h, time_emb) 111 | h = self.block2(h) 112 | return h + self.res_conv(x) 113 | 114 | 115 | class SelfAttention(nn.Module): 116 | def __init__(self, in_channel, n_head=1, norm_groups=32): 117 | super().__init__() 118 | 119 | self.n_head = n_head 120 | 121 | self.norm = nn.GroupNorm(norm_groups, in_channel) 122 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 123 | self.out = nn.Conv2d(in_channel, in_channel, 1) 124 | 125 | def forward(self, input): 126 | batch, channel, height, width = input.shape 127 | n_head = self.n_head 128 | head_dim = channel // n_head 129 | 130 | norm = self.norm(input) 131 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 132 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 133 | 134 | attn = torch.einsum( 135 | "bnchw, bncyx -> bnhwyx", query, key 136 | ).contiguous() / math.sqrt(channel) 137 | attn = attn.view(batch, n_head, height, width, -1) 138 | attn = torch.softmax(attn, -1) 139 | attn = attn.view(batch, n_head, height, width, height, width) 140 | 141 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 142 | out = self.out(out.view(batch, channel, height, width)) 143 | 144 | return out + input 145 | 146 | 147 | class ResnetBlocWithAttn(nn.Module): 148 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 149 | super().__init__() 150 | 151 | self.with_attn = with_attn 152 | self.res_block = ResnetBlock( 153 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 154 | if with_attn: 155 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 156 | 157 | def forward(self, x, time_emb): 158 | x = self.res_block(x, time_emb) 159 | if(self.with_attn): 160 | x = self.attn(x) 161 | return x 162 | 163 | 164 | class UNet(nn.Module): 165 | def __init__( 166 | self, 167 | in_channel=6, 168 | out_channel=3, 169 | inner_channel=32, 170 | norm_groups=32, 171 | channel_mults=(1, 2, 4, 8, 8), 172 | attn_res=(8), 173 | res_blocks=3, 174 | dropout=0, 175 | with_noise_level_emb=True, 176 | image_size=128 177 | ): 178 | super().__init__() 179 | 180 | if with_noise_level_emb: 181 | noise_level_channel = inner_channel 182 | self.noise_level_mlp = nn.Sequential( 183 | PositionalEncoding(inner_channel), 184 | nn.Linear(inner_channel, inner_channel * 4), 185 | Swish(), 186 | nn.Linear(inner_channel * 4, inner_channel) 187 | ) 188 | else: 189 | noise_level_channel = None 190 | self.noise_level_mlp = None 191 | 192 | num_mults = len(channel_mults) 193 | pre_channel = inner_channel 194 | feat_channels = [pre_channel] 195 | now_res = image_size 196 | downs = [nn.Conv2d(in_channel, inner_channel, 197 | kernel_size=3, padding=1)] 198 | for ind in range(num_mults): 199 | is_last = (ind == num_mults - 1) 200 | use_attn = (now_res in attn_res) 201 | channel_mult = inner_channel * channel_mults[ind] 202 | for _ in range(0, res_blocks): 203 | downs.append(ResnetBlocWithAttn( 204 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 205 | feat_channels.append(channel_mult) 206 | pre_channel = channel_mult 207 | if not is_last: 208 | downs.append(Downsample(pre_channel)) 209 | feat_channels.append(pre_channel) 210 | now_res = now_res//2 211 | self.downs = nn.ModuleList(downs) 212 | 213 | self.mid = nn.ModuleList([ 214 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 215 | dropout=dropout, with_attn=True), 216 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 217 | dropout=dropout, with_attn=False) 218 | ]) 219 | 220 | ups = [] 221 | for ind in reversed(range(num_mults)): 222 | is_last = (ind < 1) 223 | use_attn = (now_res in attn_res) 224 | channel_mult = inner_channel * channel_mults[ind] 225 | for _ in range(0, res_blocks+1): 226 | ups.append(ResnetBlocWithAttn( 227 | pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 228 | dropout=dropout, with_attn=use_attn)) 229 | pre_channel = channel_mult 230 | if not is_last: 231 | ups.append(Upsample(pre_channel)) 232 | now_res = now_res*2 233 | 234 | self.ups = nn.ModuleList(ups) 235 | 236 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 237 | 238 | def forward(self, x, time): 239 | t = self.noise_level_mlp(time) if exists( 240 | self.noise_level_mlp) else None 241 | 242 | feats = [] 243 | for layer in self.downs: 244 | if isinstance(layer, ResnetBlocWithAttn): 245 | x = layer(x, t) 246 | else: 247 | x = layer(x) 248 | feats.append(x) 249 | 250 | for layer in self.mid: 251 | if isinstance(layer, ResnetBlocWithAttn): 252 | x = layer(x, t) 253 | else: 254 | x = layer(x) 255 | 256 | for layer in self.ups: 257 | if isinstance(layer, ResnetBlocWithAttn): 258 | x = layer(torch.cat((x, feats.pop()), dim=1), t) 259 | else: 260 | x = layer(x) 261 | 262 | return self.final_conv(x) 263 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6 2 | torchvision 3 | numpy 4 | pandas 5 | tqdm 6 | lmdb 7 | opencv-python 8 | pillow 9 | tensorboardx 10 | wandb 11 | 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import data as Data 3 | import model as Model 4 | import argparse 5 | import logging 6 | import core.logger as Logger 7 | import core.metrics as Metrics 8 | from core.wandb_logger import WandbLogger 9 | from tensorboardX import SummaryWriter 10 | import os 11 | import numpy as np 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-c', '--config', type=str, default='config/sr_sr3_16_128.json', 16 | help='JSON file for configuration') 17 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'], 18 | help='Run either train(training) or val(generation)', default='train') 19 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 20 | parser.add_argument('-debug', '-d', action='store_true') 21 | parser.add_argument('-enable_wandb', action='store_true') 22 | parser.add_argument('-log_wandb_ckpt', action='store_true') 23 | parser.add_argument('-log_eval', action='store_true') 24 | 25 | # parse configs 26 | args = parser.parse_args() 27 | opt = Logger.parse(args) 28 | # Convert to NoneDict, which return None for missing key. 29 | opt = Logger.dict_to_nonedict(opt) 30 | 31 | # logging 32 | torch.backends.cudnn.enabled = True 33 | torch.backends.cudnn.benchmark = True 34 | 35 | Logger.setup_logger(None, opt['path']['log'], 36 | 'train', level=logging.INFO, screen=True) 37 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO) 38 | logger = logging.getLogger('base') 39 | logger.info(Logger.dict2str(opt)) 40 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger']) 41 | 42 | # Initialize WandbLogger 43 | if opt['enable_wandb']: 44 | import wandb 45 | wandb_logger = WandbLogger(opt) 46 | wandb.define_metric('validation/val_step') 47 | wandb.define_metric('epoch') 48 | wandb.define_metric("validation/*", step_metric="val_step") 49 | val_step = 0 50 | else: 51 | wandb_logger = None 52 | 53 | # dataset 54 | for phase, dataset_opt in opt['datasets'].items(): 55 | if phase == 'train' and args.phase != 'val': 56 | train_set = Data.create_dataset(dataset_opt, phase) 57 | train_loader = Data.create_dataloader( 58 | train_set, dataset_opt, phase) 59 | # elif phase == 'val': 60 | # val_set = Data.create_dataset(dataset_opt, phase) 61 | # val_loader = Data.create_dataloader( 62 | # val_set, dataset_opt, phase) 63 | logger.info('Initial Dataset Finished') 64 | 65 | # model 66 | diffusion = Model.create_model(opt) 67 | logger.info('Initial Model Finished') 68 | 69 | # Train 70 | current_step = diffusion.begin_step 71 | current_epoch = diffusion.begin_epoch 72 | n_iter = opt['train']['n_iter'] 73 | 74 | if opt['path']['resume_state']: 75 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 76 | current_epoch, current_step)) 77 | 78 | diffusion.set_new_noise_schedule( 79 | opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase']) 80 | if opt['phase'] == 'train': 81 | while current_step < n_iter: 82 | current_epoch += 1 83 | for _, train_data in enumerate(train_loader): 84 | current_step += 1 85 | if current_step > n_iter: 86 | break 87 | 88 | diffusion.feed_data(train_data) 89 | diffusion.optimize_parameters() 90 | # log 91 | if current_step % opt['train']['print_freq'] == 0: 92 | logs = diffusion.get_current_log() 93 | message = ' '.format( 94 | current_epoch, current_step) 95 | for k, v in logs.items(): 96 | message += '{:s}: {:.4e} '.format(k, v) 97 | tb_logger.add_scalar(k, v, current_step) 98 | logger.info(message) 99 | # diffusion.update_learning_rate() 100 | 101 | if wandb_logger: 102 | wandb_logger.log_metrics(logs) 103 | 104 | # # validation 105 | # if current_step % opt['train']['val_freq'] == 0: 106 | # avg_psnr = 0.0 107 | # idx = 0 108 | # result_path = '{}/{}'.format(opt['path'] 109 | # ['results'], current_epoch) 110 | # os.makedirs(result_path, exist_ok=True) 111 | # 112 | # diffusion.set_new_noise_schedule( 113 | # opt['model']['beta_schedule']['val'], schedule_phase='val') 114 | # for _, val_data in enumerate(val_loader): 115 | # idx += 1 116 | # diffusion.feed_data(val_data) 117 | # diffusion.test(continous=False) 118 | # visuals = diffusion.get_current_visuals() 119 | # sr_img = Metrics.tensor2img(visuals['SR']) # uint8 120 | # hr_img = Metrics.tensor2img(visuals['HR']) # uint8 121 | # lr_img = Metrics.tensor2img(visuals['LR']) # uint8 122 | # fake_img = Metrics.tensor2img(visuals['INF']) # uint8 123 | # 124 | # # generation 125 | # Metrics.save_img( 126 | # hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx)) 127 | # Metrics.save_img( 128 | # sr_img, '{}/{}_{}_sr.png'.format(result_path, current_step, idx)) 129 | # Metrics.save_img( 130 | # lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx)) 131 | # Metrics.save_img( 132 | # fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx)) 133 | # tb_logger.add_image( 134 | # 'Iter_{}'.format(current_step), 135 | # np.transpose(np.concatenate( 136 | # (fake_img, sr_img, hr_img), axis=1), [2, 0, 1]), 137 | # idx) 138 | # avg_psnr += Metrics.calculate_psnr( 139 | # sr_img, hr_img) 140 | # 141 | # if wandb_logger: 142 | # wandb_logger.log_image( 143 | # f'validation_{idx}', 144 | # np.concatenate((fake_img, sr_img, hr_img), axis=1) 145 | # ) 146 | # 147 | # avg_psnr = avg_psnr / idx 148 | # diffusion.set_new_noise_schedule( 149 | # opt['model']['beta_schedule']['train'], schedule_phase='train') 150 | # # log 151 | # logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) 152 | # logger_val = logging.getLogger('val') # validation logger 153 | # logger_val.info(' psnr: {:.4e}'.format( 154 | # current_epoch, current_step, avg_psnr)) 155 | # # tensorboard logger 156 | # tb_logger.add_scalar('psnr', avg_psnr, current_step) 157 | # 158 | # if wandb_logger: 159 | # wandb_logger.log_metrics({ 160 | # 'validation/val_psnr': avg_psnr, 161 | # 'validation/val_step': val_step 162 | # }) 163 | # val_step += 1 164 | 165 | if current_step % opt['train']['save_checkpoint_freq'] == 0: 166 | logger.info('Saving models and training states.') 167 | diffusion.save_network(current_epoch, current_step) 168 | 169 | if wandb_logger and opt['log_wandb_ckpt']: 170 | wandb_logger.log_checkpoint(current_epoch, current_step) 171 | 172 | if wandb_logger: 173 | wandb_logger.log_metrics({'epoch': current_epoch-1}) 174 | 175 | # save model 176 | logger.info('End of training.') --------------------------------------------------------------------------------