├── .idea ├── Efficient-VRNet.iml ├── encodings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── 2007_val.txt ├── README.md ├── backbone ├── __init__.py ├── __pycache__ │ └── __init__.cpython-39.pyc ├── attention_modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── eca.cpython-39.pyc │ │ └── shuffle_attention.cpython-39.pyc │ ├── contextual_attention.py │ ├── eca.py │ ├── mobile_attention.py │ ├── mobile_vit_attention.py │ └── shuffle_attention.py ├── conv_utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── normal_conv.cpython-39.pyc │ ├── dcn.py │ ├── ds_conv.py │ ├── dynamic_conv.py │ └── normal_conv.py ├── fusion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ └── vr_coc.cpython-39.pyc │ ├── context_cluster_down.py │ ├── context_cluster_iterative.py │ ├── context_cluster_nofold.py │ ├── context_cluster_withGAP.py │ └── vr_coc.py ├── radar │ ├── __init__.py │ ├── context_cluster.py │ ├── context_cluster_down.py │ ├── context_cluster_iterative.py │ ├── context_cluster_nofold.py │ ├── context_cluster_withGAP.py │ ├── contextcluster.py │ ├── model_utils.py │ ├── pointmlp.py │ └── pointnet.py └── vision │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── context_cluster.cpython-39.pyc │ ├── context_cluster.py │ ├── context_cluster_down.py │ ├── context_cluster_iterative.py │ ├── context_cluster_nofold.py │ └── context_cluster_withGAP.py ├── deeplab.py ├── get_miou.py ├── head ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── decouplehead.cpython-39.pyc └── decouplehead.py ├── image_augmentation_test ├── __init__.py ├── dark_channel.py └── sharpen.py ├── model_data ├── heatmap_vision.png ├── voc_classes.txt └── waterscenes.txt ├── neck ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── coc_fpn_dual.cpython-39.pyc └── coc_fpn_dual.py ├── nets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── deeplabv3_training.cpython-39.pyc │ ├── efficient_vrnet.cpython-39.pyc │ └── yolo_training.cpython-39.pyc ├── deeplabv3_training.py ├── efficient_vrnet.py └── yolo_training.py ├── predict.py ├── predict_seg.py ├── requirements.txt ├── train.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── callbacks.cpython-39.pyc │ ├── dataloader.cpython-39.pyc │ ├── multitaskloss.cpython-39.pyc │ ├── utils.cpython-39.pyc │ ├── utils_bbox.cpython-39.pyc │ ├── utils_fit.cpython-39.pyc │ └── utils_map.cpython-39.pyc ├── callbacks.py ├── dataloader.py ├── multitaskloss.py ├── utils.py ├── utils_bbox.py ├── utils_fit.py └── utils_map.py ├── utils_seg ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── callbacks.cpython-39.pyc │ ├── dataloader.cpython-39.pyc │ ├── utils.cpython-39.pyc │ ├── utils_fit.cpython-39.pyc │ └── utils_metrics.cpython-39.pyc ├── callbacks.py ├── dataloader.py ├── utils.py ├── utils_fit.py └── utils_metrics.py ├── venv ├── Scripts │ ├── Activate.ps1 │ ├── activate │ ├── activate.bat │ ├── deactivate.bat │ ├── python.exe │ └── pythonw.exe └── pyvenv.cfg ├── voc_annotation.py ├── voc_annotation_seg.py └── yolo.py /.idea/Efficient-VRNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASY-VRNet: Waterway Panoptic Driving Perception Model based on Asymmetric Fair Fusion of Vision and 4D mmWave Radar 2 | 3 | # Device: 4 | Camera: Sony IMX-317 5 | 6 | Radar: Ocuii Imaging Radar 7 | 8 | # Implementation: 9 | 10 | * Create a conda environment and install dependencies 11 | > git clone https://github.com/GuanRunwei/Efficient-VRNet.git \ 12 | > cd Efficient-VRNet \ 13 | > conda create -n efficientvrnet python=3.7 \ 14 | > conda activate efficientvrnet \ 15 | > pip install -r requirements.txt 16 | 17 | 18 | * Prepare datasets for object detection and semantic segmentation based on image and radar 19 | 20 | > For object detection, make two files, one for training and another for test. There are two ways to complete this. \ 21 | > 1. Make two txt files, one for training and another for test. Two files are with the same format: 22 | > In one line, there are two parts, an image path and objects: 23 | > **image path**: E:/Big_Datasets/water_surface/all-1114/all/VOCdevkit/VOC2007/JPEGImages/1664091257.87023.jpg 24 | > **object 1** (the first four numbers are the bounding box and the last is the category): 1131,430,1152,473,0 25 | > **object 2**: 920,425,937,451,0 26 | > Therefore, each line is like this: E:/Big_Datasets/water_surface/all-1114/all/VOCdevkit/VOC2007/JPEGImages/1664091257.87023.jpg 1131,430,1152,473,0 920,425,937,451,0 27 | > 2. Organize the files in VOC format in one folder like this: \ 28 | > VOCdevkit \ 29 | > -VOC2007 \ 30 | > -- Annotations -> xml annotations in VOC format (you need to put annotations in it) \ 31 | > -- ImageSets -> id (you do not need to do) \ 32 | > -- JPEGImages -> images (you need to put images in it) \ 33 | > **enter** voc_annotation.py and follow the annotation to make your dataset 34 | 35 | 36 | > For semantic segmentation, make the folders in VOC format. \ 37 | > VOCdevkit \ 38 | > -VOC2007 \ 39 | > -- ImageSets -> id (you do not need to do) \ 40 | > -- JPEGImages -> images (you need to put images in it) \ 41 | > -- SegmentationClass -> images_seg (you need to put images in it) 42 | 43 | 44 | > For radar files, you need to make the radar map with the spatial size of images for object detection. 45 | We need four features: range, velocity, elevation and power, so firstly project 3D point clouds into 2D image plane, 46 | then make a numpy matrix with 4×512×512, each channel means one feature. Then, save the numpy matrix in npz format. 47 | 48 | > ***Attention:*** The names of images for object detection,segmentation, and radar must be the same! 49 | The only difference between them is the format. 50 | 51 | * Train 52 | 53 | > After you have completed the above, enter the **train.py**. Change the file path variables and hyperparameters and run it. 54 | 55 | 56 | * Visualization 57 | 58 | > **predict.py** is to test the object detection \ 59 | > **yolo.py** is to define the model for object detection \ 60 | 61 | > **predict_seg.py** is to test the semantic segmentation \ 62 | > **deeplab.py** is to define the model for semantic segmentation \ 63 | 64 | > enter these files see the details by annotations in the files 65 | 66 | If have any questions, put them in Issues --- 67 | 68 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/__init__.py -------------------------------------------------------------------------------- /backbone/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/attention_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__init__.py -------------------------------------------------------------------------------- /backbone/attention_modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/attention_modules/__pycache__/eca.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__pycache__/eca.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/attention_modules/__pycache__/shuffle_attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__pycache__/shuffle_attention.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/attention_modules/contextual_attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import flatten, nn 4 | from torch.nn import init 5 | from torch.nn.modules.activation import ReLU 6 | from torch.nn.modules.batchnorm import BatchNorm2d 7 | from torch.nn import functional as F 8 | 9 | 10 | class ContextAttention(nn.Module): 11 | 12 | def __init__(self, dim=512, kernel_size=3): 13 | super().__init__() 14 | self.dim = dim 15 | self.kernel_size = kernel_size 16 | 17 | self.key_embed = nn.Sequential( 18 | nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=4, bias=False), 19 | nn.BatchNorm2d(dim), 20 | nn.ReLU() 21 | ) 22 | self.value_embed=nn.Sequential( 23 | nn.Conv2d(dim, dim, 1, bias=False), 24 | nn.BatchNorm2d(dim) 25 | ) 26 | 27 | factor = 4 28 | self.attention_embed = nn.Sequential( 29 | nn.Conv2d(2*dim, 2*dim//factor, 1, bias=False), 30 | nn.BatchNorm2d(2*dim//factor), 31 | nn.ReLU(), 32 | nn.Conv2d(2*dim//factor, kernel_size*kernel_size*dim, 1) 33 | ) 34 | 35 | 36 | def forward(self, x): 37 | bs, c, h, w = x.shape 38 | k1 = self.key_embed(x) # bs,c,h,w 39 | v = self.value_embed(x).view(bs, c, -1) # bs,c,h,w 40 | 41 | y = torch.cat([k1, x], dim=1) # bs,2c,h,w 42 | att = self.attention_embed(y) # bs, c*k*k,h,w 43 | att = att.reshape(bs, c, self.kernel_size*self.kernel_size, h, w) 44 | att = att.mean(2, keepdim=False).view(bs, c, -1) # bs,c,h*w 45 | k2 = F.softmax(att, dim=-1)*v 46 | k2 = k2.view(bs, c, h, w) 47 | return k1+k2 48 | 49 | 50 | if __name__ == '__main__': 51 | input = torch.randn(50, 512, 7, 7) 52 | cot = ContextAttention(dim=512, kernel_size=3) 53 | output = cot(input) 54 | print(output.shape) 55 | -------------------------------------------------------------------------------- /backbone/attention_modules/eca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class eca_block(nn.Module): 7 | def __init__(self, channel, b=1, gamma=2): 8 | super(eca_block, self).__init__() 9 | kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) 10 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1 11 | 12 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 13 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | output = self.avg_pool(x) 18 | # print("average pool shape:", output.shape) 19 | output = self.conv(output.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 20 | output = self.sigmoid(output) 21 | 22 | return x * output.expand_as(x) 23 | 24 | 25 | -------------------------------------------------------------------------------- /backbone/attention_modules/mobile_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def position(H, W, is_cuda=True): 7 | if is_cuda: 8 | loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1) 9 | loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W) 10 | else: 11 | loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1) 12 | loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W) 13 | loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0) 14 | return loc 15 | 16 | 17 | def stride(x, stride): 18 | b, c, h, w = x.shape 19 | return x[:, :, ::stride, ::stride] 20 | 21 | 22 | def init_rate_half(tensor): 23 | if tensor is not None: 24 | tensor.data.fill_(0.5) 25 | 26 | 27 | def init_rate_0(tensor): 28 | if tensor is not None: 29 | tensor.data.fill_(0.) 30 | 31 | 32 | class MobileAttention(nn.Module): 33 | def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1): 34 | super(MobileAttention, self).__init__() 35 | self.in_planes = in_planes 36 | self.out_planes = out_planes 37 | self.head = head 38 | self.kernel_att = kernel_att 39 | self.kernel_conv = kernel_conv 40 | self.stride = stride 41 | self.dilation = dilation 42 | self.rate1 = torch.nn.Parameter(torch.Tensor(1)) 43 | self.rate2 = torch.nn.Parameter(torch.Tensor(1)) 44 | self.head_dim = self.out_planes // self.head 45 | 46 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1) 47 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1) 48 | self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1) 49 | self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1) 50 | 51 | self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2 52 | self.pad_att = torch.nn.ReflectionPad2d(self.padding_att) 53 | self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride) 54 | self.softmax = torch.nn.Softmax(dim=1) 55 | 56 | self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False) 57 | self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes, 58 | kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1, 59 | stride=stride) 60 | 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | init_rate_half(self.rate1) 65 | init_rate_half(self.rate2) 66 | kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv) 67 | for i in range(self.kernel_conv * self.kernel_conv): 68 | kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1. 69 | kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1) 70 | self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True) 71 | self.dep_conv.bias = init_rate_0(self.dep_conv.bias) 72 | 73 | def forward(self, x): 74 | q, k, v = self.conv1(x), self.conv2(x), self.conv3(x) 75 | scaling = float(self.head_dim) ** -0.5 76 | b, c, h, w = q.shape 77 | h_out, w_out = h // self.stride, w // self.stride 78 | 79 | pe = self.conv_p(position(h, w, x.is_cuda)) 80 | 81 | q_att = q.view(b * self.head, self.head_dim, h, w) * scaling 82 | k_att = k.view(b * self.head, self.head_dim, h, w) 83 | v_att = v.view(b * self.head, self.head_dim, h, w) 84 | 85 | if self.stride > 1: 86 | q_att = stride(q_att, self.stride) 87 | q_pe = stride(pe, self.stride) 88 | else: 89 | q_pe = pe 90 | 91 | unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim, 92 | self.kernel_att * self.kernel_att, h_out, 93 | w_out) # b*head, head_dim, k_att^2, h_out, w_out 94 | unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out, 95 | w_out) # 1, head_dim, k_att^2, h_out, w_out 96 | 97 | att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum( 98 | 1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out) 99 | att = self.softmax(att) 100 | 101 | out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att, 102 | h_out, w_out) 103 | out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out) 104 | 105 | f_all = self.fc(torch.cat( 106 | [q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w), 107 | v.view(b, self.head, self.head_dim, h * w)], 1)) 108 | f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1]) 109 | 110 | out_conv = self.dep_conv(f_conv) 111 | 112 | return self.rate1 * out_att + self.rate2 * out_conv -------------------------------------------------------------------------------- /backbone/attention_modules/mobile_vit_attention.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from einops import rearrange 4 | 5 | 6 | class PreNorm(nn.Module): 7 | def __init__(self, dim, fn): 8 | super().__init__() 9 | self.ln = nn.LayerNorm(dim) 10 | self.fn = fn 11 | 12 | def forward(self, x, **kwargs): 13 | return self.fn(self.ln(x), **kwargs) 14 | 15 | 16 | class FeedForward(nn.Module): 17 | def __init__(self, dim, mlp_dim, dropout): 18 | super().__init__() 19 | self.net = nn.Sequential( 20 | nn.Linear(dim, mlp_dim), 21 | nn.SiLU(), 22 | nn.Dropout(dropout), 23 | nn.Linear(mlp_dim, dim), 24 | nn.Dropout(dropout) 25 | ) 26 | 27 | def forward(self, x): 28 | return self.net(x) 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, dim, heads, head_dim, dropout): 33 | super().__init__() 34 | inner_dim = heads * head_dim 35 | project_out = not (heads == 1 and head_dim == dim) 36 | 37 | self.heads = heads 38 | self.scale = head_dim ** -0.5 39 | 40 | self.attend = nn.Softmax(dim=-1) 41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 42 | 43 | self.to_out = nn.Sequential( 44 | nn.Linear(inner_dim, dim), 45 | nn.Dropout(dropout) 46 | ) if project_out else nn.Identity() 47 | 48 | def forward(self, x): 49 | qkv = self.to_qkv(x).chunk(3, dim=-1) 50 | q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) 51 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 52 | attn = self.attend(dots) 53 | out = torch.matmul(attn, v) 54 | out = rearrange(out, 'b p h n d -> b p n (h d)') 55 | return self.to_out(out) 56 | 57 | 58 | class Transformer(nn.Module): 59 | def __init__(self, dim, depth, heads, head_dim, mlp_dim, dropout=0.): 60 | super().__init__() 61 | self.layers = nn.ModuleList([]) 62 | for _ in range(depth): 63 | self.layers.append(nn.ModuleList([ 64 | PreNorm(dim, Attention(dim, heads, head_dim, dropout)), 65 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) 66 | ])) 67 | 68 | def forward(self, x): 69 | out = x 70 | for att, ffn in self.layers: 71 | out = out + att(out) 72 | out = out + ffn(out) 73 | return out 74 | 75 | 76 | class MobileViTAttention(nn.Module): 77 | def __init__(self, in_channel=3, dim=512, kernel_size=3, patch_size=4): 78 | super().__init__() 79 | self.ph, self.pw = patch_size, patch_size 80 | self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2) 81 | self.conv2 = nn.Conv2d(in_channel, dim, kernel_size=1) 82 | 83 | self.trans = Transformer(dim=dim, depth=3, heads=8, head_dim=64, mlp_dim=1024) 84 | 85 | self.conv3 = nn.Conv2d(dim, in_channel, kernel_size=1) 86 | self.conv4 = nn.Conv2d(2 * in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2) 87 | 88 | def forward(self, x): 89 | y = x.clone() # bs,c,h,w 90 | 91 | ## Local Representation 92 | y = self.conv2(self.conv1(x)) # bs,dim,h,w 93 | 94 | ## Global Representation 95 | _, _, h, w = y.shape 96 | y = rearrange(y, 'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim', ph=self.ph, pw=self.pw) # bs,h,w,dim 97 | y = self.trans(y) 98 | y = rearrange(y, 'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)', ph=self.ph, pw=self.pw, nh=h // self.ph, 99 | nw=w // self.pw) # bs,dim,h,w 100 | 101 | ## Fusion 102 | y = self.conv3(y) # bs,dim,h,w 103 | y = torch.cat([x, y], 1) # bs,2*dim,h,w 104 | y = self.conv4(y) # bs,c,h,w 105 | 106 | return y 107 | 108 | 109 | if __name__ == '__main__': 110 | m = MobileViTAttention(in_channel=96) 111 | input = torch.randn(16, 96, 80, 80) 112 | output = m(input) 113 | print(output.shape) -------------------------------------------------------------------------------- /backbone/attention_modules/shuffle_attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from torch.nn.parameter import Parameter 6 | 7 | 8 | class ShuffleAttention(nn.Module): 9 | 10 | def __init__(self, channel=512,reduction=16,G=8): 11 | super().__init__() 12 | self.G=G 13 | self.channel=channel 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G)) 16 | self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1)) 17 | self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1)) 18 | self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1)) 19 | self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1)) 20 | self.sigmoid=nn.Sigmoid() 21 | 22 | def init_weights(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | init.kaiming_normal_(m.weight, mode='fan_out') 26 | if m.bias is not None: 27 | init.constant_(m.bias, 0) 28 | elif isinstance(m, nn.BatchNorm2d): 29 | init.constant_(m.weight, 1) 30 | init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.Linear): 32 | init.normal_(m.weight, std=0.001) 33 | if m.bias is not None: 34 | init.constant_(m.bias, 0) 35 | 36 | 37 | @staticmethod 38 | def channel_shuffle(x, groups): 39 | b, c, h, w = x.shape 40 | x = x.reshape(b, groups, -1, h, w) 41 | x = x.permute(0, 2, 1, 3, 4) 42 | 43 | # flatten 44 | x = x.reshape(b, -1, h, w) 45 | 46 | return x 47 | 48 | def forward(self, x): 49 | b, c, h, w = x.size() 50 | #group into subfeatures 51 | x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w 52 | 53 | #channel_split 54 | x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w 55 | 56 | #channel attention 57 | x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1 58 | x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1 59 | x_channel=x_0*self.sigmoid(x_channel) 60 | 61 | #spatial attention 62 | x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w 63 | x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w 64 | x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w 65 | 66 | # concatenate along channel axis 67 | out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w 68 | out=out.contiguous().view(b,-1,h,w) 69 | 70 | # channel shuffle 71 | out = self.channel_shuffle(out, 2) 72 | return out 73 | 74 | 75 | if __name__ == '__main__': 76 | input=torch.randn(50,32,224,224) 77 | se = ShuffleAttention(channel=32,G=4) 78 | output=se(input) 79 | print(output.shape) -------------------------------------------------------------------------------- /backbone/conv_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/conv_utils/__init__.py -------------------------------------------------------------------------------- /backbone/conv_utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/conv_utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/conv_utils/__pycache__/normal_conv.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/conv_utils/__pycache__/normal_conv.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/conv_utils/dcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.ops 3 | from torch import nn 4 | 5 | 6 | class DeformableConv2d(nn.Module): 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size=3, 11 | stride=1, 12 | padding=1, 13 | bias=False): 14 | super(DeformableConv2d, self).__init__() 15 | 16 | assert type(kernel_size) == tuple or type(kernel_size) == int 17 | 18 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) 19 | self.stride = stride if type(stride) == tuple else (stride, stride) 20 | self.padding = padding 21 | 22 | self.offset_conv = nn.Conv2d(in_channels, 23 | 2 * kernel_size[0] * kernel_size[1], 24 | kernel_size=kernel_size, 25 | stride=stride, 26 | padding=self.padding, 27 | bias=True) 28 | 29 | nn.init.constant_(self.offset_conv.weight, 0.) 30 | nn.init.constant_(self.offset_conv.bias, 0.) 31 | 32 | self.modulator_conv = nn.Conv2d(in_channels, 33 | 1 * kernel_size[0] * kernel_size[1], 34 | kernel_size=kernel_size, 35 | stride=stride, 36 | padding=self.padding, 37 | bias=True) 38 | 39 | nn.init.constant_(self.modulator_conv.weight, 0.) 40 | nn.init.constant_(self.modulator_conv.bias, 0.) 41 | 42 | self.regular_conv = nn.Conv2d(in_channels=in_channels, 43 | out_channels=out_channels, 44 | kernel_size=kernel_size, 45 | stride=stride, 46 | padding=self.padding, 47 | bias=bias) 48 | 49 | def forward(self, x): 50 | # h, w = x.shape[2:] 51 | # max_offset = max(h, w)/4. 52 | 53 | offset = self.offset_conv(x) # .clamp(-max_offset, max_offset) 54 | modulator = 2. * torch.sigmoid(self.modulator_conv(x)) 55 | 56 | x = torchvision.ops.deform_conv2d(input=x, 57 | offset=offset, 58 | weight=self.regular_conv.weight, 59 | bias=self.regular_conv.bias, 60 | padding=self.padding, 61 | mask=modulator, 62 | stride=self.stride, 63 | ) 64 | return x -------------------------------------------------------------------------------- /backbone/conv_utils/ds_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd 5 | 6 | 7 | class BN_Conv2d(nn.Module): 8 | """ 9 | BN_CONV, default activation is SiLU 10 | """ 11 | 12 | def __init__(self, in_channels: object, out_channels: object, kernel_size: object, stride: object, padding: object, 13 | dilation=1, groups=1, bias=False, activation=True) -> object: 14 | super(BN_Conv2d, self).__init__() 15 | layers = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 16 | padding=padding, dilation=dilation, groups=groups, bias=bias), 17 | nn.BatchNorm2d(out_channels)] 18 | if activation: 19 | layers.append(nn.ReLU(inplace=True)) 20 | self.seq = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | return self.seq(x) 24 | 25 | 26 | class DWConv(nn.Module): 27 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, activation=True): 28 | super().__init__() 29 | self.dconv = BN_Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, groups=in_channels, 30 | activation=activation) 31 | self.pconv = BN_Conv2d(in_channels, out_channels, kernel_size=1, stride=1, groups=1, activation=activation) 32 | 33 | def forward(self, x): 34 | x = self.dconv(x) 35 | return self.pconv(x) -------------------------------------------------------------------------------- /backbone/conv_utils/dynamic_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16): 9 | super(Attention, self).__init__() 10 | attention_channel = max(int(in_planes * reduction), min_channel) 11 | self.kernel_size = kernel_size 12 | self.kernel_num = kernel_num 13 | self.temperature = 1.0 14 | 15 | self.avgpool = nn.AdaptiveAvgPool2d(1) 16 | self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False) 17 | self.bn = nn.BatchNorm2d(attention_channel) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True) 21 | self.func_channel = self.get_channel_attention 22 | 23 | if in_planes == groups and in_planes == out_planes: # depth-wise convolution 24 | self.func_filter = self.skip 25 | else: 26 | self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True) 27 | self.func_filter = self.get_filter_attention 28 | 29 | if kernel_size == 1: # point-wise convolution 30 | self.func_spatial = self.skip 31 | else: 32 | self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True) 33 | self.func_spatial = self.get_spatial_attention 34 | 35 | if kernel_num == 1: 36 | self.func_kernel = self.skip 37 | else: 38 | self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True) 39 | self.func_kernel = self.get_kernel_attention 40 | 41 | self._initialize_weights() 42 | 43 | def _initialize_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 47 | if m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | if isinstance(m, nn.BatchNorm2d): 50 | nn.init.constant_(m.weight, 1) 51 | nn.init.constant_(m.bias, 0) 52 | 53 | def update_temperature(self, temperature): 54 | self.temperature = temperature 55 | 56 | @staticmethod 57 | def skip(_): 58 | return 1.0 59 | 60 | def get_channel_attention(self, x): 61 | channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature) 62 | return channel_attention 63 | 64 | def get_filter_attention(self, x): 65 | filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature) 66 | return filter_attention 67 | 68 | def get_spatial_attention(self, x): 69 | spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size) 70 | spatial_attention = torch.sigmoid(spatial_attention / self.temperature) 71 | return spatial_attention 72 | 73 | def get_kernel_attention(self, x): 74 | kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1) 75 | kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1) 76 | return kernel_attention 77 | 78 | def forward(self, x): 79 | x = self.avgpool(x) 80 | x = self.fc(x) 81 | x = self.bn(x) 82 | x = self.relu(x) 83 | return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x) 84 | 85 | 86 | class DynamicConv(nn.Module): 87 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, 88 | reduction=0.0625, kernel_num=4): 89 | super(DynamicConv, self).__init__() 90 | self.in_planes = in_planes 91 | self.out_planes = out_planes 92 | self.kernel_size = kernel_size 93 | self.stride = stride 94 | self.padding = padding 95 | self.dilation = dilation 96 | self.groups = groups 97 | self.kernel_num = kernel_num 98 | self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups, 99 | reduction=reduction, kernel_num=kernel_num) 100 | self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size), 101 | requires_grad=True) 102 | self._initialize_weights() 103 | 104 | if self.kernel_size == 1 and self.kernel_num == 1: 105 | self._forward_impl = self._forward_impl_pw1x 106 | else: 107 | self._forward_impl = self._forward_impl_common 108 | 109 | def _initialize_weights(self): 110 | for i in range(self.kernel_num): 111 | nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu') 112 | 113 | def update_temperature(self, temperature): 114 | self.attention.update_temperature(temperature) 115 | 116 | def _forward_impl_common(self, x): 117 | # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent, 118 | # while we observe that when using the latter method the models will run faster with less gpu memory cost. 119 | channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x) 120 | batch_size, in_planes, height, width = x.size() 121 | x = x * channel_attention 122 | x = x.reshape(1, -1, height, width) 123 | aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0) 124 | aggregate_weight = torch.sum(aggregate_weight, dim=1).view( 125 | [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size]) 126 | output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, 127 | dilation=self.dilation, groups=self.groups * batch_size) 128 | output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1)) 129 | output = output * filter_attention 130 | return output 131 | 132 | def _forward_impl_pw1x(self, x): 133 | channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x) 134 | x = x * channel_attention 135 | output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding, 136 | dilation=self.dilation, groups=self.groups) 137 | output = output * filter_attention 138 | return output 139 | 140 | def forward(self, x): 141 | return self._forward_impl(x) -------------------------------------------------------------------------------- /backbone/conv_utils/normal_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SiLU(nn.Module): 6 | @staticmethod 7 | def forward(x): 8 | return x * torch.sigmoid(x) 9 | 10 | 11 | def get_activation(name="silu", inplace=True): 12 | if name == "silu": 13 | module = SiLU() 14 | elif name == "relu": 15 | module = nn.ReLU(inplace=inplace) 16 | elif name == "lrelu": 17 | module = nn.LeakyReLU(0.1, inplace=inplace) 18 | else: 19 | raise AttributeError("Unsupported act type: {}".format(name)) 20 | return module 21 | 22 | 23 | class DWConv(nn.Module): 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True): 25 | super().__init__() 26 | self.dconv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 27 | stride=stride, groups=in_channels, padding=padding, dilation=dilation, bias=bias) 28 | self.pconv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, groups=1, 29 | bias=bias) 30 | 31 | def forward(self, x): 32 | x = self.dconv(x) 33 | return self.pconv(x) 34 | 35 | 36 | class BaseConv(nn.Module): 37 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="relu", ds_conv=False): 38 | super().__init__() 39 | pad = (ksize - 1) // 2 40 | if ds_conv is False: 41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, 42 | groups=groups, bias=bias) 43 | else: 44 | self.conv = DWConv(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, bias=bias) 45 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03) 46 | self.act = get_activation(act, inplace=True) 47 | 48 | def forward(self, x): 49 | return self.act(self.bn(self.conv(x))) 50 | 51 | def fuseforward(self, x): 52 | return self.act(self.conv(x)) 53 | -------------------------------------------------------------------------------- /backbone/fusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/fusion/__init__.py -------------------------------------------------------------------------------- /backbone/fusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/fusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/fusion/__pycache__/vr_coc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/fusion/__pycache__/vr_coc.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/radar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/radar/__init__.py -------------------------------------------------------------------------------- /backbone/radar/contextcluster.py: -------------------------------------------------------------------------------- 1 | """ 2 | We add our Context Cluster based on PointMLP, to validate its effectiveness and applicability on point cloud data. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | # from torch import einsum 8 | # from einops import rearrange, repeat 9 | from einops import rearrange 10 | import torch.nn.functional as F 11 | 12 | 13 | try: 14 | from model_utils import get_activation, square_distance, index_points, query_ball_point, knn_point 15 | except: 16 | from .model_utils import get_activation, square_distance, index_points, query_ball_point, knn_point 17 | 18 | try: 19 | from pointnet2_ops.pointnet2_utils import furthest_point_sample as furthest_point_sample 20 | except: 21 | print("==> not using CUDA FPS\n") 22 | try: 23 | from model_utils import farthest_point_sample as furthest_point_sample 24 | except: 25 | from .model_utils import farthest_point_sample as furthest_point_sample 26 | 27 | 28 | class LocalGrouper(nn.Module): 29 | def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs): 30 | """ 31 | Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] 32 | :param groups: groups number 33 | :param kneighbors: k-nerighbors 34 | :param kwargs: others 35 | """ 36 | super(LocalGrouper, self).__init__() 37 | self.groups = groups 38 | self.kneighbors = kneighbors 39 | self.use_xyz = use_xyz 40 | if normalize is not None: 41 | self.normalize = normalize.lower() 42 | else: 43 | self.normalize = None 44 | if self.normalize not in ["center", "anchor"]: 45 | print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") 46 | self.normalize = None 47 | if self.normalize is not None: 48 | add_channel=3 if self.use_xyz else 0 49 | self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel])) 50 | self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel])) 51 | 52 | def forward(self, xyz, points): 53 | B, N, C = xyz.shape 54 | S = self.groups 55 | xyz = xyz.contiguous() # xyz [btach, points, xyz] 56 | 57 | # fps_idx = torch.multinomial(torch.linspace(0, N - 1, steps=N).repeat(B, 1).to(xyz.device), num_samples=self.groups, replacement=False).long() 58 | # fps_idx = farthest_point_sample(xyz, self.groups).long() 59 | fps_idx = furthest_point_sample(xyz, self.groups).long() # [B, npoint] 60 | new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3] 61 | new_points = index_points(points, fps_idx) # [B, npoint, d] 62 | 63 | idx = knn_point(self.kneighbors, xyz, new_xyz) 64 | # idx = query_ball_point(radius, nsample, xyz, new_xyz) 65 | grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] 66 | grouped_points = index_points(points, idx) # [B, npoint, k, d] 67 | if self.use_xyz: 68 | grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] 69 | if self.normalize is not None: 70 | if self.normalize =="center": 71 | mean = torch.mean(grouped_points, dim=2, keepdim=True) 72 | if self.normalize =="anchor": 73 | mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points 74 | mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] 75 | std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1) 76 | grouped_points = (grouped_points-mean)/(std + 1e-5) 77 | grouped_points = self.affine_alpha*grouped_points + self.affine_beta 78 | 79 | new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1) 80 | return new_xyz, new_points 81 | 82 | 83 | class ConvBNReLU1D(nn.Module): 84 | def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'): 85 | super(ConvBNReLU1D, self).__init__() 86 | self.act = get_activation(activation) 87 | self.net = nn.Sequential( 88 | nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), 89 | nn.BatchNorm1d(out_channels), 90 | self.act 91 | ) 92 | 93 | def forward(self, x): 94 | return self.net(x) 95 | 96 | 97 | class ConvBNReLURes1D(nn.Module): 98 | def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): 99 | super(ConvBNReLURes1D, self).__init__() 100 | self.act = get_activation(activation) 101 | self.net1 = nn.Sequential( 102 | nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion), 103 | kernel_size=kernel_size, groups=groups, bias=bias), 104 | nn.BatchNorm1d(int(channel * res_expansion)), 105 | self.act 106 | ) 107 | if groups > 1: 108 | self.net2 = nn.Sequential( 109 | nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, 110 | kernel_size=kernel_size, groups=groups, bias=bias), 111 | nn.BatchNorm1d(channel), 112 | self.act, 113 | nn.Conv1d(in_channels=channel, out_channels=channel, 114 | kernel_size=kernel_size, bias=bias), 115 | nn.BatchNorm1d(channel), 116 | ) 117 | else: 118 | self.net2 = nn.Sequential( 119 | nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, 120 | kernel_size=kernel_size, bias=bias), 121 | nn.BatchNorm1d(channel) 122 | ) 123 | 124 | def forward(self, x): 125 | return self.act(self.net2(self.net1(x)) + x) 126 | 127 | 128 | def pairwise_cos_sim(x1: torch.Tensor, x2:torch.Tensor): 129 | """ 130 | return pair-wise similarity matrix between two tensors 131 | :param x1: [B,...,M,D] 132 | :param x2: [B,...,N,D] 133 | :return: similarity matrix [B,...,M,N] 134 | """ 135 | x1 = F.normalize(x1,dim=-1) 136 | x2 = F.normalize(x2,dim=-1) 137 | 138 | sim = torch.matmul(x1, x2.transpose(-2, -1)) 139 | return sim 140 | 141 | 142 | class ContextCluster(nn.Module): 143 | def __init__(self, dim, heads=4, head_dim=24): 144 | super(ContextCluster, self).__init__() 145 | self.heads = heads 146 | self.head_dim=head_dim 147 | self.fc1 = nn.Linear(dim, heads*head_dim) 148 | self.fc2 = nn.Linear(heads*head_dim, dim) 149 | self.fc_v = nn.Linear(dim, heads*head_dim) 150 | self.sim_alpha = nn.Parameter(torch.ones(1)) 151 | self.sim_beta = nn.Parameter(torch.zeros(1)) 152 | 153 | def forward(self, x): #[b,d,k] 154 | res = x 155 | x = rearrange(x, "b d k -> b k d") 156 | value = self.fc_v(x) # [b,k,head*head_d] 157 | x = self.fc1(x) # [b,k,head*head_d] 158 | x = rearrange(x, "b k (h d) -> (b h) k d", h=self.heads) # [b,k,d] 159 | value = rearrange(value, "b k (h d) -> (b h) k d", h=self.heads) # [b,k,d] 160 | center = x.mean(dim=1, keepdim=True) # [b,1,d] 161 | value_center = value.mean(dim=1, keepdim=True) # [b,1,d] 162 | sim = torch.sigmoid(self.sim_beta + self.sim_alpha * pairwise_cos_sim(center, x) )#[B,1,k] 163 | # out [b, 1, d] 164 | out = ( (value.unsqueeze(dim=1)*sim.unsqueeze(dim=-1) ).sum(dim=2) + value_center)/ (sim.sum(dim=-1,keepdim=True)+ 1.0) # [B,M,D] 165 | out = out*(sim.squeeze(dim=1).unsqueeze(dim=-1)) # [b,k,d] 166 | out = rearrange(out, "(b h) k d -> b k (h d)", h=self.heads) # [b,k,d] 167 | out = self.fc2(out) 168 | out = rearrange(out, "b k d -> b d k") 169 | return res + out 170 | 171 | 172 | class PreExtraction(nn.Module): 173 | def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True, 174 | activation='relu', use_xyz=True): 175 | """ 176 | input: [b,g,k,d]: output:[b,d,g] 177 | :param channels: 178 | :param blocks: 179 | """ 180 | super(PreExtraction, self).__init__() 181 | in_channels = 3+2*channels if use_xyz else 2*channels 182 | self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) 183 | operation = [] 184 | for _ in range(blocks): 185 | operation.append( 186 | ContextCluster(out_channels, heads=1,head_dim=max(out_channels//4, 32)) 187 | ) 188 | operation.append( 189 | # add context cluster here. 190 | ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion, 191 | bias=bias, activation=activation) 192 | ) 193 | self.operation = nn.Sequential(*operation) 194 | 195 | def forward(self, x): 196 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) 197 | x = x.permute(0, 1, 3, 2) 198 | x = x.reshape(-1, d, s) 199 | x = self.transfer(x) 200 | batch_size, _, _ = x.size() 201 | x = self.operation(x) # [b, d, k] 202 | x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 203 | x = x.reshape(b, n, -1).permute(0, 2, 1) 204 | return x 205 | 206 | 207 | class PosExtraction(nn.Module): 208 | def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'): 209 | """ 210 | input[b,d,g]; output[b,d,g] 211 | :param channels: 212 | :param blocks: 213 | """ 214 | super(PosExtraction, self).__init__() 215 | operation = [] 216 | for _ in range(blocks): 217 | operation.append( 218 | ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) 219 | ) 220 | self.operation = nn.Sequential(*operation) 221 | 222 | def forward(self, x): # [b, d, g] 223 | return self.operation(x) 224 | 225 | 226 | 227 | 228 | class Model(nn.Module): 229 | def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0, 230 | activation="relu", bias=True, use_xyz=True, normalize="center", 231 | dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], 232 | k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs): 233 | super(Model, self).__init__() 234 | self.stages = len(pre_blocks) 235 | self.class_num = class_num 236 | self.points = points 237 | self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) 238 | assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ 239 | "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." 240 | self.local_grouper_list = nn.ModuleList() 241 | self.pre_blocks_list = nn.ModuleList() 242 | self.pos_blocks_list = nn.ModuleList() 243 | last_channel = embed_dim 244 | anchor_points = self.points 245 | for i in range(len(pre_blocks)): 246 | out_channel = last_channel * dim_expansion[i] 247 | pre_block_num = pre_blocks[i] 248 | pos_block_num = pos_blocks[i] 249 | kneighbor = k_neighbors[i] 250 | reduce = reducers[i] 251 | anchor_points = anchor_points // reduce 252 | # append local_grouper_list 253 | local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] 254 | self.local_grouper_list.append(local_grouper) 255 | # append pre_block_list 256 | pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups, 257 | res_expansion=res_expansion, 258 | bias=bias, activation=activation, use_xyz=use_xyz) 259 | self.pre_blocks_list.append(pre_block_module) 260 | # append pos_block_list 261 | pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups, 262 | res_expansion=res_expansion, bias=bias, activation=activation) 263 | self.pos_blocks_list.append(pos_block_module) 264 | 265 | last_channel = out_channel 266 | 267 | self.act = get_activation(activation) 268 | self.classifier = nn.Sequential( 269 | nn.Linear(last_channel, 512), 270 | nn.BatchNorm1d(512), 271 | self.act, 272 | nn.Dropout(0.5), 273 | nn.Linear(512, 256), 274 | nn.BatchNorm1d(256), 275 | self.act, 276 | nn.Dropout(0.5), 277 | nn.Linear(256, self.class_num) 278 | ) 279 | 280 | def forward(self, x): 281 | xyz = x.permute(0, 2, 1) 282 | batch_size, _, _ = x.size() 283 | x = self.embedding(x) # B,D,N 284 | for i in range(self.stages): 285 | # Give xyz[b, p, 3] and fea[b, p, d], return new_xyz[b, g, 3] and new_fea[b, g, k, d] 286 | xyz, x = self.local_grouper_list[i](xyz, x.permute(0, 2, 1)) # [b,g,3] [b,g,k,d] 287 | x = self.pre_blocks_list[i](x) # [b,d,g] 288 | x = self.pos_blocks_list[i](x) # [b,d,g] 289 | 290 | x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) 291 | x = self.classifier(x) 292 | return x 293 | 294 | 295 | def pointMLP_CoC(num_classes=40, **kwargs) -> Model: 296 | return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1, 297 | activation="relu", bias=False, use_xyz=False, normalize="anchor", 298 | dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], 299 | k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) 300 | 301 | 302 | if __name__ == '__main__': 303 | data = torch.rand(2, 3, 1024) 304 | print("===> testing pointMLP ...") 305 | model = pointMLP_CoC2() 306 | out = model(data) 307 | print(out.shape) 308 | parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 309 | print(f"Paramter number is: {parameters}") 310 | 311 | -------------------------------------------------------------------------------- /backbone/radar/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model utils for our Point Cloud Efficient Transformer for TPAMI. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | # from torch import einsum 8 | # from einops import rearrange, repeat 9 | 10 | 11 | def get_activation(activation): 12 | if activation.lower() == 'gelu': 13 | return nn.GELU() 14 | elif activation.lower() == 'rrelu': 15 | return nn.RReLU(inplace=True) 16 | elif activation.lower() == 'selu': 17 | return nn.SELU(inplace=True) 18 | elif activation.lower() == 'silu': 19 | return nn.SiLU(inplace=True) 20 | elif activation.lower() == 'hardswish': 21 | return nn.Hardswish(inplace=True) 22 | elif activation.lower() == 'leakyrelu': 23 | return nn.LeakyReLU(inplace=True) 24 | elif activation.lower() == 'leakyrelu0.2': 25 | return nn.LeakyReLU(0.2, inplace=True) 26 | else: 27 | return nn.ReLU(inplace=True) 28 | 29 | 30 | def square_distance(src, dst): 31 | """ 32 | Calculate Euclid distance between each two points. 33 | src^T * dst = xn * xm + yn * ym + zn * zm; 34 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 35 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 36 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 37 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 38 | Input: 39 | src: source points, [B, N, C] 40 | dst: target points, [B, M, C] 41 | Output: 42 | dist: per-point square distance, [B, N, M] 43 | """ 44 | B, N, _ = src.shape 45 | _, M, _ = dst.shape 46 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 47 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 48 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 49 | return dist 50 | 51 | 52 | def index_points(points, idx): 53 | """ 54 | Input: 55 | points: input points data, [B, N, C] 56 | idx: sample index data, [B, S] 57 | Return: 58 | new_points:, indexed points data, [B, S, C] 59 | """ 60 | device = points.device 61 | B = points.shape[0] 62 | view_shape = list(idx.shape) 63 | view_shape[1:] = [1] * (len(view_shape) - 1) 64 | repeat_shape = list(idx.shape) 65 | repeat_shape[0] = 1 66 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 67 | new_points = points[batch_indices, idx, :] 68 | return new_points 69 | 70 | 71 | def farthest_point_sample(xyz, npoint): 72 | """ 73 | Input: 74 | xyz: pointcloud data, [B, N, 3] 75 | npoint: number of samples 76 | Return: 77 | centroids: sampled pointcloud index, [B, npoint] 78 | """ 79 | device = xyz.device 80 | B, N, C = xyz.shape 81 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 82 | distance = torch.ones(B, N).to(device) * 1e10 83 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 84 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 85 | for i in range(npoint): 86 | centroids[:, i] = farthest 87 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 88 | dist = torch.sum((xyz - centroid) ** 2, -1) 89 | distance = torch.min(distance, dist) 90 | farthest = torch.max(distance, -1)[1] 91 | return centroids 92 | 93 | 94 | def query_ball_point(radius, nsample, xyz, new_xyz): 95 | """ 96 | Input: 97 | radius: local region radius 98 | nsample: max sample number in local region 99 | xyz: all points, [B, N, 3] 100 | new_xyz: query points, [B, S, 3] 101 | Return: 102 | group_idx: grouped points index, [B, S, nsample] 103 | """ 104 | device = xyz.device 105 | B, N, C = xyz.shape 106 | _, S, _ = new_xyz.shape 107 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 108 | sqrdists = square_distance(new_xyz, xyz) 109 | group_idx[sqrdists > radius ** 2] = N 110 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 111 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 112 | mask = group_idx == N 113 | group_idx[mask] = group_first[mask] 114 | return group_idx 115 | 116 | 117 | def knn_point(nsample, xyz, new_xyz): 118 | """ 119 | Input: 120 | nsample: max sample number in local region 121 | xyz: all points, [B, N, C] 122 | new_xyz: query points, [B, S, C] 123 | Return: 124 | group_idx: grouped points index, [B, S, nsample] 125 | """ 126 | sqrdists = square_distance(new_xyz, xyz) 127 | _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) 128 | return group_idx 129 | 130 | 131 | -------------------------------------------------------------------------------- /backbone/radar/pointnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class STN3d(nn.Module): 12 | def __init__(self): 13 | super(STN3d, self).__init__() 14 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 15 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 16 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 17 | self.fc1 = nn.Linear(1024, 512) 18 | self.fc2 = nn.Linear(512, 256) 19 | self.fc3 = nn.Linear(256, 9) 20 | self.relu = nn.ReLU() 21 | 22 | self.bn1 = nn.BatchNorm1d(64) 23 | self.bn2 = nn.BatchNorm1d(128) 24 | self.bn3 = nn.BatchNorm1d(1024) 25 | self.bn4 = nn.BatchNorm1d(512) 26 | self.bn5 = nn.BatchNorm1d(256) 27 | 28 | 29 | def forward(self, x): 30 | batchsize = x.size()[0] 31 | x = F.relu(self.bn1(self.conv1(x))) 32 | x = F.relu(self.bn2(self.conv2(x))) 33 | x = F.relu(self.bn3(self.conv3(x))) 34 | x = torch.max(x, 2, keepdim=True)[0] 35 | x = x.view(-1, 1024) 36 | 37 | x = F.relu(self.bn4(self.fc1(x))) 38 | x = F.relu(self.bn5(self.fc2(x))) 39 | x = self.fc3(x) 40 | 41 | iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) 42 | if x.is_cuda: 43 | iden = iden.cuda() 44 | x = x + iden 45 | x = x.view(-1, 3, 3) 46 | return x 47 | 48 | 49 | class STNkd(nn.Module): 50 | def __init__(self, k=64): 51 | super(STNkd, self).__init__() 52 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 53 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 54 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 55 | self.fc1 = nn.Linear(1024, 512) 56 | self.fc2 = nn.Linear(512, 256) 57 | self.fc3 = nn.Linear(256, k*k) 58 | self.relu = nn.ReLU() 59 | 60 | self.bn1 = nn.BatchNorm1d(64) 61 | self.bn2 = nn.BatchNorm1d(128) 62 | self.bn3 = nn.BatchNorm1d(1024) 63 | self.bn4 = nn.BatchNorm1d(512) 64 | self.bn5 = nn.BatchNorm1d(256) 65 | 66 | self.k = k 67 | 68 | def forward(self, x): 69 | batchsize = x.size()[0] 70 | x = F.relu(self.bn1(self.conv1(x))) 71 | x = F.relu(self.bn2(self.conv2(x))) 72 | x = F.relu(self.bn3(self.conv3(x))) 73 | x = torch.max(x, 2, keepdim=True)[0] 74 | x = x.view(-1, 1024) 75 | 76 | x = F.relu(self.bn4(self.fc1(x))) 77 | x = F.relu(self.bn5(self.fc2(x))) 78 | x = self.fc3(x) 79 | 80 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) 81 | if x.is_cuda: 82 | iden = iden.cuda() 83 | x = x + iden 84 | x = x.view(-1, self.k, self.k) 85 | return x 86 | 87 | class PointNetfeat(nn.Module): 88 | def __init__(self, global_feat = True, feature_transform = False): 89 | super(PointNetfeat, self).__init__() 90 | self.stn = STN3d() 91 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 92 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 93 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 94 | self.bn1 = nn.BatchNorm1d(64) 95 | self.bn2 = nn.BatchNorm1d(128) 96 | self.bn3 = nn.BatchNorm1d(1024) 97 | self.global_feat = global_feat 98 | self.feature_transform = feature_transform 99 | if self.feature_transform: 100 | self.fstn = STNkd(k=64) 101 | 102 | def forward(self, x): 103 | n_pts = x.size()[2] 104 | trans = self.stn(x) 105 | x = x.transpose(2, 1) 106 | x = torch.bmm(x, trans) 107 | x = x.transpose(2, 1) 108 | x = F.relu(self.bn1(self.conv1(x))) 109 | 110 | if self.feature_transform: 111 | trans_feat = self.fstn(x) 112 | x = x.transpose(2,1) 113 | x = torch.bmm(x, trans_feat) 114 | x = x.transpose(2,1) 115 | else: 116 | trans_feat = None 117 | 118 | pointfeat = x 119 | x = F.relu(self.bn2(self.conv2(x))) 120 | x = self.bn3(self.conv3(x)) 121 | x = torch.max(x, 2, keepdim=True)[0] 122 | x = x.view(-1, 1024) 123 | if self.global_feat: 124 | return x, trans, trans_feat 125 | else: 126 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 127 | return torch.cat([x, pointfeat], 1), trans, trans_feat 128 | 129 | class PointNetCls(nn.Module): 130 | def __init__(self, k=2, feature_transform=False): 131 | super(PointNetCls, self).__init__() 132 | self.feature_transform = feature_transform 133 | self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) 134 | self.fc1 = nn.Linear(1024, 512) 135 | self.fc2 = nn.Linear(512, 256) 136 | self.fc3 = nn.Linear(256, k) 137 | self.dropout = nn.Dropout(p=0.3) 138 | self.bn1 = nn.BatchNorm1d(512) 139 | self.bn2 = nn.BatchNorm1d(256) 140 | self.relu = nn.ReLU() 141 | 142 | def forward(self, x): 143 | x, trans, trans_feat = self.feat(x) 144 | x = F.relu(self.bn1(self.fc1(x))) 145 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 146 | x = self.fc3(x) 147 | return F.log_softmax(x, dim=1), trans, trans_feat 148 | 149 | 150 | class PointNetDenseCls(nn.Module): 151 | def __init__(self, k = 2, feature_transform=False): 152 | super(PointNetDenseCls, self).__init__() 153 | self.k = k 154 | self.feature_transform=feature_transform 155 | self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) 156 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 157 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 158 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 159 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 160 | self.bn1 = nn.BatchNorm1d(512) 161 | self.bn2 = nn.BatchNorm1d(256) 162 | self.bn3 = nn.BatchNorm1d(128) 163 | 164 | def forward(self, x): 165 | batchsize = x.size()[0] 166 | n_pts = x.size()[2] 167 | x, trans, trans_feat = self.feat(x) 168 | x = F.relu(self.bn1(self.conv1(x))) 169 | x = F.relu(self.bn2(self.conv2(x))) 170 | x = F.relu(self.bn3(self.conv3(x))) 171 | x = self.conv4(x) 172 | x = x.transpose(2,1).contiguous() 173 | x = F.log_softmax(x.view(-1,self.k), dim=-1) 174 | x = x.view(batchsize, n_pts, self.k) 175 | return x, trans, trans_feat 176 | 177 | def feature_transform_regularizer(trans): 178 | d = trans.size()[1] 179 | batchsize = trans.size()[0] 180 | I = torch.eye(d)[None, :, :] 181 | if trans.is_cuda: 182 | I = I.cuda() 183 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) 184 | return loss 185 | 186 | if __name__ == '__main__': 187 | sim_data = Variable(torch.rand(32,3,2500)) 188 | trans = STN3d() 189 | out = trans(sim_data) 190 | print('stn', out.size()) 191 | print('loss', feature_transform_regularizer(out)) 192 | 193 | sim_data_64d = Variable(torch.rand(32, 64, 2500)) 194 | trans = STNkd(k=64) 195 | out = trans(sim_data_64d) 196 | print('stn64d', out.size()) 197 | print('loss', feature_transform_regularizer(out)) 198 | 199 | pointfeat = PointNetfeat(global_feat=True) 200 | out, _, _ = pointfeat(sim_data) 201 | print('global feat', out.size()) 202 | 203 | pointfeat = PointNetfeat(global_feat=False) 204 | out, _, _ = pointfeat(sim_data) 205 | print('point feat', out.size()) 206 | 207 | cls = PointNetCls(k = 5) 208 | out, _, _ = cls(sim_data) 209 | print('class', out.size()) 210 | 211 | seg = PointNetDenseCls(k = 3) 212 | out, _, _ = seg(sim_data) 213 | print('seg', out.size()) -------------------------------------------------------------------------------- /backbone/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/vision/__init__.py -------------------------------------------------------------------------------- /backbone/vision/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/vision/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /backbone/vision/__pycache__/context_cluster.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/vision/__pycache__/context_cluster.cpython-39.pyc -------------------------------------------------------------------------------- /get_miou.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | from deeplab import DeeplabV3 7 | from utils_seg.utils_metrics import compute_mIoU, show_results 8 | 9 | ''' 10 | 进行指标评估需要注意以下几点: 11 | 1、该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。 12 | 2、该文件计算的是验证集的miou,当前该库将测试集当作验证集使用,不单独划分测试集 13 | ''' 14 | if __name__ == "__main__": 15 | #---------------------------------------------------------------------------# 16 | # miou_mode用于指定该文件运行时计算的内容 17 | # miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。 18 | # miou_mode为1代表仅仅获得预测结果。 19 | # miou_mode为2代表仅仅计算miou。 20 | #---------------------------------------------------------------------------# 21 | miou_mode = 0 22 | #------------------------------# 23 | # 分类个数+1、如2+1 24 | #------------------------------# 25 | num_classes = 21 26 | #--------------------------------------------# 27 | # 区分的种类,和json_to_dataset里面的一样 28 | #--------------------------------------------# 29 | name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 30 | # name_classes = ["_background_","cat","dog"] 31 | #-------------------------------------------------------# 32 | # 指向VOC数据集所在的文件夹 33 | # 默认指向根目录下的VOC数据集 34 | #-------------------------------------------------------# 35 | VOCdevkit_path = 'E:/Big_Datasets/voc_seg/VOCdevkit/VOCdevkit' 36 | 37 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() 38 | gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/") 39 | miou_out_path = "miou_out" 40 | pred_dir = os.path.join(miou_out_path, 'detection-results') 41 | 42 | if miou_mode == 0 or miou_mode == 1: 43 | if not os.path.exists(pred_dir): 44 | os.makedirs(pred_dir) 45 | 46 | print("Load model.") 47 | deeplab = DeeplabV3() 48 | print("Load model done.") 49 | 50 | print("Get predict result.") 51 | for image_id in tqdm(image_ids): 52 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 53 | image = Image.open(image_path) 54 | image = deeplab.get_miou_png(image) 55 | image.save(os.path.join(pred_dir, image_id + ".png")) 56 | print("Get predict result done.") 57 | 58 | if miou_mode == 0 or miou_mode == 2: 59 | print("Get miou.") 60 | hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数 61 | print("Get miou done.") 62 | show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes) -------------------------------------------------------------------------------- /head/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/head/__init__.py -------------------------------------------------------------------------------- /head/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/head/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /head/__pycache__/decouplehead.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/head/__pycache__/decouplehead.cpython-39.pyc -------------------------------------------------------------------------------- /head/decouplehead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from backbone.conv_utils.normal_conv import BaseConv 5 | 6 | 7 | class DecoupleHead(nn.Module): 8 | def __init__(self, num_classes, width=1.0, in_channels=[128, 320, 512], act="relu", depthwise=False): 9 | super().__init__() 10 | Conv = BaseConv 11 | 12 | self.cls_convs = nn.ModuleList() 13 | self.reg_convs = nn.ModuleList() 14 | self.cls_preds = nn.ModuleList() 15 | self.reg_preds = nn.ModuleList() 16 | self.obj_preds = nn.ModuleList() 17 | self.stems = nn.ModuleList() 18 | 19 | for i in range(len(in_channels)): 20 | self.stems.append( 21 | Conv(in_channels=int(in_channels[i] * width), out_channels=int(256 * width), ksize=1, stride=1, 22 | act=act)) 23 | self.cls_convs.append(nn.Sequential(*[ 24 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True), 25 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True), 26 | ])) 27 | self.cls_preds.append( 28 | nn.Conv2d(in_channels=int(256 * width), out_channels=num_classes, kernel_size=1, stride=1, padding=0) 29 | ) 30 | 31 | self.reg_convs.append(nn.Sequential(*[ 32 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True), 33 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True) 34 | ])) 35 | self.reg_preds.append( 36 | nn.Conv2d(in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0) 37 | ) 38 | self.obj_preds.append( 39 | nn.Conv2d(in_channels=int(256 * width), out_channels=1, kernel_size=1, stride=1, padding=0) 40 | ) 41 | 42 | def forward(self, inputs): 43 | # ---------------------------------------------------# 44 | # inputs输入 45 | # P3_out 80, 80, 256 46 | # P4_out 40, 40, 512 47 | # P5_out 20, 20, 1024 48 | # ---------------------------------------------------# 49 | outputs = [] 50 | for k, x in enumerate(inputs): 51 | # ---------------------------------------------------# 52 | # 利用1x1卷积进行通道整合 53 | # ---------------------------------------------------# 54 | x = self.stems[k](x) 55 | # ---------------------------------------------------# 56 | # 利用两个卷积标准化激活函数来进行特征提取 57 | # ---------------------------------------------------# 58 | cls_feat = self.cls_convs[k](x) 59 | # ---------------------------------------------------# 60 | # 判断特征点所属的种类 61 | # 80, 80, num_classes 62 | # 40, 40, num_classes 63 | # 20, 20, num_classes 64 | # ---------------------------------------------------# 65 | cls_output = self.cls_preds[k](cls_feat) 66 | 67 | # ---------------------------------------------------# 68 | # 利用两个卷积标准化激活函数来进行特征提取 69 | # ---------------------------------------------------# 70 | reg_feat = self.reg_convs[k](x) 71 | # ---------------------------------------------------# 72 | # 特征点的回归系数 73 | # reg_pred 80, 80, 4 74 | # reg_pred 40, 40, 4 75 | # reg_pred 20, 20, 4 76 | # ---------------------------------------------------# 77 | reg_output = self.reg_preds[k](reg_feat) 78 | # ---------------------------------------------------# 79 | # 判断特征点是否有对应的物体 80 | # obj_pred 80, 80, 1 81 | # obj_pred 40, 40, 1 82 | # obj_pred 20, 20, 1 83 | # ---------------------------------------------------# 84 | obj_output = self.obj_preds[k](reg_feat) 85 | 86 | output = torch.cat([reg_output, obj_output, cls_output], 1) 87 | outputs.append(output) 88 | return outputs -------------------------------------------------------------------------------- /image_augmentation_test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/image_augmentation_test/__init__.py -------------------------------------------------------------------------------- /image_augmentation_test/dark_channel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cv2 3 | import math 4 | import numpy as np 5 | 6 | 7 | def DarkChannel(im, sz): 8 | b, g, r = cv2.split(im) 9 | dc = cv2.min(cv2.min(r, g), b) 10 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sz, sz)) 11 | dark = cv2.erode(dc, kernel) 12 | return dark 13 | 14 | 15 | def AtmLight(im, dark): 16 | [h, w] = im.shape[:2] 17 | imsz = h * w 18 | numpx = int(max(math.floor(imsz / 1000), 1)) 19 | darkvec = dark.reshape(imsz, 1) 20 | imvec = im.reshape(imsz, 3) 21 | 22 | indices = darkvec.argsort() 23 | indices = indices[imsz - numpx::] 24 | 25 | atmsum = np.zeros([1, 3]) 26 | for ind in range(1, numpx): 27 | atmsum = atmsum + imvec[indices[ind]] 28 | 29 | A = atmsum / numpx; 30 | return A 31 | 32 | 33 | def TransmissionEstimate(im, A, sz): 34 | omega = 0.95 35 | im3 = np.empty(im.shape, im.dtype) 36 | 37 | for ind in range(0, 3): 38 | im3[:, :, ind] = im[:, :, ind] / A[0, ind] 39 | 40 | transmission = 1 - omega * DarkChannel(im3, sz) 41 | return transmission 42 | 43 | 44 | def Guidedfilter(im, p, r, eps): 45 | mean_I = cv2.boxFilter(im, cv2.CV_64F, (r, r)) 46 | mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r)) 47 | mean_Ip = cv2.boxFilter(im * p, cv2.CV_64F, (r, r)) 48 | cov_Ip = mean_Ip - mean_I * mean_p 49 | 50 | mean_II = cv2.boxFilter(im * im, cv2.CV_64F, (r, r)) 51 | var_I = mean_II - mean_I * mean_I 52 | 53 | a = cov_Ip / (var_I + eps) 54 | b = mean_p - a * mean_I 55 | 56 | mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r)) 57 | mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r)) 58 | 59 | q = mean_a * im + mean_b 60 | return q 61 | 62 | 63 | def TransmissionRefine(im, et): 64 | gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) 65 | gray = np.float64(gray) / 255 66 | r = 60 67 | eps = 0.0001 68 | t = Guidedfilter(gray, et, r, eps) 69 | 70 | return t 71 | 72 | 73 | def Recover(im, t, A, tx=0.1): 74 | res = np.empty(im.shape, im.dtype) 75 | t = cv2.max(t, tx) 76 | 77 | for ind in range(0, 3): 78 | res[:, :, ind] = (im[:, :, ind] - A[0, ind]) / t + A[0, ind] 79 | 80 | return res 81 | 82 | 83 | if __name__ == '__main__': 84 | fn = '../images/fog_image.png' 85 | src = cv2.imread(fn) 86 | I = src.astype('float64') / 255 87 | 88 | dark = DarkChannel(I, 15) 89 | A = AtmLight(I, dark) 90 | te = TransmissionEstimate(I, A, 15) 91 | t = TransmissionRefine(src, te) 92 | J = Recover(I, t, A, 0.1) 93 | 94 | arr = np.hstack((I, J)) 95 | cv2.imshow("contrast", arr) 96 | cv2.imwrite("../images/car-02-dehaze.png", J * 255) 97 | cv2.imwrite("../images/car-02-contrast.png", arr * 255) 98 | cv2.waitKey(); -------------------------------------------------------------------------------- /image_augmentation_test/sharpen.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # 线性拉伸处理 8 | # 去掉最大最小0.5%的像素值 线性拉伸至[0,1] 9 | def stretchImage(data, s=0.005, bins=2000): 10 | ht = np.histogram(data, bins); 11 | d = np.cumsum(ht[0]) / float(data.size) 12 | lmin = 0; 13 | lmax = bins - 1 14 | while lmin < bins: 15 | if d[lmin] >= s: 16 | break 17 | lmin += 1 18 | while lmax >= 0: 19 | if d[lmax] <= 1 - s: 20 | break 21 | lmax -= 1 22 | return np.clip((data - ht[1][lmin]) / (ht[1][lmax] - ht[1][lmin]), 0, 1) 23 | 24 | 25 | # 根据半径计算权重参数矩阵 26 | g_para = {} 27 | 28 | 29 | def getPara(radius=5): 30 | global g_para 31 | m = g_para.get(radius, None) 32 | if m is not None: 33 | return m 34 | size = radius * 2 + 1 35 | m = np.zeros((size, size)) 36 | for h in range(-radius, radius + 1): 37 | for w in range(-radius, radius + 1): 38 | if h == 0 and w == 0: 39 | continue 40 | m[radius + h, radius + w] = 1.0 / math.sqrt(h ** 2 + w ** 2) 41 | m /= m.sum() 42 | g_para[radius] = m 43 | return m 44 | 45 | 46 | # 常规的ACE实现 47 | def zmIce(I, ratio=4, radius=300): 48 | para = getPara(radius) 49 | height, width = I.shape 50 | zh = [] 51 | zw = [] 52 | n = 0 53 | while n < radius: 54 | zh.append(0) 55 | zw.append(0) 56 | n += 1 57 | for n in range(height): 58 | zh.append(n) 59 | for n in range(width): 60 | zw.append(n) 61 | n = 0 62 | while n < radius: 63 | zh.append(height - 1) 64 | zw.append(width - 1) 65 | n += 1 66 | # print(zh) 67 | # print(zw) 68 | 69 | Z = I[np.ix_(zh, zw)] 70 | res = np.zeros(I.shape) 71 | for h in range(radius * 2 + 1): 72 | for w in range(radius * 2 + 1): 73 | if para[h][w] == 0: 74 | continue 75 | res += (para[h][w] * np.clip((I - Z[h:h + height, w:w + width]) * ratio, -1, 1)) 76 | return res 77 | 78 | 79 | # 单通道ACE快速增强实现 80 | def zmIceFast(I, ratio, radius): 81 | print(I) 82 | height, width = I.shape[:2] 83 | if min(height, width) <= 2: 84 | return np.zeros(I.shape) + 0.5 85 | Rs = cv2.resize(I, (int((width + 1) / 2), int((height + 1) / 2))) 86 | Rf = zmIceFast(Rs, ratio, radius) # 递归调用 87 | Rf = cv2.resize(Rf, (width, height)) 88 | Rs = cv2.resize(Rs, (width, height)) 89 | 90 | return Rf + zmIce(I, ratio, radius) - zmIce(Rs, ratio, radius) 91 | 92 | 93 | # rgb三通道分别增强 ratio是对比度增强因子 radius是卷积模板半径 94 | def zmIceColor(I, ratio=4, radius=3): 95 | res = np.zeros(I.shape) 96 | for k in range(3): 97 | res[:, :, k] = stretchImage(zmIceFast(I[:, :, k], ratio, radius)) 98 | return res 99 | 100 | 101 | # 主函数 102 | if __name__ == '__main__': 103 | img = cv2.imread('../images/rain_image.png') 104 | res = zmIceColor(img / 255.0) * 255 105 | cv2.imwrite('../images/ship-remove-rain.jpg', res) -------------------------------------------------------------------------------- /model_data/heatmap_vision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/model_data/heatmap_vision.png -------------------------------------------------------------------------------- /model_data/voc_classes.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /model_data/waterscenes.txt: -------------------------------------------------------------------------------- 1 | pier 2 | vessel 3 | ship 4 | boat -------------------------------------------------------------------------------- /neck/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/neck/__init__.py -------------------------------------------------------------------------------- /neck/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/neck/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /neck/__pycache__/coc_fpn_dual.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/neck/__pycache__/coc_fpn_dual.cpython-39.pyc -------------------------------------------------------------------------------- /neck/coc_fpn_dual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from thop import profile 5 | from thop import clever_format 6 | import math 7 | from backbone.attention_modules.eca import eca_block 8 | from backbone.attention_modules.shuffle_attention import ShuffleAttention 9 | from backbone.conv_utils.normal_conv import DWConv, BaseConv 10 | from backbone.vision.context_cluster import ClusterBlock 11 | from backbone.fusion.vr_coc import coc_medium, coc_small 12 | from torchinfo import summary 13 | 14 | 15 | class CoCUpsample(nn.Module): 16 | def __init__(self, in_channels, out_channels, scale=2, ds_conv=False): 17 | super().__init__() 18 | 19 | self.upsample = nn.Sequential( 20 | BaseConv(in_channels, out_channels, 1, 1, act='relu', ds_conv=ds_conv), 21 | nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True) 22 | ) 23 | 24 | def forward(self, x,): 25 | x = self.upsample(x) 26 | return x 27 | 28 | 29 | class CoC_Conv(nn.Module): 30 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="relu", ds_conv=False): 31 | super(CoC_Conv, self).__init__() 32 | 33 | self.coc = ClusterBlock(dim=in_channels) 34 | self.conv_att = BaseConv(in_channels, out_channels, ksize=ksize, stride=stride, act=act, ds_conv=ds_conv) 35 | 36 | def forward(self, x): 37 | x = self.coc(x) 38 | x = self.conv_att(x) 39 | return x 40 | 41 | 42 | # -----------------------------------------# 43 | # ASPP特征提取模块 44 | # 利用不同膨胀率的膨胀卷积进行特征提取 45 | # -----------------------------------------# 46 | class ASPP(nn.Module): 47 | def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): 48 | super(ASPP, self).__init__() 49 | self.branch1 = nn.Sequential( 50 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), 51 | nn.BatchNorm2d(dim_out, momentum=bn_mom), 52 | nn.ReLU(inplace=True), 53 | ) 54 | self.branch2 = nn.Sequential( 55 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), 56 | nn.BatchNorm2d(dim_out, momentum=bn_mom), 57 | nn.ReLU(inplace=True), 58 | ) 59 | self.branch3 = nn.Sequential( 60 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), 61 | nn.BatchNorm2d(dim_out, momentum=bn_mom), 62 | nn.ReLU(inplace=True), 63 | ) 64 | self.branch4 = nn.Sequential( 65 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), 66 | nn.BatchNorm2d(dim_out, momentum=bn_mom), 67 | nn.ReLU(inplace=True), 68 | ) 69 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) 70 | self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) 71 | self.branch5_relu = nn.ReLU(inplace=True) 72 | 73 | self.conv_cat = nn.Sequential( 74 | nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), 75 | nn.BatchNorm2d(dim_out, momentum=bn_mom), 76 | nn.ReLU(inplace=True), 77 | ) 78 | 79 | def forward(self, x): 80 | [b, c, row, col] = x.size() 81 | # -----------------------------------------# 82 | # 一共五个分支 83 | # -----------------------------------------# 84 | conv1x1 = self.branch1(x) 85 | conv3x3_1 = self.branch2(x) 86 | conv3x3_2 = self.branch3(x) 87 | conv3x3_3 = self.branch4(x) 88 | # -----------------------------------------# 89 | # 第五个分支,全局平均池化+卷积 90 | # -----------------------------------------# 91 | global_feature = torch.mean(x, 2, True) 92 | global_feature = torch.mean(global_feature, 3, True) 93 | global_feature = self.branch5_conv(global_feature) 94 | global_feature = self.branch5_bn(global_feature) 95 | global_feature = self.branch5_relu(global_feature) 96 | global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) 97 | 98 | # -----------------------------------------# 99 | # 将五个分支的内容堆叠起来 100 | # 然后1x1卷积整合特征。 101 | # -----------------------------------------# 102 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) 103 | result = self.conv_cat(feature_cat) 104 | return result 105 | 106 | 107 | class SpatialPyramidPooling(nn.Module): 108 | def __init__(self, pool_sizes=[5, 9, 13]): 109 | super(SpatialPyramidPooling, self).__init__() 110 | 111 | self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes]) 112 | 113 | def forward(self, x): 114 | features = [maxpool(x) for maxpool in self.maxpools[::-1]] 115 | features = torch.cat(features + [x], dim=1) 116 | 117 | return features 118 | 119 | 120 | def shuffle_channels(x, groups=2): 121 | """Channel Shuffle""" 122 | 123 | batch_size, channels, h, w = x.data.size() 124 | if channels % groups: 125 | return x 126 | channels_per_group = channels // groups 127 | x = x.view(batch_size, groups, channels_per_group, h, w) 128 | x = torch.transpose(x, 1, 2).contiguous() 129 | x = x.view(batch_size, -1, h, w) 130 | return x 131 | 132 | 133 | class CoCFpnDual(nn.Module): 134 | def __init__(self, num_seg_class=9, depth=1.0, width=1.0, in_features=("dark2", "dark3", "dark4", "dark5"), 135 | in_channels=[64, 128, 320, 512], aspp_channel=1024): 136 | super().__init__() 137 | 138 | Conv = CoC_Conv 139 | 140 | self.backbone = coc_small(pretrained=False, width=width) 141 | self.in_features = in_features 142 | self.num_seg_class = num_seg_class 143 | in_channels = [int(item*width) for item in in_channels] 144 | 145 | self.aspp = ASPP(dim_in=in_channels[-1], dim_out=in_channels[-1]) 146 | 147 | # ================================= segmentation modules =================================== # 148 | # ----------------------- 20*20*512 -> 40*40*320 -> 40*40*640 ------------------------ # 149 | self.upsample5_4 = CoCUpsample(in_channels=in_channels[-1], out_channels=in_channels[-2]) 150 | self.sc_attn_seg4 = ShuffleAttention(channel=in_channels[-2]*2) 151 | # ------------------------------------------------------------------------------------- # 152 | 153 | # ----------------------- 40*40*640 -> 80*80*128 -> 80*80*256 ------------------------ # 154 | self.upsample4_3 = CoCUpsample(in_channels=in_channels[-2]*2, out_channels=in_channels[-3]) 155 | self.sc_attn_seg3 = ShuffleAttention(channel=in_channels[-3] * 2) 156 | # ------------------------------------------------------------------------------------ # 157 | 158 | # ----------------------- 80*80*256 -> 160*160*64 -> 160*160*128 ------------------------ # 159 | self.upsample3_2 = CoCUpsample(in_channels=in_channels[-3] * 2, out_channels=in_channels[0]) 160 | self.sc_attn_seg2 = ShuffleAttention(channel=in_channels[0] * 2) 161 | # ------------------------------------------------------------------------------------ # 162 | 163 | # ----------------------- 80*80*256 -> 160*160*64 -> 640*640*9 ------------------------ # 164 | self.upsample2_0 = CoCUpsample(in_channels=in_channels[0] * 2, out_channels=self.num_seg_class, scale=4) 165 | # ------------------------------------------------------------------------------------ # 166 | # ========================================================================================== # 167 | 168 | # ================================= detection modules ====================================== # 169 | # ----------------------- 20*20*512 -> 20*20*512 ----------------------- # 170 | self.p5_out_det = Conv(in_channels=in_channels[-1], out_channels=in_channels[-1]) 171 | # ----------------------------------------------------------------------- # 172 | 173 | # ----------------------- 20*20*512 -> 40*40*320 -> 40*40*640 -> 40*40*320 ------------------------ # 174 | self.p5_4_det = CoCUpsample(in_channels=in_channels[-1], out_channels=in_channels[-2]) 175 | self.p4_out_det = Conv(in_channels=in_channels[-2]*2, out_channels=in_channels[-2]) 176 | # ------------------------------------------------------------------------------------------------- # 177 | 178 | # ----------------------- 40*40*320 -> 80*80*128 -> 80*80*256 -> 80*80*128 ------------------------ # 179 | self.p4_3_det = CoCUpsample(in_channels=in_channels[-2], out_channels=in_channels[-3]) 180 | self.p3_out_det = Conv(in_channels=in_channels[-3]*2, out_channels=in_channels[-3]) 181 | # ------------------------------------------------------------------------------------------------- # 182 | # ========================================================================================== # 183 | 184 | def forward(self, x, x_radar): 185 | 186 | x_out, x_radar_out = self.backbone(x, x_radar) 187 | 188 | x_stage2, x_stage3, x_stage4, x_stage5 = x_out 189 | x_stage5 = self.aspp(x_stage5) 190 | 191 | x_radar_stage2, x_radar_stage3, x_radar_stage4, x_radar_stage5 = x_radar_out 192 | 193 | # ---------------------------- segmentation ------------------------------- # 194 | x_stage5_4 = self.upsample5_4(x_stage5) 195 | x_stage4_concat_5 = torch.cat([x_stage4, x_stage5_4], dim=1) 196 | x_stage4_concat_5 = shuffle_channels(x_stage4_concat_5) 197 | x_stage4_concat_5 = self.sc_attn_seg4(x_stage4_concat_5) 198 | 199 | x_stage4_3 = self.upsample4_3(x_stage4_concat_5) 200 | x_stage3_concat_4 = torch.cat([x_stage4_3, x_stage3], dim=1) 201 | x_stage3_concat_4 = shuffle_channels(x_stage3_concat_4) 202 | x_stage3_concat_4 = self.sc_attn_seg3(x_stage3_concat_4) 203 | 204 | x_stage3_2 = self.upsample3_2(x_stage3_concat_4) 205 | x_stage2_concat_3 = torch.cat([x_stage3_2, x_stage2], dim=1) 206 | x_stage2_concat_3 = shuffle_channels(x_stage2_concat_3) 207 | x_stage2_concat_3 = self.sc_attn_seg2(x_stage2_concat_3) 208 | 209 | x_segmentation_out = self.upsample2_0(x_stage2_concat_3) 210 | # ------------------------------------------------------------------------ # 211 | 212 | # ----------------------------- detection -------------------------------- # 213 | p5_out = self.p5_out_det(x_radar_stage5) 214 | 215 | p5_4_upsample = self.p5_4_det(p5_out) 216 | p4_concat_5 = torch.cat([x_radar_stage4, p5_4_upsample], dim=1) 217 | p4_out = self.p4_out_det(p4_concat_5) 218 | 219 | p4_3_upsample = self.p4_3_det(p4_out) 220 | p3_concat_4 = torch.cat([x_radar_stage3, p4_3_upsample], dim=1) 221 | p3_out = self.p3_out_det(p3_concat_4) 222 | # ------------------------------------------------------------------------ # 223 | 224 | return (p3_out, p4_out, p5_out), x_segmentation_out 225 | 226 | 227 | if __name__ == '__main__': 228 | # input_map = torch.randn((1, 512, 20, 20)).cuda() 229 | # aspp = ASPP(dim_in=512, dim_out=1024).cuda() 230 | model = CoCFpnDual(width=1.0).cuda() 231 | model.eval() 232 | input = torch.rand(1, 3, 512, 512).cuda() 233 | input_radar = torch.rand(1, 4, 512, 512).cuda() 234 | output = model(input, input_radar) 235 | print(summary(model, input_size=[(1, 3, 512, 512), (1, 4, 512, 512)])) 236 | macs, params = profile(model, inputs=(input, input_radar)) 237 | macs, params = clever_format([macs, params], "%.3f") 238 | print("params:", params) 239 | print("macs:", macs) 240 | print(output[1].shape) 241 | print(output[0][0].shape) 242 | print(output[0][1].shape) 243 | print(output[0][2].shape) 244 | # model = SpatialPyramidPooling() 245 | # input = torch.rand(1, 512, 20, 20) 246 | # output = model(input) 247 | # print(output.shape) 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__init__.py -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /nets/__pycache__/deeplabv3_training.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/deeplabv3_training.cpython-39.pyc -------------------------------------------------------------------------------- /nets/__pycache__/efficient_vrnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/efficient_vrnet.cpython-39.pyc -------------------------------------------------------------------------------- /nets/__pycache__/yolo_training.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/yolo_training.cpython-39.pyc -------------------------------------------------------------------------------- /nets/deeplabv3_training.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def CE_Loss(inputs, target, cls_weights, num_classes=21): 10 | n, c, h, w = inputs.size() 11 | nt, ht, wt = target.size() 12 | if h != ht and w != wt: 13 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 14 | 15 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 16 | temp_target = target.view(-1) 17 | 18 | CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target) 19 | return CE_loss 20 | 21 | 22 | def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2): 23 | n, c, h, w = inputs.size() 24 | nt, ht, wt = target.size() 25 | if h != ht and w != wt: 26 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 27 | 28 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 29 | temp_target = target.view(-1) 30 | 31 | logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs, 32 | temp_target) 33 | pt = torch.exp(logpt) 34 | if alpha is not None: 35 | logpt *= alpha 36 | loss = -((1 - pt) ** gamma) * logpt 37 | loss = loss.mean() 38 | return loss 39 | 40 | 41 | def Dice_loss(inputs, target, beta=1, smooth=1e-5): 42 | n, c, h, w = inputs.size() 43 | nt, ht, wt, ct = target.size() 44 | if h != ht and w != wt: 45 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 46 | 47 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1) 48 | temp_target = target.view(n, -1, ct) 49 | 50 | # --------------------------------------------# 51 | # 计算dice loss 52 | # --------------------------------------------# 53 | tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1]) 54 | fp = torch.sum(temp_inputs, axis=[0, 1]) - tp 55 | fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp 56 | 57 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 58 | dice_loss = 1 - torch.mean(score) 59 | return dice_loss 60 | 61 | 62 | def weights_init(net, init_type='normal', init_gain=0.02): 63 | def init_func(m): 64 | classname = m.__class__.__name__ 65 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 66 | if init_type == 'normal': 67 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain) 68 | elif init_type == 'xavier': 69 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain) 70 | elif init_type == 'kaiming': 71 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 72 | elif init_type == 'orthogonal': 73 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain) 74 | else: 75 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 76 | elif classname.find('BatchNorm2d') != -1: 77 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 78 | torch.nn.init.constant_(m.bias.data, 0.0) 79 | 80 | print('initialize network with %s type' % init_type) 81 | net.apply(init_func) 82 | 83 | 84 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio=0.1, warmup_lr_ratio=0.1, 85 | no_aug_iter_ratio=0.3, step_num=10): 86 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 87 | if iters <= warmup_total_iters: 88 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 89 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start 90 | elif iters >= total_iters - no_aug_iter: 91 | lr = min_lr 92 | else: 93 | lr = min_lr + 0.5 * (lr - min_lr) * ( 94 | 1.0 + math.cos( 95 | math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) 96 | ) 97 | return lr 98 | 99 | def step_lr(lr, decay_rate, step_size, iters): 100 | if step_size < 1: 101 | raise ValueError("step_size must above 1.") 102 | n = iters // step_size 103 | out_lr = lr * decay_rate ** n 104 | return out_lr 105 | 106 | if lr_decay_type == "cos": 107 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 108 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 109 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 110 | func = partial(yolox_warm_cos_lr, lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 111 | else: 112 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 113 | step_size = total_iters / step_num 114 | func = partial(step_lr, lr, decay_rate, step_size) 115 | 116 | return func 117 | 118 | 119 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): 120 | lr = lr_scheduler_func(epoch) 121 | for param_group in optimizer.param_groups: 122 | param_group['lr'] = lr 123 | -------------------------------------------------------------------------------- /nets/efficient_vrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from neck.coc_fpn_dual import CoCFpnDual 5 | from head.decouplehead import DecoupleHead 6 | from torchinfo import summary 7 | from thop import profile 8 | from thop import clever_format 9 | # from torchsummary import summary 10 | import time 11 | 12 | 13 | class EfficientVRNet(nn.Module): 14 | def __init__(self, num_classes, num_seg_classes, phi): 15 | super().__init__() 16 | depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00} 17 | width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00} 18 | depth, width = depth_dict[phi], width_dict[phi] 19 | 20 | self.backbone = CoCFpnDual(width=width, num_seg_class=num_seg_classes) 21 | 22 | self.head = DecoupleHead(num_classes, width, depthwise=True) 23 | 24 | def forward(self, x, x_radar): 25 | fpn_outs, seg_outputs = self.backbone.forward(x, x_radar) 26 | det_outputs = self.head.forward(fpn_outs) 27 | return det_outputs, seg_outputs 28 | 29 | 30 | if __name__ == '__main__': 31 | model = EfficientVRNet(num_classes=4, phi='l', num_seg_classes=9).cuda() 32 | model.eval() 33 | input_map1 = torch.randn((1, 3, 512, 512)).cuda() 34 | input_map2 = torch.randn((1, 4, 512, 512)).cuda() 35 | t1 = time.time() 36 | test_times = 300 37 | for i in range(test_times): 38 | output_map, output_seg = model(input_map1, input_map2) 39 | t2 = time.time() 40 | print("fps:", (1 / ((t2 - t1) / test_times))) 41 | 42 | output_map, output_seg = model(input_map1, input_map2) 43 | print(output_map[0].shape) 44 | print(output_map[1].shape) 45 | print(output_map[2].shape) 46 | print(output_seg.shape) 47 | print(summary(model, input_size=((1, 3, 512, 512), (1, 4, 512, 512)))) 48 | 49 | macs, params = profile(model, inputs=([input_map1, input_map2])) 50 | flops = macs * 2 51 | flops, params = clever_format([flops, params], "%.3f") 52 | print("params:", params) 53 | print("flops:", flops) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -----------------------------------------------------------------------# 2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | # -----------------------------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from yolo import YOLO 12 | 13 | if __name__ == "__main__": 14 | yolo = YOLO() 15 | # ----------------------------------------------------------------------------------------------------------# 16 | # mode用于指定测试的模式: 17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 23 | # ----------------------------------------------------------------------------------------------------------# 24 | mode = "predict" 25 | # -------------------------------------------------------------------------# 26 | # crop 指定了是否在单张图片预测后对目标进行截取 27 | # count 指定了是否进行目标的计数 28 | # crop、count仅在mode='predict'时有效 29 | # -------------------------------------------------------------------------# 30 | crop = False 31 | count = False 32 | # ----------------------------------------------------------------------------------------------------------# 33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 37 | # video_fps 用于保存的视频的fps 38 | # 39 | # video_path、video_save_path和video_fps仅在mode='video'时有效 40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 41 | # ----------------------------------------------------------------------------------------------------------# 42 | video_path = 'images/video2.mp4' 43 | video_save_path = "images/video_det_out.mp4" 44 | video_fps = 33.0 45 | # ----------------------------------------------------------------------------------------------------------# 46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 47 | # fps_image_path 用于指定测试的fps图片 48 | # 49 | # test_interval和fps_image_path仅在mode='fps'有效 50 | # ----------------------------------------------------------------------------------------------------------# 51 | test_interval = 100 52 | fps_image_path = "images/example1.jpg" 53 | # -------------------------------------------------------------------------# 54 | # dir_origin_path 指定了用于检测的图片的文件夹路径 55 | # dir_save_path 指定了检测完图片的保存路径 56 | # 57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 58 | # -------------------------------------------------------------------------# 59 | dir_origin_path = "img/" 60 | dir_save_path = "img_out/" 61 | # -------------------------------------------------------------------------# 62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下 63 | # 64 | # heatmap_save_path仅在mode='heatmap'有效 65 | # -------------------------------------------------------------------------# 66 | heatmap_save_path = "model_data/heatmap_vision.png" 67 | # -------------------------------------------------------------------------# 68 | # simplify 使用Simplify onnx 69 | # onnx_save_path 指定了onnx的保存路径 70 | # -------------------------------------------------------------------------# 71 | simplify = True 72 | onnx_save_path = "model_data/models.onnx" 73 | 74 | if mode == "predict": 75 | ''' 76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 77 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。 78 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值 79 | 在原图上利用矩阵的方式进行截取。 80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断, 81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。 82 | ''' 83 | while True: 84 | img = input('Input image filename:') 85 | try: 86 | image = Image.open(img) 87 | image_id = img[-20:-4] 88 | except: 89 | print('Open Error! Try again!') 90 | continue 91 | else: 92 | r_image = yolo.detect_image(image, image_id, crop=crop, count=count) 93 | r_image.show() 94 | 95 | elif mode == "video": 96 | capture = cv2.VideoCapture(video_path) 97 | if video_save_path != "": 98 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 99 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 100 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 101 | 102 | ref, frame = capture.read() 103 | if not ref: 104 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 105 | 106 | fps = 0.0 107 | while (True): 108 | t1 = time.time() 109 | # 读取某一帧 110 | ref, frame = capture.read() 111 | if not ref: 112 | break 113 | # 格式转变,BGRtoRGB 114 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 115 | # 转变成Image 116 | frame = Image.fromarray(np.uint8(frame)) 117 | # 进行检测 118 | frame = np.array(yolo.detect_image(frame)) 119 | # RGBtoBGR满足opencv显示格式 120 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 121 | 122 | fps = (fps + (1. / (time.time() - t1))) / 2 123 | print("fps= %.2f" % (fps)) 124 | # frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 125 | 126 | # cv2.imshow("video", frame) 127 | c = cv2.waitKey(1) & 0xff 128 | if video_save_path != "": 129 | out.write(frame) 130 | 131 | if c == 27: 132 | capture.release() 133 | break 134 | 135 | print("Video Detection Done!") 136 | capture.release() 137 | if video_save_path != "": 138 | print("Save processed video to the path :" + video_save_path) 139 | out.release() 140 | cv2.destroyAllWindows() 141 | 142 | elif mode == "fps": 143 | img = Image.open(fps_image_path) 144 | tact_time = yolo.get_FPS(img, test_interval) 145 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1') 146 | 147 | elif mode == "dir_predict": 148 | import os 149 | 150 | from tqdm import tqdm 151 | 152 | img_names = os.listdir(dir_origin_path) 153 | for img_name in tqdm(img_names): 154 | if img_name.lower().endswith( 155 | ('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 156 | image_path = os.path.join(dir_origin_path, img_name) 157 | image = Image.open(image_path) 158 | r_image = yolo.detect_image(image) 159 | if not os.path.exists(dir_save_path): 160 | os.makedirs(dir_save_path) 161 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 162 | 163 | elif mode == "heatmap": 164 | while True: 165 | img = input('Input image filename:') 166 | try: 167 | image = Image.open(img) 168 | image_id = img[-20:-4] 169 | except: 170 | print('Open Error! Try again!') 171 | continue 172 | else: 173 | yolo.detect_heatmap(image, image_id, heatmap_save_path) 174 | 175 | elif mode == "export_onnx": 176 | yolo.convert_to_onnx(simplify, onnx_save_path) 177 | 178 | else: 179 | raise AssertionError( 180 | "Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.") 181 | -------------------------------------------------------------------------------- /predict_seg.py: -------------------------------------------------------------------------------- 1 | #----------------------------------------------------# 2 | # 将单张图片预测、摄像头检测和FPS测试功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #----------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from deeplab import DeeplabV3 12 | 13 | if __name__ == "__main__": 14 | #-------------------------------------------------------------------------# 15 | # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可 16 | #-------------------------------------------------------------------------# 17 | deeplab = DeeplabV3() 18 | #----------------------------------------------------------------------------------------------------------# 19 | # mode用于指定测试的模式: 20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 25 | #----------------------------------------------------------------------------------------------------------# 26 | mode = "predict" 27 | #-------------------------------------------------------------------------# 28 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算 29 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量 30 | # 31 | # count、name_classes仅在mode='predict'时有效 32 | #-------------------------------------------------------------------------# 33 | count = False 34 | name_classes = [ "free-space", "pier", "vessel", "ship", "boat", "buoy", "sailor", "kayak"] 35 | # name_classes = ["background","cat","dog"] 36 | #----------------------------------------------------------------------------------------------------------# 37 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 38 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 39 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 40 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 41 | # video_fps 用于保存的视频的fps 42 | # 43 | # video_path、video_save_path和video_fps仅在mode='video'时有效 44 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 45 | #----------------------------------------------------------------------------------------------------------# 46 | video_path = 'images/video2.mp4' 47 | video_save_path = "images/video_seg_out.mp4" 48 | video_fps = 33.0 49 | #----------------------------------------------------------------------------------------------------------# 50 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 51 | # fps_image_path 用于指定测试的fps图片 52 | # 53 | # test_interval和fps_image_path仅在mode='fps'有效 54 | #----------------------------------------------------------------------------------------------------------# 55 | test_interval = 100 56 | fps_image_path = "img/street.jpg" 57 | #-------------------------------------------------------------------------# 58 | # dir_origin_path 指定了用于检测的图片的文件夹路径 59 | # dir_save_path 指定了检测完图片的保存路径 60 | # 61 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 62 | #-------------------------------------------------------------------------# 63 | dir_origin_path = "img/" 64 | dir_save_path = "img_out/" 65 | #-------------------------------------------------------------------------# 66 | # simplify 使用Simplify onnx 67 | # onnx_save_path 指定了onnx的保存路径 68 | #-------------------------------------------------------------------------# 69 | simplify = True 70 | onnx_save_path = "model_data/models.onnx" 71 | 72 | if mode == "predict": 73 | ''' 74 | predict.py有几个注意点 75 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。 76 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。 77 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。 78 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。 79 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。 80 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3)) 81 | for c in range(self.num_classes): 82 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8') 83 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8') 84 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8') 85 | ''' 86 | while True: 87 | img = input('Input image filename:') 88 | try: 89 | image = Image.open(img) 90 | image_id = img[-20:-4] 91 | except Exception as e: 92 | print('Open Error! Try again!') 93 | print(e) 94 | continue 95 | else: 96 | r_image = deeplab.detect_image(image, image_id, count=count, name_classes=name_classes) 97 | r_image.show() 98 | 99 | elif mode == "video": 100 | capture=cv2.VideoCapture(video_path) 101 | if video_save_path!="": 102 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 103 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 104 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 105 | 106 | ref, frame = capture.read() 107 | if not ref: 108 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 109 | 110 | fps = 0.0 111 | while(True): 112 | t1 = time.time() 113 | # 读取某一帧 114 | ref, frame = capture.read() 115 | if not ref: 116 | break 117 | # 格式转变,BGRtoRGB 118 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 119 | # 转变成Image 120 | frame = Image.fromarray(np.uint8(frame)) 121 | # 进行检测 122 | frame = np.array(deeplab.detect_image(frame)) 123 | # RGBtoBGR满足opencv显示格式 124 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 125 | 126 | fps = ( fps + (1./(time.time()-t1)) ) / 2 127 | print("fps= %.2f"%(fps)) 128 | # frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 129 | # 130 | # cv2.imshow("video",frame) 131 | c= cv2.waitKey(1) & 0xff 132 | if video_save_path!="": 133 | out.write(frame) 134 | 135 | if c==27: 136 | capture.release() 137 | break 138 | print("Video Detection Done!") 139 | capture.release() 140 | if video_save_path!="": 141 | print("Save processed video to the path :" + video_save_path) 142 | out.release() 143 | cv2.destroyAllWindows() 144 | 145 | elif mode == "fps": 146 | img = Image.open(fps_image_path) 147 | tact_time = deeplab.get_FPS(img, test_interval) 148 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 149 | 150 | elif mode == "dir_predict": 151 | import os 152 | from tqdm import tqdm 153 | 154 | img_names = os.listdir(dir_origin_path) 155 | for img_name in tqdm(img_names): 156 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 157 | image_path = os.path.join(dir_origin_path, img_name) 158 | image = Image.open(image_path) 159 | r_image = deeplab.detect_image(image) 160 | if not os.path.exists(dir_save_path): 161 | os.makedirs(dir_save_path) 162 | r_image.save(os.path.join(dir_save_path, img_name)) 163 | elif mode == "export_onnx": 164 | deeplab.convert_to_onnx(simplify, onnx_save_path) 165 | 166 | else: 167 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.") 168 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.1.0 2 | einops==0.4.1 3 | matplotlib==3.4.2 4 | numpy==1.21.6 5 | onnx==1.11.0 6 | onnxsim==0.4.17 7 | opencv_python_headless==4.5.5.64 8 | pandas==1.3.5 9 | Pillow==9.4.0 10 | pycocotools==2.0.5 11 | scikit_learn==1.0.2 12 | scipy==1.7.3 13 | thop==0.0.31.post2005241907 14 | timm==0.6.7 15 | torch==1.9.0 16 | torchinfo==1.7.0 17 | torchvision==0.10.0 18 | tqdm==4.64.0 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/callbacks.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/multitaskloss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/multitaskloss.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_bbox.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils_bbox.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils_fit.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_map.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils_map.cpython-39.pyc -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import torch 4 | import matplotlib 5 | 6 | matplotlib.use('Agg') 7 | import scipy.signal 8 | from matplotlib import pyplot as plt 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import shutil 12 | import numpy as np 13 | 14 | from PIL import Image 15 | from tqdm import tqdm 16 | from .utils import cvtColor, preprocess_input, resize_image 17 | from .utils_bbox import decode_outputs, non_max_suppression 18 | from .utils_map import get_coco_map, get_map 19 | 20 | 21 | class LossHistory(): 22 | def __init__(self, log_dir, model, input_shape): 23 | self.log_dir = log_dir 24 | self.losses = [] 25 | self.val_loss = [] 26 | 27 | os.makedirs(self.log_dir) 28 | self.writer = SummaryWriter(self.log_dir) 29 | try: 30 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 31 | self.writer.add_graph(model, dummy_input) 32 | except: 33 | pass 34 | 35 | def append_loss(self, epoch, loss, val_loss): 36 | if not os.path.exists(self.log_dir): 37 | os.makedirs(self.log_dir) 38 | 39 | self.losses.append(loss) 40 | self.val_loss.append(val_loss) 41 | 42 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 43 | f.write(str(loss)) 44 | f.write("\n") 45 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 46 | f.write(str(val_loss)) 47 | f.write("\n") 48 | 49 | self.writer.add_scalar('loss', loss, epoch) 50 | self.writer.add_scalar('val_loss', val_loss, epoch) 51 | self.loss_plot() 52 | 53 | def loss_plot(self): 54 | iters = range(len(self.losses)) 55 | 56 | plt.figure() 57 | plt.plot(iters, self.losses, 'red', linewidth=2, label='train loss') 58 | plt.plot(iters, self.val_loss, 'coral', linewidth=2, label='val loss') 59 | try: 60 | if len(self.losses) < 25: 61 | num = 5 62 | else: 63 | num = 15 64 | 65 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle='--', linewidth=2, 66 | label='smooth train loss') 67 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle='--', linewidth=2, 68 | label='smooth val loss') 69 | except: 70 | pass 71 | 72 | plt.grid(True) 73 | plt.xlabel('Epoch') 74 | plt.ylabel('Loss') 75 | plt.legend(loc="upper right") 76 | 77 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 78 | 79 | plt.cla() 80 | plt.close("all") 81 | 82 | 83 | class EvalCallback(): 84 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, local_rank, radar_path, \ 85 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, 86 | MINOVERLAP=0.5, eval_flag=True, period=1): 87 | super(EvalCallback, self).__init__() 88 | 89 | self.net = net 90 | self.input_shape = input_shape 91 | self.class_names = class_names 92 | self.num_classes = num_classes 93 | self.val_lines = val_lines 94 | self.log_dir = log_dir 95 | self.cuda = cuda 96 | self.local_rank = local_rank 97 | self.map_out_path = map_out_path 98 | self.max_boxes = max_boxes 99 | self.confidence = confidence 100 | self.nms_iou = nms_iou 101 | self.letterbox_image = letterbox_image 102 | self.MINOVERLAP = MINOVERLAP 103 | self.eval_flag = eval_flag 104 | self.period = period 105 | self.radar_path = radar_path 106 | 107 | self.maps = [0] 108 | self.epoches = [0] 109 | if self.eval_flag: 110 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 111 | f.write(str(0)) 112 | f.write("\n") 113 | 114 | def get_map_txt(self, image_id, image, radar_data, class_names, map_out_path): 115 | f = open(os.path.join(map_out_path, "detection-results/" + image_id + ".txt"), "w") 116 | image_shape = np.array(np.shape(image)[0:2]) 117 | # ---------------------------------------------------------# 118 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 119 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 120 | # ---------------------------------------------------------# 121 | image = cvtColor(image) 122 | # ---------------------------------------------------------# 123 | # 给图像增加灰条,实现不失真的resize 124 | # 也可以直接resize进行识别 125 | # ---------------------------------------------------------# 126 | image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) 127 | # ---------------------------------------------------------# 128 | # 添加上batch_size维度 129 | # ---------------------------------------------------------# 130 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 131 | 132 | with torch.no_grad(): 133 | images = torch.from_numpy(image_data) 134 | if self.cuda: 135 | images = images.cuda(self.local_rank) 136 | # ---------------------------------------------------------# 137 | # 将图像输入网络当中进行预测! 138 | # ---------------------------------------------------------# 139 | outputs, _ = self.net(images, radar_data) 140 | outputs = decode_outputs(outputs, self.input_shape, self.local_rank) 141 | # ---------------------------------------------------------# 142 | # 将预测框进行堆叠,然后进行非极大抑制 143 | # ---------------------------------------------------------# 144 | results = non_max_suppression(outputs, self.num_classes, self.input_shape, 145 | image_shape, self.letterbox_image, conf_thres=self.confidence, 146 | nms_thres=self.nms_iou) 147 | 148 | if results[0] is None: 149 | return 150 | 151 | top_label = np.array(results[0][:, 6], dtype='int32') 152 | top_conf = results[0][:, 4] * results[0][:, 5] 153 | top_boxes = results[0][:, :4] 154 | 155 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] 156 | top_boxes = top_boxes[top_100] 157 | top_conf = top_conf[top_100] 158 | top_label = top_label[top_100] 159 | 160 | for i, c in list(enumerate(top_label)): 161 | predicted_class = self.class_names[int(c)] 162 | box = top_boxes[i] 163 | score = str(top_conf[i]) 164 | 165 | top, left, bottom, right = box 166 | if predicted_class not in class_names: 167 | continue 168 | 169 | f.write("%s %s %s %s %s %s\n" % ( 170 | predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom)))) 171 | 172 | f.close() 173 | return 174 | 175 | def on_epoch_end(self, epoch, model_eval): 176 | if epoch % self.period == 0 and self.eval_flag: 177 | self.net = model_eval 178 | if not os.path.exists(self.map_out_path): 179 | os.makedirs(self.map_out_path) 180 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): 181 | os.makedirs(os.path.join(self.map_out_path, "ground-truth")) 182 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): 183 | os.makedirs(os.path.join(self.map_out_path, "detection-results")) 184 | print("Get map.") 185 | for annotation_line in tqdm(self.val_lines): 186 | line = annotation_line.split() 187 | 188 | # ------------------------------# 189 | # 读取雷达特征map 190 | # ------------------------------# 191 | pattern_string = "\d{10}.\d{5}" 192 | pattern = re.compile(pattern_string) # 查找数字 193 | name = pattern.findall(annotation_line)[-1] 194 | 195 | radar_path = os.path.join(self.radar_path, name + '.npz') 196 | radar_data = np.load(radar_path)['arr_0'] 197 | radar_data = torch.from_numpy(radar_data).type(torch.FloatTensor).unsqueeze(0).cuda(self.local_rank) 198 | 199 | image_id = os.path.basename(line[0]).split('.')[0] 200 | # ------------------------------# 201 | # 读取图像并转换成RGB图像 202 | # ------------------------------# 203 | image = Image.open(line[0]) 204 | # ------------------------------# 205 | # 获得预测框 206 | # ------------------------------# 207 | gt_boxes = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) 208 | # ------------------------------# 209 | # 获得预测txt 210 | # ------------------------------# 211 | self.get_map_txt(image_id, image, radar_data, self.class_names, self.map_out_path) 212 | 213 | # ------------------------------# 214 | # 获得真实框txt 215 | # ------------------------------# 216 | with open(os.path.join(self.map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f: 217 | for box in gt_boxes: 218 | left, top, right, bottom, obj = box 219 | obj_name = self.class_names[obj] 220 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 221 | 222 | print("Calculate Map.") 223 | try: 224 | temp_map = get_coco_map(class_names=self.class_names, path=self.map_out_path)[1] 225 | except: 226 | temp_map = get_map(self.MINOVERLAP, False, path=self.map_out_path) 227 | self.maps.append(temp_map) 228 | self.epoches.append(epoch) 229 | 230 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 231 | f.write(str(temp_map)) 232 | f.write("\n") 233 | 234 | plt.figure() 235 | plt.plot(self.epoches, self.maps, 'red', linewidth=2, label='train map') 236 | 237 | plt.grid(True) 238 | plt.xlabel('Epoch') 239 | plt.ylabel('Map %s' % str(self.MINOVERLAP)) 240 | plt.title('A Map Curve') 241 | plt.legend(loc="upper right") 242 | 243 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) 244 | plt.cla() 245 | plt.close("all") 246 | 247 | print("Get map done.") 248 | shutil.rmtree(self.map_out_path) 249 | -------------------------------------------------------------------------------- /utils/multitaskloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class MultiTaskLossWrapper(nn.Module): 7 | def __init__(self, task_num): 8 | super().__init__() 9 | self.task_num = task_num 10 | self.log_vars = nn.Parameter(torch.zeros(task_num-1)) 11 | 12 | def forward(self, loss_seg, loss_det): 13 | loss0 = loss_det 14 | 15 | precision1 = torch.exp(-self.log_vars[0]) 16 | loss1 = precision1 * loss_seg + self.log_vars[0] 17 | 18 | return loss0 + loss1 19 | 20 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | #---------------------------------------------------------# 6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 8 | #---------------------------------------------------------# 9 | def cvtColor(image): 10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 11 | return image 12 | else: 13 | image = image.convert('RGB') 14 | return image 15 | 16 | #---------------------------------------------------# 17 | # 对输入图像进行resize 18 | #---------------------------------------------------# 19 | def resize_image(image, size, letterbox_image): 20 | iw, ih = image.size 21 | w, h = size 22 | if letterbox_image: 23 | scale = min(w/iw, h/ih) 24 | nw = int(iw*scale) 25 | nh = int(ih*scale) 26 | 27 | image = image.resize((nw,nh), Image.BICUBIC) 28 | new_image = Image.new('RGB', size, (128,128,128)) 29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 30 | else: 31 | new_image = image.resize((w, h), Image.BICUBIC) 32 | return new_image 33 | 34 | #---------------------------------------------------# 35 | # 获得类 36 | #---------------------------------------------------# 37 | def get_classes(classes_path): 38 | with open(classes_path, encoding='utf-8') as f: 39 | class_names = f.readlines() 40 | class_names = [c.strip() for c in class_names] 41 | return class_names, len(class_names) 42 | 43 | def preprocess_input(image): 44 | image /= 255.0 45 | image -= np.array([0.485, 0.456, 0.406]) 46 | image /= np.array([0.229, 0.224, 0.225]) 47 | return image 48 | 49 | 50 | def preprocess_input_radar(data): 51 | _range = np.max(data) - np.min(data) 52 | data = (data - np.min(data)) / _range + 0.0000000000001 53 | return data 54 | 55 | #---------------------------------------------------# 56 | # 获得学习率 57 | #---------------------------------------------------# 58 | def get_lr(optimizer): 59 | for param_group in optimizer.param_groups: 60 | return param_group['lr'] 61 | 62 | def show_config(**kwargs): 63 | print('Configurations:') 64 | print('-' * 70) 65 | print('|%25s | %40s|' % ('keys', 'values')) 66 | print('-' * 70) 67 | for key, value in kwargs.items(): 68 | print('|%25s | %40s|' % (str(key), str(value))) 69 | print('-' * 70) 70 | -------------------------------------------------------------------------------- /utils/utils_bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.ops import nms, boxes 4 | 5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image): 6 | #-----------------------------------------------------------------# 7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 8 | #-----------------------------------------------------------------# 9 | box_yx = box_xy[..., ::-1] 10 | box_hw = box_wh[..., ::-1] 11 | input_shape = np.array(input_shape) 12 | image_shape = np.array(image_shape) 13 | 14 | if letterbox_image: 15 | #-----------------------------------------------------------------# 16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 17 | # new_shape指的是宽高缩放情况 18 | #-----------------------------------------------------------------# 19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape)) 20 | offset = (input_shape - new_shape)/2./input_shape 21 | scale = input_shape/new_shape 22 | 23 | box_yx = (box_yx - offset) * scale 24 | box_hw *= scale 25 | 26 | box_mins = box_yx - (box_hw / 2.) 27 | box_maxes = box_yx + (box_hw / 2.) 28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1) 29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1) 30 | return boxes 31 | 32 | def decode_outputs(outputs, input_shape, local_rank): 33 | grids = [] 34 | strides = [] 35 | hw = [x.shape[-2:] for x in outputs] 36 | #---------------------------------------------------# 37 | # outputs输入前代表每个特征层的预测结果 38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400 39 | # batch_size, 5 + num_classes, 40, 40 40 | # batch_size, 5 + num_classes, 20, 20 41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400 42 | # 堆叠后为batch_size, 8400, 5 + num_classes 43 | #---------------------------------------------------# 44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) 45 | #---------------------------------------------------# 46 | # 获得每一个特征点属于每一个种类的概率 47 | #---------------------------------------------------# 48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:]) 49 | for h, w in hw: 50 | #---------------------------# 51 | # 根据特征层的高宽生成网格点 52 | #---------------------------# 53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)]) 54 | #---------------------------# 55 | # 1, 6400, 2 56 | # 1, 1600, 2 57 | # 1, 400, 2 58 | #---------------------------# 59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2) 60 | shape = grid.shape[:2] 61 | 62 | grids.append(grid) 63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h)) 64 | #---------------------------# 65 | # 将网格点堆叠到一起 66 | # 1, 6400, 2 67 | # 1, 1600, 2 68 | # 1, 400, 2 69 | # 70 | # 1, 8400, 2 71 | #---------------------------# 72 | grids = torch.cat(grids, dim=1).type(outputs.type()).cuda(local_rank) 73 | strides = torch.cat(strides, dim=1).type(outputs.type()).cuda(local_rank) 74 | #------------------------# 75 | # 根据网格点进行解码 76 | #------------------------# 77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides 78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides 79 | #-----------------# 80 | # 归一化 81 | #-----------------# 82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1] 83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0] 84 | return outputs 85 | 86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): 87 | #----------------------------------------------------------# 88 | # 将预测结果的格式转换成左上角右下角的格式。 89 | # prediction [batch_size, num_anchors, 85] 90 | #----------------------------------------------------------# 91 | box_corner = prediction.new(prediction.shape) 92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 96 | prediction[:, :, :4] = box_corner[:, :, :4] 97 | 98 | output = [None for _ in range(len(prediction))] 99 | #----------------------------------------------------------# 100 | # 对输入图片进行循环,一般只会进行一次 101 | #----------------------------------------------------------# 102 | for i, image_pred in enumerate(prediction): 103 | #----------------------------------------------------------# 104 | # 对种类预测部分取max。 105 | # class_conf [num_anchors, 1] 种类置信度 106 | # class_pred [num_anchors, 1] 种类 107 | #----------------------------------------------------------# 108 | class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) 109 | 110 | #----------------------------------------------------------# 111 | # 利用置信度进行第一轮筛选 112 | #----------------------------------------------------------# 113 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() 114 | 115 | if not image_pred.size(0): 116 | continue 117 | #-------------------------------------------------------------------------# 118 | # detections [num_anchors, 7] 119 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred 120 | #-------------------------------------------------------------------------# 121 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 122 | detections = detections[conf_mask] 123 | 124 | nms_out_index = boxes.batched_nms( 125 | detections[:, :4], 126 | detections[:, 4] * detections[:, 5], 127 | detections[:, 6], 128 | nms_thres, 129 | ) 130 | 131 | output[i] = detections[nms_out_index] 132 | 133 | # #------------------------------------------# 134 | # # 获得预测结果中包含的所有种类 135 | # #------------------------------------------# 136 | # unique_labels = detections[:, -1].cpu().unique() 137 | 138 | # if prediction.is_cuda: 139 | # unique_labels = unique_labels.cuda() 140 | # detections = detections.cuda() 141 | 142 | # for c in unique_labels: 143 | # #------------------------------------------# 144 | # # 获得某一类得分筛选后全部的预测结果 145 | # #------------------------------------------# 146 | # detections_class = detections[detections[:, -1] == c] 147 | 148 | # #------------------------------------------# 149 | # # 使用官方自带的非极大抑制会速度更快一些! 150 | # #------------------------------------------# 151 | # keep = nms( 152 | # detections_class[:, :4], 153 | # detections_class[:, 4] * detections_class[:, 5], 154 | # nms_thres 155 | # ) 156 | # max_detections = detections_class[keep] 157 | 158 | # # # 按照存在物体的置信度排序 159 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) 160 | # # detections_class = detections_class[conf_sort_index] 161 | # # # 进行非极大抑制 162 | # # max_detections = [] 163 | # # while detections_class.size(0): 164 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 165 | # # max_detections.append(detections_class[0].unsqueeze(0)) 166 | # # if len(detections_class) == 1: 167 | # # break 168 | # # ious = bbox_iou(max_detections[-1], detections_class[1:]) 169 | # # detections_class = detections_class[1:][ious < nms_thres] 170 | # # # 堆叠 171 | # # max_detections = torch.cat(max_detections).data 172 | 173 | # # Add max detections to outputs 174 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) 175 | 176 | if output[i] is not None: 177 | output[i] = output[i].cpu().numpy() 178 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] 179 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) 180 | return output 181 | -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from utils.utils import get_lr 7 | 8 | from nets.deeplabv3_training import (CE_Loss, Dice_loss, Focal_Loss, 9 | weights_init) 10 | 11 | from utils_seg.utils import get_lr 12 | from utils_seg.utils_metrics import f_score 13 | 14 | from utils.multitaskloss import MultiTaskLossWrapper 15 | 16 | 17 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, loss_history_seg, eval_callback, eval_callback_seg, optimizer, epoch, epoch_step, 18 | epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, dice_loss, focal_loss, cls_weights, num_class_seg, local_rank=0): 19 | total_loss_det = 0 20 | total_loss_seg = 0 21 | total_f_score = 0 22 | 23 | val_loss_det = 0 24 | val_loss_seg = 0 25 | val_f_score = 0 26 | 27 | total_loss = 0 28 | val_total_loss = 0 29 | 30 | if local_rank >= 0: 31 | print('Start Train') 32 | pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) 33 | model_train.train() 34 | for iteration, batch in enumerate(gen): 35 | if iteration >= epoch_step: 36 | break 37 | 38 | images, targets, radars, pngs, seg_labels = batch[0], batch[1], batch[2], batch[3], batch[4] 39 | 40 | with torch.no_grad(): 41 | weights = torch.from_numpy(cls_weights) 42 | if cuda: 43 | images = images.cuda(local_rank) 44 | targets = [ann.cuda(local_rank) for ann in targets] 45 | radars = radars.cuda(local_rank) 46 | pngs = pngs.cuda(local_rank) 47 | seg_labels = seg_labels.cuda(local_rank) 48 | weights = weights.cuda(local_rank) 49 | 50 | # ----------------------# 51 | # 清零梯度 52 | # ----------------------# 53 | optimizer.zero_grad() 54 | if not fp16: 55 | # ----------------------# 56 | # 前向传播 57 | # ----------------------# 58 | outputs, outputs_seg = model_train(images, radars) 59 | 60 | if focal_loss: 61 | loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg) 62 | else: 63 | loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg) 64 | 65 | if dice_loss: 66 | main_dice = Dice_loss(outputs_seg, seg_labels) 67 | loss_seg = loss_seg + main_dice 68 | 69 | # ----------------------# 70 | # 计算损失 71 | # ----------------------# 72 | loss_det = yolo_loss(outputs, targets) 73 | 74 | mtl = MultiTaskLossWrapper(task_num=2) 75 | total_loss = mtl(loss_seg, loss_det) 76 | 77 | with torch.no_grad(): 78 | train_f_score = f_score(outputs_seg, seg_labels) 79 | 80 | # ----------------------# 81 | # 反向传播 82 | # ----------------------# 83 | total_loss.backward() 84 | optimizer.step() 85 | else: 86 | from torch.cuda.amp import autocast 87 | with autocast(): 88 | outputs, outputs_seg = model_train(images, radars) 89 | 90 | if focal_loss: 91 | loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg) 92 | else: 93 | loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg) 94 | 95 | if dice_loss: 96 | main_dice = Dice_loss(outputs_seg, seg_labels) 97 | loss_seg = loss_seg + main_dice 98 | 99 | # ----------------------# 100 | # calculate loss 101 | # ----------------------# 102 | loss_det = yolo_loss(outputs, targets) 103 | 104 | # mtl = MultiTaskLossWrapper(task_num=2) 105 | # total_loss = mtl(loss_seg, loss_det) 106 | total_loss = loss_det + 5 * loss_seg 107 | 108 | with torch.no_grad(): 109 | train_f_score = f_score(outputs_seg, seg_labels) 110 | 111 | # ----------------------# 112 | # back-propagation 113 | # ----------------------# 114 | scaler.scale(total_loss).backward() 115 | scaler.step(optimizer) 116 | scaler.update() 117 | if ema: 118 | ema.update(model_train) 119 | 120 | total_loss_det += loss_det.item() 121 | total_loss_seg += loss_seg.item() 122 | total_loss += total_loss_det + total_loss_seg 123 | total_f_score += train_f_score.item() 124 | 125 | if local_rank >= 0: 126 | pbar.set_postfix(**{'detection loss': total_loss_det / (iteration + 1), 127 | 'segmentation loss': total_loss_seg / (iteration + 1), 128 | 'total loss': total_loss / (iteration + 1), 129 | 'f score': total_f_score / (iteration + 1), 130 | 'lr': get_lr(optimizer)}) 131 | pbar.update(1) 132 | 133 | if local_rank >= 0: 134 | pbar.close() 135 | print('Finish Train') 136 | print('Start Validation') 137 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3) 138 | 139 | if ema: 140 | model_train_eval = ema.ema 141 | else: 142 | model_train_eval = model_train.eval() 143 | 144 | for iteration, batch in enumerate(gen_val): 145 | if iteration >= epoch_step_val: 146 | break 147 | images, targets, radars, pngs, seg_labels = batch[0], batch[1], batch[2], batch[3], batch[4] 148 | with torch.no_grad(): 149 | if cuda: 150 | images = images.cuda(local_rank) 151 | targets = [ann.cuda(local_rank) for ann in targets] 152 | radars = radars.cuda(local_rank) 153 | pngs = pngs.cuda(local_rank) 154 | seg_labels = seg_labels.cuda(local_rank) 155 | weights = weights.cuda(local_rank) 156 | # ----------------------# 157 | # 清零梯度 158 | # ----------------------# 159 | optimizer.zero_grad() 160 | # ----------------------# 161 | # 前向传播 162 | # ----------------------# 163 | outputs, outputs_seg = model_train(images, radars) 164 | 165 | if focal_loss: 166 | loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg) 167 | else: 168 | loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg) 169 | 170 | if dice_loss: 171 | main_dice = Dice_loss(outputs_seg, seg_labels) 172 | loss_seg = loss_seg + main_dice 173 | 174 | # -------------------------------# 175 | # 计算f_score 176 | # -------------------------------# 177 | _f_score = f_score(outputs_seg, seg_labels) 178 | 179 | # ----------------------# 180 | # 计算损失 181 | # ----------------------# 182 | loss_value = yolo_loss(outputs, targets) 183 | loss_value_seg = loss_seg 184 | val_f_score += _f_score.item() 185 | 186 | val_loss_det += loss_value.item() 187 | val_loss_seg += loss_value_seg.item() 188 | val_total_loss = val_loss_det + val_loss_seg 189 | 190 | if local_rank >= 0: 191 | pbar.set_postfix(**{'detection val_loss': val_loss_det / (iteration + 1), 192 | 'segmentation val_loss': val_loss_seg / (iteration + 1), 193 | 'val loss': val_total_loss / (iteration + 1), 194 | 'f_score': val_f_score / (iteration + 1), 195 | }) 196 | pbar.update(1) 197 | 198 | if local_rank >= 0: 199 | pbar.close() 200 | print('Finish Validation') 201 | loss_history.append_loss(epoch + 1, total_loss_det / epoch_step, val_loss_det / epoch_step_val) 202 | loss_history_seg.append_loss(epoch + 1, total_loss_seg / epoch_step, val_loss_seg / epoch_step_val) 203 | eval_callback.on_epoch_end(epoch + 1, model_train_eval) 204 | eval_callback_seg.on_epoch_end(epoch + 1, model_train_eval) 205 | print('Epoch:' + str(epoch + 1) + '/' + str(Epoch)) 206 | print('Total Loss: %.3f || Val Loss Det: %.3f || Val Loss Seg: %.3f' % ((total_loss / epoch_step, 207 | val_loss_det / epoch_step_val, 208 | val_loss_seg / epoch_step_val))) 209 | 210 | # -----------------------------------------------# 211 | # 保存权值 212 | # -----------------------------------------------# 213 | if ema: 214 | save_state_dict = ema.ema.state_dict() 215 | else: 216 | save_state_dict = model.state_dict() 217 | 218 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 219 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-det_val_loss%.3f-seg_val_loss%.3f.pth" % ( 220 | epoch + 1, val_total_loss / epoch_step, val_loss_det / epoch_step_val, val_loss_seg / epoch_step_val))) 221 | 222 | if len(loss_history.val_loss) <= 1 or (val_total_loss / epoch_step_val) <= min(loss_history.val_loss) + min(loss_history_seg.val_loss): 223 | print('Save best model to best_epoch_weights.pth') 224 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) 225 | 226 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /utils_seg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__init__.py -------------------------------------------------------------------------------- /utils_seg/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils_seg/__pycache__/callbacks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/callbacks.cpython-39.pyc -------------------------------------------------------------------------------- /utils_seg/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /utils_seg/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils_seg/__pycache__/utils_fit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/utils_fit.cpython-39.pyc -------------------------------------------------------------------------------- /utils_seg/__pycache__/utils_metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/utils_metrics.cpython-39.pyc -------------------------------------------------------------------------------- /utils_seg/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import matplotlib 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | matplotlib.use('Agg') 8 | from matplotlib import pyplot as plt 9 | import scipy.signal 10 | 11 | import cv2 12 | import shutil 13 | import numpy as np 14 | 15 | from PIL import Image 16 | from tqdm import tqdm 17 | from torch.utils.tensorboard import SummaryWriter 18 | from .utils import cvtColor, preprocess_input, resize_image 19 | from .utils_metrics import compute_mIoU 20 | 21 | 22 | class LossHistory(): 23 | def __init__(self, log_dir, model, input_shape): 24 | self.log_dir = log_dir 25 | self.losses = [] 26 | self.val_loss = [] 27 | 28 | os.makedirs(self.log_dir) 29 | self.writer = SummaryWriter(self.log_dir) 30 | try: 31 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 32 | self.writer.add_graph(model, dummy_input) 33 | except: 34 | pass 35 | 36 | def append_loss(self, epoch, loss, val_loss): 37 | if not os.path.exists(self.log_dir): 38 | os.makedirs(self.log_dir) 39 | 40 | self.losses.append(loss) 41 | self.val_loss.append(val_loss) 42 | 43 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 44 | f.write(str(loss)) 45 | f.write("\n") 46 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 47 | f.write(str(val_loss)) 48 | f.write("\n") 49 | 50 | self.writer.add_scalar('loss', loss, epoch) 51 | self.writer.add_scalar('val_loss', val_loss, epoch) 52 | self.loss_plot() 53 | 54 | def loss_plot(self): 55 | iters = range(len(self.losses)) 56 | 57 | plt.figure() 58 | plt.plot(iters, self.losses, 'red', linewidth=2, label='train loss') 59 | plt.plot(iters, self.val_loss, 'coral', linewidth=2, label='val loss') 60 | try: 61 | if len(self.losses) < 25: 62 | num = 5 63 | else: 64 | num = 15 65 | 66 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle='--', linewidth=2, 67 | label='smooth train loss') 68 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle='--', linewidth=2, 69 | label='smooth val loss') 70 | except: 71 | pass 72 | 73 | plt.grid(True) 74 | plt.xlabel('Epoch') 75 | plt.ylabel('Loss') 76 | plt.legend(loc="upper right") 77 | 78 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 79 | 80 | plt.cla() 81 | plt.close("all") 82 | 83 | 84 | class EvalCallback(): 85 | def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, local_rank, radar_path, \ 86 | miou_out_path=".temp_miou_out", eval_flag=True, period=1): 87 | super(EvalCallback, self).__init__() 88 | 89 | self.net = net 90 | self.input_shape = input_shape 91 | self.num_classes = num_classes 92 | self.image_ids = image_ids 93 | self.dataset_path = dataset_path 94 | self.log_dir = log_dir 95 | self.cuda = cuda 96 | self.local_rank = local_rank 97 | self.miou_out_path = miou_out_path 98 | self.eval_flag = eval_flag 99 | self.period = period 100 | self.radar_path = radar_path 101 | 102 | pattern_string = "\d{10}.\d{5}" 103 | pattern = re.compile(pattern_string) # 查找数字 104 | self.image_ids = [pattern.findall(image_id)[-1] for image_id in image_ids] 105 | 106 | self.mious = [0] 107 | self.epoches = [0] 108 | if self.eval_flag: 109 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: 110 | f.write(str(0)) 111 | f.write("\n") 112 | 113 | def get_miou_png(self, image, radar_data): 114 | # ---------------------------------------------------------# 115 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 116 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 117 | # ---------------------------------------------------------# 118 | image = cvtColor(image) 119 | orininal_h = np.array(image).shape[0] 120 | orininal_w = np.array(image).shape[1] 121 | # ---------------------------------------------------------# 122 | # 给图像增加灰条,实现不失真的resize 123 | # 也可以直接resize进行识别 124 | # ---------------------------------------------------------# 125 | image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0])) 126 | # ---------------------------------------------------------# 127 | # 添加上batch_size维度 128 | # ---------------------------------------------------------# 129 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0) 130 | 131 | with torch.no_grad(): 132 | images = torch.from_numpy(image_data) 133 | if self.cuda: 134 | images = images.cuda(self.local_rank) 135 | radar_data = radar_data.cuda(self.local_rank) 136 | 137 | # ---------------------------------------------------# 138 | # 图片传入网络进行预测 139 | # ---------------------------------------------------# 140 | pr = self.net(images, radar_data)[1][0] 141 | # ---------------------------------------------------# 142 | # 取出每一个像素点的种类 143 | # ---------------------------------------------------# 144 | pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy() 145 | # --------------------------------------# 146 | # 将灰条部分截取掉 147 | # --------------------------------------# 148 | pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh), \ 149 | int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)] 150 | # ---------------------------------------------------# 151 | # 进行图片的resize 152 | # ---------------------------------------------------# 153 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR) 154 | # ---------------------------------------------------# 155 | # 取出每一个像素点的种类 156 | # ---------------------------------------------------# 157 | pr = pr.argmax(axis=-1) 158 | 159 | image = Image.fromarray(np.uint8(pr)) 160 | return image 161 | 162 | def on_epoch_end(self, epoch, model_eval): 163 | if epoch % self.period == 0 and self.eval_flag: 164 | self.net = model_eval 165 | gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/") 166 | pred_dir = os.path.join(self.miou_out_path, 'detection-results') 167 | if not os.path.exists(self.miou_out_path): 168 | os.makedirs(self.miou_out_path) 169 | if not os.path.exists(pred_dir): 170 | os.makedirs(pred_dir) 171 | print("Get miou.") 172 | for image_id in tqdm(self.image_ids): 173 | # ------------------------------# 174 | # 读取雷达特征map 175 | # ------------------------------# 176 | radar_path = os.path.join(self.radar_path, image_id + '.npz') 177 | radar_data = np.load(radar_path)['arr_0'] 178 | radar_data = torch.from_numpy(radar_data).type(torch.FloatTensor).unsqueeze(0).cuda(self.local_rank) 179 | 180 | # -------------------------------# 181 | # 从文件中读取图像 182 | # -------------------------------# 183 | image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/" + image_id + ".jpg") 184 | image = Image.open(image_path) 185 | # ------------------------------# 186 | # 获得预测txt 187 | # ------------------------------# 188 | image = self.get_miou_png(image, radar_data) 189 | image.save(os.path.join(pred_dir, image_id + ".png")) 190 | 191 | print("Calculate miou.") 192 | _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数 193 | temp_miou = np.nanmean(IoUs) * 100 194 | 195 | self.mious.append(temp_miou) 196 | self.epoches.append(epoch) 197 | 198 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f: 199 | f.write(str(temp_miou)) 200 | f.write("\n") 201 | 202 | plt.figure() 203 | plt.plot(self.epoches, self.mious, 'red', linewidth=2, label='train miou') 204 | 205 | plt.grid(True) 206 | plt.xlabel('Epoch') 207 | plt.ylabel('Miou') 208 | plt.title('A Miou Curve') 209 | plt.legend(loc="upper right") 210 | 211 | plt.savefig(os.path.join(self.log_dir, "epoch_miou.png")) 212 | plt.cla() 213 | plt.close("all") 214 | 215 | print("Get miou done.") 216 | shutil.rmtree(self.miou_out_path) 217 | -------------------------------------------------------------------------------- /utils_seg/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data.dataset import Dataset 8 | 9 | from utils_seg.utils import cvtColor, preprocess_input 10 | 11 | 12 | class DeeplabDataset(Dataset): 13 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path): 14 | super(DeeplabDataset, self).__init__() 15 | self.annotation_lines = annotation_lines 16 | self.length = len(annotation_lines) 17 | self.input_shape = input_shape 18 | self.num_classes = num_classes 19 | self.train = train 20 | self.dataset_path = dataset_path 21 | 22 | def __len__(self): 23 | return self.length 24 | 25 | def __getitem__(self, index): 26 | annotation_line = self.annotation_lines[index] 27 | name = annotation_line.split()[0] 28 | 29 | #-------------------------------# 30 | # 从文件中读取图像 31 | #-------------------------------# 32 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg")) 33 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png")) 34 | #-------------------------------# 35 | # 数据增强 36 | #-------------------------------# 37 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train) 38 | 39 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1]) 40 | png = np.array(png) 41 | png[png >= self.num_classes] = self.num_classes 42 | #-------------------------------------------------------# 43 | # 转化成one_hot的形式 44 | # 在这里需要+1是因为voc数据集有些标签具有白边部分 45 | # 我们需要将白边部分进行忽略,+1的目的是方便忽略。 46 | #-------------------------------------------------------# 47 | seg_labels = np.eye(self.num_classes+1)[png.reshape([-1])] 48 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1)) 49 | 50 | return jpg, png, seg_labels 51 | 52 | def rand(self, a=0, b=1): 53 | return np.random.rand() * (b - a) + a 54 | 55 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True): 56 | image = cvtColor(image) 57 | label = Image.fromarray(np.array(label)) 58 | #------------------------------# 59 | # 获得图像的高宽与目标高宽 60 | #------------------------------# 61 | iw, ih = image.size 62 | h, w = input_shape 63 | 64 | if not random: 65 | iw, ih = image.size 66 | scale = min(w/iw, h/ih) 67 | nw = int(iw*scale) 68 | nh = int(ih*scale) 69 | 70 | image = image.resize((nw,nh), Image.BICUBIC) 71 | new_image = Image.new('RGB', [w, h], (128,128,128)) 72 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 73 | 74 | label = label.resize((nw,nh), Image.NEAREST) 75 | new_label = Image.new('L', [w, h], (0)) 76 | new_label.paste(label, ((w-nw)//2, (h-nh)//2)) 77 | return new_image, new_label 78 | 79 | #------------------------------------------# 80 | # 对图像进行缩放并且进行长和宽的扭曲 81 | #------------------------------------------# 82 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter) 83 | scale = self.rand(0.25, 2) 84 | if new_ar < 1: 85 | nh = int(scale*h) 86 | nw = int(nh*new_ar) 87 | else: 88 | nw = int(scale*w) 89 | nh = int(nw/new_ar) 90 | image = image.resize((nw,nh), Image.BICUBIC) 91 | label = label.resize((nw,nh), Image.NEAREST) 92 | 93 | #------------------------------------------# 94 | # 翻转图像 95 | #------------------------------------------# 96 | flip = self.rand()<.5 97 | if flip: 98 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 99 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 100 | 101 | #------------------------------------------# 102 | # 将图像多余的部分加上灰条 103 | #------------------------------------------# 104 | dx = int(self.rand(0, w-nw)) 105 | dy = int(self.rand(0, h-nh)) 106 | new_image = Image.new('RGB', (w,h), (128,128,128)) 107 | new_label = Image.new('L', (w,h), (0)) 108 | new_image.paste(image, (dx, dy)) 109 | new_label.paste(label, (dx, dy)) 110 | image = new_image 111 | label = new_label 112 | 113 | image_data = np.array(image, np.uint8) 114 | 115 | #------------------------------------------# 116 | # 高斯模糊 117 | #------------------------------------------# 118 | blur = self.rand() < 0.25 119 | if blur: 120 | image_data = cv2.GaussianBlur(image_data, (5, 5), 0) 121 | 122 | #------------------------------------------# 123 | # 旋转 124 | #------------------------------------------# 125 | rotate = self.rand() < 0.25 126 | if rotate: 127 | center = (w // 2, h // 2) 128 | rotation = np.random.randint(-10, 11) 129 | M = cv2.getRotationMatrix2D(center, -rotation, scale=1) 130 | image_data = cv2.warpAffine(image_data, M, (w, h), flags=cv2.INTER_CUBIC, borderValue=(128,128,128)) 131 | label = cv2.warpAffine(np.array(label, np.uint8), M, (w, h), flags=cv2.INTER_NEAREST, borderValue=(0)) 132 | 133 | #---------------------------------# 134 | # 对图像进行色域变换 135 | # 计算色域变换的参数 136 | #---------------------------------# 137 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1 138 | #---------------------------------# 139 | # 将图像转到HSV上 140 | #---------------------------------# 141 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV)) 142 | dtype = image_data.dtype 143 | #---------------------------------# 144 | # 应用变换 145 | #---------------------------------# 146 | x = np.arange(0, 256, dtype=r.dtype) 147 | lut_hue = ((x * r[0]) % 180).astype(dtype) 148 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 149 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 150 | 151 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 152 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) 153 | 154 | return image_data, label 155 | 156 | 157 | # DataLoader中collate_fn使用 158 | def deeplab_dataset_collate(batch): 159 | images = [] 160 | pngs = [] 161 | seg_labels = [] 162 | for img, png, labels in batch: 163 | images.append(img) 164 | pngs.append(png) 165 | seg_labels.append(labels) 166 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) 167 | pngs = torch.from_numpy(np.array(pngs)).long() 168 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor) 169 | return images, pngs, seg_labels 170 | -------------------------------------------------------------------------------- /utils_seg/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | #---------------------------------------------------------# 6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 8 | #---------------------------------------------------------# 9 | def cvtColor(image): 10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 11 | return image 12 | else: 13 | image = image.convert('RGB') 14 | return image 15 | 16 | #---------------------------------------------------# 17 | # 对输入图像进行resize 18 | #---------------------------------------------------# 19 | def resize_image(image, size): 20 | iw, ih = image.size 21 | w, h = size 22 | 23 | scale = min(w/iw, h/ih) 24 | nw = int(iw*scale) 25 | nh = int(ih*scale) 26 | 27 | image = image.resize((nw,nh), Image.BICUBIC) 28 | new_image = Image.new('RGB', size, (128,128,128)) 29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 30 | 31 | return new_image, nw, nh 32 | 33 | #---------------------------------------------------# 34 | # 获得学习率 35 | #---------------------------------------------------# 36 | def get_lr(optimizer): 37 | for param_group in optimizer.param_groups: 38 | return param_group['lr'] 39 | 40 | def preprocess_input(image): 41 | image /= 255.0 42 | image -= np.array([0.485, 0.456, 0.406]) 43 | image /= np.array([0.229, 0.224, 0.225]) 44 | return image 45 | 46 | def show_config(**kwargs): 47 | print('Configurations:') 48 | print('-' * 70) 49 | print('|%25s | %40s|' % ('keys', 'values')) 50 | print('-' * 70) 51 | for key, value in kwargs.items(): 52 | print('|%25s | %40s|' % (str(key), str(value))) 53 | print('-' * 70) 54 | 55 | def download_weights(backbone, model_dir="./model_data"): 56 | import os 57 | from torch.hub import load_state_dict_from_url 58 | 59 | download_urls = { 60 | 'mobilenet' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar', 61 | 'xception' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth', 62 | } 63 | url = download_urls[backbone] 64 | 65 | if not os.path.exists(model_dir): 66 | os.makedirs(model_dir) 67 | load_state_dict_from_url(url, model_dir) -------------------------------------------------------------------------------- /utils_seg/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from nets.deeplabv3_training import (CE_Loss, Dice_loss, Focal_Loss, 5 | weights_init) 6 | from tqdm import tqdm 7 | 8 | from utils_seg.utils import get_lr 9 | from utils_seg.utils_metrics import f_score 10 | 11 | 12 | 13 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, \ 14 | fp16, scaler, save_period, save_dir, local_rank=0): 15 | total_loss = 0 16 | total_f_score = 0 17 | 18 | val_loss = 0 19 | val_f_score = 0 20 | 21 | if local_rank == 0: 22 | print('Start Train') 23 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 24 | model_train.train() 25 | for iteration, batch in enumerate(gen): 26 | if iteration >= epoch_step: 27 | break 28 | imgs, targets, radars, pngs, labels = batch 29 | 30 | with torch.no_grad(): 31 | weights = torch.from_numpy(cls_weights) 32 | if cuda: 33 | imgs = imgs.cuda(local_rank) 34 | targets = [ann.cuda(local_rank) for ann in targets] 35 | radars = radars.cuda(local_rank) 36 | pngs = pngs.cuda(local_rank) 37 | labels = labels.cuda(local_rank) 38 | weights = weights.cuda(local_rank) 39 | 40 | optimizer.zero_grad() 41 | if not fp16: 42 | _, outputs = model_train(imgs, radars) 43 | if focal_loss: 44 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 45 | else: 46 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 47 | 48 | if dice_loss: 49 | main_dice = Dice_loss(outputs, labels) 50 | loss = loss + main_dice 51 | 52 | with torch.no_grad(): 53 | #-------------------------------# 54 | # 计算f_score 55 | #-------------------------------# 56 | _f_score = f_score(outputs, labels) 57 | 58 | loss.backward() 59 | optimizer.step() 60 | else: 61 | from torch.cuda.amp import autocast 62 | with autocast(): 63 | _, outputs = model_train(imgs, radars) 64 | if focal_loss: 65 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 66 | else: 67 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 68 | 69 | if dice_loss: 70 | main_dice = Dice_loss(outputs, labels) 71 | loss = loss + main_dice 72 | 73 | with torch.no_grad(): 74 | #-------------------------------# 75 | # 计算f_score 76 | #-------------------------------# 77 | _f_score = f_score(outputs, labels) 78 | 79 | #----------------------# 80 | # 反向传播 81 | #----------------------# 82 | scaler.scale(loss).backward() 83 | scaler.step(optimizer) 84 | scaler.update() 85 | 86 | total_loss += loss.item() 87 | total_f_score += _f_score.item() 88 | 89 | if local_rank == 0: 90 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 91 | 'f_score' : total_f_score / (iteration + 1), 92 | 'lr' : get_lr(optimizer)}) 93 | pbar.update(1) 94 | 95 | if local_rank == 0: 96 | pbar.close() 97 | print('Finish Train') 98 | print('Start Validation') 99 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 100 | 101 | model_train.eval() 102 | for iteration, batch in enumerate(gen_val): 103 | if iteration >= epoch_step_val: 104 | break 105 | imgs, targets, radars, pngs, labels = batch 106 | with torch.no_grad(): 107 | weights = torch.from_numpy(cls_weights) 108 | if cuda: 109 | imgs = imgs.cuda(local_rank) 110 | targets = [ann.cuda(local_rank) for ann in targets] 111 | radars = radars.cuda(local_rank) 112 | pngs = pngs.cuda(local_rank) 113 | labels = labels.cuda(local_rank) 114 | weights = weights.cuda(local_rank) 115 | 116 | _, outputs = model_train(imgs, radars) 117 | if focal_loss: 118 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes) 119 | else: 120 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes) 121 | 122 | if dice_loss: 123 | main_dice = Dice_loss(outputs, labels) 124 | loss = loss + main_dice 125 | #-------------------------------# 126 | # 计算f_score 127 | #-------------------------------# 128 | _f_score = f_score(outputs, labels) 129 | 130 | val_loss += loss.item() 131 | val_f_score += _f_score.item() 132 | 133 | if local_rank == 0: 134 | pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1), 135 | 'f_score' : val_f_score / (iteration + 1), 136 | 'lr' : get_lr(optimizer)}) 137 | pbar.update(1) 138 | 139 | if local_rank == 0: 140 | pbar.close() 141 | print('Finish Validation') 142 | loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val) 143 | eval_callback.on_epoch_end(epoch + 1, model_train) 144 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 145 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val)) 146 | 147 | #-----------------------------------------------# 148 | # 保存权值 149 | #-----------------------------------------------# 150 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 151 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val))) 152 | 153 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 154 | print('Save best model to best_epoch_weights.pth') 155 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth")) 156 | 157 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /utils_seg/utils_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from os.path import join 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | 11 | 12 | def f_score(inputs, target, beta=1, smooth=1e-5, threhold=0.5): 13 | n, c, h, w = inputs.size() 14 | nt, ht, wt, ct = target.size() 15 | if h != ht and w != wt: 16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 17 | 18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1) 19 | temp_target = target.view(n, -1, ct) 20 | 21 | # --------------------------------------------# 22 | # 计算dice系数 23 | # --------------------------------------------# 24 | temp_inputs = torch.gt(temp_inputs, threhold).float() 25 | tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1]) 26 | fp = torch.sum(temp_inputs, axis=[0, 1]) - tp 27 | fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp 28 | 29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 30 | score = torch.mean(score) 31 | return score 32 | 33 | 34 | # 设标签宽W,长H 35 | def fast_hist(a, b, n): 36 | # --------------------------------------------------------------------------------# 37 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) 38 | # --------------------------------------------------------------------------------# 39 | k = (a >= 0) & (a < n) 40 | # --------------------------------------------------------------------------------# 41 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 42 | # 返回中,写对角线上的为分类正确的像素点 43 | # --------------------------------------------------------------------------------# 44 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 45 | 46 | 47 | def per_class_iu(hist): 48 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 49 | 50 | 51 | def per_class_PA_Recall(hist): 52 | return np.diag(hist) / np.maximum(hist.sum(1), 1) 53 | 54 | 55 | def per_class_Precision(hist): 56 | return np.diag(hist) / np.maximum(hist.sum(0), 1) 57 | 58 | 59 | def per_Accuracy(hist): 60 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) 61 | 62 | 63 | def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None): 64 | print('Num classes', num_classes) 65 | # -----------------------------------------# 66 | # 创建一个全是0的矩阵,是一个混淆矩阵 67 | # -----------------------------------------# 68 | hist = np.zeros((num_classes, num_classes)) 69 | 70 | # ------------------------------------------------# 71 | # 获得验证集标签路径列表,方便直接读取 72 | # 获得验证集图像分割结果路径列表,方便直接读取 73 | # ------------------------------------------------# 74 | gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list] 75 | pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list] 76 | 77 | # ------------------------------------------------# 78 | # 读取每一个(图片-标签)对 79 | # ------------------------------------------------# 80 | for ind in range(len(gt_imgs)): 81 | # ------------------------------------------------# 82 | # 读取一张图像分割结果,转化成numpy数组 83 | # ------------------------------------------------# 84 | 85 | pred = np.array(Image.open(pred_imgs[ind])) 86 | # ------------------------------------------------# 87 | # 读取一张对应的标签,转化成numpy数组 88 | # ------------------------------------------------# 89 | label = np.array(Image.open(gt_imgs[ind])) 90 | 91 | # 如果图像分割结果与标签的大小不一样,这张图片就不计算 92 | if len(label.flatten()) != len(pred.flatten()): 93 | print( 94 | 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format( 95 | len(label.flatten()), len(pred.flatten()), gt_imgs[ind], 96 | pred_imgs[ind])) 97 | continue 98 | 99 | # ------------------------------------------------# 100 | # 对一张图片计算21×21的hist矩阵,并累加 101 | # ------------------------------------------------# 102 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 103 | # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 104 | if name_classes is not None and ind > 0 and ind % 10 == 0: 105 | print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( 106 | ind, 107 | len(gt_imgs), 108 | 100 * np.nanmean(per_class_iu(hist)), 109 | 100 * np.nanmean(per_class_PA_Recall(hist)), 110 | 100 * per_Accuracy(hist) 111 | ) 112 | ) 113 | # ------------------------------------------------# 114 | # 计算所有验证集图片的逐类别mIoU值 115 | # ------------------------------------------------# 116 | IoUs = per_class_iu(hist) 117 | PA_Recall = per_class_PA_Recall(hist) 118 | Precision = per_class_Precision(hist) 119 | # ------------------------------------------------# 120 | # 逐类别输出一下mIoU值 121 | # ------------------------------------------------# 122 | if name_classes is not None: 123 | for ind_class in range(num_classes): 124 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ 125 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2)) + '; Precision-' + str( 126 | round(Precision[ind_class] * 100, 2))) 127 | 128 | # -----------------------------------------------------------------# 129 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 130 | # -----------------------------------------------------------------# 131 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str( 132 | round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) 133 | return np.array(hist, np.int), IoUs, PA_Recall, Precision 134 | 135 | 136 | def adjust_axes(r, t, fig, axes): 137 | bb = t.get_window_extent(renderer=r) 138 | text_width_inches = bb.width / fig.dpi 139 | current_fig_width = fig.get_figwidth() 140 | new_fig_width = current_fig_width + text_width_inches 141 | propotion = new_fig_width / current_fig_width 142 | x_lim = axes.get_xlim() 143 | axes.set_xlim([x_lim[0], x_lim[1] * propotion]) 144 | 145 | 146 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size=12, plt_show=True): 147 | fig = plt.gcf() 148 | axes = plt.gca() 149 | plt.barh(range(len(values)), values, color='royalblue') 150 | plt.title(plot_title, fontsize=tick_font_size + 2) 151 | plt.xlabel(x_label, fontsize=tick_font_size) 152 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) 153 | r = fig.canvas.get_renderer() 154 | for i, val in enumerate(values): 155 | str_val = " " + str(val) 156 | if val < 1.0: 157 | str_val = " {0:.2f}".format(val) 158 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') 159 | if i == (len(values) - 1): 160 | adjust_axes(r, t, fig, axes) 161 | 162 | fig.tight_layout() 163 | fig.savefig(output_path) 164 | if plt_show: 165 | plt.show() 166 | plt.close() 167 | 168 | 169 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size=12): 170 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs) * 100), "Intersection over Union", \ 171 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size=tick_font_size, plt_show=True) 172 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) 173 | 174 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Pixel Accuracy", \ 175 | os.path.join(miou_out_path, "mPA.png"), tick_font_size=tick_font_size, plt_show=False) 176 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) 177 | 178 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Recall", \ 179 | os.path.join(miou_out_path, "Recall.png"), tick_font_size=tick_font_size, plt_show=False) 180 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) 181 | 182 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision) * 100), "Precision", \ 183 | os.path.join(miou_out_path, "Precision.png"), tick_font_size=tick_font_size, plt_show=False) 184 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) 185 | 186 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: 187 | writer = csv.writer(f) 188 | writer_list = [] 189 | writer_list.append([' '] + [str(c) for c in name_classes]) 190 | for i in range(len(hist)): 191 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) 192 | writer.writerows(writer_list) 193 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) 194 | -------------------------------------------------------------------------------- /venv/Scripts/Activate.ps1: -------------------------------------------------------------------------------- 1 | <# 2 | .Synopsis 3 | Activate a Python virtual environment for the current PowerShell session. 4 | 5 | .Description 6 | Pushes the python executable for a virtual environment to the front of the 7 | $Env:PATH environment variable and sets the prompt to signify that you are 8 | in a Python virtual environment. Makes use of the command line switches as 9 | well as the `pyvenv.cfg` file values present in the virtual environment. 10 | 11 | .Parameter VenvDir 12 | Path to the directory that contains the virtual environment to activate. The 13 | default value for this is the parent of the directory that the Activate.ps1 14 | script is located within. 15 | 16 | .Parameter Prompt 17 | The prompt prefix to display when this virtual environment is activated. By 18 | default, this prompt is the name of the virtual environment folder (VenvDir) 19 | surrounded by parentheses and followed by a single space (ie. '(.venv) '). 20 | 21 | .Example 22 | Activate.ps1 23 | Activates the Python virtual environment that contains the Activate.ps1 script. 24 | 25 | .Example 26 | Activate.ps1 -Verbose 27 | Activates the Python virtual environment that contains the Activate.ps1 script, 28 | and shows extra information about the activation as it executes. 29 | 30 | .Example 31 | Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv 32 | Activates the Python virtual environment located in the specified location. 33 | 34 | .Example 35 | Activate.ps1 -Prompt "MyPython" 36 | Activates the Python virtual environment that contains the Activate.ps1 script, 37 | and prefixes the current prompt with the specified string (surrounded in 38 | parentheses) while the virtual environment is active. 39 | 40 | .Notes 41 | On Windows, it may be required to enable this Activate.ps1 script by setting the 42 | execution policy for the user. You can do this by issuing the following PowerShell 43 | command: 44 | 45 | PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser 46 | 47 | For more information on Execution Policies: 48 | https://go.microsoft.com/fwlink/?LinkID=135170 49 | 50 | #> 51 | Param( 52 | [Parameter(Mandatory = $false)] 53 | [String] 54 | $VenvDir, 55 | [Parameter(Mandatory = $false)] 56 | [String] 57 | $Prompt 58 | ) 59 | 60 | <# Function declarations --------------------------------------------------- #> 61 | 62 | <# 63 | .Synopsis 64 | Remove all shell session elements added by the Activate script, including the 65 | addition of the virtual environment's Python executable from the beginning of 66 | the PATH variable. 67 | 68 | .Parameter NonDestructive 69 | If present, do not remove this function from the global namespace for the 70 | session. 71 | 72 | #> 73 | function global:deactivate ([switch]$NonDestructive) { 74 | # Revert to original values 75 | 76 | # The prior prompt: 77 | if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) { 78 | Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt 79 | Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT 80 | } 81 | 82 | # The prior PYTHONHOME: 83 | if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) { 84 | Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME 85 | Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME 86 | } 87 | 88 | # The prior PATH: 89 | if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) { 90 | Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH 91 | Remove-Item -Path Env:_OLD_VIRTUAL_PATH 92 | } 93 | 94 | # Just remove the VIRTUAL_ENV altogether: 95 | if (Test-Path -Path Env:VIRTUAL_ENV) { 96 | Remove-Item -Path env:VIRTUAL_ENV 97 | } 98 | 99 | # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether: 100 | if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) { 101 | Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force 102 | } 103 | 104 | # Leave deactivate function in the global namespace if requested: 105 | if (-not $NonDestructive) { 106 | Remove-Item -Path function:deactivate 107 | } 108 | } 109 | 110 | <# 111 | .Description 112 | Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the 113 | given folder, and returns them in a map. 114 | 115 | For each line in the pyvenv.cfg file, if that line can be parsed into exactly 116 | two strings separated by `=` (with any amount of whitespace surrounding the =) 117 | then it is considered a `key = value` line. The left hand string is the key, 118 | the right hand is the value. 119 | 120 | If the value starts with a `'` or a `"` then the first and last character is 121 | stripped from the value before being captured. 122 | 123 | .Parameter ConfigDir 124 | Path to the directory that contains the `pyvenv.cfg` file. 125 | #> 126 | function Get-PyVenvConfig( 127 | [String] 128 | $ConfigDir 129 | ) { 130 | Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg" 131 | 132 | # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue). 133 | $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue 134 | 135 | # An empty map will be returned if no config file is found. 136 | $pyvenvConfig = @{ } 137 | 138 | if ($pyvenvConfigPath) { 139 | 140 | Write-Verbose "File exists, parse `key = value` lines" 141 | $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath 142 | 143 | $pyvenvConfigContent | ForEach-Object { 144 | $keyval = $PSItem -split "\s*=\s*", 2 145 | if ($keyval[0] -and $keyval[1]) { 146 | $val = $keyval[1] 147 | 148 | # Remove extraneous quotations around a string value. 149 | if ("'""".Contains($val.Substring(0, 1))) { 150 | $val = $val.Substring(1, $val.Length - 2) 151 | } 152 | 153 | $pyvenvConfig[$keyval[0]] = $val 154 | Write-Verbose "Adding Key: '$($keyval[0])'='$val'" 155 | } 156 | } 157 | } 158 | return $pyvenvConfig 159 | } 160 | 161 | 162 | <# Begin Activate script --------------------------------------------------- #> 163 | 164 | # Determine the containing directory of this script 165 | $VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition 166 | $VenvExecDir = Get-Item -Path $VenvExecPath 167 | 168 | Write-Verbose "Activation script is located in path: '$VenvExecPath'" 169 | Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)" 170 | Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)" 171 | 172 | # Set values required in priority: CmdLine, ConfigFile, Default 173 | # First, get the location of the virtual environment, it might not be 174 | # VenvExecDir if specified on the command line. 175 | if ($VenvDir) { 176 | Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values" 177 | } 178 | else { 179 | Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir." 180 | $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/") 181 | Write-Verbose "VenvDir=$VenvDir" 182 | } 183 | 184 | # Next, read the `pyvenv.cfg` file to determine any required value such 185 | # as `prompt`. 186 | $pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir 187 | 188 | # Next, set the prompt from the command line, or the config file, or 189 | # just use the name of the virtual environment folder. 190 | if ($Prompt) { 191 | Write-Verbose "Prompt specified as argument, using '$Prompt'" 192 | } 193 | else { 194 | Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value" 195 | if ($pyvenvCfg -and $pyvenvCfg['prompt']) { 196 | Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'" 197 | $Prompt = $pyvenvCfg['prompt']; 198 | } 199 | else { 200 | Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)" 201 | Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'" 202 | $Prompt = Split-Path -Path $venvDir -Leaf 203 | } 204 | } 205 | 206 | Write-Verbose "Prompt = '$Prompt'" 207 | Write-Verbose "VenvDir='$VenvDir'" 208 | 209 | # Deactivate any currently active virtual environment, but leave the 210 | # deactivate function in place. 211 | deactivate -nondestructive 212 | 213 | # Now set the environment variable VIRTUAL_ENV, used by many tools to determine 214 | # that there is an activated venv. 215 | $env:VIRTUAL_ENV = $VenvDir 216 | 217 | if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) { 218 | 219 | Write-Verbose "Setting prompt to '$Prompt'" 220 | 221 | # Set the prompt to include the env name 222 | # Make sure _OLD_VIRTUAL_PROMPT is global 223 | function global:_OLD_VIRTUAL_PROMPT { "" } 224 | Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT 225 | New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt 226 | 227 | function global:prompt { 228 | Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) " 229 | _OLD_VIRTUAL_PROMPT 230 | } 231 | } 232 | 233 | # Clear PYTHONHOME 234 | if (Test-Path -Path Env:PYTHONHOME) { 235 | Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME 236 | Remove-Item -Path Env:PYTHONHOME 237 | } 238 | 239 | # Add the venv to the PATH 240 | Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH 241 | $Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH" 242 | -------------------------------------------------------------------------------- /venv/Scripts/activate: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate" *from bash* 2 | # you cannot run it directly 3 | 4 | deactivate () { 5 | # reset old environment variables 6 | if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then 7 | PATH="${_OLD_VIRTUAL_PATH:-}" 8 | export PATH 9 | unset _OLD_VIRTUAL_PATH 10 | fi 11 | if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then 12 | PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}" 13 | export PYTHONHOME 14 | unset _OLD_VIRTUAL_PYTHONHOME 15 | fi 16 | 17 | # This should detect bash and zsh, which have a hash command that must 18 | # be called to get it to forget past commands. Without forgetting 19 | # past commands the $PATH changes we made may not be respected 20 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then 21 | hash -r 2> /dev/null 22 | fi 23 | 24 | if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then 25 | PS1="${_OLD_VIRTUAL_PS1:-}" 26 | export PS1 27 | unset _OLD_VIRTUAL_PS1 28 | fi 29 | 30 | unset VIRTUAL_ENV 31 | if [ ! "${1:-}" = "nondestructive" ] ; then 32 | # Self destruct! 33 | unset -f deactivate 34 | fi 35 | } 36 | 37 | # unset irrelevant variables 38 | deactivate nondestructive 39 | 40 | VIRTUAL_ENV="E:\Normal_Workspace_Collection\Efficient-VRNet-beta\Efficient-VRNet-beta\venv" 41 | export VIRTUAL_ENV 42 | 43 | _OLD_VIRTUAL_PATH="$PATH" 44 | PATH="$VIRTUAL_ENV/Scripts:$PATH" 45 | export PATH 46 | 47 | # unset PYTHONHOME if set 48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway) 49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash 50 | if [ -n "${PYTHONHOME:-}" ] ; then 51 | _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}" 52 | unset PYTHONHOME 53 | fi 54 | 55 | if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then 56 | _OLD_VIRTUAL_PS1="${PS1:-}" 57 | PS1="(venv) ${PS1:-}" 58 | export PS1 59 | fi 60 | 61 | # This should detect bash and zsh, which have a hash command that must 62 | # be called to get it to forget past commands. Without forgetting 63 | # past commands the $PATH changes we made may not be respected 64 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then 65 | hash -r 2> /dev/null 66 | fi 67 | -------------------------------------------------------------------------------- /venv/Scripts/activate.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | rem This file is UTF-8 encoded, so we need to update the current code page while executing it 4 | for /f "tokens=2 delims=:." %%a in ('"%SystemRoot%\System32\chcp.com"') do ( 5 | set _OLD_CODEPAGE=%%a 6 | ) 7 | if defined _OLD_CODEPAGE ( 8 | "%SystemRoot%\System32\chcp.com" 65001 > nul 9 | ) 10 | 11 | set VIRTUAL_ENV=E:\Normal_Workspace_Collection\Efficient-VRNet-beta\Efficient-VRNet-beta\venv 12 | 13 | if not defined PROMPT set PROMPT=$P$G 14 | 15 | if defined _OLD_VIRTUAL_PROMPT set PROMPT=%_OLD_VIRTUAL_PROMPT% 16 | if defined _OLD_VIRTUAL_PYTHONHOME set PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME% 17 | 18 | set _OLD_VIRTUAL_PROMPT=%PROMPT% 19 | set PROMPT=(venv) %PROMPT% 20 | 21 | if defined PYTHONHOME set _OLD_VIRTUAL_PYTHONHOME=%PYTHONHOME% 22 | set PYTHONHOME= 23 | 24 | if defined _OLD_VIRTUAL_PATH set PATH=%_OLD_VIRTUAL_PATH% 25 | if not defined _OLD_VIRTUAL_PATH set _OLD_VIRTUAL_PATH=%PATH% 26 | 27 | set PATH=%VIRTUAL_ENV%\Scripts;%PATH% 28 | 29 | :END 30 | if defined _OLD_CODEPAGE ( 31 | "%SystemRoot%\System32\chcp.com" %_OLD_CODEPAGE% > nul 32 | set _OLD_CODEPAGE= 33 | ) 34 | -------------------------------------------------------------------------------- /venv/Scripts/deactivate.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if defined _OLD_VIRTUAL_PROMPT ( 4 | set "PROMPT=%_OLD_VIRTUAL_PROMPT%" 5 | ) 6 | set _OLD_VIRTUAL_PROMPT= 7 | 8 | if defined _OLD_VIRTUAL_PYTHONHOME ( 9 | set "PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME%" 10 | set _OLD_VIRTUAL_PYTHONHOME= 11 | ) 12 | 13 | if defined _OLD_VIRTUAL_PATH ( 14 | set "PATH=%_OLD_VIRTUAL_PATH%" 15 | ) 16 | 17 | set _OLD_VIRTUAL_PATH= 18 | 19 | set VIRTUAL_ENV= 20 | 21 | :END 22 | -------------------------------------------------------------------------------- /venv/Scripts/python.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/venv/Scripts/python.exe -------------------------------------------------------------------------------- /venv/Scripts/pythonw.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/venv/Scripts/pythonw.exe -------------------------------------------------------------------------------- /venv/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | home = D:\Anaconda\Software 2 | include-system-site-packages = false 3 | version = 3.9.12 4 | -------------------------------------------------------------------------------- /voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | 5 | from utils.utils import get_classes 6 | 7 | # --------------------------------------------------------------------------------------------------------------------------------# 8 | # annotation_mode用于指定该文件运行时计算的内容 9 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt 10 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt 11 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt 12 | # --------------------------------------------------------------------------------------------------------------------------------# 13 | annotation_mode = 0 14 | # -------------------------------------------------------------------# 15 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息 16 | # 与训练和预测所用的classes_path一致即可 17 | # 如果生成的2007_train.txt里面没有目标信息 18 | # 那么就是因为classes没有设定正确 19 | # 仅在annotation_mode为0和2的时候有效 20 | # -------------------------------------------------------------------# 21 | classes_path = 'model_data/waterscenes.txt' 22 | # --------------------------------------------------------------------------------------------------------------------------------# 23 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1 24 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1 25 | # 仅在annotation_mode为0和1的时候有效 26 | # --------------------------------------------------------------------------------------------------------------------------------# 27 | trainval_percent = 0.8 28 | train_percent = 0.8 29 | # -------------------------------------------------------# 30 | # 指向VOC数据集所在的文件夹 31 | # 默认指向根目录下的VOC数据集 32 | # -------------------------------------------------------# 33 | VOCdevkit_path = "E:/dataset_collection/WaterScenes/all-1114-voc/all-1114/all/VOCdevkit" 34 | 35 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')] 36 | classes, _ = get_classes(classes_path) 37 | 38 | 39 | def convert_annotation(year, image_id, list_file): 40 | # print(year) 41 | # print(image_id) 42 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml' % (year, image_id)), encoding='utf-8') 43 | tree = ET.parse(in_file) 44 | root = tree.getroot() 45 | 46 | for obj in root.iter('object'): 47 | difficult = 0 48 | if obj.find('difficult') != None: 49 | difficult = obj.find('difficult').text 50 | cls = obj.find('name').text 51 | if cls not in classes or int(difficult) == 1: 52 | continue 53 | cls_id = classes.index(cls) 54 | xmlbox = obj.find('bndbox') 55 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), 56 | int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text))) 57 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) 58 | 59 | 60 | if __name__ == "__main__": 61 | random.seed(0) 62 | if annotation_mode == 0 or annotation_mode == 1: 63 | print("Generate txt in ImageSets.") 64 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations') 65 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main') 66 | temp_xml = os.listdir(xmlfilepath) 67 | total_xml = [] 68 | for xml in temp_xml: 69 | if xml.endswith(".xml"): 70 | total_xml.append(xml) 71 | 72 | num = len(total_xml) 73 | list = range(num) 74 | tv = int(num * trainval_percent) 75 | tr = int(tv * train_percent) 76 | trainval = random.sample(list, tv) 77 | train = random.sample(trainval, tr) 78 | 79 | print("train and val size", tv) 80 | print("train size", tr) 81 | ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w') 82 | ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w') 83 | ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w') 84 | fval = open(os.path.join(saveBasePath, 'val.txt'), 'w') 85 | 86 | for i in list: 87 | name = total_xml[i][:-4] + '\n' 88 | if i in trainval: 89 | ftrainval.write(name) 90 | if i in train: 91 | ftrain.write(name) 92 | else: 93 | fval.write(name) 94 | else: 95 | ftest.write(name) 96 | 97 | ftrainval.close() 98 | ftrain.close() 99 | fval.close() 100 | ftest.close() 101 | print("Generate txt in ImageSets done.") 102 | 103 | if annotation_mode == 0 or annotation_mode == 2: 104 | print("Generate 2007_train.txt and 2007_val.txt for train.") 105 | for year, image_set in VOCdevkit_sets: 106 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt' % (year, image_set)), 107 | encoding='utf-8').read().strip().split() 108 | list_file = open('%s_%s.txt' % (year, image_set), 'w', encoding='utf-8') 109 | for image_id in image_ids: 110 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg' % (os.path.abspath(VOCdevkit_path), year, image_id)) 111 | 112 | try: 113 | convert_annotation(year, image_id, list_file) 114 | except: 115 | continue 116 | list_file.write('\n') 117 | list_file.close() 118 | print("Generate 2007_train.txt and 2007_val.txt for train done.") 119 | -------------------------------------------------------------------------------- /voc_annotation_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | #-------------------------------------------------------# 9 | # 想要增加测试集修改trainval_percent 10 | # 修改train_percent用于改变验证集的比例 9:1 11 | # 12 | # 当前该库将测试集当作验证集使用,不单独划分测试集 13 | #-------------------------------------------------------# 14 | trainval_percent = 1 15 | train_percent = 0.8 16 | #-------------------------------------------------------# 17 | # 指向VOC数据集所在的文件夹 18 | # 默认指向根目录下的VOC数据集 19 | #-------------------------------------------------------# 20 | VOCdevkit_path = 'E:/Big_Datasets/voc_seg/VOCdevkit/VOCdevkit' 21 | 22 | if __name__ == "__main__": 23 | random.seed(0) 24 | print("Generate txt in ImageSets.") 25 | segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass') 26 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation') 27 | 28 | temp_seg = os.listdir(segfilepath) 29 | total_seg = [] 30 | for seg in temp_seg: 31 | if seg.endswith(".png"): 32 | total_seg.append(seg) 33 | 34 | num = len(total_seg) 35 | list = range(num) 36 | tv = int(num*trainval_percent) 37 | tr = int(tv*train_percent) 38 | trainval= random.sample(list,tv) 39 | train = random.sample(trainval,tr) 40 | 41 | print("train and val size",tv) 42 | print("traub suze",tr) 43 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 44 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 45 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 46 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 47 | 48 | for i in list: 49 | name = total_seg[i][:-4]+'\n' 50 | if i in trainval: 51 | ftrainval.write(name) 52 | if i in train: 53 | ftrain.write(name) 54 | else: 55 | fval.write(name) 56 | else: 57 | ftest.write(name) 58 | 59 | ftrainval.close() 60 | ftrain.close() 61 | fval.close() 62 | ftest.close() 63 | print("Generate txt in ImageSets done.") 64 | 65 | print("Check datasets format, this may take a while.") 66 | print("检查数据集格式是否符合要求,这可能需要一段时间。") 67 | classes_nums = np.zeros([256], np.int) 68 | for i in tqdm(list): 69 | name = total_seg[i] 70 | png_file_name = os.path.join(segfilepath, name) 71 | if not os.path.exists(png_file_name): 72 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name)) 73 | 74 | png = np.array(Image.open(png_file_name), np.uint8) 75 | if len(np.shape(png)) > 2: 76 | print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png)))) 77 | print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png)))) 78 | 79 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256) 80 | 81 | print("打印像素点的值与数量。") 82 | print('-' * 37) 83 | print("| %15s | %15s |"%("Key", "Value")) 84 | print('-' * 37) 85 | for i in range(256): 86 | if classes_nums[i] > 0: 87 | print("| %15s | %15s |"%(str(i), str(classes_nums[i]))) 88 | print('-' * 37) 89 | 90 | if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0: 91 | print("检测到标签中像素点的值仅包含0与255,数据格式有误。") 92 | print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。") 93 | elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0: 94 | print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。") 95 | --------------------------------------------------------------------------------