├── .gitignore ├── Dataloader ├── __init__.py ├── baseloader.py ├── camvid_loader.py ├── citys_loader.py ├── custom_loader.py ├── seg11valid.txt └── voc_loader.py ├── Models ├── DeepLab_v1.py ├── DeepLab_v2.py ├── DeepLab_v3.py ├── DeepLab_v3plus.py ├── Dilation8.py ├── FCN.py ├── PSPNet.py ├── SegNet.py ├── UNet.py └── __init__.py ├── README.md ├── augmentations.py ├── evaluate.py ├── learning_curve.py ├── loss.py ├── metrics.py ├── optimizer.py ├── preparation.py ├── requirements.txt ├── train.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | logs 3 | .idea 4 | -------------------------------------------------------------------------------- /Dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom_loader import CustomLoader 2 | from .voc_loader import VOCLoader, SBDLoader, VOC11Val 3 | from .citys_loader import CityscapesLoader 4 | from .camvid_loader import CamVidLoader 5 | 6 | VALID_DATASET = ['voc', 'cityscapes', 'sbd', 'voc11', 'camvid', 'custom'] 7 | 8 | 9 | def get_loader(dataset_type): 10 | if dataset_type.lower() == 'custom': 11 | return CustomLoader 12 | elif dataset_type.lower() == 'voc': 13 | return VOCLoader 14 | elif dataset_type.lower() == 'cityscapes': 15 | return CityscapesLoader 16 | elif dataset_type.lower() == 'sbd': 17 | return SBDLoader 18 | elif dataset_type.lower() == 'voc11': 19 | return VOC11Val 20 | elif dataset_type.lower() == 'camvid': 21 | return CamVidLoader 22 | else: 23 | raise ValueError('Unsupported dataset, ' 24 | 'valid datasets as follows:\n{}\n' 25 | 'voc11 only for evaluation'.format(', '.join(VALID_DATASET))) 26 | -------------------------------------------------------------------------------- /Dataloader/baseloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils import data 4 | from torchvision import transforms 5 | 6 | 7 | class BaseLoader(data.Dataset): 8 | # specify class_name if available 9 | class_name = None 10 | 11 | def __init__(self, 12 | root, 13 | n_classes, 14 | split='train', 15 | img_size=None, 16 | augmentations=None, 17 | ignore_index=None, 18 | class_weight=None, 19 | pretrained=False): 20 | 21 | self.root = root 22 | self.n_classes = n_classes 23 | self.split = split 24 | self.img_size = img_size 25 | self.augmentations = augmentations 26 | self.ignore_index = ignore_index 27 | self.class_weight = class_weight 28 | 29 | if pretrained: 30 | # if use pretrained model, substract mean and divide standard deviation 31 | self.mean = torch.tensor([0.485, 0.456, 0.406]) 32 | self.std = torch.tensor([0.229, 0.224, 0.225]) 33 | self.tf = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize(self.mean.tolist(), self.std.tolist()) 36 | ]) 37 | self.untf = transforms.Compose([ 38 | transforms.Normalize((-self.mean / self.std).tolist(), 39 | (1.0 / self.std).tolist()) 40 | ]) 41 | else: 42 | # if not use pretrained model, only scale images to [0, 1] 43 | self.tf = transforms.Compose([transforms.ToTensor()]) 44 | self.untf = transforms.Compose( 45 | [transforms.Normalize([0, 0, 0], [1, 1, 1])]) 46 | 47 | def __getitem__(self, index): 48 | return NotImplementedError 49 | 50 | def transform(self, img, lbl): 51 | img = self.tf(img) 52 | lbl = np.array(lbl, dtype=np.int32) 53 | lbl[lbl == 255] = -1 54 | if self.ignore_index: 55 | lbl[lbl == self.ignore_index] = -1 56 | lbl = torch.from_numpy(lbl).long() 57 | return img, lbl 58 | 59 | def untransform(self, img, lbl): 60 | img = self.untf(img) 61 | img = img.numpy() 62 | img = img.transpose(1, 2, 0) 63 | img = img * 255 64 | img = img.astype(np.uint8) 65 | lbl = lbl.numpy() 66 | return img, lbl 67 | 68 | def getpalette(self): 69 | return NotImplementedError 70 | -------------------------------------------------------------------------------- /Dataloader/camvid_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from .baseloader import BaseLoader 6 | 7 | class CamVidLoader(BaseLoader): 8 | """CamVid dataset loader. 9 | Parameters 10 | ---------- 11 | root: path to CamVid dataset. 12 | n_classes: number of classes, default 11. 13 | split: choose subset of dataset, 'train','val' or 'test'. 14 | img_size: a list or a tuple, scale image to proper size. 15 | augmentations: whether to perform augmentation. 16 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 11. 17 | class_weight: useful in unbalanced datasets. 18 | pretrained: whether to use pretrained models 19 | """ 20 | class_names = np.array([ 21 | 'sky', 22 | 'building', 23 | 'pole', 24 | 'road', 25 | 'pavement', 26 | 'tree', 27 | 'sign', 28 | 'fence', 29 | 'vehicle', 30 | 'pedestrian', 31 | 'bicyclist', 32 | 'void' 33 | ]) 34 | 35 | def __init__( 36 | self, 37 | root, 38 | n_classes=11, 39 | split='train', 40 | img_size=None, 41 | augmentations=None, 42 | ignore_index=11, 43 | class_weight=None, 44 | pretrained=False 45 | ): 46 | super(CamVidLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained) 47 | 48 | path = os.path.join(self.root, self.split + ".txt") 49 | with open(path, "r") as f: 50 | self.file_list = [file_name.rstrip() for file_name in f] 51 | 52 | self.class_weight = [0.2595, 0.1826, 4.5640, 0.1417, 53 | 0.9051, 0.3826, 9.6446, 1.8418, 54 | 0.6823 ,6.2478, 7.3614] 55 | 56 | print(f"Found {len(self.file_list)} {split} images") 57 | 58 | def __len__(self): 59 | return len(self.file_list) 60 | 61 | def __getitem__(self, index): 62 | img_name = self.file_list[index] 63 | img_name = img_name.split()[0].split('/')[-1] 64 | img_path = os.path.join(self.root, self.split, img_name) 65 | if self.split == 'train': 66 | lbl_path = os.path.join(self.root, 'trainannot', img_name) 67 | elif self.split == 'val': 68 | lbl_path = os.path.join(self.root, 'valannot', img_name) 69 | elif self.split == 'test': 70 | lbl_path = os.path.join(self.root, 'testannot', img_name) 71 | 72 | img = Image.open(img_path).convert('RGB') 73 | lbl = Image.open(lbl_path) 74 | if self.img_size: 75 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR) 76 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST) 77 | if self.augmentations: 78 | img, lbl = self.augmentations(img, lbl) 79 | 80 | img, lbl = self.transform(img, lbl) 81 | return img, lbl 82 | 83 | def getpalette(self): 84 | return np.asarray( 85 | [ 86 | [128, 128, 128], 87 | [128, 0, 0], 88 | [192, 192, 128], 89 | [128, 64, 128], 90 | [0, 0, 192], 91 | [128, 128, 0], 92 | [192, 128, 128], 93 | [64, 64, 128], 94 | [64, 0, 128], 95 | [64, 64, 0], 96 | [0, 128, 192], 97 | ] 98 | ) 99 | 100 | 101 | # Test code 102 | # if __name__ == '__main__': 103 | # from torch.utils.data import DataLoader 104 | # root = r'D:/Datasets/CamVid' 105 | # batch_size = 2 106 | # loader = CamVidLoader(root=root, img_size=None) 107 | # test_loader = DataLoader(loader, batch_size=batch_size, shuffle=True) 108 | 109 | # palette = test_loader.dataset.getpalette() 110 | # fig, axes = plt.subplots(batch_size, 2, subplot_kw={'xticks': [], 'yticks': []}) 111 | # fig.subplots_adjust(left=0.03, right=0.97, hspace=0.2, wspace=0.05) 112 | 113 | # for imgs, labels in test_loader: 114 | # imgs = imgs.numpy() 115 | # imgs = np.transpose(imgs, [0,2,3,1]) 116 | # labels = labels.numpy() 117 | 118 | # for i in range(batch_size): 119 | # axes[i][0].imshow(imgs[i]) 120 | 121 | # mask_unlabeled = labels[i] == -1 122 | # viz_unlabeled = ( 123 | # np.zeros((labels[i].shape[0], labels[i].shape[1], 3)) 124 | # ).astype(np.uint8) 125 | 126 | # lbl_viz = palette[labels[i]] 127 | # lbl_viz[labels[i] == -1] = (0, 0, 0) 128 | # lbl_viz[mask_unlabeled] = viz_unlabeled[mask_unlabeled] 129 | 130 | # axes[i][1].imshow(lbl_viz.astype(np.uint8)) 131 | # plt.show() 132 | # break 133 | 134 | -------------------------------------------------------------------------------- /Dataloader/citys_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from .baseloader import BaseLoader 6 | from collections import namedtuple 7 | from torch.utils import data 8 | from torchvision import transforms 9 | 10 | 11 | class CityscapesLoader(BaseLoader): 12 | """Cityscapes dataset loader. 13 | Parameters 14 | ---------- 15 | root: path to cityscapes dataset. 16 | for directory: 17 | --VOCdevkit--VOC2012---ImageSets 18 | |-JPEGImages 19 | |- ... 20 | root should be xxx/VOCdevkit/VOC2012 21 | n_classes: number of classes, default 19. 22 | split: choose subset of dataset, 'train','val' or 'trainval'. 23 | img_size: scale image to proper size. 24 | augmentations: whether to perform augmentation. 25 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255. 26 | class_weight: useful in unbalanced datasets. 27 | pretrained: whether to use pretrained models 28 | """ 29 | 30 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 31 | 'ignore_in_eval', 'color']) 32 | 33 | classes = [ 34 | # name id trainId ignoreInEval color 35 | CityscapesClass('unlabeled', 0, 255, True, (0, 0, 0)), 36 | CityscapesClass('ego vehicle', 1, 255, True, (0, 0, 0)), 37 | CityscapesClass('rectification border', 2, 255, True, (0, 0, 0)), 38 | CityscapesClass('out of roi', 3, 255, True, (0, 0, 0)), 39 | CityscapesClass('static', 4, 255, True, (0, 0, 0)), 40 | CityscapesClass('dynamic', 5, 255, True, (111, 74, 0)), 41 | CityscapesClass('ground', 6, 255, True, (81, 0, 81)), 42 | CityscapesClass('road', 7, 0, False, (128, 64, 128)), 43 | CityscapesClass('sidewalk', 8, 1, False, (244, 35, 232)), 44 | CityscapesClass('parking', 9, 255, True, (250, 170, 160)), 45 | CityscapesClass('rail track', 10, 255, True, (230, 150, 140)), 46 | CityscapesClass('building', 11, 2, False, (70, 70, 70)), 47 | CityscapesClass('wall', 12, 3, False, (102, 102, 156)), 48 | CityscapesClass('fence', 13, 4, False, (190, 153, 153)), 49 | CityscapesClass('guard rail', 14, 255, True, (180, 165, 180)), 50 | CityscapesClass('bridge', 15, 255, True, (150, 100, 100)), 51 | CityscapesClass('tunnel', 16, 255, True, (150, 120, 90)), 52 | CityscapesClass('pole', 17, 5, False, (153, 153, 153)), 53 | CityscapesClass('polegroup', 18, 255, True, (153, 153, 153)), 54 | CityscapesClass('traffic light', 19, 6, False, (250, 170, 30)), 55 | CityscapesClass('traffic sign', 20, 7, False, (220, 220, 0)), 56 | CityscapesClass('vegetation', 21, 8, False, (107, 142, 35)), 57 | CityscapesClass('terrain', 22, 9, False, (152, 251, 152)), 58 | CityscapesClass('sky', 23, 10, False, (70, 130, 180)), 59 | CityscapesClass('person', 24, 11, False, (220, 20, 60)), 60 | CityscapesClass('rider', 25, 12, False, (255, 0, 0)), 61 | CityscapesClass('car', 26, 13, False, (0, 0, 142)), 62 | CityscapesClass('truck', 27, 14, False, (0, 0, 70)), 63 | CityscapesClass('bus', 28, 15, False, (0, 60, 100)), 64 | CityscapesClass('caravan', 29, 255, True, (0, 0, 90)), 65 | CityscapesClass('trailer', 30, 255, True, (0, 0, 110)), 66 | CityscapesClass('train', 31, 16, False, (0, 80, 100)), 67 | CityscapesClass('motorcycle', 32, 17, False, (0, 0, 230)), 68 | CityscapesClass('bicycle', 33, 18, False, (119, 11, 32)), 69 | CityscapesClass('license plate', -1, -1, True, (0, 0, 142)), 70 | ] 71 | 72 | def __init__( 73 | self, 74 | root, 75 | n_classes=19, 76 | split="train", 77 | img_size=None, 78 | augmentations=None, 79 | ignore_index=255, 80 | class_weight=None, 81 | pretrained=False 82 | ): 83 | super(CityscapesLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained) 84 | 85 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 86 | self.labels_dir = os.path.join(self.root, 'gtFine', split) 87 | self.images = [] 88 | self.labels = [] 89 | 90 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 91 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 92 | 27, 28, 31, 32, 33] 93 | self.class_map = dict(zip(self.valid_classes, range(self.n_classes))) 94 | 95 | for city in os.listdir(self.images_dir): 96 | img_dir = os.path.join(self.images_dir, city) 97 | label_dir = os.path.join(self.labels_dir, city) 98 | for file_name in os.listdir(img_dir): 99 | label_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 100 | 'gtFine_labelIds.png') 101 | self.images.append(os.path.join(img_dir, file_name)) 102 | self.labels.append(os.path.join(label_dir, label_name)) 103 | 104 | print(f"Found {len(self.images)} {split} images") 105 | 106 | def __len__(self): 107 | return len(self.images) 108 | 109 | def __getitem__(self, index): 110 | img = Image.open(self.images[index]).convert('RGB') 111 | lbl = Image.open(self.labels[index]) 112 | 113 | if self.img_size: 114 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR) 115 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST) 116 | 117 | if self.augmentations: 118 | img, lbl = self.augmentations(img, lbl) 119 | 120 | img, lbl = self.transform(img, lbl) 121 | return img, lbl 122 | 123 | def transform(self, img, lbl): 124 | img = self.tf(img) 125 | 126 | lbl = np.array(lbl, dtype=np.int32) 127 | lbl = self.encode_segmap(lbl) 128 | lbl = torch.from_numpy(lbl).long() 129 | return img, lbl 130 | 131 | def getpalette(self): 132 | return np.array([ 133 | [128, 64, 128], 134 | [244, 35, 232], 135 | [70, 70, 70], 136 | [102, 102, 156], 137 | [190, 153, 153], 138 | [153, 153, 153], 139 | [250, 170, 30], 140 | [220, 220, 0], 141 | [107, 142, 35], 142 | [152, 251, 152], 143 | [0, 130, 180], 144 | [220, 20, 60], 145 | [255, 0, 0], 146 | [0, 0, 142], 147 | [0, 0, 70], 148 | [0, 60, 100], 149 | [0, 80, 100], 150 | [0, 0, 230], 151 | [119, 11, 32] 152 | ]) 153 | 154 | def decode_segmap(self, lbl): 155 | label_colours = self.getpalette() 156 | r = label_mask.copy() 157 | g = label_mask.copy() 158 | b = label_mask.copy() 159 | for ll in range(0, self.n_classes): 160 | r[label_mask == ll] = label_colours[ll, 0] 161 | g[label_mask == ll] = label_colours[ll, 1] 162 | b[label_mask == ll] = label_colours[ll, 2] 163 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 164 | rgb[:, :, 0] = r / 255.0 165 | rgb[:, :, 1] = g / 255.0 166 | rgb[:, :, 2] = b / 255.0 167 | 168 | return rgb 169 | 170 | def encode_segmap(self, mask): 171 | # Put all void classes to -1 172 | for _voidc in self.void_classes: 173 | mask[mask == _voidc] = -1 174 | for _validc in self.valid_classes: 175 | mask[mask == _validc] = self.class_map[_validc] 176 | return mask 177 | 178 | 179 | if __name__ == "__main__": 180 | import matplotlib.pyplot as plt 181 | 182 | local_path = "/home/ecust/zww/DANet/datasets/cityscapes" 183 | dst = CityscapesLoader(local_path, transform=True) 184 | bs = 4 185 | trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 186 | for i, data_samples in enumerate(trainloader): 187 | imgs, labels = data_samples 188 | 189 | plt.subplots(1, 1) 190 | for j in range(1): 191 | plt.subplot(1, 2, j + 1) 192 | plt.imshow(np.transpose(imgs.numpy()[j], [1, 2, 0])) 193 | plt.subplot(1, 2, j + 2) 194 | plt.imshow(dst.decode_segmap(labels.numpy()[j])) 195 | plt.show() 196 | -------------------------------------------------------------------------------- /Dataloader/custom_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from .baseloader import BaseLoader 5 | from PIL import Image 6 | 7 | 8 | class CustomLoader(BaseLoader): 9 | """Custom dataset loader. 10 | Parameters 11 | ---------- 12 | root: path to custom dataset, with train.txt and val.txt together. 13 | i.e., -----dataset 14 | |--train.txt 15 | |--val.txt 16 | n_classes: number of classes. 17 | split: choose subset of dataset, 'train','val' or 'test'. 18 | img_size: scale image to proper size. 19 | augmentations: whether to perform augmentation. 20 | ignore_index: ingore_index will be ignored in training phase and evaluation. 21 | class_weight: useful in unbalanced datasets. 22 | pretrained: whether to use pretrained models 23 | """ 24 | # specify class_names if necessary 25 | class_names = None 26 | 27 | def __init__( 28 | self, 29 | root, 30 | n_classes, 31 | split="train", 32 | img_size=None, 33 | augmentations=None, 34 | ignore_index=None, 35 | class_weight=None, 36 | pretrained=False 37 | ): 38 | super(CustomLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained) 39 | 40 | path = os.path.join(self.root, split + ".txt") 41 | with open(path, "r") as f: 42 | self.file_list = [file_name.rstrip().split() for file_name in f] 43 | 44 | print(f"Found {len(self.file_list)} {split} images") 45 | 46 | def __len__(self): 47 | return len(self.file_list) 48 | 49 | def __getitem__(self, index): 50 | img_name = self.file_list[index][0] 51 | lbl_name = self.file_list[index][1] 52 | 53 | img = Image.open(os.path.join(self.root, img_name)).convert('RGB') 54 | lbl = Image.open(os.path.join(self.root, lbl_name)) 55 | 56 | if self.img_size: 57 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR) 58 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST) 59 | 60 | if self.augmentations: 61 | img, lbl = self.augmentations(img, lbl) 62 | 63 | img, lbl = self.transform(img, lbl) 64 | return img, lbl 65 | 66 | def getpalette(self): 67 | """for custom palette, if not specified, use pascal voc palette by default. 68 | """ 69 | n = self.n_classes 70 | palette = [0]*(n*3) 71 | for j in range(0, n): 72 | lab = j 73 | palette[j*3+0] = 0 74 | palette[j*3+1] = 0 75 | palette[j*3+2] = 0 76 | i = 0 77 | while (lab > 0): 78 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 79 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 80 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 81 | i = i + 1 82 | lab >>= 3 83 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8) 84 | return palette 85 | 86 | 87 | # Test code 88 | # if __name__ == '__main__': 89 | # from torch.utils.data import DataLoader 90 | # root = '' 91 | # batch_size = 2 92 | # loader = CustomLoader(root=root, img_size=None) 93 | # test_loader = DataLoader(loader, batch_size=batch_size, shuffle=True) 94 | 95 | # palette = test_loader.dataset.getpalette() 96 | # fig, axes = plt.subplots(batch_size, 2, subplot_kw={'xticks': [], 'yticks': []}) 97 | # fig.subplots_adjust(left=0.03, right=0.97, hspace=0.2, wspace=0.05) 98 | 99 | # for imgs, labels in test_loader: 100 | # imgs = imgs.numpy() 101 | # imgs = np.transpose(imgs, [0,2,3,1]) 102 | # labels = labels.numpy() 103 | 104 | # for i in range(batch_size): 105 | # axes[i][0].imshow(imgs[i]) 106 | 107 | # mask_unlabeled = labels[i] == -1 108 | # viz_unlabeled = ( 109 | # np.zeros((labels[i].shape[0], labels[i].shape[1], 3)) 110 | # ).astype(np.uint8) 111 | 112 | # lbl_viz = palette[labels[i]] 113 | # lbl_viz[labels[i] == -1] = (0, 0, 0) 114 | # lbl_viz[mask_unlabeled] = viz_unlabeled[mask_unlabeled] 115 | 116 | # axes[i][1].imshow(lbl_viz.astype(np.uint8)) 117 | # plt.show() 118 | # break 119 | 120 | -------------------------------------------------------------------------------- /Dataloader/seg11valid.txt: -------------------------------------------------------------------------------- 1 | 2007_000033 2 | 2007_000042 3 | 2007_000061 4 | 2007_000123 5 | 2007_000129 6 | 2007_000175 7 | 2007_000187 8 | 2007_000323 9 | 2007_000332 10 | 2007_000346 11 | 2007_000452 12 | 2007_000464 13 | 2007_000491 14 | 2007_000529 15 | 2007_000559 16 | 2007_000572 17 | 2007_000629 18 | 2007_000636 19 | 2007_000661 20 | 2007_000663 21 | 2007_000676 22 | 2007_000727 23 | 2007_000762 24 | 2007_000783 25 | 2007_000799 26 | 2007_000804 27 | 2007_000830 28 | 2007_000837 29 | 2007_000847 30 | 2007_000862 31 | 2007_000925 32 | 2007_000999 33 | 2007_001154 34 | 2007_001175 35 | 2007_001239 36 | 2007_001284 37 | 2007_001288 38 | 2007_001289 39 | 2007_001299 40 | 2007_001311 41 | 2007_001321 42 | 2007_001377 43 | 2007_001408 44 | 2007_001423 45 | 2007_001430 46 | 2007_001457 47 | 2007_001458 48 | 2007_001526 49 | 2007_001568 50 | 2007_001585 51 | 2007_001586 52 | 2007_001587 53 | 2007_001594 54 | 2007_001630 55 | 2007_001677 56 | 2007_001678 57 | 2007_001717 58 | 2007_001733 59 | 2007_001761 60 | 2007_001763 61 | 2007_001774 62 | 2007_001884 63 | 2007_001955 64 | 2007_002046 65 | 2007_002094 66 | 2007_002119 67 | 2007_002132 68 | 2007_002260 69 | 2007_002266 70 | 2007_002268 71 | 2007_002284 72 | 2007_002376 73 | 2007_002378 74 | 2007_002387 75 | 2007_002400 76 | 2007_002412 77 | 2007_002426 78 | 2007_002427 79 | 2007_002445 80 | 2007_002470 81 | 2007_002539 82 | 2007_002565 83 | 2007_002597 84 | 2007_002618 85 | 2007_002619 86 | 2007_002624 87 | 2007_002643 88 | 2007_002648 89 | 2007_002719 90 | 2007_002728 91 | 2007_002823 92 | 2007_002824 93 | 2007_002852 94 | 2007_002903 95 | 2007_003011 96 | 2007_003020 97 | 2007_003022 98 | 2007_003051 99 | 2007_003088 100 | 2007_003101 101 | 2007_003106 102 | 2007_003110 103 | 2007_003131 104 | 2007_003134 105 | 2007_003137 106 | 2007_003143 107 | 2007_003169 108 | 2007_003188 109 | 2007_003194 110 | 2007_003195 111 | 2007_003201 112 | 2007_003349 113 | 2007_003367 114 | 2007_003373 115 | 2007_003499 116 | 2007_003503 117 | 2007_003506 118 | 2007_003530 119 | 2007_003571 120 | 2007_003587 121 | 2007_003611 122 | 2007_003621 123 | 2007_003682 124 | 2007_003711 125 | 2007_003714 126 | 2007_003742 127 | 2007_003786 128 | 2007_003841 129 | 2007_003848 130 | 2007_003861 131 | 2007_003872 132 | 2007_003917 133 | 2007_003957 134 | 2007_003991 135 | 2007_004033 136 | 2007_004052 137 | 2007_004112 138 | 2007_004121 139 | 2007_004143 140 | 2007_004189 141 | 2007_004190 142 | 2007_004193 143 | 2007_004241 144 | 2007_004275 145 | 2007_004281 146 | 2007_004380 147 | 2007_004392 148 | 2007_004405 149 | 2007_004468 150 | 2007_004483 151 | 2007_004510 152 | 2007_004538 153 | 2007_004558 154 | 2007_004644 155 | 2007_004649 156 | 2007_004712 157 | 2007_004722 158 | 2007_004856 159 | 2007_004866 160 | 2007_004902 161 | 2007_004969 162 | 2007_005058 163 | 2007_005074 164 | 2007_005107 165 | 2007_005114 166 | 2007_005149 167 | 2007_005173 168 | 2007_005281 169 | 2007_005294 170 | 2007_005296 171 | 2007_005304 172 | 2007_005331 173 | 2007_005354 174 | 2007_005358 175 | 2007_005428 176 | 2007_005460 177 | 2007_005469 178 | 2007_005509 179 | 2007_005547 180 | 2007_005600 181 | 2007_005608 182 | 2007_005626 183 | 2007_005689 184 | 2007_005696 185 | 2007_005705 186 | 2007_005759 187 | 2007_005803 188 | 2007_005813 189 | 2007_005828 190 | 2007_005844 191 | 2007_005845 192 | 2007_005857 193 | 2007_005911 194 | 2007_005915 195 | 2007_005978 196 | 2007_006028 197 | 2007_006035 198 | 2007_006046 199 | 2007_006076 200 | 2007_006086 201 | 2007_006117 202 | 2007_006171 203 | 2007_006241 204 | 2007_006260 205 | 2007_006277 206 | 2007_006348 207 | 2007_006364 208 | 2007_006373 209 | 2007_006444 210 | 2007_006449 211 | 2007_006549 212 | 2007_006553 213 | 2007_006560 214 | 2007_006647 215 | 2007_006678 216 | 2007_006680 217 | 2007_006698 218 | 2007_006761 219 | 2007_006802 220 | 2007_006837 221 | 2007_006841 222 | 2007_006864 223 | 2007_006866 224 | 2007_006946 225 | 2007_007007 226 | 2007_007084 227 | 2007_007109 228 | 2007_007130 229 | 2007_007165 230 | 2007_007168 231 | 2007_007195 232 | 2007_007196 233 | 2007_007203 234 | 2007_007211 235 | 2007_007235 236 | 2007_007341 237 | 2007_007414 238 | 2007_007417 239 | 2007_007470 240 | 2007_007477 241 | 2007_007493 242 | 2007_007498 243 | 2007_007524 244 | 2007_007534 245 | 2007_007624 246 | 2007_007651 247 | 2007_007688 248 | 2007_007748 249 | 2007_007795 250 | 2007_007810 251 | 2007_007815 252 | 2007_007818 253 | 2007_007836 254 | 2007_007849 255 | 2007_007881 256 | 2007_007996 257 | 2007_008051 258 | 2007_008084 259 | 2007_008106 260 | 2007_008110 261 | 2007_008204 262 | 2007_008222 263 | 2007_008256 264 | 2007_008260 265 | 2007_008339 266 | 2007_008374 267 | 2007_008415 268 | 2007_008430 269 | 2007_008543 270 | 2007_008547 271 | 2007_008596 272 | 2007_008645 273 | 2007_008670 274 | 2007_008708 275 | 2007_008722 276 | 2007_008747 277 | 2007_008802 278 | 2007_008815 279 | 2007_008897 280 | 2007_008944 281 | 2007_008964 282 | 2007_008973 283 | 2007_008980 284 | 2007_009015 285 | 2007_009068 286 | 2007_009084 287 | 2007_009088 288 | 2007_009096 289 | 2007_009221 290 | 2007_009245 291 | 2007_009251 292 | 2007_009252 293 | 2007_009258 294 | 2007_009320 295 | 2007_009323 296 | 2007_009331 297 | 2007_009346 298 | 2007_009392 299 | 2007_009413 300 | 2007_009419 301 | 2007_009446 302 | 2007_009458 303 | 2007_009521 304 | 2007_009562 305 | 2007_009592 306 | 2007_009654 307 | 2007_009655 308 | 2007_009684 309 | 2007_009687 310 | 2007_009691 311 | 2007_009706 312 | 2007_009750 313 | 2007_009756 314 | 2007_009764 315 | 2007_009794 316 | 2007_009817 317 | 2007_009841 318 | 2007_009897 319 | 2007_009911 320 | 2007_009923 321 | 2007_009938 322 | 2008_000073 323 | 2008_000075 324 | 2008_000107 325 | 2008_000123 326 | 2008_000149 327 | 2008_000213 328 | 2008_000215 329 | 2008_000223 330 | 2008_000233 331 | 2008_000239 332 | 2008_000271 333 | 2008_000345 334 | 2008_000391 335 | 2008_000401 336 | 2008_000501 337 | 2008_000533 338 | 2008_000573 339 | 2008_000589 340 | 2008_000657 341 | 2008_000661 342 | 2008_000725 343 | 2008_000731 344 | 2008_000763 345 | 2008_000765 346 | 2008_000811 347 | 2008_000853 348 | 2008_000911 349 | 2008_000919 350 | 2008_000943 351 | 2008_001135 352 | 2008_001231 353 | 2008_001249 354 | 2008_001379 355 | 2008_001433 356 | 2008_001439 357 | 2008_001513 358 | 2008_001531 359 | 2008_001547 360 | 2008_001715 361 | 2008_001821 362 | 2008_001885 363 | 2008_001971 364 | 2008_002043 365 | 2008_002205 366 | 2008_002239 367 | 2008_002269 368 | 2008_002273 369 | 2008_002379 370 | 2008_002383 371 | 2008_002467 372 | 2008_002521 373 | 2008_002623 374 | 2008_002681 375 | 2008_002775 376 | 2008_002835 377 | 2008_002859 378 | 2008_003105 379 | 2008_003135 380 | 2008_003155 381 | 2008_003369 382 | 2008_003709 383 | 2008_003777 384 | 2008_003821 385 | 2008_003885 386 | 2008_004069 387 | 2008_004172 388 | 2008_004175 389 | 2008_004279 390 | 2008_004339 391 | 2008_004345 392 | 2008_004363 393 | 2008_004453 394 | 2008_004562 395 | 2008_004575 396 | 2008_004621 397 | 2008_004659 398 | 2008_004705 399 | 2008_004995 400 | 2008_005049 401 | 2008_005097 402 | 2008_005105 403 | 2008_005145 404 | 2008_005217 405 | 2008_005262 406 | 2008_005439 407 | 2008_005525 408 | 2008_005633 409 | 2008_005637 410 | 2008_005691 411 | 2008_006055 412 | 2008_006229 413 | 2008_006327 414 | 2008_006553 415 | 2008_006835 416 | 2008_007025 417 | 2008_007031 418 | 2008_007123 419 | 2008_007497 420 | 2008_007677 421 | 2008_007797 422 | 2008_007811 423 | 2008_008051 424 | 2008_008103 425 | 2008_008301 426 | 2009_000013 427 | 2009_000022 428 | 2009_000032 429 | 2009_000037 430 | 2009_000039 431 | 2009_000087 432 | 2009_000121 433 | 2009_000149 434 | 2009_000201 435 | 2009_000205 436 | 2009_000219 437 | 2009_000335 438 | 2009_000351 439 | 2009_000387 440 | 2009_000391 441 | 2009_000446 442 | 2009_000455 443 | 2009_000457 444 | 2009_000469 445 | 2009_000487 446 | 2009_000523 447 | 2009_000619 448 | 2009_000641 449 | 2009_000675 450 | 2009_000705 451 | 2009_000723 452 | 2009_000727 453 | 2009_000771 454 | 2009_000845 455 | 2009_000879 456 | 2009_000919 457 | 2009_000931 458 | 2009_000935 459 | 2009_000989 460 | 2009_000991 461 | 2009_001255 462 | 2009_001299 463 | 2009_001333 464 | 2009_001363 465 | 2009_001391 466 | 2009_001411 467 | 2009_001433 468 | 2009_001505 469 | 2009_001535 470 | 2009_001565 471 | 2009_001607 472 | 2009_001663 473 | 2009_001683 474 | 2009_001687 475 | 2009_001731 476 | 2009_001775 477 | 2009_001851 478 | 2009_001941 479 | 2009_002035 480 | 2009_002165 481 | 2009_002171 482 | 2009_002221 483 | 2009_002291 484 | 2009_002295 485 | 2009_002317 486 | 2009_002445 487 | 2009_002487 488 | 2009_002521 489 | 2009_002527 490 | 2009_002535 491 | 2009_002539 492 | 2009_002549 493 | 2009_002571 494 | 2009_002573 495 | 2009_002591 496 | 2009_002635 497 | 2009_002649 498 | 2009_002651 499 | 2009_002727 500 | 2009_002749 501 | 2009_002753 502 | 2009_002771 503 | 2009_002887 504 | 2009_002975 505 | 2009_003003 506 | 2009_003005 507 | 2009_003059 508 | 2009_003063 509 | 2009_003065 510 | 2009_003071 511 | 2009_003105 512 | 2009_003123 513 | 2009_003193 514 | 2009_003269 515 | 2009_003273 516 | 2009_003311 517 | 2009_003323 518 | 2009_003343 519 | 2009_003387 520 | 2009_003481 521 | 2009_003517 522 | 2009_003523 523 | 2009_003549 524 | 2009_003551 525 | 2009_003589 526 | 2009_003607 527 | 2009_003703 528 | 2009_003707 529 | 2009_003771 530 | 2009_003849 531 | 2009_003857 532 | 2009_003895 533 | 2009_004021 534 | 2009_004033 535 | 2009_004043 536 | 2009_004099 537 | 2009_004125 538 | 2009_004217 539 | 2009_004255 540 | 2009_004455 541 | 2009_004507 542 | 2009_004509 543 | 2009_004579 544 | 2009_004581 545 | 2009_004687 546 | 2009_004801 547 | 2009_004859 548 | 2009_004867 549 | 2009_004895 550 | 2009_004969 551 | 2009_004993 552 | 2009_005087 553 | 2009_005089 554 | 2009_005137 555 | 2009_005189 556 | 2009_005217 557 | 2009_005219 558 | 2010_000003 559 | 2010_000065 560 | 2010_000083 561 | 2010_000159 562 | 2010_000163 563 | 2010_000309 564 | 2010_000427 565 | 2010_000559 566 | 2010_000573 567 | 2010_000639 568 | 2010_000683 569 | 2010_000907 570 | 2010_000961 571 | 2010_001017 572 | 2010_001061 573 | 2010_001069 574 | 2010_001149 575 | 2010_001151 576 | 2010_001251 577 | 2010_001313 578 | 2010_001327 579 | 2010_001331 580 | 2010_001553 581 | 2010_001557 582 | 2010_001563 583 | 2010_001577 584 | 2010_001579 585 | 2010_001767 586 | 2010_001773 587 | 2010_001851 588 | 2010_001995 589 | 2010_002017 590 | 2010_002025 591 | 2010_002137 592 | 2010_002147 593 | 2010_002161 594 | 2010_002271 595 | 2010_002305 596 | 2010_002361 597 | 2010_002531 598 | 2010_002623 599 | 2010_002693 600 | 2010_002701 601 | 2010_002763 602 | 2010_002921 603 | 2010_002929 604 | 2010_002939 605 | 2010_003123 606 | 2010_003187 607 | 2010_003207 608 | 2010_003239 609 | 2010_003275 610 | 2010_003325 611 | 2010_003365 612 | 2010_003381 613 | 2010_003409 614 | 2010_003453 615 | 2010_003473 616 | 2010_003495 617 | 2010_003531 618 | 2010_003547 619 | 2010_003675 620 | 2010_003781 621 | 2010_003813 622 | 2010_003915 623 | 2010_003971 624 | 2010_004041 625 | 2010_004063 626 | 2010_004149 627 | 2010_004165 628 | 2010_004219 629 | 2010_004355 630 | 2010_004419 631 | 2010_004479 632 | 2010_004529 633 | 2010_004543 634 | 2010_004551 635 | 2010_004559 636 | 2010_004697 637 | 2010_004763 638 | 2010_004783 639 | 2010_004795 640 | 2010_004815 641 | 2010_004825 642 | 2010_005013 643 | 2010_005021 644 | 2010_005063 645 | 2010_005159 646 | 2010_005187 647 | 2010_005245 648 | 2010_005305 649 | 2010_005421 650 | 2010_005531 651 | 2010_005705 652 | 2010_005709 653 | 2010_005719 654 | 2010_005727 655 | 2010_005871 656 | 2010_005877 657 | 2010_005899 658 | 2010_005991 659 | 2011_000045 660 | 2011_000051 661 | 2011_000173 662 | 2011_000185 663 | 2011_000291 664 | 2011_000419 665 | 2011_000435 666 | 2011_000455 667 | 2011_000479 668 | 2011_000503 669 | 2011_000521 670 | 2011_000536 671 | 2011_000598 672 | 2011_000607 673 | 2011_000661 674 | 2011_000669 675 | 2011_000747 676 | 2011_000789 677 | 2011_000809 678 | 2011_000843 679 | 2011_000969 680 | 2011_001069 681 | 2011_001071 682 | 2011_001161 683 | 2011_001263 684 | 2011_001281 685 | 2011_001287 686 | 2011_001313 687 | 2011_001341 688 | 2011_001421 689 | 2011_001447 690 | 2011_001529 691 | 2011_001567 692 | 2011_001589 693 | 2011_001597 694 | 2011_001601 695 | 2011_001607 696 | 2011_001613 697 | 2011_001619 698 | 2011_001665 699 | 2011_001669 700 | 2011_001713 701 | 2011_001745 702 | 2011_001775 703 | 2011_001793 704 | 2011_001812 705 | 2011_001868 706 | 2011_001984 707 | 2011_002041 708 | 2011_002121 709 | 2011_002223 710 | 2011_002279 711 | 2011_002295 712 | 2011_002317 713 | 2011_002327 714 | 2011_002343 715 | 2011_002371 716 | 2011_002379 717 | 2011_002391 718 | 2011_002509 719 | 2011_002535 720 | 2011_002575 721 | 2011_002589 722 | 2011_002623 723 | 2011_002641 724 | 2011_002675 725 | 2011_002685 726 | 2011_002713 727 | 2011_002863 728 | 2011_002929 729 | 2011_002993 730 | 2011_002997 731 | 2011_003011 732 | 2011_003055 733 | 2011_003085 734 | 2011_003145 735 | 2011_003197 736 | 2011_003271 737 | -------------------------------------------------------------------------------- /Dataloader/voc_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from scipy.io import loadmat 6 | from .baseloader import BaseLoader 7 | 8 | 9 | class VOCLoader(BaseLoader): 10 | """PASCAL VOC dataset loader. 11 | Parameters 12 | ---------- 13 | root: path to pascal voc dataset. 14 | for directory: 15 | --VOCdevkit--VOC2012---ImageSets 16 | |-JPEGImages 17 | |- ... 18 | root should be xxx/VOCdevkit/VOC2012 19 | n_classes: number of classes, default 21. 20 | split: choose subset of dataset, 'train','val' or 'trainval'. 21 | img_size: scale image to proper size. 22 | augmentations: whether to perform augmentation. 23 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255. 24 | class_weight: useful in unbalanced datasets. 25 | pretrained: whether to use pretrained models 26 | """ 27 | class_names = np.array([ 28 | 'background', 'aeroplane', 'bicycle', 29 | 'bird', 'boat', 'bottle', 'bus', 30 | 'car', 'cat', 'chair', 'cow', 'diningtable', 31 | 'dog', 'horse', 'motorbike', 'person', 32 | 'potted plant', 'sheep', 'sofa', 'train', 33 | 'tv/monitor', 34 | ]) 35 | 36 | def __init__( 37 | self, 38 | root, 39 | n_classes=21, 40 | split="train", 41 | img_size=None, 42 | augmentations=None, 43 | ignore_index=255, 44 | class_weight=None, 45 | pretrained=False 46 | ): 47 | super(VOCLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained) 48 | 49 | path = os.path.join(self.root, "ImageSets/Segmentation", split + ".txt") 50 | with open(path, "r") as f: 51 | self.file_list = [file_name.rstrip() for file_name in f] 52 | 53 | print(f"Found {len(self.file_list)} {split} images") 54 | 55 | def __len__(self): 56 | return len(self.file_list) 57 | 58 | def __getitem__(self, index): 59 | img_name = self.file_list[index] 60 | img_path = os.path.join(self.root, "JPEGImages", img_name + ".jpg") 61 | lbl_path = os.path.join(self.root, "SegmentationClass", img_name + ".png") 62 | img = Image.open(img_path).convert('RGB') 63 | lbl = Image.open(lbl_path) 64 | if self.img_size: 65 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR) 66 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST) 67 | if self.augmentations: 68 | img, lbl = self.augmentations(img, lbl) 69 | 70 | img, lbl = self.transform(img, lbl) 71 | return img, lbl 72 | 73 | def getpalette(self): 74 | n = self.n_classes 75 | palette = [0]*(n*3) 76 | for j in range(0, n): 77 | lab = j 78 | palette[j*3+0] = 0 79 | palette[j*3+1] = 0 80 | palette[j*3+2] = 0 81 | i = 0 82 | while (lab > 0): 83 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 84 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 85 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 86 | i = i + 1 87 | lab >>= 3 88 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8) 89 | return palette 90 | 91 | class SBDLoader(BaseLoader): 92 | """Semantic Boundaries Dataset(SBD) dataset loader. 93 | Parameters 94 | ---------- 95 | root: path to SBD dataset. 96 | for directory: 97 | --benchmark_RELEASE--dataset---img 98 | |-cls 99 | |-train.txt 100 | |- ... 101 | root should be xxx/benchmark_RELEASE 102 | n_classes: number of classes, default 21. 103 | split: choose subset of dataset, 'train' or 'val'. 104 | img_size: scale image to proper size. 105 | augmentations: whether to perform augmentation. 106 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255. 107 | class_weight: useful in unbalanced datasets. 108 | pretrained: whether to use pretrained models 109 | """ 110 | class_names = np.array([ 111 | 'background', 'aeroplane', 'bicycle', 112 | 'bird', 'boat', 'bottle', 'bus', 113 | 'car', 'cat', 'chair', 'cow', 'diningtable', 114 | 'dog', 'horse', 'motorbike', 'person', 115 | 'potted plant', 'sheep', 'sofa', 'train', 116 | 'tv/monitor', 117 | ]) 118 | def __init__( 119 | self, 120 | root, 121 | n_classes=21, 122 | split="train", 123 | img_size=None, 124 | augmentations=None, 125 | ignore_index=255, 126 | class_weight=None, 127 | pretrained=False 128 | ): 129 | super(SBDLoader, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained) 130 | 131 | path = os.path.join(self.root, 'dataset', split + ".txt") 132 | with open(path, "r") as f: 133 | self.file_list = [file_name.rstrip() for file_name in f] 134 | 135 | print(f"Found {len(self.file_list)} {split} images") 136 | 137 | def __len__(self): 138 | return len(self.file_list) 139 | 140 | def __getitem__(self, index): 141 | img_name = self.file_list[index] 142 | img_path = os.path.join(self.root, 'dataset/img', img_name + '.jpg') 143 | lbl_path = os.path.join(self.root, 'dataset/cls', img_name + '.mat') 144 | 145 | img = Image.open(img_path).convert('RGB') 146 | lbl = loadmat(lbl_path) 147 | lbl = lbl['GTcls'][0]['Segmentation'][0].astype(np.int32) 148 | lbl = Image.fromarray(lbl) 149 | if self.img_size: 150 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR) 151 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST) 152 | if self.augmentations: 153 | img, lbl = self.augmentations(img, lbl) 154 | 155 | img, lbl = self.transform(img, lbl) 156 | return img, lbl 157 | 158 | def getpalette(self): 159 | n = self.n_classes 160 | palette = [0]*(n*3) 161 | for j in range(0, n): 162 | lab = j 163 | palette[j*3+0] = 0 164 | palette[j*3+1] = 0 165 | palette[j*3+2] = 0 166 | i = 0 167 | while (lab > 0): 168 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 169 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 170 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 171 | i = i + 1 172 | lab >>= 3 173 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8) 174 | return palette 175 | 176 | class VOC11Val(BaseLoader): 177 | """load PASCAL VOC 2012 dataset, but only use seg11valid.txt for evaluation. 178 | Parameters 179 | ---------- 180 | root: path to PASCAL VOC 2012 dataset. 181 | n_classes: number of classes, default 21. 182 | split: only 'seg11valid' is available. 183 | img_size: scale image to proper size. 184 | augmentations: whether to perform augmentation. 185 | ignore_index: ingore_index will be ignored in training phase and evaluation, default 255. 186 | class_weight: useful in unbalanced datasets. 187 | pretrained: whether to use pretrained models 188 | """ 189 | class_names = np.array([ 190 | 'background', 'aeroplane', 'bicycle', 191 | 'bird', 'boat', 'bottle', 'bus', 192 | 'car', 'cat', 'chair', 'cow', 'diningtable', 193 | 'dog', 'horse', 'motorbike', 'person', 194 | 'potted plant', 'sheep', 'sofa', 'train', 195 | 'tv/monitor', 196 | ]) 197 | 198 | def __init__( 199 | self, 200 | root, 201 | n_classes=21, 202 | split="seg11valid", 203 | img_size=None, 204 | augmentations=None, 205 | ignore_index=255, 206 | class_weight=None, 207 | pretrained=False 208 | ): 209 | super(VOC11Val, self).__init__(root, n_classes, split, img_size, augmentations, ignore_index, class_weight, pretrained) 210 | 211 | current_path = os.path.realpath(__file__) 212 | 213 | path = os.path.join(current_path[:-13] + "seg11valid.txt") 214 | with open(path, "r") as f: 215 | self.file_list = [file_name.rstrip() for file_name in f] 216 | 217 | print(f"Found {len(self.file_list)} {split} images") 218 | 219 | def __len__(self): 220 | return len(self.file_list) 221 | 222 | def __getitem__(self, index): 223 | img_name = self.file_list[index] 224 | img_path = os.path.join(self.root, "JPEGImages", img_name + ".jpg") 225 | lbl_path = os.path.join(self.root, "SegmentationClass", img_name + ".png") 226 | img = Image.open(img_path).convert('RGB') 227 | lbl = Image.open(lbl_path) 228 | 229 | if self.img_size: 230 | img = img.resize((self.img_size[1], self.img_size[0]), Image.BILINEAR) 231 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), Image.NEAREST) 232 | if self.augmentations: 233 | img, lbl = self.augmentations(img, lbl) 234 | 235 | img, lbl = self.transform(img, lbl) 236 | return img, lbl 237 | 238 | def getpalette(self): 239 | n = self.n_classes 240 | palette = [0]*(n*3) 241 | for j in range(0, n): 242 | lab = j 243 | palette[j*3+0] = 0 244 | palette[j*3+1] = 0 245 | palette[j*3+2] = 0 246 | i = 0 247 | while (lab > 0): 248 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 249 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 250 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 251 | i = i + 1 252 | lab >>= 3 253 | palette = np.array(palette).reshape([-1, 3]).astype(np.uint8) 254 | return palette 255 | 256 | # Test code 257 | # if __name__ == '__main__': 258 | # from torch.utils.data import DataLoader 259 | # root = r'D:\Datasets\VOCdevkit\VOC2012' 260 | # batch_size = 2 261 | # loader = VOCLoader(root=root, img_size=(500, 500)) 262 | # test_loader = DataLoader(loader, batch_size=batch_size, shuffle=True) 263 | 264 | # palette = test_loader.dataset.getpalette() 265 | # fig, axes = plt.subplots(batch_size, 2, subplot_kw={'xticks': [], 'yticks': []}) 266 | # fig.subplots_adjust(left=0.03, right=0.97, hspace=0.2, wspace=0.05) 267 | 268 | # for imgs, labels in test_loader: 269 | # imgs = imgs.numpy() 270 | # imgs = np.transpose(imgs, [0,2,3,1]) 271 | # labels = labels.numpy() 272 | 273 | # for i in range(batch_size): 274 | # axes[i][0].imshow(imgs[i]) 275 | 276 | # mask_unlabeled = labels[i] == -1 277 | # viz_unlabeled = ( 278 | # np.zeros((labels[i].shape[0], labels[i].shape[1], 3)) 279 | # ).astype(np.uint8) 280 | 281 | # lbl_viz = palette[labels[i]] 282 | # lbl_viz[labels[i] == -1] = (0, 0, 0) 283 | # lbl_viz[mask_unlabeled] = viz_unlabeled[mask_unlabeled] 284 | 285 | # axes[i][1].imshow(lbl_viz.astype(np.uint8)) 286 | # plt.show() 287 | # break 288 | 289 | 290 | -------------------------------------------------------------------------------- /Models/DeepLab_v1.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DeepLabLargeFOV(nn.Module): 8 | """Adapted from official implementation: 9 | 10 | http://www.cs.jhu.edu/~alanlab/ccvl/DeepLab-LargeFOV/train.prototxt 11 | 12 | input dimension equal to 13 | n = 32 * k - 31, e.g., 321 (for k = 11) 14 | Dimension after pooling w. subsampling: 15 | (16 * k - 15); (8 * k - 7); (4 * k - 3); (2 * k - 1); (k). 16 | For k = 11, these translate to 17 | 161; 81; 41; 21; 11 18 | """ 19 | def __init__(self, n_classes): 20 | super(DeepLabLargeFOV, self).__init__() 21 | 22 | features = [] 23 | features.append(nn.Conv2d(3, 64, 3, padding=1)) 24 | features.append(nn.ReLU(inplace=True)) 25 | features.append(nn.Conv2d(64, 64, 3, padding=1)) 26 | features.append(nn.ReLU(inplace=True)) 27 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True)) 28 | 29 | features.append(nn.Conv2d(64, 128, 3, padding=1)) 30 | features.append(nn.ReLU(inplace=True)) 31 | features.append(nn.Conv2d(128, 128, 3, padding=1)) 32 | features.append(nn.ReLU(inplace=True)) 33 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True)) 34 | 35 | features.append(nn.Conv2d(128, 256, 3, padding=1)) 36 | features.append(nn.ReLU(inplace=True)) 37 | features.append(nn.Conv2d(256, 256, 3, padding=1)) 38 | features.append(nn.ReLU(inplace=True)) 39 | features.append(nn.Conv2d(256, 256, 3, padding=1)) 40 | features.append(nn.ReLU(inplace=True)) 41 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True)) 42 | 43 | features.append(nn.Conv2d(256, 512, 3, padding=1)) 44 | features.append(nn.ReLU(inplace=True)) 45 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 46 | features.append(nn.ReLU(inplace=True)) 47 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 48 | features.append(nn.ReLU(inplace=True)) 49 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True)) 50 | 51 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2)) 52 | features.append(nn.ReLU(inplace=True)) 53 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2)) 54 | features.append(nn.ReLU(inplace=True)) 55 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2)) 56 | features.append(nn.ReLU(inplace=True)) 57 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True)) 58 | self.features = nn.Sequential(*features) 59 | 60 | fc = [] 61 | fc.append(nn.AvgPool2d(3, stride=1, padding=1)) 62 | fc.append(nn.Conv2d(512, 1024, 3, padding=12, dilation=12)) 63 | fc.append(nn.ReLU(inplace=True)) 64 | fc.append(nn.Conv2d(1024, 1024, 1)) 65 | fc.append(nn.ReLU(inplace=True)) 66 | fc.append(nn.Dropout(p=0.5)) 67 | self.fc = nn.Sequential(*fc) 68 | 69 | self.score = nn.Conv2d(1024, n_classes, 1) 70 | 71 | self._initialize_weights() 72 | 73 | def _initialize_weights(self): 74 | 75 | vgg = torchvision.models.vgg16(pretrained=True) 76 | state_dict = vgg.features.state_dict() 77 | self.features.load_state_dict(state_dict) 78 | 79 | # for m in self.fc.modules(): 80 | # if isinstance(m, nn.Conv2d): 81 | # nn.init.kaiming_normal_(m.weight) 82 | # nn.init.constant_(m.bias, 0) 83 | 84 | nn.init.normal_(self.score.weight, std=0.01) 85 | nn.init.constant_(self.score.bias, 0) 86 | 87 | def forward(self, x): 88 | _, _, h, w = x.size() 89 | out = self.features(x) 90 | out = self.fc(out) 91 | out = self.score(out) 92 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 93 | return out 94 | 95 | def get_parameters(self, bias=False, score=False): 96 | if score: 97 | if bias: 98 | yield self.score.bias 99 | else: 100 | yield self.score.weight 101 | else: 102 | for module in [self.features, self.fc]: 103 | for m in module.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | if bias: 106 | yield m.bias 107 | else: 108 | yield m.weight 109 | 110 | 111 | if __name__ == "__main__": 112 | import torch 113 | import time 114 | model = DeepLabLargeFOV(21) 115 | print(f'==> Testing {model.__class__.__name__} with PyTorch') 116 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 117 | 118 | model = model.to(device) 119 | model.eval() 120 | 121 | x = torch.Tensor(1, 3, 321, 321) 122 | x = x.to(device) 123 | 124 | torch.cuda.synchronize() 125 | t_start = time.time() 126 | for i in range(10): 127 | model(x) 128 | torch.cuda.synchronize() 129 | elapsed_time = time.time() - t_start 130 | 131 | print(f'Speed: {(elapsed_time / 10) * 1000:.2f} ms') 132 | -------------------------------------------------------------------------------- /Models/DeepLab_v2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torchvision 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DeepLabASPPVGG(nn.Module): 9 | """Adapted from official implementation: 10 | 11 | http://liangchiehchen.com/projects/DeepLabv2_vgg.html 12 | """ 13 | def __init__(self, n_classes): 14 | super(DeepLabASPPVGG, self).__init__() 15 | 16 | features = [] 17 | features.append(nn.Conv2d(3, 64, 3, padding=1)) 18 | features.append(nn.ReLU(inplace=True)) 19 | features.append(nn.Conv2d(64, 64, 3, padding=1)) 20 | features.append(nn.ReLU(inplace=True)) 21 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True)) 22 | 23 | features.append(nn.Conv2d(64, 128, 3, padding=1)) 24 | features.append(nn.ReLU(inplace=True)) 25 | features.append(nn.Conv2d(128, 128, 3, padding=1)) 26 | features.append(nn.ReLU(inplace=True)) 27 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True)) 28 | 29 | features.append(nn.Conv2d(128, 256, 3, padding=1)) 30 | features.append(nn.ReLU(inplace=True)) 31 | features.append(nn.Conv2d(256, 256, 3, padding=1)) 32 | features.append(nn.ReLU(inplace=True)) 33 | features.append(nn.Conv2d(256, 256, 3, padding=1)) 34 | features.append(nn.ReLU(inplace=True)) 35 | features.append(nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=True)) 36 | 37 | features.append(nn.Conv2d(256, 512, 3, padding=1)) 38 | features.append(nn.ReLU(inplace=True)) 39 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 40 | features.append(nn.ReLU(inplace=True)) 41 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 42 | features.append(nn.ReLU(inplace=True)) 43 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True)) 44 | 45 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2)) 46 | features.append(nn.ReLU(inplace=True)) 47 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2)) 48 | features.append(nn.ReLU(inplace=True)) 49 | features.append(nn.Conv2d(512, 512, 3, padding=2, dilation=2)) 50 | features.append(nn.ReLU(inplace=True)) 51 | features.append(nn.MaxPool2d(3, stride=1, padding=1, ceil_mode=True)) 52 | self.features = nn.Sequential(*features) 53 | 54 | # hole = 6 55 | fc1 = [] 56 | fc1.append(nn.Conv2d(512, 1024, 3, padding=6, dilation=6)) 57 | fc1.append(nn.ReLU(inplace=True)) 58 | fc1.append(nn.Dropout(p=0.5)) 59 | fc1.append(nn.Conv2d(1024, 1024, 1)) 60 | fc1.append(nn.ReLU(inplace=True)) 61 | fc1.append(nn.Dropout(p=0.5)) 62 | self.fc1 = nn.Sequential(*fc1) 63 | self.fc1_score = nn.Conv2d(1024, n_classes, 1) 64 | 65 | # hole = 12 66 | fc2 = [] 67 | fc2.append(nn.Conv2d(512, 1024, 3, padding=12, dilation=12)) 68 | fc2.append(nn.ReLU(inplace=True)) 69 | fc2.append(nn.Dropout(p=0.5)) 70 | fc2.append(nn.Conv2d(1024, 1024, 1)) 71 | fc2.append(nn.ReLU(inplace=True)) 72 | fc2.append(nn.Dropout(p=0.5)) 73 | self.fc2 = nn.Sequential(*fc2) 74 | self.fc2_score = nn.Conv2d(1024, n_classes, 1) 75 | 76 | # hole = 18 77 | fc3 = [] 78 | fc3.append(nn.Conv2d(512, 1024, 3, padding=18, dilation=18)) 79 | fc3.append(nn.ReLU(inplace=True)) 80 | fc3.append(nn.Dropout(p=0.5)) 81 | fc3.append(nn.Conv2d(1024, 1024, 1)) 82 | fc3.append(nn.ReLU(inplace=True)) 83 | fc3.append(nn.Dropout(p=0.5)) 84 | self.fc3 = nn.Sequential(*fc3) 85 | self.fc3_score = nn.Conv2d(1024, n_classes, 1) 86 | 87 | # hole = 24 88 | fc4 = [] 89 | fc4.append(nn.Conv2d(512, 1024, 3, padding=24, dilation=24)) 90 | fc4.append(nn.ReLU(inplace=True)) 91 | fc4.append(nn.Dropout(p=0.5)) 92 | fc4.append(nn.Conv2d(1024, 1024, 1)) 93 | fc4.append(nn.ReLU(inplace=True)) 94 | fc4.append(nn.Dropout(p=0.5)) 95 | self.fc4 = nn.Sequential(*fc4) 96 | self.fc4_score = nn.Conv2d(1024, n_classes, 1) 97 | 98 | self._initialize_weights() 99 | 100 | def _initialize_weights(self): 101 | for m in [self.fc1_score, self.fc2_score, self.fc3_score, self.fc4_score]: 102 | nn.init.normal_(m.weight, std=0.01) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | # for module in [self.fc1, self.fc2, self.fc3, self.fc4]: 106 | # for m in self.modules(): 107 | # if isinstance(m, nn.Conv2d): 108 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | # nn.init.normal_(m.weight, std=math.sqrt(2. / n)) 110 | # nn.init.kaiming_normal_(m.weight, mode='fan_in') 111 | # nn.init.constant_(m.bias, 0) 112 | 113 | vgg = torchvision.models.vgg16(pretrained=True) 114 | state_dict = vgg.features.state_dict() 115 | self.features.load_state_dict(state_dict) 116 | 117 | def forward(self, x): 118 | _, _, h, w = x.size() 119 | out = self.features(x) 120 | fuse1 = self.fc1_score(self.fc1(out)) 121 | fuse2 = self.fc2_score(self.fc2(out)) 122 | fuse3 = self.fc3_score(self.fc3(out)) 123 | fuse4 = self.fc4_score(self.fc4(out)) 124 | out = fuse1 + fuse2 + fuse3 + fuse4 125 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 126 | return out 127 | 128 | def get_parameters(self, bias=False, score=False): 129 | if score: 130 | for m in [self.fc1_score, self.fc2_score, self.fc3_score, self.fc4_score]: 131 | if bias: 132 | yield m.bias 133 | else: 134 | yield m.weight 135 | else: 136 | for module in [self.features, self.fc1, self.fc2, self.fc3, self.fc4]: 137 | for m in module.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | if bias: 140 | yield m.bias 141 | else: 142 | yield m.weight 143 | 144 | def freeze_bn(m): 145 | classname = m.__class__.__name__ 146 | if classname.find('BatchNorm') != -1: 147 | for p in m.parameters(): 148 | p.requires_grad = False 149 | 150 | 151 | class DeepLabASPPResNet(nn.Module): 152 | def __init__(self, n_classes): 153 | super(DeepLabASPPResNet, self).__init__() 154 | self.resnet = ResNet(Bottleneck, [3, 4, 23, 3]) 155 | self.atrous_rates = [6, 12, 18, 24] 156 | self.aspp = ASPP(2048, self.atrous_rates, n_classes) 157 | self.resnet.apply(freeze_bn) 158 | 159 | def forward(self, x): 160 | _, _, h, w = x.size() 161 | x2 = F.interpolate(x, size=(int(h * 0.75) + 1, int(w * 0.75) + 1), mode='bilinear', align_corners=True) 162 | x3 = F.interpolate(x, size=(int(h * 0.5) + 1, int(w * 0.5) + 1), mode='bilinear', align_corners=True) 163 | x = self.aspp(self.resnet(x)) 164 | x2 = self.aspp(self.resnet(x2)) 165 | x3 = self.aspp(self.resnet(x3)) 166 | 167 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) 168 | 169 | x2 = F.interpolate(x2, size=(h, w), mode='bilinear', align_corners=True) 170 | 171 | x3 = F.interpolate(x3, size=(h, w), mode='bilinear', align_corners=True) 172 | 173 | x4 = torch.max(torch.max(x, x2), x3) 174 | return x, x2, x3, x4 175 | 176 | def get_parameters(self, bias=False, score=False): 177 | if score: 178 | for m in self.aspp.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | if bias: 181 | yield m.bias 182 | else: 183 | yield m.weight 184 | else: 185 | for m in self.resnet.modules(): 186 | for p in m.parameters(): 187 | if p.requires_grad: 188 | yield p 189 | 190 | 191 | class ASPP(nn.Module): 192 | def __init__(self, in_channels, atrous_rates, n_classes): 193 | super(ASPP, self).__init__() 194 | 195 | rate1, rate2, rate3, rate4 = atrous_rates 196 | self.conv1 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate1, dilation=rate1, bias=True) 197 | self.conv2 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate2, dilation=rate2, bias=True) 198 | self.conv3 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate3, dilation=rate3, bias=True) 199 | self.conv4 = nn.Conv2d(2048, n_classes, kernel_size=3, padding=rate4, dilation=rate4, bias=True) 200 | 201 | self._initialize_weights() 202 | 203 | # def _initialize_weights(self): 204 | # for m in self.modules(): 205 | # if isinstance(m, nn.Conv2d): 206 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 207 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 208 | # nn.init.kaiming_normal_(m.weight, mode='fan_out') 209 | # nn.init.constant_(m.bias, 0) 210 | 211 | def forward(self, x): 212 | features1 = self.conv1(x) 213 | features2 = self.conv2(x) 214 | features3 = self.conv3(x) 215 | features4 = self.conv4(x) 216 | out = features1 + features2 + features3 + features4 217 | 218 | return out 219 | 220 | def conv1x1(in_planes, out_planes, stride=1): 221 | """1x1 convolution""" 222 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 223 | 224 | class Bottleneck(nn.Module): 225 | expansion = 4 226 | 227 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 228 | super(Bottleneck, self).__init__() 229 | self.conv1 = conv1x1(inplanes, planes) 230 | self.bn1 = nn.BatchNorm2d(planes) 231 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 232 | padding=dilation, dilation=dilation, bias=False) 233 | self.bn2 = nn.BatchNorm2d(planes) 234 | self.conv3 = conv1x1(planes, planes * self.expansion) 235 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 236 | self.relu = nn.ReLU(inplace=True) 237 | self.downsample = downsample 238 | self.stride = stride 239 | 240 | def forward(self, x): 241 | identity = x 242 | 243 | out = self.conv1(x) 244 | out = self.bn1(out) 245 | out = self.relu(out) 246 | 247 | out = self.conv2(out) 248 | out = self.bn2(out) 249 | out = self.relu(out) 250 | 251 | out = self.conv3(out) 252 | out = self.bn3(out) 253 | 254 | if self.downsample is not None: 255 | identity = self.downsample(x) 256 | 257 | out += identity 258 | out = self.relu(out) 259 | 260 | return out 261 | 262 | 263 | class ResNet(nn.Module): 264 | """ 265 | Adapted from https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/deeplabv3.py 266 | """ 267 | def __init__(self, block, layers): 268 | super(ResNet, self).__init__() 269 | self.inplanes = 64 270 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 271 | bias=False) 272 | self.bn1 = nn.BatchNorm2d(64) 273 | 274 | self.relu = nn.ReLU(inplace=True) 275 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 276 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 277 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 278 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 279 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 280 | 281 | self._initialize_weights() 282 | 283 | # def _initialize_weights(self): 284 | # for m in self.modules(): 285 | # if isinstance(m, nn.Conv2d): 286 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 287 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 288 | # nn.init.kaiming_normal_(m.weight, mode='fan_out') 289 | # elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm): 290 | # m.weight.data.fill_(1) 291 | # m.bias.data.zero_() 292 | 293 | resnet = torchvision.models.resnet101(pretrained=True) 294 | self.conv1.load_state_dict(resnet.conv1.state_dict()) 295 | self.bn1.load_state_dict(resnet.bn1.state_dict()) 296 | self.layer1.load_state_dict(resnet.layer1.state_dict()) 297 | self.layer2.load_state_dict(resnet.layer2.state_dict()) 298 | self.layer3.load_state_dict(resnet.layer3.state_dict()) 299 | self.layer4.load_state_dict(resnet.layer4.state_dict()) 300 | 301 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 302 | downsample = None 303 | if stride != 1 or self.inplanes != planes * block.expansion: 304 | downsample = nn.Sequential( 305 | nn.Conv2d(self.inplanes, planes * block.expansion, 306 | kernel_size=1, stride=stride, bias=False), 307 | nn.BatchNorm2d(planes * block.expansion)) 308 | 309 | layers = [] 310 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample)) 311 | self.inplanes = planes * block.expansion 312 | for i in range(1, blocks): 313 | layers.append(block(self.inplanes, planes, dilation=dilation)) 314 | 315 | return nn.Sequential(*layers) 316 | 317 | def forward(self, x): 318 | x = self.conv1(x) 319 | x = self.bn1(x) 320 | x = self.relu(x) 321 | x = self.maxpool(x) 322 | 323 | x = self.layer1(x) 324 | x = self.layer2(x) 325 | x = self.layer3(x) 326 | x = self.layer4(x) 327 | 328 | return x 329 | -------------------------------------------------------------------------------- /Models/DeepLab_v3.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DeepLabV3(nn.Module): 8 | def __init__(self, n_classes): 9 | super(DeepLabV3, self).__init__() 10 | self.n_classes = n_classes 11 | self.resnet = ResNet(Bottleneck, [3, 4, 6, 3]) 12 | # self.atrous_rates = [6, 12, 18] # output_stride = 16 13 | self.atrous_rates = [12, 24, 36] # output_stride = 8 14 | self.aspp = ASPP(2048, self.atrous_rates) 15 | 16 | self.final = nn.Conv2d(256, n_classes, 1) 17 | nn.init.normal_(self.final.weight, 0.01) 18 | nn.init.constant_(self.final.bias, 0) 19 | 20 | def forward(self, x): 21 | _, _, h, w = x.size() 22 | out = self.resnet(x) 23 | out = self.aspp(out) 24 | out = self.final(out) 25 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True) 26 | return out 27 | 28 | 29 | class ASPP(nn.Module): 30 | def __init__(self, in_channels, atrous_rates): 31 | super(ASPP, self).__init__() 32 | out_channels = 256 33 | 34 | self.imagepool = nn.Sequential( 35 | nn.AdaptiveAvgPool2d(1), 36 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 37 | kernel_size=1, bias=False), 38 | nn.BatchNorm2d(num_features=out_channels), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | self.conv1x1 = nn.Sequential( 43 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 44 | kernel_size=1, bias=False), 45 | nn.BatchNorm2d(num_features=out_channels), 46 | nn.ReLU(inplace=True) 47 | ) 48 | 49 | rate1, rate2, rate3 = tuple(atrous_rates) 50 | self.conv1 = self._ASPPConv(in_channels, out_channels, rate1) 51 | self.conv2 = self._ASPPConv(in_channels, out_channels, rate2) 52 | self.conv3 = self._ASPPConv(in_channels, out_channels, rate3) 53 | 54 | self.project = nn.Sequential( 55 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, 56 | kernel_size=1, bias=False), 57 | nn.BatchNorm2d(num_features=out_channels), 58 | nn.ReLU(inplace=True), 59 | nn.Dropout(p=0.1) 60 | ) 61 | 62 | self._initialize_weights() 63 | 64 | def _initialize_weights(self): 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight) 68 | if m.bias is not None: 69 | nn.init.constant_(m.bias, 0) 70 | 71 | 72 | def _ASPPConv(self, in_channels, out_channels, atrous_rate): 73 | block = nn.Sequential( 74 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 75 | kernel_size=3, padding=atrous_rate, 76 | dilation=atrous_rate, bias=False), 77 | nn.BatchNorm2d(num_features=out_channels), 78 | nn.ReLU(inplace=True) 79 | ) 80 | return block 81 | 82 | def forward(self, x): 83 | _, _, h, w = x.size() 84 | 85 | features1 = F.interpolate(self.imagepool(x), size=(h, w), mode='bilinear', align_corners=True) 86 | 87 | features2 = self.conv1x1(x) 88 | features3 = self.conv1(x) 89 | features4 = self.conv2(x) 90 | features5 = self.conv3(x) 91 | out = torch.cat((features1, features2, features3, features4, features5), 1) 92 | 93 | out = self.project(out) 94 | return out 95 | 96 | def conv1x1(in_planes, out_planes, stride=1): 97 | """1x1 convolution""" 98 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 99 | 100 | class Bottleneck(nn.Module): 101 | expansion = 4 102 | 103 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, multi_grid=1): 104 | super(Bottleneck, self).__init__() 105 | self.conv1 = conv1x1(inplanes, planes) 106 | self.bn1 = nn.BatchNorm2d(planes) 107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 108 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) 109 | self.bn2 = nn.BatchNorm2d(planes) 110 | self.conv3 = conv1x1(planes, planes * self.expansion) 111 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.downsample = downsample 114 | self.stride = stride 115 | 116 | def forward(self, x): 117 | identity = x 118 | 119 | out = self.conv1(x) 120 | out = self.bn1(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv2(out) 124 | out = self.bn2(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv3(out) 128 | out = self.bn3(out) 129 | 130 | if self.downsample is not None: 131 | identity = self.downsample(x) 132 | 133 | out += identity 134 | out = self.relu(out) 135 | 136 | return out 137 | 138 | 139 | class ResNet(nn.Module): 140 | """ 141 | Adapted from https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/deeplabv3.py 142 | """ 143 | def __init__(self, block, layers): 144 | super(ResNet, self).__init__() 145 | self.inplanes = 64 146 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 147 | bias=False) 148 | self.bn1 = nn.BatchNorm2d(64) 149 | self.relu = nn.ReLU(inplace=True) 150 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 151 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 153 | 154 | # for output_stride = 16 155 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 156 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 2, 4)) 157 | 158 | # for output_stride = 8 159 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 160 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1, 2, 4)) 161 | 162 | self._initialize_weights() 163 | 164 | def _initialize_weights(self): 165 | resnet = torchvision.models.resnet50(pretrained=True) 166 | self.conv1.load_state_dict(resnet.conv1.state_dict()) 167 | self.bn1.load_state_dict(resnet.bn1.state_dict()) 168 | self.layer1.load_state_dict(resnet.layer1.state_dict()) 169 | self.layer2.load_state_dict(resnet.layer2.state_dict()) 170 | self.layer3.load_state_dict(resnet.layer3.state_dict()) 171 | self.layer4.load_state_dict(resnet.layer4.state_dict()) 172 | 173 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): 174 | downsample = None 175 | if stride != 1 or self.inplanes != planes * block.expansion: 176 | downsample = nn.Sequential( 177 | nn.Conv2d(self.inplanes, planes * block.expansion, 178 | kernel_size=1, stride=stride, bias=False), 179 | nn.BatchNorm2d(planes * block.expansion)) 180 | 181 | layers = [] 182 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 183 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) 184 | self.inplanes = planes * block.expansion 185 | for i in range(1, blocks): 186 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) 187 | 188 | return nn.Sequential(*layers) 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | 196 | x = self.layer1(x) 197 | x = self.layer2(x) 198 | x = self.layer3(x) 199 | x = self.layer4(x) 200 | 201 | return x -------------------------------------------------------------------------------- /Models/DeepLab_v3plus.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DeepLabV3Plus(nn.Module): 8 | def __init__(self, n_classes): 9 | super(DeepLabV3Plus, self).__init__() 10 | 11 | self.resnet = ResNet(Bottleneck, [3, 4, 6, 3]) 12 | self.head = _DeepLabHead() 13 | self.decoder1 = nn.Conv2d(64, 48, 1) 14 | self.decoder2 = nn.Sequential( 15 | nn.Conv2d(304, 256, 3), 16 | nn.BatchNorm2d(256), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(256, 256, 3), 19 | nn.BatchNorm2d(256), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(256, n_classes, 1) 22 | ) 23 | 24 | def forward(self, x): 25 | _, _, h, w = x.size() 26 | out, branch = self.resnet(x) 27 | _, _, uh, uw = branch.size() 28 | out = self.head(out) 29 | out = F.interpolate(out, size=(uh, uw), mode='bilinear', align_corners=True) 30 | branch = self.decoder1(branch) 31 | out = torch.cat([out, branch], 1) 32 | out = self.decoder2(out) 33 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True) 34 | return out 35 | 36 | class _DeepLabHead(nn.Module): 37 | def __init__(self): 38 | super(_DeepLabHead, self).__init__() 39 | self.aspp = ASPP(2048, [6, 12, 18]) # output_stride = 16 40 | # self.aspp = ASPP(2048, [12, 24, 36]) # output_stride = 8 41 | # self.block = nn.Sequential( 42 | # nn.Conv2d(in_channels=256, out_channels=256, 43 | # kernel_size=3, padding=1, bias=False), 44 | # nn.BatchNorm2d(num_features=256), 45 | # nn.ReLU(inplace=True), 46 | # nn.Dropout(0.1), 47 | # nn.Conv2d(in_channels=256, out_channels=n_classes, 48 | # kernel_size=1) 49 | # ) 50 | 51 | def forward(self, x): 52 | out = self.aspp(x) 53 | # out = self.block(out) 54 | return out 55 | 56 | 57 | class ASPP(nn.Module): 58 | def __init__(self, in_channels, atrous_rates): 59 | super(ASPP, self).__init__() 60 | out_channels = 256 61 | 62 | self.imagepool = nn.Sequential( 63 | nn.AdaptiveAvgPool2d(1), 64 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 65 | kernel_size=1, bias=False), 66 | nn.BatchNorm2d(num_features=out_channels), 67 | nn.ReLU(inplace=True) 68 | ) 69 | 70 | self.conv1x1 = nn.Sequential( 71 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 72 | kernel_size=1, bias=False), 73 | nn.BatchNorm2d(num_features=out_channels), 74 | nn.ReLU(inplace=True) 75 | ) 76 | 77 | rate1, rate2, rate3 = tuple(atrous_rates) 78 | self.conv1 = self._ASPPConv(in_channels, out_channels, rate1) 79 | self.conv2 = self._ASPPConv(in_channels, out_channels, rate2) 80 | self.conv3 = self._ASPPConv(in_channels, out_channels, rate3) 81 | 82 | self.project = nn.Sequential( 83 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, 84 | kernel_size=1, bias=False), 85 | nn.BatchNorm2d(num_features=out_channels), 86 | nn.ReLU(inplace=True), 87 | # nn.Dropout(p=0.5) 88 | ) 89 | 90 | def forward(self, x): 91 | _, _, h, w = x.size() 92 | 93 | features1 = F.interpolate(self.imagepool(x), size=(h, w), mode='bilinear', align_corners=True) 94 | 95 | features2 = self.conv1x1(x) 96 | features3 = self.conv1(x) 97 | features4 = self.conv2(x) 98 | features5 = self.conv3(x) 99 | out = torch.cat((features1, features2, features3, features4, features5), 1) 100 | out = self.project(out) 101 | return out 102 | 103 | def _ASPPConv(self, in_channels, out_channels, atrous_rate): 104 | block = nn.Sequential( 105 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 106 | kernel_size=3, padding=atrous_rate, 107 | dilation=atrous_rate, bias=False), 108 | nn.BatchNorm2d(num_features=out_channels), 109 | nn.ReLU(inplace=True) 110 | ) 111 | return block 112 | 113 | 114 | def conv1x1(in_planes, out_planes, stride=1): 115 | """1x1 convolution""" 116 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 117 | 118 | class Bottleneck(nn.Module): 119 | expansion = 4 120 | 121 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, multi_grid=1): 122 | super(Bottleneck, self).__init__() 123 | self.conv1 = conv1x1(inplanes, planes) 124 | self.bn1 = nn.BatchNorm2d(planes) 125 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 126 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) 127 | self.bn2 = nn.BatchNorm2d(planes) 128 | self.conv3 = conv1x1(planes, planes * self.expansion) 129 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.downsample = downsample 132 | self.stride = stride 133 | 134 | def forward(self, x): 135 | identity = x 136 | 137 | out = self.conv1(x) 138 | out = self.bn1(out) 139 | out = self.relu(out) 140 | 141 | out = self.conv2(out) 142 | out = self.bn2(out) 143 | out = self.relu(out) 144 | 145 | out = self.conv3(out) 146 | out = self.bn3(out) 147 | 148 | if self.downsample is not None: 149 | identity = self.downsample(x) 150 | 151 | out += identity 152 | out = self.relu(out) 153 | 154 | return out 155 | 156 | 157 | class ResNet(nn.Module): 158 | """ 159 | Adapted from https://github.com/speedinghzl/pytorch-segmentation-toolbox/blob/master/networks/deeplabv3.py 160 | """ 161 | def __init__(self, block, layers): 162 | super(ResNet, self).__init__() 163 | self.inplanes = 64 164 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 165 | bias=False) 166 | self.bn1 = nn.BatchNorm2d(64) 167 | self.relu = nn.ReLU(inplace=True) 168 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 169 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 170 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 171 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1) 172 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=2, multi_grid=(1, 2, 4)) 173 | 174 | # for output_stride = 8 175 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 176 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1, 2, 4)) 177 | 178 | self._initialize_weights() 179 | 180 | def _initialize_weights(self): 181 | resnet = torchvision.models.resnet50(pretrained=True) 182 | self.conv1.load_state_dict(resnet.conv1.state_dict()) 183 | self.bn1.load_state_dict(resnet.bn1.state_dict()) 184 | self.layer1.load_state_dict(resnet.layer1.state_dict()) 185 | self.layer2.load_state_dict(resnet.layer2.state_dict()) 186 | self.layer3.load_state_dict(resnet.layer3.state_dict()) 187 | self.layer4.load_state_dict(resnet.layer4.state_dict()) 188 | 189 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): 190 | downsample = None 191 | if stride != 1 or self.inplanes != planes * block.expansion: 192 | downsample = nn.Sequential( 193 | nn.Conv2d(self.inplanes, planes * block.expansion, 194 | kernel_size=1, stride=stride, bias=False), 195 | nn.BatchNorm2d(planes * block.expansion)) 196 | 197 | layers = [] 198 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 199 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) 200 | self.inplanes = planes * block.expansion 201 | for i in range(1, blocks): 202 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) 203 | 204 | return nn.Sequential(*layers) 205 | 206 | def forward(self, x): 207 | out = self.conv1(x) 208 | out = self.bn1(out) 209 | out = self.relu(out) 210 | branch = self.maxpool(out) 211 | 212 | out = self.layer1(branch) 213 | out = self.layer2(out) 214 | out = self.layer3(out) 215 | out = self.layer4(out) 216 | 217 | return out, branch -------------------------------------------------------------------------------- /Models/Dilation8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | class Dilation8(nn.Module): 7 | """Adapted from official dilated8 implementation: 8 | 9 | https://github.com/fyu/dilation/blob/master/models/dilation8_pascal_voc_deploy.prototxt 10 | """ 11 | def __init__(self, n_classes): 12 | super(Dilation8, self).__init__() 13 | features1 = [] 14 | # conv1 15 | features1.append(nn.Conv2d(3, 64, 3)) 16 | features1.append(nn.ReLU(inplace=True)) 17 | features1.append(nn.Conv2d(64, 64, 3)) 18 | features1.append(nn.ReLU(inplace=True)) 19 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2 20 | 21 | # conv2 22 | features1.append(nn.Conv2d(64, 128, 3)) 23 | features1.append(nn.ReLU(inplace=True)) 24 | features1.append(nn.Conv2d(128, 128, 3)) 25 | features1.append(nn.ReLU(inplace=True)) 26 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4 27 | 28 | # conv3 29 | features1.append(nn.Conv2d(128, 256, 3)) 30 | features1.append(nn.ReLU(inplace=True)) 31 | features1.append(nn.Conv2d(256, 256, 3)) 32 | features1.append(nn.ReLU(inplace=True)) 33 | features1.append(nn.Conv2d(256, 256, 3)) 34 | features1.append(nn.ReLU(inplace=True)) 35 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8 36 | 37 | # conv4 38 | features1.append(nn.Conv2d(256, 512, 3)) 39 | features1.append(nn.ReLU(inplace=True)) 40 | features1.append(nn.Conv2d(512, 512, 3)) 41 | features1.append(nn.ReLU(inplace=True)) 42 | features1.append(nn.Conv2d(512, 512, 3)) 43 | features1.append(nn.ReLU(inplace=True)) 44 | self.features1 = nn.Sequential(*features1) 45 | 46 | # conv5 47 | features2 = [] 48 | features2.append(nn.Conv2d(512, 512, 3, dilation=2)) 49 | features2.append(nn.ReLU(inplace=True)) 50 | features2.append(nn.Conv2d(512, 512, 3, dilation=2)) 51 | features2.append(nn.ReLU(inplace=True)) 52 | features2.append(nn.Conv2d(512, 512, 3, dilation=2)) 53 | features2.append(nn.ReLU(inplace=True)) 54 | self.features2 = nn.Sequential(*features2) 55 | 56 | fc = [] 57 | fc.append(nn.Conv2d(512, 4096, 7, dilation=4)) 58 | fc.append(nn.ReLU(inplace=True)) 59 | fc.append(nn.Dropout(p=0.5)) 60 | fc.append(nn.Conv2d(4096, 4096, 1)) 61 | fc.append(nn.ReLU(inplace=True)) 62 | fc.append(nn.Dropout(p=0.5)) 63 | fc.append(nn.Conv2d(4096, n_classes, 1)) 64 | self.fc = nn.Sequential(*fc) 65 | 66 | context = [] 67 | context.append(nn.Conv2d(n_classes, 2 * n_classes, 3, padding=33)) 68 | context.append(nn.ReLU(inplace=True)) 69 | context.append(nn.Conv2d(2 * n_classes, 2 * n_classes, 3, padding=0)) 70 | context.append(nn.ReLU(inplace=True)) 71 | context.append(nn.Conv2d(2 * n_classes, 4 * n_classes, 3, dilation=2)) 72 | context.append(nn.ReLU(inplace=True)) 73 | context.append(nn.Conv2d(4 * n_classes, 8 * n_classes, 3, dilation=4)) 74 | context.append(nn.ReLU(inplace=True)) 75 | context.append(nn.Conv2d(8 * n_classes, 16 * n_classes, 3, dilation=8)) 76 | context.append(nn.ReLU(inplace=True)) 77 | context.append(nn.Conv2d(16 * n_classes, 32 * n_classes, 3, dilation=16)) 78 | context.append(nn.ReLU(inplace=True)) 79 | context.append(nn.Conv2d(32 * n_classes, 32 * n_classes, 3)) 80 | context.append(nn.ReLU(inplace=True)) 81 | context.append(nn.Conv2d(32 * n_classes, n_classes, 1)) 82 | context.append(nn.ReLU(inplace=True)) 83 | self.context = nn.Sequential(*context) 84 | 85 | self._initialize_weights() 86 | 87 | def _initialize_weights(self): 88 | vgg16 = torchvision.models.vgg16(pretrained=True) 89 | vgg_features1 = vgg16.features[0:23] 90 | self.features1.load_state_dict(vgg_features1.state_dict()) 91 | 92 | vgg_features2 = vgg16.features[24:30] 93 | for l1, l2 in zip(vgg_features2.children(), self.features2.children()): 94 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 95 | assert l1.weight.size() == l2.weight.size() 96 | assert l1.bias.size() == l2.bias.size() 97 | l2.weight.data = l1.weight.data 98 | l2.bias.data = l1.bias.data 99 | 100 | fc = self.fc[0:4] 101 | for l1, l2 in zip(vgg16.classifier.children(), fc.children()): 102 | if isinstance(l1, nn.Linear) and isinstance(l2, nn.Conv2d): 103 | l2.weight.data = l1.weight.data.view(l2.weight.size()) 104 | l2.bias.data = l1.bias.data.view(l2.bias.size()) 105 | 106 | for m in self.context.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | nn.init.normal_(m.weight, std=0.001) 109 | nn.init.constant_(m.bias, 0) 110 | 111 | def forward(self, x): 112 | _, _, h, w = x.size() 113 | out = self.features1(x) 114 | out = self.features2(out) 115 | out = self.fc(out) 116 | out = self.context(out) 117 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 118 | return out -------------------------------------------------------------------------------- /Models/FCN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | 6 | # https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py 7 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 8 | """Make a 2D bilinear kernel suitable for upsampling""" 9 | factor = (kernel_size + 1) // 2 10 | if kernel_size % 2 == 1: 11 | center = factor - 1 12 | else: 13 | center = factor - 0.5 14 | og = np.ogrid[:kernel_size, :kernel_size] 15 | filt = (1 - abs(og[0] - center) / factor) * \ 16 | (1 - abs(og[1] - center) / factor) 17 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 18 | dtype=np.float64) 19 | weight[range(in_channels), range(out_channels), :, :] = filt 20 | return torch.from_numpy(weight).float() 21 | 22 | 23 | class FCN32s(nn.Module): 24 | """Adapted from official implementation: 25 | 26 | https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/voc-fcn32s/train.prototxt 27 | """ 28 | def __init__(self, n_classes): 29 | super(FCN32s, self).__init__() 30 | 31 | features = [] 32 | # conv1 33 | features.append(nn.Conv2d(3, 64, 3, padding=100)) 34 | features.append(nn.ReLU(inplace=True)) 35 | features.append(nn.Conv2d(64, 64, 3, padding=1)) 36 | features.append(nn.ReLU(inplace=True)) 37 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2 38 | 39 | # conv2 40 | features.append(nn.Conv2d(64, 128, 3, padding=1)) 41 | features.append(nn.ReLU(inplace=True)) 42 | features.append(nn.Conv2d(128, 128, 3, padding=1)) 43 | features.append(nn.ReLU(inplace=True)) 44 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4 45 | 46 | # conv3 47 | features.append(nn.Conv2d(128, 256, 3, padding=1)) 48 | features.append(nn.ReLU(inplace=True)) 49 | features.append(nn.Conv2d(256, 256, 3, padding=1)) 50 | features.append(nn.ReLU(inplace=True)) 51 | features.append(nn.Conv2d(256, 256, 3, padding=1)) 52 | features.append(nn.ReLU(inplace=True)) 53 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8 54 | 55 | # conv4 56 | features.append(nn.Conv2d(256, 512, 3, padding=1)) 57 | features.append(nn.ReLU(inplace=True)) 58 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 59 | features.append(nn.ReLU(inplace=True)) 60 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 61 | features.append(nn.ReLU(inplace=True)) 62 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/16 63 | 64 | # conv5 65 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 66 | features.append(nn.ReLU(inplace=True)) 67 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 68 | features.append(nn.ReLU(inplace=True)) 69 | features.append(nn.Conv2d(512, 512, 3, padding=1)) 70 | features.append(nn.ReLU(inplace=True)) 71 | features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/32 72 | 73 | self.features = nn.Sequential(*features) 74 | 75 | fc = [] 76 | fc.append(nn.Conv2d(512, 4096, 7)) 77 | fc.append(nn.ReLU(inplace=True)) 78 | fc.append(nn.Dropout(p=0.5)) 79 | fc.append(nn.Conv2d(4096, 4096, 1)) 80 | fc.append(nn.ReLU(inplace=True)) 81 | fc.append(nn.Dropout(p=0.5)) 82 | self.fc = nn.Sequential(*fc) 83 | 84 | self.score_fr = nn.Conv2d(4096, n_classes, 1) 85 | self.upscore = nn.ConvTranspose2d(n_classes, n_classes, 64, stride=32, 86 | bias=False) 87 | 88 | self._initialize_weights() 89 | 90 | def _initialize_weights(self): 91 | self.score_fr.weight.data.zero_() 92 | self.score_fr.bias.data.zero_() 93 | 94 | assert self.upscore.kernel_size[0] == self.upscore.kernel_size[1] 95 | initial_weight = get_upsampling_weight( 96 | self.upscore.in_channels, self.upscore.out_channels, 97 | self.upscore.kernel_size[0]) 98 | self.upscore.weight.data.copy_(initial_weight) 99 | 100 | 101 | vgg16 = torchvision.models.vgg16(pretrained=True) 102 | state_dict = vgg16.features.state_dict() 103 | self.features.load_state_dict(state_dict) 104 | 105 | for l1, l2 in zip(vgg16.classifier.children(), self.fc): 106 | if isinstance(l1, nn.Linear) and isinstance(l2, nn.Conv2d): 107 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 108 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 109 | 110 | def forward(self, x): 111 | out = self.features(x) 112 | out = self.fc(out) 113 | out = self.score_fr(out) 114 | out = self.upscore(out) 115 | out = out[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() 116 | 117 | return out 118 | 119 | def get_parameters(self, bias=False): 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | if bias: 123 | yield m.bias 124 | else: 125 | yield m.weight 126 | 127 | 128 | class FCN8sAtOnce(nn.Module): 129 | def __init__(self, n_classes): 130 | super(FCN8sAtOnce, self).__init__() 131 | 132 | features1 = [] 133 | # conv1 134 | features1.append(nn.Conv2d(3, 64, 3, padding=100)) 135 | features1.append(nn.ReLU(inplace=True)) 136 | features1.append(nn.Conv2d(64, 64, 3, padding=1)) 137 | features1.append(nn.ReLU(inplace=True)) 138 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2 139 | 140 | # conv2 141 | features1.append(nn.Conv2d(64, 128, 3, padding=1)) 142 | features1.append(nn.ReLU(inplace=True)) 143 | features1.append(nn.Conv2d(128, 128, 3, padding=1)) 144 | features1.append(nn.ReLU(inplace=True)) 145 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4 146 | 147 | # conv3 148 | features1.append(nn.Conv2d(128, 256, 3, padding=1)) 149 | features1.append(nn.ReLU(inplace=True)) 150 | features1.append(nn.Conv2d(256, 256, 3, padding=1)) 151 | features1.append(nn.ReLU(inplace=True)) 152 | features1.append(nn.Conv2d(256, 256, 3, padding=1)) 153 | features1.append(nn.ReLU(inplace=True)) 154 | features1.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8 155 | self.features1 = nn.Sequential(*features1) 156 | 157 | features2 = [] 158 | # conv4 159 | features2.append(nn.Conv2d(256, 512, 3, padding=1)) 160 | features2.append(nn.ReLU(inplace=True)) 161 | features2.append(nn.Conv2d(512, 512, 3, padding=1)) 162 | features2.append(nn.ReLU(inplace=True)) 163 | features2.append(nn.Conv2d(512, 512, 3, padding=1)) 164 | features2.append(nn.ReLU(inplace=True)) 165 | features2.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/16 166 | self.features2 = nn.Sequential(*features2) 167 | 168 | features3 = [] 169 | # conv5 170 | features3.append(nn.Conv2d(512, 512, 3, padding=1)) 171 | features3.append(nn.ReLU(inplace=True)) 172 | features3.append(nn.Conv2d(512, 512, 3, padding=1)) 173 | features3.append(nn.ReLU(inplace=True)) 174 | features3.append(nn.Conv2d(512, 512, 3, padding=1)) 175 | features3.append(nn.ReLU(inplace=True)) 176 | features3.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/32 177 | self.features3 = nn.Sequential(*features3) 178 | 179 | fc = [] 180 | # fc6 181 | fc.append(nn.Conv2d(512, 4096, 7)) 182 | fc.append(nn.ReLU(inplace=True)) 183 | fc.append(nn.Dropout2d()) 184 | 185 | # fc7 186 | fc.append(nn.Conv2d(4096, 4096, 1)) 187 | fc.append(nn.ReLU(inplace=True)) 188 | fc.append(nn.Dropout2d()) 189 | self.fc = nn.Sequential(*fc) 190 | 191 | self.score_fr = nn.Conv2d(4096, n_classes, 1) 192 | self.score_pool3 = nn.Conv2d(256, n_classes, 1) 193 | self.score_pool4 = nn.Conv2d(512, n_classes, 1) 194 | 195 | self.upscore2 = nn.ConvTranspose2d( 196 | n_classes, n_classes, 4, stride=2, bias=False) 197 | self.upscore8 = nn.ConvTranspose2d( 198 | n_classes, n_classes, 16, stride=8, bias=False) 199 | self.upscore_pool4 = nn.ConvTranspose2d( 200 | n_classes, n_classes, 4, stride=2, bias=False) 201 | 202 | self._initialize_weights() 203 | 204 | def _initialize_weights(self): 205 | for m in [self.score_fr, self.score_pool3, self.score_pool4]: 206 | m.weight.data.zero_() 207 | m.bias.data.zero_() 208 | 209 | for m in [self.upscore2, self.upscore8, self.upscore_pool4]: 210 | assert m.kernel_size[0] == m.kernel_size[1] 211 | initial_weight = get_upsampling_weight( 212 | m.in_channels, m.out_channels, m.kernel_size[0]) 213 | m.weight.data.copy_(initial_weight) 214 | 215 | vgg16 = torchvision.models.vgg16(pretrained=True) 216 | vgg_features = [ 217 | vgg16.features[:17], 218 | vgg16.features[17:24], 219 | vgg16.features[24:], 220 | ] 221 | features = [ 222 | self.features1, 223 | self.features2, 224 | self.features3, 225 | ] 226 | 227 | for l1, l2 in zip(vgg_features, features): 228 | for ll1, ll2 in zip(l1.children(), l2.children()): 229 | if isinstance(ll1, nn.Conv2d) and isinstance(ll2, nn.Conv2d): 230 | assert ll1.weight.size() == ll2.weight.size() 231 | assert ll1.bias.size() == ll2.bias.size() 232 | ll2.weight.data.copy_(ll1.weight.data) 233 | ll2.bias.data.copy_(ll1.bias.data) 234 | 235 | for l1, l2 in zip(vgg16.classifier.children(), self.fc): 236 | if isinstance(l1, nn.Linear) and isinstance(l2, nn.Conv2d): 237 | l2.weight.data.copy_(l1.weight.data.view(l2.weight.size())) 238 | l2.bias.data.copy_(l1.bias.data.view(l2.bias.size())) 239 | 240 | def forward(self, x): 241 | pool3 = self.features1(x) # 1/8 242 | pool4 = self.features2(pool3) # 1/16 243 | pool5 = self.features3(pool4) # 1/32 244 | fc = self.fc(pool5) 245 | score_fr = self.score_fr(fc) 246 | upscore2 = self.upscore2(score_fr) # 1/16 247 | 248 | score_pool4 = self.score_pool4(pool4 * 0.01) # XXX: scaling to train at once 249 | score_pool4c = score_pool4[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 250 | upscore_pool4 = self.upscore_pool4(upscore2 + score_pool4c) # 1/8 251 | 252 | score_pool3 = self.score_pool3(pool3 * 0.0001) # XXX: scaling to train at once 253 | score_pool3c = score_pool3[:, :, 254 | 9:9 + upscore_pool4.size()[2], 255 | 9:9 + upscore_pool4.size()[3]] 256 | out = self.upscore8(upscore_pool4 + score_pool3c) 257 | 258 | out = out[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous() 259 | 260 | return out 261 | 262 | def get_parameters(self, bias=False): 263 | for m in self.modules(): 264 | if isinstance(m, nn.Conv2d): 265 | if bias: 266 | yield m.bias 267 | else: 268 | yield m.weight 269 | 270 | 271 | if __name__ == "__main__": 272 | import torch 273 | import time 274 | model = FCN32s(21) 275 | print(f'==> Testing {model.__class__.__name__} with PyTorch') 276 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 277 | # torch.backends.cudnn.benchmark = True 278 | 279 | model = model.to(device) 280 | model.eval() 281 | 282 | x = torch.Tensor(1, 3, 500, 500) 283 | x = x.to(device) 284 | 285 | torch.cuda.synchronize() 286 | t_start = time.time() 287 | for i in range(10): 288 | model(x) 289 | torch.cuda.synchronize() 290 | elapsed_time = time.time() - t_start 291 | 292 | print(f'Speed: {(elapsed_time / 10) * 1000:.2f} ms') -------------------------------------------------------------------------------- /Models/PSPNet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # def freeze_bn(self): 7 | # for m in self.modules(): 8 | # if isinstance(m, nn.BatchNorm2d): 9 | # m.eval() 10 | 11 | class PSPNet(nn.Module): 12 | """set crop size to 480 13 | """ 14 | def __init__(self, n_classes): 15 | super(PSPNet, self).__init__() 16 | 17 | self.resnet = ResNet(Bottleneck, [3, 4, 6, 3]) 18 | self.pyramid_pooling = PyramidPooling(2048, 512) 19 | self.final = nn.Sequential( 20 | nn.Conv2d(4096, 512, 3, padding=1, bias=False), 21 | nn.BatchNorm2d(512, momentum=.95), 22 | nn.ReLU(inplace=True), 23 | nn.Dropout(p=0.1), 24 | nn.Conv2d(512, n_classes, 1) 25 | ) 26 | 27 | self._initialize_weights() 28 | 29 | def _initialize_weights(self): 30 | for m in self.final: 31 | if isinstance(m, nn.Conv2d): 32 | nn.init.kaiming_normal_(m.weight) 33 | if m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | if isinstance(m, nn.BatchNorm2d): 36 | nn.init.constant_(m.weight, 1) 37 | nn.init.constant_(m.bias, 0) 38 | 39 | def forward(self, x): 40 | _, _, h, w = x.size() 41 | out = self.resnet(x) 42 | out = self.pyramid_pooling(out) 43 | out = self.final(out) 44 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True) 45 | return out 46 | 47 | def conv1x1(in_planes, out_planes, stride=1): 48 | """1x1 convolution""" 49 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 50 | 51 | class PyramidPooling(nn.Module): 52 | def __init__(self, in_channels, out_channels): 53 | super(PyramidPooling, self).__init__() 54 | self.pool1 = self._pyramid_conv(in_channels, out_channels, 10) 55 | self.pool2 = self._pyramid_conv(in_channels, out_channels, 20) 56 | self.pool3 = self._pyramid_conv(in_channels, out_channels, 30) 57 | self.pool4 = self._pyramid_conv(in_channels, out_channels, 60) 58 | 59 | self._initialize_weights() 60 | 61 | def _initialize_weights(self): 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | nn.init.kaiming_normal_(m.weight) 65 | if isinstance(m, nn.BatchNorm2d): 66 | nn.init.constant_(m.weight, 1) 67 | nn.init.constant_(m.bias, 0) 68 | 69 | def _pyramid_conv(self, in_channels, out_channels, scale): 70 | module = nn.Sequential( 71 | # nn.AdaptiveAvgPool2d(scale), 72 | nn.AvgPool2d(kernel_size=scale, stride=scale), 73 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 74 | nn.BatchNorm2d(out_channels, momentum=.95), 75 | nn.ReLU(inplace=True) 76 | ) 77 | return module 78 | 79 | def forward(self, x): 80 | _, _, h, w = x.size() 81 | pool1 = self.pool1(x) 82 | pool2 = self.pool2(x) 83 | pool3 = self.pool3(x) 84 | pool4 = self.pool4(x) 85 | pool1 = F.interpolate(pool1, size=(h, w), mode='bilinear', align_corners=True) 86 | pool2 = F.interpolate(pool2, size=(h, w), mode='bilinear', align_corners=True) 87 | pool3 = F.interpolate(pool3, size=(h, w), mode='bilinear', align_corners=True) 88 | pool4 = F.interpolate(pool4, size=(h, w), mode='bilinear', align_corners=True) 89 | out = torch.cat([x, pool1, pool2, pool3, pool4], 1) 90 | return out 91 | 92 | 93 | class Bottleneck(nn.Module): 94 | expansion = 4 95 | 96 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 97 | super(Bottleneck, self).__init__() 98 | self.conv1 = conv1x1(inplanes, planes) 99 | self.bn1 = nn.BatchNorm2d(planes, momentum=.95) 100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 101 | padding=dilation, dilation=dilation, bias=False) 102 | self.bn2 = nn.BatchNorm2d(planes, momentum=.95) 103 | self.conv3 = conv1x1(planes, planes * self.expansion) 104 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=.95) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.downsample = downsample 107 | self.stride = stride 108 | 109 | def forward(self, x): 110 | identity = x 111 | 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv2(out) 117 | out = self.bn2(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv3(out) 121 | out = self.bn3(out) 122 | 123 | if self.downsample is not None: 124 | identity = self.downsample(x) 125 | 126 | out += identity 127 | out = self.relu(out) 128 | 129 | return out 130 | 131 | 132 | class ResNet(nn.Module): 133 | def __init__(self, block, layers): 134 | super(ResNet, self).__init__() 135 | self.inplanes = 64 136 | self.conv1 = nn.Sequential( 137 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 138 | nn.BatchNorm2d(64, momentum=.95), 139 | nn.ReLU(inplace=True), 140 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 141 | nn.BatchNorm2d(64, momentum=.95), 142 | nn.ReLU(inplace=True), 143 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 144 | nn.BatchNorm2d(64, momentum=.95), 145 | nn.ReLU(inplace=True), 146 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 147 | ) 148 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 152 | 153 | self._initialize_weights() 154 | 155 | def _initialize_weights(self): 156 | for m in self.conv1.children(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight) 159 | if isinstance(m, nn.BatchNorm2d): 160 | nn.init.constant_(m.weight, 1) 161 | nn.init.constant_(m.bias, 0) 162 | 163 | for module in [self.layer1, self.layer2, self.layer3, self.layer4]: 164 | for m in module.modules(): 165 | if isinstance(m, nn.BatchNorm2d): 166 | nn.init.constant_(m.weight, 1) 167 | nn.init.constant_(m.bias, 0) 168 | 169 | resnet = torchvision.models.resnet50(pretrained=True) 170 | self.layer1.load_state_dict(resnet.layer1.state_dict()) 171 | self.layer2.load_state_dict(resnet.layer2.state_dict()) 172 | self.layer3.load_state_dict(resnet.layer3.state_dict()) 173 | self.layer4.load_state_dict(resnet.layer4.state_dict()) 174 | 175 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 176 | downsample = None 177 | if stride != 1 or self.inplanes != planes * block.expansion: 178 | downsample = nn.Sequential( 179 | nn.Conv2d(self.inplanes, planes * block.expansion, 180 | kernel_size=1, stride=stride, bias=False), 181 | nn.BatchNorm2d(planes * block.expansion, momentum=0.95)) 182 | 183 | layers = [] 184 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample)) 185 | self.inplanes = planes * block.expansion 186 | for i in range(1, blocks): 187 | layers.append(block(self.inplanes, planes, dilation=dilation)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | 192 | def forward(self, x): 193 | out = self.conv1(x) 194 | 195 | out = self.layer1(out) 196 | out = self.layer2(out) 197 | out = self.layer3(out) 198 | out = self.layer4(out) 199 | 200 | return out -------------------------------------------------------------------------------- /Models/SegNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision 4 | 5 | 6 | # use vgg16_bn pretrained model 7 | class SegNet(nn.Module): 8 | """Adapted from official implementation: 9 | 10 | https://github.com/alexgkendall/SegNet-Tutorial/tree/master/Models 11 | """ 12 | def __init__(self, n_classes): 13 | super(SegNet, self).__init__() 14 | 15 | # conv1 16 | features1 = [] 17 | features1.append(nn.Conv2d(3, 64, 3, padding=1)) 18 | features1.append(nn.BatchNorm2d(64)) 19 | features1.append(nn.ReLU(inplace=True)) 20 | features1.append(nn.Conv2d(64, 64, 3, padding=1)) 21 | features1.append(nn.BatchNorm2d(64)) 22 | features1.append(nn.ReLU(inplace=True)) 23 | self.features1 = nn.Sequential(*features1) 24 | self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/2 25 | 26 | # conv2 27 | features2 = [] 28 | features2.append(nn.Conv2d(64, 128, 3, padding=1)) 29 | features2.append(nn.BatchNorm2d(128)) 30 | features2.append(nn.ReLU(inplace=True)) 31 | features2.append(nn.Conv2d(128, 128, 3, padding=1)) 32 | features2.append(nn.BatchNorm2d(128)) 33 | features2.append(nn.ReLU(inplace=True)) 34 | self.features2 = nn.Sequential(*features2) 35 | self.pool2 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/4 36 | 37 | # conv3 38 | features3 = [] 39 | features3.append(nn.Conv2d(128, 256, 3, padding=1)) 40 | features3.append(nn.BatchNorm2d(256)) 41 | features3.append(nn.ReLU(inplace=True)) 42 | features3.append(nn.Conv2d(256, 256, 3, padding=1)) 43 | features3.append(nn.BatchNorm2d(256)) 44 | features3.append(nn.ReLU(inplace=True)) 45 | features3.append(nn.Conv2d(256, 256, 3, padding=1)) 46 | features3.append(nn.BatchNorm2d(256)) 47 | features3.append(nn.ReLU(inplace=True)) 48 | self.features3 = nn.Sequential(*features3) 49 | self.pool3 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/8 50 | 51 | # conv4 52 | features4 = [] 53 | features4.append(nn.Conv2d(256, 512, 3, padding=1)) 54 | features4.append(nn.BatchNorm2d(512)) 55 | features4.append(nn.ReLU(inplace=True)) 56 | features4.append(nn.Conv2d(512, 512, 3, padding=1)) 57 | features4.append(nn.BatchNorm2d(512)) 58 | features4.append(nn.ReLU(inplace=True)) 59 | features4.append(nn.Conv2d(512, 512, 3, padding=1)) 60 | features4.append(nn.BatchNorm2d(512)) 61 | features4.append(nn.ReLU(inplace=True)) 62 | self.features4 = nn.Sequential(*features4) 63 | self.pool4 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/16 64 | 65 | # conv5 66 | features5 = [] 67 | features5.append(nn.Conv2d(512, 512, 3, padding=1)) 68 | features5.append(nn.BatchNorm2d(512)) 69 | features5.append(nn.ReLU(inplace=True)) 70 | features5.append(nn.Conv2d(512, 512, 3, padding=1)) 71 | features5.append(nn.BatchNorm2d(512)) 72 | features5.append(nn.ReLU(inplace=True)) 73 | features5.append(nn.Conv2d(512, 512, 3, padding=1)) 74 | features5.append(nn.BatchNorm2d(512)) 75 | features5.append(nn.ReLU(inplace=True)) 76 | self.features5 = nn.Sequential(*features5) 77 | self.pool5 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/32 78 | 79 | # convTranspose1 80 | self.unpool6 = nn.MaxUnpool2d(2, stride=2) 81 | features6 = [] 82 | features6.append(nn.Conv2d(512, 512, 3, padding=1)) 83 | features6.append(nn.BatchNorm2d(512)) 84 | features6.append(nn.ReLU(inplace=True)) 85 | features6.append(nn.Conv2d(512, 512, 3, padding=1)) 86 | features6.append(nn.BatchNorm2d(512)) 87 | features6.append(nn.ReLU(inplace=True)) 88 | features6.append(nn.Conv2d(512, 512, 3, padding=1)) 89 | features6.append(nn.BatchNorm2d(512)) 90 | features6.append(nn.ReLU(inplace=True)) 91 | self.features6 = nn.Sequential(*features6) 92 | 93 | # convTranspose2 94 | self.unpool7 = nn.MaxUnpool2d(2, stride=2) 95 | features7 = [] 96 | features7.append(nn.Conv2d(512, 512, 3, padding=1)) 97 | features7.append(nn.BatchNorm2d(512)) 98 | features7.append(nn.ReLU(inplace=True)) 99 | features7.append(nn.Conv2d(512, 512, 3, padding=1)) 100 | features7.append(nn.BatchNorm2d(512)) 101 | features7.append(nn.ReLU(inplace=True)) 102 | features7.append(nn.Conv2d(512, 256, 3, padding=1)) 103 | features7.append(nn.BatchNorm2d(256)) 104 | features7.append(nn.ReLU(inplace=True)) 105 | self.features7 = nn.Sequential(*features7) 106 | 107 | # convTranspose3 108 | self.unpool8 = nn.MaxUnpool2d(2, stride=2) 109 | features8 = [] 110 | features8.append(nn.Conv2d(256, 256, 3, padding=1)) 111 | features8.append(nn.BatchNorm2d(256)) 112 | features8.append(nn.ReLU(inplace=True)) 113 | features8.append(nn.Conv2d(256, 256, 3, padding=1)) 114 | features8.append(nn.BatchNorm2d(256)) 115 | features8.append(nn.ReLU(inplace=True)) 116 | features8.append(nn.Conv2d(256, 128, 3, padding=1)) 117 | features8.append(nn.BatchNorm2d(128)) 118 | features8.append(nn.ReLU(inplace=True)) 119 | self.features8 = nn.Sequential(*features8) 120 | 121 | # convTranspose4 122 | self.unpool9 = nn.MaxUnpool2d(2, stride=2) 123 | features9 = [] 124 | features9.append(nn.Conv2d(128, 128, 3, padding=1)) 125 | features9.append(nn.BatchNorm2d(128)) 126 | features9.append(nn.ReLU(inplace=True)) 127 | features9.append(nn.Conv2d(128, 64, 3, padding=1)) 128 | features9.append(nn.BatchNorm2d(64)) 129 | features9.append(nn.ReLU(inplace=True)) 130 | self.features9 = nn.Sequential(*features9) 131 | 132 | # convTranspose5 133 | self.unpool10 = nn.MaxUnpool2d(2, stride=2) 134 | self.final = nn.Sequential( 135 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 136 | nn.BatchNorm2d(64), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(64, n_classes, kernel_size=3, padding=1), 139 | ) 140 | 141 | self._initialize_weights() 142 | 143 | def _initialize_weights(self): 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight) 147 | # if isinstance(m, nn.BatchNorm2d): 148 | # nn.init.constant_(m.weight, 1) 149 | # nn.init.constant_(m.bias, 0.001) 150 | 151 | vgg16 = torchvision.models.vgg16_bn(pretrained=True) 152 | vgg_features = [ 153 | vgg16.features[0:6], 154 | vgg16.features[7:13], 155 | vgg16.features[14:23], 156 | vgg16.features[24:33], 157 | vgg16.features[34:43] 158 | ] 159 | features = [ 160 | self.features1, 161 | self.features2, 162 | self.features3, 163 | self.features4, 164 | self.features5 165 | ] 166 | for l1, l2 in zip(vgg_features, features): 167 | for ll1, ll2 in zip(l1.children(), l2.children()): 168 | if isinstance(ll1, nn.Conv2d) and isinstance(ll2, nn.Conv2d): 169 | assert ll1.weight.size() == ll2.weight.size() 170 | assert ll1.bias.size() == ll2.bias.size() 171 | ll2.weight.data = ll1.weight.data 172 | ll2.bias.data = ll1.bias.data 173 | if isinstance(ll1, nn.BatchNorm2d) and isinstance(ll2, nn.BatchNorm2d): 174 | assert ll1.weight.size() == ll2.weight.size() 175 | assert ll1.bias.size() == ll2.bias.size() 176 | ll2.weight.data = ll1.weight.data 177 | ll2.bias.data = ll1.bias.data 178 | 179 | def forward(self, x): 180 | out = self.features1(x) 181 | out, indices_1 = self.pool1(out) 182 | out = self.features2(out) 183 | out, indices_2 = self.pool2(out) 184 | out = self.features3(out) 185 | out, indices_3 = self.pool3(out) 186 | out = self.features4(out) 187 | out, indices_4 = self.pool4(out) 188 | out = self.features5(out) 189 | out, indices_5 = self.pool5(out) 190 | out = self.unpool6(out, indices_5) 191 | out = self.features6(out) 192 | out = self.unpool7(out, indices_4) 193 | out = self.features7(out) 194 | out = self.unpool8(out, indices_3) 195 | out = self.features8(out) 196 | out = self.unpool9(out, indices_2) 197 | out = self.features9(out) 198 | out = self.unpool10(out, indices_1) 199 | out = self.final(out) 200 | return out 201 | 202 | 203 | # use vgg16 pretrained model 204 | # class SegNet(nn.Module): 205 | # """ 206 | # Adapted from official implementation: 207 | 208 | # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/Models 209 | # """ 210 | # def __init__(self, n_classes): 211 | # super(SegNet, self).__init__() 212 | 213 | # # conv1 214 | # features1 = [] 215 | # features1.append(nn.Conv2d(3, 64, 3, padding=1)) 216 | # features1.append(nn.BatchNorm2d(64)) 217 | # features1.append(nn.ReLU(inplace=True)) 218 | # features1.append(nn.Conv2d(64, 64, 3, padding=1)) 219 | # features1.append(nn.BatchNorm2d(64)) 220 | # features1.append(nn.ReLU(inplace=True)) 221 | # self.features1 = nn.Sequential(*features1) 222 | # self.pool1 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/2 223 | 224 | # # conv2 225 | # features2 = [] 226 | # features2.append(nn.Conv2d(64, 128, 3, padding=1)) 227 | # features2.append(nn.BatchNorm2d(128)) 228 | # features2.append(nn.ReLU(inplace=True)) 229 | # features2.append(nn.Conv2d(128, 128, 3, padding=1)) 230 | # features2.append(nn.BatchNorm2d(128)) 231 | # features2.append(nn.ReLU(inplace=True)) 232 | # self.features2 = nn.Sequential(*features2) 233 | # self.pool2 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/4 234 | 235 | # # conv3 236 | # features3 = [] 237 | # features3.append(nn.Conv2d(128, 256, 3, padding=1)) 238 | # features3.append(nn.BatchNorm2d(256)) 239 | # features3.append(nn.ReLU(inplace=True)) 240 | # features3.append(nn.Conv2d(256, 256, 3, padding=1)) 241 | # features3.append(nn.BatchNorm2d(256)) 242 | # features3.append(nn.ReLU(inplace=True)) 243 | # features3.append(nn.Conv2d(256, 256, 3, padding=1)) 244 | # features3.append(nn.BatchNorm2d(256)) 245 | # features3.append(nn.ReLU(inplace=True)) 246 | # self.features3 = nn.Sequential(*features3) 247 | # self.pool3 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/8 248 | 249 | # # conv4 250 | # features4 = [] 251 | # features4.append(nn.Conv2d(256, 512, 3, padding=1)) 252 | # features4.append(nn.BatchNorm2d(512)) 253 | # features4.append(nn.ReLU(inplace=True)) 254 | # features4.append(nn.Conv2d(512, 512, 3, padding=1)) 255 | # features4.append(nn.BatchNorm2d(512)) 256 | # features4.append(nn.ReLU(inplace=True)) 257 | # features4.append(nn.Conv2d(512, 512, 3, padding=1)) 258 | # features4.append(nn.BatchNorm2d(512)) 259 | # features4.append(nn.ReLU(inplace=True)) 260 | # self.features4 = nn.Sequential(*features4) 261 | # self.pool4 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/16 262 | 263 | # # conv5 264 | # features5 = [] 265 | # features5.append(nn.Conv2d(512, 512, 3, padding=1)) 266 | # features5.append(nn.BatchNorm2d(512)) 267 | # features5.append(nn.ReLU(inplace=True)) 268 | # features5.append(nn.Conv2d(512, 512, 3, padding=1)) 269 | # features5.append(nn.BatchNorm2d(512)) 270 | # features5.append(nn.ReLU(inplace=True)) 271 | # features5.append(nn.Conv2d(512, 512, 3, padding=1)) 272 | # features5.append(nn.BatchNorm2d(512)) 273 | # features5.append(nn.ReLU(inplace=True)) 274 | # self.features5 = nn.Sequential(*features5) 275 | # self.pool5 = nn.MaxPool2d(2, stride=2, return_indices=True, ceil_mode=True) # 1/32 276 | 277 | # # convTranspose1 278 | # self.unpool6 = nn.MaxUnpool2d(2, stride=2) 279 | # features6 = [] 280 | # features6.append(nn.Conv2d(512, 512, 3, padding=1)) 281 | # features6.append(nn.BatchNorm2d(512)) 282 | # features6.append(nn.ReLU(inplace=True)) 283 | # features6.append(nn.Conv2d(512, 512, 3, padding=1)) 284 | # features6.append(nn.BatchNorm2d(512)) 285 | # features6.append(nn.ReLU(inplace=True)) 286 | # features6.append(nn.Conv2d(512, 512, 3, padding=1)) 287 | # features6.append(nn.BatchNorm2d(512)) 288 | # features6.append(nn.ReLU(inplace=True)) 289 | # self.features6 = nn.Sequential(*features6) 290 | 291 | # # convTranspose2 292 | # self.unpool7 = nn.MaxUnpool2d(2, stride=2) 293 | # features7 = [] 294 | # features7.append(nn.Conv2d(512, 512, 3, padding=1)) 295 | # features7.append(nn.BatchNorm2d(512)) 296 | # features7.append(nn.ReLU(inplace=True)) 297 | # features7.append(nn.Conv2d(512, 512, 3, padding=1)) 298 | # features7.append(nn.BatchNorm2d(512)) 299 | # features7.append(nn.ReLU(inplace=True)) 300 | # features7.append(nn.Conv2d(512, 256, 3, padding=1)) 301 | # features7.append(nn.BatchNorm2d(256)) 302 | # features7.append(nn.ReLU(inplace=True)) 303 | # self.features7 = nn.Sequential(*features7) 304 | 305 | # # convTranspose3 306 | # self.unpool8 = nn.MaxUnpool2d(2, stride=2) 307 | # features8 = [] 308 | # features8.append(nn.Conv2d(256, 256, 3, padding=1)) 309 | # features8.append(nn.BatchNorm2d(256)) 310 | # features8.append(nn.ReLU(inplace=True)) 311 | # features8.append(nn.Conv2d(256, 256, 3, padding=1)) 312 | # features8.append(nn.BatchNorm2d(256)) 313 | # features8.append(nn.ReLU(inplace=True)) 314 | # features8.append(nn.Conv2d(256, 128, 3, padding=1)) 315 | # features8.append(nn.BatchNorm2d(128)) 316 | # features8.append(nn.ReLU(inplace=True)) 317 | # self.features8 = nn.Sequential(*features8) 318 | 319 | # # convTranspose4 320 | # self.unpool9 = nn.MaxUnpool2d(2, stride=2) 321 | # features9 = [] 322 | # features9.append(nn.Conv2d(128, 128, 3, padding=1)) 323 | # features9.append(nn.BatchNorm2d(128)) 324 | # features9.append(nn.ReLU(inplace=True)) 325 | # features9.append(nn.Conv2d(128, 64, 3, padding=1)) 326 | # features9.append(nn.BatchNorm2d(64)) 327 | # features9.append(nn.ReLU(inplace=True)) 328 | # self.features9 = nn.Sequential(*features9) 329 | 330 | # # convTranspose5 331 | # self.unpool10 = nn.MaxUnpool2d(2, stride=2) 332 | # self.final = nn.Sequential( 333 | # nn.Conv2d(64, 64, kernel_size=3, padding=1), 334 | # nn.BatchNorm2d(64), 335 | # nn.ReLU(inplace=True), 336 | # nn.Conv2d(64, n_classes, kernel_size=3, padding=1), 337 | # ) 338 | 339 | # self._initialize_weights() 340 | 341 | # def _initialize_weights(self): 342 | # for m in self.modules(): 343 | # if isinstance(m, nn.Conv2d): 344 | # nn.init.kaiming_normal_(m.weight) 345 | 346 | # vgg16 = torchvision.models.vgg16(pretrained=True) 347 | # vgg_features = [ 348 | # vgg16.features[0:4], 349 | # vgg16.features[5:9], 350 | # vgg16.features[10:16], 351 | # vgg16.features[17:23], 352 | # vgg16.features[24:29] 353 | # ] 354 | # features = [ 355 | # self.features1, 356 | # self.features2, 357 | # self.features3, 358 | # self.features4, 359 | # self.features5 360 | # ] 361 | # for l1, l2 in zip(vgg_features, features): 362 | # for i in range(len(list(l1.modules())) // 2): 363 | # assert isinstance(l1[i * 2], nn.Conv2d) == isinstance(l2[i * 3], nn.Conv2d) 364 | # assert l1[i * 2].weight.size() == l2[i * 3].weight.size() 365 | # assert l1[i * 2].bias.size() == l2[i * 3].bias.size() 366 | # l2[i * 3].weight.data = l1[i * 2].weight.data 367 | # l2[i * 3].bias.data = l1[i * 2].bias.data 368 | 369 | # def forward(self, x): 370 | # out = self.features1(x) 371 | # out, indices_1 = self.pool1(out) 372 | # out = self.features2(out) 373 | # out, indices_2 = self.pool2(out) 374 | # out = self.features3(out) 375 | # out, indices_3 = self.pool3(out) 376 | # out = self.features4(out) 377 | # out, indices_4 = self.pool4(out) 378 | # out = self.features5(out) 379 | # out, indices_5 = self.pool5(out) 380 | # out = self.unpool6(out, indices_5) 381 | # out = self.features6(out) 382 | # out = self.unpool7(out, indices_4) 383 | # out = self.features7(out) 384 | # out = self.unpool8(out, indices_3) 385 | # out = self.features8(out) 386 | # out = self.unpool9(out, indices_2) 387 | # out = self.features9(out) 388 | # out = self.unpool10(out, indices_1) 389 | # out = self.final(out) 390 | # return out 391 | 392 | 393 | # Bilinear interpolation upsampling version 394 | # class SegNet(nn.Module): 395 | # def __init__(self, n_classes): 396 | # super(SegNet, self).__init__() 397 | # 398 | # # conv1 399 | # features = [] 400 | # features.append(nn.Conv2d(3, 64, 3, padding=1)) 401 | # features.append(nn.BatchNorm2d(64)) 402 | # features.append(nn.ReLU(inplace=True)) 403 | # features.append(nn.Conv2d(64, 64, 3, padding=1)) 404 | # features.append(nn.BatchNorm2d(64)) 405 | # features.append(nn.ReLU(inplace=True)) 406 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/2 407 | 408 | # # conv2 409 | # features.append(nn.Conv2d(64, 128, 3, padding=1)) 410 | # features.append(nn.BatchNorm2d(128)) 411 | # features.append(nn.ReLU(inplace=True)) 412 | # features.append(nn.Conv2d(128, 128, 3, padding=1)) 413 | # features.append(nn.BatchNorm2d(128)) 414 | # features.append(nn.ReLU(inplace=True)) 415 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/4 416 | 417 | # # conv3 418 | # features.append(nn.Conv2d(128, 256, 3, padding=1)) 419 | # features.append(nn.BatchNorm2d(256)) 420 | # features.append(nn.ReLU(inplace=True)) 421 | # features.append(nn.Conv2d(256, 256, 3, padding=1)) 422 | # features.append(nn.BatchNorm2d(256)) 423 | # features.append(nn.ReLU(inplace=True)) 424 | # features.append(nn.Conv2d(256, 256, 3, padding=1)) 425 | # features.append(nn.BatchNorm2d(256)) 426 | # features.append(nn.ReLU(inplace=True)) 427 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/8 428 | 429 | # # conv4 430 | # features.append(nn.Conv2d(256, 512, 3, padding=1)) 431 | # features.append(nn.BatchNorm2d(512)) 432 | # features.append(nn.ReLU(inplace=True)) 433 | # features.append(nn.Conv2d(512, 512, 3, padding=1)) 434 | # features.append(nn.BatchNorm2d(512)) 435 | # features.append(nn.ReLU(inplace=True)) 436 | # features.append(nn.Conv2d(512, 512, 3, padding=1)) 437 | # features.append(nn.BatchNorm2d(512)) 438 | # features.append(nn.ReLU(inplace=True)) 439 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/16 440 | 441 | # # conv5 442 | # features.append(nn.Conv2d(512, 512, 3, padding=1)) 443 | # features.append(nn.BatchNorm2d(512)) 444 | # features.append(nn.ReLU(inplace=True)) 445 | # features.append(nn.Conv2d(512, 512, 3, padding=1)) 446 | # features.append(nn.BatchNorm2d(512)) 447 | # features.append(nn.ReLU(inplace=True)) 448 | # features.append(nn.Conv2d(512, 512, 3, padding=1)) 449 | # features.append(nn.BatchNorm2d(512)) 450 | # features.append(nn.ReLU(inplace=True)) 451 | # features.append(nn.MaxPool2d(2, stride=2, ceil_mode=True)) # 1/32 452 | # self.features = nn.Sequential(*features) 453 | 454 | # # convTranspose1 455 | # up1 = [] 456 | # up1.append(nn.Conv2d(512, 512, 3, padding=1)) 457 | # up1.append(nn.BatchNorm2d(512)) 458 | # up1.append(nn.ReLU(inplace=True)) 459 | # up1.append(nn.Conv2d(512, 512, 3, padding=1)) 460 | # up1.append(nn.BatchNorm2d(512)) 461 | # up1.append(nn.ReLU(inplace=True)) 462 | # up1.append(nn.Conv2d(512, 512, 3, padding=1)) 463 | # up1.append(nn.BatchNorm2d(512)) 464 | # up1.append(nn.ReLU(inplace=True)) 465 | # self.up1 = nn.Sequential(*up1) 466 | 467 | # # convTranspose2 468 | # up2 = [] 469 | # up2.append(nn.Conv2d(512, 512, 3, padding=1)) 470 | # up2.append(nn.BatchNorm2d(512)) 471 | # up2.append(nn.ReLU(inplace=True)) 472 | # up2.append(nn.Conv2d(512, 512, 3, padding=1)) 473 | # up2.append(nn.BatchNorm2d(512)) 474 | # up2.append(nn.ReLU(inplace=True)) 475 | # up2.append(nn.Conv2d(512, 256, 3, padding=1)) 476 | # up2.append(nn.BatchNorm2d(256)) 477 | # up2.append(nn.ReLU(inplace=True)) 478 | # self.up2 = nn.Sequential(*up2) 479 | 480 | # # convTranspose3 481 | # up3 = [] 482 | # up3.append(nn.Conv2d(256, 256, 3, padding=1)) 483 | # up3.append(nn.BatchNorm2d(256)) 484 | # up3.append(nn.ReLU(inplace=True)) 485 | # up3.append(nn.Conv2d(256, 256, 3, padding=1)) 486 | # up3.append(nn.BatchNorm2d(256)) 487 | # up3.append(nn.ReLU(inplace=True)) 488 | # up3.append(nn.Conv2d(256, 128, 3, padding=1)) 489 | # up3.append(nn.BatchNorm2d(128)) 490 | # up3.append(nn.ReLU(inplace=True)) 491 | # self.up3 = nn.Sequential(*up3) 492 | 493 | # # convTranspose4 494 | # up4 = [] 495 | # up4.append(nn.Conv2d(128, 128, 3, padding=1)) 496 | # up4.append(nn.BatchNorm2d(128)) 497 | # up4.append(nn.ReLU(inplace=True)) 498 | # up4.append(nn.Conv2d(128, 64, 3, padding=1)) 499 | # up4.append(nn.BatchNorm2d(64)) 500 | # up4.append(nn.ReLU(inplace=True)) 501 | # self.up4 = nn.Sequential(*up4) 502 | 503 | # self.final = nn.Sequential( 504 | # nn.Conv2d(64, 64, kernel_size=3, padding=1), 505 | # nn.BatchNorm2d(64), 506 | # nn.ReLU(inplace=True), 507 | # nn.Conv2d(64, n_classes, kernel_size=3, padding=1), 508 | # ) 509 | 510 | # self._initialize_weights() 511 | 512 | # def _initialize_weights(self): 513 | # for m in self.modules(): 514 | # if isinstance(m, nn.Conv2d): 515 | # nn.init.kaiming_normal_(m.weight) 516 | # if isinstance(m, nn.BatchNorm2d): 517 | # nn.init.constant_(m.weight, 1) 518 | # nn.init.constant_(m.bias, 0.001) 519 | 520 | # vgg16 = torchvision.models.vgg16_bn(pretrained=True) 521 | # state_dict = vgg16.features.state_dict() 522 | # self.features.load_state_dict(state_dict) 523 | 524 | # def forward(self, x): 525 | # out = self.features(x) 526 | # out = F.interpolate(out, scale_factor=2, mode='bilinear') 527 | # out = self.up1(out) 528 | # out = F.interpolate(out, scale_factor=2, mode='bilinear') 529 | # out = self.up2(out) 530 | # out = F.interpolate(out, scale_factor=2, mode='bilinear') 531 | # out = self.up3(out) 532 | # out = F.interpolate(out, scale_factor=2, mode='bilinear') 533 | # out = self.up4(out) 534 | # out = F.interpolate(out, scale_factor=2, mode='bilinear') 535 | # out = self.final(out) 536 | # return out 537 | -------------------------------------------------------------------------------- /Models/UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EncoderBlock(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | super(EncoderBlock, self).__init__() 8 | layers = [ 9 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 10 | nn.BatchNorm2d(out_channels), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(out_channels), 14 | nn.ReLU(inplace=True), 15 | ] 16 | self.encode = nn.Sequential(*layers) 17 | 18 | def forward(self, x): 19 | return self.encode(x) 20 | 21 | 22 | class DecoderBlock(nn.Module): 23 | def __init__(self, in_channels, middle_channels, out_channels): 24 | super(DecoderBlock, self).__init__() 25 | self.decode = nn.Sequential( 26 | nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), 27 | nn.BatchNorm2d(middle_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1), 30 | nn.BatchNorm2d(middle_channels), 31 | nn.ReLU(inplace=True), 32 | nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2), 33 | ) 34 | 35 | def forward(self, x): 36 | return self.decode(x) 37 | 38 | 39 | class UNet(nn.Module): 40 | def __init__(self, n_classes): 41 | super(UNet, self).__init__() 42 | self.enc1 = EncoderBlock(3, 64) 43 | self.enc1_pool = nn.MaxPool2d(kernel_size=2, stride=2) 44 | self.enc2 = EncoderBlock(64, 128) 45 | self.enc2_pool = nn.MaxPool2d(kernel_size=2, stride=2) 46 | self.enc3 = EncoderBlock(128, 256) 47 | self.enc3_pool = nn.MaxPool2d(kernel_size=2, stride=2) 48 | self.enc4 = EncoderBlock(256, 512) 49 | self.enc4_pool = nn.MaxPool2d(kernel_size=2, stride=2) 50 | self.center = DecoderBlock(512, 1024, 512) 51 | self.dec4 = DecoderBlock(1024, 512, 256) 52 | self.dec3 = DecoderBlock(512, 256, 128) 53 | self.dec2 = DecoderBlock(256, 128, 64) 54 | self.dec1 = nn.Sequential( 55 | nn.Conv2d(128, 64, kernel_size=3, padding=1), 56 | nn.BatchNorm2d(64), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 59 | nn.BatchNorm2d(64), 60 | nn.ReLU(inplace=True), 61 | ) 62 | self.final = nn.Conv2d(64, n_classes, kernel_size=1) 63 | initialize_weights(self) 64 | 65 | def forward(self, x): 66 | enc1 = self.enc1(x) 67 | enc1_pool = self.enc1_pool(enc1) 68 | enc2 = self.enc2(enc1_pool) 69 | enc2_pool = self.enc2_pool(enc2) 70 | enc3 = self.enc3(enc2_pool) 71 | enc3_pool = self.enc3_pool(enc3) 72 | enc4 = self.enc4(enc3_pool) 73 | enc4_pool = self.enc4_pool(enc4) 74 | center = self.center(enc4_pool) 75 | dec4 = self.dec4(torch.cat([center, enc4], 1)) 76 | dec3 = self.dec3(torch.cat([dec4, enc3], 1)) 77 | dec2 = self.dec2(torch.cat([dec3, enc2], 1)) 78 | dec1 = self.dec1(torch.cat([dec2, enc1], 1)) 79 | final = self.final(dec1) 80 | return final 81 | 82 | 83 | def initialize_weights(model): 84 | for module in model.modules(): 85 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 86 | nn.init.kaiming_normal_(module.weight) 87 | if module.bias is not None: 88 | nn.init.constant_(module.bias, 0) 89 | elif isinstance(module, nn.BatchNorm2d): 90 | module.weight.data.fill_(1) 91 | module.bias.data.zero_() 92 | -------------------------------------------------------------------------------- /Models/__init__.py: -------------------------------------------------------------------------------- 1 | from .FCN import FCN32s, FCN8sAtOnce 2 | from .UNet import UNet 3 | from .SegNet import SegNet 4 | from .DeepLab_v1 import DeepLabLargeFOV 5 | from .DeepLab_v2 import DeepLabASPPVGG, DeepLabASPPResNet 6 | from .DeepLab_v3 import DeepLabV3 7 | from .DeepLab_v3plus import DeepLabV3Plus 8 | from .Dilation8 import Dilation8 9 | from .PSPNet import PSPNet 10 | import torch 11 | 12 | VALID_MODEL = [ 13 | 'fcn32s', 'fcn8s', 'unet', 'segnet', 'deeplab-largefov', 'deeplab-aspp-vgg', 14 | 'deeplab-aspp-resnet', 'deeplab-v3', 'deeplab-v3+', 'dilation8', 'pspnet' 15 | ] 16 | 17 | 18 | def model_loader(model_name, n_classes, resume): 19 | model_name = model_name.lower() 20 | if model_name == 'fcn32s': 21 | model = FCN32s(n_classes=n_classes) 22 | elif model_name == 'fcn8s': 23 | model = FCN8sAtOnce(n_classes=n_classes) 24 | elif model_name == 'unet': 25 | model = UNet(n_classes=n_classes) 26 | elif model_name == 'segnet': 27 | model = SegNet(n_classes=n_classes) 28 | elif model_name == 'deeplab-largefov': 29 | model = DeepLabLargeFOV(n_classes=n_classes) 30 | elif model_name == 'deeplab-aspp-vgg': 31 | model = DeepLabASPPVGG(n_classes=n_classes) 32 | elif model_name == 'deeplab-aspp-resnet': 33 | model = DeepLabASPPResNet(n_classes=n_classes) 34 | elif model_name == 'deeplab-v3': 35 | model = DeepLabV3(n_classes=n_classes) 36 | elif model_name == 'deeplab-v3+': 37 | model = DeepLabV3Plus(n_classes=n_classes) 38 | elif model_name == 'dilation8': 39 | model = Dilation8(n_classes=n_classes) 40 | elif model_name == 'pspnet': 41 | model = PSPNet(n_classes=n_classes) 42 | else: 43 | raise ValueError('Unsupported model, ' 44 | 'valid models as follows:\n{}'.format( 45 | ', '.join(VALID_MODEL))) 46 | 47 | start_epoch = 1 48 | if resume: 49 | checkpoint = torch.load(resume) 50 | model.load_state_dict(checkpoint['model_state_dict']) 51 | start_epoch = checkpoint['epoch'] 52 | else: 53 | checkpoint = None 54 | 55 | return model, start_epoch, checkpoint 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch models 2 | 3 | ## Dataset 4 | ### PASCAL VOC 5 | model|acc|acc_cls|mean_iu|notes 6 | ---|---|---|---|--- 7 | FCN32s|90.17%|75.56%|61.81%|lr=1.0e-10
reduction='sum' 8 | FCN32s(original)|-|-|63.6%| 9 | FCN8sAtOnce|90.27%|74.95%|62.13%|lr=1.0e-10
reduction='sum' 10 | FCN8sAtOnce(original)|-|-|65.4%| 11 | DeepLab-LargeFov|93.71%|72.21%|61.32%|pad images to 513x513 for evaluation 12 | DeepLab-LargeFov|90.90%|73.89%|62.09%|use original resolution for evaluation 13 | DeepLab-LargeFov(original)|-|-|62.25%| 14 | DeepLab-ASPP|93.10|80.13%|61.07%| 15 | DeepLab-ASPP(original)|-|-|68.96%| 16 | 17 | ### CamVid 18 | model|acc|acc_cls|mean_iu|notes 19 | ---|---|---|---|--- 20 | SegNet(Maxunpooling, vgg16-based)|86.71%|66.39%|54.09%|lr=0.01 21 | SegNet(Maxunpooling, vg16_bn-based)|87.84%|70.75%|57.68%|lr=0.01 22 | SegNet(Bilinear interpolation)|85.86%|71.95%|56.22%|lr=0.01 23 | SegNet(original)|88.6%|65.9%|50.2% 24 | UNet|84.38%|62.80%|49.83%|lr=0.01 25 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image, ImageOps 3 | 4 | 5 | class Compose: 6 | def __init__(self, augmentations): 7 | self.augmentations = augmentations 8 | 9 | def __call__(self, imgs, lbls): 10 | assert imgs.size == lbls.size 11 | for aug in self.augmentations: 12 | imgs, lbls = aug(imgs, lbls) 13 | 14 | return imgs, lbls 15 | 16 | 17 | class RandomFlip: 18 | """Flip images horizontally. 19 | """ 20 | def __init__(self, prob=0.5): 21 | self.prob = prob 22 | 23 | def __call__(self, image, label): 24 | if random.random() < self.prob: 25 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 26 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 27 | return image, label 28 | 29 | 30 | class RandomCrop: 31 | """Crop images to given size. 32 | 33 | Parameters 34 | ---------- 35 | crop_size: a tuple specifying crop size, 36 | which can be larger than original size. 37 | """ 38 | def __init__(self, crop_size): 39 | self.crop_size = crop_size 40 | 41 | @staticmethod 42 | def get_params(img, output_size): 43 | w, h = img.size 44 | th, tw = output_size 45 | if w == tw and h == th: 46 | return 0, 0, h, w 47 | 48 | i = random.randint(0, h - th) 49 | j = random.randint(0, w - tw) 50 | return i, j, th, tw 51 | 52 | def __call__(self, image, label): 53 | if image.size[0] < self.crop_size[1]: 54 | image = ImageOps.expand(image, (self.crop_size[1] - image.size[0], 0), fill=0) 55 | label = ImageOps.expand(label, (self.crop_size[1] - label.size[0], 0), fill=255) 56 | if image.size[1] < self.crop_size[0]: 57 | image = ImageOps.expand(image, (0, self.crop_size[0] - image.size[1]), fill=0) 58 | label = ImageOps.expand(label, (0, self.crop_size[0] - label.size[1]), fill=255) 59 | 60 | i, j, h, w = self.get_params(image, self.crop_size) 61 | image = image.crop((j, i, j + w, i + h)) 62 | label = label.crop((j, i, j + w, i + h)) 63 | 64 | return image, label 65 | 66 | 67 | class RandomScale: 68 | """Scale images within range. 69 | 70 | Parameters 71 | ---------- 72 | scale_range: a tuple specifying lowest and highest range. 73 | """ 74 | def __init__(self, scale_range): 75 | self.scale = scale_range 76 | 77 | def __call__(self, image, label): 78 | w, h = image.size 79 | scale = random.uniform(self.scale[0], self.scale[1]) 80 | ow, oh = int(w * scale), int(h * scale) 81 | image = image.resize((ow, oh), Image.BILINEAR) 82 | label = label.resize((ow, oh), Image.NEAREST) 83 | 84 | return image, label 85 | 86 | 87 | def get_augmentations(args): 88 | """Specify augmentation. 89 | """ 90 | augs = [] 91 | if args.flip: 92 | augs.append(RandomFlip()) 93 | if args.crop_size: 94 | augs.append(RandomCrop(args.crop_size)) 95 | if args.scale_range: 96 | augs.append(RandomScale(args.scale_range)) 97 | 98 | if augs == []: 99 | return None 100 | print('Using augmentations: ', end=' ') 101 | for x in augs: 102 | print(x.__class__.__name__, end=' ') 103 | print('\n') 104 | 105 | return Compose(augs) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import scipy 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import tqdm 9 | import Models 10 | from utils import visualize_segmentation, get_tile_image, runningScore, averageMeter 11 | from Dataloader import get_loader 12 | from augmentations import RandomCrop, Compose 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 17 | ) 18 | parser.add_argument('--model', type=str, default='deeplab-largefov') 19 | parser.add_argument('--model_file', type=str, default='/home/ecust/lx/Semantic-Segmentation-PyTorch/logs/deeplab-largefov_20190417_230357/model_best.pth.tar',help='Model path') 20 | parser.add_argument('--dataset_type', type=str, default='voc',help='type of dataset') 21 | parser.add_argument('--dataset', type=str, default='/home/ecust/Datasets/PASCAL VOC/VOCdevkit/VOC2012',help='path to dataset') 22 | parser.add_argument('--img_size', type=tuple, default=None, help='resize images using bilinear interpolation') 23 | parser.add_argument('--crop_size', type=tuple, default=None, help='crop images') 24 | parser.add_argument('--n_classes', type=int, default=21, help='number of classes') 25 | parser.add_argument('--pretrained', type=bool, default=True, help='should be set the same as train.py') 26 | args = parser.parse_args() 27 | 28 | model_file = args.model_file 29 | root = args.dataset 30 | n_classes = args.n_classes 31 | 32 | crop=None 33 | # crop = Compose([RandomCrop(args.crop_size)]) 34 | loader = get_loader(args.dataset_type) 35 | val_loader = DataLoader( 36 | loader(root, n_classes=n_classes, split='val', img_size=args.img_size, augmentations=crop, pretrained=args.pretrained), 37 | batch_size=1, shuffle=False, num_workers=4) 38 | 39 | model, _, _ = Models.model_loader(args.model, n_classes, resume=None) 40 | 41 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | model = model.to(device) 43 | 44 | print('==> Loading {} model file: {}'.format(model.__class__.__name__, model_file)) 45 | 46 | model_data = torch.load(model_file) 47 | 48 | try: 49 | model.load_state_dict(model_data) 50 | except Exception: 51 | model.load_state_dict(model_data['model_state_dict']) 52 | model.eval() 53 | 54 | print('==> Evaluating with {} dataset'.format(args.dataset_type)) 55 | visualizations = [] 56 | metrics = runningScore(n_classes) 57 | 58 | for data, target in tqdm.tqdm(val_loader, total=len(val_loader), ncols=80, leave=False): 59 | data, target = data.to(device), target.to(device) 60 | score = model(data) 61 | 62 | imgs = data.data.cpu() 63 | lbl_pred = score.data.max(1)[1].cpu().numpy() 64 | lbl_true = target.data.cpu() 65 | for img, lt, lp in zip(imgs, lbl_true, lbl_pred): 66 | img, lt = val_loader.dataset.untransform(img, lt) 67 | metrics.update(lt, lp) 68 | if len(visualizations) < 9: 69 | viz = visualize_segmentation( 70 | lbl_pred=lp, lbl_true=lt, img=img, 71 | n_classes=n_classes, dataloader=val_loader) 72 | visualizations.append(viz) 73 | acc, acc_cls, mean_iu, fwavacc, cls_iu = metrics.get_scores() 74 | print(''' 75 | Accuracy: {0:.2f} 76 | Accuracy Class: {1:.2f} 77 | Mean IoU: {2:.2f} 78 | FWAV Accuracy: {3:.2f}'''.format(acc * 100, 79 | acc_cls * 100, 80 | mean_iu * 100, 81 | fwavacc * 100) + '\n') 82 | 83 | class_name = val_loader.dataset.class_names 84 | if class_name is not None: 85 | for index, value in enumerate(cls_iu.values()): 86 | offset = 20 - len(class_name[index]) 87 | print(class_name[index] + ' ' * offset + f'{value * 100:>.2f}') 88 | else: 89 | print("\nyou don't specify class_names, use number instead") 90 | for key, value in cls_iu.items(): 91 | print(key, f'{value * 100:>.2f}') 92 | 93 | viz = get_tile_image(visualizations) 94 | # img = Image.fromarray(viz) 95 | # img.save('viz_evaluate.png') 96 | scipy.misc.imsave('viz_evaluate.png', viz) 97 | 98 | if __name__ == '__main__': 99 | main() -------------------------------------------------------------------------------- /learning_curve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import learning_curve 3 | 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('log_file') 8 | args = parser.parse_args() 9 | 10 | log_file = args.log_file 11 | 12 | learning_curve(log_file) 13 | 14 | 15 | if __name__ == '__main__': 16 | main() -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | 6 | 7 | def CrossEntropyLoss(score, target, weight, ignore_index, reduction): 8 | """Cross entropy for single or multiple outputs. 9 | """ 10 | if not isinstance(score, tuple): 11 | loss = F.cross_entropy( 12 | score, target, weight=weight, ignore_index=ignore_index, reduction=reduction) 13 | return loss 14 | 15 | loss = 0 16 | for s in score: 17 | loss = loss + F.cross_entropy( 18 | s, target, weight=weight, ignore_index=ignore_index, reduction=reduction) 19 | return loss 20 | 21 | def resize_labels(labels, size): 22 | new_labels = [] 23 | for label in labels: 24 | label = label.float().cpu().numpy() 25 | label = Image.fromarray(label).resize((size[1], size[0]), Image.NEAREST) 26 | new_labels.append(np.asarray(label)) 27 | new_labels = torch.LongTensor(new_labels) 28 | return new_labels -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | 4 | 5 | # Adapted from: 6 | # https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/metrics.py 7 | 8 | 9 | class runningScore(object): 10 | def __init__(self, n_classes): 11 | self.n_classes = n_classes 12 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 13 | 14 | def _fast_hist(self, label_true, label_pred, n_class): 15 | mask = (label_true >= 0) & (label_true < n_class) 16 | hist = np.bincount( 17 | n_class * label_true[mask].astype(int) + label_pred[mask], 18 | minlength=n_class**2).reshape(n_class, n_class) 19 | return hist 20 | 21 | def update(self, label_trues, label_preds): 22 | for lt, lp in zip(label_trues, label_preds): 23 | self.confusion_matrix += self._fast_hist( 24 | lt.flatten(), lp.flatten(), self.n_classes) 25 | 26 | def get_scores(self): 27 | """Returns accuracy score evaluation result. 28 | - overall accuracy 29 | - mean accuracy 30 | - mean IU 31 | - fwavacc 32 | """ 33 | hist = self.confusion_matrix 34 | acc = np.diag(hist).sum() / hist.sum() 35 | acc_cls = np.diag(hist) / hist.sum(axis=1) 36 | acc_cls = np.nanmean(acc_cls) 37 | iu = np.diag(hist) / ( 38 | hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 39 | mean_iu = np.nanmean(iu) 40 | freq = hist.sum(axis=1) / hist.sum() 41 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 42 | cls_iu = dict(zip(range(self.n_classes), iu)) 43 | 44 | return acc, acc_cls, mean_iu, fwavacc, cls_iu 45 | 46 | def reset(self): 47 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 48 | 49 | 50 | class averageMeter(object): 51 | """Computes and stores the average and current value""" 52 | 53 | def __init__(self): 54 | self.reset() 55 | 56 | def reset(self): 57 | self.val = 0 58 | self.avg = 0 59 | self.sum = 0 60 | self.count = 0 61 | 62 | def update(self, val, n=1): 63 | self.val = val 64 | self.sum += val * n 65 | self.count += n 66 | self.avg = self.sum / self.count 67 | 68 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_optimizer(args, model): 6 | """Optimizer for different models 7 | """ 8 | if args.optim.lower() == 'sgd': 9 | if args.model.lower() in ['fcn32s', 'fcn8s']: 10 | optim = fcn_optim(model, args) 11 | elif args.model.lower() in ['deeplab-largefov', 'deeplab-aspp-vgg']: 12 | optim = deeplab_optim(model, args) 13 | elif args.model.lower() in ['deeplab-aspp-resnet']: 14 | optim = deeplabv2_optim(model, args) 15 | else: 16 | optim = torch.optim.SGD( 17 | model.parameters(), 18 | lr=args.lr, 19 | momentum=args.beta1, 20 | weight_decay=args.weight_decay) 21 | 22 | elif args.optim.lower() == 'adam': 23 | optim = torch.optim.Adam( 24 | model.parameters(), 25 | lr=args.lr, 26 | betas=(args.beta1, 0.999), 27 | weight_decay=args.weight_decay) 28 | 29 | return optim 30 | 31 | def fcn_optim(model, args): 32 | """optimizer for fcn32s and fcn8s 33 | """ 34 | optim = torch.optim.SGD( 35 | [{'params': model.get_parameters(bias=False)}, 36 | {'params': model.get_parameters(bias=True), 'lr': args.lr * 2, 'weight_decay': 0}], 37 | lr=args.lr, 38 | momentum=args.beta1, 39 | weight_decay=args.weight_decay) 40 | return optim 41 | 42 | def deeplab_optim(model, args): 43 | """optimizer for deeplab-v1 and deeplab-v2-vgg 44 | """ 45 | optim = torch.optim.SGD( 46 | [{'params': model.get_parameters(bias=False, score=False)}, 47 | {'params': model.get_parameters(bias=True, score=False), 'lr': args.lr * 2, 'weight_decay': 0}, 48 | {'params': model.get_parameters(bias=False, score=True), 'lr': args.lr * 10}, 49 | {'params': model.get_parameters(bias=True, score=True), 'lr': args.lr * 20, 'weight_decay': 0}], 50 | lr=args.lr, 51 | momentum=args.beta1, 52 | weight_decay=args.weight_decay) 53 | return optim 54 | 55 | def deeplabv2_optim(model, args): 56 | """optimizer for deeplab-v2-resnet 57 | """ 58 | optim = torch.optim.SGD( 59 | [{'params': model.get_parameters(bias=False, score=False)}, 60 | {'params': model.get_parameters(bias=False, score=True), 'lr': args.lr * 10}, 61 | {'params': model.get_parameters(bias=True, score=True), 'lr': args.lr * 20, 'weight_decay': 0}], 62 | lr=args.lr, 63 | momentum=args.beta1, 64 | weight_decay=args.weight_decay) 65 | return optim -------------------------------------------------------------------------------- /preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser( 6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 7 | ) 8 | parser.add_argument('--train_img_root', type=str, required=True, help='path to training images') 9 | parser.add_argument('--train_lbl_root', type=str, required=True, help='path to training labels') 10 | parser.add_argument('--val_img_root', type=str, help='path to validation images') 11 | parser.add_argument('--val_lbl_root', type=str, help='path to validation labels') 12 | parser.add_argument('--train_split', type=float, help='proportion of the dataset to include in the train split') 13 | 14 | args = parser.parse_args() 15 | 16 | train_img_root = args.train_img_root 17 | train_lbl_root = args.train_lbl_root 18 | val_img_root = args.val_img_root 19 | val_lbl_root = args.val_lbl_root 20 | train_split = args.train_split 21 | 22 | if val_img_root is None: 23 | img = [] 24 | lbl = [] 25 | for root, _, files in os.walk(train_img_root): 26 | for filename in files: 27 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')): 28 | img.append(os.path.join(root, filename)) 29 | 30 | for root, _, files in os.walk(train_lbl_root): 31 | for filename in files: 32 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')): 33 | lbl.append(os.path.join(root, filename)) 34 | 35 | assert len(img) == len(lbl), 'numbers of images and labels are not equal' 36 | 37 | choice = np.random.choice(len(img), len(img), replace=False) 38 | train = choice[:int(len(img) * train_split)] 39 | val = choice[int(len(img) * train_split):] 40 | 41 | with open('train.txt', 'a') as f: 42 | for index in train: 43 | f.write(' '.join([img[index], lbl[index]]) + '\n') 44 | 45 | with open('val.txt', 'a') as f: 46 | for index in val: 47 | f.write(' '.join([img[index], lbl[index]]) + '\n') 48 | else: 49 | train_img = [] 50 | train_lbl = [] 51 | val_img = [] 52 | val_lbl = [] 53 | name_list = [train_img, train_lbl, val_img, val_lbl] 54 | root = [train_img_root, train_lbl_root, val_img_root, val_lbl_root] 55 | for nlist, root in zip(name_list, root): 56 | for _root, _, files in os.walk(root): 57 | for filename in files: 58 | if filename.lower().endswith(('.png', '.jpg', '.jpeg')): 59 | nlist.append(os.path.join(_root, filename)) 60 | 61 | with open('train.txt', 'a') as f: 62 | for index in range(len(train_img)): 63 | f.write(' '.join([train_img[index], train_lbl[index]]) + '\n') 64 | 65 | with open('val.txt', 'a') as f: 66 | for index in range(len(val_img)): 67 | f.write(' '.join([val_img[index], val_lbl[index]]) + '\n') 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.0.2 2 | numpy==1.22.0 3 | PyYAML 4 | scikit-image==0.14.1 5 | scipy==1.2.1 6 | tqdm==4.31.1 7 | pytz 8 | seaborn 9 | pandas -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import random 4 | import yaml 5 | import argparse 6 | import datetime 7 | import torch 8 | from Dataloader import get_loader 9 | from torch.utils.data import DataLoader 10 | from Models import model_loader 11 | from trainer import Trainer 12 | from utils import get_scheduler 13 | from optimizer import get_optimizer 14 | from augmentations import get_augmentations 15 | 16 | here = osp.dirname(osp.abspath(__file__)) 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser( 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 22 | ) 23 | parser.add_argument('--model', type=str, default='deeplab-largefov', help='model to train for') 24 | parser.add_argument('--epochs', type=int, default=50, help='total epochs') 25 | parser.add_argument('--val_epoch', type=int, default=10, help='validation interval') 26 | parser.add_argument('--batch_size', type=int, default=16, help='number of batch size') 27 | parser.add_argument('--img_size', type=tuple, default=None, help='resize images to proper size') 28 | parser.add_argument('--dataset_type', type=str, default='voc', help='choose which dataset to use') 29 | parser.add_argument('--dataset_root', type=str, default='/home/ecust/Datasets/PASCAL VOC/VOC_Aug', help='path to dataset') 30 | parser.add_argument('--n_classes', type=int, default=21, help='number of classes') 31 | parser.add_argument('--resume', default=None, help='path to checkpoint') 32 | parser.add_argument('--optim', type=str, default='sgd', help='optimizer') 33 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 34 | parser.add_argument('--lr_policy', type=str, default='poly', help='learning rate policy') 35 | parser.add_argument('--weight-decay', type=float, default=0.0005, help='weight decay') 36 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum for sgd, beta1 for adam') 37 | parser.add_argument('--lr_decay_step', type=float, default=10, help='step size for step learning policy') 38 | parser.add_argument('--lr_power', type=int, default=0.9, help='power parameter for poly learning policy') 39 | parser.add_argument('--pretrained', type=bool, default=True, help='whether to use pretrained models') 40 | parser.add_argument('--iter_size', type=int, default=10, help='iters to accumulate gradients') 41 | 42 | parser.add_argument('--crop_size', type=tuple, default=(321, 321), help='crop sizes of images') 43 | parser.add_argument('--flip', type=bool, default=True, help='whether to use horizontal flip') 44 | 45 | args = parser.parse_args() 46 | 47 | now = datetime.datetime.now() 48 | args.out = osp.join(here, 'logs', args.model + '_' + now.strftime('%Y%m%d_%H%M%S')) 49 | 50 | if not osp.exists(args.out): 51 | os.makedirs(args.out) 52 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 53 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 54 | 55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 56 | print(f'Start training {args.model} using {device.type}\n') 57 | 58 | random.seed(1337) 59 | torch.manual_seed(1337) 60 | torch.cuda.manual_seed(1337) 61 | 62 | # 1. dataset 63 | 64 | root = args.dataset_root 65 | loader = get_loader(args.dataset_type) 66 | 67 | augmentations = get_augmentations(args) 68 | 69 | train_loader = DataLoader( 70 | loader(root, n_classes=args.n_classes, split='train_aug', img_size=args.img_size, augmentations=augmentations, 71 | pretrained=args.pretrained), 72 | batch_size=args.batch_size, shuffle=True, num_workers=4) 73 | val_loader = DataLoader( 74 | loader(root, n_classes=args.n_classes, split='val_id', img_size=args.img_size, pretrained=args.pretrained), 75 | batch_size=1, shuffle=False, num_workers=4) 76 | 77 | # 2. model 78 | model, start_epoch, ckpt = model_loader(args.model, args.n_classes, args.resume) 79 | model = model.to(device) 80 | 81 | # 3. optimizer 82 | optim = get_optimizer(args, model) 83 | if args.resume: 84 | optim.load_state_dict(ckpt['optim_state_dict']) 85 | 86 | scheduler = get_scheduler(optim, args) 87 | 88 | # 4. train 89 | trainer = Trainer( 90 | device=device, 91 | model=model, 92 | optimizer=optim, 93 | scheduler=scheduler, 94 | train_loader=train_loader, 95 | val_loader=val_loader, 96 | out=args.out, 97 | epochs=args.epochs, 98 | n_classes=args.n_classes, 99 | val_epoch=args.val_epoch, 100 | iter_size=args.iter_size 101 | ) 102 | trainer.epoch = start_epoch 103 | trainer.train() 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import os.path as osp 4 | import shutil 5 | import numpy as np 6 | import pytz 7 | import scipy.misc 8 | import torch 9 | import tqdm 10 | from PIL import Image 11 | from loss import CrossEntropyLoss, resize_labels 12 | from utils import visualize_segmentation, get_tile_image, learning_curve 13 | from metrics import runningScore, averageMeter 14 | 15 | 16 | class Trainer: 17 | def __init__(self, device, model, optimizer, scheduler, train_loader, 18 | val_loader, out, epochs, n_classes, val_epoch=10, iter_size=1): 19 | self.device = device 20 | 21 | self.model = model 22 | self.optim = optimizer 23 | self.scheduler = scheduler 24 | self.train_loader = train_loader 25 | self.val_loader = val_loader 26 | 27 | self.timestamp_start = \ 28 | datetime.datetime.now(pytz.timezone('UTC')) 29 | 30 | self.val_epoch = val_epoch 31 | self.iter_size = iter_size 32 | 33 | self.out = out 34 | if not osp.exists(self.out): 35 | os.makedirs(self.out) 36 | 37 | self.log_headers = [ 38 | 'epoch', 39 | 'train/loss', 40 | 'train/acc', 41 | 'train/acc_cls', 42 | 'train/mean_iu', 43 | 'train/fwavacc', 44 | 'valid/loss', 45 | 'valid/acc', 46 | 'valid/acc_cls', 47 | 'valid/mean_iu', 48 | 'valid/fwavacc', 49 | 'elapsed_time', 50 | ] 51 | if not osp.exists(osp.join(self.out, 'log.csv')): 52 | with open(osp.join(self.out, 'log.csv'), 'w') as f: 53 | f.write(','.join(self.log_headers) + '\n') 54 | 55 | self.n_classes = n_classes 56 | self.epoch = 1 57 | self.epochs = epochs 58 | self.best_mean_iu = 0 59 | 60 | def train_epoch(self): 61 | if self.epoch % self.val_epoch == 0 or self.epoch == 1: 62 | self.validate() 63 | 64 | self.model.train() 65 | train_metrics = runningScore(self.n_classes) 66 | train_loss_meter = averageMeter() 67 | 68 | self.optim.zero_grad() 69 | 70 | for data, target in tqdm.tqdm( 71 | self.train_loader, total=len(self.train_loader), 72 | desc=f'Train epoch={self.epoch}', ncols=80, leave=False): 73 | 74 | self.iter += 1 75 | assert self.model.training 76 | 77 | data, target = data.to(self.device), target.to(self.device) 78 | score = self.model(data) 79 | 80 | weight = self.train_loader.dataset.class_weight 81 | if weight: 82 | weight = torch.Tensor(weight).to(self.device) 83 | 84 | loss = CrossEntropyLoss(score, target, weight=weight, ignore_index=-1, reduction='mean') 85 | 86 | loss_data = loss.data.item() 87 | train_loss_meter.update(loss_data) 88 | 89 | if np.isnan(loss_data): 90 | raise ValueError('loss is nan while training') 91 | 92 | loss /= self.iter_size 93 | loss.backward() 94 | 95 | if self.iter % self.iter_size == 0: 96 | self.optim.step() 97 | self.optim.zero_grad() 98 | 99 | 100 | # if not isinstance(score, tuple): 101 | # lbl_pred = score.data.max(1)[1].cpu().numpy() 102 | # else: 103 | # lbl_pred = score[-1].data.max(1)[1].cpu().numpy() 104 | 105 | # lbl_true = target.data.cpu().numpy() 106 | # lbl_pred, lbl_true = get_multiscale_results(score, target, upsample_logits=False) 107 | if isinstance(score, tuple): 108 | lbl_pred = score[-1].data.max(1)[1].cpu().numpy() 109 | else: 110 | lbl_pred = score.data.max(1)[1].cpu().numpy() 111 | lbl_true = target.data.cpu().numpy() 112 | train_metrics.update(lbl_true, lbl_pred) 113 | 114 | acc, acc_cls, mean_iou, fwavacc, _ = train_metrics.get_scores() 115 | metrics = [acc, acc_cls, mean_iou, fwavacc] 116 | 117 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 118 | elapsed_time = ( 119 | datetime.datetime.now(pytz.timezone('UTC')) - 120 | self.timestamp_start).total_seconds() 121 | log = [self.epoch] + [train_loss_meter.avg] + \ 122 | metrics + [''] * 5 + [elapsed_time] 123 | log = map(str, log) 124 | f.write(','.join(log) + '\n') 125 | 126 | if self.scheduler: 127 | self.scheduler.step() 128 | if self.epoch % self.val_epoch == 0 or self.epoch == 1: 129 | lr = self.optim.param_groups[0]['lr'] 130 | print(f'\nCurrent base learning rate of epoch {self.epoch}: {lr:.7f}') 131 | 132 | train_loss_meter.reset() 133 | train_metrics.reset() 134 | 135 | def validate(self): 136 | 137 | visualizations = [] 138 | val_metrics = runningScore(self.n_classes) 139 | val_loss_meter = averageMeter() 140 | 141 | with torch.no_grad(): 142 | self.model.eval() 143 | for data, target in tqdm.tqdm( 144 | self.val_loader, total=len(self.val_loader), 145 | desc=f'Valid epoch={self.epoch}', ncols=80, leave=False): 146 | 147 | data, target = data.to(self.device), target.to(self.device) 148 | 149 | score = self.model(data) 150 | 151 | weight = self.val_loader.dataset.class_weight 152 | if weight: 153 | weight = torch.Tensor(weight).to(self.device) 154 | 155 | # target = resize_labels(target, (score.size()[2], score.size()[3])) 156 | # target = target.to(self.device) 157 | loss = CrossEntropyLoss(score, target, weight=weight, reduction='mean', ignore_index=-1) 158 | loss_data = loss.data.item() 159 | if np.isnan(loss_data): 160 | raise ValueError('loss is nan while validating') 161 | 162 | val_loss_meter.update(loss_data) 163 | 164 | # if not isinstance(score, tuple): 165 | # lbl_pred = score.data.max(1)[1].cpu().numpy() 166 | # else: 167 | # lbl_pred = score[-1].data.max(1)[1].cpu().numpy() 168 | 169 | # lbl_pred, lbl_true = get_multiscale_results(score, target, upsample_logits=False) 170 | imgs = data.data.cpu() 171 | if isinstance(score, tuple): 172 | lbl_pred = score[-1].data.max(1)[1].cpu().numpy() 173 | else: 174 | lbl_pred = score.data.max(1)[1].cpu().numpy() 175 | lbl_true = target.data.cpu() 176 | for img, lt, lp in zip(imgs, lbl_true, lbl_pred): 177 | img, lt = self.val_loader.dataset.untransform(img, lt) 178 | val_metrics.update(lt, lp) 179 | # img = Image.fromarray(img).resize((lt.shape[1], lt.shape[0]), Image.BILINEAR) 180 | # img = np.array(img) 181 | if len(visualizations) < 9: 182 | viz = visualize_segmentation( 183 | lbl_pred=lp, lbl_true=lt, img=img, 184 | n_classes=self.n_classes, dataloader=self.train_loader) 185 | visualizations.append(viz) 186 | 187 | acc, acc_cls, mean_iou, fwavacc, _ = val_metrics.get_scores() 188 | metrics = [acc, acc_cls, mean_iou, fwavacc] 189 | 190 | print(f'\nEpoch: {self.epoch}', f'loss: {val_loss_meter.avg}, mIoU: {mean_iou}') 191 | 192 | out = osp.join(self.out, 'visualization_viz') 193 | if not osp.exists(out): 194 | os.makedirs(out) 195 | out_file = osp.join(out, 'epoch{:0>5d}.jpg'.format(self.epoch)) 196 | scipy.misc.imsave(out_file, get_tile_image(visualizations)) 197 | 198 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 199 | elapsed_time = ( 200 | datetime.datetime.now(pytz.timezone('UTC')) - 201 | self.timestamp_start).total_seconds() 202 | log = [self.epoch] + [''] * 5 + \ 203 | [val_loss_meter.avg] + metrics + [elapsed_time] 204 | log = map(str, log) 205 | f.write(','.join(log) + '\n') 206 | 207 | mean_iu = metrics[2] 208 | is_best = mean_iu > self.best_mean_iu 209 | if is_best: 210 | self.best_mean_iu = mean_iu 211 | torch.save({ 212 | 'epoch': self.epoch, 213 | 'arch': self.model.__class__.__name__, 214 | 'optim_state_dict': self.optim.state_dict(), 215 | 'model_state_dict': self.model.state_dict(), 216 | 'best_mean_iu': self.best_mean_iu, 217 | }, osp.join(self.out, 'checkpoint.pth.tar')) 218 | if is_best: 219 | shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'), 220 | osp.join(self.out, 'model_best.pth.tar')) 221 | 222 | val_loss_meter.reset() 223 | val_metrics.reset() 224 | 225 | def train(self): 226 | self.iter = 0 227 | for epoch in tqdm.trange(self.epoch, self.epochs + 1, 228 | desc='Train', ncols=80): 229 | self.epoch = epoch 230 | self.train_epoch() 231 | 232 | learning_curve(osp.join(self.out, 'log.csv')) 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import pandas 5 | import seaborn 6 | import skimage 7 | import skimage.color 8 | import skimage.transform 9 | from torch.optim import lr_scheduler 10 | 11 | # Adapted from https://github.com/wkentaro/fcn/blob/master/fcn/utils.py 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Visualization 15 | # ----------------------------------------------------------------------------- 16 | 17 | 18 | def centerize(src, dst_shape, margin_color=None): 19 | """Centerize image for specified image size 20 | @param src: image to centerize 21 | @param dst_shape: image shape (height, width) or (height, width, channel) 22 | """ 23 | if src.shape[:2] == dst_shape[:2]: 24 | return src 25 | centerized = np.zeros(dst_shape, dtype=src.dtype) 26 | if margin_color: 27 | centerized[:, :] = margin_color 28 | pad_vertical, pad_horizontal = 0, 0 29 | h, w = src.shape[:2] 30 | dst_h, dst_w = dst_shape[:2] 31 | if h < dst_h: 32 | pad_vertical = (dst_h - h) // 2 33 | if w < dst_w: 34 | pad_horizontal = (dst_w - w) // 2 35 | centerized[pad_vertical:pad_vertical + h, pad_horizontal:pad_horizontal + 36 | w] = src 37 | return centerized 38 | 39 | 40 | def _tile_images(imgs, tile_shape, concatenated_image): 41 | """Concatenate images whose sizes are same. 42 | @param imgs: image list which should be concatenated 43 | @param tile_shape: shape for which images should be concatenated 44 | @param concatenated_image: returned image. 45 | if it is None, new image will be created. 46 | """ 47 | y_num, x_num = tile_shape 48 | one_width = imgs[0].shape[1] 49 | one_height = imgs[0].shape[0] 50 | if concatenated_image is None: 51 | if len(imgs[0].shape) == 3: 52 | n_channels = imgs[0].shape[2] 53 | assert all(im.shape[2] == n_channels for im in imgs) 54 | concatenated_image = np.zeros( 55 | (one_height * y_num, one_width * x_num, n_channels), 56 | dtype=np.uint8, 57 | ) 58 | else: 59 | concatenated_image = np.zeros( 60 | (one_height * y_num, one_width * x_num), dtype=np.uint8) 61 | for y in range(y_num): 62 | for x in range(x_num): 63 | i = x + y * x_num 64 | if i >= len(imgs): 65 | pass 66 | else: 67 | concatenated_image[y * one_height:(y + 1) * one_height, x * 68 | one_width:(x + 1) * one_width] = imgs[i] 69 | return concatenated_image 70 | 71 | 72 | def get_tile_image(imgs, tile_shape=None, result_img=None, margin_color=None): 73 | """Concatenate images whose sizes are different. 74 | @param imgs: image list which should be concatenated 75 | @param tile_shape: shape for which images should be concatenated 76 | @param result_img: numpy array to put result image 77 | """ 78 | 79 | def resize(*args, **kwargs): 80 | return skimage.transform.resize(*args, **kwargs) 81 | 82 | def get_tile_shape(img_num): 83 | x_num = 0 84 | y_num = int(math.sqrt(img_num)) 85 | while x_num * y_num < img_num: 86 | x_num += 1 87 | return y_num, x_num 88 | 89 | if tile_shape is None: 90 | tile_shape = get_tile_shape(len(imgs)) 91 | 92 | # get max tile size to which each image should be resized 93 | max_height, max_width = np.inf, np.inf 94 | for img in imgs: 95 | max_height = min([max_height, img.shape[0]]) 96 | max_width = min([max_width, img.shape[1]]) 97 | 98 | # resize and concatenate images 99 | for i, img in enumerate(imgs): 100 | h, w = img.shape[:2] 101 | dtype = img.dtype 102 | h_scale, w_scale = max_height / h, max_width / w 103 | scale = min([h_scale, w_scale]) 104 | h, w = int(scale * h), int(scale * w) 105 | img = resize( 106 | image=img, 107 | output_shape=(h, w), 108 | mode='reflect', 109 | preserve_range=True, 110 | anti_aliasing=True, 111 | ).astype(dtype) 112 | if len(img.shape) == 3: 113 | img = centerize(img, (max_height, max_width, 3), margin_color) 114 | else: 115 | img = centerize(img, (max_height, max_width), margin_color) 116 | imgs[i] = img 117 | return _tile_images(imgs, tile_shape, result_img) 118 | 119 | 120 | def label2rgb(lbl, dataloader, img=None, n_labels=None, alpha=0.5): 121 | if n_labels is None: 122 | n_labels = lbl.max() + 1 # +1 for bg_label 0 123 | 124 | cmap = dataloader.dataset.getpalette() 125 | # cmap = getpalette(n_labels) 126 | # cmap = np.array(cmap).reshape([-1, 3]).astype(np.uint8) 127 | 128 | lbl_viz = cmap[lbl] 129 | lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled 130 | 131 | if img is not None: 132 | 133 | # img_gray = skimage.color.rgb2gray(img) 134 | # img_gray = skimage.color.gray2rgb(img_gray) 135 | # img_gray *= 255 136 | lbl_viz = alpha * lbl_viz + (1 - alpha) * img 137 | lbl_viz = lbl_viz.astype(np.uint8) 138 | 139 | return lbl_viz 140 | 141 | 142 | def visualize_segmentation(**kwargs): 143 | """Visualize segmentation. 144 | Parameters 145 | ---------- 146 | img: ndarray 147 | Input image to predict label. 148 | lbl_true: ndarray 149 | Ground truth of the label. 150 | lbl_pred: ndarray 151 | Label predicted. 152 | n_class: int 153 | Number of classes. 154 | label_names: dict or list 155 | Names of each label value. 156 | Key or index is label_value and value is its name. 157 | Returns 158 | ------- 159 | img_array: ndarray 160 | Visualized image. 161 | """ 162 | img = kwargs.pop('img', None) 163 | lbl_true = kwargs.pop('lbl_true', None) 164 | lbl_pred = kwargs.pop('lbl_pred', None) 165 | n_class = kwargs.pop('n_classes', None) 166 | dataloader = kwargs.pop('dataloader', None) 167 | if kwargs: 168 | raise RuntimeError('Unexpected keys in kwargs: {}'.format( 169 | kwargs.keys())) 170 | 171 | if lbl_true is None and lbl_pred is None: 172 | raise ValueError('lbl_true or lbl_pred must be not None.') 173 | 174 | mask_unlabeled = None 175 | viz_unlabeled = None 176 | if lbl_true is not None: 177 | mask_unlabeled = lbl_true == -1 178 | # lbl_true[mask_unlabeled] = 0 179 | viz_unlabeled = (np.zeros((lbl_true.shape[0], lbl_true.shape[1], 180 | 3))).astype(np.uint8) 181 | # if lbl_pred is not None: 182 | # lbl_pred[mask_unlabeled] = 0 183 | 184 | vizs = [] 185 | 186 | if lbl_true is not None: 187 | viz_trues = [ 188 | img, 189 | label2rgb(lbl_true, dataloader, n_labels=n_class), 190 | label2rgb(lbl_true, dataloader, img, n_labels=n_class), 191 | ] 192 | viz_trues[1][mask_unlabeled] = viz_unlabeled[mask_unlabeled] 193 | viz_trues[2][mask_unlabeled] = viz_unlabeled[mask_unlabeled] 194 | vizs.append(get_tile_image(viz_trues, (1, 3))) 195 | 196 | if lbl_pred is not None: 197 | viz_preds = [ 198 | img, 199 | label2rgb(lbl_pred, dataloader, n_labels=n_class), 200 | label2rgb(lbl_pred, dataloader, img, n_labels=n_class), 201 | ] 202 | if mask_unlabeled is not None and viz_unlabeled is not None: 203 | viz_preds[1][mask_unlabeled] = viz_unlabeled[mask_unlabeled] 204 | viz_preds[2][mask_unlabeled] = viz_unlabeled[mask_unlabeled] 205 | vizs.append(get_tile_image(viz_preds, (1, 3))) 206 | 207 | if len(vizs) == 1: 208 | return vizs[0] 209 | elif len(vizs) == 2: 210 | return get_tile_image(vizs, (2, 1)) 211 | else: 212 | raise RuntimeError 213 | 214 | 215 | # ----------------------------------------------------------------------------- 216 | # Utilities 217 | # ----------------------------------------------------------------------------- 218 | 219 | # Adapted from official CycleGAN implementation 220 | 221 | 222 | def get_scheduler(optimizer, opt): 223 | """Return a learning rate scheduler 224 | Parameters: 225 | optimizer -- the optimizer of the network 226 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  227 | opt.lr_policy is the name of learning rate policy: linear | poly | step | plateau | cosine 228 | For 'linear', we keep the same learning rate for the first epochs 229 | and linearly decay the rate to zero over the next epochs. 230 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 231 | See https://pytorch.org/docs/stable/optim.html for more details. 232 | """ 233 | if opt.lr_policy == 'linear': 234 | 235 | def lambda_rule(epoch): 236 | lr = 1.0 - max(0, 237 | epoch + 1 - opt.epochs) / float(opt.niter_decay + 1) 238 | return lr 239 | 240 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 241 | elif opt.lr_policy == 'poly': 242 | 243 | def lambda_rule(epoch): 244 | lr = (1 - epoch / opt.epochs)**opt.lr_power 245 | return lr 246 | 247 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 248 | elif opt.lr_policy == 'step': 249 | scheduler = lr_scheduler.StepLR( 250 | optimizer, step_size=opt.lr_decay_step, gamma=0.1) 251 | elif opt.lr_policy == 'plateau': 252 | scheduler = lr_scheduler.ReduceLROnPlateau( 253 | optimizer, mode='min', factor=0.2, threshold=1e-4, patience=5) 254 | elif opt.lr_policy == 'cosine': 255 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs) 256 | elif opt.lr_policy is None: 257 | scheduler = None 258 | else: 259 | return NotImplementedError( 260 | f'learning rate policy {opt.lr_policy} is not implemented') 261 | return scheduler 262 | 263 | 264 | # Adapted from: 265 | # https://github.com/wkentaro/pytorch-fcn/blob/master/examples/voc/learning_curve.py 266 | 267 | 268 | def learning_curve(log_file): 269 | print(f'==> Plotting log file: {log_file}') 270 | 271 | df = pandas.read_csv(log_file) 272 | 273 | colors = ['red', 'green', 'blue', 'purple', 'orange'] 274 | colors = seaborn.xkcd_palette(colors) 275 | 276 | plt.figure(figsize=(20, 6), dpi=300) 277 | 278 | row_min = df.min() 279 | row_max = df.max() 280 | 281 | # initialize DataFrame for train 282 | columns = [ 283 | 'epoch', 284 | 'train/loss', 285 | 'train/acc', 286 | 'train/acc_cls', 287 | 'train/mean_iu', 288 | 'train/fwavacc', 289 | ] 290 | df_train = df[columns] 291 | # if hasattr(df_train, 'rolling'): 292 | # df_train = df_train.rolling(window=10).mean() 293 | # else: 294 | # df_train = pandas.rolling_mean(df_train, window=10) 295 | df_train = df_train.dropna() 296 | 297 | # initialize DataFrame for val 298 | columns = [ 299 | 'epoch', 300 | 'valid/loss', 301 | 'valid/acc', 302 | 'valid/acc_cls', 303 | 'valid/mean_iu', 304 | 'valid/fwavacc', 305 | ] 306 | df_valid = df[columns] 307 | df_valid = df_valid.dropna() 308 | 309 | data_frames = {'train': df_train, 'valid': df_valid} 310 | 311 | n_row = 2 312 | n_col = 2 313 | for i, split in enumerate(['train', 'valid']): 314 | df_split = data_frames[split] 315 | 316 | # loss 317 | plt.subplot(n_row, n_col, i * n_col + 1) 318 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) 319 | plt.plot( 320 | df_split['epoch'], 321 | df_split[f'{split}/loss'], 322 | '-', 323 | markersize=1, 324 | color=colors[0], 325 | alpha=.5, 326 | label=f'{split} loss') 327 | plt.xlim((1, row_max['epoch'])) 328 | plt.ylim( 329 | min(df_split[f'{split}/loss']), max(df_split[f'{split}/loss'])) 330 | plt.xlabel('epoch') 331 | plt.ylabel(f'{split} loss') 332 | 333 | # loss (log) 334 | # plt.subplot(n_row, n_col, i * n_col + 2) 335 | # plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) 336 | # plt.semilogy(df_split['epoch'], df_split[f'{split}/loss'], 337 | # '-', markersize=1, color=colors[0], alpha=.5, 338 | # label=f'{split} loss') 339 | # plt.xlim((1, row_max['epoch'])) 340 | # plt.ylim(min(df_split[f'{split}/loss']), max(df_split[f'{split}/loss'])) 341 | # plt.xlabel('epoch') 342 | # plt.ylabel('f{split} loss (log)') 343 | 344 | # lbl accuracy 345 | plt.subplot(n_row, n_col, i * n_col + 2) 346 | plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0)) 347 | plt.plot( 348 | df_split['epoch'], 349 | df_split[f'{split}/acc'], 350 | '-', 351 | markersize=1, 352 | color=colors[1], 353 | alpha=.5, 354 | label=f'{split} accuracy') 355 | plt.plot( 356 | df_split['epoch'], 357 | df_split[f'{split}/acc_cls'], 358 | '-', 359 | markersize=1, 360 | color=colors[2], 361 | alpha=.5, 362 | label=f'{split} accuracy class') 363 | plt.plot( 364 | df_split['epoch'], 365 | df_split[f'{split}/mean_iu'], 366 | '-', 367 | markersize=1, 368 | color=colors[3], 369 | alpha=.5, 370 | label=f'{split} mean IU') 371 | plt.plot( 372 | df_split['epoch'], 373 | df_split[f'{split}/fwavacc'], 374 | '-', 375 | markersize=1, 376 | color=colors[4], 377 | alpha=.5, 378 | label=f'{split} fwav accuracy') 379 | plt.legend() 380 | plt.xlim((1, row_max['epoch'])) 381 | plt.ylim((0, 1)) 382 | plt.xlabel('epoch') 383 | plt.ylabel(f'{split} label accuracy') 384 | 385 | # out_file = osp.splitext(log_file)[0] + '.png' 386 | out_file = log_file[:-4] + '.png' 387 | plt.savefig(out_file) 388 | print(f'==> Wrote figure to: {out_file}') 389 | --------------------------------------------------------------------------------