├── D-D-M.png ├── README.md ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── camvid.cpython-36.pyc │ ├── camvid_PG.cpython-36.pyc │ └── joint_transforms.cpython-36.pyc ├── data_PG.py └── joint_transforms.py ├── eval_PG.py ├── test.py ├── train.py ├── train_1024_DataParallel.py ├── train_1024_apex.py ├── train_1024_distributed.py ├── train_1024_horovod.py ├── train_1024_multiprocessing.py ├── train_PG.py ├── unet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── dice_loss.cpython-36.pyc │ ├── unet_model.cpython-36.pyc │ ├── unet_model_BN_MP_PG.cpython-36.pyc │ ├── unet_model_PG.cpython-36.pyc │ ├── unet_model_PG_256.cpython-36.pyc │ ├── unet_model_PG_more_conv.cpython-36.pyc │ ├── unet_model_PG_new.cpython-36.pyc │ └── unet_parts.cpython-36.pyc ├── dice_loss.py ├── unet_model.py └── unet_parts.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── crf.cpython-36.pyc ├── data_vis.cpython-36.pyc ├── imgs.cpython-36.pyc ├── load.cpython-36.pyc ├── training.cpython-36.pyc └── utils.cpython-36.pyc ├── crf.py ├── data_vis.py ├── imgs.py ├── load.py ├── training.py └── utils.py /D-D-M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/D-D-M.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PGU-net-Model PyTorch1.0 Python3.7 with multi GPUs 2 | 3 | This is my contribution to the "PGU-net+: Progressive Growing of U-net+ for Automated Cervical Nuclei Segmentation" in MICCAI2019 workshop MMMI. 4 | 5 | Paper link: http://arxiv.org/abs/1911.01062 6 | Please cite this paper if you find this project helpful for your research. 7 | 8 | # Dependencies 9 | 10 | python 3.7, CUDA 10.1, numpy 1.17, matplotlib 3.1.1, scikit-image (0.21), scipy (0.3.1), pyparsing (2.4.2), pytorch (1.0) (anaconda is recommended) 11 | other packages could be the latest version. 12 | 13 | # Training: 14 | 15 | 1.Install all dependencies 16 | 17 | 2.For PGU-net+: Progressive Growing of U-net+, 18 | python train_PG.py 19 | 20 | For the U-net structure in the final stag, 21 | python train.py 22 | 23 | # Testing: 24 | For testing, 25 | python test.py 26 | 27 | # Distributed training mode 28 | In the later stage, several parallel and distributed training methods of multi GPU are added,参考自https://github.com/tczhangzhi/pytorch-distributed 29 | 30 | 1.nn.DataParallel 简单方便的 nn.DataParallel 31 | 32 | 2.torch.distributed 使用 torch.distributed 加速并行训练 33 | 34 | 3.torch.multiprocessing 使用 torch.multiprocessing 取代启动器 35 | 36 | 4.apex 使用 apex 再加速 37 | 38 | 5.horovod horovod 的优雅实现 39 | 40 | For the U-net structure with 5 distributed training methods above 41 | 42 | python train_1024_DataParallel.py for nn.DataParallel 43 | 44 | python train_1024_distributed.py for torch.distributed 45 | 46 | python train_1024_multiprocessing.py for torch.multiprocessing 47 | 48 | python train_1024_apex.py for apex 49 | 50 | python train_1024_horovod.py for horovod 51 | 52 | 实验结果如下图(apex和horovod存在一些bug): 53 | 54 | 55 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/camvid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/datasets/__pycache__/camvid.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/camvid_PG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/datasets/__pycache__/camvid_PG.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/joint_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/datasets/__pycache__/joint_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/data_PG.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | from PIL import Image 6 | import cv2 7 | import random 8 | from torchvision.datasets.folder import default_loader 9 | classes = ['Sky', 'Building', 'Column-Pole', 'Road'] 10 | class_weight = torch.FloatTensor([0.1, 0.1, 0.5,1]) 11 | # class_weight = torch.FloatTensor([0.0057471264, 0.0050251, 0.00884955752,1]) 12 | # mean = [0.611, 0.506, 0.54] 13 | mean = [0.6127558736339982,0.5071148744673234,0.5406509545283443] 14 | std = [0.13964046123851956,0.16156206296516235,0.165885041027991] 15 | testmean = [0.6170943891910641,0.5133861905981716,0.545347489522038] 16 | teststd = [0.14098655787705194,0.16313775003634445,0.16636559984060037] 17 | class_color = [(128, 128, 128), (128, 0, 0), (192, 192, 128), (128, 64, 128)] 18 | pred_dir0 = './train_pred00/' 19 | def _make_dataset(dir): 20 | images = [] 21 | names = [] 22 | for root, _, fnames in sorted(os.walk(dir)): 23 | for fname in fnames: 24 | # if is_image_file(fname): 25 | path = os.path.join(root, fname) 26 | 27 | item = path 28 | name = path[-27:] 29 | # print(name) 30 | images.append(item) 31 | names.append(name) 32 | return images, names 33 | 34 | class LabelToLongTensor(object): 35 | def __call__(self, pic): 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | label = torch.from_numpy(pic).long() 39 | else: 40 | label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 41 | label = label.view(pic.size[1], pic.size[0], 1) 42 | label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long() 43 | # print('0', label.unsqueeze(2).shape, name) 44 | # cv2.imwrite(name,80*(label.unsqueeze(2).cpu().numpy())) 45 | return label 46 | 47 | class LabelTensorToPILImage(object): 48 | def __call__(self, label): 49 | label = label.unsqueeze(0) 50 | colored_label = torch.zeros(3, label.size(1), label.size(2)).byte() 51 | for i, color in enumerate(class_color): 52 | mask = label.eq(i) 53 | for j in range(3): 54 | colored_label[j].masked_fill_(mask, color[j]) 55 | npimg = colored_label.numpy() 56 | npimg = np.transpose(npimg, (1, 2, 0)) 57 | mode = None 58 | if npimg.shape[2] == 1: 59 | npimg = npimg[:, :, 0] 60 | mode = "L" 61 | return Image.fromarray(npimg, mode=mode) 62 | 63 | class CamVid(data.Dataset): 64 | def __init__(self, root, split='train', joint_transform=None, 65 | transform=None, target_transform=LabelToLongTensor(), 66 | download=False, 67 | loader=default_loader): 68 | self.root = root 69 | assert split in ('train', 'val', 'test') 70 | self.split = split 71 | self.transform = transform 72 | self.target_transform = target_transform 73 | self.joint_transform = joint_transform 74 | self.loader = loader 75 | self.class_weight = class_weight 76 | self.classes = classes 77 | self.mean = mean 78 | self.std = std 79 | # print(os.path.join(self.root, self.split)) 80 | print(self.split) 81 | if download: 82 | self.download() 83 | self.imgs, self.names = _make_dataset(os.path.join(self.root, self.split)) 84 | # self.scaled_savepath = scaled_savepath 85 | 86 | def __getitem__(self, index): 87 | path = self.imgs[index] 88 | img = self.loader(path) 89 | # print('1',img.size) 90 | 91 | target = Image.open(path.replace(self.split, self.split + 'annot')) 92 | # i = random.randint(0,100) 93 | # img.save(os.path.join(pred_dir0, '%r_input.png' % i)) 94 | # 255*target.save(os.path.join(pred_dir0, '%r_gt.png' % i)) 95 | if self.joint_transform is not None: 96 | # img= self.joint_transform(img) 97 | img, target = self.joint_transform([img, target]) 98 | # target = self.joint_transform(target) 99 | # print('1',np.max(img), np.max(target)) 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | # name = self.scaled_savepath + self.split + '/' + name 103 | target = self.target_transform(target) 104 | # print('1',img.shape, target.shape) 105 | # print('1',np.max(img), np.max(target)) 106 | return img, target 107 | def __len__(self): 108 | return len(self.imgs) 109 | def download(self): 110 | raise NotImplementedError 111 | -------------------------------------------------------------------------------- /datasets/joint_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | 11 | class JointScale(object): 12 | """Rescales the input PIL.Image to the given 'size'. 13 | 'size' will be the size of the smaller edge. 14 | For example, if height > width, then image will be 15 | rescaled to (size * height / width, size) 16 | size: size of the smaller edge 17 | interpolation: Default: PIL.Image.BILINEAR 18 | """ 19 | 20 | def __init__(self, size, interpolation=Image.BILINEAR): 21 | self.size = size 22 | self.interpolation = interpolation 23 | 24 | def __call__(self, imgs): 25 | w, h = imgs[0].size 26 | if (w <= h and w == self.size) or (h <= w and h == self.size): 27 | return imgs 28 | if w < h: 29 | ow = self.size 30 | oh = int(self.size * h / w) 31 | return [img.resize((ow, oh), self.interpolation) for img in imgs] 32 | else: 33 | oh = self.size 34 | ow = int(self.size * w / h) 35 | return [img.resize((ow, oh), self.interpolation) for img in imgs] 36 | 37 | 38 | class JointCenterCrop(object): 39 | """Crops the given PIL.Image at the center to have a region of 40 | the given size. size can be a tuple (target_height, target_width) 41 | or an integer, in which case the target will be of a square shape (size, size) 42 | """ 43 | 44 | def __init__(self, size): 45 | if isinstance(size, numbers.Number): 46 | self.size = (int(size), int(size)) 47 | else: 48 | self.size = size 49 | 50 | def __call__(self, imgs): 51 | w, h = imgs[0].size 52 | th, tw = self.size 53 | x1 = int(round((w - tw) / 2.)) 54 | y1 = int(round((h - th) / 2.)) 55 | return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs] 56 | 57 | 58 | class JointPad(object): 59 | """Pads the given PIL.Image on all sides with the given "pad" value""" 60 | 61 | def __init__(self, padding, fill=0): 62 | assert isinstance(padding, numbers.Number) 63 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 64 | self.padding = padding 65 | self.fill = fill 66 | 67 | def __call__(self, imgs): 68 | return [ImageOps.expand(img, border=self.padding, fill=self.fill) for img in imgs] 69 | 70 | 71 | class JointLambda(object): 72 | """Applies a lambda as a transform.""" 73 | 74 | def __init__(self, lambd): 75 | assert isinstance(lambd, types.LambdaType) 76 | self.lambd = lambd 77 | 78 | def __call__(self, imgs): 79 | return [self.lambd(img) for img in imgs] 80 | 81 | class JointRandomCrop(object): 82 | """Crops the given list of PIL.Image at a random location to have a region of 83 | the given size. size can be a tuple (target_height, target_width) 84 | or an integer, in which case the target will be of a square shape (size, size) 85 | """ 86 | 87 | def __init__(self, size, padding=0): 88 | if isinstance(size, numbers.Number): 89 | self.size = (int(size), int(size)) 90 | else: 91 | self.size = size 92 | self.padding = padding 93 | 94 | def __call__(self, imgs): 95 | if self.padding > 0: 96 | imgs = [ImageOps.expand(img, border=self.padding, fill=0) for img in imgs] 97 | 98 | # print('2',imgs.size) 99 | w, h = imgs[0].size 100 | th, tw = self.size 101 | if w == tw and h == th: 102 | return imgs 103 | 104 | x1 = random.randint(0, w - tw) 105 | y1 = random.randint(0, h - th) 106 | # return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in imgs] 107 | return imgs.crop((x1, y1, x1 + tw, y1 + th)) 108 | 109 | 110 | class JointRandomHorizontalFlip(object): 111 | """Randomly horizontally flips the given list of PIL.Image with a probability of 0.5 112 | """ 113 | 114 | def __call__(self, imgs): 115 | if random.random() < 0.5: 116 | return [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs] 117 | # return imgs.transpose(Image.FLIP_LEFT_RIGHT) 118 | # return imgs.transpose(Image.FLIP_LEFT_RIGHT) 119 | return imgs 120 | 121 | 122 | class JointRandomSizedCrop(object): 123 | """Random crop the given list of PIL.Image to a random size of (0.08 to 1.0) of the original size 124 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 125 | This is popularly used to train the Inception networks 126 | size: size of the smaller edge 127 | interpolation: Default: PIL.Image.BILINEAR 128 | """ 129 | 130 | def __init__(self, size, interpolation=Image.BILINEAR): 131 | self.size = size 132 | self.interpolation = interpolation 133 | 134 | def __call__(self, imgs): 135 | for attempt in range(10): 136 | area = imgs[0].size[0] * imgs[0].size[1] 137 | target_area = random.uniform(0.08, 1.0) * area 138 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 139 | 140 | w = int(round(math.sqrt(target_area * aspect_ratio))) 141 | h = int(round(math.sqrt(target_area / aspect_ratio))) 142 | 143 | if random.random() < 0.5: 144 | w, h = h, w 145 | 146 | if w <= imgs[0].size[0] and h <= imgs[0].size[1]: 147 | x1 = random.randint(0, imgs[0].size[0] - w) 148 | y1 = random.randint(0, imgs[0].size[1] - h) 149 | 150 | imgs = [img.crop((x1, y1, x1 + w, y1 + h)) for img in imgs] 151 | assert(imgs[0].size == (w, h)) 152 | 153 | return [img.resize((self.size, self.size), self.interpolation) for img in imgs] 154 | 155 | # Fallback 156 | scale = JointScale(self.size, interpolation=self.interpolation) 157 | crop = JointCenterCrop(self.size) 158 | return crop(scale(imgs)) 159 | -------------------------------------------------------------------------------- /eval_PG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import glob 5 | import os 6 | # from dice_loss import dice_coeff 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import cv2 10 | def get_predictions(output_batch): 11 | bs,c,h,w = output_batch.size() 12 | tensor = output_batch.data 13 | values, indices = tensor.cpu().max(1) 14 | indices = indices.view(bs,h,w) 15 | return indices 16 | 17 | def error(preds, targets): 18 | assert preds.size() == targets.size() 19 | bs,h,w = preds.size() 20 | n_pixels = bs*h*w 21 | # n_pixels = 1 22 | incorrect = preds.ne(targets).cpu().sum().numpy() 23 | err = incorrect/n_pixels 24 | # print(incorrect,n_pixels,err) 25 | # return round(err,5) 26 | return err 27 | 28 | def train_net(net, dataset, criterion, optimizer, EE_size): 29 | net.train() 30 | trn_loss = 0 31 | trn_error = 0 32 | DICE = 0 33 | for i, data in enumerate(dataset): 34 | # with torch.no_grad(): 35 | imgs = data[0] 36 | true_masks = data[1] 37 | imgs = torch.FloatTensor(F.interpolate(imgs, EE_size, mode='nearest')).cuda() 38 | true_masks = torch.FloatTensor(F.interpolate((true_masks.unsqueeze(1).float()), EE_size, mode='nearest')).squeeze(1).cuda().long() 39 | # imgs = torch.FloatTensor(imgs) 40 | imgs = Variable(imgs.cuda(), requires_grad=True) 41 | # true_masks = Variable(true_masks.cuda())/80###for pre data 42 | true_masks = Variable(true_masks.cuda()) 43 | masks_pred = net(imgs) 44 | pred = get_predictions(masks_pred) 45 | # print(pred.shape) 46 | dice = DiceLoss(true_masks.data.cpu(), pred) 47 | DICE = DICE + dice 48 | # print('----------------------train-', true_masks.shape, masks_pred.shape, imgs.shape) 49 | loss = criterion(masks_pred, true_masks) 50 | # print('loss',loss.item()) 51 | trn_loss = trn_loss + loss.item() 52 | trn_error += error(pred, true_masks.data.cpu()) 53 | 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | # print('----------------------train-', true_masks.shape, masks_pred.shape, imgs.shape) 58 | trn_loss /= len(dataset) 59 | trn_error /= len(dataset) 60 | DICE /= len(dataset) 61 | del imgs, true_masks, masks_pred, pred 62 | return trn_loss, trn_error, DICE 63 | 64 | def eval_net(net, test_loader, criterion, EE_size): 65 | test_loss = 0 66 | test_error = 0 67 | DICE = 0 68 | for data, target in test_loader: 69 | # with torch.no_grad(): 70 | data = torch.FloatTensor(F.interpolate(data, EE_size, mode='nearest')).cuda() 71 | target = torch.FloatTensor(F.interpolate((target.unsqueeze(1).float()), EE_size, mode='nearest')).squeeze(1).cuda().long() 72 | # print('----------------------val-', target.shape, data.shape) 73 | data = Variable(data.cuda()) 74 | # target = Variable(target.cuda())/80###for pre data 75 | target = Variable(target.cuda()) 76 | output = net(data) 77 | pred = get_predictions(output) 78 | dice = DiceLoss(target.data.cpu(), pred) 79 | DICE = DICE + dice 80 | test_loss += criterion(output, target).item() 81 | pred = get_predictions(output) 82 | test_error += error(pred, target.data.cpu()) 83 | del data, target, output, pred 84 | # print('----------------------val-', target.shape, output.shape, data.shape) 85 | test_loss /= len(test_loader) 86 | test_error /= len(test_loader) 87 | DICE /= len(test_loader) 88 | 89 | return test_loss, test_error, DICE 90 | 91 | def save_weights(WEIGHTS_PATH,model, epoch, loss, err): 92 | # weights_fname = 'weights-%d-%.3f-%.3f-%.4f.pth' % (epoch, loss, err, dice) 93 | # weights_fpath = os.path.join(WEIGHTS_PATH, weights_fname) 94 | name = str(WEIGHTS_PATH) + '/' + str(epoch) + '_' + str(round(loss,5)) + '_' + str(round(err,5)) + '_model.pkl' 95 | print(name) 96 | torch.save(model, name) 97 | # torch.save({ 98 | def save_weights_dice(WEIGHTS_PATH,model, epoch, loss, err, dice, EE): 99 | # weights_fname = 'weights-%d-%.3f-%.3f-%.4f.pth' % (epoch, loss, err, dice) 100 | # weights_fpath = os.path.join(WEIGHTS_PATH, weights_fname) 101 | name = str(WEIGHTS_PATH) + '/' + 'PG_' + str(EE) + '_' + str(epoch) + '_' + str(round(loss,5)) + '_' + str(round(err,5)) + '_' + str(round(dice,5)) + '_model.pkl' 102 | print(name) 103 | torch.save(model, name) 104 | # torch.save({ 105 | # 'startEpoch': epoch, 106 | # 'loss':loss, 107 | # 'error': err, 108 | # 'acc': 1-err, 109 | # 'state_dict': model.state_dict() 110 | # }, weights_fpath) 111 | # shutil.copyfile(weights_fpath, WEIGHTS_PATH+'latest.th') 112 | 113 | def adjust_learning_rate(lr, decay, optimizer, cur_epoch, n_epochs): 114 | """Sets the learning rate to the initially 115 | configured `lr` decayed by `decay` every `n_epochs`""" 116 | # new_lr = lr * (decay ** (cur_epoch // n_epochs)) 117 | new = decay ** (cur_epoch // n_epochs) 118 | # print('currunt LR:', new_lr) 119 | for param_group in optimizer.param_groups: 120 | # param_group['lr'] = new_lr 121 | print('currunt LR:', param_group['lr'] * new) 122 | def view_sample_predictions(model, loader, FILE_test_imgs_original, pred_dir, EE_size): 123 | names = [] 124 | DICE = 0 125 | # print('1',FILE_test_imgs_original) 126 | for files in glob.glob(FILE_test_imgs_original + "/*.png"): 127 | # print('name:',files[-27:]) 128 | # print('name:',files.split('/')[-1]) 129 | # names.append(files[-27:]) 130 | names.append(files.split('/')[-1]) 131 | i = 0 132 | for idx, data in enumerate(loader): 133 | inputs = data[0] 134 | targets = data[1] 135 | inputs = torch.FloatTensor(F.interpolate(inputs, EE_size, mode='nearest')).cuda() 136 | targets = torch.FloatTensor(F.interpolate((targets.unsqueeze(1).float()), EE_size, mode='nearest')).squeeze(1).cuda().long() 137 | inputs = Variable(inputs.cuda()) 138 | # targets = Variable(true_masks.cuda())/80###for pre data 139 | targets = Variable(targets.cuda()) 140 | output = model(inputs) 141 | pred = get_predictions(output) 142 | # pred[pred = 1] = 0 143 | pred[pred > 1] = 2 144 | for i_ST in range(pred.shape[0]): 145 | # print('imgs.shape,true_masks.shape,pred.shape',inputs.shape,targets.shape,pred.shape) 146 | 147 | # print(pred[i_ST].unsqueeze(0).data.cpu().numpy().shape) 148 | # cv2.imwrite(os.path.join(pred_dir, '%r_input.png' % i),127*np.transpose(inputs[i_ST].data.cpu().numpy(),(1,2,0))) 149 | cv2.imwrite(os.path.join(pred_dir, '%r_prediction.png' % i),127*np.transpose(pred[i_ST].unsqueeze_(0).data.cpu().numpy(),(1,2,0))) 150 | # cv2.imwrite(os.path.join(pred_dir, '%r_gt.png' % i),127*np.transpose(targets[i_ST].unsqueeze_(0).data.cpu().numpy(),(1,2,0))) 151 | i = i + 1 152 | # cv2.imwrite(name,127*np.transpose(np.asarray(pred),(1,2,0))) 153 | # cv2.imwrite(name,np.transpose(np.asarray(pred),(1,2,0))) 154 | dice = DiceLoss(targets.data.cpu(), pred) 155 | # print('----------------------test-', targets.shape, output.shape, imgs.shape) 156 | # print('TEST-masks_pred, true_masks', pred.shape, output.shape, targets.shape) 157 | # print(name,dice) 158 | DICE = DICE + dice 159 | 160 | # print('TEST-masks_pred, true_masks', inputs.shape, output.shape, targets.shape) 161 | del inputs, targets, data, output, pred 162 | return DICE/len(loader) 163 | def calculateDice(pred_dir, true_dir): 164 | print('pred_dir, true_dir',pred_dir, true_dir) 165 | imgNumber = len(glob.glob(pred_dir + "*.png")) 166 | # print(imgNumber) 167 | dice = 0 168 | iou = 0 169 | for files in glob.glob(pred_dir + "*.png"): 170 | up = 0 171 | down = 0 172 | img_pred = cv2.imread(files,0)/80 173 | # print('pred:',files) 174 | file = files[-27:] 175 | # print('true file:', true_dir+file) 176 | img_true = cv2.imread(true_dir+file,0)/80 177 | # print(np.max(img_true),np.max(img_pred)) 178 | height = img_true.shape[0] 179 | width = img_true.shape[1] 180 | new_masks_look = np.empty((height,width,3)) 181 | for j in range(height): 182 | for k in range(width): 183 | if img_pred[j][k]>1 and img_true[j][k]>1: 184 | up = up +1 185 | else: 186 | pass 187 | if img_pred[j][k]>1: 188 | down = down +1 189 | else: 190 | pass 191 | if img_true[j][k]>1: 192 | down = down +1 193 | else: 194 | pass 195 | dice_value = 2*up/down 196 | # print(file,dice_value) 197 | dice = dice + dice_value 198 | Jaccard_value = up/(down-up) 199 | iou=iou+Jaccard_value 200 | dice = dice/imgNumber 201 | # print('dice ave:',dice/imgNumber)#,Jaccard_value) 202 | # print('iou ave:',iou/imgNumber)#,Jaccard_value) 203 | return dice 204 | 205 | def DiceLoss(pred, target): 206 | N = target.shape[0] 207 | # pred = pred.view(N, -1) 208 | # target = target.view(N, -1) 209 | pred, target = pred.data.cpu().numpy(), target.data.cpu().numpy() 210 | # pred = pred.squeeze(dim=1) 211 | 212 | pred[pred <= 1] = 0 213 | pred[pred > 1] = 1 214 | 215 | target[target <= 1] = 0 216 | target[target > 1] = 1 217 | # print('pred, masks', pred.shape, target.shape) 218 | dice = 2 * (pred * target).sum(1).sum(1) / (pred.sum(1).sum(1) + target.sum(1).sum(1) + 1e-5) 219 | dice = dice.sum() / N 220 | # 返回的是dice距离 221 | return dice 222 | 223 | class DiceLoss_23(nn.Module): 224 | def __init__(self): 225 | super().__init__() 226 | 227 | def forward(self, pred, target, EE): 228 | size = 4 * 2 ** (EE - 1) 229 | classnum = 4 230 | # 首先将金标准拆开 231 | target_copy = torch.zeros((target.size(0), classnum, size, size)) 232 | print('1',target_copy.shape) 233 | # print('1',target.size(0)) 234 | for index in range(1, classnum + 1): 235 | temp_target = torch.zeros(target.size()) 236 | # print('1',temp_target.shape) 237 | temp_target[target == index] = 1 238 | target_copy[:, index - 1, :, :, :] = temp_target 239 | # target_copy: (B, 2, 20, 256, 256) 240 | 241 | target_copy = target_copy.cuda() 242 | dice = 0.0 243 | 244 | dice_L = 2 * (pred[:, 1, :, :, :] * target_copy[:, 1 - 1, :, :, :]).sum(dim=1).sum(dim=1).sum(dim=1) / (pred[:, 1, :, :, :].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + target_copy[:, 1 - 1, :, :, :].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 1e-10) 245 | dice_T = (2 * (pred[:, 2, :, :, :] * target_copy[:, 2 - 1, :, :, :]).sum(dim=1).sum(dim=1).sum(dim=1) + 1e-10) / (pred[:, 2, :, :, :].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + target_copy[:, 2 - 1, :, :, :].pow(2).sum(dim=1).sum(dim=1).sum(dim=1) + 1e-10) 246 | dice = 0.2 * dice_L + 1.8 * dice_T ###256 more attention to Tumor 247 | dice /= classnum 248 | 249 | # 返回的是dice距离 250 | # return (1 - dice).mean(),int(dice_L.mean().data.cpu().numpy()* 1000)/1000.0,int(dice_T.mean().data.cpu().numpy() * 1000)/1000.0 251 | return (1 - dice).mean(),dice_L.mean().data.cpu().numpy(),dice_T.mean().data.cpu().numpy() 252 | def resize(images, EE): 253 | x = 7 254 | if EE >= x: 255 | EE = x 256 | return F.adaptive_avg_pool2d(images, 8 * 2 ** (EE - 1)) 257 | # return images 258 | def weights_init(m): 259 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 260 | nn.init.xavier_normal_(m.weight.data) 261 | if m.bias is not None: 262 | nn.init.constant_(m.bias, 0) 263 | elif type(m) == nn.BatchNorm2d: 264 | nn.init.normal_(m.weight, 1.0, 0.02) 265 | nn.init.constant_(m.bias, 0) 266 | 267 | def weight_transform(pretrained_dict,pretrained_dict0, EE): 268 | if EE == 7: 269 | print('7') 270 | pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 271 | pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 272 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 273 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 274 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 275 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 276 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 277 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 278 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.weight'] 279 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.bias'] 280 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] 281 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] 282 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight'] 283 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias'] 284 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight'] 285 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias'] 286 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.5.weight'] 287 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.5.bias'] 288 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.weight'] 289 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.bias'] 290 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.weight'] 291 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.bias'] 292 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 293 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 294 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 295 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 296 | pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 297 | pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 298 | if EE == 6: 299 | print('6') 300 | pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 301 | pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 302 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 303 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 304 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 305 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 306 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 307 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 308 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.weight'] 309 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.bias'] 310 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] 311 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] 312 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.weight'] 313 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.bias'] 314 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.weight'] 315 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.bias'] 316 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.weight'] 317 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.bias'] 318 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 319 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 320 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 321 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 322 | pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 323 | pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 324 | 325 | if EE == 5: 326 | print('5') 327 | pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 328 | pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 329 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 330 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 331 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 332 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 333 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 334 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 335 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.weight'] 336 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.bias'] 337 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.weight'] 338 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.bias'] 339 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.weight'] 340 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.bias'] 341 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 342 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 343 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 344 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 345 | pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 346 | pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 347 | if EE == 4: 348 | print('4') 349 | pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 350 | pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 351 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 352 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 353 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 354 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 355 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 356 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 357 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.weight'] 358 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.bias'] 359 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 360 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 361 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 362 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 363 | pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 364 | pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 365 | if EE == 3: 366 | print('3') 367 | pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 368 | pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 369 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 370 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 371 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 372 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 373 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.weight'] 374 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.bias'] 375 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 376 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 377 | pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 378 | pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 379 | if EE == 2: 380 | print('E = 2') 381 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 382 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 383 | pretrained_dict['module.model.model.1.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.weight'] 384 | pretrained_dict['module.model.model.1.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.bias'] 385 | 386 | pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 387 | pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 388 | pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 389 | pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 390 | return pretrained_dict 391 | 392 | def weight_transform_new(pretrained_dict,pretrained_dict0, EE): 393 | if EE == 7: 394 | print('7') 395 | # pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 396 | # pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 397 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 398 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 399 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 400 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 401 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 402 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 403 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.weight'] 404 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.bias'] 405 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] 406 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] 407 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight'] 408 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias'] 409 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight'] 410 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias'] 411 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.5.weight'] 412 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.5.bias'] 413 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.weight'] 414 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.bias'] 415 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.weight'] 416 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.bias'] 417 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 418 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 419 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 420 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 421 | # pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 422 | # pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 423 | if EE == 6: 424 | print('6') 425 | # pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 426 | # pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 427 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 428 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 429 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 430 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 431 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 432 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 433 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.weight'] 434 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.bias'] 435 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] 436 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] 437 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.weight'] 438 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.model.3.bias'] 439 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.weight'] 440 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.5.bias'] 441 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.weight'] 442 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.bias'] 443 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 444 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 445 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 446 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 447 | # pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 448 | # pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 449 | 450 | if EE == 5: 451 | print('5') 452 | # pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 453 | # pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 454 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 455 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 456 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 457 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 458 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 459 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 460 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.weight'] 461 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.1.bias'] 462 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.weight'] 463 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.model.3.bias'] 464 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.weight'] 465 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.5.bias'] 466 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 467 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 468 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 469 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 470 | # pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 471 | # pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 472 | if EE == 4: 473 | print('4') 474 | # pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 475 | # pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 476 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 477 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 478 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 479 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 480 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.weight'] 481 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.1.bias'] 482 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.weight'] 483 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.model.3.bias'] 484 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.3.model.5.weight'] 485 | pretrained_dict['module.model.model.1.model.3.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.3.model.5.bias'] 486 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 487 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 488 | # pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 489 | # pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 490 | if EE == 3: 491 | print('3') 492 | # pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 493 | # pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 494 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 495 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 496 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.3.model.1.weight'] 497 | pretrained_dict['module.model.model.1.model.3.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.3.model.1.bias'] 498 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.model.3.weight'] 499 | pretrained_dict['module.model.model.1.model.3.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.model.3.bias'] 500 | pretrained_dict['module.model.model.1.model.3.model.5.weight'] = pretrained_dict0['module.model.model.1.model.5.weight'] 501 | pretrained_dict['module.model.model.1.model.3.model.5.bias'] = pretrained_dict0['module.model.model.1.model.5.bias'] 502 | # pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 503 | # pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 504 | if EE == 2: 505 | print('E = 2') 506 | pretrained_dict['module.model.model.1.model.3.model.1.weight'] = pretrained_dict0['module.model.model.1.model.1.weight'] 507 | pretrained_dict['module.model.model.1.model.3.model.1.bias'] = pretrained_dict0['module.model.model.1.model.1.bias'] 508 | pretrained_dict['module.model.model.1.model.3.model.3.weight'] = pretrained_dict0['module.model.model.1.model.3.weight'] 509 | pretrained_dict['module.model.model.1.model.3.model.3.bias'] = pretrained_dict0['module.model.model.1.model.3.bias'] 510 | 511 | # pretrained_dict['module.model.model.0.weight'] = pretrained_dict0['module.model.model.0.weight'] 512 | # pretrained_dict['module.model.model.0.bias'] = pretrained_dict0['module.model.model.0.bias'] 513 | # pretrained_dict['module.model.model.3.weight'] = pretrained_dict0['module.model.model.3.weight'] 514 | # pretrained_dict['module.model.model.3.bias'] = pretrained_dict0['module.model.model.3.bias'] 515 | return pretrained_dict 516 | 517 | def weight_transform_MC(pretrained_dict,pretrained_dict0, EE): 518 | if EE == 7: 519 | print('7') 520 | # pretrained_dict[''] = pretrained_dict0[''] 521 | pretrained_dict['module.model.model.3.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.1.weight'] 522 | pretrained_dict['module.model.model.3.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.1.bias'] 523 | pretrained_dict['module.model.model.3.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.3.weight'] 524 | pretrained_dict['module.model.model.3.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.3.bias'] 525 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.1.weight'] 526 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.1.bias'] 527 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.3.weight'] 528 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.3.bias'] 529 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.weight'] 530 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.bias'] 531 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.weight'] 532 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.bias'] 533 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.1.weight'] 534 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.1.bias'] 535 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.3.weight'] 536 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.3.bias'] 537 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.1.weight'] 538 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.1.bias'] 539 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.3.weight'] 540 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.3.bias'] 541 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.1.weight'] 542 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.1.bias'] 543 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.3.weight'] 544 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.3.bias'] 545 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.5.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.5.weight'] 546 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.5.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.5.bias'] 547 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.7.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.7.weight'] 548 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.4.model.7.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.7.bias'] 549 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.6.weight'] 550 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.6.bias'] 551 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.8.weight'] 552 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.8.bias'] 553 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.6.weight'] 554 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.6.bias'] 555 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.8.weight'] 556 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.8.bias'] 557 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.6.weight'] 558 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.6.bias'] 559 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.8.weight'] 560 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.8.bias'] 561 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.6.weight'] 562 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.6.bias'] 563 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.8.weight'] 564 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.8.bias'] 565 | pretrained_dict['module.model.model.3.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.6.weight'] 566 | pretrained_dict['module.model.model.3.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.6.bias'] 567 | pretrained_dict['module.model.model.3.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.8.weight'] 568 | pretrained_dict['module.model.model.3.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.8.bias'] 569 | 570 | 571 | if EE == 6: 572 | print('6') 573 | # pretrained_dict[''] = pretrained_dict0[''] 574 | pretrained_dict['module.model.model.3.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.1.weight'] 575 | pretrained_dict['module.model.model.3.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.1.bias'] 576 | pretrained_dict['module.model.model.3.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.3.weight'] 577 | pretrained_dict['module.model.model.3.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.3.bias'] 578 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.1.weight'] 579 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.1.bias'] 580 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.3.weight'] 581 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.3.bias'] 582 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.weight'] 583 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.bias'] 584 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.weight'] 585 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.bias'] 586 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.1.weight'] 587 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.1.bias'] 588 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.3.weight'] 589 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.3.bias'] 590 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.1.weight'] 591 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.1.bias'] 592 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.3.weight'] 593 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.3.bias'] 594 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.5.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.5.weight'] 595 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.5.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.5.bias'] 596 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.7.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.7.weight'] 597 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.4.model.7.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.4.model.7.bias'] 598 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.6.weight'] 599 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.6.bias'] 600 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.8.weight'] 601 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.8.bias'] 602 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.6.weight'] 603 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.6.bias'] 604 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.8.weight'] 605 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.8.bias'] 606 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.6.weight'] 607 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.6.bias'] 608 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.8.weight'] 609 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.8.bias'] 610 | pretrained_dict['module.model.model.3.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.6.weight'] 611 | pretrained_dict['module.model.model.3.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.6.bias'] 612 | pretrained_dict['module.model.model.3.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.8.weight'] 613 | pretrained_dict['module.model.model.3.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.8.bias'] 614 | 615 | if EE == 5: 616 | print('5') 617 | 618 | pretrained_dict['module.model.model.3.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.1.weight'] 619 | pretrained_dict['module.model.model.3.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.1.bias'] 620 | pretrained_dict['module.model.model.3.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.3.weight'] 621 | pretrained_dict['module.model.model.3.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.3.bias'] 622 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.1.weight'] 623 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.1.bias'] 624 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.3.weight'] 625 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.3.bias'] 626 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.weight'] 627 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.bias'] 628 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.weight'] 629 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.bias'] 630 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.1.weight'] 631 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.1.bias'] 632 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.3.weight'] 633 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.3.bias'] 634 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.5.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.5.weight'] 635 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.5.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.5.bias'] 636 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.7.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.7.weight'] 637 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.4.model.7.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.4.model.7.bias'] 638 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.6.weight'] 639 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.6.bias'] 640 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.8.weight'] 641 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.8.bias'] 642 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.6.weight'] 643 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.6.bias'] 644 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.8.weight'] 645 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.8.bias'] 646 | pretrained_dict['module.model.model.3.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.6.weight'] 647 | pretrained_dict['module.model.model.3.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.6.bias'] 648 | pretrained_dict['module.model.model.3.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.8.weight'] 649 | pretrained_dict['module.model.model.3.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.8.bias'] 650 | 651 | if EE == 4: 652 | print('4') 653 | pretrained_dict['module.model.model.3.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.1.weight'] 654 | pretrained_dict['module.model.model.3.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.1.bias'] 655 | pretrained_dict['module.model.model.3.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.3.weight'] 656 | pretrained_dict['module.model.model.3.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.3.bias'] 657 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.1.weight'] 658 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.1.bias'] 659 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.3.weight'] 660 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.3.bias'] 661 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.weight'] 662 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.1.bias'] 663 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.weight'] 664 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.3.bias'] 665 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.5.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.5.weight'] 666 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.5.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.5.bias'] 667 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.7.weight'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.7.weight'] 668 | pretrained_dict['module.model.model.3.model.4.model.4.model.4.model.7.bias'] = pretrained_dict0['module.model.model.3.model.4.model.4.model.7.bias'] 669 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.4.model.6.weight'] 670 | pretrained_dict['module.model.model.3.model.4.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.4.model.6.bias'] 671 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.4.model.8.weight'] 672 | pretrained_dict['module.model.model.3.model.4.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.4.model.8.bias'] 673 | pretrained_dict['module.model.model.3.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.6.weight'] 674 | pretrained_dict['module.model.model.3.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.6.bias'] 675 | pretrained_dict['module.model.model.3.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.8.weight'] 676 | pretrained_dict['module.model.model.3.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.8.bias'] 677 | 678 | if EE == 3: 679 | print('3') 680 | 681 | pretrained_dict['module.model.model.3.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.1.weight'] 682 | pretrained_dict['module.model.model.3.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.1.bias'] 683 | pretrained_dict['module.model.model.3.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.3.weight'] 684 | pretrained_dict['module.model.model.3.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.3.bias'] 685 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.4.model.1.weight'] 686 | pretrained_dict['module.model.model.3.model.4.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.4.model.1.bias'] 687 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.4.model.3.weight'] 688 | pretrained_dict['module.model.model.3.model.4.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.4.model.3.bias'] 689 | pretrained_dict['module.model.model.3.model.4.model.4.model.5.weight'] = pretrained_dict0['module.model.model.3.model.4.model.5.weight'] 690 | pretrained_dict['module.model.model.3.model.4.model.4.model.5.bias'] = pretrained_dict0['module.model.model.3.model.4.model.5.bias'] 691 | pretrained_dict['module.model.model.3.model.4.model.4.model.7.weight'] = pretrained_dict0['module.model.model.3.model.4.model.7.weight'] 692 | pretrained_dict['module.model.model.3.model.4.model.4.model.7.bias'] = pretrained_dict0['module.model.model.3.model.4.model.7.bias'] 693 | pretrained_dict['module.model.model.3.model.4.model.6.weight'] = pretrained_dict0['module.model.model.3.model.6.weight'] 694 | pretrained_dict['module.model.model.3.model.4.model.6.bias'] = pretrained_dict0['module.model.model.3.model.6.bias'] 695 | pretrained_dict['module.model.model.3.model.4.model.8.weight'] = pretrained_dict0['module.model.model.3.model.8.weight'] 696 | pretrained_dict['module.model.model.3.model.4.model.8.bias'] = pretrained_dict0['module.model.model.3.model.8.bias'] 697 | if EE == 2: 698 | print('E = 2') 699 | pretrained_dict['module.model.model.3.model.4.model.1.weight'] = pretrained_dict0['module.model.model.3.model.1.weight'] 700 | pretrained_dict['module.model.model.3.model.4.model.1.bias'] = pretrained_dict0['module.model.model.3.model.1.bias'] 701 | pretrained_dict['module.model.model.3.model.4.model.3.weight'] = pretrained_dict0['module.model.model.3.model.3.weight'] 702 | pretrained_dict['module.model.model.3.model.4.model.3.bias'] = pretrained_dict0['module.model.model.3.model.3.bias'] 703 | pretrained_dict['module.model.model.3.model.4.model.5.weight'] = pretrained_dict0['module.model.model.3.model.5.weight'] 704 | pretrained_dict['module.model.model.3.model.4.model.5.bias'] = pretrained_dict0['module.model.model.3.model.5.bias'] 705 | pretrained_dict['module.model.model.3.model.4.model.7.weight'] = pretrained_dict0['module.model.model.3.model.7.weight'] 706 | pretrained_dict['module.model.model.3.model.4.model.7.bias'] = pretrained_dict0['module.model.model.3.model.7.bias'] 707 | 708 | return pretrained_dict 709 | 710 | def weight_transform_BN_MP(pretrained_dict,pretrained_dict0, EE): 711 | Keys0 = [] 712 | Values0 = [] 713 | for k, v in pretrained_dict0.items(): 714 | Keys0.append(k) 715 | Values0.append(v) 716 | Keys1 = [] 717 | Values1 = [] 718 | for k, v in pretrained_dict.items(): 719 | Keys1.append(k) 720 | Values1.append(v) 721 | # print('1', len(Keys0),len(Keys0)) 722 | if EE == 2: 723 | print('2') 724 | for i in range(14,42): 725 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 726 | if EE == 3: 727 | print('3') 728 | for i in range(14,70): 729 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 730 | if EE == 4: 731 | print('4') 732 | for i in range(14,98): 733 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 734 | if EE == 5: 735 | print('5') 736 | for i in range(14,126): 737 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 738 | if EE == 6: 739 | print('6') 740 | for i in range(14,154): 741 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 742 | if EE == 7: 743 | print('7') 744 | for i in range(14,182): 745 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 746 | # return pretrained_dict 747 | return pretrained_dict 748 | 749 | def weight_transform_1024(pretrained_dict,pretrained_dict0, EE): 750 | Keys0 = [] 751 | Values0 = [] 752 | for k, v in pretrained_dict0.items(): 753 | Keys0.append(k) 754 | Values0.append(v) 755 | Keys1 = [] 756 | Values1 = [] 757 | for k, v in pretrained_dict.items(): 758 | Keys1.append(k) 759 | Values1.append(v) 760 | # print('1', len(Keys0),len(Keys0)) 761 | if EE == 2: 762 | print('2') 763 | for i in range(14,42): 764 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 765 | if EE == 3: 766 | print('3') 767 | for i in range(14,70): 768 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 769 | if EE == 4: 770 | print('4') 771 | for i in range(14,98): 772 | pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 773 | # if EE == 5: 774 | # print('5') 775 | # for i in range(14,126): 776 | # pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 777 | # if EE == 6: 778 | # print('6') 779 | # for i in range(14,154): 780 | # pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 781 | # if EE == 7: 782 | # print('7') 783 | # for i in range(14,182): 784 | # pretrained_dict[str(Keys1[i + 14])] = pretrained_dict0[str(Keys0[i])] 785 | # return pretrained_dict 786 | return pretrained_dict -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | # from unet import UNet 15 | import time 16 | from math import ceil 17 | # from utils import get_ids, split_ids, split_train_val, get_imgs_and_masks, batch 18 | import utils.imgs 19 | import utils.training as train_utils 20 | from torch.autograd import Variable 21 | # from unet import unet_model_PG 22 | from unet.unet_model import UNet1, UNet2, UNet3, UNet4 23 | ################################################data 24 | 25 | data_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 26 | batch_size = 10 27 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 28 | 29 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 30 | # train_joint_transform = None 31 | train_dset = data_PG.CamVid(data_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 32 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True) 33 | 34 | val_dset = data_PG.CamVid(data_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 35 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, shuffle=False) 36 | 37 | test_dset = data_PG.CamVid(data_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 38 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=False) 39 | 40 | # print("Train: %d" %len(train_loader.dataset.imgs)) 41 | # print("Val: %d" %len(val_loader.dataset.imgs)) 42 | print("Test: %d" %len(test_loader.dataset.imgs)) 43 | print("Classes: %d" % len(train_loader.dataset.classes)) 44 | 45 | # inputs, targets = next(iter(train_loader)) 46 | # print("Inputs: ", inputs.size()) 47 | # print("Targets: ", targets.size()) 48 | # utils.imgs.view_image(inputs[0]) 49 | # utils.imgs.view_annotated(targets[0]) 50 | EE = 4 51 | EE_size = 256 52 | LR = 1e-3 53 | net = UNet4(n_channels=3, n_classes=4).cuda() 54 | device = 'cuda' 55 | module_dir = '' 56 | net0 = torch.load(module_dir).cuda() 57 | 58 | pretrained_dict0 = net0.state_dict() 59 | pretrained_dict1 = net.state_dict() 60 | Keys0 = [] 61 | Values0 = [] 62 | for k, v in pretrained_dict0.items(): 63 | Keys0.append(k) 64 | Values0.append(v) 65 | Keys1 = [] 66 | Values1 = [] 67 | for k, v in pretrained_dict1.items(): 68 | Keys1.append(k) 69 | Values1.append(v) 70 | print('1', len(Keys0),len(Keys1)) 71 | 72 | for i in range(0,134): 73 | pretrained_dict1[str(Keys1[i])] = pretrained_dict0[str(Keys0[i])] 74 | net.load_state_dict(pretrained_dict1) 75 | net = net.to(device) 76 | model = torch.nn.DataParallel(net, device_ids=[0]).cuda() 77 | 78 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 79 | # print('EE, model', EE, model) 80 | pred_dir = './train_PG_pred/' 81 | 82 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 83 | if __name__ == '__main__': 84 | 85 | epoch_num = 2 86 | best_loss = 1. 87 | best_dice = 0. 88 | LR_DECAY = 0.95 89 | DECAY_EVERY_N_EPOCHS = 10 90 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda() 91 | 92 | for epoch in range(1, epoch_num): 93 | model = model.cuda() 94 | ### Checkpoint ### 95 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 96 | print('-----------test_dice',DICE1) 97 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | from math import ceil 16 | import utils.imgs 17 | import utils.training as train_utils 18 | from torch.autograd import Variable 19 | from unet.unet_model import UNet4 20 | 21 | #################################################data 22 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 23 | WEIGHTS_PATH = Path('./weights/train_1024/') 24 | 25 | WEIGHTS_PATH.mkdir(exist_ok=True) 26 | batch_size = 60 27 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 28 | 29 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 30 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 31 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True) 32 | 33 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 34 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, shuffle=False) 35 | 36 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 37 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=False) 38 | 39 | print("Train: %d" %len(train_loader.dataset.imgs)) 40 | print("Val: %d" %len(val_loader.dataset.imgs)) 41 | print("Test: %d" %len(test_loader.dataset.imgs)) 42 | print("Classes: %d" % len(train_loader.dataset.classes)) 43 | 44 | inputs, targets = next(iter(train_loader)) 45 | print("Inputs: ", inputs.size()) 46 | print("Targets: ", targets.size()) 47 | # utils.imgs.view_image(inputs[0]) 48 | # utils.imgs.view_annotated(targets[0]) 49 | EE = 4 50 | device = 'cuda' 51 | EE_size = 256 52 | LR = 1e-3 53 | model = UNet4(n_channels=3, n_classes=4).cuda() 54 | model = model.to(device) 55 | model = torch.nn.DataParallel(model).cuda() 56 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 57 | # print('EE, model', EE, model) 58 | pred_dir = './train_PG_pred/' 59 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 60 | ######################## 61 | params = list(model.parameters()) 62 | k = 0 63 | for i in params: 64 | l = 1 65 | # print("该层的结构:" + str(list(i.size()))) 66 | for j in i.size(): 67 | l *= j 68 | # print("该层参数和:" + str(l)) 69 | k = k + l 70 | print("总参数数量和:" + str(k)) 71 | ############################# 72 | if __name__ == '__main__': 73 | EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# 74 | 75 | epoch_num = 10000 76 | best_loss = 1. 77 | best_dice = 0. 78 | LR_DECAY = 0.95 79 | DECAY_EVERY_N_EPOCHS = 10 80 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda() 81 | 82 | for epoch in range(1, epoch_num): 83 | 84 | model = model.cuda() 85 | 86 | ################################################## 87 | ### Train ### 88 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 89 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 90 | ## Test ### 91 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 92 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 93 | ### Checkpoint ### 94 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 95 | print('-----------test_dice',DICE1) 96 | if best_dice < DICE1: 97 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 98 | best_dice = DICE1 99 | ### Adjust Lr ### 100 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 101 | -------------------------------------------------------------------------------- /train_1024_DataParallel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | import datetime 16 | from math import ceil 17 | import utils.imgs 18 | import utils.training as train_utils 19 | from torch.autograd import Variable 20 | from unet.unet_model import UNet4 21 | 22 | #################################################data 23 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 24 | WEIGHTS_PATH = Path('./weights/train_1024/') 25 | 26 | WEIGHTS_PATH.mkdir(exist_ok=True) 27 | batch_size = 60 28 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 29 | 30 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 31 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 32 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True) 33 | 34 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 35 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, shuffle=False) 36 | 37 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 38 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=False) 39 | 40 | print("Train: %d" %len(train_loader.dataset.imgs)) 41 | print("Val: %d" %len(val_loader.dataset.imgs)) 42 | print("Test: %d" %len(test_loader.dataset.imgs)) 43 | print("Classes: %d" % len(train_loader.dataset.classes)) 44 | 45 | inputs, targets = next(iter(train_loader)) 46 | print("Inputs: ", inputs.size()) 47 | print("Targets: ", targets.size()) 48 | # utils.imgs.view_image(inputs[0]) 49 | # utils.imgs.view_annotated(targets[0]) 50 | EE = 4 51 | device = 'cuda' 52 | EE_size = 256 53 | LR = 1e-3 54 | model = UNet4(n_channels=3, n_classes=4).cuda() 55 | model = model.to(device) 56 | model = torch.nn.DataParallel(model).cuda() 57 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 58 | # print('EE, model', EE, model) 59 | pred_dir = './train_PG_pred/' 60 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 61 | ######################## 62 | params = list(model.parameters()) 63 | k = 0 64 | for i in params: 65 | l = 1 66 | # print("该层的结构:" + str(list(i.size()))) 67 | for j in i.size(): 68 | l *= j 69 | # print("该层参数和:" + str(l)) 70 | k = k + l 71 | print("总参数数量和:" + str(k)) 72 | ############################# 73 | if __name__ == '__main__': 74 | EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# 75 | 76 | epoch_num = 10000 77 | best_loss = 1. 78 | best_dice = 0. 79 | LR_DECAY = 0.95 80 | DECAY_EVERY_N_EPOCHS = 10 81 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda() 82 | 83 | for epoch in range(1, epoch_num): 84 | 85 | model = model.cuda() 86 | start_time = datetime.datetime.now() 87 | print('start_time',start_time) 88 | ################################################## 89 | ### Train ### 90 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 91 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 92 | ## Test ### 93 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 94 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 95 | ### Checkpoint ### 96 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 97 | print('-----------test_dice',DICE1) 98 | if best_dice < DICE1: 99 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 100 | best_dice = DICE1 101 | ### Adjust Lr ### 102 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 103 | end_time = datetime.datetime.now() 104 | print('end_time', end_time) 105 | print('time', (end_time - start_time).seconds) 106 | -------------------------------------------------------------------------------- /train_1024_apex.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | import datetime 16 | from math import ceil 17 | import utils.imgs 18 | import utils.training as train_utils 19 | from torch.autograd import Variable 20 | from unet.unet_model import UNet4 21 | ################################################################## 22 | import random 23 | import argparse 24 | import torch.distributed as dist 25 | # import torch.multiprocessing as mp 26 | #torch.multiprocessing 会帮助我们自动创建进程,spawn 开启了 nprocs=4 个线程,每个线程执行 main_worker 并向其中传入 local_rank(当前进程 index)和 args(即 4 和 myargs)作为参数 27 | 28 | parser = argparse.ArgumentParser(description='使用 Apex 再加速') 29 | from apex.parallel import DistributedDataParallel 30 | from apex import amp 31 | parser.add_argument('--local_rank', default=-1, type=int, 32 | help='node rank for distributed training') 33 | def main(): 34 | args = parser.parse_args() 35 | 36 | # random.seed(args.seed) 37 | # torch.manual_seed(args.seed) 38 | cudnn.deterministic = True 39 | 40 | 41 | main_worker(args.local_rank, 4, args) 42 | def main_worker(gpu, ngpus_per_node, args): 43 | 44 | dist.init_process_group(backend='nccl') 45 | torch.cuda.set_device(gpu) 46 | 47 | #################################################data 48 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 49 | WEIGHTS_PATH = Path('./weights/train_1024/') 50 | 51 | WEIGHTS_PATH.mkdir(exist_ok=True) 52 | batch_size = 60 53 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 54 | 55 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 56 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 57 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 58 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 59 | ####################################### 60 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dset) 61 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=train_sampler) 62 | ####################################### 63 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dset) 64 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, num_workers=2,pin_memory=True,sampler=val_sampler) 65 | ####################################### 66 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) 67 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=test_sampler) 68 | ####################################### 69 | 70 | print("Train: %d" %len(train_loader.dataset.imgs)) 71 | print("Val: %d" %len(val_loader.dataset.imgs)) 72 | print("Test: %d" %len(test_loader.dataset.imgs)) 73 | print("Classes: %d" % len(train_loader.dataset.classes)) 74 | 75 | inputs, targets = next(iter(train_loader)) 76 | print("Inputs: ", inputs.size()) 77 | print("Targets: ", targets.size()) 78 | # utils.imgs.view_image(inputs[0]) 79 | # utils.imgs.view_annotated(targets[0]) 80 | EE = 4 81 | device = 'cuda' 82 | EE_size = 256 83 | LR = 1e-3 84 | model = UNet4(n_channels=3, n_classes=4) 85 | # gpu = args.local_rank 86 | torch.cuda.set_device(gpu) 87 | model.cuda(gpu) 88 | # model = model.to(device) 89 | # model = torch.nn.DataParallel(model).cuda() 90 | ######################################## 91 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 92 | model, optimizer = amp.initialize(model,optimizer) 93 | model = DistributedDataParallel(model) 94 | ################################################### 95 | 96 | # print('EE, model', EE, model) 97 | pred_dir = './train_PG_pred/' 98 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 99 | 100 | ############################# 101 | 102 | EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# 103 | 104 | epoch_num = 10000 105 | best_loss = 1. 106 | best_dice = 0. 107 | LR_DECAY = 0.95 108 | DECAY_EVERY_N_EPOCHS = 10 109 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda(gpu) 110 | cudnn.benchmark = True 111 | for epoch in range(1, epoch_num): 112 | start_time = datetime.datetime.now() 113 | print('start_time',start_time) 114 | model = model.cuda() 115 | 116 | ################################################## 117 | ### Train ### 118 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 119 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 120 | ## Test ### 121 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 122 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 123 | ### Checkpoint ### 124 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 125 | print('-----------test_dice',DICE1) 126 | if best_dice < DICE1: 127 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 128 | best_dice = DICE1 129 | ### Adjust Lr ### 130 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 131 | end_time = datetime.datetime.now() 132 | print('end_time', end_time) 133 | print('time', (end_time - start_time).seconds) 134 | if __name__ == '__main__': 135 | main() 136 | -------------------------------------------------------------------------------- /train_1024_distributed.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | import datetime 16 | from math import ceil 17 | import utils.imgs 18 | import utils.training as train_utils 19 | from torch.autograd import Variable 20 | from unet.unet_model import UNet4 21 | ################################################################## 22 | import random 23 | import argparse 24 | import torch.distributed as dist 25 | 26 | parser = argparse.ArgumentParser(description='使用 torch.distributed 加速并行训练') 27 | 28 | parser.add_argument('--local_rank', default=-1, type=int, 29 | help='node rank for distributed training') 30 | def main(): 31 | args = parser.parse_args() 32 | 33 | # random.seed(args.seed) 34 | # torch.manual_seed(args.seed) 35 | cudnn.deterministic = True 36 | 37 | 38 | main_worker(args.local_rank, 4, args) 39 | def main_worker(gpu, ngpus_per_node, args): 40 | 41 | dist.init_process_group(backend='nccl') 42 | torch.cuda.set_device(gpu) 43 | 44 | #################################################data 45 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 46 | WEIGHTS_PATH = Path('./weights/train_1024/') 47 | 48 | WEIGHTS_PATH.mkdir(exist_ok=True) 49 | batch_size = 60 50 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 51 | 52 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 53 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 54 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 55 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 56 | ####################################### 57 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dset) 58 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=train_sampler) 59 | ####################################### 60 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dset) 61 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, num_workers=2,pin_memory=True,sampler=val_sampler) 62 | ####################################### 63 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) 64 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=test_sampler) 65 | ####################################### 66 | 67 | print("Train: %d" %len(train_loader.dataset.imgs)) 68 | print("Val: %d" %len(val_loader.dataset.imgs)) 69 | print("Test: %d" %len(test_loader.dataset.imgs)) 70 | print("Classes: %d" % len(train_loader.dataset.classes)) 71 | 72 | inputs, targets = next(iter(train_loader)) 73 | print("Inputs: ", inputs.size()) 74 | print("Targets: ", targets.size()) 75 | # utils.imgs.view_image(inputs[0]) 76 | # utils.imgs.view_annotated(targets[0]) 77 | EE = 4 78 | device = 'cuda' 79 | EE_size = 256 80 | LR = 1e-3 81 | model = UNet4(n_channels=3, n_classes=4) 82 | # gpu = args.local_rank 83 | torch.cuda.set_device(gpu) 84 | model.cuda(gpu) 85 | # model = model.to(device) 86 | # model = torch.nn.DataParallel(model).cuda() 87 | ######################################## 88 | model = torch.nn.parallel.DistributedDataParallel(model,find_unused_parameters=True, device_ids=[gpu]) 89 | ################################################### 90 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 91 | # print('EE, model', EE, model) 92 | pred_dir = './train_PG_pred/' 93 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 94 | 95 | ############################# 96 | 97 | EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# 98 | 99 | epoch_num = 10000 100 | best_loss = 1. 101 | best_dice = 0. 102 | LR_DECAY = 0.95 103 | DECAY_EVERY_N_EPOCHS = 10 104 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda(gpu) 105 | cudnn.benchmark = True 106 | for epoch in range(1, epoch_num): 107 | start_time = datetime.datetime.now() 108 | print('start_time',start_time) 109 | model = model.cuda() 110 | 111 | ################################################## 112 | ### Train ### 113 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 114 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 115 | ## Test ### 116 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 117 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 118 | ### Checkpoint ### 119 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 120 | print('-----------test_dice',DICE1) 121 | if best_dice < DICE1: 122 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 123 | best_dice = DICE1 124 | ### Adjust Lr ### 125 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 126 | end_time = datetime.datetime.now() 127 | print('end_time', end_time) 128 | print('time', (end_time - start_time).seconds) 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /train_1024_horovod.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | import datetime 16 | from math import ceil 17 | import utils.imgs 18 | import utils.training as train_utils 19 | from torch.autograd import Variable 20 | from unet.unet_model import UNet4 21 | ################################################################## 22 | import random 23 | import argparse 24 | import torch.distributed as dist 25 | 26 | 27 | parser = argparse.ArgumentParser(description='使用 Apex 再加速') 28 | import horovod.torch as hvd 29 | def main(): 30 | args = parser.parse_args() 31 | 32 | # random.seed(args.seed) 33 | # torch.manual_seed(args.seed) 34 | cudnn.deterministic = True 35 | 36 | hvd.init() 37 | local_rank = hvd.local_rank() 38 | torch.cuda.set_device(local_rank) 39 | main_worker(args.local_rank, 4, args) 40 | def main_worker(gpu, ngpus_per_node, args): 41 | 42 | # dist.init_process_group(backend='nccl') 43 | torch.cuda.set_device(gpu) 44 | 45 | #################################################data 46 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 47 | WEIGHTS_PATH = Path('./weights/train_1024/') 48 | 49 | WEIGHTS_PATH.mkdir(exist_ok=True) 50 | batch_size = 60 51 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 52 | 53 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 54 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 55 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 56 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 57 | ####################################### 58 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dset) 59 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=train_sampler) 60 | ####################################### 61 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dset) 62 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, num_workers=2,pin_memory=True,sampler=val_sampler) 63 | ####################################### 64 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) 65 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, num_workers=2, pin_memory=True,sampler=test_sampler) 66 | ####################################### 67 | 68 | print("Train: %d" %len(train_loader.dataset.imgs)) 69 | print("Val: %d" %len(val_loader.dataset.imgs)) 70 | print("Test: %d" %len(test_loader.dataset.imgs)) 71 | print("Classes: %d" % len(train_loader.dataset.classes)) 72 | 73 | inputs, targets = next(iter(train_loader)) 74 | print("Inputs: ", inputs.size()) 75 | print("Targets: ", targets.size()) 76 | # utils.imgs.view_image(inputs[0]) 77 | # utils.imgs.view_annotated(targets[0]) 78 | EE = 4 79 | device = 'cuda' 80 | EE_size = 256 81 | LR = 1e-3 82 | model = UNet4(n_channels=3, n_classes=4) 83 | # gpu = args.local_rank 84 | torch.cuda.set_device(gpu) 85 | model.cuda(gpu) 86 | # model = model.to(device) 87 | hvd.broadcast_parameters(model.state_dict(), root_rank=0) 88 | ######################################## 89 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 90 | hvd.broadcast_optimizer_state(optimizer, root_rank=0) 91 | compression = hvd.Compression.fp16 92 | 93 | optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=compression) 94 | model = DistributedDataParallel(model) 95 | ################################################### 96 | 97 | # print('EE, model', EE, model) 98 | pred_dir = './train_PG_pred/' 99 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 100 | 101 | ############################# 102 | 103 | EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# 104 | 105 | epoch_num = 10000 106 | best_loss = 1. 107 | best_dice = 0. 108 | LR_DECAY = 0.95 109 | DECAY_EVERY_N_EPOCHS = 10 110 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda(gpu) 111 | cudnn.benchmark = True 112 | for epoch in range(1, epoch_num): 113 | start_time = datetime.datetime.now() 114 | print('start_time',start_time) 115 | model = model.cuda() 116 | 117 | ################################################## 118 | ### Train ### 119 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 120 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 121 | ## Test ### 122 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 123 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 124 | ### Checkpoint ### 125 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 126 | print('-----------test_dice',DICE1) 127 | if best_dice < DICE1: 128 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 129 | best_dice = DICE1 130 | ### Adjust Lr ### 131 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 132 | end_time = datetime.datetime.now() 133 | print('end_time', end_time) 134 | print('time', (end_time - start_time).seconds) 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /train_1024_multiprocessing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | import datetime 16 | from math import ceil 17 | import utils.imgs 18 | import utils.training as train_utils 19 | from torch.autograd import Variable 20 | from unet.unet_model import UNet4 21 | ################################################################## 22 | import random 23 | import argparse 24 | import torch.distributed as dist 25 | import torch.multiprocessing as mp 26 | #torch.multiprocessing 会帮助我们自动创建进程,spawn 开启了 nprocs=4 个线程,每个线程执行 main_worker 并向其中传入 local_rank(当前进程 index)和 args(即 4 和 myargs)作为参数 27 | 28 | 29 | parser = argparse.ArgumentParser(description='使用 torch.multiprocessing 取代启动器') 30 | def main(): 31 | args = parser.parse_args() 32 | 33 | # random.seed(args.seed) 34 | # torch.manual_seed(args.seed) 35 | cudnn.deterministic = True 36 | 37 | 38 | mp.spawn(main_worker, nprocs=4, args=(4, args)) 39 | def main_worker(gpu, ngpus_per_node, args): 40 | #由于没有 torch.distributed.launch 读取的默认环境变量作为配置,我们需要手动为 init_process_group 指定参数 41 | dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=4, rank=gpu) 42 | # torch.cuda.set_device(args.local_rank) 43 | 44 | #################################################data 45 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 46 | WEIGHTS_PATH = Path('./weights/train_1024/') 47 | 48 | WEIGHTS_PATH.mkdir(exist_ok=True) 49 | batch_size = 60 50 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 51 | 52 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 53 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 54 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 55 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 56 | ####################################### 57 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dset) 58 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, sampler=train_sampler) 59 | ####################################### 60 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dset) 61 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, sampler=val_sampler) 62 | ####################################### 63 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dset) 64 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, sampler=test_sampler) 65 | ####################################### 66 | 67 | print("Train: %d" %len(train_loader.dataset.imgs)) 68 | print("Val: %d" %len(val_loader.dataset.imgs)) 69 | print("Test: %d" %len(test_loader.dataset.imgs)) 70 | print("Classes: %d" % len(train_loader.dataset.classes)) 71 | 72 | inputs, targets = next(iter(train_loader)) 73 | print("Inputs: ", inputs.size()) 74 | print("Targets: ", targets.size()) 75 | # utils.imgs.view_image(inputs[0]) 76 | # utils.imgs.view_annotated(targets[0]) 77 | EE = 4 78 | device = 'cuda' 79 | EE_size = 256 80 | LR = 1e-3 81 | model = UNet4(n_channels=3, n_classes=4) 82 | # gpu = args.local_rank 83 | torch.cuda.set_device(gpu) 84 | model.cuda(gpu) 85 | # model = model.to(device) 86 | # model = torch.nn.DataParallel(model).cuda() 87 | ######################################## 88 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=True) 89 | ################################################### 90 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 91 | # print('EE, model', EE, model) 92 | pred_dir = './train_PG_pred/' 93 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 94 | 95 | ############################# 96 | 97 | EE0 = 1 #EE =1-8#2-16#3-32#4-64#5-128#6-256#7-512#8-1024# 98 | 99 | epoch_num = 10000 100 | best_loss = 1. 101 | best_dice = 0. 102 | LR_DECAY = 0.95 103 | DECAY_EVERY_N_EPOCHS = 10 104 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda() 105 | 106 | for epoch in range(1, epoch_num): 107 | start_time = datetime.datetime.now() 108 | print('start_time',start_time) 109 | model = model.cuda() 110 | 111 | ################################################## 112 | ### Train ### 113 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 114 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 115 | ## Test ### 116 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 117 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 118 | ### Checkpoint ### 119 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 120 | print('-----------test_dice',DICE1) 121 | if best_dice < DICE1: 122 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 123 | best_dice = DICE1 124 | ### Adjust Lr ### 125 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 126 | end_time = datetime.datetime.now() 127 | print('end_time', end_time) 128 | print('time', (end_time - start_time).seconds) 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /train_PG.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from optparse import OptionParser 4 | import numpy as np 5 | from pathlib import Path 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | from torch import optim 10 | from datasets import data_PG 11 | from datasets import joint_transforms 12 | import torchvision.transforms as transforms 13 | from eval_PG import eval_net, train_net, save_weights, adjust_learning_rate, view_sample_predictions, calculateDice, save_weights_dice, weight_transform_1024 14 | import time 15 | from math import ceil 16 | import utils.imgs 17 | import utils.training as train_utils 18 | from torch.autograd import Variable 19 | from unet.unet_model import UNet1, UNet2, UNet3, UNet4 20 | 21 | ###############################################data 22 | 23 | PAP_PATH = Path('/home/zhaojie/zhaojie/PG/Pdata/') 24 | WEIGHTS_PATH = Path('./weights/train_PG_1024/') 25 | WEIGHTS_PATH.mkdir(exist_ok=True) 26 | batch_size = 30 27 | normalize = transforms.Normalize(mean=data_PG.mean, std=data_PG.std) 28 | 29 | train_joint_transformer = transforms.Compose([joint_transforms.JointRandomHorizontalFlip()]) 30 | # train_joint_transform = None 31 | train_dset = data_PG.CamVid(PAP_PATH, 'train', joint_transform=train_joint_transformer, transform=transforms.Compose([transforms.ToTensor(), normalize])) 32 | train_loader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True) 33 | 34 | val_dset = data_PG.CamVid(PAP_PATH, 'val', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 35 | val_loader = torch.utils.data.DataLoader(val_dset, batch_size=batch_size, shuffle=False) 36 | 37 | test_dset = data_PG.CamVid(PAP_PATH, 'test', joint_transform=None, transform=transforms.Compose([ transforms.ToTensor(), normalize])) 38 | test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=False) 39 | 40 | print("Train: %d" %len(train_loader.dataset.imgs)) 41 | print("Val: %d" %len(val_loader.dataset.imgs)) 42 | print("Test: %d" %len(test_loader.dataset.imgs)) 43 | print("Classes: %d" % len(train_loader.dataset.classes)) 44 | 45 | inputs, targets = next(iter(train_loader)) 46 | print("Inputs: ", inputs.size()) 47 | print("Targets: ", targets.size()) 48 | # utils.imgs.view_image(inputs[0]) 49 | # utils.imgs.view_annotated(targets[0]) 50 | device = 'cuda' 51 | LR = 3e-4 52 | EE = 1 53 | EE_size = 32 54 | model = UNet1(n_channels=3, n_classes=4).cuda() 55 | 56 | model = model.to(device) 57 | model = torch.nn.DataParallel(model).cuda() 58 | optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4) 59 | # print('EE, model', EE, model) 60 | pred_dir = './train_PG_pred/' 61 | FILE_test_imgs_original = '/home/zhaojie/zhaojie/PG/Pdata/test' 62 | if __name__ == '__main__': 63 | EE0 = 1 64 | 65 | epoch_num = 1000 66 | best_loss = 1. 67 | best_dice = 0. 68 | LR_DECAY = 0.95 69 | DECAY_EVERY_N_EPOCHS = 40 70 | criterion = nn.NLLLoss(weight=data_PG.class_weight.cuda()).cuda() 71 | base_params = list(map(id, model.parameters())) 72 | for epoch in range(1, epoch_num): 73 | #################EE######################################## 74 | 75 | if int(epoch / 40) < 4: 76 | EE = ceil(epoch / 40) #1,2,3,4,5,6,7 77 | else: 78 | EE = 4 79 | print('---------------EE-------------------',EE) 80 | 81 | if EE0 != EE: 82 | if EE == 2: 83 | model1 = UNet2(n_channels=3, n_classes=4).cuda() 84 | EE_size = 64 85 | logits_params = filter(lambda p: id(p) not in base_params, model1.parameters()) 86 | params = [{"params": logits_params, "lr": 1e-4}, 87 | {"params": model.parameters(), "lr": 1e-6}] 88 | if EE == 3: 89 | model1 = UNet3(n_channels=3, n_classes=4).cuda() 90 | EE_size = 128 91 | logits_params = filter(lambda p: id(p) not in base_params, model1.parameters()) 92 | params = [{"params": logits_params, "lr": 1e-4}, 93 | {"params": model.parameters(), "lr": 1e-6}] 94 | if EE == 4: 95 | model1 = UNet4(n_channels=3, n_classes=4).cuda() 96 | EE_size = 256 97 | logits_params = filter(lambda p: id(p) not in base_params, model1.parameters()) 98 | params = [{"params": logits_params, "lr": 1e-4}, 99 | {"params": model.parameters(), "lr": 1e-4}] 100 | EE0 = EE 101 | pretrained_dict0 = model.state_dict() 102 | 103 | 104 | 105 | 106 | optimizer1 = torch.optim.RMSprop(params, lr=LR, weight_decay=1e-4) 107 | optimizer = optimizer1 108 | pretrained_dict = model1.state_dict() 109 | pretrained_dict = weight_transform_1024(pretrained_dict,pretrained_dict0, EE) 110 | 111 | model1.load_state_dict(pretrained_dict) 112 | model = model1.cuda() 113 | model = model.to(device) 114 | model = torch.nn.DataParallel(model).cuda() 115 | # print('EE, model', model) 116 | ################################################################# 117 | 118 | model = model.cuda() 119 | 120 | ################################################## 121 | ### Train ### 122 | trn_loss, trn_err, train_DICE = train_net(model, train_loader, criterion, optimizer, EE_size) 123 | print('Epoch {:d}\nTrain - Loss: {:.4f}, Acc: {:.4f}, Dice: {:.4f}'.format(epoch, trn_loss, 1-trn_err, train_DICE)) 124 | ## Test ### 125 | val_loss, val_err, val_DICE = eval_net(model, val_loader, criterion, EE_size) 126 | print('Val - Loss: {:.4f} | Acc: {:.4f}, Dice: {:.4f}'.format(val_loss, 1-val_err, val_DICE)) 127 | ### Checkpoint ### 128 | DICE1 = view_sample_predictions(model, test_loader, FILE_test_imgs_original, pred_dir, EE_size) 129 | print('----------test_dice',DICE1, best_dice) 130 | if best_dice < DICE1: 131 | # save_weights_dice(WEIGHTS_PATH, model, epoch, train_DICE, val_DICE, DICE1, EE) 132 | best_dice = DICE1 133 | ### Adjust Lr ### 134 | adjust_learning_rate(LR, LR_DECAY, optimizer, epoch, DECAY_EVERY_N_EPOCHS) 135 | -------------------------------------------------------------------------------- /unet/__init__.py: -------------------------------------------------------------------------------- 1 | # from .unet_model import UNet1, UNet2, UNet3, UNet4 2 | -------------------------------------------------------------------------------- /unet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/dice_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/dice_loss.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_model.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_model_BN_MP_PG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_model_BN_MP_PG.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_model_PG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_model_PG.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_model_PG_256.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_model_PG_256.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_model_PG_more_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_model_PG_more_conv.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_model_PG_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_model_PG_new.cpython-36.pyc -------------------------------------------------------------------------------- /unet/__pycache__/unet_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/unet/__pycache__/unet_parts.cpython-36.pyc -------------------------------------------------------------------------------- /unet/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function, Variable 3 | 4 | class DiceCoeff(Function): 5 | """Dice coeff for individual examples""" 6 | 7 | def forward(self, input, target): 8 | self.save_for_backward(input, target) 9 | eps = 0.0001 10 | self.inter = torch.dot(input.view(-1), target.view(-1)) 11 | self.union = torch.sum(input) + torch.sum(target) + eps 12 | 13 | t = (2 * self.inter.float() + eps) / self.union.float() 14 | return t 15 | 16 | # This function has only a single output, so it gets only one gradient 17 | def backward(self, grad_output): 18 | 19 | input, target = self.saved_variables 20 | grad_input = grad_target = None 21 | 22 | if self.needs_input_grad[0]: 23 | grad_input = grad_output * 2 * (target * self.union - self.inter) \ 24 | / (self.union * self.union) 25 | if self.needs_input_grad[1]: 26 | grad_target = None 27 | 28 | return grad_input, grad_target 29 | 30 | 31 | def dice_coeff(input, target): 32 | """Dice coeff for batches""" 33 | if input.is_cuda: 34 | s = torch.FloatTensor(1).cuda().zero_() 35 | else: 36 | s = torch.FloatTensor(1).zero_() 37 | 38 | for i, c in enumerate(zip(input, target)): 39 | s = s + DiceCoeff().forward(c[0], c[1]) 40 | 41 | return s / (i + 1) 42 | -------------------------------------------------------------------------------- /unet/unet_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from .unet_parts import * 4 | import torch.nn as nn 5 | class UNet1(nn.Module): 6 | def __init__(self, n_channels, n_classes): 7 | super(UNet1, self).__init__() 8 | self.inc = inconv(n_channels, 512) 9 | self.down4 = down(512, 512) 10 | self.up1 = up(1024, 256) 11 | self.outc = outconv(256, n_classes) 12 | self.LS = nn.LogSoftmax(dim=1) 13 | def forward(self, x): 14 | x1 = self.inc(x) 15 | # print('0', x1.shape) 16 | x2 = self.down4(x1) 17 | # print('1', x2.shape) 18 | x3 = self.up1(x2, x1) 19 | # print('x3', x3.shape) 20 | x = self.outc(x3) 21 | x = self.LS(x) 22 | # print('x', x.shape) 23 | return x 24 | 25 | class UNet2(nn.Module): 26 | def __init__(self, n_channels, n_classes): 27 | super(UNet2, self).__init__() 28 | self.inc = inconv(n_channels, 256) 29 | self.down3 = down(256, 512) 30 | self.down4 = down(512, 512) 31 | self.up1 = up(1024, 256) 32 | self.up2 = up(512, 128) 33 | self.outc1 = outconv(256, n_classes) 34 | self.outc2 = outconv(128, n_classes) 35 | self.LS = nn.LogSoftmax(dim=1) 36 | def forward(self, x): 37 | x1 = self.inc(x)#64 38 | x2 = self.down3(x1)#32 39 | x3 = self.down4(x2)#16 40 | x4 = self.up1(x3, x2) 41 | x5 = self.up2(x4, x1) 42 | x4 = self.outc1(x4) 43 | x5 = self.outc2(x5) 44 | 45 | x4 = nn.functional.interpolate(x4, scale_factor=(2, 2), mode='bilinear', align_corners=True) 46 | x = x4 + x5 47 | x = self.LS(x) 48 | # print('x', x.shape) 49 | return x 50 | 51 | class UNet3(nn.Module): 52 | def __init__(self, n_channels, n_classes): 53 | super(UNet3, self).__init__() 54 | self.inc = inconv(n_channels, 128) 55 | self.down2 = down(128, 256) 56 | self.down3 = down(256, 512) 57 | self.down4 = down(512, 512) 58 | self.up1 = up(1024, 256) 59 | self.up2 = up(512, 128) 60 | self.up3 = up(256, 64) 61 | self.outc1 = outconv(256, n_classes) 62 | self.outc2 = outconv(128, n_classes) 63 | self.outc3 = outconv(64, n_classes) 64 | self.LS = nn.LogSoftmax(dim=1) 65 | def forward(self, x): 66 | x1 = self.inc(x)#128 67 | x2 = self.down2(x1)#64 68 | x3 = self.down3(x2)#32 69 | x4 = self.down4(x3)#16 70 | x5 = self.up1(x4, x3)#32+32 71 | x6 = self.up2(x5, x2) 72 | x7 = self.up3(x6, x1) 73 | x5 = self.outc1(x5) 74 | x6 = self.outc2(x6) 75 | x7 = self.outc3(x7) 76 | 77 | x5 = nn.functional.interpolate(x5, scale_factor=(4, 4), mode='bilinear', align_corners=True) 78 | x6 = nn.functional.interpolate(x6, scale_factor=(2, 2), mode='bilinear', align_corners=True) 79 | x = x5 + x6 + x7 80 | # x = x6 + x7 81 | x = self.LS(x) 82 | # print('x', x.shape) 83 | return x 84 | 85 | class UNet4(nn.Module): 86 | def __init__(self, n_channels, n_classes): 87 | super(UNet4, self).__init__() 88 | self.inc = inconv(n_channels, 64) 89 | self.down1 = down(64, 128) 90 | self.down2 = down(128, 256) 91 | self.down3 = down(256, 512) 92 | self.down4 = down(512, 512) 93 | self.up1 = up(1024, 256) 94 | self.up2 = up(512, 128) 95 | self.up3 = up(256, 64) 96 | self.up4 = up(128, 64) 97 | self.outc1 = outconv(256, n_classes) 98 | self.outc2 = outconv(128, n_classes) 99 | self.outc3 = outconv(64, n_classes) 100 | self.outc4 = outconv(64, n_classes) 101 | self.LS = nn.LogSoftmax(dim=1) 102 | def forward(self, x): 103 | x1 = self.inc(x)#256 104 | x2 = self.down1(x1)#128 105 | x3 = self.down2(x2)#64 106 | x4 = self.down3(x3)#32 107 | x5 = self.down4(x4)#16 108 | x6 = self.up1(x5, x4)#32+32 109 | 110 | x7 = self.up2(x6, x3) 111 | 112 | x8 = self.up3(x7, x2) 113 | x9 = self.up4(x8, x1) 114 | 115 | x6 = self.outc1(x6) 116 | x7 = self.outc2(x7) 117 | x8 = self.outc3(x8) 118 | x9 = self.outc4(x9) 119 | x6 = nn.functional.interpolate(x6, scale_factor=(8, 8), mode='bilinear', align_corners=True) 120 | x7 = nn.functional.interpolate(x7, scale_factor=(4, 4), mode='bilinear', align_corners=True) 121 | x8 = nn.functional.interpolate(x8, scale_factor=(2, 2), mode='bilinear', align_corners=True) 122 | x = x6 + x7 + x8 + x9 123 | # x = x8 + x9 124 | x = self.LS(x) 125 | # print('x', x.shape) 126 | return x 127 | -------------------------------------------------------------------------------- /unet/unet_parts.py: -------------------------------------------------------------------------------- 1 | # sub-parts of the U-Net model 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class double_conv(nn.Module): 9 | '''(conv => BN => ReLU) * 2''' 10 | def __init__(self, in_ch, out_ch): 11 | super(double_conv, self).__init__() 12 | self.conv = nn.Sequential( 13 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 14 | nn.BatchNorm2d(out_ch), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 17 | nn.BatchNorm2d(out_ch), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | return x 24 | 25 | 26 | class inconv(nn.Module): 27 | def __init__(self, in_ch, out_ch): 28 | super(inconv, self).__init__() 29 | self.conv = double_conv(in_ch, out_ch) 30 | 31 | def forward(self, x): 32 | x = self.conv(x) 33 | return x 34 | 35 | 36 | class down(nn.Module): 37 | def __init__(self, in_ch, out_ch): 38 | super(down, self).__init__() 39 | self.mpconv = nn.Sequential( 40 | nn.MaxPool2d(2), 41 | double_conv(in_ch, out_ch) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.mpconv(x) 46 | return x 47 | 48 | 49 | class up(nn.Module): 50 | def __init__(self, in_ch, out_ch, bilinear=True): 51 | super(up, self).__init__() 52 | 53 | # would be a nice idea if the upsampling could be learned too, 54 | # but my machine do not have enough memory to handle all those weights 55 | # if bilinear: 56 | # self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 57 | # self.up = nn.functional.interpolate(scale_factor=(2, 2), mode='bilinear', align_corners=True) 58 | # else: 59 | # self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 60 | 61 | self.conv = double_conv(in_ch, out_ch) 62 | 63 | def forward(self, x1, x2): 64 | # x1 = self.up(x1) 65 | x1 = nn.functional.interpolate(x1, scale_factor=(2, 2), mode='bilinear', align_corners=True) 66 | # print('0', x1.shape) 67 | # input is CHW 68 | # diffY = x2.size()[2] - x1.size()[2] 69 | # diffX = x2.size()[3] - x1.size()[3] 70 | 71 | x = torch.cat([x2, x1], dim=1) 72 | # print('2', x.shape) 73 | x = self.conv(x) 74 | return x 75 | 76 | 77 | class outconv(nn.Module): 78 | def __init__(self, in_ch, out_ch): 79 | super(outconv, self).__init__() 80 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 81 | 82 | def forward(self, x): 83 | x = self.conv(x) 84 | return x 85 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .crf import * 2 | from .load import * 3 | from .utils import * 4 | from .data_vis import * 5 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/crf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/crf.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_vis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/data_vis.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/imgs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/imgs.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/load.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/load.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/training.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/training.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Minerva-J/PGU-net-Model/7a059a2da53eaefb6d73557aa459a117841e5625/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/crf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | 4 | def dense_crf(img, output_probs): 5 | h = output_probs.shape[0] 6 | w = output_probs.shape[1] 7 | 8 | output_probs = np.expand_dims(output_probs, 0) 9 | output_probs = np.append(1 - output_probs, output_probs, axis=0) 10 | 11 | d = dcrf.DenseCRF2D(w, h, 2) 12 | U = -np.log(output_probs) 13 | U = U.reshape((2, -1)) 14 | U = np.ascontiguousarray(U) 15 | img = np.ascontiguousarray(img) 16 | 17 | d.setUnaryEnergy(U) 18 | 19 | d.addPairwiseGaussian(sxy=20, compat=3) 20 | d.addPairwiseBilateral(sxy=30, srgb=20, rgbim=img, compat=10) 21 | 22 | Q = d.inference(5) 23 | Q = np.argmax(np.array(Q), axis=0).reshape((h, w)) 24 | 25 | return Q 26 | -------------------------------------------------------------------------------- /utils/data_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def plot_img_and_mask(img, mask): 4 | fig = plt.figure() 5 | a = fig.add_subplot(1, 2, 1) 6 | a.set_title('Input image') 7 | plt.imshow(img) 8 | 9 | b = fig.add_subplot(1, 2, 2) 10 | b.set_title('Output mask') 11 | plt.imshow(mask) 12 | plt.show() -------------------------------------------------------------------------------- /utils/imgs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | Sky = [128,128,128] 5 | Building = [128,0,0] 6 | Pole = [192,192,128] 7 | Road = [128,64,128] 8 | 9 | 10 | DSET_MEAN = [0.611, 0.506, 0.54] 11 | DSET_STD = [0.14, 0.16, 0.165] 12 | 13 | label_colours = np.array([Sky, Building, Pole, Road]) 14 | 15 | 16 | def view_annotated(tensor, plot=True): 17 | temp = tensor.numpy() 18 | r = temp.copy() 19 | g = temp.copy() 20 | b = temp.copy() 21 | for l in range(0,3): 22 | r[temp==l]=label_colours[l,0] 23 | g[temp==l]=label_colours[l,1] 24 | b[temp==l]=label_colours[l,2] 25 | 26 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 27 | rgb[:,:,0] = (r/255.0)#[:,:,0] 28 | rgb[:,:,1] = (g/255.0)#[:,:,1] 29 | rgb[:,:,2] = (b/255.0)#[:,:,2] 30 | if plot: 31 | plt.imshow(rgb) 32 | plt.show() 33 | else: 34 | return rgb 35 | 36 | def decode_image(tensor): 37 | inp = tensor.numpy().transpose((1, 2, 0)) 38 | mean = np.array(DSET_MEAN) 39 | std = np.array(DSET_STD) 40 | inp = std * inp + mean 41 | return inp 42 | 43 | def view_image(tensor): 44 | inp = decode_image(tensor) 45 | inp = np.clip(inp, 0, 1) 46 | plt.imshow(inp) 47 | plt.show() 48 | -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | # 2 | # load.py : utils on generators / lists of ids to transform from strings to 3 | # cropped images and masks 4 | 5 | import os 6 | 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from .utils import resize_and_crop, get_square, normalize, hwc_to_chw 11 | 12 | 13 | def get_ids(dir): 14 | """Returns a list of the ids in the directory""" 15 | return (f[:-4] for f in os.listdir(dir)) 16 | 17 | 18 | def split_ids(ids, n=2): 19 | """Split each id in n, creating n tuples (id, k) for each id""" 20 | return ((id, i) for id in ids for i in range(n)) 21 | 22 | 23 | def to_cropped_imgs(ids, dir, suffix, scale): 24 | """From a list of tuples, returns the correct cropped img""" 25 | for id, pos in ids: 26 | im = resize_and_crop(Image.open(dir + id + suffix), scale=scale) 27 | yield get_square(im, pos) 28 | 29 | def get_imgs_and_masks(ids, dir_img, dir_mask, scale): 30 | """Return all the couples (img, mask)""" 31 | 32 | imgs = to_cropped_imgs(ids, dir_img, '.png', scale) 33 | 34 | # need to transform from HWC to CHW 35 | imgs_switched = map(hwc_to_chw, imgs) 36 | imgs_normalized = map(normalize, imgs_switched) 37 | 38 | masks = to_cropped_imgs(ids, dir_mask, '.png', scale) 39 | 40 | return zip(imgs_normalized, masks) 41 | 42 | 43 | def get_full_img_and_mask(id, dir_img, dir_mask): 44 | im = Image.open(dir_img + id + '.png') 45 | mask = Image.open(dir_mask + id + '.png') 46 | return np.array(im), np.array(mask) 47 | -------------------------------------------------------------------------------- /utils/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import sys 5 | import math 6 | import string 7 | import random 8 | import shutil 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision.transforms as transforms 13 | from torchvision.utils import save_image 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | 17 | # from . import imgs as img_utils 18 | 19 | RESULTS_PATH = 'results/' 20 | # WEIGHTS_PATH = 'weights/fcdencenet103/' 21 | # WEIGHTS_PATH = 'weights/defcdencenet56/' 22 | WEIGHTS_PATH = 'weights/36uDeformDenceNet/' 23 | 24 | 25 | def save_weights(model, epoch, loss, err): 26 | weights_fname = 'weights-%d-%.3f-%.3f.pth' % (epoch, loss, err) 27 | weights_fpath = os.path.join(WEIGHTS_PATH, weights_fname) 28 | name = WEIGHTS_PATH + str(epoch) + '_' + str(round(loss,5)) + '_' +str(round(err,5))+ '_model.pkl' 29 | print(name) 30 | 31 | torch.save(model, name) 32 | # torch.save({ 33 | # 'startEpoch': epoch, 34 | # 'loss':loss, 35 | # 'error': err, 36 | # 'acc': 1-err, 37 | # 'state_dict': model.state_dict() 38 | # }, weights_fpath) 39 | # shutil.copyfile(weights_fpath, WEIGHTS_PATH+'latest.th') 40 | 41 | def load_weights(model, fpath): 42 | print("loading weights '{}'".format(fpath)) 43 | weights = torch.load(fpath) 44 | startEpoch = weights['startEpoch'] 45 | model.load_state_dict(weights['state_dict']) 46 | print("loaded weights (lastEpoch {}, loss {}, error {})" 47 | .format(startEpoch-1, weights['loss'], weights['error'])) 48 | return startEpoch 49 | 50 | def get_predictions(output_batch): 51 | bs,c,h,w = output_batch.size() 52 | tensor = output_batch.data 53 | values, indices = tensor.cpu().max(1) 54 | indices = indices.view(bs,h,w) 55 | return indices 56 | 57 | def error(preds, targets): 58 | assert preds.size() == targets.size() 59 | bs,h,w = preds.size() 60 | n_pixels = bs*h*w 61 | incorrect = preds.ne(targets).cpu().sum() 62 | err = incorrect/n_pixels 63 | return round(err,5) 64 | 65 | def train(model, trn_loader, optimizer, criterion, epoch): 66 | model.train() 67 | trn_loss = 0 68 | trn_error = 0 69 | for idx, data in enumerate(trn_loader): 70 | inputs = Variable(data[0].cuda()) 71 | # inputs = Variable(data[0].cuda(0,async=True)) 72 | targets = Variable(data[1].cuda()) 73 | # targets = Variable(data[1].cuda(0,async=True)) 74 | 75 | print('inputs.shape',inputs.shape) 76 | print('targets.shape',targets.shape) 77 | optimizer.zero_grad() 78 | output = model(inputs) 79 | print('output.shape',output.shape) 80 | loss = criterion(output, targets) 81 | loss.backward() 82 | optimizer.step() 83 | 84 | trn_loss += loss.data[0] 85 | pred = get_predictions(output) 86 | trn_error += error(pred, targets.data.cpu()) 87 | 88 | trn_loss /= len(trn_loader) 89 | trn_error /= len(trn_loader) 90 | return trn_loss, trn_error 91 | 92 | def test(model, test_loader, criterion, epoch=1): 93 | # model.eval() 94 | test_loss = 0 95 | test_error = 0 96 | for data, target in test_loader: 97 | data = Variable(data.cuda(), volatile=True) 98 | # data = Variable(data.cuda(0,async=True), volatile=True) 99 | target = Variable(target.cuda()) 100 | # target = Variable(target.cuda(0,async=True)) 101 | output = model(data) 102 | test_loss += criterion(output, target).data[0] 103 | pred = get_predictions(output) 104 | test_error += error(pred, target.data.cpu()) 105 | test_loss /= len(test_loader) 106 | test_error /= len(test_loader) 107 | return test_loss, test_error 108 | 109 | def adjust_learning_rate(lr, decay, optimizer, cur_epoch, n_epochs): 110 | """Sets the learning rate to the initially 111 | configured `lr` decayed by `decay` every `n_epochs`""" 112 | new_lr = lr * (decay ** (cur_epoch // n_epochs)) 113 | for param_group in optimizer.param_groups: 114 | param_group['lr'] = new_lr 115 | 116 | def weights_init(m): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_uniform(m.weight) 119 | m.bias.data.zero_() 120 | 121 | def predict(model, input_loader, n_batches=1): 122 | input_loader.batch_size = 1 123 | predictions = [] 124 | # model.eval() 125 | for input, target in input_loader: 126 | data = Variable(input.cuda(), volatile=True) 127 | label = Variable(target.cuda()) 128 | output = model(data) 129 | pred = get_predictions(output) 130 | predictions.append([input,target,pred]) 131 | return predictions 132 | 133 | def view_sample_predictions_b(model, loader, n): 134 | inputs, targets = next(iter(loader)) 135 | data = Variable(inputs.cuda(), volatile=True) 136 | label = Variable(targets.cuda()) 137 | output = model(data) 138 | pred = get_predictions(output) 139 | batch_size = inputs.size(0) 140 | for i in range(min(n, batch_size)): 141 | img_utils.view_image(inputs[i]) 142 | img_utils.view_annotated(targets[i]) 143 | img_utils.view_annotated(pred[i]) 144 | img_utils.view_annotated(pred[i]) 145 | def view_sample_predictions(model, loader, names): 146 | # inputs, targets = next(iter(loader)) 147 | # data = Variable(inputs.cuda(), volatile=True) 148 | # label = Variable(targets.cuda()) 149 | # output = model(data) 150 | # pred = get_predictions(output) 151 | i = 0 152 | for idx, data in enumerate(loader): 153 | inputs = Variable(data[0].cuda(),volatile=True) 154 | targets = Variable(data[1].cuda()) 155 | output = model(inputs) 156 | pred = get_predictions(output) 157 | name = './data/pred/Full/' + names[i] 158 | # print(name,np.asarray(pred).shape) 159 | # print(name,np.asarray(pred).shape) 160 | # cv2.imwrite(name,80*np.transpose(np.asarray(pred),(1,2,0))) 161 | # cv2.imwrite(name,np.transpose(np.asarray(pred),(1,2,0))) 162 | i = i + 1 -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | def get_square(img, pos): 6 | """Extract a left or a right square from ndarray shape : (H, W, C))""" 7 | h = img.shape[0] 8 | if pos == 0: 9 | return img[:, :h] 10 | else: 11 | return img[:, -h:] 12 | 13 | def split_img_into_squares(img): 14 | return get_square(img, 0), get_square(img, 1) 15 | 16 | def hwc_to_chw(img): 17 | return np.transpose(img, axes=[2, 0, 1]) 18 | 19 | def resize_and_crop(pilimg, scale=0.5, final_height=None): 20 | w = pilimg.size[0] 21 | h = pilimg.size[1] 22 | newW = int(w * scale) 23 | newH = int(h * scale) 24 | 25 | if not final_height: 26 | diff = 0 27 | else: 28 | diff = newH - final_height 29 | 30 | img = pilimg.resize((newW, newH)) 31 | img = img.crop((0, diff // 2, newW, newH - diff // 2)) 32 | return np.array(img, dtype=np.float32) 33 | 34 | def batch(iterable, batch_size): 35 | """Yields lists by batch""" 36 | b = [] 37 | for i, t in enumerate(iterable): 38 | b.append(t) 39 | if (i + 1) % batch_size == 0: 40 | yield b 41 | b = [] 42 | 43 | if len(b) > 0: 44 | yield b 45 | 46 | def split_train_val(dataset, val_percent=0.05): 47 | dataset = list(dataset) 48 | length = len(dataset) 49 | n = int(length * val_percent) 50 | random.shuffle(dataset) 51 | return {'train': dataset[:-n], 'val': dataset[-n:]} 52 | 53 | 54 | def normalize(x): 55 | return x / 255 56 | 57 | def merge_masks(img1, img2, full_w): 58 | h = img1.shape[0] 59 | 60 | new = np.zeros((h, full_w), np.float32) 61 | new[:, :full_w // 2 + 1] = img1[:, :full_w // 2 + 1] 62 | new[:, full_w // 2 + 1:] = img2[:, -(full_w // 2 - 1):] 63 | 64 | return new 65 | 66 | 67 | # credits to https://stackoverflow.com/users/6076729/manuel-lagunas 68 | def rle_encode(mask_image): 69 | pixels = mask_image.flatten() 70 | # We avoid issues with '1' at the start or end (at the corners of 71 | # the original image) by setting those pixels to '0' explicitly. 72 | # We do not expect these to be non-zero for an accurate mask, 73 | # so this should not harm the score. 74 | pixels[0] = 0 75 | pixels[-1] = 0 76 | runs = np.where(pixels[1:] != pixels[:-1])[0] + 2 77 | runs[1::2] = runs[1::2] - runs[:-1:2] 78 | return runs 79 | --------------------------------------------------------------------------------