├── tests ├── __init__.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── test_loss.cpython-37-PYTEST.pyc │ └── test_loss.py └── __pycache__ │ └── __init__.cpython-37.pyc ├── lanenet ├── __init__.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── loss.cpython-37.pyc │ │ ├── blocks.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── BiseNet_v2.cpython-37.pyc │ │ └── BiseNet_v2_2.cpython-37.pyc │ ├── model-old.py │ ├── model.py │ ├── loss.py │ ├── decoders.py │ ├── encoders.py │ ├── BiseNet.py │ ├── BiseNet_v2.py │ ├── BiseNet_v2-bak.py │ ├── BiseNet_v2-2021-04-21.py │ ├── bisenetv2.py │ ├── BiseNet_v2_1.py │ └── BiseNet_v2_2.py ├── utils │ ├── __init__.py │ ├── random.jpg │ ├── random1.jpg │ ├── frame0270.jpg │ ├── frame0830.jpg │ ├── frame0974.png │ ├── average_meter.py │ ├── postprocess.py │ ├── cli_helper.py │ └── PicEnhanceUtils.py ├── dataloader │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── transformers.cpython-37.pyc │ ├── transformers.py │ ├── lmdb_data_loaders.py │ └── data_loaders.py ├── __pycache__ │ ├── config.cpython-37.pyc │ └── __init__.cpython-37.pyc ├── config.py ├── Tensor2Pic.py ├── GetFrame1.py ├── GetFrame.py ├── ConcatVideo.py ├── test.py ├── online_test.py ├── demo_test.py ├── online_test_video.py └── train.py ├── data ├── tusimple_test_image │ ├── 0.jpg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ └── 4.jpg └── training_data_example │ ├── image │ ├── 0000.png │ ├── 0001.png │ ├── 0002.png │ ├── 0003.png │ ├── 0004.png │ └── 0005.png │ ├── gt_image_binary │ ├── 0000.png │ ├── 0001.png │ ├── 0002.png │ ├── 0003.png │ ├── 0004.png │ └── 0005.png │ ├── gt_image_instance │ ├── 0000.png │ ├── 0001.png │ ├── 0002.png │ ├── 0003.png │ ├── 0004.png │ └── 0005.png │ ├── val.txt │ └── train.txt ├── setup.py ├── README.md └── scripts ├── MultiWeightTrain.py ├── Convert2LMDB.py └── generateMGdataset_v3.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lanenet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lanenet/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lanenet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lanenet/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lanenet/utils/random.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/utils/random.jpg -------------------------------------------------------------------------------- /lanenet/utils/random1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/utils/random1.jpg -------------------------------------------------------------------------------- /lanenet/utils/frame0270.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/utils/frame0270.jpg -------------------------------------------------------------------------------- /lanenet/utils/frame0830.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/utils/frame0830.jpg -------------------------------------------------------------------------------- /lanenet/utils/frame0974.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/utils/frame0974.png -------------------------------------------------------------------------------- /data/tusimple_test_image/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/tusimple_test_image/0.jpg -------------------------------------------------------------------------------- /data/tusimple_test_image/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/tusimple_test_image/1.jpg -------------------------------------------------------------------------------- /data/tusimple_test_image/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/tusimple_test_image/2.jpg -------------------------------------------------------------------------------- /data/tusimple_test_image/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/tusimple_test_image/3.jpg -------------------------------------------------------------------------------- /data/tusimple_test_image/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/tusimple_test_image/4.jpg -------------------------------------------------------------------------------- /data/training_data_example/image/0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/image/0000.png -------------------------------------------------------------------------------- /data/training_data_example/image/0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/image/0001.png -------------------------------------------------------------------------------- /data/training_data_example/image/0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/image/0002.png -------------------------------------------------------------------------------- /data/training_data_example/image/0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/image/0003.png -------------------------------------------------------------------------------- /data/training_data_example/image/0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/image/0004.png -------------------------------------------------------------------------------- /data/training_data_example/image/0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/image/0005.png -------------------------------------------------------------------------------- /lanenet/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /tests/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/tests/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lanenet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lanenet/model/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/model/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /lanenet/model/__pycache__/blocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/model/__pycache__/blocks.cpython-37.pyc -------------------------------------------------------------------------------- /tests/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/tests/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lanenet/model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/training_data_example/gt_image_binary/0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_binary/0000.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_binary/0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_binary/0001.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_binary/0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_binary/0002.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_binary/0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_binary/0003.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_binary/0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_binary/0004.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_binary/0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_binary/0005.png -------------------------------------------------------------------------------- /lanenet/model/__pycache__/BiseNet_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/model/__pycache__/BiseNet_v2.cpython-37.pyc -------------------------------------------------------------------------------- /data/training_data_example/gt_image_instance/0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_instance/0000.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_instance/0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_instance/0001.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_instance/0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_instance/0002.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_instance/0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_instance/0003.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_instance/0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_instance/0004.png -------------------------------------------------------------------------------- /data/training_data_example/gt_image_instance/0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/data/training_data_example/gt_image_instance/0005.png -------------------------------------------------------------------------------- /lanenet/dataloader/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/dataloader/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /lanenet/model/__pycache__/BiseNet_v2_2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/model/__pycache__/BiseNet_v2_2.cpython-37.pyc -------------------------------------------------------------------------------- /tests/model/__pycache__/test_loss.cpython-37-PYTEST.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/tests/model/__pycache__/test_loss.cpython-37-PYTEST.pyc -------------------------------------------------------------------------------- /lanenet/config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | n_labels=6 3 | no_of_instances=6 4 | num_classes=2 5 | gpu_no='cuda:0' 6 | device_ids = [0, 1,2,3,4,5,6] 7 | # 1-train,2-val 8 | is_training=1 -------------------------------------------------------------------------------- /lanenet/dataloader/__pycache__/transformers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mengpengfei/pytorch-lanenet/HEAD/lanenet/dataloader/__pycache__/transformers.cpython-37.pyc -------------------------------------------------------------------------------- /lanenet/Tensor2Pic.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | from torchvision import transforms 4 | 5 | toPIL = transforms.ToPILImage() #这个函数可以将张量转为PIL图片,由小数转为0-255之间的像素值 6 | img = torch.randn(3,128,64) 7 | pic = toPIL(img) 8 | pic.save('./random1.jpg') -------------------------------------------------------------------------------- /data/training_data_example/val.txt: -------------------------------------------------------------------------------- 1 | ./data/training_data_example/image/0004.png ./data/training_data_example/gt_image_binary/0004.png ./data/training_data_example/gt_image_instance/0004.png 2 | ./data/training_data_example/image/0005.png ./data/training_data_example/gt_image_binary/0005.png ./data/training_data_example/gt_image_instance/0005.png -------------------------------------------------------------------------------- /lanenet/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(): 2 | """Computes and stores the average and current value 3 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 4 | """ 5 | 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | package_name = 'lanenet' 4 | 5 | setup( 6 | name=package_name, 7 | version='0.1.0', 8 | packages=find_packages(), 9 | py_modules=[], 10 | zip_safe=True, 11 | install_requires=[ 12 | 'setuptools', 13 | 'torch', 14 | 'torchvision', 15 | 'opencv-python', 16 | 'numpy', 17 | 'tqdm' 18 | ], 19 | author='Andreas Klintberg', 20 | maintainer='Andreas Klintberg', 21 | description='Lanenet implementation in PyTorch', 22 | license='Apache License, Version 2.0', 23 | test_suite='pytest' 24 | ) 25 | -------------------------------------------------------------------------------- /data/training_data_example/train.txt: -------------------------------------------------------------------------------- 1 | ./data/training_data_example/image/0000.png ./data/training_data_example/gt_image_binary/0000.png ./data/training_data_example/gt_image_instance/0000.png 2 | ./data/training_data_example/image/0001.png ./data/training_data_example/gt_image_binary/0001.png ./data/training_data_example/gt_image_instance/0001.png 3 | ./data/training_data_example/image/0002.png ./data/training_data_example/gt_image_binary/0002.png ./data/training_data_example/gt_image_instance/0002.png 4 | ./data/training_data_example/image/0003.png ./data/training_data_example/gt_image_binary/0000.png ./data/training_data_example/gt_image_instance/0003.png 5 | -------------------------------------------------------------------------------- /lanenet/dataloader/transformers.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | from skimage.transform import resize 4 | 5 | class Rescale(): 6 | """Rescale the image in a sample to a given size. 7 | 8 | Args: 9 | output_size (width, height) (tuple): Desired output size (width, height). Output is 10 | matched to output_size. 11 | """ 12 | 13 | def __init__(self, output_size): 14 | assert isinstance(output_size, (tuple)) 15 | self.output_size = output_size 16 | 17 | def __call__(self, sample): 18 | #sample = resize(sample, self.output_size) 19 | sample = cv2.resize(sample, dsize=self.output_size, interpolation=cv2.INTER_NEAREST) 20 | 21 | return sample 22 | -------------------------------------------------------------------------------- /lanenet/GetFrame1.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import time 4 | #要提取视频的文件名,隐藏后缀 5 | sourceFileName='ch0_20200318140335_20200318140435' 6 | #在这里把后缀接上 7 | video_path = os.path.join("/workspace/lanenet-lane-detection-11000", sourceFileName+'.mp4') 8 | times=0 9 | #提取视频的频率,每1帧提取一个 10 | frameFrequency=10 11 | print(video_path) 12 | camera = cv2.VideoCapture(video_path) 13 | 14 | file_dir='./vedio/ch0_20200318140335_20200318140435/' 15 | 16 | while True: 17 | times=times+1 18 | res, image = camera.read() 19 | if not res: 20 | print('not res , not image') 21 | break 22 | if times%frameFrequency==0: 23 | cv2.imwrite(file_dir + str(times)+'.jpg', image) 24 | # image=cv2.resize(image,(1280,720)) #将图片转换为1280*720 25 | print('-----------------------end---------------------------------') 26 | camera.release() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-lanenet 2 | 效果很好的lanenet网络,主干网络基于bisenetv2并对主干网络做了修改,效果远好于bisnetv2 3 | 可直接训练自己的数据应用于生产 4 | 5 | inspired by https://github.com/MaybeShewill-CV/lanenet-lane-detection 6 | 7 | Using Bisenetv2 as Encoder. 8 | 9 | 使用步骤: 10 | 11 | 1、安装pytorch环境,pytorch官网有说明,推荐使用docker 12 | 13 | 2、生成样本的train.txt和val.txt文件 14 | 15 | 文件内容: 16 | 17 | 原始图片 语义分割图 实例分割图 18 | 19 | 3、修改 script目录下的Convert2LMDB.py main函数中的txt文件路径和生成的lmdb文件名 20 | 21 | 4、修改lanenet/train.py中的train_dataset_file和val_dataset_file为自己生成的lmdb文件路径 22 | 23 | 按照自己的需要lanenet/config.py文件中的配置参数 24 | 25 | 5、执行 python setup.py install 26 | 27 | 6、执行python lanenet/train.py --lr 0.001 --val True --bs 16 --save ./checkpoints --w1 0.25 --w2 0.25 --w3 0.25 --w4 0.25 --epochs 200 28 | 29 | 开始训练。 30 | 31 | 7、训练完成后 修改 lanenet/online_test_video.py 中的模型路径和视频路径,测试视频‘ 32 | 33 | 34 | 35 | TODO: 36 | 1、将代码封装,抽取配置,整理代码 37 | 38 | -------------------------------------------------------------------------------- /scripts/MultiWeightTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | if __name__ == '__main__': 3 | for i in range(1,10): 4 | for j in range(1,10): 5 | for m in range(1,10): 6 | for n in range(1,10): 7 | if i+j+m+n==10: 8 | print('python lanenet/train.py --dataset /workspace/mogo_data/index ' 9 | '--lr 0.001 --val True --bs 16 --save ./checkpoints --w1 {4} --w2 {5} --w3 {6} --w4 {7} --epochs 30 2>&1 ' 10 | '> logs/mogodata{0}_{1}_{2}_{3}.log'.format(i/10,j/10,m/10,n/10,i/10,j/10,m/10,n/10)) 11 | # print("{0}_{1}_{2}_{3}".format(i/10,j/10,m/10,n/10)) 12 | # os.popen('python lanenet/train.py --dataset /workspace/mogo_data/index ' 13 | # '--lr 0.001 --val True --bs 16 --save ./checkpoints --w1 {4} --w2 {5} --w3 {6} --w4 {7} 2>&1 ' 14 | # '> mogodata{0}_{1}_{2}_{3}.log'.format(i/10,j/10,m/10,n/10,i/10,j/10,m/10,n/10)) 15 | 16 | -------------------------------------------------------------------------------- /lanenet/GetFrame.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import time 4 | #要提取视频的文件名,隐藏后缀 5 | sourceFileName='ch0_20200318140335_20200318140435' 6 | #在这里把后缀接上 7 | video_path = os.path.join("/workspace/lanenet-lane-detection-11000", sourceFileName+'.mp4') 8 | times=0 9 | #提取视频的频率,每1帧提取一个 10 | frameFrequency=1 11 | print(video_path) 12 | camera = cv2.VideoCapture(video_path) 13 | 14 | file_dir='./vedio/ch0_20200318140335_20200318140435/' 15 | video=cv2.VideoWriter('./vedio/test.avi',cv2.VideoWriter_fourcc(*'MJPG'),25,(1280,720)) #定义保存视频目录名称及压缩格式,fps=10,像素为1280*720 16 | 17 | while True: 18 | times=times+1 19 | res, image = camera.read() 20 | print(res) 21 | print(image.shape) 22 | if not res: 23 | print('not res , not image') 24 | break 25 | if times%frameFrequency==0: 26 | # cv2.imwrite(outPutDirName + str(times)+'.jpg', image) 27 | image=cv2.resize(image,(1280,720)) #将图片转换为1280*720 28 | video.write(image) #写入视频 29 | print('-----------------------end---------------------------------') 30 | camera.release() 31 | video.release() -------------------------------------------------------------------------------- /lanenet/ConcatVideo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | INPUT_FILE1 = 'D:/yum/tmcdata/test123/test.avi' 5 | INPUT_FILE2 = 'D:/yum/tmcdata/test123/test-ch.avi' 6 | OUTPUT_FILE = 'D:/yum/tmcdata/test123/merge.avi' 7 | 8 | reader1 = cv2.VideoCapture(INPUT_FILE1) 9 | reader2 = cv2.VideoCapture(INPUT_FILE2) 10 | width = int(reader1.get(cv2.CAP_PROP_FRAME_WIDTH)) 11 | height = int(reader1.get(cv2.CAP_PROP_FRAME_HEIGHT)) 12 | writer = cv2.VideoWriter(OUTPUT_FILE, 13 | cv2.VideoWriter_fourcc(*'MJPG'),25, #fps 14 | (width, height//2)) # resolution 15 | 16 | print(reader1.isOpened()) 17 | print(reader2.isOpened()) 18 | have_more_frame = True 19 | c = 0 20 | while have_more_frame: 21 | have_more_frame, frame1 = reader1.read() 22 | _, frame2 = reader2.read() 23 | try: 24 | frame1 = cv2.resize(frame1, (width//2, height//2)) 25 | frame2 = cv2.resize(frame2, (width//2, height//2)) 26 | img = np.hstack((frame1, frame2)) 27 | cv2.waitKey(1) 28 | writer.write(img) 29 | c += 1 30 | except: 31 | pass 32 | print(str(c) + ' is ok') 33 | 34 | 35 | writer.release() 36 | reader1.release() 37 | reader2.release() 38 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /lanenet/utils/postprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import MeanShift, estimate_bandwidth 3 | 4 | 5 | def embedding_post_process(embedding, bin_seg, band_width=1.5, max_num_lane=4): 6 | """ 7 | First use mean shift to find dense cluster center. 8 | 9 | Arguments: 10 | ---------- 11 | embedding: numpy [H, W, embed_dim] 12 | bin_seg: numpy [H, W], each pixel is 0 or 1, 0 for background pixel 13 | delta_v: coordinates within distance of 2*delta_v to cluster center are 14 | 15 | Return: 16 | --------- 17 | cluster_result: numpy [H, W], index of different lanes on each pixel 18 | """ 19 | cluster_result = np.zeros(bin_seg.shape, dtype=np.int32) 20 | 21 | cluster_list = embedding[bin_seg>0] 22 | if len(cluster_list)==0: 23 | return cluster_result 24 | 25 | mean_shift = MeanShift(bandwidth=band_width, bin_seeding=True, n_jobs=-1) 26 | mean_shift.fit(cluster_list) 27 | 28 | labels = mean_shift.labels_ 29 | cluster_result[bin_seg>0] = labels + 1 30 | 31 | cluster_result[cluster_result > max_num_lane] = 0 32 | for idx in np.unique(cluster_result): 33 | if len(cluster_result[cluster_result==idx]) < 15: 34 | cluster_result[cluster_result==idx] = 0 35 | 36 | return cluster_result -------------------------------------------------------------------------------- /lanenet/utils/cli_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--dataset", help="Dataset path") 6 | parser.add_argument("--save", required=False, help="Directory to save model checkpoint", default="./checkpoints") 7 | parser.add_argument("--epochs", required=False, type=int, help="Training epochs", default=100000) 8 | parser.add_argument("--bs", required=False, type=int, help="Batch size", default=32) 9 | parser.add_argument("--val", required=False, type=bool, help="Use validation", default=False) 10 | parser.add_argument("--lr", required=False, type=float, help="Learning rate", default=0.0005) 11 | parser.add_argument("--pretrained", required=False, default=None, help="pretrained model path") 12 | parser.add_argument("--image", default="./output", help="output image folder") 13 | parser.add_argument("--net", help="backbone network") 14 | parser.add_argument("--json", help="post processing json") 15 | parser.add_argument("--w1", help="post processing json") 16 | parser.add_argument("--w2", help="post processing json") 17 | parser.add_argument("--w3", help="post processing json") 18 | parser.add_argument("--w4", help="post processing json") 19 | return parser.parse_args() 20 | -------------------------------------------------------------------------------- /tests/model/test_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lanenet.model import HNetLoss 4 | 5 | 6 | def test_hnet(): 7 | gt_labels = torch.tensor([[[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [3.0, 3.0, 1.0]], 8 | [[1.0, 1.0, 1.0], [2.0, 2.0, 1.0], [3.0, 3.0, 1.0]]], 9 | dtype=torch.float32).view(6,3) 10 | transformation_coffecient = torch.tensor([0.58348501, -0.79861236, 2.30343866, 11 | -0.09976104, -1.22268307, 2.43086767], 12 | dtype=torch.float32) 13 | 14 | # import numpy as np 15 | # c_val = [0.58348501, -0.79861236, 2.30343866, 16 | # -0.09976104, -1.22268307, 2.43086767] 17 | # R = np.zeros([3, 3], np.float32) 18 | # R[0, 0] = c_val[0] 19 | # R[0, 1] = c_val[1] 20 | # R[0, 2] = c_val[2] 21 | # R[1, 1] = c_val[3] 22 | # R[1, 2] = c_val[4] 23 | # R[2, 1] = c_val[5] 24 | # R[2, 2] = 1 25 | # 26 | # print(np.mat(R).I) 27 | hnet_loss = HNetLoss(gt_labels, transformation_coffecient, 'loss') 28 | hnet_inference = HNetLoss(gt_labels, transformation_coffecient, 'inference') 29 | 30 | _loss = hnet_loss._hnet_loss() 31 | 32 | _pred = hnet_inference._hnet_transformation() 33 | 34 | print("loss: ", _loss) 35 | print("pred: ", _pred) 36 | -------------------------------------------------------------------------------- /lanenet/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import cv2 9 | from tqdm import tqdm 10 | from lanenet import config 11 | 12 | from lanenet.model.model import compute_loss 13 | from lanenet.utils.average_meter import AverageMeter 14 | 15 | DEVICE = torch.device(config.gpu_no if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def test(val_loader, model, epoch): 19 | model.eval() 20 | step = 0 21 | batch_time = AverageMeter() 22 | total_losses = AverageMeter() 23 | binary_losses = AverageMeter() 24 | instance_losses = AverageMeter() 25 | mean_iou = AverageMeter() 26 | end = time.time() 27 | val_img_list = [] 28 | # val_img_md5 = open(os.path.join(im_path, "val_" + str(epoch + 1) + ".txt"), "w") 29 | for batch_idx, batch in enumerate(iter(val_loader)): 30 | step += 1 31 | # image_data = Variable(input_data["input_tensor"]).to(DEVICE) 32 | # instance_label = Variable(input_data["instance_label"]).to(DEVICE) 33 | # binary_label = Variable(input_data["binary_label"]).to(DEVICE) 34 | 35 | image_data = Variable(batch[0]).type(torch.FloatTensor).to(DEVICE) 36 | binary_label = Variable(batch[1]).type(torch.LongTensor).to(DEVICE) 37 | instance_label = Variable(batch[2]).type(torch.FloatTensor).to(DEVICE) 38 | 39 | 40 | # output process 41 | net_output = model(image_data) 42 | net_output 43 | total_loss, binary_loss, instance_loss, out, val_iou = compute_loss(net_output, binary_label, instance_label) 44 | total_losses.update(total_loss.item(), image_data.size()[0]) 45 | binary_losses.update(binary_loss.item(), image_data.size()[0]) 46 | instance_losses.update(instance_loss.item(), image_data.size()[0]) 47 | mean_iou.update(val_iou, image_data.size()[0]) 48 | 49 | # if step % 100 == 0: 50 | # val_img_list.append( 51 | # compose_img(image_data, out, binary_label, net_output["instance_seg_logits"], instance_label, 0)) 52 | # val_img_md5.write(input_data["img_name"][0] + "\n") 53 | # lane_cluster_and_draw(image_data, net_output["binary_seg_pred"], net_output["instance_seg_logits"], input_data["o_size"], input_data["img_name"], json_path) 54 | batch_time.update(time.time() - end) 55 | end = time.time() 56 | 57 | # print( 58 | # "Epoch {ep} Validation Report | ETA: {et:.2f}|Total:{tot:.5f}|Binary:{bin:.5f}|Instance:{ins:.5f}|IoU:{iou:.5f}".format( 59 | # ep=epoch + 1, 60 | # et=batch_time.val, 61 | # tot=total_losses.avg, 62 | # bin=binary_losses.avg, 63 | # ins=instance_losses.avg, 64 | # iou=mean_iou.avg, 65 | # )) 66 | # sys.stdout.flush() 67 | # val_img = np.concatenate(val_img_list, axis=1) 68 | # cv2.imwrite(os.path.join(im_path, "val_" + str(epoch + 1) + ".png"), val_img) 69 | # val_img_md5.close() 70 | return mean_iou.avg 71 | -------------------------------------------------------------------------------- /lanenet/online_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from PIL import Image 4 | import numpy as np 5 | from lanenet.model.model import LaneNet, compute_loss 6 | import cv2 7 | # import torch.nn.functional as F 8 | from torchvision.transforms import functional as F 9 | import time 10 | from torchvision import transforms 11 | from lanenet.dataloader.transformers import Rescale 12 | import os 13 | DEVICE = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu') 14 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 15 | 16 | if __name__ == '__main__': 17 | import os 18 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 19 | model_path = './checkpoints-combine/773_checkpoint.pth' 20 | gpu = True 21 | if not torch.cuda.is_available(): 22 | gpu = False 23 | model = LaneNet() 24 | model.to(DEVICE) 25 | # if gpu: 26 | # model = model.cuda() 27 | # print('loading pretrained model from %s' % model_path) 28 | if gpu: 29 | model.load_state_dict(torch.load(model_path)) 30 | else: 31 | model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 32 | model.eval() 33 | 34 | #if gpu: 35 | # model.load_state_dict(torch.load(model_path)) 36 | #else: 37 | # model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 38 | 39 | picPath="/workspace/mogo_data/index/test" 40 | video=cv2.VideoWriter('./vedio/test1.avi',cv2.VideoWriter_fourcc(*'MJPG'),25,(1280, 720)) 41 | for file in os.listdir(picPath): 42 | #file="frame0270.jpg" 43 | toPIL = transforms.ToPILImage() 44 | #transform=transforms.Compose([Rescale((720, 1280))]) 45 | file_path = os.path.join(picPath, file) 46 | imgori = cv2.imread(file_path, cv2.IMREAD_COLOR) 47 | 48 | #img0=np.asarray(F.crop(toPIL(imgori[:,:,[2,1,0]]), imgori.shape[0]-240,0, 240, imgori.shape[1])) 49 | #img0 = transform(imgori) 50 | img0=imgori 51 | img = img0.reshape(img0.shape[2], img0.shape[0], img0.shape[1]) 52 | img = np.expand_dims(img,0) 53 | print(img.shape) 54 | imgdata=Variable(torch.from_numpy(img)).type(torch.FloatTensor).to(DEVICE) 55 | output=model(imgdata) 56 | binary_seg_pred=output["binary_seg_pred"] 57 | binary_seg_pred = binary_seg_pred.squeeze(0) 58 | binary_seg_pred1=binary_seg_pred.to(torch.float32).cpu() 59 | pic = toPIL(binary_seg_pred1) 60 | imgx = cv2.cvtColor(np.asarray(pic),cv2.COLOR_RGB2BGR) 61 | imgx[np.where((imgx!=[0, 0, 0]).all(axis=2))] = [255,255,255] 62 | src7 = cv2.addWeighted(img0,0.8,imgx,1,0) 63 | final_img=src7 64 | # final_img=cv2.resize(src7,(1280, 720)) 65 | final_img[np.where((final_img==[255, 255, 255]).all(axis=2))] = [0,0,255] 66 | #cv2.imwrite('./random.jpg',final_img) 67 | video.write(final_img) 68 | #break 69 | print('-----------------------runing---------------------------------') 70 | print('-----------------------end---------------------------------') 71 | video.release() 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /lanenet/demo_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import cv2 8 | from torchvision import transforms 9 | from lanenet.dataloader.transformers import Rescale 10 | from lanenet.model.model import LaneNet 11 | from lanenet.utils.postprocess import embedding_post_process 12 | import torch.nn as nn 13 | import os 14 | DEVICE = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu') 15 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--img_path", '-i', type=str, default="demo/demo.jpg", help="Path to demo img") 21 | parser.add_argument("--model_path", '-w', type=str, help="Path to model weights") 22 | # parser.add_argument("--band_width", '-b', type=float, default=1.5, help="Value of delta_v") 23 | # parser.add_argument("--visualize", '-v', action="store_true", default=False, help="Visualize the result") 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def main(): 29 | args = parse_args() 30 | img_path = args.img_path 31 | model_path = args.model_path 32 | 33 | # global best_epoch 34 | # global args 35 | gpu = True 36 | if not torch.cuda.is_available(): 37 | gpu = False 38 | 39 | model = LaneNet() 40 | model.to(DEVICE) 41 | 42 | if gpu: 43 | model.load_state_dict(torch.load(model_path)) 44 | else: 45 | model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 46 | 47 | model.eval() 48 | 49 | transform=transforms.Compose([Rescale((1280, 720))]) 50 | imgori = cv2.imread(img_path, cv2.IMREAD_COLOR) 51 | imgori=transform(imgori) 52 | toPIL = transforms.ToPILImage() 53 | img = np.asarray(toPIL(imgori[:,:,[2,1,0]])) 54 | img=np.transpose(img,(2,0,1)) 55 | img = np.expand_dims(img,0) 56 | print(img.shape) 57 | imgdata=Variable(torch.from_numpy(img)).type(torch.FloatTensor).to(DEVICE) 58 | print(imgdata.size()) 59 | output=model(imgdata) 60 | 61 | embedding = output['instance_seg_logits'] 62 | embedding = embedding.detach().cpu().numpy() 63 | embedding = np.transpose(embedding[0], (1, 2, 0)) 64 | 65 | bin_seg_pred=output["binary_seg_pred"][0][0].detach().cpu().numpy() 66 | 67 | img = cv2.cvtColor(imgori, cv2.COLOR_RGB2BGR) 68 | seg_img = np.zeros_like(img) 69 | 70 | lane_seg_img = embedding_post_process(embedding, bin_seg_pred, band_width=3, max_num_lane=6) 71 | color = np.array([ 72 | [255, 125, 0], 73 | [0, 255, 0], 74 | [0, 0, 255], 75 | [0, 255, 255], 76 | [255, 0, 0], 77 | [255, 255, 0]], dtype='uint8') 78 | 79 | for i, lane_idx in enumerate(np.unique(lane_seg_img)): 80 | if lane_idx==0: 81 | continue 82 | seg_img[lane_seg_img == lane_idx] = color[i-1] 83 | img = cv2.addWeighted(src1=seg_img, alpha=0.8, src2=img, beta=1., gamma=0.) 84 | 85 | cv2.imwrite("demo/demo_result.jpg", img) 86 | 87 | # if args.visualize: 88 | if True: 89 | cv2.imshow("", img) 90 | cv2.waitKey(0) 91 | cv2.destroyAllWindows() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /lanenet/model/model-old.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | LaneNet model 4 | https://arxiv.org/pdf/1807.01726.pdf 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from lanenet.model.loss import DiscriminativeLoss 11 | from lanenet.model.encoders import VGGEncoder 12 | from lanenet.model.decoders import ESPNetDecoder, FCNDecoder 13 | from lanenet import config 14 | 15 | DEVICE = torch.device(config.gpu_no if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class LaneNet1(nn.Module): 19 | def __init__(self, arch="VGG"): 20 | super(LaneNet1, self).__init__() 21 | # no of instances for segmentation 22 | self.no_of_instances = 5 23 | encode_num_blocks = 5 24 | in_channels = [3, 64, 128, 256, 512] 25 | out_channels = in_channels[1:] + [512] 26 | self._arch = arch 27 | if self._arch == 'VGG': 28 | self._encoder = VGGEncoder(encode_num_blocks, in_channels, out_channels) 29 | self._encoder.to(DEVICE) 30 | 31 | decode_layers = ["pool5", "pool4", "pool3"] 32 | decode_channels = out_channels[:-len(decode_layers) - 1:-1] 33 | decode_last_stride = 8 34 | # self._decoder = ESPNetDecoder() 35 | self._decoder = FCNDecoder(decode_layers, decode_channels, decode_last_stride) 36 | self._decoder.to(DEVICE) 37 | elif self._arch == 'ESPNet': 38 | raise NotImplementedError 39 | elif self._arch == 'ENNet': 40 | raise NotImplementedError 41 | 42 | self._pix_layer = nn.Conv2d(in_channels=64, out_channels=self.no_of_instances, kernel_size=1, bias=False).to( 43 | DEVICE) 44 | self.relu = nn.ReLU().to(DEVICE) 45 | 46 | def forward(self, input_tensor): 47 | encode_ret = self._encoder(input_tensor) 48 | decode_ret = self._decoder(encode_ret) 49 | 50 | decode_logits = decode_ret['logits'] 51 | 52 | decode_logits = decode_logits.to(DEVICE) 53 | 54 | binary_seg_ret = torch.argmax(F.softmax(decode_logits, dim=1), dim=1, keepdim=True) 55 | 56 | decode_deconv = decode_ret['deconv'] 57 | pix_embedding = self.relu(self._pix_layer(decode_deconv)) 58 | 59 | return { 60 | 'instance_seg_logits': pix_embedding, 61 | 'binary_seg_pred': binary_seg_ret, 62 | 'binary_seg_logits': decode_logits 63 | } 64 | if __name__ == '__main__': 65 | import os 66 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 67 | input = torch.rand(1, 3, 256, 512).cuda() 68 | model = LaneNet1().cuda() 69 | model.eval() 70 | print(model) 71 | output = model(input) 72 | binary_seg_pred=output["binary_seg_pred"].squeeze(0) 73 | binary_seg_pred=binary_seg_pred.squeeze(0) 74 | 75 | instance_seg_logits=output["instance_seg_logits"].squeeze(0) 76 | instance_seg_logits=instance_seg_logits.permute(1, 2, 0) 77 | 78 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 79 | print('BiSeNet_v2', binary_seg_pred.size()) 80 | print('BiSeNet_v2', instance_seg_logits.size()) 81 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 82 | 83 | def compute_loss(net_output, binary_label, instance_label): 84 | k_binary = 0.7 85 | k_instance = 0.3 86 | k_dist = 1.0 87 | 88 | ce_loss_fn = nn.CrossEntropyLoss() 89 | binary_seg_logits = net_output["binary_seg_logits"] 90 | binary_loss = ce_loss_fn(binary_seg_logits, binary_label) 91 | 92 | pix_embedding = net_output["instance_seg_logits"] 93 | ds_loss_fn = DiscriminativeLoss(0.5, 1.5, 1.0, 1.0, 0.001) 94 | var_loss, dist_loss, reg_loss = ds_loss_fn(pix_embedding, instance_label) 95 | binary_loss = binary_loss * k_binary 96 | instance_loss = var_loss * k_instance 97 | dist_loss = dist_loss * k_dist 98 | total_loss = binary_loss + instance_loss + dist_loss 99 | out = net_output["binary_seg_pred"] 100 | iou = 0 101 | batch_size = out.size()[0] 102 | for i in range(batch_size): 103 | PR = out[i].squeeze(0).nonzero().size()[0] 104 | GT = binary_label[i].nonzero().size()[0] 105 | TP = (out[i].squeeze(0) * binary_label[i]).nonzero().size()[0] 106 | union = PR + GT - TP 107 | iou += TP / union 108 | iou = iou / batch_size 109 | return total_loss, binary_loss, instance_loss, out, iou 110 | -------------------------------------------------------------------------------- /lanenet/dataloader/lmdb_data_loaders.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | from torch.utils.data import Dataset, DataLoader 3 | from lanenet import config 4 | import numpy as np 5 | from torchvision import transforms 6 | from lanenet.utils import PicEnhanceUtils 7 | from PIL import Image 8 | 9 | class LaneDataSet(Dataset): 10 | 11 | def __init__(self, dataset, n_labels=config.n_labels, transform=None): 12 | env = lmdb.open(dataset, max_dbs=6, map_size=int(1024*1024*1024*8), readonly=True) 13 | # 创建对应的数据库 14 | self.gt_img_db = env.open_db("gt_img".encode()) 15 | self.gt_img_shape_db = env.open_db("gt_img_shape".encode()) 16 | 17 | self.gt_binary_img_db = env.open_db("gt_binary_img".encode()) 18 | self.gt_binary_img_shape_db = env.open_db("gt_binary_img_shape".encode()) 19 | 20 | self.gt_instance_img_db = env.open_db("gt_instance_img".encode()) 21 | self.gt_instance_img_shape_db = env.open_db("gt_instance_shape".encode()) 22 | self.n_labels = n_labels 23 | 24 | self.txn = env.begin() 25 | self._length = self.txn.stat(db=self.gt_img_db)["entries"] 26 | 27 | def _split_instance_gt(self, label_instance_img): 28 | # number of channels, number of unique pixel values, subtracting no label 29 | # adapted from here https://github.com/nyoki-mtl/pytorch-discriminative-loss/blob/master/src/dataset.py 30 | no_of_instances = self.n_labels 31 | ins = np.zeros((no_of_instances, label_instance_img.shape[0], label_instance_img.shape[1])) 32 | for _ch, label in enumerate(np.unique(label_instance_img)[1:]): 33 | ins[_ch, label_instance_img == label] = 1 34 | return ins 35 | 36 | def __getitem__(self, idx): 37 | idx = str(idx).encode() 38 | 39 | gt_img_buf = self.txn.get(idx, db=self.gt_img_db) 40 | gt_img_array = np.frombuffer(gt_img_buf, dtype=np.uint8) 41 | gt_img_list=str(self.txn.get(idx, db=self.gt_img_shape_db).decode()).replace(' ','').split(',') 42 | img=gt_img_array.reshape(int(gt_img_list[0]), int(gt_img_list[1]), int(gt_img_list[2])) 43 | 44 | 45 | gt_binary_img_buf = self.txn.get(idx, db=self.gt_binary_img_db) 46 | gt_binary_img_array = np.frombuffer(gt_binary_img_buf, dtype=np.uint8) 47 | gt_binary_img_list=str(self.txn.get(idx, db=self.gt_binary_img_shape_db).decode()).replace(' ','').split(',') 48 | label_img=gt_binary_img_array.reshape(int(gt_binary_img_list[0]), int(gt_binary_img_list[1]),int(gt_binary_img_list[2])) 49 | 50 | gt_instance_img_buf = self.txn.get(idx, db=self.gt_instance_img_db) 51 | gt_instance_img_array = np.frombuffer(gt_instance_img_buf, dtype=np.uint8) 52 | gt_instance_img_list=str(self.txn.get(idx, db=self.gt_instance_img_shape_db).decode()).replace(' ','').split(',') 53 | label_instance_img=gt_instance_img_array.reshape(int(gt_instance_img_list[0]), int(gt_instance_img_list[1])) 54 | 55 | toPil=transforms.ToPILImage() 56 | 57 | # img=toPil(img[:,:,[2,1,0]]) 58 | img=toPil(img[:,:,[2,1,0]]) 59 | 60 | label_instance_img=toPil(label_instance_img) 61 | label_img=toPil(label_img[:,:,[2,1,0]]) 62 | 63 | img=PicEnhanceUtils.random_color_augmentation(img) 64 | img,label_img,label_instance_img=PicEnhanceUtils.random_horizon_flip_batch_images(img,label_img,label_instance_img) 65 | img,label_img,label_instance_img=PicEnhanceUtils.random_crop(img,label_img,label_instance_img) 66 | 67 | resize=transforms.Resize((720,1280),interpolation=Image.NEAREST) 68 | img=resize(img) 69 | label_img=resize(label_img) 70 | label_instance_img=resize(label_instance_img) 71 | 72 | # if self.transform: 73 | img=np.asarray(img) 74 | label_img=np.asarray(label_img) 75 | label_instance_img=np.asarray(label_instance_img) 76 | label_instance_img = self._split_instance_gt(label_instance_img) 77 | 78 | # img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) 79 | 80 | img=np.transpose(img,(2,0,1)) 81 | 82 | 83 | label_binary = np.zeros([label_img.shape[0], label_img.shape[1]], dtype=np.uint8) 84 | mask = np.where((label_img[:, :, :] != [0, 0, 0]).all(axis=2)) 85 | label_binary[mask] = 1 86 | 87 | # we could split the instance label here, each instance in one channel (basically a binary mask for each) 88 | return img, label_binary, label_instance_img 89 | 90 | def __len__(self): 91 | return self._length -------------------------------------------------------------------------------- /scripts/Convert2LMDB.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | import lmdb 4 | import cv2 5 | import numpy as np 6 | from torchvision import transforms 7 | import random 8 | 9 | 10 | def get_image_label_list(data_txt): 11 | gt_imgs=[] 12 | gt_binary_imgs=[] 13 | gt_instance_imgs=[] 14 | lines=open(data_txt).readlines() 15 | new = [] # 定义一个空列表,用来存储结果 16 | for line in lines: 17 | temp1 = line.strip('\n') # 去掉每行最后的换行符'\n' 18 | new.append(temp1) # 将上一步得到的列表添加到new中 19 | random.shuffle(new)#乱序一个列表 20 | for line in new: 21 | line_arr=line.rstrip('\n').split(' ') 22 | gt_imgs.append(line_arr[0]) 23 | gt_binary_imgs.append(line_arr[1]) 24 | gt_instance_imgs.append(line_arr[2]) 25 | return gt_imgs,gt_binary_imgs,gt_instance_imgs 26 | 27 | 28 | def img2lmdb(txt_path,data_name): 29 | # 创建数据库文件 30 | # env = lmdb.open(data_name) 31 | env = lmdb.open(data_name, max_dbs=6, map_size=int(1024*1024*1024*50)) 32 | 33 | # 创建对应的数据库 34 | gt_img = env.open_db("gt_img".encode()) 35 | gt_img_shape = env.open_db("gt_img_shape".encode()) 36 | 37 | gt_binary_img = env.open_db("gt_binary_img".encode()) 38 | gt_binary_img_shape = env.open_db("gt_binary_img_shape".encode()) 39 | 40 | gt_instance_img = env.open_db("gt_instance_img".encode()) 41 | gt_instance_shape = env.open_db("gt_instance_shape".encode()) 42 | # -----------------------val------------------------------ 43 | # val_gt_img = env.open_db("val_gt_img") 44 | # val_gt_img_shape = env.open_db("val_gt_img_shape") 45 | # 46 | # val_gt_binary_img = env.open_db("val_gt_binary_img") 47 | # val_gt_binary_img_shape = env.open_db("val_gt_binary_img_shape") 48 | # 49 | # val_gt_instance_img = env.open_db("val_gt_instance_img") 50 | # val_gt_instance_img_shape = env.open_db("val_gt_instance_img_shape") 51 | 52 | gt_imgs,gt_binary_imgs,gt_instance_imgs = get_image_label_list(txt_path) 53 | #print(gt_binary_imgs) 54 | # val_gt_imgs,val_gt_binary_imgs,val_gt_instance_imgs = get_image_label_list('val_txt_path') 55 | # 把图像数据写入到LMDB中 56 | with env.begin(write=True) as txn: 57 | for idx, path in enumerate(gt_imgs): 58 | print("{} {}".format(idx, path)) 59 | data = cv2.imread(path, cv2.IMREAD_COLOR) 60 | print(path) 61 | txn.put(str(idx).encode(), data, db=gt_img) 62 | txn.put(str(idx).encode(),"".join(str(data.shape)).replace('(','').replace(')','').encode(), db=gt_img_shape) 63 | 64 | for idx, path in enumerate(gt_binary_imgs): 65 | print("{} {}".format(idx, path)) 66 | data = cv2.imread(path, cv2.IMREAD_COLOR) 67 | txn.put(str(idx).encode(), data, db=gt_binary_img) 68 | txn.put(str(idx).encode(),"".join(str(data.shape)).replace('(','').replace(')','').encode(), db=gt_binary_img_shape) 69 | 70 | for idx, path in enumerate(gt_instance_imgs): 71 | print("{} {}".format(idx, path)) 72 | data = cv2.imread(path, cv2.IMREAD_UNCHANGED) 73 | txn.put(str(idx).encode(), data, db=gt_instance_img) 74 | txn.put(str(idx).encode(),"".join(str(data.shape)).replace('(','').replace(')','').encode(), db=gt_instance_shape) 75 | # txn.commit() 76 | env.close() 77 | if __name__ == '__main__': 78 | # get_image_label_list('D:/yum/tmcdata/test123/val.txt') 79 | img2lmdb('/workspace/mogo_data/index/train.txt','train') 80 | # env = lmdb.open('./val', max_dbs=6, map_size=int(1024*1024*1024*8), readonly=True) 81 | # 创建对应的数据库 82 | # gt_img = env.open_db("gt_img".encode()) 83 | # gt_img_shape = env.open_db("gt_img_shape".encode()) 84 | # 85 | # gt_binary_img = env.open_db("gt_binary_img".encode()) 86 | # gt_binary_img_shape = env.open_db("gt_binary_img_shape".encode()) 87 | # 88 | # gt_instance_img = env.open_db("gt_instance_img".encode()) 89 | # gt_instance_shape = env.open_db("gt_instance_shape".encode()) 90 | # 91 | # txn = env.begin() 92 | # _length = txn.stat(db=gt_img)["entries"] 93 | # 94 | # a=np.frombuffer(txn.get('0'.encode(), db=gt_img),'uint8') 95 | # abcd=str(txn.get('0'.encode(), db=gt_img_shape).decode()).replace(' ','').split(',') 96 | # img = a.reshape(int(abcd[0]), int(abcd[1]),int(abcd[2])) 97 | # 98 | # toPIL = transforms.ToPILImage() 99 | # imgori = toPIL(img) 100 | # imgori.save('./randomxx.jpg') 101 | # cv2.imwrite('./random.jpg',np.asarray(imgori)) 102 | # print() 103 | 104 | # np.reshape(a,) -------------------------------------------------------------------------------- /lanenet/dataloader/data_loaders.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | from PIL import Image 7 | import cv2 8 | import numpy as np 9 | 10 | from torchvision.transforms import ToTensor 11 | from torchvision import datasets, transforms 12 | 13 | from lanenet.utils import PicEnhanceUtils 14 | 15 | import random 16 | from lanenet import config 17 | 18 | class LaneDataSet(Dataset): 19 | def __init__(self, dataset, n_labels=config.n_labels, transform=None): 20 | self._gt_img_list = [] 21 | self._gt_label_binary_list = [] 22 | self._gt_label_instance_list = [] 23 | self.transform = transform 24 | self.n_labels = n_labels 25 | 26 | with open(dataset, 'r') as file: 27 | for _info in file: 28 | info_tmp = _info.strip(' ').split() 29 | 30 | self._gt_img_list.append(info_tmp[0]) 31 | self._gt_label_binary_list.append(info_tmp[1]) 32 | self._gt_label_instance_list.append(info_tmp[2]) 33 | 34 | assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list) 35 | 36 | self._shuffle() 37 | 38 | def _shuffle(self): 39 | # randomly shuffle all list identically 40 | c = list(zip(self._gt_img_list, self._gt_label_binary_list, self._gt_label_instance_list)) 41 | random.shuffle(c) 42 | self._gt_img_list, self._gt_label_binary_list, self._gt_label_instance_list = zip(*c) 43 | 44 | def _split_instance_gt(self, label_instance_img): 45 | # number of channels, number of unique pixel values, subtracting no label 46 | # adapted from here https://github.com/nyoki-mtl/pytorch-discriminative-loss/blob/master/src/dataset.py 47 | no_of_instances = self.n_labels 48 | ins = np.zeros((no_of_instances, label_instance_img.shape[0], label_instance_img.shape[1])) 49 | for _ch, label in enumerate(np.unique(label_instance_img)[1:]): 50 | ins[_ch, label_instance_img == label] = 1 51 | return ins 52 | 53 | def __len__(self): 54 | return len(self._gt_img_list) 55 | 56 | def __getitem__(self, idx): 57 | assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \ 58 | == len(self._gt_img_list) 59 | 60 | # load all 61 | img = cv2.imread(self._gt_img_list[idx], cv2.IMREAD_COLOR) 62 | 63 | label_instance_img = cv2.imread(self._gt_label_instance_list[idx], cv2.IMREAD_UNCHANGED) 64 | 65 | label_img = cv2.imread(self._gt_label_binary_list[idx], cv2.IMREAD_COLOR) 66 | # print("------------------------------------------------------------------") 67 | # print(img.size()) 68 | # print("------------------------------------------------------------------") 69 | # optional transformations 70 | 71 | toPil=transforms.ToPILImage() 72 | 73 | img=toPil(img[:,:,[2,1,0]]) 74 | label_instance_img=toPil(label_instance_img) 75 | label_img=toPil(label_img[:,:,[2,1,0]]) 76 | 77 | img=PicEnhanceUtils.random_color_augmentation(img) 78 | img,label_img,label_instance_img=PicEnhanceUtils.random_horizon_flip_batch_images(img,label_img,label_instance_img) 79 | img,label_img,label_instance_img=PicEnhanceUtils.random_crop(img,label_img,label_instance_img) 80 | 81 | resize=transforms.Resize((720,1280),interpolation=Image.NEAREST) 82 | img=resize(img) 83 | label_img=resize(label_img) 84 | label_instance_img=resize(label_instance_img) 85 | 86 | # if self.transform: 87 | img=np.asarray(img) 88 | label_img=np.asarray(label_img) 89 | label_instance_img=np.asarray(label_instance_img) 90 | 91 | # extract each label into separate binary channels 92 | # print(self._gt_label_instance_list[idx]) 93 | label_instance_img = self._split_instance_gt(label_instance_img) 94 | 95 | 96 | # reshape for pytorch 97 | # tensorflow: [height, width, channels] 98 | # pytorch: [channels, height, width] 99 | # print("------------------------------------------------------------------") 100 | # print(img.size()) 101 | # print("------------------------------------------------------------------") 102 | # img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) 103 | 104 | img=np.transpose(img,(2,0,1)) 105 | 106 | # print("------------------------------------------------------------------") 107 | # print(img.size()) 108 | # print("------------------------------------------------------------------") 109 | 110 | 111 | # print("///////////////////////////////////////////////////") 112 | # print(img.shape) 113 | # print(label_img.shape) 114 | # print("///////////////////////////////////////////////////") 115 | 116 | label_binary = np.zeros([label_img.shape[0], label_img.shape[1]], dtype=np.uint8) 117 | mask = np.where((label_img[:, :, :] != [0, 0, 0]).all(axis=2)) 118 | label_binary[mask] = 1 119 | 120 | # we could split the instance label here, each instance in one channel (basically a binary mask for each) 121 | return img, label_binary, label_instance_img 122 | -------------------------------------------------------------------------------- /lanenet/model/model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | LaneNet model 4 | https://arxiv.org/pdf/1807.01726.pdf 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from lanenet.model.loss import DiscriminativeLoss 11 | from lanenet.model.BiseNet_v2_2 import BiSeNet 12 | from lanenet import config 13 | 14 | import numpy as np 15 | 16 | DEVICE = torch.device(config.gpu_no if torch.cuda.is_available() else 'cpu') 17 | # from lanenet.model import lanenet_postprocess 18 | 19 | class LaneNet(nn.Module): 20 | def __init__(self): 21 | super(LaneNet, self).__init__() 22 | self.biSeNet=BiSeNet() 23 | self.no_of_instances=config.no_of_instances 24 | self.bn = nn.BatchNorm2d(128) 25 | self._pix_layer = nn.Conv2d(in_channels=128, out_channels=self.no_of_instances, kernel_size=1, bias=False).to( 26 | DEVICE) 27 | self.relu = nn.ReLU().to(DEVICE) 28 | 29 | def forward(self, input_tensor): 30 | 31 | biseNet_out=self.biSeNet(input_tensor) 32 | 33 | binary_seg_logits=biseNet_out["binary_seg_logits"] 34 | 35 | softmax_out=F.softmax(binary_seg_logits, dim=1) 36 | if config.is_training==2: 37 | tensor1=torch.zeros_like(softmax_out) 38 | tensor2=torch.ones_like(softmax_out) 39 | softmax_out=torch.where(softmax_out <=0.5 , tensor1, tensor2) 40 | binary_seg_ret = torch.argmax(softmax_out, dim=1, keepdim=True) 41 | 42 | pix_embedding = self.relu(self._pix_layer(self.bn(biseNet_out["instance_seg_logits"]))) 43 | 44 | # instance_seg_logits=pix_embedding.squeeze(0) 45 | # print(instance_seg_logits.size()) 46 | # instance_seg_logits=instance_seg_logits.permute(1, 2, 0) 47 | # 48 | # binary_seg_pred=binary_seg_ret.squeeze(0) 49 | # binary_seg_pred=binary_seg_pred.squeeze(0) 50 | if config.is_training==1: 51 | return { 52 | 'instance_seg_logits': pix_embedding, 53 | 'binary_seg_pred': binary_seg_ret, 54 | 'binary_seg_logits': biseNet_out["binary_seg_logits"], 55 | 'bsb_out1': biseNet_out["bsb_out1"], 56 | 'bsb_out2': biseNet_out["bsb_out2"], 57 | 'sg_out1': biseNet_out["sg_out1"], 58 | 'sg_out3': biseNet_out["sg_out3"], 59 | 'sg_out4': biseNet_out["sg_out4"], 60 | 'sg_out5': biseNet_out["sg_out5"] 61 | } 62 | return { 63 | 'instance_seg_logits': pix_embedding, 64 | 'binary_seg_pred': binary_seg_ret, 65 | 'binary_seg_logits': biseNet_out["binary_seg_logits"] 66 | } 67 | 68 | if __name__ == '__main__': 69 | 70 | print(1==2) 71 | import os 72 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 73 | input = torch.rand(1, 3, 224, 1280) 74 | model = LaneNet() 75 | model.eval() 76 | print(model) 77 | output = model(input) 78 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 79 | print('BiSeNet_v2', output["instance_seg_logits"].size()) 80 | print('BiSeNet_v2', output["binary_seg_pred"].size()) 81 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 82 | # postprocessor = lanenet_postprocess.LaneNetPostProcessor(cfg=CFG) 83 | 84 | # postprocess_result = postprocessor.postprocess( 85 | # binary_seg_result=binary_seg_image[0], 86 | # instance_seg_result=instance_seg_image[0], 87 | # source_image=image_vis 88 | # ) 89 | # mask_image = postprocess_result['mask_image'] 90 | # 91 | # for i in range(CFG.MODEL.EMBEDDING_FEATS_DIMS): 92 | # instance_seg_image[0][:, :, i] = minmax_scale(instance_seg_image[0][:, :, i]) 93 | # embedding_image = np.array(instance_seg_image[0], np.uint8) 94 | 95 | 96 | def compute_loss(net_output, binary_label, instance_label,w1=0.25,w2=0.25,w3=0.25,w4=0.25): 97 | k_binary = 1.0 98 | k_instance = 0 99 | k_dist = 0 100 | 101 | # k_binary = 0.7 102 | # k_instance = 0.3 103 | # k_dist = 1.0 104 | 105 | ce_loss_fn = nn.CrossEntropyLoss() 106 | binary_seg_logits = net_output["binary_seg_logits"] 107 | bsb_out1 = net_output["bsb_out1"] 108 | bsb_out2 = net_output["bsb_out2"] 109 | 110 | sg_out1 = net_output["sg_out1"] 111 | sg_out3 = net_output["sg_out3"] 112 | sg_out4 = net_output["sg_out4"] 113 | sg_out5 = net_output["sg_out5"] 114 | # binary_seg_logits = net_output["binary_seg_pred"] 115 | binary_loss0 = ce_loss_fn(binary_seg_logits, binary_label) 116 | binary_loss1 = ce_loss_fn(bsb_out1, binary_label) 117 | binary_loss2 = ce_loss_fn(bsb_out2, binary_label) 118 | 119 | binary_loss3 = ce_loss_fn(sg_out1, binary_label) 120 | binary_loss4 = ce_loss_fn(sg_out3, binary_label) 121 | binary_loss5 = ce_loss_fn(sg_out4, binary_label) 122 | binary_loss6 = ce_loss_fn(sg_out5, binary_label) 123 | 124 | binary_loss=w1*binary_loss0+w2*binary_loss1+w3*binary_loss2+w4*(binary_loss3+binary_loss4+binary_loss5+binary_loss6)/4 125 | # binary_loss=(binary_loss0+binary_loss1+binary_loss2+(binary_loss3+binary_loss4+binary_loss5+binary_loss6)/4)/4 126 | # binary_loss=(binary_loss0+binary_loss1+binary_loss2)/3 127 | 128 | pix_embedding = net_output["instance_seg_logits"] 129 | # ds_loss_fn = DiscriminativeLoss(0.5, 1.5, 1.0, 1.0, 0.001) 130 | ds_loss_fn = DiscriminativeLoss(0.4, 3.0, 1.0, 1.0, 0.001) 131 | var_loss, dist_loss, reg_loss = ds_loss_fn(pix_embedding, instance_label) 132 | binary_loss = binary_loss * k_binary 133 | instance_loss = var_loss * k_instance 134 | dist_loss = dist_loss * k_dist 135 | total_loss = binary_loss + instance_loss + dist_loss 136 | out = net_output["binary_seg_pred"] 137 | iou = 0 138 | batch_size = out.size()[0] 139 | k=0 140 | for i in range(batch_size): 141 | PR = out[i].squeeze(0).nonzero().size()[0] 142 | GT = binary_label[i].nonzero().size()[0] 143 | TP = (out[i].squeeze(0) * binary_label[i]).nonzero().size()[0] 144 | union = PR + GT - TP 145 | if union!=0: 146 | iou += TP / union 147 | k+=1 148 | iou = iou / k 149 | return total_loss, binary_loss, instance_loss, out, iou 150 | -------------------------------------------------------------------------------- /lanenet/model/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/nyoki-mtl/pytorch-discriminative-loss/blob/master/src/loss.py 3 | This is the implementation of following paper: 4 | https://arxiv.org/pdf/1802.05591.pdf 5 | This implementation is based on following code: 6 | https://github.com/Wizaron/instance-segmentation-pytorch 7 | """ 8 | 9 | from torch.nn.modules.loss import _Loss 10 | from torch.autograd import Variable 11 | import torch 12 | from torch.functional import F 13 | from lanenet import config 14 | 15 | DEVICE = torch.device(config.gpu_no if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | class DiscriminativeLoss(_Loss): 19 | 20 | def __init__(self, delta_var=0.5, delta_dist=1.5, norm=2, alpha=1.0, beta=1.0, gamma=0.001, 21 | usegpu=False, size_average=True): 22 | super(DiscriminativeLoss, self).__init__(reduction='mean') 23 | self.delta_var = delta_var 24 | self.delta_dist = delta_dist 25 | self.norm = norm 26 | self.alpha = alpha 27 | self.beta = beta 28 | self.gamma = gamma 29 | self.usegpu = usegpu 30 | assert self.norm in [1, 2] 31 | 32 | def forward(self, input, target): 33 | # _assert_no_grad(target) 34 | return self._discriminative_loss(input, target) 35 | 36 | def _discriminative_loss(self, embedding, seg_gt): 37 | batch_size = embedding.shape[0] 38 | embed_dim = embedding.shape[1] 39 | 40 | var_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device) 41 | dist_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device) 42 | reg_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device) 43 | 44 | for b in range(batch_size): 45 | embedding_b = embedding[b] # (embed_dim, H, W) 46 | seg_gt_b = seg_gt[b] 47 | 48 | labels = torch.unique(seg_gt_b) 49 | labels = labels[labels != 0] 50 | num_lanes = len(labels) 51 | if num_lanes == 0: 52 | # please refer to issue here: https://github.com/harryhan618/LaneNet/issues/12 53 | _nonsense = embedding.sum() 54 | _zero = torch.zeros_like(_nonsense) 55 | var_loss = var_loss + _nonsense * _zero 56 | dist_loss = dist_loss + _nonsense * _zero 57 | reg_loss = reg_loss + _nonsense * _zero 58 | continue 59 | 60 | centroid_mean = [] 61 | for lane_idx in labels: 62 | seg_mask_i = (seg_gt_b == lane_idx) 63 | if not seg_mask_i.any(): 64 | continue 65 | embedding_i = embedding_b[seg_mask_i] 66 | 67 | mean_i = torch.mean(embedding_i, dim=0) 68 | centroid_mean.append(mean_i) 69 | 70 | # ---------- var_loss ------------- 71 | var_loss = var_loss + torch.mean(F.relu( 72 | torch.norm(embedding_i - mean_i, dim=0) - self.delta_var) ** 2) / num_lanes 73 | centroid_mean = torch.stack(centroid_mean) # (n_lane, embed_dim) 74 | 75 | if num_lanes > 1: 76 | centroid_mean1 = centroid_mean.reshape(-1, 1, embed_dim) 77 | centroid_mean2 = centroid_mean.reshape(1, -1, embed_dim) 78 | dist = torch.norm(centroid_mean1 - centroid_mean2, dim=2) # shape (num_lanes, num_lanes) 79 | dist = dist + torch.eye(num_lanes, dtype=dist.dtype, 80 | device=dist.device) * self.delta_dist # diagonal elements are 0, now mask above delta_d 81 | 82 | # divided by two for double calculated loss above, for implementation convenience 83 | dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_dist) ** 2) / ( 84 | num_lanes * (num_lanes - 1)) / 2 85 | 86 | # reg_loss is not used in original paper 87 | # reg_loss = reg_loss + torch.mean(torch.norm(centroid_mean, dim=1)) 88 | 89 | var_loss = var_loss / batch_size 90 | dist_loss = dist_loss / batch_size 91 | reg_loss = reg_loss / batch_size 92 | return var_loss, dist_loss, reg_loss 93 | 94 | 95 | class HNetLoss(_Loss): 96 | """ 97 | HNet Loss 98 | """ 99 | 100 | def __init__(self, gt_pts, transformation_coefficient, name, usegpu=True): 101 | """ 102 | 103 | :param gt_pts: [x, y, 1] 104 | :param transformation_coeffcient: [[a, b, c], [0, d, e], [0, f, 1]] 105 | :param name: 106 | :return: 107 | """ 108 | super(HNetLoss, self).__init__() 109 | 110 | self.gt_pts = gt_pts 111 | 112 | self.transformation_coefficient = transformation_coefficient 113 | self.name = name 114 | self.usegpu = usegpu 115 | 116 | def _hnet_loss(self): 117 | """ 118 | 119 | :return: 120 | """ 121 | H, preds = self._hnet() 122 | x_transformation_back = torch.matmul(torch.inverse(H), preds) 123 | loss = torch.mean(torch.pow(self.gt_pts.t()[0, :] - x_transformation_back[0, :], 2)) 124 | 125 | return loss 126 | 127 | def _hnet(self): 128 | """ 129 | 130 | :return: 131 | """ 132 | self.transformation_coefficient = torch.cat((self.transformation_coefficient, torch.tensor([1.0])), 133 | dim=0) 134 | H_indices = torch.tensor([0, 1, 2, 4, 5, 7, 8]) 135 | H_shape = 9 136 | H = torch.zeros(H_shape) 137 | H.scatter_(dim=0, index=H_indices, src=self.transformation_coefficient) 138 | H = H.view((3, 3)) 139 | 140 | pts_projects = torch.matmul(H, self.gt_pts.t()) 141 | 142 | Y = pts_projects[1, :] 143 | X = pts_projects[0, :] 144 | Y_One = torch.ones(Y.size()) 145 | Y_stack = torch.stack((torch.pow(Y, 3), torch.pow(Y, 2), Y, Y_One), dim=1).squeeze() 146 | w = torch.matmul(torch.matmul(torch.inverse(torch.matmul(Y_stack.t(), Y_stack)), 147 | Y_stack.t()), 148 | X.view(-1, 1)) 149 | 150 | x_preds = torch.matmul(Y_stack, w) 151 | preds = torch.stack((x_preds.squeeze(), Y, Y_One), dim=1).t() 152 | return (H, preds) 153 | 154 | def _hnet_transformation(self): 155 | """ 156 | """ 157 | H, preds = self._hnet() 158 | x_transformation_back = torch.matmul(torch.inverse(H), preds) 159 | 160 | return x_transformation_back 161 | 162 | def forward(self, input, target, n_clusters): 163 | return self._hnet_loss(input, target) 164 | -------------------------------------------------------------------------------- /lanenet/online_test_video.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import cv2 7 | from torchvision import transforms 8 | from lanenet.dataloader.transformers import Rescale 9 | from lanenet.model.model import LaneNet 10 | import torch.nn as nn 11 | import os 12 | DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') 13 | os.environ["CUDA_VISIBLE_DEVICES"] = '3' 14 | 15 | def compose_img(image_data, out, i=0): 16 | oridata=image_data[i].cpu().numpy() 17 | val_gt=oridata.transpose(1, 2, 0).astype(np.uint8) 18 | #val_gt=oridata.reshape(oridata.shape[1],oridata.shape[2],oridata.shape[0]).astype(np.uint8) 19 | # val_gt = (.transpose(1, 2, 0)).astype(np.uint8) 20 | predata=out[i].squeeze(0).cpu().numpy() 21 | val_pred=predata.transpose(0, 1)*255 22 | # val_pred=predata.reshape(predata.shape[0],predata.shape[1])*255 23 | # val_pred = .transpose(0, 1) * 255 24 | # val_label = binary_label[i].squeeze(0).cpu().numpy().transpose(0, 1) * 255 25 | val_out = np.zeros((val_pred.shape[0], val_pred.shape[1], 3), dtype=np.uint8) 26 | val_out[:, :, 0] = val_pred 27 | # val_out[:, :, 1] = val_label 28 | val_gt[val_out == 255] = 255 29 | # epsilon = 1e-5 30 | # pix_embedding = pix_embedding[i].data.cpu().numpy() 31 | # pix_vec = pix_embedding / (np.sum(pix_embedding, axis=0, keepdims=True) + epsilon) * 255 32 | # pix_vec = np.round(pix_vec).astype(np.uint8).transpose(1, 2, 0) 33 | # ins_label = instance_label[i].data.cpu().numpy().transpose(0, 1) 34 | # ins_label = np.repeat(np.expand_dims(ins_label, -1), 3, -1) 35 | # val_img = np.concatenate((val_gt, pix_vec, ins_label), axis=0) 36 | # val_img = np.concatenate((val_gt, pix_vec), axis=0) 37 | # return val_img 38 | return val_gt 39 | 40 | if __name__ == '__main__': 41 | model_path = './checkpoints-combine-new1/83_checkpoint.pth' 42 | gpu = True 43 | if not torch.cuda.is_available(): 44 | gpu = False 45 | 46 | # device = torch.device('cpu') 47 | 48 | model = LaneNet() 49 | model.to(DEVICE) 50 | # model = nn.DataParallel(model, device_ids=[0,7]) 51 | 52 | # model.load_state_dict(torch.load(model_path, map_location=device)) 53 | 54 | # if gpu: 55 | # model = model.cuda() 56 | # print('loading pretrained model from %s' % model_path) 57 | if gpu: 58 | model.load_state_dict(torch.load(model_path)) 59 | else: 60 | model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 61 | 62 | model.eval() 63 | 64 | sourceFileName='ch0_20200318140335_20200318140435' 65 | video_path = os.path.join("/workspace/lanenet-lane-detection-11000", sourceFileName+'.mp4') 66 | times=0 67 | frameFrequency=1 68 | camera = cv2.VideoCapture(video_path) 69 | 70 | video=cv2.VideoWriter('./vedio/test.avi',cv2.VideoWriter_fourcc(*'MJPG'),25,(1280, 720)) 71 | 72 | while True: 73 | res, imgori = camera.read() 74 | # cv2.imwrite('./vedio/pic/random%s.jpg'%(str(times)),imgori) 75 | times=times+1 76 | if not res: 77 | print('not res , not image') 78 | break 79 | if times%frameFrequency==0: 80 | # print('--------------------------------------------') 81 | # cv2.imwrite(outPutDirName + str(times)+'.jpg', image) 82 | transform=transforms.Compose([Rescale((1280, 720))]) 83 | imgori=transform(imgori) 84 | # img=imgori 85 | #print(img.shape) 86 | toPIL = transforms.ToPILImage() 87 | img = np.asarray(toPIL(imgori[:,:,[2,1,0]])) 88 | # imgori = transforms.ColorJitter(brightness=0.0001)(imgori) 89 | # img0 = transform(np.asarray(imgori)) 90 | # img=img0 91 | 92 | img=np.transpose(img,(2,0,1)) 93 | 94 | #img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) 95 | 96 | img = np.expand_dims(img,0) 97 | # toTensor=transforms.ToTensor() 98 | # imgdata=Variable(img).type(torch.FloatTensor).cuda() 99 | print(img.shape) 100 | # imgTensor=toTensor(img) 101 | #img = np.expand_dims(img,0) 102 | imgdata=Variable(torch.from_numpy(img)).type(torch.FloatTensor).to(DEVICE) 103 | # imgdata=imgdata.unsqueeze(0) 104 | print(imgdata.size()) 105 | output=model(imgdata) 106 | binary_seg_pred=output["binary_seg_pred"] 107 | 108 | # out=compose_img(imgdata,binary_seg_pred) 109 | 110 | binary_seg_pred = binary_seg_pred.squeeze(0) 111 | binary_seg_pred1=binary_seg_pred.to(torch.float32).cpu() 112 | # .numpy() 113 | # binary_seg_pred1=np.transpose(binary_seg_pred1,(1,2,0)) 114 | # # print(binary_seg_pred1.shape) 115 | # binary_seg_pred1=binary_seg_pred1.reshape(binary_seg_pred1.shape[1],binary_seg_pred1.shape[2],binary_seg_pred1.shape[0]) 116 | # # print(binary_seg_pred1[0]) 117 | # # binary_seg_pred1 = binary_seg_pred.squeeze(0).cpu().numpy() 118 | # # pic.save('./vedio/pic/random%s.jpg'%(str(int( round(time_stamp * 1000) )))) 119 | pic=toPIL(binary_seg_pred1) 120 | imgx = cv2.cvtColor(np.asarray(pic),cv2.COLOR_RGB2BGR) 121 | imgx[np.where((imgx!=[0, 0, 0]).all(axis=2))] = [255,255,255] 122 | # # final_img[np.where((final_img==[255, 255, 255]).all(axis=2))] = [0,0,255] 123 | # #cv2.imwrite('./vedio/pic/random%s.jpg'%(str(times)),imgx) 124 | # time_stamp = time.time() 125 | # 126 | # # img2 = cv2.merge((imgx,imgx,imgx)) 127 | # 128 | print (imgori.shape) 129 | print (imgx.shape) 130 | # 131 | src7 = cv2.addWeighted(imgori,0.8,imgx,1,0) 132 | # 133 | # #pic.save('./vedio/pic/random%s.jpg'%(str(int( round(time_stamp * 1000) )))) 134 | # # array1 = binary_seg_pred1.numpy()#to numpy array 135 | # # array1 = array1.reshape(array1.shape[1], array1.shape[2], array1.shape[0]) 136 | # # final_img=np.uint8(array1) 137 | # # pic = toPIL() 138 | # # pic.save('./random1.jpg') 139 | # #print(binary_seg_pred1.shape) 140 | final_img=cv2.resize(src7,(1280, 720)) 141 | final_img[np.where((final_img==[255, 255, 255]).all(axis=2))] = [0,0,255] 142 | # print(final_img) 143 | video.write(final_img) 144 | print("frame"+str(times)+str(times)) 145 | print('-----------------------end---------------------------------') 146 | camera.release() 147 | video.release() 148 | 149 | -------------------------------------------------------------------------------- /lanenet/model/decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lanenet import config 4 | 5 | DEVICE = torch.device(config.gpu_no if torch.cuda.is_available() else 'cpu') 6 | from .blocks import * 7 | 8 | class ESPNetDecoder(): 9 | def __init__(self): 10 | 11 | # light-weight decoder 12 | self.level3_C = C(128 + 3, classes, 1, 1) 13 | self.br = nn.BatchNorm2d(classes, eps=1e-03) 14 | self.conv = CBR(19 + classes, classes, 3, 1) 15 | 16 | self.up_l3 = nn.Sequential( 17 | nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False)) 18 | self.combine_l2_l3 = nn.Sequential(BR(2 * classes), 19 | DilatedParllelResidualBlockB(2 * classes, classes, add=False)) 20 | 21 | self.up_l2 = nn.Sequential( 22 | nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False), BR(classes)) 23 | 24 | self.classifier = nn.ConvTranspose2d(classes, classes, 2, stride=2, padding=0, output_padding=0, bias=False) 25 | 26 | def forward(self, input): 27 | ''' 28 | :param input: RGB image 29 | :return: transformed feature map 30 | ''' 31 | output0 = self.modules[0](input) 32 | inp1 = self.modules[1](input) 33 | inp2 = self.modules[2](input) 34 | 35 | output0_cat = self.modules[3](torch.cat([output0, inp1], 1)) 36 | output1_0 = self.modules[4](output0_cat) # down-sampled 37 | 38 | for i, layer in enumerate(self.modules[5]): 39 | if i == 0: 40 | output1 = layer(output1_0) 41 | else: 42 | output1 = layer(output1) 43 | 44 | output1_cat = self.modules[6](torch.cat([output1, output1_0, inp2], 1)) 45 | 46 | output2_0 = self.modules[7](output1_cat) # down-sampled 47 | for i, layer in enumerate(self.modules[8]): 48 | if i == 0: 49 | output2 = layer(output2_0) 50 | else: 51 | output2 = layer(output2) 52 | 53 | output2_cat = self.modules[9](torch.cat([output2_0, output2], 1)) # concatenate for feature map width expansion 54 | 55 | output2_c = self.up_l3(self.br(self.modules[10](output2_cat))) # RUM 56 | 57 | output1_C = self.level3_C(output1_cat) # project to C-dimensional space 58 | comb_l2_l3 = self.up_l2(self.combine_l2_l3(torch.cat([output1_C, output2_c], 1))) # RUM 59 | 60 | concat_features = self.conv(torch.cat([comb_l2_l3, output0_cat], 1)) 61 | 62 | classifier = self.classifier(concat_features) 63 | return classifier 64 | 65 | 66 | class ENetDecoder(): 67 | def __init__(self): 68 | # Stage 4 - Decoder 69 | self.upsample4_0 = UpsamplingBottleneck( 70 | 128, 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 71 | self.regular4_1 = RegularBottleneck( 72 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 73 | self.regular4_2 = RegularBottleneck( 74 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 75 | 76 | # Stage 5 - Decoder 77 | self.upsample5_0 = UpsamplingBottleneck( 78 | 64, 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 79 | self.regular5_1 = RegularBottleneck( 80 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 81 | self.transposed_conv = nn.ConvTranspose2d( 82 | 16, 83 | num_classes, 84 | kernel_size=3, 85 | stride=2, 86 | padding=1, 87 | output_padding=1, 88 | bias=False) 89 | 90 | def forward(self, x): 91 | # Initial block 92 | x = self.initial_block(x) 93 | 94 | # Stage 1 - Encoder 95 | x, max_indices1_0 = self.downsample1_0(x) 96 | x = self.regular1_1(x) 97 | x = self.regular1_2(x) 98 | x = self.regular1_3(x) 99 | x = self.regular1_4(x) 100 | 101 | # Stage 2 - Encoder 102 | x, max_indices2_0 = self.downsample2_0(x) 103 | x = self.regular2_1(x) 104 | x = self.dilated2_2(x) 105 | x = self.asymmetric2_3(x) 106 | x = self.dilated2_4(x) 107 | x = self.regular2_5(x) 108 | x = self.dilated2_6(x) 109 | x = self.asymmetric2_7(x) 110 | x = self.dilated2_8(x) 111 | 112 | # Stage 3 - Encoder 113 | x = self.regular3_0(x) 114 | x = self.dilated3_1(x) 115 | x = self.asymmetric3_2(x) 116 | x = self.dilated3_3(x) 117 | x = self.regular3_4(x) 118 | x = self.dilated3_5(x) 119 | x = self.asymmetric3_6(x) 120 | x = self.dilated3_7(x) 121 | 122 | # Stage 4 - Decoder 123 | x = self.upsample4_0(x, max_indices2_0) 124 | x = self.regular4_1(x) 125 | x = self.regular4_2(x) 126 | 127 | # Stage 5 - Decoder 128 | x = self.upsample5_0(x, max_indices1_0) 129 | x = self.regular5_1(x) 130 | x = self.transposed_conv(x) 131 | 132 | return x 133 | 134 | 135 | class FCNDecoder(nn.Module): 136 | def __init__(self, decode_layers, decode_channels=[], decode_last_stride=8): 137 | super(FCNDecoder, self).__init__() 138 | 139 | self._decode_channels = [512, 256] 140 | self._out_channel = 64 141 | self._decode_layers = decode_layers 142 | 143 | self._conv_layers = [] 144 | for _ch in self._decode_channels: 145 | self._conv_layers.append(nn.Conv2d(_ch, self._out_channel, kernel_size=1, bias=False).to(DEVICE)) 146 | 147 | self._conv_final = nn.Conv2d(self._out_channel, 2, kernel_size=1, bias=False) 148 | self._deconv = nn.ConvTranspose2d(self._out_channel, self._out_channel, kernel_size=4, stride=2, padding=1, 149 | bias=False) 150 | 151 | self._deconv_final = nn.ConvTranspose2d(self._out_channel, self._out_channel, kernel_size=16, 152 | stride=decode_last_stride, 153 | padding=4, bias=False) 154 | 155 | def forward(self, encode_data): 156 | ret = {} 157 | input_tensor = encode_data[self._decode_layers[0]] 158 | input_tensor.to(DEVICE) 159 | score = self._conv_layers[0](input_tensor) 160 | for i, layer in enumerate(self._decode_layers[1:]): 161 | deconv = self._deconv(score) 162 | 163 | input_tensor = encode_data[layer] 164 | score = self._conv_layers[i](input_tensor) 165 | 166 | fused = torch.add(deconv, score) 167 | score = fused 168 | 169 | deconv_final = self._deconv_final(score) 170 | score_final = self._conv_final(deconv_final) 171 | 172 | 173 | ret['logits'] = score_final 174 | ret['deconv'] = deconv_final 175 | return ret 176 | -------------------------------------------------------------------------------- /lanenet/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from lanenet.dataloader.lmdb_data_loaders import LaneDataSet 9 | from lanenet.dataloader.transformers import Rescale 10 | from lanenet.model.model import LaneNet, compute_loss 11 | from torch.utils.data import DataLoader 12 | from torch.autograd import Variable 13 | import torch.nn as nn 14 | 15 | from torchvision import transforms 16 | 17 | from lanenet.utils.cli_helper import parse_args 18 | from lanenet.utils.average_meter import AverageMeter 19 | from lanenet.test import test 20 | 21 | import numpy as np 22 | import cv2 23 | from lanenet import config 24 | 25 | # might want this in the transformer part as well 26 | # VGG_MEAN = [103.939, 116.779, 123.68] 27 | 28 | DEVICE = torch.device(config.gpu_no if torch.cuda.is_available() else 'cpu') 29 | 30 | 31 | def compose_img(image_data, out, binary_label, pix_embedding, instance_label, i): 32 | val_gt = (image_data[i].cpu().numpy().transpose(1, 2, 0)).astype(np.uint8) 33 | val_pred = out[i].squeeze(0).cpu().numpy().transpose(0, 1) * 255 34 | val_label = binary_label[i].squeeze(0).cpu().numpy().transpose(0, 1) * 255 35 | val_out = np.zeros((val_pred.shape[0], val_pred.shape[1], 3), dtype=np.uint8) 36 | val_out[:, :, 0] = val_pred 37 | val_out[:, :, 1] = val_label 38 | val_gt[val_out == 255] = 255 39 | # epsilon = 1e-5 40 | # pix_embedding = pix_embedding[i].data.cpu().numpy() 41 | # pix_vec = pix_embedding / (np.sum(pix_embedding, axis=0, keepdims=True) + epsilon) * 255 42 | # pix_vec = np.round(pix_vec).astype(np.uint8).transpose(1, 2, 0) 43 | # ins_label = instance_label[i].data.cpu().numpy().transpose(0, 1) 44 | # ins_label = np.repeat(np.expand_dims(ins_label, -1), 3, -1) 45 | # val_img = np.concatenate((val_gt, pix_vec, ins_label), axis=0) 46 | # val_img = np.concatenate((val_gt, pix_vec), axis=0) 47 | # return val_img 48 | return val_gt 49 | 50 | def train(train_loader, model, optimizer, epoch,w1,w2,w3,w4): 51 | model.train() 52 | batch_time = AverageMeter() 53 | mean_iou = AverageMeter() 54 | total_losses = AverageMeter() 55 | binary_losses = AverageMeter() 56 | instance_losses = AverageMeter() 57 | end = time.time() 58 | step = 0 59 | 60 | t = tqdm(enumerate(iter(train_loader)), leave=False, total=len(train_loader)) 61 | 62 | for batch_idx, batch in t: 63 | try: 64 | step += 1 65 | image_data = Variable(batch[0]).type(torch.FloatTensor).to(DEVICE) 66 | binary_label = Variable(batch[1]).type(torch.LongTensor).to(DEVICE) 67 | instance_label = Variable(batch[2]).type(torch.FloatTensor).to(DEVICE) 68 | 69 | #print("///////////////////////////////////////////////////") 70 | #print(image_data.size()) 71 | #print(binary_label.size()) 72 | # # print(image_data.shape) 73 | #print("///////////////////////////////////////////////////") 74 | 75 | # forward pass 76 | net_output = model(image_data) 77 | 78 | # compute loss 79 | total_loss, binary_loss, instance_loss, out, train_iou = compute_loss(net_output, binary_label, instance_label,w1,w2,w3,w4) 80 | 81 | # update loss in AverageMeter instance 82 | total_losses.update(total_loss.item(), image_data.size()[0]) 83 | binary_losses.update(binary_loss.item(), image_data.size()[0]) 84 | instance_losses.update(instance_loss.item(), image_data.size()[0]) 85 | mean_iou.update(train_iou, image_data.size()[0]) 86 | 87 | # reset gradients 88 | optimizer.zero_grad() 89 | 90 | # backpropagate 91 | total_loss.backward() 92 | 93 | # update weights 94 | optimizer.step() 95 | 96 | # update batch time 97 | batch_time.update(time.time() - end) 98 | end = time.time() 99 | 100 | if step % 30 == 0: 101 | print( 102 | "Epoch {ep} Step {st} |({batch}/{size})| ETA: {et:.2f}|Total loss:{tot:.5f}|Binary loss:{bin:.5f}|Instance loss:{ins:.5f}|IoU:{iou:.5f}".format( 103 | ep=epoch + 1, 104 | st=step, 105 | batch=batch_idx + 1, 106 | size=len(train_loader), 107 | et=batch_time.val, 108 | tot=total_losses.avg, 109 | bin=binary_losses.avg, 110 | ins=instance_losses.avg, 111 | iou=train_iou, 112 | )) 113 | print("current learning rate is %s"%(str(optimizer.state_dict()['param_groups'][0]['lr']))) 114 | sys.stdout.flush() 115 | train_img_list = [] 116 | for i in range(3): 117 | train_img_list.append( 118 | compose_img(image_data, out, binary_label, net_output["instance_seg_logits"], instance_label, i)) 119 | train_img = np.concatenate(train_img_list, axis=1) 120 | cv2.imwrite(os.path.join("./output", "train_" + str(epoch + 1) + "_step_" + str(step) + ".png"), train_img) 121 | except Exception as e: 122 | print(e) 123 | print('error') 124 | return mean_iou.avg 125 | 126 | 127 | def save_model(save_path, epoch, model): 128 | save_name = os.path.join(save_path, f'{epoch}_checkpoint.pth') 129 | torch.save(model.module.state_dict(), save_name) 130 | print("model is saved: {}".format(save_name)) 131 | 132 | 133 | def main(): 134 | args = parse_args() 135 | 136 | save_path = args.save 137 | w1 = args.w1 138 | w2 = args.w2 139 | w3 = args.w3 140 | w4 = args.w4 141 | 142 | if not os.path.isdir(save_path): 143 | os.makedirs(save_path) 144 | 145 | train_dataset_file = '/workspace/all/index/train1' 146 | val_dataset_file = '/workspace/all/index/val1' 147 | 148 | train_dataset = LaneDataSet(train_dataset_file, transform=None) 149 | # train_dataset = LaneDataSet(train_dataset_file, transform=transforms.Compose([Rescale((1280, 720))])) 150 | train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True,num_workers=24,pin_memory=True,drop_last=True) 151 | 152 | if args.val: 153 | val_dataset = LaneDataSet(val_dataset_file, transform=None) 154 | # val_dataset = LaneDataSet(val_dataset_file, transform=transforms.Compose([Rescale((1280, 720))])) 155 | val_loader = DataLoader(val_dataset, batch_size=args.bs, shuffle=True,num_workers=24,pin_memory=True,drop_last=True) 156 | 157 | model = LaneNet() 158 | model = nn.DataParallel(model, device_ids=config.device_ids) 159 | model.to(DEVICE) 160 | 161 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 162 | print(f"{args.epochs} epochs {len(train_dataset)} training samples\n") 163 | log_model="/workspace/pytorch-lanenet-master/checkpoints-combine-new1/83_checkpoint_state.pth" 164 | # 如果有保存的模型,则加载模型,并在其基础上继续训练 165 | if os.path.exists(log_model): 166 | checkpoint = torch.load(log_model) 167 | model.module.load_state_dict(checkpoint["net"]) 168 | optimizer.load_state_dict(checkpoint['optimizer']) 169 | 170 | #for p in optimizer.param_groups: 171 | # p['lr'] = args.lr 172 | 173 | start_epoch = int(checkpoint['epoch'])+1 174 | # start_epoch = 272 175 | print('load epoch {} success'.format(start_epoch-1)) 176 | else: 177 | start_epoch = 0 178 | print('no model,will start train from 0 epoche') 179 | 180 | for epoch in range(start_epoch, args.epochs): 181 | print(f"Epoch {epoch}") 182 | train_iou = train(train_loader, model, optimizer, epoch,w1,w2,w3,w4) 183 | if args.val: 184 | val_iou = test(val_loader, model, epoch) 185 | if (epoch+1) % 3 == 0: 186 | save_model(save_path, epoch, model) 187 | save_state_name = os.path.join(save_path, f'{epoch}_checkpoint_state.pth') 188 | checkpoint = { 189 | "net": model.module.state_dict(), 190 | 'optimizer':optimizer.state_dict(), 191 | "epoch": epoch 192 | } 193 | torch.save(checkpoint, save_state_name) 194 | print(f"Train IoU : {train_iou}") 195 | if args.val: 196 | print(f"Val IoU : {val_iou}") 197 | 198 | 199 | if __name__ == '__main__': 200 | main() 201 | -------------------------------------------------------------------------------- /lanenet/model/encoders.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Shared encoders (U-net). 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | from collections import OrderedDict 8 | 9 | import torchvision.models as models 10 | 11 | from .blocks import RegularBottleneck, DownsamplingBottleneck, InitialBlock, InputProjectionA, \ 12 | DilatedParallelResidualBlockB, DownSamplerB, C, CBR, BR 13 | 14 | 15 | class VGGEncoder(nn.Module): 16 | """ 17 | Simple VGG Encoder 18 | """ 19 | 20 | def __init__(self, num_blocks, in_channels, out_channels): 21 | super(VGGEncoder, self).__init__() 22 | 23 | self.pretrained_modules = models.vgg16(pretrained=True).features 24 | 25 | self.num_blocks = num_blocks 26 | self._in_channels = in_channels 27 | self._out_channels = out_channels 28 | self._conv_reps = [2, 2, 3, 3, 3] 29 | self.net = nn.Sequential() 30 | self.pretrained_net = nn.Sequential() 31 | 32 | for i in range(num_blocks): 33 | self.net.add_module("block" + str(i + 1), self._encode_block(i + 1)) 34 | self.pretrained_net.add_module("block" + str(i + 1), self._encode_pretrained_block(i + 1)) 35 | 36 | def _encode_block(self, block_id, kernel_size=3, stride=1): 37 | out_channels = self._out_channels[block_id - 1] 38 | padding = (kernel_size - 1) // 2 39 | seq = nn.Sequential() 40 | 41 | for i in range(self._conv_reps[block_id - 1]): 42 | if i == 0: 43 | in_channels = self._in_channels[block_id - 1] 44 | else: 45 | in_channels = out_channels 46 | seq.add_module("conv_{}_{}".format(block_id, i + 1), 47 | nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)) 48 | seq.add_module("bn_{}_{}".format(block_id, i + 1), nn.BatchNorm2d(out_channels)) 49 | seq.add_module("relu_{}_{}".format(block_id, i + 1), nn.ReLU()) 50 | seq.add_module("maxpool" + str(block_id), nn.MaxPool2d(kernel_size=2, stride=2)) 51 | return seq 52 | 53 | def _encode_pretrained_block(self, block_id): 54 | seq = nn.Sequential() 55 | for i in range(0, self._conv_reps[block_id - 1], 4): 56 | seq.add_module("conv_{}_{}".format(block_id, i + 1), self.pretrained_modules[i]) 57 | seq.add_module("relu_{}_{}".format(block_id, i + 2), self.pretrained_modules[i + 1]) 58 | seq.add_module("conv_{}_{}".format(block_id, i + 3), self.pretrained_modules[i + 2]) 59 | seq.add_module("relu_{}_{}".format(block_id, i + 4), self.pretrained_modules[i + 3]) 60 | seq.add_module("maxpool" + str(block_id), self.pretrained_modules[i + 4]) 61 | return seq 62 | 63 | def forward(self, input_tensor): 64 | ret = OrderedDict() 65 | # 5 stage of encoding 66 | X = input_tensor 67 | for i, block in enumerate(self.net): 68 | pool = block(X) 69 | ret["pool" + str(i + 1)] = pool 70 | 71 | X = pool 72 | return ret 73 | 74 | 75 | class ESPNetEncoder(nn.Module): 76 | """ 77 | ESPNet-C encoder 78 | """ 79 | 80 | def __init__(self, classes=20, p=5, q=3): 81 | ''' 82 | :param classes: number of classes in the dataset. Default is 20 for the cityscapes 83 | :param p: depth multiplier 84 | :param q: depth multiplier 85 | ''' 86 | super().__init__() 87 | self.level1 = CBR(3, 16, 3, 2) 88 | self.sample1 = InputProjectionA(1) 89 | self.sample2 = InputProjectionA(2) 90 | 91 | self.b1 = BR(16 + 3) 92 | self.level2_0 = DownSamplerB(16 + 3, 64) 93 | 94 | self.level2 = nn.ModuleList() 95 | for i in range(0, p): 96 | self.level2.append(DilatedParallelResidualBlockB(64, 64)) 97 | self.b2 = BR(128 + 3) 98 | 99 | self.level3_0 = DownSamplerB(128 + 3, 128) 100 | self.level3 = nn.ModuleList() 101 | for i in range(0, q): 102 | self.level3.append(DilatedParallelResidualBlockB(128, 128)) 103 | self.b3 = BR(256) 104 | 105 | self.classifier = C(256, classes, 1, 1) 106 | 107 | def forward(self, input): 108 | ''' 109 | :param input: Receives the input RGB image 110 | :return: the transformed feature map with spatial dimensions 1/8th of the input image 111 | ''' 112 | output0 = self.level1(input) 113 | inp1 = self.sample1(input) 114 | inp2 = self.sample2(input) 115 | 116 | output0_cat = self.b1(torch.cat([output0, inp1], 1)) 117 | output1_0 = self.level2_0(output0_cat) # down-sampled 118 | 119 | for i, layer in enumerate(self.level2): 120 | if i == 0: 121 | output1 = layer(output1_0) 122 | else: 123 | output1 = layer(output1) 124 | 125 | output1_cat = self.b2(torch.cat([output1, output1_0, inp2], 1)) 126 | 127 | output2_0 = self.level3_0(output1_cat) # down-sampled 128 | for i, layer in enumerate(self.level3): 129 | if i == 0: 130 | output2 = layer(output2_0) 131 | else: 132 | output2 = layer(output2) 133 | 134 | output2_cat = self.b3(torch.cat([output2_0, output2], 1)) 135 | 136 | classifier = self.classifier(output2_cat) 137 | 138 | return classifier 139 | 140 | 141 | class ENetEncoder(nn.Module): 142 | """ 143 | ENET Encoder 144 | """ 145 | 146 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): 147 | super().__init__() 148 | 149 | def forward(self, input): 150 | self.initial_block = InitialBlock(3, 16, padding=1, relu=encoder_relu) 151 | 152 | # Stage 1 - Encoder 153 | self.downsample1_0 = DownsamplingBottleneck(16, 64, padding=1, return_indices=True, dropout_prob=0.01, 154 | relu=encoder_relu) 155 | self.regular1_1 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu) 156 | self.regular1_2 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu) 157 | self.regular1_3 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu) 158 | self.regular1_4 = RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu) 159 | 160 | # Stage 2 - Encoder 161 | self.downsample2_0 = DownsamplingBottleneck(64, 128, padding=1, return_indices=True, dropout_prob=0.1, 162 | relu=encoder_relu) 163 | self.regular2_1 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu) 164 | self.dilated2_2 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 165 | self.asymmetric2_3 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, 166 | relu=encoder_relu) 167 | self.dilated2_4 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 168 | self.regular2_5 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu) 169 | self.dilated2_6 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 170 | self.asymmetric2_7 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, 171 | relu=encoder_relu) 172 | self.dilated2_8 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 173 | 174 | # Stage 3 - Encoder 175 | self.regular3_0 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu) 176 | self.dilated3_1 = RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 177 | self.asymmetric3_2 = RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, 178 | relu=encoder_relu) 179 | self.dilated3_3 = RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 180 | self.regular3_4 = RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu) 181 | self.dilated3_5 = RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 182 | self.asymmetric3_6 = RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, 183 | relu=encoder_relu) 184 | self.dilated3_7 = RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 185 | -------------------------------------------------------------------------------- /lanenet/model/BiseNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | 8 | class conv2d(nn.Module): 9 | def __init__(self,in_dim,out_dim,k,pad,stride,groups = 1,bias=False,use_bn = True,use_rl = True): 10 | super(conv2d,self).__init__() 11 | self.use_bn = use_bn 12 | self.use_rl = use_rl 13 | self.conv = nn.Conv2d(in_dim,out_dim,k,padding=pad,stride=stride, groups=groups,bias=bias) 14 | self.bn = nn.BatchNorm2d(out_dim) 15 | self.relu = nn.ReLU(inplace=True) 16 | def forward(self,bottom): 17 | if self.use_bn and self.use_rl: 18 | return self.relu(self.bn(self.conv(bottom))) 19 | elif self.use_bn: 20 | return self.bn(self.conv(bottom)) 21 | else: 22 | return self.conv(bottom) 23 | 24 | class StemBlock(nn.Module): 25 | def __init__(self): 26 | super(StemBlock,self).__init__() 27 | self.conv1 = conv2d(3,16,3,1,2) 28 | self.conv_1x1 = conv2d(16,8,1,0,1) 29 | self.conv_3x3 = conv2d(8,16,3,1,2) 30 | self.mpooling = nn.MaxPool2d(3,2,1) 31 | self.conv2 = conv2d(32,16,3,1,1) 32 | def forward(self,bottom): 33 | base = self.conv1(bottom) 34 | conv_1 = self.conv_1x1(base) 35 | conv_3 = self.conv_3x3(conv_1) 36 | pool = self.mpooling(base) 37 | cat = torch.cat([conv_3,pool],1) 38 | res = self.conv2(cat) 39 | return res 40 | 41 | class ContextEmbeddingBlock(nn.Module): 42 | def __init__(self,in_dim): 43 | super(ContextEmbeddingBlock,self).__init__() 44 | self.gap = nn.AdaptiveAvgPool2d(1)#1 45 | self.bn1 = nn.BatchNorm2d(in_dim) 46 | self.conv1 = conv2d(in_dim,in_dim,1,0,1) 47 | self.conv2 = conv2d(in_dim,in_dim,3,1,1,use_bn = False,use_rl = False) 48 | def forward(self,bottom): 49 | gap = self.gap(bottom) 50 | bn = self.bn1(gap) 51 | conv1 = self.conv1(bn) 52 | feat = bottom+conv1 53 | res = self.conv2(feat) 54 | return res 55 | 56 | class GatherExpansion(nn.Module): 57 | def __init__(self,in_dim,out_dim,stride = 1,exp = 6): 58 | super(GatherExpansion,self).__init__() 59 | exp_dim = in_dim*exp 60 | self.stride = stride 61 | self.conv1 = conv2d(in_dim,exp_dim,3,1,1) 62 | self.dwconv2 = conv2d(exp_dim,exp_dim,3,1,1,exp_dim,use_rl = False) 63 | self.conv_11 = conv2d(exp_dim,out_dim,1,0,1,use_rl = False) 64 | 65 | self.dwconv1 = conv2d(exp_dim,exp_dim,3,1,2,exp_dim,use_rl = False) 66 | self.dwconv3 = conv2d(in_dim,in_dim,3,1,2,in_dim,use_rl = False) 67 | self.conv_12 = conv2d(in_dim,out_dim,1,0,1,use_rl = False) 68 | self.relu = nn.ReLU(inplace=True) 69 | def forward(self,bottom): 70 | base = self.conv1(bottom) 71 | if self.stride == 2: 72 | base = self.dwconv1(base) 73 | bottom = self.dwconv3(bottom) 74 | bottom = self.conv_12(bottom) 75 | x = self.dwconv2(base) 76 | x = self.conv_11(x) 77 | res = self.relu(x+bottom) 78 | return res 79 | 80 | class BGA(nn.Module): 81 | def __init__(self,in_dim): 82 | super(BGA,self).__init__() 83 | self.in_dim = in_dim 84 | self.db_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 85 | self.db_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 86 | self.db_conv = conv2d(in_dim,in_dim,3,1,2,use_rl=False) 87 | self.db_apooling = nn.AvgPool2d(3,2,1) 88 | 89 | self.sb_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 90 | self.sb_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 91 | self.sb_conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 92 | self.sb_sigmoid = nn.Sigmoid() 93 | 94 | self.conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 95 | def forward(self,db,sb): 96 | db_dwc = self.db_dwconv(db) 97 | db_out = self.db_conv1x1(db_dwc)# 98 | db_conv = self.db_conv(db) 99 | db_pool = self.db_apooling(db_conv) 100 | 101 | sb_dwc = self.sb_dwconv(sb) 102 | sb_out = self.sb_sigmoid(self.sb_conv1x1(sb_dwc))# 103 | sb_conv = self.sb_conv(sb) 104 | sb_up = self.sb_sigmoid(F.interpolate(sb_conv, size=db_out.size()[2:], mode="bilinear",align_corners=True)) 105 | db_l = db_out*sb_up 106 | sb_r = F.interpolate(sb_out*db_pool, size=db_out.size()[2:], mode="bilinear",align_corners=True) 107 | res = self.conv(db_l+sb_r) 108 | return res 109 | 110 | class SegHead(nn.Module): 111 | def __init__(self,in_dim,out_dim,cls,size=[512,1024]): 112 | super(SegHead,self).__init__() 113 | self.size = size 114 | self.conv = conv2d(in_dim,out_dim,3,1,1) 115 | self.cls = conv2d(out_dim,cls,1,0,1,use_bn=False,use_rl=False) 116 | def forward(self,feat): 117 | x = self.conv(feat) 118 | x = self.cls(x) 119 | pred = F.interpolate(x, size=self.size, mode="bilinear",align_corners=True) 120 | return pred 121 | 122 | 123 | class DetailedBranch(nn.Module): 124 | def __init__(self): 125 | super(DetailedBranch,self).__init__() 126 | self.s1_conv1 = conv2d(3,64,3,1,2) 127 | self.s1_conv2 = conv2d(64,64,3,1,1) 128 | 129 | self.s2_conv1 = conv2d(64,64,3,1,2) 130 | self.s2_conv2 = conv2d(64,64,3,1,1) 131 | self.s2_conv3 = conv2d(64,64,3,1,1) 132 | 133 | self.s3_conv1 = conv2d(64,128,3,1,2) 134 | self.s3_conv2 = conv2d(128,128,3,1,1) 135 | self.s3_conv3 = conv2d(128,128,3,1,1) 136 | def forward(self,bottom): 137 | s1_1 = self.s1_conv1(bottom) 138 | s1_2 = self.s1_conv2(s1_1) 139 | 140 | s2_1 = self.s2_conv1(s1_2) 141 | s2_2 = self.s2_conv2(s2_1) 142 | s2_3 = self.s2_conv3(s2_2) 143 | 144 | s3_1 = self.s3_conv1(s2_3) 145 | s3_2 = self.s3_conv2(s3_1) 146 | s3_3 = self.s3_conv3(s3_2) 147 | return s3_3 148 | 149 | class SemanticBranch(nn.Module): 150 | def __init__(self, cls): 151 | super(SemanticBranch,self).__init__() 152 | self.stem = StemBlock() 153 | self.s3_ge1 = GatherExpansion(16,32,2) 154 | self.s3_ge2 = GatherExpansion(32,32) 155 | 156 | self.s4_ge1 = GatherExpansion(32,64,2) 157 | self.s4_ge2 = GatherExpansion(64,64) 158 | 159 | self.s5_ge1 = GatherExpansion(64,128,2) 160 | self.s5_ge2 = GatherExpansion(128,128) 161 | self.s5_ge3 = GatherExpansion(128,128) 162 | self.s5_ge4 = GatherExpansion(128,128) 163 | self.s5_ge5 = GatherExpansion(128,128,exp=1) 164 | if self.training: 165 | self.seghead1 = SegHead(16,16,cls) 166 | self.seghead2 = SegHead(32,32,cls) 167 | self.seghead3 = SegHead(64,64,cls) 168 | self.seghead4 = SegHead(128,128,cls) 169 | 170 | self.ceb = ContextEmbeddingBlock(128) 171 | 172 | def forward(self,bottom): 173 | stg12 = self.stem(bottom) 174 | #print(stg12.size()) 175 | stg3 = self.s3_ge1(stg12) 176 | stg3 = self.s3_ge2(stg3) 177 | #print(stg3.size()) 178 | stg4 = self.s4_ge1(stg3) 179 | stg4 = self.s4_ge2(stg4) 180 | #print(stg4.size()) 181 | stg5 = self.s5_ge1(stg4) 182 | stg5 = self.s5_ge2(stg5) 183 | stg5 = self.s5_ge3(stg5) 184 | stg5 = self.s5_ge4(stg5) 185 | stg5 = self.s5_ge5(stg5) 186 | #print(stg5.size()) 187 | out = self.ceb(stg5) 188 | if self.training: 189 | seghead1 = self.seghead1(stg12) 190 | seghead2 = self.seghead2(stg3) 191 | seghead3 = self.seghead3(stg4) 192 | seghead4 = self.seghead4(stg5) 193 | return out,seghead1,seghead2,seghead3,seghead4 194 | else: 195 | return out 196 | 197 | 198 | class BiSeNet(nn.Module): 199 | def __init__(self,cls): 200 | super(BiSeNet, self).__init__() 201 | self.db = DetailedBranch() 202 | self.sb = SemanticBranch(cls) 203 | self.bga = BGA(128) 204 | self.seghead = SegHead(128,128,cls) 205 | self._init_params() 206 | self.criterion = nn.CrossEntropyLoss(ignore_index=255) 207 | def _init_params(self): 208 | for m in self.modules(): 209 | if isinstance(m, nn.Conv2d): 210 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 211 | if m.bias is not None: 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.BatchNorm2d): 214 | nn.init.constant_(m.weight, 1) 215 | nn.init.constant_(m.bias, 0) 216 | elif isinstance(m, nn.BatchNorm1d): 217 | nn.init.constant_(m.weight, 1) 218 | nn.init.constant_(m.bias, 0) 219 | elif isinstance(m, nn.Linear): 220 | nn.init.normal_(m.weight, 0, 0.01) 221 | if m.bias is not None: 222 | nn.init.constant_(m.bias, 0) 223 | def forward(self,data,y=None): 224 | db = self.db(data) 225 | if self.training: 226 | sb,head1,head2,head3,head4 = self.sb(data) 227 | else: 228 | sb = self.sb(data) 229 | bga = self.bga(db,sb) 230 | pred = self.seghead(bga) 231 | if self.training: 232 | main_loss = self.criterion(pred, y) 233 | aux1_loss = self.criterion(head1, y) 234 | aux2_loss = self.criterion(head2, y) 235 | aux3_loss = self.criterion(head3, y) 236 | aux4_loss = self.criterion(head4, y) 237 | return pred.max(1)[1],main_loss,(aux1_loss,aux2_loss,aux3_loss,aux4_loss) 238 | return pred 239 | 240 | if __name__ == '__main__': 241 | import os 242 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 243 | input = torch.rand(4, 3, 720, 960).cuda() 244 | model = BiSeNet(11,False).cuda() 245 | model.eval() 246 | print(model) 247 | output = model(input) 248 | print('BiSeNet', output.size()) -------------------------------------------------------------------------------- /lanenet/utils/PicEnhanceUtils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import torch 4 | from torchvision import transforms as transforms 5 | import numpy as np 6 | from torchvision.transforms import functional as F 7 | import random 8 | import cv2 9 | import numbers 10 | from PIL import Image 11 | 12 | def clip_by_tensor(t,t_min,t_max): 13 | """ 14 | clip_by_tensor 15 | :param t: tensor 16 | :param t_min: min 17 | :param t_max: max 18 | :return: cliped tensor 19 | """ 20 | t=t.float() 21 | t_min=t_min.float() 22 | t_max=t_max.float() 23 | 24 | result = (t >= t_min).float() * t + (t < t_min).float() * t_min 25 | result = (result <= t_max).float() * result + (result > t_max).float() * t_max 26 | return result 27 | 28 | # return pil,input pil 29 | def random_color_augmentation(gt_image): 30 | rv=random.random() 31 | if rv<0.7: 32 | gt_image = transforms.ColorJitter(brightness=0.05)(gt_image) 33 | gt_image = transforms.ColorJitter(contrast=[0.7, 1.3])(gt_image) 34 | gt_image = transforms.ColorJitter(saturation=[0.8, 1.2])(gt_image) 35 | return gt_image 36 | return gt_image 37 | 38 | # return pil,input pil 39 | def random_horizon_flip_batch_images(gt_image, gt_binary_image, gt_instance_image): 40 | trans=MyRandomHorizontalFlip(p=0.5) 41 | gt_image=trans(gt_image) 42 | gt_binary_image=trans(gt_binary_image) 43 | gt_instance_image=trans(gt_instance_image) 44 | return gt_image,gt_binary_image,gt_instance_image 45 | 46 | # return pil,input pil 47 | def random_crop(gt_image, gt_binary_image, gt_instance_image): 48 | rv=random.random() 49 | w, h = gt_image.size 50 | i = random.randint(0, h - 244) 51 | j = random.randint(0, w - 1280) 52 | RandomCrop = MyRandomCrop(size=(244, 1280),i=i,j=j) 53 | if rv<1: 54 | gt_image = RandomCrop(gt_image) 55 | # gt_image.save('./random1.jpg') 56 | gt_binary_image = RandomCrop(gt_binary_image) 57 | # gt_image.save('./random2.jpg') 58 | gt_instance_image = RandomCrop(gt_instance_image) 59 | # gt_image.save('./random3.jpg') 60 | return gt_image,gt_binary_image,gt_instance_image 61 | return gt_image,gt_binary_image,gt_instance_image 62 | 63 | class MyRandomHorizontalFlip(object): 64 | """Horizontally flip the given PIL Image randomly with a given probability. 65 | 66 | Args: 67 | p (float): probability of the image being flipped. Default value is 0.5 68 | """ 69 | 70 | def __init__(self, p=0.5): 71 | self.p = p 72 | self.rv=random.random() 73 | # self.rv=0 74 | 75 | def __call__(self, img): 76 | """ 77 | Args: 78 | img (PIL Image): Image to be flipped. 79 | 80 | Returns: 81 | PIL Image: Randomly flipped image. 82 | """ 83 | if self.rv < self.p: 84 | return F.hflip(img) 85 | return img 86 | 87 | def __repr__(self): 88 | return self.__class__.__name__ + '(p={})'.format(self.p) 89 | 90 | class MyRandomCrop(object): 91 | """Crop the given PIL Image at a random location. 92 | 93 | Args: 94 | size (sequence or int): Desired output size of the crop. If size is an 95 | int instead of sequence like (h, w), a square crop (size, size) is 96 | made. 97 | padding (int or sequence, optional): Optional padding on each border 98 | of the image. Default is None, i.e no padding. If a sequence of length 99 | 4 is provided, it is used to pad left, top, right, bottom borders 100 | respectively. If a sequence of length 2 is provided, it is used to 101 | pad left/right, top/bottom borders, respectively. 102 | pad_if_needed (boolean): It will pad the image if smaller than the 103 | desired size to avoid raising an exception. Since cropping is done 104 | after padding, the padding seems to be done at a random offset. 105 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 106 | length 3, it is used to fill R, G, B channels respectively. 107 | This value is only used when the padding_mode is constant 108 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 109 | 110 | - constant: pads with a constant value, this value is specified with fill 111 | 112 | - edge: pads with the last value on the edge of the image 113 | 114 | - reflect: pads with reflection of image (without repeating the last value on the edge) 115 | 116 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 117 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 118 | 119 | - symmetric: pads with reflection of image (repeating the last value on the edge) 120 | 121 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 122 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 123 | 124 | """ 125 | 126 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant',i=0,j=0): 127 | if isinstance(size, numbers.Number): 128 | self.size = (int(size), int(size)) 129 | else: 130 | self.size = size 131 | self.padding = padding 132 | self.pad_if_needed = pad_if_needed 133 | self.fill = fill 134 | self.padding_mode = padding_mode 135 | self.i=i 136 | self.j=j 137 | @staticmethod 138 | def get_params(img, output_size,i,j): 139 | """Get parameters for ``crop`` for a random crop. 140 | 141 | Args: 142 | img (PIL Image): Image to be cropped. 143 | output_size (tuple): Expected output size of the crop. 144 | 145 | Returns: 146 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 147 | """ 148 | w, h = img.size 149 | th, tw = output_size 150 | if w <= tw and h <= th: 151 | return 0, 0, h, w 152 | return i, j, th, tw 153 | 154 | def __call__(self, img): 155 | """ 156 | Args: 157 | img (PIL Image): Image to be cropped. 158 | 159 | Returns: 160 | PIL Image: Cropped image. 161 | """ 162 | if self.padding is not None: 163 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 164 | 165 | # pad the width if needed 166 | if self.pad_if_needed and img.size[0] < self.size[1]: 167 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 168 | # pad the height if needed 169 | if self.pad_if_needed and img.size[1] < self.size[0]: 170 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 171 | 172 | i, j, h, w = self.get_params(img, self.size,self.i,self.j) 173 | 174 | return F.crop(img, i, j, h, w) 175 | 176 | def __repr__(self): 177 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 178 | 179 | if __name__ == '__main__': 180 | 181 | # transforms.Resize((32,32)) 182 | toPil=transforms.ToPILImage() 183 | # gt_imageOri = cv2.imread("D:/code/pytorch-lanenet-master/lanenet/utils/frame0270.jpg", cv2.IMREAD_COLOR) 184 | label_img = cv2.imread("D:/yum/tmcdata/test123/frame0270.jpg", cv2.IMREAD_COLOR) 185 | label_img=toPil(label_img[:,:,[2,1,0]]) 186 | label_img.save('./random.jpg') 187 | print(np.asarray(label_img).shape) 188 | cv2.imwrite('./random1.jpg',np.asarray(label_img)[:,:,[2,1,0]]) 189 | 190 | # gt_image=toPil(gt_imageOri) 191 | # imgx=np.asarray(gt_image) 192 | # 193 | # img=np.transpose(imgx,(2,0,1)) 194 | # img=np.transpose(img,(1,2,0)) 195 | # 196 | # 197 | # img1 = imgx.reshape(imgx.shape[2], imgx.shape[0], imgx.shape[1]) 198 | # img1=img1.reshape(img1.shape[1],img1.shape[2],img1.shape[0]) 199 | # 200 | # src7 = cv2.addWeighted(img,1.0,img1,0.3,0) 201 | # gt_image=random_color_augmentation(gt_image) 202 | # gt_image.save('./random1.jpg') 203 | # gt_image, gt_binary_image, label_instance_img=random_horizon_flip_batch_images(gt_image,label_img,label_img) 204 | # gt_image, gt_binary_image, label_instance_img=random_crop(gt_image, gt_binary_image, label_instance_img) 205 | # 206 | # resize=transforms.Resize((720,1280),interpolation=Image.NEAREST) 207 | # gt_image=resize(gt_image) 208 | # gt_binary_image=resize(gt_binary_image) 209 | # 210 | # src7 = cv2.addWeighted(cv2.cvtColor(np.asarray(gt_image),cv2.COLOR_RGB2BGR),0.8,cv2.cvtColor(np.asarray(gt_binary_image),cv2.COLOR_RGB2BGR),1,0) 211 | # final_img=src7 212 | # final_img[np.where((final_img==[255, 255, 255]).all(axis=2))] = [0,0,255] 213 | 214 | # gt_image, gt_binary_image, label_instance_img=random_horizon_flip_batch_images(gt_image,gt_image,label_instance_img) 215 | # gt_image, gt_binary_image, label_instance_img=random_crop(gt_image, gt_binary_image, label_instance_img) 216 | # print(np.asarray(label_instance_img).shape) 217 | 218 | # from lanenet.dataloader.transformers import Rescale 219 | # transform=transforms.Compose([Rescale((1280, 720))]) 220 | # for i in range(0,10000): 221 | 222 | # label_instance_img1=label_instance_img.resize((1280, 720), Image.NEAREST) 223 | 224 | # resize=transforms.Resize((720,1280),interpolation=Image.NEAREST) 225 | # label_instance_img1=resize(label_instance_img) 226 | # label_instance_img=resize(label_instance_img) 227 | # label_instance_img=resize(label_instance_img) 228 | # label_instance_img.save('./random1.jpg') 229 | 230 | # print(np.unique(np.asarray(label_instance_img1))[1:]) 231 | 232 | # imgx = cv2.cvtColor(np.asarray(toPil(label_instance_img[:,:,[2,1,0]])),cv2.COLOR_RGB2BGR) 233 | # label_instance_img=toPil(label_instance_img) 234 | # label_instance_img.save('./random1.jpg') 235 | # print(label_instance_img.shape) 236 | # label_instance_img=random_color_augmentation(label_instance_img) 237 | # label_instance_img.save('./random1.jpg') 238 | # gt_image, gt_binary_image, gt_instance_image=random_horizon_flip_batch_images(label_instance_img,label_instance_img,label_instance_img) 239 | # gt_binary_image.save('./random2.jpg') 240 | # gt_image, gt_binary_image, gt_instance_image=random_crop(gt_image, gt_binary_image, gt_instance_image) 241 | # print() 242 | # gt_binary_image.save('./random3.jpg') 243 | -------------------------------------------------------------------------------- /scripts/generateMGdataset_v3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import xml.etree.ElementTree as ET 4 | import cv2 5 | import numpy as np 6 | import shutil 7 | import random 8 | import json 9 | import math 10 | import sys 11 | import codecs 12 | 13 | sys.stdout = codecs.getwriter('utf-8')(sys.stdout.detach()) 14 | 15 | semantic_label_dict = { 16 | '单虚线': 255, 17 | '单实线': 255, 18 | '双实线': 255, 19 | '双虚线': 255, 20 | # '可行驶区域': 1, 21 | # '直行或左转': 3, 22 | # '左转或直行': 3, 23 | # '直行或右转': 3, 24 | # '左弯或向左合流': 3, 25 | # '右弯或向右合流': 3, 26 | # '右转或向右合流': 3, 27 | # '左右转弯': 3, 28 | # '左转或掉头': 3, 29 | # '直行': 3, 30 | # '左转': 3, 31 | # '右转': 3, 32 | # '掉头': 3, 33 | # '箭头': 3, 34 | # 35 | # '停止线': 1, 36 | # '减速带': 1, 37 | # '减速让行': 1, 38 | # '斑马线': 4, 39 | # '车距确认线': 4, 40 | # '导流带': 5, 41 | # '菱形减速标': 1, 42 | # 43 | # '限速': 1, 44 | # '文字': 1, 45 | # '其他': 1, 46 | # '其它': 1, 47 | # 'TrafficSign': 0, 48 | } 49 | instance_label_dict = { 50 | '左三': 40, 51 | '左二': 80, 52 | '左一': 120, 53 | '右一': 160, 54 | '右二': 200, 55 | '右三': 240, 56 | '左四': 0, 57 | '右四': 0, 58 | '左五': 0, 59 | '右五': 0, 60 | '左六': 0, 61 | '右六': 0, 62 | } 63 | 64 | #semantic_new = semantic_image[min_h:max_h, :] 65 | #CityTunnel: 20032514*; 505, 825 66 | #Highway: 14*; 415, 735 67 | #sanhuan: 2002*; 480, 800 68 | #shunyi: frame*; 281, 505 69 | 70 | def compute_polygon_area(points): 71 | point_num = len(points) 72 | if (point_num < 3): return 0.0 73 | s = points[0][1] * (points[point_num - 1][0] - points[1][0]) 74 | # for i in range(point_num): # (int i = 1 i < point_num ++i): 75 | for i in range(1, point_num): # 有小伙伴发现一个bug,这里做了修改,但是没有测试,需要使用的亲请测试下,以免结果不正确。 76 | s += points[i][1] * (points[i - 1][0] - points[(i + 1) % point_num][0]) 77 | return abs(s / 2.0) 78 | 79 | 80 | def processXml(root_path): 81 | label_dir = os.path.join(root_path, 'gt_xml') 82 | gt_image_dir = os.path.join(root_path, 'gt_image') 83 | gt_semantic_dir = os.path.join(root_path, 'gt_binary_image') 84 | gt_instance_dir = os.path.join(root_path, 'gt_instance_image') 85 | semantic_keys = semantic_label_dict.keys() 86 | instance_keys = instance_label_dict.keys() 87 | 88 | files = os.listdir(label_dir) 89 | for idx, name in enumerate(files): 90 | try: 91 | if not name.endswith('.xml'): 92 | continue 93 | name = name[:-3]+'xml' 94 | png_name = name[:-3] + 'png' 95 | jpg_name = name[:-3] + 'jpg' 96 | src_path = os.path.join(gt_image_dir, jpg_name) 97 | semantic_path = os.path.join(gt_semantic_dir, png_name) 98 | instance_path = os.path.join(gt_instance_dir, png_name) 99 | _path = os.path.join(gt_image_dir, jpg_name) 100 | # src_image = cv2.imread(src_path, cv2.IMREAD_COLOR) 101 | semantic_image = cv2.imread(semantic_path, cv2.IMREAD_GRAYSCALE) 102 | instance_image = cv2.imread(instance_path, cv2.IMREAD_GRAYSCALE) 103 | 104 | xml_path = os.path.join(label_dir, name) 105 | tree = ET.parse(xml_path) 106 | root = tree.getroot() 107 | for item in root.iter('item'): 108 | instance = item.find('name') 109 | semantic_v = -1 110 | instance_v = -1 111 | for semantic_key in semantic_keys: 112 | if instance.text.find(semantic_key) >= 0: 113 | semantic_v = semantic_label_dict[semantic_key] 114 | break 115 | for instance_key in instance_keys: 116 | if instance.text.find(instance_key) >= 0: 117 | # semantic_v = 255 118 | instance_v = instance_label_dict[instance_key] 119 | break 120 | if semantic_v == -1 and instance_v == -1: 121 | print(name, instance.text) 122 | continue 123 | if semantic_v == 0 or instance_v == 0: 124 | continue 125 | 126 | arr = [] 127 | polygon = item.find('polygon') 128 | cubic_bezier = item.find('cubic_bezier') 129 | bndbox = item.find('bndbox') 130 | if polygon is not None: 131 | for xy in polygon: 132 | arr.append(int(xy.text)) 133 | elif cubic_bezier is not None: 134 | for xy in cubic_bezier: 135 | if len(xy.tag) <= 3: 136 | arr.append(int(xy.text)) 137 | elif bndbox is not None: 138 | for xy in bndbox: 139 | arr.append(int(xy.text)) 140 | else: 141 | print('Error boundingbox: %s, %s ' % (name, instance.text)) 142 | # shutil.copy(src_path, os.path.join(root_path, 'errors')) 143 | # shutil.copy(xml_path, os.path.join(root_path, 'errors')) 144 | continue 145 | 146 | pt = [] 147 | for i in range(0, len(arr) - 1, 2): 148 | pt.append([arr[i], arr[i + 1]]) 149 | b = np.array([pt], dtype=np.int32) 150 | s = compute_polygon_area(pt) 151 | # print('v: %d, s: %d' % (instance_v, s)) 152 | if s > 3 and semantic_v > 0: 153 | cv2.fillPoly(semantic_image, b, semantic_v) 154 | if s > 3 and instance_v > 0: 155 | cv2.fillPoly(instance_image, b, instance_v) 156 | if s <= 3: 157 | print('Warning: Too little target: %s, %s, %d' % (name, instance.text, s)) 158 | if semantic_image.max() > 0: 159 | cv2.imwrite(semantic_path, semantic_image) 160 | cv2.imwrite(instance_path, instance_image) 161 | else: 162 | # print('Error null name: ', name) 163 | os.remove(xml_path) 164 | except: 165 | print('error') 166 | 167 | def processJson(root_path): 168 | gt_json_dir = os.path.join(root_path, 'gt_json') 169 | gt_image_dir = os.path.join(root_path, 'gt_image') 170 | gt_semantic_dir = os.path.join(root_path, 'gt_binary_image') 171 | if os.path.exists(gt_semantic_dir): 172 | shutil.rmtree(gt_semantic_dir) 173 | os.mkdir(gt_semantic_dir) 174 | gt_instance_dir = os.path.join(root_path, 'gt_instance_image') 175 | if os.path.exists(gt_instance_dir): 176 | shutil.rmtree(gt_instance_dir) 177 | os.mkdir(gt_instance_dir) 178 | semantic_keys = semantic_label_dict.keys() 179 | instance_keys = instance_label_dict.keys() 180 | 181 | jsons = os.listdir(gt_json_dir) 182 | for name in jsons: 183 | try: 184 | json_path = os.path.join(gt_json_dir, name) 185 | org_im_name = name[:-4] + 'jpg' 186 | semantic_name = name[:-4] + 'png' 187 | instance_name = semantic_name 188 | im = cv2.imread(os.path.join(gt_image_dir, org_im_name), cv2.IMREAD_COLOR) 189 | ss = (im.shape[0], im.shape[1]) 190 | semantic_img_path = os.path.join(gt_semantic_dir, semantic_name) 191 | instance_img_path = os.path.join(gt_instance_dir, instance_name) 192 | semantic_im = np.zeros(ss, np.uint8) 193 | instance_im = np.zeros(ss, np.uint8) 194 | #print(json_path) 195 | with open(json_path, encoding='utf-8-sig', errors='ignore') as f: 196 | info_dict = json.loads(f.read(), strict=False) 197 | info_shapes = info_dict['shapes'] 198 | for info in info_shapes: 199 | semantic_v = -1 200 | instance_v = -1 201 | for semantic_key in semantic_keys: 202 | if info['label'].find(semantic_key) >= 0: 203 | semantic_v = semantic_label_dict[semantic_key] 204 | break 205 | for instance_key in instance_keys: 206 | if info['label'].find(instance_key) >= 0: 207 | # if 208 | # semantic_v = 255 209 | instance_v = instance_label_dict[instance_key] 210 | break 211 | if semantic_v == -1 and instance_v == -1: 212 | print(json_path, info['label']) 213 | continue 214 | 215 | info_pts = info['points'] 216 | # print(info_pts) 217 | pts_nums = len(info_pts) 218 | edge_lines = [] 219 | for pti in range(pts_nums): 220 | x0 = info_pts[pti][0] 221 | y0 = info_pts[pti][1] 222 | edge_lines.append([round(x0),round(y0)]) 223 | b = np.array([edge_lines], dtype=np.int32) 224 | s = compute_polygon_area(edge_lines) 225 | # print('v: %d, s: %d' % (instance_v, s)) 226 | if semantic_v > -1: 227 | cv2.fillPoly(semantic_im, b, semantic_v) 228 | if instance_v > -1: 229 | cv2.fillPoly(instance_im, b, instance_v) 230 | cv2.imwrite(semantic_img_path, semantic_im) 231 | cv2.imwrite(instance_img_path, instance_im) 232 | except: 233 | print('error') 234 | return 235 | 236 | def resizeAll(root_path): 237 | gt_image_dir = os.path.join(root_path, 'gt_image') 238 | gt_semantic_dir = os.path.join(root_path, 'gt_binary_image') 239 | gt_instance_dir = os.path.join(root_path, 'gt_instance_image') 240 | dim = (1280, 720) 241 | pngs = os.listdir(gt_semantic_dir) 242 | for png in pngs: 243 | #semantic_new = semantic_image[min_h:max_h, :] 244 | #CityTunnel: 20032514*; 505, 825 245 | #Highway: 14*; 415, 735 246 | #sanhuan: 2002*; 480, 800 247 | #shunyi: frame*; 281, 505 248 | 249 | 250 | jpg = png[:-3] + 'jpg' 251 | org_im_path = os.path.join(gt_image_dir, jpg) 252 | semantic_im_path = os.path.join(gt_semantic_dir, png) 253 | instance_im_path = os.path.join(gt_instance_dir, png) 254 | org_im = cv2.imread(org_im_path, cv2.IMREAD_COLOR) 255 | semantic_im = cv2.imread(semantic_im_path, cv2.IMREAD_GRAYSCALE) 256 | instance_im = cv2.imread(instance_im_path, cv2.IMREAD_GRAYSCALE) 257 | if jpg[:8] == '20032514': #CityTunnel 258 | #print(jpg[:8]) 259 | org_im = org_im[505:825,:] 260 | semantic_im = semantic_im[505:825,:] 261 | instance_im = instance_im[505:825,:] 262 | elif jpg[:2] == '14': #Highway 263 | #print(jpg[:2]) 264 | org_im = org_im[415:735,:] 265 | semantic_im = semantic_im[415:735,:] 266 | instance_im = instance_im[415:735,:] 267 | elif jpg[:4] == '2002': #sanhuan 268 | #print(jpg[:4]) 269 | org_im = org_im[480:800,:] 270 | semantic_im = semantic_im[480:800,:] 271 | instance_im = instance_im[480:800,:] 272 | elif jpg[:5] == 'frame': #shunyi 273 | #print(jpg[:5]) 274 | org_im = org_im[281:505,:] 275 | semantic_im = semantic_im[281:505,:] 276 | instance_im = instance_im[281:505,:] 277 | if org_im.shape[1] > 1280: 278 | dim = (1280, 224) 279 | org_im = cv2.resize(org_im, dim, interpolation = cv2.INTER_CUBIC) 280 | semantic_im = cv2.resize(semantic_im, dim, interpolation = cv2.INTER_NEAREST) 281 | instance_im = cv2.resize(instance_im, dim, interpolation = cv2.INTER_NEAREST) 282 | cv2.imwrite(org_im_path, org_im) 283 | cv2.imwrite(semantic_im_path, semantic_im) 284 | cv2.imwrite(instance_im_path, instance_im) 285 | 286 | def gen_index(root_path): 287 | semantic_path = os.path.join(root_path, 'gt_binary_image') 288 | index_path = os.path.join(root_path, 'index') 289 | if not os.path.exists(index_path): 290 | os.mkdir(index_path) 291 | else: 292 | shutil.rmtree(index_path) 293 | os.mkdir(index_path) 294 | all_txt_path = os.path.join(index_path, 'all.txt') 295 | train_txt_path = os.path.join(index_path, 'train.txt') 296 | trainval_txt_path = os.path.join(index_path, 'test.txt') 297 | val_txt_path = os.path.join(index_path, 'val.txt') 298 | 299 | semantic_list = os.listdir(semantic_path) 300 | for each in semantic_list: 301 | if not each.endswith('.png'): 302 | continue 303 | jpg_each = each[:-3] + 'jpg' 304 | im_path = os.path.join(root_path, 'gt_image', jpg_each) 305 | semantic_path = os.path.join(root_path, 'gt_binary_image', each) 306 | instance_path = os.path.join(root_path, 'gt_instance_image', each) 307 | line = im_path + ' ' + semantic_path + ' ' + instance_path 308 | with open(all_txt_path, 'a') as f: 309 | f.write(line + '\n') 310 | 311 | with open(all_txt_path, 'r') as f: 312 | lines = f.readlines() 313 | g = [i for i in range(len(lines))] 314 | random.shuffle(g) 315 | train = g[:int(len(lines) * 19 / 20)] 316 | #trainval = g[int(len(lines) * 18 / 20):int(len(lines) * 19 / 20)] 317 | val = g[int(len(lines) * 19 / 20):] 318 | 319 | for n, line in enumerate(lines): 320 | if n in train: 321 | with open(train_txt_path, 'a') as trainf: 322 | trainf.write(line) 323 | #elif n in trainval: 324 | #with open(trainval_txt_path, 'a') as trainvalf: 325 | #trainvalf.write(line) 326 | elif n in val: 327 | with open(val_txt_path, 'a') as valf: 328 | valf.write(line) 329 | shutil.copyfile(val_txt_path, trainval_txt_path) 330 | 331 | if __name__ == '__main__': 332 | root_path = '/workspace/mogo_data' 333 | # root_path = 'D:\\yum\\tmcdata\\test123' 334 | processJson(root_path) 335 | print('processJson done') 336 | processXml(root_path) 337 | print('processxml done') 338 | # resizeAll(root_path) 339 | # print('resizeAll done') 340 | gen_index(root_path) 341 | print('======finished!======') 342 | 343 | -------------------------------------------------------------------------------- /lanenet/model/BiseNet_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | from lanenet import config 8 | import collections 9 | 10 | class conv2d(nn.Module): 11 | def __init__(self,in_dim,out_dim,k,pad,stride,groups = 1,bias=False,use_bn = True,use_rl = True): 12 | super(conv2d,self).__init__() 13 | self.use_bn = use_bn 14 | self.use_rl = use_rl 15 | self.conv = nn.Conv2d(in_dim,out_dim,k,padding=pad,stride=stride, groups=groups,bias=bias) 16 | self.bn = nn.BatchNorm2d(out_dim) 17 | self.relu = nn.ReLU(inplace=True) 18 | def forward(self,bottom): 19 | if self.use_bn and self.use_rl: 20 | return self.relu(self.bn(self.conv(bottom))) 21 | elif self.use_bn: 22 | return self.bn(self.conv(bottom)) 23 | else: 24 | return self.conv(bottom) 25 | 26 | class SegHead(nn.Module): 27 | def __init__(self,in_dim,out_dim,cls,size=[720,1280]): 28 | super(SegHead,self).__init__() 29 | self.size = size 30 | self.conv = conv2d(in_dim,out_dim,3,1,1) 31 | self.cls = conv2d(out_dim,cls,1,0,1,use_bn=False,use_rl=False) 32 | def forward(self,feat): 33 | x = self.conv(feat) 34 | x = self.cls(x) 35 | pred = F.interpolate(x, size=self.size, mode="bilinear",align_corners=True) 36 | return pred 37 | 38 | class StemBlock(nn.Module): 39 | def __init__(self): 40 | super(StemBlock,self).__init__() 41 | self.conv1 = conv2d(3,16,3,1,2) 42 | self.conv_1x1 = conv2d(16,32,1,0,1) 43 | self.conv_3x3 = conv2d(32,32,3,1,2) 44 | self.mpooling = nn.MaxPool2d(3,2,1) 45 | self.conv2 = conv2d(48,32,3,1,1) 46 | def forward(self,bottom): 47 | base = self.conv1(bottom) 48 | conv_1 = self.conv_1x1(base) 49 | conv_3 = self.conv_3x3(conv_1) 50 | pool = self.mpooling(base) 51 | cat = torch.cat([conv_3,pool],1) 52 | res = self.conv2(cat) 53 | return res 54 | 55 | class ContextEmbeddingBlock(nn.Module): 56 | def __init__(self,in_dim): 57 | super(ContextEmbeddingBlock,self).__init__() 58 | self.gap = nn.AdaptiveAvgPool2d(1)#1 59 | # self.gap = nn.AvgPool2d(3,1,1) 60 | self.bn1 = nn.BatchNorm2d(in_dim) 61 | self.conv1 = conv2d(in_dim,in_dim,1,0,1) 62 | self.conv2 = conv2d(in_dim,in_dim,3,1,1,use_bn = False,use_rl = False) 63 | def forward(self,bottom): 64 | gap = self.gap(bottom) 65 | # print(gap) 66 | bn = self.bn1(gap) 67 | conv1 = self.conv1(bn) 68 | feat = bottom+conv1 69 | res = self.conv2(feat) 70 | return res 71 | 72 | class GatherExpansion(nn.Module): 73 | def __init__(self,in_dim,out_dim,stride = 1,exp = 6): 74 | super(GatherExpansion,self).__init__() 75 | exp_dim = in_dim*exp 76 | self.stride = stride 77 | self.conv1 = conv2d(in_dim,exp_dim,3,1,1) 78 | self.dwconv2 = conv2d(exp_dim,exp_dim,3,1,1,exp_dim,use_rl = False) 79 | self.conv_11 = conv2d(exp_dim,out_dim,1,0,1,use_rl = False) 80 | 81 | self.dwconv1 = conv2d(exp_dim,exp_dim,3,1,2,exp_dim,use_rl = False) 82 | self.dwconv3 = conv2d(in_dim,in_dim,3,1,2,in_dim,use_rl = False) 83 | self.conv_12 = conv2d(in_dim,out_dim,1,0,1,use_rl = False) 84 | self.relu = nn.ReLU(inplace=True) 85 | def forward(self,bottom): 86 | base = self.conv1(bottom) 87 | if self.stride == 2: 88 | base = self.dwconv1(base) 89 | bottom = self.dwconv3(bottom) 90 | bottom = self.conv_12(bottom) 91 | x = self.dwconv2(base) 92 | x = self.conv_11(x) 93 | res = self.relu(x+bottom) 94 | return res 95 | 96 | class BGA(nn.Module): 97 | def __init__(self,in_dim): 98 | super(BGA,self).__init__() 99 | self.in_dim = in_dim 100 | self.db_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 101 | self.db_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 102 | self.db_conv = conv2d(in_dim,in_dim,3,1,2,use_rl=False) 103 | self.db_apooling = nn.AvgPool2d(3,2,1) 104 | 105 | self.sb_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 106 | self.sb_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 107 | self.sb_conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 108 | self.sb_sigmoid = nn.Sigmoid() 109 | 110 | self.conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 111 | def forward(self,db,sb): 112 | db_dwc = self.db_dwconv(db) 113 | db_out = self.db_conv1x1(db_dwc)# 114 | db_conv = self.db_conv(db) 115 | db_pool = self.db_apooling(db_conv) 116 | 117 | sb_dwc = self.sb_dwconv(sb) 118 | sb_out = self.sb_sigmoid(self.sb_conv1x1(sb_dwc))# 119 | sb_conv = self.sb_conv(sb) 120 | sb_up = self.sb_sigmoid(F.interpolate(sb_conv, size=db_out.size()[2:], mode="bilinear",align_corners=True)) 121 | db_l = db_out*sb_up 122 | sb_r = F.interpolate(sb_out*db_pool, size=db_out.size()[2:], mode="bilinear",align_corners=True) 123 | res = self.conv(db_l+sb_r) 124 | return res 125 | 126 | class DetailedBranch(nn.Module): 127 | def __init__(self): 128 | super(DetailedBranch,self).__init__() 129 | self.s1_conv1 = conv2d(3,64,3,1,2) 130 | self.s1_conv2 = conv2d(64,64,3,1,1) 131 | 132 | self.s2_conv1 = conv2d(64,64,3,1,2) 133 | self.s2_conv2 = conv2d(64,64,3,1,1) 134 | self.s2_conv3 = conv2d(64,64,3,1,1) 135 | 136 | self.s3_conv1 = conv2d(64,128,3,1,2) 137 | self.s3_conv2 = conv2d(128,128,3,1,1) 138 | self.s3_conv3 = conv2d(128,128,3,1,1) 139 | def forward(self,bottom): 140 | 141 | detail_stage_outputs = collections.OrderedDict() 142 | 143 | s1_1 = self.s1_conv1(bottom) 144 | s1_2 = self.s1_conv2(s1_1) 145 | 146 | detail_stage_outputs["stg1"] = s1_2 147 | 148 | s2_1 = self.s2_conv1(s1_2) 149 | s2_2 = self.s2_conv2(s2_1) 150 | s2_3 = self.s2_conv3(s2_2) 151 | 152 | detail_stage_outputs["stg2"] = s2_3 153 | 154 | s3_1 = self.s3_conv1(s2_3) 155 | s3_2 = self.s3_conv2(s3_1) 156 | s3_3 = self.s3_conv3(s3_2) 157 | 158 | detail_stage_outputs["stg3"] = s3_3 159 | 160 | return { 161 | 'out': s3_3, 162 | 'detail_stage_outputs': detail_stage_outputs 163 | } 164 | 165 | class SemanticBranch(nn.Module): 166 | def __init__(self): 167 | super(SemanticBranch,self).__init__() 168 | self.stem = StemBlock() 169 | self.s3_ge1 = GatherExpansion(32,32,2) 170 | self.s3_ge2 = GatherExpansion(32,32) 171 | 172 | self.s4_ge1 = GatherExpansion(32,64,2) 173 | self.s4_ge2 = GatherExpansion(64,64) 174 | 175 | self.s5_ge1 = GatherExpansion(64,128,2) 176 | self.s5_ge2 = GatherExpansion(128,128) 177 | self.s5_ge3 = GatherExpansion(128,128) 178 | self.s5_ge4 = GatherExpansion(128,128) 179 | self.s5_ge5 = GatherExpansion(128,128,exp=1) 180 | 181 | self.ceb = ContextEmbeddingBlock(128) 182 | 183 | def forward(self,bottom): 184 | seg_stage_outputs = collections.OrderedDict() 185 | 186 | stg1 = self.stem(bottom) 187 | #print(stg12.size()) 188 | seg_stage_outputs["stg1"] = stg1 189 | 190 | stg3 = self.s3_ge1(stg1) 191 | stg3 = self.s3_ge2(stg3) 192 | #print(stg3.size()) 193 | seg_stage_outputs["stg3"] = stg3 194 | 195 | stg4 = self.s4_ge1(stg3) 196 | stg4 = self.s4_ge2(stg4) 197 | 198 | seg_stage_outputs["stg4"] = stg4 199 | #print(stg4.size()) 200 | stg5 = self.s5_ge1(stg4) 201 | stg5 = self.s5_ge2(stg5) 202 | stg5 = self.s5_ge3(stg5) 203 | stg5 = self.s5_ge4(stg5) 204 | stg5 = self.s5_ge5(stg5) 205 | 206 | seg_stage_outputs["stg5"] = stg5 207 | #print(stg5.size()) 208 | out = self.ceb(stg5) 209 | 210 | return { 211 | 'out': out, 212 | 'seg_stage_outputs': seg_stage_outputs 213 | } 214 | 215 | class InstanceSegmentationBranch(nn.Module): 216 | def __init__(self): 217 | super(InstanceSegmentationBranch,self).__init__() 218 | self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 219 | self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 220 | def forward(self,data): 221 | input_tensor_size=list(data.size()) 222 | tmp_size=input_tensor_size[2:] 223 | out_put_tensor_size=tuple([int(tmp * 8) for tmp in tmp_size]) 224 | conv1_out=self.bsconv1(data) 225 | conv2_out=self.bsconv2(conv1_out) 226 | isb_out = F.interpolate(conv2_out, size=out_put_tensor_size, mode="bilinear",align_corners=True) 227 | return isb_out 228 | 229 | class BinarySegmentationBranch(nn.Module): 230 | 231 | def __init__(self): 232 | super(BinarySegmentationBranch,self).__init__() 233 | 234 | self.bsconv1_pre = conv2d(128,256,3,1,1,use_rl=True) 235 | self.bsconv1_pre1 = conv2d(256,256,3,1,1,use_rl=True) 236 | self.bsconv1_pre2 = conv2d(352,256,3,1,1,use_rl=True) 237 | self.bsconv1_pre3 = conv2d(256,256,3,1,1,use_rl=True) 238 | self.bsconv1_pre4 = conv2d(320,256,3,1,1,use_rl=True) 239 | self.bsconv1_pre5 = conv2d(256,128,3,1,1,use_rl=True) 240 | self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 241 | 242 | 243 | # self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 244 | # self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 245 | # self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 246 | 247 | def forward(self,data,seg_stage_outputs,detail_stage_outputs): 248 | 249 | input_tensor_size=list(data.size()) 250 | tmp_size=input_tensor_size[2:] 251 | output_stage2_size=[int(tmp * 2) for tmp in tmp_size] 252 | output_stage1_size=[int(tmp * 4) for tmp in tmp_size] 253 | out_put_tensor_size=[int(tmp * 8) for tmp in tmp_size] 254 | 255 | out1=self.bsconv1_pre(data) 256 | # channel=256 257 | out2=self.bsconv1_pre1(out1) 258 | 259 | output_stage2_tensor = F.interpolate(out2, size=tuple(output_stage2_size), mode="bilinear",align_corners=True) 260 | # output_stage2_tensor = tf.concat([output_stage2_tensor, detail_stage_outputs['stage_2'], semantic_stage_outputs['stage_1']], axis=-1, name='stage_2_concate_features') 261 | output_stage2_tensor=torch.cat([output_stage2_tensor,detail_stage_outputs['stg2'], seg_stage_outputs['stg1']],1) 262 | output_stage2_tensor=self.bsconv1_pre2(output_stage2_tensor) 263 | output_stage2_tensor=self.bsconv1_pre3(output_stage2_tensor) 264 | # channel=256 265 | output_stage1_tensor = F.interpolate(output_stage2_tensor, size=tuple(output_stage1_size), mode="bilinear",align_corners=True) 266 | #output_stage1_tensor = tf.concat([output_stage1_tensor, detail_stage_outputs['stage_1']], axis=-1, name='stage_1_concate_features') 267 | output_stage1_tensor=torch.cat([output_stage1_tensor,detail_stage_outputs['stg1']],1) 268 | output_stage1_tensor=self.bsconv1_pre4(output_stage1_tensor) 269 | output_stage1_tensor=self.bsconv1_pre5(output_stage1_tensor) 270 | output_stage1_tensor=self.bsconv3(output_stage1_tensor) 271 | 272 | # conv_out1=self.bsconv1(data) 273 | # conv_out2=self.bsconv2(conv_out1) 274 | # conv_out3=self.bsconv3(conv_out2) 275 | 276 | # print("--------------------conv_out3 size--------------------") 277 | # print(list(conv_out3.size())) 278 | # print("--------------------bga size--------------------") 279 | 280 | # print("--------------------out_put_tensor_size--------------------") 281 | # print(out_put_tensor_size) 282 | # print(tuple(out_put_tensor_size)) 283 | # print("--------------------out_put_tensor_size--------------------") 284 | bsb_out = F.interpolate(output_stage1_tensor, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 285 | return bsb_out 286 | 287 | class BiSeNet(nn.Module): 288 | def __init__(self): 289 | super(BiSeNet, self).__init__() 290 | self.db = DetailedBranch() 291 | self.sb = SemanticBranch() 292 | self.bga = BGA(128) 293 | self._init_params() 294 | self.criterion = nn.CrossEntropyLoss(ignore_index=255) 295 | self.binarySegmentationBranch=BinarySegmentationBranch() 296 | self.instanceSegmentationBranch=InstanceSegmentationBranch() 297 | def _init_params(self): 298 | for m in self.modules(): 299 | if isinstance(m, nn.Conv2d): 300 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 301 | if m.bias is not None: 302 | nn.init.constant_(m.bias, 0) 303 | elif isinstance(m, nn.BatchNorm2d): 304 | nn.init.constant_(m.weight, 1) 305 | nn.init.constant_(m.bias, 0) 306 | elif isinstance(m, nn.BatchNorm1d): 307 | nn.init.constant_(m.weight, 1) 308 | nn.init.constant_(m.bias, 0) 309 | elif isinstance(m, nn.Linear): 310 | nn.init.normal_(m.weight, 0, 0.01) 311 | if m.bias is not None: 312 | nn.init.constant_(m.bias, 0) 313 | def forward(self,data,y=None): 314 | db = self.db(data) 315 | sb = self.sb(data) 316 | bga = self.bga(db["out"],sb["out"]) 317 | bsb_res=self.binarySegmentationBranch(bga,sb["seg_stage_outputs"],db["detail_stage_outputs"]) 318 | isb_res=self.instanceSegmentationBranch(bga) 319 | return { 320 | 'instance_seg_logits': isb_res, 321 | # 'binary_seg_pred': binary_seg_ret, 322 | 'binary_seg_logits': bsb_res 323 | } 324 | 325 | if __name__ == '__main__': 326 | import os 327 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 328 | input = torch.rand(1, 3, 256, 512) 329 | model = BiSeNet() 330 | model.eval() 331 | print(model) 332 | output = model(input) 333 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 334 | print('BiSeNet_v2', output["instance_seg_logits"].size()) 335 | print('BiSeNet_v2', output["binary_seg_logits"].size()) 336 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 337 | -------------------------------------------------------------------------------- /lanenet/model/BiseNet_v2-bak.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | from lanenet import config 8 | import collections 9 | 10 | class conv2d(nn.Module): 11 | def __init__(self,in_dim,out_dim,k,pad,stride,groups = 1,bias=False,use_bn = True,use_rl = True): 12 | super(conv2d,self).__init__() 13 | self.use_bn = use_bn 14 | self.use_rl = use_rl 15 | self.conv = nn.Conv2d(in_dim,out_dim,k,padding=pad,stride=stride, groups=groups,bias=bias) 16 | self.bn = nn.BatchNorm2d(out_dim) 17 | self.relu = nn.ReLU(inplace=True) 18 | def forward(self,bottom): 19 | if self.use_bn and self.use_rl: 20 | return self.relu(self.bn(self.conv(bottom))) 21 | elif self.use_bn: 22 | return self.bn(self.conv(bottom)) 23 | else: 24 | return self.conv(bottom) 25 | 26 | class SegHead(nn.Module): 27 | def __init__(self,in_dim,out_dim,cls,size=[720,1280]): 28 | super(SegHead,self).__init__() 29 | self.size = size 30 | self.conv = conv2d(in_dim,out_dim,3,1,1) 31 | self.cls = conv2d(out_dim,cls,1,0,1,use_bn=False,use_rl=False) 32 | def forward(self,feat): 33 | x = self.conv(feat) 34 | x = self.cls(x) 35 | pred = F.interpolate(x, size=self.size, mode="bilinear",align_corners=True) 36 | return pred 37 | 38 | class StemBlock(nn.Module): 39 | def __init__(self): 40 | super(StemBlock,self).__init__() 41 | self.conv1 = conv2d(3,16,3,1,2) 42 | self.conv_1x1 = conv2d(16,8,1,0,1) 43 | self.conv_3x3 = conv2d(8,16,3,1,2) 44 | self.mpooling = nn.MaxPool2d(3,2,1) 45 | self.conv2 = conv2d(32,16,3,1,1) 46 | def forward(self,bottom): 47 | base = self.conv1(bottom) 48 | conv_1 = self.conv_1x1(base) 49 | conv_3 = self.conv_3x3(conv_1) 50 | pool = self.mpooling(base) 51 | cat = torch.cat([conv_3,pool],1) 52 | res = self.conv2(cat) 53 | return res 54 | 55 | class ContextEmbeddingBlock(nn.Module): 56 | def __init__(self,in_dim): 57 | super(ContextEmbeddingBlock,self).__init__() 58 | self.gap = nn.AdaptiveAvgPool2d(1)#1 59 | self.bn1 = nn.BatchNorm2d(in_dim) 60 | self.conv1 = conv2d(in_dim,in_dim,1,0,1) 61 | self.conv2 = conv2d(in_dim,in_dim,3,1,1,use_bn = False,use_rl = False) 62 | def forward(self,bottom): 63 | gap = self.gap(bottom) 64 | bn = self.bn1(gap) 65 | conv1 = self.conv1(bn) 66 | feat = bottom+conv1 67 | res = self.conv2(feat) 68 | return res 69 | 70 | class GatherExpansion(nn.Module): 71 | def __init__(self,in_dim,out_dim,stride = 1,exp = 6): 72 | super(GatherExpansion,self).__init__() 73 | exp_dim = in_dim*exp 74 | self.stride = stride 75 | self.conv1 = conv2d(in_dim,exp_dim,3,1,1) 76 | self.dwconv2 = conv2d(exp_dim,exp_dim,3,1,1,exp_dim,use_rl = False) 77 | self.conv_11 = conv2d(exp_dim,out_dim,1,0,1,use_rl = False) 78 | 79 | self.dwconv1 = conv2d(exp_dim,exp_dim,3,1,2,exp_dim,use_rl = False) 80 | self.dwconv3 = conv2d(in_dim,in_dim,3,1,2,in_dim,use_rl = False) 81 | self.conv_12 = conv2d(in_dim,out_dim,1,0,1,use_rl = False) 82 | self.relu = nn.ReLU(inplace=True) 83 | def forward(self,bottom): 84 | base = self.conv1(bottom) 85 | if self.stride == 2: 86 | base = self.dwconv1(base) 87 | bottom = self.dwconv3(bottom) 88 | bottom = self.conv_12(bottom) 89 | x = self.dwconv2(base) 90 | x = self.conv_11(x) 91 | res = self.relu(x+bottom) 92 | return res 93 | 94 | class BGA(nn.Module): 95 | def __init__(self,in_dim): 96 | super(BGA,self).__init__() 97 | self.in_dim = in_dim 98 | self.db_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 99 | self.db_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 100 | self.db_conv = conv2d(in_dim,in_dim,3,1,2,use_rl=False) 101 | self.db_apooling = nn.AvgPool2d(3,2,1) 102 | 103 | self.sb_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 104 | self.sb_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 105 | self.sb_conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 106 | self.sb_sigmoid = nn.Sigmoid() 107 | 108 | self.conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 109 | def forward(self,db,sb): 110 | db_dwc = self.db_dwconv(db) 111 | db_out = self.db_conv1x1(db_dwc)# 112 | db_conv = self.db_conv(db) 113 | db_pool = self.db_apooling(db_conv) 114 | 115 | sb_dwc = self.sb_dwconv(sb) 116 | sb_out = self.sb_sigmoid(self.sb_conv1x1(sb_dwc))# 117 | sb_conv = self.sb_conv(sb) 118 | sb_up = self.sb_sigmoid(F.interpolate(sb_conv, size=db_out.size()[2:], mode="bilinear",align_corners=True)) 119 | db_l = db_out*sb_up 120 | sb_r = F.interpolate(sb_out*db_pool, size=db_out.size()[2:], mode="bilinear",align_corners=True) 121 | res = self.conv(db_l+sb_r) 122 | return res 123 | 124 | class DetailedBranch(nn.Module): 125 | def __init__(self): 126 | super(DetailedBranch,self).__init__() 127 | self.s1_conv1 = conv2d(3,64,3,1,2) 128 | self.s1_conv2 = conv2d(64,64,3,1,1) 129 | 130 | self.s2_conv1 = conv2d(64,64,3,1,2) 131 | self.s2_conv2 = conv2d(64,64,3,1,1) 132 | self.s2_conv3 = conv2d(64,64,3,1,1) 133 | 134 | self.s3_conv1 = conv2d(64,128,3,1,2) 135 | self.s3_conv2 = conv2d(128,128,3,1,1) 136 | self.s3_conv3 = conv2d(128,128,3,1,1) 137 | def forward(self,bottom): 138 | 139 | detail_stage_outputs = collections.OrderedDict() 140 | 141 | s1_1 = self.s1_conv1(bottom) 142 | s1_2 = self.s1_conv2(s1_1) 143 | 144 | detail_stage_outputs["stg1"] = s1_2 145 | 146 | s2_1 = self.s2_conv1(s1_2) 147 | s2_2 = self.s2_conv2(s2_1) 148 | s2_3 = self.s2_conv3(s2_2) 149 | 150 | detail_stage_outputs["stg2"] = s2_3 151 | 152 | s3_1 = self.s3_conv1(s2_3) 153 | s3_2 = self.s3_conv2(s3_1) 154 | s3_3 = self.s3_conv3(s3_2) 155 | 156 | detail_stage_outputs["stg3"] = s3_3 157 | 158 | return { 159 | 'out': s3_3, 160 | 'detail_stage_outputs': detail_stage_outputs 161 | } 162 | 163 | class SemanticBranch(nn.Module): 164 | def __init__(self): 165 | super(SemanticBranch,self).__init__() 166 | self.stem = StemBlock() 167 | self.s3_ge1 = GatherExpansion(16,32,2) 168 | self.s3_ge2 = GatherExpansion(32,32) 169 | 170 | self.s4_ge1 = GatherExpansion(32,64,2) 171 | self.s4_ge2 = GatherExpansion(64,64) 172 | 173 | self.s5_ge1 = GatherExpansion(64,128,2) 174 | self.s5_ge2 = GatherExpansion(128,128) 175 | self.s5_ge3 = GatherExpansion(128,128) 176 | self.s5_ge4 = GatherExpansion(128,128) 177 | self.s5_ge5 = GatherExpansion(128,128,exp=1) 178 | 179 | self.ceb = ContextEmbeddingBlock(128) 180 | 181 | def forward(self,bottom): 182 | seg_stage_outputs = collections.OrderedDict() 183 | 184 | stg1 = self.stem(bottom) 185 | #print(stg12.size()) 186 | seg_stage_outputs["stg1"] = stg1 187 | 188 | stg3 = self.s3_ge1(stg1) 189 | stg3 = self.s3_ge2(stg3) 190 | #print(stg3.size()) 191 | seg_stage_outputs["stg3"] = stg3 192 | 193 | stg4 = self.s4_ge1(stg3) 194 | stg4 = self.s4_ge2(stg4) 195 | 196 | seg_stage_outputs["stg4"] = stg4 197 | #print(stg4.size()) 198 | stg5 = self.s5_ge1(stg4) 199 | stg5 = self.s5_ge2(stg5) 200 | stg5 = self.s5_ge3(stg5) 201 | stg5 = self.s5_ge4(stg5) 202 | stg5 = self.s5_ge5(stg5) 203 | 204 | seg_stage_outputs["stg5"] = stg5 205 | #print(stg5.size()) 206 | out = self.ceb(stg5) 207 | 208 | return { 209 | 'out': out, 210 | 'seg_stage_outputs': seg_stage_outputs 211 | } 212 | 213 | class InstanceSegmentationBranch(nn.Module): 214 | def __init__(self): 215 | super(InstanceSegmentationBranch,self).__init__() 216 | self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 217 | self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 218 | def forward(self,data): 219 | input_tensor_size=list(data.size()) 220 | tmp_size=input_tensor_size[2:] 221 | out_put_tensor_size=tuple([int(tmp * 8) for tmp in tmp_size]) 222 | conv1_out=self.bsconv1(data) 223 | conv2_out=self.bsconv2(conv1_out) 224 | isb_out = F.interpolate(conv2_out, size=out_put_tensor_size, mode="bilinear",align_corners=True) 225 | return isb_out 226 | 227 | class BinarySegmentationBranch(nn.Module): 228 | 229 | def __init__(self): 230 | super(BinarySegmentationBranch,self).__init__() 231 | 232 | self.bsconv1_pre = conv2d(128,32,3,1,1,use_rl=True) 233 | self.bsconv1_pre1 = conv2d(32,64,3,1,1,use_rl=True) 234 | self.bsconv1_pre2 = conv2d(144,64,3,1,1,use_rl=True) 235 | self.bsconv1_pre3 = conv2d(64,128,3,1,1,use_rl=True) 236 | self.bsconv1_pre4 = conv2d(192,64,3,1,1,use_rl=True) 237 | self.bsconv1_pre5 = conv2d(64,128,3,1,1,use_rl=True) 238 | self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 239 | 240 | 241 | # self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 242 | # self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 243 | # self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 244 | 245 | def forward(self,data,seg_stage_outputs,detail_stage_outputs): 246 | 247 | input_tensor_size=list(data.size()) 248 | tmp_size=input_tensor_size[2:] 249 | output_stage2_size=[int(tmp * 2) for tmp in tmp_size] 250 | output_stage1_size=[int(tmp * 4) for tmp in tmp_size] 251 | out_put_tensor_size=[int(tmp * 8) for tmp in tmp_size] 252 | 253 | out1=self.bsconv1_pre(data) 254 | out2=self.bsconv1_pre1(out1) 255 | output_stage2_tensor = F.interpolate(out2, size=tuple(output_stage2_size), mode="bilinear",align_corners=True) 256 | # output_stage2_tensor = tf.concat([output_stage2_tensor, detail_stage_outputs['stage_2'], semantic_stage_outputs['stage_1']], axis=-1, name='stage_2_concate_features') 257 | output_stage2_tensor=torch.cat([output_stage2_tensor,detail_stage_outputs['stg2'], seg_stage_outputs['stg1']],1) 258 | output_stage2_tensor=self.bsconv1_pre2(output_stage2_tensor) 259 | output_stage2_tensor=self.bsconv1_pre3(output_stage2_tensor) 260 | output_stage1_tensor = F.interpolate(output_stage2_tensor, size=tuple(output_stage1_size), mode="bilinear",align_corners=True) 261 | #output_stage1_tensor = tf.concat([output_stage1_tensor, detail_stage_outputs['stage_1']], axis=-1, name='stage_1_concate_features') 262 | output_stage1_tensor=torch.cat([output_stage1_tensor,detail_stage_outputs['stg1']],1) 263 | output_stage1_tensor=self.bsconv1_pre4(output_stage1_tensor) 264 | output_stage1_tensor=self.bsconv1_pre5(output_stage1_tensor) 265 | output_stage1_tensor=self.bsconv3(output_stage1_tensor) 266 | 267 | # conv_out1=self.bsconv1(data) 268 | # conv_out2=self.bsconv2(conv_out1) 269 | # conv_out3=self.bsconv3(conv_out2) 270 | 271 | # print("--------------------conv_out3 size--------------------") 272 | # print(list(conv_out3.size())) 273 | # print("--------------------bga size--------------------") 274 | 275 | # print("--------------------out_put_tensor_size--------------------") 276 | # print(out_put_tensor_size) 277 | # print(tuple(out_put_tensor_size)) 278 | # print("--------------------out_put_tensor_size--------------------") 279 | bsb_out = F.interpolate(output_stage1_tensor, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 280 | return bsb_out 281 | 282 | class BiSeNet(nn.Module): 283 | def __init__(self): 284 | super(BiSeNet, self).__init__() 285 | self.db = DetailedBranch() 286 | self.sb = SemanticBranch() 287 | self.bga = BGA(128) 288 | self._init_params() 289 | self.criterion = nn.CrossEntropyLoss(ignore_index=255) 290 | self.binarySegmentationBranch=BinarySegmentationBranch() 291 | self.instanceSegmentationBranch=InstanceSegmentationBranch() 292 | def _init_params(self): 293 | for m in self.modules(): 294 | if isinstance(m, nn.Conv2d): 295 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 296 | if m.bias is not None: 297 | nn.init.constant_(m.bias, 0) 298 | elif isinstance(m, nn.BatchNorm2d): 299 | nn.init.constant_(m.weight, 1) 300 | nn.init.constant_(m.bias, 0) 301 | elif isinstance(m, nn.BatchNorm1d): 302 | nn.init.constant_(m.weight, 1) 303 | nn.init.constant_(m.bias, 0) 304 | elif isinstance(m, nn.Linear): 305 | nn.init.normal_(m.weight, 0, 0.01) 306 | if m.bias is not None: 307 | nn.init.constant_(m.bias, 0) 308 | def forward(self,data,y=None): 309 | db = self.db(data) 310 | sb = self.sb(data) 311 | bga = self.bga(db["out"],sb["out"]) 312 | # print("--------------------bga size--------------------") 313 | # print(bga.size()) 314 | # print("--------------------bga size--------------------") 315 | bsb_res=self.binarySegmentationBranch(bga,sb["seg_stage_outputs"],db["detail_stage_outputs"]) 316 | isb_res=self.instanceSegmentationBranch(bga) 317 | return { 318 | 'instance_seg_logits': isb_res, 319 | # 'binary_seg_pred': binary_seg_ret, 320 | 'binary_seg_logits': bsb_res 321 | } 322 | 323 | if __name__ == '__main__': 324 | import os 325 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 326 | input = torch.rand(1, 3, 256, 512).cuda() 327 | model = BiSeNet().cuda() 328 | model.eval() 329 | print(model) 330 | output = model(input) 331 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 332 | print('BiSeNet_v2', output["instance_seg_logits"].size()) 333 | print('BiSeNet_v2', output["binary_seg_logits"].size()) 334 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 335 | -------------------------------------------------------------------------------- /lanenet/model/BiseNet_v2-2021-04-21.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | from lanenet import config 8 | import collections 9 | 10 | class conv2d(nn.Module): 11 | def __init__(self,in_dim,out_dim,k,pad,stride,groups = 1,bias=False,use_bn = True,use_rl = True): 12 | super(conv2d,self).__init__() 13 | self.use_bn = use_bn 14 | self.use_rl = use_rl 15 | self.conv = nn.Conv2d(in_dim,out_dim,k,padding=pad,stride=stride, groups=groups,bias=bias) 16 | self.bn = nn.BatchNorm2d(out_dim) 17 | self.relu = nn.ReLU(inplace=True) 18 | def forward(self,bottom): 19 | if self.use_bn and self.use_rl: 20 | return self.relu(self.bn(self.conv(bottom))) 21 | elif self.use_bn: 22 | return self.bn(self.conv(bottom)) 23 | else: 24 | return self.conv(bottom) 25 | 26 | class SegHead(nn.Module): 27 | def __init__(self,in_dim,out_dim,cls,size=[720,1280]): 28 | super(SegHead,self).__init__() 29 | self.size = size 30 | self.conv = conv2d(in_dim,out_dim,3,1,1) 31 | self.cls = conv2d(out_dim,cls,1,0,1,use_bn=False,use_rl=False) 32 | def forward(self,feat): 33 | x = self.conv(feat) 34 | x = self.cls(x) 35 | pred = F.interpolate(x, size=self.size, mode="bilinear",align_corners=True) 36 | return pred 37 | 38 | class StemBlock(nn.Module): 39 | def __init__(self): 40 | super(StemBlock,self).__init__() 41 | self.conv1 = conv2d(3,16,3,1,2) 42 | self.conv_1x1 = conv2d(16,8,1,0,1) 43 | self.conv_3x3 = conv2d(8,16,3,1,2) 44 | self.mpooling = nn.MaxPool2d(3,2,1) 45 | self.conv2 = conv2d(32,16,3,1,1) 46 | def forward(self,bottom): 47 | base = self.conv1(bottom) 48 | conv_1 = self.conv_1x1(base) 49 | conv_3 = self.conv_3x3(conv_1) 50 | pool = self.mpooling(base) 51 | cat = torch.cat([conv_3,pool],1) 52 | res = self.conv2(cat) 53 | return res 54 | 55 | class ContextEmbeddingBlock(nn.Module): 56 | def __init__(self,in_dim): 57 | super(ContextEmbeddingBlock,self).__init__() 58 | self.gap = nn.AdaptiveAvgPool2d(1)#1 59 | self.bn1 = nn.BatchNorm2d(in_dim) 60 | self.conv1 = conv2d(in_dim,in_dim,1,0,1) 61 | self.conv2 = conv2d(in_dim,in_dim,3,1,1,use_bn = False,use_rl = False) 62 | def forward(self,bottom): 63 | gap = self.gap(bottom) 64 | bn = self.bn1(gap) 65 | conv1 = self.conv1(bn) 66 | feat = bottom+conv1 67 | res = self.conv2(feat) 68 | return res 69 | 70 | class GatherExpansion(nn.Module): 71 | def __init__(self,in_dim,out_dim,stride = 1,exp = 6): 72 | super(GatherExpansion,self).__init__() 73 | exp_dim = in_dim*exp 74 | self.stride = stride 75 | self.conv1 = conv2d(in_dim,exp_dim,3,1,1) 76 | self.dwconv2 = conv2d(exp_dim,exp_dim,3,1,1,exp_dim,use_rl = False) 77 | self.conv_11 = conv2d(exp_dim,out_dim,1,0,1,use_rl = False) 78 | 79 | self.dwconv1 = conv2d(exp_dim,exp_dim,3,1,2,exp_dim,use_rl = False) 80 | self.dwconv3 = conv2d(in_dim,in_dim,3,1,2,in_dim,use_rl = False) 81 | self.conv_12 = conv2d(in_dim,out_dim,1,0,1,use_rl = False) 82 | self.relu = nn.ReLU(inplace=True) 83 | def forward(self,bottom): 84 | base = self.conv1(bottom) 85 | if self.stride == 2: 86 | base = self.dwconv1(base) 87 | bottom = self.dwconv3(bottom) 88 | bottom = self.conv_12(bottom) 89 | x = self.dwconv2(base) 90 | x = self.conv_11(x) 91 | res = self.relu(x+bottom) 92 | return res 93 | 94 | class BGA(nn.Module): 95 | def __init__(self,in_dim): 96 | super(BGA,self).__init__() 97 | self.in_dim = in_dim 98 | self.db_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 99 | self.db_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 100 | self.db_conv = conv2d(in_dim,in_dim,3,1,2,use_rl=False) 101 | self.db_apooling = nn.AvgPool2d(3,2,1) 102 | 103 | self.sb_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 104 | self.sb_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 105 | self.sb_conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 106 | self.sb_sigmoid = nn.Sigmoid() 107 | 108 | self.conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 109 | def forward(self,db,sb): 110 | db_dwc = self.db_dwconv(db) 111 | db_out = self.db_conv1x1(db_dwc)# 112 | db_conv = self.db_conv(db) 113 | db_pool = self.db_apooling(db_conv) 114 | 115 | sb_dwc = self.sb_dwconv(sb) 116 | sb_out = self.sb_sigmoid(self.sb_conv1x1(sb_dwc))# 117 | sb_conv = self.sb_conv(sb) 118 | sb_up = self.sb_sigmoid(F.interpolate(sb_conv, size=db_out.size()[2:], mode="bilinear",align_corners=True)) 119 | db_l = db_out*sb_up 120 | sb_r = F.interpolate(sb_out*db_pool, size=db_out.size()[2:], mode="bilinear",align_corners=True) 121 | res = self.conv(db_l+sb_r) 122 | return res 123 | 124 | class DetailedBranch(nn.Module): 125 | def __init__(self): 126 | super(DetailedBranch,self).__init__() 127 | self.s1_conv1 = conv2d(3,64,3,1,2) 128 | self.s1_conv2 = conv2d(64,64,3,1,1) 129 | 130 | self.s2_conv1 = conv2d(64,64,3,1,2) 131 | self.s2_conv2 = conv2d(64,64,3,1,1) 132 | self.s2_conv3 = conv2d(64,64,3,1,1) 133 | 134 | self.s3_conv1 = conv2d(64,128,3,1,2) 135 | self.s3_conv2 = conv2d(128,128,3,1,1) 136 | self.s3_conv3 = conv2d(128,128,3,1,1) 137 | def forward(self,bottom): 138 | 139 | detail_stage_outputs = collections.OrderedDict() 140 | 141 | s1_1 = self.s1_conv1(bottom) 142 | s1_2 = self.s1_conv2(s1_1) 143 | 144 | detail_stage_outputs["stg1"] = s1_2 145 | 146 | s2_1 = self.s2_conv1(s1_2) 147 | s2_2 = self.s2_conv2(s2_1) 148 | s2_3 = self.s2_conv3(s2_2) 149 | 150 | detail_stage_outputs["stg2"] = s2_3 151 | 152 | s3_1 = self.s3_conv1(s2_3) 153 | s3_2 = self.s3_conv2(s3_1) 154 | s3_3 = self.s3_conv3(s3_2) 155 | 156 | detail_stage_outputs["stg3"] = s3_3 157 | 158 | return { 159 | 'out': s3_3, 160 | 'detail_stage_outputs': detail_stage_outputs 161 | } 162 | 163 | class SemanticBranch(nn.Module): 164 | def __init__(self): 165 | super(SemanticBranch,self).__init__() 166 | self.stem = StemBlock() 167 | self.s3_ge1 = GatherExpansion(16,32,2) 168 | self.s3_ge2 = GatherExpansion(32,32) 169 | 170 | self.s4_ge1 = GatherExpansion(32,64,2) 171 | self.s4_ge2 = GatherExpansion(64,64) 172 | 173 | self.s5_ge1 = GatherExpansion(64,128,2) 174 | self.s5_ge2 = GatherExpansion(128,128) 175 | self.s5_ge3 = GatherExpansion(128,128) 176 | self.s5_ge4 = GatherExpansion(128,128) 177 | self.s5_ge5 = GatherExpansion(128,128,exp=1) 178 | 179 | self.ceb = ContextEmbeddingBlock(128) 180 | 181 | def forward(self,bottom): 182 | seg_stage_outputs = collections.OrderedDict() 183 | 184 | stg1 = self.stem(bottom) 185 | #print(stg12.size()) 186 | seg_stage_outputs["stg1"] = stg1 187 | 188 | stg3 = self.s3_ge1(stg1) 189 | stg3 = self.s3_ge2(stg3) 190 | #print(stg3.size()) 191 | seg_stage_outputs["stg3"] = stg3 192 | 193 | stg4 = self.s4_ge1(stg3) 194 | stg4 = self.s4_ge2(stg4) 195 | 196 | seg_stage_outputs["stg4"] = stg4 197 | #print(stg4.size()) 198 | stg5 = self.s5_ge1(stg4) 199 | stg5 = self.s5_ge2(stg5) 200 | stg5 = self.s5_ge3(stg5) 201 | stg5 = self.s5_ge4(stg5) 202 | stg5 = self.s5_ge5(stg5) 203 | 204 | seg_stage_outputs["stg5"] = stg5 205 | #print(stg5.size()) 206 | out = self.ceb(stg5) 207 | 208 | return { 209 | 'out': out, 210 | 'seg_stage_outputs': seg_stage_outputs 211 | } 212 | 213 | class InstanceSegmentationBranch(nn.Module): 214 | def __init__(self): 215 | super(InstanceSegmentationBranch,self).__init__() 216 | self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 217 | self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 218 | def forward(self,data): 219 | input_tensor_size=list(data.size()) 220 | tmp_size=input_tensor_size[2:] 221 | out_put_tensor_size=tuple([int(tmp * 8) for tmp in tmp_size]) 222 | conv1_out=self.bsconv1(data) 223 | conv2_out=self.bsconv2(conv1_out) 224 | isb_out = F.interpolate(conv2_out, size=out_put_tensor_size, mode="bilinear",align_corners=True) 225 | return isb_out 226 | 227 | class BinarySegmentationBranch(nn.Module): 228 | 229 | def __init__(self): 230 | super(BinarySegmentationBranch,self).__init__() 231 | 232 | self.bsconv1_pre = conv2d(128,32,3,1,1,use_rl=True) 233 | self.bsconv1_pre1 = conv2d(32,64,3,1,1,use_rl=True) 234 | self.bsconv1_pre2 = conv2d(144,64,3,1,1,use_rl=True) 235 | self.bsconv1_pre3 = conv2d(64,128,3,1,1,use_rl=True) 236 | self.bsconv1_pre4 = conv2d(192,64,3,1,1,use_rl=True) 237 | self.bsconv1_pre5 = conv2d(64,128,3,1,1,use_rl=True) 238 | self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 239 | 240 | 241 | # self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 242 | # self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 243 | # self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 244 | 245 | def forward(self,data,seg_stage_outputs,detail_stage_outputs): 246 | 247 | input_tensor_size=list(data.size()) 248 | tmp_size=input_tensor_size[2:] 249 | output_stage2_size=[int(tmp * 2) for tmp in tmp_size] 250 | output_stage1_size=[int(tmp * 4) for tmp in tmp_size] 251 | out_put_tensor_size=[int(tmp * 8) for tmp in tmp_size] 252 | 253 | out1=self.bsconv1_pre(data) 254 | out2=self.bsconv1_pre1(out1) 255 | output_stage2_tensor = F.interpolate(out2, size=tuple(output_stage2_size), mode="bilinear",align_corners=True) 256 | # output_stage2_tensor = tf.concat([output_stage2_tensor, detail_stage_outputs['stage_2'], semantic_stage_outputs['stage_1']], axis=-1, name='stage_2_concate_features') 257 | output_stage2_tensor=torch.cat([output_stage2_tensor,detail_stage_outputs['stg2'], seg_stage_outputs['stg1']],1) 258 | output_stage2_tensor=self.bsconv1_pre2(output_stage2_tensor) 259 | output_stage2_tensor=self.bsconv1_pre3(output_stage2_tensor) 260 | output_stage1_tensor = F.interpolate(output_stage2_tensor, size=tuple(output_stage1_size), mode="bilinear",align_corners=True) 261 | #output_stage1_tensor = tf.concat([output_stage1_tensor, detail_stage_outputs['stage_1']], axis=-1, name='stage_1_concate_features') 262 | output_stage1_tensor=torch.cat([output_stage1_tensor,detail_stage_outputs['stg1']],1) 263 | output_stage1_tensor=self.bsconv1_pre4(output_stage1_tensor) 264 | output_stage1_tensor=self.bsconv1_pre5(output_stage1_tensor) 265 | output_stage1_tensor=self.bsconv3(output_stage1_tensor) 266 | 267 | # conv_out1=self.bsconv1(data) 268 | # conv_out2=self.bsconv2(conv_out1) 269 | # conv_out3=self.bsconv3(conv_out2) 270 | 271 | # print("--------------------conv_out3 size--------------------") 272 | # print(list(conv_out3.size())) 273 | # print("--------------------bga size--------------------") 274 | 275 | # print("--------------------out_put_tensor_size--------------------") 276 | # print(out_put_tensor_size) 277 | # print(tuple(out_put_tensor_size)) 278 | # print("--------------------out_put_tensor_size--------------------") 279 | bsb_out = F.interpolate(output_stage1_tensor, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 280 | return bsb_out 281 | 282 | class BiSeNet(nn.Module): 283 | def __init__(self): 284 | super(BiSeNet, self).__init__() 285 | self.db = DetailedBranch() 286 | self.sb = SemanticBranch() 287 | self.bga = BGA(128) 288 | self._init_params() 289 | self.criterion = nn.CrossEntropyLoss(ignore_index=255) 290 | self.binarySegmentationBranch=BinarySegmentationBranch() 291 | self.instanceSegmentationBranch=InstanceSegmentationBranch() 292 | def _init_params(self): 293 | for m in self.modules(): 294 | if isinstance(m, nn.Conv2d): 295 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 296 | if m.bias is not None: 297 | nn.init.constant_(m.bias, 0) 298 | elif isinstance(m, nn.BatchNorm2d): 299 | nn.init.constant_(m.weight, 1) 300 | nn.init.constant_(m.bias, 0) 301 | elif isinstance(m, nn.BatchNorm1d): 302 | nn.init.constant_(m.weight, 1) 303 | nn.init.constant_(m.bias, 0) 304 | elif isinstance(m, nn.Linear): 305 | nn.init.normal_(m.weight, 0, 0.01) 306 | if m.bias is not None: 307 | nn.init.constant_(m.bias, 0) 308 | def forward(self,data,y=None): 309 | db = self.db(data) 310 | sb = self.sb(data) 311 | bga = self.bga(db["out"],sb["out"]) 312 | # print("--------------------bga size--------------------") 313 | # print(bga.size()) 314 | # print("--------------------bga size--------------------") 315 | bsb_res=self.binarySegmentationBranch(bga,sb["seg_stage_outputs"],db["detail_stage_outputs"]) 316 | isb_res=self.instanceSegmentationBranch(bga) 317 | return { 318 | 'instance_seg_logits': isb_res, 319 | # 'binary_seg_pred': binary_seg_ret, 320 | 'binary_seg_logits': bsb_res 321 | } 322 | 323 | if __name__ == '__main__': 324 | import os 325 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 326 | input = torch.rand(1, 3, 256, 512).cuda() 327 | model = BiSeNet().cuda() 328 | model.eval() 329 | print(model) 330 | output = model(input) 331 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 332 | print('BiSeNet_v2', output["instance_seg_logits"].size()) 333 | print('BiSeNet_v2', output["binary_seg_logits"].size()) 334 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 335 | -------------------------------------------------------------------------------- /lanenet/model/bisenetv2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ConvBNReLU(nn.Module): 8 | 9 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, 10 | dilation=1, groups=1, bias=False): 11 | super(ConvBNReLU, self).__init__() 12 | self.conv = nn.Conv2d( 13 | in_chan, out_chan, kernel_size=ks, stride=stride, 14 | padding=padding, dilation=dilation, 15 | groups=groups, bias=bias) 16 | self.bn = nn.BatchNorm2d(out_chan) 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | feat = self.conv(x) 21 | feat = self.bn(feat) 22 | feat = self.relu(feat) 23 | return feat 24 | 25 | 26 | class UpSample(nn.Module): 27 | 28 | def __init__(self, n_chan, factor=2): 29 | super(UpSample, self).__init__() 30 | out_chan = n_chan * factor * factor 31 | self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0) 32 | self.up = nn.PixelShuffle(factor) 33 | self.init_weight() 34 | 35 | def forward(self, x): 36 | feat = self.proj(x) 37 | feat = self.up(feat) 38 | return feat 39 | 40 | def init_weight(self): 41 | nn.init.xavier_normal_(self.proj.weight, gain=1.) 42 | 43 | 44 | class DetailBranch(nn.Module): 45 | 46 | def __init__(self): 47 | super(DetailBranch, self).__init__() 48 | self.S1 = nn.Sequential( 49 | ConvBNReLU(3, 64, 3, stride=2), 50 | ConvBNReLU(64, 64, 3, stride=1), 51 | ) 52 | self.S2 = nn.Sequential( 53 | ConvBNReLU(64, 64, 3, stride=2), 54 | ConvBNReLU(64, 64, 3, stride=1), 55 | ConvBNReLU(64, 64, 3, stride=1), 56 | ) 57 | self.S3 = nn.Sequential( 58 | ConvBNReLU(64, 128, 3, stride=2), 59 | ConvBNReLU(128, 128, 3, stride=1), 60 | ConvBNReLU(128, 128, 3, stride=1), 61 | ) 62 | 63 | def forward(self, x): 64 | feat = self.S1(x) 65 | feat = self.S2(feat) 66 | feat = self.S3(feat) 67 | return feat 68 | 69 | 70 | class StemBlock(nn.Module): 71 | 72 | def __init__(self): 73 | super(StemBlock, self).__init__() 74 | self.conv = ConvBNReLU(3, 16, 3, stride=2) 75 | self.left = nn.Sequential( 76 | ConvBNReLU(16, 8, 1, stride=1, padding=0), 77 | ConvBNReLU(8, 16, 3, stride=2), 78 | ) 79 | self.right = nn.MaxPool2d( 80 | kernel_size=3, stride=2, padding=1, ceil_mode=False) 81 | self.fuse = ConvBNReLU(32, 16, 3, stride=1) 82 | 83 | def forward(self, x): 84 | feat = self.conv(x) 85 | feat_left = self.left(feat) 86 | feat_right = self.right(feat) 87 | feat = torch.cat([feat_left, feat_right], dim=1) 88 | feat = self.fuse(feat) 89 | return feat 90 | 91 | 92 | class CEBlock(nn.Module): 93 | 94 | def __init__(self): 95 | super(CEBlock, self).__init__() 96 | self.bn = nn.BatchNorm2d(128) 97 | self.conv_gap = ConvBNReLU(128, 128, 1, stride=1, padding=0) 98 | #TODO: in paper here is naive conv2d, no bn-relu 99 | self.conv_last = ConvBNReLU(128, 128, 3, stride=1) 100 | 101 | def forward(self, x): 102 | feat = torch.mean(x, dim=(2, 3), keepdim=True) 103 | feat = self.bn(feat) 104 | feat = self.conv_gap(feat) 105 | feat = feat + x 106 | feat = self.conv_last(feat) 107 | return feat 108 | 109 | 110 | class GELayerS1(nn.Module): 111 | 112 | def __init__(self, in_chan, out_chan, exp_ratio=6): 113 | super(GELayerS1, self).__init__() 114 | mid_chan = in_chan * exp_ratio 115 | self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1) 116 | self.dwconv = nn.Sequential( 117 | nn.Conv2d( 118 | in_chan, mid_chan, kernel_size=3, stride=1, 119 | padding=1, groups=in_chan, bias=False), 120 | nn.BatchNorm2d(mid_chan), 121 | nn.ReLU(inplace=True), # not shown in paper 122 | ) 123 | self.conv2 = nn.Sequential( 124 | nn.Conv2d( 125 | mid_chan, out_chan, kernel_size=1, stride=1, 126 | padding=0, bias=False), 127 | nn.BatchNorm2d(out_chan), 128 | ) 129 | self.conv2[1].last_bn = True 130 | self.relu = nn.ReLU(inplace=True) 131 | 132 | def forward(self, x): 133 | feat = self.conv1(x) 134 | feat = self.dwconv(feat) 135 | feat = self.conv2(feat) 136 | feat = feat + x 137 | feat = self.relu(feat) 138 | return feat 139 | 140 | 141 | class GELayerS2(nn.Module): 142 | 143 | def __init__(self, in_chan, out_chan, exp_ratio=6): 144 | super(GELayerS2, self).__init__() 145 | mid_chan = in_chan * exp_ratio 146 | self.conv1 = ConvBNReLU(in_chan, in_chan, 3, stride=1) 147 | self.dwconv1 = nn.Sequential( 148 | nn.Conv2d( 149 | in_chan, mid_chan, kernel_size=3, stride=2, 150 | padding=1, groups=in_chan, bias=False), 151 | nn.BatchNorm2d(mid_chan), 152 | ) 153 | self.dwconv2 = nn.Sequential( 154 | nn.Conv2d( 155 | mid_chan, mid_chan, kernel_size=3, stride=1, 156 | padding=1, groups=mid_chan, bias=False), 157 | nn.BatchNorm2d(mid_chan), 158 | nn.ReLU(inplace=True), # not shown in paper 159 | ) 160 | self.conv2 = nn.Sequential( 161 | nn.Conv2d( 162 | mid_chan, out_chan, kernel_size=1, stride=1, 163 | padding=0, bias=False), 164 | nn.BatchNorm2d(out_chan), 165 | ) 166 | self.conv2[1].last_bn = True 167 | self.shortcut = nn.Sequential( 168 | nn.Conv2d( 169 | in_chan, in_chan, kernel_size=3, stride=2, 170 | padding=1, groups=in_chan, bias=False), 171 | nn.BatchNorm2d(in_chan), 172 | nn.Conv2d( 173 | in_chan, out_chan, kernel_size=1, stride=1, 174 | padding=0, bias=False), 175 | nn.BatchNorm2d(out_chan), 176 | ) 177 | self.relu = nn.ReLU(inplace=True) 178 | 179 | def forward(self, x): 180 | feat = self.conv1(x) 181 | feat = self.dwconv1(feat) 182 | feat = self.dwconv2(feat) 183 | feat = self.conv2(feat) 184 | shortcut = self.shortcut(x) 185 | feat = feat + shortcut 186 | feat = self.relu(feat) 187 | return feat 188 | 189 | 190 | class SegmentBranch(nn.Module): 191 | 192 | def __init__(self): 193 | super(SegmentBranch, self).__init__() 194 | self.S1S2 = StemBlock() 195 | self.S3 = nn.Sequential( 196 | GELayerS2(16, 32), 197 | GELayerS1(32, 32), 198 | ) 199 | self.S4 = nn.Sequential( 200 | GELayerS2(32, 64), 201 | GELayerS1(64, 64), 202 | ) 203 | self.S5_4 = nn.Sequential( 204 | GELayerS2(64, 128), 205 | GELayerS1(128, 128), 206 | GELayerS1(128, 128), 207 | GELayerS1(128, 128), 208 | ) 209 | self.S5_5 = CEBlock() 210 | 211 | def forward(self, x): 212 | feat2 = self.S1S2(x) 213 | feat3 = self.S3(feat2) 214 | feat4 = self.S4(feat3) 215 | feat5_4 = self.S5_4(feat4) 216 | feat5_5 = self.S5_5(feat5_4) 217 | return feat2, feat3, feat4, feat5_4, feat5_5 218 | 219 | 220 | class BGALayer(nn.Module): 221 | 222 | def __init__(self): 223 | super(BGALayer, self).__init__() 224 | self.left1 = nn.Sequential( 225 | nn.Conv2d( 226 | 128, 128, kernel_size=3, stride=1, 227 | padding=1, groups=128, bias=False), 228 | nn.BatchNorm2d(128), 229 | nn.Conv2d( 230 | 128, 128, kernel_size=1, stride=1, 231 | padding=0, bias=False), 232 | ) 233 | self.left2 = nn.Sequential( 234 | nn.Conv2d( 235 | 128, 128, kernel_size=3, stride=2, 236 | padding=1, bias=False), 237 | nn.BatchNorm2d(128), 238 | nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False) 239 | ) 240 | self.right1 = nn.Sequential( 241 | nn.Conv2d( 242 | 128, 128, kernel_size=3, stride=1, 243 | padding=1, bias=False), 244 | nn.BatchNorm2d(128), 245 | ) 246 | self.right2 = nn.Sequential( 247 | nn.Conv2d( 248 | 128, 128, kernel_size=3, stride=1, 249 | padding=1, groups=128, bias=False), 250 | nn.BatchNorm2d(128), 251 | nn.Conv2d( 252 | 128, 128, kernel_size=1, stride=1, 253 | padding=0, bias=False), 254 | ) 255 | self.up1 = nn.Upsample(scale_factor=4) 256 | self.up2 = nn.Upsample(scale_factor=4) 257 | ##TODO: does this really has no relu? 258 | self.conv = nn.Sequential( 259 | nn.Conv2d( 260 | 128, 128, kernel_size=3, stride=1, 261 | padding=1, bias=False), 262 | nn.BatchNorm2d(128), 263 | nn.ReLU(inplace=True), # not shown in paper 264 | ) 265 | 266 | def forward(self, x_d, x_s): 267 | dsize = x_d.size()[2:] 268 | left1 = self.left1(x_d) 269 | left2 = self.left2(x_d) 270 | right1 = self.right1(x_s) 271 | right2 = self.right2(x_s) 272 | right1 = self.up1(right1) 273 | left = left1 * torch.sigmoid(right1) 274 | right = left2 * torch.sigmoid(right2) 275 | right = self.up2(right) 276 | out = self.conv(left + right) 277 | return out 278 | 279 | 280 | 281 | class SegmentHead(nn.Module): 282 | 283 | def __init__(self, in_chan, mid_chan, n_classes, up_factor=8, aux=True): 284 | super(SegmentHead, self).__init__() 285 | self.conv = ConvBNReLU(in_chan, mid_chan, 3, stride=1) 286 | self.drop = nn.Dropout(0.1) 287 | self.up_factor = up_factor 288 | 289 | out_chan = n_classes * up_factor * up_factor 290 | if aux: 291 | self.conv_out = nn.Sequential( 292 | ConvBNReLU(mid_chan, up_factor * up_factor, 3, stride=1), 293 | nn.Conv2d(up_factor * up_factor, out_chan, 1, 1, 0), 294 | nn.PixelShuffle(up_factor) 295 | ) 296 | else: 297 | self.conv_out = nn.Sequential( 298 | nn.Conv2d(mid_chan, out_chan, 1, 1, 0), 299 | nn.PixelShuffle(up_factor) 300 | ) 301 | 302 | def forward(self, x): 303 | feat = self.conv(x) 304 | feat = self.drop(feat) 305 | feat = self.conv_out(feat) 306 | return feat 307 | 308 | 309 | class BiSeNetV2(nn.Module): 310 | 311 | def __init__(self, n_classes, output_aux=True): 312 | super(BiSeNetV2, self).__init__() 313 | self.output_aux = output_aux 314 | self.detail = DetailBranch() 315 | self.segment = SegmentBranch() 316 | self.bga = BGALayer() 317 | 318 | ## TODO: what is the number of mid chan ? 319 | self.head = SegmentHead(128, 1024, n_classes, up_factor=8, aux=False) 320 | if self.output_aux: 321 | self.aux2 = SegmentHead(16, 128, n_classes, up_factor=4) 322 | self.aux3 = SegmentHead(32, 128, n_classes, up_factor=8) 323 | self.aux4 = SegmentHead(64, 128, n_classes, up_factor=16) 324 | self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32) 325 | 326 | self.init_weights() 327 | 328 | def forward(self, x): 329 | size = x.size()[2:] 330 | feat_d = self.detail(x) 331 | feat2, feat3, feat4, feat5_4, feat_s = self.segment(x) 332 | feat_head = self.bga(feat_d, feat_s) 333 | 334 | logits = self.head(feat_head) 335 | if self.output_aux: 336 | logits_aux2 = self.aux2(feat2) 337 | logits_aux3 = self.aux3(feat3) 338 | logits_aux4 = self.aux4(feat4) 339 | logits_aux5_4 = self.aux5_4(feat5_4) 340 | return logits, logits_aux2, logits_aux3, logits_aux4, logits_aux5_4 341 | pred = logits.argmax(dim=1) 342 | return pred 343 | 344 | def init_weights(self): 345 | for name, module in self.named_modules(): 346 | if isinstance(module, (nn.Conv2d, nn.Linear)): 347 | nn.init.kaiming_normal_(module.weight, mode='fan_out') 348 | if not module.bias is None: nn.init.constant_(module.bias, 0) 349 | elif isinstance(module, nn.modules.batchnorm._BatchNorm): 350 | if hasattr(module, 'last_bn') and module.last_bn: 351 | nn.init.zeros_(module.weight) 352 | else: 353 | nn.init.ones_(module.weight) 354 | nn.init.zeros_(module.bias) 355 | 356 | 357 | if __name__ == "__main__": 358 | # x = torch.randn(16, 3, 1024, 2048) 359 | # detail = DetailBranch() 360 | # feat = detail(x) 361 | # print('detail', feat.size()) 362 | # 363 | # x = torch.randn(16, 3, 1024, 2048) 364 | # stem = StemBlock() 365 | # feat = stem(x) 366 | # print('stem', feat.size()) 367 | # 368 | # x = torch.randn(16, 128, 16, 32) 369 | # ceb = CEBlock() 370 | # feat = ceb(x) 371 | # print(feat.size()) 372 | # 373 | # x = torch.randn(16, 32, 16, 32) 374 | # ge1 = GELayerS1(32, 32) 375 | # feat = ge1(x) 376 | # print(feat.size()) 377 | # 378 | # x = torch.randn(16, 16, 16, 32) 379 | # ge2 = GELayerS2(16, 32) 380 | # feat = ge2(x) 381 | # print(feat.size()) 382 | # 383 | # left = torch.randn(16, 128, 64, 128) 384 | # right = torch.randn(16, 128, 16, 32) 385 | # bga = BGALayer() 386 | # feat = bga(left, right) 387 | # print(feat.size()) 388 | # 389 | # x = torch.randn(16, 128, 64, 128) 390 | # head = SegmentHead(128, 128, 19) 391 | # logits = head(x) 392 | # print(logits.size()) 393 | # 394 | # x = torch.randn(16, 3, 1024, 2048) 395 | # segment = SegmentBranch() 396 | # feat = segment(x)[0] 397 | # print(feat.size()) 398 | # 399 | x = torch.randn(16, 3, 1024, 2048) 400 | model = BiSeNetV2(n_classes=19) 401 | outs = model(x) 402 | for out in outs: 403 | print(out.size()) 404 | # print(logits.size()) 405 | 406 | # for name, param in model.named_parameters(): 407 | # if len(param.size()) == 1: 408 | # print(name) -------------------------------------------------------------------------------- /lanenet/model/BiseNet_v2_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | from lanenet import config 8 | import collections 9 | 10 | class conv2d(nn.Module): 11 | def __init__(self,in_dim,out_dim,k,pad,stride,groups = 1,bias=False,use_bn = True,use_rl = True): 12 | super(conv2d,self).__init__() 13 | self.use_bn = use_bn 14 | self.use_rl = use_rl 15 | self.conv = nn.Conv2d(in_dim,out_dim,k,padding=pad,stride=stride, groups=groups,bias=bias) 16 | self.bn = nn.BatchNorm2d(out_dim) 17 | self.relu = nn.ReLU(inplace=True) 18 | def forward(self,bottom): 19 | if self.use_bn and self.use_rl: 20 | return self.relu(self.bn(self.conv(bottom))) 21 | elif self.use_bn: 22 | return self.bn(self.conv(bottom)) 23 | else: 24 | return self.conv(bottom) 25 | 26 | class SegHead(nn.Module): 27 | def __init__(self,in_dim,out_dim,cls,size=[720,1280]): 28 | super(SegHead,self).__init__() 29 | self.size = size 30 | self.conv = conv2d(in_dim,out_dim,3,1,1) 31 | self.cls = conv2d(out_dim,cls,1,0,1,use_bn=False,use_rl=False) 32 | def forward(self,feat): 33 | x = self.conv(feat) 34 | x = self.cls(x) 35 | pred = F.interpolate(x, size=self.size, mode="bilinear",align_corners=True) 36 | return pred 37 | 38 | class StemBlock(nn.Module): 39 | def __init__(self): 40 | super(StemBlock,self).__init__() 41 | self.conv1 = conv2d(3,16,3,1,2) 42 | self.conv_1x1 = conv2d(16,32,1,0,1) 43 | self.conv_3x3 = conv2d(32,32,3,1,2) 44 | self.mpooling = nn.MaxPool2d(3,2,1) 45 | self.conv2 = conv2d(48,32,3,1,1) 46 | def forward(self,bottom): 47 | base = self.conv1(bottom) 48 | conv_1 = self.conv_1x1(base) 49 | conv_3 = self.conv_3x3(conv_1) 50 | pool = self.mpooling(base) 51 | cat = torch.cat([conv_3,pool],1) 52 | res = self.conv2(cat) 53 | return res 54 | 55 | class ContextEmbeddingBlock(nn.Module): 56 | def __init__(self,in_dim): 57 | super(ContextEmbeddingBlock,self).__init__() 58 | self.gap = nn.AdaptiveAvgPool2d(1)#1 59 | # self.gap = nn.AvgPool2d(3,1,1) 60 | self.bn1 = nn.BatchNorm2d(in_dim) 61 | self.conv1 = conv2d(in_dim,in_dim,1,0,1) 62 | self.conv2 = conv2d(in_dim,in_dim,3,1,1,use_bn = False,use_rl = False) 63 | def forward(self,bottom): 64 | gap = self.gap(bottom) 65 | # print(gap) 66 | bn = self.bn1(gap) 67 | conv1 = self.conv1(bn) 68 | feat = bottom+conv1 69 | res = self.conv2(feat) 70 | return res 71 | 72 | class GatherExpansion(nn.Module): 73 | def __init__(self,in_dim,out_dim,stride = 1,exp = 6): 74 | super(GatherExpansion,self).__init__() 75 | exp_dim = in_dim*exp 76 | self.stride = stride 77 | self.conv1 = conv2d(in_dim,exp_dim,3,1,1) 78 | self.dwconv2 = conv2d(exp_dim,exp_dim,3,1,1,exp_dim,use_rl = False) 79 | self.conv_11 = conv2d(exp_dim,out_dim,1,0,1,use_rl = False) 80 | 81 | self.dwconv1 = conv2d(exp_dim,exp_dim,3,1,2,exp_dim,use_rl = False) 82 | self.dwconv3 = conv2d(in_dim,in_dim,3,1,2,in_dim,use_rl = False) 83 | self.conv_12 = conv2d(in_dim,out_dim,1,0,1,use_rl = False) 84 | self.relu = nn.ReLU(inplace=True) 85 | def forward(self,bottom): 86 | base = self.conv1(bottom) 87 | if self.stride == 2: 88 | base = self.dwconv1(base) 89 | bottom = self.dwconv3(bottom) 90 | bottom = self.conv_12(bottom) 91 | x = self.dwconv2(base) 92 | x = self.conv_11(x) 93 | res = self.relu(x+bottom) 94 | return res 95 | 96 | class BGA(nn.Module): 97 | def __init__(self,in_dim): 98 | super(BGA,self).__init__() 99 | self.in_dim = in_dim 100 | self.db_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 101 | self.db_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 102 | self.db_conv = conv2d(in_dim,in_dim,3,1,2,use_rl=False) 103 | self.db_apooling = nn.AvgPool2d(3,2,1) 104 | 105 | self.sb_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 106 | self.sb_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 107 | self.sb_conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 108 | self.sb_sigmoid = nn.Sigmoid() 109 | 110 | self.conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 111 | def forward(self,db,sb): 112 | db_dwc = self.db_dwconv(db) 113 | db_out = self.db_conv1x1(db_dwc)# 114 | db_conv = self.db_conv(db) 115 | db_pool = self.db_apooling(db_conv) 116 | 117 | sb_dwc = self.sb_dwconv(sb) 118 | sb_out = self.sb_sigmoid(self.sb_conv1x1(sb_dwc))# 119 | sb_conv = self.sb_conv(sb) 120 | sb_up = self.sb_sigmoid(F.interpolate(sb_conv, size=db_out.size()[2:], mode="bilinear",align_corners=True)) 121 | db_l = db_out*sb_up 122 | sb_r = F.interpolate(sb_out*db_pool, size=db_out.size()[2:], mode="bilinear",align_corners=True) 123 | res = self.conv(db_l+sb_r) 124 | return res 125 | 126 | class DetailedBranch(nn.Module): 127 | def __init__(self): 128 | super(DetailedBranch,self).__init__() 129 | self.s1_conv1 = conv2d(3,64,3,1,2) 130 | self.s1_conv2 = conv2d(64,64,3,1,1) 131 | 132 | self.s2_conv1 = conv2d(64,64,3,1,2) 133 | self.s2_conv2 = conv2d(64,64,3,1,1) 134 | self.s2_conv3 = conv2d(64,64,3,1,1) 135 | 136 | self.s3_conv1 = conv2d(64,128,3,1,2) 137 | self.s3_conv2 = conv2d(128,128,3,1,1) 138 | self.s3_conv3 = conv2d(128,128,3,1,1) 139 | def forward(self,bottom): 140 | 141 | detail_stage_outputs = collections.OrderedDict() 142 | 143 | s1_1 = self.s1_conv1(bottom) 144 | s1_2 = self.s1_conv2(s1_1) 145 | 146 | detail_stage_outputs["stg1"] = s1_2 147 | 148 | s2_1 = self.s2_conv1(s1_2) 149 | s2_2 = self.s2_conv2(s2_1) 150 | s2_3 = self.s2_conv3(s2_2) 151 | 152 | detail_stage_outputs["stg2"] = s2_3 153 | 154 | s3_1 = self.s3_conv1(s2_3) 155 | s3_2 = self.s3_conv2(s3_1) 156 | s3_3 = self.s3_conv3(s3_2) 157 | 158 | detail_stage_outputs["stg3"] = s3_3 159 | 160 | return { 161 | 'out': s3_3, 162 | 'detail_stage_outputs': detail_stage_outputs 163 | } 164 | 165 | class SemanticBranch(nn.Module): 166 | def __init__(self): 167 | super(SemanticBranch,self).__init__() 168 | self.stem = StemBlock() 169 | self.s3_ge1 = GatherExpansion(32,32,2) 170 | self.s3_ge2 = GatherExpansion(32,32) 171 | 172 | self.s4_ge1 = GatherExpansion(32,64,2) 173 | self.s4_ge2 = GatherExpansion(64,64) 174 | 175 | self.s5_ge1 = GatherExpansion(64,128,2) 176 | self.s5_ge2 = GatherExpansion(128,128) 177 | self.s5_ge3 = GatherExpansion(128,128) 178 | self.s5_ge4 = GatherExpansion(128,128) 179 | self.s5_ge5 = GatherExpansion(128,128,exp=1) 180 | 181 | self.ceb = ContextEmbeddingBlock(128) 182 | 183 | if config.is_training==1: 184 | self.seghead1 = SegHead(32,32,config.num_classes) 185 | self.seghead3 = SegHead(32,32,config.num_classes) 186 | self.seghead4 = SegHead(64,64,config.num_classes) 187 | self.seghead5 = SegHead(128,128,config.num_classes) 188 | 189 | def forward(self,bottom): 190 | seg_stage_outputs = collections.OrderedDict() 191 | 192 | stg1 = self.stem(bottom) 193 | #print(stg12.size()) 194 | seg_stage_outputs["stg1"] = stg1 195 | 196 | stg3 = self.s3_ge1(stg1) 197 | stg3 = self.s3_ge2(stg3) 198 | #print(stg3.size()) 199 | seg_stage_outputs["stg3"] = stg3 200 | 201 | stg4 = self.s4_ge1(stg3) 202 | stg4 = self.s4_ge2(stg4) 203 | 204 | seg_stage_outputs["stg4"] = stg4 205 | #print(stg4.size()) 206 | stg5 = self.s5_ge1(stg4) 207 | stg5 = self.s5_ge2(stg5) 208 | stg5 = self.s5_ge3(stg5) 209 | stg5 = self.s5_ge4(stg5) 210 | stg5 = self.s5_ge5(stg5) 211 | 212 | seg_stage_outputs["stg5"] = stg5 213 | #print(stg5.size()) 214 | out = self.ceb(stg5) 215 | 216 | if self.training: 217 | seghead1 = self.seghead1(stg1) 218 | seghead2 = self.seghead3(stg3) 219 | seghead3 = self.seghead4(stg4) 220 | seghead4 = self.seghead5(stg5) 221 | 222 | return { 223 | 'out': out, 224 | 'seg_stage_outputs': seg_stage_outputs, 225 | 'seghead1':seghead1, 226 | 'seghead2':seghead2, 227 | 'seghead3':seghead3, 228 | 'seghead4':seghead4 229 | } 230 | else: 231 | return { 232 | 'out': out, 233 | 'seg_stage_outputs': seg_stage_outputs 234 | } 235 | 236 | class InstanceSegmentationBranch(nn.Module): 237 | def __init__(self): 238 | super(InstanceSegmentationBranch,self).__init__() 239 | self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 240 | self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 241 | def forward(self,data): 242 | input_tensor_size=list(data.size()) 243 | tmp_size=input_tensor_size[2:] 244 | out_put_tensor_size=tuple([int(tmp * 8) for tmp in tmp_size]) 245 | conv1_out=self.bsconv1(data) 246 | conv2_out=self.bsconv2(conv1_out) 247 | isb_out = F.interpolate(conv2_out, size=out_put_tensor_size, mode="bilinear",align_corners=True) 248 | return isb_out 249 | 250 | class BinarySegmentationBranch(nn.Module): 251 | 252 | def __init__(self): 253 | super(BinarySegmentationBranch,self).__init__() 254 | 255 | self.bsconv1_pre = conv2d(128,128,3,1,1,use_rl=True) # 256 | # 融合1/2 特征 257 | self.bsconv1_pre2 = conv2d((128+(128-32)),32,3,1,1,use_rl=True) # 258 | # 融合1/4 特征 259 | self.bsconv1_pre4 = conv2d(128,16,3,1,1,use_rl=True)# 260 | # 融合1/8 特征 261 | self.bsconv3 = conv2d(16,config.num_classes,1,0,1,use_rl=False,use_bn=True) 262 | 263 | 264 | # self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 265 | # self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 266 | # self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 267 | 268 | def forward(self,data,seg_stage_outputs,detail_stage_outputs): 269 | 270 | input_tensor_size=list(data.size()) 271 | tmp_size=input_tensor_size[2:] 272 | output_stage2_size=[int(tmp * 2) for tmp in tmp_size] 273 | output_stage1_size=[int(tmp * 4) for tmp in tmp_size] 274 | out_put_tensor_size=[int(tmp * 8) for tmp in tmp_size] 275 | 276 | out1=self.bsconv1_pre(data) 277 | # channel=256 278 | # out2=self.bsconv1_pre1(out1) 279 | 280 | output_stage2_tensor = F.interpolate(out1, size=tuple(output_stage2_size), mode="bilinear",align_corners=True) 281 | # output_stage2_tensor = tf.concat([output_stage2_tensor, detail_stage_outputs['stage_2'], semantic_stage_outputs['stage_1']], axis=-1, name='stage_2_concate_features') 282 | output_stage2_tensor=torch.cat([output_stage2_tensor,detail_stage_outputs['stg2'], seg_stage_outputs['stg1']],1) 283 | 284 | output_stage2_tensor=self.bsconv1_pre2(output_stage2_tensor) 285 | # output_stage2_tensor=self.bsconv1_pre3(output_stage2_tensor) 286 | # channel=256 287 | output_stage1_tensor = F.interpolate(output_stage2_tensor, size=tuple(output_stage1_size), mode="bilinear",align_corners=True) 288 | #output_stage1_tensor = tf.concat([output_stage1_tensor, detail_stage_outputs['stage_1']], axis=-1, name='stage_1_concate_features') 289 | output_stage1_tensor=torch.cat([output_stage1_tensor,detail_stage_outputs['stg1']],1) 290 | 291 | print(output_stage1_tensor.size()) 292 | 293 | output_stage1_tensor=self.bsconv1_pre4(output_stage1_tensor) 294 | # output_stage1_tensor=self.bsconv1_pre5(output_stage1_tensor) 295 | output_stage1_tensor=self.bsconv3(output_stage1_tensor) 296 | 297 | # conv_out1=self.bsconv1(data) 298 | # conv_out2=self.bsconv2(conv_out1) 299 | # conv_out3=self.bsconv3(conv_out2) 300 | 301 | # print("--------------------conv_out3 size--------------------") 302 | # print(list(conv_out3.size())) 303 | # print("--------------------bga size--------------------") 304 | 305 | # print("--------------------out_put_tensor_size--------------------") 306 | # print(out_put_tensor_size) 307 | # print(tuple(out_put_tensor_size)) 308 | # print("--------------------out_put_tensor_size--------------------") 309 | bsb_out = F.interpolate(output_stage1_tensor, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 310 | return bsb_out 311 | 312 | class BiSeNet(nn.Module): 313 | def __init__(self): 314 | super(BiSeNet, self).__init__() 315 | self.db = DetailedBranch() 316 | self.sb = SemanticBranch() 317 | self.bga = BGA(128) 318 | if config.is_training==1: 319 | self.seghead = SegHead(128,128,config.num_classes) 320 | self.criterion = nn.CrossEntropyLoss(ignore_index=255) 321 | self._init_params() 322 | self.binarySegmentationBranch=BinarySegmentationBranch() 323 | self.instanceSegmentationBranch=InstanceSegmentationBranch() 324 | def _init_params(self): 325 | for m in self.modules(): 326 | if isinstance(m, nn.Conv2d): 327 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 328 | if m.bias is not None: 329 | nn.init.constant_(m.bias, 0) 330 | elif isinstance(m, nn.BatchNorm2d): 331 | nn.init.constant_(m.weight, 1) 332 | nn.init.constant_(m.bias, 0) 333 | elif isinstance(m, nn.BatchNorm1d): 334 | nn.init.constant_(m.weight, 1) 335 | nn.init.constant_(m.bias, 0) 336 | elif isinstance(m, nn.Linear): 337 | nn.init.normal_(m.weight, 0, 0.01) 338 | if m.bias is not None: 339 | nn.init.constant_(m.bias, 0) 340 | def forward(self,data,y=None): 341 | db = self.db(data) 342 | sb = self.sb(data) 343 | bga = self.bga(db["out"],sb["out"]) 344 | bsb_res=self.binarySegmentationBranch(bga,sb["seg_stage_outputs"],db["detail_stage_outputs"]) 345 | isb_res=self.instanceSegmentationBranch(bga) 346 | if config.is_training==1: 347 | aux1_loss = self.criterion(sb["seghead1"], y) 348 | aux2_loss = self.criterion(sb["seghead2"], y) 349 | aux3_loss = self.criterion(sb["seghead3"], y) 350 | aux4_loss = self.criterion(sb["seghead4"], y) 351 | return { 352 | 'instance_seg_logits': isb_res, 353 | # 'binary_seg_pred': binary_seg_ret, 354 | 'binary_seg_logits': bsb_res, 355 | 'aux1_loss':aux1_loss, 356 | 'aux2_loss':aux2_loss, 357 | 'aux3_loss':aux3_loss, 358 | 'aux4_loss':aux4_loss 359 | } 360 | return { 361 | 'instance_seg_logits': isb_res, 362 | # 'binary_seg_pred': binary_seg_ret, 363 | 'binary_seg_logits': bsb_res 364 | } 365 | 366 | if __name__ == '__main__': 367 | import os 368 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 369 | input = torch.rand(1, 3, 256, 512) 370 | model = BiSeNet() 371 | model.eval() 372 | print(model) 373 | output = model(input) 374 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 375 | print('BiSeNet_v2', output["instance_seg_logits"].size()) 376 | print('BiSeNet_v2', output["binary_seg_logits"].size()) 377 | print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 378 | -------------------------------------------------------------------------------- /lanenet/model/BiseNet_v2_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | from lanenet import config 8 | import collections 9 | 10 | class conv2d(nn.Module): 11 | def __init__(self,in_dim,out_dim,k,pad,stride,groups = 1,bias=False,use_bn = True,use_rl = True): 12 | super(conv2d,self).__init__() 13 | self.use_bn = use_bn 14 | self.use_rl = use_rl 15 | self.conv = nn.Conv2d(in_dim,out_dim,k,padding=pad,stride=stride, groups=groups,bias=bias) 16 | self.bn = nn.BatchNorm2d(out_dim) 17 | self.relu = nn.ReLU(inplace=True) 18 | def forward(self,bottom): 19 | if self.use_bn and self.use_rl: 20 | return self.relu(self.bn(self.conv(bottom))) 21 | elif self.use_bn: 22 | return self.bn(self.conv(bottom)) 23 | else: 24 | return self.conv(bottom) 25 | 26 | class SegHead(nn.Module): 27 | def __init__(self,in_dim,out_dim,cls): 28 | super(SegHead,self).__init__() 29 | self.conv = conv2d(in_dim,out_dim,3,1,1) 30 | self.cls = conv2d(out_dim,cls,1,0,1,use_bn=False,use_rl=False) 31 | def forward(self,feat,size=[720,1280]): 32 | x = self.conv(feat) 33 | x = self.cls(x) 34 | pred = F.interpolate(x, size=size, mode="bilinear",align_corners=True) 35 | return pred 36 | 37 | class StemBlock(nn.Module): 38 | def __init__(self): 39 | super(StemBlock,self).__init__() 40 | self.conv1 = conv2d(3,16,3,1,2) 41 | self.conv_1x1 = conv2d(16,32,1,0,1) 42 | self.conv_3x3 = conv2d(32,32,3,1,2) 43 | self.mpooling = nn.MaxPool2d(3,2,1) 44 | self.conv2 = conv2d(48,32,3,1,1) 45 | def forward(self,bottom): 46 | base = self.conv1(bottom) 47 | conv_1 = self.conv_1x1(base) 48 | conv_3 = self.conv_3x3(conv_1) 49 | pool = self.mpooling(base) 50 | cat = torch.cat([conv_3,pool],1) 51 | res = self.conv2(cat) 52 | return res 53 | 54 | class ContextEmbeddingBlock(nn.Module): 55 | def __init__(self,in_dim): 56 | super(ContextEmbeddingBlock,self).__init__() 57 | self.gap = nn.AdaptiveAvgPool2d(1)#1 58 | # self.gap = nn.AvgPool2d(3,1,1) 59 | self.bn1 = nn.BatchNorm2d(in_dim) 60 | self.conv1 = conv2d(in_dim,in_dim,1,0,1) 61 | self.conv2 = conv2d(in_dim,in_dim,3,1,1,use_bn = False,use_rl = False) 62 | def forward(self,bottom): 63 | gap = self.gap(bottom) 64 | # print(gap) 65 | bn = self.bn1(gap) 66 | conv1 = self.conv1(bn) 67 | feat = bottom+conv1 68 | res = self.conv2(feat) 69 | return res 70 | 71 | class GatherExpansion(nn.Module): 72 | def __init__(self,in_dim,out_dim,stride = 1,exp = 6): 73 | super(GatherExpansion,self).__init__() 74 | exp_dim = in_dim*exp 75 | self.stride = stride 76 | self.conv1 = conv2d(in_dim,exp_dim,3,1,1) 77 | self.dwconv2 = conv2d(exp_dim,exp_dim,3,1,1,exp_dim,use_rl = False) 78 | self.conv_11 = conv2d(exp_dim,out_dim,1,0,1,use_rl = False) 79 | 80 | self.dwconv1 = conv2d(exp_dim,exp_dim,3,1,2,exp_dim,use_rl = False) 81 | self.dwconv3 = conv2d(in_dim,in_dim,3,1,2,in_dim,use_rl = False) 82 | self.conv_12 = conv2d(in_dim,out_dim,1,0,1,use_rl = False) 83 | self.relu = nn.ReLU(inplace=True) 84 | def forward(self,bottom): 85 | base = self.conv1(bottom) 86 | if self.stride == 2: 87 | base = self.dwconv1(base) 88 | bottom = self.dwconv3(bottom) 89 | bottom = self.conv_12(bottom) 90 | x = self.dwconv2(base) 91 | x = self.conv_11(x) 92 | res = self.relu(x+bottom) 93 | return res 94 | 95 | class BGA(nn.Module): 96 | def __init__(self,in_dim): 97 | super(BGA,self).__init__() 98 | self.in_dim = in_dim 99 | self.db_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 100 | self.db_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 101 | self.db_conv = conv2d(in_dim,in_dim,3,1,2,use_rl=False) 102 | self.db_apooling = nn.AvgPool2d(3,2,1) 103 | 104 | self.sb_dwconv = conv2d(in_dim,in_dim,3,1,1,in_dim,use_rl=False) 105 | self.sb_conv1x1 = conv2d(in_dim,in_dim,1,0,1,use_rl=False,use_bn=False) 106 | self.sb_conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 107 | self.sb_sigmoid = nn.Sigmoid() 108 | 109 | self.conv = conv2d(in_dim,in_dim,3,1,1,use_rl=False) 110 | def forward(self,db,sb): 111 | db_dwc = self.db_dwconv(db) 112 | db_out = self.db_conv1x1(db_dwc)# 113 | db_conv = self.db_conv(db) 114 | db_pool = self.db_apooling(db_conv) 115 | 116 | sb_dwc = self.sb_dwconv(sb) 117 | sb_out = self.sb_sigmoid(self.sb_conv1x1(sb_dwc))# 118 | sb_conv = self.sb_conv(sb) 119 | sb_up = self.sb_sigmoid(F.interpolate(sb_conv, size=db_out.size()[2:], mode="bilinear",align_corners=True)) 120 | db_l = db_out*sb_up 121 | sb_r = F.interpolate(sb_out*db_pool, size=db_out.size()[2:], mode="bilinear",align_corners=True) 122 | res = self.conv(db_l+sb_r) 123 | return res 124 | 125 | class DetailedBranch(nn.Module): 126 | def __init__(self): 127 | super(DetailedBranch,self).__init__() 128 | self.s1_conv1 = conv2d(3,64,3,1,2) 129 | 130 | self.s1_conv2 = conv2d(64,64,3,1,1) 131 | 132 | self.s2_conv1 = conv2d(64,64,3,1,2) 133 | self.s2_conv2 = conv2d(64,64,3,1,1) 134 | self.s2_conv3 = conv2d(64,64,3,1,1) 135 | 136 | self.s3_conv1 = conv2d(64,128,3,1,2) 137 | self.s3_conv2 = conv2d(128,128,3,1,1) 138 | self.s3_conv3 = conv2d(128,128,3,1,1) 139 | def forward(self,bottom): 140 | 141 | detail_stage_outputs = collections.OrderedDict() 142 | 143 | s1_1 = self.s1_conv1(bottom) 144 | # print(s1_1.size()) 145 | s1_2 = self.s1_conv2(s1_1) 146 | 147 | detail_stage_outputs["stg1"] = s1_2 148 | 149 | s2_1 = self.s2_conv1(s1_2) 150 | s2_2 = self.s2_conv2(s2_1) 151 | s2_3 = self.s2_conv3(s2_2) 152 | 153 | detail_stage_outputs["stg2"] = s2_3 154 | 155 | s3_1 = self.s3_conv1(s2_3) 156 | s3_2 = self.s3_conv2(s3_1) 157 | s3_3 = self.s3_conv3(s3_2) 158 | 159 | detail_stage_outputs["stg3"] = s3_3 160 | 161 | return { 162 | 'out': s3_3, 163 | 'detail_stage_outputs': detail_stage_outputs 164 | } 165 | 166 | class SemanticBranch(nn.Module): 167 | def __init__(self): 168 | super(SemanticBranch,self).__init__() 169 | self.stem = StemBlock() 170 | self.s3_ge1 = GatherExpansion(32,32,2) 171 | self.s3_ge2 = GatherExpansion(32,32) 172 | 173 | self.s4_ge1 = GatherExpansion(32,64,2) 174 | self.s4_ge2 = GatherExpansion(64,64) 175 | 176 | self.s5_ge1 = GatherExpansion(64,128,2) 177 | self.s5_ge2 = GatherExpansion(128,128) 178 | self.s5_ge3 = GatherExpansion(128,128) 179 | self.s5_ge4 = GatherExpansion(128,128) 180 | self.s5_ge5 = GatherExpansion(128,128,exp=1) 181 | 182 | self.ceb = ContextEmbeddingBlock(128) 183 | 184 | def forward(self,bottom): 185 | seg_stage_outputs = collections.OrderedDict() 186 | print(bottom.size()) 187 | stg1 = self.stem(bottom) 188 | seg_stage_outputs["stg1"] = stg1 189 | print(stg1.size()) 190 | 191 | stg3 = self.s3_ge1(stg1) 192 | stg3 = self.s3_ge2(stg3) 193 | print(stg3.size()) 194 | seg_stage_outputs["stg3"] = stg3 195 | 196 | stg4 = self.s4_ge1(stg3) 197 | stg4 = self.s4_ge2(stg4) 198 | 199 | seg_stage_outputs["stg4"] = stg4 200 | print(stg4.size()) 201 | stg5 = self.s5_ge1(stg4) 202 | stg5 = self.s5_ge2(stg5) 203 | stg5 = self.s5_ge3(stg5) 204 | stg5 = self.s5_ge4(stg5) 205 | stg5 = self.s5_ge5(stg5) 206 | 207 | seg_stage_outputs["stg5"] = stg5 208 | print(stg5.size()) 209 | out = self.ceb(stg5) 210 | 211 | return { 212 | 'out': out, 213 | 'seg_stage_outputs': seg_stage_outputs 214 | } 215 | 216 | class InstanceSegmentationBranch(nn.Module): 217 | def __init__(self): 218 | super(InstanceSegmentationBranch,self).__init__() 219 | self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 220 | self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 221 | def forward(self,data): 222 | input_tensor_size=list(data.size()) 223 | tmp_size=input_tensor_size[2:] 224 | out_put_tensor_size=tuple([int(tmp * 8) for tmp in tmp_size]) 225 | conv1_out=self.bsconv1(data) 226 | conv2_out=self.bsconv2(conv1_out) 227 | isb_out = F.interpolate(conv2_out, size=out_put_tensor_size, mode="bilinear",align_corners=True) 228 | return isb_out 229 | 230 | class BinarySegmentationBranch(nn.Module): 231 | 232 | def __init__(self): 233 | super(BinarySegmentationBranch,self).__init__() 234 | 235 | self.bsconv1_pre = conv2d(128,32,3,1,1,use_rl=True) # 236 | 237 | if config.is_training==1: 238 | self.bsconv1_pre_help = conv2d(32,config.num_classes,1,0,1,use_rl=False,use_bn=True)#降维 239 | # 融合1/4 特征 240 | self.bsconv1_pre2 = conv2d(128,192,3,1,1,use_rl=True) # 241 | 242 | if config.is_training==1: 243 | self.bsconv1_pre2_help = conv2d(192,config.num_classes,1,0,1,use_rl=False,use_bn=True)#降维 244 | 245 | # 融合1/4 特征 246 | self.bsconv1_pre4 = conv2d((128+192-64),128,3,1,1,use_rl=True)# 247 | # 融合1/8 特征 248 | self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 249 | 250 | if config.is_training==1: 251 | self.seghead1 = SegHead(32,32,config.num_classes) 252 | self.seghead3 = SegHead(32,32,config.num_classes) 253 | self.seghead4 = SegHead(64,64,config.num_classes) 254 | self.seghead5 = SegHead(128,128,config.num_classes) 255 | 256 | # self.bsconv1 = conv2d(128,256,3,1,1,use_rl=True) 257 | # self.bsconv2 = conv2d(256,128,1,0,1,use_rl=True) 258 | # self.bsconv3 = conv2d(128,config.num_classes,1,0,1,use_rl=False,use_bn=True) 259 | 260 | def forward(self,data,seg_stage_outputs,detail_stage_outputs): 261 | 262 | input_tensor_size=list(data.size()) 263 | tmp_size=input_tensor_size[2:] 264 | output_stage2_size=[int(tmp * 2) for tmp in tmp_size] 265 | output_stage1_size=[int(tmp * 4) for tmp in tmp_size] 266 | out_put_tensor_size=[int(tmp * 8) for tmp in tmp_size] 267 | 268 | out1=self.bsconv1_pre(data) 269 | # channel=256 270 | # out2=self.bsconv1_pre1(out1) 271 | 272 | output_stage2_tensor = F.interpolate(out1, size=tuple(output_stage2_size), mode="bilinear",align_corners=True) 273 | if config.is_training==1: 274 | bsb_out1=self.bsconv1_pre_help(out1) 275 | bsb_out1 = F.interpolate(bsb_out1, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 276 | 277 | # output_stage2_tensor = tf.concat([output_stage2_tensor, detail_stage_outputs['stage_2'], semantic_stage_outputs['stage_1']], axis=-1, name='stage_2_concate_features') 278 | output_stage2_tensor=torch.cat([output_stage2_tensor,detail_stage_outputs['stg2'], seg_stage_outputs['stg1']],1) 279 | 280 | output_stage2_tensor=self.bsconv1_pre2(output_stage2_tensor) 281 | # output_stage2_tensor=self.bsconv1_pre3(output_stage2_tensor) 282 | # channel=256 283 | output_stage1_tensor = F.interpolate(output_stage2_tensor, size=tuple(output_stage1_size), mode="bilinear",align_corners=True) 284 | if config.is_training==1: 285 | bsb_out2=self.bsconv1_pre2_help(output_stage2_tensor) 286 | bsb_out2 = F.interpolate(bsb_out2, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 287 | 288 | #output_stage1_tensor = tf.concat([output_stage1_tensor, detail_stage_outputs['stage_1']], axis=-1, name='stage_1_concate_features') 289 | output_stage1_tensor=torch.cat([output_stage1_tensor,detail_stage_outputs['stg1']],1) 290 | 291 | # print(output_stage1_tensor.size()) 292 | 293 | output_stage1_tensor=self.bsconv1_pre4(output_stage1_tensor) 294 | # output_stage1_tensor=self.bsconv1_pre5(output_stage1_tensor) 295 | output_stage1_tensor=self.bsconv3(output_stage1_tensor) 296 | 297 | bsb_out = F.interpolate(output_stage1_tensor, size=tuple(out_put_tensor_size), mode="bilinear",align_corners=True) 298 | 299 | if config.is_training==1: 300 | sg_out1=self.seghead1(seg_stage_outputs['stg1'],size=out_put_tensor_size) 301 | sg_out3=self.seghead3(seg_stage_outputs['stg3'],size=out_put_tensor_size) 302 | sg_out4=self.seghead4(seg_stage_outputs['stg4'],size=out_put_tensor_size) 303 | sg_out5=self.seghead5(seg_stage_outputs['stg5'],size=out_put_tensor_size) 304 | return bsb_out,bsb_out1,bsb_out2,sg_out1,sg_out3,sg_out4,sg_out5 305 | 306 | return bsb_out 307 | 308 | 309 | class BiSeNet(nn.Module): 310 | def __init__(self): 311 | super(BiSeNet, self).__init__() 312 | self.db = DetailedBranch() 313 | self.sb = SemanticBranch() 314 | self.bga = BGA(128) 315 | self._init_params() 316 | self.criterion = nn.CrossEntropyLoss(ignore_index=255) 317 | self.binarySegmentationBranch=BinarySegmentationBranch() 318 | self.instanceSegmentationBranch=InstanceSegmentationBranch() 319 | def _init_params(self): 320 | for m in self.modules(): 321 | if isinstance(m, nn.Conv2d): 322 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 323 | if m.bias is not None: 324 | nn.init.constant_(m.bias, 0) 325 | elif isinstance(m, nn.BatchNorm2d): 326 | nn.init.constant_(m.weight, 1) 327 | nn.init.constant_(m.bias, 0) 328 | elif isinstance(m, nn.BatchNorm1d): 329 | nn.init.constant_(m.weight, 1) 330 | nn.init.constant_(m.bias, 0) 331 | elif isinstance(m, nn.Linear): 332 | nn.init.normal_(m.weight, 0, 0.01) 333 | if m.bias is not None: 334 | nn.init.constant_(m.bias, 0) 335 | def forward(self,data,y=None): 336 | db = self.db(data) 337 | sb = self.sb(data) 338 | bga = self.bga(db["out"],sb["out"]) 339 | isb_res=self.instanceSegmentationBranch(bga) 340 | if config.is_training==1: 341 | bsb_out,bsb_out1,bsb_out2,sg_out1,sg_out3,sg_out4,sg_out5=self.binarySegmentationBranch(bga,sb["seg_stage_outputs"],db["detail_stage_outputs"]) 342 | return { 343 | 'instance_seg_logits': isb_res, 344 | # 'binary_seg_pred': binary_seg_ret, 345 | 'binary_seg_logits': bsb_out, 346 | 'bsb_out1': bsb_out1, 347 | 'bsb_out2': bsb_out2, 348 | 'sg_out1': sg_out1, 349 | 'sg_out3': sg_out3, 350 | 'sg_out4': sg_out4, 351 | 'sg_out5': sg_out5 352 | } 353 | bsb_out=self.binarySegmentationBranch(bga,sb["seg_stage_outputs"],db["detail_stage_outputs"]) 354 | return { 355 | 'instance_seg_logits': isb_res, 356 | # 'binary_seg_pred': binary_seg_ret, 357 | 'binary_seg_logits': bsb_out 358 | } 359 | 360 | if __name__ == '__main__': 361 | # import os 362 | # os.environ["CUDA_VISIBLE_DEVICES"] = '5' 363 | input = torch.rand(1, 3, 800, 800) 364 | model = BiSeNet() 365 | model.eval() 366 | print(model) 367 | output = model(input) 368 | # print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 369 | # print('BiSeNet_v2', output["instance_seg_logits"].size()) 370 | # print('BiSeNet_v2', output["bsb_out2"].size()) 371 | # print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') 372 | print((0.11+0.12+0.13+0.14)/3) 373 | --------------------------------------------------------------------------------