├── .devcontainer ├── devcontainer.json └── docker-compose.yml ├── Dockerfile ├── LICENSE ├── README.md ├── docs ├── IV2023_AMFNet.pdf ├── overall.png └── results.jpg ├── model ├── AMFNet.py └── __init__.py ├── run_demo.py └── util ├── RGB_Depth_Fusion_dataset.py ├── __init__.py ├── __pycache__ ├── MF_dataset.cpython-36.pyc ├── Po_dataset.cpython-36.pyc ├── RGB_Depth_Fusion_dataset.cpython-36.pyc ├── RGB_Depth_Fusion_dataset.cpython-38.pyc ├── RGB_Depth_dataset.cpython-36.pyc ├── RGB_Depth_dataset.cpython-38.pyc ├── RGB_Disp_dataset.cpython-36.pyc ├── RT_dataset.cpython-36.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── augmentation.cpython-36.pyc ├── augmentation.cpython-38.pyc ├── lr_policy.cpython-36.pyc ├── util.cpython-36.pyc └── util.cpython-38.pyc ├── augmentation.py └── util.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.163.1/containers/docker-from-docker-compose 3 | { 4 | "name": "amfnet", 5 | "dockerComposeFile": "docker-compose.yml", 6 | "service": "amfnet", 7 | "workspaceFolder": "/workspace" 8 | 9 | // // Use this environment variable if you need to bind mount your local source code into a new container. 10 | // "remoteEnv": { 11 | // "LOCAL_WORKSPACE_FOLDER": "${localWorkspaceFolder}" 12 | // }, 13 | 14 | // // Set *default* container specific settings.json values on container create. 15 | // "settings": { 16 | // "terminal.integrated.shell.linux": "/bin/bash" 17 | // }, 18 | 19 | // // Add the IDs of extensions you want installed when the container is created. 20 | // "extensions": [ 21 | // "ms-azuretools.vscode-docker" 22 | // ], 23 | 24 | // // Use 'forwardPorts' to make a list of ports inside the container available locally. 25 | // // "forwardPorts": [], 26 | 27 | // // Use 'postCreateCommand' to run commands after the container is created. 28 | // // "postCreateCommand": "docker --version", 29 | 30 | // // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 31 | // "remoteUser": "vscode" 32 | } -------------------------------------------------------------------------------- /.devcontainer/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.3' 2 | services: 3 | amfnet: 4 | # Uncomment the next line to use a non-root user for all processes. You can also 5 | # simply use the "remoteUser" property in devcontainer.json if you just want VS Code 6 | # and its sub-processes (terminals, tasks, debugging) to execute as the user. On Linux, 7 | # you may need to update USER_UID and USER_GID in .devcontainer/Dockerfile to match your 8 | # user if not 1000. See https://aka.ms/vscode-remote/containers/non-root for details. 9 | # user: vscode 10 | runtime: nvidia 11 | image: docker_image_test # The name of the docker image 12 | ports: 13 | - '12347:6006' 14 | volumes: 15 | # Update this to wherever you want VS Code to mount the folder of your project 16 | - ..:/workspace:cached # Do not change! 17 | # - /home/sun/somefolder/:/somefolder # folder_in_local_computer:folder_in_docker_container 18 | 19 | # Forwards the local Docker socket to the container. 20 | - /var/run/docker.sock:/var/run/docker-host.sock 21 | shm_size: 32g 22 | devices: 23 | - /dev/nvidia0 24 | 25 | # Uncomment the next four lines if you will use a ptrace-based debuggers like C++, Go, and Rust. 26 | # cap_add: 27 | # - SYS_PTRACE 28 | # security_opt: 29 | # - seccomp:unconfined 30 | 31 | # Overrides default command so things don't shut down after the process ends. 32 | #entrypoint: /usr/local/share/docker-init.sh 33 | command: sleep infinity -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu20.04 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | #RUN apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 'A4B469963BF863CC' 4 | 5 | RUN apt-key del 7fa2af80 6 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub 7 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu2004/x86_64/7fa2af80.pub 8 | 9 | 10 | RUN apt-get update && apt-get install -y vim python3 python3-pip 11 | 12 | RUN pip3 install --upgrade pip 13 | RUN pip3 install setuptools>=40.3.0 14 | 15 | RUN pip3 install -U scipy scikit-learn 16 | RUN pip3 install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 17 | RUN pip3 install torchsummary 18 | RUN pip3 install tensorboard==2.11.0 19 | RUN pip3 install einops 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMFNet-PyTorch 2 | 3 | The official pytorch implementation of **Adaptive-Mask Fusion Network for Segmentation of Drivable Road and Negative Obstacle With Untrustworthy Features**. 4 | 5 | Paper link: https://arxiv.org/abs/2304.13979 6 | 7 | We test our code in Python 3.7, CUDA 11.1, cuDNN 8, and PyTorch 1.7.1. We provide `Dockerfile` to build the docker image we used. You can modify the `Dockerfile` as you want. 8 |
9 | 10 |
11 | 12 | # Demo 13 | 14 | 18 | 19 | # Introduction 20 | AMFNet is a multi-modal fusion network for semantic segmentation of drivable road and negative obstacles. 21 | # Dataset 22 | We developed the [**NPO dataset**](https://ieeexplore.ieee.org/document/10114585/) to build our **DRNO dataset**. You can download the [**DRNO dataset**](https://pan.baidu.com/s/1ca5vx5QavXNewws9scPqRA?pwd=drno ). 23 | # Pretrained weights 24 | The pretrained weight of AMFNet can be downloaded from [**here**](https://pan.baidu.com/s/1ca5vx5QavXNewws9scPqRA?pwd=drno). 25 | # Usage 26 | * Clone this repo 27 | ``` 28 | $ git clone https://github.com/lab-sun/AMFNet.git 29 | ``` 30 | * Build docker image 31 | ``` 32 | $ cd ~/AMFNet 33 | $ docker build -t docker_image_amfnet . 34 | ``` 35 | * Download the dataset 36 | ``` 37 | $ (You should be in the AMFNet folder) 38 | $ mkdir ./dataset 39 | $ cd ./dataset 40 | $ (download our preprocessed dataset.zip in this folder) 41 | $ unzip -d .. dataset.zip 42 | ``` 43 | * To reproduce our results, you need to download our pretrained weights. 44 | ``` 45 | $ (You should be in the AMFNet folder) 46 | $ mkdir ./weights_backup/AMFNet 47 | $ cd ./weights_backup/AMFNet 48 | $ (download our preprocessed dataset.zip in this folder) 49 | $ unzip -d .. dataset.zip 50 | $ docker run -it --shm-size 8G -p 1234:6006 --name docker_container_mafnet --gpus all -v ~/AMFNet:/workspace docker_image_amfnet 51 | $ (currently, you should be in the docker) 52 | $ cd /workspace 53 | $ python3 run_demo.py 54 | ``` 55 | The results will be saved in the `./runs` folder. 56 | * To train AMFNet 57 | ``` 58 | $ (You should be in the AMFNet folder) 59 | $ docker run -it --shm-size 8G -p 1234:6006 --name docker_container_mafnet --gpus all -v ~/AMFNet:/workspace docker_image_amfnet 60 | $ (currently, you should be in the docker) 61 | $ cd /workspace 62 | $ python3 train.py 63 | ``` 64 | * To see the training process 65 | ``` 66 | $ (fire up another terminal) 67 | $ docker exec -it docker_container_amfnet /bin/bash 68 | $ cd /workspace 69 | $ tensorboard --bind_all --logdir=./runs/tensorboard_log/ 70 | $ (fire up your favorite browser with http://localhost:1234, you will see the tensorboard) 71 | ``` 72 | The results will be saved in the `./runs` folder. 73 | Note: Please change the smoothing factor in the Tensorboard webpage to `0.999`, otherwise, you may not find the patterns from the noisy plots. If you have the error `docker: Error response from daemon: could not select device driver`, please first install [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) on your computer! 74 | 75 | # Citation 76 | If you use AMFNet in your academic work, please cite: 77 | ``` 78 | @ARTICLE{feng2023amfnet, 79 | author={Zhen Feng and Yuchao Feng and Yanning Guo and Yuxiang Sun}, 80 | journal={IEEE Intelligent Vehicles Symposium}, 81 | title={Adaptive-Mask Fusion Network for Segmentation of Drivable Road and Negative Obstacle With Untrustworthy Features}, 82 | year={2023}, 83 | volume={}, 84 | number={}, 85 | pages={}, 86 | doi={}} 87 | ``` 88 | 89 | # Demo 90 |
91 | 92 |
93 | 94 | # Acknowledgement 95 | Some of the codes are borrowed from [RTFNet](https://github.com/yuxiangsun/RTFNet) 96 | 97 | Contact: yx.sun@polyu.edu.hk 98 | 99 | Website: https://yuxiangsun.github.io/ 100 | -------------------------------------------------------------------------------- /docs/IV2023_AMFNet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/docs/IV2023_AMFNet.pdf -------------------------------------------------------------------------------- /docs/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/docs/overall.png -------------------------------------------------------------------------------- /docs/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/docs/results.jpg -------------------------------------------------------------------------------- /model/AMFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules import padding 4 | import torchvision.models as models 5 | import torch.nn.functional as F 6 | 7 | from torch import nn, einsum 8 | from einops import rearrange, repeat 9 | 10 | class AMFNet(nn.Module): 11 | 12 | def __init__(self, n_class): 13 | super(AMFNet, self).__init__() 14 | 15 | resnet_raw_model1 = models.resnet50(pretrained=True) 16 | resnet_raw_model2 = models.resnet50(pretrained=True) 17 | self.inplanes = 2048 18 | 19 | ######## Thermal ENCODER ######## 20 | 21 | self.encoder_thermal_conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 22 | self.encoder_thermal_conv1.weight.data = torch.unsqueeze(torch.mean(resnet_raw_model1.conv1.weight.data, dim=1), dim=1) 23 | self.encoder_thermal_bn1 = resnet_raw_model1.bn1 24 | self.encoder_thermal_relu = resnet_raw_model1.relu 25 | self.encoder_thermal_maxpool = resnet_raw_model1.maxpool 26 | self.encoder_thermal_layer1 = resnet_raw_model1.layer1 27 | self.encoder_thermal_layer2 = resnet_raw_model1.layer2 28 | self.encoder_thermal_layer3 = resnet_raw_model1.layer3 29 | self.encoder_thermal_layer4 = BottleStack(dim=1024,fmap_size=(18,32),dim_out=2048,proj_factor = 4,num_layers=3,heads=4,dim_head=512) 30 | 31 | ######## RGB ENCODER ######## 32 | 33 | self.encoder_rgb_conv1 = resnet_raw_model2.conv1 34 | self.encoder_rgb_bn1 = resnet_raw_model2.bn1 35 | self.encoder_rgb_relu = resnet_raw_model2.relu 36 | self.encoder_rgb_maxpool = resnet_raw_model2.maxpool 37 | self.encoder_rgb_layer1 = resnet_raw_model2.layer1 38 | self.encoder_rgb_layer2 = resnet_raw_model2.layer2 39 | self.encoder_rgb_layer3 = resnet_raw_model2.layer3 40 | self.encoder_rgb_layer4= BottleStack(dim=1024,fmap_size=(18,32),dim_out=2048,proj_factor = 4,num_layers=3,heads=4,dim_head=512) 41 | 42 | ######## DECODER ######## 43 | self.upconv5 = upbolckV2(cin=2048,cout=1024) 44 | self.upconv4 = upbolckV2(cin=1024,cout=512) 45 | self.upconv3 = upbolckV2(cin=512,cout=256) 46 | self.upconv2 = upbolckV2(cin=256,cout=128) 47 | self.upconv1 = upbolckV2(cin=128,cout=n_class) 48 | 49 | ######## FUSION ######## 50 | self.fusion1 = Fusion_V2(in_channels=64,med_channels=32,channel=128) 51 | self.fusion2 = Fusion_V2(in_channels=256,med_channels=128,channel=512) 52 | self.fusion3 = Fusion_V2(in_channels=512,med_channels=256,channel=1024) 53 | self.fusion4 = Fusion_V2(in_channels=1024,med_channels=512,channel=2048) 54 | self.fusion5 = Fusion_V2(in_channels=2048,med_channels=1024,channel=4096) 55 | 56 | self.skip_tranform = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=1) 57 | 58 | def forward(self, input): 59 | 60 | rgb = input[:,:3] 61 | thermal = input[:,3:4] 62 | 63 | mask = input[:,4:5] 64 | 65 | verbose = False 66 | 67 | # encoder 68 | 69 | ###################################################################### 70 | 71 | if verbose: print("rgb.size() original: ", rgb.size()) 72 | if verbose: print("thermal.size() original: ", thermal.size()) 73 | 74 | ###################################################################### 75 | 76 | rgb = self.encoder_rgb_conv1(rgb) 77 | if verbose: print("rgb.size() after conv1: ", rgb.size()) 78 | rgb = self.encoder_rgb_bn1(rgb) 79 | if verbose: print("rgb.size() after bn1: ", rgb.size()) 80 | rgb = self.encoder_rgb_relu(rgb) 81 | if verbose: print("rgb.size() after relu: ", rgb.size()) 82 | 83 | thermal = self.encoder_thermal_conv1(thermal) 84 | if verbose: print("thermal.size() after conv1: ", thermal.size()) 85 | thermal = self.encoder_thermal_bn1(thermal) 86 | if verbose: print("thermal.size() after bn1: ", thermal.size()) 87 | thermal = self.encoder_thermal_relu(thermal) 88 | if verbose: print("thermal.size() after relu: ", thermal.size()) 89 | 90 | rgb = self.fusion1(rgb,thermal,mask) 91 | skip1 = rgb 92 | rgb = self.encoder_rgb_maxpool(rgb) 93 | if verbose: print("rgb.size() after maxpool: ", rgb.size()) 94 | thermal = self.encoder_thermal_maxpool(thermal) 95 | if verbose: print("thermal.size() after maxpool: ", thermal.size()) 96 | ###################################################################### 97 | rgb = self.encoder_rgb_layer1(rgb) 98 | if verbose: print("rgb.size() after layer1: ", rgb.size()) 99 | thermal = self.encoder_thermal_layer1(thermal) 100 | if verbose: print("thermal.size() after layer1: ", thermal.size()) 101 | rgb = self.fusion2(rgb,thermal,mask) 102 | skip2 = rgb 103 | ###################################################################### 104 | rgb = self.encoder_rgb_layer2(rgb) 105 | if verbose: print("rgb.size() after layer2: ", rgb.size()) 106 | thermal = self.encoder_thermal_layer2(thermal) 107 | if verbose: print("thermal.size() after layer2: ", thermal.size()) 108 | rgb = self.fusion3(rgb,thermal,mask) 109 | skip3 = rgb 110 | ###################################################################### 111 | rgb = self.encoder_rgb_layer3(rgb) 112 | if verbose: print("rgb.size() after layer3: ", rgb.size()) 113 | thermal = self.encoder_thermal_layer3(thermal) 114 | if verbose: print("thermal.size() after layer3: ", thermal.size()) 115 | rgb = self.fusion4(rgb,thermal,mask) 116 | skip4 = rgb 117 | if verbose: print("rgb.size() after fusion_con3d4: ", rgb.size()) 118 | ###################################################################### 119 | rgb = self.encoder_rgb_layer4(rgb) 120 | if verbose: print("thermal.size() after layer4: ", rgb.size()) 121 | thermal = self.encoder_thermal_layer4(thermal) 122 | if verbose: print("thermal.size() after layer4: ", thermal.size()) 123 | fuse = self.fusion5(rgb,thermal,mask) 124 | 125 | ###################################################################### 126 | 127 | # decoder 128 | fuse = self.upconv5(fuse) 129 | if verbose: print("fuse after deconv1: ", fuse.size()) # (30, 40) 130 | fuse = fuse+skip4 131 | 132 | fuse = self.upconv4(fuse) 133 | if verbose: print("fuse after deconv2: ", fuse.size()) # (60, 80) 134 | fuse = fuse+skip3 135 | 136 | fuse = self.upconv3(fuse) 137 | if verbose: print("fuse after deconv3: ", fuse.size()) # (120, 160) 138 | fuse = fuse+skip2 139 | fuse = self.upconv2(fuse) 140 | if verbose: print("fuse after deconv4: ", fuse.size()) # (240, 320) 141 | skip1 = self.skip_tranform(skip1) 142 | fuse = fuse+skip1 143 | fuse = self.upconv1(fuse) 144 | if verbose: print("fuse after deconv5: ", fuse.size()) # (480, 640) 145 | 146 | return fuse 147 | 148 | class TransBottleneck(nn.Module): 149 | 150 | def __init__(self, inplanes, planes, stride=1, upsample=None): 151 | super(TransBottleneck, self).__init__() 152 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 153 | self.bn1 = nn.BatchNorm2d(planes) 154 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 155 | self.bn2 = nn.BatchNorm2d(planes) 156 | 157 | if upsample is not None and stride != 1: 158 | self.conv3 = nn.ConvTranspose2d(planes, planes, kernel_size=2, stride=stride, padding=0, bias=False) 159 | else: 160 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 161 | 162 | self.bn3 = nn.BatchNorm2d(planes) 163 | self.relu = nn.ReLU(inplace=True) 164 | self.upsample = upsample 165 | self.stride = stride 166 | 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | nn.init.xavier_uniform_(m.weight.data) 170 | elif isinstance(m, nn.ConvTranspose2d): 171 | nn.init.xavier_uniform_(m.weight.data) 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | 176 | def forward(self, x): 177 | residual = x 178 | 179 | out = self.conv1(x) 180 | out = self.bn1(out) 181 | out = self.relu(out) 182 | 183 | out = self.conv2(out) 184 | out = self.bn2(out) 185 | out = self.relu(out) 186 | 187 | out = self.conv3(out) 188 | out = self.bn3(out) 189 | 190 | if self.upsample is not None: 191 | residual = self.upsample(x) 192 | 193 | out += residual 194 | out = self.relu(out) 195 | 196 | return out 197 | 198 | class upbolckV2(nn.Module): 199 | def __init__(self,cin,cout): 200 | super().__init__() 201 | 202 | self.conv1 = nn.Conv2d(cin,cin//2,kernel_size=3,stride=1,padding=1) 203 | self.bn1 = nn.BatchNorm2d(cin//2) 204 | self.relu1 = nn.ReLU(inplace=True) 205 | 206 | self.conv2 = nn.Conv2d(cin//2,cin//2,kernel_size=3,stride=1,padding=1) 207 | self.bn2 = nn.BatchNorm2d(cin//2) 208 | self.relu2 = nn.ReLU(inplace=True) 209 | 210 | self.conv3 = nn.Conv2d(cin//2,cin//2,kernel_size=3,stride=1,padding=1) 211 | self.bn3 = nn.BatchNorm2d(cin//2) 212 | self.relu3 = nn.ReLU(inplace=True) 213 | 214 | self.shortcutconv = nn.Conv2d(cin,cin//2,kernel_size=1,stride=1) 215 | self.shortcutbn = nn.BatchNorm2d(cin//2) 216 | self.shortcutrelu = nn.ReLU(inplace=True) 217 | 218 | 219 | self.se = SE_fz(in_channels=cin//2,med_channels=cin//4) 220 | 221 | self.transconv = nn.ConvTranspose2d(cin//2,cout,kernel_size=2, stride=2, padding=0, bias=False) 222 | self.transbn = nn.BatchNorm2d(cout) 223 | self.transrelu = nn.ReLU(inplace=True) 224 | 225 | def forward(self,x): 226 | 227 | fusion = self.conv1(x) 228 | fusion = self.bn1(fusion) 229 | fusion = self.relu1(fusion) 230 | 231 | sc0 = fusion 232 | 233 | 234 | fusion = self.conv2(fusion) 235 | fusion = self.bn2(fusion) 236 | fusion = self.relu2(fusion) 237 | 238 | fusion = sc0 + fusion 239 | 240 | fusion = self.conv3(fusion) 241 | fusion = self.bn3(fusion) 242 | fusion = self.relu3(fusion) 243 | 244 | sc = self.shortcutconv(x) 245 | sc = self.shortcutbn(sc) 246 | sc = self.shortcutrelu(sc) 247 | 248 | fusion = fusion+sc 249 | 250 | fusion = self.se(fusion) 251 | 252 | 253 | fusion = self.transconv(fusion) 254 | fusion = self.transbn(fusion) 255 | fusion = self.transrelu(fusion) 256 | 257 | return fusion 258 | 259 | class SE_fz(nn.Module): 260 | def __init__(self, in_channels, med_channels): 261 | super(SE_fz, self).__init__() 262 | 263 | self.average = nn.AdaptiveAvgPool2d(1) 264 | 265 | self.fc1 = nn.Linear(in_channels,med_channels) 266 | self.bn1 = nn.BatchNorm1d(med_channels) 267 | self.relu = nn.ReLU() 268 | self.fc2 = nn.Linear(med_channels,in_channels) 269 | self.sg = nn.Sigmoid() 270 | 271 | def forward(self,input): 272 | x = input 273 | x = self.average(input) 274 | x = x.squeeze(2) 275 | x = x.squeeze(2) 276 | x = self.fc1(x) 277 | x= self.bn1(x) 278 | x = self.relu(x) 279 | x = self.fc2(x) 280 | x = self.sg(x) 281 | x = x.unsqueeze(2) 282 | x = x.unsqueeze(3) 283 | out = torch.mul(input,x) 284 | return out 285 | 286 | 287 | class Fusion_V2(nn.Module): 288 | def __init__(self, in_channels, med_channels,channel): 289 | super().__init__() 290 | self.Weight = weight(linearhidden=channel) 291 | 292 | self.pam = PAM(channel=in_channels) 293 | self.cam = SE_fz(in_channels=in_channels,med_channels=med_channels) 294 | 295 | def forward(self, rgb, thermal, mask): 296 | 297 | weights = self.Weight(rgb,thermal) 298 | 299 | B,C,H,W = rgb.size() 300 | 301 | mask = F.interpolate(mask,[H,W],mode="nearest") 302 | 303 | mask_rgb = torch.ones(B,1,H,W) 304 | if mask.is_cuda: 305 | mask_rgb = mask_rgb.cuda(mask.device) 306 | 307 | mask_thermal = torch.mul(mask.reshape(B,-1),weights[:,1].reshape((B,1))).reshape(B,1,H,W) 308 | 309 | mask_rgb = mask_rgb-mask_thermal 310 | 311 | fusion = rgb * mask_rgb + thermal * mask_thermal 312 | 313 | fusion = self.cam(fusion) 314 | fusion = self.pam(fusion) 315 | 316 | 317 | return fusion 318 | 319 | 320 | 321 | class weight(nn.Module): 322 | def __init__(self, linearhidden): 323 | super().__init__() 324 | self.adapool = nn.AdaptiveAvgPool2d(output_size=1) 325 | self.fc1 = nn.Linear(linearhidden,linearhidden//2) 326 | self.bn1 = nn.BatchNorm1d(linearhidden//2) 327 | self.relu1 = nn.ReLU(True) 328 | 329 | self.fc2 = nn.Linear(linearhidden//2,linearhidden//4) 330 | self.bn2 = nn.BatchNorm1d(linearhidden//4) 331 | self.relu2 = nn.ReLU(True) 332 | 333 | self.fc3 = nn.Linear(linearhidden//4,2) 334 | self.relu3 = nn.ReLU(True) 335 | self.sf = nn.Softmax(dim=1) 336 | def forward(self,rgb,thermal): 337 | 338 | x = torch.cat((rgb,thermal),dim=1) 339 | x = self.adapool(x) 340 | b = x.size(0) 341 | x=x.reshape(b,-1) 342 | x = self.fc1(x) 343 | x = self.bn1(x) 344 | x = self.relu1(x) 345 | 346 | x = self.fc2(x) 347 | x = self.bn2(x) 348 | x = self.relu2(x) 349 | 350 | x = self.fc3(x) 351 | x = self.relu3(x) 352 | x = self.sf(x) 353 | return x 354 | 355 | 356 | class PAM(nn.Module): 357 | """ Position Attention Module """ 358 | def __init__(self, channel): 359 | super(PAM, self).__init__() 360 | self.conv = nn.Conv2d(channel, 1, kernel_size=1) 361 | self.act = nn.Sigmoid() 362 | 363 | def forward(self, x): 364 | """ 365 | Args: 366 | x ([torch.tensor]): size N*C*H*W 367 | Returns: 368 | [torch.tensor]: size N*C*H*W 369 | """ 370 | _, c, _, _ = x.size() 371 | y = self.act(self.conv(x)) 372 | y = y.repeat(1, c, 1, 1) 373 | return x * y 374 | 375 | 376 | class Attention(nn.Module): 377 | def __init__( 378 | self, 379 | *, 380 | dim, 381 | fmap_size, 382 | heads = 4, 383 | dim_head = 128, 384 | rel_pos_emb = False 385 | ): 386 | super().__init__() 387 | self.heads = heads 388 | self.scale = dim_head ** -0.5 389 | inner_dim = heads * dim_head 390 | 391 | self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False) 392 | 393 | self.pos_emb = AbsPosEmb(fmap_size, dim_head) 394 | 395 | 396 | def forward(self, fmap): 397 | heads, b, c, h, w = self.heads, *fmap.shape 398 | 399 | q, k, v = self.to_qkv(fmap).chunk(3, dim = 1) 400 | q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> b h (x y) d', h = heads), (q, k, v)) 401 | 402 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 403 | sim += self.pos_emb(q) 404 | 405 | attn = sim.softmax(dim = -1) 406 | 407 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 408 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) 409 | 410 | return out 411 | 412 | class AbsPosEmb(nn.Module): 413 | def __init__( 414 | self, 415 | fmap_size, 416 | dim_head 417 | ): 418 | super().__init__() 419 | scale = dim_head ** -0.5 420 | self.scale = scale 421 | self.height = nn.Parameter(torch.randn(fmap_size[0], dim_head) * scale) 422 | self.width = nn.Parameter(torch.randn(fmap_size[1], dim_head) * scale) 423 | 424 | def forward(self, q): 425 | emb = rearrange(self.height, 'h d -> h () d') + rearrange(self.width, 'w d -> () w d') 426 | emb = rearrange(emb, ' h w d -> (h w) d') 427 | logits = einsum('b h i d, j d -> b h i j', q, emb) * self.scale 428 | return logits 429 | 430 | class BottleBlock(nn.Module): 431 | def __init__( 432 | self, 433 | *, 434 | dim, 435 | fmap_size, 436 | dim_out, 437 | proj_factor, 438 | downsample, 439 | heads = 4, 440 | dim_head = 128, 441 | rel_pos_emb = False, 442 | activation = nn.ReLU() 443 | ): 444 | super().__init__() 445 | 446 | # shortcut 447 | 448 | if dim != dim_out or downsample: #di yi bian de shi hou zhi xing 449 | kernel_size, stride, padding = (3, 2, 1) if downsample else (1, 1, 0) 450 | 451 | self.shortcut = nn.Sequential( 452 | nn.Conv2d(dim, dim_out, kernel_size, stride = stride, padding = padding, bias = False), 453 | nn.BatchNorm2d(dim_out), 454 | activation 455 | ) 456 | else: 457 | self.shortcut = nn.Identity() 458 | 459 | # contraction and expansion 460 | 461 | attention_dim = dim_out // proj_factor 462 | 463 | 464 | 465 | self.net = nn.Sequential( 466 | nn.Conv2d(dim, attention_dim, 1, bias = False), 467 | nn.BatchNorm2d(attention_dim), 468 | activation, 469 | Attention( 470 | dim = attention_dim, 471 | fmap_size = fmap_size, 472 | heads = heads, 473 | dim_head = dim_head, 474 | rel_pos_emb = rel_pos_emb 475 | ), 476 | nn.AvgPool2d((2, 2)) if downsample else nn.Identity(), 477 | nn.BatchNorm2d(heads*dim_head), 478 | activation, 479 | nn.Conv2d(heads*dim_head, dim_out, 1, bias = False), 480 | nn.BatchNorm2d(dim_out) 481 | ) 482 | 483 | # init last batch norm gamma to zero 484 | 485 | nn.init.zeros_(self.net[-1].weight) 486 | 487 | # final activation 488 | 489 | self.activation = activation 490 | 491 | def forward(self, x): 492 | 493 | shortcut = self.shortcut(x) 494 | 495 | x = self.net(x) 496 | 497 | 498 | x += shortcut 499 | return self.activation(x) 500 | 501 | # main bottle stack 502 | 503 | class BottleStack(nn.Module): 504 | def __init__( 505 | self, 506 | *, 507 | dim, 508 | fmap_size, 509 | dim_out = 2048, 510 | proj_factor = 4, 511 | num_layers = 3, 512 | heads = 4, 513 | dim_head = 128, 514 | downsample = True, 515 | rel_pos_emb = False, 516 | activation = nn.ReLU() 517 | ): 518 | super().__init__() 519 | self.dim = dim 520 | self.fmap_size = fmap_size 521 | 522 | layers = [] 523 | 524 | for i in range(num_layers): 525 | is_first = i == 0 526 | dim = (dim if is_first else dim_out) 527 | layer_downsample = is_first and downsample 528 | #layer_fmap_size = fmap_size 529 | layer_fmap_size = (fmap_size[0] // (2 if downsample and not is_first else 1),fmap_size[1] // (2 if downsample and not is_first else 1)) 530 | #layer_fmap_size = fmap_size[1] // (2 if downsample and not is_first else 1) 531 | layers.append(BottleBlock( 532 | dim = dim, 533 | fmap_size = layer_fmap_size, 534 | dim_out = dim_out, 535 | proj_factor = proj_factor, 536 | heads = heads, 537 | dim_head = dim_head, 538 | downsample = layer_downsample, 539 | rel_pos_emb = rel_pos_emb, 540 | activation = activation 541 | )) 542 | 543 | self.net = nn.Sequential(*layers) 544 | 545 | def forward(self, x): 546 | _, c, h, w = x.shape 547 | assert c == self.dim, f'channels of feature map {c} must match channels given at init {self.dim}' 548 | assert h == self.fmap_size[0] and w == self.fmap_size[1], f'height and width of feature map must match the fmap_size given at init {self.fmap_size}' 549 | return self.net(x) 550 | 551 | 552 | 553 | class SE_fz(nn.Module): 554 | def __init__(self, in_channels, med_channels): 555 | super(SE_fz, self).__init__() 556 | 557 | self.average = nn.AdaptiveAvgPool2d(1) 558 | 559 | self.fc1 = nn.Linear(in_channels,med_channels) 560 | self.bn1 = nn.BatchNorm1d(med_channels) 561 | self.relu = nn.ReLU() 562 | self.fc2 = nn.Linear(med_channels,in_channels) 563 | self.sg = nn.Sigmoid() 564 | 565 | def forward(self,input): 566 | x = input 567 | x = self.average(input) 568 | x = x.squeeze(2) 569 | x = x.squeeze(2) 570 | x = self.fc1(x) 571 | x= self.bn1(x) 572 | x = self.relu(x) 573 | x = self.fc2(x) 574 | x = self.sg(x) 575 | x = x.unsqueeze(2) 576 | x = x.unsqueeze(3) 577 | out = torch.mul(input,x) 578 | return out 579 | 580 | def unit_test(): 581 | num_minibatch = 2 582 | rgb = torch.randn(num_minibatch, 3, 288, 512).cuda(0) 583 | thermal = torch.randn(num_minibatch, 2, 288, 512).cuda(0) 584 | rtf_net = AMFNet(9).cuda(0) 585 | input = torch.cat((rgb, thermal), dim=1) 586 | output = rtf_net(input) 587 | print("output size:",output.size()) 588 | 589 | if __name__ == '__main__': 590 | unit_test() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .AMFNet import AMFNet -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | import os, argparse, time, datetime, sys, shutil, stat, torch 2 | import numpy as np 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | from util.RGB_Depth_Fusion_dataset import RGB_Depth_Fusion_dataset 6 | from util.util import compute_results, visualize 7 | from sklearn.metrics import confusion_matrix 8 | from scipy.io import savemat 9 | from model import AMFNet 10 | 11 | ############################################################################################# 12 | parser = argparse.ArgumentParser(description='Test with pytorch') 13 | ############################################################################################# 14 | 15 | parser.add_argument('--model_name', '-m', type=str, default='AMFNet') 16 | parser.add_argument('--weight_name', '-w', type=str, default='AMFNet') 17 | parser.add_argument('--file_name', '-f', type=str, default='final.pth') 18 | parser.add_argument('--dataset_split', '-d', type=str, default='test') 19 | parser.add_argument('--gpu', '-g', type=int, default=0) 20 | ############################################################################################# 21 | parser.add_argument('--img_height', '-ih', type=int, default=288) 22 | parser.add_argument('--img_width', '-iw', type=int, default=512) 23 | parser.add_argument('--num_workers', '-j', type=int, default=16) 24 | parser.add_argument('--n_class', '-nc', type=int, default=3) 25 | parser.add_argument('--data_dir', '-dr', type=str, default='./dataset/') 26 | parser.add_argument('--model_dir', '-wd', type=str, default='./weights_backup/') 27 | args = parser.parse_args() 28 | ############################################################################################# 29 | 30 | 31 | if __name__ == '__main__': 32 | 33 | torch.cuda.set_device(args.gpu) 34 | print("\nthe pytorch version:", torch.__version__) 35 | print("the gpu count:", torch.cuda.device_count()) 36 | print("the current used gpu:", torch.cuda.current_device(), '\n') 37 | 38 | # prepare save direcotry 39 | if os.path.exists("./runs"): 40 | print("previous \"./runs\" folder exist, will delete this folder") 41 | shutil.rmtree("./runs") 42 | os.makedirs("./runs") 43 | os.chmod("./runs", stat.S_IRWXO) # allow the folder created by docker read, written, and execuated by local machine 44 | model_dir = os.path.join(args.model_dir, args.weight_name) 45 | if os.path.exists(model_dir) is False: 46 | sys.exit("the %s does not exit." %(model_dir)) 47 | model_file = os.path.join(model_dir, args.file_name) 48 | if os.path.exists(model_file) is True: 49 | print('use the final model file.') 50 | else: 51 | sys.exit('no model file found.') 52 | print('testing %s: %s on GPU #%d with pytorch' % (args.model_name, args.weight_name, args.gpu)) 53 | 54 | conf_total = np.zeros((args.n_class, args.n_class)) 55 | 56 | model = eval(args.model_name)(n_class=args.n_class) 57 | 58 | if args.gpu >= 0: model.cuda(args.gpu) 59 | print('loading model file %s... ' % model_file) 60 | pretrained_weight = torch.load(model_file, map_location = lambda storage, loc: storage.cuda(args.gpu)) 61 | own_state = model.state_dict() 62 | for name, param in pretrained_weight.items(): 63 | own_state[name].copy_(param) 64 | print('done!') 65 | 66 | batch_size = 1 67 | test_dataset = RGB_Depth_Fusion_dataset(data_dir=args.data_dir, split=args.dataset_split, input_h=args.img_height, input_w=args.img_width) 68 | test_loader = DataLoader( 69 | dataset = test_dataset, 70 | batch_size = batch_size, 71 | shuffle = False, 72 | num_workers = args.num_workers, 73 | pin_memory = True, 74 | drop_last = False 75 | ) 76 | ave_time_cost = 0.0 77 | 78 | model.eval() 79 | with torch.no_grad(): 80 | for it, (images, labels, names) in enumerate(test_loader): 81 | images = Variable(images).cuda(args.gpu) 82 | labels = Variable(labels).cuda(args.gpu) 83 | torch.cuda.synchronize() 84 | start_time = time.time() 85 | logits = model(images) 86 | torch.cuda.synchronize() 87 | end_time = time.time() 88 | 89 | if it>=5: # # ignore the first 5 frames 90 | ave_time_cost += (end_time-start_time) 91 | # convert tensor to numpy 1d array 92 | label = labels.cpu().numpy().squeeze().flatten() 93 | prediction = logits.argmax(1).cpu().numpy().squeeze().flatten() # prediction and label are both 1-d array, size: minibatch*640*480 94 | # generate confusion matrix frame-by-frame 95 | conf = confusion_matrix(y_true=label, y_pred=prediction, labels=[0,1,2]) # conf is an n_class*n_class matrix, vertical axis: groundtruth, horizontal axis: prediction 96 | conf_total += conf 97 | # save demo images 98 | visualize(image_name=names, predictions=logits.argmax(1), weight_name=args.weight_name) 99 | print("%s, %s, frame %d/%d, %s, time cost: %.2f ms, demo result saved." 100 | %(args.model_name, args.weight_name, it+1, len(test_loader), names, (end_time-start_time)*1000)) 101 | 102 | precision_per_class, recall_per_class, iou_per_class,F1_per_class = compute_results(conf_total) 103 | conf_total_matfile = os.path.join("./runs", 'conf_'+args.weight_name+'.mat') 104 | savemat(conf_total_matfile, {'conf': conf_total}) # 'conf' is the variable name when loaded in Matlab 105 | 106 | print('\n###########################################################################') 107 | print('\n%s: %s test results (with batch size %d) on %s using %s:' %(args.model_name, args.weight_name, batch_size, datetime.date.today(), torch.cuda.get_device_name(args.gpu))) 108 | print('\n* the tested dataset name: %s' % args.dataset_split) 109 | print('* the tested image count: %d' % len(test_loader)) 110 | print('* the tested image size: %d*%d' %(args.img_height, args.img_width)) 111 | print('* the weight name: %s' %args.weight_name) 112 | print('* the file name: %s' %args.file_name) 113 | print("* iou per class: \n unlabeled: %.6f, road: %.6f, negative: %.6f" \ 114 | %(iou_per_class[0]*100, iou_per_class[1]*100, iou_per_class[2]*100)) 115 | print("* recall per class: \n unlabeled: %.6f, road: %.6f, negative: %.6f" \ 116 | %(recall_per_class[0]*100, recall_per_class[1]*100, recall_per_class[2]*100)) 117 | print("* pre per class: \n unlabeled: %.6f, road: %.6f, negative: %.6f" \ 118 | %(precision_per_class[0]*100, precision_per_class[1]*100, precision_per_class[2]*100)) 119 | print("* F1 per class: \n unlabeled: %.6f, road: %.6f, negative: %.6f" \ 120 | %(F1_per_class[0]*100, F1_per_class[1]*100, F1_per_class[2]*100)) 121 | 122 | print("\n* average values (np.mean(x)): \n iou: %.6f, recall: %.6f, pre: %.6f, F1: %.6f" \ 123 | %(iou_per_class.mean()*100,recall_per_class.mean()*100, precision_per_class.mean()*100,F1_per_class.mean()*100)) 124 | print("* average values (np.mean(np.nan_to_num(x))): \n iou: %.6f, recall: %.6f, pre: %.6f, F1: %.6f" \ 125 | %(np.mean(np.nan_to_num(iou_per_class))*100, np.mean(np.nan_to_num(recall_per_class))*100, np.mean(np.nan_to_num(precision_per_class))*100, np.mean(np.nan_to_num(F1_per_class))*100)) 126 | print('\n* the average time cost per frame (with batch size %d): %.2f ms, namely, the inference speed is %.2f fps' %(batch_size, ave_time_cost*1000/(len(test_loader)-5), 1.0/(ave_time_cost/(len(test_loader)-5)))) # ignore the first 10 frames 127 | 128 | print('\n###########################################################################') -------------------------------------------------------------------------------- /util/RGB_Depth_Fusion_dataset.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | from torch.utils.data.dataset import Dataset 3 | import numpy as np 4 | import PIL 5 | 6 | class RGB_Depth_Fusion_dataset(Dataset): 7 | 8 | def __init__(self, data_dir, split, input_h=288, input_w=512 ,transform=[],threshold = 0): 9 | super(RGB_Depth_Fusion_dataset, self).__init__() 10 | 11 | with open(os.path.join(data_dir, split+'.txt'), 'r') as f: 12 | self.names = [name.strip() for name in f.readlines()] 13 | 14 | self.data_dir = data_dir 15 | self.split = split 16 | self.input_h = input_h 17 | self.input_w = input_w 18 | self.transform = transform 19 | self.n_data = len(self.names) 20 | self.threshold = threshold 21 | 22 | def read_image(self, name, folder,head): 23 | file_path = os.path.join(self.data_dir, '%s/%s%s.png' % (folder, head,name)) 24 | image = np.asarray(PIL.Image.open(file_path)) 25 | return image 26 | 27 | def __getitem__(self, index): 28 | name = self.names[index] 29 | image = self.read_image(name, 'left','left') 30 | label = self.read_image(name, 'labels','label') 31 | depth = self.read_image(name, 'depth','depth') 32 | 33 | image = np.asarray(PIL.Image.fromarray(image).resize((self.input_w, self.input_h))) 34 | image = image.astype('float32') 35 | image = np.transpose(image, (2,0,1))/255.0 36 | depth = np.asarray(PIL.Image.fromarray(depth).resize((self.input_w, self.input_h))) 37 | depth = depth.astype('float32') 38 | 39 | mask = np.float32(depth>self.threshold) 40 | 41 | M = depth.max() 42 | depth = depth/M 43 | 44 | label = np.asarray(PIL.Image.fromarray(label).resize((self.input_w, self.input_h), resample=PIL.Image.NEAREST)) 45 | label = label.astype('int64') 46 | return torch.cat((torch.tensor(image), torch.tensor(depth).unsqueeze(0), torch.tensor(mask).unsqueeze(0)),dim=0), torch.tensor(label),name 47 | 48 | def __len__(self): 49 | return self.n_data 50 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util/__pycache__/MF_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/MF_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/Po_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/Po_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/RGB_Depth_Fusion_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/RGB_Depth_Fusion_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/RGB_Depth_Fusion_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/RGB_Depth_Fusion_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/RGB_Depth_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/RGB_Depth_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/RGB_Depth_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/RGB_Depth_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/RGB_Disp_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/RGB_Disp_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/RT_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/RT_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/augmentation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/augmentation.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /util/__pycache__/lr_policy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/lr_policy.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-sun/AMFNet/98336a2c623bc37c68c3fbec85348c29e3d695ae/util/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /util/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | #from ipdb import set_trace as st 4 | 5 | 6 | class RandomFlip(): 7 | def __init__(self, prob=0.5): 8 | #super(RandomFlip, self).__init__() 9 | self.prob = prob 10 | 11 | def __call__(self, image, label): 12 | if np.random.rand() < self.prob: 13 | image = image[:,::-1] 14 | label = label[:,::-1] 15 | return image, label 16 | 17 | 18 | class RandomCrop(): 19 | def __init__(self, crop_rate=0.1, prob=1.0): 20 | #super(RandomCrop, self).__init__() 21 | self.crop_rate = crop_rate 22 | self.prob = prob 23 | 24 | def __call__(self, image, label): 25 | if np.random.rand() < self.prob: 26 | w, h, c = image.shape 27 | 28 | h1 = np.random.randint(0, h*self.crop_rate) 29 | w1 = np.random.randint(0, w*self.crop_rate) 30 | h2 = np.random.randint(h-h*self.crop_rate, h+1) 31 | w2 = np.random.randint(w-w*self.crop_rate, w+1) 32 | 33 | image = image[w1:w2, h1:h2] 34 | label = label[w1:w2, h1:h2] 35 | 36 | return image, label 37 | 38 | 39 | class RandomCropOut(): 40 | def __init__(self, crop_rate=0.2, prob=1.0): 41 | #super(RandomCropOut, self).__init__() 42 | self.crop_rate = crop_rate 43 | self.prob = prob 44 | 45 | def __call__(self, image, label): 46 | if np.random.rand() < self.prob: 47 | w, h, c = image.shape 48 | 49 | h1 = np.random.randint(0, h*self.crop_rate) 50 | w1 = np.random.randint(0, w*self.crop_rate) 51 | h2 = int(h1 + h*self.crop_rate) 52 | w2 = int(w1 + w*self.crop_rate) 53 | 54 | image[w1:w2, h1:h2] = 0 55 | label[w1:w2, h1:h2] = 0 56 | 57 | return image, label 58 | 59 | 60 | class RandomBrightness(): 61 | def __init__(self, bright_range=0.15, prob=0.9): 62 | #super(RandomBrightness, self).__init__() 63 | self.bright_range = bright_range 64 | self.prob = prob 65 | 66 | def __call__(self, image, label): 67 | if np.random.rand() < self.prob: 68 | bright_factor = np.random.uniform(1-self.bright_range, 1+self.bright_range) 69 | image = (image * bright_factor).astype(image.dtype) 70 | 71 | return image, label 72 | 73 | 74 | class RandomNoise(): 75 | def __init__(self, noise_range=5, prob=0.9): 76 | #super(RandomNoise, self).__init__() 77 | self.noise_range = noise_range 78 | self.prob = prob 79 | 80 | def __call__(self, image, label): 81 | if np.random.rand() < self.prob: 82 | w, h, c = image.shape 83 | 84 | noise = np.random.randint( 85 | -self.noise_range, 86 | self.noise_range, 87 | (w,h,c) 88 | ) 89 | 90 | image = (image + noise).clip(0,255).astype(image.dtype) 91 | 92 | return image, label 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | # By Yuxiang Sun, Dec. 4, 2020 2 | # Email: sun.yuxiang@outlook.com 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | # 0:unlabeled, 1:car, 2:person, 3:bike, 4:curve, 5:car_stop, 6:guardrail, 7:color_cone, 8:bump 8 | def get_palette(): 9 | unlabelled = [0,0,0] 10 | car = [64,0,128] 11 | person = [64,64,0] 12 | bike = [0,128,192] 13 | others = [192,128,64] 14 | curve = [0,0,192] 15 | car_stop = [128,128,0] 16 | color_cone = [192,128,128] 17 | guardrail = [64,64,128] 18 | 19 | #bump = [192,64,0] 20 | ashcan = [64,128,64] 21 | 22 | palette = np.array([unlabelled,person,car, bike, others, curve, ashcan,color_cone,car_stop, guardrail]) 23 | return palette 24 | 25 | def visualize(image_name, predictions, weight_name): 26 | palette = get_palette() 27 | for (i, pred) in enumerate(predictions): 28 | pred = predictions[i].cpu().numpy() 29 | img = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8) 30 | for cid in range(0, len(palette)): # fix the mistake from the MFNet code on Dec.27, 2019 31 | img[pred == cid] = palette[cid] 32 | img = Image.fromarray(np.uint8(img)) 33 | img.save('runs/Pred_' + weight_name + '_' + image_name[i] + '.png') 34 | 35 | def compute_results(conf_total): 36 | n_class = conf_total.shape[0] 37 | consider_unlabeled = True # must consider the unlabeled, please set it to True 38 | if consider_unlabeled is True: 39 | start_index = 0 40 | else: 41 | start_index = 1 42 | precision_per_class = np.zeros(n_class) 43 | recall_per_class = np.zeros(n_class) 44 | iou_per_class = np.zeros(n_class) 45 | F1_per_class = np.zeros(n_class) 46 | for cid in range(start_index, n_class): # cid: class id 47 | if conf_total[start_index:, cid].sum() == 0: 48 | precision_per_class[cid] = np.nan 49 | else: 50 | precision_per_class[cid] = float(conf_total[cid, cid]) / float(conf_total[start_index:, cid].sum()) # precision = TP/TP+FP 51 | if conf_total[cid, start_index:].sum() == 0: 52 | recall_per_class[cid] = np.nan 53 | else: 54 | recall_per_class[cid] = float(conf_total[cid, cid]) / float(conf_total[cid, start_index:].sum()) # recall = TP/TP+FN 55 | if (conf_total[cid, start_index:].sum() + conf_total[start_index:, cid].sum() - conf_total[cid, cid]) == 0: 56 | iou_per_class[cid] = np.nan 57 | else: 58 | iou_per_class[cid] = float(conf_total[cid, cid]) / float((conf_total[cid, start_index:].sum() + conf_total[start_index:, cid].sum() - conf_total[cid, cid])) # IoU = TP/TP+FP+FN 59 | if (recall_per_class[cid] == np.nan) | (precision_per_class[cid] == np.nan) |(precision_per_class[cid]==0)|(recall_per_class[cid]==0): 60 | F1_per_class[cid] = np.nan 61 | else : 62 | F1_per_class[cid] = 2 / (1/precision_per_class[cid] +1/recall_per_class[cid]) 63 | 64 | return precision_per_class, recall_per_class, iou_per_class,F1_per_class 65 | --------------------------------------------------------------------------------