├── DMIST_100_val.txt
├── DMIST_60_val.txt
├── DMIST_train.txt
├── IRDST_train.txt
├── IRDST_val.txt
├── README.md
├── model_data
├── .DS_Store
├── classes.txt
└── simhei.ttf
├── nets
├── .DS_Store
├── LASNet.py
├── RDIAN
│ ├── .DS_Store
│ ├── BaseConv.py
│ ├── cbam.py
│ ├── direction.py
│ └── segmentation.py
├── __init__.py
├── darknet.py
└── training.py
├── predict.py
├── readme
├── PR1.png
├── PR2.png
├── PR3.png
└── vis.png
├── results
├── DMIST-100
│ ├── ACM.txt
│ ├── AGPCNet.txt
│ ├── DNANet.txt
│ ├── HRNet.txt
│ ├── ISNet.txt
│ ├── ISTDUNet.txt
│ ├── LASNet.txt
│ ├── RDIAN.txt
│ ├── RISTD.txt
│ ├── SANet.txt
│ ├── SSTNet.txt
│ ├── SwinT.txt
│ ├── U2Net.txt
│ ├── UCF.txt
│ └── UIUNet.txt
├── DMIST-60
│ ├── ACM.txt
│ ├── AGPCNet.txt
│ ├── DNANet.txt
│ ├── HRNet.txt
│ ├── ISNet.txt
│ ├── ISTDUNet.txt
│ ├── LASNet.txt
│ ├── RDIAN.txt
│ ├── RISTD.txt
│ ├── SANet.txt
│ ├── SSTNet.txt
│ ├── SwinT.txt
│ ├── U2Net.txt
│ ├── UCF.txt
│ └── UIUNet.txt
└── IRDST
│ ├── ACM.txt
│ ├── AGPCNet.txt
│ ├── DNANet.txt
│ ├── HRNet.txt
│ ├── ISNet.txt
│ ├── ISTDUNet.txt
│ ├── LASNet.txt
│ ├── RDIAN.txt
│ ├── RISTD.txt
│ ├── SANet.txt
│ ├── SSTNet.txt
│ ├── SwinT.txt
│ ├── U2Net.txt
│ ├── UCF.txt
│ └── UIUNet.txt
├── test_DMIST.py
├── train_DMIST.py
├── utils
├── .DS_Store
├── __init__.py
├── callbacks.py
├── dataloader.py
├── dataloader_for_DMIST.py
├── utils.py
├── utils_bbox.py
├── utils_fit.py
└── utils_map.py
└── utils_coco
├── coco_annotation.py
└── coco_to_txt.py
/README.md:
--------------------------------------------------------------------------------
1 | # DMIST-Benchmark
2 | ## ***Dense Moving Infrared Small Target Detection***
3 |
4 | The DMIST benchmark datasets and baseline model implementation of the **TGRS 2024** paper [**Towards Dense Moving Infrared Small Target Detection: New Datasets and Baseline**](https://ieeexplore.ieee.org/document/10636251)
5 |
6 |
7 |
8 |
9 | ## Benchmark Datasets (bounding box-based)
10 | - We synthesize two dense moving infrared small target datasets **DMIST-60** and **DMIST-100** on DAUB.
11 | - Datasets are available at `DMIST` [Baidu](https://pan.baidu.com/s/1LL4rAFfv0Z8HRV4-w8mJjw?pwd=vkcu)/[Google](https://drive.google.com/drive/folders/13CvH9muxs-9fcgeSZJWraw1StWxE3zek?usp=sharing) and `IRDST` [Baidu](https://pan.baidu.com/s/10So3fntJMQxBy-bdSUUD6Q?pwd=t2ti)(code: t2ti). Or you can download `IRDST` directly from the [website](https://xzbai.buaa.edu.cn/datasets.html). In addition, we also introduce a new drone swarm dataset, `DSISTD`[Baidu](https://pan.baidu.com/s/1-di7v8e1Vmp3PzzRqEGKHg?pwd=r5cg)(code: r5cg), that integrates real targets into simulated infrared backgrounds.
12 |
13 | - You need to reorganize these datasets in a format similar to the `DMIST_train.txt` and `DMIST_val.txt` files we provided (`txt files` are used in training). We provide the `txt files` for DMIST and IRDST.
14 | For example:
15 | ```python
16 | train_annotation_path = '/home/LASNet/DMIST_train.txt'
17 | val_annotation_path = '/home/LASNet/DMIST_60_val.txt'
18 | ```
19 | - Or you can generate a new `txt file` based on the path of your datasets. `Text files` (e.g., `DMIST_60_val.txt`) can be generated from `json files` (e.g., `60_coco_val.json`). We also provide all `json files` for `DMIST` [Baidu](https://pan.baidu.com/s/1LL4rAFfv0Z8HRV4-w8mJjw?pwd=vkcu)/[Google](https://drive.google.com/drive/folders/13CvH9muxs-9fcgeSZJWraw1StWxE3zek?usp=sharing) and `IRDST` [Baidu](https://pan.baidu.com/s/10So3fntJMQxBy-bdSUUD6Q?pwd=t2ti).
20 |
21 | ``` python
22 | python utils_coco/coco_to_txt.py
23 | ```
24 |
25 | - The folder structure should look like this:
26 | ```
27 | DMIST
28 | ├─coco_train.json
29 | ├─60_coco_val.json
30 | ├─100_coco_val.json
31 | ├─images
32 | │ ├─train
33 | │ │ ├─data5
34 | │ │ │ ├─0.bmp
35 | │ │ │ ├─0.txt
36 | │ │ │ ├─ ...
37 | │ │ │ ├─2999.bmp
38 | │ │ │ ├─2999.txt
39 | │ │ │ ├─ ...
40 | │ │ ├─ ...
41 | │ ├─test60
42 | │ │ ├─data6
43 | │ │ │ ├─0.bmp
44 | │ │ │ ├─0.txt
45 | │ │ │ ├─ ...
46 | │ │ │ ├─398.bmp
47 | │ │ │ ├─398.txt
48 | │ │ │ ├─ ...
49 | │ │ ├─ ...
50 | │ ├─test100
51 | │ │ ├─ ...
52 | ```
53 |
54 |
55 | ## Prerequisite
56 |
57 | * python==3.10.11
58 | * pytorch==1.12.0
59 | * torchvision==0.13.0
60 | * numpy==1.24.3
61 | * opencv-python==4.7.0.72
62 | * pillow==9.5.0
63 | * scipy==1.10.1
64 | * Tested on Ubuntu 20.04, with CUDA 11.3, and 1x NVIDIA 3090.
65 |
66 |
67 | ## Usage of baseline LASNet
68 |
69 | ### Train
70 |
71 | ```python
72 | CUDA_VISIBLE_DEVICES=0 python train_DMIST.py
73 | ```
74 |
75 | ### Test
76 | - Usually `model_best.pth` is not necessarily the best model. The best model may have a lower val_loss or a higher AP50 during verification.
77 | ```python
78 | "model_path": '/home/LASNet/logs/model.pth'
79 | ```
80 | - You need to change the path of the `json file` of test sets. For example:
81 | ```python
82 | #Use DMIST-100 dataset for test.
83 | cocoGt_path = '/home/public/DMIST/100_coco_val.json'
84 | dataset_img_path = '/home/public/DMIST/'
85 | ```
86 | ```python
87 | python test_DMIST.py
88 | ```
89 |
90 | ### Visulization
91 | - We support `video` and `single-frame image` prediction.
92 | ```python
93 | # mode = "video" #Predict a sequence
94 | mode = "predict" #Predict a single-frame image
95 | ```
96 | ```python
97 | python predict.py
98 | ```
99 |
100 | ## Results
101 | - We optimize old codes and retrain LASNet, achieving slightly better performance results than those reported in our paper.
102 |
103 |
104 |
105 | Method |
106 | Dataset |
107 | mAP50 (%) |
108 | Precision (%) |
109 | Recall (%) |
110 | F1 (%) |
111 | Download |
112 |
113 |
114 | LASNet |
115 | DMIST-60 |
116 | 76.47 |
117 | 95.84 |
118 | 80.07 |
119 | 87.25 |
120 |
121 | Baidu (code: y7ki)
122 |
123 | Google
124 | |
125 |
126 |
127 | LASNet |
128 | DMIST-100 |
129 | 65.70 |
130 | 96.52 |
131 | 68.68 |
132 | 80.25 |
133 |
134 |
135 |
136 |
137 |
138 | - PR curve on DMIST and IRDST datasets in the paper.
139 | - We provide the results on [DMIST-60](./results/DMIST-60), [DMIST-100](./results/DMIST-100) and [IRDST](./results/IRDST), and you can plot them using Python.
140 |
141 |
142 |
143 |
144 |
145 | ## Contact
146 | If any questions, kindly contact with Shengjia Chen via e-mail: csj_uestc@126.com.
147 |
148 | ## References
149 | 1. S. Chen, L. Ji, J. Zhu, M. Ye and X. Yao, "SSTNet: Sliced Spatio-Temporal Network With Cross-Slice ConvLSTM for Moving Infrared Dim-Small Target Detection," in IEEE Transactions on Geoscience and Remote Sensing, vol. 62, pp. 1-12, 2024, Art no. 5000912, doi: 10.1109/TGRS.2024.3350024.
150 | 2. B. Hui et al., “A dataset for infrared image dim-small aircraft target detection and tracking under ground/air background,” Sci. Data Bank, CSTR 31253.11.sciencedb.902, Oct. 2019.
151 |
152 | ## Citation
153 |
154 | If you find this repo useful, please cite our paper.
155 |
156 | ```
157 | @ARTICLE{chen2024dmist,
158 | author={Chen, Shengjia and Ji, Luping and Zhu, Sicheng and Ye, Mao and Ren, Haohao and Sang, Yongsheng},
159 | journal={IEEE Transactions on Geoscience and Remote Sensing},
160 | title={Toward Dense Moving Infrared Small Target Detection: New Datasets and Baseline},
161 | year={2024},
162 | volume={62},
163 | pages={1-13},
164 | doi={10.1109/TGRS.2024.3443280}}
165 |
166 | ```
--------------------------------------------------------------------------------
/model_data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/model_data/.DS_Store
--------------------------------------------------------------------------------
/model_data/classes.txt:
--------------------------------------------------------------------------------
1 | target
--------------------------------------------------------------------------------
/model_data/simhei.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/model_data/simhei.ttf
--------------------------------------------------------------------------------
/nets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/nets/.DS_Store
--------------------------------------------------------------------------------
/nets/LASNet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from scipy.ndimage import gaussian_filter
8 | from .darknet import BaseConv
9 | from .RDIAN.segmentation import RDIAN
10 |
11 |
12 | class LASNet(nn.Module):
13 | def __init__(self, num_classes, num_frame=5):
14 | super(LASNet, self).__init__()
15 |
16 | self.num_frame = num_frame
17 | self.backbone = RDIAN()
18 | self.MAF = Motion_Affinity_Fusion_Module(channels=[128], num_frame=num_frame)
19 | self.head = YOLOXHead(num_classes=num_classes, width = 1.0, in_channels = [128], act = "silu")
20 | self.mapping0 = nn.Sequential(
21 | nn.Conv2d(128*num_frame, 128, kernel_size=1, stride=1, padding=0, bias=False),
22 | nn.LeakyReLU())
23 | self.LAS = Linking_Aware_Sliced_Module(input_dim=64, hidden_dim=[64,64], kernel_size=(3, 3), num_layers=2, num_slices=2, num_frames=self.num_frame)
24 | self.mapping1 = nn.Sequential(
25 | nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
26 | nn.LeakyReLU(),
27 | nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0, bias=False),
28 | nn.LeakyReLU())
29 | self.mapping2 = nn.Sequential(
30 | nn.Conv2d(64, 128, kernel_size=1, stride=1, padding=0, bias=False),
31 | nn.LeakyReLU())
32 | self.conv_backbone = nn.Sequential(
33 | nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
34 | nn.LeakyReLU(),
35 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False),
36 | nn.LeakyReLU(),
37 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
38 | nn.LeakyReLU(),
39 | nn.Conv2d(128, 128, kernel_size=1, stride=1, padding=0, bias=False),
40 | nn.LeakyReLU())
41 | self.motion_head = nn.Sequential(
42 | BaseConv(self.num_frame*128,128,3,1),
43 | BaseConv(128,64,3,1),
44 | BaseConv(64,1,1,1))
45 | self.mm_loss = motion_mask_loss()
46 |
47 |
48 | def forward(self, inputs, multi_targets=None):
49 | feat = []
50 | for i in range(self.num_frame):
51 | feats = self.backbone(inputs[:,:,i,:,:])
52 | feats = self.conv_backbone(feats)
53 | feat.append(feats)
54 |
55 | multi_feat = torch.stack([self.mapping1(feat[i]) for i in range(self.num_frame)], 1)
56 | lstm_output, _ = self.LAS(multi_feat)
57 | motion_relation = lstm_output[-1]
58 | motion = torch.stack([self.mapping2(motion_relation[:,i,:,:,:]) for i in range(self.num_frame)], 1)
59 | feat = self.MAF(feat, motion)
60 | outputs = self.head(feat)
61 |
62 | if self.training:
63 | pred_m = self.motion_head(torch.cat([motion[:,i,:,:,:] for i in range(self.num_frame)], 1))
64 | pred_m = F.interpolate(pred_m, size=[inputs.shape[3], inputs.shape[4]], mode='bilinear', align_corners=True)
65 | mm_loss = self.mm_loss(pred_m, multi_targets)
66 |
67 | if self.training:
68 | return outputs, mm_loss
69 | else:
70 | return outputs
71 |
72 | class motion_mask_loss(nn.Module):
73 | def __init__(self):
74 | super(motion_mask_loss, self).__init__()
75 |
76 | def forward(self, pred_m, multi_targets):
77 | multi_targets = np.array(multi_targets)
78 | gt_target = torch.tensor(multi_targets)
79 | heatmap = torch.zeros(pred_m.shape[0], pred_m.shape[2], pred_m.shape[3]).cuda()
80 | for b in range(gt_target.shape[0]):
81 | for f in range(gt_target.shape[1]):
82 | for t in range(gt_target.shape[2]):
83 | x, y = gt_target[b,f,t,:2]
84 | s_x, s_y = gt_target[b,f,t,2:4]
85 | heatmap[b, int(y):int(y) + int(s_y), int(x):int(x) + int(s_x) ] = 255
86 | target = heatmap.unsqueeze(1)
87 | pred = torch.sigmoid(pred_m)
88 | smooth = 1
89 | intersection = pred * target
90 | intersection_sum = torch.sum(intersection, dim=(1,2,3))
91 | pred_sum = torch.sum(pred, dim=(1,2,3))
92 | target_sum = torch.sum(target, dim=(1,2,3))
93 | loss = (intersection_sum + smooth) / \
94 | (pred_sum + target_sum - intersection_sum + smooth)
95 | loss = 1 - torch.mean(loss)
96 | return loss
97 |
98 |
99 | class Linking_Aware_Node(nn.Module):
100 | def __init__(self, input_dim, hidden_dim, kernel_size, bias):
101 | super(Linking_Aware_Node, self).__init__()
102 |
103 | self.input_dim = input_dim
104 | self.hidden_dim = hidden_dim
105 | self.kernel_size = kernel_size
106 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
107 | self.bias = bias
108 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
109 | out_channels=4 * self.hidden_dim,
110 | kernel_size=self.kernel_size,
111 | padding=self.padding,
112 | bias=self.bias)
113 | self.conv2 = nn.Conv2d(in_channels=4 * self.input_dim,
114 | out_channels=4 * self.hidden_dim,
115 | kernel_size=self.kernel_size,
116 | padding=self.padding,
117 | bias=self.bias)
118 |
119 | def forward(self, input_tensor, input_head, cur_state, multi_head):
120 | h_cur, c_cur = cur_state
121 | combined = torch.cat([input_tensor, h_cur], dim=1)
122 | combined_conv = self.conv(combined)
123 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
124 |
125 | i = torch.sigmoid(cc_i)
126 | f = torch.sigmoid(cc_f)
127 | o = torch.sigmoid(cc_o)
128 | g = torch.tanh(cc_g)
129 |
130 | c_next = f * c_cur + i * g
131 | h_next = o * torch.tanh(c_next)
132 |
133 | m_h, m_c = multi_head
134 | combined2 = torch.cat([input_tensor, h_cur, input_head, m_h], dim=1)
135 | combined_conv2 = self.conv2(combined2)
136 | mm_i, mm_f, mm_o, mm_g = torch.split(combined_conv2, self.hidden_dim, dim=1)
137 |
138 | m_i = torch.sigmoid(mm_i+cc_i)
139 | m_f = torch.sigmoid(mm_f+cc_f)
140 | m_o = torch.sigmoid(mm_o+cc_o)
141 | m_g = torch.tanh(mm_g+cc_g)
142 |
143 | m_c_next = m_f * m_c + m_i * m_g
144 | m_h_next = m_o * torch.tanh(m_c_next)
145 |
146 | return h_next, c_next, m_h_next, m_c_next
147 |
148 | def init_hidden(self, batch_size, image_size):
149 | height, width = image_size
150 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
151 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
152 |
153 | class Linking_Aware_Sliced_Module(nn.Module):
154 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, num_slices, num_frames,
155 | batch_first=True, bias=True, return_all_layers=False):
156 | super(Linking_Aware_Sliced_Module, self).__init__()
157 |
158 | self._check_kernel_size_consistency(kernel_size)
159 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
160 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
161 | if not len(kernel_size) == len(hidden_dim) == num_layers:
162 | raise ValueError('Inconsistent list length.')
163 |
164 | self.input_dim = input_dim
165 | self.hidden_dim = hidden_dim
166 | self.kernel_size = kernel_size
167 | self.num_layers = num_layers
168 | self.deep = num_slices
169 | self.frames = num_frames
170 | self.batch_first = batch_first
171 | self.bias = bias
172 | self.return_all_layers = return_all_layers
173 |
174 | Node_list = {}
175 |
176 | for i in range(self.deep):
177 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
178 |
179 | for j in range(self.num_layers):
180 | Node_list.update({'%d%d'%(i,j): Linking_Aware_Node(input_dim=cur_input_dim,
181 | hidden_dim=self.hidden_dim[j],
182 | kernel_size=self.kernel_size[j],
183 | bias=self.bias)})
184 | self.Node_list = nn.ModuleDict(Node_list)
185 |
186 | self.linking_weight1 = {}
187 | for i in range(0,self.deep):
188 | for j in range(0,self.num_layers):
189 | for k in range(0,self.frames):
190 | self.linking_weight1.update({'%d%d%d'%(i,j,k): nn.Conv3d(self.frames,self.frames,1,1,0)})
191 | self.linking_1 = nn.ModuleDict(self.linking_weight1)
192 |
193 | self.linking_weight2 = {}
194 | for i in range(0,self.deep):
195 | for j in range(0,self.num_layers):
196 | for k in range(0,self.frames):
197 | self.linking_weight2.update({'%d%d%d'%(i,j,k): nn.Conv3d(self.frames,self.frames,1,1,0)})
198 | self.linking_2 = nn.ModuleDict(self.linking_weight2)
199 |
200 | self.state_weight1 = {}
201 | for i in range(0,self.deep):
202 | for j in range(0,self.num_layers):
203 | for t in range(1,self.frames):
204 | self.state_weight1.update({'%d%d%d%d'%(i,j,t,1): nn.Conv3d(t+1,t+1,1,1,0)})
205 | self.state_weight1.update({'%d%d%d%d'%(i,j,t,2): nn.Conv3d(t+1,t+1,1,1,0)})
206 | self.state_1 = nn.ModuleDict(self.state_weight1)
207 |
208 | self.state_weight2 = {}
209 | for i in range(1,self.deep):
210 | for j in range(0,self.num_layers):
211 | for t in range(0,self.frames):
212 | self.state_weight2.update({'%d%d%d%d'%(i,j,t,1): nn.Conv3d(i+1,i+1,1,1,0)})
213 | self.state_weight2.update({'%d%d%d%d'%(i,j,t,2): nn.Conv3d(i+1,i+1,1,1,0)})
214 |
215 | self.state_2 = nn.ModuleDict(self.state_weight2)
216 |
217 |
218 | def forward(self, input_tensor, hidden_state=None):
219 |
220 | if not self.batch_first:
221 | # (t, b, c, h, w) -> (b, t, c, h, w)
222 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
223 |
224 | b, _, _, h, w = input_tensor.size()
225 |
226 | if hidden_state is not None:
227 | raise NotImplementedError()
228 | else:
229 | hidden_state = self._init_hidden(batch_size=b,
230 | image_size=(h, w))
231 | deep_state = self._init_motion_hidden(batch_size=b,
232 | image_size=(h, w), t_len = input_tensor.shape[1])
233 |
234 | layer_output_list = []
235 | last_state_list = []
236 |
237 | seq_len = input_tensor.size(1)
238 | cur_layer_input = input_tensor
239 | head_input = input_tensor
240 |
241 | input_deep_h = {}
242 | input_deep_c = {}
243 | slice_state = {}
244 |
245 | for deep_idx in range(self.deep):
246 |
247 | for layer_idx in range(self.num_layers):
248 |
249 | past_state = []
250 | output_inner = []
251 | h, c = hidden_state['%d%d'%(deep_idx,layer_idx)]
252 |
253 | for t in range(seq_len):
254 |
255 | cur_input = self.linking_1['%d%d%d'%(deep_idx,layer_idx,t)](cur_layer_input)
256 | cur_input2 = nn.functional.softmax(cur_input, dim=1)
257 | selected_cur_input_index = torch.argmax(cur_input2[0,:,0,0,0])
258 | cur_input = cur_input[:,selected_cur_input_index,:,:,:]
259 |
260 | s_input = self.linking_2['%d%d%d'%(deep_idx,layer_idx,t)](head_input)
261 | s_input2 = nn.functional.softmax(s_input, dim=1)
262 | selected_s_input_index = torch.argmax(s_input2[0,:,0,0,0])
263 | s_input = s_input[:,selected_s_input_index,:,:,:]
264 |
265 | if t ==0:
266 | past_state.append([h,c])
267 | selected_state = [h,c]
268 | else:
269 | p_h=[]
270 | p_c=[]
271 | for i in range(len(past_state)):
272 | past_h, past_c = past_state[i]
273 | p_h.append(past_h)
274 | p_c.append(past_c)
275 | pp_h = torch.stack([p_h[i] for i in range(len(p_h))], 1)
276 | pp_c = torch.stack([p_c[i] for i in range(len(p_c))], 1)
277 |
278 | pp_h = self.state_1['%d%d%d%d'%(deep_idx,layer_idx,t,1)](pp_h)
279 | pp_h2 = nn.functional.softmax(pp_h, dim=1)
280 | selected_pp_h_index = torch.argmax(pp_h2[0,:,0,0,0])
281 | pp_h = pp_h[:,selected_pp_h_index,:,:,:]
282 |
283 | pp_c = self.state_1['%d%d%d%d'%(deep_idx,layer_idx,t,2)](pp_c)
284 | pp_c2 = nn.functional.softmax(pp_c, dim=1)
285 | selected_pp_c_index = torch.argmax(pp_c2[0,:,0,0,0])
286 | pp_c = pp_c[:,selected_pp_c_index,:,:,:]
287 |
288 | selected_state = [pp_h,pp_c]
289 |
290 | if deep_idx == 0:
291 | m_h, m_c = deep_state['%d%d'%(layer_idx, t)]
292 | else:
293 | m_h = input_deep_h['%d%d%d'%(deep_idx-1,layer_idx, t)]
294 | m_c = input_deep_c['%d%d%d'%(deep_idx-1,layer_idx, t)]
295 |
296 | slice_state.update({'%d%d%d'%(deep_idx,layer_idx,t): [m_h, m_c]})
297 |
298 | if deep_idx == 0:
299 | selected_slice = slice_state['%d%d%d'%(deep_idx,layer_idx,t)]
300 | else:
301 | mm_h=[]
302 | mm_c=[]
303 | for i in range(deep_idx+1):
304 | past_mh, past_mc = slice_state['%d%d%d'%(i,layer_idx,t)]
305 | mm_h.append(past_mh)
306 | mm_c.append(past_mc)
307 | pp_mh = torch.stack([mm_h[i] for i in range(len(mm_h))], 1)
308 | pp_mc = torch.stack([mm_c[i] for i in range(len(mm_c))], 1)
309 |
310 | pp_mh = self.state_2['%d%d%d%d'%(deep_idx,layer_idx,t,1)](pp_mh)
311 | pp_mc = self.state_2['%d%d%d%d'%(deep_idx,layer_idx,t,2)](pp_mc)
312 |
313 | pp_mh2 = nn.functional.softmax(pp_mh, dim=1)
314 | selected_pp_mh_index = torch.argmax(pp_mh2[0,:,0,0,0])
315 | pp_mh = pp_mh[:,selected_pp_mh_index,:,:,:]
316 |
317 | pp_mc2 = nn.functional.softmax(pp_mc, dim=1)
318 | selected_pp_mc_index = torch.argmax(pp_mc2[0,:,0,0,0])
319 | pp_mc = pp_mc[:,selected_pp_mc_index,:,:,:]
320 |
321 | selected_slice = [pp_mh,pp_mc]
322 |
323 | h, c, m_h, m_c = self.Node_list['%d%d'%(deep_idx,layer_idx)](input_tensor=cur_input, input_head = s_input, cur_state=selected_state, multi_head=selected_slice)
324 |
325 | past_state.append([h,c])
326 | output_inner.append(h+m_h)
327 |
328 | input_deep_h.update({'%d%d%d'%(deep_idx,layer_idx,t): m_h})
329 | input_deep_c.update({'%d%d%d'%(deep_idx,layer_idx,t): m_c})
330 |
331 | layer_output = torch.stack(output_inner, dim=1)
332 | head_output = torch.stack(([input_deep_h['%d%d%d'%(deep_idx, layer_idx, t)] for t in range (seq_len)]), dim=1)
333 |
334 | cur_layer_input = layer_output
335 | head_input = head_output
336 |
337 | layer_output_list.append(layer_output)
338 | last_state_list.append([h, c])
339 |
340 | if not self.return_all_layers:
341 | layer_output_list = layer_output_list[-1:]
342 | last_state_list = last_state_list[-1:]
343 |
344 | return layer_output_list, last_state_list
345 |
346 |
347 | def _init_hidden(self, batch_size, image_size):
348 | init_states = {}
349 | for i in range(0,self.deep):
350 | for j in range(0,self.num_layers):
351 | init_states.update({'%d%d'%(i,j): self.Node_list['%d%d'%(i,j)].init_hidden(batch_size, image_size)})
352 | return init_states
353 |
354 | def _init_motion_hidden(self, batch_size, image_size, t_len):
355 |
356 | init_states = {}
357 | for i in range(0,self.num_layers):
358 | for j in range(0,t_len):
359 | init_states.update({'%d%d'%(i,j): self.Node_list['00'].init_hidden(batch_size, image_size)})
360 | return init_states
361 |
362 |
363 | @staticmethod
364 | def _check_kernel_size_consistency(kernel_size):
365 | if not (isinstance(kernel_size, tuple) or
366 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
367 | raise ValueError('`kernel_size` must be tuple or list of tuples')
368 |
369 | @staticmethod
370 | def _extend_for_multilayer(param, num_layers):
371 | if not isinstance(param, list):
372 | param = [param] * num_layers
373 | return param
374 |
375 |
376 | class Motion_Affinity_Fusion_Module(nn.Module):
377 | def __init__(self, channels=[128,256,512] ,num_frame=5):
378 | super().__init__()
379 | self.num_frame = num_frame
380 | self.weight = nn.ParameterList(torch.nn.Parameter(torch.tensor([0.25]), requires_grad=True) for _ in range(num_frame))
381 |
382 | self.conv_ref = nn.Sequential(
383 | BaseConv(channels[0]*(self.num_frame-1), channels[0]*2,3,1),
384 | BaseConv(channels[0]*2,channels[0],3,1,act='sigmoid')
385 | )
386 | self.conv_cur = nn.Sequential(
387 | BaseConv(channels[0], channels[0],3,1),
388 | BaseConv(channels[0], channels[0],3,1)
389 | )
390 |
391 | self.conv_gl = nn.Sequential(
392 | BaseConv(channels[0]*2, channels[0]*2,3,1),
393 | BaseConv(channels[0]*2,channels[0],3,1)
394 | )
395 |
396 | self.conv_gl_mix = nn.Sequential(
397 | BaseConv(channels[0], channels[0],3,1),
398 | BaseConv(channels[0],channels[0],3,1)
399 | )
400 | self.conv_cr_mix = nn.Sequential(
401 | BaseConv(channels[0]*2, channels[0]*2,3,1),
402 | BaseConv(channels[0]*2,channels[0],3,1)
403 | )
404 | self.conv_final = nn.Sequential(
405 | BaseConv(channels[0]*2, channels[0]*2,3,1),
406 | BaseConv(channels[0]*2,channels[0],3,1)
407 | )
408 |
409 | def forward(self, feats, motion):
410 | f_feats = []
411 | r_feat = torch.cat([feats[i] for i in range(self.num_frame-1)],dim=1)
412 | r_feat = self.conv_ref(r_feat)
413 | c_feat = self.conv_cur(r_feat*feats[-1])
414 | c_feat = self.conv_cr_mix(torch.cat([c_feat, feats[-1]], dim=1))
415 |
416 | r_feats = torch.stack([self.conv_gl(torch.cat([motion[:,i,:,:,:], feats[-1]], dim=1))*self.weight[i] for i in range(self.num_frame)], dim=1)
417 | r_feat= self.conv_gl_mix(torch.sum(r_feats, dim=1))
418 | c_feat = self.conv_final(torch.cat([r_feat,c_feat], dim=1))
419 | f_feats.append(c_feat)
420 |
421 | return f_feats
422 |
423 |
424 |
425 | class YOLOXHead(nn.Module):
426 | def __init__(self, num_classes, width = 1.0, in_channels = [16, 32, 64], act = "silu"):
427 | super().__init__()
428 | Conv = BaseConv
429 |
430 | self.cls_convs = nn.ModuleList()
431 | self.reg_convs = nn.ModuleList()
432 | self.cls_preds = nn.ModuleList()
433 | self.reg_preds = nn.ModuleList()
434 | self.obj_preds = nn.ModuleList()
435 | self.stems = nn.ModuleList()
436 |
437 | for i in range(len(in_channels)):
438 | self.stems.append(BaseConv(in_channels = int(in_channels[i] * width), out_channels = int(256 * width), ksize = 1, stride = 1, act = act))
439 | self.cls_convs.append(nn.Sequential(*[
440 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
441 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
442 | ]))
443 | self.cls_preds.append(
444 | nn.Conv2d(in_channels = int(256 * width), out_channels = num_classes, kernel_size = 1, stride = 1, padding = 0)
445 | )
446 |
447 | self.reg_convs.append(nn.Sequential(*[
448 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
449 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act)
450 | ]))
451 | self.reg_preds.append(
452 | nn.Conv2d(in_channels = int(256 * width), out_channels = 4, kernel_size = 1, stride = 1, padding = 0)
453 | )
454 | self.obj_preds.append(
455 | nn.Conv2d(in_channels = int(256 * width), out_channels = 1, kernel_size = 1, stride = 1, padding = 0)
456 | )
457 |
458 | def forward(self, inputs):
459 |
460 | outputs = []
461 | for k, x in enumerate(inputs):
462 | x = self.stems[k](x)
463 | cls_feat = self.cls_convs[k](x)
464 | cls_output = self.cls_preds[k](cls_feat)
465 | reg_feat = self.reg_convs[k](x)
466 | reg_output = self.reg_preds[k](reg_feat)
467 | obj_output = self.obj_preds[k](reg_feat)
468 | output = torch.cat([reg_output, obj_output, cls_output], 1)
469 | outputs.append(output)
470 | return outputs
471 |
--------------------------------------------------------------------------------
/nets/RDIAN/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/nets/RDIAN/.DS_Store
--------------------------------------------------------------------------------
/nets/RDIAN/BaseConv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class SiLU(nn.Module):
5 | """SiLU激活函数"""
6 | @staticmethod
7 | def forward(x):
8 | return x * torch.sigmoid(x)
9 |
10 | def get_activation(name="silu", inplace=True):
11 | # inplace为True,将会改变输入的数据 (降低显存),否则不会改变原输入,只会产生新的输出
12 | if name == "silu":
13 | module = SiLU()
14 | elif name == "relu":
15 | module = nn.ReLU(inplace=inplace)
16 | elif name == "lrelu":
17 | module = nn.LeakyReLU(0.1, inplace=inplace)
18 | else:
19 | raise AttributeError("Unsupported act type: {}".format(name))
20 | return module
21 |
22 |
23 | class BaseConv(nn.Module):
24 | """带归一化和激活函数的标准卷积并且保证宽高不变"""
25 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
26 | """
27 | :param in_channels: 输入通道
28 | :param out_channels: 输出通道
29 | :param ksize: 卷积核大小
30 | :param stride: 步长
31 | :param groups: 是否分组卷积
32 | :param bias: 偏置
33 | :param act: 所选激活函数
34 | """
35 | super().__init__()
36 | pad = (ksize - 1) // 2
37 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
38 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
39 | self.act = get_activation(act, inplace=True)
40 |
41 | def forward(self, x):
42 | return self.act(self.bn(self.conv(x)))
43 |
44 | def fuseforward(self, x):
45 | return self.act(self.conv(x))
--------------------------------------------------------------------------------
/nets/RDIAN/cbam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class BasicConv(nn.Module):
7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8 | super(BasicConv, self).__init__()
9 | self.out_channels = out_planes
10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12 | self.relu = nn.ReLU() if relu else None
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.bn is not None:
17 | x = self.bn(x)
18 | if self.relu is not None:
19 | x = self.relu(x)
20 | return x
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | class ChannelGate(nn.Module):
27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 | super(ChannelGate, self).__init__()
29 | self.gate_channels = gate_channels
30 | self.mlp = nn.Sequential(
31 | Flatten(),
32 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
33 | nn.ReLU(),
34 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
35 | )
36 | self.pool_types = pool_types
37 | def forward(self, x):
38 | channel_att_sum = None
39 | for pool_type in self.pool_types:
40 | if pool_type=='avg':
41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp( avg_pool )
43 | elif pool_type=='max':
44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 | channel_att_raw = self.mlp( max_pool )
46 | elif pool_type=='lp':
47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp( lp_pool )
49 | elif pool_type=='lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp( lse_pool )
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 | return x * scale
61 |
62 | def logsumexp_2d(tensor):
63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 | return outputs
67 |
68 | class ChannelPool(nn.Module):
69 | def forward(self, x):
70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 |
72 | class SpatialGate(nn.Module):
73 | def __init__(self):
74 | super(SpatialGate, self).__init__()
75 | kernel_size = 7
76 | self.compress = ChannelPool()
77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78 | def forward(self, x):
79 | x_compress = self.compress(x)
80 | x_out = self.spatial(x_compress)
81 | scale = F.sigmoid(x_out)
82 | return x * scale
83 |
84 | class CBAM(nn.Module):
85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86 | super(CBAM, self).__init__()
87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88 | self.SpatialGate = SpatialGate()
89 | def forward(self, x):
90 | x_out = self.ChannelGate(x)
91 | x_out = self.SpatialGate(x_out)
92 | return x_out
93 |
--------------------------------------------------------------------------------
/nets/RDIAN/direction.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class Conv_d11(nn.Module):
7 | def __init__(self):
8 | super(Conv_d11, self).__init__()
9 | kernel = [[-1, 0, 0, 0, 0],
10 | [0, 0, 0,0,0],
11 | [0, 0, 1,0,0],
12 | [0, 0, 0,0,0],
13 | [0,0,0,0,0]]
14 |
15 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
16 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
17 |
18 | def forward(self, input):
19 | ##print("input:",input.shape)
20 | return F.conv2d(input, self.weight, padding=2)
21 |
22 | class Conv_d12(nn.Module):
23 | def __init__(self):
24 | super(Conv_d12, self).__init__()
25 | kernel = [[0, 0, -1, 0, 0],
26 | [0, 0, 0,0,0],
27 | [0, 0, 1,0,0],
28 | [0, 0, 0,0,0],
29 | [0,0,0,0,0]]
30 |
31 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
32 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
33 |
34 | def forward(self, input):
35 | return F.conv2d(input, self.weight, padding=2)
36 |
37 |
38 | class Conv_d13(nn.Module):
39 | def __init__(self):
40 | super(Conv_d13, self).__init__()
41 | kernel = [[0, 0, 0, 0, -1],
42 | [0, 0, 0,0,0],
43 | [0, 0, 1,0,0],
44 | [0, 0, 0,0,0],
45 | [0,0,0,0,0]]
46 |
47 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
48 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
49 |
50 | def forward(self, input):
51 | return F.conv2d(input, self.weight, padding=2)
52 |
53 |
54 | class Conv_d14(nn.Module):
55 | def __init__(self):
56 | super(Conv_d14, self).__init__()
57 | kernel = [[0, 0, 0, 0, 0],
58 | [0, 0, 0,0,0],
59 | [0, 0, 1,0,-1],
60 | [0, 0, 0,0,0],
61 | [0,0,0,0,0]]
62 |
63 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
64 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
65 |
66 | def forward(self, input):
67 | return F.conv2d(input, self.weight, padding=2)
68 |
69 |
70 | class Conv_d15(nn.Module):
71 | def __init__(self):
72 | super(Conv_d15, self).__init__()
73 | kernel = [[0, 0, 0, 0, 0],
74 | [0, 0, 0,0,0],
75 | [0, 0, 1,0,0],
76 | [0, 0, 0,0,0],
77 | [0,0,0,0,-1]]
78 |
79 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
80 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
81 |
82 | def forward(self, input):
83 | return F.conv2d(input, self.weight, padding=2)
84 |
85 | class Conv_d16(nn.Module):
86 | def __init__(self):
87 | super(Conv_d16, self).__init__()
88 | kernel = [[0, 0, 0, 0, 0],
89 | [0, 0, 0,0,0],
90 | [0, 0, 1,0,0],
91 | [0, 0, 0,0,0],
92 | [0,0,-1,0,0]]
93 |
94 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
95 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
96 |
97 | def forward(self, input):
98 | return F.conv2d(input, self.weight, padding=2)
99 |
100 | class Conv_d17(nn.Module):
101 | def __init__(self):
102 | super(Conv_d17, self).__init__()
103 | kernel = [[0, 0, 0, 0, 0],
104 | [0, 0, 0,0,0],
105 | [0, 0, 1,0,0],
106 | [0, 0, 0,0,0],
107 | [-1,0,0,0,0]]
108 |
109 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
110 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
111 |
112 | def forward(self, input):
113 | return F.conv2d(input, self.weight, padding=2)
114 |
115 | class Conv_d18(nn.Module):
116 | def __init__(self):
117 | super(Conv_d18, self).__init__()
118 | kernel = [[0, 0, 0, 0, 0],
119 | [0, 0, 0,0,0],
120 | [-1, 0, 1,0,0],
121 | [0, 0, 0,0,0],
122 | [0,0,0,0,0]]
123 |
124 | kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
125 | self.weight = nn.Parameter(data=kernel, requires_grad=False)
126 |
127 | def forward(self, input):
128 | return F.conv2d(input, self.weight, padding=2)
--------------------------------------------------------------------------------
/nets/RDIAN/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import cv2
5 |
6 | from .cbam import *
7 | from .direction import *
8 | from .BaseConv import BaseConv
9 |
10 | def conv_batch(in_num, out_num, kernel_size=3, padding=1, stride=1):
11 | return nn.Sequential(
12 | nn.Conv2d(in_num, out_num, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
13 | nn.BatchNorm2d(out_num),
14 | nn.LeakyReLU())
15 |
16 | class NewBlock(nn.Module):
17 | def __init__(self, in_channels, stride,kernel_size,padding):
18 | super(NewBlock, self).__init__()
19 | reduced_channels = int(in_channels/2)
20 | self.layer1 = conv_batch(in_channels, reduced_channels, kernel_size=kernel_size, padding=padding, stride=stride)
21 | self.layer2 = conv_batch(reduced_channels, in_channels, kernel_size=kernel_size, padding=padding, stride=stride)
22 |
23 | def forward(self, x):
24 | residual = x
25 | out = self.layer1(x)
26 | out = self.layer2(out)
27 | out += residual
28 | return out
29 |
30 | class RDIAN(nn.Module):
31 | def __init__(self):
32 |
33 | super(RDIAN, self).__init__()
34 | accumulate_params = "none"
35 | self.conv1 = conv_batch(1, 16)
36 | self.conv2 = conv_batch(16, 32, stride=2)
37 | self.residual_block0 = self.make_layer(NewBlock, in_channels=32, num_blocks=1, kernel_size=1,padding=0,stride=1)
38 | self.residual_block1 = self.make_layer(NewBlock, in_channels=32, num_blocks=2, kernel_size=3,padding=1,stride=1)
39 | self.residual_block2 = self.make_layer(NewBlock, in_channels=32, num_blocks=2, kernel_size=5,padding=2,stride=1)
40 | self.residual_block3 = self.make_layer(NewBlock, in_channels=32, num_blocks=2, kernel_size=7,padding=3,stride=1)
41 | self.cbam = CBAM(32, 32)
42 | self.conv_cat = conv_batch(4*32, 32, 3, padding=1)
43 | self.conv_res = conv_batch(16, 32, 1, padding=0)
44 | self.relu = nn.ReLU(True)
45 |
46 | self.d11=Conv_d11()
47 | self.d12=Conv_d12()
48 | self.d13=Conv_d13()
49 | self.d14=Conv_d14()
50 | self.d15=Conv_d15()
51 | self.d16=Conv_d16()
52 | self.d17=Conv_d17()
53 | self.d18=Conv_d18()
54 |
55 | def forward(self, x):
56 |
57 | x = x[:,-1,:,:].unsqueeze(1)
58 |
59 | _, _, hei, wid = x.shape
60 | d11 = self.d11(x)
61 | d12 = self.d12(x)
62 | d13 = self.d13(x)
63 | d14 = self.d14(x)
64 | d15 = self.d15(x)
65 | d16 = self.d16(x)
66 | d17 = self.d17(x)
67 | d18 = self.d18(x)
68 | md = d11.mul(d15) + d12.mul(d16) + d13.mul(d17) + d14.mul(d18)
69 | md = F.sigmoid(md)
70 |
71 | out1= self.conv1(x)
72 | out2 = out1.mul(md)
73 | out = self.conv2(out1 + out2)
74 |
75 | c0 = self.residual_block0(out)
76 | c1 = self.residual_block1(out)
77 | c2 = self.residual_block2(out)
78 | c3 = self.residual_block3(out)
79 |
80 | x_cat = self.conv_cat(torch.cat((c0, c1, c2, c3), dim=1))
81 | x_a = self.cbam(x_cat)
82 |
83 | temp = F.interpolate(x_a, size=[hei, wid], mode='bilinear')
84 | temp2 = self.conv_res(out1)
85 | x_new = self.relu( temp + temp2)
86 |
87 | return x_new
88 |
89 | def make_layer(self, block, in_channels, num_blocks, stride, kernel_size, padding):
90 | layers = []
91 | for i in range(0, num_blocks):
92 | layers.append(block(in_channels, stride, kernel_size, padding))
93 | return nn.Sequential(*layers)
94 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/nets/darknet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Copyright (c) Megvii, Inc. and its affiliates.
4 |
5 | import torch
6 | from torch import nn
7 |
8 | class SiLU(nn.Module):
9 | @staticmethod
10 | def forward(x):
11 | return x * torch.sigmoid(x)
12 |
13 | def get_activation(name="silu", inplace=True):
14 | if name == "silu":
15 | module = SiLU()
16 | elif name == "relu":
17 | module = nn.ReLU(inplace=inplace)
18 | elif name == "lrelu":
19 | module = nn.LeakyReLU(0.1, inplace=inplace)
20 | elif name == "sigmoid":
21 | module = nn.Sigmoid()
22 | else:
23 | raise AttributeError("Unsupported act type: {}".format(name))
24 | return module
25 |
26 | class Focus(nn.Module):
27 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
28 | super().__init__()
29 | self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
30 |
31 | def forward(self, x):
32 | patch_top_left = x[..., ::2, ::2]
33 | patch_bot_left = x[..., 1::2, ::2]
34 | patch_top_right = x[..., ::2, 1::2]
35 | patch_bot_right = x[..., 1::2, 1::2]
36 | x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
37 | return self.conv(x)
38 |
39 | class BaseConv(nn.Module):
40 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
41 | super().__init__()
42 | pad = (ksize - 1) // 2
43 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
44 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
45 | self.act = get_activation(act, inplace=True)
46 |
47 | def forward(self, x):
48 | return self.act(self.bn(self.conv(x)))
49 |
50 | def fuseforward(self, x):
51 | return self.act(self.conv(x))
52 |
53 | class DWConv(nn.Module):
54 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
55 | super().__init__()
56 | self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
57 | self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
58 |
59 | def forward(self, x):
60 | x = self.dconv(x)
61 | return self.pconv(x)
62 |
63 | class SPPBottleneck(nn.Module):
64 | def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
65 | super().__init__()
66 | hidden_channels = in_channels // 2
67 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
68 | self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
69 | conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
70 | self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
71 |
72 | def forward(self, x):
73 | x = self.conv1(x)
74 | x = torch.cat([x] + [m(x) for m in self.m], dim=1)
75 | x = self.conv2(x)
76 | return x
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | # Standard bottleneck
81 | def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
82 | super().__init__()
83 | hidden_channels = int(out_channels * expansion)
84 | Conv = DWConv if depthwise else BaseConv
85 |
86 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
87 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
88 | self.use_add = shortcut and in_channels == out_channels
89 |
90 | def forward(self, x):
91 | y = self.conv2(self.conv1(x))
92 | if self.use_add:
93 | y = y + x
94 | return y
95 |
96 | class CSPLayer(nn.Module):
97 | def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
98 | # ch_in, ch_out, number, shortcut, groups, expansion
99 | super().__init__()
100 | hidden_channels = int(out_channels * expansion)
101 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
102 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
103 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
104 | module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
105 | self.m = nn.Sequential(*module_list)
106 |
107 | def forward(self, x):
108 |
109 | x_1 = self.conv1(x)
110 | x_2 = self.conv2(x)
111 | x_1 = self.m(x_1)
112 | x = torch.cat((x_1, x_2), dim=1)
113 | return self.conv3(x)
114 |
115 | class CSPDarknet(nn.Module):
116 | def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
117 | super().__init__()
118 | assert out_features, "please provide output features of Darknet"
119 | self.out_features = out_features
120 | Conv = DWConv if depthwise else BaseConv
121 | base_channels = int(wid_mul * 64) # 64
122 | base_depth = max(round(dep_mul * 3), 1) # 3
123 |
124 | self.stem = Focus(3, base_channels, ksize=3, act=act)
125 | self.dark2 = nn.Sequential(
126 | Conv(base_channels, base_channels * 2, 3, 2, act=act),
127 | CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
128 | )
129 | self.dark3 = nn.Sequential(
130 | Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
131 | CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
132 | )
133 | self.dark4 = nn.Sequential(
134 | Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
135 | CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
136 | )
137 | self.dark5 = nn.Sequential(
138 | Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
139 | SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
140 | CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
141 | )
142 |
143 | def forward(self, x):
144 | outputs = {}
145 | x = self.stem(x)
146 | outputs["stem"] = x
147 | x = self.dark2(x)
148 | outputs["dark2"] = x
149 | x = self.dark3(x)
150 | outputs["dark3"] = x
151 | x = self.dark4(x)
152 | outputs["dark4"] = x
153 | x = self.dark5(x)
154 | outputs["dark5"] = x
155 | return {k: v for k, v in outputs.items() if k in self.out_features}
156 |
157 |
158 | if __name__ == '__main__':
159 | print(CSPDarknet(1, 1))
--------------------------------------------------------------------------------
/nets/training.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Copyright (c) Megvii, Inc. and its affiliates.
4 | import math
5 | from copy import deepcopy
6 | from functools import partial
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torchvision.ops.focal_loss import sigmoid_focal_loss
12 |
13 |
14 | class IOUloss(nn.Module):
15 | def __init__(self, reduction="none", loss_type="iou"):
16 | super(IOUloss, self).__init__()
17 | self.reduction = reduction
18 | self.loss_type = loss_type
19 |
20 | def forward(self, pred, target):
21 | assert pred.shape[0] == target.shape[0]
22 |
23 | pred = pred.view(-1, 4)
24 | target = target.view(-1, 4)
25 | tl = torch.max(
26 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
27 | )
28 | br = torch.min(
29 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
30 | )
31 |
32 | area_p = torch.prod(pred[:, 2:], 1)
33 | area_g = torch.prod(target[:, 2:], 1)
34 |
35 | en = (tl < br).type(tl.type()).prod(dim=1)
36 | area_i = torch.prod(br - tl, 1) * en
37 | area_u = area_p + area_g - area_i
38 | iou = (area_i) / (area_u + 1e-16)
39 |
40 | if self.loss_type == "iou":
41 | loss = 1 - iou ** 2
42 | elif self.loss_type == "giou":
43 | c_tl = torch.min(
44 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
45 | )
46 | c_br = torch.max(
47 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
48 | )
49 | area_c = torch.prod(c_br - c_tl, 1)
50 | giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
51 | loss = 1 - giou.clamp(min=-1.0, max=1.0)
52 | elif self.loss_type == 'ciou':
53 | b1_cxy = pred[:,:2]
54 | b2_cxy = target[:,:2]
55 | center_distance = torch.sum(torch.pow((b1_cxy - b2_cxy), 2), axis=-1)
56 | enclose_mins = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
57 | enclose_maxes = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
58 | enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(br))
59 | enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
60 | ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
61 | v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(pred[:, 2]/torch.clamp(pred[:, 3],min = 1e-6)) - torch.atan(target[:, 2]/torch.clamp(target[:, 3],min = 1e-6))), 2)
62 | alpha = v / torch.clamp((1.0 - iou + v),min=1e-6)
63 | ciou = ciou - alpha * v
64 | loss = 1 - ciou.clamp(min=-1.0, max=1.0)
65 |
66 | if self.reduction == "mean":
67 | loss = loss.mean()
68 | elif self.reduction == "sum":
69 | loss = loss.sum()
70 |
71 | return loss
72 |
73 | class YOLOLoss(nn.Module):
74 | def __init__(self, num_classes, fp16, strides=[8, 16, 32]):
75 | super().__init__()
76 | self.num_classes = num_classes
77 | self.strides = strides
78 |
79 | self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
80 | self.iou_loss = IOUloss(reduction="none")
81 | self.grids = [torch.zeros(1)] * len(strides)
82 | self.fp16 = fp16
83 |
84 | def forward(self, inputs, labels=None):
85 | outputs = []
86 | x_shifts = []
87 | y_shifts = []
88 | expanded_strides = []
89 |
90 | #-----------------------------------------------#
91 | # inputs [[batch_size, num_classes + 5, 20, 20]
92 | # [batch_size, num_classes + 5, 40, 40]
93 | # [batch_size, num_classes + 5, 80, 80]]
94 | # outputs [[batch_size, 400, num_classes + 5]
95 | # [batch_size, 1600, num_classes + 5]
96 | # [batch_size, 6400, num_classes + 5]]
97 | # x_shifts [[batch_size, 400]
98 | # [batch_size, 1600]
99 | # [batch_size, 6400]]
100 | #-----------------------------------------------#
101 | for k, (stride, output) in enumerate(zip(self.strides, inputs)):
102 | output, grid = self.get_output_and_grid(output, k, stride)
103 | x_shifts.append(grid[:, :, 0])
104 | y_shifts.append(grid[:, :, 1])
105 | expanded_strides.append(torch.ones_like(grid[:, :, 0]) * stride)
106 |
107 | outputs.append(output)
108 |
109 | return self.get_losses(x_shifts, y_shifts, expanded_strides, labels, torch.cat(outputs, 1))
110 |
111 | def get_output_and_grid(self, output, k, stride):
112 | grid = self.grids[k]
113 | hsize, wsize = output.shape[-2:]
114 | if grid.shape[2:4] != output.shape[2:4]:
115 | yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing='ij')
116 | grid = torch.stack((xv, yv), 2).view(1, hsize, wsize, 2).type(output.type())
117 | self.grids[k] = grid
118 | grid = grid.view(1, -1, 2)
119 |
120 | output = output.flatten(start_dim=2).permute(0, 2, 1)
121 | output[..., :2] = (output[..., :2] + grid.type_as(output)) * stride
122 | output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
123 | return output, grid
124 |
125 | def get_losses(self, x_shifts, y_shifts, expanded_strides, labels, outputs):
126 | #-----------------------------------------------#
127 | # [batch, n_anchors_all, 4]
128 | #-----------------------------------------------#
129 | bbox_preds = outputs[:, :, :4] #4, 4096, 4
130 |
131 | #-----------------------------------------------#
132 | # [batch, n_anchors_all, 1]
133 | #-----------------------------------------------#
134 | obj_preds = outputs[:, :, 4:5]
135 |
136 | #-----------------------------------------------#
137 | # [batch, n_anchors_all, n_cls]
138 | #-----------------------------------------------#
139 | cls_preds = outputs[:, :, 5:]
140 |
141 | total_num_anchors = outputs.shape[1]
142 | #-----------------------------------------------#
143 | # x_shifts [1, n_anchors_all]
144 | # y_shifts [1, n_anchors_all]
145 | # expanded_strides [1, n_anchors_all]
146 | #-----------------------------------------------#
147 | x_shifts = torch.cat(x_shifts, 1).type_as(outputs)
148 | y_shifts = torch.cat(y_shifts, 1).type_as(outputs)
149 | expanded_strides = torch.cat(expanded_strides, 1).type_as(outputs)
150 |
151 | cls_targets = []
152 | reg_targets = []
153 | obj_targets = []
154 | fg_masks = []
155 |
156 | num_fg = 0.0
157 | for batch_idx in range(outputs.shape[0]):
158 | num_gt = len(labels[batch_idx])
159 | if num_gt == 0:
160 | cls_target = outputs.new_zeros((0, self.num_classes))
161 | reg_target = outputs.new_zeros((0, 4))
162 | obj_target = outputs.new_zeros((total_num_anchors, 1))
163 | fg_mask = outputs.new_zeros(total_num_anchors).bool()
164 | else:
165 | #-----------------------------------------------#
166 | # gt_bboxes_per_image [num_gt, num_classes]
167 | # gt_classes [num_gt]
168 | # bboxes_preds_per_image [n_anchors_all, 4]
169 | # cls_preds_per_image [n_anchors_all, num_classes]
170 | # obj_preds_per_image [n_anchors_all, 1]
171 | #-----------------------------------------------#
172 | gt_bboxes_per_image = labels[batch_idx][..., :4].type_as(outputs)
173 | gt_classes = labels[batch_idx][..., 4].type_as(outputs)
174 | bboxes_preds_per_image = bbox_preds[batch_idx]
175 | cls_preds_per_image = cls_preds[batch_idx]
176 | obj_preds_per_image = obj_preds[batch_idx]
177 |
178 | gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = self.get_assignments(
179 | num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image,
180 | expanded_strides, x_shifts, y_shifts,
181 | )
182 | torch.cuda.empty_cache()
183 | num_fg += num_fg_img
184 | cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes).float() * pred_ious_this_matching.unsqueeze(-1)
185 | obj_target = fg_mask.unsqueeze(-1)
186 | reg_target = gt_bboxes_per_image[matched_gt_inds]
187 | cls_targets.append(cls_target)
188 | reg_targets.append(reg_target)
189 | obj_targets.append(obj_target.type(cls_target.type()))
190 | fg_masks.append(fg_mask)
191 |
192 | cls_targets = torch.cat(cls_targets, 0)
193 | reg_targets = torch.cat(reg_targets, 0)
194 | obj_targets = torch.cat(obj_targets, 0)
195 | fg_masks = torch.cat(fg_masks, 0)
196 |
197 | num_fg = max(num_fg, 1)
198 | loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum()
199 | loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum()
200 | loss_cls = (self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
201 | # loss_obj = (sigmoid_focal_loss(obj_preds.view(-1, 1), obj_targets)).sum()
202 | # loss_cls = (sigmoid_focal_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets)).sum()
203 | reg_weight = 5.0
204 | loss = reg_weight * loss_iou + loss_obj + loss_cls
205 |
206 | return loss / num_fg
207 |
208 | @torch.no_grad()
209 | def get_assignments(self, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, cls_preds_per_image, obj_preds_per_image, expanded_strides, x_shifts, y_shifts):
210 | #-------------------------------------------------------#
211 | # fg_mask [n_anchors_all]
212 | # is_in_boxes_and_center [num_gt, len(fg_mask)]
213 | #-------------------------------------------------------#
214 | fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt)
215 |
216 | #-------------------------------------------------------#
217 | # fg_mask [n_anchors_all]
218 | # bboxes_preds_per_image [fg_mask, 4]
219 | # cls_preds_ [fg_mask, num_classes]
220 | # obj_preds_ [fg_mask, 1]
221 | #-------------------------------------------------------#
222 | bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
223 | cls_preds_ = cls_preds_per_image[fg_mask]
224 | obj_preds_ = obj_preds_per_image[fg_mask]
225 | num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
226 |
227 | #-------------------------------------------------------#
228 | # pair_wise_ious [num_gt, fg_mask]
229 | #-------------------------------------------------------#
230 | pair_wise_ious = self.bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
231 | pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
232 |
233 | #-------------------------------------------------------#
234 | # cls_preds_ [num_gt, fg_mask, num_classes]
235 | # gt_cls_per_image [num_gt, fg_mask, num_classes]
236 | #-------------------------------------------------------#
237 | if self.fp16:
238 | with torch.cuda.amp.autocast(enabled=False):
239 | cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
240 | gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
241 | pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
242 | else:
243 | cls_preds_ = cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
244 | gt_cls_per_image = F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, num_in_boxes_anchor, 1)
245 | pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
246 | del cls_preds_
247 |
248 | cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center).float()
249 |
250 | num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
251 | del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
252 | return gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg
253 |
254 | def bboxes_iou(self, bboxes_a, bboxes_b, xyxy=True):
255 | if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
256 | raise IndexError
257 |
258 | if xyxy:
259 | tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
260 | br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
261 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
262 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
263 | else:
264 | tl = torch.max(
265 | (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
266 | (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
267 | )
268 | br = torch.min(
269 | (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
270 | (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
271 | )
272 |
273 | area_a = torch.prod(bboxes_a[:, 2:], 1)
274 | area_b = torch.prod(bboxes_b[:, 2:], 1)
275 | en = (tl < br).type(tl.type()).prod(dim=2)
276 | area_i = torch.prod(br - tl, 2) * en
277 | return area_i / (area_a[:, None] + area_b - area_i)
278 |
279 | def get_in_boxes_info(self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt, center_radius = 2.5):
280 | #-------------------------------------------------------#
281 | # expanded_strides_per_image [n_anchors_all]
282 | # x_centers_per_image [num_gt, n_anchors_all]
283 | # x_centers_per_image [num_gt, n_anchors_all]
284 | #-------------------------------------------------------#
285 | expanded_strides_per_image = expanded_strides[0]
286 | x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
287 | y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
288 |
289 | #-------------------------------------------------------#
290 | # gt_bboxes_per_image_x [num_gt, n_anchors_all]
291 | #-------------------------------------------------------#
292 | gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
293 | gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
294 | gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
295 | gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
296 |
297 | #-------------------------------------------------------#
298 | # bbox_deltas [num_gt, n_anchors_all, 4]
299 | #-------------------------------------------------------#
300 | b_l = x_centers_per_image - gt_bboxes_per_image_l
301 | b_r = gt_bboxes_per_image_r - x_centers_per_image
302 | b_t = y_centers_per_image - gt_bboxes_per_image_t
303 | b_b = gt_bboxes_per_image_b - y_centers_per_image
304 | bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
305 |
306 | #-------------------------------------------------------#
307 | # is_in_boxes [num_gt, n_anchors_all]
308 | # is_in_boxes_all [n_anchors_all]
309 | #-------------------------------------------------------#
310 | is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
311 | is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
312 |
313 | gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
314 | gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
315 | gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) - center_radius * expanded_strides_per_image.unsqueeze(0)
316 | gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(1, total_num_anchors) + center_radius * expanded_strides_per_image.unsqueeze(0)
317 |
318 | #-------------------------------------------------------#
319 | # center_deltas [num_gt, n_anchors_all, 4]
320 | #-------------------------------------------------------#
321 | c_l = x_centers_per_image - gt_bboxes_per_image_l
322 | c_r = gt_bboxes_per_image_r - x_centers_per_image
323 | c_t = y_centers_per_image - gt_bboxes_per_image_t
324 | c_b = gt_bboxes_per_image_b - y_centers_per_image
325 | center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
326 |
327 | #-------------------------------------------------------#
328 | # is_in_centers [num_gt, n_anchors_all]
329 | # is_in_centers_all [n_anchors_all]
330 | #-------------------------------------------------------#
331 | is_in_centers = center_deltas.min(dim=-1).values > 0.0
332 | is_in_centers_all = is_in_centers.sum(dim=0) > 0
333 |
334 | #-------------------------------------------------------#
335 | # is_in_boxes_anchor [n_anchors_all]
336 | # is_in_boxes_and_center [num_gt, is_in_boxes_anchor]
337 | #-------------------------------------------------------#
338 | is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
339 | is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
340 | return is_in_boxes_anchor, is_in_boxes_and_center
341 |
342 | def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
343 | #-------------------------------------------------------#
344 | # cost [num_gt, fg_mask]
345 | # pair_wise_ious [num_gt, fg_mask]
346 | # gt_classes [num_gt]
347 | # fg_mask [n_anchors_all]
348 | # matching_matrix [num_gt, fg_mask]
349 | #-------------------------------------------------------#
350 | matching_matrix = torch.zeros_like(cost)
351 |
352 | #------------------------------------------------------------#
353 | # topk_ious [num_gt, n_candidate_k]
354 | # dynamic_ks [num_gt]
355 | # matching_matrix [num_gt, fg_mask]
356 | #------------------------------------------------------------#
357 | n_candidate_k = min(10, pair_wise_ious.size(1))
358 | topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
359 | dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
360 |
361 | for gt_idx in range(num_gt):
362 | _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
363 | matching_matrix[gt_idx][pos_idx] = 1.0
364 | del topk_ious, dynamic_ks, pos_idx
365 |
366 | #------------------------------------------------------------#
367 | # anchor_matching_gt [fg_mask]
368 | #------------------------------------------------------------#
369 | anchor_matching_gt = matching_matrix.sum(0)
370 | if (anchor_matching_gt > 1).sum() > 0:
371 | _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
372 | matching_matrix[:, anchor_matching_gt > 1] *= 0.0
373 | matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
374 | #------------------------------------------------------------#
375 | # fg_mask_inboxes [fg_mask]
376 | #------------------------------------------------------------#
377 | fg_mask_inboxes = matching_matrix.sum(0) > 0.0
378 | num_fg = fg_mask_inboxes.sum().item()
379 | fg_mask[fg_mask.clone()] = fg_mask_inboxes
380 | matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
381 | gt_matched_classes = gt_classes[matched_gt_inds]
382 |
383 | pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
384 | return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
385 |
386 | def is_parallel(model):
387 | # Returns True if model is of type DP or DDP
388 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
389 |
390 | def de_parallel(model):
391 | # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
392 | return model.module if is_parallel(model) else model
393 |
394 | def copy_attr(a, b, include=(), exclude=()):
395 | # Copy attributes from b to a, options to only include [...] and to exclude [...]
396 | for k, v in b.__dict__.items():
397 | if (len(include) and k not in include) or k.startswith('_') or k in exclude:
398 | continue
399 | else:
400 | setattr(a, k, v)
401 |
402 | class ModelEMA:
403 | """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
404 | Keeps a moving average of everything in the model state_dict (parameters and buffers)
405 | For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
406 | """
407 |
408 | def __init__(self, model, decay=0.9999, tau=2000, updates=0):
409 | # Create EMA
410 | self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
411 | # if next(model.parameters()).device.type != 'cpu':
412 | # self.ema.half() # FP16 EMA
413 | self.updates = updates # number of EMA updates
414 | self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
415 | for p in self.ema.parameters():
416 | p.requires_grad_(False)
417 |
418 | def update(self, model):
419 | # Update EMA parameters
420 | with torch.no_grad():
421 | self.updates += 1
422 | d = self.decay(self.updates)
423 |
424 | msd = de_parallel(model).state_dict() # model state_dict
425 | for k, v in self.ema.state_dict().items():
426 | if v.dtype.is_floating_point:
427 | v *= d
428 | v += (1 - d) * msd[k].detach()
429 |
430 | def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
431 | # Update EMA attributes
432 | copy_attr(self.ema, model, include, exclude)
433 |
434 | def weights_init(net, init_type='normal', init_gain = 0.02):
435 | def init_func(m):
436 | classname = m.__class__.__name__
437 | if hasattr(m, 'weight') and classname.find('Conv') != -1:
438 | if init_type == 'normal':
439 | torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
440 | elif init_type == 'xavier':
441 | torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
442 | elif init_type == 'kaiming':
443 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
444 | elif init_type == 'orthogonal':
445 | torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
446 | else:
447 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
448 | elif classname.find('BatchNorm2d') != -1:
449 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
450 | torch.nn.init.constant_(m.bias.data, 0.0)
451 | print('initialize network with %s type' % init_type)
452 | net.apply(init_func)
453 |
454 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
455 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
456 | if iters <= warmup_total_iters:
457 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
458 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
459 | elif iters >= total_iters - no_aug_iter:
460 | lr = min_lr
461 | else:
462 | lr = min_lr + 0.5 * (lr - min_lr) * (
463 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
464 | )
465 | return lr
466 |
467 | def step_lr(lr, decay_rate, step_size, iters):
468 | if step_size < 1:
469 | raise ValueError("step_size must above 1.")
470 | n = iters // step_size
471 | out_lr = lr * decay_rate ** n
472 | return out_lr
473 |
474 | if lr_decay_type == "cos":
475 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
476 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
477 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
478 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
479 | else:
480 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
481 | step_size = total_iters / step_num
482 | func = partial(step_lr, lr, decay_rate, step_size)
483 |
484 | return func
485 |
486 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
487 | lr = lr_scheduler_func(epoch)
488 | for param_group in optimizer.param_groups:
489 | param_group['lr'] = lr
490 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import time
2 | import cv2
3 | import numpy as np
4 | from PIL import Image
5 | from test import get_history_imgs
6 | import colorsys
7 | import os
8 | import time
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from PIL import ImageDraw, ImageFont
13 | from nets.LASNet import LASNet
14 | from utils.utils import (cvtColor, get_classes, preprocess_input, resize_image,
15 | show_config)
16 | from utils.utils_bbox import decode_outputs, non_max_suppression
17 |
18 |
19 | class Pred_vid(object):
20 | _defaults = {
21 |
22 | "model_path" : '/home/LASNet/logs/model.pth',
23 | "classes_path" : 'model_data/classes.txt',
24 | "input_shape" : [512, 512],
25 | "phi" : 's',
26 | "confidence" : 0.5,
27 | "nms_iou" : 0.3,
28 | "letterbox_image" : True,
29 | "cuda" : True,
30 | }
31 |
32 | @classmethod
33 | def get_defaults(cls, n):
34 | if n in cls._defaults:
35 | return cls._defaults[n]
36 | else:
37 | return "Unrecognized attribute name '" + n + "'"
38 |
39 | def __init__(self, **kwargs):
40 | self.__dict__.update(self._defaults)
41 | for name, value in kwargs.items():
42 | setattr(self, name, value)
43 |
44 | self.class_names, self.num_classes = get_classes(self.classes_path)
45 |
46 | hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
47 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
48 | self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
49 | self.generate()
50 |
51 | show_config(**self._defaults)
52 |
53 | def generate(self, onnx=False):
54 | self.net = LASNet(self.num_classes, num_frame=5)
55 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56 | self.net.load_state_dict(torch.load(self.model_path, map_location=device))
57 | self.net = self.net.eval()
58 | print('{} model, and classes loaded.'.format(self.model_path))
59 | if not onnx:
60 | if self.cuda:
61 | self.net = nn.DataParallel(self.net)
62 | self.net = self.net.cuda()
63 |
64 | def detect_image(self, images, crop = False, count = False):
65 |
66 | image_shape = np.array(np.shape(images[0])[0:2])
67 |
68 | images = [cvtColor(image) for image in images]
69 | c_image = images[-1]
70 | image_data = [resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) for image in images]
71 | image_data = [np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)) for image in image_data]
72 | # (3, 640, 640) -> (3, 16, 640, 640)
73 | image_data = np.stack(image_data, axis=1)
74 |
75 | image_data = np.expand_dims(image_data, 0)
76 |
77 | with torch.no_grad():
78 | images = torch.from_numpy(image_data)
79 | if self.cuda:
80 | images = images.cuda()
81 | outputs = self.net(images)
82 | outputs = decode_outputs(outputs, self.input_shape)
83 | outputs = non_max_suppression(outputs, self.num_classes, self.input_shape,
84 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
85 |
86 | if outputs[0] is None:
87 | return c_image
88 |
89 | top_label = np.array(outputs[0][:, 6], dtype = 'int32')
90 | top_conf = outputs[0][:, 4] * outputs[0][:, 5]
91 | top_boxes = outputs[0][:, :4]
92 |
93 | font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * c_image.size[1] + 15).astype('int32')) #######
94 | thickness = int(max((c_image.size[0] + c_image.size[1]) // np.mean(self.input_shape), 1))
95 |
96 | if count:
97 | print("top_label:", len(top_label))
98 | classes_nums = np.zeros([self.num_classes])
99 | for i in range(self.num_classes):
100 | num = np.sum(top_label == i)
101 | if num > 0:
102 | print(self.class_names[i], " : ", num)
103 | classes_nums[i] = num
104 | print("classes_nums:", classes_nums)
105 |
106 | print(len(top_label))
107 | if crop:
108 | for i, c in list(enumerate(top_label)):
109 | top, left, bottom, right = top_boxes[i]
110 |
111 | top = max(0, np.floor(top).astype('int32'))
112 | left = max(0, np.floor(left).astype('int32'))
113 | bottom = min(c_image.size[1], np.floor(bottom).astype('int32'))
114 | right = min(c_image.size[0], np.floor(right).astype('int32'))
115 |
116 | dir_save_path = "img_crop"
117 | if not os.path.exists(dir_save_path):
118 | os.makedirs(dir_save_path)
119 | crop_image = c_image.crop([left, top, right, bottom])
120 | crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
121 | print("save crop_" + str(i) + ".png to " + dir_save_path)
122 |
123 | for i, c in list(enumerate(top_label)):
124 | predicted_class = self.class_names[int(c)]
125 | box = top_boxes[i]
126 | score = top_conf[i]
127 |
128 | top, left, bottom, right = box
129 |
130 | top = max(0, np.floor(top).astype('int32'))
131 | left = max(0, np.floor(left).astype('int32'))
132 | bottom = min(c_image.size[1], np.floor(bottom).astype('int32'))
133 | right = min(c_image.size[0], np.floor(right).astype('int32'))
134 |
135 | label = '{} {:.2f}'.format(predicted_class, score)
136 | draw = ImageDraw.Draw(c_image)
137 | label_size = draw.textbbox((125, 20),label, font)
138 | label = label.encode('utf-8')
139 |
140 |
141 | if top - label_size[1] >= 0:
142 | text_origin = np.array([left, top - label_size[1]])
143 | else:
144 | text_origin = np.array([left, top + 1])
145 |
146 | for i in range(thickness):
147 | draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
148 | del draw
149 |
150 | return c_image
151 |
152 | if __name__ == "__main__":
153 | yolo = Pred_vid()
154 |
155 | # mode = "video"
156 | mode = "predict"
157 |
158 | crop = False
159 | count = False
160 |
161 | if mode == "predict":
162 |
163 | while True:
164 | img = input('Input image filename:')
165 | try:
166 | img = get_history_imgs(img)
167 | images = [Image.open(item) for item in img]
168 | except:
169 | print('Open Error! Try again!')
170 | continue
171 | else:
172 | r_image = yolo.detect_image(images, crop = crop, count=count)
173 | r_image.save("pred.png")
174 | if mode == "video":
175 | import numpy as np
176 | from tqdm import tqdm
177 | dir_path = '/home/public/DMIST/images/test60/data6/'
178 | images = os.listdir(dir_path)
179 | for file_name in os.listdir(dir_path):
180 | if file_name.endswith('.ipynb_checkpoints'):
181 | images.remove('.ipynb_checkpoints')
182 | images = [fn for fn in images if fn.endswith("bmp")]
183 | images.sort(key=lambda x:int(x[:-4]))
184 | list_img = []
185 | for image in tqdm(images):
186 | image = dir_path+image
187 | img = get_history_imgs(image)
188 | imgs = [Image.open(item) for item in img]
189 | r_image = yolo.detect_image(imgs, crop = crop, count=count)
190 | list_img.append(cv2.cvtColor(np.asarray(r_image), cv2.COLOR_RGB2BGR))
191 |
192 | fourcc = cv2.VideoWriter_fourcc(*'MJPG')# *'XVID'
193 | outfile = cv2.VideoWriter("./output.avi", fourcc, 24, (256, 256), True)
194 |
195 | for i in list_img:
196 | outfile.write(i)
197 | if cv2.waitKey(1) == 27:
198 | break
199 | outfile.release()
200 | cv2.destroyAllWindows()
201 |
202 |
203 | else:
204 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
205 |
--------------------------------------------------------------------------------
/readme/PR1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/readme/PR1.png
--------------------------------------------------------------------------------
/readme/PR2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/readme/PR2.png
--------------------------------------------------------------------------------
/readme/PR3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/readme/PR3.png
--------------------------------------------------------------------------------
/readme/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/readme/vis.png
--------------------------------------------------------------------------------
/results/DMIST-100/ACM.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9770253929866989 0.9636497574110295 0.9636497574110295 0.9636497574110295 0.9636497574110295 0.9636497574110295 0.9636497574110295 0.9635400361228176 0.963394967636468 0.9627324283201929 0.9622151247767011 0.9614583677651803 0.9604543029194298 0.9588760199680398 0.9574426151966259 0.9554380385340778 0.9529741045302615 0.9507181260019412 0.9478226836402587 0.9453390326736152 0.9426294082209555 0.9395788508341858 0.9361280129088647 0.9330326691436688 0.9292451744119115 0.9262490437238863 0.9228109298255519 0.9189469404350478 0.9152700307563989 0.9113186482350285 0.9075642443916505 0.9034658802823058 0.8997393414289494 0.8957260121958517 0.8920611530668632 0.8878636030310437 0.8840027034627016 0.8801099820681411 0.876418675006264 0.8723269165822357 0.8682539682539683 0.8641129340594917 0.8601609159321641 0.856004370320383 0.8513690362250618 0.8467093360901227 0.8420059704142668 0.8374298878565948 0.83235244602197 0.827174874718608 0.8216632125903577 0.8161515925383389 0.8098833522119538 0.8035352712904386 0.7969428329508633 0.7894911504424779 0.7807058477497886 0.7713888690551385 0.7592193540673035 0.7438447478765724 0.721807671069542 0.6901407803252313 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/AGPCNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9749928140270192 0.9739622080047612 0.9739622080047612 0.9739622080047612 0.9739622080047612 0.9739622080047612 0.9739622080047612 0.9739622080047612 0.9737048665620094 0.9734979852855682 0.972941770690526 0.9719736929959271 0.9710942558505227 0.9700120123555658 0.9685013130007065 0.966816367265469 0.9653152097492383 0.9640071587971453 0.9619980554713391 0.9598406706102672 0.9572692984188794 0.9548754257035311 0.9521351157304523 0.9490490572198188 0.9460906121555941 0.9429097743797877 0.9395467804744867 0.9358062178668783 0.9322982355986883 0.9280687188657555 0.9240043573789258 0.9195488231070922 0.9154294682337847 0.9106751764263434 0.905799115063164 0.9008742740218671 0.8962208008481569 0.8913498329055068 0.8859760216143195 0.8806215082211646 0.8748683694555452 0.869310095773733 0.8630928945896102 0.857070459244005 0.8512127485060345 0.844951243841477 0.8384293791553861 0.831833422194868 0.8250270309734975 0.8187037224870677 0.8118081302095659 0.8048584833299243 0.7980214088290434 0.7909724133550695 0.7836125183132403 0.7759242199703609 0.7672396781500577 0.7572764232025968 0.7443943968717517 0.7269138594393384 0.7030110346698523 0.668892126461363 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/DNANet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9994319257716342 0.9966188524590164 0.9933693348827671 0.9882196950379928 0.9843159562797124 0.9819262862741124 0.9808007357801857 0.9794811848683065 0.9788049201263639 0.9777683803540144 0.9763606779165219 0.9752858723718185 0.9742618056951942 0.9725377158263064 0.9708011647476825 0.9687718771877187 0.9664242992846253 0.9641230320183973 0.9615774459503235 0.9590555858040885 0.9563946658647294 0.9535025235350968 0.9499816642076806 0.9464342417013377 0.9428782805077486 0.9388700008946944 0.9344739117695403 0.9302978122019991 0.9254419778467459 0.9206534274525707 0.9151565023924904 0.9093453502085938 0.9034528372282332 0.8972104173477171 0.891633570921333 0.8853337260318492 0.8791322009949663 0.8725594318063287 0.865630599870759 0.8585312887785853 0.8514044859508526 0.8436289432863554 0.8358566600573167 0.8276721938032371 0.8199240779686906 0.8117416430429745 0.8035726894019629 0.7951082365093426 0.7874312298003092 0.7797467050076156 0.7724479360502391 0.765417089642512 0.758297857254698 0.75153023397493 0.7452506057401221 0.737897202745831 0.7298777940531022 0.719566865797051 0.7069068742176019 0.6888616525453864 0.6622385915044621 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/HRNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.893255443441317 0.888963842417701 0.8879546700785962 0.8864976010966416 0.8826458295475196 0.8785295362476571 0.8744906905977616 0.8697360150848523 0.865113232637993 0.8602713935561792 0.8551731886616802 0.8502991384228311 0.8454073209036362 0.8401333151615021 0.8346219684523741 0.8295116363714219 0.8245055330228832 0.8194447055285072 0.8138504400123823 0.8087588193545694 0.8031239882179211 0.7974358590545759 0.7917667628537719 0.7852172620656109 0.779243765084473 0.7721177487691072 0.7657199742343503 0.7585853895940435 0.7520374396264605 0.7444930083367066 0.7364901987716579 0.7285106582998425 0.7204603545974602 0.7120999827019546 0.7041135494161775 0.6953371140516699 0.686017932206764 0.6763507131647905 0.6665749083298924 0.6555668358714044 0.6435787651808084 0.6307984295629199 0.6164527922087533 0.5975887801717973 0.5719680964141273 0.5325810731660671 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/ISNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9862930860713566 0.9727109156337465 0.9670103092783505 0.9624739731213325 0.9606149457167763 0.9580536912751678 0.9560963230318069 0.9541354307484331 0.9524183006535948 0.9500176201104193 0.9470764964178414 0.9437948618012927 0.9401366429977922 0.9365397092726294 0.9329455423295709 0.9292720949754167 0.9251881811032469 0.9209105794116094 0.9160751441058469 0.9122940860670792 0.9076751450245426 0.9029696425241961 0.8981232163299527 0.8930769289868545 0.8881252750476749 0.8832160818118692 0.8779516721386321 0.8726069973809693 0.8674593125598344 0.8623675184257688 0.8572570560659613 0.8518620310567542 0.8463689369747232 0.8411294194563341 0.8358696188174959 0.8303764913199819 0.8250800009208739 0.8198300939517189 0.8145736206770143 0.8093316287736362 0.803783300206855 0.7988571428571428 0.7938057482656096 0.7888364517979753 0.783982416127663 0.77877173479875 0.7734355886590372 0.7684115189337847 0.7632945959594335 0.758138646595826 0.7527122900260867 0.7477323673935957 0.7425057203273386 0.736710893264599 0.7300931648928152 0.7226445408982998 0.7127431868940499 0.6989397411684495 0.6786764461376948 0.6476516948094788 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/ISTDUNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9984898822108125 0.9977352275066914 0.9957615531856713 0.9917173679635973 0.9864419201172591 0.9804846809801476 0.9730497675219563 0.9641710830783011 0.957624884671148 0.9526495793694473 0.9487649813901305 0.9452558305955827 0.9416884767815749 0.9391177896267572 0.9365267608538723 0.9340476563535779 0.9315826393917451 0.9292246690400563 0.9265945129624968 0.9241601709711582 0.9220775688564123 0.9197070216455445 0.9171831403244525 0.9145972138098122 0.9122320587681296 0.9097509174444477 0.9072032634259034 0.9046097611408676 0.9014987644019384 0.8982059619926063 0.8947224142451908 0.8910166490364715 0.8867997803014752 0.883016703475781 0.878500313557609 0.8738660171012139 0.8694861057871932 0.8644673675112502 0.859259562543218 0.8538089277529001 0.8483597802216583 0.8424860937959484 0.8367641327102578 0.8304054370021667 0.8240069267473779 0.8176532982944222 0.8109502365615915 0.8043751319217847 0.7974527858054977 0.7902216797799193 0.7826945320298512 0.7753318883771335 0.7678601878083147 0.7604226638054863 0.7527929842609367 0.743802468864931 0.7329486920781249 0.7179935483541146 0.696433924876244 0.6658928620534773 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/LASNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9999333155508136 0.9999333155508136 0.9999333155508136 0.9999325986587133 0.9999325986587133 0.9999325986587133 0.9998867561293245 0.9998036809815951 0.9997522243495889 0.9997114175581803 0.9996916826870761 0.9996564518345472 0.9996436815442532 0.999618343021549 0.999604487745712 0.9995640075401049 0.9995069982203838 0.9994935284186832 0.9994550641635235 0.9993510841015605 0.9992981703691429 0.9992243353114341 0.999150184723005 0.9991116525800148 0.9990216551293224 0.9989293616616307 0.9988170738222253 0.9987038332658247 0.9985581055069005 0.9983134050021939 0.9980534924199141 0.9978307467509479 0.99755954461477 0.9972390409300073 0.9968794809712983 0.9965166742804934 0.9961255176630812 0.9955557477913904 0.9950904991677026 0.9944156444079454 0.9937462164609588 0.9929177364089848 0.991770919152138 0.9904356918787034 0.9888252695001997 0.9865946874750893 0.9837370616072779 0.9801658746790236 0.9753155426593948 0.9691464751979124 0.9613127208755624 0.9520509883529719 0.9413857348238125 0.9291386140372693 0.9161521766795877 0.9016966283542738 0.8869785106547053 0.872476300520581 0.8583607699887049 0.8446990287875397 0.8311370068815728 0.8183275010628904 0.8052696311310994 0.7926354606736976 0.7805745316587525 0.7659162523213323 0.7458784262162436 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/RDIAN.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9970163316582915 0.9970163316582915 0.9967132292522597 0.9958583887539325 0.9958583887539325 0.9956649371927908 0.9953856869232108 0.994855387765549 0.994556851897294 0.9940307692307693 0.9933260006338435 0.9923540354649872 0.9906877241552885 0.9886885413174351 0.9862613866225004 0.9832269688011469 0.9790074294205052 0.9741744426440185 0.969692179612827 0.9664486409954733 0.9637019982376851 0.9607505455070059 0.9582118310925309 0.9555793021248613 0.9532488800359018 0.950569940363856 0.9481205951448708 0.9451860009618936 0.941953267517974 0.9383900948788664 0.9345251062823282 0.9301555100742439 0.9254571352471831 0.9204393638726628 0.9148976364352759 0.9090482496845574 0.9026465314017449 0.8963014912870265 0.8892262930425694 0.8817997678599877 0.8740678786811639 0.8662314322703734 0.8583611557643955 0.8502966772941115 0.8424633440154007 0.8345526673634289 0.8268840953508817 0.8188251372857057 0.8116460025651988 0.8050405417535181 0.798227074122909 0.7917490638941375 0.7853796749882963 0.779541735348736 0.7737285009251755 0.7678104297604892 0.7619648293774893 0.7557828858329051 0.7489430408446237 0.740507232342178 0.7286818263664326 0.7112691904081715 0.6832439453146869 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/RISTD.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9780332346915548 0.9771434082360884 0.9746109206014244 0.9735930518600331 0.9722422783794625 0.9714314221098843 0.970394362995795 0.9701133889014042 0.9694624601516983 0.9691075743355233 0.9678367346938775 0.9673197193145435 0.9667792447037151 0.9661216946261018 0.9653651062031068 0.9643528138258824 0.9632454436983834 0.961762435204588 0.9606463600492703 0.9591246658084959 0.9573865240883334 0.9555256668908316 0.9538697357378958 0.9513703673402601 0.9491291602000345 0.9464048161259048 0.9430201279360753 0.9395841560835013 0.9356605156218772 0.9311880743940423 0.9267878638229575 0.9214309496064116 0.9156106056786671 0.9101818583826212 0.9038754812358029 0.89696311806438 0.890103121460788 0.8822542449257438 0.8736210620770679 0.8654774030986973 0.8564459569622547 0.8473808750135393 0.838191492787223 0.8282572701196333 0.8180804660695651 0.8083411042855744 0.7975019269847944 0.7870916188973499 0.7761863331283608 0.7654530166336859 0.7544788907929584 0.7429695889732001 0.731050648377845 0.7185341487941181 0.7033930754747374 0.6838016716463069 0.6588769097474992 0.6239437746008334 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/SANet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9301948051948052 0.9293735498839907 0.9258576874205845 0.9244510439164867 0.9225589225589226 0.9196693380450558 0.9171051566421112 0.9130967240526601 0.9094214944726124 0.9061371841155235 0.9032077582991421 0.899656336109477 0.8955079651551109 0.8924787787063236 0.8892029810198613 0.8861484632272228 0.8829490616621984 0.8796638214581189 0.876242620453247 0.8730626441753172 0.869875664859503 0.8665744957709824 0.8632690534666966 0.8596977833004193 0.856005655708731 0.8524206424376911 0.8490773086762854 0.845082447372029 0.8411217751061244 0.8370869702993115 0.8327038164805848 0.8284643634167072 0.8241499316851847 0.8200351563861625 0.8155404267603262 0.8111573208287977 0.8063954351748934 0.8014667456363447 0.7965833663666956 0.7914263069305313 0.7862199767178505 0.7806613549486483 0.7752937409718396 0.7696975607641711 0.7640585339858924 0.7583612751689276 0.7522605293075642 0.7458913295984702 0.7397726174555226 0.7326681148984322 0.7252043548728067 0.7170059505167554 0.7083684052203526 0.6978069625403907 0.6848423298924012 0.6677845408344167 0.6430916033201864 0.61004065605671 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/SSTNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9992678015742266 0.9983068971662805 0.9974027749299432 0.9967702245462935 0.9961359805976898 0.9955322651079983 0.9950062421972534 0.9944907110826393 0.99419002050581 0.9935020657013357 0.9928685807917179 0.9923600334248538 0.9916340362629932 0.9912599828968156 0.9908490849084909 0.9904536741214057 0.9901162720790698 0.9892323527083121 0.9884980937550341 0.9876859091326109 0.987011602322403 0.9861725561334987 0.9852274234186448 0.9841933354769923 0.9830412157719128 0.9818362931718781 0.9803665804565389 0.9787015986431381 0.9766571867633663 0.9742707704687186 0.9712330273567635 0.9676392175471239 0.9631593855277194 0.9574768719072899 0.9513418863858302 0.9439060571280075 0.9353816600633708 0.9257154067091543 0.9150082356360818 0.9033130493576741 0.8908090066082556 0.877404271288504 0.8635436277233635 0.8485111534081407 0.8332275553040672 0.8171704203653437 0.8017011834319526 0.7855462889120179 0.769628430860638 0.7535663604451416 0.7379702471832227 0.7224592134441439 0.706798987765622 0.6888921669862812 0.6659574468085107 0.6365309149627293 0.5974578013639482 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/SwinT.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9292561349693251 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9273204020573503 0.9271392710749579 0.9262397575650587 0.9250803245657028 0.9237414093166834 0.9212462703422369 0.9187010003457159 0.9161753423505796 0.9131025859844801 0.9096330102922395 0.9061934759558411 0.902293838220476 0.8979743182569303 0.8937297230807079 0.8890463463345865 0.8846834820487599 0.8797871195198751 0.8742601833584774 0.868732721886538 0.8630601625243549 0.8573574088033076 0.8511466731102959 0.8453385796788729 0.8391951603614439 0.8333161266217326 0.8271409324655117 0.8210204623970236 0.8146008728830815 0.8081344851339846 0.8015491430888523 0.795264395253181 0.7888990590845273 0.7826157216891966 0.7758347623780761 0.7696040002542535 0.7629863791574052 0.756392746413582 0.7493987788602265 0.7422695567858372 0.7347681888110402 0.7269444264582094 0.7173572458112335 0.7056123192290036 0.6887948335712517 0.6630662610713094 0.6253193392789619 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/U2Net.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9989853896103896 0.9954076946627207 0.9909757090514316 0.9884390145395799 0.9854347206965296 0.9839314046517909 0.9812543392733164 0.9780441640378549 0.9745618741058655 0.9702894921366323 0.9647448665483649 0.9584590774913835 0.9516604440951071 0.9438382095036606 0.9356654344998132 0.9269173723745275 0.9184850513163766 0.9097730237411948 0.9015627296330769 0.8924987791281593 0.8840317730693689 0.8759011878827736 0.8681120003740619 0.8604222495299291 0.8520249111572429 0.8446482113750243 0.8375210560362772 0.8305760573301075 0.8235225133837214 0.8166065266021426 0.8099561424802965 0.8039502484465606 0.7974283746968835 0.7913050581937356 0.7857113058956043 0.7798193040522408 0.7741924334786384 0.7684537976533466 0.7635670509855921 0.7583797242383671 0.7534292316623853 0.748111368568171 0.7436076882701392 0.7389261465694807 0.734238665381998 0.7297058274401649 0.7245528424229113 0.7196431224924463 0.714468510456371 0.709227502379732 0.703585854845234 0.6970972733135703 0.6890824207833209 0.6794075775209909 0.6657926507677425 0.6465614040943215 0.6194119772114804 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/UCF.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9957730242090431 0.9947206757535035 0.9946091644204852 0.994308864666731 0.993297803617571 0.9918130990415336 0.9894930391384292 0.9871589085072231 0.9844850948509485 0.980981508111721 0.9775820950284352 0.9719374529642947 0.9668637236084453 0.9611050475555997 0.9546075623406994 0.9476850832783036 0.940380332362515 0.933711066363909 0.9258361170362619 0.9177294347274071 0.9094014342686482 0.9012313312923904 0.8927197537988074 0.8848414257220725 0.8770826692938484 0.8695445003038507 0.8623953043592957 0.8548616889830295 0.847230851827499 0.8398041697734774 0.832448946776009 0.825337906206397 0.8181431920796944 0.8114626453774887 0.8049884359847459 0.7979861317710598 0.7913599166199273 0.7845565540644237 0.7780972703978645 0.7717491754433485 0.7654626060138782 0.7596532702915682 0.7531782227992955 0.7467802585605697 0.7405357282417743 0.734320663728628 0.728722127066959 0.7230994152046784 0.7172780249540575 0.7118512029350204 0.7056781559278645 0.6999433005736648 0.6936443208565398 0.6865301772227662 0.6785427700208634 0.6675429204356655 0.6532080784657067 0.6323946579253082 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-100/UIUNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9992280972597453 0.9991074970249901 0.9980977369629387 0.9965104941756042 0.994351669941061 0.9923021901290916 0.9910578609000584 0.989529598038715 0.9882366273798731 0.9875644710822987 0.9868311385230871 0.986183420104125 0.985727476173336 0.9851242790940918 0.9841428629364768 0.9832777192737607 0.9818905804790233 0.980125675873155 0.9782483954605347 0.9761132706289903 0.973427617288794 0.9702806273381819 0.9667511412751038 0.9624811619164334 0.9578800800626577 0.9534337851139547 0.9481549695093212 0.9424413432669855 0.9364635910839663 0.9300162589135973 0.9233674059313381 0.9161449743144106 0.9090076613751799 0.9019582066661189 0.8947415649676956 0.8876511468065066 0.8805405405405405 0.8737187430269711 0.8666779852338584 0.8594987421969627 0.8529148983694438 0.8465114149930503 0.8396629236134912 0.8330961521921011 0.8268725124543278 0.8209520089326843 0.8149578953398448 0.808910880753291 0.8028160437359604 0.7967393271366291 0.7904387646364215 0.7843909823147219 0.7773268807956245 0.7697483443708609 0.761703088986371 0.7530578301276445 0.7429019704937269 0.730983925239482 0.7160811576188018 0.6965350272667346 0.6673127008974454 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/ACM.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9873522133626615 0.9745297153321125 0.9603260869565218 0.9563446314656829 0.9548036864032319 0.9544963751349684 0.9537025249310418 0.9535328853593973 0.9528396836808052 0.9510927808268113 0.9499571905169614 0.9487837837837838 0.9465116857903776 0.9443331721803992 0.9433516070084919 0.941091347603256 0.9390059672180678 0.9369257291129792 0.9344383304254462 0.9325224765669834 0.9300431458632957 0.9276787130068761 0.9257144429615157 0.9232306761696868 0.9207342406425865 0.9177487087898827 0.9151418936927124 0.9121820287789014 0.9099237296317353 0.9070077519379844 0.9040108471331865 0.9009095721642042 0.8981752877573999 0.894872763920383 0.8916828526883313 0.8886854539318259 0.8854272343489217 0.8821340147282885 0.8788887348033098 0.875507898143506 0.8716882663778629 0.8678666798536248 0.8644545398313414 0.8609828871138228 0.8573736321000521 0.8534719535169902 0.8497851954875599 0.8457331140707544 0.841793972712162 0.8375502537024523 0.8326969879652578 0.8282364600690474 0.8230459666372894 0.8178706613988121 0.8120208153520793 0.8059893911212148 0.7993853205730616 0.7926903344578493 0.7858587054526909 0.7776788127744617 0.7689055712611184 0.7597181434669862 0.7492477962463 0.7370744134218992 0.7238809948066585 0.7086748872092127 0.6905261080827763 0.6702375015163973 0.6449635689633133 0.6159935015118011 0.5819285346006411 0.5419293577556703 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/AGPCNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9989128219066385 0.9989128219066385 0.9989128219066385 0.9989128219066385 0.9989128219066385 0.9988662131519275 0.9988320741632907 0.9988320741632907 0.9988320741632907 0.9987423093918896 0.9985374325847832 0.9983599604128376 0.9982895793920232 0.998136405440728 0.9980063617221451 0.9978082774763273 0.9974867856786676 0.9969545643702947 0.9965459849311882 0.995932261084163 0.995368496056744 0.9946695919534014 0.9940776842416187 0.9929981894308023 0.9915129001206633 0.9903891233005157 0.9888446514423077 0.9871040833052114 0.9855811683784914 0.9840646832561426 0.9821923743500867 0.9805359425087473 0.9787678455548345 0.9771563878599712 0.9753948446341139 0.9734878948390139 0.9714442072021962 0.9694897514173572 0.9673772938502765 0.9647094715490728 0.9615233513730207 0.9579613066227902 0.954434048881908 0.9505605695800529 0.9458684545238968 0.9413324378589768 0.9361393521324336 0.930666949478345 0.9247196665677381 0.9186201351724831 0.9119513387944737 0.904855674273612 0.8972911963882618 0.8898485078957301 0.8818450912678836 0.8739775583051879 0.8657561988835518 0.8571803914831695 0.8490968938659634 0.8404221696532042 0.8319864958755532 0.8227802731273536 0.813657136549288 0.8039450628736343 0.7943363748788396 0.7835003429480545 0.7720630818385606 0.759250281526158 0.7442884485600806 0.7277878094208977 0.7071779556838962 0.6822369657192098 0.6537121513639286 0.6154504509627743 0.5711741548459663 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/DNANet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9530912946510183 0.9530912946510183 0.9530912946510183 0.9530912946510183 0.9530912946510183 0.9530912946510183 0.9530912946510183 0.9530912946510183 0.9530429189150555 0.9525654351318151 0.951098658015042 0.9486057068741893 0.945712722570377 0.942403200588614 0.9383648067751663 0.9330556441326531 0.9276281429530702 0.9223861786483222 0.9161118069747511 0.9111028049027395 0.9045789729323972 0.899221664826796 0.8933094384707287 0.8874957341658556 0.8823281010581181 0.8767753389283409 0.8718588250874618 0.8673211127701708 0.8630762812751193 0.8579303873559511 0.852427732079906 0.8484485899980965 0.8444425003280697 0.8400712934180294 0.8364077511273773 0.8324586224884997 0.8281276302701048 0.823856855899728 0.8195553001185387 0.815244505762504 0.8113946394876199 0.8070502755861544 0.8029481802105478 0.799041428420654 0.7948847962429751 0.7908945039056301 0.7869658547754658 0.7831837426854915 0.7791392366528407 0.7752247381588432 0.7712188186635647 0.7671952868066905 0.7630898879169886 0.7590406028189861 0.7549826130620252 0.7502954027094612 0.74491432656435 0.7396356058194945 0.7343617021276596 0.7282849115876052 0.7215470909142379 0.7142412653651198 0.7057633973710818 0.6950108411548057 0.6824535331998018 0.6677839580479571 0.6498736544698462 0.628168524776553 0.6038169377735627 0.5769959137663802 0.5466966069703344 0.5151741244850632 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/HRNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9221838719799185 0.9092802290771834 0.9083906747458149 0.9080857964177784 0.9073970906839988 0.9051672640382318 0.9021304347826087 0.8996543778801843 0.8975511733251592 0.8949960953925632 0.8927573132593474 0.8891870001012453 0.8864481636448164 0.8827678149056036 0.8796793587174349 0.8766350250392406 0.873301737756714 0.8697702021034585 0.8670806096818983 0.8638590485184748 0.8603162597868258 0.8566845916302843 0.8534187492071547 0.8494446729740848 0.8456914290999052 0.8420446619354982 0.8385360259922278 0.8345832484206236 0.8300369740008217 0.826013290912856 0.8217781402936378 0.8174799339720691 0.8129820484321848 0.8079794623489939 0.8031538069273918 0.797952471886272 0.7927336653970114 0.7876123125761741 0.7820388016727223 0.7756702643931157 0.7694192977165845 0.762737330965715 0.755155954489508 0.7474996224749382 0.7399830215940587 0.7317855763844702 0.7227523408461219 0.7130037783375315 0.7026237351949173 0.6915155183174777 0.678569479066213 0.6639572547341953 0.6486345131151177 0.6305158359480412 0.6094967114236353 0.5843208881215174 0.5568514677541642 0.5253512447239744 0.4915013841579421 0.45555711304239477 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/ISNet.txt:
--------------------------------------------------------------------------------
1 | 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.9888364779874214 0.988825819489435 0.988457532164072 0.9880510169623608 0.9876451751660852 0.9869674635987776 0.9858262078261236 0.9841904271948396 0.9824600686669652 0.979962639221768 0.9778368349796921 0.975652725974129 0.9738864681675139 0.9722374693052145 0.9701265682860775 0.9683154956409001 0.9666959030876203 0.9645909645909646 0.9619035342645542 0.9597800631300275 0.9575626630001856 0.9556515782262739 0.9530223891994257 0.9506850124592489 0.9478755647052096 0.9449941846511714 0.9417140608792074 0.9384483326685973 0.9347919311846763 0.9312354597771519 0.9272541389613168 0.9231450191290692 0.91857456894618 0.9143270474999636 0.9097708218461495 0.9049308362896351 0.8994171045067447 0.8938570369696103 0.8883026173439368 0.8818093445063802 0.8753546332762728 0.8682000453960808 0.8606851438983262 0.8529659361640082 0.8450139096939867 0.8372139013499589 0.828692414323441 0.8201643053915781 0.8115740275644981 0.8028966492663 0.793321519948648 0.7834597955527453 0.7731789985291296 0.7624659155325866 0.7513485960152199 0.7393255455832257 0.7254113606559225 0.7101875742904358 0.6924707897558734 0.6734493666662218 0.650589431540148 0.6241550830265234 0.5929169904388661 0.5563685255538358 0.5144105882828264 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/ISTDUNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9959642548284808 0.9959642548284808 0.9956805731142014 0.9954472641429897 0.9946815672546115 0.994276823915097 0.9935956528067537 0.9927309610345701 0.9915115413253909 0.9897158322056834 0.9873460383601936 0.9834757036829576 0.9789194615325217 0.9718801167509077 0.9652237033126567 0.9596071435894281 0.9522414354378507 0.9452920271386007 0.938613961652714 0.9330005422474562 0.9279822335025381 0.9229500437457868 0.9177362814234856 0.9137208454810496 0.9094006042446321 0.9061922861534428 0.9025404843262517 0.8991591103695085 0.8962114649143195 0.892861503398315 0.8894959879146966 0.886190638491252 0.8829006823628414 0.8795213753824929 0.8765421528925266 0.873448905109489 0.8706197509412106 0.8677231742066435 0.8644654779135944 0.8613896373514731 0.8584169613330852 0.8552165347478924 0.8522099073220963 0.8492217963013057 0.8460070058167561 0.8429836981718396 0.8395295439556413 0.8360665499541486 0.8324621924872814 0.8287443404940132 0.8249030895226087 0.8212230176984461 0.8170424480072521 0.8127379751785457 0.8079523082486239 0.8035773857414784 0.7986166329766341 0.793706490006737 0.7885636474964131 0.782560809756315 0.7764297339402432 0.7694442205127987 0.7617850720353873 0.7534553679956854 0.744041702214256 0.7331548365045004 0.7211985382602647 0.7072366783532574 0.6911793916396688 0.671748737684426 0.6483388175352935 0.6193155083988143 0.5835784018978692 0.5411851481435179 0.4916769579376517 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/LASNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 1.0 1.0 0.99995869304804 0.99995869304804 0.99995869304804 0.99995869304804 0.99995869304804 0.99995869304804 0.9998599243591539 0.9998314739290168 0.9997825276500559 0.9997171145685997 0.9996613230977648 0.9996613230977648 0.9996613230977648 0.9996391960438907 0.9996057171020207 0.9995885543295305 0.9995710072392529 0.9994742465614029 0.9993951644993394 0.9993747903260239 0.9993208928782331 0.9992606074394267 0.9991841618622865 0.9990879478827361 0.9990673875508841 0.9989934757827847 0.9988958367233083 0.9988296651478826 0.9987369715873522 0.9986593462717058 0.9985851346187609 0.9984243275792855 0.9982846505462803 0.9981912707506699 0.9980082069251694 0.9978454660526245 0.9977107307312555 0.997545301126775 0.997341331494421 0.997102601977048 0.9968139488805231 0.9964315847324463 0.9961639731251135 0.9958041647549839 0.9953611567436913 0.9949866523158409 0.9944989767958101 0.9940121795394612 0.9932969893031398 0.9925026426716431 0.9913929813791692 0.9901331461096763 0.9885946045597002 0.9867116438768749 0.9843136560339586 0.9810635628256056 0.9770535982199978 0.9714269899258275 0.9642139166900845 0.9552120094811694 0.9434847166043725 0.930228282366504 0.9143526311233372 0.8973462680878163 0.8800334101227283 0.8626366715675432 0.8450269014172964 0.8272071300441184 0.8103071544512424 0.7930567009726945 0.7754997530290264 0.7565220126665455 0.7355527911990638 0.7098307942357002 0.6771883193418805 0.6330259956105181 0.5800873625546016 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/RDIAN.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9953938419685916 0.9952170507943612 0.99482447656638 0.9944960356909052 0.9940670289855073 0.9934844433125345 0.9926061741435319 0.9913573970513472 0.9900302652661563 0.987976222642529 0.9852431669446939 0.9825241532990431 0.9796369881955664 0.9768490177527965 0.9752663368844918 0.9735794472677001 0.9715476131950248 0.9692369582450868 0.9669535480193787 0.9645738530194112 0.961838295344866 0.9592916000491944 0.9562430654620384 0.9534460113700639 0.9502403980025618 0.9468720217223211 0.9434332339557508 0.9394899882513038 0.9351612771015205 0.9310835972436622 0.9263275788823234 0.9212448443944506 0.9161871184860361 0.9106791582343867 0.9045956551918465 0.8981064322567296 0.8919187200747402 0.8851837841245823 0.8789486272177012 0.8723114912500447 0.8661073533253985 0.8595049728752261 0.8521908342815679 0.8452490528073292 0.8384916084645053 0.831542983622869 0.8241251983411022 0.8171925954113901 0.8099600602616029 0.8020712411908266 0.7943363131277157 0.7859910542466323 0.7775649194467184 0.7686530699395169 0.7591932244268834 0.7485314132166995 0.7368248895572892 0.7242810791914439 0.7094423587969798 0.6927144655109415 0.6735905237004794 0.6501793730279772 0.621514709606434 0.587840535351039 0.5476960872430943 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/RISTD.txt:
--------------------------------------------------------------------------------
1 | 0.998139534883721 0.9980351490012007 0.9980351490012007 0.9980351490012007 0.9977188239270024 0.9974253038740195 0.9973953573386949 0.9973953573386949 0.9971116680117232 0.9969420114768952 0.9964623443771685 0.9957607451186682 0.9950805767599661 0.9939968157022421 0.9924921288447566 0.991198573717587 0.9899581413048074 0.9887463962620539 0.9871768433287715 0.9855838283535775 0.9840877361188205 0.9826149540183926 0.9811777974212821 0.9793379410694798 0.9779381041254631 0.9764324151099598 0.9752647911163542 0.9739167950693375 0.9724758659178075 0.9708668985719516 0.9694024080415332 0.9680508112724168 0.9663559313285226 0.9645380149681752 0.9622555707139608 0.9603744910786318 0.9579254757832685 0.9554768820574395 0.9522630501272242 0.9486060182760025 0.9452262849202631 0.9411704690823189 0.9367984382625671 0.9324748115033029 0.9277938780512421 0.9226094514460348 0.9176237110589776 0.9123942259830762 0.9068121193331913 0.9009831902133598 0.8952175163894789 0.8883734225837765 0.8817691152684457 0.8750529173698797 0.8680247958409812 0.8603784281008001 0.8527422755550927 0.844611846238019 0.8362523229505046 0.8270266643663619 0.817508081871454 0.8070964269074361 0.7961462984182175 0.7853033428225628 0.7724241799050959 0.7580117854380467 0.742431351434505 0.7247458404369806 0.703763724307283 0.6806904646268415 0.6517201207012898 0.6170576909363189 0.576515290465462 0.5293741384610807 0.4788629253852859 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/SANet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9777038269550749 0.960582270199542 0.947136563876652 0.939495263115994 0.9314780615169076 0.9251040952933115 0.9183072184277488 0.9132724719101124 0.9081784001931634 0.9026226473310707 0.8968698609136773 0.8914738526401341 0.8866142833826856 0.882522786528475 0.8792133143738359 0.8762589389344416 0.8732142230334515 0.870235694687779 0.8675215675262477 0.8650175215507682 0.8622721724123417 0.8595964047704903 0.8570391225142362 0.8546098427688263 0.8523439867589808 0.8495006813967516 0.8466931826709693 0.844443528186944 0.8414819223329861 0.8387365000477874 0.8358990431592338 0.8332294930219311 0.8302998503328803 0.8268703127858057 0.824064459523637 0.8209806779476068 0.8180169423642231 0.8145748750494672 0.8109994952403295 0.8073138756641137 0.8031544051166996 0.7986789668376468 0.7941895194135216 0.7895873467735008 0.784779938709562 0.7796268397265036 0.773976883103722 0.7678341946244497 0.7612170975760022 0.7545479873285799 0.7464634384335547 0.7374734535157728 0.7276610325421151 0.7170816970467392 0.7051882243613347 0.6918031540625238 0.6777109780536484 0.6612539221609401 0.6441790844208201 0.6238753230502341 0.6017070900555765 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/SSTNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.998046875 0.997616399173685 0.9971590909090909 0.9970233306516493 0.9963333988083546 0.9962744158816329 0.9957899830631503 0.9956763189210631 0.9956257777442589 0.9951785956811082 0.9947789551731595 0.9945767307855267 0.9944712661579194 0.9941952506596307 0.9938622619369979 0.9936068419492782 0.9934299864206406 0.9931887266859308 0.9930399400364065 0.9928737465429187 0.9924978219483076 0.9920347574221579 0.9917628888725454 0.9913729912602199 0.9910967164907241 0.9907418258525723 0.9902377440563986 0.9899212065548412 0.9894323072795773 0.9889015335038478 0.9880722852194372 0.98759990721019 0.986953365438688 0.9862394864326686 0.9853881893475187 0.9845929471597921 0.9835246221931426 0.9819606526674736 0.9806581390550856 0.9785805091833663 0.9762621659451189 0.9735086217826804 0.97010504882225 0.9651797633153845 0.9594370016181466 0.9524321875530891 0.94299100038413 0.931794048156309 0.9187725168916753 0.9044459149407232 0.8890166865315852 0.8727735952993022 0.8565335101386817 0.8399942564495285 0.8225894697012308 0.805212808840735 0.7881728919165513 0.7695811143067111 0.7504935941481913 0.7305719756889518 0.710451184403856 0.6891576063205177 0.6661738017598531 0.6428424255327917 0.6180113119230269 0.5920034591571065 0.5646290059092834 0.5367268076011809 0.5083536728201485 0.47665363761864077 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/SwinT.txt:
--------------------------------------------------------------------------------
1 | 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814843817976159 0.7814176497244515 0.7814176497244515 0.7814176497244515 0.7811305374409715 0.7802614030408109 0.7790161024470648 0.7769423861359851 0.7746932762470363 0.7720106638580785 0.769717383295616 0.7675045155040715 0.7645259346744371 0.7612371555456454 0.7581840117967737 0.754597903751408 0.7507990130326063 0.7465703164880041 0.7422470337647158 0.7375025028365481 0.7328494531294737 0.7279324859518527 0.7223916859677344 0.7158525903598939 0.7095014801363217 0.7024080400079248 0.694559312413849 0.6863400189954746 0.677559419915099 0.6674704793316427 0.6568672462148054 0.6448213395698348 0.6309943453036645 0.6159931204836722 0.5997706757949975 0.5816572871716348 0.5625582175895373 0.5403145304246196 0.5154906909543143 0.4875916464722861 0.4571177709140196 0.4237367936340401 0.3874161721033773 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/U2Net.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.997615894039735 0.9955842391304348 0.993157599551318 0.9902715506302343 0.9885764499121266 0.9857921042286741 0.9823813855508211 0.9792544976416079 0.9758148148148148 0.971259790256206 0.9677273956549175 0.9636303793302066 0.9591373439273553 0.955299055613851 0.9504643962848297 0.9452253705757789 0.9409127877286225 0.935408228522703 0.9302679811860803 0.9254116121277303 0.9202283043458721 0.9151023890784983 0.9103849536567215 0.9058556349841286 0.901170527353376 0.896118586948837 0.8912233982532554 0.8859130096163207 0.8816284532396532 0.8767372386024159 0.8721638180611528 0.867371326935191 0.8628758356869615 0.858274647887324 0.8532353235323532 0.8484711759255529 0.8439949155053691 0.8389855455334566 0.8335269661033049 0.8281473407041436 0.8235677888115154 0.8180825347779472 0.8126033698573496 0.8067929638034579 0.8010906479373852 0.7953008629861685 0.7895825077105985 0.783301718366436 0.7765448486982905 0.7698648741123645 0.7627625631480989 0.7546002530201176 0.7467856850380804 0.7376301201612112 0.7278016295619365 0.7175656790891448 0.7056311501428421 0.6929992973971013 0.6787985824027565 0.6640922403941483 0.6472022112433892 0.6294308948732988 0.6102610976433652 0.5891101572235999 0.5669420008945878 0.5423536914618352 0.5160755996913653 0.4878346095547369 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/UCF.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9887806943268417 0.9887806943268417 0.9887806943268417 0.988277652181194 0.9865943185445261 0.985513756316676 0.9839592738449716 0.9832847019591496 0.9819034852546917 0.9808853118712274 0.9797869101978691 0.9790033517077089 0.9778120184899846 0.9770789292415857 0.976222251879532 0.9749599983375933 0.9734740221604479 0.9719029843885038 0.9700851717397375 0.9682965169190039 0.9669468316177166 0.9646834854821543 0.9620609770049193 0.9591097874200776 0.9554379776601999 0.9520578810334468 0.9486828354174925 0.9441159179721271 0.9394838523963996 0.9343654833145218 0.9293714965003433 0.9241424932367054 0.9187607081667618 0.9136701883325677 0.9074760224794794 0.9015874103121736 0.8953211343105787 0.8894481610703883 0.8832925545094152 0.8776535893781412 0.8713760481304405 0.8647501460619584 0.8587091576892093 0.8517143481109413 0.8451077709361978 0.8386240253304994 0.8320643490049872 0.8251667695201153 0.8183593415440483 0.8112404799281107 0.8052523616734143 0.7986201175105407 0.7912698169642401 0.7840060755271413 0.776768125458689 0.7697057413794723 0.7616500308352938 0.7537141689324874 0.7455290804160764 0.7369302864629261 0.7282779507985943 0.7186111848596642 0.7087193806320599 0.6976502375850182 0.6863222871994802 0.6737374054646411 0.6592513102742998 0.6426335124376823 0.6252757330160854 0.6044513066376959 0.5807577435484412 0.5544871710483011 0.5251208145297148 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/DMIST-60/UIUNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9346492397972793 0.9258851298190401 0.9112977568370378 0.9052664140437707 0.8999631766294341 0.89474216380182 0.8908854730461139 0.8876756838446426 0.8830079828268599 0.8795285770121163 0.8760699967493769 0.8728300763813114 0.869909036888056 0.8669781997335758 0.864809488159502 0.8615271524395847 0.8593391424380985 0.8571637783744203 0.8540699996927142 0.8509316770186336 0.8478734854966467 0.8445555074574004 0.8410275642227639 0.8377965535347749 0.8348229553797396 0.8317475940507436 0.8281376157370999 0.8249876602903106 0.8220290529212819 0.8184048049840518 0.8147738810865209 0.8109880779094752 0.8070886880325087 0.8029534774815509 0.7984682337893275 0.7938999200928731 0.7889350535753334 0.7840922718775352 0.7787152706669397 0.7729083402146986 0.767294712261284 0.761259418996629 0.7545157323255046 0.7473115157708924 0.7389559134774627 0.7297676966551139 0.7199009221778269 0.7088209457071599 0.6966625836055375 0.6832500969404202 0.6687347860975752 0.6513832976445396 0.6335681940788936 0.6135284840857358 0.5891143857387687 0.5628545803795707 0.534972148064483 0.5045113631632477 0.4732565473741944 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/ACM.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9402390438247012 0.8856548856548857 0.8688524590163934 0.8588351431391905 0.8515562649640862 0.8501672240802676 0.8474358974358974 0.8474358974358974 0.8474358974358974 0.8462757527733756 0.8448087431693989 0.8439575033200531 0.8387596899224806 0.8373408769448374 0.834306956983537 0.8339526242452392 0.8339526242452392 0.8322680185144369 0.831110199056023 0.8305727399720615 0.8286255924170616 0.8260869565217391 0.8243313201035375 0.8213114754098361 0.8203470031545741 0.8169588875453446 0.8123019302794584 0.8067238516878804 0.8031737565008668 0.8010269576379975 0.795973817463258 0.7924663249493384 0.7897877223178428 0.7857778766970843 0.7823617659665667 0.7802690582959642 0.7752774974772957 0.771154973617354 0.765617616482374 0.7605749862662516 0.7557326250553342 0.7494414847912012 0.7456396561795877 0.7395774305836639 0.7365807519508158 0.7313169984686064 0.7266840342389281 0.721113521330441 0.7157931762223004 0.710308502633559 0.7062079082483164 0.6991517192255391 0.694384858044164 0.6888124769740882 0.6829924650161464 0.676299569717409 0.6686818464320018 0.659764448096412 0.6541397533763946 0.6483139050791007 0.6404085347355648 0.6305529705637459 0.620040764089681 0.6042102391561335 0.5829553798971967 0.5570052432630167 0.5169249061791699 0.43815168818272093 0.25854477207925797 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/AGPCNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9225589225589226 0.9110169491525424 0.9023136246786633 0.9023136246786633 0.9023136246786633 0.8993666432090077 0.8909090909090909 0.8868729989327642 0.8802559414990859 0.8799832144355854 0.8761503067484663 0.871463499825358 0.8682880606443462 0.8666864608076009 0.865592890863649 0.8584669338677354 0.8562107298211696 0.8530344202898551 0.8505276760715055 0.8487785180698566 0.8477255204317656 0.8463671658347199 0.8446668967897826 0.8433533734134937 0.8412698412698413 0.8374270801350936 0.8363986947493326 0.833309749540116 0.8328101542241027 0.8309765208110993 0.8294314381270903 0.8271359283671185 0.8249097472924187 0.8222506393861893 0.8195760036084799 0.8178146490557799 0.8153454313912767 0.8129570501596457 0.8112377524495101 0.8089865784866758 0.8071239105721865 0.8045320560058954 0.8008592141770339 0.7981507327285415 0.7938354419631485 0.7906014726565732 0.7871243251953912 0.7844942935852027 0.7822679778733866 0.7786694667366684 0.7749158495536368 0.7719085060757684 0.7678571428571429 0.7648182252233817 0.7608073003397056 0.755685510071475 0.7515717279481806 0.7483999254334183 0.7424104708234867 0.7369292642536077 0.7331829486957024 0.7266877398961391 0.7206919347730277 0.7144853533996237 0.7055994144403199 0.6969943548797234 0.6876915472071181 0.6754855994641661 0.6591975706266678 0.6387887843895579 0.6056945642795514 0.556320470360585 0.47824440728236517 0.3220363849765258 0.20245183432568056 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/DNANet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9065656565656566 0.9065656565656566 0.9065656565656566 0.9065656565656566 0.9065656565656566 0.9065656565656566 0.9062111801242236 0.9032091097308489 0.8996157540826129 0.8891741548994437 0.8832819722650231 0.876885303402315 0.8728948204639339 0.8690406121247792 0.8672048997772829 0.8658536585365854 0.863880742913001 0.8611239060340857 0.8603485838779956 0.8593879239040529 0.8588461538461538 0.8584834834834835 0.8547146179996421 0.8534497517548365 0.8502554278416348 0.8496692723992784 0.8496692723992784 0.8491571819622533 0.8483333333333334 0.8477704432170281 0.8475649140937863 0.8467208947635994 0.8451517758387612 0.8434584755403868 0.8434584755403868 0.8427814718809873 0.8388470357025877 0.8346548260915531 0.8338309482137342 0.8308830883088308 0.8285158150851581 0.8265296502309796 0.8254481611532064 0.823286688195192 0.8213061439746858 0.8194837492496355 0.8179991628296358 0.8166038971671852 0.8155511022044089 0.8133813851457223 0.8105770702586866 0.8086221091235686 0.8066525020147997 0.8039131369597936 0.8015251154330488 0.7995600769865273 0.7972663614328037 0.7958968269674781 0.7934205413806502 0.790736040609137 0.7884842826019297 0.7857273559011894 0.7821953844314241 0.7789492180636092 0.7765828972338835 0.7735582891321521 0.770961485175949 0.7674203717795413 0.7631550916063059 0.757149554617909 0.7519893899204244 0.7461323485377782 0.735384018242686 0.7215986474428215 0.6898338870431894 0.5955696441375147 0.4046877768496722 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/HRNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9827586206896551 0.9613636363636363 0.9493293591654247 0.9476534296028881 0.9476534296028881 0.9421088904203997 0.9352148813341886 0.9261487964989059 0.9203539823008849 0.9110039456378781 0.9098360655737705 0.9098360655737705 0.9098360655737705 0.9057542768273716 0.9038128249566725 0.9019448946515397 0.900995151824445 0.8992788461538461 0.8978384527872583 0.8969569779643232 0.8942622950819672 0.8906219535971924 0.88909426987061 0.8872514204545454 0.8868214407647661 0.8840674635663992 0.8832391713747646 0.8830045523520486 0.8813829017792565 0.880433273156506 0.879470913753543 0.8790918690601901 0.8777863182167563 0.875898834614431 0.8747744496571634 0.8718217867972942 0.8692159746577667 0.8678130037614186 0.8673229188957843 0.8652083333333334 0.8644670050761422 0.8631058358061325 0.8606218115314275 0.85962441314554 0.8588354151359294 0.8566196176523138 0.8552348125807842 0.854830421377184 0.8545393466028387 0.8525909501519258 0.8499599037690457 0.8478141433168511 0.8452856265356266 0.843390696625836 0.8410245087215721 0.8385724585436193 0.8345907071846577 0.8313086692435119 0.8295639219934995 0.828009828009828 0.826095461658842 0.8231157706322316 0.8203007518796992 0.8173822118341739 0.814238310708899 0.8111097964978703 0.8073107049608355 0.8037787389027999 0.8004355352057625 0.7970085470085471 0.7937439431463336 0.7896182077414585 0.7840297889946214 0.7789943227899432 0.7694169960474309 0.7579250720461095 0.731950739367303 0.6761574750553166 0.5619197697920592 0.44184067030190183 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/ISNet.txt:
--------------------------------------------------------------------------------
1 | 0.9302325581395349 0.9224137931034483 0.8767395626242545 0.8582995951417004 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8505711639394875 0.8480606590842811 0.8467638782297263 0.8460559796437659 0.8428435568235854 0.8415754923413566 0.8410877416613555 0.8408040201005025 0.8387406411979267 0.8386041439476554 0.8370796305976651 0.8349530030269237 0.8349530030269237 0.8333842861510241 0.8330655761647436 0.8325842696629213 0.8323786142935079 0.8319493871095295 0.8309166871467191 0.8309166871467191 0.8298078085337846 0.8286384976525821 0.8269864726611345 0.8250137741046832 0.8237564322469982 0.8213802435723951 0.8208232445520581 0.8185666469777515 0.8165005749329245 0.8145311916324244 0.8130679752817157 0.8111857218589856 0.8095114345114345 0.8087339201083277 0.8071452199801522 0.8047738085638255 0.8014178810555337 0.7983393557315291 0.7948968105065666 0.7918766893125867 0.7891341743119266 0.7859946726482546 0.7820609380349195 0.7787823990355636 0.7763201049524434 0.7722842118763628 0.769452992526534 0.764745596268336 0.7614700936824406 0.7563597908466012 0.7511915015790985 0.7458188348860703 0.7397500548125411 0.7314691285554192 0.7224385937581139 0.7141196097265052 0.7000732958709993 0.6854515444470644 0.6678445229681979 0.6419890926267875 0.5977442126434516 0.4711022535643109 0.1946089825298043 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/ISTDUNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9915611814345991 0.9685534591194969 0.9661016949152542 0.9580645161290322 0.945994599459946 0.9312130177514792 0.918918918918919 0.9153182308522114 0.9073714839961202 0.8961654459284791 0.8875192604006163 0.8796753705010586 0.8755541481950602 0.8742586002372479 0.8684503901895206 0.8656024716786818 0.8632292423142096 0.8604775481111903 0.8604775481111903 0.8604775481111903 0.8604775481111903 0.8604775481111903 0.8604775481111903 0.8592167454422688 0.8589503280224929 0.8589503280224929 0.8589283148972848 0.8589283148972848 0.8581196581196581 0.8575353757384256 0.8565562913907285 0.8558918222794591 0.8548266400598653 0.8528630103889828 0.8508648901355774 0.8501761964306014 0.8484881924519974 0.8469464419877643 0.8444883938794628 0.8419825663896209 0.8387731026346834 0.8379412608820435 0.8363874345549738 0.8350242208207659 0.8328884143085958 0.830593964692582 0.8282442748091603 0.8261373995193503 0.824702163870654 0.8216962524654833 0.8200820623983898 0.8178863017840944 0.814642090458213 0.8120654692931634 0.8085528178986123 0.8059019118869493 0.8026449643947101 0.7991373589913736 0.7946573299617454 0.789970838087993 0.7845639246778989 0.7798776424980314 0.7757812963292415 0.7683681137638014 0.762873777627042 0.7581876521354282 0.7502291722836344 0.741933788754598 0.7331661891117478 0.7241207532131115 0.7111432706222865 0.695025569502557 0.6695971019614773 0.6270404831864186 0.5305058763413388 0.26096447707878545 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/LASNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9761904761904762 0.9626749611197511 0.9620637329286799 0.9498327759197325 0.9474153297682709 0.9405120481927711 0.9349904397705545 0.9340413638904416 0.9269767441860465 0.9260016353229763 0.9239280774550485 0.9239280774550485 0.9225871313672922 0.9214444092492873 0.9198113207547169 0.916597853014038 0.9144039306956296 0.9120391973868409 0.9117782909930716 0.9079118028534371 0.9073575129533679 0.905467847516365 0.9047709201060204 0.9041095890410958 0.9019267488283285 0.9007506255212677 0.8989835809225958 0.8989083295947361 0.8987322893363162 0.897432239657632 0.896105702364395 0.8955223880597015 0.8947437067953568 0.8924514011613229 0.89057156814851 0.8904077023653869 0.8874898925724847 0.886232206405694 0.8855428259683579 0.8843819865857554 0.882279792746114 0.8816027314721832 0.8812623274161736 0.8808629490513339 0.8807978172923134 0.880367816091954 0.8791477471184073 0.8791477471184073 0.8778540772532188 0.8775113962518993 0.8770018160805679 0.8756380134489185 0.8746920937624155 0.8722891566265061 0.8712501903456678 0.8697083021690352 0.8680794896238175 0.8663409025582064 0.865041450049178 0.864266888611458 0.8624821294846484 0.8604651162790697 0.859526938239159 0.8576312733840795 0.8564720812182741 0.8546758104738155 0.8533922218132745 0.850611408951268 0.8484220498549352 0.8459300634200267 0.8438072344322345 0.8403011912789391 0.8374930977360574 0.8343281962231387 0.8310918988152417 0.8267588663628268 0.8231751261196335 0.8174843529174238 0.805569197524801 0.7832052430571927 0.7239702147806999 0.5887682412338336 0.41528711535217716 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/RDIAN.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8106096595407759 0.8102043733028441 0.8102043733028441 0.8102043733028441 0.808944737197676 0.8088080168776371 0.8081056584228926 0.8063872255489022 0.8044699414365961 0.803565146006802 0.8019843342036553 0.8019843342036553 0.8019843342036553 0.8019616026711185 0.8012762078395624 0.7986653956148713 0.7986653956148713 0.7985644612128462 0.7979871912168344 0.7975022301516503 0.7954209105945852 0.7942900841192965 0.7929862377715139 0.7918016850291639 0.7903417651256839 0.7885599131917532 0.7867535616853591 0.7839378238341969 0.7825615834717908 0.7803857608849809 0.7781709686372468 0.776034236804565 0.7729115438293348 0.7692857607070904 0.7666369767145947 0.7634924591798579 0.7601903832072248 0.7566083895220479 0.7530120481927711 0.7494130447231289 0.7446093531223746 0.7382671480144405 0.7320980402627223 0.7255801825293351 0.7147051379397621 0.7023529411764706 0.6902364896951277 0.6715867158671587 0.6439400465156344 0.6132961074694505 0.568947641264904 0.4683186479415391 0.27980707851620684 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/RISTD.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9355932203389831 0.931909212283044 0.931909212283044 0.9248197734294542 0.923265306122449 0.9230769230769231 0.9222290263319045 0.9215035931453842 0.9148727984344422 0.9118807540552389 0.9074509803921569 0.9058151980021406 0.9021702838063439 0.8991468616697136 0.896700143472023 0.8939393939393939 0.8924974823766365 0.8906546489563567 0.8902596206960044 0.8902596206960044 0.8902596206960044 0.8902596206960044 0.8902596206960044 0.8897331684043117 0.8891933550265456 0.8881118881118881 0.8868607395751377 0.8857532999544834 0.8839692482915718 0.8839046409930879 0.8830385300973537 0.8818939293744213 0.8791096328514776 0.875897944017835 0.875 0.8725638931030458 0.8711322679360761 0.8682595352915388 0.8671291355389541 0.8651884239017281 0.8635351285685361 0.8609081934846989 0.8585848875216304 0.8577861163227017 0.8559135847674845 0.8549986598767086 0.8533834586466166 0.8510801810263855 0.8498503491852345 0.8479947725230744 0.8461845785057931 0.8424748381056409 0.8409680867307986 0.8386927909063715 0.8354643091985096 0.832414632400315 0.8290718038528897 0.8251575774184708 0.8215242432360033 0.816543226060169 0.8121833921128567 0.8068801303339809 0.8018879490008581 0.7972762179025678 0.7925912880122109 0.7880844211858071 0.7822884453190891 0.7762203166226913 0.7666827206079092 0.7584785557758531 0.7491234310686519 0.7386750975645903 0.7247991583779648 0.7062094957944569 0.6829963598087803 0.6427058727702207 0.5897889921794304 0.48389602629220196 0.3288756115165611 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/SANet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9260869565217391 0.9042016806722689 0.9040333796940194 0.8948979591836734 0.8843881856540085 0.8746001279590531 0.8705673758865248 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8684660961158657 0.8683132838495009 0.8682521706304266 0.8677878713479008 0.867358842743817 0.8671757491611709 0.8661781285231116 0.8639790232710587 0.8634089215373151 0.8623673637581127 0.8619783772860463 0.8599607458292443 0.8593840230991338 0.8581473842115133 0.8561793656580386 0.8541927409261577 0.8520911551558543 0.8496291876225386 0.8479166666666667 0.8461663814878188 0.8439117623636219 0.8424582748401186 0.8399847386493705 0.8379287155346334 0.8356124314442414 0.8333213901827302 0.8305964912280702 0.8283592353183881 0.8254278399137582 0.8231795956125281 0.8188752424046541 0.814873417721519 0.8126358948872461 0.80934112775545 0.806127934061996 0.8020717504535612 0.7957106090935088 0.78859472743521 0.7819561183276935 0.7750293208231155 0.7663357072308572 0.7530348058227976 0.7367065227957617 0.7098309156426129 0.6590909090909091 0.49278620948098184 0.2831438387702334 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/SSTNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.978494623655914 0.9696356275303644 0.9546142208774584 0.9444444444444444 0.941282746160795 0.9320531757754801 0.9240506329113924 0.9193815571507454 0.9140625 0.9060052219321149 0.9043856183326748 0.9012212643678161 0.8923988153998026 0.8900796080832823 0.8896434634974533 0.8879701253667645 0.8864693446088795 0.8864693446088795 0.8864693446088795 0.8864693446088795 0.8844363489172232 0.8826687116564417 0.8820399926618969 0.881002824858757 0.8793833643909876 0.8772668393782384 0.8766449746926971 0.8766449746926971 0.8766449746926971 0.8754706456561149 0.8740520043336945 0.8723181580324437 0.8717467467467468 0.8699644912452553 0.8690661245964367 0.8681870011402508 0.8678905282126366 0.8664763368837158 0.8654553093162483 0.8652951514524967 0.8651073696945256 0.8645142121422205 0.8616847565075401 0.8610587792012058 0.8598638203901362 0.8589513377626145 0.8579964850615114 0.857057057057057 0.8564693997811264 0.8559315057215774 0.855533199195171 0.8546130128863941 0.8538509124652026 0.8531558935361216 0.8521336914353924 0.8501894491401923 0.8499643112062812 0.8492561517309455 0.8479850952249517 0.8469380819683273 0.8457432612756872 0.8444837990292535 0.843109631147541 0.841981879237154 0.8409812589502521 0.8397847489757231 0.8390804597701149 0.8372354928575663 0.8347694457382394 0.8326649862511457 0.8314251875246743 0.8281726028914862 0.8248286367098249 0.8211499412204767 0.81686595342983 0.810144778724715 0.8000500375281461 0.7860228099975735 0.7553764678793461 0.6597688366366128 0.5027500448376876 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/SwinT.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.87890625 0.8689788053949904 0.8677298311444653 0.8677298311444653 0.857439446366782 0.8566529492455418 0.8471741637831603 0.8436893203883495 0.8402154398563735 0.8344716753716352 0.8341030195381883 0.8271484375 0.8246792913866829 0.8204334365325078 0.8198198198198198 0.8198198198198198 0.8181186283595923 0.8160132752540966 0.81575682382134 0.8135394247701037 0.812081784386617 0.8104199337863739 0.8098941098610192 0.8084830824024607 0.805387830933581 0.8035077288941737 0.8010530809733883 0.8001374570446735 0.7954065469904963 0.7923448626653102 0.790894588293042 0.7889614243323442 0.7872242647058824 0.7854998332036028 0.7815919330615748 0.7794577685088634 0.7766548762001011 0.7741049534085336 0.773343517280886 0.7702840225737811 0.767075148675437 0.7624584717607974 0.7592969943963321 0.75591593579348 0.7534808853118712 0.7508253419273699 0.7468267319161951 0.7433067342829442 0.7396772317534167 0.7344976313370573 0.7299421009098428 0.7259647707408902 0.7222550370807902 0.7172647604426534 0.7131705489289952 0.70710203088897 0.701874036294627 0.6964037927844589 0.688725352508286 0.6815074112563584 0.6732568027210885 0.6599682165376531 0.64409867546408 0.6212135377711294 0.5844954525768731 0.5175422032554099 0.4281142294436238 0.33388511835788776 0.23292904987969046 0.1600823815764838 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/U2Net.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9353448275862069 0.9030172413793104 0.8877551020408163 0.8754863813229572 0.8615004122011541 0.8557692307692307 0.8467072495849475 0.8440832910106653 0.838339222614841 0.8327987169206095 0.8277934936350778 0.8254545454545454 0.8224725943970768 0.8212493028443949 0.8193971166448231 0.8182041820418204 0.8164191419141914 0.8164191419141914 0.8164191419141914 0.8146180619161019 0.8146180619161019 0.8141933228456564 0.8136170212765957 0.8112366703803915 0.8096573208722742 0.8075257578020009 0.8071613459879207 0.8062689864678265 0.8054190463540974 0.8052115583075335 0.8046378257075177 0.8034095030830612 0.8011171884091702 0.8002655748589134 0.7995599559955996 0.7978451034777043 0.795842538662382 0.795842538662382 0.794818957904033 0.7941934867729921 0.7926443202979516 0.7901512423478574 0.7897340754483612 0.7881531298499741 0.786790071518721 0.7854793619470481 0.7849393233143133 0.7838773102634683 0.7828729706855428 0.7813668311649469 0.7806190125276344 0.7785457163426925 0.777079508543991 0.7754269515314942 0.7726873437394418 0.7704537926465718 0.7683568136077388 0.767055541465889 0.7657792775903012 0.7641013091887924 0.7626751668972154 0.7607951864086834 0.7596098014167925 0.7583452211126962 0.7554436048138818 0.7522638713572252 0.7494613229907348 0.746392515460648 0.7410261715113272 0.7345360824742269 0.7285439003903355 0.7197458361413305 0.7066337233695399 0.6792361080412006 0.6010420686993438 0.3245244215938303 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/UCF.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.9831223628691983 0.976905311778291 0.9720062208398134 0.9618497109826589 0.94579945799458 0.9367897727272727 0.9345734445157152 0.9249862107004965 0.9184275184275185 0.9067538126361656 0.9005503144654088 0.8951902368987796 0.8899474375821288 0.8851167020309184 0.8811963445029077 0.8776205450733753 0.8752767527675277 0.872639336711193 0.8719298245614036 0.870215417482642 0.870215417482642 0.870215417482642 0.870215417482642 0.8701020744155417 0.8701020744155417 0.869877212565779 0.868461776539771 0.8679410477163286 0.8673250322026621 0.8649173955296404 0.864139863872948 0.8634247284014486 0.8609121748963438 0.8587577488756534 0.8581041199386141 0.8575185820468839 0.8560311284046692 0.8536664503569111 0.8519882179675994 0.8489737567650363 0.8474694242815949 0.8461613028305545 0.8441902240710977 0.8436377038606837 0.8414370610480821 0.8394974962663622 0.8368497480140087 0.8338211926068412 0.8320666830105418 0.8305152336895837 0.828484279680901 0.8268921711180838 0.8246349681767129 0.8217923353117902 0.8196756601607348 0.8169310683700688 0.8135954361124476 0.8102125941872982 0.8078570957748336 0.8042088954877026 0.8007328784432651 0.7971668934801435 0.7916263310745402 0.7864497041420119 0.7819361566537281 0.7773995915588836 0.7716155509955077 0.7653083342364799 0.7594487145507554 0.7512792681035819 0.7417002012072434 0.7310776452180716 0.7173347214992192 0.6968253968253968 0.6659681833732467 0.6098230156890022 0.41462655601659754 0.18056143949697875 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/results/IRDST/UIUNet.txt:
--------------------------------------------------------------------------------
1 | 1.0 0.92578125 0.9114470842332614 0.8984910836762688 0.8953553441522104 0.8953553441522104 0.8953553441522104 0.8953553441522104 0.887583035258048 0.8833546189868029 0.8833333333333333 0.8783284023668639 0.8773424190800682 0.8760124610591901 0.8721692491060786 0.8701949860724234 0.8672911787665886 0.86597188560349 0.8632794457274827 0.8591334639669062 0.8572605561277034 0.8555858310626703 0.8544315459738345 0.8520963425512935 0.850187265917603 0.849359494081401 0.8477715003138732 0.84666465439179 0.8448125544899738 0.8428969359331476 0.8406422884900823 0.8373634945397815 0.8360184838266517 0.8360116873630388 0.8331403427512737 0.8311643835616438 0.8281665190434012 0.8271379458530297 0.8252265389021977 0.821983152339389 0.8184236453201971 0.8151091535810034 0.8114715998884448 0.8076159537989532 0.8051467348197762 0.80333218825146 0.7999162829635831 0.7957821024346552 0.7923523332538358 0.7865646258503401 0.7836654589371981 0.7796510343812119 0.77502691065662 0.7684842560916009 0.7622315846697472 0.7562913907284768 0.751033324722294 0.7432177844762623 0.735107421875 0.7290092234454032 0.720497253541486 0.7136258660508084 0.7060597751576638 0.6967243675099867 0.6887599108669742 0.6761861816724285 0.6579604378720952 0.6270903763804372 0.5575123395853899 0.38012735473600423 0.23558300106972674 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
--------------------------------------------------------------------------------
/test_DMIST.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import colorsys
4 | from nets.LASNet import LASNet
5 | from utils.utils import (cvtColor, get_classes, preprocess_input, resize_image,
6 | show_config)
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from PIL import Image
11 | from pycocotools.coco import COCO
12 | from pycocotools.cocoeval import COCOeval
13 | from tqdm import tqdm
14 | from utils.utils import cvtColor, get_classes, preprocess_input, resize_image
15 | from utils.utils_bbox import decode_outputs, non_max_suppression
16 |
17 | map_mode = 0
18 | cocoGt_path = '/home/public/DMIST/100_coco_val.json' #60_coco_val.json
19 | dataset_img_path = '/home/public/DMIST/'
20 | temp_save_path = 'map_out/coco_eval'
21 |
22 | class MAP_vid(object):
23 | _defaults = {
24 |
25 | "model_path" : '/home/LASNet/logs/model.pth',
26 | "classes_path" : 'model_data/classes.txt',
27 | "input_shape" : [512, 512],
28 | "phi" : 's',
29 | "confidence" : 0.5,
30 | "nms_iou" : 0.3,
31 | "letterbox_image" : True,
32 | "cuda" : True,
33 | }
34 |
35 | @classmethod
36 | def get_defaults(cls, n):
37 | if n in cls._defaults:
38 | return cls._defaults[n]
39 | else:
40 | return "Unrecognized attribute name '" + n + "'"
41 |
42 | def __init__(self, **kwargs):
43 | self.__dict__.update(self._defaults)
44 | for name, value in kwargs.items():
45 | setattr(self, name, value)
46 | self.dataset = dataset
47 | self.class_names, self.num_classes = get_classes(self.classes_path)
48 |
49 | hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
50 | self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
51 | self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
52 | self.generate()
53 |
54 | show_config(**self._defaults)
55 |
56 | def generate(self, onnx=False):
57 |
58 | self.net = LASNet(self.num_classes, num_frame=5)
59 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60 | self.net.load_state_dict(torch.load(self.model_path, map_location=device))
61 | self.net = self.net.eval()
62 | print('{} model, and classes loaded.'.format(self.model_path))
63 | if not onnx:
64 | if self.cuda:
65 | self.net = nn.DataParallel(self.net)
66 | self.net = self.net.cuda()
67 |
68 | def detect_image(self, image_id, images, results):
69 |
70 | image_shape = np.array(np.shape(images[0])[0:2])
71 | images = [cvtColor(image) for image in images]
72 | image_data = [resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) for image in images]
73 | image_data = [np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)) for image in image_data]
74 | # (3, 640, 640) -> (3, 16, 640, 640)
75 | image_data = np.stack(image_data, axis=1)
76 |
77 | image_data = np.expand_dims(image_data, 0)
78 |
79 | with torch.no_grad():
80 | images = torch.from_numpy(image_data)
81 | if self.cuda:
82 | images = images.cuda()
83 | outputs = self.net(images)
84 | outputs = decode_outputs(outputs, self.input_shape)
85 | outputs = non_max_suppression(outputs, self.num_classes, self.input_shape,
86 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
87 |
88 | if outputs[0] is None:
89 | return results
90 |
91 | top_label = np.array(outputs[0][:, 6], dtype = 'int32')
92 | top_conf = outputs[0][:, 4] * outputs[0][:, 5]
93 | top_boxes = outputs[0][:, :4]
94 |
95 | for i, c in enumerate(top_label):
96 | result = {}
97 | top, left, bottom, right = top_boxes[i]
98 |
99 | result["image_id"] = int(image_id)
100 | result["category_id"] = clsid2catid[c]
101 | result["bbox"] = [float(left),float(top),float(right-left),float(bottom-top)]
102 | result["score"] = float(top_conf[i])
103 | results.append(result)
104 | return results
105 |
106 | def get_history_imgs(line):
107 | dir_path = line.replace(line.split('/')[-1],'')
108 | file_type = line.split('.')[-1]
109 | index = int(line.split('/')[-1][:-4])
110 |
111 | return [os.path.join(dir_path, "%d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)]
112 |
113 |
114 |
115 | if __name__ == "__main__":
116 | if not os.path.exists(temp_save_path):
117 | os.makedirs(temp_save_path)
118 |
119 | cocoGt = COCO(cocoGt_path)
120 | ids = list(cocoGt.imgToAnns.keys())
121 | clsid2catid = cocoGt.getCatIds()
122 |
123 | if map_mode == 0 or map_mode == 1:
124 | yolo = MAP_vid(confidence = 0.001, nms_iou = 0.65)
125 |
126 | with open(os.path.join(temp_save_path, 'eval_results.json'),"w") as f:
127 | results = []
128 | for image_id in tqdm(ids):
129 | image_path = os.path.join(dataset_img_path, cocoGt.loadImgs(image_id)[0]['file_name'])
130 |
131 | images = get_history_imgs(image_path)
132 | images = [Image.open(item) for item in images]
133 | # image = Image.open(image_path)
134 | results = yolo.detect_image(image_id, images, results)
135 | json.dump(results, f)
136 |
137 | if map_mode == 0 or map_mode == 2:
138 | cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json'))
139 | cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
140 | cocoEval.evaluate()
141 | cocoEval.accumulate()
142 | cocoEval.summarize()
143 |
144 | """
145 | T:iouThrs [0.5:0.05:0.95] T=10 IoU thresholds for evaluation
146 | R:recThrs [0:0.01:100] R=101 recall thresholds for evaluation
147 | K: category ids
148 | A: [all, small, meduim, large] A=4
149 | M: maxDets [1, 10, 100] M=3 max detections per image
150 | """
151 | precisions = cocoEval.eval['precision']
152 | precision_50 = precisions[0,:,0,0,-1]
153 | recalls = cocoEval.eval['recall']
154 | recall_50 = recalls[0,0,0,-1]
155 |
156 | with open("pr_results.txt", 'w') as f:
157 | for pred in precision_50:
158 | f.writelines(str(pred)+'\t')
159 |
160 | print("Precision: %.4f, Recall: %.4f, F1: %.4f" %(np.mean(precision_50[:int(recall_50*100)]), recall_50, 2*recall_50*np.mean(precision_50[:int(recall_50*100)])/( recall_50+np.mean(precision_50[:int(recall_50*100)]))))
161 | print("Get map done.")
162 |
163 | import matplotlib.pyplot as plt
164 | plt.figure(1)
165 | plt.title('PR Curve')# give plot a title
166 | plt.xlabel('Recall')# make axis labels
167 | plt.ylabel('Precision')
168 |
169 | x_axis = plt.xlim(0,105)
170 | y_axis = plt.ylim(0,1.05)
171 | plt.figure(1)
172 | plt.plot(precision_50)
173 | plt.show()
174 | plt.savefig('p-r.png')
175 |
176 |
--------------------------------------------------------------------------------
/train_DMIST.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os,random
3 | import numpy as np
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torch.distributed as dist
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from torch.utils.data import DataLoader
10 | import pycocotools.coco as coco
11 | from nets.LASNet import LASNet
12 | from nets.training import (ModelEMA, YOLOLoss, get_lr_scheduler,
13 | set_optimizer_lr, weights_init)
14 | from utils.callbacks import EvalCallback, LossHistory
15 | from utils.dataloader import YoloDataset
16 | from utils.dataloader_for_DMIST import seqDataset, dataset_collate
17 | from utils.utils import get_classes, show_config
18 | from utils.utils_fit import fit_one_epoch
19 |
20 |
21 | if __name__ == "__main__":
22 |
23 | Cuda = True
24 | distributed = False
25 | sync_bn = False
26 | fp16 = False
27 | classes_path = 'model_data/classes.txt'
28 | model_path = ''
29 | input_shape = [512, 512]
30 | phi = 's'
31 | mosaic = False
32 | mosaic_prob = 0.5
33 | mixup = False
34 | mixup_prob = 0.5
35 | special_aug_ratio = 0.7
36 |
37 |
38 | Init_Epoch = 0
39 | Freeze_Epoch = 0
40 | Freeze_batch_size = 4
41 | UnFreeze_Epoch = 100
42 | Unfreeze_batch_size = 4
43 | Freeze_Train = False
44 | Init_lr = 1e-2
45 | Min_lr = Init_lr * 0.01
46 | optimizer_type = "sgd"
47 | momentum = 0.937
48 | weight_decay = 5e-4
49 | lr_decay_type = "cos"
50 |
51 | save_period = 1
52 | save_dir = 'logs'
53 | eval_flag = True
54 | eval_period = 1
55 | num_workers = 4
56 |
57 | train_annotation_path = '/home/LASNet/DMIST_train.txt'
58 | val_annotation_path = '/home/LASNet/DMIST_60_val.txt' # DMIST_100_val.txt
59 |
60 | ngpus_per_node = torch.cuda.device_count()
61 |
62 | if distributed:
63 | dist.init_process_group(backend="nccl")
64 | local_rank = int(os.environ["LOCAL_RANK"])
65 | rank = int(os.environ["RANK"])
66 | device = torch.device("cuda", local_rank)
67 | if local_rank == 0:
68 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
69 | print("Gpu Device Count : ", ngpus_per_node)
70 | else:
71 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
72 | local_rank = 0
73 | rank = 0
74 |
75 | seed = 2023
76 | torch.manual_seed(seed)
77 | torch.cuda.manual_seed_all(seed)
78 | np.random.seed(seed)
79 | random.seed(seed)
80 | torch.backends.cudnn.deterministic = True
81 |
82 | class_names, num_classes = get_classes(classes_path)
83 |
84 | model = LASNet(num_classes=1, num_frame=5)
85 | weights_init(model)
86 | if model_path != '':
87 |
88 | if local_rank == 0:
89 | print('Load weights {}.'.format(model_path))
90 |
91 | model_dict = model.state_dict()
92 | pretrained_dict = torch.load(model_path, map_location = device)
93 | load_key, no_load_key, temp_dict = [], [], {}
94 | for k, v in pretrained_dict.items():
95 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
96 | temp_dict[k] = v
97 | load_key.append(k)
98 | else:
99 | no_load_key.append(k)
100 | model_dict.update(temp_dict)
101 | model.load_state_dict(model_dict)
102 |
103 | if local_rank == 0:
104 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key))
105 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key))
106 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m")
107 |
108 | yolo_loss = YOLOLoss(num_classes, fp16, strides=[8])
109 |
110 | if local_rank == 0:
111 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
112 | log_dir = os.path.join(save_dir, "loss_" + str(time_str))
113 | loss_history = LossHistory(log_dir, model, input_shape=input_shape)
114 | else:
115 | loss_history = None
116 |
117 | if fp16:
118 | from torch.cuda.amp import GradScaler as GradScaler
119 | scaler = GradScaler()
120 | else:
121 | scaler = None
122 |
123 | model_train = model.train()
124 |
125 | if sync_bn and ngpus_per_node > 1 and distributed:
126 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train)
127 | elif sync_bn:
128 | print("Sync_bn is not support in one gpu or not distributed.")
129 |
130 | if Cuda:
131 | if distributed:
132 |
133 | model_train = model_train.cuda(local_rank)
134 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank],find_unused_parameters=True)
135 | else:
136 | model_train = torch.nn.DataParallel(model)
137 | cudnn.benchmark = True
138 | model_train = model_train.cuda()
139 |
140 | ema = ModelEMA(model_train)
141 |
142 | with open(train_annotation_path, encoding='utf-8') as f:
143 | train_lines = f.readlines()
144 | with open(val_annotation_path, encoding='utf-8') as f:
145 | val_lines = f.readlines()
146 | num_train = len(train_lines)
147 | num_val = len(val_lines)
148 |
149 |
150 | if local_rank == 0:
151 | show_config(
152 | classes_path = classes_path, model_path = model_path, input_shape = input_shape, \
153 | Init_Epoch = Init_Epoch, Freeze_Epoch = Freeze_Epoch, UnFreeze_Epoch = UnFreeze_Epoch, Freeze_batch_size = Freeze_batch_size, Unfreeze_batch_size = Unfreeze_batch_size, Freeze_Train = Freeze_Train, \
154 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \
155 | save_period = save_period, save_dir = log_dir, num_workers = num_workers, num_train = num_train, num_val = num_val
156 | )
157 |
158 | wanted_step = 5e4 if optimizer_type == "sgd" else 1.5e4
159 | total_step = num_train // Unfreeze_batch_size * UnFreeze_Epoch
160 | if total_step <= wanted_step:
161 | if num_train // Unfreeze_batch_size == 0:
162 | raise ValueError('The dataset is too small for training. Please expand the dataset.')
163 | wanted_epoch = wanted_step // (num_train // Unfreeze_batch_size) + 1
164 | print("\n\033[1;33;44m[Warning] When using the %s optimizer, it is recommended to set the total training step size above %d. \033[0m"%(optimizer_type, wanted_step))
165 | print("\033[1;33;44m[Warning] The total training data amount of this run is %d, the Unfreeze_batch_size is %d, a total of %d epochs are trained, and the total training step size is %d. \033[0m"%(num_train, Unfreeze_batch_size, UnFreeze_Epoch, total_step))
166 | print("\033[1;33;44m[Warning] Since the total training step size is %d, which is less than the recommended total step size %d, it is recommended to set the total epoch to %d. \033[0m"%(total_step, wanted_step, wanted_epoch))
167 |
168 |
169 | if True:
170 | UnFreeze_flag = False
171 | if Freeze_Train:
172 | for param in model.backbone.parameters():
173 | param.requires_grad = False
174 |
175 | batch_size = Freeze_batch_size if Freeze_Train else Unfreeze_batch_size
176 |
177 | nbs = 64
178 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
179 | lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4
180 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
181 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
182 |
183 | pg0, pg1, pg2 = [], [], []
184 |
185 | for k, v in model.named_modules():
186 |
187 | if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
188 | pg2.append(v.bias)
189 | if 'subnet' not in k and (isinstance(v, nn.BatchNorm2d) or "bn" in k):
190 | #######################################
191 | # if hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
192 | #######################################
193 | pg0.append(v.weight)
194 | elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
195 | pg1.append(v.weight)
196 |
197 |
198 | optimizer = {
199 | 'adam' : optim.Adam(pg0, Init_lr_fit, betas = (momentum, 0.999)),
200 | 'sgd' : optim.SGD(pg0, Init_lr_fit, momentum = momentum, nesterov=True)
201 | }[optimizer_type]
202 | optimizer.add_param_group({"params": pg1, "weight_decay": weight_decay})
203 | optimizer.add_param_group({"params": pg2})
204 |
205 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
206 |
207 | epoch_step = num_train // batch_size
208 | epoch_step_val = num_val // batch_size
209 |
210 | if epoch_step == 0 or epoch_step_val == 0:
211 | raise ValueError("The dataset is too small to continue training. Please expand the dataset. ")
212 |
213 | if ema:
214 | ema.updates = epoch_step * Init_Epoch
215 |
216 | train_dataset = seqDataset(train_annotation_path, input_shape[0], 5, 'train')
217 | val_dataset = seqDataset(val_annotation_path, input_shape[0], 5, 'val')
218 |
219 | if distributed:
220 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,)
221 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,)
222 | batch_size = batch_size // ngpus_per_node
223 | shuffle = False
224 | else:
225 | train_sampler = None
226 | val_sampler = None
227 | shuffle = True
228 |
229 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
230 | drop_last=True, collate_fn=dataset_collate, sampler=train_sampler)
231 |
232 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
233 | drop_last=True, collate_fn=dataset_collate, sampler=val_sampler)
234 |
235 | if local_rank == 0:
236 | eval_callback = EvalCallback(model, input_shape, class_names, num_classes, val_lines, log_dir, Cuda, \
237 | eval_flag=eval_flag, period=eval_period)
238 | else:
239 | eval_callback = None
240 |
241 |
242 | for epoch in range(Init_Epoch, UnFreeze_Epoch):
243 |
244 | if epoch >= Freeze_Epoch and not UnFreeze_flag and Freeze_Train:
245 | batch_size = Unfreeze_batch_size
246 |
247 | nbs = 64
248 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 5e-2
249 | lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4
250 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
251 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
252 |
253 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
254 |
255 | for param in model.backbone.parameters():
256 | param.requires_grad = True
257 |
258 | epoch_step = num_train // batch_size
259 | epoch_step_val = num_val // batch_size
260 |
261 | if epoch_step == 0 or epoch_step_val == 0:
262 | raise ValueError("The dataset is too small to continue training. Please expand the dataset.")
263 |
264 | if distributed:
265 | batch_size = batch_size // ngpus_per_node
266 |
267 | if ema:
268 | ema.updates = epoch_step * epoch
269 |
270 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
271 | drop_last=True, collate_fn=dataset_collate, sampler=train_sampler)
272 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
273 | drop_last=True, collate_fn=dataset_collate, sampler=val_sampler)
274 |
275 | UnFreeze_flag = True
276 |
277 | gen.dataset.epoch_now = epoch
278 | gen_val.dataset.epoch_now = epoch
279 |
280 | if distributed:
281 | train_sampler.set_epoch(epoch)
282 |
283 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
284 |
285 | fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, log_dir, local_rank)
286 |
287 | if distributed:
288 | dist.barrier()
289 |
290 | if local_rank == 0:
291 | loss_history.writer.close()
292 |
293 |
294 |
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UESTC-nnLab/DMIST/07bb456ae2c4b2a71a0065a30d84953cbfd38844/utils/.DS_Store
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | from email.mime import image
2 | import os
3 | import torch
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import scipy.signal
7 | from matplotlib import pyplot as plt
8 | from torch.utils.tensorboard import SummaryWriter
9 | import shutil
10 | import numpy as np
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from .utils import cvtColor, preprocess_input, resize_image
14 | from .utils_bbox import decode_outputs, non_max_suppression
15 | from .utils_map import get_coco_map, get_map
16 |
17 | class LossHistory():
18 | def __init__(self, log_dir, model, input_shape):
19 | self.log_dir = log_dir
20 | self.losses = []
21 | self.val_loss = []
22 |
23 | os.makedirs(self.log_dir)
24 | self.writer = SummaryWriter(self.log_dir)
25 | try:
26 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
27 | self.writer.add_graph(model, dummy_input)
28 | except:
29 | pass
30 |
31 | def append_loss(self, epoch, loss, val_loss):
32 | if not os.path.exists(self.log_dir):
33 | os.makedirs(self.log_dir)
34 |
35 | self.losses.append(loss)
36 | self.val_loss.append(val_loss)
37 |
38 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
39 | f.write(str(loss))
40 | f.write("\n")
41 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
42 | f.write(str(val_loss))
43 | f.write("\n")
44 |
45 | self.writer.add_scalar('loss', loss, epoch)
46 | self.writer.add_scalar('val_loss', val_loss, epoch)
47 | self.loss_plot()
48 |
49 | def loss_plot(self):
50 | iters = range(len(self.losses))
51 |
52 | plt.figure()
53 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
54 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
55 | try:
56 | if len(self.losses) < 25:
57 | num = 5
58 | else:
59 | num = 15
60 |
61 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
62 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
63 | except:
64 | pass
65 |
66 | plt.grid(True)
67 | plt.xlabel('Epoch')
68 | plt.ylabel('Loss')
69 | plt.legend(loc="upper right")
70 |
71 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
72 |
73 | plt.cla()
74 | plt.close("all")
75 |
76 | class EvalCallback():
77 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
78 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
79 | super(EvalCallback, self).__init__()
80 |
81 | self.net = net
82 | self.input_shape = input_shape
83 | self.class_names = class_names
84 | self.num_classes = num_classes
85 | self.val_lines = val_lines
86 |
87 | self.log_dir = log_dir
88 | self.cuda = cuda
89 | self.map_out_path = map_out_path
90 | self.max_boxes = max_boxes
91 | self.confidence = confidence
92 | self.nms_iou = nms_iou
93 | self.letterbox_image = letterbox_image
94 | self.MINOVERLAP = MINOVERLAP
95 | self.eval_flag = eval_flag
96 | self.period = period
97 |
98 | self.maps = [0]
99 | self.epoches = [0]
100 | if self.eval_flag:
101 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
102 | f.write(str(0))
103 | f.write("\n")
104 |
105 | def get_history_imgs(self, line):
106 | dir_path = line.replace(line.split('/')[-1],'')
107 | file_type = line.split('.')[-1]
108 | index = int(line.split('/')[-1][:-4])
109 |
110 | return [os.path.join(dir_path, "%d.%s" % (max(id, 0),file_type)) for id in range(index - 4, index + 1)]
111 |
112 |
113 |
114 | def get_map_txt(self, image_id, images, class_names, map_out_path):
115 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
116 | image_shape = np.array(np.shape(images[0])[0:2])
117 | images = [cvtColor(image) for image in images]
118 | image_data = [resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) for image in images]
119 | image_data = [np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)) for image in image_data]
120 | image_data = np.stack(image_data, axis=1)
121 | image_data = np.expand_dims(image_data, 0)
122 |
123 | with torch.no_grad():
124 | images = torch.from_numpy(image_data)
125 | if self.cuda:
126 | images = images.cuda()
127 | outputs = self.net(images)
128 | outputs = decode_outputs(outputs, self.input_shape)
129 | results = non_max_suppression(outputs, self.num_classes, self.input_shape,
130 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
131 |
132 | if results[0] is None:
133 | return
134 |
135 | top_label = np.array(results[0][:, 6], dtype = 'int32')
136 | top_conf = results[0][:, 4] * results[0][:, 5]
137 | top_boxes = results[0][:, :4]
138 |
139 | top_100 = np.argsort(top_label)[::-1][:self.max_boxes]
140 | top_boxes = top_boxes[top_100]
141 | top_conf = top_conf[top_100]
142 | top_label = top_label[top_100]
143 |
144 | for i, c in list(enumerate(top_label)):
145 | predicted_class = self.class_names[int(c)]
146 | box = top_boxes[i]
147 | score = str(top_conf[i])
148 |
149 | top, left, bottom, right = box
150 | if predicted_class not in class_names:
151 | continue
152 |
153 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
154 |
155 | f.close()
156 | return
157 |
158 | def on_epoch_end(self, epoch, model_eval):
159 | if epoch % self.period == 0 and self.eval_flag:
160 | self.net = model_eval
161 | if not os.path.exists(self.map_out_path):
162 | os.makedirs(self.map_out_path)
163 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
164 | os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
165 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
166 | os.makedirs(os.path.join(self.map_out_path, "detection-results"))
167 | print("Get map.")
168 | for annotation_line in tqdm(self.val_lines):
169 | line = annotation_line.split()
170 | image_id = "-".join(line[0].split("/")[6:8]).split('.')[0]
171 | # cb update
172 | images = self.get_history_imgs(line[0])
173 | images = [Image.open(item) for item in images]
174 | # image = Image.open(line[0])
175 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
176 | self.get_map_txt(image_id, images, self.class_names, self.map_out_path)
177 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
178 | for box in gt_boxes:
179 | left, top, right, bottom, obj = box
180 | obj_name = self.class_names[obj]
181 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
182 |
183 | print("Calculate Map.")
184 | try:
185 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
186 | except:
187 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
188 | self.maps.append(temp_map)
189 | self.epoches.append(epoch)
190 |
191 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
192 | f.write(str(temp_map))
193 | f.write("\n")
194 |
195 | plt.figure()
196 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
197 |
198 | plt.grid(True)
199 | plt.xlabel('Epoch')
200 | plt.ylabel('Map %s'%str(self.MINOVERLAP))
201 | plt.title('A Map Curve')
202 | plt.legend(loc="upper right")
203 |
204 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
205 | plt.cla()
206 | plt.close("all")
207 |
208 | print("Get map done.")
209 | shutil.rmtree(self.map_out_path)
210 |
--------------------------------------------------------------------------------
/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | from random import sample, shuffle
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data.dataset import Dataset
7 | from utils.utils import cvtColor, preprocess_input
8 |
9 | class YoloDataset(Dataset):
10 | def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, \
11 | mosaic, mixup, mosaic_prob, mixup_prob, train, special_aug_ratio = 0.7):
12 | super(YoloDataset, self).__init__()
13 | self.annotation_lines = annotation_lines
14 | self.input_shape = input_shape
15 | self.num_classes = num_classes
16 | self.epoch_length = epoch_length
17 | self.mosaic = mosaic
18 | self.mosaic_prob = mosaic_prob
19 | self.mixup = mixup
20 | self.mixup_prob = mixup_prob
21 | self.train = train
22 | self.special_aug_ratio = special_aug_ratio
23 |
24 | self.epoch_now = -1
25 | self.length = len(self.annotation_lines)
26 |
27 | def __len__(self):
28 | return self.length
29 |
30 | def __getitem__(self, index):
31 | index = index % self.length
32 | if self.mosaic and self.rand() < self.mosaic_prob and self.epoch_now < self.epoch_length * self.special_aug_ratio:
33 | lines = sample(self.annotation_lines, 3)
34 | lines.append(self.annotation_lines[index])
35 | shuffle(lines)
36 | image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
37 |
38 | if self.mixup and self.rand() < self.mixup_prob:
39 | lines = sample(self.annotation_lines, 1)
40 | image_2, box_2 = self.get_random_data(lines[0], self.input_shape, random = self.train)
41 | image, box = self.get_random_data_with_MixUp(image, box, image_2, box_2)
42 | else:
43 | image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
44 |
45 | image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
46 | box = np.array(box, dtype=np.float32)
47 | if len(box) != 0:
48 | box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
49 | box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
50 | return image, box
51 |
52 | def rand(self, a=0, b=1):
53 | return np.random.rand()*(b-a) + a
54 |
55 | def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
56 | line = annotation_line.split()
57 | image = Image.open(line[0])
58 | image = cvtColor(image)
59 | iw, ih = image.size
60 | h, w = input_shape
61 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
62 |
63 | if not random:
64 | scale = min(w/iw, h/ih)
65 | nw = int(iw*scale)
66 | nh = int(ih*scale)
67 | dx = (w-nw)//2
68 | dy = (h-nh)//2
69 |
70 | image = image.resize((nw,nh), Image.BICUBIC)
71 | new_image = Image.new('RGB', (w,h), (128,128,128))
72 | new_image.paste(image, (dx, dy))
73 | image_data = np.array(new_image, np.float32)
74 |
75 | if len(box)>0:
76 | np.random.shuffle(box)
77 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
78 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
79 | box[:, 0:2][box[:, 0:2]<0] = 0
80 | box[:, 2][box[:, 2]>w] = w
81 | box[:, 3][box[:, 3]>h] = h
82 | box_w = box[:, 2] - box[:, 0]
83 | box_h = box[:, 3] - box[:, 1]
84 | box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
85 |
86 | return image_data, box
87 |
88 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
89 | scale = self.rand(.25, 2)
90 | if new_ar < 1:
91 | nh = int(scale*h)
92 | nw = int(nh*new_ar)
93 | else:
94 | nw = int(scale*w)
95 | nh = int(nw/new_ar)
96 | image = image.resize((nw,nh), Image.BICUBIC)
97 |
98 | dx = int(self.rand(0, w-nw))
99 | dy = int(self.rand(0, h-nh))
100 | new_image = Image.new('RGB', (w,h), (128,128,128))
101 | new_image.paste(image, (dx, dy))
102 | image = new_image
103 |
104 | flip = self.rand()<.5
105 | if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
106 | image_data = np.array(image, np.uint8)
107 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
108 | hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
109 | dtype = image_data.dtype
110 | x = np.arange(0, 256, dtype=r.dtype)
111 | lut_hue = ((x * r[0]) % 180).astype(dtype)
112 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
113 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
114 |
115 | image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
116 | image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
117 |
118 | if len(box)>0:
119 | np.random.shuffle(box)
120 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
121 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
122 | if flip: box[:, [0,2]] = w - box[:, [2,0]]
123 | box[:, 0:2][box[:, 0:2]<0] = 0
124 | box[:, 2][box[:, 2]>w] = w
125 | box[:, 3][box[:, 3]>h] = h
126 | box_w = box[:, 2] - box[:, 0]
127 | box_h = box[:, 3] - box[:, 1]
128 | box = box[np.logical_and(box_w>1, box_h>1)]
129 |
130 | return image_data, box
131 |
132 | def merge_bboxes(self, bboxes, cutx, cuty):
133 | merge_bbox = []
134 | for i in range(len(bboxes)):
135 | for box in bboxes[i]:
136 | tmp_box = []
137 | x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
138 |
139 | if i == 0:
140 | if y1 > cuty or x1 > cutx:
141 | continue
142 | if y2 >= cuty and y1 <= cuty:
143 | y2 = cuty
144 | if x2 >= cutx and x1 <= cutx:
145 | x2 = cutx
146 |
147 | if i == 1:
148 | if y2 < cuty or x1 > cutx:
149 | continue
150 | if y2 >= cuty and y1 <= cuty:
151 | y1 = cuty
152 | if x2 >= cutx and x1 <= cutx:
153 | x2 = cutx
154 |
155 | if i == 2:
156 | if y2 < cuty or x2 < cutx:
157 | continue
158 | if y2 >= cuty and y1 <= cuty:
159 | y1 = cuty
160 | if x2 >= cutx and x1 <= cutx:
161 | x1 = cutx
162 |
163 | if i == 3:
164 | if y1 > cuty or x2 < cutx:
165 | continue
166 | if y2 >= cuty and y1 <= cuty:
167 | y2 = cuty
168 | if x2 >= cutx and x1 <= cutx:
169 | x1 = cutx
170 | tmp_box.append(x1)
171 | tmp_box.append(y1)
172 | tmp_box.append(x2)
173 | tmp_box.append(y2)
174 | tmp_box.append(box[-1])
175 | merge_bbox.append(tmp_box)
176 | return merge_bbox
177 |
178 | def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
179 | h, w = input_shape
180 | min_offset_x = self.rand(0.3, 0.7)
181 | min_offset_y = self.rand(0.3, 0.7)
182 |
183 | image_datas = []
184 | box_datas = []
185 | index = 0
186 | for line in annotation_line:
187 |
188 | line_content = line.split()
189 | image = Image.open(line_content[0])
190 | image = cvtColor(image)
191 | iw, ih = image.size
192 | box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
193 | flip = self.rand()<.5
194 | if flip and len(box)>0:
195 | image = image.transpose(Image.FLIP_LEFT_RIGHT)
196 | box[:, [0,2]] = iw - box[:, [2,0]]
197 | new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
198 | scale = self.rand(.4, 1)
199 | if new_ar < 1:
200 | nh = int(scale*h)
201 | nw = int(nh*new_ar)
202 | else:
203 | nw = int(scale*w)
204 | nh = int(nw/new_ar)
205 | image = image.resize((nw, nh), Image.BICUBIC)
206 |
207 | if index == 0:
208 | dx = int(w*min_offset_x) - nw
209 | dy = int(h*min_offset_y) - nh
210 | elif index == 1:
211 | dx = int(w*min_offset_x) - nw
212 | dy = int(h*min_offset_y)
213 | elif index == 2:
214 | dx = int(w*min_offset_x)
215 | dy = int(h*min_offset_y)
216 | elif index == 3:
217 | dx = int(w*min_offset_x)
218 | dy = int(h*min_offset_y) - nh
219 |
220 | new_image = Image.new('RGB', (w,h), (128,128,128))
221 | new_image.paste(image, (dx, dy))
222 | image_data = np.array(new_image)
223 |
224 | index = index + 1
225 | box_data = []
226 |
227 | if len(box)>0:
228 | np.random.shuffle(box)
229 | box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
230 | box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
231 | box[:, 0:2][box[:, 0:2]<0] = 0
232 | box[:, 2][box[:, 2]>w] = w
233 | box[:, 3][box[:, 3]>h] = h
234 | box_w = box[:, 2] - box[:, 0]
235 | box_h = box[:, 3] - box[:, 1]
236 | box = box[np.logical_and(box_w>1, box_h>1)]
237 | box_data = np.zeros((len(box),5))
238 | box_data[:len(box)] = box
239 |
240 | image_datas.append(image_data)
241 | box_datas.append(box_data)
242 |
243 | cutx = int(w * min_offset_x)
244 | cuty = int(h * min_offset_y)
245 |
246 | new_image = np.zeros([h, w, 3])
247 | new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
248 | new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
249 | new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
250 | new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
251 |
252 | new_image = np.array(new_image, np.uint8)
253 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
254 | hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
255 | dtype = new_image.dtype
256 | x = np.arange(0, 256, dtype=r.dtype)
257 | lut_hue = ((x * r[0]) % 180).astype(dtype)
258 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
259 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
260 |
261 | new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
262 | new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
263 | new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
264 |
265 | return new_image, new_boxes
266 |
267 | def get_random_data_with_MixUp(self, image_1, box_1, image_2, box_2):
268 | new_image = np.array(image_1, np.float32) * 0.5 + np.array(image_2, np.float32) * 0.5
269 | if len(box_1) == 0:
270 | new_boxes = box_2
271 | elif len(box_2) == 0:
272 | new_boxes = box_1
273 | else:
274 | new_boxes = np.concatenate([box_1, box_2], axis=0)
275 | return new_image, new_boxes
276 |
277 |
278 | def yolo_dataset_collate(batch):
279 | images = []
280 | bboxes = []
281 | for img, box in batch:
282 | images.append(img)
283 | bboxes.append(box)
284 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
285 | bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
286 | return images, bboxes
287 |
--------------------------------------------------------------------------------
/utils/dataloader_for_DMIST.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | from torch.utils.data.dataset import Dataset
6 | from torch.utils.data import DataLoader
7 | import xml.etree.ElementTree as ET
8 | import time
9 | import torch
10 | import copy
11 |
12 | # convert to RGB
13 | def cvtColor(image):
14 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
15 | return image
16 | else:
17 | image = image.convert('RGB')
18 | return image
19 |
20 | # normalization
21 | def preprocess(image):
22 | image /= 255.0
23 | image -= np.array([0.485, 0.456, 0.406])
24 | image /= np.array([0.229, 0.224, 0.225])
25 | return image
26 |
27 | def rand(a=0, b=1):
28 | return np.random.rand()*(b-a) + a
29 |
30 | def augmentation(images, boxes,h, w, hue=.1, sat=0.7, val=0.4):
31 | # images [5, w, h, 3], bbox [:,4]
32 | filp = rand()<.5
33 | if filp:
34 | for i in range(len(images)):
35 | images[i] = Image.fromarray(images[i].astype('uint8')).convert('RGB').transpose(Image.Transpose.FLIP_LEFT_RIGHT)
36 | for i in range(len(boxes)):
37 | boxes[i][[0,2]] = w - boxes[i][[2,0]]
38 |
39 | images = np.array(images, np.uint8)
40 | r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
41 | for i in range(len(images)):
42 | hue, sat, val = cv2.split(cv2.cvtColor(images[i], cv2.COLOR_RGB2HSV))
43 | dtype = images[i].dtype
44 | x = np.arange(0, 256, dtype=r.dtype)
45 | lut_hue = ((x * r[0]) % 180).astype(dtype)
46 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
47 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
48 | images[i] = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
49 | images[i] = cv2.cvtColor(images[i], cv2.COLOR_HSV2RGB)
50 |
51 | return np.array(images,dtype=np.float32), np.array(boxes,dtype=np.float32)
52 |
53 |
54 |
55 | class seqDataset(Dataset):
56 | def __init__(self, dataset_path, image_size, num_frame=5 ,type='train'):
57 | super(seqDataset, self).__init__()
58 | self.dataset_path = dataset_path
59 | self.img_idx = []
60 | self.anno_idx = []
61 | self.image_size = image_size
62 | self.num_frame = num_frame
63 |
64 |
65 | if type == 'train':
66 | self.txt_path = dataset_path
67 | self.aug = True
68 | else:
69 | self.txt_path = dataset_path
70 | self.aug = False
71 | with open(self.txt_path) as f:
72 | data_lines = f.readlines()
73 | self.length = len(data_lines)
74 | for line in data_lines:
75 | line = line.strip('\n').split()
76 | self.img_idx.append(line[0])
77 | self.anno_idx.append(np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]))
78 |
79 | def __len__(self):
80 | return self.length
81 |
82 | def __getitem__(self, index):
83 |
84 | images, box, multi_box = self.get_data(index)
85 | images = np.transpose(preprocess(images),(3, 0, 1, 2))
86 | if len(box) != 0:
87 | box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
88 | box[:, 0:2] = box[:, 0:2] + ( box[:, 2:4] / 2 )
89 | for box in multi_box:
90 | if len(box) != 0:
91 | box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
92 | box[:, 0:2] = box[:, 0:2] + ( box[:, 2:4] / 2 )
93 | return images, box, multi_box
94 |
95 | def get_data(self, index):
96 | image_data = []
97 | multi_frame_label = []
98 |
99 | h, w = self.image_size, self.image_size
100 | file_name = self.img_idx[index]
101 | image_id = int(file_name.split("/")[-1][:-4])
102 | image_path = file_name.replace(file_name.split("/")[-1], '')
103 | label_data = self.anno_idx[index]
104 |
105 | n = []
106 | m = []
107 |
108 | for id in range(0, self.num_frame):
109 |
110 | with open(image_path + '%d.txt' % max(image_id - id, 0), 'r') as f:
111 | lines = f.readlines()
112 | labels_box = [[int(num) for num in line.strip().split()] for line in lines]
113 | labels = np.array(labels_box)
114 |
115 | img = Image.open(image_path +'%d.bmp' % max(image_id - id, 0))
116 | img = cvtColor(img)
117 | iw, ih = img.size
118 |
119 | scale = min(w/iw, h/ih)
120 | nw = int(iw*scale)
121 | nh = int(ih*scale)
122 | dx = (w-nw)//2
123 | dy = (h-nh)//2
124 |
125 | img = img.resize((nw, nh), Image.Resampling.BICUBIC)
126 | new_img = Image.new('RGB', (w,h), (128, 128, 128))
127 | new_img.paste(img, (dx, dy))
128 | image_data.append(np.array(new_img, np.float32))
129 |
130 | if len(label_data) > 0 and id == 0:
131 | # np.random.shuffle(label_data)
132 | label_data[:, [0, 2]] = label_data[:, [0, 2]]*nw/iw + dx
133 | label_data[:, [1, 3]] = label_data[:, [1, 3]]*nh/ih + dy
134 |
135 | label_data[:, 0:2][label_data[:, 0:2]<0] = 0
136 | label_data[:, 2][label_data[:, 2]>w] = w
137 | label_data[:, 3][label_data[:, 3]>h] = h
138 | # discard invalid box
139 | box_w = label_data[:, 2] - label_data[:, 0]
140 | box_h = label_data[:, 3] - label_data[:, 1]
141 | label_data = label_data[np.logical_and(box_w>1, box_h>1)]
142 |
143 | if len(labels) > 0:
144 | #np.random.shuffle(labels)
145 | labels[:, [0, 2]] = labels[:, [0, 2]]*nw/iw + dx
146 | labels[:, [1, 3]] = labels[:, [1, 3]]*nh/ih + dy
147 |
148 | labels[:, 0:2][labels[:, 0:2]<0] = 0
149 | labels[:, 2][labels[:, 2]>w] = w
150 | labels[:, 3][labels[:, 3]>h] = h
151 | #discard invalid box
152 | # box_w = labels[:, 2] - labels[:, 0]
153 | # box_h = labels[:, 3] - labels[:, 1]
154 | # labels = labels[np.logical_and(box_w>1, box_h>1)]
155 |
156 | multi_frame_label.append(np.array(labels, dtype=np.float32))
157 | multi_frame_label = multi_frame_label[::-1]
158 | image_data = np.array(image_data[::-1])
159 | label_data = np.array(label_data, dtype=np.float32)
160 |
161 | return image_data, label_data, multi_frame_label
162 |
163 | def dataset_collate(batch):
164 | images = []
165 | bboxes = []
166 | multi_boxes = []
167 | for img, box, multi_box in batch:
168 | images.append(img)
169 | bboxes.append(box)
170 | multi_boxes.append(multi_box)
171 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
172 | bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
173 |
174 | return images, bboxes, multi_boxes
175 |
176 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | def cvtColor(image):
5 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
6 | return image
7 | else:
8 | image = image.convert('RGB')
9 | return image
10 |
11 | def resize_image(image, size, letterbox_image):
12 | iw, ih = image.size
13 | w, h = size
14 | if letterbox_image:
15 | scale = min(w/iw, h/ih)
16 | nw = int(iw*scale)
17 | nh = int(ih*scale)
18 |
19 | image = image.resize((nw,nh), Image.BICUBIC)
20 | new_image = Image.new('RGB', size, (128,128,128))
21 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
22 | else:
23 | new_image = image.resize((w, h), Image.BICUBIC)
24 | return new_image
25 |
26 | def get_classes(classes_path):
27 | with open(classes_path, encoding='utf-8') as f:
28 | class_names = f.readlines()
29 | class_names = [c.strip() for c in class_names]
30 | return class_names, len(class_names)
31 |
32 | def preprocess_input(image):
33 | image /= 255.0
34 | image -= np.array([0.485, 0.456, 0.406])
35 | image /= np.array([0.229, 0.224, 0.225])
36 | return image
37 |
38 | def get_lr(optimizer):
39 | for param_group in optimizer.param_groups:
40 | return param_group['lr']
41 |
42 | def show_config(**kwargs):
43 | print('Configurations:')
44 | print('-' * 130)
45 | print('|%25s | %100s|' % ('keys', 'values'))
46 | print('-' * 130)
47 | for key, value in kwargs.items():
48 | print('|%25s | %100s|' % (str(key), str(value)))
49 | print('-' * 130)
50 |
--------------------------------------------------------------------------------
/utils/utils_bbox.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.ops import nms, boxes
4 |
5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
6 |
7 | box_yx = box_xy[..., ::-1]
8 | box_hw = box_wh[..., ::-1]
9 | input_shape = np.array(input_shape)
10 | image_shape = np.array(image_shape)
11 |
12 | if letterbox_image:
13 |
14 | new_shape = np.round(image_shape * np.min(input_shape/image_shape))
15 | offset = (input_shape - new_shape)/2./input_shape
16 | scale = input_shape/new_shape
17 |
18 | box_yx = (box_yx - offset) * scale
19 | box_hw *= scale
20 |
21 | box_mins = box_yx - (box_hw / 2.)
22 | box_maxes = box_yx + (box_hw / 2.)
23 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
24 | boxes *= np.concatenate([image_shape, image_shape], axis=-1)
25 | return boxes
26 |
27 | def decode_outputs(outputs, input_shape):
28 | grids = []
29 | strides = []
30 | hw = [x.shape[-2:] for x in outputs]
31 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
32 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
33 | for h, w in hw:
34 |
35 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)], indexing='ij')
36 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
37 | shape = grid.shape[:2]
38 |
39 | grids.append(grid)
40 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
41 |
42 | grids = torch.cat(grids, dim=1).type(outputs.type())
43 | strides = torch.cat(strides, dim=1).type(outputs.type())
44 |
45 | outputs[..., :2] = (outputs[..., :2] + grids) * strides
46 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
47 |
48 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
49 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
50 | return outputs
51 |
52 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
53 |
54 | box_corner = prediction.new(prediction.shape)
55 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
56 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
57 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
58 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
59 | prediction[:, :, :4] = box_corner[:, :, :4]
60 |
61 | output = [None for _ in range(len(prediction))]
62 |
63 | for i, image_pred in enumerate(prediction):
64 |
65 | class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
66 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
67 |
68 | if not image_pred.size(0):
69 | continue
70 |
71 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
72 | detections = detections[conf_mask]
73 |
74 | nms_out_index = boxes.batched_nms(
75 | detections[:, :4],
76 | detections[:, 4] * detections[:, 5],
77 | detections[:, 6],
78 | nms_thres,
79 | )
80 |
81 | output[i] = detections[nms_out_index]
82 |
83 | if output[i] is not None:
84 | output[i] = output[i].cpu().numpy()
85 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
86 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
87 | return output
88 |
--------------------------------------------------------------------------------
/utils/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tqdm import tqdm
4 | from utils.utils import get_lr
5 |
6 |
7 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
8 | loss = 0
9 | val_loss = 0
10 |
11 | epoch_step = epoch_step // 5 ####################
12 |
13 | if local_rank == 0:
14 | print('Start Train')
15 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
16 | model_train.train()
17 | for iteration, batch in enumerate(gen):
18 | if iteration >= epoch_step:
19 | break
20 |
21 | images, targets, multi_targets = batch[0], batch[1], batch[2]
22 | with torch.no_grad():
23 | if cuda:
24 | images = images.cuda(local_rank)
25 | targets = [ann.cuda(local_rank) for ann in targets]
26 | for target in multi_targets:
27 | target = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in target]
28 | target = [ann.cuda(local_rank) for ann in target]
29 |
30 | optimizer.zero_grad()
31 | if not fp16:
32 |
33 | outputs, mm_loss = model_train(images, multi_targets)
34 | loss_value = yolo_loss(outputs, targets) + mm_loss
35 |
36 | loss_value.backward()
37 | optimizer.step()
38 | else:
39 | from torch.cuda.amp import autocast
40 | with autocast():
41 | outputs = model_train(images)
42 | loss_value = yolo_loss(outputs, targets)
43 |
44 | scaler.scale(loss_value).backward()
45 | scaler.step(optimizer)
46 | scaler.update()
47 | if ema:
48 | ema.update(model_train)
49 |
50 | loss += loss_value.item()
51 |
52 | if local_rank == 0:
53 | pbar.set_postfix(**{'loss' : loss / (iteration + 1),
54 | 'lr' : get_lr(optimizer)})
55 | pbar.update(1)
56 |
57 | if local_rank == 0:
58 | pbar.close()
59 | print('Finish Train')
60 | print('Start Validation')
61 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
62 |
63 | if ema:
64 | model_train_eval = ema.ema
65 | else:
66 | model_train_eval = model_train.eval()
67 |
68 | for iteration, batch in enumerate(gen_val):
69 | if iteration >= epoch_step_val:
70 | break
71 | images, targets = batch[0], batch[1]
72 |
73 | with torch.no_grad():
74 | if cuda:
75 | images = images.cuda(local_rank)
76 | targets = [ann.cuda(local_rank) for ann in targets]
77 |
78 | optimizer.zero_grad()
79 | outputs = model_train_eval(images)
80 | loss_value = yolo_loss(outputs, targets)
81 |
82 | val_loss += loss_value.item()
83 | if local_rank == 0:
84 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
85 | pbar.update(1)
86 |
87 | if local_rank == 0:
88 | pbar.close()
89 | print('Finish Validation')
90 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
91 | eval_callback.on_epoch_end(epoch + 1, model_train_eval)
92 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
93 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
94 |
95 | if ema:
96 | save_state_dict = ema.ema.state_dict()
97 | else:
98 | save_state_dict = model.state_dict()
99 |
100 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
101 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
102 |
103 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
104 | print('Save best model to best_epoch_weights.pth')
105 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
106 |
107 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
108 |
--------------------------------------------------------------------------------
/utils_coco/coco_annotation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | train_datasets_path = "coco_dataset/train2017"
6 | val_datasets_path = "coco_dataset/val2017"
7 |
8 | train_annotation_path = "coco_dataset/annotations/instances_train2017.json"
9 | val_annotation_path = "coco_dataset/annotations/instances_val2017.json"
10 |
11 | train_output_path = "coco_train.txt"
12 | val_output_path = "coco_val.txt"
13 |
14 | if __name__ == "__main__":
15 | name_box_id = defaultdict(list)
16 | id_name = dict()
17 | f = open(train_annotation_path, encoding='utf-8')
18 | data = json.load(f)
19 |
20 | annotations = data['annotations']
21 | for ant in annotations:
22 | id = ant['image_id']
23 | name = os.path.join(train_datasets_path, '%012d.jpg' % id)
24 | cat = ant['category_id']
25 | if cat >= 1 and cat <= 11:
26 | cat = cat - 1
27 | elif cat >= 13 and cat <= 25:
28 | cat = cat - 2
29 | elif cat >= 27 and cat <= 28:
30 | cat = cat - 3
31 | elif cat >= 31 and cat <= 44:
32 | cat = cat - 5
33 | elif cat >= 46 and cat <= 65:
34 | cat = cat - 6
35 | elif cat == 67:
36 | cat = cat - 7
37 | elif cat == 70:
38 | cat = cat - 9
39 | elif cat >= 72 and cat <= 82:
40 | cat = cat - 10
41 | elif cat >= 84 and cat <= 90:
42 | cat = cat - 11
43 | name_box_id[name].append([ant['bbox'], cat])
44 |
45 | f = open(train_output_path, 'w')
46 | for key in name_box_id.keys():
47 | f.write(key)
48 | box_infos = name_box_id[key]
49 | for info in box_infos:
50 | x_min = int(info[0][0])
51 | y_min = int(info[0][1])
52 | x_max = x_min + int(info[0][2])
53 | y_max = y_min + int(info[0][3])
54 |
55 | box_info = " %d,%d,%d,%d,%d" % (
56 | x_min, y_min, x_max, y_max, int(info[1]))
57 | f.write(box_info)
58 | f.write('\n')
59 | f.close()
60 |
61 | name_box_id = defaultdict(list)
62 | id_name = dict()
63 | f = open(val_annotation_path, encoding='utf-8')
64 | data = json.load(f)
65 |
66 | annotations = data['annotations']
67 | for ant in annotations:
68 | id = ant['image_id']
69 | name = os.path.join(val_datasets_path, '%012d.jpg' % id)
70 | cat = ant['category_id']
71 | if cat >= 1 and cat <= 11:
72 | cat = cat - 1
73 | elif cat >= 13 and cat <= 25:
74 | cat = cat - 2
75 | elif cat >= 27 and cat <= 28:
76 | cat = cat - 3
77 | elif cat >= 31 and cat <= 44:
78 | cat = cat - 5
79 | elif cat >= 46 and cat <= 65:
80 | cat = cat - 6
81 | elif cat == 67:
82 | cat = cat - 7
83 | elif cat == 70:
84 | cat = cat - 9
85 | elif cat >= 72 and cat <= 82:
86 | cat = cat - 10
87 | elif cat >= 84 and cat <= 90:
88 | cat = cat - 11
89 | name_box_id[name].append([ant['bbox'], cat])
90 |
91 | f = open(val_output_path, 'w')
92 | for key in name_box_id.keys():
93 | f.write(key)
94 | box_infos = name_box_id[key]
95 | for info in box_infos:
96 | x_min = int(info[0][0])
97 | y_min = int(info[0][1])
98 | x_max = x_min + int(info[0][2])
99 | y_max = y_min + int(info[0][3])
100 |
101 | box_info = " %d,%d,%d,%d,%d" % (
102 | x_min, y_min, x_max, y_max, int(info[1]))
103 | f.write(box_info)
104 | f.write('\n')
105 | f.close()
106 |
--------------------------------------------------------------------------------
/utils_coco/coco_to_txt.py:
--------------------------------------------------------------------------------
1 | """
2 | coco2txt
3 | """
4 |
5 | import json
6 | import os
7 | from collections import defaultdict
8 |
9 | train_datasets_path = "/home/public/DMIST-60/"
10 | val_datasets_path = "/home/public/DMIST-60/"
11 |
12 | train_annotation_path = "/home/public/DMIST-60/2_coco_train.json"
13 | val_annotation_path = "/home/public/DMIST-60/60_coco_val.json"
14 |
15 | train_output_path = "DMIST_train.txt"
16 | val_output_path = "DMIST_60_val.txt"
17 |
18 | def get_path(images, id):
19 | for image in images:
20 | if id == image["id"]:
21 | return image['file_name']
22 |
23 | if __name__ == "__main__":
24 | name_box_id = defaultdict(list)
25 | id_name = dict()
26 | f = open(train_annotation_path, encoding='utf-8')
27 | data = json.load(f)
28 |
29 | images = data['images']
30 | annotations = data['annotations']
31 | for ant in annotations:
32 | id = ant['image_id']
33 | name = os.path.join(train_datasets_path, get_path(images, id))
34 | cat = ant['category_id'] - 1
35 | name_box_id[name].append([ant['bbox'], cat])
36 |
37 | f = open(train_output_path, 'w')
38 | for key in name_box_id.keys():
39 | f.write(key)
40 | box_infos = name_box_id[key]
41 | for info in box_infos:
42 | x_min = int(info[0][0])
43 | y_min = int(info[0][1])
44 | x_max = x_min + int(info[0][2])
45 | y_max = y_min + int(info[0][3])
46 |
47 | box_info = " %d,%d,%d,%d,%d" % (
48 | x_min, y_min, x_max, y_max, int(info[1]))
49 | f.write(box_info)
50 | f.write('\n')
51 | f.close()
52 |
53 | name_box_id = defaultdict(list)
54 | id_name = dict()
55 | f = open(val_annotation_path, encoding='utf-8')
56 | data = json.load(f)
57 |
58 | images = data['images']
59 | annotations = data['annotations']
60 | for ant in annotations:
61 | id = ant['image_id']
62 | name = os.path.join(train_datasets_path, get_path(images, id))
63 | cat = ant['category_id']
64 | cat = cat - 1
65 | name_box_id[name].append([ant['bbox'], cat])
66 |
67 | f = open(val_output_path, 'w')
68 | for key in name_box_id.keys():
69 | f.write(key)
70 | box_infos = name_box_id[key]
71 | for info in box_infos:
72 | x_min = int(info[0][0])
73 | y_min = int(info[0][1])
74 | x_max = x_min + int(info[0][2])
75 | y_max = y_min + int(info[0][3])
76 |
77 | box_info = " %d,%d,%d,%d,%d" % (
78 | x_min, y_min, x_max, y_max, int(info[1]))
79 | f.write(box_info)
80 | f.write('\n')
81 | f.close()
--------------------------------------------------------------------------------