├── README.md
├── clr_head.py
├── demo_onnx.py
├── demo_onnx_new.py
├── demo_trt.py
├── grid_sample.py
├── imgs
├── output_onnx.png
└── output_trt.png
├── my_log
└── test_onnx.log
├── test.jpg
└── torch2onnx.py
/README.md:
--------------------------------------------------------------------------------
1 | # CLRNet-onnxruntime-and-tensorrt-demo
2 | This is the onnxruntime and tensorrt inference code for CLRNet: Cross Layer Refinement Network for Lane Detection (CVPR 2022). Official code: https://github.com/Turoad/CLRNet
3 |
4 | [skip to onnx demo](#onnx) and [skip to tensorrt demo](#tensorrt)
5 |
6 | ## Note
7 | 1、Making onnx supported op grid_sampler.
8 | 2、Using this code you can successfully convert to onnx model and inference an onnxruntime demo. A new version demo only use numpy to do post-processing, easy to deploy but more time cost for NMS.
9 | 3、Modifications according to the following operations will affect the training code, this code only for onnx inference.
10 | 4、It mainly includes two parts: model inference and post-processing.
11 | 5、Supporting convert to tensorrt engine.
12 |
13 | ## convert and test onnx
14 | 1、git official code and install original environment by refer to https://github.com/Turoad/CLRNet
15 | 2、git clone this code
16 | 3、cp clr_head.py to your_path/CLRNet/clrnet/models/heads/
17 | 4、mkdir your_path/CLRNet/modules/ and cp grid_sample.py to your_path/CLRNet/modules/
18 | 5、cp torch2onnx.py to your_path/CLRNet/
19 | 6、For example, run
20 | ```
21 | python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth
22 | ```
23 | my deployment log is here https://github.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/blob/main/my_log/test_onnx.log
24 | 7、cp test.jpg to your_path/CLRNet/ and run
25 |
26 | 1) NMS based on torch and cpython.
27 | ```
28 | python demo_onnx.py
29 | ````
30 | 2) NMS based on numpy.
31 | ```
32 | python demo_onnx_new.py
33 | ```
34 | ## onnx output
35 |
36 |
37 |
38 | ## convert and test tensorrt
39 | Tensorrt version needs to be greater than 8.4. This code is implemented in TensorRT-8.4.0.6.
40 | *GatherElements error、IShuffleLayer error、`is_tensor()' failed* have been resolved.
41 | 1、install tensorrt and compilation tools *trtexec*.
42 | 2、install *polygraphy* to help modify the onnx model. You can install by
43 | ```
44 | pip install nvidia-pyindex
45 | pip install polygraphy
46 | pip install onnx-graphsurgeon
47 | ```
48 | and run
49 | ```
50 | polygraphy surgeon sanitize your_path/tusimple_r18.onnx --fold-constants --output your_path/tusimple_r18.onnx
51 | ```
52 | 3、 convert to tensorrt and get tusimple_r18.engine
53 | ```
54 | ./trtexec --onnx=your_path/tusimple_r18.onnx --saveEngine=your_path/tusimple_r18.engine --verbose
55 | ```
56 | 4、test tensorrt demo.
57 | ```
58 | python demo_trt.py
59 | ```
60 |
61 | ## tensorrt output
62 |
63 |
64 |
65 |
66 | ## TO DO
67 | - [x] Optimize post-processing.
68 | - [x] Tensorrt demo.
69 | - [ ] Cpp demo.
70 |
71 |
--------------------------------------------------------------------------------
/clr_head.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import cv2
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from mmcv.cnn import ConvModule
9 |
10 | from clrnet.utils.lane import Lane
11 | from clrnet.models.losses.focal_loss import FocalLoss
12 | from clrnet.models.losses.accuracy import accuracy
13 | from clrnet.ops import nms
14 |
15 | from clrnet.models.utils.roi_gather import ROIGather, LinearModule
16 | from clrnet.models.utils.seg_decoder import SegDecoder
17 | from clrnet.models.utils.dynamic_assign import assign
18 | from clrnet.models.losses.lineiou_loss import liou_loss
19 | from ..registry import HEADS
20 |
21 | from modules.grid_sample import bilinear_grid_sample
22 |
23 | @HEADS.register_module
24 | class CLRHead(nn.Module):
25 | def __init__(self,
26 | num_points=72,
27 | prior_feat_channels=64,
28 | fc_hidden_dim=64,
29 | num_priors=192,
30 | num_fc=2,
31 | refine_layers=3,
32 | sample_points=36,
33 | cfg=None):
34 | super(CLRHead, self).__init__()
35 | self.cfg = cfg
36 | self.img_w = self.cfg.img_w
37 | self.img_h = self.cfg.img_h
38 | self.n_strips = num_points - 1
39 | self.n_offsets = num_points
40 | self.num_priors = num_priors
41 | self.sample_points = sample_points
42 | self.refine_layers = refine_layers
43 | self.fc_hidden_dim = fc_hidden_dim
44 |
45 |
46 | self.sample_x_indexs = (torch.linspace(0, 1, steps=self.sample_points, dtype=torch.float32) * self.n_strips).long()
47 | self.prior_feat_ys = torch.flip((1 - self.sample_x_indexs.float() / self.n_strips), dims=[-1])
48 | self.prior_ys = torch.linspace(1, 0, steps=self.n_offsets, dtype=torch.float32)
49 |
50 | # self.register_buffer(name='sample_x_indexs', tensor=(torch.linspace(
51 | # 0, 1, steps=self.sample_points, dtype=torch.float32) *
52 | # self.n_strips).long())
53 | # self.register_buffer(name='prior_feat_ys', tensor=torch.flip(
54 | # (1 - self.sample_x_indexs.float() / self.n_strips), dims=[-1]))
55 | # self.register_buffer(name='prior_ys', tensor=torch.linspace(1,
56 | # 0,
57 | # steps=self.n_offsets,
58 | # dtype=torch.float32))
59 |
60 | self.prior_feat_channels = prior_feat_channels
61 |
62 | self._init_prior_embeddings()
63 | init_priors, priors_on_featmap = self.generate_priors_from_embeddings() #None, None
64 | self.register_buffer(name='priors', tensor=init_priors)
65 | self.register_buffer(name='priors_on_featmap', tensor=priors_on_featmap)
66 |
67 | # generate xys for feature map
68 | self.seg_decoder = SegDecoder(self.img_h, self.img_w,
69 | self.cfg.num_classes,
70 | self.prior_feat_channels,
71 | self.refine_layers)
72 |
73 | reg_modules = list()
74 | cls_modules = list()
75 | for _ in range(num_fc):
76 | reg_modules += [*LinearModule(self.fc_hidden_dim)]
77 | cls_modules += [*LinearModule(self.fc_hidden_dim)]
78 | self.reg_modules = nn.ModuleList(reg_modules)
79 | self.cls_modules = nn.ModuleList(cls_modules)
80 |
81 | self.roi_gather = ROIGather(self.prior_feat_channels, self.num_priors,
82 | self.sample_points, self.fc_hidden_dim,
83 | self.refine_layers)
84 |
85 | self.reg_layers = nn.Linear(
86 | self.fc_hidden_dim, self.n_offsets + 1 + 2 +
87 | 1) # n offsets + 1 length + start_x + start_y + theta
88 | self.cls_layers = nn.Linear(self.fc_hidden_dim, 2)
89 |
90 | weights = torch.ones(self.cfg.num_classes)
91 | weights[0] = self.cfg.bg_weight
92 | # self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
93 | # weight=weights)
94 |
95 | # init the weights here
96 | self.init_weights()
97 |
98 | # function to init layer weights
99 | def init_weights(self):
100 | # initialize heads
101 | for m in self.cls_layers.parameters():
102 | nn.init.normal_(m, mean=0., std=1e-3)
103 |
104 | for m in self.reg_layers.parameters():
105 | nn.init.normal_(m, mean=0., std=1e-3)
106 |
107 | def pool_prior_features(self, batch_features, num_priors, prior_xs):
108 | '''
109 | pool prior feature from feature map.
110 | Args:
111 | batch_features (Tensor): Input feature maps, shape: (B, C, H, W)
112 | '''
113 |
114 | batch_size = batch_features.shape[0]
115 |
116 | prior_xs = prior_xs.view(batch_size, num_priors, -1, 1)
117 | prior_ys = self.prior_feat_ys.repeat(batch_size * num_priors).view(
118 | batch_size, num_priors, -1, 1)
119 |
120 | prior_xs = prior_xs * 2. - 1.
121 | prior_ys = prior_ys * 2. - 1.
122 | grid = torch.cat((prior_xs, prior_ys), dim=-1)
123 | # feature = F.grid_sample(batch_features, grid,
124 | # align_corners=True).permute(0, 2, 1, 3)
125 |
126 | feature = bilinear_grid_sample(batch_features, grid,
127 | align_corners=True).permute(0, 2, 1, 3)
128 |
129 | feature = feature.reshape(batch_size * num_priors,
130 | self.prior_feat_channels, self.sample_points,
131 | 1)
132 | return feature
133 |
134 | def generate_priors_from_embeddings(self):
135 | predictions = self.prior_embeddings.weight # (num_prop, 3)
136 |
137 | # 2 scores, 1 start_y, 1 start_x, 1 theta, 1 length, 72 coordinates, score[0] = negative prob, score[1] = positive prob
138 | priors = predictions.new_zeros(
139 | (self.num_priors, 2 + 2 + 2 + self.n_offsets), device=predictions.device)
140 |
141 | priors[:, 2:5] = predictions.clone()
142 | priors[:, 6:] = (
143 | priors[:, 3].unsqueeze(1).clone().repeat(1, self.n_offsets) *
144 | (self.img_w - 1) +
145 | ((1 - self.prior_ys.repeat(self.num_priors, 1) -
146 | priors[:, 2].unsqueeze(1).clone().repeat(1, self.n_offsets)) *
147 | self.img_h / torch.tan(priors[:, 4].unsqueeze(1).clone().repeat(
148 | 1, self.n_offsets) * math.pi + 1e-5))) / (self.img_w - 1)
149 |
150 | # init priors on feature map
151 | priors_on_featmap = priors.clone()[..., 6 + self.sample_x_indexs]
152 |
153 | return priors, priors_on_featmap
154 |
155 | def _init_prior_embeddings(self):
156 | # [start_y, start_x, theta] -> all normalize
157 | self.prior_embeddings = nn.Embedding(self.num_priors, 3)
158 |
159 | bottom_priors_nums = self.num_priors * 3 // 4
160 | left_priors_nums, _ = self.num_priors // 8, self.num_priors // 8
161 |
162 | strip_size = 0.5 / (left_priors_nums // 2 - 1)
163 | bottom_strip_size = 1 / (bottom_priors_nums // 4 + 1)
164 | for i in range(left_priors_nums):
165 | nn.init.constant_(self.prior_embeddings.weight[i, 0],
166 | (i // 2) * strip_size)
167 | nn.init.constant_(self.prior_embeddings.weight[i, 1], 0.)
168 | nn.init.constant_(self.prior_embeddings.weight[i, 2],
169 | 0.16 if i % 2 == 0 else 0.32)
170 |
171 | for i in range(left_priors_nums,
172 | left_priors_nums + bottom_priors_nums):
173 | nn.init.constant_(self.prior_embeddings.weight[i, 0], 0.)
174 | nn.init.constant_(self.prior_embeddings.weight[i, 1],
175 | ((i - left_priors_nums) // 4 + 1) *
176 | bottom_strip_size)
177 | nn.init.constant_(self.prior_embeddings.weight[i, 2],
178 | 0.2 * (i % 4 + 1))
179 |
180 | for i in range(left_priors_nums + bottom_priors_nums, self.num_priors):
181 | nn.init.constant_(
182 | self.prior_embeddings.weight[i, 0],
183 | ((i - left_priors_nums - bottom_priors_nums) // 2) *
184 | strip_size)
185 | nn.init.constant_(self.prior_embeddings.weight[i, 1], 1.)
186 | nn.init.constant_(self.prior_embeddings.weight[i, 2],
187 | 0.68 if i % 2 == 0 else 0.84)
188 |
189 | # forward function here
190 | def forward(self, x, **kwargs):
191 | '''
192 | Take pyramid features as input to perform Cross Layer Refinement and finally output the prediction lanes.
193 | Each feature is a 4D tensor.
194 | Args:
195 | x: input features (list[Tensor])
196 | Return:
197 | prediction_list: each layer's prediction result
198 | seg: segmentation result for auxiliary loss
199 | '''
200 | batch_features = list(x[len(x) - self.refine_layers:])
201 | batch_features.reverse()
202 | batch_size = batch_features[-1].shape[0]
203 |
204 | if self.training:
205 | self.priors, self.priors_on_featmap = self.generate_priors_from_embeddings()
206 |
207 | priors, priors_on_featmap = self.priors.repeat(batch_size, 1,
208 | 1), self.priors_on_featmap.repeat(
209 | batch_size, 1, 1)
210 |
211 | predictions_lists = []
212 |
213 | # iterative refine
214 | prior_features_stages = []
215 | for stage in range(self.refine_layers):
216 | num_priors = priors_on_featmap.shape[1]
217 | prior_xs = torch.flip(priors_on_featmap, dims=[2])
218 |
219 | batch_prior_features = self.pool_prior_features(
220 | batch_features[stage], num_priors, prior_xs)
221 | prior_features_stages.append(batch_prior_features)
222 |
223 | fc_features = self.roi_gather(prior_features_stages,
224 | batch_features[stage], stage)
225 |
226 | fc_features = fc_features.view(num_priors, batch_size,
227 | -1).reshape(batch_size * num_priors,
228 | self.fc_hidden_dim)
229 |
230 | cls_features = fc_features.clone()
231 | reg_features = fc_features.clone()
232 | for cls_layer in self.cls_modules:
233 | cls_features = cls_layer(cls_features)
234 | for reg_layer in self.reg_modules:
235 | reg_features = reg_layer(reg_features)
236 |
237 | cls_logits = self.cls_layers(cls_features)
238 | reg = self.reg_layers(reg_features)
239 |
240 | cls_logits = cls_logits.reshape(
241 | batch_size, -1, cls_logits.shape[1]) # (B, num_priors, 2)
242 | reg = reg.reshape(batch_size, -1, reg.shape[1])
243 |
244 | predictions = priors.clone()
245 | predictions[:, :, :2] = cls_logits
246 |
247 | predictions[:, :,
248 | 2:5] += reg[:, :, :3] # also reg theta angle here
249 | predictions[:, :, 5] = reg[:, :, 3] # length
250 |
251 | def tran_tensor(t):
252 | return t.unsqueeze(2).clone().repeat(1, 1, self.n_offsets)
253 |
254 | predictions[..., 6:] = (
255 | tran_tensor(predictions[..., 3]) * (self.img_w - 1) +
256 | ((1 - self.prior_ys.repeat(batch_size, num_priors, 1) -
257 | tran_tensor(predictions[..., 2])) * self.img_h /
258 | torch.tan(tran_tensor(predictions[..., 4]) * math.pi + 1e-5))) / (self.img_w - 1)
259 |
260 | prediction_lines = predictions.clone()
261 | predictions[..., 6:] += reg[..., 4:]
262 |
263 | predictions_lists.append(predictions)
264 |
265 | if stage != self.refine_layers - 1:
266 | priors = prediction_lines.detach().clone()
267 | priors_on_featmap = priors[..., 6 + self.sample_x_indexs]
268 |
269 | if self.training:
270 | seg = None
271 | seg_features = torch.cat([
272 | F.interpolate(feature,
273 | size=[
274 | batch_features[-1].shape[2],
275 | batch_features[-1].shape[3]
276 | ],
277 | mode='bilinear',
278 | align_corners=False)
279 | for feature in batch_features
280 | ],
281 | dim=1)
282 | seg = self.seg_decoder(seg_features)
283 | output = {'predictions_lists': predictions_lists, 'seg': seg}
284 | return self.loss(output, kwargs['batch'])
285 |
286 | return predictions_lists[-1]
287 |
288 | def predictions_to_pred(self, predictions):
289 | '''
290 | Convert predictions to internal Lane structure for evaluation.
291 | '''
292 | self.prior_ys = self.prior_ys.to(predictions.device)
293 | self.prior_ys = self.prior_ys.double()
294 |
295 | lanes = []
296 | for lane in predictions:
297 | lane_xs = lane[6:] # normalized value
298 | start = min(max(0, int(round(lane[2].item() * self.n_strips))),
299 | self.n_strips)
300 | length = int(round(lane[5].item()))
301 | end = start + length - 1
302 | end = min(end, len(self.prior_ys) - 1)
303 | # end = label_end
304 | # if the prediction does not start at the bottom of the image,
305 | # extend its prediction until the x is outside the image
306 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.)
307 | ).cpu().numpy()[::-1].cumprod()[::-1]).astype(np.bool))
308 | lane_xs[end + 1:] = -2
309 | lane_xs[:start][mask] = -2
310 | lane_ys = self.prior_ys[lane_xs >= 0]
311 | lane_xs = lane_xs[lane_xs >= 0]
312 | lane_xs = lane_xs.flip(0).double()
313 | lane_ys = lane_ys.flip(0)
314 |
315 | lane_ys = (lane_ys * (self.cfg.ori_img_h - self.cfg.cut_height) +
316 | self.cfg.cut_height) / self.cfg.ori_img_h
317 | if len(lane_xs) <= 1:
318 | continue
319 | points = torch.stack(
320 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
321 | dim=1).squeeze(2)
322 |
323 | lane = Lane(points=points.cpu().numpy(),
324 | metadata={
325 | 'start_x': lane[3],
326 | 'start_y': lane[2],
327 | 'conf': lane[1]
328 | })
329 | lanes.append(lane)
330 | return lanes
331 |
332 | def loss(self,
333 | output,
334 | batch,
335 | cls_loss_weight=2.,
336 | xyt_loss_weight=0.5,
337 | iou_loss_weight=2.,
338 | seg_loss_weight=1.):
339 | if self.cfg.haskey('cls_loss_weight'):
340 | cls_loss_weight = self.cfg.cls_loss_weight
341 | if self.cfg.haskey('xyt_loss_weight'):
342 | xyt_loss_weight = self.cfg.xyt_loss_weight
343 | if self.cfg.haskey('iou_loss_weight'):
344 | iou_loss_weight = self.cfg.iou_loss_weight
345 | if self.cfg.haskey('seg_loss_weight'):
346 | seg_loss_weight = self.cfg.seg_loss_weight
347 |
348 | predictions_lists = output['predictions_lists']
349 | targets = batch['lane_line'].clone()
350 | cls_criterion = FocalLoss(alpha=0.25, gamma=2.)
351 | cls_loss = 0
352 | reg_xytl_loss = 0
353 | iou_loss = 0
354 | cls_acc = []
355 |
356 | cls_acc_stage = []
357 | for stage in range(self.refine_layers):
358 | predictions_list = predictions_lists[stage]
359 | for predictions, target in zip(predictions_list, targets):
360 | target = target[target[:, 1] == 1]
361 |
362 | if len(target) == 0:
363 | # If there are no targets, all predictions have to be negatives (i.e., 0 confidence)
364 | cls_target = predictions.new_zeros(predictions.shape[0]).long()
365 | cls_pred = predictions[:, :2]
366 | cls_loss = cls_loss + cls_criterion(
367 | cls_pred, cls_target).sum()
368 | continue
369 |
370 | with torch.no_grad():
371 | matched_row_inds, matched_col_inds = assign(
372 | predictions, target, self.img_w, self.img_h)
373 |
374 | # classification targets
375 | cls_target = predictions.new_zeros(predictions.shape[0]).long()
376 | cls_target[matched_row_inds] = 1
377 | cls_pred = predictions[:, :2]
378 |
379 | # regression targets -> [start_y, start_x, theta] (all transformed to absolute values), only on matched pairs
380 | reg_yxtl = predictions[matched_row_inds, 2:6]
381 | reg_yxtl[:, 0] *= self.n_strips
382 | reg_yxtl[:, 1] *= (self.img_w - 1)
383 | reg_yxtl[:, 2] *= 180
384 | reg_yxtl[:, 3] *= self.n_strips
385 |
386 | target_yxtl = target[matched_col_inds, 2:6].clone()
387 |
388 | # regression targets -> S coordinates (all transformed to absolute values)
389 | reg_pred = predictions[matched_row_inds, 6:]
390 | reg_pred *= (self.img_w - 1)
391 | reg_targets = target[matched_col_inds, 6:].clone()
392 |
393 | with torch.no_grad():
394 | predictions_starts = torch.clamp(
395 | (predictions[matched_row_inds, 2] *
396 | self.n_strips).round().long(), 0,
397 | self.n_strips) # ensure the predictions starts is valid
398 | target_starts = (target[matched_col_inds, 2] *
399 | self.n_strips).round().long()
400 | target_yxtl[:, -1] -= (predictions_starts - target_starts
401 | ) # reg length
402 |
403 | # Loss calculation
404 | cls_loss = cls_loss + cls_criterion(cls_pred, cls_target).sum(
405 | ) / target.shape[0]
406 |
407 | target_yxtl[:, 0] *= self.n_strips
408 | target_yxtl[:, 2] *= 180
409 | reg_xytl_loss = reg_xytl_loss + F.smooth_l1_loss(
410 | reg_yxtl, target_yxtl,
411 | reduction='none').mean()
412 |
413 | iou_loss = iou_loss + liou_loss(
414 | reg_pred, reg_targets,
415 | self.img_w, length=15)
416 |
417 | # calculate acc
418 | cls_accuracy = accuracy(cls_pred, cls_target)
419 | cls_acc_stage.append(cls_accuracy)
420 |
421 | cls_acc.append(sum(cls_acc_stage) / len(cls_acc_stage))
422 |
423 | # extra segmentation loss
424 | seg_loss = self.criterion(F.log_softmax(output['seg'], dim=1),
425 | batch['seg'].long())
426 |
427 | cls_loss /= (len(targets) * self.refine_layers)
428 | reg_xytl_loss /= (len(targets) * self.refine_layers)
429 | iou_loss /= (len(targets) * self.refine_layers)
430 |
431 | loss = cls_loss * cls_loss_weight + reg_xytl_loss * xyt_loss_weight \
432 | + seg_loss * seg_loss_weight + iou_loss * iou_loss_weight
433 |
434 | return_value = {
435 | 'loss': loss,
436 | 'loss_stats': {
437 | 'loss': loss,
438 | 'cls_loss': cls_loss * cls_loss_weight,
439 | 'reg_xytl_loss': reg_xytl_loss * xyt_loss_weight,
440 | 'seg_loss': seg_loss * seg_loss_weight,
441 | 'iou_loss': iou_loss * iou_loss_weight
442 | }
443 | }
444 |
445 | for i in range(self.refine_layers):
446 | return_value['loss_stats']['stage_{}_acc'.format(i)] = cls_acc[i]
447 |
448 | return return_value
449 |
450 |
451 | def get_lanes(self, output, as_lanes=True):
452 | '''
453 | Convert model output to lanes.
454 | '''
455 |
456 | softmax = nn.Softmax(dim=1)
457 |
458 | decoded = []
459 | for predictions in output:
460 | # filter out the conf lower than conf threshold
461 | threshold = self.cfg.test_parameters.conf_threshold
462 | scores = softmax(predictions[:, :2])[:, 1]
463 | keep_inds = scores >= threshold
464 | predictions = predictions[keep_inds]
465 | scores = scores[keep_inds]
466 |
467 | if predictions.shape[0] == 0:
468 | decoded.append([])
469 | continue
470 | nms_predictions = predictions.detach().clone()
471 | nms_predictions = torch.cat(
472 | [nms_predictions[..., :4], nms_predictions[..., 5:]], dim=-1)
473 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
474 | nms_predictions[...,
475 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1)
476 |
477 | keep, num_to_keep, _ = nms(
478 | nms_predictions,
479 | scores,
480 | overlap=self.cfg.test_parameters.nms_thres,
481 | top_k=self.cfg.max_lanes)
482 |
483 | keep = keep[:num_to_keep]
484 | predictions = predictions[keep]
485 |
486 | if predictions.shape[0] == 0:
487 | decoded.append([])
488 | continue
489 |
490 | predictions[:, 5] = torch.round(predictions[:, 5] * self.n_strips)
491 | #print("SDDDKLLLJJKLLLL")
492 | #as_lanes = False
493 | if as_lanes:
494 | pred = self.predictions_to_pred(predictions)
495 | else:
496 | pred = predictions
497 |
498 | decoded.append(pred)
499 |
500 | return decoded
501 |
--------------------------------------------------------------------------------
/demo_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import numpy as np
5 | import timeit
6 | import onnxruntime
7 | import torch
8 | from clrnet.ops import nms_impl
9 |
10 | from scipy.interpolate import InterpolatedUnivariateSpline
11 | import numpy as np
12 |
13 | COLORS = [
14 | (255, 0, 0),
15 | (0, 255, 0),
16 | (0, 0, 255),
17 | (255, 255, 0),
18 | (255, 0, 255),
19 | (0, 255, 255),
20 | (128, 255, 0),
21 | (255, 128, 0),
22 | (128, 0, 255),
23 | (255, 0, 128),
24 | (0, 128, 255),
25 | (0, 255, 128),
26 | (128, 255, 255),
27 | (255, 128, 255),
28 | (255, 255, 128),
29 | (60, 180, 0),
30 | (180, 60, 0),
31 | (0, 60, 180),
32 | (0, 180, 60),
33 | (60, 0, 180),
34 | (180, 0, 60),
35 | (255, 0, 0),
36 | (0, 255, 0),
37 | (0, 0, 255),
38 | (255, 255, 0),
39 | (255, 0, 255),
40 | (0, 255, 255),
41 | (128, 255, 0),
42 | (255, 128, 0),
43 | (128, 0, 255),
44 | ]
45 |
46 | class Lane:
47 | def __init__(self, points=None, invalid_value=-2., metadata=None):
48 | super(Lane, self).__init__()
49 | self.curr_iter = 0
50 | self.points = points
51 | self.invalid_value = invalid_value
52 | self.function = InterpolatedUnivariateSpline(points[:, 1],
53 | points[:, 0],
54 | k=min(3,
55 | len(points) - 1))
56 | self.min_y = points[:, 1].min() - 0.01
57 | self.max_y = points[:, 1].max() + 0.01
58 |
59 | self.metadata = metadata or {}
60 |
61 | self.sample_y = range(710, 150, -10)
62 | self.ori_img_w = 1280
63 | self.ori_img_h = 720
64 |
65 | def __repr__(self):
66 | return '[Lane]\n' + str(self.points) + '\n[/Lane]'
67 |
68 | def __call__(self, lane_ys):
69 | lane_xs = self.function(lane_ys)
70 |
71 | lane_xs[(lane_ys < self.min_y) |
72 | (lane_ys > self.max_y)] = self.invalid_value
73 | return lane_xs
74 |
75 | def to_array(self):
76 | sample_y = self.sample_y
77 | img_w, img_h = self.ori_img_w, self.ori_img_h
78 | ys = np.array(sample_y) / float(img_h)
79 | xs = self(ys)
80 | valid_mask = (xs >= 0) & (xs < 1)
81 | lane_xs = xs[valid_mask] * img_w
82 | lane_ys = ys[valid_mask] * img_h
83 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
84 | axis=1)
85 | return lane
86 |
87 | def __iter__(self):
88 | return self
89 |
90 | def __next__(self):
91 | if self.curr_iter < len(self.points):
92 | self.curr_iter += 1
93 | return self.points[self.curr_iter - 1]
94 | self.curr_iter = 0
95 | raise StopIteration
96 |
97 | def nms(boxes, scores, overlap, top_k):
98 | return nms_impl.nms_forward(boxes, scores, overlap, top_k)
99 |
100 | class CLRNetDemo():
101 | def __init__(self, model_path):
102 | self.ort_session = onnxruntime.InferenceSession(model_path)
103 | self.conf_threshold = 0.4
104 | self.nms_thres = 50
105 | self.max_lanes = 5
106 | self.sample_points = 36
107 | self.num_points = 72
108 | self.n_offsets = 72
109 | self.n_strips = 71
110 | self.img_w = 1280
111 | self.img_h = 720
112 | self.ori_img_w = 1280
113 | self.ori_img_h = 720
114 | self.cut_height = 160
115 |
116 | self.input_width = 800
117 | self.input_height = 320
118 |
119 | self.sample_x_indexs = (np.linspace(0, 1, self.sample_points) * self.n_strips)
120 | self.prior_feat_ys = np.flip((1 - self.sample_x_indexs / self.n_strips))
121 | self.prior_ys = np.linspace(1,0, self.n_offsets)
122 |
123 | def softmax(self, x, axis=None):
124 | x = x - x.max(axis=axis, keepdims=True)
125 | y = np.exp(x)
126 | return y / y.sum(axis=axis, keepdims=True)
127 |
128 | def predictions_to_pred(self, predictions):
129 | lanes = []
130 | for lane in predictions:
131 | lane_xs = lane[6:] # normalized value
132 | start = min(max(0, int(round(lane[2].item() * self.n_strips))),
133 | self.n_strips)
134 | length = int(round(lane[5].item()))
135 | end = start + length - 1
136 | end = min(end, len(self.prior_ys) - 1)
137 | # end = label_end
138 | # if the prediction does not start at the bottom of the image,
139 | # extend its prediction until the x is outside the image
140 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.)
141 | )[::-1].cumprod()[::-1]).astype(np.bool))
142 |
143 | lane_xs[end + 1:] = -2
144 | lane_xs[:start][mask] = -2
145 | lane_ys = self.prior_ys[lane_xs >= 0]
146 | lane_xs = lane_xs[lane_xs >= 0]
147 |
148 | lane_xs = np.double(lane_xs)
149 | lane_xs = np.flip(lane_xs, axis=0)
150 | lane_ys = np.flip(lane_ys, axis=0)
151 | lane_ys = (lane_ys * (self.ori_img_h - self.cut_height) +
152 | self.cut_height) / self.ori_img_h
153 | if len(lane_xs) <= 1:
154 | continue
155 |
156 | points = np.stack(
157 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
158 | axis=1).squeeze(2)
159 |
160 | lane = Lane(points=points,
161 | metadata={
162 | 'start_x': lane[3],
163 | 'start_y': lane[2],
164 | 'conf': lane[1]
165 | })
166 | lanes.append(lane)
167 | return lanes
168 |
169 | def get_lanes(self, output, as_lanes=True):
170 | '''
171 | Convert model output to lanes.
172 | '''
173 | decoded = []
174 | for predictions in output:
175 | # filter out the conf lower than conf threshold
176 | scores = self.softmax(predictions[:, :2], 1)[:, 1]
177 |
178 | keep_inds = scores >= self.conf_threshold
179 | predictions = predictions[keep_inds]
180 | scores = scores[keep_inds]
181 |
182 | if predictions.shape[0] == 0:
183 | decoded.append([])
184 | continue
185 | nms_predictions = predictions
186 |
187 | nms_predictions = np.concatenate(
188 | [nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
189 |
190 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
191 | nms_predictions[...,
192 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1)
193 |
194 | keep, num_to_keep, _ = nms(
195 | torch.tensor(nms_predictions).cuda(),
196 | torch.tensor(scores).cuda(),
197 | overlap=self.nms_thres,
198 | top_k=self.max_lanes)
199 |
200 | keep = keep[:num_to_keep].cpu().numpy()
201 | predictions = predictions[keep]
202 |
203 | predictions = predictions
204 |
205 | if predictions.shape[0] == 0:
206 | decoded.append([])
207 | continue
208 |
209 | predictions[:, 5] = np.round(predictions[:, 5] * self.n_strips)
210 | pred = self.predictions_to_pred(predictions)
211 | decoded.append(pred)
212 |
213 | return decoded
214 |
215 | def imshow_lanes(self, img, lanes, show=False, out_file=None, width=4):
216 | lanes = [lane.to_array() for lane in lanes]
217 |
218 | lanes_xys = []
219 | for _, lane in enumerate(lanes):
220 | xys = []
221 | for x, y in lane:
222 | if x <= 0 or y <= 0:
223 | continue
224 | x, y = int(x), int(y)
225 | xys.append((x, y))
226 | lanes_xys.append(xys)
227 | lanes_xys.sort(key=lambda xys : xys[0][0])
228 |
229 | for idx, xys in enumerate(lanes_xys):
230 | for i in range(1, len(xys)):
231 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
232 | return img
233 |
234 | def forward(self, img):
235 | img_ = img.copy()
236 | h, w = img.shape[:2]
237 | img = img[self.cut_height:, :, :]
238 | img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC)
239 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
240 |
241 | img = img.astype(np.float32) / 255.0
242 |
243 | img = np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1))
244 |
245 | ort_inputs = {self.ort_session.get_inputs()[0].name: img}
246 | ort_outs = self.ort_session.run(None, ort_inputs)
247 | output = ort_outs[0]
248 |
249 | output = self.get_lanes(output)
250 |
251 | res = self.imshow_lanes(img_, output[0])
252 | return res
253 |
254 | if __name__ == "__main__":
255 | clr = CLRNetDemo('./tusimple_r18.onnx')
256 | img = cv2.imread('./test.jpg')
257 | output = clr.forward(img)
258 | cv2.imwrite('output_onnx.png', output)
259 |
--------------------------------------------------------------------------------
/demo_onnx_new.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import numpy as np
5 | import onnxruntime
6 | from scipy.interpolate import InterpolatedUnivariateSpline
7 |
8 | COLORS = [
9 | (255, 0, 0),
10 | (0, 255, 0),
11 | (0, 0, 255),
12 | (255, 255, 0),
13 | (255, 0, 255),
14 | (0, 255, 255),
15 | (128, 255, 0),
16 | (255, 128, 0),
17 | (128, 0, 255),
18 | (255, 0, 128),
19 | (0, 128, 255),
20 | (0, 255, 128),
21 | (128, 255, 255),
22 | (255, 128, 255),
23 | (255, 255, 128),
24 | (60, 180, 0),
25 | (180, 60, 0),
26 | (0, 60, 180),
27 | (0, 180, 60),
28 | (60, 0, 180),
29 | (180, 0, 60),
30 | (255, 0, 0),
31 | (0, 255, 0),
32 | (0, 0, 255),
33 | (255, 255, 0),
34 | (255, 0, 255),
35 | (0, 255, 255),
36 | (128, 255, 0),
37 | (255, 128, 0),
38 | (128, 0, 255),
39 | ]
40 |
41 | class Lane:
42 | def __init__(self, points=None, invalid_value=-2., metadata=None):
43 | super(Lane, self).__init__()
44 | self.curr_iter = 0
45 | self.points = points
46 | self.invalid_value = invalid_value
47 | self.function = InterpolatedUnivariateSpline(points[:, 1],
48 | points[:, 0],
49 | k=min(3,
50 | len(points) - 1))
51 | self.min_y = points[:, 1].min() - 0.01
52 | self.max_y = points[:, 1].max() + 0.01
53 |
54 | self.metadata = metadata or {}
55 |
56 | self.sample_y = range(710, 150, -10)
57 | self.ori_img_w = 1280
58 | self.ori_img_h = 720
59 |
60 | def __repr__(self):
61 | return '[Lane]\n' + str(self.points) + '\n[/Lane]'
62 |
63 | def __call__(self, lane_ys):
64 | lane_xs = self.function(lane_ys)
65 |
66 | lane_xs[(lane_ys < self.min_y) |
67 | (lane_ys > self.max_y)] = self.invalid_value
68 | return lane_xs
69 |
70 | def to_array(self):
71 | sample_y = self.sample_y
72 | img_w, img_h = self.ori_img_w, self.ori_img_h
73 | ys = np.array(sample_y) / float(img_h)
74 | xs = self(ys)
75 | valid_mask = (xs >= 0) & (xs < 1)
76 | lane_xs = xs[valid_mask] * img_w
77 | lane_ys = ys[valid_mask] * img_h
78 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
79 | axis=1)
80 | return lane
81 |
82 | def __iter__(self):
83 | return self
84 |
85 | def __next__(self):
86 | if self.curr_iter < len(self.points):
87 | self.curr_iter += 1
88 | return self.points[self.curr_iter - 1]
89 | self.curr_iter = 0
90 | raise StopIteration
91 |
92 |
93 |
94 | class CLRNetDemo():
95 | def __init__(self, model_path):
96 | self.ort_session = onnxruntime.InferenceSession(model_path)
97 | self.conf_threshold = 0.4
98 | self.nms_thres = 50
99 | self.max_lanes = 5
100 | self.sample_points = 36
101 | self.num_points = 72
102 | self.n_offsets = 72
103 | self.n_strips = 71
104 | self.img_w = 1280
105 | self.img_h = 720
106 | self.ori_img_w = 1280
107 | self.ori_img_h = 720
108 | self.cut_height = 160
109 |
110 | self.input_width = 800
111 | self.input_height = 320
112 |
113 | self.sample_x_indexs = (np.linspace(0, 1, self.sample_points) * self.n_strips)
114 | self.prior_feat_ys = np.flip((1 - self.sample_x_indexs / self.n_strips))
115 | self.prior_ys = np.linspace(1,0, self.n_offsets)
116 |
117 | def softmax(self, x, axis=None):
118 | x = x - x.max(axis=axis, keepdims=True)
119 | y = np.exp(x)
120 | return y / y.sum(axis=axis, keepdims=True)
121 |
122 |
123 | def Lane_nms(self, proposals,scores,overlap=50, top_k=4):
124 | keep_index = []
125 | sorted_score = np.sort(scores)[-1] # from big to small
126 | indices = np.argsort(-scores) # from big to small
127 |
128 | r_filters = np.zeros(len(scores))
129 |
130 | for i,indice in enumerate(indices):
131 | if r_filters[i]==1: # continue if this proposal is filted by nms before
132 | continue
133 | keep_index.append(indice)
134 | if len(keep_index)>top_k: # break if more than top_k
135 | break
136 | if i == (len(scores)-1):# break if indice is the last one
137 | break
138 | sub_indices = indices[i+1:]
139 | for sub_i,sub_indice in enumerate(sub_indices):
140 | r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap)
141 | if r_filter: r_filters[i+1+sub_i]=1
142 | num_to_keep = len(keep_index)
143 | keep_index = list(map(lambda x: x.item(), keep_index))
144 | return keep_index, num_to_keep
145 |
146 | def Lane_IOU(self, parent_box, compared_box, threshold):
147 | '''
148 | calculate distance one pair of proposal lines
149 | return True if distance less than threshold
150 | '''
151 | n_offsets=72
152 | n_strips = n_offsets - 1
153 |
154 | start_a = (parent_box[2] * n_strips + 0.5).astype(int) # add 0.5 trick to make int() like round
155 | start_b = (compared_box[2] * n_strips + 0.5).astype(int)
156 | start = max(start_a,start_b)
157 | end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).astype(int))
158 | end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).astype(int))
159 | end = min(min(end_a,end_b),71)
160 | if (end - start)<0:
161 | return False
162 | dist = 0
163 | for i in range(5+start,5 + end.astype(int)):
164 | if i>(5+end):
165 | break
166 | if parent_box[i] < compared_box[i]:
167 | dist += compared_box[i] - parent_box[i]
168 | else:
169 | dist += parent_box[i] - compared_box[i]
170 | return dist < (threshold * (end - start + 1))
171 |
172 |
173 | def predictions_to_pred(self, predictions):
174 | lanes = []
175 | for lane in predictions:
176 | lane_xs = lane[6:] # normalized value
177 | start = min(max(0, int(round(lane[2].item() * self.n_strips))),
178 | self.n_strips)
179 | length = int(round(lane[5].item()))
180 | end = start + length - 1
181 | end = min(end, len(self.prior_ys) - 1)
182 | # end = label_end
183 | # if the prediction does not start at the bottom of the image,
184 | # extend its prediction until the x is outside the image
185 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.)
186 | )[::-1].cumprod()[::-1]).astype(np.bool))
187 |
188 | lane_xs[end + 1:] = -2
189 | lane_xs[:start][mask] = -2
190 | lane_ys = self.prior_ys[lane_xs >= 0]
191 | lane_xs = lane_xs[lane_xs >= 0]
192 |
193 | lane_xs = np.double(lane_xs)
194 | lane_xs = np.flip(lane_xs, axis=0)
195 | lane_ys = np.flip(lane_ys, axis=0)
196 | lane_ys = (lane_ys * (self.ori_img_h - self.cut_height) +
197 | self.cut_height) / self.ori_img_h
198 | if len(lane_xs) <= 1:
199 | continue
200 |
201 | points = np.stack(
202 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
203 | axis=1).squeeze(2)
204 |
205 | lane = Lane(points=points,
206 | metadata={
207 | 'start_x': lane[3],
208 | 'start_y': lane[2],
209 | 'conf': lane[1]
210 | })
211 | lanes.append(lane)
212 | return lanes
213 |
214 | def get_lanes(self, output, as_lanes=True):
215 | '''
216 | Convert model output to lanes.
217 | '''
218 | decoded = []
219 | for predictions in output:
220 | # filter out the conf lower than conf threshold
221 | scores = self.softmax(predictions[:, :2], 1)[:, 1]
222 |
223 | keep_inds = scores >= self.conf_threshold
224 | predictions = predictions[keep_inds]
225 | scores = scores[keep_inds]
226 |
227 | if predictions.shape[0] == 0:
228 | decoded.append([])
229 | continue
230 | nms_predictions = predictions
231 |
232 | nms_predictions = np.concatenate(
233 | [nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
234 |
235 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
236 | nms_predictions[...,
237 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1)
238 |
239 |
240 | keep, num_to_keep = self.Lane_nms(
241 | nms_predictions,
242 | scores,
243 | self.nms_thres,
244 | self.max_lanes)
245 |
246 | keep = keep[:num_to_keep]
247 | predictions = predictions[keep]
248 |
249 | if predictions.shape[0] == 0:
250 | decoded.append([])
251 | continue
252 |
253 | predictions[:, 5] = np.round(predictions[:, 5] * self.n_strips)
254 | pred = self.predictions_to_pred(predictions)
255 | decoded.append(pred)
256 |
257 | return decoded
258 |
259 | def imshow_lanes(self, img, lanes, show=False, out_file=None, width=4):
260 | lanes = [lane.to_array() for lane in lanes]
261 |
262 | lanes_xys = []
263 | for _, lane in enumerate(lanes):
264 | xys = []
265 | for x, y in lane:
266 | if x <= 0 or y <= 0:
267 | continue
268 | x, y = int(x), int(y)
269 | xys.append((x, y))
270 | lanes_xys.append(xys)
271 | lanes_xys.sort(key=lambda xys : xys[0][0])
272 |
273 | for idx, xys in enumerate(lanes_xys):
274 | for i in range(1, len(xys)):
275 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
276 | return img
277 |
278 | def forward(self, img):
279 | img_ = img.copy()
280 | h, w = img.shape[:2]
281 | img = img[self.cut_height:, :, :]
282 | img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC)
283 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
284 | img = img.astype(np.float32) / 255.0
285 |
286 | img = np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1))
287 |
288 | ort_inputs = {self.ort_session.get_inputs()[0].name: img}
289 | ort_outs = self.ort_session.run(None, ort_inputs)
290 | output = ort_outs[0]
291 |
292 | output = self.get_lanes(output)
293 | res = self.imshow_lanes(img_, output[0])
294 | return res
295 |
296 | if __name__ == "__main__":
297 | clr = CLRNetDemo('./tusimple_r18.onnx')
298 | img = cv2.imread('./test.jpg')
299 | output = clr.forward(img)
300 | cv2.imwrite('output_onnx.png', output)
301 | print("Done!")
302 |
--------------------------------------------------------------------------------
/demo_trt.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | import os
4 | import sys
5 | import cv2
6 | import numpy as np
7 | import tensorrt as trt
8 | import pycuda.driver as cuda
9 | import pycuda.autoinit
10 |
11 | from scipy.interpolate import InterpolatedUnivariateSpline
12 |
13 | COLORS = [
14 | (255, 0, 0),
15 | (0, 255, 0),
16 | (0, 0, 255),
17 | (255, 255, 0),
18 | (255, 0, 255),
19 | (0, 255, 255),
20 | (128, 255, 0),
21 | (255, 128, 0),
22 | (128, 0, 255),
23 | (255, 0, 128),
24 | (0, 128, 255),
25 | (0, 255, 128),
26 | (128, 255, 255),
27 | (255, 128, 255),
28 | (255, 255, 128),
29 | (60, 180, 0),
30 | (180, 60, 0),
31 | (0, 60, 180),
32 | (0, 180, 60),
33 | (60, 0, 180),
34 | (180, 0, 60),
35 | (255, 0, 0),
36 | (0, 255, 0),
37 | (0, 0, 255),
38 | (255, 255, 0),
39 | (255, 0, 255),
40 | (0, 255, 255),
41 | (128, 255, 0),
42 | (255, 128, 0),
43 | (128, 0, 255),
44 | ]
45 |
46 | class Lane:
47 | def __init__(self, points=None, invalid_value=-2., metadata=None):
48 | super(Lane, self).__init__()
49 | self.curr_iter = 0
50 | self.points = points
51 | self.invalid_value = invalid_value
52 | self.function = InterpolatedUnivariateSpline(points[:, 1],
53 | points[:, 0],
54 | k=min(3,
55 | len(points) - 1))
56 | self.min_y = points[:, 1].min() - 0.01
57 | self.max_y = points[:, 1].max() + 0.01
58 |
59 | self.metadata = metadata or {}
60 |
61 | self.sample_y = range(710, 150, -10)
62 | self.ori_img_w = 1280
63 | self.ori_img_h = 720
64 |
65 | def __repr__(self):
66 | return '[Lane]\n' + str(self.points) + '\n[/Lane]'
67 |
68 | def __call__(self, lane_ys):
69 | lane_xs = self.function(lane_ys)
70 |
71 | lane_xs[(lane_ys < self.min_y) |
72 | (lane_ys > self.max_y)] = self.invalid_value
73 | return lane_xs
74 |
75 | def to_array(self):
76 | sample_y = self.sample_y
77 | img_w, img_h = self.ori_img_w, self.ori_img_h
78 | ys = np.array(sample_y) / float(img_h)
79 | xs = self(ys)
80 | valid_mask = (xs >= 0) & (xs < 1)
81 | lane_xs = xs[valid_mask] * img_w
82 | lane_ys = ys[valid_mask] * img_h
83 | lane = np.concatenate((lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
84 | axis=1)
85 | return lane
86 |
87 | def __iter__(self):
88 | return self
89 |
90 | def __next__(self):
91 | if self.curr_iter < len(self.points):
92 | self.curr_iter += 1
93 | return self.points[self.curr_iter - 1]
94 | self.curr_iter = 0
95 | raise StopIteration
96 |
97 | class CLRNetDemo:
98 | def __init__(self, engine_path):
99 | self.logger = trt.Logger(trt.Logger.ERROR)
100 | with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
101 | self.engine = runtime.deserialize_cuda_engine(f.read())
102 | self.context = self.engine.create_execution_context()
103 |
104 | self.inputs = []
105 | self.outputs = []
106 | self.allocations = []
107 | for i in range(self.engine.num_bindings):
108 | is_input = False
109 | if self.engine.binding_is_input(i):
110 | is_input = True
111 | name = self.engine.get_binding_name(i)
112 | dtype = self.engine.get_binding_dtype(i)
113 | shape = self.engine.get_binding_shape(i)
114 | if is_input:
115 | self.batch_size = shape[0]
116 | size = np.dtype(trt.nptype(dtype)).itemsize
117 | for s in shape:
118 | size *= s
119 | allocation = cuda.mem_alloc(size)
120 | binding = {
121 | 'index': i,
122 | 'name': name,
123 | 'dtype': np.dtype(trt.nptype(dtype)),
124 | 'shape': list(shape),
125 | 'allocation': allocation,
126 | }
127 | self.allocations.append(allocation)
128 | if self.engine.binding_is_input(i):
129 | self.inputs.append(binding)
130 | else:
131 | self.outputs.append(binding)
132 |
133 | self.conf_threshold = 0.4
134 | self.nms_thres = 50
135 | self.max_lanes = 5
136 | self.sample_points = 36
137 | self.num_points = 72
138 | self.n_offsets = 72
139 | self.n_strips = 71
140 | self.img_w = 1280
141 | self.img_h = 720
142 | self.ori_img_w = 1280
143 | self.ori_img_h = 720
144 | self.cut_height = 160
145 |
146 | self.input_width = 800
147 | self.input_height = 320
148 |
149 | self.sample_x_indexs = (np.linspace(0, 1, self.sample_points) * self.n_strips)
150 | self.prior_feat_ys = np.flip((1 - self.sample_x_indexs / self.n_strips))
151 | self.prior_ys = np.linspace(1,0, self.n_offsets)
152 |
153 | def softmax(self, x, axis=None):
154 | x = x - x.max(axis=axis, keepdims=True)
155 | y = np.exp(x)
156 | return y / y.sum(axis=axis, keepdims=True)
157 |
158 | def Lane_nms(self, proposals,scores,overlap=50, top_k=4):
159 | keep_index = []
160 | sorted_score = np.sort(scores)[-1] # from big to small
161 | indices = np.argsort(-scores) # from big to small
162 |
163 | r_filters = np.zeros(len(scores))
164 |
165 | for i,indice in enumerate(indices):
166 | if r_filters[i]==1: # continue if this proposal is filted by nms before
167 | continue
168 | keep_index.append(indice)
169 | if len(keep_index)>top_k: # break if more than top_k
170 | break
171 | if i == (len(scores)-1):# break if indice is the last one
172 | break
173 | sub_indices = indices[i+1:]
174 | for sub_i,sub_indice in enumerate(sub_indices):
175 | r_filter = self.Lane_IOU(proposals[indice,:],proposals[sub_indice,:],overlap)
176 | if r_filter: r_filters[i+1+sub_i]=1
177 | num_to_keep = len(keep_index)
178 | keep_index = list(map(lambda x: x.item(), keep_index))
179 | return keep_index, num_to_keep
180 |
181 | def Lane_IOU(self, parent_box, compared_box, threshold):
182 | '''
183 | calculate distance one pair of proposal lines
184 | return True if distance less than threshold
185 | '''
186 | n_offsets=72
187 | n_strips = n_offsets - 1
188 |
189 | start_a = (parent_box[2] * n_strips + 0.5).astype(int) # add 0.5 trick to make int() like round
190 | start_b = (compared_box[2] * n_strips + 0.5).astype(int)
191 | start = max(start_a,start_b)
192 | end_a = start_a + parent_box[4] - 1 + 0.5 - (((parent_box[4] - 1)<0).astype(int))
193 | end_b = start_b + compared_box[4] - 1 + 0.5 - (((compared_box[4] - 1)<0).astype(int))
194 | end = min(min(end_a,end_b),71)
195 | if (end - start)<0:
196 | return False
197 | dist = 0
198 | for i in range(5+start,5 + end.astype(int)):
199 | if i>(5+end):
200 | break
201 | if parent_box[i] < compared_box[i]:
202 | dist += compared_box[i] - parent_box[i]
203 | else:
204 | dist += parent_box[i] - compared_box[i]
205 | return dist < (threshold * (end - start + 1))
206 |
207 | def predictions_to_pred(self, predictions):
208 | lanes = []
209 | for lane in predictions:
210 | lane_xs = lane[6:] # normalized value
211 | start = min(max(0, int(round(lane[2].item() * self.n_strips))),
212 | self.n_strips)
213 | length = int(round(lane[5].item()))
214 | end = start + length - 1
215 | end = min(end, len(self.prior_ys) - 1)
216 | # end = label_end
217 | # if the prediction does not start at the bottom of the image,
218 | # extend its prediction until the x is outside the image
219 | mask = ~((((lane_xs[:start] >= 0.) & (lane_xs[:start] <= 1.)
220 | )[::-1].cumprod()[::-1]).astype(np.bool))
221 |
222 | lane_xs[end + 1:] = -2
223 | lane_xs[:start][mask] = -2
224 | lane_ys = self.prior_ys[lane_xs >= 0]
225 | lane_xs = lane_xs[lane_xs >= 0]
226 |
227 | lane_xs = np.double(lane_xs)
228 | lane_xs = np.flip(lane_xs, axis=0)
229 | lane_ys = np.flip(lane_ys, axis=0)
230 | lane_ys = (lane_ys * (self.ori_img_h - self.cut_height) +
231 | self.cut_height) / self.ori_img_h
232 | if len(lane_xs) <= 1:
233 | continue
234 |
235 | points = np.stack(
236 | (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)),
237 | axis=1).squeeze(2)
238 |
239 | lane = Lane(points=points,
240 | metadata={
241 | 'start_x': lane[3],
242 | 'start_y': lane[2],
243 | 'conf': lane[1]
244 | })
245 | lanes.append(lane)
246 | return lanes
247 |
248 | def get_lanes(self, output, as_lanes=True):
249 | '''
250 | Convert model output to lanes.
251 | '''
252 | decoded = []
253 | for predictions in output:
254 | # filter out the conf lower than conf threshold
255 | scores = self.softmax(predictions[:, :2], 1)[:, 1]
256 |
257 | keep_inds = scores >= self.conf_threshold
258 | predictions = predictions[keep_inds]
259 | scores = scores[keep_inds]
260 |
261 | if predictions.shape[0] == 0:
262 | decoded.append([])
263 | continue
264 | nms_predictions = predictions
265 |
266 | nms_predictions = np.concatenate(
267 | [nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
268 |
269 | nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
270 | nms_predictions[...,
271 | 5:] = nms_predictions[..., 5:] * (self.img_w - 1)
272 |
273 |
274 | keep, num_to_keep = self.Lane_nms(
275 | nms_predictions,
276 | scores,
277 | self.nms_thres,
278 | self.max_lanes)
279 |
280 | keep = keep[:num_to_keep]
281 | predictions = predictions[keep]
282 |
283 | if predictions.shape[0] == 0:
284 | decoded.append([])
285 | continue
286 |
287 | predictions[:, 5] = np.round(predictions[:, 5] * self.n_strips)
288 | pred = self.predictions_to_pred(predictions)
289 | decoded.append(pred)
290 |
291 | return decoded
292 |
293 | def imshow_lanes(self, img, lanes, show=False, out_file=None, width=4):
294 | lanes = [lane.to_array() for lane in lanes]
295 |
296 | lanes_xys = []
297 | for _, lane in enumerate(lanes):
298 | xys = []
299 | for x, y in lane:
300 | if x <= 0 or y <= 0:
301 | continue
302 | x, y = int(x), int(y)
303 | xys.append((x, y))
304 | lanes_xys.append(xys)
305 | lanes_xys.sort(key=lambda xys : xys[0][0])
306 |
307 | for idx, xys in enumerate(lanes_xys):
308 | for i in range(1, len(xys)):
309 | cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
310 | return img
311 |
312 | def forward(self, img):
313 | img_ = img.copy()
314 | h, w = img.shape[:2]
315 | img = img[self.cut_height:, :, :]
316 | img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC)
317 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
318 |
319 | img = img.astype(np.float32) / 255.0
320 |
321 | img = np.transpose(np.float32(img[:,:,:,np.newaxis]), (3,2,0,1))
322 | img = np.ascontiguousarray(img)
323 | cuda.memcpy_htod(self.inputs[0]['allocation'], img)
324 | self.context.execute_v2(self.allocations)
325 | outputs = []
326 | for out in self.outputs:
327 | output = np.zeros(out['shape'],out['dtype'])
328 | cuda.memcpy_dtoh(output, out['allocation'])
329 | outputs.append(output)
330 |
331 | output = outputs[0]
332 | output = self.get_lanes(output)
333 | res = self.imshow_lanes(img_, output[0])
334 | return res
335 |
336 | if __name__ == "__main__":
337 | isnet = CLRNetDemo('tusimple_r18.engine')
338 | image = cv2.imread('test.jpg')
339 | output = isnet.forward(image)
340 | cv2.imwrite('output_trt.png', output)
--------------------------------------------------------------------------------
/grid_sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import itertools
4 | import operator
5 |
6 |
7 | def gather(input, dim, index):
8 | indices = [torch.arange(size, device=index.device) for size in index.shape]
9 | indices = list(torch.meshgrid(*indices))
10 | indices[dim] = index
11 | sizes = list(reversed(list(itertools.accumulate(reversed(input.shape), operator.mul))))
12 | index = sum((index * size for index, size in zip(indices, sizes[1:] + [1])))
13 | output = input.flatten()[index]
14 | return output
15 |
16 |
17 | def bilinear_grid_sample(im, grid, align_corners=False):
18 | """Given an input and a flow-field grid, computes the output using input
19 | values and pixel locations from grid. Supported only bilinear interpolation
20 | method to sample the input pixels.
21 | Args:
22 | im (torch.Tensor): Input feature map, shape (N, C, H, W)
23 | grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
24 | align_corners {bool}: If set to True, the extrema (-1 and 1) are
25 | considered as referring to the center points of the input’s
26 | corner pixels. If set to False, they are instead considered as
27 | referring to the corner points of the input’s corner pixels,
28 | making the sampling more resolution agnostic.
29 | Returns:
30 | torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
31 | """
32 |
33 | n, c, h, w = im.shape
34 | gn, gh, gw, _ = grid.shape
35 | assert n == gn
36 |
37 | x = grid[:, :, :, 0]
38 | y = grid[:, :, :, 1]
39 |
40 | if align_corners:
41 | x = ((x + 1) / 2) * (w - 1)
42 | y = ((y + 1) / 2) * (h - 1)
43 | else:
44 | x = ((x + 1) * w - 1) / 2
45 | y = ((y + 1) * h - 1) / 2
46 |
47 | x = x.view(n, -1)
48 | y = y.view(n, -1)
49 |
50 | x0 = torch.floor(x).long()
51 | y0 = torch.floor(y).long()
52 | x1 = x0 + 1
53 | y1 = y0 + 1
54 |
55 | wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
56 | wb = ((x1 - x) * (y - y0)).unsqueeze(1)
57 | wc = ((x - x0) * (y1 - y)).unsqueeze(1)
58 | wd = ((x - x0) * (y - y0)).unsqueeze(1)
59 |
60 | # Apply default for grid_sample function zero padding
61 | im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
62 |
63 | padded_h = h + 2
64 | padded_w = w + 2
65 | # save points positions after padding
66 | x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
67 |
68 | # Clip coordinates to padded image size
69 | x0 = torch.where(x0 < 0, torch.tensor(0, device=x0.device), x0)
70 | x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=x0.device), x0)
71 | x1 = torch.where(x1 < 0, torch.tensor(0, device=x1.device), x1)
72 | x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=x1.device), x1)
73 | y0 = torch.where(y0 < 0, torch.tensor(0, device=y0.device), y0)
74 | y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=y0.device), y0)
75 | y1 = torch.where(y1 < 0, torch.tensor(0, device=y1.device), y1)
76 | y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=y1.device), y1)
77 |
78 | im_padded = im_padded.view(n, c, -1)
79 |
80 | x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
81 | x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
82 | x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
83 | x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
84 |
85 | # if (x0 + y0 * padded_w).shape[0] == 1:
86 | # x0_y0 = torch.add(x0, y0 * padded_w).repeat(c, 1).unsqueeze(0)
87 | # x0_y1 = torch.add(x0, y1 * padded_w).repeat(c, 1).unsqueeze(0)
88 | # x1_y0 = torch.add(x1, y0 * padded_w).repeat(c, 1).unsqueeze(0)
89 | # x1_y1 = torch.add(x1, y1 * padded_w).repeat(c, 1).unsqueeze(0)
90 | # else:
91 | # x0_y0 = torch.add(x0, y0 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0)
92 | # x0_y1 = torch.add(x0, y1 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0)
93 | # x1_y0 = torch.add(x1, y0 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0)
94 | # x1_y1 = torch.add(x1, y1 * padded_w).unsqueeze(0).repeat(c, 1, 1).transpose(1, 0)
95 |
96 | # if (x0 + y0 * padded_w).shape[0] == 1:
97 | # # x0_y0 = (x0 + y0 * padded_w).squeeze().repeat(c).view(1, 256, 4096)
98 | # # x0_y1 = (x0 + y1 * padded_w).squeeze().repeat(c).view(1, 256, 4096)
99 | # # x1_y0 = (x1 + y0 * padded_w).squeeze().repeat(c).view(1, 256, 4096)
100 | # # x1_y1 = (x1 + y1 * padded_w).squeeze().repeat(c).view(1, 256, 4096)
101 | # x0_y0 = torch.ones(1, 256, 4096, dtype=x0.dtype)
102 | # x0_y1 = torch.ones(1, 256, 4096, dtype=x0.dtype)
103 | # x1_y0 = torch.ones(1, 256, 4096, dtype=x0.dtype)
104 | # x1_y1 = torch.ones(1, 256, 4096, dtype=x0.dtype)
105 | # else:
106 | # # x0_y0 = torch.stack([(x0 + y0 * padded_w) for _ in range(c)]).transpose(1, 0)
107 | # # x0_y1 = torch.stack([(x0 + y1 * padded_w) for _ in range(c)]).transpose(1, 0)
108 | # # x1_y0 = torch.stack([(x1 + y0 * padded_w) for _ in range(c)]).transpose(1, 0)
109 | # # x1_y1 = torch.stack([(x1 + y1 * padded_w) for _ in range(c)]).transpose(1, 0)
110 | # x0_y0 = torch.ones(11, 3, 4096, dtype=x0.dtype)
111 | # x0_y1 = torch.ones(11, 3, 4096, dtype=x0.dtype)
112 | # x1_y0 = torch.ones(11, 3, 4096, dtype=x0.dtype)
113 | # x1_y1 = torch.ones(11, 3, 4096, dtype=x0.dtype)
114 |
115 | # x0_y0 = torch.cat([(x0 + y0 * padded_w).unsqueeze(1) for _ in range(c)], axis=1)
116 | # x0_y1 = torch.cat([(x0 + y1 * padded_w).unsqueeze(1) for _ in range(c)], axis=1)
117 | # x1_y0 = torch.cat([(x1 + y0 * padded_w).unsqueeze(1) for _ in range(c)], axis=1)
118 | # x1_y1 = torch.cat([(x1 + y1 * padded_w).unsqueeze(1) for _ in range(c)], axis=1)
119 |
120 | Ia = torch.gather(im_padded, 2, x0_y0)
121 | Ib = torch.gather(im_padded, 2, x0_y1)
122 | Ic = torch.gather(im_padded, 2, x1_y0)
123 | Id = torch.gather(im_padded, 2, x1_y1)
124 | # Ia = gather(im_padded, 2, x0_y0)
125 | # Ib = gather(im_padded, 2, x0_y1)
126 | # Ic = gather(im_padded, 2, x1_y0)
127 | # Id = gather(im_padded, 2, x1_y1)
128 |
129 | return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
--------------------------------------------------------------------------------
/imgs/output_onnx.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/07a47e666d4bfca9a4c60a777c62c23828fca663/imgs/output_onnx.png
--------------------------------------------------------------------------------
/imgs/output_trt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/07a47e666d4bfca9a4c60a777c62c23828fca663/imgs/output_trt.png
--------------------------------------------------------------------------------
/my_log/test_onnx.log:
--------------------------------------------------------------------------------
1 | (base) [***@ai02 temp]$ git clone https://github.com/Turoad/clrnet
2 | 正克隆到 'clrnet'...
3 | remote: Enumerating objects: 105, done.
4 | remote: Counting objects: 100% (105/105), done.
5 | remote: Compressing objects: 100% (81/81), done.
6 | remote: Total 105 (delta 22), reused 95 (delta 17), pack-reused 0
7 | 接收对象中: 100% (105/105), 286.98 KiB | 0 bytes/s, done.
8 | 处理 delta 中: 100% (22/22), done.
9 | (base) [***@ai02 temp]$ ls
10 | clrnet
11 | (base) [***@ai02 temp]$ cd clrnet/
12 | (base) [***@ai02 clrnet]$ ls
13 | clrnet configs LICENSE main.py README.md requirements.txt setup.py tools
14 | (base) [***@ai02 clrnet]$ conda create -n clrnet python=3.8 -y
15 | Collecting package metadata (current_repodata.json): done
16 | Solving environment: done
17 |
18 |
19 | ==> WARNING: A newer version of conda exists. <==
20 | current version: 4.11.0
21 | latest version: 4.14.0
22 |
23 | Please update conda by running
24 |
25 | $ conda update -n base -c defaults conda
26 |
27 |
28 |
29 | ## Package Plan ##
30 |
31 | environment location: /data3/***/anaconda3/envs/clrnet
32 |
33 | added / updated specs:
34 | - python=3.8
35 |
36 |
37 | The following NEW packages will be INSTALLED:
38 |
39 | _libgcc_mutex anaconda/pkgs/main/linux-64::_libgcc_mutex-0.1-main
40 | _openmp_mutex anaconda/pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu
41 | ca-certificates anaconda/pkgs/main/linux-64::ca-certificates-2022.07.19-h06a4308_0
42 | certifi anaconda/pkgs/main/linux-64::certifi-2022.6.15-py38h06a4308_0
43 | ld_impl_linux-64 anaconda/pkgs/main/linux-64::ld_impl_linux-64-2.38-h1181459_1
44 | libffi anaconda/pkgs/main/linux-64::libffi-3.3-he6710b0_2
45 | libgcc-ng anaconda/pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1
46 | libgomp anaconda/pkgs/main/linux-64::libgomp-11.2.0-h1234567_1
47 | libstdcxx-ng anaconda/pkgs/main/linux-64::libstdcxx-ng-11.2.0-h1234567_1
48 | ncurses anaconda/pkgs/main/linux-64::ncurses-6.3-h5eee18b_3
49 | openssl anaconda/pkgs/main/linux-64::openssl-1.1.1q-h7f8727e_0
50 | pip anaconda/pkgs/main/linux-64::pip-22.1.2-py38h06a4308_0
51 | python anaconda/pkgs/main/linux-64::python-3.8.13-h12debd9_0
52 | readline anaconda/pkgs/main/linux-64::readline-8.1.2-h7f8727e_1
53 | setuptools anaconda/pkgs/main/linux-64::setuptools-63.4.1-py38h06a4308_0
54 | sqlite anaconda/pkgs/main/linux-64::sqlite-3.39.2-h5082296_0
55 | tk anaconda/pkgs/main/linux-64::tk-8.6.12-h1ccaba5_0
56 | wheel anaconda/pkgs/main/noarch::wheel-0.37.1-pyhd3eb1b0_0
57 | xz anaconda/pkgs/main/linux-64::xz-5.2.5-h7f8727e_1
58 | zlib anaconda/pkgs/main/linux-64::zlib-1.2.12-h7f8727e_2
59 |
60 |
61 | Preparing transaction: done
62 | Verifying transaction: done
63 | Executing transaction: done
64 | #
65 | # To activate this environment, use
66 | #
67 | # $ conda activate clrnet
68 | #
69 | # To deactivate an active environment, use
70 | #
71 | # $ conda deactivate
72 |
73 | (base) [***@ai02 clrnet]$ conda activate clrnet
74 | (clrnet) [***@ai02 clrnet]$ pip install torch==1.8.0 torchvision==0.9.0
75 | Collecting torch==1.8.0
76 | Using cached torch-1.8.0-cp38-cp38-manylinux1_x86_64.whl (735.5 MB)
77 | Collecting torchvision==0.9.0
78 | Using cached torchvision-0.9.0-cp38-cp38-manylinux1_x86_64.whl (17.3 MB)
79 | Collecting numpy
80 | Downloading numpy-1.23.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
81 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.1/17.1 MB 11.5 MB/s eta 0:00:00
82 | Collecting typing-extensions
83 | Using cached typing_extensions-4.3.0-py3-none-any.whl (25 kB)
84 | Collecting pillow>=4.1.1
85 | Using cached Pillow-9.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
86 | Installing collected packages: typing-extensions, pillow, numpy, torch, torchvision
87 | Successfully installed numpy-1.23.2 pillow-9.2.0 torch-1.8.0 torchvision-0.9.0 typing-extensions-4.3.0
88 | (clrnet) [***@ai02 clrnet]$ python setup.py build develop
89 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer.
90 | warnings.warn(
91 | running build
92 | running build_py
93 | creating build
94 | creating build/lib.linux-x86_64-cpython-38
95 | creating build/lib.linux-x86_64-cpython-38/clrnet
96 | copying clrnet/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet
97 | creating build/lib.linux-x86_64-cpython-38/clrnet/utils
98 | copying clrnet/utils/tusimple_metric.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
99 | copying clrnet/utils/logger.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
100 | copying clrnet/utils/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
101 | copying clrnet/utils/llamas_utils.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
102 | copying clrnet/utils/llamas_metric.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
103 | copying clrnet/utils/recorder.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
104 | copying clrnet/utils/visualization.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
105 | copying clrnet/utils/culane_metric.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
106 | copying clrnet/utils/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
107 | copying clrnet/utils/config.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
108 | copying clrnet/utils/lane.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
109 | copying clrnet/utils/net_utils.py -> build/lib.linux-x86_64-cpython-38/clrnet/utils
110 | creating build/lib.linux-x86_64-cpython-38/clrnet/engine
111 | copying clrnet/engine/scheduler.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine
112 | copying clrnet/engine/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine
113 | copying clrnet/engine/runner.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine
114 | copying clrnet/engine/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine
115 | copying clrnet/engine/optimizer.py -> build/lib.linux-x86_64-cpython-38/clrnet/engine
116 | creating build/lib.linux-x86_64-cpython-38/clrnet/models
117 | copying clrnet/models/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/models
118 | copying clrnet/models/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models
119 | creating build/lib.linux-x86_64-cpython-38/clrnet/datasets
120 | copying clrnet/datasets/llamas.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets
121 | copying clrnet/datasets/base_dataset.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets
122 | copying clrnet/datasets/tusimple.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets
123 | copying clrnet/datasets/registry.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets
124 | copying clrnet/datasets/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets
125 | copying clrnet/datasets/culane.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets
126 | creating build/lib.linux-x86_64-cpython-38/clrnet/ops
127 | copying clrnet/ops/nms.py -> build/lib.linux-x86_64-cpython-38/clrnet/ops
128 | copying clrnet/ops/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/ops
129 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/utils
130 | copying clrnet/models/utils/roi_gather.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils
131 | copying clrnet/models/utils/seg_decoder.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils
132 | copying clrnet/models/utils/dynamic_assign.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils
133 | copying clrnet/models/utils/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/utils
134 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/backbones
135 | copying clrnet/models/backbones/dla34.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/backbones
136 | copying clrnet/models/backbones/resnet.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/backbones
137 | copying clrnet/models/backbones/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/backbones
138 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/heads
139 | copying clrnet/models/heads/clr_head.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/heads
140 | copying clrnet/models/heads/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/heads
141 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/losses
142 | copying clrnet/models/losses/focal_loss.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses
143 | copying clrnet/models/losses/accuracy.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses
144 | copying clrnet/models/losses/lineiou_loss.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses
145 | copying clrnet/models/losses/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/losses
146 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/necks
147 | copying clrnet/models/necks/fpn.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/necks
148 | copying clrnet/models/necks/pafpn.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/necks
149 | copying clrnet/models/necks/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/necks
150 | creating build/lib.linux-x86_64-cpython-38/clrnet/models/nets
151 | copying clrnet/models/nets/detector.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/nets
152 | copying clrnet/models/nets/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/models/nets
153 | creating build/lib.linux-x86_64-cpython-38/clrnet/datasets/process
154 | copying clrnet/datasets/process/transforms.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process
155 | copying clrnet/datasets/process/process.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process
156 | copying clrnet/datasets/process/__init__.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process
157 | copying clrnet/datasets/process/generate_lane_line.py -> build/lib.linux-x86_64-cpython-38/clrnet/datasets/process
158 | running egg_info
159 | creating clrnet.egg-info
160 | writing clrnet.egg-info/PKG-INFO
161 | writing dependency_links to clrnet.egg-info/dependency_links.txt
162 | writing requirements to clrnet.egg-info/requires.txt
163 | writing top-level names to clrnet.egg-info/top_level.txt
164 | writing manifest file 'clrnet.egg-info/SOURCES.txt'
165 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/utils/cpp_extension.py:369: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
166 | warnings.warn(msg.format('we could not find ninja.'))
167 | reading manifest file 'clrnet.egg-info/SOURCES.txt'
168 | adding license file 'LICENSE'
169 | writing manifest file 'clrnet.egg-info/SOURCES.txt'
170 | running build_ext
171 | building 'clrnet.ops.nms_impl' extension
172 | creating build/temp.linux-x86_64-cpython-38
173 | creating build/temp.linux-x86_64-cpython-38/clrnet
174 | creating build/temp.linux-x86_64-cpython-38/clrnet/ops
175 | creating build/temp.linux-x86_64-cpython-38/clrnet/ops/csrc
176 | /data1/cv/softs/gcc-5.1.0/install/bin/gcc -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/TH -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/data3/***/anaconda3/envs/clrnet/include/python3.8 -c ./clrnet/ops/csrc/nms.cpp -o build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms.o -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=nms_impl -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14
177 | cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid for C/ObjC but not for C++
178 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Parallel.h:140:0,
179 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/utils.h:3,
180 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn/cloneable.h:5,
181 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/nn.h:3,
182 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:13,
183 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
184 | from ./clrnet/ops/csrc/nms.cpp:30:
185 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ParallelOpenMP.h:83:0: warning: ignoring #pragma omp parallel [-Wunknown-pragmas]
186 | #pragma omp parallel for if ((end - begin) >= grain_size)
187 | ^
188 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Device.h:5:0,
189 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Allocator.h:6,
190 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:7,
191 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
192 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
193 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
194 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
195 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
196 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
197 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8,
198 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
199 | from ./clrnet/ops/csrc/nms.cpp:30:
200 | ./clrnet/ops/csrc/nms.cpp: In function ‘std::vector nms_forward(at::Tensor, at::Tensor, float, long unsigned int)’:
201 | ./clrnet/ops/csrc/nms.cpp:40:41: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations]
202 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
203 | ^
204 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:225:39: note: in definition of macro ‘C10_EXPAND_MSVC_WORKAROUND’
205 | #define C10_EXPAND_MSVC_WORKAROUND(x) x
206 | ^
207 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:244:34: note: in expansion of macro ‘C10_UNLIKELY’
208 | #define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e)
209 | ^
210 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:291:7: note: in expansion of macro ‘C10_UNLIKELY_OR_CONST’
211 | if (C10_UNLIKELY_OR_CONST(!(cond))) { \
212 | ^
213 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:484:32: note: in expansion of macro ‘TORCH_INTERNAL_ASSERT’
214 | C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \
215 | ^
216 | ./clrnet/ops/csrc/nms.cpp:40:23: note: in expansion of macro ‘AT_ASSERTM’
217 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
218 | ^
219 | ./clrnet/ops/csrc/nms.cpp:42:24: note: in expansion of macro ‘CHECK_CUDA’
220 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
221 | ^
222 | ./clrnet/ops/csrc/nms.cpp:53:5: note: in expansion of macro ‘CHECK_INPUT’
223 | CHECK_INPUT(boxes);
224 | ^
225 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Tensor.h:3:0,
226 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Context.h:4,
227 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:9,
228 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
229 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
230 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
231 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
232 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
233 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
234 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8,
235 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
236 | from ./clrnet/ops/csrc/nms.cpp:30:
237 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:303:30: note: declared here
238 | DeprecatedTypeProperties & type() const {
239 | ^
240 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Device.h:5:0,
241 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/core/Allocator.h:6,
242 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:7,
243 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
244 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
245 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
246 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
247 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
248 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
249 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8,
250 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
251 | from ./clrnet/ops/csrc/nms.cpp:30:
252 | ./clrnet/ops/csrc/nms.cpp:40:41: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations]
253 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
254 | ^
255 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:225:39: note: in definition of macro ‘C10_EXPAND_MSVC_WORKAROUND’
256 | #define C10_EXPAND_MSVC_WORKAROUND(x) x
257 | ^
258 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:244:34: note: in expansion of macro ‘C10_UNLIKELY’
259 | #define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e)
260 | ^
261 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:291:7: note: in expansion of macro ‘C10_UNLIKELY_OR_CONST’
262 | if (C10_UNLIKELY_OR_CONST(!(cond))) { \
263 | ^
264 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/c10/util/Exception.h:484:32: note: in expansion of macro ‘TORCH_INTERNAL_ASSERT’
265 | C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(cond, __VA_ARGS__)); \
266 | ^
267 | ./clrnet/ops/csrc/nms.cpp:40:23: note: in expansion of macro ‘AT_ASSERTM’
268 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
269 | ^
270 | ./clrnet/ops/csrc/nms.cpp:42:24: note: in expansion of macro ‘CHECK_CUDA’
271 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
272 | ^
273 | ./clrnet/ops/csrc/nms.cpp:54:5: note: in expansion of macro ‘CHECK_INPUT’
274 | CHECK_INPUT(idx);
275 | ^
276 | In file included from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Tensor.h:3:0,
277 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Context.h:4,
278 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/ATen.h:9,
279 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
280 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
281 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
282 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
283 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
284 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
285 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8,
286 | from /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/extension.h:4,
287 | from ./clrnet/ops/csrc/nms.cpp:30:
288 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:303:30: note: declared here
289 | DeprecatedTypeProperties & type() const {
290 | ^
291 | /usr/local/cuda-11.3/bin/nvcc -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/TH -I/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda-11.3/include -I/data3/***/anaconda3/envs/clrnet/include/python3.8 -c ./clrnet/ops/csrc/nms_kernel.cu -o build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms_kernel.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=nms_impl -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 -ccbin /data1/cv/softs/gcc-5.1.0/install/bin/gcc -std=c++14
292 | ./clrnet/ops/csrc/nms_kernel.cu: In lambda function:
293 | ./clrnet/ops/csrc/nms_kernel.cu:171:43: warning: ‘at::DeprecatedTypeProperties& at::Tensor::type() const’ is deprecated: Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device(). [-Wdeprecated-declarations]
294 | AT_DISPATCH_FLOATING_TYPES(boxes.type(), "nms_cuda_forward", ([&] {
295 | ^
296 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:303:1: note: declared here
297 | DeprecatedTypeProperties & type() const {
298 | ^
299 | ./clrnet/ops/csrc/nms_kernel.cu:171:98: warning: ‘c10::ScalarType detail::scalar_type(const at::DeprecatedTypeProperties&)’ is deprecated: passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, pass an at::ScalarType instead [-Wdeprecated-declarations]
300 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/Dispatch.h:109:1: note: declared here
301 | inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
302 | ^
303 | ./clrnet/ops/csrc/nms_kernel.cu: In lambda function:
304 | ./clrnet/ops/csrc/nms_kernel.cu:171:856: warning: ‘T* at::Tensor::data() const [with T = double]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
305 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
306 | T * data() const {
307 | ^
308 | ./clrnet/ops/csrc/nms_kernel.cu:171:878: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
309 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
310 | T * data() const {
311 | ^
312 | ./clrnet/ops/csrc/nms_kernel.cu:171:901: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
313 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
314 | T * data() const {
315 | ^
316 | ./clrnet/ops/csrc/nms_kernel.cu: In lambda function:
317 | ./clrnet/ops/csrc/nms_kernel.cu:171:1648: warning: ‘T* at::Tensor::data() const [with T = float]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
318 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
319 | T * data() const {
320 | ^
321 | ./clrnet/ops/csrc/nms_kernel.cu:171:1670: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
322 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
323 | T * data() const {
324 | ^
325 | ./clrnet/ops/csrc/nms_kernel.cu:171:1693: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
326 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
327 | T * data() const {
328 | ^
329 | ./clrnet/ops/csrc/nms_kernel.cu: In function ‘std::vector nms_cuda_forward(at::Tensor, at::Tensor, float, long unsigned int)’:
330 | ./clrnet/ops/csrc/nms_kernel.cu:183:108: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
331 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
332 | T * data() const {
333 | ^
334 | ./clrnet/ops/csrc/nms_kernel.cu:183:129: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
335 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
336 | T * data() const {
337 | ^
338 | ./clrnet/ops/csrc/nms_kernel.cu:183:150: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
339 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
340 | T * data() const {
341 | ^
342 | ./clrnet/ops/csrc/nms_kernel.cu:183:186: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
343 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
344 | T * data() const {
345 | ^
346 | ./clrnet/ops/csrc/nms_kernel.cu:183:214: warning: ‘T* at::Tensor::data() const [with T = long int]’ is deprecated: Tensor.data() is deprecated. Please use Tensor.data_ptr() instead. [-Wdeprecated-declarations]
347 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/include/ATen/core/TensorBody.h:395:1: note: declared here
348 | T * data() const {
349 | ^
350 | /data1/cv/softs/gcc-5.1.0/install/bin/g++ -pthread -shared -B /data3/***/anaconda3/envs/clrnet/compiler_compat -L/data3/***/anaconda3/envs/clrnet/lib -Wl,-rpath=/data3/***/anaconda3/envs/clrnet/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms.o build/temp.linux-x86_64-cpython-38/./clrnet/ops/csrc/nms_kernel.o -L/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/lib -L/usr/local/cuda-11.3/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-38/clrnet/ops/nms_impl.cpython-38-x86_64-linux-gnu.so
351 | running develop
352 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
353 | warnings.warn(
354 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
355 | warnings.warn(
356 | running build_ext
357 | copying build/lib.linux-x86_64-cpython-38/clrnet/ops/nms_impl.cpython-38-x86_64-linux-gnu.so -> clrnet/ops
358 | Creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/clrnet.egg-link (link to .)
359 | Adding clrnet 1.0 to easy-install.pth file
360 |
361 | Installed /data3/***/project/temp/clrnet
362 | Processing dependencies for clrnet==1.0
363 | Searching for ptflops
364 | Reading https://pypi.org/simple/ptflops/
365 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
366 | warnings.warn(
367 | Downloading https://files.pythonhosted.org/packages/d5/16/b6992b799d14bdc7a78fbfff8825777dd92728ed761090852c43f4792ce1/ptflops-0.6.9.tar.gz#sha256=95423006d7520eff5cc2fcbbd149329d39d81c9bff361c9b3b13bfbd5d1efe75
368 | Best match: ptflops 0.6.9
369 | Processing ptflops-0.6.9.tar.gz
370 | Writing /tmp/easy_install-h2cxdo1q/ptflops-0.6.9/setup.cfg
371 | Running ptflops-0.6.9/setup.py -q bdist_egg --dist-dir /tmp/easy_install-h2cxdo1q/ptflops-0.6.9/egg-dist-tmp-pwoqahab
372 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
373 | warnings.warn(
374 | Moving ptflops-0.6.9-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
375 | Adding ptflops 0.6.9 to easy-install.pth file
376 |
377 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ptflops-0.6.9-py3.8.egg
378 | Searching for pathspec
379 | Reading https://pypi.org/simple/pathspec/
380 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
381 | warnings.warn(
382 | Downloading https://files.pythonhosted.org/packages/42/ba/a9d64c7bcbc7e3e8e5f93a52721b377e994c22d16196e2b0f1236774353a/pathspec-0.9.0-py2.py3-none-any.whl#sha256=7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a
383 | Best match: pathspec 0.9.0
384 | Processing pathspec-0.9.0-py2.py3-none-any.whl
385 | Installing pathspec-0.9.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
386 | Adding pathspec 0.9.0 to easy-install.pth file
387 |
388 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pathspec-0.9.0-py3.8.egg
389 | Searching for albumentations==0.4.6
390 | Reading https://pypi.org/simple/albumentations/
391 | Downloading https://files.pythonhosted.org/packages/92/33/1c459c2c9a4028ec75527eff88bc4e2d256555189f42af4baf4d7bd89233/albumentations-0.4.6.tar.gz#sha256=510ac855a6dc7f80723bba7de98c9ee575997a76cf49192c44c707c904a020f8
392 | Best match: albumentations 0.4.6
393 | Processing albumentations-0.4.6.tar.gz
394 | Writing /tmp/easy_install-egwsde9u/albumentations-0.4.6/setup.cfg
395 | Running albumentations-0.4.6/setup.py -q bdist_egg --dist-dir /tmp/easy_install-egwsde9u/albumentations-0.4.6/egg-dist-tmp-mma_63qf
396 | no previously-included directories found matching 'docs/_build'
397 | warning: no previously-included files matching 'docs/augs_overview/*/images/*.jpg' found anywhere in distribution
398 | warning: no previously-included files matching '*.py[co]' found anywhere in distribution
399 | warning: no previously-included files matching '.DS_Store' found anywhere in distribution
400 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
401 | warnings.warn(
402 | zip_safe flag not set; analyzing archive contents...
403 | Moving albumentations-0.4.6-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
404 | Adding albumentations 0.4.6 to easy-install.pth file
405 |
406 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/albumentations-0.4.6-py3.8.egg
407 | Searching for mmcv==1.2.5
408 | Reading https://pypi.org/simple/mmcv/
409 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
410 | warnings.warn(
411 | Downloading https://files.pythonhosted.org/packages/0c/ce/0574933eb3a02eade4af331579b3cc8840e59e5d757ee67d243f08fd9b75/mmcv-1.2.5.tar.gz#sha256=8d2f3fb61b75cbcfbc0e9756e905d6b8678a3a14ca9f8603726ca86846c41b28
412 | Best match: mmcv 1.2.5
413 | Processing mmcv-1.2.5.tar.gz
414 | Writing /tmp/easy_install-f2kyyeji/mmcv-1.2.5/setup.cfg
415 | Running mmcv-1.2.5/setup.py -q bdist_egg --dist-dir /tmp/easy_install-f2kyyeji/mmcv-1.2.5/egg-dist-tmp-oqa8zdgb
416 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
417 | warnings.warn(
418 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.model_zoo' as data is deprecated, please list it in `packages`.
419 | !!
420 |
421 |
422 | ############################
423 | # Package would be ignored #
424 | ############################
425 | Python recognizes 'mmcv.model_zoo' as an importable package,
426 | but it is not listed in the `packages` configuration of setuptools.
427 |
428 | 'mmcv.model_zoo' has been automatically added to the distribution only
429 | because it may contain data files, but this behavior is likely to change
430 | in future versions of setuptools (and therefore is considered deprecated).
431 |
432 | Please make sure that 'mmcv.model_zoo' is included as a package by using
433 | the `packages` configuration field or the proper discovery methods
434 | (for example by using `find_namespace_packages(...)`/`find_namespace:`
435 | instead of `find_packages(...)`/`find:`).
436 |
437 | You can read more about "package discovery" and "data files" on setuptools
438 | documentation page.
439 |
440 |
441 | !!
442 |
443 | check.warn(importable)
444 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.ops.csrc' as data is deprecated, please list it in `packages`.
445 | !!
446 |
447 |
448 | ############################
449 | # Package would be ignored #
450 | ############################
451 | Python recognizes 'mmcv.ops.csrc' as an importable package,
452 | but it is not listed in the `packages` configuration of setuptools.
453 |
454 | 'mmcv.ops.csrc' has been automatically added to the distribution only
455 | because it may contain data files, but this behavior is likely to change
456 | in future versions of setuptools (and therefore is considered deprecated).
457 |
458 | Please make sure that 'mmcv.ops.csrc' is included as a package by using
459 | the `packages` configuration field or the proper discovery methods
460 | (for example by using `find_namespace_packages(...)`/`find_namespace:`
461 | instead of `find_packages(...)`/`find:`).
462 |
463 | You can read more about "package discovery" and "data files" on setuptools
464 | documentation page.
465 |
466 |
467 | !!
468 |
469 | check.warn(importable)
470 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.ops.csrc.parrots' as data is deprecated, please list it in `packages`.
471 | !!
472 |
473 |
474 | ############################
475 | # Package would be ignored #
476 | ############################
477 | Python recognizes 'mmcv.ops.csrc.parrots' as an importable package,
478 | but it is not listed in the `packages` configuration of setuptools.
479 |
480 | 'mmcv.ops.csrc.parrots' has been automatically added to the distribution only
481 | because it may contain data files, but this behavior is likely to change
482 | in future versions of setuptools (and therefore is considered deprecated).
483 |
484 | Please make sure that 'mmcv.ops.csrc.parrots' is included as a package by using
485 | the `packages` configuration field or the proper discovery methods
486 | (for example by using `find_namespace_packages(...)`/`find_namespace:`
487 | instead of `find_packages(...)`/`find:`).
488 |
489 | You can read more about "package discovery" and "data files" on setuptools
490 | documentation page.
491 |
492 |
493 | !!
494 |
495 | check.warn(importable)
496 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/build_py.py:153: SetuptoolsDeprecationWarning: Installing 'mmcv.ops.csrc.pytorch' as data is deprecated, please list it in `packages`.
497 | !!
498 |
499 |
500 | ############################
501 | # Package would be ignored #
502 | ############################
503 | Python recognizes 'mmcv.ops.csrc.pytorch' as an importable package,
504 | but it is not listed in the `packages` configuration of setuptools.
505 |
506 | 'mmcv.ops.csrc.pytorch' has been automatically added to the distribution only
507 | because it may contain data files, but this behavior is likely to change
508 | in future versions of setuptools (and therefore is considered deprecated).
509 |
510 | Please make sure that 'mmcv.ops.csrc.pytorch' is included as a package by using
511 | the `packages` configuration field or the proper discovery methods
512 | (for example by using `find_namespace_packages(...)`/`find_namespace:`
513 | instead of `find_packages(...)`/`find:`).
514 |
515 | You can read more about "package discovery" and "data files" on setuptools
516 | documentation page.
517 |
518 |
519 | !!
520 |
521 | check.warn(importable)
522 | creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/mmcv-1.2.5-py3.8.egg
523 | Extracting mmcv-1.2.5-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
524 | Adding mmcv 1.2.5 to easy-install.pth file
525 |
526 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/mmcv-1.2.5-py3.8.egg
527 | Searching for timm
528 | Reading https://pypi.org/simple/timm/
529 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
530 | warnings.warn(
531 | Downloading https://files.pythonhosted.org/packages/72/ed/358a8bc5685c31c0fe7765351b202cf6a8c087893b5d2d64f63c950f8beb/timm-0.6.7-py3-none-any.whl#sha256=4bbd7a5c9ae462ec7fec3d99ffc62ac2012010d755248e3de778d50bce5f6186
532 | Best match: timm 0.6.7
533 | Processing timm-0.6.7-py3-none-any.whl
534 | Installing timm-0.6.7-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
535 | Adding timm 0.6.7 to easy-install.pth file
536 |
537 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/timm-0.6.7-py3.8.egg
538 | Searching for yapf
539 | Reading https://pypi.org/simple/yapf/
540 | Downloading https://files.pythonhosted.org/packages/47/88/843c2e68f18a5879b4fbf37cb99fbabe1ffc4343b2e63191c8462235c008/yapf-0.32.0-py2.py3-none-any.whl#sha256=8fea849025584e486fd06d6ba2bed717f396080fd3cc236ba10cb97c4c51cf32
541 | Best match: yapf 0.32.0
542 | Processing yapf-0.32.0-py2.py3-none-any.whl
543 | Installing yapf-0.32.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
544 | Adding yapf 0.32.0 to easy-install.pth file
545 | Installing yapf script to /data3/***/anaconda3/envs/clrnet/bin
546 | Installing yapf-diff script to /data3/***/anaconda3/envs/clrnet/bin
547 |
548 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/yapf-0.32.0-py3.8.egg
549 | Searching for ujson==1.35
550 | Reading https://pypi.org/simple/ujson/
551 | Downloading https://files.pythonhosted.org/packages/16/c4/79f3409bc710559015464e5f49b9879430d8f87498ecdc335899732e5377/ujson-1.35.tar.gz#sha256=f66073e5506e91d204ab0c614a148d5aa938bdbf104751be66f8ad7a222f5f86
552 | Best match: ujson 1.35
553 | Processing ujson-1.35.tar.gz
554 | Writing /tmp/easy_install-ui57loz0/ujson-1.35/setup.cfg
555 | Running ujson-1.35/setup.py -q bdist_egg --dist-dir /tmp/easy_install-ui57loz0/ujson-1.35/egg-dist-tmp-7dvt_ppm
556 | Warning: 'classifiers' should be a list, got type 'filter'
557 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
558 | warnings.warn(
559 | ./lib/ultrajsonenc.c:156:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static
560 | *(outputOffset++) = g_hexChars[(value & 0x000f) >> 0];
561 | ^
562 | ./lib/ultrajsonenc.c:155:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static
563 | *(outputOffset++) = g_hexChars[(value & 0x00f0) >> 4];
564 | ^
565 | ./lib/ultrajsonenc.c:154:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static
566 | *(outputOffset++) = g_hexChars[(value & 0x0f00) >> 8];
567 | ^
568 | ./lib/ultrajsonenc.c:153:23: warning: ‘g_hexChars’ is static but used in inline function ‘Buffer_AppendShortHexUnchecked’ which is not static
569 | *(outputOffset++) = g_hexChars[(value & 0xf000) >> 12];
570 | ^
571 | ./python/objToJSON.c: In function ‘PyUnicodeToUTF8’:
572 | ./python/objToJSON.c:154:18: warning: initialization discards ‘const’ qualifier from pointer target type [-Wdiscarded-qualifiers]
573 | char *data = PyUnicode_AsUTF8AndSize(obj, &len);
574 | ^
575 | zip_safe flag not set; analyzing archive contents...
576 | __pycache__.ujson.cpython-38: module references __file__
577 | creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ujson-1.35-py3.8-linux-x86_64.egg
578 | Extracting ujson-1.35-py3.8-linux-x86_64.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
579 | Adding ujson 1.35 to easy-install.pth file
580 |
581 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ujson-1.35-py3.8-linux-x86_64.egg
582 | Searching for Shapely==1.7.0
583 | Reading https://pypi.org/simple/Shapely/
584 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
585 | warnings.warn(
586 | Downloading https://files.pythonhosted.org/packages/da/93/da69f1f278c02b4dfcf27b33e36bed43d2a6c8213d57b9a21840af4a407a/Shapely-1.7.0-cp38-cp38-manylinux1_x86_64.whl#sha256=2154b9f25c5f13785cb05ce80b2c86e542bc69671193743f29c9f4c791c35db3
587 | Best match: Shapely 1.7.0
588 | Processing Shapely-1.7.0-cp38-cp38-manylinux1_x86_64.whl
589 | Installing Shapely-1.7.0-cp38-cp38-manylinux1_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
590 | Adding Shapely 1.7.0 to easy-install.pth file
591 |
592 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/Shapely-1.7.0-py3.8-linux-x86_64.egg
593 | Searching for imgaug>=0.4.0
594 | Reading https://pypi.org/simple/imgaug/
595 | Downloading https://files.pythonhosted.org/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl#sha256=ce61e65b4eb7405fc62c1b0a79d2fa92fd47f763aaecb65152d29243592111f9
596 | Best match: imgaug 0.4.0
597 | Processing imgaug-0.4.0-py2.py3-none-any.whl
598 | Installing imgaug-0.4.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
599 | Adding imgaug 0.4.0 to easy-install.pth file
600 |
601 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/imgaug-0.4.0-py3.8.egg
602 | Searching for p_tqdm
603 | Reading https://pypi.org/simple/p_tqdm/
604 | Downloading https://files.pythonhosted.org/packages/b9/c4/ce6abe2fa3868b1ea9216a81522a9ece36f47bdbb966f8f31f76e2967178/p_tqdm-1.3.3.tar.gz#sha256=8b9316d8bae43279e03ea01c8849422e5072a931f1b94b79890b0da426802c6e
605 | Best match: p-tqdm 1.3.3
606 | Processing p_tqdm-1.3.3.tar.gz
607 | Writing /tmp/easy_install-wruk1oj0/p_tqdm-1.3.3/setup.cfg
608 | Running p_tqdm-1.3.3/setup.py -q bdist_egg --dist-dir /tmp/easy_install-wruk1oj0/p_tqdm-1.3.3/egg-dist-tmp-mqacu6ap
609 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/dist.py:771: UserWarning: Usage of dash-separated 'description-file' will not be supported in future versions. Please use the underscore name 'description_file' instead
610 | warnings.warn(
611 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
612 | warnings.warn(
613 | zip_safe flag not set; analyzing archive contents...
614 | Moving p_tqdm-1.3.3-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
615 | Adding p-tqdm 1.3.3 to easy-install.pth file
616 |
617 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/p_tqdm-1.3.3-py3.8.egg
618 | Searching for tqdm
619 | Reading https://pypi.org/simple/tqdm/
620 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
621 | warnings.warn(
622 | Downloading https://files.pythonhosted.org/packages/8a/c4/d15f1e627fff25443ded77ea70a7b5532d6371498f9285d44d62587e209c/tqdm-4.64.0-py2.py3-none-any.whl#sha256=74a2cdefe14d11442cedf3ba4e21a3b84ff9a2dbdc6cfae2c34addb2a14a5ea6
623 | Best match: tqdm 4.64.0
624 | Processing tqdm-4.64.0-py2.py3-none-any.whl
625 | Installing tqdm-4.64.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
626 | Adding tqdm 4.64.0 to easy-install.pth file
627 | Installing tqdm script to /data3/***/anaconda3/envs/clrnet/bin
628 |
629 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/tqdm-4.64.0-py3.8.egg
630 | Searching for scikit-image
631 | Reading https://pypi.org/simple/scikit-image/
632 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.7.2 is an invalid version and will not be supported in a future release
633 | warnings.warn(
634 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.8.0 is an invalid version and will not be supported in a future release
635 | warnings.warn(
636 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.8.1 is an invalid version and will not be supported in a future release
637 | warnings.warn(
638 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.8.2 is an invalid version and will not be supported in a future release
639 | warnings.warn(
640 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.9.0 is an invalid version and will not be supported in a future release
641 | warnings.warn(
642 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.9.1 is an invalid version and will not be supported in a future release
643 | warnings.warn(
644 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.9.3 is an invalid version and will not be supported in a future release
645 | warnings.warn(
646 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.10.0 is an invalid version and will not be supported in a future release
647 | warnings.warn(
648 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.10.1 is an invalid version and will not be supported in a future release
649 | warnings.warn(
650 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.11.2 is an invalid version and will not be supported in a future release
651 | warnings.warn(
652 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.11.3 is an invalid version and will not be supported in a future release
653 | warnings.warn(
654 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.0 is an invalid version and will not be supported in a future release
655 | warnings.warn(
656 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.1 is an invalid version and will not be supported in a future release
657 | warnings.warn(
658 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.2 is an invalid version and will not be supported in a future release
659 | warnings.warn(
660 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.12.3 is an invalid version and will not be supported in a future release
661 | warnings.warn(
662 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.13.0 is an invalid version and will not be supported in a future release
663 | warnings.warn(
664 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.13.1 is an invalid version and will not be supported in a future release
665 | warnings.warn(
666 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.0 is an invalid version and will not be supported in a future release
667 | warnings.warn(
668 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.1 is an invalid version and will not be supported in a future release
669 | warnings.warn(
670 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.2 is an invalid version and will not be supported in a future release
671 | warnings.warn(
672 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.3 is an invalid version and will not be supported in a future release
673 | warnings.warn(
674 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.14.5 is an invalid version and will not be supported in a future release
675 | warnings.warn(
676 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.15.0 is an invalid version and will not be supported in a future release
677 | warnings.warn(
678 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.16.2 is an invalid version and will not be supported in a future release
679 | warnings.warn(
680 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.17.1 is an invalid version and will not be supported in a future release
681 | warnings.warn(
682 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.17.2 is an invalid version and will not be supported in a future release
683 | warnings.warn(
684 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.0 is an invalid version and will not be supported in a future release
685 | warnings.warn(
686 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.1 is an invalid version and will not be supported in a future release
687 | warnings.warn(
688 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.2 is an invalid version and will not be supported in a future release
689 | warnings.warn(
690 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.18.3 is an invalid version and will not be supported in a future release
691 | warnings.warn(
692 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.0rc0 is an invalid version and will not be supported in a future release
693 | warnings.warn(
694 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.0 is an invalid version and will not be supported in a future release
695 | warnings.warn(
696 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.1 is an invalid version and will not be supported in a future release
697 | warnings.warn(
698 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.2 is an invalid version and will not be supported in a future release
699 | warnings.warn(
700 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: image-0.19.3 is an invalid version and will not be supported in a future release
701 | warnings.warn(
702 | Downloading https://files.pythonhosted.org/packages/96/11/878ee6757f75835c396fbdd934ca8e1a1681553ac0925fbf77065c9618e5/scikit_image-0.19.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=e207c6ce5ce121d7d9b9d2b61b9adca57d1abed112c902d8ffbfdc20fb42c12b
703 | Best match: scikit-image 0.19.3
704 | Processing scikit_image-0.19.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
705 | Installing scikit_image-0.19.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
706 | Adding scikit-image 0.19.3 to easy-install.pth file
707 | Installing skivi script to /data3/***/anaconda3/envs/clrnet/bin
708 |
709 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scikit_image-0.19.3-py3.8-linux-x86_64.egg
710 | Searching for pytorch_warmup
711 | Reading https://pypi.org/simple/pytorch_warmup/
712 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: warmup-0.0.3 is an invalid version and will not be supported in a future release
713 | warnings.warn(
714 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: warmup-0.0.4 is an invalid version and will not be supported in a future release
715 | warnings.warn(
716 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: warmup-0.1.0 is an invalid version and will not be supported in a future release
717 | warnings.warn(
718 | Downloading https://files.pythonhosted.org/packages/19/95/a9fb4ddf493a3df4a0ca947cc03ab452c569e5c5d5c66e31bfbdcfd2ae99/pytorch-warmup-0.1.0.tar.gz#sha256=6462c34b6942a61df54bca72c3ed7c8df8ef7089ade1f39724ec15e47022fb45
719 | Best match: pytorch-warmup 0.1.0
720 | Processing pytorch-warmup-0.1.0.tar.gz
721 | Writing /tmp/easy_install-m9mhpmsj/pytorch-warmup-0.1.0/setup.cfg
722 | Running pytorch-warmup-0.1.0/setup.py -q bdist_egg --dist-dir /tmp/easy_install-m9mhpmsj/pytorch-warmup-0.1.0/egg-dist-tmp-4vd6i_sa
723 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
724 | warnings.warn(
725 | zip_safe flag not set; analyzing archive contents...
726 | Moving pytorch_warmup-0.1.0-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
727 | Adding pytorch-warmup 0.1.0 to easy-install.pth file
728 |
729 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pytorch_warmup-0.1.0-py3.8.egg
730 | Searching for opencv-python
731 | Reading https://pypi.org/simple/opencv-python/
732 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.0.14 is an invalid version and will not be supported in a future release
733 | warnings.warn(
734 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
735 | warnings.warn(
736 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.10.37 is an invalid version and will not be supported in a future release
737 | warnings.warn(
738 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.39 is an invalid version and will not be supported in a future release
739 | warnings.warn(
740 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.41 is an invalid version and will not be supported in a future release
741 | warnings.warn(
742 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.43 is an invalid version and will not be supported in a future release
743 | warnings.warn(
744 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.11.45 is an invalid version and will not be supported in a future release
745 | warnings.warn(
746 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.13.47 is an invalid version and will not be supported in a future release
747 | warnings.warn(
748 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.15.55 is an invalid version and will not be supported in a future release
749 | warnings.warn(
750 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.16.57 is an invalid version and will not be supported in a future release
751 | warnings.warn(
752 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.16.59 is an invalid version and will not be supported in a future release
753 | warnings.warn(
754 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.17.61 is an invalid version and will not be supported in a future release
755 | warnings.warn(
756 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.17.63 is an invalid version and will not be supported in a future release
757 | warnings.warn(
758 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-3.4.18.65 is an invalid version and will not be supported in a future release
759 | warnings.warn(
760 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.3.0.38 is an invalid version and will not be supported in a future release
761 | warnings.warn(
762 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.40 is an invalid version and will not be supported in a future release
763 | warnings.warn(
764 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.42 is an invalid version and will not be supported in a future release
765 | warnings.warn(
766 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.44 is an invalid version and will not be supported in a future release
767 | warnings.warn(
768 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.4.0.46 is an invalid version and will not be supported in a future release
769 | warnings.warn(
770 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.1.48 is an invalid version and will not be supported in a future release
771 | warnings.warn(
772 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.3.56 is an invalid version and will not be supported in a future release
773 | warnings.warn(
774 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.4.58 is an invalid version and will not be supported in a future release
775 | warnings.warn(
776 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.4.60 is an invalid version and will not be supported in a future release
777 | warnings.warn(
778 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.5.62 is an invalid version and will not be supported in a future release
779 | warnings.warn(
780 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.5.5.64 is an invalid version and will not be supported in a future release
781 | warnings.warn(
782 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: python-4.6.0.66 is an invalid version and will not be supported in a future release
783 | warnings.warn(
784 | Downloading https://files.pythonhosted.org/packages/af/bf/8d189a5c43460f6b5c8eb81ead8732e94b9f73ef8d9abba9e8f5a61a6531/opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=dbdc84a9b4ea2cbae33861652d25093944b9959279200b7ae0badd32439f74de
785 | Best match: opencv-python 4.6.0.66
786 | Processing opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
787 | Installing opencv_python-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
788 | Adding opencv-python 4.6.0.66 to easy-install.pth file
789 |
790 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/opencv_python-4.6.0.66-py3.8-linux-x86_64.egg
791 | Searching for sklearn
792 | Reading https://pypi.org/simple/sklearn/
793 | Downloading https://files.pythonhosted.org/packages/1e/7a/dbb3be0ce9bd5c8b7e3d87328e79063f8b263b2b1bfa4774cb1147bfcd3f/sklearn-0.0.tar.gz#sha256=e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
794 | Best match: sklearn 0.0
795 | Processing sklearn-0.0.tar.gz
796 | Writing /tmp/easy_install-0b3r3r4j/sklearn-0.0/setup.cfg
797 | Running sklearn-0.0/setup.py -q bdist_egg --dist-dir /tmp/easy_install-0b3r3r4j/sklearn-0.0/egg-dist-tmp-px69reo0
798 | file wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141.py (for module wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141) not found
799 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
800 | warnings.warn(
801 | file wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141.py (for module wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141) not found
802 | file wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141.py (for module wheel-platform-tag-is-broken-on-empty-wheels-see-issue-141) not found
803 | warning: install_lib: 'build/lib' does not exist -- no Python modules to install
804 |
805 | creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/sklearn-0.0-py3.8.egg
806 | Extracting sklearn-0.0-py3.8.egg to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
807 | Adding sklearn 0.0 to easy-install.pth file
808 |
809 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/sklearn-0.0-py3.8.egg
810 | Searching for addict
811 | Reading https://pypi.org/simple/addict/
812 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
813 | warnings.warn(
814 | Downloading https://files.pythonhosted.org/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl#sha256=249bb56bbfd3cdc2a004ea0ff4c2b6ddc84d53bc2194761636eb314d5cfa5dfc
815 | Best match: addict 2.4.0
816 | Processing addict-2.4.0-py3-none-any.whl
817 | Installing addict-2.4.0-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
818 | Adding addict 2.4.0 to easy-install.pth file
819 |
820 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/addict-2.4.0-py3.8.egg
821 | Searching for pandas
822 | Reading https://pypi.org/simple/pandas/
823 | Downloading https://files.pythonhosted.org/packages/76/59/3451a9898bf236a86494829a889ea901930dc1c493f403912192158e4390/pandas-1.5.0rc0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=0a3cfd8777f3eb9fa6df12ea6f4aea338a9205507eb12d8dca3b3f00cf4ffe17
824 | Best match: pandas 1.5.0rc0
825 | Processing pandas-1.5.0rc0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
826 | Installing pandas-1.5.0rc0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
827 | Adding pandas 1.5.0rc0 to easy-install.pth file
828 |
829 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pandas-1.5.0rc0-py3.8-linux-x86_64.egg
830 | Searching for opencv-python-headless>=4.1.1
831 | Reading https://pypi.org/simple/opencv-python-headless/
832 | Download error on https://pypi.org/simple/opencv-python-headless/: The read operation timed out -- Some packages may not be found!
833 | Couldn't find index page for 'opencv-python-headless' (maybe misspelled?)
834 | Scanning index of all packages (this may take a while)
835 | Reading https://pypi.org/simple/
836 | No local packages or working download links found for opencv-python-headless>=4.1.1
837 | error: Could not find suitable distribution for Requirement.parse('opencv-python-headless>=4.1.1')
838 | (clrnet) [***@ai02 clrnet]$ pip list
839 | Package Version Editable project location
840 | ----------------- --------- --------------------------------
841 | addict 2.4.0
842 | albumentations 0.4.6
843 | certifi 2022.6.15
844 | clrnet 1.0 /data3/***/project/temp/clrnet
845 | imgaug 0.4.0
846 | mmcv 1.2.5
847 | numpy 1.23.2
848 | opencv-python 4.6.0.66
849 | p-tqdm 1.3.3
850 | pandas 1.5.0rc0
851 | pathspec 0.9.0
852 | Pillow 9.2.0
853 | pip 22.1.2
854 | ptflops 0.6.9
855 | pytorch-warmup 0.1.0
856 | scikit-image 0.19.3
857 | setuptools 63.4.1
858 | Shapely 1.7.0
859 | sklearn 0.0
860 | timm 0.6.7
861 | torch 1.8.0
862 | torchvision 0.9.0
863 | tqdm 4.64.0
864 | typing_extensions 4.3.0
865 | ujson 1.35
866 | wheel 0.37.1
867 | yapf 0.32.0
868 | (clrnet) [***@ai02 clrnet]$ pip install opencv-python-headless
869 | Collecting opencv-python-headless
870 | Using cached opencv_python_headless-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (48.3 MB)
871 | Requirement already satisfied: numpy>=1.14.5 in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages (from opencv-python-headless) (1.23.2)
872 | Installing collected packages: opencv-python-headless
873 | ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
874 | albumentations 0.4.6 requires PyYAML, which is not installed.
875 | albumentations 0.4.6 requires scipy, which is not installed.
876 | Successfully installed opencv-python-headless-4.6.0.66
877 | (clrnet) [***@ai02 clrnet]$ python setup.py build develop
878 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer.
879 | warnings.warn(
880 | running build
881 | running build_py
882 | running egg_info
883 | writing clrnet.egg-info/PKG-INFO
884 | writing dependency_links to clrnet.egg-info/dependency_links.txt
885 | writing requirements to clrnet.egg-info/requires.txt
886 | writing top-level names to clrnet.egg-info/top_level.txt
887 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/utils/cpp_extension.py:369: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
888 | warnings.warn(msg.format('we could not find ninja.'))
889 | reading manifest file 'clrnet.egg-info/SOURCES.txt'
890 | adding license file 'LICENSE'
891 | writing manifest file 'clrnet.egg-info/SOURCES.txt'
892 | running build_ext
893 | running develop
894 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
895 | warnings.warn(
896 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
897 | warnings.warn(
898 | running build_ext
899 | copying build/lib.linux-x86_64-cpython-38/clrnet/ops/nms_impl.cpython-38-x86_64-linux-gnu.so -> clrnet/ops
900 | Creating /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/clrnet.egg-link (link to .)
901 | clrnet 1.0 is already the active version in easy-install.pth
902 |
903 | Installed /data3/***/project/temp/clrnet
904 | Processing dependencies for clrnet==1.0
905 | Searching for PyYAML
906 | Reading https://pypi.org/simple/PyYAML/
907 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: is an invalid version and will not be supported in a future release
908 | warnings.warn(
909 | Downloading https://files.pythonhosted.org/packages/d7/42/7ad4b6d67a16229496d4f6e74201bdbebcf4bc1e87d5a70c9297d4961bd2/PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl#sha256=277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287
910 | Best match: PyYAML 6.0
911 | Processing PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl
912 | Installing PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
913 | Adding PyYAML 6.0 to easy-install.pth file
914 |
915 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/PyYAML-6.0-py3.8-linux-x86_64.egg
916 | Searching for scipy
917 | Reading https://pypi.org/simple/scipy/
918 | Downloading https://files.pythonhosted.org/packages/f7/38/f5fad7b9a50a26a6a709549f815db7026a97d1a5652e548d5bbe587bb954/scipy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=73b704c5eea9be811919cae4caacf3180dd9212d9aed08477c1d2ba14900a9de
919 | Best match: scipy 1.9.0
920 | Processing scipy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
921 | Installing scipy-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
922 | Adding scipy 1.9.0 to easy-install.pth file
923 |
924 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scipy-1.9.0-py3.8-linux-x86_64.egg
925 | Searching for imageio
926 | Reading https://pypi.org/simple/imageio/
927 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-linux32 is an invalid version and will not be supported in a future release
928 | warnings.warn(
929 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: linux32 is an invalid version and will not be supported in a future release
930 | warnings.warn(
931 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-linux64 is an invalid version and will not be supported in a future release
932 | warnings.warn(
933 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: linux64 is an invalid version and will not be supported in a future release
934 | warnings.warn(
935 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-osx64 is an invalid version and will not be supported in a future release
936 | warnings.warn(
937 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: osx64 is an invalid version and will not be supported in a future release
938 | warnings.warn(
939 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-win32 is an invalid version and will not be supported in a future release
940 | warnings.warn(
941 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: win32 is an invalid version and will not be supported in a future release
942 | warnings.warn(
943 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.1-win64 is an invalid version and will not be supported in a future release
944 | warnings.warn(
945 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: win64 is an invalid version and will not be supported in a future release
946 | warnings.warn(
947 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-linux32 is an invalid version and will not be supported in a future release
948 | warnings.warn(
949 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-linux64 is an invalid version and will not be supported in a future release
950 | warnings.warn(
951 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-osx64 is an invalid version and will not be supported in a future release
952 | warnings.warn(
953 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-win32 is an invalid version and will not be supported in a future release
954 | warnings.warn(
955 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.2-win64 is an invalid version and will not be supported in a future release
956 | warnings.warn(
957 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-linux32 is an invalid version and will not be supported in a future release
958 | warnings.warn(
959 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-linux64 is an invalid version and will not be supported in a future release
960 | warnings.warn(
961 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-osx64 is an invalid version and will not be supported in a future release
962 | warnings.warn(
963 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-win32 is an invalid version and will not be supported in a future release
964 | warnings.warn(
965 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 1.3-win64 is an invalid version and will not be supported in a future release
966 | warnings.warn(
967 | Downloading https://files.pythonhosted.org/packages/19/ea/fe0cf58fe26999fc22b7cc75215a6093e1f9e146b747b1c39ad7ed320e92/imageio-2.21.1-py3-none-any.whl#sha256=ea8770d082cea02de6ca5500ab3ad649a8c09832528152efd07da5c225b13722
968 | Best match: imageio 2.21.1
969 | Processing imageio-2.21.1-py3-none-any.whl
970 | Installing imageio-2.21.1-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
971 | Adding imageio 2.21.1 to easy-install.pth file
972 | Installing imageio_download_bin script to /data3/***/anaconda3/envs/clrnet/bin
973 | Installing imageio_remove_bin script to /data3/***/anaconda3/envs/clrnet/bin
974 |
975 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/imageio-2.21.1-py3.8.egg
976 | Searching for matplotlib
977 | Reading https://pypi.org/simple/matplotlib/
978 | Downloading https://files.pythonhosted.org/packages/45/5e/5f45fbab5d259403348bebcf1aac514c64ecc0c27a8a24dca9a4a7476403/matplotlib-3.6.0rc1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl#sha256=daff0aff6e90fddd8fb53aa9e7116e0dbd48861573070c69faf7342ba2c20a44
979 | Best match: matplotlib 3.6.0rc1
980 | Processing matplotlib-3.6.0rc1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
981 | Installing matplotlib-3.6.0rc1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
982 | Adding matplotlib 3.6.0rc1 to easy-install.pth file
983 |
984 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/matplotlib-3.6.0rc1-py3.8-linux-x86_64.egg
985 | Searching for six
986 | Reading https://pypi.org/simple/six/
987 | Downloading https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl#sha256=8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254
988 | Best match: six 1.16.0
989 | Processing six-1.16.0-py2.py3-none-any.whl
990 | Installing six-1.16.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
991 | Adding six 1.16.0 to easy-install.pth file
992 |
993 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/six-1.16.0-py3.8.egg
994 | Searching for pathos
995 | Reading https://pypi.org/simple/pathos/
996 | Downloading https://files.pythonhosted.org/packages/57/78/3f3a6c8f9447a947af2f2a191f71cc870b0fa673432044f38d966f3fd9c0/pathos-0.2.9-py3-none-any.whl#sha256=1c44373d8692897d5d15a8aa3b3a442ddc0814c5e848f4ff0ded5491f34b1dac
997 | Best match: pathos 0.2.9
998 | Processing pathos-0.2.9-py3-none-any.whl
999 | Installing pathos-0.2.9-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1000 | Adding pathos 0.2.9 to easy-install.pth file
1001 | Installing portpicker script to /data3/***/anaconda3/envs/clrnet/bin
1002 | Installing pathos_connect script to /data3/***/anaconda3/envs/clrnet/bin
1003 |
1004 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pathos-0.2.9-py3.8.egg
1005 | Searching for packaging>=20.0
1006 | Reading https://pypi.org/simple/packaging/
1007 | Downloading https://files.pythonhosted.org/packages/05/8e/8de486cbd03baba4deef4142bd643a3e7bbe954a784dc1bb17142572d127/packaging-21.3-py3-none-any.whl#sha256=ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522
1008 | Best match: packaging 21.3
1009 | Processing packaging-21.3-py3-none-any.whl
1010 | Installing packaging-21.3-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1011 | Adding packaging 21.3 to easy-install.pth file
1012 |
1013 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/packaging-21.3-py3.8.egg
1014 | Searching for PyWavelets>=1.1.1
1015 | Reading https://pypi.org/simple/PyWavelets/
1016 | Downloading https://files.pythonhosted.org/packages/4b/20/04a0a3e43a45a459c2bcde756833b2eca9430729e89d65da35f70e99e997/PyWavelets-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=d7369597e1b1d125eb4b458a36cef052beed188444e55ed21445c1196008e200
1017 | Best match: PyWavelets 1.3.0
1018 | Processing PyWavelets-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
1019 | Installing PyWavelets-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1020 | Adding PyWavelets 1.3.0 to easy-install.pth file
1021 |
1022 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/PyWavelets-1.3.0-py3.8-linux-x86_64.egg
1023 | Searching for tifffile>=2019.7.26
1024 | Reading https://pypi.org/simple/tifffile/
1025 | Downloading https://files.pythonhosted.org/packages/45/84/31e59ef72ac4149bb27ab9ccb3aa2b0d294abd97cf61dafd599bddf50a69/tifffile-2022.8.12-py3-none-any.whl#sha256=1456f9f6943c85082ef4d73f5329038826da67f70d5d513873a06f3b1598d23e
1026 | Best match: tifffile 2022.8.12
1027 | Processing tifffile-2022.8.12-py3-none-any.whl
1028 | Installing tifffile-2022.8.12-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1029 | Adding tifffile 2022.8.12 to easy-install.pth file
1030 | Installing lsm2bin script to /data3/***/anaconda3/envs/clrnet/bin
1031 | Installing tiff2fsspec script to /data3/***/anaconda3/envs/clrnet/bin
1032 | Installing tiffcomment script to /data3/***/anaconda3/envs/clrnet/bin
1033 | Installing tifffile script to /data3/***/anaconda3/envs/clrnet/bin
1034 |
1035 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/tifffile-2022.8.12-py3.8.egg
1036 | Searching for networkx>=2.2
1037 | Reading https://pypi.org/simple/networkx/
1038 | Downloading https://files.pythonhosted.org/packages/be/25/5b0fc262a2f2d7d11c22cb7785edf2befc756ae076b383034e79e255eb11/networkx-2.8.6-py3-none-any.whl#sha256=2a30822761f34d56b9a370d96a4bf4827a535f5591a4078a453425caeba0c5bb
1039 | Best match: networkx 2.8.6
1040 | Processing networkx-2.8.6-py3-none-any.whl
1041 | Installing networkx-2.8.6-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1042 | Adding networkx 2.8.6 to easy-install.pth file
1043 |
1044 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/networkx-2.8.6-py3.8.egg
1045 | Searching for scikit-learn
1046 | Reading https://pypi.org/simple/scikit-learn/
1047 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.9 is an invalid version and will not be supported in a future release
1048 | warnings.warn(
1049 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.10 is an invalid version and will not be supported in a future release
1050 | warnings.warn(
1051 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.11 is an invalid version and will not be supported in a future release
1052 | warnings.warn(
1053 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.12 is an invalid version and will not be supported in a future release
1054 | warnings.warn(
1055 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.12.1 is an invalid version and will not be supported in a future release
1056 | warnings.warn(
1057 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.13 is an invalid version and will not be supported in a future release
1058 | warnings.warn(
1059 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.13.1 is an invalid version and will not be supported in a future release
1060 | warnings.warn(
1061 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.14 is an invalid version and will not be supported in a future release
1062 | warnings.warn(
1063 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.14.1 is an invalid version and will not be supported in a future release
1064 | warnings.warn(
1065 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.0b1 is an invalid version and will not be supported in a future release
1066 | warnings.warn(
1067 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.0b2 is an invalid version and will not be supported in a future release
1068 | warnings.warn(
1069 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.0 is an invalid version and will not be supported in a future release
1070 | warnings.warn(
1071 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.1 is an invalid version and will not be supported in a future release
1072 | warnings.warn(
1073 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.15.2 is an invalid version and will not be supported in a future release
1074 | warnings.warn(
1075 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.16b1 is an invalid version and will not be supported in a future release
1076 | warnings.warn(
1077 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.16.0 is an invalid version and will not be supported in a future release
1078 | warnings.warn(
1079 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.16.1 is an invalid version and will not be supported in a future release
1080 | warnings.warn(
1081 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.17b1 is an invalid version and will not be supported in a future release
1082 | warnings.warn(
1083 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.17 is an invalid version and will not be supported in a future release
1084 | warnings.warn(
1085 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.17.1 is an invalid version and will not be supported in a future release
1086 | warnings.warn(
1087 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18rc2 is an invalid version and will not be supported in a future release
1088 | warnings.warn(
1089 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18 is an invalid version and will not be supported in a future release
1090 | warnings.warn(
1091 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18.1 is an invalid version and will not be supported in a future release
1092 | warnings.warn(
1093 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.18.2 is an invalid version and will not be supported in a future release
1094 | warnings.warn(
1095 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19b2 is an invalid version and will not be supported in a future release
1096 | warnings.warn(
1097 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19.0 is an invalid version and will not be supported in a future release
1098 | warnings.warn(
1099 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19.1 is an invalid version and will not be supported in a future release
1100 | warnings.warn(
1101 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.19.2 is an invalid version and will not be supported in a future release
1102 | warnings.warn(
1103 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20rc1 is an invalid version and will not be supported in a future release
1104 | warnings.warn(
1105 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.0 is an invalid version and will not be supported in a future release
1106 | warnings.warn(
1107 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.1 is an invalid version and will not be supported in a future release
1108 | warnings.warn(
1109 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.2 is an invalid version and will not be supported in a future release
1110 | warnings.warn(
1111 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.3 is an invalid version and will not be supported in a future release
1112 | warnings.warn(
1113 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.20.4 is an invalid version and will not be supported in a future release
1114 | warnings.warn(
1115 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21rc2 is an invalid version and will not be supported in a future release
1116 | warnings.warn(
1117 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21.1 is an invalid version and will not be supported in a future release
1118 | warnings.warn(
1119 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21.2 is an invalid version and will not be supported in a future release
1120 | warnings.warn(
1121 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.21.3 is an invalid version and will not be supported in a future release
1122 | warnings.warn(
1123 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22rc2.post1 is an invalid version and will not be supported in a future release
1124 | warnings.warn(
1125 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22rc3 is an invalid version and will not be supported in a future release
1126 | warnings.warn(
1127 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22 is an invalid version and will not be supported in a future release
1128 | warnings.warn(
1129 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22.1 is an invalid version and will not be supported in a future release
1130 | warnings.warn(
1131 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.22.2.post1 is an invalid version and will not be supported in a future release
1132 | warnings.warn(
1133 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.0rc1 is an invalid version and will not be supported in a future release
1134 | warnings.warn(
1135 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.0 is an invalid version and will not be supported in a future release
1136 | warnings.warn(
1137 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.1 is an invalid version and will not be supported in a future release
1138 | warnings.warn(
1139 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.23.2 is an invalid version and will not be supported in a future release
1140 | warnings.warn(
1141 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.dev0 is an invalid version and will not be supported in a future release
1142 | warnings.warn(
1143 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.0rc1 is an invalid version and will not be supported in a future release
1144 | warnings.warn(
1145 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.0 is an invalid version and will not be supported in a future release
1146 | warnings.warn(
1147 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.1 is an invalid version and will not be supported in a future release
1148 | warnings.warn(
1149 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-0.24.2 is an invalid version and will not be supported in a future release
1150 | warnings.warn(
1151 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0rc1 is an invalid version and will not be supported in a future release
1152 | warnings.warn(
1153 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0rc2 is an invalid version and will not be supported in a future release
1154 | warnings.warn(
1155 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0 is an invalid version and will not be supported in a future release
1156 | warnings.warn(
1157 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0.1 is an invalid version and will not be supported in a future release
1158 | warnings.warn(
1159 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.0.2 is an invalid version and will not be supported in a future release
1160 | warnings.warn(
1161 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.0rc1 is an invalid version and will not be supported in a future release
1162 | warnings.warn(
1163 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.0 is an invalid version and will not be supported in a future release
1164 | warnings.warn(
1165 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.1 is an invalid version and will not be supported in a future release
1166 | warnings.warn(
1167 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: learn-1.1.2 is an invalid version and will not be supported in a future release
1168 | warnings.warn(
1169 | Downloading https://files.pythonhosted.org/packages/91/d1/50eb92222e8b2f315ec5499b97a926d271305e19e254fdced4db899647d6/scikit_learn-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=f94c0146bad51daef919c402a3da8c1c6162619653e1c00c92baa168fda292f2
1170 | Best match: scikit-learn 1.1.2
1171 | Processing scikit_learn-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
1172 | Installing scikit_learn-1.1.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1173 | Adding scikit-learn 1.1.2 to easy-install.pth file
1174 |
1175 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scikit_learn-1.1.2-py3.8-linux-x86_64.egg
1176 | Searching for pytz>=2020.1
1177 | Reading https://pypi.org/simple/pytz/
1178 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2004d is an invalid version and will not be supported in a future release
1179 | warnings.warn(
1180 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005e is an invalid version and will not be supported in a future release
1181 | warnings.warn(
1182 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005i is an invalid version and will not be supported in a future release
1183 | warnings.warn(
1184 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005k is an invalid version and will not be supported in a future release
1185 | warnings.warn(
1186 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2005m is an invalid version and will not be supported in a future release
1187 | warnings.warn(
1188 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2006g is an invalid version and will not be supported in a future release
1189 | warnings.warn(
1190 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2006j is an invalid version and will not be supported in a future release
1191 | warnings.warn(
1192 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2006p is an invalid version and will not be supported in a future release
1193 | warnings.warn(
1194 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007d is an invalid version and will not be supported in a future release
1195 | warnings.warn(
1196 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007f is an invalid version and will not be supported in a future release
1197 | warnings.warn(
1198 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007g is an invalid version and will not be supported in a future release
1199 | warnings.warn(
1200 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007i is an invalid version and will not be supported in a future release
1201 | warnings.warn(
1202 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2007k is an invalid version and will not be supported in a future release
1203 | warnings.warn(
1204 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2008g is an invalid version and will not be supported in a future release
1205 | warnings.warn(
1206 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2008h is an invalid version and will not be supported in a future release
1207 | warnings.warn(
1208 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2008i is an invalid version and will not be supported in a future release
1209 | warnings.warn(
1210 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009d is an invalid version and will not be supported in a future release
1211 | warnings.warn(
1212 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009e is an invalid version and will not be supported in a future release
1213 | warnings.warn(
1214 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009f is an invalid version and will not be supported in a future release
1215 | warnings.warn(
1216 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009g is an invalid version and will not be supported in a future release
1217 | warnings.warn(
1218 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009i is an invalid version and will not be supported in a future release
1219 | warnings.warn(
1220 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009j is an invalid version and will not be supported in a future release
1221 | warnings.warn(
1222 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009l is an invalid version and will not be supported in a future release
1223 | warnings.warn(
1224 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009n is an invalid version and will not be supported in a future release
1225 | warnings.warn(
1226 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009p is an invalid version and will not be supported in a future release
1227 | warnings.warn(
1228 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2009u is an invalid version and will not be supported in a future release
1229 | warnings.warn(
1230 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010e is an invalid version and will not be supported in a future release
1231 | warnings.warn(
1232 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010g is an invalid version and will not be supported in a future release
1233 | warnings.warn(
1234 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010h is an invalid version and will not be supported in a future release
1235 | warnings.warn(
1236 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010k is an invalid version and will not be supported in a future release
1237 | warnings.warn(
1238 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010l is an invalid version and will not be supported in a future release
1239 | warnings.warn(
1240 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2010o is an invalid version and will not be supported in a future release
1241 | warnings.warn(
1242 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011d is an invalid version and will not be supported in a future release
1243 | warnings.warn(
1244 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011e is an invalid version and will not be supported in a future release
1245 | warnings.warn(
1246 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011g is an invalid version and will not be supported in a future release
1247 | warnings.warn(
1248 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011h is an invalid version and will not be supported in a future release
1249 | warnings.warn(
1250 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011j is an invalid version and will not be supported in a future release
1251 | warnings.warn(
1252 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011k is an invalid version and will not be supported in a future release
1253 | warnings.warn(
1254 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2011n is an invalid version and will not be supported in a future release
1255 | warnings.warn(
1256 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012d is an invalid version and will not be supported in a future release
1257 | warnings.warn(
1258 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012f is an invalid version and will not be supported in a future release
1259 | warnings.warn(
1260 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012g is an invalid version and will not be supported in a future release
1261 | warnings.warn(
1262 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012h is an invalid version and will not be supported in a future release
1263 | warnings.warn(
1264 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2012j is an invalid version and will not be supported in a future release
1265 | warnings.warn(
1266 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 2013d is an invalid version and will not be supported in a future release
1267 | warnings.warn(
1268 | Downloading https://files.pythonhosted.org/packages/d5/50/54451e88e3da4616286029a3a17fc377de817f66a0f50e1faaee90161724/pytz-2022.2.1-py2.py3-none-any.whl#sha256=220f481bdafa09c3955dfbdddb7b57780e9a94f5127e35456a48589b9e0c0197
1269 | Best match: pytz 2022.2.1
1270 | Processing pytz-2022.2.1-py2.py3-none-any.whl
1271 | Installing pytz-2022.2.1-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1272 | Adding pytz 2022.2.1 to easy-install.pth file
1273 |
1274 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pytz-2022.2.1-py3.8.egg
1275 | Searching for python-dateutil>=2.8.1
1276 | Reading https://pypi.org/simple/python-dateutil/
1277 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-1.4 is an invalid version and will not be supported in a future release
1278 | warnings.warn(
1279 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-1.4.1 is an invalid version and will not be supported in a future release
1280 | warnings.warn(
1281 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-1.5 is an invalid version and will not be supported in a future release
1282 | warnings.warn(
1283 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.1 is an invalid version and will not be supported in a future release
1284 | warnings.warn(
1285 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.2 is an invalid version and will not be supported in a future release
1286 | warnings.warn(
1287 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.3 is an invalid version and will not be supported in a future release
1288 | warnings.warn(
1289 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.0 is an invalid version and will not be supported in a future release
1290 | warnings.warn(
1291 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.1.post1 is an invalid version and will not be supported in a future release
1292 | warnings.warn(
1293 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.1 is an invalid version and will not be supported in a future release
1294 | warnings.warn(
1295 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.4.2 is an invalid version and will not be supported in a future release
1296 | warnings.warn(
1297 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.0 is an invalid version and will not be supported in a future release
1298 | warnings.warn(
1299 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.1 is an invalid version and will not be supported in a future release
1300 | warnings.warn(
1301 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.2 is an invalid version and will not be supported in a future release
1302 | warnings.warn(
1303 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.5.3 is an invalid version and will not be supported in a future release
1304 | warnings.warn(
1305 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.6.0 is an invalid version and will not be supported in a future release
1306 | warnings.warn(
1307 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.6.1 is an invalid version and will not be supported in a future release
1308 | warnings.warn(
1309 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.0 is an invalid version and will not be supported in a future release
1310 | warnings.warn(
1311 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.1 is an invalid version and will not be supported in a future release
1312 | warnings.warn(
1313 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.2 is an invalid version and will not be supported in a future release
1314 | warnings.warn(
1315 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.3 is an invalid version and will not be supported in a future release
1316 | warnings.warn(
1317 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.4 is an invalid version and will not be supported in a future release
1318 | warnings.warn(
1319 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.7.5 is an invalid version and will not be supported in a future release
1320 | warnings.warn(
1321 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.8.0 is an invalid version and will not be supported in a future release
1322 | warnings.warn(
1323 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.8.1 is an invalid version and will not be supported in a future release
1324 | warnings.warn(
1325 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: dateutil-2.8.2 is an invalid version and will not be supported in a future release
1326 | warnings.warn(
1327 | Downloading https://files.pythonhosted.org/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl#sha256=961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9
1328 | Best match: python-dateutil 2.8.2
1329 | Processing python_dateutil-2.8.2-py2.py3-none-any.whl
1330 | Installing python_dateutil-2.8.2-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1331 | Adding python-dateutil 2.8.2 to easy-install.pth file
1332 |
1333 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/python_dateutil-2.8.2-py3.8.egg
1334 | Searching for pyparsing>=2.2.1
1335 | Reading https://pypi.org/simple/pyparsing/
1336 | Downloading https://files.pythonhosted.org/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl#sha256=5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc
1337 | Best match: pyparsing 3.0.9
1338 | Processing pyparsing-3.0.9-py3-none-any.whl
1339 | Installing pyparsing-3.0.9-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1340 | Adding pyparsing 3.0.9 to easy-install.pth file
1341 |
1342 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pyparsing-3.0.9-py3.8.egg
1343 | Searching for kiwisolver>=1.0.1
1344 | Reading https://pypi.org/simple/kiwisolver/
1345 | Downloading https://files.pythonhosted.org/packages/86/7a/6b438da7534dacd232ed4e19f74f4edced2cda9494d7e6536f54edfdf4a5/kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl#sha256=2e407cb4bd5a13984a6c2c0fe1845e4e41e96f183e5e5cd4d77a857d9693494c
1346 | Best match: kiwisolver 1.4.4
1347 | Processing kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl
1348 | Installing kiwisolver-1.4.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1349 | Adding kiwisolver 1.4.4 to easy-install.pth file
1350 |
1351 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/kiwisolver-1.4.4-py3.8-linux-x86_64.egg
1352 | Searching for fonttools>=4.22.0
1353 | Reading https://pypi.org/simple/fonttools/
1354 | Downloading https://files.pythonhosted.org/packages/99/e2/90421047dfd94ed78064b41345d63985c576dc98b7c87bec92a1a8cb88e5/fonttools-4.37.1-py3-none-any.whl#sha256=fff6b752e326c15756c819fe2fe7ceab69f96a1dbcfe8911d0941cdb49905007
1355 | Best match: fonttools 4.37.1
1356 | Processing fonttools-4.37.1-py3-none-any.whl
1357 | Installing fonttools-4.37.1-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1358 | Adding fonttools 4.37.1 to easy-install.pth file
1359 | Installing fonttools script to /data3/***/anaconda3/envs/clrnet/bin
1360 | Installing pyftmerge script to /data3/***/anaconda3/envs/clrnet/bin
1361 | Installing pyftsubset script to /data3/***/anaconda3/envs/clrnet/bin
1362 | Installing ttx script to /data3/***/anaconda3/envs/clrnet/bin
1363 |
1364 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/fonttools-4.37.1-py3.8.egg
1365 | Searching for cycler>=0.10
1366 | Reading https://pypi.org/simple/cycler/
1367 | Downloading https://files.pythonhosted.org/packages/5c/f9/695d6bedebd747e5eb0fe8fad57b72fdf25411273a39791cde838d5a8f51/cycler-0.11.0-py3-none-any.whl#sha256=3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3
1368 | Best match: cycler 0.11.0
1369 | Processing cycler-0.11.0-py3-none-any.whl
1370 | Installing cycler-0.11.0-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1371 | Adding cycler 0.11.0 to easy-install.pth file
1372 |
1373 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/cycler-0.11.0-py3.8.egg
1374 | Searching for contourpy>=1.0.1
1375 | Reading https://pypi.org/simple/contourpy/
1376 | Downloading https://files.pythonhosted.org/packages/c1/e2/7cabfe4858cd8f2f71946ac5510421ade781074012666bc0627ef95f4105/contourpy-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl#sha256=2b20bf416a693744f1ebbbd6f24e2300b613c3570ed67e518a97dc052bf3ad11
1377 | Best match: contourpy 1.0.4
1378 | Processing contourpy-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
1379 | Installing contourpy-1.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1380 | Adding contourpy 1.0.4 to easy-install.pth file
1381 |
1382 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/contourpy-1.0.4-py3.8-linux-x86_64.egg
1383 | Searching for multiprocess>=0.70.13
1384 | Reading https://pypi.org/simple/multiprocess/
1385 | Downloading https://files.pythonhosted.org/packages/4b/28/915e6d4943eab538963a4fac7c409aa4a75213f5e6b98d91334e9a284c5c/multiprocess-0.70.13-py37-none-any.whl#sha256=62e556a0c31ec7176e28aa331663ac26c276ee3536b5e9bb5e850681e7a00f11
1386 | Best match: multiprocess 0.70.13
1387 | Processing multiprocess-0.70.13-py37-none-any.whl
1388 | Installing multiprocess-0.70.13-py37-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1389 | Adding multiprocess 0.70.13 to easy-install.pth file
1390 |
1391 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/multiprocess-0.70.13-py3.8.egg
1392 | Searching for pox>=0.3.1
1393 | Reading https://pypi.org/simple/pox/
1394 | Downloading https://files.pythonhosted.org/packages/ee/ae/47dde73e5fa582669d08d0c8284e1863d8e8353947da264c38bc2831e8bb/pox-0.3.1-py2.py3-none-any.whl#sha256=541b5c845aacb806c1364d4142003efb809d654c9ca8db82e650ee86c81e680b
1395 | Best match: pox 0.3.1
1396 | Processing pox-0.3.1-py2.py3-none-any.whl
1397 | Installing pox-0.3.1-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1398 | Adding pox 0.3.1 to easy-install.pth file
1399 | Installing pox script to /data3/***/anaconda3/envs/clrnet/bin
1400 |
1401 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pox-0.3.1-py3.8.egg
1402 | Searching for dill>=0.3.5.1
1403 | Reading https://pypi.org/simple/dill/
1404 | Downloading https://files.pythonhosted.org/packages/12/ff/3b1a8f5d59600393506c64fa14d13afdfe6fe79ed65a18d64026fe9f8356/dill-0.3.5.1-py2.py3-none-any.whl#sha256=33501d03270bbe410c72639b350e941882a8b0fd55357580fbc873fba0c59302
1405 | Best match: dill 0.3.5.1
1406 | Processing dill-0.3.5.1-py2.py3-none-any.whl
1407 | Installing dill-0.3.5.1-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1408 | Adding dill 0.3.5.1 to easy-install.pth file
1409 | Installing get_objgraph script to /data3/***/anaconda3/envs/clrnet/bin
1410 | Installing undill script to /data3/***/anaconda3/envs/clrnet/bin
1411 |
1412 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/dill-0.3.5.1-py3.8.egg
1413 | Searching for ppft>=1.7.6.5
1414 | Reading https://pypi.org/simple/ppft/
1415 | Downloading https://files.pythonhosted.org/packages/dc/e9/5060afcfad6e19f4725d59e64398e31ec9439863244b401cb3437becbbff/ppft-1.7.6.5-py2.py3-none-any.whl#sha256=07166097d7dd45af7b98859654390d579d11dadf20780f6baca4bded3f55a580
1416 | Best match: ppft 1.7.6.5
1417 | Processing ppft-1.7.6.5-py2.py3-none-any.whl
1418 | Installing ppft-1.7.6.5-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1419 | Adding ppft 1.7.6.5 to easy-install.pth file
1420 | Installing ppserver script to /data3/***/anaconda3/envs/clrnet/bin
1421 |
1422 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ppft-1.7.6.5-py3.8.egg
1423 | Searching for threadpoolctl>=2.0.0
1424 | Reading https://pypi.org/simple/threadpoolctl/
1425 | Downloading https://files.pythonhosted.org/packages/61/cf/6e354304bcb9c6413c4e02a747b600061c21d38ba51e7e544ac7bc66aecc/threadpoolctl-3.1.0-py3-none-any.whl#sha256=8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b
1426 | Best match: threadpoolctl 3.1.0
1427 | Processing threadpoolctl-3.1.0-py3-none-any.whl
1428 | Installing threadpoolctl-3.1.0-py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1429 | Adding threadpoolctl 3.1.0 to easy-install.pth file
1430 |
1431 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/threadpoolctl-3.1.0-py3.8.egg
1432 | Searching for joblib>=1.0.0
1433 | Reading https://pypi.org/simple/joblib/
1434 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2d.dev is an invalid version and will not be supported in a future release
1435 | warnings.warn(
1436 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2e.dev is an invalid version and will not be supported in a future release
1437 | warnings.warn(
1438 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2f.dev is an invalid version and will not be supported in a future release
1439 | warnings.warn(
1440 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.3.2g.dev is an invalid version and will not be supported in a future release
1441 | warnings.warn(
1442 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: 0.7.0d is an invalid version and will not be supported in a future release
1443 | warnings.warn(
1444 | /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pkg_resources/__init__.py:123: PkgResourcesDeprecationWarning: r1 is an invalid version and will not be supported in a future release
1445 | warnings.warn(
1446 | Downloading https://files.pythonhosted.org/packages/3e/d5/0163eb0cfa0b673aa4fe1cd3ea9d8a81ea0f32e50807b0c295871e4aab2e/joblib-1.1.0-py2.py3-none-any.whl#sha256=f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6
1447 | Best match: joblib 1.1.0
1448 | Processing joblib-1.1.0-py2.py3-none-any.whl
1449 | Installing joblib-1.1.0-py2.py3-none-any.whl to /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1450 | Adding joblib 1.1.0 to easy-install.pth file
1451 |
1452 | Installed /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/joblib-1.1.0-py3.8.egg
1453 | Searching for ptflops==0.6.9
1454 | Best match: ptflops 0.6.9
1455 | Processing ptflops-0.6.9-py3.8.egg
1456 | ptflops 0.6.9 is already the active version in easy-install.pth
1457 |
1458 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ptflops-0.6.9-py3.8.egg
1459 | Searching for pathspec==0.9.0
1460 | Best match: pathspec 0.9.0
1461 | Processing pathspec-0.9.0-py3.8.egg
1462 | pathspec 0.9.0 is already the active version in easy-install.pth
1463 |
1464 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pathspec-0.9.0-py3.8.egg
1465 | Searching for albumentations==0.4.6
1466 | Best match: albumentations 0.4.6
1467 | Processing albumentations-0.4.6-py3.8.egg
1468 | albumentations 0.4.6 is already the active version in easy-install.pth
1469 |
1470 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/albumentations-0.4.6-py3.8.egg
1471 | Searching for mmcv==1.2.5
1472 | Best match: mmcv 1.2.5
1473 | Processing mmcv-1.2.5-py3.8.egg
1474 | mmcv 1.2.5 is already the active version in easy-install.pth
1475 |
1476 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/mmcv-1.2.5-py3.8.egg
1477 | Searching for timm==0.6.7
1478 | Best match: timm 0.6.7
1479 | Processing timm-0.6.7-py3.8.egg
1480 | timm 0.6.7 is already the active version in easy-install.pth
1481 |
1482 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/timm-0.6.7-py3.8.egg
1483 | Searching for yapf==0.32.0
1484 | Best match: yapf 0.32.0
1485 | Processing yapf-0.32.0-py3.8.egg
1486 | yapf 0.32.0 is already the active version in easy-install.pth
1487 | Installing yapf script to /data3/***/anaconda3/envs/clrnet/bin
1488 | Installing yapf-diff script to /data3/***/anaconda3/envs/clrnet/bin
1489 |
1490 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/yapf-0.32.0-py3.8.egg
1491 | Searching for ujson==1.35
1492 | Best match: ujson 1.35
1493 | Processing ujson-1.35-py3.8-linux-x86_64.egg
1494 | ujson 1.35 is already the active version in easy-install.pth
1495 |
1496 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/ujson-1.35-py3.8-linux-x86_64.egg
1497 | Searching for Shapely==1.7.0
1498 | Best match: Shapely 1.7.0
1499 | Processing Shapely-1.7.0-py3.8-linux-x86_64.egg
1500 | Shapely 1.7.0 is already the active version in easy-install.pth
1501 |
1502 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/Shapely-1.7.0-py3.8-linux-x86_64.egg
1503 | Searching for imgaug==0.4.0
1504 | Best match: imgaug 0.4.0
1505 | Processing imgaug-0.4.0-py3.8.egg
1506 | imgaug 0.4.0 is already the active version in easy-install.pth
1507 |
1508 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/imgaug-0.4.0-py3.8.egg
1509 | Searching for p-tqdm==1.3.3
1510 | Best match: p-tqdm 1.3.3
1511 | Processing p_tqdm-1.3.3-py3.8.egg
1512 | p-tqdm 1.3.3 is already the active version in easy-install.pth
1513 |
1514 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/p_tqdm-1.3.3-py3.8.egg
1515 | Searching for tqdm==4.64.0
1516 | Best match: tqdm 4.64.0
1517 | Processing tqdm-4.64.0-py3.8.egg
1518 | tqdm 4.64.0 is already the active version in easy-install.pth
1519 | Installing tqdm script to /data3/***/anaconda3/envs/clrnet/bin
1520 |
1521 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/tqdm-4.64.0-py3.8.egg
1522 | Searching for scikit-image==0.19.3
1523 | Best match: scikit-image 0.19.3
1524 | Processing scikit_image-0.19.3-py3.8-linux-x86_64.egg
1525 | scikit-image 0.19.3 is already the active version in easy-install.pth
1526 | Installing skivi script to /data3/***/anaconda3/envs/clrnet/bin
1527 |
1528 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/scikit_image-0.19.3-py3.8-linux-x86_64.egg
1529 | Searching for pytorch-warmup==0.1.0
1530 | Best match: pytorch-warmup 0.1.0
1531 | Processing pytorch_warmup-0.1.0-py3.8.egg
1532 | pytorch-warmup 0.1.0 is already the active version in easy-install.pth
1533 |
1534 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pytorch_warmup-0.1.0-py3.8.egg
1535 | Searching for opencv-python==4.6.0.66
1536 | Best match: opencv-python 4.6.0.66
1537 | Processing opencv_python-4.6.0.66-py3.8-linux-x86_64.egg
1538 | opencv-python 4.6.0.66 is already the active version in easy-install.pth
1539 |
1540 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/opencv_python-4.6.0.66-py3.8-linux-x86_64.egg
1541 | Searching for sklearn==0.0
1542 | Best match: sklearn 0.0
1543 | Processing sklearn-0.0-py3.8.egg
1544 | sklearn 0.0 is already the active version in easy-install.pth
1545 |
1546 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/sklearn-0.0-py3.8.egg
1547 | Searching for addict==2.4.0
1548 | Best match: addict 2.4.0
1549 | Processing addict-2.4.0-py3.8.egg
1550 | addict 2.4.0 is already the active version in easy-install.pth
1551 |
1552 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/addict-2.4.0-py3.8.egg
1553 | Searching for pandas==1.5.0rc0
1554 | Best match: pandas 1.5.0rc0
1555 | Processing pandas-1.5.0rc0-py3.8-linux-x86_64.egg
1556 | pandas 1.5.0rc0 is already the active version in easy-install.pth
1557 |
1558 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pandas-1.5.0rc0-py3.8-linux-x86_64.egg
1559 | Searching for torchvision==0.9.0
1560 | Best match: torchvision 0.9.0
1561 | Adding torchvision 0.9.0 to easy-install.pth file
1562 |
1563 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1564 | Searching for torch==1.8.0
1565 | Best match: torch 1.8.0
1566 | Adding torch 1.8.0 to easy-install.pth file
1567 | Installing convert-caffe2-to-onnx script to /data3/***/anaconda3/envs/clrnet/bin
1568 | Installing convert-onnx-to-caffe2 script to /data3/***/anaconda3/envs/clrnet/bin
1569 |
1570 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1571 | Searching for opencv-python-headless==4.6.0.66
1572 | Best match: opencv-python-headless 4.6.0.66
1573 | Adding opencv-python-headless 4.6.0.66 to easy-install.pth file
1574 |
1575 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1576 | Searching for numpy==1.23.2
1577 | Best match: numpy 1.23.2
1578 | Adding numpy 1.23.2 to easy-install.pth file
1579 | Installing f2py script to /data3/***/anaconda3/envs/clrnet/bin
1580 | Installing f2py3 script to /data3/***/anaconda3/envs/clrnet/bin
1581 | Installing f2py3.8 script to /data3/***/anaconda3/envs/clrnet/bin
1582 |
1583 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1584 | Searching for Pillow==9.2.0
1585 | Best match: Pillow 9.2.0
1586 | Adding Pillow 9.2.0 to easy-install.pth file
1587 |
1588 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1589 | Searching for typing-extensions==4.3.0
1590 | Best match: typing-extensions 4.3.0
1591 | Adding typing-extensions 4.3.0 to easy-install.pth file
1592 |
1593 | Using /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages
1594 | Finished processing dependencies for clrnet==1.0
1595 | (clrnet) [***@ai02 clrnet]$ pip list
1596 | Package Version Editable project location
1597 | ---------------------- --------- --------------------------------
1598 | addict 2.4.0
1599 | albumentations 0.4.6
1600 | certifi 2022.6.15
1601 | clrnet 1.0 /data3/***/project/temp/clrnet
1602 | contourpy 1.0.4
1603 | cycler 0.11.0
1604 | dill 0.3.5.1
1605 | fonttools 4.37.1
1606 | imageio 2.21.1
1607 | imgaug 0.4.0
1608 | joblib 1.1.0
1609 | kiwisolver 1.4.4
1610 | matplotlib 3.6.0rc1
1611 | mmcv 1.2.5
1612 | multiprocess 0.70.13
1613 | networkx 2.8.6
1614 | numpy 1.23.2
1615 | opencv-python 4.6.0.66
1616 | opencv-python-headless 4.6.0.66
1617 | p-tqdm 1.3.3
1618 | packaging 21.3
1619 | pandas 1.5.0rc0
1620 | pathos 0.2.9
1621 | pathspec 0.9.0
1622 | Pillow 9.2.0
1623 | pip 22.1.2
1624 | pox 0.3.1
1625 | ppft 1.7.6.5
1626 | ptflops 0.6.9
1627 | pyparsing 3.0.9
1628 | python-dateutil 2.8.2
1629 | pytorch-warmup 0.1.0
1630 | pytz 2022.2.1
1631 | PyWavelets 1.3.0
1632 | PyYAML 6.0
1633 | scikit-image 0.19.3
1634 | scikit-learn 1.1.2
1635 | scipy 1.9.0
1636 | setuptools 63.4.1
1637 | Shapely 1.7.0
1638 | six 1.16.0
1639 | sklearn 0.0
1640 | threadpoolctl 3.1.0
1641 | tifffile 2022.8.12
1642 | timm 0.6.7
1643 | torch 1.8.0
1644 | torchvision 0.9.0
1645 | tqdm 4.64.0
1646 | typing_extensions 4.3.0
1647 | ujson 1.35
1648 | wheel 0.37.1
1649 | yapf 0.32.0
1650 | (clrnet) [***@ai02 clrnet]$ git clone https://github.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo.git
1651 | 正克隆到 'CLRNet-onnxruntime-and-tensorrt-demo'...
1652 | remote: Enumerating objects: 68, done.
1653 | remote: Counting objects: 100% (68/68), done.
1654 | remote: Compressing objects: 100% (65/65), done.
1655 | remote: Total 68 (delta 36), reused 0 (delta 0), pack-reused 0
1656 | Unpacking objects: 100% (68/68), done.
1657 | (clrnet) [***@ai02 clrnet]$ ls
1658 | build clrnet clrnet.egg-info CLRNet-onnxruntime-and-tensorrt-demo configs LICENSE main.py README.md requirements.txt setup.py tools
1659 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/clr_head.py ./clrnet/models/heads/
1660 | (clrnet) [***@ai02 clrnet]$ mkdir modules
1661 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/grid_sample.py ./modules/
1662 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/torch2onnx.py ./
1663 | (clrnet) [***@ai02 clrnet]$ python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth
1664 | pretrained model: https://download.pytorch.org/models/resnet18-5c106cde.pth
1665 | Traceback (most recent call last):
1666 | File "torch2onnx.py", line 49, in
1667 | main()
1668 | File "torch2onnx.py", line 22, in main
1669 | state_dict = torch.load(args.load_from, map_location='cpu')['net']
1670 | File "/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/serialization.py", line 579, in load
1671 | with _open_file_like(f, 'rb') as opened_file:
1672 | File "/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/serialization.py", line 230, in _open_file_like
1673 | return _open_file(name_or_buffer, mode)
1674 | File "/data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/torch/serialization.py", line 211, in __init__
1675 | super(_open_file, self).__init__(open(name, mode))
1676 | FileNotFoundError: [Errno 2] No such file or directory: 'tusimple_r18.pth'
1677 | (clrnet) [***@ai02 clrnet]$ wget https://github.com/Turoad/CLRNet/releases/download/models/tusimple_r18.pth.zip
1678 | --2022-08-25 13:33:12-- https://github.com/Turoad/CLRNet/releases/download/models/tusimple_r18.pth.zip
1679 | 正在解析主机 github.com (github.com)... 20.205.243.166
1680 | 正在连接 github.com (github.com)|20.205.243.166|:443... 已连接。
1681 | 无法建立 SSL 连接。
1682 | (clrnet) [***@ai02 clrnet]$ cp /data3/***/project/tusimple_r18.pth.zip ./
1683 | (clrnet) [***@ai02 clrnet]$ unzip tusimple_r18.pth.zip
1684 | Archive: tusimple_r18.pth.zip
1685 | inflating: tusimple_r18.pth
1686 | (clrnet) [***@ai02 clrnet]$ python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth
1687 | pretrained model: https://download.pytorch.org/models/resnet18-5c106cde.pth
1688 | /data3/***/project/temp/clrnet/modules/grid_sample.py:35: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
1689 | assert n == gn
1690 | /data3/***/project/temp/clrnet/modules/grid_sample.py:69: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1691 | x0 = torch.where(x0 < 0, torch.tensor(0, device=x0.device), x0)
1692 | /data3/***/project/temp/clrnet/modules/grid_sample.py:70: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1693 | x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=x0.device), x0)
1694 | /data3/***/project/temp/clrnet/modules/grid_sample.py:70: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
1695 | x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=x0.device), x0)
1696 | /data3/***/project/temp/clrnet/modules/grid_sample.py:71: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1697 | x1 = torch.where(x1 < 0, torch.tensor(0, device=x1.device), x1)
1698 | /data3/***/project/temp/clrnet/modules/grid_sample.py:72: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1699 | x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=x1.device), x1)
1700 | /data3/***/project/temp/clrnet/modules/grid_sample.py:72: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
1701 | x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=x1.device), x1)
1702 | /data3/***/project/temp/clrnet/modules/grid_sample.py:73: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1703 | y0 = torch.where(y0 < 0, torch.tensor(0, device=y0.device), y0)
1704 | /data3/***/project/temp/clrnet/modules/grid_sample.py:74: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1705 | y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=y0.device), y0)
1706 | /data3/***/project/temp/clrnet/modules/grid_sample.py:74: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
1707 | y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=y0.device), y0)
1708 | /data3/***/project/temp/clrnet/modules/grid_sample.py:75: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1709 | y1 = torch.where(y1 < 0, torch.tensor(0, device=y1.device), y1)
1710 | /data3/***/project/temp/clrnet/modules/grid_sample.py:76: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
1711 | y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=y1.device), y1)
1712 | /data3/***/project/temp/clrnet/modules/grid_sample.py:76: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
1713 | y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=y1.device), y1)
1714 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/test.jpg ./
1715 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx ./
1716 | cp: 无法获取"CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx" 的文件状态(stat): 没有那个文件或目录
1717 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx.py ./
1718 | (clrnet) [***@ai02 clrnet]$ cp CLRNet-onnxruntime-and-tensorrt-demo/demo_onnx_new.py ./
1719 | (clrnet) [***@ai02 clrnet]$ pytho demo_onnx.py
1720 | bash: pytho: 未找到命令...
1721 | (clrnet) [***@ai02 clrnet]$ python demo_onnx.py
1722 | Traceback (most recent call last):
1723 | File "demo_onnx.py", line 6, in
1724 | import onnxruntime
1725 | ModuleNotFoundError: No module named 'onnxruntime'
1726 | (clrnet) [***@ai02 clrnet]$ pip install onnxruntime
1727 | Collecting onnxruntime
1728 | Using cached onnxruntime-1.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
1729 | Collecting protobuf
1730 | Downloading protobuf-4.21.5-cp37-abi3-manylinux2014_x86_64.whl (408 kB)
1731 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 408.4/408.4 kB 499.8 kB/s eta 0:00:00
1732 | Collecting coloredlogs
1733 | Using cached coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
1734 | Collecting flatbuffers
1735 | Downloading flatbuffers-2.0.7-py2.py3-none-any.whl (26 kB)
1736 | Requirement already satisfied: numpy>=1.21.0 in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages (from onnxruntime) (1.23.2)
1737 | Collecting sympy
1738 | Downloading sympy-1.11-py3-none-any.whl (6.5 MB)
1739 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.5/6.5 MB 6.1 MB/s eta 0:00:00
1740 | Requirement already satisfied: packaging in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/packaging-21.3-py3.8.egg (from onnxruntime) (21.3)
1741 | Collecting humanfriendly>=9.1
1742 | Using cached humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
1743 | Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /data3/***/anaconda3/envs/clrnet/lib/python3.8/site-packages/pyparsing-3.0.9-py3.8.egg (from packaging->onnxruntime) (3.0.9)
1744 | Collecting mpmath>=0.19
1745 | Using cached mpmath-1.2.1-py3-none-any.whl (532 kB)
1746 | Installing collected packages: mpmath, flatbuffers, sympy, protobuf, humanfriendly, coloredlogs, onnxruntime
1747 | Successfully installed coloredlogs-15.0.1 flatbuffers-2.0.7 humanfriendly-10.0 mpmath-1.2.1 onnxruntime-1.12.1 protobuf-4.21.5 sympy-1.11
1748 | (clrnet) [***@ai02 clrnet]$ ls
1749 | build clrnet.egg-info configs demo_onnx.py main.py README.md setup.py tools tusimple_r18.onnx tusimple_r18.pth.zip
1750 | clrnet CLRNet-onnxruntime-and-tensorrt-demo demo_onnx_new.py LICENSE modules requirements.txt test.jpg torch2onnx.py tusimple_r18.pth
1751 | (clrnet) [***@ai02 clrnet]$ python demo_onnx.py
1752 | demo_onnx.py:141: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
1753 | Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
1754 | )[::-1].cumprod()[::-1]).astype(np.bool))
1755 | (clrnet) [***@ai02 clrnet]$ ls
1756 | build clrnet.egg-info configs demo_onnx.py main.py output_onnx.png requirements.txt test.jpg torch2onnx.py tusimple_r18.pth
1757 | clrnet CLRNet-onnxruntime-and-tensorrt-demo demo_onnx_new.py LICENSE modules README.md setup.py tools tusimple_r18.onnx tusimple_r18.pth.zip
1758 | (clrnet) [***@ai02 clrnet]$ python demo_onnx_new.py
1759 | demo_onnx_new.py:186: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
1760 | Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
1761 | )[::-1].cumprod()[::-1]).astype(np.bool))
1762 | Done!
1763 | (clrnet) [***@ai02 clrnet]$
1764 |
--------------------------------------------------------------------------------
/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/CLRNet-onnxruntime-and-tensorrt-demo/07a47e666d4bfca9a4c60a777c62c23828fca663/test.jpg
--------------------------------------------------------------------------------
/torch2onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import torch.nn.parallel
5 | import torch.backends.cudnn as cudnn
6 | import argparse
7 | import numpy as np
8 | import random
9 | from clrnet.utils.config import Config
10 | from clrnet.datasets import build_dataloader
11 | from clrnet.models.registry import build_net
12 |
13 | def main():
14 | args = parse_args()
15 | cfg = Config.fromfile(args.config)
16 | cfg.load_from = args.load_from
17 | net = build_net(cfg)
18 | net = net.cpu()
19 |
20 | from collections import OrderedDict
21 | new_state_dict = OrderedDict()
22 | state_dict = torch.load(args.load_from, map_location='cpu')['net']
23 |
24 | for k, v in state_dict.items():
25 | namekey = k[7:] # 去掉module前缀
26 | new_state_dict[namekey] = v
27 |
28 | net.load_state_dict(new_state_dict)
29 |
30 | dummy_input = torch.randn(1, 3, 320,800, device='cpu')
31 | torch.onnx.export(net, dummy_input, 'tusimple_r18.onnx',
32 | export_params=True, opset_version=11, do_constant_folding=True,
33 | input_names = ['input'])
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train a detector')
37 | parser.add_argument('config', help='train config file path')
38 |
39 | parser.add_argument('--load_from',
40 | default=None,
41 | help='the checkpoint file to load from')
42 |
43 | args = parser.parse_args()
44 |
45 | return args
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
51 | # python torch2onnx.py configs/clrnet/clr_resnet18_tusimple.py --load_from tusimple_r18.pth
52 |
53 |
--------------------------------------------------------------------------------