├── .devcontainer ├── devcontainer.json ├── docker-compose.yml └── readme ├── Dockerfile ├── LICENSE ├── README.md ├── config.py ├── docs ├── overall.jpg └── results.jpg ├── model ├── __init__.py └── mysegformer │ ├── PotCrackSeg.py │ ├── __init__.py │ ├── decoders │ └── Decoder.py │ └── encoders │ ├── dual_segformer.py │ └── one_segformer.py ├── run_demo.py ├── train.py └── util ├── MY_dataset.py ├── __init__.py ├── augmentation.py ├── init_func.py ├── load_utils.py ├── lr_policy.py └── util.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.163.1/containers/docker-from-docker-compose 3 | { 4 | "name": "potcrackseg", 5 | "dockerComposeFile": "docker-compose.yml", 6 | "service": "potcrackseg", 7 | "workspaceFolder": "/workspace" 8 | 9 | // // Use this environment variable if you need to bind mount your local source code into a new container. 10 | // "remoteEnv": { 11 | // "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}" 12 | // }, 13 | 14 | // // Set *default* container specific settings.json values on container create. 15 | // "settings": { 16 | // "terminal.integrated.shell.linux": "/bin/bash" 17 | // }, 18 | 19 | // // Add the IDs of extensions you want installed when the container is created. 20 | // "extensions": [ 21 | // "ms-azuretools.vscode-docker" 22 | // ], 23 | 24 | // // Use 'forwardPorts' to make a list of ports inside the container available locally. 25 | // // "forwardPorts": [], 26 | 27 | // // Use 'postCreateCommand' to run commands after the container is created. 28 | // // "postCreateCommand": "docker --version", 29 | 30 | // // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 31 | // "remoteUser": "vscode" 32 | } -------------------------------------------------------------------------------- /.devcontainer/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.3' 2 | services: 3 | potcrackseg: 4 | # Uncomment the next line to use a non-root user for all processes. You can also 5 | # simply use the "remoteUser" property in devcontainer.json if you just want VS Code 6 | # and its sub-processes (terminals, tasks, debugging) to execute as the user. On Linux, 7 | # you may need to update USER_UID and USER_GID in .devcontainer/Dockerfile to match your 8 | # user if not 1000. See https://aka.ms/vscode-remote/containers/non-root for details. 9 | # user: vscode 10 | runtime: nvidia 11 | image: docker_image_drcnet # The name of the docker image 12 | ports: 13 | - '11011:6006' 14 | volumes: 15 | # Update this to wherever you want VS Code to mount the folder of your project 16 | - ..:/workspace:cached # Do not change! 17 | # - /home/sun/somefolder/:/somefolder # folder_in_local_computer:folder_in_docker_container 18 | 19 | # Forwards the local Docker socket to the container. 20 | - /var/run/docker.sock:/var/run/docker-host.sock 21 | shm_size: 32g 22 | devices: 23 | - /dev/nvidia0 24 | # - /dev/nvidia1 # Please note this line, if your computer has only one GPU 25 | 26 | # Uncomment the next four lines if you will use a ptrace-based debuggers like C++, Go, and Rust. 27 | # cap_add: 28 | # - SYS_PTRACE 29 | # security_opt: 30 | # - seccomp:unconfined 31 | 32 | # Overrides default command so things don't shut down after the process ends. 33 | #entrypoint: /usr/local/share/docker-init.sh 34 | command: sleep infinity 35 | -------------------------------------------------------------------------------- /.devcontainer/readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | #RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 'A4B469963BF863CC' 4 | 5 | RUN apt-key del 7fa2af80 6 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub 7 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu2004/x86_64/7fa2af80.pub 8 | 9 | 10 | RUN apt-get update && apt-get install -y vim python3 python3-pip 11 | 12 | RUN pip3 install --upgrade pip 13 | RUN pip3 install setuptools>=40.3.0 14 | 15 | RUN pip3 install -U scipy scikit-learn 16 | RUN pip3 install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 17 | RUN pip3 install torchsummary 18 | RUN pip3 install tensorboard==2.11.0 19 | RUN pip3 install einops 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Autonomous Systems Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PotCrackSeg 2 | The official pytorch implementation of **Segmentation of Road Negative Obstacles Based on Dual Semantic-feature Complementary Fusion for Autonomous Driving**. ([TIV](https://ieeexplore.ieee.org/document/10468640/)) 3 | 4 | 5 | We test our code in Python 3.8, CUDA 11.3, cuDNN 8, and PyTorch 1.12.1. We provide `Dockerfile` to build the docker image we used. You can modify the `Dockerfile` as you want. 6 |
7 | 8 |
9 | 10 | # Demo 11 | 12 | The accompanied video can be found at: 13 |
14 | 15 |
16 | 17 | # Introduction 18 | 19 | PotCrackSeg with an RGB-Depth fusion network with a dual semantic-feature complementary fusion module for the segmentation of potholes and cracks in traffic scenes. 20 | 21 | # Dataset 22 | 23 | The **NPO++** dataset is upgraded from the existing [**NPO**](https://pan.baidu.com/s/1-LuHyKXEuJ0oLMe1PHtq0Q?pwd=drno) dataset by re-labeling potholes and cracks. You can downloaded **NPO++** dataset from [here](https://pan.baidu.com/s/1608EIKo-be63XE3-7UYcIQ?pwd=uxks) 24 | 25 | # Pretrained weights 26 | The pretrained weight of PotCrackSeg can be downloaded from [here](https://pan.baidu.com/s/18xGs1Jp1xbSekBjJVEh9Pg?pwd=ynva). 27 | 28 | # Usage 29 | * Clone this repo 30 | ``` 31 | $ git clone https://github.com/lab-sun/PotCrackSeg.git 32 | ``` 33 | * Build docker image 34 | ``` 35 | $ cd ~/PotCrackSeg 36 | $ docker build -t docker_image_PotCrackSeg . 37 | ``` 38 | * Download the dataset 39 | ``` 40 | $ (You should be in the PotCrackSeg folder) 41 | $ mkdir ./NPO++ 42 | $ cd ./NPO++ 43 | $ (download our preprocessed NPO++.zip in this folder) 44 | $ unzip -d . NPO++.zip 45 | ``` 46 | * To reproduce our results, you need to download our pretrained weights. 47 | ``` 48 | $ (You should be in the PotCrackSeg folder) 49 | $ mkdir ./weights_backup 50 | $ cd ./weights_backup 51 | $ (download our preprocessed weights_backup.zip in this folder) 52 | $ unzip -d . weights_backup.zip 53 | $ docker run -it --shm-size 8G -p 1234:6006 --name docker_container_potcrackseg --gpus all -v ~/PotCrackSeg:/workspace docker_image_potcrackseg 54 | $ (currently, you should be in the docker) 55 | $ cd /workspace 56 | $ (To reproduce the results) 57 | $ python3 run_demo.py 58 | ``` 59 | The results will be saved in the `./runs` folder. The default results are PotCrackSeg-4B. If you want to reproduce the results of PotCrackSeg-2B, you can modify the *PotCrackSeg-4B* to *PotCrackSeg-2B* in run_demo.py 60 | 61 | * To train PotCrackSeg. 62 | ``` 63 | $ (You should be in the PotCrackSeg folder) 64 | $ docker run -it --shm-size 8G -p 1234:6006 --name docker_container_potcrackseg --gpus all -v ~/PotCrackSeg:/workspace docker_image_potcrackseg 65 | $ (currently, you should be in the docker) 66 | $ cd /workspace 67 | $ python3 train.py 68 | ``` 69 | 70 | * To see the training process 71 | ``` 72 | $ (fire up another terminal) 73 | $ docker exec -it docker_container_potcrackseg /bin/bash 74 | $ cd /workspace 75 | $ tensorboard --bind_all --logdir=./runs/tensorboard_log/ 76 | $ (fire up your favorite browser with http://localhost:1234, you will see the tensorboard) 77 | ``` 78 | The results will be saved in the `./runs` folder. 79 | Note: Please change the smoothing factor in the Tensorboard webpage to `0.999`, otherwise, you may not find the patterns from the noisy plots. If you have the error `docker: Error response from daemon: could not select device driver`, please first install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) on your computer! 80 | 81 | # Citation 82 | If you use PotCrackSeg in your academic work, please cite: 83 | ``` 84 | 85 | ``` 86 | 87 | # Acknowledgement 88 | Some of the codes are borrowed from [IGFNet](https://github.com/lab-sun/IGFNet). 89 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | import argparse 8 | 9 | C = edict() 10 | config = C 11 | cfg = C 12 | 13 | C.seed = 12345 14 | 15 | remoteip = os.popen('pwd').read() 16 | C.root_dir = os.path.abspath(os.path.join(os.getcwd(), './')) 17 | C.abs_dir = osp.realpath(".") 18 | 19 | # Dataset config 20 | """Dataset Path""" 21 | C.dataset_name = 'NPO++' 22 | C.dataset_path = osp.join(C.root_dir, 'NPO++') 23 | C.rgb_root_folder = osp.join(C.dataset_path, 'left') 24 | C.rgb_format = '.png' 25 | C.gt_root_folder = osp.join(C.dataset_path, 'labels') 26 | C.gt_format = '.png' 27 | C.x_root_folder = osp.join(C.dataset_path, 'depth') 28 | C.x_format = '.png' 29 | C.x_is_single_channel = False # True for raw depth, thermal and aolp/dolp(not aolp/dolp tri) input 30 | C.train_source = osp.join(C.dataset_path, "train.txt") 31 | C.eval_source = osp.join(C.dataset_path, "test.txt") 32 | C.is_test = False 33 | C.num_train_imgs = 2300 34 | C.num_eval_imgs = 1150 35 | 36 | 37 | """Image Config""" 38 | C.background = 255 39 | C.image_height = 288 40 | C.image_width = 512 41 | C.norm_mean = np.array([0.485, 0.456, 0.406]) 42 | C.norm_std = np.array([0.229, 0.224, 0.225]) 43 | 44 | """ Settings for network, this would be different for each kind of model""" 45 | C.backbone = 'mit_b2' # Remember change the path below. 46 | # C.pretrained_model = C.root_dir + '/pretrained/mit_b2.pth' # Using for training 47 | C.pretrained_model = None # Using for testing 48 | C.decoder = 'MLPDecoder' 49 | C.decoder_embed_dim = 512 50 | C.optimizer = 'AdamW' 51 | 52 | """Train Config""" 53 | C.lr = 6e-5 54 | C.lr_power = 0.9 55 | C.momentum = 0.9 56 | C.weight_decay = 0.0005 57 | C.batch_size = 2 58 | C.nepochs = 500 59 | C.niters_per_epoch = C.num_train_imgs // C.batch_size + 1 60 | C.num_workers = 16 61 | C.train_scale_array = [0.5, 0.75, 1, 1.25, 1.5, 1.75] 62 | C.warm_up_epoch = 10 63 | 64 | C.fix_bias = True 65 | C.bn_eps = 1e-3 66 | C.bn_momentum = 0.1 67 | 68 | """Eval Config""" 69 | C.eval_iter = 25 70 | C.eval_stride_rate = 2 / 3 71 | C.eval_scale_array = [1] # [0.75, 1, 1.25] # 72 | C.eval_flip = False # True # 73 | C.eval_crop_size = [288, 512] # [height weight] 74 | 75 | """Store Config""" 76 | C.checkpoint_start_epoch = 50 77 | C.checkpoint_step = 1 78 | 79 | """Path Config""" 80 | def add_path(path): 81 | if path not in sys.path: 82 | sys.path.insert(0, path) 83 | add_path(osp.join(C.root_dir)) 84 | 85 | C.log_dir = osp.abspath('log_' + C.dataset_name + '_' + C.backbone) 86 | C.tb_dir = osp.abspath(osp.join(C.log_dir, "tb")) 87 | C.log_dir_link = C.log_dir 88 | C.checkpoint_dir = osp.abspath(osp.join(C.log_dir, "checkpoint")) 89 | 90 | exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) 91 | C.log_file = C.log_dir + '/log_' + exp_time + '.log' 92 | C.link_log_file = C.log_file + '/log_last.log' 93 | C.val_log_file = C.log_dir + '/val_' + exp_time + '.log' 94 | C.link_val_log_file = C.log_dir + '/val_last.log' 95 | 96 | if __name__ == '__main__': 97 | print(config.nepochs) 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument( 100 | '-tb', '--tensorboard', default=False, action='store_true') 101 | args = parser.parse_args() 102 | 103 | if args.tensorboard: 104 | open_tensorboard() 105 | -------------------------------------------------------------------------------- /docs/overall.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/PotCrackSeg/4d1aa7b53b4cf7f6ebd1cb660cd66a793d77150d/docs/overall.jpg -------------------------------------------------------------------------------- /docs/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/PotCrackSeg/4d1aa7b53b4cf7f6ebd1cb660cd66a793d77150d/docs/results.jpg -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .mysegformer.PotCrackSeg import EncoderDecoder as PotCrackSeg 2 | 3 | -------------------------------------------------------------------------------- /model/mysegformer/PotCrackSeg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append('.') 7 | 8 | from util.init_func import init_weight 9 | from util.load_utils import load_pretrain 10 | from functools import partial 11 | from config import config 12 | from model.mysegformer.decoders.Decoder import DecoderHead 13 | 14 | 15 | class EncoderDecoder(nn.Module): 16 | def __init__(self,cfg=None, criterion=nn.CrossEntropyLoss(reduction='mean', ignore_index=255), encoder_name='mit_b2', n_class=3, norm_layer=nn.BatchNorm2d): 17 | super(EncoderDecoder, self).__init__() 18 | self.channels = [64, 128, 320, 512] 19 | self.norm_layer = norm_layer 20 | 21 | # import backbone and decoder 22 | if encoder_name == 'PotCrackSeg-5B': 23 | #logger.info('Using backbone: Segformer-B5') 24 | from model.mysegformer.encoders.dual_segformer import mit_b5 as backbone 25 | print("chose mit_b5") 26 | self.backbone = backbone(norm_fuse=norm_layer) 27 | elif encoder_name == 'PotCrackSeg-4B': 28 | #logger.info('Using backbone: Segformer-B4') 29 | from model.mysegformer.encoders.dual_segformer import mit_b4 as backbone 30 | print("chose mit_b4") 31 | self.backbone = backbone(norm_fuse=norm_layer) 32 | elif encoder_name == 'PotCrackSeg-3B': 33 | #logger.info('Using backbone: Segformer-B4') 34 | from model.mysegformer.encoders.dual_segformer import mit_b3 as backbone 35 | print("chose mit_b3") 36 | self.backbone = backbone(norm_fuse=norm_layer) 37 | elif encoder_name == 'PotCrackSeg-2B': 38 | #logger.info('Using backbone: Segformer-B2') 39 | from model.mysegformer.encoders.dual_segformer import mit_b2 as backbone 40 | print("chose mit_b2") 41 | self.backbone = backbone(norm_fuse=norm_layer) 42 | elif encoder_name == 'PotCrackSeg-1B': 43 | #logger.info('Using backbone: Segformer-B1') 44 | from model.mysegformer.encoders.dual_segformer import mit_b1 as backbone 45 | print("chose mit_b1") 46 | self.backbone = backbone(norm_fuse=norm_layer) 47 | elif encoder_name == 'PotCrackSeg-0B': 48 | #logger.info('Using backbone: Segformer-B0') 49 | self.channels = [32, 64, 160, 256] 50 | from model.mysegformer.encoders.dual_segformer import mit_b0 as backbone 51 | print("chose mit_b0") 52 | self.backbone = backbone(norm_fuse=norm_layer) 53 | else: 54 | #logger.info('Using backbone: Segformer-B2') 55 | from encoders.dual_segformer import mit_b2 as backbone 56 | self.backbone = backbone(norm_fuse=norm_layer) 57 | 58 | self.aux_head = None 59 | 60 | self.decode_head = DecoderHead(in_channels=self.channels, num_classes=n_class, norm_layer=norm_layer, embed_dim=512) 61 | 62 | 63 | self.voting = nn.Conv2d(in_channels=n_class*2,out_channels=n_class,kernel_size=3,stride=1,padding=1) 64 | 65 | self.init_weights(cfg, pretrained=cfg.pretrained_model) 66 | 67 | def init_weights(self, cfg, pretrained=None): 68 | if pretrained: 69 | self.backbone.init_weights(pretrained=pretrained) 70 | init_weight(self.decode_head, nn.init.kaiming_normal_, 71 | self.norm_layer, cfg.bn_eps, cfg.bn_momentum, 72 | mode='fan_in', nonlinearity='relu') 73 | 74 | def encode_decode(self, rgb, modal_x): 75 | """Encode images with backbone and decode into a semantic segmentation 76 | map of the same size as input.""" 77 | orisize = rgb.shape 78 | rgb,depth = self.backbone(rgb, modal_x) 79 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion = self.decode_head.forward(rgb,depth) 80 | 81 | rgb_fusion = F.interpolate(rgb_fusion, size=orisize[2:], mode='bilinear', align_corners=False) 82 | depth_fusion = F.interpolate(depth_fusion, size=orisize[2:], mode='bilinear', align_corners=False) 83 | 84 | final = self.voting(torch.cat((rgb_fusion,depth_fusion),dim=1)) 85 | 86 | return rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple,depth_fusion,final 87 | 88 | def forward(self, input): 89 | 90 | rgb = input[:,:3] 91 | modal_x = input[:,3:] 92 | modal_x = torch.cat((modal_x,modal_x,modal_x),dim=1) 93 | 94 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion,final = self.encode_decode(rgb, modal_x) 95 | 96 | return rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion,final 97 | 98 | 99 | def unit_test(): 100 | 101 | num_minibatch = 2 102 | rgb = torch.randn(num_minibatch, 3, 288, 512).cuda(0) 103 | thermal = torch.randn(num_minibatch, 1, 288, 512).cuda(0) 104 | images = torch.cat((rgb,thermal),dim=1) 105 | rtf_net = EncoderDecoder(cfg = config, encoder_name='mit_b2', decoder_name='MLPDecoderNewDRC', n_class=3).cuda(0) 106 | #input = torch.cat((rgb, thermal), dim=1) 107 | #out_seg_4,out_seg_3,out_seg_2,out_seg_1,out_x,out_d,final = rtf_net(images) 108 | 109 | rtf_net.eval() 110 | 111 | final = rtf_net(images) 112 | 113 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion ,out_x,out_d,final = final 114 | 115 | # for i in out_seg_1: 116 | # print(i.shape) 117 | 118 | 119 | print(final.shape) 120 | 121 | if __name__ == '__main__': 122 | unit_test() 123 | -------------------------------------------------------------------------------- /model/mysegformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/PotCrackSeg/4d1aa7b53b4cf7f6ebd1cb660cd66a793d77150d/model/mysegformer/__init__.py -------------------------------------------------------------------------------- /model/mysegformer/decoders/Decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | 5 | from torch.nn.modules import module 6 | import torch.nn.functional as F 7 | 8 | class MLP(nn.Module): 9 | """ 10 | Linear Embedding: 11 | """ 12 | def __init__(self, input_dim=2048, embed_dim=768): 13 | super().__init__() 14 | self.proj = nn.Linear(input_dim, embed_dim) 15 | 16 | def forward(self, x): 17 | x = x.flatten(2).transpose(1, 2) 18 | x = self.proj(x) 19 | return x 20 | 21 | 22 | class DecoderHead(nn.Module): 23 | def __init__(self, 24 | in_channels=[64, 128, 320, 512], 25 | num_classes=40, 26 | dropout_ratio=0.1, 27 | norm_layer=nn.BatchNorm2d, 28 | embed_dim=768, 29 | align_corners=False): 30 | 31 | super(DecoderHead, self).__init__() 32 | self.num_classes = num_classes 33 | self.dropout_ratio = dropout_ratio 34 | self.align_corners = align_corners 35 | 36 | self.in_channels = in_channels 37 | 38 | if dropout_ratio > 0: 39 | self.dropout = nn.Dropout2d(dropout_ratio) 40 | else: 41 | self.dropout = None 42 | 43 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 44 | 45 | embedding_dim = embed_dim 46 | 47 | 48 | # RGB decoder 49 | 50 | self.linear_c4_rgb = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 51 | self.linear_c3_rgb = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 52 | self.linear_c2_rgb = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 53 | self.linear_c1_rgb = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 54 | 55 | self.linear_fuse_rgb = nn.Sequential( 56 | nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1), 57 | norm_layer(embedding_dim), 58 | nn.ReLU(inplace=True) 59 | ) 60 | 61 | self.linear_pred_rgb = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 62 | 63 | 64 | #depth decoder 65 | 66 | self.linear_c4_depth = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 67 | self.linear_c3_depth = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 68 | self.linear_c2_depth = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 69 | self.linear_c1_depth = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 70 | 71 | self.linear_fuse_depth = nn.Sequential( 72 | nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1), 73 | norm_layer(embedding_dim), 74 | nn.ReLU(inplace=True) 75 | ) 76 | 77 | self.linear_pred_depth = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 78 | 79 | 80 | self.linear_fusion = nn.Sequential( 81 | nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1), 82 | norm_layer(embedding_dim), 83 | nn.ReLU(inplace=True) 84 | ) 85 | 86 | #self.linear_pred_fusion = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 87 | 88 | 89 | 90 | self.DRC = DRC(embedding_dim,self.num_classes) 91 | 92 | def forward(self, rgb,depth): 93 | # len=4, 1/4,1/8,1/16,1/32 94 | 95 | c1_r, c2_r, c3_r, c4_r = rgb 96 | verbose = False 97 | 98 | if verbose: print("c1_r size",c1_r.size()) 99 | if verbose: print("c2_r size",c2_r.size()) 100 | if verbose: print("c3_r size",c3_r.size()) 101 | if verbose: print("c4_r size",c4_r.size()) 102 | 103 | c1_d, c2_d, c3_d, c4_d = depth 104 | 105 | if verbose: print("c1_d size",c1_d.size()) 106 | if verbose: print("c2_d size",c2_d.size()) 107 | if verbose: print("c3_d size",c3_d.size()) 108 | if verbose: print("c4_d size",c4_d.size()) 109 | 110 | ############## MLP decoder on C1-C4 ########### 111 | n, _, h, w = c4_r.shape 112 | 113 | _c4_r = self.linear_c4_rgb(c4_r).permute(0,2,1).reshape(n, -1, c4_r.shape[2], c4_r.shape[3]) 114 | if verbose: print("_c4_r size after linear_c4_rgb",_c4_r.size()) 115 | _c4_d = self.linear_c4_depth(c4_d).permute(0,2,1).reshape(n, -1, c4_d.shape[2], c4_d.shape[3]) 116 | if verbose: print("_c4_d size after linear_c4_depth",_c4_d.size()) 117 | 118 | _c4_r = F.interpolate(_c4_r, size=c1_r.size()[2:],mode='bilinear',align_corners=self.align_corners) 119 | if verbose: print("_c4_r size after interpolate",_c4_r.size()) 120 | _c4_d = F.interpolate(_c4_d, size=c1_d.size()[2:],mode='bilinear',align_corners=self.align_corners) 121 | if verbose: print("_c4_d size after interpolate",_c4_d.size()) 122 | 123 | 124 | 125 | _c3_r = self.linear_c3_rgb(c3_r).permute(0,2,1).reshape(n, -1, c3_r.shape[2], c3_r.shape[3]) 126 | if verbose: print("_c3_r size after linear_c3_rgb",_c3_r.size()) 127 | _c3_d = self.linear_c3_depth(c3_d).permute(0,2,1).reshape(n, -1, c3_d.shape[2], c3_d.shape[3]) 128 | if verbose: print("_c3_d size after linear_c3_depth",_c3_d.size()) 129 | 130 | _c3_r = F.interpolate(_c3_r, size=c1_r.size()[2:],mode='bilinear',align_corners=self.align_corners) 131 | if verbose: print("_c3_r size after interpolate",_c3_r.size()) 132 | _c3_d = F.interpolate(_c3_d, size=c1_d.size()[2:],mode='bilinear',align_corners=self.align_corners) 133 | if verbose: print("_c3_d size after interpolate",_c3_d.size()) 134 | 135 | 136 | 137 | _c2_r = self.linear_c2_rgb(c2_r).permute(0,2,1).reshape(n, -1, c2_r.shape[2], c2_r.shape[3]) 138 | if verbose: print("_c2_r size after linear_c2_rgb",_c2_r.size()) 139 | _c2_d = self.linear_c2_depth(c2_d).permute(0,2,1).reshape(n, -1, c2_d.shape[2], c2_d.shape[3]) 140 | if verbose: print("_c2_d size after linear_c2_depth",_c2_d.size()) 141 | 142 | _c2_r = F.interpolate(_c2_r, size=c1_r.size()[2:],mode='bilinear',align_corners=self.align_corners) 143 | if verbose: print("_c2_r size after interpolate",_c2_r.size()) 144 | _c2_d = F.interpolate(_c2_d, size=c1_d.size()[2:],mode='bilinear',align_corners=self.align_corners) 145 | if verbose: print("_c2_d size after interpolate",_c2_d.size()) 146 | 147 | 148 | 149 | _c1_d = self.linear_c1_depth(c1_d).permute(0,2,1).reshape(n, -1, c1_d.shape[2], c1_d.shape[3]) 150 | if verbose: print("_c1_d size after linear_c1_depth",_c1_d.size()) 151 | _c1_r = self.linear_c1_rgb(c1_r).permute(0,2,1).reshape(n, -1, c1_r.shape[2], c1_r.shape[3]) 152 | if verbose: print("_c1_r size after linear_c1_rgb",_c1_r.size()) 153 | 154 | 155 | _c_d = self.linear_fuse_depth(torch.cat([_c4_d, _c3_d, _c2_d, _c1_d], dim=1)) 156 | x_d = self.dropout(_c_d) 157 | x_d = self.linear_pred_depth(x_d) 158 | 159 | _c_r = self.linear_fuse_rgb(torch.cat([_c4_r, _c3_r, _c2_r, _c1_r], dim=1)) 160 | x_r = self.dropout(_c_r) 161 | x_r = self.linear_pred_rgb(x_r) 162 | 163 | 164 | fusion = self.linear_fusion(torch.cat([_c4_r+_c4_d, _c3_r+_c3_d, _c2_r+_c2_d, _c1_r+_c1_d], dim=1)) 165 | 166 | 167 | 168 | rgb_comple,depth_comple,rgb_fusion,depth_fusion = self.DRC(fusion,x_r,x_d) 169 | 170 | 171 | return x_r, rgb_comple, rgb_fusion, x_d, depth_comple, depth_fusion 172 | 173 | class DRC(nn.Module): 174 | def __init__(self,in_channel, n_class): 175 | super(DRC, self).__init__() 176 | 177 | self.rgb_segconv = nn.Conv2d(in_channels=in_channel,out_channels=n_class,kernel_size=1,stride=1,padding=0) 178 | self.depth_segconv = nn.Conv2d(in_channels=in_channel,out_channels=n_class,kernel_size=1,stride=1,padding=0) 179 | 180 | 181 | self.rgb_comple_conv1 = nn.Conv2d(in_channels=n_class,out_channels=n_class,kernel_size=3,stride=1,padding=1) 182 | self.rgb_comple_bn1 = nn.BatchNorm2d(n_class) 183 | self.rgb_comple_relu1 = nn.ReLU() 184 | self.rgb_comple_fusion_conv1 = nn.Conv2d(in_channels=n_class*2,out_channels=n_class,kernel_size=1,stride=1,padding=0) 185 | 186 | 187 | self.depth_comple_conv1 = nn.Conv2d(in_channels=n_class,out_channels=n_class,kernel_size=3,stride=1,padding=1) 188 | self.depth_comple_bn1 = nn.BatchNorm2d(n_class) 189 | self.depth_comple_relu1 = nn.ReLU() 190 | self.depth_comple_fusion_conv1 = nn.Conv2d(in_channels=n_class*2,out_channels=n_class,kernel_size=1,stride=1,padding=0) 191 | 192 | 193 | def forward(self,fusion,x_r,x_d): 194 | 195 | rgb_missing = self.rgb_segconv(fusion) 196 | rgb_comple1 = self.rgb_comple_conv1(rgb_missing) 197 | rgb_comple1 = self.rgb_comple_bn1(rgb_comple1) 198 | rgb_comple1 = self.rgb_comple_relu1(rgb_comple1) 199 | rgb_comple = rgb_missing+rgb_comple1 200 | rgb_fusion = self.rgb_comple_fusion_conv1(torch.cat((x_r,rgb_comple),dim=1)) 201 | 202 | depth_missing = self.depth_segconv(fusion) 203 | depth_comple1 = self.depth_comple_conv1(depth_missing) 204 | depth_comple1 = self.depth_comple_bn1(depth_comple1) 205 | depth_comple1 = self.depth_comple_relu1(depth_comple1) 206 | depth_comple = depth_missing+depth_comple1 207 | depth_fusion = self.depth_comple_fusion_conv1(torch.cat((x_d,depth_comple),dim=1)) 208 | 209 | return rgb_comple,depth_comple, rgb_fusion, depth_fusion -------------------------------------------------------------------------------- /model/mysegformer/encoders/dual_segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | import math 8 | import time 9 | #from engine.logger import get_logger 10 | 11 | #logger = get_logger() 12 | 13 | 14 | class DWConv(nn.Module): 15 | """ 16 | Depthwise convolution bloc: input: x with size(B N C); output size (B N C) 17 | """ 18 | def __init__(self, dim=768): 19 | super(DWConv, self).__init__() 20 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim) 21 | 22 | def forward(self, x, H, W): 23 | B, N, C = x.shape 24 | x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # B N C -> B C N -> B C H W 25 | x = self.dwconv(x) 26 | x = x.flatten(2).transpose(1, 2) # B C H W -> B N C 27 | 28 | return x 29 | 30 | 31 | class Mlp(nn.Module): 32 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 33 | super().__init__() 34 | """ 35 | MLP Block: 36 | """ 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = nn.Linear(in_features, hidden_features) 40 | self.dwconv = DWConv(hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | self.apply(self._init_weights) 46 | 47 | def _init_weights(self, m): 48 | if isinstance(m, nn.Linear): 49 | trunc_normal_(m.weight, std=.02) 50 | if isinstance(m, nn.Linear) and m.bias is not None: 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.LayerNorm): 53 | nn.init.constant_(m.bias, 0) 54 | nn.init.constant_(m.weight, 1.0) 55 | elif isinstance(m, nn.Conv2d): 56 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 57 | fan_out //= m.groups 58 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | 62 | def forward(self, x, H, W): 63 | x = self.fc1(x) 64 | x = self.dwconv(x, H, W) 65 | x = self.act(x) 66 | x = self.drop(x) 67 | x = self.fc2(x) 68 | x = self.drop(x) 69 | return x 70 | 71 | 72 | class Attention(nn.Module): 73 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 74 | super().__init__() 75 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 76 | 77 | self.dim = dim 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | # Linear embedding 83 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 84 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 85 | self.attn_drop = nn.Dropout(attn_drop) 86 | self.proj = nn.Linear(dim, dim) 87 | self.proj_drop = nn.Dropout(proj_drop) 88 | 89 | self.sr_ratio = sr_ratio 90 | if sr_ratio > 1: 91 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 92 | self.norm = nn.LayerNorm(dim) 93 | 94 | self.apply(self._init_weights) 95 | 96 | def _init_weights(self, m): 97 | if isinstance(m, nn.Linear): 98 | trunc_normal_(m.weight, std=.02) 99 | if isinstance(m, nn.Linear) and m.bias is not None: 100 | nn.init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.LayerNorm): 102 | nn.init.constant_(m.bias, 0) 103 | nn.init.constant_(m.weight, 1.0) 104 | elif isinstance(m, nn.Conv2d): 105 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | fan_out //= m.groups 107 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 108 | if m.bias is not None: 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x, H, W): 112 | B, N, C = x.shape 113 | # B N C -> B N num_head C//num_head -> B C//num_head N num_heads 114 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 115 | 116 | if self.sr_ratio > 1: 117 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 118 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 119 | x_ = self.norm(x_) 120 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | else: 122 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 123 | k, v = kv[0], kv[1] 124 | 125 | attn = (q @ k.transpose(-2, -1)) * self.scale 126 | attn = attn.softmax(dim=-1) 127 | attn = self.attn_drop(attn) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 130 | x = self.proj(x) 131 | x = self.proj_drop(x) 132 | 133 | return x 134 | 135 | 136 | class Block(nn.Module): 137 | """ 138 | Transformer Block: Self-Attention -> Mix FFN -> OverLap Patch Merging 139 | """ 140 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 142 | super().__init__() 143 | self.norm1 = norm_layer(dim) 144 | self.attn = Attention( 145 | dim, 146 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 147 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 148 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 149 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 150 | self.norm2 = norm_layer(dim) 151 | mlp_hidden_dim = int(dim * mlp_ratio) 152 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 153 | 154 | self.apply(self._init_weights) 155 | 156 | def _init_weights(self, m): 157 | if isinstance(m, nn.Linear): 158 | trunc_normal_(m.weight, std=.02) 159 | if isinstance(m, nn.Linear) and m.bias is not None: 160 | nn.init.constant_(m.bias, 0) 161 | elif isinstance(m, nn.LayerNorm): 162 | nn.init.constant_(m.bias, 0) 163 | nn.init.constant_(m.weight, 1.0) 164 | elif isinstance(m, nn.Conv2d): 165 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 166 | fan_out //= m.groups 167 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 168 | if m.bias is not None: 169 | m.bias.data.zero_() 170 | 171 | def forward(self, x, H, W): 172 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 173 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 174 | 175 | return x 176 | 177 | 178 | class OverlapPatchEmbed(nn.Module): 179 | """ Image to Patch Embedding 180 | """ 181 | 182 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 183 | super().__init__() 184 | img_size = to_2tuple(img_size) 185 | patch_size = to_2tuple(patch_size) 186 | 187 | self.img_size = img_size 188 | self.patch_size = patch_size 189 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 190 | self.num_patches = self.H * self.W 191 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 192 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 193 | self.norm = nn.LayerNorm(embed_dim) 194 | 195 | self.apply(self._init_weights) 196 | 197 | def _init_weights(self, m): 198 | if isinstance(m, nn.Linear): 199 | trunc_normal_(m.weight, std=.02) 200 | if isinstance(m, nn.Linear) and m.bias is not None: 201 | nn.init.constant_(m.bias, 0) 202 | elif isinstance(m, nn.LayerNorm): 203 | nn.init.constant_(m.bias, 0) 204 | nn.init.constant_(m.weight, 1.0) 205 | elif isinstance(m, nn.Conv2d): 206 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 207 | fan_out //= m.groups 208 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 209 | if m.bias is not None: 210 | m.bias.data.zero_() 211 | 212 | def forward(self, x): 213 | # B C H W 214 | x = self.proj(x) 215 | _, _, H, W = x.shape 216 | x = x.flatten(2).transpose(1, 2) 217 | # B H*W/16 C 218 | x = self.norm(x) 219 | 220 | return x, H, W 221 | 222 | 223 | class RGBXTransformer(nn.Module): 224 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 225 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 226 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d, 227 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 228 | super().__init__() 229 | self.num_classes = num_classes 230 | self.depths = depths 231 | 232 | # patch_embed 233 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 234 | embed_dim=embed_dims[0]) 235 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 236 | embed_dim=embed_dims[1]) 237 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 238 | embed_dim=embed_dims[2]) 239 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 240 | embed_dim=embed_dims[3]) 241 | 242 | self.extra_patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 243 | embed_dim=embed_dims[0]) 244 | self.extra_patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 245 | embed_dim=embed_dims[1]) 246 | self.extra_patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 247 | embed_dim=embed_dims[2]) 248 | self.extra_patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 249 | embed_dim=embed_dims[3]) 250 | 251 | # transformer encoder 252 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 253 | cur = 0 254 | 255 | self.block1 = nn.ModuleList([Block( 256 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 257 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 258 | sr_ratio=sr_ratios[0]) 259 | for i in range(depths[0])]) 260 | self.norm1 = norm_layer(embed_dims[0]) 261 | 262 | self.extra_block1 = nn.ModuleList([Block( 263 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 264 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 265 | sr_ratio=sr_ratios[0]) 266 | for i in range(depths[0])]) 267 | self.extra_norm1 = norm_layer(embed_dims[0]) 268 | cur += depths[0] 269 | 270 | self.block2 = nn.ModuleList([Block( 271 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 272 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur], norm_layer=norm_layer, 273 | sr_ratio=sr_ratios[1]) 274 | for i in range(depths[1])]) 275 | self.norm2 = norm_layer(embed_dims[1]) 276 | 277 | self.extra_block2 = nn.ModuleList([Block( 278 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 279 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+1], norm_layer=norm_layer, 280 | sr_ratio=sr_ratios[1]) 281 | for i in range(depths[1])]) 282 | self.extra_norm2 = norm_layer(embed_dims[1]) 283 | 284 | cur += depths[1] 285 | 286 | self.block3 = nn.ModuleList([Block( 287 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 288 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 289 | sr_ratio=sr_ratios[2]) 290 | for i in range(depths[2])]) 291 | self.norm3 = norm_layer(embed_dims[2]) 292 | 293 | self.extra_block3 = nn.ModuleList([Block( 294 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 295 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 296 | sr_ratio=sr_ratios[2]) 297 | for i in range(depths[2])]) 298 | self.extra_norm3 = norm_layer(embed_dims[2]) 299 | 300 | cur += depths[2] 301 | 302 | self.block4 = nn.ModuleList([Block( 303 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 304 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 305 | sr_ratio=sr_ratios[3]) 306 | for i in range(depths[3])]) 307 | self.norm4 = norm_layer(embed_dims[3]) 308 | 309 | self.extra_block4 = nn.ModuleList([Block( 310 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 311 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 312 | sr_ratio=sr_ratios[3]) 313 | for i in range(depths[3])]) 314 | self.extra_norm4 = norm_layer(embed_dims[3]) 315 | 316 | cur += depths[3] 317 | 318 | self.apply(self._init_weights) 319 | 320 | def _init_weights(self, m): 321 | if isinstance(m, nn.Linear): 322 | trunc_normal_(m.weight, std=.02) 323 | if isinstance(m, nn.Linear) and m.bias is not None: 324 | nn.init.constant_(m.bias, 0) 325 | elif isinstance(m, nn.LayerNorm): 326 | nn.init.constant_(m.bias, 0) 327 | nn.init.constant_(m.weight, 1.0) 328 | elif isinstance(m, nn.Conv2d): 329 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 330 | fan_out //= m.groups 331 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 332 | if m.bias is not None: 333 | m.bias.data.zero_() 334 | 335 | def init_weights(self, pretrained=None): 336 | if isinstance(pretrained, str): 337 | load_dualpath_model(self, pretrained) 338 | else: 339 | raise TypeError('pretrained must be a str or None') 340 | 341 | def forward_features(self, x_rgb, x_e): 342 | """ 343 | x_rgb: B x N x H x W 344 | """ 345 | B = x_rgb.shape[0] 346 | rgb = [] 347 | depth = [] 348 | 349 | # stage 1 350 | x_rgb, H, W = self.patch_embed1(x_rgb) 351 | # B H*W/16 C 352 | x_e, _, _ = self.extra_patch_embed1(x_e) 353 | for i, blk in enumerate(self.block1): 354 | x_rgb = blk(x_rgb, H, W) 355 | for i, blk in enumerate(self.extra_block1): 356 | x_e = blk(x_e, H, W) 357 | x_rgb = self.norm1(x_rgb) 358 | x_e = self.extra_norm1(x_e) 359 | 360 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 361 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 362 | 363 | rgb.append(x_rgb) 364 | depth.append(x_e) 365 | 366 | # stage 2 367 | x_rgb, H, W = self.patch_embed2(x_rgb) 368 | x_e, _, _ = self.extra_patch_embed2(x_e) 369 | for i, blk in enumerate(self.block2): 370 | x_rgb = blk(x_rgb, H, W) 371 | for i, blk in enumerate(self.extra_block2): 372 | x_e = blk(x_e, H, W) 373 | x_rgb = self.norm2(x_rgb) 374 | x_e = self.extra_norm2(x_e) 375 | 376 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 377 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 378 | 379 | rgb.append(x_rgb) 380 | depth.append(x_e) 381 | 382 | # stage 3 383 | x_rgb, H, W = self.patch_embed3(x_rgb) 384 | x_e, _, _ = self.extra_patch_embed3(x_e) 385 | for i, blk in enumerate(self.block3): 386 | x_rgb = blk(x_rgb, H, W) 387 | for i, blk in enumerate(self.extra_block3): 388 | x_e = blk(x_e, H, W) 389 | x_rgb = self.norm3(x_rgb) 390 | x_e = self.extra_norm3(x_e) 391 | 392 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 393 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 394 | 395 | rgb.append(x_rgb) 396 | depth.append(x_e) 397 | 398 | 399 | # stage 4 400 | x_rgb, H, W = self.patch_embed4(x_rgb) 401 | x_e, _, _ = self.extra_patch_embed4(x_e) 402 | for i, blk in enumerate(self.block4): 403 | x_rgb = blk(x_rgb, H, W) 404 | for i, blk in enumerate(self.extra_block4): 405 | x_e = blk(x_e, H, W) 406 | x_rgb = self.norm4(x_rgb) 407 | x_e = self.extra_norm4(x_e) 408 | 409 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 410 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 411 | 412 | rgb.append(x_rgb) 413 | depth.append(x_e) 414 | 415 | return rgb,depth 416 | 417 | def forward(self, x_rgb, x_e): 418 | rgb,depth = self.forward_features(x_rgb, x_e) 419 | return rgb,depth 420 | 421 | 422 | def load_dualpath_model(model, model_file): 423 | # load raw state_dict 424 | t_start = time.time() 425 | if isinstance(model_file, str): 426 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu')) 427 | #raw_state_dict = torch.load(model_file) 428 | if 'model' in raw_state_dict.keys(): 429 | raw_state_dict = raw_state_dict['model'] 430 | else: 431 | raw_state_dict = model_file 432 | 433 | state_dict = {} 434 | for k, v in raw_state_dict.items(): 435 | if k.find('patch_embed') >= 0: 436 | state_dict[k] = v 437 | state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v 438 | elif k.find('block') >= 0: 439 | state_dict[k] = v 440 | state_dict[k.replace('block', 'extra_block')] = v 441 | elif k.find('norm') >= 0: 442 | state_dict[k] = v 443 | state_dict[k.replace('norm', 'extra_norm')] = v 444 | 445 | t_ioend = time.time() 446 | 447 | model.load_state_dict(state_dict, strict=False) 448 | del state_dict 449 | 450 | t_end = time.time() 451 | 452 | 453 | class mit_b0(RGBXTransformer): 454 | def __init__(self, fuse_cfg=None, **kwargs): 455 | super(mit_b0, self).__init__( 456 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 457 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 458 | drop_rate=0.0, drop_path_rate=0.1) 459 | 460 | 461 | class mit_b1(RGBXTransformer): 462 | def __init__(self, fuse_cfg=None, **kwargs): 463 | super(mit_b1, self).__init__( 464 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 465 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 466 | drop_rate=0.0, drop_path_rate=0.1) 467 | 468 | 469 | class mit_b2(RGBXTransformer): 470 | def __init__(self, fuse_cfg=None, **kwargs): 471 | super(mit_b2, self).__init__( 472 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 473 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 474 | drop_rate=0.0, drop_path_rate=0.1) 475 | 476 | 477 | class mit_b3(RGBXTransformer): 478 | def __init__(self, fuse_cfg=None, **kwargs): 479 | super(mit_b3, self).__init__( 480 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 481 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 482 | drop_rate=0.0, drop_path_rate=0.1) 483 | 484 | 485 | class mit_b4(RGBXTransformer): 486 | def __init__(self, fuse_cfg=None, **kwargs): 487 | super(mit_b4, self).__init__( 488 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 489 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 490 | drop_rate=0.0, drop_path_rate=0.1) 491 | 492 | 493 | class mit_b5(RGBXTransformer): 494 | def __init__(self, fuse_cfg=None, **kwargs): 495 | super(mit_b5, self).__init__( 496 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 497 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 498 | drop_rate=0.0, drop_path_rate=0.1) 499 | -------------------------------------------------------------------------------- /model/mysegformer/encoders/one_segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | import math 8 | import time 9 | #from engine.logger import get_logger 10 | 11 | #logger = get_logger() 12 | 13 | 14 | class DWConv(nn.Module): 15 | """ 16 | Depthwise convolution bloc: input: x with size(B N C); output size (B N C) 17 | """ 18 | def __init__(self, dim=768): 19 | super(DWConv, self).__init__() 20 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim) 21 | 22 | def forward(self, x, H, W): 23 | B, N, C = x.shape 24 | x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # B N C -> B C N -> B C H W 25 | x = self.dwconv(x) 26 | x = x.flatten(2).transpose(1, 2) # B C H W -> B N C 27 | 28 | return x 29 | 30 | 31 | class Mlp(nn.Module): 32 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 33 | super().__init__() 34 | """ 35 | MLP Block: 36 | """ 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = nn.Linear(in_features, hidden_features) 40 | self.dwconv = DWConv(hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | self.apply(self._init_weights) 46 | 47 | def _init_weights(self, m): 48 | if isinstance(m, nn.Linear): 49 | trunc_normal_(m.weight, std=.02) 50 | if isinstance(m, nn.Linear) and m.bias is not None: 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.LayerNorm): 53 | nn.init.constant_(m.bias, 0) 54 | nn.init.constant_(m.weight, 1.0) 55 | elif isinstance(m, nn.Conv2d): 56 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 57 | fan_out //= m.groups 58 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 59 | if m.bias is not None: 60 | m.bias.data.zero_() 61 | 62 | def forward(self, x, H, W): 63 | x = self.fc1(x) 64 | x = self.dwconv(x, H, W) 65 | x = self.act(x) 66 | x = self.drop(x) 67 | x = self.fc2(x) 68 | x = self.drop(x) 69 | return x 70 | 71 | 72 | class Attention(nn.Module): 73 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 74 | super().__init__() 75 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 76 | 77 | self.dim = dim 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = qk_scale or head_dim ** -0.5 81 | 82 | # Linear embedding 83 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 84 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 85 | self.attn_drop = nn.Dropout(attn_drop) 86 | self.proj = nn.Linear(dim, dim) 87 | self.proj_drop = nn.Dropout(proj_drop) 88 | 89 | self.sr_ratio = sr_ratio 90 | if sr_ratio > 1: 91 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 92 | self.norm = nn.LayerNorm(dim) 93 | 94 | self.apply(self._init_weights) 95 | 96 | def _init_weights(self, m): 97 | if isinstance(m, nn.Linear): 98 | trunc_normal_(m.weight, std=.02) 99 | if isinstance(m, nn.Linear) and m.bias is not None: 100 | nn.init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.LayerNorm): 102 | nn.init.constant_(m.bias, 0) 103 | nn.init.constant_(m.weight, 1.0) 104 | elif isinstance(m, nn.Conv2d): 105 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | fan_out //= m.groups 107 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 108 | if m.bias is not None: 109 | m.bias.data.zero_() 110 | 111 | def forward(self, x, H, W): 112 | B, N, C = x.shape 113 | # B N C -> B N num_head C//num_head -> B C//num_head N num_heads 114 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 115 | 116 | if self.sr_ratio > 1: 117 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 118 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 119 | x_ = self.norm(x_) 120 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | else: 122 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 123 | k, v = kv[0], kv[1] 124 | 125 | attn = (q @ k.transpose(-2, -1)) * self.scale 126 | attn = attn.softmax(dim=-1) 127 | attn = self.attn_drop(attn) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 130 | x = self.proj(x) 131 | x = self.proj_drop(x) 132 | 133 | return x 134 | 135 | 136 | class Block(nn.Module): 137 | """ 138 | Transformer Block: Self-Attention -> Mix FFN -> OverLap Patch Merging 139 | """ 140 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 142 | super().__init__() 143 | self.norm1 = norm_layer(dim) 144 | self.attn = Attention( 145 | dim, 146 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 147 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 148 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 149 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 150 | self.norm2 = norm_layer(dim) 151 | mlp_hidden_dim = int(dim * mlp_ratio) 152 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 153 | 154 | self.apply(self._init_weights) 155 | 156 | def _init_weights(self, m): 157 | if isinstance(m, nn.Linear): 158 | trunc_normal_(m.weight, std=.02) 159 | if isinstance(m, nn.Linear) and m.bias is not None: 160 | nn.init.constant_(m.bias, 0) 161 | elif isinstance(m, nn.LayerNorm): 162 | nn.init.constant_(m.bias, 0) 163 | nn.init.constant_(m.weight, 1.0) 164 | elif isinstance(m, nn.Conv2d): 165 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 166 | fan_out //= m.groups 167 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 168 | if m.bias is not None: 169 | m.bias.data.zero_() 170 | 171 | def forward(self, x, H, W): 172 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 173 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 174 | 175 | return x 176 | 177 | 178 | class OverlapPatchEmbed(nn.Module): 179 | """ Image to Patch Embedding 180 | """ 181 | 182 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 183 | super().__init__() 184 | img_size = to_2tuple(img_size) 185 | patch_size = to_2tuple(patch_size) 186 | 187 | self.img_size = img_size 188 | self.patch_size = patch_size 189 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 190 | self.num_patches = self.H * self.W 191 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 192 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 193 | self.norm = nn.LayerNorm(embed_dim) 194 | 195 | self.apply(self._init_weights) 196 | 197 | def _init_weights(self, m): 198 | if isinstance(m, nn.Linear): 199 | trunc_normal_(m.weight, std=.02) 200 | if isinstance(m, nn.Linear) and m.bias is not None: 201 | nn.init.constant_(m.bias, 0) 202 | elif isinstance(m, nn.LayerNorm): 203 | nn.init.constant_(m.bias, 0) 204 | nn.init.constant_(m.weight, 1.0) 205 | elif isinstance(m, nn.Conv2d): 206 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 207 | fan_out //= m.groups 208 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 209 | if m.bias is not None: 210 | m.bias.data.zero_() 211 | 212 | def forward(self, x): 213 | # B C H W 214 | x = self.proj(x) 215 | _, _, H, W = x.shape 216 | x = x.flatten(2).transpose(1, 2) 217 | # B H*W/16 C 218 | x = self.norm(x) 219 | 220 | return x, H, W 221 | 222 | 223 | class RGBXTransformer(nn.Module): 224 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 225 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 226 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d, 227 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 228 | super().__init__() 229 | self.num_classes = num_classes 230 | self.depths = depths 231 | 232 | # patch_embed 233 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 234 | embed_dim=embed_dims[0]) 235 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 236 | embed_dim=embed_dims[1]) 237 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 238 | embed_dim=embed_dims[2]) 239 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 240 | embed_dim=embed_dims[3]) 241 | 242 | # transformer encoder 243 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 244 | cur = 0 245 | 246 | self.block1 = nn.ModuleList([Block( 247 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 248 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 249 | sr_ratio=sr_ratios[0]) 250 | for i in range(depths[0])]) 251 | self.norm1 = norm_layer(embed_dims[0]) 252 | 253 | self.block2 = nn.ModuleList([Block( 254 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 255 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur], norm_layer=norm_layer, 256 | sr_ratio=sr_ratios[1]) 257 | for i in range(depths[1])]) 258 | self.norm2 = norm_layer(embed_dims[1]) 259 | 260 | cur += depths[1] 261 | 262 | self.block3 = nn.ModuleList([Block( 263 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 264 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 265 | sr_ratio=sr_ratios[2]) 266 | for i in range(depths[2])]) 267 | self.norm3 = norm_layer(embed_dims[2]) 268 | 269 | cur += depths[2] 270 | 271 | self.block4 = nn.ModuleList([Block( 272 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 273 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 274 | sr_ratio=sr_ratios[3]) 275 | for i in range(depths[3])]) 276 | self.norm4 = norm_layer(embed_dims[3]) 277 | 278 | cur += depths[3] 279 | 280 | self.apply(self._init_weights) 281 | 282 | def _init_weights(self, m): 283 | if isinstance(m, nn.Linear): 284 | trunc_normal_(m.weight, std=.02) 285 | if isinstance(m, nn.Linear) and m.bias is not None: 286 | nn.init.constant_(m.bias, 0) 287 | elif isinstance(m, nn.LayerNorm): 288 | nn.init.constant_(m.bias, 0) 289 | nn.init.constant_(m.weight, 1.0) 290 | elif isinstance(m, nn.Conv2d): 291 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 292 | fan_out //= m.groups 293 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 294 | if m.bias is not None: 295 | m.bias.data.zero_() 296 | 297 | def init_weights(self, pretrained=None): 298 | if isinstance(pretrained, str): 299 | load_dualpath_model(self, pretrained) 300 | else: 301 | raise TypeError('pretrained must be a str or None') 302 | 303 | def forward_features(self, x_rgb, x_e): 304 | """ 305 | x_rgb: B x N x H x W 306 | """ 307 | B = x_rgb.shape[0] 308 | rgb = [] 309 | depth = [] 310 | 311 | 312 | # stage 1 313 | x_rgb, H, W = self.patch_embed1(x_rgb) 314 | # B H*W/16 C 315 | for i, blk in enumerate(self.block1): 316 | x_rgb = blk(x_rgb, H, W) 317 | x_rgb = self.norm1(x_rgb) 318 | 319 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 320 | 321 | rgb.append(x_rgb) 322 | 323 | 324 | 325 | # stage 2 326 | x_rgb, H, W = self.patch_embed2(x_rgb) 327 | for i, blk in enumerate(self.block2): 328 | x_rgb = blk(x_rgb, H, W) 329 | x_rgb = self.norm2(x_rgb) 330 | 331 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 332 | 333 | rgb.append(x_rgb) 334 | 335 | # stage 3 336 | x_rgb, H, W = self.patch_embed3(x_rgb) 337 | for i, blk in enumerate(self.block3): 338 | x_rgb = blk(x_rgb, H, W) 339 | x_rgb = self.norm3(x_rgb) 340 | 341 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 342 | 343 | rgb.append(x_rgb) 344 | 345 | 346 | # stage 4 347 | x_rgb, H, W = self.patch_embed4(x_rgb) 348 | for i, blk in enumerate(self.block4): 349 | x_rgb = blk(x_rgb, H, W) 350 | x_rgb = self.norm4(x_rgb) 351 | 352 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 353 | 354 | rgb.append(x_rgb) 355 | depth = [] 356 | return rgb,depth 357 | 358 | def forward(self, x_rgb, x_e): 359 | rgb,depth = self.forward_features(x_rgb, x_e) 360 | return rgb,depth 361 | 362 | 363 | def load_dualpath_model(model, model_file): 364 | # load raw state_dict 365 | t_start = time.time() 366 | if isinstance(model_file, str): 367 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu')) 368 | #raw_state_dict = torch.load(model_file) 369 | if 'model' in raw_state_dict.keys(): 370 | raw_state_dict = raw_state_dict['model'] 371 | else: 372 | raw_state_dict = model_file 373 | 374 | state_dict = {} 375 | for k, v in raw_state_dict.items(): 376 | if k.find('patch_embed') >= 0: 377 | state_dict[k] = v 378 | state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v 379 | elif k.find('block') >= 0: 380 | state_dict[k] = v 381 | state_dict[k.replace('block', 'extra_block')] = v 382 | elif k.find('norm') >= 0: 383 | state_dict[k] = v 384 | state_dict[k.replace('norm', 'extra_norm')] = v 385 | 386 | t_ioend = time.time() 387 | 388 | model.load_state_dict(state_dict, strict=False) 389 | del state_dict 390 | 391 | t_end = time.time() 392 | 393 | 394 | class mit_b0(RGBXTransformer): 395 | def __init__(self, fuse_cfg=None, **kwargs): 396 | super(mit_b0, self).__init__( 397 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 398 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 399 | drop_rate=0.0, drop_path_rate=0.1) 400 | 401 | 402 | class mit_b1(RGBXTransformer): 403 | def __init__(self, fuse_cfg=None, **kwargs): 404 | super(mit_b1, self).__init__( 405 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 406 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 407 | drop_rate=0.0, drop_path_rate=0.1) 408 | 409 | 410 | class mit_b2(RGBXTransformer): 411 | def __init__(self, fuse_cfg=None, **kwargs): 412 | super(mit_b2, self).__init__( 413 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 414 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 415 | drop_rate=0.0, drop_path_rate=0.1) 416 | 417 | 418 | class mit_b3(RGBXTransformer): 419 | def __init__(self, fuse_cfg=None, **kwargs): 420 | super(mit_b3, self).__init__( 421 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 422 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 423 | drop_rate=0.0, drop_path_rate=0.1) 424 | 425 | 426 | class mit_b4(RGBXTransformer): 427 | def __init__(self, fuse_cfg=None, **kwargs): 428 | super(mit_b4, self).__init__( 429 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 430 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 431 | drop_rate=0.0, drop_path_rate=0.1) 432 | 433 | 434 | class mit_b5(RGBXTransformer): 435 | def __init__(self, fuse_cfg=None, **kwargs): 436 | super(mit_b5, self).__init__( 437 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 438 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 439 | drop_rate=0.0, drop_path_rate=0.1) 440 | -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | import os, argparse, time, datetime, stat, shutil,sys 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | import torchvision.utils as vutils 9 | from util.MY_dataset import MY_dataset 10 | from sklearn.metrics import confusion_matrix 11 | from util.util import compute_results, visualize 12 | from scipy.io import savemat 13 | from torch.utils.tensorboard import SummaryWriter 14 | from model import PotCrackSeg 15 | from util.lr_policy import WarmUpPolyLR 16 | from util.init_func import init_weight, group_weight 17 | from config import config 18 | 19 | ############################################################################################# 20 | parser = argparse.ArgumentParser(description='Test with pytorch') 21 | ############################################################################################# 22 | parser.add_argument('--model_name', '-m', type=str, default='PotCrackSeg') # DRCNet_RDe_b3V3, DRCNet_RDe_b4V3, DRCNet_RDe_b5V3 23 | parser.add_argument('--weight_name', '-w', type=str, default='PotCrackSeg-4B') # DRCNet_RDe_b3V3, DRCNet_RDe_b4V3, DRCNet_RDe_b5V3 24 | parser.add_argument('--backbone', '-bac', type=str, default='PotCrackSeg-4B') # mit_3, mit_4, mit_5 25 | parser.add_argument('--file_name', '-f', type=str, default='final.pth') 26 | parser.add_argument('--dataset_split', '-d', type=str, default='test') # normal_test, abnormal_test, urban_test,rural_test 27 | parser.add_argument('--gpu', '-g', type=int, default=1) 28 | ############################################################################################# 29 | parser.add_argument('--img_height', '-ih', type=int, default=288) 30 | parser.add_argument('--img_width', '-iw', type=int, default=512) 31 | parser.add_argument('--num_workers', '-j', type=int, default=16) 32 | parser.add_argument('--n_class', '-nc', type=int, default=3) 33 | parser.add_argument('--data_dir', '-dr', type=str, default='./NPO++/') 34 | parser.add_argument('--model_dir', '-wd', type=str, default='./weights_backup/') 35 | args = parser.parse_args() 36 | ############################################################################################# 37 | 38 | def get_palette(): 39 | unlabelled = [0,0,0] 40 | potholes = [153,0,0] 41 | cracks = [0,153,0] 42 | palette = np.array([unlabelled,potholes, cracks]) 43 | return palette 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | torch.cuda.set_device(args.gpu) 49 | print("\nthe pytorch version:", torch.__version__) 50 | print("the gpu count:", torch.cuda.device_count()) 51 | print("the current used gpu:", torch.cuda.current_device(), '\n') 52 | 53 | # prepare save direcotry 54 | if os.path.exists("./runs"): 55 | print("previous \"./runs\" folder exist, will delete this folder") 56 | shutil.rmtree("./runs") 57 | os.makedirs("./runs") 58 | os.chmod("./runs", stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine 59 | model_dir = os.path.join(args.model_dir, args.weight_name) 60 | if os.path.exists(model_dir) is False: 61 | sys.exit("the %s does not exit." %(model_dir)) 62 | model_file = os.path.join(model_dir, args.file_name) 63 | if os.path.exists(model_file) is True: 64 | print('use the final model file.') 65 | else: 66 | sys.exit('no model file found.') 67 | print('testing %s: %s on GPU #%d with pytorch' % (args.model_name, args.weight_name, args.gpu)) 68 | 69 | conf_total = np.zeros((args.n_class, args.n_class)) 70 | model = eval(args.model_name)(cfg = config ,n_class=args.n_class, encoder_name=args.backbone) 71 | if args.gpu >= 0: model.cuda(args.gpu) 72 | print('loading model file %s... ' % model_file) 73 | pretrained_weight = torch.load(model_file, map_location = lambda storage, loc: storage.cuda(args.gpu)) 74 | own_state = model.state_dict() 75 | for name, param in pretrained_weight.items(): 76 | own_state[name].copy_(param) 77 | print('done!') 78 | 79 | batch_size = 1 80 | test_dataset = MY_dataset(data_dir=args.data_dir, split=args.dataset_split, input_h=args.img_height, input_w=args.img_width) 81 | test_loader = DataLoader( 82 | dataset = test_dataset, 83 | batch_size = batch_size, 84 | shuffle = False, 85 | num_workers = args.num_workers, 86 | pin_memory = True, 87 | drop_last = False 88 | ) 89 | ave_time_cost = 0.0 90 | 91 | model.eval() 92 | with torch.no_grad(): 93 | for it, (images, labels, names) in enumerate(test_loader): 94 | images = Variable(images).cuda(args.gpu) 95 | labels = Variable(labels).cuda(args.gpu) 96 | # flop,params = profile(model,inputs=(images,)) 97 | # print(flop) 98 | # print(params) 99 | torch.cuda.synchronize() 100 | start_time = time.time() 101 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images) 102 | torch.cuda.synchronize() 103 | end_time = time.time() 104 | if it>=5: # # ignore the first 5 frames 105 | ave_time_cost += (end_time-start_time) 106 | # convert tensor to numpy 1d array 107 | label = labels.cpu().numpy().squeeze().flatten() 108 | prediction = logits.argmax(1).cpu().numpy().squeeze().flatten() # prediction and label are both 1-d array, size: minibatch*640*480 109 | # generate confusion matrix frame-by-frame 110 | conf = confusion_matrix(y_true=label, y_pred=prediction, labels=[0,1,2]) # conf is an n_class*n_class matrix, vertical axis: groundtruth, horizontal axis: prediction 111 | conf_total += conf 112 | # save demo images 113 | visualize(image_name=names, predictions=logits.argmax(1), weight_name=args.weight_name) 114 | print("%s, %s, frame %d/%d, %s, time cost: %.2f ms, demo result saved." 115 | %(args.model_name, args.weight_name, it+1, len(test_loader), names, (end_time-start_time)*1000)) 116 | 117 | precision_per_class, recall_per_class, iou_per_class,F1_per_class = compute_results(conf_total) 118 | #precision, recall, IoU,F1 = compute_results(conf_total) 119 | conf_total_matfile = os.path.join("./runs", 'conf_'+args.weight_name+'.mat') 120 | savemat(conf_total_matfile, {'conf': conf_total}) # 'conf' is the variable name when loaded in Matlab 121 | 122 | print('\n###########################################################################') 123 | print('\n%s: %s test results (with batch size %d) on %s using %s:' %(args.model_name, args.weight_name, batch_size, datetime.date.today(), torch.cuda.get_device_name(args.gpu))) 124 | print('\n* the tested dataset name: %s' % args.dataset_split) 125 | print('* the tested image count: %d' % len(test_loader)) 126 | print('* the tested image size: %d*%d' %(args.img_height, args.img_width)) 127 | print('* the weight name: %s' %args.weight_name) 128 | print('* the file name: %s' %args.file_name) 129 | print("* iou per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \ 130 | %(iou_per_class[0]*100, iou_per_class[1]*100, iou_per_class[2]*100)) 131 | print("* recall per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \ 132 | %(recall_per_class[0]*100, recall_per_class[1]*100, recall_per_class[2]*100)) 133 | print("* pre per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \ 134 | %(precision_per_class[0]*100, precision_per_class[1]*100, precision_per_class[2]*100)) 135 | print("* F1 per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \ 136 | %(F1_per_class[0]*100, F1_per_class[1]*100, F1_per_class[2]*100)) 137 | 138 | print("\n* average values (np.mean(x)): \n iou: %.3f, recall: %.3f, pre: %.3f, F1: %.3f" \ 139 | %(iou_per_class[1:].mean()*100,recall_per_class[1:].mean()*100, precision_per_class[1:].mean()*100,F1_per_class[1:].mean()*100)) 140 | print("* average values (np.mean(np.nan_to_num(x))): \n iou: %.1f, recall: %.1f, pre: %.1f, F1: %.1f" \ 141 | %(np.mean(np.nan_to_num(iou_per_class[1:]))*100, np.mean(np.nan_to_num(recall_per_class[1:]))*100, np.mean(np.nan_to_num(precision_per_class[1:]))*100, np.mean(np.nan_to_num(F1_per_class[1:]))*100)) 142 | 143 | 144 | print('\n* the average time cost per frame (with batch size %d): %.2f ms, namely, the inference speed is %.2f fps' %(batch_size, ave_time_cost*1000/(len(test_loader)-5), 1.0/(ave_time_cost/(len(test_loader)-5)))) # ignore the first 10 frames 145 | print('\n###########################################################################') 146 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, argparse, time, datetime, stat, shutil,sys 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | import torchvision.utils as vutils 9 | from util.MY_dataset import MY_dataset 10 | from util.augmentation import RandomFlip, RandomCrop, RandomCropOut, RandomBrightness, RandomNoise 11 | from util.util import compute_results 12 | from sklearn.metrics import confusion_matrix 13 | from torch.utils.tensorboard import SummaryWriter 14 | from model import PotCrackSeg 15 | from util.lr_policy import WarmUpPolyLR 16 | from util.init_func import init_weight, group_weight 17 | from config import config 18 | 19 | ############################################################################################# 20 | parser = argparse.ArgumentParser(description='Train with pytorch') 21 | ############################################################################################# 22 | parser.add_argument('--model_name', '-m', type=str, default='PotCrackSeg') 23 | parser.add_argument('--batch_size', '-b', type=int, default=2) 24 | parser.add_argument('--lr_start', '-ls', type=float, default=6e-5) 25 | parser.add_argument('--gpu', '-g', type=int, default=0) 26 | ############################################################################################# 27 | parser.add_argument('--lr_decay', '-ld', type=float, default=0.95) 28 | parser.add_argument('--epoch_max', '-em', type=int, default=500) # please stop training mannully 29 | parser.add_argument('--epoch_from', '-ef', type=int, default=0) 30 | parser.add_argument('--num_workers', '-j', type=int, default=8) 31 | parser.add_argument('--n_class', '-nc', type=int, default=3) 32 | parser.add_argument('--data_dir', '-dr', type=str, default='./NPO++/') 33 | parser.add_argument('--pre_weight', '-prw', type=str, default='/pretrained/mit_b2.pth') 34 | parser.add_argument('--backbone', '-bac', type=str, default='PotCrackSeg-2B') 35 | parser.add_argument('--model_dir', '-wd', type=str, default='./weights_backup/') 36 | # parser.add_argument('--weight_name', '-w', type=str, default='DRCNet_0DRC_RDe_b0') # RTFNet_152, RTFNet_50, please change the number of layers in the network file 37 | # parser.add_argument('--file_name', '-f', type=str, default='109.pth') 38 | args = parser.parse_args() 39 | ############################################################################################# 40 | 41 | augmentation_methods = [ 42 | RandomFlip(prob=0.5), 43 | RandomCrop(crop_rate=0.1, prob=1.0), 44 | # RandomCropOut(crop_rate=0.2, prob=1.0), 45 | # RandomBrightness(bright_range=0.15, prob=0.9), 46 | # RandomNoise(noise_range=5, prob=0.9), 47 | ] 48 | 49 | def fusion_loss(rgb_predict, rgb_comple,depth_predict, depth_comple,label): 50 | 51 | feature_map_B, feature_map_C, feature_map_W, feature_map_H = rgb_predict.size() 52 | label_B, label_W, label_H = label.size() 53 | 54 | if feature_map_W != label_W: 55 | label = torch.cuda.FloatTensor(label.unsqueeze(1).cpu().numpy()) 56 | label = F.interpolate(label,[feature_map_W,feature_map_H],mode="nearest") 57 | label = torch.cuda.LongTensor(label.squeeze(1).cpu().numpy()) 58 | 59 | loss_pr_rgb_seg = F.cross_entropy(rgb_predict, label) 60 | rgb_predict = rgb_predict.detach() 61 | rgb_predict=rgb_predict.argmax(1) 62 | rgb_predict.eq_(label) 63 | rgb_predict=rgb_predict.clone().detach_().requires_grad_(False) 64 | add_map_rgb = (1-rgb_predict)*label 65 | add_map_rgb=add_map_rgb.clone().detach_().requires_grad_(False) 66 | loss_add_rgb = F.cross_entropy(rgb_comple,add_map_rgb) 67 | 68 | loss_pr_depth_seg = F.cross_entropy(depth_predict, label) 69 | depth_predict = depth_predict.detach() 70 | depth_predict=depth_predict.argmax(1) 71 | depth_predict.eq_(label) 72 | depth_predict=depth_predict.clone().detach_().requires_grad_(False) 73 | add_map_depth = (1-depth_predict)*label 74 | add_map_depth=add_map_depth.clone().detach_().requires_grad_(False) 75 | loss_add_depth = F.cross_entropy(depth_comple,add_map_depth) 76 | 77 | loss = loss_pr_rgb_seg+loss_add_rgb+loss_pr_depth_seg+loss_add_depth 78 | return loss,add_map_rgb,add_map_depth 79 | 80 | 81 | 82 | def train(epo, model, train_loader, optimizer): 83 | model.train() 84 | loss_sum = 0 85 | 86 | loss_seg_sum = 0 87 | for it, (images, labels, names) in enumerate(train_loader): 88 | images = Variable(images).cuda(args.gpu) 89 | labels = Variable(labels).cuda(args.gpu) 90 | 91 | start_t = time.time() # time.time() returns the current time 92 | optimizer.zero_grad() 93 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images) 94 | 95 | loss1,add_map_rgb,add_map_depth = fusion_loss(rgb_predict, rgb_comple,depth_predict, depth_comple,labels) 96 | #loss2 = F.cross_entropy(rgb_fusion, labels) 97 | #loss3 = F.cross_entropy(depth_fusion, labels) 98 | loss4 = F.cross_entropy(logits, labels) 99 | 100 | #loss = 0.5*loss1+loss2+loss3+loss4 101 | loss = 0.5*loss1+loss4 102 | 103 | loss.backward() 104 | optimizer.step() 105 | 106 | loss_sum = loss_sum+loss 107 | loss_seg_sum = loss_seg_sum+loss4 108 | 109 | current_idx = (epo- 0) * config.niters_per_epoch + it 110 | lr = lr_policy.get_lr(current_idx) 111 | 112 | for i in range(len(optimizer.param_groups)): 113 | optimizer.param_groups[i]['lr'] = lr 114 | 115 | lr_this_epo=0 116 | for param_group in optimizer.param_groups: 117 | lr_this_epo = param_group['lr'] 118 | 119 | print('Train: %s, epo %s/%s, iter %s/%s, lr %.8f, %.2f img/sec, loss %.4f, loss_average %.4f, loss_seg_average %.4f, time %s' \ 120 | % (args.model_name, epo, args.epoch_max, it+1, len(train_loader), lr_this_epo, len(names)/(time.time()-start_t), float(loss), float(loss_sum/(it+1)), float(loss_seg_sum/(it+1)), 121 | datetime.datetime.now().replace(microsecond=0)-start_datetime)) 122 | if accIter['train'] % 1 == 0: 123 | writer.add_scalar('Train/loss', loss, accIter['train']) 124 | view_figure = True # note that I have not colorized the GT and predictions here 125 | if accIter['train'] % 1000 == 0: 126 | if view_figure: 127 | input_rgb_images = vutils.make_grid(images[:,:3], nrow=8, padding=10) # can only display 3-channel images, so images[:,:3] 128 | writer.add_image('Train/input_rgb_images', input_rgb_images, accIter['train']) 129 | scale = max(1, 255//args.n_class) # label (0,1,2..) is invisable, multiply a constant for visualization 130 | groundtruth_tensor = labels.unsqueeze(1) * scale # mini_batch*480*640 -> mini_batch*1*480*640 131 | groundtruth_tensor = torch.cat((groundtruth_tensor, groundtruth_tensor, groundtruth_tensor), 1) # change to 3-channel for visualization 132 | groudtruth_images = vutils.make_grid(groundtruth_tensor, nrow=8, padding=10) 133 | writer.add_image('Train/groudtruth_images', groudtruth_images, accIter['train']) 134 | predicted_tensor = logits.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640 135 | predicted_tensor = torch.cat((predicted_tensor, predicted_tensor, predicted_tensor),1) # change to 3-channel for visualization, mini_batch*1*480*640 136 | predicted_images = vutils.make_grid(predicted_tensor, nrow=8, padding=10) 137 | writer.add_image('Train/predicted_images', predicted_images, accIter['train']) 138 | 139 | predicted_tensor_rgb_1 = rgb_predict.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640 140 | predicted_tensor_rgb_1 = torch.cat((predicted_tensor_rgb_1, predicted_tensor_rgb_1, predicted_tensor_rgb_1),1) # change to 3-channel for visualization, mini_batch*1*480*640 141 | predicted_images_rgb_1 = vutils.make_grid(predicted_tensor_rgb_1, nrow=8, padding=10) 142 | writer.add_image('Train/predicted_images_rgb_1', predicted_images_rgb_1, accIter['train']) 143 | 144 | predicted_tensor_depth_1 = depth_predict.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640 145 | predicted_tensor_depth_1 = torch.cat((predicted_tensor_depth_1, predicted_tensor_depth_1, predicted_tensor_depth_1),1) # change to 3-channel for visualization, mini_batch*1*480*640 146 | predicted_images_depth_1 = vutils.make_grid(predicted_tensor_depth_1, nrow=8, padding=10) 147 | writer.add_image('Train/predicted_images_depth_1', predicted_images_depth_1, accIter['train']) 148 | 149 | add_map_rgb_1 = add_map_rgb.unsqueeze(1) * scale 150 | predicted_tensor_need_rgb_1 = torch.cat((add_map_rgb_1, add_map_rgb_1, add_map_rgb_1), 1) # change to 3-channel for visualization 151 | predicted_images_need_rgb_1 = vutils.make_grid(predicted_tensor_need_rgb_1, nrow=8, padding=10) 152 | writer.add_image('Train/predicted_images_need_rgb_1', predicted_images_need_rgb_1, accIter['train']) 153 | 154 | add_map_depth_1 = add_map_depth.unsqueeze(1) * scale 155 | predicted_tensor_need_depth_1 = torch.cat((add_map_depth_1, add_map_depth_1, add_map_depth_1), 1) # change to 3-channel for visualization 156 | predicted_images_need_depth_1 = vutils.make_grid(predicted_tensor_need_depth_1, nrow=8, padding=10) 157 | writer.add_image('Train/predicted_images_need_depth_1', predicted_images_need_depth_1, accIter['train']) 158 | 159 | predicted_tensor_complex_rgb_1 = rgb_comple.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640 160 | predicted_tensor_complex_rgb_1 = torch.cat((predicted_tensor_complex_rgb_1, predicted_tensor_complex_rgb_1, predicted_tensor_complex_rgb_1),1) # change to 3-channel for visualization, mini_batch*1*480*640 161 | predicted_images_complex_rgb_1 = vutils.make_grid(predicted_tensor_complex_rgb_1, nrow=8, padding=10) 162 | writer.add_image('Train/predicted_images_complex_rgb_1', predicted_images_complex_rgb_1, accIter['train']) 163 | 164 | predicted_tensor_complex_depth_1 = depth_comple.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640 165 | predicted_tensor_complex_depth_1 = torch.cat((predicted_tensor_complex_depth_1, predicted_tensor_complex_depth_1, predicted_tensor_complex_depth_1),1) # change to 3-channel for visualization, mini_batch*1*480*640 166 | predicted_images_complex_depth_1 = vutils.make_grid(predicted_tensor_complex_depth_1, nrow=8, padding=10) 167 | writer.add_image('Train/predicted_images_complex_depth_1', predicted_images_complex_depth_1, accIter['train']) 168 | accIter['train'] = accIter['train'] + 1 169 | 170 | def validation(epo, model, val_loader): 171 | model.eval() 172 | with torch.no_grad(): 173 | for it, (images, labels, names) in enumerate(val_loader): 174 | images = Variable(images).cuda(args.gpu) 175 | labels = Variable(labels).cuda(args.gpu) 176 | start_t = time.time() # time.time() returns the current time 177 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images) 178 | loss = F.cross_entropy(logits, labels) # Note that the cross_entropy function has already include the softmax function 179 | print('Val: %s, epo %s/%s, iter %s/%s, %.2f img/sec, loss %.4f, time %s' \ 180 | % (args.model_name, epo, args.epoch_max, it + 1, len(val_loader), len(names)/(time.time()-start_t), float(loss), 181 | datetime.datetime.now().replace(microsecond=0)-start_datetime)) 182 | if accIter['val'] % 1 == 0: 183 | writer.add_scalar('Validation/loss', loss, accIter['val']) 184 | view_figure = False # note that I have not colorized the GT and predictions here 185 | if accIter['val'] % 1000 == 0: 186 | if view_figure: 187 | input_rgb_images = vutils.make_grid(images[:, :3], nrow=8, padding=10) # can only display 3-channel images, so images[:,:3] 188 | writer.add_image('Validation/input_rgb_images', input_rgb_images, accIter['val']) 189 | scale = max(1, 255 // args.n_class) # label (0,1,2..) is invisable, multiply a constant for visualization 190 | groundtruth_tensor = labels.unsqueeze(1) * scale # mini_batch*480*640 -> mini_batch*1*480*640 191 | groundtruth_tensor = torch.cat((groundtruth_tensor, groundtruth_tensor, groundtruth_tensor), 1) # change to 3-channel for visualization 192 | groudtruth_images = vutils.make_grid(groundtruth_tensor, nrow=8, padding=10) 193 | writer.add_image('Validation/groudtruth_images', groudtruth_images, accIter['val']) 194 | predicted_tensor = logits.argmax(1).unsqueeze(1)*scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640 195 | predicted_tensor = torch.cat((predicted_tensor, predicted_tensor, predicted_tensor), 1) # change to 3-channel for visualization, mini_batch*1*480*640 196 | predicted_images = vutils.make_grid(predicted_tensor, nrow=8, padding=10) 197 | writer.add_image('Validation/predicted_images', predicted_images, accIter['val']) 198 | accIter['val'] += 1 199 | 200 | def testing(epo, model, test_loader): 201 | model.eval() 202 | conf_total = np.zeros((args.n_class, args.n_class)) 203 | label_list = ["unlabeled", "pothole", "crack"] 204 | testing_results_file = os.path.join(weight_dir, 'testing_results_file.txt') 205 | with torch.no_grad(): 206 | for it, (images, labels, names) in enumerate(test_loader): 207 | images = Variable(images).cuda(args.gpu) 208 | labels = Variable(labels).cuda(args.gpu) 209 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images) 210 | label = labels.cpu().numpy().squeeze().flatten() 211 | prediction = logits.argmax(1).cpu().numpy().squeeze().flatten() # prediction and label are both 1-d array, size: minibatch*640*480 212 | conf = confusion_matrix(y_true=label, y_pred=prediction, labels=[0,1,2]) # conf is args.n_class*args.n_class matrix, vertical axis: groundtruth, horizontal axis: prediction 213 | conf_total += conf 214 | print('Test: %s, epo %s/%s, iter %s/%s, time %s' % (args.model_name, epo, args.epoch_max, it+1, len(test_loader), 215 | datetime.datetime.now().replace(microsecond=0)-start_datetime)) 216 | precision, recall, IoU, F1 = compute_results(conf_total) 217 | writer.add_scalar('Test/average_recall', recall.mean(), epo) 218 | writer.add_scalar('Test/average_IoU', IoU.mean(), epo) 219 | writer.add_scalar('Test/average_precision',precision.mean(), epo) 220 | writer.add_scalar('Test/average_F1', F1.mean(), epo) 221 | for i in range(len(precision)): 222 | writer.add_scalar("Test(class)/precision_class_%s" % label_list[i], precision[i], epo) 223 | writer.add_scalar("Test(class)/recall_class_%s"% label_list[i], recall[i],epo) 224 | writer.add_scalar('Test(class)/Iou_%s'% label_list[i], IoU[i], epo) 225 | writer.add_scalar('Test(class)/F1_%s'% label_list[i], F1[i], epo) 226 | if epo==0: 227 | with open(testing_results_file, 'w') as f: 228 | f.write("# %s, initial lr: %s, batch size: %s, date: %s \n" %(args.model_name, args.lr_start, args.batch_size, datetime.date.today())) 229 | f.write("# epoch: unlabeled, car, person, bike, curve, car_stop, guardrail, color_cone, bump, average(nan_to_num). (Pre %, Acc %, IoU %, F1 %)\n") 230 | with open(testing_results_file, 'a') as f: 231 | f.write(str(epo)+': ') 232 | for i in range(len(precision)): 233 | f.write('%0.4f, %0.4f, %0.4f, %0.4f ' % (100*precision[i], 100*recall[i], 100*IoU[i], 100*F1[i])) 234 | f.write('%0.4f, %0.4f, %0.4f, %0.4f\n' % (100*np.mean(np.nan_to_num(precision)), 100*np.mean(np.nan_to_num(recall)), 100*np.mean(np.nan_to_num(IoU)), 100*np.mean(np.nan_to_num(F1)) )) 235 | #f.write('%0.4f, %0.4f, %0.4f, %0.4f\n' % (100*np.mean(np.nan_to_num(recall)), 100*np.mean(np.nan_to_num(IoU), 100*np.mean(np.nan_to_num(precision)), )))) 236 | print('saving testing results.') 237 | with open(testing_results_file, "r") as file: 238 | writer.add_text('testing_results', file.read().replace('\n', ' \n'), epo) 239 | 240 | if __name__ == '__main__': 241 | 242 | torch.cuda.set_device(args.gpu) 243 | print("\nthe pytorch version:", torch.__version__) 244 | print("the gpu count:", torch.cuda.device_count()) 245 | print("the current used gpu:", torch.cuda.current_device(), '\n') 246 | 247 | config.pretrained_model = config.root_dir + args.pre_weight 248 | 249 | model = eval(args.model_name)(cfg = config ,n_class=args.n_class, encoder_name=args.backbone) 250 | 251 | base_lr = args.lr_start 252 | 253 | if args.gpu >= 0: model.cuda(args.gpu) 254 | #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_start, momentum=0.9, weight_decay=0.0005) 255 | #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay, last_epoch=-1) 256 | params_list = [] 257 | params_list = group_weight(params_list, model, nn.BatchNorm2d, base_lr) 258 | optimizer = torch.optim.AdamW(params_list, lr=base_lr, betas=(0.9, 0.999), weight_decay=config.weight_decay) 259 | 260 | total_iteration = config.nepochs * config.niters_per_epoch 261 | lr_policy = WarmUpPolyLR(base_lr, config.lr_power, total_iteration, config.niters_per_epoch * config.warm_up_epoch) 262 | 263 | # preparing folders 264 | if os.path.exists("./PotCrackSeg"): 265 | shutil.rmtree("./PotCrackSeg") 266 | weight_dir = os.path.join("./PotCrackSeg", args.model_name) 267 | os.makedirs(weight_dir) 268 | os.chmod(weight_dir, stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine 269 | 270 | writer = SummaryWriter("./PotCrackSeg/tensorboard_log") 271 | os.chmod("./PotCrackSeg/tensorboard_log", stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine 272 | os.chmod("./PotCrackSeg", stat.S_IRWXO) 273 | 274 | print('training %s on GPU #%d with pytorch' % (args.model_name, args.gpu)) 275 | print('from epoch %d / %s' % (args.epoch_from, args.epoch_max)) 276 | print('weight will be saved in: %s' % weight_dir) 277 | 278 | train_dataset = MY_dataset(data_dir=args.data_dir, split='train', transform=augmentation_methods,input_h=288, input_w=512) 279 | val_dataset = MY_dataset(data_dir=args.data_dir, split='validation',input_h=288, input_w=512) 280 | test_dataset = MY_dataset(data_dir=args.data_dir, split='test',input_h=288, input_w=512) 281 | 282 | train_loader = DataLoader( 283 | dataset = train_dataset, 284 | batch_size = args.batch_size, 285 | shuffle = True, 286 | num_workers = args.num_workers, 287 | pin_memory = True, 288 | drop_last = False 289 | ) 290 | val_loader = DataLoader( 291 | dataset = val_dataset, 292 | batch_size = args.batch_size, 293 | shuffle = False, 294 | num_workers = args.num_workers, 295 | pin_memory = True, 296 | drop_last = False 297 | ) 298 | test_loader = DataLoader( 299 | dataset = test_dataset, 300 | batch_size = args.batch_size, 301 | shuffle = False, 302 | num_workers = args.num_workers, 303 | pin_memory = True, 304 | drop_last = False 305 | ) 306 | start_datetime = datetime.datetime.now().replace(microsecond=0) 307 | accIter = {'train': 0, 'val': 0} 308 | for epo in range(args.epoch_from, args.epoch_max): 309 | print('\ntrain %s, epo #%s begin...' % (args.model_name, epo)) 310 | #scheduler.step() # if using pytorch 0.4.1, please put this statement here 311 | train(epo, model, train_loader, optimizer) 312 | validation(epo, model, val_loader) 313 | 314 | checkpoint_model_file = os.path.join(weight_dir, str(epo) + '.pth') 315 | print('saving check point %s: ' % checkpoint_model_file) 316 | torch.save(model.state_dict(), checkpoint_model_file) 317 | 318 | testing(epo, model, test_loader) 319 | #scheduler.step() # if using pytorch 1.1 or above, please put this statement here 320 | -------------------------------------------------------------------------------- /util/MY_dataset.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | from torch.utils.data.dataset import Dataset 3 | import numpy as np 4 | import PIL 5 | 6 | class MY_dataset(Dataset): 7 | 8 | def __init__(self, data_dir, split, input_h=288, input_w=512 ,transform=[]): 9 | super(MY_dataset, self).__init__() 10 | 11 | with open(os.path.join(data_dir, split+'.txt'), 'r') as f: 12 | self.names = [name.strip() for name in f.readlines()] 13 | 14 | self.data_dir = data_dir 15 | self.split = split 16 | self.input_h = input_h 17 | self.input_w = input_w 18 | self.transform = transform 19 | self.n_data = len(self.names) 20 | 21 | def read_image(self, name, folder,head): 22 | file_path = os.path.join(self.data_dir, '%s/%s%s.png' % (folder, head,name)) 23 | image = np.asarray(PIL.Image.open(file_path)) 24 | return image 25 | 26 | def __getitem__(self, index): 27 | name = self.names[index] 28 | image = self.read_image(name, 'left','left') 29 | label = self.read_image(name, 'labels','label') 30 | depth = self.read_image(name, 'depth','depth') 31 | image = np.asarray(PIL.Image.fromarray(image).resize((self.input_w, self.input_h))) 32 | image = image.astype('float32') 33 | image = np.transpose(image, (2,0,1))/255.0 34 | depth = np.asarray(PIL.Image.fromarray(depth).resize((self.input_w, self.input_h))) 35 | 36 | depth = depth.astype('float32') 37 | 38 | M = depth.max() 39 | depth = depth/M 40 | 41 | label = np.asarray(PIL.Image.fromarray(label).resize((self.input_w, self.input_h), resample=PIL.Image.NEAREST)) 42 | label = label.astype('int64') 43 | 44 | 45 | return torch.cat((torch.tensor(image), torch.tensor(depth).unsqueeze(0)),dim=0), torch.tensor(label),name 46 | 47 | def __len__(self): 48 | return self.n_data 49 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | class RandomFlip(): 5 | def __init__(self, prob=0.5): 6 | #super(RandomFlip, self).__init__() 7 | self.prob = prob 8 | 9 | def __call__(self, image, label): 10 | if np.random.rand() < self.prob: 11 | image = image[:,::-1] 12 | label = label[:,::-1] 13 | return image, label 14 | 15 | class RandomCrop(): 16 | def __init__(self, crop_rate=0.1, prob=1.0): 17 | #super(RandomCrop, self).__init__() 18 | self.crop_rate = crop_rate 19 | self.prob = prob 20 | 21 | def __call__(self, image, label): 22 | if np.random.rand() < self.prob: 23 | w, h, c = image.shape 24 | 25 | h1 = np.random.randint(0, h*self.crop_rate) 26 | w1 = np.random.randint(0, w*self.crop_rate) 27 | h2 = np.random.randint(h-h*self.crop_rate, h+1) 28 | w2 = np.random.randint(w-w*self.crop_rate, w+1) 29 | 30 | image = image[w1:w2, h1:h2] 31 | label = label[w1:w2, h1:h2] 32 | 33 | return image, label 34 | 35 | 36 | class RandomCropOut(): 37 | def __init__(self, crop_rate=0.2, prob=1.0): 38 | #super(RandomCropOut, self).__init__() 39 | self.crop_rate = crop_rate 40 | self.prob = prob 41 | 42 | def __call__(self, image, label): 43 | if np.random.rand() < self.prob: 44 | w, h, c = image.shape 45 | 46 | h1 = np.random.randint(0, h*self.crop_rate) 47 | w1 = np.random.randint(0, w*self.crop_rate) 48 | h2 = int(h1 + h*self.crop_rate) 49 | w2 = int(w1 + w*self.crop_rate) 50 | 51 | image[w1:w2, h1:h2] = 0 52 | label[w1:w2, h1:h2] = 0 53 | 54 | return image, label 55 | 56 | 57 | class RandomBrightness(): 58 | def __init__(self, bright_range=0.15, prob=0.9): 59 | #super(RandomBrightness, self).__init__() 60 | self.bright_range = bright_range 61 | self.prob = prob 62 | 63 | def __call__(self, image, label): 64 | if np.random.rand() < self.prob: 65 | bright_factor = np.random.uniform(1-self.bright_range, 1+self.bright_range) 66 | image = (image * bright_factor).astype(image.dtype) 67 | 68 | return image, label 69 | 70 | 71 | class RandomNoise(): 72 | def __init__(self, noise_range=5, prob=0.9): 73 | #super(RandomNoise, self).__init__() 74 | self.noise_range = noise_range 75 | self.prob = prob 76 | 77 | def __call__(self, image, label): 78 | if np.random.rand() < self.prob: 79 | w, h, c = image.shape 80 | 81 | noise = np.random.randint( 82 | -self.noise_range, 83 | self.noise_range, 84 | (w,h,c) 85 | ) 86 | 87 | image = (image + noise).clip(0,255).astype(image.dtype) 88 | 89 | return image, label 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /util/init_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/9/28 下午12:13 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : init_func.py.py 7 | import torch 8 | import torch.nn as nn 9 | 10 | def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, 11 | **kwargs): 12 | for name, m in feature.named_modules(): 13 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | conv_init(m.weight, **kwargs) 15 | elif isinstance(m, norm_layer): 16 | m.eps = bn_eps 17 | m.momentum = bn_momentum 18 | nn.init.constant_(m.weight, 1) 19 | nn.init.constant_(m.bias, 0) 20 | 21 | 22 | def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, 23 | **kwargs): 24 | if isinstance(module_list, list): 25 | for feature in module_list: 26 | __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, 27 | **kwargs) 28 | else: 29 | __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, 30 | **kwargs) 31 | 32 | 33 | def group_weight(weight_group, module, norm_layer, lr): 34 | group_decay = [] 35 | group_no_decay = [] 36 | count = 0 37 | for m in module.modules(): 38 | if isinstance(m, nn.Linear): 39 | group_decay.append(m.weight) 40 | if m.bias is not None: 41 | group_no_decay.append(m.bias) 42 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): 43 | group_decay.append(m.weight) 44 | if m.bias is not None: 45 | group_no_decay.append(m.bias) 46 | elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ 47 | or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.LayerNorm): 48 | if m.weight is not None: 49 | group_no_decay.append(m.weight) 50 | if m.bias is not None: 51 | group_no_decay.append(m.bias) 52 | elif isinstance(m, nn.Parameter): 53 | group_decay.append(m) 54 | 55 | assert len(list(module.parameters())) >= len(group_decay) + len(group_no_decay) 56 | weight_group.append(dict(params=group_decay, lr=lr)) 57 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) 58 | return weight_group -------------------------------------------------------------------------------- /util/load_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | from torch import distributed as dist 4 | 5 | 6 | def get_dist_info(): 7 | if dist.is_available(): 8 | initialized = dist.is_initialized() 9 | else: 10 | initialized = False 11 | if initialized: 12 | rank = dist.get_rank() 13 | world_size = dist.get_world_size() 14 | else: 15 | rank = 0 16 | world_size = 1 17 | return rank, world_size 18 | 19 | 20 | def load_state_dict(module, state_dict, strict=False, logger=None): 21 | unexpected_keys = [] 22 | all_missing_keys = [] 23 | err_msg = [] 24 | 25 | metadata = getattr(state_dict, '_metadata', None) 26 | state_dict = state_dict.copy() 27 | if metadata is not None: 28 | state_dict._metadata = metadata 29 | 30 | # use _load_from_state_dict to enable checkpoint version control 31 | def load(module, prefix=''): 32 | # recursively check parallel module in case that the model has a 33 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) 34 | local_metadata = {} if metadata is None else metadata.get( 35 | prefix[:-1], {}) 36 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, 37 | all_missing_keys, unexpected_keys, 38 | err_msg) 39 | for name, child in module._modules.items(): 40 | if child is not None: 41 | load(child, prefix + name + '.') 42 | 43 | load(module) 44 | load = None # break load->load reference cycle 45 | 46 | # ignore "num_batches_tracked" of BN layers 47 | missing_keys = [ 48 | key for key in all_missing_keys if 'num_batches_tracked' not in key 49 | ] 50 | 51 | if unexpected_keys: 52 | err_msg.append('unexpected key in source ' 53 | f'state_dict: {", ".join(unexpected_keys)}\n') 54 | if missing_keys: 55 | err_msg.append( 56 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n') 57 | 58 | rank, _ = get_dist_info() 59 | if len(err_msg) > 0 and rank == 0: 60 | err_msg.insert( 61 | 0, 'The model and loaded state dict do not match exactly\n') 62 | err_msg = '\n'.join(err_msg) 63 | if strict: 64 | raise RuntimeError(err_msg) 65 | else: 66 | print(err_msg) 67 | 68 | 69 | 70 | def load_pretrain(model, 71 | filename, 72 | strict=False, 73 | revise_keys=[(r'^module\.', '')]): 74 | checkpoint = torch.load(filename) 75 | # OrderedDict is a subclass of dict 76 | if not isinstance(checkpoint, dict): 77 | raise RuntimeError( 78 | f'No state_dict found in checkpoint file {filename}') 79 | # get state_dict from checkpoint 80 | if 'state_dict' in checkpoint: 81 | state_dict = checkpoint['state_dict'] 82 | elif 'model' in checkpoint: 83 | state_dict = checkpoint['model'] 84 | else: 85 | state_dict = checkpoint 86 | # strip prefix of state_dict 87 | for p, r in revise_keys: 88 | state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()} 89 | # load state_dict 90 | load_state_dict(model, state_dict, strict) 91 | return checkpoint -------------------------------------------------------------------------------- /util/lr_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/8/1 上午1:50 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : lr_policy.py.py 7 | 8 | from abc import ABCMeta, abstractmethod 9 | 10 | 11 | class BaseLR(): 12 | __metaclass__ = ABCMeta 13 | 14 | @abstractmethod 15 | def get_lr(self, cur_iter): pass 16 | 17 | 18 | class PolyLR(BaseLR): 19 | def __init__(self, start_lr, lr_power, total_iters): 20 | self.start_lr = start_lr 21 | self.lr_power = lr_power 22 | self.total_iters = total_iters + 0.0 23 | 24 | def get_lr(self, cur_iter): 25 | return self.start_lr * ( 26 | (1 - float(cur_iter) / self.total_iters) ** self.lr_power) 27 | 28 | 29 | class WarmUpPolyLR(BaseLR): 30 | def __init__(self, start_lr, lr_power, total_iters, warmup_steps): 31 | print(start_lr) 32 | self.start_lr = start_lr 33 | self.lr_power = lr_power 34 | self.total_iters = total_iters + 0.0 35 | self.warmup_steps = warmup_steps 36 | 37 | def get_lr(self, cur_iter): 38 | if cur_iter < self.warmup_steps: 39 | return self.start_lr * (cur_iter / self.warmup_steps) 40 | else: 41 | return self.start_lr * ( 42 | (1 - float(cur_iter) / self.total_iters) ** self.lr_power) 43 | 44 | 45 | class MultiStageLR(BaseLR): 46 | def __init__(self, lr_stages): 47 | assert type(lr_stages) in [list, tuple] and len(lr_stages[0]) == 2, \ 48 | 'lr_stages must be list or tuple, with [iters, lr] format' 49 | self._lr_stagess = lr_stages 50 | 51 | def get_lr(self, epoch): 52 | for it_lr in self._lr_stagess: 53 | if epoch < it_lr[0]: 54 | return it_lr[1] 55 | 56 | 57 | class LinearIncreaseLR(BaseLR): 58 | def __init__(self, start_lr, end_lr, warm_iters): 59 | self._start_lr = start_lr 60 | self._end_lr = end_lr 61 | self._warm_iters = warm_iters 62 | self._delta_lr = (end_lr - start_lr) / warm_iters 63 | 64 | def get_lr(self, cur_epoch): 65 | return self._start_lr + cur_epoch * self._delta_lr -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | # By Yuxiang Sun, Dec. 4, 2020 2 | # Email: sun.yuxiang@outlook.com 3 | 4 | import numpy as np 5 | from PIL import Image 6 | import torch 7 | 8 | # 0:unlabeled, 1:car, 2:person, 3:bike, 4:curve, 5:car_stop, 6:guardrail, 7:color_cone, 8:bump 9 | def get_palette(): 10 | unlabelled = [0,0,0] 11 | sat = [153,0,0] 12 | fanban = [0,153,0] 13 | lidar = [0,0,153] 14 | penzui = [153,153,0] 15 | palette = np.array([unlabelled,sat, fanban, lidar, penzui]) 16 | return palette 17 | 18 | def visualize(image_name, predictions, weight_name): 19 | palette = get_palette() 20 | for (i, pred) in enumerate(predictions): 21 | pred = predictions[i].cpu().numpy() 22 | img = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8) 23 | for cid in range(0, len(palette)): # fix the mistake from the MFNet code on Dec.27, 2019 24 | img[pred == cid] = palette[cid] 25 | img = Image.fromarray(np.uint8(img)) 26 | img.save('runs/Pred_' + weight_name + '_' + image_name[i] + '.png') 27 | 28 | def compute_results(conf_total): 29 | n_class = conf_total.shape[0] 30 | consider_unlabeled = True # must consider the unlabeled, please set it to True 31 | if consider_unlabeled is True: 32 | start_index = 0 33 | else: 34 | start_index = 1 35 | precision_per_class = np.zeros(n_class) 36 | recall_per_class = np.zeros(n_class) 37 | iou_per_class = np.zeros(n_class) 38 | F1_per_class = np.zeros(n_class) 39 | for cid in range(start_index, n_class): # cid: class id 40 | if conf_total[start_index:, cid].sum() == 0: 41 | precision_per_class[cid] = np.nan 42 | else: 43 | precision_per_class[cid] = float(conf_total[cid, cid]) / float(conf_total[start_index:, cid].sum()) # precision = TP/TP+FP 44 | if conf_total[cid, start_index:].sum() == 0: 45 | recall_per_class[cid] = np.nan 46 | else: 47 | recall_per_class[cid] = float(conf_total[cid, cid]) / float(conf_total[cid, start_index:].sum()) # recall = TP/TP+FN 48 | if (conf_total[cid, start_index:].sum() + conf_total[start_index:, cid].sum() - conf_total[cid, cid]) == 0: 49 | iou_per_class[cid] = np.nan 50 | else: 51 | iou_per_class[cid] = float(conf_total[cid, cid]) / float((conf_total[cid, start_index:].sum() + conf_total[start_index:, cid].sum() - conf_total[cid, cid])) # IoU = TP/TP+FP+FN 52 | if (recall_per_class[cid] == np.nan) | (precision_per_class[cid] == np.nan) |(precision_per_class[cid]==0)|(recall_per_class[cid]==0): 53 | F1_per_class[cid] = np.nan 54 | else : 55 | F1_per_class[cid] = 2 / (1/precision_per_class[cid] +1/recall_per_class[cid]) 56 | 57 | return precision_per_class, recall_per_class, iou_per_class,F1_per_class 58 | 59 | def label2onehot(label,n_class,input_w,input_h): 60 | label = torch.tensor(label).unsqueeze(0) 61 | print(label.size()) 62 | onehot = torch.zeros(n_class, input_h, input_w).long() # 先生成模板 63 | print(onehot.size()) 64 | onehot.scatter_(0, label, 1).float() # 这个就是生成6个channel的, scatter_这个函数不必理解太深,知道这么一个用法就OK了 65 | 66 | return onehot 67 | 68 | --------------------------------------------------------------------------------