├── .devcontainer
├── devcontainer.json
├── docker-compose.yml
└── readme
├── Dockerfile
├── LICENSE
├── README.md
├── config.py
├── docs
├── overall.jpg
└── results.jpg
├── model
├── __init__.py
└── mysegformer
│ ├── PotCrackSeg.py
│ ├── __init__.py
│ ├── decoders
│ └── Decoder.py
│ └── encoders
│ ├── dual_segformer.py
│ └── one_segformer.py
├── run_demo.py
├── train.py
└── util
├── MY_dataset.py
├── __init__.py
├── augmentation.py
├── init_func.py
├── load_utils.py
├── lr_policy.py
└── util.py
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at:
2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.163.1/containers/docker-from-docker-compose
3 | {
4 | "name": "potcrackseg",
5 | "dockerComposeFile": "docker-compose.yml",
6 | "service": "potcrackseg",
7 | "workspaceFolder": "/workspace"
8 |
9 | // // Use this environment variable if you need to bind mount your local source code into a new container.
10 | // "remoteEnv": {
11 | // "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}"
12 | // },
13 |
14 | // // Set *default* container specific settings.json values on container create.
15 | // "settings": {
16 | // "terminal.integrated.shell.linux": "/bin/bash"
17 | // },
18 |
19 | // // Add the IDs of extensions you want installed when the container is created.
20 | // "extensions": [
21 | // "ms-azuretools.vscode-docker"
22 | // ],
23 |
24 | // // Use 'forwardPorts' to make a list of ports inside the container available locally.
25 | // // "forwardPorts": [],
26 |
27 | // // Use 'postCreateCommand' to run commands after the container is created.
28 | // // "postCreateCommand": "docker --version",
29 |
30 | // // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
31 | // "remoteUser": "vscode"
32 | }
--------------------------------------------------------------------------------
/.devcontainer/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '2.3'
2 | services:
3 | potcrackseg:
4 | # Uncomment the next line to use a non-root user for all processes. You can also
5 | # simply use the "remoteUser" property in devcontainer.json if you just want VS Code
6 | # and its sub-processes (terminals, tasks, debugging) to execute as the user. On Linux,
7 | # you may need to update USER_UID and USER_GID in .devcontainer/Dockerfile to match your
8 | # user if not 1000. See https://aka.ms/vscode-remote/containers/non-root for details.
9 | # user: vscode
10 | runtime: nvidia
11 | image: docker_image_drcnet # The name of the docker image
12 | ports:
13 | - '11011:6006'
14 | volumes:
15 | # Update this to wherever you want VS Code to mount the folder of your project
16 | - ..:/workspace:cached # Do not change!
17 | # - /home/sun/somefolder/:/somefolder # folder_in_local_computer:folder_in_docker_container
18 |
19 | # Forwards the local Docker socket to the container.
20 | - /var/run/docker.sock:/var/run/docker-host.sock
21 | shm_size: 32g
22 | devices:
23 | - /dev/nvidia0
24 | # - /dev/nvidia1 # Please note this line, if your computer has only one GPU
25 |
26 | # Uncomment the next four lines if you will use a ptrace-based debuggers like C++, Go, and Rust.
27 | # cap_add:
28 | # - SYS_PTRACE
29 | # security_opt:
30 | # - seccomp:unconfined
31 |
32 | # Overrides default command so things don't shut down after the process ends.
33 | #entrypoint: /usr/local/share/docker-init.sh
34 | command: sleep infinity
35 |
--------------------------------------------------------------------------------
/.devcontainer/readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04
2 | ENV DEBIAN_FRONTEND=noninteractive
3 | #RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 'A4B469963BF863CC'
4 |
5 | RUN apt-key del 7fa2af80
6 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub
7 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu2004/x86_64/7fa2af80.pub
8 |
9 |
10 | RUN apt-get update && apt-get install -y vim python3 python3-pip
11 |
12 | RUN pip3 install --upgrade pip
13 | RUN pip3 install setuptools>=40.3.0
14 |
15 | RUN pip3 install -U scipy scikit-learn
16 | RUN pip3 install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
17 | RUN pip3 install torchsummary
18 | RUN pip3 install tensorboard==2.11.0
19 | RUN pip3 install einops
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Autonomous Systems Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PotCrackSeg
2 | The official pytorch implementation of **Segmentation of Road Negative Obstacles Based on Dual Semantic-feature Complementary Fusion for Autonomous Driving**. ([TIV](https://ieeexplore.ieee.org/document/10468640/))
3 |
4 |
5 | We test our code in Python 3.8, CUDA 11.3, cuDNN 8, and PyTorch 1.12.1. We provide `Dockerfile` to build the docker image we used. You can modify the `Dockerfile` as you want.
6 |
7 |

8 |
9 |
10 | # Demo
11 |
12 | The accompanied video can be found at:
13 |
16 |
17 | # Introduction
18 |
19 | PotCrackSeg with an RGB-Depth fusion network with a dual semantic-feature complementary fusion module for the segmentation of potholes and cracks in traffic scenes.
20 |
21 | # Dataset
22 |
23 | The **NPO++** dataset is upgraded from the existing [**NPO**](https://pan.baidu.com/s/1-LuHyKXEuJ0oLMe1PHtq0Q?pwd=drno) dataset by re-labeling potholes and cracks. You can downloaded **NPO++** dataset from [here](https://pan.baidu.com/s/1608EIKo-be63XE3-7UYcIQ?pwd=uxks)
24 |
25 | # Pretrained weights
26 | The pretrained weight of PotCrackSeg can be downloaded from [here](https://pan.baidu.com/s/18xGs1Jp1xbSekBjJVEh9Pg?pwd=ynva).
27 |
28 | # Usage
29 | * Clone this repo
30 | ```
31 | $ git clone https://github.com/lab-sun/PotCrackSeg.git
32 | ```
33 | * Build docker image
34 | ```
35 | $ cd ~/PotCrackSeg
36 | $ docker build -t docker_image_PotCrackSeg .
37 | ```
38 | * Download the dataset
39 | ```
40 | $ (You should be in the PotCrackSeg folder)
41 | $ mkdir ./NPO++
42 | $ cd ./NPO++
43 | $ (download our preprocessed NPO++.zip in this folder)
44 | $ unzip -d . NPO++.zip
45 | ```
46 | * To reproduce our results, you need to download our pretrained weights.
47 | ```
48 | $ (You should be in the PotCrackSeg folder)
49 | $ mkdir ./weights_backup
50 | $ cd ./weights_backup
51 | $ (download our preprocessed weights_backup.zip in this folder)
52 | $ unzip -d . weights_backup.zip
53 | $ docker run -it --shm-size 8G -p 1234:6006 --name docker_container_potcrackseg --gpus all -v ~/PotCrackSeg:/workspace docker_image_potcrackseg
54 | $ (currently, you should be in the docker)
55 | $ cd /workspace
56 | $ (To reproduce the results)
57 | $ python3 run_demo.py
58 | ```
59 | The results will be saved in the `./runs` folder. The default results are PotCrackSeg-4B. If you want to reproduce the results of PotCrackSeg-2B, you can modify the *PotCrackSeg-4B* to *PotCrackSeg-2B* in run_demo.py
60 |
61 | * To train PotCrackSeg.
62 | ```
63 | $ (You should be in the PotCrackSeg folder)
64 | $ docker run -it --shm-size 8G -p 1234:6006 --name docker_container_potcrackseg --gpus all -v ~/PotCrackSeg:/workspace docker_image_potcrackseg
65 | $ (currently, you should be in the docker)
66 | $ cd /workspace
67 | $ python3 train.py
68 | ```
69 |
70 | * To see the training process
71 | ```
72 | $ (fire up another terminal)
73 | $ docker exec -it docker_container_potcrackseg /bin/bash
74 | $ cd /workspace
75 | $ tensorboard --bind_all --logdir=./runs/tensorboard_log/
76 | $ (fire up your favorite browser with http://localhost:1234, you will see the tensorboard)
77 | ```
78 | The results will be saved in the `./runs` folder.
79 | Note: Please change the smoothing factor in the Tensorboard webpage to `0.999`, otherwise, you may not find the patterns from the noisy plots. If you have the error `docker: Error response from daemon: could not select device driver`, please first install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) on your computer!
80 |
81 | # Citation
82 | If you use PotCrackSeg in your academic work, please cite:
83 | ```
84 |
85 | ```
86 |
87 | # Acknowledgement
88 | Some of the codes are borrowed from [IGFNet](https://github.com/lab-sun/IGFNet).
89 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import sys
4 | import time
5 | import numpy as np
6 | from easydict import EasyDict as edict
7 | import argparse
8 |
9 | C = edict()
10 | config = C
11 | cfg = C
12 |
13 | C.seed = 12345
14 |
15 | remoteip = os.popen('pwd').read()
16 | C.root_dir = os.path.abspath(os.path.join(os.getcwd(), './'))
17 | C.abs_dir = osp.realpath(".")
18 |
19 | # Dataset config
20 | """Dataset Path"""
21 | C.dataset_name = 'NPO++'
22 | C.dataset_path = osp.join(C.root_dir, 'NPO++')
23 | C.rgb_root_folder = osp.join(C.dataset_path, 'left')
24 | C.rgb_format = '.png'
25 | C.gt_root_folder = osp.join(C.dataset_path, 'labels')
26 | C.gt_format = '.png'
27 | C.x_root_folder = osp.join(C.dataset_path, 'depth')
28 | C.x_format = '.png'
29 | C.x_is_single_channel = False # True for raw depth, thermal and aolp/dolp(not aolp/dolp tri) input
30 | C.train_source = osp.join(C.dataset_path, "train.txt")
31 | C.eval_source = osp.join(C.dataset_path, "test.txt")
32 | C.is_test = False
33 | C.num_train_imgs = 2300
34 | C.num_eval_imgs = 1150
35 |
36 |
37 | """Image Config"""
38 | C.background = 255
39 | C.image_height = 288
40 | C.image_width = 512
41 | C.norm_mean = np.array([0.485, 0.456, 0.406])
42 | C.norm_std = np.array([0.229, 0.224, 0.225])
43 |
44 | """ Settings for network, this would be different for each kind of model"""
45 | C.backbone = 'mit_b2' # Remember change the path below.
46 | # C.pretrained_model = C.root_dir + '/pretrained/mit_b2.pth' # Using for training
47 | C.pretrained_model = None # Using for testing
48 | C.decoder = 'MLPDecoder'
49 | C.decoder_embed_dim = 512
50 | C.optimizer = 'AdamW'
51 |
52 | """Train Config"""
53 | C.lr = 6e-5
54 | C.lr_power = 0.9
55 | C.momentum = 0.9
56 | C.weight_decay = 0.0005
57 | C.batch_size = 2
58 | C.nepochs = 500
59 | C.niters_per_epoch = C.num_train_imgs // C.batch_size + 1
60 | C.num_workers = 16
61 | C.train_scale_array = [0.5, 0.75, 1, 1.25, 1.5, 1.75]
62 | C.warm_up_epoch = 10
63 |
64 | C.fix_bias = True
65 | C.bn_eps = 1e-3
66 | C.bn_momentum = 0.1
67 |
68 | """Eval Config"""
69 | C.eval_iter = 25
70 | C.eval_stride_rate = 2 / 3
71 | C.eval_scale_array = [1] # [0.75, 1, 1.25] #
72 | C.eval_flip = False # True #
73 | C.eval_crop_size = [288, 512] # [height weight]
74 |
75 | """Store Config"""
76 | C.checkpoint_start_epoch = 50
77 | C.checkpoint_step = 1
78 |
79 | """Path Config"""
80 | def add_path(path):
81 | if path not in sys.path:
82 | sys.path.insert(0, path)
83 | add_path(osp.join(C.root_dir))
84 |
85 | C.log_dir = osp.abspath('log_' + C.dataset_name + '_' + C.backbone)
86 | C.tb_dir = osp.abspath(osp.join(C.log_dir, "tb"))
87 | C.log_dir_link = C.log_dir
88 | C.checkpoint_dir = osp.abspath(osp.join(C.log_dir, "checkpoint"))
89 |
90 | exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime())
91 | C.log_file = C.log_dir + '/log_' + exp_time + '.log'
92 | C.link_log_file = C.log_file + '/log_last.log'
93 | C.val_log_file = C.log_dir + '/val_' + exp_time + '.log'
94 | C.link_val_log_file = C.log_dir + '/val_last.log'
95 |
96 | if __name__ == '__main__':
97 | print(config.nepochs)
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument(
100 | '-tb', '--tensorboard', default=False, action='store_true')
101 | args = parser.parse_args()
102 |
103 | if args.tensorboard:
104 | open_tensorboard()
105 |
--------------------------------------------------------------------------------
/docs/overall.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lab-sun/PotCrackSeg/4d1aa7b53b4cf7f6ebd1cb660cd66a793d77150d/docs/overall.jpg
--------------------------------------------------------------------------------
/docs/results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lab-sun/PotCrackSeg/4d1aa7b53b4cf7f6ebd1cb660cd66a793d77150d/docs/results.jpg
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .mysegformer.PotCrackSeg import EncoderDecoder as PotCrackSeg
2 |
3 |
--------------------------------------------------------------------------------
/model/mysegformer/PotCrackSeg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import sys
6 | sys.path.append('.')
7 |
8 | from util.init_func import init_weight
9 | from util.load_utils import load_pretrain
10 | from functools import partial
11 | from config import config
12 | from model.mysegformer.decoders.Decoder import DecoderHead
13 |
14 |
15 | class EncoderDecoder(nn.Module):
16 | def __init__(self,cfg=None, criterion=nn.CrossEntropyLoss(reduction='mean', ignore_index=255), encoder_name='mit_b2', n_class=3, norm_layer=nn.BatchNorm2d):
17 | super(EncoderDecoder, self).__init__()
18 | self.channels = [64, 128, 320, 512]
19 | self.norm_layer = norm_layer
20 |
21 | # import backbone and decoder
22 | if encoder_name == 'PotCrackSeg-5B':
23 | #logger.info('Using backbone: Segformer-B5')
24 | from model.mysegformer.encoders.dual_segformer import mit_b5 as backbone
25 | print("chose mit_b5")
26 | self.backbone = backbone(norm_fuse=norm_layer)
27 | elif encoder_name == 'PotCrackSeg-4B':
28 | #logger.info('Using backbone: Segformer-B4')
29 | from model.mysegformer.encoders.dual_segformer import mit_b4 as backbone
30 | print("chose mit_b4")
31 | self.backbone = backbone(norm_fuse=norm_layer)
32 | elif encoder_name == 'PotCrackSeg-3B':
33 | #logger.info('Using backbone: Segformer-B4')
34 | from model.mysegformer.encoders.dual_segformer import mit_b3 as backbone
35 | print("chose mit_b3")
36 | self.backbone = backbone(norm_fuse=norm_layer)
37 | elif encoder_name == 'PotCrackSeg-2B':
38 | #logger.info('Using backbone: Segformer-B2')
39 | from model.mysegformer.encoders.dual_segformer import mit_b2 as backbone
40 | print("chose mit_b2")
41 | self.backbone = backbone(norm_fuse=norm_layer)
42 | elif encoder_name == 'PotCrackSeg-1B':
43 | #logger.info('Using backbone: Segformer-B1')
44 | from model.mysegformer.encoders.dual_segformer import mit_b1 as backbone
45 | print("chose mit_b1")
46 | self.backbone = backbone(norm_fuse=norm_layer)
47 | elif encoder_name == 'PotCrackSeg-0B':
48 | #logger.info('Using backbone: Segformer-B0')
49 | self.channels = [32, 64, 160, 256]
50 | from model.mysegformer.encoders.dual_segformer import mit_b0 as backbone
51 | print("chose mit_b0")
52 | self.backbone = backbone(norm_fuse=norm_layer)
53 | else:
54 | #logger.info('Using backbone: Segformer-B2')
55 | from encoders.dual_segformer import mit_b2 as backbone
56 | self.backbone = backbone(norm_fuse=norm_layer)
57 |
58 | self.aux_head = None
59 |
60 | self.decode_head = DecoderHead(in_channels=self.channels, num_classes=n_class, norm_layer=norm_layer, embed_dim=512)
61 |
62 |
63 | self.voting = nn.Conv2d(in_channels=n_class*2,out_channels=n_class,kernel_size=3,stride=1,padding=1)
64 |
65 | self.init_weights(cfg, pretrained=cfg.pretrained_model)
66 |
67 | def init_weights(self, cfg, pretrained=None):
68 | if pretrained:
69 | self.backbone.init_weights(pretrained=pretrained)
70 | init_weight(self.decode_head, nn.init.kaiming_normal_,
71 | self.norm_layer, cfg.bn_eps, cfg.bn_momentum,
72 | mode='fan_in', nonlinearity='relu')
73 |
74 | def encode_decode(self, rgb, modal_x):
75 | """Encode images with backbone and decode into a semantic segmentation
76 | map of the same size as input."""
77 | orisize = rgb.shape
78 | rgb,depth = self.backbone(rgb, modal_x)
79 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion = self.decode_head.forward(rgb,depth)
80 |
81 | rgb_fusion = F.interpolate(rgb_fusion, size=orisize[2:], mode='bilinear', align_corners=False)
82 | depth_fusion = F.interpolate(depth_fusion, size=orisize[2:], mode='bilinear', align_corners=False)
83 |
84 | final = self.voting(torch.cat((rgb_fusion,depth_fusion),dim=1))
85 |
86 | return rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple,depth_fusion,final
87 |
88 | def forward(self, input):
89 |
90 | rgb = input[:,:3]
91 | modal_x = input[:,3:]
92 | modal_x = torch.cat((modal_x,modal_x,modal_x),dim=1)
93 |
94 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion,final = self.encode_decode(rgb, modal_x)
95 |
96 | return rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion,final
97 |
98 |
99 | def unit_test():
100 |
101 | num_minibatch = 2
102 | rgb = torch.randn(num_minibatch, 3, 288, 512).cuda(0)
103 | thermal = torch.randn(num_minibatch, 1, 288, 512).cuda(0)
104 | images = torch.cat((rgb,thermal),dim=1)
105 | rtf_net = EncoderDecoder(cfg = config, encoder_name='mit_b2', decoder_name='MLPDecoderNewDRC', n_class=3).cuda(0)
106 | #input = torch.cat((rgb, thermal), dim=1)
107 | #out_seg_4,out_seg_3,out_seg_2,out_seg_1,out_x,out_d,final = rtf_net(images)
108 |
109 | rtf_net.eval()
110 |
111 | final = rtf_net(images)
112 |
113 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion ,out_x,out_d,final = final
114 |
115 | # for i in out_seg_1:
116 | # print(i.shape)
117 |
118 |
119 | print(final.shape)
120 |
121 | if __name__ == '__main__':
122 | unit_test()
123 |
--------------------------------------------------------------------------------
/model/mysegformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lab-sun/PotCrackSeg/4d1aa7b53b4cf7f6ebd1cb660cd66a793d77150d/model/mysegformer/__init__.py
--------------------------------------------------------------------------------
/model/mysegformer/decoders/Decoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch
4 |
5 | from torch.nn.modules import module
6 | import torch.nn.functional as F
7 |
8 | class MLP(nn.Module):
9 | """
10 | Linear Embedding:
11 | """
12 | def __init__(self, input_dim=2048, embed_dim=768):
13 | super().__init__()
14 | self.proj = nn.Linear(input_dim, embed_dim)
15 |
16 | def forward(self, x):
17 | x = x.flatten(2).transpose(1, 2)
18 | x = self.proj(x)
19 | return x
20 |
21 |
22 | class DecoderHead(nn.Module):
23 | def __init__(self,
24 | in_channels=[64, 128, 320, 512],
25 | num_classes=40,
26 | dropout_ratio=0.1,
27 | norm_layer=nn.BatchNorm2d,
28 | embed_dim=768,
29 | align_corners=False):
30 |
31 | super(DecoderHead, self).__init__()
32 | self.num_classes = num_classes
33 | self.dropout_ratio = dropout_ratio
34 | self.align_corners = align_corners
35 |
36 | self.in_channels = in_channels
37 |
38 | if dropout_ratio > 0:
39 | self.dropout = nn.Dropout2d(dropout_ratio)
40 | else:
41 | self.dropout = None
42 |
43 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
44 |
45 | embedding_dim = embed_dim
46 |
47 |
48 | # RGB decoder
49 |
50 | self.linear_c4_rgb = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
51 | self.linear_c3_rgb = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
52 | self.linear_c2_rgb = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
53 | self.linear_c1_rgb = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
54 |
55 | self.linear_fuse_rgb = nn.Sequential(
56 | nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1),
57 | norm_layer(embedding_dim),
58 | nn.ReLU(inplace=True)
59 | )
60 |
61 | self.linear_pred_rgb = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
62 |
63 |
64 | #depth decoder
65 |
66 | self.linear_c4_depth = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
67 | self.linear_c3_depth = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
68 | self.linear_c2_depth = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
69 | self.linear_c1_depth = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
70 |
71 | self.linear_fuse_depth = nn.Sequential(
72 | nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1),
73 | norm_layer(embedding_dim),
74 | nn.ReLU(inplace=True)
75 | )
76 |
77 | self.linear_pred_depth = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
78 |
79 |
80 | self.linear_fusion = nn.Sequential(
81 | nn.Conv2d(in_channels=embedding_dim*4, out_channels=embedding_dim, kernel_size=1),
82 | norm_layer(embedding_dim),
83 | nn.ReLU(inplace=True)
84 | )
85 |
86 | #self.linear_pred_fusion = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
87 |
88 |
89 |
90 | self.DRC = DRC(embedding_dim,self.num_classes)
91 |
92 | def forward(self, rgb,depth):
93 | # len=4, 1/4,1/8,1/16,1/32
94 |
95 | c1_r, c2_r, c3_r, c4_r = rgb
96 | verbose = False
97 |
98 | if verbose: print("c1_r size",c1_r.size())
99 | if verbose: print("c2_r size",c2_r.size())
100 | if verbose: print("c3_r size",c3_r.size())
101 | if verbose: print("c4_r size",c4_r.size())
102 |
103 | c1_d, c2_d, c3_d, c4_d = depth
104 |
105 | if verbose: print("c1_d size",c1_d.size())
106 | if verbose: print("c2_d size",c2_d.size())
107 | if verbose: print("c3_d size",c3_d.size())
108 | if verbose: print("c4_d size",c4_d.size())
109 |
110 | ############## MLP decoder on C1-C4 ###########
111 | n, _, h, w = c4_r.shape
112 |
113 | _c4_r = self.linear_c4_rgb(c4_r).permute(0,2,1).reshape(n, -1, c4_r.shape[2], c4_r.shape[3])
114 | if verbose: print("_c4_r size after linear_c4_rgb",_c4_r.size())
115 | _c4_d = self.linear_c4_depth(c4_d).permute(0,2,1).reshape(n, -1, c4_d.shape[2], c4_d.shape[3])
116 | if verbose: print("_c4_d size after linear_c4_depth",_c4_d.size())
117 |
118 | _c4_r = F.interpolate(_c4_r, size=c1_r.size()[2:],mode='bilinear',align_corners=self.align_corners)
119 | if verbose: print("_c4_r size after interpolate",_c4_r.size())
120 | _c4_d = F.interpolate(_c4_d, size=c1_d.size()[2:],mode='bilinear',align_corners=self.align_corners)
121 | if verbose: print("_c4_d size after interpolate",_c4_d.size())
122 |
123 |
124 |
125 | _c3_r = self.linear_c3_rgb(c3_r).permute(0,2,1).reshape(n, -1, c3_r.shape[2], c3_r.shape[3])
126 | if verbose: print("_c3_r size after linear_c3_rgb",_c3_r.size())
127 | _c3_d = self.linear_c3_depth(c3_d).permute(0,2,1).reshape(n, -1, c3_d.shape[2], c3_d.shape[3])
128 | if verbose: print("_c3_d size after linear_c3_depth",_c3_d.size())
129 |
130 | _c3_r = F.interpolate(_c3_r, size=c1_r.size()[2:],mode='bilinear',align_corners=self.align_corners)
131 | if verbose: print("_c3_r size after interpolate",_c3_r.size())
132 | _c3_d = F.interpolate(_c3_d, size=c1_d.size()[2:],mode='bilinear',align_corners=self.align_corners)
133 | if verbose: print("_c3_d size after interpolate",_c3_d.size())
134 |
135 |
136 |
137 | _c2_r = self.linear_c2_rgb(c2_r).permute(0,2,1).reshape(n, -1, c2_r.shape[2], c2_r.shape[3])
138 | if verbose: print("_c2_r size after linear_c2_rgb",_c2_r.size())
139 | _c2_d = self.linear_c2_depth(c2_d).permute(0,2,1).reshape(n, -1, c2_d.shape[2], c2_d.shape[3])
140 | if verbose: print("_c2_d size after linear_c2_depth",_c2_d.size())
141 |
142 | _c2_r = F.interpolate(_c2_r, size=c1_r.size()[2:],mode='bilinear',align_corners=self.align_corners)
143 | if verbose: print("_c2_r size after interpolate",_c2_r.size())
144 | _c2_d = F.interpolate(_c2_d, size=c1_d.size()[2:],mode='bilinear',align_corners=self.align_corners)
145 | if verbose: print("_c2_d size after interpolate",_c2_d.size())
146 |
147 |
148 |
149 | _c1_d = self.linear_c1_depth(c1_d).permute(0,2,1).reshape(n, -1, c1_d.shape[2], c1_d.shape[3])
150 | if verbose: print("_c1_d size after linear_c1_depth",_c1_d.size())
151 | _c1_r = self.linear_c1_rgb(c1_r).permute(0,2,1).reshape(n, -1, c1_r.shape[2], c1_r.shape[3])
152 | if verbose: print("_c1_r size after linear_c1_rgb",_c1_r.size())
153 |
154 |
155 | _c_d = self.linear_fuse_depth(torch.cat([_c4_d, _c3_d, _c2_d, _c1_d], dim=1))
156 | x_d = self.dropout(_c_d)
157 | x_d = self.linear_pred_depth(x_d)
158 |
159 | _c_r = self.linear_fuse_rgb(torch.cat([_c4_r, _c3_r, _c2_r, _c1_r], dim=1))
160 | x_r = self.dropout(_c_r)
161 | x_r = self.linear_pred_rgb(x_r)
162 |
163 |
164 | fusion = self.linear_fusion(torch.cat([_c4_r+_c4_d, _c3_r+_c3_d, _c2_r+_c2_d, _c1_r+_c1_d], dim=1))
165 |
166 |
167 |
168 | rgb_comple,depth_comple,rgb_fusion,depth_fusion = self.DRC(fusion,x_r,x_d)
169 |
170 |
171 | return x_r, rgb_comple, rgb_fusion, x_d, depth_comple, depth_fusion
172 |
173 | class DRC(nn.Module):
174 | def __init__(self,in_channel, n_class):
175 | super(DRC, self).__init__()
176 |
177 | self.rgb_segconv = nn.Conv2d(in_channels=in_channel,out_channels=n_class,kernel_size=1,stride=1,padding=0)
178 | self.depth_segconv = nn.Conv2d(in_channels=in_channel,out_channels=n_class,kernel_size=1,stride=1,padding=0)
179 |
180 |
181 | self.rgb_comple_conv1 = nn.Conv2d(in_channels=n_class,out_channels=n_class,kernel_size=3,stride=1,padding=1)
182 | self.rgb_comple_bn1 = nn.BatchNorm2d(n_class)
183 | self.rgb_comple_relu1 = nn.ReLU()
184 | self.rgb_comple_fusion_conv1 = nn.Conv2d(in_channels=n_class*2,out_channels=n_class,kernel_size=1,stride=1,padding=0)
185 |
186 |
187 | self.depth_comple_conv1 = nn.Conv2d(in_channels=n_class,out_channels=n_class,kernel_size=3,stride=1,padding=1)
188 | self.depth_comple_bn1 = nn.BatchNorm2d(n_class)
189 | self.depth_comple_relu1 = nn.ReLU()
190 | self.depth_comple_fusion_conv1 = nn.Conv2d(in_channels=n_class*2,out_channels=n_class,kernel_size=1,stride=1,padding=0)
191 |
192 |
193 | def forward(self,fusion,x_r,x_d):
194 |
195 | rgb_missing = self.rgb_segconv(fusion)
196 | rgb_comple1 = self.rgb_comple_conv1(rgb_missing)
197 | rgb_comple1 = self.rgb_comple_bn1(rgb_comple1)
198 | rgb_comple1 = self.rgb_comple_relu1(rgb_comple1)
199 | rgb_comple = rgb_missing+rgb_comple1
200 | rgb_fusion = self.rgb_comple_fusion_conv1(torch.cat((x_r,rgb_comple),dim=1))
201 |
202 | depth_missing = self.depth_segconv(fusion)
203 | depth_comple1 = self.depth_comple_conv1(depth_missing)
204 | depth_comple1 = self.depth_comple_bn1(depth_comple1)
205 | depth_comple1 = self.depth_comple_relu1(depth_comple1)
206 | depth_comple = depth_missing+depth_comple1
207 | depth_fusion = self.depth_comple_fusion_conv1(torch.cat((x_d,depth_comple),dim=1))
208 |
209 | return rgb_comple,depth_comple, rgb_fusion, depth_fusion
--------------------------------------------------------------------------------
/model/mysegformer/encoders/dual_segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | import math
8 | import time
9 | #from engine.logger import get_logger
10 |
11 | #logger = get_logger()
12 |
13 |
14 | class DWConv(nn.Module):
15 | """
16 | Depthwise convolution bloc: input: x with size(B N C); output size (B N C)
17 | """
18 | def __init__(self, dim=768):
19 | super(DWConv, self).__init__()
20 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim)
21 |
22 | def forward(self, x, H, W):
23 | B, N, C = x.shape
24 | x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # B N C -> B C N -> B C H W
25 | x = self.dwconv(x)
26 | x = x.flatten(2).transpose(1, 2) # B C H W -> B N C
27 |
28 | return x
29 |
30 |
31 | class Mlp(nn.Module):
32 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
33 | super().__init__()
34 | """
35 | MLP Block:
36 | """
37 | out_features = out_features or in_features
38 | hidden_features = hidden_features or in_features
39 | self.fc1 = nn.Linear(in_features, hidden_features)
40 | self.dwconv = DWConv(hidden_features)
41 | self.act = act_layer()
42 | self.fc2 = nn.Linear(hidden_features, out_features)
43 | self.drop = nn.Dropout(drop)
44 |
45 | self.apply(self._init_weights)
46 |
47 | def _init_weights(self, m):
48 | if isinstance(m, nn.Linear):
49 | trunc_normal_(m.weight, std=.02)
50 | if isinstance(m, nn.Linear) and m.bias is not None:
51 | nn.init.constant_(m.bias, 0)
52 | elif isinstance(m, nn.LayerNorm):
53 | nn.init.constant_(m.bias, 0)
54 | nn.init.constant_(m.weight, 1.0)
55 | elif isinstance(m, nn.Conv2d):
56 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
57 | fan_out //= m.groups
58 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
59 | if m.bias is not None:
60 | m.bias.data.zero_()
61 |
62 | def forward(self, x, H, W):
63 | x = self.fc1(x)
64 | x = self.dwconv(x, H, W)
65 | x = self.act(x)
66 | x = self.drop(x)
67 | x = self.fc2(x)
68 | x = self.drop(x)
69 | return x
70 |
71 |
72 | class Attention(nn.Module):
73 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
74 | super().__init__()
75 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
76 |
77 | self.dim = dim
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | self.scale = qk_scale or head_dim ** -0.5
81 |
82 | # Linear embedding
83 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
84 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
85 | self.attn_drop = nn.Dropout(attn_drop)
86 | self.proj = nn.Linear(dim, dim)
87 | self.proj_drop = nn.Dropout(proj_drop)
88 |
89 | self.sr_ratio = sr_ratio
90 | if sr_ratio > 1:
91 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
92 | self.norm = nn.LayerNorm(dim)
93 |
94 | self.apply(self._init_weights)
95 |
96 | def _init_weights(self, m):
97 | if isinstance(m, nn.Linear):
98 | trunc_normal_(m.weight, std=.02)
99 | if isinstance(m, nn.Linear) and m.bias is not None:
100 | nn.init.constant_(m.bias, 0)
101 | elif isinstance(m, nn.LayerNorm):
102 | nn.init.constant_(m.bias, 0)
103 | nn.init.constant_(m.weight, 1.0)
104 | elif isinstance(m, nn.Conv2d):
105 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | fan_out //= m.groups
107 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
108 | if m.bias is not None:
109 | m.bias.data.zero_()
110 |
111 | def forward(self, x, H, W):
112 | B, N, C = x.shape
113 | # B N C -> B N num_head C//num_head -> B C//num_head N num_heads
114 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
115 |
116 | if self.sr_ratio > 1:
117 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
118 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
119 | x_ = self.norm(x_)
120 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121 | else:
122 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123 | k, v = kv[0], kv[1]
124 |
125 | attn = (q @ k.transpose(-2, -1)) * self.scale
126 | attn = attn.softmax(dim=-1)
127 | attn = self.attn_drop(attn)
128 |
129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
130 | x = self.proj(x)
131 | x = self.proj_drop(x)
132 |
133 | return x
134 |
135 |
136 | class Block(nn.Module):
137 | """
138 | Transformer Block: Self-Attention -> Mix FFN -> OverLap Patch Merging
139 | """
140 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
142 | super().__init__()
143 | self.norm1 = norm_layer(dim)
144 | self.attn = Attention(
145 | dim,
146 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
147 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
148 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
149 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
150 | self.norm2 = norm_layer(dim)
151 | mlp_hidden_dim = int(dim * mlp_ratio)
152 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
153 |
154 | self.apply(self._init_weights)
155 |
156 | def _init_weights(self, m):
157 | if isinstance(m, nn.Linear):
158 | trunc_normal_(m.weight, std=.02)
159 | if isinstance(m, nn.Linear) and m.bias is not None:
160 | nn.init.constant_(m.bias, 0)
161 | elif isinstance(m, nn.LayerNorm):
162 | nn.init.constant_(m.bias, 0)
163 | nn.init.constant_(m.weight, 1.0)
164 | elif isinstance(m, nn.Conv2d):
165 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
166 | fan_out //= m.groups
167 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
168 | if m.bias is not None:
169 | m.bias.data.zero_()
170 |
171 | def forward(self, x, H, W):
172 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
173 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
174 |
175 | return x
176 |
177 |
178 | class OverlapPatchEmbed(nn.Module):
179 | """ Image to Patch Embedding
180 | """
181 |
182 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
183 | super().__init__()
184 | img_size = to_2tuple(img_size)
185 | patch_size = to_2tuple(patch_size)
186 |
187 | self.img_size = img_size
188 | self.patch_size = patch_size
189 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
190 | self.num_patches = self.H * self.W
191 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
192 | padding=(patch_size[0] // 2, patch_size[1] // 2))
193 | self.norm = nn.LayerNorm(embed_dim)
194 |
195 | self.apply(self._init_weights)
196 |
197 | def _init_weights(self, m):
198 | if isinstance(m, nn.Linear):
199 | trunc_normal_(m.weight, std=.02)
200 | if isinstance(m, nn.Linear) and m.bias is not None:
201 | nn.init.constant_(m.bias, 0)
202 | elif isinstance(m, nn.LayerNorm):
203 | nn.init.constant_(m.bias, 0)
204 | nn.init.constant_(m.weight, 1.0)
205 | elif isinstance(m, nn.Conv2d):
206 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
207 | fan_out //= m.groups
208 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
209 | if m.bias is not None:
210 | m.bias.data.zero_()
211 |
212 | def forward(self, x):
213 | # B C H W
214 | x = self.proj(x)
215 | _, _, H, W = x.shape
216 | x = x.flatten(2).transpose(1, 2)
217 | # B H*W/16 C
218 | x = self.norm(x)
219 |
220 | return x, H, W
221 |
222 |
223 | class RGBXTransformer(nn.Module):
224 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
225 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
226 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d,
227 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
228 | super().__init__()
229 | self.num_classes = num_classes
230 | self.depths = depths
231 |
232 | # patch_embed
233 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
234 | embed_dim=embed_dims[0])
235 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
236 | embed_dim=embed_dims[1])
237 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
238 | embed_dim=embed_dims[2])
239 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
240 | embed_dim=embed_dims[3])
241 |
242 | self.extra_patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
243 | embed_dim=embed_dims[0])
244 | self.extra_patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
245 | embed_dim=embed_dims[1])
246 | self.extra_patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
247 | embed_dim=embed_dims[2])
248 | self.extra_patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
249 | embed_dim=embed_dims[3])
250 |
251 | # transformer encoder
252 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
253 | cur = 0
254 |
255 | self.block1 = nn.ModuleList([Block(
256 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
257 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
258 | sr_ratio=sr_ratios[0])
259 | for i in range(depths[0])])
260 | self.norm1 = norm_layer(embed_dims[0])
261 |
262 | self.extra_block1 = nn.ModuleList([Block(
263 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
264 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
265 | sr_ratio=sr_ratios[0])
266 | for i in range(depths[0])])
267 | self.extra_norm1 = norm_layer(embed_dims[0])
268 | cur += depths[0]
269 |
270 | self.block2 = nn.ModuleList([Block(
271 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
272 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur], norm_layer=norm_layer,
273 | sr_ratio=sr_ratios[1])
274 | for i in range(depths[1])])
275 | self.norm2 = norm_layer(embed_dims[1])
276 |
277 | self.extra_block2 = nn.ModuleList([Block(
278 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
279 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur+1], norm_layer=norm_layer,
280 | sr_ratio=sr_ratios[1])
281 | for i in range(depths[1])])
282 | self.extra_norm2 = norm_layer(embed_dims[1])
283 |
284 | cur += depths[1]
285 |
286 | self.block3 = nn.ModuleList([Block(
287 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
288 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
289 | sr_ratio=sr_ratios[2])
290 | for i in range(depths[2])])
291 | self.norm3 = norm_layer(embed_dims[2])
292 |
293 | self.extra_block3 = nn.ModuleList([Block(
294 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
295 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
296 | sr_ratio=sr_ratios[2])
297 | for i in range(depths[2])])
298 | self.extra_norm3 = norm_layer(embed_dims[2])
299 |
300 | cur += depths[2]
301 |
302 | self.block4 = nn.ModuleList([Block(
303 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
304 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
305 | sr_ratio=sr_ratios[3])
306 | for i in range(depths[3])])
307 | self.norm4 = norm_layer(embed_dims[3])
308 |
309 | self.extra_block4 = nn.ModuleList([Block(
310 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
311 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
312 | sr_ratio=sr_ratios[3])
313 | for i in range(depths[3])])
314 | self.extra_norm4 = norm_layer(embed_dims[3])
315 |
316 | cur += depths[3]
317 |
318 | self.apply(self._init_weights)
319 |
320 | def _init_weights(self, m):
321 | if isinstance(m, nn.Linear):
322 | trunc_normal_(m.weight, std=.02)
323 | if isinstance(m, nn.Linear) and m.bias is not None:
324 | nn.init.constant_(m.bias, 0)
325 | elif isinstance(m, nn.LayerNorm):
326 | nn.init.constant_(m.bias, 0)
327 | nn.init.constant_(m.weight, 1.0)
328 | elif isinstance(m, nn.Conv2d):
329 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
330 | fan_out //= m.groups
331 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
332 | if m.bias is not None:
333 | m.bias.data.zero_()
334 |
335 | def init_weights(self, pretrained=None):
336 | if isinstance(pretrained, str):
337 | load_dualpath_model(self, pretrained)
338 | else:
339 | raise TypeError('pretrained must be a str or None')
340 |
341 | def forward_features(self, x_rgb, x_e):
342 | """
343 | x_rgb: B x N x H x W
344 | """
345 | B = x_rgb.shape[0]
346 | rgb = []
347 | depth = []
348 |
349 | # stage 1
350 | x_rgb, H, W = self.patch_embed1(x_rgb)
351 | # B H*W/16 C
352 | x_e, _, _ = self.extra_patch_embed1(x_e)
353 | for i, blk in enumerate(self.block1):
354 | x_rgb = blk(x_rgb, H, W)
355 | for i, blk in enumerate(self.extra_block1):
356 | x_e = blk(x_e, H, W)
357 | x_rgb = self.norm1(x_rgb)
358 | x_e = self.extra_norm1(x_e)
359 |
360 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
361 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
362 |
363 | rgb.append(x_rgb)
364 | depth.append(x_e)
365 |
366 | # stage 2
367 | x_rgb, H, W = self.patch_embed2(x_rgb)
368 | x_e, _, _ = self.extra_patch_embed2(x_e)
369 | for i, blk in enumerate(self.block2):
370 | x_rgb = blk(x_rgb, H, W)
371 | for i, blk in enumerate(self.extra_block2):
372 | x_e = blk(x_e, H, W)
373 | x_rgb = self.norm2(x_rgb)
374 | x_e = self.extra_norm2(x_e)
375 |
376 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
377 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
378 |
379 | rgb.append(x_rgb)
380 | depth.append(x_e)
381 |
382 | # stage 3
383 | x_rgb, H, W = self.patch_embed3(x_rgb)
384 | x_e, _, _ = self.extra_patch_embed3(x_e)
385 | for i, blk in enumerate(self.block3):
386 | x_rgb = blk(x_rgb, H, W)
387 | for i, blk in enumerate(self.extra_block3):
388 | x_e = blk(x_e, H, W)
389 | x_rgb = self.norm3(x_rgb)
390 | x_e = self.extra_norm3(x_e)
391 |
392 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
393 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
394 |
395 | rgb.append(x_rgb)
396 | depth.append(x_e)
397 |
398 |
399 | # stage 4
400 | x_rgb, H, W = self.patch_embed4(x_rgb)
401 | x_e, _, _ = self.extra_patch_embed4(x_e)
402 | for i, blk in enumerate(self.block4):
403 | x_rgb = blk(x_rgb, H, W)
404 | for i, blk in enumerate(self.extra_block4):
405 | x_e = blk(x_e, H, W)
406 | x_rgb = self.norm4(x_rgb)
407 | x_e = self.extra_norm4(x_e)
408 |
409 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
410 | x_e = x_e.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
411 |
412 | rgb.append(x_rgb)
413 | depth.append(x_e)
414 |
415 | return rgb,depth
416 |
417 | def forward(self, x_rgb, x_e):
418 | rgb,depth = self.forward_features(x_rgb, x_e)
419 | return rgb,depth
420 |
421 |
422 | def load_dualpath_model(model, model_file):
423 | # load raw state_dict
424 | t_start = time.time()
425 | if isinstance(model_file, str):
426 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
427 | #raw_state_dict = torch.load(model_file)
428 | if 'model' in raw_state_dict.keys():
429 | raw_state_dict = raw_state_dict['model']
430 | else:
431 | raw_state_dict = model_file
432 |
433 | state_dict = {}
434 | for k, v in raw_state_dict.items():
435 | if k.find('patch_embed') >= 0:
436 | state_dict[k] = v
437 | state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v
438 | elif k.find('block') >= 0:
439 | state_dict[k] = v
440 | state_dict[k.replace('block', 'extra_block')] = v
441 | elif k.find('norm') >= 0:
442 | state_dict[k] = v
443 | state_dict[k.replace('norm', 'extra_norm')] = v
444 |
445 | t_ioend = time.time()
446 |
447 | model.load_state_dict(state_dict, strict=False)
448 | del state_dict
449 |
450 | t_end = time.time()
451 |
452 |
453 | class mit_b0(RGBXTransformer):
454 | def __init__(self, fuse_cfg=None, **kwargs):
455 | super(mit_b0, self).__init__(
456 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
457 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
458 | drop_rate=0.0, drop_path_rate=0.1)
459 |
460 |
461 | class mit_b1(RGBXTransformer):
462 | def __init__(self, fuse_cfg=None, **kwargs):
463 | super(mit_b1, self).__init__(
464 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
465 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
466 | drop_rate=0.0, drop_path_rate=0.1)
467 |
468 |
469 | class mit_b2(RGBXTransformer):
470 | def __init__(self, fuse_cfg=None, **kwargs):
471 | super(mit_b2, self).__init__(
472 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
473 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
474 | drop_rate=0.0, drop_path_rate=0.1)
475 |
476 |
477 | class mit_b3(RGBXTransformer):
478 | def __init__(self, fuse_cfg=None, **kwargs):
479 | super(mit_b3, self).__init__(
480 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
481 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
482 | drop_rate=0.0, drop_path_rate=0.1)
483 |
484 |
485 | class mit_b4(RGBXTransformer):
486 | def __init__(self, fuse_cfg=None, **kwargs):
487 | super(mit_b4, self).__init__(
488 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
489 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
490 | drop_rate=0.0, drop_path_rate=0.1)
491 |
492 |
493 | class mit_b5(RGBXTransformer):
494 | def __init__(self, fuse_cfg=None, **kwargs):
495 | super(mit_b5, self).__init__(
496 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
497 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
498 | drop_rate=0.0, drop_path_rate=0.1)
499 |
--------------------------------------------------------------------------------
/model/mysegformer/encoders/one_segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | import math
8 | import time
9 | #from engine.logger import get_logger
10 |
11 | #logger = get_logger()
12 |
13 |
14 | class DWConv(nn.Module):
15 | """
16 | Depthwise convolution bloc: input: x with size(B N C); output size (B N C)
17 | """
18 | def __init__(self, dim=768):
19 | super(DWConv, self).__init__()
20 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim)
21 |
22 | def forward(self, x, H, W):
23 | B, N, C = x.shape
24 | x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # B N C -> B C N -> B C H W
25 | x = self.dwconv(x)
26 | x = x.flatten(2).transpose(1, 2) # B C H W -> B N C
27 |
28 | return x
29 |
30 |
31 | class Mlp(nn.Module):
32 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
33 | super().__init__()
34 | """
35 | MLP Block:
36 | """
37 | out_features = out_features or in_features
38 | hidden_features = hidden_features or in_features
39 | self.fc1 = nn.Linear(in_features, hidden_features)
40 | self.dwconv = DWConv(hidden_features)
41 | self.act = act_layer()
42 | self.fc2 = nn.Linear(hidden_features, out_features)
43 | self.drop = nn.Dropout(drop)
44 |
45 | self.apply(self._init_weights)
46 |
47 | def _init_weights(self, m):
48 | if isinstance(m, nn.Linear):
49 | trunc_normal_(m.weight, std=.02)
50 | if isinstance(m, nn.Linear) and m.bias is not None:
51 | nn.init.constant_(m.bias, 0)
52 | elif isinstance(m, nn.LayerNorm):
53 | nn.init.constant_(m.bias, 0)
54 | nn.init.constant_(m.weight, 1.0)
55 | elif isinstance(m, nn.Conv2d):
56 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
57 | fan_out //= m.groups
58 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
59 | if m.bias is not None:
60 | m.bias.data.zero_()
61 |
62 | def forward(self, x, H, W):
63 | x = self.fc1(x)
64 | x = self.dwconv(x, H, W)
65 | x = self.act(x)
66 | x = self.drop(x)
67 | x = self.fc2(x)
68 | x = self.drop(x)
69 | return x
70 |
71 |
72 | class Attention(nn.Module):
73 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
74 | super().__init__()
75 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
76 |
77 | self.dim = dim
78 | self.num_heads = num_heads
79 | head_dim = dim // num_heads
80 | self.scale = qk_scale or head_dim ** -0.5
81 |
82 | # Linear embedding
83 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
84 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
85 | self.attn_drop = nn.Dropout(attn_drop)
86 | self.proj = nn.Linear(dim, dim)
87 | self.proj_drop = nn.Dropout(proj_drop)
88 |
89 | self.sr_ratio = sr_ratio
90 | if sr_ratio > 1:
91 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
92 | self.norm = nn.LayerNorm(dim)
93 |
94 | self.apply(self._init_weights)
95 |
96 | def _init_weights(self, m):
97 | if isinstance(m, nn.Linear):
98 | trunc_normal_(m.weight, std=.02)
99 | if isinstance(m, nn.Linear) and m.bias is not None:
100 | nn.init.constant_(m.bias, 0)
101 | elif isinstance(m, nn.LayerNorm):
102 | nn.init.constant_(m.bias, 0)
103 | nn.init.constant_(m.weight, 1.0)
104 | elif isinstance(m, nn.Conv2d):
105 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | fan_out //= m.groups
107 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
108 | if m.bias is not None:
109 | m.bias.data.zero_()
110 |
111 | def forward(self, x, H, W):
112 | B, N, C = x.shape
113 | # B N C -> B N num_head C//num_head -> B C//num_head N num_heads
114 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
115 |
116 | if self.sr_ratio > 1:
117 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
118 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
119 | x_ = self.norm(x_)
120 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121 | else:
122 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123 | k, v = kv[0], kv[1]
124 |
125 | attn = (q @ k.transpose(-2, -1)) * self.scale
126 | attn = attn.softmax(dim=-1)
127 | attn = self.attn_drop(attn)
128 |
129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
130 | x = self.proj(x)
131 | x = self.proj_drop(x)
132 |
133 | return x
134 |
135 |
136 | class Block(nn.Module):
137 | """
138 | Transformer Block: Self-Attention -> Mix FFN -> OverLap Patch Merging
139 | """
140 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
142 | super().__init__()
143 | self.norm1 = norm_layer(dim)
144 | self.attn = Attention(
145 | dim,
146 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
147 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
148 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
149 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
150 | self.norm2 = norm_layer(dim)
151 | mlp_hidden_dim = int(dim * mlp_ratio)
152 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
153 |
154 | self.apply(self._init_weights)
155 |
156 | def _init_weights(self, m):
157 | if isinstance(m, nn.Linear):
158 | trunc_normal_(m.weight, std=.02)
159 | if isinstance(m, nn.Linear) and m.bias is not None:
160 | nn.init.constant_(m.bias, 0)
161 | elif isinstance(m, nn.LayerNorm):
162 | nn.init.constant_(m.bias, 0)
163 | nn.init.constant_(m.weight, 1.0)
164 | elif isinstance(m, nn.Conv2d):
165 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
166 | fan_out //= m.groups
167 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
168 | if m.bias is not None:
169 | m.bias.data.zero_()
170 |
171 | def forward(self, x, H, W):
172 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
173 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
174 |
175 | return x
176 |
177 |
178 | class OverlapPatchEmbed(nn.Module):
179 | """ Image to Patch Embedding
180 | """
181 |
182 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
183 | super().__init__()
184 | img_size = to_2tuple(img_size)
185 | patch_size = to_2tuple(patch_size)
186 |
187 | self.img_size = img_size
188 | self.patch_size = patch_size
189 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
190 | self.num_patches = self.H * self.W
191 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
192 | padding=(patch_size[0] // 2, patch_size[1] // 2))
193 | self.norm = nn.LayerNorm(embed_dim)
194 |
195 | self.apply(self._init_weights)
196 |
197 | def _init_weights(self, m):
198 | if isinstance(m, nn.Linear):
199 | trunc_normal_(m.weight, std=.02)
200 | if isinstance(m, nn.Linear) and m.bias is not None:
201 | nn.init.constant_(m.bias, 0)
202 | elif isinstance(m, nn.LayerNorm):
203 | nn.init.constant_(m.bias, 0)
204 | nn.init.constant_(m.weight, 1.0)
205 | elif isinstance(m, nn.Conv2d):
206 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
207 | fan_out //= m.groups
208 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
209 | if m.bias is not None:
210 | m.bias.data.zero_()
211 |
212 | def forward(self, x):
213 | # B C H W
214 | x = self.proj(x)
215 | _, _, H, W = x.shape
216 | x = x.flatten(2).transpose(1, 2)
217 | # B H*W/16 C
218 | x = self.norm(x)
219 |
220 | return x, H, W
221 |
222 |
223 | class RGBXTransformer(nn.Module):
224 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
225 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
226 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d,
227 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
228 | super().__init__()
229 | self.num_classes = num_classes
230 | self.depths = depths
231 |
232 | # patch_embed
233 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
234 | embed_dim=embed_dims[0])
235 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
236 | embed_dim=embed_dims[1])
237 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
238 | embed_dim=embed_dims[2])
239 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
240 | embed_dim=embed_dims[3])
241 |
242 | # transformer encoder
243 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
244 | cur = 0
245 |
246 | self.block1 = nn.ModuleList([Block(
247 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
248 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
249 | sr_ratio=sr_ratios[0])
250 | for i in range(depths[0])])
251 | self.norm1 = norm_layer(embed_dims[0])
252 |
253 | self.block2 = nn.ModuleList([Block(
254 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
255 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur], norm_layer=norm_layer,
256 | sr_ratio=sr_ratios[1])
257 | for i in range(depths[1])])
258 | self.norm2 = norm_layer(embed_dims[1])
259 |
260 | cur += depths[1]
261 |
262 | self.block3 = nn.ModuleList([Block(
263 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
264 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
265 | sr_ratio=sr_ratios[2])
266 | for i in range(depths[2])])
267 | self.norm3 = norm_layer(embed_dims[2])
268 |
269 | cur += depths[2]
270 |
271 | self.block4 = nn.ModuleList([Block(
272 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
273 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
274 | sr_ratio=sr_ratios[3])
275 | for i in range(depths[3])])
276 | self.norm4 = norm_layer(embed_dims[3])
277 |
278 | cur += depths[3]
279 |
280 | self.apply(self._init_weights)
281 |
282 | def _init_weights(self, m):
283 | if isinstance(m, nn.Linear):
284 | trunc_normal_(m.weight, std=.02)
285 | if isinstance(m, nn.Linear) and m.bias is not None:
286 | nn.init.constant_(m.bias, 0)
287 | elif isinstance(m, nn.LayerNorm):
288 | nn.init.constant_(m.bias, 0)
289 | nn.init.constant_(m.weight, 1.0)
290 | elif isinstance(m, nn.Conv2d):
291 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
292 | fan_out //= m.groups
293 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
294 | if m.bias is not None:
295 | m.bias.data.zero_()
296 |
297 | def init_weights(self, pretrained=None):
298 | if isinstance(pretrained, str):
299 | load_dualpath_model(self, pretrained)
300 | else:
301 | raise TypeError('pretrained must be a str or None')
302 |
303 | def forward_features(self, x_rgb, x_e):
304 | """
305 | x_rgb: B x N x H x W
306 | """
307 | B = x_rgb.shape[0]
308 | rgb = []
309 | depth = []
310 |
311 |
312 | # stage 1
313 | x_rgb, H, W = self.patch_embed1(x_rgb)
314 | # B H*W/16 C
315 | for i, blk in enumerate(self.block1):
316 | x_rgb = blk(x_rgb, H, W)
317 | x_rgb = self.norm1(x_rgb)
318 |
319 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
320 |
321 | rgb.append(x_rgb)
322 |
323 |
324 |
325 | # stage 2
326 | x_rgb, H, W = self.patch_embed2(x_rgb)
327 | for i, blk in enumerate(self.block2):
328 | x_rgb = blk(x_rgb, H, W)
329 | x_rgb = self.norm2(x_rgb)
330 |
331 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
332 |
333 | rgb.append(x_rgb)
334 |
335 | # stage 3
336 | x_rgb, H, W = self.patch_embed3(x_rgb)
337 | for i, blk in enumerate(self.block3):
338 | x_rgb = blk(x_rgb, H, W)
339 | x_rgb = self.norm3(x_rgb)
340 |
341 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
342 |
343 | rgb.append(x_rgb)
344 |
345 |
346 | # stage 4
347 | x_rgb, H, W = self.patch_embed4(x_rgb)
348 | for i, blk in enumerate(self.block4):
349 | x_rgb = blk(x_rgb, H, W)
350 | x_rgb = self.norm4(x_rgb)
351 |
352 | x_rgb = x_rgb.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
353 |
354 | rgb.append(x_rgb)
355 | depth = []
356 | return rgb,depth
357 |
358 | def forward(self, x_rgb, x_e):
359 | rgb,depth = self.forward_features(x_rgb, x_e)
360 | return rgb,depth
361 |
362 |
363 | def load_dualpath_model(model, model_file):
364 | # load raw state_dict
365 | t_start = time.time()
366 | if isinstance(model_file, str):
367 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))
368 | #raw_state_dict = torch.load(model_file)
369 | if 'model' in raw_state_dict.keys():
370 | raw_state_dict = raw_state_dict['model']
371 | else:
372 | raw_state_dict = model_file
373 |
374 | state_dict = {}
375 | for k, v in raw_state_dict.items():
376 | if k.find('patch_embed') >= 0:
377 | state_dict[k] = v
378 | state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v
379 | elif k.find('block') >= 0:
380 | state_dict[k] = v
381 | state_dict[k.replace('block', 'extra_block')] = v
382 | elif k.find('norm') >= 0:
383 | state_dict[k] = v
384 | state_dict[k.replace('norm', 'extra_norm')] = v
385 |
386 | t_ioend = time.time()
387 |
388 | model.load_state_dict(state_dict, strict=False)
389 | del state_dict
390 |
391 | t_end = time.time()
392 |
393 |
394 | class mit_b0(RGBXTransformer):
395 | def __init__(self, fuse_cfg=None, **kwargs):
396 | super(mit_b0, self).__init__(
397 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
398 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
399 | drop_rate=0.0, drop_path_rate=0.1)
400 |
401 |
402 | class mit_b1(RGBXTransformer):
403 | def __init__(self, fuse_cfg=None, **kwargs):
404 | super(mit_b1, self).__init__(
405 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
406 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
407 | drop_rate=0.0, drop_path_rate=0.1)
408 |
409 |
410 | class mit_b2(RGBXTransformer):
411 | def __init__(self, fuse_cfg=None, **kwargs):
412 | super(mit_b2, self).__init__(
413 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
414 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
415 | drop_rate=0.0, drop_path_rate=0.1)
416 |
417 |
418 | class mit_b3(RGBXTransformer):
419 | def __init__(self, fuse_cfg=None, **kwargs):
420 | super(mit_b3, self).__init__(
421 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
422 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
423 | drop_rate=0.0, drop_path_rate=0.1)
424 |
425 |
426 | class mit_b4(RGBXTransformer):
427 | def __init__(self, fuse_cfg=None, **kwargs):
428 | super(mit_b4, self).__init__(
429 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
430 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
431 | drop_rate=0.0, drop_path_rate=0.1)
432 |
433 |
434 | class mit_b5(RGBXTransformer):
435 | def __init__(self, fuse_cfg=None, **kwargs):
436 | super(mit_b5, self).__init__(
437 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
438 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
439 | drop_rate=0.0, drop_path_rate=0.1)
440 |
--------------------------------------------------------------------------------
/run_demo.py:
--------------------------------------------------------------------------------
1 | import os, argparse, time, datetime, stat, shutil,sys
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | import torchvision.utils as vutils
9 | from util.MY_dataset import MY_dataset
10 | from sklearn.metrics import confusion_matrix
11 | from util.util import compute_results, visualize
12 | from scipy.io import savemat
13 | from torch.utils.tensorboard import SummaryWriter
14 | from model import PotCrackSeg
15 | from util.lr_policy import WarmUpPolyLR
16 | from util.init_func import init_weight, group_weight
17 | from config import config
18 |
19 | #############################################################################################
20 | parser = argparse.ArgumentParser(description='Test with pytorch')
21 | #############################################################################################
22 | parser.add_argument('--model_name', '-m', type=str, default='PotCrackSeg') # DRCNet_RDe_b3V3, DRCNet_RDe_b4V3, DRCNet_RDe_b5V3
23 | parser.add_argument('--weight_name', '-w', type=str, default='PotCrackSeg-4B') # DRCNet_RDe_b3V3, DRCNet_RDe_b4V3, DRCNet_RDe_b5V3
24 | parser.add_argument('--backbone', '-bac', type=str, default='PotCrackSeg-4B') # mit_3, mit_4, mit_5
25 | parser.add_argument('--file_name', '-f', type=str, default='final.pth')
26 | parser.add_argument('--dataset_split', '-d', type=str, default='test') # normal_test, abnormal_test, urban_test,rural_test
27 | parser.add_argument('--gpu', '-g', type=int, default=1)
28 | #############################################################################################
29 | parser.add_argument('--img_height', '-ih', type=int, default=288)
30 | parser.add_argument('--img_width', '-iw', type=int, default=512)
31 | parser.add_argument('--num_workers', '-j', type=int, default=16)
32 | parser.add_argument('--n_class', '-nc', type=int, default=3)
33 | parser.add_argument('--data_dir', '-dr', type=str, default='./NPO++/')
34 | parser.add_argument('--model_dir', '-wd', type=str, default='./weights_backup/')
35 | args = parser.parse_args()
36 | #############################################################################################
37 |
38 | def get_palette():
39 | unlabelled = [0,0,0]
40 | potholes = [153,0,0]
41 | cracks = [0,153,0]
42 | palette = np.array([unlabelled,potholes, cracks])
43 | return palette
44 |
45 |
46 | if __name__ == '__main__':
47 |
48 | torch.cuda.set_device(args.gpu)
49 | print("\nthe pytorch version:", torch.__version__)
50 | print("the gpu count:", torch.cuda.device_count())
51 | print("the current used gpu:", torch.cuda.current_device(), '\n')
52 |
53 | # prepare save direcotry
54 | if os.path.exists("./runs"):
55 | print("previous \"./runs\" folder exist, will delete this folder")
56 | shutil.rmtree("./runs")
57 | os.makedirs("./runs")
58 | os.chmod("./runs", stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine
59 | model_dir = os.path.join(args.model_dir, args.weight_name)
60 | if os.path.exists(model_dir) is False:
61 | sys.exit("the %s does not exit." %(model_dir))
62 | model_file = os.path.join(model_dir, args.file_name)
63 | if os.path.exists(model_file) is True:
64 | print('use the final model file.')
65 | else:
66 | sys.exit('no model file found.')
67 | print('testing %s: %s on GPU #%d with pytorch' % (args.model_name, args.weight_name, args.gpu))
68 |
69 | conf_total = np.zeros((args.n_class, args.n_class))
70 | model = eval(args.model_name)(cfg = config ,n_class=args.n_class, encoder_name=args.backbone)
71 | if args.gpu >= 0: model.cuda(args.gpu)
72 | print('loading model file %s... ' % model_file)
73 | pretrained_weight = torch.load(model_file, map_location = lambda storage, loc: storage.cuda(args.gpu))
74 | own_state = model.state_dict()
75 | for name, param in pretrained_weight.items():
76 | own_state[name].copy_(param)
77 | print('done!')
78 |
79 | batch_size = 1
80 | test_dataset = MY_dataset(data_dir=args.data_dir, split=args.dataset_split, input_h=args.img_height, input_w=args.img_width)
81 | test_loader = DataLoader(
82 | dataset = test_dataset,
83 | batch_size = batch_size,
84 | shuffle = False,
85 | num_workers = args.num_workers,
86 | pin_memory = True,
87 | drop_last = False
88 | )
89 | ave_time_cost = 0.0
90 |
91 | model.eval()
92 | with torch.no_grad():
93 | for it, (images, labels, names) in enumerate(test_loader):
94 | images = Variable(images).cuda(args.gpu)
95 | labels = Variable(labels).cuda(args.gpu)
96 | # flop,params = profile(model,inputs=(images,))
97 | # print(flop)
98 | # print(params)
99 | torch.cuda.synchronize()
100 | start_time = time.time()
101 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images)
102 | torch.cuda.synchronize()
103 | end_time = time.time()
104 | if it>=5: # # ignore the first 5 frames
105 | ave_time_cost += (end_time-start_time)
106 | # convert tensor to numpy 1d array
107 | label = labels.cpu().numpy().squeeze().flatten()
108 | prediction = logits.argmax(1).cpu().numpy().squeeze().flatten() # prediction and label are both 1-d array, size: minibatch*640*480
109 | # generate confusion matrix frame-by-frame
110 | conf = confusion_matrix(y_true=label, y_pred=prediction, labels=[0,1,2]) # conf is an n_class*n_class matrix, vertical axis: groundtruth, horizontal axis: prediction
111 | conf_total += conf
112 | # save demo images
113 | visualize(image_name=names, predictions=logits.argmax(1), weight_name=args.weight_name)
114 | print("%s, %s, frame %d/%d, %s, time cost: %.2f ms, demo result saved."
115 | %(args.model_name, args.weight_name, it+1, len(test_loader), names, (end_time-start_time)*1000))
116 |
117 | precision_per_class, recall_per_class, iou_per_class,F1_per_class = compute_results(conf_total)
118 | #precision, recall, IoU,F1 = compute_results(conf_total)
119 | conf_total_matfile = os.path.join("./runs", 'conf_'+args.weight_name+'.mat')
120 | savemat(conf_total_matfile, {'conf': conf_total}) # 'conf' is the variable name when loaded in Matlab
121 |
122 | print('\n###########################################################################')
123 | print('\n%s: %s test results (with batch size %d) on %s using %s:' %(args.model_name, args.weight_name, batch_size, datetime.date.today(), torch.cuda.get_device_name(args.gpu)))
124 | print('\n* the tested dataset name: %s' % args.dataset_split)
125 | print('* the tested image count: %d' % len(test_loader))
126 | print('* the tested image size: %d*%d' %(args.img_height, args.img_width))
127 | print('* the weight name: %s' %args.weight_name)
128 | print('* the file name: %s' %args.file_name)
129 | print("* iou per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \
130 | %(iou_per_class[0]*100, iou_per_class[1]*100, iou_per_class[2]*100))
131 | print("* recall per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \
132 | %(recall_per_class[0]*100, recall_per_class[1]*100, recall_per_class[2]*100))
133 | print("* pre per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \
134 | %(precision_per_class[0]*100, precision_per_class[1]*100, precision_per_class[2]*100))
135 | print("* F1 per class: \n unlabeled: %.1f, pothole: %.1f, crack: %.1f" \
136 | %(F1_per_class[0]*100, F1_per_class[1]*100, F1_per_class[2]*100))
137 |
138 | print("\n* average values (np.mean(x)): \n iou: %.3f, recall: %.3f, pre: %.3f, F1: %.3f" \
139 | %(iou_per_class[1:].mean()*100,recall_per_class[1:].mean()*100, precision_per_class[1:].mean()*100,F1_per_class[1:].mean()*100))
140 | print("* average values (np.mean(np.nan_to_num(x))): \n iou: %.1f, recall: %.1f, pre: %.1f, F1: %.1f" \
141 | %(np.mean(np.nan_to_num(iou_per_class[1:]))*100, np.mean(np.nan_to_num(recall_per_class[1:]))*100, np.mean(np.nan_to_num(precision_per_class[1:]))*100, np.mean(np.nan_to_num(F1_per_class[1:]))*100))
142 |
143 |
144 | print('\n* the average time cost per frame (with batch size %d): %.2f ms, namely, the inference speed is %.2f fps' %(batch_size, ave_time_cost*1000/(len(test_loader)-5), 1.0/(ave_time_cost/(len(test_loader)-5)))) # ignore the first 10 frames
145 | print('\n###########################################################################')
146 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os, argparse, time, datetime, stat, shutil,sys
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | import torchvision.utils as vutils
9 | from util.MY_dataset import MY_dataset
10 | from util.augmentation import RandomFlip, RandomCrop, RandomCropOut, RandomBrightness, RandomNoise
11 | from util.util import compute_results
12 | from sklearn.metrics import confusion_matrix
13 | from torch.utils.tensorboard import SummaryWriter
14 | from model import PotCrackSeg
15 | from util.lr_policy import WarmUpPolyLR
16 | from util.init_func import init_weight, group_weight
17 | from config import config
18 |
19 | #############################################################################################
20 | parser = argparse.ArgumentParser(description='Train with pytorch')
21 | #############################################################################################
22 | parser.add_argument('--model_name', '-m', type=str, default='PotCrackSeg')
23 | parser.add_argument('--batch_size', '-b', type=int, default=2)
24 | parser.add_argument('--lr_start', '-ls', type=float, default=6e-5)
25 | parser.add_argument('--gpu', '-g', type=int, default=0)
26 | #############################################################################################
27 | parser.add_argument('--lr_decay', '-ld', type=float, default=0.95)
28 | parser.add_argument('--epoch_max', '-em', type=int, default=500) # please stop training mannully
29 | parser.add_argument('--epoch_from', '-ef', type=int, default=0)
30 | parser.add_argument('--num_workers', '-j', type=int, default=8)
31 | parser.add_argument('--n_class', '-nc', type=int, default=3)
32 | parser.add_argument('--data_dir', '-dr', type=str, default='./NPO++/')
33 | parser.add_argument('--pre_weight', '-prw', type=str, default='/pretrained/mit_b2.pth')
34 | parser.add_argument('--backbone', '-bac', type=str, default='PotCrackSeg-2B')
35 | parser.add_argument('--model_dir', '-wd', type=str, default='./weights_backup/')
36 | # parser.add_argument('--weight_name', '-w', type=str, default='DRCNet_0DRC_RDe_b0') # RTFNet_152, RTFNet_50, please change the number of layers in the network file
37 | # parser.add_argument('--file_name', '-f', type=str, default='109.pth')
38 | args = parser.parse_args()
39 | #############################################################################################
40 |
41 | augmentation_methods = [
42 | RandomFlip(prob=0.5),
43 | RandomCrop(crop_rate=0.1, prob=1.0),
44 | # RandomCropOut(crop_rate=0.2, prob=1.0),
45 | # RandomBrightness(bright_range=0.15, prob=0.9),
46 | # RandomNoise(noise_range=5, prob=0.9),
47 | ]
48 |
49 | def fusion_loss(rgb_predict, rgb_comple,depth_predict, depth_comple,label):
50 |
51 | feature_map_B, feature_map_C, feature_map_W, feature_map_H = rgb_predict.size()
52 | label_B, label_W, label_H = label.size()
53 |
54 | if feature_map_W != label_W:
55 | label = torch.cuda.FloatTensor(label.unsqueeze(1).cpu().numpy())
56 | label = F.interpolate(label,[feature_map_W,feature_map_H],mode="nearest")
57 | label = torch.cuda.LongTensor(label.squeeze(1).cpu().numpy())
58 |
59 | loss_pr_rgb_seg = F.cross_entropy(rgb_predict, label)
60 | rgb_predict = rgb_predict.detach()
61 | rgb_predict=rgb_predict.argmax(1)
62 | rgb_predict.eq_(label)
63 | rgb_predict=rgb_predict.clone().detach_().requires_grad_(False)
64 | add_map_rgb = (1-rgb_predict)*label
65 | add_map_rgb=add_map_rgb.clone().detach_().requires_grad_(False)
66 | loss_add_rgb = F.cross_entropy(rgb_comple,add_map_rgb)
67 |
68 | loss_pr_depth_seg = F.cross_entropy(depth_predict, label)
69 | depth_predict = depth_predict.detach()
70 | depth_predict=depth_predict.argmax(1)
71 | depth_predict.eq_(label)
72 | depth_predict=depth_predict.clone().detach_().requires_grad_(False)
73 | add_map_depth = (1-depth_predict)*label
74 | add_map_depth=add_map_depth.clone().detach_().requires_grad_(False)
75 | loss_add_depth = F.cross_entropy(depth_comple,add_map_depth)
76 |
77 | loss = loss_pr_rgb_seg+loss_add_rgb+loss_pr_depth_seg+loss_add_depth
78 | return loss,add_map_rgb,add_map_depth
79 |
80 |
81 |
82 | def train(epo, model, train_loader, optimizer):
83 | model.train()
84 | loss_sum = 0
85 |
86 | loss_seg_sum = 0
87 | for it, (images, labels, names) in enumerate(train_loader):
88 | images = Variable(images).cuda(args.gpu)
89 | labels = Variable(labels).cuda(args.gpu)
90 |
91 | start_t = time.time() # time.time() returns the current time
92 | optimizer.zero_grad()
93 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images)
94 |
95 | loss1,add_map_rgb,add_map_depth = fusion_loss(rgb_predict, rgb_comple,depth_predict, depth_comple,labels)
96 | #loss2 = F.cross_entropy(rgb_fusion, labels)
97 | #loss3 = F.cross_entropy(depth_fusion, labels)
98 | loss4 = F.cross_entropy(logits, labels)
99 |
100 | #loss = 0.5*loss1+loss2+loss3+loss4
101 | loss = 0.5*loss1+loss4
102 |
103 | loss.backward()
104 | optimizer.step()
105 |
106 | loss_sum = loss_sum+loss
107 | loss_seg_sum = loss_seg_sum+loss4
108 |
109 | current_idx = (epo- 0) * config.niters_per_epoch + it
110 | lr = lr_policy.get_lr(current_idx)
111 |
112 | for i in range(len(optimizer.param_groups)):
113 | optimizer.param_groups[i]['lr'] = lr
114 |
115 | lr_this_epo=0
116 | for param_group in optimizer.param_groups:
117 | lr_this_epo = param_group['lr']
118 |
119 | print('Train: %s, epo %s/%s, iter %s/%s, lr %.8f, %.2f img/sec, loss %.4f, loss_average %.4f, loss_seg_average %.4f, time %s' \
120 | % (args.model_name, epo, args.epoch_max, it+1, len(train_loader), lr_this_epo, len(names)/(time.time()-start_t), float(loss), float(loss_sum/(it+1)), float(loss_seg_sum/(it+1)),
121 | datetime.datetime.now().replace(microsecond=0)-start_datetime))
122 | if accIter['train'] % 1 == 0:
123 | writer.add_scalar('Train/loss', loss, accIter['train'])
124 | view_figure = True # note that I have not colorized the GT and predictions here
125 | if accIter['train'] % 1000 == 0:
126 | if view_figure:
127 | input_rgb_images = vutils.make_grid(images[:,:3], nrow=8, padding=10) # can only display 3-channel images, so images[:,:3]
128 | writer.add_image('Train/input_rgb_images', input_rgb_images, accIter['train'])
129 | scale = max(1, 255//args.n_class) # label (0,1,2..) is invisable, multiply a constant for visualization
130 | groundtruth_tensor = labels.unsqueeze(1) * scale # mini_batch*480*640 -> mini_batch*1*480*640
131 | groundtruth_tensor = torch.cat((groundtruth_tensor, groundtruth_tensor, groundtruth_tensor), 1) # change to 3-channel for visualization
132 | groudtruth_images = vutils.make_grid(groundtruth_tensor, nrow=8, padding=10)
133 | writer.add_image('Train/groudtruth_images', groudtruth_images, accIter['train'])
134 | predicted_tensor = logits.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640
135 | predicted_tensor = torch.cat((predicted_tensor, predicted_tensor, predicted_tensor),1) # change to 3-channel for visualization, mini_batch*1*480*640
136 | predicted_images = vutils.make_grid(predicted_tensor, nrow=8, padding=10)
137 | writer.add_image('Train/predicted_images', predicted_images, accIter['train'])
138 |
139 | predicted_tensor_rgb_1 = rgb_predict.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640
140 | predicted_tensor_rgb_1 = torch.cat((predicted_tensor_rgb_1, predicted_tensor_rgb_1, predicted_tensor_rgb_1),1) # change to 3-channel for visualization, mini_batch*1*480*640
141 | predicted_images_rgb_1 = vutils.make_grid(predicted_tensor_rgb_1, nrow=8, padding=10)
142 | writer.add_image('Train/predicted_images_rgb_1', predicted_images_rgb_1, accIter['train'])
143 |
144 | predicted_tensor_depth_1 = depth_predict.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640
145 | predicted_tensor_depth_1 = torch.cat((predicted_tensor_depth_1, predicted_tensor_depth_1, predicted_tensor_depth_1),1) # change to 3-channel for visualization, mini_batch*1*480*640
146 | predicted_images_depth_1 = vutils.make_grid(predicted_tensor_depth_1, nrow=8, padding=10)
147 | writer.add_image('Train/predicted_images_depth_1', predicted_images_depth_1, accIter['train'])
148 |
149 | add_map_rgb_1 = add_map_rgb.unsqueeze(1) * scale
150 | predicted_tensor_need_rgb_1 = torch.cat((add_map_rgb_1, add_map_rgb_1, add_map_rgb_1), 1) # change to 3-channel for visualization
151 | predicted_images_need_rgb_1 = vutils.make_grid(predicted_tensor_need_rgb_1, nrow=8, padding=10)
152 | writer.add_image('Train/predicted_images_need_rgb_1', predicted_images_need_rgb_1, accIter['train'])
153 |
154 | add_map_depth_1 = add_map_depth.unsqueeze(1) * scale
155 | predicted_tensor_need_depth_1 = torch.cat((add_map_depth_1, add_map_depth_1, add_map_depth_1), 1) # change to 3-channel for visualization
156 | predicted_images_need_depth_1 = vutils.make_grid(predicted_tensor_need_depth_1, nrow=8, padding=10)
157 | writer.add_image('Train/predicted_images_need_depth_1', predicted_images_need_depth_1, accIter['train'])
158 |
159 | predicted_tensor_complex_rgb_1 = rgb_comple.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640
160 | predicted_tensor_complex_rgb_1 = torch.cat((predicted_tensor_complex_rgb_1, predicted_tensor_complex_rgb_1, predicted_tensor_complex_rgb_1),1) # change to 3-channel for visualization, mini_batch*1*480*640
161 | predicted_images_complex_rgb_1 = vutils.make_grid(predicted_tensor_complex_rgb_1, nrow=8, padding=10)
162 | writer.add_image('Train/predicted_images_complex_rgb_1', predicted_images_complex_rgb_1, accIter['train'])
163 |
164 | predicted_tensor_complex_depth_1 = depth_comple.argmax(1).unsqueeze(1) * scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640
165 | predicted_tensor_complex_depth_1 = torch.cat((predicted_tensor_complex_depth_1, predicted_tensor_complex_depth_1, predicted_tensor_complex_depth_1),1) # change to 3-channel for visualization, mini_batch*1*480*640
166 | predicted_images_complex_depth_1 = vutils.make_grid(predicted_tensor_complex_depth_1, nrow=8, padding=10)
167 | writer.add_image('Train/predicted_images_complex_depth_1', predicted_images_complex_depth_1, accIter['train'])
168 | accIter['train'] = accIter['train'] + 1
169 |
170 | def validation(epo, model, val_loader):
171 | model.eval()
172 | with torch.no_grad():
173 | for it, (images, labels, names) in enumerate(val_loader):
174 | images = Variable(images).cuda(args.gpu)
175 | labels = Variable(labels).cuda(args.gpu)
176 | start_t = time.time() # time.time() returns the current time
177 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images)
178 | loss = F.cross_entropy(logits, labels) # Note that the cross_entropy function has already include the softmax function
179 | print('Val: %s, epo %s/%s, iter %s/%s, %.2f img/sec, loss %.4f, time %s' \
180 | % (args.model_name, epo, args.epoch_max, it + 1, len(val_loader), len(names)/(time.time()-start_t), float(loss),
181 | datetime.datetime.now().replace(microsecond=0)-start_datetime))
182 | if accIter['val'] % 1 == 0:
183 | writer.add_scalar('Validation/loss', loss, accIter['val'])
184 | view_figure = False # note that I have not colorized the GT and predictions here
185 | if accIter['val'] % 1000 == 0:
186 | if view_figure:
187 | input_rgb_images = vutils.make_grid(images[:, :3], nrow=8, padding=10) # can only display 3-channel images, so images[:,:3]
188 | writer.add_image('Validation/input_rgb_images', input_rgb_images, accIter['val'])
189 | scale = max(1, 255 // args.n_class) # label (0,1,2..) is invisable, multiply a constant for visualization
190 | groundtruth_tensor = labels.unsqueeze(1) * scale # mini_batch*480*640 -> mini_batch*1*480*640
191 | groundtruth_tensor = torch.cat((groundtruth_tensor, groundtruth_tensor, groundtruth_tensor), 1) # change to 3-channel for visualization
192 | groudtruth_images = vutils.make_grid(groundtruth_tensor, nrow=8, padding=10)
193 | writer.add_image('Validation/groudtruth_images', groudtruth_images, accIter['val'])
194 | predicted_tensor = logits.argmax(1).unsqueeze(1)*scale # mini_batch*args.n_class*480*640 -> mini_batch*480*640 -> mini_batch*1*480*640
195 | predicted_tensor = torch.cat((predicted_tensor, predicted_tensor, predicted_tensor), 1) # change to 3-channel for visualization, mini_batch*1*480*640
196 | predicted_images = vutils.make_grid(predicted_tensor, nrow=8, padding=10)
197 | writer.add_image('Validation/predicted_images', predicted_images, accIter['val'])
198 | accIter['val'] += 1
199 |
200 | def testing(epo, model, test_loader):
201 | model.eval()
202 | conf_total = np.zeros((args.n_class, args.n_class))
203 | label_list = ["unlabeled", "pothole", "crack"]
204 | testing_results_file = os.path.join(weight_dir, 'testing_results_file.txt')
205 | with torch.no_grad():
206 | for it, (images, labels, names) in enumerate(test_loader):
207 | images = Variable(images).cuda(args.gpu)
208 | labels = Variable(labels).cuda(args.gpu)
209 | rgb_predict, rgb_comple, rgb_fusion, depth_predict, depth_comple, depth_fusion, logits= model(images)
210 | label = labels.cpu().numpy().squeeze().flatten()
211 | prediction = logits.argmax(1).cpu().numpy().squeeze().flatten() # prediction and label are both 1-d array, size: minibatch*640*480
212 | conf = confusion_matrix(y_true=label, y_pred=prediction, labels=[0,1,2]) # conf is args.n_class*args.n_class matrix, vertical axis: groundtruth, horizontal axis: prediction
213 | conf_total += conf
214 | print('Test: %s, epo %s/%s, iter %s/%s, time %s' % (args.model_name, epo, args.epoch_max, it+1, len(test_loader),
215 | datetime.datetime.now().replace(microsecond=0)-start_datetime))
216 | precision, recall, IoU, F1 = compute_results(conf_total)
217 | writer.add_scalar('Test/average_recall', recall.mean(), epo)
218 | writer.add_scalar('Test/average_IoU', IoU.mean(), epo)
219 | writer.add_scalar('Test/average_precision',precision.mean(), epo)
220 | writer.add_scalar('Test/average_F1', F1.mean(), epo)
221 | for i in range(len(precision)):
222 | writer.add_scalar("Test(class)/precision_class_%s" % label_list[i], precision[i], epo)
223 | writer.add_scalar("Test(class)/recall_class_%s"% label_list[i], recall[i],epo)
224 | writer.add_scalar('Test(class)/Iou_%s'% label_list[i], IoU[i], epo)
225 | writer.add_scalar('Test(class)/F1_%s'% label_list[i], F1[i], epo)
226 | if epo==0:
227 | with open(testing_results_file, 'w') as f:
228 | f.write("# %s, initial lr: %s, batch size: %s, date: %s \n" %(args.model_name, args.lr_start, args.batch_size, datetime.date.today()))
229 | f.write("# epoch: unlabeled, car, person, bike, curve, car_stop, guardrail, color_cone, bump, average(nan_to_num). (Pre %, Acc %, IoU %, F1 %)\n")
230 | with open(testing_results_file, 'a') as f:
231 | f.write(str(epo)+': ')
232 | for i in range(len(precision)):
233 | f.write('%0.4f, %0.4f, %0.4f, %0.4f ' % (100*precision[i], 100*recall[i], 100*IoU[i], 100*F1[i]))
234 | f.write('%0.4f, %0.4f, %0.4f, %0.4f\n' % (100*np.mean(np.nan_to_num(precision)), 100*np.mean(np.nan_to_num(recall)), 100*np.mean(np.nan_to_num(IoU)), 100*np.mean(np.nan_to_num(F1)) ))
235 | #f.write('%0.4f, %0.4f, %0.4f, %0.4f\n' % (100*np.mean(np.nan_to_num(recall)), 100*np.mean(np.nan_to_num(IoU), 100*np.mean(np.nan_to_num(precision)), ))))
236 | print('saving testing results.')
237 | with open(testing_results_file, "r") as file:
238 | writer.add_text('testing_results', file.read().replace('\n', ' \n'), epo)
239 |
240 | if __name__ == '__main__':
241 |
242 | torch.cuda.set_device(args.gpu)
243 | print("\nthe pytorch version:", torch.__version__)
244 | print("the gpu count:", torch.cuda.device_count())
245 | print("the current used gpu:", torch.cuda.current_device(), '\n')
246 |
247 | config.pretrained_model = config.root_dir + args.pre_weight
248 |
249 | model = eval(args.model_name)(cfg = config ,n_class=args.n_class, encoder_name=args.backbone)
250 |
251 | base_lr = args.lr_start
252 |
253 | if args.gpu >= 0: model.cuda(args.gpu)
254 | #optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_start, momentum=0.9, weight_decay=0.0005)
255 | #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_decay, last_epoch=-1)
256 | params_list = []
257 | params_list = group_weight(params_list, model, nn.BatchNorm2d, base_lr)
258 | optimizer = torch.optim.AdamW(params_list, lr=base_lr, betas=(0.9, 0.999), weight_decay=config.weight_decay)
259 |
260 | total_iteration = config.nepochs * config.niters_per_epoch
261 | lr_policy = WarmUpPolyLR(base_lr, config.lr_power, total_iteration, config.niters_per_epoch * config.warm_up_epoch)
262 |
263 | # preparing folders
264 | if os.path.exists("./PotCrackSeg"):
265 | shutil.rmtree("./PotCrackSeg")
266 | weight_dir = os.path.join("./PotCrackSeg", args.model_name)
267 | os.makedirs(weight_dir)
268 | os.chmod(weight_dir, stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine
269 |
270 | writer = SummaryWriter("./PotCrackSeg/tensorboard_log")
271 | os.chmod("./PotCrackSeg/tensorboard_log", stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine
272 | os.chmod("./PotCrackSeg", stat.S_IRWXO)
273 |
274 | print('training %s on GPU #%d with pytorch' % (args.model_name, args.gpu))
275 | print('from epoch %d / %s' % (args.epoch_from, args.epoch_max))
276 | print('weight will be saved in: %s' % weight_dir)
277 |
278 | train_dataset = MY_dataset(data_dir=args.data_dir, split='train', transform=augmentation_methods,input_h=288, input_w=512)
279 | val_dataset = MY_dataset(data_dir=args.data_dir, split='validation',input_h=288, input_w=512)
280 | test_dataset = MY_dataset(data_dir=args.data_dir, split='test',input_h=288, input_w=512)
281 |
282 | train_loader = DataLoader(
283 | dataset = train_dataset,
284 | batch_size = args.batch_size,
285 | shuffle = True,
286 | num_workers = args.num_workers,
287 | pin_memory = True,
288 | drop_last = False
289 | )
290 | val_loader = DataLoader(
291 | dataset = val_dataset,
292 | batch_size = args.batch_size,
293 | shuffle = False,
294 | num_workers = args.num_workers,
295 | pin_memory = True,
296 | drop_last = False
297 | )
298 | test_loader = DataLoader(
299 | dataset = test_dataset,
300 | batch_size = args.batch_size,
301 | shuffle = False,
302 | num_workers = args.num_workers,
303 | pin_memory = True,
304 | drop_last = False
305 | )
306 | start_datetime = datetime.datetime.now().replace(microsecond=0)
307 | accIter = {'train': 0, 'val': 0}
308 | for epo in range(args.epoch_from, args.epoch_max):
309 | print('\ntrain %s, epo #%s begin...' % (args.model_name, epo))
310 | #scheduler.step() # if using pytorch 0.4.1, please put this statement here
311 | train(epo, model, train_loader, optimizer)
312 | validation(epo, model, val_loader)
313 |
314 | checkpoint_model_file = os.path.join(weight_dir, str(epo) + '.pth')
315 | print('saving check point %s: ' % checkpoint_model_file)
316 | torch.save(model.state_dict(), checkpoint_model_file)
317 |
318 | testing(epo, model, test_loader)
319 | #scheduler.step() # if using pytorch 1.1 or above, please put this statement here
320 |
--------------------------------------------------------------------------------
/util/MY_dataset.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 | from torch.utils.data.dataset import Dataset
3 | import numpy as np
4 | import PIL
5 |
6 | class MY_dataset(Dataset):
7 |
8 | def __init__(self, data_dir, split, input_h=288, input_w=512 ,transform=[]):
9 | super(MY_dataset, self).__init__()
10 |
11 | with open(os.path.join(data_dir, split+'.txt'), 'r') as f:
12 | self.names = [name.strip() for name in f.readlines()]
13 |
14 | self.data_dir = data_dir
15 | self.split = split
16 | self.input_h = input_h
17 | self.input_w = input_w
18 | self.transform = transform
19 | self.n_data = len(self.names)
20 |
21 | def read_image(self, name, folder,head):
22 | file_path = os.path.join(self.data_dir, '%s/%s%s.png' % (folder, head,name))
23 | image = np.asarray(PIL.Image.open(file_path))
24 | return image
25 |
26 | def __getitem__(self, index):
27 | name = self.names[index]
28 | image = self.read_image(name, 'left','left')
29 | label = self.read_image(name, 'labels','label')
30 | depth = self.read_image(name, 'depth','depth')
31 | image = np.asarray(PIL.Image.fromarray(image).resize((self.input_w, self.input_h)))
32 | image = image.astype('float32')
33 | image = np.transpose(image, (2,0,1))/255.0
34 | depth = np.asarray(PIL.Image.fromarray(depth).resize((self.input_w, self.input_h)))
35 |
36 | depth = depth.astype('float32')
37 |
38 | M = depth.max()
39 | depth = depth/M
40 |
41 | label = np.asarray(PIL.Image.fromarray(label).resize((self.input_w, self.input_h), resample=PIL.Image.NEAREST))
42 | label = label.astype('int64')
43 |
44 |
45 | return torch.cat((torch.tensor(image), torch.tensor(depth).unsqueeze(0)),dim=0), torch.tensor(label),name
46 |
47 | def __len__(self):
48 | return self.n_data
49 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/util/augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | class RandomFlip():
5 | def __init__(self, prob=0.5):
6 | #super(RandomFlip, self).__init__()
7 | self.prob = prob
8 |
9 | def __call__(self, image, label):
10 | if np.random.rand() < self.prob:
11 | image = image[:,::-1]
12 | label = label[:,::-1]
13 | return image, label
14 |
15 | class RandomCrop():
16 | def __init__(self, crop_rate=0.1, prob=1.0):
17 | #super(RandomCrop, self).__init__()
18 | self.crop_rate = crop_rate
19 | self.prob = prob
20 |
21 | def __call__(self, image, label):
22 | if np.random.rand() < self.prob:
23 | w, h, c = image.shape
24 |
25 | h1 = np.random.randint(0, h*self.crop_rate)
26 | w1 = np.random.randint(0, w*self.crop_rate)
27 | h2 = np.random.randint(h-h*self.crop_rate, h+1)
28 | w2 = np.random.randint(w-w*self.crop_rate, w+1)
29 |
30 | image = image[w1:w2, h1:h2]
31 | label = label[w1:w2, h1:h2]
32 |
33 | return image, label
34 |
35 |
36 | class RandomCropOut():
37 | def __init__(self, crop_rate=0.2, prob=1.0):
38 | #super(RandomCropOut, self).__init__()
39 | self.crop_rate = crop_rate
40 | self.prob = prob
41 |
42 | def __call__(self, image, label):
43 | if np.random.rand() < self.prob:
44 | w, h, c = image.shape
45 |
46 | h1 = np.random.randint(0, h*self.crop_rate)
47 | w1 = np.random.randint(0, w*self.crop_rate)
48 | h2 = int(h1 + h*self.crop_rate)
49 | w2 = int(w1 + w*self.crop_rate)
50 |
51 | image[w1:w2, h1:h2] = 0
52 | label[w1:w2, h1:h2] = 0
53 |
54 | return image, label
55 |
56 |
57 | class RandomBrightness():
58 | def __init__(self, bright_range=0.15, prob=0.9):
59 | #super(RandomBrightness, self).__init__()
60 | self.bright_range = bright_range
61 | self.prob = prob
62 |
63 | def __call__(self, image, label):
64 | if np.random.rand() < self.prob:
65 | bright_factor = np.random.uniform(1-self.bright_range, 1+self.bright_range)
66 | image = (image * bright_factor).astype(image.dtype)
67 |
68 | return image, label
69 |
70 |
71 | class RandomNoise():
72 | def __init__(self, noise_range=5, prob=0.9):
73 | #super(RandomNoise, self).__init__()
74 | self.noise_range = noise_range
75 | self.prob = prob
76 |
77 | def __call__(self, image, label):
78 | if np.random.rand() < self.prob:
79 | w, h, c = image.shape
80 |
81 | noise = np.random.randint(
82 | -self.noise_range,
83 | self.noise_range,
84 | (w,h,c)
85 | )
86 |
87 | image = (image + noise).clip(0,255).astype(image.dtype)
88 |
89 | return image, label
90 |
91 |
92 |
93 |
--------------------------------------------------------------------------------
/util/init_func.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # encoding: utf-8
3 | # @Time : 2018/9/28 下午12:13
4 | # @Author : yuchangqian
5 | # @Contact : changqian_yu@163.com
6 | # @File : init_func.py.py
7 | import torch
8 | import torch.nn as nn
9 |
10 | def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
11 | **kwargs):
12 | for name, m in feature.named_modules():
13 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14 | conv_init(m.weight, **kwargs)
15 | elif isinstance(m, norm_layer):
16 | m.eps = bn_eps
17 | m.momentum = bn_momentum
18 | nn.init.constant_(m.weight, 1)
19 | nn.init.constant_(m.bias, 0)
20 |
21 |
22 | def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
23 | **kwargs):
24 | if isinstance(module_list, list):
25 | for feature in module_list:
26 | __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum,
27 | **kwargs)
28 | else:
29 | __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum,
30 | **kwargs)
31 |
32 |
33 | def group_weight(weight_group, module, norm_layer, lr):
34 | group_decay = []
35 | group_no_decay = []
36 | count = 0
37 | for m in module.modules():
38 | if isinstance(m, nn.Linear):
39 | group_decay.append(m.weight)
40 | if m.bias is not None:
41 | group_no_decay.append(m.bias)
42 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
43 | group_decay.append(m.weight)
44 | if m.bias is not None:
45 | group_no_decay.append(m.bias)
46 | elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \
47 | or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.LayerNorm):
48 | if m.weight is not None:
49 | group_no_decay.append(m.weight)
50 | if m.bias is not None:
51 | group_no_decay.append(m.bias)
52 | elif isinstance(m, nn.Parameter):
53 | group_decay.append(m)
54 |
55 | assert len(list(module.parameters())) >= len(group_decay) + len(group_no_decay)
56 | weight_group.append(dict(params=group_decay, lr=lr))
57 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
58 | return weight_group
--------------------------------------------------------------------------------
/util/load_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import re
3 | from torch import distributed as dist
4 |
5 |
6 | def get_dist_info():
7 | if dist.is_available():
8 | initialized = dist.is_initialized()
9 | else:
10 | initialized = False
11 | if initialized:
12 | rank = dist.get_rank()
13 | world_size = dist.get_world_size()
14 | else:
15 | rank = 0
16 | world_size = 1
17 | return rank, world_size
18 |
19 |
20 | def load_state_dict(module, state_dict, strict=False, logger=None):
21 | unexpected_keys = []
22 | all_missing_keys = []
23 | err_msg = []
24 |
25 | metadata = getattr(state_dict, '_metadata', None)
26 | state_dict = state_dict.copy()
27 | if metadata is not None:
28 | state_dict._metadata = metadata
29 |
30 | # use _load_from_state_dict to enable checkpoint version control
31 | def load(module, prefix=''):
32 | # recursively check parallel module in case that the model has a
33 | # complicated structure, e.g., nn.Module(nn.Module(DDP))
34 | local_metadata = {} if metadata is None else metadata.get(
35 | prefix[:-1], {})
36 | module._load_from_state_dict(state_dict, prefix, local_metadata, True,
37 | all_missing_keys, unexpected_keys,
38 | err_msg)
39 | for name, child in module._modules.items():
40 | if child is not None:
41 | load(child, prefix + name + '.')
42 |
43 | load(module)
44 | load = None # break load->load reference cycle
45 |
46 | # ignore "num_batches_tracked" of BN layers
47 | missing_keys = [
48 | key for key in all_missing_keys if 'num_batches_tracked' not in key
49 | ]
50 |
51 | if unexpected_keys:
52 | err_msg.append('unexpected key in source '
53 | f'state_dict: {", ".join(unexpected_keys)}\n')
54 | if missing_keys:
55 | err_msg.append(
56 | f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
57 |
58 | rank, _ = get_dist_info()
59 | if len(err_msg) > 0 and rank == 0:
60 | err_msg.insert(
61 | 0, 'The model and loaded state dict do not match exactly\n')
62 | err_msg = '\n'.join(err_msg)
63 | if strict:
64 | raise RuntimeError(err_msg)
65 | else:
66 | print(err_msg)
67 |
68 |
69 |
70 | def load_pretrain(model,
71 | filename,
72 | strict=False,
73 | revise_keys=[(r'^module\.', '')]):
74 | checkpoint = torch.load(filename)
75 | # OrderedDict is a subclass of dict
76 | if not isinstance(checkpoint, dict):
77 | raise RuntimeError(
78 | f'No state_dict found in checkpoint file {filename}')
79 | # get state_dict from checkpoint
80 | if 'state_dict' in checkpoint:
81 | state_dict = checkpoint['state_dict']
82 | elif 'model' in checkpoint:
83 | state_dict = checkpoint['model']
84 | else:
85 | state_dict = checkpoint
86 | # strip prefix of state_dict
87 | for p, r in revise_keys:
88 | state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
89 | # load state_dict
90 | load_state_dict(model, state_dict, strict)
91 | return checkpoint
--------------------------------------------------------------------------------
/util/lr_policy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # encoding: utf-8
3 | # @Time : 2018/8/1 上午1:50
4 | # @Author : yuchangqian
5 | # @Contact : changqian_yu@163.com
6 | # @File : lr_policy.py.py
7 |
8 | from abc import ABCMeta, abstractmethod
9 |
10 |
11 | class BaseLR():
12 | __metaclass__ = ABCMeta
13 |
14 | @abstractmethod
15 | def get_lr(self, cur_iter): pass
16 |
17 |
18 | class PolyLR(BaseLR):
19 | def __init__(self, start_lr, lr_power, total_iters):
20 | self.start_lr = start_lr
21 | self.lr_power = lr_power
22 | self.total_iters = total_iters + 0.0
23 |
24 | def get_lr(self, cur_iter):
25 | return self.start_lr * (
26 | (1 - float(cur_iter) / self.total_iters) ** self.lr_power)
27 |
28 |
29 | class WarmUpPolyLR(BaseLR):
30 | def __init__(self, start_lr, lr_power, total_iters, warmup_steps):
31 | print(start_lr)
32 | self.start_lr = start_lr
33 | self.lr_power = lr_power
34 | self.total_iters = total_iters + 0.0
35 | self.warmup_steps = warmup_steps
36 |
37 | def get_lr(self, cur_iter):
38 | if cur_iter < self.warmup_steps:
39 | return self.start_lr * (cur_iter / self.warmup_steps)
40 | else:
41 | return self.start_lr * (
42 | (1 - float(cur_iter) / self.total_iters) ** self.lr_power)
43 |
44 |
45 | class MultiStageLR(BaseLR):
46 | def __init__(self, lr_stages):
47 | assert type(lr_stages) in [list, tuple] and len(lr_stages[0]) == 2, \
48 | 'lr_stages must be list or tuple, with [iters, lr] format'
49 | self._lr_stagess = lr_stages
50 |
51 | def get_lr(self, epoch):
52 | for it_lr in self._lr_stagess:
53 | if epoch < it_lr[0]:
54 | return it_lr[1]
55 |
56 |
57 | class LinearIncreaseLR(BaseLR):
58 | def __init__(self, start_lr, end_lr, warm_iters):
59 | self._start_lr = start_lr
60 | self._end_lr = end_lr
61 | self._warm_iters = warm_iters
62 | self._delta_lr = (end_lr - start_lr) / warm_iters
63 |
64 | def get_lr(self, cur_epoch):
65 | return self._start_lr + cur_epoch * self._delta_lr
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | # By Yuxiang Sun, Dec. 4, 2020
2 | # Email: sun.yuxiang@outlook.com
3 |
4 | import numpy as np
5 | from PIL import Image
6 | import torch
7 |
8 | # 0:unlabeled, 1:car, 2:person, 3:bike, 4:curve, 5:car_stop, 6:guardrail, 7:color_cone, 8:bump
9 | def get_palette():
10 | unlabelled = [0,0,0]
11 | sat = [153,0,0]
12 | fanban = [0,153,0]
13 | lidar = [0,0,153]
14 | penzui = [153,153,0]
15 | palette = np.array([unlabelled,sat, fanban, lidar, penzui])
16 | return palette
17 |
18 | def visualize(image_name, predictions, weight_name):
19 | palette = get_palette()
20 | for (i, pred) in enumerate(predictions):
21 | pred = predictions[i].cpu().numpy()
22 | img = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
23 | for cid in range(0, len(palette)): # fix the mistake from the MFNet code on Dec.27, 2019
24 | img[pred == cid] = palette[cid]
25 | img = Image.fromarray(np.uint8(img))
26 | img.save('runs/Pred_' + weight_name + '_' + image_name[i] + '.png')
27 |
28 | def compute_results(conf_total):
29 | n_class = conf_total.shape[0]
30 | consider_unlabeled = True # must consider the unlabeled, please set it to True
31 | if consider_unlabeled is True:
32 | start_index = 0
33 | else:
34 | start_index = 1
35 | precision_per_class = np.zeros(n_class)
36 | recall_per_class = np.zeros(n_class)
37 | iou_per_class = np.zeros(n_class)
38 | F1_per_class = np.zeros(n_class)
39 | for cid in range(start_index, n_class): # cid: class id
40 | if conf_total[start_index:, cid].sum() == 0:
41 | precision_per_class[cid] = np.nan
42 | else:
43 | precision_per_class[cid] = float(conf_total[cid, cid]) / float(conf_total[start_index:, cid].sum()) # precision = TP/TP+FP
44 | if conf_total[cid, start_index:].sum() == 0:
45 | recall_per_class[cid] = np.nan
46 | else:
47 | recall_per_class[cid] = float(conf_total[cid, cid]) / float(conf_total[cid, start_index:].sum()) # recall = TP/TP+FN
48 | if (conf_total[cid, start_index:].sum() + conf_total[start_index:, cid].sum() - conf_total[cid, cid]) == 0:
49 | iou_per_class[cid] = np.nan
50 | else:
51 | iou_per_class[cid] = float(conf_total[cid, cid]) / float((conf_total[cid, start_index:].sum() + conf_total[start_index:, cid].sum() - conf_total[cid, cid])) # IoU = TP/TP+FP+FN
52 | if (recall_per_class[cid] == np.nan) | (precision_per_class[cid] == np.nan) |(precision_per_class[cid]==0)|(recall_per_class[cid]==0):
53 | F1_per_class[cid] = np.nan
54 | else :
55 | F1_per_class[cid] = 2 / (1/precision_per_class[cid] +1/recall_per_class[cid])
56 |
57 | return precision_per_class, recall_per_class, iou_per_class,F1_per_class
58 |
59 | def label2onehot(label,n_class,input_w,input_h):
60 | label = torch.tensor(label).unsqueeze(0)
61 | print(label.size())
62 | onehot = torch.zeros(n_class, input_h, input_w).long() # 先生成模板
63 | print(onehot.size())
64 | onehot.scatter_(0, label, 1).float() # 这个就是生成6个channel的, scatter_这个函数不必理解太深,知道这么一个用法就OK了
65 |
66 | return onehot
67 |
68 |
--------------------------------------------------------------------------------