├── yxk_loss
├── __init__.py
├── loss_test.py
└── loss_collection.py
├── image
├── 0_label.png
├── 0_out.png
├── 0_src.jpg
├── 1_label.png
├── 1_out.png
├── 1_src.jpg
├── 4_label.png
├── 4_out.png
├── 4_src.jpg
├── 5_label.png
├── 5_out.png
└── 5_src.jpg
├── __pycache__
├── loss.cpython-36.pyc
├── tools.cpython-36.pyc
├── models.cpython-36.pyc
└── voc_loader.cpython-36.pyc
├── .idea
├── vcs.xml
├── misc.xml
├── modules.xml
├── FCN_pytorch.iml
└── workspace.xml
├── README.md
├── evaluate.py
├── tensorboard.py
├── predict.py
├── tools.py
├── loss.py
├── voc_loader.py
├── train.py
├── result.txt~
└── models.py
/yxk_loss/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/image/0_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/0_label.png
--------------------------------------------------------------------------------
/image/0_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/0_out.png
--------------------------------------------------------------------------------
/image/0_src.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/0_src.jpg
--------------------------------------------------------------------------------
/image/1_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/1_label.png
--------------------------------------------------------------------------------
/image/1_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/1_out.png
--------------------------------------------------------------------------------
/image/1_src.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/1_src.jpg
--------------------------------------------------------------------------------
/image/4_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/4_label.png
--------------------------------------------------------------------------------
/image/4_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/4_out.png
--------------------------------------------------------------------------------
/image/4_src.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/4_src.jpg
--------------------------------------------------------------------------------
/image/5_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/5_label.png
--------------------------------------------------------------------------------
/image/5_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/5_out.png
--------------------------------------------------------------------------------
/image/5_src.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/image/5_src.jpg
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/tools.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/tools.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/voc_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overfitover/fcn_pytorch/HEAD/__pycache__/voc_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/FCN_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fully Convolutional Networks for Semantic Segmentation
2 |
3 | ## Data
4 | Pascal VOC数据集
5 | 数据组织形式:
6 | /home/my_name/data/VOC/VOCdevkit/VOC2012/JPEGImages
7 |
8 | ## 工具
9 | pytorch 0.4.1
10 | tensorboardX (可视化工具)
11 |
12 |
13 | ## 注意事项
14 | 如果你参考别的项目,遇到训练很多次没有效果,可能是他的loss实现的有问题.你可以换他其他的loss函数重新训练,也可以参考本项目的loss.
15 | 交叉熵loss亲测有效.
16 |
17 | ## 使用说明
18 |
19 | ### train
20 | python train.py
21 |
22 | ### evaluate
23 | python evaluate.py
24 |
25 | ### predict
26 | python predict.py
27 |
28 |
29 | ## 可视化查看
30 | ```
31 | cd FCN_pytorch
32 | tensorboard --logdir ./runs
33 | ```
34 |
35 | ## 图片保存:
36 | 运行python predict.py,图片保存在指定文件夹中,方便自己查看.
37 |
38 | ## 结果:
39 | 训练200epoch的结果:
40 |
41 | 
42 |
43 | 
44 |
45 | 
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4 | import torch
5 | import models
6 | import voc_loader
7 | import numpy as np
8 | from torch.autograd import Variable
9 | import tools
10 |
11 |
12 | n_class = 21
13 | def evaluate():
14 | use_cuda = torch.cuda.is_available()
15 | path = os.path.expanduser('/home/yxk/data/')
16 | val_data = voc_loader.VOC2012ClassSeg(root=path,
17 | split='val',
18 | transform=True)
19 | val_loader = torch.utils.data.DataLoader(val_data,
20 | batch_size=1,
21 | shuffle=False,
22 | num_workers=5)
23 | print('load model .....')
24 | vgg_model = models.VGGNet(requires_grad=True)
25 | fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class)
26 | fcn_model.load_state_dict(torch.load('params.pth'))
27 |
28 | if use_cuda:
29 | fcn_model.cuda()
30 | fcn_model.eval()
31 |
32 | label_trues, label_preds = [], []
33 | # for idx, (img, label) in enumerate(val_loader):
34 | for idx in range(len(val_data)):
35 | img, label = val_data[idx]
36 | img = img.unsqueeze(0)
37 | if use_cuda:
38 | img = img.cuda()
39 | img = Variable(img)
40 |
41 | out = fcn_model(img) # 1, 21, 320, 320
42 |
43 | pred = out.data.max(1)[1].squeeze_(1).squeeze_(0) # 320, 320
44 |
45 | if use_cuda:
46 | pred = pred.cpu()
47 | label_trues.append(label.numpy())
48 | label_preds.append(pred.numpy())
49 |
50 | if idx % 30 == 0:
51 | print('evaluate [%d/%d]' % (idx, len(val_loader)))
52 |
53 | metrics = tools.accuracy_score(label_trues, label_preds)
54 | metrics = np.array(metrics)
55 | metrics *= 100
56 | print('''\
57 | Accuracy: {0}
58 | Accuracy Class: {1}
59 | Mean IU: {2}
60 | FWAV Accuracy: {3}'''.format(*metrics))
61 |
62 |
63 | if __name__ == '__main__':
64 | evaluate()
--------------------------------------------------------------------------------
/tensorboard.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torchvision.utils as vutils
4 |
5 | import numpy as np
6 | import torchvision.models as models
7 |
8 | from torchvision import datasets
9 |
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 | resnet18 = models.resnet18(False)
14 |
15 | writer = SummaryWriter()
16 |
17 | sample_rate = 44100
18 |
19 | freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440]
20 |
21 | for n_iter in range(100):
22 |
23 | s1 = torch.rand(1) # value to keep
24 |
25 | s2 = torch.rand(1)
26 |
27 | writer.add_scalar('data/scalar1', s1[0], n_iter)
28 |
29 | writer.add_scalar('data/scalar2', s2[0], n_iter)
30 |
31 | writer.add_scalars('data/scalar_group', {"xsinx":n_iter*np.sin(n_iter),
32 |
33 | "xcosx":n_iter*np.cos(n_iter),
34 |
35 | "arctanx": np.arctan(n_iter)}, n_iter)
36 |
37 | x = torch.rand(32, 3, 64, 64)
38 |
39 | if n_iter % 10 == 0:
40 |
41 | x = vutils.make_grid(x, normalize=True, scale_each=True)
42 |
43 | writer.add_image('Image', x, n_iter)
44 |
45 | x = torch.zeros(sample_rate*2)
46 |
47 | for i in range(x.size(0)):
48 |
49 | x[i] = np.cos(freqs[n_iter//10]*np.pi*float(i)/float(sample_rate))
50 |
51 | writer.add_audio('myAudio', x, n_iter, sample_rate=sample_rate)
52 |
53 | writer.add_text('Text', 'text logged at step:'+str(n_iter), n_iter)
54 |
55 | for name, param in resnet18.named_parameters():
56 |
57 | writer.add_histogram(name, param.clone().cpu().data.numpy(), n_iter)
58 |
59 |
60 | writer.add_pr_curve('xoxo', np.random.randint(2, size=100),
61 |
62 | np.random.rand(100), n_iter) # needs tensorboard 0.4RC or later
63 |
64 | dataset = datasets.MNIST('mnist', train=False, download=True)
65 |
66 | images = dataset.test_data[:100].float()
67 |
68 | label = dataset.test_labels[:100]
69 |
70 | features = images.view(100, 784)
71 |
72 | writer.add_embedding(features, metadata=label, label_img=images.unsqueeze(1))
73 |
74 | # export scalar data to JSON for external processing
75 |
76 | writer.export_scalars_to_json("./all_scalars.json")
77 |
78 | writer.close()
79 |
80 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4 | import torch
5 | import cv2
6 | import numpy as np
7 | from torch.autograd import Variable
8 | import torchvision.transforms as transforms
9 | import voc_loader
10 | import models
11 | import random
12 | import tools
13 | from loss import CrossEntropyLoss2d
14 |
15 | n_class = 21
16 | def main():
17 | use_cuda = torch.cuda.is_available()
18 | path = os.path.expanduser('/home/yxk/data/')
19 |
20 | dataset = voc_loader.VOC2012ClassSeg(root=path,
21 | split='train',
22 | transform=True)
23 |
24 | vgg_model = models.VGGNet(requires_grad=True)
25 | fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class)
26 | fcn_model.load_state_dict(torch.load('./pretrained_models/model120.pth', map_location='cpu'))
27 |
28 | fcn_model.eval()
29 |
30 | if use_cuda:
31 | fcn_model.cuda()
32 |
33 | criterion = CrossEntropyLoss2d()
34 |
35 | for i in range(len(dataset)):
36 | idx = random.randrange(0, len(dataset))
37 | img, label = dataset[idx]
38 | img_name = str(i)
39 |
40 | img_src, _ = dataset.untransform(img, label) # whc
41 |
42 | cv2.imwrite(path + 'image/%s_src.jpg' % img_name, img_src)
43 | tools.labelTopng(label, path + 'image/%s_label.png' % img_name) # 将label转换成图片
44 |
45 | # a = tools.labelToimg(label)
46 | #
47 | # print(a)
48 |
49 | if use_cuda:
50 | img = img.cuda()
51 | label = label.cuda()
52 | img = Variable(img.unsqueeze(0), volatile=True)
53 | label = Variable(label.unsqueeze(0), volatile=True)
54 | # print("label: ", label.data)
55 |
56 | out = fcn_model(img) # (1, 21, 320, 320)
57 | loss = criterion(out, label)
58 | # print(img_name, 'loss:', loss.data[0])
59 |
60 | net_out = out.data.max(1)[1].squeeze_(0) # 320, 320
61 | # print(out.data.max(1)[1].shape)
62 | # print("out", net_out)
63 | if use_cuda:
64 | net_out = net_out.cpu()
65 |
66 | tools.labelTopng(net_out, path + 'image/%s_out.png' % img_name) # 将网络输出转换成图片
67 |
68 | if i == 10:
69 | break
70 |
71 |
72 | if __name__ == '__main__':
73 | main()
--------------------------------------------------------------------------------
/yxk_loss/loss_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 |
4 | '''
5 | 参考文献: https://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medium=referral
6 | 如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss;
7 | 如果 reduce = True,那么 loss 返回的是标量
8 |
9 | 如果 size_average = True,返回 loss.mean();
10 | 如果 size_average = True,返回 loss.sum();
11 | '''
12 |
13 | # nn.L1Loss: loss(input, target)=|input-target|
14 | if False:
15 | loss_fn = torch.nn.L1Loss(reduce=True, size_average=False)
16 | input = torch.autograd.Variable(torch.randn(3, 4))
17 | target = torch.autograd.Variable(torch.randn(3, 4))
18 | loss = loss_fn(input, target)
19 | print(input)
20 | print(target)
21 | print(loss)
22 | print(input.size(), target.size(), loss.size())
23 |
24 |
25 | # nn.SmoothL1Loss 在(-1, 1)上是平方loss, 其他情况是L1 loss
26 | if False:
27 | loss_fn = torch.nn.SmoothL1Loss(reduce=False, size_average=False)
28 | input = torch.autograd.Variable(torch.randn(3, 4))
29 | target = torch.autograd.Variable(torch.randn(3, 4))
30 | loss = loss_fn(input, target)
31 | print(input)
32 | print(target)
33 | print(loss)
34 | print(input.size(), target.size(), loss.size())
35 |
36 | # nn.MSELoss 均方损失函数
37 | if False:
38 | loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
39 | input = torch.autograd.Variable(torch.randn(3, 4))
40 | target = torch.autograd.Variable(torch.randn(3, 4))
41 | loss = loss_fn(input, target)
42 | print(input)
43 | print(target)
44 | print(loss)
45 | print(input.size(), target.size(), loss.size())
46 |
47 | # nn.BCELoss
48 | if False:
49 | import torch.nn.functional as F
50 |
51 | loss_fn = torch.nn.BCELoss(reduce=False, size_average=False)
52 | input = torch.autograd.Variable(torch.randn(3, 4))
53 | target = torch.autograd.Variable(torch.FloatTensor(3, 4).random_(2))
54 | loss = loss_fn(F.sigmoid(input), target)
55 | print(input, input.shape)
56 | print(F.sigmoid(input))
57 | print(target, target.shape)
58 | print(loss, loss.shape)
59 |
60 | # nn.CrossEntropyLoss
61 | if False:
62 | weight = torch.Tensor([1, 2, 1, 1, 10])
63 | loss_fn = torch.nn.CrossEntropyLoss(reduce=False, size_average=False, weight=None)
64 | input = Variable(torch.randn(3, 5)) # (batch_size, C)
65 | target = Variable(torch.LongTensor(3).random_(5))
66 | loss = loss_fn(input, target)
67 | print(input)
68 | print(target)
69 | print(loss)
70 |
71 |
72 |
--------------------------------------------------------------------------------
/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL.Image as Image
4 |
5 | def getPalette():
6 | '''
7 | http://blog.csdn.net/yhl_leo/article/details/52185581
8 | '''
9 | pal = np.array([[0, 0, 0],
10 | [128, 0, 0],
11 | [0, 128, 0],
12 | [128, 128, 0],
13 | [0, 0, 128],
14 | [128, 0, 128],
15 | [0, 128, 128],
16 | [128, 128, 128],
17 | [64, 0, 0],
18 | [192, 0, 0],
19 | [64, 128, 0],
20 | [192, 128, 0],
21 | [64, 0, 128],
22 | [192, 0, 128],
23 | [64, 128, 128],
24 | [192, 128, 128],
25 | [0, 64, 0],
26 | [128, 64, 0],
27 | [0, 192, 0],
28 | [128, 192, 0],
29 | [0, 64, 128]], dtype='uint8').flatten()
30 | return pal
31 |
32 |
33 | def colorize_mask(mask):
34 | """
35 | :param mask: 图片大小的数值,代表不同的颜色
36 | :return:
37 | """
38 | new_mask = Image.fromarray(mask.astype(np.uint8), 'P') # 将二维数组转为图像
39 |
40 | pal = getPalette()
41 | new_mask.putpalette(pal)
42 | # print(new_mask.show())
43 | return new_mask
44 |
45 | # m = np.array([[1,2], [3,4]])
46 | # colorize_mask(m)
47 |
48 |
49 | def getFileName(file_path):
50 | '''
51 | get file_path name from path+name+'test.jpg'
52 | return test
53 | '''
54 | full_name = file_path.split('/')[-1]
55 | name = os.path.splitext(full_name)[0]
56 |
57 | return name
58 |
59 |
60 | def labelTopng(label, img_name):
61 | '''
62 | convert tensor cpu label to png and save
63 | '''
64 | label = label.numpy() # 320 320
65 | label_pil = colorize_mask(label)
66 | label_pil.save(img_name)
67 |
68 | def labelToimg(label):
69 | label = label.numpy()
70 | label_pil = colorize_mask(label)
71 | return label_pil
72 |
73 |
74 | def _fast_hist(label_true, label_pred, n_class):
75 | mask = (label_true >= 0) & (label_true < n_class)
76 | hist = np.bincount(
77 | n_class * label_true[mask].astype(int) +
78 | label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
79 | return hist
80 |
81 |
82 | def accuracy_score(label_trues, label_preds, n_class=21):
83 | """Returns accuracy score evaluation result.
84 | - overall accuracy
85 | - mean accuracy
86 | - mean IU
87 | - fwavacc
88 | """
89 | hist = np.zeros((n_class, n_class))
90 | for lt, lp in zip(label_trues, label_preds):
91 | hist += _fast_hist(lt.flatten(), lp.flatten(), n_class) # n_class, n_class
92 | acc = np.diag(hist).sum() / hist.sum()
93 | acc_cls = np.diag(hist) / hist.sum(axis=1)
94 | acc_cls = np.nanmean(acc_cls)
95 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
96 | mean_iu = np.nanmean(iu)
97 | freq = hist.sum(axis=1) / hist.sum()
98 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
99 | return acc, acc_cls, mean_iu, fwavacc
--------------------------------------------------------------------------------
/yxk_loss/loss_collection.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from torch.autograd import Variable
5 |
6 | '''
7 | 如果 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss;
8 | 如果 reduce = True,那么 loss 返回的是标量
9 |
10 | 如果 size_average = True,返回 loss.mean();
11 | 如果 size_average = True,返回 loss.sum();
12 | '''
13 |
14 |
15 | class CrossEntropyLoss(nn.Module):
16 | def __init__(self):
17 | super(CrossEntropyLoss, self).__init__()
18 | self.cross_entropy_loss = nn.CrossEntropyLoss(weight=None, size_average=False)
19 |
20 | def forward(self, inputs, targets):
21 | return self.cross_entropy_loss(inputs, targets)
22 |
23 | class CrossEntropyLoss2d(nn.Module):
24 | """
25 | Negative Log Likelihood
26 | """
27 |
28 | def __init__(self, weight=None, size_average=True):
29 | super(CrossEntropyLoss2d, self).__init__()
30 | self.nll_loss = nn.NLLLoss2d(weight, size_average)
31 |
32 | def forward(self, inputs, targets):
33 | return self.nll_loss(F.log_softmax(inputs), targets)
34 |
35 | def smooth_l1(deltas, targets, sigma=3.0):
36 | """
37 | :param deltas: (tensor) predictions, sized [N,D].
38 | :param targets: (tensor) targets, sized [N,].
39 | :param sigma: 3.0
40 | :return:
41 | """
42 |
43 | sigma2 = sigma * sigma
44 | diffs = deltas - targets
45 | smooth_l1_signs = torch.min(torch.abs(diffs), 1.0 / sigma2).detach().float()
46 |
47 | smooth_l1_option1 = torch.mul(diffs, diffs) * 0.5 * sigma2
48 | smooth_l1_option2 = torch.abs(diffs) - 0.5 / sigma2
49 | smooth_l1_add = torch.mul(smooth_l1_option1, smooth_l1_signs) + \
50 | torch.mul(smooth_l1_option2, 1 - smooth_l1_signs)
51 | smooth_l1 = smooth_l1_add
52 |
53 | return smooth_l1
54 |
55 | class FocalLoss(nn.Module):
56 |
57 | def __init__(self, gamma=0, alpha=None, size_average=True):
58 | super(FocalLoss, self).__init__()
59 | self.gamma = gamma
60 | self.alpha = alpha
61 | if isinstance(alpha, (float, int, float)): self.alpha = torch.Tensor([alpha, 1-alpha])
62 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
63 | self.size_average = size_average
64 |
65 | def forward(self, input, target):
66 | if input.dim() > 2:
67 | input = input.view(input.size(0), input.size(1),-1) # N,C,H,W => N,C,H*W
68 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
69 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
70 | target = target.view(-1, 1)
71 |
72 | logpt = F.log_softmax(input)
73 | logpt = logpt.gather(1, target)
74 | logpt = logpt.view(-1)
75 | pt = Variable(logpt.data.exp())
76 |
77 | if self.alpha is not None:
78 | if self.alpha.type() != input.data.type():
79 | self.alpha = self.alpha.type_as(input.data)
80 | at = self.alpha.gather(0, target.data.view(-1))
81 | logpt = logpt * Variable(at)
82 |
83 | loss = -1 * (1-pt)**self.gamma * logpt
84 | if self.size_average: return loss.mean()
85 | else:
86 | return loss.sum()
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 | from torch.autograd import Variable
5 |
6 |
7 | class CrossEntropy2d(nn.Module):
8 | '''
9 | 这个实现有问题, mmp 误事
10 | loss doesn't change, loss can not be backward?
11 |
12 | why need change? only net weight need to be change.
13 | '''
14 | def __init__(self):
15 | super(CrossEntropy2d, self).__init__()
16 | self.criterion = nn.CrossEntropyLoss(weight=None, size_average=False) # should size_average=False? True for average,false for tatal
17 |
18 | def forward(self, out, target):
19 | n, c, h, w = out.size() # n:batch_size, c:class
20 | out = out.view(-1, c) # (n*h*w, c)
21 | target = target.view(-1) # (n*h*w)
22 | # print('out', out.size(), 'target', target.size())
23 | loss = self.criterion(out, target)
24 |
25 | return loss
26 |
27 | class CrossEntropyLoss(nn.Module):
28 | def __init__(self):
29 | super(CrossEntropyLoss, self).__init__()
30 | self.cross_entropy_loss = nn.CrossEntropyLoss(weight=None, size_average=False)
31 |
32 | def forward(self, inputs, targets):
33 | return self.cross_entropy_loss(inputs, targets)
34 |
35 | class CrossEntropyLoss2d(nn.Module):
36 | """
37 | 亲测有效
38 | """
39 |
40 | def __init__(self, weight=None, size_average=True):
41 | super(CrossEntropyLoss2d, self).__init__()
42 | self.nll_loss = nn.NLLLoss2d(weight, size_average)
43 |
44 | def forward(self, inputs, targets):
45 | return self.nll_loss(F.log_softmax(inputs), targets)
46 |
47 |
48 | def smooth_l1(deltas, targets, sigma=3.0):
49 | """
50 | :param deltas: (tensor) predictions, sized [N,D].
51 | :param targets: (tensor) targets, sized [N,].
52 | :param sigma: 3.0
53 | :return:
54 | """
55 |
56 | sigma2 = sigma * sigma
57 | diffs = deltas - targets
58 | smooth_l1_signs = torch.min(torch.abs(diffs), 1.0 / sigma2).detach().float()
59 |
60 | smooth_l1_option1 = torch.mul(diffs, diffs) * 0.5 * sigma2
61 | smooth_l1_option2 = torch.abs(diffs) - 0.5 / sigma2
62 | smooth_l1_add = torch.mul(smooth_l1_option1, smooth_l1_signs) + \
63 | torch.mul(smooth_l1_option2, 1 - smooth_l1_signs)
64 | smooth_l1 = smooth_l1_add
65 |
66 | return smooth_l1
67 |
68 |
69 | class FocalLoss(nn.Module):
70 | def __init__(self, gamma=0, alpha=None, size_average=True):
71 | super(FocalLoss, self).__init__()
72 | self.gamma = gamma
73 | self.alpha = alpha
74 | if isinstance(alpha, (float, int, float)): self.alpha = torch.Tensor([alpha, 1-alpha])
75 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
76 | self.size_average = size_average
77 |
78 | def forward(self, input, target):
79 | if input.dim() > 2:
80 | input = input.view(input.size(0), input.size(1),-1) # N,C,H,W => N,C,H*W
81 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
82 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
83 | target = target.view(-1, 1)
84 |
85 | logpt = F.log_softmax(input)
86 | logpt = logpt.gather(1, target)
87 | logpt = logpt.view(-1)
88 | pt = Variable(logpt.data.exp())
89 |
90 | if self.alpha is not None:
91 | if self.alpha.type() != input.data.type():
92 | self.alpha = self.alpha.type_as(input.data)
93 | at = self.alpha.gather(0, target.data.view(-1))
94 | logpt = logpt * Variable(at)
95 |
96 | loss = -1 * (1-pt)**self.gamma * logpt
97 | if self.size_average: return loss.mean()
98 | else:
99 | return loss.sum()
100 |
101 |
102 | if __name__ == '__main__':
103 | loss_fn = torch.nn.CrossEntropyLoss(reduce=False, size_average=False, weight=None)
104 | input = Variable(torch.randn(2, 3, 5)) # (batch_size, C)
105 | target = Variable(torch.LongTensor(2, 5).random_(3))
106 | loss = loss_fn(input, target)
107 | print(input)
108 | print(target)
109 | print(loss)
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/voc_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import collections
4 | import os.path as osp
5 |
6 | import numpy as np
7 | import PIL.Image
8 | import scipy.io
9 | import torch
10 | from torch.utils import data
11 | import cv2
12 | import random
13 |
14 |
15 | """
16 | https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/datasets/voc.py
17 | """
18 |
19 |
20 | class VOCClassSegBase(data.Dataset):
21 |
22 | class_names = np.array([
23 | 'background',
24 | 'aeroplane',
25 | 'bicycle',
26 | 'bird',
27 | 'boat',
28 | 'bottle',
29 | 'bus',
30 | 'car',
31 | 'cat',
32 | 'chair',
33 | 'cow',
34 | 'diningtable',
35 | 'dog',
36 | 'horse',
37 | 'motorbike',
38 | 'person',
39 | 'potted plant',
40 | 'sheep',
41 | 'sofa',
42 | 'train',
43 | 'tv/monitor',
44 | ])
45 | mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
46 |
47 |
48 | def __init__(self, root, split='train', transform=True):
49 | self.root = root
50 | self.split = split
51 | self._transform = transform
52 |
53 | # VOC2011 and others are subset of VOC2012
54 | dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012')
55 | # dataset_dir = osp.join(self.root, 'VOC2007')
56 |
57 | self.files = collections.defaultdict(list)
58 | for split_file in ['train', 'val']:
59 | imgsets_file = osp.join(
60 | dataset_dir, 'ImageSets/Segmentation/%s.txt' % split_file)
61 | for img_name in open(imgsets_file):
62 | img_name = img_name.strip()
63 | img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % img_name)
64 | lbl_file = osp.join(dataset_dir, 'SegmentationClass/%s.png' % img_name)
65 | self.files[split_file].append({
66 | 'img': img_file,
67 | 'lbl': lbl_file,
68 | })
69 |
70 | def __len__(self):
71 | return len(self.files[self.split])
72 |
73 | def __getitem__(self, index):
74 | data_file = self.files[self.split][index] # 数据
75 | # load image
76 | img_file = data_file['img']
77 | img = PIL.Image.open(img_file)
78 | img = np.array(img, dtype=np.uint8)
79 | # load label
80 | lbl_file = data_file['lbl']
81 | lbl = PIL.Image.open(lbl_file)
82 | lbl = np.array(lbl, dtype=np.uint8)
83 |
84 | lbl[lbl == 255] = 0
85 | # augment
86 | img, lbl = self.randomFlip(img, lbl)
87 | img, lbl = self.randomCrop(img, lbl)
88 | img, lbl = self.resize(img, lbl)
89 |
90 | if self._transform:
91 | return self.transform(img, lbl)
92 | else:
93 | return img, lbl
94 |
95 |
96 | def transform(self, img, lbl):
97 | img = img[:, :, ::-1] # RGB -> BGR
98 | img = img.astype(np.float64)
99 | img -= self.mean_bgr
100 | img = img.transpose(2, 0, 1) # whc -> cwh
101 | img = torch.from_numpy(img).float()
102 | lbl = torch.from_numpy(lbl).long()
103 | return img, lbl
104 |
105 | def untransform(self, img, lbl):
106 | img = img.numpy()
107 | img = img.transpose(1, 2, 0) # cwh -> whc
108 | img += self.mean_bgr
109 | img = img.astype(np.uint8)
110 | img = img[:, :, ::-1] # BGR -> RGB
111 | lbl = lbl.numpy()
112 | return img, lbl
113 |
114 | def randomFlip(self, img, label):
115 | if random.random() < 0.5:
116 | img = np.fliplr(img)
117 | label = np.fliplr(label)
118 | return img, label
119 |
120 | def resize(self, img, label, s=320):
121 | # print(s, img.shape)
122 | img = cv2.resize(img, (s, s), interpolation=cv2.INTER_LINEAR)
123 | label = cv2.resize(label, (s, s), interpolation=cv2.INTER_NEAREST)
124 | return img, label
125 |
126 | def randomCrop(self, img, label):
127 | h, w, _ = img.shape
128 | short_size = min(w, h)
129 | rand_size = random.randrange(int(0.7 * short_size), short_size)
130 | x = random.randrange(0, w - rand_size)
131 | y = random.randrange(0, h - rand_size)
132 |
133 | return img[y:y + rand_size, x:x + rand_size], label[y:y + rand_size, x:x + rand_size]
134 | # data augmentaion
135 | def augmentation(self, img, lbl):
136 | img, lbl = self.randomFlip(img, lbl)
137 | img, lbl = self.randomCrop(img, lbl)
138 | img, lbl = self.resize(img, lbl)
139 | return img, lbl
140 |
141 | # elif not self.predict: # for batch test, this is needed
142 | # img, label = self.randomCrop(img, label)
143 | # img, label = self.resize(img, label, VOCClassSeg.img_size)
144 | # else:
145 | # pass
146 |
147 |
148 | class VOC2012ClassSeg(VOCClassSegBase):
149 |
150 | # url = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' # NOQA
151 |
152 | def __init__(self, root, split='train', transform=False):
153 | super(VOC2012ClassSeg, self).__init__(
154 | root, split=split, transform=transform)
155 |
156 |
157 | """
158 | vocbase = VOC2012ClassSeg(root="/home/yxk/Downloads/")
159 |
160 | print(vocbase.__len__())
161 | img, lbl = vocbase.__getitem__(0)
162 | img = img[:, :, ::-1]
163 | img = cv2.resize(img, (320, 320), interpolation=cv2.INTER_LINEAR)
164 | print(np.shape(img))
165 | print(np.shape(lbl))
166 |
167 | """
168 |
169 |
170 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.utils.data.dataloader
6 | import numpy as np
7 | import torchvision
8 | import models
9 | import voc_loader
10 | import loss
11 | from torch.optim import Adam, SGD
12 | from tensorboardX import SummaryWriter
13 | from argparse import ArgumentParser
14 | import tools
15 |
16 |
17 | # argumentparse
18 | parser = ArgumentParser()
19 | parser.add_argument('-bs', '--batch_size', type=int, default=2, help="batch size of the data")
20 | parser.add_argument('-e', '--epochs', type=int, default=300, help='epoch of the train')
21 | parser.add_argument('-c', '--n_class', type=int, default=21, help='the classes of the dataset')
22 | parser.add_argument('-lr', '--learning_rate', type=float, default=1e-3, help='learning rate')
23 | args = parser.parse_args()
24 |
25 | # import visualize
26 | writer = SummaryWriter()
27 |
28 | batch_size = args.batch_size
29 | learning_rate = args.learning_rate
30 | epoch_num = args.epochs
31 | n_class = args.n_class
32 |
33 |
34 | best_test_loss = np.inf
35 | pretrained = 'reload'
36 | use_cuda = torch.cuda.is_available()
37 |
38 | # path = os.path.expanduser('/home/yxk/Downloads/')
39 |
40 | # dataset 2007
41 | data_path = os.path.expanduser('/home/yxk/data/')
42 |
43 | print('load data....')
44 | train_data = voc_loader.VOC2012ClassSeg(root=data_path, split='train', transform=True)
45 |
46 | train_loader = torch.utils.data.DataLoader(train_data,
47 | batch_size=batch_size,
48 | shuffle=True,
49 | num_workers=5)
50 | val_data = voc_loader.VOC2012ClassSeg(root=data_path,
51 | split='val',
52 | transform=True)
53 | val_loader = torch.utils.data.DataLoader(val_data,
54 | batch_size=batch_size,
55 | shuffle=False,
56 | num_workers=5)
57 |
58 | vgg_model = models.VGGNet(requires_grad=True)
59 | fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class)
60 |
61 | if use_cuda:
62 | fcn_model.cuda()
63 |
64 | criterion = loss.CrossEntropyLoss2d()
65 | # create your optimizer
66 | optimizer = Adam(fcn_model.parameters())
67 | # optimizer = torch.optim.SGD(fcn_model.parameters(), lr=0.01)
68 |
69 | def train(epoch):
70 | fcn_model.train() # tran mode
71 | total_loss = 0.
72 | for batch_idx, (imgs, labels) in enumerate(train_loader):
73 | N = imgs.size(0)
74 | if use_cuda:
75 | imgs = imgs.cuda()
76 | labels = labels.cuda()
77 |
78 | imgs_tensor = Variable(imgs) # torch.Size([2, 3, 320, 320])
79 | labels_tensor = Variable(labels) # torch.Size([2, 320, 320])
80 | out = fcn_model(imgs_tensor) # torch.Size([2, 21, 320, 320])
81 |
82 | # with open('./result.txt', 'r+') as f:
83 | # f.write(str(out.detach().numpy()))
84 | # f.write("\n")
85 |
86 | loss = criterion(out, labels_tensor)
87 | loss /= N
88 | optimizer.zero_grad()
89 | loss.backward()
90 | optimizer.step() # update all arguments
91 | total_loss += loss.data[0] # return float
92 |
93 | # if batch_idx == 2:
94 | # break
95 |
96 | if (batch_idx) % 20 == 0:
97 | print('train epoch [%d/%d], iter[%d/%d], lr %.7f, aver_loss %.5f' % (epoch,
98 | epoch_num, batch_idx,
99 | len(train_loader), learning_rate,
100 | total_loss / (batch_idx + 1)))
101 |
102 | # # visiualize scalar
103 | # if epoch % 10 == 0:
104 | # label_img = tools.labelToimg(labels[0])
105 | # net_out = out[0].data.max(1)[1].squeeze_(0)
106 | # out_img = tools.labelToimg(net_out)
107 | # writer.add_scalar("loss", loss, epoch)
108 | # writer.add_scalar("total_loss", total_loss, epoch)
109 | # writer.add_scalars('loss/scalar_group', {"loss": epoch * loss,
110 | # "total_loss": epoch * total_loss})
111 | # writer.add_image('Image', imgs[0], epoch)
112 | # writer.add_image('label', label_img, epoch)
113 | # writer.add_image("out", out_img, epoch)
114 |
115 | assert total_loss is not np.nan
116 | assert total_loss is not np.inf
117 |
118 | # model save
119 | if (epoch) % 20 == 0:
120 | torch.save(fcn_model.state_dict(), './pretrained_models/model%d.pth'%epoch) # save for 5 epochs
121 | total_loss /= len(train_loader)
122 | print('train epoch [%d/%d] average_loss %.5f' % (epoch, epoch_num, total_loss))
123 |
124 |
125 | def test(epoch):
126 | fcn_model.eval()
127 | total_loss = 0.
128 | for batch_idx, (imgs, labels) in enumerate(val_loader):
129 | N = imgs.size(0)
130 | if use_cuda:
131 | imgs = imgs.cuda()
132 | labels = labels.cuda()
133 | imgs = Variable(imgs) # , volatile=True
134 | labels = Variable(labels) # , volatile=True
135 | out = fcn_model(imgs)
136 | loss = criterion(out, labels)
137 | loss /= N
138 | total_loss += loss.data[0]
139 |
140 | if (batch_idx + 1) % 3 == 0:
141 | print('test epoch [%d/%d], iter[%d/%d], aver_loss %.5f' % (epoch,
142 | epoch_num, batch_idx, len(val_loader),
143 | total_loss / (batch_idx + 1)))
144 |
145 |
146 |
147 | total_loss /= len(val_loader)
148 | print('test epoch [%d/%d] average_loss %.5f' % (epoch, epoch_num, total_loss))
149 |
150 | global best_test_loss
151 | if best_test_loss > total_loss:
152 | best_test_loss = total_loss
153 | print('best loss....')
154 | # fcn_model.save('SBD.pth')
155 |
156 |
157 | if __name__ == '__main__':
158 | # print(torch.cuda.is_available())
159 | for epoch in range(epoch_num):
160 | train(epoch)
161 | # test(epoch)
162 | # adjust learning rate
163 | if epoch == 20:
164 | learning_rate *= 0.01
165 | optimizer.param_groups[0]['lr'] = learning_rate
166 | # optimizer.param_groups[1]['lr'] = learning_rate * 2
--------------------------------------------------------------------------------
/result.txt~:
--------------------------------------------------------------------------------
1 | [[[[ 5.92323971e+00 -8.00159335e-01 -4.08085883e-02 ... -1.83560163e-01
2 | -2.73198158e-01 8.72965217e-01]
3 | [ 7.91519880e-01 -6.11786366e-01 -5.96345067e-01 ... 4.50389147e-01
4 | 1.13440382e+00 7.10353255e-01]
5 | [ 1.36564672e+00 1.48077452e+00 -9.77772772e-02 ... -7.36534715e-01
6 | 1.47829682e-01 2.38782719e-01]
7 | ...
8 | [ 4.10035086e+00 2.12672305e+00 -2.63656586e-01 ... 5.45657754e-01
9 | 8.17386031e-01 -1.19801074e-01]
10 | [-6.56773925e-01 4.99885708e-01 -9.63091254e-01 ... -1.77323252e-01
11 | 2.85922766e-01 7.60141313e-01]
12 | [ 1.10465205e+00 4.30303049e+00 4.33405042e-01 ... 1.23104215e+00
13 | 1.31855106e+00 9.57159519e-01]]
14 |
15 | [[-1.10851789e+00 6.93656325e-01 -1.98275137e+00 ... -2.27719402e+00
16 | -1.75730240e+00 1.84114003e+00]
17 | [-1.23786783e+00 1.29358852e+00 4.08284992e-01 ... 9.27657843e-01
18 | -1.99201500e+00 2.58990735e-01]
19 | [-2.54253864e+00 -7.82760739e-01 9.52594429e-02 ... 9.89706397e-01
20 | 2.07655117e-01 9.46576118e-01]
21 | ...
22 | [-8.37866366e-01 -1.01337838e+00 3.31114709e-01 ... 1.05518985e+00
23 | -1.62823427e+00 1.81347169e-02]
24 | [-4.31546807e-01 1.43094845e-02 4.65616274e+00 ... 4.70807219e+00
25 | -4.96275946e-02 4.37961042e-01]
26 | [-5.69424808e-01 -2.60643095e-01 5.83820283e-01 ... 7.80711055e-01
27 | -1.78764209e-01 -2.10972771e-01]]
28 |
29 | [[ 1.35069156e+00 -1.27873790e+00 -1.16048425e-01 ... 4.98614877e-01
30 | 9.64964747e-01 -1.32347178e+00]
31 | [ 1.02951872e+00 -1.43840754e+00 4.39141244e-01 ... 4.35016841e-01
32 | 4.43858176e-01 -1.03785980e+00]
33 | [-1.94979334e+00 2.59248108e-01 -1.03685164e+00 ... -6.70111299e-01
34 | 6.39723539e+00 8.45238268e-02]
35 | ...
36 | [-5.84895015e-01 -7.93908477e-01 -3.35549533e-01 ... -2.78513074e-01
37 | 4.41413069e+00 -8.51080298e-01]
38 | [-5.95055640e-01 -1.84282732e+00 -8.80498886e-01 ... -4.80649233e-01
39 | -3.42939615e-01 -5.18452287e-01]
40 | [-1.11653221e+00 2.84047395e-01 -1.27934396e+00 ... -1.14018810e+00
41 | -3.95604074e-01 -4.71287340e-01]]
42 |
43 | ...
44 |
45 | [[-1.39704025e+00 4.81168538e-01 3.22706223e-01 ... 7.93970108e-01
46 | -1.44059372e+00 3.61856669e-01]
47 | [-1.30902135e+00 -6.25786334e-02 -9.29827452e-01 ... -9.13603425e-01
48 | -1.42951572e+00 -6.53945327e-01]
49 | [-1.60974324e-01 -7.23360181e-01 2.42996597e+00 ... 3.82239413e+00
50 | 1.36434108e-01 4.94289696e-02]
51 | ...
52 | [-6.84200883e-01 3.93484265e-01 3.60280037e+00 ... 3.37190700e+00
53 | -3.26695889e-01 -5.98627567e-01]
54 | [-1.59633458e+00 -5.97781062e-01 8.47480774e-01 ... 6.81312203e-01
55 | -3.30486387e-01 -6.95333838e-01]
56 | [-2.87006021e-01 -4.44574237e-01 8.85481954e-01 ... 6.13724589e-01
57 | 6.60700560e-01 2.30180919e-02]]
58 |
59 | [[ 1.17688119e+00 -2.55626726e+00 -3.06273028e-02 ... 7.17433274e-01
60 | -6.34098649e-01 -2.37238622e+00]
61 | [ 4.78541493e-01 -2.96923548e-01 4.50580835e-01 ... -9.04723406e-02
62 | 4.63649654e+00 3.16596508e-01]
63 | [-7.02592194e-01 -3.03770155e-01 -3.37809443e-01 ... 6.68550208e-02
64 | -1.19772375e+00 -2.58156729e+00]
65 | ...
66 | [-5.74998677e-01 -7.27917492e-01 3.11518461e-01 ... -1.54708683e-01
67 | 1.56628579e-01 -4.19498712e-01]
68 | [-1.07138073e+00 8.90584514e-02 -2.41475195e-01 ... -6.66464567e-01
69 | 3.09198707e-01 6.07450902e-01]
70 | [-5.77349544e-01 1.22415029e-01 -1.21873701e+00 ... -1.71460152e+00
71 | -3.33992660e-01 -4.11800772e-01]]
72 |
73 | [[-1.52562523e+00 -2.21709341e-01 -9.81039703e-01 ... -1.32452083e+00
74 | -1.32980561e+00 2.72239351e+00]
75 | [-1.59430981e+00 -9.04133737e-01 -7.92442918e-01 ... -1.58888841e+00
76 | -1.06508708e+00 -9.98270869e-01]
77 | [-2.59319639e+00 -1.04864287e+00 7.85771608e-02 ... -1.39404207e-01
78 | -1.01008296e+00 9.52526629e-01]
79 | ...
80 | [-1.74395108e+00 -1.16058242e+00 5.95709682e-02 ... 5.37530482e-02
81 | -4.73932147e-01 -2.40867734e-01]
82 | [ 1.47193265e+00 -1.19619203e+00 5.09279072e-02 ... -1.54825449e-02
83 | -1.75131059e+00 -3.87153745e-01]
84 | [-4.38177288e-01 -1.15348125e+00 -4.65141952e-01 ... -7.71390259e-01
85 | -8.49012077e-01 -7.12182224e-01]]]
86 |
87 |
88 | [[[ 5.79030466e+00 -6.13477111e-01 1.34621739e-01 ... 2.35931322e-01
89 | 3.62992436e-01 9.07037079e-01]
90 | [ 7.75903106e-01 -6.45938516e-01 -6.03215575e-01 ... 2.59432852e-01
91 | 1.08058214e+00 6.78438067e-01]
92 | [ 1.43162465e+00 1.51875818e+00 -8.04446638e-02 ... -4.78534549e-01
93 | 4.73610580e-01 2.21373439e-02]
94 | ...
95 | [ 4.10263252e+00 2.25726724e+00 -2.02034026e-01 ... 4.84108210e-01
96 | 6.01388156e-01 -2.65990585e-01]
97 | [-5.95873713e-01 5.99902868e-01 -7.94230580e-01 ... -2.45066911e-01
98 | 2.16514528e-01 7.37514496e-01]
99 | [ 1.16487753e+00 4.30234385e+00 5.37481904e-01 ... 1.46371186e+00
100 | 1.33895302e+00 1.19199586e+00]]
101 |
102 | [[-1.03382134e+00 6.44701064e-01 -1.82450879e+00 ... -1.76479149e+00
103 | -1.36212254e+00 1.43689275e+00]
104 | [-1.23853838e+00 1.36728752e+00 5.08096933e-01 ... 7.33646035e-01
105 | -1.19717956e+00 5.10577559e-01]
106 | [-2.37532234e+00 -7.42191195e-01 1.42771542e-01 ... 5.63968301e-01
107 | -3.23223203e-01 6.67942166e-01]
108 | ...
109 | [-8.30584466e-01 -1.05315518e+00 4.86099184e-01 ... 9.79613185e-01
110 | -2.00232697e+00 -2.39551976e-01]
111 | [-3.91282946e-01 -3.54413316e-03 4.65324020e+00 ... 5.34130001e+00
112 | -3.45440149e-01 2.99561739e-01]
113 | [-5.26298404e-01 -2.22298160e-01 4.12146777e-01 ... 6.14750028e-01
114 | -3.49447429e-01 -2.99425632e-01]]
115 |
116 | [[ 1.20366549e+00 -1.35653746e+00 1.97638273e-02 ... 7.36568451e-01
117 | 5.61294317e-01 -1.12250817e+00]
118 | [ 9.76951838e-01 -1.44697058e+00 3.67370695e-01 ... 5.68231344e-02
119 | -1.32984102e-01 -1.14536870e+00]
120 | [-1.83419859e+00 1.97738379e-01 -9.99228954e-01 ... -8.64497900e-01
121 | 4.54018593e+00 -3.18870276e-01]
122 | ...
123 | [-6.12242699e-01 -7.50237465e-01 -3.86662364e-01 ... -5.03511131e-01
124 | 5.42525959e+00 -9.06905651e-01]
125 | [-6.22503996e-01 -1.79135609e+00 -9.82421637e-01 ... -1.15232557e-01
126 | -3.47687066e-01 -3.64765257e-01]
127 | [-1.13146615e+00 2.28893250e-01 -1.21872151e+00 ... -1.05051839e+00
128 | -3.50717962e-01 -3.68980527e-01]]
129 |
130 | ...
131 |
132 | [[-1.34322214e+00 5.18092990e-01 3.07727635e-01 ... 6.34089351e-01
133 | -1.30468976e+00 1.36265606e-01]
134 | [-1.35041499e+00 -1.45696998e-02 -8.98943424e-01 ... -6.61089420e-01
135 | -1.20868838e+00 -5.06785631e-01]
136 | [-1.18282229e-01 -6.79277182e-01 2.40828681e+00 ... 3.20594454e+00
137 | -6.19544685e-02 -8.86955559e-02]
138 | ...
139 | [-6.57785892e-01 3.32177550e-01 3.55007601e+00 ... 3.67592573e+00
140 | -3.96220297e-01 -6.03670657e-01]
141 | [-1.62041402e+00 -6.26118422e-01 7.36821055e-01 ... 8.25809717e-01
142 | -4.42255110e-01 -7.45555758e-01]
143 | [-2.87318736e-01 -4.32660460e-01 7.54006624e-01 ... 6.52838707e-01
144 | 8.37458014e-01 5.28133810e-02]]
145 |
146 | [[ 1.13879967e+00 -2.46339774e+00 -8.04775879e-02 ... 3.66423726e-01
147 | -2.91143537e-01 -1.74150181e+00]
148 | [ 4.57310677e-01 -3.95081580e-01 4.17001843e-01 ... 8.25519189e-02
149 | 3.85775566e+00 2.49403328e-01]
150 | [-7.21119940e-01 -2.90319532e-01 -3.51026058e-01 ... 3.48591879e-02
151 | -8.78576815e-01 -2.11211038e+00]
152 | ...
153 | [-7.24928796e-01 -6.07973993e-01 3.06567937e-01 ... 4.75262031e-02
154 | 1.96665227e-01 -3.96473110e-01]
155 | [-1.16305733e+00 1.33155555e-01 -2.07348824e-01 ... -7.02263296e-01
156 | 4.37461734e-01 5.54734647e-01]
157 | [-5.77684641e-01 8.88289586e-02 -1.22607136e+00 ... -1.88995123e+00
158 | -3.47441137e-01 -5.39577723e-01]]
159 |
160 | [[-1.51553297e+00 -2.27344140e-01 -9.51495409e-01 ... -1.37347555e+00
161 | -1.21363258e+00 1.88223696e+00]
162 | [-1.54229426e+00 -8.41434836e-01 -7.48655975e-01 ... -1.15299320e+00
163 | -1.46017432e+00 -6.28697932e-01]
164 | [-2.49284363e+00 -1.03243661e+00 8.55334699e-02 ... 5.77781796e-02
165 | -7.97680020e-01 7.41598606e-01]
166 | ...
167 | [-1.77141857e+00 -1.26735592e+00 4.56841588e-02 ... -1.27552390e-01
168 | -5.63686013e-01 -3.96051288e-01]
169 | [ 1.56623125e+00 -1.08536267e+00 3.38384509e-03 ... -1.66544914e-01
170 | -2.08985281e+00 -2.63504922e-01]
171 | [-4.08491790e-01 -1.06412196e+00 -5.31406641e-01 ... -1.28227520e+00
172 | -1.24773955e+00 -8.71654749e-01]]]]
173 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torchvision import models
7 | from torchvision.models.vgg import VGG
8 | import torch.nn.functional as F
9 | from torch.nn import init
10 | import numpy as np
11 |
12 | """
13 |
14 | """
15 |
16 |
17 | def get_upsample_weight(in_channels, out_channels, kernel_size):
18 | '''
19 | make a 2D bilinear kernel suitable for upsampling
20 | '''
21 | factor = (kernel_size + 1) // 2
22 | if kernel_size % 2 == 1:
23 | center = factor - 1
24 | else:
25 | center = factor - 0.5
26 | og = np.ogrid[:kernel_size, :kernel_size] # list (64 x 1), (1 x 64)
27 | filt = (1 - abs(og[0] - center) / factor) * \
28 | (1 - abs(og[1] - center) / factor) # 64 x 64
29 | weight = np.zeros((in_channels, out_channels, kernel_size,
30 | kernel_size), dtype=np.float64)
31 | weight[range(in_channels), range(out_channels), :, :] = filt
32 |
33 | return torch.from_numpy(weight).float()
34 |
35 | class FCN32s(nn.Module):
36 | def __init__(self, pretrained_net, n_class):
37 | super().__init__()
38 | self.n_class = n_class
39 | self.pretrained_net = pretrained_net
40 | self.relu = nn.ReLU(inplace=True)
41 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
42 | self.bn1 = nn.BatchNorm2d(512)
43 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
44 | self.bn2 = nn.BatchNorm2d(256)
45 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
46 | self.bn3 = nn.BatchNorm2d(128)
47 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
48 | self.bn4 = nn.BatchNorm2d(64)
49 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
50 | self.bn5 = nn.BatchNorm2d(32)
51 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
52 |
53 | def forward(self, x):
54 | output = self.pretrained_net.forward(x)
55 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32]
56 | score = self.bn1(self.relu(self.deconv1(x5))) # size=[n, 512, x.h/16, x.w/16]
57 | score = self.bn2(self.relu(self.deconv2(score))) # size=[n, 256, x.h/8, x.w/8]
58 | score = self.bn3(self.relu(self.deconv3(score))) # size=[n, 128, x.h/4, x.w/4]
59 | score = self.bn4(self.relu(self.deconv4(score))) # size=[n, 64, x.h/2, x.w/2]
60 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w]
61 | score = self.classifier(score) # size=[n, n_class, x.h, x.w]
62 |
63 | return score
64 |
65 |
66 | class FCN16s(nn.Module):
67 | def __init__(self, pretrained_net, n_class):
68 | super().__init__()
69 | self.n_class = n_class
70 | self.pretrained_net = pretrained_net
71 | self.relu = nn.ReLU(inplace=True)
72 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
73 | self.bn1 = nn.BatchNorm2d(512)
74 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
75 | self.bn2 = nn.BatchNorm2d(256)
76 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
77 | self.bn3 = nn.BatchNorm2d(128)
78 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
79 | self.bn4 = nn.BatchNorm2d(64)
80 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
81 | self.bn5 = nn.BatchNorm2d(32)
82 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
83 |
84 | def forward(self, x):
85 | output = self.pretrained_net.forward(x)
86 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32]
87 | x4 = output['x4'] # size=[n, 512, x.h/16, x.w/16]
88 |
89 | score = self.relu(self.deconv1(x5)) # size=[n, 512, x.h/16, x.w/16]
90 | score = self.bn1(score + x4) # element-wise add, size=[n, 512, x.h/16, x.w/16]
91 | score = self.bn2(self.relu(self.deconv2(score))) # size=[n, 256, x.h/8, x.w/8]
92 | score = self.bn3(self.relu(self.deconv3(score))) # size=[n, 128, x.h/4, x.w/4]
93 | score = self.bn4(self.relu(self.deconv4(score))) # size=[n, 64, x.h/2, x.w/2]
94 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w]
95 | score = self.classifier(score) # size=[n, n_class, x.h, x.w]
96 |
97 | return score
98 |
99 |
100 | class FCN8s(nn.Module):
101 | def __init__(self, pretrained_net, n_class):
102 | super().__init__()
103 | self.n_class = n_class
104 | self.pretrained_net = pretrained_net
105 | self.relu = nn.ReLU(inplace=True)
106 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
107 | self.bn1 = nn.BatchNorm2d(512)
108 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
109 | self.bn2 = nn.BatchNorm2d(256)
110 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
111 | self.bn3 = nn.BatchNorm2d(128)
112 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
113 | self.bn4 = nn.BatchNorm2d(64)
114 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
115 | self.bn5 = nn.BatchNorm2d(32)
116 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
117 |
118 | # self._init_weights()
119 | #
120 | # 1
121 | init.xavier_uniform_(self.deconv1.weight)
122 | # 2
123 | init.xavier_uniform_(self.deconv2.weight)
124 | # 3
125 | init.xavier_uniform_(self.deconv3.weight)
126 | init.xavier_uniform_(self.deconv4.weight)
127 | init.xavier_uniform_(self.deconv5.weight)
128 | init.xavier_uniform_(self.classifier.weight)
129 |
130 | def forward(self, x):
131 | output = self.pretrained_net.forward(x)
132 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32]
133 | x4 = output['x4'] # size=[n, 512, x.h/16, x.w/16]
134 | x3 = output['x3'] # size=[n, 512, x.h/8, x.w/8]
135 |
136 | score = self.relu(self.deconv1(x5)) # size=[n, 512, x.h/16, x.w/16]
137 | score = self.bn1(score + x4) # element-wise add, size=[n, 512, x.h/16, x.w/16]
138 | score = self.relu(self.deconv2(score)) # size=[n, 256, x.h/8, x.w/8]
139 | score = self.bn2(score+x3)
140 | score = self.bn3(self.relu(self.deconv3(score))) # size=[n, 128, x.h/4, x.w/4]
141 | score = self.bn4(self.relu(self.deconv4(score))) # size=[n, 64, x.h/2, x.w/2]
142 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w]
143 | score = self.classifier(score) # size=[n, n_class, x.h, x.w]
144 |
145 | return score
146 |
147 | def _init_weights(self):
148 | '''
149 | hide method, used just in class
150 | '''
151 | for m in self.modules():
152 | if isinstance(m, nn.Conv2d):
153 | m.weight.data.zero_()
154 | # if m.bias is not None:
155 | m.bias.data.zero_()
156 | if isinstance(m, nn.ConvTranspose2d):
157 | assert m.kernel_size[0] == m.kernel_size[1]
158 | initial_weight = get_upsample_weight(m.in_channels,
159 | m.out_channels, m.kernel_size[0])
160 | m.weight.data.copy_(initial_weight) # copy not = ?
161 |
162 |
163 | class FCN1s(nn.Module):
164 | def __init__(self, pretrained_net, n_class):
165 | super().__init__()
166 | self.n_class = n_class
167 | self.pretrained_net = pretrained_net
168 | self.relu = nn.ReLU(inplace=True)
169 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
170 | self.bn1 = nn.BatchNorm2d(512)
171 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
172 | self.bn2 = nn.BatchNorm2d(256)
173 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
174 | self.bn3 = nn.BatchNorm2d(128)
175 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
176 | self.bn4 = nn.BatchNorm2d(64)
177 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
178 | self.bn5 = nn.BatchNorm2d(32)
179 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
180 |
181 | def forward(self, x):
182 | output = self.pretrained_net.forward(x)
183 | x5 = output['x5'] # size=[n, 512, x.h/32, x.w/32]
184 | x4 = output['x4'] # size=[n, 512, x.h/16, x.w/16]
185 | x3 = output['x3'] # size=[n, 512, x.h/8, x.w/8]
186 | x2 = output['x2'] # size=[n, 512, x.h/4, x.w/4]
187 | x1 = output['x1'] # size=[n, 512, x.h/2, x.w/2]
188 |
189 | score = self.relu(self.deconv1(x5)) # size=[n, 512, x.h/16, x.w/16]
190 | score = self.bn1(score + x4) # element-wise add, size=[n, 512, x.h/16, x.w/16]
191 | score = self.relu(self.deconv2(score)) # size=[n, 256, x.h/8, x.w/8]
192 | score = self.bn2(score+x3)
193 | score = self.relu(self.deconv3(score)) # size=[n, 128, x.h/4, x.w/4]
194 | score = self.bn3(score+x2)
195 | score = self.relu(self.deconv4(score)) # size=[n, 64, x.h/2, x.w/2]
196 | score = self.bn4(score+x1)
197 | score = self.bn5(self.relu(self.deconv5(score))) # size=[n, 32, x.h, x.w]
198 | score = self.classifier(score) # size=[n, n_class, x.h, x.w]
199 |
200 | return score
201 |
202 | ranges = {
203 | 'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
204 | 'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
205 | 'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
206 | 'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
207 | }
208 |
209 | # cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
210 | cfg = {
211 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
212 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
213 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
214 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
215 | }
216 |
217 |
218 | def make_layers(cfg, batch_norm=False):
219 | """
220 | :param cfg: cfg['vgg16']
221 | :param batch_norm:
222 | :return: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 数字表示卷积 'M': 表示池化
223 | """
224 | layers = []
225 | in_channels = 3
226 | for v in cfg:
227 | if v == 'M':
228 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
229 | else:
230 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
231 | if batch_norm:
232 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
233 | else:
234 | layers += [conv2d, nn.ReLU(inplace=True)]
235 | in_channels = v
236 |
237 | return nn.Sequential(*layers)
238 |
239 |
240 | class VGGNet(VGG):
241 | def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
242 | super().__init__(make_layers(cfg[model]))
243 | self.ranges = ranges[model]
244 |
245 | if pretrained:
246 | vgg16 = models.vgg16(pretrained=False)
247 | vgg16.load_state_dict(torch.load('/home/yxk/.torch/models/vgg16-397923af.pth'))
248 | # exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)
249 |
250 | if not requires_grad:
251 | for param in super().parameters():
252 | param.requires_grad = False
253 |
254 | if remove_fc: # delete redundant fully-connected layer params, can save memory
255 | del self.classifier
256 |
257 | if show_params:
258 | for name, param in self.named_parameters():
259 | print(name, param.size())
260 |
261 | def forward(self, x):
262 | output = {}
263 | # get the output of each maxpooling layer (5 maxpool in VGG net)
264 | for idx in range(len(self.ranges)):
265 | for layer in range(self.ranges[idx][0], self.ranges[idx][1]):
266 | x = self.features[layer](x)
267 | output["x%d" % (idx+1)] = x
268 | return output
269 |
270 |
271 | # other models
272 | class UNetEnc(nn.Module):
273 |
274 | def __init__(self, in_channels, features, out_channels):
275 | super().__init__()
276 |
277 | self.up = nn.Sequential(
278 | nn.Conv2d(in_channels, features, 3),
279 | nn.ReLU(inplace=True),
280 | nn.Conv2d(features, features, 3),
281 | nn.ReLU(inplace=True),
282 | nn.ConvTranspose2d(features, out_channels, 2, stride=2),
283 | nn.ReLU(inplace=True),
284 | )
285 |
286 | def forward(self, x):
287 | return self.up(x)
288 |
289 |
290 | class UNetDec(nn.Module):
291 |
292 | def __init__(self, in_channels, out_channels, dropout=False):
293 | super().__init__()
294 |
295 | layers = [
296 | nn.Conv2d(in_channels, out_channels, 3),
297 | nn.ReLU(inplace=True),
298 | nn.Conv2d(out_channels, out_channels, 3),
299 | nn.ReLU(inplace=True),
300 | ]
301 | if dropout:
302 | layers += [nn.Dropout(.5)]
303 | layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]
304 |
305 | self.down = nn.Sequential(*layers)
306 |
307 | def forward(self, x):
308 | return self.down(x)
309 |
310 |
311 | class UNet(nn.Module):
312 |
313 | def __init__(self, num_classes):
314 | super().__init__()
315 |
316 | self.dec1 = UNetDec(3, 64)
317 | self.dec2 = UNetDec(64, 128)
318 | self.dec3 = UNetDec(128, 256)
319 | self.dec4 = UNetDec(256, 512, dropout=True)
320 | self.center = nn.Sequential(
321 | nn.Conv2d(512, 1024, 3),
322 | nn.ReLU(inplace=True),
323 | nn.Conv2d(1024, 1024, 3),
324 | nn.ReLU(inplace=True),
325 | nn.Dropout(),
326 | nn.ConvTranspose2d(1024, 512, 2, stride=2),
327 | nn.ReLU(inplace=True),
328 | )
329 | self.enc4 = UNetEnc(1024, 512, 256)
330 | self.enc3 = UNetEnc(512, 256, 128)
331 | self.enc2 = UNetEnc(256, 128, 64)
332 | self.enc1 = nn.Sequential(
333 | nn.Conv2d(128, 64, 3),
334 | nn.ReLU(inplace=True),
335 | nn.Conv2d(64, 64, 3),
336 | nn.ReLU(inplace=True),
337 | )
338 | self.final = nn.Conv2d(64, num_classes, 1)
339 |
340 | def forward(self, x):
341 | dec1 = self.dec1(x)
342 | dec2 = self.dec2(dec1)
343 | dec3 = self.dec3(dec2)
344 | dec4 = self.dec4(dec3)
345 | center = self.center(dec4)
346 | enc4 = self.enc4(torch.cat([
347 | center, F.upsample_bilinear(dec4, center.size()[2:])], 1))
348 | enc3 = self.enc3(torch.cat([
349 | enc4, F.upsample_bilinear(dec3, enc4.size()[2:])], 1))
350 | enc2 = self.enc2(torch.cat([
351 | enc3, F.upsample_bilinear(dec2, enc3.size()[2:])], 1))
352 | enc1 = self.enc1(torch.cat([
353 | enc2, F.upsample_bilinear(dec1, enc2.size()[2:])], 1))
354 |
355 | return F.upsample_bilinear(self.final(enc1), x.size()[2:])
356 |
357 |
358 | class SegNetEnc(nn.Module):
359 |
360 | def __init__(self, in_channels, out_channels, num_layers):
361 | super().__init__()
362 |
363 | layers = [
364 | nn.UpsamplingBilinear2d(scale_factor=2),
365 | nn.Conv2d(in_channels, in_channels // 2, 3, padding=1),
366 | nn.BatchNorm2d(in_channels // 2),
367 | nn.ReLU(inplace=True),
368 | ]
369 | layers += [
370 | nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1),
371 | nn.BatchNorm2d(in_channels // 2),
372 | nn.ReLU(inplace=True),
373 | ] * num_layers
374 | layers += [
375 | nn.Conv2d(in_channels // 2, out_channels, 3, padding=1),
376 | nn.BatchNorm2d(out_channels),
377 | nn.ReLU(inplace=True),
378 | ]
379 | self.encode = nn.Sequential(*layers)
380 |
381 | def forward(self, x):
382 | return self.encode(x)
383 |
384 |
385 | class SegNet(nn.Module):
386 |
387 | def __init__(self, num_classes):
388 | super().__init__()
389 |
390 | # should be vgg16bn but at the moment we have no pretrained bn models
391 | decoders = list(models.vgg16(pretrained=True).features.children())
392 |
393 | self.dec1 = nn.Sequential(*decoders[:5])
394 | self.dec2 = nn.Sequential(*decoders[5:10])
395 | self.dec3 = nn.Sequential(*decoders[10:17])
396 | self.dec4 = nn.Sequential(*decoders[17:24])
397 | self.dec5 = nn.Sequential(*decoders[24:])
398 |
399 | # gives better results
400 | for m in self.modules():
401 | if isinstance(m, nn.Conv2d):
402 | m.requires_grad = False
403 |
404 | self.enc5 = SegNetEnc(512, 512, 1)
405 | self.enc4 = SegNetEnc(1024, 256, 1)
406 | self.enc3 = SegNetEnc(512, 128, 1)
407 | self.enc2 = SegNetEnc(256, 64, 0)
408 | self.enc1 = nn.Sequential(
409 | nn.UpsamplingBilinear2d(scale_factor=2),
410 | nn.Conv2d(128, 64, 3, padding=1),
411 | nn.BatchNorm2d(64),
412 | nn.ReLU(inplace=True),
413 | )
414 | self.final = nn.Conv2d(64, num_classes, 3, padding=1)
415 |
416 | def forward(self, x):
417 | dec1 = self.dec1(x)
418 | dec2 = self.dec2(dec1)
419 | dec3 = self.dec3(dec2)
420 | dec4 = self.dec4(dec3)
421 | dec5 = self.dec5(dec4)
422 | enc5 = self.enc5(dec5)
423 | enc4 = self.enc4(torch.cat([dec4, enc5], 1))
424 | enc3 = self.enc3(torch.cat([dec3, enc4], 1))
425 | enc2 = self.enc2(torch.cat([dec2, enc3], 1))
426 | enc1 = self.enc1(torch.cat([dec1, enc2], 1))
427 |
428 | return F.upsample_bilinear(self.final(enc1), x.size()[2:])
429 |
430 |
431 | class PSPDec(nn.Module):
432 |
433 | def __init__(self, in_features, out_features, downsize, upsize=60):
434 | super().__init__()
435 |
436 | self.features = nn.Sequential(
437 | nn.AvgPool2d(downsize, stride=downsize),
438 | nn.Conv2d(in_features, out_features, 1, bias=False),
439 | nn.BatchNorm2d(out_features, momentum=.95),
440 | nn.ReLU(inplace=True),
441 | nn.UpsamplingBilinear2d(upsize)
442 | )
443 |
444 | def forward(self, x):
445 | return self.features(x)
446 |
447 |
448 | class PSPNet(nn.Module):
449 |
450 | def __init__(self, num_classes):
451 | super().__init__()
452 |
453 | '''
454 | self.conv1 = nn.Sequential(
455 | nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
456 | nn.BatchNorm2d(64, momentum=.95),
457 | nn.ReLU(inplace=True),
458 | nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
459 | nn.BatchNorm2d(64, momentum=.95),
460 | nn.ReLU(inplace=True),
461 | nn.Conv2d(64, 128, 3, stride=1, padding=1, bias=False),
462 | nn.BatchNorm2d(128, momentum=.95),
463 | nn.ReLU(inplace=True),
464 | nn.MaxPool2d(3, stride=2, padding=1),
465 | )
466 | '''
467 |
468 | resnet = models.resnet101(pretrained=True)
469 |
470 | self.conv1 = resnet.conv1
471 | self.layer1 = resnet.layer1
472 | self.layer2 = resnet.layer2
473 | self.layer3 = resnet.layer3
474 | self.layer4 = resnet.layer4
475 |
476 | for m in self.modules():
477 | if isinstance(m, nn.Conv2d):
478 | m.stride = 1
479 | m.requires_grad = False
480 | if isinstance(m, nn.BatchNorm2d):
481 | m.requires_grad = False
482 |
483 | self.layer5a = PSPDec(2048, 512, 60)
484 | self.layer5b = PSPDec(2048, 512, 30)
485 | self.layer5c = PSPDec(2048, 512, 20)
486 | self.layer5d = PSPDec(2048, 512, 10)
487 |
488 | self.final = nn.Sequential(
489 | nn.Conv2d(2048, 512, 3, padding=1, bias=False),
490 | nn.BatchNorm2d(512, momentum=.95),
491 | nn.ReLU(inplace=True),
492 | nn.Dropout(.1),
493 | nn.Conv2d(512, num_classes, 1),
494 | )
495 |
496 | def forward(self, x):
497 | print('x', x.size())
498 | x = self.conv1(x)
499 | print('conv1', x.size())
500 | x = self.layer1(x)
501 | print('layer1', x.size())
502 | x = self.layer2(x)
503 | print('layer2', x.size())
504 | x = self.layer3(x)
505 | print('layer3', x.size())
506 | x = self.layer4(x)
507 | print('layer4', x.size())
508 | x = self.final(torch.cat([
509 | x,
510 | self.layer5a(x),
511 | self.layer5b(x),
512 | self.layer5c(x),
513 | self.layer5d(x),
514 | ], 1))
515 | print('final', x.size())
516 |
517 | return F.upsample_bilinear(self.final, x.size()[2:])
518 |
519 |
520 | if __name__ == "__main__":
521 | batch_size, n_class, h, w = 10, 20, 160, 160
522 |
523 | # test output size
524 | vgg_model = VGGNet(requires_grad=True)
525 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) # 224
526 | output = vgg_model.forward(input)
527 | assert output['x5'].size() == torch.Size([batch_size, 512, 5, 5])
528 |
529 | fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class)
530 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
531 | output = fcn_model.forward(input)
532 | assert output.size() == torch.Size([batch_size, n_class, h, w])
533 |
534 | fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class)
535 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
536 | output = fcn_model(input)
537 | assert output.size() == torch.Size([batch_size, n_class, h, w])
538 |
539 | fcn_model = FCN8s(pretrained_net=vgg_model, n_class=n_class)
540 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
541 | output = fcn_model(input)
542 | assert output.size() == torch.Size([batch_size, n_class, h, w])
543 |
544 | fcn_model = FCN1s(pretrained_net=vgg_model, n_class=n_class)
545 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
546 | output = fcn_model(input)
547 | assert output.size() == torch.Size([batch_size, n_class, h, w])
548 |
549 | # test a random batch, loss should decrease
550 | fcn_model = FCN1s(pretrained_net=vgg_model, n_class=n_class)
551 | criterion = nn.BCELoss()
552 | optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9)
553 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
554 | y = torch.autograd.Variable(torch.randn(batch_size, n_class, h, w), requires_grad=False)
555 | for iter in range(10):
556 | optimizer.zero_grad()
557 | output = fcn_model(input)
558 |
559 | output = nn.functional.sigmoid(output)
560 | loss = criterion(output, y) # loss
561 | loss.backward()
562 | print("iter{}, loss {}".format(iter, loss.data[0]))
563 | optimizer.step()
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
150 |
151 |
152 |
153 | VGG
154 | models
155 | data
156 | utils
157 | train_loader
158 | use_cuda
159 | n_class
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 | true
203 | DEFINITION_ORDER
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 | 1532674682494
401 |
402 |
403 | 1532674682494
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 | file://$PROJECT_DIR$/train.py
452 | 159
453 |
454 |
455 | file://$PROJECT_DIR$/train.py
456 | 79
457 |
458 |
459 |
460 | file://$PROJECT_DIR$/train.py
461 | 58
462 |
463 |
464 |
465 | file://$PROJECT_DIR$/predict.py
466 | 42
467 |
468 |
469 |
470 | file://$PROJECT_DIR$/predict.py
471 | 60
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 |
829 |
830 |
831 |
832 |
833 |
834 |
835 |
836 |
837 |
838 |
839 |
840 |
841 |
842 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 |
858 |
859 |
860 |
861 |
862 |
863 |
--------------------------------------------------------------------------------