├── LICENSE ├── README.md ├── config.py ├── dataset ├── CULane.py ├── Tusimple.py └── __init__.py ├── demo ├── demo.jpg └── demo_result.jpg ├── demo_test.py ├── experiments ├── exp0 │ └── cfg.json ├── exp10 │ └── cfg.json └── vgg_SCNN_DULR_w9 │ ├── cfg.json │ └── t7_to_pt.py ├── model.py ├── requirements.txt ├── test_CULane.py ├── test_tusimple.py ├── train.py └── utils ├── lane_evaluation ├── CULane │ ├── CMakeLists.txt │ ├── Run.sh │ ├── include │ │ ├── counter.hpp │ │ ├── hungarianGraph.hpp │ │ ├── lane_compare.hpp │ │ └── spline.hpp │ ├── src │ │ ├── counter.cpp │ │ ├── evaluate.cpp │ │ ├── lane_compare.cpp │ │ └── spline.cpp │ └── src_origin │ │ ├── counter.cpp │ │ ├── evaluate.cpp │ │ ├── lane_compare.cpp │ │ └── spline.cpp └── tusimple │ └── lane.py ├── lr_scheduler.py ├── prob2lines └── getLane.py ├── tensorboard.py └── transforms ├── __init__.py ├── data_augmentation.py └── transforms.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 HarryHan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCNN lane detection in Pytorch 2 | 3 | SCNN is a segmentation-tasked lane detection algorithm, described in ['Spatial As Deep: Spatial CNN for Traffic Scene Understanding'](https://arxiv.org/abs/1712.06080). The [official implementation]() is in lua torch. 4 | 5 | This repository contains a re-implementation in Pytorch. 6 | 7 | 8 | 9 | ### Updates 10 | 11 | - 2019 / 08 / 14: Code refined including more convenient test & evaluation script. 12 | - 2019 / 08 / 12: Trained model on both dataset provided. 13 | - 2019 / 05 / 08: Evaluation is provided. 14 | - 2019 / 04 / 23: Trained model converted from [official t7 model](https://github.com/XingangPan/SCNN#Testing) is provided. 15 | 16 |
17 | 18 | ## Data preparation 19 | 20 | ### CULane 21 | 22 | The dataset is available in [CULane](https://xingangpan.github.io/projects/CULane.html). Please download and unzip the files in one folder, which later is represented as `CULane_path`. Then modify the path of `CULane_path` in `config.py`. Also, modify the path of `CULane_path` as `data_dir` in `utils/lane_evaluation/CULane/Run.sh` . 23 | ``` 24 | CULane_path 25 | ├── driver_100_30frame 26 | ├── driver_161_90frame 27 | ├── driver_182_30frame 28 | ├── driver_193_90frame 29 | ├── driver_23_30frame 30 | ├── driver_37_30frame 31 | ├── laneseg_label_w16 32 | ├── laneseg_label_w16_test 33 | └── list 34 | ``` 35 | 36 | **Note: absolute path is encouraged.** 37 | 38 | 39 | 40 | 41 | 42 | ### Tusimple 43 | The dataset is available in [here](https://github.com/TuSimple/tusimple-benchmark/issues/3). Please download and unzip the files in one folder, which later is represented as `Tusimple_path`. Then modify the path of `Tusimple_path` in `config.py`. 44 | ``` 45 | Tusimple_path 46 | ├── clips 47 | ├── label_data_0313.json 48 | ├── label_data_0531.json 49 | ├── label_data_0601.json 50 | └── test_label.json 51 | ``` 52 | 53 | **Note: seg\_label images and gt.txt, as in CULane dataset format, will be generated the first time `Tusimple` object is instantiated. It may take time.** 54 | 55 | 56 | 57 |
58 | 59 | ## Trained Model Provided 60 | 61 | * Model trained on CULane Dataset can be converted from [official implementation](https://github.com/XingangPan/SCNN#Testing), which can be downloaded [here](https://drive.google.com/open?id=1Wv3r3dCYNBwJdKl_WPEfrEOt-XGaROKu). Please put the `vgg_SCNN_DULR_w9.t7` file into `experiments/vgg_SCNN_DULR_w9`. 62 | 63 | ```bash 64 | python experiments/vgg_SCNN_DULR_w9/t7_to_pt.py 65 | ``` 66 | 67 | Model will be cached into `experiments/vgg_SCNN_DULR_w9/vgg_SCNN_DULR_w9.pth`. 68 | 69 | **Note**:`torch.utils.serialization` is obsolete in Pytorch 1.0+. You can directly download **the converted model [here](https://drive.google.com/open?id=1bBdN3yhoOQBC9pRtBUxzeRrKJdF7uVTJ)**. 70 | 71 | 72 | 73 | * My trained model on Tusimple can be downloaded [here](https://drive.google.com/open?id=1IwEenTekMt-t6Yr5WJU9_kv4d_Pegd_Q). Its configure file is in `exp0`. 74 | 75 | | Accuracy | FP | FN | 76 | | -------- | ---- | ---- | 77 | | 94.16% |0.0735|0.0825| 78 | 79 | 80 | 81 | 82 | 83 | * My trained model on CULane can be downloaded [here](https://drive.google.com/open?id=1AZn23w8RbMh1P6lJcVcf6PcTIWJvQg9u). Its configure file is in `exp10`. 84 | 85 | | Category | F1-measure | 86 | | --------- | ------------------- | 87 | | Normal | 90.26 | 88 | | Crowded | 68.23 | 89 | | HLight | 61.84 | 90 | | Shadow | 61.16 | 91 | | No line | 43.44 | 92 | | Arrow | 84.64 | 93 | | Curve | 61.74 | 94 | | Crossroad | 2728 (FP measure) | 95 | | Night | 65.32 | 96 | 97 | 98 | 99 | 100 | 101 |
102 | 103 | 104 | ## Demo Test 105 | 106 | For single image demo test: 107 | 108 | ```shell 109 | python demo_test.py -i demo/demo.jpg 110 | -w experiments/vgg_SCNN_DULR_w9/vgg_SCNN_DULR_w9.pth 111 | [--visualize / -v] 112 | ``` 113 | 114 | ![](demo/demo_result.jpg "demo_result") 115 | 116 | 117 | 118 |
119 | 120 | ## Train 121 | 122 | 1. Specify an experiment directory, e.g. `experiments/exp0`. 123 | 124 | 2. Modify the hyperparameters in `experiments/exp0/cfg.json`. 125 | 126 | 3. Start training: 127 | 128 | ```shell 129 | python train.py --exp_dir ./experiments/exp0 [--resume/-r] 130 | ``` 131 | 132 | 4. Monitor on tensorboard: 133 | 134 | ```bash 135 | tensorboard --logdir='experiments/exp0' 136 | ``` 137 | 138 | **Note** 139 | 140 | 141 | - My model is trained with `torch.nn.DataParallel`. Modify it according to your hardware configuration. 142 | - Currently the backbone is vgg16 from torchvision. Several modifications are done to the torchvision model according to paper, i.e., i). dilation of last three conv layer is changed to 2, ii). last two maxpooling layer is removed. 143 | 144 | 145 | 146 |
147 | 148 | ## Evaluation 149 | 150 | * CULane Evaluation code is ported from [official implementation]() and an extra `CMakeLists.txt` is provided. 151 | 152 | 1. Please build the CPP code first. 153 | 2. Then modify `root` as absolute project path in `utils/lane_evaluation/CULane/Run.sh`. 154 | 155 | ```bash 156 | cd utils/lane_evaluation/CULane 157 | mkdir build && cd build 158 | cmake .. 159 | make 160 | ``` 161 | 162 | Just run the evaluation script. Result will be saved into corresponding `exp_dir` directory, 163 | 164 | ``` shell 165 | python test_CULane.py --exp_dir ./experiments/exp10 166 | ``` 167 | 168 | 169 | 170 | * Tusimple Evaluation code is ported from [tusimple repo](https://github.com/TuSimple/tusimple-benchmark/blob/master/evaluate/lane.py). 171 | 172 | ```Shell 173 | python test_tusimple.py --exp_dir ./experiments/exp0 174 | ``` 175 | 176 | 177 | 178 | 179 | 180 | ## Acknowledgement 181 | 182 | This repos is build based on [official implementation](). 183 | 184 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | Dataset_Path = dict( 2 | CULane = "/home/lion/Dataset/CULane/data/CULane", 3 | Tusimple = "/home/lion/Dataset/tusimple" 4 | ) 5 | -------------------------------------------------------------------------------- /dataset/CULane.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class CULane(Dataset): 10 | def __init__(self, path, image_set, transforms=None): 11 | super(CULane, self).__init__() 12 | assert image_set in ('train', 'val', 'test'), "image_set is not valid!" 13 | self.data_dir_path = path 14 | self.image_set = image_set 15 | self.transforms = transforms 16 | 17 | if image_set != 'test': 18 | self.createIndex() 19 | else: 20 | self.createIndex_test() 21 | 22 | 23 | def createIndex(self): 24 | listfile = os.path.join(self.data_dir_path, "list", "{}_gt.txt".format(self.image_set)) 25 | 26 | self.img_list = [] 27 | self.segLabel_list = [] 28 | self.exist_list = [] 29 | with open(listfile) as f: 30 | for line in f: 31 | line = line.strip() 32 | l = line.split(" ") 33 | self.img_list.append(os.path.join(self.data_dir_path, l[0][1:])) # l[0][1:] get rid of the first '/' so as for os.path.join 34 | self.segLabel_list.append(os.path.join(self.data_dir_path, l[1][1:])) 35 | self.exist_list.append([int(x) for x in l[2:]]) 36 | 37 | def createIndex_test(self): 38 | listfile = os.path.join(self.data_dir_path, "list", "{}.txt".format(self.image_set)) 39 | 40 | self.img_list = [] 41 | with open(listfile) as f: 42 | for line in f: 43 | line = line.strip() 44 | self.img_list.append(os.path.join(self.data_dir_path, line[1:])) # l[0][1:] get rid of the first '/' so as for os.path.join 45 | 46 | def __getitem__(self, idx): 47 | img = cv2.imread(self.img_list[idx]) 48 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 49 | if self.image_set != 'test': 50 | segLabel = cv2.imread(self.segLabel_list[idx])[:, :, 0] 51 | exist = np.array(self.exist_list[idx]) 52 | else: 53 | segLabel = None 54 | exist = None 55 | 56 | sample = {'img': img, 57 | 'segLabel': segLabel, 58 | 'exist': exist, 59 | 'img_name': self.img_list[idx]} 60 | if self.transforms is not None: 61 | sample = self.transforms(sample) 62 | return sample 63 | 64 | def __len__(self): 65 | return len(self.img_list) 66 | 67 | @staticmethod 68 | def collate(batch): 69 | if isinstance(batch[0]['img'], torch.Tensor): 70 | img = torch.stack([b['img'] for b in batch]) 71 | else: 72 | img = [b['img'] for b in batch] 73 | 74 | if batch[0]['segLabel'] is None: 75 | segLabel = None 76 | exist = None 77 | elif isinstance(batch[0]['segLabel'], torch.Tensor): 78 | segLabel = torch.stack([b['segLabel'] for b in batch]) 79 | exist = torch.stack([b['exist'] for b in batch]) 80 | else: 81 | segLabel = [b['segLabel'] for b in batch] 82 | exist = [b['exist'] for b in batch] 83 | 84 | samples = {'img': img, 85 | 'segLabel': segLabel, 86 | 'exist': exist, 87 | 'img_name': [x['img_name'] for x in batch]} 88 | 89 | return samples -------------------------------------------------------------------------------- /dataset/Tusimple.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class Tusimple(Dataset): 11 | """ 12 | image_set is splitted into three partitions: train, val, test. 13 | train includes label_data_0313.json, label_data_0601.json 14 | val includes label_data_0531.json 15 | test includes test_label.json 16 | """ 17 | TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json'] 18 | VAL_SET = ['label_data_0531.json'] 19 | TEST_SET = ['test_label.json'] 20 | 21 | def __init__(self, path, image_set, transforms=None): 22 | super(Tusimple, self).__init__() 23 | assert image_set in ('train', 'val', 'test'), "image_set is not valid!" 24 | self.data_dir_path = path 25 | self.image_set = image_set 26 | self.transforms = transforms 27 | 28 | if not os.path.exists(os.path.join(path, "seg_label")): 29 | print("Label is going to get generated into dir: {} ...".format(os.path.join(path, "seg_label"))) 30 | self.generate_label() 31 | self.createIndex() 32 | 33 | def createIndex(self): 34 | self.img_list = [] 35 | self.segLabel_list = [] 36 | self.exist_list = [] 37 | 38 | listfile = os.path.join(self.data_dir_path, "seg_label", "list", "{}_gt.txt".format(self.image_set)) 39 | if not os.path.exists(listfile): 40 | raise FileNotFoundError("List file doesn't exist. Label has to be generated! ...") 41 | 42 | with open(listfile) as f: 43 | for line in f: 44 | line = line.strip() 45 | l = line.split(" ") 46 | self.img_list.append(os.path.join(self.data_dir_path, l[0][1:])) # l[0][1:] get rid of the first '/' so as for os.path.join 47 | self.segLabel_list.append(os.path.join(self.data_dir_path, l[1][1:])) 48 | self.exist_list.append([int(x) for x in l[2:]]) 49 | 50 | def __getitem__(self, idx): 51 | img = cv2.imread(self.img_list[idx]) 52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 53 | if self.image_set != 'test': 54 | segLabel = cv2.imread(self.segLabel_list[idx])[:, :, 0] 55 | exist = np.array(self.exist_list[idx]) 56 | else: 57 | segLabel = None 58 | exist = None 59 | 60 | sample = {'img': img, 61 | 'segLabel': segLabel, 62 | 'exist': exist, 63 | 'img_name': self.img_list[idx]} 64 | if self.transforms is not None: 65 | sample = self.transforms(sample) 66 | return sample 67 | 68 | def __len__(self): 69 | return len(self.img_list) 70 | 71 | def generate_label(self): 72 | save_dir = os.path.join(self.data_dir_path, "seg_label") 73 | os.makedirs(save_dir, exist_ok=True) 74 | 75 | # --------- merge json into one file --------- 76 | with open(os.path.join(save_dir, "train.json"), "w") as outfile: 77 | for json_name in self.TRAIN_SET: 78 | with open(os.path.join(self.data_dir_path, json_name)) as infile: 79 | for line in infile: 80 | outfile.write(line) 81 | 82 | with open(os.path.join(save_dir, "val.json"), "w") as outfile: 83 | for json_name in self.VAL_SET: 84 | with open(os.path.join(self.data_dir_path, json_name)) as infile: 85 | for line in infile: 86 | outfile.write(line) 87 | 88 | with open(os.path.join(save_dir, "test.json"), "w") as outfile: 89 | for json_name in self.TEST_SET: 90 | with open(os.path.join(self.data_dir_path, json_name)) as infile: 91 | for line in infile: 92 | outfile.write(line) 93 | 94 | self._gen_label_for_json('train') 95 | print("train set is done") 96 | self._gen_label_for_json('val') 97 | print("val set is done") 98 | self._gen_label_for_json('test') 99 | print("test set is done") 100 | 101 | def _gen_label_for_json(self, image_set): 102 | H, W = 720, 1280 103 | SEG_WIDTH = 30 104 | save_dir = "seg_label" 105 | 106 | os.makedirs(os.path.join(self.data_dir_path, save_dir, "list"), exist_ok=True) 107 | list_f = open(os.path.join(self.data_dir_path, save_dir, "list", "{}_gt.txt".format(image_set)), "w") 108 | 109 | json_path = os.path.join(self.data_dir_path, save_dir, "{}.json".format(image_set)) 110 | with open(json_path) as f: 111 | for line in f: 112 | label = json.loads(line) 113 | 114 | # ---------- clean and sort lanes ------------- 115 | lanes = [] 116 | _lanes = [] 117 | slope = [] # identify 1st, 2nd, 3rd, 4th lane through slope 118 | for i in range(len(label['lanes'])): 119 | l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0] 120 | if (len(l)>1): 121 | _lanes.append(l) 122 | slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180) 123 | _lanes = [_lanes[i] for i in np.argsort(slope)] 124 | slope = [slope[i] for i in np.argsort(slope)] 125 | 126 | idx_1 = None 127 | idx_2 = None 128 | idx_3 = None 129 | idx_4 = None 130 | for i in range(len(slope)): 131 | if slope[i]<=90: 132 | idx_2 = i 133 | idx_1 = i-1 if i>0 else None 134 | else: 135 | idx_3 = i 136 | idx_4 = i+1 if i+1 < len(slope) else None 137 | break 138 | lanes.append([] if idx_1 is None else _lanes[idx_1]) 139 | lanes.append([] if idx_2 is None else _lanes[idx_2]) 140 | lanes.append([] if idx_3 is None else _lanes[idx_3]) 141 | lanes.append([] if idx_4 is None else _lanes[idx_4]) 142 | # --------------------------------------------- 143 | 144 | img_path = label['raw_file'] 145 | seg_img = np.zeros((H, W, 3)) 146 | list_str = [] # str to be written to list.txt 147 | for i in range(len(lanes)): 148 | coords = lanes[i] 149 | if len(coords) < 4: 150 | list_str.append('0') 151 | continue 152 | for j in range(len(coords)-1): 153 | cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2) 154 | list_str.append('1') 155 | 156 | seg_path = img_path.split("/") 157 | seg_path, img_name = os.path.join(self.data_dir_path, save_dir, seg_path[1], seg_path[2]), seg_path[3] 158 | os.makedirs(seg_path, exist_ok=True) 159 | seg_path = os.path.join(seg_path, img_name[:-3]+"png") 160 | cv2.imwrite(seg_path, seg_img) 161 | 162 | seg_path = "/".join([save_dir, *img_path.split("/")[1:3], img_name[:-3]+"png"]) 163 | if seg_path[0] != '/': 164 | seg_path = '/' + seg_path 165 | if img_path[0] != '/': 166 | img_path = '/' + img_path 167 | list_str.insert(0, seg_path) 168 | list_str.insert(0, img_path) 169 | list_str = " ".join(list_str) + "\n" 170 | list_f.write(list_str) 171 | 172 | list_f.close() 173 | 174 | @staticmethod 175 | def collate(batch): 176 | if isinstance(batch[0]['img'], torch.Tensor): 177 | img = torch.stack([b['img'] for b in batch]) 178 | else: 179 | img = [b['img'] for b in batch] 180 | 181 | if batch[0]['segLabel'] is None: 182 | segLabel = None 183 | exist = None 184 | elif isinstance(batch[0]['segLabel'], torch.Tensor): 185 | segLabel = torch.stack([b['segLabel'] for b in batch]) 186 | exist = torch.stack([b['exist'] for b in batch]) 187 | else: 188 | segLabel = [b['segLabel'] for b in batch] 189 | exist = [b['exist'] for b in batch] 190 | 191 | samples = {'img': img, 192 | 'segLabel': segLabel, 193 | 'exist': exist, 194 | 'img_name': [x['img_name'] for x in batch]} 195 | 196 | return samples -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .CULane import CULane 2 | from .Tusimple import Tusimple -------------------------------------------------------------------------------- /demo/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harryhan618/SCNN_Pytorch/bbe0acff4681bcda7e984f867dd31fdaa9a7bf81/demo/demo.jpg -------------------------------------------------------------------------------- /demo/demo_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harryhan618/SCNN_Pytorch/bbe0acff4681bcda7e984f867dd31fdaa9a7bf81/demo/demo_result.jpg -------------------------------------------------------------------------------- /demo_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import torch 4 | 5 | from model import SCNN 6 | from utils.prob2lines import getLane 7 | from utils.transforms import * 8 | 9 | net = SCNN(input_size=(800, 288), pretrained=False) 10 | mean=(0.3598, 0.3653, 0.3662) # CULane mean, std 11 | std=(0.2573, 0.2663, 0.2756) 12 | transform_img = Resize((800, 288)) 13 | transform_to_net = Compose(ToTensor(), Normalize(mean=mean, std=std)) 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img") 19 | parser.add_argument("--weight_path", '-w', type=str, help="Path to model weights") 20 | parser.add_argument("--visualize", '-v', action="store_true", default=False, help="Visualize the result") 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | img_path = args.img_path 28 | weight_path = args.weight_path 29 | 30 | img = cv2.imread(img_path) 31 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 32 | img = transform_img({'img': img})['img'] 33 | x = transform_to_net({'img': img})['img'] 34 | x.unsqueeze_(0) 35 | 36 | save_dict = torch.load(weight_path, map_location='cpu') 37 | net.load_state_dict(save_dict['net']) 38 | net.eval() 39 | 40 | seg_pred, exist_pred = net(x)[:2] 41 | seg_pred = seg_pred.detach().cpu().numpy() 42 | exist_pred = exist_pred.detach().cpu().numpy() 43 | seg_pred = seg_pred[0] 44 | exist = [1 if exist_pred[0, i] > 0.5 else 0 for i in range(4)] 45 | 46 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 47 | lane_img = np.zeros_like(img) 48 | color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8') 49 | coord_mask = np.argmax(seg_pred, axis=0) 50 | for i in range(0, 4): 51 | if exist_pred[0, i] > 0.5: 52 | lane_img[coord_mask == (i + 1)] = color[i] 53 | img = cv2.addWeighted(src1=lane_img, alpha=0.8, src2=img, beta=1., gamma=0.) 54 | cv2.imwrite("demo/demo_result.jpg", img) 55 | 56 | for x in getLane.prob2lines_CULane(seg_pred, exist): 57 | print(x) 58 | 59 | if args.visualize: 60 | print([1 if exist_pred[0, i] > 0.5 else 0 for i in range(4)]) 61 | cv2.imshow("", img) 62 | cv2.waitKey(0) 63 | cv2.destroyAllWindows() 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /experiments/exp0/cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "device": "cuda:0", 3 | "MAX_EPOCHES": 60, 4 | 5 | "dataset": { 6 | "dataset_name": "Tusimple", 7 | "batch_size": 32, 8 | "resize_shape": [512, 288] 9 | }, 10 | 11 | "optim": { 12 | "lr": 15e-2, 13 | "momentum": 0.9, 14 | "weight_decay": 1e-4, 15 | "nesterov": true 16 | }, 17 | 18 | "lr_scheduler": { 19 | "warmup": 20, 20 | "max_iter": 1500, 21 | "min_lrs": 1e-10 22 | }, 23 | 24 | "model": { 25 | "scale_exist": 0.07 26 | } 27 | } -------------------------------------------------------------------------------- /experiments/exp10/cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "device": "cuda:0", 3 | "MAX_EPOCHES": 30, 4 | 5 | "dataset": { 6 | "dataset_name": "CULane", 7 | "batch_size": 128, 8 | "resize_shape": [800, 288] 9 | }, 10 | 11 | "optim": { 12 | "lr": 16e-2, 13 | "momentum": 0.9, 14 | "weight_decay": 1e-3, 15 | "nesterov": true 16 | }, 17 | 18 | "lr_scheduler": { 19 | "warmup": 50, 20 | "max_iter": 8000 21 | } 22 | } -------------------------------------------------------------------------------- /experiments/vgg_SCNN_DULR_w9/cfg.json: -------------------------------------------------------------------------------- 1 | { 2 | "device": "cuda:0", 3 | 4 | "dataset": { 5 | "dataset_name": "CULane", 6 | "resize_shape": [800, 288] 7 | } 8 | 9 | 10 | } -------------------------------------------------------------------------------- /experiments/vgg_SCNN_DULR_w9/t7_to_pt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | abs_file_path = os.path.abspath(os.path.dirname(__file__)) 4 | sys.path.append(os.path.join(abs_file_path, "..", "..")) # add path 5 | 6 | import torch 7 | import torch.nn as nn 8 | import collections 9 | from torch.utils.serialization import load_lua 10 | from model import SCNN 11 | 12 | model1 = load_lua('experiments/vgg_SCNN_DULR_w9/vgg_SCNN_DULR_w9.t7', unknown_classes=True) 13 | model2 = collections.OrderedDict() 14 | 15 | model2['backbone.0.weight'] = model1.modules[0].weight 16 | model2['backbone.1.weight'] = model1.modules[1].weight 17 | model2['backbone.1.bias'] = model1.modules[1].bias 18 | model2['backbone.1.running_mean'] = model1.modules[1].running_mean 19 | model2['backbone.1.running_var'] = model1.modules[1].running_var 20 | model2['backbone.3.weight'] = model1.modules[3].weight 21 | model2['backbone.4.weight'] = model1.modules[4].weight 22 | model2['backbone.4.bias'] = model1.modules[4].bias 23 | model2['backbone.4.running_mean'] = model1.modules[4].running_mean 24 | model2['backbone.4.running_var'] = model1.modules[4].running_var 25 | 26 | model2['backbone.7.weight'] = model1.modules[7].weight 27 | model2['backbone.8.weight'] = model1.modules[8].weight 28 | model2['backbone.8.bias'] = model1.modules[8].bias 29 | model2['backbone.8.running_mean'] = model1.modules[8].running_mean 30 | model2['backbone.8.running_var'] = model1.modules[8].running_var 31 | model2['backbone.10.weight'] = model1.modules[10].weight 32 | model2['backbone.11.weight'] = model1.modules[11].weight 33 | model2['backbone.11.bias'] = model1.modules[11].bias 34 | model2['backbone.11.running_mean'] = model1.modules[11].running_mean 35 | model2['backbone.11.running_var'] = model1.modules[11].running_var 36 | 37 | model2['backbone.14.weight'] = model1.modules[14].weight 38 | model2['backbone.15.weight'] = model1.modules[15].weight 39 | model2['backbone.15.bias'] = model1.modules[15].bias 40 | model2['backbone.15.running_mean'] = model1.modules[15].running_mean 41 | model2['backbone.15.running_var'] = model1.modules[15].running_var 42 | model2['backbone.17.weight'] = model1.modules[17].weight 43 | model2['backbone.18.weight'] = model1.modules[18].weight 44 | model2['backbone.18.bias'] = model1.modules[18].bias 45 | model2['backbone.18.running_mean'] = model1.modules[18].running_mean 46 | model2['backbone.18.running_var'] = model1.modules[18].running_var 47 | model2['backbone.20.weight'] = model1.modules[20].weight 48 | model2['backbone.21.weight'] = model1.modules[21].weight 49 | model2['backbone.21.bias'] = model1.modules[21].bias 50 | model2['backbone.21.running_mean'] = model1.modules[21].running_mean 51 | model2['backbone.21.running_var'] = model1.modules[21].running_var 52 | 53 | model2['backbone.24.weight'] = model1.modules[24].weight 54 | model2['backbone.25.weight'] = model1.modules[25].weight 55 | model2['backbone.25.bias'] = model1.modules[25].bias 56 | model2['backbone.25.running_mean'] = model1.modules[25].running_mean 57 | model2['backbone.25.running_var'] = model1.modules[25].running_var 58 | model2['backbone.27.weight'] = model1.modules[27].weight 59 | model2['backbone.28.weight'] = model1.modules[28].weight 60 | model2['backbone.28.bias'] = model1.modules[28].bias 61 | model2['backbone.28.running_mean'] = model1.modules[28].running_mean 62 | model2['backbone.28.running_var'] = model1.modules[28].running_var 63 | model2['backbone.30.weight'] = model1.modules[30].weight 64 | model2['backbone.31.weight'] = model1.modules[31].weight 65 | model2['backbone.31.bias'] = model1.modules[31].bias 66 | model2['backbone.31.running_mean'] = model1.modules[31].running_mean 67 | model2['backbone.31.running_var'] = model1.modules[31].running_var 68 | 69 | model2['backbone.34.weight'] = model1.modules[33].weight 70 | model2['backbone.35.weight'] = model1.modules[34].weight 71 | model2['backbone.35.bias'] = model1.modules[34].bias 72 | model2['backbone.35.running_mean'] = model1.modules[34].running_mean 73 | model2['backbone.35.running_var'] = model1.modules[34].running_var 74 | model2['backbone.37.weight'] = model1.modules[36].weight 75 | model2['backbone.38.weight'] = model1.modules[37].weight 76 | model2['backbone.38.bias'] = model1.modules[37].bias 77 | model2['backbone.38.running_mean'] = model1.modules[37].running_mean 78 | model2['backbone.38.running_var'] = model1.modules[37].running_var 79 | model2['backbone.40.weight'] = model1.modules[39].weight 80 | model2['backbone.41.weight'] = model1.modules[40].weight 81 | model2['backbone.41.bias'] = model1.modules[40].bias 82 | model2['backbone.41.running_mean'] = model1.modules[40].running_mean 83 | model2['backbone.41.running_var'] = model1.modules[40].running_var 84 | 85 | model2['layer1.0.weight'] = model1.modules[42].modules[0].weight 86 | model2['layer1.1.weight'] = model1.modules[42].modules[1].weight 87 | model2['layer1.1.bias'] = model1.modules[42].modules[1].bias 88 | model2['layer1.1.running_mean'] = model1.modules[42].modules[1].running_mean 89 | model2['layer1.1.running_var'] = model1.modules[42].modules[1].running_var 90 | model2['layer1.3.weight'] = model1.modules[42].modules[3].weight 91 | model2['layer1.4.weight'] = model1.modules[42].modules[4].weight 92 | model2['layer1.4.bias'] = model1.modules[42].modules[4].bias 93 | model2['layer1.4.running_mean'] = model1.modules[42].modules[4].running_mean 94 | model2['layer1.4.running_var'] = model1.modules[42].modules[4].running_var 95 | 96 | model2['message_passing.up_down.weight'] = model1.modules[42].modules[6].modules[0].modules[0].modules[2].modules[0].modules[1].modules[1].modules[0].weight 97 | model2['message_passing.down_up.weight'] = model1.modules[42].modules[6].modules[0].modules[0].modules[140].modules[1].modules[2].modules[0].modules[0].weight 98 | model2['message_passing.left_right.weight'] = model1.modules[42].modules[6].modules[1].modules[0].modules[2].modules[0].modules[1].modules[1].modules[0].weight 99 | model2['message_passing.right_left.weight'] = model1.modules[42].modules[6].modules[1].modules[0].modules[396].modules[1].modules[2].modules[0].modules[0].weight 100 | 101 | model2['layer2.1.weight'] = model1.modules[42].modules[8].weight 102 | model2['layer2.1.bias'] = model1.modules[42].modules[8].bias 103 | model2['fc.0.weight'] = model1.modules[43].modules[1].modules[3].weight 104 | model2['fc.0.bias'] = model1.modules[43].modules[1].modules[3].bias 105 | model2['fc.2.weight'] = model1.modules[43].modules[1].modules[5].weight 106 | model2['fc.2.bias'] = model1.modules[43].modules[1].modules[5].bias 107 | 108 | save_name = os.path.join('experiments', 'vgg_SCNN_DULR_w9', 'vgg_SCNN_DULR_w9.pth') 109 | torch.save(model2, save_name) 110 | 111 | # load and save again 112 | net = SCNN(input_size=(800, 288), pretrained=False) 113 | d = torch.load(save_name) 114 | net.load_state_dict(d, strict=False) 115 | for m in net.backbone.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | if m.bias is not None: 118 | m.bias.data.zero_() 119 | 120 | 121 | save_dict = { 122 | "epoch": 0, 123 | "net": net.state_dict(), 124 | "optim": None, 125 | "lr_scheduler": None 126 | } 127 | 128 | if not os.path.exists(os.path.join('experiments', 'vgg_SCNN_DULR_w9')): 129 | os.makedirs(os.path.join('experiments', 'vgg_SCNN_DULR_w9'), exist_ok=True) 130 | torch.save(save_dict, save_name) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | 7 | class SCNN(nn.Module): 8 | def __init__( 9 | self, 10 | input_size, 11 | ms_ks=9, 12 | pretrained=True 13 | ): 14 | """ 15 | Argument 16 | ms_ks: kernel size in message passing conv 17 | """ 18 | super(SCNN, self).__init__() 19 | self.pretrained = pretrained 20 | self.net_init(input_size, ms_ks) 21 | if not pretrained: 22 | self.weight_init() 23 | 24 | self.scale_background = 0.4 25 | self.scale_seg = 1.0 26 | self.scale_exist = 0.1 27 | 28 | self.ce_loss = nn.CrossEntropyLoss(weight=torch.tensor([self.scale_background, 1, 1, 1, 1])) 29 | self.bce_loss = nn.BCELoss() 30 | 31 | def forward(self, img, seg_gt=None, exist_gt=None): 32 | x = self.backbone(img) 33 | x = self.layer1(x) 34 | x = self.message_passing_forward(x) 35 | x = self.layer2(x) 36 | 37 | seg_pred = F.interpolate(x, scale_factor=8, mode='bilinear', align_corners=True) 38 | x = self.layer3(x) 39 | x = x.view(-1, self.fc_input_feature) 40 | exist_pred = self.fc(x) 41 | 42 | if seg_gt is not None and exist_gt is not None: 43 | loss_seg = self.ce_loss(seg_pred, seg_gt) 44 | loss_exist = self.bce_loss(exist_pred, exist_gt) 45 | loss = loss_seg * self.scale_seg + loss_exist * self.scale_exist 46 | else: 47 | loss_seg = torch.tensor(0, dtype=img.dtype, device=img.device) 48 | loss_exist = torch.tensor(0, dtype=img.dtype, device=img.device) 49 | loss = torch.tensor(0, dtype=img.dtype, device=img.device) 50 | 51 | return seg_pred, exist_pred, loss_seg, loss_exist, loss 52 | 53 | def message_passing_forward(self, x): 54 | Vertical = [True, True, False, False] 55 | Reverse = [False, True, False, True] 56 | for ms_conv, v, r in zip(self.message_passing, Vertical, Reverse): 57 | x = self.message_passing_once(x, ms_conv, v, r) 58 | return x 59 | 60 | def message_passing_once(self, x, conv, vertical=True, reverse=False): 61 | """ 62 | Argument: 63 | ---------- 64 | x: input tensor 65 | vertical: vertical message passing or horizontal 66 | reverse: False for up-down or left-right, True for down-up or right-left 67 | """ 68 | nB, C, H, W = x.shape 69 | if vertical: 70 | slices = [x[:, :, i:(i + 1), :] for i in range(H)] 71 | dim = 2 72 | else: 73 | slices = [x[:, :, :, i:(i + 1)] for i in range(W)] 74 | dim = 3 75 | if reverse: 76 | slices = slices[::-1] 77 | 78 | out = [slices[0]] 79 | for i in range(1, len(slices)): 80 | out.append(slices[i] + F.relu(conv(out[i - 1]))) 81 | if reverse: 82 | out = out[::-1] 83 | return torch.cat(out, dim=dim) 84 | 85 | def net_init(self, input_size, ms_ks): 86 | input_w, input_h = input_size 87 | self.fc_input_feature = 5 * int(input_w/16) * int(input_h/16) 88 | self.backbone = models.vgg16_bn(pretrained=self.pretrained).features 89 | 90 | # ----------------- process backbone ----------------- 91 | for i in [34, 37, 40]: 92 | conv = self.backbone._modules[str(i)] 93 | dilated_conv = nn.Conv2d( 94 | conv.in_channels, conv.out_channels, conv.kernel_size, stride=conv.stride, 95 | padding=tuple(p * 2 for p in conv.padding), dilation=2, bias=(conv.bias is not None) 96 | ) 97 | dilated_conv.load_state_dict(conv.state_dict()) 98 | self.backbone._modules[str(i)] = dilated_conv 99 | self.backbone._modules.pop('33') 100 | self.backbone._modules.pop('43') 101 | 102 | # ----------------- SCNN part ----------------- 103 | self.layer1 = nn.Sequential( 104 | nn.Conv2d(512, 1024, 3, padding=4, dilation=4, bias=False), 105 | nn.BatchNorm2d(1024), 106 | nn.ReLU(), 107 | nn.Conv2d(1024, 128, 1, bias=False), 108 | nn.BatchNorm2d(128), 109 | nn.ReLU() # (nB, 128, 36, 100) 110 | ) 111 | 112 | # ----------------- add message passing ----------------- 113 | self.message_passing = nn.ModuleList() 114 | self.message_passing.add_module('up_down', nn.Conv2d(128, 128, (1, ms_ks), padding=(0, ms_ks // 2), bias=False)) 115 | self.message_passing.add_module('down_up', nn.Conv2d(128, 128, (1, ms_ks), padding=(0, ms_ks // 2), bias=False)) 116 | self.message_passing.add_module('left_right', 117 | nn.Conv2d(128, 128, (ms_ks, 1), padding=(ms_ks // 2, 0), bias=False)) 118 | self.message_passing.add_module('right_left', 119 | nn.Conv2d(128, 128, (ms_ks, 1), padding=(ms_ks // 2, 0), bias=False)) 120 | # (nB, 128, 36, 100) 121 | 122 | # ----------------- SCNN part ----------------- 123 | self.layer2 = nn.Sequential( 124 | nn.Dropout2d(0.1), 125 | nn.Conv2d(128, 5, 1) # get (nB, 5, 36, 100) 126 | ) 127 | 128 | self.layer3 = nn.Sequential( 129 | nn.Softmax(dim=1), # (nB, 5, 36, 100) 130 | nn.AvgPool2d(2, 2), # (nB, 5, 18, 50) 131 | ) 132 | self.fc = nn.Sequential( 133 | nn.Linear(self.fc_input_feature, 128), 134 | nn.ReLU(), 135 | nn.Linear(128, 4), 136 | nn.Sigmoid() 137 | ) 138 | 139 | def weight_init(self): 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | m.reset_parameters() 143 | elif isinstance(m, nn.BatchNorm2d): 144 | m.weight.data[:] = 1. 145 | m.bias.data.zero_() 146 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | torch>=0.4.1 4 | torchvision -------------------------------------------------------------------------------- /test_CULane.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | import dataset 10 | from config import * 11 | from model import SCNN 12 | from utils.prob2lines import getLane 13 | from utils.transforms import * 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--exp_dir", type=str, default="./experiments/exp10") 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | # ------------ config ------------ 24 | args = parse_args() 25 | exp_dir = args.exp_dir 26 | exp_name = exp_dir.split('/')[-1] 27 | 28 | with open(os.path.join(exp_dir, "cfg.json")) as f: 29 | exp_cfg = json.load(f) 30 | resize_shape = tuple(exp_cfg['dataset']['resize_shape']) 31 | device = torch.device('cuda') 32 | 33 | 34 | def split_path(path): 35 | """split path tree into list""" 36 | folders = [] 37 | while True: 38 | path, folder = os.path.split(path) 39 | if folder != "": 40 | folders.insert(0, folder) 41 | else: 42 | if path != "": 43 | folders.insert(0, path) 44 | break 45 | return folders 46 | 47 | 48 | # ------------ data and model ------------ 49 | # # CULane mean, std 50 | # mean=(0.3598, 0.3653, 0.3662) 51 | # std=(0.2573, 0.2663, 0.2756) 52 | # Imagenet mean, std 53 | mean = (0.485, 0.456, 0.406) 54 | std = (0.229, 0.224, 0.225) 55 | dataset_name = exp_cfg['dataset'].pop('dataset_name') 56 | Dataset_Type = getattr(dataset, dataset_name) 57 | transform = Compose(Resize(resize_shape), ToTensor(), 58 | Normalize(mean=mean, std=std)) 59 | test_dataset = Dataset_Type(Dataset_Path[dataset_name], "test", transform) 60 | test_loader = DataLoader(test_dataset, batch_size=64, collate_fn=test_dataset.collate, num_workers=4) 61 | 62 | net = SCNN(resize_shape, pretrained=False) 63 | save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth') 64 | save_dict = torch.load(save_name, map_location='cpu') 65 | print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch']) 66 | net.load_state_dict(save_dict['net']) 67 | net = torch.nn.DataParallel(net.to(device)) 68 | net.eval() 69 | 70 | # ------------ test ------------ 71 | out_path = os.path.join(exp_dir, "coord_output") 72 | evaluation_path = os.path.join(exp_dir, "evaluate") 73 | if not os.path.exists(out_path): 74 | os.mkdir(out_path) 75 | if not os.path.exists(evaluation_path): 76 | os.mkdir(evaluation_path) 77 | 78 | progressbar = tqdm(range(len(test_loader))) 79 | with torch.no_grad(): 80 | for batch_idx, sample in enumerate(test_loader): 81 | img = sample['img'].to(device) 82 | img_name = sample['img_name'] 83 | 84 | seg_pred, exist_pred = net(img)[:2] 85 | seg_pred = F.softmax(seg_pred, dim=1) 86 | seg_pred = seg_pred.detach().cpu().numpy() 87 | exist_pred = exist_pred.detach().cpu().numpy() 88 | 89 | for b in range(len(seg_pred)): 90 | seg = seg_pred[b] 91 | exist = [1 if exist_pred[b, i] > 0.5 else 0 for i in range(4)] 92 | lane_coords = getLane.prob2lines_CULane(seg, exist, resize_shape=(590, 1640), y_px_gap=20, pts=18) 93 | 94 | path_tree = split_path(img_name[b]) 95 | save_dir, save_name = path_tree[-3:-1], path_tree[-1] 96 | save_dir = os.path.join(out_path, *save_dir) 97 | save_name = save_name[:-3] + "lines.txt" 98 | save_name = os.path.join(save_dir, save_name) 99 | if not os.path.exists(save_dir): 100 | os.makedirs(save_dir) 101 | 102 | with open(save_name, "w") as f: 103 | for l in lane_coords: 104 | for (x, y) in l: 105 | print("{} {}".format(x, y), end=" ", file=f) 106 | print(file=f) 107 | 108 | progressbar.update(1) 109 | progressbar.close() 110 | 111 | # ---- evaluate ---- 112 | os.system("sh utils/lane_evaluation/CULane/Run.sh " + exp_name) 113 | -------------------------------------------------------------------------------- /test_tusimple.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | import dataset 10 | from config import * 11 | from model import SCNN 12 | from utils.prob2lines import getLane 13 | from utils.transforms import * 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--exp_dir", type=str, default="./experiments/exp0") 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | # ------------ config ------------ 24 | args = parse_args() 25 | exp_dir = args.exp_dir 26 | exp_name = exp_dir.split('/')[-1] 27 | 28 | with open(os.path.join(exp_dir, "cfg.json")) as f: 29 | exp_cfg = json.load(f) 30 | resize_shape = tuple(exp_cfg['dataset']['resize_shape']) 31 | device = torch.device('cuda') 32 | 33 | 34 | def split_path(path): 35 | """split path tree into list""" 36 | folders = [] 37 | while True: 38 | path, folder = os.path.split(path) 39 | if folder != "": 40 | folders.insert(0, folder) 41 | else: 42 | if path != "": 43 | folders.insert(0, path) 44 | break 45 | return folders 46 | 47 | 48 | # ------------ data and model ------------ 49 | # # CULane mean, std 50 | # mean=(0.3598, 0.3653, 0.3662) 51 | # std=(0.2573, 0.2663, 0.2756) 52 | # Imagenet mean, std 53 | mean = (0.485, 0.456, 0.406) 54 | std = (0.229, 0.224, 0.225) 55 | transform = Compose(Resize(resize_shape), ToTensor(), 56 | Normalize(mean=mean, std=std)) 57 | dataset_name = exp_cfg['dataset'].pop('dataset_name') 58 | Dataset_Type = getattr(dataset, dataset_name) 59 | test_dataset = Dataset_Type(Dataset_Path['Tusimple'], "test", transform) 60 | test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=test_dataset.collate, num_workers=4) 61 | 62 | net = SCNN(input_size=resize_shape, pretrained=False) 63 | save_name = os.path.join(exp_dir, exp_dir.split('/')[-1] + '_best.pth') 64 | save_dict = torch.load(save_name, map_location='cpu') 65 | print("\nloading", save_name, "...... From Epoch: ", save_dict['epoch']) 66 | net.load_state_dict(save_dict['net']) 67 | net = torch.nn.DataParallel(net.to(device)) 68 | net.eval() 69 | 70 | # ------------ test ------------ 71 | out_path = os.path.join(exp_dir, "coord_output") 72 | evaluation_path = os.path.join(exp_dir, "evaluate") 73 | if not os.path.exists(out_path): 74 | os.mkdir(out_path) 75 | if not os.path.exists(evaluation_path): 76 | os.mkdir(evaluation_path) 77 | dump_to_json = [] 78 | 79 | progressbar = tqdm(range(len(test_loader))) 80 | with torch.no_grad(): 81 | for batch_idx, sample in enumerate(test_loader): 82 | img = sample['img'].to(device) 83 | img_name = sample['img_name'] 84 | 85 | seg_pred, exist_pred = net(img)[:2] 86 | seg_pred = F.softmax(seg_pred, dim=1) 87 | seg_pred = seg_pred.detach().cpu().numpy() 88 | exist_pred = exist_pred.detach().cpu().numpy() 89 | 90 | for b in range(len(seg_pred)): 91 | seg = seg_pred[b] 92 | exist = [1 if exist_pred[b, i] > 0.5 else 0 for i in range(4)] 93 | lane_coords = getLane.prob2lines_tusimple(seg, exist, resize_shape=(720, 1280), y_px_gap=10, pts=56) 94 | for i in range(len(lane_coords)): 95 | lane_coords[i] = sorted(lane_coords[i], key=lambda pair: pair[1]) 96 | 97 | path_tree = split_path(img_name[b]) 98 | save_dir, save_name = path_tree[-3:-1], path_tree[-1] 99 | save_dir = os.path.join(out_path, *save_dir) 100 | save_name = save_name[:-3] + "lines.txt" 101 | save_name = os.path.join(save_dir, save_name) 102 | if not os.path.exists(save_dir): 103 | os.makedirs(save_dir, exist_ok=True) 104 | 105 | with open(save_name, "w") as f: 106 | for l in lane_coords: 107 | for (x, y) in l: 108 | print("{} {}".format(x, y), end=" ", file=f) 109 | print(file=f) 110 | 111 | json_dict = {} 112 | json_dict['lanes'] = [] 113 | json_dict['h_sample'] = [] 114 | json_dict['raw_file'] = os.path.join(*path_tree[-4:]) 115 | json_dict['run_time'] = 0 116 | for l in lane_coords: 117 | if len(l) == 0: 118 | continue 119 | json_dict['lanes'].append([]) 120 | for (x, y) in l: 121 | json_dict['lanes'][-1].append(int(x)) 122 | for (x, y) in lane_coords[0]: 123 | json_dict['h_sample'].append(y) 124 | dump_to_json.append(json.dumps(json_dict)) 125 | 126 | progressbar.update(1) 127 | progressbar.close() 128 | 129 | with open(os.path.join(out_path, "predict_test.json"), "w") as f: 130 | for line in dump_to_json: 131 | print(line, end="\n", file=f) 132 | 133 | # ---- evaluate ---- 134 | from utils.lane_evaluation.tusimple.lane import LaneEval 135 | 136 | eval_result = LaneEval.bench_one_submit(os.path.join(out_path, "predict_test.json"), 137 | os.path.join(Dataset_Path['Tusimple'], 'test_label.json')) 138 | print(eval_result) 139 | with open(os.path.join(evaluation_path, "evaluation_result.txt"), "w") as f: 140 | print(eval_result, file=f) 141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import shutil 5 | import time 6 | 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from config import * 12 | import dataset 13 | from model import SCNN 14 | from utils.tensorboard import TensorBoard 15 | from utils.transforms import * 16 | from utils.lr_scheduler import PolyLR 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--exp_dir", type=str, default="./experiments/exp0") 22 | parser.add_argument("--resume", "-r", action="store_true") 23 | args = parser.parse_args() 24 | return args 25 | args = parse_args() 26 | 27 | # ------------ config ------------ 28 | exp_dir = args.exp_dir 29 | while exp_dir[-1]=='/': 30 | exp_dir = exp_dir[:-1] 31 | exp_name = exp_dir.split('/')[-1] 32 | 33 | with open(os.path.join(exp_dir, "cfg.json")) as f: 34 | exp_cfg = json.load(f) 35 | resize_shape = tuple(exp_cfg['dataset']['resize_shape']) 36 | 37 | device = torch.device(exp_cfg['device']) 38 | tensorboard = TensorBoard(exp_dir) 39 | 40 | # ------------ train data ------------ 41 | # # CULane mean, std 42 | # mean=(0.3598, 0.3653, 0.3662) 43 | # std=(0.2573, 0.2663, 0.2756) 44 | # Imagenet mean, std 45 | mean=(0.485, 0.456, 0.406) 46 | std=(0.229, 0.224, 0.225) 47 | transform_train = Compose(Resize(resize_shape), Rotation(2), ToTensor(), 48 | Normalize(mean=mean, std=std)) 49 | dataset_name = exp_cfg['dataset'].pop('dataset_name') 50 | Dataset_Type = getattr(dataset, dataset_name) 51 | train_dataset = Dataset_Type(Dataset_Path[dataset_name], "train", transform_train) 52 | train_loader = DataLoader(train_dataset, batch_size=exp_cfg['dataset']['batch_size'], shuffle=True, collate_fn=train_dataset.collate, num_workers=8) 53 | 54 | # ------------ val data ------------ 55 | transform_val_img = Resize(resize_shape) 56 | transform_val_x = Compose(ToTensor(), Normalize(mean=mean, std=std)) 57 | transform_val = Compose(transform_val_img, transform_val_x) 58 | val_dataset = Dataset_Type(Dataset_Path[dataset_name], "val", transform_val) 59 | val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=val_dataset.collate, num_workers=4) 60 | 61 | # ------------ preparation ------------ 62 | net = SCNN(resize_shape, pretrained=True) 63 | net = net.to(device) 64 | net = torch.nn.DataParallel(net) 65 | 66 | optimizer = optim.SGD(net.parameters(), **exp_cfg['optim']) 67 | lr_scheduler = PolyLR(optimizer, 0.9, **exp_cfg['lr_scheduler']) 68 | best_val_loss = 1e6 69 | 70 | 71 | def train(epoch): 72 | print("Train Epoch: {}".format(epoch)) 73 | net.train() 74 | train_loss = 0 75 | train_loss_seg = 0 76 | train_loss_exist = 0 77 | progressbar = tqdm(range(len(train_loader))) 78 | 79 | for batch_idx, sample in enumerate(train_loader): 80 | img = sample['img'].to(device) 81 | segLabel = sample['segLabel'].to(device) 82 | exist = sample['exist'].to(device) 83 | 84 | optimizer.zero_grad() 85 | seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist) 86 | if isinstance(net, torch.nn.DataParallel): 87 | loss_seg = loss_seg.sum() 88 | loss_exist = loss_exist.sum() 89 | loss = loss.sum() 90 | loss.backward() 91 | optimizer.step() 92 | lr_scheduler.step() 93 | 94 | iter_idx = epoch * len(train_loader) + batch_idx 95 | train_loss = loss.item() 96 | train_loss_seg = loss_seg.item() 97 | train_loss_exist = loss_exist.item() 98 | progressbar.set_description("batch loss: {:.3f}".format(loss.item())) 99 | progressbar.update(1) 100 | 101 | lr = optimizer.param_groups[0]['lr'] 102 | tensorboard.scalar_summary(exp_name + "/train_loss", train_loss, iter_idx) 103 | tensorboard.scalar_summary(exp_name + "/train_loss_seg", train_loss_seg, iter_idx) 104 | tensorboard.scalar_summary(exp_name + "/train_loss_exist", train_loss_exist, iter_idx) 105 | tensorboard.scalar_summary(exp_name + "/learning_rate", lr, iter_idx) 106 | 107 | progressbar.close() 108 | tensorboard.writer.flush() 109 | 110 | if epoch % 1 == 0: 111 | save_dict = { 112 | "epoch": epoch, 113 | "net": net.module.state_dict() if isinstance(net, torch.nn.DataParallel) else net.state_dict(), 114 | "optim": optimizer.state_dict(), 115 | "lr_scheduler": lr_scheduler.state_dict(), 116 | "best_val_loss": best_val_loss 117 | } 118 | save_name = os.path.join(exp_dir, exp_name + '.pth') 119 | torch.save(save_dict, save_name) 120 | print("model is saved: {}".format(save_name)) 121 | 122 | print("------------------------\n") 123 | 124 | 125 | def val(epoch): 126 | global best_val_loss 127 | 128 | print("Val Epoch: {}".format(epoch)) 129 | 130 | net.eval() 131 | val_loss = 0 132 | val_loss_seg = 0 133 | val_loss_exist = 0 134 | progressbar = tqdm(range(len(val_loader))) 135 | 136 | with torch.no_grad(): 137 | for batch_idx, sample in enumerate(val_loader): 138 | img = sample['img'].to(device) 139 | segLabel = sample['segLabel'].to(device) 140 | exist = sample['exist'].to(device) 141 | 142 | seg_pred, exist_pred, loss_seg, loss_exist, loss = net(img, segLabel, exist) 143 | if isinstance(net, torch.nn.DataParallel): 144 | loss_seg = loss_seg.sum() 145 | loss_exist = loss_exist.sum() 146 | loss = loss.sum() 147 | 148 | # visualize validation every 5 frame, 50 frames in all 149 | gap_num = 5 150 | if batch_idx%gap_num == 0 and batch_idx < 50 * gap_num: 151 | origin_imgs = [] 152 | seg_pred = seg_pred.detach().cpu().numpy() 153 | exist_pred = exist_pred.detach().cpu().numpy() 154 | 155 | for b in range(len(img)): 156 | img_name = sample['img_name'][b] 157 | img = cv2.imread(img_name) 158 | img = transform_val_img({'img': img})['img'] 159 | 160 | lane_img = np.zeros_like(img) 161 | color = np.array([[255, 125, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255]], dtype='uint8') 162 | 163 | coord_mask = np.argmax(seg_pred[b], axis=0) 164 | for i in range(0, 4): 165 | if exist_pred[b, i] > 0.5: 166 | lane_img[coord_mask==(i+1)] = color[i] 167 | img = cv2.addWeighted(src1=lane_img, alpha=0.8, src2=img, beta=1., gamma=0.) 168 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 169 | lane_img = cv2.cvtColor(lane_img, cv2.COLOR_BGR2RGB) 170 | cv2.putText(lane_img, "{}".format([1 if exist_pred[b, i]>0.5 else 0 for i in range(4)]), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (255, 255, 255), 2) 171 | origin_imgs.append(img) 172 | origin_imgs.append(lane_img) 173 | tensorboard.image_summary("img_{}".format(batch_idx), origin_imgs, epoch) 174 | 175 | val_loss += loss.item() 176 | val_loss_seg += loss_seg.item() 177 | val_loss_exist += loss_exist.item() 178 | 179 | progressbar.set_description("batch loss: {:.3f}".format(loss.item())) 180 | progressbar.update(1) 181 | 182 | progressbar.close() 183 | iter_idx = (epoch + 1) * len(train_loader) # keep align with training process iter_idx 184 | tensorboard.scalar_summary("val_loss", val_loss, iter_idx) 185 | tensorboard.scalar_summary("val_loss_seg", val_loss_seg, iter_idx) 186 | tensorboard.scalar_summary("val_loss_exist", val_loss_exist, iter_idx) 187 | tensorboard.writer.flush() 188 | 189 | print("------------------------\n") 190 | if val_loss < best_val_loss: 191 | best_val_loss = val_loss 192 | save_name = os.path.join(exp_dir, exp_name + '.pth') 193 | copy_name = os.path.join(exp_dir, exp_name + '_best.pth') 194 | shutil.copyfile(save_name, copy_name) 195 | 196 | 197 | def main(): 198 | global best_val_loss 199 | if args.resume: 200 | save_dict = torch.load(os.path.join(exp_dir, exp_name + '.pth')) 201 | if isinstance(net, torch.nn.DataParallel): 202 | net.module.load_state_dict(save_dict['net']) 203 | else: 204 | net.load_state_dict(save_dict['net']) 205 | optimizer.load_state_dict(save_dict['optim']) 206 | lr_scheduler.load_state_dict(save_dict['lr_scheduler']) 207 | start_epoch = save_dict['epoch'] + 1 208 | best_val_loss = save_dict.get("best_val_loss", 1e6) 209 | else: 210 | start_epoch = 0 211 | 212 | exp_cfg['MAX_EPOCHES'] = int(np.ceil(exp_cfg['lr_scheduler']['max_iter'] / len(train_loader))) 213 | for epoch in range(start_epoch, exp_cfg['MAX_EPOCHES']): 214 | train(epoch) 215 | if epoch % 1 == 0: 216 | print("\nValidation For Experiment: ", exp_dir) 217 | print(time.strftime('%H:%M:%S', time.localtime())) 218 | val(epoch) 219 | 220 | 221 | if __name__ == "__main__": 222 | main() 223 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8) 2 | project (evaluate) 3 | 4 | SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}) 5 | set(CMAKE_CXX_STANDARD 11) 6 | set(CMAKE_CXX_FLAGS "-DCPU_ONLY -fopenmp") 7 | 8 | find_package(OpenCV REQUIRED) 9 | include_directories("${PROJECT_SOURCE_DIR}/include") 10 | 11 | add_executable(evaluate 12 | ${PROJECT_SOURCE_DIR}/src/evaluate.cpp 13 | ${PROJECT_SOURCE_DIR}/src/counter.cpp 14 | ${PROJECT_SOURCE_DIR}/src/lane_compare.cpp 15 | ${PROJECT_SOURCE_DIR}/src/spline.cpp 16 | ) 17 | target_link_libraries(evaluate ${OpenCV_LIBS}) 18 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/Run.sh: -------------------------------------------------------------------------------- 1 | root=/home/lion/SCNN_Pytorch/ 2 | exp=$1 3 | data_dir=/home/lion/Dataset/CULane/data/CULane/ 4 | detect_dir=${root}/experiments/${exp}/coord_output/ 5 | bin_dir=${root}/utils/lane_evaluation/CULane 6 | 7 | w_lane=30; 8 | iou=0.5; # Set iou to 0.3 or 0.5 9 | im_w=1640 10 | im_h=590 11 | frame=1 12 | list0=${data_dir}list/test_split/test0_normal.txt 13 | list1=${data_dir}list/test_split/test1_crowd.txt 14 | list2=${data_dir}list/test_split/test2_hlight.txt 15 | list3=${data_dir}list/test_split/test3_shadow.txt 16 | list4=${data_dir}list/test_split/test4_noline.txt 17 | list5=${data_dir}list/test_split/test5_arrow.txt 18 | list6=${data_dir}list/test_split/test6_curve.txt 19 | list7=${data_dir}list/test_split/test7_cross.txt 20 | list8=${data_dir}list/test_split/test8_night.txt 21 | out0=${detect_dir}../evaluate/out0_normal.txt 22 | out1=${detect_dir}../evaluate/out1_crowd.txt 23 | out2=${detect_dir}../evaluate/out2_hlight.txt 24 | out3=${detect_dir}../evaluate/out3_shadow.txt 25 | out4=${detect_dir}../evaluate/out4_noline.txt 26 | out5=${detect_dir}../evaluate/out5_arrow.txt 27 | out6=${detect_dir}../evaluate/out6_curve.txt 28 | out7=${detect_dir}../evaluate/out7_cross.txt 29 | out8=${detect_dir}../evaluate/out8_night.txt 30 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list0 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out0 31 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list1 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out1 32 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list2 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out2 33 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list3 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out3 34 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list4 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out4 35 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list5 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out5 36 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list6 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out6 37 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list7 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out7 38 | ${bin_dir}/evaluate -a $data_dir -d $detect_dir -i $data_dir -l $list8 -w $w_lane -t $iou -c $im_w -r $im_h -f $frame -o $out8 39 | cat ${detect_dir}/../evaluate/out*.txt > ${detect_dir}/../evaluate/${exp}_iou${iou}_split.txt 40 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/include/counter.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COUNTER_HPP 2 | #define COUNTER_HPP 3 | 4 | #include "lane_compare.hpp" 5 | #include "hungarianGraph.hpp" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace std; 12 | using namespace cv; 13 | 14 | // before coming to use functions of this class, the lanes should resize to im_width and im_height using resize_lane() in lane_compare.hpp 15 | class Counter 16 | { 17 | public: 18 | Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int _lane_width=10):tp(0),fp(0),fn(0){ 19 | im_width = _im_width; 20 | im_height = _im_height; 21 | sim_threshold = _iou_threshold; 22 | lane_compare = new LaneCompare(_im_width, _im_height, _lane_width, LaneCompare::IOU); 23 | }; 24 | double get_precision(void); 25 | double get_recall(void); 26 | long getTP(void); 27 | long getFP(void); 28 | long getFN(void); 29 | // direct add tp, fp, tn and fn 30 | // first match with hungarian 31 | vector count_im_pair(const vector > &anno_lanes, const vector > &detect_lanes); 32 | void makeMatch(const vector > &similarity, vector &match1, vector &match2); 33 | 34 | private: 35 | double sim_threshold; 36 | int im_width; 37 | int im_height; 38 | long tp; 39 | long fp; 40 | long fn; 41 | LaneCompare *lane_compare; 42 | }; 43 | #endif 44 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/include/hungarianGraph.hpp: -------------------------------------------------------------------------------- 1 | #ifndef HUNGARIAN_GRAPH_HPP 2 | #define HUNGARIAN_GRAPH_HPP 3 | #include 4 | using namespace std; 5 | 6 | struct pipartiteGraph { 7 | vector > mat; 8 | vector leftUsed, rightUsed; 9 | vector leftWeight, rightWeight; 10 | vectorrightMatch, leftMatch; 11 | int leftNum, rightNum; 12 | bool matchDfs(int u) { 13 | leftUsed[u] = true; 14 | for (int v = 0; v < rightNum; v++) { 15 | if (!rightUsed[v] && fabs(leftWeight[u] + rightWeight[v] - mat[u][v]) < 1e-2) { 16 | rightUsed[v] = true; 17 | if (rightMatch[v] == -1 || matchDfs(rightMatch[v])) { 18 | rightMatch[v] = u; 19 | leftMatch[u] = v; 20 | return true; 21 | } 22 | } 23 | } 24 | return false; 25 | } 26 | void resize(int leftNum, int rightNum) { 27 | this->leftNum = leftNum; 28 | this->rightNum = rightNum; 29 | leftMatch.resize(leftNum); 30 | rightMatch.resize(rightNum); 31 | leftUsed.resize(leftNum); 32 | rightUsed.resize(rightNum); 33 | leftWeight.resize(leftNum); 34 | rightWeight.resize(rightNum); 35 | mat.resize(leftNum); 36 | for (int i = 0; i < leftNum; i++) mat[i].resize(rightNum); 37 | } 38 | void match() { 39 | for (int i = 0; i < leftNum; i++) leftMatch[i] = -1; 40 | for (int i = 0; i < rightNum; i++) rightMatch[i] = -1; 41 | for (int i = 0; i < rightNum; i++) rightWeight[i] = 0; 42 | for (int i = 0; i < leftNum; i++) { 43 | leftWeight[i] = -1e5; 44 | for (int j = 0; j < rightNum; j++) { 45 | if (leftWeight[i] < mat[i][j]) leftWeight[i] = mat[i][j]; 46 | } 47 | } 48 | 49 | for (int u = 0; u < leftNum; u++) { 50 | while (1) { 51 | for (int i = 0; i < leftNum; i++) leftUsed[i] = false; 52 | for (int i = 0; i < rightNum; i++) rightUsed[i] = false; 53 | if (matchDfs(u)) break; 54 | double d = 1e10; 55 | for (int i = 0; i < leftNum; i++) { 56 | if (leftUsed[i] ) { 57 | for (int j = 0; j < rightNum; j++) { 58 | if (!rightUsed[j]) d = min(d, leftWeight[i] + rightWeight[j] - mat[i][j]); 59 | } 60 | } 61 | } 62 | if (d == 1e10) return ; 63 | for (int i = 0; i < leftNum; i++) if (leftUsed[i]) leftWeight[i] -= d; 64 | for (int i = 0; i < rightNum; i++) if (rightUsed[i]) rightWeight[i] += d; 65 | } 66 | } 67 | } 68 | }; 69 | 70 | 71 | #endif // HUNGARIAN_GRAPH_HPP 72 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/include/lane_compare.hpp: -------------------------------------------------------------------------------- 1 | #ifndef LANE_COMPARE_HPP 2 | #define LANE_COMPARE_HPP 3 | 4 | #include "spline.hpp" 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | using namespace cv; 11 | 12 | class LaneCompare{ 13 | public: 14 | enum CompareMode{ 15 | IOU, 16 | Caltech 17 | }; 18 | 19 | LaneCompare(int _im_width, int _im_height, int _lane_width = 10, CompareMode _compare_mode = IOU){ 20 | im_width = _im_width; 21 | im_height = _im_height; 22 | compare_mode = _compare_mode; 23 | lane_width = _lane_width; 24 | } 25 | 26 | double get_lane_similarity(const vector &lane1, const vector &lane2); 27 | void resize_lane(vector &curr_lane, int curr_width, int curr_height); 28 | private: 29 | CompareMode compare_mode; 30 | int im_width; 31 | int im_height; 32 | int lane_width; 33 | Spline splineSolver; 34 | }; 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/include/spline.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SPLINE_HPP 2 | #define SPLINE_HPP 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace cv; 9 | using namespace std; 10 | 11 | struct Func { 12 | double a_x; 13 | double b_x; 14 | double c_x; 15 | double d_x; 16 | double a_y; 17 | double b_y; 18 | double c_y; 19 | double d_y; 20 | double h; 21 | }; 22 | class Spline { 23 | public: 24 | vector splineInterpTimes(const vector &tmp_line, int times); 25 | vector splineInterpStep(vector tmp_line, double step); 26 | vector cal_fun(const vector &point_v); 27 | }; 28 | #endif 29 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/src/counter.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | > File Name: counter.cpp 3 | > Author: Xingang Pan, Jun Li 4 | > Mail: px117@ie.cuhk.edu.hk 5 | > Created Time: Thu Jul 14 20:23:08 2016 6 | ************************************************************************/ 7 | 8 | #include "counter.hpp" 9 | #include 10 | 11 | double Counter::get_precision(void) 12 | { 13 | cerr<<"tp: "< Counter::count_im_pair(const vector > &anno_lanes, const vector > &detect_lanes) 48 | { 49 | vector anno_match(anno_lanes.size(), -1); 50 | vector detect_match; 51 | if(anno_lanes.empty()) 52 | { 53 | fp += detect_lanes.size(); 54 | return anno_match; 55 | } 56 | 57 | if(detect_lanes.empty()) 58 | { 59 | fn += anno_lanes.size(); 60 | return anno_match; 61 | } 62 | // hungarian match first 63 | 64 | // first calc similarity matrix 65 | vector > similarity(anno_lanes.size(), vector(detect_lanes.size(), 0)); 66 | for(int i=0; i &curr_anno_lane = anno_lanes[i]; 69 | for(int j=0; j &curr_detect_lane = detect_lanes[j]; 72 | similarity[i][j] = lane_compare->get_lane_similarity(ref(curr_anno_lane), ref(curr_detect_lane)); 73 | } 74 | } 75 | 76 | 77 | 78 | makeMatch(ref(similarity), ref(anno_match), ref(detect_match)); 79 | 80 | 81 | int curr_tp = 0; 82 | // count and add 83 | for(int i=0; i=0 && similarity[i][anno_match[i]] > sim_threshold) 86 | { 87 | curr_tp++; 88 | } 89 | else 90 | { 91 | anno_match[i] = -1; 92 | } 93 | } 94 | int curr_fn = anno_lanes.size() - curr_tp; 95 | int curr_fp = detect_lanes.size() - curr_tp; 96 | tp += curr_tp; 97 | fn += curr_fn; 98 | fp += curr_fp; 99 | return anno_match; 100 | } 101 | 102 | 103 | void Counter::makeMatch(const vector > &similarity, vector &match1, vector &match2) { 104 | int m = similarity.size(); 105 | int n = similarity[0].size(); 106 | pipartiteGraph gra; 107 | bool have_exchange = false; 108 | if (m > n) { 109 | have_exchange = true; 110 | swap(m, n); 111 | } 112 | gra.resize(m, n); 113 | for (int i = 0; i < gra.leftNum; i++) { 114 | for (int j = 0; j < gra.rightNum; j++) { 115 | if(have_exchange) 116 | gra.mat[i][j] = similarity[j][i]; 117 | else 118 | gra.mat[i][j] = similarity[i][j]; 119 | } 120 | } 121 | gra.match(); 122 | match1 = gra.leftMatch; 123 | match2 = gra.rightMatch; 124 | if (have_exchange) swap(match1, match2); 125 | } 126 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/src/evaluate.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | > File Name: evaluate.cpp 3 | > Author: Xingang Pan, Jun Li 4 | > Mail: px117@ie.cuhk.edu.hk 5 | > Created Time: 2016年07月14日 星期四 18时28分45秒 6 | ************************************************************************/ 7 | 8 | #include "counter.hpp" 9 | #include "spline.hpp" 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | using namespace std; 25 | using namespace cv; 26 | 27 | void help(void) 28 | { 29 | cout<<"./evaluate [OPTIONS]"< > &lanes); 45 | void visualize(string &full_im_name, vector > &anno_lanes, vector > &detect_lanes, vector anno_match, int width_lane); 46 | void worker_func(vector &lines_list_v, int start, int end, int &tp, int &fp, int &fn); 47 | void update_tp_fp_fn(int &tp, int &fp, int &fn, int _tp, int _fp, int _fn); 48 | 49 | double get_precision(int tp, int fp, int fn) 50 | { 51 | cerr<<"tp: "< lines_list_v; 167 | string line; 168 | while(getline(ifs_im_list, line)) { 169 | lines_list_v.push_back(line); 170 | } 171 | ifs_im_list.close(); 172 | 173 | int TP=0, FP=0, FN=0; //result 174 | int NUM = lines_list_v.size(); 175 | int batch_size = NUM / NUM_PROCESS; 176 | vector thread_v; 177 | for (int i=0; iNUM) ? NUM:_end; 180 | thread_v.push_back(thread(worker_func, ref(lines_list_v), _start, _end, ref(TP), ref(FP), ref(FN))); 181 | } 182 | 183 | for (int i=0; i anno_match; 190 | // string sub_im_name; 191 | // int count = 0; 192 | 193 | 194 | // while(getline(ifs_im_list, sub_im_name)) 195 | // { 196 | // count++; 197 | // if (count < frame) 198 | // continue; 199 | // string full_im_name = im_dir + sub_im_name; 200 | // string sub_txt_name = sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt"; 201 | // string anno_file_name = anno_dir + sub_txt_name; 202 | // string detect_file_name = detect_dir + sub_txt_name; 203 | // vector > anno_lanes; 204 | // vector > detect_lanes; 205 | // read_lane_file(anno_file_name, anno_lanes); 206 | // read_lane_file(detect_file_name, detect_lanes); 207 | // //cerr< > &lanes) 240 | { 241 | lanes.clear(); 242 | ifstream ifs_lane(file_name, ios::in); 243 | if(ifs_lane.fail()) 244 | { 245 | return; 246 | } 247 | 248 | string str_line; 249 | while(getline(ifs_lane, str_line)) 250 | { 251 | vector curr_lane; 252 | stringstream ss; 253 | ss<>x>>y) 256 | { 257 | curr_lane.push_back(Point2f(x, y)); 258 | } 259 | lanes.push_back(curr_lane); 260 | } 261 | 262 | ifs_lane.close(); 263 | } 264 | 265 | void visualize(string &full_im_name, vector > &anno_lanes, vector > &detect_lanes, vector anno_match, int width_lane) 266 | { 267 | Mat img = imread(full_im_name, 1); 268 | Mat img2 = imread(full_im_name, 1); 269 | vector curr_lane; 270 | vector p_interp; 271 | Spline splineSolver; 272 | Scalar color_B = Scalar(255, 0, 0); 273 | Scalar color_G = Scalar(0, 255, 0); 274 | Scalar color_R = Scalar(0, 0, 255); 275 | Scalar color_P = Scalar(255, 0, 255); 276 | Scalar color; 277 | for (int i=0; i= 0) 289 | { 290 | color = color_G; 291 | } 292 | else 293 | { 294 | color = color_G; 295 | } 296 | for (int n=0; n guard(myMutex); 346 | tp += _tp; 347 | fp += _fp; 348 | fn += _fn; 349 | } 350 | 351 | void worker_func(vector &lines_list_v, int start, int end, int &tp, int &fp, int &fn) 352 | { 353 | Counter counter(im_width, im_height, iou_threshold, width_lane); 354 | 355 | vector anno_match; 356 | string sub_im_name; 357 | int count = 0; 358 | 359 | for (int i=start; i > anno_lanes; 368 | vector > detect_lanes; 369 | 370 | read_lane_file(anno_file_name, ref(anno_lanes)); 371 | read_lane_file(detect_file_name, ref(detect_lanes)); 372 | 373 | anno_match = counter.count_im_pair(ref(anno_lanes), ref(detect_lanes)); 374 | } 375 | 376 | update_tp_fp_fn(ref(tp), ref(fp), ref(fn), counter.getTP(), counter.getFP(), counter.getFN()); 377 | } -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/src/lane_compare.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | > File Name: lane_compare.cpp 3 | > Author: Xingang Pan, Jun Li 4 | > Mail: px117@ie.cuhk.edu.hk 5 | > Created Time: Fri Jul 15 10:26:32 2016 6 | ************************************************************************/ 7 | 8 | #include "lane_compare.hpp" 9 | #include 10 | #include 11 | 12 | 13 | double LaneCompare::get_lane_similarity(const vector &lane1, const vector &lane2) 14 | { 15 | if(lane1.size()<2 || lane2.size()<2) 16 | { 17 | cerr<<"lane size must be greater or equal to 2"< p_interp1; 24 | vector p_interp2; 25 | if(lane1.size() == 2) 26 | { 27 | p_interp1 = lane1; 28 | } 29 | else 30 | { 31 | p_interp1 = splineSolver.splineInterpTimes(lane1, 50); 32 | } 33 | 34 | if(lane2.size() == 2) 35 | { 36 | p_interp2 = lane2; 37 | } 38 | else 39 | { 40 | p_interp2 = splineSolver.splineInterpTimes(lane2, 50); 41 | } 42 | 43 | Scalar color_white = Scalar(1); 44 | for(int n=0; n &curr_lane, int curr_width, int curr_height) 64 | { 65 | if(curr_width == im_width && curr_height == im_height) 66 | { 67 | return; 68 | } 69 | double x_scale = im_width/(double)curr_width; 70 | double y_scale = im_height/(double)curr_height; 71 | for(int n=0; n 2 | #include 3 | #include "spline.hpp" 4 | using namespace std; 5 | using namespace cv; 6 | 7 | vector Spline::splineInterpTimes(const vector& tmp_line, int times) { 8 | vector res; 9 | 10 | if(tmp_line.size() == 2) { 11 | double x1 = tmp_line[0].x; 12 | double y1 = tmp_line[0].y; 13 | double x2 = tmp_line[1].x; 14 | double y2 = tmp_line[1].y; 15 | 16 | for (int k = 0; k <= times; k++) { 17 | double xi = x1 + double((x2 - x1) * k) / times; 18 | double yi = y1 + double((y2 - y1) * k) / times; 19 | res.push_back(Point2f(xi, yi)); 20 | } 21 | } 22 | 23 | else if(tmp_line.size() > 2) 24 | { 25 | vector tmp_func; 26 | tmp_func = this->cal_fun(tmp_line); 27 | if (tmp_func.empty()) { 28 | cout << "in splineInterpTimes: cal_fun failed" << endl; 29 | return res; 30 | } 31 | for(int j = 0; j < tmp_func.size(); j++) 32 | { 33 | double delta = tmp_func[j].h / times; 34 | for(int k = 0; k < times; k++) 35 | { 36 | double t1 = delta*k; 37 | double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3); 38 | double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3); 39 | res.push_back(Point2f(x1, y1)); 40 | } 41 | } 42 | res.push_back(tmp_line[tmp_line.size() - 1]); 43 | } 44 | else { 45 | cerr << "in splineInterpTimes: not enough points" << endl; 46 | } 47 | return res; 48 | } 49 | vector Spline::splineInterpStep(vector tmp_line, double step) { 50 | vector res; 51 | /* 52 | if (tmp_line.size() == 2) { 53 | double x1 = tmp_line[0].x; 54 | double y1 = tmp_line[0].y; 55 | double x2 = tmp_line[1].x; 56 | double y2 = tmp_line[1].y; 57 | 58 | for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) { 59 | double xi; 60 | if (yi == y1) xi = x1; 61 | else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1; 62 | res.push_back(Point2f(xi, yi)); 63 | } 64 | }*/ 65 | if (tmp_line.size() == 2) { 66 | double x1 = tmp_line[0].x; 67 | double y1 = tmp_line[0].y; 68 | double x2 = tmp_line[1].x; 69 | double y2 = tmp_line[1].y; 70 | tmp_line[1].x = (x1 + x2) / 2; 71 | tmp_line[1].y = (y1 + y2) / 2; 72 | tmp_line.push_back(Point2f(x2, y2)); 73 | } 74 | if (tmp_line.size() > 2) { 75 | vector tmp_func; 76 | tmp_func = this->cal_fun(tmp_line); 77 | double ystart = tmp_line[0].y; 78 | double yend = tmp_line[tmp_line.size() - 1].y; 79 | bool down; 80 | if (ystart < yend) down = 1; 81 | else down = 0; 82 | if (tmp_func.empty()) { 83 | cerr << "in splineInterpStep: cal_fun failed" << endl; 84 | } 85 | 86 | for(int j = 0; j < tmp_func.size(); j++) 87 | { 88 | for(double t1 = 0; t1 < tmp_func[j].h; t1 += step) 89 | { 90 | double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3); 91 | double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3); 92 | res.push_back(Point2f(x1, y1)); 93 | } 94 | } 95 | res.push_back(tmp_line[tmp_line.size() - 1]); 96 | } 97 | else { 98 | cerr << "in splineInterpStep: not enough points" << endl; 99 | } 100 | return res; 101 | } 102 | 103 | vector Spline::cal_fun(const vector &point_v) 104 | { 105 | vector func_v; 106 | int n = point_v.size(); 107 | if(n<=2) { 108 | cout << "in cal_fun: point number less than 3" << endl; 109 | return func_v; 110 | } 111 | 112 | func_v.resize(point_v.size()-1); 113 | 114 | vector Mx(n); 115 | vector My(n); 116 | vector A(n-2); 117 | vector B(n-2); 118 | vector C(n-2); 119 | vector Dx(n-2); 120 | vector Dy(n-2); 121 | vector h(n-1); 122 | //vector func_v(n-1); 123 | 124 | for(int i = 0; i < n-1; i++) 125 | { 126 | h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2)); 127 | } 128 | 129 | for(int i = 0; i < n-2; i++) 130 | { 131 | A[i] = h[i]; 132 | B[i] = 2*(h[i]+h[i+1]); 133 | C[i] = h[i+1]; 134 | 135 | Dx[i] = 6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] ); 136 | Dy[i] = 6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] ); 137 | } 138 | 139 | //TDMA 140 | C[0] = C[0] / B[0]; 141 | Dx[0] = Dx[0] / B[0]; 142 | Dy[0] = Dy[0] / B[0]; 143 | for(int i = 1; i < n-2; i++) 144 | { 145 | double tmp = B[i] - A[i]*C[i-1]; 146 | C[i] = C[i] / tmp; 147 | Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp; 148 | Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp; 149 | } 150 | Mx[n-2] = Dx[n-3]; 151 | My[n-2] = Dy[n-3]; 152 | for(int i = n-4; i >= 0; i--) 153 | { 154 | Mx[i+1] = Dx[i] - C[i]*Mx[i+2]; 155 | My[i+1] = Dy[i] - C[i]*My[i+2]; 156 | } 157 | 158 | Mx[0] = 0; 159 | Mx[n-1] = 0; 160 | My[0] = 0; 161 | My[n-1] = 0; 162 | 163 | for(int i = 0; i < n-1; i++) 164 | { 165 | func_v[i].a_x = point_v[i].x; 166 | func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6; 167 | func_v[i].c_x = Mx[i]/2; 168 | func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]); 169 | 170 | func_v[i].a_y = point_v[i].y; 171 | func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6; 172 | func_v[i].c_y = My[i]/2; 173 | func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]); 174 | 175 | func_v[i].h = h[i]; 176 | } 177 | return func_v; 178 | } 179 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/src_origin/counter.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | > File Name: counter.cpp 3 | > Author: Xingang Pan, Jun Li 4 | > Mail: px117@ie.cuhk.edu.hk 5 | > Created Time: Thu Jul 14 20:23:08 2016 6 | ************************************************************************/ 7 | 8 | #include "counter.hpp" 9 | 10 | double Counter::get_precision(void) 11 | { 12 | cerr<<"tp: "< Counter::count_im_pair(const vector > &anno_lanes, const vector > &detect_lanes) 47 | { 48 | vector anno_match(anno_lanes.size(), -1); 49 | vector detect_match; 50 | if(anno_lanes.empty()) 51 | { 52 | fp += detect_lanes.size(); 53 | return anno_match; 54 | } 55 | 56 | if(detect_lanes.empty()) 57 | { 58 | fn += anno_lanes.size(); 59 | return anno_match; 60 | } 61 | // hungarian match first 62 | 63 | // first calc similarity matrix 64 | vector > similarity(anno_lanes.size(), vector(detect_lanes.size(), 0)); 65 | for(int i=0; i &curr_anno_lane = anno_lanes[i]; 68 | for(int j=0; j &curr_detect_lane = detect_lanes[j]; 71 | similarity[i][j] = lane_compare->get_lane_similarity(curr_anno_lane, curr_detect_lane); 72 | } 73 | } 74 | 75 | 76 | 77 | makeMatch(similarity, anno_match, detect_match); 78 | 79 | 80 | int curr_tp = 0; 81 | // count and add 82 | for(int i=0; i=0 && similarity[i][anno_match[i]] > sim_threshold) 85 | { 86 | curr_tp++; 87 | } 88 | else 89 | { 90 | anno_match[i] = -1; 91 | } 92 | } 93 | int curr_fn = anno_lanes.size() - curr_tp; 94 | int curr_fp = detect_lanes.size() - curr_tp; 95 | tp += curr_tp; 96 | fn += curr_fn; 97 | fp += curr_fp; 98 | return anno_match; 99 | } 100 | 101 | 102 | void Counter::makeMatch(const vector > &similarity, vector &match1, vector &match2) { 103 | int m = similarity.size(); 104 | int n = similarity[0].size(); 105 | pipartiteGraph gra; 106 | bool have_exchange = false; 107 | if (m > n) { 108 | have_exchange = true; 109 | swap(m, n); 110 | } 111 | gra.resize(m, n); 112 | for (int i = 0; i < gra.leftNum; i++) { 113 | for (int j = 0; j < gra.rightNum; j++) { 114 | if(have_exchange) 115 | gra.mat[i][j] = similarity[j][i]; 116 | else 117 | gra.mat[i][j] = similarity[i][j]; 118 | } 119 | } 120 | gra.match(); 121 | match1 = gra.leftMatch; 122 | match2 = gra.rightMatch; 123 | if (have_exchange) swap(match1, match2); 124 | } 125 | -------------------------------------------------------------------------------- /utils/lane_evaluation/CULane/src_origin/evaluate.cpp: -------------------------------------------------------------------------------- 1 | /************************************************************************* 2 | > File Name: evaluate.cpp 3 | > Author: Xingang Pan, Jun Li 4 | > Mail: px117@ie.cuhk.edu.hk 5 | > Created Time: 2016年07月14日 星期四 18时28分45秒 6 | ************************************************************************/ 7 | 8 | #include "counter.hpp" 9 | #include "spline.hpp" 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | using namespace std; 21 | using namespace cv; 22 | 23 | void help(void) 24 | { 25 | cout<<"./evaluate [OPTIONS]"< > &lanes); 41 | void visualize(string &full_im_name, vector > &anno_lanes, vector > &detect_lanes, vector anno_match, int width_lane); 42 | 43 | int main(int argc, char **argv) 44 | { 45 | // process params 46 | string anno_dir = "/data/driving/eval_data/anno_label/"; 47 | string detect_dir = "/data/driving/eval_data/predict_label/"; 48 | string im_dir = "/data/driving/eval_data/img/"; 49 | string list_im_file = "/data/driving/eval_data/img/all.txt"; 50 | string output_file = "./output.txt"; 51 | int width_lane = 10; 52 | double iou_threshold = 0.4; 53 | int im_width = 1920; 54 | int im_height = 1080; 55 | int oc; 56 | bool show = false; 57 | int frame = 1; 58 | while((oc = getopt(argc, argv, "ha:d:i:l:w:t:c:r:sf:o:")) != -1) 59 | { 60 | switch(oc) 61 | { 62 | case 'h': 63 | help(); 64 | return 0; 65 | case 'a': 66 | anno_dir = optarg; 67 | break; 68 | case 'd': 69 | detect_dir = optarg; 70 | break; 71 | case 'i': 72 | im_dir = optarg; 73 | break; 74 | case 'l': 75 | list_im_file = optarg; 76 | break; 77 | case 'w': 78 | width_lane = atoi(optarg); 79 | break; 80 | case 't': 81 | iou_threshold = atof(optarg); 82 | break; 83 | case 'c': 84 | im_width = atoi(optarg); 85 | break; 86 | case 'r': 87 | im_height = atoi(optarg); 88 | break; 89 | case 's': 90 | show = true; 91 | break; 92 | case 'f': 93 | frame = atoi(optarg); 94 | break; 95 | case 'o': 96 | output_file = optarg; 97 | break; 98 | } 99 | } 100 | 101 | 102 | cout<<"------------Configuration---------"< anno_match; 134 | string sub_im_name; 135 | int count = 0; 136 | while(getline(ifs_im_list, sub_im_name)) 137 | { 138 | count++; 139 | if (count < frame) 140 | continue; 141 | string full_im_name = im_dir + sub_im_name; 142 | string sub_txt_name = sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt"; 143 | string anno_file_name = anno_dir + sub_txt_name; 144 | string detect_file_name = detect_dir + sub_txt_name; 145 | vector > anno_lanes; 146 | vector > detect_lanes; 147 | read_lane_file(anno_file_name, anno_lanes); 148 | read_lane_file(detect_file_name, detect_lanes); 149 | //cerr< > &lanes) 180 | { 181 | lanes.clear(); 182 | ifstream ifs_lane(file_name, ios::in); 183 | if(ifs_lane.fail()) 184 | { 185 | return; 186 | } 187 | 188 | string str_line; 189 | while(getline(ifs_lane, str_line)) 190 | { 191 | vector curr_lane; 192 | stringstream ss; 193 | ss<>x>>y) 196 | { 197 | curr_lane.push_back(Point2f(x, y)); 198 | } 199 | lanes.push_back(curr_lane); 200 | } 201 | 202 | ifs_lane.close(); 203 | } 204 | 205 | void visualize(string &full_im_name, vector > &anno_lanes, vector > &detect_lanes, vector anno_match, int width_lane) 206 | { 207 | Mat img = imread(full_im_name, 1); 208 | Mat img2 = imread(full_im_name, 1); 209 | vector curr_lane; 210 | vector p_interp; 211 | Spline splineSolver; 212 | Scalar color_B = Scalar(255, 0, 0); 213 | Scalar color_G = Scalar(0, 255, 0); 214 | Scalar color_R = Scalar(0, 0, 255); 215 | Scalar color_P = Scalar(255, 0, 255); 216 | Scalar color; 217 | for (int i=0; i= 0) 229 | { 230 | color = color_G; 231 | } 232 | else 233 | { 234 | color = color_G; 235 | } 236 | for (int n=0; n File Name: lane_compare.cpp 3 | > Author: Xingang Pan, Jun Li 4 | > Mail: px117@ie.cuhk.edu.hk 5 | > Created Time: Fri Jul 15 10:26:32 2016 6 | ************************************************************************/ 7 | 8 | #include "lane_compare.hpp" 9 | #include 10 | #include 11 | 12 | 13 | double LaneCompare::get_lane_similarity(const vector &lane1, const vector &lane2) 14 | { 15 | if(lane1.size()<2 || lane2.size()<2) 16 | { 17 | cerr<<"lane size must be greater or equal to 2"< p_interp1; 24 | vector p_interp2; 25 | if(lane1.size() == 2) 26 | { 27 | p_interp1 = lane1; 28 | } 29 | else 30 | { 31 | p_interp1 = splineSolver.splineInterpTimes(lane1, 50); 32 | } 33 | 34 | if(lane2.size() == 2) 35 | { 36 | p_interp2 = lane2; 37 | } 38 | else 39 | { 40 | p_interp2 = splineSolver.splineInterpTimes(lane2, 50); 41 | } 42 | 43 | Scalar color_white = Scalar(1); 44 | for(int n=0; n &curr_lane, int curr_width, int curr_height) 64 | { 65 | if(curr_width == im_width && curr_height == im_height) 66 | { 67 | return; 68 | } 69 | double x_scale = im_width/(double)curr_width; 70 | double y_scale = im_height/(double)curr_height; 71 | for(int n=0; n 2 | #include 3 | #include "spline.hpp" 4 | using namespace std; 5 | using namespace cv; 6 | 7 | vector Spline::splineInterpTimes(const vector& tmp_line, int times) { 8 | vector res; 9 | 10 | if(tmp_line.size() == 2) { 11 | double x1 = tmp_line[0].x; 12 | double y1 = tmp_line[0].y; 13 | double x2 = tmp_line[1].x; 14 | double y2 = tmp_line[1].y; 15 | 16 | for (int k = 0; k <= times; k++) { 17 | double xi = x1 + double((x2 - x1) * k) / times; 18 | double yi = y1 + double((y2 - y1) * k) / times; 19 | res.push_back(Point2f(xi, yi)); 20 | } 21 | } 22 | 23 | else if(tmp_line.size() > 2) 24 | { 25 | vector tmp_func; 26 | tmp_func = this->cal_fun(tmp_line); 27 | if (tmp_func.empty()) { 28 | cout << "in splineInterpTimes: cal_fun failed" << endl; 29 | return res; 30 | } 31 | for(int j = 0; j < tmp_func.size(); j++) 32 | { 33 | double delta = tmp_func[j].h / times; 34 | for(int k = 0; k < times; k++) 35 | { 36 | double t1 = delta*k; 37 | double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3); 38 | double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3); 39 | res.push_back(Point2f(x1, y1)); 40 | } 41 | } 42 | res.push_back(tmp_line[tmp_line.size() - 1]); 43 | } 44 | else { 45 | cerr << "in splineInterpTimes: not enough points" << endl; 46 | } 47 | return res; 48 | } 49 | vector Spline::splineInterpStep(vector tmp_line, double step) { 50 | vector res; 51 | /* 52 | if (tmp_line.size() == 2) { 53 | double x1 = tmp_line[0].x; 54 | double y1 = tmp_line[0].y; 55 | double x2 = tmp_line[1].x; 56 | double y2 = tmp_line[1].y; 57 | 58 | for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) { 59 | double xi; 60 | if (yi == y1) xi = x1; 61 | else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1; 62 | res.push_back(Point2f(xi, yi)); 63 | } 64 | }*/ 65 | if (tmp_line.size() == 2) { 66 | double x1 = tmp_line[0].x; 67 | double y1 = tmp_line[0].y; 68 | double x2 = tmp_line[1].x; 69 | double y2 = tmp_line[1].y; 70 | tmp_line[1].x = (x1 + x2) / 2; 71 | tmp_line[1].y = (y1 + y2) / 2; 72 | tmp_line.push_back(Point2f(x2, y2)); 73 | } 74 | if (tmp_line.size() > 2) { 75 | vector tmp_func; 76 | tmp_func = this->cal_fun(tmp_line); 77 | double ystart = tmp_line[0].y; 78 | double yend = tmp_line[tmp_line.size() - 1].y; 79 | bool down; 80 | if (ystart < yend) down = 1; 81 | else down = 0; 82 | if (tmp_func.empty()) { 83 | cerr << "in splineInterpStep: cal_fun failed" << endl; 84 | } 85 | 86 | for(int j = 0; j < tmp_func.size(); j++) 87 | { 88 | for(double t1 = 0; t1 < tmp_func[j].h; t1 += step) 89 | { 90 | double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3); 91 | double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3); 92 | res.push_back(Point2f(x1, y1)); 93 | } 94 | } 95 | res.push_back(tmp_line[tmp_line.size() - 1]); 96 | } 97 | else { 98 | cerr << "in splineInterpStep: not enough points" << endl; 99 | } 100 | return res; 101 | } 102 | 103 | vector Spline::cal_fun(const vector &point_v) 104 | { 105 | vector func_v; 106 | int n = point_v.size(); 107 | if(n<=2) { 108 | cout << "in cal_fun: point number less than 3" << endl; 109 | return func_v; 110 | } 111 | 112 | func_v.resize(point_v.size()-1); 113 | 114 | vector Mx(n); 115 | vector My(n); 116 | vector A(n-2); 117 | vector B(n-2); 118 | vector C(n-2); 119 | vector Dx(n-2); 120 | vector Dy(n-2); 121 | vector h(n-1); 122 | //vector func_v(n-1); 123 | 124 | for(int i = 0; i < n-1; i++) 125 | { 126 | h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2)); 127 | } 128 | 129 | for(int i = 0; i < n-2; i++) 130 | { 131 | A[i] = h[i]; 132 | B[i] = 2*(h[i]+h[i+1]); 133 | C[i] = h[i+1]; 134 | 135 | Dx[i] = 6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] ); 136 | Dy[i] = 6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] ); 137 | } 138 | 139 | //TDMA 140 | C[0] = C[0] / B[0]; 141 | Dx[0] = Dx[0] / B[0]; 142 | Dy[0] = Dy[0] / B[0]; 143 | for(int i = 1; i < n-2; i++) 144 | { 145 | double tmp = B[i] - A[i]*C[i-1]; 146 | C[i] = C[i] / tmp; 147 | Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp; 148 | Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp; 149 | } 150 | Mx[n-2] = Dx[n-3]; 151 | My[n-2] = Dy[n-3]; 152 | for(int i = n-4; i >= 0; i--) 153 | { 154 | Mx[i+1] = Dx[i] - C[i]*Mx[i+2]; 155 | My[i+1] = Dy[i] - C[i]*My[i+2]; 156 | } 157 | 158 | Mx[0] = 0; 159 | Mx[n-1] = 0; 160 | My[0] = 0; 161 | My[n-1] = 0; 162 | 163 | for(int i = 0; i < n-1; i++) 164 | { 165 | func_v[i].a_x = point_v[i].x; 166 | func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6; 167 | func_v[i].c_x = Mx[i]/2; 168 | func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]); 169 | 170 | func_v[i].a_y = point_v[i].y; 171 | func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6; 172 | func_v[i].c_y = My[i]/2; 173 | func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]); 174 | 175 | func_v[i].h = h[i]; 176 | } 177 | return func_v; 178 | } 179 | -------------------------------------------------------------------------------- /utils/lane_evaluation/tusimple/lane.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LinearRegression 3 | import json as json 4 | 5 | 6 | class LaneEval(object): 7 | lr = LinearRegression() 8 | pixel_thresh = 20 9 | pt_thresh = 0.85 10 | 11 | @staticmethod 12 | def get_angle(xs, y_samples): 13 | xs, ys = xs[xs >= 0], y_samples[xs >= 0] 14 | if len(xs) > 1: 15 | LaneEval.lr.fit(ys[:, None], xs) 16 | k = LaneEval.lr.coef_[0] 17 | theta = np.arctan(k) 18 | else: 19 | theta = 0 20 | return theta 21 | 22 | @staticmethod 23 | def line_accuracy(pred, gt, thresh): 24 | pred = np.array([p if p >= 0 else -100 for p in pred]) 25 | gt = np.array([g if g >= 0 else -100 for g in gt]) 26 | return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt) 27 | 28 | @staticmethod 29 | def bench(pred, gt, y_samples, running_time): 30 | if any(len(p) != len(y_samples) for p in pred): 31 | raise Exception('Format of lanes error.') 32 | if running_time > 200 or len(gt) + 2 < len(pred): 33 | return 0., 0., 1. 34 | angles = [LaneEval.get_angle(np.array(x_gts), np.array(y_samples)) for x_gts in gt] 35 | threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles] 36 | line_accs = [] 37 | fp, fn = 0., 0. 38 | matched = 0. 39 | for x_gts, thresh in zip(gt, threshs): 40 | accs = [LaneEval.line_accuracy(np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred] 41 | max_acc = np.max(accs) if len(accs) > 0 else 0. 42 | if max_acc < LaneEval.pt_thresh: 43 | fn += 1 44 | else: 45 | matched += 1 46 | line_accs.append(max_acc) 47 | fp = len(pred) - matched 48 | if len(gt) > 4 and fn > 0: 49 | fn -= 1 50 | s = sum(line_accs) 51 | if len(gt) > 4: 52 | s -= min(line_accs) 53 | return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.) , 1.) 54 | 55 | @staticmethod 56 | def bench_one_submit(pred_file, gt_file): 57 | try: 58 | json_pred = [json.loads(line) for line in open(pred_file).readlines()] 59 | except BaseException as e: 60 | raise Exception('Fail to load json file of the prediction.') 61 | json_gt = [json.loads(line) for line in open(gt_file).readlines()] 62 | if len(json_gt) != len(json_pred): 63 | raise Exception('We do not get the predictions of all the test tasks') 64 | gts = {l['raw_file']: l for l in json_gt} 65 | accuracy, fp, fn = 0., 0., 0. 66 | for pred in json_pred: 67 | if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred: 68 | raise Exception('raw_file or lanes or run_time not in some predictions.') 69 | raw_file = pred['raw_file'] 70 | pred_lanes = pred['lanes'] 71 | run_time = pred['run_time'] 72 | if raw_file not in gts: 73 | raise Exception('Some raw_file from your predictions do not exist in the test tasks.') 74 | gt = gts[raw_file] 75 | gt_lanes = gt['lanes'] 76 | y_samples = gt['h_samples'] 77 | try: 78 | a, p, n = LaneEval.bench(pred_lanes, gt_lanes, y_samples, run_time) 79 | except BaseException as e: 80 | raise Exception('Format of lanes error.') 81 | accuracy += a 82 | fp += p 83 | fn += n 84 | num = len(gts) 85 | # the first return parameter is the default ranking parameter 86 | return json.dumps([ 87 | {'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'}, 88 | {'name': 'FP', 'value': fp / num, 'order': 'asc'}, 89 | {'name': 'FN', 'value': fn / num, 'order': 'asc'} 90 | ]) 91 | 92 | 93 | if __name__ == '__main__': 94 | import sys 95 | try: 96 | if len(sys.argv) != 3: 97 | raise Exception('Invalid input arguments') 98 | print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2])) 99 | except Exception as e: 100 | print(e.message) 101 | sys.exit(e.message) 102 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyLR(_LRScheduler): 5 | def __init__(self, optimizer, pow, max_iter, min_lrs=1e-20, last_epoch=-1, warmup=0): 6 | """ 7 | :param warmup: how many steps for linearly warmup lr 8 | """ 9 | self.pow = pow 10 | self.max_iter = max_iter 11 | if not isinstance(min_lrs, list) and not isinstance(min_lrs, tuple): 12 | self.min_lrs = [min_lrs] * len(optimizer.param_groups) 13 | 14 | assert isinstance(warmup, int), "The type of warmup is incorrect, got {}".format(type(warmup)) 15 | self.warmup = max(warmup, 0) 16 | 17 | super(PolyLR, self).__init__(optimizer, last_epoch) 18 | 19 | def get_lr(self): 20 | if self.last_epoch < self.warmup: 21 | return [base_lr / self.warmup * (self.last_epoch+1) for base_lr in self.base_lrs] 22 | 23 | if self.last_epoch < self.max_iter: 24 | coeff = (1 - (self.last_epoch-self.warmup) / (self.max_iter-self.warmup)) ** self.pow 25 | else: 26 | coeff = 0 27 | return [(base_lr - min_lr) * coeff + min_lr 28 | for base_lr, min_lr in zip(self.base_lrs, self.min_lrs)] 29 | -------------------------------------------------------------------------------- /utils/prob2lines/getLane.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None): 6 | """ 7 | Arguments: 8 | ---------- 9 | prob_map: prob map for single lane, np array size (h, w) 10 | resize_shape: reshape size target, (H, W) 11 | 12 | Return: 13 | ---------- 14 | coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape 15 | """ 16 | if resize_shape is None: 17 | resize_shape = prob_map.shape 18 | h, w = prob_map.shape 19 | H, W = resize_shape 20 | 21 | coords = np.zeros(pts) 22 | for i in range(pts): 23 | y = int((H - 10 - i * y_px_gap) * h / H) 24 | if y < 0: 25 | break 26 | line = prob_map[y, :] 27 | id = np.argmax(line) 28 | if line[id] > thresh: 29 | coords[i] = int(id / w * W) 30 | if (coords > 0).sum() < 2: 31 | coords = np.zeros(pts) 32 | return coords 33 | 34 | 35 | def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3): 36 | """ 37 | Arguments: 38 | ---------- 39 | seg_pred: np.array size (5, h, w) 40 | resize_shape: reshape size target, (H, W) 41 | exist: list of existence, e.g. [0, 1, 1, 0] 42 | smooth: whether to smooth the probability or not 43 | y_px_gap: y pixel gap for sampling 44 | pts: how many points for one lane 45 | thresh: probability threshold 46 | 47 | Return: 48 | ---------- 49 | coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ] 50 | """ 51 | if resize_shape is None: 52 | resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w) 53 | _, h, w = seg_pred.shape 54 | H, W = resize_shape 55 | coordinates = [] 56 | 57 | if pts is None: 58 | pts = round(H / 2 / y_px_gap) 59 | 60 | seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0))) 61 | for i in range(4): 62 | prob_map = seg_pred[..., i + 1] 63 | if smooth: 64 | prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE) 65 | if exist[i] > 0: 66 | coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape) 67 | if (coords>0).sum() < 2: 68 | continue 69 | coordinates.append( 70 | [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in 71 | range(pts)]) 72 | 73 | return coordinates 74 | 75 | 76 | def getLane_CULane(prob_map, y_px_gap, pts, thresh, resize_shape=None): 77 | """ 78 | Arguments: 79 | ---------- 80 | prob_map: prob map for single lane, np array size (h, w) 81 | resize_shape: reshape size target, (H, W) 82 | Return: 83 | ---------- 84 | coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape 85 | """ 86 | if resize_shape is None: 87 | resize_shape = prob_map.shape 88 | h, w = prob_map.shape 89 | H, W = resize_shape 90 | 91 | coords = np.zeros(pts) 92 | for i in range(pts): 93 | y = int(h - i * y_px_gap / H * h - 1) 94 | if y < 0: 95 | break 96 | line = prob_map[y, :] 97 | id = np.argmax(line) 98 | if line[id] > thresh: 99 | coords[i] = int(id / w * W) 100 | if (coords > 0).sum() < 2: 101 | coords = np.zeros(pts) 102 | return coords 103 | 104 | 105 | def prob2lines_CULane(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=20, pts=None, thresh=0.3): 106 | """ 107 | Arguments: 108 | ---------- 109 | seg_pred: np.array size (5, h, w) 110 | resize_shape: reshape size target, (H, W) 111 | exist: list of existence, e.g. [0, 1, 1, 0] 112 | smooth: whether to smooth the probability or not 113 | y_px_gap: y pixel gap for sampling 114 | pts: how many points for one lane 115 | thresh: probability threshold 116 | Return: 117 | ---------- 118 | coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ] 119 | """ 120 | if resize_shape is None: 121 | resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w) 122 | _, h, w = seg_pred.shape 123 | H, W = resize_shape 124 | coordinates = [] 125 | 126 | if pts is None: 127 | pts = round(H / 2 / y_px_gap) 128 | 129 | seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0))) 130 | for i in range(4): 131 | prob_map = seg_pred[..., i + 1] 132 | if smooth: 133 | prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE) 134 | if exist[i] > 0: 135 | coords = getLane_CULane(prob_map, y_px_gap, pts, thresh, resize_shape) 136 | if (coords>0).sum() < 2: 137 | continue 138 | coordinates.append([[coords[j], H - 1 - j * y_px_gap] for j in range(pts) if coords[j] > 0]) 139 | 140 | return coordinates 141 | -------------------------------------------------------------------------------- /utils/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Code copied from pytorch-tutorial https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/04-utils/tensorboard/logger.py 2 | import tensorflow as tf 3 | import numpy as np 4 | from PIL import Image 5 | import scipy.misc 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | 12 | class TensorBoard(object): 13 | 14 | def __init__(self, log_dir): 15 | """Create a summary writer logging to log_dir.""" 16 | self.writer = tf.summary.FileWriter(log_dir) 17 | 18 | def scalar_summary(self, tag, value, step): 19 | """Log a scalar variable.""" 20 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 21 | self.writer.add_summary(summary, step) 22 | 23 | def image_summary(self, tag, images, step): 24 | """Log a list of images.""" 25 | 26 | img_summaries = [] 27 | for i, img in enumerate(images): 28 | # Write the image to a string 29 | try: 30 | s = StringIO() 31 | except: 32 | s = BytesIO() 33 | # scipy.misc.toimage(img).save(s, format="png") 34 | Image.fromarray(img).save(s, format='png') 35 | 36 | 37 | # Create an Image object 38 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 39 | height=img.shape[0], 40 | width=img.shape[1]) 41 | # Create a Summary value 42 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 43 | 44 | # Create and write Summary 45 | summary = tf.Summary(value=img_summaries) 46 | self.writer.add_summary(summary, step) 47 | 48 | def histo_summary(self, tag, values, step, bins=1000): 49 | """Log a histogram of the tensor of values.""" 50 | 51 | # Create a histogram using numpy 52 | counts, bin_edges = np.histogram(values, bins=bins) 53 | 54 | # Fill the fields of the histogram proto 55 | hist = tf.HistogramProto() 56 | hist.min = float(np.min(values)) 57 | hist.max = float(np.max(values)) 58 | hist.num = int(np.prod(values.shape)) 59 | hist.sum = float(np.sum(values)) 60 | hist.sum_squares = float(np.sum(values**2)) 61 | 62 | # Drop the start of the first bin 63 | bin_edges = bin_edges[1:] 64 | 65 | # Add bin edges and counts 66 | for edge in bin_edges: 67 | hist.bucket_limit.append(edge) 68 | for c in counts: 69 | hist.bucket.append(c) 70 | 71 | # Create and write Summary 72 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 73 | self.writer.add_summary(summary, step) 74 | self.writer.flush() 75 | -------------------------------------------------------------------------------- /utils/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | from .data_augmentation import * 3 | -------------------------------------------------------------------------------- /utils/transforms/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | 5 | from utils.transforms.transforms import CustomTransform 6 | 7 | 8 | class RandomFlip(CustomTransform): 9 | def __init__(self, prob_x=0, prob_y=0): 10 | """ 11 | Arguments: 12 | ---------- 13 | prob_x: range [0, 1], probability to use horizontal flip, setting to 0 means disabling flip 14 | prob_y: range [0, 1], probability to use vertical flip 15 | """ 16 | self.prob_x = prob_x 17 | self.prob_y = prob_y 18 | 19 | def __call__(self, sample): 20 | img = sample.get('img').copy() 21 | segLabel = sample.get('segLabel', None) 22 | if segLabel is not None: 23 | segLabel = segLabel.copy() 24 | 25 | flip_x = np.random.choice([False, True], p=(1 - self.prob_x, self.prob_x)) 26 | flip_y = np.random.choice([False, True], p=(1 - self.prob_y, self.prob_y)) 27 | if flip_x: 28 | img = np.ascontiguousarray(np.flip(img, axis=1)) 29 | if segLabel is not None: 30 | segLabel = np.ascontiguousarray(np.flip(segLabel, axis=1)) 31 | 32 | if flip_y: 33 | img = np.ascontiguousarray(np.flip(img, axis=0)) 34 | if segLabel is not None: 35 | segLabel = np.ascontiguousarray(np.flip(segLabel, axis=0)) 36 | 37 | _sample = sample.copy() 38 | _sample['img'] = img 39 | _sample['segLabel'] = segLabel 40 | return _sample 41 | 42 | 43 | class Darkness(CustomTransform): 44 | def __init__(self, coeff): 45 | assert coeff >= 1., "Darkness coefficient must be greater than 1" 46 | self.coeff = coeff 47 | 48 | def __call__(self, sample): 49 | img = sample.get('img') 50 | coeff = np.random.uniform(1., self.coeff) 51 | img = (img.astype('float32') / coeff).astype('uint8') 52 | 53 | _sample = sample.copy() 54 | _sample['img'] = img 55 | return _sample -------------------------------------------------------------------------------- /utils/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import numpy as np 4 | import torch 5 | from torchvision.transforms import Normalize as Normalize_th 6 | 7 | 8 | class CustomTransform: 9 | def __call__(self, *args, **kwargs): 10 | raise NotImplementedError 11 | 12 | def __str__(self): 13 | return self.__class__.__name__ 14 | 15 | def __eq__(self, name): 16 | return str(self) == name 17 | 18 | def __iter__(self): 19 | def iter_fn(): 20 | for t in [self]: 21 | yield t 22 | return iter_fn() 23 | 24 | def __contains__(self, name): 25 | for t in self.__iter__(): 26 | if isinstance(t, Compose): 27 | if name in t: 28 | return True 29 | elif name == t: 30 | return True 31 | return False 32 | 33 | 34 | class Compose(CustomTransform): 35 | """ 36 | All transform in Compose should be able to accept two non None variable, img and boxes 37 | """ 38 | def __init__(self, *transforms): 39 | self.transforms = [*transforms] 40 | 41 | def __call__(self, sample): 42 | for t in self.transforms: 43 | sample = t(sample) 44 | return sample 45 | 46 | def __iter__(self): 47 | return iter(self.transforms) 48 | 49 | def modules(self): 50 | yield self 51 | for t in self.transforms: 52 | if isinstance(t, Compose): 53 | for _t in t.modules(): 54 | yield _t 55 | else: 56 | yield t 57 | 58 | 59 | class Resize(CustomTransform): 60 | def __init__(self, size): 61 | if isinstance(size, int): 62 | size = (size, size) 63 | self.size = size #(W, H) 64 | 65 | def __call__(self, sample): 66 | img = sample.get('img') 67 | segLabel = sample.get('segLabel', None) 68 | 69 | img = cv2.resize(img, self.size, interpolation=cv2.INTER_CUBIC) 70 | if segLabel is not None: 71 | segLabel = cv2.resize(segLabel, self.size, interpolation=cv2.INTER_NEAREST) 72 | 73 | _sample = sample.copy() 74 | _sample['img'] = img 75 | _sample['segLabel'] = segLabel 76 | return _sample 77 | 78 | def reset_size(self, size): 79 | if isinstance(size, int): 80 | size = (size, size) 81 | self.size = size 82 | 83 | 84 | class RandomResize(Resize): 85 | """ 86 | Resize to (w, h), where w randomly samples from (minW, maxW) and h randomly samples from (minH, maxH) 87 | """ 88 | def __init__(self, minW, maxW, minH=None, maxH=None, batch=False): 89 | if minH is None or maxH is None: 90 | minH, maxH = minW, maxW 91 | super(RandomResize, self).__init__((minW, minH)) 92 | self.minW = minW 93 | self.maxW = maxW 94 | self.minH = minH 95 | self.maxH = maxH 96 | self.batch = batch 97 | 98 | def random_set_size(self): 99 | w = np.random.randint(self.minW, self.maxW+1) 100 | h = np.random.randint(self.minH, self.maxH+1) 101 | self.reset_size((w, h)) 102 | 103 | 104 | class Rotation(CustomTransform): 105 | def __init__(self, theta): 106 | self.theta = theta 107 | 108 | def __call__(self, sample): 109 | img = sample.get('img') 110 | segLabel = sample.get('segLabel', None) 111 | 112 | u = np.random.uniform() 113 | degree = (u-0.5) * self.theta 114 | R = cv2.getRotationMatrix2D((img.shape[1]//2, img.shape[0]//2), degree, 1) 115 | img = cv2.warpAffine(img, R, (img.shape[1], img.shape[0]), flags=cv2.INTER_LINEAR) 116 | if segLabel is not None: 117 | segLabel = cv2.warpAffine(segLabel, R, (segLabel.shape[1], segLabel.shape[0]), flags=cv2.INTER_NEAREST) 118 | 119 | _sample = sample.copy() 120 | _sample['img'] = img 121 | _sample['segLabel'] = segLabel 122 | return _sample 123 | 124 | def reset_theta(self, theta): 125 | self.theta = theta 126 | 127 | 128 | class Normalize(CustomTransform): 129 | def __init__(self, mean, std): 130 | self.transform = Normalize_th(mean, std) 131 | 132 | def __call__(self, sample): 133 | img = sample.get('img') 134 | 135 | img = self.transform(img) 136 | 137 | _sample = sample.copy() 138 | _sample['img'] = img 139 | return _sample 140 | 141 | 142 | class ToTensor(CustomTransform): 143 | def __init__(self, dtype=torch.float): 144 | self.dtype=dtype 145 | 146 | def __call__(self, sample): 147 | img = sample.get('img') 148 | segLabel = sample.get('segLabel', None) 149 | exist = sample.get('exist', None) 150 | 151 | img = img.transpose(2, 0, 1) 152 | img = torch.from_numpy(img).type(self.dtype) / 255. 153 | if segLabel is not None: 154 | segLabel = torch.from_numpy(segLabel).type(torch.long) 155 | if exist is not None: 156 | exist = torch.from_numpy(exist).type(torch.float32) # BCEloss requires float tensor 157 | 158 | _sample = sample.copy() 159 | _sample['img'] = img 160 | _sample['segLabel'] = segLabel 161 | _sample['exist'] = exist 162 | return _sample 163 | 164 | 165 | --------------------------------------------------------------------------------