├── 001.avi ├── README.md ├── __pycache__ └── func4video.cpython-37.pyc ├── deep_sort ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── deep_sort.cpython-36.pyc │ ├── deep_sort.cpython-37.pyc │ └── deep_sort.cpython-38.pyc ├── deep │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── feature_extractor.cpython-36.pyc │ │ ├── feature_extractor.cpython-37.pyc │ │ ├── feature_extractor.cpython-38.pyc │ │ ├── model.cpython-36.pyc │ │ ├── model.cpython-37.pyc │ │ ├── model.cpython-38.pyc │ │ └── utilsss.cpython-37.pyc │ ├── checkpoint │ │ ├── .gitkeep │ │ └── ckpt.t7 │ ├── evaluate.py │ ├── feature_extractor.py │ ├── model.py │ ├── original_model.py │ ├── test.py │ ├── train.jpg │ ├── train.py │ └── utilsss.py ├── deep_sort.py └── sort │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── detection.cpython-36.pyc │ ├── detection.cpython-37.pyc │ ├── detection.cpython-38.pyc │ ├── iou_matching.cpython-36.pyc │ ├── iou_matching.cpython-37.pyc │ ├── iou_matching.cpython-38.pyc │ ├── kalman_filter.cpython-36.pyc │ ├── kalman_filter.cpython-37.pyc │ ├── kalman_filter.cpython-38.pyc │ ├── linear_assignment.cpython-36.pyc │ ├── linear_assignment.cpython-37.pyc │ ├── linear_assignment.cpython-38.pyc │ ├── nn_matching.cpython-36.pyc │ ├── nn_matching.cpython-37.pyc │ ├── nn_matching.cpython-38.pyc │ ├── preprocessing.cpython-36.pyc │ ├── preprocessing.cpython-37.pyc │ ├── preprocessing.cpython-38.pyc │ ├── track.cpython-36.pyc │ ├── track.cpython-37.pyc │ ├── track.cpython-38.pyc │ ├── tracker.cpython-36.pyc │ ├── tracker.cpython-37.pyc │ └── tracker.cpython-38.pyc │ ├── detection.py │ ├── iou_matching.py │ ├── kalman_filter.py │ ├── linear_assignment.py │ ├── nn_matching.py │ ├── preprocessing.py │ ├── track.py │ └── tracker.py ├── detector ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc └── yolov7 │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── get_map.py │ ├── img │ └── street.jpg │ ├── kmeans_for_anchors.py │ ├── nets │ ├── __init__.py │ ├── backbone.py │ ├── yolo.py │ └── yolo_training.py │ ├── predict.py │ ├── requirements.txt │ ├── summary.py │ ├── train.py │ ├── utils │ ├── __init__.py │ ├── callbacks.py │ ├── dataloader.py │ ├── utils.py │ ├── utils_bbox.py │ ├── utils_fit.py │ └── utils_map.py │ ├── utils_coco │ ├── coco_annotation.py │ └── get_map_coco.py │ ├── voc_annotation.py │ ├── yolo.py │ └── 常见问题汇总.md ├── img ├── 1 └── result.png ├── output ├── results 00_00_00-00_00_30~1.gif └── results.txt ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── draw.cpython-36.pyc │ ├── draw.cpython-37.pyc │ ├── draw.cpython-38.pyc │ ├── io.cpython-36.pyc │ ├── io.cpython-37.pyc │ ├── io.cpython-38.pyc │ ├── log.cpython-36.pyc │ ├── log.cpython-37.pyc │ ├── log.cpython-38.pyc │ ├── parser.cpython-36.pyc │ ├── parser.cpython-37.pyc │ └── parser.cpython-38.pyc ├── asserts.py ├── draw.py ├── evaluation.py ├── io.py ├── json_logger.py ├── log.py ├── parser.py └── tools.py └── yolov7_deepsort.py /001.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/001.avi -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Sort with Yolov7 in Pytorch 2 | 3 | [Video](https://www.bilibili.com/video/BV1be4y1R73B/?vd_source=5f91ae5ee1720ed85b51761270b90025) 4 | 5 | - Version 1.1 6 | 7 | Add tracks in img. 8 | 9 | - Version 1.0 10 | The source code of deep sort is from [Deep Sort](https://github.com/ZQPei/deep_sort_pytorch) and the source code of YoloV7-pytorch is from [Yolov7-pytorch](https://github.com/bubbliiiing/yolov7-pytorch). You could download corresponding pre-trained weights from the original project pages. 11 | When using this code, you may need to adjust the file address of pre-training weight and coco_classes file 12 | 13 | ## Environment 14 | 15 | torch>=1.2 16 | 17 | ## Reference 18 | 19 | [https://github.com/bubbliiiing/yolov7-pytorch](https://github.com/bubbliiiing/yolov7-pytorch) 20 | 21 | [https://github.com/ZQPei/deep_sort_pytorch](https://github.com/ZQPei/deep_sort_pytorch) 22 | 23 | 24 | -------------------------------------------------------------------------------- /__pycache__/func4video.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/__pycache__/func4video.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/README.md: -------------------------------------------------------------------------------- 1 | # Deep Sort 2 | 3 | This is the implemention of deep sort with pytorch. -------------------------------------------------------------------------------- /deep_sort/__init__.py: -------------------------------------------------------------------------------- 1 | from .deep_sort import DeepSort 2 | 3 | 4 | __all__ = ['DeepSort', 'build_tracker'] 5 | 6 | 7 | def build_tracker(use_cuda): 8 | return DeepSort('./deep_sort/deep/checkpoint/ckpt.t7',# namesfile=cfg.DEEPSORT.CLASS_NAMES, 9 | max_dist=0.2, min_confidence=0.1, 10 | nms_max_overlap=0.5, max_iou_distance=0.7, 11 | max_age=70, n_init=3, nn_budget=100, use_cuda=True) 12 | 13 | 14 | # def build_tracker(cfg, use_cuda): 15 | # return DeepSort(cfg.DEEPSORT.REID_CKPT,# namesfile=cfg.DEEPSORT.CLASS_NAMES, 16 | # max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE, 17 | # nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE, 18 | # max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda) 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /deep_sort/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/__pycache__/deep_sort.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/__pycache__/deep_sort.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/__pycache__/deep_sort.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/__pycache__/deep_sort.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/__pycache__/deep_sort.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/__pycache__/deep_sort.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__init__.py -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/feature_extractor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/feature_extractor.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/feature_extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/feature_extractor.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/feature_extractor.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/deep/__pycache__/utilsss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/__pycache__/utilsss.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/deep/checkpoint/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/checkpoint/.gitkeep -------------------------------------------------------------------------------- /deep_sort/deep/checkpoint/ckpt.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/checkpoint/ckpt.t7 -------------------------------------------------------------------------------- /deep_sort/deep/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | features = torch.load("features.pth") 4 | qf = features["qf"] 5 | ql = features["ql"] 6 | gf = features["gf"] 7 | gl = features["gl"] 8 | 9 | scores = qf.mm(gf.t()) 10 | res = scores.topk(5, dim=1)[1][:,0] 11 | top1correct = gl[res].eq(ql).sum().item() 12 | 13 | print("Acc top1:{:.3f}".format(top1correct/ql.size(0))) 14 | 15 | 16 | -------------------------------------------------------------------------------- /deep_sort/deep/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import cv2 5 | import logging 6 | 7 | from .model import Net 8 | 9 | class Extractor(object): 10 | def __init__(self, model_path, use_cuda=True): 11 | self.net = Net(reid=True) 12 | self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" 13 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict'] 14 | self.net.load_state_dict(state_dict) 15 | logger = logging.getLogger("root.tracker") 16 | logger.info("Loading weights from {}... Done!".format(model_path)) 17 | self.net.to(self.device) 18 | self.size = (64, 128) 19 | self.norm = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 22 | ]) 23 | 24 | 25 | 26 | def _preprocess(self, im_crops): 27 | """ 28 | TODO: 29 | 1. to float with scale from 0 to 1 30 | 2. resize to (64, 128) as Market1501 dataset did 31 | 3. concatenate to a numpy array 32 | 3. to torch Tensor 33 | 4. normalize 34 | """ 35 | def _resize(im, size): 36 | return cv2.resize(im.astype(np.float32)/255., size) 37 | 38 | im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float() 39 | return im_batch 40 | 41 | 42 | def __call__(self, im_crops): 43 | im_batch = self._preprocess(im_crops) 44 | with torch.no_grad(): 45 | im_batch = im_batch.to(self.device) 46 | features = self.net(im_batch) 47 | return features.cpu().numpy() 48 | 49 | 50 | # if __name__ == '__main__': 51 | # img = cv2.imread("demo.jpg")[:,:,(2,1,0)] 52 | # extr = Extractor("checkpoint/ckpt.t7") 53 | # feature = extr(img) 54 | # print(feature.shape) 55 | 56 | -------------------------------------------------------------------------------- /deep_sort/deep/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicBlock(nn.Module): 6 | def __init__(self, c_in, c_out,is_downsample=False): 7 | super(BasicBlock,self).__init__() 8 | self.is_downsample = is_downsample 9 | if is_downsample: 10 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False) 11 | else: 12 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(c_out) 14 | self.relu = nn.ReLU(True) 15 | self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(c_out) 17 | if is_downsample: 18 | self.downsample = nn.Sequential( 19 | nn.Conv2d(c_in, c_out, 1, stride=2, bias=False), 20 | nn.BatchNorm2d(c_out) 21 | ) 22 | elif c_in != c_out: 23 | self.downsample = nn.Sequential( 24 | nn.Conv2d(c_in, c_out, 1, stride=1, bias=False), 25 | nn.BatchNorm2d(c_out) 26 | ) 27 | self.is_downsample = True 28 | 29 | def forward(self,x): 30 | y = self.conv1(x) 31 | y = self.bn1(y) 32 | y = self.relu(y) 33 | y = self.conv2(y) 34 | y = self.bn2(y) 35 | if self.is_downsample: 36 | x = self.downsample(x) 37 | return F.relu(x.add(y),True) 38 | 39 | def make_layers(c_in,c_out,repeat_times, is_downsample=False): 40 | blocks = [] 41 | for i in range(repeat_times): 42 | if i ==0: 43 | blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),] 44 | else: 45 | blocks += [BasicBlock(c_out,c_out),] 46 | return nn.Sequential(*blocks) 47 | 48 | class Net(nn.Module): 49 | def __init__(self, num_classes=751 ,reid=False): 50 | super(Net,self).__init__() 51 | # 3 128 64 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(3,64,3,stride=1,padding=1), 54 | nn.BatchNorm2d(64), 55 | nn.ReLU(inplace=True), 56 | # nn.Conv2d(32,32,3,stride=1,padding=1), 57 | # nn.BatchNorm2d(32), 58 | # nn.ReLU(inplace=True), 59 | nn.MaxPool2d(3,2,padding=1), 60 | ) 61 | # 32 64 32 62 | self.layer1 = make_layers(64,64,2,False) 63 | # 32 64 32 64 | self.layer2 = make_layers(64,128,2,True) 65 | # 64 32 16 66 | self.layer3 = make_layers(128,256,2,True) 67 | # 128 16 8 68 | self.layer4 = make_layers(256,512,2,True) 69 | # 256 8 4 70 | self.avgpool = nn.AvgPool2d((8,4),1) 71 | # 256 1 1 72 | self.reid = reid 73 | self.classifier = nn.Sequential( 74 | nn.Linear(512, 256), 75 | nn.BatchNorm1d(256), 76 | nn.ReLU(inplace=True), 77 | nn.Dropout(), 78 | nn.Linear(256, num_classes), 79 | ) 80 | 81 | def forward(self, x): 82 | x = self.conv(x) 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | x = self.layer4(x) 87 | x = self.avgpool(x) 88 | x = x.view(x.size(0),-1) 89 | # B x 128 90 | if self.reid: 91 | x = x.div(x.norm(p=2,dim=1,keepdim=True)) 92 | return x 93 | # classifier 94 | x = self.classifier(x) 95 | return x 96 | 97 | 98 | if __name__ == '__main__': 99 | net = Net() 100 | x = torch.randn(4,3,128,64) 101 | y = net(x) 102 | import ipdb; ipdb.set_trace() 103 | 104 | 105 | -------------------------------------------------------------------------------- /deep_sort/deep/original_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicBlock(nn.Module): 6 | def __init__(self, c_in, c_out,is_downsample=False): 7 | super(BasicBlock,self).__init__() 8 | self.is_downsample = is_downsample 9 | if is_downsample: 10 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False) 11 | else: 12 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(c_out) 14 | self.relu = nn.ReLU(True) 15 | self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(c_out) 17 | if is_downsample: 18 | self.downsample = nn.Sequential( 19 | nn.Conv2d(c_in, c_out, 1, stride=2, bias=False), 20 | nn.BatchNorm2d(c_out) 21 | ) 22 | elif c_in != c_out: 23 | self.downsample = nn.Sequential( 24 | nn.Conv2d(c_in, c_out, 1, stride=1, bias=False), 25 | nn.BatchNorm2d(c_out) 26 | ) 27 | self.is_downsample = True 28 | 29 | def forward(self,x): 30 | y = self.conv1(x) 31 | y = self.bn1(y) 32 | y = self.relu(y) 33 | y = self.conv2(y) 34 | y = self.bn2(y) 35 | if self.is_downsample: 36 | x = self.downsample(x) 37 | return F.relu(x.add(y),True) 38 | 39 | def make_layers(c_in,c_out,repeat_times, is_downsample=False): 40 | blocks = [] 41 | for i in range(repeat_times): 42 | if i ==0: 43 | blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),] 44 | else: 45 | blocks += [BasicBlock(c_out,c_out),] 46 | return nn.Sequential(*blocks) 47 | 48 | class Net(nn.Module): 49 | def __init__(self, num_classes=625 ,reid=False): 50 | super(Net,self).__init__() 51 | # 3 128 64 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(3,32,3,stride=1,padding=1), 54 | nn.BatchNorm2d(32), 55 | nn.ELU(inplace=True), 56 | nn.Conv2d(32,32,3,stride=1,padding=1), 57 | nn.BatchNorm2d(32), 58 | nn.ELU(inplace=True), 59 | nn.MaxPool2d(3,2,padding=1), 60 | ) 61 | # 32 64 32 62 | self.layer1 = make_layers(32,32,2,False) 63 | # 32 64 32 64 | self.layer2 = make_layers(32,64,2,True) 65 | # 64 32 16 66 | self.layer3 = make_layers(64,128,2,True) 67 | # 128 16 8 68 | self.dense = nn.Sequential( 69 | nn.Dropout(p=0.6), 70 | nn.Linear(128*16*8, 128), 71 | nn.BatchNorm1d(128), 72 | nn.ELU(inplace=True) 73 | ) 74 | # 256 1 1 75 | self.reid = reid 76 | self.batch_norm = nn.BatchNorm1d(128) 77 | self.classifier = nn.Sequential( 78 | nn.Linear(128, num_classes), 79 | ) 80 | 81 | def forward(self, x): 82 | x = self.conv(x) 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | 87 | x = x.view(x.size(0),-1) 88 | if self.reid: 89 | x = self.dense[0](x) 90 | x = self.dense[1](x) 91 | x = x.div(x.norm(p=2,dim=1,keepdim=True)) 92 | return x 93 | x = self.dense(x) 94 | # B x 128 95 | # classifier 96 | x = self.classifier(x) 97 | return x 98 | 99 | 100 | if __name__ == '__main__': 101 | net = Net(reid=True) 102 | x = torch.randn(4,3,128,64) 103 | y = net(x) 104 | import ipdb; ipdb.set_trace() 105 | 106 | 107 | -------------------------------------------------------------------------------- /deep_sort/deep/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torchvision 4 | 5 | import argparse 6 | import os 7 | 8 | from model import Net 9 | 10 | parser = argparse.ArgumentParser(description="Train on market1501") 11 | parser.add_argument("--data-dir",default='data',type=str) 12 | parser.add_argument("--no-cuda",action="store_true") 13 | parser.add_argument("--gpu-id",default=0,type=int) 14 | args = parser.parse_args() 15 | 16 | # device 17 | device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" 18 | if torch.cuda.is_available() and not args.no_cuda: 19 | cudnn.benchmark = True 20 | 21 | # data loader 22 | root = args.data_dir 23 | query_dir = os.path.join(root,"query") 24 | gallery_dir = os.path.join(root,"gallery") 25 | transform = torchvision.transforms.Compose([ 26 | torchvision.transforms.Resize((128,64)), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | queryloader = torch.utils.data.DataLoader( 31 | torchvision.datasets.ImageFolder(query_dir, transform=transform), 32 | batch_size=64, shuffle=False 33 | ) 34 | galleryloader = torch.utils.data.DataLoader( 35 | torchvision.datasets.ImageFolder(gallery_dir, transform=transform), 36 | batch_size=64, shuffle=False 37 | ) 38 | 39 | # net definition 40 | net = Net(reid=True) 41 | assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" 42 | print('Loading from checkpoint/ckpt.t7') 43 | checkpoint = torch.load("./checkpoint/ckpt.t7") 44 | net_dict = checkpoint['net_dict'] 45 | net.load_state_dict(net_dict, strict=False) 46 | net.eval() 47 | net.to(device) 48 | 49 | # compute features 50 | query_features = torch.tensor([]).float() 51 | query_labels = torch.tensor([]).long() 52 | gallery_features = torch.tensor([]).float() 53 | gallery_labels = torch.tensor([]).long() 54 | 55 | with torch.no_grad(): 56 | for idx,(inputs,labels) in enumerate(queryloader): 57 | inputs = inputs.to(device) 58 | features = net(inputs).cpu() 59 | query_features = torch.cat((query_features, features), dim=0) 60 | query_labels = torch.cat((query_labels, labels)) 61 | 62 | for idx,(inputs,labels) in enumerate(galleryloader): 63 | inputs = inputs.to(device) 64 | features = net(inputs).cpu() 65 | gallery_features = torch.cat((gallery_features, features), dim=0) 66 | gallery_labels = torch.cat((gallery_labels, labels)) 67 | 68 | gallery_labels -= 2 69 | 70 | # save features 71 | features = { 72 | "qf": query_features, 73 | "ql": query_labels, 74 | "gf": gallery_features, 75 | "gl": gallery_labels 76 | } 77 | torch.save(features,"features.pth") -------------------------------------------------------------------------------- /deep_sort/deep/train.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/deep/train.jpg -------------------------------------------------------------------------------- /deep_sort/deep/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | from utilsss import MyDataset, validate, show_confMat 12 | from torch.utils.data import DataLoader 13 | 14 | from model import Net 15 | 16 | parser = argparse.ArgumentParser(description="Train on market1501") 17 | parser.add_argument("--data-dir",default='data',type=str) 18 | parser.add_argument("--no-cuda",action="store_true") 19 | parser.add_argument("--gpu-id",default=0,type=int) 20 | parser.add_argument("--lr",default=0.1, type=float) 21 | parser.add_argument("--interval",'-i',default=20,type=int) 22 | parser.add_argument('--resume', '-r',action='store_true') 23 | args = parser.parse_args() 24 | 25 | 26 | train_bs = 128 #batch_size 27 | valid_bs = 128 28 | # device 29 | device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" 30 | if torch.cuda.is_available() and not args.no_cuda: 31 | cudnn.benchmark = True 32 | 33 | # data loading 34 | print('==> Preparing data..') 35 | 36 | transform = [transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(45)] 37 | transform_train = transforms.Compose([ 38 | transforms.Resize((128,64)), 39 | transforms.RandomChoice(transform), 40 | transforms.ToTensor(), 41 | transforms.Normalize((0.37169233, 0.38456926, 0.3438824), (0.20752552, 0.18884005, 0.18621244)) 42 | 43 | ]) 44 | # 1,cifar10; 2, data_rs; 3, NWPU; 4, UCMerced_LandUse 45 | transform_test = transforms.Compose([ 46 | transforms.Resize((128,64)), 47 | transforms.ToTensor(), 48 | transforms.Normalize((0.37169233, 0.38456926, 0.3438824), (0.20752552, 0.18884005, 0.18621244)) 49 | ]) 50 | 51 | 52 | 53 | ####------------------------------import data list-------------------------------------------#### 54 | 55 | train_txt_path = './data/train.txt' 56 | valid_txt_path = './data/test.txt' 57 | 58 | print(train_txt_path, valid_txt_path) 59 | 60 | # 构建MyDataset实例 61 | valid_data = MyDataset(txt_path=train_txt_path, transform=transform_test) 62 | train_data = MyDataset(txt_path=valid_txt_path, transform=transform_train) 63 | 64 | # 构建DataLoder 65 | trainloader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True) 66 | testloader = DataLoader(dataset=valid_data, batch_size=valid_bs) 67 | 68 | 69 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 70 | 71 | ''' 72 | root = args.data_dir 73 | train_dir = os.path.join(root,"train") 74 | test_dir = os.path.join(root,"test") 75 | transform_train = torchvision.transforms.Compose([ 76 | torchvision.transforms.RandomCrop((128,64),padding=4), 77 | torchvision.transforms.RandomHorizontalFlip(), 78 | torchvision.transforms.ToTensor(), 79 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 80 | ]) 81 | transform_test = torchvision.transforms.Compose([ 82 | torchvision.transforms.Resize((128,64)), 83 | torchvision.transforms.ToTensor(), 84 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 85 | ]) 86 | trainloader = torch.utils.data.DataLoader( 87 | torchvision.datasets.ImageFolder(train_dir, transform=transform_train), 88 | batch_size=64,shuffle=True 89 | ) 90 | testloader = torch.utils.data.DataLoader( 91 | torchvision.datasets.ImageFolder(test_dir, transform=transform_test), 92 | batch_size=64,shuffle=True 93 | ) 94 | ''' 95 | num_classes = len(classes) 96 | 97 | # net definition 98 | start_epoch = 0 99 | net = Net(num_classes=num_classes) 100 | if args.resume: 101 | assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" 102 | print('Loading from checkpoint/ckpt.t7') 103 | checkpoint = torch.load("./checkpoint/ckpt.t7") 104 | # import ipdb; ipdb.set_trace() 105 | net_dict = checkpoint['net_dict'] 106 | net.load_state_dict(net_dict) 107 | best_acc = checkpoint['acc'] 108 | start_epoch = checkpoint['epoch'] 109 | net.to(device) 110 | 111 | # loss and optimizer 112 | criterion = torch.nn.CrossEntropyLoss() 113 | optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4) 114 | best_acc = 0. 115 | 116 | # train function for each epoch 117 | def train(epoch): 118 | print("\nEpoch : %d"%(epoch+1)) 119 | net.train() 120 | training_loss = 0. 121 | train_loss = 0. 122 | correct = 0 123 | total = 0 124 | interval = args.interval 125 | start = time.time() 126 | for idx, (inputs, labels) in enumerate(trainloader): 127 | # forward 128 | inputs,labels = inputs.to(device),labels.to(device) 129 | outputs = net(inputs) 130 | loss = criterion(outputs, labels) 131 | 132 | # backward 133 | optimizer.zero_grad() 134 | loss.backward() 135 | optimizer.step() 136 | 137 | # accumurating 138 | training_loss += loss.item() 139 | train_loss += loss.item() 140 | correct += outputs.max(dim=1)[1].eq(labels).sum().item() 141 | total += labels.size(0) 142 | 143 | # print 144 | if (idx+1)%interval == 0: 145 | end = time.time() 146 | print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( 147 | 100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total 148 | )) 149 | training_loss = 0. 150 | start = time.time() 151 | 152 | return train_loss/len(trainloader), 1.- correct/total 153 | 154 | def test(epoch): 155 | global best_acc 156 | net.eval() 157 | test_loss = 0. 158 | correct = 0 159 | total = 0 160 | start = time.time() 161 | with torch.no_grad(): 162 | for idx, (inputs, labels) in enumerate(testloader): 163 | inputs, labels = inputs.to(device), labels.to(device) 164 | outputs = net(inputs) 165 | loss = criterion(outputs, labels) 166 | 167 | test_loss += loss.item() 168 | correct += outputs.max(dim=1)[1].eq(labels).sum().item() 169 | total += labels.size(0) 170 | 171 | print("Testing ...") 172 | end = time.time() 173 | print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( 174 | 100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total 175 | )) 176 | 177 | # saving checkpoint 178 | acc = 100.*correct/total 179 | if acc > best_acc: 180 | best_acc = acc 181 | print("Saving parameters to checkpoint/ckpt.t7") 182 | checkpoint = { 183 | 'net_dict':net.state_dict(), 184 | 'acc':acc, 185 | 'epoch':epoch, 186 | } 187 | if not os.path.isdir('checkpoint'): 188 | os.mkdir('checkpoint') 189 | torch.save(checkpoint, './checkpoint/ckpt.t7') 190 | 191 | return test_loss/len(testloader), 1.- correct/total 192 | 193 | # plot figure 194 | x_epoch = [] 195 | record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]} 196 | fig = plt.figure() 197 | ax0 = fig.add_subplot(121, title="loss") 198 | ax1 = fig.add_subplot(122, title="top1err") 199 | def draw_curve(epoch, train_loss, train_err, test_loss, test_err): 200 | global record 201 | record['train_loss'].append(train_loss) 202 | record['train_err'].append(train_err) 203 | record['test_loss'].append(test_loss) 204 | record['test_err'].append(test_err) 205 | 206 | x_epoch.append(epoch) 207 | ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train') 208 | ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val') 209 | ax1.plot(x_epoch, record['train_err'], 'bo-', label='train') 210 | ax1.plot(x_epoch, record['test_err'], 'ro-', label='val') 211 | if epoch == 0: 212 | ax0.legend() 213 | ax1.legend() 214 | fig.savefig("train.jpg") 215 | 216 | # lr decay 217 | def lr_decay(): 218 | global optimizer 219 | for params in optimizer.param_groups: 220 | params['lr'] *= 0.1 221 | lr = params['lr'] 222 | print("Learning rate adjusted to {}".format(lr)) 223 | 224 | def main(): 225 | for epoch in range(start_epoch, start_epoch+40): 226 | train_loss, train_err = train(epoch) 227 | test_loss, test_err = test(epoch) 228 | draw_curve(epoch, train_loss, train_err, test_loss, test_err) 229 | if (epoch+1)%20==0: 230 | lr_decay() 231 | 232 | 233 | if __name__ == '__main__': 234 | main() 235 | -------------------------------------------------------------------------------- /deep_sort/deep/utilsss.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | import os 8 | import matplotlib.pyplot as plt 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import seaborn as sns 12 | 13 | class Net(nn.Module): 14 | def __init__(self): 15 | super(Net, self).__init__() 16 | self.conv1 = nn.Conv2d(3, 6, 5) 17 | self.pool1 = nn.MaxPool2d(2, 2) 18 | self.conv2 = nn.Conv2d(6, 16, 5) 19 | self.pool2 = nn.MaxPool2d(2, 2) 20 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 21 | self.fc2 = nn.Linear(120, 84) 22 | self.fc3 = nn.Linear(84, 10) 23 | 24 | def forward(self, x): 25 | x = self.pool1(F.relu(self.conv1(x))) 26 | x = self.pool2(F.relu(self.conv2(x))) 27 | x = x.view(-1, 16 * 5 * 5) 28 | x = F.relu(self.fc1(x)) 29 | x = F.relu(self.fc2(x)) 30 | x = self.fc3(x) 31 | return x 32 | 33 | # 定义权值初始化 34 | def initialize_weights(self): 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | torch.nn.init.xavier_normal_(m.weight.data) 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.weight.data.fill_(1) 42 | m.bias.data.zero_() 43 | elif isinstance(m, nn.Linear): 44 | torch.nn.init.normal_(m.weight.data, 0, 0.01) 45 | m.bias.data.zero_() 46 | 47 | class MyDataset(Dataset): 48 | def __init__(self, txt_path, transform = None, target_transform = None): 49 | fh = open(txt_path, 'r') 50 | imgs = [] 51 | for line in fh: 52 | line = line.rstrip() 53 | words = line.split() 54 | imgs.append((words[0], int(words[1]))) 55 | 56 | self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据 57 | self.transform = transform 58 | self.target_transform = target_transform 59 | 60 | def __getitem__(self, index): 61 | fn, label = self.imgs[index] 62 | img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1 63 | 64 | if self.transform is not None: 65 | img = self.transform(img) # 在这里做transform,转为tensor等等 66 | 67 | return img, label 68 | 69 | def __len__(self): 70 | return len(self.imgs) 71 | 72 | 73 | def validate(net, data_loader, set_name, classes_name): 74 | """ 75 | 对一批数据进行预测,返回混淆矩阵以及Accuracy 76 | :param net: 77 | :param data_loader: 78 | :param set_name: eg: 'valid' 'train' 'tesst 79 | :param classes_name: 80 | :return: 81 | """ 82 | net.eval() 83 | cls_num = len(classes_name) 84 | conf_mat = np.zeros([cls_num, cls_num]) 85 | 86 | for data in data_loader: 87 | images, labels = data 88 | images = Variable(images) 89 | labels = Variable(labels) 90 | 91 | outputs = net(images) 92 | outputs.detach_() 93 | 94 | _, predicted = torch.max(outputs.data, 1) 95 | 96 | # 统计混淆矩阵 97 | for i in range(len(labels)): 98 | cate_i = labels[i]#.numpy() 99 | pre_i = predicted[i]#.numpy() 100 | conf_mat[cate_i, pre_i] += 1.0 101 | 102 | for i in range(cls_num): 103 | print('class:{:<10}, total num:{:<6}, correct num:{:<5} Recall: {:.2%} Precision: {:.2%}'.format( 104 | classes_name[i], np.sum(conf_mat[i, :]), conf_mat[i, i], conf_mat[i, i] / (1 + np.sum(conf_mat[i, :])), 105 | conf_mat[i, i] / (1 + np.sum(conf_mat[:, i])))) 106 | 107 | print('{} set Accuracy:{:.2%}'.format(set_name, np.trace(conf_mat) / np.sum(conf_mat))) 108 | 109 | return conf_mat, '{:.2}'.format(np.trace(conf_mat) / np.sum(conf_mat)) 110 | 111 | 112 | def show_confMat(confusion_mat, classes, set_name, out_dir): 113 | 114 | # 归一化 115 | confusion_mat_N = confusion_mat.copy() 116 | for i in range(len(classes)): 117 | confusion_mat_N[i, :] = confusion_mat[i, :] / confusion_mat[i, :].sum() 118 | 119 | # 获取颜色 120 | cmap = plt.cm.get_cmap('coolwarm') #bwr 更多颜色: http://matplotlib.org/examples/color/colormaps_reference.html 121 | plt.figure(figsize = (20,20)) 122 | plt.imshow(confusion_mat_N, cmap=cmap) 123 | plt.colorbar(shrink = 0.8) 124 | 125 | # sns.heatmap(confusion_mat_N, vmin = 0, vmax = 200, center = 0) 126 | # plt.show() 127 | # 128 | # 设置文字 129 | xlocations = np.array(range(len(classes))) 130 | plt.xticks(xlocations, list(classes), rotation=90) 131 | plt.yticks(xlocations, list(classes)) 132 | plt.xlabel('Predict label') 133 | plt.ylabel('True label') 134 | plt.title('MLFCNeXt50')#'Confusion_Matrix_' + set_name)# 135 | 136 | # 打印数字 137 | for i in range(confusion_mat_N.shape[0]): 138 | for j in range(confusion_mat_N.shape[1]): 139 | plt.text(x=j, y=i, s=int(confusion_mat[i, j]), va='center', ha='center', color='white', fontsize=10) 140 | # 保存 141 | plt.savefig(os.path.join(out_dir, 'Confusion_Matrix' + set_name + '.png')) 142 | # plt.close() 143 | 144 | 145 | def normalize_invert(tensor, mean, std): 146 | for t, m, s in zip(tensor, mean, std): 147 | t.mul_(s).add_(m) 148 | return tensor 149 | -------------------------------------------------------------------------------- /deep_sort/deep_sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .deep.feature_extractor import Extractor 5 | from .sort.nn_matching import NearestNeighborDistanceMetric 6 | from .sort.preprocessing import non_max_suppression 7 | from .sort.detection import Detection 8 | from .sort.tracker import Tracker 9 | 10 | 11 | __all__ = ['DeepSort'] 12 | 13 | 14 | class DeepSort(object): 15 | def __init__(self, model_path, max_dist=0.2, \ 16 | min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7,\ 17 | max_age=70, n_init=3, nn_budget=100, use_cuda=True): 18 | self.min_confidence = min_confidence 19 | self.nms_max_overlap = nms_max_overlap 20 | 21 | self.extractor = Extractor(model_path, use_cuda=use_cuda) 22 | 23 | max_cosine_distance = max_dist 24 | nn_budget = 100 25 | metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget) 26 | self.tracker = Tracker(metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init) 27 | #----added by deyiwang 28 | namesfile = './detector/yolov7/model_data/coco_classes.txt' 29 | self.class_names = self.load_class_names(namesfile) 30 | # print(type(self.class_names)) #---list 31 | 32 | def update(self, bbox_xywh, confidences, ori_img): 33 | # print('deep_sort, update, bbox_xywh:\n', bbox_xywh) 34 | self.height, self.width = ori_img.shape[:2] 35 | # generate detections 36 | features = self._get_features(bbox_xywh, ori_img) 37 | # print('deep-sort.py, update, features:', type(features), len(features)) 38 | # print(len(features[0])) 39 | bbox_tlwh = self._xywh_to_tlwh(bbox_xywh) 40 | # print('deep_sort, update, bbox_tlwh:\n', bbox_tlwh) 41 | # print('deep-sort.py, update, bbox_tlwh:', type(bbox_tlwh), len(bbox_tlwh)) 42 | detections = [Detection(bbox_tlwh[i], conf, features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence] 43 | # print('deep-sort.py, update, detections:\n', type(detections), len(detections),detections) 44 | 45 | # print(dir(detections)) 46 | #-----added by deyiwang 47 | # run on non-maximum supression 48 | boxes = np.array([d.tlwh for d in detections]) 49 | # print('deep_sort, update, boxes:\n', boxes) 50 | scores = np.array([d.confidence for d in detections]) 51 | # print('deep_sort, update, scores:\n', scores) 52 | indices = non_max_suppression(boxes, self.nms_max_overlap, scores) 53 | # print('deep_sort, update, indices:\n', indices) 54 | detections = [detections[i] for i in indices] 55 | 56 | # update tracker 57 | self.tracker.predict() 58 | self.tracker.update(detections) 59 | 60 | # output bbox identities 61 | outputs = [] 62 | for track in self.tracker.tracks: 63 | if not track.is_confirmed() or track.time_since_update > 1: 64 | continue 65 | box = track.to_tlwh() 66 | x1,y1,x2,y2 = self._tlwh_to_xyxy(box) 67 | track_id = track.track_id 68 | outputs.append(np.array([x1,y1,x2,y2,track_id], dtype=np.int)) 69 | if len(outputs) > 0: 70 | outputs = np.stack(outputs,axis=0) 71 | return outputs 72 | 73 | def load_class_names(self, namesfile): 74 | with open(namesfile, 'r', encoding='utf8') as fp: 75 | class_names = [line.strip() for line in fp.readlines()] 76 | return class_names 77 | 78 | 79 | """ 80 | TODO: 81 | Convert bbox from xc_yc_w_h to xtl_ytl_w_h 82 | Thanks JieChen91@github.com for reporting this bug! 83 | """ 84 | @staticmethod 85 | def _xywh_to_tlwh(bbox_xywh): 86 | if bbox_xywh is not None: 87 | bbox_tlwh = np.zeros((bbox_xywh.shape[0],bbox_xywh.shape[1])) 88 | if isinstance(bbox_xywh, np.ndarray): 89 | bbox_tlwh = bbox_xywh.copy() 90 | elif isinstance(bbox_xywh, torch.Tensor): 91 | bbox_tlwh = bbox_xywh.clone() 92 | 93 | bbox_tlwh[:,0] = bbox_xywh[:,0] - bbox_xywh[:,2]/2. 94 | bbox_tlwh[:,1] = bbox_xywh[:,1] - bbox_xywh[:,3]/2. 95 | bbox_tlwh[:,2] = bbox_xywh[:,2] 96 | bbox_tlwh[:,3] = bbox_xywh[:,3] 97 | return bbox_tlwh 98 | 99 | 100 | def _xywh_to_xyxy(self, bbox_xywh): 101 | x,y,w,h = bbox_xywh 102 | x1 = max(int(x-w/2),0) 103 | x2 = min(int(x+w/2),self.width-1) 104 | y1 = max(int(y-h/2),0) 105 | y2 = min(int(y+h/2),self.height-1) 106 | return x1,y1,x2,y2 107 | 108 | def _tlwh_to_xyxy(self, bbox_tlwh): 109 | """ 110 | TODO: 111 | Convert bbox from xtl_ytl_w_h to xc_yc_w_h 112 | Thanks JieChen91@github.com for reporting this bug! 113 | """ 114 | x,y,w,h = bbox_tlwh 115 | x1 = max(int(x),0) 116 | x2 = min(int(x+w),self.width-1) 117 | y1 = max(int(y),0) 118 | y2 = min(int(y+h),self.height-1) 119 | return x1,y1,x2,y2 120 | 121 | def _xyxy_to_tlwh(self, bbox_xyxy): 122 | x1,y1,x2,y2 = bbox_xyxy 123 | 124 | t = x1 125 | l = y1 126 | w = int(x2-x1) 127 | h = int(y2-y1) 128 | return t,l,w,h 129 | 130 | def _get_features(self, bbox_xywh, ori_img): 131 | im_crops = [] 132 | for box in bbox_xywh: 133 | x1,y1,x2,y2 = self._xywh_to_xyxy(box) 134 | im = ori_img[y1:y2,x1:x2] 135 | im_crops.append(im) 136 | if im_crops: 137 | features = self.extractor(im_crops) 138 | else: 139 | features = np.array([]) 140 | return features 141 | 142 | 143 | -------------------------------------------------------------------------------- /deep_sort/sort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__init__.py -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/detection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/detection.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/detection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/detection.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/detection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/detection.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/iou_matching.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/iou_matching.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/iou_matching.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/iou_matching.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/iou_matching.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/kalman_filter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/kalman_filter.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/kalman_filter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/kalman_filter.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/kalman_filter.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/linear_assignment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/linear_assignment.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/linear_assignment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/linear_assignment.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/linear_assignment.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/nn_matching.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/nn_matching.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/nn_matching.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/nn_matching.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/nn_matching.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/preprocessing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/preprocessing.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/preprocessing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/preprocessing.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/track.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/track.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/track.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/track.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/track.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/track.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/tracker.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/tracker.cpython-36.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/tracker.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/tracker.cpython-37.pyc -------------------------------------------------------------------------------- /deep_sort/sort/__pycache__/tracker.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/deep_sort/sort/__pycache__/tracker.cpython-38.pyc -------------------------------------------------------------------------------- /deep_sort/sort/detection.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | 4 | 5 | class Detection(object): 6 | """ 7 | This class represents a bounding box detection in a single image. 8 | 9 | Parameters 10 | ---------- 11 | tlwh : array_like 12 | Bounding box in format `(x, y, w, h)`. 13 | confidence : float 14 | Detector confidence score. 15 | feature : array_like 16 | A feature vector that describes the object contained in this image. 17 | 18 | Attributes 19 | ---------- 20 | tlwh : ndarray 21 | Bounding box in format `(top left x, top left y, width, height)`. 22 | confidence : ndarray 23 | Detector confidence score. 24 | feature : ndarray | NoneType 25 | A feature vector that describes the object contained in this image. 26 | 27 | """ 28 | 29 | def __init__(self, tlwh, confidence, feature): 30 | self.tlwh = np.asarray(tlwh, dtype=np.float) 31 | # print(self.tlwh) 32 | self.confidence = float(confidence) 33 | # print(self.confidence) 34 | self.feature = np.asarray(feature, dtype=np.float32) 35 | # print(len(self.feature)) 36 | 37 | def to_tlbr(self): 38 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., 39 | `(top left, bottom right)`. 40 | """ 41 | ret = self.tlwh.copy() 42 | ret[2:] += ret[:2] 43 | return ret 44 | 45 | def to_xyah(self): 46 | """Convert bounding box to format `(center x, center y, aspect ratio, 47 | height)`, where the aspect ratio is `width / height`. 48 | """ 49 | ret = self.tlwh.copy() 50 | ret[:2] += ret[2:] / 2 51 | ret[2] /= ret[3] 52 | return ret 53 | -------------------------------------------------------------------------------- /deep_sort/sort/iou_matching.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from . import linear_assignment 5 | 6 | 7 | def iou(bbox, candidates): 8 | """Computer intersection over union. 9 | 10 | Parameters 11 | ---------- 12 | bbox : ndarray 13 | A bounding box in format `(top left x, top left y, width, height)`. 14 | candidates : ndarray 15 | A matrix of candidate bounding boxes (one per row) in the same format 16 | as `bbox`. 17 | 18 | Returns 19 | ------- 20 | ndarray 21 | The intersection over union in [0, 1] between the `bbox` and each 22 | candidate. A higher score means a larger fraction of the `bbox` is 23 | occluded by the candidate. 24 | 25 | """ 26 | bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] 27 | candidates_tl = candidates[:, :2] 28 | candidates_br = candidates[:, :2] + candidates[:, 2:] 29 | 30 | tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], 31 | np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] 32 | br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], 33 | np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] 34 | wh = np.maximum(0., br - tl) 35 | 36 | area_intersection = wh.prod(axis=1) 37 | area_bbox = bbox[2:].prod() 38 | area_candidates = candidates[:, 2:].prod(axis=1) 39 | return area_intersection / (area_bbox + area_candidates - area_intersection) 40 | 41 | 42 | def iou_cost(tracks, detections, track_indices=None, 43 | detection_indices=None): 44 | """An intersection over union distance metric. 45 | 46 | Parameters 47 | ---------- 48 | tracks : List[deep_sort.track.Track] 49 | A list of tracks. 50 | detections : List[deep_sort.detection.Detection] 51 | A list of detections. 52 | track_indices : Optional[List[int]] 53 | A list of indices to tracks that should be matched. Defaults to 54 | all `tracks`. 55 | detection_indices : Optional[List[int]] 56 | A list of indices to detections that should be matched. Defaults 57 | to all `detections`. 58 | 59 | Returns 60 | ------- 61 | ndarray 62 | Returns a cost matrix of shape 63 | len(track_indices), len(detection_indices) where entry (i, j) is 64 | `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. 65 | 66 | """ 67 | if track_indices is None: 68 | track_indices = np.arange(len(tracks)) 69 | if detection_indices is None: 70 | detection_indices = np.arange(len(detections)) 71 | 72 | cost_matrix = np.zeros((len(track_indices), len(detection_indices))) 73 | for row, track_idx in enumerate(track_indices): 74 | if tracks[track_idx].time_since_update > 1: 75 | cost_matrix[row, :] = linear_assignment.INFTY_COST 76 | continue 77 | 78 | bbox = tracks[track_idx].to_tlwh() 79 | candidates = np.asarray([detections[i].tlwh for i in detection_indices]) 80 | cost_matrix[row, :] = 1. - iou(bbox, candidates) 81 | return cost_matrix 82 | -------------------------------------------------------------------------------- /deep_sort/sort/kalman_filter.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import scipy.linalg 4 | 5 | 6 | """ 7 | Table for the 0.95 quantile of the chi-square distribution with N degrees of 8 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv 9 | function and used as Mahalanobis gating threshold. 10 | """ 11 | chi2inv95 = { 12 | 1: 3.8415, 13 | 2: 5.9915, 14 | 3: 7.8147, 15 | 4: 9.4877, 16 | 5: 11.070, 17 | 6: 12.592, 18 | 7: 14.067, 19 | 8: 15.507, 20 | 9: 16.919} 21 | 22 | 23 | class KalmanFilter(object): 24 | """ 25 | A simple Kalman filter for tracking bounding boxes in image space. 26 | 27 | The 8-dimensional state space 28 | 29 | x, y, a, h, vx, vy, va, vh 30 | 31 | contains the bounding box center position (x, y), aspect ratio a, height h, 32 | and their respective velocities. 33 | 34 | Object motion follows a constant velocity model. The bounding box location 35 | (x, y, a, h) is taken as direct observation of the state space (linear 36 | observation model). 37 | 38 | """ 39 | 40 | def __init__(self): 41 | ndim, dt = 4, 1. 42 | 43 | # Create Kalman filter model matrices. 44 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 45 | for i in range(ndim): 46 | self._motion_mat[i, ndim + i] = dt 47 | self._update_mat = np.eye(ndim, 2 * ndim) 48 | 49 | # Motion and observation uncertainty are chosen relative to the current 50 | # state estimate. These weights control the amount of uncertainty in 51 | # the model. This is a bit hacky. 52 | self._std_weight_position = 1. / 20 53 | self._std_weight_velocity = 1. / 160 54 | 55 | def initiate(self, measurement): 56 | """Create track from unassociated measurement. 57 | 58 | Parameters 59 | ---------- 60 | measurement : ndarray 61 | Bounding box coordinates (x, y, a, h) with center position (x, y), 62 | aspect ratio a, and height h. 63 | 64 | Returns 65 | ------- 66 | (ndarray, ndarray) 67 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 68 | dimensional) of the new track. Unobserved velocities are initialized 69 | to 0 mean. 70 | 71 | """ 72 | mean_pos = measurement 73 | mean_vel = np.zeros_like(mean_pos) 74 | mean = np.r_[mean_pos, mean_vel] 75 | 76 | std = [ 77 | 2 * self._std_weight_position * measurement[3], 78 | 2 * self._std_weight_position * measurement[3], 79 | 1e-2, 80 | 2 * self._std_weight_position * measurement[3], 81 | 10 * self._std_weight_velocity * measurement[3], 82 | 10 * self._std_weight_velocity * measurement[3], 83 | 1e-5, 84 | 10 * self._std_weight_velocity * measurement[3]] 85 | covariance = np.diag(np.square(std)) 86 | return mean, covariance 87 | 88 | def predict(self, mean, covariance): 89 | """Run Kalman filter prediction step. 90 | 91 | Parameters 92 | ---------- 93 | mean : ndarray 94 | The 8 dimensional mean vector of the object state at the previous 95 | time step. 96 | covariance : ndarray 97 | The 8x8 dimensional covariance matrix of the object state at the 98 | previous time step. 99 | 100 | Returns 101 | ------- 102 | (ndarray, ndarray) 103 | Returns the mean vector and covariance matrix of the predicted 104 | state. Unobserved velocities are initialized to 0 mean. 105 | 106 | """ 107 | std_pos = [ 108 | self._std_weight_position * mean[3], 109 | self._std_weight_position * mean[3], 110 | 1e-2, 111 | self._std_weight_position * mean[3]] 112 | std_vel = [ 113 | self._std_weight_velocity * mean[3], 114 | self._std_weight_velocity * mean[3], 115 | 1e-5, 116 | self._std_weight_velocity * mean[3]] 117 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 118 | 119 | mean = np.dot(self._motion_mat, mean) 120 | covariance = np.linalg.multi_dot(( 121 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 122 | 123 | return mean, covariance 124 | 125 | def project(self, mean, covariance): 126 | """Project state distribution to measurement space. 127 | 128 | Parameters 129 | ---------- 130 | mean : ndarray 131 | The state's mean vector (8 dimensional array). 132 | covariance : ndarray 133 | The state's covariance matrix (8x8 dimensional). 134 | 135 | Returns 136 | ------- 137 | (ndarray, ndarray) 138 | Returns the projected mean and covariance matrix of the given state 139 | estimate. 140 | 141 | """ 142 | std = [ 143 | self._std_weight_position * mean[3], 144 | self._std_weight_position * mean[3], 145 | 1e-1, 146 | self._std_weight_position * mean[3]] 147 | innovation_cov = np.diag(np.square(std)) 148 | 149 | mean = np.dot(self._update_mat, mean) 150 | covariance = np.linalg.multi_dot(( 151 | self._update_mat, covariance, self._update_mat.T)) 152 | return mean, covariance + innovation_cov 153 | 154 | def update(self, mean, covariance, measurement): 155 | """Run Kalman filter correction step. 156 | 157 | Parameters 158 | ---------- 159 | mean : ndarray 160 | The predicted state's mean vector (8 dimensional). 161 | covariance : ndarray 162 | The state's covariance matrix (8x8 dimensional). 163 | measurement : ndarray 164 | The 4 dimensional measurement vector (x, y, a, h), where (x, y) 165 | is the center position, a the aspect ratio, and h the height of the 166 | bounding box. 167 | 168 | Returns 169 | ------- 170 | (ndarray, ndarray) 171 | Returns the measurement-corrected state distribution. 172 | 173 | """ 174 | projected_mean, projected_cov = self.project(mean, covariance) 175 | 176 | chol_factor, lower = scipy.linalg.cho_factor( 177 | projected_cov, lower=True, check_finite=False) 178 | kalman_gain = scipy.linalg.cho_solve( 179 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, 180 | check_finite=False).T 181 | innovation = measurement - projected_mean 182 | 183 | new_mean = mean + np.dot(innovation, kalman_gain.T) 184 | new_covariance = covariance - np.linalg.multi_dot(( 185 | kalman_gain, projected_cov, kalman_gain.T)) 186 | return new_mean, new_covariance 187 | 188 | def gating_distance(self, mean, covariance, measurements, 189 | only_position=False): 190 | """Compute gating distance between state distribution and measurements. 191 | 192 | A suitable distance threshold can be obtained from `chi2inv95`. If 193 | `only_position` is False, the chi-square distribution has 4 degrees of 194 | freedom, otherwise 2. 195 | 196 | Parameters 197 | ---------- 198 | mean : ndarray 199 | Mean vector over the state distribution (8 dimensional). 200 | covariance : ndarray 201 | Covariance of the state distribution (8x8 dimensional). 202 | measurements : ndarray 203 | An Nx4 dimensional matrix of N measurements, each in 204 | format (x, y, a, h) where (x, y) is the bounding box center 205 | position, a the aspect ratio, and h the height. 206 | only_position : Optional[bool] 207 | If True, distance computation is done with respect to the bounding 208 | box center position only. 209 | 210 | Returns 211 | ------- 212 | ndarray 213 | Returns an array of length N, where the i-th element contains the 214 | squared Mahalanobis distance between (mean, covariance) and 215 | `measurements[i]`. 216 | 217 | """ 218 | mean, covariance = self.project(mean, covariance) 219 | if only_position: 220 | mean, covariance = mean[:2], covariance[:2, :2] 221 | measurements = measurements[:, :2] 222 | 223 | cholesky_factor = np.linalg.cholesky(covariance) 224 | d = measurements - mean 225 | z = scipy.linalg.solve_triangular( 226 | cholesky_factor, d.T, lower=True, check_finite=False, 227 | overwrite_b=True) 228 | squared_maha = np.sum(z * z, axis=0) 229 | return squared_maha 230 | -------------------------------------------------------------------------------- /deep_sort/sort/linear_assignment.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | # from sklearn.utils.linear_assignment_ import linear_assignment 5 | from scipy.optimize import linear_sum_assignment as linear_assignment 6 | from . import kalman_filter 7 | 8 | 9 | INFTY_COST = 1e+5 10 | 11 | 12 | def min_cost_matching( 13 | distance_metric, max_distance, tracks, detections, track_indices=None, 14 | detection_indices=None): 15 | """Solve linear assignment problem. 16 | 17 | Parameters 18 | ---------- 19 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 20 | The distance metric is given a list of tracks and detections as well as 21 | a list of N track indices and M detection indices. The metric should 22 | return the NxM dimensional cost matrix, where element (i, j) is the 23 | association cost between the i-th track in the given track indices and 24 | the j-th detection in the given detection_indices. 25 | max_distance : float 26 | Gating threshold. Associations with cost larger than this value are 27 | disregarded. 28 | tracks : List[track.Track] 29 | A list of predicted tracks at the current time step. 30 | detections : List[detection.Detection] 31 | A list of detections at the current time step. 32 | track_indices : List[int] 33 | List of track indices that maps rows in `cost_matrix` to tracks in 34 | `tracks` (see description above). 35 | detection_indices : List[int] 36 | List of detection indices that maps columns in `cost_matrix` to 37 | detections in `detections` (see description above). 38 | 39 | Returns 40 | ------- 41 | (List[(int, int)], List[int], List[int]) 42 | Returns a tuple with the following three entries: 43 | * A list of matched track and detection indices. 44 | * A list of unmatched track indices. 45 | * A list of unmatched detection indices. 46 | 47 | """ 48 | if track_indices is None: 49 | track_indices = np.arange(len(tracks)) 50 | if detection_indices is None: 51 | detection_indices = np.arange(len(detections)) 52 | 53 | if len(detection_indices) == 0 or len(track_indices) == 0: 54 | return [], track_indices, detection_indices # Nothing to match. 55 | 56 | cost_matrix = distance_metric( 57 | tracks, detections, track_indices, detection_indices) 58 | cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 59 | 60 | row_indices, col_indices = linear_assignment(cost_matrix) 61 | 62 | matches, unmatched_tracks, unmatched_detections = [], [], [] 63 | for col, detection_idx in enumerate(detection_indices): 64 | if col not in col_indices: 65 | unmatched_detections.append(detection_idx) 66 | for row, track_idx in enumerate(track_indices): 67 | if row not in row_indices: 68 | unmatched_tracks.append(track_idx) 69 | for row, col in zip(row_indices, col_indices): 70 | track_idx = track_indices[row] 71 | detection_idx = detection_indices[col] 72 | if cost_matrix[row, col] > max_distance: 73 | unmatched_tracks.append(track_idx) 74 | unmatched_detections.append(detection_idx) 75 | else: 76 | matches.append((track_idx, detection_idx)) 77 | return matches, unmatched_tracks, unmatched_detections 78 | 79 | 80 | def matching_cascade( 81 | distance_metric, max_distance, cascade_depth, tracks, detections, 82 | track_indices=None, detection_indices=None): 83 | """Run matching cascade. 84 | 85 | Parameters 86 | ---------- 87 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 88 | The distance metric is given a list of tracks and detections as well as 89 | a list of N track indices and M detection indices. The metric should 90 | return the NxM dimensional cost matrix, where element (i, j) is the 91 | association cost between the i-th track in the given track indices and 92 | the j-th detection in the given detection indices. 93 | max_distance : float 94 | Gating threshold. Associations with cost larger than this value are 95 | disregarded. 96 | cascade_depth: int 97 | The cascade depth, should be se to the maximum track age. 98 | tracks : List[track.Track] 99 | A list of predicted tracks at the current time step. 100 | detections : List[detection.Detection] 101 | A list of detections at the current time step. 102 | track_indices : Optional[List[int]] 103 | List of track indices that maps rows in `cost_matrix` to tracks in 104 | `tracks` (see description above). Defaults to all tracks. 105 | detection_indices : Optional[List[int]] 106 | List of detection indices that maps columns in `cost_matrix` to 107 | detections in `detections` (see description above). Defaults to all 108 | detections. 109 | 110 | Returns 111 | ------- 112 | (List[(int, int)], List[int], List[int]) 113 | Returns a tuple with the following three entries: 114 | * A list of matched track and detection indices. 115 | * A list of unmatched track indices. 116 | * A list of unmatched detection indices. 117 | 118 | """ 119 | if track_indices is None: 120 | track_indices = list(range(len(tracks))) 121 | if detection_indices is None: 122 | detection_indices = list(range(len(detections))) 123 | 124 | unmatched_detections = detection_indices 125 | matches = [] 126 | for level in range(cascade_depth): 127 | if len(unmatched_detections) == 0: # No detections left 128 | break 129 | 130 | track_indices_l = [ 131 | k for k in track_indices 132 | if tracks[k].time_since_update == 1 + level 133 | ] 134 | if len(track_indices_l) == 0: # Nothing to match at this level 135 | continue 136 | 137 | matches_l, _, unmatched_detections = \ 138 | min_cost_matching( 139 | distance_metric, max_distance, tracks, detections, 140 | track_indices_l, unmatched_detections) 141 | matches += matches_l 142 | unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) 143 | return matches, unmatched_tracks, unmatched_detections 144 | 145 | 146 | def gate_cost_matrix( 147 | kf, cost_matrix, tracks, detections, track_indices, detection_indices, 148 | gated_cost=INFTY_COST, only_position=False): 149 | """Invalidate infeasible entries in cost matrix based on the state 150 | distributions obtained by Kalman filtering. 151 | 152 | Parameters 153 | ---------- 154 | kf : The Kalman filter. 155 | cost_matrix : ndarray 156 | The NxM dimensional cost matrix, where N is the number of track indices 157 | and M is the number of detection indices, such that entry (i, j) is the 158 | association cost between `tracks[track_indices[i]]` and 159 | `detections[detection_indices[j]]`. 160 | tracks : List[track.Track] 161 | A list of predicted tracks at the current time step. 162 | detections : List[detection.Detection] 163 | A list of detections at the current time step. 164 | track_indices : List[int] 165 | List of track indices that maps rows in `cost_matrix` to tracks in 166 | `tracks` (see description above). 167 | detection_indices : List[int] 168 | List of detection indices that maps columns in `cost_matrix` to 169 | detections in `detections` (see description above). 170 | gated_cost : Optional[float] 171 | Entries in the cost matrix corresponding to infeasible associations are 172 | set this value. Defaults to a very large value. 173 | only_position : Optional[bool] 174 | If True, only the x, y position of the state distribution is considered 175 | during gating. Defaults to False. 176 | 177 | Returns 178 | ------- 179 | ndarray 180 | Returns the modified cost matrix. 181 | 182 | """ 183 | gating_dim = 2 if only_position else 4 184 | gating_threshold = kalman_filter.chi2inv95[gating_dim] 185 | measurements = np.asarray( 186 | [detections[i].to_xyah() for i in detection_indices]) 187 | for row, track_idx in enumerate(track_indices): 188 | track = tracks[track_idx] 189 | gating_distance = kf.gating_distance( 190 | track.mean, track.covariance, measurements, only_position) 191 | cost_matrix[row, gating_distance > gating_threshold] = gated_cost 192 | return cost_matrix 193 | -------------------------------------------------------------------------------- /deep_sort/sort/nn_matching.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | 4 | 5 | def _pdist(a, b): 6 | """Compute pair-wise squared distance between points in `a` and `b`. 7 | 8 | Parameters 9 | ---------- 10 | a : array_like 11 | An NxM matrix of N samples of dimensionality M. 12 | b : array_like 13 | An LxM matrix of L samples of dimensionality M. 14 | 15 | Returns 16 | ------- 17 | ndarray 18 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 19 | contains the squared distance between `a[i]` and `b[j]`. 20 | 21 | """ 22 | a, b = np.asarray(a), np.asarray(b) 23 | if len(a) == 0 or len(b) == 0: 24 | return np.zeros((len(a), len(b))) 25 | a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) 26 | r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] 27 | r2 = np.clip(r2, 0., float(np.inf)) 28 | return r2 29 | 30 | 31 | def _cosine_distance(a, b, data_is_normalized=False): 32 | """Compute pair-wise cosine distance between points in `a` and `b`. 33 | 34 | Parameters 35 | ---------- 36 | a : array_like 37 | An NxM matrix of N samples of dimensionality M. 38 | b : array_like 39 | An LxM matrix of L samples of dimensionality M. 40 | data_is_normalized : Optional[bool] 41 | If True, assumes rows in a and b are unit length vectors. 42 | Otherwise, a and b are explicitly normalized to lenght 1. 43 | 44 | Returns 45 | ------- 46 | ndarray 47 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 48 | contains the squared distance between `a[i]` and `b[j]`. 49 | 50 | """ 51 | if not data_is_normalized: 52 | a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) 53 | b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) 54 | return 1. - np.dot(a, b.T) 55 | 56 | 57 | def _nn_euclidean_distance(x, y): 58 | """ Helper function for nearest neighbor distance metric (Euclidean). 59 | 60 | Parameters 61 | ---------- 62 | x : ndarray 63 | A matrix of N row-vectors (sample points). 64 | y : ndarray 65 | A matrix of M row-vectors (query points). 66 | 67 | Returns 68 | ------- 69 | ndarray 70 | A vector of length M that contains for each entry in `y` the 71 | smallest Euclidean distance to a sample in `x`. 72 | 73 | """ 74 | distances = _pdist(x, y) 75 | return np.maximum(0.0, distances.min(axis=0)) 76 | 77 | 78 | def _nn_cosine_distance(x, y): 79 | """ Helper function for nearest neighbor distance metric (cosine). 80 | 81 | Parameters 82 | ---------- 83 | x : ndarray 84 | A matrix of N row-vectors (sample points). 85 | y : ndarray 86 | A matrix of M row-vectors (query points). 87 | 88 | Returns 89 | ------- 90 | ndarray 91 | A vector of length M that contains for each entry in `y` the 92 | smallest cosine distance to a sample in `x`. 93 | 94 | """ 95 | distances = _cosine_distance(x, y) 96 | return distances.min(axis=0) 97 | 98 | 99 | class NearestNeighborDistanceMetric(object): 100 | """ 101 | A nearest neighbor distance metric that, for each target, returns 102 | the closest distance to any sample that has been observed so far. 103 | 104 | Parameters 105 | ---------- 106 | metric : str 107 | Either "euclidean" or "cosine". 108 | matching_threshold: float 109 | The matching threshold. Samples with larger distance are considered an 110 | invalid match. 111 | budget : Optional[int] 112 | If not None, fix samples per class to at most this number. Removes 113 | the oldest samples when the budget is reached. 114 | 115 | Attributes 116 | ---------- 117 | samples : Dict[int -> List[ndarray]] 118 | A dictionary that maps from target identities to the list of samples 119 | that have been observed so far. 120 | 121 | """ 122 | 123 | def __init__(self, metric, matching_threshold, budget=None): 124 | 125 | 126 | if metric == "euclidean": 127 | self._metric = _nn_euclidean_distance 128 | elif metric == "cosine": 129 | self._metric = _nn_cosine_distance 130 | else: 131 | raise ValueError( 132 | "Invalid metric; must be either 'euclidean' or 'cosine'") 133 | self.matching_threshold = matching_threshold 134 | self.budget = budget 135 | self.samples = {} 136 | 137 | def partial_fit(self, features, targets, active_targets): 138 | """Update the distance metric with new data. 139 | 140 | Parameters 141 | ---------- 142 | features : ndarray 143 | An NxM matrix of N features of dimensionality M. 144 | targets : ndarray 145 | An integer array of associated target identities. 146 | active_targets : List[int] 147 | A list of targets that are currently present in the scene. 148 | 149 | """ 150 | for feature, target in zip(features, targets): 151 | self.samples.setdefault(target, []).append(feature) 152 | if self.budget is not None: 153 | self.samples[target] = self.samples[target][-self.budget:] 154 | self.samples = {k: self.samples[k] for k in active_targets} 155 | 156 | def distance(self, features, targets): 157 | """Compute distance between features and targets. 158 | 159 | Parameters 160 | ---------- 161 | features : ndarray 162 | An NxM matrix of N features of dimensionality M. 163 | targets : List[int] 164 | A list of targets to match the given `features` against. 165 | 166 | Returns 167 | ------- 168 | ndarray 169 | Returns a cost matrix of shape len(targets), len(features), where 170 | element (i, j) contains the closest squared distance between 171 | `targets[i]` and `features[j]`. 172 | 173 | """ 174 | cost_matrix = np.zeros((len(targets), len(features))) 175 | for i, target in enumerate(targets): 176 | cost_matrix[i, :] = self._metric(self.samples[target], features) 177 | return cost_matrix 178 | -------------------------------------------------------------------------------- /deep_sort/sort/preprocessing.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def non_max_suppression(boxes, max_bbox_overlap, scores=None): 7 | """Suppress overlapping detections. 8 | 9 | Original code from [1]_ has been adapted to include confidence score. 10 | 11 | .. [1] http://www.pyimagesearch.com/2015/02/16/ 12 | faster-non-maximum-suppression-python/ 13 | 14 | Examples 15 | -------- 16 | 17 | >>> boxes = [d.roi for d in detections] 18 | >>> scores = [d.confidence for d in detections] 19 | >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores) 20 | >>> detections = [detections[i] for i in indices] 21 | 22 | Parameters 23 | ---------- 24 | boxes : ndarray 25 | Array of ROIs (x, y, width, height). 26 | max_bbox_overlap : float 27 | ROIs that overlap more than this values are suppressed. 28 | scores : Optional[array_like] 29 | Detector confidence score. 30 | 31 | Returns 32 | ------- 33 | List[int] 34 | Returns indices of detections that have survived non-maxima suppression. 35 | 36 | """ 37 | if len(boxes) == 0: 38 | return [] 39 | 40 | boxes = boxes.astype(np.float) 41 | pick = [] 42 | 43 | x1 = boxes[:, 0] 44 | y1 = boxes[:, 1] 45 | x2 = boxes[:, 2] + boxes[:, 0] 46 | y2 = boxes[:, 3] + boxes[:, 1] 47 | 48 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | if scores is not None: 50 | idxs = np.argsort(scores) 51 | else: 52 | idxs = np.argsort(y2) 53 | 54 | while len(idxs) > 0: 55 | last = len(idxs) - 1 56 | i = idxs[last] 57 | pick.append(i) 58 | 59 | xx1 = np.maximum(x1[i], x1[idxs[:last]]) 60 | yy1 = np.maximum(y1[i], y1[idxs[:last]]) 61 | xx2 = np.minimum(x2[i], x2[idxs[:last]]) 62 | yy2 = np.minimum(y2[i], y2[idxs[:last]]) 63 | 64 | w = np.maximum(0, xx2 - xx1 + 1) 65 | h = np.maximum(0, yy2 - yy1 + 1) 66 | 67 | overlap = (w * h) / area[idxs[:last]] 68 | 69 | idxs = np.delete( 70 | idxs, np.concatenate( 71 | ([last], np.where(overlap > max_bbox_overlap)[0]))) 72 | 73 | return pick 74 | -------------------------------------------------------------------------------- /deep_sort/sort/track.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | 3 | 4 | class TrackState: 5 | """ 6 | Enumeration type for the single target track state. Newly created tracks are 7 | classified as `tentative` until enough evidence has been collected. Then, 8 | the track state is changed to `confirmed`. Tracks that are no longer alive 9 | are classified as `deleted` to mark them for removal from the set of active 10 | tracks. 11 | 12 | """ 13 | 14 | Tentative = 1 15 | Confirmed = 2 16 | Deleted = 3 17 | 18 | 19 | class Track: 20 | """ 21 | A single target track with state space `(x, y, a, h)` and associated 22 | velocities, where `(x, y)` is the center of the bounding box, `a` is the 23 | aspect ratio and `h` is the height. 24 | 25 | Parameters 26 | ---------- 27 | mean : ndarray 28 | Mean vector of the initial state distribution. 29 | covariance : ndarray 30 | Covariance matrix of the initial state distribution. 31 | track_id : int 32 | A unique track identifier. 33 | n_init : int 34 | Number of consecutive detections before the track is confirmed. The 35 | track state is set to `Deleted` if a miss occurs within the first 36 | `n_init` frames. 37 | max_age : int 38 | The maximum number of consecutive misses before the track state is 39 | set to `Deleted`. 40 | feature : Optional[ndarray] 41 | Feature vector of the detection this track originates from. If not None, 42 | this feature is added to the `features` cache. 43 | 44 | Attributes 45 | ---------- 46 | mean : ndarray 47 | Mean vector of the initial state distribution. 48 | covariance : ndarray 49 | Covariance matrix of the initial state distribution. 50 | track_id : int 51 | A unique track identifier. 52 | hits : int 53 | Total number of measurement updates. 54 | age : int 55 | Total number of frames since first occurance. 56 | time_since_update : int 57 | Total number of frames since last measurement update. 58 | state : TrackState 59 | The current track state. 60 | features : List[ndarray] 61 | A cache of features. On each measurement update, the associated feature 62 | vector is added to this list. 63 | 64 | """ 65 | 66 | def __init__(self, mean, covariance, track_id, n_init, max_age, 67 | feature=None): 68 | self.mean = mean 69 | self.covariance = covariance 70 | self.track_id = track_id 71 | self.hits = 1 72 | self.age = 1 73 | self.time_since_update = 0 74 | 75 | self.state = TrackState.Tentative 76 | self.features = [] 77 | if feature is not None: 78 | self.features.append(feature) 79 | 80 | self._n_init = n_init 81 | self._max_age = max_age 82 | 83 | def to_tlwh(self): 84 | """Get current position in bounding box format `(top left x, top left y, 85 | width, height)`. 86 | 87 | Returns 88 | ------- 89 | ndarray 90 | The bounding box. 91 | 92 | """ 93 | ret = self.mean[:4].copy() 94 | ret[2] *= ret[3] 95 | ret[:2] -= ret[2:] / 2 96 | return ret 97 | 98 | def to_tlbr(self): 99 | """Get current position in bounding box format `(min x, miny, max x, 100 | max y)`. 101 | 102 | Returns 103 | ------- 104 | ndarray 105 | The bounding box. 106 | 107 | """ 108 | ret = self.to_tlwh() 109 | ret[2:] = ret[:2] + ret[2:] 110 | return ret 111 | 112 | def predict(self, kf): 113 | """Propagate the state distribution to the current time step using a 114 | Kalman filter prediction step. 115 | 116 | Parameters 117 | ---------- 118 | kf : kalman_filter.KalmanFilter 119 | The Kalman filter. 120 | 121 | """ 122 | self.mean, self.covariance = kf.predict(self.mean, self.covariance) 123 | self.age += 1 124 | self.time_since_update += 1 125 | 126 | def update(self, kf, detection): 127 | """Perform Kalman filter measurement update step and update the feature 128 | cache. 129 | 130 | Parameters 131 | ---------- 132 | kf : kalman_filter.KalmanFilter 133 | The Kalman filter. 134 | detection : Detection 135 | The associated detection. 136 | 137 | """ 138 | self.mean, self.covariance = kf.update( 139 | self.mean, self.covariance, detection.to_xyah()) 140 | self.features.append(detection.feature) 141 | 142 | self.hits += 1 143 | self.time_since_update = 0 144 | if self.state == TrackState.Tentative and self.hits >= self._n_init: 145 | self.state = TrackState.Confirmed 146 | 147 | def mark_missed(self): 148 | """Mark this track as missed (no association at the current time step). 149 | """ 150 | if self.state == TrackState.Tentative: 151 | self.state = TrackState.Deleted 152 | elif self.time_since_update > self._max_age: 153 | self.state = TrackState.Deleted 154 | 155 | def is_tentative(self): 156 | """Returns True if this track is tentative (unconfirmed). 157 | """ 158 | return self.state == TrackState.Tentative 159 | 160 | def is_confirmed(self): 161 | """Returns True if this track is confirmed.""" 162 | return self.state == TrackState.Confirmed 163 | 164 | def is_deleted(self): 165 | """Returns True if this track is dead and should be deleted.""" 166 | return self.state == TrackState.Deleted 167 | -------------------------------------------------------------------------------- /deep_sort/sort/tracker.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from . import kalman_filter 5 | from . import linear_assignment 6 | from . import iou_matching 7 | from .track import Track 8 | 9 | 10 | class Tracker: 11 | """ 12 | This is the multi-target tracker. 13 | 14 | Parameters 15 | ---------- 16 | metric : nn_matching.NearestNeighborDistanceMetric 17 | A distance metric for measurement-to-track association. 18 | max_age : int 19 | Maximum number of missed misses before a track is deleted. 20 | n_init : int 21 | Number of consecutive detections before the track is confirmed. The 22 | track state is set to `Deleted` if a miss occurs within the first 23 | `n_init` frames. 24 | 25 | Attributes 26 | ---------- 27 | metric : nn_matching.NearestNeighborDistanceMetric 28 | The distance metric used for measurement to track association. 29 | max_age : int 30 | Maximum number of missed misses before a track is deleted. 31 | n_init : int 32 | Number of frames that a track remains in initialization phase. 33 | kf : kalman_filter.KalmanFilter 34 | A Kalman filter to filter target trajectories in image space. 35 | tracks : List[Track] 36 | The list of active tracks at the current time step. 37 | 38 | """ 39 | 40 | def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3): 41 | self.metric = metric 42 | # print("tracker.metric:=-=-=--=",self.metric) 43 | self.max_iou_distance = max_iou_distance 44 | self.max_age = max_age 45 | self.n_init = n_init 46 | 47 | self.kf = kalman_filter.KalmanFilter() 48 | self.tracks = [] 49 | self._next_id = 1 50 | 51 | def predict(self): 52 | """Propagate track state distributions one time step forward. 53 | 54 | This function should be called once every time step, before `update`. 55 | """ 56 | # print('tracker.py, predict, predict:', type(self.tracks), len(self.tracks)) 57 | for track in self.tracks: 58 | track.predict(self.kf) 59 | 60 | def update(self, detections): 61 | """Perform measurement update and track management. 62 | 63 | Parameters 64 | ---------- 65 | detections : List[deep_sort.detection.Detection] 66 | A list of detections at the current time step. 67 | 68 | """ 69 | # Run matching cascade. 70 | # print(type(detections)) 71 | matches, unmatched_tracks, unmatched_detections = \ 72 | self._match(detections) 73 | # print("tracker.matches:=-=-=--=",matches) 74 | # Update track set. 75 | for track_idx, detection_idx in matches: 76 | self.tracks[track_idx].update( 77 | self.kf, detections[detection_idx]) 78 | for track_idx in unmatched_tracks: 79 | self.tracks[track_idx].mark_missed() 80 | for detection_idx in unmatched_detections: 81 | self._initiate_track(detections[detection_idx]) 82 | self.tracks = [t for t in self.tracks if not t.is_deleted()] 83 | 84 | # Update distance metric. 85 | active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] 86 | features, targets = [], [] 87 | for track in self.tracks: 88 | if not track.is_confirmed(): 89 | continue 90 | features += track.features 91 | targets += [track.track_id for _ in track.features] 92 | track.features = [] 93 | self.metric.partial_fit( 94 | np.asarray(features), np.asarray(targets), active_targets) 95 | 96 | def _match(self, detections): 97 | 98 | 99 | def gated_metric(tracks, dets, track_indices, detection_indices): 100 | # print("tracker.gated metric.dets",type(track_indices)) 101 | features = np.array([dets[i].feature for i in detection_indices]) 102 | # print(features) 103 | targets = np.array([tracks[i].track_id for i in track_indices]) 104 | # print("-------------=================") 105 | # print(targets) 106 | cost_matrix = self.metric.distance(features, targets) 107 | cost_matrix = linear_assignment.gate_cost_matrix( 108 | self.kf, cost_matrix, tracks, dets, track_indices, 109 | detection_indices) 110 | 111 | return cost_matrix 112 | 113 | # Split track set into confirmed and unconfirmed tracks. 114 | confirmed_tracks = [ 115 | i for i, t in enumerate(self.tracks) if t.is_confirmed()] 116 | unconfirmed_tracks = [ 117 | i for i, t in enumerate(self.tracks) if not t.is_confirmed()] 118 | 119 | # Associate confirmed tracks using appearance features. 120 | matches_a, unmatched_tracks_a, unmatched_detections = \ 121 | linear_assignment.matching_cascade( 122 | gated_metric, self.metric.matching_threshold, self.max_age, 123 | self.tracks, detections, confirmed_tracks) 124 | 125 | # Associate remaining tracks together with unconfirmed tracks using IOU. 126 | iou_track_candidates = unconfirmed_tracks + [ 127 | k for k in unmatched_tracks_a if 128 | self.tracks[k].time_since_update == 1] 129 | unmatched_tracks_a = [ 130 | k for k in unmatched_tracks_a if 131 | self.tracks[k].time_since_update != 1] 132 | matches_b, unmatched_tracks_b, unmatched_detections = \ 133 | linear_assignment.min_cost_matching( 134 | iou_matching.iou_cost, self.max_iou_distance, self.tracks, 135 | detections, iou_track_candidates, unmatched_detections) 136 | 137 | matches = matches_a + matches_b 138 | unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) 139 | return matches, unmatched_tracks, unmatched_detections 140 | 141 | def _initiate_track(self, detection): 142 | mean, covariance = self.kf.initiate(detection.to_xyah()) 143 | self.tracks.append(Track( 144 | mean, covariance, self._next_id, self.n_init, self.max_age, 145 | detection.feature)) 146 | self._next_id += 1 147 | -------------------------------------------------------------------------------- /detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .yolov7 import yolo 2 | from .yolov7.yolo import YOLO 3 | 4 | 5 | __all__ = ['build_detector'] 6 | 7 | def build_detector(): 8 | return YOLO() 9 | -------------------------------------------------------------------------------- /detector/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/detector/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /detector/yolov7/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore map, miou, datasets 2 | map_out/ 3 | miou_out/ 4 | VOCdevkit/ 5 | datasets/ 6 | Medical_Datasets/ 7 | lfw/ 8 | logs/ 9 | model_data/ 10 | .temp_map_out/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /detector/yolov7/README.md: -------------------------------------------------------------------------------- 1 | ## YOLOV7:You Only Look Once目标检测模型在pytorch当中的实现 2 | --- 3 | 4 | ## 目录 5 | 1. [仓库更新 Top News](#仓库更新) 6 | 2. [相关仓库 Related code](#相关仓库) 7 | 3. [性能情况 Performance](#性能情况) 8 | 4. [所需环境 Environment](#所需环境) 9 | 5. [文件下载 Download](#文件下载) 10 | 6. [训练步骤 How2train](#训练步骤) 11 | 7. [预测步骤 How2predict](#预测步骤) 12 | 8. [评估步骤 How2eval](#评估步骤) 13 | 9. [参考资料 Reference](#Reference) 14 | 15 | ## Top News 16 | **`2022-07`**:**仓库创建,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪、支持多GPU训练、支持各个种类目标数量计算、支持heatmap、支持EMA。** 17 | 18 | ## 相关仓库 19 | | 模型 | 路径 | 20 | | :----- | :----- | 21 | YoloV3 | https://github.com/bubbliiiing/yolo3-pytorch 22 | Efficientnet-Yolo3 | https://github.com/bubbliiiing/efficientnet-yolo3-pytorch 23 | YoloV4 | https://github.com/bubbliiiing/yolov4-pytorch 24 | YoloV4-tiny | https://github.com/bubbliiiing/yolov4-tiny-pytorch 25 | Mobilenet-Yolov4 | https://github.com/bubbliiiing/mobilenet-yolov4-pytorch 26 | YoloV5-V5.0 | https://github.com/bubbliiiing/yolov5-pytorch 27 | YoloV5-V6.1 | https://github.com/bubbliiiing/yolov5-v6.1-pytorch 28 | YoloX | https://github.com/bubbliiiing/yolox-pytorch 29 | YoloV7 | https://github.com/bubbliiiing/yolov7-pytorch 30 | 31 | ## 性能情况 32 | | 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | mAP 0.5:0.95 | mAP 0.5 | 33 | | :-----: | :-----: | :------: | :------: | :------: | :-----: | 34 | | COCO-Train2017 | [yolov7_weights.pth](https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_weights.pth) | COCO-Val2017 | 640x640 | 50.7 | 69.2 35 | | COCO-Train2017 | [yolov7_x_weights.pth](https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_weights.pth) | COCO-Val2017 | 640x640 | 52.4 | 70.5 36 | 37 | ## 所需环境 38 | torch==1.2.0 39 | 为了使用amp混合精度,推荐使用torch1.7.1以上的版本。 40 | 41 | ## 文件下载 42 | 训练所需的权值可在百度网盘中下载。 43 | 链接: https://pan.baidu.com/s/1uYpjWC1uOo3Q-klpUEy9LQ 44 | 提取码: pmua 45 | 46 | VOC数据集下载地址如下,里面已经包括了训练集、测试集、验证集(与测试集一样),无需再次划分: 47 | 链接: https://pan.baidu.com/s/19Mw2u_df_nBzsC2lg20fQA 48 | 提取码: j5ge 49 | 50 | ## 训练步骤 51 | ### a、训练VOC07+12数据集 52 | 1. 数据集的准备 53 | **本文使用VOC格式进行训练,训练前需要下载好VOC07+12的数据集,解压后放在根目录** 54 | 55 | 2. 数据集的处理 56 | 修改voc_annotation.py里面的annotation_mode=2,运行voc_annotation.py生成根目录下的2007_train.txt和2007_val.txt。 57 | 58 | 3. 开始网络训练 59 | train.py的默认参数用于训练VOC数据集,直接运行train.py即可开始训练。 60 | 61 | 4. 训练结果预测 62 | 训练结果预测需要用到两个文件,分别是yolo.py和predict.py。我们首先需要去yolo.py里面修改model_path以及classes_path,这两个参数必须要修改。 63 | **model_path指向训练好的权值文件,在logs文件夹里。 64 | classes_path指向检测类别所对应的txt。** 65 | 完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。 66 | 67 | ### b、训练自己的数据集 68 | 1. 数据集的准备 69 | **本文使用VOC格式进行训练,训练前需要自己制作好数据集,** 70 | 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中。 71 | 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。 72 | 73 | 2. 数据集的处理 74 | 在完成数据集的摆放之后,我们需要利用voc_annotation.py获得训练用的2007_train.txt和2007_val.txt。 75 | 修改voc_annotation.py里面的参数。第一次训练可以仅修改classes_path,classes_path用于指向检测类别所对应的txt。 76 | 训练自己的数据集时,可以自己建立一个cls_classes.txt,里面写自己所需要区分的类别。 77 | model_data/cls_classes.txt文件内容为: 78 | ```python 79 | cat 80 | dog 81 | ... 82 | ``` 83 | 修改voc_annotation.py中的classes_path,使其对应cls_classes.txt,并运行voc_annotation.py。 84 | 85 | 3. 开始网络训练 86 | **训练的参数较多,均在train.py中,大家可以在下载库后仔细看注释,其中最重要的部分依然是train.py里的classes_path。** 87 | **classes_path用于指向检测类别所对应的txt,这个txt和voc_annotation.py里面的txt一样!训练自己的数据集必须要修改!** 88 | 修改完classes_path后就可以运行train.py开始训练了,在训练多个epoch后,权值会生成在logs文件夹中。 89 | 90 | 4. 训练结果预测 91 | 训练结果预测需要用到两个文件,分别是yolo.py和predict.py。在yolo.py里面修改model_path以及classes_path。 92 | **model_path指向训练好的权值文件,在logs文件夹里。 93 | classes_path指向检测类别所对应的txt。** 94 | 完成修改后就可以运行predict.py进行检测了。运行后输入图片路径即可检测。 95 | 96 | ## 预测步骤 97 | ### a、使用预训练权重 98 | 1. 下载完库后解压,在百度网盘下载权值,放入model_data,运行predict.py,输入 99 | ```python 100 | img/street.jpg 101 | ``` 102 | 2. 在predict.py里面进行设置可以进行fps测试和video视频检测。 103 | ### b、使用自己训练的权重 104 | 1. 按照训练步骤训练。 105 | 2. 在yolo.py文件里面,在如下部分修改model_path和classes_path使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,classes_path是model_path对应分的类**。 106 | ```python 107 | _defaults = { 108 | #--------------------------------------------------------------------------# 109 | # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! 110 | # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt 111 | # 112 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 113 | # 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。 114 | # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 115 | #--------------------------------------------------------------------------# 116 | "model_path" : 'model_data/yolov7_weights.pth', 117 | "classes_path" : 'model_data/coco_classes.txt', 118 | #---------------------------------------------------------------------# 119 | # anchors_path代表先验框对应的txt文件,一般不修改。 120 | # anchors_mask用于帮助代码找到对应的先验框,一般不修改。 121 | #---------------------------------------------------------------------# 122 | "anchors_path" : 'model_data/yolo_anchors.txt', 123 | "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]], 124 | #---------------------------------------------------------------------# 125 | # 输入图片的大小,必须为32的倍数。 126 | #---------------------------------------------------------------------# 127 | "input_shape" : [640, 640], 128 | #------------------------------------------------------# 129 | # 所使用到的yolov7的版本,本仓库一共提供两个: 130 | # l : 对应yolov7 131 | # x : 对应yolov7_x 132 | #------------------------------------------------------# 133 | "phi" : 'l', 134 | #---------------------------------------------------------------------# 135 | # 只有得分大于置信度的预测框会被保留下来 136 | #---------------------------------------------------------------------# 137 | "confidence" : 0.5, 138 | #---------------------------------------------------------------------# 139 | # 非极大抑制所用到的nms_iou大小 140 | #---------------------------------------------------------------------# 141 | "nms_iou" : 0.3, 142 | #---------------------------------------------------------------------# 143 | # 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize, 144 | # 在多次测试后,发现关闭letterbox_image直接resize的效果更好 145 | #---------------------------------------------------------------------# 146 | "letterbox_image" : True, 147 | #-------------------------------# 148 | # 是否使用Cuda 149 | # 没有GPU可以设置成False 150 | #-------------------------------# 151 | "cuda" : True, 152 | } 153 | ``` 154 | 3. 运行predict.py,输入 155 | ```python 156 | img/street.jpg 157 | ``` 158 | 4. 在predict.py里面进行设置可以进行fps测试和video视频检测。 159 | 160 | ## 评估步骤 161 | ### a、评估VOC07+12的测试集 162 | 1. 本文使用VOC格式进行评估。VOC07+12已经划分好了测试集,无需利用voc_annotation.py生成ImageSets文件夹下的txt。 163 | 2. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。** 164 | 3. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。 165 | 166 | ### b、评估自己的数据集 167 | 1. 本文使用VOC格式进行评估。 168 | 2. 如果在训练前已经运行过voc_annotation.py文件,代码会自动将数据集划分成训练集、验证集和测试集。如果想要修改测试集的比例,可以修改voc_annotation.py文件下的trainval_percent。trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1。train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1。 169 | 3. 利用voc_annotation.py划分测试集后,前往get_map.py文件修改classes_path,classes_path用于指向检测类别所对应的txt,这个txt和训练时的txt一样。评估自己的数据集必须要修改。 170 | 4. 在yolo.py里面修改model_path以及classes_path。**model_path指向训练好的权值文件,在logs文件夹里。classes_path指向检测类别所对应的txt。** 171 | 5. 运行get_map.py即可获得评估结果,评估结果会保存在map_out文件夹中。 172 | 173 | ## Reference 174 | https://github.com/WongKinYiu/yolov7 175 | -------------------------------------------------------------------------------- /detector/yolov7/get_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | from PIL import Image 5 | from tqdm import tqdm 6 | 7 | from utils.utils import get_classes 8 | from utils.utils_map import get_coco_map, get_map 9 | from yolo import YOLO 10 | 11 | if __name__ == "__main__": 12 | ''' 13 | Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。 14 | 默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。 15 | 16 | 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值 17 | 因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框, 18 | ''' 19 | #------------------------------------------------------------------------------------------------------------------# 20 | # map_mode用于指定该文件运行时计算的内容 21 | # map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。 22 | # map_mode为1代表仅仅获得预测结果。 23 | # map_mode为2代表仅仅获得真实框。 24 | # map_mode为3代表仅仅计算VOC_map。 25 | # map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行 26 | #-------------------------------------------------------------------------------------------------------------------# 27 | map_mode = 0 28 | #--------------------------------------------------------------------------------------# 29 | # 此处的classes_path用于指定需要测量VOC_map的类别 30 | # 一般情况下与训练和预测所用的classes_path一致即可 31 | #--------------------------------------------------------------------------------------# 32 | classes_path = 'model_data/voc_classes.txt' 33 | #--------------------------------------------------------------------------------------# 34 | # MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。 35 | # 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。 36 | # 37 | # 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。 38 | # 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低, 39 | #--------------------------------------------------------------------------------------# 40 | MINOVERLAP = 0.5 41 | #--------------------------------------------------------------------------------------# 42 | # 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP 43 | # 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。 44 | # 45 | # 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。 46 | # 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。 47 | #--------------------------------------------------------------------------------------# 48 | confidence = 0.001 49 | #--------------------------------------------------------------------------------------# 50 | # 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。 51 | # 52 | # 该值一般不调整。 53 | #--------------------------------------------------------------------------------------# 54 | nms_iou = 0.5 55 | #---------------------------------------------------------------------------------------------------------------# 56 | # Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。 57 | # 58 | # 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。 59 | # 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。 60 | # 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。 61 | #---------------------------------------------------------------------------------------------------------------# 62 | score_threhold = 0.5 63 | #-------------------------------------------------------# 64 | # map_vis用于指定是否开启VOC_map计算的可视化 65 | #-------------------------------------------------------# 66 | map_vis = False 67 | #-------------------------------------------------------# 68 | # 指向VOC数据集所在的文件夹 69 | # 默认指向根目录下的VOC数据集 70 | #-------------------------------------------------------# 71 | VOCdevkit_path = 'VOCdevkit' 72 | #-------------------------------------------------------# 73 | # 结果输出的文件夹,默认为map_out 74 | #-------------------------------------------------------# 75 | map_out_path = 'map_out' 76 | 77 | image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split() 78 | 79 | if not os.path.exists(map_out_path): 80 | os.makedirs(map_out_path) 81 | if not os.path.exists(os.path.join(map_out_path, 'ground-truth')): 82 | os.makedirs(os.path.join(map_out_path, 'ground-truth')) 83 | if not os.path.exists(os.path.join(map_out_path, 'detection-results')): 84 | os.makedirs(os.path.join(map_out_path, 'detection-results')) 85 | if not os.path.exists(os.path.join(map_out_path, 'images-optional')): 86 | os.makedirs(os.path.join(map_out_path, 'images-optional')) 87 | 88 | class_names, _ = get_classes(classes_path) 89 | 90 | if map_mode == 0 or map_mode == 1: 91 | print("Load model.") 92 | yolo = YOLO(confidence = confidence, nms_iou = nms_iou) 93 | print("Load model done.") 94 | 95 | print("Get predict result.") 96 | for image_id in tqdm(image_ids): 97 | image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg") 98 | image = Image.open(image_path) 99 | if map_vis: 100 | image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg")) 101 | yolo.get_map_txt(image_id, image, class_names, map_out_path) 102 | print("Get predict result done.") 103 | 104 | if map_mode == 0 or map_mode == 2: 105 | print("Get ground truth result.") 106 | for image_id in tqdm(image_ids): 107 | with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 108 | root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot() 109 | for obj in root.findall('object'): 110 | difficult_flag = False 111 | if obj.find('difficult')!=None: 112 | difficult = obj.find('difficult').text 113 | if int(difficult)==1: 114 | difficult_flag = True 115 | obj_name = obj.find('name').text 116 | if obj_name not in class_names: 117 | continue 118 | bndbox = obj.find('bndbox') 119 | left = bndbox.find('xmin').text 120 | top = bndbox.find('ymin').text 121 | right = bndbox.find('xmax').text 122 | bottom = bndbox.find('ymax').text 123 | 124 | if difficult_flag: 125 | new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) 126 | else: 127 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 128 | print("Get ground truth result done.") 129 | 130 | if map_mode == 0 or map_mode == 3: 131 | print("Get map.") 132 | get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path) 133 | print("Get map done.") 134 | 135 | if map_mode == 4: 136 | print("Get map.") 137 | get_coco_map(class_names = class_names, path = map_out_path) 138 | print("Get map done.") 139 | -------------------------------------------------------------------------------- /detector/yolov7/img/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/detector/yolov7/img/street.jpg -------------------------------------------------------------------------------- /detector/yolov7/kmeans_for_anchors.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------------------------------------------------------------------------# 2 | # kmeans虽然会对数据集中的框进行聚类,但是很多数据集由于框的大小相近,聚类出来的9个框相差不大, 3 | # 这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,shape越小的特征层适合越大的先验框 4 | # 原始网络的先验框已经按大中小比例分配好了,不进行聚类也会有非常好的效果。 5 | #-------------------------------------------------------------------------------------------------------# 6 | import glob 7 | import xml.etree.ElementTree as ET 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | 14 | def cas_ratio(box,cluster): 15 | ratios_of_box_cluster = box / cluster 16 | ratios_of_cluster_box = cluster / box 17 | ratios = np.concatenate([ratios_of_box_cluster, ratios_of_cluster_box], axis = -1) 18 | 19 | return np.max(ratios, -1) 20 | 21 | def avg_ratio(box,cluster): 22 | return np.mean([np.min(cas_ratio(box[i],cluster)) for i in range(box.shape[0])]) 23 | 24 | def kmeans(box,k): 25 | #-------------------------------------------------------------# 26 | # 取出一共有多少框 27 | #-------------------------------------------------------------# 28 | row = box.shape[0] 29 | 30 | #-------------------------------------------------------------# 31 | # 每个框各个点的位置 32 | #-------------------------------------------------------------# 33 | distance = np.empty((row,k)) 34 | 35 | #-------------------------------------------------------------# 36 | # 最后的聚类位置 37 | #-------------------------------------------------------------# 38 | last_clu = np.zeros((row,)) 39 | 40 | np.random.seed() 41 | 42 | #-------------------------------------------------------------# 43 | # 随机选5个当聚类中心 44 | #-------------------------------------------------------------# 45 | cluster = box[np.random.choice(row,k,replace = False)] 46 | 47 | iter = 0 48 | while True: 49 | #-------------------------------------------------------------# 50 | # 计算当前框和先验框的宽高比例 51 | #-------------------------------------------------------------# 52 | for i in range(row): 53 | distance[i] = cas_ratio(box[i],cluster) 54 | 55 | #-------------------------------------------------------------# 56 | # 取出最小点 57 | #-------------------------------------------------------------# 58 | near = np.argmin(distance,axis=1) 59 | 60 | if (last_clu == near).all(): 61 | break 62 | 63 | #-------------------------------------------------------------# 64 | # 求每一个类的中位点 65 | #-------------------------------------------------------------# 66 | for j in range(k): 67 | cluster[j] = np.median( 68 | box[near == j],axis=0) 69 | 70 | last_clu = near 71 | if iter % 5 == 0: 72 | print('iter: {:d}. avg_ratio:{:.2f}'.format(iter, avg_ratio(box,cluster))) 73 | iter += 1 74 | 75 | return cluster, near 76 | 77 | def load_data(path): 78 | data = [] 79 | #-------------------------------------------------------------# 80 | # 对于每一个xml都寻找box 81 | #-------------------------------------------------------------# 82 | for xml_file in tqdm(glob.glob('{}/*xml'.format(path))): 83 | tree = ET.parse(xml_file) 84 | height = int(tree.findtext('./size/height')) 85 | width = int(tree.findtext('./size/width')) 86 | if height<=0 or width<=0: 87 | continue 88 | 89 | #-------------------------------------------------------------# 90 | # 对于每一个目标都获得它的宽高 91 | #-------------------------------------------------------------# 92 | for obj in tree.iter('object'): 93 | xmin = int(float(obj.findtext('bndbox/xmin'))) / width 94 | ymin = int(float(obj.findtext('bndbox/ymin'))) / height 95 | xmax = int(float(obj.findtext('bndbox/xmax'))) / width 96 | ymax = int(float(obj.findtext('bndbox/ymax'))) / height 97 | 98 | xmin = np.float64(xmin) 99 | ymin = np.float64(ymin) 100 | xmax = np.float64(xmax) 101 | ymax = np.float64(ymax) 102 | # 得到宽高 103 | data.append([xmax-xmin,ymax-ymin]) 104 | return np.array(data) 105 | 106 | if __name__ == '__main__': 107 | np.random.seed(0) 108 | #-------------------------------------------------------------# 109 | # 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml 110 | # 会生成yolo_anchors.txt 111 | #-------------------------------------------------------------# 112 | input_shape = [640, 640] 113 | anchors_num = 9 114 | #-------------------------------------------------------------# 115 | # 载入数据集,可以使用VOC的xml 116 | #-------------------------------------------------------------# 117 | path = 'VOCdevkit/VOC2007/Annotations' 118 | 119 | #-------------------------------------------------------------# 120 | # 载入所有的xml 121 | # 存储格式为转化为比例后的width,height 122 | #-------------------------------------------------------------# 123 | print('Load xmls.') 124 | data = load_data(path) 125 | print('Load xmls done.') 126 | 127 | #-------------------------------------------------------------# 128 | # 使用k聚类算法 129 | #-------------------------------------------------------------# 130 | print('K-means boxes.') 131 | cluster, near = kmeans(data, anchors_num) 132 | print('K-means boxes done.') 133 | data = data * np.array([input_shape[1], input_shape[0]]) 134 | cluster = cluster * np.array([input_shape[1], input_shape[0]]) 135 | 136 | #-------------------------------------------------------------# 137 | # 绘图 138 | #-------------------------------------------------------------# 139 | for j in range(anchors_num): 140 | plt.scatter(data[near == j][:,0], data[near == j][:,1]) 141 | plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black') 142 | plt.savefig("kmeans_for_anchors.jpg") 143 | plt.show() 144 | print('Save kmeans_for_anchors.jpg in root dir.') 145 | 146 | cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1])] 147 | print('avg_ratio:{:.2f}'.format(avg_ratio(data, cluster))) 148 | print(cluster) 149 | 150 | f = open("yolo_anchors.txt", 'w') 151 | row = np.shape(cluster)[0] 152 | for i in range(row): 153 | if i == 0: 154 | x_y = "%d,%d" % (cluster[i][0], cluster[i][1]) 155 | else: 156 | x_y = ", %d,%d" % (cluster[i][0], cluster[i][1]) 157 | f.write(x_y) 158 | f.close() 159 | -------------------------------------------------------------------------------- /detector/yolov7/nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /detector/yolov7/nets/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def autopad(k, p=None): 6 | if p is None: 7 | p = k // 2 if isinstance(k, int) else [x // 2 for x in k] 8 | return p 9 | 10 | class SiLU(nn.Module): 11 | @staticmethod 12 | def forward(x): 13 | return x * torch.sigmoid(x) 14 | 15 | class Conv(nn.Module): 16 | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=SiLU()): # ch_in, ch_out, kernel, stride, padding, groups 17 | super(Conv, self).__init__() 18 | self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) 19 | self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03) 20 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 21 | 22 | def forward(self, x): 23 | return self.act(self.bn(self.conv(x))) 24 | 25 | def fuseforward(self, x): 26 | return self.act(self.conv(x)) 27 | 28 | class Block(nn.Module): 29 | def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]): 30 | super(Block, self).__init__() 31 | c_ = int(c2 * e) 32 | 33 | self.ids = ids 34 | self.cv1 = Conv(c1, c_, 1, 1) 35 | self.cv2 = Conv(c1, c_, 1, 1) 36 | self.cv3 = nn.ModuleList( 37 | [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)] 38 | ) 39 | self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1) 40 | 41 | def forward(self, x): 42 | x_1 = self.cv1(x) 43 | x_2 = self.cv2(x) 44 | 45 | x_all = [x_1, x_2] 46 | for i in range(len(self.cv3)): 47 | x_2 = self.cv3[i](x_2) 48 | x_all.append(x_2) 49 | 50 | out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1)) 51 | return out 52 | 53 | class MP(nn.Module): 54 | def __init__(self, k=2): 55 | super(MP, self).__init__() 56 | self.m = nn.MaxPool2d(kernel_size=k, stride=k) 57 | 58 | def forward(self, x): 59 | return self.m(x) 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, c1, c2): 63 | super(Transition, self).__init__() 64 | self.cv1 = Conv(c1, c2, 1, 1) 65 | self.cv2 = Conv(c1, c2, 1, 1) 66 | self.cv3 = Conv(c2, c2, 3, 2) 67 | 68 | self.mp = MP() 69 | 70 | def forward(self, x): 71 | x_1 = self.mp(x) 72 | x_1 = self.cv1(x_1) 73 | 74 | x_2 = self.cv2(x) 75 | x_2 = self.cv3(x_2) 76 | 77 | return torch.cat([x_2, x_1], 1) 78 | 79 | class Backbone(nn.Module): 80 | def __init__(self, transition_channels, block_channels, n, phi, pretrained=False): 81 | super().__init__() 82 | #-----------------------------------------------# 83 | # 输入图片是640, 640, 3 84 | #-----------------------------------------------# 85 | ids = { 86 | 'l' : [-1, -3, -5, -6], 87 | 'x' : [-1, -3, -5, -7, -8], 88 | }[phi] 89 | self.stem = nn.Sequential( 90 | Conv(3, transition_channels, 3, 1), 91 | Conv(transition_channels, transition_channels * 2, 3, 2), 92 | Conv(transition_channels * 2, transition_channels * 2, 3, 1), 93 | ) 94 | self.dark2 = nn.Sequential( 95 | Conv(transition_channels * 2, transition_channels * 4, 3, 2), 96 | Block(transition_channels * 4, block_channels * 2, transition_channels * 8, n=n, ids=ids), 97 | ) 98 | self.dark3 = nn.Sequential( 99 | Transition(transition_channels * 8, transition_channels * 4), 100 | Block(transition_channels * 8, block_channels * 4, transition_channels * 16, n=n, ids=ids), 101 | ) 102 | self.dark4 = nn.Sequential( 103 | Transition(transition_channels * 16, transition_channels * 8), 104 | Block(transition_channels * 16, block_channels * 8, transition_channels * 32, n=n, ids=ids), 105 | ) 106 | self.dark5 = nn.Sequential( 107 | Transition(transition_channels * 32, transition_channels * 16), 108 | Block(transition_channels * 32, block_channels * 8, transition_channels * 32, n=n, ids=ids), 109 | ) 110 | 111 | if pretrained: 112 | url = { 113 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth', 114 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth', 115 | }[phi] 116 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", model_dir="./model_data") 117 | self.load_state_dict(checkpoint, strict=False) 118 | print("Load weights from " + url.split('/')[-1]) 119 | 120 | def forward(self, x): 121 | x = self.stem(x) 122 | x = self.dark2(x) 123 | #-----------------------------------------------# 124 | # dark3的输出为80, 80, 256,是一个有效特征层 125 | #-----------------------------------------------# 126 | x = self.dark3(x) 127 | feat1 = x 128 | #-----------------------------------------------# 129 | # dark4的输出为40, 40, 512,是一个有效特征层 130 | #-----------------------------------------------# 131 | x = self.dark4(x) 132 | feat2 = x 133 | #-----------------------------------------------# 134 | # dark5的输出为20, 20, 1024,是一个有效特征层 135 | #-----------------------------------------------# 136 | x = self.dark5(x) 137 | feat3 = x 138 | return feat1, feat2, feat3 139 | -------------------------------------------------------------------------------- /detector/yolov7/nets/yolo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .backbone import Backbone, Block, Conv, SiLU, Transition, autopad 6 | 7 | 8 | class SPPCSPC(nn.Module): 9 | # CSP https://github.com/WongKinYiu/CrossStagePartialNetworks 10 | def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)): 11 | super(SPPCSPC, self).__init__() 12 | c_ = int(2 * c2 * e) # hidden channels 13 | self.cv1 = Conv(c1, c_, 1, 1) 14 | self.cv2 = Conv(c1, c_, 1, 1) 15 | self.cv3 = Conv(c_, c_, 3, 1) 16 | self.cv4 = Conv(c_, c_, 1, 1) 17 | self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) 18 | self.cv5 = Conv(4 * c_, c_, 1, 1) 19 | self.cv6 = Conv(c_, c_, 3, 1) 20 | self.cv7 = Conv(2 * c_, c2, 1, 1) 21 | 22 | def forward(self, x): 23 | x1 = self.cv4(self.cv3(self.cv1(x))) 24 | y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1))) 25 | y2 = self.cv2(x) 26 | return self.cv7(torch.cat((y1, y2), dim=1)) 27 | 28 | class RepConv(nn.Module): 29 | # Represented convolution 30 | # https://arxiv.org/abs/2101.03697 31 | def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=SiLU(), deploy=False): 32 | super(RepConv, self).__init__() 33 | self.deploy = deploy 34 | self.groups = g 35 | self.in_channels = c1 36 | self.out_channels = c2 37 | 38 | assert k == 3 39 | assert autopad(k, p) == 1 40 | 41 | padding_11 = autopad(k, p) - k // 2 42 | self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) 43 | 44 | if deploy: 45 | self.rbr_reparam = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=True) 46 | else: 47 | self.rbr_identity = (nn.BatchNorm2d(num_features=c1, eps=0.001, momentum=0.03) if c2 == c1 and s == 1 else None) 48 | self.rbr_dense = nn.Sequential( 49 | nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False), 50 | nn.BatchNorm2d(num_features=c2, eps=0.001, momentum=0.03), 51 | ) 52 | self.rbr_1x1 = nn.Sequential( 53 | nn.Conv2d( c1, c2, 1, s, padding_11, groups=g, bias=False), 54 | nn.BatchNorm2d(num_features=c2, eps=0.001, momentum=0.03), 55 | ) 56 | 57 | def forward(self, inputs): 58 | if hasattr(self, "rbr_reparam"): 59 | return self.act(self.rbr_reparam(inputs)) 60 | if self.rbr_identity is None: 61 | id_out = 0 62 | else: 63 | id_out = self.rbr_identity(inputs) 64 | return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) 65 | 66 | def get_equivalent_kernel_bias(self): 67 | kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense) 68 | kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1) 69 | kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity) 70 | return ( 71 | kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, 72 | bias3x3 + bias1x1 + biasid, 73 | ) 74 | 75 | def _pad_1x1_to_3x3_tensor(self, kernel1x1): 76 | if kernel1x1 is None: 77 | return 0 78 | else: 79 | return nn.functional.pad(kernel1x1, [1, 1, 1, 1]) 80 | 81 | def _fuse_bn_tensor(self, branch): 82 | if branch is None: 83 | return 0, 0 84 | if isinstance(branch, nn.Sequential): 85 | kernel = branch[0].weight 86 | running_mean = branch[1].running_mean 87 | running_var = branch[1].running_var 88 | gamma = branch[1].weight 89 | beta = branch[1].bias 90 | eps = branch[1].eps 91 | else: 92 | assert isinstance(branch, nn.BatchNorm2d) 93 | if not hasattr(self, "id_tensor"): 94 | input_dim = self.in_channels // self.groups 95 | kernel_value = np.zeros( 96 | (self.in_channels, input_dim, 3, 3), dtype=np.float32 97 | ) 98 | for i in range(self.in_channels): 99 | kernel_value[i, i % input_dim, 1, 1] = 1 100 | self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) 101 | kernel = self.id_tensor 102 | running_mean = branch.running_mean 103 | running_var = branch.running_var 104 | gamma = branch.weight 105 | beta = branch.bias 106 | eps = branch.eps 107 | std = (running_var + eps).sqrt() 108 | t = (gamma / std).reshape(-1, 1, 1, 1) 109 | return kernel * t, beta - running_mean * gamma / std 110 | 111 | def repvgg_convert(self): 112 | kernel, bias = self.get_equivalent_kernel_bias() 113 | return ( 114 | kernel.detach().cpu().numpy(), 115 | bias.detach().cpu().numpy(), 116 | ) 117 | 118 | def fuse_conv_bn(self, conv, bn): 119 | std = (bn.running_var + bn.eps).sqrt() 120 | bias = bn.bias - bn.running_mean * bn.weight / std 121 | 122 | t = (bn.weight / std).reshape(-1, 1, 1, 1) 123 | weights = conv.weight * t 124 | 125 | bn = nn.Identity() 126 | conv = nn.Conv2d(in_channels = conv.in_channels, 127 | out_channels = conv.out_channels, 128 | kernel_size = conv.kernel_size, 129 | stride=conv.stride, 130 | padding = conv.padding, 131 | dilation = conv.dilation, 132 | groups = conv.groups, 133 | bias = True, 134 | padding_mode = conv.padding_mode) 135 | 136 | conv.weight = torch.nn.Parameter(weights) 137 | conv.bias = torch.nn.Parameter(bias) 138 | return conv 139 | 140 | def fuse_repvgg_block(self): 141 | if self.deploy: 142 | return 143 | print(f"RepConv.fuse_repvgg_block") 144 | self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1]) 145 | 146 | self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1]) 147 | rbr_1x1_bias = self.rbr_1x1.bias 148 | weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1]) 149 | 150 | # Fuse self.rbr_identity 151 | if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)): 152 | identity_conv_1x1 = nn.Conv2d( 153 | in_channels=self.in_channels, 154 | out_channels=self.out_channels, 155 | kernel_size=1, 156 | stride=1, 157 | padding=0, 158 | groups=self.groups, 159 | bias=False) 160 | identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device) 161 | identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze() 162 | identity_conv_1x1.weight.data.fill_(0.0) 163 | identity_conv_1x1.weight.data.fill_diagonal_(1.0) 164 | identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3) 165 | 166 | identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity) 167 | bias_identity_expanded = identity_conv_1x1.bias 168 | weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1]) 169 | else: 170 | bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) ) 171 | weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) ) 172 | 173 | self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded) 174 | self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded) 175 | 176 | self.rbr_reparam = self.rbr_dense 177 | self.deploy = True 178 | 179 | if self.rbr_identity is not None: 180 | del self.rbr_identity 181 | self.rbr_identity = None 182 | 183 | if self.rbr_1x1 is not None: 184 | del self.rbr_1x1 185 | self.rbr_1x1 = None 186 | 187 | if self.rbr_dense is not None: 188 | del self.rbr_dense 189 | self.rbr_dense = None 190 | 191 | def fuse_conv_and_bn(conv, bn): 192 | fusedconv = nn.Conv2d(conv.in_channels, 193 | conv.out_channels, 194 | kernel_size=conv.kernel_size, 195 | stride=conv.stride, 196 | padding=conv.padding, 197 | groups=conv.groups, 198 | bias=True).requires_grad_(False).to(conv.weight.device) 199 | 200 | w_conv = conv.weight.clone().view(conv.out_channels, -1) 201 | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) 202 | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) 203 | 204 | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias 205 | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) 206 | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) 207 | return fusedconv 208 | 209 | #---------------------------------------------------# 210 | # yolo_body 211 | #---------------------------------------------------# 212 | class YoloBody(nn.Module): 213 | def __init__(self, anchors_mask, num_classes, phi, pretrained=False): 214 | super(YoloBody, self).__init__() 215 | #-----------------------------------------------# 216 | # 定义了不同yolov7版本的参数 217 | #-----------------------------------------------# 218 | transition_channels = {'l' : 32, 'x' : 40}[phi] 219 | block_channels = 32 220 | panet_channels = {'l' : 32, 'x' : 64}[phi] 221 | e = {'l' : 2, 'x' : 1}[phi] 222 | n = {'l' : 4, 'x' : 6}[phi] 223 | ids = {'l' : [-1, -2, -3, -4, -5, -6], 'x' : [-1, -3, -5, -7, -8]}[phi] 224 | conv = {'l' : RepConv, 'x' : Conv}[phi] 225 | #-----------------------------------------------# 226 | # 输入图片是640, 640, 3 227 | #-----------------------------------------------# 228 | 229 | #---------------------------------------------------# 230 | # 生成主干模型 231 | # 获得三个有效特征层,他们的shape分别是: 232 | # 80, 80, 512 233 | # 40, 40, 1024 234 | # 20, 20, 1024 235 | #---------------------------------------------------# 236 | self.backbone = Backbone(transition_channels, block_channels, n, phi, pretrained=pretrained) 237 | 238 | self.upsample = nn.Upsample(scale_factor=2, mode="nearest") 239 | 240 | self.sppcspc = SPPCSPC(transition_channels * 32, transition_channels * 16) 241 | self.conv_for_P5 = Conv(transition_channels * 16, transition_channels * 8) 242 | self.conv_for_feat2 = Conv(transition_channels * 32, transition_channels * 8) 243 | self.conv3_for_upsample1 = Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids) 244 | 245 | self.conv_for_P4 = Conv(transition_channels * 8, transition_channels * 4) 246 | self.conv_for_feat1 = Conv(transition_channels * 16, transition_channels * 4) 247 | self.conv3_for_upsample2 = Block(transition_channels * 8, panet_channels * 2, transition_channels * 4, e=e, n=n, ids=ids) 248 | 249 | self.down_sample1 = Transition(transition_channels * 4, transition_channels * 4) 250 | self.conv3_for_downsample1 = Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids) 251 | 252 | self.down_sample2 = Transition(transition_channels * 8, transition_channels * 8) 253 | self.conv3_for_downsample2 = Block(transition_channels * 32, panet_channels * 8, transition_channels * 16, e=e, n=n, ids=ids) 254 | 255 | self.rep_conv_1 = conv(transition_channels * 4, transition_channels * 8, 3, 1) 256 | self.rep_conv_2 = conv(transition_channels * 8, transition_channels * 16, 3, 1) 257 | self.rep_conv_3 = conv(transition_channels * 16, transition_channels * 32, 3, 1) 258 | 259 | self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1) 260 | self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1) 261 | self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1) 262 | 263 | def fuse(self): 264 | print('Fusing layers... ') 265 | for m in self.modules(): 266 | if isinstance(m, RepConv): 267 | m.fuse_repvgg_block() 268 | elif type(m) is Conv and hasattr(m, 'bn'): 269 | m.conv = fuse_conv_and_bn(m.conv, m.bn) 270 | delattr(m, 'bn') 271 | m.forward = m.fuseforward 272 | return self 273 | 274 | def forward(self, x): 275 | # backbone 276 | feat1, feat2, feat3 = self.backbone.forward(x) 277 | 278 | P5 = self.sppcspc(feat3) 279 | P5_conv = self.conv_for_P5(P5) 280 | P5_upsample = self.upsample(P5_conv) 281 | P4 = torch.cat([self.conv_for_feat2(feat2), P5_upsample], 1) 282 | P4 = self.conv3_for_upsample1(P4) 283 | 284 | P4_conv = self.conv_for_P4(P4) 285 | P4_upsample = self.upsample(P4_conv) 286 | P3 = torch.cat([self.conv_for_feat1(feat1), P4_upsample], 1) 287 | P3 = self.conv3_for_upsample2(P3) 288 | 289 | P3_downsample = self.down_sample1(P3) 290 | P4 = torch.cat([P3_downsample, P4], 1) 291 | P4 = self.conv3_for_downsample1(P4) 292 | 293 | P4_downsample = self.down_sample2(P4) 294 | P5 = torch.cat([P4_downsample, P5], 1) 295 | P5 = self.conv3_for_downsample2(P5) 296 | 297 | P3 = self.rep_conv_1(P3) 298 | P4 = self.rep_conv_2(P4) 299 | P5 = self.rep_conv_3(P5) 300 | #---------------------------------------------------# 301 | # 第三个特征层 302 | # y3=(batch_size, 75, 80, 80) 303 | #---------------------------------------------------# 304 | out2 = self.yolo_head_P3(P3) 305 | #---------------------------------------------------# 306 | # 第二个特征层 307 | # y2=(batch_size, 75, 40, 40) 308 | #---------------------------------------------------# 309 | out1 = self.yolo_head_P4(P4) 310 | #---------------------------------------------------# 311 | # 第一个特征层 312 | # y1=(batch_size, 75, 20, 20) 313 | #---------------------------------------------------# 314 | out0 = self.yolo_head_P5(P5) 315 | 316 | return [out0, out1, out2] 317 | -------------------------------------------------------------------------------- /detector/yolov7/predict.py: -------------------------------------------------------------------------------- 1 | #-----------------------------------------------------------------------# 2 | # predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能 3 | # 整合到了一个py文件中,通过指定mode进行模式的修改。 4 | #-----------------------------------------------------------------------# 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from yolo import YOLO 12 | 13 | if __name__ == "__main__": 14 | yolo = YOLO() 15 | #----------------------------------------------------------------------------------------------------------# 16 | # mode用于指定测试的模式: 17 | # 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 18 | # 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。 19 | # 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 20 | # 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。 21 | # 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。 22 | # 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。 23 | #----------------------------------------------------------------------------------------------------------# 24 | mode = "predict" 25 | #-------------------------------------------------------------------------# 26 | # crop 指定了是否在单张图片预测后对目标进行截取 27 | # count 指定了是否进行目标的计数 28 | # crop、count仅在mode='predict'时有效 29 | #-------------------------------------------------------------------------# 30 | crop = False 31 | count = False 32 | #----------------------------------------------------------------------------------------------------------# 33 | # video_path 用于指定视频的路径,当video_path=0时表示检测摄像头 34 | # 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。 35 | # video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存 36 | # 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。 37 | # video_fps 用于保存的视频的fps 38 | # 39 | # video_path、video_save_path和video_fps仅在mode='video'时有效 40 | # 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。 41 | #----------------------------------------------------------------------------------------------------------# 42 | video_path = 0 43 | video_save_path = "" 44 | video_fps = 25.0 45 | #----------------------------------------------------------------------------------------------------------# 46 | # test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。 47 | # fps_image_path 用于指定测试的fps图片 48 | # 49 | # test_interval和fps_image_path仅在mode='fps'有效 50 | #----------------------------------------------------------------------------------------------------------# 51 | test_interval = 100 52 | fps_image_path = "img/street.jpg" 53 | #-------------------------------------------------------------------------# 54 | # dir_origin_path 指定了用于检测的图片的文件夹路径 55 | # dir_save_path 指定了检测完图片的保存路径 56 | # 57 | # dir_origin_path和dir_save_path仅在mode='dir_predict'时有效 58 | #-------------------------------------------------------------------------# 59 | dir_origin_path = "img/" 60 | dir_save_path = "img_out/" 61 | #-------------------------------------------------------------------------# 62 | # heatmap_save_path 热力图的保存路径,默认保存在model_data下 63 | # 64 | # heatmap_save_path仅在mode='heatmap'有效 65 | #-------------------------------------------------------------------------# 66 | heatmap_save_path = "model_data/heatmap_vision.png" 67 | #-------------------------------------------------------------------------# 68 | # simplify 使用Simplify onnx 69 | # onnx_save_path 指定了onnx的保存路径 70 | #-------------------------------------------------------------------------# 71 | simplify = True 72 | onnx_save_path = "model_data/models.onnx" 73 | 74 | if mode == "predict": 75 | ''' 76 | 1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。 77 | 2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。 78 | 3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值 79 | 在原图上利用矩阵的方式进行截取。 80 | 4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断, 81 | 比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。 82 | ''' 83 | while True: 84 | img = input('Input image filename:') 85 | try: 86 | image = Image.open(img) 87 | except: 88 | print('Open Error! Try again!') 89 | continue 90 | else: 91 | r_image = yolo.detect_image(image, crop = crop, count=count) 92 | r_image.show() 93 | 94 | elif mode == "video": 95 | capture = cv2.VideoCapture(video_path) 96 | if video_save_path!="": 97 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 98 | size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))) 99 | out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size) 100 | 101 | ref, frame = capture.read() 102 | if not ref: 103 | raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。") 104 | 105 | fps = 0.0 106 | while(True): 107 | t1 = time.time() 108 | # 读取某一帧 109 | ref, frame = capture.read() 110 | if not ref: 111 | break 112 | # 格式转变,BGRtoRGB 113 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 114 | # 转变成Image 115 | frame = Image.fromarray(np.uint8(frame)) 116 | # 进行检测 117 | frame = np.array(yolo.detect_image(frame)) 118 | # RGBtoBGR满足opencv显示格式 119 | frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR) 120 | 121 | fps = ( fps + (1./(time.time()-t1)) ) / 2 122 | print("fps= %.2f"%(fps)) 123 | frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) 124 | 125 | cv2.imshow("video",frame) 126 | c= cv2.waitKey(1) & 0xff 127 | if video_save_path!="": 128 | out.write(frame) 129 | 130 | if c==27: 131 | capture.release() 132 | break 133 | 134 | print("Video Detection Done!") 135 | capture.release() 136 | if video_save_path!="": 137 | print("Save processed video to the path :" + video_save_path) 138 | out.release() 139 | cv2.destroyAllWindows() 140 | 141 | elif mode == "fps": 142 | img = Image.open(fps_image_path) 143 | tact_time = yolo.get_FPS(img, test_interval) 144 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') 145 | 146 | elif mode == "dir_predict": 147 | import os 148 | 149 | from tqdm import tqdm 150 | 151 | img_names = os.listdir(dir_origin_path) 152 | for img_name in tqdm(img_names): 153 | if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')): 154 | image_path = os.path.join(dir_origin_path, img_name) 155 | image = Image.open(image_path) 156 | r_image = yolo.detect_image(image) 157 | if not os.path.exists(dir_save_path): 158 | os.makedirs(dir_save_path) 159 | r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0) 160 | 161 | elif mode == "heatmap": 162 | while True: 163 | img = input('Input image filename:') 164 | try: 165 | image = Image.open(img) 166 | except: 167 | print('Open Error! Try again!') 168 | continue 169 | else: 170 | yolo.detect_heatmap(image, heatmap_save_path) 171 | 172 | elif mode == "export_onnx": 173 | yolo.convert_to_onnx(simplify, onnx_save_path) 174 | 175 | else: 176 | raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.") 177 | -------------------------------------------------------------------------------- /detector/yolov7/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.2.1 2 | numpy==1.17.0 3 | matplotlib==3.1.2 4 | opencv_python==4.1.2.30 5 | torch==1.2.0 6 | torchvision==0.4.0 7 | tqdm==4.60.0 8 | Pillow==8.2.0 9 | h5py==2.10.0 10 | -------------------------------------------------------------------------------- /detector/yolov7/summary.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------# 2 | # 该部分代码用于看网络结构 3 | #--------------------------------------------# 4 | import torch 5 | from thop import clever_format, profile 6 | 7 | from nets.yolo import YoloBody 8 | 9 | if __name__ == "__main__": 10 | input_shape = [640, 640] 11 | anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] 12 | num_classes = 80 13 | phi = 'l' 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | m = YoloBody(anchors_mask, num_classes, phi, False).to(device) 17 | for i in m.children(): 18 | print(i) 19 | print('==============================') 20 | 21 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) 22 | flops, params = profile(m.to(device), (dummy_input, ), verbose=False) 23 | #--------------------------------------------------------# 24 | # flops * 2是因为profile没有将卷积作为两个operations 25 | # 有些论文将卷积算乘法、加法两个operations。此时乘2 26 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 27 | # 本代码选择乘2,参考YOLOX。 28 | #--------------------------------------------------------# 29 | flops = flops * 2 30 | flops, params = clever_format([flops, params], "%.3f") 31 | print('Total GFLOPS: %s' % (flops)) 32 | print('Total params: %s' % (params)) 33 | -------------------------------------------------------------------------------- /detector/yolov7/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /detector/yolov7/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import torch 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import scipy.signal 8 | from matplotlib import pyplot as plt 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import shutil 12 | import numpy as np 13 | 14 | from PIL import Image 15 | from tqdm import tqdm 16 | from .utils import cvtColor, preprocess_input, resize_image 17 | from .utils_bbox import DecodeBox 18 | from .utils_map import get_coco_map, get_map 19 | 20 | 21 | class LossHistory(): 22 | def __init__(self, log_dir, model, input_shape): 23 | self.log_dir = log_dir 24 | self.losses = [] 25 | self.val_loss = [] 26 | 27 | os.makedirs(self.log_dir) 28 | self.writer = SummaryWriter(self.log_dir) 29 | try: 30 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 31 | self.writer.add_graph(model, dummy_input) 32 | except: 33 | pass 34 | 35 | def append_loss(self, epoch, loss, val_loss): 36 | if not os.path.exists(self.log_dir): 37 | os.makedirs(self.log_dir) 38 | 39 | self.losses.append(loss) 40 | self.val_loss.append(val_loss) 41 | 42 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 43 | f.write(str(loss)) 44 | f.write("\n") 45 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 46 | f.write(str(val_loss)) 47 | f.write("\n") 48 | 49 | self.writer.add_scalar('loss', loss, epoch) 50 | self.writer.add_scalar('val_loss', val_loss, epoch) 51 | self.loss_plot() 52 | 53 | def loss_plot(self): 54 | iters = range(len(self.losses)) 55 | 56 | plt.figure() 57 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 58 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 59 | try: 60 | if len(self.losses) < 25: 61 | num = 5 62 | else: 63 | num = 15 64 | 65 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 66 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 67 | except: 68 | pass 69 | 70 | plt.grid(True) 71 | plt.xlabel('Epoch') 72 | plt.ylabel('Loss') 73 | plt.legend(loc="upper right") 74 | 75 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 76 | 77 | plt.cla() 78 | plt.close("all") 79 | 80 | class EvalCallback(): 81 | def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \ 82 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): 83 | super(EvalCallback, self).__init__() 84 | 85 | self.net = net 86 | self.input_shape = input_shape 87 | self.anchors = anchors 88 | self.anchors_mask = anchors_mask 89 | self.class_names = class_names 90 | self.num_classes = num_classes 91 | self.val_lines = val_lines 92 | self.log_dir = log_dir 93 | self.cuda = cuda 94 | self.map_out_path = map_out_path 95 | self.max_boxes = max_boxes 96 | self.confidence = confidence 97 | self.nms_iou = nms_iou 98 | self.letterbox_image = letterbox_image 99 | self.MINOVERLAP = MINOVERLAP 100 | self.eval_flag = eval_flag 101 | self.period = period 102 | 103 | self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask) 104 | 105 | self.maps = [0] 106 | self.epoches = [0] 107 | if self.eval_flag: 108 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 109 | f.write(str(0)) 110 | f.write("\n") 111 | 112 | def get_map_txt(self, image_id, image, class_names, map_out_path): 113 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8') 114 | image_shape = np.array(np.shape(image)[0:2]) 115 | #---------------------------------------------------------# 116 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 117 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 118 | #---------------------------------------------------------# 119 | image = cvtColor(image) 120 | #---------------------------------------------------------# 121 | # 给图像增加灰条,实现不失真的resize 122 | # 也可以直接resize进行识别 123 | #---------------------------------------------------------# 124 | image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) 125 | #---------------------------------------------------------# 126 | # 添加上batch_size维度 127 | #---------------------------------------------------------# 128 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 129 | 130 | with torch.no_grad(): 131 | images = torch.from_numpy(image_data) 132 | if self.cuda: 133 | images = images.cuda() 134 | #---------------------------------------------------------# 135 | # 将图像输入网络当中进行预测! 136 | #---------------------------------------------------------# 137 | outputs = self.net(images) 138 | outputs = self.bbox_util.decode_box(outputs) 139 | #---------------------------------------------------------# 140 | # 将预测框进行堆叠,然后进行非极大抑制 141 | #---------------------------------------------------------# 142 | results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, 143 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 144 | 145 | if results[0] is None: 146 | return 147 | 148 | top_label = np.array(results[0][:, 6], dtype = 'int32') 149 | top_conf = results[0][:, 4] * results[0][:, 5] 150 | top_boxes = results[0][:, :4] 151 | 152 | top_100 = np.argsort(top_label)[::-1][:self.max_boxes] 153 | top_boxes = top_boxes[top_100] 154 | top_conf = top_conf[top_100] 155 | top_label = top_label[top_100] 156 | 157 | for i, c in list(enumerate(top_label)): 158 | predicted_class = self.class_names[int(c)] 159 | box = top_boxes[i] 160 | score = str(top_conf[i]) 161 | 162 | top, left, bottom, right = box 163 | if predicted_class not in class_names: 164 | continue 165 | 166 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) 167 | 168 | f.close() 169 | return 170 | 171 | def on_epoch_end(self, epoch, model_eval): 172 | if epoch % self.period == 0 and self.eval_flag: 173 | self.net = model_eval 174 | if not os.path.exists(self.map_out_path): 175 | os.makedirs(self.map_out_path) 176 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): 177 | os.makedirs(os.path.join(self.map_out_path, "ground-truth")) 178 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): 179 | os.makedirs(os.path.join(self.map_out_path, "detection-results")) 180 | print("Get map.") 181 | for annotation_line in tqdm(self.val_lines): 182 | line = annotation_line.split() 183 | image_id = os.path.basename(line[0]).split('.')[0] 184 | #------------------------------# 185 | # 读取图像并转换成RGB图像 186 | #------------------------------# 187 | image = Image.open(line[0]) 188 | #------------------------------# 189 | # 获得预测框 190 | #------------------------------# 191 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 192 | #------------------------------# 193 | # 获得预测txt 194 | #------------------------------# 195 | self.get_map_txt(image_id, image, self.class_names, self.map_out_path) 196 | 197 | #------------------------------# 198 | # 获得真实框txt 199 | #------------------------------# 200 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 201 | for box in gt_boxes: 202 | left, top, right, bottom, obj = box 203 | obj_name = self.class_names[obj] 204 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 205 | 206 | print("Calculate Map.") 207 | try: 208 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] 209 | except: 210 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) 211 | self.maps.append(temp_map) 212 | self.epoches.append(epoch) 213 | 214 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 215 | f.write(str(temp_map)) 216 | f.write("\n") 217 | 218 | plt.figure() 219 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') 220 | 221 | plt.grid(True) 222 | plt.xlabel('Epoch') 223 | plt.ylabel('Map %s'%str(self.MINOVERLAP)) 224 | plt.title('A Map Curve') 225 | plt.legend(loc="upper right") 226 | 227 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) 228 | plt.cla() 229 | plt.close("all") 230 | 231 | print("Get map done.") 232 | shutil.rmtree(self.map_out_path) 233 | -------------------------------------------------------------------------------- /detector/yolov7/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | #---------------------------------------------------------# 6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 8 | #---------------------------------------------------------# 9 | def cvtColor(image): 10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 11 | return image 12 | else: 13 | image = image.convert('RGB') 14 | return image 15 | 16 | #---------------------------------------------------# 17 | # 对输入图像进行resize 18 | #---------------------------------------------------# 19 | def resize_image(image, size, letterbox_image): 20 | iw, ih = image.size 21 | w, h = size 22 | if letterbox_image: 23 | scale = min(w/iw, h/ih) 24 | nw = int(iw*scale) 25 | nh = int(ih*scale) 26 | 27 | image = image.resize((nw,nh), Image.BICUBIC) 28 | new_image = Image.new('RGB', size, (128,128,128)) 29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 30 | else: 31 | new_image = image.resize((w, h), Image.BICUBIC) 32 | return new_image 33 | 34 | #---------------------------------------------------# 35 | # 获得类 36 | #---------------------------------------------------# 37 | def get_classes(classes_path): 38 | with open(classes_path, encoding='utf-8') as f: 39 | class_names = f.readlines() 40 | class_names = [c.strip() for c in class_names] 41 | return class_names, len(class_names) 42 | 43 | #---------------------------------------------------# 44 | # 获得先验框 45 | #---------------------------------------------------# 46 | def get_anchors(anchors_path): 47 | '''loads the anchors from a file''' 48 | with open(anchors_path, encoding='utf-8') as f: 49 | anchors = f.readline() 50 | anchors = [float(x) for x in anchors.split(',')] 51 | anchors = np.array(anchors).reshape(-1, 2) 52 | return anchors, len(anchors) 53 | 54 | #---------------------------------------------------# 55 | # 获得学习率 56 | #---------------------------------------------------# 57 | def get_lr(optimizer): 58 | for param_group in optimizer.param_groups: 59 | return param_group['lr'] 60 | 61 | def preprocess_input(image): 62 | image /= 255.0 63 | return image 64 | 65 | def show_config(**kwargs): 66 | print('Configurations:') 67 | print('-' * 70) 68 | print('|%25s | %40s|' % ('keys', 'values')) 69 | print('-' * 70) 70 | for key, value in kwargs.items(): 71 | print('|%25s | %40s|' % (str(key), str(value))) 72 | print('-' * 70) 73 | 74 | def download_weights(phi, model_dir="./model_data"): 75 | import os 76 | from torch.hub import load_state_dict_from_url 77 | 78 | download_urls = { 79 | "l" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_backbone_weights.pth', 80 | "x" : 'https://github.com/bubbliiiing/yolov7-pytorch/releases/download/v1.0/yolov7_x_backbone_weights.pth', 81 | } 82 | url = download_urls[phi] 83 | 84 | if not os.path.exists(model_dir): 85 | os.makedirs(model_dir) 86 | load_state_dict_from_url(url, model_dir) -------------------------------------------------------------------------------- /detector/yolov7/utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from utils.utils import get_lr 7 | 8 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0): 9 | loss = 0 10 | val_loss = 0 11 | 12 | if local_rank == 0: 13 | print('Start Train') 14 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 15 | model_train.train() 16 | for iteration, batch in enumerate(gen): 17 | if iteration >= epoch_step: 18 | break 19 | 20 | images, targets = batch[0], batch[1] 21 | with torch.no_grad(): 22 | if cuda: 23 | images = images.cuda(local_rank) 24 | targets = targets.cuda(local_rank) 25 | #----------------------# 26 | # 清零梯度 27 | #----------------------# 28 | optimizer.zero_grad() 29 | if not fp16: 30 | #----------------------# 31 | # 前向传播 32 | #----------------------# 33 | outputs = model_train(images) 34 | loss_value = yolo_loss(outputs, targets, images) 35 | 36 | #----------------------# 37 | # 反向传播 38 | #----------------------# 39 | loss_value.backward() 40 | optimizer.step() 41 | else: 42 | from torch.cuda.amp import autocast 43 | with autocast(): 44 | #----------------------# 45 | # 前向传播 46 | #----------------------# 47 | outputs = model_train(images) 48 | loss_value = yolo_loss(outputs, targets, images) 49 | 50 | #----------------------# 51 | # 反向传播 52 | #----------------------# 53 | scaler.scale(loss_value).backward() 54 | scaler.step(optimizer) 55 | scaler.update() 56 | if ema: 57 | ema.update(model_train) 58 | 59 | loss += loss_value.item() 60 | 61 | if local_rank == 0: 62 | pbar.set_postfix(**{'loss' : loss / (iteration + 1), 63 | 'lr' : get_lr(optimizer)}) 64 | pbar.update(1) 65 | 66 | if local_rank == 0: 67 | pbar.close() 68 | print('Finish Train') 69 | print('Start Validation') 70 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 71 | 72 | if ema: 73 | model_train_eval = ema.ema 74 | else: 75 | model_train_eval = model_train.eval() 76 | 77 | for iteration, batch in enumerate(gen_val): 78 | if iteration >= epoch_step_val: 79 | break 80 | images, targets = batch[0], batch[1] 81 | with torch.no_grad(): 82 | if cuda: 83 | images = images.cuda(local_rank) 84 | targets = targets.cuda(local_rank) 85 | #----------------------# 86 | # 清零梯度 87 | #----------------------# 88 | optimizer.zero_grad() 89 | #----------------------# 90 | # 前向传播 91 | #----------------------# 92 | outputs = model_train_eval(images) 93 | loss_value = yolo_loss(outputs, targets, images) 94 | 95 | val_loss += loss_value.item() 96 | if local_rank == 0: 97 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) 98 | pbar.update(1) 99 | 100 | if local_rank == 0: 101 | pbar.close() 102 | print('Finish Validation') 103 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val) 104 | eval_callback.on_epoch_end(epoch + 1, model_train_eval) 105 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 106 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) 107 | 108 | #-----------------------------------------------# 109 | # 保存权值 110 | #-----------------------------------------------# 111 | if ema: 112 | save_state_dict = ema.ema.state_dict() 113 | else: 114 | save_state_dict = model.state_dict() 115 | 116 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 117 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))) 118 | 119 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 120 | print('Save best model to best_epoch_weights.pth') 121 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) 122 | 123 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /detector/yolov7/utils_coco/coco_annotation.py: -------------------------------------------------------------------------------- 1 | #-------------------------------------------------------# 2 | # 用于处理COCO数据集,根据json文件生成txt文件用于训练 3 | #-------------------------------------------------------# 4 | import json 5 | import os 6 | from collections import defaultdict 7 | 8 | #-------------------------------------------------------# 9 | # 指向了COCO训练集与验证集图片的路径 10 | #-------------------------------------------------------# 11 | train_datasets_path = "coco_dataset/train2017" 12 | val_datasets_path = "coco_dataset/val2017" 13 | 14 | #-------------------------------------------------------# 15 | # 指向了COCO训练集与验证集标签的路径 16 | #-------------------------------------------------------# 17 | train_annotation_path = "coco_dataset/annotations/instances_train2017.json" 18 | val_annotation_path = "coco_dataset/annotations/instances_val2017.json" 19 | 20 | #-------------------------------------------------------# 21 | # 生成的txt文件路径 22 | #-------------------------------------------------------# 23 | train_output_path = "coco_train.txt" 24 | val_output_path = "coco_val.txt" 25 | 26 | if __name__ == "__main__": 27 | name_box_id = defaultdict(list) 28 | id_name = dict() 29 | f = open(train_annotation_path, encoding='utf-8') 30 | data = json.load(f) 31 | 32 | annotations = data['annotations'] 33 | for ant in annotations: 34 | id = ant['image_id'] 35 | name = os.path.join(train_datasets_path, '%012d.jpg' % id) 36 | cat = ant['category_id'] 37 | if cat >= 1 and cat <= 11: 38 | cat = cat - 1 39 | elif cat >= 13 and cat <= 25: 40 | cat = cat - 2 41 | elif cat >= 27 and cat <= 28: 42 | cat = cat - 3 43 | elif cat >= 31 and cat <= 44: 44 | cat = cat - 5 45 | elif cat >= 46 and cat <= 65: 46 | cat = cat - 6 47 | elif cat == 67: 48 | cat = cat - 7 49 | elif cat == 70: 50 | cat = cat - 9 51 | elif cat >= 72 and cat <= 82: 52 | cat = cat - 10 53 | elif cat >= 84 and cat <= 90: 54 | cat = cat - 11 55 | name_box_id[name].append([ant['bbox'], cat]) 56 | 57 | f = open(train_output_path, 'w') 58 | for key in name_box_id.keys(): 59 | f.write(key) 60 | box_infos = name_box_id[key] 61 | for info in box_infos: 62 | x_min = int(info[0][0]) 63 | y_min = int(info[0][1]) 64 | x_max = x_min + int(info[0][2]) 65 | y_max = y_min + int(info[0][3]) 66 | 67 | box_info = " %d,%d,%d,%d,%d" % ( 68 | x_min, y_min, x_max, y_max, int(info[1])) 69 | f.write(box_info) 70 | f.write('\n') 71 | f.close() 72 | 73 | name_box_id = defaultdict(list) 74 | id_name = dict() 75 | f = open(val_annotation_path, encoding='utf-8') 76 | data = json.load(f) 77 | 78 | annotations = data['annotations'] 79 | for ant in annotations: 80 | id = ant['image_id'] 81 | name = os.path.join(val_datasets_path, '%012d.jpg' % id) 82 | cat = ant['category_id'] 83 | if cat >= 1 and cat <= 11: 84 | cat = cat - 1 85 | elif cat >= 13 and cat <= 25: 86 | cat = cat - 2 87 | elif cat >= 27 and cat <= 28: 88 | cat = cat - 3 89 | elif cat >= 31 and cat <= 44: 90 | cat = cat - 5 91 | elif cat >= 46 and cat <= 65: 92 | cat = cat - 6 93 | elif cat == 67: 94 | cat = cat - 7 95 | elif cat == 70: 96 | cat = cat - 9 97 | elif cat >= 72 and cat <= 82: 98 | cat = cat - 10 99 | elif cat >= 84 and cat <= 90: 100 | cat = cat - 11 101 | name_box_id[name].append([ant['bbox'], cat]) 102 | 103 | f = open(val_output_path, 'w') 104 | for key in name_box_id.keys(): 105 | f.write(key) 106 | box_infos = name_box_id[key] 107 | for info in box_infos: 108 | x_min = int(info[0][0]) 109 | y_min = int(info[0][1]) 110 | x_max = x_min + int(info[0][2]) 111 | y_max = y_min + int(info[0][3]) 112 | 113 | box_info = " %d,%d,%d,%d,%d" % ( 114 | x_min, y_min, x_max, y_max, int(info[1])) 115 | f.write(box_info) 116 | f.write('\n') 117 | f.close() 118 | -------------------------------------------------------------------------------- /detector/yolov7/utils_coco/get_map_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from pycocotools.coco import COCO 8 | from pycocotools.cocoeval import COCOeval 9 | from tqdm import tqdm 10 | 11 | from utils.utils import cvtColor, preprocess_input, resize_image 12 | from yolo import YOLO 13 | 14 | #---------------------------------------------------------------------------# 15 | # map_mode用于指定该文件运行时计算的内容 16 | # map_mode为0代表整个map计算流程,包括获得预测结果、计算map。 17 | # map_mode为1代表仅仅获得预测结果。 18 | # map_mode为2代表仅仅获得计算map。 19 | #---------------------------------------------------------------------------# 20 | map_mode = 0 21 | #-------------------------------------------------------# 22 | # 指向了验证集标签与图片路径 23 | #-------------------------------------------------------# 24 | cocoGt_path = 'coco_dataset/annotations/instances_val2017.json' 25 | dataset_img_path = 'coco_dataset/val2017' 26 | #-------------------------------------------------------# 27 | # 结果输出的文件夹,默认为map_out 28 | #-------------------------------------------------------# 29 | temp_save_path = 'map_out/coco_eval' 30 | 31 | class mAP_YOLO(YOLO): 32 | #---------------------------------------------------# 33 | # 检测图片 34 | #---------------------------------------------------# 35 | def detect_image(self, image_id, image, results, clsid2catid): 36 | #---------------------------------------------------# 37 | # 计算输入图片的高和宽 38 | #---------------------------------------------------# 39 | image_shape = np.array(np.shape(image)[0:2]) 40 | #---------------------------------------------------------# 41 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 42 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 43 | #---------------------------------------------------------# 44 | image = cvtColor(image) 45 | #---------------------------------------------------------# 46 | # 给图像增加灰条,实现不失真的resize 47 | # 也可以直接resize进行识别 48 | #---------------------------------------------------------# 49 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 50 | #---------------------------------------------------------# 51 | # 添加上batch_size维度 52 | #---------------------------------------------------------# 53 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 54 | 55 | with torch.no_grad(): 56 | images = torch.from_numpy(image_data) 57 | if self.cuda: 58 | images = images.cuda() 59 | #---------------------------------------------------------# 60 | # 将图像输入网络当中进行预测! 61 | #---------------------------------------------------------# 62 | outputs = self.net(images) 63 | outputs = self.bbox_util.decode_box(outputs) 64 | #---------------------------------------------------------# 65 | # 将预测框进行堆叠,然后进行非极大抑制 66 | #---------------------------------------------------------# 67 | outputs = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, 68 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 69 | 70 | if outputs[0] is None: 71 | return results 72 | 73 | top_label = np.array(outputs[0][:, 6], dtype = 'int32') 74 | top_conf = outputs[0][:, 4] * outputs[0][:, 5] 75 | top_boxes = outputs[0][:, :4] 76 | 77 | for i, c in enumerate(top_label): 78 | result = {} 79 | top, left, bottom, right = top_boxes[i] 80 | 81 | result["image_id"] = int(image_id) 82 | result["category_id"] = clsid2catid[c] 83 | result["bbox"] = [float(left),float(top),float(right-left),float(bottom-top)] 84 | result["score"] = float(top_conf[i]) 85 | results.append(result) 86 | return results 87 | 88 | if __name__ == "__main__": 89 | if not os.path.exists(temp_save_path): 90 | os.makedirs(temp_save_path) 91 | 92 | cocoGt = COCO(cocoGt_path) 93 | ids = list(cocoGt.imgToAnns.keys()) 94 | clsid2catid = cocoGt.getCatIds() 95 | 96 | if map_mode == 0 or map_mode == 1: 97 | yolo = mAP_YOLO(confidence = 0.001, nms_iou = 0.65) 98 | 99 | with open(os.path.join(temp_save_path, 'eval_results.json'),"w") as f: 100 | results = [] 101 | for image_id in tqdm(ids): 102 | image_path = os.path.join(dataset_img_path, cocoGt.loadImgs(image_id)[0]['file_name']) 103 | image = Image.open(image_path) 104 | results = yolo.detect_image(image_id, image, results, clsid2catid) 105 | json.dump(results, f) 106 | 107 | if map_mode == 0 or map_mode == 2: 108 | cocoDt = cocoGt.loadRes(os.path.join(temp_save_path, 'eval_results.json')) 109 | cocoEval = COCOeval(cocoGt, cocoDt, 'bbox') 110 | cocoEval.evaluate() 111 | cocoEval.accumulate() 112 | cocoEval.summarize() 113 | print("Get map done.") 114 | -------------------------------------------------------------------------------- /detector/yolov7/voc_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import xml.etree.ElementTree as ET 4 | 5 | import numpy as np 6 | 7 | from utils.utils import get_classes 8 | 9 | #--------------------------------------------------------------------------------------------------------------------------------# 10 | # annotation_mode用于指定该文件运行时计算的内容 11 | # annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt 12 | # annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt 13 | # annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt 14 | #--------------------------------------------------------------------------------------------------------------------------------# 15 | annotation_mode = 0 16 | #-------------------------------------------------------------------# 17 | # 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息 18 | # 与训练和预测所用的classes_path一致即可 19 | # 如果生成的2007_train.txt里面没有目标信息 20 | # 那么就是因为classes没有设定正确 21 | # 仅在annotation_mode为0和2的时候有效 22 | #-------------------------------------------------------------------# 23 | classes_path = 'model_data/voc_classes.txt' 24 | #--------------------------------------------------------------------------------------------------------------------------------# 25 | # trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1 26 | # train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1 27 | # 仅在annotation_mode为0和1的时候有效 28 | #--------------------------------------------------------------------------------------------------------------------------------# 29 | trainval_percent = 0.9 30 | train_percent = 0.9 31 | #-------------------------------------------------------# 32 | # 指向VOC数据集所在的文件夹 33 | # 默认指向根目录下的VOC数据集 34 | #-------------------------------------------------------# 35 | VOCdevkit_path = 'VOCdevkit' 36 | 37 | VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')] 38 | classes, _ = get_classes(classes_path) 39 | 40 | #-------------------------------------------------------# 41 | # 统计目标数量 42 | #-------------------------------------------------------# 43 | photo_nums = np.zeros(len(VOCdevkit_sets)) 44 | nums = np.zeros(len(classes)) 45 | def convert_annotation(year, image_id, list_file): 46 | in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8') 47 | tree=ET.parse(in_file) 48 | root = tree.getroot() 49 | 50 | for obj in root.iter('object'): 51 | difficult = 0 52 | if obj.find('difficult')!=None: 53 | difficult = obj.find('difficult').text 54 | cls = obj.find('name').text 55 | if cls not in classes or int(difficult)==1: 56 | continue 57 | cls_id = classes.index(cls) 58 | xmlbox = obj.find('bndbox') 59 | b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text))) 60 | list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) 61 | 62 | nums[classes.index(cls)] = nums[classes.index(cls)] + 1 63 | 64 | if __name__ == "__main__": 65 | random.seed(0) 66 | if " " in os.path.abspath(VOCdevkit_path): 67 | raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。") 68 | 69 | if annotation_mode == 0 or annotation_mode == 1: 70 | print("Generate txt in ImageSets.") 71 | xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations') 72 | saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main') 73 | temp_xml = os.listdir(xmlfilepath) 74 | total_xml = [] 75 | for xml in temp_xml: 76 | if xml.endswith(".xml"): 77 | total_xml.append(xml) 78 | 79 | num = len(total_xml) 80 | list = range(num) 81 | tv = int(num*trainval_percent) 82 | tr = int(tv*train_percent) 83 | trainval= random.sample(list,tv) 84 | train = random.sample(trainval,tr) 85 | 86 | print("train and val size",tv) 87 | print("train size",tr) 88 | ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w') 89 | ftest = open(os.path.join(saveBasePath,'test.txt'), 'w') 90 | ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w') 91 | fval = open(os.path.join(saveBasePath,'val.txt'), 'w') 92 | 93 | for i in list: 94 | name=total_xml[i][:-4]+'\n' 95 | if i in trainval: 96 | ftrainval.write(name) 97 | if i in train: 98 | ftrain.write(name) 99 | else: 100 | fval.write(name) 101 | else: 102 | ftest.write(name) 103 | 104 | ftrainval.close() 105 | ftrain.close() 106 | fval.close() 107 | ftest.close() 108 | print("Generate txt in ImageSets done.") 109 | 110 | if annotation_mode == 0 or annotation_mode == 2: 111 | print("Generate 2007_train.txt and 2007_val.txt for train.") 112 | type_index = 0 113 | for year, image_set in VOCdevkit_sets: 114 | image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split() 115 | list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8') 116 | for image_id in image_ids: 117 | list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id)) 118 | 119 | convert_annotation(year, image_id, list_file) 120 | list_file.write('\n') 121 | photo_nums[type_index] = len(image_ids) 122 | type_index += 1 123 | list_file.close() 124 | print("Generate 2007_train.txt and 2007_val.txt for train done.") 125 | 126 | def printTable(List1, List2): 127 | for i in range(len(List1[0])): 128 | print("|", end=' ') 129 | for j in range(len(List1)): 130 | print(List1[j][i].rjust(int(List2[j])), end=' ') 131 | print("|", end=' ') 132 | print() 133 | 134 | str_nums = [str(int(x)) for x in nums] 135 | tableData = [ 136 | classes, str_nums 137 | ] 138 | colWidths = [0]*len(tableData) 139 | len1 = 0 140 | for i in range(len(tableData)): 141 | for j in range(len(tableData[i])): 142 | if len(tableData[i][j]) > colWidths[i]: 143 | colWidths[i] = len(tableData[i][j]) 144 | printTable(tableData, colWidths) 145 | 146 | if photo_nums[0] <= 500: 147 | print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。") 148 | 149 | if np.sum(nums) == 0: 150 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 151 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 152 | print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!") 153 | print("(重要的事情说三遍)。") 154 | -------------------------------------------------------------------------------- /img/1: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /img/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/img/result.png -------------------------------------------------------------------------------- /output/results 00_00_00-00_00_30~1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/output/results 00_00_00-00_00_30~1.gif -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/draw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/draw.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/draw.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/draw.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/draw.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/draw.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/io.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/io.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/io.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/log.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/log.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/parser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/parser.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deyiwang89/pytorch-yolov7-deepsort/ba43f1f40c5a9ff54f83906ddef294ed6b80a0e7/utils/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /utils/asserts.py: -------------------------------------------------------------------------------- 1 | from os import environ 2 | 3 | 4 | def assert_in(file, files_to_check): 5 | if file not in files_to_check: 6 | raise AssertionError("{} does not exist in the list".format(str(file))) 7 | return True 8 | 9 | 10 | def assert_in_env(check_list: list): 11 | for item in check_list: 12 | assert_in(item, environ.keys()) 13 | return True 14 | -------------------------------------------------------------------------------- /utils/draw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from collections import deque 4 | 5 | palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1) 6 | 7 | 8 | def compute_color_for_labels(label): 9 | """ 10 | Simple function that adds fixed color depending on the class 11 | """ 12 | color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette] 13 | return tuple(color) 14 | 15 | pts = [deque(maxlen = 30) for _ in range(9999)] 16 | def draw_boxes(img, bbox, identities=None, offset=(0,0)): 17 | for i,box in enumerate(bbox): 18 | x1,y1,x2,y2 = [int(i) for i in box] 19 | x1 += offset[0] 20 | x2 += offset[0] 21 | y1 += offset[1] 22 | y2 += offset[1] 23 | # box text and bar 24 | id = int(identities[i]) if identities is not None else 0 25 | color = compute_color_for_labels(id) 26 | label = '{}{:d}'.format("", id) 27 | t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0] 28 | cv2.rectangle(img,(x1, y1),(x2,y2),color,3) 29 | cv2.rectangle(img,(x1, y1),(x1+t_size[0]+3,y1+t_size[1]+4), color,-1) 30 | cv2.putText(img,label,(x1,y1+t_size[1]+4), cv2.FONT_HERSHEY_PLAIN, 2, [255,255,255], 2) 31 | #-----lines 32 | center = ((round((x1+x2)/2),round((y1+y2)/2))) 33 | pts[id].append(center) 34 | cv2.circle(img,(round((x1+x2)/2),round((y1+y2)/2)),1,color,5) 35 | for j in range(1, len(pts[id])): 36 | if pts[id][j-1] is None or pts[id][j] is None: 37 | continue 38 | thickness = int(np.sqrt(64/float(j+1)) * 2) 39 | cv2.line(img,(pts[id][j-1]),(pts[id][j]), color, thickness) 40 | return img 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | for i in range(82): 46 | print(compute_color_for_labels(i)) 47 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import copy 4 | import motmetrics as mm 5 | mm.lap.default_solver = 'lap' 6 | from utils.io import read_results, unzip_objs 7 | 8 | 9 | class Evaluator(object): 10 | 11 | def __init__(self, data_root, seq_name, data_type): 12 | self.data_root = data_root 13 | self.seq_name = seq_name 14 | self.data_type = data_type 15 | 16 | self.load_annotations() 17 | self.reset_accumulator() 18 | 19 | def load_annotations(self): 20 | assert self.data_type == 'mot' 21 | 22 | gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt') 23 | self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True) 24 | self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True) 25 | 26 | def reset_accumulator(self): 27 | self.acc = mm.MOTAccumulator(auto_id=True) 28 | 29 | def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False): 30 | # results 31 | trk_tlwhs = np.copy(trk_tlwhs) 32 | trk_ids = np.copy(trk_ids) 33 | 34 | # gts 35 | gt_objs = self.gt_frame_dict.get(frame_id, []) 36 | gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2] 37 | 38 | # ignore boxes 39 | ignore_objs = self.gt_ignore_frame_dict.get(frame_id, []) 40 | ignore_tlwhs = unzip_objs(ignore_objs)[0] 41 | 42 | 43 | # remove ignored results 44 | keep = np.ones(len(trk_tlwhs), dtype=bool) 45 | iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5) 46 | if len(iou_distance) > 0: 47 | match_is, match_js = mm.lap.linear_sum_assignment(iou_distance) 48 | match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js]) 49 | match_ious = iou_distance[match_is, match_js] 50 | 51 | match_js = np.asarray(match_js, dtype=int) 52 | match_js = match_js[np.logical_not(np.isnan(match_ious))] 53 | keep[match_js] = False 54 | trk_tlwhs = trk_tlwhs[keep] 55 | trk_ids = trk_ids[keep] 56 | 57 | # get distance matrix 58 | iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5) 59 | 60 | # acc 61 | self.acc.update(gt_ids, trk_ids, iou_distance) 62 | 63 | if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'): 64 | events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics 65 | else: 66 | events = None 67 | return events 68 | 69 | def eval_file(self, filename): 70 | self.reset_accumulator() 71 | 72 | result_frame_dict = read_results(filename, self.data_type, is_gt=False) 73 | frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys()))) 74 | for frame_id in frames: 75 | trk_objs = result_frame_dict.get(frame_id, []) 76 | trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2] 77 | self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False) 78 | 79 | return self.acc 80 | 81 | @staticmethod 82 | def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')): 83 | names = copy.deepcopy(names) 84 | if metrics is None: 85 | metrics = mm.metrics.motchallenge_metrics 86 | metrics = copy.deepcopy(metrics) 87 | 88 | mh = mm.metrics.create() 89 | summary = mh.compute_many( 90 | accs, 91 | metrics=metrics, 92 | names=names, 93 | generate_overall=True 94 | ) 95 | 96 | return summary 97 | 98 | @staticmethod 99 | def save_summary(summary, filename): 100 | import pandas as pd 101 | writer = pd.ExcelWriter(filename) 102 | summary.to_excel(writer) 103 | writer.save() 104 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | import numpy as np 4 | 5 | # from utils.log import get_logger 6 | 7 | 8 | def write_results(filename, results, data_type): 9 | if data_type == 'mot': 10 | save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n' 11 | elif data_type == 'kitti': 12 | save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n' 13 | else: 14 | raise ValueError(data_type) 15 | 16 | with open(filename, 'w') as f: 17 | for frame_id, tlwhs, track_ids in results: 18 | if data_type == 'kitti': 19 | frame_id -= 1 20 | for tlwh, track_id in zip(tlwhs, track_ids): 21 | if track_id < 0: 22 | continue 23 | x1, y1, w, h = tlwh 24 | x2, y2 = x1 + w, y1 + h 25 | line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h) 26 | f.write(line) 27 | 28 | 29 | # def write_results(filename, results_dict: Dict, data_type: str): 30 | # if not filename: 31 | # return 32 | # path = os.path.dirname(filename) 33 | # if not os.path.exists(path): 34 | # os.makedirs(path) 35 | 36 | # if data_type in ('mot', 'mcmot', 'lab'): 37 | # save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n' 38 | # elif data_type == 'kitti': 39 | # save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n' 40 | # else: 41 | # raise ValueError(data_type) 42 | 43 | # with open(filename, 'w') as f: 44 | # for frame_id, frame_data in results_dict.items(): 45 | # if data_type == 'kitti': 46 | # frame_id -= 1 47 | # for tlwh, track_id in frame_data: 48 | # if track_id < 0: 49 | # continue 50 | # x1, y1, w, h = tlwh 51 | # x2, y2 = x1 + w, y1 + h 52 | # line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0) 53 | # f.write(line) 54 | # logger.info('Save results to {}'.format(filename)) 55 | 56 | 57 | def read_results(filename, data_type: str, is_gt=False, is_ignore=False): 58 | if data_type in ('mot', 'lab'): 59 | read_fun = read_mot_results 60 | else: 61 | raise ValueError('Unknown data type: {}'.format(data_type)) 62 | 63 | return read_fun(filename, is_gt, is_ignore) 64 | 65 | 66 | """ 67 | labels={'ped', ... % 1 68 | 'person_on_vhcl', ... % 2 69 | 'car', ... % 3 70 | 'bicycle', ... % 4 71 | 'mbike', ... % 5 72 | 'non_mot_vhcl', ... % 6 73 | 'static_person', ... % 7 74 | 'distractor', ... % 8 75 | 'occluder', ... % 9 76 | 'occluder_on_grnd', ... %10 77 | 'occluder_full', ... % 11 78 | 'reflection', ... % 12 79 | 'crowd' ... % 13 80 | }; 81 | """ 82 | 83 | 84 | def read_mot_results(filename, is_gt, is_ignore): 85 | valid_labels = {1} 86 | ignore_labels = {2, 7, 8, 12} 87 | results_dict = dict() 88 | if os.path.isfile(filename): 89 | with open(filename, 'r') as f: 90 | for line in f.readlines(): 91 | linelist = line.split(',') 92 | if len(linelist) < 7: 93 | continue 94 | fid = int(linelist[0]) 95 | if fid < 1: 96 | continue 97 | results_dict.setdefault(fid, list()) 98 | 99 | if is_gt: 100 | if 'MOT16-' in filename or 'MOT17-' in filename: 101 | label = int(float(linelist[7])) 102 | mark = int(float(linelist[6])) 103 | if mark == 0 or label not in valid_labels: 104 | continue 105 | score = 1 106 | elif is_ignore: 107 | if 'MOT16-' in filename or 'MOT17-' in filename: 108 | label = int(float(linelist[7])) 109 | vis_ratio = float(linelist[8]) 110 | if label not in ignore_labels and vis_ratio >= 0: 111 | continue 112 | else: 113 | continue 114 | score = 1 115 | else: 116 | score = float(linelist[6]) 117 | 118 | tlwh = tuple(map(float, linelist[2:6])) 119 | target_id = int(linelist[1]) 120 | 121 | results_dict[fid].append((tlwh, target_id, score)) 122 | 123 | return results_dict 124 | 125 | 126 | def unzip_objs(objs): 127 | if len(objs) > 0: 128 | tlwhs, ids, scores = zip(*objs) 129 | else: 130 | tlwhs, ids, scores = [], [], [] 131 | tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4) 132 | 133 | return tlwhs, ids, scores -------------------------------------------------------------------------------- /utils/json_logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | https://medium.com/analytics-vidhya/creating-a-custom-logging-mechanism-for-real-time-object-detection-using-tdd-4ca2cfcd0a2f 4 | """ 5 | import json 6 | from os import makedirs 7 | from os.path import exists, join 8 | from datetime import datetime 9 | 10 | 11 | class JsonMeta(object): 12 | HOURS = 3 13 | MINUTES = 59 14 | SECONDS = 59 15 | PATH_TO_SAVE = 'LOGS' 16 | DEFAULT_FILE_NAME = 'remaining' 17 | 18 | 19 | class BaseJsonLogger(object): 20 | """ 21 | This is the base class that returns __dict__ of its own 22 | it also returns the dicts of objects in the attributes that are list instances 23 | 24 | """ 25 | 26 | def dic(self): 27 | # returns dicts of objects 28 | out = {} 29 | for k, v in self.__dict__.items(): 30 | if hasattr(v, 'dic'): 31 | out[k] = v.dic() 32 | elif isinstance(v, list): 33 | out[k] = self.list(v) 34 | else: 35 | out[k] = v 36 | return out 37 | 38 | @staticmethod 39 | def list(values): 40 | # applies the dic method on items in the list 41 | return [v.dic() if hasattr(v, 'dic') else v for v in values] 42 | 43 | 44 | class Label(BaseJsonLogger): 45 | """ 46 | For each bounding box there are various categories with confidences. Label class keeps track of that information. 47 | """ 48 | 49 | def __init__(self, category: str, confidence: float): 50 | self.category = category 51 | self.confidence = confidence 52 | 53 | 54 | class Bbox(BaseJsonLogger): 55 | """ 56 | This module stores the information for each frame and use them in JsonParser 57 | Attributes: 58 | labels (list): List of label module. 59 | top (int): 60 | left (int): 61 | width (int): 62 | height (int): 63 | 64 | Args: 65 | bbox_id (float): 66 | top (int): 67 | left (int): 68 | width (int): 69 | height (int): 70 | 71 | References: 72 | Check Label module for better understanding. 73 | 74 | 75 | """ 76 | 77 | def __init__(self, bbox_id, top, left, width, height): 78 | self.labels = [] 79 | self.bbox_id = bbox_id 80 | self.top = top 81 | self.left = left 82 | self.width = width 83 | self.height = height 84 | 85 | def add_label(self, category, confidence): 86 | # adds category and confidence only if top_k is not exceeded. 87 | self.labels.append(Label(category, confidence)) 88 | 89 | def labels_full(self, value): 90 | return len(self.labels) == value 91 | 92 | 93 | class Frame(BaseJsonLogger): 94 | """ 95 | This module stores the information for each frame and use them in JsonParser 96 | Attributes: 97 | timestamp (float): The elapsed time of captured frame 98 | frame_id (int): The frame number of the captured video 99 | bboxes (list of Bbox objects): Stores the list of bbox objects. 100 | 101 | References: 102 | Check Bbox class for better information 103 | 104 | Args: 105 | timestamp (float): 106 | frame_id (int): 107 | 108 | """ 109 | 110 | def __init__(self, frame_id: int, timestamp: float = None): 111 | self.frame_id = frame_id 112 | self.timestamp = timestamp 113 | self.bboxes = [] 114 | 115 | def add_bbox(self, bbox_id: int, top: int, left: int, width: int, height: int): 116 | bboxes_ids = [bbox.bbox_id for bbox in self.bboxes] 117 | if bbox_id not in bboxes_ids: 118 | self.bboxes.append(Bbox(bbox_id, top, left, width, height)) 119 | else: 120 | raise ValueError("Frame with id: {} already has a Bbox with id: {}".format(self.frame_id, bbox_id)) 121 | 122 | def add_label_to_bbox(self, bbox_id: int, category: str, confidence: float): 123 | bboxes = {bbox.id: bbox for bbox in self.bboxes} 124 | if bbox_id in bboxes.keys(): 125 | res = bboxes.get(bbox_id) 126 | res.add_label(category, confidence) 127 | else: 128 | raise ValueError('the bbox with id: {} does not exists!'.format(bbox_id)) 129 | 130 | 131 | class BboxToJsonLogger(BaseJsonLogger): 132 | """ 133 | ُ This module is designed to automate the task of logging jsons. An example json is used 134 | to show the contents of json file shortly 135 | Example: 136 | { 137 | "video_details": { 138 | "frame_width": 1920, 139 | "frame_height": 1080, 140 | "frame_rate": 20, 141 | "video_name": "/home/gpu/codes/MSD/pedestrian_2/project/public/camera1.avi" 142 | }, 143 | "frames": [ 144 | { 145 | "frame_id": 329, 146 | "timestamp": 3365.1254 147 | "bboxes": [ 148 | { 149 | "labels": [ 150 | { 151 | "category": "pedestrian", 152 | "confidence": 0.9 153 | } 154 | ], 155 | "bbox_id": 0, 156 | "top": 1257, 157 | "left": 138, 158 | "width": 68, 159 | "height": 109 160 | } 161 | ] 162 | }], 163 | 164 | Attributes: 165 | frames (dict): It's a dictionary that maps each frame_id to json attributes. 166 | video_details (dict): information about video file. 167 | top_k_labels (int): shows the allowed number of labels 168 | start_time (datetime object): we use it to automate the json output by time. 169 | 170 | Args: 171 | top_k_labels (int): shows the allowed number of labels 172 | 173 | """ 174 | 175 | def __init__(self, top_k_labels: int = 1): 176 | self.frames = {} 177 | self.video_details = self.video_details = dict(frame_width=None, frame_height=None, frame_rate=None, 178 | video_name=None) 179 | self.top_k_labels = top_k_labels 180 | self.start_time = datetime.now() 181 | 182 | def set_top_k(self, value): 183 | self.top_k_labels = value 184 | 185 | def frame_exists(self, frame_id: int) -> bool: 186 | """ 187 | Args: 188 | frame_id (int): 189 | 190 | Returns: 191 | bool: true if frame_id is recognized 192 | """ 193 | return frame_id in self.frames.keys() 194 | 195 | def add_frame(self, frame_id: int, timestamp: float = None) -> None: 196 | """ 197 | Args: 198 | frame_id (int): 199 | timestamp (float): opencv captured frame time property 200 | 201 | Raises: 202 | ValueError: if frame_id would not exist in class frames attribute 203 | 204 | Returns: 205 | None 206 | 207 | """ 208 | if not self.frame_exists(frame_id): 209 | self.frames[frame_id] = Frame(frame_id, timestamp) 210 | else: 211 | raise ValueError("Frame id: {} already exists".format(frame_id)) 212 | 213 | def bbox_exists(self, frame_id: int, bbox_id: int) -> bool: 214 | """ 215 | Args: 216 | frame_id: 217 | bbox_id: 218 | 219 | Returns: 220 | bool: if bbox exists in frame bboxes list 221 | """ 222 | bboxes = [] 223 | if self.frame_exists(frame_id=frame_id): 224 | bboxes = [bbox.bbox_id for bbox in self.frames[frame_id].bboxes] 225 | return bbox_id in bboxes 226 | 227 | def find_bbox(self, frame_id: int, bbox_id: int): 228 | """ 229 | 230 | Args: 231 | frame_id: 232 | bbox_id: 233 | 234 | Returns: 235 | bbox_id (int): 236 | 237 | Raises: 238 | ValueError: if bbox_id does not exist in the bbox list of specific frame. 239 | """ 240 | if not self.bbox_exists(frame_id, bbox_id): 241 | raise ValueError("frame with id: {} does not contain bbox with id: {}".format(frame_id, bbox_id)) 242 | bboxes = {bbox.bbox_id: bbox for bbox in self.frames[frame_id].bboxes} 243 | return bboxes.get(bbox_id) 244 | 245 | def add_bbox_to_frame(self, frame_id: int, bbox_id: int, top: int, left: int, width: int, height: int) -> None: 246 | """ 247 | 248 | Args: 249 | frame_id (int): 250 | bbox_id (int): 251 | top (int): 252 | left (int): 253 | width (int): 254 | height (int): 255 | 256 | Returns: 257 | None 258 | 259 | Raises: 260 | ValueError: if bbox_id already exist in frame information with frame_id 261 | ValueError: if frame_id does not exist in frames attribute 262 | """ 263 | if self.frame_exists(frame_id): 264 | frame = self.frames[frame_id] 265 | if not self.bbox_exists(frame_id, bbox_id): 266 | frame.add_bbox(bbox_id, top, left, width, height) 267 | else: 268 | raise ValueError( 269 | "frame with frame_id: {} already contains the bbox with id: {} ".format(frame_id, bbox_id)) 270 | else: 271 | raise ValueError("frame with frame_id: {} does not exist".format(frame_id)) 272 | 273 | def add_label_to_bbox(self, frame_id: int, bbox_id: int, category: str, confidence: float): 274 | """ 275 | Args: 276 | frame_id: 277 | bbox_id: 278 | category: 279 | confidence: the confidence value returned from yolo detection 280 | 281 | Returns: 282 | None 283 | 284 | Raises: 285 | ValueError: if labels quota (top_k_labels) exceeds. 286 | """ 287 | bbox = self.find_bbox(frame_id, bbox_id) 288 | if not bbox.labels_full(self.top_k_labels): 289 | bbox.add_label(category, confidence) 290 | else: 291 | raise ValueError("labels in frame_id: {}, bbox_id: {} is fulled".format(frame_id, bbox_id)) 292 | 293 | def add_video_details(self, frame_width: int = None, frame_height: int = None, frame_rate: int = None, 294 | video_name: str = None): 295 | self.video_details['frame_width'] = frame_width 296 | self.video_details['frame_height'] = frame_height 297 | self.video_details['frame_rate'] = frame_rate 298 | self.video_details['video_name'] = video_name 299 | 300 | def output(self): 301 | output = {'video_details': self.video_details} 302 | result = list(self.frames.values()) 303 | output['frames'] = [item.dic() for item in result] 304 | return output 305 | 306 | def json_output(self, output_name): 307 | """ 308 | Args: 309 | output_name: 310 | 311 | Returns: 312 | None 313 | 314 | Notes: 315 | It creates the json output with `output_name` name. 316 | """ 317 | if not output_name.endswith('.json'): 318 | output_name += '.json' 319 | with open(output_name, 'w') as file: 320 | json.dump(self.output(), file) 321 | file.close() 322 | 323 | def set_start(self): 324 | self.start_time = datetime.now() 325 | 326 | def schedule_output_by_time(self, output_dir=JsonMeta.PATH_TO_SAVE, hours: int = 0, minutes: int = 0, 327 | seconds: int = 60) -> None: 328 | """ 329 | Notes: 330 | Creates folder and then periodically stores the jsons on that address. 331 | 332 | Args: 333 | output_dir (str): the directory where output files will be stored 334 | hours (int): 335 | minutes (int): 336 | seconds (int): 337 | 338 | Returns: 339 | None 340 | 341 | """ 342 | end = datetime.now() 343 | interval = 0 344 | interval += abs(min([hours, JsonMeta.HOURS]) * 3600) 345 | interval += abs(min([minutes, JsonMeta.MINUTES]) * 60) 346 | interval += abs(min([seconds, JsonMeta.SECONDS])) 347 | diff = (end - self.start_time).seconds 348 | 349 | if diff > interval: 350 | output_name = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '.json' 351 | if not exists(output_dir): 352 | makedirs(output_dir) 353 | output = join(output_dir, output_name) 354 | self.json_output(output_name=output) 355 | self.frames = {} 356 | self.start_time = datetime.now() 357 | 358 | def schedule_output_by_frames(self, frames_quota, frame_counter, output_dir=JsonMeta.PATH_TO_SAVE): 359 | """ 360 | saves as the number of frames quota increases higher. 361 | :param frames_quota: 362 | :param frame_counter: 363 | :param output_dir: 364 | :return: 365 | """ 366 | pass 367 | 368 | def flush(self, output_dir): 369 | """ 370 | Notes: 371 | We use this function to output jsons whenever possible. 372 | like the time that we exit the while loop of opencv. 373 | 374 | Args: 375 | output_dir: 376 | 377 | Returns: 378 | None 379 | 380 | """ 381 | filename = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '-remaining.json' 382 | output = join(output_dir, filename) 383 | self.json_output(output_name=output) 384 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(name='root'): 5 | formatter = logging.Formatter( 6 | # fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s') 7 | fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 8 | 9 | handler = logging.StreamHandler() 10 | handler.setFormatter(formatter) 11 | 12 | logger = logging.getLogger(name) 13 | logger.setLevel(logging.INFO) 14 | logger.addHandler(handler) 15 | return logger 16 | 17 | 18 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from easydict import EasyDict as edict 4 | 5 | class YamlParser(edict): 6 | """ 7 | This is yaml parser based on EasyDict. 8 | """ 9 | def __init__(self, cfg_dict=None, config_file=None): 10 | if cfg_dict is None: 11 | cfg_dict = {} 12 | 13 | if config_file is not None: 14 | assert(os.path.isfile(config_file)) 15 | with open(config_file, 'r') as fo: 16 | cfg_dict.update(yaml.load(fo.read())) 17 | 18 | super(YamlParser, self).__init__(cfg_dict) 19 | 20 | 21 | def merge_from_file(self, config_file): 22 | with open(config_file, 'r') as fo: 23 | self.update(yaml.load(fo.read())) 24 | 25 | 26 | def merge_from_dict(self, config_dict): 27 | self.update(config_dict) 28 | 29 | 30 | def get_config(config_file=None): 31 | return YamlParser(config_file=config_file) 32 | 33 | 34 | if __name__ == "__main__": 35 | cfg = YamlParser(config_file="../configs/yolov3.yaml") 36 | cfg.merge_from_file("../configs/deep_sort.yaml") 37 | 38 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from time import time 3 | 4 | 5 | def is_video(ext: str): 6 | """ 7 | Returns true if ext exists in 8 | allowed_exts for video files. 9 | 10 | Args: 11 | ext: 12 | 13 | Returns: 14 | 15 | """ 16 | 17 | allowed_exts = ('.mp4', '.webm', '.ogg', '.avi', '.wmv', '.mkv', '.3gp') 18 | return any((ext.endswith(x) for x in allowed_exts)) 19 | 20 | 21 | def tik_tok(func): 22 | """ 23 | keep track of time for each process. 24 | Args: 25 | func: 26 | 27 | Returns: 28 | 29 | """ 30 | @wraps(func) 31 | def _time_it(*args, **kwargs): 32 | start = time() 33 | try: 34 | return func(*args, **kwargs) 35 | finally: 36 | end_ = time() 37 | print("time: {:.03f}s, fps: {:.03f}".format(end_ - start, 1 / (end_ - start))) 38 | 39 | return _time_it 40 | -------------------------------------------------------------------------------- /yolov7_deepsort.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import argparse 5 | import torch 6 | import warnings 7 | import numpy as np 8 | 9 | from detector import build_detector 10 | from deep_sort import build_tracker 11 | from utils.draw import draw_boxes 12 | from utils.parser import get_config 13 | from utils.log import get_logger 14 | from utils.io import write_results 15 | 16 | from PIL import Image 17 | 18 | 19 | class VideoTracker(object): 20 | def __init__(self, args, video_path): 21 | self.args = args 22 | self.video_path = video_path 23 | self.logger = get_logger("root") 24 | 25 | use_cuda = args.use_cuda and torch.cuda.is_available() 26 | if not use_cuda: 27 | warnings.warn("Running in cpu mode which maybe very slow!", UserWarning) 28 | 29 | if args.display: 30 | cv2.namedWindow("test", cv2.WINDOW_NORMAL) 31 | cv2.resizeWindow("test", args.display_width, args.display_height) 32 | 33 | if args.cam != -1: 34 | print("Using webcam " + str(args.cam)) 35 | self.vdo = cv2.VideoCapture(args.cam) 36 | else: 37 | self.vdo = cv2.VideoCapture() 38 | self.detector = build_detector() 39 | self.deepsort = build_tracker(use_cuda=True) 40 | self.class_names = self.detector.class_names 41 | 42 | def __enter__(self): 43 | if self.args.cam != -1: 44 | ret, frame = self.vdo.read() 45 | assert ret, "Error: Camera error" 46 | self.im_width = frame.shape[0] 47 | self.im_height = frame.shape[1] 48 | 49 | else: 50 | assert os.path.isfile(self.video_path), "Path error" 51 | self.vdo.open(self.video_path) 52 | self.im_width = int(self.vdo.get(cv2.CAP_PROP_FRAME_WIDTH)) 53 | self.im_height = int(self.vdo.get(cv2.CAP_PROP_FRAME_HEIGHT)) 54 | assert self.vdo.isOpened() 55 | 56 | if self.args.save_path: 57 | os.makedirs(self.args.save_path, exist_ok=True) 58 | 59 | # path of saved video and results 60 | self.save_video_path = os.path.join(self.args.save_path, "results.avi") 61 | self.save_results_path = os.path.join(self.args.save_path, "results.txt") 62 | 63 | # create video writer 64 | fourcc = cv2.VideoWriter_fourcc(*'MJPG') 65 | self.writer = cv2.VideoWriter(self.save_video_path, fourcc, 20, (self.im_width, self.im_height)) 66 | 67 | # logging 68 | self.logger.info("Save results to {}".format(self.args.save_path)) 69 | 70 | return self 71 | 72 | def __exit__(self, exc_type, exc_value, exc_traceback): 73 | if exc_type: 74 | print(exc_type, exc_value, exc_traceback) 75 | 76 | 77 | def run(self): 78 | results = [] 79 | idx_frame = 0 80 | while self.vdo.grab(): 81 | idx_frame += 1 82 | if idx_frame % self.args.frame_interval: 83 | continue 84 | 85 | start = time.time() 86 | ref, ori_im = self.vdo.retrieve() 87 | 88 | if ref is True: 89 | im = cv2.cvtColor(ori_im, cv2.COLOR_BGR2RGB) 90 | #----- do detection 91 | frame = Image.fromarray(np.uint8(im)) 92 | bbox_xywh, cls_conf, cls_ids = self.detector.new_detect(frame) 93 | if cls_conf is not None: 94 | #-----copy 95 | list_fin = [] 96 | for i in bbox_xywh: 97 | temp = [] 98 | temp.append(i[0]) 99 | temp.append(i[1]) 100 | temp.append(i[2]*1.) 101 | temp.append(i[3]*1.) 102 | list_fin.append(temp) 103 | new_bbox = np.array(list_fin).astype(np.float32) 104 | 105 | #-----#-----mask processing filter the useless part 106 | mask = [0,1,2,3,5,7]# keep specific classes the indexes are corresponded to coco_classes 107 | mask_filter = [] 108 | for i in cls_ids: 109 | if i in mask: 110 | mask_filter.append(1) 111 | else: 112 | mask_filter.append(0) 113 | new_cls_conf = [] 114 | new_new_bbox = [] 115 | new_cls_ids = [] 116 | for i in range(len(mask_filter)): 117 | if mask_filter[i]==1: 118 | new_cls_conf.append(cls_conf[i]) 119 | new_new_bbox.append(new_bbox[i]) 120 | new_cls_ids.append(cls_ids[i]) 121 | new_bbox = np.array(new_new_bbox).astype(np.float32) 122 | cls_conf = np.array(new_cls_conf).astype(np.float32) 123 | cls_ids = np.array(new_cls_ids).astype(np.float32) 124 | #-----#----- 125 | 126 | # do tracking 127 | outputs = self.deepsort.update(new_bbox, cls_conf, im) 128 | 129 | # draw boxes for visualization 130 | if len(outputs) > 0: 131 | bbox_tlwh = [] 132 | bbox_xyxy = outputs[:, :4] 133 | identities = outputs[:, -1] 134 | ori_im = draw_boxes(ori_im, bbox_xyxy, identities) 135 | for bb_xyxy in bbox_xyxy: 136 | bbox_tlwh.append(self.deepsort._xyxy_to_tlwh(bb_xyxy)) 137 | results.append((idx_frame - 1, bbox_tlwh, identities)) 138 | 139 | end = time.time() 140 | 141 | if self.args.display: 142 | cv2.imshow("test", ori_im) 143 | cv2.waitKey(1) 144 | 145 | if self.args.save_path: 146 | self.writer.write(ori_im) 147 | 148 | # save results 149 | write_results(self.save_results_path, results, 'mot') 150 | 151 | # logging 152 | self.logger.info("time: {:.03f}s, fps: {:.03f}, detection numbers: {}, tracking numbers: {}" \ 153 | .format(end - start, 1 / (end - start), new_bbox.shape[0], len(outputs))) 154 | 155 | 156 | def parse_args(): 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--display", action="store_true", default=True) 159 | parser.add_argument("--frame_interval", type=int, default=1) 160 | parser.add_argument("--display_width", type=int, default=800) 161 | parser.add_argument("--display_height", type=int, default=600) 162 | parser.add_argument("--save_path", type=str, default="./output/") 163 | parser.add_argument("--cpu", dest="use_cuda", action="store_false", default=True) 164 | parser.add_argument("--camera", action="store", dest="cam", type=int, default="-1") 165 | return parser.parse_args() 166 | 167 | 168 | if __name__ == "__main__": 169 | args = parse_args() 170 | 171 | with VideoTracker( args, video_path='./001.avi') as vdo_trk: 172 | vdo_trk.run() 173 | --------------------------------------------------------------------------------