├── .gitattributes ├── LICENSE ├── README.md ├── dataloader ├── MRBrain_loader.py ├── __init__.py ├── augmentation.py └── coder.py ├── evaluate.py ├── images ├── image.png ├── image_mask.png ├── mask.png ├── pred_crf_mask.png └── pred_mask.png ├── model ├── __init__.py ├── fcn.py └── unet.py ├── predict.py ├── train.py ├── train_unet.py └── utils ├── __init__.py ├── crf.py ├── loss.py └── metrics.py /.gitattributes: -------------------------------------------------------------------------------- 1 | checkpoint/best_unet_model.pkl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 彭智亮 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 医学图像分割 2 | 3 | #### 数据集 4 | MRBrainS 数据集 5 | 1~4 用于train 5用于val 6 | #### 实现 7 | 1. VGG-FCN 8 | 2. UNet 参考:[pytorch-unrt](https://github.com/milesial/Pytorch-UNet) 9 | #### predict 10 | checkpoint中已提供训练好的model 11 | ```python 12 | python predict.py 13 | ``` 14 | #### 结果 15 | Mean IoU : 0.8053 Mean dice: 0.8921 16 | 1. 原图: 17 | 18 | ![原图](./images/image.png) 19 | 20 | 2. mask: 21 | 22 | ![](./images/mask.png) 23 | 24 | 3. 预测 25 | 26 | ![](./images/pred_mask.png) 27 | 28 | 4. dense crf 29 | 30 | ![](./images/pred_crf_mask.png) 31 | -------------------------------------------------------------------------------- /dataloader/MRBrain_loader.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Nov 13,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | import os 8 | import os.path as osp 9 | import numpy as np 10 | 11 | import torch 12 | from torch.utils import data 13 | import cv2 14 | 15 | class MRBrainSDataset(data.Dataset): 16 | def __init__(self, root, split='train', is_transform=True, augmentations=None, img_norm=False): 17 | self.root = root 18 | self.is_transform = is_transform 19 | self.augmentations = augmentations 20 | self.img_norm = img_norm 21 | files = tuple(open(osp.join(root, split+'.txt'), 'r')) 22 | self.files = [file_.rstrip('\n') for file_ in files] 23 | 24 | def __len__(self): 25 | return len(self.files) 26 | 27 | def __getitem__(self, index): 28 | 29 | img = cv2.imread(osp.join(self.root, self.files[index]+'.png')) 30 | mask = cv2.imread(osp.join(self.root, self.files[index]+'_mask.png'), 0) 31 | 32 | if self.augmentations is not None: 33 | img, mask = self.augmentations(img, mask) 34 | img_g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 35 | 36 | mask = mask.astype(np.uint8) 37 | 38 | if self.img_norm: 39 | img = img / 255.0 40 | 41 | if self.is_transform: 42 | # to transpose, to Tensor 43 | img, mask = self._transform(img, mask) 44 | else: 45 | img = data_np.astype(np.float32) 46 | 47 | return img, mask 48 | 49 | def _transform(self, img, mask): 50 | img = img.astype(np.float64) 51 | img = img.transpose(2, 0, 1) 52 | img = torch.from_numpy(img).float() 53 | mask = torch.from_numpy(mask).long() 54 | return img, mask 55 | 56 | 57 | if __name__ == '__main__': 58 | from augmentations import * 59 | # from utils.utils import * 60 | # data_aug=None 61 | data_aug = Compose([ 62 | RandomHorizontallyFlip(0.5), 63 | RandomRotate(10), 64 | # Scale(256), 65 | ]) 66 | 67 | dloader = torch.utils.data.DataLoader(MRBrainSDataset(osp.join('/home/cv_xfwang/data/', 'MRBrainS'), split='train', is_transform=True, img_norm=True, augmentations=data_aug), batch_size=1) 68 | for idx, (img, mask) in enumerate(dloader): 69 | if idx < 10: 70 | img = img.cpu().data[0].numpy().transpose(1, 2, 0) 71 | mask = mask.cpu().data[0].numpy() 72 | #print(mask,np.sum(mask)) 73 | img = img * 255.0 74 | import cv2 75 | #cv2.imwrite('sample/%d.png'%(idx+1), img.astype(np.uint8)) 76 | cv2.imshow('sample image',img.astype(np.uint8)) 77 | print(np.unique(mask)) 78 | #cv2.imwrite('sample/%d_mask.png'%(idx+1), mask.astype(np.uint8)) 79 | cv2.imshow('sample mask', mask.astype(np.uint8)) 80 | cv2.waitKey(0) 81 | else: 82 | exit(-1) 83 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/augmentation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import numpy as np 5 | import torchvision.transforms.functional as tf 6 | 7 | from PIL import Image, ImageOps 8 | 9 | 10 | class Compose(object): 11 | def __init__(self, augmentations): 12 | self.augmentations = augmentations 13 | self.PIL2Numpy = False 14 | 15 | def __call__(self, img, mask): 16 | if isinstance(img, np.ndarray): 17 | img = Image.fromarray(img, mode="RGB") 18 | mask = Image.fromarray(mask, mode="L") 19 | self.PIL2Numpy = True 20 | 21 | assert img.size == mask.size 22 | for a in self.augmentations: 23 | img, mask = a(img, mask) 24 | 25 | if self.PIL2Numpy: 26 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 27 | 28 | return img, mask 29 | 30 | 31 | class RandomCrop(object): 32 | def __init__(self, size, padding=0): 33 | if isinstance(size, numbers.Number): 34 | self.size = (int(size), int(size)) 35 | else: 36 | self.size = size 37 | self.padding = padding 38 | 39 | def __call__(self, img, mask): 40 | if self.padding > 0: 41 | img = ImageOps.expand(img, border=self.padding, fill=0) 42 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 43 | 44 | assert img.size == mask.size 45 | w, h = img.size 46 | th, tw = self.size 47 | if w == tw and h == th: 48 | return img, mask 49 | if w < tw or h < th: 50 | return ( 51 | img.resize((tw, th), Image.BILINEAR), 52 | mask.resize((tw, th), Image.NEAREST), 53 | ) 54 | 55 | x1 = random.randint(0, w - tw) 56 | y1 = random.randint(0, h - th) 57 | return ( 58 | img.crop((x1, y1, x1 + tw, y1 + th)), 59 | mask.crop((x1, y1, x1 + tw, y1 + th)), 60 | ) 61 | 62 | 63 | class AdjustGamma(object): 64 | def __init__(self, gamma): 65 | self.gamma = gamma 66 | 67 | def __call__(self, img, mask): 68 | assert img.size == mask.size 69 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask 70 | 71 | 72 | class AdjustSaturation(object): 73 | def __init__(self, saturation): 74 | self.saturation = saturation 75 | 76 | def __call__(self, img, mask): 77 | assert img.size == mask.size 78 | return tf.adjust_saturation(img, 79 | random.uniform(1 - self.saturation, 80 | 1 + self.saturation)), mask 81 | 82 | 83 | class AdjustHue(object): 84 | def __init__(self, hue): 85 | self.hue = hue 86 | 87 | def __call__(self, img, mask): 88 | assert img.size == mask.size 89 | return tf.adjust_hue(img, random.uniform(-self.hue, 90 | self.hue)), mask 91 | 92 | 93 | class AdjustBrightness(object): 94 | def __init__(self, bf): 95 | self.bf = bf 96 | 97 | def __call__(self, img, mask): 98 | assert img.size == mask.size 99 | return tf.adjust_brightness(img, 100 | random.uniform(1 - self.bf, 101 | 1 + self.bf)), mask 102 | 103 | class AdjustContrast(object): 104 | def __init__(self, cf): 105 | self.cf = cf 106 | 107 | def __call__(self, img, mask): 108 | assert img.size == mask.size 109 | return tf.adjust_contrast(img, 110 | random.uniform(1 - self.cf, 111 | 1 + self.cf)), mask 112 | 113 | class CenterCrop(object): 114 | def __init__(self, size): 115 | if isinstance(size, numbers.Number): 116 | self.size = (int(size), int(size)) 117 | else: 118 | self.size = size 119 | 120 | def __call__(self, img, mask): 121 | assert img.size == mask.size 122 | w, h = img.size 123 | th, tw = self.size 124 | x1 = int(round((w - tw) / 2.)) 125 | y1 = int(round((h - th) / 2.)) 126 | return ( 127 | img.crop((x1, y1, x1 + tw, y1 + th)), 128 | mask.crop((x1, y1, x1 + tw, y1 + th)), 129 | ) 130 | 131 | 132 | class RandomHorizontallyFlip(object): 133 | def __init__(self, p): 134 | self.p = p 135 | 136 | def __call__(self, img, mask): 137 | if random.random() < self.p: 138 | return ( 139 | img.transpose(Image.FLIP_LEFT_RIGHT), 140 | mask.transpose(Image.FLIP_LEFT_RIGHT), 141 | ) 142 | return img, mask 143 | 144 | 145 | class RandomVerticallyFlip(object): 146 | def __init__(self, p): 147 | self.p = p 148 | 149 | def __call__(self, img, mask): 150 | if random.random() < self.p: 151 | return ( 152 | img.transpose(Image.FLIP_TOP_BOTTOM), 153 | mask.transpose(Image.FLIP_TOP_BOTTOM), 154 | ) 155 | return img, mask 156 | 157 | 158 | class FreeScale(object): 159 | def __init__(self, size): 160 | self.size = tuple(reversed(size)) # size: (h, w) 161 | 162 | def __call__(self, img, mask): 163 | assert img.size == mask.size 164 | return ( 165 | img.resize(self.size, Image.BILINEAR), 166 | mask.resize(self.size, Image.NEAREST), 167 | ) 168 | 169 | 170 | class RandomTranslate(object): 171 | def __init__(self, offset): 172 | self.offset = offset # tuple (delta_x, delta_y) 173 | 174 | def __call__(self, img, mask): 175 | assert img.size == mask.size 176 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0]) 177 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1]) 178 | 179 | x_crop_offset = x_offset 180 | y_crop_offset = y_offset 181 | if x_offset < 0: 182 | x_crop_offset = 0 183 | if y_offset < 0: 184 | y_crop_offset = 0 185 | 186 | cropped_img = tf.crop(img, 187 | y_crop_offset, 188 | x_crop_offset, 189 | img.size[1]-abs(y_offset), 190 | img.size[0]-abs(x_offset)) 191 | 192 | if x_offset >= 0 and y_offset >= 0: 193 | padding_tuple = (0, 0, x_offset, y_offset) 194 | 195 | elif x_offset >= 0 and y_offset < 0: 196 | padding_tuple = (0, abs(y_offset), x_offset, 0) 197 | 198 | elif x_offset < 0 and y_offset >= 0: 199 | padding_tuple = (abs(x_offset), 0, 0, y_offset) 200 | 201 | elif x_offset < 0 and y_offset < 0: 202 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0) 203 | 204 | return ( 205 | tf.pad(cropped_img, 206 | padding_tuple, 207 | padding_mode='reflect'), 208 | tf.affine(mask, 209 | translate=(-x_offset, -y_offset), 210 | scale=1.0, 211 | angle=0.0, 212 | shear=0.0, 213 | fillcolor=0)) 214 | 215 | 216 | class RandomRotate(object): 217 | def __init__(self, degree): 218 | self.degree = degree 219 | 220 | def __call__(self, img, mask): 221 | rotate_degree = random.random() * 2 * self.degree - self.degree 222 | return ( 223 | tf.affine(img, 224 | translate=(0, 0), 225 | scale=1.0, 226 | angle=rotate_degree, 227 | resample=Image.BILINEAR, 228 | fillcolor=(0, 0, 0), 229 | shear=0.0), 230 | tf.affine(mask, 231 | translate=(0, 0), 232 | scale=1.0, 233 | angle=rotate_degree, 234 | resample=Image.NEAREST, 235 | fillcolor=0, 236 | shear=0.0)) 237 | 238 | 239 | 240 | class Scale(object): 241 | def __init__(self, size=512, random_scale=False): 242 | self.size = size 243 | self.random_scale = random_scale 244 | 245 | def __call__(self, img, mask): 246 | assert img.size == mask.size 247 | w, h = img.size 248 | if self.random_scale: 249 | new_w = int(random.uniform(0.5, 2) * w) 250 | new_h = int(random.uniform(0.5, 2) * h) 251 | return ( 252 | img.resize((new_w, new_h), Image.BILINEAR), 253 | mask.resize((new_w, new_h), Image.NEAREST), 254 | ) 255 | if (w >= h and w == self.size) or (h >= w and h == self.size): 256 | return img, mask 257 | if w > h: 258 | ow = self.size 259 | oh = int(self.size * h / w) 260 | return ( 261 | img.resize((ow, oh), Image.BILINEAR), 262 | mask.resize((ow, oh), Image.NEAREST), 263 | ) 264 | else: 265 | oh = self.size 266 | ow = int(self.size * w / h) 267 | return ( 268 | img.resize((ow, oh), Image.BILINEAR), 269 | mask.resize((ow, oh), Image.NEAREST), 270 | ) 271 | 272 | 273 | class RandomSizedCrop(object): 274 | def __init__(self, size): 275 | self.size = size 276 | 277 | def __call__(self, img, mask): 278 | assert img.size == mask.size 279 | for attempt in range(10): 280 | area = img.size[0] * img.size[1] 281 | target_area = random.uniform(0.45, 1.0) * area 282 | aspect_ratio = random.uniform(0.5, 2) 283 | 284 | w = int(round(math.sqrt(target_area * aspect_ratio))) 285 | h = int(round(math.sqrt(target_area / aspect_ratio))) 286 | 287 | if random.random() < 0.5: 288 | w, h = h, w 289 | 290 | if w <= img.size[0] and h <= img.size[1]: 291 | x1 = random.randint(0, img.size[0] - w) 292 | y1 = random.randint(0, img.size[1] - h) 293 | 294 | img = img.crop((x1, y1, x1 + w, y1 + h)) 295 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 296 | assert img.size == (w, h) 297 | 298 | return ( 299 | img.resize((self.size, self.size), Image.BILINEAR), 300 | mask.resize((self.size, self.size), Image.NEAREST), 301 | ) 302 | 303 | # Fallback 304 | scale = Scale(self.size) 305 | crop = CenterCrop(self.size) 306 | return crop(*scale(img, mask)) 307 | 308 | 309 | 310 | 311 | 312 | class RandomSized(object): 313 | def __init__(self, size): 314 | self.size = size 315 | self.scale = Scale(self.size) 316 | self.crop = RandomCrop(self.size) 317 | 318 | def __call__(self, img, mask): 319 | assert img.size == mask.size 320 | 321 | w = int(random.uniform(0.5, 2) * img.size[0]) 322 | h = int(random.uniform(0.5, 2) * img.size[1]) 323 | 324 | img, mask = ( 325 | img.resize((w, h), Image.BILINEAR), 326 | mask.resize((w, h), Image.NEAREST), 327 | ) 328 | 329 | return self.crop(*self.scale(img, mask)) 330 | -------------------------------------------------------------------------------- /dataloader/coder.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Nov 13,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | import numpy as np 8 | import cv2 9 | # 训练时类别 10 | # 0. Background 11 | # 1. Cortical gray matter (皮质灰质) 12 | # 2. Basal ganglia (基底神经节 ) 13 | # 3. White matter (白质) 14 | # 4. White matter lesions (白质组织) 15 | # 5. Cerebrospinal fluid in the extracerebral space (脑脊液) 16 | # 6. Ventricles (脑室) 17 | # 7. Cerebellum (小脑) 18 | # 8. Brainstem (脑干) 19 | # 测试时类别,类别合并 20 | # 0. Background 21 | # 1. Cerebrospinal fluid (including ventricles) 22 | # 2. Gray matter (cortical gray matter and basal ganglia) 23 | # 3. White matter (including white matter lesions) 24 | # label_test:[0,2,2,3,3,1,1,0,0] 25 | 26 | # Back 0 : Background 27 | # GM 2 : Cortical GM(red), Basal ganglia(green) 28 | # WM 3: WM(yellow), WM lesions(blue) 29 | # CSF 1 : CSF(pink), Ventricles(light blue) 30 | # Back: Cerebellum(white), Brainstem(dark red) 31 | 32 | color = np.asarray([[0,0,0],[0,0,255],[0,255,0],[0,255,255],[255,0,0],\ 33 | [255,0,255],[255,255,0],[255,255,255],[0,0,128],[0,128,0],[128,0,0]]).astype(np.uint8) 34 | color_test = np.asarray([[0,0,0],[0,0,255],[0,255,0],[255,0,0]]).astype(np.uint8) 35 | # Back , CSF , GM , WM 36 | label_test=[0,2,2,3,3,1,1,0,0] 37 | 38 | def merge_classes(label): 39 | """ 40 | 功能:将九类按一定的规则合并成4类,具体间上方注释 41 | 输入: 有9类的二维np.array 42 | 输出: 只有4类的二维np.array 43 | """ 44 | label = label.astype(np.int) 45 | label[label == 1] = 2 46 | label[label == 4] = 3 47 | label[label == 5] = 1 48 | label[label == 6] = 1 49 | label[label == 7] = 0 50 | label[label == 8] = 0 51 | return label 52 | def encode(label,color): 53 | """ 54 | 将输入的灰度图转换成RGB图 55 | 输入: 56 | label: 灰度图(二维 np.array) 57 | color: 每类对应的颜色 58 | """ 59 | H,W = label.shape 60 | img = np.zeros((H,W,3)) 61 | for i in range(H): 62 | for j in range(W): 63 | img[i,j] = color[label[i,j]] 64 | return img 65 | 66 | if __name__ == '__main__': 67 | a = np.array([[0,1,2,3],[4,5,6,7]],dtype=np.int) 68 | print(a,'\n',merge_classes(a)) 69 | img = cv2.imread('/home/cv_xfwang/data/MRBrainS/TrainingData/5/slices/5_001.png') 70 | cv2.imshow('image',img) 71 | mask = cv2.imread('/home/cv_xfwang/data/MRBrainS/TrainingData/5/slices/5_001_mask.png',0) 72 | cv2.imshow('Encode mask image',encode(mask,color)) 73 | cv2.waitKey(0) 74 | cv2.destroyAllWindows() 75 | 76 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Nov14 31,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | 8 | import time 9 | import numpy as np 10 | import os 11 | import os.path as osp 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | from tqdm import tqdm 19 | from torch.utils.data import Dataset,DataLoader 20 | from torch.optim import lr_scheduler,Adam,SGD 21 | from torchvision import datasets, models, transforms 22 | from torchsummary import summary 23 | from model.unet import UNet 24 | from model.fcn import fcn 25 | from utils.metrics import Score,averageMeter 26 | from utils.crf import dense_crf 27 | from dataloader.MRBrain_loader import MRBrainSDataset 28 | from dataloader.augmentation import * 29 | from dataloader.coder import merge_classes 30 | 31 | 32 | # GPU or CPU 33 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | # 参数设置 36 | defualt_path = osp.join('/home/cv_xfwang/data/', 'MRBrainS') 37 | batch_size = 1 38 | num_workers = 4 39 | resume_path = '/home/cv_xfwang/MRBrainS_seg/checkpoint/best_unet_model.pkl' 40 | # data loader 41 | val_loader = DataLoader(MRBrainSDataset(defualt_path, split='val', is_transform=True, \ 42 | img_norm=True, augmentations=Compose([Scale(224)])), \ 43 | batch_size=1,num_workers=num_workers,pin_memory=True,shuffle=False) 44 | # Setup Model and summary 45 | model = UNet().to(device) 46 | # summary(model,(3,224,224),batch_size) # summary 网络参数 47 | 48 | # running_metrics = Score(n_classes=9) 49 | running_metrics = Score(n_classes=4) # label_test=[0,2,2,3,3,1,1,0,0] 50 | # resume 51 | if osp.isfile(resume_path): 52 | checkpoint = torch.load(resume_path) 53 | model.load_state_dict(checkpoint["model_state"]) 54 | best_iou = checkpoint['best_iou'] 55 | print("=====>", 56 | "Loaded checkpoint '{}' (iter {})".format( 57 | resume_path, checkpoint["epoch"] 58 | )) 59 | print("=====> best mIoU: %.4f best mean dice: %.4f"%(best_iou,(best_iou*2)/(best_iou+1))) 60 | else: 61 | raise ValueError("can't find model") 62 | 63 | 64 | print(">>>Test After Dense CRF: ") 65 | model.eval() 66 | running_metrics.reset() 67 | with torch.no_grad(): 68 | for i, (img, mask) in tqdm(enumerate(val_loader)): 69 | img = img.to(device) 70 | output = model(img) #[-1, 9, 256, 256] 71 | probs = F.softmax(output, dim=1) 72 | pred = probs.cpu().data[0].numpy() 73 | label = mask.cpu().data[0].numpy() 74 | # crf 75 | img = img.cpu().data[0].numpy() 76 | pred = dense_crf(img*255, pred) 77 | # print(pred.shape) 78 | # _, pred = torch.max(torch.tensor(pred), dim=-1) 79 | pred = np.asarray(pred, dtype=np.int) 80 | label = np.asarray(label, dtype=np.int) 81 | # 合并特征 82 | pred = merge_classes(pred) 83 | label = merge_classes(label) 84 | # print(pred.shape,label.shape) 85 | running_metrics.update(label,pred) 86 | 87 | score, class_iou = running_metrics.get_scores() 88 | for k, v in score.items(): 89 | print(k,':',v) 90 | print(i, class_iou) -------------------------------------------------------------------------------- /images/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/images/image.png -------------------------------------------------------------------------------- /images/image_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/images/image_mask.png -------------------------------------------------------------------------------- /images/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/images/mask.png -------------------------------------------------------------------------------- /images/pred_crf_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/images/pred_crf_mask.png -------------------------------------------------------------------------------- /images/pred_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/images/pred_mask.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/model/__init__.py -------------------------------------------------------------------------------- /model/fcn.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Nov 13,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchsummary import summary 12 | 13 | class fcn(nn.Module): 14 | def __init__(self,n_classes=9): 15 | super(fcn, self).__init__() 16 | self.n_classes = n_classes 17 | 18 | self.conv_block1 = nn.Sequential( 19 | nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), 20 | nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) 21 | self.conv_block2 = nn.Sequential( 22 | nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), 23 | nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) 24 | self.conv_block3 = nn.Sequential( 25 | nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), 27 | nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) 28 | self.conv_block4 = nn.Sequential( 29 | nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), 30 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), 31 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) 32 | self.conv_block5 = nn.Sequential( 33 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), 34 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), 35 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) 36 | 37 | self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) 38 | 39 | self.conv1_16=nn.Conv2d(64, 16, 3, padding=1) 40 | self.conv2_16=nn.Conv2d(128, 16, 3, padding=1) 41 | self.up_conv2_16 = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2) 42 | self.conv3_16=nn.Conv2d(256, 16, 3, padding=1) 43 | self.up_conv3_16 = nn.ConvTranspose2d(16, 16, kernel_size=4, stride=4) 44 | self.conv4_16=nn.Conv2d(512, 16, 3, padding=1) 45 | self.up_conv4_16 = nn.ConvTranspose2d(16, 16, kernel_size=8, stride=8) 46 | self.conv5_16=nn.Conv2d(512, 16, 3, padding=1) 47 | self.up_conv5_16 = nn.ConvTranspose2d(16, 16, kernel_size=16, stride=16) 48 | 49 | self.score=nn.Sequential( 50 | nn.Conv2d(4*16,self.n_classes,1), 51 | nn.Dropout(0.5), 52 | ) 53 | 54 | def forward(self, x): 55 | conv1 = self.conv_block1(x) 56 | conv2 = self.conv_block2(self.pool(conv1)) 57 | conv3 = self.conv_block3(self.pool(conv2)) 58 | conv4 = self.conv_block4(self.pool(conv3)) 59 | conv5 = self.conv_block5(self.pool(conv4)) 60 | 61 | conv1_16=self.conv1_16(conv1) 62 | up_conv2_16=self.up_conv2_16(self.conv2_16(conv2)) 63 | up_conv3_16=self.up_conv3_16(self.conv3_16(conv3)) 64 | up_conv4_16=self.up_conv4_16(self.conv4_16(conv4)) 65 | up_conv5_16=self.up_conv5_16(self.conv5_16(conv5)) 66 | 67 | concat_1_to_4=torch.cat([conv1_16,up_conv2_16,up_conv3_16,up_conv4_16], 1) 68 | score=self.score(concat_1_to_4) 69 | return score # [-1, 9, 256, 256] 70 | 71 | def init_vgg16_params(self, vgg16, copy_fc8=True): 72 | blocks = [self.conv_block1, 73 | self.conv_block2, 74 | self.conv_block3, 75 | self.conv_block4, 76 | self.conv_block5] 77 | 78 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 79 | features = list(vgg16.features.children()) 80 | 81 | for idx, conv_block in enumerate(blocks): 82 | for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): 83 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 84 | assert l1.weight.size() == l2.weight.size() 85 | assert l1.bias.size() == l2.bias.size() 86 | l2.weight.data = l1.weight.data 87 | l2.bias.data = l1.bias.data 88 | 89 | class fcn_5(nn.Module): 90 | def __init__(self,n_classes=9): 91 | super(fcn_5, self).__init__() 92 | self.n_classes = n_classes 93 | 94 | self.conv_block1 = nn.Sequential( 95 | nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), 96 | nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) 97 | self.conv_block2 = nn.Sequential( 98 | nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), 99 | nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) 100 | self.conv_block3 = nn.Sequential( 101 | nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), 102 | nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), 103 | nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) 104 | self.conv_block4 = nn.Sequential( 105 | nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), 106 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), 107 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) 108 | self.conv_block5 = nn.Sequential( 109 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), 110 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), 111 | nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) 112 | 113 | self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) 114 | 115 | self.conv1_16=nn.Conv2d(64, 16, 3, padding=1) 116 | self.conv2_16=nn.Conv2d(128, 16, 3, padding=1) 117 | self.up_conv2_16 = nn.ConvTranspose2d(16, 16, kernel_size=2, stride=2) 118 | self.conv3_16=nn.Conv2d(256, 16, 3, padding=1) 119 | self.up_conv3_16 = nn.ConvTranspose2d(16, 16, kernel_size=4, stride=4) 120 | self.conv4_16=nn.Conv2d(512, 16, 3, padding=1) 121 | self.up_conv4_16 = nn.ConvTranspose2d(16, 16, kernel_size=8, stride=8) 122 | self.conv5_16=nn.Conv2d(512, 16, 3, padding=1) 123 | self.up_conv5_16 = nn.ConvTranspose2d(16, 16, kernel_size=16, stride=16) 124 | 125 | self.finscore=nn.Sequential( 126 | nn.Conv2d(5*16,self.n_classes,1), 127 | nn.Dropout(0.5), 128 | ) 129 | 130 | def forward(self, x): 131 | conv1 = self.conv_block1(x) 132 | conv2 = self.conv_block2(self.pool(conv1)) 133 | conv3 = self.conv_block3(self.pool(conv2)) 134 | conv4 = self.conv_block4(self.pool(conv3)) 135 | conv5 = self.conv_block5(self.pool(conv4)) 136 | 137 | conv1_16=self.conv1_16(conv1) 138 | up_conv2_16=self.up_conv2_16(self.conv2_16(conv2)) 139 | up_conv3_16=self.up_conv3_16(self.conv3_16(conv3)) 140 | up_conv4_16=self.up_conv4_16(self.conv4_16(conv4)) 141 | up_conv5_16=self.up_conv5_16(self.conv5_16(conv5)) 142 | 143 | concat_1_to_5=torch.cat([conv1_16,up_conv2_16,up_conv3_16,up_conv4_16,up_conv5_16], 1) 144 | score=self.finscore(concat_1_to_5) 145 | return score # [-1, 9, 256, 256] 146 | 147 | def init_vgg16_params(self, vgg16, copy_fc8=True): 148 | blocks = [self.conv_block1, 149 | self.conv_block2, 150 | self.conv_block3, 151 | self.conv_block4, 152 | self.conv_block5] 153 | 154 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 155 | features = list(vgg16.features.children()) 156 | 157 | for idx, conv_block in enumerate(blocks): 158 | for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): 159 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 160 | assert l1.weight.size() == l2.weight.size() 161 | assert l1.bias.size() == l2.bias.size() 162 | l2.weight.data = l1.weight.data 163 | l2.bias.data = l1.bias.data 164 | class fcn_dilated(nn.Module): 165 | def __init__(self,n_classes=9): 166 | super(fcn_dilated, self).__init__() 167 | self.n_classes = n_classes 168 | 169 | self.conv_block1 = nn.Sequential( 170 | nn.Conv2d(3, 64, 3, dilation=1, padding=1),nn.ReLU(inplace=True), 171 | nn.Conv2d(64, 64, 3, dilation=2, padding=2),nn.ReLU(inplace=True),) 172 | self.conv_block2 = nn.Sequential( 173 | nn.Conv2d(64, 128, 3, dilation=1, padding=1),nn.ReLU(inplace=True), 174 | nn.Conv2d(128, 128, 3, dilation=2, padding=2),nn.ReLU(inplace=True),) 175 | self.conv_block3 = nn.Sequential( 176 | nn.Conv2d(128, 256, 3, dilation=1, padding=1),nn.ReLU(inplace=True), 177 | nn.Conv2d(256, 256, 3, dilation=2, padding=2),nn.ReLU(inplace=True), 178 | nn.Conv2d(256, 256, 3, dilation=3, padding=3),nn.ReLU(inplace=True),) 179 | self.conv_block4 = nn.Sequential( 180 | nn.Conv2d(256, 512, 3, dilation=1, padding=1),nn.ReLU(inplace=True), 181 | nn.Conv2d(512, 512, 3, dilation=2, padding=2),nn.ReLU(inplace=True), 182 | nn.Conv2d(512, 512, 3, dilation=3, padding=3),nn.ReLU(inplace=True),) 183 | self.conv_block5 = nn.Sequential( 184 | nn.Conv2d(512, 512, 3, dilation=1, padding=1),nn.ReLU(inplace=True), 185 | nn.Conv2d(512, 512, 3, dilation=2, padding=2),nn.ReLU(inplace=True), 186 | nn.Conv2d(512, 512, 3, dilation=3, padding=3),nn.ReLU(inplace=True),) 187 | 188 | self.conv1_16=nn.Conv2d(64, 16, 3, padding=1) 189 | self.conv2_16=nn.Conv2d(128, 16, 3, padding=1) 190 | self.conv3_16=nn.Conv2d(256, 16, 3, padding=1) 191 | self.conv4_16=nn.Conv2d(512, 16, 3, padding=1) 192 | self.conv5_16=nn.Conv2d(512, 16, 3, padding=1) 193 | 194 | self.score=nn.Sequential( 195 | nn.Conv2d(4*16,self.n_classes,1), 196 | nn.Dropout(0.5), 197 | ) 198 | 199 | def forward(self, x): 200 | conv1 = self.conv_block1(x) 201 | conv2 = self.conv_block2(conv1) 202 | conv3 = self.conv_block3(conv2) 203 | conv4 = self.conv_block4(conv3) 204 | conv5 = self.conv_block5(conv4) 205 | 206 | conv1_16=self.conv1_16(conv1) 207 | conv2_16=self.conv2_16(conv2) 208 | conv3_16=self.conv3_16(conv3) 209 | conv4_16=self.conv4_16(conv4) 210 | conv5_16=self.conv5_16(conv5) 211 | 212 | concat_1_to_4=torch.cat([conv1_16,conv2_16,conv3_16,conv4_16], 1) 213 | score=self.score(concat_1_to_4) 214 | return score 215 | 216 | def init_vgg16_params(self, vgg16, copy_fc8=True): 217 | blocks = [self.conv_block1, 218 | self.conv_block2, 219 | self.conv_block3, 220 | self.conv_block4, 221 | self.conv_block5] 222 | 223 | ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] 224 | features = list(vgg16.features.children()) 225 | 226 | for idx, conv_block in enumerate(blocks): 227 | for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): 228 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 229 | assert l1.weight.size() == l2.weight.size() 230 | assert l1.bias.size() == l2.bias.size() 231 | l2.weight.data = l1.weight.data 232 | l2.bias.data = l1.bias.data 233 | 234 | if __name__=='__main__': 235 | # x=torch.Tensor(4,3,256,256) 236 | # model=fcn(n_classes=9) 237 | # y=model(x) 238 | # print(y.shape) 239 | model=fcn(n_classes=9) 240 | summary(model.cuda(),(3,256,256)) -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, n_channels=3, n_classes=9): 8 | super(UNet, self).__init__() 9 | self.inc = inconv(n_channels, 64) 10 | self.down1 = down(64, 128) 11 | self.down2 = down(128, 256) 12 | self.down3 = down(256, 512) 13 | self.down4 = down(512, 512) 14 | self.up1 = up(1024, 256) 15 | self.up2 = up(512, 128) 16 | self.up3 = up(256, 64) 17 | self.up4 = up(128, 64) 18 | self.outfin = outconv(64, n_classes) 19 | 20 | def forward(self, x): 21 | x1 = self.inc(x) 22 | x2 = self.down1(x1) 23 | x3 = self.down2(x2) 24 | x4 = self.down3(x3) 25 | x5 = self.down4(x4) 26 | x = self.up1(x5, x4) 27 | x = self.up2(x, x3) 28 | x = self.up3(x, x2) 29 | x = self.up4(x, x1) 30 | x = self.outfin(x) 31 | return x 32 | 33 | 34 | class double_conv(nn.Module): 35 | '''(conv => BN => ReLU) * 2''' 36 | def __init__(self, in_ch, out_ch): 37 | super(double_conv, self).__init__() 38 | self.conv = nn.Sequential( 39 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 40 | nn.BatchNorm2d(out_ch), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 43 | nn.BatchNorm2d(out_ch), 44 | nn.ReLU(inplace=True) 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.conv(x) 49 | return x 50 | 51 | 52 | class inconv(nn.Module): 53 | def __init__(self, in_ch, out_ch): 54 | super(inconv, self).__init__() 55 | self.conv = double_conv(in_ch, out_ch) 56 | 57 | def forward(self, x): 58 | x = self.conv(x) 59 | return x 60 | 61 | 62 | class down(nn.Module): 63 | def __init__(self, in_ch, out_ch): 64 | super(down, self).__init__() 65 | self.mpconv = nn.Sequential( 66 | nn.MaxPool2d(2), 67 | double_conv(in_ch, out_ch) 68 | ) 69 | 70 | def forward(self, x): 71 | x = self.mpconv(x) 72 | return x 73 | 74 | 75 | class up(nn.Module): 76 | def __init__(self, in_ch, out_ch, bilinear=True): 77 | super(up, self).__init__() 78 | 79 | # would be a nice idea if the upsampling could be learned too, 80 | # but my machine do not have enough memory to handle all those weights 81 | if bilinear: 82 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 83 | else: 84 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 85 | 86 | self.conv = double_conv(in_ch, out_ch) 87 | 88 | def forward(self, x1, x2): 89 | x1 = self.up(x1) 90 | diffX = x1.size()[2] - x2.size()[2] 91 | diffY = x1.size()[3] - x2.size()[3] 92 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 93 | diffY // 2, int(diffY / 2))) 94 | x = torch.cat([x2, x1], dim=1) 95 | x = self.conv(x) 96 | return x 97 | 98 | 99 | class outconv(nn.Module): 100 | def __init__(self, in_ch, out_ch): 101 | super(outconv, self).__init__() 102 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 103 | 104 | def forward(self, x): 105 | x = self.conv(x) 106 | return x -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Nov 14,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | 8 | import time 9 | import numpy as np 10 | import cv2 11 | import os.path as osp 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | from tqdm import tqdm 19 | from model.fcn import fcn,fcn_dilated 20 | from model.unet import UNet 21 | from dataloader.augmentation import * 22 | from dataloader.coder import * 23 | from utils.crf import dense_crf 24 | 25 | # 图片预处理 26 | def transform(img, mask): 27 | img = img/255.0 28 | img = img.astype(np.float64) 29 | img = img.transpose(2, 0, 1) 30 | img = torch.from_numpy(img).float() 31 | mask = torch.from_numpy(mask).long() 32 | return img, mask 33 | 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | # 从 fcn,UNet中任选一个模型 36 | model = UNet().to(device) 37 | resume_path = './checkpoint/best_'+'unet'+'_model.pkl' #fcn,unet 38 | root_path = './image' 39 | img_path = root_path+'.png' 40 | mask_path = root_path+'_mask.png' 41 | 42 | image = cv2.imread(img_path) 43 | mask = cv2.imread(mask_path, 0) 44 | img,mask = Compose([Scale(224)])(image.copy(),mask) 45 | # image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 46 | mask = mask.astype(np.uint8) 47 | img, mask = transform(img, mask) 48 | img, mask = torch.unsqueeze(img,0), torch.unsqueeze(mask,0) 49 | # resume 50 | if osp.isfile(resume_path) : 51 | checkpoint = torch.load(resume_path) 52 | model.load_state_dict(checkpoint["model_state"]) 53 | best_iou = checkpoint['best_iou'] 54 | print("=====>", 55 | "Loaded checkpoint '{}' (iter {})".format( 56 | resume_path, checkpoint["epoch"] 57 | ) ) 58 | print("=====> best mIoU: %.4f best mean dice: %.4f"%(best_iou,(best_iou*2)/(best_iou+1))) 59 | 60 | else: 61 | raise ValueError("can't find model") 62 | 63 | crf = True 64 | 65 | with torch.no_grad(): 66 | img,mask= img.to(device),mask.to(device) 67 | output = model(img) #[1, 9, 256, 256] 68 | probs = F.softmax(output, dim=1) 69 | if crf: 70 | pred_crf = probs.cpu().data[0].numpy() 71 | # crf 72 | img = img.cpu().data[0].numpy() 73 | pred_crf = dense_crf(img*255, pred_crf) 74 | pred_crf = np.asarray(pred_crf, dtype=np.int) 75 | # 合并特征 76 | pred_crf = merge_classes(pred_crf) 77 | _, pred = torch.max(probs, dim=1) 78 | pred = pred.cpu().data[0].numpy() 79 | label = mask.cpu().data[0].numpy() 80 | pred = np.asarray(pred, dtype=np.int) 81 | label = np.asarray(label, dtype=np.int) 82 | pred = merge_classes(pred) 83 | label = merge_classes(label) 84 | cv2.namedWindow("image",0) 85 | cv2.imshow("image",image) 86 | cv2.namedWindow("mask",0) 87 | cv2.imshow("mask",encode(label,color_test)) 88 | cv2.namedWindow("pred",0) 89 | cv2.imshow("pred",encode(pred,color_test)) 90 | cv2.namedWindow("pred_crf",0) 91 | cv2.imshow("pred_crf",encode(pred_crf,color_test)) 92 | cv2.waitKey(0) 93 | cv2.destroyAllWindows() 94 | 95 | cv2.imwrite("image.png",image) 96 | cv2.imwrite("mask.png",encode(label,color_test)) 97 | cv2.imwrite("pred_mask.png",encode(pred,color_test)) 98 | cv2.imwrite("pred_crf_mask.png",encode(pred_crf,color_test)) 99 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Oct 31,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | 8 | import time 9 | import numpy as np 10 | import os.path as osp 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | 17 | from torch.utils.data import Dataset,DataLoader 18 | from torch.optim import lr_scheduler,Adam,SGD 19 | from torchvision import datasets, models, transforms 20 | from torchsummary import summary 21 | from model import fcn,fcn_dilated 22 | from utils import Score,averageMeter,cross_entropy2d 23 | from dataloader.MRBrain_loder import MRBrainSDataset 24 | from dataloader.augmentation import * 25 | 26 | # 参数设置 27 | defualt_path = osp.join('/home/cv_xfwang/data/', 'MRBrainS') 28 | learning_rate = 1e-8 29 | batch_size = 32 30 | num_workers = 4 31 | resume_path = '/home/cv_xfwang/MRBrainS_seg/checkpoint/best_model.pkl' 32 | resume_flag = True 33 | start_epoch = 0 34 | end_epoch = 1000 35 | test_interval = 10 36 | print_interval = 1 37 | momentum=0.99 38 | weight_decay = 0.005 39 | best_iou = -100 40 | 41 | # GPU or CPU 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | 44 | # Setup Dataloader 45 | data_aug = Compose([ 46 | RandomHorizontallyFlip(0.5), 47 | RandomRotate(10), 48 | Scale(256), 49 | ]) 50 | train_loader = DataLoader(MRBrainSDataset(defualt_path, split='train', is_transform=True, \ 51 | img_norm=True, augmentations=data_aug), \ 52 | batch_size=batch_size,num_workers=num_workers,pin_memory=True,shuffle=True) 53 | val_loader = DataLoader(MRBrainSDataset(defualt_path, split='val', is_transform=True, \ 54 | img_norm=True, augmentations=None), \ 55 | batch_size=batch_size,num_workers=num_workers,pin_memory=True,shuffle=False) 56 | 57 | # Setup Model and summary 58 | model = model(n_classes=9).to(device) 59 | vgg16 = models.vgg16(pretrained=False) 60 | vgg16.load_state_dict(torch.load("/home/cv_xfwang/pretrained/vgg16-397923af.pth")) 61 | model.init_vgg16_params(vgg16) 62 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 63 | summary(model,(3,256,256)) # summary 网络参数 64 | 65 | # 需要学习的参数 66 | # base_learning_list = list(filter(lambda p: p.requires_grad, model.base_net.parameters())) 67 | # learning_list = model.parameters() 68 | 69 | # 优化器以及学习率设置 70 | # optimizer = SGD([ 71 | # {'params': model.base_net.parameters(),'lr': learning_rate / 10}, 72 | # {'params': model.model_class.parameters(), 'lr': learning_rate * 10}, 73 | # {'params': model.model_reg.parameters(), 'lr': learning_rate * 10} 74 | # ], lr=learning_rate, momentum=momentum, weight_decay=weight_decay) 75 | optimizer = torch.optim.SGD( 76 | model.parameters(), 77 | lr=alearning_rate, 78 | momentum=momentum, 79 | weight_decay=weight_decay 80 | ) 81 | 82 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(0.4 * end_epoch), int(0.7 * end_epoch),int(0.8 * end_epoch),int(0.9 * end_epoch)], gamma=0.1) 83 | # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True) 84 | criterion = cross_entropy2d() 85 | 86 | # resume 87 | if (os.path.isfile(resume_path) and resume_flag): 88 | checkpoint = torch.load(resume_path) 89 | model.load_state_dict(checkpoint["model_state"]) 90 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 91 | best_iou = checkpoint['best_iou'] 92 | # scheduler.load_state_dict(checkpoint["scheduler_state"]) 93 | # start_epoch = checkpoint["epoch"] 94 | print("=====>", 95 | "Loaded checkpoint '{}' (iter {})".format( 96 | resume_path, checkpoint["epoch"] 97 | ) 98 | ) 99 | else: 100 | print("=====>","No checkpoint found at '{}'".format(resume_path)) 101 | 102 | # Training 103 | def train(epoch): 104 | print("Epoch: ",epoch) 105 | model.train() 106 | total_loss = 0 107 | for index, (img, mask) in enumerate(train_loader): 108 | img = img.to(device) 109 | optimizer.zero_grad() 110 | output = model(img) #[-1, 9, 256, 256] 111 | _, pred = torch.max(output, dim=1) 112 | loss = criterion(output,mask) 113 | total_loss += loss 114 | loss.backward() 115 | optimizer.step() 116 | 117 | print("loss: %.4f"%(total_loss/(img.szie(0)*(index+1))) ) 118 | 119 | # return mean IoU, mean dice 120 | def test(epoch): 121 | print(">>>Test: ") 122 | global best_iou 123 | model.eval() 124 | preds = [] 125 | gts = [] 126 | with torch.no_grad(): 127 | for i, (img, mask) in val_loader: 128 | img = img.to(device) 129 | output = model(img) 130 | output = F.interpolate(output, size=(h, w), mode='bilinear', align_corners=True) 131 | probs = F.softmax(output, dim=1) 132 | _, pred = torch.max(probs, dim=1) 133 | pred = pred.cpu().data[0].numpy() 134 | 135 | label = mask.cpu().data[0].numpy() 136 | pred = np.asarray(pred, dtype=np.int) 137 | label = np.asarray(label, dtype=np.int) 138 | gts.append(label) 139 | preds.append(preds) 140 | 141 | whole_brain_preds = np.dstack(preds) 142 | whole_brain_gts = np.dstack(gts) 143 | running_metrics = Score(9) 144 | running_metrics.update(whole_brain_gts, whole_brain_preds) 145 | scores, class_iou = running_metrics.get_scores() 146 | mIoU = np.nanmean(class_iou[1::]) 147 | mean_dice = (mIoU * 2) / (mIoU + 1) 148 | 149 | print("mean Iou",mIoU, "mean dice",mean_dice) 150 | if mIoU > best_iou: 151 | best_iou = mIoU 152 | state = { 153 | "epoch": epoch + 1, 154 | "model_state": model.state_dict(), 155 | "optimizer_state": optimizer.state_dict(), 156 | "scheduler_state": scheduler.state_dict(), 157 | "best_iou": best_iou, 158 | } 159 | save_path = osp.join(osp.split(resume_path)[0],"best_model.pkl") 160 | print("saving......") 161 | torch.save(state, save_path) 162 | return mIoU, mean_dice 163 | 164 | 165 | for epoch in range(start_epoch, end_epoch): 166 | train(epoch) 167 | test(epoch) 168 | scheduler.step() 169 | # print(train_loss[-1],train_acc[-1],test_loss[-1],test_acc[-1] -------------------------------------------------------------------------------- /train_unet.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | ''' 3 | Created on Nov14 31,2018 4 | 5 | @author: pengzhiliang 6 | ''' 7 | 8 | import time 9 | import numpy as np 10 | import os 11 | import os.path as osp 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | from tqdm import tqdm 19 | from torch.utils.data import Dataset,DataLoader 20 | from torch.optim import lr_scheduler,Adam,SGD 21 | from torchvision import datasets, models, transforms 22 | from torchsummary import summary 23 | from model.unet import UNet 24 | from utils.metrics import Score,averageMeter 25 | from utils.loss import cross_entropy2d,BCEDiceLoss,bootstrapped_cross_entropy2d 26 | from dataloader.MRBrain_loader import MRBrainSDataset 27 | from dataloader.augmentation import * 28 | from dataloader.coder import merge_classes 29 | 30 | 31 | # 参数设置 32 | defualt_path = osp.join('/home/cv_xfwang/data/', 'MRBrainS') 33 | learning_rate = 1e-6 34 | batch_size = 32 35 | num_workers = 4 36 | resume_path = '/home/cv_xfwang/MRBrainS_seg/checkpoint/best_unet_model.pkl' 37 | resume_flag = True 38 | start_epoch = 0 39 | end_epoch = 500 40 | test_interval = 10 41 | print_interval = 1 42 | momentum=0.99 43 | weight_decay = 0.005 44 | best_iou = -100 45 | 46 | # GPU or CPU 47 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 48 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 49 | 50 | # Setup Dataloader 51 | data_aug = Compose([ 52 | RandomHorizontallyFlip(0.5), 53 | RandomRotate(10), 54 | Scale(224), 55 | ]) 56 | train_loader = DataLoader(MRBrainSDataset(defualt_path, split='train', is_transform=True, \ 57 | img_norm=True, augmentations=data_aug), \ 58 | batch_size=batch_size,num_workers=num_workers,pin_memory=True,shuffle=True) 59 | val_loader = DataLoader(MRBrainSDataset(defualt_path, split='val', is_transform=True, \ 60 | img_norm=True, augmentations=Compose([Scale(224)])), \ 61 | batch_size=1,num_workers=num_workers,pin_memory=True,shuffle=False) 62 | 63 | # Setup Model and summary 64 | model = UNet().to(device) 65 | summary(model,(3,224,224),batch_size) # summary 网络参数 66 | # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 67 | 68 | 69 | # 需要学习的参数 70 | # base_learning_list = list(filter(lambda p: p.requires_grad, model.base_net.parameters())) 71 | # learning_list = model.parameters() 72 | 73 | # 优化器以及学习率设置 74 | optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=momentum,weight_decay=weight_decay) 75 | # learning rate调节器 76 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(0.2 * end_epoch), int(0.6 * end_epoch),int(0.9 * end_epoch)], gamma=0.01) 77 | # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True) 78 | criterion = cross_entropy2d 79 | # criterion = BCEDiceLoss() 80 | 81 | # running_metrics = Score(n_classes=9) 82 | running_metrics = Score(n_classes=4) # label_test=[0,2,2,3,3,1,1,0,0] 83 | label_test = [0,2,2,3,3,1,1,0,0] 84 | # resume 85 | if (osp.isfile(resume_path) and resume_flag): 86 | checkpoint = torch.load(resume_path) 87 | model.load_state_dict(checkpoint["model_state"]) 88 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 89 | best_iou = checkpoint['best_iou'] 90 | # scheduler.load_state_dict(checkpoint["scheduler_state"]) 91 | # start_epoch = checkpoint["epoch"] 92 | print("=====>", 93 | "Loaded checkpoint '{}' (iter {})".format( 94 | resume_path, checkpoint["epoch"] 95 | ) 96 | ) 97 | else: 98 | print("=====>","No checkpoint found at '{}'".format(resume_path)) 99 | print("load unet weight and bias") 100 | model_dict = model.state_dict() 101 | pretrained_dict = torch.load("/home/cv_xfwang/Pytorch-UNet/MODEL.pth") 102 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 103 | model_dict.update(pretrained_dict) 104 | model.load_state_dict(model_dict) 105 | 106 | # Training 107 | def train(epoch): 108 | print("Epoch: ",epoch) 109 | model.train() 110 | total_loss = 0 111 | # for index, (img, mask) in tqdm(enumerate(train_loader)): 112 | for index, (img, mask) in tqdm(enumerate(train_loader), total=len(train_loader), desc="Epoch {}".format(epoch), ncols=0): 113 | #img: torch.Size([32, 3, 256, 256]) mask:torch.Size([32, 256, 256]) 114 | img,mask= img.to(device),mask.to(device) 115 | optimizer.zero_grad() 116 | output = model(img) #[-1, 9, 256, 256] 117 | # _, pred = torch.max(output, dim=1) 118 | loss = criterion(output,mask)#,size_average=False 119 | total_loss += loss 120 | loss.backward() 121 | optimizer.step() 122 | 123 | print("Average loss: %.4f"%(total_loss/(img.size(0)*(index+1))) ) 124 | 125 | # return mean IoU, mean dice 126 | def test(epoch): 127 | print(">>>Test: ") 128 | global best_iou 129 | model.eval() 130 | running_metrics.reset() 131 | with torch.no_grad(): 132 | for i, (img, mask) in tqdm(enumerate(val_loader)): 133 | img = img.to(device) 134 | output = model(img) #[-1, 9, 256, 256] 135 | probs = F.softmax(output, dim=1) 136 | _, pred = torch.max(probs, dim=1) 137 | pred = pred.cpu().data[0].numpy() 138 | label = mask.cpu().data[0].numpy() 139 | pred = np.asarray(pred, dtype=np.int) 140 | label = np.asarray(label, dtype=np.int) 141 | # print(pred.shape,label.shape) 142 | 143 | running_metrics.update(merge_classes(label),merge_classes(pred)) 144 | 145 | score, class_iou = running_metrics.get_scores() 146 | for k, v in score.items(): 147 | print(k,':',v) 148 | print(i, class_iou) 149 | if score["Mean IoU : \t"] > best_iou: 150 | best_iou = score["Mean IoU : \t"] 151 | state = { 152 | "epoch": epoch + 1, 153 | "model_state": model.state_dict(), 154 | "optimizer_state": optimizer.state_dict(), 155 | "scheduler_state": scheduler.state_dict(), 156 | "best_iou": best_iou, 157 | } 158 | save_path = osp.join(osp.split(resume_path)[0],"best_unet_model.pkl") 159 | print("saving......") 160 | torch.save(state, save_path) 161 | # return mIoU, mean_dice 162 | 163 | 164 | for epoch in range(start_epoch, end_epoch): 165 | train(epoch) 166 | test(epoch) 167 | scheduler.step() 168 | # print(train_loss[-1],train_acc[-1],test_loss[-1],test_acc[-1] -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhiliang/MRBrainS_seg/52c392edb0b3d3988cdf526002f2e6df5c8401fe/utils/__init__.py -------------------------------------------------------------------------------- /utils/crf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | from pydensecrf.utils import compute_unary, create_pairwise_bilateral, create_pairwise_gaussian, unary_from_softmax 4 | def dense_crf(img, prob): 5 | ''' 6 | input: 7 | img: numpy array of shape (num of channels, height, width) 8 | prob: numpy array of shape (9, height, width), neural network last layer sigmoid output for img 9 | 10 | output: 11 | res: (height, width) 12 | 13 | Modified from: 14 | http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/18/image-segmentation-with-tensorflow-using-cnns-and-conditional-random-fields/ 15 | https://github.com/yt605155624/tensorflow-deeplab-resnet/blob/e81482d7bb1ae674f07eae32b0953fe09ff1c9d1/inference_crf.py 16 | ''' 17 | 18 | img = np.swapaxes(img, 0, 2) 19 | # img.shape: (width, height, num of channels)(224,224,3) 20 | 21 | num_iter = 50 22 | 23 | prob = np.swapaxes(prob, 1, 2) # shape: (1, width, height) (9,224,224) 24 | num_classes = 9 #2 25 | 26 | d = dcrf.DenseCRF2D(img.shape[0] , img.shape[1], num_classes) 27 | 28 | unary = unary_from_softmax(prob) # shape: (num_classes, width * height) 29 | unary = np.ascontiguousarray(unary) 30 | img = np.ascontiguousarray(img,dtype=np.uint8) 31 | 32 | d.setUnaryEnergy(unary) 33 | d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=img, compat=3) 34 | 35 | Q = d.inference(num_iter) # set the number of iterations 36 | res = np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1])) 37 | # res.shape: (width, height) 38 | 39 | res = np.swapaxes(res, 0, 1) # res.shape: (height, width) 40 | # res = res[np.newaxis, :, :] # res.shape: (1, height, width) 41 | 42 | # func_end = time.time() 43 | # print('{:.2f} sec spent on CRF with {} iterations'.format(func_end - func_start, num_iter)) 44 | # about 2 sec for a 1280 * 960 image with 5 iterations 45 | return res -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from torch.autograd import Variable 6 | 7 | def cross_entropy2d(inp, target, weight=None, size_average=True): 8 | n, c, h, w = inp.size() 9 | nt, ht, wt = target.size() 10 | 11 | # Handle inconsistent size between inp and target 12 | if h > ht and w > wt: # upsample labels 13 | target = target.unsequeeze(1) 14 | target = F.upsample(target, size=(h, w), mode='nearest') 15 | target = target.sequeeze(1) 16 | elif h < ht and w < wt: # upsample images 17 | inp = F.upsample(inp, size=(ht, wt), mode='bilinear') 18 | elif h != ht and w != wt: 19 | raise Exception("Only support upsampling") 20 | 21 | log_p = F.log_softmax(inp, dim=1) 22 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 23 | log_p = log_p[target.contiguous().view(-1, 1).repeat(1, c) >= 0] 24 | log_p = log_p.view(-1, c) 25 | 26 | mask = target >= 0 27 | target = target[mask] 28 | loss = F.nll_loss(log_p, target, ignore_index=250, 29 | weight=weight, size_average=False) 30 | if size_average: 31 | loss /= mask.float().data.sum() 32 | return loss 33 | 34 | def dice_loss(preds, trues, weight=None, is_average=True): 35 | num = preds.size(0) 36 | preds = preds.view(num, -1) 37 | trues = trues.view(num, -1) 38 | if weight is not None: 39 | w = torch.autograd.Variable(weight).view(num, -1) 40 | preds = preds * w 41 | trues = trues * w 42 | intersection = (preds * trues).sum(1) 43 | scores = 2. * (intersection + 1) / (preds.sum(1) + trues.sum(1) + 1) 44 | 45 | if is_average: 46 | score = scores.sum()/num 47 | return torch.clamp(score, 0., 1.) 48 | else: 49 | return scores 50 | 51 | def dice_clamp(preds, trues, is_average=True): 52 | preds = torch.round(preds) 53 | return dice_loss(preds, trues, is_average=is_average) 54 | 55 | class DiceLoss(nn.Module): 56 | def __init__(self, size_average=True): 57 | super().__init__() 58 | self.size_average = size_average 59 | 60 | def forward(self, input, target, weight=None): 61 | return 1-dice_loss(F.sigmoid(input), target, weight=weight, is_average=self.size_average) 62 | 63 | class BCEDiceLoss(nn.Module): 64 | def __init__(self, size_average=True): 65 | super().__init__() 66 | self.size_average = size_average 67 | self.dice = DiceLoss(size_average=size_average) 68 | 69 | def forward(self, input, target, weight=None): 70 | # n, c, h, w = input.size() 71 | # input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 72 | # target = target.view(-1) 73 | return nn.modules.loss.BCEWithLogitsLoss(size_average=self.size_average, weight=weight)(input, target) + self.dice(input, target, weight=weight) 74 | 75 | # def dice_loss(inp, target): 76 | # target=target.cpu().unsqueeze(1) 77 | # # target_bin=Variable(torch.zeros(1,9,target.shape[2],target.shape[3])).scatter_(1,target,1).cuda() 78 | # target_bin=torch.zeros(1,9,target.shape[2],target.shape[3]).scatter_(1,target,1).cuda() 79 | # target=target.squeeze(1).cuda() 80 | # target=target.squeeze(1).cuda() 81 | # smooth = 1. 82 | 83 | # iflat = inp.view(-1) 84 | # tflat = target_bin.view(-1) 85 | # intersection = (iflat * tflat).sum() 86 | 87 | # return 1 - ((2. * intersection + smooth) / 88 | # (iflat.sum() + tflat.sum() + smooth)) 89 | 90 | # def weighted_loss(inp,target_bin,weight,size_average=True): 91 | # n,c,h,w=inp.size() 92 | # # NHWC 93 | # inp=F.softmax(inp,dim=1).transpose(1,2).transpose(2,3).contiguous().view(-1,c) 94 | # inp=inp[target_bin.view(n*h*w,c)>=0] 95 | # inp=inp.view(-1,c) 96 | 97 | # weight=weight.transpose(1,2).transpose(2,3).contiguous() 98 | # weight=weight.view(n*h*w,1).repeat(1,c) 99 | # ''' 100 | # mask=target>=0 101 | # target=target[mask] 102 | # target_bin=np.zeros((n*h*w,c),np.float) 103 | # for i,term in enumerate(target): 104 | # target_bin[i,int(term)]=1 105 | # target_bin=torch.from_numpy(target_bin).float() 106 | # target_bin=Variable(target_bin.cuda()) 107 | # ''' 108 | # loss=F.binary_cross_entropy(inp,target_bin,weight=weight,size_average=False) 109 | # if size_average: 110 | # loss/=(target_bin>=0).data.sum()/c 111 | # return loss 112 | 113 | def bootstrapped_cross_entropy2d(inp, target, K, weight=None, size_average=True): 114 | 115 | batch_size = inp.size()[0] 116 | 117 | def _bootstrap_xentropy_single(inp, target, K, weight=None, size_average=True): 118 | n, c, h, w = inp.size() 119 | log_p = F.log_softmax(inp, dim=1) 120 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 121 | log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] 122 | log_p = log_p.view(-1, c) 123 | 124 | mask = target >= 0 125 | target = target[mask] 126 | loss = F.nll_loss(log_p, target, weight=weight, ignore_index=250, 127 | reduce=False, size_average=False) 128 | topk_loss, _ = loss.topk(K) 129 | reduced_topk_loss = topk_loss.sum() / K 130 | 131 | return reduced_topk_loss 132 | 133 | loss = 0.0 134 | # Bootstrap from each image not entire batch 135 | for i in range(batch_size): 136 | loss += _bootstrap_xentropy_single(inp=torch.unsqueeze(inp[i], 0), 137 | target=torch.unsqueeze(target[i], 0), 138 | K=K, 139 | weight=weight, 140 | size_average=size_average) 141 | return loss / float(batch_size) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | """ 3 | Use steps: 4 | -- running_metrics = Score(n_classes=9) 5 | -- running_metrics.update(gt, pred) # gt,pred all numpy arrays 6 | -- score, class_iou = running_metrics.get_scores() 7 | 8 | for k, v in scores.items(): 9 | print(k, v) 10 | for i in range(n_class): 11 | print(i, class_iou[i]) 12 | """ 13 | 14 | class Score(object): 15 | def __init__(self, n_classes): 16 | self.n_classes = n_classes 17 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 18 | 19 | 20 | def _fast_hist(self, label_true, label_pred, n_class): 21 | mask = (label_true >= 0) & (label_true < n_class) 22 | hist = np.bincount( 23 | n_class * label_true[mask].astype(int) + label_pred[mask], 24 | minlength=n_class ** 2, 25 | ).reshape(n_class, n_class) 26 | return hist 27 | 28 | def update(self, label_trues, label_preds): 29 | for lt, lp in zip(label_trues, label_preds): 30 | self.confusion_matrix += self._fast_hist( 31 | lt.flatten(), lp.flatten(), self.n_classes 32 | ) 33 | 34 | def get_scores(self): 35 | """Returns accuracy score evaluation result. 36 | - overall accuracy 37 | - mean accuracy 38 | - mean IU 39 | - fwavacc 40 | """ 41 | hist = self.confusion_matrix 42 | print("hist sum", hist.sum()) 43 | acc = np.diag(hist).sum() / hist.sum() 44 | acc_cls = np.diag(hist) / hist.sum(axis=1) 45 | acc_cls = np.nanmean(acc_cls) 46 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 47 | mean_iu = np.nanmean(iu) 48 | mean_dice=(mean_iu*2)/(mean_iu+1) 49 | freq = hist.sum(axis=1) / hist.sum() 50 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 51 | # cls_iu = dict(zip(range(self.n_classes), iu)) 52 | cls_iu = iu 53 | 54 | return ( 55 | { 56 | "Overall Acc: \t": acc, 57 | "Mean Acc : \t": acc_cls, 58 | "FreqW Acc : \t": fwavacc, 59 | "Mean IoU : \t": mean_iu, 60 | "Mean dice : \t": mean_dice, 61 | }, 62 | cls_iu, 63 | ) 64 | 65 | def reset(self): 66 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 67 | 68 | class averageMeter(object): 69 | """Computes and stores the average and current value""" 70 | def __init__(self): 71 | self.reset() 72 | 73 | def reset(self): 74 | self.val = 0 75 | self.avg = 0 76 | self.sum = 0 77 | self.count = 0 78 | 79 | def update(self, val, n=1): 80 | self.val = val 81 | self.sum += val * n 82 | self.count += n 83 | self.avg = self.sum / self.count 84 | --------------------------------------------------------------------------------