├── yxk_loss ├── __init__.py ├── loss_test.py └── loss_collection.py ├── image ├── 0_label.png ├── 0_out.png ├── 0_src.jpg ├── 1_label.png ├── 1_out.png ├── 1_src.jpg ├── 4_label.png ├── 4_out.png ├── 4_src.jpg ├── 5_label.png ├── 5_out.png └── 5_src.jpg ├── __pycache__ ├── loss.cpython-36.pyc ├── tools.cpython-36.pyc ├── models.cpython-36.pyc └── voc_loader.cpython-36.pyc ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── FCN_pytorch.iml └── workspace.xml ├── README.md ├── evaluate.py ├── tensorboard.py ├── predict.py ├── tools.py ├── loss.py ├── voc_loader.py ├── train.py ├── result.txt~ └── models.py /yxk_loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/0_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/0_label.png -------------------------------------------------------------------------------- /image/0_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/0_out.png -------------------------------------------------------------------------------- /image/0_src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/0_src.jpg -------------------------------------------------------------------------------- /image/1_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/1_label.png -------------------------------------------------------------------------------- /image/1_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/1_out.png -------------------------------------------------------------------------------- /image/1_src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/1_src.jpg -------------------------------------------------------------------------------- /image/4_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/4_label.png -------------------------------------------------------------------------------- /image/4_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/4_out.png -------------------------------------------------------------------------------- /image/4_src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/4_src.jpg -------------------------------------------------------------------------------- /image/5_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/5_label.png -------------------------------------------------------------------------------- /image/5_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/5_out.png -------------------------------------------------------------------------------- /image/5_src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/5_src.jpg -------------------------------------------------------------------------------- /__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/voc_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/voc_loader.cpython-36.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/FCN_pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fully Convolutional Networks for Semantic Segmentation 2 | 3 | ## Data 4 | Pascal VOC数据集 5 | 数据组织形式:  6 | /home/my_name/data/VOC/VOCdevkit/VOC2012/JPEGImages 7 | 8 | ## 工具 9 | pytorch 0.4.1 10 | tensorboardX (可视化工具) 11 | 12 | 13 | ## 注意事项 14 | 如果你参考别的项目,遇到训练很多次没有效果,可能是他的loss实现的有问题.你可以换他其他的loss函数重新训练,也可以参考本项目的loss. 15 | 交叉熵loss亲测有效. 16 | 17 | ## 使用说明 18 | 19 | ### train 20 | python train.py 21 | 22 | ### evaluate 23 | python evaluate.py 24 | 25 | ### predict 26 | python predict.py 27 | 28 | 29 | ## 可视化查看 30 | ``` 31 | cd FCN_pytorch 32 | tensorboard --logdir ./runs 33 | ``` 34 | 35 | ## 图片保存: 36 | 运行python predict.py,图片保存在指定文件夹中,方便自己查看. 37 | 38 | ## 结果: 39 | 训练200epoch的结果: 40 | 41 | ![avatar](./image/1_src.jpg) 42 | 43 | ![avatar](./image/1_label.png) 44 | 45 | ![avatar](./image/1_out.png) 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | import torch 5 | import models 6 | import voc_loader 7 | import numpy as np 8 | from torch.autograd import Variable 9 | import tools 10 | 11 | 12 | n_class = 21 13 | def evaluate(): 14 | use_cuda = torch.cuda.is_available() 15 | path = os.path.expanduser('/home/yxk/data/') 16 | val_data = voc_loader.VOC2012ClassSeg(root=path, 17 | split='val', 18 | transform=True) 19 | val_loader = torch.utils.data.DataLoader(val_data, 20 | batch_size=1, 21 | shuffle=False, 22 | num_workers=5) 23 | print('load model .....') 24 | vgg_model = models.VGGNet(requires_grad=True) 25 | fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class) 26 | fcn_model.load_state_dict(torch.load('params.pth')) 27 | 28 | if use_cuda: 29 | fcn_model.cuda() 30 | fcn_model.eval() 31 | 32 | label_trues, label_preds = [], [] 33 | # for idx, (img, label) in enumerate(val_loader): 34 | for idx in range(len(val_data)): 35 | img, label = val_data[idx] 36 | img = img.unsqueeze(0) 37 | if use_cuda: 38 | img = img.cuda() 39 | img = Variable(img) 40 | 41 | out = fcn_model(img) # 1, 21, 320, 320 42 | 43 | pred = out.data.max(1)[1].squeeze_(1).squeeze_(0) # 320, 320 44 | 45 | if use_cuda: 46 | pred = pred.cpu() 47 | label_trues.append(label.numpy()) 48 | label_preds.append(pred.numpy()) 49 | 50 | if idx % 30 == 0: 51 | print('evaluate [%d/%d]' % (idx, len(val_loader))) 52 | 53 | metrics = tools.accuracy_score(label_trues, label_preds) 54 | metrics = np.array(metrics) 55 | metrics *= 100 56 | print('''\ 57 | Accuracy: {0} 58 | Accuracy Class: {1} 59 | Mean IU: {2} 60 | FWAV Accuracy: {3}'''.format(*metrics)) 61 | 62 | 63 | if __name__ == '__main__': 64 | evaluate() -------------------------------------------------------------------------------- /tensorboard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchvision.utils as vutils 4 | 5 | import numpy as np 6 | import torchvision.models as models 7 | 8 | from torchvision import datasets 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | resnet18 = models.resnet18(False) 14 | 15 | writer = SummaryWriter() 16 | 17 | sample_rate = 44100 18 | 19 | freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440] 20 | 21 | for n_iter in range(100): 22 | 23 | s1 = torch.rand(1) # value to keep 24 | 25 | s2 = torch.rand(1) 26 | 27 | writer.add_scalar('data/scalar1', s1[0], n_iter) 28 | 29 | writer.add_scalar('data/scalar2', s2[0], n_iter) 30 | 31 | writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter), 32 | 33 | "xcosx":n_iter*np.cos(n_iter), 34 | 35 | "arctanx": np.arctan(n_iter)}, n_iter) 36 | 37 | x = torch.rand(32, 3, 64, 64) 38 | 39 | if n_iter % 10 == 0: 40 | 41 | x = vutils.make_grid(x, normalize=True, scale_each=True) 42 | 43 | writer.add_image('Image', x, n_iter) 44 | 45 | x = torch.zeros(sample_rate*2) 46 | 47 | for i in range(x.size(0)): 48 | 49 | x[i] = np.cos(freqs[n_iter//10]*np.pi*float(i)/float(sample_rate)) 50 | 51 | writer.add_audio('myAudio', x, n_iter, sample_rate=sample_rate) 52 | 53 | writer.add_text('Text', 'text logged at step:'+str(n_iter), n_iter) 54 | 55 | for name, param in resnet18.named_parameters(): 56 | 57 | writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter) 58 | 59 | 60 | writer.add_pr_curve('xoxo', np.random.randint(2, size=100), 61 | 62 | np.random.rand(100), n_iter) # needs tensorboard 0.4RC or later 63 | 64 | dataset = datasets.MNIST('mnist', train=False, download=True) 65 | 66 | images = dataset.test_data[:100].float() 67 | 68 | label = dataset.test_labels[:100] 69 | 70 | features = images.view(100, 784) 71 | 72 | writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1)) 73 | 74 | # export scalar data to JSON for external processing 75 | 76 | writer.export_scalars_to_json("./all_scalars.json") 77 | 78 | writer.close() 79 | 80 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import torchvision.transforms as transforms 9 | import voc_loader 10 | import models 11 | import random 12 | import tools 13 | from loss import CrossEntropyLoss2d 14 | 15 | n_class = 21 16 | def main(): 17 | use_cuda = torch.cuda.is_available() 18 | path = os.path.expanduser('/home/yxk/data/') 19 | 20 | dataset = voc_loader.VOC2012ClassSeg(root=path, 21 | split='train', 22 | transform=True) 23 | 24 | vgg_model = models.VGGNet(requires_grad=True) 25 | fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class) 26 | fcn_model.load_state_dict(torch.load('./pretrained_models/model120.pth', map_location='cpu')) 27 | 28 | fcn_model.eval() 29 | 30 | if use_cuda: 31 | fcn_model.cuda() 32 | 33 | criterion = CrossEntropyLoss2d() 34 | 35 | for i in range(len(dataset)): 36 | idx = random.randrange(0, len(dataset)) 37 | img, label = dataset[idx] 38 | img_name = str(i) 39 | 40 | img_src, _ = dataset.untransform(img, label) # whc 41 | 42 | cv2.imwrite(path + 'image/%s_src.jpg' % img_name, img_src) 43 | tools.labelTopng(label, path + 'image/%s_label.png' % img_name) # 将label转换成图片 44 | 45 | # a = tools.labelToimg(label) 46 | # 47 | # print(a) 48 | 49 | if use_cuda: 50 | img = img.cuda() 51 | label = label.cuda() 52 | img = Variable(img.unsqueeze(0), volatile=True) 53 | label = Variable(label.unsqueeze(0), volatile=True) 54 | # print("label: ", label.data) 55 | 56 | out = fcn_model(img) # (1, 21, 320, 320) 57 | loss = criterion(out, label) 58 | # print(img_name, 'loss:', loss.data[0]) 59 | 60 | net_out = out.data.max(1)[1].squeeze_(0) # 320, 320 61 | # print(out.data.max(1)[1].shape) 62 | # print("out", net_out) 63 | if use_cuda: 64 | net_out = net_out.cpu() 65 | 66 | tools.labelTopng(net_out, path + 'image/%s_out.png' % img_name) # 将网络输出转换成图片 67 | 68 | if i == 10: 69 | break 70 | 71 | 72 | if __name__ == '__main__': 73 | main() -------------------------------------------------------------------------------- /yxk_loss/loss_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | ''' 5 | 参考文献: https://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medium=referral 6 | 如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss; 7 | 如果 reduce = True,那么 loss 返回的是标量 8 | 9 | 如果 size_average = True,返回 loss.mean(); 10 | 如果 size_average = True,返回 loss.sum(); 11 | ''' 12 | 13 | # nn.L1Loss: loss(input, target)=|input-target| 14 | if False: 15 | loss_fn = torch.nn.L1Loss(reduce=True, size_average=False) 16 | input = torch.autograd.Variable(torch.randn(3, 4)) 17 | target = torch.autograd.Variable(torch.randn(3, 4)) 18 | loss = loss_fn(input, target) 19 | print(input) 20 | print(target) 21 | print(loss) 22 | print(input.size(), target.size(), loss.size()) 23 | 24 | 25 | # nn.SmoothL1Loss  在(-1, 1)上是平方loss, 其他情况是L1 loss 26 | if False: 27 | loss_fn = torch.nn.SmoothL1Loss(reduce=False, size_average=False) 28 | input = torch.autograd.Variable(torch.randn(3, 4)) 29 | target = torch.autograd.Variable(torch.randn(3, 4)) 30 | loss = loss_fn(input, target) 31 | print(input) 32 | print(target) 33 | print(loss) 34 | print(input.size(), target.size(), loss.size()) 35 | 36 | # nn.MSELoss 均方损失函数 37 | if False: 38 | loss_fn = torch.nn.MSELoss(reduce=False, size_average=False) 39 | input = torch.autograd.Variable(torch.randn(3, 4)) 40 | target = torch.autograd.Variable(torch.randn(3, 4)) 41 | loss = loss_fn(input, target) 42 | print(input) 43 | print(target) 44 | print(loss) 45 | print(input.size(), target.size(), loss.size()) 46 | 47 | # nn.BCELoss 48 | if False: 49 | import torch.nn.functional as F 50 | 51 | loss_fn = torch.nn.BCELoss(reduce=False, size_average=False) 52 | input = torch.autograd.Variable(torch.randn(3, 4)) 53 | target = torch.autograd.Variable(torch.FloatTensor(3, 4).random_(2)) 54 | loss = loss_fn(F.sigmoid(input), target) 55 | print(input, input.shape) 56 | print(F.sigmoid(input)) 57 | print(target, target.shape) 58 | print(loss, loss.shape) 59 | 60 | # nn.CrossEntropyLoss 61 | if False: 62 | weight = torch.Tensor([1, 2, 1, 1, 10]) 63 | loss_fn = torch.nn.CrossEntropyLoss(reduce=False, size_average=False, weight=None) 64 | input = Variable(torch.randn(3, 5)) # (batch_size, C) 65 | target = Variable(torch.LongTensor(3).random_(5)) 66 | loss = loss_fn(input, target) 67 | print(input) 68 | print(target) 69 | print(loss) 70 | 71 | 72 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image as Image 4 | 5 | def getPalette(): 6 | ''' 7 | http://blog.csdn.net/yhl_leo/article/details/52185581 8 | ''' 9 | pal = np.array([[0, 0, 0], 10 | [128, 0, 0], 11 | [0, 128, 0], 12 | [128, 128, 0], 13 | [0, 0, 128], 14 | [128, 0, 128], 15 | [0, 128, 128], 16 | [128, 128, 128], 17 | [64, 0, 0], 18 | [192, 0, 0], 19 | [64, 128, 0], 20 | [192, 128, 0], 21 | [64, 0, 128], 22 | [192, 0, 128], 23 | [64, 128, 128], 24 | [192, 128, 128], 25 | [0, 64, 0], 26 | [128, 64, 0], 27 | [0, 192, 0], 28 | [128, 192, 0], 29 | [0, 64, 128]], dtype='uint8').flatten() 30 | return pal 31 | 32 | 33 | def colorize_mask(mask): 34 | """ 35 | :param mask: 图片大小的数值,代表不同的颜色 36 | :return: 37 | """ 38 | new_mask = Image.fromarray(mask.astype(np.uint8), 'P') # 将二维数组转为图像 39 | 40 | pal = getPalette() 41 | new_mask.putpalette(pal) 42 | # print(new_mask.show()) 43 | return new_mask 44 | 45 | # m = np.array([[1,2], [3,4]]) 46 | # colorize_mask(m) 47 | 48 | 49 | def getFileName(file_path): 50 | ''' 51 | get file_path name from path+name+'test.jpg' 52 | return test 53 | ''' 54 | full_name = file_path.split('/')[-1] 55 | name = os.path.splitext(full_name)[0] 56 | 57 | return name 58 | 59 | 60 | def labelTopng(label, img_name): 61 | ''' 62 | convert tensor cpu label to png and save 63 | ''' 64 | label = label.numpy() # 320 320 65 | label_pil = colorize_mask(label) 66 | label_pil.save(img_name) 67 | 68 | def labelToimg(label): 69 | label = label.numpy() 70 | label_pil = colorize_mask(label) 71 | return label_pil 72 | 73 | 74 | def _fast_hist(label_true, label_pred, n_class): 75 | mask = (label_true >= 0) & (label_true < n_class) 76 | hist = np.bincount( 77 | n_class * label_true[mask].astype(int) + 78 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class) 79 | return hist 80 | 81 | 82 | def accuracy_score(label_trues, label_preds, n_class=21): 83 | """Returns accuracy score evaluation result. 84 | - overall accuracy 85 | - mean accuracy 86 | - mean IU 87 | - fwavacc 88 | """ 89 | hist = np.zeros((n_class, n_class)) 90 | for lt, lp in zip(label_trues, label_preds): 91 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) # n_class, n_class 92 | acc = np.diag(hist).sum() / hist.sum() 93 | acc_cls = np.diag(hist) / hist.sum(axis=1) 94 | acc_cls = np.nanmean(acc_cls) 95 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 96 | mean_iu = np.nanmean(iu) 97 | freq = hist.sum(axis=1) / hist.sum() 98 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 99 | return acc, acc_cls, mean_iu, fwavacc -------------------------------------------------------------------------------- /yxk_loss/loss_collection.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | ''' 7 | 如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss; 8 | 如果 reduce = True,那么 loss 返回的是标量 9 | 10 | 如果 size_average = True,返回 loss.mean(); 11 | 如果 size_average = True,返回 loss.sum(); 12 | ''' 13 | 14 | 15 | class CrossEntropyLoss(nn.Module): 16 | def __init__(self): 17 | super(CrossEntropyLoss, self).__init__() 18 | self.cross_entropy_loss = nn.CrossEntropyLoss(weight=None, size_average=False) 19 | 20 | def forward(self, inputs, targets): 21 | return self.cross_entropy_loss(inputs, targets) 22 | 23 | class CrossEntropyLoss2d(nn.Module): 24 | """ 25 | Negative Log Likelihood 26 | """ 27 | 28 | def __init__(self, weight=None, size_average=True): 29 | super(CrossEntropyLoss2d, self).__init__() 30 | self.nll_loss = nn.NLLLoss2d(weight, size_average) 31 | 32 | def forward(self, inputs, targets): 33 | return self.nll_loss(F.log_softmax(inputs), targets) 34 | 35 | def smooth_l1(deltas, targets, sigma=3.0): 36 | """ 37 | :param deltas: (tensor) predictions, sized [N,D]. 38 | :param targets: (tensor) targets, sized [N,]. 39 | :param sigma: 3.0 40 | :return: 41 | """ 42 | 43 | sigma2 = sigma * sigma 44 | diffs = deltas - targets 45 | smooth_l1_signs = torch.min(torch.abs(diffs), 1.0 / sigma2).detach().float() 46 | 47 | smooth_l1_option1 = torch.mul(diffs, diffs) * 0.5 * sigma2 48 | smooth_l1_option2 = torch.abs(diffs) - 0.5 / sigma2 49 | smooth_l1_add = torch.mul(smooth_l1_option1, smooth_l1_signs) + \ 50 | torch.mul(smooth_l1_option2, 1 - smooth_l1_signs) 51 | smooth_l1 = smooth_l1_add 52 | 53 | return smooth_l1 54 | 55 | class FocalLoss(nn.Module): 56 | 57 | def __init__(self, gamma=0, alpha=None, size_average=True): 58 | super(FocalLoss, self).__init__() 59 | self.gamma = gamma 60 | self.alpha = alpha 61 | if isinstance(alpha, (float, int, float)): self.alpha = torch.Tensor([alpha, 1-alpha]) 62 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) 63 | self.size_average = size_average 64 | 65 | def forward(self, input, target): 66 | if input.dim() > 2: 67 | input = input.view(input.size(0), input.size(1),-1) # N,C,H,W => N,C,H*W 68 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 69 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 70 | target = target.view(-1, 1) 71 | 72 | logpt = F.log_softmax(input) 73 | logpt = logpt.gather(1, target) 74 | logpt = logpt.view(-1) 75 | pt = Variable(logpt.data.exp()) 76 | 77 | if self.alpha is not None: 78 | if self.alpha.type() != input.data.type(): 79 | self.alpha = self.alpha.type_as(input.data) 80 | at = self.alpha.gather(0, target.data.view(-1)) 81 | logpt = logpt * Variable(at) 82 | 83 | loss = -1 * (1-pt)**self.gamma * logpt 84 | if self.size_average: return loss.mean() 85 | else: 86 | return loss.sum() -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class CrossEntropy2d(nn.Module): 8 | ''' 9 | 这个实现有问题, mmp 误事 10 | loss doesn't change, loss can not be backward? 11 | 12 | why need change? only net weight need to be change. 13 | ''' 14 | def __init__(self): 15 | super(CrossEntropy2d, self).__init__() 16 | self.criterion = nn.CrossEntropyLoss(weight=None, size_average=False) # should size_average=False? True for average,false for tatal 17 | 18 | def forward(self, out, target): 19 | n, c, h, w = out.size() # n:batch_size, c:class 20 | out = out.view(-1, c) # (n*h*w, c) 21 | target = target.view(-1) # (n*h*w) 22 | # print('out', out.size(), 'target', target.size()) 23 | loss = self.criterion(out, target) 24 | 25 | return loss 26 | 27 | class CrossEntropyLoss(nn.Module): 28 | def __init__(self): 29 | super(CrossEntropyLoss, self).__init__() 30 | self.cross_entropy_loss = nn.CrossEntropyLoss(weight=None, size_average=False) 31 | 32 | def forward(self, inputs, targets): 33 | return self.cross_entropy_loss(inputs, targets) 34 | 35 | class CrossEntropyLoss2d(nn.Module): 36 | """ 37 | 亲测有效 38 | """ 39 | 40 | def __init__(self, weight=None, size_average=True): 41 | super(CrossEntropyLoss2d, self).__init__() 42 | self.nll_loss = nn.NLLLoss2d(weight, size_average) 43 | 44 | def forward(self, inputs, targets): 45 | return self.nll_loss(F.log_softmax(inputs), targets) 46 | 47 | 48 | def smooth_l1(deltas, targets, sigma=3.0): 49 | """ 50 | :param deltas: (tensor) predictions, sized [N,D]. 51 | :param targets: (tensor) targets, sized [N,]. 52 | :param sigma: 3.0 53 | :return: 54 | """ 55 | 56 | sigma2 = sigma * sigma 57 | diffs = deltas - targets 58 | smooth_l1_signs = torch.min(torch.abs(diffs), 1.0 / sigma2).detach().float() 59 | 60 | smooth_l1_option1 = torch.mul(diffs, diffs) * 0.5 * sigma2 61 | smooth_l1_option2 = torch.abs(diffs) - 0.5 / sigma2 62 | smooth_l1_add = torch.mul(smooth_l1_option1, smooth_l1_signs) + \ 63 | torch.mul(smooth_l1_option2, 1 - smooth_l1_signs) 64 | smooth_l1 = smooth_l1_add 65 | 66 | return smooth_l1 67 | 68 | 69 | class FocalLoss(nn.Module): 70 | def __init__(self, gamma=0, alpha=None, size_average=True): 71 | super(FocalLoss, self).__init__() 72 | self.gamma = gamma 73 | self.alpha = alpha 74 | if isinstance(alpha, (float, int, float)): self.alpha = torch.Tensor([alpha, 1-alpha]) 75 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) 76 | self.size_average = size_average 77 | 78 | def forward(self, input, target): 79 | if input.dim() > 2: 80 | input = input.view(input.size(0), input.size(1),-1) # N,C,H,W => N,C,H*W 81 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 82 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 83 | target = target.view(-1, 1) 84 | 85 | logpt = F.log_softmax(input) 86 | logpt = logpt.gather(1, target) 87 | logpt = logpt.view(-1) 88 | pt = Variable(logpt.data.exp()) 89 | 90 | if self.alpha is not None: 91 | if self.alpha.type() != input.data.type(): 92 | self.alpha = self.alpha.type_as(input.data) 93 | at = self.alpha.gather(0, target.data.view(-1)) 94 | logpt = logpt * Variable(at) 95 | 96 | loss = -1 * (1-pt)**self.gamma * logpt 97 | if self.size_average: return loss.mean() 98 | else: 99 | return loss.sum() 100 | 101 | 102 | if __name__ == '__main__': 103 | loss_fn = torch.nn.CrossEntropyLoss(reduce=False, size_average=False, weight=None) 104 | input = Variable(torch.randn(2, 3, 5)) # (batch_size, C) 105 | target = Variable(torch.LongTensor(2, 5).random_(3)) 106 | loss = loss_fn(input, target) 107 | print(input) 108 | print(target) 109 | print(loss) 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /voc_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import collections 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import PIL.Image 8 | import scipy.io 9 | import torch 10 | from torch.utils import data 11 | import cv2 12 | import random 13 | 14 | 15 | """ 16 | https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/datasets/voc.py 17 | """ 18 | 19 | 20 | class VOCClassSegBase(data.Dataset): 21 | 22 | class_names = np.array([ 23 | 'background', 24 | 'aeroplane', 25 | 'bicycle', 26 | 'bird', 27 | 'boat', 28 | 'bottle', 29 | 'bus', 30 | 'car', 31 | 'cat', 32 | 'chair', 33 | 'cow', 34 | 'diningtable', 35 | 'dog', 36 | 'horse', 37 | 'motorbike', 38 | 'person', 39 | 'potted plant', 40 | 'sheep', 41 | 'sofa', 42 | 'train', 43 | 'tv/monitor', 44 | ]) 45 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 46 | 47 | 48 | def __init__(self, root, split='train', transform=True): 49 | self.root = root 50 | self.split = split 51 | self._transform = transform 52 | 53 | # VOC2011 and others are subset of VOC2012 54 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012') 55 | # dataset_dir = osp.join(self.root, 'VOC2007') 56 | 57 | self.files = collections.defaultdict(list) 58 | for split_file in ['train', 'val']: 59 | imgsets_file = osp.join( 60 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split_file) 61 | for img_name in open(imgsets_file): 62 | img_name = img_name.strip() 63 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % img_name) 64 | lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % img_name) 65 | self.files[split_file].append({ 66 | 'img': img_file, 67 | 'lbl': lbl_file, 68 | }) 69 | 70 | def __len__(self): 71 | return len(self.files[self.split]) 72 | 73 | def __getitem__(self, index): 74 | data_file = self.files[self.split][index] # 数据 75 | # load image 76 | img_file = data_file['img'] 77 | img = PIL.Image.open(img_file) 78 | img = np.array(img, dtype=np.uint8) 79 | # load label 80 | lbl_file = data_file['lbl'] 81 | lbl = PIL.Image.open(lbl_file) 82 | lbl = np.array(lbl, dtype=np.uint8) 83 | 84 | lbl[lbl == 255] = 0 85 | # augment 86 | img, lbl = self.randomFlip(img, lbl) 87 | img, lbl = self.randomCrop(img, lbl) 88 | img, lbl = self.resize(img, lbl) 89 | 90 | if self._transform: 91 | return self.transform(img, lbl) 92 | else: 93 | return img, lbl 94 | 95 | 96 | def transform(self, img, lbl): 97 | img = img[:, :, ::-1] # RGB -> BGR 98 | img = img.astype(np.float64) 99 | img -= self.mean_bgr 100 | img = img.transpose(2, 0, 1) # whc -> cwh 101 | img = torch.from_numpy(img).float() 102 | lbl = torch.from_numpy(lbl).long() 103 | return img, lbl 104 | 105 | def untransform(self, img, lbl): 106 | img = img.numpy() 107 | img = img.transpose(1, 2, 0) # cwh -> whc 108 | img += self.mean_bgr 109 | img = img.astype(np.uint8) 110 | img = img[:, :, ::-1] # BGR -> RGB 111 | lbl = lbl.numpy() 112 | return img, lbl 113 | 114 | def randomFlip(self, img, label): 115 | if random.random() < 0.5: 116 | img = np.fliplr(img) 117 | label = np.fliplr(label) 118 | return img, label 119 | 120 | def resize(self, img, label, s=320): 121 | # print(s, img.shape) 122 | img = cv2.resize(img, (s, s), interpolation=cv2.INTER_LINEAR) 123 | label = cv2.resize(label, (s, s), interpolation=cv2.INTER_NEAREST) 124 | return img, label 125 | 126 | def randomCrop(self, img, label): 127 | h, w, _ = img.shape 128 | short_size = min(w, h) 129 | rand_size = random.randrange(int(0.7 * short_size), short_size) 130 | x = random.randrange(0, w - rand_size) 131 | y = random.randrange(0, h - rand_size) 132 | 133 | return img[y:y + rand_size, x:x + rand_size], label[y:y + rand_size, x:x + rand_size] 134 | # data augmentaion 135 | def augmentation(self, img, lbl): 136 | img, lbl = self.randomFlip(img, lbl) 137 | img, lbl = self.randomCrop(img, lbl) 138 | img, lbl = self.resize(img, lbl) 139 | return img, lbl 140 | 141 | # elif not self.predict: # for batch test, this is needed 142 | # img, label = self.randomCrop(img, label) 143 | # img, label = self.resize(img, label, VOCClassSeg.img_size) 144 | # else: 145 | # pass 146 | 147 | 148 | class VOC2012ClassSeg(VOCClassSegBase): 149 | 150 | # url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' # NOQA 151 | 152 | def __init__(self, root, split='train', transform=False): 153 | super(VOC2012ClassSeg, self).__init__( 154 | root, split=split, transform=transform) 155 | 156 | 157 | """ 158 | vocbase = VOC2012ClassSeg(root="/home/yxk/Downloads/") 159 | 160 | print(vocbase.__len__()) 161 | img, lbl = vocbase.__getitem__(0) 162 | img = img[:, :, ::-1] 163 | img = cv2.resize(img, (320, 320), interpolation=cv2.INTER_LINEAR) 164 | print(np.shape(img)) 165 | print(np.shape(lbl)) 166 | 167 | """ 168 | 169 | 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.utils.data.dataloader 6 | import numpy as np 7 | import torchvision 8 | import models 9 | import voc_loader 10 | import loss 11 | from torch.optim import Adam, SGD 12 | from tensorboardX import SummaryWriter 13 | from argparse import ArgumentParser 14 | import tools 15 | 16 | 17 | # argumentparse 18 | parser = ArgumentParser() 19 | parser.add_argument('-bs', '--batch_size', type=int, default=2, help="batch size of the data") 20 | parser.add_argument('-e', '--epochs', type=int, default=300, help='epoch of the train') 21 | parser.add_argument('-c', '--n_class', type=int, default=21, help='the classes of the dataset') 22 | parser.add_argument('-lr', '--learning_rate', type=float, default=1e-3, help='learning rate') 23 | args = parser.parse_args() 24 | 25 | # import visualize 26 | writer = SummaryWriter() 27 | 28 | batch_size = args.batch_size 29 | learning_rate = args.learning_rate 30 | epoch_num = args.epochs 31 | n_class = args.n_class 32 | 33 | 34 | best_test_loss = np.inf 35 | pretrained = 'reload' 36 | use_cuda = torch.cuda.is_available() 37 | 38 | # path = os.path.expanduser('/home/yxk/Downloads/') 39 | 40 | # dataset 2007 41 | data_path = os.path.expanduser('/home/yxk/data/') 42 | 43 | print('load data....') 44 | train_data = voc_loader.VOC2012ClassSeg(root=data_path, split='train', transform=True) 45 | 46 | train_loader = torch.utils.data.DataLoader(train_data, 47 | batch_size=batch_size, 48 | shuffle=True, 49 | num_workers=5) 50 | val_data = voc_loader.VOC2012ClassSeg(root=data_path, 51 | split='val', 52 | transform=True) 53 | val_loader = torch.utils.data.DataLoader(val_data, 54 | batch_size=batch_size, 55 | shuffle=False, 56 | num_workers=5) 57 | 58 | vgg_model = models.VGGNet(requires_grad=True) 59 | fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class) 60 | 61 | if use_cuda: 62 | fcn_model.cuda() 63 | 64 | criterion = loss.CrossEntropyLoss2d() 65 | # create your optimizer 66 | optimizer = Adam(fcn_model.parameters()) 67 | # optimizer = torch.optim.SGD(fcn_model.parameters(), lr=0.01) 68 | 69 | def train(epoch): 70 | fcn_model.train() # tran mode 71 | total_loss = 0. 72 | for batch_idx, (imgs, labels) in enumerate(train_loader): 73 | N = imgs.size(0) 74 | if use_cuda: 75 | imgs = imgs.cuda() 76 | labels = labels.cuda() 77 | 78 | imgs_tensor = Variable(imgs) # torch.Size([2, 3, 320, 320]) 79 | labels_tensor = Variable(labels) # torch.Size([2, 320, 320]) 80 | out = fcn_model(imgs_tensor) # torch.Size([2, 21, 320, 320]) 81 | 82 | # with open('./result.txt', 'r+') as f: 83 | # f.write(str(out.detach().numpy())) 84 | # f.write("\n") 85 | 86 | loss = criterion(out, labels_tensor) 87 | loss /= N 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() # update all arguments 91 | total_loss += loss.data[0] # return float 92 | 93 | # if batch_idx == 2: 94 | # break 95 | 96 | if (batch_idx) % 20 == 0: 97 | print('train epoch [%d/%d], iter[%d/%d], lr %.7f, aver_loss %.5f' % (epoch, 98 | epoch_num, batch_idx, 99 | len(train_loader), learning_rate, 100 | total_loss / (batch_idx + 1))) 101 | 102 | # # visiualize scalar 103 | # if epoch % 10 == 0: 104 | # label_img = tools.labelToimg(labels[0]) 105 | # net_out = out[0].data.max(1)[1].squeeze_(0) 106 | # out_img = tools.labelToimg(net_out) 107 | # writer.add_scalar("loss", loss, epoch) 108 | # writer.add_scalar("total_loss", total_loss, epoch) 109 | # writer.add_scalars('loss/scalar_group', {"loss": epoch * loss, 110 | # "total_loss": epoch * total_loss}) 111 | # writer.add_image('Image', imgs[0], epoch) 112 | # writer.add_image('label', label_img, epoch) 113 | # writer.add_image("out", out_img, epoch) 114 | 115 | assert total_loss is not np.nan 116 | assert total_loss is not np.inf 117 | 118 | # model save 119 | if (epoch) % 20 == 0: 120 | torch.save(fcn_model.state_dict(), './pretrained_models/model%d.pth'%epoch) # save for 5 epochs 121 | total_loss /= len(train_loader) 122 | print('train epoch [%d/%d] average_loss %.5f' % (epoch, epoch_num, total_loss)) 123 | 124 | 125 | def test(epoch): 126 | fcn_model.eval() 127 | total_loss = 0. 128 | for batch_idx, (imgs, labels) in enumerate(val_loader): 129 | N = imgs.size(0) 130 | if use_cuda: 131 | imgs = imgs.cuda() 132 | labels = labels.cuda() 133 | imgs = Variable(imgs) # , volatile=True 134 | labels = Variable(labels) # , volatile=True 135 | out = fcn_model(imgs) 136 | loss = criterion(out, labels) 137 | loss /= N 138 | total_loss += loss.data[0] 139 | 140 | if (batch_idx + 1) % 3 == 0: 141 | print('test epoch [%d/%d], iter[%d/%d], aver_loss %.5f' % (epoch, 142 | epoch_num, batch_idx, len(val_loader), 143 | total_loss / (batch_idx + 1))) 144 | 145 | 146 | 147 | total_loss /= len(val_loader) 148 | print('test epoch [%d/%d] average_loss %.5f' % (epoch, epoch_num, total_loss)) 149 | 150 | global best_test_loss 151 | if best_test_loss > total_loss: 152 | best_test_loss = total_loss 153 | print('best loss....') 154 | # fcn_model.save('SBD.pth') 155 | 156 | 157 | if __name__ == '__main__': 158 | # print(torch.cuda.is_available()) 159 | for epoch in range(epoch_num): 160 | train(epoch) 161 | # test(epoch) 162 | # adjust learning rate 163 | if epoch == 20: 164 | learning_rate *= 0.01 165 | optimizer.param_groups[0]['lr'] = learning_rate 166 | # optimizer.param_groups[1]['lr'] = learning_rate * 2 -------------------------------------------------------------------------------- /result.txt~: -------------------------------------------------------------------------------- 1 | [[[[ 5.92323971e+00 -8.00159335e-01 -4.08085883e-02 ... -1.83560163e-01 2 | -2.73198158e-01 8.72965217e-01] 3 | [ 7.91519880e-01 -6.11786366e-01 -5.96345067e-01 ... 4.50389147e-01 4 | 1.13440382e+00 7.10353255e-01] 5 | [ 1.36564672e+00 1.48077452e+00 -9.77772772e-02 ... -7.36534715e-01 6 | 1.47829682e-01 2.38782719e-01] 7 | ... 8 | [ 4.10035086e+00 2.12672305e+00 -2.63656586e-01 ... 5.45657754e-01 9 | 8.17386031e-01 -1.19801074e-01] 10 | [-6.56773925e-01 4.99885708e-01 -9.63091254e-01 ... -1.77323252e-01 11 | 2.85922766e-01 7.60141313e-01] 12 | [ 1.10465205e+00 4.30303049e+00 4.33405042e-01 ... 1.23104215e+00 13 | 1.31855106e+00 9.57159519e-01]] 14 | 15 | [[-1.10851789e+00 6.93656325e-01 -1.98275137e+00 ... -2.27719402e+00 16 | -1.75730240e+00 1.84114003e+00] 17 | [-1.23786783e+00 1.29358852e+00 4.08284992e-01 ... 9.27657843e-01 18 | -1.99201500e+00 2.58990735e-01] 19 | [-2.54253864e+00 -7.82760739e-01 9.52594429e-02 ... 9.89706397e-01 20 | 2.07655117e-01 9.46576118e-01] 21 | ... 22 | [-8.37866366e-01 -1.01337838e+00 3.31114709e-01 ... 1.05518985e+00 23 | -1.62823427e+00 1.81347169e-02] 24 | [-4.31546807e-01 1.43094845e-02 4.65616274e+00 ... 4.70807219e+00 25 | -4.96275946e-02 4.37961042e-01] 26 | [-5.69424808e-01 -2.60643095e-01 5.83820283e-01 ... 7.80711055e-01 27 | -1.78764209e-01 -2.10972771e-01]] 28 | 29 | [[ 1.35069156e+00 -1.27873790e+00 -1.16048425e-01 ... 4.98614877e-01 30 | 9.64964747e-01 -1.32347178e+00] 31 | [ 1.02951872e+00 -1.43840754e+00 4.39141244e-01 ... 4.35016841e-01 32 | 4.43858176e-01 -1.03785980e+00] 33 | [-1.94979334e+00 2.59248108e-01 -1.03685164e+00 ... -6.70111299e-01 34 | 6.39723539e+00 8.45238268e-02] 35 | ... 36 | [-5.84895015e-01 -7.93908477e-01 -3.35549533e-01 ... -2.78513074e-01 37 | 4.41413069e+00 -8.51080298e-01] 38 | [-5.95055640e-01 -1.84282732e+00 -8.80498886e-01 ... -4.80649233e-01 39 | -3.42939615e-01 -5.18452287e-01] 40 | [-1.11653221e+00 2.84047395e-01 -1.27934396e+00 ... -1.14018810e+00 41 | -3.95604074e-01 -4.71287340e-01]] 42 | 43 | ... 44 | 45 | [[-1.39704025e+00 4.81168538e-01 3.22706223e-01 ... 7.93970108e-01 46 | -1.44059372e+00 3.61856669e-01] 47 | [-1.30902135e+00 -6.25786334e-02 -9.29827452e-01 ... -9.13603425e-01 48 | -1.42951572e+00 -6.53945327e-01] 49 | [-1.60974324e-01 -7.23360181e-01 2.42996597e+00 ... 3.82239413e+00 50 | 1.36434108e-01 4.94289696e-02] 51 | ... 52 | [-6.84200883e-01 3.93484265e-01 3.60280037e+00 ... 3.37190700e+00 53 | -3.26695889e-01 -5.98627567e-01] 54 | [-1.59633458e+00 -5.97781062e-01 8.47480774e-01 ... 6.81312203e-01 55 | -3.30486387e-01 -6.95333838e-01] 56 | [-2.87006021e-01 -4.44574237e-01 8.85481954e-01 ... 6.13724589e-01 57 | 6.60700560e-01 2.30180919e-02]] 58 | 59 | [[ 1.17688119e+00 -2.55626726e+00 -3.06273028e-02 ... 7.17433274e-01 60 | -6.34098649e-01 -2.37238622e+00] 61 | [ 4.78541493e-01 -2.96923548e-01 4.50580835e-01 ... -9.04723406e-02 62 | 4.63649654e+00 3.16596508e-01] 63 | [-7.02592194e-01 -3.03770155e-01 -3.37809443e-01 ... 6.68550208e-02 64 | -1.19772375e+00 -2.58156729e+00] 65 | ... 66 | [-5.74998677e-01 -7.27917492e-01 3.11518461e-01 ... -1.54708683e-01 67 | 1.56628579e-01 -4.19498712e-01] 68 | [-1.07138073e+00 8.90584514e-02 -2.41475195e-01 ... -6.66464567e-01 69 | 3.09198707e-01 6.07450902e-01] 70 | [-5.77349544e-01 1.22415029e-01 -1.21873701e+00 ... -1.71460152e+00 71 | -3.33992660e-01 -4.11800772e-01]] 72 | 73 | [[-1.52562523e+00 -2.21709341e-01 -9.81039703e-01 ... -1.32452083e+00 74 | -1.32980561e+00 2.72239351e+00] 75 | [-1.59430981e+00 -9.04133737e-01 -7.92442918e-01 ... -1.58888841e+00 76 | -1.06508708e+00 -9.98270869e-01] 77 | [-2.59319639e+00 -1.04864287e+00 7.85771608e-02 ... -1.39404207e-01 78 | -1.01008296e+00 9.52526629e-01] 79 | ... 80 | [-1.74395108e+00 -1.16058242e+00 5.95709682e-02 ... 5.37530482e-02 81 | -4.73932147e-01 -2.40867734e-01] 82 | [ 1.47193265e+00 -1.19619203e+00 5.09279072e-02 ... -1.54825449e-02 83 | -1.75131059e+00 -3.87153745e-01] 84 | [-4.38177288e-01 -1.15348125e+00 -4.65141952e-01 ... -7.71390259e-01 85 | -8.49012077e-01 -7.12182224e-01]]] 86 | 87 | 88 | [[[ 5.79030466e+00 -6.13477111e-01 1.34621739e-01 ... 2.35931322e-01 89 | 3.62992436e-01 9.07037079e-01] 90 | [ 7.75903106e-01 -6.45938516e-01 -6.03215575e-01 ... 2.59432852e-01 91 | 1.08058214e+00 6.78438067e-01] 92 | [ 1.43162465e+00 1.51875818e+00 -8.04446638e-02 ... -4.78534549e-01 93 | 4.73610580e-01 2.21373439e-02] 94 | ... 95 | [ 4.10263252e+00 2.25726724e+00 -2.02034026e-01 ... 4.84108210e-01 96 | 6.01388156e-01 -2.65990585e-01] 97 | [-5.95873713e-01 5.99902868e-01 -7.94230580e-01 ... -2.45066911e-01 98 | 2.16514528e-01 7.37514496e-01] 99 | [ 1.16487753e+00 4.30234385e+00 5.37481904e-01 ... 1.46371186e+00 100 | 1.33895302e+00 1.19199586e+00]] 101 | 102 | [[-1.03382134e+00 6.44701064e-01 -1.82450879e+00 ... -1.76479149e+00 103 | -1.36212254e+00 1.43689275e+00] 104 | [-1.23853838e+00 1.36728752e+00 5.08096933e-01 ... 7.33646035e-01 105 | -1.19717956e+00 5.10577559e-01] 106 | [-2.37532234e+00 -7.42191195e-01 1.42771542e-01 ... 5.63968301e-01 107 | -3.23223203e-01 6.67942166e-01] 108 | ... 109 | [-8.30584466e-01 -1.05315518e+00 4.86099184e-01 ... 9.79613185e-01 110 | -2.00232697e+00 -2.39551976e-01] 111 | [-3.91282946e-01 -3.54413316e-03 4.65324020e+00 ... 5.34130001e+00 112 | -3.45440149e-01 2.99561739e-01] 113 | [-5.26298404e-01 -2.22298160e-01 4.12146777e-01 ... 6.14750028e-01 114 | -3.49447429e-01 -2.99425632e-01]] 115 | 116 | [[ 1.20366549e+00 -1.35653746e+00 1.97638273e-02 ... 7.36568451e-01 117 | 5.61294317e-01 -1.12250817e+00] 118 | [ 9.76951838e-01 -1.44697058e+00 3.67370695e-01 ... 5.68231344e-02 119 | -1.32984102e-01 -1.14536870e+00] 120 | [-1.83419859e+00 1.97738379e-01 -9.99228954e-01 ... -8.64497900e-01 121 | 4.54018593e+00 -3.18870276e-01] 122 | ... 123 | [-6.12242699e-01 -7.50237465e-01 -3.86662364e-01 ... -5.03511131e-01 124 | 5.42525959e+00 -9.06905651e-01] 125 | [-6.22503996e-01 -1.79135609e+00 -9.82421637e-01 ... -1.15232557e-01 126 | -3.47687066e-01 -3.64765257e-01] 127 | [-1.13146615e+00 2.28893250e-01 -1.21872151e+00 ... -1.05051839e+00 128 | -3.50717962e-01 -3.68980527e-01]] 129 | 130 | ... 131 | 132 | [[-1.34322214e+00 5.18092990e-01 3.07727635e-01 ... 6.34089351e-01 133 | -1.30468976e+00 1.36265606e-01] 134 | [-1.35041499e+00 -1.45696998e-02 -8.98943424e-01 ... -6.61089420e-01 135 | -1.20868838e+00 -5.06785631e-01] 136 | [-1.18282229e-01 -6.79277182e-01 2.40828681e+00 ... 3.20594454e+00 137 | -6.19544685e-02 -8.86955559e-02] 138 | ... 139 | [-6.57785892e-01 3.32177550e-01 3.55007601e+00 ... 3.67592573e+00 140 | -3.96220297e-01 -6.03670657e-01] 141 | [-1.62041402e+00 -6.26118422e-01 7.36821055e-01 ... 8.25809717e-01 142 | -4.42255110e-01 -7.45555758e-01] 143 | [-2.87318736e-01 -4.32660460e-01 7.54006624e-01 ... 6.52838707e-01 144 | 8.37458014e-01 5.28133810e-02]] 145 | 146 | [[ 1.13879967e+00 -2.46339774e+00 -8.04775879e-02 ... 3.66423726e-01 147 | -2.91143537e-01 -1.74150181e+00] 148 | [ 4.57310677e-01 -3.95081580e-01 4.17001843e-01 ... 8.25519189e-02 149 | 3.85775566e+00 2.49403328e-01] 150 | [-7.21119940e-01 -2.90319532e-01 -3.51026058e-01 ... 3.48591879e-02 151 | -8.78576815e-01 -2.11211038e+00] 152 | ... 153 | [-7.24928796e-01 -6.07973993e-01 3.06567937e-01 ... 4.75262031e-02 154 | 1.96665227e-01 -3.96473110e-01] 155 | [-1.16305733e+00 1.33155555e-01 -2.07348824e-01 ... -7.02263296e-01 156 | 4.37461734e-01 5.54734647e-01] 157 | [-5.77684641e-01 8.88289586e-02 -1.22607136e+00 ... -1.88995123e+00 158 | -3.47441137e-01 -5.39577723e-01]] 159 | 160 | [[-1.51553297e+00 -2.27344140e-01 -9.51495409e-01 ... -1.37347555e+00 161 | -1.21363258e+00 1.88223696e+00] 162 | [-1.54229426e+00 -8.41434836e-01 -7.48655975e-01 ... -1.15299320e+00 163 | -1.46017432e+00 -6.28697932e-01] 164 | [-2.49284363e+00 -1.03243661e+00 8.55334699e-02 ... 5.77781796e-02 165 | -7.97680020e-01 7.41598606e-01] 166 | ... 167 | [-1.77141857e+00 -1.26735592e+00 4.56841588e-02 ... -1.27552390e-01 168 | -5.63686013e-01 -3.96051288e-01] 169 | [ 1.56623125e+00 -1.08536267e+00 3.38384509e-03 ... -1.66544914e-01 170 | -2.08985281e+00 -2.63504922e-01] 171 | [-4.08491790e-01 -1.06412196e+00 -5.31406641e-01 ... -1.28227520e+00 172 | -1.24773955e+00 -8.71654749e-01]]]] 173 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torchvision import models 7 | from torchvision.models.vgg import VGG 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | import numpy as np 11 | 12 | """ 13 | 14 | """ 15 | 16 | 17 | def get_upsample_weight(in_channels, out_channels, kernel_size): 18 | ''' 19 | make a 2D bilinear kernel suitable for upsampling 20 | ''' 21 | factor = (kernel_size + 1) // 2 22 | if kernel_size % 2 == 1: 23 | center = factor - 1 24 | else: 25 | center = factor - 0.5 26 | og = np.ogrid[:kernel_size, :kernel_size] # list (64 x 1), (1 x 64) 27 | filt = (1 - abs(og[0] - center) / factor) * \ 28 | (1 - abs(og[1] - center) / factor) # 64 x 64 29 | weight = np.zeros((in_channels, out_channels, kernel_size, 30 | kernel_size), dtype=np.float64) 31 | weight[range(in_channels), range(out_channels), :, :] = filt 32 | 33 | return torch.from_numpy(weight).float() 34 | 35 | class FCN32s(nn.Module): 36 | def __init__(self, pretrained_net, n_class): 37 | super().__init__() 38 | self.n_class = n_class 39 | self.pretrained_net = pretrained_net 40 | self.relu = nn.ReLU(inplace=True) 41 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 42 | self.bn1 = nn.BatchNorm2d(512) 43 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 44 | self.bn2 = nn.BatchNorm2d(256) 45 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 46 | self.bn3 = nn.BatchNorm2d(128) 47 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 48 | self.bn4 = nn.BatchNorm2d(64) 49 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 50 | self.bn5 = nn.BatchNorm2d(32) 51 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 52 | 53 | def forward(self, x): 54 | output = self.pretrained_net.forward(x) 55 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32] 56 | score = self.bn1(self.relu(self.deconv1(x5))) # size=[n, 512, x.h/16, x.w/16] 57 | score = self.bn2(self.relu(self.deconv2(score))) # size=[n, 256, x.h/8, x.w/8] 58 | score = self.bn3(self.relu(self.deconv3(score))) # size=[n, 128, x.h/4, x.w/4] 59 | score = self.bn4(self.relu(self.deconv4(score))) # size=[n, 64, x.h/2, x.w/2] 60 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w] 61 | score = self.classifier(score) # size=[n, n_class, x.h, x.w] 62 | 63 | return score 64 | 65 | 66 | class FCN16s(nn.Module): 67 | def __init__(self, pretrained_net, n_class): 68 | super().__init__() 69 | self.n_class = n_class 70 | self.pretrained_net = pretrained_net 71 | self.relu = nn.ReLU(inplace=True) 72 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 73 | self.bn1 = nn.BatchNorm2d(512) 74 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 75 | self.bn2 = nn.BatchNorm2d(256) 76 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 77 | self.bn3 = nn.BatchNorm2d(128) 78 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 79 | self.bn4 = nn.BatchNorm2d(64) 80 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 81 | self.bn5 = nn.BatchNorm2d(32) 82 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 83 | 84 | def forward(self, x): 85 | output = self.pretrained_net.forward(x) 86 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32] 87 | x4 = output['x4'] # size=[n, 512, x.h/16, x.w/16] 88 | 89 | score = self.relu(self.deconv1(x5)) # size=[n, 512, x.h/16, x.w/16] 90 | score = self.bn1(score + x4) # element-wise add, size=[n, 512, x.h/16, x.w/16] 91 | score = self.bn2(self.relu(self.deconv2(score))) # size=[n, 256, x.h/8, x.w/8] 92 | score = self.bn3(self.relu(self.deconv3(score))) # size=[n, 128, x.h/4, x.w/4] 93 | score = self.bn4(self.relu(self.deconv4(score))) # size=[n, 64, x.h/2, x.w/2] 94 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w] 95 | score = self.classifier(score) # size=[n, n_class, x.h, x.w] 96 | 97 | return score 98 | 99 | 100 | class FCN8s(nn.Module): 101 | def __init__(self, pretrained_net, n_class): 102 | super().__init__() 103 | self.n_class = n_class 104 | self.pretrained_net = pretrained_net 105 | self.relu = nn.ReLU(inplace=True) 106 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 107 | self.bn1 = nn.BatchNorm2d(512) 108 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 109 | self.bn2 = nn.BatchNorm2d(256) 110 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 111 | self.bn3 = nn.BatchNorm2d(128) 112 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 113 | self.bn4 = nn.BatchNorm2d(64) 114 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 115 | self.bn5 = nn.BatchNorm2d(32) 116 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 117 | 118 | # self._init_weights() 119 | # 120 | # 1 121 | init.xavier_uniform_(self.deconv1.weight) 122 | # 2 123 | init.xavier_uniform_(self.deconv2.weight) 124 | # 3 125 | init.xavier_uniform_(self.deconv3.weight) 126 | init.xavier_uniform_(self.deconv4.weight) 127 | init.xavier_uniform_(self.deconv5.weight) 128 | init.xavier_uniform_(self.classifier.weight) 129 | 130 | def forward(self, x): 131 | output = self.pretrained_net.forward(x) 132 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32] 133 | x4 = output['x4'] # size=[n, 512, x.h/16, x.w/16] 134 | x3 = output['x3'] # size=[n, 512, x.h/8, x.w/8] 135 | 136 | score = self.relu(self.deconv1(x5)) # size=[n, 512, x.h/16, x.w/16] 137 | score = self.bn1(score + x4) # element-wise add, size=[n, 512, x.h/16, x.w/16] 138 | score = self.relu(self.deconv2(score)) # size=[n, 256, x.h/8, x.w/8] 139 | score = self.bn2(score+x3) 140 | score = self.bn3(self.relu(self.deconv3(score))) # size=[n, 128, x.h/4, x.w/4] 141 | score = self.bn4(self.relu(self.deconv4(score))) # size=[n, 64, x.h/2, x.w/2] 142 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w] 143 | score = self.classifier(score) # size=[n, n_class, x.h, x.w] 144 | 145 | return score 146 | 147 | def _init_weights(self): 148 | ''' 149 | hide method, used just in class 150 | ''' 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | m.weight.data.zero_() 154 | # if m.bias is not None: 155 | m.bias.data.zero_() 156 | if isinstance(m, nn.ConvTranspose2d): 157 | assert m.kernel_size[0] == m.kernel_size[1] 158 | initial_weight = get_upsample_weight(m.in_channels, 159 | m.out_channels, m.kernel_size[0]) 160 | m.weight.data.copy_(initial_weight) # copy not = ? 161 | 162 | 163 | class FCN1s(nn.Module): 164 | def __init__(self, pretrained_net, n_class): 165 | super().__init__() 166 | self.n_class = n_class 167 | self.pretrained_net = pretrained_net 168 | self.relu = nn.ReLU(inplace=True) 169 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 170 | self.bn1 = nn.BatchNorm2d(512) 171 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 172 | self.bn2 = nn.BatchNorm2d(256) 173 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 174 | self.bn3 = nn.BatchNorm2d(128) 175 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 176 | self.bn4 = nn.BatchNorm2d(64) 177 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 178 | self.bn5 = nn.BatchNorm2d(32) 179 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 180 | 181 | def forward(self, x): 182 | output = self.pretrained_net.forward(x) 183 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32] 184 | x4 = output['x4'] # size=[n, 512, x.h/16, x.w/16] 185 | x3 = output['x3'] # size=[n, 512, x.h/8, x.w/8] 186 | x2 = output['x2'] # size=[n, 512, x.h/4, x.w/4] 187 | x1 = output['x1'] # size=[n, 512, x.h/2, x.w/2] 188 | 189 | score = self.relu(self.deconv1(x5)) # size=[n, 512, x.h/16, x.w/16] 190 | score = self.bn1(score + x4) # element-wise add, size=[n, 512, x.h/16, x.w/16] 191 | score = self.relu(self.deconv2(score)) # size=[n, 256, x.h/8, x.w/8] 192 | score = self.bn2(score+x3) 193 | score = self.relu(self.deconv3(score)) # size=[n, 128, x.h/4, x.w/4] 194 | score = self.bn3(score+x2) 195 | score = self.relu(self.deconv4(score)) # size=[n, 64, x.h/2, x.w/2] 196 | score = self.bn4(score+x1) 197 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w] 198 | score = self.classifier(score) # size=[n, n_class, x.h, x.w] 199 | 200 | return score 201 | 202 | ranges = { 203 | 'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)), 204 | 'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)), 205 | 'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)), 206 | 'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37)) 207 | } 208 | 209 | # cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 210 | cfg = { 211 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 212 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 213 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 214 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 215 | } 216 | 217 | 218 | def make_layers(cfg, batch_norm=False): 219 | """ 220 | :param cfg: cfg['vgg16'] 221 | :param batch_norm: 222 | :return: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 数字表示卷积 'M': 表示池化 223 | """ 224 | layers = [] 225 | in_channels = 3 226 | for v in cfg: 227 | if v == 'M': 228 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 229 | else: 230 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 231 | if batch_norm: 232 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 233 | else: 234 | layers += [conv2d, nn.ReLU(inplace=True)] 235 | in_channels = v 236 | 237 | return nn.Sequential(*layers) 238 | 239 | 240 | class VGGNet(VGG): 241 | def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False): 242 | super().__init__(make_layers(cfg[model])) 243 | self.ranges = ranges[model] 244 | 245 | if pretrained: 246 | vgg16 = models.vgg16(pretrained=False) 247 | vgg16.load_state_dict(torch.load('/home/yxk/.torch/models/vgg16-397923af.pth')) 248 | # exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model) 249 | 250 | if not requires_grad: 251 | for param in super().parameters(): 252 | param.requires_grad = False 253 | 254 | if remove_fc: # delete redundant fully-connected layer params, can save memory 255 | del self.classifier 256 | 257 | if show_params: 258 | for name, param in self.named_parameters(): 259 | print(name, param.size()) 260 | 261 | def forward(self, x): 262 | output = {} 263 | # get the output of each maxpooling layer (5 maxpool in VGG net) 264 | for idx in range(len(self.ranges)): 265 | for layer in range(self.ranges[idx][0], self.ranges[idx][1]): 266 | x = self.features[layer](x) 267 | output["x%d" % (idx+1)] = x 268 | return output 269 | 270 | 271 | # other models 272 | class UNetEnc(nn.Module): 273 | 274 | def __init__(self, in_channels, features, out_channels): 275 | super().__init__() 276 | 277 | self.up = nn.Sequential( 278 | nn.Conv2d(in_channels, features, 3), 279 | nn.ReLU(inplace=True), 280 | nn.Conv2d(features, features, 3), 281 | nn.ReLU(inplace=True), 282 | nn.ConvTranspose2d(features, out_channels, 2, stride=2), 283 | nn.ReLU(inplace=True), 284 | ) 285 | 286 | def forward(self, x): 287 | return self.up(x) 288 | 289 | 290 | class UNetDec(nn.Module): 291 | 292 | def __init__(self, in_channels, out_channels, dropout=False): 293 | super().__init__() 294 | 295 | layers = [ 296 | nn.Conv2d(in_channels, out_channels, 3), 297 | nn.ReLU(inplace=True), 298 | nn.Conv2d(out_channels, out_channels, 3), 299 | nn.ReLU(inplace=True), 300 | ] 301 | if dropout: 302 | layers += [nn.Dropout(.5)] 303 | layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)] 304 | 305 | self.down = nn.Sequential(*layers) 306 | 307 | def forward(self, x): 308 | return self.down(x) 309 | 310 | 311 | class UNet(nn.Module): 312 | 313 | def __init__(self, num_classes): 314 | super().__init__() 315 | 316 | self.dec1 = UNetDec(3, 64) 317 | self.dec2 = UNetDec(64, 128) 318 | self.dec3 = UNetDec(128, 256) 319 | self.dec4 = UNetDec(256, 512, dropout=True) 320 | self.center = nn.Sequential( 321 | nn.Conv2d(512, 1024, 3), 322 | nn.ReLU(inplace=True), 323 | nn.Conv2d(1024, 1024, 3), 324 | nn.ReLU(inplace=True), 325 | nn.Dropout(), 326 | nn.ConvTranspose2d(1024, 512, 2, stride=2), 327 | nn.ReLU(inplace=True), 328 | ) 329 | self.enc4 = UNetEnc(1024, 512, 256) 330 | self.enc3 = UNetEnc(512, 256, 128) 331 | self.enc2 = UNetEnc(256, 128, 64) 332 | self.enc1 = nn.Sequential( 333 | nn.Conv2d(128, 64, 3), 334 | nn.ReLU(inplace=True), 335 | nn.Conv2d(64, 64, 3), 336 | nn.ReLU(inplace=True), 337 | ) 338 | self.final = nn.Conv2d(64, num_classes, 1) 339 | 340 | def forward(self, x): 341 | dec1 = self.dec1(x) 342 | dec2 = self.dec2(dec1) 343 | dec3 = self.dec3(dec2) 344 | dec4 = self.dec4(dec3) 345 | center = self.center(dec4) 346 | enc4 = self.enc4(torch.cat([ 347 | center, F.upsample_bilinear(dec4, center.size()[2:])], 1)) 348 | enc3 = self.enc3(torch.cat([ 349 | enc4, F.upsample_bilinear(dec3, enc4.size()[2:])], 1)) 350 | enc2 = self.enc2(torch.cat([ 351 | enc3, F.upsample_bilinear(dec2, enc3.size()[2:])], 1)) 352 | enc1 = self.enc1(torch.cat([ 353 | enc2, F.upsample_bilinear(dec1, enc2.size()[2:])], 1)) 354 | 355 | return F.upsample_bilinear(self.final(enc1), x.size()[2:]) 356 | 357 | 358 | class SegNetEnc(nn.Module): 359 | 360 | def __init__(self, in_channels, out_channels, num_layers): 361 | super().__init__() 362 | 363 | layers = [ 364 | nn.UpsamplingBilinear2d(scale_factor=2), 365 | nn.Conv2d(in_channels, in_channels // 2, 3, padding=1), 366 | nn.BatchNorm2d(in_channels // 2), 367 | nn.ReLU(inplace=True), 368 | ] 369 | layers += [ 370 | nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1), 371 | nn.BatchNorm2d(in_channels // 2), 372 | nn.ReLU(inplace=True), 373 | ] * num_layers 374 | layers += [ 375 | nn.Conv2d(in_channels // 2, out_channels, 3, padding=1), 376 | nn.BatchNorm2d(out_channels), 377 | nn.ReLU(inplace=True), 378 | ] 379 | self.encode = nn.Sequential(*layers) 380 | 381 | def forward(self, x): 382 | return self.encode(x) 383 | 384 | 385 | class SegNet(nn.Module): 386 | 387 | def __init__(self, num_classes): 388 | super().__init__() 389 | 390 | # should be vgg16bn but at the moment we have no pretrained bn models 391 | decoders = list(models.vgg16(pretrained=True).features.children()) 392 | 393 | self.dec1 = nn.Sequential(*decoders[:5]) 394 | self.dec2 = nn.Sequential(*decoders[5:10]) 395 | self.dec3 = nn.Sequential(*decoders[10:17]) 396 | self.dec4 = nn.Sequential(*decoders[17:24]) 397 | self.dec5 = nn.Sequential(*decoders[24:]) 398 | 399 | # gives better results 400 | for m in self.modules(): 401 | if isinstance(m, nn.Conv2d): 402 | m.requires_grad = False 403 | 404 | self.enc5 = SegNetEnc(512, 512, 1) 405 | self.enc4 = SegNetEnc(1024, 256, 1) 406 | self.enc3 = SegNetEnc(512, 128, 1) 407 | self.enc2 = SegNetEnc(256, 64, 0) 408 | self.enc1 = nn.Sequential( 409 | nn.UpsamplingBilinear2d(scale_factor=2), 410 | nn.Conv2d(128, 64, 3, padding=1), 411 | nn.BatchNorm2d(64), 412 | nn.ReLU(inplace=True), 413 | ) 414 | self.final = nn.Conv2d(64, num_classes, 3, padding=1) 415 | 416 | def forward(self, x): 417 | dec1 = self.dec1(x) 418 | dec2 = self.dec2(dec1) 419 | dec3 = self.dec3(dec2) 420 | dec4 = self.dec4(dec3) 421 | dec5 = self.dec5(dec4) 422 | enc5 = self.enc5(dec5) 423 | enc4 = self.enc4(torch.cat([dec4, enc5], 1)) 424 | enc3 = self.enc3(torch.cat([dec3, enc4], 1)) 425 | enc2 = self.enc2(torch.cat([dec2, enc3], 1)) 426 | enc1 = self.enc1(torch.cat([dec1, enc2], 1)) 427 | 428 | return F.upsample_bilinear(self.final(enc1), x.size()[2:]) 429 | 430 | 431 | class PSPDec(nn.Module): 432 | 433 | def __init__(self, in_features, out_features, downsize, upsize=60): 434 | super().__init__() 435 | 436 | self.features = nn.Sequential( 437 | nn.AvgPool2d(downsize, stride=downsize), 438 | nn.Conv2d(in_features, out_features, 1, bias=False), 439 | nn.BatchNorm2d(out_features, momentum=.95), 440 | nn.ReLU(inplace=True), 441 | nn.UpsamplingBilinear2d(upsize) 442 | ) 443 | 444 | def forward(self, x): 445 | return self.features(x) 446 | 447 | 448 | class PSPNet(nn.Module): 449 | 450 | def __init__(self, num_classes): 451 | super().__init__() 452 | 453 | ''' 454 | self.conv1 = nn.Sequential( 455 | nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), 456 | nn.BatchNorm2d(64, momentum=.95), 457 | nn.ReLU(inplace=True), 458 | nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False), 459 | nn.BatchNorm2d(64, momentum=.95), 460 | nn.ReLU(inplace=True), 461 | nn.Conv2d(64, 128, 3, stride=1, padding=1, bias=False), 462 | nn.BatchNorm2d(128, momentum=.95), 463 | nn.ReLU(inplace=True), 464 | nn.MaxPool2d(3, stride=2, padding=1), 465 | ) 466 | ''' 467 | 468 | resnet = models.resnet101(pretrained=True) 469 | 470 | self.conv1 = resnet.conv1 471 | self.layer1 = resnet.layer1 472 | self.layer2 = resnet.layer2 473 | self.layer3 = resnet.layer3 474 | self.layer4 = resnet.layer4 475 | 476 | for m in self.modules(): 477 | if isinstance(m, nn.Conv2d): 478 | m.stride = 1 479 | m.requires_grad = False 480 | if isinstance(m, nn.BatchNorm2d): 481 | m.requires_grad = False 482 | 483 | self.layer5a = PSPDec(2048, 512, 60) 484 | self.layer5b = PSPDec(2048, 512, 30) 485 | self.layer5c = PSPDec(2048, 512, 20) 486 | self.layer5d = PSPDec(2048, 512, 10) 487 | 488 | self.final = nn.Sequential( 489 | nn.Conv2d(2048, 512, 3, padding=1, bias=False), 490 | nn.BatchNorm2d(512, momentum=.95), 491 | nn.ReLU(inplace=True), 492 | nn.Dropout(.1), 493 | nn.Conv2d(512, num_classes, 1), 494 | ) 495 | 496 | def forward(self, x): 497 | print('x', x.size()) 498 | x = self.conv1(x) 499 | print('conv1', x.size()) 500 | x = self.layer1(x) 501 | print('layer1', x.size()) 502 | x = self.layer2(x) 503 | print('layer2', x.size()) 504 | x = self.layer3(x) 505 | print('layer3', x.size()) 506 | x = self.layer4(x) 507 | print('layer4', x.size()) 508 | x = self.final(torch.cat([ 509 | x, 510 | self.layer5a(x), 511 | self.layer5b(x), 512 | self.layer5c(x), 513 | self.layer5d(x), 514 | ], 1)) 515 | print('final', x.size()) 516 | 517 | return F.upsample_bilinear(self.final, x.size()[2:]) 518 | 519 | 520 | if __name__ == "__main__": 521 | batch_size, n_class, h, w = 10, 20, 160, 160 522 | 523 | # test output size 524 | vgg_model = VGGNet(requires_grad=True) 525 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) # 224 526 | output = vgg_model.forward(input) 527 | assert output['x5'].size() == torch.Size([batch_size, 512, 5, 5]) 528 | 529 | fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class) 530 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 531 | output = fcn_model.forward(input) 532 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 533 | 534 | fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class) 535 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 536 | output = fcn_model(input) 537 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 538 | 539 | fcn_model = FCN8s(pretrained_net=vgg_model, n_class=n_class) 540 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 541 | output = fcn_model(input) 542 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 543 | 544 | fcn_model = FCN1s(pretrained_net=vgg_model, n_class=n_class) 545 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 546 | output = fcn_model(input) 547 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 548 | 549 | # test a random batch, loss should decrease 550 | fcn_model = FCN1s(pretrained_net=vgg_model, n_class=n_class) 551 | criterion = nn.BCELoss() 552 | optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9) 553 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 554 | y = torch.autograd.Variable(torch.randn(batch_size, n_class, h, w), requires_grad=False) 555 | for iter in range(10): 556 | optimizer.zero_grad() 557 | output = fcn_model(input) 558 | 559 | output = nn.functional.sigmoid(output) 560 | loss = criterion(output, y) # loss 561 | loss.backward() 562 | print("iter{}, loss {}".format(iter, loss.data[0])) 563 | optimizer.step() -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 150 | 151 | 152 | 153 | VGG 154 | models 155 | data 156 | utils 157 | train_loader 158 | use_cuda 159 | n_class 160 | 161 | 162 | 163 | 165 | 166 | 198 | 199 | 200 | 201 | 202 | true 203 | DEFINITION_ORDER 204 | 205 | 206 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 |