├── .idea
├── .gitignore
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── pet.iml
└── vcs.xml
├── config
├── sample_ddpm_128.json
├── sample_sr3_128.json
├── sr_ddpm_16_128.json
├── sr_sr3_16_128.json
└── sr_sr3_64_512.json
├── core
├── logger.py
├── metrics.py
└── wandb_logger.py
├── data
├── LRHR_dataset.py
├── __init__.py
├── dataloader.py
├── prepare_data.py
└── util.py
├── easy_train.py
├── inference.py
├── model
├── __init__.py
├── base_model.py
├── ddpm_modules
│ ├── diffusion.py
│ └── unet.py
├── model.py
├── networks.py
└── sr3_modules
│ ├── diffusion.py
│ └── unet.py
├── requirement.txt
└── train.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/config/sample_ddpm_128.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "generation_ffhq",
3 | "phase": "train",
4 | "gpu_ids": [
5 | 0
6 | ],
7 | "path": {
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": null
13 | // "resume_state": "experiments/generation_ffhq_210811_140902/checkpoint/I30_E1"
14 | },
15 | "datasets": {
16 | "train": {
17 | "name": "FFHQ",
18 | "mode": "HR",
19 | "dataroot": "dataset/ffhq_16_128",
20 | "datatype": "lmdb", //lmdb or img, path of img files
21 | "l_resolution": 16,
22 | "r_resolution": 128,
23 | "batch_size": 12,
24 | "num_workers": 8,
25 | "use_shuffle": true,
26 | "data_len": -1
27 | },
28 | "val": {
29 | "name": "CelebaHQ",
30 | "mode": "HR",
31 | "dataroot": "dataset/celebahq_16_128",
32 | "datatype": "lmdb", //lmdb or img, path of img files
33 | "l_resolution": 16,
34 | "r_resolution": 128,
35 | "data_len": 10
36 | }
37 | },
38 | "model": {
39 | "which_model_G": "ddpm", //ddpm, sr3
40 | "finetune_norm": false,
41 | "unet": {
42 | "in_channel": 3,
43 | "out_channel": 3,
44 | "inner_channel": 64,
45 | "channel_multiplier": [
46 | 1,
47 | 1,
48 | 2,
49 | 2,
50 | 4,
51 | 4
52 | ],
53 | "attn_res": [
54 | 16
55 | ],
56 | "res_blocks": 2,
57 | "dropout": 0.2
58 | },
59 | "beta_schedule": {
60 | "train": {
61 | "schedule": "linear",
62 | "n_timestep": 2000,
63 | "linear_start": 1e-4,
64 | "linear_end": 2e-2
65 | },
66 | "val": {
67 | "schedule": "linear",
68 | "n_timestep": 2000,
69 | "linear_start": 1e-4,
70 | "linear_end": 2e-2
71 | }
72 | },
73 | "diffusion": {
74 | "image_size": 128,
75 | "channels": 3, //sample channel
76 | "conditional": false
77 | }
78 | },
79 | "train": {
80 | "n_iter": 1000000,
81 | "val_freq": 1e4,
82 | "save_checkpoint_freq": 1e4,
83 | "print_freq": 200,
84 | "optimizer": {
85 | "type": "adam",
86 | "lr": 1e-4
87 | },
88 | "ema_scheduler": {
89 | "step_start_ema": 5000,
90 | "update_ema_every": 1,
91 | "ema_decay": 0.9999
92 | }
93 | },
94 | "wandb": {
95 | "project": "generation_ffhq_ddpm"
96 | }
97 | }
--------------------------------------------------------------------------------
/config/sample_sr3_128.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "generation_ffhq",
3 | "phase": "train", // train or val
4 | "gpu_ids": [
5 | 0
6 | ],
7 | "path": { //set the path
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": null
13 | // "resume_state": "experiments/generation_ffhq_210811_140902/checkpoint/I1560000_E91" //pretrain model or training state
14 | },
15 | "datasets": {
16 | "train": {
17 | "name": "FFHQ",
18 | "mode": "HR", // whether need LR img
19 | "dataroot": "dataset/ffhq_16_128",
20 | "datatype": "lmdb", //lmdb or img, path of img files
21 | "l_resolution": 16, // low resolution need to super_resolution
22 | "r_resolution": 128, // high resolution
23 | "batch_size": 4,
24 | "num_workers": 8,
25 | "use_shuffle": true,
26 | "data_len": -1 // -1 represents all data used in train
27 | },
28 | "val": {
29 | "name": "CelebaHQ",
30 | "mode": "HR",
31 | "dataroot": "dataset/celebahq_16_128",
32 | "datatype": "lmdb", //lmdb or img, path of img files
33 | "l_resolution": 16,
34 | "r_resolution": 128,
35 | "data_len": 50
36 | }
37 | },
38 | "model": {
39 | "which_model_G": "sr3", // use the ddpm or sr3 network structure
40 | "finetune_norm": false,
41 | "unet": {
42 | "in_channel": 3,
43 | "out_channel": 3,
44 | "inner_channel": 64,
45 | "channel_multiplier": [
46 | 1,
47 | 2,
48 | 4,
49 | 8,
50 | 8
51 | ],
52 | "attn_res": [
53 | 16
54 | ],
55 | "res_blocks": 2,
56 | "dropout": 0.2
57 | },
58 | "beta_schedule": { // use munual beta_schedule for acceleration
59 | "train": {
60 | "schedule": "linear",
61 | "n_timestep": 2000,
62 | "linear_start": 1e-6,
63 | "linear_end": 1e-2
64 | },
65 | "val": {
66 | "schedule": "linear",
67 | "n_timestep": 2000,
68 | "linear_start": 1e-6,
69 | "linear_end": 1e-2
70 | }
71 | },
72 | "diffusion": {
73 | "image_size": 128,
74 | "channels": 3, //sample channel
75 | "conditional": false // unconditional generation or unconditional generation(super_resolution)
76 | }
77 | },
78 | "train": {
79 | "n_iter": 10000000,
80 | "val_freq": 1e4,
81 | "save_checkpoint_freq": 1e4,
82 | "print_freq": 200,
83 | "optimizer": {
84 | "type": "adam",
85 | "lr": 1e-4
86 | },
87 | "ema_scheduler": { // not used now
88 | "step_start_ema": 5000,
89 | "update_ema_every": 1,
90 | "ema_decay": 0.9999
91 | }
92 | },
93 | "wandb": {
94 | "project": "generation_ffhq_sr3"
95 | }
96 | }
--------------------------------------------------------------------------------
/config/sr_ddpm_16_128.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sr_ffhq",
3 | "phase": "train",
4 | "gpu_ids": [
5 | 0
6 | ],
7 | "path": {
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": null
13 | // "resume_state": "experiments/sr_ffhq_210806_204158/checkpoint/I640000_E37" //pretrain model or training state
14 | },
15 | "datasets": {
16 | "train": {
17 | "name": "FFHQ",
18 | "mode": "HR",
19 | "dataroot": "dataset/processed",
20 | "datatype": "lmdb", //lmdb or img, path of img files
21 | "l_resolution": 64,
22 | "r_resolution": 64,
23 | "batch_size": 2,
24 | "num_workers": 0,
25 | "use_shuffle": true,
26 | "data_len": -1
27 | },
28 | "val": {
29 | "name": "CelebaHQ",
30 | "mode": "LRHR",
31 | "dataroot": "dataset/processed",
32 | "datatype": "lmdb", //lmdb or img, path of img files
33 | "l_resolution": 64,
34 | "r_resolution": 64,
35 | "data_len": 3
36 | }
37 | },
38 | "model": {
39 | "which_model_G": "ddpm", //ddpm, sr3
40 | "finetune_norm": false,
41 | "unet": {
42 | "in_channel": 2,
43 | "out_channel": 1,
44 | "inner_channel": 32,
45 | "channel_multiplier": [
46 | 1,
47 | 1,
48 | 2,
49 | 2,
50 | 4,
51 | 4
52 | ],
53 | "attn_res": [
54 | 16
55 | ],
56 | "res_blocks": 2,
57 | "dropout": 0.2
58 | },
59 | "beta_schedule": {
60 | "train": {
61 | "schedule": "linear",
62 | "n_timestep": 2000,
63 | "linear_start": 1e-4,
64 | "linear_end": 2e-2
65 | },
66 | "val": {
67 | "schedule": "linear",
68 | "n_timestep": 2000,
69 | "linear_start": 1e-4,
70 | "linear_end": 2e-2
71 | }
72 | },
73 | "diffusion": {
74 | "image_size": 64,
75 | "channels": 1, //sample channel
76 | "conditional": true
77 | }
78 | },
79 | "train": {
80 | "n_iter": 1000000,
81 | "val_freq": 1e4,
82 | "save_checkpoint_freq": 1e4,
83 | "print_freq": 200,
84 | "optimizer": {
85 | "type": "adam",
86 | "lr": 1e-4
87 | },
88 | "ema_scheduler": {
89 | "step_start_ema": 5000,
90 | "update_ema_every": 1,
91 | "ema_decay": 0.9999
92 | }
93 | },
94 | "wandb": {
95 | "project": "sr_ffhq"
96 | }
97 | }
--------------------------------------------------------------------------------
/config/sr_sr3_16_128.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "sr_ffhq",
3 | "phase": "train", // train or val
4 | "gpu_ids": [
5 | 3
6 | ],
7 | "path": { //set the path
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": null //resume_state": "experiments/sr_ffhq_210806_204158/checkpoint/I640000_E37" //pretrain model or training state
13 | },
14 | "datasets": {
15 | "train": {
16 | "name": "FFHQ",
17 | "mode": "HR", // whether need LR img
18 | "dataroot": "train_mat",
19 | "datatype": "lmdb", //lmdb or img, path of img files
20 | "l_resolution": 64, // low resolution need to super_resolution
21 | "r_resolution": 64, // high resolution
22 | "batch_size": 4,
23 | "num_workers": 0,
24 | "use_shuffle": true,
25 | "data_len": -1 // -1 represents all data used in train
26 | },
27 | "val": {
28 | "name": "CelebaHQ",
29 | "mode": "LRHR",
30 | "dataroot": "dataset",
31 | "datatype": "lmdb", //lmdb or img, path of img files
32 | "l_resolution": 64,
33 | "r_resolution": 64,
34 | "data_len": -1 // data length in validation
35 | }
36 | },
37 | "model": {
38 | "which_model_G": "sr3", // use the ddpm or sr3 network structure
39 | "finetune_norm": false,
40 | "unet": {
41 | "PreNet": {
42 | "in_channel": 1,
43 | "out_channel": 1,
44 | "inner_channel": 64,
45 | "channel_multiplier": [
46 | 1,
47 | 2,
48 | 3,
49 | 4
50 | ],
51 | "attn_res": [
52 | 32
53 | ],
54 | "res_blocks": 3,
55 | "dropout": 0.1
56 | },
57 | "DenoiseNet": {
58 | "in_channel": 2,
59 | "out_channel": 1,
60 | "inner_channel": 32,
61 | "channel_multiplier": [
62 | 1,
63 | 2,
64 | 3,
65 | 4
66 | ],
67 | "attn_res": [
68 | 32
69 | ],
70 | "res_blocks": 3,
71 | "dropout": 0.1
72 | }
73 |
74 | },
75 | "beta_schedule": { // use munual beta_schedule for acceleration
76 | "train": {
77 | "schedule": "linear",
78 | "n_timestep": 2000,
79 | "linear_start": 1e-6,
80 | "linear_end": 1e-2
81 | },
82 | "val": {
83 | "schedule": "linear",
84 | "n_timestep": 2000,
85 | "linear_start": 1e-6,
86 | "linear_end": 1e-2
87 | }
88 | },
89 | "diffusion": {
90 | "image_size": 128,
91 | "channels": 1, //sample channel
92 | "conditional": true // unconditional generation or unconditional generation(super_resolution)
93 | }
94 | },
95 | "train": {
96 | "n_iter": 1000000,
97 | "val_freq": 1e4,
98 | "save_checkpoint_freq": 2e4,
99 | "print_freq": 200,
100 | "optimizer": {
101 | "type": "adam",
102 | "lr": 1e-4
103 | },
104 | "ema_scheduler": { // not used now
105 | "step_start_ema": 5000,
106 | "update_ema_every": 1,
107 | "ema_decay": 0.9999
108 | }
109 | },
110 | "wandb": {
111 | "project": "sr_ffhq"
112 | }
113 | }
--------------------------------------------------------------------------------
/config/sr_sr3_64_512.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "distributed_high_sr_ffhq",
3 | "phase": "train", // train or val
4 | "gpu_ids": [
5 | 0,1
6 | ],
7 | "path": { //set the path
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": null
13 | // "resume_state": "experiments/distributed_high_sr_ffhq_210901_121212/checkpoint/I830000_E32" //pretrain model or training state
14 | },
15 | "datasets": {
16 | "train": {
17 | "name": "FFHQ",
18 | "mode": "HR", // whether need LR img
19 | "dataroot": "dataset/ffhq_64_512",
20 | "datatype": "img", //lmdb or img, path of img files
21 | "l_resolution": 64, // low resolution need to super_resolution
22 | "r_resolution": 512, // high resolution
23 | "batch_size": 2,
24 | "num_workers": 8,
25 | "use_shuffle": true,
26 | "data_len": -1 // -1 represents all data used in train
27 | },
28 | "val": {
29 | "name": "CelebaHQ",
30 | "mode": "LRHR",
31 | "dataroot": "dataset/celebahq_64_512",
32 | "datatype": "img", //lmdb or img, path of img files
33 | "l_resolution": 64,
34 | "r_resolution": 512,
35 | "data_len": 50
36 | }
37 | },
38 | "model": {
39 | "which_model_G": "sr3", // use the ddpm or sr3 network structure
40 | "finetune_norm": false,
41 | "unet": {
42 | "in_channel": 6,
43 | "out_channel": 3,
44 | "inner_channel": 64,
45 | "norm_groups": 16,
46 | "channel_multiplier": [
47 | 1,
48 | 2,
49 | 4,
50 | 8,
51 | // 8,
52 | // 16,
53 | 16
54 | ],
55 | "attn_res": [
56 | // 16
57 | ],
58 | "res_blocks": 1,
59 | "dropout": 0
60 | },
61 | "beta_schedule": { // use munual beta_schedule for acceleration
62 | "train": {
63 | "schedule": "linear",
64 | "n_timestep": 2000,
65 | "linear_start": 1e-6,
66 | "linear_end": 1e-2
67 | },
68 | "val": {
69 | "schedule": "linear",
70 | "n_timestep": 2000,
71 | "linear_start": 1e-6,
72 | "linear_end": 1e-2
73 | }
74 | },
75 | "diffusion": {
76 | "image_size": 512,
77 | "channels": 3, //sample channel
78 | "conditional": true // unconditional generation or unconditional generation(super_resolution)
79 | }
80 | },
81 | "train": {
82 | "n_iter": 1000000,
83 | "val_freq": 1e4,
84 | "save_checkpoint_freq": 1e4,
85 | "print_freq": 50,
86 | "optimizer": {
87 | "type": "adam",
88 | "lr": 3e-6
89 | },
90 | "ema_scheduler": { // not used now
91 | "step_start_ema": 5000,
92 | "update_ema_every": 1,
93 | "ema_decay": 0.9999
94 | }
95 | },
96 | "wandb": {
97 | "project": "distributed_high_sr_ffhq"
98 | }
99 | }
--------------------------------------------------------------------------------
/core/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | from collections import OrderedDict
5 | import json
6 | from datetime import datetime
7 |
8 |
9 | def mkdirs(paths):
10 | if isinstance(paths, str):
11 | os.makedirs(paths, exist_ok=True)
12 | else:
13 | for path in paths:
14 | os.makedirs(path, exist_ok=True)
15 |
16 |
17 | def get_timestamp():
18 | return datetime.now().strftime('%y%m%d_%H%M%S')
19 |
20 |
21 | def parse(args):
22 | phase = args.phase
23 | opt_path = args.config
24 | gpu_ids = args.gpu_ids
25 | enable_wandb = args.enable_wandb
26 | # remove comments starting with '//'
27 | json_str = ''
28 | with open(opt_path, 'r') as f:
29 | for line in f:
30 | line = line.split('//')[0] + '\n'
31 | json_str += line
32 | opt = json.loads(json_str, object_pairs_hook=OrderedDict)
33 |
34 | # set log directory
35 | if args.debug:
36 | opt['name'] = 'debug_{}'.format(opt['name'])
37 | experiments_root = os.path.join(
38 | 'experiments', '{}_{}'.format(opt['name'], get_timestamp()))
39 | opt['path']['experiments_root'] = experiments_root
40 | for key, path in opt['path'].items():
41 | if 'resume' not in key and 'experiments' not in key:
42 | opt['path'][key] = os.path.join(experiments_root, path)
43 | mkdirs(opt['path'][key])
44 |
45 | # change dataset length limit
46 | opt['phase'] = phase
47 |
48 | # export CUDA_VISIBLE_DEVICES
49 | if gpu_ids is not None:
50 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')]
51 | gpu_list = gpu_ids
52 | else:
53 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
54 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
55 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
56 | if len(gpu_list) > 1:
57 | opt['distributed'] = True
58 | else:
59 | opt['distributed'] = False
60 |
61 | # debug
62 | if 'debug' in opt['name']:
63 | opt['train']['val_freq'] = 2
64 | opt['train']['print_freq'] = 2
65 | opt['train']['save_checkpoint_freq'] = 3
66 | opt['datasets']['train']['batch_size'] = 2
67 | opt['model']['beta_schedule']['train']['n_timestep'] = 10
68 | opt['model']['beta_schedule']['val']['n_timestep'] = 10
69 | opt['datasets']['train']['data_len'] = 6
70 | opt['datasets']['val']['data_len'] = 3
71 |
72 | # validation in train phase
73 | if phase == 'train':
74 | opt['datasets']['val']['data_len'] = 3
75 |
76 | # W&B Logging
77 | try:
78 | log_wandb_ckpt = args.log_wandb_ckpt
79 | opt['log_wandb_ckpt'] = log_wandb_ckpt
80 | except:
81 | pass
82 | try:
83 | log_eval = args.log_eval
84 | opt['log_eval'] = log_eval
85 | except:
86 | pass
87 | try:
88 | log_infer = args.log_infer
89 | opt['log_infer'] = log_infer
90 | except:
91 | pass
92 | opt['enable_wandb'] = enable_wandb
93 |
94 | return opt
95 |
96 |
97 | class NoneDict(dict):
98 | def __missing__(self, key):
99 | return None
100 |
101 |
102 | # convert to NoneDict, which return None for missing key.
103 | def dict_to_nonedict(opt):
104 | if isinstance(opt, dict):
105 | new_opt = dict()
106 | for key, sub_opt in opt.items():
107 | new_opt[key] = dict_to_nonedict(sub_opt)
108 | return NoneDict(**new_opt)
109 | elif isinstance(opt, list):
110 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
111 | else:
112 | return opt
113 |
114 |
115 | def dict2str(opt, indent_l=1):
116 | '''dict to string for logger'''
117 | msg = ''
118 | for k, v in opt.items():
119 | if isinstance(v, dict):
120 | msg += ' ' * (indent_l * 2) + k + ':[\n'
121 | msg += dict2str(v, indent_l + 1)
122 | msg += ' ' * (indent_l * 2) + ']\n'
123 | else:
124 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
125 | return msg
126 |
127 |
128 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
129 | '''set up logger'''
130 | l = logging.getLogger(logger_name)
131 | formatter = logging.Formatter(
132 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
133 | log_file = os.path.join(root, '{}.log'.format(phase))
134 | fh = logging.FileHandler(log_file, mode='w')
135 | fh.setFormatter(formatter)
136 | l.setLevel(level)
137 | l.addHandler(fh)
138 | if screen:
139 | sh = logging.StreamHandler()
140 | sh.setFormatter(formatter)
141 | l.addHandler(sh)
142 |
--------------------------------------------------------------------------------
/core/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import cv2
5 | from torchvision.utils import make_grid
6 |
7 |
8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
9 | '''
10 | Converts a torch Tensor into an image Numpy array
11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
13 | '''
14 | tensor = tensor.squeeze().float().cpu()
15 | img_np = tensor.numpy()
16 | return img_np
17 |
18 |
19 | def save_img(img, img_path, mode='RGB'):
20 | savImg = sitk.GetImageFromArray(img[:, :, :])
21 | sitk.WriteImage(savImg, img_path)
22 | # cv2.imwrite(img_path, img)
23 |
24 |
25 |
26 | def calculate_psnr(img1, img2):
27 | # img1 and img2 have range [0, 255]
28 | img1 = img1.astype(np.float64)
29 | img2 = img2.astype(np.float64)
30 | mse = np.mean((img1 - img2)**2)
31 | if mse == 0:
32 | return float('inf')
33 | return 20 * math.log10(255.0 / math.sqrt(mse))
34 |
35 |
36 | def ssim(img1, img2):
37 | C1 = (0.01 * 255)**2
38 | C2 = (0.03 * 255)**2
39 |
40 | img1 = img1.astype(np.float64)
41 | img2 = img2.astype(np.float64)
42 | kernel = cv2.getGaussianKernel(11, 1.5)
43 | window = np.outer(kernel, kernel.transpose())
44 |
45 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
46 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
47 | mu1_sq = mu1**2
48 | mu2_sq = mu2**2
49 | mu1_mu2 = mu1 * mu2
50 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
51 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
52 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
53 |
54 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
55 | (sigma1_sq + sigma2_sq + C2))
56 | return ssim_map.mean()
57 |
58 |
59 | def calculate_ssim(img1, img2):
60 | '''calculate SSIM
61 | the same outputs as MATLAB's
62 | img1, img2: [0, 255]
63 | '''
64 | if not img1.shape == img2.shape:
65 | raise ValueError('Input images must have the same dimensions.')
66 | if img1.ndim == 2:
67 | return ssim(img1, img2)
68 | elif img1.ndim == 3:
69 | if img1.shape[2] == 3:
70 | ssims = []
71 | for i in range(3):
72 | ssims.append(ssim(img1, img2))
73 | return np.array(ssims).mean()
74 | elif img1.shape[2] == 1:
75 | return ssim(np.squeeze(img1), np.squeeze(img2))
76 | else:
77 | raise ValueError('Wrong input image dimensions.')
78 |
--------------------------------------------------------------------------------
/core/wandb_logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | class WandbLogger:
4 | """
5 | Log using `Weights and Biases`.
6 | """
7 | def __init__(self, opt):
8 | try:
9 | import wandb
10 | except ImportError:
11 | raise ImportError(
12 | "To use the Weights and Biases Logger please install wandb."
13 | "Run `pip install wandb` to install it."
14 | )
15 |
16 | self._wandb = wandb
17 |
18 | # Initialize a W&B run
19 | if self._wandb.run is None:
20 | self._wandb.init(
21 | project=opt['wandb']['project'],
22 | config=opt,
23 | dir='./experiments'
24 | )
25 |
26 | self.config = self._wandb.config
27 |
28 | if self.config.get('log_eval', None):
29 | self.eval_table = self._wandb.Table(columns=[
30 | 'sr_image',
31 | 'hr_image',
32 | 'psnr',
33 | 'ssim'])
34 | else:
35 | self.eval_table = None
36 |
37 | if self.config.get('log_infer', None):
38 | self.infer_table = self._wandb.Table(columns=[
39 | 'sr_image',
40 | 'hr_image'])
41 | else:
42 | self.infer_table = None
43 |
44 | def log_metrics(self, metrics, commit=True):
45 | """
46 | Log train/validation metrics onto W&B.
47 |
48 | metrics: dictionary of metrics to be logged
49 | """
50 | self._wandb.log(metrics, commit=commit)
51 |
52 | def log_image(self, key_name, image_array):
53 | """
54 | Log image array onto W&B.
55 |
56 | key_name: name of the key
57 | image_array: numpy array of image.
58 | """
59 | self._wandb.log({key_name: self._wandb.Image(image_array)})
60 |
61 | def log_images(self, key_name, list_images):
62 | """
63 | Log list of image array onto W&B
64 |
65 | key_name: name of the key
66 | list_images: list of numpy image arrays
67 | """
68 | self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})
69 |
70 | def log_checkpoint(self, current_epoch, current_step):
71 | """
72 | Log the model checkpoint as W&B artifacts
73 |
74 | current_epoch: the current epoch
75 | current_step: the current batch step
76 | """
77 | model_artifact = self._wandb.Artifact(
78 | self._wandb.run.id + "_model", type="model"
79 | )
80 |
81 | gen_path = os.path.join(
82 | self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
83 | opt_path = os.path.join(
84 | self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))
85 |
86 | model_artifact.add_file(gen_path)
87 | model_artifact.add_file(opt_path)
88 | self._wandb.log_artifact(model_artifact, aliases=["latest"])
89 |
90 | def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
91 | """
92 | Add data row-wise to the initialized table.
93 | """
94 | if psnr is not None and ssim is not None:
95 | self.eval_table.add_data(
96 | self._wandb.Image(fake_img),
97 | self._wandb.Image(sr_img),
98 | self._wandb.Image(hr_img),
99 | psnr,
100 | ssim
101 | )
102 | else:
103 | self.infer_table.add_data(
104 | self._wandb.Image(fake_img),
105 | self._wandb.Image(sr_img),
106 | self._wandb.Image(hr_img)
107 | )
108 |
109 | def log_eval_table(self, commit=False):
110 | """
111 | Log the table
112 | """
113 | if self.eval_table:
114 | self._wandb.log({'eval_data': self.eval_table}, commit=commit)
115 | elif self.infer_table:
116 | self._wandb.log({'infer_data': self.infer_table}, commit=commit)
117 |
--------------------------------------------------------------------------------
/data/LRHR_dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from torchvision import transforms, datasets
5 | from torchvision.utils import save_image
6 | import torchvision.transforms
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | import random
10 | import data.util as Util
11 | import os
12 | from medpy.io import load
13 | import numpy as np
14 | import scipy.io as io
15 | class LRHRDataset(Dataset):
16 | def __init__(self, dataroot, datatype, l_resolution=64, r_resolution=64, split='train', data_len=-1, need_LR=False):
17 | self.datatype = datatype
18 | self.data_len = data_len
19 | self.need_LR = need_LR
20 | self.split = split
21 | self.path = Util.get_paths_from_images(
22 | '{}'.format(dataroot))
23 | self.dataset_len = len(self.path)
24 | if self.data_len <= 0:
25 | self.data_len = self.dataset_len
26 | else:
27 | self.data_len = min(self.data_len, self.dataset_len)
28 | def __len__(self):
29 | return self.data_len
30 |
31 | def __getitem__(self, index):
32 | image_path = os.path.join(self.path[index])
33 | image= io.loadmat(image_path)['img']
34 | image_h = image[:,128:256,:]
35 | img_hpet = torch.Tensor(image_h)
36 | image_s = image[:,0:128,:]
37 | img_spet = torch.Tensor(image_s)
38 | if self.need_LR:
39 | image_l = image[:,0:128,:]
40 | img_lpet = torch.Tensor(image_l)
41 | if self.need_LR:
42 | return {'LR': img_lpet, 'HR': img_hpet, 'SR': img_spet, 'Index': index}
43 | else:
44 | return {'HR': img_hpet, 'SR': img_spet, 'Index': index}
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | from re import split
4 | import torch.utils.data
5 | import os
6 |
7 | def create_dataloader(dataset, dataset_opt, phase):
8 | '''create dataloader '''
9 | if phase == 'train':
10 | return torch.utils.data.DataLoader(
11 | dataset,
12 | batch_size=dataset_opt['batch_size'],
13 | shuffle=dataset_opt['use_shuffle'],
14 | num_workers=dataset_opt['num_workers'],
15 | pin_memory=True)
16 | elif phase == 'val':
17 | return torch.utils.data.DataLoader(
18 | dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
19 | else:
20 | raise NotImplementedError(
21 | 'Dataloader [{:s}] is not found.'.format(phase))
22 |
23 |
24 | def create_dataset(dataset_opt, phase):
25 | '''create dataset'''
26 | mode = dataset_opt['mode']
27 | from data.LRHR_dataset import LRHRDataset as D
28 | dataset = D(dataroot=dataset_opt['dataroot'],
29 | datatype=dataset_opt['datatype'],
30 | l_resolution=dataset_opt['l_resolution'],
31 | r_resolution=dataset_opt['r_resolution'],
32 | split=phase,
33 | data_len=dataset_opt['data_len'],
34 | need_LR=(mode == 'LRHR')
35 | )
36 | logger = logging.getLogger('base')
37 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
38 | dataset_opt['name']))
39 | return dataset
40 | if __name__ == "__main__":
41 | from data.LRHR_dataset import LRHRDataset as D
42 | dataset = D(
43 | dataroot='C:\\Users\Administrator\Desktop\PET-Reconstruction-with-Diffusion\dataset\processed',
44 | datatype='jpg',
45 | l_resolution=64,
46 | r_resolution=64,
47 | split='train',
48 | data_len=-1,
49 | need_LR=False
50 | )
51 | train_set = dataset
52 | train_loader=torch.utils.data.DataLoader(
53 | dataset,
54 | batch_size=2,
55 | shuffle="true",
56 | num_workers=0,
57 | pin_memory=True)
58 | for _, train_data in enumerate(train_loader):
59 | print(train_data['HR'].shape)
60 | # print(torch.zeros(train_data['HR'].shape[0:2], dtype=torch.float))
61 | # path = 'dataset/processed'
62 | # # print(os.path.join(path.split(path.split('\\')[-1])[0]),'heihei',path.split('\\')[-1])
63 | # print(os.path.join(
64 | # path.split(path.split('/')[-1])[0],'PreNet', 'I{}_E{}_gen.pth'))
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import medpy
2 | from medpy.io import load
3 | import os
4 | import pickle
5 | import torch
6 | from torch.utils.data import Dataset, DataLoader
7 | from torchvision import transforms, datasets
8 | from torchvision.utils import save_image
9 | import numpy as np
10 | class MyDataset(Dataset):
11 | def __init__(self, root_l, subfolder_l,root_s,subfolder_s,prefixs, transform=None):
12 | super(MyDataset, self).__init__()
13 | self.prefixs=prefixs
14 | self.l_path = os.path.join(root_l, subfolder_l)
15 | self.s_path=os.path.join(root_s, subfolder_s)
16 | self.templ = [x for x in os.listdir(self.l_path) if os.path.splitext(x)[1] == ".img"]
17 | self.temps = [x for x in os.listdir(self.s_path) if os.path.splitext(x)[1] == ".img"]
18 | self.image_list_l=[]
19 | self.image_list_s = []
20 | #找指定前缀的数据
21 | for file in self.templ:
22 | for pre in prefixs:
23 | if pre in file:
24 | self.image_list_l.append(file)
25 | #找指定前缀的数据
26 | for file in self.temps:
27 | for pre in prefixs:
28 | if pre in file:
29 | self.image_list_s.append(file)
30 | # print(self.image_list_l)
31 | # print(self.image_list_s)
32 | self.transform = transform
33 |
34 | def __len__(self):
35 | return len(self.image_list_l)
36 |
37 | def __getitem__(self, item):
38 | #读图片(低剂量PET)
39 | image_path_l = os.path.join(self.l_path, self.image_list_l[item])
40 | #image = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR -> RGB
41 | image_l,h=load(image_path_l)
42 | image=np.array(image_l)
43 | #print(image.shape)
44 | if self.transform is not None:
45 | image = self.transform(image_l)
46 | #读标签(高质量PET)
47 | image_path_s = os.path.join(self.s_path, self.image_list_s[item])
48 | image_s,h2=load(image_path_s)
49 | image_s=np.array(image_s)
50 | #print(image_l.shape)0
51 | # print(image_path_l,image_path_s)
52 | #添加通道维度
53 | image_l=image_l[np.newaxis,:]
54 | image_s=image_s[np.newaxis,:]
55 | image_l=torch.Tensor(image_l)
56 | image_s=torch.Tensor(image_s)
57 | #print(image.shape)
58 | if self.transform is not None:
59 | image = self.transform(image_s)
60 | #返回:影像,标签
61 | return image_l, image_s
62 | ###
63 | class MyMultiDataset(Dataset):
64 | def __init__(self, root_l, subfolder_l,root_s,subfolder_s,root_mri,subfolder_mri,prefixs, transform=None):
65 | super(MyMultiDataset, self).__init__()
66 | self.prefixs=prefixs
67 | self.l_path = os.path.join(root_l, subfolder_l)
68 | self.s_path=os.path.join(root_s, subfolder_s)
69 | self.templ = [x for x in os.listdir(self.l_path) if os.path.splitext(x)[1] == ".img"]
70 | self.temps = [x for x in os.listdir(self.s_path) if os.path.splitext(x)[1] == ".img"]
71 | self.image_list_l=[]
72 | self.image_list_s = []
73 | self.image_list_mri = []
74 | #找指定前缀的数据
75 | for file in self.templ:
76 | for pre in prefixs:
77 | if pre in file:
78 | self.image_list_l.append(file)
79 | #找指定前缀的数据
80 | for file in self.temps:
81 | for pre in prefixs:
82 | if pre in file:
83 | self.image_list_s.append(file)
84 | #找指定前缀的数据
85 | for file in self.temp_mri:
86 | for pre in prefixs:
87 | if pre in file:
88 | self.image_list_mri.append(file)
89 | # print(self.image_list_l)
90 | # print(self.image_list_s)
91 | self.transform = transform
92 |
93 | def __len__(self):
94 | return len(self.image_list_l)
95 |
96 | def __getitem__(self, item):
97 | #读图片(低剂量PET)
98 | image_path_l = os.path.join(self.l_path, self.image_list_l[item])
99 | #image = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)[:, :, [2, 1, 0]] # BGR -> RGB
100 | image_l,h=load(image_path_l)
101 | image=np.array(image_l)
102 | #print(image.shape)
103 | if self.transform is not None:
104 | image = self.transform(image_l)
105 | #读标签(高质量PET)
106 | image_path_s = os.path.join(self.s_path, self.image_list_s[item])
107 | image_s,h2=load(image_path_s)
108 | image_s=np.array(image_s)
109 | #print(image_l.shape)0
110 | # print(image_path_l,image_path_s)
111 | #添加通道维度
112 | image_l=image_l[np.newaxis,:]
113 | image_s=image_s[np.newaxis,:]
114 | image_l=torch.Tensor(image_l)
115 | image_s=torch.Tensor(image_s)
116 | #print(image.shape)
117 | if self.transform is not None:
118 | image = self.transform(image_s)
119 | #返回:影像,标签
120 | return image_l, image_s
121 | #
122 | #data
123 | def loadData(root1, subfolder1,root2,subfolder2,prefixs, batch_size, shuffle=True):
124 |
125 | transform = None
126 | #测试已修改
127 | dataset = MyDataset(root1, subfolder1,root2,subfolder2,prefixs,transform=transform)
128 | #dataset = MyDataset(root, subfolder,transform=None)
129 |
130 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
131 | #multi data
132 | def loadMultiData(root1, subfolder1,root2,subfolder2,root3,subfolder3,prefixs, batch_size, shuffle=True):
133 |
134 | transform = None
135 | #测试已修改
136 | dataset = MyMultiDataset(root1, subfolder1,root2,subfolder2,root3,subfolder3,prefixs,transform=transform)
137 | #dataset = MyDataset(root, subfolder,transform=None)
138 |
139 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
140 | #
141 | #x=MyDataset('./data/l_cut','')
142 | def readTxtLineAsList(txt_path):
143 | fi = open(txt_path, 'r')
144 | txt = fi.readlines()
145 | res_list = []
146 | for w in txt:
147 | w = w.replace('\n', '')
148 | res_list.append(w)
149 | return res_list
150 |
151 | if __name__ == '__main__':
152 | train_txt_path = r"E:\Projects\PyCharm Projects\dataset\split\Ex2\train.txt"
153 | val_txt_path = r"E:\Projects\PyCharm Projects\dataset\split\Ex2\val.txt"
154 | train_imgs = readTxtLineAsList(train_txt_path)
155 | print(train_imgs)
156 | val_imgs = readTxtLineAsList(val_txt_path)
157 | print(val_imgs)
158 | trainloader=loadData('E:\Projects\PyCharm Projects\dataset\clinical/train_l_cut','','E:\Projects\PyCharm Projects\dataset\clinical/train_s_cut','',prefixs=train_imgs,batch_size=1)
159 | valloader = loadData('E:\Projects\PyCharm Projects\dataset\clinical/train_l_cut', '',
160 | 'E:\Projects\PyCharm Projects\dataset\clinical/train_s_cut', '', prefixs=val_imgs,
161 | batch_size=1)
--------------------------------------------------------------------------------
/data/prepare_data.py:
--------------------------------------------------------------------------------
1 | import medpy
2 | from medpy.io import load
3 | from medpy.io import save
4 | import numpy as np
5 | import os
6 | import SimpleITK as sitk
7 |
8 |
9 | # 用于数据切片
10 | def Datamake(root_l, root_s):
11 | all_l_names = []
12 | all_s_names = []
13 | for root, dirs, files in os.walk(root_l):
14 | all_l_names = (files)
15 | for root, dirs, files in os.walk(root_s):
16 | all_s_names = (files)
17 | #
18 | all_l_name = []
19 | all_s_name = []
20 | for i in all_l_names:
21 | if os.path.splitext(i)[1] == ".img":
22 | # print(i)
23 | all_l_name.append(i)
24 | for i in all_s_names:
25 | if os.path.splitext(i)[1] == ".img":
26 | all_s_name.append(i)
27 | #
28 | print(all_l_name)
29 | #
30 | for file in all_l_name:
31 | image_path_l = os.path.join(root_l, file)
32 | image_l, h = load(image_path_l)
33 | image_l = np.array(image_l)
34 | # print(image_l.shape)
35 | cut_cnt = 0
36 | # print(cut_cnt)
37 | for i in range(0, 8):
38 | for j in range(0, 8):
39 | for k in range(0, 8):
40 | image_cut = image_l[9 * i:64 + 9 * i, 9 * j:64 + 9 * j, 9 * k:64 + 9 * k]
41 | savImg = sitk.GetImageFromArray(image_cut.transpose(2, 1, 0))
42 | sitk.WriteImage(savImg,
43 | 'C:\\Users\Administrator\Desktop\PET-Reconstruction-with-Diffusion\dataset\processed\LPET_cut' + '/' + file + '_cut' + str(cut_cnt) + '.img')
44 | cut_cnt += 1
45 |
46 | for file in all_s_name:
47 | image_path_s = os.path.join(root_s, file)
48 | image_s, h = load(image_path_s)
49 | image_s = np.array(image_s)
50 | # print(image_l.shape)
51 | cut_cnt = 0
52 | for i in range(0, 8):
53 | for j in range(0, 8):
54 | for k in range(0, 8):
55 | image_cut = image_s[9 * i:64 + 9 * i, 9 * j:64 + 9 * j, 9 * k:64 + 9 * k]
56 | savImg = sitk.GetImageFromArray(image_cut.transpose(2, 1, 0))
57 | sitk.WriteImage(savImg,
58 | 'C:\\Users\Administrator\Desktop\PET-Reconstruction-with-Diffusion\dataset\processed\HPET_cut' + '/' + file + '_cut' + str(cut_cnt) + '.img')
59 | cut_cnt += 1
60 | if __name__ == '__main__':
61 | Datamake('D:\zpx\CVT3D\dataset\processed\LPET','D:\zpx\CVT3D\dataset\processed\SPET')
62 |
--------------------------------------------------------------------------------
/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | import random
5 | import numpy as np
6 |
7 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
9 |
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 |
15 | def get_paths_from_images(path):
16 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
17 | images = []
18 | for dirpath, _, fnames in sorted(os.walk(path)):
19 | for fname in sorted(fnames):
20 | if fname.endswith('.mat'):
21 | img_path = os.path.join(dirpath, fname)
22 | images.append(img_path)
23 | assert images, '{:s} has no valid image file'.format(path)
24 | return sorted(images)
25 |
26 |
27 |
28 | def transform2numpy(img):
29 | img = np.array(img)
30 | img = img.astype(np.float32) / 255.
31 | if img.ndim == 2:
32 | img = np.expand_dims(img, axis=2)
33 | # some images have 4 channels
34 | if img.shape[2] > 3:
35 | img = img[:, :, :3]
36 | return img
37 |
38 |
39 | def transform2tensor(img, min_max=(0, 1)):
40 | # HWC to CHW
41 | img = torch.from_numpy(np.ascontiguousarray(
42 | np.transpose(img, (2, 0, 1)))).float()
43 | # to range min_max
44 | img = img*(min_max[1] - min_max[0]) + min_max[0]
45 | return img
46 |
47 | totensor = torchvision.transforms.ToTensor()
48 | def transform_augment(img_list, split='val', min_max=(0, 1)):
49 | imgs = [totensor(img) for img in img_list]
50 | return imgs
51 | # implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14
52 |
--------------------------------------------------------------------------------
/easy_train.py:
--------------------------------------------------------------------------------
1 | import torch, torchvision
2 | from torch import nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | from torch.utils.data import DataLoader
6 | import torchvision.transforms as transforms
7 |
8 | from einops import rearrange, repeat
9 | from tqdm.notebook import tqdm
10 | from functools import partial
11 | from PIL import Image
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | import math, os, copy
15 |
16 | """
17 | Define U-net Architecture:
18 | Approximate reverse diffusion process by using U-net
19 | U-net of SR3 : U-net backbone + Positional Encoding of time + Multihead Self-Attention
20 | """
21 |
22 | # U-net Encoding
23 | class PositionalEncoding(nn.Module):
24 | def __init__(self, dim):
25 | super().__init__()
26 | self.dim = dim
27 |
28 | def forward(self, noise_level):
29 | # Input : tensor of value of coefficient alpha at specific step of diffusion process e.g. torch.Tensor([0.03])
30 | # Transform level of noise into representation of given desired dimension
31 | count = self.dim // 2
32 | step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
33 | encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
34 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
35 | return encoding
36 |
37 | #
38 | class FeatureWiseAffine(nn.Module):
39 | def __init__(self, in_channels, out_channels, use_affine_level=False):
40 | super(FeatureWiseAffine, self).__init__()
41 | self.use_affine_level = use_affine_level
42 | self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)))
43 |
44 | def forward(self, x, noise_embed):
45 | noise = self.noise_func(noise_embed).view(x.shape[0], -1, 1, 1)
46 | if self.use_affine_level:
47 | gamma, beta = noise.chunk(2, dim=1)
48 | x = (1 + gamma) * x + beta
49 | else:
50 | x = x + noise
51 | return x
52 |
53 | # swish activation function
54 | class Swish(nn.Module):
55 | def forward(self, x):
56 | return x * torch.sigmoid(x)
57 |
58 |
59 | class Upsample(nn.Module):
60 | def __init__(self, dim):
61 | super().__init__()
62 | self.up = nn.Upsample(scale_factor=2, mode="nearest")
63 | self.conv = nn.Conv2d(dim, dim, 3, padding=1)
64 |
65 | def forward(self, x):
66 | return self.conv(self.up(x))
67 |
68 |
69 | class Downsample(nn.Module):
70 | def __init__(self, dim):
71 | super().__init__()
72 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
73 |
74 | def forward(self, x):
75 | return self.conv(x)
76 |
77 |
78 | class Block(nn.Module):
79 | def __init__(self, dim, dim_out, groups=32, dropout=0):
80 | super().__init__()
81 | self.block = nn.Sequential(
82 | nn.GroupNorm(groups, dim),
83 | Swish(),
84 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
85 | nn.Conv2d(dim, dim_out, 3, padding=1)
86 | )
87 |
88 | def forward(self, x):
89 | return self.block(x)
90 |
91 |
92 | # Linear Multi-head Self-attention
93 | class SelfAtt(nn.Module):
94 | def __init__(self, channel_dim, num_heads, norm_groups=32):
95 | super(SelfAtt, self).__init__()
96 | self.groupnorm = nn.GroupNorm(norm_groups, channel_dim)
97 | self.num_heads = num_heads
98 | self.qkv = nn.Conv2d(channel_dim, channel_dim * 3, 1, bias=False)
99 | self.proj = nn.Conv2d(channel_dim, channel_dim, 1)
100 |
101 | def forward(self, x):
102 | b, c, h, w = x.size()
103 | x = self.groupnorm(x)
104 | qkv = rearrange(self.qkv(x), "b (qkv heads c) h w -> (qkv) b heads c (h w)", heads=self.num_heads, qkv=3)
105 | queries, keys, values = qkv[0], qkv[1], qkv[2]
106 |
107 | keys = F.softmax(keys, dim=-1)
108 | att = torch.einsum('bhdn,bhen->bhde', keys, values)
109 | out = torch.einsum('bhde,bhdn->bhen', att, queries)
110 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.num_heads, h=h, w=w)
111 |
112 | return self.proj(out)
113 |
114 |
115 | class ResBlock(nn.Module):
116 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0,
117 | num_heads=1, use_affine_level=False, norm_groups=32, att=True):
118 | super().__init__()
119 | self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
120 | self.block1 = Block(dim, dim_out, groups=norm_groups)
121 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
122 | self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
123 | self.att = att
124 | self.attn = SelfAtt(dim_out, num_heads=num_heads, norm_groups=norm_groups)
125 |
126 | def forward(self, x, time_emb):
127 | y = self.block1(x)
128 | y = self.noise_func(y, time_emb)
129 | y = self.block2(y)
130 | x = y + self.res_conv(x)
131 | if self.att:
132 | x = self.attn(x)
133 | return x
134 |
135 |
136 | class UNet(nn.Module):
137 | def __init__(self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32,
138 | channel_mults=[1, 2, 4, 8, 8], res_blocks=3, dropout=0, img_size=128):
139 | super().__init__()
140 |
141 | noise_level_channel = inner_channel
142 | self.noise_level_mlp = nn.Sequential(
143 | PositionalEncoding(inner_channel),
144 | nn.Linear(inner_channel, inner_channel * 4),
145 | Swish(),
146 | nn.Linear(inner_channel * 4, inner_channel)
147 | )
148 |
149 | num_mults = len(channel_mults)
150 | pre_channel = inner_channel
151 | feat_channels = [pre_channel]
152 | now_res = img_size
153 |
154 | # Downsampling stage of U-net
155 | downs = [nn.Conv2d(in_channel, inner_channel, kernel_size=3, padding=1)]
156 | for ind in range(num_mults):
157 | is_last = (ind == num_mults - 1)
158 | channel_mult = inner_channel * channel_mults[ind]
159 | for _ in range(0, res_blocks):
160 | downs.append(ResBlock(
161 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel,
162 | norm_groups=norm_groups, dropout=dropout))
163 | feat_channels.append(channel_mult)
164 | pre_channel = channel_mult
165 | if not is_last:
166 | downs.append(Downsample(pre_channel))
167 | feat_channels.append(pre_channel)
168 | now_res = now_res // 2
169 | self.downs = nn.ModuleList(downs)
170 |
171 | self.mid = nn.ModuleList([
172 | ResBlock(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
173 | norm_groups=norm_groups, dropout=dropout),
174 | ResBlock(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
175 | norm_groups=norm_groups, dropout=dropout, att=False)
176 | ])
177 |
178 | # Upsampling stage of U-net
179 | ups = []
180 | for ind in reversed(range(num_mults)):
181 | is_last = (ind < 1)
182 | channel_mult = inner_channel * channel_mults[ind]
183 | for _ in range(0, res_blocks + 1):
184 | ups.append(ResBlock(
185 | pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel,
186 | norm_groups=norm_groups, dropout=dropout))
187 | pre_channel = channel_mult
188 | if not is_last:
189 | ups.append(Upsample(pre_channel))
190 | now_res = now_res * 2
191 |
192 | self.ups = nn.ModuleList(ups)
193 |
194 | self.final_conv = Block(pre_channel, out_channel, groups=norm_groups)
195 |
196 | def forward(self, x, noise_level):
197 | # Embedding of time step with noise coefficient alpha
198 | t = self.noise_level_mlp(noise_level)
199 | print(t.shape)
200 | feats = []
201 | for layer in self.downs:
202 | if isinstance(layer, ResBlock):
203 | x = layer(x, t)
204 | else:
205 | x = layer(x)
206 | feats.append(x)
207 |
208 | for layer in self.mid:
209 | x = layer(x, t)
210 |
211 | for layer in self.ups:
212 | if isinstance(layer, ResBlock):
213 | x = layer(torch.cat((x, feats.pop()), dim=1), t)
214 | else:
215 | x = layer(x)
216 |
217 | return self.final_conv(x)
218 |
219 |
220 | """
221 | Define Diffusion process framework to train desired model:
222 | Forward Diffusion process:
223 | Given original image x_0, apply Gaussian noise ε_t for each time step t
224 | After proper length of time step, image x_T reachs to pure Gaussian noise
225 | Objective of model f :
226 | model f is trained to predict actual added noise ε_t for each time step t
227 | """
228 |
229 |
230 | class Diffusion(nn.Module):
231 | def __init__(self, model, device, img_size, LR_size, channels=3):
232 | super().__init__()
233 | self.channels = channels
234 | self.model = model.to(device)
235 | self.img_size = img_size
236 | self.LR_size = LR_size
237 | self.device = device
238 |
239 | def set_loss(self, loss_type):
240 | if loss_type == 'l1':
241 | self.loss_func = nn.L1Loss(reduction='sum')
242 | elif loss_type == 'l2':
243 | self.loss_func = nn.MSELoss(reduction='sum')
244 | else:
245 | raise NotImplementedError()
246 |
247 | def make_beta_schedule(self, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2):
248 | if schedule == 'linear':
249 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64)
250 | elif schedule == 'warmup':
251 | warmup_frac = 0.1
252 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
253 | warmup_time = int(n_timestep * warmup_frac)
254 | betas[:warmup_time] = np.linspace(linear_start, linear_end, warmup_time, dtype=np.float64)
255 | elif schedule == "cosine":
256 | cosine_s = 8e-3
257 | timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
258 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
259 | alphas = torch.cos(alphas).pow(2)
260 | alphas = alphas / alphas[0]
261 | betas = 1 - alphas[1:] / alphas[:-1]
262 | betas = betas.clamp(max=0.999)
263 | else:
264 | raise NotImplementedError(schedule)
265 | return betas
266 |
267 | def set_new_noise_schedule(self, schedule_opt):
268 | to_torch = partial(torch.tensor, dtype=torch.float32, device=self.device)
269 |
270 | betas = self.make_beta_schedule(
271 | schedule=schedule_opt['schedule'],
272 | n_timestep=schedule_opt['n_timestep'],
273 | linear_start=schedule_opt['linear_start'],
274 | linear_end=schedule_opt['linear_end'])
275 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
276 | alphas = 1. - betas
277 | alphas_cumprod = np.cumprod(alphas, axis=0)
278 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
279 | self.sqrt_alphas_cumprod_prev = np.sqrt(np.append(1., alphas_cumprod))
280 | self.num_timesteps = int(len(betas))
281 | # Coefficient for forward diffusion q(x_t | x_{t-1}) and others
282 | self.register_buffer('betas', to_torch(betas))
283 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
284 | self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
285 | self.register_buffer('pred_coef1', to_torch(np.sqrt(1. / alphas_cumprod)))
286 | self.register_buffer('pred_coef2', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
287 |
288 | # Coefficient for reverse diffusion posterior q(x_{t-1} | x_t, x_0)
289 | variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
290 | self.register_buffer('variance', to_torch(variance))
291 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
292 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(variance, 1e-20))))
293 | self.register_buffer('posterior_mean_coef1',
294 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
295 | self.register_buffer('posterior_mean_coef2',
296 | to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
297 |
298 | # Predict desired image x_0 from x_t with noise z_t -> Output is predicted x_0
299 | def predict_start(self, x_t, t, noise):
300 | return self.pred_coef1[t] * x_t - self.pred_coef2[t] * noise
301 |
302 | # Compute mean and log variance of posterior(reverse diffusion process) distribution
303 | def q_posterior(self, x_start, x_t, t):
304 | posterior_mean = self.posterior_mean_coef1[t] * x_start + self.posterior_mean_coef2[t] * x_t
305 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
306 | return posterior_mean, posterior_log_variance_clipped
307 |
308 | # Note that posterior q for reverse diffusion process is conditioned Gaussian distribution q(x_{t-1}|x_t, x_0)
309 | # Thus to compute desired posterior q, we need original image x_0 in ideal,
310 | # but it's impossible for actual training procedure -> Thus we reconstruct desired x_0 and use this for posterior
311 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
312 | batch_size = x.shape[0]
313 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t + 1]]).repeat(batch_size, 1).to(x.device)
314 | x_recon = self.predict_start(x, t, noise=self.model(torch.cat([condition_x, x], dim=1), noise_level))
315 |
316 | if clip_denoised:
317 | x_recon.clamp_(-1., 1.)
318 |
319 | mean, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
320 | return mean, posterior_log_variance
321 |
322 | # Progress single step of reverse diffusion process
323 | # Given mean and log variance of posterior, sample reverse diffusion result from the posterior
324 | @torch.no_grad()
325 | def p_sample(self, x, t, clip_denoised=True, condition_x=None):
326 | mean, log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
327 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
328 | return mean + noise * (0.5 * log_variance).exp()
329 |
330 | # Progress whole reverse diffusion process
331 | @torch.no_grad()
332 | def super_resolution(self, x_in):
333 | img = torch.rand_like(x_in, device=x_in.device)
334 | for i in reversed(range(0, self.num_timesteps)):
335 | img = self.p_sample(img, i, condition_x=x_in)
336 | return img
337 |
338 | # Compute loss to train the model
339 | def p_losses(self, x_in):
340 | x_start = x_in
341 | lr_imgs = transforms.Resize(self.img_size)(transforms.Resize(self.LR_size)(x_in))
342 | b, c, h, w = x_start.shape
343 | t = np.random.randint(1, self.num_timesteps + 1)
344 | sqrt_alpha = torch.FloatTensor(
345 | np.random.uniform(self.sqrt_alphas_cumprod_prev[t - 1], self.sqrt_alphas_cumprod_prev[t], size=b)
346 | ).to(x_start.device)
347 | sqrt_alpha = sqrt_alpha.view(-1, 1, 1, 1)
348 |
349 | noise = torch.randn_like(x_start).to(x_start.device)
350 | # Perturbed image obtained by forward diffusion process at random time step t
351 | x_noisy = sqrt_alpha * x_start + (1 - sqrt_alpha ** 2).sqrt() * noise
352 | # The model predict actual noise added at time step t
353 | pred_noise = self.model(torch.cat([lr_imgs, x_noisy], dim=1), noise_level=sqrt_alpha)
354 |
355 | return self.loss_func(noise, pred_noise)
356 |
357 | def forward(self, x, *args, **kwargs):
358 | return self.p_losses(x, *args, **kwargs)
359 |
360 |
361 | # Class to train & test desired model
362 | class SR3():
363 | def __init__(self, device, img_size, LR_size, loss_type, dataloader, testloader,
364 | schedule_opt, save_path, load_path=None, load=False,
365 | in_channel=6, out_channel=3, inner_channel=32, norm_groups=8,
366 | channel_mults=(1, 2, 4, 8, 8), res_blocks=3, dropout=0, lr=1e-5, distributed=False):
367 | super(SR3, self).__init__()
368 | self.dataloader = dataloader
369 | self.testloader = testloader
370 | self.device = device
371 | self.save_path = save_path
372 | self.img_size = img_size
373 | self.LR_size = LR_size
374 |
375 | model = UNet(in_channel, out_channel, inner_channel, norm_groups, channel_mults, res_blocks, dropout, img_size)
376 | self.sr3 = Diffusion(model, device, img_size, LR_size, out_channel)
377 |
378 | # Apply weight initialization & set loss & set noise schedule
379 | self.sr3.apply(self.weights_init_orthogonal)
380 | self.sr3.set_loss(loss_type)
381 | self.sr3.set_new_noise_schedule(schedule_opt)
382 |
383 | if distributed:
384 | assert torch.cuda.is_available()
385 | self.sr3 = nn.DataParallel(self.sr3)
386 |
387 | self.optimizer = torch.optim.Adam(self.sr3.parameters(), lr=lr)
388 |
389 | params = sum(p.numel() for p in self.sr3.parameters())
390 | print(f"Number of model parameters : {params}")
391 |
392 | if load:
393 | self.load(load_path)
394 |
395 | def weights_init_orthogonal(self, m):
396 | classname = m.__class__.__name__
397 | if classname.find('Conv') != -1:
398 | init.orthogonal_(m.weight.data, gain=1)
399 | if m.bias is not None:
400 | m.bias.data.zero_()
401 | elif classname.find('Linear') != -1:
402 | init.orthogonal_(m.weight.data, gain=1)
403 | if m.bias is not None:
404 | m.bias.data.zero_()
405 | elif classname.find('BatchNorm2d') != -1:
406 | init.constant_(m.weight.data, 1.0)
407 | init.constant_(m.bias.data, 0.0)
408 |
409 | def train(self, epoch, verbose):
410 | fixed_imgs = copy.deepcopy(next(iter(self.testloader)))
411 | fixed_imgs = fixed_imgs[0].to(self.device)
412 | # Transform to low-resolution images
413 | fixed_imgs = transforms.Resize(self.img_size)(transforms.Resize(self.LR_size)(fixed_imgs))
414 |
415 | for i in tqdm(range(epoch)):
416 | train_loss = 0
417 | for _, imgs in enumerate(self.dataloader):
418 | # Initial imgs are high-resolution
419 | imgs = imgs[0].to(self.device)
420 | b, c, h, w = imgs.shape
421 |
422 | self.optimizer.zero_grad()
423 | loss = self.sr3(imgs)
424 | loss = loss.sum() / int(b * c * h * w)
425 | loss.backward()
426 | self.optimizer.step()
427 | train_loss += loss.item() * b
428 |
429 | if (i + 1) % verbose == 0:
430 | self.sr3.eval()
431 | test_imgs = next(iter(self.testloader))
432 | test_imgs = test_imgs[0].to(self.device)
433 | b, c, h, w = test_imgs.shape
434 |
435 | with torch.no_grad():
436 | val_loss = self.sr3(test_imgs)
437 | val_loss = val_loss.sum() / int(b * c * h * w)
438 | self.sr3.train()
439 |
440 | train_loss = train_loss / len(self.dataloader)
441 | print(f'Epoch: {i + 1} / loss:{train_loss:.3f} / val_loss:{val_loss.item():.3f}')
442 |
443 | # Save example of test images to check training
444 | plt.figure(figsize=(15, 10))
445 | plt.subplot(1, 2, 1)
446 | plt.axis("off")
447 | plt.title("Low-Resolution Inputs")
448 | plt.imshow(np.transpose(torchvision.utils.make_grid(fixed_imgs,
449 | nrow=2, padding=1, normalize=True).cpu(),
450 | (1, 2, 0)))
451 |
452 | plt.subplot(1, 2, 2)
453 | plt.axis("off")
454 | plt.title("Super-Resolution Results")
455 | plt.imshow(np.transpose(torchvision.utils.make_grid(self.test(fixed_imgs).detach().cpu(),
456 | nrow=2, padding=1, normalize=True), (1, 2, 0)))
457 | plt.savefig('SuperResolution_Result.jpg')
458 | plt.close()
459 |
460 | # Save model weight
461 | self.save(self.save_path)
462 |
463 | def test(self, imgs):
464 | imgs_lr = transforms.Resize(self.img_size)(transforms.Resize(self.LR_size)(imgs))
465 | self.sr3.eval()
466 | with torch.no_grad():
467 | if isinstance(self.sr3, nn.DataParallel):
468 | result_SR = self.sr3.module.super_resolution(imgs_lr)
469 | else:
470 | result_SR = self.sr3.super_resolution(imgs_lr)
471 | self.sr3.train()
472 | return result_SR
473 |
474 | def save(self, save_path):
475 | network = self.sr3
476 | if isinstance(self.sr3, nn.DataParallel):
477 | network = network.module
478 | state_dict = network.state_dict()
479 | for key, param in state_dict.items():
480 | state_dict[key] = param.cpu()
481 | torch.save(state_dict, save_path)
482 |
483 | def load(self, load_path):
484 | network = self.sr3
485 | if isinstance(self.sr3, nn.DataParallel):
486 | network = network.module
487 | network.load_state_dict(torch.load(load_path))
488 | print("Model loaded successfully")
489 |
490 |
491 | if __name__ == "__main__":
492 | batch_size = 16
493 | LR_size = 32
494 | img_size = 128
495 | root = './data/ffhq_thumb'
496 | testroot = './data/celeba_hq'
497 |
498 | transforms_ = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor(),
499 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
500 | dataloader = DataLoader(torchvision.datasets.ImageFolder(root, transform=transforms_),
501 | batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
502 | testloader = DataLoader(torchvision.datasets.ImageFolder(testroot, transform=transforms_),
503 | batch_size=4, shuffle=True, num_workers=8, pin_memory=True)
504 |
505 | cuda = torch.cuda.is_available()
506 | device = torch.device("cuda:2" if cuda else "cpu")
507 | schedule_opt = {'schedule': 'linear', 'n_timestep': 2000, 'linear_start': 1e-4, 'linear_end': 0.05}
508 |
509 | sr3 = SR3(device, img_size=img_size, LR_size=LR_size, loss_type='l1',
510 | dataloader=dataloader, testloader=testloader, schedule_opt=schedule_opt,
511 | save_path='./SR3.pt', load_path='./SR3.pt', load=True, inner_channel=96,
512 | norm_groups=16, channel_mults=(1, 2, 2, 2), dropout=0.2, res_blocks=2, lr=1e-5, distributed=False)
513 | sr3.train(epoch=250, verbose=25)
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import data as Data
3 | import model as Model
4 | import argparse
5 | import logging
6 | import core.logger as Logger
7 | import core.metrics as Metrics
8 | from core.wandb_logger import WandbLogger
9 | from tensorboardX import SummaryWriter
10 | import os
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('-c', '--config', type=str, default='config/sr_sr3_64_512.json',
15 | help='JSON file for configuration')
16 | parser.add_argument('-p', '--phase', type=str, choices=['val'], help='val(generation)', default='val')
17 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
18 | parser.add_argument('-debug', '-d', action='store_true')
19 | parser.add_argument('-enable_wandb', action='store_true')
20 | parser.add_argument('-log_infer', action='store_true')
21 |
22 | # parse configs
23 | args = parser.parse_args()
24 | opt = Logger.parse(args)
25 | # Convert to NoneDict, which return None for missing key.
26 | opt = Logger.dict_to_nonedict(opt)
27 |
28 | # logging
29 | torch.backends.cudnn.enabled = True
30 | torch.backends.cudnn.benchmark = True
31 |
32 | Logger.setup_logger(None, opt['path']['log'],
33 | 'train', level=logging.INFO, screen=True)
34 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
35 | logger = logging.getLogger('base')
36 | logger.info(Logger.dict2str(opt))
37 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])
38 |
39 | # Initialize WandbLogger
40 | if opt['enable_wandb']:
41 | wandb_logger = WandbLogger(opt)
42 | else:
43 | wandb_logger = None
44 |
45 | # dataset
46 | for phase, dataset_opt in opt['datasets'].items():
47 | if phase == 'val':
48 | val_set = Data.create_dataset(dataset_opt, phase)
49 | val_loader = Data.create_dataloader(
50 | val_set, dataset_opt, phase)
51 | logger.info('Initial Dataset Finished')
52 |
53 | # model
54 | diffusion = Model.create_model(opt)
55 | logger.info('Initial Model Finished')
56 |
57 | diffusion.set_new_noise_schedule(
58 | opt['model']['beta_schedule']['val'], schedule_phase='val')
59 |
60 | logger.info('Begin Model Inference.')
61 | current_step = 0
62 | current_epoch = 0
63 | idx = 0
64 |
65 | result_path = '{}'.format(opt['path']['results'])
66 | os.makedirs(result_path, exist_ok=True)
67 | for _, val_data in enumerate(val_loader):
68 | idx += 1
69 | diffusion.feed_data(val_data)
70 | diffusion.test(continous=True)
71 | visuals = diffusion.get_current_visuals(need_LR=False)
72 |
73 | hr_img = Metrics.tensor2img(visuals['HR']) # uint8
74 | fake_img = Metrics.tensor2img(visuals['INF']) # uint8
75 |
76 | sr_img_mode = 'grid'
77 | if sr_img_mode == 'single':
78 | # single img series
79 | sr_img = visuals['SR'] # uint8
80 | sample_num = sr_img.shape[0]
81 | for iter in range(0, sample_num):
82 | Metrics.save_img(
83 | Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter))
84 | else:
85 | # grid img
86 | sr_img = Metrics.tensor2img(visuals['SR']) # uint8
87 | Metrics.save_img(
88 | sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx))
89 | Metrics.save_img(
90 | Metrics.tensor2img(visuals['SR'][-1]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx))
91 |
92 | Metrics.save_img(
93 | hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
94 | Metrics.save_img(
95 | fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))
96 |
97 | if wandb_logger and opt['log_infer']:
98 | wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img)
99 |
100 | if wandb_logger and opt['log_infer']:
101 | wandb_logger.log_eval_table(commit=True)
102 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | from .model import DDPM as M
7 | m = M(opt)
8 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
9 | return m
10 |
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class BaseModel():
7 | def __init__(self, opt):
8 | self.opt = opt
9 | self.device = torch.device(
10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu')
11 | self.begin_step = 0
12 | self.begin_epoch = 0
13 |
14 | def feed_data(self, data):
15 | pass
16 |
17 | def optimize_parameters(self):
18 | pass
19 |
20 | def get_current_visuals(self):
21 | pass
22 |
23 | def get_current_losses(self):
24 | pass
25 |
26 | def print_network(self):
27 | pass
28 |
29 | def set_device(self, x):
30 | if isinstance(x, dict):
31 | for key, item in x.items():
32 | if item is not None:
33 | x[key] = item.to(self.device)
34 | elif isinstance(x, list):
35 | for item in x:
36 | if item is not None:
37 | item = item.to(self.device)
38 | else:
39 | x = x.to(self.device)
40 | return x
41 |
42 | def get_network_description(self, network):
43 | '''Get the string and total parameters of the network'''
44 | if isinstance(network, nn.DataParallel):
45 | network = network.module
46 | s = str(network)
47 | n = sum(map(lambda x: x.numel(), network.parameters()))
48 | return s, n
49 |
--------------------------------------------------------------------------------
/model/ddpm_modules/diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn, einsum
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 | from functools import partial
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 |
11 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
12 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
13 | warmup_time = int(n_timestep * warmup_frac)
14 | betas[:warmup_time] = np.linspace(
15 | linear_start, linear_end, warmup_time, dtype=np.float64)
16 | return betas
17 |
18 |
19 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
20 | if schedule == 'quad':
21 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
22 | n_timestep, dtype=np.float64) ** 2
23 | elif schedule == 'linear':
24 | betas = np.linspace(linear_start, linear_end,
25 | n_timestep, dtype=np.float64)
26 | elif schedule == 'warmup10':
27 | betas = _warmup_beta(linear_start, linear_end,
28 | n_timestep, 0.1)
29 | elif schedule == 'warmup50':
30 | betas = _warmup_beta(linear_start, linear_end,
31 | n_timestep, 0.5)
32 | elif schedule == 'const':
33 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
34 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
35 | betas = 1. / np.linspace(n_timestep,
36 | 1, n_timestep, dtype=np.float64)
37 | elif schedule == "cosine":
38 | timesteps = (
39 | torch.arange(n_timestep + 1, dtype=torch.float64) /
40 | n_timestep + cosine_s
41 | )
42 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
43 | alphas = torch.cos(alphas).pow(2)
44 | alphas = alphas / alphas[0]
45 | betas = 1 - alphas[1:] / alphas[:-1]
46 | betas = betas.clamp(max=0.999)
47 | else:
48 | raise NotImplementedError(schedule)
49 | return betas
50 |
51 |
52 | # gaussian diffusion trainer class
53 |
54 | def exists(x):
55 | return x is not None
56 |
57 |
58 | def default(val, d):
59 | if exists(val):
60 | return val
61 | return d() if isfunction(d) else d
62 |
63 |
64 | def extract(a, t, x_shape):
65 | b, *_ = t.shape
66 | out = a.gather(-1, t)
67 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
68 |
69 |
70 | def noise_like(shape, device, repeat=False):
71 | def repeat_noise(): return torch.randn(
72 | (1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
73 |
74 | def noise(): return torch.randn(shape, device=device)
75 | return repeat_noise() if repeat else noise()
76 |
77 |
78 | class GaussianDiffusion(nn.Module):
79 | def __init__(
80 | self,
81 | denoise_fn,
82 | image_size,
83 | channels=3,
84 | loss_type='l1',
85 | conditional=True,
86 | schedule_opt=None
87 | ):
88 | super().__init__()
89 | self.channels = channels
90 | self.image_size = image_size
91 | self.denoise_fn = denoise_fn
92 | self.conditional = conditional
93 | self.loss_type = loss_type
94 | if schedule_opt is not None:
95 | pass
96 | # self.set_new_noise_schedule(schedule_opt)
97 |
98 | def set_loss(self, device):
99 | if self.loss_type == 'l1':
100 | self.loss_func = nn.L1Loss(reduction='sum').to(device)
101 | elif self.loss_type == 'l2':
102 | self.loss_func = nn.MSELoss(reduction='sum').to(device)
103 | else:
104 | raise NotImplementedError()
105 |
106 | def set_new_noise_schedule(self, schedule_opt, device):
107 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
108 | betas = make_beta_schedule(
109 | schedule=schedule_opt['schedule'],
110 | n_timestep=schedule_opt['n_timestep'],
111 | linear_start=schedule_opt['linear_start'],
112 | linear_end=schedule_opt['linear_end'])
113 | betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
114 | alphas = 1. - betas
115 | alphas_cumprod = np.cumprod(alphas, axis=0)
116 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
117 |
118 | timesteps, = betas.shape
119 | self.num_timesteps = int(timesteps)
120 | self.register_buffer('betas', to_torch(betas))
121 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
122 | self.register_buffer('alphas_cumprod_prev',
123 | to_torch(alphas_cumprod_prev))
124 |
125 | # calculations for diffusion q(x_t | x_{t-1}) and others
126 | self.register_buffer('sqrt_alphas_cumprod',
127 | to_torch(np.sqrt(alphas_cumprod)))
128 | self.register_buffer('sqrt_one_minus_alphas_cumprod',
129 | to_torch(np.sqrt(1. - alphas_cumprod)))
130 | self.register_buffer('log_one_minus_alphas_cumprod',
131 | to_torch(np.log(1. - alphas_cumprod)))
132 | self.register_buffer('sqrt_recip_alphas_cumprod',
133 | to_torch(np.sqrt(1. / alphas_cumprod)))
134 | self.register_buffer('sqrt_recipm1_alphas_cumprod',
135 | to_torch(np.sqrt(1. / alphas_cumprod - 1)))
136 |
137 | # calculations for posterior q(x_{t-1} | x_t, x_0)
138 | posterior_variance = betas * \
139 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
140 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
141 | self.register_buffer('posterior_variance',
142 | to_torch(posterior_variance))
143 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
144 | self.register_buffer('posterior_log_variance_clipped', to_torch(
145 | np.log(np.maximum(posterior_variance, 1e-20))))
146 | self.register_buffer('posterior_mean_coef1', to_torch(
147 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
148 | self.register_buffer('posterior_mean_coef2', to_torch(
149 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
150 |
151 | def q_mean_variance(self, x_start, t):
152 | mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
153 | variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
154 | log_variance = extract(
155 | self.log_one_minus_alphas_cumprod, t, x_start.shape)
156 | return mean, variance, log_variance
157 |
158 | def predict_start_from_noise(self, x_t, t, noise):
159 | return (
160 | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
161 | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
162 | )
163 |
164 | def q_posterior(self, x_start, x_t, t):
165 | posterior_mean = (
166 | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
167 | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
168 | )
169 | posterior_variance = extract(self.posterior_variance, t, x_t.shape)
170 | posterior_log_variance_clipped = extract(
171 | self.posterior_log_variance_clipped, t, x_t.shape)
172 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
173 |
174 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
175 | if condition_x is not None:
176 | x_recon = self.predict_start_from_noise(
177 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), t))
178 | else:
179 | x_recon = self.predict_start_from_noise(
180 | x, t=t, noise=self.denoise_fn(x, t))
181 |
182 | if clip_denoised:
183 | x_recon.clamp_(-1., 1.)
184 |
185 | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
186 | x_start=x_recon, x_t=x, t=t)
187 | return model_mean, posterior_variance, posterior_log_variance
188 |
189 | @torch.no_grad()
190 | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False, condition_x=None):
191 | b, *_, device = *x.shape, x.device
192 | model_mean, _, model_log_variance = self.p_mean_variance(
193 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
194 | noise = noise_like(x.shape, device, repeat_noise)
195 | # no noise when t == 0
196 | nonzero_mask = (1 - (t == 0).float()).reshape(b,
197 | *((1,) * (len(x.shape) - 1)))
198 | return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
199 |
200 | @torch.no_grad()
201 | def p_sample_loop(self, x_in, continous=False):
202 | device = self.betas.device
203 | sample_inter = (1 | (self.num_timesteps//10))
204 |
205 | if not self.conditional:
206 | shape = x_in
207 | b = shape[0]
208 | img = torch.randn(shape, device=device)
209 | ret_img = img
210 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
211 | img = self.p_sample(img, torch.full(
212 | (b,), i, device=device, dtype=torch.long))
213 | if i % sample_inter == 0:
214 | ret_img = torch.cat([ret_img, img], dim=0)
215 | return img
216 | else:
217 | x = x_in
218 | shape = x.shape
219 | b = shape[0]
220 | img = torch.randn(shape, device=device)
221 | ret_img = x
222 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
223 | img = self.p_sample(img, torch.full(
224 | (b,), i, device=device, dtype=torch.long), condition_x=x)
225 | if i % sample_inter == 0:
226 | ret_img = torch.cat([ret_img, img], dim=0)
227 | if continous:
228 | return ret_img
229 | else:
230 | return ret_img[-1]
231 |
232 | @torch.no_grad()
233 | def sample(self, batch_size=1, continous=False):
234 | image_size = self.image_size
235 | channels = self.channels
236 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
237 |
238 | @torch.no_grad()
239 | def super_resolution(self, x_in, continous=False):
240 | return self.p_sample_loop(x_in, continous)
241 |
242 | @torch.no_grad()
243 | def interpolate(self, x1, x2, t=None, lam=0.5):
244 | b, *_, device = *x1.shape, x1.device
245 | t = default(t, self.num_timesteps - 1)
246 |
247 | assert x1.shape == x2.shape
248 |
249 | t_batched = torch.stack([torch.tensor(t, device=device)] * b)
250 | xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
251 |
252 | img = (1 - lam) * xt1 + lam * xt2
253 | for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
254 | img = self.p_sample(img, torch.full(
255 | (b,), i, device=device, dtype=torch.long))
256 |
257 | return img
258 |
259 | def q_sample(self, x_start, t, noise=None):
260 | noise = default(noise, lambda: torch.randn_like(x_start))
261 |
262 | # fix gama
263 | return (
264 | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
265 | extract(self.sqrt_one_minus_alphas_cumprod,
266 | t, x_start.shape) * noise
267 | )
268 |
269 | def p_losses(self, x_in, noise=None):
270 | x_start = x_in['HR']
271 | [b, c, h, w , l] = x_in['HR'].shape
272 | t = torch.randint(0, self.num_timesteps, (b,),
273 | device=x_start.device).long()
274 |
275 | noise = default(noise, lambda: torch.randn_like(x_start))
276 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
277 |
278 | if not self.conditional:
279 | x_recon = self.denoise_fn(x_noisy, t)
280 | else:
281 | x_recon = self.denoise_fn(
282 | torch.cat([x_in['SR'], x_noisy], dim=1), t)
283 | loss = self.loss_func(noise, x_recon)
284 |
285 | return loss
286 |
287 | def forward(self, x, *args, **kwargs):
288 | return self.p_losses(x, *args, **kwargs)
289 |
--------------------------------------------------------------------------------
/model/ddpm_modules/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from inspect import isfunction
8 |
9 |
10 | def exists(x):
11 | return x is not None
12 |
13 |
14 | def default(val, d):
15 | if exists(val):
16 | return val
17 | return d() if isfunction(d) else d
18 |
19 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
20 | class PositionalEncoding(nn.Module):
21 | def __init__(self, dim):
22 | super().__init__()
23 | self.dim = dim
24 |
25 | def forward(self, noise_level):
26 | count = self.dim // 2
27 | step = torch.arange(count, dtype=noise_level.dtype,
28 | device=noise_level.device) / count
29 | encoding = noise_level.unsqueeze(
30 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
31 | encoding = torch.cat(
32 | [torch.sin(encoding), torch.cos(encoding)], dim=-1)
33 | return encoding
34 |
35 |
36 | class FeatureWiseAffine(nn.Module):
37 | def __init__(self, in_channels, out_channels, use_affine_level=False):
38 | super(FeatureWiseAffine, self).__init__()
39 |
40 | self.use_affine_level = use_affine_level
41 | self.noise_func = nn.Sequential(
42 | nn.Linear(in_channels,out_channels*(1+self.use_affine_level))
43 | )
44 |
45 | def forward(self, x, noise_embed):
46 | batch = x.shape[0]
47 | if noise_embed is None:
48 | return x
49 | elif self.use_affine_level:
50 | gamma, beta = self.noise_func(noise_embed).view(
51 | batch, -1, 1, 1, 1).chunk(2, dim=1)
52 | x = (1 + gamma) * x + beta
53 | else:
54 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1, 1)
55 | return x
56 |
57 |
58 | class Swish(nn.Module):
59 | def forward(self, x):
60 | return x * torch.sigmoid(x)
61 |
62 |
63 | class Upsample(nn.Module):
64 | def __init__(self, dim):
65 | super().__init__()
66 | self.up = nn.Upsample(scale_factor=2, mode="nearest")
67 | self.conv = nn.Conv3d(dim, dim, 3, padding=1)
68 |
69 | def forward(self, x):
70 | return self.conv(self.up(x))
71 |
72 |
73 | class Downsample(nn.Module):
74 | def __init__(self, dim):
75 | super().__init__()
76 | self.conv = nn.Conv3d(dim, dim, 3, 2, 1)
77 |
78 | def forward(self, x):
79 | return self.conv(x)
80 |
81 |
82 | # building block modules
83 |
84 |
85 | class Block(nn.Module):
86 | def __init__(self, dim, dim_out, groups=16, dropout=0):
87 | super().__init__()
88 | self.block = nn.Sequential(
89 | nn.GroupNorm(groups, dim),
90 | Swish(),
91 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
92 | nn.Conv3d(dim, dim_out, 3, padding=1)
93 | )
94 |
95 | def forward(self, x):
96 | return self.block(x)
97 |
98 |
99 | class ResnetBlock(nn.Module):
100 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=16):
101 | super().__init__()
102 | if noise_level_emb_dim is not None:
103 | self.noise_func = FeatureWiseAffine(
104 | noise_level_emb_dim, dim_out, use_affine_level)
105 |
106 | self.block1 = Block(dim, dim_out, groups=norm_groups)
107 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
108 | self.res_conv = nn.Conv3d(
109 | dim, dim_out, 1) if dim != dim_out else nn.Identity()
110 |
111 | def forward(self, x, time_emb):
112 | b, c, h, w, d = x.shape
113 | h = self.block1(x)
114 | if time_emb is not None:
115 | h = self.noise_func(h, time_emb)
116 | h = self.block2(h)
117 | return h + self.res_conv(x)
118 |
119 |
120 | class SelfAttention(nn.Module):
121 | def __init__(self, in_channel, n_head=1, norm_groups=32):
122 | super().__init__()
123 |
124 | self.n_head = n_head
125 |
126 | self.norm = nn.GroupNorm(norm_groups, in_channel)
127 | self.qkv = nn.Conv3d(in_channel, in_channel * 3, 1, bias=False)
128 | self.out = nn.Conv3d(in_channel, in_channel, 1)
129 |
130 | def forward(self, input):
131 | batch, channel, height, width, depth = input.shape
132 | n_head = self.n_head
133 | head_dim = channel // n_head
134 |
135 | norm = self.norm(input)
136 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width, depth)
137 | query, key, value = qkv.chunk(3, dim=2)
138 |
139 | attn = torch.einsum(
140 | "bnchwd, bncyxz -> bnhwdyxz", query, key
141 | ).contiguous() / math.sqrt(channel)
142 | attn = attn.view(batch, n_head, height, width, depth, -1)
143 | attn = torch.softmax(attn, -1)
144 | attn = attn.view(batch, n_head, height, width, depth, height, width, depth)
145 |
146 | out = torch.einsum("bnhwdyxz, bncyxz -> bnchwd", attn, value).contiguous()
147 | out = self.out(out.view(batch, channel, height, width, depth))
148 |
149 | return out + input
150 |
151 |
152 | class ResnetBlocWithAttn(nn.Module):
153 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
154 | super().__init__()
155 | self.with_attn = with_attn
156 | self.res_block = ResnetBlock(
157 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
158 | if with_attn:
159 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
160 |
161 | def forward(self, x, time_emb):
162 | x = self.res_block(x, time_emb)
163 | if(self.with_attn):
164 | x = self.attn(x)
165 | return x
166 |
167 |
168 | class UNet(nn.Module):
169 | def __init__(
170 | self,
171 | in_channel=1,
172 | out_channel=1,
173 | inner_channel=32,
174 | norm_groups=16,
175 | channel_mults=(1, 2, 4, 8, 8),
176 | attn_res=(8,),
177 | res_blocks=3,
178 | dropout=0,
179 | with_noise_level_emb=False,
180 | image_size=64
181 | ):
182 | super().__init__()
183 |
184 | if with_noise_level_emb:
185 | noise_level_channel = inner_channel
186 | self.noise_level_mlp = nn.Sequential(
187 | PositionalEncoding(inner_channel),
188 | nn.Linear(inner_channel, inner_channel * 4),
189 | Swish(),
190 | nn.Linear(inner_channel * 4, inner_channel)
191 | )
192 | else:
193 | noise_level_channel = None
194 | self.noise_level_mlp = None
195 |
196 | num_mults = len(channel_mults)
197 | pre_channel = inner_channel
198 | feat_channels = [pre_channel]
199 | now_res = image_size
200 | downs = [nn.Conv3d(in_channel, inner_channel,
201 | kernel_size=3, padding=1)]
202 | for ind in range(num_mults):
203 | is_last = (ind == num_mults - 1)
204 | use_attn = (now_res in attn_res)
205 | channel_mult = inner_channel * channel_mults[ind]
206 | for _ in range(0, res_blocks):
207 | downs.append(ResnetBlocWithAttn(
208 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
209 | feat_channels.append(channel_mult)
210 | pre_channel = channel_mult
211 | if not is_last:
212 | downs.append(Downsample(pre_channel))
213 | feat_channels.append(pre_channel)
214 | now_res = now_res//2
215 | self.downs = nn.ModuleList(downs)
216 |
217 | self.mid = nn.ModuleList([
218 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
219 | dropout=dropout, with_attn=True),
220 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
221 | dropout=dropout, with_attn=False)
222 | ])
223 |
224 | ups = []
225 | for ind in reversed(range(num_mults)):
226 | is_last = (ind < 1)
227 | use_attn = (now_res in attn_res)
228 | channel_mult = inner_channel * channel_mults[ind]
229 | for _ in range(0, res_blocks+1):
230 | ups.append(ResnetBlocWithAttn(
231 | pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
232 | dropout=dropout, with_attn=use_attn))
233 | pre_channel = channel_mult
234 | if not is_last:
235 | ups.append(Upsample(pre_channel))
236 | now_res = now_res*2
237 |
238 | self.ups = nn.ModuleList(ups)
239 |
240 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
241 |
242 | def forward(self, x, time):
243 | t = self.noise_level_mlp(time) if exists(
244 | self.noise_level_mlp) else None
245 |
246 | feats = []
247 | for layer in self.downs:
248 | if isinstance(layer, ResnetBlocWithAttn):
249 | x = layer(x, t)
250 | else:
251 | x = layer(x)
252 | feats.append(x)
253 |
254 | for layer in self.mid:
255 | if isinstance(layer, ResnetBlocWithAttn):
256 |
257 | x = layer(x, t)
258 | else:
259 |
260 | x = layer(x)
261 |
262 | for layer in self.ups:
263 | if isinstance(layer, ResnetBlocWithAttn):
264 | x = layer(torch.cat((x, feats.pop()), dim=1), t)
265 | else:
266 |
267 | x = layer(x)
268 |
269 | return self.final_conv(x)
270 |
271 |
272 | if __name__ == "__main__":
273 | model = UNet().to("cuda")
274 | noise_level = torch.FloatTensor(
275 | [0.5]).repeat(1, 1).to("cuda")
276 | x = torch.randn(1, 1, 64, 64, 64).to("cuda")
277 | y = model(x, noise_level)
278 | print(y.shape)
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | import os
7 | import model.networks as networks
8 | from .base_model import BaseModel
9 | logger = logging.getLogger('base')
10 |
11 |
12 | class DDPM(BaseModel):
13 | def __init__(self, opt):
14 | super(DDPM, self).__init__(opt)
15 | # define network and load pretrained models
16 | self.netP = self.set_device(networks.define_P(opt))
17 | self.netG = self.set_device(networks.define_G(opt))
18 | self.schedule_phase = None
19 | # set loss and load resume state
20 | self.loss_func = nn.L1Loss(reduction='sum').to(self.device)
21 | self.lr = opt['train']["optimizer"]["lr"]
22 | self.old_lr = self.lr
23 | self.set_loss()
24 | self.set_new_noise_schedule(
25 | opt['model']['beta_schedule']['train'], schedule_phase='train')
26 | if self.opt['phase'] == 'train':
27 | self.netG.train()
28 | self.netP.train()
29 | # find the parameters to optimize
30 | if opt['model']['finetune_norm']:
31 | optim_params = []
32 | optim_params_P = []
33 | for k, v in self.netG.named_parameters():
34 | v.requires_grad = False
35 | if k.find('transformer') >= 0:
36 | v.requires_grad = True
37 | v.data.zero_()
38 | optim_params.append(v)
39 | logger.info(
40 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k))
41 | for k, v in self.netP.named_parameters():
42 | v.requires_grad = False
43 | if k.find('transformer') >= 0:
44 | v.requires_grad = True
45 | v.data.zero_()
46 | optim_params.append(v)
47 | logger.info(
48 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k))
49 | else:
50 | optim_params = list(self.netG.parameters())
51 | optim_params_P = list(self.netP.parameters())
52 | self.optG = torch.optim.Adam(
53 | optim_params, lr=opt['train']["optimizer"]["lr"])
54 | self.optP = torch.optim.Adam(
55 | optim_params_P, lr=opt['train']["optimizer"]["lr"])
56 | self.log_dict = OrderedDict()
57 | self.load_network()
58 | self.print_network()
59 |
60 | def feed_data(self, data):
61 | self.data = self.set_device(data)
62 |
63 | def optimize_parameters(self):
64 | self.optG.zero_grad()
65 | self.optP.zero_grad()
66 | # 采样得到Prenet结果
67 | self.initial_predict()
68 | # 计算残差并作为loss的x_start
69 | self.data['IP'] = self.IP
70 | self.data['RS'] = self.data['HR'] - self.IP
71 | l_pix = self.netG(self.data)
72 | # need to average in multi-gpu
73 | b, c, h, w = self.data['HR'].shape
74 | l_pix = (l_pix.sum())/int(b*c*h*w)
75 | l_pix.backward()
76 | # 更新两个网络
77 | self.optG.step()
78 | self.optP.step()
79 | # set log
80 | self.log_dict['l_pix'] = l_pix.item()
81 | # self.log_dict['loss_pix'] = l_loss.item()
82 | def initial_predict(self):
83 | self.IP = self.netP(self.data['SR'],time = None)
84 |
85 | def test(self, continous=False):
86 | self.netG.eval()
87 | self.netP.eval()
88 | with torch.no_grad():
89 |
90 | if isinstance(self.netG, nn.DataParallel):
91 | self.SR = self.netG.module.super_resolution(
92 | self.data['SR'], continous)
93 | else:
94 | self.SR = self.netG.super_resolution(
95 | self.data['SR'], continous)
96 | self.netG.train()
97 | self.netP.train()
98 |
99 | def sample(self, batch_size=1, continous=False):
100 | self.netG.eval()
101 | with torch.no_grad():
102 | if isinstance(self.netG, nn.DataParallel):
103 | self.SR = self.netG.module.sample(batch_size, continous)
104 | else:
105 | self.SR = self.netG.sample(batch_size, continous)
106 | self.netG.train()
107 |
108 | def set_loss(self):
109 | if isinstance(self.netG, nn.DataParallel):
110 | self.netG.module.set_loss(self.device)
111 | else:
112 | self.netG.set_loss(self.device)
113 |
114 |
115 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
116 | if self.schedule_phase is None or self.schedule_phase != schedule_phase:
117 | self.schedule_phase = schedule_phase
118 | if isinstance(self.netG, nn.DataParallel):
119 | self.netG.module.set_new_noise_schedule(
120 | schedule_opt, self.device)
121 | else:
122 | self.netG.set_new_noise_schedule(schedule_opt, self.device)
123 |
124 |
125 | def get_current_log(self):
126 | return self.log_dict
127 |
128 | def get_current_visuals(self, need_LR=True, sample=False):
129 | out_dict = OrderedDict()
130 | if sample:
131 | out_dict['SAM'] = self.SR.detach().float().cpu()
132 | else:
133 | out_dict['SR'] = self.SR.detach().float().cpu()
134 | out_dict['INF'] = self.data['SR'].detach().float().cpu()
135 | out_dict['HR'] = self.data['HR'].detach().float().cpu()
136 | if need_LR and 'LR' in self.data:
137 | out_dict['LR'] = self.data['LR'].detach().float().cpu()
138 | else:
139 | out_dict['LR'] = out_dict['INF']
140 | return out_dict
141 |
142 | def print_network(self):
143 | s, n = self.get_network_description(self.netG)
144 | if isinstance(self.netG, nn.DataParallel):
145 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
146 | self.netG.module.__class__.__name__)
147 | else:
148 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
149 |
150 | logger.info(
151 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
152 | logger.info(s)
153 |
154 | def save_network(self, epoch, iter_step):
155 | # Prenet保存
156 | gen_path = os.path.join(
157 | self.opt['path']['checkpoint'], 'I{}_E{}_PreNet_gen.pth'.format(iter_step, epoch))
158 | opt_path = os.path.join(
159 | self.opt['path']['checkpoint'], 'I{}_E{}_PreNet_opt.pth'.format(iter_step, epoch))
160 | # gen
161 | network = self.netP
162 | if isinstance(self.netP, nn.DataParallel):
163 | network = network.module
164 | state_dict = network.state_dict()
165 | for key, param in state_dict.items():
166 | state_dict[key] = param.cpu()
167 | torch.save(state_dict, gen_path)
168 | # opt
169 | opt_state = {'epoch': epoch, 'iter': iter_step,
170 | 'scheduler': None, 'optimizer': None}
171 | opt_state['optimizer'] = self.optP.state_dict()
172 | torch.save(opt_state, opt_path)
173 |
174 | # DenoiseNet 保存
175 | gen_path = os.path.join(
176 | self.opt['path']['checkpoint'], 'I{}_E{}_DenoiseNet_gen.pth'.format(iter_step, epoch))
177 | opt_path = os.path.join(
178 | self.opt['path']['checkpoint'], 'I{}_E{}_DenoiseNet_opt.pth'.format(iter_step, epoch))
179 | # gen
180 | network = self.netG
181 | if isinstance(self.netG, nn.DataParallel):
182 | network = network.module
183 | state_dict = network.state_dict()
184 | for key, param in state_dict.items():
185 | state_dict[key] = param.cpu()
186 | torch.save(state_dict, gen_path)
187 | # opt
188 | opt_state = {'epoch': epoch, 'iter': iter_step,
189 | 'scheduler': None, 'optimizer': None}
190 | opt_state['optimizer'] = self.optG.state_dict()
191 | torch.save(opt_state, opt_path)
192 |
193 | logger.info(
194 | 'Saved model in [{:s}] ...'.format(gen_path))
195 |
196 | def load_network(self):
197 | # Prenet加载
198 | if self.opt['path']['resume_state'] is not None:
199 | load_path = self.opt['path']['resume_state']
200 | logger.info(
201 | 'Loading pretrained model for G [{:s}] ...'.format(load_path))
202 | gen_path = '{}_PreNet_gen.pth'.format(load_path)
203 | opt_path = '{}_PreNet_opt.pth'.format(load_path)
204 | # gen
205 | network = self.netP
206 | if isinstance(self.netP, nn.DataParallel):
207 | network = network.module
208 | network.load_state_dict(torch.load(
209 | gen_path), strict=(not self.opt['model']['finetune_norm']))
210 | # network.load_state_dict(torch.load(
211 | # gen_path), strict=False)
212 | if self.opt['phase'] == 'train':
213 | # optimizer
214 | opt = torch.load(opt_path)
215 | self.optP.load_state_dict(opt['optimizer'])
216 | self.begin_step = opt['iter']
217 | self.begin_epoch = opt['epoch']
218 |
219 | # DenoiseNet加载
220 | if self.opt['path']['resume_state'] is not None:
221 | load_path = self.opt['path']['resume_state']
222 | logger.info(
223 | 'Loading pretrained model for G [{:s}] ...'.format(load_path))
224 | gen_path = '{}_DenoiseNet_gen.pth'.format(load_path)
225 | opt_path = '{}_DenoiseNet_opt.pth'.format(load_path)
226 | # gen
227 | network = self.netG
228 | if isinstance(self.netG, nn.DataParallel):
229 | network = network.module
230 | network.load_state_dict(torch.load(
231 | gen_path), strict=(not self.opt['model']['finetune_norm']))
232 | # network.load_state_dict(torch.load(
233 | # gen_path), strict=False)
234 | if self.opt['phase'] == 'train':
235 | # optimizer
236 | opt = torch.load(opt_path)
237 | self.optG.load_state_dict(opt['optimizer'])
238 | self.begin_step = opt['iter']
239 | self.begin_epoch = opt['epoch']
240 | def update_learning_rate(self):
241 | self.niter_decay = 1000000
242 | if self.old_lr > 0.000001:
243 | lrd = 200 * self.lr / self.niter_decay
244 | lr = self.old_lr - lrd
245 | else:
246 | lr = self.old_lr
247 | for param_group in self.optP.param_groups:
248 | param_group['lr'] = lr
249 | for param_group in self.optG.param_groups:
250 | param_group['lr'] = lr
251 | print('update learning rate: %f -> %f' % (self.old_lr, lr))
252 | self.old_lr = lr
--------------------------------------------------------------------------------
/model/networks.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.nn import modules
7 | logger = logging.getLogger('base')
8 | ####################
9 | # initialize
10 | ####################
11 |
12 |
13 | def weights_init_normal(m, std=0.02):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv') != -1:
16 | init.normal_(m.weight.data, 0.0, std)
17 | if m.bias is not None:
18 | m.bias.data.zero_()
19 | elif classname.find('Linear') != -1:
20 | init.normal_(m.weight.data, 0.0, std)
21 | if m.bias is not None:
22 | m.bias.data.zero_()
23 | elif classname.find('BatchNorm2d') != -1:
24 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm
25 | init.constant_(m.bias.data, 0.0)
26 |
27 |
28 | def weights_init_kaiming(m, scale=1):
29 | classname = m.__class__.__name__
30 | if classname.find('Conv2d') != -1:
31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
32 | m.weight.data *= scale
33 | if m.bias is not None:
34 | m.bias.data.zero_()
35 | elif classname.find('Linear') != -1:
36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
37 | m.weight.data *= scale
38 | if m.bias is not None:
39 | m.bias.data.zero_()
40 | elif classname.find('BatchNorm2d') != -1:
41 | init.constant_(m.weight.data, 1.0)
42 | init.constant_(m.bias.data, 0.0)
43 |
44 |
45 | def weights_init_orthogonal(m):
46 | classname = m.__class__.__name__
47 | if classname.find('Conv') != -1:
48 | init.orthogonal_(m.weight.data, gain=1)
49 | if m.bias is not None:
50 | m.bias.data.zero_()
51 | elif classname.find('Linear') != -1:
52 | init.orthogonal_(m.weight.data, gain=1)
53 | if m.bias is not None:
54 | m.bias.data.zero_()
55 | elif classname.find('BatchNorm2d') != -1:
56 | init.constant_(m.weight.data, 1.0)
57 | init.constant_(m.bias.data, 0.0)
58 |
59 |
60 | def init_weights(net, init_type='kaiming', scale=1, std=0.02):
61 | # scale for 'kaiming', std for 'normal'.
62 | logger.info('Initialization method [{:s}]'.format(init_type))
63 | if init_type == 'normal':
64 | weights_init_normal_ = functools.partial(weights_init_normal, std=std)
65 | net.apply(weights_init_normal_)
66 | elif init_type == 'kaiming':
67 | weights_init_kaiming_ = functools.partial(
68 | weights_init_kaiming, scale=scale)
69 | net.apply(weights_init_kaiming_)
70 | elif init_type == 'orthogonal':
71 | net.apply(weights_init_orthogonal)
72 | else:
73 | raise NotImplementedError(
74 | 'initialization method [{:s}] not implemented'.format(init_type))
75 |
76 |
77 | ####################
78 | # define network
79 | ####################
80 | def define_P(opt):
81 | model_opt = opt['model']
82 | if model_opt['which_model_G'] == 'ddpm':
83 | from .ddpm_modules import diffusion, unet
84 | elif model_opt['which_model_G'] == 'sr3':
85 | from .sr3_modules import diffusion, unet
86 | if ('norm_groups' not in model_opt['unet']['PreNet']) or model_opt['unet']['PreNet']['norm_groups'] is None:
87 | model_opt['unet']['PreNet']['norm_groups']=32
88 | model = unet.UNet(
89 | in_channel=model_opt['unet']['PreNet']['in_channel'],
90 | out_channel=model_opt['unet']['PreNet']['out_channel'],
91 | norm_groups=model_opt['unet']['PreNet']['norm_groups'],
92 | inner_channel=model_opt['unet']['PreNet']['inner_channel'],
93 | channel_mults=model_opt['unet']['PreNet']['channel_multiplier'],
94 | attn_res=model_opt['unet']['PreNet']['attn_res'],
95 | res_blocks=model_opt['unet']['PreNet']['res_blocks'],
96 | dropout=model_opt['unet']['PreNet']['dropout'],
97 | with_noise_level_emb= False,
98 | image_size=model_opt['diffusion']['image_size']
99 | )
100 |
101 | if opt['phase'] == 'train':
102 | # init_weights(netG, init_type='kaiming', scale=0.1)
103 | init_weights(model, init_type='orthogonal')
104 | if opt['gpu_ids'] and opt['distributed']:
105 | assert torch.cuda.is_available()
106 | model = nn.DataParallel(model)
107 | return model
108 |
109 | # Generator
110 | def define_G(opt):
111 | model_opt = opt['model']
112 | if model_opt['which_model_G'] == 'ddpm':
113 | from .ddpm_modules import diffusion, unet
114 | elif model_opt['which_model_G'] == 'sr3':
115 | from .sr3_modules import diffusion, unet
116 | if ('norm_groups' not in model_opt['unet']['DenoiseNet']) or model_opt['unet']['DenoiseNet']['norm_groups'] is None:
117 | model_opt['unet']['DenoiseNet']['norm_groups']=32
118 | model = unet.UNet(
119 | in_channel=model_opt['unet']['DenoiseNet']['in_channel'],
120 | out_channel=model_opt['unet']['DenoiseNet']['out_channel'],
121 | norm_groups=model_opt['unet']['DenoiseNet']['norm_groups'],
122 | inner_channel=model_opt['unet']['DenoiseNet']['inner_channel'],
123 | channel_mults=model_opt['unet']['DenoiseNet']['channel_multiplier'],
124 | attn_res=model_opt['unet']['DenoiseNet']['attn_res'],
125 | res_blocks=model_opt['unet']['DenoiseNet']['res_blocks'],
126 | dropout=model_opt['unet']['DenoiseNet']['dropout'],
127 | image_size=model_opt['diffusion']['image_size']
128 | )
129 | netG = diffusion.GaussianDiffusion(
130 | model,
131 | image_size=model_opt['diffusion']['image_size'],
132 | channels=model_opt['diffusion']['channels'],
133 | loss_type='l1', # L1 or L2
134 | conditional=model_opt['diffusion']['conditional'],
135 | schedule_opt=model_opt['beta_schedule']['train']
136 | )
137 | if opt['phase'] == 'train':
138 | # init_weights(netG, init_type='kaiming', scale=0.1)
139 | init_weights(netG, init_type='orthogonal')
140 | if opt['gpu_ids'] and opt['distributed']:
141 | assert torch.cuda.is_available()
142 | netG = nn.DataParallel(netG)
143 | return netG
144 |
145 |
--------------------------------------------------------------------------------
/model/sr3_modules/diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import device, nn, einsum
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 | from functools import partial
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 |
11 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
12 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
13 | warmup_time = int(n_timestep * warmup_frac)
14 | betas[:warmup_time] = np.linspace(
15 | linear_start, linear_end, warmup_time, dtype=np.float64)
16 | return betas
17 |
18 |
19 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
20 | if schedule == 'quad':
21 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
22 | n_timestep, dtype=np.float64) ** 2
23 | elif schedule == 'linear':
24 | betas = np.linspace(linear_start, linear_end,
25 | n_timestep, dtype=np.float64)
26 | elif schedule == 'warmup10':
27 | betas = _warmup_beta(linear_start, linear_end,
28 | n_timestep, 0.1)
29 | elif schedule == 'warmup50':
30 | betas = _warmup_beta(linear_start, linear_end,
31 | n_timestep, 0.5)
32 | elif schedule == 'const':
33 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
34 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
35 | betas = 1. / np.linspace(n_timestep,
36 | 1, n_timestep, dtype=np.float64)
37 | elif schedule == "cosine":
38 | timesteps = (
39 | torch.arange(n_timestep + 1, dtype=torch.float64) /
40 | n_timestep + cosine_s
41 | )
42 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
43 | alphas = torch.cos(alphas).pow(2)
44 | alphas = alphas / alphas[0]
45 | betas = 1 - alphas[1:] / alphas[:-1]
46 | betas = betas.clamp(max=0.999)
47 | else:
48 | raise NotImplementedError(schedule)
49 | return betas
50 |
51 |
52 | # gaussian diffusion trainer class
53 |
54 | def exists(x):
55 | return x is not None
56 |
57 |
58 | def default(val, d):
59 | if exists(val):
60 | return val
61 | return d() if isfunction(d) else d
62 |
63 |
64 | class GaussianDiffusion(nn.Module):
65 | def __init__(
66 | self,
67 | denoise_fn,
68 | image_size,
69 | channels=1,
70 | loss_type='l1',
71 | conditional=True,
72 | schedule_opt=None
73 | ):
74 | super().__init__()
75 | self.channels = channels
76 | self.image_size = image_size
77 | self.denoise_fn = denoise_fn
78 | self.loss_type = loss_type
79 | self.conditional = conditional
80 | if schedule_opt is not None:
81 | pass
82 | # self.set_new_noise_schedule(schedule_opt)
83 |
84 | def set_loss(self, device):
85 | if self.loss_type == 'l1':
86 | self.loss_func = nn.L1Loss(reduction='sum').to(device)
87 | elif self.loss_type == 'l2':
88 | self.loss_func = nn.MSELoss(reduction='sum').to(device)
89 | else:
90 | raise NotImplementedError()
91 |
92 | def set_new_noise_schedule(self, schedule_opt, device):
93 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
94 |
95 | betas = make_beta_schedule(
96 | schedule=schedule_opt['schedule'],
97 | n_timestep=schedule_opt['n_timestep'],
98 | linear_start=schedule_opt['linear_start'],
99 | linear_end=schedule_opt['linear_end'])
100 | betas = betas.detach().cpu().numpy() if isinstance(
101 | betas, torch.Tensor) else betas
102 | alphas = 1. - betas
103 | alphas_cumprod = np.cumprod(alphas, axis=0)
104 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
105 | self.sqrt_alphas_cumprod_prev = np.sqrt(
106 | np.append(1., alphas_cumprod))
107 |
108 | timesteps, = betas.shape
109 | self.num_timesteps = int(timesteps)
110 | self.register_buffer('betas', to_torch(betas))
111 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
112 | self.register_buffer('alphas_cumprod_prev',
113 | to_torch(alphas_cumprod_prev))
114 |
115 | # calculations for diffusion q(x_t | x_{t-1}) and others
116 | self.register_buffer('sqrt_alphas_cumprod',
117 | to_torch(np.sqrt(alphas_cumprod)))
118 | self.register_buffer('sqrt_one_minus_alphas_cumprod',
119 | to_torch(np.sqrt(1. - alphas_cumprod)))
120 | self.register_buffer('log_one_minus_alphas_cumprod',
121 | to_torch(np.log(1. - alphas_cumprod)))
122 | self.register_buffer('sqrt_recip_alphas_cumprod',
123 | to_torch(np.sqrt(1. / alphas_cumprod)))
124 | self.register_buffer('sqrt_recipm1_alphas_cumprod',
125 | to_torch(np.sqrt(1. / alphas_cumprod - 1)))
126 |
127 | # calculations for posterior q(x_{t-1} | x_t, x_0)
128 | posterior_variance = betas * \
129 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
130 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
131 | self.register_buffer('posterior_variance',
132 | to_torch(posterior_variance))
133 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
134 | self.register_buffer('posterior_log_variance_clipped', to_torch(
135 | np.log(np.maximum(posterior_variance, 1e-20))))
136 | self.register_buffer('posterior_mean_coef1', to_torch(
137 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
138 | self.register_buffer('posterior_mean_coef2', to_torch(
139 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
140 |
141 | def predict_start_from_noise(self, x_t, t, noise):
142 | return self.sqrt_recip_alphas_cumprod[t] * x_t - \
143 | self.sqrt_recipm1_alphas_cumprod[t] * noise
144 |
145 | def q_posterior(self, x_start, x_t, t):
146 | posterior_mean = self.posterior_mean_coef1[t] * \
147 | x_start + self.posterior_mean_coef2[t] * x_t
148 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
149 | return posterior_mean, posterior_log_variance_clipped
150 |
151 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
152 | batch_size = x.shape[0]
153 | noise_level = torch.FloatTensor(
154 | [self.sqrt_alphas_cumprod_prev[t+1]]).repeat(batch_size, 1).to(x.device)
155 | if condition_x is not None:
156 | x_recon = self.predict_start_from_noise(
157 | x, t=t, noise=self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level))
158 | else:
159 | x_recon = self.predict_start_from_noise(
160 | x, t=t, noise=self.denoise_fn(x, noise_level))
161 |
162 | if clip_denoised:
163 | x_recon.clamp_(-1., 1.)
164 |
165 | model_mean, posterior_log_variance = self.q_posterior(
166 | x_start=x_recon, x_t=x, t=t)
167 | return model_mean, posterior_log_variance
168 |
169 | @torch.no_grad()
170 | def p_sample(self, x, t, clip_denoised=True, condition_x=None):
171 | model_mean, model_log_variance = self.p_mean_variance(
172 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
173 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
174 | return model_mean + noise * (0.5 * model_log_variance).exp()
175 |
176 | @torch.no_grad()
177 | def p_sample_loop(self, x_in, continous=False):
178 | device = self.betas.device
179 | sample_inter = (1 | (self.num_timesteps//10))
180 | if not self.conditional:
181 | shape = x_in
182 | img = torch.randn(shape, device=device)
183 | ret_img = img
184 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
185 | img = self.p_sample(img, i)
186 | if i % sample_inter == 0:
187 | ret_img = torch.cat([ret_img, img], dim=0)
188 | else:
189 | x = x_in
190 | shape = x.shape
191 | img = torch.randn(shape, device=device)
192 | ret_img = x
193 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
194 | img = self.p_sample(img, i, condition_x=x)
195 | if i % sample_inter == 0:
196 | ret_img = torch.cat([ret_img, img], dim=0)
197 | if continous:
198 | return ret_img
199 | else:
200 | return ret_img[-1]
201 |
202 | @torch.no_grad()
203 | def sample(self, batch_size=1, continous=False):
204 | image_size = self.image_size
205 | channels = self.channels
206 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
207 |
208 | @torch.no_grad()
209 | def super_resolution(self, x_in, continous=False):
210 | return self.p_sample_loop(x_in, continous)
211 |
212 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
213 | noise = default(noise, lambda: torch.randn_like(x_start))
214 |
215 | # random gama
216 | return (
217 | continuous_sqrt_alpha_cumprod * x_start +
218 | (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise
219 | )
220 |
221 | def p_losses(self, x_in, noise=None):
222 | # x_start = x_in['IP']
223 | x_start = x_in['RS']
224 | [b, c, h, w] = x_start.shape
225 | t = np.random.randint(1, self.num_timesteps + 1)
226 | continuous_sqrt_alpha_cumprod = torch.FloatTensor(
227 | np.random.uniform(
228 | self.sqrt_alphas_cumprod_prev[t-1],
229 | self.sqrt_alphas_cumprod_prev[t],
230 | size=b
231 | )
232 | ).to(x_start.device)
233 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(
234 | b, -1)
235 |
236 | noise = default(noise, lambda: torch.randn_like(x_start))
237 | x_noisy = self.q_sample(
238 | x_start=x_start, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise=noise)
239 |
240 | if not self.conditional:
241 | x_recon = self.denoise_fn(x_noisy, continuous_sqrt_alpha_cumprod)
242 | else:
243 | x_recon = self.denoise_fn(torch.cat([x_in['SR'],x_noisy], dim=1), continuous_sqrt_alpha_cumprod)
244 | loss = self.loss_func(noise, x_recon)
245 | # loss = self.loss_func(x_start,x_in['HR'])
246 | return loss
247 | #
248 | def forward(self, x, *args, **kwargs):
249 | return self.p_losses(x, *args, **kwargs)
250 |
251 |
--------------------------------------------------------------------------------
/model/sr3_modules/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 |
7 |
8 | def exists(x):
9 | return x is not None
10 |
11 |
12 | def default(val, d):
13 | if exists(val):
14 | return val
15 | return d() if isfunction(d) else d
16 |
17 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
18 | class PositionalEncoding(nn.Module):
19 | def __init__(self, dim):
20 | super().__init__()
21 | self.dim = dim
22 |
23 | def forward(self, noise_level):
24 | count = self.dim // 2
25 | step = torch.arange(count, dtype=noise_level.dtype,
26 | device=noise_level.device) / count
27 | encoding = noise_level.unsqueeze(
28 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
29 | encoding = torch.cat(
30 | [torch.sin(encoding), torch.cos(encoding)], dim=-1)
31 | return encoding
32 |
33 |
34 | class FeatureWiseAffine(nn.Module):
35 | def __init__(self, in_channels, out_channels, use_affine_level=False):
36 | super(FeatureWiseAffine, self).__init__()
37 | self.use_affine_level = use_affine_level
38 | self.noise_func = nn.Sequential(
39 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
40 | )
41 |
42 | def forward(self, x, noise_embed):
43 | batch = x.shape[0]
44 | if self.use_affine_level:
45 | gamma, beta = self.noise_func(noise_embed).view(
46 | batch, -1, 1, 1).chunk(2, dim=1)
47 | x = (1 + gamma) * x + beta
48 | else:
49 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
50 | return x
51 |
52 |
53 | class Swish(nn.Module):
54 | def forward(self, x):
55 | return x * torch.sigmoid(x)
56 |
57 |
58 | class Upsample(nn.Module):
59 | def __init__(self, dim):
60 | super().__init__()
61 | self.up = nn.Upsample(scale_factor=2, mode="nearest")
62 | self.conv = nn.Conv2d(dim, dim, 3, padding=1)
63 |
64 | def forward(self, x):
65 | return self.conv(self.up(x))
66 |
67 |
68 | class Downsample(nn.Module):
69 | def __init__(self, dim):
70 | super().__init__()
71 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
72 |
73 | def forward(self, x):
74 | return self.conv(x)
75 |
76 |
77 | # building block modules
78 |
79 |
80 | class Block(nn.Module):
81 | def __init__(self, dim, dim_out, groups=32, dropout=0):
82 | super().__init__()
83 | self.block = nn.Sequential(
84 | nn.GroupNorm(groups, dim),
85 | Swish(),
86 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
87 | nn.Conv2d(dim, dim_out, 3, padding=1)
88 | )
89 |
90 | def forward(self, x):
91 | return self.block(x)
92 |
93 |
94 | class ResnetBlock(nn.Module):
95 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
96 | super().__init__()
97 | if noise_level_emb_dim is not None:
98 | self.noise_func = FeatureWiseAffine(
99 | noise_level_emb_dim, dim_out, use_affine_level)
100 |
101 | self.block1 = Block(dim, dim_out, groups=norm_groups)
102 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
103 | self.res_conv = nn.Conv2d(
104 | dim, dim_out, 1) if dim != dim_out else nn.Identity()
105 |
106 | def forward(self, x, time_emb):
107 | b, c, h, w = x.shape
108 | h = self.block1(x)
109 | if time_emb is not None:
110 | h = self.noise_func(h, time_emb)
111 | h = self.block2(h)
112 | return h + self.res_conv(x)
113 |
114 |
115 | class SelfAttention(nn.Module):
116 | def __init__(self, in_channel, n_head=1, norm_groups=32):
117 | super().__init__()
118 |
119 | self.n_head = n_head
120 |
121 | self.norm = nn.GroupNorm(norm_groups, in_channel)
122 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
123 | self.out = nn.Conv2d(in_channel, in_channel, 1)
124 |
125 | def forward(self, input):
126 | batch, channel, height, width = input.shape
127 | n_head = self.n_head
128 | head_dim = channel // n_head
129 |
130 | norm = self.norm(input)
131 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
132 | query, key, value = qkv.chunk(3, dim=2) # bhdyx
133 |
134 | attn = torch.einsum(
135 | "bnchw, bncyx -> bnhwyx", query, key
136 | ).contiguous() / math.sqrt(channel)
137 | attn = attn.view(batch, n_head, height, width, -1)
138 | attn = torch.softmax(attn, -1)
139 | attn = attn.view(batch, n_head, height, width, height, width)
140 |
141 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
142 | out = self.out(out.view(batch, channel, height, width))
143 |
144 | return out + input
145 |
146 |
147 | class ResnetBlocWithAttn(nn.Module):
148 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
149 | super().__init__()
150 |
151 | self.with_attn = with_attn
152 | self.res_block = ResnetBlock(
153 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
154 | if with_attn:
155 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
156 |
157 | def forward(self, x, time_emb):
158 | x = self.res_block(x, time_emb)
159 | if(self.with_attn):
160 | x = self.attn(x)
161 | return x
162 |
163 |
164 | class UNet(nn.Module):
165 | def __init__(
166 | self,
167 | in_channel=6,
168 | out_channel=3,
169 | inner_channel=32,
170 | norm_groups=32,
171 | channel_mults=(1, 2, 4, 8, 8),
172 | attn_res=(8),
173 | res_blocks=3,
174 | dropout=0,
175 | with_noise_level_emb=True,
176 | image_size=128
177 | ):
178 | super().__init__()
179 |
180 | if with_noise_level_emb:
181 | noise_level_channel = inner_channel
182 | self.noise_level_mlp = nn.Sequential(
183 | PositionalEncoding(inner_channel),
184 | nn.Linear(inner_channel, inner_channel * 4),
185 | Swish(),
186 | nn.Linear(inner_channel * 4, inner_channel)
187 | )
188 | else:
189 | noise_level_channel = None
190 | self.noise_level_mlp = None
191 |
192 | num_mults = len(channel_mults)
193 | pre_channel = inner_channel
194 | feat_channels = [pre_channel]
195 | now_res = image_size
196 | downs = [nn.Conv2d(in_channel, inner_channel,
197 | kernel_size=3, padding=1)]
198 | for ind in range(num_mults):
199 | is_last = (ind == num_mults - 1)
200 | use_attn = (now_res in attn_res)
201 | channel_mult = inner_channel * channel_mults[ind]
202 | for _ in range(0, res_blocks):
203 | downs.append(ResnetBlocWithAttn(
204 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
205 | feat_channels.append(channel_mult)
206 | pre_channel = channel_mult
207 | if not is_last:
208 | downs.append(Downsample(pre_channel))
209 | feat_channels.append(pre_channel)
210 | now_res = now_res//2
211 | self.downs = nn.ModuleList(downs)
212 |
213 | self.mid = nn.ModuleList([
214 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
215 | dropout=dropout, with_attn=True),
216 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
217 | dropout=dropout, with_attn=False)
218 | ])
219 |
220 | ups = []
221 | for ind in reversed(range(num_mults)):
222 | is_last = (ind < 1)
223 | use_attn = (now_res in attn_res)
224 | channel_mult = inner_channel * channel_mults[ind]
225 | for _ in range(0, res_blocks+1):
226 | ups.append(ResnetBlocWithAttn(
227 | pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
228 | dropout=dropout, with_attn=use_attn))
229 | pre_channel = channel_mult
230 | if not is_last:
231 | ups.append(Upsample(pre_channel))
232 | now_res = now_res*2
233 |
234 | self.ups = nn.ModuleList(ups)
235 |
236 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
237 |
238 | def forward(self, x, time):
239 | t = self.noise_level_mlp(time) if exists(
240 | self.noise_level_mlp) else None
241 |
242 | feats = []
243 | for layer in self.downs:
244 | if isinstance(layer, ResnetBlocWithAttn):
245 | x = layer(x, t)
246 | else:
247 | x = layer(x)
248 | feats.append(x)
249 |
250 | for layer in self.mid:
251 | if isinstance(layer, ResnetBlocWithAttn):
252 | x = layer(x, t)
253 | else:
254 | x = layer(x)
255 |
256 | for layer in self.ups:
257 | if isinstance(layer, ResnetBlocWithAttn):
258 | x = layer(torch.cat((x, feats.pop()), dim=1), t)
259 | else:
260 | x = layer(x)
261 |
262 | return self.final_conv(x)
263 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | torch>=1.6
2 | torchvision
3 | numpy
4 | pandas
5 | tqdm
6 | lmdb
7 | opencv-python
8 | pillow
9 | tensorboardx
10 | wandb
11 |
12 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import data as Data
3 | import model as Model
4 | import argparse
5 | import logging
6 | import core.logger as Logger
7 | import core.metrics as Metrics
8 | from core.wandb_logger import WandbLogger
9 | from tensorboardX import SummaryWriter
10 | import os
11 | import numpy as np
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('-c', '--config', type=str, default='config/sr_sr3_16_128.json',
16 | help='JSON file for configuration')
17 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
18 | help='Run either train(training) or val(generation)', default='train')
19 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
20 | parser.add_argument('-debug', '-d', action='store_true')
21 | parser.add_argument('-enable_wandb', action='store_true')
22 | parser.add_argument('-log_wandb_ckpt', action='store_true')
23 | parser.add_argument('-log_eval', action='store_true')
24 |
25 | # parse configs
26 | args = parser.parse_args()
27 | opt = Logger.parse(args)
28 | # Convert to NoneDict, which return None for missing key.
29 | opt = Logger.dict_to_nonedict(opt)
30 |
31 | # logging
32 | torch.backends.cudnn.enabled = True
33 | torch.backends.cudnn.benchmark = True
34 |
35 | Logger.setup_logger(None, opt['path']['log'],
36 | 'train', level=logging.INFO, screen=True)
37 | Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
38 | logger = logging.getLogger('base')
39 | logger.info(Logger.dict2str(opt))
40 | tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])
41 |
42 | # Initialize WandbLogger
43 | if opt['enable_wandb']:
44 | import wandb
45 | wandb_logger = WandbLogger(opt)
46 | wandb.define_metric('validation/val_step')
47 | wandb.define_metric('epoch')
48 | wandb.define_metric("validation/*", step_metric="val_step")
49 | val_step = 0
50 | else:
51 | wandb_logger = None
52 |
53 | # dataset
54 | for phase, dataset_opt in opt['datasets'].items():
55 | if phase == 'train' and args.phase != 'val':
56 | train_set = Data.create_dataset(dataset_opt, phase)
57 | train_loader = Data.create_dataloader(
58 | train_set, dataset_opt, phase)
59 | # elif phase == 'val':
60 | # val_set = Data.create_dataset(dataset_opt, phase)
61 | # val_loader = Data.create_dataloader(
62 | # val_set, dataset_opt, phase)
63 | logger.info('Initial Dataset Finished')
64 |
65 | # model
66 | diffusion = Model.create_model(opt)
67 | logger.info('Initial Model Finished')
68 |
69 | # Train
70 | current_step = diffusion.begin_step
71 | current_epoch = diffusion.begin_epoch
72 | n_iter = opt['train']['n_iter']
73 |
74 | if opt['path']['resume_state']:
75 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
76 | current_epoch, current_step))
77 |
78 | diffusion.set_new_noise_schedule(
79 | opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase'])
80 | if opt['phase'] == 'train':
81 | while current_step < n_iter:
82 | current_epoch += 1
83 | for _, train_data in enumerate(train_loader):
84 | current_step += 1
85 | if current_step > n_iter:
86 | break
87 |
88 | diffusion.feed_data(train_data)
89 | diffusion.optimize_parameters()
90 | # log
91 | if current_step % opt['train']['print_freq'] == 0:
92 | logs = diffusion.get_current_log()
93 | message = ' '.format(
94 | current_epoch, current_step)
95 | for k, v in logs.items():
96 | message += '{:s}: {:.4e} '.format(k, v)
97 | tb_logger.add_scalar(k, v, current_step)
98 | logger.info(message)
99 | # diffusion.update_learning_rate()
100 |
101 | if wandb_logger:
102 | wandb_logger.log_metrics(logs)
103 |
104 | # # validation
105 | # if current_step % opt['train']['val_freq'] == 0:
106 | # avg_psnr = 0.0
107 | # idx = 0
108 | # result_path = '{}/{}'.format(opt['path']
109 | # ['results'], current_epoch)
110 | # os.makedirs(result_path, exist_ok=True)
111 | #
112 | # diffusion.set_new_noise_schedule(
113 | # opt['model']['beta_schedule']['val'], schedule_phase='val')
114 | # for _, val_data in enumerate(val_loader):
115 | # idx += 1
116 | # diffusion.feed_data(val_data)
117 | # diffusion.test(continous=False)
118 | # visuals = diffusion.get_current_visuals()
119 | # sr_img = Metrics.tensor2img(visuals['SR']) # uint8
120 | # hr_img = Metrics.tensor2img(visuals['HR']) # uint8
121 | # lr_img = Metrics.tensor2img(visuals['LR']) # uint8
122 | # fake_img = Metrics.tensor2img(visuals['INF']) # uint8
123 | #
124 | # # generation
125 | # Metrics.save_img(
126 | # hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
127 | # Metrics.save_img(
128 | # sr_img, '{}/{}_{}_sr.png'.format(result_path, current_step, idx))
129 | # Metrics.save_img(
130 | # lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx))
131 | # Metrics.save_img(
132 | # fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))
133 | # tb_logger.add_image(
134 | # 'Iter_{}'.format(current_step),
135 | # np.transpose(np.concatenate(
136 | # (fake_img, sr_img, hr_img), axis=1), [2, 0, 1]),
137 | # idx)
138 | # avg_psnr += Metrics.calculate_psnr(
139 | # sr_img, hr_img)
140 | #
141 | # if wandb_logger:
142 | # wandb_logger.log_image(
143 | # f'validation_{idx}',
144 | # np.concatenate((fake_img, sr_img, hr_img), axis=1)
145 | # )
146 | #
147 | # avg_psnr = avg_psnr / idx
148 | # diffusion.set_new_noise_schedule(
149 | # opt['model']['beta_schedule']['train'], schedule_phase='train')
150 | # # log
151 | # logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
152 | # logger_val = logging.getLogger('val') # validation logger
153 | # logger_val.info(' psnr: {:.4e}'.format(
154 | # current_epoch, current_step, avg_psnr))
155 | # # tensorboard logger
156 | # tb_logger.add_scalar('psnr', avg_psnr, current_step)
157 | #
158 | # if wandb_logger:
159 | # wandb_logger.log_metrics({
160 | # 'validation/val_psnr': avg_psnr,
161 | # 'validation/val_step': val_step
162 | # })
163 | # val_step += 1
164 |
165 | if current_step % opt['train']['save_checkpoint_freq'] == 0:
166 | logger.info('Saving models and training states.')
167 | diffusion.save_network(current_epoch, current_step)
168 |
169 | if wandb_logger and opt['log_wandb_ckpt']:
170 | wandb_logger.log_checkpoint(current_epoch, current_step)
171 |
172 | if wandb_logger:
173 | wandb_logger.log_metrics({'epoch': current_epoch-1})
174 |
175 | # save model
176 | logger.info('End of training.')
--------------------------------------------------------------------------------