├── .gitignore
├── Dataloader
├── __init__.py
├── baseloader.py
├── camvid_loader.py
├── citys_loader.py
├── custom_loader.py
├── seg11valid.txt
└── voc_loader.py
├── Models
├── DeepLab_v1.py
├── DeepLab_v2.py
├── DeepLab_v3.py
├── DeepLab_v3plus.py
├── Dilation8.py
├── FCN.py
├── PSPNet.py
├── SegNet.py
├── UNet.py
└── __init__.py
├── README.md
├── augmentations.py
├── evaluate.py
├── learning_curve.py
├── loss.py
├── metrics.py
├── optimizer.py
├── preparation.py
├── requirements.txt
├── train.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | logs
3 | .idea
4 |
--------------------------------------------------------------------------------
/Dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | from .custom_loader import CustomLoader
2 | from .voc_loader import VOCLoader, SBDLoader, VOC11Val
3 | from .citys_loader import CityscapesLoader
4 | from .camvid_loader import CamVidLoader
5 |
6 | VALID_DATASET = ['voc', 'cityscapes', 'sbd', 'voc11', 'camvid', 'custom']
7 |
8 |
9 | def get_loader(dataset_type):
10 | if dataset_type.lower() == 'custom':
11 | return CustomLoader
12 | elif dataset_type.lower() == 'voc':
13 | return VOCLoader
14 | elif dataset_type.lower() == 'cityscapes':
15 | return CityscapesLoader
16 | elif dataset_type.lower() == 'sbd':
17 | return SBDLoader
18 | elif dataset_type.lower() == 'voc11':
19 | return VOC11Val
20 | elif dataset_type.lower() == 'camvid':
21 | return CamVidLoader
22 | else:
23 | raise ValueError('Unsupported dataset, '
24 | 'valid datasets as follows:\n{}\n'
25 | 'voc11 only for evaluation'.format(', '.join(VALID_DATASET)))
26 |
--------------------------------------------------------------------------------
/Dataloader/baseloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils import data
4 | from torchvision import transforms
5 |
6 |
7 | class BaseLoader(data.Dataset):
8 | # specify class_name if available
9 | class_name = None
10 |
11 | def __init__(self,
12 | root,
13 | n_classes,
14 | split='train',
15 | img_size=None,
16 | augmentations=None,
17 | ignore_index=None,
18 | class_weight=None,
19 | pretrained=False):
20 |
21 | self.root = root
22 | self.n_classes = n_classes
23 | self.split = split
24 | self.img_size = img_size
25 | self.augmentations = augmentations
26 | self.ignore_index = ignore_index
27 | self.class_weight = class_weight
28 |
29 | if pretrained:
30 | # if use pretrained model, substract mean and divide standard deviation
31 | self.mean = torch.tensor([0.485, 0.456, 0.406])
32 | self.std = torch.tensor([0.229, 0.224, 0.225])
33 | self.tf = transforms.Compose([
34 | transforms.ToTensor(),
35 | transforms.Normalize(self.mean.tolist(), self.std.tolist())
36 | ])
37 | self.untf = transforms.Compose([
38 | transforms.Normalize((-self.mean / self.std).tolist(),
39 | (1.0 / self.std).tolist())
40 | ])
41 | else:
42 | # if not use pretrained model, only scale images to [0, 1]
43 | self.tf = transforms.Compose([transforms.ToTensor()])
44 | self.untf = transforms.Compose(
45 | [transforms.Normalize([0, 0, 0], [1, 1, 1])])
46 |
47 | def __getitem__(self, index):
48 | return NotImplementedError
49 |
50 | def transform(self, img, lbl):
51 | img = self.tf(img)
52 | lbl = np.array(lbl, dtype=np.int32)
53 | lbl[lbl == 255] = -1
54 | if self.ignore_index:
55 | lbl[lbl == self.ignore_index] = -1
56 | lbl = torch.from_numpy(lbl).long()
57 | return img, lbl
58 |
59 | def untransform(self, img, lbl):
60 | img = self.untf(img)
61 | img = img.numpy()
62 | img = img.transpose(1, 2, 0)
63 | img = img * 255
64 | img = img.astype(np.uint8)
65 | lbl = lbl.numpy()
66 | return img, lbl
67 |
68 | def getpalette(self):
69 | return NotImplementedError
70 |
--------------------------------------------------------------------------------
/Dataloader/camvid_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from .baseloader import BaseLoader
6 |
7 | class CamVidLoader(BaseLoader):
8 | """CamVid dataset loader.
9 | Parameters
10 | ----------
11 | root: path to CamVid dataset.
12 | n_classes: number of classes, default 11.
13 | split: choose subset of dataset, 'train','val' or 'test'.
14 | img_size: a list or a tuple, scale image to proper size.
15 | augmentations: whether to perform augmentation.
16 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 11.
17 | class_weight: useful in unbalanced datasets.
18 | pretrained: whether to use pretrained models
19 | """
20 | class_names = np.array([
21 | 'sky',
22 | 'building',
23 | 'pole',
24 | 'road',
25 | 'pavement',
26 | 'tree',
27 | 'sign',
28 | 'fence',
29 | 'vehicle',
30 | 'pedestrian',
31 | 'bicyclist',
32 | 'void'
33 | ])
34 |
35 | def __init__(
36 | self,
37 | root,
38 | n_classes=11,
39 | split='train',
40 | img_size=None,
41 | augmentations=None,
42 | ignore_index=11,
43 | class_weight=None,
44 | pretrained=False
45 | ):
46 | super(CamVidLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained)
47 |
48 | path = os.path.join(self.root, self.split + ".txt")
49 | with open(path, "r") as f:
50 | self.file_list = [file_name.rstrip() for file_name in f]
51 |
52 | self.class_weight = [0.2595, 0.1826, 4.5640, 0.1417,
53 | 0.9051, 0.3826, 9.6446, 1.8418,
54 | 0.6823 ,6.2478, 7.3614]
55 |
56 | print(f"Found {len(self.file_list)} {split} images")
57 |
58 | def __len__(self):
59 | return len(self.file_list)
60 |
61 | def __getitem__(self, index):
62 | img_name = self.file_list[index]
63 | img_name = img_name.split()[0].split('/')[-1]
64 | img_path = os.path.join(self.root, self.split, img_name)
65 | if self.split == 'train':
66 | lbl_path = os.path.join(self.root, 'trainannot', img_name)
67 | elif self.split == 'val':
68 | lbl_path = os.path.join(self.root, 'valannot', img_name)
69 | elif self.split == 'test':
70 | lbl_path = os.path.join(self.root, 'testannot', img_name)
71 |
72 | img = Image.open(img_path).convert('RGB')
73 | lbl = Image.open(lbl_path)
74 | if self.img_size:
75 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR)
76 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST)
77 | if self.augmentations:
78 | img, lbl = self.augmentations(img, lbl)
79 |
80 | img, lbl = self.transform(img, lbl)
81 | return img, lbl
82 |
83 | def getpalette(self):
84 | return np.asarray(
85 | [
86 | [128, 128, 128],
87 | [128, 0, 0],
88 | [192, 192, 128],
89 | [128, 64, 128],
90 | [0, 0, 192],
91 | [128, 128, 0],
92 | [192, 128, 128],
93 | [64, 64, 128],
94 | [64, 0, 128],
95 | [64, 64, 0],
96 | [0, 128, 192],
97 | ]
98 | )
99 |
100 |
101 | # Test code
102 | # if __name__ == '__main__':
103 | # from torch.utils.data import DataLoader
104 | # root = r'D:/Datasets/CamVid'
105 | # batch_size = 2
106 | # loader = CamVidLoader(root=root, img_size=None)
107 | # test_loader = DataLoader(loader, batch_size=batch_size, shuffle=True)
108 |
109 | # palette = test_loader.dataset.getpalette()
110 | # fig, axes = plt.subplots(batch_size, 2, subplot_kw={'xticks': [], 'yticks': []})
111 | # fig.subplots_adjust(left=0.03, right=0.97, hspace=0.2, wspace=0.05)
112 |
113 | # for imgs, labels in test_loader:
114 | # imgs = imgs.numpy()
115 | # imgs = np.transpose(imgs, [0,2,3,1])
116 | # labels = labels.numpy()
117 |
118 | # for i in range(batch_size):
119 | # axes[i][0].imshow(imgs[i])
120 |
121 | # mask_unlabeled = labels[i] == -1
122 | # viz_unlabeled = (
123 | # np.zeros((labels[i].shape[0], labels[i].shape[1], 3))
124 | # ).astype(np.uint8)
125 |
126 | # lbl_viz = palette[labels[i]]
127 | # lbl_viz[labels[i] == -1] = (0, 0, 0)
128 | # lbl_viz[mask_unlabeled] = viz_unlabeled[mask_unlabeled]
129 |
130 | # axes[i][1].imshow(lbl_viz.astype(np.uint8))
131 | # plt.show()
132 | # break
133 |
134 |
--------------------------------------------------------------------------------
/Dataloader/citys_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from .baseloader import BaseLoader
6 | from collections import namedtuple
7 | from torch.utils import data
8 | from torchvision import transforms
9 |
10 |
11 | class CityscapesLoader(BaseLoader):
12 | """Cityscapes dataset loader.
13 | Parameters
14 | ----------
15 | root: path to cityscapes dataset.
16 | for directory:
17 | --VOCdevkit--VOC2012---ImageSets
18 | |-JPEGImages
19 | |- ...
20 | root should be xxx/VOCdevkit/VOC2012
21 | n_classes: number of classes, default 19.
22 | split: choose subset of dataset, 'train','val' or 'trainval'.
23 | img_size: scale image to proper size.
24 | augmentations: whether to perform augmentation.
25 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255.
26 | class_weight: useful in unbalanced datasets.
27 | pretrained: whether to use pretrained models
28 | """
29 |
30 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id',
31 | 'ignore_in_eval', 'color'])
32 |
33 | classes = [
34 | # name id trainId ignoreInEval color
35 | CityscapesClass('unlabeled', 0, 255, True, (0, 0, 0)),
36 | CityscapesClass('ego vehicle', 1, 255, True, (0, 0, 0)),
37 | CityscapesClass('rectification border', 2, 255, True, (0, 0, 0)),
38 | CityscapesClass('out of roi', 3, 255, True, (0, 0, 0)),
39 | CityscapesClass('static', 4, 255, True, (0, 0, 0)),
40 | CityscapesClass('dynamic', 5, 255, True, (111, 74, 0)),
41 | CityscapesClass('ground', 6, 255, True, (81, 0, 81)),
42 | CityscapesClass('road', 7, 0, False, (128, 64, 128)),
43 | CityscapesClass('sidewalk', 8, 1, False, (244, 35, 232)),
44 | CityscapesClass('parking', 9, 255, True, (250, 170, 160)),
45 | CityscapesClass('rail track', 10, 255, True, (230, 150, 140)),
46 | CityscapesClass('building', 11, 2, False, (70, 70, 70)),
47 | CityscapesClass('wall', 12, 3, False, (102, 102, 156)),
48 | CityscapesClass('fence', 13, 4, False, (190, 153, 153)),
49 | CityscapesClass('guard rail', 14, 255, True, (180, 165, 180)),
50 | CityscapesClass('bridge', 15, 255, True, (150, 100, 100)),
51 | CityscapesClass('tunnel', 16, 255, True, (150, 120, 90)),
52 | CityscapesClass('pole', 17, 5, False, (153, 153, 153)),
53 | CityscapesClass('polegroup', 18, 255, True, (153, 153, 153)),
54 | CityscapesClass('traffic light', 19, 6, False, (250, 170, 30)),
55 | CityscapesClass('traffic sign', 20, 7, False, (220, 220, 0)),
56 | CityscapesClass('vegetation', 21, 8, False, (107, 142, 35)),
57 | CityscapesClass('terrain', 22, 9, False, (152, 251, 152)),
58 | CityscapesClass('sky', 23, 10, False, (70, 130, 180)),
59 | CityscapesClass('person', 24, 11, False, (220, 20, 60)),
60 | CityscapesClass('rider', 25, 12, False, (255, 0, 0)),
61 | CityscapesClass('car', 26, 13, False, (0, 0, 142)),
62 | CityscapesClass('truck', 27, 14, False, (0, 0, 70)),
63 | CityscapesClass('bus', 28, 15, False, (0, 60, 100)),
64 | CityscapesClass('caravan', 29, 255, True, (0, 0, 90)),
65 | CityscapesClass('trailer', 30, 255, True, (0, 0, 110)),
66 | CityscapesClass('train', 31, 16, False, (0, 80, 100)),
67 | CityscapesClass('motorcycle', 32, 17, False, (0, 0, 230)),
68 | CityscapesClass('bicycle', 33, 18, False, (119, 11, 32)),
69 | CityscapesClass('license plate', -1, -1, True, (0, 0, 142)),
70 | ]
71 |
72 | def __init__(
73 | self,
74 | root,
75 | n_classes=19,
76 | split="train",
77 | img_size=None,
78 | augmentations=None,
79 | ignore_index=255,
80 | class_weight=None,
81 | pretrained=False
82 | ):
83 | super(CityscapesLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained)
84 |
85 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
86 | self.labels_dir = os.path.join(self.root, 'gtFine', split)
87 | self.images = []
88 | self.labels = []
89 |
90 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
91 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26,
92 | 27, 28, 31, 32, 33]
93 | self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))
94 |
95 | for city in os.listdir(self.images_dir):
96 | img_dir = os.path.join(self.images_dir, city)
97 | label_dir = os.path.join(self.labels_dir, city)
98 | for file_name in os.listdir(img_dir):
99 | label_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
100 | 'gtFine_labelIds.png')
101 | self.images.append(os.path.join(img_dir, file_name))
102 | self.labels.append(os.path.join(label_dir, label_name))
103 |
104 | print(f"Found {len(self.images)} {split} images")
105 |
106 | def __len__(self):
107 | return len(self.images)
108 |
109 | def __getitem__(self, index):
110 | img = Image.open(self.images[index]).convert('RGB')
111 | lbl = Image.open(self.labels[index])
112 |
113 | if self.img_size:
114 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR)
115 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST)
116 |
117 | if self.augmentations:
118 | img, lbl = self.augmentations(img, lbl)
119 |
120 | img, lbl = self.transform(img, lbl)
121 | return img, lbl
122 |
123 | def transform(self, img, lbl):
124 | img = self.tf(img)
125 |
126 | lbl = np.array(lbl, dtype=np.int32)
127 | lbl = self.encode_segmap(lbl)
128 | lbl = torch.from_numpy(lbl).long()
129 | return img, lbl
130 |
131 | def getpalette(self):
132 | return np.array([
133 | [128, 64, 128],
134 | [244, 35, 232],
135 | [70, 70, 70],
136 | [102, 102, 156],
137 | [190, 153, 153],
138 | [153, 153, 153],
139 | [250, 170, 30],
140 | [220, 220, 0],
141 | [107, 142, 35],
142 | [152, 251, 152],
143 | [0, 130, 180],
144 | [220, 20, 60],
145 | [255, 0, 0],
146 | [0, 0, 142],
147 | [0, 0, 70],
148 | [0, 60, 100],
149 | [0, 80, 100],
150 | [0, 0, 230],
151 | [119, 11, 32]
152 | ])
153 |
154 | def decode_segmap(self, lbl):
155 | label_colours = self.getpalette()
156 | r = label_mask.copy()
157 | g = label_mask.copy()
158 | b = label_mask.copy()
159 | for ll in range(0, self.n_classes):
160 | r[label_mask == ll] = label_colours[ll, 0]
161 | g[label_mask == ll] = label_colours[ll, 1]
162 | b[label_mask == ll] = label_colours[ll, 2]
163 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
164 | rgb[:, :, 0] = r / 255.0
165 | rgb[:, :, 1] = g / 255.0
166 | rgb[:, :, 2] = b / 255.0
167 |
168 | return rgb
169 |
170 | def encode_segmap(self, mask):
171 | # Put all void classes to -1
172 | for _voidc in self.void_classes:
173 | mask[mask == _voidc] = -1
174 | for _validc in self.valid_classes:
175 | mask[mask == _validc] = self.class_map[_validc]
176 | return mask
177 |
178 |
179 | if __name__ == "__main__":
180 | import matplotlib.pyplot as plt
181 |
182 | local_path = "/home/ecust/zww/DANet/datasets/cityscapes"
183 | dst = CityscapesLoader(local_path, transform=True)
184 | bs = 4
185 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
186 | for i, data_samples in enumerate(trainloader):
187 | imgs, labels = data_samples
188 |
189 | plt.subplots(1, 1)
190 | for j in range(1):
191 | plt.subplot(1, 2, j + 1)
192 | plt.imshow(np.transpose(imgs.numpy()[j], [1, 2, 0]))
193 | plt.subplot(1, 2, j + 2)
194 | plt.imshow(dst.decode_segmap(labels.numpy()[j]))
195 | plt.show()
196 |
--------------------------------------------------------------------------------
/Dataloader/custom_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from .baseloader import BaseLoader
5 | from PIL import Image
6 |
7 |
8 | class CustomLoader(BaseLoader):
9 | """Custom dataset loader.
10 | Parameters
11 | ----------
12 | root: path to custom dataset, with train.txt and val.txt together.
13 | i.e., -----dataset
14 | |--train.txt
15 | |--val.txt
16 | n_classes: number of classes.
17 | split: choose subset of dataset, 'train','val' or 'test'.
18 | img_size: scale image to proper size.
19 | augmentations: whether to perform augmentation.
20 | ignore_index: ingore_index will be ignored in training phase and evaluation.
21 | class_weight: useful in unbalanced datasets.
22 | pretrained: whether to use pretrained models
23 | """
24 | # specify class_names if necessary
25 | class_names = None
26 |
27 | def __init__(
28 | self,
29 | root,
30 | n_classes,
31 | split="train",
32 | img_size=None,
33 | augmentations=None,
34 | ignore_index=None,
35 | class_weight=None,
36 | pretrained=False
37 | ):
38 | super(CustomLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained)
39 |
40 | path = os.path.join(self.root, split + ".txt")
41 | with open(path, "r") as f:
42 | self.file_list = [file_name.rstrip().split() for file_name in f]
43 |
44 | print(f"Found {len(self.file_list)} {split} images")
45 |
46 | def __len__(self):
47 | return len(self.file_list)
48 |
49 | def __getitem__(self, index):
50 | img_name = self.file_list[index][0]
51 | lbl_name = self.file_list[index][1]
52 |
53 | img = Image.open(os.path.join(self.root, img_name)).convert('RGB')
54 | lbl = Image.open(os.path.join(self.root, lbl_name))
55 |
56 | if self.img_size:
57 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR)
58 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST)
59 |
60 | if self.augmentations:
61 | img, lbl = self.augmentations(img, lbl)
62 |
63 | img, lbl = self.transform(img, lbl)
64 | return img, lbl
65 |
66 | def getpalette(self):
67 | """for custom palette, if not specified, use pascal voc palette by default.
68 | """
69 | n = self.n_classes
70 | palette = [0]*(n*3)
71 | for j in range(0, n):
72 | lab = j
73 | palette[j*3+0] = 0
74 | palette[j*3+1] = 0
75 | palette[j*3+2] = 0
76 | i = 0
77 | while (lab > 0):
78 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i))
79 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i))
80 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i))
81 | i = i + 1
82 | lab >>= 3
83 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8)
84 | return palette
85 |
86 |
87 | # Test code
88 | # if __name__ == '__main__':
89 | # from torch.utils.data import DataLoader
90 | # root = ''
91 | # batch_size = 2
92 | # loader = CustomLoader(root=root, img_size=None)
93 | # test_loader = DataLoader(loader, batch_size=batch_size, shuffle=True)
94 |
95 | # palette = test_loader.dataset.getpalette()
96 | # fig, axes = plt.subplots(batch_size, 2, subplot_kw={'xticks': [], 'yticks': []})
97 | # fig.subplots_adjust(left=0.03, right=0.97, hspace=0.2, wspace=0.05)
98 |
99 | # for imgs, labels in test_loader:
100 | # imgs = imgs.numpy()
101 | # imgs = np.transpose(imgs, [0,2,3,1])
102 | # labels = labels.numpy()
103 |
104 | # for i in range(batch_size):
105 | # axes[i][0].imshow(imgs[i])
106 |
107 | # mask_unlabeled = labels[i] == -1
108 | # viz_unlabeled = (
109 | # np.zeros((labels[i].shape[0], labels[i].shape[1], 3))
110 | # ).astype(np.uint8)
111 |
112 | # lbl_viz = palette[labels[i]]
113 | # lbl_viz[labels[i] == -1] = (0, 0, 0)
114 | # lbl_viz[mask_unlabeled] = viz_unlabeled[mask_unlabeled]
115 |
116 | # axes[i][1].imshow(lbl_viz.astype(np.uint8))
117 | # plt.show()
118 | # break
119 |
120 |
--------------------------------------------------------------------------------
/Dataloader/seg11valid.txt:
--------------------------------------------------------------------------------
1 | 2007_000033
2 | 2007_000042
3 | 2007_000061
4 | 2007_000123
5 | 2007_000129
6 | 2007_000175
7 | 2007_000187
8 | 2007_000323
9 | 2007_000332
10 | 2007_000346
11 | 2007_000452
12 | 2007_000464
13 | 2007_000491
14 | 2007_000529
15 | 2007_000559
16 | 2007_000572
17 | 2007_000629
18 | 2007_000636
19 | 2007_000661
20 | 2007_000663
21 | 2007_000676
22 | 2007_000727
23 | 2007_000762
24 | 2007_000783
25 | 2007_000799
26 | 2007_000804
27 | 2007_000830
28 | 2007_000837
29 | 2007_000847
30 | 2007_000862
31 | 2007_000925
32 | 2007_000999
33 | 2007_001154
34 | 2007_001175
35 | 2007_001239
36 | 2007_001284
37 | 2007_001288
38 | 2007_001289
39 | 2007_001299
40 | 2007_001311
41 | 2007_001321
42 | 2007_001377
43 | 2007_001408
44 | 2007_001423
45 | 2007_001430
46 | 2007_001457
47 | 2007_001458
48 | 2007_001526
49 | 2007_001568
50 | 2007_001585
51 | 2007_001586
52 | 2007_001587
53 | 2007_001594
54 | 2007_001630
55 | 2007_001677
56 | 2007_001678
57 | 2007_001717
58 | 2007_001733
59 | 2007_001761
60 | 2007_001763
61 | 2007_001774
62 | 2007_001884
63 | 2007_001955
64 | 2007_002046
65 | 2007_002094
66 | 2007_002119
67 | 2007_002132
68 | 2007_002260
69 | 2007_002266
70 | 2007_002268
71 | 2007_002284
72 | 2007_002376
73 | 2007_002378
74 | 2007_002387
75 | 2007_002400
76 | 2007_002412
77 | 2007_002426
78 | 2007_002427
79 | 2007_002445
80 | 2007_002470
81 | 2007_002539
82 | 2007_002565
83 | 2007_002597
84 | 2007_002618
85 | 2007_002619
86 | 2007_002624
87 | 2007_002643
88 | 2007_002648
89 | 2007_002719
90 | 2007_002728
91 | 2007_002823
92 | 2007_002824
93 | 2007_002852
94 | 2007_002903
95 | 2007_003011
96 | 2007_003020
97 | 2007_003022
98 | 2007_003051
99 | 2007_003088
100 | 2007_003101
101 | 2007_003106
102 | 2007_003110
103 | 2007_003131
104 | 2007_003134
105 | 2007_003137
106 | 2007_003143
107 | 2007_003169
108 | 2007_003188
109 | 2007_003194
110 | 2007_003195
111 | 2007_003201
112 | 2007_003349
113 | 2007_003367
114 | 2007_003373
115 | 2007_003499
116 | 2007_003503
117 | 2007_003506
118 | 2007_003530
119 | 2007_003571
120 | 2007_003587
121 | 2007_003611
122 | 2007_003621
123 | 2007_003682
124 | 2007_003711
125 | 2007_003714
126 | 2007_003742
127 | 2007_003786
128 | 2007_003841
129 | 2007_003848
130 | 2007_003861
131 | 2007_003872
132 | 2007_003917
133 | 2007_003957
134 | 2007_003991
135 | 2007_004033
136 | 2007_004052
137 | 2007_004112
138 | 2007_004121
139 | 2007_004143
140 | 2007_004189
141 | 2007_004190
142 | 2007_004193
143 | 2007_004241
144 | 2007_004275
145 | 2007_004281
146 | 2007_004380
147 | 2007_004392
148 | 2007_004405
149 | 2007_004468
150 | 2007_004483
151 | 2007_004510
152 | 2007_004538
153 | 2007_004558
154 | 2007_004644
155 | 2007_004649
156 | 2007_004712
157 | 2007_004722
158 | 2007_004856
159 | 2007_004866
160 | 2007_004902
161 | 2007_004969
162 | 2007_005058
163 | 2007_005074
164 | 2007_005107
165 | 2007_005114
166 | 2007_005149
167 | 2007_005173
168 | 2007_005281
169 | 2007_005294
170 | 2007_005296
171 | 2007_005304
172 | 2007_005331
173 | 2007_005354
174 | 2007_005358
175 | 2007_005428
176 | 2007_005460
177 | 2007_005469
178 | 2007_005509
179 | 2007_005547
180 | 2007_005600
181 | 2007_005608
182 | 2007_005626
183 | 2007_005689
184 | 2007_005696
185 | 2007_005705
186 | 2007_005759
187 | 2007_005803
188 | 2007_005813
189 | 2007_005828
190 | 2007_005844
191 | 2007_005845
192 | 2007_005857
193 | 2007_005911
194 | 2007_005915
195 | 2007_005978
196 | 2007_006028
197 | 2007_006035
198 | 2007_006046
199 | 2007_006076
200 | 2007_006086
201 | 2007_006117
202 | 2007_006171
203 | 2007_006241
204 | 2007_006260
205 | 2007_006277
206 | 2007_006348
207 | 2007_006364
208 | 2007_006373
209 | 2007_006444
210 | 2007_006449
211 | 2007_006549
212 | 2007_006553
213 | 2007_006560
214 | 2007_006647
215 | 2007_006678
216 | 2007_006680
217 | 2007_006698
218 | 2007_006761
219 | 2007_006802
220 | 2007_006837
221 | 2007_006841
222 | 2007_006864
223 | 2007_006866
224 | 2007_006946
225 | 2007_007007
226 | 2007_007084
227 | 2007_007109
228 | 2007_007130
229 | 2007_007165
230 | 2007_007168
231 | 2007_007195
232 | 2007_007196
233 | 2007_007203
234 | 2007_007211
235 | 2007_007235
236 | 2007_007341
237 | 2007_007414
238 | 2007_007417
239 | 2007_007470
240 | 2007_007477
241 | 2007_007493
242 | 2007_007498
243 | 2007_007524
244 | 2007_007534
245 | 2007_007624
246 | 2007_007651
247 | 2007_007688
248 | 2007_007748
249 | 2007_007795
250 | 2007_007810
251 | 2007_007815
252 | 2007_007818
253 | 2007_007836
254 | 2007_007849
255 | 2007_007881
256 | 2007_007996
257 | 2007_008051
258 | 2007_008084
259 | 2007_008106
260 | 2007_008110
261 | 2007_008204
262 | 2007_008222
263 | 2007_008256
264 | 2007_008260
265 | 2007_008339
266 | 2007_008374
267 | 2007_008415
268 | 2007_008430
269 | 2007_008543
270 | 2007_008547
271 | 2007_008596
272 | 2007_008645
273 | 2007_008670
274 | 2007_008708
275 | 2007_008722
276 | 2007_008747
277 | 2007_008802
278 | 2007_008815
279 | 2007_008897
280 | 2007_008944
281 | 2007_008964
282 | 2007_008973
283 | 2007_008980
284 | 2007_009015
285 | 2007_009068
286 | 2007_009084
287 | 2007_009088
288 | 2007_009096
289 | 2007_009221
290 | 2007_009245
291 | 2007_009251
292 | 2007_009252
293 | 2007_009258
294 | 2007_009320
295 | 2007_009323
296 | 2007_009331
297 | 2007_009346
298 | 2007_009392
299 | 2007_009413
300 | 2007_009419
301 | 2007_009446
302 | 2007_009458
303 | 2007_009521
304 | 2007_009562
305 | 2007_009592
306 | 2007_009654
307 | 2007_009655
308 | 2007_009684
309 | 2007_009687
310 | 2007_009691
311 | 2007_009706
312 | 2007_009750
313 | 2007_009756
314 | 2007_009764
315 | 2007_009794
316 | 2007_009817
317 | 2007_009841
318 | 2007_009897
319 | 2007_009911
320 | 2007_009923
321 | 2007_009938
322 | 2008_000073
323 | 2008_000075
324 | 2008_000107
325 | 2008_000123
326 | 2008_000149
327 | 2008_000213
328 | 2008_000215
329 | 2008_000223
330 | 2008_000233
331 | 2008_000239
332 | 2008_000271
333 | 2008_000345
334 | 2008_000391
335 | 2008_000401
336 | 2008_000501
337 | 2008_000533
338 | 2008_000573
339 | 2008_000589
340 | 2008_000657
341 | 2008_000661
342 | 2008_000725
343 | 2008_000731
344 | 2008_000763
345 | 2008_000765
346 | 2008_000811
347 | 2008_000853
348 | 2008_000911
349 | 2008_000919
350 | 2008_000943
351 | 2008_001135
352 | 2008_001231
353 | 2008_001249
354 | 2008_001379
355 | 2008_001433
356 | 2008_001439
357 | 2008_001513
358 | 2008_001531
359 | 2008_001547
360 | 2008_001715
361 | 2008_001821
362 | 2008_001885
363 | 2008_001971
364 | 2008_002043
365 | 2008_002205
366 | 2008_002239
367 | 2008_002269
368 | 2008_002273
369 | 2008_002379
370 | 2008_002383
371 | 2008_002467
372 | 2008_002521
373 | 2008_002623
374 | 2008_002681
375 | 2008_002775
376 | 2008_002835
377 | 2008_002859
378 | 2008_003105
379 | 2008_003135
380 | 2008_003155
381 | 2008_003369
382 | 2008_003709
383 | 2008_003777
384 | 2008_003821
385 | 2008_003885
386 | 2008_004069
387 | 2008_004172
388 | 2008_004175
389 | 2008_004279
390 | 2008_004339
391 | 2008_004345
392 | 2008_004363
393 | 2008_004453
394 | 2008_004562
395 | 2008_004575
396 | 2008_004621
397 | 2008_004659
398 | 2008_004705
399 | 2008_004995
400 | 2008_005049
401 | 2008_005097
402 | 2008_005105
403 | 2008_005145
404 | 2008_005217
405 | 2008_005262
406 | 2008_005439
407 | 2008_005525
408 | 2008_005633
409 | 2008_005637
410 | 2008_005691
411 | 2008_006055
412 | 2008_006229
413 | 2008_006327
414 | 2008_006553
415 | 2008_006835
416 | 2008_007025
417 | 2008_007031
418 | 2008_007123
419 | 2008_007497
420 | 2008_007677
421 | 2008_007797
422 | 2008_007811
423 | 2008_008051
424 | 2008_008103
425 | 2008_008301
426 | 2009_000013
427 | 2009_000022
428 | 2009_000032
429 | 2009_000037
430 | 2009_000039
431 | 2009_000087
432 | 2009_000121
433 | 2009_000149
434 | 2009_000201
435 | 2009_000205
436 | 2009_000219
437 | 2009_000335
438 | 2009_000351
439 | 2009_000387
440 | 2009_000391
441 | 2009_000446
442 | 2009_000455
443 | 2009_000457
444 | 2009_000469
445 | 2009_000487
446 | 2009_000523
447 | 2009_000619
448 | 2009_000641
449 | 2009_000675
450 | 2009_000705
451 | 2009_000723
452 | 2009_000727
453 | 2009_000771
454 | 2009_000845
455 | 2009_000879
456 | 2009_000919
457 | 2009_000931
458 | 2009_000935
459 | 2009_000989
460 | 2009_000991
461 | 2009_001255
462 | 2009_001299
463 | 2009_001333
464 | 2009_001363
465 | 2009_001391
466 | 2009_001411
467 | 2009_001433
468 | 2009_001505
469 | 2009_001535
470 | 2009_001565
471 | 2009_001607
472 | 2009_001663
473 | 2009_001683
474 | 2009_001687
475 | 2009_001731
476 | 2009_001775
477 | 2009_001851
478 | 2009_001941
479 | 2009_002035
480 | 2009_002165
481 | 2009_002171
482 | 2009_002221
483 | 2009_002291
484 | 2009_002295
485 | 2009_002317
486 | 2009_002445
487 | 2009_002487
488 | 2009_002521
489 | 2009_002527
490 | 2009_002535
491 | 2009_002539
492 | 2009_002549
493 | 2009_002571
494 | 2009_002573
495 | 2009_002591
496 | 2009_002635
497 | 2009_002649
498 | 2009_002651
499 | 2009_002727
500 | 2009_002749
501 | 2009_002753
502 | 2009_002771
503 | 2009_002887
504 | 2009_002975
505 | 2009_003003
506 | 2009_003005
507 | 2009_003059
508 | 2009_003063
509 | 2009_003065
510 | 2009_003071
511 | 2009_003105
512 | 2009_003123
513 | 2009_003193
514 | 2009_003269
515 | 2009_003273
516 | 2009_003311
517 | 2009_003323
518 | 2009_003343
519 | 2009_003387
520 | 2009_003481
521 | 2009_003517
522 | 2009_003523
523 | 2009_003549
524 | 2009_003551
525 | 2009_003589
526 | 2009_003607
527 | 2009_003703
528 | 2009_003707
529 | 2009_003771
530 | 2009_003849
531 | 2009_003857
532 | 2009_003895
533 | 2009_004021
534 | 2009_004033
535 | 2009_004043
536 | 2009_004099
537 | 2009_004125
538 | 2009_004217
539 | 2009_004255
540 | 2009_004455
541 | 2009_004507
542 | 2009_004509
543 | 2009_004579
544 | 2009_004581
545 | 2009_004687
546 | 2009_004801
547 | 2009_004859
548 | 2009_004867
549 | 2009_004895
550 | 2009_004969
551 | 2009_004993
552 | 2009_005087
553 | 2009_005089
554 | 2009_005137
555 | 2009_005189
556 | 2009_005217
557 | 2009_005219
558 | 2010_000003
559 | 2010_000065
560 | 2010_000083
561 | 2010_000159
562 | 2010_000163
563 | 2010_000309
564 | 2010_000427
565 | 2010_000559
566 | 2010_000573
567 | 2010_000639
568 | 2010_000683
569 | 2010_000907
570 | 2010_000961
571 | 2010_001017
572 | 2010_001061
573 | 2010_001069
574 | 2010_001149
575 | 2010_001151
576 | 2010_001251
577 | 2010_001313
578 | 2010_001327
579 | 2010_001331
580 | 2010_001553
581 | 2010_001557
582 | 2010_001563
583 | 2010_001577
584 | 2010_001579
585 | 2010_001767
586 | 2010_001773
587 | 2010_001851
588 | 2010_001995
589 | 2010_002017
590 | 2010_002025
591 | 2010_002137
592 | 2010_002147
593 | 2010_002161
594 | 2010_002271
595 | 2010_002305
596 | 2010_002361
597 | 2010_002531
598 | 2010_002623
599 | 2010_002693
600 | 2010_002701
601 | 2010_002763
602 | 2010_002921
603 | 2010_002929
604 | 2010_002939
605 | 2010_003123
606 | 2010_003187
607 | 2010_003207
608 | 2010_003239
609 | 2010_003275
610 | 2010_003325
611 | 2010_003365
612 | 2010_003381
613 | 2010_003409
614 | 2010_003453
615 | 2010_003473
616 | 2010_003495
617 | 2010_003531
618 | 2010_003547
619 | 2010_003675
620 | 2010_003781
621 | 2010_003813
622 | 2010_003915
623 | 2010_003971
624 | 2010_004041
625 | 2010_004063
626 | 2010_004149
627 | 2010_004165
628 | 2010_004219
629 | 2010_004355
630 | 2010_004419
631 | 2010_004479
632 | 2010_004529
633 | 2010_004543
634 | 2010_004551
635 | 2010_004559
636 | 2010_004697
637 | 2010_004763
638 | 2010_004783
639 | 2010_004795
640 | 2010_004815
641 | 2010_004825
642 | 2010_005013
643 | 2010_005021
644 | 2010_005063
645 | 2010_005159
646 | 2010_005187
647 | 2010_005245
648 | 2010_005305
649 | 2010_005421
650 | 2010_005531
651 | 2010_005705
652 | 2010_005709
653 | 2010_005719
654 | 2010_005727
655 | 2010_005871
656 | 2010_005877
657 | 2010_005899
658 | 2010_005991
659 | 2011_000045
660 | 2011_000051
661 | 2011_000173
662 | 2011_000185
663 | 2011_000291
664 | 2011_000419
665 | 2011_000435
666 | 2011_000455
667 | 2011_000479
668 | 2011_000503
669 | 2011_000521
670 | 2011_000536
671 | 2011_000598
672 | 2011_000607
673 | 2011_000661
674 | 2011_000669
675 | 2011_000747
676 | 2011_000789
677 | 2011_000809
678 | 2011_000843
679 | 2011_000969
680 | 2011_001069
681 | 2011_001071
682 | 2011_001161
683 | 2011_001263
684 | 2011_001281
685 | 2011_001287
686 | 2011_001313
687 | 2011_001341
688 | 2011_001421
689 | 2011_001447
690 | 2011_001529
691 | 2011_001567
692 | 2011_001589
693 | 2011_001597
694 | 2011_001601
695 | 2011_001607
696 | 2011_001613
697 | 2011_001619
698 | 2011_001665
699 | 2011_001669
700 | 2011_001713
701 | 2011_001745
702 | 2011_001775
703 | 2011_001793
704 | 2011_001812
705 | 2011_001868
706 | 2011_001984
707 | 2011_002041
708 | 2011_002121
709 | 2011_002223
710 | 2011_002279
711 | 2011_002295
712 | 2011_002317
713 | 2011_002327
714 | 2011_002343
715 | 2011_002371
716 | 2011_002379
717 | 2011_002391
718 | 2011_002509
719 | 2011_002535
720 | 2011_002575
721 | 2011_002589
722 | 2011_002623
723 | 2011_002641
724 | 2011_002675
725 | 2011_002685
726 | 2011_002713
727 | 2011_002863
728 | 2011_002929
729 | 2011_002993
730 | 2011_002997
731 | 2011_003011
732 | 2011_003055
733 | 2011_003085
734 | 2011_003145
735 | 2011_003197
736 | 2011_003271
737 |
--------------------------------------------------------------------------------
/Dataloader/voc_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from scipy.io import loadmat
6 | from .baseloader import BaseLoader
7 |
8 |
9 | class VOCLoader(BaseLoader):
10 | """PASCAL VOC dataset loader.
11 | Parameters
12 | ----------
13 | root: path to pascal voc dataset.
14 | for directory:
15 | --VOCdevkit--VOC2012---ImageSets
16 | |-JPEGImages
17 | |- ...
18 | root should be xxx/VOCdevkit/VOC2012
19 | n_classes: number of classes, default 21.
20 | split: choose subset of dataset, 'train','val' or 'trainval'.
21 | img_size: scale image to proper size.
22 | augmentations: whether to perform augmentation.
23 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255.
24 | class_weight: useful in unbalanced datasets.
25 | pretrained: whether to use pretrained models
26 | """
27 | class_names = np.array([
28 | 'background', 'aeroplane', 'bicycle',
29 | 'bird', 'boat', 'bottle', 'bus',
30 | 'car', 'cat', 'chair', 'cow', 'diningtable',
31 | 'dog', 'horse', 'motorbike', 'person',
32 | 'potted plant', 'sheep', 'sofa', 'train',
33 | 'tv/monitor',
34 | ])
35 |
36 | def __init__(
37 | self,
38 | root,
39 | n_classes=21,
40 | split="train",
41 | img_size=None,
42 | augmentations=None,
43 | ignore_index=255,
44 | class_weight=None,
45 | pretrained=False
46 | ):
47 | super(VOCLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained)
48 |
49 | path = os.path.join(self.root, "ImageSets/Segmentation", split + ".txt")
50 | with open(path, "r") as f:
51 | self.file_list = [file_name.rstrip() for file_name in f]
52 |
53 | print(f"Found {len(self.file_list)} {split} images")
54 |
55 | def __len__(self):
56 | return len(self.file_list)
57 |
58 | def __getitem__(self, index):
59 | img_name = self.file_list[index]
60 | img_path = os.path.join(self.root, "JPEGImages", img_name + ".jpg")
61 | lbl_path = os.path.join(self.root, "SegmentationClass", img_name + ".png")
62 | img = Image.open(img_path).convert('RGB')
63 | lbl = Image.open(lbl_path)
64 | if self.img_size:
65 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR)
66 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST)
67 | if self.augmentations:
68 | img, lbl = self.augmentations(img, lbl)
69 |
70 | img, lbl = self.transform(img, lbl)
71 | return img, lbl
72 |
73 | def getpalette(self):
74 | n = self.n_classes
75 | palette = [0]*(n*3)
76 | for j in range(0, n):
77 | lab = j
78 | palette[j*3+0] = 0
79 | palette[j*3+1] = 0
80 | palette[j*3+2] = 0
81 | i = 0
82 | while (lab > 0):
83 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i))
84 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i))
85 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i))
86 | i = i + 1
87 | lab >>= 3
88 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8)
89 | return palette
90 |
91 | class SBDLoader(BaseLoader):
92 | """Semantic Boundaries Dataset(SBD) dataset loader.
93 | Parameters
94 | ----------
95 | root: path to SBD dataset.
96 | for directory:
97 | --benchmark_RELEASE--dataset---img
98 | |-cls
99 | |-train.txt
100 | |- ...
101 | root should be xxx/benchmark_RELEASE
102 | n_classes: number of classes, default 21.
103 | split: choose subset of dataset, 'train' or 'val'.
104 | img_size: scale image to proper size.
105 | augmentations: whether to perform augmentation.
106 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255.
107 | class_weight: useful in unbalanced datasets.
108 | pretrained: whether to use pretrained models
109 | """
110 | class_names = np.array([
111 | 'background', 'aeroplane', 'bicycle',
112 | 'bird', 'boat', 'bottle', 'bus',
113 | 'car', 'cat', 'chair', 'cow', 'diningtable',
114 | 'dog', 'horse', 'motorbike', 'person',
115 | 'potted plant', 'sheep', 'sofa', 'train',
116 | 'tv/monitor',
117 | ])
118 | def __init__(
119 | self,
120 | root,
121 | n_classes=21,
122 | split="train",
123 | img_size=None,
124 | augmentations=None,
125 | ignore_index=255,
126 | class_weight=None,
127 | pretrained=False
128 | ):
129 | super(SBDLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained)
130 |
131 | path = os.path.join(self.root, 'dataset', split + ".txt")
132 | with open(path, "r") as f:
133 | self.file_list = [file_name.rstrip() for file_name in f]
134 |
135 | print(f"Found {len(self.file_list)} {split} images")
136 |
137 | def __len__(self):
138 | return len(self.file_list)
139 |
140 | def __getitem__(self, index):
141 | img_name = self.file_list[index]
142 | img_path = os.path.join(self.root, 'dataset/img', img_name + '.jpg')
143 | lbl_path = os.path.join(self.root, 'dataset/cls', img_name + '.mat')
144 |
145 | img = Image.open(img_path).convert('RGB')
146 | lbl = loadmat(lbl_path)
147 | lbl = lbl['GTcls'][0]['Segmentation'][0].astype(np.int32)
148 | lbl = Image.fromarray(lbl)
149 | if self.img_size:
150 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR)
151 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST)
152 | if self.augmentations:
153 | img, lbl = self.augmentations(img, lbl)
154 |
155 | img, lbl = self.transform(img, lbl)
156 | return img, lbl
157 |
158 | def getpalette(self):
159 | n = self.n_classes
160 | palette = [0]*(n*3)
161 | for j in range(0, n):
162 | lab = j
163 | palette[j*3+0] = 0
164 | palette[j*3+1] = 0
165 | palette[j*3+2] = 0
166 | i = 0
167 | while (lab > 0):
168 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i))
169 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i))
170 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i))
171 | i = i + 1
172 | lab >>= 3
173 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8)
174 | return palette
175 |
176 | class VOC11Val(BaseLoader):
177 | """load PASCAL VOC 2012 dataset, but only use seg11valid.txt for evaluation.
178 | Parameters
179 | ----------
180 | root: path to PASCAL VOC 2012 dataset.
181 | n_classes: number of classes, default 21.
182 | split: only 'seg11valid' is available.
183 | img_size: scale image to proper size.
184 | augmentations: whether to perform augmentation.
185 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255.
186 | class_weight: useful in unbalanced datasets.
187 | pretrained: whether to use pretrained models
188 | """
189 | class_names = np.array([
190 | 'background', 'aeroplane', 'bicycle',
191 | 'bird', 'boat', 'bottle', 'bus',
192 | 'car', 'cat', 'chair', 'cow', 'diningtable',
193 | 'dog', 'horse', 'motorbike', 'person',
194 | 'potted plant', 'sheep', 'sofa', 'train',
195 | 'tv/monitor',
196 | ])
197 |
198 | def __init__(
199 | self,
200 | root,
201 | n_classes=21,
202 | split="seg11valid",
203 | img_size=None,
204 | augmentations=None,
205 | ignore_index=255,
206 | class_weight=None,
207 | pretrained=False
208 | ):
209 | super(VOC11Val, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained)
210 |
211 | current_path = os.path.realpath(__file__)
212 |
213 | path = os.path.join(current_path[:-13] + "seg11valid.txt")
214 | with open(path, "r") as f:
215 | self.file_list = [file_name.rstrip() for file_name in f]
216 |
217 | print(f"Found {len(self.file_list)} {split} images")
218 |
219 | def __len__(self):
220 | return len(self.file_list)
221 |
222 | def __getitem__(self, index):
223 | img_name = self.file_list[index]
224 | img_path = os.path.join(self.root, "JPEGImages", img_name + ".jpg")
225 | lbl_path = os.path.join(self.root, "SegmentationClass", img_name + ".png")
226 | img = Image.open(img_path).convert('RGB')
227 | lbl = Image.open(lbl_path)
228 |
229 | if self.img_size:
230 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR)
231 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST)
232 | if self.augmentations:
233 | img, lbl = self.augmentations(img, lbl)
234 |
235 | img, lbl = self.transform(img, lbl)
236 | return img, lbl
237 |
238 | def getpalette(self):
239 | n = self.n_classes
240 | palette = [0]*(n*3)
241 | for j in range(0, n):
242 | lab = j
243 | palette[j*3+0] = 0
244 | palette[j*3+1] = 0
245 | palette[j*3+2] = 0
246 | i = 0
247 | while (lab > 0):
248 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i))
249 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i))
250 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i))
251 | i = i + 1
252 | lab >>= 3
253 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8)
254 | return palette
255 |
256 | # Test code
257 | # if __name__ == '__main__':
258 | # from torch.utils.data import DataLoader
259 | # root = r'D:\Datasets\VOCdevkit\VOC2012'
260 | # batch_size = 2
261 | # loader = VOCLoader(root=root, img_size=(500, 500))
262 | # test_loader = DataLoader(loader, batch_size=batch_size, shuffle=True)
263 |
264 | # palette = test_loader.dataset.getpalette()
265 | # fig, axes = plt.subplots(batch_size, 2, subplot_kw={'xticks': [], 'yticks': []})
266 | # fig.subplots_adjust(left=0.03, right=0.97, hspace=0.2, wspace=0.05)
267 |
268 | # for imgs, labels in test_loader:
269 | # imgs = imgs.numpy()
270 | # imgs = np.transpose(imgs, [0,2,3,1])
271 | # labels = labels.numpy()
272 |
273 | # for i in range(batch_size):
274 | # axes[i][0].imshow(imgs[i])
275 |
276 | # mask_unlabeled = labels[i] == -1
277 | # viz_unlabeled = (
278 | # np.zeros((labels[i].shape[0], labels[i].shape[1], 3))
279 | # ).astype(np.uint8)
280 |
281 | # lbl_viz = palette[labels[i]]
282 | # lbl_viz[labels[i] == -1] = (0, 0, 0)
283 | # lbl_viz[mask_unlabeled] = viz_unlabeled[mask_unlabeled]
284 |
285 | # axes[i][1].imshow(lbl_viz.astype(np.uint8))
286 | # plt.show()
287 | # break
288 |
289 |
290 |
--------------------------------------------------------------------------------
/Models/DeepLab_v1.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class DeepLabLargeFOV(nn.Module):
8 | """Adapted from official implementation:
9 |
10 | http://www.cs.jhu.edu/~alanlab/ccvl/DeepLab-LargeFOV/train.prototxt
11 |
12 | input dimension equal to
13 | n = 32 * k - 31, e.g., 321 (for k = 11)
14 | Dimension after pooling w. subsampling:
15 | (16 * k - 15); (8 * k - 7); (4 * k - 3); (2 * k - 1); (k).
16 | For k = 11, these translate to
17 | 161; 81; 41; 21; 11
18 | """
19 | def __init__(self, n_classes):
20 | super(DeepLabLargeFOV, self).__init__()
21 |
22 | features = []
23 | features.append(nn.Conv2d(3, 64, 3, padding=1))
24 | features.append(nn.ReLU(inplace=True))
25 | features.append(nn.Conv2d(64, 64, 3, padding=1))
26 | features.append(nn.ReLU(inplace=True))
27 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True))
28 |
29 | features.append(nn.Conv2d(64, 128, 3, padding=1))
30 | features.append(nn.ReLU(inplace=True))
31 | features.append(nn.Conv2d(128, 128, 3, padding=1))
32 | features.append(nn.ReLU(inplace=True))
33 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True))
34 |
35 | features.append(nn.Conv2d(128, 256, 3, padding=1))
36 | features.append(nn.ReLU(inplace=True))
37 | features.append(nn.Conv2d(256, 256, 3, padding=1))
38 | features.append(nn.ReLU(inplace=True))
39 | features.append(nn.Conv2d(256, 256, 3, padding=1))
40 | features.append(nn.ReLU(inplace=True))
41 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True))
42 |
43 | features.append(nn.Conv2d(256, 512, 3, padding=1))
44 | features.append(nn.ReLU(inplace=True))
45 | features.append(nn.Conv2d(512, 512, 3, padding=1))
46 | features.append(nn.ReLU(inplace=True))
47 | features.append(nn.Conv2d(512, 512, 3, padding=1))
48 | features.append(nn.ReLU(inplace=True))
49 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True))
50 |
51 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2))
52 | features.append(nn.ReLU(inplace=True))
53 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2))
54 | features.append(nn.ReLU(inplace=True))
55 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2))
56 | features.append(nn.ReLU(inplace=True))
57 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True))
58 | self.features = nn.Sequential(*features)
59 |
60 | fc = []
61 | fc.append(nn.AvgPool2d(3, stride=1, padding=1))
62 | fc.append(nn.Conv2d(512, 1024, 3, padding=12, dilation=12))
63 | fc.append(nn.ReLU(inplace=True))
64 | fc.append(nn.Conv2d(1024, 1024, 1))
65 | fc.append(nn.ReLU(inplace=True))
66 | fc.append(nn.Dropout(p=0.5))
67 | self.fc = nn.Sequential(*fc)
68 |
69 | self.score = nn.Conv2d(1024, n_classes, 1)
70 |
71 | self._initialize_weights()
72 |
73 | def _initialize_weights(self):
74 |
75 | vgg = torchvision.models.vgg16(pretrained=True)
76 | state_dict = vgg.features.state_dict()
77 | self.features.load_state_dict(state_dict)
78 |
79 | # for m in self.fc.modules():
80 | # if isinstance(m, nn.Conv2d):
81 | # nn.init.kaiming_normal_(m.weight)
82 | # nn.init.constant_(m.bias, 0)
83 |
84 | nn.init.normal_(self.score.weight, std=0.01)
85 | nn.init.constant_(self.score.bias, 0)
86 |
87 | def forward(self, x):
88 | _, _, h, w = x.size()
89 | out = self.features(x)
90 | out = self.fc(out)
91 | out = self.score(out)
92 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
93 | return out
94 |
95 | def get_parameters(self, bias=False, score=False):
96 | if score:
97 | if bias:
98 | yield self.score.bias
99 | else:
100 | yield self.score.weight
101 | else:
102 | for module in [self.features, self.fc]:
103 | for m in module.modules():
104 | if isinstance(m, nn.Conv2d):
105 | if bias:
106 | yield m.bias
107 | else:
108 | yield m.weight
109 |
110 |
111 | if __name__ == "__main__":
112 | import torch
113 | import time
114 | model = DeepLabLargeFOV(21)
115 | print(f'==> Testing {model.__class__.__name__} with PyTorch')
116 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
117 |
118 | model = model.to(device)
119 | model.eval()
120 |
121 | x = torch.Tensor(1, 3, 321, 321)
122 | x = x.to(device)
123 |
124 | torch.cuda.synchronize()
125 | t_start = time.time()
126 | for i in range(10):
127 | model(x)
128 | torch.cuda.synchronize()
129 | elapsed_time = time.time() - t_start
130 |
131 | print(f'Speed: {(elapsed_time / 10) * 1000:.2f} ms')
132 |
--------------------------------------------------------------------------------
/Models/DeepLab_v2.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torchvision
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DeepLabASPPVGG(nn.Module):
9 | """Adapted from official implementation:
10 |
11 | http://liangchiehchen.com/projects/DeepLabv2_vgg.html
12 | """
13 | def __init__(self, n_classes):
14 | super(DeepLabASPPVGG, self).__init__()
15 |
16 | features = []
17 | features.append(nn.Conv2d(3, 64, 3, padding=1))
18 | features.append(nn.ReLU(inplace=True))
19 | features.append(nn.Conv2d(64, 64, 3, padding=1))
20 | features.append(nn.ReLU(inplace=True))
21 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True))
22 |
23 | features.append(nn.Conv2d(64, 128, 3, padding=1))
24 | features.append(nn.ReLU(inplace=True))
25 | features.append(nn.Conv2d(128, 128, 3, padding=1))
26 | features.append(nn.ReLU(inplace=True))
27 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True))
28 |
29 | features.append(nn.Conv2d(128, 256, 3, padding=1))
30 | features.append(nn.ReLU(inplace=True))
31 | features.append(nn.Conv2d(256, 256, 3, padding=1))
32 | features.append(nn.ReLU(inplace=True))
33 | features.append(nn.Conv2d(256, 256, 3, padding=1))
34 | features.append(nn.ReLU(inplace=True))
35 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True))
36 |
37 | features.append(nn.Conv2d(256, 512, 3, padding=1))
38 | features.append(nn.ReLU(inplace=True))
39 | features.append(nn.Conv2d(512, 512, 3, padding=1))
40 | features.append(nn.ReLU(inplace=True))
41 | features.append(nn.Conv2d(512, 512, 3, padding=1))
42 | features.append(nn.ReLU(inplace=True))
43 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True))
44 |
45 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2))
46 | features.append(nn.ReLU(inplace=True))
47 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2))
48 | features.append(nn.ReLU(inplace=True))
49 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2))
50 | features.append(nn.ReLU(inplace=True))
51 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True))
52 | self.features = nn.Sequential(*features)
53 |
54 | # hole = 6
55 | fc1 = []
56 | fc1.append(nn.Conv2d(512, 1024, 3, padding=6, dilation=6))
57 | fc1.append(nn.ReLU(inplace=True))
58 | fc1.append(nn.Dropout(p=0.5))
59 | fc1.append(nn.Conv2d(1024, 1024, 1))
60 | fc1.append(nn.ReLU(inplace=True))
61 | fc1.append(nn.Dropout(p=0.5))
62 | self.fc1 = nn.Sequential(*fc1)
63 | self.fc1_score = nn.Conv2d(1024, n_classes, 1)
64 |
65 | # hole = 12
66 | fc2 = []
67 | fc2.append(nn.Conv2d(512, 1024, 3, padding=12, dilation=12))
68 | fc2.append(nn.ReLU(inplace=True))
69 | fc2.append(nn.Dropout(p=0.5))
70 | fc2.append(nn.Conv2d(1024, 1024, 1))
71 | fc2.append(nn.ReLU(inplace=True))
72 | fc2.append(nn.Dropout(p=0.5))
73 | self.fc2 = nn.Sequential(*fc2)
74 | self.fc2_score = nn.Conv2d(1024, n_classes, 1)
75 |
76 | # hole = 18
77 | fc3 = []
78 | fc3.append(nn.Conv2d(512, 1024, 3, padding=18, dilation=18))
79 | fc3.append(nn.ReLU(inplace=True))
80 | fc3.append(nn.Dropout(p=0.5))
81 | fc3.append(nn.Conv2d(1024, 1024, 1))
82 | fc3.append(nn.ReLU(inplace=True))
83 | fc3.append(nn.Dropout(p=0.5))
84 | self.fc3 = nn.Sequential(*fc3)
85 | self.fc3_score = nn.Conv2d(1024, n_classes, 1)
86 |
87 | # hole = 24
88 | fc4 = []
89 | fc4.append(nn.Conv2d(512, 1024, 3, padding=24, dilation=24))
90 | fc4.append(nn.ReLU(inplace=True))
91 | fc4.append(nn.Dropout(p=0.5))
92 | fc4.append(nn.Conv2d(1024, 1024, 1))
93 | fc4.append(nn.ReLU(inplace=True))
94 | fc4.append(nn.Dropout(p=0.5))
95 | self.fc4 = nn.Sequential(*fc4)
96 | self.fc4_score = nn.Conv2d(1024, n_classes, 1)
97 |
98 | self._initialize_weights()
99 |
100 | def _initialize_weights(self):
101 | for m in [self.fc1_score, self.fc2_score, self.fc3_score, self.fc4_score]:
102 | nn.init.normal_(m.weight, std=0.01)
103 | nn.init.constant_(m.bias, 0)
104 |
105 | # for module in [self.fc1, self.fc2, self.fc3, self.fc4]:
106 | # for m in self.modules():
107 | # if isinstance(m, nn.Conv2d):
108 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
109 | # nn.init.normal_(m.weight, std=math.sqrt(2. / n))
110 | # nn.init.kaiming_normal_(m.weight, mode='fan_in')
111 | # nn.init.constant_(m.bias, 0)
112 |
113 | vgg = torchvision.models.vgg16(pretrained=True)
114 | state_dict = vgg.features.state_dict()
115 | self.features.load_state_dict(state_dict)
116 |
117 | def forward(self, x):
118 | _, _, h, w = x.size()
119 | out = self.features(x)
120 | fuse1 = self.fc1_score(self.fc1(out))
121 | fuse2 = self.fc2_score(self.fc2(out))
122 | fuse3 = self.fc3_score(self.fc3(out))
123 | fuse4 = self.fc4_score(self.fc4(out))
124 | out = fuse1 + fuse2 + fuse3 + fuse4
125 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
126 | return out
127 |
128 | def get_parameters(self, bias=False, score=False):
129 | if score:
130 | for m in [self.fc1_score, self.fc2_score, self.fc3_score, self.fc4_score]:
131 | if bias:
132 | yield m.bias
133 | else:
134 | yield m.weight
135 | else:
136 | for module in [self.features, self.fc1, self.fc2, self.fc3, self.fc4]:
137 | for m in module.modules():
138 | if isinstance(m, nn.Conv2d):
139 | if bias:
140 | yield m.bias
141 | else:
142 | yield m.weight
143 |
144 | def freeze_bn(m):
145 | classname = m.__class__.__name__
146 | if classname.find('BatchNorm') != -1:
147 | for p in m.parameters():
148 | p.requires_grad = False
149 |
150 |
151 | class DeepLabASPPResNet(nn.Module):
152 | def __init__(self, n_classes):
153 | super(DeepLabASPPResNet, self).__init__()
154 | self.resnet = ResNet(Bottleneck, [3, 4, 23, 3])
155 | self.atrous_rates = [6, 12, 18, 24]
156 | self.aspp = ASPP(2048, self.atrous_rates, n_classes)
157 | self.resnet.apply(freeze_bn)
158 |
159 | def forward(self, x):
160 | _, _, h, w = x.size()
161 | x2 = F.interpolate(x, size=(int(h * 0.75) + 1, int(w * 0.75) + 1), mode='bilinear', align_corners=True)
162 | x3 = F.interpolate(x, size=(int(h * 0.5) + 1, int(w * 0.5) + 1), mode='bilinear', align_corners=True)
163 | x = self.aspp(self.resnet(x))
164 | x2 = self.aspp(self.resnet(x2))
165 | x3 = self.aspp(self.resnet(x3))
166 |
167 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
168 |
169 | x2 = F.interpolate(x2, size=(h, w), mode='bilinear', align_corners=True)
170 |
171 | x3 = F.interpolate(x3, size=(h, w), mode='bilinear', align_corners=True)
172 |
173 | x4 = torch.max(torch.max(x, x2), x3)
174 | return x, x2, x3, x4
175 |
176 | def get_parameters(self, bias=False, score=False):
177 | if score:
178 | for m in self.aspp.modules():
179 | if isinstance(m, nn.Conv2d):
180 | if bias:
181 | yield m.bias
182 | else:
183 | yield m.weight
184 | else:
185 | for m in self.resnet.modules():
186 | for p in m.parameters():
187 | if p.requires_grad:
188 | yield p
189 |
190 |
191 | class ASPP(nn.Module):
192 | def __init__(self, in_channels, atrous_rates, n_classes):
193 | super(ASPP, self).__init__()
194 |
195 | rate1, rate2, rate3, rate4 = atrous_rates
196 | self.conv1 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate1, dilation=rate1, bias=True)
197 | self.conv2 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate2, dilation=rate2, bias=True)
198 | self.conv3 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate3, dilation=rate3, bias=True)
199 | self.conv4 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate4, dilation=rate4, bias=True)
200 |
201 | self._initialize_weights()
202 |
203 | # def _initialize_weights(self):
204 | # for m in self.modules():
205 | # if isinstance(m, nn.Conv2d):
206 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
207 | # m.weight.data.normal_(0, math.sqrt(2. / n))
208 | # nn.init.kaiming_normal_(m.weight, mode='fan_out')
209 | # nn.init.constant_(m.bias, 0)
210 |
211 | def forward(self, x):
212 | features1 = self.conv1(x)
213 | features2 = self.conv2(x)
214 | features3 = self.conv3(x)
215 | features4 = self.conv4(x)
216 | out = features1 + features2 + features3 + features4
217 |
218 | return out
219 |
220 | def conv1x1(in_planes, out_planes, stride=1):
221 | """1x1 convolution"""
222 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
223 |
224 | class Bottleneck(nn.Module):
225 | expansion = 4
226 |
227 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
228 | super(Bottleneck, self).__init__()
229 | self.conv1 = conv1x1(inplanes, planes)
230 | self.bn1 = nn.BatchNorm2d(planes)
231 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
232 | padding=dilation, dilation=dilation, bias=False)
233 | self.bn2 = nn.BatchNorm2d(planes)
234 | self.conv3 = conv1x1(planes, planes * self.expansion)
235 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
236 | self.relu = nn.ReLU(inplace=True)
237 | self.downsample = downsample
238 | self.stride = stride
239 |
240 | def forward(self, x):
241 | identity = x
242 |
243 | out = self.conv1(x)
244 | out = self.bn1(out)
245 | out = self.relu(out)
246 |
247 | out = self.conv2(out)
248 | out = self.bn2(out)
249 | out = self.relu(out)
250 |
251 | out = self.conv3(out)
252 | out = self.bn3(out)
253 |
254 | if self.downsample is not None:
255 | identity = self.downsample(x)
256 |
257 | out += identity
258 | out = self.relu(out)
259 |
260 | return out
261 |
262 |
263 | class ResNet(nn.Module):
264 | """
265 | Adapted from https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/deeplabv3.py
266 | """
267 | def __init__(self, block, layers):
268 | super(ResNet, self).__init__()
269 | self.inplanes = 64
270 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
271 | bias=False)
272 | self.bn1 = nn.BatchNorm2d(64)
273 |
274 | self.relu = nn.ReLU(inplace=True)
275 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
276 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
277 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
278 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
279 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
280 |
281 | self._initialize_weights()
282 |
283 | # def _initialize_weights(self):
284 | # for m in self.modules():
285 | # if isinstance(m, nn.Conv2d):
286 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
287 | # m.weight.data.normal_(0, math.sqrt(2. / n))
288 | # nn.init.kaiming_normal_(m.weight, mode='fan_out')
289 | # elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
290 | # m.weight.data.fill_(1)
291 | # m.bias.data.zero_()
292 |
293 | resnet = torchvision.models.resnet101(pretrained=True)
294 | self.conv1.load_state_dict(resnet.conv1.state_dict())
295 | self.bn1.load_state_dict(resnet.bn1.state_dict())
296 | self.layer1.load_state_dict(resnet.layer1.state_dict())
297 | self.layer2.load_state_dict(resnet.layer2.state_dict())
298 | self.layer3.load_state_dict(resnet.layer3.state_dict())
299 | self.layer4.load_state_dict(resnet.layer4.state_dict())
300 |
301 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
302 | downsample = None
303 | if stride != 1 or self.inplanes != planes * block.expansion:
304 | downsample = nn.Sequential(
305 | nn.Conv2d(self.inplanes, planes * block.expansion,
306 | kernel_size=1, stride=stride, bias=False),
307 | nn.BatchNorm2d(planes * block.expansion))
308 |
309 | layers = []
310 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
311 | self.inplanes = planes * block.expansion
312 | for i in range(1, blocks):
313 | layers.append(block(self.inplanes, planes, dilation=dilation))
314 |
315 | return nn.Sequential(*layers)
316 |
317 | def forward(self, x):
318 | x = self.conv1(x)
319 | x = self.bn1(x)
320 | x = self.relu(x)
321 | x = self.maxpool(x)
322 |
323 | x = self.layer1(x)
324 | x = self.layer2(x)
325 | x = self.layer3(x)
326 | x = self.layer4(x)
327 |
328 | return x
329 |
--------------------------------------------------------------------------------
/Models/DeepLab_v3.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class DeepLabV3(nn.Module):
8 | def __init__(self, n_classes):
9 | super(DeepLabV3, self).__init__()
10 | self.n_classes = n_classes
11 | self.resnet = ResNet(Bottleneck, [3, 4, 6, 3])
12 | # self.atrous_rates = [6, 12, 18] # output_stride = 16
13 | self.atrous_rates = [12, 24, 36] # output_stride = 8
14 | self.aspp = ASPP(2048, self.atrous_rates)
15 |
16 | self.final = nn.Conv2d(256, n_classes, 1)
17 | nn.init.normal_(self.final.weight, 0.01)
18 | nn.init.constant_(self.final.bias, 0)
19 |
20 | def forward(self, x):
21 | _, _, h, w = x.size()
22 | out = self.resnet(x)
23 | out = self.aspp(out)
24 | out = self.final(out)
25 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True)
26 | return out
27 |
28 |
29 | class ASPP(nn.Module):
30 | def __init__(self, in_channels, atrous_rates):
31 | super(ASPP, self).__init__()
32 | out_channels = 256
33 |
34 | self.imagepool = nn.Sequential(
35 | nn.AdaptiveAvgPool2d(1),
36 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
37 | kernel_size=1, bias=False),
38 | nn.BatchNorm2d(num_features=out_channels),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | self.conv1x1 = nn.Sequential(
43 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
44 | kernel_size=1, bias=False),
45 | nn.BatchNorm2d(num_features=out_channels),
46 | nn.ReLU(inplace=True)
47 | )
48 |
49 | rate1, rate2, rate3 = tuple(atrous_rates)
50 | self.conv1 = self._ASPPConv(in_channels, out_channels, rate1)
51 | self.conv2 = self._ASPPConv(in_channels, out_channels, rate2)
52 | self.conv3 = self._ASPPConv(in_channels, out_channels, rate3)
53 |
54 | self.project = nn.Sequential(
55 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
56 | kernel_size=1, bias=False),
57 | nn.BatchNorm2d(num_features=out_channels),
58 | nn.ReLU(inplace=True),
59 | nn.Dropout(p=0.1)
60 | )
61 |
62 | self._initialize_weights()
63 |
64 | def _initialize_weights(self):
65 | for m in self.modules():
66 | if isinstance(m, nn.Conv2d):
67 | nn.init.kaiming_normal_(m.weight)
68 | if m.bias is not None:
69 | nn.init.constant_(m.bias, 0)
70 |
71 |
72 | def _ASPPConv(self, in_channels, out_channels, atrous_rate):
73 | block = nn.Sequential(
74 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
75 | kernel_size=3, padding=atrous_rate,
76 | dilation=atrous_rate, bias=False),
77 | nn.BatchNorm2d(num_features=out_channels),
78 | nn.ReLU(inplace=True)
79 | )
80 | return block
81 |
82 | def forward(self, x):
83 | _, _, h, w = x.size()
84 |
85 | features1 = F.interpolate(self.imagepool(x), size=(h, w), mode='bilinear', align_corners=True)
86 |
87 | features2 = self.conv1x1(x)
88 | features3 = self.conv1(x)
89 | features4 = self.conv2(x)
90 | features5 = self.conv3(x)
91 | out = torch.cat((features1, features2, features3, features4, features5), 1)
92 |
93 | out = self.project(out)
94 | return out
95 |
96 | def conv1x1(in_planes, out_planes, stride=1):
97 | """1x1 convolution"""
98 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
99 |
100 | class Bottleneck(nn.Module):
101 | expansion = 4
102 |
103 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, multi_grid=1):
104 | super(Bottleneck, self).__init__()
105 | self.conv1 = conv1x1(inplanes, planes)
106 | self.bn1 = nn.BatchNorm2d(planes)
107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
108 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
109 | self.bn2 = nn.BatchNorm2d(planes)
110 | self.conv3 = conv1x1(planes, planes * self.expansion)
111 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
112 | self.relu = nn.ReLU(inplace=True)
113 | self.downsample = downsample
114 | self.stride = stride
115 |
116 | def forward(self, x):
117 | identity = x
118 |
119 | out = self.conv1(x)
120 | out = self.bn1(out)
121 | out = self.relu(out)
122 |
123 | out = self.conv2(out)
124 | out = self.bn2(out)
125 | out = self.relu(out)
126 |
127 | out = self.conv3(out)
128 | out = self.bn3(out)
129 |
130 | if self.downsample is not None:
131 | identity = self.downsample(x)
132 |
133 | out += identity
134 | out = self.relu(out)
135 |
136 | return out
137 |
138 |
139 | class ResNet(nn.Module):
140 | """
141 | Adapted from https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/deeplabv3.py
142 | """
143 | def __init__(self, block, layers):
144 | super(ResNet, self).__init__()
145 | self.inplanes = 64
146 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
147 | bias=False)
148 | self.bn1 = nn.BatchNorm2d(64)
149 | self.relu = nn.ReLU(inplace=True)
150 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
153 |
154 | # for output_stride = 16
155 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
156 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 2, 4))
157 |
158 | # for output_stride = 8
159 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
160 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1, 2, 4))
161 |
162 | self._initialize_weights()
163 |
164 | def _initialize_weights(self):
165 | resnet = torchvision.models.resnet50(pretrained=True)
166 | self.conv1.load_state_dict(resnet.conv1.state_dict())
167 | self.bn1.load_state_dict(resnet.bn1.state_dict())
168 | self.layer1.load_state_dict(resnet.layer1.state_dict())
169 | self.layer2.load_state_dict(resnet.layer2.state_dict())
170 | self.layer3.load_state_dict(resnet.layer3.state_dict())
171 | self.layer4.load_state_dict(resnet.layer4.state_dict())
172 |
173 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
174 | downsample = None
175 | if stride != 1 or self.inplanes != planes * block.expansion:
176 | downsample = nn.Sequential(
177 | nn.Conv2d(self.inplanes, planes * block.expansion,
178 | kernel_size=1, stride=stride, bias=False),
179 | nn.BatchNorm2d(planes * block.expansion))
180 |
181 | layers = []
182 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
183 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
184 | self.inplanes = planes * block.expansion
185 | for i in range(1, blocks):
186 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
187 |
188 | return nn.Sequential(*layers)
189 |
190 | def forward(self, x):
191 | x = self.conv1(x)
192 | x = self.bn1(x)
193 | x = self.relu(x)
194 | x = self.maxpool(x)
195 |
196 | x = self.layer1(x)
197 | x = self.layer2(x)
198 | x = self.layer3(x)
199 | x = self.layer4(x)
200 |
201 | return x
--------------------------------------------------------------------------------
/Models/DeepLab_v3plus.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class DeepLabV3Plus(nn.Module):
8 | def __init__(self, n_classes):
9 | super(DeepLabV3Plus, self).__init__()
10 |
11 | self.resnet = ResNet(Bottleneck, [3, 4, 6, 3])
12 | self.head = _DeepLabHead()
13 | self.decoder1 = nn.Conv2d(64, 48, 1)
14 | self.decoder2 = nn.Sequential(
15 | nn.Conv2d(304, 256, 3),
16 | nn.BatchNorm2d(256),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(256, 256, 3),
19 | nn.BatchNorm2d(256),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(256, n_classes, 1)
22 | )
23 |
24 | def forward(self, x):
25 | _, _, h, w = x.size()
26 | out, branch = self.resnet(x)
27 | _, _, uh, uw = branch.size()
28 | out = self.head(out)
29 | out = F.interpolate(out, size=(uh, uw), mode='bilinear', align_corners=True)
30 | branch = self.decoder1(branch)
31 | out = torch.cat([out, branch], 1)
32 | out = self.decoder2(out)
33 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True)
34 | return out
35 |
36 | class _DeepLabHead(nn.Module):
37 | def __init__(self):
38 | super(_DeepLabHead, self).__init__()
39 | self.aspp = ASPP(2048, [6, 12, 18]) # output_stride = 16
40 | # self.aspp = ASPP(2048, [12, 24, 36]) # output_stride = 8
41 | # self.block = nn.Sequential(
42 | # nn.Conv2d(in_channels=256, out_channels=256,
43 | # kernel_size=3, padding=1, bias=False),
44 | # nn.BatchNorm2d(num_features=256),
45 | # nn.ReLU(inplace=True),
46 | # nn.Dropout(0.1),
47 | # nn.Conv2d(in_channels=256, out_channels=n_classes,
48 | # kernel_size=1)
49 | # )
50 |
51 | def forward(self, x):
52 | out = self.aspp(x)
53 | # out = self.block(out)
54 | return out
55 |
56 |
57 | class ASPP(nn.Module):
58 | def __init__(self, in_channels, atrous_rates):
59 | super(ASPP, self).__init__()
60 | out_channels = 256
61 |
62 | self.imagepool = nn.Sequential(
63 | nn.AdaptiveAvgPool2d(1),
64 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
65 | kernel_size=1, bias=False),
66 | nn.BatchNorm2d(num_features=out_channels),
67 | nn.ReLU(inplace=True)
68 | )
69 |
70 | self.conv1x1 = nn.Sequential(
71 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
72 | kernel_size=1, bias=False),
73 | nn.BatchNorm2d(num_features=out_channels),
74 | nn.ReLU(inplace=True)
75 | )
76 |
77 | rate1, rate2, rate3 = tuple(atrous_rates)
78 | self.conv1 = self._ASPPConv(in_channels, out_channels, rate1)
79 | self.conv2 = self._ASPPConv(in_channels, out_channels, rate2)
80 | self.conv3 = self._ASPPConv(in_channels, out_channels, rate3)
81 |
82 | self.project = nn.Sequential(
83 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
84 | kernel_size=1, bias=False),
85 | nn.BatchNorm2d(num_features=out_channels),
86 | nn.ReLU(inplace=True),
87 | # nn.Dropout(p=0.5)
88 | )
89 |
90 | def forward(self, x):
91 | _, _, h, w = x.size()
92 |
93 | features1 = F.interpolate(self.imagepool(x), size=(h, w), mode='bilinear', align_corners=True)
94 |
95 | features2 = self.conv1x1(x)
96 | features3 = self.conv1(x)
97 | features4 = self.conv2(x)
98 | features5 = self.conv3(x)
99 | out = torch.cat((features1, features2, features3, features4, features5), 1)
100 | out = self.project(out)
101 | return out
102 |
103 | def _ASPPConv(self, in_channels, out_channels, atrous_rate):
104 | block = nn.Sequential(
105 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
106 | kernel_size=3, padding=atrous_rate,
107 | dilation=atrous_rate, bias=False),
108 | nn.BatchNorm2d(num_features=out_channels),
109 | nn.ReLU(inplace=True)
110 | )
111 | return block
112 |
113 |
114 | def conv1x1(in_planes, out_planes, stride=1):
115 | """1x1 convolution"""
116 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
117 |
118 | class Bottleneck(nn.Module):
119 | expansion = 4
120 |
121 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, multi_grid=1):
122 | super(Bottleneck, self).__init__()
123 | self.conv1 = conv1x1(inplanes, planes)
124 | self.bn1 = nn.BatchNorm2d(planes)
125 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
126 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
127 | self.bn2 = nn.BatchNorm2d(planes)
128 | self.conv3 = conv1x1(planes, planes * self.expansion)
129 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
130 | self.relu = nn.ReLU(inplace=True)
131 | self.downsample = downsample
132 | self.stride = stride
133 |
134 | def forward(self, x):
135 | identity = x
136 |
137 | out = self.conv1(x)
138 | out = self.bn1(out)
139 | out = self.relu(out)
140 |
141 | out = self.conv2(out)
142 | out = self.bn2(out)
143 | out = self.relu(out)
144 |
145 | out = self.conv3(out)
146 | out = self.bn3(out)
147 |
148 | if self.downsample is not None:
149 | identity = self.downsample(x)
150 |
151 | out += identity
152 | out = self.relu(out)
153 |
154 | return out
155 |
156 |
157 | class ResNet(nn.Module):
158 | """
159 | Adapted from https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/deeplabv3.py
160 | """
161 | def __init__(self, block, layers):
162 | super(ResNet, self).__init__()
163 | self.inplanes = 64
164 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
165 | bias=False)
166 | self.bn1 = nn.BatchNorm2d(64)
167 | self.relu = nn.ReLU(inplace=True)
168 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
169 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
171 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1)
172 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 2, 4))
173 |
174 | # for output_stride = 8
175 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
176 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1, 2, 4))
177 |
178 | self._initialize_weights()
179 |
180 | def _initialize_weights(self):
181 | resnet = torchvision.models.resnet50(pretrained=True)
182 | self.conv1.load_state_dict(resnet.conv1.state_dict())
183 | self.bn1.load_state_dict(resnet.bn1.state_dict())
184 | self.layer1.load_state_dict(resnet.layer1.state_dict())
185 | self.layer2.load_state_dict(resnet.layer2.state_dict())
186 | self.layer3.load_state_dict(resnet.layer3.state_dict())
187 | self.layer4.load_state_dict(resnet.layer4.state_dict())
188 |
189 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
190 | downsample = None
191 | if stride != 1 or self.inplanes != planes * block.expansion:
192 | downsample = nn.Sequential(
193 | nn.Conv2d(self.inplanes, planes * block.expansion,
194 | kernel_size=1, stride=stride, bias=False),
195 | nn.BatchNorm2d(planes * block.expansion))
196 |
197 | layers = []
198 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
199 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
200 | self.inplanes = planes * block.expansion
201 | for i in range(1, blocks):
202 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
203 |
204 | return nn.Sequential(*layers)
205 |
206 | def forward(self, x):
207 | out = self.conv1(x)
208 | out = self.bn1(out)
209 | out = self.relu(out)
210 | branch = self.maxpool(out)
211 |
212 | out = self.layer1(branch)
213 | out = self.layer2(out)
214 | out = self.layer3(out)
215 | out = self.layer4(out)
216 |
217 | return out, branch
--------------------------------------------------------------------------------
/Models/Dilation8.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 |
6 | class Dilation8(nn.Module):
7 | """Adapted from official dilated8 implementation:
8 |
9 | https://github.com/fyu/dilation/blob/master/models/dilation8_pascal_voc_deploy.prototxt
10 | """
11 | def __init__(self, n_classes):
12 | super(Dilation8, self).__init__()
13 | features1 = []
14 | # conv1
15 | features1.append(nn.Conv2d(3, 64, 3))
16 | features1.append(nn.ReLU(inplace=True))
17 | features1.append(nn.Conv2d(64, 64, 3))
18 | features1.append(nn.ReLU(inplace=True))
19 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2
20 |
21 | # conv2
22 | features1.append(nn.Conv2d(64, 128, 3))
23 | features1.append(nn.ReLU(inplace=True))
24 | features1.append(nn.Conv2d(128, 128, 3))
25 | features1.append(nn.ReLU(inplace=True))
26 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4
27 |
28 | # conv3
29 | features1.append(nn.Conv2d(128, 256, 3))
30 | features1.append(nn.ReLU(inplace=True))
31 | features1.append(nn.Conv2d(256, 256, 3))
32 | features1.append(nn.ReLU(inplace=True))
33 | features1.append(nn.Conv2d(256, 256, 3))
34 | features1.append(nn.ReLU(inplace=True))
35 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8
36 |
37 | # conv4
38 | features1.append(nn.Conv2d(256, 512, 3))
39 | features1.append(nn.ReLU(inplace=True))
40 | features1.append(nn.Conv2d(512, 512, 3))
41 | features1.append(nn.ReLU(inplace=True))
42 | features1.append(nn.Conv2d(512, 512, 3))
43 | features1.append(nn.ReLU(inplace=True))
44 | self.features1 = nn.Sequential(*features1)
45 |
46 | # conv5
47 | features2 = []
48 | features2.append(nn.Conv2d(512, 512, 3, dilation=2))
49 | features2.append(nn.ReLU(inplace=True))
50 | features2.append(nn.Conv2d(512, 512, 3, dilation=2))
51 | features2.append(nn.ReLU(inplace=True))
52 | features2.append(nn.Conv2d(512, 512, 3, dilation=2))
53 | features2.append(nn.ReLU(inplace=True))
54 | self.features2 = nn.Sequential(*features2)
55 |
56 | fc = []
57 | fc.append(nn.Conv2d(512, 4096, 7, dilation=4))
58 | fc.append(nn.ReLU(inplace=True))
59 | fc.append(nn.Dropout(p=0.5))
60 | fc.append(nn.Conv2d(4096, 4096, 1))
61 | fc.append(nn.ReLU(inplace=True))
62 | fc.append(nn.Dropout(p=0.5))
63 | fc.append(nn.Conv2d(4096, n_classes, 1))
64 | self.fc = nn.Sequential(*fc)
65 |
66 | context = []
67 | context.append(nn.Conv2d(n_classes, 2 * n_classes, 3, padding=33))
68 | context.append(nn.ReLU(inplace=True))
69 | context.append(nn.Conv2d(2 * n_classes, 2 * n_classes, 3, padding=0))
70 | context.append(nn.ReLU(inplace=True))
71 | context.append(nn.Conv2d(2 * n_classes, 4 * n_classes, 3, dilation=2))
72 | context.append(nn.ReLU(inplace=True))
73 | context.append(nn.Conv2d(4 * n_classes, 8 * n_classes, 3, dilation=4))
74 | context.append(nn.ReLU(inplace=True))
75 | context.append(nn.Conv2d(8 * n_classes, 16 * n_classes, 3, dilation=8))
76 | context.append(nn.ReLU(inplace=True))
77 | context.append(nn.Conv2d(16 * n_classes, 32 * n_classes, 3, dilation=16))
78 | context.append(nn.ReLU(inplace=True))
79 | context.append(nn.Conv2d(32 * n_classes, 32 * n_classes, 3))
80 | context.append(nn.ReLU(inplace=True))
81 | context.append(nn.Conv2d(32 * n_classes, n_classes, 1))
82 | context.append(nn.ReLU(inplace=True))
83 | self.context = nn.Sequential(*context)
84 |
85 | self._initialize_weights()
86 |
87 | def _initialize_weights(self):
88 | vgg16 = torchvision.models.vgg16(pretrained=True)
89 | vgg_features1 = vgg16.features[0:23]
90 | self.features1.load_state_dict(vgg_features1.state_dict())
91 |
92 | vgg_features2 = vgg16.features[24:30]
93 | for l1, l2 in zip(vgg_features2.children(), self.features2.children()):
94 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
95 | assert l1.weight.size() == l2.weight.size()
96 | assert l1.bias.size() == l2.bias.size()
97 | l2.weight.data = l1.weight.data
98 | l2.bias.data = l1.bias.data
99 |
100 | fc = self.fc[0:4]
101 | for l1, l2 in zip(vgg16.classifier.children(), fc.children()):
102 | if isinstance(l1, nn.Linear) and isinstance(l2, nn.Conv2d):
103 | l2.weight.data = l1.weight.data.view(l2.weight.size())
104 | l2.bias.data = l1.bias.data.view(l2.bias.size())
105 |
106 | for m in self.context.modules():
107 | if isinstance(m, nn.Conv2d):
108 | nn.init.normal_(m.weight, std=0.001)
109 | nn.init.constant_(m.bias, 0)
110 |
111 | def forward(self, x):
112 | _, _, h, w = x.size()
113 | out = self.features1(x)
114 | out = self.features2(out)
115 | out = self.fc(out)
116 | out = self.context(out)
117 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
118 | return out
--------------------------------------------------------------------------------
/Models/FCN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 |
6 | # https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py
7 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
8 | """Make a 2D bilinear kernel suitable for upsampling"""
9 | factor = (kernel_size + 1) // 2
10 | if kernel_size % 2 == 1:
11 | center = factor - 1
12 | else:
13 | center = factor - 0.5
14 | og = np.ogrid[:kernel_size, :kernel_size]
15 | filt = (1 - abs(og[0] - center) / factor) * \
16 | (1 - abs(og[1] - center) / factor)
17 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
18 | dtype=np.float64)
19 | weight[range(in_channels), range(out_channels), :, :] = filt
20 | return torch.from_numpy(weight).float()
21 |
22 |
23 | class FCN32s(nn.Module):
24 | """Adapted from official implementation:
25 |
26 | https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn32s/train.prototxt
27 | """
28 | def __init__(self, n_classes):
29 | super(FCN32s, self).__init__()
30 |
31 | features = []
32 | # conv1
33 | features.append(nn.Conv2d(3, 64, 3, padding=100))
34 | features.append(nn.ReLU(inplace=True))
35 | features.append(nn.Conv2d(64, 64, 3, padding=1))
36 | features.append(nn.ReLU(inplace=True))
37 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2
38 |
39 | # conv2
40 | features.append(nn.Conv2d(64, 128, 3, padding=1))
41 | features.append(nn.ReLU(inplace=True))
42 | features.append(nn.Conv2d(128, 128, 3, padding=1))
43 | features.append(nn.ReLU(inplace=True))
44 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4
45 |
46 | # conv3
47 | features.append(nn.Conv2d(128, 256, 3, padding=1))
48 | features.append(nn.ReLU(inplace=True))
49 | features.append(nn.Conv2d(256, 256, 3, padding=1))
50 | features.append(nn.ReLU(inplace=True))
51 | features.append(nn.Conv2d(256, 256, 3, padding=1))
52 | features.append(nn.ReLU(inplace=True))
53 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8
54 |
55 | # conv4
56 | features.append(nn.Conv2d(256, 512, 3, padding=1))
57 | features.append(nn.ReLU(inplace=True))
58 | features.append(nn.Conv2d(512, 512, 3, padding=1))
59 | features.append(nn.ReLU(inplace=True))
60 | features.append(nn.Conv2d(512, 512, 3, padding=1))
61 | features.append(nn.ReLU(inplace=True))
62 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/16
63 |
64 | # conv5
65 | features.append(nn.Conv2d(512, 512, 3, padding=1))
66 | features.append(nn.ReLU(inplace=True))
67 | features.append(nn.Conv2d(512, 512, 3, padding=1))
68 | features.append(nn.ReLU(inplace=True))
69 | features.append(nn.Conv2d(512, 512, 3, padding=1))
70 | features.append(nn.ReLU(inplace=True))
71 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/32
72 |
73 | self.features = nn.Sequential(*features)
74 |
75 | fc = []
76 | fc.append(nn.Conv2d(512, 4096, 7))
77 | fc.append(nn.ReLU(inplace=True))
78 | fc.append(nn.Dropout(p=0.5))
79 | fc.append(nn.Conv2d(4096, 4096, 1))
80 | fc.append(nn.ReLU(inplace=True))
81 | fc.append(nn.Dropout(p=0.5))
82 | self.fc = nn.Sequential(*fc)
83 |
84 | self.score_fr = nn.Conv2d(4096, n_classes, 1)
85 | self.upscore = nn.ConvTranspose2d(n_classes, n_classes, 64, stride=32,
86 | bias=False)
87 |
88 | self._initialize_weights()
89 |
90 | def _initialize_weights(self):
91 | self.score_fr.weight.data.zero_()
92 | self.score_fr.bias.data.zero_()
93 |
94 | assert self.upscore.kernel_size[0] == self.upscore.kernel_size[1]
95 | initial_weight = get_upsampling_weight(
96 | self.upscore.in_channels, self.upscore.out_channels,
97 | self.upscore.kernel_size[0])
98 | self.upscore.weight.data.copy_(initial_weight)
99 |
100 |
101 | vgg16 = torchvision.models.vgg16(pretrained=True)
102 | state_dict = vgg16.features.state_dict()
103 | self.features.load_state_dict(state_dict)
104 |
105 | for l1, l2 in zip(vgg16.classifier.children(), self.fc):
106 | if isinstance(l1, nn.Linear) and isinstance(l2, nn.Conv2d):
107 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size()))
108 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size()))
109 |
110 | def forward(self, x):
111 | out = self.features(x)
112 | out = self.fc(out)
113 | out = self.score_fr(out)
114 | out = self.upscore(out)
115 | out = out[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous()
116 |
117 | return out
118 |
119 | def get_parameters(self, bias=False):
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | if bias:
123 | yield m.bias
124 | else:
125 | yield m.weight
126 |
127 |
128 | class FCN8sAtOnce(nn.Module):
129 | def __init__(self, n_classes):
130 | super(FCN8sAtOnce, self).__init__()
131 |
132 | features1 = []
133 | # conv1
134 | features1.append(nn.Conv2d(3, 64, 3, padding=100))
135 | features1.append(nn.ReLU(inplace=True))
136 | features1.append(nn.Conv2d(64, 64, 3, padding=1))
137 | features1.append(nn.ReLU(inplace=True))
138 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2
139 |
140 | # conv2
141 | features1.append(nn.Conv2d(64, 128, 3, padding=1))
142 | features1.append(nn.ReLU(inplace=True))
143 | features1.append(nn.Conv2d(128, 128, 3, padding=1))
144 | features1.append(nn.ReLU(inplace=True))
145 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4
146 |
147 | # conv3
148 | features1.append(nn.Conv2d(128, 256, 3, padding=1))
149 | features1.append(nn.ReLU(inplace=True))
150 | features1.append(nn.Conv2d(256, 256, 3, padding=1))
151 | features1.append(nn.ReLU(inplace=True))
152 | features1.append(nn.Conv2d(256, 256, 3, padding=1))
153 | features1.append(nn.ReLU(inplace=True))
154 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8
155 | self.features1 = nn.Sequential(*features1)
156 |
157 | features2 = []
158 | # conv4
159 | features2.append(nn.Conv2d(256, 512, 3, padding=1))
160 | features2.append(nn.ReLU(inplace=True))
161 | features2.append(nn.Conv2d(512, 512, 3, padding=1))
162 | features2.append(nn.ReLU(inplace=True))
163 | features2.append(nn.Conv2d(512, 512, 3, padding=1))
164 | features2.append(nn.ReLU(inplace=True))
165 | features2.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/16
166 | self.features2 = nn.Sequential(*features2)
167 |
168 | features3 = []
169 | # conv5
170 | features3.append(nn.Conv2d(512, 512, 3, padding=1))
171 | features3.append(nn.ReLU(inplace=True))
172 | features3.append(nn.Conv2d(512, 512, 3, padding=1))
173 | features3.append(nn.ReLU(inplace=True))
174 | features3.append(nn.Conv2d(512, 512, 3, padding=1))
175 | features3.append(nn.ReLU(inplace=True))
176 | features3.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/32
177 | self.features3 = nn.Sequential(*features3)
178 |
179 | fc = []
180 | # fc6
181 | fc.append(nn.Conv2d(512, 4096, 7))
182 | fc.append(nn.ReLU(inplace=True))
183 | fc.append(nn.Dropout2d())
184 |
185 | # fc7
186 | fc.append(nn.Conv2d(4096, 4096, 1))
187 | fc.append(nn.ReLU(inplace=True))
188 | fc.append(nn.Dropout2d())
189 | self.fc = nn.Sequential(*fc)
190 |
191 | self.score_fr = nn.Conv2d(4096, n_classes, 1)
192 | self.score_pool3 = nn.Conv2d(256, n_classes, 1)
193 | self.score_pool4 = nn.Conv2d(512, n_classes, 1)
194 |
195 | self.upscore2 = nn.ConvTranspose2d(
196 | n_classes, n_classes, 4, stride=2, bias=False)
197 | self.upscore8 = nn.ConvTranspose2d(
198 | n_classes, n_classes, 16, stride=8, bias=False)
199 | self.upscore_pool4 = nn.ConvTranspose2d(
200 | n_classes, n_classes, 4, stride=2, bias=False)
201 |
202 | self._initialize_weights()
203 |
204 | def _initialize_weights(self):
205 | for m in [self.score_fr, self.score_pool3, self.score_pool4]:
206 | m.weight.data.zero_()
207 | m.bias.data.zero_()
208 |
209 | for m in [self.upscore2, self.upscore8, self.upscore_pool4]:
210 | assert m.kernel_size[0] == m.kernel_size[1]
211 | initial_weight = get_upsampling_weight(
212 | m.in_channels, m.out_channels, m.kernel_size[0])
213 | m.weight.data.copy_(initial_weight)
214 |
215 | vgg16 = torchvision.models.vgg16(pretrained=True)
216 | vgg_features = [
217 | vgg16.features[:17],
218 | vgg16.features[17:24],
219 | vgg16.features[24:],
220 | ]
221 | features = [
222 | self.features1,
223 | self.features2,
224 | self.features3,
225 | ]
226 |
227 | for l1, l2 in zip(vgg_features, features):
228 | for ll1, ll2 in zip(l1.children(), l2.children()):
229 | if isinstance(ll1, nn.Conv2d) and isinstance(ll2, nn.Conv2d):
230 | assert ll1.weight.size() == ll2.weight.size()
231 | assert ll1.bias.size() == ll2.bias.size()
232 | ll2.weight.data.copy_(ll1.weight.data)
233 | ll2.bias.data.copy_(ll1.bias.data)
234 |
235 | for l1, l2 in zip(vgg16.classifier.children(), self.fc):
236 | if isinstance(l1, nn.Linear) and isinstance(l2, nn.Conv2d):
237 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size()))
238 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size()))
239 |
240 | def forward(self, x):
241 | pool3 = self.features1(x) # 1/8
242 | pool4 = self.features2(pool3) # 1/16
243 | pool5 = self.features3(pool4) # 1/32
244 | fc = self.fc(pool5)
245 | score_fr = self.score_fr(fc)
246 | upscore2 = self.upscore2(score_fr) # 1/16
247 |
248 | score_pool4 = self.score_pool4(pool4 * 0.01) # XXX: scaling to train at once
249 | score_pool4c = score_pool4[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
250 | upscore_pool4 = self.upscore_pool4(upscore2 + score_pool4c) # 1/8
251 |
252 | score_pool3 = self.score_pool3(pool3 * 0.0001) # XXX: scaling to train at once
253 | score_pool3c = score_pool3[:, :,
254 | 9:9 + upscore_pool4.size()[2],
255 | 9:9 + upscore_pool4.size()[3]]
256 | out = self.upscore8(upscore_pool4 + score_pool3c)
257 |
258 | out = out[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous()
259 |
260 | return out
261 |
262 | def get_parameters(self, bias=False):
263 | for m in self.modules():
264 | if isinstance(m, nn.Conv2d):
265 | if bias:
266 | yield m.bias
267 | else:
268 | yield m.weight
269 |
270 |
271 | if __name__ == "__main__":
272 | import torch
273 | import time
274 | model = FCN32s(21)
275 | print(f'==> Testing {model.__class__.__name__} with PyTorch')
276 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277 | # torch.backends.cudnn.benchmark = True
278 |
279 | model = model.to(device)
280 | model.eval()
281 |
282 | x = torch.Tensor(1, 3, 500, 500)
283 | x = x.to(device)
284 |
285 | torch.cuda.synchronize()
286 | t_start = time.time()
287 | for i in range(10):
288 | model(x)
289 | torch.cuda.synchronize()
290 | elapsed_time = time.time() - t_start
291 |
292 | print(f'Speed: {(elapsed_time / 10) * 1000:.2f} ms')
--------------------------------------------------------------------------------
/Models/PSPNet.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # def freeze_bn(self):
7 | # for m in self.modules():
8 | # if isinstance(m, nn.BatchNorm2d):
9 | # m.eval()
10 |
11 | class PSPNet(nn.Module):
12 | """set crop size to 480
13 | """
14 | def __init__(self, n_classes):
15 | super(PSPNet, self).__init__()
16 |
17 | self.resnet = ResNet(Bottleneck, [3, 4, 6, 3])
18 | self.pyramid_pooling = PyramidPooling(2048, 512)
19 | self.final = nn.Sequential(
20 | nn.Conv2d(4096, 512, 3, padding=1, bias=False),
21 | nn.BatchNorm2d(512, momentum=.95),
22 | nn.ReLU(inplace=True),
23 | nn.Dropout(p=0.1),
24 | nn.Conv2d(512, n_classes, 1)
25 | )
26 |
27 | self._initialize_weights()
28 |
29 | def _initialize_weights(self):
30 | for m in self.final:
31 | if isinstance(m, nn.Conv2d):
32 | nn.init.kaiming_normal_(m.weight)
33 | if m.bias is not None:
34 | nn.init.constant_(m.bias, 0)
35 | if isinstance(m, nn.BatchNorm2d):
36 | nn.init.constant_(m.weight, 1)
37 | nn.init.constant_(m.bias, 0)
38 |
39 | def forward(self, x):
40 | _, _, h, w = x.size()
41 | out = self.resnet(x)
42 | out = self.pyramid_pooling(out)
43 | out = self.final(out)
44 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True)
45 | return out
46 |
47 | def conv1x1(in_planes, out_planes, stride=1):
48 | """1x1 convolution"""
49 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
50 |
51 | class PyramidPooling(nn.Module):
52 | def __init__(self, in_channels, out_channels):
53 | super(PyramidPooling, self).__init__()
54 | self.pool1 = self._pyramid_conv(in_channels, out_channels, 10)
55 | self.pool2 = self._pyramid_conv(in_channels, out_channels, 20)
56 | self.pool3 = self._pyramid_conv(in_channels, out_channels, 30)
57 | self.pool4 = self._pyramid_conv(in_channels, out_channels, 60)
58 |
59 | self._initialize_weights()
60 |
61 | def _initialize_weights(self):
62 | for m in self.modules():
63 | if isinstance(m, nn.Conv2d):
64 | nn.init.kaiming_normal_(m.weight)
65 | if isinstance(m, nn.BatchNorm2d):
66 | nn.init.constant_(m.weight, 1)
67 | nn.init.constant_(m.bias, 0)
68 |
69 | def _pyramid_conv(self, in_channels, out_channels, scale):
70 | module = nn.Sequential(
71 | # nn.AdaptiveAvgPool2d(scale),
72 | nn.AvgPool2d(kernel_size=scale, stride=scale),
73 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
74 | nn.BatchNorm2d(out_channels, momentum=.95),
75 | nn.ReLU(inplace=True)
76 | )
77 | return module
78 |
79 | def forward(self, x):
80 | _, _, h, w = x.size()
81 | pool1 = self.pool1(x)
82 | pool2 = self.pool2(x)
83 | pool3 = self.pool3(x)
84 | pool4 = self.pool4(x)
85 | pool1 = F.interpolate(pool1, size=(h, w), mode='bilinear', align_corners=True)
86 | pool2 = F.interpolate(pool2, size=(h, w), mode='bilinear', align_corners=True)
87 | pool3 = F.interpolate(pool3, size=(h, w), mode='bilinear', align_corners=True)
88 | pool4 = F.interpolate(pool4, size=(h, w), mode='bilinear', align_corners=True)
89 | out = torch.cat([x, pool1, pool2, pool3, pool4], 1)
90 | return out
91 |
92 |
93 | class Bottleneck(nn.Module):
94 | expansion = 4
95 |
96 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
97 | super(Bottleneck, self).__init__()
98 | self.conv1 = conv1x1(inplanes, planes)
99 | self.bn1 = nn.BatchNorm2d(planes, momentum=.95)
100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
101 | padding=dilation, dilation=dilation, bias=False)
102 | self.bn2 = nn.BatchNorm2d(planes, momentum=.95)
103 | self.conv3 = conv1x1(planes, planes * self.expansion)
104 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=.95)
105 | self.relu = nn.ReLU(inplace=True)
106 | self.downsample = downsample
107 | self.stride = stride
108 |
109 | def forward(self, x):
110 | identity = x
111 |
112 | out = self.conv1(x)
113 | out = self.bn1(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv2(out)
117 | out = self.bn2(out)
118 | out = self.relu(out)
119 |
120 | out = self.conv3(out)
121 | out = self.bn3(out)
122 |
123 | if self.downsample is not None:
124 | identity = self.downsample(x)
125 |
126 | out += identity
127 | out = self.relu(out)
128 |
129 | return out
130 |
131 |
132 | class ResNet(nn.Module):
133 | def __init__(self, block, layers):
134 | super(ResNet, self).__init__()
135 | self.inplanes = 64
136 | self.conv1 = nn.Sequential(
137 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
138 | nn.BatchNorm2d(64, momentum=.95),
139 | nn.ReLU(inplace=True),
140 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
141 | nn.BatchNorm2d(64, momentum=.95),
142 | nn.ReLU(inplace=True),
143 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
144 | nn.BatchNorm2d(64, momentum=.95),
145 | nn.ReLU(inplace=True),
146 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
147 | )
148 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
152 |
153 | self._initialize_weights()
154 |
155 | def _initialize_weights(self):
156 | for m in self.conv1.children():
157 | if isinstance(m, nn.Conv2d):
158 | nn.init.kaiming_normal_(m.weight)
159 | if isinstance(m, nn.BatchNorm2d):
160 | nn.init.constant_(m.weight, 1)
161 | nn.init.constant_(m.bias, 0)
162 |
163 | for module in [self.layer1, self.layer2, self.layer3, self.layer4]:
164 | for m in module.modules():
165 | if isinstance(m, nn.BatchNorm2d):
166 | nn.init.constant_(m.weight, 1)
167 | nn.init.constant_(m.bias, 0)
168 |
169 | resnet = torchvision.models.resnet50(pretrained=True)
170 | self.layer1.load_state_dict(resnet.layer1.state_dict())
171 | self.layer2.load_state_dict(resnet.layer2.state_dict())
172 | self.layer3.load_state_dict(resnet.layer3.state_dict())
173 | self.layer4.load_state_dict(resnet.layer4.state_dict())
174 |
175 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
176 | downsample = None
177 | if stride != 1 or self.inplanes != planes * block.expansion:
178 | downsample = nn.Sequential(
179 | nn.Conv2d(self.inplanes, planes * block.expansion,
180 | kernel_size=1, stride=stride, bias=False),
181 | nn.BatchNorm2d(planes * block.expansion, momentum=0.95))
182 |
183 | layers = []
184 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample))
185 | self.inplanes = planes * block.expansion
186 | for i in range(1, blocks):
187 | layers.append(block(self.inplanes, planes, dilation=dilation))
188 |
189 | return nn.Sequential(*layers)
190 |
191 |
192 | def forward(self, x):
193 | out = self.conv1(x)
194 |
195 | out = self.layer1(out)
196 | out = self.layer2(out)
197 | out = self.layer3(out)
198 | out = self.layer4(out)
199 |
200 | return out
--------------------------------------------------------------------------------
/Models/SegNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision
4 |
5 |
6 | # use vgg16_bn pretrained model
7 | class SegNet(nn.Module):
8 | """Adapted from official implementation:
9 |
10 | https://github.com/alexgkendall/SegNet-Tutorial/tree/master/Models
11 | """
12 | def __init__(self, n_classes):
13 | super(SegNet, self).__init__()
14 |
15 | # conv1
16 | features1 = []
17 | features1.append(nn.Conv2d(3, 64, 3, padding=1))
18 | features1.append(nn.BatchNorm2d(64))
19 | features1.append(nn.ReLU(inplace=True))
20 | features1.append(nn.Conv2d(64, 64, 3, padding=1))
21 | features1.append(nn.BatchNorm2d(64))
22 | features1.append(nn.ReLU(inplace=True))
23 | self.features1 = nn.Sequential(*features1)
24 | self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/2
25 |
26 | # conv2
27 | features2 = []
28 | features2.append(nn.Conv2d(64, 128, 3, padding=1))
29 | features2.append(nn.BatchNorm2d(128))
30 | features2.append(nn.ReLU(inplace=True))
31 | features2.append(nn.Conv2d(128, 128, 3, padding=1))
32 | features2.append(nn.BatchNorm2d(128))
33 | features2.append(nn.ReLU(inplace=True))
34 | self.features2 = nn.Sequential(*features2)
35 | self.pool2 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/4
36 |
37 | # conv3
38 | features3 = []
39 | features3.append(nn.Conv2d(128, 256, 3, padding=1))
40 | features3.append(nn.BatchNorm2d(256))
41 | features3.append(nn.ReLU(inplace=True))
42 | features3.append(nn.Conv2d(256, 256, 3, padding=1))
43 | features3.append(nn.BatchNorm2d(256))
44 | features3.append(nn.ReLU(inplace=True))
45 | features3.append(nn.Conv2d(256, 256, 3, padding=1))
46 | features3.append(nn.BatchNorm2d(256))
47 | features3.append(nn.ReLU(inplace=True))
48 | self.features3 = nn.Sequential(*features3)
49 | self.pool3 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/8
50 |
51 | # conv4
52 | features4 = []
53 | features4.append(nn.Conv2d(256, 512, 3, padding=1))
54 | features4.append(nn.BatchNorm2d(512))
55 | features4.append(nn.ReLU(inplace=True))
56 | features4.append(nn.Conv2d(512, 512, 3, padding=1))
57 | features4.append(nn.BatchNorm2d(512))
58 | features4.append(nn.ReLU(inplace=True))
59 | features4.append(nn.Conv2d(512, 512, 3, padding=1))
60 | features4.append(nn.BatchNorm2d(512))
61 | features4.append(nn.ReLU(inplace=True))
62 | self.features4 = nn.Sequential(*features4)
63 | self.pool4 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/16
64 |
65 | # conv5
66 | features5 = []
67 | features5.append(nn.Conv2d(512, 512, 3, padding=1))
68 | features5.append(nn.BatchNorm2d(512))
69 | features5.append(nn.ReLU(inplace=True))
70 | features5.append(nn.Conv2d(512, 512, 3, padding=1))
71 | features5.append(nn.BatchNorm2d(512))
72 | features5.append(nn.ReLU(inplace=True))
73 | features5.append(nn.Conv2d(512, 512, 3, padding=1))
74 | features5.append(nn.BatchNorm2d(512))
75 | features5.append(nn.ReLU(inplace=True))
76 | self.features5 = nn.Sequential(*features5)
77 | self.pool5 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/32
78 |
79 | # convTranspose1
80 | self.unpool6 = nn.MaxUnpool2d(2, stride=2)
81 | features6 = []
82 | features6.append(nn.Conv2d(512, 512, 3, padding=1))
83 | features6.append(nn.BatchNorm2d(512))
84 | features6.append(nn.ReLU(inplace=True))
85 | features6.append(nn.Conv2d(512, 512, 3, padding=1))
86 | features6.append(nn.BatchNorm2d(512))
87 | features6.append(nn.ReLU(inplace=True))
88 | features6.append(nn.Conv2d(512, 512, 3, padding=1))
89 | features6.append(nn.BatchNorm2d(512))
90 | features6.append(nn.ReLU(inplace=True))
91 | self.features6 = nn.Sequential(*features6)
92 |
93 | # convTranspose2
94 | self.unpool7 = nn.MaxUnpool2d(2, stride=2)
95 | features7 = []
96 | features7.append(nn.Conv2d(512, 512, 3, padding=1))
97 | features7.append(nn.BatchNorm2d(512))
98 | features7.append(nn.ReLU(inplace=True))
99 | features7.append(nn.Conv2d(512, 512, 3, padding=1))
100 | features7.append(nn.BatchNorm2d(512))
101 | features7.append(nn.ReLU(inplace=True))
102 | features7.append(nn.Conv2d(512, 256, 3, padding=1))
103 | features7.append(nn.BatchNorm2d(256))
104 | features7.append(nn.ReLU(inplace=True))
105 | self.features7 = nn.Sequential(*features7)
106 |
107 | # convTranspose3
108 | self.unpool8 = nn.MaxUnpool2d(2, stride=2)
109 | features8 = []
110 | features8.append(nn.Conv2d(256, 256, 3, padding=1))
111 | features8.append(nn.BatchNorm2d(256))
112 | features8.append(nn.ReLU(inplace=True))
113 | features8.append(nn.Conv2d(256, 256, 3, padding=1))
114 | features8.append(nn.BatchNorm2d(256))
115 | features8.append(nn.ReLU(inplace=True))
116 | features8.append(nn.Conv2d(256, 128, 3, padding=1))
117 | features8.append(nn.BatchNorm2d(128))
118 | features8.append(nn.ReLU(inplace=True))
119 | self.features8 = nn.Sequential(*features8)
120 |
121 | # convTranspose4
122 | self.unpool9 = nn.MaxUnpool2d(2, stride=2)
123 | features9 = []
124 | features9.append(nn.Conv2d(128, 128, 3, padding=1))
125 | features9.append(nn.BatchNorm2d(128))
126 | features9.append(nn.ReLU(inplace=True))
127 | features9.append(nn.Conv2d(128, 64, 3, padding=1))
128 | features9.append(nn.BatchNorm2d(64))
129 | features9.append(nn.ReLU(inplace=True))
130 | self.features9 = nn.Sequential(*features9)
131 |
132 | # convTranspose5
133 | self.unpool10 = nn.MaxUnpool2d(2, stride=2)
134 | self.final = nn.Sequential(
135 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
136 | nn.BatchNorm2d(64),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(64, n_classes, kernel_size=3, padding=1),
139 | )
140 |
141 | self._initialize_weights()
142 |
143 | def _initialize_weights(self):
144 | for m in self.modules():
145 | if isinstance(m, nn.Conv2d):
146 | nn.init.kaiming_normal_(m.weight)
147 | # if isinstance(m, nn.BatchNorm2d):
148 | # nn.init.constant_(m.weight, 1)
149 | # nn.init.constant_(m.bias, 0.001)
150 |
151 | vgg16 = torchvision.models.vgg16_bn(pretrained=True)
152 | vgg_features = [
153 | vgg16.features[0:6],
154 | vgg16.features[7:13],
155 | vgg16.features[14:23],
156 | vgg16.features[24:33],
157 | vgg16.features[34:43]
158 | ]
159 | features = [
160 | self.features1,
161 | self.features2,
162 | self.features3,
163 | self.features4,
164 | self.features5
165 | ]
166 | for l1, l2 in zip(vgg_features, features):
167 | for ll1, ll2 in zip(l1.children(), l2.children()):
168 | if isinstance(ll1, nn.Conv2d) and isinstance(ll2, nn.Conv2d):
169 | assert ll1.weight.size() == ll2.weight.size()
170 | assert ll1.bias.size() == ll2.bias.size()
171 | ll2.weight.data = ll1.weight.data
172 | ll2.bias.data = ll1.bias.data
173 | if isinstance(ll1, nn.BatchNorm2d) and isinstance(ll2, nn.BatchNorm2d):
174 | assert ll1.weight.size() == ll2.weight.size()
175 | assert ll1.bias.size() == ll2.bias.size()
176 | ll2.weight.data = ll1.weight.data
177 | ll2.bias.data = ll1.bias.data
178 |
179 | def forward(self, x):
180 | out = self.features1(x)
181 | out, indices_1 = self.pool1(out)
182 | out = self.features2(out)
183 | out, indices_2 = self.pool2(out)
184 | out = self.features3(out)
185 | out, indices_3 = self.pool3(out)
186 | out = self.features4(out)
187 | out, indices_4 = self.pool4(out)
188 | out = self.features5(out)
189 | out, indices_5 = self.pool5(out)
190 | out = self.unpool6(out, indices_5)
191 | out = self.features6(out)
192 | out = self.unpool7(out, indices_4)
193 | out = self.features7(out)
194 | out = self.unpool8(out, indices_3)
195 | out = self.features8(out)
196 | out = self.unpool9(out, indices_2)
197 | out = self.features9(out)
198 | out = self.unpool10(out, indices_1)
199 | out = self.final(out)
200 | return out
201 |
202 |
203 | # use vgg16 pretrained model
204 | # class SegNet(nn.Module):
205 | # """
206 | # Adapted from official implementation:
207 |
208 | # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/Models
209 | # """
210 | # def __init__(self, n_classes):
211 | # super(SegNet, self).__init__()
212 |
213 | # # conv1
214 | # features1 = []
215 | # features1.append(nn.Conv2d(3, 64, 3, padding=1))
216 | # features1.append(nn.BatchNorm2d(64))
217 | # features1.append(nn.ReLU(inplace=True))
218 | # features1.append(nn.Conv2d(64, 64, 3, padding=1))
219 | # features1.append(nn.BatchNorm2d(64))
220 | # features1.append(nn.ReLU(inplace=True))
221 | # self.features1 = nn.Sequential(*features1)
222 | # self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/2
223 |
224 | # # conv2
225 | # features2 = []
226 | # features2.append(nn.Conv2d(64, 128, 3, padding=1))
227 | # features2.append(nn.BatchNorm2d(128))
228 | # features2.append(nn.ReLU(inplace=True))
229 | # features2.append(nn.Conv2d(128, 128, 3, padding=1))
230 | # features2.append(nn.BatchNorm2d(128))
231 | # features2.append(nn.ReLU(inplace=True))
232 | # self.features2 = nn.Sequential(*features2)
233 | # self.pool2 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/4
234 |
235 | # # conv3
236 | # features3 = []
237 | # features3.append(nn.Conv2d(128, 256, 3, padding=1))
238 | # features3.append(nn.BatchNorm2d(256))
239 | # features3.append(nn.ReLU(inplace=True))
240 | # features3.append(nn.Conv2d(256, 256, 3, padding=1))
241 | # features3.append(nn.BatchNorm2d(256))
242 | # features3.append(nn.ReLU(inplace=True))
243 | # features3.append(nn.Conv2d(256, 256, 3, padding=1))
244 | # features3.append(nn.BatchNorm2d(256))
245 | # features3.append(nn.ReLU(inplace=True))
246 | # self.features3 = nn.Sequential(*features3)
247 | # self.pool3 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/8
248 |
249 | # # conv4
250 | # features4 = []
251 | # features4.append(nn.Conv2d(256, 512, 3, padding=1))
252 | # features4.append(nn.BatchNorm2d(512))
253 | # features4.append(nn.ReLU(inplace=True))
254 | # features4.append(nn.Conv2d(512, 512, 3, padding=1))
255 | # features4.append(nn.BatchNorm2d(512))
256 | # features4.append(nn.ReLU(inplace=True))
257 | # features4.append(nn.Conv2d(512, 512, 3, padding=1))
258 | # features4.append(nn.BatchNorm2d(512))
259 | # features4.append(nn.ReLU(inplace=True))
260 | # self.features4 = nn.Sequential(*features4)
261 | # self.pool4 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/16
262 |
263 | # # conv5
264 | # features5 = []
265 | # features5.append(nn.Conv2d(512, 512, 3, padding=1))
266 | # features5.append(nn.BatchNorm2d(512))
267 | # features5.append(nn.ReLU(inplace=True))
268 | # features5.append(nn.Conv2d(512, 512, 3, padding=1))
269 | # features5.append(nn.BatchNorm2d(512))
270 | # features5.append(nn.ReLU(inplace=True))
271 | # features5.append(nn.Conv2d(512, 512, 3, padding=1))
272 | # features5.append(nn.BatchNorm2d(512))
273 | # features5.append(nn.ReLU(inplace=True))
274 | # self.features5 = nn.Sequential(*features5)
275 | # self.pool5 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/32
276 |
277 | # # convTranspose1
278 | # self.unpool6 = nn.MaxUnpool2d(2, stride=2)
279 | # features6 = []
280 | # features6.append(nn.Conv2d(512, 512, 3, padding=1))
281 | # features6.append(nn.BatchNorm2d(512))
282 | # features6.append(nn.ReLU(inplace=True))
283 | # features6.append(nn.Conv2d(512, 512, 3, padding=1))
284 | # features6.append(nn.BatchNorm2d(512))
285 | # features6.append(nn.ReLU(inplace=True))
286 | # features6.append(nn.Conv2d(512, 512, 3, padding=1))
287 | # features6.append(nn.BatchNorm2d(512))
288 | # features6.append(nn.ReLU(inplace=True))
289 | # self.features6 = nn.Sequential(*features6)
290 |
291 | # # convTranspose2
292 | # self.unpool7 = nn.MaxUnpool2d(2, stride=2)
293 | # features7 = []
294 | # features7.append(nn.Conv2d(512, 512, 3, padding=1))
295 | # features7.append(nn.BatchNorm2d(512))
296 | # features7.append(nn.ReLU(inplace=True))
297 | # features7.append(nn.Conv2d(512, 512, 3, padding=1))
298 | # features7.append(nn.BatchNorm2d(512))
299 | # features7.append(nn.ReLU(inplace=True))
300 | # features7.append(nn.Conv2d(512, 256, 3, padding=1))
301 | # features7.append(nn.BatchNorm2d(256))
302 | # features7.append(nn.ReLU(inplace=True))
303 | # self.features7 = nn.Sequential(*features7)
304 |
305 | # # convTranspose3
306 | # self.unpool8 = nn.MaxUnpool2d(2, stride=2)
307 | # features8 = []
308 | # features8.append(nn.Conv2d(256, 256, 3, padding=1))
309 | # features8.append(nn.BatchNorm2d(256))
310 | # features8.append(nn.ReLU(inplace=True))
311 | # features8.append(nn.Conv2d(256, 256, 3, padding=1))
312 | # features8.append(nn.BatchNorm2d(256))
313 | # features8.append(nn.ReLU(inplace=True))
314 | # features8.append(nn.Conv2d(256, 128, 3, padding=1))
315 | # features8.append(nn.BatchNorm2d(128))
316 | # features8.append(nn.ReLU(inplace=True))
317 | # self.features8 = nn.Sequential(*features8)
318 |
319 | # # convTranspose4
320 | # self.unpool9 = nn.MaxUnpool2d(2, stride=2)
321 | # features9 = []
322 | # features9.append(nn.Conv2d(128, 128, 3, padding=1))
323 | # features9.append(nn.BatchNorm2d(128))
324 | # features9.append(nn.ReLU(inplace=True))
325 | # features9.append(nn.Conv2d(128, 64, 3, padding=1))
326 | # features9.append(nn.BatchNorm2d(64))
327 | # features9.append(nn.ReLU(inplace=True))
328 | # self.features9 = nn.Sequential(*features9)
329 |
330 | # # convTranspose5
331 | # self.unpool10 = nn.MaxUnpool2d(2, stride=2)
332 | # self.final = nn.Sequential(
333 | # nn.Conv2d(64, 64, kernel_size=3, padding=1),
334 | # nn.BatchNorm2d(64),
335 | # nn.ReLU(inplace=True),
336 | # nn.Conv2d(64, n_classes, kernel_size=3, padding=1),
337 | # )
338 |
339 | # self._initialize_weights()
340 |
341 | # def _initialize_weights(self):
342 | # for m in self.modules():
343 | # if isinstance(m, nn.Conv2d):
344 | # nn.init.kaiming_normal_(m.weight)
345 |
346 | # vgg16 = torchvision.models.vgg16(pretrained=True)
347 | # vgg_features = [
348 | # vgg16.features[0:4],
349 | # vgg16.features[5:9],
350 | # vgg16.features[10:16],
351 | # vgg16.features[17:23],
352 | # vgg16.features[24:29]
353 | # ]
354 | # features = [
355 | # self.features1,
356 | # self.features2,
357 | # self.features3,
358 | # self.features4,
359 | # self.features5
360 | # ]
361 | # for l1, l2 in zip(vgg_features, features):
362 | # for i in range(len(list(l1.modules())) // 2):
363 | # assert isinstance(l1[i * 2], nn.Conv2d) == isinstance(l2[i * 3], nn.Conv2d)
364 | # assert l1[i * 2].weight.size() == l2[i * 3].weight.size()
365 | # assert l1[i * 2].bias.size() == l2[i * 3].bias.size()
366 | # l2[i * 3].weight.data = l1[i * 2].weight.data
367 | # l2[i * 3].bias.data = l1[i * 2].bias.data
368 |
369 | # def forward(self, x):
370 | # out = self.features1(x)
371 | # out, indices_1 = self.pool1(out)
372 | # out = self.features2(out)
373 | # out, indices_2 = self.pool2(out)
374 | # out = self.features3(out)
375 | # out, indices_3 = self.pool3(out)
376 | # out = self.features4(out)
377 | # out, indices_4 = self.pool4(out)
378 | # out = self.features5(out)
379 | # out, indices_5 = self.pool5(out)
380 | # out = self.unpool6(out, indices_5)
381 | # out = self.features6(out)
382 | # out = self.unpool7(out, indices_4)
383 | # out = self.features7(out)
384 | # out = self.unpool8(out, indices_3)
385 | # out = self.features8(out)
386 | # out = self.unpool9(out, indices_2)
387 | # out = self.features9(out)
388 | # out = self.unpool10(out, indices_1)
389 | # out = self.final(out)
390 | # return out
391 |
392 |
393 | # Bilinear interpolation upsampling version
394 | # class SegNet(nn.Module):
395 | # def __init__(self, n_classes):
396 | # super(SegNet, self).__init__()
397 | #
398 | # # conv1
399 | # features = []
400 | # features.append(nn.Conv2d(3, 64, 3, padding=1))
401 | # features.append(nn.BatchNorm2d(64))
402 | # features.append(nn.ReLU(inplace=True))
403 | # features.append(nn.Conv2d(64, 64, 3, padding=1))
404 | # features.append(nn.BatchNorm2d(64))
405 | # features.append(nn.ReLU(inplace=True))
406 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2
407 |
408 | # # conv2
409 | # features.append(nn.Conv2d(64, 128, 3, padding=1))
410 | # features.append(nn.BatchNorm2d(128))
411 | # features.append(nn.ReLU(inplace=True))
412 | # features.append(nn.Conv2d(128, 128, 3, padding=1))
413 | # features.append(nn.BatchNorm2d(128))
414 | # features.append(nn.ReLU(inplace=True))
415 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4
416 |
417 | # # conv3
418 | # features.append(nn.Conv2d(128, 256, 3, padding=1))
419 | # features.append(nn.BatchNorm2d(256))
420 | # features.append(nn.ReLU(inplace=True))
421 | # features.append(nn.Conv2d(256, 256, 3, padding=1))
422 | # features.append(nn.BatchNorm2d(256))
423 | # features.append(nn.ReLU(inplace=True))
424 | # features.append(nn.Conv2d(256, 256, 3, padding=1))
425 | # features.append(nn.BatchNorm2d(256))
426 | # features.append(nn.ReLU(inplace=True))
427 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8
428 |
429 | # # conv4
430 | # features.append(nn.Conv2d(256, 512, 3, padding=1))
431 | # features.append(nn.BatchNorm2d(512))
432 | # features.append(nn.ReLU(inplace=True))
433 | # features.append(nn.Conv2d(512, 512, 3, padding=1))
434 | # features.append(nn.BatchNorm2d(512))
435 | # features.append(nn.ReLU(inplace=True))
436 | # features.append(nn.Conv2d(512, 512, 3, padding=1))
437 | # features.append(nn.BatchNorm2d(512))
438 | # features.append(nn.ReLU(inplace=True))
439 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/16
440 |
441 | # # conv5
442 | # features.append(nn.Conv2d(512, 512, 3, padding=1))
443 | # features.append(nn.BatchNorm2d(512))
444 | # features.append(nn.ReLU(inplace=True))
445 | # features.append(nn.Conv2d(512, 512, 3, padding=1))
446 | # features.append(nn.BatchNorm2d(512))
447 | # features.append(nn.ReLU(inplace=True))
448 | # features.append(nn.Conv2d(512, 512, 3, padding=1))
449 | # features.append(nn.BatchNorm2d(512))
450 | # features.append(nn.ReLU(inplace=True))
451 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/32
452 | # self.features = nn.Sequential(*features)
453 |
454 | # # convTranspose1
455 | # up1 = []
456 | # up1.append(nn.Conv2d(512, 512, 3, padding=1))
457 | # up1.append(nn.BatchNorm2d(512))
458 | # up1.append(nn.ReLU(inplace=True))
459 | # up1.append(nn.Conv2d(512, 512, 3, padding=1))
460 | # up1.append(nn.BatchNorm2d(512))
461 | # up1.append(nn.ReLU(inplace=True))
462 | # up1.append(nn.Conv2d(512, 512, 3, padding=1))
463 | # up1.append(nn.BatchNorm2d(512))
464 | # up1.append(nn.ReLU(inplace=True))
465 | # self.up1 = nn.Sequential(*up1)
466 |
467 | # # convTranspose2
468 | # up2 = []
469 | # up2.append(nn.Conv2d(512, 512, 3, padding=1))
470 | # up2.append(nn.BatchNorm2d(512))
471 | # up2.append(nn.ReLU(inplace=True))
472 | # up2.append(nn.Conv2d(512, 512, 3, padding=1))
473 | # up2.append(nn.BatchNorm2d(512))
474 | # up2.append(nn.ReLU(inplace=True))
475 | # up2.append(nn.Conv2d(512, 256, 3, padding=1))
476 | # up2.append(nn.BatchNorm2d(256))
477 | # up2.append(nn.ReLU(inplace=True))
478 | # self.up2 = nn.Sequential(*up2)
479 |
480 | # # convTranspose3
481 | # up3 = []
482 | # up3.append(nn.Conv2d(256, 256, 3, padding=1))
483 | # up3.append(nn.BatchNorm2d(256))
484 | # up3.append(nn.ReLU(inplace=True))
485 | # up3.append(nn.Conv2d(256, 256, 3, padding=1))
486 | # up3.append(nn.BatchNorm2d(256))
487 | # up3.append(nn.ReLU(inplace=True))
488 | # up3.append(nn.Conv2d(256, 128, 3, padding=1))
489 | # up3.append(nn.BatchNorm2d(128))
490 | # up3.append(nn.ReLU(inplace=True))
491 | # self.up3 = nn.Sequential(*up3)
492 |
493 | # # convTranspose4
494 | # up4 = []
495 | # up4.append(nn.Conv2d(128, 128, 3, padding=1))
496 | # up4.append(nn.BatchNorm2d(128))
497 | # up4.append(nn.ReLU(inplace=True))
498 | # up4.append(nn.Conv2d(128, 64, 3, padding=1))
499 | # up4.append(nn.BatchNorm2d(64))
500 | # up4.append(nn.ReLU(inplace=True))
501 | # self.up4 = nn.Sequential(*up4)
502 |
503 | # self.final = nn.Sequential(
504 | # nn.Conv2d(64, 64, kernel_size=3, padding=1),
505 | # nn.BatchNorm2d(64),
506 | # nn.ReLU(inplace=True),
507 | # nn.Conv2d(64, n_classes, kernel_size=3, padding=1),
508 | # )
509 |
510 | # self._initialize_weights()
511 |
512 | # def _initialize_weights(self):
513 | # for m in self.modules():
514 | # if isinstance(m, nn.Conv2d):
515 | # nn.init.kaiming_normal_(m.weight)
516 | # if isinstance(m, nn.BatchNorm2d):
517 | # nn.init.constant_(m.weight, 1)
518 | # nn.init.constant_(m.bias, 0.001)
519 |
520 | # vgg16 = torchvision.models.vgg16_bn(pretrained=True)
521 | # state_dict = vgg16.features.state_dict()
522 | # self.features.load_state_dict(state_dict)
523 |
524 | # def forward(self, x):
525 | # out = self.features(x)
526 | # out = F.interpolate(out, scale_factor=2, mode='bilinear')
527 | # out = self.up1(out)
528 | # out = F.interpolate(out, scale_factor=2, mode='bilinear')
529 | # out = self.up2(out)
530 | # out = F.interpolate(out, scale_factor=2, mode='bilinear')
531 | # out = self.up3(out)
532 | # out = F.interpolate(out, scale_factor=2, mode='bilinear')
533 | # out = self.up4(out)
534 | # out = F.interpolate(out, scale_factor=2, mode='bilinear')
535 | # out = self.final(out)
536 | # return out
537 |
--------------------------------------------------------------------------------
/Models/UNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class EncoderBlock(nn.Module):
6 | def __init__(self, in_channels, out_channels):
7 | super(EncoderBlock, self).__init__()
8 | layers = [
9 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
10 | nn.BatchNorm2d(out_channels),
11 | nn.ReLU(inplace=True),
12 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
13 | nn.BatchNorm2d(out_channels),
14 | nn.ReLU(inplace=True),
15 | ]
16 | self.encode = nn.Sequential(*layers)
17 |
18 | def forward(self, x):
19 | return self.encode(x)
20 |
21 |
22 | class DecoderBlock(nn.Module):
23 | def __init__(self, in_channels, middle_channels, out_channels):
24 | super(DecoderBlock, self).__init__()
25 | self.decode = nn.Sequential(
26 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
27 | nn.BatchNorm2d(middle_channels),
28 | nn.ReLU(inplace=True),
29 | nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
30 | nn.BatchNorm2d(middle_channels),
31 | nn.ReLU(inplace=True),
32 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2),
33 | )
34 |
35 | def forward(self, x):
36 | return self.decode(x)
37 |
38 |
39 | class UNet(nn.Module):
40 | def __init__(self, n_classes):
41 | super(UNet, self).__init__()
42 | self.enc1 = EncoderBlock(3, 64)
43 | self.enc1_pool = nn.MaxPool2d(kernel_size=2, stride=2)
44 | self.enc2 = EncoderBlock(64, 128)
45 | self.enc2_pool = nn.MaxPool2d(kernel_size=2, stride=2)
46 | self.enc3 = EncoderBlock(128, 256)
47 | self.enc3_pool = nn.MaxPool2d(kernel_size=2, stride=2)
48 | self.enc4 = EncoderBlock(256, 512)
49 | self.enc4_pool = nn.MaxPool2d(kernel_size=2, stride=2)
50 | self.center = DecoderBlock(512, 1024, 512)
51 | self.dec4 = DecoderBlock(1024, 512, 256)
52 | self.dec3 = DecoderBlock(512, 256, 128)
53 | self.dec2 = DecoderBlock(256, 128, 64)
54 | self.dec1 = nn.Sequential(
55 | nn.Conv2d(128, 64, kernel_size=3, padding=1),
56 | nn.BatchNorm2d(64),
57 | nn.ReLU(inplace=True),
58 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
59 | nn.BatchNorm2d(64),
60 | nn.ReLU(inplace=True),
61 | )
62 | self.final = nn.Conv2d(64, n_classes, kernel_size=1)
63 | initialize_weights(self)
64 |
65 | def forward(self, x):
66 | enc1 = self.enc1(x)
67 | enc1_pool = self.enc1_pool(enc1)
68 | enc2 = self.enc2(enc1_pool)
69 | enc2_pool = self.enc2_pool(enc2)
70 | enc3 = self.enc3(enc2_pool)
71 | enc3_pool = self.enc3_pool(enc3)
72 | enc4 = self.enc4(enc3_pool)
73 | enc4_pool = self.enc4_pool(enc4)
74 | center = self.center(enc4_pool)
75 | dec4 = self.dec4(torch.cat([center, enc4], 1))
76 | dec3 = self.dec3(torch.cat([dec4, enc3], 1))
77 | dec2 = self.dec2(torch.cat([dec3, enc2], 1))
78 | dec1 = self.dec1(torch.cat([dec2, enc1], 1))
79 | final = self.final(dec1)
80 | return final
81 |
82 |
83 | def initialize_weights(model):
84 | for module in model.modules():
85 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
86 | nn.init.kaiming_normal_(module.weight)
87 | if module.bias is not None:
88 | nn.init.constant_(module.bias, 0)
89 | elif isinstance(module, nn.BatchNorm2d):
90 | module.weight.data.fill_(1)
91 | module.bias.data.zero_()
92 |
--------------------------------------------------------------------------------
/Models/__init__.py:
--------------------------------------------------------------------------------
1 | from .FCN import FCN32s, FCN8sAtOnce
2 | from .UNet import UNet
3 | from .SegNet import SegNet
4 | from .DeepLab_v1 import DeepLabLargeFOV
5 | from .DeepLab_v2 import DeepLabASPPVGG, DeepLabASPPResNet
6 | from .DeepLab_v3 import DeepLabV3
7 | from .DeepLab_v3plus import DeepLabV3Plus
8 | from .Dilation8 import Dilation8
9 | from .PSPNet import PSPNet
10 | import torch
11 |
12 | VALID_MODEL = [
13 | 'fcn32s', 'fcn8s', 'unet', 'segnet', 'deeplab-largefov', 'deeplab-aspp-vgg',
14 | 'deeplab-aspp-resnet', 'deeplab-v3', 'deeplab-v3+', 'dilation8', 'pspnet'
15 | ]
16 |
17 |
18 | def model_loader(model_name, n_classes, resume):
19 | model_name = model_name.lower()
20 | if model_name == 'fcn32s':
21 | model = FCN32s(n_classes=n_classes)
22 | elif model_name == 'fcn8s':
23 | model = FCN8sAtOnce(n_classes=n_classes)
24 | elif model_name == 'unet':
25 | model = UNet(n_classes=n_classes)
26 | elif model_name == 'segnet':
27 | model = SegNet(n_classes=n_classes)
28 | elif model_name == 'deeplab-largefov':
29 | model = DeepLabLargeFOV(n_classes=n_classes)
30 | elif model_name == 'deeplab-aspp-vgg':
31 | model = DeepLabASPPVGG(n_classes=n_classes)
32 | elif model_name == 'deeplab-aspp-resnet':
33 | model = DeepLabASPPResNet(n_classes=n_classes)
34 | elif model_name == 'deeplab-v3':
35 | model = DeepLabV3(n_classes=n_classes)
36 | elif model_name == 'deeplab-v3+':
37 | model = DeepLabV3Plus(n_classes=n_classes)
38 | elif model_name == 'dilation8':
39 | model = Dilation8(n_classes=n_classes)
40 | elif model_name == 'pspnet':
41 | model = PSPNet(n_classes=n_classes)
42 | else:
43 | raise ValueError('Unsupported model, '
44 | 'valid models as follows:\n{}'.format(
45 | ', '.join(VALID_MODEL)))
46 |
47 | start_epoch = 1
48 | if resume:
49 | checkpoint = torch.load(resume)
50 | model.load_state_dict(checkpoint['model_state_dict'])
51 | start_epoch = checkpoint['epoch']
52 | else:
53 | checkpoint = None
54 |
55 | return model, start_epoch, checkpoint
56 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch models
2 |
3 | ## Dataset
4 | ### PASCAL VOC
5 | model|acc|acc_cls|mean_iu|notes
6 | ---|---|---|---|---
7 | FCN32s|90.17%|75.56%|61.81%|lr=1.0e-10
reduction='sum'
8 | FCN32s(original)|-|-|63.6%|
9 | FCN8sAtOnce|90.27%|74.95%|62.13%|lr=1.0e-10
reduction='sum'
10 | FCN8sAtOnce(original)|-|-|65.4%|
11 | DeepLab-LargeFov|93.71%|72.21%|61.32%|pad images to 513x513 for evaluation
12 | DeepLab-LargeFov|90.90%|73.89%|62.09%|use original resolution for evaluation
13 | DeepLab-LargeFov(original)|-|-|62.25%|
14 | DeepLab-ASPP|93.10|80.13%|61.07%|
15 | DeepLab-ASPP(original)|-|-|68.96%|
16 |
17 | ### CamVid
18 | model|acc|acc_cls|mean_iu|notes
19 | ---|---|---|---|---
20 | SegNet(Maxunpooling, vgg16-based)|86.71%|66.39%|54.09%|lr=0.01
21 | SegNet(Maxunpooling, vg16_bn-based)|87.84%|70.75%|57.68%|lr=0.01
22 | SegNet(Bilinear interpolation)|85.86%|71.95%|56.22%|lr=0.01
23 | SegNet(original)|88.6%|65.9%|50.2%
24 | UNet|84.38%|62.80%|49.83%|lr=0.01
25 |
--------------------------------------------------------------------------------
/augmentations.py:
--------------------------------------------------------------------------------
1 | import random
2 | from PIL import Image, ImageOps
3 |
4 |
5 | class Compose:
6 | def __init__(self, augmentations):
7 | self.augmentations = augmentations
8 |
9 | def __call__(self, imgs, lbls):
10 | assert imgs.size == lbls.size
11 | for aug in self.augmentations:
12 | imgs, lbls = aug(imgs, lbls)
13 |
14 | return imgs, lbls
15 |
16 |
17 | class RandomFlip:
18 | """Flip images horizontally.
19 | """
20 | def __init__(self, prob=0.5):
21 | self.prob = prob
22 |
23 | def __call__(self, image, label):
24 | if random.random() < self.prob:
25 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
26 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
27 | return image, label
28 |
29 |
30 | class RandomCrop:
31 | """Crop images to given size.
32 |
33 | Parameters
34 | ----------
35 | crop_size: a tuple specifying crop size,
36 | which can be larger than original size.
37 | """
38 | def __init__(self, crop_size):
39 | self.crop_size = crop_size
40 |
41 | @staticmethod
42 | def get_params(img, output_size):
43 | w, h = img.size
44 | th, tw = output_size
45 | if w == tw and h == th:
46 | return 0, 0, h, w
47 |
48 | i = random.randint(0, h - th)
49 | j = random.randint(0, w - tw)
50 | return i, j, th, tw
51 |
52 | def __call__(self, image, label):
53 | if image.size[0] < self.crop_size[1]:
54 | image = ImageOps.expand(image, (self.crop_size[1] - image.size[0], 0), fill=0)
55 | label = ImageOps.expand(label, (self.crop_size[1] - label.size[0], 0), fill=255)
56 | if image.size[1] < self.crop_size[0]:
57 | image = ImageOps.expand(image, (0, self.crop_size[0] - image.size[1]), fill=0)
58 | label = ImageOps.expand(label, (0, self.crop_size[0] - label.size[1]), fill=255)
59 |
60 | i, j, h, w = self.get_params(image, self.crop_size)
61 | image = image.crop((j, i, j + w, i + h))
62 | label = label.crop((j, i, j + w, i + h))
63 |
64 | return image, label
65 |
66 |
67 | class RandomScale:
68 | """Scale images within range.
69 |
70 | Parameters
71 | ----------
72 | scale_range: a tuple specifying lowest and highest range.
73 | """
74 | def __init__(self, scale_range):
75 | self.scale = scale_range
76 |
77 | def __call__(self, image, label):
78 | w, h = image.size
79 | scale = random.uniform(self.scale[0], self.scale[1])
80 | ow, oh = int(w * scale), int(h * scale)
81 | image = image.resize((ow, oh), Image.BILINEAR)
82 | label = label.resize((ow, oh), Image.NEAREST)
83 |
84 | return image, label
85 |
86 |
87 | def get_augmentations(args):
88 | """Specify augmentation.
89 | """
90 | augs = []
91 | if args.flip:
92 | augs.append(RandomFlip())
93 | if args.crop_size:
94 | augs.append(RandomCrop(args.crop_size))
95 | if args.scale_range:
96 | augs.append(RandomScale(args.scale_range))
97 |
98 | if augs == []:
99 | return None
100 | print('Using augmentations: ', end=' ')
101 | for x in augs:
102 | print(x.__class__.__name__, end=' ')
103 | print('\n')
104 |
105 | return Compose(augs)
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | from PIL import Image
5 | import scipy
6 | import torch
7 | from torch.utils.data import DataLoader
8 | import tqdm
9 | import Models
10 | from utils import visualize_segmentation, get_tile_image, runningScore, averageMeter
11 | from Dataloader import get_loader
12 | from augmentations import RandomCrop, Compose
13 |
14 | def main():
15 | parser = argparse.ArgumentParser(
16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
17 | )
18 | parser.add_argument('--model', type=str, default='deeplab-largefov')
19 | parser.add_argument('--model_file', type=str, default='/home/ecust/lx/Semantic-Segmentation-PyTorch/logs/deeplab-largefov_20190417_230357/model_best.pth.tar',help='Model path')
20 | parser.add_argument('--dataset_type', type=str, default='voc',help='type of dataset')
21 | parser.add_argument('--dataset', type=str, default='/home/ecust/Datasets/PASCAL VOC/VOCdevkit/VOC2012',help='path to dataset')
22 | parser.add_argument('--img_size', type=tuple, default=None, help='resize images using bilinear interpolation')
23 | parser.add_argument('--crop_size', type=tuple, default=None, help='crop images')
24 | parser.add_argument('--n_classes', type=int, default=21, help='number of classes')
25 | parser.add_argument('--pretrained', type=bool, default=True, help='should be set the same as train.py')
26 | args = parser.parse_args()
27 |
28 | model_file = args.model_file
29 | root = args.dataset
30 | n_classes = args.n_classes
31 |
32 | crop=None
33 | # crop = Compose([RandomCrop(args.crop_size)])
34 | loader = get_loader(args.dataset_type)
35 | val_loader = DataLoader(
36 | loader(root, n_classes=n_classes, split='val', img_size=args.img_size, augmentations=crop, pretrained=args.pretrained),
37 | batch_size=1, shuffle=False, num_workers=4)
38 |
39 | model, _, _ = Models.model_loader(args.model, n_classes, resume=None)
40 |
41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 | model = model.to(device)
43 |
44 | print('==> Loading {} model file: {}'.format(model.__class__.__name__, model_file))
45 |
46 | model_data = torch.load(model_file)
47 |
48 | try:
49 | model.load_state_dict(model_data)
50 | except Exception:
51 | model.load_state_dict(model_data['model_state_dict'])
52 | model.eval()
53 |
54 | print('==> Evaluating with {} dataset'.format(args.dataset_type))
55 | visualizations = []
56 | metrics = runningScore(n_classes)
57 |
58 | for data, target in tqdm.tqdm(val_loader, total=len(val_loader), ncols=80, leave=False):
59 | data, target = data.to(device), target.to(device)
60 | score = model(data)
61 |
62 | imgs = data.data.cpu()
63 | lbl_pred = score.data.max(1)[1].cpu().numpy()
64 | lbl_true = target.data.cpu()
65 | for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
66 | img, lt = val_loader.dataset.untransform(img, lt)
67 | metrics.update(lt, lp)
68 | if len(visualizations) < 9:
69 | viz = visualize_segmentation(
70 | lbl_pred=lp, lbl_true=lt, img=img,
71 | n_classes=n_classes, dataloader=val_loader)
72 | visualizations.append(viz)
73 | acc, acc_cls, mean_iu, fwavacc, cls_iu = metrics.get_scores()
74 | print('''
75 | Accuracy: {0:.2f}
76 | Accuracy Class: {1:.2f}
77 | Mean IoU: {2:.2f}
78 | FWAV Accuracy: {3:.2f}'''.format(acc * 100,
79 | acc_cls * 100,
80 | mean_iu * 100,
81 | fwavacc * 100) + '\n')
82 |
83 | class_name = val_loader.dataset.class_names
84 | if class_name is not None:
85 | for index, value in enumerate(cls_iu.values()):
86 | offset = 20 - len(class_name[index])
87 | print(class_name[index] + ' ' * offset + f'{value * 100:>.2f}')
88 | else:
89 | print("\nyou don't specify class_names, use number instead")
90 | for key, value in cls_iu.items():
91 | print(key, f'{value * 100:>.2f}')
92 |
93 | viz = get_tile_image(visualizations)
94 | # img = Image.fromarray(viz)
95 | # img.save('viz_evaluate.png')
96 | scipy.misc.imsave('viz_evaluate.png', viz)
97 |
98 | if __name__ == '__main__':
99 | main()
--------------------------------------------------------------------------------
/learning_curve.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils import learning_curve
3 |
4 |
5 | def main():
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('log_file')
8 | args = parser.parse_args()
9 |
10 | log_file = args.log_file
11 |
12 | learning_curve(log_file)
13 |
14 |
15 | if __name__ == '__main__':
16 | main()
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import numpy as np
3 | from PIL import Image
4 | import torch
5 |
6 |
7 | def CrossEntropyLoss(score, target, weight, ignore_index, reduction):
8 | """Cross entropy for single or multiple outputs.
9 | """
10 | if not isinstance(score, tuple):
11 | loss = F.cross_entropy(
12 | score, target, weight=weight, ignore_index=ignore_index, reduction=reduction)
13 | return loss
14 |
15 | loss = 0
16 | for s in score:
17 | loss = loss + F.cross_entropy(
18 | s, target, weight=weight, ignore_index=ignore_index, reduction=reduction)
19 | return loss
20 |
21 | def resize_labels(labels, size):
22 | new_labels = []
23 | for label in labels:
24 | label = label.float().cpu().numpy()
25 | label = Image.fromarray(label).resize((size[1], size[0]), Image.NEAREST)
26 | new_labels.append(np.asarray(label))
27 | new_labels = torch.LongTensor(new_labels)
28 | return new_labels
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 |
4 |
5 | # Adapted from:
6 | # https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/metrics.py
7 |
8 |
9 | class runningScore(object):
10 | def __init__(self, n_classes):
11 | self.n_classes = n_classes
12 | self.confusion_matrix = np.zeros((n_classes, n_classes))
13 |
14 | def _fast_hist(self, label_true, label_pred, n_class):
15 | mask = (label_true >= 0) & (label_true < n_class)
16 | hist = np.bincount(
17 | n_class * label_true[mask].astype(int) + label_pred[mask],
18 | minlength=n_class**2).reshape(n_class, n_class)
19 | return hist
20 |
21 | def update(self, label_trues, label_preds):
22 | for lt, lp in zip(label_trues, label_preds):
23 | self.confusion_matrix += self._fast_hist(
24 | lt.flatten(), lp.flatten(), self.n_classes)
25 |
26 | def get_scores(self):
27 | """Returns accuracy score evaluation result.
28 | - overall accuracy
29 | - mean accuracy
30 | - mean IU
31 | - fwavacc
32 | """
33 | hist = self.confusion_matrix
34 | acc = np.diag(hist).sum() / hist.sum()
35 | acc_cls = np.diag(hist) / hist.sum(axis=1)
36 | acc_cls = np.nanmean(acc_cls)
37 | iu = np.diag(hist) / (
38 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
39 | mean_iu = np.nanmean(iu)
40 | freq = hist.sum(axis=1) / hist.sum()
41 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
42 | cls_iu = dict(zip(range(self.n_classes), iu))
43 |
44 | return acc, acc_cls, mean_iu, fwavacc, cls_iu
45 |
46 | def reset(self):
47 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
48 |
49 |
50 | class averageMeter(object):
51 | """Computes and stores the average and current value"""
52 |
53 | def __init__(self):
54 | self.reset()
55 |
56 | def reset(self):
57 | self.val = 0
58 | self.avg = 0
59 | self.sum = 0
60 | self.count = 0
61 |
62 | def update(self, val, n=1):
63 | self.val = val
64 | self.sum += val * n
65 | self.count += n
66 | self.avg = self.sum / self.count
67 |
68 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def get_optimizer(args, model):
6 | """Optimizer for different models
7 | """
8 | if args.optim.lower() == 'sgd':
9 | if args.model.lower() in ['fcn32s', 'fcn8s']:
10 | optim = fcn_optim(model, args)
11 | elif args.model.lower() in ['deeplab-largefov', 'deeplab-aspp-vgg']:
12 | optim = deeplab_optim(model, args)
13 | elif args.model.lower() in ['deeplab-aspp-resnet']:
14 | optim = deeplabv2_optim(model, args)
15 | else:
16 | optim = torch.optim.SGD(
17 | model.parameters(),
18 | lr=args.lr,
19 | momentum=args.beta1,
20 | weight_decay=args.weight_decay)
21 |
22 | elif args.optim.lower() == 'adam':
23 | optim = torch.optim.Adam(
24 | model.parameters(),
25 | lr=args.lr,
26 | betas=(args.beta1, 0.999),
27 | weight_decay=args.weight_decay)
28 |
29 | return optim
30 |
31 | def fcn_optim(model, args):
32 | """optimizer for fcn32s and fcn8s
33 | """
34 | optim = torch.optim.SGD(
35 | [{'params': model.get_parameters(bias=False)},
36 | {'params': model.get_parameters(bias=True), 'lr': args.lr * 2, 'weight_decay': 0}],
37 | lr=args.lr,
38 | momentum=args.beta1,
39 | weight_decay=args.weight_decay)
40 | return optim
41 |
42 | def deeplab_optim(model, args):
43 | """optimizer for deeplab-v1 and deeplab-v2-vgg
44 | """
45 | optim = torch.optim.SGD(
46 | [{'params': model.get_parameters(bias=False, score=False)},
47 | {'params': model.get_parameters(bias=True, score=False), 'lr': args.lr * 2, 'weight_decay': 0},
48 | {'params': model.get_parameters(bias=False, score=True), 'lr': args.lr * 10},
49 | {'params': model.get_parameters(bias=True, score=True), 'lr': args.lr * 20, 'weight_decay': 0}],
50 | lr=args.lr,
51 | momentum=args.beta1,
52 | weight_decay=args.weight_decay)
53 | return optim
54 |
55 | def deeplabv2_optim(model, args):
56 | """optimizer for deeplab-v2-resnet
57 | """
58 | optim = torch.optim.SGD(
59 | [{'params': model.get_parameters(bias=False, score=False)},
60 | {'params': model.get_parameters(bias=False, score=True), 'lr': args.lr * 10},
61 | {'params': model.get_parameters(bias=True, score=True), 'lr': args.lr * 20, 'weight_decay': 0}],
62 | lr=args.lr,
63 | momentum=args.beta1,
64 | weight_decay=args.weight_decay)
65 | return optim
--------------------------------------------------------------------------------
/preparation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser(
6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
7 | )
8 | parser.add_argument('--train_img_root', type=str, required=True, help='path to training images')
9 | parser.add_argument('--train_lbl_root', type=str, required=True, help='path to training labels')
10 | parser.add_argument('--val_img_root', type=str, help='path to validation images')
11 | parser.add_argument('--val_lbl_root', type=str, help='path to validation labels')
12 | parser.add_argument('--train_split', type=float, help='proportion of the dataset to include in the train split')
13 |
14 | args = parser.parse_args()
15 |
16 | train_img_root = args.train_img_root
17 | train_lbl_root = args.train_lbl_root
18 | val_img_root = args.val_img_root
19 | val_lbl_root = args.val_lbl_root
20 | train_split = args.train_split
21 |
22 | if val_img_root is None:
23 | img = []
24 | lbl = []
25 | for root, _, files in os.walk(train_img_root):
26 | for filename in files:
27 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
28 | img.append(os.path.join(root, filename))
29 |
30 | for root, _, files in os.walk(train_lbl_root):
31 | for filename in files:
32 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
33 | lbl.append(os.path.join(root, filename))
34 |
35 | assert len(img) == len(lbl), 'numbers of images and labels are not equal'
36 |
37 | choice = np.random.choice(len(img), len(img), replace=False)
38 | train = choice[:int(len(img) * train_split)]
39 | val = choice[int(len(img) * train_split):]
40 |
41 | with open('train.txt', 'a') as f:
42 | for index in train:
43 | f.write(' '.join([img[index], lbl[index]]) + '\n')
44 |
45 | with open('val.txt', 'a') as f:
46 | for index in val:
47 | f.write(' '.join([img[index], lbl[index]]) + '\n')
48 | else:
49 | train_img = []
50 | train_lbl = []
51 | val_img = []
52 | val_lbl = []
53 | name_list = [train_img, train_lbl, val_img, val_lbl]
54 | root = [train_img_root, train_lbl_root, val_img_root, val_lbl_root]
55 | for nlist, root in zip(name_list, root):
56 | for _root, _, files in os.walk(root):
57 | for filename in files:
58 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
59 | nlist.append(os.path.join(_root, filename))
60 |
61 | with open('train.txt', 'a') as f:
62 | for index in range(len(train_img)):
63 | f.write(' '.join([train_img[index], train_lbl[index]]) + '\n')
64 |
65 | with open('val.txt', 'a') as f:
66 | for index in range(len(val_img)):
67 | f.write(' '.join([val_img[index], val_lbl[index]]) + '\n')
68 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.0.2
2 | numpy==1.22.0
3 | PyYAML
4 | scikit-image==0.14.1
5 | scipy==1.2.1
6 | tqdm==4.31.1
7 | pytz
8 | seaborn
9 | pandas
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import random
4 | import yaml
5 | import argparse
6 | import datetime
7 | import torch
8 | from Dataloader import get_loader
9 | from torch.utils.data import DataLoader
10 | from Models import model_loader
11 | from trainer import Trainer
12 | from utils import get_scheduler
13 | from optimizer import get_optimizer
14 | from augmentations import get_augmentations
15 |
16 | here = osp.dirname(osp.abspath(__file__))
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser(
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
22 | )
23 | parser.add_argument('--model', type=str, default='deeplab-largefov', help='model to train for')
24 | parser.add_argument('--epochs', type=int, default=50, help='total epochs')
25 | parser.add_argument('--val_epoch', type=int, default=10, help='validation interval')
26 | parser.add_argument('--batch_size', type=int, default=16, help='number of batch size')
27 | parser.add_argument('--img_size', type=tuple, default=None, help='resize images to proper size')
28 | parser.add_argument('--dataset_type', type=str, default='voc', help='choose which dataset to use')
29 | parser.add_argument('--dataset_root', type=str, default='/home/ecust/Datasets/PASCAL VOC/VOC_Aug', help='path to dataset')
30 | parser.add_argument('--n_classes', type=int, default=21, help='number of classes')
31 | parser.add_argument('--resume', default=None, help='path to checkpoint')
32 | parser.add_argument('--optim', type=str, default='sgd', help='optimizer')
33 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
34 | parser.add_argument('--lr_policy', type=str, default='poly', help='learning rate policy')
35 | parser.add_argument('--weight-decay', type=float, default=0.0005, help='weight decay')
36 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum for sgd, beta1 for adam')
37 | parser.add_argument('--lr_decay_step', type=float, default=10, help='step size for step learning policy')
38 | parser.add_argument('--lr_power', type=int, default=0.9, help='power parameter for poly learning policy')
39 | parser.add_argument('--pretrained', type=bool, default=True, help='whether to use pretrained models')
40 | parser.add_argument('--iter_size', type=int, default=10, help='iters to accumulate gradients')
41 |
42 | parser.add_argument('--crop_size', type=tuple, default=(321, 321), help='crop sizes of images')
43 | parser.add_argument('--flip', type=bool, default=True, help='whether to use horizontal flip')
44 |
45 | args = parser.parse_args()
46 |
47 | now = datetime.datetime.now()
48 | args.out = osp.join(here, 'logs', args.model + '_' + now.strftime('%Y%m%d_%H%M%S'))
49 |
50 | if not osp.exists(args.out):
51 | os.makedirs(args.out)
52 | with open(osp.join(args.out, 'config.yaml'), 'w') as f:
53 | yaml.safe_dump(args.__dict__, f, default_flow_style=False)
54 |
55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56 | print(f'Start training {args.model} using {device.type}\n')
57 |
58 | random.seed(1337)
59 | torch.manual_seed(1337)
60 | torch.cuda.manual_seed(1337)
61 |
62 | # 1. dataset
63 |
64 | root = args.dataset_root
65 | loader = get_loader(args.dataset_type)
66 |
67 | augmentations = get_augmentations(args)
68 |
69 | train_loader = DataLoader(
70 | loader(root, n_classes=args.n_classes, split='train_aug', img_size=args.img_size, augmentations=augmentations,
71 | pretrained=args.pretrained),
72 | batch_size=args.batch_size, shuffle=True, num_workers=4)
73 | val_loader = DataLoader(
74 | loader(root, n_classes=args.n_classes, split='val_id', img_size=args.img_size, pretrained=args.pretrained),
75 | batch_size=1, shuffle=False, num_workers=4)
76 |
77 | # 2. model
78 | model, start_epoch, ckpt = model_loader(args.model, args.n_classes, args.resume)
79 | model = model.to(device)
80 |
81 | # 3. optimizer
82 | optim = get_optimizer(args, model)
83 | if args.resume:
84 | optim.load_state_dict(ckpt['optim_state_dict'])
85 |
86 | scheduler = get_scheduler(optim, args)
87 |
88 | # 4. train
89 | trainer = Trainer(
90 | device=device,
91 | model=model,
92 | optimizer=optim,
93 | scheduler=scheduler,
94 | train_loader=train_loader,
95 | val_loader=val_loader,
96 | out=args.out,
97 | epochs=args.epochs,
98 | n_classes=args.n_classes,
99 | val_epoch=args.val_epoch,
100 | iter_size=args.iter_size
101 | )
102 | trainer.epoch = start_epoch
103 | trainer.train()
104 |
105 |
106 | if __name__ == '__main__':
107 | main()
108 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import os.path as osp
4 | import shutil
5 | import numpy as np
6 | import pytz
7 | import scipy.misc
8 | import torch
9 | import tqdm
10 | from PIL import Image
11 | from loss import CrossEntropyLoss, resize_labels
12 | from utils import visualize_segmentation, get_tile_image, learning_curve
13 | from metrics import runningScore, averageMeter
14 |
15 |
16 | class Trainer:
17 | def __init__(self, device, model, optimizer, scheduler, train_loader,
18 | val_loader, out, epochs, n_classes, val_epoch=10, iter_size=1):
19 | self.device = device
20 |
21 | self.model = model
22 | self.optim = optimizer
23 | self.scheduler = scheduler
24 | self.train_loader = train_loader
25 | self.val_loader = val_loader
26 |
27 | self.timestamp_start = \
28 | datetime.datetime.now(pytz.timezone('UTC'))
29 |
30 | self.val_epoch = val_epoch
31 | self.iter_size = iter_size
32 |
33 | self.out = out
34 | if not osp.exists(self.out):
35 | os.makedirs(self.out)
36 |
37 | self.log_headers = [
38 | 'epoch',
39 | 'train/loss',
40 | 'train/acc',
41 | 'train/acc_cls',
42 | 'train/mean_iu',
43 | 'train/fwavacc',
44 | 'valid/loss',
45 | 'valid/acc',
46 | 'valid/acc_cls',
47 | 'valid/mean_iu',
48 | 'valid/fwavacc',
49 | 'elapsed_time',
50 | ]
51 | if not osp.exists(osp.join(self.out, 'log.csv')):
52 | with open(osp.join(self.out, 'log.csv'), 'w') as f:
53 | f.write(','.join(self.log_headers) + '\n')
54 |
55 | self.n_classes = n_classes
56 | self.epoch = 1
57 | self.epochs = epochs
58 | self.best_mean_iu = 0
59 |
60 | def train_epoch(self):
61 | if self.epoch % self.val_epoch == 0 or self.epoch == 1:
62 | self.validate()
63 |
64 | self.model.train()
65 | train_metrics = runningScore(self.n_classes)
66 | train_loss_meter = averageMeter()
67 |
68 | self.optim.zero_grad()
69 |
70 | for data, target in tqdm.tqdm(
71 | self.train_loader, total=len(self.train_loader),
72 | desc=f'Train epoch={self.epoch}', ncols=80, leave=False):
73 |
74 | self.iter += 1
75 | assert self.model.training
76 |
77 | data, target = data.to(self.device), target.to(self.device)
78 | score = self.model(data)
79 |
80 | weight = self.train_loader.dataset.class_weight
81 | if weight:
82 | weight = torch.Tensor(weight).to(self.device)
83 |
84 | loss = CrossEntropyLoss(score, target, weight=weight, ignore_index=-1, reduction='mean')
85 |
86 | loss_data = loss.data.item()
87 | train_loss_meter.update(loss_data)
88 |
89 | if np.isnan(loss_data):
90 | raise ValueError('loss is nan while training')
91 |
92 | loss /= self.iter_size
93 | loss.backward()
94 |
95 | if self.iter % self.iter_size == 0:
96 | self.optim.step()
97 | self.optim.zero_grad()
98 |
99 |
100 | # if not isinstance(score, tuple):
101 | # lbl_pred = score.data.max(1)[1].cpu().numpy()
102 | # else:
103 | # lbl_pred = score[-1].data.max(1)[1].cpu().numpy()
104 |
105 | # lbl_true = target.data.cpu().numpy()
106 | # lbl_pred, lbl_true = get_multiscale_results(score, target, upsample_logits=False)
107 | if isinstance(score, tuple):
108 | lbl_pred = score[-1].data.max(1)[1].cpu().numpy()
109 | else:
110 | lbl_pred = score.data.max(1)[1].cpu().numpy()
111 | lbl_true = target.data.cpu().numpy()
112 | train_metrics.update(lbl_true, lbl_pred)
113 |
114 | acc, acc_cls, mean_iou, fwavacc, _ = train_metrics.get_scores()
115 | metrics = [acc, acc_cls, mean_iou, fwavacc]
116 |
117 | with open(osp.join(self.out, 'log.csv'), 'a') as f:
118 | elapsed_time = (
119 | datetime.datetime.now(pytz.timezone('UTC')) -
120 | self.timestamp_start).total_seconds()
121 | log = [self.epoch] + [train_loss_meter.avg] + \
122 | metrics + [''] * 5 + [elapsed_time]
123 | log = map(str, log)
124 | f.write(','.join(log) + '\n')
125 |
126 | if self.scheduler:
127 | self.scheduler.step()
128 | if self.epoch % self.val_epoch == 0 or self.epoch == 1:
129 | lr = self.optim.param_groups[0]['lr']
130 | print(f'\nCurrent base learning rate of epoch {self.epoch}: {lr:.7f}')
131 |
132 | train_loss_meter.reset()
133 | train_metrics.reset()
134 |
135 | def validate(self):
136 |
137 | visualizations = []
138 | val_metrics = runningScore(self.n_classes)
139 | val_loss_meter = averageMeter()
140 |
141 | with torch.no_grad():
142 | self.model.eval()
143 | for data, target in tqdm.tqdm(
144 | self.val_loader, total=len(self.val_loader),
145 | desc=f'Valid epoch={self.epoch}', ncols=80, leave=False):
146 |
147 | data, target = data.to(self.device), target.to(self.device)
148 |
149 | score = self.model(data)
150 |
151 | weight = self.val_loader.dataset.class_weight
152 | if weight:
153 | weight = torch.Tensor(weight).to(self.device)
154 |
155 | # target = resize_labels(target, (score.size()[2], score.size()[3]))
156 | # target = target.to(self.device)
157 | loss = CrossEntropyLoss(score, target, weight=weight, reduction='mean', ignore_index=-1)
158 | loss_data = loss.data.item()
159 | if np.isnan(loss_data):
160 | raise ValueError('loss is nan while validating')
161 |
162 | val_loss_meter.update(loss_data)
163 |
164 | # if not isinstance(score, tuple):
165 | # lbl_pred = score.data.max(1)[1].cpu().numpy()
166 | # else:
167 | # lbl_pred = score[-1].data.max(1)[1].cpu().numpy()
168 |
169 | # lbl_pred, lbl_true = get_multiscale_results(score, target, upsample_logits=False)
170 | imgs = data.data.cpu()
171 | if isinstance(score, tuple):
172 | lbl_pred = score[-1].data.max(1)[1].cpu().numpy()
173 | else:
174 | lbl_pred = score.data.max(1)[1].cpu().numpy()
175 | lbl_true = target.data.cpu()
176 | for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
177 | img, lt = self.val_loader.dataset.untransform(img, lt)
178 | val_metrics.update(lt, lp)
179 | # img = Image.fromarray(img).resize((lt.shape[1], lt.shape[0]), Image.BILINEAR)
180 | # img = np.array(img)
181 | if len(visualizations) < 9:
182 | viz = visualize_segmentation(
183 | lbl_pred=lp, lbl_true=lt, img=img,
184 | n_classes=self.n_classes, dataloader=self.train_loader)
185 | visualizations.append(viz)
186 |
187 | acc, acc_cls, mean_iou, fwavacc, _ = val_metrics.get_scores()
188 | metrics = [acc, acc_cls, mean_iou, fwavacc]
189 |
190 | print(f'\nEpoch: {self.epoch}', f'loss: {val_loss_meter.avg}, mIoU: {mean_iou}')
191 |
192 | out = osp.join(self.out, 'visualization_viz')
193 | if not osp.exists(out):
194 | os.makedirs(out)
195 | out_file = osp.join(out, 'epoch{:0>5d}.jpg'.format(self.epoch))
196 | scipy.misc.imsave(out_file, get_tile_image(visualizations))
197 |
198 | with open(osp.join(self.out, 'log.csv'), 'a') as f:
199 | elapsed_time = (
200 | datetime.datetime.now(pytz.timezone('UTC')) -
201 | self.timestamp_start).total_seconds()
202 | log = [self.epoch] + [''] * 5 + \
203 | [val_loss_meter.avg] + metrics + [elapsed_time]
204 | log = map(str, log)
205 | f.write(','.join(log) + '\n')
206 |
207 | mean_iu = metrics[2]
208 | is_best = mean_iu > self.best_mean_iu
209 | if is_best:
210 | self.best_mean_iu = mean_iu
211 | torch.save({
212 | 'epoch': self.epoch,
213 | 'arch': self.model.__class__.__name__,
214 | 'optim_state_dict': self.optim.state_dict(),
215 | 'model_state_dict': self.model.state_dict(),
216 | 'best_mean_iu': self.best_mean_iu,
217 | }, osp.join(self.out, 'checkpoint.pth.tar'))
218 | if is_best:
219 | shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
220 | osp.join(self.out, 'model_best.pth.tar'))
221 |
222 | val_loss_meter.reset()
223 | val_metrics.reset()
224 |
225 | def train(self):
226 | self.iter = 0
227 | for epoch in tqdm.trange(self.epoch, self.epochs + 1,
228 | desc='Train', ncols=80):
229 | self.epoch = epoch
230 | self.train_epoch()
231 |
232 | learning_curve(osp.join(self.out, 'log.csv'))
233 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import pandas
5 | import seaborn
6 | import skimage
7 | import skimage.color
8 | import skimage.transform
9 | from torch.optim import lr_scheduler
10 |
11 | # Adapted from https://github.com/wkentaro/fcn/blob/master/fcn/utils.py
12 |
13 | # -----------------------------------------------------------------------------
14 | # Visualization
15 | # -----------------------------------------------------------------------------
16 |
17 |
18 | def centerize(src, dst_shape, margin_color=None):
19 | """Centerize image for specified image size
20 | @param src: image to centerize
21 | @param dst_shape: image shape (height, width) or (height, width, channel)
22 | """
23 | if src.shape[:2] == dst_shape[:2]:
24 | return src
25 | centerized = np.zeros(dst_shape, dtype=src.dtype)
26 | if margin_color:
27 | centerized[:, :] = margin_color
28 | pad_vertical, pad_horizontal = 0, 0
29 | h, w = src.shape[:2]
30 | dst_h, dst_w = dst_shape[:2]
31 | if h < dst_h:
32 | pad_vertical = (dst_h - h) // 2
33 | if w < dst_w:
34 | pad_horizontal = (dst_w - w) // 2
35 | centerized[pad_vertical:pad_vertical + h, pad_horizontal:pad_horizontal +
36 | w] = src
37 | return centerized
38 |
39 |
40 | def _tile_images(imgs, tile_shape, concatenated_image):
41 | """Concatenate images whose sizes are same.
42 | @param imgs: image list which should be concatenated
43 | @param tile_shape: shape for which images should be concatenated
44 | @param concatenated_image: returned image.
45 | if it is None, new image will be created.
46 | """
47 | y_num, x_num = tile_shape
48 | one_width = imgs[0].shape[1]
49 | one_height = imgs[0].shape[0]
50 | if concatenated_image is None:
51 | if len(imgs[0].shape) == 3:
52 | n_channels = imgs[0].shape[2]
53 | assert all(im.shape[2] == n_channels for im in imgs)
54 | concatenated_image = np.zeros(
55 | (one_height * y_num, one_width * x_num, n_channels),
56 | dtype=np.uint8,
57 | )
58 | else:
59 | concatenated_image = np.zeros(
60 | (one_height * y_num, one_width * x_num), dtype=np.uint8)
61 | for y in range(y_num):
62 | for x in range(x_num):
63 | i = x + y * x_num
64 | if i >= len(imgs):
65 | pass
66 | else:
67 | concatenated_image[y * one_height:(y + 1) * one_height, x *
68 | one_width:(x + 1) * one_width] = imgs[i]
69 | return concatenated_image
70 |
71 |
72 | def get_tile_image(imgs, tile_shape=None, result_img=None, margin_color=None):
73 | """Concatenate images whose sizes are different.
74 | @param imgs: image list which should be concatenated
75 | @param tile_shape: shape for which images should be concatenated
76 | @param result_img: numpy array to put result image
77 | """
78 |
79 | def resize(*args, **kwargs):
80 | return skimage.transform.resize(*args, **kwargs)
81 |
82 | def get_tile_shape(img_num):
83 | x_num = 0
84 | y_num = int(math.sqrt(img_num))
85 | while x_num * y_num < img_num:
86 | x_num += 1
87 | return y_num, x_num
88 |
89 | if tile_shape is None:
90 | tile_shape = get_tile_shape(len(imgs))
91 |
92 | # get max tile size to which each image should be resized
93 | max_height, max_width = np.inf, np.inf
94 | for img in imgs:
95 | max_height = min([max_height, img.shape[0]])
96 | max_width = min([max_width, img.shape[1]])
97 |
98 | # resize and concatenate images
99 | for i, img in enumerate(imgs):
100 | h, w = img.shape[:2]
101 | dtype = img.dtype
102 | h_scale, w_scale = max_height / h, max_width / w
103 | scale = min([h_scale, w_scale])
104 | h, w = int(scale * h), int(scale * w)
105 | img = resize(
106 | image=img,
107 | output_shape=(h, w),
108 | mode='reflect',
109 | preserve_range=True,
110 | anti_aliasing=True,
111 | ).astype(dtype)
112 | if len(img.shape) == 3:
113 | img = centerize(img, (max_height, max_width, 3), margin_color)
114 | else:
115 | img = centerize(img, (max_height, max_width), margin_color)
116 | imgs[i] = img
117 | return _tile_images(imgs, tile_shape, result_img)
118 |
119 |
120 | def label2rgb(lbl, dataloader, img=None, n_labels=None, alpha=0.5):
121 | if n_labels is None:
122 | n_labels = lbl.max() + 1 # +1 for bg_label 0
123 |
124 | cmap = dataloader.dataset.getpalette()
125 | # cmap = getpalette(n_labels)
126 | # cmap = np.array(cmap).reshape([-1, 3]).astype(np.uint8)
127 |
128 | lbl_viz = cmap[lbl]
129 | lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled
130 |
131 | if img is not None:
132 |
133 | # img_gray = skimage.color.rgb2gray(img)
134 | # img_gray = skimage.color.gray2rgb(img_gray)
135 | # img_gray *= 255
136 | lbl_viz = alpha * lbl_viz + (1 - alpha) * img
137 | lbl_viz = lbl_viz.astype(np.uint8)
138 |
139 | return lbl_viz
140 |
141 |
142 | def visualize_segmentation(**kwargs):
143 | """Visualize segmentation.
144 | Parameters
145 | ----------
146 | img: ndarray
147 | Input image to predict label.
148 | lbl_true: ndarray
149 | Ground truth of the label.
150 | lbl_pred: ndarray
151 | Label predicted.
152 | n_class: int
153 | Number of classes.
154 | label_names: dict or list
155 | Names of each label value.
156 | Key or index is label_value and value is its name.
157 | Returns
158 | -------
159 | img_array: ndarray
160 | Visualized image.
161 | """
162 | img = kwargs.pop('img', None)
163 | lbl_true = kwargs.pop('lbl_true', None)
164 | lbl_pred = kwargs.pop('lbl_pred', None)
165 | n_class = kwargs.pop('n_classes', None)
166 | dataloader = kwargs.pop('dataloader', None)
167 | if kwargs:
168 | raise RuntimeError('Unexpected keys in kwargs: {}'.format(
169 | kwargs.keys()))
170 |
171 | if lbl_true is None and lbl_pred is None:
172 | raise ValueError('lbl_true or lbl_pred must be not None.')
173 |
174 | mask_unlabeled = None
175 | viz_unlabeled = None
176 | if lbl_true is not None:
177 | mask_unlabeled = lbl_true == -1
178 | # lbl_true[mask_unlabeled] = 0
179 | viz_unlabeled = (np.zeros((lbl_true.shape[0], lbl_true.shape[1],
180 | 3))).astype(np.uint8)
181 | # if lbl_pred is not None:
182 | # lbl_pred[mask_unlabeled] = 0
183 |
184 | vizs = []
185 |
186 | if lbl_true is not None:
187 | viz_trues = [
188 | img,
189 | label2rgb(lbl_true, dataloader, n_labels=n_class),
190 | label2rgb(lbl_true, dataloader, img, n_labels=n_class),
191 | ]
192 | viz_trues[1][mask_unlabeled] = viz_unlabeled[mask_unlabeled]
193 | viz_trues[2][mask_unlabeled] = viz_unlabeled[mask_unlabeled]
194 | vizs.append(get_tile_image(viz_trues, (1, 3)))
195 |
196 | if lbl_pred is not None:
197 | viz_preds = [
198 | img,
199 | label2rgb(lbl_pred, dataloader, n_labels=n_class),
200 | label2rgb(lbl_pred, dataloader, img, n_labels=n_class),
201 | ]
202 | if mask_unlabeled is not None and viz_unlabeled is not None:
203 | viz_preds[1][mask_unlabeled] = viz_unlabeled[mask_unlabeled]
204 | viz_preds[2][mask_unlabeled] = viz_unlabeled[mask_unlabeled]
205 | vizs.append(get_tile_image(viz_preds, (1, 3)))
206 |
207 | if len(vizs) == 1:
208 | return vizs[0]
209 | elif len(vizs) == 2:
210 | return get_tile_image(vizs, (2, 1))
211 | else:
212 | raise RuntimeError
213 |
214 |
215 | # -----------------------------------------------------------------------------
216 | # Utilities
217 | # -----------------------------------------------------------------------------
218 |
219 | # Adapted from official CycleGAN implementation
220 |
221 |
222 | def get_scheduler(optimizer, opt):
223 | """Return a learning rate scheduler
224 | Parameters:
225 | optimizer -- the optimizer of the network
226 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
227 | opt.lr_policy is the name of learning rate policy: linear | poly | step | plateau | cosine
228 | For 'linear', we keep the same learning rate for the first epochs
229 | and linearly decay the rate to zero over the next epochs.
230 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
231 | See https://pytorch.org/docs/stable/optim.html for more details.
232 | """
233 | if opt.lr_policy == 'linear':
234 |
235 | def lambda_rule(epoch):
236 | lr = 1.0 - max(0,
237 | epoch + 1 - opt.epochs) / float(opt.niter_decay + 1)
238 | return lr
239 |
240 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
241 | elif opt.lr_policy == 'poly':
242 |
243 | def lambda_rule(epoch):
244 | lr = (1 - epoch / opt.epochs)**opt.lr_power
245 | return lr
246 |
247 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
248 | elif opt.lr_policy == 'step':
249 | scheduler = lr_scheduler.StepLR(
250 | optimizer, step_size=opt.lr_decay_step, gamma=0.1)
251 | elif opt.lr_policy == 'plateau':
252 | scheduler = lr_scheduler.ReduceLROnPlateau(
253 | optimizer, mode='min', factor=0.2, threshold=1e-4, patience=5)
254 | elif opt.lr_policy == 'cosine':
255 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs)
256 | elif opt.lr_policy is None:
257 | scheduler = None
258 | else:
259 | return NotImplementedError(
260 | f'learning rate policy {opt.lr_policy} is not implemented')
261 | return scheduler
262 |
263 |
264 | # Adapted from:
265 | # https://github.com/wkentaro/pytorch-fcn/blob/master/examples/voc/learning_curve.py
266 |
267 |
268 | def learning_curve(log_file):
269 | print(f'==> Plotting log file: {log_file}')
270 |
271 | df = pandas.read_csv(log_file)
272 |
273 | colors = ['red', 'green', 'blue', 'purple', 'orange']
274 | colors = seaborn.xkcd_palette(colors)
275 |
276 | plt.figure(figsize=(20, 6), dpi=300)
277 |
278 | row_min = df.min()
279 | row_max = df.max()
280 |
281 | # initialize DataFrame for train
282 | columns = [
283 | 'epoch',
284 | 'train/loss',
285 | 'train/acc',
286 | 'train/acc_cls',
287 | 'train/mean_iu',
288 | 'train/fwavacc',
289 | ]
290 | df_train = df[columns]
291 | # if hasattr(df_train, 'rolling'):
292 | # df_train = df_train.rolling(window=10).mean()
293 | # else:
294 | # df_train = pandas.rolling_mean(df_train, window=10)
295 | df_train = df_train.dropna()
296 |
297 | # initialize DataFrame for val
298 | columns = [
299 | 'epoch',
300 | 'valid/loss',
301 | 'valid/acc',
302 | 'valid/acc_cls',
303 | 'valid/mean_iu',
304 | 'valid/fwavacc',
305 | ]
306 | df_valid = df[columns]
307 | df_valid = df_valid.dropna()
308 |
309 | data_frames = {'train': df_train, 'valid': df_valid}
310 |
311 | n_row = 2
312 | n_col = 2
313 | for i, split in enumerate(['train', 'valid']):
314 | df_split = data_frames[split]
315 |
316 | # loss
317 | plt.subplot(n_row, n_col, i * n_col + 1)
318 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
319 | plt.plot(
320 | df_split['epoch'],
321 | df_split[f'{split}/loss'],
322 | '-',
323 | markersize=1,
324 | color=colors[0],
325 | alpha=.5,
326 | label=f'{split} loss')
327 | plt.xlim((1, row_max['epoch']))
328 | plt.ylim(
329 | min(df_split[f'{split}/loss']), max(df_split[f'{split}/loss']))
330 | plt.xlabel('epoch')
331 | plt.ylabel(f'{split} loss')
332 |
333 | # loss (log)
334 | # plt.subplot(n_row, n_col, i * n_col + 2)
335 | # plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
336 | # plt.semilogy(df_split['epoch'], df_split[f'{split}/loss'],
337 | # '-', markersize=1, color=colors[0], alpha=.5,
338 | # label=f'{split} loss')
339 | # plt.xlim((1, row_max['epoch']))
340 | # plt.ylim(min(df_split[f'{split}/loss']), max(df_split[f'{split}/loss']))
341 | # plt.xlabel('epoch')
342 | # plt.ylabel('f{split} loss (log)')
343 |
344 | # lbl accuracy
345 | plt.subplot(n_row, n_col, i * n_col + 2)
346 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
347 | plt.plot(
348 | df_split['epoch'],
349 | df_split[f'{split}/acc'],
350 | '-',
351 | markersize=1,
352 | color=colors[1],
353 | alpha=.5,
354 | label=f'{split} accuracy')
355 | plt.plot(
356 | df_split['epoch'],
357 | df_split[f'{split}/acc_cls'],
358 | '-',
359 | markersize=1,
360 | color=colors[2],
361 | alpha=.5,
362 | label=f'{split} accuracy class')
363 | plt.plot(
364 | df_split['epoch'],
365 | df_split[f'{split}/mean_iu'],
366 | '-',
367 | markersize=1,
368 | color=colors[3],
369 | alpha=.5,
370 | label=f'{split} mean IU')
371 | plt.plot(
372 | df_split['epoch'],
373 | df_split[f'{split}/fwavacc'],
374 | '-',
375 | markersize=1,
376 | color=colors[4],
377 | alpha=.5,
378 | label=f'{split} fwav accuracy')
379 | plt.legend()
380 | plt.xlim((1, row_max['epoch']))
381 | plt.ylim((0, 1))
382 | plt.xlabel('epoch')
383 | plt.ylabel(f'{split} label accuracy')
384 |
385 | # out_file = osp.splitext(log_file)[0] + '.png'
386 | out_file = log_file[:-4] + '.png'
387 | plt.savefig(out_file)
388 | print(f'==> Wrote figure to: {out_file}')
389 |
--------------------------------------------------------------------------------