├── README.md ├── model ├── train.sh ├── dataloader.py ├── inference.py ├── evaluation.py ├── lanenet.py ├── loss.py └── main.py ├── LICENSE ├── hnet ├── dataloader_parallel.py ├── dataloader_thinning.py ├── thinning.py ├── inference_h.py ├── lanenet_hnet.py ├── main_thinning.py └── main_parallel.py └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Instance-Lane-Segmentation 2 | A pytorch implementation of "Towards End-to-End Lane Detection: an Instance Segmentation Approach" 3 | -------------------------------------------------------------------------------- /model/train.sh: -------------------------------------------------------------------------------- 1 | #MV2_USE_CUDA=1 MV2_ENABLE_AFFINITY=0 MV2_SMP_USE_CMA=0 GLOG_logtostderr=1 \ 2 | srun --mpi=pmi2 -p bj11part -n1 --job-name=all --gres=gpu:4 --ntasks-per-node=1 python -u main.py 2>&1 | tee log.txt & 3 | 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import numpy as np 4 | import cv2 5 | #/mnt/lustre/share/dingmingyu/new_list_lane.txt 6 | 7 | class MyDataset(data.Dataset): 8 | def __init__(self, file, dir_path, new_width, new_height, label_width, label_height): 9 | imgs = [] 10 | fw = open(file, 'r') 11 | lines = fw.readlines() 12 | for line in lines: 13 | words = line.strip().split() 14 | imgs.append((words[0], words[1])) 15 | self.imgs = imgs 16 | self.dir_path = dir_path 17 | self.height = new_height 18 | self.width = new_width 19 | self.label_height = label_height 20 | self.label_width = label_width 21 | 22 | def __getitem__(self, index): 23 | path, label = self.imgs[index] 24 | path = os.path.join(self.dir_path, path) 25 | img = cv2.imread(path).astype(np.float32) 26 | img = img[:,:,:3] 27 | img = cv2.resize(img, (self.width, self.height)) 28 | img -= [104, 117, 123] 29 | img = img.transpose(2, 0, 1) 30 | gt = cv2.imread(label,-1) 31 | gt = cv2.resize(gt, (self.label_width, self.label_height), interpolation = cv2.INTER_NEAREST) 32 | if len(gt.shape) == 3: 33 | gt = gt[:,:,0] 34 | 35 | gt_num_list = list(np.unique(gt)) 36 | gt_num_list.remove(0) 37 | target_ins = np.zeros((4, gt.shape[0],gt.shape[1])).astype('uint8') 38 | for index, ins in enumerate(gt_num_list): 39 | target_ins[index,:,:] += (gt==ins) 40 | return img, target_ins, len(gt_num_list) 41 | 42 | def __len__(self): 43 | return len(self.imgs) 44 | -------------------------------------------------------------------------------- /hnet/dataloader_parallel.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import numpy as np 4 | import cv2 5 | #/mnt/lustre/share/dingmingyu/new_list_lane.txt 6 | 7 | class MyDataset(data.Dataset): 8 | def __init__(self, file, dir_path, new_width, new_height, label_width, label_height): 9 | imgs = [] 10 | fw = open(file, 'r') 11 | lines = fw.readlines() 12 | for line in lines: 13 | words = line.strip().split() 14 | imgs.append((words[0], words[1])) 15 | self.imgs = imgs 16 | self.dir_path = dir_path 17 | self.height = new_height 18 | self.width = new_width 19 | self.label_height = label_height 20 | self.label_width = label_width 21 | 22 | def __getitem__(self, index): 23 | path, label= self.imgs[index] 24 | path = os.path.join(self.dir_path, path) 25 | img = cv2.imread(path).astype(np.float32) 26 | img = img[:,:,:3] 27 | img = cv2.resize(img, (self.width, self.height)) 28 | img -= [104, 117, 123] 29 | img = img.transpose(2, 0, 1) 30 | gt = cv2.imread(label,-1) 31 | gt = cv2.resize(gt, (self.label_width, self.label_height), interpolation = cv2.INTER_NEAREST) 32 | if len(gt.shape) == 3: 33 | gt = gt[:,:,0] 34 | gt_num_list = list(np.unique(gt)) 35 | gt_num_list.remove(0) 36 | target_ins = np.zeros((4, gt.shape[0],gt.shape[1])).astype('uint8') 37 | for index, ins in enumerate(gt_num_list): 38 | target_ins[index,:,:] += (gt==ins) 39 | return img, target_ins, len(gt_num_list), gt 40 | 41 | def __len__(self): 42 | return len(self.imgs) 43 | -------------------------------------------------------------------------------- /hnet/dataloader_thinning.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import numpy as np 4 | import cv2 5 | #/mnt/lustre/share/dingmingyu/new_list_lane.txt 6 | 7 | class MyDataset(data.Dataset): 8 | def __init__(self, file, dir_path, new_width, new_height, label_width, label_height): 9 | imgs = [] 10 | fw = open(file, 'r') 11 | lines = fw.readlines() 12 | for line in lines: 13 | words = line.strip().split() 14 | imgs.append((words[0], words[1], words[2])) 15 | self.imgs = imgs 16 | self.dir_path = dir_path 17 | self.height = new_height 18 | self.width = new_width 19 | self.label_height = label_height 20 | self.label_width = label_width 21 | 22 | def __getitem__(self, index): 23 | path, label, h_label = self.imgs[index] 24 | path = os.path.join(self.dir_path, path) 25 | img = cv2.imread(path).astype(np.float32) 26 | img = img[:,:,:3] 27 | img = cv2.resize(img, (self.width, self.height)) 28 | img -= [104, 117, 123] 29 | img = img.transpose(2, 0, 1) 30 | gt = cv2.imread(label,-1) 31 | gt = cv2.resize(gt, (self.label_width, self.label_height), interpolation = cv2.INTER_NEAREST) 32 | if len(gt.shape) == 3: 33 | gt = gt[:,:,0] 34 | thining_gt = cv2.imread(h_label,-1) 35 | gt_num_list = list(np.unique(gt)) 36 | gt_num_list.remove(0) 37 | target_ins = np.zeros((4, gt.shape[0],gt.shape[1])).astype('uint8') 38 | for index, ins in enumerate(gt_num_list): 39 | target_ins[index,:,:] += (gt==ins) 40 | return img, target_ins, len(gt_num_list), thining_gt 41 | 42 | def __len__(self): 43 | return len(self.imgs) 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /hnet/thinning.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from scipy import weave 4 | import numpy as np 5 | import cv2 6 | 7 | def _thinningIteration(im, iter): 8 | I, M = im, np.zeros(im.shape, np.uint8) 9 | expr = """ 10 | for (int i = 1; i < NI[0]-1; i++) { 11 | for (int j = 1; j < NI[1]-1; j++) { 12 | int p2 = I2(i-1, j); 13 | int p3 = I2(i-1, j+1); 14 | int p4 = I2(i, j+1); 15 | int p5 = I2(i+1, j+1); 16 | int p6 = I2(i+1, j); 17 | int p7 = I2(i+1, j-1); 18 | int p8 = I2(i, j-1); 19 | int p9 = I2(i-1, j-1); 20 | int A = (p2 == 0 && p3 == 1) + (p3 == 0 && p4 == 1) + 21 | (p4 == 0 && p5 == 1) + (p5 == 0 && p6 == 1) + 22 | (p6 == 0 && p7 == 1) + (p7 == 0 && p8 == 1) + 23 | (p8 == 0 && p9 == 1) + (p9 == 0 && p2 == 1); 24 | int B = p2 + p3 + p4 + p5 + p6 + p7 + p8 + p9; 25 | int m1 = iter == 0 ? (p2 * p4 * p6) : (p2 * p4 * p8); 26 | int m2 = iter == 0 ? (p4 * p6 * p8) : (p2 * p6 * p8); 27 | //int m1 = iter == 0 ? 1 - ((p2 * p4 * p6 == 0) && (p4 * p6 * p8 == 0)): 0; 28 | //int m2 = iter == 1 ? 1 - ((p2 * p4 * p8 == 0) && (p2 * p6 * p8 == 0)): 0; 29 | if (A == 1 && B >= 2 && B <= 6 && m1 == 0 && m2 == 0) { 30 | M2(i,j) = 1; 31 | } 32 | } 33 | } 34 | """ 35 | 36 | weave.inline(expr, ["I", "iter", "M"]) 37 | return (I & ~M) 38 | 39 | 40 | def thinning(src): 41 | dst = src.copy() 42 | prev = np.zeros(src.shape[:2], np.uint8) 43 | diff = None 44 | 45 | while True: 46 | dst = _thinningIteration(dst, 0) 47 | dst = _thinningIteration(dst, 1) 48 | diff = np.absolute(dst - prev) 49 | prev = dst.copy() 50 | if np.sum(diff) == 0: 51 | break 52 | dst[0,:] = 0 53 | dst[-1,:]=0 54 | dst[:,0] = 0 55 | dst[:,-1] = 0 56 | return dst 57 | 58 | f = open('/mnt/lustre/share/dingmingyu/new_list_lane.txt').readlines() 59 | for index, line in enumerate(f): 60 | if index % 200 == 0: 61 | print index 62 | gt_name = line.strip().split()[1] 63 | img = cv2.imread(gt_name,-1) 64 | img = cv2.resize(img,(209,177), interpolation=cv2.INTER_NEAREST) 65 | if len(img.shape) == 3: 66 | img = img[:,:,0] 67 | 68 | new_image = np.zeros((177, 209)).astype('uint8') 69 | for i in range(4): 70 | image = img.copy() 71 | image[image != i+1] = 0 72 | image[image == i+1] = 1 73 | thinning_img = thinning(image) 74 | new_image += thinning_img * (i+1) 75 | cv2.imwrite(gt_name[:-4] + '_thin.png', new_image) 76 | #print gt_name[:-4] + '_thin.png' 77 | -------------------------------------------------------------------------------- /model/inference.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | import lanenet 14 | import os 15 | import torchvision as tv 16 | from torch.autograd import Variable 17 | import cv2 18 | 19 | from sklearn.cluster import KMeans, AffinityPropagation, MeanShift, estimate_bandwidth 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--img_list', dest='img_list', default='/mnt/lustre/dingmingyu/data/test_pic/list.txt', help='the test image list', type=str) 24 | parser.add_argument('--img_dir', dest='img_dir', default='/mnt/lustre/dingmingyu/data/test_pic/', 25 | help='the test image dir', type=str) 26 | parser.add_argument('--model_path', dest='model_path', default='checkpoints/020_checkpoint.pth.tar', 27 | help='the test model', type=str) 28 | 29 | def main(): 30 | global args 31 | args = parser.parse_args() 32 | print ("Build model ...") 33 | model = lanenet.Net() 34 | model = torch.nn.DataParallel(model).cuda() 35 | state = torch.load(args.model_path)['state_dict'] 36 | model.load_state_dict(state) 37 | model.eval() 38 | 39 | mean = [0.485, 0.456, 0.406] 40 | std = [0.229, 0.224, 0.225] 41 | f = open(args.img_list) 42 | ni = 0 43 | for line in f: 44 | line = line.strip() 45 | arrl = line.split(" ") 46 | image = cv2.imread(args.img_dir + arrl[0]) 47 | image = cv2.resize(image,(833,705)).astype(np.float32) 48 | image -= [104, 117, 123] 49 | image = image.transpose(2, 0, 1) 50 | #image = cv2.imread(args.img_dir + arrl[0], -1) 51 | #image = cv2.resize(image, (732,704), interpolation = cv2.INTER_NEAREST) 52 | #print image.shape 53 | image = torch.from_numpy(image).unsqueeze(0) 54 | image = Variable(image.float().cuda(0), volatile=True) 55 | output, embedding, n_objects_predictions = model(image) 56 | #prob = output.data[0].max(0)[1].cpu().numpy() 57 | #print prob.max(),prob.shape 58 | output = torch.nn.functional.softmax(output[0],dim=0) 59 | prob = output.data.cpu().numpy() 60 | # prob = output.data[0].max(0)[1].cpu().numpy() 61 | # print output.size() 62 | #print output.max(),type(output) 63 | 64 | print prob[1].max(),prob.shape 65 | prob = (prob[1] >= 0.4) 66 | n_objects_predictions = n_objects_predictions * 4 67 | n_objects_predictions = torch.round(n_objects_predictions).int() 68 | n_objects_predictions = int(n_objects_predictions.data.cpu()) 69 | embedding = embedding.data.cpu().numpy() 70 | embedding = embedding[0,:,:,:].transpose((1, 2, 0)) 71 | mylist = [] 72 | indexlist = [] 73 | for i in range(embedding.shape[0]): 74 | for j in range(embedding.shape[1]): 75 | if prob[i][j] > 0: 76 | mylist.append(embedding[i,j,:]) 77 | indexlist.append((i,j)) 78 | mylist = np.array(mylist) 79 | 80 | #bandwidth = estimate_bandwidth(mylist, quantile=0.3, n_samples=100, n_jobs = 8) 81 | #print bandwidth 82 | estimator = MeanShift(bandwidth=1.5, bin_seeding=True) 83 | #estimator = KMeans(n_clusters = n_objects_predictions) 84 | estimator.fit(mylist) 85 | print len(np.unique(estimator.labels_)),'~~~~~~~~~~~~~~~~' 86 | for i in range(4): 87 | print len(estimator.labels_[estimator.labels_==i]),' ', 88 | 89 | probAll = np.zeros((prob.shape[0], prob.shape[1], 3), dtype=np.float) 90 | # probAll[:,:,0] += prob # line 1 91 | # probAll[:,:,1] += prob # line 2 92 | # probAll[:,:,2] += prob # line 3 93 | 94 | for index,item in enumerate(estimator.labels_): 95 | x = indexlist[index][0] 96 | y = indexlist[index][1] 97 | if item < 3: 98 | probAll[x,y,item] += prob[x,y] # line 1 99 | else: 100 | probAll[x,y,:] += 1 101 | 102 | probAll = np.clip(probAll * 255, 0, 255) 103 | 104 | test_img = cv2.imread(args.img_dir + arrl[0], -1) 105 | probAll = cv2.resize(probAll, (1280,720), interpolation = cv2.INTER_NEAREST) 106 | test_img = cv2.resize(test_img, (1280,720)) 107 | 108 | ni = ni + 1 109 | test_img = np.clip(test_img + probAll, 0, 255).astype('uint8') 110 | cv2.imwrite(args.img_dir + 'prob/test_' + str(ni) + '_lane.png', test_img) 111 | print('write img: ' + str(ni+1)) 112 | f.close() 113 | 114 | if __name__ == '__main__': 115 | main() 116 | 117 | -------------------------------------------------------------------------------- /hnet/inference_h.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | import lanenet 14 | import os 15 | import torchvision as tv 16 | from torch.autograd import Variable 17 | import cv2 18 | 19 | from sklearn.cluster import KMeans, AffinityPropagation, MeanShift, estimate_bandwidth 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--img_list', dest='img_list', default='/mnt/lustre/dingmingyu/data/test_pic/list.txt', help='the test image list', type=str) 24 | parser.add_argument('--img_dir', dest='img_dir', default='/mnt/lustre/dingmingyu/data/test_pic/', 25 | help='the test image dir', type=str) 26 | parser.add_argument('--model_path', dest='model_path', default='checkpoints/020_checkpoint.pth.tar', 27 | help='the test model', type=str) 28 | 29 | def main(): 30 | global args 31 | args = parser.parse_args() 32 | print ("Build model ...") 33 | model = lanenet.Net() 34 | model = torch.nn.DataParallel(model).cuda() 35 | state = torch.load(args.model_path)['state_dict'] 36 | model.load_state_dict(state) 37 | model.eval() 38 | 39 | mean = [0.485, 0.456, 0.406] 40 | std = [0.229, 0.224, 0.225] 41 | f = open(args.img_list) 42 | ni = 0 43 | if 1: 44 | arrl = ('/mnt/lustre/share/dingmingyu/share/zhuangpeiye/unlabeled_data/Curve/Sunny/Daytime/20170922-20170924/DriverUnknown/GOPR0023_3.MP4/00000015.jpg','/mnt/lustre/share/dingmingyu/share/sunpeng/ADAS/GMXL_DATA/201801_return/lane_and_type/Label/Curve/Sunny/Daytime/20170922-20170924/DriverUnknown/GOPR0023_3.MP4/00000015_thin.png') 45 | image = cv2.imread(arrl[0]) 46 | image = cv2.resize(image,(833,705)).astype(np.float32) 47 | image -= [104, 117, 123] 48 | image = image.transpose(2, 0, 1) 49 | #image = cv2.imread(args.img_dir + arrl[0], -1) 50 | #image = cv2.resize(image, (732,704), interpolation = cv2.INTER_NEAREST) 51 | #print image.shape 52 | image = torch.from_numpy(image).unsqueeze(0) 53 | image = Variable(image.float().cuda(0), volatile=True) 54 | output, embedding, n_objects_predictions, x_hnet = model(image) 55 | x_hnet = x_hnet.data.cpu().numpy() 56 | print x_hnet 57 | H = np.zeros((3,3)).astype(np.float64) 58 | H[0,0] = x_hnet[0,0] 59 | H[0,1] = x_hnet[0,1] 60 | H[0,2] = x_hnet[0,2] 61 | H[1,1] = x_hnet[0,3] 62 | H[1,2] = x_hnet[0,4] 63 | H[2,1] = x_hnet[0,5] 64 | H[2,2] = 1 65 | gt = cv2.imread(arrl[1],-1) 66 | cv2.imwrite('gt.png', gt) 67 | list_y, list_x = np.where(gt >= 1) 68 | print list_y 69 | p = np.vstack((list_x, list_y, np.ones(len(list_y)))).astype(np.float64) 70 | p_ = H.dot(p).astype(np.int32) 71 | print p_/(-10000) 72 | 73 | 74 | #prob = output.data[0].max(0)[1].cpu().numpy() 75 | #print prob.max(),prob.shape 76 | output = torch.nn.functional.softmax(output[0],dim=0) 77 | prob = output.data.cpu().numpy() 78 | # prob = output.data[0].max(0)[1].cpu().numpy() 79 | # print output.size() 80 | #print output.max(),type(output) 81 | 82 | print prob[1].max(),prob.shape 83 | prob = (prob[1] >= 0.2) 84 | n_objects_predictions = n_objects_predictions * 4 85 | n_objects_predictions = torch.round(n_objects_predictions).int() 86 | n_objects_predictions = int(n_objects_predictions.data.cpu()) 87 | embedding = embedding.data.cpu().numpy() 88 | embedding = embedding[0,:,:,:].transpose((1, 2, 0)) 89 | mylist = [] 90 | indexlist = [] 91 | for i in range(embedding.shape[0]): 92 | for j in range(embedding.shape[1]): 93 | if prob[i][j] > 0: 94 | mylist.append(embedding[i,j,:]) 95 | indexlist.append((i,j)) 96 | mylist = np.array(mylist) 97 | print n_objects_predictions,'~~~~~~~~~~~' 98 | 99 | bandwidth = estimate_bandwidth(mylist, quantile=0.3, n_samples=100, n_jobs = 8) 100 | print bandwidth 101 | estimator = MeanShift(bandwidth=1.5, bin_seeding=True) 102 | #estimator = KMeans(n_clusters = n_objects_predictions) 103 | estimator.fit(mylist) 104 | print len(np.unique(estimator.labels_)),'~~~~~~~~~~~~~~~~' 105 | for i in range(4): 106 | print len(estimator.labels_[estimator.labels_==i]),' ', 107 | 108 | probAll = np.zeros((prob.shape[0], prob.shape[1], 3), dtype=np.float) 109 | # probAll[:,:,0] += prob # line 1 110 | # probAll[:,:,1] += prob # line 2 111 | # probAll[:,:,2] += prob # line 3 112 | 113 | for index,item in enumerate(estimator.labels_): 114 | x = indexlist[index][0] 115 | y = indexlist[index][1] 116 | if item < 3: 117 | probAll[x,y,item] += prob[x,y] # line 1 118 | else: 119 | probAll[x,y,:] += 1 120 | 121 | probAll = np.clip(probAll * 255, 0, 255) 122 | 123 | test_img = cv2.imread(arrl[0], -1) 124 | probAll = cv2.resize(probAll, (1280,720), interpolation = cv2.INTER_NEAREST) 125 | test_img = cv2.resize(test_img, (1280,720)) 126 | 127 | ni = ni + 1 128 | test_img = np.clip(test_img + probAll, 0, 255).astype('uint8') 129 | cv2.imwrite(args.img_dir + 'prob/test_' + str(ni) + '_lane.png', test_img) 130 | print('write img: ' + str(ni+1)) 131 | 132 | if __name__ == '__main__': 133 | main() 134 | 135 | -------------------------------------------------------------------------------- /model/evaluation.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader 13 | import lanenet 14 | import os 15 | import torchvision as tv 16 | from torch.autograd import Variable 17 | from PIL import Image 18 | import cv2 19 | from sklearn.cluster import KMeans, AffinityPropagation, MeanShift, estimate_bandwidth 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--img_list', dest='img_list', default='/mnt/lustre/sunpeng/ADAS_lane_prob_IoU_evaluation/GMXL/new_camera1212_test.txt', 23 | help='the test image list', type=str) 24 | 25 | parser.add_argument('--img_dir', dest='img_dir', default='/mnt/lustre/drive_data/adas-video/record/', 26 | help='the test image dir', type=str) 27 | parser.add_argument('--gtpng_dir', dest='gtpng_dir', default='/mnt/lustre/drive_data/adas-video/record/', 28 | help='the test png label dir', type=str) 29 | 30 | parser.add_argument('--model_path', dest='model_path', default='checkpoints/020_checkpoint.pth.tar', 31 | help='the test model', type=str) 32 | 33 | parser.add_argument('--prethreshold', dest='prethreshold', default=0.4, 34 | help='preset bad threshold', type=float) 35 | 36 | 37 | def get_iou(gt, prob): 38 | cross = np.multiply(gt, prob) 39 | unite = gt + prob 40 | unite[unite >= 1] = 1 41 | union = np.sum(unite) 42 | inter = np.sum(cross) 43 | if (union != 0): 44 | iou = inter * 1.0 / union 45 | if (np.sum(gt) == 0): 46 | iou = -1 47 | else: 48 | iou = -10 49 | return iou 50 | 51 | def iou_one_frame(gtpng, rprob, bias_threshold): 52 | iou = [0, 0, 0, 0] 53 | gt = [gtpng.copy(),gtpng.copy(),gtpng.copy(),gtpng.copy()] 54 | for i, item in enumerate(gt): 55 | item[item != i+1] = 0 56 | item[item == i+1] = 1 57 | prob = [rprob.copy(),rprob.copy(),rprob.copy(),rprob.copy()] 58 | for i, item in enumerate(prob): 59 | item[item != i+1] = 0 60 | item[item == i+1] = 1 61 | iou_all = np.zeros((4,4)) 62 | for i in range(4): 63 | for j in range(4): 64 | iou_all[i,j] = get_iou(gt[i],prob[j]) 65 | #print iou_all 66 | for i in range(4): 67 | max = np.argmax(iou_all) 68 | gt_max = int(max/4) 69 | prob_max = max % 4 70 | #print gt_max, prob_max 71 | iou[gt_max] = iou_all.max() 72 | iou_all[gt_max,:] = -100 73 | iou_all[:,prob_max] = -100 74 | return iou 75 | 76 | 77 | 78 | def main(): 79 | global args 80 | args = parser.parse_args() 81 | print ("Build model ...") 82 | model = lanenet.Net() 83 | model = torch.nn.DataParallel(model).cuda() 84 | state = torch.load(args.model_path)['state_dict'] 85 | model.load_state_dict(state) 86 | model.eval() 87 | 88 | mean = [104, 117, 123] 89 | f = open(args.img_list) 90 | ni = 0 91 | count_gt = [0,0,0,0] 92 | total_iou = [0,0,0,0] 93 | total_iou_big = [0,0,0,0] 94 | for line in f: 95 | #if ni>1: 96 | # break 97 | line = line.strip() 98 | arrl = line.split(" ") 99 | 100 | #gtlb = cv2.imread(args.gtpng_dir + arrl[1], -1) 101 | gtlb = cv2.imread(args.gtpng_dir + arrl[1], -1) 102 | #print gtlb.shape 103 | gt_num_list = list(np.unique(gtlb)) 104 | gt_num_list.remove(0) 105 | n_objects_gt = len(gt_num_list) 106 | image = cv2.imread(args.img_dir + arrl[0]) 107 | image = cv2.resize(image,(833,705)).astype(np.float32) 108 | image -= mean 109 | image = image.transpose(2, 0, 1) 110 | #image = cv2.imread(args.img_dir + arrl[0], -1) 111 | #image = cv2.resize(image, (732,704), interpolation = cv2.INTER_NEAREST) 112 | #print image.shape 113 | image = torch.from_numpy(image).unsqueeze(0) 114 | image = Variable(image.float().cuda(0), volatile=True) 115 | output, embedding, n_objects_predictions = model(image) 116 | output = torch.nn.functional.softmax(output[0],dim=0) 117 | prob = output.data.cpu().numpy() 118 | embedding = embedding.data.cpu().numpy() 119 | n_objects_predictions = n_objects_predictions * 4 120 | n_objects_predictions = torch.round(n_objects_predictions).int() 121 | n_objects_predictions = int(n_objects_predictions.data.cpu()) 122 | print n_objects_predictions,'~~~~~~~~~~~~~~~~~~~~~~~`' 123 | 124 | if not n_objects_predictions: 125 | continue 126 | prob[prob >= args.prethreshold] = 1.0 127 | prob[prob < args.prethreshold] = 0 128 | embedding = embedding[0,:,:,:].transpose((1, 2, 0)) 129 | #print prob.shape 130 | mylist = [] 131 | indexlist = [] 132 | for i in range(embedding.shape[0]): 133 | for j in range(embedding.shape[1]): 134 | if prob[1][i][j] > 0: 135 | mylist.append(embedding[i,j,:]) 136 | indexlist.append((i,j)) 137 | if not mylist: 138 | continue 139 | mylist = np.array(mylist) 140 | # bandwidth = estimate_bandwidth(mylist, quantile=0.3, n_samples=100, n_jobs = 8) 141 | # print bandwidth 142 | # estimator = MeanShift(bandwidth=1, bin_seeding=True) 143 | estimator = KMeans(n_clusters = n_objects_predictions) 144 | #estimator = AffinityPropagation(preference=-0.4, damping = 0.5) 145 | t = time.time() 146 | estimator.fit(mylist) 147 | print time.time() - t 148 | for i in range(4): 149 | print len(estimator.labels_[estimator.labels_==i]) 150 | #print len(np.unique(estimator.labels_)),'~~~~~~~~~~~~~~~~' 151 | new_prob = np.zeros((embedding.shape[0],embedding.shape[1]),dtype=int) 152 | for index, item in enumerate(estimator.labels_): 153 | if item <= 4: 154 | new_prob[indexlist[index][0]][indexlist[index][1]] = item + 1 155 | 156 | gtlb = cv2.resize(gtlb, (prob.shape[2], prob.shape[1]),interpolation = cv2.INTER_NEAREST) 157 | iou = iou_one_frame(gtlb, new_prob, args.prethreshold) 158 | 159 | print('IoU of ' + str(ni) + ' '+ arrl[0] + ': ' + str(iou)) 160 | for i in range(0,4): 161 | if iou[i] >= 0: 162 | count_gt[i] = count_gt[i] + 1 163 | total_iou[i] = total_iou[i] + iou[i] 164 | ni += 1 165 | mean_iou = np.divide(total_iou, count_gt) 166 | print('Image numer: ' + str(ni)) 167 | print('Mean IoU of four lanes: ' + str(mean_iou)) 168 | print('Overall evaluation: ' + str(mean_iou[0] * 0.2 + mean_iou[1] * 0.3 + mean_iou[2] * 0.3 + mean_iou[3] * 0.2)) 169 | 170 | f.close() 171 | 172 | if __name__ == '__main__': 173 | main() 174 | 175 | -------------------------------------------------------------------------------- /model/lanenet.py: -------------------------------------------------------------------------------- 1 | #--coding:utf-8-- 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | "3x3 convolution with padding" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | 12 | class conv2DBatchNormRelu(nn.Module): 13 | def __init__(self, in_channels, n_filters, stride=1, k_size=3, padding=1, bias=True, dilation=1, with_bn=True): 14 | super(conv2DBatchNormRelu, self).__init__() 15 | 16 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 17 | padding=padding, stride=stride, bias=bias, dilation=dilation) 18 | 19 | if with_bn: 20 | self.cbr_unit = nn.Sequential(conv_mod, 21 | nn.BatchNorm2d(int(n_filters)), 22 | nn.ReLU(inplace=True),) 23 | else: 24 | self.cbr_unit = nn.Sequential(conv_mod, 25 | nn.ReLU(inplace=True),) 26 | 27 | def forward(self, inputs): 28 | outputs = self.cbr_unit(inputs) 29 | return outputs 30 | 31 | 32 | class sharedBottom(nn.Module): 33 | def __init__(self,): 34 | super(sharedBottom, self).__init__() 35 | self.conv1 = conv2DBatchNormRelu(3, 16, 2) 36 | self.conv2a1 = conv2DBatchNormRelu(16, 16, 2) 37 | self.conv2a2 = conv2DBatchNormRelu(16,8) 38 | self.conv2a3 = conv2DBatchNormRelu(8,4) 39 | self.conv2a4 = conv2DBatchNormRelu(4,4) 40 | self.conv2a_strided = conv2DBatchNormRelu(32,32,2) 41 | self.conv3 = conv2DBatchNormRelu(32,32,2) 42 | self.conv4 = conv2DBatchNormRelu(32,32,1) 43 | self.conv6 = conv2DBatchNormRelu(32,64,2) 44 | self.conv8 = conv2DBatchNormRelu(64,64,1) 45 | self.conv9 = conv2DBatchNormRelu(64,128,2) 46 | self.conv11 = conv2DBatchNormRelu(128,128,1) 47 | self.conv11_1 = conv2DBatchNormRelu(128,32,1) 48 | self.conv11_2 = conv2DBatchNormRelu(128,32,1) 49 | self.conv11_3 = conv2DBatchNormRelu(128,32,1) 50 | self.conv11_4 = conv2DBatchNormRelu(32,64,1) 51 | self.conv11_6 = conv2DBatchNormRelu(32,64,1) 52 | self.conv11_5 = conv2DBatchNormRelu(64,128,1) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x1 = self.conv2a1(x) 57 | x2 = self.conv2a2(x1) 58 | x3 = self.conv2a3(x2) 59 | x4 = self.conv2a4(x3) 60 | x = torch.cat([x1, x2, x3, x4], dim = 1) 61 | x = self.conv2a_strided(x) 62 | x = self.conv3(x) 63 | x = self.conv4(x) 64 | x = self.conv6(x) 65 | x = self.conv8(x) 66 | x = self.conv9(x) 67 | x = self.conv11(x) 68 | x1= self.conv11_1(x) 69 | x2= self.conv11_2(x) 70 | x3= self.conv11_3(x) 71 | x4= self.conv11_4(x3) 72 | x6= self.conv11_6(x2) 73 | x5= self.conv11_5(x4) 74 | x = torch.cat([x1, x5, x6], dim = 1) 75 | return x 76 | 77 | class laneNet(nn.Module): 78 | def __init__(self,): 79 | super(laneNet, self).__init__() 80 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 81 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 82 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 83 | self.conv12 = conv2DBatchNormRelu(128,16,1) 84 | self.conv13 = conv2DBatchNormRelu(16,8,1) 85 | self.conv14 = nn.Conv2d(8, 2, 1,stride = 1,padding = 0, bias=True) 86 | def forward(self, x): 87 | x = self.conv11_7(x) 88 | x = self.conv11_8(x) 89 | x = self.conv11_9(x) 90 | x = nn.Upsample(size=(45,53),mode='bilinear')(x) 91 | x = self.conv12(x) 92 | x = nn.Upsample(size=(177,209),mode='bilinear')(x) 93 | x = self.conv13(x) 94 | x = self.conv14(x) 95 | return x 96 | 97 | class clusterNet(nn.Module): 98 | def __init__(self,): 99 | super(clusterNet, self).__init__() 100 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 101 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 102 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 103 | self.conv12 = conv2DBatchNormRelu(128,16,1) 104 | self.conv13 = conv2DBatchNormRelu(16,8,1) 105 | self.deconv1 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, bias=True) 106 | self.deconv2 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, bias=True) 107 | self.deconv3 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, bias=True) 108 | self.deconv4 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, bias=True) 109 | self.conv14 = nn.Conv2d(8, 4, 1,stride = 1,padding = 0, bias=True) 110 | def forward(self, x): 111 | x = self.conv11_7(x) 112 | x = self.deconv1(x) 113 | x = self.conv11_8(x) 114 | x = self.deconv2(x) 115 | x = self.conv11_9(x) 116 | x = self.deconv3(x) 117 | x = self.conv12(x) 118 | x = self.deconv4(x) 119 | x = self.conv13(x) 120 | x = self.conv14(x) 121 | return x 122 | 123 | class insClsNet(nn.Module): 124 | def __init__(self,): 125 | super(insClsNet, self).__init__() 126 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 127 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 128 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 129 | self.conv12 = conv2DBatchNormRelu(128,64,1) 130 | self.conv13 = conv2DBatchNormRelu(64,64,1) 131 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 132 | self.ins_cls_out = nn.Sequential() 133 | self.ins_cls_out.add_module('linear', nn.Linear(64, 1)) 134 | self.ins_cls_out.add_module('sigmoid', nn.Sigmoid()) 135 | 136 | 137 | def forward(self, x): 138 | x = self.conv11_7(x) 139 | x = self.conv11_8(x) 140 | x = self.conv11_9(x) 141 | x = self.conv12(x) 142 | x = self.conv13(x) 143 | x = self.global_pool(x) 144 | x = x.squeeze(3).squeeze(2) 145 | x_ins_cls = self.ins_cls_out(x) 146 | return x_ins_cls 147 | 148 | class Net(nn.Module): 149 | def __init__(self): 150 | # nn.Module子类的函数必须在构造函数中执行父类的构造函数 151 | # 下式等价于nn.Module.__init__(self) 152 | super(Net, self).__init__() 153 | self.bottom = sharedBottom() 154 | self.sem_seg = laneNet() 155 | self.ins_seg = clusterNet() 156 | self.ins_cls = insClsNet() 157 | self._initialize_weights() 158 | 159 | def _initialize_weights(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal(m.weight.data) 163 | if m.bias is not None: 164 | m.bias.data.zero_() 165 | elif isinstance(m, nn.BatchNorm2d): 166 | m.weight.data.fill_(1) 167 | m.bias.data.zero_() 168 | 169 | def forward(self, x): 170 | x = self.bottom(x) 171 | x_sem = self.sem_seg(x) 172 | x_ins = self.ins_seg(x) 173 | x_cls = self.ins_cls(x) 174 | return x_sem, x_ins, x_cls 175 | 176 | #net = Net() 177 | #print(net) 178 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _assert_no_grad, _Loss 6 | from torch.autograd import Variable 7 | 8 | def cross_entropy2d(input, target, weight=None, size_average=True): 9 | n, c, h, w = input.size() 10 | log_p = F.log_softmax(input, dim=1) 11 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 12 | log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] 13 | log_p = log_p.view(-1, c) 14 | 15 | mask = target >= 0 16 | target = target[mask] 17 | loss = F.nll_loss(log_p, target, ignore_index=255, weight=weight, size_average=False) 18 | if size_average: 19 | loss /= mask.data.sum() 20 | return loss 21 | 22 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): 23 | 24 | batch_size = input.size()[0] 25 | 26 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): 27 | n, c, h, w = input.size() 28 | log_p = F.log_softmax(input, dim=1) 29 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 30 | log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] 31 | log_p = log_p.view(-1, c) 32 | 33 | mask = target >= 0 34 | target = target[mask] 35 | loss = F.nll_loss(log_p, target, weight=weight, ignore_index=250, 36 | reduce=False, size_average=False) 37 | topk_loss, _ = loss.topk(K) 38 | reduced_topk_loss = topk_loss.sum() / K 39 | 40 | return reduced_topk_loss 41 | 42 | loss = 0.0 43 | # Bootstrap from each image not entire batch 44 | for i in range(batch_size): 45 | loss += _bootstrap_xentropy_single(input=torch.unsqueeze(input[i], 0), 46 | target=torch.unsqueeze(target[i], 0), 47 | K=K, 48 | weight=weight, 49 | size_average=size_average) 50 | return loss / float(batch_size) 51 | 52 | 53 | def discriminative_loss(input, target, n_objects, max_n_objects, usegpu): 54 | """input: bs, n_filters, fmap, fmap 55 | target: bs, n_instances, fmap, fmap 56 | n_objects: bs""" 57 | bs, n_filters, height, width = input.size() 58 | n_instances = target.size(1) 59 | 60 | input = input.permute(0, 2, 3, 1).contiguous().view( 61 | bs, height * width, n_filters) 62 | target = target.permute(0, 2, 3, 1).contiguous().view( 63 | bs, height * width, n_instances) 64 | cluster_means = calculate_means( 65 | input, target, n_objects, max_n_objects, usegpu) 66 | var_term = calculate_variance_term( 67 | input, target, cluster_means, n_objects, delta_v=0.5, norm=2) 68 | dist_term = calculate_distance_term( 69 | cluster_means, n_objects, delta_d=3, norm=2, usegpu=True) 70 | reg_term = calculate_regularization_term(cluster_means, n_objects, norm=2) 71 | loss = var_term + dist_term + 0.001*reg_term 72 | return loss 73 | 74 | 75 | def calculate_means(pred, gt, n_objects, max_n_objects, usegpu): 76 | """pred: bs, height * width, n_filters 77 | gt: bs, height * width, n_instances""" 78 | 79 | bs, n_loc, n_filters = pred.size() 80 | n_instances = gt.size(2) 81 | 82 | pred_repeated = pred.unsqueeze(2).expand( 83 | bs, n_loc, n_instances, n_filters) # bs, n_loc, n_instances, n_filters 84 | # bs, n_loc, n_instances, 1 85 | gt_expanded = gt.unsqueeze(3) 86 | 87 | #print pred_repeated.size(),pred_repeated.type,gt_expanded.size(),gt_expanded.type 88 | pred_masked = pred_repeated * gt_expanded 89 | 90 | means = [] 91 | for i in range(bs): 92 | _n_objects_sample = n_objects[i] 93 | # n_loc, n_objects, n_filters 94 | if _n_objects_sample: 95 | _pred_masked_sample = pred_masked[i, :, : _n_objects_sample] 96 | # n_loc, n_objects, 1 97 | _gt_expanded_sample = gt_expanded[i, :, : _n_objects_sample] 98 | 99 | _mean_sample = _pred_masked_sample.sum(0) / _gt_expanded_sample.sum(0) # n_objects, n_filters 100 | if (max_n_objects - _n_objects_sample) != 0: 101 | n_fill_objects = max_n_objects - _n_objects_sample 102 | _fill_sample = torch.zeros(n_fill_objects, n_filters) 103 | if usegpu: 104 | _fill_sample = _fill_sample.cuda() 105 | _fill_sample = Variable(_fill_sample) 106 | _mean_sample = torch.cat((_mean_sample, _fill_sample), dim=0) 107 | else: 108 | _mean_sample = torch.zeros(max_n_objects, n_filters) 109 | if usegpu: 110 | _mean_sample = _mean_sample.cuda() 111 | _mean_sample = Variable(_mean_sample) 112 | means.append(_mean_sample) 113 | 114 | means = torch.stack(means) 115 | 116 | # means = pred_masked.sum(1) / gt_expanded.sum(1) 117 | # # bs, n_instances, n_filters 118 | 119 | return means 120 | 121 | 122 | 123 | def calculate_variance_term(pred, gt, means, n_objects, delta_v, norm=2): 124 | """pred: bs, height * width, n_filters 125 | gt: bs, height * width, n_instances 126 | means: bs, n_instances, n_filters""" 127 | 128 | bs, n_loc, n_filters = pred.size() 129 | n_instances = gt.size(2) 130 | 131 | # bs, n_loc, n_instances, n_filters 132 | means = means.unsqueeze(1).expand(bs, n_loc, n_instances, n_filters) 133 | # bs, n_loc, n_instances, n_filters 134 | pred = pred.unsqueeze(2).expand(bs, n_loc, n_instances, n_filters) 135 | # bs, n_loc, n_instances, n_filters 136 | gt = gt.unsqueeze(3).expand(bs, n_loc, n_instances, n_filters) 137 | 138 | _var = (torch.clamp(torch.norm((pred - means), norm, 3) - 139 | delta_v, min=0.0) ** 2) * gt[:, :, :, 0] 140 | 141 | var_term = 0.0 142 | for i in range(bs): 143 | if n_objects[i]: 144 | _var_sample = _var[i, :, :n_objects[i]] # n_loc, n_objects 145 | _gt_sample = gt[i, :, :n_objects[i], 0] # n_loc, n_objects 146 | 147 | var_term += torch.sum(_var_sample) / torch.sum(_gt_sample) 148 | var_term = var_term / bs 149 | 150 | return var_term 151 | 152 | def calculate_distance_term(means, n_objects, delta_d, norm=2, usegpu=True): 153 | """means: bs, n_instances, n_filters""" 154 | 155 | bs, n_instances, n_filters = means.size() 156 | 157 | dist_term = 0.0 158 | for i in range(bs): 159 | _n_objects_sample = n_objects[i] 160 | 161 | if _n_objects_sample <= 1: 162 | continue 163 | 164 | _mean_sample = means[i, : _n_objects_sample, :] # n_objects, n_filters 165 | means_1 = _mean_sample.unsqueeze(1).expand( 166 | _n_objects_sample, _n_objects_sample, n_filters) 167 | means_2 = means_1.permute(1, 0, 2) 168 | 169 | diff = means_1 - means_2 # n_objects, n_objects, n_filters 170 | 171 | _norm = torch.norm(diff, norm, 2) 172 | 173 | margin = delta_d * (1.0 - torch.eye(_n_objects_sample)) 174 | if usegpu: 175 | margin = margin.cuda() 176 | margin = Variable(margin) 177 | 178 | _dist_term_sample = torch.sum( 179 | torch.clamp(margin - _norm, min=0.0) ** 2) 180 | _dist_term_sample = _dist_term_sample / \ 181 | (_n_objects_sample * (_n_objects_sample - 1)) 182 | dist_term += _dist_term_sample 183 | 184 | dist_term = dist_term / bs 185 | 186 | return dist_term 187 | 188 | 189 | def calculate_regularization_term(means, n_objects, norm): 190 | """means: bs, n_instances, n_filters""" 191 | 192 | bs, n_instances, n_filters = means.size() 193 | 194 | reg_term = 0.0 195 | for i in range(bs): 196 | if n_objects[i]: 197 | _mean_sample = means[i, : n_objects[i], :] # n_objects, n_filters 198 | _norm = torch.norm(_mean_sample, norm, 1) 199 | reg_term += torch.mean(_norm) 200 | reg_term = reg_term / bs 201 | 202 | return reg_term -------------------------------------------------------------------------------- /hnet/lanenet_hnet.py: -------------------------------------------------------------------------------- 1 | #--coding:utf-8-- 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | "3x3 convolution with padding" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | 12 | class conv2DBatchNormRelu(nn.Module): 13 | def __init__(self, in_channels, n_filters, stride=1, k_size=3, padding=1, bias=True, dilation=1, with_bn=True): 14 | super(conv2DBatchNormRelu, self).__init__() 15 | 16 | conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, 17 | padding=padding, stride=stride, bias=bias, dilation=dilation) 18 | 19 | if with_bn: 20 | self.cbr_unit = nn.Sequential(conv_mod, 21 | nn.BatchNorm2d(int(n_filters)), 22 | nn.ReLU(inplace=True),) 23 | else: 24 | self.cbr_unit = nn.Sequential(conv_mod, 25 | nn.ReLU(inplace=True),) 26 | 27 | def forward(self, inputs): 28 | outputs = self.cbr_unit(inputs) 29 | return outputs 30 | 31 | 32 | class sharedBottom(nn.Module): 33 | def __init__(self,): 34 | super(sharedBottom, self).__init__() 35 | self.conv1 = conv2DBatchNormRelu(3, 16, 2) 36 | self.conv2a1 = conv2DBatchNormRelu(16, 16, 2) 37 | self.conv2a2 = conv2DBatchNormRelu(16,8) 38 | self.conv2a3 = conv2DBatchNormRelu(8,4) 39 | self.conv2a4 = conv2DBatchNormRelu(4,4) 40 | self.conv2a_strided = conv2DBatchNormRelu(32,32,2) 41 | self.conv3 = conv2DBatchNormRelu(32,32,2) 42 | self.conv4 = conv2DBatchNormRelu(32,32,1) 43 | self.conv6 = conv2DBatchNormRelu(32,64,2) 44 | self.conv8 = conv2DBatchNormRelu(64,64,1) 45 | self.conv9 = conv2DBatchNormRelu(64,128,2) 46 | self.conv11 = conv2DBatchNormRelu(128,128,1) 47 | self.conv11_1 = conv2DBatchNormRelu(128,32,1) 48 | self.conv11_2 = conv2DBatchNormRelu(128,32,1) 49 | self.conv11_3 = conv2DBatchNormRelu(128,32,1) 50 | self.conv11_4 = conv2DBatchNormRelu(32,64,1) 51 | self.conv11_6 = conv2DBatchNormRelu(32,64,1) 52 | self.conv11_5 = conv2DBatchNormRelu(64,128,1) 53 | 54 | def forward(self, x): 55 | x = self.conv1(x) 56 | x1 = self.conv2a1(x) 57 | x2 = self.conv2a2(x1) 58 | x3 = self.conv2a3(x2) 59 | x4 = self.conv2a4(x3) 60 | x = torch.cat([x1, x2, x3, x4], dim = 1) 61 | x = self.conv2a_strided(x) 62 | x = self.conv3(x) 63 | x = self.conv4(x) 64 | x = self.conv6(x) 65 | x = self.conv8(x) 66 | x = self.conv9(x) 67 | x = self.conv11(x) 68 | x1= self.conv11_1(x) 69 | x2= self.conv11_2(x) 70 | x3= self.conv11_3(x) 71 | x4= self.conv11_4(x3) 72 | x6= self.conv11_6(x2) 73 | x5= self.conv11_5(x4) 74 | x = torch.cat([x1, x5, x6], dim = 1) 75 | return x 76 | 77 | class laneNet(nn.Module): 78 | def __init__(self,): 79 | super(laneNet, self).__init__() 80 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 81 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 82 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 83 | self.conv12 = conv2DBatchNormRelu(128,16,1) 84 | self.conv13 = conv2DBatchNormRelu(16,8,1) 85 | self.conv14 = nn.Conv2d(8, 2, 1,stride = 1,padding = 0, bias=True) 86 | def forward(self, x): 87 | x = self.conv11_7(x) 88 | x = self.conv11_8(x) 89 | x = self.conv11_9(x) 90 | x = nn.Upsample(size=(45,53),mode='bilinear')(x) 91 | x = self.conv12(x) 92 | x = nn.Upsample(size=(177,209),mode='bilinear')(x) 93 | x = self.conv13(x) 94 | x = self.conv14(x) 95 | return x 96 | 97 | class clusterNet(nn.Module): 98 | def __init__(self,): 99 | super(clusterNet, self).__init__() 100 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 101 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 102 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 103 | self.conv12 = conv2DBatchNormRelu(128,16,1) 104 | self.conv13 = conv2DBatchNormRelu(16,8,1) 105 | self.deconv1 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, bias=True) 106 | self.deconv2 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, bias=True) 107 | self.deconv3 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, bias=True) 108 | self.deconv4 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, bias=True) 109 | self.conv14 = nn.Conv2d(8, 4, 1,stride = 1,padding = 0, bias=True) 110 | def forward(self, x): 111 | x = self.conv11_7(x) 112 | x = self.deconv1(x) 113 | x = self.conv11_8(x) 114 | x = self.deconv2(x) 115 | x = self.conv11_9(x) 116 | x = self.deconv3(x) 117 | x = self.conv12(x) 118 | x = self.deconv4(x) 119 | x = self.conv13(x) 120 | x = self.conv14(x) 121 | return x 122 | 123 | class insClsNet(nn.Module): 124 | def __init__(self,): 125 | super(insClsNet, self).__init__() 126 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 127 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 128 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 129 | self.conv12 = conv2DBatchNormRelu(128,64,1) 130 | self.conv13 = conv2DBatchNormRelu(64,64,1) 131 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 132 | self.ins_cls_out = nn.Sequential() 133 | self.ins_cls_out.add_module('linear', nn.Linear(64, 1)) 134 | self.ins_cls_out.add_module('sigmoid', nn.Sigmoid()) 135 | 136 | 137 | def forward(self, x): 138 | x = self.conv11_7(x) 139 | x = self.conv11_8(x) 140 | x = self.conv11_9(x) 141 | x = self.conv12(x) 142 | x = self.conv13(x) 143 | x = self.global_pool(x) 144 | x = x.squeeze(3).squeeze(2) 145 | x_ins_cls = self.ins_cls_out(x) 146 | return x_ins_cls 147 | 148 | class hNet(nn.Module): 149 | def __init__(self,): 150 | super(hNet, self).__init__() 151 | self.conv11_7 = conv2DBatchNormRelu(224,128,1) 152 | self.conv11_8 = conv2DBatchNormRelu(128,128,1) 153 | self.conv11_9 = conv2DBatchNormRelu(128,128,1) 154 | self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 155 | self.h_cls_out = nn.Sequential() 156 | self.h_cls_out.add_module('linear1', nn.Linear(128, 256)) 157 | self.h_cls_out.add_module('bn', torch.nn.BatchNorm1d(256)) 158 | self.h_cls_out.add_module('relu', torch.nn.ReLU()) 159 | self.h_cls_out.add_module('linear2', nn.Linear(256, 6)) 160 | 161 | 162 | def forward(self, x): 163 | x = self.conv11_7(x) 164 | x = self.conv11_8(x) 165 | x = self.conv11_9(x) 166 | x = self.global_pool(x) 167 | x = x.squeeze(3).squeeze(2) 168 | x_h_cls = self.h_cls_out(x) 169 | return x_h_cls 170 | 171 | class Net(nn.Module): 172 | def __init__(self): 173 | # nn.Module子类的函数必须在构造函数中执行父类的构造函数 174 | # 下式等价于nn.Module.__init__(self) 175 | super(Net, self).__init__() 176 | self.bottom = sharedBottom() 177 | self.sem_seg = laneNet() 178 | self.ins_seg = clusterNet() 179 | self.ins_cls = insClsNet() 180 | self.hnet = hNet() 181 | self._initialize_weights() 182 | 183 | def _initialize_weights(self): 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | nn.init.kaiming_normal(m.weight.data) 187 | if m.bias is not None: 188 | m.bias.data.zero_() 189 | elif isinstance(m, nn.BatchNorm2d): 190 | m.weight.data.fill_(1) 191 | m.bias.data.zero_() 192 | 193 | def forward(self, x): 194 | x = self.bottom(x) 195 | x_sem = self.sem_seg(x) 196 | x_ins = self.ins_seg(x) 197 | x_cls = self.ins_cls(x) 198 | x_hnet = self.hnet(x) 199 | return x_sem, x_ins, x_cls, x_hnet 200 | 201 | #net = Net() 202 | #print(net) 203 | -------------------------------------------------------------------------------- /model/main.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | from loss import * 13 | from dataloader import * 14 | import lanenet 15 | import os 16 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dir_path', default='') 20 | parser.add_argument('--workers', default=4, type=int, metavar='N', 21 | help='number of data loading workers (default: 8)') 22 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 25 | help='manual epoch number (useful on restarts)') 26 | parser.add_argument('--batch_size', default=128, type=int, 27 | metavar='N', help='mini-batch size') 28 | parser.add_argument('--new_length', default=705, type=int) 29 | parser.add_argument('--new_width', default=833, type=int) 30 | parser.add_argument('--label_length', default=177, type=int) 31 | parser.add_argument('--label_width', default=209, type=int) 32 | parser.add_argument('--lr', '--learning-rate', default=0.05, type=float, 33 | metavar='LR', help='initial learning rate') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', default=1e-4, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | parser.add_argument('--print-freq', default=1, type=int, 39 | metavar='N', help='print frequency (default: 20)') 40 | parser.add_argument('--save-freq', default=1, type=int, 41 | metavar='N', help='save frequency (default: 200)') 42 | parser.add_argument('--resume', default='checkpoints', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 45 | help='evaluate model on validation set') 46 | 47 | best_prec = 0 48 | 49 | 50 | def main(): 51 | global args, best_prec 52 | args = parser.parse_args() 53 | print ("Build model ...") 54 | model = lanenet.Net() 55 | model = torch.nn.DataParallel(model).cuda() 56 | #model.apply(weights_init) 57 | #params = torch.load('checkpoints/old.pth.tar') 58 | #model.load_state_dict(params['state_dict']) 59 | if not os.path.exists(args.resume): 60 | os.makedirs(args.resume) 61 | print("Saving everything to directory %s." % (args.resume)) 62 | 63 | # define loss function (criterion) and optimizer 64 | criterion = cross_entropy2d 65 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 66 | momentum=args.momentum, 67 | weight_decay=args.weight_decay) 68 | cudnn.benchmark = True 69 | 70 | # data transform 71 | 72 | train_data = MyDataset('/mnt/lustre/share/dingmingyu/new_list_lane.txt', args.dir_path, args.new_width, args.new_length,args.label_width,args.label_length) 73 | 74 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,pin_memory=True) 75 | 76 | 77 | for epoch in range(args.start_epoch, args.epochs): 78 | print 'epoch: ' + str(epoch + 1) 79 | 80 | # train for one epoch 81 | train(train_loader, model, criterion, optimizer, epoch) 82 | 83 | # evaluate on validation set 84 | 85 | # remember best prec and save checkpoint 86 | 87 | if (epoch + 1) % args.save_freq == 0: 88 | checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar") 89 | save_checkpoint({ 90 | 'epoch': epoch + 1, 91 | 'state_dict': model.state_dict(), 92 | 'optimizer' : optimizer.state_dict(), 93 | }, checkpoint_name, args.resume) 94 | 95 | 96 | 97 | def train(train_loader, model, criterion, optimizer, epoch): 98 | batch_time = AverageMeter() 99 | data_time = AverageMeter() 100 | losses_seg = AverageMeter() 101 | losses_ins = AverageMeter() 102 | losses_cls = AverageMeter() 103 | lrs = AverageMeter() 104 | # switch to train mode 105 | model.train() 106 | weight_cus = torch.ones(2) 107 | weight_cus[1] = 2 108 | weight_cus = weight_cus.cuda() 109 | end = time.time() 110 | for i, (input, target_ins, n_objects) in enumerate(train_loader): 111 | # measure data loading time 112 | data_time.update(time.time() - end) 113 | lr = adjust_learning_rate(optimizer, epoch*len(train_loader)+i, args.epochs*len(train_loader)) 114 | lrs.update(lr) 115 | input = input.float().cuda() 116 | target01 = target_ins.sum(1) 117 | target01 = target01.long().cuda() 118 | 119 | 120 | target_ins = target_ins.float().cuda() 121 | #n_objects = n_objects.long().cuda() 122 | 123 | input_var = torch.autograd.Variable(input) 124 | target_var = torch.autograd.Variable(target01) 125 | target_ins_var = torch.autograd.Variable(target_ins) 126 | #n_objects_var = torch.autograd.Variable(n_objects) 127 | 128 | n_objects_normalized = n_objects.float().cuda() / 4 129 | n_objects_normalized_var = torch.autograd.Variable(n_objects_normalized) 130 | 131 | 132 | x_sem, x_ins, x_cls= model(input_var) 133 | 134 | criterion_mse = torch.nn.MSELoss().cuda() 135 | loss_cls = criterion_mse(x_cls, n_objects_normalized_var) 136 | #print x_sem.size(), x_ins.size(), target_ins_var.size(), n_objects_var.size() (256L, 2L, 177L, 209L) (256L, 4L, 177L, 209L) (256L, 4L, 177L, 209L) (256L,) 137 | loss_ins = discriminative_loss(x_ins, target_ins_var, n_objects, 4, usegpu=True) 138 | #print x_sem.size(),target_var.size() 139 | loss_seg = criterion(x_sem, target_var, weight= weight_cus, size_average=True) 140 | loss = loss_seg + loss_ins + loss_cls 141 | 142 | losses_seg.update(loss_seg.data[0], input.size(0)) 143 | losses_ins.update(loss_ins.data[0], input.size(0)) 144 | losses_cls.update(loss_cls.data[0], input.size(0)) 145 | # compute gradient and do SGD step 146 | optimizer.zero_grad() 147 | loss.backward() 148 | optimizer.step() 149 | 150 | # measure elapsed time 151 | batch_time.update(time.time() - end) 152 | end = time.time() 153 | 154 | if i % args.print_freq == 0: 155 | print('Epoch: [{0}][{1}/{2}]\t' 156 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 157 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 158 | 'Loss_seg {loss_seg.val:.4f} ({loss_seg.avg:.4f})\t' 159 | 'Loss_ins {loss_ins.val:.4f} ({loss_ins.avg:.4f})\t' 160 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t' 161 | 'Lr {lr.val:.5f} ({lr.avg:.5f})\t'.format( 162 | epoch, i, len(train_loader), batch_time=batch_time, 163 | data_time=data_time, loss_seg=losses_seg, loss_ins=losses_ins, loss_cls=losses_cls, lr=lrs)) 164 | 165 | 166 | 167 | def save_checkpoint(state, filename, resume_path): 168 | cur_path = os.path.join(resume_path, filename) 169 | torch.save(state, cur_path) 170 | 171 | 172 | class AverageMeter(object): 173 | """Computes and stores the average and current value""" 174 | def __init__(self): 175 | self.reset() 176 | 177 | def reset(self): 178 | self.val = 0 179 | self.avg = 0 180 | self.sum = 0 181 | self.count = 0 182 | 183 | def update(self, val, n=1): 184 | self.val = val 185 | self.sum += val * n 186 | self.count += n 187 | self.avg = self.sum / self.count 188 | 189 | 190 | def adjust_learning_rate(optimizer, curr_iter, max_iter, power=0.9): 191 | lr = args.lr * (1 - float(curr_iter)/max_iter)**power 192 | for param_group in optimizer.param_groups: 193 | param_group['lr'] = lr 194 | 195 | return lr 196 | 197 | if __name__ == '__main__': 198 | main() 199 | -------------------------------------------------------------------------------- /hnet/main_thinning.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | from loss import * 13 | from dataloader import * 14 | import lanenet 15 | import os 16 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dir_path', default='') 20 | parser.add_argument('--workers', default=4, type=int, metavar='N', 21 | help='number of data loading workers (default: 8)') 22 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 25 | help='manual epoch number (useful on restarts)') 26 | parser.add_argument('--batch_size', default=128, type=int, 27 | metavar='N', help='mini-batch size') 28 | parser.add_argument('--new_length', default=705, type=int) 29 | parser.add_argument('--new_width', default=833, type=int) 30 | parser.add_argument('--label_length', default=177, type=int) 31 | parser.add_argument('--label_width', default=209, type=int) 32 | parser.add_argument('--lr', '--learning-rate', default=0.05, type=float, 33 | metavar='LR', help='initial learning rate') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', default=1e-4, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | parser.add_argument('--print-freq', default=1, type=int, 39 | metavar='N', help='print frequency (default: 20)') 40 | parser.add_argument('--save-freq', default=1, type=int, 41 | metavar='N', help='save frequency (default: 200)') 42 | parser.add_argument('--resume', default='checkpoints', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 45 | help='evaluate model on validation set') 46 | 47 | best_prec = 0 48 | 49 | 50 | def main(): 51 | global args, best_prec 52 | args = parser.parse_args() 53 | print ("Build model ...") 54 | model = lanenet.Net() 55 | model = torch.nn.DataParallel(model).cuda() 56 | #model.apply(weights_init) 57 | #params = torch.load('checkpoints/old.pth.tar') 58 | #model.load_state_dict(params['state_dict']) 59 | if not os.path.exists(args.resume): 60 | os.makedirs(args.resume) 61 | print("Saving everything to directory %s." % (args.resume)) 62 | 63 | # define loss function (criterion) and optimizer 64 | criterion = cross_entropy2d 65 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 66 | momentum=args.momentum, 67 | weight_decay=args.weight_decay) 68 | cudnn.benchmark = True 69 | 70 | # data transform 71 | 72 | train_data = MyDataset('new_list_lane.txt', args.dir_path, args.new_width, args.new_length,args.label_width,args.label_length) 73 | 74 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,pin_memory=True) 75 | 76 | 77 | for epoch in range(args.start_epoch, args.epochs): 78 | print 'epoch: ' + str(epoch + 1) 79 | 80 | # train for one epoch 81 | train(train_loader, model, criterion, optimizer, epoch) 82 | 83 | # evaluate on validation set 84 | 85 | # remember best prec and save checkpoint 86 | 87 | if (epoch + 1) % args.save_freq == 0: 88 | checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar") 89 | save_checkpoint({ 90 | 'epoch': epoch + 1, 91 | 'state_dict': model.state_dict(), 92 | 'optimizer' : optimizer.state_dict(), 93 | }, checkpoint_name, args.resume) 94 | 95 | 96 | 97 | def train(train_loader, model, criterion, optimizer, epoch): 98 | batch_time = AverageMeter() 99 | data_time = AverageMeter() 100 | losses_seg = AverageMeter() 101 | losses_ins = AverageMeter() 102 | losses_cls = AverageMeter() 103 | losses_hnet = AverageMeter() 104 | lrs = AverageMeter() 105 | # switch to train mode 106 | model.train() 107 | weight_cus = torch.ones(2) 108 | weight_cus[1] = 2 109 | weight_cus = weight_cus.cuda() 110 | end = time.time() 111 | for i, (input, target_ins, n_objects, thining_gt) in enumerate(train_loader): 112 | # measure data loading time 113 | data_time.update(time.time() - end) 114 | lr = adjust_learning_rate(optimizer, epoch*len(train_loader)+i, args.epochs*len(train_loader)) 115 | lrs.update(lr) 116 | input = input.float().cuda() 117 | target01 = target_ins.sum(1) 118 | target01 = target01.long().cuda() 119 | 120 | 121 | target_ins = target_ins.float().cuda() 122 | #n_objects = n_objects.long().cuda() 123 | 124 | input_var = torch.autograd.Variable(input) 125 | target_var = torch.autograd.Variable(target01) 126 | target_ins_var = torch.autograd.Variable(target_ins) 127 | #n_objects_var = torch.autograd.Variable(n_objects) 128 | 129 | n_objects_normalized = n_objects.float().cuda() / 4 130 | n_objects_normalized_var = torch.autograd.Variable(n_objects_normalized) 131 | x_sem, x_ins, x_cls, x_hnet= model(input_var) 132 | criterion_mse = torch.nn.MSELoss().cuda() 133 | 134 | 135 | ############################################## HNET ################################################## 136 | 137 | n, c = x_hnet.size() 138 | thining_gt = thining_gt.numpy() 139 | loss_hnet = torch.autograd.Variable(torch.zeros(1).float().cuda()) 140 | avg_size = 0 141 | for n_index in range(n): 142 | H = torch.autograd.Variable(torch.zeros((3,3)).float().cuda()) 143 | H[0,0] = x_hnet[n_index,0] 144 | H[0,1] = x_hnet[n_index,1] 145 | H[0,2] = x_hnet[n_index,2] 146 | H[1,1] = x_hnet[n_index,3] 147 | H[1,2] = x_hnet[n_index,4] 148 | H[2,1] = x_hnet[n_index,5] 149 | H[2,2] = 1 150 | if (thining_gt[n_index] > 0).any(): 151 | for index in range(1, 5): 152 | list_y, list_x = np.where(thining_gt[n_index] == index) 153 | if len(list_y): 154 | avg_size += 1 155 | old_x = torch.autograd.Variable(torch.from_numpy(list_x).float().cuda()) 156 | p = np.vstack((list_x, list_y, np.ones(len(list_y)))).astype(np.float64) 157 | P = torch.autograd.Variable(torch.from_numpy(p).float().cuda()) 158 | P_trans = torch.mm(H,P) 159 | list_x = P_trans[0] 160 | list_y = P_trans[1] 161 | scale = P_trans[2] 162 | Y = torch.stack((list_y**2, list_y, torch.autograd.Variable(torch.ones(list_y.size()[0]).float().cuda())), dim=0) 163 | W = torch.inverse(Y.mm(Y.t()))# 164 | W = W.mm(Y).mm(list_x.unsqueeze(1)) 165 | new_x = W[0]*list_y**2 + W[1]*list_y + W[2] 166 | new_P = torch.stack((new_x, list_y, scale), dim=0) 167 | trans_x = torch.inverse(H).mm(new_P)[0] 168 | if criterion_mse(trans_x/50, old_x/50).data.cpu().numpy() < 1: 169 | loss_hnet += criterion_mse(trans_x/50, old_x/50) 170 | #else: 171 | # print trans_x, old_x 172 | 173 | loss_hnet /= avg_size 174 | print avg_size 175 | loss_cls = criterion_mse(x_cls, n_objects_normalized_var) 176 | #print x_sem.size(), x_ins.size(), target_ins_var.size(), n_objects_var.size() (256L, 2L, 177L, 209L) (256L, 4L, 177L, 209L) (256L, 4L, 177L, 209L) (256L,) 177 | loss_ins = discriminative_loss(x_ins, target_ins_var, n_objects, 4, usegpu=True) 178 | #print x_sem.size(),target_var.size() 179 | loss_seg = criterion(x_sem, target_var, weight= weight_cus, size_average=True) 180 | loss = loss_seg + loss_ins + loss_cls + loss_hnet 181 | 182 | losses_seg.update(loss_seg.data[0], input.size(0)) 183 | losses_ins.update(loss_ins.data[0], input.size(0)) 184 | losses_cls.update(loss_cls.data[0], input.size(0)) 185 | losses_hnet.update(loss_hnet.data[0], input.size(0)) 186 | # compute gradient and do SGD step 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | 191 | # measure elapsed time 192 | batch_time.update(time.time() - end) 193 | end = time.time() 194 | 195 | if i % args.print_freq == 0: 196 | print('Epoch: [{0}][{1}/{2}]\t' 197 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 198 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 199 | 'Loss_seg {loss_seg.val:.4f} ({loss_seg.avg:.4f})\t' 200 | 'Loss_ins {loss_ins.val:.4f} ({loss_ins.avg:.4f})\t' 201 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t' 202 | 'Loss_hnet {loss_hnet.val:.4f} ({loss_hnet.avg:.4f})\t' 203 | 'Lr {lr.val:.5f} ({lr.avg:.5f})\t'.format( 204 | epoch, i, len(train_loader), batch_time=batch_time, 205 | data_time=data_time, loss_seg=losses_seg, loss_ins=losses_ins, loss_cls=losses_cls, loss_hnet=losses_hnet, lr=lrs)) 206 | 207 | 208 | 209 | def save_checkpoint(state, filename, resume_path): 210 | cur_path = os.path.join(resume_path, filename) 211 | torch.save(state, cur_path) 212 | 213 | 214 | class AverageMeter(object): 215 | """Computes and stores the average and current value""" 216 | def __init__(self): 217 | self.reset() 218 | 219 | def reset(self): 220 | self.val = 0 221 | self.avg = 0 222 | self.sum = 0 223 | self.count = 0 224 | 225 | def update(self, val, n=1): 226 | self.val = val 227 | self.sum += val * n 228 | self.count += n 229 | self.avg = self.sum / self.count 230 | 231 | 232 | def adjust_learning_rate(optimizer, curr_iter, max_iter, power=0.9): 233 | lr = args.lr * (1 - float(curr_iter)/max_iter)**power 234 | for param_group in optimizer.param_groups: 235 | param_group['lr'] = lr 236 | 237 | return lr 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /hnet/main_parallel.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import time 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | from torchvision import transforms 11 | from torch.utils.data import DataLoader 12 | from loss import * 13 | from dataloader import * 14 | import lanenet 15 | import os 16 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dir_path', default='') 20 | parser.add_argument('--workers', default=4, type=int, metavar='N', 21 | help='number of data loading workers (default: 8)') 22 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 23 | help='number of total epochs to run') 24 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 25 | help='manual epoch number (useful on restarts)') 26 | parser.add_argument('--batch_size', default=128, type=int, 27 | metavar='N', help='mini-batch size') 28 | parser.add_argument('--new_length', default=705, type=int) 29 | parser.add_argument('--new_width', default=833, type=int) 30 | parser.add_argument('--label_length', default=177, type=int) 31 | parser.add_argument('--label_width', default=209, type=int) 32 | parser.add_argument('--lr', '--learning-rate', default=0.05, type=float, 33 | metavar='LR', help='initial learning rate') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 35 | help='momentum') 36 | parser.add_argument('--weight-decay', default=1e-4, type=float, 37 | metavar='W', help='weight decay (default: 1e-4)') 38 | parser.add_argument('--print-freq', default=1, type=int, 39 | metavar='N', help='print frequency (default: 20)') 40 | parser.add_argument('--save-freq', default=1, type=int, 41 | metavar='N', help='save frequency (default: 200)') 42 | parser.add_argument('--resume', default='checkpoints', type=str, metavar='PATH', 43 | help='path to latest checkpoint (default: none)') 44 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 45 | help='evaluate model on validation set') 46 | 47 | best_prec = 0 48 | 49 | 50 | def main(): 51 | global args, best_prec 52 | args = parser.parse_args() 53 | print ("Build model ...") 54 | model = lanenet.Net() 55 | model = torch.nn.DataParallel(model).cuda() 56 | #model.apply(weights_init) 57 | #params = torch.load('checkpoints/old.pth.tar') 58 | #model.load_state_dict(params['state_dict']) 59 | if not os.path.exists(args.resume): 60 | os.makedirs(args.resume) 61 | print("Saving everything to directory %s." % (args.resume)) 62 | 63 | # define loss function (criterion) and optimizer 64 | criterion = cross_entropy2d 65 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 66 | momentum=args.momentum, 67 | weight_decay=args.weight_decay) 68 | cudnn.benchmark = True 69 | 70 | # data transform 71 | 72 | train_data = MyDataset('/mnt/lustre/share/dingmingyu/new_list_lane.txt', args.dir_path, args.new_width, args.new_length,args.label_width,args.label_length) 73 | 74 | train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,pin_memory=True) 75 | 76 | 77 | for epoch in range(args.start_epoch, args.epochs): 78 | print 'epoch: ' + str(epoch + 1) 79 | 80 | # train for one epoch 81 | train(train_loader, model, criterion, optimizer, epoch) 82 | 83 | # evaluate on validation set 84 | 85 | # remember best prec and save checkpoint 86 | 87 | if (epoch + 1) % args.save_freq == 0: 88 | checkpoint_name = "%03d_%s" % (epoch + 1, "checkpoint.pth.tar") 89 | save_checkpoint({ 90 | 'epoch': epoch + 1, 91 | 'state_dict': model.state_dict(), 92 | 'optimizer' : optimizer.state_dict(), 93 | }, checkpoint_name, args.resume) 94 | 95 | 96 | 97 | def train(train_loader, model, criterion, optimizer, epoch): 98 | batch_time = AverageMeter() 99 | data_time = AverageMeter() 100 | losses_seg = AverageMeter() 101 | losses_ins = AverageMeter() 102 | losses_cls = AverageMeter() 103 | losses_hnet = AverageMeter() 104 | losses_width = AverageMeter() 105 | lrs = AverageMeter() 106 | # switch to train mode 107 | model.train() 108 | weight_cus = torch.ones(2) 109 | weight_cus[1] = 2 110 | weight_cus = weight_cus.cuda() 111 | end = time.time() 112 | for i, (input, target_ins, n_objects, thining_gt) in enumerate(train_loader): 113 | # measure data loading time 114 | data_time.update(time.time() - end) 115 | lr = adjust_learning_rate(optimizer, epoch*len(train_loader)+i, args.epochs*len(train_loader)) 116 | lrs.update(lr) 117 | input = input.float().cuda() 118 | target01 = target_ins.sum(1) 119 | target01 = target01.long().cuda() 120 | 121 | 122 | target_ins = target_ins.float().cuda() 123 | #n_objects = n_objects.long().cuda() 124 | 125 | input_var = torch.autograd.Variable(input) 126 | target_var = torch.autograd.Variable(target01) 127 | target_ins_var = torch.autograd.Variable(target_ins) 128 | #n_objects_var = torch.autograd.Variable(n_objects) 129 | 130 | n_objects_normalized = n_objects.float().cuda() / 4 131 | n_objects_normalized_var = torch.autograd.Variable(n_objects_normalized) 132 | x_sem, x_ins, x_cls, x_hnet= model(input_var) 133 | criterion_mse = torch.nn.MSELoss().cuda() 134 | #x_hnet = x_hnet * 10 135 | 136 | 137 | ############################################## HNET ################################################## 138 | 139 | n, c = x_hnet.size() 140 | thining_gt = thining_gt.numpy() 141 | loss_hnet = torch.autograd.Variable(torch.zeros(1).float().cuda()) 142 | avg_size = 0 143 | loss_width = torch.autograd.Variable(torch.zeros(1).float().cuda()) 144 | avg_width = 0 145 | #print x_hnet 146 | for n_index in range(n): 147 | H = torch.autograd.Variable(torch.zeros((3,3)).float().cuda()) 148 | H[0,0] = x_hnet[n_index,0] 149 | H[0,1] = x_hnet[n_index,1] 150 | H[0,2] = x_hnet[n_index,2] 151 | H[1,1] = x_hnet[n_index,3] 152 | H[1,2] = x_hnet[n_index,4] 153 | H[2,1] = x_hnet[n_index,5] 154 | H[2,2] = 1 155 | if (thining_gt[n_index] > 0).any(): 156 | bird_list = [] 157 | for index in range(1, 5): 158 | list_y, list_x = np.where(thining_gt[n_index] == index) 159 | if len(list_y): 160 | avg_size += 1 161 | old_x = torch.autograd.Variable(torch.from_numpy(list_x).float().cuda()) 162 | p = np.vstack((list_x, list_y, np.ones(len(list_y)))).astype(np.float64) 163 | P = torch.autograd.Variable(torch.from_numpy(p).float().cuda()) 164 | P_trans = torch.mm(H,P) 165 | list_x = P_trans[0] 166 | list_y = P_trans[1] 167 | scale = P_trans[2] 168 | Y = torch.stack((list_y**2, list_y, torch.autograd.Variable(torch.ones(list_y.size()[0]).float().cuda())), dim=0) 169 | W = torch.inverse(Y.mm(Y.t()))# 170 | W = W.mm(Y).mm(list_x.unsqueeze(1)) 171 | new_x = W[0]*list_y**2 + W[1]*list_y + W[2] 172 | new_P = torch.stack((new_x, list_y, scale), dim=0) 173 | trans_x = torch.inverse(H).mm(new_P)[0] 174 | if index==2 or index==3: 175 | bird_list.append((W,list_y)) 176 | if criterion_mse(trans_x/50, old_x/50).data.cpu().numpy() < 1: 177 | loss_hnet += criterion_mse(trans_x/50, old_x/50) 178 | if len(bird_list) == 2: 179 | miny = torch.min(bird_list[0][1].min(), bird_list[1][1].min()) 180 | maxy = torch.max(bird_list[0][1].max(), bird_list[1][1].max()) 181 | width_list = [] 182 | for yy in torch.arange(float(miny.data.cpu().numpy()),float(maxy.data.cpu().numpy()),float(((maxy-miny)/10).data.cpu().numpy())): 183 | a = bird_list[0][0][0]*yy**2 + bird_list[0][0][1]*yy + bird_list[0][0][2] 184 | b = bird_list[1][0][0]*yy**2 + bird_list[1][0][1]*yy + bird_list[1][0][2] 185 | width_list.append(a - b) 186 | if len(width_list)>= 10: 187 | avg_width += 1 188 | width_variable = torch.cat((width_list[0],width_list[1]),0) 189 | for xx in range(2,10): 190 | width_variable = torch.cat((width_variable,width_list[xx]),0) 191 | width_variable = width_variable/width_variable.mean() 192 | mean_variable = torch.autograd.Variable(torch.ones(10).float().cuda()) 193 | loss_width += criterion_mse(width_variable, mean_variable)/3 194 | #else: 195 | # print trans_x, old_x 196 | if avg_width: 197 | loss_width /= avg_width 198 | loss_hnet /= avg_size 199 | loss_cls = criterion_mse(x_cls, n_objects_normalized_var) 200 | #print x_sem.size(), x_ins.size(), target_ins_var.size(), n_objects_var.size() (256L, 2L, 177L, 209L) (256L, 4L, 177L, 209L) (256L, 4L, 177L, 209L) (256L,) 201 | loss_ins = discriminative_loss(x_ins, target_ins_var, n_objects, 4, usegpu=True) 202 | #print x_sem.size(),target_var.size() 203 | loss_seg = criterion(x_sem, target_var, weight= weight_cus, size_average=True) 204 | loss = loss_seg + loss_ins + loss_cls + loss_hnet + loss_width 205 | 206 | losses_seg.update(loss_seg.data[0], input.size(0)) 207 | losses_ins.update(loss_ins.data[0], input.size(0)) 208 | losses_cls.update(loss_cls.data[0], input.size(0)) 209 | losses_hnet.update(loss_hnet.data[0], input.size(0)) 210 | losses_width.update(loss_width.data[0], input.size(0)) 211 | # compute gradient and do SGD step 212 | optimizer.zero_grad() 213 | loss.backward() 214 | optimizer.step() 215 | 216 | # measure elapsed time 217 | batch_time.update(time.time() - end) 218 | end = time.time() 219 | 220 | if i % args.print_freq == 0: 221 | print('Epoch: [{0}][{1}/{2}]\t' 222 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 223 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 224 | 'Loss_seg {loss_seg.val:.4f} ({loss_seg.avg:.4f})\t' 225 | 'Loss_ins {loss_ins.val:.4f} ({loss_ins.avg:.4f})\t' 226 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg:.4f})\t' 227 | 'Loss_hnet {loss_hnet.val:.4f} ({loss_hnet.avg:.4f})\t' 228 | 'Loss_width {loss_width.val:.4f} ({loss_width.avg:.4f})\t' 229 | 'Lr {lr.val:.5f} ({lr.avg:.5f})\t'.format( 230 | epoch, i, len(train_loader), batch_time=batch_time, 231 | data_time=data_time, loss_seg=losses_seg, loss_ins=losses_ins, loss_cls=losses_cls, loss_hnet=losses_hnet, loss_width=losses_width, lr=lrs)) 232 | 233 | 234 | 235 | def save_checkpoint(state, filename, resume_path): 236 | cur_path = os.path.join(resume_path, filename) 237 | torch.save(state, cur_path) 238 | 239 | 240 | class AverageMeter(object): 241 | """Computes and stores the average and current value""" 242 | def __init__(self): 243 | self.reset() 244 | 245 | def reset(self): 246 | self.val = 0 247 | self.avg = 0 248 | self.sum = 0 249 | self.count = 0 250 | 251 | def update(self, val, n=1): 252 | self.val = val 253 | self.sum += val * n 254 | self.count += n 255 | self.avg = self.sum / self.count 256 | 257 | 258 | def adjust_learning_rate(optimizer, curr_iter, max_iter, power=0.9): 259 | lr = args.lr * (1 - float(curr_iter)/max_iter)**power 260 | for param_group in optimizer.param_groups: 261 | param_group['lr'] = lr 262 | 263 | return lr 264 | 265 | if __name__ == '__main__': 266 | main() 267 | --------------------------------------------------------------------------------