├── thumbnail
└── abstract.png
├── __pycache__
├── util.cpython-36.pyc
├── dataset.cpython-36.pyc
├── model.cpython-36.pyc
└── networks.cpython-36.pyc
├── .idea
├── vcs.xml
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── .gitignore
├── modules.xml
├── VEUS.iml
└── deployment.xml
├── README.md
├── train.py
├── dataset.py
├── util.py
├── networks.py
└── model.py
/thumbnail/abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/thumbnail/abstract.png
--------------------------------------------------------------------------------
/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../:\github\VEUS\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/VEUS.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VEUS
2 | Code for ["Virtual elastography ultrasound via generate adversarial network for breast cancer diagnosis"](https://www.nature.com/articles/s41467-023-36102-1.pdf).
3 |
4 | If you find this work useful in your research, please cite
5 | ```
6 | @article{yao2023virtual,
7 | title={Virtual elastography ultrasound via generative adversarial network for breast cancer diagnosis},
8 | author={Yao, Zhao and Luo, Ting and Dong, YiJie and Jia, XiaoHong and Deng, YinHui and Wu, GuoQing and Zhu, Ying and Zhang, JingWen and Liu, Juan and Yang, LiChun and others},
9 | journal={Nature Communications},
10 | volume={14},
11 | number={1},
12 | pages={788},
13 | year={2023},
14 | publisher={Nature Publishing Group UK London}
15 | }
16 | ```
17 |
18 | 
19 |
20 | ### Prerequisites
21 |
22 | * python=3.6
23 | * pytorch=1.10.0
24 | * CUDA= 10.2
25 | * torchvision=0.11.1
26 | * other dependencies (e.g., visdom, dominate)
27 |
28 | ## Getting start
29 |
30 | * clone this repository:
31 | ```
32 | git clone git@github.com:yyyzzzhao/VEUS.git
33 | ```
34 |
35 | ### Usage
36 | ```
37 | python train.py
38 | ```
39 |
40 | We partially borrowed the [Pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) model for this project, many thanks to the author.
41 |
42 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
27 |
28 |
29 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import torch
4 | from model import EnhanceGANModel
5 | from dataset import AlignedDataset
6 | from util import Visualizer
7 |
8 |
9 | def get_parser():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--data_root', default=r'E:\Database\RUIJIN_data\03_DATASET\mask_dataset', type=str, help='path/to/data')
12 | parser.add_argument('--EnhanceT', default=True, help='')
13 |
14 | # training related
15 | parser.add_argument('--isTrain', default=True, help='')
16 | parser.add_argument('--epoch_count', default=1, help='the starting epoch count')
17 | parser.add_argument('--lr', default=0.0002, help='learning rate')
18 | parser.add_argument('--lr_policy', default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
19 | parser.add_argument('--batch_size', default=1, help='')
20 | parser.add_argument('--niter', default=100, help='# of iter at starting learning rate')
21 | parser.add_argument('--niter_decay', default=100, help='# of iter to linearly decay learning rate to zero')
22 | parser.add_argument('--epoch', default='latest', help='')
23 | parser.add_argument('--gpu_ids', default=[0], help='which device')
24 |
25 | parser.add_argument('--checkpoints_dir', default='./checkpoints', help='')
26 | parser.add_argument('--name', default='experiment_name', help='')
27 | return parser.parse_args()
28 |
29 |
30 | if __name__ == '__main__':
31 | opt = get_parser()
32 | dataset = AlignedDataset(opt.data_root, 'train')
33 | dataset_size = len(dataset)
34 | dataset = torch.utils.data.DataLoader(
35 | dataset,
36 | batch_size=opt.batch_size,
37 | shuffle=not False,
38 | num_workers=0)
39 | print('The number of training images = %d' % dataset_size)
40 |
41 | model = EnhanceGANModel(opt)
42 | model.setup(opt)
43 |
44 | visualizer = Visualizer(opt)
45 | total_iters = 0
46 |
47 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
48 | epoch_start_time = time.time() # timer for entire epoch
49 | iter_data_time = time.time() # timer for data loading per iteration
50 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
51 |
52 | for i, data in enumerate(dataset): # inner loop within one epoch
53 | iter_start_time = time.time() # timer for computation per iteration
54 | if total_iters % 100 == 0:
55 | t_data = iter_start_time - iter_data_time
56 | visualizer.reset()
57 | total_iters += opt.batch_size
58 | epoch_iter += opt.batch_size
59 | model.set_input(data) # unpack data from dataset and apply preprocessing
60 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights
61 |
62 | if total_iters % 400 == 0: # display images on visdom and save images to a HTML file
63 | save_result = total_iters % 1000 == 0
64 | model.compute_visuals()
65 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
66 |
67 | if total_iters % 100 == 0: # print training losses and save logging information to the disk
68 | losses = model.get_current_losses()
69 | t_comp = (time.time() - iter_start_time) / opt.batch_size
70 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
71 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
72 |
73 | if total_iters % 5000 == 0: # cache our latest model every iterations
74 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
75 | save_suffix = 'latest'
76 | model.save_networks(save_suffix)
77 |
78 | iter_data_time = time.time()
79 | if epoch % 5 == 0: # cache our model every epochs
80 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
81 | model.save_networks('latest')
82 | model.save_networks(epoch)
83 |
84 | print('End of epoch %d / %d \t Time Taken: %d sec' % (
85 | epoch, 100 + 100, time.time() - epoch_start_time))
86 | model.update_learning_rate()
87 |
88 |
89 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision
3 | import os.path
4 | import torch.utils.data as data
5 | import pandas as pd
6 | from PIL import Image
7 | from util import *
8 | import numpy as np
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | class AlignedDataset:
36 | """A dataset class for paired image dataset.
37 |
38 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
39 | During test time, you need to prepare a directory '/path/to/data/test'.
40 | """
41 |
42 | def __init__(self, data_root, phase='train'):
43 | """Initialize this dataset class.
44 |
45 | Parameters:
46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
47 | """
48 | super(AlignedDataset, self).__init__()
49 | self.data_root = data_root
50 | self.phase = phase
51 |
52 | self.load_size = 286
53 | self.crop_size = 256
54 | self.input_nc = 1
55 | self.output_nc = 3
56 |
57 | self.dir_ABM = os.path.join(self.data_root, self.phase) # get the image directory
58 | self.ABM_paths = sorted(make_dataset(self.dir_ABM)) # get image paths
59 |
60 | def __getitem__(self, index):
61 | """Return a data point and its metadata information.
62 |
63 | Parameters:
64 | index - - a random integer for data indexing
65 |
66 | Returns a dictionary that contains A, B, M, A_paths, B_paths, M_paths
67 | A (tensor) - - an image in the input domain
68 | B (tensor) - - its corresponding image in the target domain
69 | M (tensor) - - tumor area mask
70 | A_paths (str) - - image paths
71 | B_paths (str) - - image paths (same as A_paths)
72 | M_paths (str) - - image paths (same as A_paths)
73 | """
74 | # read a image given a random integer index
75 | ABM_path = self.ABM_paths[index]
76 | ABM = Image.open(ABM_path).convert('RGB')
77 | # split ABM image into A, B and Mask
78 | w, h = ABM.size
79 | w2 = int(w / 3)
80 | A = ABM.crop((0, 0, w2, h)) # left, upper, right, lower
81 | B = ABM.crop((w2, 0, 2*w2, h))
82 | M = ABM.crop((2*w2, 0, w, h))
83 |
84 | # apply the same transform to both A and B
85 | transform_params = self.get_params(A.size)
86 | A_transform = self.get_transform(transform_params, grayscale=(self.input_nc == 1))
87 | B_transform = self.get_transform(transform_params, grayscale=(self.output_nc == 1))
88 | M_transform = self.get_transform(transform_params, grayscale=True)
89 |
90 | A = A_transform(A)
91 | B = B_transform(B)
92 | M = M_transform(M)
93 |
94 | # get bounding bboxes coordinates
95 | M_num = M.numpy() # (-1, 1)
96 | M_num = np.squeeze(M_num)
97 | logical_M = np.where(M_num < 0, 0, 1)
98 | bbox = extract_bboxes(logical_M) # (y0, x0, y1, x1)
99 |
100 | return {'A': A, 'B': B, 'M': M, 'bbox': bbox}
101 |
102 | def __len__(self):
103 | """Return the total number of images in the dataset."""
104 | return len(self.ABM_paths)
105 |
106 | # misc
107 | def get_params(self, size):
108 | w, h = size
109 | new_h = new_w = self.load_size
110 |
111 | x = random.randint(0, np.maximum(0, new_w - self.crop_size))
112 | y = random.randint(0, np.maximum(0, new_h - self.crop_size))
113 |
114 | flip = random.random() > 0.5
115 |
116 | return {'crop_pos': (x, y), 'flip': flip}
117 |
118 | def get_transform(self, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
119 | transform_list = []
120 | if grayscale:
121 | transform_list.append(transforms.Grayscale(1))
122 |
123 | osize = [self.load_size, self.load_size]
124 | transform_list.append(transforms.Resize(osize, method))
125 |
126 | if params is None:
127 | transform_list.append(transforms.RandomCrop(self.crop_size))
128 | else:
129 | transform_list.append(transforms.Lambda(lambda img: my_crop(img, params['crop_pos'], self.crop_size)))
130 |
131 | if params is None:
132 | transform_list.append(transforms.RandomHorizontalFlip())
133 | elif params['flip']:
134 | transform_list.append(transforms.Lambda(lambda img: my_flip(img, params['flip'])))
135 |
136 | transform_list += [transforms.ToTensor()]
137 | if convert:
138 | if grayscale:
139 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
140 | else:
141 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
142 | return transforms.Compose(transform_list)
143 |
144 |
145 | def my_make_power_2(img, base, method=Image.BICUBIC):
146 | ow, oh = img.size
147 | h = int(round(oh / base) * base)
148 | w = int(round(ow / base) * base)
149 | if (h == oh) and (w == ow):
150 | return img
151 |
152 | my_print_size_warning(ow, oh, w, h)
153 | return img.resize((w, h), method)
154 |
155 |
156 | def __scale_width(img, target_width, method=Image.BICUBIC):
157 | ow, oh = img.size
158 | if (ow == target_width):
159 | return img
160 | w = target_width
161 | h = int(target_width * oh / ow)
162 | return img.resize((w, h), method)
163 |
164 |
165 | def my_crop(img, pos, size):
166 | ow, oh = img.size
167 | x1, y1 = pos
168 | tw = th = size
169 | if (ow > tw or oh > th):
170 | return img.crop((x1, y1, x1 + tw, y1 + th))
171 | return img
172 |
173 |
174 | def my_flip(img, flip):
175 | if flip:
176 | return img.transpose(Image.FLIP_LEFT_RIGHT)
177 | return img
178 |
179 |
180 | def my_print_size_warning(ow, oh, w, h):
181 | """Print warning information about image size(only print once)"""
182 | if not hasattr(my_print_size_warning, 'has_printed'):
183 | print("The image size needs to be a multiple of 4. "
184 | "The loaded image size was (%d, %d), so it was adjusted to "
185 | "(%d, %d). This adjustment will be done to all images "
186 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
187 | my_print_size_warning.has_printed = True
188 |
189 |
190 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import ntpath
5 | import time
6 | import cv2
7 | import torch
8 | from PIL import Image
9 | from subprocess import Popen, PIPE
10 | from util import *
11 | # from scipy.misc import imresize
12 |
13 | if sys.version_info[0] == 2:
14 | VisdomExceptionBase = Exception
15 | else:
16 | VisdomExceptionBase = ConnectionError
17 |
18 |
19 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
20 | """Save images to the disk.
21 |
22 | Parameters:
23 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
24 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
25 | image_path (str) -- the string is used to create image paths
26 | aspect_ratio (float) -- the aspect ratio of saved images
27 | width (int) -- the images will be resized to width x width
28 |
29 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
30 | """
31 | image_dir = webpage.get_image_dir()
32 | short_path = ntpath.basename(image_path[0])
33 | name = os.path.splitext(short_path)[0]
34 |
35 | webpage.add_header(name)
36 | ims, txts, links = [], [], []
37 |
38 | for label, im_data in visuals.items():
39 | im = tensor2im(im_data)
40 | image_name = '%s_%s.png' % (name, label)
41 | save_path = os.path.join(image_dir, image_name)
42 | h, w, _ = im.shape
43 | save_image(im, save_path)
44 |
45 | ims.append(image_name)
46 | txts.append(label)
47 | links.append(image_name)
48 | webpage.add_images(ims, txts, links, width=width)
49 |
50 |
51 | class Visualizer:
52 | """This class includes several functions that can display/save images and print/save logging information.
53 |
54 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
55 | """
56 |
57 | def __init__(self, opt):
58 | """Initialize the Visualizer class
59 |
60 | Parameters:
61 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
62 | Step 1: Cache the training/test options
63 | Step 2: connect to a visdom server
64 | Step 3: create an HTML object for saveing HTML filters
65 | Step 4: create a logging file to store training losses
66 | """
67 | self.opt = opt # cache the option
68 | self.display_id = 1
69 | self.win_size = 256
70 | self.name = opt.name
71 | self.port = 8097
72 | self.saved = False
73 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
74 | self.img_dir = os.path.join(self.web_dir, 'images')
75 | mkdirs([self.web_dir, self.img_dir])
76 | if self.display_id > 0: # connect to a visdom server given and
77 | import visdom
78 | self.ncols = 4
79 | self.vis = visdom.Visdom(server='http://localhost', port=8097, env='main')
80 | if not self.vis.check_connection():
81 | self.create_visdom_connections()
82 |
83 | # create a logging file to store training losses
84 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
85 | with open(self.log_name, "a") as log_file:
86 | now = time.strftime("%c")
87 | log_file.write('================ Training Loss (%s) ================\n' % now)
88 |
89 | def reset(self):
90 | """Reset the self.saved status"""
91 | self.saved = False
92 |
93 | def create_visdom_connections(self):
94 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
95 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
96 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
97 | print('Command: %s' % cmd)
98 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
99 |
100 | def display_current_results(self, visuals, epoch, save_result):
101 | """Display current results on visdom; save current results to an HTML file.
102 |
103 | Parameters:
104 | visuals (OrderedDict) - - dictionary of images to display or save
105 | epoch (int) - - the current epoch
106 | save_result (bool) - - if save the current results to an HTML file
107 | """
108 | if self.display_id > 0: # show images in the browser using visdom
109 | ncols = self.ncols
110 | if ncols > 0: # show all the images in one visdom panel
111 | ncols = min(ncols, len(visuals))
112 | h, w = next(iter(visuals.values())).shape[:2]
113 | table_css = """""" % (w, h) # create a table css
117 | # create a table of images.
118 | title = self.name
119 | label_html = ''
120 | label_html_row = ''
121 | images = []
122 | idx = 0
123 | for label, image in visuals.items():
124 | image_numpy = tensor2im(image)
125 | # print(image_numpy.shape)
126 | if not image_numpy.shape[:2] == (h, w):
127 | image_numpy = cv2.resize(image_numpy, (w, h))
128 | label_html_row += '| %s | ' % label
129 | images.append(image_numpy.transpose([2, 0, 1]))
130 | idx += 1
131 | if idx % ncols == 0:
132 | label_html += '%s
' % label_html_row
133 | label_html_row = ''
134 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
135 | while idx % ncols != 0:
136 | images.append(white_image)
137 | label_html_row += ' | '
138 | idx += 1
139 | if label_html_row != '':
140 | label_html += '%s
' % label_html_row
141 | try:
142 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
143 | padding=2, opts=dict(title=title + ' images'))
144 | label_html = '' % label_html
145 | self.vis.text(table_css + label_html, win=self.display_id + 2,
146 | opts=dict(title=title + ' labels'))
147 | except VisdomExceptionBase:
148 | self.create_visdom_connections()
149 |
150 | else: # show each image in a separate visdom panel;
151 | idx = 1
152 | try:
153 | for label, image in visuals.items():
154 | image_numpy = tensor2im(image)
155 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
156 | win=self.display_id + idx)
157 | idx += 1
158 | except VisdomExceptionBase:
159 | self.create_visdom_connections()
160 |
161 | # save images to the disk
162 | for label, image in visuals.items():
163 | image_numpy = tensor2im(image)
164 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
165 | save_image(image_numpy, img_path)
166 |
167 | def plot_current_losses(self, epoch, counter_ratio, losses):
168 | """display the current losses on visdom display: dictionary of error labels and values
169 |
170 | Parameters:
171 | epoch (int) -- current epoch
172 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
173 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
174 | """
175 | if not hasattr(self, 'plot_data'):
176 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
177 | self.plot_data['X'].append(epoch + counter_ratio)
178 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
179 | try:
180 | self.vis.line(
181 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
182 | Y=np.array(self.plot_data['Y']),
183 | opts={
184 | 'title': self.name + ' loss over time',
185 | 'legend': self.plot_data['legend'],
186 | 'xlabel': 'epoch',
187 | 'ylabel': 'loss'},
188 | win=self.display_id)
189 | except VisdomExceptionBase:
190 | self.create_visdom_connections()
191 |
192 | # losses: same format as |losses| of plot_current_losses
193 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
194 | """print current losses on console; also save the losses to the disk
195 |
196 | Parameters:
197 | epoch (int) -- current epoch
198 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
199 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
200 | t_comp (float) -- computational time per data point (normalized by batch_size)
201 | t_data (float) -- data loading time per data point (normalized by batch_size)
202 | """
203 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
204 | for k, v in losses.items():
205 | message += '%s: %.3f ' % (k, v)
206 |
207 | print(message) # print the message
208 | with open(self.log_name, "a") as log_file:
209 | log_file.write('%s\n' % message) # save the message
210 |
211 |
212 | def extract_bboxes(mask, exp_ratio=1.0):
213 | """Compute bounding boxes from mask.
214 | mask: [height, width]. Mask pixels are either 1 or 0.
215 | exp_ratio: expand the side length of the rectangle in some ratio. >=1.0
216 | Returns: bbox array [y1, x1, y2, x2].
217 | """
218 | boxes = np.zeros([1, 4], dtype=np.int32)
219 | m = mask[:, :]
220 | # Bounding box.
221 | horizontal_indicies = np.where(np.any(m, axis=0))[0]
222 | vertical_indicies = np.where(np.any(m, axis=1))[0]
223 | if horizontal_indicies.shape[0]:
224 | x1, x2 = horizontal_indicies[[0, -1]]
225 | y1, y2 = vertical_indicies[[0, -1]]
226 | # x2 and y2 should not be part of the box. Increment by 1.
227 | x2 += 1
228 | y2 += 1
229 | if exp_ratio > 1.0:
230 | side_x = (x2 - x1 + 1) * exp_ratio
231 | side_y = (y2 - y1 + 1) * exp_ratio
232 | x1 = x1 - side_x / 2 if (x1 - side_x/2) > 0 else 0
233 | x2 = x2 + side_x / 2 if (x2 + side_x/2) < np.size(mask, 2) else np.size(mask, 2)
234 | y1 = y1 - side_y / 2 if (y1 - side_y/2) > 0 else 0
235 | y2 = y2 + side_y / 2 if (y2 + side_y/2) < np.size(mask, 1) else np.size(mask, 1)
236 | else:
237 | # No mask for this instance. Might happen due to
238 | # resizing or cropping. Set bbox to zeros
239 | x1, x2, y1, y2 = 0, 0, 0, 0
240 | boxes = np.array([y1, x1, y2, x2])
241 | return boxes.astype(np.int32)
242 |
243 |
244 | def tensor2im(input_image, imtype=np.uint8):
245 | """"Converts a Tensor array into a numpy image array.
246 |
247 | Parameters:
248 | input_image (tensor) -- the input image tensor array
249 | imtype (type) -- the desired type of the converted numpy array
250 | """
251 | if not isinstance(input_image, np.ndarray):
252 | if isinstance(input_image, torch.Tensor): # get the data from a variable
253 | image_tensor = input_image.data
254 | else:
255 | return input_image
256 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
257 | if image_numpy.shape[0] == 1: # grayscale to RGB
258 | image_numpy = np.tile(image_numpy, (3, 1, 1))
259 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
260 | else: # if it is a numpy array, do nothing
261 | image_numpy = input_image
262 | return image_numpy.astype(imtype)
263 |
264 |
265 | def save_image(image_numpy, image_path):
266 | """Save a numpy image to the disk
267 |
268 | Parameters:
269 | image_numpy (numpy array) -- input numpy array
270 | image_path (str) -- the path of the image
271 | """
272 | image_pil = Image.fromarray(image_numpy)
273 | image_pil.save(image_path)
274 |
275 |
276 | def mkdirs(paths):
277 | """create empty directories if they don't exist
278 |
279 | Parameters:
280 | paths (str list) -- a list of directory paths
281 | """
282 | if isinstance(paths, list) and not isinstance(paths, str):
283 | for path in paths:
284 | mkdir(path)
285 | else:
286 | mkdir(paths)
287 |
288 |
289 | def mkdir(path):
290 | """create a single empty directory if it didn't exist
291 |
292 | Parameters:
293 | path (str) -- a single directory path
294 | """
295 | if not os.path.exists(path):
296 | os.makedirs(path)
297 |
298 |
299 | def index_closest(pixel, bar):
300 | n = len(bar) # bar: =====================
301 | temp = [] # n ^ 0
302 | for i in range(n): # index: |
303 | dis = sum(list(map(lambda x: abs(x[0]-x[1]), zip(pixel, bar[i]))))
304 | temp.append(dis)
305 | value = n - temp.index(min(temp))
306 | return value
307 |
308 |
309 | def cal_color_vector(color_img, cluster):
310 | """ Caculate hard-code strain ratio according cluster"""
311 | color_bar = [[0, 0, 143], [0, 0, 159],[0, 0, 175] ,[0, 0, 191] ,[0, 0, 207] ,[0, 0, 223] ,[0, 0, 239] ,[0, 0, 255] ,
312 | [0, 16, 255],[0, 32, 255],[0, 48, 255],[0, 64, 255],[0, 80, 255],[0, 96, 255],[0, 112, 255],[0, 128, 255],
313 | [0, 143, 255],[0, 159, 255],[0, 175, 255],[0, 191, 255],[0, 207, 255],[0, 223, 255],[0, 239, 255],[0, 255, 255],
314 | [16, 255, 239],[32, 255, 223],[48, 255, 207],[64, 255, 191],[80, 255, 175],[96, 255, 159],[112, 255, 143],[128, 255, 128],
315 | [143, 255, 112],[159, 255, 96],[175, 255, 80],[191, 255, 64],[207, 255, 48],[223, 255, 32],[239, 255, 16],[255, 255, 0],
316 | [255, 239, 0],[255, 223, 0],[255, 207, 0],[255, 191, 0],[255, 175, 0],[255, 159, 0],[255, 143, 0],[255, 128, 0],
317 | [255, 112, 0],[255, 96, 0],[255, 80, 0],[255, 64, 0],[255, 48, 0],[255, 32, 0],[255, 16, 0],[255, 0, 0],
318 | [239, 0, 0],[223, 0, 0],[207, 0, 0],[191, 0, 0],[175, 0, 0],[159, 0, 0],[143, 0, 0],[128, 0, 0]]
319 | color_img = np.squeeze(color_img)
320 | cluster = np.squeeze(cluster) # shape (1, 1, 256, 256)
321 | num = torch.max(cluster) + 1 # number clusters
322 | vec = []
323 | print(num)
324 | for i in range(num):
325 | values = 0
326 | XYs = torch.nonzero(cluster == i) # tuple (Xs, Ys)
327 | n = XYs.size()[0] # the number of nonzero value
328 | if n > 0:
329 | for j in range(n):
330 | # print(j)
331 | pixel = color_img[:, XYs[j, 0], XYs[j, 1]]
332 | value = index_closest(pixel, color_bar)
333 | values += value
334 | vec.append(values / n)
335 | else:
336 | vec.append(0)
337 | return vec
338 |
339 |
340 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 |
8 | class UnetGenerator(nn.Module):
9 | """Create a Unet-based generator"""
10 |
11 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
12 | """Construct a Unet generator
13 | Parameters:
14 | input_nc (int) -- the number of channels in input images
15 | output_nc (int) -- the number of channels in output images
16 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
17 | image of size 128x128 will become of size 1x1 # at the bottleneck
18 | ngf (int) -- the number of filters in the last conv layer
19 | norm_layer -- normalization layer
20 |
21 | We construct the U-Net from the innermost layer to the outermost layer.
22 | It is a recursive process.
23 | """
24 | super(UnetGenerator, self).__init__()
25 | # construct unet structure
26 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
27 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
28 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
29 | # gradually reduce the number of filters from ngf * 8 to ngf
30 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
31 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
32 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
33 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
34 |
35 | def forward(self, input):
36 | """Standard forward"""
37 | return self.model(input)
38 |
39 |
40 | class UnetSkipConnectionBlock(nn.Module):
41 | """Defines the Unet submodule with skip connection.
42 | X -------------------identity----------------------
43 | |-- downsampling -- |submodule| -- upsampling --|
44 | """
45 |
46 | def __init__(self, outer_nc, inner_nc, input_nc=None,
47 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
48 | """Construct a Unet submodule with skip connections.
49 |
50 | Parameters:
51 | outer_nc (int) -- the number of filters in the outer conv layer
52 | inner_nc (int) -- the number of filters in the inner conv layer
53 | input_nc (int) -- the number of channels in input images/features
54 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
55 | outermost (bool) -- if this module is the outermost module
56 | innermost (bool) -- if this module is the innermost module
57 | norm_layer -- normalization layer
58 | user_dropout (bool) -- if use dropout layers.
59 | """
60 | super(UnetSkipConnectionBlock, self).__init__()
61 | self.outermost = outermost
62 | if type(norm_layer) == functools.partial:
63 | use_bias = norm_layer.func == nn.InstanceNorm2d
64 | else:
65 | use_bias = norm_layer == nn.InstanceNorm2d
66 | if input_nc is None:
67 | input_nc = outer_nc
68 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
69 | stride=2, padding=1, bias=use_bias)
70 | downrelu = nn.LeakyReLU(0.2, True)
71 | downnorm = norm_layer(inner_nc)
72 | uprelu = nn.ReLU(True)
73 | upnorm = norm_layer(outer_nc)
74 |
75 | if outermost:
76 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
77 | kernel_size=4, stride=2,
78 | padding=1)
79 | down = [downconv]
80 | up = [uprelu, upconv, nn.Tanh()]
81 | model = down + [submodule] + up
82 | elif innermost:
83 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
84 | kernel_size=4, stride=2,
85 | padding=1, bias=use_bias)
86 | down = [downrelu, downconv]
87 | up = [uprelu, upconv, upnorm]
88 | model = down + up
89 | else:
90 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
91 | kernel_size=4, stride=2,
92 | padding=1, bias=use_bias)
93 | down = [downrelu, downconv, downnorm]
94 | up = [uprelu, upconv, upnorm]
95 |
96 | if use_dropout:
97 | model = down + [submodule] + up + [nn.Dropout(0.5)]
98 | else:
99 | model = down + [submodule] + up
100 |
101 | self.model = nn.Sequential(*model)
102 |
103 | def forward(self, x):
104 | if self.outermost:
105 | return self.model(x)
106 | else: # add skip connections
107 | return torch.cat([x, self.model(x)], 1)
108 |
109 |
110 | class NLayerDiscriminator(nn.Module):
111 | """Defines a PatchGAN discriminator"""
112 |
113 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
114 | """Construct a PatchGAN discriminator
115 |
116 | Parameters:
117 | input_nc (int) -- the number of channels in input images
118 | ndf (int) -- the number of filters in the last conv layer
119 | n_layers (int) -- the number of conv layers in the discriminator
120 | norm_layer -- normalization layer
121 | """
122 | super(NLayerDiscriminator, self).__init__()
123 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
124 | use_bias = norm_layer.func != nn.BatchNorm2d
125 | else:
126 | use_bias = norm_layer != nn.BatchNorm2d
127 |
128 | kw = 4
129 | padw = 1
130 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
131 | nf_mult = 1
132 | nf_mult_prev = 1
133 | for n in range(1, n_layers): # gradually increase the number of filters
134 | nf_mult_prev = nf_mult
135 | nf_mult = min(2 ** n, 8) # n = 1,2 nf= 2, 4
136 | sequence += [
137 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
138 | norm_layer(ndf * nf_mult),
139 | nn.LeakyReLU(0.2, True)
140 | ]
141 |
142 | nf_mult_prev = nf_mult
143 | nf_mult = min(2 ** n_layers, 8)
144 | sequence += [
145 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
146 | norm_layer(ndf * nf_mult),
147 | nn.LeakyReLU(0.2, True)
148 | ]
149 |
150 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
151 | self.model = nn.Sequential(*sequence)
152 |
153 | def forward(self, input):
154 | """Standard forward."""
155 | return self.model(input)
156 |
157 |
158 | class GANLoss(nn.Module):
159 | """Define different GAN objectives.
160 |
161 | The GANLoss class abstracts away the need to create the target label tensor
162 | that has the same size as the input.
163 | """
164 |
165 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
166 | """ Initialize the GANLoss class.
167 |
168 | Parameters:
169 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
170 | target_real_label (bool) - - label for a real image
171 | target_fake_label (bool) - - label of a fake image
172 |
173 | Note: Do not use sigmoid as the last layer of Discriminator.
174 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
175 | """
176 | super(GANLoss, self).__init__()
177 | self.register_buffer('real_label', torch.tensor(target_real_label))
178 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
179 | self.gan_mode = gan_mode
180 | if gan_mode == 'lsgan':
181 | self.loss = nn.MSELoss()
182 | elif gan_mode == 'vanilla':
183 | self.loss = nn.BCEWithLogitsLoss()
184 | elif gan_mode in ['wgangp']:
185 | self.loss = None
186 | else:
187 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
188 |
189 | def get_target_tensor(self, prediction, target_is_real):
190 | """Create label tensors with the same size as the input.
191 |
192 | Parameters:
193 | prediction (tensor) - - tpyically the prediction from a discriminator
194 | target_is_real (bool) - - if the ground truth label is for real images or fake images
195 |
196 | Returns:
197 | A label tensor filled with ground truth label, and with the size of the input
198 | """
199 |
200 | if target_is_real:
201 | target_tensor = self.real_label
202 | else:
203 | target_tensor = self.fake_label
204 | return target_tensor.expand_as(prediction)
205 |
206 | def __call__(self, prediction, target_is_real):
207 | """Calculate loss given Discriminator's output and grount truth labels.
208 |
209 | Parameters:
210 | prediction (tensor) - - tpyically the prediction output from a discriminator
211 | target_is_real (bool) - - if the ground truth label is for real images or fake images
212 |
213 | Returns:
214 | the calculated loss.
215 | """
216 | if self.gan_mode in ['lsgan', 'vanilla']:
217 | target_tensor = self.get_target_tensor(prediction, target_is_real)
218 | loss = self.loss(prediction, target_tensor)
219 | elif self.gan_mode == 'wgangp':
220 | if target_is_real:
221 | loss = -prediction.mean()
222 | else:
223 | loss = prediction.mean()
224 | return loss
225 |
226 |
227 | # VGG architecture used for identity preserving loss
228 | class VGG_(nn.Module):
229 | def __init__(self, requires_grad=False):
230 | super().__init__()
231 | net = VGG('VGG13', num_classes=4)
232 | checkpoint = torch.load('./models/ckpt.pth')
233 | for key in list(checkpoint['net'].keys()):
234 | key_ = key[key.index('.')+1:]
235 | checkpoint['net'][key_] = checkpoint['net'].pop(key)
236 | net.load_state_dict(checkpoint['net'])
237 | vgg_pretrained_features = net.features
238 | self.slice1 = torch.nn.Sequential()
239 | self.slice2 = torch.nn.Sequential()
240 | self.slice3 = torch.nn.Sequential()
241 | self.slice4 = torch.nn.Sequential()
242 | self.slice5 = torch.nn.Sequential()
243 | for x in range(2):
244 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
245 | for x in range(2, 7):
246 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
247 | for x in range(7, 12):
248 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
249 | for x in range(12, 21):
250 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
251 | for x in range(21, 30):
252 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
253 | if not requires_grad:
254 | for param in self.parameters():
255 | param.requires_grad = False
256 |
257 | def forward(self, X):
258 | h_relu1 = self.slice1(X)
259 | h_relu2 = self.slice2(h_relu1)
260 | h_relu3 = self.slice3(h_relu2)
261 | h_relu4 = self.slice4(h_relu3)
262 | h_relu5 = self.slice5(h_relu4)
263 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
264 | return out
265 |
266 |
267 | # VGG identity preserving loss defination
268 | class VGGLoss(nn.Module):
269 | def __init__(self):
270 | super(VGGLoss, self).__init__()
271 | self.vgg = VGG_().cuda()
272 | self.criterion = nn.L1Loss()
273 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
274 |
275 | def forward(self, x, y):
276 | """
277 | x : ground truth (tensor)
278 | y : fake images (tensor)
279 | """
280 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
281 | loss = 0
282 | for i in range(len(x_vgg)):
283 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
284 | return loss
285 |
286 |
287 | # misc
288 | def get_scheduler(optimizer, opt):
289 | """Return a learning rate scheduler
290 |
291 | Parameters:
292 | optimizer -- the optimizer of the network
293 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
294 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
295 |
296 | For 'linear', we keep the same learning rate for the first epochs
297 | and linearly decay the rate to zero over the next epochs.
298 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
299 | See https://pytorch.org/docs/stable/optim.html for more details.
300 | """
301 | if opt.lr_policy == 'linear':
302 | def lambda_rule(epoch):
303 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
304 | return lr_l
305 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
306 | elif opt.lr_policy == 'step':
307 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
308 | elif opt.lr_policy == 'plateau':
309 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
310 | elif opt.lr_policy == 'cosine':
311 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
312 | else:
313 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
314 | return scheduler
315 |
316 |
317 | def init_weights(net, init_type='normal', init_gain=0.02):
318 | """Initialize network weights.
319 |
320 | Parameters:
321 | net (network) -- network to be initialized
322 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
323 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
324 |
325 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
326 | work better for some applications. Feel free to try yourself.
327 | """
328 | def init_func(m): # define the initialization function
329 | classname = m.__class__.__name__
330 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
331 | if init_type == 'normal':
332 | init.normal_(m.weight.data, 0.0, init_gain)
333 | elif init_type == 'xavier':
334 | init.xavier_normal_(m.weight.data, gain=init_gain)
335 | elif init_type == 'kaiming':
336 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
337 | elif init_type == 'orthogonal':
338 | init.orthogonal_(m.weight.data, gain=init_gain)
339 | else:
340 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
341 | if hasattr(m, 'bias') and m.bias is not None:
342 | init.constant_(m.bias.data, 0.0)
343 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
344 | init.normal_(m.weight.data, 1.0, init_gain)
345 | init.constant_(m.bias.data, 0.0)
346 |
347 | print('initialize network with %s' % init_type)
348 | net.apply(init_func) # apply the initialization function
349 |
350 |
351 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
352 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
353 | Parameters:
354 | net (network) -- the network to be initialized
355 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
356 | gain (float) -- scaling factor for normal, xavier and orthogonal.
357 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
358 |
359 | Return an initialized network.
360 | """
361 | if len(gpu_ids) > 0:
362 | assert(torch.cuda.is_available())
363 | net.to(gpu_ids[0])
364 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
365 | init_weights(net, init_type, init_gain=init_gain)
366 | return net
367 |
368 |
369 |
370 |
371 |
372 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from networks import *
6 | from util import *
7 |
8 |
9 | class BaseModel(ABC):
10 | """This class is an abstract base class (ABC) for models.
11 | To create a subclass, you need to implement the following five functions:
12 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
13 | -- : unpack data from dataset and apply preprocessing.
14 | -- : produce intermediate results.
15 | -- : calculate losses, gradients, and update network weights.
16 | -- : (optionally) add model-specific options and set default options.
17 | """
18 |
19 | def __init__(self, opt):
20 | """Initialize the BaseModel class.
21 |
22 | Parameters:
23 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
24 |
25 | When creating your custom class, you need to implement your own initialization.
26 | In this fucntion, you should first call
27 | Then, you need to define four lists:
28 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
29 | -- self.model_names (str list): specify the images that you want to display and save.
30 | -- self.visual_names (str list): define networks used in our training.
31 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
32 | """
33 | self.opt = opt
34 | self.gpu_ids = opt.gpu_ids
35 | self.isTrain = opt.isTrain
36 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
37 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
38 | torch.backends.cudnn.benchmark = True
39 | self.loss_names = []
40 | self.model_names = []
41 | self.visual_names = []
42 | self.optimizers = []
43 | self.image_paths = []
44 | self.metric = 0 # used for learning rate policy 'plateau'
45 |
46 | @staticmethod
47 | def modify_commandline_options(parser, is_train):
48 | """Add new model-specific options, and rewrite default values for existing options.
49 |
50 | Parameters:
51 | parser -- original option parser
52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53 |
54 | Returns:
55 | the modified parser.
56 | """
57 | return parser
58 |
59 | @abstractmethod
60 | def set_input(self, input):
61 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
62 |
63 | Parameters:
64 | input (dict): includes the data itself and its metadata information.
65 | """
66 | pass
67 |
68 | @abstractmethod
69 | def forward(self):
70 | """Run forward pass; called by both functions and ."""
71 | pass
72 |
73 | @abstractmethod
74 | def optimize_parameters(self):
75 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
76 | pass
77 |
78 | def setup(self, opt):
79 | """Load and print networks; create schedulers
80 |
81 | Parameters:
82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83 | """
84 | if self.isTrain:
85 | self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86 | if not self.isTrain:
87 | load_suffix = 'iter_%d' % opt.epoch
88 | self.load_networks(load_suffix)
89 | self.print_networks(True)
90 |
91 | def eval(self):
92 | """Make models eval mode during test time"""
93 | for name in self.model_names:
94 | if isinstance(name, str):
95 | net = getattr(self, 'net' + name)
96 | net.eval()
97 |
98 | def test(self):
99 | """Forward function used in test time.
100 |
101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
102 | It also calls to produce additional visualization results
103 | """
104 | with torch.no_grad():
105 | self.forward()
106 | self.compute_visuals()
107 |
108 | def compute_visuals(self):
109 | """Calculate additional output images for visdom and HTML visualization"""
110 | pass
111 |
112 | def get_image_paths(self):
113 | """ Return image paths that are used to load current data"""
114 | return self.image_paths
115 |
116 | def update_learning_rate(self):
117 | """Update learning rates for all the networks; called at the end of every epoch"""
118 | for scheduler in self.schedulers:
119 | if self.opt.lr_policy == 'plateau':
120 | scheduler.step(self.metric)
121 | else:
122 | scheduler.step()
123 |
124 | lr = self.optimizers[0].param_groups[0]['lr']
125 | print('learning rate = %.7f' % lr)
126 |
127 | def get_current_visuals(self):
128 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
129 | visual_ret = OrderedDict()
130 | for name in self.visual_names:
131 | if isinstance(name, str):
132 | visual_ret[name] = getattr(self, name)
133 | return visual_ret
134 |
135 | def get_current_losses(self):
136 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
137 | errors_ret = OrderedDict()
138 | for name in self.loss_names:
139 | if isinstance(name, str):
140 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
141 | return errors_ret
142 |
143 | def save_networks(self, epoch):
144 | """Save all the networks to the disk.
145 |
146 | Parameters:
147 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
148 | """
149 | for name in self.model_names:
150 | if isinstance(name, str):
151 | save_filename = '%s_net_%s.pth' % (epoch, name)
152 | save_path = os.path.join(self.save_dir, save_filename)
153 | net = getattr(self, 'net' + name)
154 |
155 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
156 | torch.save(net.module.cpu().state_dict(), save_path)
157 | net.cuda(self.gpu_ids[0])
158 | else:
159 | torch.save(net.cpu().state_dict(), save_path)
160 |
161 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
162 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
163 | key = keys[i]
164 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
165 | if module.__class__.__name__.startswith('InstanceNorm') and \
166 | (key == 'running_mean' or key == 'running_var'):
167 | if getattr(module, key) is None:
168 | state_dict.pop('.'.join(keys))
169 | if module.__class__.__name__.startswith('InstanceNorm') and \
170 | (key == 'num_batches_tracked'):
171 | state_dict.pop('.'.join(keys))
172 | else:
173 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
174 |
175 | def load_networks(self, epoch):
176 | """Load all the networks from the disk.
177 |
178 | Parameters:
179 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
180 | """
181 | for name in self.model_names:
182 | if isinstance(name, str):
183 | load_filename = '%s_net_%s.pth' % (epoch, name)
184 | load_path = os.path.join(self.save_dir, load_filename)
185 | net = getattr(self, 'net' + name)
186 | if isinstance(net, torch.nn.DataParallel):
187 | net = net.module
188 | print('loading the model from %s' % load_path)
189 | # if you are using PyTorch newer than 0.4 (e.g., built from
190 | # GitHub source), you can remove str() on self.device
191 | state_dict = torch.load(load_path, map_location=str(self.device))
192 | if hasattr(state_dict, '_metadata'):
193 | del state_dict._metadata
194 |
195 | # patch InstanceNorm checkpoints prior to 0.4
196 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
197 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
198 | net.load_state_dict(state_dict)
199 |
200 | def print_networks(self, verbose):
201 | """Print the total number of parameters in the network and (if verbose) network architecture
202 |
203 | Parameters:
204 | verbose (bool) -- if verbose: print the network architecture
205 | """
206 | print('---------- Networks initialized -------------')
207 | for name in self.model_names:
208 | if isinstance(name, str):
209 | net = getattr(self, 'net' + name)
210 | num_params = 0
211 | for param in net.parameters():
212 | num_params += param.numel()
213 | if verbose:
214 | print(net)
215 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
216 | print('-----------------------------------------------')
217 |
218 | def set_requires_grad(self, nets, requires_grad=False):
219 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
220 | Parameters:
221 | nets (network list) -- a list of networks
222 | requires_grad (bool) -- whether the networks require gradients or not
223 | """
224 | if not isinstance(nets, list):
225 | nets = [nets]
226 | for net in nets:
227 | if net is not None:
228 | for param in net.parameters():
229 | param.requires_grad = requires_grad
230 |
231 |
232 | class EnhanceGANModel(BaseModel):
233 | @staticmethod
234 | def modify_commandline_options(parser, is_train=True):
235 | """Add new dataset-specific options, and rewrite default values for existing options.
236 |
237 | Parameters:
238 | parser -- original option parser
239 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
240 |
241 | Returns:
242 | the modified parser.
243 |
244 | For pix2pix, we do not use image buffer
245 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
246 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
247 | """
248 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
249 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
250 | if is_train:
251 | parser.set_defaults(pool_size=0, gan_mode='vanilla')
252 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
253 |
254 | return parser
255 |
256 | def __init__(self, opt):
257 | """Initialize the pix2pix class.
258 |
259 | Parameters:
260 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
261 | """
262 | BaseModel.__init__(self, opt)
263 | # specify the training losses you want to print out. The training/test scripts will call
264 | self.loss_names = ['GA_GAN', 'GT_GAN', 'G_L1', 'D_A', 'D_T']
265 | # specify the images you want to save/display. The training/test scripts will call
266 | self.visual_names = ['real_A', 'fake_B', 'real_B', 'real_AT', 'real_BT', 'fake_BT']
267 | # specify the models you want to save to the disk. The training/test scripts will call and
268 | if self.isTrain:
269 | self.model_names = ['G', 'DA', 'DT']
270 | else: # during test time, only load G
271 | self.model_names = ['G']
272 | # define networks (both generator and discriminator)
273 | self.netG = UnetGenerator(1, 3, 8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True)
274 | self.netG = init_net(self.netG, gpu_ids=opt.gpu_ids)
275 |
276 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
277 | self.netDA = NLayerDiscriminator(1 + 3, 64, 3, norm_layer=nn.BatchNorm2d)
278 | self.netDA = init_net(self.netDA, gpu_ids=opt.gpu_ids)
279 |
280 | if self.opt.EnhanceT:
281 | self.netDT = NLayerDiscriminator(1 + 3, 64, 3, norm_layer=nn.BatchNorm2d)
282 | self.netDT = init_net(self.netDT, gpu_ids=opt.gpu_ids)
283 |
284 | if self.isTrain:
285 | # define loss functions
286 | self.criterionGAN = GANLoss('vanilla').to(self.device)
287 | self.criterionL1 = torch.nn.L1Loss()
288 | # initialize optimizers; schedulers will be automatically created by function .
289 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(0.5, 0.999))
290 | self.optimizer_DA = torch.optim.Adam(self.netDA.parameters(), lr=opt.lr, betas=(0.5, 0.999))
291 | if self.opt.EnhanceT:
292 | self.optimizer_DT = torch.optim.Adam(self.netDT.parameters(), lr=opt.lr, betas=(0.5, 0.999))
293 | self.optimizers.append(self.optimizer_DT)
294 | self.optimizers.append(self.optimizer_G)
295 | self.optimizers.append(self.optimizer_DA)
296 | # Default backup dic for fasten training process
297 | self.real_backup = dict()
298 | self.fake_backup = dict()
299 |
300 | def set_input(self, input):
301 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
302 |
303 | Parameters:
304 | input (dict): include the data itself and its metadata information.
305 |
306 | The option 'direction' can be used to swap images in domain A and domain B.
307 | """
308 | AtoB = True
309 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
310 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
311 | self.mask = input['M']
312 | self.bbox = input['bbox'][0] # (y0, x0, y1, x1)
313 |
314 | def forward(self):
315 | """Run forward pass; called by both functions and ."""
316 | self.fake_B = self.netG(self.real_A) # G(A)
317 | # crop tumor area
318 | self.real_AT = self.real_A[:, :, self.bbox[0]: self.bbox[2], self.bbox[1]: self.bbox[3]]
319 | self.real_BT = self.real_B[:, :, self.bbox[0]: self.bbox[2], self.bbox[1]: self.bbox[3]]
320 | self.fake_BT = self.fake_B[:, :, self.bbox[0]: self.bbox[2], self.bbox[1]: self.bbox[3]]
321 |
322 | def backward_D_basic(self, netD, realA, realB, fakeB):
323 | """A general caculation of discriminator backward"""
324 | # Real (we use a conditional GAN)
325 | real_AB = torch.cat((realA, realB), 1)
326 | pred_real = netD.forward(real_AB)
327 | loss_D_real = self.criterionGAN(pred_real, True)
328 | # Fake
329 | fake_AB = torch.cat((realA, fakeB), 1)
330 | pred_fake = netD.forward(fake_AB.detach())
331 | loss_D_fake = self.criterionGAN(pred_fake, False)
332 |
333 | loss_D = (loss_D_real + loss_D_fake) * 0.5
334 | return loss_D
335 |
336 | def backward_DA(self):
337 | self.loss_D_A = self.backward_D_basic(self.netDA, self.real_A, self.real_B, self.fake_B)
338 | self.loss_D_A.backward()
339 |
340 | def backward_DT(self):
341 | self.loss_D_T = self.backward_D_basic(self.netDT, self.real_AT, self.real_BT, self.fake_BT)
342 | self.loss_D_T.backward()
343 |
344 | def backward_G(self):
345 | """Calculate GAN and L1 loss for the generator"""
346 | # First, G(A) should fake the discriminator
347 | fake_AB = torch.cat((self.real_A, self.fake_B), 1)
348 | pred_fake_A = self.netDA(fake_AB)
349 | self.loss_GA_GAN = self.criterionGAN(pred_fake_A, True)
350 | # Tumor enhance
351 | if self.opt.EnhanceT:
352 | fake_ABT = torch.cat((self.real_AT, self.fake_BT), 1)
353 | pred_fake_AT = self.netDT(fake_ABT)
354 | self.loss_GT_GAN = self.criterionGAN(pred_fake_AT, True)
355 | # Second, G(A) = B
356 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * 100
357 | # combine loss and calculate gradients
358 | self.loss_G = self.loss_GA_GAN + self.loss_G_L1 + self.loss_GT_GAN
359 |
360 | self.loss_G.backward()
361 |
362 | def optimize_parameters(self):
363 | self.forward() # compute fake images: G(A)
364 | # update DA
365 | self.set_requires_grad(self.netDA, True) # enable backprop for D
366 | self.optimizer_DA.zero_grad() # set D's gradients to zero
367 | self.backward_DA() # calculate gradients for D
368 | self.optimizer_DA.step() # update D's weights
369 | if self.opt.EnhanceT:
370 | # update DT
371 | self.set_requires_grad(self.netDT, True) # enable backprop for D
372 | self.optimizer_DT.zero_grad() # set D's gradients to zero
373 | self.backward_DT() # calculate gradients for D
374 | self.optimizer_DT.step() # update D's weights
375 | # update G
376 | self.set_requires_grad(self.netDA, False) # D requires no gradients when optimizing G
377 | self.set_requires_grad(self.netDT, False)
378 | self.optimizer_G.zero_grad() # set G's gradients to zero
379 | self.backward_G() # calculate graidents for G
380 | self.optimizer_G.step() # udpate G's weights
381 |
382 |
383 |
384 |
385 |
386 |
--------------------------------------------------------------------------------