├── .idea
├── Efficient-VRNet.iml
├── encodings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── 2007_val.txt
├── README.md
├── backbone
├── __init__.py
├── __pycache__
│ └── __init__.cpython-39.pyc
├── attention_modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── eca.cpython-39.pyc
│ │ └── shuffle_attention.cpython-39.pyc
│ ├── contextual_attention.py
│ ├── eca.py
│ ├── mobile_attention.py
│ ├── mobile_vit_attention.py
│ └── shuffle_attention.py
├── conv_utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ └── normal_conv.cpython-39.pyc
│ ├── dcn.py
│ ├── ds_conv.py
│ ├── dynamic_conv.py
│ └── normal_conv.py
├── fusion
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ └── vr_coc.cpython-39.pyc
│ ├── context_cluster_down.py
│ ├── context_cluster_iterative.py
│ ├── context_cluster_nofold.py
│ ├── context_cluster_withGAP.py
│ └── vr_coc.py
├── radar
│ ├── __init__.py
│ ├── context_cluster.py
│ ├── context_cluster_down.py
│ ├── context_cluster_iterative.py
│ ├── context_cluster_nofold.py
│ ├── context_cluster_withGAP.py
│ ├── contextcluster.py
│ ├── model_utils.py
│ ├── pointmlp.py
│ └── pointnet.py
└── vision
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── context_cluster.cpython-39.pyc
│ ├── context_cluster.py
│ ├── context_cluster_down.py
│ ├── context_cluster_iterative.py
│ ├── context_cluster_nofold.py
│ └── context_cluster_withGAP.py
├── deeplab.py
├── get_miou.py
├── head
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── decouplehead.cpython-39.pyc
└── decouplehead.py
├── image_augmentation_test
├── __init__.py
├── dark_channel.py
└── sharpen.py
├── model_data
├── heatmap_vision.png
├── voc_classes.txt
└── waterscenes.txt
├── neck
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── coc_fpn_dual.cpython-39.pyc
└── coc_fpn_dual.py
├── nets
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── deeplabv3_training.cpython-39.pyc
│ ├── efficient_vrnet.cpython-39.pyc
│ └── yolo_training.cpython-39.pyc
├── deeplabv3_training.py
├── efficient_vrnet.py
└── yolo_training.py
├── predict.py
├── predict_seg.py
├── requirements.txt
├── train.py
├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── callbacks.cpython-39.pyc
│ ├── dataloader.cpython-39.pyc
│ ├── multitaskloss.cpython-39.pyc
│ ├── utils.cpython-39.pyc
│ ├── utils_bbox.cpython-39.pyc
│ ├── utils_fit.cpython-39.pyc
│ └── utils_map.cpython-39.pyc
├── callbacks.py
├── dataloader.py
├── multitaskloss.py
├── utils.py
├── utils_bbox.py
├── utils_fit.py
└── utils_map.py
├── utils_seg
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── callbacks.cpython-39.pyc
│ ├── dataloader.cpython-39.pyc
│ ├── utils.cpython-39.pyc
│ ├── utils_fit.cpython-39.pyc
│ └── utils_metrics.cpython-39.pyc
├── callbacks.py
├── dataloader.py
├── utils.py
├── utils_fit.py
└── utils_metrics.py
├── venv
├── Scripts
│ ├── Activate.ps1
│ ├── activate
│ ├── activate.bat
│ ├── deactivate.bat
│ ├── python.exe
│ └── pythonw.exe
└── pyvenv.cfg
├── voc_annotation.py
├── voc_annotation_seg.py
└── yolo.py
/.idea/Efficient-VRNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ASY-VRNet: Waterway Panoptic Driving Perception Model based on Asymmetric Fair Fusion of Vision and 4D mmWave Radar
2 |
3 | # Device:
4 | Camera: Sony IMX-317
5 |
6 | Radar: Ocuii Imaging Radar
7 |
8 | # Implementation:
9 |
10 | * Create a conda environment and install dependencies
11 | > git clone https://github.com/GuanRunwei/Efficient-VRNet.git \
12 | > cd Efficient-VRNet \
13 | > conda create -n efficientvrnet python=3.7 \
14 | > conda activate efficientvrnet \
15 | > pip install -r requirements.txt
16 |
17 |
18 | * Prepare datasets for object detection and semantic segmentation based on image and radar
19 |
20 | > For object detection, make two files, one for training and another for test. There are two ways to complete this. \
21 | > 1. Make two txt files, one for training and another for test. Two files are with the same format:
22 | > In one line, there are two parts, an image path and objects:
23 | > **image path**: E:/Big_Datasets/water_surface/all-1114/all/VOCdevkit/VOC2007/JPEGImages/1664091257.87023.jpg
24 | > **object 1** (the first four numbers are the bounding box and the last is the category): 1131,430,1152,473,0
25 | > **object 2**: 920,425,937,451,0
26 | > Therefore, each line is like this: E:/Big_Datasets/water_surface/all-1114/all/VOCdevkit/VOC2007/JPEGImages/1664091257.87023.jpg 1131,430,1152,473,0 920,425,937,451,0
27 | > 2. Organize the files in VOC format in one folder like this: \
28 | > VOCdevkit \
29 | > -VOC2007 \
30 | > -- Annotations -> xml annotations in VOC format (you need to put annotations in it) \
31 | > -- ImageSets -> id (you do not need to do) \
32 | > -- JPEGImages -> images (you need to put images in it) \
33 | > **enter** voc_annotation.py and follow the annotation to make your dataset
34 |
35 |
36 | > For semantic segmentation, make the folders in VOC format. \
37 | > VOCdevkit \
38 | > -VOC2007 \
39 | > -- ImageSets -> id (you do not need to do) \
40 | > -- JPEGImages -> images (you need to put images in it) \
41 | > -- SegmentationClass -> images_seg (you need to put images in it)
42 |
43 |
44 | > For radar files, you need to make the radar map with the spatial size of images for object detection.
45 | We need four features: range, velocity, elevation and power, so firstly project 3D point clouds into 2D image plane,
46 | then make a numpy matrix with 4×512×512, each channel means one feature. Then, save the numpy matrix in npz format.
47 |
48 | > ***Attention:*** The names of images for object detection,segmentation, and radar must be the same!
49 | The only difference between them is the format.
50 |
51 | * Train
52 |
53 | > After you have completed the above, enter the **train.py**. Change the file path variables and hyperparameters and run it.
54 |
55 |
56 | * Visualization
57 |
58 | > **predict.py** is to test the object detection \
59 | > **yolo.py** is to define the model for object detection \
60 |
61 | > **predict_seg.py** is to test the semantic segmentation \
62 | > **deeplab.py** is to define the model for semantic segmentation \
63 |
64 | > enter these files see the details by annotations in the files
65 |
66 | If have any questions, put them in Issues ---
67 |
68 |
--------------------------------------------------------------------------------
/backbone/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/__init__.py
--------------------------------------------------------------------------------
/backbone/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/attention_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__init__.py
--------------------------------------------------------------------------------
/backbone/attention_modules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/attention_modules/__pycache__/eca.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__pycache__/eca.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/attention_modules/__pycache__/shuffle_attention.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/attention_modules/__pycache__/shuffle_attention.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/attention_modules/contextual_attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import flatten, nn
4 | from torch.nn import init
5 | from torch.nn.modules.activation import ReLU
6 | from torch.nn.modules.batchnorm import BatchNorm2d
7 | from torch.nn import functional as F
8 |
9 |
10 | class ContextAttention(nn.Module):
11 |
12 | def __init__(self, dim=512, kernel_size=3):
13 | super().__init__()
14 | self.dim = dim
15 | self.kernel_size = kernel_size
16 |
17 | self.key_embed = nn.Sequential(
18 | nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=4, bias=False),
19 | nn.BatchNorm2d(dim),
20 | nn.ReLU()
21 | )
22 | self.value_embed=nn.Sequential(
23 | nn.Conv2d(dim, dim, 1, bias=False),
24 | nn.BatchNorm2d(dim)
25 | )
26 |
27 | factor = 4
28 | self.attention_embed = nn.Sequential(
29 | nn.Conv2d(2*dim, 2*dim//factor, 1, bias=False),
30 | nn.BatchNorm2d(2*dim//factor),
31 | nn.ReLU(),
32 | nn.Conv2d(2*dim//factor, kernel_size*kernel_size*dim, 1)
33 | )
34 |
35 |
36 | def forward(self, x):
37 | bs, c, h, w = x.shape
38 | k1 = self.key_embed(x) # bs,c,h,w
39 | v = self.value_embed(x).view(bs, c, -1) # bs,c,h,w
40 |
41 | y = torch.cat([k1, x], dim=1) # bs,2c,h,w
42 | att = self.attention_embed(y) # bs, c*k*k,h,w
43 | att = att.reshape(bs, c, self.kernel_size*self.kernel_size, h, w)
44 | att = att.mean(2, keepdim=False).view(bs, c, -1) # bs,c,h*w
45 | k2 = F.softmax(att, dim=-1)*v
46 | k2 = k2.view(bs, c, h, w)
47 | return k1+k2
48 |
49 |
50 | if __name__ == '__main__':
51 | input = torch.randn(50, 512, 7, 7)
52 | cot = ContextAttention(dim=512, kernel_size=3)
53 | output = cot(input)
54 | print(output.shape)
55 |
--------------------------------------------------------------------------------
/backbone/attention_modules/eca.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class eca_block(nn.Module):
7 | def __init__(self, channel, b=1, gamma=2):
8 | super(eca_block, self).__init__()
9 | kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
10 | kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
11 |
12 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
13 | self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
14 | self.sigmoid = nn.Sigmoid()
15 |
16 | def forward(self, x):
17 | output = self.avg_pool(x)
18 | # print("average pool shape:", output.shape)
19 | output = self.conv(output.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
20 | output = self.sigmoid(output)
21 |
22 | return x * output.expand_as(x)
23 |
24 |
25 |
--------------------------------------------------------------------------------
/backbone/attention_modules/mobile_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def position(H, W, is_cuda=True):
7 | if is_cuda:
8 | loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)
9 | loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)
10 | else:
11 | loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
12 | loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
13 | loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
14 | return loc
15 |
16 |
17 | def stride(x, stride):
18 | b, c, h, w = x.shape
19 | return x[:, :, ::stride, ::stride]
20 |
21 |
22 | def init_rate_half(tensor):
23 | if tensor is not None:
24 | tensor.data.fill_(0.5)
25 |
26 |
27 | def init_rate_0(tensor):
28 | if tensor is not None:
29 | tensor.data.fill_(0.)
30 |
31 |
32 | class MobileAttention(nn.Module):
33 | def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
34 | super(MobileAttention, self).__init__()
35 | self.in_planes = in_planes
36 | self.out_planes = out_planes
37 | self.head = head
38 | self.kernel_att = kernel_att
39 | self.kernel_conv = kernel_conv
40 | self.stride = stride
41 | self.dilation = dilation
42 | self.rate1 = torch.nn.Parameter(torch.Tensor(1))
43 | self.rate2 = torch.nn.Parameter(torch.Tensor(1))
44 | self.head_dim = self.out_planes // self.head
45 |
46 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
47 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
48 | self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
49 | self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)
50 |
51 | self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
52 | self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
53 | self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
54 | self.softmax = torch.nn.Softmax(dim=1)
55 |
56 | self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
57 | self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,
58 | kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,
59 | stride=stride)
60 |
61 | self.reset_parameters()
62 |
63 | def reset_parameters(self):
64 | init_rate_half(self.rate1)
65 | init_rate_half(self.rate2)
66 | kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
67 | for i in range(self.kernel_conv * self.kernel_conv):
68 | kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
69 | kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
70 | self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
71 | self.dep_conv.bias = init_rate_0(self.dep_conv.bias)
72 |
73 | def forward(self, x):
74 | q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
75 | scaling = float(self.head_dim) ** -0.5
76 | b, c, h, w = q.shape
77 | h_out, w_out = h // self.stride, w // self.stride
78 |
79 | pe = self.conv_p(position(h, w, x.is_cuda))
80 |
81 | q_att = q.view(b * self.head, self.head_dim, h, w) * scaling
82 | k_att = k.view(b * self.head, self.head_dim, h, w)
83 | v_att = v.view(b * self.head, self.head_dim, h, w)
84 |
85 | if self.stride > 1:
86 | q_att = stride(q_att, self.stride)
87 | q_pe = stride(pe, self.stride)
88 | else:
89 | q_pe = pe
90 |
91 | unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,
92 | self.kernel_att * self.kernel_att, h_out,
93 | w_out) # b*head, head_dim, k_att^2, h_out, w_out
94 | unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,
95 | w_out) # 1, head_dim, k_att^2, h_out, w_out
96 |
97 | att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(
98 | 1) # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
99 | att = self.softmax(att)
100 |
101 | out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,
102 | h_out, w_out)
103 | out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)
104 |
105 | f_all = self.fc(torch.cat(
106 | [q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),
107 | v.view(b, self.head, self.head_dim, h * w)], 1))
108 | f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])
109 |
110 | out_conv = self.dep_conv(f_conv)
111 |
112 | return self.rate1 * out_att + self.rate2 * out_conv
--------------------------------------------------------------------------------
/backbone/attention_modules/mobile_vit_attention.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from einops import rearrange
4 |
5 |
6 | class PreNorm(nn.Module):
7 | def __init__(self, dim, fn):
8 | super().__init__()
9 | self.ln = nn.LayerNorm(dim)
10 | self.fn = fn
11 |
12 | def forward(self, x, **kwargs):
13 | return self.fn(self.ln(x), **kwargs)
14 |
15 |
16 | class FeedForward(nn.Module):
17 | def __init__(self, dim, mlp_dim, dropout):
18 | super().__init__()
19 | self.net = nn.Sequential(
20 | nn.Linear(dim, mlp_dim),
21 | nn.SiLU(),
22 | nn.Dropout(dropout),
23 | nn.Linear(mlp_dim, dim),
24 | nn.Dropout(dropout)
25 | )
26 |
27 | def forward(self, x):
28 | return self.net(x)
29 |
30 |
31 | class Attention(nn.Module):
32 | def __init__(self, dim, heads, head_dim, dropout):
33 | super().__init__()
34 | inner_dim = heads * head_dim
35 | project_out = not (heads == 1 and head_dim == dim)
36 |
37 | self.heads = heads
38 | self.scale = head_dim ** -0.5
39 |
40 | self.attend = nn.Softmax(dim=-1)
41 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
42 |
43 | self.to_out = nn.Sequential(
44 | nn.Linear(inner_dim, dim),
45 | nn.Dropout(dropout)
46 | ) if project_out else nn.Identity()
47 |
48 | def forward(self, x):
49 | qkv = self.to_qkv(x).chunk(3, dim=-1)
50 | q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)
51 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
52 | attn = self.attend(dots)
53 | out = torch.matmul(attn, v)
54 | out = rearrange(out, 'b p h n d -> b p n (h d)')
55 | return self.to_out(out)
56 |
57 |
58 | class Transformer(nn.Module):
59 | def __init__(self, dim, depth, heads, head_dim, mlp_dim, dropout=0.):
60 | super().__init__()
61 | self.layers = nn.ModuleList([])
62 | for _ in range(depth):
63 | self.layers.append(nn.ModuleList([
64 | PreNorm(dim, Attention(dim, heads, head_dim, dropout)),
65 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
66 | ]))
67 |
68 | def forward(self, x):
69 | out = x
70 | for att, ffn in self.layers:
71 | out = out + att(out)
72 | out = out + ffn(out)
73 | return out
74 |
75 |
76 | class MobileViTAttention(nn.Module):
77 | def __init__(self, in_channel=3, dim=512, kernel_size=3, patch_size=4):
78 | super().__init__()
79 | self.ph, self.pw = patch_size, patch_size
80 | self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2)
81 | self.conv2 = nn.Conv2d(in_channel, dim, kernel_size=1)
82 |
83 | self.trans = Transformer(dim=dim, depth=3, heads=8, head_dim=64, mlp_dim=1024)
84 |
85 | self.conv3 = nn.Conv2d(dim, in_channel, kernel_size=1)
86 | self.conv4 = nn.Conv2d(2 * in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2)
87 |
88 | def forward(self, x):
89 | y = x.clone() # bs,c,h,w
90 |
91 | ## Local Representation
92 | y = self.conv2(self.conv1(x)) # bs,dim,h,w
93 |
94 | ## Global Representation
95 | _, _, h, w = y.shape
96 | y = rearrange(y, 'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim', ph=self.ph, pw=self.pw) # bs,h,w,dim
97 | y = self.trans(y)
98 | y = rearrange(y, 'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)', ph=self.ph, pw=self.pw, nh=h // self.ph,
99 | nw=w // self.pw) # bs,dim,h,w
100 |
101 | ## Fusion
102 | y = self.conv3(y) # bs,dim,h,w
103 | y = torch.cat([x, y], 1) # bs,2*dim,h,w
104 | y = self.conv4(y) # bs,c,h,w
105 |
106 | return y
107 |
108 |
109 | if __name__ == '__main__':
110 | m = MobileViTAttention(in_channel=96)
111 | input = torch.randn(16, 96, 80, 80)
112 | output = m(input)
113 | print(output.shape)
--------------------------------------------------------------------------------
/backbone/attention_modules/shuffle_attention.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from torch.nn import init
5 | from torch.nn.parameter import Parameter
6 |
7 |
8 | class ShuffleAttention(nn.Module):
9 |
10 | def __init__(self, channel=512,reduction=16,G=8):
11 | super().__init__()
12 | self.G=G
13 | self.channel=channel
14 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
15 | self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
16 | self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
17 | self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
18 | self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
19 | self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
20 | self.sigmoid=nn.Sigmoid()
21 |
22 | def init_weights(self):
23 | for m in self.modules():
24 | if isinstance(m, nn.Conv2d):
25 | init.kaiming_normal_(m.weight, mode='fan_out')
26 | if m.bias is not None:
27 | init.constant_(m.bias, 0)
28 | elif isinstance(m, nn.BatchNorm2d):
29 | init.constant_(m.weight, 1)
30 | init.constant_(m.bias, 0)
31 | elif isinstance(m, nn.Linear):
32 | init.normal_(m.weight, std=0.001)
33 | if m.bias is not None:
34 | init.constant_(m.bias, 0)
35 |
36 |
37 | @staticmethod
38 | def channel_shuffle(x, groups):
39 | b, c, h, w = x.shape
40 | x = x.reshape(b, groups, -1, h, w)
41 | x = x.permute(0, 2, 1, 3, 4)
42 |
43 | # flatten
44 | x = x.reshape(b, -1, h, w)
45 |
46 | return x
47 |
48 | def forward(self, x):
49 | b, c, h, w = x.size()
50 | #group into subfeatures
51 | x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w
52 |
53 | #channel_split
54 | x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w
55 |
56 | #channel attention
57 | x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
58 | x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
59 | x_channel=x_0*self.sigmoid(x_channel)
60 |
61 | #spatial attention
62 | x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
63 | x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
64 | x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w
65 |
66 | # concatenate along channel axis
67 | out=torch.cat([x_channel,x_spatial],dim=1) #bs*G,c//G,h,w
68 | out=out.contiguous().view(b,-1,h,w)
69 |
70 | # channel shuffle
71 | out = self.channel_shuffle(out, 2)
72 | return out
73 |
74 |
75 | if __name__ == '__main__':
76 | input=torch.randn(50,32,224,224)
77 | se = ShuffleAttention(channel=32,G=4)
78 | output=se(input)
79 | print(output.shape)
--------------------------------------------------------------------------------
/backbone/conv_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/conv_utils/__init__.py
--------------------------------------------------------------------------------
/backbone/conv_utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/conv_utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/conv_utils/__pycache__/normal_conv.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/conv_utils/__pycache__/normal_conv.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/conv_utils/dcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.ops
3 | from torch import nn
4 |
5 |
6 | class DeformableConv2d(nn.Module):
7 | def __init__(self,
8 | in_channels,
9 | out_channels,
10 | kernel_size=3,
11 | stride=1,
12 | padding=1,
13 | bias=False):
14 | super(DeformableConv2d, self).__init__()
15 |
16 | assert type(kernel_size) == tuple or type(kernel_size) == int
17 |
18 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
19 | self.stride = stride if type(stride) == tuple else (stride, stride)
20 | self.padding = padding
21 |
22 | self.offset_conv = nn.Conv2d(in_channels,
23 | 2 * kernel_size[0] * kernel_size[1],
24 | kernel_size=kernel_size,
25 | stride=stride,
26 | padding=self.padding,
27 | bias=True)
28 |
29 | nn.init.constant_(self.offset_conv.weight, 0.)
30 | nn.init.constant_(self.offset_conv.bias, 0.)
31 |
32 | self.modulator_conv = nn.Conv2d(in_channels,
33 | 1 * kernel_size[0] * kernel_size[1],
34 | kernel_size=kernel_size,
35 | stride=stride,
36 | padding=self.padding,
37 | bias=True)
38 |
39 | nn.init.constant_(self.modulator_conv.weight, 0.)
40 | nn.init.constant_(self.modulator_conv.bias, 0.)
41 |
42 | self.regular_conv = nn.Conv2d(in_channels=in_channels,
43 | out_channels=out_channels,
44 | kernel_size=kernel_size,
45 | stride=stride,
46 | padding=self.padding,
47 | bias=bias)
48 |
49 | def forward(self, x):
50 | # h, w = x.shape[2:]
51 | # max_offset = max(h, w)/4.
52 |
53 | offset = self.offset_conv(x) # .clamp(-max_offset, max_offset)
54 | modulator = 2. * torch.sigmoid(self.modulator_conv(x))
55 |
56 | x = torchvision.ops.deform_conv2d(input=x,
57 | offset=offset,
58 | weight=self.regular_conv.weight,
59 | bias=self.regular_conv.bias,
60 | padding=self.padding,
61 | mask=modulator,
62 | stride=self.stride,
63 | )
64 | return x
--------------------------------------------------------------------------------
/backbone/conv_utils/ds_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.autograd
5 |
6 |
7 | class BN_Conv2d(nn.Module):
8 | """
9 | BN_CONV, default activation is SiLU
10 | """
11 |
12 | def __init__(self, in_channels: object, out_channels: object, kernel_size: object, stride: object, padding: object,
13 | dilation=1, groups=1, bias=False, activation=True) -> object:
14 | super(BN_Conv2d, self).__init__()
15 | layers = [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
16 | padding=padding, dilation=dilation, groups=groups, bias=bias),
17 | nn.BatchNorm2d(out_channels)]
18 | if activation:
19 | layers.append(nn.ReLU(inplace=True))
20 | self.seq = nn.Sequential(*layers)
21 |
22 | def forward(self, x):
23 | return self.seq(x)
24 |
25 |
26 | class DWConv(nn.Module):
27 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, activation=True):
28 | super().__init__()
29 | self.dconv = BN_Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, groups=in_channels,
30 | activation=activation)
31 | self.pconv = BN_Conv2d(in_channels, out_channels, kernel_size=1, stride=1, groups=1, activation=activation)
32 |
33 | def forward(self, x):
34 | x = self.dconv(x)
35 | return self.pconv(x)
--------------------------------------------------------------------------------
/backbone/conv_utils/dynamic_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.autograd
5 |
6 |
7 | class Attention(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
9 | super(Attention, self).__init__()
10 | attention_channel = max(int(in_planes * reduction), min_channel)
11 | self.kernel_size = kernel_size
12 | self.kernel_num = kernel_num
13 | self.temperature = 1.0
14 |
15 | self.avgpool = nn.AdaptiveAvgPool2d(1)
16 | self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)
17 | self.bn = nn.BatchNorm2d(attention_channel)
18 | self.relu = nn.ReLU(inplace=True)
19 |
20 | self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
21 | self.func_channel = self.get_channel_attention
22 |
23 | if in_planes == groups and in_planes == out_planes: # depth-wise convolution
24 | self.func_filter = self.skip
25 | else:
26 | self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
27 | self.func_filter = self.get_filter_attention
28 |
29 | if kernel_size == 1: # point-wise convolution
30 | self.func_spatial = self.skip
31 | else:
32 | self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
33 | self.func_spatial = self.get_spatial_attention
34 |
35 | if kernel_num == 1:
36 | self.func_kernel = self.skip
37 | else:
38 | self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
39 | self.func_kernel = self.get_kernel_attention
40 |
41 | self._initialize_weights()
42 |
43 | def _initialize_weights(self):
44 | for m in self.modules():
45 | if isinstance(m, nn.Conv2d):
46 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
47 | if m.bias is not None:
48 | nn.init.constant_(m.bias, 0)
49 | if isinstance(m, nn.BatchNorm2d):
50 | nn.init.constant_(m.weight, 1)
51 | nn.init.constant_(m.bias, 0)
52 |
53 | def update_temperature(self, temperature):
54 | self.temperature = temperature
55 |
56 | @staticmethod
57 | def skip(_):
58 | return 1.0
59 |
60 | def get_channel_attention(self, x):
61 | channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
62 | return channel_attention
63 |
64 | def get_filter_attention(self, x):
65 | filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
66 | return filter_attention
67 |
68 | def get_spatial_attention(self, x):
69 | spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
70 | spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
71 | return spatial_attention
72 |
73 | def get_kernel_attention(self, x):
74 | kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
75 | kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
76 | return kernel_attention
77 |
78 | def forward(self, x):
79 | x = self.avgpool(x)
80 | x = self.fc(x)
81 | x = self.bn(x)
82 | x = self.relu(x)
83 | return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)
84 |
85 |
86 | class DynamicConv(nn.Module):
87 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
88 | reduction=0.0625, kernel_num=4):
89 | super(DynamicConv, self).__init__()
90 | self.in_planes = in_planes
91 | self.out_planes = out_planes
92 | self.kernel_size = kernel_size
93 | self.stride = stride
94 | self.padding = padding
95 | self.dilation = dilation
96 | self.groups = groups
97 | self.kernel_num = kernel_num
98 | self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
99 | reduction=reduction, kernel_num=kernel_num)
100 | self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
101 | requires_grad=True)
102 | self._initialize_weights()
103 |
104 | if self.kernel_size == 1 and self.kernel_num == 1:
105 | self._forward_impl = self._forward_impl_pw1x
106 | else:
107 | self._forward_impl = self._forward_impl_common
108 |
109 | def _initialize_weights(self):
110 | for i in range(self.kernel_num):
111 | nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')
112 |
113 | def update_temperature(self, temperature):
114 | self.attention.update_temperature(temperature)
115 |
116 | def _forward_impl_common(self, x):
117 | # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
118 | # while we observe that when using the latter method the models will run faster with less gpu memory cost.
119 | channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
120 | batch_size, in_planes, height, width = x.size()
121 | x = x * channel_attention
122 | x = x.reshape(1, -1, height, width)
123 | aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
124 | aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
125 | [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
126 | output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
127 | dilation=self.dilation, groups=self.groups * batch_size)
128 | output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
129 | output = output * filter_attention
130 | return output
131 |
132 | def _forward_impl_pw1x(self, x):
133 | channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
134 | x = x * channel_attention
135 | output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
136 | dilation=self.dilation, groups=self.groups)
137 | output = output * filter_attention
138 | return output
139 |
140 | def forward(self, x):
141 | return self._forward_impl(x)
--------------------------------------------------------------------------------
/backbone/conv_utils/normal_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SiLU(nn.Module):
6 | @staticmethod
7 | def forward(x):
8 | return x * torch.sigmoid(x)
9 |
10 |
11 | def get_activation(name="silu", inplace=True):
12 | if name == "silu":
13 | module = SiLU()
14 | elif name == "relu":
15 | module = nn.ReLU(inplace=inplace)
16 | elif name == "lrelu":
17 | module = nn.LeakyReLU(0.1, inplace=inplace)
18 | else:
19 | raise AttributeError("Unsupported act type: {}".format(name))
20 | return module
21 |
22 |
23 | class DWConv(nn.Module):
24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
25 | super().__init__()
26 | self.dconv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
27 | stride=stride, groups=in_channels, padding=padding, dilation=dilation, bias=bias)
28 | self.pconv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, groups=1,
29 | bias=bias)
30 |
31 | def forward(self, x):
32 | x = self.dconv(x)
33 | return self.pconv(x)
34 |
35 |
36 | class BaseConv(nn.Module):
37 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="relu", ds_conv=False):
38 | super().__init__()
39 | pad = (ksize - 1) // 2
40 | if ds_conv is False:
41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad,
42 | groups=groups, bias=bias)
43 | else:
44 | self.conv = DWConv(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, bias=bias)
45 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
46 | self.act = get_activation(act, inplace=True)
47 |
48 | def forward(self, x):
49 | return self.act(self.bn(self.conv(x)))
50 |
51 | def fuseforward(self, x):
52 | return self.act(self.conv(x))
53 |
--------------------------------------------------------------------------------
/backbone/fusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/fusion/__init__.py
--------------------------------------------------------------------------------
/backbone/fusion/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/fusion/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/fusion/__pycache__/vr_coc.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/fusion/__pycache__/vr_coc.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/radar/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/radar/__init__.py
--------------------------------------------------------------------------------
/backbone/radar/contextcluster.py:
--------------------------------------------------------------------------------
1 | """
2 | We add our Context Cluster based on PointMLP, to validate its effectiveness and applicability on point cloud data.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | # from torch import einsum
8 | # from einops import rearrange, repeat
9 | from einops import rearrange
10 | import torch.nn.functional as F
11 |
12 |
13 | try:
14 | from model_utils import get_activation, square_distance, index_points, query_ball_point, knn_point
15 | except:
16 | from .model_utils import get_activation, square_distance, index_points, query_ball_point, knn_point
17 |
18 | try:
19 | from pointnet2_ops.pointnet2_utils import furthest_point_sample as furthest_point_sample
20 | except:
21 | print("==> not using CUDA FPS\n")
22 | try:
23 | from model_utils import farthest_point_sample as furthest_point_sample
24 | except:
25 | from .model_utils import farthest_point_sample as furthest_point_sample
26 |
27 |
28 | class LocalGrouper(nn.Module):
29 | def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
30 | """
31 | Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
32 | :param groups: groups number
33 | :param kneighbors: k-nerighbors
34 | :param kwargs: others
35 | """
36 | super(LocalGrouper, self).__init__()
37 | self.groups = groups
38 | self.kneighbors = kneighbors
39 | self.use_xyz = use_xyz
40 | if normalize is not None:
41 | self.normalize = normalize.lower()
42 | else:
43 | self.normalize = None
44 | if self.normalize not in ["center", "anchor"]:
45 | print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
46 | self.normalize = None
47 | if self.normalize is not None:
48 | add_channel=3 if self.use_xyz else 0
49 | self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel]))
50 | self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
51 |
52 | def forward(self, xyz, points):
53 | B, N, C = xyz.shape
54 | S = self.groups
55 | xyz = xyz.contiguous() # xyz [btach, points, xyz]
56 |
57 | # fps_idx = torch.multinomial(torch.linspace(0, N - 1, steps=N).repeat(B, 1).to(xyz.device), num_samples=self.groups, replacement=False).long()
58 | # fps_idx = farthest_point_sample(xyz, self.groups).long()
59 | fps_idx = furthest_point_sample(xyz, self.groups).long() # [B, npoint]
60 | new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3]
61 | new_points = index_points(points, fps_idx) # [B, npoint, d]
62 |
63 | idx = knn_point(self.kneighbors, xyz, new_xyz)
64 | # idx = query_ball_point(radius, nsample, xyz, new_xyz)
65 | grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
66 | grouped_points = index_points(points, idx) # [B, npoint, k, d]
67 | if self.use_xyz:
68 | grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
69 | if self.normalize is not None:
70 | if self.normalize =="center":
71 | mean = torch.mean(grouped_points, dim=2, keepdim=True)
72 | if self.normalize =="anchor":
73 | mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
74 | mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
75 | std = torch.std((grouped_points-mean).reshape(B,-1),dim=-1,keepdim=True).unsqueeze(dim=-1).unsqueeze(dim=-1)
76 | grouped_points = (grouped_points-mean)/(std + 1e-5)
77 | grouped_points = self.affine_alpha*grouped_points + self.affine_beta
78 |
79 | new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
80 | return new_xyz, new_points
81 |
82 |
83 | class ConvBNReLU1D(nn.Module):
84 | def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
85 | super(ConvBNReLU1D, self).__init__()
86 | self.act = get_activation(activation)
87 | self.net = nn.Sequential(
88 | nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
89 | nn.BatchNorm1d(out_channels),
90 | self.act
91 | )
92 |
93 | def forward(self, x):
94 | return self.net(x)
95 |
96 |
97 | class ConvBNReLURes1D(nn.Module):
98 | def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
99 | super(ConvBNReLURes1D, self).__init__()
100 | self.act = get_activation(activation)
101 | self.net1 = nn.Sequential(
102 | nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
103 | kernel_size=kernel_size, groups=groups, bias=bias),
104 | nn.BatchNorm1d(int(channel * res_expansion)),
105 | self.act
106 | )
107 | if groups > 1:
108 | self.net2 = nn.Sequential(
109 | nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
110 | kernel_size=kernel_size, groups=groups, bias=bias),
111 | nn.BatchNorm1d(channel),
112 | self.act,
113 | nn.Conv1d(in_channels=channel, out_channels=channel,
114 | kernel_size=kernel_size, bias=bias),
115 | nn.BatchNorm1d(channel),
116 | )
117 | else:
118 | self.net2 = nn.Sequential(
119 | nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
120 | kernel_size=kernel_size, bias=bias),
121 | nn.BatchNorm1d(channel)
122 | )
123 |
124 | def forward(self, x):
125 | return self.act(self.net2(self.net1(x)) + x)
126 |
127 |
128 | def pairwise_cos_sim(x1: torch.Tensor, x2:torch.Tensor):
129 | """
130 | return pair-wise similarity matrix between two tensors
131 | :param x1: [B,...,M,D]
132 | :param x2: [B,...,N,D]
133 | :return: similarity matrix [B,...,M,N]
134 | """
135 | x1 = F.normalize(x1,dim=-1)
136 | x2 = F.normalize(x2,dim=-1)
137 |
138 | sim = torch.matmul(x1, x2.transpose(-2, -1))
139 | return sim
140 |
141 |
142 | class ContextCluster(nn.Module):
143 | def __init__(self, dim, heads=4, head_dim=24):
144 | super(ContextCluster, self).__init__()
145 | self.heads = heads
146 | self.head_dim=head_dim
147 | self.fc1 = nn.Linear(dim, heads*head_dim)
148 | self.fc2 = nn.Linear(heads*head_dim, dim)
149 | self.fc_v = nn.Linear(dim, heads*head_dim)
150 | self.sim_alpha = nn.Parameter(torch.ones(1))
151 | self.sim_beta = nn.Parameter(torch.zeros(1))
152 |
153 | def forward(self, x): #[b,d,k]
154 | res = x
155 | x = rearrange(x, "b d k -> b k d")
156 | value = self.fc_v(x) # [b,k,head*head_d]
157 | x = self.fc1(x) # [b,k,head*head_d]
158 | x = rearrange(x, "b k (h d) -> (b h) k d", h=self.heads) # [b,k,d]
159 | value = rearrange(value, "b k (h d) -> (b h) k d", h=self.heads) # [b,k,d]
160 | center = x.mean(dim=1, keepdim=True) # [b,1,d]
161 | value_center = value.mean(dim=1, keepdim=True) # [b,1,d]
162 | sim = torch.sigmoid(self.sim_beta + self.sim_alpha * pairwise_cos_sim(center, x) )#[B,1,k]
163 | # out [b, 1, d]
164 | out = ( (value.unsqueeze(dim=1)*sim.unsqueeze(dim=-1) ).sum(dim=2) + value_center)/ (sim.sum(dim=-1,keepdim=True)+ 1.0) # [B,M,D]
165 | out = out*(sim.squeeze(dim=1).unsqueeze(dim=-1)) # [b,k,d]
166 | out = rearrange(out, "(b h) k d -> b k (h d)", h=self.heads) # [b,k,d]
167 | out = self.fc2(out)
168 | out = rearrange(out, "b k d -> b d k")
169 | return res + out
170 |
171 |
172 | class PreExtraction(nn.Module):
173 | def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
174 | activation='relu', use_xyz=True):
175 | """
176 | input: [b,g,k,d]: output:[b,d,g]
177 | :param channels:
178 | :param blocks:
179 | """
180 | super(PreExtraction, self).__init__()
181 | in_channels = 3+2*channels if use_xyz else 2*channels
182 | self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
183 | operation = []
184 | for _ in range(blocks):
185 | operation.append(
186 | ContextCluster(out_channels, heads=1,head_dim=max(out_channels//4, 32))
187 | )
188 | operation.append(
189 | # add context cluster here.
190 | ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
191 | bias=bias, activation=activation)
192 | )
193 | self.operation = nn.Sequential(*operation)
194 |
195 | def forward(self, x):
196 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6])
197 | x = x.permute(0, 1, 3, 2)
198 | x = x.reshape(-1, d, s)
199 | x = self.transfer(x)
200 | batch_size, _, _ = x.size()
201 | x = self.operation(x) # [b, d, k]
202 | x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
203 | x = x.reshape(b, n, -1).permute(0, 2, 1)
204 | return x
205 |
206 |
207 | class PosExtraction(nn.Module):
208 | def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
209 | """
210 | input[b,d,g]; output[b,d,g]
211 | :param channels:
212 | :param blocks:
213 | """
214 | super(PosExtraction, self).__init__()
215 | operation = []
216 | for _ in range(blocks):
217 | operation.append(
218 | ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
219 | )
220 | self.operation = nn.Sequential(*operation)
221 |
222 | def forward(self, x): # [b, d, g]
223 | return self.operation(x)
224 |
225 |
226 |
227 |
228 | class Model(nn.Module):
229 | def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0,
230 | activation="relu", bias=True, use_xyz=True, normalize="center",
231 | dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
232 | k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs):
233 | super(Model, self).__init__()
234 | self.stages = len(pre_blocks)
235 | self.class_num = class_num
236 | self.points = points
237 | self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
238 | assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
239 | "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
240 | self.local_grouper_list = nn.ModuleList()
241 | self.pre_blocks_list = nn.ModuleList()
242 | self.pos_blocks_list = nn.ModuleList()
243 | last_channel = embed_dim
244 | anchor_points = self.points
245 | for i in range(len(pre_blocks)):
246 | out_channel = last_channel * dim_expansion[i]
247 | pre_block_num = pre_blocks[i]
248 | pos_block_num = pos_blocks[i]
249 | kneighbor = k_neighbors[i]
250 | reduce = reducers[i]
251 | anchor_points = anchor_points // reduce
252 | # append local_grouper_list
253 | local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
254 | self.local_grouper_list.append(local_grouper)
255 | # append pre_block_list
256 | pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
257 | res_expansion=res_expansion,
258 | bias=bias, activation=activation, use_xyz=use_xyz)
259 | self.pre_blocks_list.append(pre_block_module)
260 | # append pos_block_list
261 | pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
262 | res_expansion=res_expansion, bias=bias, activation=activation)
263 | self.pos_blocks_list.append(pos_block_module)
264 |
265 | last_channel = out_channel
266 |
267 | self.act = get_activation(activation)
268 | self.classifier = nn.Sequential(
269 | nn.Linear(last_channel, 512),
270 | nn.BatchNorm1d(512),
271 | self.act,
272 | nn.Dropout(0.5),
273 | nn.Linear(512, 256),
274 | nn.BatchNorm1d(256),
275 | self.act,
276 | nn.Dropout(0.5),
277 | nn.Linear(256, self.class_num)
278 | )
279 |
280 | def forward(self, x):
281 | xyz = x.permute(0, 2, 1)
282 | batch_size, _, _ = x.size()
283 | x = self.embedding(x) # B,D,N
284 | for i in range(self.stages):
285 | # Give xyz[b, p, 3] and fea[b, p, d], return new_xyz[b, g, 3] and new_fea[b, g, k, d]
286 | xyz, x = self.local_grouper_list[i](xyz, x.permute(0, 2, 1)) # [b,g,3] [b,g,k,d]
287 | x = self.pre_blocks_list[i](x) # [b,d,g]
288 | x = self.pos_blocks_list[i](x) # [b,d,g]
289 |
290 | x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
291 | x = self.classifier(x)
292 | return x
293 |
294 |
295 | def pointMLP_CoC(num_classes=40, **kwargs) -> Model:
296 | return Model(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=1,
297 | activation="relu", bias=False, use_xyz=False, normalize="anchor",
298 | dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
299 | k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
300 |
301 |
302 | if __name__ == '__main__':
303 | data = torch.rand(2, 3, 1024)
304 | print("===> testing pointMLP ...")
305 | model = pointMLP_CoC2()
306 | out = model(data)
307 | print(out.shape)
308 | parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
309 | print(f"Paramter number is: {parameters}")
310 |
311 |
--------------------------------------------------------------------------------
/backbone/radar/model_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Model utils for our Point Cloud Efficient Transformer for TPAMI.
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | # from torch import einsum
8 | # from einops import rearrange, repeat
9 |
10 |
11 | def get_activation(activation):
12 | if activation.lower() == 'gelu':
13 | return nn.GELU()
14 | elif activation.lower() == 'rrelu':
15 | return nn.RReLU(inplace=True)
16 | elif activation.lower() == 'selu':
17 | return nn.SELU(inplace=True)
18 | elif activation.lower() == 'silu':
19 | return nn.SiLU(inplace=True)
20 | elif activation.lower() == 'hardswish':
21 | return nn.Hardswish(inplace=True)
22 | elif activation.lower() == 'leakyrelu':
23 | return nn.LeakyReLU(inplace=True)
24 | elif activation.lower() == 'leakyrelu0.2':
25 | return nn.LeakyReLU(0.2, inplace=True)
26 | else:
27 | return nn.ReLU(inplace=True)
28 |
29 |
30 | def square_distance(src, dst):
31 | """
32 | Calculate Euclid distance between each two points.
33 | src^T * dst = xn * xm + yn * ym + zn * zm;
34 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
35 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
36 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
37 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
38 | Input:
39 | src: source points, [B, N, C]
40 | dst: target points, [B, M, C]
41 | Output:
42 | dist: per-point square distance, [B, N, M]
43 | """
44 | B, N, _ = src.shape
45 | _, M, _ = dst.shape
46 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
47 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
48 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
49 | return dist
50 |
51 |
52 | def index_points(points, idx):
53 | """
54 | Input:
55 | points: input points data, [B, N, C]
56 | idx: sample index data, [B, S]
57 | Return:
58 | new_points:, indexed points data, [B, S, C]
59 | """
60 | device = points.device
61 | B = points.shape[0]
62 | view_shape = list(idx.shape)
63 | view_shape[1:] = [1] * (len(view_shape) - 1)
64 | repeat_shape = list(idx.shape)
65 | repeat_shape[0] = 1
66 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
67 | new_points = points[batch_indices, idx, :]
68 | return new_points
69 |
70 |
71 | def farthest_point_sample(xyz, npoint):
72 | """
73 | Input:
74 | xyz: pointcloud data, [B, N, 3]
75 | npoint: number of samples
76 | Return:
77 | centroids: sampled pointcloud index, [B, npoint]
78 | """
79 | device = xyz.device
80 | B, N, C = xyz.shape
81 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
82 | distance = torch.ones(B, N).to(device) * 1e10
83 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
84 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
85 | for i in range(npoint):
86 | centroids[:, i] = farthest
87 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
88 | dist = torch.sum((xyz - centroid) ** 2, -1)
89 | distance = torch.min(distance, dist)
90 | farthest = torch.max(distance, -1)[1]
91 | return centroids
92 |
93 |
94 | def query_ball_point(radius, nsample, xyz, new_xyz):
95 | """
96 | Input:
97 | radius: local region radius
98 | nsample: max sample number in local region
99 | xyz: all points, [B, N, 3]
100 | new_xyz: query points, [B, S, 3]
101 | Return:
102 | group_idx: grouped points index, [B, S, nsample]
103 | """
104 | device = xyz.device
105 | B, N, C = xyz.shape
106 | _, S, _ = new_xyz.shape
107 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
108 | sqrdists = square_distance(new_xyz, xyz)
109 | group_idx[sqrdists > radius ** 2] = N
110 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
111 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
112 | mask = group_idx == N
113 | group_idx[mask] = group_first[mask]
114 | return group_idx
115 |
116 |
117 | def knn_point(nsample, xyz, new_xyz):
118 | """
119 | Input:
120 | nsample: max sample number in local region
121 | xyz: all points, [B, N, C]
122 | new_xyz: query points, [B, S, C]
123 | Return:
124 | group_idx: grouped points index, [B, S, nsample]
125 | """
126 | sqrdists = square_distance(new_xyz, xyz)
127 | _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
128 | return group_idx
129 |
130 |
131 |
--------------------------------------------------------------------------------
/backbone/radar/pointnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.parallel
5 | import torch.utils.data
6 | from torch.autograd import Variable
7 | import numpy as np
8 | import torch.nn.functional as F
9 |
10 |
11 | class STN3d(nn.Module):
12 | def __init__(self):
13 | super(STN3d, self).__init__()
14 | self.conv1 = torch.nn.Conv1d(3, 64, 1)
15 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
16 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
17 | self.fc1 = nn.Linear(1024, 512)
18 | self.fc2 = nn.Linear(512, 256)
19 | self.fc3 = nn.Linear(256, 9)
20 | self.relu = nn.ReLU()
21 |
22 | self.bn1 = nn.BatchNorm1d(64)
23 | self.bn2 = nn.BatchNorm1d(128)
24 | self.bn3 = nn.BatchNorm1d(1024)
25 | self.bn4 = nn.BatchNorm1d(512)
26 | self.bn5 = nn.BatchNorm1d(256)
27 |
28 |
29 | def forward(self, x):
30 | batchsize = x.size()[0]
31 | x = F.relu(self.bn1(self.conv1(x)))
32 | x = F.relu(self.bn2(self.conv2(x)))
33 | x = F.relu(self.bn3(self.conv3(x)))
34 | x = torch.max(x, 2, keepdim=True)[0]
35 | x = x.view(-1, 1024)
36 |
37 | x = F.relu(self.bn4(self.fc1(x)))
38 | x = F.relu(self.bn5(self.fc2(x)))
39 | x = self.fc3(x)
40 |
41 | iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
42 | if x.is_cuda:
43 | iden = iden.cuda()
44 | x = x + iden
45 | x = x.view(-1, 3, 3)
46 | return x
47 |
48 |
49 | class STNkd(nn.Module):
50 | def __init__(self, k=64):
51 | super(STNkd, self).__init__()
52 | self.conv1 = torch.nn.Conv1d(k, 64, 1)
53 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
54 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
55 | self.fc1 = nn.Linear(1024, 512)
56 | self.fc2 = nn.Linear(512, 256)
57 | self.fc3 = nn.Linear(256, k*k)
58 | self.relu = nn.ReLU()
59 |
60 | self.bn1 = nn.BatchNorm1d(64)
61 | self.bn2 = nn.BatchNorm1d(128)
62 | self.bn3 = nn.BatchNorm1d(1024)
63 | self.bn4 = nn.BatchNorm1d(512)
64 | self.bn5 = nn.BatchNorm1d(256)
65 |
66 | self.k = k
67 |
68 | def forward(self, x):
69 | batchsize = x.size()[0]
70 | x = F.relu(self.bn1(self.conv1(x)))
71 | x = F.relu(self.bn2(self.conv2(x)))
72 | x = F.relu(self.bn3(self.conv3(x)))
73 | x = torch.max(x, 2, keepdim=True)[0]
74 | x = x.view(-1, 1024)
75 |
76 | x = F.relu(self.bn4(self.fc1(x)))
77 | x = F.relu(self.bn5(self.fc2(x)))
78 | x = self.fc3(x)
79 |
80 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
81 | if x.is_cuda:
82 | iden = iden.cuda()
83 | x = x + iden
84 | x = x.view(-1, self.k, self.k)
85 | return x
86 |
87 | class PointNetfeat(nn.Module):
88 | def __init__(self, global_feat = True, feature_transform = False):
89 | super(PointNetfeat, self).__init__()
90 | self.stn = STN3d()
91 | self.conv1 = torch.nn.Conv1d(3, 64, 1)
92 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
93 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
94 | self.bn1 = nn.BatchNorm1d(64)
95 | self.bn2 = nn.BatchNorm1d(128)
96 | self.bn3 = nn.BatchNorm1d(1024)
97 | self.global_feat = global_feat
98 | self.feature_transform = feature_transform
99 | if self.feature_transform:
100 | self.fstn = STNkd(k=64)
101 |
102 | def forward(self, x):
103 | n_pts = x.size()[2]
104 | trans = self.stn(x)
105 | x = x.transpose(2, 1)
106 | x = torch.bmm(x, trans)
107 | x = x.transpose(2, 1)
108 | x = F.relu(self.bn1(self.conv1(x)))
109 |
110 | if self.feature_transform:
111 | trans_feat = self.fstn(x)
112 | x = x.transpose(2,1)
113 | x = torch.bmm(x, trans_feat)
114 | x = x.transpose(2,1)
115 | else:
116 | trans_feat = None
117 |
118 | pointfeat = x
119 | x = F.relu(self.bn2(self.conv2(x)))
120 | x = self.bn3(self.conv3(x))
121 | x = torch.max(x, 2, keepdim=True)[0]
122 | x = x.view(-1, 1024)
123 | if self.global_feat:
124 | return x, trans, trans_feat
125 | else:
126 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
127 | return torch.cat([x, pointfeat], 1), trans, trans_feat
128 |
129 | class PointNetCls(nn.Module):
130 | def __init__(self, k=2, feature_transform=False):
131 | super(PointNetCls, self).__init__()
132 | self.feature_transform = feature_transform
133 | self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
134 | self.fc1 = nn.Linear(1024, 512)
135 | self.fc2 = nn.Linear(512, 256)
136 | self.fc3 = nn.Linear(256, k)
137 | self.dropout = nn.Dropout(p=0.3)
138 | self.bn1 = nn.BatchNorm1d(512)
139 | self.bn2 = nn.BatchNorm1d(256)
140 | self.relu = nn.ReLU()
141 |
142 | def forward(self, x):
143 | x, trans, trans_feat = self.feat(x)
144 | x = F.relu(self.bn1(self.fc1(x)))
145 | x = F.relu(self.bn2(self.dropout(self.fc2(x))))
146 | x = self.fc3(x)
147 | return F.log_softmax(x, dim=1), trans, trans_feat
148 |
149 |
150 | class PointNetDenseCls(nn.Module):
151 | def __init__(self, k = 2, feature_transform=False):
152 | super(PointNetDenseCls, self).__init__()
153 | self.k = k
154 | self.feature_transform=feature_transform
155 | self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
156 | self.conv1 = torch.nn.Conv1d(1088, 512, 1)
157 | self.conv2 = torch.nn.Conv1d(512, 256, 1)
158 | self.conv3 = torch.nn.Conv1d(256, 128, 1)
159 | self.conv4 = torch.nn.Conv1d(128, self.k, 1)
160 | self.bn1 = nn.BatchNorm1d(512)
161 | self.bn2 = nn.BatchNorm1d(256)
162 | self.bn3 = nn.BatchNorm1d(128)
163 |
164 | def forward(self, x):
165 | batchsize = x.size()[0]
166 | n_pts = x.size()[2]
167 | x, trans, trans_feat = self.feat(x)
168 | x = F.relu(self.bn1(self.conv1(x)))
169 | x = F.relu(self.bn2(self.conv2(x)))
170 | x = F.relu(self.bn3(self.conv3(x)))
171 | x = self.conv4(x)
172 | x = x.transpose(2,1).contiguous()
173 | x = F.log_softmax(x.view(-1,self.k), dim=-1)
174 | x = x.view(batchsize, n_pts, self.k)
175 | return x, trans, trans_feat
176 |
177 | def feature_transform_regularizer(trans):
178 | d = trans.size()[1]
179 | batchsize = trans.size()[0]
180 | I = torch.eye(d)[None, :, :]
181 | if trans.is_cuda:
182 | I = I.cuda()
183 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
184 | return loss
185 |
186 | if __name__ == '__main__':
187 | sim_data = Variable(torch.rand(32,3,2500))
188 | trans = STN3d()
189 | out = trans(sim_data)
190 | print('stn', out.size())
191 | print('loss', feature_transform_regularizer(out))
192 |
193 | sim_data_64d = Variable(torch.rand(32, 64, 2500))
194 | trans = STNkd(k=64)
195 | out = trans(sim_data_64d)
196 | print('stn64d', out.size())
197 | print('loss', feature_transform_regularizer(out))
198 |
199 | pointfeat = PointNetfeat(global_feat=True)
200 | out, _, _ = pointfeat(sim_data)
201 | print('global feat', out.size())
202 |
203 | pointfeat = PointNetfeat(global_feat=False)
204 | out, _, _ = pointfeat(sim_data)
205 | print('point feat', out.size())
206 |
207 | cls = PointNetCls(k = 5)
208 | out, _, _ = cls(sim_data)
209 | print('class', out.size())
210 |
211 | seg = PointNetDenseCls(k = 3)
212 | out, _, _ = seg(sim_data)
213 | print('seg', out.size())
--------------------------------------------------------------------------------
/backbone/vision/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/vision/__init__.py
--------------------------------------------------------------------------------
/backbone/vision/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/vision/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/backbone/vision/__pycache__/context_cluster.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/backbone/vision/__pycache__/context_cluster.cpython-39.pyc
--------------------------------------------------------------------------------
/get_miou.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 | from tqdm import tqdm
5 |
6 | from deeplab import DeeplabV3
7 | from utils_seg.utils_metrics import compute_mIoU, show_results
8 |
9 | '''
10 | 进行指标评估需要注意以下几点:
11 | 1、该文件生成的图为灰度图,因为值比较小,按照PNG形式的图看是没有显示效果的,所以看到近似全黑的图是正常的。
12 | 2、该文件计算的是验证集的miou,当前该库将测试集当作验证集使用,不单独划分测试集
13 | '''
14 | if __name__ == "__main__":
15 | #---------------------------------------------------------------------------#
16 | # miou_mode用于指定该文件运行时计算的内容
17 | # miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。
18 | # miou_mode为1代表仅仅获得预测结果。
19 | # miou_mode为2代表仅仅计算miou。
20 | #---------------------------------------------------------------------------#
21 | miou_mode = 0
22 | #------------------------------#
23 | # 分类个数+1、如2+1
24 | #------------------------------#
25 | num_classes = 21
26 | #--------------------------------------------#
27 | # 区分的种类,和json_to_dataset里面的一样
28 | #--------------------------------------------#
29 | name_classes = ["background","aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
30 | # name_classes = ["_background_","cat","dog"]
31 | #-------------------------------------------------------#
32 | # 指向VOC数据集所在的文件夹
33 | # 默认指向根目录下的VOC数据集
34 | #-------------------------------------------------------#
35 | VOCdevkit_path = 'E:/Big_Datasets/voc_seg/VOCdevkit/VOCdevkit'
36 |
37 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines()
38 | gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/")
39 | miou_out_path = "miou_out"
40 | pred_dir = os.path.join(miou_out_path, 'detection-results')
41 |
42 | if miou_mode == 0 or miou_mode == 1:
43 | if not os.path.exists(pred_dir):
44 | os.makedirs(pred_dir)
45 |
46 | print("Load model.")
47 | deeplab = DeeplabV3()
48 | print("Load model done.")
49 |
50 | print("Get predict result.")
51 | for image_id in tqdm(image_ids):
52 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
53 | image = Image.open(image_path)
54 | image = deeplab.get_miou_png(image)
55 | image.save(os.path.join(pred_dir, image_id + ".png"))
56 | print("Get predict result done.")
57 |
58 | if miou_mode == 0 or miou_mode == 2:
59 | print("Get miou.")
60 | hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数
61 | print("Get miou done.")
62 | show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
--------------------------------------------------------------------------------
/head/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/head/__init__.py
--------------------------------------------------------------------------------
/head/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/head/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/head/__pycache__/decouplehead.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/head/__pycache__/decouplehead.cpython-39.pyc
--------------------------------------------------------------------------------
/head/decouplehead.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from backbone.conv_utils.normal_conv import BaseConv
5 |
6 |
7 | class DecoupleHead(nn.Module):
8 | def __init__(self, num_classes, width=1.0, in_channels=[128, 320, 512], act="relu", depthwise=False):
9 | super().__init__()
10 | Conv = BaseConv
11 |
12 | self.cls_convs = nn.ModuleList()
13 | self.reg_convs = nn.ModuleList()
14 | self.cls_preds = nn.ModuleList()
15 | self.reg_preds = nn.ModuleList()
16 | self.obj_preds = nn.ModuleList()
17 | self.stems = nn.ModuleList()
18 |
19 | for i in range(len(in_channels)):
20 | self.stems.append(
21 | Conv(in_channels=int(in_channels[i] * width), out_channels=int(256 * width), ksize=1, stride=1,
22 | act=act))
23 | self.cls_convs.append(nn.Sequential(*[
24 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True),
25 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True),
26 | ]))
27 | self.cls_preds.append(
28 | nn.Conv2d(in_channels=int(256 * width), out_channels=num_classes, kernel_size=1, stride=1, padding=0)
29 | )
30 |
31 | self.reg_convs.append(nn.Sequential(*[
32 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True),
33 | Conv(in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ds_conv=True)
34 | ]))
35 | self.reg_preds.append(
36 | nn.Conv2d(in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0)
37 | )
38 | self.obj_preds.append(
39 | nn.Conv2d(in_channels=int(256 * width), out_channels=1, kernel_size=1, stride=1, padding=0)
40 | )
41 |
42 | def forward(self, inputs):
43 | # ---------------------------------------------------#
44 | # inputs输入
45 | # P3_out 80, 80, 256
46 | # P4_out 40, 40, 512
47 | # P5_out 20, 20, 1024
48 | # ---------------------------------------------------#
49 | outputs = []
50 | for k, x in enumerate(inputs):
51 | # ---------------------------------------------------#
52 | # 利用1x1卷积进行通道整合
53 | # ---------------------------------------------------#
54 | x = self.stems[k](x)
55 | # ---------------------------------------------------#
56 | # 利用两个卷积标准化激活函数来进行特征提取
57 | # ---------------------------------------------------#
58 | cls_feat = self.cls_convs[k](x)
59 | # ---------------------------------------------------#
60 | # 判断特征点所属的种类
61 | # 80, 80, num_classes
62 | # 40, 40, num_classes
63 | # 20, 20, num_classes
64 | # ---------------------------------------------------#
65 | cls_output = self.cls_preds[k](cls_feat)
66 |
67 | # ---------------------------------------------------#
68 | # 利用两个卷积标准化激活函数来进行特征提取
69 | # ---------------------------------------------------#
70 | reg_feat = self.reg_convs[k](x)
71 | # ---------------------------------------------------#
72 | # 特征点的回归系数
73 | # reg_pred 80, 80, 4
74 | # reg_pred 40, 40, 4
75 | # reg_pred 20, 20, 4
76 | # ---------------------------------------------------#
77 | reg_output = self.reg_preds[k](reg_feat)
78 | # ---------------------------------------------------#
79 | # 判断特征点是否有对应的物体
80 | # obj_pred 80, 80, 1
81 | # obj_pred 40, 40, 1
82 | # obj_pred 20, 20, 1
83 | # ---------------------------------------------------#
84 | obj_output = self.obj_preds[k](reg_feat)
85 |
86 | output = torch.cat([reg_output, obj_output, cls_output], 1)
87 | outputs.append(output)
88 | return outputs
--------------------------------------------------------------------------------
/image_augmentation_test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/image_augmentation_test/__init__.py
--------------------------------------------------------------------------------
/image_augmentation_test/dark_channel.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import cv2
3 | import math
4 | import numpy as np
5 |
6 |
7 | def DarkChannel(im, sz):
8 | b, g, r = cv2.split(im)
9 | dc = cv2.min(cv2.min(r, g), b)
10 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sz, sz))
11 | dark = cv2.erode(dc, kernel)
12 | return dark
13 |
14 |
15 | def AtmLight(im, dark):
16 | [h, w] = im.shape[:2]
17 | imsz = h * w
18 | numpx = int(max(math.floor(imsz / 1000), 1))
19 | darkvec = dark.reshape(imsz, 1)
20 | imvec = im.reshape(imsz, 3)
21 |
22 | indices = darkvec.argsort()
23 | indices = indices[imsz - numpx::]
24 |
25 | atmsum = np.zeros([1, 3])
26 | for ind in range(1, numpx):
27 | atmsum = atmsum + imvec[indices[ind]]
28 |
29 | A = atmsum / numpx;
30 | return A
31 |
32 |
33 | def TransmissionEstimate(im, A, sz):
34 | omega = 0.95
35 | im3 = np.empty(im.shape, im.dtype)
36 |
37 | for ind in range(0, 3):
38 | im3[:, :, ind] = im[:, :, ind] / A[0, ind]
39 |
40 | transmission = 1 - omega * DarkChannel(im3, sz)
41 | return transmission
42 |
43 |
44 | def Guidedfilter(im, p, r, eps):
45 | mean_I = cv2.boxFilter(im, cv2.CV_64F, (r, r))
46 | mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r))
47 | mean_Ip = cv2.boxFilter(im * p, cv2.CV_64F, (r, r))
48 | cov_Ip = mean_Ip - mean_I * mean_p
49 |
50 | mean_II = cv2.boxFilter(im * im, cv2.CV_64F, (r, r))
51 | var_I = mean_II - mean_I * mean_I
52 |
53 | a = cov_Ip / (var_I + eps)
54 | b = mean_p - a * mean_I
55 |
56 | mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r))
57 | mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r))
58 |
59 | q = mean_a * im + mean_b
60 | return q
61 |
62 |
63 | def TransmissionRefine(im, et):
64 | gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
65 | gray = np.float64(gray) / 255
66 | r = 60
67 | eps = 0.0001
68 | t = Guidedfilter(gray, et, r, eps)
69 |
70 | return t
71 |
72 |
73 | def Recover(im, t, A, tx=0.1):
74 | res = np.empty(im.shape, im.dtype)
75 | t = cv2.max(t, tx)
76 |
77 | for ind in range(0, 3):
78 | res[:, :, ind] = (im[:, :, ind] - A[0, ind]) / t + A[0, ind]
79 |
80 | return res
81 |
82 |
83 | if __name__ == '__main__':
84 | fn = '../images/fog_image.png'
85 | src = cv2.imread(fn)
86 | I = src.astype('float64') / 255
87 |
88 | dark = DarkChannel(I, 15)
89 | A = AtmLight(I, dark)
90 | te = TransmissionEstimate(I, A, 15)
91 | t = TransmissionRefine(src, te)
92 | J = Recover(I, t, A, 0.1)
93 |
94 | arr = np.hstack((I, J))
95 | cv2.imshow("contrast", arr)
96 | cv2.imwrite("../images/car-02-dehaze.png", J * 255)
97 | cv2.imwrite("../images/car-02-contrast.png", arr * 255)
98 | cv2.waitKey();
--------------------------------------------------------------------------------
/image_augmentation_test/sharpen.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import math
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | # 线性拉伸处理
8 | # 去掉最大最小0.5%的像素值 线性拉伸至[0,1]
9 | def stretchImage(data, s=0.005, bins=2000):
10 | ht = np.histogram(data, bins);
11 | d = np.cumsum(ht[0]) / float(data.size)
12 | lmin = 0;
13 | lmax = bins - 1
14 | while lmin < bins:
15 | if d[lmin] >= s:
16 | break
17 | lmin += 1
18 | while lmax >= 0:
19 | if d[lmax] <= 1 - s:
20 | break
21 | lmax -= 1
22 | return np.clip((data - ht[1][lmin]) / (ht[1][lmax] - ht[1][lmin]), 0, 1)
23 |
24 |
25 | # 根据半径计算权重参数矩阵
26 | g_para = {}
27 |
28 |
29 | def getPara(radius=5):
30 | global g_para
31 | m = g_para.get(radius, None)
32 | if m is not None:
33 | return m
34 | size = radius * 2 + 1
35 | m = np.zeros((size, size))
36 | for h in range(-radius, radius + 1):
37 | for w in range(-radius, radius + 1):
38 | if h == 0 and w == 0:
39 | continue
40 | m[radius + h, radius + w] = 1.0 / math.sqrt(h ** 2 + w ** 2)
41 | m /= m.sum()
42 | g_para[radius] = m
43 | return m
44 |
45 |
46 | # 常规的ACE实现
47 | def zmIce(I, ratio=4, radius=300):
48 | para = getPara(radius)
49 | height, width = I.shape
50 | zh = []
51 | zw = []
52 | n = 0
53 | while n < radius:
54 | zh.append(0)
55 | zw.append(0)
56 | n += 1
57 | for n in range(height):
58 | zh.append(n)
59 | for n in range(width):
60 | zw.append(n)
61 | n = 0
62 | while n < radius:
63 | zh.append(height - 1)
64 | zw.append(width - 1)
65 | n += 1
66 | # print(zh)
67 | # print(zw)
68 |
69 | Z = I[np.ix_(zh, zw)]
70 | res = np.zeros(I.shape)
71 | for h in range(radius * 2 + 1):
72 | for w in range(radius * 2 + 1):
73 | if para[h][w] == 0:
74 | continue
75 | res += (para[h][w] * np.clip((I - Z[h:h + height, w:w + width]) * ratio, -1, 1))
76 | return res
77 |
78 |
79 | # 单通道ACE快速增强实现
80 | def zmIceFast(I, ratio, radius):
81 | print(I)
82 | height, width = I.shape[:2]
83 | if min(height, width) <= 2:
84 | return np.zeros(I.shape) + 0.5
85 | Rs = cv2.resize(I, (int((width + 1) / 2), int((height + 1) / 2)))
86 | Rf = zmIceFast(Rs, ratio, radius) # 递归调用
87 | Rf = cv2.resize(Rf, (width, height))
88 | Rs = cv2.resize(Rs, (width, height))
89 |
90 | return Rf + zmIce(I, ratio, radius) - zmIce(Rs, ratio, radius)
91 |
92 |
93 | # rgb三通道分别增强 ratio是对比度增强因子 radius是卷积模板半径
94 | def zmIceColor(I, ratio=4, radius=3):
95 | res = np.zeros(I.shape)
96 | for k in range(3):
97 | res[:, :, k] = stretchImage(zmIceFast(I[:, :, k], ratio, radius))
98 | return res
99 |
100 |
101 | # 主函数
102 | if __name__ == '__main__':
103 | img = cv2.imread('../images/rain_image.png')
104 | res = zmIceColor(img / 255.0) * 255
105 | cv2.imwrite('../images/ship-remove-rain.jpg', res)
--------------------------------------------------------------------------------
/model_data/heatmap_vision.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/model_data/heatmap_vision.png
--------------------------------------------------------------------------------
/model_data/voc_classes.txt:
--------------------------------------------------------------------------------
1 | aeroplane
2 | bicycle
3 | bird
4 | boat
5 | bottle
6 | bus
7 | car
8 | cat
9 | chair
10 | cow
11 | diningtable
12 | dog
13 | horse
14 | motorbike
15 | person
16 | pottedplant
17 | sheep
18 | sofa
19 | train
20 | tvmonitor
--------------------------------------------------------------------------------
/model_data/waterscenes.txt:
--------------------------------------------------------------------------------
1 | pier
2 | vessel
3 | ship
4 | boat
--------------------------------------------------------------------------------
/neck/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/neck/__init__.py
--------------------------------------------------------------------------------
/neck/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/neck/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/neck/__pycache__/coc_fpn_dual.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/neck/__pycache__/coc_fpn_dual.cpython-39.pyc
--------------------------------------------------------------------------------
/neck/coc_fpn_dual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from thop import profile
5 | from thop import clever_format
6 | import math
7 | from backbone.attention_modules.eca import eca_block
8 | from backbone.attention_modules.shuffle_attention import ShuffleAttention
9 | from backbone.conv_utils.normal_conv import DWConv, BaseConv
10 | from backbone.vision.context_cluster import ClusterBlock
11 | from backbone.fusion.vr_coc import coc_medium, coc_small
12 | from torchinfo import summary
13 |
14 |
15 | class CoCUpsample(nn.Module):
16 | def __init__(self, in_channels, out_channels, scale=2, ds_conv=False):
17 | super().__init__()
18 |
19 | self.upsample = nn.Sequential(
20 | BaseConv(in_channels, out_channels, 1, 1, act='relu', ds_conv=ds_conv),
21 | nn.Upsample(scale_factor=scale, mode='bilinear', align_corners=True)
22 | )
23 |
24 | def forward(self, x,):
25 | x = self.upsample(x)
26 | return x
27 |
28 |
29 | class CoC_Conv(nn.Module):
30 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="relu", ds_conv=False):
31 | super(CoC_Conv, self).__init__()
32 |
33 | self.coc = ClusterBlock(dim=in_channels)
34 | self.conv_att = BaseConv(in_channels, out_channels, ksize=ksize, stride=stride, act=act, ds_conv=ds_conv)
35 |
36 | def forward(self, x):
37 | x = self.coc(x)
38 | x = self.conv_att(x)
39 | return x
40 |
41 |
42 | # -----------------------------------------#
43 | # ASPP特征提取模块
44 | # 利用不同膨胀率的膨胀卷积进行特征提取
45 | # -----------------------------------------#
46 | class ASPP(nn.Module):
47 | def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
48 | super(ASPP, self).__init__()
49 | self.branch1 = nn.Sequential(
50 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
51 | nn.BatchNorm2d(dim_out, momentum=bn_mom),
52 | nn.ReLU(inplace=True),
53 | )
54 | self.branch2 = nn.Sequential(
55 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
56 | nn.BatchNorm2d(dim_out, momentum=bn_mom),
57 | nn.ReLU(inplace=True),
58 | )
59 | self.branch3 = nn.Sequential(
60 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
61 | nn.BatchNorm2d(dim_out, momentum=bn_mom),
62 | nn.ReLU(inplace=True),
63 | )
64 | self.branch4 = nn.Sequential(
65 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
66 | nn.BatchNorm2d(dim_out, momentum=bn_mom),
67 | nn.ReLU(inplace=True),
68 | )
69 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
70 | self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
71 | self.branch5_relu = nn.ReLU(inplace=True)
72 |
73 | self.conv_cat = nn.Sequential(
74 | nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
75 | nn.BatchNorm2d(dim_out, momentum=bn_mom),
76 | nn.ReLU(inplace=True),
77 | )
78 |
79 | def forward(self, x):
80 | [b, c, row, col] = x.size()
81 | # -----------------------------------------#
82 | # 一共五个分支
83 | # -----------------------------------------#
84 | conv1x1 = self.branch1(x)
85 | conv3x3_1 = self.branch2(x)
86 | conv3x3_2 = self.branch3(x)
87 | conv3x3_3 = self.branch4(x)
88 | # -----------------------------------------#
89 | # 第五个分支,全局平均池化+卷积
90 | # -----------------------------------------#
91 | global_feature = torch.mean(x, 2, True)
92 | global_feature = torch.mean(global_feature, 3, True)
93 | global_feature = self.branch5_conv(global_feature)
94 | global_feature = self.branch5_bn(global_feature)
95 | global_feature = self.branch5_relu(global_feature)
96 | global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
97 |
98 | # -----------------------------------------#
99 | # 将五个分支的内容堆叠起来
100 | # 然后1x1卷积整合特征。
101 | # -----------------------------------------#
102 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
103 | result = self.conv_cat(feature_cat)
104 | return result
105 |
106 |
107 | class SpatialPyramidPooling(nn.Module):
108 | def __init__(self, pool_sizes=[5, 9, 13]):
109 | super(SpatialPyramidPooling, self).__init__()
110 |
111 | self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes])
112 |
113 | def forward(self, x):
114 | features = [maxpool(x) for maxpool in self.maxpools[::-1]]
115 | features = torch.cat(features + [x], dim=1)
116 |
117 | return features
118 |
119 |
120 | def shuffle_channels(x, groups=2):
121 | """Channel Shuffle"""
122 |
123 | batch_size, channels, h, w = x.data.size()
124 | if channels % groups:
125 | return x
126 | channels_per_group = channels // groups
127 | x = x.view(batch_size, groups, channels_per_group, h, w)
128 | x = torch.transpose(x, 1, 2).contiguous()
129 | x = x.view(batch_size, -1, h, w)
130 | return x
131 |
132 |
133 | class CoCFpnDual(nn.Module):
134 | def __init__(self, num_seg_class=9, depth=1.0, width=1.0, in_features=("dark2", "dark3", "dark4", "dark5"),
135 | in_channels=[64, 128, 320, 512], aspp_channel=1024):
136 | super().__init__()
137 |
138 | Conv = CoC_Conv
139 |
140 | self.backbone = coc_small(pretrained=False, width=width)
141 | self.in_features = in_features
142 | self.num_seg_class = num_seg_class
143 | in_channels = [int(item*width) for item in in_channels]
144 |
145 | self.aspp = ASPP(dim_in=in_channels[-1], dim_out=in_channels[-1])
146 |
147 | # ================================= segmentation modules =================================== #
148 | # ----------------------- 20*20*512 -> 40*40*320 -> 40*40*640 ------------------------ #
149 | self.upsample5_4 = CoCUpsample(in_channels=in_channels[-1], out_channels=in_channels[-2])
150 | self.sc_attn_seg4 = ShuffleAttention(channel=in_channels[-2]*2)
151 | # ------------------------------------------------------------------------------------- #
152 |
153 | # ----------------------- 40*40*640 -> 80*80*128 -> 80*80*256 ------------------------ #
154 | self.upsample4_3 = CoCUpsample(in_channels=in_channels[-2]*2, out_channels=in_channels[-3])
155 | self.sc_attn_seg3 = ShuffleAttention(channel=in_channels[-3] * 2)
156 | # ------------------------------------------------------------------------------------ #
157 |
158 | # ----------------------- 80*80*256 -> 160*160*64 -> 160*160*128 ------------------------ #
159 | self.upsample3_2 = CoCUpsample(in_channels=in_channels[-3] * 2, out_channels=in_channels[0])
160 | self.sc_attn_seg2 = ShuffleAttention(channel=in_channels[0] * 2)
161 | # ------------------------------------------------------------------------------------ #
162 |
163 | # ----------------------- 80*80*256 -> 160*160*64 -> 640*640*9 ------------------------ #
164 | self.upsample2_0 = CoCUpsample(in_channels=in_channels[0] * 2, out_channels=self.num_seg_class, scale=4)
165 | # ------------------------------------------------------------------------------------ #
166 | # ========================================================================================== #
167 |
168 | # ================================= detection modules ====================================== #
169 | # ----------------------- 20*20*512 -> 20*20*512 ----------------------- #
170 | self.p5_out_det = Conv(in_channels=in_channels[-1], out_channels=in_channels[-1])
171 | # ----------------------------------------------------------------------- #
172 |
173 | # ----------------------- 20*20*512 -> 40*40*320 -> 40*40*640 -> 40*40*320 ------------------------ #
174 | self.p5_4_det = CoCUpsample(in_channels=in_channels[-1], out_channels=in_channels[-2])
175 | self.p4_out_det = Conv(in_channels=in_channels[-2]*2, out_channels=in_channels[-2])
176 | # ------------------------------------------------------------------------------------------------- #
177 |
178 | # ----------------------- 40*40*320 -> 80*80*128 -> 80*80*256 -> 80*80*128 ------------------------ #
179 | self.p4_3_det = CoCUpsample(in_channels=in_channels[-2], out_channels=in_channels[-3])
180 | self.p3_out_det = Conv(in_channels=in_channels[-3]*2, out_channels=in_channels[-3])
181 | # ------------------------------------------------------------------------------------------------- #
182 | # ========================================================================================== #
183 |
184 | def forward(self, x, x_radar):
185 |
186 | x_out, x_radar_out = self.backbone(x, x_radar)
187 |
188 | x_stage2, x_stage3, x_stage4, x_stage5 = x_out
189 | x_stage5 = self.aspp(x_stage5)
190 |
191 | x_radar_stage2, x_radar_stage3, x_radar_stage4, x_radar_stage5 = x_radar_out
192 |
193 | # ---------------------------- segmentation ------------------------------- #
194 | x_stage5_4 = self.upsample5_4(x_stage5)
195 | x_stage4_concat_5 = torch.cat([x_stage4, x_stage5_4], dim=1)
196 | x_stage4_concat_5 = shuffle_channels(x_stage4_concat_5)
197 | x_stage4_concat_5 = self.sc_attn_seg4(x_stage4_concat_5)
198 |
199 | x_stage4_3 = self.upsample4_3(x_stage4_concat_5)
200 | x_stage3_concat_4 = torch.cat([x_stage4_3, x_stage3], dim=1)
201 | x_stage3_concat_4 = shuffle_channels(x_stage3_concat_4)
202 | x_stage3_concat_4 = self.sc_attn_seg3(x_stage3_concat_4)
203 |
204 | x_stage3_2 = self.upsample3_2(x_stage3_concat_4)
205 | x_stage2_concat_3 = torch.cat([x_stage3_2, x_stage2], dim=1)
206 | x_stage2_concat_3 = shuffle_channels(x_stage2_concat_3)
207 | x_stage2_concat_3 = self.sc_attn_seg2(x_stage2_concat_3)
208 |
209 | x_segmentation_out = self.upsample2_0(x_stage2_concat_3)
210 | # ------------------------------------------------------------------------ #
211 |
212 | # ----------------------------- detection -------------------------------- #
213 | p5_out = self.p5_out_det(x_radar_stage5)
214 |
215 | p5_4_upsample = self.p5_4_det(p5_out)
216 | p4_concat_5 = torch.cat([x_radar_stage4, p5_4_upsample], dim=1)
217 | p4_out = self.p4_out_det(p4_concat_5)
218 |
219 | p4_3_upsample = self.p4_3_det(p4_out)
220 | p3_concat_4 = torch.cat([x_radar_stage3, p4_3_upsample], dim=1)
221 | p3_out = self.p3_out_det(p3_concat_4)
222 | # ------------------------------------------------------------------------ #
223 |
224 | return (p3_out, p4_out, p5_out), x_segmentation_out
225 |
226 |
227 | if __name__ == '__main__':
228 | # input_map = torch.randn((1, 512, 20, 20)).cuda()
229 | # aspp = ASPP(dim_in=512, dim_out=1024).cuda()
230 | model = CoCFpnDual(width=1.0).cuda()
231 | model.eval()
232 | input = torch.rand(1, 3, 512, 512).cuda()
233 | input_radar = torch.rand(1, 4, 512, 512).cuda()
234 | output = model(input, input_radar)
235 | print(summary(model, input_size=[(1, 3, 512, 512), (1, 4, 512, 512)]))
236 | macs, params = profile(model, inputs=(input, input_radar))
237 | macs, params = clever_format([macs, params], "%.3f")
238 | print("params:", params)
239 | print("macs:", macs)
240 | print(output[1].shape)
241 | print(output[0][0].shape)
242 | print(output[0][1].shape)
243 | print(output[0][2].shape)
244 | # model = SpatialPyramidPooling()
245 | # input = torch.rand(1, 512, 20, 20)
246 | # output = model(input)
247 | # print(output.shape)
248 |
249 |
250 |
251 |
252 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__init__.py
--------------------------------------------------------------------------------
/nets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/deeplabv3_training.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/deeplabv3_training.cpython-39.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/efficient_vrnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/efficient_vrnet.cpython-39.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/yolo_training.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/nets/__pycache__/yolo_training.cpython-39.pyc
--------------------------------------------------------------------------------
/nets/deeplabv3_training.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def CE_Loss(inputs, target, cls_weights, num_classes=21):
10 | n, c, h, w = inputs.size()
11 | nt, ht, wt = target.size()
12 | if h != ht and w != wt:
13 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
14 |
15 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
16 | temp_target = target.view(-1)
17 |
18 | CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
19 | return CE_loss
20 |
21 |
22 | def Focal_Loss(inputs, target, cls_weights, num_classes=21, alpha=0.5, gamma=2):
23 | n, c, h, w = inputs.size()
24 | nt, ht, wt = target.size()
25 | if h != ht and w != wt:
26 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
27 |
28 | temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
29 | temp_target = target.view(-1)
30 |
31 | logpt = -nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes, reduction='none')(temp_inputs,
32 | temp_target)
33 | pt = torch.exp(logpt)
34 | if alpha is not None:
35 | logpt *= alpha
36 | loss = -((1 - pt) ** gamma) * logpt
37 | loss = loss.mean()
38 | return loss
39 |
40 |
41 | def Dice_loss(inputs, target, beta=1, smooth=1e-5):
42 | n, c, h, w = inputs.size()
43 | nt, ht, wt, ct = target.size()
44 | if h != ht and w != wt:
45 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
46 |
47 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1)
48 | temp_target = target.view(n, -1, ct)
49 |
50 | # --------------------------------------------#
51 | # 计算dice loss
52 | # --------------------------------------------#
53 | tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1])
54 | fp = torch.sum(temp_inputs, axis=[0, 1]) - tp
55 | fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp
56 |
57 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
58 | dice_loss = 1 - torch.mean(score)
59 | return dice_loss
60 |
61 |
62 | def weights_init(net, init_type='normal', init_gain=0.02):
63 | def init_func(m):
64 | classname = m.__class__.__name__
65 | if hasattr(m, 'weight') and classname.find('Conv') != -1:
66 | if init_type == 'normal':
67 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
68 | elif init_type == 'xavier':
69 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
70 | elif init_type == 'kaiming':
71 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
72 | elif init_type == 'orthogonal':
73 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
74 | else:
75 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
76 | elif classname.find('BatchNorm2d') != -1:
77 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
78 | torch.nn.init.constant_(m.bias.data, 0.0)
79 |
80 | print('initialize network with %s type' % init_type)
81 | net.apply(init_func)
82 |
83 |
84 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio=0.1, warmup_lr_ratio=0.1,
85 | no_aug_iter_ratio=0.3, step_num=10):
86 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
87 | if iters <= warmup_total_iters:
88 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
89 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
90 | elif iters >= total_iters - no_aug_iter:
91 | lr = min_lr
92 | else:
93 | lr = min_lr + 0.5 * (lr - min_lr) * (
94 | 1.0 + math.cos(
95 | math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
96 | )
97 | return lr
98 |
99 | def step_lr(lr, decay_rate, step_size, iters):
100 | if step_size < 1:
101 | raise ValueError("step_size must above 1.")
102 | n = iters // step_size
103 | out_lr = lr * decay_rate ** n
104 | return out_lr
105 |
106 | if lr_decay_type == "cos":
107 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
108 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
109 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
110 | func = partial(yolox_warm_cos_lr, lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
111 | else:
112 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
113 | step_size = total_iters / step_num
114 | func = partial(step_lr, lr, decay_rate, step_size)
115 |
116 | return func
117 |
118 |
119 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
120 | lr = lr_scheduler_func(epoch)
121 | for param_group in optimizer.param_groups:
122 | param_group['lr'] = lr
123 |
--------------------------------------------------------------------------------
/nets/efficient_vrnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from neck.coc_fpn_dual import CoCFpnDual
5 | from head.decouplehead import DecoupleHead
6 | from torchinfo import summary
7 | from thop import profile
8 | from thop import clever_format
9 | # from torchsummary import summary
10 | import time
11 |
12 |
13 | class EfficientVRNet(nn.Module):
14 | def __init__(self, num_classes, num_seg_classes, phi):
15 | super().__init__()
16 | depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00}
17 | width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00}
18 | depth, width = depth_dict[phi], width_dict[phi]
19 |
20 | self.backbone = CoCFpnDual(width=width, num_seg_class=num_seg_classes)
21 |
22 | self.head = DecoupleHead(num_classes, width, depthwise=True)
23 |
24 | def forward(self, x, x_radar):
25 | fpn_outs, seg_outputs = self.backbone.forward(x, x_radar)
26 | det_outputs = self.head.forward(fpn_outs)
27 | return det_outputs, seg_outputs
28 |
29 |
30 | if __name__ == '__main__':
31 | model = EfficientVRNet(num_classes=4, phi='l', num_seg_classes=9).cuda()
32 | model.eval()
33 | input_map1 = torch.randn((1, 3, 512, 512)).cuda()
34 | input_map2 = torch.randn((1, 4, 512, 512)).cuda()
35 | t1 = time.time()
36 | test_times = 300
37 | for i in range(test_times):
38 | output_map, output_seg = model(input_map1, input_map2)
39 | t2 = time.time()
40 | print("fps:", (1 / ((t2 - t1) / test_times)))
41 |
42 | output_map, output_seg = model(input_map1, input_map2)
43 | print(output_map[0].shape)
44 | print(output_map[1].shape)
45 | print(output_map[2].shape)
46 | print(output_seg.shape)
47 | print(summary(model, input_size=((1, 3, 512, 512), (1, 4, 512, 512))))
48 |
49 | macs, params = profile(model, inputs=([input_map1, input_map2]))
50 | flops = macs * 2
51 | flops, params = clever_format([flops, params], "%.3f")
52 | print("params:", params)
53 | print("flops:", flops)
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # -----------------------------------------------------------------------#
2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | # -----------------------------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from yolo import YOLO
12 |
13 | if __name__ == "__main__":
14 | yolo = YOLO()
15 | # ----------------------------------------------------------------------------------------------------------#
16 | # mode用于指定测试的模式:
17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
23 | # ----------------------------------------------------------------------------------------------------------#
24 | mode = "predict"
25 | # -------------------------------------------------------------------------#
26 | # crop 指定了是否在单张图片预测后对目标进行截取
27 | # count 指定了是否进行目标的计数
28 | # crop、count仅在mode='predict'时有效
29 | # -------------------------------------------------------------------------#
30 | crop = False
31 | count = False
32 | # ----------------------------------------------------------------------------------------------------------#
33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
37 | # video_fps 用于保存的视频的fps
38 | #
39 | # video_path、video_save_path和video_fps仅在mode='video'时有效
40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
41 | # ----------------------------------------------------------------------------------------------------------#
42 | video_path = 'images/video2.mp4'
43 | video_save_path = "images/video_det_out.mp4"
44 | video_fps = 33.0
45 | # ----------------------------------------------------------------------------------------------------------#
46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
47 | # fps_image_path 用于指定测试的fps图片
48 | #
49 | # test_interval和fps_image_path仅在mode='fps'有效
50 | # ----------------------------------------------------------------------------------------------------------#
51 | test_interval = 100
52 | fps_image_path = "images/example1.jpg"
53 | # -------------------------------------------------------------------------#
54 | # dir_origin_path 指定了用于检测的图片的文件夹路径
55 | # dir_save_path 指定了检测完图片的保存路径
56 | #
57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
58 | # -------------------------------------------------------------------------#
59 | dir_origin_path = "img/"
60 | dir_save_path = "img_out/"
61 | # -------------------------------------------------------------------------#
62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下
63 | #
64 | # heatmap_save_path仅在mode='heatmap'有效
65 | # -------------------------------------------------------------------------#
66 | heatmap_save_path = "model_data/heatmap_vision.png"
67 | # -------------------------------------------------------------------------#
68 | # simplify 使用Simplify onnx
69 | # onnx_save_path 指定了onnx的保存路径
70 | # -------------------------------------------------------------------------#
71 | simplify = True
72 | onnx_save_path = "model_data/models.onnx"
73 |
74 | if mode == "predict":
75 | '''
76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
77 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
78 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
79 | 在原图上利用矩阵的方式进行截取。
80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
82 | '''
83 | while True:
84 | img = input('Input image filename:')
85 | try:
86 | image = Image.open(img)
87 | image_id = img[-20:-4]
88 | except:
89 | print('Open Error! Try again!')
90 | continue
91 | else:
92 | r_image = yolo.detect_image(image, image_id, crop=crop, count=count)
93 | r_image.show()
94 |
95 | elif mode == "video":
96 | capture = cv2.VideoCapture(video_path)
97 | if video_save_path != "":
98 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
99 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
100 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
101 |
102 | ref, frame = capture.read()
103 | if not ref:
104 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
105 |
106 | fps = 0.0
107 | while (True):
108 | t1 = time.time()
109 | # 读取某一帧
110 | ref, frame = capture.read()
111 | if not ref:
112 | break
113 | # 格式转变,BGRtoRGB
114 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
115 | # 转变成Image
116 | frame = Image.fromarray(np.uint8(frame))
117 | # 进行检测
118 | frame = np.array(yolo.detect_image(frame))
119 | # RGBtoBGR满足opencv显示格式
120 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
121 |
122 | fps = (fps + (1. / (time.time() - t1))) / 2
123 | print("fps= %.2f" % (fps))
124 | # frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
125 |
126 | # cv2.imshow("video", frame)
127 | c = cv2.waitKey(1) & 0xff
128 | if video_save_path != "":
129 | out.write(frame)
130 |
131 | if c == 27:
132 | capture.release()
133 | break
134 |
135 | print("Video Detection Done!")
136 | capture.release()
137 | if video_save_path != "":
138 | print("Save processed video to the path :" + video_save_path)
139 | out.release()
140 | cv2.destroyAllWindows()
141 |
142 | elif mode == "fps":
143 | img = Image.open(fps_image_path)
144 | tact_time = yolo.get_FPS(img, test_interval)
145 | print(str(tact_time) + ' seconds, ' + str(1 / tact_time) + 'FPS, @batch_size 1')
146 |
147 | elif mode == "dir_predict":
148 | import os
149 |
150 | from tqdm import tqdm
151 |
152 | img_names = os.listdir(dir_origin_path)
153 | for img_name in tqdm(img_names):
154 | if img_name.lower().endswith(
155 | ('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
156 | image_path = os.path.join(dir_origin_path, img_name)
157 | image = Image.open(image_path)
158 | r_image = yolo.detect_image(image)
159 | if not os.path.exists(dir_save_path):
160 | os.makedirs(dir_save_path)
161 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
162 |
163 | elif mode == "heatmap":
164 | while True:
165 | img = input('Input image filename:')
166 | try:
167 | image = Image.open(img)
168 | image_id = img[-20:-4]
169 | except:
170 | print('Open Error! Try again!')
171 | continue
172 | else:
173 | yolo.detect_heatmap(image, image_id, heatmap_save_path)
174 |
175 | elif mode == "export_onnx":
176 | yolo.convert_to_onnx(simplify, onnx_save_path)
177 |
178 | else:
179 | raise AssertionError(
180 | "Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
181 |
--------------------------------------------------------------------------------
/predict_seg.py:
--------------------------------------------------------------------------------
1 | #----------------------------------------------------#
2 | # 将单张图片预测、摄像头检测和FPS测试功能
3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。
4 | #----------------------------------------------------#
5 | import time
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from deeplab import DeeplabV3
12 |
13 | if __name__ == "__main__":
14 | #-------------------------------------------------------------------------#
15 | # 如果想要修改对应种类的颜色,到__init__函数里修改self.colors即可
16 | #-------------------------------------------------------------------------#
17 | deeplab = DeeplabV3()
18 | #----------------------------------------------------------------------------------------------------------#
19 | # mode用于指定测试的模式:
20 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
21 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
22 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
23 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
24 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
25 | #----------------------------------------------------------------------------------------------------------#
26 | mode = "predict"
27 | #-------------------------------------------------------------------------#
28 | # count 指定了是否进行目标的像素点计数(即面积)与比例计算
29 | # name_classes 区分的种类,和json_to_dataset里面的一样,用于打印种类和数量
30 | #
31 | # count、name_classes仅在mode='predict'时有效
32 | #-------------------------------------------------------------------------#
33 | count = False
34 | name_classes = [ "free-space", "pier", "vessel", "ship", "boat", "buoy", "sailor", "kayak"]
35 | # name_classes = ["background","cat","dog"]
36 | #----------------------------------------------------------------------------------------------------------#
37 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
38 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
39 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
40 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
41 | # video_fps 用于保存的视频的fps
42 | #
43 | # video_path、video_save_path和video_fps仅在mode='video'时有效
44 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
45 | #----------------------------------------------------------------------------------------------------------#
46 | video_path = 'images/video2.mp4'
47 | video_save_path = "images/video_seg_out.mp4"
48 | video_fps = 33.0
49 | #----------------------------------------------------------------------------------------------------------#
50 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
51 | # fps_image_path 用于指定测试的fps图片
52 | #
53 | # test_interval和fps_image_path仅在mode='fps'有效
54 | #----------------------------------------------------------------------------------------------------------#
55 | test_interval = 100
56 | fps_image_path = "img/street.jpg"
57 | #-------------------------------------------------------------------------#
58 | # dir_origin_path 指定了用于检测的图片的文件夹路径
59 | # dir_save_path 指定了检测完图片的保存路径
60 | #
61 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
62 | #-------------------------------------------------------------------------#
63 | dir_origin_path = "img/"
64 | dir_save_path = "img_out/"
65 | #-------------------------------------------------------------------------#
66 | # simplify 使用Simplify onnx
67 | # onnx_save_path 指定了onnx的保存路径
68 | #-------------------------------------------------------------------------#
69 | simplify = True
70 | onnx_save_path = "model_data/models.onnx"
71 |
72 | if mode == "predict":
73 | '''
74 | predict.py有几个注意点
75 | 1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
76 | 具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。
77 | 2、如果想要保存,利用r_image.save("img.jpg")即可保存。
78 | 3、如果想要原图和分割图不混合,可以把blend参数设置成False。
79 | 4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。
80 | seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
81 | for c in range(self.num_classes):
82 | seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8')
83 | seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8')
84 | seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8')
85 | '''
86 | while True:
87 | img = input('Input image filename:')
88 | try:
89 | image = Image.open(img)
90 | image_id = img[-20:-4]
91 | except Exception as e:
92 | print('Open Error! Try again!')
93 | print(e)
94 | continue
95 | else:
96 | r_image = deeplab.detect_image(image, image_id, count=count, name_classes=name_classes)
97 | r_image.show()
98 |
99 | elif mode == "video":
100 | capture=cv2.VideoCapture(video_path)
101 | if video_save_path!="":
102 | fourcc = cv2.VideoWriter_fourcc(*'XVID')
103 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
104 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
105 |
106 | ref, frame = capture.read()
107 | if not ref:
108 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
109 |
110 | fps = 0.0
111 | while(True):
112 | t1 = time.time()
113 | # 读取某一帧
114 | ref, frame = capture.read()
115 | if not ref:
116 | break
117 | # 格式转变,BGRtoRGB
118 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
119 | # 转变成Image
120 | frame = Image.fromarray(np.uint8(frame))
121 | # 进行检测
122 | frame = np.array(deeplab.detect_image(frame))
123 | # RGBtoBGR满足opencv显示格式
124 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
125 |
126 | fps = ( fps + (1./(time.time()-t1)) ) / 2
127 | print("fps= %.2f"%(fps))
128 | # frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
129 | #
130 | # cv2.imshow("video",frame)
131 | c= cv2.waitKey(1) & 0xff
132 | if video_save_path!="":
133 | out.write(frame)
134 |
135 | if c==27:
136 | capture.release()
137 | break
138 | print("Video Detection Done!")
139 | capture.release()
140 | if video_save_path!="":
141 | print("Save processed video to the path :" + video_save_path)
142 | out.release()
143 | cv2.destroyAllWindows()
144 |
145 | elif mode == "fps":
146 | img = Image.open(fps_image_path)
147 | tact_time = deeplab.get_FPS(img, test_interval)
148 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
149 |
150 | elif mode == "dir_predict":
151 | import os
152 | from tqdm import tqdm
153 |
154 | img_names = os.listdir(dir_origin_path)
155 | for img_name in tqdm(img_names):
156 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
157 | image_path = os.path.join(dir_origin_path, img_name)
158 | image = Image.open(image_path)
159 | r_image = deeplab.detect_image(image)
160 | if not os.path.exists(dir_save_path):
161 | os.makedirs(dir_save_path)
162 | r_image.save(os.path.join(dir_save_path, img_name))
163 | elif mode == "export_onnx":
164 | deeplab.convert_to_onnx(simplify, onnx_save_path)
165 |
166 | else:
167 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps' or 'dir_predict'.")
168 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==1.1.0
2 | einops==0.4.1
3 | matplotlib==3.4.2
4 | numpy==1.21.6
5 | onnx==1.11.0
6 | onnxsim==0.4.17
7 | opencv_python_headless==4.5.5.64
8 | pandas==1.3.5
9 | Pillow==9.4.0
10 | pycocotools==2.0.5
11 | scikit_learn==1.0.2
12 | scipy==1.7.3
13 | thop==0.0.31.post2005241907
14 | timm==0.6.7
15 | torch==1.9.0
16 | torchinfo==1.7.0
17 | torchvision==0.10.0
18 | tqdm==4.64.0
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/callbacks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/callbacks.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/multitaskloss.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/multitaskloss.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_bbox.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils_bbox.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_fit.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils_fit.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_map.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils/__pycache__/utils_map.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import torch
4 | import matplotlib
5 |
6 | matplotlib.use('Agg')
7 | import scipy.signal
8 | from matplotlib import pyplot as plt
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | import shutil
12 | import numpy as np
13 |
14 | from PIL import Image
15 | from tqdm import tqdm
16 | from .utils import cvtColor, preprocess_input, resize_image
17 | from .utils_bbox import decode_outputs, non_max_suppression
18 | from .utils_map import get_coco_map, get_map
19 |
20 |
21 | class LossHistory():
22 | def __init__(self, log_dir, model, input_shape):
23 | self.log_dir = log_dir
24 | self.losses = []
25 | self.val_loss = []
26 |
27 | os.makedirs(self.log_dir)
28 | self.writer = SummaryWriter(self.log_dir)
29 | try:
30 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
31 | self.writer.add_graph(model, dummy_input)
32 | except:
33 | pass
34 |
35 | def append_loss(self, epoch, loss, val_loss):
36 | if not os.path.exists(self.log_dir):
37 | os.makedirs(self.log_dir)
38 |
39 | self.losses.append(loss)
40 | self.val_loss.append(val_loss)
41 |
42 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
43 | f.write(str(loss))
44 | f.write("\n")
45 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
46 | f.write(str(val_loss))
47 | f.write("\n")
48 |
49 | self.writer.add_scalar('loss', loss, epoch)
50 | self.writer.add_scalar('val_loss', val_loss, epoch)
51 | self.loss_plot()
52 |
53 | def loss_plot(self):
54 | iters = range(len(self.losses))
55 |
56 | plt.figure()
57 | plt.plot(iters, self.losses, 'red', linewidth=2, label='train loss')
58 | plt.plot(iters, self.val_loss, 'coral', linewidth=2, label='val loss')
59 | try:
60 | if len(self.losses) < 25:
61 | num = 5
62 | else:
63 | num = 15
64 |
65 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle='--', linewidth=2,
66 | label='smooth train loss')
67 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle='--', linewidth=2,
68 | label='smooth val loss')
69 | except:
70 | pass
71 |
72 | plt.grid(True)
73 | plt.xlabel('Epoch')
74 | plt.ylabel('Loss')
75 | plt.legend(loc="upper right")
76 |
77 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
78 |
79 | plt.cla()
80 | plt.close("all")
81 |
82 |
83 | class EvalCallback():
84 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, local_rank, radar_path, \
85 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True,
86 | MINOVERLAP=0.5, eval_flag=True, period=1):
87 | super(EvalCallback, self).__init__()
88 |
89 | self.net = net
90 | self.input_shape = input_shape
91 | self.class_names = class_names
92 | self.num_classes = num_classes
93 | self.val_lines = val_lines
94 | self.log_dir = log_dir
95 | self.cuda = cuda
96 | self.local_rank = local_rank
97 | self.map_out_path = map_out_path
98 | self.max_boxes = max_boxes
99 | self.confidence = confidence
100 | self.nms_iou = nms_iou
101 | self.letterbox_image = letterbox_image
102 | self.MINOVERLAP = MINOVERLAP
103 | self.eval_flag = eval_flag
104 | self.period = period
105 | self.radar_path = radar_path
106 |
107 | self.maps = [0]
108 | self.epoches = [0]
109 | if self.eval_flag:
110 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
111 | f.write(str(0))
112 | f.write("\n")
113 |
114 | def get_map_txt(self, image_id, image, radar_data, class_names, map_out_path):
115 | f = open(os.path.join(map_out_path, "detection-results/" + image_id + ".txt"), "w")
116 | image_shape = np.array(np.shape(image)[0:2])
117 | # ---------------------------------------------------------#
118 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
119 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
120 | # ---------------------------------------------------------#
121 | image = cvtColor(image)
122 | # ---------------------------------------------------------#
123 | # 给图像增加灰条,实现不失真的resize
124 | # 也可以直接resize进行识别
125 | # ---------------------------------------------------------#
126 | image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
127 | # ---------------------------------------------------------#
128 | # 添加上batch_size维度
129 | # ---------------------------------------------------------#
130 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
131 |
132 | with torch.no_grad():
133 | images = torch.from_numpy(image_data)
134 | if self.cuda:
135 | images = images.cuda(self.local_rank)
136 | # ---------------------------------------------------------#
137 | # 将图像输入网络当中进行预测!
138 | # ---------------------------------------------------------#
139 | outputs, _ = self.net(images, radar_data)
140 | outputs = decode_outputs(outputs, self.input_shape, self.local_rank)
141 | # ---------------------------------------------------------#
142 | # 将预测框进行堆叠,然后进行非极大抑制
143 | # ---------------------------------------------------------#
144 | results = non_max_suppression(outputs, self.num_classes, self.input_shape,
145 | image_shape, self.letterbox_image, conf_thres=self.confidence,
146 | nms_thres=self.nms_iou)
147 |
148 | if results[0] is None:
149 | return
150 |
151 | top_label = np.array(results[0][:, 6], dtype='int32')
152 | top_conf = results[0][:, 4] * results[0][:, 5]
153 | top_boxes = results[0][:, :4]
154 |
155 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
156 | top_boxes = top_boxes[top_100]
157 | top_conf = top_conf[top_100]
158 | top_label = top_label[top_100]
159 |
160 | for i, c in list(enumerate(top_label)):
161 | predicted_class = self.class_names[int(c)]
162 | box = top_boxes[i]
163 | score = str(top_conf[i])
164 |
165 | top, left, bottom, right = box
166 | if predicted_class not in class_names:
167 | continue
168 |
169 | f.write("%s %s %s %s %s %s\n" % (
170 | predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)), str(int(bottom))))
171 |
172 | f.close()
173 | return
174 |
175 | def on_epoch_end(self, epoch, model_eval):
176 | if epoch % self.period == 0 and self.eval_flag:
177 | self.net = model_eval
178 | if not os.path.exists(self.map_out_path):
179 | os.makedirs(self.map_out_path)
180 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
181 | os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
182 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
183 | os.makedirs(os.path.join(self.map_out_path, "detection-results"))
184 | print("Get map.")
185 | for annotation_line in tqdm(self.val_lines):
186 | line = annotation_line.split()
187 |
188 | # ------------------------------#
189 | # 读取雷达特征map
190 | # ------------------------------#
191 | pattern_string = "\d{10}.\d{5}"
192 | pattern = re.compile(pattern_string) # 查找数字
193 | name = pattern.findall(annotation_line)[-1]
194 |
195 | radar_path = os.path.join(self.radar_path, name + '.npz')
196 | radar_data = np.load(radar_path)['arr_0']
197 | radar_data = torch.from_numpy(radar_data).type(torch.FloatTensor).unsqueeze(0).cuda(self.local_rank)
198 |
199 | image_id = os.path.basename(line[0]).split('.')[0]
200 | # ------------------------------#
201 | # 读取图像并转换成RGB图像
202 | # ------------------------------#
203 | image = Image.open(line[0])
204 | # ------------------------------#
205 | # 获得预测框
206 | # ------------------------------#
207 | gt_boxes = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
208 | # ------------------------------#
209 | # 获得预测txt
210 | # ------------------------------#
211 | self.get_map_txt(image_id, image, radar_data, self.class_names, self.map_out_path)
212 |
213 | # ------------------------------#
214 | # 获得真实框txt
215 | # ------------------------------#
216 | with open(os.path.join(self.map_out_path, "ground-truth/" + image_id + ".txt"), "w") as new_f:
217 | for box in gt_boxes:
218 | left, top, right, bottom, obj = box
219 | obj_name = self.class_names[obj]
220 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
221 |
222 | print("Calculate Map.")
223 | try:
224 | temp_map = get_coco_map(class_names=self.class_names, path=self.map_out_path)[1]
225 | except:
226 | temp_map = get_map(self.MINOVERLAP, False, path=self.map_out_path)
227 | self.maps.append(temp_map)
228 | self.epoches.append(epoch)
229 |
230 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
231 | f.write(str(temp_map))
232 | f.write("\n")
233 |
234 | plt.figure()
235 | plt.plot(self.epoches, self.maps, 'red', linewidth=2, label='train map')
236 |
237 | plt.grid(True)
238 | plt.xlabel('Epoch')
239 | plt.ylabel('Map %s' % str(self.MINOVERLAP))
240 | plt.title('A Map Curve')
241 | plt.legend(loc="upper right")
242 |
243 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
244 | plt.cla()
245 | plt.close("all")
246 |
247 | print("Get map done.")
248 | shutil.rmtree(self.map_out_path)
249 |
--------------------------------------------------------------------------------
/utils/multitaskloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class MultiTaskLossWrapper(nn.Module):
7 | def __init__(self, task_num):
8 | super().__init__()
9 | self.task_num = task_num
10 | self.log_vars = nn.Parameter(torch.zeros(task_num-1))
11 |
12 | def forward(self, loss_seg, loss_det):
13 | loss0 = loss_det
14 |
15 | precision1 = torch.exp(-self.log_vars[0])
16 | loss1 = precision1 * loss_seg + self.log_vars[0]
17 |
18 | return loss0 + loss1
19 |
20 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | #---------------------------------------------------------#
6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
8 | #---------------------------------------------------------#
9 | def cvtColor(image):
10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
11 | return image
12 | else:
13 | image = image.convert('RGB')
14 | return image
15 |
16 | #---------------------------------------------------#
17 | # 对输入图像进行resize
18 | #---------------------------------------------------#
19 | def resize_image(image, size, letterbox_image):
20 | iw, ih = image.size
21 | w, h = size
22 | if letterbox_image:
23 | scale = min(w/iw, h/ih)
24 | nw = int(iw*scale)
25 | nh = int(ih*scale)
26 |
27 | image = image.resize((nw,nh), Image.BICUBIC)
28 | new_image = Image.new('RGB', size, (128,128,128))
29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
30 | else:
31 | new_image = image.resize((w, h), Image.BICUBIC)
32 | return new_image
33 |
34 | #---------------------------------------------------#
35 | # 获得类
36 | #---------------------------------------------------#
37 | def get_classes(classes_path):
38 | with open(classes_path, encoding='utf-8') as f:
39 | class_names = f.readlines()
40 | class_names = [c.strip() for c in class_names]
41 | return class_names, len(class_names)
42 |
43 | def preprocess_input(image):
44 | image /= 255.0
45 | image -= np.array([0.485, 0.456, 0.406])
46 | image /= np.array([0.229, 0.224, 0.225])
47 | return image
48 |
49 |
50 | def preprocess_input_radar(data):
51 | _range = np.max(data) - np.min(data)
52 | data = (data - np.min(data)) / _range + 0.0000000000001
53 | return data
54 |
55 | #---------------------------------------------------#
56 | # 获得学习率
57 | #---------------------------------------------------#
58 | def get_lr(optimizer):
59 | for param_group in optimizer.param_groups:
60 | return param_group['lr']
61 |
62 | def show_config(**kwargs):
63 | print('Configurations:')
64 | print('-' * 70)
65 | print('|%25s | %40s|' % ('keys', 'values'))
66 | print('-' * 70)
67 | for key, value in kwargs.items():
68 | print('|%25s | %40s|' % (str(key), str(value)))
69 | print('-' * 70)
70 |
--------------------------------------------------------------------------------
/utils/utils_bbox.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.ops import nms, boxes
4 |
5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
6 | #-----------------------------------------------------------------#
7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
8 | #-----------------------------------------------------------------#
9 | box_yx = box_xy[..., ::-1]
10 | box_hw = box_wh[..., ::-1]
11 | input_shape = np.array(input_shape)
12 | image_shape = np.array(image_shape)
13 |
14 | if letterbox_image:
15 | #-----------------------------------------------------------------#
16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
17 | # new_shape指的是宽高缩放情况
18 | #-----------------------------------------------------------------#
19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape))
20 | offset = (input_shape - new_shape)/2./input_shape
21 | scale = input_shape/new_shape
22 |
23 | box_yx = (box_yx - offset) * scale
24 | box_hw *= scale
25 |
26 | box_mins = box_yx - (box_hw / 2.)
27 | box_maxes = box_yx + (box_hw / 2.)
28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1)
30 | return boxes
31 |
32 | def decode_outputs(outputs, input_shape, local_rank):
33 | grids = []
34 | strides = []
35 | hw = [x.shape[-2:] for x in outputs]
36 | #---------------------------------------------------#
37 | # outputs输入前代表每个特征层的预测结果
38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400
39 | # batch_size, 5 + num_classes, 40, 40
40 | # batch_size, 5 + num_classes, 20, 20
41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400
42 | # 堆叠后为batch_size, 8400, 5 + num_classes
43 | #---------------------------------------------------#
44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
45 | #---------------------------------------------------#
46 | # 获得每一个特征点属于每一个种类的概率
47 | #---------------------------------------------------#
48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
49 | for h, w in hw:
50 | #---------------------------#
51 | # 根据特征层的高宽生成网格点
52 | #---------------------------#
53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)])
54 | #---------------------------#
55 | # 1, 6400, 2
56 | # 1, 1600, 2
57 | # 1, 400, 2
58 | #---------------------------#
59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
60 | shape = grid.shape[:2]
61 |
62 | grids.append(grid)
63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
64 | #---------------------------#
65 | # 将网格点堆叠到一起
66 | # 1, 6400, 2
67 | # 1, 1600, 2
68 | # 1, 400, 2
69 | #
70 | # 1, 8400, 2
71 | #---------------------------#
72 | grids = torch.cat(grids, dim=1).type(outputs.type()).cuda(local_rank)
73 | strides = torch.cat(strides, dim=1).type(outputs.type()).cuda(local_rank)
74 | #------------------------#
75 | # 根据网格点进行解码
76 | #------------------------#
77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides
78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
79 | #-----------------#
80 | # 归一化
81 | #-----------------#
82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
84 | return outputs
85 |
86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
87 | #----------------------------------------------------------#
88 | # 将预测结果的格式转换成左上角右下角的格式。
89 | # prediction [batch_size, num_anchors, 85]
90 | #----------------------------------------------------------#
91 | box_corner = prediction.new(prediction.shape)
92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
96 | prediction[:, :, :4] = box_corner[:, :, :4]
97 |
98 | output = [None for _ in range(len(prediction))]
99 | #----------------------------------------------------------#
100 | # 对输入图片进行循环,一般只会进行一次
101 | #----------------------------------------------------------#
102 | for i, image_pred in enumerate(prediction):
103 | #----------------------------------------------------------#
104 | # 对种类预测部分取max。
105 | # class_conf [num_anchors, 1] 种类置信度
106 | # class_pred [num_anchors, 1] 种类
107 | #----------------------------------------------------------#
108 | class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
109 |
110 | #----------------------------------------------------------#
111 | # 利用置信度进行第一轮筛选
112 | #----------------------------------------------------------#
113 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
114 |
115 | if not image_pred.size(0):
116 | continue
117 | #-------------------------------------------------------------------------#
118 | # detections [num_anchors, 7]
119 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
120 | #-------------------------------------------------------------------------#
121 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
122 | detections = detections[conf_mask]
123 |
124 | nms_out_index = boxes.batched_nms(
125 | detections[:, :4],
126 | detections[:, 4] * detections[:, 5],
127 | detections[:, 6],
128 | nms_thres,
129 | )
130 |
131 | output[i] = detections[nms_out_index]
132 |
133 | # #------------------------------------------#
134 | # # 获得预测结果中包含的所有种类
135 | # #------------------------------------------#
136 | # unique_labels = detections[:, -1].cpu().unique()
137 |
138 | # if prediction.is_cuda:
139 | # unique_labels = unique_labels.cuda()
140 | # detections = detections.cuda()
141 |
142 | # for c in unique_labels:
143 | # #------------------------------------------#
144 | # # 获得某一类得分筛选后全部的预测结果
145 | # #------------------------------------------#
146 | # detections_class = detections[detections[:, -1] == c]
147 |
148 | # #------------------------------------------#
149 | # # 使用官方自带的非极大抑制会速度更快一些!
150 | # #------------------------------------------#
151 | # keep = nms(
152 | # detections_class[:, :4],
153 | # detections_class[:, 4] * detections_class[:, 5],
154 | # nms_thres
155 | # )
156 | # max_detections = detections_class[keep]
157 |
158 | # # # 按照存在物体的置信度排序
159 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
160 | # # detections_class = detections_class[conf_sort_index]
161 | # # # 进行非极大抑制
162 | # # max_detections = []
163 | # # while detections_class.size(0):
164 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
165 | # # max_detections.append(detections_class[0].unsqueeze(0))
166 | # # if len(detections_class) == 1:
167 | # # break
168 | # # ious = bbox_iou(max_detections[-1], detections_class[1:])
169 | # # detections_class = detections_class[1:][ious < nms_thres]
170 | # # # 堆叠
171 | # # max_detections = torch.cat(max_detections).data
172 |
173 | # # Add max detections to outputs
174 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
175 |
176 | if output[i] is not None:
177 | output[i] = output[i].cpu().numpy()
178 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
179 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
180 | return output
181 |
--------------------------------------------------------------------------------
/utils/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from utils.utils import get_lr
7 |
8 | from nets.deeplabv3_training import (CE_Loss, Dice_loss, Focal_Loss,
9 | weights_init)
10 |
11 | from utils_seg.utils import get_lr
12 | from utils_seg.utils_metrics import f_score
13 |
14 | from utils.multitaskloss import MultiTaskLossWrapper
15 |
16 |
17 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, loss_history_seg, eval_callback, eval_callback_seg, optimizer, epoch, epoch_step,
18 | epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, dice_loss, focal_loss, cls_weights, num_class_seg, local_rank=0):
19 | total_loss_det = 0
20 | total_loss_seg = 0
21 | total_f_score = 0
22 |
23 | val_loss_det = 0
24 | val_loss_seg = 0
25 | val_f_score = 0
26 |
27 | total_loss = 0
28 | val_total_loss = 0
29 |
30 | if local_rank >= 0:
31 | print('Start Train')
32 | pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
33 | model_train.train()
34 | for iteration, batch in enumerate(gen):
35 | if iteration >= epoch_step:
36 | break
37 |
38 | images, targets, radars, pngs, seg_labels = batch[0], batch[1], batch[2], batch[3], batch[4]
39 |
40 | with torch.no_grad():
41 | weights = torch.from_numpy(cls_weights)
42 | if cuda:
43 | images = images.cuda(local_rank)
44 | targets = [ann.cuda(local_rank) for ann in targets]
45 | radars = radars.cuda(local_rank)
46 | pngs = pngs.cuda(local_rank)
47 | seg_labels = seg_labels.cuda(local_rank)
48 | weights = weights.cuda(local_rank)
49 |
50 | # ----------------------#
51 | # 清零梯度
52 | # ----------------------#
53 | optimizer.zero_grad()
54 | if not fp16:
55 | # ----------------------#
56 | # 前向传播
57 | # ----------------------#
58 | outputs, outputs_seg = model_train(images, radars)
59 |
60 | if focal_loss:
61 | loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
62 | else:
63 | loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
64 |
65 | if dice_loss:
66 | main_dice = Dice_loss(outputs_seg, seg_labels)
67 | loss_seg = loss_seg + main_dice
68 |
69 | # ----------------------#
70 | # 计算损失
71 | # ----------------------#
72 | loss_det = yolo_loss(outputs, targets)
73 |
74 | mtl = MultiTaskLossWrapper(task_num=2)
75 | total_loss = mtl(loss_seg, loss_det)
76 |
77 | with torch.no_grad():
78 | train_f_score = f_score(outputs_seg, seg_labels)
79 |
80 | # ----------------------#
81 | # 反向传播
82 | # ----------------------#
83 | total_loss.backward()
84 | optimizer.step()
85 | else:
86 | from torch.cuda.amp import autocast
87 | with autocast():
88 | outputs, outputs_seg = model_train(images, radars)
89 |
90 | if focal_loss:
91 | loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
92 | else:
93 | loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
94 |
95 | if dice_loss:
96 | main_dice = Dice_loss(outputs_seg, seg_labels)
97 | loss_seg = loss_seg + main_dice
98 |
99 | # ----------------------#
100 | # calculate loss
101 | # ----------------------#
102 | loss_det = yolo_loss(outputs, targets)
103 |
104 | # mtl = MultiTaskLossWrapper(task_num=2)
105 | # total_loss = mtl(loss_seg, loss_det)
106 | total_loss = loss_det + 5 * loss_seg
107 |
108 | with torch.no_grad():
109 | train_f_score = f_score(outputs_seg, seg_labels)
110 |
111 | # ----------------------#
112 | # back-propagation
113 | # ----------------------#
114 | scaler.scale(total_loss).backward()
115 | scaler.step(optimizer)
116 | scaler.update()
117 | if ema:
118 | ema.update(model_train)
119 |
120 | total_loss_det += loss_det.item()
121 | total_loss_seg += loss_seg.item()
122 | total_loss += total_loss_det + total_loss_seg
123 | total_f_score += train_f_score.item()
124 |
125 | if local_rank >= 0:
126 | pbar.set_postfix(**{'detection loss': total_loss_det / (iteration + 1),
127 | 'segmentation loss': total_loss_seg / (iteration + 1),
128 | 'total loss': total_loss / (iteration + 1),
129 | 'f score': total_f_score / (iteration + 1),
130 | 'lr': get_lr(optimizer)})
131 | pbar.update(1)
132 |
133 | if local_rank >= 0:
134 | pbar.close()
135 | print('Finish Train')
136 | print('Start Validation')
137 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
138 |
139 | if ema:
140 | model_train_eval = ema.ema
141 | else:
142 | model_train_eval = model_train.eval()
143 |
144 | for iteration, batch in enumerate(gen_val):
145 | if iteration >= epoch_step_val:
146 | break
147 | images, targets, radars, pngs, seg_labels = batch[0], batch[1], batch[2], batch[3], batch[4]
148 | with torch.no_grad():
149 | if cuda:
150 | images = images.cuda(local_rank)
151 | targets = [ann.cuda(local_rank) for ann in targets]
152 | radars = radars.cuda(local_rank)
153 | pngs = pngs.cuda(local_rank)
154 | seg_labels = seg_labels.cuda(local_rank)
155 | weights = weights.cuda(local_rank)
156 | # ----------------------#
157 | # 清零梯度
158 | # ----------------------#
159 | optimizer.zero_grad()
160 | # ----------------------#
161 | # 前向传播
162 | # ----------------------#
163 | outputs, outputs_seg = model_train(images, radars)
164 |
165 | if focal_loss:
166 | loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
167 | else:
168 | loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
169 |
170 | if dice_loss:
171 | main_dice = Dice_loss(outputs_seg, seg_labels)
172 | loss_seg = loss_seg + main_dice
173 |
174 | # -------------------------------#
175 | # 计算f_score
176 | # -------------------------------#
177 | _f_score = f_score(outputs_seg, seg_labels)
178 |
179 | # ----------------------#
180 | # 计算损失
181 | # ----------------------#
182 | loss_value = yolo_loss(outputs, targets)
183 | loss_value_seg = loss_seg
184 | val_f_score += _f_score.item()
185 |
186 | val_loss_det += loss_value.item()
187 | val_loss_seg += loss_value_seg.item()
188 | val_total_loss = val_loss_det + val_loss_seg
189 |
190 | if local_rank >= 0:
191 | pbar.set_postfix(**{'detection val_loss': val_loss_det / (iteration + 1),
192 | 'segmentation val_loss': val_loss_seg / (iteration + 1),
193 | 'val loss': val_total_loss / (iteration + 1),
194 | 'f_score': val_f_score / (iteration + 1),
195 | })
196 | pbar.update(1)
197 |
198 | if local_rank >= 0:
199 | pbar.close()
200 | print('Finish Validation')
201 | loss_history.append_loss(epoch + 1, total_loss_det / epoch_step, val_loss_det / epoch_step_val)
202 | loss_history_seg.append_loss(epoch + 1, total_loss_seg / epoch_step, val_loss_seg / epoch_step_val)
203 | eval_callback.on_epoch_end(epoch + 1, model_train_eval)
204 | eval_callback_seg.on_epoch_end(epoch + 1, model_train_eval)
205 | print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
206 | print('Total Loss: %.3f || Val Loss Det: %.3f || Val Loss Seg: %.3f' % ((total_loss / epoch_step,
207 | val_loss_det / epoch_step_val,
208 | val_loss_seg / epoch_step_val)))
209 |
210 | # -----------------------------------------------#
211 | # 保存权值
212 | # -----------------------------------------------#
213 | if ema:
214 | save_state_dict = ema.ema.state_dict()
215 | else:
216 | save_state_dict = model.state_dict()
217 |
218 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
219 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-det_val_loss%.3f-seg_val_loss%.3f.pth" % (
220 | epoch + 1, val_total_loss / epoch_step, val_loss_det / epoch_step_val, val_loss_seg / epoch_step_val)))
221 |
222 | if len(loss_history.val_loss) <= 1 or (val_total_loss / epoch_step_val) <= min(loss_history.val_loss) + min(loss_history_seg.val_loss):
223 | print('Save best model to best_epoch_weights.pth')
224 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
225 |
226 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/utils_seg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__init__.py
--------------------------------------------------------------------------------
/utils_seg/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/utils_seg/__pycache__/callbacks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/callbacks.cpython-39.pyc
--------------------------------------------------------------------------------
/utils_seg/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/utils_seg/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/utils_seg/__pycache__/utils_fit.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/utils_fit.cpython-39.pyc
--------------------------------------------------------------------------------
/utils_seg/__pycache__/utils_metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/utils_seg/__pycache__/utils_metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/utils_seg/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import matplotlib
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | matplotlib.use('Agg')
8 | from matplotlib import pyplot as plt
9 | import scipy.signal
10 |
11 | import cv2
12 | import shutil
13 | import numpy as np
14 |
15 | from PIL import Image
16 | from tqdm import tqdm
17 | from torch.utils.tensorboard import SummaryWriter
18 | from .utils import cvtColor, preprocess_input, resize_image
19 | from .utils_metrics import compute_mIoU
20 |
21 |
22 | class LossHistory():
23 | def __init__(self, log_dir, model, input_shape):
24 | self.log_dir = log_dir
25 | self.losses = []
26 | self.val_loss = []
27 |
28 | os.makedirs(self.log_dir)
29 | self.writer = SummaryWriter(self.log_dir)
30 | try:
31 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
32 | self.writer.add_graph(model, dummy_input)
33 | except:
34 | pass
35 |
36 | def append_loss(self, epoch, loss, val_loss):
37 | if not os.path.exists(self.log_dir):
38 | os.makedirs(self.log_dir)
39 |
40 | self.losses.append(loss)
41 | self.val_loss.append(val_loss)
42 |
43 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
44 | f.write(str(loss))
45 | f.write("\n")
46 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
47 | f.write(str(val_loss))
48 | f.write("\n")
49 |
50 | self.writer.add_scalar('loss', loss, epoch)
51 | self.writer.add_scalar('val_loss', val_loss, epoch)
52 | self.loss_plot()
53 |
54 | def loss_plot(self):
55 | iters = range(len(self.losses))
56 |
57 | plt.figure()
58 | plt.plot(iters, self.losses, 'red', linewidth=2, label='train loss')
59 | plt.plot(iters, self.val_loss, 'coral', linewidth=2, label='val loss')
60 | try:
61 | if len(self.losses) < 25:
62 | num = 5
63 | else:
64 | num = 15
65 |
66 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle='--', linewidth=2,
67 | label='smooth train loss')
68 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle='--', linewidth=2,
69 | label='smooth val loss')
70 | except:
71 | pass
72 |
73 | plt.grid(True)
74 | plt.xlabel('Epoch')
75 | plt.ylabel('Loss')
76 | plt.legend(loc="upper right")
77 |
78 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
79 |
80 | plt.cla()
81 | plt.close("all")
82 |
83 |
84 | class EvalCallback():
85 | def __init__(self, net, input_shape, num_classes, image_ids, dataset_path, log_dir, cuda, local_rank, radar_path, \
86 | miou_out_path=".temp_miou_out", eval_flag=True, period=1):
87 | super(EvalCallback, self).__init__()
88 |
89 | self.net = net
90 | self.input_shape = input_shape
91 | self.num_classes = num_classes
92 | self.image_ids = image_ids
93 | self.dataset_path = dataset_path
94 | self.log_dir = log_dir
95 | self.cuda = cuda
96 | self.local_rank = local_rank
97 | self.miou_out_path = miou_out_path
98 | self.eval_flag = eval_flag
99 | self.period = period
100 | self.radar_path = radar_path
101 |
102 | pattern_string = "\d{10}.\d{5}"
103 | pattern = re.compile(pattern_string) # 查找数字
104 | self.image_ids = [pattern.findall(image_id)[-1] for image_id in image_ids]
105 |
106 | self.mious = [0]
107 | self.epoches = [0]
108 | if self.eval_flag:
109 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
110 | f.write(str(0))
111 | f.write("\n")
112 |
113 | def get_miou_png(self, image, radar_data):
114 | # ---------------------------------------------------------#
115 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
116 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
117 | # ---------------------------------------------------------#
118 | image = cvtColor(image)
119 | orininal_h = np.array(image).shape[0]
120 | orininal_w = np.array(image).shape[1]
121 | # ---------------------------------------------------------#
122 | # 给图像增加灰条,实现不失真的resize
123 | # 也可以直接resize进行识别
124 | # ---------------------------------------------------------#
125 | image_data, nw, nh = resize_image(image, (self.input_shape[1], self.input_shape[0]))
126 | # ---------------------------------------------------------#
127 | # 添加上batch_size维度
128 | # ---------------------------------------------------------#
129 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)
130 |
131 | with torch.no_grad():
132 | images = torch.from_numpy(image_data)
133 | if self.cuda:
134 | images = images.cuda(self.local_rank)
135 | radar_data = radar_data.cuda(self.local_rank)
136 |
137 | # ---------------------------------------------------#
138 | # 图片传入网络进行预测
139 | # ---------------------------------------------------#
140 | pr = self.net(images, radar_data)[1][0]
141 | # ---------------------------------------------------#
142 | # 取出每一个像素点的种类
143 | # ---------------------------------------------------#
144 | pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()
145 | # --------------------------------------#
146 | # 将灰条部分截取掉
147 | # --------------------------------------#
148 | pr = pr[int((self.input_shape[0] - nh) // 2): int((self.input_shape[0] - nh) // 2 + nh), \
149 | int((self.input_shape[1] - nw) // 2): int((self.input_shape[1] - nw) // 2 + nw)]
150 | # ---------------------------------------------------#
151 | # 进行图片的resize
152 | # ---------------------------------------------------#
153 | pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation=cv2.INTER_LINEAR)
154 | # ---------------------------------------------------#
155 | # 取出每一个像素点的种类
156 | # ---------------------------------------------------#
157 | pr = pr.argmax(axis=-1)
158 |
159 | image = Image.fromarray(np.uint8(pr))
160 | return image
161 |
162 | def on_epoch_end(self, epoch, model_eval):
163 | if epoch % self.period == 0 and self.eval_flag:
164 | self.net = model_eval
165 | gt_dir = os.path.join(self.dataset_path, "VOC2007/SegmentationClass/")
166 | pred_dir = os.path.join(self.miou_out_path, 'detection-results')
167 | if not os.path.exists(self.miou_out_path):
168 | os.makedirs(self.miou_out_path)
169 | if not os.path.exists(pred_dir):
170 | os.makedirs(pred_dir)
171 | print("Get miou.")
172 | for image_id in tqdm(self.image_ids):
173 | # ------------------------------#
174 | # 读取雷达特征map
175 | # ------------------------------#
176 | radar_path = os.path.join(self.radar_path, image_id + '.npz')
177 | radar_data = np.load(radar_path)['arr_0']
178 | radar_data = torch.from_numpy(radar_data).type(torch.FloatTensor).unsqueeze(0).cuda(self.local_rank)
179 |
180 | # -------------------------------#
181 | # 从文件中读取图像
182 | # -------------------------------#
183 | image_path = os.path.join(self.dataset_path, "VOC2007/JPEGImages/" + image_id + ".jpg")
184 | image = Image.open(image_path)
185 | # ------------------------------#
186 | # 获得预测txt
187 | # ------------------------------#
188 | image = self.get_miou_png(image, radar_data)
189 | image.save(os.path.join(pred_dir, image_id + ".png"))
190 |
191 | print("Calculate miou.")
192 | _, IoUs, _, _ = compute_mIoU(gt_dir, pred_dir, self.image_ids, self.num_classes, None) # 执行计算mIoU的函数
193 | temp_miou = np.nanmean(IoUs) * 100
194 |
195 | self.mious.append(temp_miou)
196 | self.epoches.append(epoch)
197 |
198 | with open(os.path.join(self.log_dir, "epoch_miou.txt"), 'a') as f:
199 | f.write(str(temp_miou))
200 | f.write("\n")
201 |
202 | plt.figure()
203 | plt.plot(self.epoches, self.mious, 'red', linewidth=2, label='train miou')
204 |
205 | plt.grid(True)
206 | plt.xlabel('Epoch')
207 | plt.ylabel('Miou')
208 | plt.title('A Miou Curve')
209 | plt.legend(loc="upper right")
210 |
211 | plt.savefig(os.path.join(self.log_dir, "epoch_miou.png"))
212 | plt.cla()
213 | plt.close("all")
214 |
215 | print("Get miou done.")
216 | shutil.rmtree(self.miou_out_path)
217 |
--------------------------------------------------------------------------------
/utils_seg/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data.dataset import Dataset
8 |
9 | from utils_seg.utils import cvtColor, preprocess_input
10 |
11 |
12 | class DeeplabDataset(Dataset):
13 | def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
14 | super(DeeplabDataset, self).__init__()
15 | self.annotation_lines = annotation_lines
16 | self.length = len(annotation_lines)
17 | self.input_shape = input_shape
18 | self.num_classes = num_classes
19 | self.train = train
20 | self.dataset_path = dataset_path
21 |
22 | def __len__(self):
23 | return self.length
24 |
25 | def __getitem__(self, index):
26 | annotation_line = self.annotation_lines[index]
27 | name = annotation_line.split()[0]
28 |
29 | #-------------------------------#
30 | # 从文件中读取图像
31 | #-------------------------------#
32 | jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))
33 | png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))
34 | #-------------------------------#
35 | # 数据增强
36 | #-------------------------------#
37 | jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
38 |
39 | jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
40 | png = np.array(png)
41 | png[png >= self.num_classes] = self.num_classes
42 | #-------------------------------------------------------#
43 | # 转化成one_hot的形式
44 | # 在这里需要+1是因为voc数据集有些标签具有白边部分
45 | # 我们需要将白边部分进行忽略,+1的目的是方便忽略。
46 | #-------------------------------------------------------#
47 | seg_labels = np.eye(self.num_classes+1)[png.reshape([-1])]
48 | seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
49 |
50 | return jpg, png, seg_labels
51 |
52 | def rand(self, a=0, b=1):
53 | return np.random.rand() * (b - a) + a
54 |
55 | def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
56 | image = cvtColor(image)
57 | label = Image.fromarray(np.array(label))
58 | #------------------------------#
59 | # 获得图像的高宽与目标高宽
60 | #------------------------------#
61 | iw, ih = image.size
62 | h, w = input_shape
63 |
64 | if not random:
65 | iw, ih = image.size
66 | scale = min(w/iw, h/ih)
67 | nw = int(iw*scale)
68 | nh = int(ih*scale)
69 |
70 | image = image.resize((nw,nh), Image.BICUBIC)
71 | new_image = Image.new('RGB', [w, h], (128,128,128))
72 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
73 |
74 | label = label.resize((nw,nh), Image.NEAREST)
75 | new_label = Image.new('L', [w, h], (0))
76 | new_label.paste(label, ((w-nw)//2, (h-nh)//2))
77 | return new_image, new_label
78 |
79 | #------------------------------------------#
80 | # 对图像进行缩放并且进行长和宽的扭曲
81 | #------------------------------------------#
82 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
83 | scale = self.rand(0.25, 2)
84 | if new_ar < 1:
85 | nh = int(scale*h)
86 | nw = int(nh*new_ar)
87 | else:
88 | nw = int(scale*w)
89 | nh = int(nw/new_ar)
90 | image = image.resize((nw,nh), Image.BICUBIC)
91 | label = label.resize((nw,nh), Image.NEAREST)
92 |
93 | #------------------------------------------#
94 | # 翻转图像
95 | #------------------------------------------#
96 | flip = self.rand()<.5
97 | if flip:
98 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
99 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
100 |
101 | #------------------------------------------#
102 | # 将图像多余的部分加上灰条
103 | #------------------------------------------#
104 | dx = int(self.rand(0, w-nw))
105 | dy = int(self.rand(0, h-nh))
106 | new_image = Image.new('RGB', (w,h), (128,128,128))
107 | new_label = Image.new('L', (w,h), (0))
108 | new_image.paste(image, (dx, dy))
109 | new_label.paste(label, (dx, dy))
110 | image = new_image
111 | label = new_label
112 |
113 | image_data = np.array(image, np.uint8)
114 |
115 | #------------------------------------------#
116 | # 高斯模糊
117 | #------------------------------------------#
118 | blur = self.rand() < 0.25
119 | if blur:
120 | image_data = cv2.GaussianBlur(image_data, (5, 5), 0)
121 |
122 | #------------------------------------------#
123 | # 旋转
124 | #------------------------------------------#
125 | rotate = self.rand() < 0.25
126 | if rotate:
127 | center = (w // 2, h // 2)
128 | rotation = np.random.randint(-10, 11)
129 | M = cv2.getRotationMatrix2D(center, -rotation, scale=1)
130 | image_data = cv2.warpAffine(image_data, M, (w, h), flags=cv2.INTER_CUBIC, borderValue=(128,128,128))
131 | label = cv2.warpAffine(np.array(label, np.uint8), M, (w, h), flags=cv2.INTER_NEAREST, borderValue=(0))
132 |
133 | #---------------------------------#
134 | # 对图像进行色域变换
135 | # 计算色域变换的参数
136 | #---------------------------------#
137 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
138 | #---------------------------------#
139 | # 将图像转到HSV上
140 | #---------------------------------#
141 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
142 | dtype = image_data.dtype
143 | #---------------------------------#
144 | # 应用变换
145 | #---------------------------------#
146 | x = np.arange(0, 256, dtype=r.dtype)
147 | lut_hue = ((x * r[0]) % 180).astype(dtype)
148 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
149 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
150 |
151 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
152 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
153 |
154 | return image_data, label
155 |
156 |
157 | # DataLoader中collate_fn使用
158 | def deeplab_dataset_collate(batch):
159 | images = []
160 | pngs = []
161 | seg_labels = []
162 | for img, png, labels in batch:
163 | images.append(img)
164 | pngs.append(png)
165 | seg_labels.append(labels)
166 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
167 | pngs = torch.from_numpy(np.array(pngs)).long()
168 | seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
169 | return images, pngs, seg_labels
170 |
--------------------------------------------------------------------------------
/utils_seg/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | #---------------------------------------------------------#
6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
8 | #---------------------------------------------------------#
9 | def cvtColor(image):
10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
11 | return image
12 | else:
13 | image = image.convert('RGB')
14 | return image
15 |
16 | #---------------------------------------------------#
17 | # 对输入图像进行resize
18 | #---------------------------------------------------#
19 | def resize_image(image, size):
20 | iw, ih = image.size
21 | w, h = size
22 |
23 | scale = min(w/iw, h/ih)
24 | nw = int(iw*scale)
25 | nh = int(ih*scale)
26 |
27 | image = image.resize((nw,nh), Image.BICUBIC)
28 | new_image = Image.new('RGB', size, (128,128,128))
29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
30 |
31 | return new_image, nw, nh
32 |
33 | #---------------------------------------------------#
34 | # 获得学习率
35 | #---------------------------------------------------#
36 | def get_lr(optimizer):
37 | for param_group in optimizer.param_groups:
38 | return param_group['lr']
39 |
40 | def preprocess_input(image):
41 | image /= 255.0
42 | image -= np.array([0.485, 0.456, 0.406])
43 | image /= np.array([0.229, 0.224, 0.225])
44 | return image
45 |
46 | def show_config(**kwargs):
47 | print('Configurations:')
48 | print('-' * 70)
49 | print('|%25s | %40s|' % ('keys', 'values'))
50 | print('-' * 70)
51 | for key, value in kwargs.items():
52 | print('|%25s | %40s|' % (str(key), str(value)))
53 | print('-' * 70)
54 |
55 | def download_weights(backbone, model_dir="./model_data"):
56 | import os
57 | from torch.hub import load_state_dict_from_url
58 |
59 | download_urls = {
60 | 'mobilenet' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/mobilenet_v2.pth.tar',
61 | 'xception' : 'https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth',
62 | }
63 | url = download_urls[backbone]
64 |
65 | if not os.path.exists(model_dir):
66 | os.makedirs(model_dir)
67 | load_state_dict_from_url(url, model_dir)
--------------------------------------------------------------------------------
/utils_seg/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from nets.deeplabv3_training import (CE_Loss, Dice_loss, Focal_Loss,
5 | weights_init)
6 | from tqdm import tqdm
7 |
8 | from utils_seg.utils import get_lr
9 | from utils_seg.utils_metrics import f_score
10 |
11 |
12 |
13 | def fit_one_epoch(model_train, model, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, dice_loss, focal_loss, cls_weights, num_classes, \
14 | fp16, scaler, save_period, save_dir, local_rank=0):
15 | total_loss = 0
16 | total_f_score = 0
17 |
18 | val_loss = 0
19 | val_f_score = 0
20 |
21 | if local_rank == 0:
22 | print('Start Train')
23 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
24 | model_train.train()
25 | for iteration, batch in enumerate(gen):
26 | if iteration >= epoch_step:
27 | break
28 | imgs, targets, radars, pngs, labels = batch
29 |
30 | with torch.no_grad():
31 | weights = torch.from_numpy(cls_weights)
32 | if cuda:
33 | imgs = imgs.cuda(local_rank)
34 | targets = [ann.cuda(local_rank) for ann in targets]
35 | radars = radars.cuda(local_rank)
36 | pngs = pngs.cuda(local_rank)
37 | labels = labels.cuda(local_rank)
38 | weights = weights.cuda(local_rank)
39 |
40 | optimizer.zero_grad()
41 | if not fp16:
42 | _, outputs = model_train(imgs, radars)
43 | if focal_loss:
44 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
45 | else:
46 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
47 |
48 | if dice_loss:
49 | main_dice = Dice_loss(outputs, labels)
50 | loss = loss + main_dice
51 |
52 | with torch.no_grad():
53 | #-------------------------------#
54 | # 计算f_score
55 | #-------------------------------#
56 | _f_score = f_score(outputs, labels)
57 |
58 | loss.backward()
59 | optimizer.step()
60 | else:
61 | from torch.cuda.amp import autocast
62 | with autocast():
63 | _, outputs = model_train(imgs, radars)
64 | if focal_loss:
65 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
66 | else:
67 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
68 |
69 | if dice_loss:
70 | main_dice = Dice_loss(outputs, labels)
71 | loss = loss + main_dice
72 |
73 | with torch.no_grad():
74 | #-------------------------------#
75 | # 计算f_score
76 | #-------------------------------#
77 | _f_score = f_score(outputs, labels)
78 |
79 | #----------------------#
80 | # 反向传播
81 | #----------------------#
82 | scaler.scale(loss).backward()
83 | scaler.step(optimizer)
84 | scaler.update()
85 |
86 | total_loss += loss.item()
87 | total_f_score += _f_score.item()
88 |
89 | if local_rank == 0:
90 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1),
91 | 'f_score' : total_f_score / (iteration + 1),
92 | 'lr' : get_lr(optimizer)})
93 | pbar.update(1)
94 |
95 | if local_rank == 0:
96 | pbar.close()
97 | print('Finish Train')
98 | print('Start Validation')
99 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
100 |
101 | model_train.eval()
102 | for iteration, batch in enumerate(gen_val):
103 | if iteration >= epoch_step_val:
104 | break
105 | imgs, targets, radars, pngs, labels = batch
106 | with torch.no_grad():
107 | weights = torch.from_numpy(cls_weights)
108 | if cuda:
109 | imgs = imgs.cuda(local_rank)
110 | targets = [ann.cuda(local_rank) for ann in targets]
111 | radars = radars.cuda(local_rank)
112 | pngs = pngs.cuda(local_rank)
113 | labels = labels.cuda(local_rank)
114 | weights = weights.cuda(local_rank)
115 |
116 | _, outputs = model_train(imgs, radars)
117 | if focal_loss:
118 | loss = Focal_Loss(outputs, pngs, weights, num_classes = num_classes)
119 | else:
120 | loss = CE_Loss(outputs, pngs, weights, num_classes = num_classes)
121 |
122 | if dice_loss:
123 | main_dice = Dice_loss(outputs, labels)
124 | loss = loss + main_dice
125 | #-------------------------------#
126 | # 计算f_score
127 | #-------------------------------#
128 | _f_score = f_score(outputs, labels)
129 |
130 | val_loss += loss.item()
131 | val_f_score += _f_score.item()
132 |
133 | if local_rank == 0:
134 | pbar.set_postfix(**{'val_loss' : val_loss / (iteration + 1),
135 | 'f_score' : val_f_score / (iteration + 1),
136 | 'lr' : get_lr(optimizer)})
137 | pbar.update(1)
138 |
139 | if local_rank == 0:
140 | pbar.close()
141 | print('Finish Validation')
142 | loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
143 | eval_callback.on_epoch_end(epoch + 1, model_train)
144 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
145 | print('Total Loss: %.3f || Val Loss: %.3f ' % (total_loss / epoch_step, val_loss / epoch_step_val))
146 |
147 | #-----------------------------------------------#
148 | # 保存权值
149 | #-----------------------------------------------#
150 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
151 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth' % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
152 |
153 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
154 | print('Save best model to best_epoch_weights.pth')
155 | torch.save(model.state_dict(), os.path.join(save_dir, "best_epoch_weights.pth"))
156 |
157 | torch.save(model.state_dict(), os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/utils_seg/utils_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from os.path import join
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from PIL import Image
10 |
11 |
12 | def f_score(inputs, target, beta=1, smooth=1e-5, threhold=0.5):
13 | n, c, h, w = inputs.size()
14 | nt, ht, wt, ct = target.size()
15 | if h != ht and w != wt:
16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
17 |
18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), -1)
19 | temp_target = target.view(n, -1, ct)
20 |
21 | # --------------------------------------------#
22 | # 计算dice系数
23 | # --------------------------------------------#
24 | temp_inputs = torch.gt(temp_inputs, threhold).float()
25 | tp = torch.sum(temp_target[..., :-1] * temp_inputs, axis=[0, 1])
26 | fp = torch.sum(temp_inputs, axis=[0, 1]) - tp
27 | fn = torch.sum(temp_target[..., :-1], axis=[0, 1]) - tp
28 |
29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
30 | score = torch.mean(score)
31 | return score
32 |
33 |
34 | # 设标签宽W,长H
35 | def fast_hist(a, b, n):
36 | # --------------------------------------------------------------------------------#
37 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
38 | # --------------------------------------------------------------------------------#
39 | k = (a >= 0) & (a < n)
40 | # --------------------------------------------------------------------------------#
41 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
42 | # 返回中,写对角线上的为分类正确的像素点
43 | # --------------------------------------------------------------------------------#
44 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
45 |
46 |
47 | def per_class_iu(hist):
48 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
49 |
50 |
51 | def per_class_PA_Recall(hist):
52 | return np.diag(hist) / np.maximum(hist.sum(1), 1)
53 |
54 |
55 | def per_class_Precision(hist):
56 | return np.diag(hist) / np.maximum(hist.sum(0), 1)
57 |
58 |
59 | def per_Accuracy(hist):
60 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
61 |
62 |
63 | def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes=None):
64 | print('Num classes', num_classes)
65 | # -----------------------------------------#
66 | # 创建一个全是0的矩阵,是一个混淆矩阵
67 | # -----------------------------------------#
68 | hist = np.zeros((num_classes, num_classes))
69 |
70 | # ------------------------------------------------#
71 | # 获得验证集标签路径列表,方便直接读取
72 | # 获得验证集图像分割结果路径列表,方便直接读取
73 | # ------------------------------------------------#
74 | gt_imgs = [join(gt_dir, x + ".png") for x in png_name_list]
75 | pred_imgs = [join(pred_dir, x + ".png") for x in png_name_list]
76 |
77 | # ------------------------------------------------#
78 | # 读取每一个(图片-标签)对
79 | # ------------------------------------------------#
80 | for ind in range(len(gt_imgs)):
81 | # ------------------------------------------------#
82 | # 读取一张图像分割结果,转化成numpy数组
83 | # ------------------------------------------------#
84 |
85 | pred = np.array(Image.open(pred_imgs[ind]))
86 | # ------------------------------------------------#
87 | # 读取一张对应的标签,转化成numpy数组
88 | # ------------------------------------------------#
89 | label = np.array(Image.open(gt_imgs[ind]))
90 |
91 | # 如果图像分割结果与标签的大小不一样,这张图片就不计算
92 | if len(label.flatten()) != len(pred.flatten()):
93 | print(
94 | 'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
95 | len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
96 | pred_imgs[ind]))
97 | continue
98 |
99 | # ------------------------------------------------#
100 | # 对一张图片计算21×21的hist矩阵,并累加
101 | # ------------------------------------------------#
102 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
103 | # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
104 | if name_classes is not None and ind > 0 and ind % 10 == 0:
105 | print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
106 | ind,
107 | len(gt_imgs),
108 | 100 * np.nanmean(per_class_iu(hist)),
109 | 100 * np.nanmean(per_class_PA_Recall(hist)),
110 | 100 * per_Accuracy(hist)
111 | )
112 | )
113 | # ------------------------------------------------#
114 | # 计算所有验证集图片的逐类别mIoU值
115 | # ------------------------------------------------#
116 | IoUs = per_class_iu(hist)
117 | PA_Recall = per_class_PA_Recall(hist)
118 | Precision = per_class_Precision(hist)
119 | # ------------------------------------------------#
120 | # 逐类别输出一下mIoU值
121 | # ------------------------------------------------#
122 | if name_classes is not None:
123 | for ind_class in range(num_classes):
124 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
125 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2)) + '; Precision-' + str(
126 | round(Precision[ind_class] * 100, 2)))
127 |
128 | # -----------------------------------------------------------------#
129 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
130 | # -----------------------------------------------------------------#
131 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(
132 | round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
133 | return np.array(hist, np.int), IoUs, PA_Recall, Precision
134 |
135 |
136 | def adjust_axes(r, t, fig, axes):
137 | bb = t.get_window_extent(renderer=r)
138 | text_width_inches = bb.width / fig.dpi
139 | current_fig_width = fig.get_figwidth()
140 | new_fig_width = current_fig_width + text_width_inches
141 | propotion = new_fig_width / current_fig_width
142 | x_lim = axes.get_xlim()
143 | axes.set_xlim([x_lim[0], x_lim[1] * propotion])
144 |
145 |
146 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size=12, plt_show=True):
147 | fig = plt.gcf()
148 | axes = plt.gca()
149 | plt.barh(range(len(values)), values, color='royalblue')
150 | plt.title(plot_title, fontsize=tick_font_size + 2)
151 | plt.xlabel(x_label, fontsize=tick_font_size)
152 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
153 | r = fig.canvas.get_renderer()
154 | for i, val in enumerate(values):
155 | str_val = " " + str(val)
156 | if val < 1.0:
157 | str_val = " {0:.2f}".format(val)
158 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
159 | if i == (len(values) - 1):
160 | adjust_axes(r, t, fig, axes)
161 |
162 | fig.tight_layout()
163 | fig.savefig(output_path)
164 | if plt_show:
165 | plt.show()
166 | plt.close()
167 |
168 |
169 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size=12):
170 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs) * 100), "Intersection over Union", \
171 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size=tick_font_size, plt_show=True)
172 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
173 |
174 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Pixel Accuracy", \
175 | os.path.join(miou_out_path, "mPA.png"), tick_font_size=tick_font_size, plt_show=False)
176 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
177 |
178 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall) * 100), "Recall", \
179 | os.path.join(miou_out_path, "Recall.png"), tick_font_size=tick_font_size, plt_show=False)
180 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
181 |
182 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision) * 100), "Precision", \
183 | os.path.join(miou_out_path, "Precision.png"), tick_font_size=tick_font_size, plt_show=False)
184 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
185 |
186 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
187 | writer = csv.writer(f)
188 | writer_list = []
189 | writer_list.append([' '] + [str(c) for c in name_classes])
190 | for i in range(len(hist)):
191 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
192 | writer.writerows(writer_list)
193 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))
194 |
--------------------------------------------------------------------------------
/venv/Scripts/Activate.ps1:
--------------------------------------------------------------------------------
1 | <#
2 | .Synopsis
3 | Activate a Python virtual environment for the current PowerShell session.
4 |
5 | .Description
6 | Pushes the python executable for a virtual environment to the front of the
7 | $Env:PATH environment variable and sets the prompt to signify that you are
8 | in a Python virtual environment. Makes use of the command line switches as
9 | well as the `pyvenv.cfg` file values present in the virtual environment.
10 |
11 | .Parameter VenvDir
12 | Path to the directory that contains the virtual environment to activate. The
13 | default value for this is the parent of the directory that the Activate.ps1
14 | script is located within.
15 |
16 | .Parameter Prompt
17 | The prompt prefix to display when this virtual environment is activated. By
18 | default, this prompt is the name of the virtual environment folder (VenvDir)
19 | surrounded by parentheses and followed by a single space (ie. '(.venv) ').
20 |
21 | .Example
22 | Activate.ps1
23 | Activates the Python virtual environment that contains the Activate.ps1 script.
24 |
25 | .Example
26 | Activate.ps1 -Verbose
27 | Activates the Python virtual environment that contains the Activate.ps1 script,
28 | and shows extra information about the activation as it executes.
29 |
30 | .Example
31 | Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
32 | Activates the Python virtual environment located in the specified location.
33 |
34 | .Example
35 | Activate.ps1 -Prompt "MyPython"
36 | Activates the Python virtual environment that contains the Activate.ps1 script,
37 | and prefixes the current prompt with the specified string (surrounded in
38 | parentheses) while the virtual environment is active.
39 |
40 | .Notes
41 | On Windows, it may be required to enable this Activate.ps1 script by setting the
42 | execution policy for the user. You can do this by issuing the following PowerShell
43 | command:
44 |
45 | PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
46 |
47 | For more information on Execution Policies:
48 | https://go.microsoft.com/fwlink/?LinkID=135170
49 |
50 | #>
51 | Param(
52 | [Parameter(Mandatory = $false)]
53 | [String]
54 | $VenvDir,
55 | [Parameter(Mandatory = $false)]
56 | [String]
57 | $Prompt
58 | )
59 |
60 | <# Function declarations --------------------------------------------------- #>
61 |
62 | <#
63 | .Synopsis
64 | Remove all shell session elements added by the Activate script, including the
65 | addition of the virtual environment's Python executable from the beginning of
66 | the PATH variable.
67 |
68 | .Parameter NonDestructive
69 | If present, do not remove this function from the global namespace for the
70 | session.
71 |
72 | #>
73 | function global:deactivate ([switch]$NonDestructive) {
74 | # Revert to original values
75 |
76 | # The prior prompt:
77 | if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
78 | Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
79 | Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
80 | }
81 |
82 | # The prior PYTHONHOME:
83 | if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
84 | Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
85 | Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
86 | }
87 |
88 | # The prior PATH:
89 | if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
90 | Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
91 | Remove-Item -Path Env:_OLD_VIRTUAL_PATH
92 | }
93 |
94 | # Just remove the VIRTUAL_ENV altogether:
95 | if (Test-Path -Path Env:VIRTUAL_ENV) {
96 | Remove-Item -Path env:VIRTUAL_ENV
97 | }
98 |
99 | # Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
100 | if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
101 | Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
102 | }
103 |
104 | # Leave deactivate function in the global namespace if requested:
105 | if (-not $NonDestructive) {
106 | Remove-Item -Path function:deactivate
107 | }
108 | }
109 |
110 | <#
111 | .Description
112 | Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
113 | given folder, and returns them in a map.
114 |
115 | For each line in the pyvenv.cfg file, if that line can be parsed into exactly
116 | two strings separated by `=` (with any amount of whitespace surrounding the =)
117 | then it is considered a `key = value` line. The left hand string is the key,
118 | the right hand is the value.
119 |
120 | If the value starts with a `'` or a `"` then the first and last character is
121 | stripped from the value before being captured.
122 |
123 | .Parameter ConfigDir
124 | Path to the directory that contains the `pyvenv.cfg` file.
125 | #>
126 | function Get-PyVenvConfig(
127 | [String]
128 | $ConfigDir
129 | ) {
130 | Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
131 |
132 | # Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
133 | $pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
134 |
135 | # An empty map will be returned if no config file is found.
136 | $pyvenvConfig = @{ }
137 |
138 | if ($pyvenvConfigPath) {
139 |
140 | Write-Verbose "File exists, parse `key = value` lines"
141 | $pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
142 |
143 | $pyvenvConfigContent | ForEach-Object {
144 | $keyval = $PSItem -split "\s*=\s*", 2
145 | if ($keyval[0] -and $keyval[1]) {
146 | $val = $keyval[1]
147 |
148 | # Remove extraneous quotations around a string value.
149 | if ("'""".Contains($val.Substring(0, 1))) {
150 | $val = $val.Substring(1, $val.Length - 2)
151 | }
152 |
153 | $pyvenvConfig[$keyval[0]] = $val
154 | Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
155 | }
156 | }
157 | }
158 | return $pyvenvConfig
159 | }
160 |
161 |
162 | <# Begin Activate script --------------------------------------------------- #>
163 |
164 | # Determine the containing directory of this script
165 | $VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
166 | $VenvExecDir = Get-Item -Path $VenvExecPath
167 |
168 | Write-Verbose "Activation script is located in path: '$VenvExecPath'"
169 | Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
170 | Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
171 |
172 | # Set values required in priority: CmdLine, ConfigFile, Default
173 | # First, get the location of the virtual environment, it might not be
174 | # VenvExecDir if specified on the command line.
175 | if ($VenvDir) {
176 | Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
177 | }
178 | else {
179 | Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
180 | $VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
181 | Write-Verbose "VenvDir=$VenvDir"
182 | }
183 |
184 | # Next, read the `pyvenv.cfg` file to determine any required value such
185 | # as `prompt`.
186 | $pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
187 |
188 | # Next, set the prompt from the command line, or the config file, or
189 | # just use the name of the virtual environment folder.
190 | if ($Prompt) {
191 | Write-Verbose "Prompt specified as argument, using '$Prompt'"
192 | }
193 | else {
194 | Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
195 | if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
196 | Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
197 | $Prompt = $pyvenvCfg['prompt'];
198 | }
199 | else {
200 | Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
201 | Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
202 | $Prompt = Split-Path -Path $venvDir -Leaf
203 | }
204 | }
205 |
206 | Write-Verbose "Prompt = '$Prompt'"
207 | Write-Verbose "VenvDir='$VenvDir'"
208 |
209 | # Deactivate any currently active virtual environment, but leave the
210 | # deactivate function in place.
211 | deactivate -nondestructive
212 |
213 | # Now set the environment variable VIRTUAL_ENV, used by many tools to determine
214 | # that there is an activated venv.
215 | $env:VIRTUAL_ENV = $VenvDir
216 |
217 | if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
218 |
219 | Write-Verbose "Setting prompt to '$Prompt'"
220 |
221 | # Set the prompt to include the env name
222 | # Make sure _OLD_VIRTUAL_PROMPT is global
223 | function global:_OLD_VIRTUAL_PROMPT { "" }
224 | Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
225 | New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
226 |
227 | function global:prompt {
228 | Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
229 | _OLD_VIRTUAL_PROMPT
230 | }
231 | }
232 |
233 | # Clear PYTHONHOME
234 | if (Test-Path -Path Env:PYTHONHOME) {
235 | Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
236 | Remove-Item -Path Env:PYTHONHOME
237 | }
238 |
239 | # Add the venv to the PATH
240 | Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
241 | $Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"
242 |
--------------------------------------------------------------------------------
/venv/Scripts/activate:
--------------------------------------------------------------------------------
1 | # This file must be used with "source bin/activate" *from bash*
2 | # you cannot run it directly
3 |
4 | deactivate () {
5 | # reset old environment variables
6 | if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
7 | PATH="${_OLD_VIRTUAL_PATH:-}"
8 | export PATH
9 | unset _OLD_VIRTUAL_PATH
10 | fi
11 | if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
12 | PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
13 | export PYTHONHOME
14 | unset _OLD_VIRTUAL_PYTHONHOME
15 | fi
16 |
17 | # This should detect bash and zsh, which have a hash command that must
18 | # be called to get it to forget past commands. Without forgetting
19 | # past commands the $PATH changes we made may not be respected
20 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
21 | hash -r 2> /dev/null
22 | fi
23 |
24 | if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
25 | PS1="${_OLD_VIRTUAL_PS1:-}"
26 | export PS1
27 | unset _OLD_VIRTUAL_PS1
28 | fi
29 |
30 | unset VIRTUAL_ENV
31 | if [ ! "${1:-}" = "nondestructive" ] ; then
32 | # Self destruct!
33 | unset -f deactivate
34 | fi
35 | }
36 |
37 | # unset irrelevant variables
38 | deactivate nondestructive
39 |
40 | VIRTUAL_ENV="E:\Normal_Workspace_Collection\Efficient-VRNet-beta\Efficient-VRNet-beta\venv"
41 | export VIRTUAL_ENV
42 |
43 | _OLD_VIRTUAL_PATH="$PATH"
44 | PATH="$VIRTUAL_ENV/Scripts:$PATH"
45 | export PATH
46 |
47 | # unset PYTHONHOME if set
48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash
50 | if [ -n "${PYTHONHOME:-}" ] ; then
51 | _OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
52 | unset PYTHONHOME
53 | fi
54 |
55 | if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
56 | _OLD_VIRTUAL_PS1="${PS1:-}"
57 | PS1="(venv) ${PS1:-}"
58 | export PS1
59 | fi
60 |
61 | # This should detect bash and zsh, which have a hash command that must
62 | # be called to get it to forget past commands. Without forgetting
63 | # past commands the $PATH changes we made may not be respected
64 | if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
65 | hash -r 2> /dev/null
66 | fi
67 |
--------------------------------------------------------------------------------
/venv/Scripts/activate.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | rem This file is UTF-8 encoded, so we need to update the current code page while executing it
4 | for /f "tokens=2 delims=:." %%a in ('"%SystemRoot%\System32\chcp.com"') do (
5 | set _OLD_CODEPAGE=%%a
6 | )
7 | if defined _OLD_CODEPAGE (
8 | "%SystemRoot%\System32\chcp.com" 65001 > nul
9 | )
10 |
11 | set VIRTUAL_ENV=E:\Normal_Workspace_Collection\Efficient-VRNet-beta\Efficient-VRNet-beta\venv
12 |
13 | if not defined PROMPT set PROMPT=$P$G
14 |
15 | if defined _OLD_VIRTUAL_PROMPT set PROMPT=%_OLD_VIRTUAL_PROMPT%
16 | if defined _OLD_VIRTUAL_PYTHONHOME set PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME%
17 |
18 | set _OLD_VIRTUAL_PROMPT=%PROMPT%
19 | set PROMPT=(venv) %PROMPT%
20 |
21 | if defined PYTHONHOME set _OLD_VIRTUAL_PYTHONHOME=%PYTHONHOME%
22 | set PYTHONHOME=
23 |
24 | if defined _OLD_VIRTUAL_PATH set PATH=%_OLD_VIRTUAL_PATH%
25 | if not defined _OLD_VIRTUAL_PATH set _OLD_VIRTUAL_PATH=%PATH%
26 |
27 | set PATH=%VIRTUAL_ENV%\Scripts;%PATH%
28 |
29 | :END
30 | if defined _OLD_CODEPAGE (
31 | "%SystemRoot%\System32\chcp.com" %_OLD_CODEPAGE% > nul
32 | set _OLD_CODEPAGE=
33 | )
34 |
--------------------------------------------------------------------------------
/venv/Scripts/deactivate.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | if defined _OLD_VIRTUAL_PROMPT (
4 | set "PROMPT=%_OLD_VIRTUAL_PROMPT%"
5 | )
6 | set _OLD_VIRTUAL_PROMPT=
7 |
8 | if defined _OLD_VIRTUAL_PYTHONHOME (
9 | set "PYTHONHOME=%_OLD_VIRTUAL_PYTHONHOME%"
10 | set _OLD_VIRTUAL_PYTHONHOME=
11 | )
12 |
13 | if defined _OLD_VIRTUAL_PATH (
14 | set "PATH=%_OLD_VIRTUAL_PATH%"
15 | )
16 |
17 | set _OLD_VIRTUAL_PATH=
18 |
19 | set VIRTUAL_ENV=
20 |
21 | :END
22 |
--------------------------------------------------------------------------------
/venv/Scripts/python.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/venv/Scripts/python.exe
--------------------------------------------------------------------------------
/venv/Scripts/pythonw.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GuanRunwei/ASY-VRNet/0e5d34a753be06cf81f4f11d2a5f6e3b27a286f5/venv/Scripts/pythonw.exe
--------------------------------------------------------------------------------
/venv/pyvenv.cfg:
--------------------------------------------------------------------------------
1 | home = D:\Anaconda\Software
2 | include-system-site-packages = false
3 | version = 3.9.12
4 |
--------------------------------------------------------------------------------
/voc_annotation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import xml.etree.ElementTree as ET
4 |
5 | from utils.utils import get_classes
6 |
7 | # --------------------------------------------------------------------------------------------------------------------------------#
8 | # annotation_mode用于指定该文件运行时计算的内容
9 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
10 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
11 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
12 | # --------------------------------------------------------------------------------------------------------------------------------#
13 | annotation_mode = 0
14 | # -------------------------------------------------------------------#
15 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
16 | # 与训练和预测所用的classes_path一致即可
17 | # 如果生成的2007_train.txt里面没有目标信息
18 | # 那么就是因为classes没有设定正确
19 | # 仅在annotation_mode为0和2的时候有效
20 | # -------------------------------------------------------------------#
21 | classes_path = 'model_data/waterscenes.txt'
22 | # --------------------------------------------------------------------------------------------------------------------------------#
23 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
24 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
25 | # 仅在annotation_mode为0和1的时候有效
26 | # --------------------------------------------------------------------------------------------------------------------------------#
27 | trainval_percent = 0.8
28 | train_percent = 0.8
29 | # -------------------------------------------------------#
30 | # 指向VOC数据集所在的文件夹
31 | # 默认指向根目录下的VOC数据集
32 | # -------------------------------------------------------#
33 | VOCdevkit_path = "E:/dataset_collection/WaterScenes/all-1114-voc/all-1114/all/VOCdevkit"
34 |
35 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
36 | classes, _ = get_classes(classes_path)
37 |
38 |
39 | def convert_annotation(year, image_id, list_file):
40 | # print(year)
41 | # print(image_id)
42 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml' % (year, image_id)), encoding='utf-8')
43 | tree = ET.parse(in_file)
44 | root = tree.getroot()
45 |
46 | for obj in root.iter('object'):
47 | difficult = 0
48 | if obj.find('difficult') != None:
49 | difficult = obj.find('difficult').text
50 | cls = obj.find('name').text
51 | if cls not in classes or int(difficult) == 1:
52 | continue
53 | cls_id = classes.index(cls)
54 | xmlbox = obj.find('bndbox')
55 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)),
56 | int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
57 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
58 |
59 |
60 | if __name__ == "__main__":
61 | random.seed(0)
62 | if annotation_mode == 0 or annotation_mode == 1:
63 | print("Generate txt in ImageSets.")
64 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
65 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
66 | temp_xml = os.listdir(xmlfilepath)
67 | total_xml = []
68 | for xml in temp_xml:
69 | if xml.endswith(".xml"):
70 | total_xml.append(xml)
71 |
72 | num = len(total_xml)
73 | list = range(num)
74 | tv = int(num * trainval_percent)
75 | tr = int(tv * train_percent)
76 | trainval = random.sample(list, tv)
77 | train = random.sample(trainval, tr)
78 |
79 | print("train and val size", tv)
80 | print("train size", tr)
81 | ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
82 | ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
83 | ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
84 | fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
85 |
86 | for i in list:
87 | name = total_xml[i][:-4] + '\n'
88 | if i in trainval:
89 | ftrainval.write(name)
90 | if i in train:
91 | ftrain.write(name)
92 | else:
93 | fval.write(name)
94 | else:
95 | ftest.write(name)
96 |
97 | ftrainval.close()
98 | ftrain.close()
99 | fval.close()
100 | ftest.close()
101 | print("Generate txt in ImageSets done.")
102 |
103 | if annotation_mode == 0 or annotation_mode == 2:
104 | print("Generate 2007_train.txt and 2007_val.txt for train.")
105 | for year, image_set in VOCdevkit_sets:
106 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt' % (year, image_set)),
107 | encoding='utf-8').read().strip().split()
108 | list_file = open('%s_%s.txt' % (year, image_set), 'w', encoding='utf-8')
109 | for image_id in image_ids:
110 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg' % (os.path.abspath(VOCdevkit_path), year, image_id))
111 |
112 | try:
113 | convert_annotation(year, image_id, list_file)
114 | except:
115 | continue
116 | list_file.write('\n')
117 | list_file.close()
118 | print("Generate 2007_train.txt and 2007_val.txt for train done.")
119 |
--------------------------------------------------------------------------------
/voc_annotation_seg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | from PIL import Image
6 | from tqdm import tqdm
7 |
8 | #-------------------------------------------------------#
9 | # 想要增加测试集修改trainval_percent
10 | # 修改train_percent用于改变验证集的比例 9:1
11 | #
12 | # 当前该库将测试集当作验证集使用,不单独划分测试集
13 | #-------------------------------------------------------#
14 | trainval_percent = 1
15 | train_percent = 0.8
16 | #-------------------------------------------------------#
17 | # 指向VOC数据集所在的文件夹
18 | # 默认指向根目录下的VOC数据集
19 | #-------------------------------------------------------#
20 | VOCdevkit_path = 'E:/Big_Datasets/voc_seg/VOCdevkit/VOCdevkit'
21 |
22 | if __name__ == "__main__":
23 | random.seed(0)
24 | print("Generate txt in ImageSets.")
25 | segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass')
26 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation')
27 |
28 | temp_seg = os.listdir(segfilepath)
29 | total_seg = []
30 | for seg in temp_seg:
31 | if seg.endswith(".png"):
32 | total_seg.append(seg)
33 |
34 | num = len(total_seg)
35 | list = range(num)
36 | tv = int(num*trainval_percent)
37 | tr = int(tv*train_percent)
38 | trainval= random.sample(list,tv)
39 | train = random.sample(trainval,tr)
40 |
41 | print("train and val size",tv)
42 | print("traub suze",tr)
43 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
44 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
45 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
46 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
47 |
48 | for i in list:
49 | name = total_seg[i][:-4]+'\n'
50 | if i in trainval:
51 | ftrainval.write(name)
52 | if i in train:
53 | ftrain.write(name)
54 | else:
55 | fval.write(name)
56 | else:
57 | ftest.write(name)
58 |
59 | ftrainval.close()
60 | ftrain.close()
61 | fval.close()
62 | ftest.close()
63 | print("Generate txt in ImageSets done.")
64 |
65 | print("Check datasets format, this may take a while.")
66 | print("检查数据集格式是否符合要求,这可能需要一段时间。")
67 | classes_nums = np.zeros([256], np.int)
68 | for i in tqdm(list):
69 | name = total_seg[i]
70 | png_file_name = os.path.join(segfilepath, name)
71 | if not os.path.exists(png_file_name):
72 | raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
73 |
74 | png = np.array(Image.open(png_file_name), np.uint8)
75 | if len(np.shape(png)) > 2:
76 | print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png))))
77 | print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png))))
78 |
79 | classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
80 |
81 | print("打印像素点的值与数量。")
82 | print('-' * 37)
83 | print("| %15s | %15s |"%("Key", "Value"))
84 | print('-' * 37)
85 | for i in range(256):
86 | if classes_nums[i] > 0:
87 | print("| %15s | %15s |"%(str(i), str(classes_nums[i])))
88 | print('-' * 37)
89 |
90 | if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
91 | print("检测到标签中像素点的值仅包含0与255,数据格式有误。")
92 | print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。")
93 | elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
94 | print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。")
95 |
--------------------------------------------------------------------------------