├── .gitignore ├── README.md ├── configs └── flamnet │ ├── flamnet_dla34_culane.py │ └── flamnet_resnet34_culane.py ├── flamnet ├── __init__.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── culane.py │ ├── curvelanes.py │ ├── llamas.py │ ├── process │ │ ├── __init__.py │ │ ├── generate_lane_line.py │ │ ├── process.py │ │ └── transforms.py │ ├── registry.py │ └── tusimple.py ├── engine │ ├── __init__.py │ ├── optimizer.py │ ├── registry.py │ ├── runner.py │ └── scheduler.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── dla34.py │ │ ├── resnet.py │ │ └── topformer.py │ ├── heads │ │ ├── __init__.py │ │ └── flamnet_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── focal_loss.py │ │ └── lineiou_loss.py │ ├── necks │ │ ├── __init__.py │ │ ├── fpn.py │ │ ├── pafpn.py │ │ └── ppam_dsaformer.py │ ├── nets │ │ ├── __init__.py │ │ └── detector.py │ ├── registry.py │ └── utils │ │ ├── __init__.py │ │ ├── dynamic_assign.py │ │ ├── roi_gather.py │ │ └── seg_decoder.py ├── ops │ ├── __init__.py │ ├── csrc │ │ ├── nms.cpp │ │ └── nms_kernel.cu │ └── nms.py └── utils │ ├── __init__.py │ ├── config.py │ ├── culane_metric.py │ ├── lane.py │ ├── llamas_metric.py │ ├── llamas_utils.py │ ├── logger.py │ ├── net_utils.py │ ├── recorder.py │ ├── registry.py │ ├── tusimple_metric.py │ └── visualization.py ├── main.py ├── requirements.txt ├── setup.py └── tools ├── detect.py └── generate_seg_tusimple.py /.gitignore: -------------------------------------------------------------------------------- 1 | work_dirs 2 | data 3 | cache/ 4 | 5 | __pycache__/ 6 | */*.un~ 7 | .*.swp 8 | 9 | 10 | logs/ 11 | *.egg-info/ 12 | *.egg 13 | *.eggs 14 | .ipynb_checkpoints/ 15 | 16 | output.txt 17 | .vscode/* 18 | .DS_Store 19 | tmp.* 20 | *.pt 21 | *.pth 22 | *.un~ 23 | *.so 24 | build 25 | *.jpg 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FLAMNet 2 | PyTorch implementation of the paper "**FLAMNet: A Flexible Line Anchor Mechanism Network for Lane Detection —— Hao Ran, Yunfei Yin, Member, IEEE, Faliang Huang, Xianjian Bao**". 3 | 4 | ## CULane demo 5 | 6 | https://user-images.githubusercontent.com/79684320/233837770-5d0b6579-ff7e-4969-bc3f-03bb2dcf05a9.mp4 7 | 8 | ## Real traffic scene test demo 9 | 10 | * **Test the FLAMNet-DLA34 model trained on the CULane dataset in real traffic scenarios (Chongqing, China). The demo video is uploaded in Google Drive.** 11 | 12 | [![IMAGE ALT TEXT](https://user-images.githubusercontent.com/79684320/233836692-a980b0c1-3ed8-412e-b573-ef6e4d620c31.png)](https://drive.google.com/file/d/1V4gHCJGESfLwda-4dflLajrzL4gEgVhA/view?usp=sharing) 13 | 14 | * **Performance comparison between FLAMNet and CLRNet in real traffic scenarios (Chongqing, China). The demo video is uploaded in Google Drive** 15 | 16 | [![IMAGE ALT TEXT](https://user-images.githubusercontent.com/79684320/233836708-fc54aa48-beea-4b2e-865e-f4ff49c96c5a.png)](https://drive.google.com/file/d/1kbuZM1sK7lv_EbDmXL6GKLMJh_hz-DMr/view?usp=sharing) 17 | 18 | ## Results 19 | ![FLAMNet_CULane](https://user-images.githubusercontent.com/79684320/234516261-e433c528-da16-4225-abc4-3ac2c67faac4.jpg) 20 | 21 | ### CULane 22 | 23 | | Method | Backbone | F1@50 | F1w | FPS | 24 | | :--- | :--- | :---: | :---: | :---:| 25 | | FLMANet | [ResNet-34](https://drive.google.com/file/d/1mtX-lf7T1F88j7BIB6agG6erIPDkUYvI/view?usp=sharing) | 80.15 | 80.93 | 93 | 26 | | FLMANet | [DLA-34](https://drive.google.com/file/d/1SK8rr7jHhR_8sLynLQQwdUnO3yCeO1Nu/view?usp=sharing) | 80.67| 82.31 | 101 | 27 | 28 | ## Introduction 29 | ![FLAMNet](https://user-images.githubusercontent.com/79684320/233835753-07905d1a-ff30-44ff-9ea8-d68a03030781.png) 30 | - FLAMNet, a lane detection network with a flexible line anchor mechanism, adopts a model architecture where CNN and Transformer are connected in series. 31 | - FLAMNet achieves SOTA result on CULane, Tusimple, and LLAMAS datasets. 32 | 33 | 34 | ## Installation 35 | 36 | ### Prerequisites 37 | Only test on Ubuntu18.04 and 20.04 with: 38 | - Python >= 3.8 (tested with Python3.8) 39 | - PyTorch >= 1.6 (tested with Pytorch1.6) 40 | - CUDA (tested with cuda10.2) 41 | - Other dependencies described in `requirements.txt` 42 | 43 | ### Clone this repository 44 | Clone this code to your workspace. 45 | We call this directory as `$CLRNET_ROOT` 46 | ```Shell 47 | git clone https://github.com/RanHao-cq/FLAMNet.git 48 | ``` 49 | 50 | ### Create a conda virtual environment and activate it (conda is optional) 51 | 52 | ```Shell 53 | conda create -n flamnet python=3.8 -y 54 | conda activate flamnet 55 | ``` 56 | 57 | ### Install dependencies 58 | 59 | ```Shell 60 | # Install pytorch firstly, the cudatoolkit version should be same in your system. 61 | 62 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 63 | 64 | # Or you can install via pip 65 | pip install torch==1.8.0 torchvision==0.9.0 66 | 67 | # Install python packages 68 | python setup.py build develop 69 | ``` 70 | 71 | ### Data preparation 72 | 73 | #### CULane 74 | 75 | Download [CULane](https://xingangpan.github.io/projects/CULane.html). Then extract them to `$CULANEROOT`. Create link to `data` directory. 76 | 77 | ```Shell 78 | cd $CLRNET_ROOT 79 | mkdir -p data 80 | ln -s $CULANEROOT data/CULane 81 | ``` 82 | 83 | For CULane, you should have structure like this: 84 | ``` 85 | $CULANEROOT/driver_xx_xxframe # data folders x6 86 | $CULANEROOT/laneseg_label_w16 # lane segmentation labels 87 | $CULANEROOT/list # data lists 88 | ``` 89 | 90 | 91 | #### Tusimple 92 | Download [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$TUSIMPLEROOT`. Create link to `data` directory. 93 | 94 | ```Shell 95 | cd $CLRNET_ROOT 96 | mkdir -p data 97 | ln -s $TUSIMPLEROOT data/tusimple 98 | ``` 99 | 100 | For Tusimple, you should have structure like this: 101 | ``` 102 | $TUSIMPLEROOT/clips # data folders 103 | $TUSIMPLEROOT/lable_data_xxxx.json # label json file x4 104 | $TUSIMPLEROOT/test_tasks_0627.json # test tasks json file 105 | $TUSIMPLEROOT/test_label.json # test label json file 106 | 107 | ``` 108 | 109 | For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation. 110 | 111 | ```Shell 112 | python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT 113 | # this will generate seg_label directory 114 | ``` 115 | 116 | #### LLAMAS 117 | Dowload [LLAMAS](https://unsupervised-llamas.com/llamas/). Then extract them to `$LLAMASROOT`. Create link to `data` directory. 118 | 119 | ```Shell 120 | cd $CLRNET_ROOT 121 | mkdir -p data 122 | ln -s $LLAMASROOT data/llamas 123 | ``` 124 | 125 | Unzip both files (`color_images.zip` and `labels.zip`) into the same directory (e.g., `data/llamas/`), which will be the dataset's root. For LLAMAS, you should have structure like this: 126 | ``` 127 | $LLAMASROOT/color_images/train # data folders 128 | $LLAMASROOT/color_images/test # data folders 129 | $LLAMASROOT/color_images/valid # data folders 130 | $LLAMASROOT/labels/train # labels folders 131 | $LLAMASROOT/labels/valid # labels folders 132 | ``` 133 | 134 | 135 | ## Getting Started 136 | 137 | ### Training 138 | For training, run 139 | ```Shell 140 | python main.py [configs/path_to_your_config] --gpus [gpu_num] 141 | ``` 142 | 143 | For example, run 144 | ```Shell 145 | python main.py configs/flamnet/flamnet_dla34_culane.py --gpus 0 146 | ``` 147 | 148 | ### Validation 149 | For testing, run 150 | ```Shell 151 | python main.py [configs/path_to_your_config] --[test|validate] --load_from [path_to_your_model] --gpus [gpu_num] 152 | ``` 153 | 154 | For example, run 155 | ```Shell 156 | python main.py configs/flamnet/flamnet_dla34_culane.py --validate --load_from culane_dla34.pth --gpus 0 157 | ``` 158 | 159 | -------------------------------------------------------------------------------- /configs/flamnet/flamnet_dla34_culane.py: -------------------------------------------------------------------------------- 1 | net = dict(type='Detector', ) 2 | 3 | norm_cfg = dict(type='SyncBN', requires_grad=True) 4 | 5 | backbone = dict( 6 | type='DLAWrapper', 7 | dla='dla34', 8 | pretrained=True, 9 | ) 10 | 11 | num_points = 72 12 | max_lanes = 4 13 | sample_y = range(589, 230, -20) 14 | 15 | heads = dict(type='FLAMNetHead', 16 | num_priors=192, 17 | refine_layers=3, 18 | fc_hidden_dim=64, 19 | sample_points=36) 20 | 21 | iou_loss_weight = 2. 22 | cls_loss_weight = 2. 23 | xyt_loss_weight = 0.2 24 | seg_loss_weight = 1.0 25 | 26 | work_dirs = "work_dirs/flamnet/dla34_culane" 27 | 28 | neck = dict( 29 | type='PPAM_DSAformer', 30 | in_channels = [64,128,256,512], 31 | channels=[32, 64, 128, 160], 32 | out_channels=[None, 64, 64, 64], 33 | decode_out_indices=[1, 2, 3], 34 | depths=4, 35 | num_heads=8, 36 | c2t_stride=2, 37 | drop_path_rate=0.2, 38 | norm_cfg=norm_cfg 39 | ) 40 | 41 | 42 | test_parameters = dict(conf_threshold=0.4, nms_thres=50, nms_topk=max_lanes) 43 | 44 | epochs = 15 45 | batch_size = 24 46 | 47 | optimizer = dict(type='AdamW', lr=0.6e-3) # 3e-4 for batchsize 8 48 | total_iter = (88880 // batch_size) * epochs 49 | scheduler = dict(type='CosineAnnealingLR', T_max=total_iter) 50 | 51 | eval_ep = 1 52 | save_ep = 1 53 | 54 | img_norm = dict(mean=[103.939, 116.779, 123.68], std=[1., 1., 1.]) 55 | ori_img_w = 1640 56 | ori_img_h = 590 57 | img_w = 800 58 | img_h = 320 59 | cut_height = 270 60 | 61 | train_process = [ 62 | dict( 63 | type='GenerateLaneLine', 64 | transforms=[ 65 | dict(name='Resize', 66 | parameters=dict(size=dict(height=img_h, width=img_w)), 67 | p=1.0), 68 | dict(name='HorizontalFlip', parameters=dict(p=1.0), p=0.5), 69 | dict(name='ChannelShuffle', parameters=dict(p=1.0), p=0.1), 70 | dict(name='MultiplyAndAddToBrightness', 71 | parameters=dict(mul=(0.85, 1.15), add=(-10, 10)), 72 | p=0.6), 73 | dict(name='AddToHueAndSaturation', 74 | parameters=dict(value=(-10, 10)), 75 | p=0.7), 76 | dict(name='OneOf', 77 | transforms=[ 78 | dict(name='MotionBlur', parameters=dict(k=(3, 5))), 79 | dict(name='MedianBlur', parameters=dict(k=(3, 5))) 80 | ], 81 | p=0.2), 82 | dict(name='Affine', 83 | parameters=dict(translate_percent=dict(x=(-0.1, 0.1), 84 | y=(-0.1, 0.1)), 85 | rotate=(-10, 10), 86 | scale=(0.8, 1.2)), 87 | p=0.7), 88 | dict(name='Resize', 89 | parameters=dict(size=dict(height=img_h, width=img_w)), 90 | p=1.0), 91 | ], 92 | ), 93 | dict(type='ToTensor', keys=['img', 'lane_line', 'seg']), 94 | ] 95 | 96 | val_process = [ 97 | dict(type='GenerateLaneLine', 98 | transforms=[ 99 | dict(name='Resize', 100 | parameters=dict(size=dict(height=img_h, width=img_w)), 101 | p=1.0), 102 | ], 103 | training=False), 104 | dict(type='ToTensor', keys=['img']), 105 | ] 106 | 107 | dataset_path = './data/CULane' 108 | dataset_type = 'CULane' 109 | dataset = dict(train=dict( 110 | type=dataset_type, 111 | data_root=dataset_path, 112 | split='train', 113 | processes=train_process, 114 | ), 115 | val=dict( 116 | type=dataset_type, 117 | data_root=dataset_path, 118 | split='test', 119 | processes=val_process, 120 | ), 121 | test=dict( 122 | type=dataset_type, 123 | data_root=dataset_path, 124 | split='test', 125 | processes=val_process, 126 | )) 127 | 128 | workers = 12 129 | log_interval = 500 130 | # seed = 0 131 | num_classes = 4 + 1 132 | ignore_label = 255 133 | bg_weight = 0.4 134 | lr_update_by_epoch = False 135 | -------------------------------------------------------------------------------- /configs/flamnet/flamnet_resnet34_culane.py: -------------------------------------------------------------------------------- 1 | net = dict(type='Detector', ) 2 | 3 | norm_cfg = dict(type='SyncBN', requires_grad=True) 4 | 5 | backbone = dict( 6 | type='ResNetWrapper', 7 | resnet='resnet34', 8 | pretrained=True, 9 | replace_stride_with_dilation=[False, False, False], 10 | out_conv=False, 11 | ) 12 | 13 | num_points = 72 14 | max_lanes = 4 15 | sample_y = range(589, 230, -20) 16 | 17 | heads = dict(type='FLAMNetHead', 18 | num_priors=384, 19 | refine_layers=3, 20 | fc_hidden_dim=64, 21 | sample_points=36) 22 | 23 | iou_loss_weight = 2. 24 | cls_loss_weight = 2. 25 | xyt_loss_weight = 0.2 26 | seg_loss_weight = 1.0 27 | 28 | work_dirs = "work_dirs/flamnet/resnet34_culane" 29 | 30 | neck = dict( 31 | type='PPAM_DSAformer', 32 | in_channels = [64,128,256,512], 33 | channels=[32, 64, 128, 160], 34 | out_channels=[None, 64, 64, 64], 35 | decode_out_indices=[1, 2, 3], 36 | depths=4, 37 | num_heads=8, 38 | c2t_stride=2, 39 | drop_path_rate=0.1, 40 | norm_cfg=norm_cfg 41 | ) 42 | 43 | test_parameters = dict(conf_threshold=0.4, nms_thres=50, nms_topk=max_lanes) 44 | 45 | epochs = 15 46 | batch_size = 24 47 | 48 | optimizer = dict(type='AdamW', lr=0.6e-3) # 3e-4 for batchsize 8 49 | total_iter = (88880 // batch_size) * epochs 50 | scheduler = dict(type='CosineAnnealingLR', T_max=total_iter) 51 | 52 | eval_ep = 1 53 | save_ep = 1 54 | 55 | img_norm = dict(mean=[103.939, 116.779, 123.68], std=[1., 1., 1.]) 56 | ori_img_w = 1640 57 | ori_img_h = 590 58 | img_w = 800 59 | img_h = 320 60 | cut_height = 270 61 | 62 | train_process = [ 63 | dict( 64 | type='GenerateLaneLine', 65 | transforms=[ 66 | dict(name='Resize', 67 | parameters=dict(size=dict(height=img_h, width=img_w)), 68 | p=1.0), 69 | dict(name='HorizontalFlip', parameters=dict(p=1.0), p=0.5), 70 | dict(name='ChannelShuffle', parameters=dict(p=1.0), p=0.1), 71 | dict(name='MultiplyAndAddToBrightness', 72 | parameters=dict(mul=(0.85, 1.15), add=(-10, 10)), 73 | p=0.6), 74 | dict(name='AddToHueAndSaturation', 75 | parameters=dict(value=(-10, 10)), 76 | p=0.7), 77 | dict(name='OneOf', 78 | transforms=[ 79 | dict(name='MotionBlur', parameters=dict(k=(3, 5))), 80 | dict(name='MedianBlur', parameters=dict(k=(3, 5))) 81 | ], 82 | p=0.2), 83 | dict(name='Affine', 84 | parameters=dict(translate_percent=dict(x=(-0.1, 0.1), 85 | y=(-0.1, 0.1)), 86 | rotate=(-10, 10), 87 | scale=(0.8, 1.2)), 88 | p=0.7), 89 | dict(name='Resize', 90 | parameters=dict(size=dict(height=img_h, width=img_w)), 91 | p=1.0), 92 | ], 93 | ), 94 | dict(type='ToTensor', keys=['img', 'lane_line', 'seg']), 95 | ] 96 | 97 | val_process = [ 98 | dict(type='GenerateLaneLine', 99 | transforms=[ 100 | dict(name='Resize', 101 | parameters=dict(size=dict(height=img_h, width=img_w)), 102 | p=1.0), 103 | ], 104 | training=False), 105 | dict(type='ToTensor', keys=['img']), 106 | ] 107 | 108 | dataset_path = './data/CULane' 109 | dataset_type = 'CULane' 110 | dataset = dict(train=dict( 111 | type=dataset_type, 112 | data_root=dataset_path, 113 | split='train', 114 | processes=train_process, 115 | ), 116 | val=dict( 117 | type=dataset_type, 118 | data_root=dataset_path, 119 | split='test', 120 | processes=val_process, 121 | ), 122 | test=dict( 123 | type=dataset_type, 124 | data_root=dataset_path, 125 | split='test', 126 | processes=val_process, 127 | )) 128 | 129 | workers = 12 130 | log_interval = 500 131 | # seed = 0 132 | num_classes = 4 + 1 133 | ignore_label = 255 134 | bg_weight = 0.4 135 | lr_update_by_epoch = False 136 | -------------------------------------------------------------------------------- /flamnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .ops import * 2 | -------------------------------------------------------------------------------- /flamnet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import build_dataset, build_dataloader 2 | 3 | from .tusimple import TuSimple 4 | from .culane import CULane 5 | from .llamas import LLAMAS 6 | from .process import * 7 | -------------------------------------------------------------------------------- /flamnet/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision 8 | import logging 9 | from .registry import DATASETS 10 | from .process import Process 11 | from flamnet.utils.visualization import imshow_lanes 12 | from mmcv.parallel import DataContainer as DC 13 | 14 | 15 | @DATASETS.register_module 16 | class BaseDataset(Dataset): 17 | def __init__(self, data_root, split, processes=None, cfg=None): 18 | self.cfg = cfg 19 | self.logger = logging.getLogger(__name__) 20 | self.data_root = data_root 21 | self.training = 'train' in split 22 | self.processes = Process(processes, cfg) 23 | 24 | def view(self, predictions, img_metas): 25 | img_metas = [item for img_meta in img_metas.data for item in img_meta] 26 | for lanes, img_meta in zip(predictions, img_metas): 27 | img_name = img_meta['img_name'] 28 | img = cv2.imread(osp.join(self.data_root, img_name)) 29 | out_file = osp.join(self.cfg.work_dir, 'visualization', 30 | img_name.replace('/', '_')) 31 | lanes = [lane.to_array(self.cfg) for lane in lanes] 32 | imshow_lanes(img, lanes, out_file=out_file) 33 | 34 | def __len__(self): 35 | return len(self.data_infos) 36 | 37 | def __getitem__(self, idx): 38 | data_info = self.data_infos[idx] 39 | img = cv2.imread(data_info['img_path']) 40 | img = img[self.cfg.cut_height:, :, :] 41 | sample = data_info.copy() 42 | sample.update({'img': img}) 43 | 44 | if self.training: 45 | label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED) 46 | if len(label.shape) > 2: 47 | label = label[:, :, 0] 48 | label = label.squeeze() 49 | label = label[self.cfg.cut_height:, :] 50 | sample.update({'mask': label}) 51 | 52 | if self.cfg.cut_height != 0: 53 | new_lanes = [] 54 | for i in sample['lanes']: 55 | lanes = [] 56 | for p in i: 57 | lanes.append((p[0], p[1] - self.cfg.cut_height)) 58 | new_lanes.append(lanes) 59 | sample.update({'lanes': new_lanes}) 60 | 61 | sample = self.processes(sample) 62 | meta = {'full_img_path': data_info['img_path'], 63 | 'img_name': data_info['img_name']} 64 | meta = DC(meta, cpu_only=True) 65 | sample.update({'meta': meta}) 66 | 67 | return sample 68 | -------------------------------------------------------------------------------- /flamnet/datasets/culane.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | from .base_dataset import BaseDataset 5 | from .registry import DATASETS 6 | import flamnet.utils.culane_metric as culane_metric 7 | import cv2 8 | from tqdm import tqdm 9 | import logging 10 | import pickle as pkl 11 | 12 | LIST_FILE = { 13 | 'train': 'list/train_gt.txt', 14 | 'val': 'list/val.txt', 15 | 'test': 'list/test.txt', 16 | } 17 | 18 | CATEGORYS = { 19 | 'normal': 'list/test_split/test0_normal.txt', 20 | 'crowd': 'list/test_split/test1_crowd.txt', 21 | 'hlight': 'list/test_split/test2_hlight.txt', 22 | 'shadow': 'list/test_split/test3_shadow.txt', 23 | 'noline': 'list/test_split/test4_noline.txt', 24 | 'arrow': 'list/test_split/test5_arrow.txt', 25 | 'curve': 'list/test_split/test6_curve.txt', 26 | 'cross': 'list/test_split/test7_cross.txt', 27 | 'night': 'list/test_split/test8_night.txt', 28 | } 29 | 30 | 31 | @DATASETS.register_module 32 | class CULane(BaseDataset): 33 | def __init__(self, data_root, split, processes=None, cfg=None): 34 | super().__init__(data_root, split, processes=processes, cfg=cfg) 35 | self.list_path = osp.join(data_root, LIST_FILE[split]) 36 | self.split = split 37 | self.load_annotations() 38 | 39 | def load_annotations(self): 40 | self.logger.info('Loading CULane annotations...') 41 | # Waiting for the dataset to load is tedious, let's cache it 42 | os.makedirs('cache', exist_ok=True) 43 | cache_path = 'cache/culane_{}.pkl'.format(self.split) 44 | if os.path.exists(cache_path): 45 | with open(cache_path, 'rb') as cache_file: 46 | self.data_infos = pkl.load(cache_file) 47 | self.max_lanes = max( 48 | len(anno['lanes']) for anno in self.data_infos) 49 | return 50 | 51 | self.data_infos = [] 52 | with open(self.list_path) as list_file: 53 | for line in list_file: 54 | infos = self.load_annotation(line.split()) 55 | self.data_infos.append(infos) 56 | 57 | # cache data infos to file 58 | with open(cache_path, 'wb') as cache_file: 59 | pkl.dump(self.data_infos, cache_file) 60 | 61 | def load_annotation(self, line): 62 | infos = {} 63 | img_line = line[0] 64 | img_line = img_line[1 if img_line[0] == '/' else 0::] 65 | img_path = os.path.join(self.data_root, img_line) 66 | infos['img_name'] = img_line 67 | infos['img_path'] = img_path 68 | if len(line) > 1: 69 | mask_line = line[1] 70 | mask_line = mask_line[1 if mask_line[0] == '/' else 0::] 71 | mask_path = os.path.join(self.data_root, mask_line) 72 | infos['mask_path'] = mask_path 73 | 74 | if len(line) > 2: 75 | exist_list = [int(l) for l in line[2:]] 76 | infos['lane_exist'] = np.array(exist_list) 77 | 78 | anno_path = img_path[:-3] + 'lines.txt' # remove sufix jpg and add lines.txt 79 | with open(anno_path, 'r') as anno_file: 80 | data = [ 81 | list(map(float, line.split())) 82 | for line in anno_file.readlines() 83 | ] 84 | lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) 85 | if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data] 86 | lanes = [list(set(lane)) for lane in lanes] # remove duplicated points 87 | lanes = [lane for lane in lanes 88 | if len(lane) > 2] # remove lanes with less than 2 points 89 | 90 | lanes = [sorted(lane, key=lambda x: x[1]) 91 | for lane in lanes] # sort by y 92 | infos['lanes'] = lanes 93 | 94 | return infos 95 | 96 | def get_prediction_string(self, pred): 97 | ys = np.arange(270, 590, 8) / self.cfg.ori_img_h 98 | out = [] 99 | for lane in pred: 100 | xs = lane(ys) 101 | valid_mask = (xs >= 0) & (xs < 1) 102 | xs = xs * self.cfg.ori_img_w 103 | lane_xs = xs[valid_mask] 104 | lane_ys = ys[valid_mask] * self.cfg.ori_img_h 105 | lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1] 106 | lane_str = ' '.join([ 107 | '{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys) 108 | ]) 109 | if lane_str != '': 110 | out.append(lane_str) 111 | 112 | return '\n'.join(out) 113 | 114 | def evaluate(self, predictions, output_basedir): 115 | import logging 116 | logger = logging.getLogger(__name__) 117 | loss_lines = [[], [], [], []] 118 | print('Generating prediction output...') 119 | for idx, pred in enumerate(predictions): 120 | output_dir = os.path.join( 121 | output_basedir, 122 | os.path.dirname(self.data_infos[idx]['img_name'])) 123 | output_filename = os.path.basename( 124 | self.data_infos[idx]['img_name'])[:-3] + 'lines.txt' 125 | os.makedirs(output_dir, exist_ok=True) 126 | output = self.get_prediction_string(pred) 127 | 128 | with open(os.path.join(output_dir, output_filename), 129 | 'w') as out_file: 130 | out_file.write(output) 131 | 132 | TOtal_TP = 0 133 | TOtal_FP = 0 134 | TOtal_FN = 0 135 | for cate, cate_file in CATEGORYS.items(): 136 | result = culane_metric.eval_predictions(output_basedir, 137 | self.data_root, 138 | os.path.join(self.data_root, cate_file), 139 | iou_thresholds=[0.5], 140 | official=True) 141 | 142 | TOtal_TP += result[0.5]['TP'] 143 | TOtal_FP += result[0.5]['FP'] 144 | TOtal_FN += result[0.5]['FN'] 145 | 146 | TOtal_P = TOtal_TP * 1.0 / (TOtal_TP + TOtal_FP + 1e-9) 147 | TOtal_R = TOtal_TP * 1.0 / (TOtal_TP + TOtal_FN + 1e-9) 148 | TOtal_F = 2 * TOtal_P * TOtal_R / (TOtal_P + TOtal_R + 1e-9) 149 | 150 | logger.info('precision: {}, recall: {}, f1: {}'.format(TOtal_P, TOtal_R, TOtal_F)) 151 | 152 | # result = culane_metric.eval_predictions(output_basedir, 153 | # self.data_root, 154 | # self.list_path, 155 | # iou_thresholds=np.linspace[0.5], 156 | # official=True) 157 | 158 | return TOtal_F 159 | -------------------------------------------------------------------------------- /flamnet/datasets/curvelanes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | from .base_dataset import BaseDataset 5 | from .registry import DATASETS 6 | import flamnet.utils.culane_metric as culane_metric 7 | import cv2 8 | from tqdm import tqdm 9 | import logging 10 | import pickle as pkl 11 | 12 | LIST_FILE = { 13 | 'train': 'list/train_gt.txt', 14 | 'val': 'list/val.txt', 15 | 'test': 'list/test.txt', 16 | } 17 | 18 | CATEGORYS = { 19 | 'normal': 'list/test_split/test0_normal.txt', 20 | 'crowd': 'list/test_split/test1_crowd.txt', 21 | 'hlight': 'list/test_split/test2_hlight.txt', 22 | 'shadow': 'list/test_split/test3_shadow.txt', 23 | 'noline': 'list/test_split/test4_noline.txt', 24 | 'arrow': 'list/test_split/test5_arrow.txt', 25 | 'curve': 'list/test_split/test6_curve.txt', 26 | 'cross': 'list/test_split/test7_cross.txt', 27 | 'night': 'list/test_split/test8_night.txt', 28 | } 29 | 30 | 31 | @DATASETS.register_module 32 | class CULane(BaseDataset): 33 | def __init__(self, data_root, split, processes=None, cfg=None): 34 | super().__init__(data_root, split, processes=processes, cfg=cfg) 35 | self.list_path = osp.join(data_root, LIST_FILE[split]) 36 | self.split = split 37 | self.load_annotations() 38 | 39 | def load_annotations(self): 40 | self.logger.info('Loading CULane annotations...') 41 | # Waiting for the dataset to load is tedious, let's cache it 42 | os.makedirs('cache', exist_ok=True) 43 | cache_path = 'cache/culane_{}.pkl'.format(self.split) 44 | if os.path.exists(cache_path): 45 | with open(cache_path, 'rb') as cache_file: 46 | self.data_infos = pkl.load(cache_file) 47 | self.max_lanes = max( 48 | len(anno['lanes']) for anno in self.data_infos) 49 | return 50 | 51 | self.data_infos = [] 52 | with open(self.list_path) as list_file: 53 | for line in list_file: 54 | infos = self.load_annotation(line.split()) 55 | self.data_infos.append(infos) 56 | 57 | # cache data infos to file 58 | with open(cache_path, 'wb') as cache_file: 59 | pkl.dump(self.data_infos, cache_file) 60 | 61 | def load_annotation(self, line): 62 | infos = {} 63 | img_line = line[0] 64 | img_line = img_line[1 if img_line[0] == '/' else 0::] 65 | img_path = os.path.join(self.data_root, img_line) 66 | infos['img_name'] = img_line 67 | infos['img_path'] = img_path 68 | if len(line) > 1: 69 | mask_line = line[1] 70 | mask_line = mask_line[1 if mask_line[0] == '/' else 0::] 71 | mask_path = os.path.join(self.data_root, mask_line) 72 | infos['mask_path'] = mask_path 73 | 74 | if len(line) > 2: 75 | exist_list = [int(l) for l in line[2:]] 76 | infos['lane_exist'] = np.array(exist_list) 77 | 78 | anno_path = img_path[:-3] + 'lines.txt' # remove sufix jpg and add lines.txt 79 | with open(anno_path, 'r') as anno_file: 80 | data = [ 81 | list(map(float, line.split())) 82 | for line in anno_file.readlines() 83 | ] 84 | lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2) 85 | if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data] 86 | lanes = [list(set(lane)) for lane in lanes] # remove duplicated points 87 | lanes = [lane for lane in lanes 88 | if len(lane) > 2] # remove lanes with less than 2 points 89 | 90 | lanes = [sorted(lane, key=lambda x: x[1]) 91 | for lane in lanes] # sort by y 92 | infos['lanes'] = lanes 93 | 94 | return infos 95 | 96 | def get_prediction_string(self, pred): 97 | ys = np.arange(270, 590, 8) / self.cfg.ori_img_h 98 | out = [] 99 | for lane in pred: 100 | xs = lane(ys) 101 | valid_mask = (xs >= 0) & (xs < 1) 102 | xs = xs * self.cfg.ori_img_w 103 | lane_xs = xs[valid_mask] 104 | lane_ys = ys[valid_mask] * self.cfg.ori_img_h 105 | lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1] 106 | lane_str = ' '.join([ 107 | '{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys) 108 | ]) 109 | if lane_str != '': 110 | out.append(lane_str) 111 | 112 | return '\n'.join(out) 113 | 114 | def evaluate(self, predictions, output_basedir): 115 | import logging 116 | logger = logging.getLogger(__name__) 117 | loss_lines = [[], [], [], []] 118 | print('Generating prediction output...') 119 | for idx, pred in enumerate(predictions): 120 | output_dir = os.path.join( 121 | output_basedir, 122 | os.path.dirname(self.data_infos[idx]['img_name'])) 123 | output_filename = os.path.basename( 124 | self.data_infos[idx]['img_name'])[:-3] + 'lines.txt' 125 | os.makedirs(output_dir, exist_ok=True) 126 | output = self.get_prediction_string(pred) 127 | 128 | with open(os.path.join(output_dir, output_filename), 129 | 'w') as out_file: 130 | out_file.write(output) 131 | 132 | TOtal_TP = 0 133 | TOtal_FP = 0 134 | TOtal_FN = 0 135 | for cate, cate_file in CATEGORYS.items(): 136 | result = culane_metric.eval_predictions(output_basedir, 137 | self.data_root, 138 | os.path.join(self.data_root, cate_file), 139 | iou_thresholds=[0.5], 140 | official=True) 141 | 142 | TOtal_TP += result[0.5]['TP'] 143 | TOtal_FP += result[0.5]['FP'] 144 | TOtal_FN += result[0.5]['FN'] 145 | 146 | TOtal_P = TOtal_TP * 1.0 / (TOtal_TP + TOtal_FP + 1e-9) 147 | TOtal_R = TOtal_TP * 1.0 / (TOtal_TP + TOtal_FN + 1e-9) 148 | TOtal_F = 2 * TOtal_P * TOtal_R / (TOtal_P + TOtal_R + 1e-9) 149 | 150 | logger.info('precision: {}, recall: {}, f1: {}'.format(TOtal_P, TOtal_R, TOtal_F)) 151 | 152 | # result = culane_metric.eval_predictions(output_basedir, 153 | # self.data_root, 154 | # self.list_path, 155 | # iou_thresholds=np.linspace[0.5], 156 | # official=True) 157 | 158 | return TOtal_F 159 | -------------------------------------------------------------------------------- /flamnet/datasets/llamas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pkl 3 | import cv2 4 | 5 | from .registry import DATASETS 6 | import numpy as np 7 | from tqdm import tqdm 8 | from .base_dataset import BaseDataset 9 | 10 | TRAIN_LABELS_DIR = 'labels/train' 11 | TEST_LABELS_DIR = 'labels/valid' 12 | TEST_IMGS_DIR = 'color_images/test' 13 | SPLIT_DIRECTORIES = {'train': 'labels/train', 'val': 'labels/valid'} 14 | from flamnet.utils.llamas_utils import get_horizontal_values_for_four_lanes 15 | import flamnet.utils.llamas_metric as llamas_metric 16 | 17 | 18 | @DATASETS.register_module 19 | class LLAMAS(BaseDataset): 20 | def __init__(self, data_root, split='train', processes=None, cfg=None): 21 | self.split = split 22 | self.data_root = data_root 23 | super().__init__(data_root, split, processes, cfg) 24 | if split != 'test' and split not in SPLIT_DIRECTORIES.keys(): 25 | raise Exception('Split `{}` does not exist.'.format(split)) 26 | if split != 'test': 27 | self.labels_dir = os.path.join(self.data_root, 28 | SPLIT_DIRECTORIES[split]) 29 | 30 | self.data_infos = [] 31 | self.load_annotations() 32 | 33 | def get_img_heigth(self, _): 34 | return self.cfg.ori_img_h 35 | 36 | def get_img_width(self, _): 37 | return self.cfg.ori_img_w 38 | 39 | def get_metrics(self, lanes, _): 40 | # Placeholders 41 | return [0] * len(lanes), [0] * len(lanes), [1] * len(lanes), [ 42 | 1 43 | ] * len(lanes) 44 | 45 | def get_img_path(self, json_path): 46 | # /foo/bar/test/folder/image_label.ext --> test/folder/image_label.ext 47 | base_name = '/'.join(json_path.split('/')[-3:]) 48 | image_path = os.path.join( 49 | 'color_images', base_name.replace('.json', '_color_rect.png')) 50 | return image_path 51 | 52 | def get_img_name(self, json_path): 53 | base_name = (json_path.split('/')[-1]).replace('.json', 54 | '_color_rect.png') 55 | return base_name 56 | 57 | def get_json_paths(self): 58 | json_paths = [] 59 | for root, _, files in os.walk(self.labels_dir): 60 | for file in files: 61 | if file.endswith(".json"): 62 | json_paths.append(os.path.join(root, file)) 63 | return json_paths 64 | 65 | def load_annotations(self): 66 | # the labels are not public for the test set yet 67 | if self.split == 'test': 68 | imgs_dir = os.path.join(self.data_root, TEST_IMGS_DIR) 69 | self.data_infos = [{ 70 | 'img_path': 71 | os.path.join(root, file), 72 | 'img_name': 73 | os.path.join(TEST_IMGS_DIR, 74 | root.split('/')[-1], file), 75 | 'lanes': [], 76 | 'relative_path': 77 | os.path.join(root.split('/')[-1], file) 78 | } for root, _, files in os.walk(imgs_dir) for file in files 79 | if file.endswith('.png')] 80 | self.data_infos = sorted(self.data_infos, 81 | key=lambda x: x['img_path']) 82 | return 83 | 84 | # Waiting for the dataset to load is tedious, let's cache it 85 | os.makedirs('cache', exist_ok=True) 86 | cache_path = 'cache/llamas_{}.pkl'.format(self.split) 87 | if os.path.exists(cache_path): 88 | with open(cache_path, 'rb') as cache_file: 89 | self.data_infos = pkl.load(cache_file) 90 | self.max_lanes = max( 91 | len(anno['lanes']) for anno in self.data_infos) 92 | return 93 | 94 | self.max_lanes = 0 95 | print("Searching annotation files...") 96 | json_paths = self.get_json_paths() 97 | print('{} annotations found.'.format(len(json_paths))) 98 | 99 | for json_path in tqdm(json_paths): 100 | lanes = get_horizontal_values_for_four_lanes(json_path) 101 | lanes = [[(x, y) for x, y in zip(lane, range(self.cfg.ori_img_h)) 102 | if x >= 0] for lane in lanes] 103 | lanes = [lane for lane in lanes if len(lane) > 0] 104 | lanes = [list(set(lane)) 105 | for lane in lanes] # remove duplicated points 106 | lanes = [lane for lane in lanes 107 | if len(lane) > 2] # remove lanes with less than 2 points 108 | 109 | lanes = [sorted(lane, key=lambda x: x[1]) 110 | for lane in lanes] # sort by y 111 | lanes.sort(key=lambda lane: lane[0][0]) 112 | mask_path = json_path.replace('.json', '.png') 113 | 114 | # generate seg labels 115 | seg = np.zeros((717, 1276, 3)) 116 | for i, lane in enumerate(lanes): 117 | for j in range(0, len(lane) - 1): 118 | cv2.line(seg, (round(lane[j][0]), lane[j][1]), 119 | (round(lane[j + 1][0]), lane[j + 1][1]), 120 | (i + 1, i + 1, i + 1), 121 | thickness=15) 122 | 123 | cv2.imwrite(mask_path, seg) 124 | 125 | relative_path = self.get_img_path(json_path) 126 | img_path = os.path.join(self.data_root, relative_path) 127 | self.max_lanes = max(self.max_lanes, len(lanes)) 128 | self.data_infos.append({ 129 | 'img_path': img_path, 130 | 'img_name': relative_path, 131 | 'mask_path': mask_path, 132 | 'lanes': lanes, 133 | 'relative_path': relative_path 134 | }) 135 | 136 | with open(cache_path, 'wb') as cache_file: 137 | pkl.dump(self.data_infos, cache_file) 138 | 139 | def assign_class_to_lanes(self, lanes): 140 | return { 141 | label: value 142 | for label, value in zip(['l0', 'l1', 'r0', 'r1'], lanes) 143 | } 144 | 145 | def get_prediction_string(self, pred): 146 | ys = np.arange(300, 717, 1) / (self.cfg.ori_img_h - 1) 147 | out = [] 148 | for lane in pred: 149 | xs = lane(ys) 150 | valid_mask = (xs >= 0) & (xs < 1) 151 | xs = xs * (self.cfg.ori_img_w - 1) 152 | lane_xs = xs[valid_mask] 153 | lane_ys = ys[valid_mask] * (self.cfg.ori_img_h - 1) 154 | lane_xs, lane_ys = lane_xs[::-1], lane_ys[::-1] 155 | lane_str = ' '.join([ 156 | '{:.5f} {:.5f}'.format(x, y) for x, y in zip(lane_xs, lane_ys) 157 | ]) 158 | if lane_str != '': 159 | out.append(lane_str) 160 | 161 | return '\n'.join(out) 162 | 163 | def evaluate(self, predictions, output_basedir): 164 | print('Generating prediction output...') 165 | for idx, pred in enumerate(predictions): 166 | relative_path = self.data_infos[idx]['relative_path'] 167 | output_filename = '/'.join(relative_path.split('/')[-2:]).replace( 168 | '_color_rect.png', '.lines.txt') 169 | output_filepath = os.path.join(output_basedir, output_filename) 170 | os.makedirs(os.path.dirname(output_filepath), exist_ok=True) 171 | output = self.get_prediction_string(pred) 172 | with open(output_filepath, 'w') as out_file: 173 | out_file.write(output) 174 | if self.split == 'test': 175 | return None 176 | result = llamas_metric.eval_predictions(output_basedir, 177 | self.labels_dir, 178 | iou_thresholds=[0.5], 179 | unofficial=False) 180 | return result[0.5]['F1'] 181 | -------------------------------------------------------------------------------- /flamnet/datasets/process/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import (RandomLROffsetLABEL, RandomUDoffsetLABEL, Resize, 2 | RandomCrop, CenterCrop, RandomRotation, RandomBlur, 3 | RandomHorizontalFlip, Normalize, ToTensor) 4 | 5 | from .generate_lane_line import GenerateLaneLine 6 | from .process import Process 7 | 8 | __all__ = [ 9 | 'Process', 10 | 'RandomLROffsetLABEL', 11 | 'RandomUDoffsetLABEL', 12 | 'Resize', 13 | 'RandomCrop', 14 | 'CenterCrop', 15 | 'RandomRotation', 16 | 'RandomBlur', 17 | 'RandomHorizontalFlip', 18 | 'Normalize', 19 | 'ToTensor', 20 | 'GenerateLaneLine', 21 | ] 22 | -------------------------------------------------------------------------------- /flamnet/datasets/process/generate_lane_line.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | import imgaug.augmenters as iaa 5 | from imgaug.augmentables.lines import LineString, LineStringsOnImage 6 | from imgaug.augmentables.segmaps import SegmentationMapsOnImage 7 | from scipy.interpolate import InterpolatedUnivariateSpline 8 | from flamnet.datasets.process.transforms import CLRTransforms 9 | 10 | from ..registry import PROCESS 11 | 12 | 13 | @PROCESS.register_module 14 | class GenerateLaneLine(object): 15 | def __init__(self, transforms=None, cfg=None, training=True): 16 | self.transforms = transforms 17 | self.img_w, self.img_h = cfg.img_w, cfg.img_h 18 | self.num_points = cfg.num_points 19 | self.n_offsets = cfg.num_points 20 | self.n_strips = cfg.num_points - 1 21 | self.strip_size = self.img_h / self.n_strips 22 | self.max_lanes = cfg.max_lanes 23 | self.offsets_ys = np.arange(self.img_h, -1, -self.strip_size) 24 | self.training = training 25 | 26 | if transforms is None: 27 | transforms = CLRTransforms(self.img_h, self.img_w) 28 | 29 | if transforms is not None: 30 | img_transforms = [] 31 | for aug in transforms: 32 | p = aug['p'] 33 | if aug['name'] != 'OneOf': 34 | img_transforms.append( 35 | iaa.Sometimes(p=p, 36 | then_list=getattr( 37 | iaa, 38 | aug['name'])(**aug['parameters']))) 39 | else: 40 | img_transforms.append( 41 | iaa.Sometimes( 42 | p=p, 43 | then_list=iaa.OneOf([ 44 | getattr(iaa, 45 | aug_['name'])(**aug_['parameters']) 46 | for aug_ in aug['transforms'] 47 | ]))) 48 | else: 49 | img_transforms = [] 50 | self.transform = iaa.Sequential(img_transforms) 51 | 52 | def lane_to_linestrings(self, lanes): 53 | lines = [] 54 | for lane in lanes: 55 | lines.append(LineString(lane)) 56 | 57 | return lines 58 | 59 | def sample_lane(self, points, sample_ys): 60 | # this function expects the points to be sorted 61 | points = np.array(points) 62 | if not np.all(points[1:, 1] < points[:-1, 1]): 63 | raise Exception('Annotaion points have to be sorted') 64 | x, y = points[:, 0], points[:, 1] 65 | 66 | # interpolate points inside domain 67 | assert len(points) > 1 68 | interp = InterpolatedUnivariateSpline(y[::-1], 69 | x[::-1], 70 | k=min(3, 71 | len(points) - 1)) 72 | domain_min_y = y.min() 73 | domain_max_y = y.max() 74 | sample_ys_inside_domain = sample_ys[(sample_ys >= domain_min_y) 75 | & (sample_ys <= domain_max_y)] 76 | assert len(sample_ys_inside_domain) > 0 77 | interp_xs = interp(sample_ys_inside_domain) 78 | 79 | # extrapolate lane to the bottom of the image with a straight line using the 2 points closest to the bottom 80 | two_closest_points = points[:2] 81 | extrap = np.polyfit(two_closest_points[:, 1], 82 | two_closest_points[:, 0], 83 | deg=1) 84 | extrap_ys = sample_ys[sample_ys > domain_max_y] 85 | extrap_xs = np.polyval(extrap, extrap_ys) 86 | all_xs = np.hstack((extrap_xs, interp_xs)) 87 | 88 | # separate between inside and outside points 89 | inside_mask = (all_xs >= 0) & (all_xs < self.img_w) 90 | xs_inside_image = all_xs[inside_mask] 91 | xs_outside_image = all_xs[~inside_mask] 92 | 93 | return xs_outside_image, xs_inside_image 94 | 95 | def filter_lane(self, lane): 96 | assert lane[-1][1] <= lane[0][1] 97 | filtered_lane = [] 98 | used = set() 99 | for p in lane: 100 | if p[1] not in used: 101 | filtered_lane.append(p) 102 | used.add(p[1]) 103 | 104 | return filtered_lane 105 | 106 | def transform_annotation(self, anno, img_wh=None): 107 | img_w, img_h = self.img_w, self.img_h 108 | 109 | old_lanes = anno['lanes'] 110 | 111 | # removing lanes with less than 2 points 112 | old_lanes = filter(lambda x: len(x) > 1, old_lanes) 113 | # sort lane points by Y (bottom to top of the image) 114 | old_lanes = [sorted(lane, key=lambda x: -x[1]) for lane in old_lanes] 115 | # remove points with same Y (keep first occurrence) 116 | old_lanes = [self.filter_lane(lane) for lane in old_lanes] 117 | # normalize the annotation coordinates 118 | old_lanes = [[[ 119 | x * self.img_w / float(img_w), y * self.img_h / float(img_h) 120 | ] for x, y in lane] for lane in old_lanes] 121 | # create tranformed annotations 122 | lanes = np.ones( 123 | (self.max_lanes, 2 + 1 + 1 + 2 + self.n_offsets), dtype=np.float32 124 | ) * -1e5 # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, S+1 coordinates 125 | lanes_endpoints = np.ones((self.max_lanes, 2)) 126 | # lanes are invalid by default 127 | lanes[:, 0] = 1 128 | lanes[:, 1] = 0 129 | for lane_idx, lane in enumerate(old_lanes): 130 | if lane_idx >= self.max_lanes: 131 | break 132 | 133 | try: 134 | xs_outside_image, xs_inside_image = self.sample_lane( 135 | lane, self.offsets_ys) 136 | except AssertionError: 137 | continue 138 | if len(xs_inside_image) <= 1: 139 | continue 140 | all_xs = np.hstack((xs_outside_image, xs_inside_image)) 141 | lanes[lane_idx, 0] = 0 142 | lanes[lane_idx, 1] = 1 143 | lanes[lane_idx, 2] = len(xs_outside_image) / self.n_strips 144 | lanes[lane_idx, 3] = xs_inside_image[0] 145 | 146 | thetas = [] 147 | for i in range(1, len(xs_inside_image)): 148 | theta = math.atan( 149 | i * self.strip_size / 150 | (xs_inside_image[i] - xs_inside_image[0] + 1e-5)) / math.pi 151 | theta = theta if theta > 0 else 1 - abs(theta) 152 | thetas.append(theta) 153 | 154 | theta_far = sum(thetas) / len(thetas) 155 | 156 | # lanes[lane_idx, 157 | # 4] = (theta_closest + theta_far) / 2 # averaged angle 158 | lanes[lane_idx, 4] = theta_far 159 | lanes[lane_idx, 5] = len(xs_inside_image) 160 | lanes[lane_idx, 6:6 + len(all_xs)] = all_xs 161 | lanes_endpoints[lane_idx, 0] = (len(all_xs) - 1) / self.n_strips 162 | lanes_endpoints[lane_idx, 1] = xs_inside_image[-1] 163 | 164 | new_anno = { 165 | 'label': lanes, 166 | 'old_anno': anno, 167 | 'lane_endpoints': lanes_endpoints 168 | } 169 | return new_anno 170 | 171 | def linestrings_to_lanes(self, lines): 172 | lanes = [] 173 | for line in lines: 174 | lanes.append(line.coords) 175 | 176 | return lanes 177 | 178 | def __call__(self, sample): 179 | img_org = sample['img'] 180 | line_strings_org = self.lane_to_linestrings(sample['lanes']) 181 | line_strings_org = LineStringsOnImage(line_strings_org, 182 | shape=img_org.shape) 183 | 184 | for i in range(30): 185 | if self.training: 186 | mask_org = SegmentationMapsOnImage(sample['mask'], 187 | shape=img_org.shape) 188 | img, line_strings, seg = self.transform( 189 | image=img_org.copy().astype(np.uint8), 190 | line_strings=line_strings_org, 191 | segmentation_maps=mask_org) 192 | else: 193 | img, line_strings = self.transform( 194 | image=img_org.copy().astype(np.uint8), 195 | line_strings=line_strings_org) 196 | line_strings.clip_out_of_image_() 197 | new_anno = {'lanes': self.linestrings_to_lanes(line_strings)} 198 | try: 199 | annos = self.transform_annotation(new_anno, 200 | img_wh=(self.img_w, 201 | self.img_h)) 202 | label = annos['label'] 203 | lane_endpoints = annos['lane_endpoints'] 204 | break 205 | except: 206 | if (i + 1) == 30: 207 | self.logger.critical( 208 | 'Transform annotation failed 30 times :(') 209 | exit() 210 | 211 | sample['img'] = img.astype(np.float32) / 255. 212 | sample['lane_line'] = label 213 | sample['lanes_endpoints'] = lane_endpoints 214 | sample['gt_points'] = new_anno['lanes'] 215 | sample['seg'] = seg.get_arr() if self.training else np.zeros( 216 | img_org.shape) 217 | 218 | return sample 219 | -------------------------------------------------------------------------------- /flamnet/datasets/process/process.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | from flamnet.utils import build_from_cfg 4 | 5 | from ..registry import PROCESS 6 | 7 | 8 | class Process(object): 9 | """Compose multiple process sequentially. 10 | Args: 11 | process (Sequence[dict | callable]): Sequence of process object or 12 | config dict to be composed. 13 | """ 14 | def __init__(self, processes, cfg): 15 | assert isinstance(processes, collections.abc.Sequence) 16 | self.processes = [] 17 | for process in processes: 18 | if isinstance(process, dict): 19 | process = build_from_cfg(process, 20 | PROCESS, 21 | default_args=dict(cfg=cfg)) 22 | self.processes.append(process) 23 | elif callable(process): 24 | self.processes.append(process) 25 | else: 26 | raise TypeError('process must be callable or a dict') 27 | 28 | def __call__(self, data): 29 | """Call function to apply processes sequentially. 30 | Args: 31 | data (dict): A result dict contains the data to process. 32 | Returns: 33 | dict: Processed data. 34 | """ 35 | 36 | for t in self.processes: 37 | data = t(data) 38 | if data is None: 39 | return None 40 | return data 41 | 42 | def __repr__(self): 43 | format_string = self.__class__.__name__ + '(' 44 | for t in self.processes: 45 | format_string += '\n' 46 | format_string += f' {t}' 47 | format_string += '\n)' 48 | return format_string 49 | -------------------------------------------------------------------------------- /flamnet/datasets/process/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import numbers 6 | import collections 7 | from PIL import Image 8 | 9 | from ..registry import PROCESS 10 | 11 | 12 | def to_tensor(data): 13 | """Convert objects of various python types to :obj:`torch.Tensor`. 14 | 15 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 16 | :class:`Sequence`, :class:`int` and :class:`float`. 17 | 18 | Args: 19 | data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to 20 | be converted. 21 | """ 22 | 23 | if isinstance(data, torch.Tensor): 24 | return data 25 | elif isinstance(data, np.ndarray): 26 | return torch.from_numpy(data) 27 | elif isinstance(data, int): 28 | return torch.LongTensor([data]) 29 | elif isinstance(data, float): 30 | return torch.FloatTensor([data]) 31 | else: 32 | raise TypeError(f'type {type(data)} cannot be converted to tensor.') 33 | 34 | 35 | @PROCESS.register_module 36 | class ToTensor(object): 37 | """Convert some results to :obj:`torch.Tensor` by given keys. 38 | 39 | Args: 40 | keys (Sequence[str]): Keys that need to be converted to Tensor. 41 | """ 42 | def __init__(self, keys=['img', 'mask'], cfg=None): 43 | self.keys = keys 44 | 45 | def __call__(self, sample): 46 | data = {} 47 | if len(sample['img'].shape) < 3: 48 | sample['img'] = np.expand_dims(img, -1) 49 | for key in self.keys: 50 | if key == 'img_metas' or key == 'gt_masks' or key == 'lane_line': 51 | data[key] = sample[key] 52 | continue 53 | data[key] = to_tensor(sample[key]) 54 | data['img'] = data['img'].permute(2, 0, 1) 55 | return data 56 | 57 | def __repr__(self): 58 | return self.__class__.__name__ + f'(keys={self.keys})' 59 | 60 | 61 | @PROCESS.register_module 62 | class RandomLROffsetLABEL(object): 63 | def __init__(self, max_offset, cfg=None): 64 | self.max_offset = max_offset 65 | 66 | def __call__(self, sample): 67 | img = sample['img'] 68 | label = sample['mask'] 69 | offset = np.random.randint(-self.max_offset, self.max_offset) 70 | h, w = img.shape[:2] 71 | 72 | img = np.array(img) 73 | if offset > 0: 74 | img[:, offset:, :] = img[:, 0:w - offset, :] 75 | img[:, :offset, :] = 0 76 | if offset < 0: 77 | real_offset = -offset 78 | img[:, 0:w - real_offset, :] = img[:, real_offset:, :] 79 | img[:, w - real_offset:, :] = 0 80 | 81 | label = np.array(label) 82 | if offset > 0: 83 | label[:, offset:] = label[:, 0:w - offset] 84 | label[:, :offset] = 0 85 | if offset < 0: 86 | offset = -offset 87 | label[:, 0:w - offset] = label[:, offset:] 88 | label[:, w - offset:] = 0 89 | sample['img'] = img 90 | sample['mask'] = label 91 | 92 | return sample 93 | 94 | 95 | @PROCESS.register_module 96 | class RandomUDoffsetLABEL(object): 97 | def __init__(self, max_offset, cfg=None): 98 | self.max_offset = max_offset 99 | 100 | def __call__(self, sample): 101 | img = sample['img'] 102 | label = sample['mask'] 103 | offset = np.random.randint(-self.max_offset, self.max_offset) 104 | h, w = img.shape[:2] 105 | 106 | img = np.array(img) 107 | if offset > 0: 108 | img[offset:, :, :] = img[0:h - offset, :, :] 109 | img[:offset, :, :] = 0 110 | if offset < 0: 111 | real_offset = -offset 112 | img[0:h - real_offset, :, :] = img[real_offset:, :, :] 113 | img[h - real_offset:, :, :] = 0 114 | 115 | label = np.array(label) 116 | if offset > 0: 117 | label[offset:, :] = label[0:h - offset, :] 118 | label[:offset, :] = 0 119 | if offset < 0: 120 | offset = -offset 121 | label[0:h - offset, :] = label[offset:, :] 122 | label[h - offset:, :] = 0 123 | sample['img'] = img 124 | sample['mask'] = label 125 | return sample 126 | 127 | 128 | @PROCESS.register_module 129 | class Resize(object): 130 | def __init__(self, size, cfg=None): 131 | assert (isinstance(size, collections.Iterable) and len(size) == 2) 132 | self.size = size 133 | 134 | def __call__(self, sample): 135 | out = list() 136 | sample['img'] = cv2.resize(sample['img'], 137 | self.size, 138 | interpolation=cv2.INTER_CUBIC) 139 | if 'mask' in sample: 140 | sample['mask'] = cv2.resize(sample['mask'], 141 | self.size, 142 | interpolation=cv2.INTER_NEAREST) 143 | return sample 144 | 145 | 146 | @PROCESS.register_module 147 | class RandomCrop(object): 148 | def __init__(self, size, cfg=None): 149 | if isinstance(size, numbers.Number): 150 | self.size = (int(size), int(size)) 151 | else: 152 | self.size = size 153 | 154 | def __call__(self, img_group): 155 | h, w = img_group[0].shape[0:2] 156 | th, tw = self.size 157 | 158 | out_images = list() 159 | h1 = random.randint(0, max(0, h - th)) 160 | w1 = random.randint(0, max(0, w - tw)) 161 | h2 = min(h1 + th, h) 162 | w2 = min(w1 + tw, w) 163 | 164 | for img in img_group: 165 | assert (img.shape[0] == h and img.shape[1] == w) 166 | out_images.append(img[h1:h2, w1:w2, ...]) 167 | return out_images 168 | 169 | 170 | @PROCESS.register_module 171 | class CenterCrop(object): 172 | def __init__(self, size, cfg=None): 173 | if isinstance(size, numbers.Number): 174 | self.size = (int(size), int(size)) 175 | else: 176 | self.size = size 177 | 178 | def __call__(self, img_group): 179 | h, w = img_group[0].shape[0:2] 180 | th, tw = self.size 181 | 182 | out_images = list() 183 | h1 = max(0, int((h - th) / 2)) 184 | w1 = max(0, int((w - tw) / 2)) 185 | h2 = min(h1 + th, h) 186 | w2 = min(w1 + tw, w) 187 | 188 | for img in img_group: 189 | assert (img.shape[0] == h and img.shape[1] == w) 190 | out_images.append(img[h1:h2, w1:w2, ...]) 191 | return out_images 192 | 193 | 194 | @PROCESS.register_module 195 | class RandomRotation(object): 196 | def __init__(self, 197 | degree=(-10, 10), 198 | interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), 199 | padding=None, 200 | cfg=None): 201 | self.degree = degree 202 | self.interpolation = interpolation 203 | self.padding = padding 204 | if self.padding is None: 205 | self.padding = [0, 0] 206 | 207 | def _rotate_img(self, sample, map_matrix): 208 | h, w = sample['img'].shape[0:2] 209 | sample['img'] = cv2.warpAffine(sample['img'], 210 | map_matrix, (w, h), 211 | flags=cv2.INTER_LINEAR, 212 | borderMode=cv2.BORDER_CONSTANT, 213 | borderValue=self.padding) 214 | 215 | def _rotate_mask(self, sample, map_matrix): 216 | if 'mask' not in sample: 217 | return 218 | h, w = sample['mask'].shape[0:2] 219 | sample['mask'] = cv2.warpAffine(sample['mask'], 220 | map_matrix, (w, h), 221 | flags=cv2.INTER_NEAREST, 222 | borderMode=cv2.BORDER_CONSTANT, 223 | borderValue=self.padding) 224 | 225 | def __call__(self, sample): 226 | v = random.random() 227 | if v < 0.5: 228 | degree = random.uniform(self.degree[0], self.degree[1]) 229 | h, w = sample['img'].shape[0:2] 230 | center = (w / 2, h / 2) 231 | map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0) 232 | self._rotate_img(sample, map_matrix) 233 | self._rotate_mask(sample, map_matrix) 234 | return sample 235 | 236 | 237 | @PROCESS.register_module 238 | class RandomBlur(object): 239 | def __init__(self, applied, cfg=None): 240 | self.applied = applied 241 | 242 | def __call__(self, img_group): 243 | assert (len(self.applied) == len(img_group)) 244 | v = random.random() 245 | if v < 0.5: 246 | out_images = [] 247 | for img, a in zip(img_group, self.applied): 248 | if a: 249 | img = cv2.GaussianBlur(img, (5, 5), 250 | random.uniform(1e-6, 0.6)) 251 | out_images.append(img) 252 | if len(img.shape) > len(out_images[-1].shape): 253 | out_images[-1] = out_images[-1][ 254 | ..., np.newaxis] # single channel image 255 | return out_images 256 | else: 257 | return img_group 258 | 259 | 260 | @PROCESS.register_module 261 | class RandomHorizontalFlip(object): 262 | """Randomly horizontally flips the given numpy Image with a probability of 0.5 263 | """ 264 | def __init__(self, cfg=None): 265 | pass 266 | 267 | def __call__(self, sample): 268 | v = random.random() 269 | if v < 0.5: 270 | sample['img'] = np.fliplr(sample['img']) 271 | if 'mask' in sample: sample['mask'] = np.fliplr(sample['mask']) 272 | return sample 273 | 274 | 275 | @PROCESS.register_module 276 | class Normalize(object): 277 | def __init__(self, img_norm, cfg=None): 278 | self.mean = np.array(img_norm['mean'], dtype=np.float32) 279 | self.std = np.array(img_norm['std'], dtype=np.float32) 280 | 281 | def __call__(self, sample): 282 | m = self.mean 283 | s = self.std 284 | img = sample['img'] 285 | if len(m) == 1: 286 | img = img - np.array(m) # single channel image 287 | img = img / np.array(s) 288 | else: 289 | img = img - np.array(m)[np.newaxis, np.newaxis, ...] 290 | img = img / np.array(s)[np.newaxis, np.newaxis, ...] 291 | sample['img'] = img 292 | 293 | return sample 294 | 295 | 296 | def CLRTransforms(img_h, img_w): 297 | return [ 298 | dict(name='Resize', 299 | parameters=dict(size=dict(height=img_h, width=img_w)), 300 | p=1.0), 301 | dict(name='HorizontalFlip', parameters=dict(p=1.0), p=0.5), 302 | dict(name='Affine', 303 | parameters=dict(translate_percent=dict(x=(-0.1, 0.1), 304 | y=(-0.1, 0.1)), 305 | rotate=(-10, 10), 306 | scale=(0.8, 1.2)), 307 | p=0.7), 308 | dict(name='Resize', 309 | parameters=dict(size=dict(height=img_h, width=img_w)), 310 | p=1.0), 311 | ] 312 | -------------------------------------------------------------------------------- /flamnet/datasets/registry.py: -------------------------------------------------------------------------------- 1 | from flamnet.utils import Registry, build_from_cfg 2 | 3 | import torch 4 | from functools import partial 5 | import numpy as np 6 | import random 7 | from mmcv.parallel import collate 8 | 9 | DATASETS = Registry('datasets') 10 | PROCESS = Registry('process') 11 | 12 | 13 | def build(cfg, registry, default_args=None): 14 | if isinstance(cfg, list): 15 | modules = [ 16 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 17 | ] 18 | return nn.Sequential(*modules) 19 | else: 20 | return build_from_cfg(cfg, registry, default_args) 21 | 22 | 23 | def build_dataset(split_cfg, cfg): 24 | return build(split_cfg, DATASETS, default_args=dict(cfg=cfg)) 25 | 26 | 27 | def worker_init_fn(worker_id, seed): 28 | worker_seed = worker_id + seed 29 | np.random.seed(worker_seed) 30 | random.seed(worker_seed) 31 | 32 | 33 | def build_dataloader(split_cfg, cfg, is_train=True): 34 | if is_train: 35 | shuffle = True 36 | else: 37 | shuffle = False 38 | 39 | dataset = build_dataset(split_cfg, cfg) 40 | 41 | init_fn = partial(worker_init_fn, seed=cfg.seed) 42 | 43 | samples_per_gpu = cfg.batch_size // cfg.gpus 44 | data_loader = torch.utils.data.DataLoader( 45 | dataset, 46 | batch_size=cfg.batch_size, 47 | shuffle=shuffle, 48 | num_workers=cfg.workers, 49 | pin_memory=False, 50 | drop_last=False, 51 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 52 | worker_init_fn=init_fn) 53 | 54 | return data_loader 55 | -------------------------------------------------------------------------------- /flamnet/datasets/tusimple.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import cv2 4 | import os 5 | import json 6 | import torchvision 7 | from .base_dataset import BaseDataset 8 | from flamnet.utils.tusimple_metric import LaneEval 9 | from .registry import DATASETS 10 | import logging 11 | import random 12 | 13 | SPLIT_FILES = { 14 | 'trainval': 15 | ['label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'], 16 | 'train': ['label_data_0313.json', 'label_data_0601.json'], 17 | 'val': ['label_data_0531.json'], 18 | 'test': ['test_label.json'], 19 | } 20 | 21 | 22 | @DATASETS.register_module 23 | class TuSimple(BaseDataset): 24 | def __init__(self, data_root, split, processes=None, cfg=None): 25 | super().__init__(data_root, split, processes, cfg) 26 | self.anno_files = SPLIT_FILES[split] 27 | self.load_annotations() 28 | self.h_samples = list(range(160, 720, 10)) 29 | 30 | def load_annotations(self): 31 | self.logger.info('Loading TuSimple annotations...') 32 | self.data_infos = [] 33 | max_lanes = 0 34 | for anno_file in self.anno_files: 35 | anno_file = osp.join(self.data_root, anno_file) 36 | with open(anno_file, 'r') as anno_obj: 37 | lines = anno_obj.readlines() 38 | for line in lines: 39 | data = json.loads(line) 40 | y_samples = data['h_samples'] 41 | gt_lanes = data['lanes'] 42 | mask_path = data['raw_file'].replace('clips', 43 | 'seg_label')[:-3] + 'png' 44 | lanes = [[(x, y) for (x, y) in zip(lane, y_samples) if x >= 0] 45 | for lane in gt_lanes] 46 | lanes = [lane for lane in lanes if len(lane) > 0] 47 | max_lanes = max(max_lanes, len(lanes)) 48 | self.data_infos.append({ 49 | 'img_path': 50 | osp.join(self.data_root, data['raw_file']), 51 | 'img_name': 52 | data['raw_file'], 53 | 'mask_path': 54 | osp.join(self.data_root, mask_path), 55 | 'lanes': 56 | lanes, 57 | }) 58 | 59 | if self.training: 60 | random.shuffle(self.data_infos) 61 | self.max_lanes = max_lanes 62 | 63 | def pred2lanes(self, pred): 64 | ys = np.array(self.h_samples) / self.cfg.ori_img_h 65 | lanes = [] 66 | for lane in pred: 67 | xs = lane(ys) 68 | invalid_mask = xs < 0 69 | lane = (xs * self.cfg.ori_img_w).astype(int) 70 | lane[invalid_mask] = -2 71 | lanes.append(lane.tolist()) 72 | 73 | return lanes 74 | 75 | def pred2tusimpleformat(self, idx, pred, runtime): 76 | runtime *= 1000. # s to ms 77 | img_name = self.data_infos[idx]['img_name'] 78 | lanes = self.pred2lanes(pred) 79 | output = {'raw_file': img_name, 'lanes': lanes, 'run_time': runtime} 80 | return json.dumps(output) 81 | 82 | def save_tusimple_predictions(self, predictions, filename, runtimes=None): 83 | if runtimes is None: 84 | runtimes = np.ones(len(predictions)) * 1.e-3 85 | lines = [] 86 | for idx, (prediction, runtime) in enumerate(zip(predictions, 87 | runtimes)): 88 | line = self.pred2tusimpleformat(idx, prediction, runtime) 89 | lines.append(line) 90 | with open(filename, 'w') as output_file: 91 | output_file.write('\n'.join(lines)) 92 | 93 | def evaluate(self, predictions, output_basedir, runtimes=None): 94 | pred_filename = os.path.join(output_basedir, 95 | 'tusimple_predictions.json') 96 | self.save_tusimple_predictions(predictions, pred_filename, runtimes) 97 | result, acc = LaneEval.bench_one_submit(pred_filename, 98 | self.cfg.test_json_file) 99 | self.logger.info(result) 100 | return acc 101 | -------------------------------------------------------------------------------- /flamnet/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RanHao-cq/FLAMNet/0ad0e4c3245bfcf23aa5ecba2a17d31bb0e7d960/flamnet/engine/__init__.py -------------------------------------------------------------------------------- /flamnet/engine/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_optimizer(cfg, net): 5 | params = [] 6 | cfg_cp = cfg.optimizer.copy() 7 | cfg_type = cfg_cp.pop('type') 8 | 9 | if cfg_type not in dir(torch.optim): 10 | raise ValueError("{} is not defined.".format(cfg_type)) 11 | 12 | _optim = getattr(torch.optim, cfg_type) 13 | return _optim(net.parameters(), **cfg_cp) 14 | -------------------------------------------------------------------------------- /flamnet/engine/registry.py: -------------------------------------------------------------------------------- 1 | from flamnet.utils import Registry, build_from_cfg 2 | 3 | TRAINER = Registry('trainer') 4 | EVALUATOR = Registry('evaluator') 5 | 6 | 7 | def build(cfg, registry, default_args=None): 8 | if isinstance(cfg, list): 9 | modules = [ 10 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 11 | ] 12 | return nn.Sequential(*modules) 13 | else: 14 | return build_from_cfg(cfg, registry, default_args) 15 | 16 | 17 | def build_trainer(cfg): 18 | return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg)) 19 | 20 | 21 | def build_evaluator(cfg): 22 | return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg)) 23 | -------------------------------------------------------------------------------- /flamnet/engine/runner.py: -------------------------------------------------------------------------------- 1 | import time 2 | import cv2 3 | import torch 4 | from tqdm import tqdm 5 | import pytorch_warmup as warmup 6 | import numpy as np 7 | import random 8 | import os 9 | 10 | from flamnet.models.registry import build_net 11 | from .registry import build_trainer, build_evaluator 12 | from .optimizer import build_optimizer 13 | from .scheduler import build_scheduler 14 | from flamnet.datasets import build_dataloader 15 | from flamnet.utils.recorder import build_recorder 16 | from flamnet.utils.net_utils import save_model, load_network, resume_network 17 | from mmcv.parallel import MMDataParallel 18 | 19 | 20 | class Runner(object): 21 | def __init__(self, cfg): 22 | torch.manual_seed(cfg.seed) 23 | np.random.seed(cfg.seed) 24 | random.seed(cfg.seed) 25 | self.cfg = cfg 26 | self.recorder = build_recorder(self.cfg) 27 | self.net = build_net(self.cfg) 28 | self.net = MMDataParallel(self.net, 29 | device_ids=range(self.cfg.gpus)).cuda() 30 | self.recorder.logger.info('Network: \n' + str(self.net)) 31 | self.resume() 32 | self.optimizer = build_optimizer(self.cfg, self.net) 33 | self.scheduler = build_scheduler(self.cfg, self.optimizer) 34 | self.metric = 0. 35 | self.best_epoch = 0 36 | self.val_loader = None 37 | self.test_loader = None 38 | 39 | def to_cuda(self, batch): 40 | for k in batch: 41 | if not isinstance(batch[k], torch.Tensor): 42 | continue 43 | batch[k] = batch[k].cuda() 44 | return batch 45 | 46 | def resume(self): 47 | if not self.cfg.load_from and not self.cfg.finetune_from: 48 | return 49 | load_network(self.net, self.cfg.load_from, finetune_from=self.cfg.finetune_from, logger=self.recorder.logger) 50 | 51 | def train_epoch(self, epoch, train_loader): 52 | self.net.train() 53 | end = time.time() 54 | max_iter = len(train_loader) 55 | for i, data in enumerate(train_loader): 56 | if self.recorder.step >= self.cfg.total_iter: 57 | break 58 | date_time = time.time() - end 59 | self.recorder.step += 1 60 | data = self.to_cuda(data) 61 | output = self.net(data) 62 | self.optimizer.zero_grad() 63 | loss = output['loss'].sum() 64 | loss.backward() 65 | self.optimizer.step() 66 | if not self.cfg.lr_update_by_epoch: 67 | self.scheduler.step() 68 | batch_time = time.time() - end 69 | end = time.time() 70 | self.recorder.update_loss_stats(output['loss_stats']) 71 | self.recorder.batch_time.update(batch_time) 72 | self.recorder.data_time.update(date_time) 73 | 74 | if i % self.cfg.log_interval == 0 or i == max_iter - 1: 75 | lr = self.optimizer.param_groups[0]['lr'] 76 | self.recorder.lr = lr 77 | self.recorder.record('train') 78 | 79 | def train(self): 80 | self.recorder.logger.info('Build train loader...') 81 | train_loader = build_dataloader(self.cfg.dataset.train, 82 | self.cfg, 83 | is_train=True) 84 | 85 | self.recorder.logger.info('Start training...') 86 | start_epoch = 0 87 | if self.cfg.resume_from: 88 | start_epoch = resume_network(self.cfg.resume_from, self.net, 89 | self.optimizer, self.scheduler, 90 | self.recorder) 91 | for epoch in range(start_epoch, self.cfg.epochs): 92 | self.recorder.epoch = epoch 93 | self.train_epoch(epoch, train_loader) 94 | if (epoch +1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1: 95 | self.save_ckpt() 96 | if (epoch +1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1: 97 | self.validate() 98 | if self.recorder.step >= self.cfg.total_iter: 99 | break 100 | if self.cfg.lr_update_by_epoch: 101 | self.scheduler.step() 102 | 103 | 104 | def test(self): 105 | if not self.test_loader: 106 | self.test_loader = build_dataloader(self.cfg.dataset.test, 107 | self.cfg, 108 | is_train=False) 109 | self.net.eval() 110 | predictions = [] 111 | for i, data in enumerate(tqdm(self.test_loader, desc=f'Testing')): 112 | data = self.to_cuda(data) 113 | with torch.no_grad(): 114 | output = self.net(data) 115 | output = self.net.module.heads.get_lanes(output) 116 | predictions.extend(output) 117 | if self.cfg.view: 118 | self.test_loader.dataset.view(output, data['meta']) 119 | 120 | metric = self.test_loader.dataset.evaluate(predictions, 121 | self.cfg.work_dir) 122 | if metric is not None: 123 | self.recorder.logger.info('metric: ' + str(metric)) 124 | 125 | def validate(self): 126 | if not self.val_loader: 127 | self.val_loader = build_dataloader(self.cfg.dataset.val, 128 | self.cfg, 129 | is_train=False) 130 | self.net.eval() 131 | predictions = [] 132 | for i, data in enumerate(tqdm(self.val_loader, desc=f'Validate')): 133 | data = self.to_cuda(data) 134 | with torch.no_grad(): 135 | output = self.net(data) 136 | output = self.net.module.heads.get_lanes(output) 137 | predictions.extend(output) 138 | if self.cfg.view: 139 | self.val_loader.dataset.view(output, data['meta']) 140 | 141 | metric = self.val_loader.dataset.evaluate(predictions,self.cfg.work_dir) 142 | self.recorder.logger.info('metric: ' + str(metric)) 143 | 144 | if not metric: 145 | return 146 | if metric > self.metric: 147 | self.metric = metric 148 | self.best_epoch = self.recorder.epoch 149 | self.save_ckpt(is_best=True) 150 | self.recorder.logger.info('Best metric: ' + str(self.metric) + ' epoch_best:' + str(self.best_epoch)) 151 | 152 | def save_ckpt(self, is_best=False): 153 | save_model(self.net, self.optimizer, self.scheduler, self.recorder, 154 | is_best) 155 | -------------------------------------------------------------------------------- /flamnet/engine/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def build_scheduler(cfg, optimizer): 6 | 7 | cfg_cp = cfg.scheduler.copy() 8 | cfg_type = cfg_cp.pop('type') 9 | 10 | if cfg_type not in dir(torch.optim.lr_scheduler): 11 | raise ValueError("{} is not defined.".format(cfg_type)) 12 | 13 | _scheduler = getattr(torch.optim.lr_scheduler, cfg_type) 14 | 15 | return _scheduler(optimizer, **cfg_cp) 16 | -------------------------------------------------------------------------------- /flamnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * 2 | from .heads import * 3 | from .nets import * 4 | from .necks import * 5 | from .registry import build_backbones 6 | -------------------------------------------------------------------------------- /flamnet/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet 2 | from .dla34 import DLA 3 | # from .topformer import Topformer -------------------------------------------------------------------------------- /flamnet/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .flamnet_head import FLAMNetHead -------------------------------------------------------------------------------- /flamnet/models/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RanHao-cq/FLAMNet/0ad0e4c3245bfcf23aa5ecba2a17d31bb0e7d960/flamnet/models/losses/__init__.py -------------------------------------------------------------------------------- /flamnet/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import torch.nn as nn 3 | 4 | 5 | @mmcv.jit(coderize=True) 6 | def accuracy(pred, target, topk=1, thresh=None): 7 | """Calculate accuracy according to the prediction and target. 8 | 9 | Args: 10 | pred (torch.Tensor): The model prediction, shape (N, num_class) 11 | target (torch.Tensor): The target of each prediction, shape (N, ) 12 | topk (int | tuple[int], optional): If the predictions in ``topk`` 13 | matches the target, the predictions will be regarded as 14 | correct ones. Defaults to 1. 15 | thresh (float, optional): If not None, predictions with scores under 16 | this threshold are considered incorrect. Default to None. 17 | 18 | Returns: 19 | float | tuple[float]: If the input ``topk`` is a single integer, 20 | the function will return a single float as accuracy. If 21 | ``topk`` is a tuple containing multiple integers, the 22 | function will return a tuple containing accuracies of 23 | each ``topk`` number. 24 | """ 25 | assert isinstance(topk, (int, tuple)) 26 | if isinstance(topk, int): 27 | topk = (topk, ) 28 | return_single = True 29 | else: 30 | return_single = False 31 | 32 | maxk = max(topk) 33 | if pred.size(0) == 0: 34 | accu = [pred.new_tensor(0.) for i in range(len(topk))] 35 | return accu[0] if return_single else accu 36 | assert pred.ndim == 2 and target.ndim == 1 37 | assert pred.size(0) == target.size(0) 38 | assert maxk <= pred.size(1), \ 39 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 40 | pred_value, pred_label = pred.topk(maxk, dim=1) 41 | pred_label = pred_label.t() # transpose to shape (maxk, N) 42 | correct = pred_label.eq(target.view(1, -1).expand_as(pred_label)) 43 | if thresh is not None: 44 | # Only prediction values larger than thresh are counted as correct 45 | correct = correct & (pred_value > thresh).t() 46 | res = [] 47 | for k in topk: 48 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 49 | res.append(correct_k.mul_(100.0 / pred.size(0))) 50 | return res[0] if return_single else res 51 | 52 | 53 | class Accuracy(nn.Module): 54 | def __init__(self, topk=(1, ), thresh=None): 55 | """Module to calculate the accuracy. 56 | 57 | Args: 58 | topk (tuple, optional): The criterion used to calculate the 59 | accuracy. Defaults to (1,). 60 | thresh (float, optional): If not None, predictions with scores 61 | under this threshold are considered incorrect. Default to None. 62 | """ 63 | super().__init__() 64 | self.topk = topk 65 | self.thresh = thresh 66 | 67 | def forward(self, pred, target): 68 | """Forward function to calculate accuracy. 69 | 70 | Args: 71 | pred (torch.Tensor): Prediction of models. 72 | target (torch.Tensor): Target for each prediction. 73 | 74 | Returns: 75 | tuple[float]: The accuracies under different topk criterions. 76 | """ 77 | return accuracy(pred, target, self.topk, self.thresh) 78 | -------------------------------------------------------------------------------- /flamnet/models/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | # pylint: disable-all 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | # Source: https://github.com/kornia/kornia/blob/f4f70fefb63287f72bc80cd96df9c061b1cb60dd/kornia/losses/focal.py 9 | 10 | 11 | class SoftmaxFocalLoss(nn.Module): 12 | def __init__(self, gamma, ignore_lb=255, *args, **kwargs): 13 | super(SoftmaxFocalLoss, self).__init__() 14 | self.gamma = gamma 15 | self.nll = nn.NLLLoss(ignore_index=ignore_lb) 16 | 17 | def forward(self, logits, labels): 18 | scores = F.softmax(logits, dim=1) 19 | factor = torch.pow(1. - scores, self.gamma) 20 | log_score = F.log_softmax(logits, dim=1) 21 | log_score = factor * log_score 22 | loss = self.nll(log_score, labels) 23 | return loss 24 | 25 | 26 | def one_hot(labels: torch.Tensor, 27 | num_classes: int, 28 | device: Optional[torch.device] = None, 29 | dtype: Optional[torch.dtype] = None, 30 | eps: Optional[float] = 1e-6) -> torch.Tensor: 31 | r"""Converts an integer label x-D tensor to a one-hot (x+1)-D tensor. 32 | 33 | Args: 34 | labels (torch.Tensor) : tensor with labels of shape :math:`(N, *)`, 35 | where N is batch size. Each value is an integer 36 | representing correct classification. 37 | num_classes (int): number of classes in labels. 38 | device (Optional[torch.device]): the desired device of returned tensor. 39 | Default: if None, uses the current device for the default tensor type 40 | (see torch.set_default_tensor_type()). device will be the CPU for CPU 41 | tensor types and the current CUDA device for CUDA tensor types. 42 | dtype (Optional[torch.dtype]): the desired data type of returned 43 | tensor. Default: if None, infers data type from values. 44 | 45 | Returns: 46 | torch.Tensor: the labels in one hot tensor of shape :math:`(N, C, *)`, 47 | 48 | Examples:: 49 | >>> labels = torch.LongTensor([[[0, 1], [2, 0]]]) 50 | >>> kornia.losses.one_hot(labels, num_classes=3) 51 | tensor([[[[1., 0.], 52 | [0., 1.]], 53 | [[0., 1.], 54 | [0., 0.]], 55 | [[0., 0.], 56 | [1., 0.]]]] 57 | """ 58 | if not torch.is_tensor(labels): 59 | raise TypeError( 60 | "Input labels type is not a torch.Tensor. Got {}".format( 61 | type(labels))) 62 | if not labels.dtype == torch.int64: 63 | raise ValueError( 64 | "labels must be of the same dtype torch.int64. Got: {}".format( 65 | labels.dtype)) 66 | if num_classes < 1: 67 | raise ValueError("The number of classes must be bigger than one." 68 | " Got: {}".format(num_classes)) 69 | shape = labels.shape 70 | one_hot = torch.zeros(shape[0], 71 | num_classes, 72 | *shape[1:], 73 | device=device, 74 | dtype=dtype) 75 | return one_hot.scatter_(1, labels.unsqueeze(1), 1.0) + eps 76 | 77 | 78 | def focal_loss(input: torch.Tensor, 79 | target: torch.Tensor, 80 | alpha: float, 81 | gamma: float = 2.0, 82 | reduction: str = 'none', 83 | eps: float = 1e-8) -> torch.Tensor: 84 | r"""Function that computes Focal loss. 85 | 86 | See :class:`~kornia.losses.FocalLoss` for details. 87 | """ 88 | if not torch.is_tensor(input): 89 | raise TypeError("Input type is not a torch.Tensor. Got {}".format( 90 | type(input))) 91 | 92 | if not len(input.shape) >= 2: 93 | raise ValueError( 94 | "Invalid input shape, we expect BxCx*. Got: {}".format( 95 | input.shape)) 96 | 97 | if input.size(0) != target.size(0): 98 | raise ValueError( 99 | 'Expected input batch_size ({}) to match target batch_size ({}).'. 100 | format(input.size(0), target.size(0))) 101 | 102 | n = input.size(0) 103 | out_size = (n, ) + input.size()[2:] 104 | if target.size()[1:] != input.size()[2:]: 105 | raise ValueError('Expected target size {}, got {}'.format( 106 | out_size, target.size())) 107 | 108 | if not input.device == target.device: 109 | raise ValueError( 110 | "input and target must be in the same device. Got: {} and {}". 111 | format(input.device, target.device)) 112 | 113 | # compute softmax over the classes axis 114 | input_soft: torch.Tensor = F.softmax(input, dim=1) + eps 115 | 116 | # create the labels one hot tensor 117 | target_one_hot: torch.Tensor = one_hot(target, 118 | num_classes=input.shape[1], 119 | device=input.device, 120 | dtype=input.dtype) 121 | 122 | # compute the actual focal loss 123 | weight = torch.pow(-input_soft + 1., gamma) 124 | 125 | focal = -alpha * weight * torch.log(input_soft) 126 | loss_tmp = torch.sum(target_one_hot * focal, dim=1) 127 | 128 | if reduction == 'none': 129 | loss = loss_tmp 130 | elif reduction == 'mean': 131 | loss = torch.mean(loss_tmp) 132 | elif reduction == 'sum': 133 | loss = torch.sum(loss_tmp) 134 | else: 135 | raise NotImplementedError( 136 | "Invalid reduction mode: {}".format(reduction)) 137 | return loss 138 | 139 | 140 | class FocalLoss(nn.Module): 141 | r"""Criterion that computes Focal loss. 142 | 143 | According to [1], the Focal loss is computed as follows: 144 | 145 | .. math:: 146 | 147 | \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) 148 | 149 | where: 150 | - :math:`p_t` is the model's estimated probability for each class. 151 | 152 | 153 | Arguments: 154 | alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. 155 | gamma (float): Focusing parameter :math:`\gamma >= 0`. 156 | reduction (str, optional): Specifies the reduction to apply to the 157 | output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, 158 | ‘mean’: the sum of the output will be divided by the number of elements 159 | in the output, ‘sum’: the output will be summed. Default: ‘none’. 160 | 161 | Shape: 162 | - Input: :math:`(N, C, *)` where C = number of classes. 163 | - Target: :math:`(N, *)` where each value is 164 | :math:`0 ≤ targets[i] ≤ C−1`. 165 | 166 | Examples: 167 | >>> N = 5 # num_classes 168 | >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'} 169 | >>> loss = kornia.losses.FocalLoss(**kwargs) 170 | >>> input = torch.randn(1, N, 3, 5, requires_grad=True) 171 | >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) 172 | >>> output = loss(input, target) 173 | >>> output.backward() 174 | 175 | References: 176 | [1] https://arxiv.org/abs/1708.02002 177 | """ 178 | def __init__(self, 179 | alpha: float, 180 | gamma: float = 2.0, 181 | reduction: str = 'none') -> None: 182 | super(FocalLoss, self).__init__() 183 | self.alpha: float = alpha 184 | self.gamma: float = gamma 185 | self.reduction: str = reduction 186 | self.eps: float = 1e-6 187 | 188 | def forward( # type: ignore 189 | self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 190 | return focal_loss(input, target, self.alpha, self.gamma, 191 | self.reduction, self.eps) 192 | -------------------------------------------------------------------------------- /flamnet/models/losses/lineiou_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def line_iou(pred, target, img_w, length=15, aligned=True): 5 | ''' 6 | Calculate the line iou value between predictions and targets 7 | Args: 8 | pred: lane predictions, shape: (num_pred, 72) 9 | target: ground truth, shape: (num_target, 72) 10 | img_w: image width 11 | length: extended radius 12 | aligned: True for iou loss calculation, False for pair-wise ious in assign 13 | ''' 14 | px1 = pred - length 15 | px2 = pred + length 16 | tx1 = target - length 17 | tx2 = target + length 18 | if aligned: 19 | invalid_mask = target 20 | ovr = torch.min(px2, tx2) - torch.max(px1, tx1) 21 | union = torch.max(px2, tx2) - torch.min(px1, tx1) 22 | else: 23 | num_pred = pred.shape[0] 24 | invalid_mask = target.repeat(num_pred, 1, 1) 25 | ovr = (torch.min(px2[:, None, :], tx2[None, ...]) - 26 | torch.max(px1[:, None, :], tx1[None, ...])) 27 | union = (torch.max(px2[:, None, :], tx2[None, ...]) - 28 | torch.min(px1[:, None, :], tx1[None, ...])) 29 | 30 | invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w) 31 | ovr[invalid_masks] = 0. 32 | union[invalid_masks] = 0. 33 | iou = ovr.sum(dim=-1) / (union.sum(dim=-1) + 1e-9) 34 | return iou 35 | 36 | 37 | def liou_loss(pred, target, img_w, length=15): 38 | return (1 - line_iou(pred, target, img_w, length)).mean() -------------------------------------------------------------------------------- /flamnet/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPN 2 | from .pafpn import PAFPN 3 | from .ppam_dsaformer import PPAM_DSAformer 4 | -------------------------------------------------------------------------------- /flamnet/models/necks/fpn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mmcv.cnn import ConvModule 7 | from ..registry import NECKS 8 | 9 | 10 | @NECKS.register_module 11 | class FPN(nn.Module): 12 | def __init__(self, 13 | in_channels, 14 | out_channels, 15 | num_outs, 16 | start_level=0, 17 | end_level=-1, 18 | add_extra_convs=False, 19 | extra_convs_on_inputs=True, 20 | relu_before_extra_convs=False, 21 | no_norm_on_lateral=False, 22 | conv_cfg=None, 23 | norm_cfg=None, 24 | attention=False, 25 | act_cfg=None, 26 | upsample_cfg=dict(mode='nearest'), 27 | init_cfg=dict(type='Xavier', 28 | layer='Conv2d', 29 | distribution='uniform'), 30 | cfg=None): 31 | super(FPN, self).__init__() 32 | assert isinstance(in_channels, list) 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.num_ins = len(in_channels) 36 | self.num_outs = num_outs 37 | self.attention = attention 38 | self.relu_before_extra_convs = relu_before_extra_convs 39 | self.no_norm_on_lateral = no_norm_on_lateral 40 | self.upsample_cfg = upsample_cfg.copy() 41 | 42 | if end_level == -1: 43 | self.backbone_end_level = self.num_ins 44 | assert num_outs >= self.num_ins - start_level 45 | else: 46 | # if end_level < inputs, no extra level is allowed 47 | self.backbone_end_level = end_level 48 | assert end_level <= len(in_channels) 49 | assert num_outs == end_level - start_level 50 | self.start_level = start_level 51 | self.end_level = end_level 52 | self.add_extra_convs = add_extra_convs 53 | assert isinstance(add_extra_convs, (str, bool)) 54 | if isinstance(add_extra_convs, str): 55 | # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' 56 | assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') 57 | elif add_extra_convs: # True 58 | if extra_convs_on_inputs: 59 | # TODO: deprecate `extra_convs_on_inputs` 60 | warnings.simplefilter('once') 61 | warnings.warn( 62 | '"extra_convs_on_inputs" will be deprecated in v2.9.0,' 63 | 'Please use "add_extra_convs"', DeprecationWarning) 64 | self.add_extra_convs = 'on_input' 65 | else: 66 | self.add_extra_convs = 'on_output' 67 | 68 | self.lateral_convs = nn.ModuleList() 69 | self.fpn_convs = nn.ModuleList() 70 | 71 | for i in range(self.start_level, self.backbone_end_level): 72 | l_conv = ConvModule( 73 | in_channels[i], 74 | out_channels, 75 | 1, 76 | conv_cfg=conv_cfg, 77 | norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, 78 | act_cfg=act_cfg, 79 | inplace=False) 80 | fpn_conv = ConvModule(out_channels, 81 | out_channels, 82 | 3, 83 | padding=1, 84 | conv_cfg=conv_cfg, 85 | norm_cfg=norm_cfg, 86 | act_cfg=act_cfg, 87 | inplace=False) 88 | 89 | self.lateral_convs.append(l_conv) 90 | self.fpn_convs.append(fpn_conv) 91 | 92 | # add extra conv layers (e.g., RetinaNet) 93 | extra_levels = num_outs - self.backbone_end_level + self.start_level 94 | if self.add_extra_convs and extra_levels >= 1: 95 | for i in range(extra_levels): 96 | if i == 0 and self.add_extra_convs == 'on_input': 97 | in_channels = self.in_channels[self.backbone_end_level - 1] 98 | else: 99 | in_channels = out_channels 100 | extra_fpn_conv = ConvModule(in_channels, 101 | out_channels, 102 | 3, 103 | stride=2, 104 | padding=1, 105 | conv_cfg=conv_cfg, 106 | norm_cfg=norm_cfg, 107 | act_cfg=act_cfg, 108 | inplace=False) 109 | self.fpn_convs.append(extra_fpn_conv) 110 | 111 | def forward(self, inputs): 112 | """Forward function.""" 113 | assert len(inputs) >= len(self.in_channels) 114 | 115 | if len(inputs) > len(self.in_channels): 116 | for _ in range(len(inputs) - len(self.in_channels)): 117 | del inputs[0] 118 | 119 | # build laterals 120 | laterals = [ 121 | lateral_conv(inputs[i + self.start_level]) 122 | for i, lateral_conv in enumerate(self.lateral_convs) 123 | ] 124 | 125 | # build top-down path 126 | used_backbone_levels = len(laterals) 127 | for i in range(used_backbone_levels - 1, 0, -1): 128 | # In some cases, fixing `scale factor` (e.g. 2) is preferred, but 129 | # it cannot co-exist with `size` in `F.interpolate`. 130 | if 'scale_factor' in self.upsample_cfg: 131 | laterals[i - 1] += F.interpolate(laterals[i], 132 | **self.upsample_cfg) 133 | else: 134 | prev_shape = laterals[i - 1].shape[2:] 135 | laterals[i - 1] += F.interpolate(laterals[i], 136 | size=prev_shape, 137 | **self.upsample_cfg) 138 | 139 | # build outputs 140 | # part 1: from original levels 141 | outs = [ 142 | self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) 143 | ] 144 | # part 2: add extra levels 145 | if self.num_outs > len(outs): 146 | # use max pool to get more levels on top of outputs 147 | # (e.g., Faster R-CNN, Mask R-CNN) 148 | if not self.add_extra_convs: 149 | for i in range(self.num_outs - used_backbone_levels): 150 | outs.append(F.max_pool2d(outs[-1], 1, stride=2)) 151 | # add conv layers on top of original feature maps (RetinaNet) 152 | else: 153 | if self.add_extra_convs == 'on_input': 154 | extra_source = inputs[self.backbone_end_level - 1] 155 | elif self.add_extra_convs == 'on_lateral': 156 | extra_source = laterals[-1] 157 | elif self.add_extra_convs == 'on_output': 158 | extra_source = outs[-1] 159 | else: 160 | raise NotImplementedError 161 | outs.append(self.fpn_convs[used_backbone_levels](extra_source)) 162 | for i in range(used_backbone_levels + 1, self.num_outs): 163 | if self.relu_before_extra_convs: 164 | outs.append(self.fpn_convs[i](F.relu(outs[-1]))) 165 | else: 166 | outs.append(self.fpn_convs[i](outs[-1])) 167 | return tuple(outs) 168 | -------------------------------------------------------------------------------- /flamnet/models/necks/pafpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from mmcv.cnn import ConvModule 4 | from mmcv.runner import auto_fp16 5 | 6 | from ..registry import NECKS 7 | from .fpn import FPN 8 | 9 | 10 | @NECKS.register_module 11 | class PAFPN(FPN): 12 | """Path Aggregation Network for Instance Segmentation. 13 | 14 | This is an implementation of the `PAFPN in Path Aggregation Network 15 | `_. 16 | 17 | Args: 18 | in_channels (List[int]): Number of input channels per scale. 19 | out_channels (int): Number of output channels (used at each scale) 20 | num_outs (int): Number of output scales. 21 | start_level (int): Index of the start input backbone level used to 22 | build the feature pyramid. Default: 0. 23 | end_level (int): Index of the end input backbone level (exclusive) to 24 | build the feature pyramid. Default: -1, which means the last level. 25 | add_extra_convs (bool): Whether to add conv layers on top of the 26 | original feature maps. Default: False. 27 | extra_convs_on_inputs (bool): Whether to apply extra conv on 28 | the original feature from the backbone. Default: False. 29 | relu_before_extra_convs (bool): Whether to apply relu before the extra 30 | conv. Default: False. 31 | no_norm_on_lateral (bool): Whether to apply norm on lateral. 32 | Default: False. 33 | conv_cfg (dict): Config dict for convolution layer. Default: None. 34 | norm_cfg (dict): Config dict for normalization layer. Default: None. 35 | act_cfg (str): Config dict for activation layer in ConvModule. 36 | Default: None. 37 | """ 38 | def __init__(self, 39 | in_channels, 40 | out_channels, 41 | num_outs, 42 | start_level=0, 43 | end_level=-1, 44 | add_extra_convs=False, 45 | extra_convs_on_inputs=True, 46 | relu_before_extra_convs=False, 47 | no_norm_on_lateral=False, 48 | conv_cfg=None, 49 | norm_cfg=None, 50 | act_cfg=None, 51 | cfg=None, 52 | attention=False): 53 | super(PAFPN, self).__init__(in_channels, 54 | out_channels, 55 | num_outs, 56 | start_level, 57 | end_level, 58 | add_extra_convs, 59 | extra_convs_on_inputs, 60 | relu_before_extra_convs, 61 | no_norm_on_lateral, 62 | conv_cfg, 63 | norm_cfg, 64 | attention, 65 | act_cfg, 66 | cfg=cfg) 67 | # add extra bottom up pathway 68 | self.downsample_convs = nn.ModuleList() 69 | self.pafpn_convs = nn.ModuleList() 70 | for i in range(self.start_level + 1, self.backbone_end_level): 71 | d_conv = ConvModule(out_channels, 72 | out_channels, 73 | 3, 74 | stride=2, 75 | padding=1, 76 | conv_cfg=conv_cfg, 77 | norm_cfg=norm_cfg, 78 | act_cfg=act_cfg, 79 | inplace=False) 80 | pafpn_conv = ConvModule(out_channels, 81 | out_channels, 82 | 3, 83 | padding=1, 84 | conv_cfg=conv_cfg, 85 | norm_cfg=norm_cfg, 86 | act_cfg=act_cfg, 87 | inplace=False) 88 | self.downsample_convs.append(d_conv) 89 | self.pafpn_convs.append(pafpn_conv) 90 | 91 | def forward(self, inputs): 92 | """Forward function.""" 93 | assert len(inputs) >= len(self.in_channels) 94 | 95 | if len(inputs) > len(self.in_channels): 96 | for _ in range(len(inputs) - len(self.in_channels)): 97 | del inputs[0] 98 | 99 | # build laterals 100 | laterals = [ 101 | lateral_conv(inputs[i + self.start_level]) 102 | for i, lateral_conv in enumerate(self.lateral_convs) 103 | ] 104 | 105 | # build top-down path 106 | used_backbone_levels = len(laterals) 107 | for i in range(used_backbone_levels - 1, 0, -1): 108 | prev_shape = laterals[i - 1].shape[2:] 109 | laterals[i - 1] += F.interpolate(laterals[i], 110 | size=prev_shape, 111 | mode='nearest') 112 | 113 | # build outputs 114 | # part 1: from original levels 115 | inter_outs = [ 116 | self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) 117 | ] 118 | 119 | # part 2: add bottom-up path 120 | for i in range(0, used_backbone_levels - 1): 121 | inter_outs[i + 1] += self.downsample_convs[i](inter_outs[i]) 122 | 123 | outs = [] 124 | outs.append(inter_outs[0]) 125 | outs.extend([ 126 | self.pafpn_convs[i - 1](inter_outs[i]) 127 | for i in range(1, used_backbone_levels) 128 | ]) 129 | 130 | # part 3: add extra levels 131 | if self.num_outs > len(outs): 132 | # use max pool to get more levels on top of outputs 133 | # (e.g., Faster R-CNN, Mask R-CNN) 134 | if not self.add_extra_convs: 135 | for i in range(self.num_outs - used_backbone_levels): 136 | outs.append(F.max_pool2d(outs[-1], 1, stride=2)) 137 | # add conv layers on top of original feature maps (RetinaNet) 138 | else: 139 | if self.add_extra_convs == 'on_input': 140 | orig = inputs[self.backbone_end_level - 1] 141 | outs.append(self.fpn_convs[used_backbone_levels](orig)) 142 | elif self.add_extra_convs == 'on_lateral': 143 | outs.append(self.fpn_convs[used_backbone_levels]( 144 | laterals[-1])) 145 | elif self.add_extra_convs == 'on_output': 146 | outs.append(self.fpn_convs[used_backbone_levels](outs[-1])) 147 | else: 148 | raise NotImplementedError 149 | for i in range(used_backbone_levels + 1, self.num_outs): 150 | if self.relu_before_extra_convs: 151 | outs.append(self.fpn_convs[i](F.relu(outs[-1]))) 152 | else: 153 | outs.append(self.fpn_convs[i](outs[-1])) 154 | return tuple(outs) 155 | -------------------------------------------------------------------------------- /flamnet/models/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector import Detector 2 | -------------------------------------------------------------------------------- /flamnet/models/nets/detector.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from flamnet.models.registry import NETS 5 | from ..registry import build_backbones, build_aggregator, build_heads, build_necks 6 | # from ..backbones.topformer import Topformer 7 | from ..necks.ppam_dsaformer import PPAM_DSAformer 8 | 9 | @NETS.register_module 10 | class Detector(nn.Module): 11 | def __init__(self, cfg): 12 | super(Detector, self).__init__() 13 | self.cfg = cfg 14 | self.backbone = build_backbones(cfg) 15 | # self.backbone = Topformer(cfgs=cfg.backbone.cfgs, 16 | # channels=cfg.backbone.channels, 17 | # out_channels=cfg.backbone.out_channels, 18 | # embed_out_indice=cfg.backbone.embed_out_indice) 19 | self.aggregator = build_aggregator(cfg) if cfg.haskey('aggregator') else None 20 | #self.neck = build_necks(cfg) if cfg.haskey('neck') else None 21 | self.neck = PPAM_DSAformer(in_channels=cfg.neck.in_channels, 22 | channels=cfg.neck.channels, 23 | out_channels=cfg.neck.out_channels, 24 | depths=cfg.neck.depths, 25 | num_heads=cfg.neck.num_heads, 26 | c2t_stride=cfg.neck.c2t_stride, 27 | drop_path_rate=cfg.neck.drop_path_rate) 28 | self.heads = build_heads(cfg) 29 | 30 | def get_lanes(self): 31 | return self.heads.get_lanes(output) 32 | 33 | def forward(self, batch): 34 | output = {} 35 | fea = self.backbone(batch['img'] if isinstance(batch, dict) else batch) 36 | 37 | if self.aggregator: 38 | fea[-1] = self.aggregator(fea[-1]) 39 | 40 | if self.neck: 41 | fea = self.neck(fea) 42 | 43 | if self.training: 44 | output = self.heads(fea, batch=batch) 45 | else: 46 | output = self.heads(fea) 47 | 48 | return output 49 | -------------------------------------------------------------------------------- /flamnet/models/registry.py: -------------------------------------------------------------------------------- 1 | from flamnet.utils import Registry, build_from_cfg 2 | import torch.nn as nn 3 | 4 | BACKBONES = Registry('backbones') 5 | AGGREGATORS = Registry('aggregators') 6 | HEADS = Registry('heads') 7 | NECKS = Registry('necks') 8 | NETS = Registry('nets') 9 | 10 | 11 | def build(cfg, registry, default_args=None): 12 | if isinstance(cfg, list): 13 | modules = [ 14 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 15 | ] 16 | return nn.Sequential(*modules) 17 | else: 18 | return build_from_cfg(cfg, registry, default_args) 19 | 20 | 21 | def build_backbones(cfg): 22 | return build(cfg.backbone, BACKBONES, default_args=dict(cfg=cfg)) 23 | 24 | 25 | def build_necks(cfg): 26 | return build(cfg.necks, NECKS, default_args=dict(cfg=cfg)) 27 | 28 | 29 | def build_aggregator(cfg): 30 | return build(cfg.aggregator, AGGREGATORS, default_args=dict(cfg=cfg)) 31 | 32 | 33 | def build_heads(cfg): 34 | return build(cfg.heads, HEADS, default_args=dict(cfg=cfg)) 35 | 36 | 37 | def build_head(split_cfg, cfg): 38 | return build(split_cfg, HEADS, default_args=dict(cfg=cfg)) 39 | 40 | 41 | def build_net(cfg): 42 | return build(cfg.net, NETS, default_args=dict(cfg=cfg)) 43 | 44 | def build_necks(cfg): 45 | return build(cfg.neck, NECKS, default_args=dict(cfg=cfg)) 46 | -------------------------------------------------------------------------------- /flamnet/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RanHao-cq/FLAMNet/0ad0e4c3245bfcf23aa5ecba2a17d31bb0e7d960/flamnet/models/utils/__init__.py -------------------------------------------------------------------------------- /flamnet/models/utils/dynamic_assign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flamnet.models.losses.lineiou_loss import line_iou 3 | 4 | 5 | def distance_cost(predictions, targets, img_w): 6 | """ 7 | repeat predictions and targets to generate all combinations 8 | use the abs distance as the new distance cost 9 | """ 10 | num_priors = predictions.shape[0] 11 | num_targets = targets.shape[0] 12 | 13 | predictions = torch.repeat_interleave( 14 | predictions, num_targets, dim=0 15 | )[..., 16 | 6:] # repeat_interleave'ing [a, b] 2 times gives [a, a, b, b] ((np + nt) * 78) 17 | 18 | targets = torch.cat( 19 | num_priors * 20 | [targets])[..., 21 | 6:] # applying this 2 times on [c, d] gives [c, d, c, d] 22 | 23 | invalid_masks = (targets < 0) | (targets >= img_w) 24 | lengths = (~invalid_masks).sum(dim=1) 25 | distances = torch.abs((targets - predictions)) 26 | distances[invalid_masks] = 0. 27 | distances = distances.sum(dim=1) / (lengths.float() + 1e-9) 28 | distances = distances.view(num_priors, num_targets) 29 | 30 | return distances 31 | 32 | 33 | def focal_cost(cls_pred, gt_labels, alpha=0.25, gamma=2, eps=1e-12): 34 | """ 35 | Args: 36 | cls_pred (Tensor): Predicted classification logits, shape 37 | [num_query, num_class]. 38 | gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). 39 | 40 | Returns: 41 | torch.Tensor: cls_cost value 42 | """ 43 | cls_pred = cls_pred.sigmoid() 44 | neg_cost = -(1 - cls_pred + eps).log() * (1 - alpha) * cls_pred.pow(gamma) 45 | pos_cost = -(cls_pred + eps).log() * alpha * (1 - cls_pred).pow(gamma) 46 | cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] 47 | return cls_cost 48 | 49 | 50 | def dynamic_k_assign(cost, pair_wise_ious): 51 | """ 52 | Assign grouth truths with priors dynamically. 53 | 54 | Args: 55 | cost: the assign cost. 56 | pair_wise_ious: iou of grouth truth and priors. 57 | 58 | Returns: 59 | prior_idx: the index of assigned prior. 60 | gt_idx: the corresponding ground truth index. 61 | """ 62 | matching_matrix = torch.zeros_like(cost) 63 | ious_matrix = pair_wise_ious 64 | ious_matrix[ious_matrix < 0] = 0. 65 | n_candidate_k = 4 66 | topk_ious, _ = torch.topk(ious_matrix, n_candidate_k, dim=0) 67 | dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1) 68 | num_gt = cost.shape[1] 69 | for gt_idx in range(num_gt): 70 | _, pos_idx = torch.topk(cost[:, gt_idx], 71 | k=dynamic_ks[gt_idx].item(), 72 | largest=False) 73 | matching_matrix[pos_idx, gt_idx] = 1.0 74 | del topk_ious, dynamic_ks, pos_idx 75 | 76 | matched_gt = matching_matrix.sum(1) 77 | if (matched_gt > 1).sum() > 0: 78 | _, cost_argmin = torch.min(cost[matched_gt > 1, :], dim=1) 79 | matching_matrix[matched_gt > 1, 0] *= 0.0 80 | matching_matrix[matched_gt > 1, cost_argmin] = 1.0 81 | 82 | prior_idx = matching_matrix.sum(1).nonzero() 83 | gt_idx = matching_matrix[prior_idx].argmax(-1) 84 | return prior_idx.flatten(), gt_idx.flatten() 85 | 86 | 87 | def assign( 88 | predictions, 89 | targets, 90 | img_w, 91 | img_h, 92 | distance_cost_weight=3., 93 | cls_cost_weight=1., 94 | ): 95 | ''' 96 | computes dynamicly matching based on the cost, including cls cost and lane similarity cost 97 | Args: 98 | predictions (Tensor): predictions predicted by each stage, shape: (num_priors, 78) 99 | targets (Tensor): lane targets, shape: (num_targets, 78) 100 | return: 101 | matched_row_inds (Tensor): matched predictions, shape: (num_targets) 102 | matched_col_inds (Tensor): matched targets, shape: (num_targets) 103 | ''' 104 | predictions = predictions.detach().clone() 105 | predictions[:, 3] *= (img_w - 1) 106 | predictions[:, 6:] *= (img_w - 1) 107 | targets = targets.detach().clone() 108 | 109 | # distances cost 110 | distances_score = distance_cost(predictions, targets, img_w) 111 | distances_score = 1 - (distances_score / torch.max(distances_score) 112 | ) + 1e-2 # normalize the distance 113 | 114 | # classification cost 115 | cls_score = focal_cost(predictions[:, :2], targets[:, 1].long()) 116 | num_priors = predictions.shape[0] 117 | num_targets = targets.shape[0] 118 | 119 | target_start_xys = targets[:, 2:4] # num_targets, 2 120 | target_start_xys[..., 0] *= (img_h - 1) 121 | prediction_start_xys = predictions[:, 2:4] 122 | prediction_start_xys[..., 0] *= (img_h - 1) 123 | 124 | start_xys_score = torch.cdist(prediction_start_xys, target_start_xys, 125 | p=2).reshape(num_priors, num_targets) 126 | start_xys_score = (1 - start_xys_score / torch.max(start_xys_score)) + 1e-2 127 | 128 | target_thetas = targets[:, 4].unsqueeze(-1) 129 | theta_score = torch.cdist(predictions[:, 4].unsqueeze(-1), 130 | target_thetas, 131 | p=1).reshape(num_priors, num_targets) * 180 132 | theta_score = (1 - theta_score / torch.max(theta_score)) + 1e-2 133 | 134 | cost = -(distances_score * start_xys_score * theta_score 135 | )**2 * distance_cost_weight + cls_score * cls_cost_weight 136 | 137 | iou = line_iou(predictions[..., 6:], targets[..., 6:], img_w, aligned=False) 138 | matched_row_inds, matched_col_inds = dynamic_k_assign(cost, iou) 139 | 140 | return matched_row_inds, matched_col_inds 141 | -------------------------------------------------------------------------------- /flamnet/models/utils/roi_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import ConvModule 5 | 6 | 7 | def LinearModule(hidden_dim): 8 | return nn.ModuleList( 9 | [nn.Linear(hidden_dim, hidden_dim), 10 | nn.ReLU(inplace=True)]) 11 | 12 | 13 | class FeatureResize(nn.Module): 14 | def __init__(self, size=(10, 25)): 15 | super(FeatureResize, self).__init__() 16 | self.size = size 17 | 18 | def forward(self, x): 19 | x = F.interpolate(x, self.size) 20 | return x.flatten(2) 21 | 22 | 23 | class ContextBlock(nn.Module): 24 | def __init__(self,inplanes,ratio,pooling_type='att', 25 | fusion_types=('channel_add', )): 26 | super(ContextBlock, self).__init__() 27 | valid_fusion_types = ['channel_add', 'channel_mul'] 28 | 29 | assert pooling_type in ['avg', 'att'] 30 | assert isinstance(fusion_types, (list, tuple)) 31 | assert all([f in valid_fusion_types for f in fusion_types]) 32 | assert len(fusion_types) > 0, 'at least one fusion should be used' 33 | 34 | self.inplanes = inplanes 35 | self.ratio = ratio 36 | self.planes = int(inplanes * ratio) 37 | self.pooling_type = pooling_type 38 | self.fusion_types = fusion_types 39 | 40 | if pooling_type == 'att': 41 | self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1) 42 | self.softmax = nn.Softmax(dim=2) 43 | else: 44 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 45 | if 'channel_add' in fusion_types: 46 | self.channel_add_conv = nn.Sequential( 47 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 48 | nn.LayerNorm([self.planes, 1, 1]), 49 | nn.ReLU(inplace=True), # yapf: disable 50 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 51 | else: 52 | self.channel_add_conv = None 53 | if 'channel_mul' in fusion_types: 54 | self.channel_mul_conv = nn.Sequential( 55 | nn.Conv2d(self.inplanes, self.planes, kernel_size=1), 56 | nn.LayerNorm([self.planes, 1, 1]), 57 | nn.ReLU(inplace=True), # yapf: disable 58 | nn.Conv2d(self.planes, self.inplanes, kernel_size=1)) 59 | else: 60 | self.channel_mul_conv = None 61 | 62 | 63 | def spatial_pool(self, x): 64 | batch, channel, height, width = x.size() 65 | if self.pooling_type == 'att': 66 | input_x = x 67 | # [N, C, H * W] 68 | input_x = input_x.view(batch, channel, height * width) 69 | # [N, 1, C, H * W] 70 | input_x = input_x.unsqueeze(1) 71 | # [N, 1, H, W] 72 | context_mask = self.conv_mask(x) 73 | # [N, 1, H * W] 74 | context_mask = context_mask.view(batch, 1, height * width) 75 | # [N, 1, H * W] 76 | context_mask = self.softmax(context_mask) 77 | # [N, 1, H * W, 1] 78 | context_mask = context_mask.unsqueeze(-1) 79 | # [N, 1, C, 1] 80 | context = torch.matmul(input_x, context_mask) 81 | # [N, C, 1, 1] 82 | context = context.view(batch, channel, 1, 1) 83 | else: 84 | # [N, C, 1, 1] 85 | context = self.avg_pool(x) 86 | return context 87 | 88 | def forward(self, x): 89 | # [N, C, 1, 1] 90 | context = self.spatial_pool(x) 91 | out = x 92 | if self.channel_mul_conv is not None: 93 | # [N, C, 1, 1] 94 | channel_mul_term = torch.sigmoid(self.channel_mul_conv(context)) 95 | out = out * channel_mul_term 96 | if self.channel_add_conv is not None: 97 | # [N, C, 1, 1] 98 | channel_add_term = self.channel_add_conv(context) 99 | out = out + channel_add_term 100 | return out 101 | 102 | 103 | class anchor_pre(nn.Module): 104 | 105 | def __init__(self, 106 | in_channels, 107 | num_priors, 108 | sample_points, 109 | fc_hidden_dim, 110 | refine_layers, 111 | mid_channels=48): 112 | super(anchor_pre, self).__init__() 113 | self.in_channels = in_channels 114 | self.num_priors = num_priors 115 | 116 | self.convs = nn.ModuleList() 117 | self.catconv = nn.ModuleList() 118 | for i in range(refine_layers): 119 | self.convs.append( 120 | ConvModule(in_channels, 121 | mid_channels, (9, 1), 122 | padding=(4, 0), 123 | bias=False, 124 | norm_cfg=dict(type='BN'))) 125 | 126 | self.catconv.append( 127 | ConvModule(mid_channels * (i + 1), 128 | in_channels, (9, 1), 129 | padding=(4, 0), 130 | bias=False, 131 | norm_cfg=dict(type='BN'))) 132 | 133 | self.fc = nn.Linear(sample_points * fc_hidden_dim, fc_hidden_dim) 134 | 135 | self.fc_norm = nn.LayerNorm(fc_hidden_dim) 136 | 137 | self.CB = ContextBlock(inplanes=64, ratio=1. / 16., pooling_type='att') 138 | 139 | def roi_fea(self, x, layer_index): 140 | feats = [] 141 | for i, feature in enumerate(x): 142 | feat_trans = self.convs[i](feature) 143 | feats.append(feat_trans) 144 | cat_feat = torch.cat(feats, dim=1) 145 | cat_feat = self.catconv[layer_index](cat_feat) 146 | return cat_feat 147 | 148 | def forward(self, roi_features, x, layer_index): 149 | ''' 150 | Args: 151 | roi_features: prior feature, shape: (Batch * num_priors, prior_feat_channel, sample_point, 1) 152 | x: feature map 153 | layer_index: currently on which layer to refine 154 | Return: 155 | roi: prior features with gathered global information, shape: (Batch, num_priors, fc_hidden_dim) 156 | ''' 157 | roi = self.roi_fea(roi_features, layer_index) 158 | roi = self.CB(roi) 159 | bs = x.size(0) 160 | roi = roi.contiguous().view(bs * self.num_priors, -1) 161 | 162 | roi = F.relu(self.fc_norm(self.fc(roi))) 163 | roi = roi.view(bs, self.num_priors, -1) 164 | 165 | return roi 166 | 167 | 168 | class HIAM(nn.Module): 169 | def __init__(self, 170 | in_channels, 171 | num_priors, 172 | sample_points, 173 | fc_hidden_dim, 174 | refine_layers, 175 | mid_channels=48): 176 | super(HIAM, self).__init__() 177 | self.in_channels = in_channels 178 | self.num_priors = num_priors 179 | self.sample_points = sample_points 180 | self.f_key = ConvModule(in_channels=self.in_channels, 181 | out_channels=self.in_channels, 182 | kernel_size=1, 183 | stride=1, 184 | padding=0, 185 | norm_cfg=dict(type='BN')) 186 | 187 | self.f_query = nn.Sequential( 188 | nn.Conv1d(in_channels=self.sample_points//2, 189 | out_channels=self.sample_points//2, 190 | kernel_size=1, 191 | stride=1, 192 | padding=0, 193 | groups=self.sample_points//2), 194 | nn.ReLU(), 195 | ) 196 | self.f_value = nn.Conv2d(in_channels=self.in_channels, 197 | out_channels=self.in_channels, 198 | kernel_size=1, 199 | stride=1, 200 | padding=0) 201 | self.anchor_Conv = nn.Sequential( 202 | nn.Conv1d(in_channels=self.in_channels, 203 | out_channels=self.in_channels, 204 | kernel_size=5, 205 | stride=1, 206 | padding=2), 207 | nn.BatchNorm1d(self.in_channels), 208 | nn.ReLU() 209 | ) 210 | 211 | self.convs = nn.ModuleList() 212 | self.catconv = nn.ModuleList() 213 | for i in range(refine_layers): 214 | self.convs.append( 215 | ConvModule(in_channels, 216 | mid_channels, (9, 1), 217 | padding=(4, 0), 218 | bias=False, 219 | norm_cfg=dict(type='BN'))) 220 | 221 | self.catconv.append( 222 | ConvModule(mid_channels * (i + 1), #### 把 前一层的特征图与当前特征图 concat (第一层特征图不做该操作) 223 | in_channels, (9, 1), 224 | padding=(4, 0), 225 | bias=False, 226 | norm_cfg=dict(type='BN'))) 227 | 228 | self.fc = nn.Linear(sample_points * fc_hidden_dim, fc_hidden_dim) 229 | 230 | self.fc_norm = nn.LayerNorm(fc_hidden_dim) 231 | 232 | def roi_fea(self, x, layer_index): 233 | feats = [] 234 | for i, feature in enumerate(x): 235 | feat_trans = self.convs[i](feature) 236 | feats.append(feat_trans) 237 | cat_feat = torch.cat(feats, dim=1) ####将feat_trans从list转换为tensor,每个list中的tensor在通道上拼接,将每一层 238 | cat_feat = self.catconv[layer_index](cat_feat) 239 | return cat_feat 240 | 241 | def forward(self, roi_features, x, layer_index): 242 | ''' 243 | Args: 244 | roi_features: prior feature, shape: (Batch * num_priors, prior_feat_channel, sample_point, 1) 245 | x: feature map 246 | layer_index: currently on which layer to refine 247 | Return: 248 | roi: prior features with gathered global information, shape: (Batch, num_priors, fc_hidden_dim) 249 | ''' 250 | roi = self.roi_fea(roi_features, layer_index) ### roi.shape 576,64,36,1 251 | bs = x.size(0) 252 | roi = roi.permute(0,2,1,3).contiguous().view(bs*self.num_priors, -1,self.in_channels) ### roi.shape 576,2304= 3,64,192,36 ;576,36,64 253 | 254 | query = roi[:,18:,:].clone() 255 | value = self.f_value(x).permute(0,3,2,1).contiguous().view(bs*self.num_priors, -1, self.in_channels) ### resize 到10x25并展开为 250x1 的向量 256 | query = self.f_query(query) ### query.shape 3,192,64 257 | key = self.f_key(x).permute(0,2,1,3).contiguous().view(bs*self.num_priors, self.in_channels,-1) ###key.shape B,C,H,W 258 | sim_map = torch.matmul(query, key) 259 | sim_map = (self.in_channels**-.5) * sim_map 260 | sim_map = sim_map.reshape(bs,self.num_priors,-1) 261 | sim_map = F.softmax(sim_map, dim=-1) 262 | sim_map = sim_map.reshape(bs*self.num_priors,18,-1) 263 | 264 | context = torch.matmul(sim_map, value) 265 | # context = self.W(context) 266 | 267 | roi[:,18:,:] = roi[:,18:,:] + F.dropout(context, p=0.2, training=self.training) 268 | roi = self.anchor_Conv(roi.permute(0,2,1)).permute(0,2,1) 269 | 270 | roi = roi.contiguous().view(bs * self.num_priors, -1) 271 | 272 | roi = F.relu(self.fc_norm(self.fc(roi))) 273 | roi = roi.view(bs, self.num_priors, -1) 274 | 275 | return roi 276 | -------------------------------------------------------------------------------- /flamnet/models/utils/seg_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class SegDecoder(nn.Module): 6 | ''' 7 | Optionaly seg decoder 8 | ''' 9 | def __init__(self, 10 | image_height, 11 | image_width, 12 | num_class, 13 | prior_feat_channels=64, 14 | refine_layers=3): 15 | super().__init__() 16 | self.dropout = nn.Dropout2d(0.1) 17 | self.conv = nn.Conv2d(prior_feat_channels * refine_layers, num_class, 18 | 1) 19 | self.image_height = image_height 20 | self.image_width = image_width 21 | 22 | def forward(self, x): 23 | x = self.dropout(x) 24 | x = self.conv(x) 25 | x = F.interpolate(x, 26 | size=[self.image_height, self.image_width], 27 | mode='bilinear', 28 | align_corners=False) 29 | return x -------------------------------------------------------------------------------- /flamnet/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .nms import nms 2 | 3 | __all__ = ['nms'] 4 | -------------------------------------------------------------------------------- /flamnet/ops/csrc/nms.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * 7 | * * Redistributions of source code must retain the above copyright notice, this 8 | * list of conditions and the following disclaimer. 9 | * 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * 14 | * * Neither the name of the copyright holder nor the names of its 15 | * contributors may be used to endorse or promote products derived from 16 | * this software without specific prior written permission. 17 | * 18 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | */ 29 | 30 | #include 31 | #include 32 | #include 33 | 34 | std::vector nms_cuda_forward( 35 | at::Tensor boxes, 36 | at::Tensor idx, 37 | float nms_overlap_thresh, 38 | unsigned long top_k); 39 | 40 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 41 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 42 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 43 | 44 | std::vector nms_forward( 45 | at::Tensor boxes, 46 | at::Tensor scores, 47 | float thresh, 48 | unsigned long top_k) { 49 | 50 | 51 | auto idx = std::get<1>(scores.sort(0,true)); 52 | 53 | CHECK_INPUT(boxes); 54 | CHECK_INPUT(idx); 55 | 56 | return nms_cuda_forward(boxes, idx, thresh, top_k); 57 | } 58 | 59 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 60 | m.def("nms_forward", &nms_forward, "NMS"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /flamnet/ops/csrc/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | // Hard-coded maximum. Increase if needed. 10 | #define MAX_COL_BLOCKS 1000 11 | #define STRIDE 4 12 | #define N_OFFSETS 72 // if you use more than 73 offsets you will have to adjust this value 13 | #define N_STRIPS (N_OFFSETS - 1) 14 | #define PROP_SIZE (5 + N_OFFSETS) 15 | #define DATASET_OFFSET 0 16 | 17 | #define DIVUP(m,n) (((m)+(n)-1) / (n)) 18 | int64_t const threadsPerBlock = sizeof(unsigned long long) * 8; 19 | 20 | // The functions below originates from Fast R-CNN 21 | // See https://github.com/rbgirshick/py-faster-rcnn 22 | // Copyright (c) 2015 Microsoft 23 | // Licensed under The MIT License 24 | // Written by Shaoqing Ren 25 | 26 | template 27 | // __device__ inline scalar_t devIoU(scalar_t const * const a, scalar_t const * const b) { 28 | __device__ inline bool devIoU(scalar_t const * const a, scalar_t const * const b, const float threshold) { 29 | const int start_a = (int) (a[2] * N_STRIPS - DATASET_OFFSET + 0.5); // 0.5 rounding trick 30 | const int start_b = (int) (b[2] * N_STRIPS - DATASET_OFFSET + 0.5); 31 | const int start = max(start_a, start_b); 32 | const int end_a = start_a + a[4] - 1 + 0.5 - ((a[4] - 1) < 0); // - (x<0) trick to adjust for negative numbers (in case length is 0) 33 | const int end_b = start_b + b[4] - 1 + 0.5 - ((b[4] - 1) < 0); 34 | const int end = min(min(end_a, end_b), N_OFFSETS - 1); 35 | // if (end < start) return 1e9; 36 | if (end < start) return false; 37 | scalar_t dist = 0; 38 | for(unsigned char i = 5 + start; i <= 5 + end; ++i) { 39 | if (a[i] < b[i]) { 40 | dist += b[i] - a[i]; 41 | } else { 42 | dist += a[i] - b[i]; 43 | } 44 | } 45 | // return (dist / (end - start + 1)) < threshold; 46 | return dist < (threshold * (end - start + 1)); 47 | // return dist / (end - start + 1); 48 | } 49 | 50 | template 51 | __global__ void nms_kernel(const int64_t n_boxes, const scalar_t nms_overlap_thresh, 52 | const scalar_t *dev_boxes, const int64_t *idx, int64_t *dev_mask) { 53 | const int64_t row_start = blockIdx.y; 54 | const int64_t col_start = blockIdx.x; 55 | 56 | if (row_start > col_start) return; 57 | 58 | const int row_size = 59 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 60 | const int col_size = 61 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 62 | 63 | __shared__ scalar_t block_boxes[threadsPerBlock * PROP_SIZE]; 64 | if (threadIdx.x < col_size) { 65 | for (int i = 0; i < PROP_SIZE; ++i) { 66 | block_boxes[threadIdx.x * PROP_SIZE + i] = dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * PROP_SIZE + i]; 67 | } 68 | // block_boxes[threadIdx.x * 4 + 0] = 69 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 0]; 70 | // block_boxes[threadIdx.x * 4 + 1] = 71 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 1]; 72 | // block_boxes[threadIdx.x * 4 + 2] = 73 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 2]; 74 | // block_boxes[threadIdx.x * 4 + 3] = 75 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 3]; 76 | } 77 | __syncthreads(); 78 | 79 | if (threadIdx.x < row_size) { 80 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 81 | const scalar_t *cur_box = dev_boxes + idx[cur_box_idx] * PROP_SIZE; 82 | int i = 0; 83 | unsigned long long t = 0; 84 | int start = 0; 85 | if (row_start == col_start) { 86 | start = threadIdx.x + 1; 87 | } 88 | for (i = start; i < col_size; i++) { 89 | if (devIoU(cur_box, block_boxes + i * PROP_SIZE, nms_overlap_thresh)) { 90 | t |= 1ULL << i; 91 | } 92 | } 93 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 94 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 95 | } 96 | } 97 | 98 | 99 | __global__ void nms_collect(const int64_t boxes_num, const int64_t col_blocks, int64_t top_k, const int64_t *idx, const int64_t *mask, int64_t *keep, int64_t *parent_object_index, int64_t *num_to_keep) { 100 | int64_t remv[MAX_COL_BLOCKS]; 101 | int64_t num_to_keep_ = 0; 102 | 103 | for (int i = 0; i < col_blocks; i++) { 104 | remv[i] = 0; 105 | } 106 | 107 | for (int i = 0; i < boxes_num; ++i) { 108 | parent_object_index[i] = 0; 109 | } 110 | 111 | for (int i = 0; i < boxes_num; i++) { 112 | int nblock = i / threadsPerBlock; 113 | int inblock = i % threadsPerBlock; 114 | 115 | 116 | if (!(remv[nblock] & (1ULL << inblock))) { 117 | int64_t idxi = idx[i]; 118 | keep[num_to_keep_] = idxi; 119 | const int64_t *p = &mask[0] + i * col_blocks; 120 | for (int j = nblock; j < col_blocks; j++) { 121 | remv[j] |= p[j]; 122 | } 123 | for (int j = i; j < boxes_num; j++) { 124 | int nblockj = j / threadsPerBlock; 125 | int inblockj = j % threadsPerBlock; 126 | if (p[nblockj] & (1ULL << inblockj)) 127 | parent_object_index[idx[j]] = num_to_keep_+1; 128 | } 129 | parent_object_index[idx[i]] = num_to_keep_+1; 130 | 131 | num_to_keep_++; 132 | 133 | if (num_to_keep_==top_k) 134 | break; 135 | } 136 | } 137 | 138 | // Initialize the rest of the keep array to avoid uninitialized values. 139 | for (int i = num_to_keep_; i < boxes_num; ++i) 140 | keep[i] = 0; 141 | 142 | *num_to_keep = min(top_k,num_to_keep_); 143 | } 144 | 145 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 146 | 147 | std::vector nms_cuda_forward( 148 | at::Tensor boxes, 149 | at::Tensor idx, 150 | float nms_overlap_thresh, 151 | unsigned long top_k) { 152 | 153 | const auto boxes_num = boxes.size(0); 154 | TORCH_CHECK(boxes.size(1) == PROP_SIZE, "Wrong number of offsets. Please adjust `PROP_SIZE`"); 155 | 156 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 157 | 158 | AT_ASSERTM (col_blocks < MAX_COL_BLOCKS, "The number of column blocks must be less than MAX_COL_BLOCKS. Increase the MAX_COL_BLOCKS constant if needed."); 159 | 160 | auto longOptions = torch::TensorOptions().device(torch::kCUDA).dtype(torch::kLong); 161 | auto mask = at::empty({boxes_num * col_blocks}, longOptions); 162 | 163 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 164 | DIVUP(boxes_num, threadsPerBlock)); 165 | dim3 threads(threadsPerBlock); 166 | 167 | CHECK_CONTIGUOUS(boxes); 168 | CHECK_CONTIGUOUS(idx); 169 | CHECK_CONTIGUOUS(mask); 170 | 171 | AT_DISPATCH_FLOATING_TYPES(boxes.type(), "nms_cuda_forward", ([&] { 172 | nms_kernel<<>>(boxes_num, 173 | (scalar_t)nms_overlap_thresh, 174 | boxes.data(), 175 | idx.data(), 176 | mask.data()); 177 | })); 178 | 179 | auto keep = at::empty({boxes_num}, longOptions); 180 | auto parent_object_index = at::empty({boxes_num}, longOptions); 181 | auto num_to_keep = at::empty({}, longOptions); 182 | 183 | nms_collect<<<1, 1>>>(boxes_num, col_blocks, top_k, 184 | idx.data(), 185 | mask.data(), 186 | keep.data(), 187 | parent_object_index.data(), 188 | num_to_keep.data()); 189 | 190 | 191 | return {keep,num_to_keep,parent_object_index}; 192 | } 193 | 194 | -------------------------------------------------------------------------------- /flamnet/ops/nms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source FLAMNet must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | from . import nms_impl 30 | 31 | 32 | def nms(boxes, scores, overlap, top_k): 33 | return nms_impl.nms_forward(boxes, scores, overlap, top_k) 34 | -------------------------------------------------------------------------------- /flamnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | from .registry import Registry, build_from_cfg 3 | -------------------------------------------------------------------------------- /flamnet/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import ast 3 | import os.path as osp 4 | import shutil 5 | import sys 6 | import tempfile 7 | from argparse import Action, ArgumentParser 8 | from collections import abc 9 | from importlib import import_module 10 | 11 | from addict import Dict 12 | from yapf.yapflib.yapf_api import FormatCode 13 | 14 | BASE_KEY = '_base_' 15 | DELETE_KEY = '_delete_' 16 | RESERVED_KEYS = ['filename', 'text', 'pretty_text'] 17 | 18 | 19 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 20 | if not osp.isfile(filename): 21 | raise FileNotFoundError(msg_tmpl.format(filename)) 22 | 23 | 24 | class ConfigDict(Dict): 25 | def __missing__(self, name): 26 | raise KeyError(name) 27 | 28 | def __getattr__(self, name): 29 | try: 30 | value = super(ConfigDict, self).__getattr__(name) 31 | except KeyError: 32 | ex = AttributeError(f"'{self.__class__.__name__}' object has no " 33 | f"attribute '{name}'") 34 | except Exception as e: 35 | ex = e 36 | else: 37 | return value 38 | raise ex 39 | 40 | 41 | def add_args(parser, cfg, prefix=''): 42 | for k, v in cfg.items(): 43 | if isinstance(v, str): 44 | parser.add_argument('--' + prefix + k) 45 | elif isinstance(v, int): 46 | parser.add_argument('--' + prefix + k, type=int) 47 | elif isinstance(v, float): 48 | parser.add_argument('--' + prefix + k, type=float) 49 | elif isinstance(v, bool): 50 | parser.add_argument('--' + prefix + k, action='store_true') 51 | elif isinstance(v, dict): 52 | add_args(parser, v, prefix + k + '.') 53 | elif isinstance(v, abc.Iterable): 54 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') 55 | else: 56 | print(f'cannot parse key {prefix + k} of type {type(v)}') 57 | return parser 58 | 59 | 60 | class Config: 61 | """A facility for config and config files. 62 | It supports common file formats as configs: python/json/yaml. The interface 63 | is the same as a dict object and also allows access config values as 64 | attributes. 65 | Example: 66 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) 67 | >>> cfg.a 68 | 1 69 | >>> cfg.b 70 | {'b1': [0, 1]} 71 | >>> cfg.b.b1 72 | [0, 1] 73 | >>> cfg = Config.fromfile('tests/data/config/a.py') 74 | >>> cfg.filename 75 | "/home/kchen/projects/mmcv/tests/data/config/a.py" 76 | >>> cfg.item4 77 | 'test' 78 | >>> cfg 79 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " 80 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" 81 | """ 82 | @staticmethod 83 | def _validate_py_syntax(filename): 84 | with open(filename) as f: 85 | content = f.read() 86 | try: 87 | ast.parse(content) 88 | except SyntaxError: 89 | raise SyntaxError('There are syntax errors in config ' 90 | f'file {filename}') 91 | 92 | @staticmethod 93 | def _file2dict(filename): 94 | filename = osp.abspath(osp.expanduser(filename)) 95 | check_file_exist(filename) 96 | if filename.endswith('.py'): 97 | with tempfile.TemporaryDirectory() as temp_config_dir: 98 | temp_config_file = tempfile.NamedTemporaryFile( 99 | dir=temp_config_dir, suffix='.py') 100 | temp_config_name = osp.basename(temp_config_file.name) 101 | shutil.copyfile(filename, 102 | osp.join(temp_config_dir, temp_config_name)) 103 | temp_module_name = osp.splitext(temp_config_name)[0] 104 | sys.path.insert(0, temp_config_dir) 105 | Config._validate_py_syntax(filename) 106 | mod = import_module(temp_module_name) 107 | sys.path.pop(0) 108 | cfg_dict = { 109 | name: value 110 | for name, value in mod.__dict__.items() 111 | if not name.startswith('__') 112 | } 113 | # delete imported module 114 | del sys.modules[temp_module_name] 115 | # close temp file 116 | temp_config_file.close() 117 | elif filename.endswith(('.yml', '.yaml', '.json')): 118 | import mmcv 119 | cfg_dict = mmcv.load(filename) 120 | else: 121 | raise IOError('Only py/yml/yaml/json type are supported now!') 122 | 123 | cfg_text = '' 124 | with open(filename, 'r') as f: 125 | cfg_text += f.read() 126 | 127 | if BASE_KEY in cfg_dict: 128 | cfg_dir = osp.dirname(filename) 129 | base_filename = cfg_dict.pop(BASE_KEY) 130 | base_filename = base_filename if isinstance( 131 | base_filename, list) else [base_filename] 132 | 133 | cfg_dict_list = list() 134 | cfg_text_list = list() 135 | for f in base_filename: 136 | _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) 137 | cfg_dict_list.append(_cfg_dict) 138 | cfg_text_list.append(_cfg_text) 139 | 140 | base_cfg_dict = dict() 141 | for c in cfg_dict_list: 142 | if len(base_cfg_dict.keys() & c.keys()) > 0: 143 | raise KeyError('Duplicate key is not allowed among bases') 144 | base_cfg_dict.update(c) 145 | 146 | base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) 147 | cfg_dict = base_cfg_dict 148 | 149 | # merge cfg_text 150 | cfg_text_list.append(cfg_text) 151 | cfg_text = '\n'.join(cfg_text_list) 152 | 153 | return cfg_dict, cfg_text 154 | 155 | @staticmethod 156 | def _merge_a_into_b(a, b): 157 | # merge dict `a` into dict `b` (non-inplace). values in `a` will 158 | # overwrite `b`. 159 | # copy first to avoid inplace modification 160 | b = b.copy() 161 | for k, v in a.items(): 162 | if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): 163 | if not isinstance(b[k], dict): 164 | raise TypeError( 165 | f'{k}={v} in child config cannot inherit from base ' 166 | f'because {k} is a dict in the child config but is of ' 167 | f'type {type(b[k])} in base config. You may set ' 168 | f'`{DELETE_KEY}=True` to ignore the base config') 169 | b[k] = Config._merge_a_into_b(v, b[k]) 170 | else: 171 | b[k] = v 172 | return b 173 | 174 | @staticmethod 175 | def fromfile(filename): 176 | cfg_dict, cfg_text = Config._file2dict(filename) 177 | return Config(cfg_dict, cfg_text=cfg_text, filename=filename) 178 | 179 | @staticmethod 180 | def auto_argparser(description=None): 181 | """Generate argparser from config file automatically (experimental) 182 | """ 183 | partial_parser = ArgumentParser(description=description) 184 | partial_parser.add_argument('config', help='config file path') 185 | cfg_file = partial_parser.parse_known_args()[0].config 186 | cfg = Config.fromfile(cfg_file) 187 | parser = ArgumentParser(description=description) 188 | parser.add_argument('config', help='config file path') 189 | add_args(parser, cfg) 190 | return parser, cfg 191 | 192 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None): 193 | if cfg_dict is None: 194 | cfg_dict = dict() 195 | elif not isinstance(cfg_dict, dict): 196 | raise TypeError('cfg_dict must be a dict, but ' 197 | f'got {type(cfg_dict)}') 198 | for key in cfg_dict: 199 | if key in RESERVED_KEYS: 200 | raise KeyError(f'{key} is reserved for config file') 201 | 202 | super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) 203 | super(Config, self).__setattr__('_filename', filename) 204 | if cfg_text: 205 | text = cfg_text 206 | elif filename: 207 | with open(filename, 'r') as f: 208 | text = f.read() 209 | else: 210 | text = '' 211 | super(Config, self).__setattr__('_text', text) 212 | 213 | @property 214 | def filename(self): 215 | return self._filename 216 | 217 | @property 218 | def text(self): 219 | return self._text 220 | 221 | @property 222 | def pretty_text(self): 223 | 224 | indent = 4 225 | 226 | def _indent(s_, num_spaces): 227 | s = s_.split('\n') 228 | if len(s) == 1: 229 | return s_ 230 | first = s.pop(0) 231 | s = [(num_spaces * ' ') + line for line in s] 232 | s = '\n'.join(s) 233 | s = first + '\n' + s 234 | return s 235 | 236 | def _format_basic_types(k, v, use_mapping=False): 237 | if isinstance(v, str): 238 | v_str = f"'{v}'" 239 | else: 240 | v_str = str(v) 241 | 242 | if use_mapping: 243 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 244 | attr_str = f'{k_str}: {v_str}' 245 | else: 246 | attr_str = f'{str(k)}={v_str}' 247 | attr_str = _indent(attr_str, indent) 248 | 249 | return attr_str 250 | 251 | def _format_list(k, v, use_mapping=False): 252 | # check if all items in the list are dict 253 | if all(isinstance(_, dict) for _ in v): 254 | v_str = '[\n' 255 | v_str += '\n'.join( 256 | f'dict({_indent(_format_dict(v_), indent)}),' 257 | for v_ in v).rstrip(',') 258 | if use_mapping: 259 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 260 | attr_str = f'{k_str}: {v_str}' 261 | else: 262 | attr_str = f'{str(k)}={v_str}' 263 | attr_str = _indent(attr_str, indent) + ']' 264 | else: 265 | attr_str = _format_basic_types(k, v, use_mapping) 266 | return attr_str 267 | 268 | def _contain_invalid_identifier(dict_str): 269 | contain_invalid_identifier = False 270 | for key_name in dict_str: 271 | contain_invalid_identifier |= \ 272 | (not str(key_name).isidentifier()) 273 | return contain_invalid_identifier 274 | 275 | def _format_dict(input_dict, outest_level=False): 276 | r = '' 277 | s = [] 278 | 279 | use_mapping = _contain_invalid_identifier(input_dict) 280 | if use_mapping: 281 | r += '{' 282 | for idx, (k, v) in enumerate(input_dict.items()): 283 | is_last = idx >= len(input_dict) - 1 284 | end = '' if outest_level or is_last else ',' 285 | if isinstance(v, dict): 286 | v_str = '\n' + _format_dict(v) 287 | if use_mapping: 288 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 289 | attr_str = f'{k_str}: dict({v_str}' 290 | else: 291 | attr_str = f'{str(k)}=dict({v_str}' 292 | attr_str = _indent(attr_str, indent) + ')' + end 293 | elif isinstance(v, list): 294 | attr_str = _format_list(k, v, use_mapping) + end 295 | else: 296 | attr_str = _format_basic_types(k, v, use_mapping) + end 297 | 298 | s.append(attr_str) 299 | r += '\n'.join(s) 300 | if use_mapping: 301 | r += '}' 302 | return r 303 | 304 | cfg_dict = self._cfg_dict.to_dict() 305 | text = _format_dict(cfg_dict, outest_level=True) 306 | # copied from setup.cfg 307 | yapf_style = dict(based_on_style='pep8', 308 | blank_line_before_nested_class_or_def=True, 309 | split_before_expression_after_opening_paren=True) 310 | text, _ = FormatCode(text, style_config=yapf_style, verify=True) 311 | 312 | return text 313 | 314 | def __repr__(self): 315 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' 316 | 317 | def __len__(self): 318 | return len(self._cfg_dict) 319 | 320 | def __getattr__(self, name): 321 | return getattr(self._cfg_dict, name) 322 | 323 | def __getitem__(self, name): 324 | return self._cfg_dict.__getitem__(name) 325 | 326 | def __setattr__(self, name, value): 327 | if isinstance(value, dict): 328 | value = ConfigDict(value) 329 | self._cfg_dict.__setattr__(name, value) 330 | 331 | def __setitem__(self, name, value): 332 | if isinstance(value, dict): 333 | value = ConfigDict(value) 334 | self._cfg_dict.__setitem__(name, value) 335 | 336 | def __iter__(self): 337 | return iter(self._cfg_dict) 338 | 339 | def haskey(self, name): 340 | return hasattr(self._cfg_dict, name) 341 | 342 | def dump(self, file=None): 343 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() 344 | if self.filename.endswith('.py'): 345 | if file is None: 346 | return self.pretty_text 347 | else: 348 | with open(file, 'w') as f: 349 | f.write(self.pretty_text) 350 | else: 351 | import mmcv 352 | if file is None: 353 | file_format = self.filename.split('.')[-1] 354 | return mmcv.dump(cfg_dict, file_format=file_format) 355 | else: 356 | mmcv.dump(cfg_dict, file) 357 | 358 | def has_attr_in_cfg(self, name): 359 | return hasattr(self._cfg_dict, name) 360 | 361 | def merge_from_dict(self, options): 362 | """Merge list into cfg_dict 363 | Merge the dict parsed by MultipleKVAction into this cfg. 364 | Examples: 365 | >>> options = {'model.backbone.depth': 50, 366 | ... 'model.backbone.with_cp':True} 367 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) 368 | >>> cfg.merge_from_dict(options) 369 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 370 | >>> assert cfg_dict == dict( 371 | ... model=dict(backbone=dict(depth=50, with_cp=True))) 372 | Args: 373 | options (dict): dict of configs to merge from. 374 | """ 375 | option_cfg_dict = {} 376 | for full_key, v in options.items(): 377 | d = option_cfg_dict 378 | key_list = full_key.split('.') 379 | for subkey in key_list[:-1]: 380 | d.setdefault(subkey, ConfigDict()) 381 | d = d[subkey] 382 | subkey = key_list[-1] 383 | d[subkey] = v 384 | 385 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 386 | super(Config, self).__setattr__( 387 | '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict)) 388 | 389 | 390 | class DictAction(Action): 391 | """ 392 | argparse action to split an argument into KEY=VALUE form 393 | on the first = and append to a dictionary. List options should 394 | be passed as comma separated values, i.e KEY=V1,V2,V3 395 | """ 396 | @staticmethod 397 | def _parse_int_float_bool(val): 398 | try: 399 | return int(val) 400 | except ValueError: 401 | pass 402 | try: 403 | return float(val) 404 | except ValueError: 405 | pass 406 | if val.lower() in ['true', 'false']: 407 | return True if val.lower() == 'true' else False 408 | return val 409 | 410 | def __call__(self, parser, namespace, values, option_string=None): 411 | options = {} 412 | for kv in values: 413 | key, val = kv.split('=', maxsplit=1) 414 | val = [self._parse_int_float_bool(v) for v in val.split(',')] 415 | if len(val) == 1: 416 | val = val[0] 417 | options[key] = val 418 | setattr(namespace, self.dest, options) 419 | -------------------------------------------------------------------------------- /flamnet/utils/culane_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from functools import partial 4 | 5 | import cv2 6 | import numpy as np 7 | from tqdm import tqdm 8 | from p_tqdm import t_map, p_map 9 | from scipy.interpolate import splprep, splev 10 | from scipy.optimize import linear_sum_assignment 11 | from shapely.geometry import LineString, Polygon 12 | 13 | 14 | def draw_lane(lane, img=None, img_shape=None, width=30): 15 | if img is None: 16 | img = np.zeros(img_shape, dtype=np.uint8) 17 | lane = lane.astype(np.int32) 18 | for p1, p2 in zip(lane[:-1], lane[1:]): 19 | cv2.line(img, 20 | tuple(p1), 21 | tuple(p2), 22 | color=(255, 255, 255), 23 | thickness=width) 24 | return img 25 | 26 | 27 | def discrete_cross_iou(xs, ys, width=30, img_shape=(590, 1640, 3)): 28 | xs = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in xs] 29 | ys = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in ys] 30 | 31 | ious = np.zeros((len(xs), len(ys))) 32 | for i, x in enumerate(xs): 33 | for j, y in enumerate(ys): 34 | ious[i, j] = (x & y).sum() / (x | y).sum() 35 | return ious 36 | 37 | 38 | def continuous_cross_iou(xs, ys, width=30, img_shape=(590, 1640, 3)): 39 | h, w, _ = img_shape 40 | image = Polygon([(0, 0), (0, h - 1), (w - 1, h - 1), (w - 1, 0)]) 41 | xs = [ 42 | LineString(lane).buffer(distance=width / 2., cap_style=1, 43 | join_style=2).intersection(image) 44 | for lane in xs 45 | ] 46 | ys = [ 47 | LineString(lane).buffer(distance=width / 2., cap_style=1, 48 | join_style=2).intersection(image) 49 | for lane in ys 50 | ] 51 | 52 | ious = np.zeros((len(xs), len(ys))) 53 | for i, x in enumerate(xs): 54 | for j, y in enumerate(ys): 55 | ious[i, j] = x.intersection(y).area / x.union(y).area 56 | 57 | return ious 58 | 59 | 60 | def interp(points, n=50): 61 | x = [x for x, _ in points] 62 | y = [y for _, y in points] 63 | tck, u = splprep([x, y], s=0, t=n, k=min(3, len(points) - 1)) 64 | 65 | u = np.linspace(0., 1., num=(len(u) - 1) * n + 1) 66 | return np.array(splev(u, tck)).T 67 | 68 | 69 | def culane_metric(pred, 70 | anno, 71 | width=30, 72 | iou_thresholds=[0.5], 73 | official=True, 74 | img_shape=(590, 1640, 3)): 75 | _metric = {} 76 | for thr in iou_thresholds: 77 | tp = 0 78 | fp = 0 if len(anno) != 0 else len(pred) 79 | fn = 0 if len(pred) != 0 else len(anno) 80 | _metric[thr] = [tp, fp, fn] 81 | 82 | interp_pred = np.array([interp(pred_lane, n=5) for pred_lane in pred], 83 | dtype=object) # (4, 50, 2) 84 | interp_anno = np.array([interp(anno_lane, n=5) for anno_lane in anno], 85 | dtype=object) # (4, 50, 2) 86 | 87 | if official: 88 | ious = discrete_cross_iou(interp_pred, 89 | interp_anno, 90 | width=width, 91 | img_shape=img_shape) 92 | else: 93 | ious = continuous_cross_iou(interp_pred, 94 | interp_anno, 95 | width=width, 96 | img_shape=img_shape) 97 | 98 | row_ind, col_ind = linear_sum_assignment(1 - ious) 99 | 100 | _metric = {} 101 | for thr in iou_thresholds: 102 | tp = int((ious[row_ind, col_ind] > thr).sum()) 103 | fp = len(pred) - tp 104 | fn = len(anno) - tp 105 | _metric[thr] = [tp, fp, fn] 106 | return _metric 107 | 108 | 109 | def load_culane_img_data(path): 110 | with open(path, 'r') as data_file: 111 | img_data = data_file.readlines() 112 | img_data = [line.split() for line in img_data] 113 | img_data = [list(map(float, lane)) for lane in img_data] 114 | img_data = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)] 115 | for lane in img_data] 116 | img_data = [lane for lane in img_data if len(lane) >= 2] 117 | 118 | return img_data 119 | 120 | 121 | def load_culane_data(data_dir, file_list_path): 122 | with open(file_list_path, 'r') as file_list: 123 | filepaths = [ 124 | os.path.join( 125 | data_dir, line[1 if line[0] == '/' else 0:].rstrip().replace( 126 | '.jpg', '.lines.txt')) for line in file_list.readlines() 127 | ] 128 | 129 | data = [] 130 | for path in filepaths: 131 | img_data = load_culane_img_data(path) 132 | data.append(img_data) 133 | 134 | return data 135 | 136 | def eval_predictions(pred_dir, 137 | anno_dir, 138 | list_path, 139 | iou_thresholds=[0.5], 140 | width=30, 141 | official=True, 142 | sequential=False): 143 | import logging 144 | logger = logging.getLogger(__name__) 145 | logger.info('Calculating metric for List: {}'.format(list_path)) 146 | predictions = load_culane_data(pred_dir, list_path) 147 | annotations = load_culane_data(anno_dir, list_path) 148 | img_shape = (590, 1640, 3) 149 | if sequential: 150 | results = map( 151 | partial(culane_metric, 152 | width=width, 153 | official=official, 154 | iou_thresholds=iou_thresholds, 155 | img_shape=img_shape), predictions, annotations) 156 | else: 157 | from multiprocessing import Pool, cpu_count 158 | from itertools import repeat 159 | with Pool(cpu_count()) as p: 160 | results = p.starmap(culane_metric, zip(predictions, annotations, 161 | repeat(width), 162 | repeat(iou_thresholds), 163 | repeat(official), 164 | repeat(img_shape))) 165 | 166 | mean_f1, mean_prec, mean_recall, total_tp, total_fp, total_fn = 0, 0, 0, 0, 0, 0 167 | ret = {} 168 | for thr in iou_thresholds: 169 | tp = sum(m[thr][0] for m in results) 170 | fp = sum(m[thr][1] for m in results) 171 | fn = sum(m[thr][2] for m in results) 172 | precision = float(tp) / (tp + fp) if tp != 0 else 0 173 | recall = float(tp) / (tp + fn) if tp != 0 else 0 174 | f1 = 2 * precision * recall / (precision + recall) if tp !=0 else 0 175 | logger.info('iou thr: {:.2f}, tp: {}, fp: {}, fn: {},' 176 | 'precision: {}, recall: {}, f1: {}'.format( 177 | thr, tp, fp, fn, precision, recall, f1)) 178 | mean_f1 += f1 / len(iou_thresholds) 179 | mean_prec += precision / len(iou_thresholds) 180 | mean_recall += recall / len(iou_thresholds) 181 | total_tp += tp 182 | total_fp += fp 183 | total_fn += fn 184 | ret[thr] = { 185 | 'TP': tp, 186 | 'FP': fp, 187 | 'FN': fn, 188 | 'Precision': precision, 189 | 'Recall': recall, 190 | 'F1': f1 191 | } 192 | if len(iou_thresholds) > 2: 193 | logger.info('mean result, total_tp: {}, total_fp: {}, total_fn: {},' 194 | 'precision: {}, recall: {}, f1: {}'.format(total_tp, total_fp, 195 | total_fn, mean_prec, mean_recall, mean_f1)) 196 | ret['mean'] = { 197 | 'TP': total_tp, 198 | 'FP': total_fp, 199 | 'FN': total_fn, 200 | 'Precision': mean_prec, 201 | 'Recall': mean_recall, 202 | 'F1': mean_f1 203 | } 204 | return ret 205 | 206 | 207 | def main(): 208 | args = parse_args() 209 | for list_path in args.list: 210 | results = eval_predictions(args.pred_dir, 211 | args.anno_dir, 212 | list_path, 213 | width=args.width, 214 | official=args.official, 215 | sequential=args.sequential) 216 | 217 | header = '=' * 20 + ' Results ({})'.format( 218 | os.path.basename(list_path)) + '=' * 20 219 | print(header) 220 | for metric, value in results.items(): 221 | if isinstance(value, float): 222 | print('{}: {:.4f}'.format(metric, value)) 223 | else: 224 | print('{}: {}'.format(metric, value)) 225 | print('=' * len(header)) 226 | 227 | 228 | def parse_args(): 229 | parser = argparse.ArgumentParser(description="Measure CULane's metric") 230 | parser.add_argument( 231 | "--pred_dir", 232 | help="Path to directory containing the predicted lanes", 233 | required=True) 234 | parser.add_argument( 235 | "--anno_dir", 236 | help="Path to directory containing the annotated lanes", 237 | required=True) 238 | parser.add_argument("--width", 239 | type=int, 240 | default=30, 241 | help="Width of the lane") 242 | parser.add_argument("--list", 243 | nargs='+', 244 | help="Path to txt file containing the list of files", 245 | required=True) 246 | parser.add_argument("--sequential", 247 | action='store_true', 248 | help="Run sequentially instead of in parallel") 249 | parser.add_argument("--official", 250 | action='store_true', 251 | help="Use official way to calculate the metric") 252 | 253 | return parser.parse_args() 254 | 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /flamnet/utils/lane.py: -------------------------------------------------------------------------------- 1 | from scipy.interpolate import InterpolatedUnivariateSpline 2 | import numpy as np 3 | 4 | 5 | class Lane: 6 | def __init__(self, points=None, invalid_value=-2., metadata=None): 7 | super(Lane, self).__init__() 8 | self.curr_iter = 0 9 | self.points = points 10 | self.invalid_value = invalid_value 11 | self.function = InterpolatedUnivariateSpline(points[:, 1], 12 | points[:, 0], 13 | k=min(3, 14 | len(points) - 1)) 15 | self.min_y = points[:, 1].min() - 0.01 16 | self.max_y = points[:, 1].max() + 0.01 17 | 18 | self.metadata = metadata or {} 19 | 20 | def __repr__(self): 21 | return '[Lane]\n' + str(self.points) + '\n[/Lane]' 22 | 23 | def __call__(self, lane_ys): 24 | lane_xs = self.function(lane_ys) 25 | 26 | lane_xs[(lane_ys < self.min_y) | 27 | (lane_ys > self.max_y)] = self.invalid_value 28 | return lane_xs 29 | 30 | def to_array(self, cfg): 31 | sample_y = cfg.sample_y 32 | img_w, img_h = cfg.ori_img_w, cfg.ori_img_h 33 | ys = np.array(sample_y) / float(img_h) 34 | xs = self(ys) 35 | valid_mask = (xs >= 0) & (xs < 1) 36 | lane_xs = xs[valid_mask] * img_w 37 | lane_ys = ys[valid_mask] * img_h 38 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 39 | axis=1) 40 | return lane 41 | 42 | def __iter__(self): 43 | return self 44 | 45 | def __next__(self): 46 | if self.curr_iter < len(self.points): 47 | self.curr_iter += 1 48 | return self.points[self.curr_iter - 1] 49 | self.curr_iter = 0 50 | raise StopIteration 51 | -------------------------------------------------------------------------------- /flamnet/utils/llamas_metric.py: -------------------------------------------------------------------------------- 1 | """ Evaluation script for the CULane metric on the LLAMAS dataset. 2 | This script will compute the F1, precision and recall metrics as described in the CULane benchmark. 3 | The predictions format is the same one used in the CULane benchmark. 4 | In summary, for every annotation file: 5 | labels/a/b/c.json 6 | There should be a prediction file: 7 | predictions/a/b/c.lines.txt 8 | Inside each .lines.txt file each line will contain a sequence of points (x, y) separated by spaces. 9 | For more information, please see https://xingangpan.github.io/projects/CULane.html 10 | This script uses two methods to compute the IoU: one using an image to draw the lanes (named `discrete` here) and 11 | another one that uses shapes with the shapely library (named `continuous` here). The results achieved with the first 12 | method are very close to the official CULane implementation. Although the second should be a more exact method and is 13 | faster to compute, it deviates more from the official implementation. By default, the method closer to the official 14 | metric is used. 15 | """ 16 | 17 | import os 18 | import argparse 19 | from functools import partial 20 | 21 | import cv2 22 | import numpy as np 23 | from p_tqdm import t_map, p_map 24 | from scipy.interpolate import splprep, splev 25 | from scipy.optimize import linear_sum_assignment 26 | from shapely.geometry import LineString, Polygon 27 | 28 | import flamnet.utils.llamas_utils as llamas_utils 29 | 30 | LLAMAS_IMG_RES = (717, 1276) 31 | 32 | 33 | def add_ys(xs): 34 | """For each x in xs, make a tuple with x and its corresponding y.""" 35 | xs = np.array(xs[300:]) 36 | valid = xs >= 0 37 | xs = xs[valid] 38 | assert len(xs) > 1 39 | ys = np.arange(300, 717)[valid] 40 | return list(zip(xs, ys)) 41 | 42 | 43 | def draw_lane(lane, img=None, img_shape=None, width=30): 44 | """Draw a lane (a list of points) on an image by drawing a line with width `width` through each 45 | pair of points i and i+i""" 46 | if img is None: 47 | img = np.zeros(img_shape, dtype=np.uint8) 48 | lane = lane.astype(np.int32) 49 | for p1, p2 in zip(lane[:-1], lane[1:]): 50 | cv2.line(img, tuple(p1), tuple(p2), color=(1, ), thickness=width) 51 | return img 52 | 53 | 54 | def discrete_cross_iou(xs, ys, width=30, img_shape=LLAMAS_IMG_RES): 55 | """For each lane in xs, compute its Intersection Over Union (IoU) with each lane in ys by drawing the lanes on 56 | an image""" 57 | xs = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in xs] 58 | ys = [draw_lane(lane, img_shape=img_shape, width=width) > 0 for lane in ys] 59 | 60 | ious = np.zeros((len(xs), len(ys))) 61 | for i, x in enumerate(xs): 62 | for j, y in enumerate(ys): 63 | # IoU by the definition: sum all intersections (binary and) and divide by the sum of the union (binary or) 64 | ious[i, j] = (x & y).sum() / (x | y).sum() 65 | return ious 66 | 67 | 68 | def continuous_cross_iou(xs, ys, width=30, img_shape=LLAMAS_IMG_RES): 69 | """For each lane in xs, compute its Intersection Over Union (IoU) with each lane in ys using the area between each 70 | pair of points""" 71 | h, w = img_shape 72 | image = Polygon([(0, 0), (0, h - 1), (w - 1, h - 1), (w - 1, 0)]) 73 | xs = [ 74 | LineString(lane).buffer(distance=width / 2., cap_style=1, 75 | join_style=2).intersection(image) 76 | for lane in xs 77 | ] 78 | ys = [ 79 | LineString(lane).buffer(distance=width / 2., cap_style=1, 80 | join_style=2).intersection(image) 81 | for lane in ys 82 | ] 83 | 84 | ious = np.zeros((len(xs), len(ys))) 85 | for i, x in enumerate(xs): 86 | for j, y in enumerate(ys): 87 | ious[i, j] = x.intersection(y).area / x.union(y).area 88 | 89 | return ious 90 | 91 | 92 | def interpolate_lane(points, n=50): 93 | """Spline interpolation of a lane. Used on the predictions""" 94 | x = [x for x, _ in points] 95 | y = [y for _, y in points] 96 | tck, _ = splprep([x, y], s=0, t=n, k=min(3, len(points) - 1)) 97 | 98 | u = np.linspace(0., 1., n) 99 | return np.array(splev(u, tck)).T 100 | 101 | 102 | def culane_metric(pred, 103 | anno, 104 | width=30, 105 | iou_thresholds=[0.5], 106 | unofficial=False, 107 | img_shape=LLAMAS_IMG_RES): 108 | _metric = {} 109 | for thr in iou_thresholds: 110 | tp = 0 111 | fp = 0 if len(anno) != 0 else len(pred) 112 | fn = 0 if len(pred) != 0 else len(anno) 113 | _metric[thr] = [tp, fp, fn] 114 | 115 | """Computes CULane's metric for a single image""" 116 | if len(pred) == 0: 117 | return 0, 0, len(anno), _metric 118 | if len(anno) == 0: 119 | return 0, len(pred), 0, _metric 120 | 121 | interp_pred = np.array([ 122 | interpolate_lane(pred_lane, n=50) for pred_lane in pred 123 | ]) # (4, 50, 2) 124 | anno = np.array([np.array(anno_lane) for anno_lane in anno], dtype=object) 125 | 126 | if unofficial: 127 | ious = continuous_cross_iou(interp_pred, anno, width=width) 128 | else: 129 | ious = discrete_cross_iou(interp_pred, 130 | anno, 131 | width=width, 132 | img_shape=img_shape) 133 | row_ind, col_ind = linear_sum_assignment(1 - ious) 134 | _metric = {} 135 | for thr in iou_thresholds: 136 | tp = int((ious[row_ind, col_ind] > thr).sum()) 137 | fp = len(pred) - tp 138 | fn = len(anno) - tp 139 | _metric[thr] = [tp, fp, fn] 140 | 141 | return _metric 142 | 143 | 144 | def load_prediction(path): 145 | """Loads an image's predictions 146 | Returns a list of lanes, where each lane is a list of points (x,y) 147 | """ 148 | with open(path, 'r') as data_file: 149 | img_data = data_file.readlines() 150 | img_data = [line.split() for line in img_data] 151 | img_data = [list(map(float, lane)) for lane in img_data] 152 | img_data = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)] 153 | for lane in img_data] 154 | img_data = [lane for lane in img_data if len(lane) >= 2] 155 | 156 | return img_data 157 | 158 | 159 | def load_prediction_list(label_paths, pred_dir): 160 | return [ 161 | load_prediction( 162 | os.path.join(pred_dir, path.replace('.json', '.lines.txt'))) 163 | for path in label_paths 164 | ] 165 | 166 | 167 | def load_labels(label_dir): 168 | """Loads the annotations and its paths 169 | Each annotation is converted to a list of points (x, y) 170 | """ 171 | label_paths = llamas_utils.get_files_from_folder(label_dir, '.json') 172 | annos = [ 173 | [ 174 | add_ys(xs) for xs in 175 | llamas_utils.get_horizontal_values_for_four_lanes(label_path) 176 | if (np.array(xs) >= 0).sum() > 1 177 | ] # lanes annotated with a single point are ignored 178 | for label_path in label_paths 179 | ] 180 | label_paths = [llamas_utils.get_label_base(p) for p in label_paths] 181 | return np.array(annos, dtype=object), np.array(label_paths, dtype=object) 182 | 183 | 184 | def eval_predictions(pred_dir, 185 | anno_dir, 186 | width=30, 187 | iou_thresholds=[0.5], 188 | unofficial=True, 189 | sequential=False): 190 | """Evaluates the predictions in pred_dir and returns CULane's metrics (precision, recall, F1 and its components)""" 191 | print(f'Loading annotation data ({anno_dir})...') 192 | os.makedirs('cache', exist_ok=True) 193 | annotations_path = 'cache/llamas_annotations.pkl' 194 | label_path = 'cache/llamas_label_paths.pkl' 195 | import pickle as pkl 196 | if os.path.exists(annotations_path) and os.path.exists(label_path): 197 | with open(annotations_path, 'rb') as cache_file: 198 | annotations = pkl.load(cache_file) 199 | with open(label_path, 'rb') as cache_file: 200 | label_paths = pkl.load(cache_file) 201 | else: 202 | annotations, label_paths = load_labels(anno_dir) 203 | with open(annotations_path, 'wb') as cache_file: 204 | pkl.dump(annotations, cache_file) 205 | with open(label_path, 'wb') as cache_file: 206 | pkl.dump(label_paths, cache_file) 207 | 208 | print(f'Loading prediction data ({pred_dir})...') 209 | predictions = load_prediction_list(label_paths, pred_dir) 210 | print('Calculating metric {}...'.format( 211 | 'sequentially' if sequential else 'in parallel')) 212 | if sequential: 213 | results = map( 214 | partial(culane_metric, 215 | width=width, 216 | unofficial=unofficial, 217 | img_shape=LLAMAS_IMG_RES), predictions, annotations) 218 | else: 219 | from multiprocessing import Pool, cpu_count 220 | from itertools import repeat 221 | with Pool(cpu_count()) as p: 222 | results = p.starmap(culane_metric, zip(predictions, annotations, 223 | repeat(width), 224 | repeat(iou_thresholds), 225 | repeat(unofficial), 226 | repeat(LLAMAS_IMG_RES))) 227 | 228 | import logging 229 | logger = logging.getLogger(__name__) 230 | mean_f1, mean_prec, mean_recall, total_tp, total_fp, total_fn = 0, 0, 0, 0, 0, 0 231 | ret = {} 232 | for thr in iou_thresholds: 233 | tp = sum(m[thr][0] for m in results) 234 | fp = sum(m[thr][1] for m in results) 235 | fn = sum(m[thr][2] for m in results) 236 | precision = float(tp) / (tp + fp) if tp != 0 else 0 237 | recall = float(tp) / (tp + fn) if tp != 0 else 0 238 | f1 = 2 * precision * recall / (precision + recall) if tp !=0 else 0 239 | logger.info('iou thr: {:.2f}, tp: {}, fp: {}, fn: {}, ' 240 | 'precision: {}, recall: {}, f1: {}'.format( 241 | thr, tp, fp, fn, precision, recall, f1)) 242 | mean_f1 += f1 / len(iou_thresholds) 243 | mean_prec += precision / len(iou_thresholds) 244 | mean_recall += recall / len(iou_thresholds) 245 | total_tp += tp 246 | total_fp += fp 247 | total_fn += fn 248 | ret[thr] = { 249 | 'TP': tp, 250 | 'FP': fp, 251 | 'FN': fn, 252 | 'Precision': precision, 253 | 'Recall': recall, 254 | 'F1': f1 255 | } 256 | if len(iou_thresholds) > 2: 257 | logger.info('mean result, total_tp: {}, total_fp: {}, total_fn: {},' 258 | 'precision: {}, recall: {}, f1: {}'.format(total_tp, total_fp, 259 | total_fn, mean_prec, mean_recall, mean_f1)) 260 | ret['mean'] = { 261 | 'TP': total_tp, 262 | 'FP': total_fp, 263 | 'FN': total_fn, 264 | 'Precision': mean_prec, 265 | 'Recall': mean_recall, 266 | 'F1': mean_f1 267 | } 268 | return ret 269 | 270 | def parse_args(): 271 | parser = argparse.ArgumentParser( 272 | description="Measure CULane's metric on the LLAMAS dataset") 273 | parser.add_argument( 274 | "--pred_dir", 275 | help="Path to directory containing the predicted lanes", 276 | required=True) 277 | parser.add_argument( 278 | "--anno_dir", 279 | help="Path to directory containing the annotated lanes", 280 | required=True) 281 | parser.add_argument("--width", 282 | type=int, 283 | default=30, 284 | help="Width of the lane") 285 | parser.add_argument("--sequential", 286 | action='store_true', 287 | help="Run sequentially instead of in parallel") 288 | parser.add_argument("--unofficial", 289 | action='store_true', 290 | help="Use a faster but unofficial algorithm") 291 | 292 | return parser.parse_args() 293 | 294 | 295 | def main(): 296 | args = parse_args() 297 | results = eval_predictions(args.pred_dir, 298 | args.anno_dir, 299 | width=args.width, 300 | iou_thresholds=[0.5], 301 | unofficial=args.unofficial, 302 | sequential=args.sequential) 303 | 304 | header = '=' * 20 + ' Results' + '=' * 20 305 | print(header) 306 | for metric, value in results.items(): 307 | if isinstance(value, float): 308 | print('{}: {:.4f}'.format(metric, value)) 309 | else: 310 | print('{}: {}'.format(metric, value)) 311 | print('=' * len(header)) 312 | 313 | 314 | if __name__ == '__main__': 315 | main() 316 | -------------------------------------------------------------------------------- /flamnet/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def init_logger(log_file=None, log_level=logging.INFO): 5 | stream_handler = logging.StreamHandler() 6 | handlers = [stream_handler] 7 | 8 | if log_file is not None: 9 | file_handler = logging.FileHandler(log_file, 'w') 10 | handlers.append(file_handler) 11 | 12 | formatter = logging.Formatter( 13 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 14 | for handler in handlers: 15 | handler.setFormatter(formatter) 16 | handler.setLevel(log_level) 17 | 18 | logging.basicConfig(level=log_level, handlers=handlers) 19 | -------------------------------------------------------------------------------- /flamnet/utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch import nn 4 | import numpy as np 5 | import torch.nn.functional 6 | 7 | 8 | def save_model(net, optim, scheduler, recorder, is_best=False): 9 | model_dir = os.path.join(recorder.work_dir, 'ckpt') 10 | os.system('mkdir -p {}'.format(model_dir)) 11 | epoch = recorder.epoch 12 | ckpt_name = 'best' if is_best else epoch 13 | torch.save( 14 | { 15 | 'net': net.state_dict(), 16 | 'optim': optim.state_dict(), 17 | 'scheduler': scheduler.state_dict(), 18 | 'recorder': recorder.state_dict(), 19 | 'epoch': epoch 20 | }, os.path.join(model_dir, '{}.pth'.format(ckpt_name))) 21 | 22 | 23 | def load_network_specified(net, model_dir, logger=None): 24 | pretrained_net = torch.load(model_dir)['net'] 25 | net_state = net.state_dict() 26 | state = {} 27 | for k, v in pretrained_net.items(): 28 | if k not in net_state.keys() or v.size() != net_state[k].size(): 29 | if logger: 30 | logger.info('skip weights: ' + k) 31 | continue 32 | state[k] = v 33 | net.load_state_dict(state, strict=False) 34 | 35 | 36 | def load_network(net, model_dir, finetune_from=None, logger=None): 37 | if finetune_from: 38 | if logger: 39 | logger.info('Finetune model from: ' + finetune_from) 40 | load_network_specified(net, finetune_from, logger) 41 | return 42 | pretrained_model = torch.load(model_dir) 43 | net.load_state_dict(pretrained_model['net'], strict=False) 44 | 45 | 46 | def resume_network(model_dir, net, optim, scheduler, recorder): 47 | if not os.path.exists(model_dir): 48 | print('WARNING: NO MODEL LOADED !!!', 'red') 49 | return 0 50 | 51 | print('resume model: {}'.format(model_dir)) 52 | 53 | pretrained_model = torch.load(model_dir) 54 | net.load_state_dict(pretrained_model['net']) 55 | optim.load_state_dict(pretrained_model['optim']) 56 | scheduler.load_state_dict(pretrained_model['scheduler']) 57 | recorder.load_state_dict(pretrained_model['recorder']) 58 | return pretrained_model['epoch'] + 1 59 | -------------------------------------------------------------------------------- /flamnet/utils/recorder.py: -------------------------------------------------------------------------------- 1 | from collections import deque, defaultdict 2 | import torch 3 | import os 4 | import datetime 5 | from .logger import init_logger 6 | import logging 7 | import pathspec 8 | 9 | 10 | class SmoothedValue(object): 11 | """Track a series of values and provide access to smoothed values over a 12 | window or the global series average. 13 | """ 14 | def __init__(self, window_size=20): 15 | self.deque = deque(maxlen=window_size) 16 | self.total = 0.0 17 | self.count = 0 18 | 19 | def update(self, value): 20 | self.deque.append(value) 21 | self.count += 1 22 | self.total += value 23 | 24 | @property 25 | def median(self): 26 | d = torch.tensor(list(self.deque)) 27 | return d.median().item() 28 | 29 | @property 30 | def avg(self): 31 | d = torch.tensor(list(self.deque)) 32 | return d.mean().item() 33 | 34 | @property 35 | def global_avg(self): 36 | return self.total / self.count 37 | 38 | 39 | class Recorder(object): 40 | def __init__(self, cfg): 41 | self.cfg = cfg 42 | self.work_dir = self.get_work_dir() 43 | cfg.work_dir = self.work_dir 44 | self.log_path = os.path.join(self.work_dir, 'log.txt') 45 | 46 | init_logger(self.log_path) 47 | self.logger = logging.getLogger(__name__) 48 | self.logger.info('Config: \n' + cfg.text) 49 | 50 | self.save_cfg(cfg) 51 | self.cp_projects(self.work_dir) 52 | 53 | # scalars 54 | self.epoch = 0 55 | self.step = 0 56 | self.loss_stats = defaultdict(SmoothedValue) 57 | self.batch_time = SmoothedValue() 58 | self.data_time = SmoothedValue() 59 | self.max_iter = self.cfg.total_iter 60 | self.lr = 0. 61 | 62 | def save_cfg(self, cfg): 63 | cfg_path = os.path.join(self.work_dir, 'config.py') 64 | with open(cfg_path, 'w') as cfg_file: 65 | cfg_file.write(cfg.text) 66 | 67 | def cp_projects(self, to_path): 68 | with open('./.gitignore', 'r') as fp: 69 | ign = fp.read() 70 | ign += '\n.git' 71 | spec = pathspec.PathSpec.from_lines( 72 | pathspec.patterns.GitWildMatchPattern, ign.splitlines()) 73 | all_files = { 74 | os.path.join(root, name) 75 | for root, dirs, files in os.walk('./') for name in files 76 | } 77 | matches = spec.match_files(all_files) 78 | matches = set(matches) 79 | to_cp_files = all_files - matches 80 | for f in to_cp_files: 81 | dirs = os.path.join(to_path, 'FLAMNet', os.path.split(f[2:])[0]) 82 | if not os.path.exists(dirs): 83 | os.makedirs(dirs) 84 | os.system('cp %s %s' % (f, os.path.join(to_path, 'FLAMNet', f[2:]))) 85 | 86 | def get_work_dir(self): 87 | now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') 88 | hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, 89 | self.cfg.batch_size) 90 | work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str) 91 | if not os.path.exists(work_dir): 92 | os.makedirs(work_dir) 93 | return work_dir 94 | 95 | def update_loss_stats(self, loss_dict): 96 | for k, v in loss_dict.items(): 97 | if not isinstance(v, torch.Tensor): continue 98 | self.loss_stats[k].update(v.detach().mean().cpu()) 99 | 100 | def record(self, prefix, step=-1, loss_stats=None, image_stats=None): 101 | self.logger.info(self) 102 | # self.write(str(self)) 103 | 104 | def write(self, content): 105 | with open(self.log_path, 'a+') as f: 106 | f.write(content) 107 | f.write('\n') 108 | 109 | def state_dict(self): 110 | scalar_dict = {} 111 | scalar_dict['step'] = self.step 112 | return scalar_dict 113 | 114 | def load_state_dict(self, scalar_dict): 115 | self.step = scalar_dict['step'] 116 | 117 | def __str__(self): 118 | loss_state = [] 119 | for k, v in self.loss_stats.items(): 120 | loss_state.append('{}: {:.4f}'.format(k, v.avg)) 121 | loss_state = ' '.join(loss_state) 122 | 123 | recording_state = ' '.join([ 124 | 'epoch: {}', 'step: {}', 'lr: {:.6f}', '{}', 'data: {:.4f}', 125 | 'batch: {:.4f}', 'eta: {}' 126 | ]) 127 | eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step) 128 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 129 | return recording_state.format(self.epoch, self.step, self.lr, 130 | loss_state, self.data_time.avg, 131 | self.batch_time.avg, eta_string) 132 | 133 | 134 | def build_recorder(cfg): 135 | return Recorder(cfg) 136 | -------------------------------------------------------------------------------- /flamnet/utils/registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import six 4 | 5 | # borrow from mmdetection 6 | 7 | 8 | def is_str(x): 9 | """Whether the input is an string instance.""" 10 | return isinstance(x, six.string_types) 11 | 12 | 13 | class Registry(object): 14 | def __init__(self, name): 15 | self._name = name 16 | self._module_dict = dict() 17 | 18 | def __repr__(self): 19 | format_str = self.__class__.__name__ + '(name={}, items={})'.format( 20 | self._name, list(self._module_dict.keys())) 21 | return format_str 22 | 23 | @property 24 | def name(self): 25 | return self._name 26 | 27 | @property 28 | def module_dict(self): 29 | return self._module_dict 30 | 31 | def get(self, key): 32 | return self._module_dict.get(key, None) 33 | 34 | def _register_module(self, module_class): 35 | """Register a module. 36 | 37 | Args: 38 | module (:obj:`nn.Module`): Module to be registered. 39 | """ 40 | if not inspect.isclass(module_class): 41 | raise TypeError('module must be a class, but got {}'.format( 42 | type(module_class))) 43 | module_name = module_class.__name__ 44 | if module_name in self._module_dict: 45 | raise KeyError('{} is already registered in {}'.format( 46 | module_name, self.name)) 47 | self._module_dict[module_name] = module_class 48 | 49 | def register_module(self, cls): 50 | self._register_module(cls) 51 | return cls 52 | 53 | 54 | def build_from_cfg(cfg, registry, default_args=None): 55 | """Build a module from config dict. 56 | 57 | Args: 58 | cfg (dict): Config dict. It should at least contain the key "type". 59 | registry (:obj:`Registry`): The registry to search the type from. 60 | default_args (dict, optional): Default initialization arguments. 61 | 62 | Returns: 63 | obj: The constructed object. 64 | """ 65 | assert isinstance(cfg, dict) and 'type' in cfg 66 | assert isinstance(default_args, dict) or default_args is None 67 | args = cfg.copy() 68 | obj_type = args.pop('type') 69 | if is_str(obj_type): 70 | obj_cls = registry.get(obj_type) 71 | if obj_cls is None: 72 | raise KeyError('{} is not in the {} registry'.format( 73 | obj_type, registry.name)) 74 | elif inspect.isclass(obj_type): 75 | obj_cls = obj_type 76 | else: 77 | raise TypeError('type must be a str or valid type, but got {}'.format( 78 | type(obj_type))) 79 | if default_args is not None: 80 | for name, value in default_args.items(): 81 | args.setdefault(name, value) 82 | return obj_cls(**args) 83 | -------------------------------------------------------------------------------- /flamnet/utils/tusimple_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.linear_model import LinearRegression 3 | import json as json 4 | 5 | 6 | class LaneEval(object): 7 | lr = LinearRegression() 8 | pixel_thresh = 20 9 | pt_thresh = 0.85 10 | 11 | @staticmethod 12 | def get_angle(xs, y_samples): 13 | xs, ys = xs[xs >= 0], y_samples[xs >= 0] 14 | if len(xs) > 1: 15 | LaneEval.lr.fit(ys[:, None], xs) 16 | k = LaneEval.lr.coef_[0] 17 | theta = np.arctan(k) 18 | else: 19 | theta = 0 20 | return theta 21 | 22 | @staticmethod 23 | def line_accuracy(pred, gt, thresh): 24 | pred = np.array([p if p >= 0 else -100 for p in pred]) 25 | gt = np.array([g if g >= 0 else -100 for g in gt]) 26 | return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt) 27 | 28 | @staticmethod 29 | def bench(pred, gt, y_samples, running_time): 30 | if any(len(p) != len(y_samples) for p in pred): 31 | raise Exception('Format of lanes error.') 32 | if running_time > 200 or len(gt) + 2 < len(pred): 33 | return 0., 0., 1. 34 | angles = [ 35 | LaneEval.get_angle(np.array(x_gts), np.array(y_samples)) 36 | for x_gts in gt 37 | ] 38 | threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles] 39 | line_accs = [] 40 | fp, fn = 0., 0. 41 | matched = 0. 42 | for x_gts, thresh in zip(gt, threshs): 43 | accs = [ 44 | LaneEval.line_accuracy(np.array(x_preds), np.array(x_gts), 45 | thresh) for x_preds in pred 46 | ] 47 | max_acc = np.max(accs) if len(accs) > 0 else 0. 48 | if max_acc < LaneEval.pt_thresh: 49 | fn += 1 50 | else: 51 | matched += 1 52 | line_accs.append(max_acc) 53 | fp = len(pred) - matched 54 | if len(gt) > 4 and fn > 0: 55 | fn -= 1 56 | s = sum(line_accs) 57 | if len(gt) > 4: 58 | s -= min(line_accs) 59 | return s / max(min(4.0, len(gt)), 60 | 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max( 61 | min(len(gt), 4.), 1.) 62 | 63 | @staticmethod 64 | def bench_one_submit(pred_file, gt_file): 65 | try: 66 | json_pred = [ 67 | json.loads(line) for line in open(pred_file).readlines() 68 | ] 69 | except BaseException as e: 70 | raise Exception('Fail to load json file of the prediction.') 71 | json_gt = [json.loads(line) for line in open(gt_file).readlines()] 72 | if len(json_gt) != len(json_pred): 73 | raise Exception( 74 | 'We do not get the predictions of all the test tasks') 75 | gts = {l['raw_file']: l for l in json_gt} 76 | accuracy, fp, fn = 0., 0., 0. 77 | for pred in json_pred: 78 | if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred: 79 | raise Exception( 80 | 'raw_file or lanes or run_time not in some predictions.') 81 | raw_file = pred['raw_file'] 82 | pred_lanes = pred['lanes'] 83 | run_time = pred['run_time'] 84 | if raw_file not in gts: 85 | raise Exception( 86 | 'Some raw_file from your predictions do not exist in the test tasks.' 87 | ) 88 | gt = gts[raw_file] 89 | gt_lanes = gt['lanes'] 90 | y_samples = gt['h_samples'] 91 | try: 92 | a, p, n = LaneEval.bench(pred_lanes, gt_lanes, y_samples, 93 | run_time) 94 | except BaseException as e: 95 | raise Exception('Format of lanes error.') 96 | accuracy += a 97 | fp += p 98 | fn += n 99 | num = len(gts) 100 | # the first return parameter is the default ranking parameter 101 | 102 | fp = fp / num 103 | fn = fn / num 104 | tp = 1 - fp 105 | precision = tp / (tp + fp) 106 | recall = tp / (tp + fn) 107 | f1 = 2 * precision * recall / (precision + recall) 108 | 109 | return json.dumps([{ 110 | 'name': 'Accuracy', 111 | 'value': accuracy / num, 112 | 'order': 'desc' 113 | }, { 114 | 'name': 'F1_score', 115 | 'value': f1, 116 | 'order': 'desc' 117 | }, { 118 | 'name': 'FP', 119 | 'value': fp, 120 | 'order': 'asc' 121 | }, { 122 | 'name': 'FN', 123 | 'value': fn, 124 | 'order': 'asc' 125 | }]), accuracy / num 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | try: 131 | if len(sys.argv) != 3: 132 | raise Exception('Invalid input arguments') 133 | print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2])) 134 | except Exception as e: 135 | print(e.message) 136 | sys.exit(e.message) 137 | -------------------------------------------------------------------------------- /flamnet/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import os.path as osp 4 | 5 | 6 | COLORS = [ 7 | (255, 0, 0), 8 | (0, 255, 0), 9 | (0, 0, 255), 10 | (255, 255, 0), 11 | (255, 0, 255), 12 | (0, 255, 255), 13 | (128, 255, 0), 14 | (255, 128, 0), 15 | (128, 0, 255), 16 | (255, 0, 128), 17 | (0, 128, 255), 18 | (0, 255, 128), 19 | (128, 255, 255), 20 | (255, 128, 255), 21 | (255, 255, 128), 22 | (60, 180, 0), 23 | (180, 60, 0), 24 | (0, 60, 180), 25 | (0, 180, 60), 26 | (60, 0, 180), 27 | (180, 0, 60), 28 | (255, 0, 0), 29 | (0, 255, 0), 30 | (0, 0, 255), 31 | (255, 255, 0), 32 | (255, 0, 255), 33 | (0, 255, 255), 34 | (128, 255, 0), 35 | (255, 128, 0), 36 | (128, 0, 255), 37 | ] 38 | 39 | 40 | def imshow_lanes(img, lanes, show=False, out_file=None, width=4): 41 | lanes_xys = [] 42 | for _, lane in enumerate(lanes): 43 | xys = [] 44 | for x, y in lane: 45 | if x <= 0 or y <= 0: 46 | continue 47 | x, y = int(x), int(y) 48 | xys.append((x, y)) 49 | lanes_xys.append(xys) 50 | lanes_xys.sort(key=lambda xys : xys[0][0]) 51 | 52 | for idx, xys in enumerate(lanes_xys): 53 | for i in range(1, len(xys)): 54 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) 55 | 56 | 57 | if show: 58 | cv2.imshow('view', img) 59 | cv2.waitKey(1) 60 | # cv2.waitKey(0) 61 | 62 | if out_file: 63 | if not osp.exists(osp.dirname(out_file)): 64 | os.makedirs(osp.dirname(out_file)) 65 | cv2.imwrite(out_file, img) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import numpy as np 8 | import random 9 | from flamnet.utils.config import Config 10 | from flamnet.engine.runner import Runner 11 | from flamnet.datasets import build_dataloader 12 | 13 | 14 | def main(): 15 | args = parse_args() 16 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( 17 | str(gpu) for gpu in args.gpus) 18 | 19 | cfg = Config.fromfile(args.config) 20 | cfg.gpus = len(args.gpus) 21 | 22 | cfg.load_from = args.load_from 23 | cfg.resume_from = args.resume_from 24 | cfg.finetune_from = args.finetune_from 25 | cfg.view = args.view 26 | cfg.seed = args.seed 27 | 28 | cfg.work_dirs = args.work_dirs if args.work_dirs else cfg.work_dirs 29 | 30 | cudnn.benchmark = True 31 | 32 | runner = Runner(cfg) 33 | 34 | if args.validate: 35 | runner.validate() 36 | elif args.test: 37 | runner.test() 38 | else: 39 | runner.train() 40 | 41 | 42 | def parse_args(): 43 | parser = argparse.ArgumentParser(description='Train a detector') 44 | parser.add_argument('config', help='train config file path') 45 | parser.add_argument('--work_dirs', 46 | type=str, 47 | default=None, 48 | help='work dirs') 49 | parser.add_argument('--load_from', 50 | default=None, 51 | help='the checkpoint file to load from') 52 | parser.add_argument('--resume_from', 53 | default=None, 54 | help='the checkpoint file to resume from') 55 | parser.add_argument('--finetune_from', 56 | default=None, 57 | help='the checkpoint file to resume from') 58 | parser.add_argument('--view', action='store_true', help='whether to view') 59 | parser.add_argument( 60 | '--validate', 61 | action='store_true', 62 | help='whether to evaluate the checkpoint during training') 63 | parser.add_argument( 64 | '--test', 65 | action='store_true', 66 | help='whether to test the checkpoint on testing set') 67 | parser.add_argument('--gpus', nargs='+', type=int, default='0') 68 | parser.add_argument('--seed', type=int, default=0, help='random seed') 69 | args = parser.parse_args() 70 | 71 | return args 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0 2 | torchvision==0.9.0 3 | pandas 4 | addict 5 | sklearn 6 | opencv-python 7 | pytorch_warmup 8 | scikit-image 9 | tqdm 10 | p_tqdm 11 | imgaug>=0.4.0 12 | Shapely==1.7.0 13 | ujson==1.35 14 | yapf 15 | pathspec 16 | timm 17 | mmcv==1.2.5 18 | albumentations==0.4.6 19 | pathspec 20 | ptflops 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | from setuptools import find_packages, setup 5 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 6 | 7 | 8 | def parse_requirements(fname='requirements.txt', with_version=True): 9 | """Parse the package dependencies listed in a requirements file but strips 10 | specific versioning information. 11 | Args: 12 | fname (str): path to requirements file 13 | with_version (bool, default=False): if True include version specs 14 | Returns: 15 | List[str]: list of requirements items 16 | CommandLine: 17 | python -c "import setup; print(setup.parse_requirements())" 18 | """ 19 | import sys 20 | from os.path import exists 21 | require_fpath = fname 22 | 23 | def parse_line(line): 24 | """Parse information from a line in a requirements text file.""" 25 | if line.startswith('-r '): 26 | # Allow specifying requirements in other files 27 | target = line.split(' ')[1] 28 | for info in parse_require_file(target): 29 | yield info 30 | else: 31 | info = {'line': line} 32 | if line.startswith('-e '): 33 | info['package'] = line.split('#egg=')[1] 34 | else: 35 | # Remove versioning from the package 36 | pat = '(' + '|'.join(['>=', '==', '>']) + ')' 37 | parts = re.split(pat, line, maxsplit=1) 38 | parts = [p.strip() for p in parts] 39 | 40 | info['package'] = parts[0] 41 | if len(parts) > 1: 42 | op, rest = parts[1:] 43 | if ';' in rest: 44 | # Handle platform specific dependencies 45 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 46 | version, platform_deps = map(str.strip, 47 | rest.split(';')) 48 | info['platform_deps'] = platform_deps 49 | else: 50 | version = rest # NOQA 51 | info['version'] = (op, version) 52 | yield info 53 | 54 | def parse_require_file(fpath): 55 | with open(fpath, 'r') as f: 56 | for line in f.readlines(): 57 | line = line.strip() 58 | if line and not line.startswith('#'): 59 | for info in parse_line(line): 60 | yield info 61 | 62 | def gen_packages_items(): 63 | if exists(require_fpath): 64 | for info in parse_require_file(require_fpath): 65 | parts = [info['package']] 66 | if with_version and 'version' in info: 67 | parts.extend(info['version']) 68 | if not sys.version.startswith('3.4'): 69 | # apparently package_deps are broken in 3.4 70 | platform_deps = info.get('platform_deps') 71 | if platform_deps is not None: 72 | parts.append(';' + platform_deps) 73 | item = ''.join(parts) 74 | yield item 75 | 76 | packages = list(gen_packages_items()) 77 | return packages 78 | 79 | 80 | install_requires = parse_requirements() 81 | 82 | 83 | def get_extensions(): 84 | extensions = [] 85 | 86 | op_files = glob.glob('flamnet/ops/csrc/*.c*') 87 | extension = CUDAExtension 88 | ext_name = 'flamnet.ops.nms_impl' 89 | 90 | ext_ops = extension( 91 | name=ext_name, 92 | sources=op_files, 93 | ) 94 | 95 | extensions.append(ext_ops) 96 | 97 | return extensions 98 | 99 | 100 | setup(name='flamnet', 101 | version="1.0", 102 | keywords='computer vision & lane detection', 103 | classifiers=[ 104 | 'License :: OSI Approved :: MIT License', 105 | 'Programming Language :: Python :: 3', 106 | 'Intended Audience :: Developers', 107 | 'Operating System :: OS Independent' 108 | ], 109 | packages=find_packages(), 110 | include_package_data=True, 111 | setup_requires=['pytest-runner'], 112 | tests_require=['pytest'], 113 | install_requires=install_requires, 114 | ext_modules=get_extensions(), 115 | cmdclass={'build_ext': BuildExtension}, 116 | zip_safe=False) 117 | -------------------------------------------------------------------------------- /tools/detect.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import os 5 | import os.path as osp 6 | import glob 7 | import argparse 8 | from flamnet.datasets.process import Process 9 | from flamnet.models.registry import build_net 10 | from flamnet.utils.config import Config 11 | from flamnet.utils.visualization import imshow_lanes 12 | from flamnet.utils.net_utils import load_network 13 | from pathlib import Path 14 | from tqdm import tqdm 15 | 16 | class Detect(object): 17 | def __init__(self, cfg): 18 | self.cfg = cfg 19 | self.processes = Process(cfg.val_process, cfg) 20 | self.net = build_net(self.cfg) 21 | self.net = torch.nn.parallel.DataParallel( 22 | self.net, device_ids = range(1)).cuda() 23 | self.net.eval() 24 | load_network(self.net, self.cfg.load_from) 25 | 26 | def preprocess(self, img_path): 27 | ori_img = cv2.imread(img_path) 28 | img = ori_img[self.cfg.cut_height:, :].astype(np.float32) 29 | data = {'img': img, 'lanes': []} 30 | data = self.processes(data) 31 | data['img'] = data['img'].unsqueeze(0) 32 | data.update({'img_path':img_path, 'ori_img':ori_img}) 33 | return data 34 | 35 | def inference(self, data): 36 | with torch.no_grad(): 37 | data = self.net(data) 38 | data = self.net.module.heads.get_lanes(data) 39 | return data 40 | 41 | def show(self, data): 42 | out_file = self.cfg.savedir 43 | if out_file: 44 | out_file = osp.join(out_file, osp.basename(data['img_path'])) 45 | lanes = [lane.to_array(self.cfg) for lane in data['lanes']] 46 | imshow_lanes(data['ori_img'], lanes, show=self.cfg.show, out_file=out_file) 47 | 48 | def run(self, data): 49 | data = self.preprocess(data) 50 | data['lanes'] = self.inference(data)[0] 51 | if self.cfg.show or self.cfg.savedir: 52 | self.show(data) 53 | return data 54 | 55 | def get_img_paths(path): 56 | p = str(Path(path).absolute()) # os-agnostic absolute path 57 | if '*' in p: 58 | paths = sorted(glob.glob(p, recursive=True)) # glob 59 | elif os.path.isdir(p): 60 | paths = sorted(glob.glob(os.path.join(p, '*.*'))) # dir 61 | elif os.path.isfile(p): 62 | paths = [p] # files 63 | else: 64 | raise Exception(f'ERROR: {p} does not exist') 65 | return paths 66 | 67 | def process(args): 68 | cfg = Config.fromfile(args.config) 69 | cfg.show = args.show 70 | cfg.savedir = args.savedir 71 | cfg.load_from = args.load_from 72 | detect = Detect(cfg) 73 | paths = get_img_paths(args.img) 74 | for p in tqdm(paths): 75 | detect.run(p) 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('config', help='The path of config file') 80 | parser.add_argument('--img', help='The path of the img (img file or img_folder), for example: data/*.png') 81 | parser.add_argument('--show', action='store_true', 82 | help='Whether to show the image') 83 | parser.add_argument('--savedir', type=str, default="work_dirs/clr/FLAMNet_test/vision", help='The root of save directory') 84 | parser.add_argument('--load_from', type=str, default='best.pth', help='The path of model') 85 | args = parser.parse_args() 86 | process(args) -------------------------------------------------------------------------------- /tools/generate_seg_tusimple.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import cv2 4 | import os 5 | import argparse 6 | 7 | TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json'] 8 | VAL_SET = ['label_data_0531.json'] 9 | TRAIN_VAL_SET = TRAIN_SET + VAL_SET 10 | TEST_SET = ['test_label.json'] 11 | 12 | 13 | def gen_label_for_json(args, image_set): 14 | H, W = 720, 1280 15 | SEG_WIDTH = 30 16 | save_dir = args.savedir 17 | 18 | os.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True) 19 | list_f = open( 20 | os.path.join(args.root, args.savedir, "list", 21 | "{}_gt.txt".format(image_set)), "w") 22 | 23 | json_path = os.path.join(args.root, args.savedir, 24 | "{}.json".format(image_set)) 25 | with open(json_path) as f: 26 | for line in f: 27 | label = json.loads(line) 28 | # ---------- clean and sort lanes ------------- 29 | lanes = [] 30 | _lanes = [] 31 | slope = [ 32 | ] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope 33 | for i in range(len(label['lanes'])): 34 | l = [(x, y) 35 | for x, y in zip(label['lanes'][i], label['h_samples']) 36 | if x >= 0] 37 | if (len(l) > 1): 38 | _lanes.append(l) 39 | slope.append( 40 | np.arctan2(l[-1][1] - l[0][1], l[0][0] - l[-1][0]) / 41 | np.pi * 180) 42 | _lanes = [_lanes[i] for i in np.argsort(slope)] 43 | slope = [slope[i] for i in np.argsort(slope)] 44 | 45 | idx = [None for i in range(6)] 46 | for i in range(len(slope)): 47 | if slope[i] <= 90: 48 | idx[2] = i 49 | idx[1] = i - 1 if i > 0 else None 50 | idx[0] = i - 2 if i > 1 else None 51 | else: 52 | idx[3] = i 53 | idx[4] = i + 1 if i + 1 < len(slope) else None 54 | idx[5] = i + 2 if i + 2 < len(slope) else None 55 | break 56 | for i in range(6): 57 | lanes.append([] if idx[i] is None else _lanes[idx[i]]) 58 | 59 | # --------------------------------------------- 60 | 61 | img_path = label['raw_file'] 62 | seg_img = np.zeros((H, W, 3)) 63 | list_str = [] # str to be written to list.txt 64 | for i in range(len(lanes)): 65 | coords = lanes[i] 66 | if len(coords) < 4: 67 | list_str.append('0') 68 | continue 69 | for j in range(len(coords) - 1): 70 | cv2.line(seg_img, coords[j], coords[j + 1], 71 | (i + 1, i + 1, i + 1), SEG_WIDTH // 2) 72 | list_str.append('1') 73 | 74 | seg_path = img_path.split("/") 75 | seg_path, img_name = os.path.join(args.root, args.savedir, 76 | seg_path[1], 77 | seg_path[2]), seg_path[3] 78 | os.makedirs(seg_path, exist_ok=True) 79 | seg_path = os.path.join(seg_path, img_name[:-3] + "png") 80 | cv2.imwrite(seg_path, seg_img) 81 | 82 | seg_path = "/".join([ 83 | args.savedir, *img_path.split("/")[1:3], img_name[:-3] + "png" 84 | ]) 85 | if seg_path[0] != '/': 86 | seg_path = '/' + seg_path 87 | if img_path[0] != '/': 88 | img_path = '/' + img_path 89 | list_str.insert(0, seg_path) 90 | list_str.insert(0, img_path) 91 | list_str = " ".join(list_str) + "\n" 92 | list_f.write(list_str) 93 | 94 | 95 | def generate_json_file(save_dir, json_file, image_set): 96 | with open(os.path.join(save_dir, json_file), "w") as outfile: 97 | for json_name in (image_set): 98 | with open(os.path.join(args.root, json_name)) as infile: 99 | for line in infile: 100 | outfile.write(line) 101 | 102 | 103 | def generate_label(args): 104 | save_dir = os.path.join(args.root, args.savedir) 105 | os.makedirs(save_dir, exist_ok=True) 106 | generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET) 107 | generate_json_file(save_dir, "test.json", TEST_SET) 108 | 109 | print("generating train_val set...") 110 | gen_label_for_json(args, 'train_val') 111 | print("generating test set...") 112 | gen_label_for_json(args, 'test') 113 | 114 | 115 | if __name__ == '__main__': 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument('--root', 118 | required=True, 119 | help='The root of the Tusimple dataset') 120 | parser.add_argument('--savedir', 121 | type=str, 122 | default='seg_label', 123 | help='The root of the Tusimple dataset') 124 | args = parser.parse_args() 125 | 126 | generate_label(args) 127 | --------------------------------------------------------------------------------