├── thumbnail └── abstract.png ├── __pycache__ ├── util.cpython-36.pyc ├── dataset.cpython-36.pyc ├── model.cpython-36.pyc └── networks.cpython-36.pyc ├── .idea ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── .gitignore ├── modules.xml ├── VEUS.iml └── deployment.xml ├── README.md ├── train.py ├── dataset.py ├── util.py ├── networks.py └── model.py /thumbnail/abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/thumbnail/abstract.png -------------------------------------------------------------------------------- /__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyyzzzhao/VEUS/HEAD/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../:\github\VEUS\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/VEUS.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VEUS 2 | Code for ["Virtual elastography ultrasound via generate adversarial network for breast cancer diagnosis"](https://www.nature.com/articles/s41467-023-36102-1.pdf). 3 | 4 | If you find this work useful in your research, please cite 5 | ``` 6 | @article{yao2023virtual, 7 | title={Virtual elastography ultrasound via generative adversarial network for breast cancer diagnosis}, 8 | author={Yao, Zhao and Luo, Ting and Dong, YiJie and Jia, XiaoHong and Deng, YinHui and Wu, GuoQing and Zhu, Ying and Zhang, JingWen and Liu, Juan and Yang, LiChun and others}, 9 | journal={Nature Communications}, 10 | volume={14}, 11 | number={1}, 12 | pages={788}, 13 | year={2023}, 14 | publisher={Nature Publishing Group UK London} 15 | } 16 | ``` 17 | 18 | !['paper abstract'](thumbnail/abstract.png ) 19 | 20 | ### Prerequisites 21 | 22 | * python=3.6 23 | * pytorch=1.10.0 24 | * CUDA= 10.2 25 | * torchvision=0.11.1 26 | * other dependencies (e.g., visdom, dominate) 27 | 28 | ## Getting start 29 | 30 | * clone this repository: 31 | ``` 32 | git clone git@github.com:yyyzzzhao/VEUS.git 33 | ``` 34 | 35 | ### Usage 36 | ``` 37 | python train.py 38 | ``` 39 | 40 | We partially borrowed the [Pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) model for this project, many thanks to the author. 41 | 42 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import torch 4 | from model import EnhanceGANModel 5 | from dataset import AlignedDataset 6 | from util import Visualizer 7 | 8 | 9 | def get_parser(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--data_root', default=r'E:\Database\RUIJIN_data\03_DATASET\mask_dataset', type=str, help='path/to/data') 12 | parser.add_argument('--EnhanceT', default=True, help='') 13 | 14 | # training related 15 | parser.add_argument('--isTrain', default=True, help='') 16 | parser.add_argument('--epoch_count', default=1, help='the starting epoch count') 17 | parser.add_argument('--lr', default=0.0002, help='learning rate') 18 | parser.add_argument('--lr_policy', default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 19 | parser.add_argument('--batch_size', default=1, help='') 20 | parser.add_argument('--niter', default=100, help='# of iter at starting learning rate') 21 | parser.add_argument('--niter_decay', default=100, help='# of iter to linearly decay learning rate to zero') 22 | parser.add_argument('--epoch', default='latest', help='') 23 | parser.add_argument('--gpu_ids', default=[0], help='which device') 24 | 25 | parser.add_argument('--checkpoints_dir', default='./checkpoints', help='') 26 | parser.add_argument('--name', default='experiment_name', help='') 27 | return parser.parse_args() 28 | 29 | 30 | if __name__ == '__main__': 31 | opt = get_parser() 32 | dataset = AlignedDataset(opt.data_root, 'train') 33 | dataset_size = len(dataset) 34 | dataset = torch.utils.data.DataLoader( 35 | dataset, 36 | batch_size=opt.batch_size, 37 | shuffle=not False, 38 | num_workers=0) 39 | print('The number of training images = %d' % dataset_size) 40 | 41 | model = EnhanceGANModel(opt) 42 | model.setup(opt) 43 | 44 | visualizer = Visualizer(opt) 45 | total_iters = 0 46 | 47 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 48 | epoch_start_time = time.time() # timer for entire epoch 49 | iter_data_time = time.time() # timer for data loading per iteration 50 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 51 | 52 | for i, data in enumerate(dataset): # inner loop within one epoch 53 | iter_start_time = time.time() # timer for computation per iteration 54 | if total_iters % 100 == 0: 55 | t_data = iter_start_time - iter_data_time 56 | visualizer.reset() 57 | total_iters += opt.batch_size 58 | epoch_iter += opt.batch_size 59 | model.set_input(data) # unpack data from dataset and apply preprocessing 60 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 61 | 62 | if total_iters % 400 == 0: # display images on visdom and save images to a HTML file 63 | save_result = total_iters % 1000 == 0 64 | model.compute_visuals() 65 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 66 | 67 | if total_iters % 100 == 0: # print training losses and save logging information to the disk 68 | losses = model.get_current_losses() 69 | t_comp = (time.time() - iter_start_time) / opt.batch_size 70 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 71 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 72 | 73 | if total_iters % 5000 == 0: # cache our latest model every iterations 74 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 75 | save_suffix = 'latest' 76 | model.save_networks(save_suffix) 77 | 78 | iter_data_time = time.time() 79 | if epoch % 5 == 0: # cache our model every epochs 80 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 81 | model.save_networks('latest') 82 | model.save_networks(epoch) 83 | 84 | print('End of epoch %d / %d \t Time Taken: %d sec' % ( 85 | epoch, 100 + 100, time.time() - epoch_start_time)) 86 | model.update_learning_rate() 87 | 88 | 89 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision 3 | import os.path 4 | import torch.utils.data as data 5 | import pandas as pd 6 | from PIL import Image 7 | from util import * 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | class AlignedDataset: 36 | """A dataset class for paired image dataset. 37 | 38 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 39 | During test time, you need to prepare a directory '/path/to/data/test'. 40 | """ 41 | 42 | def __init__(self, data_root, phase='train'): 43 | """Initialize this dataset class. 44 | 45 | Parameters: 46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 47 | """ 48 | super(AlignedDataset, self).__init__() 49 | self.data_root = data_root 50 | self.phase = phase 51 | 52 | self.load_size = 286 53 | self.crop_size = 256 54 | self.input_nc = 1 55 | self.output_nc = 3 56 | 57 | self.dir_ABM = os.path.join(self.data_root, self.phase) # get the image directory 58 | self.ABM_paths = sorted(make_dataset(self.dir_ABM)) # get image paths 59 | 60 | def __getitem__(self, index): 61 | """Return a data point and its metadata information. 62 | 63 | Parameters: 64 | index - - a random integer for data indexing 65 | 66 | Returns a dictionary that contains A, B, M, A_paths, B_paths, M_paths 67 | A (tensor) - - an image in the input domain 68 | B (tensor) - - its corresponding image in the target domain 69 | M (tensor) - - tumor area mask 70 | A_paths (str) - - image paths 71 | B_paths (str) - - image paths (same as A_paths) 72 | M_paths (str) - - image paths (same as A_paths) 73 | """ 74 | # read a image given a random integer index 75 | ABM_path = self.ABM_paths[index] 76 | ABM = Image.open(ABM_path).convert('RGB') 77 | # split ABM image into A, B and Mask 78 | w, h = ABM.size 79 | w2 = int(w / 3) 80 | A = ABM.crop((0, 0, w2, h)) # left, upper, right, lower 81 | B = ABM.crop((w2, 0, 2*w2, h)) 82 | M = ABM.crop((2*w2, 0, w, h)) 83 | 84 | # apply the same transform to both A and B 85 | transform_params = self.get_params(A.size) 86 | A_transform = self.get_transform(transform_params, grayscale=(self.input_nc == 1)) 87 | B_transform = self.get_transform(transform_params, grayscale=(self.output_nc == 1)) 88 | M_transform = self.get_transform(transform_params, grayscale=True) 89 | 90 | A = A_transform(A) 91 | B = B_transform(B) 92 | M = M_transform(M) 93 | 94 | # get bounding bboxes coordinates 95 | M_num = M.numpy() # (-1, 1) 96 | M_num = np.squeeze(M_num) 97 | logical_M = np.where(M_num < 0, 0, 1) 98 | bbox = extract_bboxes(logical_M) # (y0, x0, y1, x1) 99 | 100 | return {'A': A, 'B': B, 'M': M, 'bbox': bbox} 101 | 102 | def __len__(self): 103 | """Return the total number of images in the dataset.""" 104 | return len(self.ABM_paths) 105 | 106 | # misc 107 | def get_params(self, size): 108 | w, h = size 109 | new_h = new_w = self.load_size 110 | 111 | x = random.randint(0, np.maximum(0, new_w - self.crop_size)) 112 | y = random.randint(0, np.maximum(0, new_h - self.crop_size)) 113 | 114 | flip = random.random() > 0.5 115 | 116 | return {'crop_pos': (x, y), 'flip': flip} 117 | 118 | def get_transform(self, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 119 | transform_list = [] 120 | if grayscale: 121 | transform_list.append(transforms.Grayscale(1)) 122 | 123 | osize = [self.load_size, self.load_size] 124 | transform_list.append(transforms.Resize(osize, method)) 125 | 126 | if params is None: 127 | transform_list.append(transforms.RandomCrop(self.crop_size)) 128 | else: 129 | transform_list.append(transforms.Lambda(lambda img: my_crop(img, params['crop_pos'], self.crop_size))) 130 | 131 | if params is None: 132 | transform_list.append(transforms.RandomHorizontalFlip()) 133 | elif params['flip']: 134 | transform_list.append(transforms.Lambda(lambda img: my_flip(img, params['flip']))) 135 | 136 | transform_list += [transforms.ToTensor()] 137 | if convert: 138 | if grayscale: 139 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 140 | else: 141 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 142 | return transforms.Compose(transform_list) 143 | 144 | 145 | def my_make_power_2(img, base, method=Image.BICUBIC): 146 | ow, oh = img.size 147 | h = int(round(oh / base) * base) 148 | w = int(round(ow / base) * base) 149 | if (h == oh) and (w == ow): 150 | return img 151 | 152 | my_print_size_warning(ow, oh, w, h) 153 | return img.resize((w, h), method) 154 | 155 | 156 | def __scale_width(img, target_width, method=Image.BICUBIC): 157 | ow, oh = img.size 158 | if (ow == target_width): 159 | return img 160 | w = target_width 161 | h = int(target_width * oh / ow) 162 | return img.resize((w, h), method) 163 | 164 | 165 | def my_crop(img, pos, size): 166 | ow, oh = img.size 167 | x1, y1 = pos 168 | tw = th = size 169 | if (ow > tw or oh > th): 170 | return img.crop((x1, y1, x1 + tw, y1 + th)) 171 | return img 172 | 173 | 174 | def my_flip(img, flip): 175 | if flip: 176 | return img.transpose(Image.FLIP_LEFT_RIGHT) 177 | return img 178 | 179 | 180 | def my_print_size_warning(ow, oh, w, h): 181 | """Print warning information about image size(only print once)""" 182 | if not hasattr(my_print_size_warning, 'has_printed'): 183 | print("The image size needs to be a multiple of 4. " 184 | "The loaded image size was (%d, %d), so it was adjusted to " 185 | "(%d, %d). This adjustment will be done to all images " 186 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 187 | my_print_size_warning.has_printed = True 188 | 189 | 190 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | import cv2 7 | import torch 8 | from PIL import Image 9 | from subprocess import Popen, PIPE 10 | from util import * 11 | # from scipy.misc import imresize 12 | 13 | if sys.version_info[0] == 2: 14 | VisdomExceptionBase = Exception 15 | else: 16 | VisdomExceptionBase = ConnectionError 17 | 18 | 19 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 20 | """Save images to the disk. 21 | 22 | Parameters: 23 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 24 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 25 | image_path (str) -- the string is used to create image paths 26 | aspect_ratio (float) -- the aspect ratio of saved images 27 | width (int) -- the images will be resized to width x width 28 | 29 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 30 | """ 31 | image_dir = webpage.get_image_dir() 32 | short_path = ntpath.basename(image_path[0]) 33 | name = os.path.splitext(short_path)[0] 34 | 35 | webpage.add_header(name) 36 | ims, txts, links = [], [], [] 37 | 38 | for label, im_data in visuals.items(): 39 | im = tensor2im(im_data) 40 | image_name = '%s_%s.png' % (name, label) 41 | save_path = os.path.join(image_dir, image_name) 42 | h, w, _ = im.shape 43 | save_image(im, save_path) 44 | 45 | ims.append(image_name) 46 | txts.append(label) 47 | links.append(image_name) 48 | webpage.add_images(ims, txts, links, width=width) 49 | 50 | 51 | class Visualizer: 52 | """This class includes several functions that can display/save images and print/save logging information. 53 | 54 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 55 | """ 56 | 57 | def __init__(self, opt): 58 | """Initialize the Visualizer class 59 | 60 | Parameters: 61 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 62 | Step 1: Cache the training/test options 63 | Step 2: connect to a visdom server 64 | Step 3: create an HTML object for saveing HTML filters 65 | Step 4: create a logging file to store training losses 66 | """ 67 | self.opt = opt # cache the option 68 | self.display_id = 1 69 | self.win_size = 256 70 | self.name = opt.name 71 | self.port = 8097 72 | self.saved = False 73 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 74 | self.img_dir = os.path.join(self.web_dir, 'images') 75 | mkdirs([self.web_dir, self.img_dir]) 76 | if self.display_id > 0: # connect to a visdom server given and 77 | import visdom 78 | self.ncols = 4 79 | self.vis = visdom.Visdom(server='http://localhost', port=8097, env='main') 80 | if not self.vis.check_connection(): 81 | self.create_visdom_connections() 82 | 83 | # create a logging file to store training losses 84 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 85 | with open(self.log_name, "a") as log_file: 86 | now = time.strftime("%c") 87 | log_file.write('================ Training Loss (%s) ================\n' % now) 88 | 89 | def reset(self): 90 | """Reset the self.saved status""" 91 | self.saved = False 92 | 93 | def create_visdom_connections(self): 94 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 95 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 96 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 97 | print('Command: %s' % cmd) 98 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 99 | 100 | def display_current_results(self, visuals, epoch, save_result): 101 | """Display current results on visdom; save current results to an HTML file. 102 | 103 | Parameters: 104 | visuals (OrderedDict) - - dictionary of images to display or save 105 | epoch (int) - - the current epoch 106 | save_result (bool) - - if save the current results to an HTML file 107 | """ 108 | if self.display_id > 0: # show images in the browser using visdom 109 | ncols = self.ncols 110 | if ncols > 0: # show all the images in one visdom panel 111 | ncols = min(ncols, len(visuals)) 112 | h, w = next(iter(visuals.values())).shape[:2] 113 | table_css = """""" % (w, h) # create a table css 117 | # create a table of images. 118 | title = self.name 119 | label_html = '' 120 | label_html_row = '' 121 | images = [] 122 | idx = 0 123 | for label, image in visuals.items(): 124 | image_numpy = tensor2im(image) 125 | # print(image_numpy.shape) 126 | if not image_numpy.shape[:2] == (h, w): 127 | image_numpy = cv2.resize(image_numpy, (w, h)) 128 | label_html_row += '%s' % label 129 | images.append(image_numpy.transpose([2, 0, 1])) 130 | idx += 1 131 | if idx % ncols == 0: 132 | label_html += '%s' % label_html_row 133 | label_html_row = '' 134 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 135 | while idx % ncols != 0: 136 | images.append(white_image) 137 | label_html_row += '' 138 | idx += 1 139 | if label_html_row != '': 140 | label_html += '%s' % label_html_row 141 | try: 142 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 143 | padding=2, opts=dict(title=title + ' images')) 144 | label_html = '%s
' % label_html 145 | self.vis.text(table_css + label_html, win=self.display_id + 2, 146 | opts=dict(title=title + ' labels')) 147 | except VisdomExceptionBase: 148 | self.create_visdom_connections() 149 | 150 | else: # show each image in a separate visdom panel; 151 | idx = 1 152 | try: 153 | for label, image in visuals.items(): 154 | image_numpy = tensor2im(image) 155 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 156 | win=self.display_id + idx) 157 | idx += 1 158 | except VisdomExceptionBase: 159 | self.create_visdom_connections() 160 | 161 | # save images to the disk 162 | for label, image in visuals.items(): 163 | image_numpy = tensor2im(image) 164 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 165 | save_image(image_numpy, img_path) 166 | 167 | def plot_current_losses(self, epoch, counter_ratio, losses): 168 | """display the current losses on visdom display: dictionary of error labels and values 169 | 170 | Parameters: 171 | epoch (int) -- current epoch 172 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 173 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 174 | """ 175 | if not hasattr(self, 'plot_data'): 176 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 177 | self.plot_data['X'].append(epoch + counter_ratio) 178 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 179 | try: 180 | self.vis.line( 181 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 182 | Y=np.array(self.plot_data['Y']), 183 | opts={ 184 | 'title': self.name + ' loss over time', 185 | 'legend': self.plot_data['legend'], 186 | 'xlabel': 'epoch', 187 | 'ylabel': 'loss'}, 188 | win=self.display_id) 189 | except VisdomExceptionBase: 190 | self.create_visdom_connections() 191 | 192 | # losses: same format as |losses| of plot_current_losses 193 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 194 | """print current losses on console; also save the losses to the disk 195 | 196 | Parameters: 197 | epoch (int) -- current epoch 198 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 199 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 200 | t_comp (float) -- computational time per data point (normalized by batch_size) 201 | t_data (float) -- data loading time per data point (normalized by batch_size) 202 | """ 203 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 204 | for k, v in losses.items(): 205 | message += '%s: %.3f ' % (k, v) 206 | 207 | print(message) # print the message 208 | with open(self.log_name, "a") as log_file: 209 | log_file.write('%s\n' % message) # save the message 210 | 211 | 212 | def extract_bboxes(mask, exp_ratio=1.0): 213 | """Compute bounding boxes from mask. 214 | mask: [height, width]. Mask pixels are either 1 or 0. 215 | exp_ratio: expand the side length of the rectangle in some ratio. >=1.0 216 | Returns: bbox array [y1, x1, y2, x2]. 217 | """ 218 | boxes = np.zeros([1, 4], dtype=np.int32) 219 | m = mask[:, :] 220 | # Bounding box. 221 | horizontal_indicies = np.where(np.any(m, axis=0))[0] 222 | vertical_indicies = np.where(np.any(m, axis=1))[0] 223 | if horizontal_indicies.shape[0]: 224 | x1, x2 = horizontal_indicies[[0, -1]] 225 | y1, y2 = vertical_indicies[[0, -1]] 226 | # x2 and y2 should not be part of the box. Increment by 1. 227 | x2 += 1 228 | y2 += 1 229 | if exp_ratio > 1.0: 230 | side_x = (x2 - x1 + 1) * exp_ratio 231 | side_y = (y2 - y1 + 1) * exp_ratio 232 | x1 = x1 - side_x / 2 if (x1 - side_x/2) > 0 else 0 233 | x2 = x2 + side_x / 2 if (x2 + side_x/2) < np.size(mask, 2) else np.size(mask, 2) 234 | y1 = y1 - side_y / 2 if (y1 - side_y/2) > 0 else 0 235 | y2 = y2 + side_y / 2 if (y2 + side_y/2) < np.size(mask, 1) else np.size(mask, 1) 236 | else: 237 | # No mask for this instance. Might happen due to 238 | # resizing or cropping. Set bbox to zeros 239 | x1, x2, y1, y2 = 0, 0, 0, 0 240 | boxes = np.array([y1, x1, y2, x2]) 241 | return boxes.astype(np.int32) 242 | 243 | 244 | def tensor2im(input_image, imtype=np.uint8): 245 | """"Converts a Tensor array into a numpy image array. 246 | 247 | Parameters: 248 | input_image (tensor) -- the input image tensor array 249 | imtype (type) -- the desired type of the converted numpy array 250 | """ 251 | if not isinstance(input_image, np.ndarray): 252 | if isinstance(input_image, torch.Tensor): # get the data from a variable 253 | image_tensor = input_image.data 254 | else: 255 | return input_image 256 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 257 | if image_numpy.shape[0] == 1: # grayscale to RGB 258 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 259 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 260 | else: # if it is a numpy array, do nothing 261 | image_numpy = input_image 262 | return image_numpy.astype(imtype) 263 | 264 | 265 | def save_image(image_numpy, image_path): 266 | """Save a numpy image to the disk 267 | 268 | Parameters: 269 | image_numpy (numpy array) -- input numpy array 270 | image_path (str) -- the path of the image 271 | """ 272 | image_pil = Image.fromarray(image_numpy) 273 | image_pil.save(image_path) 274 | 275 | 276 | def mkdirs(paths): 277 | """create empty directories if they don't exist 278 | 279 | Parameters: 280 | paths (str list) -- a list of directory paths 281 | """ 282 | if isinstance(paths, list) and not isinstance(paths, str): 283 | for path in paths: 284 | mkdir(path) 285 | else: 286 | mkdir(paths) 287 | 288 | 289 | def mkdir(path): 290 | """create a single empty directory if it didn't exist 291 | 292 | Parameters: 293 | path (str) -- a single directory path 294 | """ 295 | if not os.path.exists(path): 296 | os.makedirs(path) 297 | 298 | 299 | def index_closest(pixel, bar): 300 | n = len(bar) # bar: ===================== 301 | temp = [] # n ^ 0 302 | for i in range(n): # index: | 303 | dis = sum(list(map(lambda x: abs(x[0]-x[1]), zip(pixel, bar[i])))) 304 | temp.append(dis) 305 | value = n - temp.index(min(temp)) 306 | return value 307 | 308 | 309 | def cal_color_vector(color_img, cluster): 310 | """ Caculate hard-code strain ratio according cluster""" 311 | color_bar = [[0, 0, 143], [0, 0, 159],[0, 0, 175] ,[0, 0, 191] ,[0, 0, 207] ,[0, 0, 223] ,[0, 0, 239] ,[0, 0, 255] , 312 | [0, 16, 255],[0, 32, 255],[0, 48, 255],[0, 64, 255],[0, 80, 255],[0, 96, 255],[0, 112, 255],[0, 128, 255], 313 | [0, 143, 255],[0, 159, 255],[0, 175, 255],[0, 191, 255],[0, 207, 255],[0, 223, 255],[0, 239, 255],[0, 255, 255], 314 | [16, 255, 239],[32, 255, 223],[48, 255, 207],[64, 255, 191],[80, 255, 175],[96, 255, 159],[112, 255, 143],[128, 255, 128], 315 | [143, 255, 112],[159, 255, 96],[175, 255, 80],[191, 255, 64],[207, 255, 48],[223, 255, 32],[239, 255, 16],[255, 255, 0], 316 | [255, 239, 0],[255, 223, 0],[255, 207, 0],[255, 191, 0],[255, 175, 0],[255, 159, 0],[255, 143, 0],[255, 128, 0], 317 | [255, 112, 0],[255, 96, 0],[255, 80, 0],[255, 64, 0],[255, 48, 0],[255, 32, 0],[255, 16, 0],[255, 0, 0], 318 | [239, 0, 0],[223, 0, 0],[207, 0, 0],[191, 0, 0],[175, 0, 0],[159, 0, 0],[143, 0, 0],[128, 0, 0]] 319 | color_img = np.squeeze(color_img) 320 | cluster = np.squeeze(cluster) # shape (1, 1, 256, 256) 321 | num = torch.max(cluster) + 1 # number clusters 322 | vec = [] 323 | print(num) 324 | for i in range(num): 325 | values = 0 326 | XYs = torch.nonzero(cluster == i) # tuple (Xs, Ys) 327 | n = XYs.size()[0] # the number of nonzero value 328 | if n > 0: 329 | for j in range(n): 330 | # print(j) 331 | pixel = color_img[:, XYs[j, 0], XYs[j, 1]] 332 | value = index_closest(pixel, color_bar) 333 | values += value 334 | vec.append(values / n) 335 | else: 336 | vec.append(0) 337 | return vec 338 | 339 | 340 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | class UnetGenerator(nn.Module): 9 | """Create a Unet-based generator""" 10 | 11 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 12 | """Construct a Unet generator 13 | Parameters: 14 | input_nc (int) -- the number of channels in input images 15 | output_nc (int) -- the number of channels in output images 16 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 17 | image of size 128x128 will become of size 1x1 # at the bottleneck 18 | ngf (int) -- the number of filters in the last conv layer 19 | norm_layer -- normalization layer 20 | 21 | We construct the U-Net from the innermost layer to the outermost layer. 22 | It is a recursive process. 23 | """ 24 | super(UnetGenerator, self).__init__() 25 | # construct unet structure 26 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 27 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 28 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 29 | # gradually reduce the number of filters from ngf * 8 to ngf 30 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 31 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 32 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 33 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 34 | 35 | def forward(self, input): 36 | """Standard forward""" 37 | return self.model(input) 38 | 39 | 40 | class UnetSkipConnectionBlock(nn.Module): 41 | """Defines the Unet submodule with skip connection. 42 | X -------------------identity---------------------- 43 | |-- downsampling -- |submodule| -- upsampling --| 44 | """ 45 | 46 | def __init__(self, outer_nc, inner_nc, input_nc=None, 47 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 48 | """Construct a Unet submodule with skip connections. 49 | 50 | Parameters: 51 | outer_nc (int) -- the number of filters in the outer conv layer 52 | inner_nc (int) -- the number of filters in the inner conv layer 53 | input_nc (int) -- the number of channels in input images/features 54 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 55 | outermost (bool) -- if this module is the outermost module 56 | innermost (bool) -- if this module is the innermost module 57 | norm_layer -- normalization layer 58 | user_dropout (bool) -- if use dropout layers. 59 | """ 60 | super(UnetSkipConnectionBlock, self).__init__() 61 | self.outermost = outermost 62 | if type(norm_layer) == functools.partial: 63 | use_bias = norm_layer.func == nn.InstanceNorm2d 64 | else: 65 | use_bias = norm_layer == nn.InstanceNorm2d 66 | if input_nc is None: 67 | input_nc = outer_nc 68 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 69 | stride=2, padding=1, bias=use_bias) 70 | downrelu = nn.LeakyReLU(0.2, True) 71 | downnorm = norm_layer(inner_nc) 72 | uprelu = nn.ReLU(True) 73 | upnorm = norm_layer(outer_nc) 74 | 75 | if outermost: 76 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 77 | kernel_size=4, stride=2, 78 | padding=1) 79 | down = [downconv] 80 | up = [uprelu, upconv, nn.Tanh()] 81 | model = down + [submodule] + up 82 | elif innermost: 83 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 84 | kernel_size=4, stride=2, 85 | padding=1, bias=use_bias) 86 | down = [downrelu, downconv] 87 | up = [uprelu, upconv, upnorm] 88 | model = down + up 89 | else: 90 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 91 | kernel_size=4, stride=2, 92 | padding=1, bias=use_bias) 93 | down = [downrelu, downconv, downnorm] 94 | up = [uprelu, upconv, upnorm] 95 | 96 | if use_dropout: 97 | model = down + [submodule] + up + [nn.Dropout(0.5)] 98 | else: 99 | model = down + [submodule] + up 100 | 101 | self.model = nn.Sequential(*model) 102 | 103 | def forward(self, x): 104 | if self.outermost: 105 | return self.model(x) 106 | else: # add skip connections 107 | return torch.cat([x, self.model(x)], 1) 108 | 109 | 110 | class NLayerDiscriminator(nn.Module): 111 | """Defines a PatchGAN discriminator""" 112 | 113 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 114 | """Construct a PatchGAN discriminator 115 | 116 | Parameters: 117 | input_nc (int) -- the number of channels in input images 118 | ndf (int) -- the number of filters in the last conv layer 119 | n_layers (int) -- the number of conv layers in the discriminator 120 | norm_layer -- normalization layer 121 | """ 122 | super(NLayerDiscriminator, self).__init__() 123 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 124 | use_bias = norm_layer.func != nn.BatchNorm2d 125 | else: 126 | use_bias = norm_layer != nn.BatchNorm2d 127 | 128 | kw = 4 129 | padw = 1 130 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 131 | nf_mult = 1 132 | nf_mult_prev = 1 133 | for n in range(1, n_layers): # gradually increase the number of filters 134 | nf_mult_prev = nf_mult 135 | nf_mult = min(2 ** n, 8) # n = 1,2 nf= 2, 4 136 | sequence += [ 137 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 138 | norm_layer(ndf * nf_mult), 139 | nn.LeakyReLU(0.2, True) 140 | ] 141 | 142 | nf_mult_prev = nf_mult 143 | nf_mult = min(2 ** n_layers, 8) 144 | sequence += [ 145 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 146 | norm_layer(ndf * nf_mult), 147 | nn.LeakyReLU(0.2, True) 148 | ] 149 | 150 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 151 | self.model = nn.Sequential(*sequence) 152 | 153 | def forward(self, input): 154 | """Standard forward.""" 155 | return self.model(input) 156 | 157 | 158 | class GANLoss(nn.Module): 159 | """Define different GAN objectives. 160 | 161 | The GANLoss class abstracts away the need to create the target label tensor 162 | that has the same size as the input. 163 | """ 164 | 165 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 166 | """ Initialize the GANLoss class. 167 | 168 | Parameters: 169 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 170 | target_real_label (bool) - - label for a real image 171 | target_fake_label (bool) - - label of a fake image 172 | 173 | Note: Do not use sigmoid as the last layer of Discriminator. 174 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 175 | """ 176 | super(GANLoss, self).__init__() 177 | self.register_buffer('real_label', torch.tensor(target_real_label)) 178 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 179 | self.gan_mode = gan_mode 180 | if gan_mode == 'lsgan': 181 | self.loss = nn.MSELoss() 182 | elif gan_mode == 'vanilla': 183 | self.loss = nn.BCEWithLogitsLoss() 184 | elif gan_mode in ['wgangp']: 185 | self.loss = None 186 | else: 187 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 188 | 189 | def get_target_tensor(self, prediction, target_is_real): 190 | """Create label tensors with the same size as the input. 191 | 192 | Parameters: 193 | prediction (tensor) - - tpyically the prediction from a discriminator 194 | target_is_real (bool) - - if the ground truth label is for real images or fake images 195 | 196 | Returns: 197 | A label tensor filled with ground truth label, and with the size of the input 198 | """ 199 | 200 | if target_is_real: 201 | target_tensor = self.real_label 202 | else: 203 | target_tensor = self.fake_label 204 | return target_tensor.expand_as(prediction) 205 | 206 | def __call__(self, prediction, target_is_real): 207 | """Calculate loss given Discriminator's output and grount truth labels. 208 | 209 | Parameters: 210 | prediction (tensor) - - tpyically the prediction output from a discriminator 211 | target_is_real (bool) - - if the ground truth label is for real images or fake images 212 | 213 | Returns: 214 | the calculated loss. 215 | """ 216 | if self.gan_mode in ['lsgan', 'vanilla']: 217 | target_tensor = self.get_target_tensor(prediction, target_is_real) 218 | loss = self.loss(prediction, target_tensor) 219 | elif self.gan_mode == 'wgangp': 220 | if target_is_real: 221 | loss = -prediction.mean() 222 | else: 223 | loss = prediction.mean() 224 | return loss 225 | 226 | 227 | # VGG architecture used for identity preserving loss 228 | class VGG_(nn.Module): 229 | def __init__(self, requires_grad=False): 230 | super().__init__() 231 | net = VGG('VGG13', num_classes=4) 232 | checkpoint = torch.load('./models/ckpt.pth') 233 | for key in list(checkpoint['net'].keys()): 234 | key_ = key[key.index('.')+1:] 235 | checkpoint['net'][key_] = checkpoint['net'].pop(key) 236 | net.load_state_dict(checkpoint['net']) 237 | vgg_pretrained_features = net.features 238 | self.slice1 = torch.nn.Sequential() 239 | self.slice2 = torch.nn.Sequential() 240 | self.slice3 = torch.nn.Sequential() 241 | self.slice4 = torch.nn.Sequential() 242 | self.slice5 = torch.nn.Sequential() 243 | for x in range(2): 244 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 245 | for x in range(2, 7): 246 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 247 | for x in range(7, 12): 248 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 249 | for x in range(12, 21): 250 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 251 | for x in range(21, 30): 252 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 253 | if not requires_grad: 254 | for param in self.parameters(): 255 | param.requires_grad = False 256 | 257 | def forward(self, X): 258 | h_relu1 = self.slice1(X) 259 | h_relu2 = self.slice2(h_relu1) 260 | h_relu3 = self.slice3(h_relu2) 261 | h_relu4 = self.slice4(h_relu3) 262 | h_relu5 = self.slice5(h_relu4) 263 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 264 | return out 265 | 266 | 267 | # VGG identity preserving loss defination 268 | class VGGLoss(nn.Module): 269 | def __init__(self): 270 | super(VGGLoss, self).__init__() 271 | self.vgg = VGG_().cuda() 272 | self.criterion = nn.L1Loss() 273 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 274 | 275 | def forward(self, x, y): 276 | """ 277 | x : ground truth (tensor) 278 | y : fake images (tensor) 279 | """ 280 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 281 | loss = 0 282 | for i in range(len(x_vgg)): 283 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 284 | return loss 285 | 286 | 287 | # misc 288 | def get_scheduler(optimizer, opt): 289 | """Return a learning rate scheduler 290 | 291 | Parameters: 292 | optimizer -- the optimizer of the network 293 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  294 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 295 | 296 | For 'linear', we keep the same learning rate for the first epochs 297 | and linearly decay the rate to zero over the next epochs. 298 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 299 | See https://pytorch.org/docs/stable/optim.html for more details. 300 | """ 301 | if opt.lr_policy == 'linear': 302 | def lambda_rule(epoch): 303 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 304 | return lr_l 305 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 306 | elif opt.lr_policy == 'step': 307 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 308 | elif opt.lr_policy == 'plateau': 309 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 310 | elif opt.lr_policy == 'cosine': 311 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 312 | else: 313 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 314 | return scheduler 315 | 316 | 317 | def init_weights(net, init_type='normal', init_gain=0.02): 318 | """Initialize network weights. 319 | 320 | Parameters: 321 | net (network) -- network to be initialized 322 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 323 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 324 | 325 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 326 | work better for some applications. Feel free to try yourself. 327 | """ 328 | def init_func(m): # define the initialization function 329 | classname = m.__class__.__name__ 330 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 331 | if init_type == 'normal': 332 | init.normal_(m.weight.data, 0.0, init_gain) 333 | elif init_type == 'xavier': 334 | init.xavier_normal_(m.weight.data, gain=init_gain) 335 | elif init_type == 'kaiming': 336 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 337 | elif init_type == 'orthogonal': 338 | init.orthogonal_(m.weight.data, gain=init_gain) 339 | else: 340 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 341 | if hasattr(m, 'bias') and m.bias is not None: 342 | init.constant_(m.bias.data, 0.0) 343 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 344 | init.normal_(m.weight.data, 1.0, init_gain) 345 | init.constant_(m.bias.data, 0.0) 346 | 347 | print('initialize network with %s' % init_type) 348 | net.apply(init_func) # apply the initialization function 349 | 350 | 351 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 352 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 353 | Parameters: 354 | net (network) -- the network to be initialized 355 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 356 | gain (float) -- scaling factor for normal, xavier and orthogonal. 357 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 358 | 359 | Return an initialized network. 360 | """ 361 | if len(gpu_ids) > 0: 362 | assert(torch.cuda.is_available()) 363 | net.to(gpu_ids[0]) 364 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 365 | init_weights(net, init_type, init_gain=init_gain) 366 | return net 367 | 368 | 369 | 370 | 371 | 372 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from networks import * 6 | from util import * 7 | 8 | 9 | class BaseModel(ABC): 10 | """This class is an abstract base class (ABC) for models. 11 | To create a subclass, you need to implement the following five functions: 12 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 13 | -- : unpack data from dataset and apply preprocessing. 14 | -- : produce intermediate results. 15 | -- : calculate losses, gradients, and update network weights. 16 | -- : (optionally) add model-specific options and set default options. 17 | """ 18 | 19 | def __init__(self, opt): 20 | """Initialize the BaseModel class. 21 | 22 | Parameters: 23 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 24 | 25 | When creating your custom class, you need to implement your own initialization. 26 | In this fucntion, you should first call 27 | Then, you need to define four lists: 28 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 29 | -- self.model_names (str list): specify the images that you want to display and save. 30 | -- self.visual_names (str list): define networks used in our training. 31 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 32 | """ 33 | self.opt = opt 34 | self.gpu_ids = opt.gpu_ids 35 | self.isTrain = opt.isTrain 36 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 37 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | self.metric = 0 # used for learning rate policy 'plateau' 45 | 46 | @staticmethod 47 | def modify_commandline_options(parser, is_train): 48 | """Add new model-specific options, and rewrite default values for existing options. 49 | 50 | Parameters: 51 | parser -- original option parser 52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 53 | 54 | Returns: 55 | the modified parser. 56 | """ 57 | return parser 58 | 59 | @abstractmethod 60 | def set_input(self, input): 61 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 62 | 63 | Parameters: 64 | input (dict): includes the data itself and its metadata information. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def forward(self): 70 | """Run forward pass; called by both functions and .""" 71 | pass 72 | 73 | @abstractmethod 74 | def optimize_parameters(self): 75 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 76 | pass 77 | 78 | def setup(self, opt): 79 | """Load and print networks; create schedulers 80 | 81 | Parameters: 82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 83 | """ 84 | if self.isTrain: 85 | self.schedulers = [get_scheduler(optimizer, opt) for optimizer in self.optimizers] 86 | if not self.isTrain: 87 | load_suffix = 'iter_%d' % opt.epoch 88 | self.load_networks(load_suffix) 89 | self.print_networks(True) 90 | 91 | def eval(self): 92 | """Make models eval mode during test time""" 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | net = getattr(self, 'net' + name) 96 | net.eval() 97 | 98 | def test(self): 99 | """Forward function used in test time. 100 | 101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 102 | It also calls to produce additional visualization results 103 | """ 104 | with torch.no_grad(): 105 | self.forward() 106 | self.compute_visuals() 107 | 108 | def compute_visuals(self): 109 | """Calculate additional output images for visdom and HTML visualization""" 110 | pass 111 | 112 | def get_image_paths(self): 113 | """ Return image paths that are used to load current data""" 114 | return self.image_paths 115 | 116 | def update_learning_rate(self): 117 | """Update learning rates for all the networks; called at the end of every epoch""" 118 | for scheduler in self.schedulers: 119 | if self.opt.lr_policy == 'plateau': 120 | scheduler.step(self.metric) 121 | else: 122 | scheduler.step() 123 | 124 | lr = self.optimizers[0].param_groups[0]['lr'] 125 | print('learning rate = %.7f' % lr) 126 | 127 | def get_current_visuals(self): 128 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 129 | visual_ret = OrderedDict() 130 | for name in self.visual_names: 131 | if isinstance(name, str): 132 | visual_ret[name] = getattr(self, name) 133 | return visual_ret 134 | 135 | def get_current_losses(self): 136 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 137 | errors_ret = OrderedDict() 138 | for name in self.loss_names: 139 | if isinstance(name, str): 140 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 141 | return errors_ret 142 | 143 | def save_networks(self, epoch): 144 | """Save all the networks to the disk. 145 | 146 | Parameters: 147 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 148 | """ 149 | for name in self.model_names: 150 | if isinstance(name, str): 151 | save_filename = '%s_net_%s.pth' % (epoch, name) 152 | save_path = os.path.join(self.save_dir, save_filename) 153 | net = getattr(self, 'net' + name) 154 | 155 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 156 | torch.save(net.module.cpu().state_dict(), save_path) 157 | net.cuda(self.gpu_ids[0]) 158 | else: 159 | torch.save(net.cpu().state_dict(), save_path) 160 | 161 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 162 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 163 | key = keys[i] 164 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 165 | if module.__class__.__name__.startswith('InstanceNorm') and \ 166 | (key == 'running_mean' or key == 'running_var'): 167 | if getattr(module, key) is None: 168 | state_dict.pop('.'.join(keys)) 169 | if module.__class__.__name__.startswith('InstanceNorm') and \ 170 | (key == 'num_batches_tracked'): 171 | state_dict.pop('.'.join(keys)) 172 | else: 173 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 174 | 175 | def load_networks(self, epoch): 176 | """Load all the networks from the disk. 177 | 178 | Parameters: 179 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 180 | """ 181 | for name in self.model_names: 182 | if isinstance(name, str): 183 | load_filename = '%s_net_%s.pth' % (epoch, name) 184 | load_path = os.path.join(self.save_dir, load_filename) 185 | net = getattr(self, 'net' + name) 186 | if isinstance(net, torch.nn.DataParallel): 187 | net = net.module 188 | print('loading the model from %s' % load_path) 189 | # if you are using PyTorch newer than 0.4 (e.g., built from 190 | # GitHub source), you can remove str() on self.device 191 | state_dict = torch.load(load_path, map_location=str(self.device)) 192 | if hasattr(state_dict, '_metadata'): 193 | del state_dict._metadata 194 | 195 | # patch InstanceNorm checkpoints prior to 0.4 196 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 197 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 198 | net.load_state_dict(state_dict) 199 | 200 | def print_networks(self, verbose): 201 | """Print the total number of parameters in the network and (if verbose) network architecture 202 | 203 | Parameters: 204 | verbose (bool) -- if verbose: print the network architecture 205 | """ 206 | print('---------- Networks initialized -------------') 207 | for name in self.model_names: 208 | if isinstance(name, str): 209 | net = getattr(self, 'net' + name) 210 | num_params = 0 211 | for param in net.parameters(): 212 | num_params += param.numel() 213 | if verbose: 214 | print(net) 215 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 216 | print('-----------------------------------------------') 217 | 218 | def set_requires_grad(self, nets, requires_grad=False): 219 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 220 | Parameters: 221 | nets (network list) -- a list of networks 222 | requires_grad (bool) -- whether the networks require gradients or not 223 | """ 224 | if not isinstance(nets, list): 225 | nets = [nets] 226 | for net in nets: 227 | if net is not None: 228 | for param in net.parameters(): 229 | param.requires_grad = requires_grad 230 | 231 | 232 | class EnhanceGANModel(BaseModel): 233 | @staticmethod 234 | def modify_commandline_options(parser, is_train=True): 235 | """Add new dataset-specific options, and rewrite default values for existing options. 236 | 237 | Parameters: 238 | parser -- original option parser 239 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 240 | 241 | Returns: 242 | the modified parser. 243 | 244 | For pix2pix, we do not use image buffer 245 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 246 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 247 | """ 248 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 249 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned') 250 | if is_train: 251 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 252 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 253 | 254 | return parser 255 | 256 | def __init__(self, opt): 257 | """Initialize the pix2pix class. 258 | 259 | Parameters: 260 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 261 | """ 262 | BaseModel.__init__(self, opt) 263 | # specify the training losses you want to print out. The training/test scripts will call 264 | self.loss_names = ['GA_GAN', 'GT_GAN', 'G_L1', 'D_A', 'D_T'] 265 | # specify the images you want to save/display. The training/test scripts will call 266 | self.visual_names = ['real_A', 'fake_B', 'real_B', 'real_AT', 'real_BT', 'fake_BT'] 267 | # specify the models you want to save to the disk. The training/test scripts will call and 268 | if self.isTrain: 269 | self.model_names = ['G', 'DA', 'DT'] 270 | else: # during test time, only load G 271 | self.model_names = ['G'] 272 | # define networks (both generator and discriminator) 273 | self.netG = UnetGenerator(1, 3, 8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=True) 274 | self.netG = init_net(self.netG, gpu_ids=opt.gpu_ids) 275 | 276 | if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 277 | self.netDA = NLayerDiscriminator(1 + 3, 64, 3, norm_layer=nn.BatchNorm2d) 278 | self.netDA = init_net(self.netDA, gpu_ids=opt.gpu_ids) 279 | 280 | if self.opt.EnhanceT: 281 | self.netDT = NLayerDiscriminator(1 + 3, 64, 3, norm_layer=nn.BatchNorm2d) 282 | self.netDT = init_net(self.netDT, gpu_ids=opt.gpu_ids) 283 | 284 | if self.isTrain: 285 | # define loss functions 286 | self.criterionGAN = GANLoss('vanilla').to(self.device) 287 | self.criterionL1 = torch.nn.L1Loss() 288 | # initialize optimizers; schedulers will be automatically created by function . 289 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 290 | self.optimizer_DA = torch.optim.Adam(self.netDA.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 291 | if self.opt.EnhanceT: 292 | self.optimizer_DT = torch.optim.Adam(self.netDT.parameters(), lr=opt.lr, betas=(0.5, 0.999)) 293 | self.optimizers.append(self.optimizer_DT) 294 | self.optimizers.append(self.optimizer_G) 295 | self.optimizers.append(self.optimizer_DA) 296 | # Default backup dic for fasten training process 297 | self.real_backup = dict() 298 | self.fake_backup = dict() 299 | 300 | def set_input(self, input): 301 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 302 | 303 | Parameters: 304 | input (dict): include the data itself and its metadata information. 305 | 306 | The option 'direction' can be used to swap images in domain A and domain B. 307 | """ 308 | AtoB = True 309 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 310 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 311 | self.mask = input['M'] 312 | self.bbox = input['bbox'][0] # (y0, x0, y1, x1) 313 | 314 | def forward(self): 315 | """Run forward pass; called by both functions and .""" 316 | self.fake_B = self.netG(self.real_A) # G(A) 317 | # crop tumor area 318 | self.real_AT = self.real_A[:, :, self.bbox[0]: self.bbox[2], self.bbox[1]: self.bbox[3]] 319 | self.real_BT = self.real_B[:, :, self.bbox[0]: self.bbox[2], self.bbox[1]: self.bbox[3]] 320 | self.fake_BT = self.fake_B[:, :, self.bbox[0]: self.bbox[2], self.bbox[1]: self.bbox[3]] 321 | 322 | def backward_D_basic(self, netD, realA, realB, fakeB): 323 | """A general caculation of discriminator backward""" 324 | # Real (we use a conditional GAN) 325 | real_AB = torch.cat((realA, realB), 1) 326 | pred_real = netD.forward(real_AB) 327 | loss_D_real = self.criterionGAN(pred_real, True) 328 | # Fake 329 | fake_AB = torch.cat((realA, fakeB), 1) 330 | pred_fake = netD.forward(fake_AB.detach()) 331 | loss_D_fake = self.criterionGAN(pred_fake, False) 332 | 333 | loss_D = (loss_D_real + loss_D_fake) * 0.5 334 | return loss_D 335 | 336 | def backward_DA(self): 337 | self.loss_D_A = self.backward_D_basic(self.netDA, self.real_A, self.real_B, self.fake_B) 338 | self.loss_D_A.backward() 339 | 340 | def backward_DT(self): 341 | self.loss_D_T = self.backward_D_basic(self.netDT, self.real_AT, self.real_BT, self.fake_BT) 342 | self.loss_D_T.backward() 343 | 344 | def backward_G(self): 345 | """Calculate GAN and L1 loss for the generator""" 346 | # First, G(A) should fake the discriminator 347 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 348 | pred_fake_A = self.netDA(fake_AB) 349 | self.loss_GA_GAN = self.criterionGAN(pred_fake_A, True) 350 | # Tumor enhance 351 | if self.opt.EnhanceT: 352 | fake_ABT = torch.cat((self.real_AT, self.fake_BT), 1) 353 | pred_fake_AT = self.netDT(fake_ABT) 354 | self.loss_GT_GAN = self.criterionGAN(pred_fake_AT, True) 355 | # Second, G(A) = B 356 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * 100 357 | # combine loss and calculate gradients 358 | self.loss_G = self.loss_GA_GAN + self.loss_G_L1 + self.loss_GT_GAN 359 | 360 | self.loss_G.backward() 361 | 362 | def optimize_parameters(self): 363 | self.forward() # compute fake images: G(A) 364 | # update DA 365 | self.set_requires_grad(self.netDA, True) # enable backprop for D 366 | self.optimizer_DA.zero_grad() # set D's gradients to zero 367 | self.backward_DA() # calculate gradients for D 368 | self.optimizer_DA.step() # update D's weights 369 | if self.opt.EnhanceT: 370 | # update DT 371 | self.set_requires_grad(self.netDT, True) # enable backprop for D 372 | self.optimizer_DT.zero_grad() # set D's gradients to zero 373 | self.backward_DT() # calculate gradients for D 374 | self.optimizer_DT.step() # update D's weights 375 | # update G 376 | self.set_requires_grad(self.netDA, False) # D requires no gradients when optimizing G 377 | self.set_requires_grad(self.netDT, False) 378 | self.optimizer_G.zero_grad() # set G's gradients to zero 379 | self.backward_G() # calculate graidents for G 380 | self.optimizer_G.step() # udpate G's weights 381 | 382 | 383 | 384 | 385 | 386 | --------------------------------------------------------------------------------