├── .gitignore
├── result
├── result.gif
└── trials.png
├── requirements.txt
├── script
└── install_packages.sh
├── README.md
└── python
├── Cityscapes_loader.py
├── CamVid_loader.py
├── CamVid_utils.py
├── train.py
├── Cityscapes_utils.py
└── fcn.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | **/.DS_Store
3 | **/__pycache__
4 | CamVid/
5 | Cityscapes
6 |
--------------------------------------------------------------------------------
/result/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pochih/FCN-pytorch/HEAD/result/result.gif
--------------------------------------------------------------------------------
/result/trials.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pochih/FCN-pytorch/HEAD/result/trials.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.0.2
2 | numpy==1.13.3
3 | scipy==0.19.1
4 | torchvision==0.1.9
5 |
--------------------------------------------------------------------------------
/script/install_packages.sh:
--------------------------------------------------------------------------------
1 | pip3 install -r requirements.txt
2 | pip3 install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/ellerbrock/open-source-badges/)
2 |
3 | ## 🚘 The easiest implementation of fully convolutional networks
4 |
5 | - Task: __semantic segmentation__, it's a very important task for automated driving
6 |
7 | - The model is based on CVPR '15 best paper honorable mentioned [Fully Convolutional Networks for Semantic Segmentation](https://arxiv.org/abs/1411.4038)
8 |
9 | ## Results
10 | ### Trials
11 |
12 |
13 | ### Training Procedures
14 |
15 |
16 |
17 | ## Performance
18 |
19 | I train with two popular benchmark dataset: CamVid and Cityscapes
20 |
21 | |dataset|n_class|pixel accuracy|
22 | |---|---|---
23 | |Cityscapes|20|96%
24 | |CamVid|32|93%
25 |
26 | ## Training
27 |
28 | ### Install packages
29 | ```bash
30 | pip3 install -r requirements.txt
31 | ```
32 |
33 | and download pytorch 0.2.0 from [pytorch.org](pytorch.org)
34 |
35 | and download [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) dataset (recommended) or [Cityscapes](https://www.cityscapes-dataset.com/) dataset
36 |
37 | ### Run the code
38 | - default dataset is CamVid
39 |
40 | create a directory named "CamVid", and put data into it, then run python codes:
41 | ```python
42 | python3 python/CamVid_utils.py
43 | python3 python/train.py CamVid
44 | ```
45 |
46 | - or train with CityScapes
47 |
48 | create a directory named "CityScapes", and put data into it, then run python codes:
49 | ```python
50 | python3 python/CityScapes_utils.py
51 | python3 python/train.py CityScapes
52 | ```
53 |
54 | ## Author
55 | Po-Chih Huang / [@pochih](https://pochih.github.io/)
56 |
--------------------------------------------------------------------------------
/python/Cityscapes_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | from matplotlib import pyplot as plt
6 | import pandas as pd
7 | import numpy as np
8 | import scipy.misc
9 | import random
10 | import os
11 |
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | from torchvision import utils
15 |
16 |
17 | root_dir = "CityScapes/"
18 | train_file = os.path.join(root_dir, "train.csv")
19 | val_file = os.path.join(root_dir, "val.csv")
20 |
21 | num_class = 20
22 | means = np.array([103.939, 116.779, 123.68]) / 255. # mean of three channels in the order of BGR
23 | h, w = 1024, 2048
24 | train_h = int(h/2) # 512
25 | train_w = int(w/2) # 1024
26 | val_h = h # 1024
27 | val_w = w # 2048
28 |
29 |
30 | class CityScapesDataset(Dataset):
31 |
32 | def __init__(self, csv_file, phase, n_class=num_class, crop=False, flip_rate=0.):
33 | self.data = pd.read_csv(csv_file)
34 | self.means = means
35 | self.n_class = n_class
36 |
37 | self.flip_rate = flip_rate
38 | self.crop = crop
39 | if phase == 'train':
40 | self.crop = True
41 | self.flip_rate = 0.5
42 | self.new_h = train_h
43 | self.new_w = train_w
44 |
45 | def __len__(self):
46 | return len(self.data)
47 |
48 | def __getitem__(self, idx):
49 | img_name = self.data.ix[idx, 0]
50 | img = scipy.misc.imread(img_name, mode='RGB')
51 | label_name = self.data.ix[idx, 1]
52 | label = np.load(label_name)
53 |
54 | if self.crop:
55 | h, w, _ = img.shape
56 | top = random.randint(0, h - self.new_h)
57 | left = random.randint(0, w - self.new_w)
58 | img = img[top:top + self.new_h, left:left + self.new_w]
59 | label = label[top:top + self.new_h, left:left + self.new_w]
60 |
61 | if random.random() < self.flip_rate:
62 | img = np.fliplr(img)
63 | label = np.fliplr(label)
64 |
65 | # reduce mean
66 | img = img[:, :, ::-1] # switch to BGR
67 | img = np.transpose(img, (2, 0, 1)) / 255.
68 | img[0] -= self.means[0]
69 | img[1] -= self.means[1]
70 | img[2] -= self.means[2]
71 |
72 | # convert to tensor
73 | img = torch.from_numpy(img.copy()).float()
74 | label = torch.from_numpy(label.copy()).long()
75 |
76 | # create one-hot encoding
77 | h, w = label.size()
78 | target = torch.zeros(self.n_class, h, w)
79 | for c in range(self.n_class):
80 | target[c][label == c] = 1
81 |
82 | sample = {'X': img, 'Y': target, 'l': label}
83 |
84 | return sample
85 |
86 |
87 | def show_batch(batch):
88 | img_batch = batch['X']
89 | img_batch[:,0,...].add_(means[0])
90 | img_batch[:,1,...].add_(means[1])
91 | img_batch[:,2,...].add_(means[2])
92 | batch_size = len(img_batch)
93 |
94 | grid = utils.make_grid(img_batch)
95 | plt.imshow(grid.numpy()[::-1].transpose((1, 2, 0)))
96 |
97 | plt.title('Batch from dataloader')
98 |
99 |
100 | if __name__ == "__main__":
101 | train_data = CityScapesDataset(csv_file=train_file, phase='train')
102 |
103 | # show a batch
104 | batch_size = 4
105 | for i in range(batch_size):
106 | sample = train_data[i]
107 | print(i, sample['X'].size(), sample['Y'].size())
108 |
109 | dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=4)
110 |
111 | for i, batch in enumerate(dataloader):
112 | print(i, batch['X'].size(), batch['Y'].size())
113 |
114 | # observe 4th batch
115 | if i == 3:
116 | plt.figure()
117 | show_batch(batch)
118 | plt.axis('off')
119 | plt.ioff()
120 | plt.show()
121 | break
122 |
--------------------------------------------------------------------------------
/python/CamVid_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | from matplotlib import pyplot as plt
6 | import pandas as pd
7 | import numpy as np
8 | import scipy.misc
9 | import random
10 | import os
11 |
12 | import torch
13 | from torch.utils.data import Dataset, DataLoader
14 | from torchvision import utils
15 |
16 |
17 | root_dir = "CamVid/"
18 | train_file = os.path.join(root_dir, "train.csv")
19 | val_file = os.path.join(root_dir, "val.csv")
20 |
21 | num_class = 32
22 | means = np.array([103.939, 116.779, 123.68]) / 255. # mean of three channels in the order of BGR
23 | h, w = 720, 960
24 | train_h = int(h * 2 / 3) # 480
25 | train_w = int(w * 2 / 3) # 640
26 | val_h = int(h/32) * 32 # 704
27 | val_w = w # 960
28 |
29 |
30 | class CamVidDataset(Dataset):
31 |
32 | def __init__(self, csv_file, phase, n_class=num_class, crop=True, flip_rate=0.5):
33 | self.data = pd.read_csv(csv_file)
34 | self.means = means
35 | self.n_class = n_class
36 |
37 | self.flip_rate = flip_rate
38 | self.crop = crop
39 | if phase == 'train':
40 | self.new_h = train_h
41 | self.new_w = train_w
42 | elif phase == 'val':
43 | self.flip_rate = 0.
44 | self.crop = False
45 | self.new_h = val_h
46 | self.new_w = val_w
47 |
48 |
49 | def __len__(self):
50 | return len(self.data)
51 |
52 | def __getitem__(self, idx):
53 | img_name = self.data.ix[idx, 0]
54 | img = scipy.misc.imread(img_name, mode='RGB')
55 | label_name = self.data.ix[idx, 1]
56 | label = np.load(label_name)
57 |
58 | if self.crop:
59 | h, w, _ = img.shape
60 | top = random.randint(0, h - self.new_h)
61 | left = random.randint(0, w - self.new_w)
62 | img = img[top:top + self.new_h, left:left + self.new_w]
63 | label = label[top:top + self.new_h, left:left + self.new_w]
64 |
65 | if random.random() < self.flip_rate:
66 | img = np.fliplr(img)
67 | label = np.fliplr(label)
68 |
69 | # reduce mean
70 | img = img[:, :, ::-1] # switch to BGR
71 | img = np.transpose(img, (2, 0, 1)) / 255.
72 | img[0] -= self.means[0]
73 | img[1] -= self.means[1]
74 | img[2] -= self.means[2]
75 |
76 | # convert to tensor
77 | img = torch.from_numpy(img.copy()).float()
78 | label = torch.from_numpy(label.copy()).long()
79 |
80 | # create one-hot encoding
81 | h, w = label.size()
82 | target = torch.zeros(self.n_class, h, w)
83 | for c in range(self.n_class):
84 | target[c][label == c] = 1
85 |
86 | sample = {'X': img, 'Y': target, 'l': label}
87 |
88 | return sample
89 |
90 |
91 | def show_batch(batch):
92 | img_batch = batch['X']
93 | img_batch[:,0,...].add_(means[0])
94 | img_batch[:,1,...].add_(means[1])
95 | img_batch[:,2,...].add_(means[2])
96 | batch_size = len(img_batch)
97 |
98 | grid = utils.make_grid(img_batch)
99 | plt.imshow(grid.numpy()[::-1].transpose((1, 2, 0)))
100 |
101 | plt.title('Batch from dataloader')
102 |
103 |
104 | if __name__ == "__main__":
105 | train_data = CamVidDataset(csv_file=train_file, phase='train')
106 |
107 | # show a batch
108 | batch_size = 4
109 | for i in range(batch_size):
110 | sample = train_data[i]
111 | print(i, sample['X'].size(), sample['Y'].size())
112 |
113 | dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
114 |
115 | for i, batch in enumerate(dataloader):
116 | print(i, batch['X'].size(), batch['Y'].size())
117 |
118 | # observe 4th batch
119 | if i == 3:
120 | plt.figure()
121 | show_batch(batch)
122 | plt.axis('off')
123 | plt.ioff()
124 | plt.show()
125 | break
126 |
--------------------------------------------------------------------------------
/python/CamVid_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | from matplotlib import pyplot as plt
6 | import matplotlib.image as mpimg
7 | import numpy as np
8 | import scipy.misc
9 | import random
10 | import os
11 |
12 |
13 | #############################
14 | # global variables #
15 | #############################
16 | root_dir = "CamVid/"
17 | data_dir = os.path.join(root_dir, "701_StillsRaw_full") # train data
18 | label_dir = os.path.join(root_dir, "LabeledApproved_full") # train label
19 | label_colors_file = os.path.join(root_dir, "label_colors.txt") # color to label
20 | val_label_file = os.path.join(root_dir, "val.csv") # validation file
21 | train_label_file = os.path.join(root_dir, "train.csv") # train file
22 |
23 | # create dir for label index
24 | label_idx_dir = os.path.join(root_dir, "Labeled_idx")
25 | if not os.path.exists(label_idx_dir):
26 | os.makedirs(label_idx_dir)
27 |
28 | label2color = {}
29 | color2label = {}
30 | label2index = {}
31 | index2label = {}
32 |
33 |
34 | def divide_train_val(val_rate=0.1, shuffle=True, random_seed=None):
35 | data_list = os.listdir(data_dir)
36 | data_len = len(data_list)
37 | val_len = int(data_len * val_rate)
38 |
39 | if random_seed:
40 | random.seed(random_seed)
41 |
42 | if shuffle:
43 | data_idx = random.sample(range(data_len), data_len)
44 | else:
45 | data_idx = list(range(data_len))
46 |
47 | val_idx = [data_list[i] for i in data_idx[:val_len]]
48 | train_idx = [data_list[i] for i in data_idx[val_len:]]
49 |
50 | # create val.csv
51 | v = open(val_label_file, "w")
52 | v.write("img,label\n")
53 | for idx, name in enumerate(val_idx):
54 | if 'png' not in name:
55 | continue
56 | img_name = os.path.join(data_dir, name)
57 | lab_name = os.path.join(label_idx_dir, name)
58 | lab_name = lab_name.split(".")[0] + "_L.png.npy"
59 | v.write("{},{}\n".format(img_name, lab_name))
60 |
61 | # create train.csv
62 | t = open(train_label_file, "w")
63 | t.write("img,label\n")
64 | for idx, name in enumerate(train_idx):
65 | if 'png' not in name:
66 | continue
67 | img_name = os.path.join(data_dir, name)
68 | lab_name = os.path.join(label_idx_dir, name)
69 | lab_name = lab_name.split(".")[0] + "_L.png.npy"
70 | t.write("{},{}\n".format(img_name, lab_name))
71 |
72 |
73 | def parse_label():
74 | # change label to class index
75 | f = open(label_colors_file, "r").read().split("\n")[:-1] # ignore the last empty line
76 | for idx, line in enumerate(f):
77 | label = line.split()[-1]
78 | color = tuple([int(x) for x in line.split()[:-1]])
79 | print(label, color)
80 | label2color[label] = color
81 | color2label[color] = label
82 | label2index[label] = idx
83 | index2label[idx] = label
84 | # rgb = np.zeros((255, 255, 3), dtype=np.uint8)
85 | # rgb[..., 0] = color[0]
86 | # rgb[..., 1] = color[1]
87 | # rgb[..., 2] = color[2]
88 | # imshow(rgb, title=label)
89 |
90 | for idx, name in enumerate(os.listdir(label_dir)):
91 | filename = os.path.join(label_idx_dir, name)
92 | if os.path.exists(filename + '.npy'):
93 | print("Skip %s" % (name))
94 | continue
95 | print("Parse %s" % (name))
96 | img = os.path.join(label_dir, name)
97 | img = scipy.misc.imread(img, mode='RGB')
98 | height, weight, _ = img.shape
99 |
100 | idx_mat = np.zeros((height, weight))
101 | for h in range(height):
102 | for w in range(weight):
103 | color = tuple(img[h, w])
104 | try:
105 | label = color2label[color]
106 | index = label2index[label]
107 | idx_mat[h, w] = index
108 | except:
109 | print("error: img:%s, h:%d, w:%d" % (name, h, w))
110 | idx_mat = idx_mat.astype(np.uint8)
111 | np.save(filename, idx_mat)
112 | print("Finish %s" % (name))
113 |
114 | # test some pixels' label
115 | img = os.path.join(label_dir, os.listdir(label_dir)[0])
116 | img = scipy.misc.imread(img, mode='RGB')
117 | test_cases = [(555, 405), (0, 0), (380, 645), (577, 943)]
118 | test_ans = ['Car', 'Building', 'Truck_Bus', 'Car']
119 | for idx, t in enumerate(test_cases):
120 | color = img[t]
121 | assert color2label[tuple(color)] == test_ans[idx]
122 |
123 |
124 | '''debug function'''
125 | def imshow(img, title=None):
126 | try:
127 | img = mpimg.imread(img)
128 | imgplot = plt.imshow(img)
129 | except:
130 | plt.imshow(img, interpolation='nearest')
131 |
132 | if title is not None:
133 | plt.title(title)
134 |
135 | plt.show()
136 |
137 |
138 | if __name__ == '__main__':
139 | divide_train_val(random_seed=1)
140 | parse_label()
141 |
--------------------------------------------------------------------------------
/python/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torch.optim import lr_scheduler
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 |
12 | from fcn import VGGNet, FCN32s, FCN16s, FCN8s, FCNs
13 | from Cityscapes_loader import CityscapesDataset
14 | from CamVid_loader import CamVidDataset
15 |
16 | from matplotlib import pyplot as plt
17 | import numpy as np
18 | import time
19 | import sys
20 | import os
21 |
22 |
23 | n_class = 20
24 |
25 | batch_size = 6
26 | epochs = 500
27 | lr = 1e-4
28 | momentum = 0
29 | w_decay = 1e-5
30 | step_size = 50
31 | gamma = 0.5
32 | configs = "FCNs-BCEWithLogits_batch{}_epoch{}_RMSprop_scheduler-step{}-gamma{}_lr{}_momentum{}_w_decay{}".format(batch_size, epochs, step_size, gamma, lr, momentum, w_decay)
33 | print("Configs:", configs)
34 |
35 | if sys.argv[1] == 'CamVid':
36 | root_dir = "CamVid/"
37 | else
38 | root_dir = "CityScapes/"
39 | train_file = os.path.join(root_dir, "train.csv")
40 | val_file = os.path.join(root_dir, "val.csv")
41 |
42 | # create dir for model
43 | model_dir = "models"
44 | if not os.path.exists(model_dir):
45 | os.makedirs(model_dir)
46 | model_path = os.path.join(model_dir, configs)
47 |
48 | use_gpu = torch.cuda.is_available()
49 | num_gpu = list(range(torch.cuda.device_count()))
50 |
51 | if sys.argv[1] == 'CamVid':
52 | train_data = CamVidDataset(csv_file=train_file, phase='train')
53 | else:
54 | train_data = CityscapesDataset(csv_file=train_file, phase='train')
55 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
56 |
57 | if sys.argv[1] == 'CamVid':
58 | val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=0)
59 | else:
60 | val_data = CityscapesDataset(csv_file=val_file, phase='val', flip_rate=0)
61 | val_loader = DataLoader(val_data, batch_size=1, num_workers=8)
62 |
63 | vgg_model = VGGNet(requires_grad=True, remove_fc=True)
64 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class)
65 |
66 | if use_gpu:
67 | ts = time.time()
68 | vgg_model = vgg_model.cuda()
69 | fcn_model = fcn_model.cuda()
70 | fcn_model = nn.DataParallel(fcn_model, device_ids=num_gpu)
71 | print("Finish cuda loading, time elapsed {}".format(time.time() - ts))
72 |
73 | criterion = nn.BCEWithLogitsLoss()
74 | optimizer = optim.RMSprop(fcn_model.parameters(), lr=lr, momentum=momentum, weight_decay=w_decay)
75 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) # decay LR by a factor of 0.5 every 30 epochs
76 |
77 | # create dir for score
78 | score_dir = os.path.join("scores", configs)
79 | if not os.path.exists(score_dir):
80 | os.makedirs(score_dir)
81 | IU_scores = np.zeros((epochs, n_class))
82 | pixel_scores = np.zeros(epochs)
83 |
84 |
85 | def train():
86 | for epoch in range(epochs):
87 | scheduler.step()
88 |
89 | ts = time.time()
90 | for iter, batch in enumerate(train_loader):
91 | optimizer.zero_grad()
92 |
93 | if use_gpu:
94 | inputs = Variable(batch['X'].cuda())
95 | labels = Variable(batch['Y'].cuda())
96 | else:
97 | inputs, labels = Variable(batch['X']), Variable(batch['Y'])
98 |
99 | outputs = fcn_model(inputs)
100 | loss = criterion(outputs, labels)
101 | loss.backward()
102 | optimizer.step()
103 |
104 | if iter % 10 == 0:
105 | print("epoch{}, iter{}, loss: {}".format(epoch, iter, loss.data[0]))
106 |
107 | print("Finish epoch {}, time elapsed {}".format(epoch, time.time() - ts))
108 | torch.save(fcn_model, model_path)
109 |
110 | val(epoch)
111 |
112 |
113 | def val(epoch):
114 | fcn_model.eval()
115 | total_ious = []
116 | pixel_accs = []
117 | for iter, batch in enumerate(val_loader):
118 | if use_gpu:
119 | inputs = Variable(batch['X'].cuda())
120 | else:
121 | inputs = Variable(batch['X'])
122 |
123 | output = fcn_model(inputs)
124 | output = output.data.cpu().numpy()
125 |
126 | N, _, h, w = output.shape
127 | pred = output.transpose(0, 2, 3, 1).reshape(-1, n_class).argmax(axis=1).reshape(N, h, w)
128 |
129 | target = batch['l'].cpu().numpy().reshape(N, h, w)
130 | for p, t in zip(pred, target):
131 | total_ious.append(iou(p, t))
132 | pixel_accs.append(pixel_acc(p, t))
133 |
134 | # Calculate average IoU
135 | total_ious = np.array(total_ious).T # n_class * val_len
136 | ious = np.nanmean(total_ious, axis=1)
137 | pixel_accs = np.array(pixel_accs).mean()
138 | print("epoch{}, pix_acc: {}, meanIoU: {}, IoUs: {}".format(epoch, pixel_accs, np.nanmean(ious), ious))
139 | IU_scores[epoch] = ious
140 | np.save(os.path.join(score_dir, "meanIU"), IU_scores)
141 | pixel_scores[epoch] = pixel_accs
142 | np.save(os.path.join(score_dir, "meanPixel"), pixel_scores)
143 |
144 |
145 | # borrow functions and modify it from https://github.com/Kaixhin/FCN-semantic-segmentation/blob/master/main.py
146 | # Calculates class intersections over unions
147 | def iou(pred, target):
148 | ious = []
149 | for cls in range(n_class):
150 | pred_inds = pred == cls
151 | target_inds = target == cls
152 | intersection = pred_inds[target_inds].sum()
153 | union = pred_inds.sum() + target_inds.sum() - intersection
154 | if union == 0:
155 | ious.append(float('nan')) # if there is no ground truth, do not include in evaluation
156 | else:
157 | ious.append(float(intersection) / max(union, 1))
158 | # print("cls", cls, pred_inds.sum(), target_inds.sum(), intersection, float(intersection) / max(union, 1))
159 | return ious
160 |
161 |
162 | def pixel_acc(pred, target):
163 | correct = (pred == target).sum()
164 | total = (target == target).sum()
165 | return correct / total
166 |
167 |
168 | if __name__ == "__main__":
169 | val(0) # show the accuracy before training
170 | train()
171 |
--------------------------------------------------------------------------------
/python/Cityscapes_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | from collections import namedtuple
6 | from matplotlib import pyplot as plt
7 | import matplotlib.image as mpimg
8 | import numpy as np
9 | import scipy.misc
10 | import random
11 | import os
12 |
13 |
14 | #############################
15 | # global variables #
16 | #############################
17 | root_dir = "CityScapes/"
18 |
19 | label_dir = os.path.join(root_dir, "gtFine")
20 | train_dir = os.path.join(label_dir, "train")
21 | val_dir = os.path.join(label_dir, "val")
22 | test_dir = os.path.join(label_dir, "test")
23 |
24 | # create dir for label index
25 | label_idx_dir = os.path.join(root_dir, "Labeled_idx")
26 | train_idx_dir = os.path.join(label_idx_dir, "train")
27 | val_idx_dir = os.path.join(label_idx_dir, "val")
28 | test_idx_dir = os.path.join(label_idx_dir, "test")
29 | for dir in [train_idx_dir, val_idx_dir, test_idx_dir]:
30 | if not os.path.exists(dir):
31 | os.makedirs(dir)
32 |
33 | train_file = os.path.join(root_dir, "train.csv")
34 | val_file = os.path.join(root_dir, "val.csv")
35 | test_file = os.path.join(root_dir, "test.csv")
36 |
37 | color2index = {}
38 |
39 | Label = namedtuple('Label', [
40 | 'name',
41 | 'id',
42 | 'trainId',
43 | 'category',
44 | 'categoryId',
45 | 'hasInstances',
46 | 'ignoreInEval',
47 | 'color'])
48 |
49 | labels = [
50 | # name id trainId category catId hasInstances ignoreInEval color
51 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
52 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
53 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
54 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
55 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
56 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
57 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
58 | Label( 'road' , 7 , 1 , 'flat' , 1 , False , False , (128, 64,128) ),
59 | Label( 'sidewalk' , 8 , 2 , 'flat' , 1 , False , False , (244, 35,232) ),
60 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
61 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
62 | Label( 'building' , 11 , 3 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
63 | Label( 'wall' , 12 , 4 , 'construction' , 2 , False , False , (102,102,156) ),
64 | Label( 'fence' , 13 , 5 , 'construction' , 2 , False , False , (190,153,153) ),
65 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
66 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
67 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
68 | Label( 'pole' , 17 , 6 , 'object' , 3 , False , False , (153,153,153) ),
69 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
70 | Label( 'traffic light' , 19 , 7 , 'object' , 3 , False , False , (250,170, 30) ),
71 | Label( 'traffic sign' , 20 , 8 , 'object' , 3 , False , False , (220,220, 0) ),
72 | Label( 'vegetation' , 21 , 9 , 'nature' , 4 , False , False , (107,142, 35) ),
73 | Label( 'terrain' , 22 , 10 , 'nature' , 4 , False , False , (152,251,152) ),
74 | Label( 'sky' , 23 , 11 , 'sky' , 5 , False , False , ( 70,130,180) ),
75 | Label( 'person' , 24 , 12 , 'human' , 6 , True , False , (220, 20, 60) ),
76 | Label( 'rider' , 25 , 13 , 'human' , 6 , True , False , (255, 0, 0) ),
77 | Label( 'car' , 26 , 14 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
78 | Label( 'truck' , 27 , 15 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
79 | Label( 'bus' , 28 , 16 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
80 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
81 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
82 | Label( 'train' , 31 , 17 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
83 | Label( 'motorcycle' , 32 , 18 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
84 | Label( 'bicycle' , 33 , 19 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
85 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
86 | ]
87 |
88 |
89 | def parse_label():
90 | # change label to class index
91 | color2index[(0,0,0)] = 0 # add an void class
92 | for obj in labels:
93 | if obj.ignoreInEval:
94 | continue
95 | idx = obj.trainId
96 | label = obj.name
97 | color = obj.color
98 | color2index[color] = idx
99 |
100 | # parse train, val, test data
101 | for label_dir, index_dir, csv_file in zip([train_dir, val_dir, test_dir], [train_idx_dir, val_idx_dir, test_idx_dir], [train_file, val_file, test_file]):
102 | f = open(csv_file, "w")
103 | f.write("img,label\n")
104 | for city in os.listdir(label_dir):
105 | city_dir = os.path.join(label_dir, city)
106 | city_idx_dir = os.path.join(index_dir, city)
107 | data_dir = city_dir.replace("gtFine", "leftImg8bit")
108 | if not os.path.exists(city_idx_dir):
109 | os.makedirs(city_idx_dir)
110 | for filename in os.listdir(city_dir):
111 | if 'color' not in filename:
112 | continue
113 | lab_name = os.path.join(city_idx_dir, filename)
114 | img_name = filename.split("gtFine")[0] + "leftImg8bit.png"
115 | img_name = os.path.join(data_dir, img_name)
116 | f.write("{},{}.npy\n".format(img_name, lab_name))
117 |
118 | if os.path.exists(lab_name + '.npy'):
119 | print("Skip %s" % (filename))
120 | continue
121 | print("Parse %s" % (filename))
122 | img = os.path.join(city_dir, filename)
123 | img = scipy.misc.imread(img, mode='RGB')
124 | height, weight, _ = img.shape
125 |
126 | idx_mat = np.zeros((height, weight))
127 | for h in range(height):
128 | for w in range(weight):
129 | color = tuple(img[h, w])
130 | try:
131 | index = color2index[color]
132 | idx_mat[h, w] = index
133 | except:
134 | # no index, assign to void
135 | idx_mat[h, w] = 19
136 | idx_mat = idx_mat.astype(np.uint8)
137 | np.save(lab_name, idx_mat)
138 | print("Finish %s" % (filename))
139 |
140 |
141 | '''debug function'''
142 | def imshow(img, title=None):
143 | try:
144 | img = mpimg.imread(img)
145 | imgplot = plt.imshow(img)
146 | except:
147 | plt.imshow(img, interpolation='nearest')
148 |
149 | if title is not None:
150 | plt.title(title)
151 |
152 | plt.show()
153 |
154 |
155 | if __name__ == '__main__':
156 | parse_label()
157 |
--------------------------------------------------------------------------------
/python/fcn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torchvision import models
9 | from torchvision.models.vgg import VGG
10 |
11 |
12 | class FCN32s(nn.Module):
13 |
14 | def __init__(self, pretrained_net, n_class):
15 | super().__init__()
16 | self.n_class = n_class
17 | self.pretrained_net = pretrained_net
18 | self.relu = nn.ReLU(inplace=True)
19 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
20 | self.bn1 = nn.BatchNorm2d(512)
21 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
22 | self.bn2 = nn.BatchNorm2d(256)
23 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
24 | self.bn3 = nn.BatchNorm2d(128)
25 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
26 | self.bn4 = nn.BatchNorm2d(64)
27 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
28 | self.bn5 = nn.BatchNorm2d(32)
29 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
30 |
31 | def forward(self, x):
32 | output = self.pretrained_net(x)
33 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32)
34 |
35 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16)
36 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8)
37 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4)
38 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2)
39 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W)
40 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1)
41 |
42 | return score # size=(N, n_class, x.H/1, x.W/1)
43 |
44 |
45 | class FCN16s(nn.Module):
46 |
47 | def __init__(self, pretrained_net, n_class):
48 | super().__init__()
49 | self.n_class = n_class
50 | self.pretrained_net = pretrained_net
51 | self.relu = nn.ReLU(inplace=True)
52 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
53 | self.bn1 = nn.BatchNorm2d(512)
54 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
55 | self.bn2 = nn.BatchNorm2d(256)
56 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
57 | self.bn3 = nn.BatchNorm2d(128)
58 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
59 | self.bn4 = nn.BatchNorm2d(64)
60 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
61 | self.bn5 = nn.BatchNorm2d(32)
62 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
63 |
64 | def forward(self, x):
65 | output = self.pretrained_net(x)
66 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32)
67 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16)
68 |
69 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16)
70 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16)
71 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8)
72 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4)
73 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2)
74 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W)
75 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1)
76 |
77 | return score # size=(N, n_class, x.H/1, x.W/1)
78 |
79 |
80 | class FCN8s(nn.Module):
81 |
82 | def __init__(self, pretrained_net, n_class):
83 | super().__init__()
84 | self.n_class = n_class
85 | self.pretrained_net = pretrained_net
86 | self.relu = nn.ReLU(inplace=True)
87 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
88 | self.bn1 = nn.BatchNorm2d(512)
89 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
90 | self.bn2 = nn.BatchNorm2d(256)
91 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
92 | self.bn3 = nn.BatchNorm2d(128)
93 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
94 | self.bn4 = nn.BatchNorm2d(64)
95 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
96 | self.bn5 = nn.BatchNorm2d(32)
97 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
98 |
99 | def forward(self, x):
100 | output = self.pretrained_net(x)
101 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32)
102 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16)
103 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8)
104 |
105 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16)
106 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16)
107 | score = self.relu(self.deconv2(score)) # size=(N, 256, x.H/8, x.W/8)
108 | score = self.bn2(score + x3) # element-wise add, size=(N, 256, x.H/8, x.W/8)
109 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4)
110 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2)
111 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W)
112 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1)
113 |
114 | return score # size=(N, n_class, x.H/1, x.W/1)
115 |
116 |
117 | class FCNs(nn.Module):
118 |
119 | def __init__(self, pretrained_net, n_class):
120 | super().__init__()
121 | self.n_class = n_class
122 | self.pretrained_net = pretrained_net
123 | self.relu = nn.ReLU(inplace=True)
124 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
125 | self.bn1 = nn.BatchNorm2d(512)
126 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
127 | self.bn2 = nn.BatchNorm2d(256)
128 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
129 | self.bn3 = nn.BatchNorm2d(128)
130 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
131 | self.bn4 = nn.BatchNorm2d(64)
132 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
133 | self.bn5 = nn.BatchNorm2d(32)
134 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1)
135 |
136 | def forward(self, x):
137 | output = self.pretrained_net(x)
138 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32)
139 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16)
140 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8)
141 | x2 = output['x2'] # size=(N, 128, x.H/4, x.W/4)
142 | x1 = output['x1'] # size=(N, 64, x.H/2, x.W/2)
143 |
144 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16)
145 | score = score + x4 # element-wise add, size=(N, 512, x.H/16, x.W/16)
146 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8)
147 | score = score + x3 # element-wise add, size=(N, 256, x.H/8, x.W/8)
148 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4)
149 | score = score + x2 # element-wise add, size=(N, 128, x.H/4, x.W/4)
150 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2)
151 | score = score + x1 # element-wise add, size=(N, 64, x.H/2, x.W/2)
152 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W)
153 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1)
154 |
155 | return score # size=(N, n_class, x.H/1, x.W/1)
156 |
157 |
158 | class VGGNet(VGG):
159 | def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
160 | super().__init__(make_layers(cfg[model]))
161 | self.ranges = ranges[model]
162 |
163 | if pretrained:
164 | exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)
165 |
166 | if not requires_grad:
167 | for param in super().parameters():
168 | param.requires_grad = False
169 |
170 | if remove_fc: # delete redundant fully-connected layer params, can save memory
171 | del self.classifier
172 |
173 | if show_params:
174 | for name, param in self.named_parameters():
175 | print(name, param.size())
176 |
177 | def forward(self, x):
178 | output = {}
179 |
180 | # get the output of each maxpooling layer (5 maxpool in VGG net)
181 | for idx in range(len(self.ranges)):
182 | for layer in range(self.ranges[idx][0], self.ranges[idx][1]):
183 | x = self.features[layer](x)
184 | output["x%d"%(idx+1)] = x
185 |
186 | return output
187 |
188 |
189 | ranges = {
190 | 'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)),
191 | 'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
192 | 'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
193 | 'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
194 | }
195 |
196 | # cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
197 | cfg = {
198 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
199 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
200 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
201 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
202 | }
203 |
204 | def make_layers(cfg, batch_norm=False):
205 | layers = []
206 | in_channels = 3
207 | for v in cfg:
208 | if v == 'M':
209 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
210 | else:
211 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
212 | if batch_norm:
213 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
214 | else:
215 | layers += [conv2d, nn.ReLU(inplace=True)]
216 | in_channels = v
217 | return nn.Sequential(*layers)
218 |
219 |
220 | if __name__ == "__main__":
221 | batch_size, n_class, h, w = 10, 20, 160, 160
222 |
223 | # test output size
224 | vgg_model = VGGNet(requires_grad=True)
225 | input = torch.autograd.Variable(torch.randn(batch_size, 3, 224, 224))
226 | output = vgg_model(input)
227 | assert output['x5'].size() == torch.Size([batch_size, 512, 7, 7])
228 |
229 | fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class)
230 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
231 | output = fcn_model(input)
232 | assert output.size() == torch.Size([batch_size, n_class, h, w])
233 |
234 | fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class)
235 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
236 | output = fcn_model(input)
237 | assert output.size() == torch.Size([batch_size, n_class, h, w])
238 |
239 | fcn_model = FCN8s(pretrained_net=vgg_model, n_class=n_class)
240 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
241 | output = fcn_model(input)
242 | assert output.size() == torch.Size([batch_size, n_class, h, w])
243 |
244 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class)
245 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
246 | output = fcn_model(input)
247 | assert output.size() == torch.Size([batch_size, n_class, h, w])
248 |
249 | print("Pass size check")
250 |
251 | # test a random batch, loss should decrease
252 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class)
253 | criterion = nn.BCELoss()
254 | optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9)
255 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w))
256 | y = torch.autograd.Variable(torch.randn(batch_size, n_class, h, w), requires_grad=False)
257 | for iter in range(10):
258 | optimizer.zero_grad()
259 | output = fcn_model(input)
260 | output = nn.functional.sigmoid(output)
261 | loss = criterion(output, y)
262 | loss.backward()
263 | print("iter{}, loss {}".format(iter, loss.data[0]))
264 | optimizer.step()
265 |
--------------------------------------------------------------------------------