├── EigenTheda.py ├── README.md ├── cfg ├── deploy │ └── dog.jpg └── training │ ├── RotaYolo.yaml │ ├── RotaYolo_RotaConv.yaml │ └── yolov7-w6.yaml ├── data ├── dota.yaml └── hyp.scratch.dota.yaml ├── detect.py ├── export.py ├── figure ├── Experiments_Comparison.png ├── LGBB.png ├── OBB_Comparison.png ├── OBB_Representation.png ├── Orientation_Sensitive_Feature_Extraction.png ├── Overview.png ├── RRC.png ├── RRC_Comparison.png ├── Vis_Comparison.png └── Vis_RRC.png ├── hubconf.py ├── models ├── __init__.py ├── common.py ├── experimental.py └── yolo.py ├── test.py ├── train.py └── utils ├── __init__.py ├── activations.py ├── add_nms.py ├── autoanchor.py ├── aws ├── __init__.py ├── mime.sh ├── resume.py └── userdata.sh ├── datasets.py ├── general.py ├── google_utils.py ├── loss.py ├── metrics.py ├── plots.py ├── torch_utils.py └── wandb_logging ├── __init__.py ├── log_dataset.py └── wandb_utils.py /EigenTheda.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | from copy import deepcopy 5 | from math import cos, sin, pi, atan2 6 | import torch 7 | # x y w h theda → EigenVector、EigenValue → Matrix: 2 × 3 8 | # x y w h: 0~1, theda: 0~180 9 | def xywhTheda2Eigen(input): # [:, xywhtheda] [a1, a2, cx; a3, a4, cy] 10 | 11 | input[..., 4] = input[..., 4] / 180 * pi 12 | # 2D 13 | if len(input.shape) == 2: 14 | output = np.zeros((input.shape[0], 2, 3)) 15 | output_vector = np.zeros((input.shape[0], 6)) # [cx, cy, a1, a2, a3, a4] 16 | output_vector_5 = np.zeros((input.shape[0], 5)) # [cx, cy, z1, z3, z2] 17 | 18 | EigenMatrix = np.zeros((input.shape[0], 2, 3)) 19 | EigenMatrix[:, 0, 2] = input[:, 0] # cx 20 | EigenMatrix[:, 1, 2] = input[:, 1] # cy 21 | 22 | EigenValueMatrix = np.zeros((input.shape[0], 2, 2)) 23 | EigenValueMatrix[:, 0, 0] = input[:, 2] # w 24 | EigenValueMatrix[:, 1, 1] = input[:, 3] # h 25 | 26 | theda = input[:, 4] 27 | EigenVectorMatrix = np.zeros((input.shape[0], 2, 2)) 28 | EigenVectorMatrix[:, 0, 0] = np.cos(theda) 29 | EigenVectorMatrix[:, 1, 0] = -np.sin(theda) 30 | EigenVectorMatrix[:, 0, 1] = -np.sin(theda) 31 | EigenVectorMatrix[:, 1, 1] = -np.cos(theda) 32 | 33 | # EigenMatrix[:2, :2] = np.dot(np.dot(EigenVectorMatrix, EigenValueMatrix), EigenVectorMatrix.T) 34 | E_sin_temp = np.einsum('ijk,ikr->ijr', EigenVectorMatrix, EigenValueMatrix) 35 | EigenMatrix[:, :2, :2] = np.einsum('ijk,ikr->ijr', E_sin_temp, EigenVectorMatrix.transpose(0, 2, 1)) 36 | output[:, :, :] = EigenMatrix[:, :, :] 37 | 38 | output_vector[:, 0] = output[:, 0, 2] # cx 39 | output_vector[:, 1] = output[:, 1, 2] # cy 40 | output_vector[:, 2] = output[:, 0, 0] # a1 41 | output_vector[:, 3] = output[:, 0, 1] # a2 42 | output_vector[:, 4] = output[:, 1, 0] # a3 43 | output_vector[:, 5] = output[:, 1, 1] # a4 44 | 45 | output_vector_5[:, 0] = output[:, 0, 2] # cx 46 | output_vector_5[:, 1] = output[:, 1, 2] # cy 47 | output_vector_5[:, 2] = output[:, 0, 0] # z1 0~1 48 | output_vector_5[:, 3] = output[:, 1, 1] # z3 0~1 49 | output_vector_5[:, 4] = output[:, 0, 1] # z2 -1~1 50 | 51 | elif len(input.shape) == 3: 52 | output_vector_5_3 = np.zeros((input.shape[0], input.shape[1], 5)) 53 | for index in range(input.shape[0]): 54 | output = np.zeros((input.shape[1], 2, 3)) 55 | output_vector = np.zeros((input.shape[1], 6)) # [cx, cy, a1, a2, a3, a4] 56 | output_vector_5 = np.zeros((input.shape[1], 5)) # [cx, cy, z1, z3, z2] 57 | 58 | EigenMatrix = np.zeros((input.shape[1], 2, 3)) 59 | EigenMatrix[:, 0, 2] = input[index, :, 0] # cx 60 | EigenMatrix[:, 1, 2] = input[index, :, 1] # cy 61 | 62 | EigenValueMatrix = np.zeros((input.shape[1], 2, 2)) 63 | EigenValueMatrix[:, 0, 0] = input[index, :, 2] # w 64 | EigenValueMatrix[:, 1, 1] = input[index, :, 3] # h 65 | 66 | theda = input[index, :, 4] 67 | EigenVectorMatrix = np.zeros((input.shape[1], 2, 2)) 68 | EigenVectorMatrix[:, 0, 0] = np.cos(theda) 69 | EigenVectorMatrix[:, 1, 0] = -np.sin(theda) 70 | EigenVectorMatrix[:, 0, 1] = -np.sin(theda) 71 | EigenVectorMatrix[:, 1, 1] = -np.cos(theda) 72 | 73 | # EigenMatrix[:2, :2] = np.dot(np.dot(EigenVectorMatrix, EigenValueMatrix), EigenVectorMatrix.T) 74 | E_sin_temp = np.einsum('ijk,ikr->ijr', EigenVectorMatrix, EigenValueMatrix) 75 | EigenMatrix[:, :2, :2] = np.einsum('ijk,ikr->ijr', E_sin_temp, EigenVectorMatrix.transpose(0, 2, 1)) 76 | output[:, :, :] = EigenMatrix[:, :, :] 77 | 78 | output_vector[:, 0] = output[:, 0, 2] # cx 79 | output_vector[:, 1] = output[:, 1, 2] # cy 80 | output_vector[:, 2] = output[:, 0, 0] # a1 81 | output_vector[:, 3] = output[:, 0, 1] # a2 82 | output_vector[:, 4] = output[:, 1, 0] # a3 83 | output_vector[:, 5] = output[:, 1, 1] # a4 84 | 85 | output_vector_5[:, 0] = output[:, 0, 2] # cx 86 | output_vector_5[:, 1] = output[:, 1, 2] # cy 87 | output_vector_5[:, 2] = output[:, 0, 0] # z1 0~1 88 | output_vector_5[:, 3] = output[:, 1, 1] # z3 0~1 89 | output_vector_5[:, 4] = output[:, 0, 1] # z2 -1~1 90 | 91 | output_vector_5_3[index, :, :] = output_vector_5[:, :] 92 | output_vector_5 = output_vector_5_3 93 | 94 | return output_vector_5 95 | 96 | def xywhTheda2Eigen_numpy(input): # x y w h: 0~1, theda: 0~180 to x, y, z1, z3, z2 97 | cx = copy.deepcopy(input[..., 0]) 98 | cy = copy.deepcopy(input[..., 1]) 99 | w = copy.deepcopy(input[..., 2]) 100 | h = copy.deepcopy(input[..., 3]) 101 | angle = copy.deepcopy(input[..., 4]) 102 | 103 | angle = angle / 180 * pi 104 | sin_ = np.sin(angle) 105 | cos_ = np.cos(angle) 106 | 107 | z1 = w * (cos_**2) + h * (sin_**2) 108 | z3 = w * (sin_**2) + h * (cos_**2) 109 | z2 = (h - w) * sin_ * cos_ 110 | 111 | output = np.concatenate([cx.reshape(-1, 1), cy.reshape(-1, 1), z1.reshape(-1, 1), z3.reshape(-1, 1), z2.reshape(-1, 1)], axis=1) 112 | 113 | return output 114 | 115 | # EigenVector、EigenValue → Matrix: 2 × 3 →x y w h theda 116 | # x y w h theda: 0~180 117 | def Eigen2xywhTheda(input_vector): # [:, xywhtheda] [a1, a2, cx; a3, a4, cy] 118 | 119 | # 2D 120 | if len(input_vector.shape) == 2: 121 | output = np.zeros((input_vector.shape[0], 5)) 122 | input = np.zeros((input_vector.shape[0], 2, 3)) # Vector to Matrix 123 | input[:, 0, 2] = input_vector[:, 0] # cx 124 | input[:, 1, 2] = input_vector[:, 1] # cy 125 | input[:, 0, 0] = input_vector[:, 2] # a1 126 | input[:, 0, 1] = input_vector[:, 4] # a2 127 | input[:, 1, 0] = input_vector[:, 4] # a3 128 | input[:, 1, 1] = input_vector[:, 3] # a4 129 | 130 | output[:, 0] = input[:, 0, 2] # cx 131 | output[:, 1] = input[:, 1, 2] # cy 132 | 133 | values, vectors = np.linalg.eig(input[:, :, :2]) 134 | a = values.argmax(axis=1) 135 | output[:, 2] = values[np.arange(a.size), a] 136 | 137 | b = 1 - a 138 | output[:, 3] = values[np.arange(b.size), b] 139 | theda_vector = vectors[np.arange(a.size), :, a] 140 | theda = np.arctan2(theda_vector[:, 1], theda_vector[:, 0]) + pi # 0 ~ 2*pi 141 | theda[theda >= pi] -= pi 142 | theda[theda == pi] -= pi # [0, 2*pi] to [0, pi) 143 | 144 | output_theda = pi - theda 145 | output_theda[output_theda == pi] -= pi # (0, pi] to [0, pi) 146 | output[:, 4] = output_theda 147 | 148 | elif len(input_vector.shape) == 3: 149 | output_3 = np.zeros((input_vector.shape[0], input_vector.shape[1], 5)) 150 | for index in range(input_vector.shape[0]): 151 | output = np.zeros((input_vector.shape[1], 5)) 152 | input = np.zeros((input_vector.shape[1], 2, 3)) # Vector to Matrix 153 | input[:, 0, 2] = input_vector[index, :, 0] # cx 154 | input[:, 1, 2] = input_vector[index, :, 1] # cy 155 | input[:, 0, 0] = input_vector[index, :, 2] # a1 156 | input[:, 0, 1] = input_vector[index, :, 4] # a2 157 | input[:, 1, 0] = input_vector[index, :, 4] # a3 158 | input[:, 1, 1] = input_vector[index, :, 3] # a4 159 | 160 | output[:, 0] = input[:, 0, 2] # cx 161 | output[:, 1] = input[:, 1, 2] # cy 162 | 163 | values, vectors = np.linalg.eig(input[:, :, :2]) # np.where 164 | a = values.argmax(axis=1) 165 | output[:, 2] = values[np.arange(a.size), a] 166 | 167 | b = 1 - a 168 | output[:, 3] = values[np.arange(b.size), b] 169 | 170 | theda_vector = vectors[np.arange(a.size), :, a] 171 | theda = np.arctan2(theda_vector[:, 1], theda_vector[:, 0]) + pi # 0 ~ 2*pi 172 | theda[theda >= pi] -= pi 173 | theda[theda == pi] -= pi # [0, 2*pi] to [0, pi) 174 | 175 | output_theda = pi - theda 176 | output_theda[output_theda == pi] -= pi # (0, pi] to [0, pi) 177 | output[:, 4] = output_theda 178 | 179 | output_3[index, :, :] = output 180 | output = output_3 181 | 182 | output[..., 4] = output[..., 4] / pi * 180 183 | return output 184 | 185 | # x y w h theda: 0~180 186 | def Eigen2xywhTheda_numpy(input_vector, h_thred=0, angle_thred=0.95): # input_vector: (n, 2) (cx, cy, z1, z3, z2) to (cx, cy, w, h, theda) 187 | cx = copy.deepcopy(input_vector[..., 0]) 188 | cy = copy.deepcopy(input_vector[..., 1]) 189 | z1 = copy.deepcopy(input_vector[..., 2]) 190 | z2 = copy.deepcopy(input_vector[..., 4]) 191 | z3 = copy.deepcopy(input_vector[..., 3]) 192 | 193 | w = 0.5 * (z1 + z3 + ((z1 - z3) ** 2 + 4 * (z2 ** 2)) ** 0.5) 194 | h = 0.5 * (z1 + z3 - ((z1 - z3) ** 2 + 4 * (z2 ** 2)) ** 0.5) 195 | 196 | 'h >= 0 means positive definiteness' 197 | non_positive_define = (h <= h_thred) 198 | 199 | sin_2theta = -2 * z2 / (w - h) 200 | cos_2theta = (z1 - z3) / (w - h) 201 | 202 | sin_2theta = np.clip(sin_2theta, -1, 1) 203 | cos_2theta = np.clip(cos_2theta, -1, 1) 204 | 205 | arcsin_2_theta = np.arcsin(sin_2theta) # [-pi/2, pi/2] 206 | arccos_2_theta = np.arccos(cos_2theta) # [0, pi] 207 | 208 | _2_theta = np.zeros_like(arcsin_2_theta) 209 | 210 | arcsin_2_theta_1 = (arcsin_2_theta >= 0) * (arcsin_2_theta <= pi / 2) * (arccos_2_theta >= 0) * (arccos_2_theta <= pi / 2) 211 | _2_theta[arcsin_2_theta_1] = arcsin_2_theta[arcsin_2_theta_1] 212 | 213 | arcsin_2_theta_2 = (arcsin_2_theta >= 0) * (arcsin_2_theta <= pi / 2) * (arccos_2_theta > pi / 2) * (arccos_2_theta <= pi) 214 | _2_theta[arcsin_2_theta_2] = pi - arcsin_2_theta[arcsin_2_theta_2] 215 | 216 | arcsin_2_theta_3 = (arcsin_2_theta >= -pi / 2) * (arcsin_2_theta < 0) * (arccos_2_theta > pi / 2) * (arccos_2_theta <= pi) 217 | _2_theta[arcsin_2_theta_3] = pi - arcsin_2_theta[arcsin_2_theta_3] 218 | 219 | arcsin_2_theta_4 = (arcsin_2_theta >= -pi / 2) * (arcsin_2_theta < 0) * (arccos_2_theta >= 0) * (arccos_2_theta <= pi / 2) 220 | _2_theta[arcsin_2_theta_4] = 2 * pi + arcsin_2_theta[arcsin_2_theta_4] 221 | 222 | theta = _2_theta / 2 # [0, pi) 223 | theta = theta / pi * 180 # [0, 180) 224 | theta = np.clip(theta, 1, 179) 225 | 226 | cx, cy, w, h, theta = cx.reshape(-1, 1), cy.reshape(-1, 1), w.reshape(-1, 1), h.reshape(-1, 1), theta.reshape(-1, 1) 227 | output = np.concatenate([cx, cy, w, h, theta], axis=1) 228 | 229 | return output, non_positive_define 230 | 231 | def arcsin_taylor(xi, x0): 232 | x = torch.arcsin(x0) + ((1 - x0**2)**(-0.5)) * (xi - x0) + x0 * ((1 - x0**2)**(-1.5)) * (xi - x0)**2 233 | 234 | return x 235 | 236 | def arccos_taylor(xi, x0): 237 | x = torch.arccos(x0) - ((1 - x0**2)**(-0.5)) * (xi - x0) - x0 * ((1 - x0**2)**(-1.5)) * (xi - x0)**2 238 | 239 | return x 240 | 241 | # x y w h theda: 0~180 242 | def Eigen2xywhTheda_numpy_tensor(input_vector, h_thred=0, angle_thred=1.0): # input_vector: (n, 2) (cx, cy, z1, z3, z2) to (cx, cy, w, h, theda) 243 | cx = input_vector[..., 0] 244 | cy = input_vector[..., 1] 245 | z1 = input_vector[..., 2] 246 | z2 = input_vector[..., 4] 247 | z3 = input_vector[..., 3] 248 | 249 | w = 0.5 * (z1 + z3 + ((z1 - z3) ** 2 + 4 * (z2 ** 2)) ** 0.5) 250 | h = 0.5 * (z1 + z3 - ((z1 - z3) ** 2 + 4 * (z2 ** 2)) ** 0.5) 251 | 252 | 'h >= 0 means positive definiteness' 253 | non_positive_define = (h <= h_thred) 254 | w[(w / h < 1.05) & (h > 0)] *= 1.05 255 | 256 | sin_2theta = -2 * z2 / (w - h) 257 | cos_2theta = (z1 - z3) / (w - h) 258 | 259 | non_positive_define |= (sin_2theta > angle_thred) | (sin_2theta < -angle_thred) | (cos_2theta > angle_thred) | (cos_2theta < -angle_thred) 260 | sin_2theta = torch.clip(sin_2theta, -angle_thred, angle_thred) 261 | cos_2theta = torch.clip(cos_2theta, -angle_thred, angle_thred) 262 | 263 | arcsin_2_theta = torch.arcsin(sin_2theta) # [-pi/2, pi/2] 264 | arccos_2_theta = torch.arccos(cos_2theta) # [0, pi] 265 | 266 | _2_theta = torch.zeros_like(arcsin_2_theta) 267 | 268 | arcsin_2_theta_1 = (arcsin_2_theta >= 0) * (arcsin_2_theta <= pi / 2) * (arccos_2_theta >= 0) * (arccos_2_theta <= pi / 2) 269 | _2_theta[arcsin_2_theta_1] = arcsin_2_theta[arcsin_2_theta_1] 270 | 271 | arcsin_2_theta_2 = (arcsin_2_theta >= 0) * (arcsin_2_theta <= pi / 2) * (arccos_2_theta > pi / 2) * (arccos_2_theta <= pi) 272 | _2_theta[arcsin_2_theta_2] = pi - arcsin_2_theta[arcsin_2_theta_2] 273 | 274 | arcsin_2_theta_3 = (arcsin_2_theta >= -pi / 2) * (arcsin_2_theta < 0) * (arccos_2_theta > pi / 2) * (arccos_2_theta <= pi) 275 | _2_theta[arcsin_2_theta_3] = pi - arcsin_2_theta[arcsin_2_theta_3] 276 | 277 | arcsin_2_theta_4 = (arcsin_2_theta >= -pi / 2) * (arcsin_2_theta < 0) * (arccos_2_theta >= 0) * (arccos_2_theta <= pi / 2) 278 | _2_theta[arcsin_2_theta_4] = 2 * pi + arcsin_2_theta[arcsin_2_theta_4] 279 | 280 | theta = _2_theta / 2 # [0, pi) 281 | theta = theta / pi * 180 # [0, 180) 282 | theta = torch.clip(theta, 1, 179) 283 | 284 | # cx, cy, w, h, theta = cx.view(-1, 1), cy.view(-1, 1), w.view(-1, 1), h.view(-1, 1), theta.view(-1, 1) 285 | cx, cy, w, h, theta = cx.unsqueeze(-1), cy.unsqueeze(-1), w.unsqueeze(-1), h.unsqueeze(-1), theta.unsqueeze(-1) 286 | output = torch.cat([cx, cy, w, h, theta], dim=-1) 287 | 288 | # output = output.to(dtype=original_type) 289 | 290 | return output, non_positive_define 291 | 292 | def Eigen2xywhTheda_gpu(input_vector): # [:, xywhtheda] [a1, a2, cx; a3, a4, cy] 293 | 294 | # 2D 295 | if len(input_vector.shape) == 2: 296 | output = torch.zeros((input_vector.shape[0], 5)) 297 | input = torch.zeros((input_vector.shape[0], 2, 3)) # Vector to Matrix 298 | input[:, 0, 2] = input_vector[:, 0] # cx 299 | input[:, 1, 2] = input_vector[:, 1] # cy 300 | input[:, 0, 0] = input_vector[:, 2] # a1 301 | input[:, 0, 1] = input_vector[:, 4] # a2 302 | input[:, 1, 0] = input_vector[:, 4] # a3 303 | input[:, 1, 1] = input_vector[:, 3] # a4 304 | 305 | output[:, 0] = input[:, 0, 2] # cx 306 | output[:, 1] = input[:, 1, 2] # cy 307 | 308 | values, vectors = torch.linalg.eigh(input[:, :, :2]) # np.where 309 | a = values.argmax(axis=1) 310 | output[:, 2] = values[torch.arange(len(a)), a] 311 | b = 1 - a 312 | output[:, 3] = values[torch.arange(len(b)), b] 313 | 314 | theda_vector = vectors[torch.arange(len(a)), :, a] 315 | theda = torch.arctan2(theda_vector[:, 1], theda_vector[:, 0]) + pi # 0 ~ 2*pi 316 | theda[theda >= pi] -= pi 317 | theda[theda == pi] -= pi # [0, 2*pi] to [0, pi) 318 | 319 | output_theda = pi - theda 320 | output_theda[output_theda == pi] -= pi # (0, pi] to [0, pi) 321 | output[:, 4] = output_theda 322 | 323 | elif len(input_vector.shape) == 3: 324 | output_3 = torch.zeros((input_vector.shape[0], input_vector.shape[1], 5)) 325 | for index in range(input_vector.shape[0]): 326 | output = torch.zeros((input_vector.shape[1], 5)) 327 | input = torch.zeros((input_vector.shape[1], 2, 3)) # Vector to Matrix 328 | input[:, 0, 2] = input_vector[index, :, 0] # cx 329 | input[:, 1, 2] = input_vector[index, :, 1] # cy 330 | input[:, 0, 0] = input_vector[index, :, 2] # a1 331 | input[:, 0, 1] = input_vector[index, :, 4] # a2 332 | input[:, 1, 0] = input_vector[index, :, 4] # a3 333 | input[:, 1, 1] = input_vector[index, :, 3] # a4 334 | 335 | output[:, 0] = input[:, 0, 2] # cx 336 | output[:, 1] = input[:, 1, 2] # cy 337 | 338 | values, vectors = torch.linalg.eigh(input[:, :, :2]) # np.where 339 | a = values.argmax(axis=1) 340 | output[:, 2] = values[torch.arange(a.size), a] 341 | 342 | b = 1 - a 343 | output[:, 3] = values[torch.arange(b.size), b] 344 | 345 | theda_vector = vectors[torch.arange(a.size), :, a] 346 | theda = torch.arctan2(theda_vector[:, 1], theda_vector[:, 0]) + pi # 0 ~ 2*pi 347 | theda[theda >= pi] -= pi 348 | theda[theda == pi] -= pi # [0, 2*pi] to [0, pi) 349 | 350 | output_theda = pi - theda 351 | output_theda[output_theda == pi] -= pi # (0, pi] to [0, pi) 352 | output[:, 4] = output_theda 353 | 354 | output_3[index, :, :] = output 355 | output = output_3 356 | 357 | output[..., 4] = output[..., 4] / pi * 180 358 | return output 359 | 360 | def Eigen2x(input): # clsid, cx, cy, z1, z3, z2 to clsid, cx, cy, x1, x2, x3 361 | output = deepcopy(input) 362 | output[..., -3] = (input[..., -3] + input[..., -2]) / 2 363 | output[..., -2] = input[..., -3] 364 | output[..., -1] = (input[..., -3] + input[..., -2] + input[..., -1] * 2) / 2 365 | 366 | return output 367 | 368 | def x2Eigen(input): # clsid, cx, cy, x1, x2, x3 to clsid, cx, cy, z1, z3, z2 369 | if torch.is_tensor(input): 370 | output = torch.zeros_like(input).to(input.device) 371 | else: 372 | output = np.zeros_like(input) 373 | output[..., :-3] = input[..., :-3] 374 | 375 | output[..., -3] = input[..., -2] 376 | output[..., -2] = input[..., -3] * 2 - input[..., -2] 377 | output[..., -1] = input[..., -1] - input[..., -3] 378 | 379 | return output 380 | 381 | if __name__ == '__main__': 382 | input = np.array([[[0.3, 0.4, 0.8, 0.5, 179 ], 383 | [0.1, 0.2, 0.41, 0.4, 0 ], 384 | [0.3, 0.4, 0.7, 0.6, 90]], 385 | [[0.3, 0.4, 0.8, 0.5, 134], 386 | [0.1, 0.2, 0.41, 0.4, 50], 387 | [0.3, 0.4, 0.7, 0.6, 70]]] 388 | ) 389 | 390 | output = xywhTheda2Eigen(input) 391 | print("input: ", input.shape, input) 392 | print("output: ", output.shape, output) 393 | 394 | output2 = Eigen2xywhTheda(output) 395 | print("output: ", output.shape, output) 396 | print("output2: ", output2.shape, output2) 397 | 398 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Linear Gaussian Bounding Box Representation and Ring-Shaped Rotated Convolution for Oriented Object Detection 2 |

3 | In oriented object detection, current representations of oriented bounding boxes (OBBs) often suffer from the boundary discontinuity problem. Methods of designing continuous regression losses do not essentially solve this problem. Although Gaussian bounding box (GBB) representation avoids this problem, directly regressing GBB is susceptible to numerical instability. We propose linear GBB (LGBB), a novel OBB representation. By linearly transforming the elements of GBB, LGBB avoids the boundary discontinuity problem and has high numerical stability. In addition, existing convolution-based rotation-sensitive feature extraction methods only have local receptive fields, resulting in slow feature aggregation. We propose ring-shaped rotated convolution (RRC), which adaptively rotates feature maps to arbitrary orientations to extract rotation-sensitive features under a ring-shaped receptive field, rapidly aggregating features and contextual information. Experimental results demonstrate that LGBB and RRC achieve state-of-the-art performance. Furthermore, integrating LGBB and RRC into various models effectively improves detection accuracy. 4 |

5 | 6 | ## Realted Work 7 | 1. Comparison with Current OBB Representations 8 |
9 | 10 |
11 | 12 | 2. Comparison with Current Convolution-Based Rotation-Sensitive Feature Extraction Methods 13 |
14 | 15 |
16 | 17 | ## Methods 18 | 1. Overview 19 |
20 | 21 |
22 | 23 | 2. LGBB 24 |
25 | 26 |
27 | 28 | 3. RRC 29 |
30 | 31 |
32 | 33 | ## Experiments 34 | 1. Comparison with Current OBB Representations 35 |
36 | 37 |
38 | 39 | 2. Comparison with Current Convolution-Based Rotation-Sensitive Feature Extraction Methods 40 |
41 | 42 | 43 | 44 |
45 | 46 | 3. Comparison with Current Oriented Object Detectors 47 |
48 | 49 | 50 | 51 |
52 | 53 | ## Installation 54 | Refer to both [yolov7](https://github.com/WongKinYiu/yolov7) and [mmrotate](https://github.com/open-mmlab/mmrotate) 55 | 56 | ## Prepare Your Dataset 57 | [DOTA](https://captain-whu.github.io/DOTA/index.html) 58 | [HRSC2016](https://sites.google.com/site/hrsc2016/) 59 | 60 | 61 | ## Training 62 | ``` 63 | # Single GPU training 64 | python train.py --workers 8 --device 0 --batch-size 2 --data data/dota.yaml --img 1024 1024 --cfg cfg/training/RotaYolo_RotaConv.yaml --weights '' --hyp data/hyp.scratch.dota.yaml 65 | 66 | # Multiple GPU training 67 | python -m torch.distributed.launch --nproc_per_node 4 --master_port 9527 train.py --workers 8 --device 0,1,2,3 --sync-bn --batch-size 8 --data data/dota.yaml --img 1024 1024 --cfg cfg/training/RotaYolo_RotaConv.yaml --weights '' --hyp data/hyp.scratch.dota.yaml 68 | ``` 69 | 70 | ## Detecting 71 | ``` 72 | python detect.py --weights 'weights/best.pt' --source 'datasets/DOTA/demo.png' --img-size 1024 --conf-thres 0.5 --iou-thres 0.2 --device 0 73 | ``` 74 | 75 | ## Citation 76 | ``` 77 | @article{ZHOU2024110677, 78 | title = {Linear Gaussian bounding box representation and ring-shaped rotated convolution for oriented object detection}, 79 | journal = {Pattern Recognition}, 80 | volume = {155}, 81 | pages = {110677}, 82 | year = {2024}, 83 | author = {Zhen Zhou and Yunkai Ma and Junfeng Fan and Zhaoyang Liu and Fengshui Jing and Min Tan}, 84 | } 85 | ``` 86 | 87 | ## Acknowledgement 88 | [mmrotate](https://github.com/open-mmlab/mmrotate) 89 | 90 | [yolov7](https://github.com/WongKinYiu/yolov7) 91 | 92 | -------------------------------------------------------------------------------- /cfg/deploy/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/cfg/deploy/dog.jpg -------------------------------------------------------------------------------- /cfg/training/RotaYolo.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 16 # number of classes 3 | depth_multiple: 0.5 # model depth multiple 4 | width_multiple: 0.5 # layer channel multiple 5 | 6 | anchors: 7 | # [16, 32, 64, 128, 256] [2, 5] [0, 45, 90, 135] 8 | - [ 16, 22, 16, 16, 16, 11, 16, 11, 16, 16, 16, 22, 21, 35, 21, 21, 21, 7, 21, 7, 21, 21, 21, 35] # P2/4 9 | - [ 33, 45, 33, 33, 33, 22, 33, 22, 33, 33, 33, 45, 42, 71, 42, 42, 42, 14, 42, 14, 42, 42, 42, 71] # P3/8 10 | - [ 66, 90, 66, 66, 66, 44, 66, 44, 66, 66, 66, 90, 84, 142, 84, 84, 84, 28, 84, 28, 84, 84, 84, 142 ] # P4/16 11 | - [ 132, 180, 132, 132, 132, 88, 132, 88, 132, 132, 132, 180, 168, 284, 168, 168, 168, 56, 168, 56, 168, 168, 168, 284 ] # P5/32 12 | - [ 264, 360, 264, 264, 264, 176, 264, 176, 264, 264, 264, 360, 336, 568, 336, 336, 336, 112, 336, 112, 336, 336, 336, 568 ] # P6/64 13 | 14 | backbone: 15 | # [from, number, module, args] 16 | [[-1, 1, ReOrg, []], # 0 17 | [-1, 1, Conv, [64, 3, 1]], # 1-P1/2 18 | 19 | [-1, 1, Conv, [128, 3, 2]], # 2-P2/4 20 | 21 | [-1, 1, Conv, [64, 1, 1]], 22 | [-2, 1, Conv, [64, 1, 1]], 23 | [-1, 1, Conv, [64, 3, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [[-1, -3, -4,], 1, Concat, [1]], 26 | [-1, 1, Conv, [128, 1, 1]], # 9 27 | 28 | [-1, 1, Conv, [256, 3, 2]], # 10-P3/8 29 | 30 | [-1, 1, Conv, [128, 1, 1]], 31 | [-2, 1, Conv, [128, 1, 1]], 32 | [-1, 1, Conv, [128, 3, 1]], 33 | [-1, 1, Conv, [128, 3, 1]], 34 | [[-1, -3, -4,], 1, Concat, [1]], 35 | [-1, 1, Conv, [256, 1, 1]], # 17 36 | 37 | [-1, 1, Conv, [512, 3, 2]], # 18-P4/16 38 | 39 | [-1, 1, Conv, [256, 1, 1]], 40 | [-2, 1, Conv, [256, 1, 1]], 41 | [-1, 1, Conv, [256, 3, 1]], 42 | [-1, 1, Conv, [256, 3, 1]], 43 | [[-1, -3, -4,], 1, Concat, [1]], 44 | [-1, 1, Conv, [512, 1, 1]], # 25 45 | 46 | [-1, 1, Conv, [768, 3, 2]], # 26-P5/32 47 | 48 | [-1, 1, Conv, [384, 1, 1]], 49 | [-2, 1, Conv, [384, 1, 1]], 50 | [-1, 1, Conv, [384, 3, 1]], 51 | [-1, 1, Conv, [384, 3, 1]], 52 | [[-1, -3, -4,], 1, Concat, [1]], 53 | [-1, 1, Conv, [768, 1, 1]], # 33 54 | 55 | [-1, 1, Conv, [1024, 3, 2]], # 34-P6/64 56 | 57 | [-1, 1, Conv, [512, 1, 1]], 58 | [-2, 1, Conv, [512, 1, 1]], 59 | [-1, 1, Conv, [512, 3, 1]], 60 | [-1, 1, Conv, [512, 3, 1]], 61 | [[-1, -3, -4,], 1, Concat, [1]], 62 | [-1, 1, Conv, [1024, 1, 1]], # 41 63 | ] 64 | 65 | head: 66 | [[-1, 1, SPPCSPC, [512]], # 37 67 | 68 | [-1, 1, Conv, [384, 1, 1]], 69 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 70 | [29, 1, Conv, [384, 1, 1]], # 29 route backbone P5 71 | [[-1, -2], 1, Concat, [1]], 72 | 73 | [-1, 1, Conv, [384, 1, 1]], 74 | [-2, 1, Conv, [384, 1, 1]], 75 | [-1, 1, Conv, [192, 3, 1]], 76 | [-1, 1, Conv, [192, 3, 1]], 77 | [[-1, -2, -3, -4,], 1, Concat, [1]], 78 | [-1, 1, Conv, [384, 1, 1]], # 47 79 | 80 | [-1, 1, Conv, [256, 1, 1]], 81 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 82 | [22, 1, Conv, [256, 1, 1]], #22 route backbone P4 83 | [[-1, -2], 1, Concat, [1]], 84 | 85 | [-1, 1, Conv, [256, 1, 1]], 86 | [-2, 1, Conv, [256, 1, 1]], 87 | [-1, 1, Conv, [128, 3, 1]], 88 | [-1, 1, Conv, [128, 3, 1]], 89 | [[-1, -2, -3, -4], 1, Concat, [1]], 90 | [-1, 1, Conv, [256, 1, 1]], # 57 91 | 92 | [-1, 1, Conv, [192, 1, 1]], 93 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 94 | [15, 1, Conv, [192, 1, 1]], # 15 route backbone P3 95 | [[-1, -2], 1, Concat, [1]], 96 | 97 | [-1, 1, Conv, [192, 1, 1]], 98 | [-2, 1, Conv, [192, 1, 1]], 99 | [-1, 1, Conv, [96, 3, 1]], 100 | [-1, 1, Conv, [96, 3, 1]], 101 | [[-1, -2, -3, -4], 1, Concat, [1]], 102 | [-1, 1, Conv, [192, 1, 1]], # 67 103 | 104 | [ -1, 1, Conv, [ 128, 1, 1 ] ], 105 | [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], 106 | [ 8, 1, Conv, [ 128, 1, 1 ] ], # 8 route backbone P2 107 | [ [ -1, -2 ], 1, Concat, [ 1 ] ], 108 | 109 | [ -1, 1, Conv, [ 128, 1, 1 ] ], 110 | [ -2, 1, Conv, [ 128, 1, 1 ] ], 111 | [ -1, 1, Conv, [ 64, 3, 1 ] ], 112 | [ -1, 1, Conv, [ 64, 3, 1 ] ], 113 | [ [ -1, -2, -3, -4 ], 1, Concat, [ 1 ] ], 114 | [ -1, 1, Conv, [ 128, 1, 1 ] ], # 77 115 | 116 | [-1, 1, Conv, [256, 3, 2]], 117 | [[-1, 67], 1, Concat, [1]], # 67 cat 118 | 119 | [-1, 1, Conv, [256, 1, 1]], 120 | [-2, 1, Conv, [256, 1, 1]], 121 | [-1, 1, Conv, [128, 3, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [[-1, -2, -3, -4], 1, Concat, [1]], 124 | [-1, 1, Conv, [256, 1, 1]], # 85 125 | 126 | [-1, 1, Conv, [384, 3, 2]], 127 | [[-1, 57], 1, Concat, [1]], # 57 cat 128 | 129 | [-1, 1, Conv, [384, 1, 1]], 130 | [-2, 1, Conv, [384, 1, 1]], 131 | [-1, 1, Conv, [192, 3, 1]], 132 | [-1, 1, Conv, [192, 3, 1]], 133 | [[-1, -2, -3, -4], 1, Concat, [1]], 134 | [-1, 1, Conv, [384, 1, 1]], # 93 135 | 136 | [-1, 1, Conv, [512, 3, 2]], 137 | [[-1, 47], 1, Concat, [1]], # 47 cat 138 | 139 | [-1, 1, Conv, [512, 1, 1]], 140 | [-2, 1, Conv, [512, 1, 1]], 141 | [-1, 1, Conv, [256, 3, 1]], 142 | [-1, 1, Conv, [256, 3, 1]], 143 | [[-1, -2, -3, -4], 1, Concat, [1]], 144 | [-1, 1, Conv, [512, 1, 1]], # 101 145 | 146 | [ -1, 1, Conv, [ 768, 3, 2 ] ], 147 | [ [ -1, 37 ], 1, Concat, [ 1 ] ], # 37 cat 148 | 149 | [ -1, 1, Conv, [ 768, 1, 1 ] ], 150 | [ -2, 1, Conv, [ 768, 1, 1 ] ], 151 | [ -1, 1, Conv, [ 384, 3, 1 ] ], 152 | [ -1, 1, Conv, [ 384, 3, 1 ] ], 153 | [ [ -1, -2, -3, -4 ], 1, Concat, [ 1 ] ], 154 | [ -1, 1, Conv, [ 768, 1, 1 ] ], # 109 155 | 156 | [77, 1, Conv, [256, 3, 1]], # 110 157 | [85, 1, Conv, [384, 3, 1]], 158 | [93, 1, Conv, [512, 3, 1]], 159 | [101, 1, Conv, [768, 3, 1]], 160 | [109, 1, Conv, [1024, 3, 1]], # 114 161 | 162 | [[110, 111, 112, 113, 114], 1, Detect, [nc, anchors]], # Detect(P2, P3, P4, P5, P6) 163 | ] 164 | -------------------------------------------------------------------------------- /cfg/training/RotaYolo_RotaConv.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 16 # number of classes 3 | depth_multiple: 0.5 # model depth multiple 4 | width_multiple: 0.5 # layer channel multiple 5 | 6 | anchors: 7 | # [16, 32, 64, 128, 256] [2, 5] [0, 45, 90, 135] 8 | - [ 16, 22, 16, 16, 16, 11, 16, 11, 16, 16, 16, 22, 21, 35, 21, 21, 21, 7, 21, 7, 21, 21, 21, 35] # P2/4 9 | - [ 33, 45, 33, 33, 33, 22, 33, 22, 33, 33, 33, 45, 42, 71, 42, 42, 42, 14, 42, 14, 42, 42, 42, 71] # P3/8 10 | - [ 66, 90, 66, 66, 66, 44, 66, 44, 66, 66, 66, 90, 84, 142, 84, 84, 84, 28, 84, 28, 84, 84, 84, 142 ] # P4/16 11 | - [ 132, 180, 132, 132, 132, 88, 132, 88, 132, 132, 132, 180, 168, 284, 168, 168, 168, 56, 168, 56, 168, 168, 168, 284 ] # P5/32 12 | - [ 264, 360, 264, 264, 264, 176, 264, 176, 264, 264, 264, 360, 336, 568, 336, 336, 336, 112, 336, 112, 336, 336, 336, 568 ] # P6/64 13 | 14 | backbone: 15 | # [from, number, module, args] 16 | [[-1, 1, ReOrg, []], # 0 17 | [-1, 1, Conv, [64, 3, 1]], # 1-P1/2 18 | 19 | [-1, 1, Conv, [128, 3, 2]], # 2-P2/4 20 | [ -1, 1, RotaConv, [ 128, 3, 1 ] ], # Angle Convolution P2 21 | 22 | [-1, 1, Conv, [64, 1, 1]], 23 | [-2, 1, Conv, [64, 1, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [[-1, -3, -4,], 1, Concat, [1]], 27 | [-1, 1, Conv, [128, 1, 1]], # 9 28 | 29 | [-1, 1, Conv, [256, 3, 2]], # 10-P3/8 30 | [ -1, 1, RotaConv, [ 256, 3, 1 ] ], # Angle Convolution P3 31 | 32 | [-1, 1, Conv, [128, 1, 1]], 33 | [-2, 1, Conv, [128, 1, 1]], 34 | [-1, 1, Conv, [128, 3, 1]], 35 | [-1, 1, Conv, [128, 3, 1]], 36 | [[-1, -3, -4,], 1, Concat, [1]], 37 | [-1, 1, Conv, [256, 1, 1]], # 17 38 | 39 | [-1, 1, Conv, [512, 3, 2]], # 18-P4/16 40 | [ -1, 1, RotaConv, [ 512, 3, 1 ] ], # Angle Convolution P4 41 | 42 | [-1, 1, Conv, [256, 1, 1]], 43 | [-2, 1, Conv, [256, 1, 1]], 44 | [-1, 1, Conv, [256, 3, 1]], 45 | [-1, 1, Conv, [256, 3, 1]], 46 | [[-1, -3, -4,], 1, Concat, [1]], 47 | [-1, 1, Conv, [512, 1, 1]], # 25 48 | 49 | [-1, 1, Conv, [768, 3, 2]], # 26-P5/32 50 | [ -1, 1, RotaConv, [ 768, 3, 1 ] ], # Angle Convolution P5 51 | 52 | [-1, 1, Conv, [384, 1, 1]], 53 | [-2, 1, Conv, [384, 1, 1]], 54 | [-1, 1, Conv, [384, 3, 1]], 55 | [-1, 1, Conv, [384, 3, 1]], 56 | [[-1, -3, -4,], 1, Concat, [1]], 57 | [-1, 1, Conv, [768, 1, 1]], # 33 58 | 59 | [-1, 1, Conv, [1024, 3, 2]], # 34-P6/64 60 | [ -1, 1, RotaConv, [ 1024, 3, 1 ] ], # Angle Convolution P6 61 | 62 | [-1, 1, Conv, [512, 1, 1]], 63 | [-2, 1, Conv, [512, 1, 1]], 64 | [-1, 1, Conv, [512, 3, 1]], 65 | [-1, 1, Conv, [512, 3, 1]], 66 | [[-1, -3, -4,], 1, Concat, [1]], 67 | [-1, 1, Conv, [1024, 1, 1]], # 41 68 | ] 69 | 70 | head: 71 | [[-1, 1, SPPCSPC, [512]], # 37 72 | 73 | [-1, 1, Conv, [384, 1, 1]], 74 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 75 | [33, 1, Conv, [384, 1, 1]], # 29 route backbone P5 76 | [[-1, -2], 1, Concat, [1]], 77 | 78 | [-1, 1, Conv, [384, 1, 1]], 79 | [-2, 1, Conv, [384, 1, 1]], 80 | [-1, 1, Conv, [192, 3, 1]], 81 | [-1, 1, Conv, [192, 3, 1]], 82 | [[-1, -2, -3, -4,], 1, Concat, [1]], 83 | [-1, 1, Conv, [384, 1, 1]], # 47 84 | 85 | [-1, 1, Conv, [256, 1, 1]], 86 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 87 | [25, 1, Conv, [256, 1, 1]], #22 route backbone P4 88 | [[-1, -2], 1, Concat, [1]], 89 | 90 | [-1, 1, Conv, [256, 1, 1]], 91 | [-2, 1, Conv, [256, 1, 1]], 92 | [-1, 1, Conv, [128, 3, 1]], 93 | [-1, 1, Conv, [128, 3, 1]], 94 | [[-1, -2, -3, -4], 1, Concat, [1]], 95 | [-1, 1, Conv, [256, 1, 1]], # 57 96 | 97 | [-1, 1, Conv, [192, 1, 1]], 98 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 99 | [17, 1, Conv, [192, 1, 1]], # 15 route backbone P3 100 | [[-1, -2], 1, Concat, [1]], 101 | 102 | [-1, 1, Conv, [192, 1, 1]], 103 | [-2, 1, Conv, [192, 1, 1]], 104 | [-1, 1, Conv, [96, 3, 1]], 105 | [-1, 1, Conv, [96, 3, 1]], 106 | [[-1, -2, -3, -4], 1, Concat, [1]], 107 | [-1, 1, Conv, [192, 1, 1]], # 67 108 | 109 | [ -1, 1, Conv, [ 128, 1, 1 ] ], 110 | [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ], 111 | [ 9, 1, Conv, [ 128, 1, 1 ] ], # 8 route backbone P2 112 | [ [ -1, -2 ], 1, Concat, [ 1 ] ], 113 | 114 | [ -1, 1, Conv, [ 128, 1, 1 ] ], 115 | [ -2, 1, Conv, [ 128, 1, 1 ] ], 116 | [ -1, 1, Conv, [ 64, 3, 1 ] ], 117 | [ -1, 1, Conv, [ 64, 3, 1 ] ], 118 | [ [ -1, -2, -3, -4 ], 1, Concat, [ 1 ] ], 119 | [ -1, 1, Conv, [ 128, 1, 1 ] ], # 77 120 | 121 | [-1, 1, Conv, [256, 3, 2]], 122 | [[-1, 72], 1, Concat, [1]], # 67 cat 123 | 124 | [-1, 1, Conv, [256, 1, 1]], 125 | [-2, 1, Conv, [256, 1, 1]], 126 | [-1, 1, Conv, [128, 3, 1]], 127 | [-1, 1, Conv, [128, 3, 1]], 128 | [[-1, -2, -3, -4], 1, Concat, [1]], 129 | [-1, 1, Conv, [256, 1, 1]], # 85 130 | 131 | [-1, 1, Conv, [384, 3, 2]], 132 | [[-1, 62], 1, Concat, [1]], # 57 cat 133 | 134 | [-1, 1, Conv, [384, 1, 1]], 135 | [-2, 1, Conv, [384, 1, 1]], 136 | [-1, 1, Conv, [192, 3, 1]], 137 | [-1, 1, Conv, [192, 3, 1]], 138 | [[-1, -2, -3, -4], 1, Concat, [1]], 139 | [-1, 1, Conv, [384, 1, 1]], # 93 140 | 141 | [-1, 1, Conv, [512, 3, 2]], 142 | [[-1, 52], 1, Concat, [1]], # 47 cat 143 | 144 | [-1, 1, Conv, [512, 1, 1]], 145 | [-2, 1, Conv, [512, 1, 1]], 146 | [-1, 1, Conv, [256, 3, 1]], 147 | [-1, 1, Conv, [256, 3, 1]], 148 | [[-1, -2, -3, -4], 1, Concat, [1]], 149 | [-1, 1, Conv, [512, 1, 1]], # 101 150 | 151 | [ -1, 1, Conv, [ 768, 3, 2 ] ], 152 | [ [ -1, 42], 1, Concat, [ 1 ] ], # 37 cat 153 | 154 | [ -1, 1, Conv, [ 768, 1, 1 ] ], 155 | [ -2, 1, Conv, [ 768, 1, 1 ] ], 156 | [ -1, 1, Conv, [ 384, 3, 1 ] ], 157 | [ -1, 1, Conv, [ 384, 3, 1 ] ], 158 | [ [ -1, -2, -3, -4 ], 1, Concat, [ 1 ] ], 159 | [ -1, 1, Conv, [ 768, 1, 1 ] ], # 109 160 | 161 | [82, 1, Conv, [256, 3, 1]], # 77 110 162 | [90, 1, Conv, [384, 3, 1]], # 85 163 | [98, 1, Conv, [512, 3, 1]], # 93 164 | [106, 1, Conv, [768, 3, 1]], # 101 165 | [114, 1, Conv, [1024, 3, 1]], # 109 114 166 | 167 | [[115, 116, 117, 118, 119], 1, Detect, [nc, anchors]], # 110, 111, 112, 113, 114 Detect(P2, P3, P4, P5, P6) 168 | ] -------------------------------------------------------------------------------- /cfg/training/yolov7-w6.yaml: -------------------------------------------------------------------------------- 1 | # parameters 2 | nc: 80 # number of classes 3 | depth_multiple: 1.0 # model depth multiple 4 | width_multiple: 1.0 # layer channel multiple 5 | 6 | # anchors 7 | anchors: 8 | - [ 19,27, 44,40, 38,94 ] # P3/8 9 | - [ 96,68, 86,152, 180,137 ] # P4/16 10 | - [ 140,301, 303,264, 238,542 ] # P5/32 11 | - [ 436,615, 739,380, 925,792 ] # P6/64 12 | 13 | # yolov7 backbone 14 | backbone: 15 | # [from, number, module, args] 16 | [[-1, 1, ReOrg, []], # 0 17 | [-1, 1, Conv, [64, 3, 1]], # 1-P1/2 18 | 19 | [-1, 1, Conv, [128, 3, 2]], # 2-P2/4 20 | [-1, 1, Conv, [64, 1, 1]], 21 | [-2, 1, Conv, [64, 1, 1]], 22 | [-1, 1, Conv, [64, 3, 1]], 23 | [-1, 1, Conv, [64, 3, 1]], 24 | [-1, 1, Conv, [64, 3, 1]], 25 | [-1, 1, Conv, [64, 3, 1]], 26 | [[-1, -3, -5, -6], 1, Concat, [1]], 27 | [-1, 1, Conv, [128, 1, 1]], # 10 28 | 29 | [-1, 1, Conv, [256, 3, 2]], # 11-P3/8 30 | [-1, 1, Conv, [128, 1, 1]], 31 | [-2, 1, Conv, [128, 1, 1]], 32 | [-1, 1, Conv, [128, 3, 1]], 33 | [-1, 1, Conv, [128, 3, 1]], 34 | [-1, 1, Conv, [128, 3, 1]], 35 | [-1, 1, Conv, [128, 3, 1]], 36 | [[-1, -3, -5, -6], 1, Concat, [1]], 37 | [-1, 1, Conv, [256, 1, 1]], # 19 38 | 39 | [-1, 1, Conv, [512, 3, 2]], # 20-P4/16 40 | [-1, 1, Conv, [256, 1, 1]], 41 | [-2, 1, Conv, [256, 1, 1]], 42 | [-1, 1, Conv, [256, 3, 1]], 43 | [-1, 1, Conv, [256, 3, 1]], 44 | [-1, 1, Conv, [256, 3, 1]], 45 | [-1, 1, Conv, [256, 3, 1]], 46 | [[-1, -3, -5, -6], 1, Concat, [1]], 47 | [-1, 1, Conv, [512, 1, 1]], # 28 48 | 49 | [-1, 1, Conv, [768, 3, 2]], # 29-P5/32 50 | [-1, 1, Conv, [384, 1, 1]], 51 | [-2, 1, Conv, [384, 1, 1]], 52 | [-1, 1, Conv, [384, 3, 1]], 53 | [-1, 1, Conv, [384, 3, 1]], 54 | [-1, 1, Conv, [384, 3, 1]], 55 | [-1, 1, Conv, [384, 3, 1]], 56 | [[-1, -3, -5, -6], 1, Concat, [1]], 57 | [-1, 1, Conv, [768, 1, 1]], # 37 58 | 59 | [-1, 1, Conv, [1024, 3, 2]], # 38-P6/64 60 | [-1, 1, Conv, [512, 1, 1]], 61 | [-2, 1, Conv, [512, 1, 1]], 62 | [-1, 1, Conv, [512, 3, 1]], 63 | [-1, 1, Conv, [512, 3, 1]], 64 | [-1, 1, Conv, [512, 3, 1]], 65 | [-1, 1, Conv, [512, 3, 1]], 66 | [[-1, -3, -5, -6], 1, Concat, [1]], 67 | [-1, 1, Conv, [1024, 1, 1]], # 46 68 | ] 69 | 70 | # yolov7 head 71 | head: 72 | [[-1, 1, SPPCSPC, [512]], # 47 73 | 74 | [-1, 1, Conv, [384, 1, 1]], 75 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 76 | [37, 1, Conv, [384, 1, 1]], # route backbone P5 77 | [[-1, -2], 1, Concat, [1]], 78 | 79 | [-1, 1, Conv, [384, 1, 1]], 80 | [-2, 1, Conv, [384, 1, 1]], 81 | [-1, 1, Conv, [192, 3, 1]], 82 | [-1, 1, Conv, [192, 3, 1]], 83 | [-1, 1, Conv, [192, 3, 1]], 84 | [-1, 1, Conv, [192, 3, 1]], 85 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 86 | [-1, 1, Conv, [384, 1, 1]], # 59 87 | 88 | [-1, 1, Conv, [256, 1, 1]], 89 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 90 | [28, 1, Conv, [256, 1, 1]], # route backbone P4 91 | [[-1, -2], 1, Concat, [1]], 92 | 93 | [-1, 1, Conv, [256, 1, 1]], 94 | [-2, 1, Conv, [256, 1, 1]], 95 | [-1, 1, Conv, [128, 3, 1]], 96 | [-1, 1, Conv, [128, 3, 1]], 97 | [-1, 1, Conv, [128, 3, 1]], 98 | [-1, 1, Conv, [128, 3, 1]], 99 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 100 | [-1, 1, Conv, [256, 1, 1]], # 71 101 | 102 | [-1, 1, Conv, [128, 1, 1]], 103 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 104 | [19, 1, Conv, [128, 1, 1]], # route backbone P3 105 | [[-1, -2], 1, Concat, [1]], 106 | 107 | [-1, 1, Conv, [128, 1, 1]], 108 | [-2, 1, Conv, [128, 1, 1]], 109 | [-1, 1, Conv, [64, 3, 1]], 110 | [-1, 1, Conv, [64, 3, 1]], 111 | [-1, 1, Conv, [64, 3, 1]], 112 | [-1, 1, Conv, [64, 3, 1]], 113 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 114 | [-1, 1, Conv, [128, 1, 1]], # 83 115 | 116 | [-1, 1, Conv, [256, 3, 2]], 117 | [[-1, 71], 1, Concat, [1]], # cat 118 | 119 | [-1, 1, Conv, [256, 1, 1]], 120 | [-2, 1, Conv, [256, 1, 1]], 121 | [-1, 1, Conv, [128, 3, 1]], 122 | [-1, 1, Conv, [128, 3, 1]], 123 | [-1, 1, Conv, [128, 3, 1]], 124 | [-1, 1, Conv, [128, 3, 1]], 125 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 126 | [-1, 1, Conv, [256, 1, 1]], # 93 127 | 128 | [-1, 1, Conv, [384, 3, 2]], 129 | [[-1, 59], 1, Concat, [1]], # cat 130 | 131 | [-1, 1, Conv, [384, 1, 1]], 132 | [-2, 1, Conv, [384, 1, 1]], 133 | [-1, 1, Conv, [192, 3, 1]], 134 | [-1, 1, Conv, [192, 3, 1]], 135 | [-1, 1, Conv, [192, 3, 1]], 136 | [-1, 1, Conv, [192, 3, 1]], 137 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 138 | [-1, 1, Conv, [384, 1, 1]], # 103 139 | 140 | [-1, 1, Conv, [512, 3, 2]], 141 | [[-1, 47], 1, Concat, [1]], # cat 142 | 143 | [-1, 1, Conv, [512, 1, 1]], 144 | [-2, 1, Conv, [512, 1, 1]], 145 | [-1, 1, Conv, [256, 3, 1]], 146 | [-1, 1, Conv, [256, 3, 1]], 147 | [-1, 1, Conv, [256, 3, 1]], 148 | [-1, 1, Conv, [256, 3, 1]], 149 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 150 | [-1, 1, Conv, [512, 1, 1]], # 113 151 | 152 | [83, 1, Conv, [256, 3, 1]], 153 | [93, 1, Conv, [512, 3, 1]], 154 | [103, 1, Conv, [768, 3, 1]], 155 | [113, 1, Conv, [1024, 3, 1]], 156 | 157 | [83, 1, Conv, [320, 3, 1]], 158 | [71, 1, Conv, [640, 3, 1]], 159 | [59, 1, Conv, [960, 3, 1]], 160 | [47, 1, Conv, [1280, 3, 1]], 161 | 162 | [[114,115,116,117,118,119,120,121], 1, IAuxDetect, [nc, anchors]], # Detect(P3, P4, P5, P6) 163 | ] 164 | -------------------------------------------------------------------------------- /data/dota.yaml: -------------------------------------------------------------------------------- 1 | # train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/] 2 | 3 | train: ./split/DOTA_train/images/train/ 4 | val: ./split/DOTA_train/images/val/ 5 | 6 | # number of classes 7 | nc: 16 8 | 9 | # class names 10 | names: ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 11 | 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter', 'container-crane'] 12 | -------------------------------------------------------------------------------- /data/hyp.scratch.dota.yaml: -------------------------------------------------------------------------------- 1 | lr0: 0.001 2 | lrf: 0.1 3 | momentum: 0.9 4 | weight_decay: 0.0001 5 | warmup_epochs: 3.0 6 | warmup_momentum: 0.8 7 | warmup_bias_lr: 0.1 8 | box: 0.5 9 | cls: 0.1 10 | cls_pw: 1.0 11 | obj: 0.4 12 | obj_pw: 1.0 13 | iou_t: 0.20 14 | anchor_t: 4.0 15 | fl_gamma: 2.0 16 | hsv_h: 0.015 17 | hsv_s: 0.4 18 | hsv_v: 0.4 19 | degrees: 45 20 | translate: 0.0 21 | scale: 0.0 22 | shear: 0.0 23 | perspective: 0.0 24 | flipud: 0.5 25 | fliplr: 0.5 26 | mosaic: 0.0 27 | mixup: 0.0 28 | copy_paste: 0.0 29 | paste_in: 0.0 30 | 31 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from pathlib import Path 4 | 5 | import cv2 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from numpy import random 9 | 10 | from models.experimental import attempt_load 11 | from utils.datasets import LoadStreams, LoadImages, xywhTheda2Points 12 | from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ 13 | scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, non_max_suppression_obb 14 | from utils.plots import plot_one_box 15 | from utils.torch_utils import select_device, load_classifier, time_synchronized, TracedModel 16 | import numpy as np 17 | import math 18 | 19 | 20 | def detect(save_img=False): 21 | source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, opt.no_trace 22 | save_img = not opt.nosave and not source.endswith('.txt') # save inference images 23 | webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith( 24 | ('rtsp://', 'rtmp://', 'http://', 'https://')) 25 | 26 | # Directories 27 | save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run 28 | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir 29 | 30 | # Initialize 31 | set_logging() 32 | device = select_device(opt.device) 33 | 34 | # Load model 35 | model = attempt_load(weights, map_location=device) # load FP32 model 36 | stride = int(model.stride.max()) # model stride 37 | imgsz = check_img_size(imgsz, s=stride) # check img_size 38 | 39 | if trace: 40 | model = TracedModel(model, device, opt.img_size) 41 | 42 | # Second-stage classifier 43 | classify = False 44 | if classify: 45 | modelc = load_classifier(name='resnet101', n=2) # initialize 46 | modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval() 47 | 48 | # Set Dataloader 49 | vid_path, vid_writer = None, None 50 | if webcam: 51 | view_img = check_imshow() 52 | cudnn.benchmark = True # set True to speed up constant image size inference 53 | dataset = LoadStreams(source, img_size=imgsz, stride=stride) 54 | else: 55 | dataset = LoadImages(source, img_size=imgsz, stride=stride) 56 | 57 | # Get names and colors 58 | names = model.module.names if hasattr(model, 'module') else model.names 59 | colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] 60 | 61 | # Run inference 62 | if device.type != 'cpu': 63 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 64 | old_img_w = old_img_h = imgsz 65 | old_img_b = 1 66 | 67 | t0 = time.time() 68 | t_inference_all = [] 69 | t_nms_all = [] 70 | cnt = 0 71 | 72 | for eval_clear_i in range(16): 73 | eval_clear = 'evaluation/' + str(eval_clear_i) + '.txt' 74 | with open(eval_clear, 'a') as f: 75 | f.truncate(0) 76 | f.close() 77 | 78 | for path, img, im0s, vid_cap in dataset: 79 | # print('path: ', path) 80 | cnt += 1 81 | if cnt % 20 == 0: 82 | print(cnt, ' / ', dataset.nf, (cnt/dataset.nf * 100), '%', '************************************************************************************************************************************') 83 | 84 | img = torch.from_numpy(img).to(device) 85 | img = img.float() # uint8 to fp16/32 86 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 87 | if img.ndimension() == 3: 88 | img = img.unsqueeze(0) 89 | 90 | # Warmup 91 | if device.type != 'cpu' and (old_img_b != img.shape[0] or old_img_h != img.shape[2] or old_img_w != img.shape[3]): 92 | old_img_b = img.shape[0] 93 | old_img_h = img.shape[2] 94 | old_img_w = img.shape[3] 95 | for i in range(3): 96 | model(img, augment=opt.augment)[0] 97 | 98 | half = False 99 | if half: 100 | img = img.half() 101 | model = model.half() 102 | 103 | # Inference 104 | t1 = time_synchronized() 105 | with torch.no_grad(): # Calculating gradients would cause a GPU memory leak 106 | pred = model(img, augment=opt.augment)[0] 107 | t2 = time_synchronized() 108 | 109 | # Apply NMS 110 | 111 | # Remove prediction boxes under extreme conditions 112 | pred = pred.view(-1, 22) 113 | h_limit = (pred[:, 3] < 8) 114 | aspect_ratio_limit = (pred[:, 2] / abs(pred[:, 3] + 1e-6) > 30) 115 | out_limit = (h_limit | aspect_ratio_limit) 116 | pred = pred[~out_limit] 117 | pred = pred.view(1, -1, 22) 118 | 119 | pred = non_max_suppression_obb(pred, conf_thres=opt.conf_thres, iou_thres=opt.iou_thres, labels=opt.classes, multi_label=False) 120 | t3 = time_synchronized() 121 | 122 | # Save output to txt 123 | eval_path = path.split("/")[-1].split(".")[0] 124 | 125 | for eval_out_i in range(pred[0].shape[0]): 126 | eval_conf = pred[0][eval_out_i, 5].cpu().numpy() 127 | 128 | eval_xywhTheta = pred[0][eval_out_i, :5].cpu().numpy() 129 | eval_xywhTheta[-1] = (eval_xywhTheta[-1] + math.pi / 2) / math.pi * 180 130 | eval_Points = xywhTheda2Points(eval_xywhTheta) 131 | 132 | eval_cls = pred[0][eval_out_i, 6].cpu().numpy() 133 | from decimal import Decimal 134 | eval_write = Decimal(str(eval_cls)).normalize() 135 | eval_write = int(eval_write) 136 | eval_txt = 'evaluation/' + str(eval_write) + '.txt' 137 | with open(eval_txt, 'a') as f: 138 | f.write(eval_path + " ") 139 | f.write(str(eval_conf) + " ") 140 | for eval_Points_i in range(eval_Points.shape[0]): 141 | f.write(str(eval_Points[eval_Points_i][0]) + " ") 142 | f.write(str(eval_Points[eval_Points_i][1]) + " ") 143 | f.write('\n') 144 | f.close() 145 | 146 | # Apply Classifier 147 | if classify: 148 | pred = apply_classifier(pred, modelc, img, im0s) 149 | 150 | # Process detections 151 | for i, det in enumerate(pred): # detections per image 152 | if webcam: # batch_size >= 1 153 | p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count 154 | else: 155 | p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) 156 | 157 | p = Path(p) # to Path 158 | save_path = str(save_dir / p.name) # img.jpg 159 | txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt 160 | gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh 161 | if len(det): 162 | # Rescale boxes from img_size to im0 size 163 | det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() 164 | 165 | # Print results 166 | for c in det[:, 6].unique(): 167 | n = (det[:, 6] == c).sum() # detections per class 168 | s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string 169 | 170 | # Write results 171 | for x, y, w, h, angle, conf, cls in reversed(det): 172 | xywh = np.array([x.cpu(), y.cpu(), w.cpu(), h.cpu()]) 173 | if save_txt: # Write to file 174 | x, y, w, h = x/gn[0], y/gn[1], w/gn[0], h/gn[1] # 0~1 175 | with open(txt_path + '.txt', 'a') as f: 176 | f.write(('%g ' * 7 + '\n') % (conf, cls, x, y, w, h, int((math.pi/2 + angle) / math.pi * 180))) # label format 177 | 178 | if save_img or view_img: # Add bbox to image 179 | angle = int((math.pi/2 + angle) / math.pi * 180) 180 | label = '%s %.2f %s' % (names[int(cls)], conf, angle) 181 | xyxy = xywh2xyxy(xywh) 182 | plot_one_box(xyxy, angle, im0, label=label, color=colors[int(cls)]) 183 | 184 | # Print time (inference + NMS) 185 | # print(f'{s}Done. ({(1E3 * (t2 - t1)):.1f}ms) Inference, ({(1E3 * (t3 - t2)):.1f}ms) NMS') 186 | t_inference_all.append((t2 - t1)*1000) 187 | t_nms_all.append((t3 - t2) * 1000) 188 | 189 | # Stream results 190 | if view_img: 191 | cv2.imshow(str(p), im0) 192 | cv2.waitKey(1) # 1 millisecond 193 | 194 | # Save results (image with detections) 195 | if save_img: 196 | if dataset.mode == 'image': 197 | cv2.imwrite(save_path, im0) 198 | print(f" The image with the result is saved in: {save_path}") 199 | else: # 'video' or 'stream' 200 | if vid_path != save_path: # new video 201 | vid_path = save_path 202 | if isinstance(vid_writer, cv2.VideoWriter): 203 | vid_writer.release() # release previous video writer 204 | if vid_cap: # video 205 | fps = vid_cap.get(cv2.CAP_PROP_FPS) 206 | w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 207 | h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 208 | else: # stream 209 | fps, w, h = 30, im0.shape[1], im0.shape[0] 210 | save_path += '.mp4' 211 | vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) 212 | vid_writer.write(im0) 213 | 214 | print('inference average time: ', sum(t_inference_all)/len(t_inference_all), 215 | 'nms average time: ', sum(t_nms_all)/len(t_nms_all)) 216 | if save_txt or save_img: 217 | print('Results saved to %s' % Path(save_dir)) 218 | 219 | print(f'Done. ({time.time() - t0:.3f}s)') 220 | 221 | 222 | def xywh2xyxy(x): 223 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 224 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 225 | y[0] = x[0] - x[2] / 2 # top left x 226 | y[1] = x[1] - x[3] / 2 # top left y 227 | y[2] = x[0] + x[2] / 2 # bottom right x 228 | y[3] = x[1] + x[3] / 2 # bottom right y 229 | return y 230 | 231 | 232 | if __name__ == '__main__': 233 | parser = argparse.ArgumentParser() 234 | 235 | parser.add_argument('--weights', default='weights/*.pt') 236 | parser.add_argument('--source', type=str, default='split/DOTA_test/', help='source') 237 | parser.add_argument('--project', default='evaluation/predictions', help='save results to project/name') 238 | parser.add_argument('--img-size', type=int, default=1024, help='inference size (pixels)') 239 | parser.add_argument('--conf-thres', type=float, default=0.5, help='object confidence threshold') 240 | parser.add_argument('--iou-thres', type=float, default=0.2, help='IOU threshold for NMS') 241 | parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 242 | 243 | parser.add_argument('--view-img', action='store_true', help='display results') 244 | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') 245 | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') 246 | parser.add_argument('--nosave', action='store_true', help='do not save images/videos') 247 | parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') 248 | parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') 249 | parser.add_argument('--augment', action='store_true', help='augmented inference') 250 | parser.add_argument('--update', action='store_true', help='update all models') 251 | parser.add_argument('--name', default='exp', help='save results to project/name') 252 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 253 | parser.add_argument('--no-trace', action='store_true', help='don`t trace model') 254 | opt = parser.parse_args() 255 | print(opt) 256 | #check_requirements(exclude=('pycocotools', 'thop')) 257 | 258 | with torch.no_grad(): 259 | if opt.update: # update all models (to fix SourceChangeWarning) 260 | for opt.weights in ['yolov7.pt']: 261 | detect() 262 | strip_optimizer(opt.weights) 263 | else: 264 | detect() 265 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import time 4 | import warnings 5 | 6 | sys.path.append('./') # to run '$ python *.py' files in subdirectories 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.mobile_optimizer import optimize_for_mobile 11 | 12 | import models 13 | from models.experimental import attempt_load, End2End 14 | from utils.activations import Hardswish, SiLU 15 | from utils.general import set_logging, check_img_size 16 | from utils.torch_utils import select_device 17 | from utils.add_nms import RegisterNMS 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--weights', type=str, default='./yolor-csp-c.pt', help='weights path') 22 | parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width 23 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 24 | parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') 25 | parser.add_argument('--dynamic-batch', action='store_true', help='dynamic batch onnx for tensorrt and onnx-runtime') 26 | parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') 27 | parser.add_argument('--end2end', action='store_true', help='export end2end onnx') 28 | parser.add_argument('--max-wh', type=int, default=None, help='None for tensorrt nms, int value for onnx-runtime nms') 29 | parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images') 30 | parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS') 31 | parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS') 32 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 33 | parser.add_argument('--simplify', action='store_true', help='simplify onnx model') 34 | parser.add_argument('--include-nms', action='store_true', help='export end2end onnx') 35 | parser.add_argument('--fp16', action='store_true', help='CoreML FP16 half-precision export') 36 | parser.add_argument('--int8', action='store_true', help='CoreML INT8 quantization') 37 | opt = parser.parse_args() 38 | opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand 39 | opt.dynamic = opt.dynamic and not opt.end2end 40 | opt.dynamic = False if opt.dynamic_batch else opt.dynamic 41 | print(opt) 42 | set_logging() 43 | t = time.time() 44 | 45 | # Load PyTorch model 46 | device = select_device(opt.device) 47 | model = attempt_load(opt.weights, map_location=device) # load FP32 model 48 | labels = model.names 49 | 50 | # Checks 51 | gs = int(max(model.stride)) # grid size (max stride) 52 | opt.img_size = [check_img_size(x, gs) for x in opt.img_size] # verify img_size are gs-multiples 53 | 54 | # Input 55 | img = torch.zeros(opt.batch_size, 3, *opt.img_size).to(device) # image size(1,3,320,192) iDetection 56 | 57 | # Update model 58 | for k, m in model.named_modules(): 59 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 60 | if isinstance(m, models.common.Conv): # assign export-friendly activations 61 | if isinstance(m.act, nn.Hardswish): 62 | m.act = Hardswish() 63 | elif isinstance(m.act, nn.SiLU): 64 | m.act = SiLU() 65 | # elif isinstance(m, models.yolo.Detect): 66 | # m.forward = m.forward_export # assign forward (optional) 67 | model.model[-1].export = not opt.grid # set Detect() layer grid export 68 | y = model(img) # dry run 69 | if opt.include_nms: 70 | model.model[-1].include_nms = True 71 | y = None 72 | 73 | # TorchScript export 74 | try: 75 | print('\nStarting TorchScript export with torch %s...' % torch.__version__) 76 | f = opt.weights.replace('.pt', '.torchscript.pt') # filename 77 | ts = torch.jit.trace(model, img, strict=False) 78 | ts.save(f) 79 | print('TorchScript export success, saved as %s' % f) 80 | except Exception as e: 81 | print('TorchScript export failure: %s' % e) 82 | 83 | # CoreML export 84 | try: 85 | import coremltools as ct 86 | 87 | print('\nStarting CoreML export with coremltools %s...' % ct.__version__) 88 | # convert model from torchscript and apply pixel scaling as per detect.py 89 | ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])]) 90 | bits, mode = (8, 'kmeans_lut') if opt.int8 else (16, 'linear') if opt.fp16 else (32, None) 91 | if bits < 32: 92 | if sys.platform.lower() == 'darwin': # quantization only supported on macOS 93 | with warnings.catch_warnings(): 94 | warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning 95 | ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) 96 | else: 97 | print('quantization only supported on macOS, skipping...') 98 | 99 | f = opt.weights.replace('.pt', '.mlmodel') # filename 100 | ct_model.save(f) 101 | print('CoreML export success, saved as %s' % f) 102 | except Exception as e: 103 | print('CoreML export failure: %s' % e) 104 | 105 | # TorchScript-Lite export 106 | try: 107 | print('\nStarting TorchScript-Lite export with torch %s...' % torch.__version__) 108 | f = opt.weights.replace('.pt', '.torchscript.ptl') # filename 109 | tsl = torch.jit.trace(model, img, strict=False) 110 | tsl = optimize_for_mobile(tsl) 111 | tsl._save_for_lite_interpreter(f) 112 | print('TorchScript-Lite export success, saved as %s' % f) 113 | except Exception as e: 114 | print('TorchScript-Lite export failure: %s' % e) 115 | 116 | # ONNX export 117 | try: 118 | import onnx 119 | 120 | print('\nStarting ONNX export with onnx %s...' % onnx.__version__) 121 | f = opt.weights.replace('.pt', '.onnx') # filename 122 | model.eval() 123 | output_names = ['classes', 'boxes'] if y is None else ['output'] 124 | dynamic_axes = None 125 | if opt.dynamic: 126 | dynamic_axes = {'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640) 127 | 'output': {0: 'batch', 2: 'y', 3: 'x'}} 128 | if opt.dynamic_batch: 129 | opt.batch_size = 'batch' 130 | dynamic_axes = { 131 | 'images': { 132 | 0: 'batch', 133 | }, } 134 | if opt.end2end and opt.max_wh is None: 135 | output_axes = { 136 | 'num_dets': {0: 'batch'}, 137 | 'det_boxes': {0: 'batch'}, 138 | 'det_scores': {0: 'batch'}, 139 | 'det_classes': {0: 'batch'}, 140 | } 141 | else: 142 | output_axes = { 143 | 'output': {0: 'batch'}, 144 | } 145 | dynamic_axes.update(output_axes) 146 | if opt.grid: 147 | if opt.end2end: 148 | print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime') 149 | model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels)) 150 | if opt.end2end and opt.max_wh is None: 151 | output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes'] 152 | shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4, 153 | opt.batch_size, opt.topk_all, opt.batch_size, opt.topk_all] 154 | else: 155 | output_names = ['output'] 156 | else: 157 | model.model[-1].concat = True 158 | 159 | torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'], 160 | output_names=output_names, 161 | dynamic_axes=dynamic_axes) 162 | 163 | # Checks 164 | onnx_model = onnx.load(f) # load onnx model 165 | onnx.checker.check_model(onnx_model) # check onnx model 166 | 167 | if opt.end2end and opt.max_wh is None: 168 | for i in onnx_model.graph.output: 169 | for j in i.type.tensor_type.shape.dim: 170 | j.dim_param = str(shapes.pop(0)) 171 | 172 | # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model 173 | 174 | # # Metadata 175 | # d = {'stride': int(max(model.stride))} 176 | # for k, v in d.items(): 177 | # meta = onnx_model.metadata_props.add() 178 | # meta.key, meta.value = k, str(v) 179 | # onnx.save(onnx_model, f) 180 | 181 | if opt.simplify: 182 | try: 183 | import onnxsim 184 | 185 | print('\nStarting to simplify ONNX...') 186 | onnx_model, check = onnxsim.simplify(onnx_model) 187 | assert check, 'assert check failed' 188 | except Exception as e: 189 | print(f'Simplifier failure: {e}') 190 | 191 | # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model 192 | onnx.save(onnx_model,f) 193 | print('ONNX export success, saved as %s' % f) 194 | 195 | if opt.include_nms: 196 | print('Registering NMS plugin for ONNX...') 197 | mo = RegisterNMS(f) 198 | mo.register_nms() 199 | mo.save(f) 200 | 201 | except Exception as e: 202 | print('ONNX export failure: %s' % e) 203 | 204 | # Finish 205 | print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t)) 206 | -------------------------------------------------------------------------------- /figure/Experiments_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/Experiments_Comparison.png -------------------------------------------------------------------------------- /figure/LGBB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/LGBB.png -------------------------------------------------------------------------------- /figure/OBB_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/OBB_Comparison.png -------------------------------------------------------------------------------- /figure/OBB_Representation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/OBB_Representation.png -------------------------------------------------------------------------------- /figure/Orientation_Sensitive_Feature_Extraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/Orientation_Sensitive_Feature_Extraction.png -------------------------------------------------------------------------------- /figure/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/Overview.png -------------------------------------------------------------------------------- /figure/RRC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/RRC.png -------------------------------------------------------------------------------- /figure/RRC_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/RRC_Comparison.png -------------------------------------------------------------------------------- /figure/Vis_Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/Vis_Comparison.png -------------------------------------------------------------------------------- /figure/Vis_RRC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhen6618/RotaYolo/7083b23e4f702ba2d8dec5d982e045d5700481ac/figure/Vis_RRC.png -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | """PyTorch Hub models 2 | 3 | Usage: 4 | import torch 5 | model = torch.hub.load('repo', 'model') 6 | """ 7 | 8 | from pathlib import Path 9 | 10 | import torch 11 | 12 | from models.yolo import Model 13 | from utils.general import check_requirements, set_logging 14 | from utils.google_utils import attempt_download 15 | from utils.torch_utils import select_device 16 | 17 | dependencies = ['torch', 'yaml'] 18 | check_requirements(Path(__file__).parent / 'requirements.txt', exclude=('pycocotools', 'thop')) 19 | set_logging() 20 | 21 | 22 | def create(name, pretrained, channels, classes, autoshape): 23 | """Creates a specified model 24 | 25 | Arguments: 26 | name (str): name of model, i.e. 'yolov7' 27 | pretrained (bool): load pretrained weights into the model 28 | channels (int): number of input channels 29 | classes (int): number of model classes 30 | 31 | Returns: 32 | pytorch model 33 | """ 34 | try: 35 | cfg = list((Path(__file__).parent / 'cfg').rglob(f'{name}.yaml'))[0] # model.yaml path 36 | model = Model(cfg, channels, classes) 37 | if pretrained: 38 | fname = f'{name}.pt' # checkpoint filename 39 | attempt_download(fname) # download if not found locally 40 | ckpt = torch.load(fname, map_location=torch.device('cpu')) # load 41 | msd = model.state_dict() # model state_dict 42 | csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 43 | csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter 44 | model.load_state_dict(csd, strict=False) # load 45 | if len(ckpt['model'].names) == classes: 46 | model.names = ckpt['model'].names # set class names attribute 47 | if autoshape: 48 | model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS 49 | device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available 50 | return model.to(device) 51 | 52 | except Exception as e: 53 | s = 'Cache maybe be out of date, try force_reload=True.' 54 | raise Exception(s) from e 55 | 56 | 57 | def custom(path_or_model='path/to/model.pt', autoshape=True): 58 | """custom mode 59 | 60 | Arguments (3 options): 61 | path_or_model (str): 'path/to/model.pt' 62 | path_or_model (dict): torch.load('path/to/model.pt') 63 | path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] 64 | 65 | Returns: 66 | pytorch model 67 | """ 68 | model = torch.load(path_or_model, map_location=torch.device('cpu')) if isinstance(path_or_model, str) else path_or_model # load checkpoint 69 | if isinstance(model, dict): 70 | model = model['ema' if model.get('ema') else 'model'] # load model 71 | 72 | hub_model = Model(model.yaml).to(next(model.parameters()).device) # create 73 | hub_model.load_state_dict(model.float().state_dict()) # load state_dict 74 | hub_model.names = model.names # class names 75 | if autoshape: 76 | hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS 77 | device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available 78 | return hub_model.to(device) 79 | 80 | 81 | def yolov7(pretrained=True, channels=3, classes=80, autoshape=True): 82 | return create('yolov7', pretrained, channels, classes, autoshape) 83 | 84 | 85 | if __name__ == '__main__': 86 | model = custom(path_or_model='yolov7.pt') # custom example 87 | # model = create(name='yolov7', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example 88 | 89 | # Verify inference 90 | import numpy as np 91 | from PIL import Image 92 | 93 | imgs = [np.zeros((640, 480, 3))] 94 | 95 | results = model(imgs) # batched inference 96 | results.print() 97 | results.save() 98 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # init -------------------------------------------------------------------------------- /models/experimental.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.common import Conv, DWConv 7 | from utils.google_utils import attempt_download 8 | 9 | 10 | class CrossConv(nn.Module): 11 | # Cross Convolution Downsample 12 | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): 13 | # ch_in, ch_out, kernel, stride, groups, expansion, shortcut 14 | super(CrossConv, self).__init__() 15 | c_ = int(c2 * e) # hidden channels 16 | self.cv1 = Conv(c1, c_, (1, k), (1, s)) 17 | self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) 18 | self.add = shortcut and c1 == c2 19 | 20 | def forward(self, x): 21 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 22 | 23 | 24 | class Sum(nn.Module): 25 | # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 26 | def __init__(self, n, weight=False): # n: number of inputs 27 | super(Sum, self).__init__() 28 | self.weight = weight # apply weights boolean 29 | self.iter = range(n - 1) # iter object 30 | if weight: 31 | self.w = nn.Parameter(-torch.arange(1., n) / 2, requires_grad=True) # layer weights 32 | 33 | def forward(self, x): 34 | y = x[0] # no weight 35 | if self.weight: 36 | w = torch.sigmoid(self.w) * 2 37 | for i in self.iter: 38 | y = y + x[i + 1] * w[i] 39 | else: 40 | for i in self.iter: 41 | y = y + x[i + 1] 42 | return y 43 | 44 | 45 | class MixConv2d(nn.Module): 46 | # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 47 | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): 48 | super(MixConv2d, self).__init__() 49 | groups = len(k) 50 | if equal_ch: # equal c_ per group 51 | i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices 52 | c_ = [(i == g).sum() for g in range(groups)] # intermediate channels 53 | else: # equal weight.numel() per group 54 | b = [c2] + [0] * groups 55 | a = np.eye(groups + 1, groups, k=-1) 56 | a -= np.roll(a, 1, axis=1) 57 | a *= np.array(k) ** 2 58 | a[0] = 1 59 | c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b 60 | 61 | self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) 62 | self.bn = nn.BatchNorm2d(c2) 63 | self.act = nn.LeakyReLU(0.1, inplace=True) 64 | 65 | def forward(self, x): 66 | return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) 67 | 68 | 69 | class Ensemble(nn.ModuleList): 70 | # Ensemble of models 71 | def __init__(self): 72 | super(Ensemble, self).__init__() 73 | 74 | def forward(self, x, augment=False): 75 | y = [] 76 | for module in self: 77 | y.append(module(x, augment)[0]) 78 | # y = torch.stack(y).max(0)[0] # max ensemble 79 | # y = torch.stack(y).mean(0) # mean ensemble 80 | y = torch.cat(y, 1) # nms ensemble 81 | return y, None # inference, train output 82 | 83 | 84 | 85 | 86 | 87 | class ORT_NMS(torch.autograd.Function): 88 | '''ONNX-Runtime NMS operation''' 89 | @staticmethod 90 | def forward(ctx, 91 | boxes, 92 | scores, 93 | max_output_boxes_per_class=torch.tensor([100]), 94 | iou_threshold=torch.tensor([0.45]), 95 | score_threshold=torch.tensor([0.25])): 96 | device = boxes.device 97 | batch = scores.shape[0] 98 | num_det = random.randint(0, 100) 99 | batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device) 100 | idxs = torch.arange(100, 100 + num_det).to(device) 101 | zeros = torch.zeros((num_det,), dtype=torch.int64).to(device) 102 | selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous() 103 | selected_indices = selected_indices.to(torch.int64) 104 | return selected_indices 105 | 106 | @staticmethod 107 | def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): 108 | return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold) 109 | 110 | 111 | class TRT_NMS(torch.autograd.Function): 112 | '''TensorRT NMS operation''' 113 | @staticmethod 114 | def forward( 115 | ctx, 116 | boxes, 117 | scores, 118 | background_class=-1, 119 | box_coding=1, 120 | iou_threshold=0.45, 121 | max_output_boxes=100, 122 | plugin_version="1", 123 | score_activation=0, 124 | score_threshold=0.25, 125 | ): 126 | batch_size, num_boxes, num_classes = scores.shape 127 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) 128 | det_boxes = torch.randn(batch_size, max_output_boxes, 4) 129 | det_scores = torch.randn(batch_size, max_output_boxes) 130 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) 131 | return num_det, det_boxes, det_scores, det_classes 132 | 133 | @staticmethod 134 | def symbolic(g, 135 | boxes, 136 | scores, 137 | background_class=-1, 138 | box_coding=1, 139 | iou_threshold=0.45, 140 | max_output_boxes=100, 141 | plugin_version="1", 142 | score_activation=0, 143 | score_threshold=0.25): 144 | out = g.op("TRT::EfficientNMS_TRT", 145 | boxes, 146 | scores, 147 | background_class_i=background_class, 148 | box_coding_i=box_coding, 149 | iou_threshold_f=iou_threshold, 150 | max_output_boxes_i=max_output_boxes, 151 | plugin_version_s=plugin_version, 152 | score_activation_i=score_activation, 153 | score_threshold_f=score_threshold, 154 | outputs=4) 155 | nums, boxes, scores, classes = out 156 | return nums, boxes, scores, classes 157 | 158 | 159 | class ONNX_ORT(nn.Module): 160 | '''onnx module with ONNX-Runtime NMS operation.''' 161 | def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80): 162 | super().__init__() 163 | self.device = device if device else torch.device("cpu") 164 | self.max_obj = torch.tensor([max_obj]).to(device) 165 | self.iou_threshold = torch.tensor([iou_thres]).to(device) 166 | self.score_threshold = torch.tensor([score_thres]).to(device) 167 | self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic 168 | self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]], 169 | dtype=torch.float32, 170 | device=self.device) 171 | self.n_classes=n_classes 172 | 173 | def forward(self, x): 174 | boxes = x[:, :, :4] 175 | conf = x[:, :, 4:5] 176 | scores = x[:, :, 5:] 177 | if self.n_classes == 1: 178 | scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5, 179 | # so there is no need to multiplicate. 180 | else: 181 | scores *= conf # conf = obj_conf * cls_conf 182 | boxes @= self.convert_matrix 183 | max_score, category_id = scores.max(2, keepdim=True) 184 | dis = category_id.float() * self.max_wh 185 | nmsbox = boxes + dis 186 | max_score_tp = max_score.transpose(1, 2).contiguous() 187 | selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold) 188 | X, Y = selected_indices[:, 0], selected_indices[:, 2] 189 | selected_boxes = boxes[X, Y, :] 190 | selected_categories = category_id[X, Y, :].float() 191 | selected_scores = max_score[X, Y, :] 192 | X = X.unsqueeze(1).float() 193 | return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1) 194 | 195 | class ONNX_TRT(nn.Module): 196 | '''onnx module with TensorRT NMS operation.''' 197 | def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): 198 | super().__init__() 199 | assert max_wh is None 200 | self.device = device if device else torch.device('cpu') 201 | self.background_class = -1, 202 | self.box_coding = 1, 203 | self.iou_threshold = iou_thres 204 | self.max_obj = max_obj 205 | self.plugin_version = '1' 206 | self.score_activation = 0 207 | self.score_threshold = score_thres 208 | self.n_classes=n_classes 209 | 210 | def forward(self, x): 211 | boxes = x[:, :, :4] 212 | conf = x[:, :, 4:5] 213 | scores = x[:, :, 5:] 214 | if self.n_classes == 1: 215 | scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5, 216 | # so there is no need to multiplicate. 217 | else: 218 | scores *= conf # conf = obj_conf * cls_conf 219 | num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding, 220 | self.iou_threshold, self.max_obj, 221 | self.plugin_version, self.score_activation, 222 | self.score_threshold) 223 | return num_det, det_boxes, det_scores, det_classes 224 | 225 | 226 | class End2End(nn.Module): 227 | '''export onnx or tensorrt model with NMS operation.''' 228 | def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80): 229 | super().__init__() 230 | device = device if device else torch.device('cpu') 231 | assert isinstance(max_wh,(int)) or max_wh is None 232 | self.model = model.to(device) 233 | self.model.model[-1].end2end = True 234 | self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT 235 | self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes) 236 | self.end2end.eval() 237 | 238 | def forward(self, x): 239 | x = self.model(x) 240 | x = self.end2end(x) 241 | return x 242 | 243 | 244 | def attempt_load(weights, map_location=None): 245 | # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a 246 | model = Ensemble() 247 | for w in weights if isinstance(weights, list) else [weights]: 248 | attempt_download(w) 249 | ckpt = torch.load(w, map_location=map_location) # load 250 | model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model 251 | 252 | # Compatibility updates 253 | for m in model.modules(): 254 | if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: 255 | m.inplace = True # pytorch 1.7.0 compatibility 256 | elif type(m) is nn.Upsample: 257 | m.recompute_scale_factor = None # torch 1.11.0 compatibility 258 | elif type(m) is Conv: 259 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 260 | 261 | if len(model) == 1: 262 | return model[-1] # return model 263 | else: 264 | print('Ensemble created with %s\n' % weights) 265 | for k in ['names', 'stride']: 266 | setattr(model, k, getattr(model[-1], k)) 267 | return model # return ensemble 268 | 269 | 270 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | from threading import Thread 6 | 7 | import numpy as np 8 | import torch 9 | import yaml 10 | from tqdm import tqdm 11 | import math 12 | from copy import deepcopy 13 | 14 | from models.experimental import attempt_load 15 | from utils.datasets import create_dataloader, xywhTheda2Points 16 | from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ 17 | box_iou, non_max_suppression, non_max_suppression_obb, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, \ 18 | increment_path, colorstr 19 | from utils.metrics import ap_per_class, ConfusionMatrix 20 | from utils.plots import plot_images, output_to_target, plot_study_txt 21 | from utils.torch_utils import select_device, time_synchronized, TracedModel 22 | from EigenTheda import Eigen2xywhTheda, xywhTheda2Eigen, x2Eigen, Eigen2xywhTheda_numpy 23 | from decimal import Decimal 24 | from mmcv.ops import diff_iou_rotated_2d 25 | import copy 26 | 27 | 28 | def compute_iou_test(tcx, pcx, scale): 29 | """ 30 | input: tcx:(m, 5) pcx:(n, 5) 31 | output: loss_iou:(m, n) 32 | """ 33 | m, n = tcx.shape[0], pcx.shape[0] 34 | device = pcx.device 35 | loss_iou = torch.zeros([tcx.shape[0], pcx.shape[0]]) 36 | 37 | t = copy.deepcopy(tcx).unsqueeze(1).repeat(1, n, 1) # (m, n, 5) 38 | p = copy.deepcopy(pcx).unsqueeze(0).repeat(m, 1, 1) # (m, n, 5) 39 | 40 | p[..., 4] = p[..., 4] / 180 * math.pi 41 | t[..., 4] = t[..., 4] / 180 * math.pi 42 | p[..., :4] = p[..., :4] / scale 43 | t[..., :4] = t[..., :4] / scale 44 | p = p.to(dtype=torch.float32) 45 | t = t.to(dtype=torch.float32) 46 | 47 | iou_mmcv_multi = torch.zeros((0, n)).to(device) 48 | for j in range(m): 49 | t_j, p_j = t[j].unsqueeze(0), p[j].unsqueeze(0) 50 | 51 | iou_mmcv_j = diff_iou_rotated_2d(t_j, p_j) # x y w h angle(rad) 52 | iou_mmcv_multi = torch.cat([iou_mmcv_multi, iou_mmcv_j], dim=0) 53 | 54 | iou_mmcv = iou_mmcv_multi 55 | 56 | return iou_mmcv 57 | 58 | def test(data, 59 | weights=None, 60 | batch_size=4, 61 | imgsz=1024, 62 | conf_thres=0.3, 63 | iou_thres=0.1, # for NMS 64 | save_json=False, 65 | single_cls=False, 66 | augment=False, 67 | verbose=False, 68 | model=None, 69 | dataloader=None, 70 | save_dir=Path(''), # for saving images 71 | save_txt=False, # for auto-labelling 72 | save_hybrid=False, # for hybrid auto-labelling 73 | save_conf=False, # save auto-label confidences 74 | plots=True, 75 | wandb_logger=None, 76 | compute_loss=None, 77 | half_precision=False, 78 | trace=False, 79 | is_coco=False, 80 | v5_metric=False): 81 | # Initialize/load model and set device 82 | training = model is not None 83 | if training: # called by train.py 84 | device = next(model.parameters()).device # get model device 85 | 86 | else: # called directly 87 | set_logging() 88 | device = select_device(opt.device, batch_size=batch_size) 89 | 90 | # Directories 91 | save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run 92 | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir 93 | 94 | # Load model 95 | model = attempt_load(weights, map_location=device) # load FP32 model 96 | gs = max(int(model.stride.max()), 32) # grid size (max stride) 97 | imgsz = check_img_size(imgsz, s=gs) # check img_size 98 | 99 | if trace: 100 | model = TracedModel(model, device, imgsz) 101 | 102 | # Configure 103 | model.eval() 104 | if isinstance(data, str): 105 | is_coco = data.endswith('coco.yaml') 106 | with open(data) as f: 107 | data = yaml.load(f, Loader=yaml.SafeLoader) 108 | check_dataset(data) # check 109 | nc = 1 if single_cls else int(data['nc']) # number of classes 110 | iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 111 | niou = iouv.numel() 112 | 113 | # Logging 114 | log_imgs = 0 115 | if wandb_logger and wandb_logger.wandb: 116 | log_imgs = min(wandb_logger.log_imgs, 100) 117 | # Dataloader 118 | if not training: 119 | if device.type != 'cpu': 120 | model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once 121 | task = opt.task if opt.task in ('train', 'val', 'test') else 'val' # path to train/val/test images 122 | dataloader = create_dataloader(data[task], imgsz, batch_size, gs, opt, pad=0.5, rect=True, 123 | prefix=colorstr(f'{task}: '))[0] 124 | 125 | if v5_metric: 126 | print("Testing with YOLOv5 AP metric...") 127 | 128 | seen = 0 129 | confusion_matrix = ConfusionMatrix(nc=nc) 130 | names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)} 131 | coco91class = coco80_to_coco91_class() 132 | s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') 133 | p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0. 134 | loss = torch.zeros(3, device=device) 135 | jdict, stats, ap, ap_class, wandb_images = [], [], [], [], [] 136 | 137 | 'test' 138 | for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): 139 | 140 | loss_targets = copy.deepcopy(targets).to(device) 141 | 142 | img = img.to(device, non_blocking=True) 143 | img = img.float() # uint8 to fp16/32 144 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 145 | 146 | "bs_index, cls, cx, cy(0-1), x1, x2, x3 to bs_index, cls, x, y, w, h, angle[0, 180)" 147 | uu = deepcopy(targets[:, 2:7]).detach().cpu().numpy() 148 | uu = x2Eigen(uu) 149 | uu, _ = Eigen2xywhTheda_numpy(uu) 150 | targets[:, 2:7] = torch.from_numpy(uu) 151 | targets = targets.to(device) 152 | 153 | nb, _, height, width = img.shape # batch size, channels, height, width 154 | 155 | # if half: 156 | # img = img.half() 157 | # model = model.half() 158 | 159 | with torch.no_grad(): 160 | # Run model 161 | t = time_synchronized() 162 | out, train_out = model(img, augment=augment) # inference and training outputs 163 | t0 += time_synchronized() - t 164 | 165 | # Compute loss 166 | if compute_loss: 167 | loss += compute_loss([x.float() for x in deepcopy(train_out)], deepcopy(loss_targets), img)[1][:3] # box, obj, cls 168 | 169 | # Run NMS 170 | "after NMS: x y w h angle[-pi/2, pi/2) conf cls" 171 | targets[:, 2:6] *= torch.Tensor([width, height, width, height]).to(device) 172 | lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling 173 | 174 | # Remove prediction boxes under extreme conditions 175 | out = out.view(-1, 22) 176 | off = (out[:, 0] >= img.shape[2]) | (out[:, 0] <= 0) | (out[:, 1] >= img.shape[2]) | (out[:, 1] <= 0) 177 | h_limit = (out[:, 3] < 4) 178 | w_limit = (out[:, 2] > 800) 179 | aspect_ratio_limit = (out[:, 2] / abs(out[:, 3] + 1e-6) > 30) 180 | out_limit = (w_limit | h_limit | aspect_ratio_limit | off) 181 | out = out[~out_limit] 182 | out = out.view(1, -1, 22) 183 | 184 | out = non_max_suppression_obb(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=False) 185 | t1 += time_synchronized() - t 186 | 187 | # Statistics per image 188 | for si, pred in enumerate(out): 189 | labels = targets[targets[:, 0] == si, 1:] 190 | nl = len(labels) 191 | tcls = labels[:, 0].tolist() if nl else [] # target class 192 | path = Path(paths[si]) 193 | seen += 1 194 | 195 | if len(pred) == 0: 196 | if nl: 197 | stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) 198 | continue 199 | 200 | # Predictions 201 | # predn = pred.clone() 202 | # scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred 203 | 204 | # # Append to text file 205 | # if save_txt: 206 | # gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh 207 | # for *xyxy, conf, cls in predn.tolist(): 208 | # xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 209 | # line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format 210 | # with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f: 211 | # f.write(('%g ' * len(line)).rstrip() % line + '\n') 212 | # 213 | # # W&B logging - Media Panel Plots 214 | # if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # Check for test operation 215 | # if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0: 216 | # box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 217 | # "class_id": int(cls), 218 | # "box_caption": "%s %.3f" % (names[cls], conf), 219 | # "scores": {"class_score": conf}, 220 | # "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] 221 | # boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space 222 | # wandb_images.append(wandb_logger.wandb.Image(img[si], boxes=boxes, caption=path.name)) 223 | # wandb_logger.log_training_progress(predn, path, names) if wandb_logger and wandb_logger.wandb_run else None 224 | # 225 | # # Append to pycocotools JSON dictionary 226 | # if save_json: 227 | # # [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... 228 | # image_id = int(path.stem) if path.stem.isnumeric() else path.stem 229 | # box = xyxy2xywh(predn[:, :4]) # xywh 230 | # box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner 231 | # for p, b in zip(pred.tolist(), box.tolist()): 232 | # jdict.append({'image_id': image_id, 233 | # 'category_id': coco91class[int(p[5])] if is_coco else int(p[5]), 234 | # 'bbox': [round(x, 3) for x in b], 235 | # 'score': round(p[4], 5)}) 236 | 237 | 'Assign all predictions as incorrect' 238 | correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool) 239 | if nl: 240 | detected = [] # target indices 241 | tcls_tensor = labels[:, 0] 242 | 243 | tbox = labels[:, 1:6] 244 | ppred = deepcopy(pred) 245 | ppred[:, 4] = (ppred[:, 4] + math.pi / 2) / math.pi * 180 # pred: (n, [xylsθ, conf, cls]) θ[0, 180) 246 | 247 | # Per target class 248 | for cls in torch.unique(tcls_tensor): 249 | ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices 250 | pi = (cls == ppred[:, 6]).nonzero(as_tuple=False).view(-1) # target indices 251 | 252 | # Search for detections 253 | if pi.shape[0]: 254 | compute_ious = compute_iou_test(tbox[ti], ppred[pi, :5], height) 255 | compute_ious = compute_ious.T 256 | ious, i = compute_ious.max(1) # best ious, indices 257 | ious = ious.cpu() 258 | 259 | # mAP 260 | detected_set = set() 261 | for j in (torch.Tensor(ious) > iouv[0]).nonzero(as_tuple=False): 262 | d = ti[i[j]] # detected target 263 | if d.item() not in detected_set: 264 | detected_set.add(d.item()) 265 | detected.append(d) 266 | correct[pi[j]] = ious[j] > iouv 267 | if len(detected) == nl: 268 | break 269 | 270 | # Append statistics (correct, conf, pcls, tcls) 271 | stats.append((correct.cpu(), pred[:, 5].cpu(), pred[:, 6].cpu(), tcls)) 272 | 273 | # Plot images 274 | plots = True 275 | if plots and batch_i < 20: 276 | f = save_dir / f'test{batch_i}_labels.png' # labels 277 | targets[:, 2:6] /= torch.Tensor([width, height, width, height]).to(device) 278 | "targets: [img_index, clsid cx cy l s(0-1) theta]) θ[0, 180) [n, 7]" 279 | Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() 280 | 281 | f = save_dir / f'test{batch_i}_pred.png' # predictions 282 | "out: x y w h angle(-pi/2~pi/2) conf cls list{[m, 7]}" 283 | for out_i in range(len(out)): 284 | out[out_i][:, 0] = out[out_i][:, 0] / width 285 | out[out_i][:, 1] = out[out_i][:, 1] / height 286 | out[out_i][:, 2] = out[out_i][:, 2] / width 287 | out[out_i][:, 3] = out[out_i][:, 3] / height 288 | out[out_i][:, 4] = (out[out_i][:, 4] + math.pi/2) / math.pi * 180 289 | "out: x y w h(0-1) angle[0, 180) conf cls" 290 | 291 | "output_to_target: [img_index, class_id, x, y, w, h(0-1), conf, angle[0, 180)]" 292 | Thread(target=plot_images, args=(img, output_to_target(out, width, height), paths, f, names), daemon=True).start() 293 | # Compute statistics 294 | stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy 295 | if len(stats) and stats[0].any(): 296 | p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, v5_metric=v5_metric, save_dir=save_dir, names=names) 297 | ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 298 | mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() 299 | nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class 300 | else: 301 | nt = torch.zeros(1) 302 | 303 | # Print results 304 | pf = '%20s' + '%12i' * 2 + '%12.3g' * 4 # print format 305 | print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) 306 | 307 | # Print results per class 308 | if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats): 309 | for i, c in enumerate(ap_class): 310 | print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) 311 | 312 | # Print speeds 313 | t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple 314 | if not training: 315 | print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t) 316 | 317 | # Plots 318 | if plots: 319 | confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) 320 | if wandb_logger and wandb_logger.wandb: 321 | val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))] 322 | wandb_logger.log({"Validation": val_batches}) 323 | if wandb_images: 324 | wandb_logger.log({"Bounding Box Debugger/Images": wandb_images}) 325 | 326 | # Save JSON 327 | if save_json and len(jdict): 328 | w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights 329 | anno_json = './coco/annotations/instances_val2017.json' # annotations json 330 | pred_json = str(save_dir / f"{w}_predictions.json") # predictions json 331 | print('\nEvaluating pycocotools mAP... saving %s...' % pred_json) 332 | with open(pred_json, 'w') as f: 333 | json.dump(jdict, f) 334 | 335 | try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb 336 | from pycocotools.coco import COCO 337 | from pycocotools.cocoeval import COCOeval 338 | 339 | anno = COCO(anno_json) # init annotations api 340 | pred = anno.loadRes(pred_json) # init predictions api 341 | eval = COCOeval(anno, pred, 'bbox') 342 | if is_coco: 343 | eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.img_files] # image IDs to evaluate 344 | eval.evaluate() 345 | eval.accumulate() 346 | eval.summarize() 347 | map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) 348 | except Exception as e: 349 | print(f'pycocotools unable to run: {e}') 350 | 351 | # Return results 352 | model.float() # for training 353 | if not training: 354 | s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' 355 | print(f"Results saved to {save_dir}{s}") 356 | maps = np.zeros(nc) + map 357 | for i, c in enumerate(ap_class): 358 | maps[c] = ap[i] 359 | return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t 360 | 361 | 362 | if __name__ == '__main__': 363 | parser = argparse.ArgumentParser(prog='test.py') 364 | parser.add_argument('--weights', nargs='+', type=str, default='weights/*.pt', help='model.pt path(s)') 365 | parser.add_argument('--data', type=str, default='data/dota.yaml', help='*.data path') 366 | parser.add_argument('--batch-size', type=int, default=4, help='size of each image batch') 367 | parser.add_argument('--img-size', type=int, default=1024, help='inference size (pixels)') 368 | parser.add_argument('--conf-thres', type=float, default=0.001, help='object confidence threshold') 369 | parser.add_argument('--iou-thres', type=float, default=0.65, help='IOU threshold for NMS') 370 | parser.add_argument('--task', default='val', help='train, val, test, speed or study') 371 | parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 372 | parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') 373 | parser.add_argument('--augment', action='store_true', help='augmented inference') 374 | parser.add_argument('--verbose', action='store_true', help='report mAP by class') 375 | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') 376 | parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') 377 | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') 378 | parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') 379 | parser.add_argument('--project', default='runs/test', help='save to project/name') 380 | parser.add_argument('--name', default='exp', help='save to project/name') 381 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 382 | parser.add_argument('--no-trace', action='store_true', help='don`t trace model') 383 | parser.add_argument('--v5-metric', action='store_true', help='assume maximum recall as 1.0 in AP calculation') 384 | opt = parser.parse_args() 385 | opt.save_json |= opt.data.endswith('coco.yaml') 386 | opt.data = check_file(opt.data) # check file 387 | print(opt) 388 | #check_requirements() 389 | 390 | if opt.task in ('train', 'val', 'test'): # run normally 391 | test(opt.data, 392 | opt.weights, 393 | opt.batch_size, 394 | opt.img_size, 395 | opt.conf_thres, 396 | opt.iou_thres, 397 | opt.save_json, 398 | opt.single_cls, 399 | opt.augment, 400 | opt.verbose, 401 | save_txt=opt.save_txt | opt.save_hybrid, 402 | save_hybrid=opt.save_hybrid, 403 | save_conf=opt.save_conf, 404 | trace=not opt.no_trace, 405 | v5_metric=opt.v5_metric 406 | ) 407 | 408 | elif opt.task == 'speed': # speed benchmarks 409 | for w in opt.weights: 410 | test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, v5_metric=opt.v5_metric) 411 | 412 | elif opt.task == 'study': # run over a range of settings and save/plot 413 | # python test.py --task study --data coco.yaml --iou 0.65 --weights yolov7.pt 414 | x = list(range(256, 1536 + 128, 128)) # x axis (image sizes) 415 | for w in opt.weights: 416 | f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to 417 | y = [] # y axis 418 | for i in x: # img-size 419 | print(f'\nRunning {f} point {i}...') 420 | r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json, 421 | plots=False, v5_metric=opt.v5_metric) 422 | y.append(r + t) # results and times 423 | np.savetxt(f, y, fmt='%10.4g') # save 424 | os.system('zip -r study.zip study_*.txt') 425 | plot_study_txt(x=x) # plot 426 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # init -------------------------------------------------------------------------------- /utils/activations.py: -------------------------------------------------------------------------------- 1 | # Activation functions 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # SiLU https://arxiv.org/pdf/1606.08415.pdf ---------------------------------------------------------------------------- 9 | class SiLU(nn.Module): # export-friendly version of nn.SiLU() 10 | @staticmethod 11 | def forward(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | class Hardswish(nn.Module): # export-friendly version of nn.Hardswish() 16 | @staticmethod 17 | def forward(x): 18 | # return x * F.hardsigmoid(x) # for torchscript and CoreML 19 | return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX 20 | 21 | 22 | class MemoryEfficientSwish(nn.Module): 23 | class F(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, x): 26 | ctx.save_for_backward(x) 27 | return x * torch.sigmoid(x) 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | x = ctx.saved_tensors[0] 32 | sx = torch.sigmoid(x) 33 | return grad_output * (sx * (1 + x * (1 - sx))) 34 | 35 | def forward(self, x): 36 | return self.F.apply(x) 37 | 38 | 39 | # Mish https://github.com/digantamisra98/Mish -------------------------------------------------------------------------- 40 | class Mish(nn.Module): 41 | @staticmethod 42 | def forward(x): 43 | return x * F.softplus(x).tanh() 44 | 45 | 46 | class MemoryEfficientMish(nn.Module): 47 | class F(torch.autograd.Function): 48 | @staticmethod 49 | def forward(ctx, x): 50 | ctx.save_for_backward(x) 51 | return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | x = ctx.saved_tensors[0] 56 | sx = torch.sigmoid(x) 57 | fx = F.softplus(x).tanh() 58 | return grad_output * (fx + x * sx * (1 - fx * fx)) 59 | 60 | def forward(self, x): 61 | return self.F.apply(x) 62 | 63 | 64 | # FReLU https://arxiv.org/abs/2007.11824 ------------------------------------------------------------------------------- 65 | class FReLU(nn.Module): 66 | def __init__(self, c1, k=3): # ch_in, kernel 67 | super().__init__() 68 | self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False) 69 | self.bn = nn.BatchNorm2d(c1) 70 | 71 | def forward(self, x): 72 | return torch.max(x, self.bn(self.conv(x))) 73 | -------------------------------------------------------------------------------- /utils/add_nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | from onnx import shape_inference 4 | try: 5 | import onnx_graphsurgeon as gs 6 | except Exception as e: 7 | print('Import onnx_graphsurgeon failure: %s' % e) 8 | 9 | import logging 10 | 11 | LOGGER = logging.getLogger(__name__) 12 | 13 | class RegisterNMS(object): 14 | def __init__( 15 | self, 16 | onnx_model_path: str, 17 | precision: str = "fp32", 18 | ): 19 | 20 | self.graph = gs.import_onnx(onnx.load(onnx_model_path)) 21 | assert self.graph 22 | LOGGER.info("ONNX graph created successfully") 23 | # Fold constants via ONNX-GS that PyTorch2ONNX may have missed 24 | self.graph.fold_constants() 25 | self.precision = precision 26 | self.batch_size = 1 27 | def infer(self): 28 | """ 29 | Sanitize the graph by cleaning any unconnected nodes, do a topological resort, 30 | and fold constant inputs values. When possible, run shape inference on the 31 | ONNX graph to determine tensor shapes. 32 | """ 33 | for _ in range(3): 34 | count_before = len(self.graph.nodes) 35 | 36 | self.graph.cleanup().toposort() 37 | try: 38 | for node in self.graph.nodes: 39 | for o in node.outputs: 40 | o.shape = None 41 | model = gs.export_onnx(self.graph) 42 | model = shape_inference.infer_shapes(model) 43 | self.graph = gs.import_onnx(model) 44 | except Exception as e: 45 | LOGGER.info(f"Shape inference could not be performed at this time:\n{e}") 46 | try: 47 | self.graph.fold_constants(fold_shapes=True) 48 | except TypeError as e: 49 | LOGGER.error( 50 | "This version of ONNX GraphSurgeon does not support folding shapes, " 51 | f"please upgrade your onnx_graphsurgeon module. Error:\n{e}" 52 | ) 53 | raise 54 | 55 | count_after = len(self.graph.nodes) 56 | if count_before == count_after: 57 | # No new folding occurred in this iteration, so we can stop for now. 58 | break 59 | 60 | def save(self, output_path): 61 | """ 62 | Save the ONNX model to the given location. 63 | Args: 64 | output_path: Path pointing to the location where to write 65 | out the updated ONNX model. 66 | """ 67 | self.graph.cleanup().toposort() 68 | model = gs.export_onnx(self.graph) 69 | onnx.save(model, output_path) 70 | LOGGER.info(f"Saved ONNX model to {output_path}") 71 | 72 | def register_nms( 73 | self, 74 | *, 75 | score_thresh: float = 0.25, 76 | nms_thresh: float = 0.45, 77 | detections_per_img: int = 100, 78 | ): 79 | """ 80 | Register the ``EfficientNMS_TRT`` plugin node. 81 | NMS expects these shapes for its input tensors: 82 | - box_net: [batch_size, number_boxes, 4] 83 | - class_net: [batch_size, number_boxes, number_labels] 84 | Args: 85 | score_thresh (float): The scalar threshold for score (low scoring boxes are removed). 86 | nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU 87 | overlap with previously selected boxes are removed). 88 | detections_per_img (int): Number of best detections to keep after NMS. 89 | """ 90 | 91 | self.infer() 92 | # Find the concat node at the end of the network 93 | op_inputs = self.graph.outputs 94 | op = "EfficientNMS_TRT" 95 | attrs = { 96 | "plugin_version": "1", 97 | "background_class": -1, # no background class 98 | "max_output_boxes": detections_per_img, 99 | "score_threshold": score_thresh, 100 | "iou_threshold": nms_thresh, 101 | "score_activation": False, 102 | "box_coding": 0, 103 | } 104 | 105 | if self.precision == "fp32": 106 | dtype_output = np.float32 107 | elif self.precision == "fp16": 108 | dtype_output = np.float16 109 | else: 110 | raise NotImplementedError(f"Currently not supports precision: {self.precision}") 111 | 112 | # NMS Outputs 113 | output_num_detections = gs.Variable( 114 | name="num_dets", 115 | dtype=np.int32, 116 | shape=[self.batch_size, 1], 117 | ) # A scalar indicating the number of valid detections per batch image. 118 | output_boxes = gs.Variable( 119 | name="det_boxes", 120 | dtype=dtype_output, 121 | shape=[self.batch_size, detections_per_img, 4], 122 | ) 123 | output_scores = gs.Variable( 124 | name="det_scores", 125 | dtype=dtype_output, 126 | shape=[self.batch_size, detections_per_img], 127 | ) 128 | output_labels = gs.Variable( 129 | name="det_classes", 130 | dtype=np.int32, 131 | shape=[self.batch_size, detections_per_img], 132 | ) 133 | 134 | op_outputs = [output_num_detections, output_boxes, output_scores, output_labels] 135 | 136 | # Create the NMS Plugin node with the selected inputs. The outputs of the node will also 137 | # become the final outputs of the graph. 138 | self.graph.layer(op=op, name="batched_nms", inputs=op_inputs, outputs=op_outputs, attrs=attrs) 139 | LOGGER.info(f"Created NMS plugin '{op}' with attributes: {attrs}") 140 | 141 | self.graph.outputs = op_outputs 142 | 143 | self.infer() 144 | 145 | def save(self, output_path): 146 | """ 147 | Save the ONNX model to the given location. 148 | Args: 149 | output_path: Path pointing to the location where to write 150 | out the updated ONNX model. 151 | """ 152 | self.graph.cleanup().toposort() 153 | model = gs.export_onnx(self.graph) 154 | onnx.save(model, output_path) 155 | LOGGER.info(f"Saved ONNX model to {output_path}") 156 | -------------------------------------------------------------------------------- /utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # Auto-anchor utils 2 | 3 | import numpy as np 4 | import torch 5 | import yaml 6 | from scipy.cluster.vq import kmeans 7 | from tqdm import tqdm 8 | from EigenTheda import Eigen2x, xywhTheda2Eigen_numpy 9 | 10 | from utils.general import colorstr 11 | 12 | 13 | def check_anchor_order(m): 14 | # Check anchor order against stride order for YOLO Detect() module m, and correct if necessary 15 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 16 | da = a[-1] - a[0] # delta a 17 | ds = m.stride[-1] - m.stride[0] # delta s 18 | if da.sign() != ds.sign(): # same order 19 | print('Reversing anchor order') 20 | m.anchors[:] = m.anchors.flip(0) 21 | m.anchor_grid[:] = m.anchor_grid.flip(0) 22 | 23 | 24 | def check_anchors(dataset, model, thr=4.0, imgsz=640): 25 | # Check anchor fit to data, recompute if necessary 26 | prefix = colorstr('autoanchor: ') 27 | print(f'\n{prefix}Analyzing anchors... ', end='') 28 | m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() 29 | shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) 30 | # scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale 31 | 32 | # labels: [cls, x, y, w, h, angle] 33 | transform_labels = np.concatenate(dataset.labels) 34 | transform_labels = transform_labels[..., 1:6] 35 | transform_labels = Eigen2x(xywhTheda2Eigen_numpy(transform_labels)) 36 | 37 | wh = torch.tensor(np.concatenate([l[2:5].reshape(1, -1) * imgsz for l in transform_labels])).float() # wh 38 | 39 | def metric(k): # compute metric 40 | r = wh[:, None] / k[None] 41 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 42 | best = x.max(1)[0] # best_x 43 | aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold 44 | bpr = (best > 1. / thr).float().mean() # best possible recall 45 | return bpr, aat 46 | 47 | anchors = m.anchor_grid.clone().cpu().view(-1, 3) # current anchors 48 | bpr, aat = metric(anchors) 49 | print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') 50 | if bpr < 0.8: # threshold to recompute 51 | print('. Attempting to improve anchors, please wait...') 52 | na = m.anchor_grid.numel() // 2 # number of anchors 53 | try: 54 | anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) 55 | except Exception as e: 56 | print(f'{prefix}ERROR: {e}') 57 | new_bpr = metric(anchors)[0] 58 | if new_bpr > bpr: # replace anchors 59 | anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) 60 | m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference 61 | check_anchor_order(m) 62 | m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss 63 | print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') 64 | else: 65 | print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.') 66 | print('') # newline 67 | 68 | 69 | def kmean_anchors(path='./data/coco.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): 70 | """ Creates kmeans-evolved anchors from training dataset 71 | 72 | Arguments: 73 | path: path to dataset *.yaml, or a loaded dataset 74 | n: number of anchors 75 | img_size: image size used for training 76 | thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 77 | gen: generations to evolve anchors using genetic algorithm 78 | verbose: print all results 79 | 80 | Return: 81 | k: kmeans evolved anchors 82 | 83 | Usage: 84 | from utils.autoanchor import *; _ = kmean_anchors() 85 | """ 86 | thr = 1. / thr 87 | prefix = colorstr('autoanchor: ') 88 | 89 | def metric(k, wh): # compute metrics 90 | r = wh[:, None] / k[None] 91 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 92 | # x = wh_iou(wh, torch.tensor(k)) # iou metric 93 | return x, x.max(1)[0] # x, best_x 94 | 95 | def anchor_fitness(k): # mutation fitness 96 | _, best = metric(torch.tensor(k, dtype=torch.float32), wh) 97 | return (best * (best > thr).float()).mean() # fitness 98 | 99 | def print_results(k): 100 | k = k[np.argsort(k.prod(1))] # sort small to large 101 | x, best = metric(k, wh0) 102 | bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr 103 | print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr') 104 | print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' 105 | f'past_thr={x[x > thr].mean():.3f}-mean: ', end='') 106 | for i, x in enumerate(k): 107 | print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg 108 | return k 109 | 110 | if isinstance(path, str): # *.yaml file 111 | with open(path) as f: 112 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict 113 | from utils.datasets import LoadImagesAndLabels 114 | dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) 115 | else: 116 | dataset = path # dataset 117 | 118 | # Get label wh 119 | shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) 120 | wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh 121 | 122 | # Filter 123 | i = (wh0 < 3.0).any(1).sum() 124 | if i: 125 | print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') 126 | wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels 127 | # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 128 | 129 | # Kmeans calculation 130 | print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...') 131 | s = wh.std(0) # sigmas for whitening 132 | k, dist = kmeans(wh / s, n, iter=30) # points, mean distance 133 | assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}') 134 | k *= s 135 | wh = torch.tensor(wh, dtype=torch.float32) # filtered 136 | wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered 137 | k = print_results(k) 138 | 139 | # Plot 140 | # k, d = [None] * 20, [None] * 20 141 | # for i in tqdm(range(1, 21)): 142 | # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance 143 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) 144 | # ax = ax.ravel() 145 | # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') 146 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh 147 | # ax[0].hist(wh[wh[:, 0]<100, 0],400) 148 | # ax[1].hist(wh[wh[:, 1]<100, 1],400) 149 | # fig.savefig('wh.png', dpi=200) 150 | 151 | # Evolve 152 | npr = np.random 153 | f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma 154 | pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar 155 | for _ in pbar: 156 | v = np.ones(sh) 157 | while (v == 1).all(): # mutate until a change occurs (prevent duplicates) 158 | v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) 159 | kg = (k.copy() * v).clip(min=2.0) 160 | fg = anchor_fitness(kg) 161 | if fg > f: 162 | f, k = fg, kg.copy() 163 | pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' 164 | if verbose: 165 | print_results(k) 166 | 167 | return print_results(k) 168 | -------------------------------------------------------------------------------- /utils/aws/__init__.py: -------------------------------------------------------------------------------- 1 | #init -------------------------------------------------------------------------------- /utils/aws/mime.sh: -------------------------------------------------------------------------------- 1 | # AWS EC2 instance startup 'MIME' script https://aws.amazon.com/premiumsupport/knowledge-center/execute-user-data-ec2/ 2 | # This script will run on every instance restart, not only on first start 3 | # --- DO NOT COPY ABOVE COMMENTS WHEN PASTING INTO USERDATA --- 4 | 5 | Content-Type: multipart/mixed; boundary="//" 6 | MIME-Version: 1.0 7 | 8 | --// 9 | Content-Type: text/cloud-config; charset="us-ascii" 10 | MIME-Version: 1.0 11 | Content-Transfer-Encoding: 7bit 12 | Content-Disposition: attachment; filename="cloud-config.txt" 13 | 14 | #cloud-config 15 | cloud_final_modules: 16 | - [scripts-user, always] 17 | 18 | --// 19 | Content-Type: text/x-shellscript; charset="us-ascii" 20 | MIME-Version: 1.0 21 | Content-Transfer-Encoding: 7bit 22 | Content-Disposition: attachment; filename="userdata.txt" 23 | 24 | #!/bin/bash 25 | # --- paste contents of userdata.sh here --- 26 | --// 27 | -------------------------------------------------------------------------------- /utils/aws/resume.py: -------------------------------------------------------------------------------- 1 | # Resume all interrupted trainings in yolor/ dir including DDP trainings 2 | # Usage: $ python utils/aws/resume.py 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import torch 9 | import yaml 10 | 11 | sys.path.append('./') # to run '$ python *.py' files in subdirectories 12 | 13 | port = 0 # --master_port 14 | path = Path('').resolve() 15 | for last in path.rglob('*/**/last.pt'): 16 | ckpt = torch.load(last) 17 | if ckpt['optimizer'] is None: 18 | continue 19 | 20 | # Load opt.yaml 21 | with open(last.parent.parent / 'opt.yaml') as f: 22 | opt = yaml.load(f, Loader=yaml.SafeLoader) 23 | 24 | # Get device count 25 | d = opt['device'].split(',') # devices 26 | nd = len(d) # number of devices 27 | ddp = nd > 1 or (nd == 0 and torch.cuda.device_count() > 1) # distributed data parallel 28 | 29 | if ddp: # multi-GPU 30 | port += 1 31 | cmd = f'python -m torch.distributed.launch --nproc_per_node {nd} --master_port {port} train.py --resume {last}' 32 | else: # single-GPU 33 | cmd = f'python train.py --resume {last}' 34 | 35 | cmd += ' > /dev/null 2>&1 &' # redirect output to dev/null and run in daemon thread 36 | print(cmd) 37 | os.system(cmd) 38 | -------------------------------------------------------------------------------- /utils/aws/userdata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # AWS EC2 instance startup script https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html 3 | # This script will run only once on first instance start (for a re-start script see mime.sh) 4 | # /home/ubuntu (ubuntu) or /home/ec2-user (amazon-linux) is working dir 5 | # Use >300 GB SSD 6 | 7 | cd home/ubuntu 8 | if [ ! -d yolor ]; then 9 | echo "Running first-time script." # install dependencies, download COCO, pull Docker 10 | git clone -b main https://github.com/WongKinYiu/yolov7 && sudo chmod -R 777 yolov7 11 | cd yolov7 12 | bash data/scripts/get_coco.sh && echo "Data done." & 13 | sudo docker pull nvcr.io/nvidia/pytorch:21.08-py3 && echo "Docker done." & 14 | python -m pip install --upgrade pip && pip install -r requirements.txt && python detect.py && echo "Requirements done." & 15 | wait && echo "All tasks done." # finish background tasks 16 | else 17 | echo "Running re-start script." # resume interrupted runs 18 | i=0 19 | list=$(sudo docker ps -qa) # container list i.e. $'one\ntwo\nthree\nfour' 20 | while IFS= read -r id; do 21 | ((i++)) 22 | echo "restarting container $i: $id" 23 | sudo docker start $id 24 | # sudo docker exec -it $id python train.py --resume # single-GPU 25 | sudo docker exec -d $id python utils/aws/resume.py # multi-scenario 26 | done <<<"$list" 27 | fi 28 | -------------------------------------------------------------------------------- /utils/google_utils.py: -------------------------------------------------------------------------------- 1 | # Google utils: https://cloud.google.com/storage/docs/reference/libraries 2 | 3 | import os 4 | import platform 5 | import subprocess 6 | import time 7 | from pathlib import Path 8 | 9 | import requests 10 | import torch 11 | 12 | 13 | def gsutil_getsize(url=''): 14 | # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du 15 | s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8') 16 | return eval(s.split(' ')[0]) if len(s) else 0 # bytes 17 | 18 | 19 | def attempt_download(file, repo='WongKinYiu/yolov7'): 20 | # Attempt file download if does not exist 21 | file = Path(str(file).strip().replace("'", '').lower()) 22 | 23 | if not file.exists(): 24 | try: 25 | response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api 26 | assets = [x['name'] for x in response['assets']] # release assets 27 | tag = response['tag_name'] # i.e. 'v1.0' 28 | except: # fallback plan 29 | assets = ['yolov7.pt', 'yolov7-tiny.pt', 'yolov7x.pt', 'yolov7-d6.pt', 'yolov7-e6.pt', 30 | 'yolov7-e6e.pt', 'yolov7-w6.pt'] 31 | tag = subprocess.check_output('git tag', shell=True).decode().split()[-1] 32 | 33 | name = file.name 34 | if name in assets: 35 | msg = f'{file} missing, try downloading from https://github.com/{repo}/releases/' 36 | redundant = False # second download option 37 | try: # GitHub 38 | url = f'https://github.com/{repo}/releases/download/{tag}/{name}' 39 | print(f'Downloading {url} to {file}...') 40 | torch.hub.download_url_to_file(url, file) 41 | assert file.exists() and file.stat().st_size > 1E6 # check 42 | except Exception as e: # GCP 43 | print(f'Download error: {e}') 44 | assert redundant, 'No secondary mirror' 45 | url = f'https://storage.googleapis.com/{repo}/ckpt/{name}' 46 | print(f'Downloading {url} to {file}...') 47 | os.system(f'curl -L {url} -o {file}') # torch.hub.download_url_to_file(url, weights) 48 | finally: 49 | if not file.exists() or file.stat().st_size < 1E6: # check 50 | file.unlink(missing_ok=True) # remove partial downloads 51 | print(f'ERROR: Download failure: {msg}') 52 | print('') 53 | return 54 | 55 | 56 | def gdrive_download(id='', file='tmp.zip'): 57 | # Downloads a file from Google Drive. from yolov7.utils.google_utils import *; gdrive_download() 58 | t = time.time() 59 | file = Path(file) 60 | cookie = Path('cookie') # gdrive cookie 61 | print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='') 62 | file.unlink(missing_ok=True) # remove existing file 63 | cookie.unlink(missing_ok=True) # remove existing cookie 64 | 65 | # Attempt file download 66 | out = "NUL" if platform.system() == "Windows" else "/dev/null" 67 | os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}') 68 | if os.path.exists('cookie'): # large file 69 | s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}' 70 | else: # small file 71 | s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"' 72 | r = os.system(s) # execute, capture return 73 | cookie.unlink(missing_ok=True) # remove existing cookie 74 | 75 | # Error check 76 | if r != 0: 77 | file.unlink(missing_ok=True) # remove partial 78 | print('Download error ') # raise Exception('Download error') 79 | return r 80 | 81 | # Unzip if archive 82 | if file.suffix == '.zip': 83 | print('unzipping... ', end='') 84 | os.system(f'unzip -q {file}') # unzip 85 | file.unlink() # remove zip to free space 86 | 87 | print(f'Done ({time.time() - t:.1f}s)') 88 | return r 89 | 90 | 91 | def get_token(cookie="./cookie"): 92 | with open(cookie) as f: 93 | for line in f: 94 | if "download" in line: 95 | return line.split()[-1] 96 | return "" 97 | 98 | # def upload_blob(bucket_name, source_file_name, destination_blob_name): 99 | # # Uploads a file to a bucket 100 | # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python 101 | # 102 | # storage_client = storage.Client() 103 | # bucket = storage_client.get_bucket(bucket_name) 104 | # blob = bucket.blob(destination_blob_name) 105 | # 106 | # blob.upload_from_filename(source_file_name) 107 | # 108 | # print('File {} uploaded to {}.'.format( 109 | # source_file_name, 110 | # destination_blob_name)) 111 | # 112 | # 113 | # def download_blob(bucket_name, source_blob_name, destination_file_name): 114 | # # Uploads a blob from a bucket 115 | # storage_client = storage.Client() 116 | # bucket = storage_client.get_bucket(bucket_name) 117 | # blob = bucket.blob(source_blob_name) 118 | # 119 | # blob.download_to_filename(destination_file_name) 120 | # 121 | # print('Blob {} downloaded to {}.'.format( 122 | # source_blob_name, 123 | # destination_file_name)) 124 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Model validation metrics 2 | 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | 9 | from . import general 10 | 11 | 12 | def fitness(x): 13 | # Model fitness as a weighted combination of metrics 14 | w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] 15 | return (x[:, :4] * w).sum(1) 16 | 17 | 18 | def ap_per_class(tp, conf, pred_cls, target_cls, v5_metric=False, plot=False, save_dir='.', names=()): 19 | """ Compute the average precision, given the recall and precision curves. 20 | Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. 21 | # Arguments 22 | tp: True positives (nparray, nx1 or nx10). 23 | conf: Objectness value from 0-1 (nparray). 24 | pred_cls: Predicted object classes (nparray). 25 | target_cls: True object classes (nparray). 26 | plot: Plot precision-recall curve at mAP@0.5 27 | save_dir: Plot save directory 28 | # Returns 29 | The average precision as computed in py-faster-rcnn. 30 | """ 31 | 32 | # Sort by objectness 33 | i = np.argsort(-conf) 34 | tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] 35 | 36 | # Find unique classes 37 | unique_classes = np.unique(target_cls) 38 | nc = unique_classes.shape[0] # number of classes, number of detections 39 | 40 | # Create Precision-Recall curve and compute AP for each class 41 | px, py = np.linspace(0, 1, 1000), [] # for plotting 42 | ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) 43 | for ci, c in enumerate(unique_classes): 44 | i = pred_cls == c 45 | n_l = (target_cls == c).sum() # number of labels 46 | n_p = i.sum() # number of predictions 47 | 48 | if n_p == 0 or n_l == 0: 49 | continue 50 | else: 51 | # Accumulate FPs and TPs 52 | fpc = (1 - tp[i]).cumsum(0) 53 | tpc = tp[i].cumsum(0) 54 | 55 | # Recall 56 | recall = tpc / (n_l + 1e-16) # recall curve 57 | r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases 58 | 59 | # Precision 60 | precision = tpc / (tpc + fpc) # precision curve 61 | p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score 62 | 63 | # AP from recall-precision curve 64 | for j in range(tp.shape[1]): 65 | ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j], v5_metric=v5_metric) 66 | if plot and j == 0: 67 | py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5 68 | 69 | # Compute F1 (harmonic mean of precision and recall) 70 | f1 = 2 * p * r / (p + r + 1e-16) 71 | if plot: 72 | plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) 73 | plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') 74 | plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') 75 | plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') 76 | 77 | i = f1.mean(0).argmax() # max F1 index 78 | return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') 79 | 80 | 81 | def compute_ap(recall, precision, v5_metric=False): 82 | """ Compute the average precision, given the recall and precision curves 83 | # Arguments 84 | recall: The recall curve (list) 85 | precision: The precision curve (list) 86 | v5_metric: Assume maximum recall to be 1.0, as in YOLOv5, MMDetetion etc. 87 | # Returns 88 | Average precision, precision curve, recall curve 89 | """ 90 | 91 | # Append sentinel values to beginning and end 92 | if v5_metric: # New YOLOv5 metric, same as MMDetection and Detectron2 repositories 93 | mrec = np.concatenate(([0.], recall, [1.0])) 94 | else: # Old YOLOv5 metric, i.e. default YOLOv7 metric 95 | mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) 96 | mpre = np.concatenate(([1.], precision, [0.])) 97 | 98 | # Compute the precision envelope 99 | mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) 100 | 101 | # Integrate area under curve 102 | method = 'interp' # methods: 'continuous', 'interp' 103 | if method == 'interp': 104 | x = np.linspace(0, 1, 101) # 101-point interp (COCO) 105 | ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate 106 | else: # 'continuous' 107 | i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes 108 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve 109 | 110 | return ap, mpre, mrec 111 | 112 | 113 | class ConfusionMatrix: 114 | # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix 115 | def __init__(self, nc, conf=0.25, iou_thres=0.45): 116 | self.matrix = np.zeros((nc + 1, nc + 1)) 117 | self.nc = nc # number of classes 118 | self.conf = conf 119 | self.iou_thres = iou_thres 120 | 121 | def process_batch(self, detections, labels): 122 | """ 123 | Return intersection-over-union (Jaccard index) of boxes. 124 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. 125 | Arguments: 126 | detections (Array[N, 6]), x1, y1, x2, y2, conf, class 127 | labels (Array[M, 5]), class, x1, y1, x2, y2 128 | Returns: 129 | None, updates confusion matrix accordingly 130 | """ 131 | detections = detections[detections[:, 4] > self.conf] 132 | gt_classes = labels[:, 0].int() 133 | detection_classes = detections[:, 5].int() 134 | iou = general.box_iou(labels[:, 1:], detections[:, :4]) 135 | 136 | x = torch.where(iou > self.iou_thres) 137 | if x[0].shape[0]: 138 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() 139 | if x[0].shape[0] > 1: 140 | matches = matches[matches[:, 2].argsort()[::-1]] 141 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]] 142 | matches = matches[matches[:, 2].argsort()[::-1]] 143 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]] 144 | else: 145 | matches = np.zeros((0, 3)) 146 | 147 | n = matches.shape[0] > 0 148 | m0, m1, _ = matches.transpose().astype(np.int16) 149 | for i, gc in enumerate(gt_classes): 150 | j = m0 == i 151 | if n and sum(j) == 1: 152 | self.matrix[gc, detection_classes[m1[j]]] += 1 # correct 153 | else: 154 | self.matrix[self.nc, gc] += 1 # background FP 155 | 156 | if n: 157 | for i, dc in enumerate(detection_classes): 158 | if not any(m1 == i): 159 | self.matrix[dc, self.nc] += 1 # background FN 160 | 161 | def matrix(self): 162 | return self.matrix 163 | 164 | def plot(self, save_dir='', names=()): 165 | try: 166 | import seaborn as sn 167 | 168 | array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize 169 | array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) 170 | 171 | fig = plt.figure(figsize=(12, 9), tight_layout=True) 172 | sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size 173 | labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels 174 | sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True, 175 | xticklabels=names + ['background FP'] if labels else "auto", 176 | yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1)) 177 | fig.axes[0].set_xlabel('True') 178 | fig.axes[0].set_ylabel('Predicted') 179 | fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) 180 | except Exception as e: 181 | pass 182 | 183 | def print(self): 184 | for i in range(self.nc + 1): 185 | print(' '.join(map(str, self.matrix[i]))) 186 | 187 | 188 | # Plots ---------------------------------------------------------------------------------------------------------------- 189 | 190 | def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): 191 | # Precision-recall curve 192 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) 193 | py = np.stack(py, axis=1) 194 | 195 | if 0 < len(names) < 21: # display per-class legend if < 21 classes 196 | for i, y in enumerate(py.T): 197 | ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) 198 | else: 199 | ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) 200 | 201 | ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) 202 | ax.set_xlabel('Recall') 203 | ax.set_ylabel('Precision') 204 | ax.set_xlim(0, 1) 205 | ax.set_ylim(0, 1) 206 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") 207 | fig.savefig(Path(save_dir), dpi=250) 208 | 209 | 210 | def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): 211 | # Metric-confidence curve 212 | fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) 213 | 214 | if 0 < len(names) < 21: # display per-class legend if < 21 classes 215 | for i, y in enumerate(py): 216 | ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) 217 | else: 218 | ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) 219 | 220 | y = py.mean(0) 221 | ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') 222 | ax.set_xlabel(xlabel) 223 | ax.set_ylabel(ylabel) 224 | ax.set_xlim(0, 1) 225 | ax.set_ylim(0, 1) 226 | plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") 227 | fig.savefig(Path(save_dir), dpi=250) 228 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | # Plotting utils 2 | 3 | import glob 4 | import math 5 | import os 6 | import random 7 | from copy import copy 8 | from pathlib import Path 9 | 10 | import cv2 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import pandas as pd 15 | import seaborn as sns 16 | import torch 17 | import yaml 18 | from PIL import Image, ImageDraw, ImageFont 19 | from scipy.signal import butter, filtfilt 20 | 21 | from utils.general import xywh2xyxy, xyxy2xywh 22 | from utils.metrics import fitness 23 | 24 | # Settings 25 | matplotlib.rc('font', **{'size': 11}) 26 | matplotlib.use('Agg') # for writing to files only 27 | 28 | 29 | def color_list(): 30 | # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb 31 | def hex2rgb(h): 32 | return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) 33 | 34 | return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949) 35 | 36 | 37 | def hist2d(x, y, n=100): 38 | # 2d histogram used in labels.png and evolve.png 39 | xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n) 40 | hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges)) 41 | xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1) 42 | yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1) 43 | return np.log(hist[xidx, yidx]) 44 | 45 | 46 | def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5): 47 | # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy 48 | def butter_lowpass(cutoff, fs, order): 49 | nyq = 0.5 * fs 50 | normal_cutoff = cutoff / nyq 51 | return butter(order, normal_cutoff, btype='low', analog=False) 52 | 53 | b, a = butter_lowpass(cutoff, fs, order=order) 54 | return filtfilt(b, a, data) # forward-backward filter 55 | 56 | def xyxy2xywh_OneShape(x): 57 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right 58 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 59 | y[0] = (x[0] + x[2]) / 2 # x center 60 | y[1] = (x[1] + x[3]) / 2 # y center 61 | y[2] = x[2] - x[0] # width 62 | y[3] = x[3] - x[1] # height 63 | return y 64 | 65 | def rotation_boxes(old_boxes, new_boxes, angle): # old_boxes: [x1 y1 x2 y2], new_boxes: [x y w h] 66 | angle = (180 - angle) / 180 * math.pi 67 | 68 | blx = old_boxes[0] 69 | bly = old_boxes[1] 70 | brx = old_boxes[2] 71 | bry = old_boxes[3] 72 | cx = new_boxes[0] 73 | cy = new_boxes[1] 74 | 75 | X1 = (blx - cx) * math.cos(angle) - (cy - bly) * math.sin(angle) + cx 76 | Y1 = -((blx - cx) * math.sin(angle) + (cy - bly) * math.cos(angle)) + cy 77 | X2 = (brx - cx) * math.cos(angle) - (cy - bly) * math.sin(angle) + cx 78 | Y2 = -((brx - cx) * math.sin(angle) + (cy - bly) * math.cos(angle)) + cy 79 | X3 = (blx - cx) * math.cos(angle) - (cy - bry) * math.sin(angle) + cx 80 | Y3 = -((blx - cx) * math.sin(angle) + (cy - bry) * math.cos(angle)) + cy 81 | X4 = (brx - cx) * math.cos(angle) - (cy - bry) * math.sin(angle) + cx 82 | Y4 = -((brx - cx) * math.sin(angle) + (cy - bry) * math.cos(angle)) + cy 83 | 84 | return (int(X1), int(Y1)), (int(X2), int(Y2)), (int(X3), int(Y3)), (int(X4), int(Y4)) 85 | 86 | def plot_one_box(x, angle, img, color=None, label=None, line_thickness=1): 87 | # Plots one bounding box on image img 88 | tl = 2 89 | color = color or [random.randint(0, 255) for _ in range(3)] 90 | 91 | new_boxes = xyxy2xywh_OneShape(x) # xyxy → xywh 92 | P1, P2, P3, P4 = rotation_boxes(x, new_boxes, angle) 93 | cv2.line(img, P1, P2, color, thickness=tl, lineType=cv2.LINE_AA) 94 | cv2.line(img, P2, P4, color, thickness=tl, lineType=cv2.LINE_AA) 95 | cv2.line(img, P4, P3, color, thickness=tl, lineType=cv2.LINE_AA) 96 | cv2.line(img, P3, P1, color, thickness=tl, lineType=cv2.LINE_AA) 97 | 98 | if label: 99 | tf = 1 # font thickness 100 | cv2.putText(img, label, (P1[0], P1[1] - 2), 0, 0.5, [0, 0, 255], thickness=tf, lineType=cv2.LINE_AA) 101 | 102 | def plot_one_box_PIL(box, img, color=None, label=None, line_thickness=None): 103 | img = Image.fromarray(img) 104 | draw = ImageDraw.Draw(img) 105 | line_thickness = line_thickness or max(int(min(img.size) / 200), 2) 106 | draw.rectangle(box, width=line_thickness, outline=tuple(color)) # plot 107 | if label: 108 | fontsize = max(round(max(img.size) / 40), 12) 109 | font = ImageFont.truetype("Arial.ttf", fontsize) 110 | txt_width, txt_height = font.getsize(label) 111 | draw.rectangle([box[0], box[1] - txt_height + 4, box[0] + txt_width, box[1]], fill=tuple(color)) 112 | draw.text((box[0], box[1] - txt_height + 1), label, fill=(255, 255, 255), font=font) 113 | return np.asarray(img) 114 | 115 | 116 | def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() 117 | # Compares the two methods for width-height anchor multiplication 118 | # https://github.com/ultralytics/yolov3/issues/168 119 | x = np.arange(-4.0, 4.0, .1) 120 | ya = np.exp(x) 121 | yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 122 | 123 | fig = plt.figure(figsize=(6, 3), tight_layout=True) 124 | plt.plot(x, ya, '.-', label='YOLOv3') 125 | plt.plot(x, yb ** 2, '.-', label='YOLOR ^2') 126 | plt.plot(x, yb ** 1.6, '.-', label='YOLOR ^1.6') 127 | plt.xlim(left=-4, right=4) 128 | plt.ylim(bottom=0, top=6) 129 | plt.xlabel('input') 130 | plt.ylabel('output') 131 | plt.grid() 132 | plt.legend() 133 | fig.savefig('comparison.png', dpi=200) 134 | 135 | 136 | def output_to_target(output, width, height): 137 | "out: x y w h(0-1) angle[0, 180) conf cls" 138 | # Convert model output to target format [batch_id, class_id, x, y, w, h, conf, angle] 139 | targets = [] 140 | for i, o in enumerate(output): 141 | for x, y, w, h, angle, conf, cls in o.cpu().numpy(): 142 | targets.append([i, cls, x, y, w, h, conf, int(angle)]) 143 | return np.array(targets) 144 | 145 | 146 | def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): 147 | # Plot image grid with labels 148 | " " 149 | "out: [img_index, class_id, x, y, w, h(0-1), conf, angle[0, 180)] [n, 8]" 150 | "targets: [img_index, clsid cx cy l s(0-1) theta]) θ[0, 180) [n, 7]" 151 | 152 | if isinstance(images, torch.Tensor): 153 | images = images.cpu().float().numpy() 154 | if isinstance(targets, torch.Tensor): 155 | targets = targets.cpu().numpy() 156 | 157 | # un-normalise 158 | if np.max(images[0]) <= 1: 159 | images *= 255 160 | 161 | tl = 1 # line thickness 162 | tf = max(tl - 1, 1) # font thickness 163 | bs, _, h, w = images.shape # batch size, channel, height, width 164 | bs = min(bs, max_subplots) # limit plot images 165 | ns = np.ceil(bs ** 0.5) # number of subplots (square) 166 | 167 | # Check if we should resize 168 | max_size = max(h, w) 169 | scale_factor = max_size / max(h, w) 170 | if scale_factor < 1: 171 | h = math.ceil(scale_factor * h) 172 | w = math.ceil(scale_factor * w) 173 | 174 | colors = color_list() # list of colors 175 | mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init 176 | for i, img in enumerate(images): # img (3, 512, 512) 177 | if i == max_subplots: # if last batch has fewer images than we expect 178 | break 179 | 180 | block_x = int(w * (i // ns)) 181 | block_y = int(h * (i % ns)) 182 | 183 | img = img.transpose(1, 2, 0) # (512, 512, 3) h w c 184 | if scale_factor < 1: 185 | img = cv2.resize(img, (w, h)) 186 | 187 | mosaic[block_y:block_y + h, block_x:block_x + w, :] = img 188 | if len(targets) > 0: # [img_index, cls, x y w h conf angle] 189 | image_targets = targets[targets[:, 0] == i] 190 | boxes = xywh2xyxy(image_targets[:, 2:6]).T 191 | classes = image_targets[:, 1].astype('int') 192 | angles = image_targets[:, -1].astype('int') 193 | labels = image_targets.shape[1] == 8 # labels if no conf column 194 | conf = None if not labels else image_targets[:, 6] # check for confidence presence (label vs pred) 195 | 196 | if boxes.shape[1]: 197 | if boxes.max() <= 1.9: # if normalized with tolerance 0.01 198 | boxes[[0, 2]] *= w # scale to pixels 199 | boxes[[1, 3]] *= h 200 | elif scale_factor < 1: # absolute coords need scale if image scales 201 | boxes *= scale_factor 202 | boxes[[0, 2]] += block_x 203 | boxes[[1, 3]] += block_y 204 | 205 | for j, box in enumerate(boxes.T): 206 | cls = int(classes[j]) 207 | angl = int(angles[j]) 208 | color = colors[cls % len(colors)] 209 | # cls = names[cls] if names else cls 210 | 211 | if not labels or conf[j] > 0.25: # 0.25 conf thresh 212 | label = '%s %s' % (cls, angl) if not labels else '%s %s %.1f' % (cls, angl, conf[j]) 213 | plot_one_box(box, angl, mosaic, label=label, color=color) 214 | 215 | # Draw image filename labels 216 | if paths: 217 | label = Path(paths[i]).name[:40] # trim to 40 char 218 | t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] 219 | cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, 0.5, [220, 220, 220], thickness=tf, 220 | lineType=cv2.LINE_AA) 221 | 222 | # Image border 223 | cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3) 224 | 225 | if fname: 226 | # r = min((max_size*2) / max(h, w) / ns, 1.0) # ratio to limit image size 227 | # mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA) 228 | # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save 229 | Image.fromarray(mosaic).save(fname) # PIL save 230 | return mosaic 231 | 232 | 233 | def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): 234 | # Plot LR simulating training for full epochs 235 | optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals 236 | y = [] 237 | for _ in range(epochs): 238 | scheduler.step() 239 | y.append(optimizer.param_groups[0]['lr']) 240 | plt.plot(y, '.-', label='LR') 241 | plt.xlabel('epoch') 242 | plt.ylabel('LR') 243 | plt.grid() 244 | plt.xlim(0, epochs) 245 | plt.ylim(0) 246 | plt.savefig(Path(save_dir) / 'LR.png', dpi=200) 247 | plt.close() 248 | 249 | 250 | def plot_test_txt(): # from utils.plots import *; plot_test() 251 | # Plot test.txt histograms 252 | x = np.loadtxt('test.txt', dtype=np.float32) 253 | box = xyxy2xywh(x[:, :4]) 254 | cx, cy = box[:, 0], box[:, 1] 255 | 256 | fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True) 257 | ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0) 258 | ax.set_aspect('equal') 259 | plt.savefig('hist2d.png', dpi=300) 260 | 261 | fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) 262 | ax[0].hist(cx, bins=600) 263 | ax[1].hist(cy, bins=600) 264 | plt.savefig('hist1d.png', dpi=200) 265 | 266 | 267 | def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() 268 | # Plot targets.txt histograms 269 | x = np.loadtxt('targets.txt', dtype=np.float32).T 270 | s = ['x targets', 'y targets', 'width targets', 'height targets'] 271 | fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True) 272 | ax = ax.ravel() 273 | for i in range(4): 274 | ax[i].hist(x[i], bins=100, label='%.3g +/- %.3g' % (x[i].mean(), x[i].std())) 275 | ax[i].legend() 276 | ax[i].set_title(s[i]) 277 | plt.savefig('targets.jpg', dpi=200) 278 | 279 | 280 | def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() 281 | # Plot study.txt generated by test.py 282 | fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) 283 | # ax = ax.ravel() 284 | 285 | fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) 286 | # for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolor-p6', 'yolor-w6', 'yolor-e6', 'yolor-d6']]: 287 | for f in sorted(Path(path).glob('study*.txt')): 288 | y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T 289 | x = np.arange(y.shape[1]) if x is None else np.array(x) 290 | s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)'] 291 | # for i in range(7): 292 | # ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8) 293 | # ax[i].set_title(s[i]) 294 | 295 | j = y[3].argmax() + 1 296 | ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8, 297 | label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO')) 298 | 299 | ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5], 300 | 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') 301 | 302 | ax2.grid(alpha=0.2) 303 | ax2.set_yticks(np.arange(20, 60, 5)) 304 | ax2.set_xlim(0, 57) 305 | ax2.set_ylim(30, 55) 306 | ax2.set_xlabel('GPU Speed (ms/img)') 307 | ax2.set_ylabel('COCO AP val') 308 | ax2.legend(loc='lower right') 309 | plt.savefig(str(Path(path).name) + '.png', dpi=300) 310 | 311 | 312 | def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): 313 | # plot dataset labels 314 | print('Plotting labels... ') 315 | c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes 316 | nc = int(c.max() + 1) # number of classes 317 | colors = color_list() 318 | x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) 319 | 320 | # seaborn correlogram 321 | sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) 322 | plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) 323 | plt.close() 324 | 325 | # matplotlib labels 326 | matplotlib.use('svg') # faster 327 | ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() 328 | ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) 329 | ax[0].set_ylabel('instances') 330 | if 0 < len(names) < 30: 331 | ax[0].set_xticks(range(len(names))) 332 | ax[0].set_xticklabels(names, rotation=90, fontsize=10) 333 | else: 334 | ax[0].set_xlabel('classes') 335 | sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) 336 | sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) 337 | 338 | # rectangles 339 | labels[:, 1:3] = 0.5 # center 340 | labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000 341 | img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255) 342 | for cls, *box in labels[:1000]: 343 | ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot 344 | ax[1].imshow(img) 345 | ax[1].axis('off') 346 | 347 | for a in [0, 1, 2, 3]: 348 | for s in ['top', 'right', 'left', 'bottom']: 349 | ax[a].spines[s].set_visible(False) 350 | 351 | plt.savefig(save_dir / 'labels.jpg', dpi=200) 352 | matplotlib.use('Agg') 353 | plt.close() 354 | 355 | # loggers 356 | for k, v in loggers.items() or {}: 357 | if k == 'wandb' and v: 358 | v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.jpg')]}, commit=False) 359 | 360 | 361 | def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() 362 | # Plot hyperparameter evolution results in evolve.txt 363 | with open(yaml_file) as f: 364 | hyp = yaml.load(f, Loader=yaml.SafeLoader) 365 | x = np.loadtxt('evolve.txt', ndmin=2) 366 | f = fitness(x) 367 | # weights = (f - f.min()) ** 2 # for weighted results 368 | plt.figure(figsize=(10, 12), tight_layout=True) 369 | matplotlib.rc('font', **{'size': 8}) 370 | for i, (k, v) in enumerate(hyp.items()): 371 | y = x[:, i + 7] 372 | # mu = (y * weights).sum() / weights.sum() # best weighted result 373 | mu = y[f.argmax()] # best single result 374 | plt.subplot(6, 5, i + 1) 375 | plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none') 376 | plt.plot(mu, f.max(), 'k+', markersize=15) 377 | plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters 378 | if i % 5 != 0: 379 | plt.yticks([]) 380 | print('%15s: %.3g' % (k, mu)) 381 | plt.savefig('evolve.png', dpi=200) 382 | print('\nPlot saved as evolve.png') 383 | 384 | 385 | def profile_idetection(start=0, stop=0, labels=(), save_dir=''): 386 | # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() 387 | ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() 388 | s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS'] 389 | files = list(Path(save_dir).glob('frames*.txt')) 390 | for fi, f in enumerate(files): 391 | try: 392 | results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows 393 | n = results.shape[1] # number of rows 394 | x = np.arange(start, min(stop, n) if stop else n) 395 | results = results[:, x] 396 | t = (results[0] - results[0].min()) # set t0=0s 397 | results[0] = x 398 | for i, a in enumerate(ax): 399 | if i < len(results): 400 | label = labels[fi] if len(labels) else f.stem.replace('frames_', '') 401 | a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5) 402 | a.set_title(s[i]) 403 | a.set_xlabel('time (s)') 404 | # if fi == len(files) - 1: 405 | # a.set_ylim(bottom=0) 406 | for side in ['top', 'right']: 407 | a.spines[side].set_visible(False) 408 | else: 409 | a.remove() 410 | except Exception as e: 411 | print('Warning: Plotting error for %s; %s' % (f, e)) 412 | 413 | ax[1].legend() 414 | plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) 415 | 416 | 417 | def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() 418 | # Plot training 'results*.txt', overlaying train and val losses 419 | s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends 420 | t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles 421 | for f in sorted(glob.glob('results*.txt') + glob.glob('../../Downloads/results*.txt')): 422 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T 423 | n = results.shape[1] # number of rows 424 | x = range(start, min(stop, n) if stop else n) 425 | fig, ax = plt.subplots(1, 5, figsize=(14, 3.5), tight_layout=True) 426 | ax = ax.ravel() 427 | for i in range(5): 428 | for j in [i, i + 5]: 429 | y = results[j, x] 430 | ax[i].plot(x, y, marker='.', label=s[j]) 431 | # y_smooth = butter_lowpass_filtfilt(y) 432 | # ax[i].plot(x, np.gradient(y_smooth), marker='.', label=s[j]) 433 | 434 | ax[i].set_title(t[i]) 435 | ax[i].legend() 436 | ax[i].set_ylabel(f) if i == 0 else None # add filename 437 | fig.savefig(f.replace('.txt', '.png'), dpi=200) 438 | 439 | 440 | def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): 441 | # Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') 442 | fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) 443 | ax = ax.ravel() 444 | s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', 445 | 'val Box', 'val Objectness', 'val Classification', 'mAP@0.5', 'mAP@0.5:0.95'] 446 | if bucket: 447 | # files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] 448 | files = ['results%g.txt' % x for x in id] 449 | c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) 450 | os.system(c) 451 | else: 452 | files = list(Path(save_dir).glob('results*.txt')) 453 | assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) 454 | for fi, f in enumerate(files): 455 | try: 456 | results = np.loadtxt(f, usecols=[2, 3, 4, 8, 9, 12, 13, 14, 10, 11], ndmin=2).T 457 | n = results.shape[1] # number of rows 458 | x = range(start, min(stop, n) if stop else n) 459 | for i in range(10): 460 | y = results[i, x] 461 | if i in [0, 1, 2, 5, 6, 7]: 462 | y[y == 0] = np.nan # don't show zero loss values 463 | # y /= y[0] # normalize 464 | label = labels[fi] if len(labels) else f.stem 465 | ax[i].plot(x, y, marker='.', label=label, linewidth=2, markersize=8) 466 | ax[i].set_title(s[i]) 467 | # if i in [5, 6, 7]: # share train and val loss y axes 468 | # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) 469 | except Exception as e: 470 | print('Warning: Plotting error for %s; %s' % (f, e)) 471 | 472 | ax[1].legend() 473 | fig.savefig(Path(save_dir) / 'results.png', dpi=200) 474 | 475 | 476 | def output_to_keypoint(output): 477 | # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] 478 | targets = [] 479 | for i, o in enumerate(output): 480 | kpts = o[:,6:] 481 | o = o[:,:6] 482 | for index, (*box, conf, cls) in enumerate(o.detach().cpu().numpy()): 483 | targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf, *list(kpts.detach().cpu().numpy()[index])]) 484 | return np.array(targets) 485 | 486 | 487 | def plot_skeleton_kpts(im, kpts, steps, orig_shape=None): 488 | #Plot the skeleton and keypointsfor coco datatset 489 | palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102], 490 | [230, 230, 0], [255, 153, 255], [153, 204, 255], 491 | [255, 102, 255], [255, 51, 255], [102, 178, 255], 492 | [51, 153, 255], [255, 153, 153], [255, 102, 102], 493 | [255, 51, 51], [153, 255, 153], [102, 255, 102], 494 | [51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0], 495 | [255, 255, 255]]) 496 | 497 | skeleton = [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], 498 | [7, 13], [6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], 499 | [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]] 500 | 501 | pose_limb_color = palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] 502 | pose_kpt_color = palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] 503 | radius = 5 504 | num_kpts = len(kpts) // steps 505 | 506 | for kid in range(num_kpts): 507 | r, g, b = pose_kpt_color[kid] 508 | x_coord, y_coord = kpts[steps * kid], kpts[steps * kid + 1] 509 | if not (x_coord % 640 == 0 or y_coord % 640 == 0): 510 | if steps == 3: 511 | conf = kpts[steps * kid + 2] 512 | if conf < 0.5: 513 | continue 514 | cv2.circle(im, (int(x_coord), int(y_coord)), radius, (int(r), int(g), int(b)), -1) 515 | 516 | for sk_id, sk in enumerate(skeleton): 517 | r, g, b = pose_limb_color[sk_id] 518 | pos1 = (int(kpts[(sk[0]-1)*steps]), int(kpts[(sk[0]-1)*steps+1])) 519 | pos2 = (int(kpts[(sk[1]-1)*steps]), int(kpts[(sk[1]-1)*steps+1])) 520 | if steps == 3: 521 | conf1 = kpts[(sk[0]-1)*steps+2] 522 | conf2 = kpts[(sk[1]-1)*steps+2] 523 | if conf1<0.5 or conf2<0.5: 524 | continue 525 | if pos1[0]%640 == 0 or pos1[1]%640==0 or pos1[0]<0 or pos1[1]<0: 526 | continue 527 | if pos2[0] % 640 == 0 or pos2[1] % 640 == 0 or pos2[0]<0 or pos2[1]<0: 528 | continue 529 | cv2.line(im, pos1, pos2, (int(r), int(g), int(b)), thickness=2) 530 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # YOLOR PyTorch utils 2 | 3 | import datetime 4 | import logging 5 | import math 6 | import os 7 | import platform 8 | import subprocess 9 | import time 10 | from contextlib import contextmanager 11 | from copy import deepcopy 12 | from pathlib import Path 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torchvision 19 | 20 | try: 21 | import thop # for FLOPS computation 22 | except ImportError: 23 | thop = None 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @contextmanager 28 | def torch_distributed_zero_first(local_rank: int): 29 | """ 30 | Decorator to make all processes in distributed training wait for each local_master to do something. 31 | """ 32 | if local_rank not in [-1, 0]: 33 | torch.distributed.barrier() 34 | yield 35 | if local_rank == 0: 36 | torch.distributed.barrier() 37 | 38 | 39 | def init_torch_seeds(seed=0): 40 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 41 | torch.manual_seed(seed) 42 | if seed == 0: # slower, more reproducible 43 | cudnn.benchmark, cudnn.deterministic = False, True 44 | else: # faster, less reproducible 45 | cudnn.benchmark, cudnn.deterministic = True, False 46 | 47 | 48 | def date_modified(path=__file__): 49 | # return human-readable file modification date, i.e. '2021-3-26' 50 | t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime) 51 | return f'{t.year}-{t.month}-{t.day}' 52 | 53 | 54 | def git_describe(path=Path(__file__).parent): # path must be a directory 55 | # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe 56 | s = f'git -C {path} describe --tags --long --always' 57 | try: 58 | return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] 59 | except subprocess.CalledProcessError as e: 60 | return '' # not a git repository 61 | 62 | 63 | def select_device(device='', batch_size=None): 64 | # device = 'cpu' or '0' or '0,1,2,3' 65 | s = f'YOLOR 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string 66 | cpu = device.lower() == 'cpu' 67 | if cpu: 68 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False 69 | elif device: # non-cpu device requested 70 | os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable 71 | assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability 72 | 73 | cuda = not cpu and torch.cuda.is_available() 74 | if cuda: 75 | n = torch.cuda.device_count() 76 | # if n > 1 and batch_size: # check that batch_size is compatible with device_count 77 | # assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' 78 | space = ' ' * len(s) 79 | for i, d in enumerate(device.split(',') if device else range(n)): 80 | p = torch.cuda.get_device_properties(i) 81 | s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB 82 | else: 83 | s += 'CPU\n' 84 | 85 | logger.info(s.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else s) # emoji-safe 86 | 87 | return torch.device('cuda:0' if cuda else 'cpu') 88 | 89 | 90 | def time_synchronized(): 91 | # pytorch-accurate time 92 | if torch.cuda.is_available(): 93 | torch.cuda.synchronize() 94 | return time.time() 95 | 96 | 97 | def profile(x, ops, n=100, device=None): 98 | # profile a pytorch module or list of modules. Example usage: 99 | # x = torch.randn(16, 3, 640, 640) # input 100 | # m1 = lambda x: x * torch.sigmoid(x) 101 | # m2 = nn.SiLU() 102 | # profile(x, [m1, m2], n=100) # profile speed over 100 iterations 103 | 104 | device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 105 | x = x.to(device) 106 | x.requires_grad = True 107 | print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '') 108 | print(f"\n{'Params':>12s}{'GFLOPS':>12s}{'forward (ms)':>16s}{'backward (ms)':>16s}{'input':>24s}{'output':>24s}") 109 | for m in ops if isinstance(ops, list) else [ops]: 110 | m = m.to(device) if hasattr(m, 'to') else m # device 111 | m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m # type 112 | dtf, dtb, t = 0., 0., [0., 0., 0.] # dt forward, backward 113 | try: 114 | flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPS 115 | except: 116 | flops = 0 117 | 118 | for _ in range(n): 119 | t[0] = time_synchronized() 120 | y = m(x) 121 | t[1] = time_synchronized() 122 | try: 123 | _ = y.sum().backward() 124 | t[2] = time_synchronized() 125 | except: # no backward method 126 | t[2] = float('nan') 127 | dtf += (t[1] - t[0]) * 1000 / n # ms per op forward 128 | dtb += (t[2] - t[1]) * 1000 / n # ms per op backward 129 | 130 | s_in = tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' 131 | s_out = tuple(y.shape) if isinstance(y, torch.Tensor) else 'list' 132 | p = sum(list(x.numel() for x in m.parameters())) if isinstance(m, nn.Module) else 0 # parameters 133 | print(f'{p:12}{flops:12.4g}{dtf:16.4g}{dtb:16.4g}{str(s_in):>24s}{str(s_out):>24s}') 134 | 135 | 136 | def is_parallel(model): 137 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 138 | 139 | 140 | def intersect_dicts(da, db, exclude=()): 141 | # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values 142 | return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} 143 | 144 | 145 | def initialize_weights(model): 146 | for m in model.modules(): 147 | t = type(m) 148 | if t is nn.Conv2d: 149 | pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 150 | elif t is nn.BatchNorm2d: 151 | m.eps = 1e-3 152 | m.momentum = 0.03 153 | elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]: 154 | m.inplace = True 155 | 156 | 157 | def find_modules(model, mclass=nn.Conv2d): 158 | # Finds layer indices matching module class 'mclass' 159 | return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] 160 | 161 | 162 | def sparsity(model): 163 | # Return global model sparsity 164 | a, b = 0., 0. 165 | for p in model.parameters(): 166 | a += p.numel() 167 | b += (p == 0).sum() 168 | return b / a 169 | 170 | 171 | def prune(model, amount=0.3): 172 | # Prune model to requested global sparsity 173 | import torch.nn.utils.prune as prune 174 | print('Pruning model... ', end='') 175 | for name, m in model.named_modules(): 176 | if isinstance(m, nn.Conv2d): 177 | prune.l1_unstructured(m, name='weight', amount=amount) # prune 178 | prune.remove(m, 'weight') # make permanent 179 | print(' %.3g global sparsity' % sparsity(model)) 180 | 181 | 182 | def fuse_conv_and_bn(conv, bn): 183 | # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ 184 | fusedconv = nn.Conv2d(conv.in_channels, 185 | conv.out_channels, 186 | kernel_size=conv.kernel_size, 187 | stride=conv.stride, 188 | padding=conv.padding, 189 | groups=conv.groups, 190 | bias=True).requires_grad_(False).to(conv.weight.device) 191 | 192 | # prepare filters 193 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 194 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 195 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) 196 | 197 | # prepare spatial bias 198 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 199 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 200 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 201 | 202 | return fusedconv 203 | 204 | 205 | def model_info(model, verbose=False, img_size=1024): 206 | # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320] 207 | n_p = sum(x.numel() for x in model.parameters()) # number parameters 208 | n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients 209 | if verbose: 210 | print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) 211 | for i, (name, p) in enumerate(model.named_parameters()): 212 | name = name.replace('module_list.', '') 213 | print('%5g %40s %9s %12g %20s %10.3g %10.3g' % 214 | (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) 215 | 216 | try: # FLOPS 217 | from thop import profile 218 | stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 219 | # img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input 220 | img = torch.zeros((1, model.yaml.get('ch', 3), 640, 640), device=next(model.parameters()).device) # input 221 | flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride GFLOPS 222 | img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float 223 | # fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 1024✖1024 GFLOPS 224 | fs = ', %.1f GFLOPS' % (flops) 225 | except (ImportError, Exception): 226 | fs = '' 227 | 228 | logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") 229 | 230 | 231 | def load_classifier(name='resnet101', n=2): 232 | # Loads a pretrained model reshaped to n-class output 233 | model = torchvision.models.__dict__[name](pretrained=True) 234 | 235 | # ResNet model properties 236 | # input_size = [3, 224, 224] 237 | # input_space = 'RGB' 238 | # input_range = [0, 1] 239 | # mean = [0.485, 0.456, 0.406] 240 | # std = [0.229, 0.224, 0.225] 241 | 242 | # Reshape output to n classes 243 | filters = model.fc.weight.shape[1] 244 | model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True) 245 | model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True) 246 | model.fc.out_features = n 247 | return model 248 | 249 | 250 | def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416) 251 | # scales img(bs,3,y,x) by ratio constrained to gs-multiple 252 | if ratio == 1.0: 253 | return img 254 | else: 255 | h, w = img.shape[2:] 256 | s = (int(h * ratio), int(w * ratio)) # new size 257 | img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize 258 | if not same_shape: # pad/crop img 259 | h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)] 260 | return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean 261 | 262 | 263 | def copy_attr(a, b, include=(), exclude=()): 264 | # Copy attributes from b to a, options to only include [...] and to exclude [...] 265 | for k, v in b.__dict__.items(): 266 | if (len(include) and k not in include) or k.startswith('_') or k in exclude: 267 | continue 268 | else: 269 | setattr(a, k, v) 270 | 271 | 272 | class ModelEMA: 273 | """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models 274 | Keep a moving average of everything in the model state_dict (parameters and buffers). 275 | This is intended to allow functionality like 276 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 277 | A smoothed version of the weights is necessary for some training schemes to perform well. 278 | This class is sensitive where it is initialized in the sequence of model init, 279 | GPU assignment and distributed training wrappers. 280 | """ 281 | 282 | def __init__(self, model, decay=0.9999, updates=0): 283 | # Create EMA 284 | self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA 285 | # if next(model.parameters()).device.type != 'cpu': 286 | # self.ema.half() # FP16 EMA 287 | self.updates = updates # number of EMA updates 288 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs) 289 | for p in self.ema.parameters(): 290 | p.requires_grad_(False) 291 | 292 | def update(self, model): 293 | # Update EMA parameters 294 | with torch.no_grad(): 295 | self.updates += 1 296 | d = self.decay(self.updates) 297 | 298 | msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict 299 | for k, v in self.ema.state_dict().items(): 300 | if v.dtype.is_floating_point: 301 | v *= d 302 | v += (1. - d) * msd[k].detach() 303 | 304 | def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): 305 | # Update EMA attributes 306 | copy_attr(self.ema, model, include, exclude) 307 | 308 | 309 | class BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): 310 | def _check_input_dim(self, input): 311 | # The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc 312 | # is this method that is overwritten by the sub-class 313 | # This original goal of this method was for tensor sanity checks 314 | # If you're ok bypassing those sanity checks (eg. if you trust your inference 315 | # to provide the right dimensional inputs), then you can just use this method 316 | # for easy conversion from SyncBatchNorm 317 | # (unfortunately, SyncBatchNorm does not store the original class - if it did 318 | # we could return the one that was originally created) 319 | return 320 | 321 | def revert_sync_batchnorm(module): 322 | # this is very similar to the function that it is trying to revert: 323 | # https://github.com/pytorch/pytorch/blob/c8b3686a3e4ba63dc59e5dcfe5db3430df256833/torch/nn/modules/batchnorm.py#L679 324 | module_output = module 325 | if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm): 326 | new_cls = BatchNormXd 327 | module_output = BatchNormXd(module.num_features, 328 | module.eps, module.momentum, 329 | module.affine, 330 | module.track_running_stats) 331 | if module.affine: 332 | with torch.no_grad(): 333 | module_output.weight = module.weight 334 | module_output.bias = module.bias 335 | module_output.running_mean = module.running_mean 336 | module_output.running_var = module.running_var 337 | module_output.num_batches_tracked = module.num_batches_tracked 338 | if hasattr(module, "qconfig"): 339 | module_output.qconfig = module.qconfig 340 | for name, child in module.named_children(): 341 | module_output.add_module(name, revert_sync_batchnorm(child)) 342 | del module 343 | return module_output 344 | 345 | 346 | class TracedModel(nn.Module): 347 | 348 | def __init__(self, model=None, device=None, img_size=(640, 640)): 349 | super(TracedModel, self).__init__() 350 | 351 | print(" Convert model to Traced-model... ") 352 | self.stride = model.stride 353 | self.names = model.names 354 | self.model = model 355 | 356 | self.model = revert_sync_batchnorm(self.model) 357 | self.model.to('cpu') 358 | self.model.eval() 359 | 360 | self.detect_layer = self.model.model[-1] 361 | self.model.traced = True 362 | 363 | rand_example = torch.rand(1, 3, img_size, img_size) 364 | 365 | traced_script_module = torch.jit.trace(self.model, rand_example, strict=False) 366 | #traced_script_module = torch.jit.script(self.model) 367 | traced_script_module.save("traced_model.pt") 368 | print(" traced_script_module saved! ") 369 | self.model = traced_script_module 370 | self.model.to(device) 371 | self.detect_layer.to(device) 372 | print(" model is traced! \n") 373 | 374 | def forward(self, x, augment=False, profile=False): 375 | out = self.model(x) 376 | out = self.detect_layer(out) 377 | return out -------------------------------------------------------------------------------- /utils/wandb_logging/__init__.py: -------------------------------------------------------------------------------- 1 | # init -------------------------------------------------------------------------------- /utils/wandb_logging/log_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | from wandb_utils import WandbLogger 6 | 7 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 8 | 9 | 10 | def create_dataset_artifact(opt): 11 | with open(opt.data) as f: 12 | data = yaml.load(f, Loader=yaml.SafeLoader) # data dict 13 | logger = WandbLogger(opt, '', None, data, job_type='Dataset Creation') 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path') 19 | parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') 20 | parser.add_argument('--project', type=str, default='YOLOR', help='name of W&B Project') 21 | opt = parser.parse_args() 22 | opt.resume = False # Explicitly disallow resume check for dataset upload job 23 | 24 | create_dataset_artifact(opt) 25 | -------------------------------------------------------------------------------- /utils/wandb_logging/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from pathlib import Path 4 | 5 | import torch 6 | import yaml 7 | from tqdm import tqdm 8 | 9 | sys.path.append(str(Path(__file__).parent.parent.parent)) # add utils/ to path 10 | from utils.datasets import LoadImagesAndLabels 11 | from utils.datasets import img2label_paths 12 | from utils.general import colorstr, xywh2xyxy, check_dataset 13 | 14 | try: 15 | import wandb 16 | from wandb import init, finish 17 | except ImportError: 18 | wandb = None 19 | 20 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 21 | 22 | 23 | def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX): 24 | return from_string[len(prefix):] 25 | 26 | 27 | def check_wandb_config_file(data_config_file): 28 | wandb_config = '_wandb.'.join(data_config_file.rsplit('.', 1)) # updated data.yaml path 29 | if Path(wandb_config).is_file(): 30 | return wandb_config 31 | return data_config_file 32 | 33 | 34 | def get_run_info(run_path): 35 | run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX)) 36 | run_id = run_path.stem 37 | project = run_path.parent.stem 38 | model_artifact_name = 'run_' + run_id + '_model' 39 | return run_id, project, model_artifact_name 40 | 41 | 42 | def check_wandb_resume(opt): 43 | process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None 44 | if isinstance(opt.resume, str): 45 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 46 | if opt.global_rank not in [-1, 0]: # For resuming DDP runs 47 | run_id, project, model_artifact_name = get_run_info(opt.resume) 48 | api = wandb.Api() 49 | artifact = api.artifact(project + '/' + model_artifact_name + ':latest') 50 | modeldir = artifact.download() 51 | opt.weights = str(Path(modeldir) / "last.pt") 52 | return True 53 | return None 54 | 55 | 56 | def process_wandb_config_ddp_mode(opt): 57 | with open(opt.data) as f: 58 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict 59 | train_dir, val_dir = None, None 60 | if isinstance(data_dict['train'], str) and data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX): 61 | api = wandb.Api() 62 | train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias) 63 | train_dir = train_artifact.download() 64 | train_path = Path(train_dir) / 'data/images/' 65 | data_dict['train'] = str(train_path) 66 | 67 | if isinstance(data_dict['val'], str) and data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX): 68 | api = wandb.Api() 69 | val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias) 70 | val_dir = val_artifact.download() 71 | val_path = Path(val_dir) / 'data/images/' 72 | data_dict['val'] = str(val_path) 73 | if train_dir or val_dir: 74 | ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml') 75 | with open(ddp_data_path, 'w') as f: 76 | yaml.dump(data_dict, f) 77 | opt.data = ddp_data_path 78 | 79 | 80 | class WandbLogger(): 81 | def __init__(self, opt, name, run_id, data_dict, job_type='Training'): 82 | # Pre-training routine -- 83 | self.job_type = job_type 84 | self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict 85 | # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call 86 | if isinstance(opt.resume, str): # checks resume from artifact 87 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 88 | run_id, project, model_artifact_name = get_run_info(opt.resume) 89 | model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name 90 | assert wandb, 'install wandb to resume wandb runs' 91 | # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config 92 | self.wandb_run = wandb.init(id=run_id, project=project, resume='allow') 93 | opt.resume = model_artifact_name 94 | elif self.wandb: 95 | self.wandb_run = wandb.init(config=opt, 96 | resume="allow", 97 | project='YOLOR' if opt.project == 'runs/train' else Path(opt.project).stem, 98 | name=name, 99 | job_type=job_type, 100 | id=run_id) if not wandb.run else wandb.run 101 | if self.wandb_run: 102 | if self.job_type == 'Training': 103 | if not opt.resume: 104 | wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict 105 | # Info useful for resuming from artifacts 106 | self.wandb_run.config.opt = vars(opt) 107 | self.wandb_run.config.data_dict = wandb_data_dict 108 | self.data_dict = self.setup_training(opt, data_dict) 109 | if self.job_type == 'Dataset Creation': 110 | self.data_dict = self.check_and_upload_dataset(opt) 111 | else: 112 | prefix = colorstr('wandb: ') 113 | print(f"{prefix}Install Weights & Biases for YOLOR logging with 'pip install wandb' (recommended)") 114 | 115 | def check_and_upload_dataset(self, opt): 116 | assert wandb, 'Install wandb to upload dataset' 117 | check_dataset(self.data_dict) 118 | config_path = self.log_dataset_artifact(opt.data, 119 | opt.single_cls, 120 | 'YOLOR' if opt.project == 'runs/train' else Path(opt.project).stem) 121 | print("Created dataset config file ", config_path) 122 | with open(config_path) as f: 123 | wandb_data_dict = yaml.load(f, Loader=yaml.SafeLoader) 124 | return wandb_data_dict 125 | 126 | def setup_training(self, opt, data_dict): 127 | self.log_dict, self.current_epoch, self.log_imgs = {}, 0, 16 # Logging Constants 128 | self.bbox_interval = opt.bbox_interval 129 | if isinstance(opt.resume, str): 130 | modeldir, _ = self.download_model_artifact(opt) 131 | if modeldir: 132 | self.weights = Path(modeldir) / "last.pt" 133 | config = self.wandb_run.config 134 | opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp = str( 135 | self.weights), config.save_period, config.total_batch_size, config.bbox_interval, config.epochs, \ 136 | config.opt['hyp'] 137 | data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume 138 | if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download 139 | self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), 140 | opt.artifact_alias) 141 | self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), 142 | opt.artifact_alias) 143 | self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None 144 | if self.train_artifact_path is not None: 145 | train_path = Path(self.train_artifact_path) / 'data/images/' 146 | data_dict['train'] = str(train_path) 147 | if self.val_artifact_path is not None: 148 | val_path = Path(self.val_artifact_path) / 'data/images/' 149 | data_dict['val'] = str(val_path) 150 | self.val_table = self.val_artifact.get("val") 151 | self.map_val_table_path() 152 | if self.val_artifact is not None: 153 | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") 154 | self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) 155 | if opt.bbox_interval == -1: 156 | self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 157 | return data_dict 158 | 159 | def download_dataset_artifact(self, path, alias): 160 | if isinstance(path, str) and path.startswith(WANDB_ARTIFACT_PREFIX): 161 | dataset_artifact = wandb.use_artifact(remove_prefix(path, WANDB_ARTIFACT_PREFIX) + ":" + alias) 162 | assert dataset_artifact is not None, "'Error: W&B dataset artifact doesn\'t exist'" 163 | datadir = dataset_artifact.download() 164 | return datadir, dataset_artifact 165 | return None, None 166 | 167 | def download_model_artifact(self, opt): 168 | if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): 169 | model_artifact = wandb.use_artifact(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX) + ":latest") 170 | assert model_artifact is not None, 'Error: W&B model artifact doesn\'t exist' 171 | modeldir = model_artifact.download() 172 | epochs_trained = model_artifact.metadata.get('epochs_trained') 173 | total_epochs = model_artifact.metadata.get('total_epochs') 174 | assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % ( 175 | total_epochs) 176 | return modeldir, model_artifact 177 | return None, None 178 | 179 | def log_model(self, path, opt, epoch, fitness_score, best_model=False): 180 | model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={ 181 | 'original_url': str(path), 182 | 'epochs_trained': epoch + 1, 183 | 'save period': opt.save_period, 184 | 'project': opt.project, 185 | 'total_epochs': opt.epochs, 186 | 'fitness_score': fitness_score 187 | }) 188 | model_artifact.add_file(str(path / 'last.pt'), name='last.pt') 189 | wandb.log_artifact(model_artifact, 190 | aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else '']) 191 | print("Saving model artifact on epoch ", epoch + 1) 192 | 193 | def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False): 194 | with open(data_file) as f: 195 | data = yaml.load(f, Loader=yaml.SafeLoader) # data dict 196 | nc, names = (1, ['item']) if single_cls else (int(data['nc']), data['names']) 197 | names = {k: v for k, v in enumerate(names)} # to index dictionary 198 | self.train_artifact = self.create_dataset_table(LoadImagesAndLabels( 199 | data['train']), names, name='train') if data.get('train') else None 200 | self.val_artifact = self.create_dataset_table(LoadImagesAndLabels( 201 | data['val']), names, name='val') if data.get('val') else None 202 | if data.get('train'): 203 | data['train'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'train') 204 | if data.get('val'): 205 | data['val'] = WANDB_ARTIFACT_PREFIX + str(Path(project) / 'val') 206 | path = data_file if overwrite_config else '_wandb.'.join(data_file.rsplit('.', 1)) # updated data.yaml path 207 | data.pop('download', None) 208 | with open(path, 'w') as f: 209 | yaml.dump(data, f) 210 | 211 | if self.job_type == 'Training': # builds correct artifact pipeline graph 212 | self.wandb_run.use_artifact(self.val_artifact) 213 | self.wandb_run.use_artifact(self.train_artifact) 214 | self.val_artifact.wait() 215 | self.val_table = self.val_artifact.get('val') 216 | self.map_val_table_path() 217 | else: 218 | self.wandb_run.log_artifact(self.train_artifact) 219 | self.wandb_run.log_artifact(self.val_artifact) 220 | return path 221 | 222 | def map_val_table_path(self): 223 | self.val_table_map = {} 224 | print("Mapping dataset") 225 | for i, data in enumerate(tqdm(self.val_table.data)): 226 | self.val_table_map[data[3]] = data[0] 227 | 228 | def create_dataset_table(self, dataset, class_to_id, name='dataset'): 229 | # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging 230 | artifact = wandb.Artifact(name=name, type="dataset") 231 | img_files = tqdm([dataset.path]) if isinstance(dataset.path, str) and Path(dataset.path).is_dir() else None 232 | img_files = tqdm(dataset.img_files) if not img_files else img_files 233 | for img_file in img_files: 234 | if Path(img_file).is_dir(): 235 | artifact.add_dir(img_file, name='data/images') 236 | labels_path = 'labels'.join(dataset.path.rsplit('images', 1)) 237 | artifact.add_dir(labels_path, name='data/labels') 238 | else: 239 | artifact.add_file(img_file, name='data/images/' + Path(img_file).name) 240 | label_file = Path(img2label_paths([img_file])[0]) 241 | artifact.add_file(str(label_file), 242 | name='data/labels/' + label_file.name) if label_file.exists() else None 243 | table = wandb.Table(columns=["id", "train_image", "Classes", "name"]) 244 | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()]) 245 | for si, (img, labels, paths, shapes) in enumerate(tqdm(dataset)): 246 | height, width = shapes[0] 247 | labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4))) * torch.Tensor([width, height, width, height]) 248 | box_data, img_classes = [], {} 249 | for cls, *xyxy in labels[:, 1:].tolist(): 250 | cls = int(cls) 251 | box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 252 | "class_id": cls, 253 | "box_caption": "%s" % (class_to_id[cls]), 254 | "scores": {"acc": 1}, 255 | "domain": "pixel"}) 256 | img_classes[cls] = class_to_id[cls] 257 | boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}} # inference-space 258 | table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes), 259 | Path(paths).name) 260 | artifact.add(table, name) 261 | return artifact 262 | 263 | def log_training_progress(self, predn, path, names): 264 | if self.val_table and self.result_table: 265 | class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) 266 | box_data = [] 267 | total_conf = 0 268 | for *xyxy, conf, cls in predn.tolist(): 269 | if conf >= 0.25: 270 | box_data.append( 271 | {"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, 272 | "class_id": int(cls), 273 | "box_caption": "%s %.3f" % (names[cls], conf), 274 | "scores": {"class_score": conf}, 275 | "domain": "pixel"}) 276 | total_conf = total_conf + conf 277 | boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space 278 | id = self.val_table_map[Path(path).name] 279 | self.result_table.add_data(self.current_epoch, 280 | id, 281 | wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set), 282 | total_conf / max(1, len(box_data)) 283 | ) 284 | 285 | def log(self, log_dict): 286 | if self.wandb_run: 287 | for key, value in log_dict.items(): 288 | self.log_dict[key] = value 289 | 290 | def end_epoch(self, best_result=False): 291 | if self.wandb_run: 292 | wandb.log(self.log_dict) 293 | self.log_dict = {} 294 | if self.result_artifact: 295 | train_results = wandb.JoinedTable(self.val_table, self.result_table, "id") 296 | self.result_artifact.add(train_results, 'result') 297 | wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch), 298 | ('best' if best_result else '')]) 299 | self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"]) 300 | self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") 301 | 302 | def finish_run(self): 303 | if self.wandb_run: 304 | if self.log_dict: 305 | wandb.log(self.log_dict) 306 | wandb.run.finish() 307 | --------------------------------------------------------------------------------