├── README.md ├── clr_head.py ├── demo_onnx.py ├── demo_onnx_new.py ├── demo_trt.py ├── grid_sample.py ├── imgs ├── output_onnx.png └── output_trt.png ├── my_log └── test_onnx.log ├── test.jpg └── torch2onnx.py /README.md: -------------------------------------------------------------------------------- 1 | # CLRNet-onnxruntime-and-tensorrt-demo 2 | This is the onnxruntime and tensorrt inference code for CLRNet: Cross Layer Refinement Network for Lane Detection (CVPR 2022). Official code: https://github.com/Turoad/CLRNet 3 | 4 | [skip to onnx demo](#onnx) and [skip to tensorrt demo](#tensorrt) 5 | 6 | ## Note 7 | 1、Making onnx supported op grid_sampler.
8 | 2、Using this code you can successfully convert to onnx model and inference an onnxruntime demo. A new version demo only use numpy to do post-processing, easy to deploy but more time cost for NMS.
9 | 3、Modifications according to the following operations will affect the training code, this code only for onnx inference.
10 | 4、It mainly includes two parts: model inference and post-processing.
11 | 5、Supporting convert to tensorrt engine. 12 | 13 | ## convert and test onnx 14 | 1、git official code and install original environment by refer to https://github.com/Turoad/CLRNet
15 | 2、git clone this code
16 | 3、cp clr_head.py to your_path/CLRNet/clrnet/models/heads/
17 | 4、mkdir your_path/CLRNet/modules/ and cp grid_sample.py to your_path/CLRNet/modules/
18 | 5、cp torch2onnx.py to your_path/CLRNet/
19 | 6、For example, run 20 | ``` 21 | python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth 22 | ``` 23 | my deployment log is here https://github.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/blob/main/my_log/test_onnx.log
24 | 7、cp test.jpg to your_path/CLRNet/ and run 25 | 26 | 1) NMS based on torch and cpython. 27 | ``` 28 | python demo_onnx.py 29 | ```` 30 | 2) NMS based on numpy. 31 | ``` 32 | python demo_onnx_new.py 33 | ``` 34 | ## onnx output 35 | 36 | 37 | 38 | ## convert and test tensorrt 39 | Tensorrt version needs to be greater than 8.4. This code is implemented in TensorRT-8.4.0.6.
40 | *GatherElements error、IShuffleLayer error、`is_tensor()' failed* have been resolved.
41 | 1、install tensorrt and compilation tools *trtexec*.
42 | 2、install *polygraphy* to help modify the onnx model. You can install by 43 | ``` 44 | pip install nvidia-pyindex 45 | pip install polygraphy 46 | pip install onnx-graphsurgeon 47 | ``` 48 | and run 49 | ``` 50 | polygraphy surgeon sanitize your_path/tusimple_r18.onnx --fold-constants --output your_path/tusimple_r18.onnx 51 | ``` 52 | 3、 convert to tensorrt and get tusimple_r18.engine 53 | ``` 54 | ./trtexec --onnx=your_path/tusimple_r18.onnx --saveEngine=your_path/tusimple_r18.engine --verbose 55 | ``` 56 | 4、test tensorrt demo. 57 | ``` 58 | python demo_trt.py 59 | ``` 60 | 61 | ## tensorrt output 62 | 63 | 64 | 65 | 66 | ## TO DO 67 | - [x] Optimize post-processing. 68 | - [x] Tensorrt demo. 69 | - [ ] Cpp demo. 70 | 71 | -------------------------------------------------------------------------------- /clr_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from mmcv.cnn import ConvModule 9 | 10 | from clrnet.utils.lane import Lane 11 | from clrnet.models.losses.focal_loss import FocalLoss 12 | from clrnet.models.losses.accuracy import accuracy 13 | from clrnet.ops import nms 14 | 15 | from clrnet.models.utils.roi_gather import ROIGather, LinearModule 16 | from clrnet.models.utils.seg_decoder import SegDecoder 17 | from clrnet.models.utils.dynamic_assign import assign 18 | from clrnet.models.losses.lineiou_loss import liou_loss 19 | from ..registry import HEADS 20 | 21 | from modules.grid_sample import bilinear_grid_sample 22 | 23 | @HEADS.register_module 24 | class CLRHead(nn.Module): 25 | def __init__(self, 26 | num_points=72, 27 | prior_feat_channels=64, 28 | fc_hidden_dim=64, 29 | num_priors=192, 30 | num_fc=2, 31 | refine_layers=3, 32 | sample_points=36, 33 | cfg=None): 34 | super(CLRHead, self).__init__() 35 | self.cfg = cfg 36 | self.img_w = self.cfg.img_w 37 | self.img_h = self.cfg.img_h 38 | self.n_strips = num_points - 1 39 | self.n_offsets = num_points 40 | self.num_priors = num_priors 41 | self.sample_points = sample_points 42 | self.refine_layers = refine_layers 43 | self.fc_hidden_dim = fc_hidden_dim 44 | 45 | 46 | self.sample_x_indexs = (torch.linspace(0, 1, steps=self.sample_points, dtype=torch.float32) * self.n_strips).long() 47 | self.prior_feat_ys = torch.flip((1 - self.sample_x_indexs.float() / self.n_strips), dims=[-1]) 48 | self.prior_ys = torch.linspace(1, 0, steps=self.n_offsets, dtype=torch.float32) 49 | 50 | # self.register_buffer(name='sample_x_indexs', tensor=(torch.linspace( 51 | # 0, 1, steps=self.sample_points, dtype=torch.float32) * 52 | # self.n_strips).long()) 53 | # self.register_buffer(name='prior_feat_ys', tensor=torch.flip( 54 | # (1 - self.sample_x_indexs.float() / self.n_strips), dims=[-1])) 55 | # self.register_buffer(name='prior_ys', tensor=torch.linspace(1, 56 | # 0, 57 | # steps=self.n_offsets, 58 | # dtype=torch.float32)) 59 | 60 | self.prior_feat_channels = prior_feat_channels 61 | 62 | self._init_prior_embeddings() 63 | init_priors, priors_on_featmap = self.generate_priors_from_embeddings() #None, None 64 | self.register_buffer(name='priors', tensor=init_priors) 65 | self.register_buffer(name='priors_on_featmap', tensor=priors_on_featmap) 66 | 67 | # generate xys for feature map 68 | self.seg_decoder = SegDecoder(self.img_h, self.img_w, 69 | self.cfg.num_classes, 70 | self.prior_feat_channels, 71 | self.refine_layers) 72 | 73 | reg_modules = list() 74 | cls_modules = list() 75 | for _ in range(num_fc): 76 | reg_modules += [*LinearModule(self.fc_hidden_dim)] 77 | cls_modules += [*LinearModule(self.fc_hidden_dim)] 78 | self.reg_modules = nn.ModuleList(reg_modules) 79 | self.cls_modules = nn.ModuleList(cls_modules) 80 | 81 | self.roi_gather = ROIGather(self.prior_feat_channels, self.num_priors, 82 | self.sample_points, self.fc_hidden_dim, 83 | self.refine_layers) 84 | 85 | self.reg_layers = nn.Linear( 86 | self.fc_hidden_dim, self.n_offsets + 1 + 2 + 87 | 1) # n offsets + 1 length + start_x + start_y + theta 88 | self.cls_layers = nn.Linear(self.fc_hidden_dim, 2) 89 | 90 | weights = torch.ones(self.cfg.num_classes) 91 | weights[0] = self.cfg.bg_weight 92 | # self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label, 93 | # weight=weights) 94 | 95 | # init the weights here 96 | self.init_weights() 97 | 98 | # function to init layer weights 99 | def init_weights(self): 100 | # initialize heads 101 | for m in self.cls_layers.parameters(): 102 | nn.init.normal_(m, mean=0., std=1e-3) 103 | 104 | for m in self.reg_layers.parameters(): 105 | nn.init.normal_(m, mean=0., std=1e-3) 106 | 107 | def pool_prior_features(self, batch_features, num_priors, prior_xs): 108 | ''' 109 | pool prior feature from feature map. 110 | Args: 111 | batch_features (Tensor): Input feature maps, shape: (B, C, H, W) 112 | ''' 113 | 114 | batch_size = batch_features.shape[0] 115 | 116 | prior_xs = prior_xs.view(batch_size, num_priors, -1, 1) 117 | prior_ys = self.prior_feat_ys.repeat(batch_size * num_priors).view( 118 | batch_size, num_priors, -1, 1) 119 | 120 | prior_xs = prior_xs * 2. - 1. 121 | prior_ys = prior_ys * 2. - 1. 122 | grid = torch.cat((prior_xs, prior_ys), dim=-1) 123 | # feature = F.grid_sample(batch_features, grid, 124 | # align_corners=True).permute(0, 2, 1, 3) 125 | 126 | feature = bilinear_grid_sample(batch_features, grid, 127 | align_corners=True).permute(0, 2, 1, 3) 128 | 129 | feature = feature.reshape(batch_size * num_priors, 130 | self.prior_feat_channels, self.sample_points, 131 | 1) 132 | return feature 133 | 134 | def generate_priors_from_embeddings(self): 135 | predictions = self.prior_embeddings.weight # (num_prop, 3) 136 | 137 | # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates, score[0] = negative prob, score[1] = positive prob 138 | priors = predictions.new_zeros( 139 | (self.num_priors, 2 + 2 + 2 + self.n_offsets), device=predictions.device) 140 | 141 | priors[:, 2:5] = predictions.clone() 142 | priors[:, 6:] = ( 143 | priors[:, 3].unsqueeze(1).clone().repeat(1, self.n_offsets) * 144 | (self.img_w - 1) + 145 | ((1 - self.prior_ys.repeat(self.num_priors, 1) - 146 | priors[:, 2].unsqueeze(1).clone().repeat(1, self.n_offsets)) * 147 | self.img_h / torch.tan(priors[:, 4].unsqueeze(1).clone().repeat( 148 | 1, self.n_offsets) * math.pi + 1e-5))) / (self.img_w - 1) 149 | 150 | # init priors on feature map 151 | priors_on_featmap = priors.clone()[..., 6 + self.sample_x_indexs] 152 | 153 | return priors, priors_on_featmap 154 | 155 | def _init_prior_embeddings(self): 156 | # [start_y, start_x, theta] -> all normalize 157 | self.prior_embeddings = nn.Embedding(self.num_priors, 3) 158 | 159 | bottom_priors_nums = self.num_priors * 3 // 4 160 | left_priors_nums, _ = self.num_priors // 8, self.num_priors // 8 161 | 162 | strip_size = 0.5 / (left_priors_nums // 2 - 1) 163 | bottom_strip_size = 1 / (bottom_priors_nums // 4 + 1) 164 | for i in range(left_priors_nums): 165 | nn.init.constant_(self.prior_embeddings.weight[i, 0], 166 | (i // 2) * strip_size) 167 | nn.init.constant_(self.prior_embeddings.weight[i, 1], 0.) 168 | nn.init.constant_(self.prior_embeddings.weight[i, 2], 169 | 0.16 if i % 2 == 0 else 0.32) 170 | 171 | for i in range(left_priors_nums, 172 | left_priors_nums + bottom_priors_nums): 173 | nn.init.constant_(self.prior_embeddings.weight[i, 0], 0.) 174 | nn.init.constant_(self.prior_embeddings.weight[i, 1], 175 | ((i - left_priors_nums) // 4 + 1) * 176 | bottom_strip_size) 177 | nn.init.constant_(self.prior_embeddings.weight[i, 2], 178 | 0.2 * (i % 4 + 1)) 179 | 180 | for i in range(left_priors_nums + bottom_priors_nums, self.num_priors): 181 | nn.init.constant_( 182 | self.prior_embeddings.weight[i, 0], 183 | ((i - left_priors_nums - bottom_priors_nums) // 2) * 184 | strip_size) 185 | nn.init.constant_(self.prior_embeddings.weight[i, 1], 1.) 186 | nn.init.constant_(self.prior_embeddings.weight[i, 2], 187 | 0.68 if i % 2 == 0 else 0.84) 188 | 189 | # forward function here 190 | def forward(self, x, **kwargs): 191 | ''' 192 | Take pyramid features as input to perform Cross Layer Refinement and finally output the prediction lanes. 193 | Each feature is a 4D tensor. 194 | Args: 195 | x: input features (list[Tensor]) 196 | Return: 197 | prediction_list: each layer's prediction result 198 | seg: segmentation result for auxiliary loss 199 | ''' 200 | batch_features = list(x[len(x) - self.refine_layers:]) 201 | batch_features.reverse() 202 | batch_size = batch_features[-1].shape[0] 203 | 204 | if self.training: 205 | self.priors, self.priors_on_featmap = self.generate_priors_from_embeddings() 206 | 207 | priors, priors_on_featmap = self.priors.repeat(batch_size, 1, 208 | 1), self.priors_on_featmap.repeat( 209 | batch_size, 1, 1) 210 | 211 | predictions_lists = [] 212 | 213 | # iterative refine 214 | prior_features_stages = [] 215 | for stage in range(self.refine_layers): 216 | num_priors = priors_on_featmap.shape[1] 217 | prior_xs = torch.flip(priors_on_featmap, dims=[2]) 218 | 219 | batch_prior_features = self.pool_prior_features( 220 | batch_features[stage], num_priors, prior_xs) 221 | prior_features_stages.append(batch_prior_features) 222 | 223 | fc_features = self.roi_gather(prior_features_stages, 224 | batch_features[stage], stage) 225 | 226 | fc_features = fc_features.view(num_priors, batch_size, 227 | -1).reshape(batch_size * num_priors, 228 | self.fc_hidden_dim) 229 | 230 | cls_features = fc_features.clone() 231 | reg_features = fc_features.clone() 232 | for cls_layer in self.cls_modules: 233 | cls_features = cls_layer(cls_features) 234 | for reg_layer in self.reg_modules: 235 | reg_features = reg_layer(reg_features) 236 | 237 | cls_logits = self.cls_layers(cls_features) 238 | reg = self.reg_layers(reg_features) 239 | 240 | cls_logits = cls_logits.reshape( 241 | batch_size, -1, cls_logits.shape[1]) # (B, num_priors, 2) 242 | reg = reg.reshape(batch_size, -1, reg.shape[1]) 243 | 244 | predictions = priors.clone() 245 | predictions[:, :, :2] = cls_logits 246 | 247 | predictions[:, :, 248 | 2:5] += reg[:, :, :3] # also reg theta angle here 249 | predictions[:, :, 5] = reg[:, :, 3] # length 250 | 251 | def tran_tensor(t): 252 | return t.unsqueeze(2).clone().repeat(1, 1, self.n_offsets) 253 | 254 | predictions[..., 6:] = ( 255 | tran_tensor(predictions[..., 3]) * (self.img_w - 1) + 256 | ((1 - self.prior_ys.repeat(batch_size, num_priors, 1) - 257 | tran_tensor(predictions[..., 2])) * self.img_h / 258 | torch.tan(tran_tensor(predictions[..., 4]) * math.pi + 1e-5))) / (self.img_w - 1) 259 | 260 | prediction_lines = predictions.clone() 261 | predictions[..., 6:] += reg[..., 4:] 262 | 263 | predictions_lists.append(predictions) 264 | 265 | if stage != self.refine_layers - 1: 266 | priors = prediction_lines.detach().clone() 267 | priors_on_featmap = priors[..., 6 + self.sample_x_indexs] 268 | 269 | if self.training: 270 | seg = None 271 | seg_features = torch.cat([ 272 | F.interpolate(feature, 273 | size=[ 274 | batch_features[-1].shape[2], 275 | batch_features[-1].shape[3] 276 | ], 277 | mode='bilinear', 278 | align_corners=False) 279 | for feature in batch_features 280 | ], 281 | dim=1) 282 | seg = self.seg_decoder(seg_features) 283 | output = {'predictions_lists': predictions_lists, 'seg': seg} 284 | return self.loss(output, kwargs['batch']) 285 | 286 | return predictions_lists[-1] 287 | 288 | def predictions_to_pred(self, predictions): 289 | ''' 290 | Convert predictions to internal Lane structure for evaluation. 291 | ''' 292 | self.prior_ys = self.prior_ys.to(predictions.device) 293 | self.prior_ys = self.prior_ys.double() 294 | 295 | lanes = [] 296 | for lane in predictions: 297 | lane_xs = lane[6:] # normalized value 298 | start = min(max(0, int(round(lane[2].item() * self.n_strips))), 299 | self.n_strips) 300 | length = int(round(lane[5].item())) 301 | end = start + length - 1 302 | end = min(end, len(self.prior_ys) - 1) 303 | # end = label_end 304 | # if the prediction does not start at the bottom of the image, 305 | # extend its prediction until the x is outside the image 306 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.) 307 | ).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool)) 308 | lane_xs[end + 1:] = -2 309 | lane_xs[:start][mask] = -2 310 | lane_ys = self.prior_ys[lane_xs >= 0] 311 | lane_xs = lane_xs[lane_xs >= 0] 312 | lane_xs = lane_xs.flip(0).double() 313 | lane_ys = lane_ys.flip(0) 314 | 315 | lane_ys = (lane_ys * (self.cfg.ori_img_h - self.cfg.cut_height) + 316 | self.cfg.cut_height) / self.cfg.ori_img_h 317 | if len(lane_xs) <= 1: 318 | continue 319 | points = torch.stack( 320 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 321 | dim=1).squeeze(2) 322 | 323 | lane = Lane(points=points.cpu().numpy(), 324 | metadata={ 325 | 'start_x': lane[3], 326 | 'start_y': lane[2], 327 | 'conf': lane[1] 328 | }) 329 | lanes.append(lane) 330 | return lanes 331 | 332 | def loss(self, 333 | output, 334 | batch, 335 | cls_loss_weight=2., 336 | xyt_loss_weight=0.5, 337 | iou_loss_weight=2., 338 | seg_loss_weight=1.): 339 | if self.cfg.haskey('cls_loss_weight'): 340 | cls_loss_weight = self.cfg.cls_loss_weight 341 | if self.cfg.haskey('xyt_loss_weight'): 342 | xyt_loss_weight = self.cfg.xyt_loss_weight 343 | if self.cfg.haskey('iou_loss_weight'): 344 | iou_loss_weight = self.cfg.iou_loss_weight 345 | if self.cfg.haskey('seg_loss_weight'): 346 | seg_loss_weight = self.cfg.seg_loss_weight 347 | 348 | predictions_lists = output['predictions_lists'] 349 | targets = batch['lane_line'].clone() 350 | cls_criterion = FocalLoss(alpha=0.25, gamma=2.) 351 | cls_loss = 0 352 | reg_xytl_loss = 0 353 | iou_loss = 0 354 | cls_acc = [] 355 | 356 | cls_acc_stage = [] 357 | for stage in range(self.refine_layers): 358 | predictions_list = predictions_lists[stage] 359 | for predictions, target in zip(predictions_list, targets): 360 | target = target[target[:, 1] == 1] 361 | 362 | if len(target) == 0: 363 | # If there are no targets, all predictions have to be negatives (i.e., 0 confidence) 364 | cls_target = predictions.new_zeros(predictions.shape[0]).long() 365 | cls_pred = predictions[:, :2] 366 | cls_loss = cls_loss + cls_criterion( 367 | cls_pred, cls_target).sum() 368 | continue 369 | 370 | with torch.no_grad(): 371 | matched_row_inds, matched_col_inds = assign( 372 | predictions, target, self.img_w, self.img_h) 373 | 374 | # classification targets 375 | cls_target = predictions.new_zeros(predictions.shape[0]).long() 376 | cls_target[matched_row_inds] = 1 377 | cls_pred = predictions[:, :2] 378 | 379 | # regression targets -> [start_y, start_x, theta] (all transformed to absolute values), only on matched pairs 380 | reg_yxtl = predictions[matched_row_inds, 2:6] 381 | reg_yxtl[:, 0] *= self.n_strips 382 | reg_yxtl[:, 1] *= (self.img_w - 1) 383 | reg_yxtl[:, 2] *= 180 384 | reg_yxtl[:, 3] *= self.n_strips 385 | 386 | target_yxtl = target[matched_col_inds, 2:6].clone() 387 | 388 | # regression targets -> S coordinates (all transformed to absolute values) 389 | reg_pred = predictions[matched_row_inds, 6:] 390 | reg_pred *= (self.img_w - 1) 391 | reg_targets = target[matched_col_inds, 6:].clone() 392 | 393 | with torch.no_grad(): 394 | predictions_starts = torch.clamp( 395 | (predictions[matched_row_inds, 2] * 396 | self.n_strips).round().long(), 0, 397 | self.n_strips) # ensure the predictions starts is valid 398 | target_starts = (target[matched_col_inds, 2] * 399 | self.n_strips).round().long() 400 | target_yxtl[:, -1] -= (predictions_starts - target_starts 401 | ) # reg length 402 | 403 | # Loss calculation 404 | cls_loss = cls_loss + cls_criterion(cls_pred, cls_target).sum( 405 | ) / target.shape[0] 406 | 407 | target_yxtl[:, 0] *= self.n_strips 408 | target_yxtl[:, 2] *= 180 409 | reg_xytl_loss = reg_xytl_loss + F.smooth_l1_loss( 410 | reg_yxtl, target_yxtl, 411 | reduction='none').mean() 412 | 413 | iou_loss = iou_loss + liou_loss( 414 | reg_pred, reg_targets, 415 | self.img_w, length=15) 416 | 417 | # calculate acc 418 | cls_accuracy = accuracy(cls_pred, cls_target) 419 | cls_acc_stage.append(cls_accuracy) 420 | 421 | cls_acc.append(sum(cls_acc_stage) / len(cls_acc_stage)) 422 | 423 | # extra segmentation loss 424 | seg_loss = self.criterion(F.log_softmax(output['seg'], dim=1), 425 | batch['seg'].long()) 426 | 427 | cls_loss /= (len(targets) * self.refine_layers) 428 | reg_xytl_loss /= (len(targets) * self.refine_layers) 429 | iou_loss /= (len(targets) * self.refine_layers) 430 | 431 | loss = cls_loss * cls_loss_weight + reg_xytl_loss * xyt_loss_weight \ 432 | + seg_loss * seg_loss_weight + iou_loss * iou_loss_weight 433 | 434 | return_value = { 435 | 'loss': loss, 436 | 'loss_stats': { 437 | 'loss': loss, 438 | 'cls_loss': cls_loss * cls_loss_weight, 439 | 'reg_xytl_loss': reg_xytl_loss * xyt_loss_weight, 440 | 'seg_loss': seg_loss * seg_loss_weight, 441 | 'iou_loss': iou_loss * iou_loss_weight 442 | } 443 | } 444 | 445 | for i in range(self.refine_layers): 446 | return_value['loss_stats']['stage_{}_acc'.format(i)] = cls_acc[i] 447 | 448 | return return_value 449 | 450 | 451 | def get_lanes(self, output, as_lanes=True): 452 | ''' 453 | Convert model output to lanes. 454 | ''' 455 | 456 | softmax = nn.Softmax(dim=1) 457 | 458 | decoded = [] 459 | for predictions in output: 460 | # filter out the conf lower than conf threshold 461 | threshold = self.cfg.test_parameters.conf_threshold 462 | scores = softmax(predictions[:, :2])[:, 1] 463 | keep_inds = scores >= threshold 464 | predictions = predictions[keep_inds] 465 | scores = scores[keep_inds] 466 | 467 | if predictions.shape[0] == 0: 468 | decoded.append([]) 469 | continue 470 | nms_predictions = predictions.detach().clone() 471 | nms_predictions = torch.cat( 472 | [nms_predictions[..., :4], nms_predictions[..., 5:]], dim=-1) 473 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips 474 | nms_predictions[..., 475 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1) 476 | 477 | keep, num_to_keep, _ = nms( 478 | nms_predictions, 479 | scores, 480 | overlap=self.cfg.test_parameters.nms_thres, 481 | top_k=self.cfg.max_lanes) 482 | 483 | keep = keep[:num_to_keep] 484 | predictions = predictions[keep] 485 | 486 | if predictions.shape[0] == 0: 487 | decoded.append([]) 488 | continue 489 | 490 | predictions[:, 5] = torch.round(predictions[:, 5] * self.n_strips) 491 | #print("SDDDKLLLJJKLLLL") 492 | #as_lanes = False 493 | if as_lanes: 494 | pred = self.predictions_to_pred(predictions) 495 | else: 496 | pred = predictions 497 | 498 | decoded.append(pred) 499 | 500 | return decoded 501 | -------------------------------------------------------------------------------- /demo_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import timeit 6 | import onnxruntime 7 | import torch 8 | from clrnet.ops import nms_impl 9 | 10 | from scipy.interpolate import InterpolatedUnivariateSpline 11 | import numpy as np 12 | 13 | COLORS = [ 14 | (255, 0, 0), 15 | (0, 255, 0), 16 | (0, 0, 255), 17 | (255, 255, 0), 18 | (255, 0, 255), 19 | (0, 255, 255), 20 | (128, 255, 0), 21 | (255, 128, 0), 22 | (128, 0, 255), 23 | (255, 0, 128), 24 | (0, 128, 255), 25 | (0, 255, 128), 26 | (128, 255, 255), 27 | (255, 128, 255), 28 | (255, 255, 128), 29 | (60, 180, 0), 30 | (180, 60, 0), 31 | (0, 60, 180), 32 | (0, 180, 60), 33 | (60, 0, 180), 34 | (180, 0, 60), 35 | (255, 0, 0), 36 | (0, 255, 0), 37 | (0, 0, 255), 38 | (255, 255, 0), 39 | (255, 0, 255), 40 | (0, 255, 255), 41 | (128, 255, 0), 42 | (255, 128, 0), 43 | (128, 0, 255), 44 | ] 45 | 46 | class Lane: 47 | def __init__(self, points=None, invalid_value=-2., metadata=None): 48 | super(Lane, self).__init__() 49 | self.curr_iter = 0 50 | self.points = points 51 | self.invalid_value = invalid_value 52 | self.function = InterpolatedUnivariateSpline(points[:, 1], 53 | points[:, 0], 54 | k=min(3, 55 | len(points) - 1)) 56 | self.min_y = points[:, 1].min() - 0.01 57 | self.max_y = points[:, 1].max() + 0.01 58 | 59 | self.metadata = metadata or {} 60 | 61 | self.sample_y = range(710, 150, -10) 62 | self.ori_img_w = 1280 63 | self.ori_img_h = 720 64 | 65 | def __repr__(self): 66 | return '[Lane]\n' + str(self.points) + '\n[/Lane]' 67 | 68 | def __call__(self, lane_ys): 69 | lane_xs = self.function(lane_ys) 70 | 71 | lane_xs[(lane_ys < self.min_y) | 72 | (lane_ys > self.max_y)] = self.invalid_value 73 | return lane_xs 74 | 75 | def to_array(self): 76 | sample_y = self.sample_y 77 | img_w, img_h = self.ori_img_w, self.ori_img_h 78 | ys = np.array(sample_y) / float(img_h) 79 | xs = self(ys) 80 | valid_mask = (xs >= 0) & (xs < 1) 81 | lane_xs = xs[valid_mask] * img_w 82 | lane_ys = ys[valid_mask] * img_h 83 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 84 | axis=1) 85 | return lane 86 | 87 | def __iter__(self): 88 | return self 89 | 90 | def __next__(self): 91 | if self.curr_iter < len(self.points): 92 | self.curr_iter += 1 93 | return self.points[self.curr_iter - 1] 94 | self.curr_iter = 0 95 | raise StopIteration 96 | 97 | def nms(boxes, scores, overlap, top_k): 98 | return nms_impl.nms_forward(boxes, scores, overlap, top_k) 99 | 100 | class CLRNetDemo(): 101 | def __init__(self, model_path): 102 | self.ort_session = onnxruntime.InferenceSession(model_path) 103 | self.conf_threshold = 0.4 104 | self.nms_thres = 50 105 | self.max_lanes = 5 106 | self.sample_points = 36 107 | self.num_points = 72 108 | self.n_offsets = 72 109 | self.n_strips = 71 110 | self.img_w = 1280 111 | self.img_h = 720 112 | self.ori_img_w = 1280 113 | self.ori_img_h = 720 114 | self.cut_height = 160 115 | 116 | self.input_width = 800 117 | self.input_height = 320 118 | 119 | self.sample_x_indexs = (np.linspace(0, 1, self.sample_points) * self.n_strips) 120 | self.prior_feat_ys = np.flip((1 - self.sample_x_indexs / self.n_strips)) 121 | self.prior_ys = np.linspace(1,0, self.n_offsets) 122 | 123 | def softmax(self, x, axis=None): 124 | x = x - x.max(axis=axis, keepdims=True) 125 | y = np.exp(x) 126 | return y / y.sum(axis=axis, keepdims=True) 127 | 128 | def predictions_to_pred(self, predictions): 129 | lanes = [] 130 | for lane in predictions: 131 | lane_xs = lane[6:] # normalized value 132 | start = min(max(0, int(round(lane[2].item() * self.n_strips))), 133 | self.n_strips) 134 | length = int(round(lane[5].item())) 135 | end = start + length - 1 136 | end = min(end, len(self.prior_ys) - 1) 137 | # end = label_end 138 | # if the prediction does not start at the bottom of the image, 139 | # extend its prediction until the x is outside the image 140 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.) 141 | )[::-1].cumprod()[::-1]).astype(np.bool)) 142 | 143 | lane_xs[end + 1:] = -2 144 | lane_xs[:start][mask] = -2 145 | lane_ys = self.prior_ys[lane_xs >= 0] 146 | lane_xs = lane_xs[lane_xs >= 0] 147 | 148 | lane_xs = np.double(lane_xs) 149 | lane_xs = np.flip(lane_xs, axis=0) 150 | lane_ys = np.flip(lane_ys, axis=0) 151 | lane_ys = (lane_ys * (self.ori_img_h - self.cut_height) + 152 | self.cut_height) / self.ori_img_h 153 | if len(lane_xs) <= 1: 154 | continue 155 | 156 | points = np.stack( 157 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 158 | axis=1).squeeze(2) 159 | 160 | lane = Lane(points=points, 161 | metadata={ 162 | 'start_x': lane[3], 163 | 'start_y': lane[2], 164 | 'conf': lane[1] 165 | }) 166 | lanes.append(lane) 167 | return lanes 168 | 169 | def get_lanes(self, output, as_lanes=True): 170 | ''' 171 | Convert model output to lanes. 172 | ''' 173 | decoded = [] 174 | for predictions in output: 175 | # filter out the conf lower than conf threshold 176 | scores = self.softmax(predictions[:, :2], 1)[:, 1] 177 | 178 | keep_inds = scores >= self.conf_threshold 179 | predictions = predictions[keep_inds] 180 | scores = scores[keep_inds] 181 | 182 | if predictions.shape[0] == 0: 183 | decoded.append([]) 184 | continue 185 | nms_predictions = predictions 186 | 187 | nms_predictions = np.concatenate( 188 | [nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1) 189 | 190 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips 191 | nms_predictions[..., 192 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1) 193 | 194 | keep, num_to_keep, _ = nms( 195 | torch.tensor(nms_predictions).cuda(), 196 | torch.tensor(scores).cuda(), 197 | overlap=self.nms_thres, 198 | top_k=self.max_lanes) 199 | 200 | keep = keep[:num_to_keep].cpu().numpy() 201 | predictions = predictions[keep] 202 | 203 | predictions = predictions 204 | 205 | if predictions.shape[0] == 0: 206 | decoded.append([]) 207 | continue 208 | 209 | predictions[:, 5] = np.round(predictions[:, 5] * self.n_strips) 210 | pred = self.predictions_to_pred(predictions) 211 | decoded.append(pred) 212 | 213 | return decoded 214 | 215 | def imshow_lanes(self, img, lanes, show=False, out_file=None, width=4): 216 | lanes = [lane.to_array() for lane in lanes] 217 | 218 | lanes_xys = [] 219 | for _, lane in enumerate(lanes): 220 | xys = [] 221 | for x, y in lane: 222 | if x <= 0 or y <= 0: 223 | continue 224 | x, y = int(x), int(y) 225 | xys.append((x, y)) 226 | lanes_xys.append(xys) 227 | lanes_xys.sort(key=lambda xys : xys[0][0]) 228 | 229 | for idx, xys in enumerate(lanes_xys): 230 | for i in range(1, len(xys)): 231 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) 232 | return img 233 | 234 | def forward(self, img): 235 | img_ = img.copy() 236 | h, w = img.shape[:2] 237 | img = img[self.cut_height:, :, :] 238 | img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC) 239 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 240 | 241 | img = img.astype(np.float32) / 255.0 242 | 243 | img = np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) 244 | 245 | ort_inputs = {self.ort_session.get_inputs()[0].name: img} 246 | ort_outs = self.ort_session.run(None, ort_inputs) 247 | output = ort_outs[0] 248 | 249 | output = self.get_lanes(output) 250 | 251 | res = self.imshow_lanes(img_, output[0]) 252 | return res 253 | 254 | if __name__ == "__main__": 255 | clr = CLRNetDemo('./tusimple_r18.onnx') 256 | img = cv2.imread('./test.jpg') 257 | output = clr.forward(img) 258 | cv2.imwrite('output_onnx.png', output) 259 | -------------------------------------------------------------------------------- /demo_onnx_new.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import onnxruntime 6 | from scipy.interpolate import InterpolatedUnivariateSpline 7 | 8 | COLORS = [ 9 | (255, 0, 0), 10 | (0, 255, 0), 11 | (0, 0, 255), 12 | (255, 255, 0), 13 | (255, 0, 255), 14 | (0, 255, 255), 15 | (128, 255, 0), 16 | (255, 128, 0), 17 | (128, 0, 255), 18 | (255, 0, 128), 19 | (0, 128, 255), 20 | (0, 255, 128), 21 | (128, 255, 255), 22 | (255, 128, 255), 23 | (255, 255, 128), 24 | (60, 180, 0), 25 | (180, 60, 0), 26 | (0, 60, 180), 27 | (0, 180, 60), 28 | (60, 0, 180), 29 | (180, 0, 60), 30 | (255, 0, 0), 31 | (0, 255, 0), 32 | (0, 0, 255), 33 | (255, 255, 0), 34 | (255, 0, 255), 35 | (0, 255, 255), 36 | (128, 255, 0), 37 | (255, 128, 0), 38 | (128, 0, 255), 39 | ] 40 | 41 | class Lane: 42 | def __init__(self, points=None, invalid_value=-2., metadata=None): 43 | super(Lane, self).__init__() 44 | self.curr_iter = 0 45 | self.points = points 46 | self.invalid_value = invalid_value 47 | self.function = InterpolatedUnivariateSpline(points[:, 1], 48 | points[:, 0], 49 | k=min(3, 50 | len(points) - 1)) 51 | self.min_y = points[:, 1].min() - 0.01 52 | self.max_y = points[:, 1].max() + 0.01 53 | 54 | self.metadata = metadata or {} 55 | 56 | self.sample_y = range(710, 150, -10) 57 | self.ori_img_w = 1280 58 | self.ori_img_h = 720 59 | 60 | def __repr__(self): 61 | return '[Lane]\n' + str(self.points) + '\n[/Lane]' 62 | 63 | def __call__(self, lane_ys): 64 | lane_xs = self.function(lane_ys) 65 | 66 | lane_xs[(lane_ys < self.min_y) | 67 | (lane_ys > self.max_y)] = self.invalid_value 68 | return lane_xs 69 | 70 | def to_array(self): 71 | sample_y = self.sample_y 72 | img_w, img_h = self.ori_img_w, self.ori_img_h 73 | ys = np.array(sample_y) / float(img_h) 74 | xs = self(ys) 75 | valid_mask = (xs >= 0) & (xs < 1) 76 | lane_xs = xs[valid_mask] * img_w 77 | lane_ys = ys[valid_mask] * img_h 78 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 79 | axis=1) 80 | return lane 81 | 82 | def __iter__(self): 83 | return self 84 | 85 | def __next__(self): 86 | if self.curr_iter < len(self.points): 87 | self.curr_iter += 1 88 | return self.points[self.curr_iter - 1] 89 | self.curr_iter = 0 90 | raise StopIteration 91 | 92 | 93 | 94 | class CLRNetDemo(): 95 | def __init__(self, model_path): 96 | self.ort_session = onnxruntime.InferenceSession(model_path) 97 | self.conf_threshold = 0.4 98 | self.nms_thres = 50 99 | self.max_lanes = 5 100 | self.sample_points = 36 101 | self.num_points = 72 102 | self.n_offsets = 72 103 | self.n_strips = 71 104 | self.img_w = 1280 105 | self.img_h = 720 106 | self.ori_img_w = 1280 107 | self.ori_img_h = 720 108 | self.cut_height = 160 109 | 110 | self.input_width = 800 111 | self.input_height = 320 112 | 113 | self.sample_x_indexs = (np.linspace(0, 1, self.sample_points) * self.n_strips) 114 | self.prior_feat_ys = np.flip((1 - self.sample_x_indexs / self.n_strips)) 115 | self.prior_ys = np.linspace(1,0, self.n_offsets) 116 | 117 | def softmax(self, x, axis=None): 118 | x = x - x.max(axis=axis, keepdims=True) 119 | y = np.exp(x) 120 | return y / y.sum(axis=axis, keepdims=True) 121 | 122 | 123 | def Lane_nms(self, proposals,scores,overlap=50, top_k=4): 124 | keep_index = [] 125 | sorted_score = np.sort(scores)[-1] # from big to small 126 | indices = np.argsort(-scores) # from big to small 127 | 128 | r_filters = np.zeros(len(scores)) 129 | 130 | for i,indice in enumerate(indices): 131 | if r_filters[i]==1: # continue if this proposal is filted by nms before 132 | continue 133 | keep_index.append(indice) 134 | if len(keep_index)>top_k: # break if more than top_k 135 | break 136 | if i == (len(scores)-1):# break if indice is the last one 137 | break 138 | sub_indices = indices[i+1:] 139 | for sub_i,sub_indice in enumerate(sub_indices): 140 | r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap) 141 | if r_filter: r_filters[i+1+sub_i]=1 142 | num_to_keep = len(keep_index) 143 | keep_index = list(map(lambda x: x.item(), keep_index)) 144 | return keep_index, num_to_keep 145 | 146 | def Lane_IOU(self, parent_box, compared_box, threshold): 147 | ''' 148 | calculate distance one pair of proposal lines 149 | return True if distance less than threshold 150 | ''' 151 | n_offsets=72 152 | n_strips = n_offsets - 1 153 | 154 | start_a = (parent_box[2] * n_strips + 0.5).astype(int) # add 0.5 trick to make int() like round 155 | start_b = (compared_box[2] * n_strips + 0.5).astype(int) 156 | start = max(start_a,start_b) 157 | end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).astype(int)) 158 | end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).astype(int)) 159 | end = min(min(end_a,end_b),71) 160 | if (end - start)<0: 161 | return False 162 | dist = 0 163 | for i in range(5+start,5 + end.astype(int)): 164 | if i>(5+end): 165 | break 166 | if parent_box[i] < compared_box[i]: 167 | dist += compared_box[i] - parent_box[i] 168 | else: 169 | dist += parent_box[i] - compared_box[i] 170 | return dist < (threshold * (end - start + 1)) 171 | 172 | 173 | def predictions_to_pred(self, predictions): 174 | lanes = [] 175 | for lane in predictions: 176 | lane_xs = lane[6:] # normalized value 177 | start = min(max(0, int(round(lane[2].item() * self.n_strips))), 178 | self.n_strips) 179 | length = int(round(lane[5].item())) 180 | end = start + length - 1 181 | end = min(end, len(self.prior_ys) - 1) 182 | # end = label_end 183 | # if the prediction does not start at the bottom of the image, 184 | # extend its prediction until the x is outside the image 185 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.) 186 | )[::-1].cumprod()[::-1]).astype(np.bool)) 187 | 188 | lane_xs[end + 1:] = -2 189 | lane_xs[:start][mask] = -2 190 | lane_ys = self.prior_ys[lane_xs >= 0] 191 | lane_xs = lane_xs[lane_xs >= 0] 192 | 193 | lane_xs = np.double(lane_xs) 194 | lane_xs = np.flip(lane_xs, axis=0) 195 | lane_ys = np.flip(lane_ys, axis=0) 196 | lane_ys = (lane_ys * (self.ori_img_h - self.cut_height) + 197 | self.cut_height) / self.ori_img_h 198 | if len(lane_xs) <= 1: 199 | continue 200 | 201 | points = np.stack( 202 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 203 | axis=1).squeeze(2) 204 | 205 | lane = Lane(points=points, 206 | metadata={ 207 | 'start_x': lane[3], 208 | 'start_y': lane[2], 209 | 'conf': lane[1] 210 | }) 211 | lanes.append(lane) 212 | return lanes 213 | 214 | def get_lanes(self, output, as_lanes=True): 215 | ''' 216 | Convert model output to lanes. 217 | ''' 218 | decoded = [] 219 | for predictions in output: 220 | # filter out the conf lower than conf threshold 221 | scores = self.softmax(predictions[:, :2], 1)[:, 1] 222 | 223 | keep_inds = scores >= self.conf_threshold 224 | predictions = predictions[keep_inds] 225 | scores = scores[keep_inds] 226 | 227 | if predictions.shape[0] == 0: 228 | decoded.append([]) 229 | continue 230 | nms_predictions = predictions 231 | 232 | nms_predictions = np.concatenate( 233 | [nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1) 234 | 235 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips 236 | nms_predictions[..., 237 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1) 238 | 239 | 240 | keep, num_to_keep = self.Lane_nms( 241 | nms_predictions, 242 | scores, 243 | self.nms_thres, 244 | self.max_lanes) 245 | 246 | keep = keep[:num_to_keep] 247 | predictions = predictions[keep] 248 | 249 | if predictions.shape[0] == 0: 250 | decoded.append([]) 251 | continue 252 | 253 | predictions[:, 5] = np.round(predictions[:, 5] * self.n_strips) 254 | pred = self.predictions_to_pred(predictions) 255 | decoded.append(pred) 256 | 257 | return decoded 258 | 259 | def imshow_lanes(self, img, lanes, show=False, out_file=None, width=4): 260 | lanes = [lane.to_array() for lane in lanes] 261 | 262 | lanes_xys = [] 263 | for _, lane in enumerate(lanes): 264 | xys = [] 265 | for x, y in lane: 266 | if x <= 0 or y <= 0: 267 | continue 268 | x, y = int(x), int(y) 269 | xys.append((x, y)) 270 | lanes_xys.append(xys) 271 | lanes_xys.sort(key=lambda xys : xys[0][0]) 272 | 273 | for idx, xys in enumerate(lanes_xys): 274 | for i in range(1, len(xys)): 275 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) 276 | return img 277 | 278 | def forward(self, img): 279 | img_ = img.copy() 280 | h, w = img.shape[:2] 281 | img = img[self.cut_height:, :, :] 282 | img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC) 283 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 284 | img = img.astype(np.float32) / 255.0 285 | 286 | img = np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) 287 | 288 | ort_inputs = {self.ort_session.get_inputs()[0].name: img} 289 | ort_outs = self.ort_session.run(None, ort_inputs) 290 | output = ort_outs[0] 291 | 292 | output = self.get_lanes(output) 293 | res = self.imshow_lanes(img_, output[0]) 294 | return res 295 | 296 | if __name__ == "__main__": 297 | clr = CLRNetDemo('./tusimple_r18.onnx') 298 | img = cv2.imread('./test.jpg') 299 | output = clr.forward(img) 300 | cv2.imwrite('output_onnx.png', output) 301 | print("Done!") 302 | -------------------------------------------------------------------------------- /demo_trt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import sys 5 | import cv2 6 | import numpy as np 7 | import tensorrt as trt 8 | import pycuda.driver as cuda 9 | import pycuda.autoinit 10 | 11 | from scipy.interpolate import InterpolatedUnivariateSpline 12 | 13 | COLORS = [ 14 | (255, 0, 0), 15 | (0, 255, 0), 16 | (0, 0, 255), 17 | (255, 255, 0), 18 | (255, 0, 255), 19 | (0, 255, 255), 20 | (128, 255, 0), 21 | (255, 128, 0), 22 | (128, 0, 255), 23 | (255, 0, 128), 24 | (0, 128, 255), 25 | (0, 255, 128), 26 | (128, 255, 255), 27 | (255, 128, 255), 28 | (255, 255, 128), 29 | (60, 180, 0), 30 | (180, 60, 0), 31 | (0, 60, 180), 32 | (0, 180, 60), 33 | (60, 0, 180), 34 | (180, 0, 60), 35 | (255, 0, 0), 36 | (0, 255, 0), 37 | (0, 0, 255), 38 | (255, 255, 0), 39 | (255, 0, 255), 40 | (0, 255, 255), 41 | (128, 255, 0), 42 | (255, 128, 0), 43 | (128, 0, 255), 44 | ] 45 | 46 | class Lane: 47 | def __init__(self, points=None, invalid_value=-2., metadata=None): 48 | super(Lane, self).__init__() 49 | self.curr_iter = 0 50 | self.points = points 51 | self.invalid_value = invalid_value 52 | self.function = InterpolatedUnivariateSpline(points[:, 1], 53 | points[:, 0], 54 | k=min(3, 55 | len(points) - 1)) 56 | self.min_y = points[:, 1].min() - 0.01 57 | self.max_y = points[:, 1].max() + 0.01 58 | 59 | self.metadata = metadata or {} 60 | 61 | self.sample_y = range(710, 150, -10) 62 | self.ori_img_w = 1280 63 | self.ori_img_h = 720 64 | 65 | def __repr__(self): 66 | return '[Lane]\n' + str(self.points) + '\n[/Lane]' 67 | 68 | def __call__(self, lane_ys): 69 | lane_xs = self.function(lane_ys) 70 | 71 | lane_xs[(lane_ys < self.min_y) | 72 | (lane_ys > self.max_y)] = self.invalid_value 73 | return lane_xs 74 | 75 | def to_array(self): 76 | sample_y = self.sample_y 77 | img_w, img_h = self.ori_img_w, self.ori_img_h 78 | ys = np.array(sample_y) / float(img_h) 79 | xs = self(ys) 80 | valid_mask = (xs >= 0) & (xs < 1) 81 | lane_xs = xs[valid_mask] * img_w 82 | lane_ys = ys[valid_mask] * img_h 83 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 84 | axis=1) 85 | return lane 86 | 87 | def __iter__(self): 88 | return self 89 | 90 | def __next__(self): 91 | if self.curr_iter < len(self.points): 92 | self.curr_iter += 1 93 | return self.points[self.curr_iter - 1] 94 | self.curr_iter = 0 95 | raise StopIteration 96 | 97 | class CLRNetDemo: 98 | def __init__(self, engine_path): 99 | self.logger = trt.Logger(trt.Logger.ERROR) 100 | with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: 101 | self.engine = runtime.deserialize_cuda_engine(f.read()) 102 | self.context = self.engine.create_execution_context() 103 | 104 | self.inputs = [] 105 | self.outputs = [] 106 | self.allocations = [] 107 | for i in range(self.engine.num_bindings): 108 | is_input = False 109 | if self.engine.binding_is_input(i): 110 | is_input = True 111 | name = self.engine.get_binding_name(i) 112 | dtype = self.engine.get_binding_dtype(i) 113 | shape = self.engine.get_binding_shape(i) 114 | if is_input: 115 | self.batch_size = shape[0] 116 | size = np.dtype(trt.nptype(dtype)).itemsize 117 | for s in shape: 118 | size *= s 119 | allocation = cuda.mem_alloc(size) 120 | binding = { 121 | 'index': i, 122 | 'name': name, 123 | 'dtype': np.dtype(trt.nptype(dtype)), 124 | 'shape': list(shape), 125 | 'allocation': allocation, 126 | } 127 | self.allocations.append(allocation) 128 | if self.engine.binding_is_input(i): 129 | self.inputs.append(binding) 130 | else: 131 | self.outputs.append(binding) 132 | 133 | self.conf_threshold = 0.4 134 | self.nms_thres = 50 135 | self.max_lanes = 5 136 | self.sample_points = 36 137 | self.num_points = 72 138 | self.n_offsets = 72 139 | self.n_strips = 71 140 | self.img_w = 1280 141 | self.img_h = 720 142 | self.ori_img_w = 1280 143 | self.ori_img_h = 720 144 | self.cut_height = 160 145 | 146 | self.input_width = 800 147 | self.input_height = 320 148 | 149 | self.sample_x_indexs = (np.linspace(0, 1, self.sample_points) * self.n_strips) 150 | self.prior_feat_ys = np.flip((1 - self.sample_x_indexs / self.n_strips)) 151 | self.prior_ys = np.linspace(1,0, self.n_offsets) 152 | 153 | def softmax(self, x, axis=None): 154 | x = x - x.max(axis=axis, keepdims=True) 155 | y = np.exp(x) 156 | return y / y.sum(axis=axis, keepdims=True) 157 | 158 | def Lane_nms(self, proposals,scores,overlap=50, top_k=4): 159 | keep_index = [] 160 | sorted_score = np.sort(scores)[-1] # from big to small 161 | indices = np.argsort(-scores) # from big to small 162 | 163 | r_filters = np.zeros(len(scores)) 164 | 165 | for i,indice in enumerate(indices): 166 | if r_filters[i]==1: # continue if this proposal is filted by nms before 167 | continue 168 | keep_index.append(indice) 169 | if len(keep_index)>top_k: # break if more than top_k 170 | break 171 | if i == (len(scores)-1):# break if indice is the last one 172 | break 173 | sub_indices = indices[i+1:] 174 | for sub_i,sub_indice in enumerate(sub_indices): 175 | r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap) 176 | if r_filter: r_filters[i+1+sub_i]=1 177 | num_to_keep = len(keep_index) 178 | keep_index = list(map(lambda x: x.item(), keep_index)) 179 | return keep_index, num_to_keep 180 | 181 | def Lane_IOU(self, parent_box, compared_box, threshold): 182 | ''' 183 | calculate distance one pair of proposal lines 184 | return True if distance less than threshold 185 | ''' 186 | n_offsets=72 187 | n_strips = n_offsets - 1 188 | 189 | start_a = (parent_box[2] * n_strips + 0.5).astype(int) # add 0.5 trick to make int() like round 190 | start_b = (compared_box[2] * n_strips + 0.5).astype(int) 191 | start = max(start_a,start_b) 192 | end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).astype(int)) 193 | end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).astype(int)) 194 | end = min(min(end_a,end_b),71) 195 | if (end - start)<0: 196 | return False 197 | dist = 0 198 | for i in range(5+start,5 + end.astype(int)): 199 | if i>(5+end): 200 | break 201 | if parent_box[i] < compared_box[i]: 202 | dist += compared_box[i] - parent_box[i] 203 | else: 204 | dist += parent_box[i] - compared_box[i] 205 | return dist < (threshold * (end - start + 1)) 206 | 207 | def predictions_to_pred(self, predictions): 208 | lanes = [] 209 | for lane in predictions: 210 | lane_xs = lane[6:] # normalized value 211 | start = min(max(0, int(round(lane[2].item() * self.n_strips))), 212 | self.n_strips) 213 | length = int(round(lane[5].item())) 214 | end = start + length - 1 215 | end = min(end, len(self.prior_ys) - 1) 216 | # end = label_end 217 | # if the prediction does not start at the bottom of the image, 218 | # extend its prediction until the x is outside the image 219 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.) 220 | )[::-1].cumprod()[::-1]).astype(np.bool)) 221 | 222 | lane_xs[end + 1:] = -2 223 | lane_xs[:start][mask] = -2 224 | lane_ys = self.prior_ys[lane_xs >= 0] 225 | lane_xs = lane_xs[lane_xs >= 0] 226 | 227 | lane_xs = np.double(lane_xs) 228 | lane_xs = np.flip(lane_xs, axis=0) 229 | lane_ys = np.flip(lane_ys, axis=0) 230 | lane_ys = (lane_ys * (self.ori_img_h - self.cut_height) + 231 | self.cut_height) / self.ori_img_h 232 | if len(lane_xs) <= 1: 233 | continue 234 | 235 | points = np.stack( 236 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), 237 | axis=1).squeeze(2) 238 | 239 | lane = Lane(points=points, 240 | metadata={ 241 | 'start_x': lane[3], 242 | 'start_y': lane[2], 243 | 'conf': lane[1] 244 | }) 245 | lanes.append(lane) 246 | return lanes 247 | 248 | def get_lanes(self, output, as_lanes=True): 249 | ''' 250 | Convert model output to lanes. 251 | ''' 252 | decoded = [] 253 | for predictions in output: 254 | # filter out the conf lower than conf threshold 255 | scores = self.softmax(predictions[:, :2], 1)[:, 1] 256 | 257 | keep_inds = scores >= self.conf_threshold 258 | predictions = predictions[keep_inds] 259 | scores = scores[keep_inds] 260 | 261 | if predictions.shape[0] == 0: 262 | decoded.append([]) 263 | continue 264 | nms_predictions = predictions 265 | 266 | nms_predictions = np.concatenate( 267 | [nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1) 268 | 269 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips 270 | nms_predictions[..., 271 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1) 272 | 273 | 274 | keep, num_to_keep = self.Lane_nms( 275 | nms_predictions, 276 | scores, 277 | self.nms_thres, 278 | self.max_lanes) 279 | 280 | keep = keep[:num_to_keep] 281 | predictions = predictions[keep] 282 | 283 | if predictions.shape[0] == 0: 284 | decoded.append([]) 285 | continue 286 | 287 | predictions[:, 5] = np.round(predictions[:, 5] * self.n_strips) 288 | pred = self.predictions_to_pred(predictions) 289 | decoded.append(pred) 290 | 291 | return decoded 292 | 293 | def imshow_lanes(self, img, lanes, show=False, out_file=None, width=4): 294 | lanes = [lane.to_array() for lane in lanes] 295 | 296 | lanes_xys = [] 297 | for _, lane in enumerate(lanes): 298 | xys = [] 299 | for x, y in lane: 300 | if x <= 0 or y <= 0: 301 | continue 302 | x, y = int(x), int(y) 303 | xys.append((x, y)) 304 | lanes_xys.append(xys) 305 | lanes_xys.sort(key=lambda xys : xys[0][0]) 306 | 307 | for idx, xys in enumerate(lanes_xys): 308 | for i in range(1, len(xys)): 309 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) 310 | return img 311 | 312 | def forward(self, img): 313 | img_ = img.copy() 314 | h, w = img.shape[:2] 315 | img = img[self.cut_height:, :, :] 316 | img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC) 317 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 318 | 319 | img = img.astype(np.float32) / 255.0 320 | 321 | img = np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1)) 322 | img = np.ascontiguousarray(img) 323 | cuda.memcpy_htod(self.inputs[0]['allocation'], img) 324 | self.context.execute_v2(self.allocations) 325 | outputs = [] 326 | for out in self.outputs: 327 | output = np.zeros(out['shape'],out['dtype']) 328 | cuda.memcpy_dtoh(output, out['allocation']) 329 | outputs.append(output) 330 | 331 | output = outputs[0] 332 | output = self.get_lanes(output) 333 | res = self.imshow_lanes(img_, output[0]) 334 | return res 335 | 336 | if __name__ == "__main__": 337 | isnet = CLRNetDemo('tusimple_r18.engine') 338 | image = cv2.imread('test.jpg') 339 | output = isnet.forward(image) 340 | cv2.imwrite('output_trt.png', output) -------------------------------------------------------------------------------- /grid_sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import itertools 4 | import operator 5 | 6 | 7 | def gather(input, dim, index): 8 | indices = [torch.arange(size, device=index.device) for size in index.shape] 9 | indices = list(torch.meshgrid(*indices)) 10 | indices[dim] = index 11 | sizes = list(reversed(list(itertools.accumulate(reversed(input.shape), operator.mul)))) 12 | index = sum((index * size for index, size in zip(indices, sizes[1:] + [1]))) 13 | output = input.flatten()[index] 14 | return output 15 | 16 | 17 | def bilinear_grid_sample(im, grid, align_corners=False): 18 | """Given an input and a flow-field grid, computes the output using input 19 | values and pixel locations from grid. Supported only bilinear interpolation 20 | method to sample the input pixels. 21 | Args: 22 | im (torch.Tensor): Input feature map, shape (N, C, H, W) 23 | grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) 24 | align_corners {bool}: If set to True, the extrema (-1 and 1) are 25 | considered as referring to the center points of the input’s 26 | corner pixels. If set to False, they are instead considered as 27 | referring to the corner points of the input’s corner pixels, 28 | making the sampling more resolution agnostic. 29 | Returns: 30 | torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) 31 | """ 32 | 33 | n, c, h, w = im.shape 34 | gn, gh, gw, _ = grid.shape 35 | assert n == gn 36 | 37 | x = grid[:, :, :, 0] 38 | y = grid[:, :, :, 1] 39 | 40 | if align_corners: 41 | x = ((x + 1) / 2) * (w - 1) 42 | y = ((y + 1) / 2) * (h - 1) 43 | else: 44 | x = ((x + 1) * w - 1) / 2 45 | y = ((y + 1) * h - 1) / 2 46 | 47 | x = x.view(n, -1) 48 | y = y.view(n, -1) 49 | 50 | x0 = torch.floor(x).long() 51 | y0 = torch.floor(y).long() 52 | x1 = x0 + 1 53 | y1 = y0 + 1 54 | 55 | wa = ((x1 - x) * (y1 - y)).unsqueeze(1) 56 | wb = ((x1 - x) * (y - y0)).unsqueeze(1) 57 | wc = ((x - x0) * (y1 - y)).unsqueeze(1) 58 | wd = ((x - x0) * (y - y0)).unsqueeze(1) 59 | 60 | # Apply default for grid_sample function zero padding 61 | im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) 62 | 63 | padded_h = h + 2 64 | padded_w = w + 2 65 | # save points positions after padding 66 | x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 67 | 68 | # Clip coordinates to padded image size 69 | x0 = torch.where(x0 < 0, torch.tensor(0, device=x0.device), x0) 70 | x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=x0.device), x0) 71 | x1 = torch.where(x1 < 0, torch.tensor(0, device=x1.device), x1) 72 | x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=x1.device), x1) 73 | y0 = torch.where(y0 < 0, torch.tensor(0, device=y0.device), y0) 74 | y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=y0.device), y0) 75 | y1 = torch.where(y1 < 0, torch.tensor(0, device=y1.device), y1) 76 | y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=y1.device), y1) 77 | 78 | im_padded = im_padded.view(n, c, -1) 79 | 80 | x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) 81 | x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) 82 | x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) 83 | x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) 84 | 85 | # if (x0 + y0 * padded_w).shape[0] == 1: 86 | # x0_y0 = torch.add(x0, y0 * padded_w).repeat(c, 1).unsqueeze(0) 87 | # x0_y1 = torch.add(x0, y1 * padded_w).repeat(c, 1).unsqueeze(0) 88 | # x1_y0 = torch.add(x1, y0 * padded_w).repeat(c, 1).unsqueeze(0) 89 | # x1_y1 = torch.add(x1, y1 * padded_w).repeat(c, 1).unsqueeze(0) 90 | # else: 91 | # x0_y0 = torch.add(x0, y0 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0) 92 | # x0_y1 = torch.add(x0, y1 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0) 93 | # x1_y0 = torch.add(x1, y0 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0) 94 | # x1_y1 = torch.add(x1, y1 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0) 95 | 96 | # if (x0 + y0 * padded_w).shape[0] == 1: 97 | # # x0_y0 = (x0 + y0 * padded_w).squeeze().repeat(c).view(1, 256, 4096) 98 | # # x0_y1 = (x0 + y1 * padded_w).squeeze().repeat(c).view(1, 256, 4096) 99 | # # x1_y0 = (x1 + y0 * padded_w).squeeze().repeat(c).view(1, 256, 4096) 100 | # # x1_y1 = (x1 + y1 * padded_w).squeeze().repeat(c).view(1, 256, 4096) 101 | # x0_y0 = torch.ones(1, 256, 4096, dtype=x0.dtype) 102 | # x0_y1 = torch.ones(1, 256, 4096, dtype=x0.dtype) 103 | # x1_y0 = torch.ones(1, 256, 4096, dtype=x0.dtype) 104 | # x1_y1 = torch.ones(1, 256, 4096, dtype=x0.dtype) 105 | # else: 106 | # # x0_y0 = torch.stack([(x0 + y0 * padded_w) for _ in range(c)]).transpose(1, 0) 107 | # # x0_y1 = torch.stack([(x0 + y1 * padded_w) for _ in range(c)]).transpose(1, 0) 108 | # # x1_y0 = torch.stack([(x1 + y0 * padded_w) for _ in range(c)]).transpose(1, 0) 109 | # # x1_y1 = torch.stack([(x1 + y1 * padded_w) for _ in range(c)]).transpose(1, 0) 110 | # x0_y0 = torch.ones(11, 3, 4096, dtype=x0.dtype) 111 | # x0_y1 = torch.ones(11, 3, 4096, dtype=x0.dtype) 112 | # x1_y0 = torch.ones(11, 3, 4096, dtype=x0.dtype) 113 | # x1_y1 = torch.ones(11, 3, 4096, dtype=x0.dtype) 114 | 115 | # x0_y0 = torch.cat([(x0 + y0 * padded_w).unsqueeze(1) for _ in range(c)], axis=1) 116 | # x0_y1 = torch.cat([(x0 + y1 * padded_w).unsqueeze(1) for _ in range(c)], axis=1) 117 | # x1_y0 = torch.cat([(x1 + y0 * padded_w).unsqueeze(1) for _ in range(c)], axis=1) 118 | # x1_y1 = torch.cat([(x1 + y1 * padded_w).unsqueeze(1) for _ in range(c)], axis=1) 119 | 120 | Ia = torch.gather(im_padded, 2, x0_y0) 121 | Ib = torch.gather(im_padded, 2, x0_y1) 122 | Ic = torch.gather(im_padded, 2, x1_y0) 123 | Id = torch.gather(im_padded, 2, x1_y1) 124 | # Ia = gather(im_padded, 2, x0_y0) 125 | # Ib = gather(im_padded, 2, x0_y1) 126 | # Ic = gather(im_padded, 2, x1_y0) 127 | # Id = gather(im_padded, 2, x1_y1) 128 | 129 | return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) -------------------------------------------------------------------------------- /imgs/output_onnx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/07a47e666d4bfca9a4c60a777c62c23828fca663/imgs/output_onnx.png -------------------------------------------------------------------------------- /imgs/output_trt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/07a47e666d4bfca9a4c60a777c62c23828fca663/imgs/output_trt.png -------------------------------------------------------------------------------- /my_log/test_onnx.log: -------------------------------------------------------------------------------- 1 | (base) [***@ai02 temp]$ git clone https://github.com/Turoad/clrnet 2 | 正克隆到 'clrnet'... 3 | remote: Enumerating objects: 105, done. 4 | remote: Counting objects: 100% (105/105), done. 5 | remote: Compressing objects: 100% (81/81), done. 6 | remote: Total 105 (delta 22), reused 95 (delta 17), pack-reused 0 7 | 接收对象中: 100% (105/105), 286.98 KiB | 0 bytes/s, done. 8 | 处理 delta 中: 100% (22/22), done. 9 | (base) [***@ai02 temp]$ ls 10 | clrnet 11 | (base) [***@ai02 temp]$ cd clrnet/ 12 | (base) [***@ai02 clrnet]$ ls 13 | clrnet configs LICENSE main.py README.md requirements.txt setup.py tools 14 | (base) [***@ai02 clrnet]$ conda create -n clrnet python=3.8 -y 15 | Collecting package metadata (current_repodata.json): done 16 | Solving environment: done 17 | 18 | 19 | ==> WARNING: A newer version of conda exists. <== 20 | current version: 4.11.0 21 | latest version: 4.14.0 22 | 23 | Please update conda by running 24 | 25 | $ conda update -n base -c defaults conda 26 | 27 | 28 | 29 | ## Package Plan ## 30 | 31 | environment location: /data3/***/anaconda3/envs/clrnet 32 | 33 | added / updated specs: 34 | - python=3.8 35 | 36 | 37 | The following NEW packages will be INSTALLED: 38 | 39 | _libgcc_mutex anaconda/pkgs/main/linux-64::_libgcc_mutex-0.1-main 40 | _openmp_mutex anaconda/pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu 41 | ca-certificates anaconda/pkgs/main/linux-64::ca-certificates-2022.07.19-h06a4308_0 42 | certifi anaconda/pkgs/main/linux-64::certifi-2022.6.15-py38h06a4308_0 43 | ld_impl_linux-64 anaconda/pkgs/main/linux-64::ld_impl_linux-64-2.38-h1181459_1 44 | libffi anaconda/pkgs/main/linux-64::libffi-3.3-he6710b0_2 45 | libgcc-ng anaconda/pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1 46 | libgomp anaconda/pkgs/main/linux-64::libgomp-11.2.0-h1234567_1 47 | libstdcxx-ng anaconda/pkgs/main/linux-64::libstdcxx-ng-11.2.0-h1234567_1 48 | ncurses anaconda/pkgs/main/linux-64::ncurses-6.3-h5eee18b_3 49 | openssl anaconda/pkgs/main/linux-64::openssl-1.1.1q-h7f8727e_0 50 | pip anaconda/pkgs/main/linux-64::pip-22.1.2-py38h06a4308_0 51 | python anaconda/pkgs/main/linux-64::python-3.8.13-h12debd9_0 52 | readline anaconda/pkgs/main/linux-64::readline-8.1.2-h7f8727e_1 53 | setuptools anaconda/pkgs/main/linux-64::setuptools-63.4.1-py38h06a4308_0 54 | sqlite anaconda/pkgs/main/linux-64::sqlite-3.39.2-h5082296_0 55 | tk anaconda/pkgs/main/linux-64::tk-8.6.12-h1ccaba5_0 56 | wheel anaconda/pkgs/main/noarch::wheel-0.37.1-pyhd3eb1b0_0 57 | xz anaconda/pkgs/main/linux-64::xz-5.2.5-h7f8727e_1 58 | zlib anaconda/pkgs/main/linux-64::zlib-1.2.12-h7f8727e_2 59 | 60 | 61 | Preparing transaction: done 62 | Verifying transaction: done 63 | Executing transaction: done 64 | # 65 | # To activate this environment, use 66 | # 67 | # $ conda activate clrnet 68 | # 69 | # To deactivate an active environment, use 70 | # 71 | # $ conda deactivate 72 | 73 | (base) [***@ai02 clrnet]$ conda activate clrnet 74 | (clrnet) [***@ai02 clrnet]$ pip install torch==1.8.0 torchvision==0.9.0 75 | Collecting torch==1.8.0 76 | Using cached torch-1.8.0-cp38-cp38-manylinux1_x86_64.whl (735.5 MB) 77 | Collecting torchvision==0.9.0 78 | Using cached torchvision-0.9.0-cp38-cp38-manylinux1_x86_64.whl (17.3 MB) 79 | Collecting numpy 80 | Downloading numpy-1.23.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB) 81 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.1/17.1 MB 11.5 MB/s eta 0:00:00 82 | Collecting typing-extensions 83 | Using cached typing_extensions-4.3.0-py3-none-any.whl (25 kB) 84 | Collecting pillow>=4.1.1 85 | Using cached Pillow-9.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB) 86 | Installing collected packages: typing-extensions, pillow, numpy, torch, torchvision 87 | Successfully installed numpy-1.23.2 pillow-9.2.0 torch-1.8.0 torchvision-0.9.0 typing-extensions-4.3.0 88 | (clrnet) [***@ai02 clrnet]$ python setup.py build develop 89 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer. 90 | warnings.warn( 91 | running build 92 | running build_py 93 | creating build 94 | creating build/lib.linux-x86_64-cpython-38 95 | creating build/lib.linux-x86_64-cpython-38/clrnet 96 | copying clrnet/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet 97 | creating build/lib.linux-x86_64-cpython-38/clrnet/utils 98 | copying clrnet/utils/tusimple_metric.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 99 | copying clrnet/utils/logger.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 100 | copying clrnet/utils/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 101 | copying clrnet/utils/llamas_utils.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 102 | copying clrnet/utils/llamas_metric.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 103 | copying clrnet/utils/recorder.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 104 | copying clrnet/utils/visualization.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 105 | copying clrnet/utils/culane_metric.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 106 | copying clrnet/utils/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 107 | copying clrnet/utils/config.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 108 | copying clrnet/utils/lane.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 109 | copying clrnet/utils/net_utils.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils 110 | creating build/lib.linux-x86_64-cpython-38/clrnet/engine 111 | copying clrnet/engine/scheduler.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine 112 | copying clrnet/engine/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine 113 | copying clrnet/engine/runner.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine 114 | copying clrnet/engine/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine 115 | copying clrnet/engine/optimizer.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine 116 | creating build/lib.linux-x86_64-cpython-38/clrnet/models 117 | copying clrnet/models/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/models 118 | copying clrnet/models/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models 119 | creating build/lib.linux-x86_64-cpython-38/clrnet/datasets 120 | copying clrnet/datasets/llamas.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets 121 | copying clrnet/datasets/base_dataset.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets 122 | copying clrnet/datasets/tusimple.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets 123 | copying clrnet/datasets/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets 124 | copying clrnet/datasets/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets 125 | copying clrnet/datasets/culane.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets 126 | creating build/lib.linux-x86_64-cpython-38/clrnet/ops 127 | copying clrnet/ops/nms.py -> build/lib.linux-x86_64-cpython-38/clrnet/ops 128 | copying clrnet/ops/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/ops 129 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/utils 130 | copying clrnet/models/utils/roi_gather.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils 131 | copying clrnet/models/utils/seg_decoder.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils 132 | copying clrnet/models/utils/dynamic_assign.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils 133 | copying clrnet/models/utils/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils 134 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/backbones 135 | copying clrnet/models/backbones/dla34.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/backbones 136 | copying clrnet/models/backbones/resnet.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/backbones 137 | copying clrnet/models/backbones/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/backbones 138 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/heads 139 | copying clrnet/models/heads/clr_head.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/heads 140 | copying clrnet/models/heads/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/heads 141 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/losses 142 | copying clrnet/models/losses/focal_loss.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses 143 | copying clrnet/models/losses/accuracy.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses 144 | copying clrnet/models/losses/lineiou_loss.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses 145 | copying clrnet/models/losses/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses 146 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/necks 147 | copying clrnet/models/necks/fpn.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/necks 148 | copying clrnet/models/necks/pafpn.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/necks 149 | copying clrnet/models/necks/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/necks 150 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/nets 151 | copying clrnet/models/nets/detector.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/nets 152 | copying clrnet/models/nets/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/nets 153 | creating build/lib.linux-x86_64-cpython-38/clrnet/datasets/process 154 | copying clrnet/datasets/process/transforms.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process 155 | copying clrnet/datasets/process/process.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process 156 | copying clrnet/datasets/process/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process 157 | copying clrnet/datasets/process/generate_lane_line.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process 158 | running egg_info 159 | creating clrnet.egg-info 160 | writing clrnet.egg-info/PKG-INFO 161 | writing dependency_links to clrnet.egg-info/dependency_links.txt 162 | writing requirements to clrnet.egg-info/requires.txt 163 | writing top-level names to clrnet.egg-info/top_level.txt 164 | writing manifest file 'clrnet.egg-info/SOURCES.txt' 165 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/utils/cpp_extension.py:369: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend. 166 | warnings.warn(msg.format('we could not find ninja.')) 167 | reading manifest file 'clrnet.egg-info/SOURCES.txt' 168 | adding license file 'LICENSE' 169 | writing manifest file 'clrnet.egg-info/SOURCES.txt' 170 | running build_ext 171 | building 'clrnet.ops.nms_impl' extension 172 | creating build/temp.linux-x86_64-cpython-38 173 | creating build/temp.linux-x86_64-cpython-38/clrnet 174 | creating build/temp.linux-x86_64-cpython-38/clrnet/ops 175 | creating build/temp.linux-x86_64-cpython-38/clrnet/ops/csrc 176 | /data1/cv/softs/gcc-5.1.0/install/bin/gcc -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/TH -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/data3/***/anaconda3/envs/clrnet/include/python3.8 -c ./clrnet/ops/csrc/nms.cpp -o build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms.o -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=nms_impl -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 177 | cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++ 178 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Parallel.h:140:0, 179 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/utils.h:3, 180 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h:5, 181 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn.h:3, 182 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:13, 183 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4, 184 | from ./clrnet/ops/csrc/nms.cpp:30: 185 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ParallelOpenMP.h:83:0: warning: ignoring #pragma omp parallel [-Wunknown-pragmas] 186 | #pragma omp parallel for if ((end - begin) >= grain_size) 187 | ^ 188 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Device.h:5:0, 189 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Allocator.h:6, 190 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:7, 191 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3, 192 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4, 193 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3, 194 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3, 195 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3, 196 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3, 197 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8, 198 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4, 199 | from ./clrnet/ops/csrc/nms.cpp:30: 200 | ./clrnet/ops/csrc/nms.cpp: In function ‘std::vector nms_forward(at::Tensor, at::Tensor, float, long unsigned int)’: 201 | ./clrnet/ops/csrc/nms.cpp:40:41: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations] 202 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 203 | ^ 204 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:225:39: note: in definition of macro ‘C10_EXPAND_MSVC_WORKAROUND’ 205 | #define C10_EXPAND_MSVC_WORKAROUND(x) x 206 | ^ 207 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:244:34: note: in expansion of macro ‘C10_UNLIKELY’ 208 | #define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) 209 | ^ 210 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:291:7: note: in expansion of macro ‘C10_UNLIKELY_OR_CONST’ 211 | if (C10_UNLIKELY_OR_CONST(!(cond))) { \ 212 | ^ 213 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:484:32: note: in expansion of macro ‘TORCH_INTERNAL_ASSERT’ 214 | C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \ 215 | ^ 216 | ./clrnet/ops/csrc/nms.cpp:40:23: note: in expansion of macro ‘AT_ASSERTM’ 217 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 218 | ^ 219 | ./clrnet/ops/csrc/nms.cpp:42:24: note: in expansion of macro ‘CHECK_CUDA’ 220 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 221 | ^ 222 | ./clrnet/ops/csrc/nms.cpp:53:5: note: in expansion of macro ‘CHECK_INPUT’ 223 | CHECK_INPUT(boxes); 224 | ^ 225 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Tensor.h:3:0, 226 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Context.h:4, 227 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:9, 228 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3, 229 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4, 230 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3, 231 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3, 232 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3, 233 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3, 234 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8, 235 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4, 236 | from ./clrnet/ops/csrc/nms.cpp:30: 237 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:303:30: note: declared here 238 | DeprecatedTypeProperties & type() const { 239 | ^ 240 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Device.h:5:0, 241 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Allocator.h:6, 242 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:7, 243 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3, 244 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4, 245 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3, 246 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3, 247 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3, 248 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3, 249 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8, 250 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4, 251 | from ./clrnet/ops/csrc/nms.cpp:30: 252 | ./clrnet/ops/csrc/nms.cpp:40:41: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations] 253 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 254 | ^ 255 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:225:39: note: in definition of macro ‘C10_EXPAND_MSVC_WORKAROUND’ 256 | #define C10_EXPAND_MSVC_WORKAROUND(x) x 257 | ^ 258 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:244:34: note: in expansion of macro ‘C10_UNLIKELY’ 259 | #define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) 260 | ^ 261 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:291:7: note: in expansion of macro ‘C10_UNLIKELY_OR_CONST’ 262 | if (C10_UNLIKELY_OR_CONST(!(cond))) { \ 263 | ^ 264 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:484:32: note: in expansion of macro ‘TORCH_INTERNAL_ASSERT’ 265 | C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \ 266 | ^ 267 | ./clrnet/ops/csrc/nms.cpp:40:23: note: in expansion of macro ‘AT_ASSERTM’ 268 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 269 | ^ 270 | ./clrnet/ops/csrc/nms.cpp:42:24: note: in expansion of macro ‘CHECK_CUDA’ 271 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 272 | ^ 273 | ./clrnet/ops/csrc/nms.cpp:54:5: note: in expansion of macro ‘CHECK_INPUT’ 274 | CHECK_INPUT(idx); 275 | ^ 276 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Tensor.h:3:0, 277 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Context.h:4, 278 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:9, 279 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3, 280 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4, 281 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3, 282 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3, 283 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3, 284 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3, 285 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8, 286 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4, 287 | from ./clrnet/ops/csrc/nms.cpp:30: 288 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:303:30: note: declared here 289 | DeprecatedTypeProperties & type() const { 290 | ^ 291 | /usr/local/cuda-11.3/bin/nvcc -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/TH -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/data3/***/anaconda3/envs/clrnet/include/python3.8 -c ./clrnet/ops/csrc/nms_kernel.cu -o build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms_kernel.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=nms_impl -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 -ccbin /data1/cv/softs/gcc-5.1.0/install/bin/gcc -std=c++14 292 | ./clrnet/ops/csrc/nms_kernel.cu: In lambda function: 293 | ./clrnet/ops/csrc/nms_kernel.cu:171:43: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations] 294 | AT_DISPATCH_FLOATING_TYPES(boxes.type(), "nms_cuda_forward", ([&] { 295 | ^ 296 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:303:1: note: declared here 297 | DeprecatedTypeProperties & type() const { 298 | ^ 299 | ./clrnet/ops/csrc/nms_kernel.cu:171:98: warning: ‘c10::ScalarType detail::scalar_type(const at::DeprecatedTypeProperties&)’ is deprecated: passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, pass an at::ScalarType instead [-Wdeprecated-declarations] 300 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Dispatch.h:109:1: note: declared here 301 | inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) { 302 | ^ 303 | ./clrnet/ops/csrc/nms_kernel.cu: In lambda function: 304 | ./clrnet/ops/csrc/nms_kernel.cu:171:856: warning: ‘T* at::Tensor::data() const [with T = double]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 305 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 306 | T * data() const { 307 | ^ 308 | ./clrnet/ops/csrc/nms_kernel.cu:171:878: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 309 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 310 | T * data() const { 311 | ^ 312 | ./clrnet/ops/csrc/nms_kernel.cu:171:901: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 313 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 314 | T * data() const { 315 | ^ 316 | ./clrnet/ops/csrc/nms_kernel.cu: In lambda function: 317 | ./clrnet/ops/csrc/nms_kernel.cu:171:1648: warning: ‘T* at::Tensor::data() const [with T = float]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 318 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 319 | T * data() const { 320 | ^ 321 | ./clrnet/ops/csrc/nms_kernel.cu:171:1670: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 322 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 323 | T * data() const { 324 | ^ 325 | ./clrnet/ops/csrc/nms_kernel.cu:171:1693: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 326 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 327 | T * data() const { 328 | ^ 329 | ./clrnet/ops/csrc/nms_kernel.cu: In function ‘std::vector nms_cuda_forward(at::Tensor, at::Tensor, float, long unsigned int)’: 330 | ./clrnet/ops/csrc/nms_kernel.cu:183:108: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 331 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 332 | T * data() const { 333 | ^ 334 | ./clrnet/ops/csrc/nms_kernel.cu:183:129: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 335 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 336 | T * data() const { 337 | ^ 338 | ./clrnet/ops/csrc/nms_kernel.cu:183:150: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 339 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 340 | T * data() const { 341 | ^ 342 | ./clrnet/ops/csrc/nms_kernel.cu:183:186: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 343 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 344 | T * data() const { 345 | ^ 346 | ./clrnet/ops/csrc/nms_kernel.cu:183:214: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations] 347 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here 348 | T * data() const { 349 | ^ 350 | /data1/cv/softs/gcc-5.1.0/install/bin/g++ -pthread -shared -B /data3/***/anaconda3/envs/clrnet/compiler_compat -L/data3/***/anaconda3/envs/clrnet/lib -Wl,-rpath=/data3/***/anaconda3/envs/clrnet/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms.o build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms_kernel.o -L/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/lib -L/usr/local/cuda-11.3/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-38/clrnet/ops/nms_impl.cpython-38-x86_64-linux-gnu.so 351 | running develop 352 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools. 353 | warnings.warn( 354 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 355 | warnings.warn( 356 | running build_ext 357 | copying build/lib.linux-x86_64-cpython-38/clrnet/ops/nms_impl.cpython-38-x86_64-linux-gnu.so -> clrnet/ops 358 | Creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/clrnet.egg-link (link to .) 359 | Adding clrnet 1.0 to easy-install.pth file 360 | 361 | Installed /data3/***/project/temp/clrnet 362 | Processing dependencies for clrnet==1.0 363 | Searching for ptflops 364 | Reading https://pypi.org/simple/ptflops/ 365 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 366 | warnings.warn( 367 | Downloading https://files.pythonhosted.org/packages/d5/16/b6992b799d14bdc7a78fbfff8825777dd92728ed761090852c43f4792ce1/ptflops-0.6.9.tar.gz#sha256=95423006d7520eff5cc2fcbbd149329d39d81c9bff361c9b3b13bfbd5d1efe75 368 | Best match: ptflops 0.6.9 369 | Processing ptflops-0.6.9.tar.gz 370 | Writing /tmp/easy_install-h2cxdo1q/ptflops-0.6.9/setup.cfg 371 | Running ptflops-0.6.9/setup.py -q bdist_egg --dist-dir /tmp/easy_install-h2cxdo1q/ptflops-0.6.9/egg-dist-tmp-pwoqahab 372 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 373 | warnings.warn( 374 | Moving ptflops-0.6.9-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 375 | Adding ptflops 0.6.9 to easy-install.pth file 376 | 377 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ptflops-0.6.9-py3.8.egg 378 | Searching for pathspec 379 | Reading https://pypi.org/simple/pathspec/ 380 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 381 | warnings.warn( 382 | Downloading https://files.pythonhosted.org/packages/42/ba/a9d64c7bcbc7e3e8e5f93a52721b377e994c22d16196e2b0f1236774353a/pathspec-0.9.0-py2.py3-none-any.whl#sha256=7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a 383 | Best match: pathspec 0.9.0 384 | Processing pathspec-0.9.0-py2.py3-none-any.whl 385 | Installing pathspec-0.9.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 386 | Adding pathspec 0.9.0 to easy-install.pth file 387 | 388 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pathspec-0.9.0-py3.8.egg 389 | Searching for albumentations==0.4.6 390 | Reading https://pypi.org/simple/albumentations/ 391 | Downloading https://files.pythonhosted.org/packages/92/33/1c459c2c9a4028ec75527eff88bc4e2d256555189f42af4baf4d7bd89233/albumentations-0.4.6.tar.gz#sha256=510ac855a6dc7f80723bba7de98c9ee575997a76cf49192c44c707c904a020f8 392 | Best match: albumentations 0.4.6 393 | Processing albumentations-0.4.6.tar.gz 394 | Writing /tmp/easy_install-egwsde9u/albumentations-0.4.6/setup.cfg 395 | Running albumentations-0.4.6/setup.py -q bdist_egg --dist-dir /tmp/easy_install-egwsde9u/albumentations-0.4.6/egg-dist-tmp-mma_63qf 396 | no previously-included directories found matching 'docs/_build' 397 | warning: no previously-included files matching 'docs/augs_overview/*/images/*.jpg' found anywhere in distribution 398 | warning: no previously-included files matching '*.py[co]' found anywhere in distribution 399 | warning: no previously-included files matching '.DS_Store' found anywhere in distribution 400 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 401 | warnings.warn( 402 | zip_safe flag not set; analyzing archive contents... 403 | Moving albumentations-0.4.6-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 404 | Adding albumentations 0.4.6 to easy-install.pth file 405 | 406 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/albumentations-0.4.6-py3.8.egg 407 | Searching for mmcv==1.2.5 408 | Reading https://pypi.org/simple/mmcv/ 409 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 410 | warnings.warn( 411 | Downloading https://files.pythonhosted.org/packages/0c/ce/0574933eb3a02eade4af331579b3cc8840e59e5d757ee67d243f08fd9b75/mmcv-1.2.5.tar.gz#sha256=8d2f3fb61b75cbcfbc0e9756e905d6b8678a3a14ca9f8603726ca86846c41b28 412 | Best match: mmcv 1.2.5 413 | Processing mmcv-1.2.5.tar.gz 414 | Writing /tmp/easy_install-f2kyyeji/mmcv-1.2.5/setup.cfg 415 | Running mmcv-1.2.5/setup.py -q bdist_egg --dist-dir /tmp/easy_install-f2kyyeji/mmcv-1.2.5/egg-dist-tmp-oqa8zdgb 416 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 417 | warnings.warn( 418 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.model_zoo' as data is deprecated, please list it in `packages`. 419 | !! 420 | 421 | 422 | ############################ 423 | # Package would be ignored # 424 | ############################ 425 | Python recognizes 'mmcv.model_zoo' as an importable package, 426 | but it is not listed in the `packages` configuration of setuptools. 427 | 428 | 'mmcv.model_zoo' has been automatically added to the distribution only 429 | because it may contain data files, but this behavior is likely to change 430 | in future versions of setuptools (and therefore is considered deprecated). 431 | 432 | Please make sure that 'mmcv.model_zoo' is included as a package by using 433 | the `packages` configuration field or the proper discovery methods 434 | (for example by using `find_namespace_packages(...)`/`find_namespace:` 435 | instead of `find_packages(...)`/`find:`). 436 | 437 | You can read more about "package discovery" and "data files" on setuptools 438 | documentation page. 439 | 440 | 441 | !! 442 | 443 | check.warn(importable) 444 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.ops.csrc' as data is deprecated, please list it in `packages`. 445 | !! 446 | 447 | 448 | ############################ 449 | # Package would be ignored # 450 | ############################ 451 | Python recognizes 'mmcv.ops.csrc' as an importable package, 452 | but it is not listed in the `packages` configuration of setuptools. 453 | 454 | 'mmcv.ops.csrc' has been automatically added to the distribution only 455 | because it may contain data files, but this behavior is likely to change 456 | in future versions of setuptools (and therefore is considered deprecated). 457 | 458 | Please make sure that 'mmcv.ops.csrc' is included as a package by using 459 | the `packages` configuration field or the proper discovery methods 460 | (for example by using `find_namespace_packages(...)`/`find_namespace:` 461 | instead of `find_packages(...)`/`find:`). 462 | 463 | You can read more about "package discovery" and "data files" on setuptools 464 | documentation page. 465 | 466 | 467 | !! 468 | 469 | check.warn(importable) 470 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.ops.csrc.parrots' as data is deprecated, please list it in `packages`. 471 | !! 472 | 473 | 474 | ############################ 475 | # Package would be ignored # 476 | ############################ 477 | Python recognizes 'mmcv.ops.csrc.parrots' as an importable package, 478 | but it is not listed in the `packages` configuration of setuptools. 479 | 480 | 'mmcv.ops.csrc.parrots' has been automatically added to the distribution only 481 | because it may contain data files, but this behavior is likely to change 482 | in future versions of setuptools (and therefore is considered deprecated). 483 | 484 | Please make sure that 'mmcv.ops.csrc.parrots' is included as a package by using 485 | the `packages` configuration field or the proper discovery methods 486 | (for example by using `find_namespace_packages(...)`/`find_namespace:` 487 | instead of `find_packages(...)`/`find:`). 488 | 489 | You can read more about "package discovery" and "data files" on setuptools 490 | documentation page. 491 | 492 | 493 | !! 494 | 495 | check.warn(importable) 496 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.ops.csrc.pytorch' as data is deprecated, please list it in `packages`. 497 | !! 498 | 499 | 500 | ############################ 501 | # Package would be ignored # 502 | ############################ 503 | Python recognizes 'mmcv.ops.csrc.pytorch' as an importable package, 504 | but it is not listed in the `packages` configuration of setuptools. 505 | 506 | 'mmcv.ops.csrc.pytorch' has been automatically added to the distribution only 507 | because it may contain data files, but this behavior is likely to change 508 | in future versions of setuptools (and therefore is considered deprecated). 509 | 510 | Please make sure that 'mmcv.ops.csrc.pytorch' is included as a package by using 511 | the `packages` configuration field or the proper discovery methods 512 | (for example by using `find_namespace_packages(...)`/`find_namespace:` 513 | instead of `find_packages(...)`/`find:`). 514 | 515 | You can read more about "package discovery" and "data files" on setuptools 516 | documentation page. 517 | 518 | 519 | !! 520 | 521 | check.warn(importable) 522 | creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/mmcv-1.2.5-py3.8.egg 523 | Extracting mmcv-1.2.5-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 524 | Adding mmcv 1.2.5 to easy-install.pth file 525 | 526 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/mmcv-1.2.5-py3.8.egg 527 | Searching for timm 528 | Reading https://pypi.org/simple/timm/ 529 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 530 | warnings.warn( 531 | Downloading https://files.pythonhosted.org/packages/72/ed/358a8bc5685c31c0fe7765351b202cf6a8c087893b5d2d64f63c950f8beb/timm-0.6.7-py3-none-any.whl#sha256=4bbd7a5c9ae462ec7fec3d99ffc62ac2012010d755248e3de778d50bce5f6186 532 | Best match: timm 0.6.7 533 | Processing timm-0.6.7-py3-none-any.whl 534 | Installing timm-0.6.7-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 535 | Adding timm 0.6.7 to easy-install.pth file 536 | 537 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/timm-0.6.7-py3.8.egg 538 | Searching for yapf 539 | Reading https://pypi.org/simple/yapf/ 540 | Downloading https://files.pythonhosted.org/packages/47/88/843c2e68f18a5879b4fbf37cb99fbabe1ffc4343b2e63191c8462235c008/yapf-0.32.0-py2.py3-none-any.whl#sha256=8fea849025584e486fd06d6ba2bed717f396080fd3cc236ba10cb97c4c51cf32 541 | Best match: yapf 0.32.0 542 | Processing yapf-0.32.0-py2.py3-none-any.whl 543 | Installing yapf-0.32.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 544 | Adding yapf 0.32.0 to easy-install.pth file 545 | Installing yapf script to /data3/***/anaconda3/envs/clrnet/bin 546 | Installing yapf-diff script to /data3/***/anaconda3/envs/clrnet/bin 547 | 548 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/yapf-0.32.0-py3.8.egg 549 | Searching for ujson==1.35 550 | Reading https://pypi.org/simple/ujson/ 551 | Downloading https://files.pythonhosted.org/packages/16/c4/79f3409bc710559015464e5f49b9879430d8f87498ecdc335899732e5377/ujson-1.35.tar.gz#sha256=f66073e5506e91d204ab0c614a148d5aa938bdbf104751be66f8ad7a222f5f86 552 | Best match: ujson 1.35 553 | Processing ujson-1.35.tar.gz 554 | Writing /tmp/easy_install-ui57loz0/ujson-1.35/setup.cfg 555 | Running ujson-1.35/setup.py -q bdist_egg --dist-dir /tmp/easy_install-ui57loz0/ujson-1.35/egg-dist-tmp-7dvt_ppm 556 | Warning: 'classifiers' should be a list, got type 'filter' 557 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 558 | warnings.warn( 559 | ./lib/ultrajsonenc.c:156:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static 560 | *(outputOffset++) = g_hexChars[(value & 0x000f) >> 0]; 561 | ^ 562 | ./lib/ultrajsonenc.c:155:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static 563 | *(outputOffset++) = g_hexChars[(value & 0x00f0) >> 4]; 564 | ^ 565 | ./lib/ultrajsonenc.c:154:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static 566 | *(outputOffset++) = g_hexChars[(value & 0x0f00) >> 8]; 567 | ^ 568 | ./lib/ultrajsonenc.c:153:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static 569 | *(outputOffset++) = g_hexChars[(value & 0xf000) >> 12]; 570 | ^ 571 | ./python/objToJSON.c: In function ‘PyUnicodeToUTF8’: 572 | ./python/objToJSON.c:154:18: warning: initialization discards ‘const’ qualifier from pointer target type [-Wdiscarded-qualifiers] 573 | char *data = PyUnicode_AsUTF8AndSize(obj, &len); 574 | ^ 575 | zip_safe flag not set; analyzing archive contents... 576 | __pycache__.ujson.cpython-38: module references __file__ 577 | creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ujson-1.35-py3.8-linux-x86_64.egg 578 | Extracting ujson-1.35-py3.8-linux-x86_64.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 579 | Adding ujson 1.35 to easy-install.pth file 580 | 581 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ujson-1.35-py3.8-linux-x86_64.egg 582 | Searching for Shapely==1.7.0 583 | Reading https://pypi.org/simple/Shapely/ 584 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 585 | warnings.warn( 586 | Downloading https://files.pythonhosted.org/packages/da/93/da69f1f278c02b4dfcf27b33e36bed43d2a6c8213d57b9a21840af4a407a/Shapely-1.7.0-cp38-cp38-manylinux1_x86_64.whl#sha256=2154b9f25c5f13785cb05ce80b2c86e542bc69671193743f29c9f4c791c35db3 587 | Best match: Shapely 1.7.0 588 | Processing Shapely-1.7.0-cp38-cp38-manylinux1_x86_64.whl 589 | Installing Shapely-1.7.0-cp38-cp38-manylinux1_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 590 | Adding Shapely 1.7.0 to easy-install.pth file 591 | 592 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/Shapely-1.7.0-py3.8-linux-x86_64.egg 593 | Searching for imgaug>=0.4.0 594 | Reading https://pypi.org/simple/imgaug/ 595 | Downloading https://files.pythonhosted.org/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl#sha256=ce61e65b4eb7405fc62c1b0a79d2fa92fd47f763aaecb65152d29243592111f9 596 | Best match: imgaug 0.4.0 597 | Processing imgaug-0.4.0-py2.py3-none-any.whl 598 | Installing imgaug-0.4.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 599 | Adding imgaug 0.4.0 to easy-install.pth file 600 | 601 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/imgaug-0.4.0-py3.8.egg 602 | Searching for p_tqdm 603 | Reading https://pypi.org/simple/p_tqdm/ 604 | Downloading https://files.pythonhosted.org/packages/b9/c4/ce6abe2fa3868b1ea9216a81522a9ece36f47bdbb966f8f31f76e2967178/p_tqdm-1.3.3.tar.gz#sha256=8b9316d8bae43279e03ea01c8849422e5072a931f1b94b79890b0da426802c6e 605 | Best match: p-tqdm 1.3.3 606 | Processing p_tqdm-1.3.3.tar.gz 607 | Writing /tmp/easy_install-wruk1oj0/p_tqdm-1.3.3/setup.cfg 608 | Running p_tqdm-1.3.3/setup.py -q bdist_egg --dist-dir /tmp/easy_install-wruk1oj0/p_tqdm-1.3.3/egg-dist-tmp-mqacu6ap 609 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/dist.py:771: UserWarning: Usage of dash-separated 'description-file' will not be supported in future versions. Please use the underscore name 'description_file' instead 610 | warnings.warn( 611 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 612 | warnings.warn( 613 | zip_safe flag not set; analyzing archive contents... 614 | Moving p_tqdm-1.3.3-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 615 | Adding p-tqdm 1.3.3 to easy-install.pth file 616 | 617 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/p_tqdm-1.3.3-py3.8.egg 618 | Searching for tqdm 619 | Reading https://pypi.org/simple/tqdm/ 620 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 621 | warnings.warn( 622 | Downloading https://files.pythonhosted.org/packages/8a/c4/d15f1e627fff25443ded77ea70a7b5532d6371498f9285d44d62587e209c/tqdm-4.64.0-py2.py3-none-any.whl#sha256=74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6 623 | Best match: tqdm 4.64.0 624 | Processing tqdm-4.64.0-py2.py3-none-any.whl 625 | Installing tqdm-4.64.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 626 | Adding tqdm 4.64.0 to easy-install.pth file 627 | Installing tqdm script to /data3/***/anaconda3/envs/clrnet/bin 628 | 629 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/tqdm-4.64.0-py3.8.egg 630 | Searching for scikit-image 631 | Reading https://pypi.org/simple/scikit-image/ 632 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.7.2 is an invalid version and will not be supported in a future release 633 | warnings.warn( 634 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.8.0 is an invalid version and will not be supported in a future release 635 | warnings.warn( 636 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.8.1 is an invalid version and will not be supported in a future release 637 | warnings.warn( 638 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.8.2 is an invalid version and will not be supported in a future release 639 | warnings.warn( 640 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.9.0 is an invalid version and will not be supported in a future release 641 | warnings.warn( 642 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.9.1 is an invalid version and will not be supported in a future release 643 | warnings.warn( 644 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.9.3 is an invalid version and will not be supported in a future release 645 | warnings.warn( 646 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.10.0 is an invalid version and will not be supported in a future release 647 | warnings.warn( 648 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.10.1 is an invalid version and will not be supported in a future release 649 | warnings.warn( 650 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.11.2 is an invalid version and will not be supported in a future release 651 | warnings.warn( 652 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.11.3 is an invalid version and will not be supported in a future release 653 | warnings.warn( 654 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.0 is an invalid version and will not be supported in a future release 655 | warnings.warn( 656 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.1 is an invalid version and will not be supported in a future release 657 | warnings.warn( 658 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.2 is an invalid version and will not be supported in a future release 659 | warnings.warn( 660 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.3 is an invalid version and will not be supported in a future release 661 | warnings.warn( 662 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.13.0 is an invalid version and will not be supported in a future release 663 | warnings.warn( 664 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.13.1 is an invalid version and will not be supported in a future release 665 | warnings.warn( 666 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.0 is an invalid version and will not be supported in a future release 667 | warnings.warn( 668 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.1 is an invalid version and will not be supported in a future release 669 | warnings.warn( 670 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.2 is an invalid version and will not be supported in a future release 671 | warnings.warn( 672 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.3 is an invalid version and will not be supported in a future release 673 | warnings.warn( 674 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.5 is an invalid version and will not be supported in a future release 675 | warnings.warn( 676 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.15.0 is an invalid version and will not be supported in a future release 677 | warnings.warn( 678 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.16.2 is an invalid version and will not be supported in a future release 679 | warnings.warn( 680 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.17.1 is an invalid version and will not be supported in a future release 681 | warnings.warn( 682 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.17.2 is an invalid version and will not be supported in a future release 683 | warnings.warn( 684 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.0 is an invalid version and will not be supported in a future release 685 | warnings.warn( 686 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.1 is an invalid version and will not be supported in a future release 687 | warnings.warn( 688 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.2 is an invalid version and will not be supported in a future release 689 | warnings.warn( 690 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.3 is an invalid version and will not be supported in a future release 691 | warnings.warn( 692 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.0rc0 is an invalid version and will not be supported in a future release 693 | warnings.warn( 694 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.0 is an invalid version and will not be supported in a future release 695 | warnings.warn( 696 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.1 is an invalid version and will not be supported in a future release 697 | warnings.warn( 698 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.2 is an invalid version and will not be supported in a future release 699 | warnings.warn( 700 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.3 is an invalid version and will not be supported in a future release 701 | warnings.warn( 702 | Downloading https://files.pythonhosted.org/packages/96/11/878ee6757f75835c396fbdd934ca8e1a1681553ac0925fbf77065c9618e5/scikit_image-0.19.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=e207c6ce5ce121d7d9b9d2b61b9adca57d1abed112c902d8ffbfdc20fb42c12b 703 | Best match: scikit-image 0.19.3 704 | Processing scikit_image-0.19.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 705 | Installing scikit_image-0.19.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 706 | Adding scikit-image 0.19.3 to easy-install.pth file 707 | Installing skivi script to /data3/***/anaconda3/envs/clrnet/bin 708 | 709 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scikit_image-0.19.3-py3.8-linux-x86_64.egg 710 | Searching for pytorch_warmup 711 | Reading https://pypi.org/simple/pytorch_warmup/ 712 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: warmup-0.0.3 is an invalid version and will not be supported in a future release 713 | warnings.warn( 714 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: warmup-0.0.4 is an invalid version and will not be supported in a future release 715 | warnings.warn( 716 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: warmup-0.1.0 is an invalid version and will not be supported in a future release 717 | warnings.warn( 718 | Downloading https://files.pythonhosted.org/packages/19/95/a9fb4ddf493a3df4a0ca947cc03ab452c569e5c5d5c66e31bfbdcfd2ae99/pytorch-warmup-0.1.0.tar.gz#sha256=6462c34b6942a61df54bca72c3ed7c8df8ef7089ade1f39724ec15e47022fb45 719 | Best match: pytorch-warmup 0.1.0 720 | Processing pytorch-warmup-0.1.0.tar.gz 721 | Writing /tmp/easy_install-m9mhpmsj/pytorch-warmup-0.1.0/setup.cfg 722 | Running pytorch-warmup-0.1.0/setup.py -q bdist_egg --dist-dir /tmp/easy_install-m9mhpmsj/pytorch-warmup-0.1.0/egg-dist-tmp-4vd6i_sa 723 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 724 | warnings.warn( 725 | zip_safe flag not set; analyzing archive contents... 726 | Moving pytorch_warmup-0.1.0-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 727 | Adding pytorch-warmup 0.1.0 to easy-install.pth file 728 | 729 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pytorch_warmup-0.1.0-py3.8.egg 730 | Searching for opencv-python 731 | Reading https://pypi.org/simple/opencv-python/ 732 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.0.14 is an invalid version and will not be supported in a future release 733 | warnings.warn( 734 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 735 | warnings.warn( 736 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.10.37 is an invalid version and will not be supported in a future release 737 | warnings.warn( 738 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.39 is an invalid version and will not be supported in a future release 739 | warnings.warn( 740 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.41 is an invalid version and will not be supported in a future release 741 | warnings.warn( 742 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.43 is an invalid version and will not be supported in a future release 743 | warnings.warn( 744 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.45 is an invalid version and will not be supported in a future release 745 | warnings.warn( 746 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.13.47 is an invalid version and will not be supported in a future release 747 | warnings.warn( 748 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.15.55 is an invalid version and will not be supported in a future release 749 | warnings.warn( 750 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.16.57 is an invalid version and will not be supported in a future release 751 | warnings.warn( 752 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.16.59 is an invalid version and will not be supported in a future release 753 | warnings.warn( 754 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.17.61 is an invalid version and will not be supported in a future release 755 | warnings.warn( 756 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.17.63 is an invalid version and will not be supported in a future release 757 | warnings.warn( 758 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.18.65 is an invalid version and will not be supported in a future release 759 | warnings.warn( 760 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.3.0.38 is an invalid version and will not be supported in a future release 761 | warnings.warn( 762 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.40 is an invalid version and will not be supported in a future release 763 | warnings.warn( 764 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.42 is an invalid version and will not be supported in a future release 765 | warnings.warn( 766 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.44 is an invalid version and will not be supported in a future release 767 | warnings.warn( 768 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.46 is an invalid version and will not be supported in a future release 769 | warnings.warn( 770 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.1.48 is an invalid version and will not be supported in a future release 771 | warnings.warn( 772 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.3.56 is an invalid version and will not be supported in a future release 773 | warnings.warn( 774 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.4.58 is an invalid version and will not be supported in a future release 775 | warnings.warn( 776 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.4.60 is an invalid version and will not be supported in a future release 777 | warnings.warn( 778 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.5.62 is an invalid version and will not be supported in a future release 779 | warnings.warn( 780 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.5.64 is an invalid version and will not be supported in a future release 781 | warnings.warn( 782 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.6.0.66 is an invalid version and will not be supported in a future release 783 | warnings.warn( 784 | Downloading https://files.pythonhosted.org/packages/af/bf/8d189a5c43460f6b5c8eb81ead8732e94b9f73ef8d9abba9e8f5a61a6531/opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=dbdc84a9b4ea2cbae33861652d25093944b9959279200b7ae0badd32439f74de 785 | Best match: opencv-python 4.6.0.66 786 | Processing opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 787 | Installing opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 788 | Adding opencv-python 4.6.0.66 to easy-install.pth file 789 | 790 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/opencv_python-4.6.0.66-py3.8-linux-x86_64.egg 791 | Searching for sklearn 792 | Reading https://pypi.org/simple/sklearn/ 793 | Downloading https://files.pythonhosted.org/packages/1e/7a/dbb3be0ce9bd5c8b7e3d87328e79063f8b263b2b1bfa4774cb1147bfcd3f/sklearn-0.0.tar.gz#sha256=e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31 794 | Best match: sklearn 0.0 795 | Processing sklearn-0.0.tar.gz 796 | Writing /tmp/easy_install-0b3r3r4j/sklearn-0.0/setup.cfg 797 | Running sklearn-0.0/setup.py -q bdist_egg --dist-dir /tmp/easy_install-0b3r3r4j/sklearn-0.0/egg-dist-tmp-px69reo0 798 | file wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141.py (for module wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141) not found 799 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 800 | warnings.warn( 801 | file wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141.py (for module wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141) not found 802 | file wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141.py (for module wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141) not found 803 | warning: install_lib: 'build/lib' does not exist -- no Python modules to install 804 | 805 | creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/sklearn-0.0-py3.8.egg 806 | Extracting sklearn-0.0-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 807 | Adding sklearn 0.0 to easy-install.pth file 808 | 809 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/sklearn-0.0-py3.8.egg 810 | Searching for addict 811 | Reading https://pypi.org/simple/addict/ 812 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 813 | warnings.warn( 814 | Downloading https://files.pythonhosted.org/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl#sha256=249bb56bbfd3cdc2a004ea0ff4c2b6ddc84d53bc2194761636eb314d5cfa5dfc 815 | Best match: addict 2.4.0 816 | Processing addict-2.4.0-py3-none-any.whl 817 | Installing addict-2.4.0-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 818 | Adding addict 2.4.0 to easy-install.pth file 819 | 820 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/addict-2.4.0-py3.8.egg 821 | Searching for pandas 822 | Reading https://pypi.org/simple/pandas/ 823 | Downloading https://files.pythonhosted.org/packages/76/59/3451a9898bf236a86494829a889ea901930dc1c493f403912192158e4390/pandas-1.5.0rc0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=0a3cfd8777f3eb9fa6df12ea6f4aea338a9205507eb12d8dca3b3f00cf4ffe17 824 | Best match: pandas 1.5.0rc0 825 | Processing pandas-1.5.0rc0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 826 | Installing pandas-1.5.0rc0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 827 | Adding pandas 1.5.0rc0 to easy-install.pth file 828 | 829 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pandas-1.5.0rc0-py3.8-linux-x86_64.egg 830 | Searching for opencv-python-headless>=4.1.1 831 | Reading https://pypi.org/simple/opencv-python-headless/ 832 | Download error on https://pypi.org/simple/opencv-python-headless/: The read operation timed out -- Some packages may not be found! 833 | Couldn't find index page for 'opencv-python-headless' (maybe misspelled?) 834 | Scanning index of all packages (this may take a while) 835 | Reading https://pypi.org/simple/ 836 | No local packages or working download links found for opencv-python-headless>=4.1.1 837 | error: Could not find suitable distribution for Requirement.parse('opencv-python-headless>=4.1.1') 838 | (clrnet) [***@ai02 clrnet]$ pip list 839 | Package Version Editable project location 840 | ----------------- --------- -------------------------------- 841 | addict 2.4.0 842 | albumentations 0.4.6 843 | certifi 2022.6.15 844 | clrnet 1.0 /data3/***/project/temp/clrnet 845 | imgaug 0.4.0 846 | mmcv 1.2.5 847 | numpy 1.23.2 848 | opencv-python 4.6.0.66 849 | p-tqdm 1.3.3 850 | pandas 1.5.0rc0 851 | pathspec 0.9.0 852 | Pillow 9.2.0 853 | pip 22.1.2 854 | ptflops 0.6.9 855 | pytorch-warmup 0.1.0 856 | scikit-image 0.19.3 857 | setuptools 63.4.1 858 | Shapely 1.7.0 859 | sklearn 0.0 860 | timm 0.6.7 861 | torch 1.8.0 862 | torchvision 0.9.0 863 | tqdm 4.64.0 864 | typing_extensions 4.3.0 865 | ujson 1.35 866 | wheel 0.37.1 867 | yapf 0.32.0 868 | (clrnet) [***@ai02 clrnet]$ pip install opencv-python-headless 869 | Collecting opencv-python-headless 870 | Using cached opencv_python_headless-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (48.3 MB) 871 | Requirement already satisfied: numpy>=1.14.5 in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages (from opencv-python-headless) (1.23.2) 872 | Installing collected packages: opencv-python-headless 873 | ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. 874 | albumentations 0.4.6 requires PyYAML, which is not installed. 875 | albumentations 0.4.6 requires scipy, which is not installed. 876 | Successfully installed opencv-python-headless-4.6.0.66 877 | (clrnet) [***@ai02 clrnet]$ python setup.py build develop 878 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer. 879 | warnings.warn( 880 | running build 881 | running build_py 882 | running egg_info 883 | writing clrnet.egg-info/PKG-INFO 884 | writing dependency_links to clrnet.egg-info/dependency_links.txt 885 | writing requirements to clrnet.egg-info/requires.txt 886 | writing top-level names to clrnet.egg-info/top_level.txt 887 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/utils/cpp_extension.py:369: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend. 888 | warnings.warn(msg.format('we could not find ninja.')) 889 | reading manifest file 'clrnet.egg-info/SOURCES.txt' 890 | adding license file 'LICENSE' 891 | writing manifest file 'clrnet.egg-info/SOURCES.txt' 892 | running build_ext 893 | running develop 894 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools. 895 | warnings.warn( 896 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools. 897 | warnings.warn( 898 | running build_ext 899 | copying build/lib.linux-x86_64-cpython-38/clrnet/ops/nms_impl.cpython-38-x86_64-linux-gnu.so -> clrnet/ops 900 | Creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/clrnet.egg-link (link to .) 901 | clrnet 1.0 is already the active version in easy-install.pth 902 | 903 | Installed /data3/***/project/temp/clrnet 904 | Processing dependencies for clrnet==1.0 905 | Searching for PyYAML 906 | Reading https://pypi.org/simple/PyYAML/ 907 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release 908 | warnings.warn( 909 | Downloading https://files.pythonhosted.org/packages/d7/42/7ad4b6d67a16229496d4f6e74201bdbebcf4bc1e87d5a70c9297d4961bd2/PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl#sha256=277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287 910 | Best match: PyYAML 6.0 911 | Processing PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl 912 | Installing PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 913 | Adding PyYAML 6.0 to easy-install.pth file 914 | 915 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/PyYAML-6.0-py3.8-linux-x86_64.egg 916 | Searching for scipy 917 | Reading https://pypi.org/simple/scipy/ 918 | Downloading https://files.pythonhosted.org/packages/f7/38/f5fad7b9a50a26a6a709549f815db7026a97d1a5652e548d5bbe587bb954/scipy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=73b704c5eea9be811919cae4caacf3180dd9212d9aed08477c1d2ba14900a9de 919 | Best match: scipy 1.9.0 920 | Processing scipy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 921 | Installing scipy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 922 | Adding scipy 1.9.0 to easy-install.pth file 923 | 924 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scipy-1.9.0-py3.8-linux-x86_64.egg 925 | Searching for imageio 926 | Reading https://pypi.org/simple/imageio/ 927 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-linux32 is an invalid version and will not be supported in a future release 928 | warnings.warn( 929 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: linux32 is an invalid version and will not be supported in a future release 930 | warnings.warn( 931 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-linux64 is an invalid version and will not be supported in a future release 932 | warnings.warn( 933 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: linux64 is an invalid version and will not be supported in a future release 934 | warnings.warn( 935 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-osx64 is an invalid version and will not be supported in a future release 936 | warnings.warn( 937 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: osx64 is an invalid version and will not be supported in a future release 938 | warnings.warn( 939 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-win32 is an invalid version and will not be supported in a future release 940 | warnings.warn( 941 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: win32 is an invalid version and will not be supported in a future release 942 | warnings.warn( 943 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-win64 is an invalid version and will not be supported in a future release 944 | warnings.warn( 945 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: win64 is an invalid version and will not be supported in a future release 946 | warnings.warn( 947 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-linux32 is an invalid version and will not be supported in a future release 948 | warnings.warn( 949 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-linux64 is an invalid version and will not be supported in a future release 950 | warnings.warn( 951 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-osx64 is an invalid version and will not be supported in a future release 952 | warnings.warn( 953 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-win32 is an invalid version and will not be supported in a future release 954 | warnings.warn( 955 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-win64 is an invalid version and will not be supported in a future release 956 | warnings.warn( 957 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-linux32 is an invalid version and will not be supported in a future release 958 | warnings.warn( 959 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-linux64 is an invalid version and will not be supported in a future release 960 | warnings.warn( 961 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-osx64 is an invalid version and will not be supported in a future release 962 | warnings.warn( 963 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-win32 is an invalid version and will not be supported in a future release 964 | warnings.warn( 965 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-win64 is an invalid version and will not be supported in a future release 966 | warnings.warn( 967 | Downloading https://files.pythonhosted.org/packages/19/ea/fe0cf58fe26999fc22b7cc75215a6093e1f9e146b747b1c39ad7ed320e92/imageio-2.21.1-py3-none-any.whl#sha256=ea8770d082cea02de6ca5500ab3ad649a8c09832528152efd07da5c225b13722 968 | Best match: imageio 2.21.1 969 | Processing imageio-2.21.1-py3-none-any.whl 970 | Installing imageio-2.21.1-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 971 | Adding imageio 2.21.1 to easy-install.pth file 972 | Installing imageio_download_bin script to /data3/***/anaconda3/envs/clrnet/bin 973 | Installing imageio_remove_bin script to /data3/***/anaconda3/envs/clrnet/bin 974 | 975 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/imageio-2.21.1-py3.8.egg 976 | Searching for matplotlib 977 | Reading https://pypi.org/simple/matplotlib/ 978 | Downloading https://files.pythonhosted.org/packages/45/5e/5f45fbab5d259403348bebcf1aac514c64ecc0c27a8a24dca9a4a7476403/matplotlib-3.6.0rc1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl#sha256=daff0aff6e90fddd8fb53aa9e7116e0dbd48861573070c69faf7342ba2c20a44 979 | Best match: matplotlib 3.6.0rc1 980 | Processing matplotlib-3.6.0rc1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl 981 | Installing matplotlib-3.6.0rc1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 982 | Adding matplotlib 3.6.0rc1 to easy-install.pth file 983 | 984 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/matplotlib-3.6.0rc1-py3.8-linux-x86_64.egg 985 | Searching for six 986 | Reading https://pypi.org/simple/six/ 987 | Downloading https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl#sha256=8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 988 | Best match: six 1.16.0 989 | Processing six-1.16.0-py2.py3-none-any.whl 990 | Installing six-1.16.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 991 | Adding six 1.16.0 to easy-install.pth file 992 | 993 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/six-1.16.0-py3.8.egg 994 | Searching for pathos 995 | Reading https://pypi.org/simple/pathos/ 996 | Downloading https://files.pythonhosted.org/packages/57/78/3f3a6c8f9447a947af2f2a191f71cc870b0fa673432044f38d966f3fd9c0/pathos-0.2.9-py3-none-any.whl#sha256=1c44373d8692897d5d15a8aa3b3a442ddc0814c5e848f4ff0ded5491f34b1dac 997 | Best match: pathos 0.2.9 998 | Processing pathos-0.2.9-py3-none-any.whl 999 | Installing pathos-0.2.9-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1000 | Adding pathos 0.2.9 to easy-install.pth file 1001 | Installing portpicker script to /data3/***/anaconda3/envs/clrnet/bin 1002 | Installing pathos_connect script to /data3/***/anaconda3/envs/clrnet/bin 1003 | 1004 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pathos-0.2.9-py3.8.egg 1005 | Searching for packaging>=20.0 1006 | Reading https://pypi.org/simple/packaging/ 1007 | Downloading https://files.pythonhosted.org/packages/05/8e/8de486cbd03baba4deef4142bd643a3e7bbe954a784dc1bb17142572d127/packaging-21.3-py3-none-any.whl#sha256=ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522 1008 | Best match: packaging 21.3 1009 | Processing packaging-21.3-py3-none-any.whl 1010 | Installing packaging-21.3-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1011 | Adding packaging 21.3 to easy-install.pth file 1012 | 1013 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/packaging-21.3-py3.8.egg 1014 | Searching for PyWavelets>=1.1.1 1015 | Reading https://pypi.org/simple/PyWavelets/ 1016 | Downloading https://files.pythonhosted.org/packages/4b/20/04a0a3e43a45a459c2bcde756833b2eca9430729e89d65da35f70e99e997/PyWavelets-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=d7369597e1b1d125eb4b458a36cef052beed188444e55ed21445c1196008e200 1017 | Best match: PyWavelets 1.3.0 1018 | Processing PyWavelets-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 1019 | Installing PyWavelets-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1020 | Adding PyWavelets 1.3.0 to easy-install.pth file 1021 | 1022 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/PyWavelets-1.3.0-py3.8-linux-x86_64.egg 1023 | Searching for tifffile>=2019.7.26 1024 | Reading https://pypi.org/simple/tifffile/ 1025 | Downloading https://files.pythonhosted.org/packages/45/84/31e59ef72ac4149bb27ab9ccb3aa2b0d294abd97cf61dafd599bddf50a69/tifffile-2022.8.12-py3-none-any.whl#sha256=1456f9f6943c85082ef4d73f5329038826da67f70d5d513873a06f3b1598d23e 1026 | Best match: tifffile 2022.8.12 1027 | Processing tifffile-2022.8.12-py3-none-any.whl 1028 | Installing tifffile-2022.8.12-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1029 | Adding tifffile 2022.8.12 to easy-install.pth file 1030 | Installing lsm2bin script to /data3/***/anaconda3/envs/clrnet/bin 1031 | Installing tiff2fsspec script to /data3/***/anaconda3/envs/clrnet/bin 1032 | Installing tiffcomment script to /data3/***/anaconda3/envs/clrnet/bin 1033 | Installing tifffile script to /data3/***/anaconda3/envs/clrnet/bin 1034 | 1035 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/tifffile-2022.8.12-py3.8.egg 1036 | Searching for networkx>=2.2 1037 | Reading https://pypi.org/simple/networkx/ 1038 | Downloading https://files.pythonhosted.org/packages/be/25/5b0fc262a2f2d7d11c22cb7785edf2befc756ae076b383034e79e255eb11/networkx-2.8.6-py3-none-any.whl#sha256=2a30822761f34d56b9a370d96a4bf4827a535f5591a4078a453425caeba0c5bb 1039 | Best match: networkx 2.8.6 1040 | Processing networkx-2.8.6-py3-none-any.whl 1041 | Installing networkx-2.8.6-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1042 | Adding networkx 2.8.6 to easy-install.pth file 1043 | 1044 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/networkx-2.8.6-py3.8.egg 1045 | Searching for scikit-learn 1046 | Reading https://pypi.org/simple/scikit-learn/ 1047 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.9 is an invalid version and will not be supported in a future release 1048 | warnings.warn( 1049 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.10 is an invalid version and will not be supported in a future release 1050 | warnings.warn( 1051 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.11 is an invalid version and will not be supported in a future release 1052 | warnings.warn( 1053 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.12 is an invalid version and will not be supported in a future release 1054 | warnings.warn( 1055 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.12.1 is an invalid version and will not be supported in a future release 1056 | warnings.warn( 1057 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.13 is an invalid version and will not be supported in a future release 1058 | warnings.warn( 1059 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.13.1 is an invalid version and will not be supported in a future release 1060 | warnings.warn( 1061 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.14 is an invalid version and will not be supported in a future release 1062 | warnings.warn( 1063 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.14.1 is an invalid version and will not be supported in a future release 1064 | warnings.warn( 1065 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.0b1 is an invalid version and will not be supported in a future release 1066 | warnings.warn( 1067 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.0b2 is an invalid version and will not be supported in a future release 1068 | warnings.warn( 1069 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.0 is an invalid version and will not be supported in a future release 1070 | warnings.warn( 1071 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.1 is an invalid version and will not be supported in a future release 1072 | warnings.warn( 1073 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.2 is an invalid version and will not be supported in a future release 1074 | warnings.warn( 1075 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.16b1 is an invalid version and will not be supported in a future release 1076 | warnings.warn( 1077 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.16.0 is an invalid version and will not be supported in a future release 1078 | warnings.warn( 1079 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.16.1 is an invalid version and will not be supported in a future release 1080 | warnings.warn( 1081 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.17b1 is an invalid version and will not be supported in a future release 1082 | warnings.warn( 1083 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.17 is an invalid version and will not be supported in a future release 1084 | warnings.warn( 1085 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.17.1 is an invalid version and will not be supported in a future release 1086 | warnings.warn( 1087 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18rc2 is an invalid version and will not be supported in a future release 1088 | warnings.warn( 1089 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18 is an invalid version and will not be supported in a future release 1090 | warnings.warn( 1091 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18.1 is an invalid version and will not be supported in a future release 1092 | warnings.warn( 1093 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18.2 is an invalid version and will not be supported in a future release 1094 | warnings.warn( 1095 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19b2 is an invalid version and will not be supported in a future release 1096 | warnings.warn( 1097 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19.0 is an invalid version and will not be supported in a future release 1098 | warnings.warn( 1099 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19.1 is an invalid version and will not be supported in a future release 1100 | warnings.warn( 1101 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19.2 is an invalid version and will not be supported in a future release 1102 | warnings.warn( 1103 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20rc1 is an invalid version and will not be supported in a future release 1104 | warnings.warn( 1105 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.0 is an invalid version and will not be supported in a future release 1106 | warnings.warn( 1107 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.1 is an invalid version and will not be supported in a future release 1108 | warnings.warn( 1109 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.2 is an invalid version and will not be supported in a future release 1110 | warnings.warn( 1111 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.3 is an invalid version and will not be supported in a future release 1112 | warnings.warn( 1113 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.4 is an invalid version and will not be supported in a future release 1114 | warnings.warn( 1115 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21rc2 is an invalid version and will not be supported in a future release 1116 | warnings.warn( 1117 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21.1 is an invalid version and will not be supported in a future release 1118 | warnings.warn( 1119 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21.2 is an invalid version and will not be supported in a future release 1120 | warnings.warn( 1121 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21.3 is an invalid version and will not be supported in a future release 1122 | warnings.warn( 1123 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22rc2.post1 is an invalid version and will not be supported in a future release 1124 | warnings.warn( 1125 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22rc3 is an invalid version and will not be supported in a future release 1126 | warnings.warn( 1127 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22 is an invalid version and will not be supported in a future release 1128 | warnings.warn( 1129 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22.1 is an invalid version and will not be supported in a future release 1130 | warnings.warn( 1131 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22.2.post1 is an invalid version and will not be supported in a future release 1132 | warnings.warn( 1133 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.0rc1 is an invalid version and will not be supported in a future release 1134 | warnings.warn( 1135 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.0 is an invalid version and will not be supported in a future release 1136 | warnings.warn( 1137 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.1 is an invalid version and will not be supported in a future release 1138 | warnings.warn( 1139 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.2 is an invalid version and will not be supported in a future release 1140 | warnings.warn( 1141 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.dev0 is an invalid version and will not be supported in a future release 1142 | warnings.warn( 1143 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.0rc1 is an invalid version and will not be supported in a future release 1144 | warnings.warn( 1145 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.0 is an invalid version and will not be supported in a future release 1146 | warnings.warn( 1147 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.1 is an invalid version and will not be supported in a future release 1148 | warnings.warn( 1149 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.2 is an invalid version and will not be supported in a future release 1150 | warnings.warn( 1151 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0rc1 is an invalid version and will not be supported in a future release 1152 | warnings.warn( 1153 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0rc2 is an invalid version and will not be supported in a future release 1154 | warnings.warn( 1155 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0 is an invalid version and will not be supported in a future release 1156 | warnings.warn( 1157 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0.1 is an invalid version and will not be supported in a future release 1158 | warnings.warn( 1159 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0.2 is an invalid version and will not be supported in a future release 1160 | warnings.warn( 1161 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.0rc1 is an invalid version and will not be supported in a future release 1162 | warnings.warn( 1163 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.0 is an invalid version and will not be supported in a future release 1164 | warnings.warn( 1165 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.1 is an invalid version and will not be supported in a future release 1166 | warnings.warn( 1167 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.2 is an invalid version and will not be supported in a future release 1168 | warnings.warn( 1169 | Downloading https://files.pythonhosted.org/packages/91/d1/50eb92222e8b2f315ec5499b97a926d271305e19e254fdced4db899647d6/scikit_learn-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=f94c0146bad51daef919c402a3da8c1c6162619653e1c00c92baa168fda292f2 1170 | Best match: scikit-learn 1.1.2 1171 | Processing scikit_learn-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 1172 | Installing scikit_learn-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1173 | Adding scikit-learn 1.1.2 to easy-install.pth file 1174 | 1175 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scikit_learn-1.1.2-py3.8-linux-x86_64.egg 1176 | Searching for pytz>=2020.1 1177 | Reading https://pypi.org/simple/pytz/ 1178 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2004d is an invalid version and will not be supported in a future release 1179 | warnings.warn( 1180 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005e is an invalid version and will not be supported in a future release 1181 | warnings.warn( 1182 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005i is an invalid version and will not be supported in a future release 1183 | warnings.warn( 1184 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005k is an invalid version and will not be supported in a future release 1185 | warnings.warn( 1186 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005m is an invalid version and will not be supported in a future release 1187 | warnings.warn( 1188 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2006g is an invalid version and will not be supported in a future release 1189 | warnings.warn( 1190 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2006j is an invalid version and will not be supported in a future release 1191 | warnings.warn( 1192 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2006p is an invalid version and will not be supported in a future release 1193 | warnings.warn( 1194 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007d is an invalid version and will not be supported in a future release 1195 | warnings.warn( 1196 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007f is an invalid version and will not be supported in a future release 1197 | warnings.warn( 1198 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007g is an invalid version and will not be supported in a future release 1199 | warnings.warn( 1200 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007i is an invalid version and will not be supported in a future release 1201 | warnings.warn( 1202 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007k is an invalid version and will not be supported in a future release 1203 | warnings.warn( 1204 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2008g is an invalid version and will not be supported in a future release 1205 | warnings.warn( 1206 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2008h is an invalid version and will not be supported in a future release 1207 | warnings.warn( 1208 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2008i is an invalid version and will not be supported in a future release 1209 | warnings.warn( 1210 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009d is an invalid version and will not be supported in a future release 1211 | warnings.warn( 1212 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009e is an invalid version and will not be supported in a future release 1213 | warnings.warn( 1214 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009f is an invalid version and will not be supported in a future release 1215 | warnings.warn( 1216 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009g is an invalid version and will not be supported in a future release 1217 | warnings.warn( 1218 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009i is an invalid version and will not be supported in a future release 1219 | warnings.warn( 1220 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009j is an invalid version and will not be supported in a future release 1221 | warnings.warn( 1222 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009l is an invalid version and will not be supported in a future release 1223 | warnings.warn( 1224 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009n is an invalid version and will not be supported in a future release 1225 | warnings.warn( 1226 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009p is an invalid version and will not be supported in a future release 1227 | warnings.warn( 1228 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009u is an invalid version and will not be supported in a future release 1229 | warnings.warn( 1230 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010e is an invalid version and will not be supported in a future release 1231 | warnings.warn( 1232 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010g is an invalid version and will not be supported in a future release 1233 | warnings.warn( 1234 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010h is an invalid version and will not be supported in a future release 1235 | warnings.warn( 1236 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010k is an invalid version and will not be supported in a future release 1237 | warnings.warn( 1238 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010l is an invalid version and will not be supported in a future release 1239 | warnings.warn( 1240 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010o is an invalid version and will not be supported in a future release 1241 | warnings.warn( 1242 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011d is an invalid version and will not be supported in a future release 1243 | warnings.warn( 1244 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011e is an invalid version and will not be supported in a future release 1245 | warnings.warn( 1246 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011g is an invalid version and will not be supported in a future release 1247 | warnings.warn( 1248 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011h is an invalid version and will not be supported in a future release 1249 | warnings.warn( 1250 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011j is an invalid version and will not be supported in a future release 1251 | warnings.warn( 1252 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011k is an invalid version and will not be supported in a future release 1253 | warnings.warn( 1254 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011n is an invalid version and will not be supported in a future release 1255 | warnings.warn( 1256 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012d is an invalid version and will not be supported in a future release 1257 | warnings.warn( 1258 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012f is an invalid version and will not be supported in a future release 1259 | warnings.warn( 1260 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012g is an invalid version and will not be supported in a future release 1261 | warnings.warn( 1262 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012h is an invalid version and will not be supported in a future release 1263 | warnings.warn( 1264 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012j is an invalid version and will not be supported in a future release 1265 | warnings.warn( 1266 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2013d is an invalid version and will not be supported in a future release 1267 | warnings.warn( 1268 | Downloading https://files.pythonhosted.org/packages/d5/50/54451e88e3da4616286029a3a17fc377de817f66a0f50e1faaee90161724/pytz-2022.2.1-py2.py3-none-any.whl#sha256=220f481bdafa09c3955dfbdddb7b57780e9a94f5127e35456a48589b9e0c0197 1269 | Best match: pytz 2022.2.1 1270 | Processing pytz-2022.2.1-py2.py3-none-any.whl 1271 | Installing pytz-2022.2.1-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1272 | Adding pytz 2022.2.1 to easy-install.pth file 1273 | 1274 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pytz-2022.2.1-py3.8.egg 1275 | Searching for python-dateutil>=2.8.1 1276 | Reading https://pypi.org/simple/python-dateutil/ 1277 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-1.4 is an invalid version and will not be supported in a future release 1278 | warnings.warn( 1279 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-1.4.1 is an invalid version and will not be supported in a future release 1280 | warnings.warn( 1281 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-1.5 is an invalid version and will not be supported in a future release 1282 | warnings.warn( 1283 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.1 is an invalid version and will not be supported in a future release 1284 | warnings.warn( 1285 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.2 is an invalid version and will not be supported in a future release 1286 | warnings.warn( 1287 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.3 is an invalid version and will not be supported in a future release 1288 | warnings.warn( 1289 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.0 is an invalid version and will not be supported in a future release 1290 | warnings.warn( 1291 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.1.post1 is an invalid version and will not be supported in a future release 1292 | warnings.warn( 1293 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.1 is an invalid version and will not be supported in a future release 1294 | warnings.warn( 1295 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.2 is an invalid version and will not be supported in a future release 1296 | warnings.warn( 1297 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.0 is an invalid version and will not be supported in a future release 1298 | warnings.warn( 1299 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.1 is an invalid version and will not be supported in a future release 1300 | warnings.warn( 1301 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.2 is an invalid version and will not be supported in a future release 1302 | warnings.warn( 1303 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.3 is an invalid version and will not be supported in a future release 1304 | warnings.warn( 1305 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.6.0 is an invalid version and will not be supported in a future release 1306 | warnings.warn( 1307 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.6.1 is an invalid version and will not be supported in a future release 1308 | warnings.warn( 1309 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.0 is an invalid version and will not be supported in a future release 1310 | warnings.warn( 1311 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.1 is an invalid version and will not be supported in a future release 1312 | warnings.warn( 1313 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.2 is an invalid version and will not be supported in a future release 1314 | warnings.warn( 1315 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.3 is an invalid version and will not be supported in a future release 1316 | warnings.warn( 1317 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.4 is an invalid version and will not be supported in a future release 1318 | warnings.warn( 1319 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.5 is an invalid version and will not be supported in a future release 1320 | warnings.warn( 1321 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.8.0 is an invalid version and will not be supported in a future release 1322 | warnings.warn( 1323 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.8.1 is an invalid version and will not be supported in a future release 1324 | warnings.warn( 1325 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.8.2 is an invalid version and will not be supported in a future release 1326 | warnings.warn( 1327 | Downloading https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl#sha256=961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 1328 | Best match: python-dateutil 2.8.2 1329 | Processing python_dateutil-2.8.2-py2.py3-none-any.whl 1330 | Installing python_dateutil-2.8.2-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1331 | Adding python-dateutil 2.8.2 to easy-install.pth file 1332 | 1333 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/python_dateutil-2.8.2-py3.8.egg 1334 | Searching for pyparsing>=2.2.1 1335 | Reading https://pypi.org/simple/pyparsing/ 1336 | Downloading https://files.pythonhosted.org/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl#sha256=5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc 1337 | Best match: pyparsing 3.0.9 1338 | Processing pyparsing-3.0.9-py3-none-any.whl 1339 | Installing pyparsing-3.0.9-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1340 | Adding pyparsing 3.0.9 to easy-install.pth file 1341 | 1342 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pyparsing-3.0.9-py3.8.egg 1343 | Searching for kiwisolver>=1.0.1 1344 | Reading https://pypi.org/simple/kiwisolver/ 1345 | Downloading https://files.pythonhosted.org/packages/86/7a/6b438da7534dacd232ed4e19f74f4edced2cda9494d7e6536f54edfdf4a5/kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl#sha256=2e407cb4bd5a13984a6c2c0fe1845e4e41e96f183e5e5cd4d77a857d9693494c 1346 | Best match: kiwisolver 1.4.4 1347 | Processing kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl 1348 | Installing kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1349 | Adding kiwisolver 1.4.4 to easy-install.pth file 1350 | 1351 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/kiwisolver-1.4.4-py3.8-linux-x86_64.egg 1352 | Searching for fonttools>=4.22.0 1353 | Reading https://pypi.org/simple/fonttools/ 1354 | Downloading https://files.pythonhosted.org/packages/99/e2/90421047dfd94ed78064b41345d63985c576dc98b7c87bec92a1a8cb88e5/fonttools-4.37.1-py3-none-any.whl#sha256=fff6b752e326c15756c819fe2fe7ceab69f96a1dbcfe8911d0941cdb49905007 1355 | Best match: fonttools 4.37.1 1356 | Processing fonttools-4.37.1-py3-none-any.whl 1357 | Installing fonttools-4.37.1-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1358 | Adding fonttools 4.37.1 to easy-install.pth file 1359 | Installing fonttools script to /data3/***/anaconda3/envs/clrnet/bin 1360 | Installing pyftmerge script to /data3/***/anaconda3/envs/clrnet/bin 1361 | Installing pyftsubset script to /data3/***/anaconda3/envs/clrnet/bin 1362 | Installing ttx script to /data3/***/anaconda3/envs/clrnet/bin 1363 | 1364 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/fonttools-4.37.1-py3.8.egg 1365 | Searching for cycler>=0.10 1366 | Reading https://pypi.org/simple/cycler/ 1367 | Downloading https://files.pythonhosted.org/packages/5c/f9/695d6bedebd747e5eb0fe8fad57b72fdf25411273a39791cde838d5a8f51/cycler-0.11.0-py3-none-any.whl#sha256=3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3 1368 | Best match: cycler 0.11.0 1369 | Processing cycler-0.11.0-py3-none-any.whl 1370 | Installing cycler-0.11.0-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1371 | Adding cycler 0.11.0 to easy-install.pth file 1372 | 1373 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/cycler-0.11.0-py3.8.egg 1374 | Searching for contourpy>=1.0.1 1375 | Reading https://pypi.org/simple/contourpy/ 1376 | Downloading https://files.pythonhosted.org/packages/c1/e2/7cabfe4858cd8f2f71946ac5510421ade781074012666bc0627ef95f4105/contourpy-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=2b20bf416a693744f1ebbbd6f24e2300b613c3570ed67e518a97dc052bf3ad11 1377 | Best match: contourpy 1.0.4 1378 | Processing contourpy-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 1379 | Installing contourpy-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1380 | Adding contourpy 1.0.4 to easy-install.pth file 1381 | 1382 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/contourpy-1.0.4-py3.8-linux-x86_64.egg 1383 | Searching for multiprocess>=0.70.13 1384 | Reading https://pypi.org/simple/multiprocess/ 1385 | Downloading https://files.pythonhosted.org/packages/4b/28/915e6d4943eab538963a4fac7c409aa4a75213f5e6b98d91334e9a284c5c/multiprocess-0.70.13-py37-none-any.whl#sha256=62e556a0c31ec7176e28aa331663ac26c276ee3536b5e9bb5e850681e7a00f11 1386 | Best match: multiprocess 0.70.13 1387 | Processing multiprocess-0.70.13-py37-none-any.whl 1388 | Installing multiprocess-0.70.13-py37-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1389 | Adding multiprocess 0.70.13 to easy-install.pth file 1390 | 1391 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/multiprocess-0.70.13-py3.8.egg 1392 | Searching for pox>=0.3.1 1393 | Reading https://pypi.org/simple/pox/ 1394 | Downloading https://files.pythonhosted.org/packages/ee/ae/47dde73e5fa582669d08d0c8284e1863d8e8353947da264c38bc2831e8bb/pox-0.3.1-py2.py3-none-any.whl#sha256=541b5c845aacb806c1364d4142003efb809d654c9ca8db82e650ee86c81e680b 1395 | Best match: pox 0.3.1 1396 | Processing pox-0.3.1-py2.py3-none-any.whl 1397 | Installing pox-0.3.1-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1398 | Adding pox 0.3.1 to easy-install.pth file 1399 | Installing pox script to /data3/***/anaconda3/envs/clrnet/bin 1400 | 1401 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pox-0.3.1-py3.8.egg 1402 | Searching for dill>=0.3.5.1 1403 | Reading https://pypi.org/simple/dill/ 1404 | Downloading https://files.pythonhosted.org/packages/12/ff/3b1a8f5d59600393506c64fa14d13afdfe6fe79ed65a18d64026fe9f8356/dill-0.3.5.1-py2.py3-none-any.whl#sha256=33501d03270bbe410c72639b350e941882a8b0fd55357580fbc873fba0c59302 1405 | Best match: dill 0.3.5.1 1406 | Processing dill-0.3.5.1-py2.py3-none-any.whl 1407 | Installing dill-0.3.5.1-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1408 | Adding dill 0.3.5.1 to easy-install.pth file 1409 | Installing get_objgraph script to /data3/***/anaconda3/envs/clrnet/bin 1410 | Installing undill script to /data3/***/anaconda3/envs/clrnet/bin 1411 | 1412 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/dill-0.3.5.1-py3.8.egg 1413 | Searching for ppft>=1.7.6.5 1414 | Reading https://pypi.org/simple/ppft/ 1415 | Downloading https://files.pythonhosted.org/packages/dc/e9/5060afcfad6e19f4725d59e64398e31ec9439863244b401cb3437becbbff/ppft-1.7.6.5-py2.py3-none-any.whl#sha256=07166097d7dd45af7b98859654390d579d11dadf20780f6baca4bded3f55a580 1416 | Best match: ppft 1.7.6.5 1417 | Processing ppft-1.7.6.5-py2.py3-none-any.whl 1418 | Installing ppft-1.7.6.5-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1419 | Adding ppft 1.7.6.5 to easy-install.pth file 1420 | Installing ppserver script to /data3/***/anaconda3/envs/clrnet/bin 1421 | 1422 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ppft-1.7.6.5-py3.8.egg 1423 | Searching for threadpoolctl>=2.0.0 1424 | Reading https://pypi.org/simple/threadpoolctl/ 1425 | Downloading https://files.pythonhosted.org/packages/61/cf/6e354304bcb9c6413c4e02a747b600061c21d38ba51e7e544ac7bc66aecc/threadpoolctl-3.1.0-py3-none-any.whl#sha256=8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b 1426 | Best match: threadpoolctl 3.1.0 1427 | Processing threadpoolctl-3.1.0-py3-none-any.whl 1428 | Installing threadpoolctl-3.1.0-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1429 | Adding threadpoolctl 3.1.0 to easy-install.pth file 1430 | 1431 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/threadpoolctl-3.1.0-py3.8.egg 1432 | Searching for joblib>=1.0.0 1433 | Reading https://pypi.org/simple/joblib/ 1434 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2d.dev is an invalid version and will not be supported in a future release 1435 | warnings.warn( 1436 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2e.dev is an invalid version and will not be supported in a future release 1437 | warnings.warn( 1438 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2f.dev is an invalid version and will not be supported in a future release 1439 | warnings.warn( 1440 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2g.dev is an invalid version and will not be supported in a future release 1441 | warnings.warn( 1442 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.7.0d is an invalid version and will not be supported in a future release 1443 | warnings.warn( 1444 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: r1 is an invalid version and will not be supported in a future release 1445 | warnings.warn( 1446 | Downloading https://files.pythonhosted.org/packages/3e/d5/0163eb0cfa0b673aa4fe1cd3ea9d8a81ea0f32e50807b0c295871e4aab2e/joblib-1.1.0-py2.py3-none-any.whl#sha256=f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6 1447 | Best match: joblib 1.1.0 1448 | Processing joblib-1.1.0-py2.py3-none-any.whl 1449 | Installing joblib-1.1.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1450 | Adding joblib 1.1.0 to easy-install.pth file 1451 | 1452 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/joblib-1.1.0-py3.8.egg 1453 | Searching for ptflops==0.6.9 1454 | Best match: ptflops 0.6.9 1455 | Processing ptflops-0.6.9-py3.8.egg 1456 | ptflops 0.6.9 is already the active version in easy-install.pth 1457 | 1458 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ptflops-0.6.9-py3.8.egg 1459 | Searching for pathspec==0.9.0 1460 | Best match: pathspec 0.9.0 1461 | Processing pathspec-0.9.0-py3.8.egg 1462 | pathspec 0.9.0 is already the active version in easy-install.pth 1463 | 1464 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pathspec-0.9.0-py3.8.egg 1465 | Searching for albumentations==0.4.6 1466 | Best match: albumentations 0.4.6 1467 | Processing albumentations-0.4.6-py3.8.egg 1468 | albumentations 0.4.6 is already the active version in easy-install.pth 1469 | 1470 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/albumentations-0.4.6-py3.8.egg 1471 | Searching for mmcv==1.2.5 1472 | Best match: mmcv 1.2.5 1473 | Processing mmcv-1.2.5-py3.8.egg 1474 | mmcv 1.2.5 is already the active version in easy-install.pth 1475 | 1476 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/mmcv-1.2.5-py3.8.egg 1477 | Searching for timm==0.6.7 1478 | Best match: timm 0.6.7 1479 | Processing timm-0.6.7-py3.8.egg 1480 | timm 0.6.7 is already the active version in easy-install.pth 1481 | 1482 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/timm-0.6.7-py3.8.egg 1483 | Searching for yapf==0.32.0 1484 | Best match: yapf 0.32.0 1485 | Processing yapf-0.32.0-py3.8.egg 1486 | yapf 0.32.0 is already the active version in easy-install.pth 1487 | Installing yapf script to /data3/***/anaconda3/envs/clrnet/bin 1488 | Installing yapf-diff script to /data3/***/anaconda3/envs/clrnet/bin 1489 | 1490 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/yapf-0.32.0-py3.8.egg 1491 | Searching for ujson==1.35 1492 | Best match: ujson 1.35 1493 | Processing ujson-1.35-py3.8-linux-x86_64.egg 1494 | ujson 1.35 is already the active version in easy-install.pth 1495 | 1496 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ujson-1.35-py3.8-linux-x86_64.egg 1497 | Searching for Shapely==1.7.0 1498 | Best match: Shapely 1.7.0 1499 | Processing Shapely-1.7.0-py3.8-linux-x86_64.egg 1500 | Shapely 1.7.0 is already the active version in easy-install.pth 1501 | 1502 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/Shapely-1.7.0-py3.8-linux-x86_64.egg 1503 | Searching for imgaug==0.4.0 1504 | Best match: imgaug 0.4.0 1505 | Processing imgaug-0.4.0-py3.8.egg 1506 | imgaug 0.4.0 is already the active version in easy-install.pth 1507 | 1508 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/imgaug-0.4.0-py3.8.egg 1509 | Searching for p-tqdm==1.3.3 1510 | Best match: p-tqdm 1.3.3 1511 | Processing p_tqdm-1.3.3-py3.8.egg 1512 | p-tqdm 1.3.3 is already the active version in easy-install.pth 1513 | 1514 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/p_tqdm-1.3.3-py3.8.egg 1515 | Searching for tqdm==4.64.0 1516 | Best match: tqdm 4.64.0 1517 | Processing tqdm-4.64.0-py3.8.egg 1518 | tqdm 4.64.0 is already the active version in easy-install.pth 1519 | Installing tqdm script to /data3/***/anaconda3/envs/clrnet/bin 1520 | 1521 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/tqdm-4.64.0-py3.8.egg 1522 | Searching for scikit-image==0.19.3 1523 | Best match: scikit-image 0.19.3 1524 | Processing scikit_image-0.19.3-py3.8-linux-x86_64.egg 1525 | scikit-image 0.19.3 is already the active version in easy-install.pth 1526 | Installing skivi script to /data3/***/anaconda3/envs/clrnet/bin 1527 | 1528 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scikit_image-0.19.3-py3.8-linux-x86_64.egg 1529 | Searching for pytorch-warmup==0.1.0 1530 | Best match: pytorch-warmup 0.1.0 1531 | Processing pytorch_warmup-0.1.0-py3.8.egg 1532 | pytorch-warmup 0.1.0 is already the active version in easy-install.pth 1533 | 1534 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pytorch_warmup-0.1.0-py3.8.egg 1535 | Searching for opencv-python==4.6.0.66 1536 | Best match: opencv-python 4.6.0.66 1537 | Processing opencv_python-4.6.0.66-py3.8-linux-x86_64.egg 1538 | opencv-python 4.6.0.66 is already the active version in easy-install.pth 1539 | 1540 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/opencv_python-4.6.0.66-py3.8-linux-x86_64.egg 1541 | Searching for sklearn==0.0 1542 | Best match: sklearn 0.0 1543 | Processing sklearn-0.0-py3.8.egg 1544 | sklearn 0.0 is already the active version in easy-install.pth 1545 | 1546 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/sklearn-0.0-py3.8.egg 1547 | Searching for addict==2.4.0 1548 | Best match: addict 2.4.0 1549 | Processing addict-2.4.0-py3.8.egg 1550 | addict 2.4.0 is already the active version in easy-install.pth 1551 | 1552 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/addict-2.4.0-py3.8.egg 1553 | Searching for pandas==1.5.0rc0 1554 | Best match: pandas 1.5.0rc0 1555 | Processing pandas-1.5.0rc0-py3.8-linux-x86_64.egg 1556 | pandas 1.5.0rc0 is already the active version in easy-install.pth 1557 | 1558 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pandas-1.5.0rc0-py3.8-linux-x86_64.egg 1559 | Searching for torchvision==0.9.0 1560 | Best match: torchvision 0.9.0 1561 | Adding torchvision 0.9.0 to easy-install.pth file 1562 | 1563 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1564 | Searching for torch==1.8.0 1565 | Best match: torch 1.8.0 1566 | Adding torch 1.8.0 to easy-install.pth file 1567 | Installing convert-caffe2-to-onnx script to /data3/***/anaconda3/envs/clrnet/bin 1568 | Installing convert-onnx-to-caffe2 script to /data3/***/anaconda3/envs/clrnet/bin 1569 | 1570 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1571 | Searching for opencv-python-headless==4.6.0.66 1572 | Best match: opencv-python-headless 4.6.0.66 1573 | Adding opencv-python-headless 4.6.0.66 to easy-install.pth file 1574 | 1575 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1576 | Searching for numpy==1.23.2 1577 | Best match: numpy 1.23.2 1578 | Adding numpy 1.23.2 to easy-install.pth file 1579 | Installing f2py script to /data3/***/anaconda3/envs/clrnet/bin 1580 | Installing f2py3 script to /data3/***/anaconda3/envs/clrnet/bin 1581 | Installing f2py3.8 script to /data3/***/anaconda3/envs/clrnet/bin 1582 | 1583 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1584 | Searching for Pillow==9.2.0 1585 | Best match: Pillow 9.2.0 1586 | Adding Pillow 9.2.0 to easy-install.pth file 1587 | 1588 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1589 | Searching for typing-extensions==4.3.0 1590 | Best match: typing-extensions 4.3.0 1591 | Adding typing-extensions 4.3.0 to easy-install.pth file 1592 | 1593 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages 1594 | Finished processing dependencies for clrnet==1.0 1595 | (clrnet) [***@ai02 clrnet]$ pip list 1596 | Package Version Editable project location 1597 | ---------------------- --------- -------------------------------- 1598 | addict 2.4.0 1599 | albumentations 0.4.6 1600 | certifi 2022.6.15 1601 | clrnet 1.0 /data3/***/project/temp/clrnet 1602 | contourpy 1.0.4 1603 | cycler 0.11.0 1604 | dill 0.3.5.1 1605 | fonttools 4.37.1 1606 | imageio 2.21.1 1607 | imgaug 0.4.0 1608 | joblib 1.1.0 1609 | kiwisolver 1.4.4 1610 | matplotlib 3.6.0rc1 1611 | mmcv 1.2.5 1612 | multiprocess 0.70.13 1613 | networkx 2.8.6 1614 | numpy 1.23.2 1615 | opencv-python 4.6.0.66 1616 | opencv-python-headless 4.6.0.66 1617 | p-tqdm 1.3.3 1618 | packaging 21.3 1619 | pandas 1.5.0rc0 1620 | pathos 0.2.9 1621 | pathspec 0.9.0 1622 | Pillow 9.2.0 1623 | pip 22.1.2 1624 | pox 0.3.1 1625 | ppft 1.7.6.5 1626 | ptflops 0.6.9 1627 | pyparsing 3.0.9 1628 | python-dateutil 2.8.2 1629 | pytorch-warmup 0.1.0 1630 | pytz 2022.2.1 1631 | PyWavelets 1.3.0 1632 | PyYAML 6.0 1633 | scikit-image 0.19.3 1634 | scikit-learn 1.1.2 1635 | scipy 1.9.0 1636 | setuptools 63.4.1 1637 | Shapely 1.7.0 1638 | six 1.16.0 1639 | sklearn 0.0 1640 | threadpoolctl 3.1.0 1641 | tifffile 2022.8.12 1642 | timm 0.6.7 1643 | torch 1.8.0 1644 | torchvision 0.9.0 1645 | tqdm 4.64.0 1646 | typing_extensions 4.3.0 1647 | ujson 1.35 1648 | wheel 0.37.1 1649 | yapf 0.32.0 1650 | (clrnet) [***@ai02 clrnet]$ git clone https://github.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo.git 1651 | 正克隆到 'CLRNet-onnxruntime-and-tensorrt-demo'... 1652 | remote: Enumerating objects: 68, done. 1653 | remote: Counting objects: 100% (68/68), done. 1654 | remote: Compressing objects: 100% (65/65), done. 1655 | remote: Total 68 (delta 36), reused 0 (delta 0), pack-reused 0 1656 | Unpacking objects: 100% (68/68), done. 1657 | (clrnet) [***@ai02 clrnet]$ ls 1658 | build clrnet clrnet.egg-info CLRNet-onnxruntime-and-tensorrt-demo configs LICENSE main.py README.md requirements.txt setup.py tools 1659 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/clr_head.py ./clrnet/models/heads/ 1660 | (clrnet) [***@ai02 clrnet]$ mkdir modules 1661 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/grid_sample.py ./modules/ 1662 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/torch2onnx.py ./ 1663 | (clrnet) [***@ai02 clrnet]$ python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth 1664 | pretrained model: https://download.pytorch.org/models/resnet18-5c106cde.pth 1665 | Traceback (most recent call last): 1666 | File "torch2onnx.py", line 49, in 1667 | main() 1668 | File "torch2onnx.py", line 22, in main 1669 | state_dict = torch.load(args.load_from, map_location='cpu')['net'] 1670 | File "/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/serialization.py", line 579, in load 1671 | with _open_file_like(f, 'rb') as opened_file: 1672 | File "/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/serialization.py", line 230, in _open_file_like 1673 | return _open_file(name_or_buffer, mode) 1674 | File "/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/serialization.py", line 211, in __init__ 1675 | super(_open_file, self).__init__(open(name, mode)) 1676 | FileNotFoundError: [Errno 2] No such file or directory: 'tusimple_r18.pth' 1677 | (clrnet) [***@ai02 clrnet]$ wget https://github.com/Turoad/CLRNet/releases/download/models/tusimple_r18.pth.zip 1678 | --2022-08-25 13:33:12-- https://github.com/Turoad/CLRNet/releases/download/models/tusimple_r18.pth.zip 1679 | 正在解析主机 github.com (github.com)... 20.205.243.166 1680 | 正在连接 github.com (github.com)|20.205.243.166|:443... 已连接。 1681 | 无法建立 SSL 连接。 1682 | (clrnet) [***@ai02 clrnet]$ cp /data3/***/project/tusimple_r18.pth.zip ./ 1683 | (clrnet) [***@ai02 clrnet]$ unzip tusimple_r18.pth.zip 1684 | Archive: tusimple_r18.pth.zip 1685 | inflating: tusimple_r18.pth 1686 | (clrnet) [***@ai02 clrnet]$ python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth 1687 | pretrained model: https://download.pytorch.org/models/resnet18-5c106cde.pth 1688 | /data3/***/project/temp/clrnet/modules/grid_sample.py:35: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! 1689 | assert n == gn 1690 | /data3/***/project/temp/clrnet/modules/grid_sample.py:69: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1691 | x0 = torch.where(x0 < 0, torch.tensor(0, device=x0.device), x0) 1692 | /data3/***/project/temp/clrnet/modules/grid_sample.py:70: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1693 | x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=x0.device), x0) 1694 | /data3/***/project/temp/clrnet/modules/grid_sample.py:70: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 1695 | x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=x0.device), x0) 1696 | /data3/***/project/temp/clrnet/modules/grid_sample.py:71: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1697 | x1 = torch.where(x1 < 0, torch.tensor(0, device=x1.device), x1) 1698 | /data3/***/project/temp/clrnet/modules/grid_sample.py:72: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1699 | x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=x1.device), x1) 1700 | /data3/***/project/temp/clrnet/modules/grid_sample.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 1701 | x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=x1.device), x1) 1702 | /data3/***/project/temp/clrnet/modules/grid_sample.py:73: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1703 | y0 = torch.where(y0 < 0, torch.tensor(0, device=y0.device), y0) 1704 | /data3/***/project/temp/clrnet/modules/grid_sample.py:74: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1705 | y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=y0.device), y0) 1706 | /data3/***/project/temp/clrnet/modules/grid_sample.py:74: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 1707 | y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=y0.device), y0) 1708 | /data3/***/project/temp/clrnet/modules/grid_sample.py:75: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1709 | y1 = torch.where(y1 < 0, torch.tensor(0, device=y1.device), y1) 1710 | /data3/***/project/temp/clrnet/modules/grid_sample.py:76: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect. 1711 | y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=y1.device), y1) 1712 | /data3/***/project/temp/clrnet/modules/grid_sample.py:76: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 1713 | y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=y1.device), y1) 1714 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/test.jpg ./ 1715 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx ./ 1716 | cp: 无法获取"CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx" 的文件状态(stat): 没有那个文件或目录 1717 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx.py ./ 1718 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx_new.py ./ 1719 | (clrnet) [***@ai02 clrnet]$ pytho demo_onnx.py 1720 | bash: pytho: 未找到命令... 1721 | (clrnet) [***@ai02 clrnet]$ python demo_onnx.py 1722 | Traceback (most recent call last): 1723 | File "demo_onnx.py", line 6, in 1724 | import onnxruntime 1725 | ModuleNotFoundError: No module named 'onnxruntime' 1726 | (clrnet) [***@ai02 clrnet]$ pip install onnxruntime 1727 | Collecting onnxruntime 1728 | Using cached onnxruntime-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB) 1729 | Collecting protobuf 1730 | Downloading protobuf-4.21.5-cp37-abi3-manylinux2014_x86_64.whl (408 kB) 1731 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 408.4/408.4 kB 499.8 kB/s eta 0:00:00 1732 | Collecting coloredlogs 1733 | Using cached coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB) 1734 | Collecting flatbuffers 1735 | Downloading flatbuffers-2.0.7-py2.py3-none-any.whl (26 kB) 1736 | Requirement already satisfied: numpy>=1.21.0 in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages (from onnxruntime) (1.23.2) 1737 | Collecting sympy 1738 | Downloading sympy-1.11-py3-none-any.whl (6.5 MB) 1739 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.5/6.5 MB 6.1 MB/s eta 0:00:00 1740 | Requirement already satisfied: packaging in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/packaging-21.3-py3.8.egg (from onnxruntime) (21.3) 1741 | Collecting humanfriendly>=9.1 1742 | Using cached humanfriendly-10.0-py2.py3-none-any.whl (86 kB) 1743 | Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pyparsing-3.0.9-py3.8.egg (from packaging->onnxruntime) (3.0.9) 1744 | Collecting mpmath>=0.19 1745 | Using cached mpmath-1.2.1-py3-none-any.whl (532 kB) 1746 | Installing collected packages: mpmath, flatbuffers, sympy, protobuf, humanfriendly, coloredlogs, onnxruntime 1747 | Successfully installed coloredlogs-15.0.1 flatbuffers-2.0.7 humanfriendly-10.0 mpmath-1.2.1 onnxruntime-1.12.1 protobuf-4.21.5 sympy-1.11 1748 | (clrnet) [***@ai02 clrnet]$ ls 1749 | build clrnet.egg-info configs demo_onnx.py main.py README.md setup.py tools tusimple_r18.onnx tusimple_r18.pth.zip 1750 | clrnet CLRNet-onnxruntime-and-tensorrt-demo demo_onnx_new.py LICENSE modules requirements.txt test.jpg torch2onnx.py tusimple_r18.pth 1751 | (clrnet) [***@ai02 clrnet]$ python demo_onnx.py 1752 | demo_onnx.py:141: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here. 1753 | Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations 1754 | )[::-1].cumprod()[::-1]).astype(np.bool)) 1755 | (clrnet) [***@ai02 clrnet]$ ls 1756 | build clrnet.egg-info configs demo_onnx.py main.py output_onnx.png requirements.txt test.jpg torch2onnx.py tusimple_r18.pth 1757 | clrnet CLRNet-onnxruntime-and-tensorrt-demo demo_onnx_new.py LICENSE modules README.md setup.py tools tusimple_r18.onnx tusimple_r18.pth.zip 1758 | (clrnet) [***@ai02 clrnet]$ python demo_onnx_new.py 1759 | demo_onnx_new.py:186: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here. 1760 | Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations 1761 | )[::-1].cumprod()[::-1]).astype(np.bool)) 1762 | Done! 1763 | (clrnet) [***@ai02 clrnet]$ 1764 | -------------------------------------------------------------------------------- /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/07a47e666d4bfca9a4c60a777c62c23828fca663/test.jpg -------------------------------------------------------------------------------- /torch2onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import numpy as np 8 | import random 9 | from clrnet.utils.config import Config 10 | from clrnet.datasets import build_dataloader 11 | from clrnet.models.registry import build_net 12 | 13 | def main(): 14 | args = parse_args() 15 | cfg = Config.fromfile(args.config) 16 | cfg.load_from = args.load_from 17 | net = build_net(cfg) 18 | net = net.cpu() 19 | 20 | from collections import OrderedDict 21 | new_state_dict = OrderedDict() 22 | state_dict = torch.load(args.load_from, map_location='cpu')['net'] 23 | 24 | for k, v in state_dict.items(): 25 | namekey = k[7:] # 去掉module前缀 26 | new_state_dict[namekey] = v 27 | 28 | net.load_state_dict(new_state_dict) 29 | 30 | dummy_input = torch.randn(1, 3, 320,800, device='cpu') 31 | torch.onnx.export(net, dummy_input, 'tusimple_r18.onnx', 32 | export_params=True, opset_version=11, do_constant_folding=True, 33 | input_names = ['input']) 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description='Train a detector') 37 | parser.add_argument('config', help='train config file path') 38 | 39 | parser.add_argument('--load_from', 40 | default=None, 41 | help='the checkpoint file to load from') 42 | 43 | args = parser.parse_args() 44 | 45 | return args 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | 51 | # python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth 52 | 53 | --------------------------------------------------------------------------------