├── README.md ├── adversarials.py ├── dag.py ├── dag_utils.py ├── data └── samples ├── dataset.py ├── loss.py ├── model ├── DenseNet.py ├── SegNet.py ├── UNet.py └── __init__.py ├── test.py ├── train.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Semgentation and Adversarial attacks 2 | 3 | This repository implements the segmentation models and segmentation adversarial attacts by pytorch. The main algorithms are referenced from "Generalizability vs. Robustness: Adversarial Examples for Medical Imaging" by Paschali, M., Conjeti, S., Navarro, F., & Navab, N. at MICCAI 2018. 4 | 5 | There are three segmentation models: UNet, SegNet, and DenseNet. Also, there are three different type of dense adversarial generations : Type A(target to be all background), Type B(target to be top 3 frequency labels), Type C(only one random target) 6 | 7 | 8 | Segmentation models 9 | 10 | - UNet : [Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/pdf/1505.04597.pdf) 11 | - SegNet : [A Deep Convolutional Encoder-Decoder 12 | Architecture for Image Segmentation](https://arxiv.org/pdf/1511.00561.pdf) 13 | - DenseNet : [The One Hundred Layers Tiramisu: 14 | Fully Convolutional DenseNets for Semantic Segmentation](https://arxiv.org/pdf/1611.09326.pdf) 15 | 16 | Adversarial Attacks for semantic segmentation DNNs. 17 | 18 | - Dense Adversarial Generation : [Adversarial examples for semantic segmentation and object detection](https://arxiv.org/pdf/1703.08603.pdf) 19 | 20 | ## Usage 21 | 22 | `train.py` : train segmentation models 23 | 24 | `test.py` : test data with trained models 25 | 26 | `adversarial.py` : generate adversarial examples based on segmentation models 27 | 28 | simple example 29 | 30 | ``` 31 | python train.py --model UNet 32 | ``` 33 | 34 | You can also use multiple GPU to train models. 35 | 36 | ``` 37 | python train.py --model UNet --device1 0 --device2 1 --device3 2 38 | ``` 39 | 40 | You can see more detailed arguments. 41 | 42 | ``` 43 | python train.py -h 44 | ``` 45 | -------------------------------------------------------------------------------- /adversarials.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision import transforms 4 | from torch.utils.data import DataLoader 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchsummary import summary 8 | import pickle 9 | import random 10 | import sys 11 | import os 12 | 13 | from model import UNet, SegNet, DenseNet 14 | 15 | from dataset import SampleDataset 16 | from scipy.stats import rice 17 | from skimage.measure import compare_ssim as ssim 18 | from dag import DAG 19 | from dag_utils import generate_target, generate_target_swap 20 | from util import make_one_hot 21 | 22 | from optparse import OptionParser 23 | 24 | BATCH_SIZE = 10 25 | 26 | 27 | def get_args(): 28 | 29 | parser = OptionParser() 30 | parser.add_option('--data_path', dest='data_path',type='string', 31 | default='data/samples', help='data path') 32 | parser.add_option('--attack_path', dest='attack_path',type='string', 33 | default=None, help='the path of adversarial attack examples') 34 | parser.add_option('--model_path', dest='model_path',type='string', 35 | help='model_path') 36 | parser.add_option('--classes', dest='classes', default=28, type='int', 37 | help='number of classes') 38 | parser.add_option('--channels', dest='channels', default=1, type='int', 39 | help='number of channels') 40 | parser.add_option('--width', dest='width', default=256, type='int', 41 | help='image width') 42 | parser.add_option('--height', dest='height', default=256, type='int', 43 | help='image height') 44 | parser.add_option('--model', dest='model', type='string', 45 | help='model name(UNet, SegNet, DenseNet)') 46 | parser.add_option('--attacks', dest='attacks', type='string', 47 | help='attack types: Rician, DAG_A, DAG_B, DAG_C') 48 | parser.add_option('--gpu', dest='gpu',type='string', 49 | default='gpu', help='gpu or cpu') 50 | parser.add_option('--device1', dest='device1', default=0, type='int', 51 | help='device1 index number') 52 | parser.add_option('--device2', dest='device2', default=-1, type='int', 53 | help='device2 index number') 54 | parser.add_option('--device3', dest='device3', default=-1, type='int', 55 | help='device3 index number') 56 | parser.add_option('--device4', dest='device4', default=-1, type='int', 57 | help='device4 index number') 58 | 59 | (options, args) = parser.parse_args() 60 | return options 61 | 62 | def load_data(args): 63 | 64 | data_path = args.data_path 65 | n_classes = args.classes 66 | data_width = args.width 67 | data_height = args.height 68 | 69 | # generate loader 70 | test_dataset = SampleDataset(data_path) 71 | 72 | test_loader = DataLoader( 73 | test_dataset, 74 | batch_size=BATCH_SIZE, 75 | num_workers=4, 76 | ) 77 | 78 | print('test_dataset : {}, test_loader : {}'.format(len(test_dataset), len(test_loader))) 79 | 80 | 81 | return test_dataset, test_loader 82 | 83 | # generate Rician noise examples 84 | # Meausre the difference between original and adversarial examples by using structural Similarity (SSIM). 85 | # The adversarial examples which has SSIM value from 0.97 to 0.99 can be passed. 86 | # SSIM adapted from https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html 87 | def Rician(test_dataset): 88 | 89 | def generate_rician(image): 90 | 91 | ssim_noise = 0 92 | 93 | if torch.is_tensor(image): 94 | image = image.numpy() 95 | 96 | rician_image = np.zeros_like(image) 97 | 98 | while ssim_noise <= 0.97 or ssim_noise >= 0.99: 99 | b = random.uniform(0, 1) 100 | rv = rice(b) 101 | rician_image = rv.pdf(image) 102 | ssim_noise = ssim(image[0], rician_image[0], data_range=rician_image[0].max() - rician_image[0].min()) 103 | 104 | #print('ssim : {:.2f}'.format(ssim_noise)) 105 | 106 | return rician_image 107 | 108 | adversarial_examples = [] 109 | 110 | for batch_idx in range(len(test_dataset)): 111 | 112 | image, labels = test_dataset.__getitem__(batch_idx) 113 | 114 | rician_image = generate_rician(image) 115 | 116 | #print("image {} save".format(batch_idx)) 117 | 118 | adversarial_examples.append([rician_image[0],labels.squeeze(0).numpy()]) 119 | 120 | print('total {} Rician noise images are generated'.format(len(adversarial_examples))) 121 | 122 | return adversarial_examples 123 | 124 | 125 | def DAG_Attack(model, test_dataset, args): 126 | 127 | # Hyperparamter for DAG 128 | 129 | num_iterations=20 130 | gamma=0.5 131 | num=15 132 | 133 | gpu = args.gpu 134 | 135 | # set device configuration 136 | device_ids = [] 137 | 138 | if gpu == 'gpu' : 139 | 140 | if not torch.cuda.is_available() : 141 | print("No cuda available") 142 | raise SystemExit 143 | 144 | device = torch.device(args.device1) 145 | 146 | device_ids.append(args.device1) 147 | 148 | if args.device2 != -1 : 149 | device_ids.append(args.device2) 150 | 151 | if args.device3 != -1 : 152 | device_ids.append(args.device3) 153 | 154 | if args.device4 != -1 : 155 | device_ids.append(args.device4) 156 | 157 | 158 | else : 159 | device = torch.device("cpu") 160 | 161 | if len(device_ids) > 1: 162 | model = nn.DataParallel(model, device_ids = device_ids) 163 | 164 | model = model.to(device) 165 | 166 | adversarial_examples = [] 167 | 168 | for batch_idx in range(len(test_dataset)): 169 | image, label = test_dataset.__getitem__(batch_idx) 170 | 171 | image = image.unsqueeze(0) 172 | pure_label = label.squeeze(0).numpy() 173 | 174 | image , label = image.clone().detach().requires_grad_(True).float(), label.clone().detach().float() 175 | image , label = image.to(device), label.to(device) 176 | 177 | # Change labels from [batch_size, height, width] to [batch_size, num_classes, height, width] 178 | label_oh=make_one_hot(label.long(),n_classes,device) 179 | 180 | if args.attacks == 'DAG_A': 181 | 182 | adv_target = torch.zeros_like(label_oh) 183 | 184 | elif args.attacks == 'DAG_B': 185 | 186 | adv_target=generate_target_swap(label_oh.cpu().numpy()) 187 | adv_target=torch.from_numpy(adv_target).float() 188 | 189 | elif args.attacks == 'DAG_C': 190 | 191 | # choice one randome particular class except background class(0) 192 | unique_label = torch.unique(label) 193 | target_class = int(random.choice(unique_label[1:]).item()) 194 | 195 | adv_target=generate_target(label_oh.cpu().numpy(), target_class = target_class) 196 | adv_target=torch.from_numpy(adv_target).float() 197 | 198 | else : 199 | print("wrong adversarial attack types : must be DAG_A, DAG_B, or DAG_C") 200 | raise SystemExit 201 | 202 | 203 | adv_target=adv_target.to(device) 204 | 205 | _, _, _, _, _, image_iteration=DAG(model=model, 206 | image=image, 207 | ground_truth=label_oh, 208 | adv_target=adv_target, 209 | num_iterations=num_iterations, 210 | gamma=gamma, 211 | no_background=True, 212 | background_class=0, 213 | device=device, 214 | verbose=False) 215 | 216 | if len(image_iteration) >= 1: 217 | 218 | adversarial_examples.append([image_iteration[-1], 219 | pure_label]) 220 | 221 | del image_iteration 222 | 223 | print('total {} {} images are generated'.format(len(adversarial_examples), args.attacks)) 224 | 225 | return adversarial_examples 226 | 227 | if __name__ == "__main__": 228 | 229 | args = get_args() 230 | 231 | n_channels = args.channels 232 | n_classes = args.classes 233 | 234 | test_dataset, test_loader = load_data(args) 235 | 236 | if args.attacks == 'Rician': 237 | 238 | adversarial_examples = Rician(test_dataset) 239 | 240 | if args.attack_path is None: 241 | 242 | adversarial_path = 'data/' + args.attacks + '.pickle' 243 | 244 | else: 245 | 246 | adversarial_path = args.attack_path 247 | 248 | else: 249 | 250 | model = None 251 | 252 | if args.model == 'UNet': 253 | model = UNet(in_channels = n_channels, n_classes = n_classes) 254 | 255 | elif args.model == 'SegNet': 256 | model = SegNet(in_channels = n_channels, n_classes = n_classes) 257 | 258 | elif args.model == 'DenseNet': 259 | model = DenseNet(in_channels = n_channels, n_classes = n_classes) 260 | 261 | else : 262 | print("wrong model : must be UNet, SegNet, or DenseNet") 263 | raise SystemExit 264 | 265 | summary(model, input_size=(n_channels, args.height, args.width), device = 'cpu') 266 | 267 | model.load_state_dict(torch.load(args.model_path)) 268 | 269 | adversarial_examples = DAG_Attack(model, test_dataset, args) 270 | 271 | if args.attack_path is None: 272 | 273 | adversarial_path = 'data/' + args.model + '_' + args.attacks + '.pickle' 274 | 275 | else: 276 | adversarial_path = args.attack_path 277 | 278 | # save adversarial examples([adversarial examples, labels]) 279 | with open(adversarial_path, 'wb') as fp: 280 | pickle.dump(adversarial_examples, fp) 281 | 282 | -------------------------------------------------------------------------------- /dag.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Function for Dense Adversarial Generation 3 | Adversarial Examples for Semantic Segmentation 4 | Muhammad Ferjad Naeem 5 | ferjad.naeem@tum.de 6 | adapted from https://github.com/IFL-CAMP/dense_adversarial_generation_pytorch 7 | ''' 8 | import torch 9 | from util import make_one_hot 10 | 11 | 12 | def DAG(model,image,ground_truth,adv_target,num_iterations=20,gamma=0.07,no_background=True,background_class=0,device='cuda:0',verbose=False): 13 | ''' 14 | Generates adversarial example for a given Image 15 | 16 | Parameters 17 | ---------- 18 | model: Torch Model 19 | image: Torch tensor of dtype=float. Requires gradient. [b*c*h*w] 20 | ground_truth: Torch tensor of labels as one hot vector per class 21 | adv_target: Torch tensor of dtype=float. This is the purturbed labels. [b*classes*h*w] 22 | num_iterations: Number of iterations for the algorithm 23 | gamma: epsilon value. The maximum Change possible. 24 | no_background: If True, does not purturb the background class 25 | background_class: The index of the background class. Used to filter background 26 | device: Device to perform the computations on 27 | verbose: Bool. If true, prints the amount of change and the number of values changed in each iteration 28 | Returns 29 | ------- 30 | Image: Adversarial Output, logits of original image as torch tensor 31 | logits: Output of the Clean Image as torch tensor 32 | noise_total: List of total noise added per iteration as numpy array 33 | noise_iteration: List of noise added per iteration as numpy array 34 | prediction_iteration: List of prediction per iteration as numpy array 35 | image_iteration: List of image per iteration as numpy array 36 | 37 | ''' 38 | 39 | noise_total=[] 40 | noise_iteration=[] 41 | prediction_iteration=[] 42 | image_iteration=[] 43 | background=None 44 | logits=model(image) 45 | orig_image=image 46 | _,predictions_orig=torch.max(logits,1) 47 | predictions_orig=make_one_hot(predictions_orig,logits.shape[1],device) 48 | 49 | if(no_background): 50 | background=torch.zeros(logits.shape) 51 | background[:,background_class,:,:]=torch.ones((background.shape[2],background.shape[3])) 52 | background=background.to(device) 53 | 54 | for a in range(num_iterations): 55 | output=model(image) 56 | _,predictions=torch.max(output,1) 57 | prediction_iteration.append(predictions[0].cpu().numpy()) 58 | predictions=make_one_hot(predictions,logits.shape[1],device) 59 | 60 | condition1=torch.eq(predictions,ground_truth) 61 | condition=condition1 62 | 63 | if no_background: 64 | condition2=(ground_truth!=background) 65 | condition=torch.mul(condition1,condition2) 66 | condition=condition.float() 67 | 68 | if(condition.sum()==0): 69 | print("Condition Reached") 70 | image=None 71 | break 72 | 73 | #Finding pixels to purturb 74 | adv_log=torch.mul(output,adv_target) 75 | #Getting the values of the original output 76 | clean_log=torch.mul(output,ground_truth) 77 | 78 | #Finding r_m 79 | adv_direction=adv_log-clean_log 80 | r_m=torch.mul(adv_direction,condition) 81 | r_m.requires_grad_() 82 | #Summation 83 | r_m_sum=r_m.sum() 84 | r_m_sum.requires_grad_() 85 | #Finding gradient with respect to image 86 | r_m_grad=torch.autograd.grad(r_m_sum,image,retain_graph=True) 87 | #Saving gradient for calculation 88 | r_m_grad_calc=r_m_grad[0] 89 | 90 | #Calculating Magnitude of the gradient 91 | r_m_grad_mag=r_m_grad_calc.norm() 92 | 93 | if(r_m_grad_mag==0): 94 | print("Condition Reached, no gradient") 95 | #image=None 96 | break 97 | #Calculating final value of r_m 98 | r_m_norm=(gamma/r_m_grad_mag)*r_m_grad_calc 99 | 100 | #if no_background: 101 | #if False: 102 | if no_background is False: 103 | condition_image=condition.sum(dim=1) 104 | condition_image=condition_image.unsqueeze(1) 105 | r_m_norm=torch.mul(r_m_norm,condition_image) 106 | 107 | #Updating the image 108 | #print("r_m_norm : ",torch.unique(r_m_norm)) 109 | image=torch.clamp((image+r_m_norm),0,1) 110 | image_iteration.append(image[0][0].detach().cpu().numpy()) 111 | noise_total.append((image-orig_image)[0][0].detach().cpu().numpy()) 112 | noise_iteration.append(r_m_norm[0][0].cpu().numpy()) 113 | 114 | if verbose: 115 | print("Iteration ",a) 116 | print("Change to the image is ",r_m_norm.sum()) 117 | print("Magnitude of grad is ",r_m_grad_mag) 118 | print("Condition 1 ",condition1.sum()) 119 | if no_background: 120 | print("Condition 2 ",condition2.sum()) 121 | print("Condition is", condition.sum()) 122 | 123 | return image, logits, noise_total, noise_iteration, prediction_iteration, image_iteration -------------------------------------------------------------------------------- /dag_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import scipy.misc as smp 5 | import scipy.ndimage 6 | from random import randint 7 | import random 8 | 9 | def ensure_dir(path): 10 | if not os.path.exists(path): 11 | os.makedirs(path) 12 | 13 | def generate_target(y_test, target_class = 13, width = 256, height = 256): 14 | 15 | y_target = y_test 16 | 17 | dilated_image = scipy.ndimage.binary_dilation(y_target[0, target_class, :, :], iterations=6).astype(y_test.dtype) 18 | 19 | for i in range(width): 20 | for j in range(height): 21 | y_target[0, target_class, i, j] = dilated_image[i,j] 22 | 23 | for i in range(width): 24 | for j in range(height): 25 | potato = np.count_nonzero(y_target[0,:,i,j]) 26 | if (potato > 1): 27 | x = np.where(y_target[0, : ,i, j] > 0) 28 | k = x[0] 29 | #print("{}, {}, {}".format(i,j,k)) 30 | if k[0] == target_class: 31 | y_target[0,k[1],i,j] = 0. 32 | else: 33 | y_target[0, k[0], i, j] = 0. 34 | 35 | return y_target 36 | 37 | def generate_target_swap(y_test): 38 | 39 | 40 | y_target = y_test 41 | 42 | y_target_arg = np.argmax(y_test, axis = 1) 43 | 44 | y_target_arg_no_back = np.where(y_target_arg>0) 45 | 46 | y_target_arg = y_target_arg[y_target_arg_no_back] 47 | 48 | classes = np.unique(y_target_arg) 49 | 50 | if len(classes) > 3: 51 | 52 | first_class = 0 53 | 54 | second_class = 0 55 | 56 | third_class = 0 57 | 58 | while first_class == second_class == third_class: 59 | first_class = classes[randint(0, len(classes)-1)] 60 | f_ind = np.where(y_target_arg==first_class) 61 | #print(np.shape(f_ind)) 62 | 63 | second_class = classes[randint(0, len(classes)-1)] 64 | s_ind = np.where(y_target_arg == second_class) 65 | 66 | third_class = classes[randint(0, len(classes) - 1)] 67 | t_ind = np.where(y_target_arg == third_class) 68 | 69 | summ = np.shape(f_ind)[1] + np.shape(s_ind)[1] + np.shape(t_ind)[1] 70 | 71 | if summ < 1000: 72 | first_class = 0 73 | 74 | second_class = 0 75 | 76 | third_class = 0 77 | 78 | for i in range(256): 79 | for j in range(256): 80 | temp = y_target[0,second_class, i,j] 81 | y_target[0,second_class, i,j] = y_target[0,first_class,i,j] 82 | y_target[0, first_class,i, j] = temp 83 | 84 | 85 | else: 86 | y_target = y_test 87 | print('Not enough classes to swap!') 88 | return y_target 89 | -------------------------------------------------------------------------------- /data/samples: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sky4689524/Pytorch_AdversarialAttacks/5fe7f050341d2afb4834dbab28f20cd80cd85317/data/samples -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | from torchvision import transforms 3 | import numpy as np 4 | import os.path 5 | import pickle 6 | import torch 7 | 8 | class SampleDataset(Dataset): 9 | 10 | def __init__(self, root_dir): 11 | 12 | self.data_path = root_dir 13 | 14 | self.images = [] 15 | self.labels = [] 16 | 17 | # data form [images, labels] 18 | with open (self.data_path, 'rb') as fp: 19 | data = pickle.load(fp) 20 | 21 | for i in range(len(data)): 22 | 23 | self.images.append(data[i][0]) 24 | self.labels.append(data[i][1]) 25 | 26 | 27 | def __len__(self): 28 | return len(self.labels) 29 | 30 | def __getitem__(self, index): 31 | image = self.images[index] 32 | labels = self.labels[index] 33 | 34 | torch_transform = transforms.Compose([ 35 | transforms.ToTensor() 36 | ]) 37 | 38 | image = torch_transform(image) 39 | labels = torch_transform(labels) 40 | 41 | return (image, labels) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use a composite loss of weighted-cross entropy and dice loss proposed in https://arxiv.org/pdf/1801.04161.pdf 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import sys, os.path 9 | from util import estimate_weights, make_one_hot 10 | 11 | def dice_score(pred, encoded_target): 12 | """ 13 | :param pred : N x C x H x W logits 14 | :param encoded_target : N x C x H x W LongTensor 15 | """ 16 | 17 | output = F.softmax(pred, dim = 1) 18 | 19 | eps = 1 20 | 21 | intersection = output * encoded_target 22 | numerator = 2 * intersection.sum(0).sum(1).sum(1) + eps 23 | denominator = output + encoded_target 24 | denominator = denominator.sum(0).sum(1).sum(1) + eps 25 | 26 | loss_per_channel = numerator / denominator 27 | 28 | score = loss_per_channel.sum() / output.size(1) 29 | 30 | del output, encoded_target 31 | 32 | return score.mean() 33 | 34 | 35 | def dice_loss(pred, encoded_target): 36 | """ 37 | :param pred : N x C x H x W logits 38 | :param encoded_target : N x C x H x W LongTensor 39 | """ 40 | 41 | output = F.softmax(pred, dim = 1) 42 | 43 | eps = 1 44 | 45 | intersection = output * encoded_target 46 | numerator = 2 * intersection.sum(0).sum(1).sum(1) + eps 47 | denominator = output + encoded_target 48 | denominator = denominator.sum(0).sum(1).sum(1) + eps 49 | 50 | loss_per_channel = 1 - (numerator / denominator) 51 | 52 | loss = loss_per_channel.sum() / output.size(1) 53 | del output, encoded_target 54 | 55 | return loss.mean() 56 | 57 | 58 | def cross_entropy_loss(pred, target, weight): 59 | """ 60 | :param pred : N x C x H x W 61 | :param target : N x H x W 62 | :param: weight : N x H x W 63 | 64 | """ 65 | 66 | loss_func = nn.CrossEntropyLoss() 67 | 68 | loss = loss_func(pred, target) 69 | 70 | return torch.mean(torch.mul(loss, weight)) 71 | 72 | def combined_loss(pred, target, device, n_classes): 73 | """ 74 | :param pred: N x C x H x W 75 | :param target: N x H x W 76 | """ 77 | 78 | weights = estimate_weights(target.float()) 79 | weights = weights.to(device) 80 | 81 | cross = cross_entropy_loss(pred, target, weights) 82 | 83 | target_oh = make_one_hot(target.long(), n_classes, device) 84 | 85 | dice = dice_loss(pred, target_oh) 86 | 87 | loss = cross + dice 88 | 89 | del weights 90 | 91 | return loss, cross, dice -------------------------------------------------------------------------------- /model/DenseNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch implementation of DenseNet 3 | 4 | reference from https://github.com/bfortuner/pytorch_tiramisu 5 | 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class DenseLayer(nn.Sequential): 13 | def __init__(self, in_channels, growth_rate, kernel_size = 3): 14 | super().__init__() 15 | self.add_module('Norm', nn.BatchNorm2d(in_channels)) 16 | self.add_module('Relu', nn.ReLU(True)) 17 | self.add_module('Conv', nn.Conv2d(in_channels, growth_rate, kernel_size, padding=1)) 18 | self.add_module('Drop', nn.Dropout2d(0.2)) 19 | 20 | def forward(self, x): 21 | return super().forward(x) 22 | 23 | class DenseBlock(nn.Module): 24 | def __init__(self, in_channels, growth_rate, n_layers, upsample=False): 25 | super().__init__() 26 | self.upsample = upsample 27 | self.layers = nn.ModuleList([DenseLayer( 28 | in_channels + n * growth_rate, growth_rate) 29 | for n in range(n_layers)]) 30 | 31 | def forward(self, x): 32 | if self.upsample: 33 | new_features = [] 34 | for layer in self.layers: 35 | out = layer(x) 36 | x = torch.cat([x, out], 1) 37 | new_features.append(out) 38 | return torch.cat(new_features,1) 39 | else: 40 | for layer in self.layers: 41 | out = layer(x) 42 | x = torch.cat([x, out], 1) 43 | return x 44 | 45 | class TransitionDown(nn.Sequential): 46 | def __init__(self, in_channels, kernel_size = 1): 47 | super().__init__() 48 | self.add_module('Norm', nn.BatchNorm2d(num_features=in_channels)) 49 | self.add_module('Relu', nn.ReLU(inplace=True)) 50 | self.add_module('Conv', nn.Conv2d(in_channels, in_channels,kernel_size)) 51 | self.add_module('Drop', nn.Dropout2d(0.2)) 52 | self.add_module('Maxp', nn.MaxPool2d(2)) 53 | 54 | def forward(self, x): 55 | return super().forward(x) 56 | 57 | 58 | class TransitionUp(nn.Module): 59 | def __init__(self, in_channels, out_channels, kernel_size = 3): 60 | super().__init__() 61 | self.TransUp = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=2) 62 | 63 | def forward(self, x, skip): 64 | out = self.TransUp(x) 65 | out = center_crop(out, skip.size(2), skip.size(3)) 66 | out = torch.cat([out, skip], 1) 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Sequential): 71 | def __init__(self, in_channels, growth_rate, n_layers): 72 | super().__init__() 73 | self.add_module('bottleneck', DenseBlock( 74 | in_channels, growth_rate, n_layers, upsample=True)) 75 | 76 | def forward(self, x): 77 | return super().forward(x) 78 | 79 | def center_crop(layer, max_height, max_width): 80 | _, _, h, w = layer.size() 81 | xy1 = (w - max_width) // 2 82 | xy2 = (h - max_height) // 2 83 | return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] 84 | 85 | class DenseNet(nn.Module): 86 | def __init__(self, in_channels=1, down_blocks=(4,5,7,10,12), 87 | up_blocks=(12,10,7,5,4), bottleneck_layers=15, 88 | growth_rate=16, out_chans_first_conv=48, n_classes=28): 89 | 90 | super().__init__() 91 | 92 | self.down_blocks = down_blocks 93 | self.up_blocks = up_blocks 94 | cur_channels_count = 0 95 | skip_connection_channel_counts = [] 96 | 97 | ## First Convolution 98 | 99 | self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, 100 | out_channels=out_chans_first_conv, kernel_size=3, padding=1)) 101 | cur_channels_count = out_chans_first_conv 102 | 103 | ## Encoding 104 | 105 | self.denseBlocksDown = nn.ModuleList([]) 106 | self.transDownBlocks = nn.ModuleList([]) 107 | for i in range(len(down_blocks)): 108 | self.denseBlocksDown.append( 109 | DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) 110 | cur_channels_count += (growth_rate*down_blocks[i]) 111 | skip_connection_channel_counts.insert(0,cur_channels_count) 112 | self.transDownBlocks.append(TransitionDown(cur_channels_count)) 113 | 114 | ## Bottleneck 115 | 116 | self.add_module('bottleneck',Bottleneck(cur_channels_count, 117 | growth_rate, bottleneck_layers)) 118 | prev_block_channels = growth_rate*bottleneck_layers 119 | cur_channels_count += prev_block_channels 120 | 121 | ## Decoding 122 | 123 | self.transUpBlocks = nn.ModuleList([]) 124 | self.denseBlocksUp = nn.ModuleList([]) 125 | for i in range(len(up_blocks)-1): 126 | self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels)) 127 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] 128 | 129 | self.denseBlocksUp.append(DenseBlock( 130 | cur_channels_count, growth_rate, up_blocks[i], 131 | upsample=True)) 132 | prev_block_channels = growth_rate*up_blocks[i] 133 | cur_channels_count += prev_block_channels 134 | 135 | ## Final DenseBlock 136 | 137 | self.transUpBlocks.append(TransitionUp( 138 | prev_block_channels, prev_block_channels)) 139 | cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1] 140 | 141 | self.denseBlocksUp.append(DenseBlock( 142 | cur_channels_count, growth_rate, up_blocks[-1], 143 | upsample=False)) 144 | cur_channels_count += growth_rate*up_blocks[-1] 145 | 146 | ## Final layer 147 | 148 | self.finalConv = nn.Conv2d(in_channels=cur_channels_count, 149 | out_channels=n_classes, kernel_size=1) 150 | 151 | 152 | def forward(self, x): 153 | out = self.firstconv(x) 154 | 155 | skip_connections = [] 156 | 157 | for i in range(len(self.down_blocks)): 158 | out = self.denseBlocksDown[i](out) 159 | skip_connections.append(out) 160 | out = self.transDownBlocks[i](out) 161 | 162 | out = self.bottleneck(out) 163 | 164 | for i in range(len(self.up_blocks)): 165 | skip = skip_connections.pop() 166 | out = self.transUpBlocks[i](out, skip) 167 | out = self.denseBlocksUp[i](out) 168 | 169 | out = self.finalConv(out) 170 | 171 | 172 | return out -------------------------------------------------------------------------------- /model/SegNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class encode_conv(nn.Module): 7 | def __init__(self, in_ch, out_ch, kernel_size = 3): 8 | super(encode_conv, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_ch, out_ch, kernel_size, stride = 1, padding=1), 11 | nn.BatchNorm2d(out_ch), 12 | nn.ReLU(inplace=True), 13 | ) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | return x 18 | 19 | class decode_conv(nn.Module): 20 | def __init__(self, in_ch, out_ch, kernel_size = 3): 21 | super(decode_conv, self).__init__() 22 | self.conv = nn.Sequential( 23 | nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride = 1, padding=1), 24 | nn.BatchNorm2d(out_ch), 25 | nn.ReLU(inplace=True), 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.conv(x) 30 | return x 31 | 32 | class outconv(nn.Module): 33 | def __init__(self, in_ch, out_ch, kernel_size = 3): 34 | super(outconv, self).__init__() 35 | self.conv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size, stride = 1, padding=1) 36 | 37 | def forward(self, x): 38 | x = self.conv(x) 39 | return x 40 | 41 | 42 | class SegNet(nn.Module): 43 | def __init__(self, in_channels = 1, n_classes = 28): 44 | super().__init__() 45 | 46 | self.conv1_1 = encode_conv(in_channels, 64) 47 | self.conv1_2 = encode_conv(64, 64) 48 | 49 | self.conv2_1 = encode_conv(64, 128) 50 | self.conv2_2 = encode_conv(128, 128) 51 | 52 | self.conv3_1 = encode_conv(128, 256) 53 | self.conv3_2 = encode_conv(256, 256) 54 | self.conv3_3 = encode_conv(256, 256) 55 | 56 | self.conv4_1 = encode_conv(256, 512) 57 | self.conv4_2 = encode_conv(512, 512) 58 | self.conv4_3 = encode_conv(512, 512) 59 | 60 | self.conv5_1 = encode_conv(512, 512) 61 | self.conv5_2 = encode_conv(512, 512) 62 | self.conv5_3 = encode_conv(512, 512) 63 | 64 | self.deconv5_3 = decode_conv(512,512) 65 | self.deconv5_2 = decode_conv(512,512) 66 | self.deconv5_1 = decode_conv(512,512) 67 | 68 | self.deconv4_3 = decode_conv(512,512) 69 | self.deconv4_2 = decode_conv(512,512) 70 | self.deconv4_1 = decode_conv(512,256) 71 | 72 | self.deconv3_3 = decode_conv(256,256) 73 | self.deconv3_2 = decode_conv(256,256) 74 | self.deconv3_1 = decode_conv(256,128) 75 | 76 | self.deconv2_2 = decode_conv(128,128) 77 | self.deconv2_1 = decode_conv(128,64) 78 | 79 | self.deconv1_2 = decode_conv(64,64) 80 | self.deconv1_1 = outconv(64,n_classes) 81 | 82 | 83 | def forward(self, x): 84 | 85 | #Encoder 86 | 87 | dim0 = x.size() 88 | x = self.conv1_1(x) 89 | x = self.conv1_2(x) 90 | x, indices0 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) 91 | 92 | dim1 = x.size() 93 | x = self.conv2_1(x) 94 | x = self.conv2_2(x) 95 | x, indices1 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) 96 | 97 | dim2 = x.size() 98 | x = self.conv3_1(x) 99 | x = self.conv3_2(x) 100 | x = self.conv3_3(x) 101 | x, indices2 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) 102 | 103 | dim3 = x.size() 104 | x = self.conv4_1(x) 105 | x = self.conv4_2(x) 106 | x = self.conv4_3(x) 107 | x, indices3 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) 108 | 109 | dim4 = x.size() 110 | x = self.conv5_1(x) 111 | x = self.conv5_2(x) 112 | x = self.conv5_3(x) 113 | x, indices4 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) 114 | 115 | # Decoder 116 | 117 | x = F.max_unpool2d(x, indices4, kernel_size = 2, stride = 2, output_size = dim4) 118 | x = self.deconv5_3(x) 119 | x = self.deconv5_2(x) 120 | x = self.deconv5_1(x) 121 | 122 | x = F.max_unpool2d(x, indices3, kernel_size = 2, stride = 2, output_size = dim3) 123 | x = self.deconv4_3(x) 124 | x = self.deconv4_2(x) 125 | x = self.deconv4_1(x) 126 | 127 | x = F.max_unpool2d(x, indices2, kernel_size = 2, stride = 2, output_size = dim2) 128 | x = self.deconv3_3(x) 129 | x = self.deconv3_2(x) 130 | x = self.deconv3_1(x) 131 | 132 | x = F.max_unpool2d(x, indices1, kernel_size = 2, stride = 2, output_size = dim1) 133 | x = self.deconv2_2(x) 134 | x = self.deconv2_1(x) 135 | 136 | x = F.max_unpool2d(x, indices0, kernel_size = 2, stride = 2, output_size = dim0) 137 | x = self.deconv1_2(x) 138 | x = self.deconv1_1(x) 139 | 140 | return x -------------------------------------------------------------------------------- /model/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class double_conv(nn.Module): 7 | '''(conv => BN => ReLU) * 2''' 8 | def __init__(self, in_ch, out_ch, kernel_size=3): 9 | super(double_conv, self).__init__() 10 | self.conv = nn.Sequential( 11 | nn.Conv2d(in_ch, out_ch, kernel_size, padding=1), 12 | nn.BatchNorm2d(out_ch), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(out_ch, out_ch, kernel_size, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | 23 | class contraction_path(nn.Module): 24 | def __init__(self, in_ch, out_ch): 25 | super().__init__() 26 | self.contract = nn.Sequential( 27 | nn.MaxPool2d(2), 28 | double_conv(in_ch, out_ch) 29 | ) 30 | 31 | def forward(self, x): 32 | x = self.contract(x) 33 | return x 34 | 35 | class expansion_path(nn.Module): 36 | def __init__(self, in_ch, out_ch): 37 | super().__init__() 38 | self.scale_factor = 2 39 | #self.expansion = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 40 | self.expansion = nn.ConvTranspose2d(in_ch , out_ch, 2, stride=2) 41 | 42 | self.conv = double_conv(in_ch, out_ch) 43 | 44 | def forward(self, x1, x2): 45 | #x = F.interpolate(x, scale_factor = self.scale_factor, mode='bilinear', align_corners=True) 46 | # x = self.expansion(x) 47 | x1 = self.expansion(x1) 48 | x = torch.cat([x2, x1], 1) 49 | x = self.conv(x) 50 | 51 | return x 52 | 53 | class outconv(nn.Module): 54 | def __init__(self, in_ch, out_ch): 55 | super(outconv, self).__init__() 56 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size = 1) 57 | 58 | def forward(self, x): 59 | x = self.conv(x) 60 | return x 61 | 62 | 63 | class UNet(nn.Module): 64 | def __init__(self, in_channels = 3, n_classes = 15): 65 | super().__init__() 66 | self.inc = double_conv(in_channels, 64) 67 | self.down1 = contraction_path(64, 128) 68 | self.down2 = contraction_path(128, 256) 69 | self.down3 = contraction_path(256, 512) 70 | self.middle = contraction_path(512, 1024) 71 | self.up4 = expansion_path(1024, 512) 72 | self.up3 = expansion_path(512, 256) 73 | self.up2 = expansion_path(256, 128) 74 | self.up1 = expansion_path(128, 64) 75 | #self.up4 = double_conv(512 + 1024, 512) 76 | #self.up3 = double_conv(256 + 512, 256) 77 | #self.up2 = double_conv(128 + 256, 128) 78 | #self.up1 = double_conv(64 + 128, 64) 79 | self.out = outconv(64, n_classes) 80 | 81 | 82 | def forward(self, x): 83 | down1 = self.inc(x) 84 | down2 = self.down1(down1) 85 | down3 = self.down2(down2) 86 | down4 = self.down3(down3) 87 | 88 | middle = self.middle(down4) 89 | 90 | out = self.up4(middle, down4) 91 | out = self.up3(out, down3) 92 | out = self.up2(out, down2) 93 | out = self.up1(out, down1) 94 | 95 | #out = F.upsample(middle, scale_factor=2) 96 | #out = torch.cat([down4, out], 1) 97 | #out = self.up4(out) 98 | 99 | #out = F.upsample(out, scale_factor=2) 100 | #out = torch.cat([down3, out], 1) 101 | #out = self.up3(out) 102 | 103 | #out = F.upsample(out, scale_factor=2) 104 | #out = torch.cat([down2, out], 1) 105 | #out = self.up2(out) 106 | 107 | #out = F.upsample(out, scale_factor=2) 108 | #out = torch.cat([down1, out], 1) 109 | #out = self.up1(out) 110 | 111 | x = self.out(out) 112 | 113 | return x -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .UNet import UNet 2 | from .SegNet import SegNet 3 | from .DenseNet import DenseNet -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import sys 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | from torchsummary import summary 10 | import copy 11 | import pickle 12 | 13 | 14 | from optparse import OptionParser 15 | 16 | from util import make_one_hot 17 | from dataset import SampleDataset 18 | from model import UNet, SegNet, DenseNet 19 | from loss import dice_score 20 | 21 | 22 | def get_args(): 23 | 24 | parser = OptionParser() 25 | parser.add_option('--data_path', dest='data_path',type='string', 26 | default='data/samples', help='data path') 27 | parser.add_option('--model_path', dest='model_path',type='string', 28 | default='checkpoints/', help='model_path') 29 | parser.add_option('--classes', dest='classes', default=28, type='int', 30 | help='number of classes') 31 | parser.add_option('--channels', dest='channels', default=1, type='int', 32 | help='number of channels') 33 | parser.add_option('--width', dest='width', default=256, type='int', 34 | help='image width') 35 | parser.add_option('--height', dest='height', default=256, type='int', 36 | help='image height') 37 | parser.add_option('--model', dest='model', type='string', 38 | help='model name(UNet, SegNet, DenseNet)') 39 | parser.add_option('--gpu', dest='gpu',type='string', 40 | default='gpu', help='gpu or cpu') 41 | parser.add_option('--device1', dest='device1', default=0, type='int', 42 | help='device1 index number') 43 | parser.add_option('--device2', dest='device2', default=-1, type='int', 44 | help='device2 index number') 45 | parser.add_option('--device3', dest='device3', default=-1, type='int', 46 | help='device3 index number') 47 | parser.add_option('--device4', dest='device4', default=-1, type='int', 48 | help='device4 index number') 49 | 50 | (options, args) = parser.parse_args() 51 | return options 52 | 53 | 54 | def test(model, args): 55 | 56 | data_path = args.data_path 57 | gpu = args.gpu 58 | n_classes = args.classes 59 | data_width = args.width 60 | data_height = args.height 61 | 62 | # set device configuration 63 | device_ids = [] 64 | 65 | if gpu == 'gpu' : 66 | 67 | if not torch.cuda.is_available() : 68 | print("No cuda available") 69 | raise SystemExit 70 | 71 | device = torch.device(args.device1) 72 | 73 | device_ids.append(args.device1) 74 | 75 | if args.device2 != -1 : 76 | device_ids.append(args.device2) 77 | 78 | if args.device3 != -1 : 79 | device_ids.append(args.device3) 80 | 81 | if args.device4 != -1 : 82 | device_ids.append(args.device4) 83 | 84 | 85 | else : 86 | device = torch.device("cpu") 87 | 88 | if len(device_ids) > 1: 89 | model = nn.DataParallel(model, device_ids = device_ids) 90 | 91 | model = model.to(device) 92 | 93 | # set testdataset 94 | 95 | test_dataset = SampleDataset(data_path) 96 | 97 | test_loader = DataLoader( 98 | test_dataset, 99 | batch_size=10, 100 | num_workers=4, 101 | ) 102 | 103 | print('test_dataset : {}, test_loader : {}'.format(len(test_dataset), len(test_loader))) 104 | 105 | avg_score = 0.0 106 | 107 | # test 108 | 109 | model.eval() # Set model to evaluate mode 110 | 111 | with torch.no_grad(): 112 | for batch_idx, (inputs, labels) in enumerate(test_loader): 113 | 114 | inputs = inputs.to(device).float() 115 | labels = labels.to(device).long() 116 | 117 | target = make_one_hot(labels[:,0,:,:], n_classes, device) 118 | 119 | pred = model(inputs) 120 | 121 | loss = dice_score(pred,target) 122 | 123 | avg_score += loss.data.cpu().numpy() 124 | 125 | del inputs, labels, target, pred, loss 126 | 127 | avg_score /= len(test_loader) 128 | 129 | print('dice_score : {:.4f}'.format(avg_score)) 130 | 131 | if __name__ == "__main__": 132 | 133 | args = get_args() 134 | 135 | n_channels = args.channels 136 | n_classes = args.classes 137 | 138 | model = None 139 | 140 | if args.model == 'UNet': 141 | model = UNet(in_channels = n_channels, n_classes = n_classes) 142 | 143 | elif args.model == 'SegNet': 144 | model = SegNet(in_channels = n_channels, n_classes = n_classes) 145 | 146 | elif args.model == 'DenseNet': 147 | model = DenseNet(in_channels = n_channels, n_classes = n_classes) 148 | 149 | else : 150 | print("wrong model : must be UNet, SegNet, or DenseNet") 151 | raise SystemExit 152 | 153 | summary(model, input_size=(n_channels, args.height, args.width), device = 'cpu') 154 | 155 | model.load_state_dict(torch.load(args.model_path)) 156 | 157 | test(model, args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | import sys 7 | from torch.autograd import Variable 8 | from torch.utils.data.dataset import Dataset 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | from sklearn.model_selection import train_test_split 13 | from torchsummary import summary 14 | from collections import defaultdict 15 | import copy 16 | import pickle 17 | 18 | from optparse import OptionParser 19 | 20 | from dataset import SampleDataset 21 | from model import UNet, SegNet, DenseNet 22 | from util import save_metrics, print_metrics 23 | from loss import combined_loss 24 | 25 | def get_args(): 26 | 27 | parser = OptionParser() 28 | parser.add_option('--data_path', dest='data_path',type='string', 29 | default='data/samples', help='data path') 30 | parser.add_option('--epochs', dest='epochs', default=50, type='int', 31 | help='number of epochs') 32 | parser.add_option('--classes', dest='classes', default=28, type='int', 33 | help='number of classes') 34 | parser.add_option('--channels', dest='channels', default=1, type='int', 35 | help='number of channels') 36 | parser.add_option('--width', dest='width', default=256, type='int', 37 | help='image width') 38 | parser.add_option('--height', dest='height', default=256, type='int', 39 | help='image height') 40 | parser.add_option('--model', dest='model', type='string', 41 | help='model name(UNet, SegNet, DenseNet)') 42 | parser.add_option('--gpu', dest='gpu',type='string', 43 | default='gpu', help='gpu or cpu') 44 | parser.add_option('--device1', dest='device1', default=0, type='int', 45 | help='device1 index number') 46 | parser.add_option('--device2', dest='device2', default=-1, type='int', 47 | help='device2 index number') 48 | parser.add_option('--device3', dest='device3', default=-1, type='int', 49 | help='device3 index number') 50 | parser.add_option('--device4', dest='device4', default=-1, type='int', 51 | help='device4 index number') 52 | 53 | (options, args) = parser.parse_args() 54 | return options 55 | 56 | def train_net(model, args): 57 | 58 | data_path = args.data_path 59 | num_epochs = args.epochs 60 | gpu = args.gpu 61 | n_classes = args.classes 62 | data_width = args.width 63 | data_height = args.height 64 | 65 | # set device configuration 66 | device_ids = [] 67 | 68 | if gpu == 'gpu' : 69 | 70 | if not torch.cuda.is_available() : 71 | print("No cuda available") 72 | raise SystemExit 73 | 74 | device = torch.device(args.device1) 75 | 76 | device_ids.append(args.device1) 77 | 78 | if args.device2 != -1 : 79 | device_ids.append(args.device2) 80 | 81 | if args.device3 != -1 : 82 | device_ids.append(args.device3) 83 | 84 | if args.device4 != -1 : 85 | device_ids.append(args.device4) 86 | 87 | 88 | else : 89 | device = torch.device("cpu") 90 | 91 | if len(device_ids) > 1: 92 | model = nn.DataParallel(model, device_ids = device_ids) 93 | 94 | model = model.to(device) 95 | 96 | # set image into training and validation dataset 97 | 98 | train_dataset = SampleDataset(data_path) 99 | 100 | print('total image : {}'.format(len(train_dataset))) 101 | 102 | train_indices, val_indices = train_test_split(np.arange(len(train_dataset)), test_size=0.2, random_state=42) 103 | 104 | train_sampler = SubsetRandomSampler(train_indices) 105 | valid_sampler = SubsetRandomSampler(val_indices) 106 | 107 | train_loader = DataLoader( 108 | train_dataset, 109 | batch_size=20, 110 | num_workers=4, 111 | sampler=train_sampler 112 | ) 113 | 114 | val_loader = DataLoader( 115 | train_dataset, 116 | batch_size=10, 117 | num_workers=4, 118 | sampler=valid_sampler 119 | ) 120 | 121 | model_folder = os.path.abspath('./checkpoints') 122 | if not os.path.exists(model_folder): 123 | os.mkdir(model_folder) 124 | 125 | if args.model == 'UNet': 126 | model_path = os.path.join(model_folder, 'UNet.pth') 127 | 128 | elif args.model == 'SegNet': 129 | model_path = os.path.join(model_folder, 'SegNet.pth') 130 | 131 | elif args.model == 'DenseNet': 132 | model_path = os.path.join(model_folder, 'DenseNet.pth') 133 | 134 | # set optimizer 135 | 136 | optimizer = torch.optim.Adam(model.parameters()) 137 | 138 | 139 | # main train 140 | 141 | display_steps = 30 142 | best_loss = 1e10 143 | loss_history = [] 144 | 145 | ## for early stopping 146 | early_stop = False 147 | patience = 7 148 | counter = 0 149 | 150 | for epoch in range(num_epochs): 151 | print('Starting epoch {}/{}'.format(epoch+1, num_epochs)) 152 | 153 | # train 154 | model.train() 155 | 156 | metrics = defaultdict(float) 157 | epoch_size = 0 158 | 159 | # train model 160 | for batch_idx, (images, masks) in enumerate(train_loader): 161 | 162 | images = images.to(device).float() 163 | masks = masks.to(device).long() 164 | 165 | optimizer.zero_grad() 166 | outputs = model(images) 167 | 168 | loss, cross, dice = combined_loss(outputs, masks.squeeze(1), device, n_classes) 169 | 170 | save_metrics(metrics, images.size(0), loss, cross, dice) 171 | 172 | loss.backward() 173 | optimizer.step() 174 | 175 | # statistics 176 | epoch_size += images.size(0) 177 | 178 | if batch_idx % display_steps == 0: 179 | print(' ', end='') 180 | print('batch {:>3}/{:>3} cross: {:.4f} , dice {:.4f} , combined_loss {:.4f}\r'\ 181 | .format(batch_idx+1, len(train_loader), cross.item(), dice.item(),loss.item())) 182 | 183 | del images, masks, outputs, loss, cross, dice 184 | 185 | print_metrics(metrics, epoch_size, 'train') 186 | 187 | # evalute 188 | print('Finished epoch {}, starting evaluation'.format(epoch+1)) 189 | model.eval() 190 | 191 | # validate model 192 | for images, masks in val_loader: 193 | images = images.to(device).float() 194 | masks = masks.to(device).long() 195 | 196 | outputs = model(images) 197 | 198 | loss, cross, dice = combined_loss(outputs, masks.squeeze(1), device, n_classes) 199 | 200 | save_metrics(metrics, images.size(0), loss, cross, dice) 201 | 202 | # statistics 203 | epoch_size += images.size(0) 204 | 205 | del images, masks, outputs, loss, cross, dice 206 | 207 | print_metrics(metrics, epoch_size, 'val') 208 | 209 | epoch_loss = metrics['loss'] / epoch_size 210 | 211 | # save model if best validation loss 212 | if epoch_loss < best_loss: 213 | print("saving best model") 214 | best_loss = epoch_loss 215 | 216 | model_copy = copy.deepcopy(model) 217 | model_copy = model_copy.cpu() 218 | 219 | model_state_dict = model_copy.module.state_dict() if len(device_ids) > 1 else model_copy.state_dict() 220 | torch.save(model_state_dict, model_path) 221 | 222 | del model_copy 223 | 224 | counter = 0 225 | 226 | else: 227 | counter += 1 228 | print('EarlyStopping counter : {:>3} / {:>3}'.format(counter, patience)) 229 | 230 | if counter >= patience : 231 | early_stop = True 232 | 233 | loss_history.append(best_loss) 234 | print('Best val loss: {:4f}'.format(best_loss)) 235 | 236 | if early_stop : 237 | print('Early Stopping') 238 | break 239 | 240 | return loss_history 241 | 242 | 243 | if __name__ == "__main__": 244 | 245 | args = get_args() 246 | 247 | n_channels = args.channels 248 | n_classes = args.classes 249 | 250 | model = None 251 | 252 | if args.model == 'UNet': 253 | model = UNet(in_channels = n_channels, n_classes = n_classes) 254 | 255 | elif args.model == 'SegNet': 256 | model = SegNet(in_channels = n_channels, n_classes = n_classes) 257 | 258 | elif args.model == 'DenseNet': 259 | model = DenseNet(in_channels = n_channels, n_classes = n_classes) 260 | 261 | else : 262 | print("wrong model : must be UNet, SegNet, or DenseNet") 263 | raise SystemExit 264 | 265 | summary(model, input_size=(n_channels, args.height, args.width), device = 'cpu') 266 | 267 | loss_history = train_net(model, args) 268 | 269 | # save validation loss history 270 | with open('./checkpoints/validation_losses', 'wb') as fp: 271 | pickle.dump(loss_history, fp) 272 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | def estimate_weights(labels): 6 | ''' 7 | reference from https://github.com/ai-med/quickNAT_pytorch 8 | Estimate weights which balance the relative importance of pixesl in the LogitsticLoss 9 | more detailed in https://arxiv.org/pdf/1801.04161.pdf 10 | ''' 11 | 12 | if torch.is_tensor(labels): 13 | labels = labels.cpu().numpy() 14 | 15 | class_weights = np.zeros_like(labels) 16 | unique, counts = np.unique(labels, return_counts=True) 17 | median_freq = np.median(counts) 18 | 19 | for i, label in enumerate(unique): 20 | class_weights += (median_freq // counts[i]) * np.array(labels == label) 21 | 22 | grads = np.gradient(labels) 23 | edge_weights = (grads[0] ** 2 + grads[1] ** 2) > 0 24 | class_weights += 2 * edge_weights 25 | 26 | class_weights = torch.tensor(class_weights).float() 27 | 28 | return class_weights 29 | 30 | 31 | def make_one_hot(labels, num_classes, device): 32 | ''' 33 | Converts an integer label to a one-hot values. 34 | 35 | Parameters 36 | ---------- 37 | labels : N x H x W, where N is batch size.(torch.Tensor) 38 | num_classes : int 39 | device: torch.device information 40 | ------- 41 | Returns 42 | target : torch.Tensor on given device 43 | N x C x H x W, where C is class number. One-hot encoded. 44 | ''' 45 | 46 | labels=labels.unsqueeze(1) 47 | one_hot = torch.FloatTensor(labels.size(0), num_classes, labels.size(2), labels.size(3)).zero_() 48 | one_hot = one_hot.to(device) 49 | target = one_hot.scatter_(1, labels.data, 1) 50 | return target 51 | 52 | 53 | def save_metrics(metrics, size, loss, cross = None, dice = None): 54 | ''' 55 | loss value save in metrics 56 | ''' 57 | 58 | if cross is not None: 59 | metrics['cross'] += cross.data.cpu().numpy() * size 60 | 61 | if dice is not None: 62 | metrics['dice'] += dice.data.cpu().numpy() * size 63 | 64 | metrics['loss'] += loss.data.cpu().numpy() * size 65 | 66 | def print_metrics(metrics, epoch_size, phase): 67 | ''' 68 | print metrics which saves loss value 69 | ''' 70 | 71 | outputs = [] 72 | 73 | for k in metrics.keys(): 74 | outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_size)) 75 | 76 | print("{}: {}".format(phase, ", ".join(outputs))) 77 | --------------------------------------------------------------------------------