├── .DS_Store ├── deep_sort ├── .DS_Store ├── configs │ ├── deep_sort.yaml │ └── parser.py └── deep_sort │ ├── .DS_Store │ ├── __init__.py │ ├── deep │ ├── __init__.py │ ├── evaluate.py │ ├── feature_extractor.py │ ├── model.py │ ├── original_model.py │ ├── test.py │ └── train.py │ ├── deep_sort.py │ └── sort │ ├── __init__.py │ ├── detection.py │ ├── iou_matching.py │ ├── kalman_filter.py │ ├── linear_assignment.py │ ├── nn_matching.py │ ├── preprocessing.py │ ├── track.py │ └── tracker.py ├── readme.md ├── requirements.txt ├── requirements2.txt ├── selfutils ├── ava_action_list.pbtxt ├── coco_names.txt ├── slowfast_detection.py ├── temp.pbtxt └── visualization.py └── yolo_slowfast.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whiffe/yolov5-slowfast-deepsort-PytorchVideo/61eb82f6ec079c1ab6cb6d2ae1078ee1f903f836/.DS_Store -------------------------------------------------------------------------------- /deep_sort/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whiffe/yolov5-slowfast-deepsort-PytorchVideo/61eb82f6ec079c1ab6cb6d2ae1078ee1f903f836/deep_sort/.DS_Store -------------------------------------------------------------------------------- /deep_sort/configs/deep_sort.yaml: -------------------------------------------------------------------------------- 1 | DEEPSORT: 2 | REID_CKPT: "deep_sort/deep_sort/deep/checkpoint/ckpt.t7" 3 | MAX_DIST: 0.2 4 | MIN_CONFIDENCE: 0.3 5 | NMS_MAX_OVERLAP: 0.4 6 | MAX_IOU_DISTANCE: 0.7 7 | MAX_AGE: 70 8 | N_INIT: 3 9 | NN_BUDGET: 100 10 | 11 | -------------------------------------------------------------------------------- /deep_sort/configs/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from easydict import EasyDict as edict 4 | 5 | class YamlParser(edict): 6 | """ 7 | This is yaml parser based on EasyDict. 8 | """ 9 | def __init__(self, cfg_dict=None, config_file=None): 10 | if cfg_dict is None: 11 | cfg_dict = {} 12 | 13 | if config_file is not None: 14 | assert(os.path.isfile(config_file)) 15 | with open(config_file, 'r') as fo: 16 | cfg_dict.update(yaml.load(fo.read(),Loader=yaml.FullLoader)) 17 | 18 | super(YamlParser, self).__init__(cfg_dict) 19 | 20 | 21 | def merge_from_file(self, config_file): 22 | with open(config_file, 'r') as fo: 23 | self.update(yaml.load(fo.read(),Loader=yaml.FullLoader)) 24 | 25 | 26 | def merge_from_dict(self, config_dict): 27 | self.update(config_dict) 28 | 29 | 30 | def get_config(config_file=None): 31 | return YamlParser(config_file=config_file) 32 | 33 | 34 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whiffe/yolov5-slowfast-deepsort-PytorchVideo/61eb82f6ec079c1ab6cb6d2ae1078ee1f903f836/deep_sort/deep_sort/.DS_Store -------------------------------------------------------------------------------- /deep_sort/deep_sort/__init__.py: -------------------------------------------------------------------------------- 1 | from .deep_sort import DeepSort 2 | 3 | 4 | __all__ = ['DeepSort', 'build_tracker'] 5 | 6 | 7 | def build_tracker(cfg, use_cuda): 8 | return DeepSort(cfg.DEEPSORT.REID_CKPT, 9 | max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE, 10 | nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE, 11 | max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda) 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whiffe/yolov5-slowfast-deepsort-PytorchVideo/61eb82f6ec079c1ab6cb6d2ae1078ee1f903f836/deep_sort/deep_sort/deep/__init__.py -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | features = torch.load("features.pth") 4 | qf = features["qf"] 5 | ql = features["ql"] 6 | gf = features["gf"] 7 | gl = features["gl"] 8 | 9 | scores = qf.mm(gf.t()) 10 | res = scores.topk(5, dim=1)[1][:,0] 11 | top1correct = gl[res].eq(ql).sum().item() 12 | 13 | print("Acc top1:{:.3f}".format(top1correct/ql.size(0))) 14 | 15 | 16 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import cv2 5 | import logging 6 | 7 | from .model import Net 8 | 9 | class Extractor(object): 10 | def __init__(self, model_path, use_cuda=True): 11 | self.net = Net(reid=True) 12 | self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" 13 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict'] 14 | self.net.load_state_dict(state_dict) 15 | logger = logging.getLogger("root.tracker") 16 | logger.info("Loading weights from {}... Done!".format(model_path)) 17 | self.net.to(self.device) 18 | self.size = (64, 128) 19 | self.norm = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 22 | ]) 23 | 24 | 25 | 26 | def _preprocess(self, im_crops): 27 | """ 28 | TODO: 29 | 1. to float with scale from 0 to 1 30 | 2. resize to (64, 128) as Market1501 dataset did 31 | 3. concatenate to a numpy array 32 | 3. to torch Tensor 33 | 4. normalize 34 | """ 35 | def _resize(im, size): 36 | return cv2.resize(im.astype(np.float32)/255., size) 37 | 38 | im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float() 39 | return im_batch 40 | 41 | 42 | def __call__(self, im_crops): 43 | im_batch = self._preprocess(im_crops) 44 | with torch.no_grad(): 45 | im_batch = im_batch.to(self.device) 46 | features = self.net(im_batch) 47 | return features.cpu().numpy() 48 | 49 | 50 | if __name__ == '__main__': 51 | img = cv2.imread("demo.jpg")[:,:,(2,1,0)] 52 | extr = Extractor("checkpoint/ckpt.t7") 53 | feature = extr(img) 54 | print(feature.shape) 55 | 56 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicBlock(nn.Module): 6 | def __init__(self, c_in, c_out,is_downsample=False): 7 | super(BasicBlock,self).__init__() 8 | self.is_downsample = is_downsample 9 | if is_downsample: 10 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False) 11 | else: 12 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(c_out) 14 | self.relu = nn.ReLU(True) 15 | self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(c_out) 17 | if is_downsample: 18 | self.downsample = nn.Sequential( 19 | nn.Conv2d(c_in, c_out, 1, stride=2, bias=False), 20 | nn.BatchNorm2d(c_out) 21 | ) 22 | elif c_in != c_out: 23 | self.downsample = nn.Sequential( 24 | nn.Conv2d(c_in, c_out, 1, stride=1, bias=False), 25 | nn.BatchNorm2d(c_out) 26 | ) 27 | self.is_downsample = True 28 | 29 | def forward(self,x): 30 | y = self.conv1(x) 31 | y = self.bn1(y) 32 | y = self.relu(y) 33 | y = self.conv2(y) 34 | y = self.bn2(y) 35 | if self.is_downsample: 36 | x = self.downsample(x) 37 | return F.relu(x.add(y),True) 38 | 39 | def make_layers(c_in,c_out,repeat_times, is_downsample=False): 40 | blocks = [] 41 | for i in range(repeat_times): 42 | if i ==0: 43 | blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),] 44 | else: 45 | blocks += [BasicBlock(c_out,c_out),] 46 | return nn.Sequential(*blocks) 47 | 48 | class Net(nn.Module): 49 | def __init__(self, num_classes=751 ,reid=False): 50 | super(Net,self).__init__() 51 | # 3 128 64 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(3,64,3,stride=1,padding=1), 54 | nn.BatchNorm2d(64), 55 | nn.ReLU(inplace=True), 56 | # nn.Conv2d(32,32,3,stride=1,padding=1), 57 | # nn.BatchNorm2d(32), 58 | # nn.ReLU(inplace=True), 59 | nn.MaxPool2d(3,2,padding=1), 60 | ) 61 | # 32 64 32 62 | self.layer1 = make_layers(64,64,2,False) 63 | # 32 64 32 64 | self.layer2 = make_layers(64,128,2,True) 65 | # 64 32 16 66 | self.layer3 = make_layers(128,256,2,True) 67 | # 128 16 8 68 | self.layer4 = make_layers(256,512,2,True) 69 | # 256 8 4 70 | self.avgpool = nn.AvgPool2d((8,4),1) 71 | # 256 1 1 72 | self.reid = reid 73 | self.classifier = nn.Sequential( 74 | nn.Linear(512, 256), 75 | nn.BatchNorm1d(256), 76 | nn.ReLU(inplace=True), 77 | nn.Dropout(), 78 | nn.Linear(256, num_classes), 79 | ) 80 | 81 | def forward(self, x): 82 | x = self.conv(x) 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | x = self.layer4(x) 87 | x = self.avgpool(x) 88 | x = x.view(x.size(0),-1) 89 | # B x 128 90 | if self.reid: 91 | x = x.div(x.norm(p=2,dim=1,keepdim=True)) 92 | return x 93 | # classifier 94 | x = self.classifier(x) 95 | return x 96 | 97 | 98 | if __name__ == '__main__': 99 | net = Net() 100 | x = torch.randn(4,3,128,64) 101 | y = net(x) 102 | import ipdb; ipdb.set_trace() 103 | 104 | 105 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/original_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicBlock(nn.Module): 6 | def __init__(self, c_in, c_out,is_downsample=False): 7 | super(BasicBlock,self).__init__() 8 | self.is_downsample = is_downsample 9 | if is_downsample: 10 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False) 11 | else: 12 | self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(c_out) 14 | self.relu = nn.ReLU(True) 15 | self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(c_out) 17 | if is_downsample: 18 | self.downsample = nn.Sequential( 19 | nn.Conv2d(c_in, c_out, 1, stride=2, bias=False), 20 | nn.BatchNorm2d(c_out) 21 | ) 22 | elif c_in != c_out: 23 | self.downsample = nn.Sequential( 24 | nn.Conv2d(c_in, c_out, 1, stride=1, bias=False), 25 | nn.BatchNorm2d(c_out) 26 | ) 27 | self.is_downsample = True 28 | 29 | def forward(self,x): 30 | y = self.conv1(x) 31 | y = self.bn1(y) 32 | y = self.relu(y) 33 | y = self.conv2(y) 34 | y = self.bn2(y) 35 | if self.is_downsample: 36 | x = self.downsample(x) 37 | return F.relu(x.add(y),True) 38 | 39 | def make_layers(c_in,c_out,repeat_times, is_downsample=False): 40 | blocks = [] 41 | for i in range(repeat_times): 42 | if i ==0: 43 | blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),] 44 | else: 45 | blocks += [BasicBlock(c_out,c_out),] 46 | return nn.Sequential(*blocks) 47 | 48 | class Net(nn.Module): 49 | def __init__(self, num_classes=625 ,reid=False): 50 | super(Net,self).__init__() 51 | # 3 128 64 52 | self.conv = nn.Sequential( 53 | nn.Conv2d(3,32,3,stride=1,padding=1), 54 | nn.BatchNorm2d(32), 55 | nn.ELU(inplace=True), 56 | nn.Conv2d(32,32,3,stride=1,padding=1), 57 | nn.BatchNorm2d(32), 58 | nn.ELU(inplace=True), 59 | nn.MaxPool2d(3,2,padding=1), 60 | ) 61 | # 32 64 32 62 | self.layer1 = make_layers(32,32,2,False) 63 | # 32 64 32 64 | self.layer2 = make_layers(32,64,2,True) 65 | # 64 32 16 66 | self.layer3 = make_layers(64,128,2,True) 67 | # 128 16 8 68 | self.dense = nn.Sequential( 69 | nn.Dropout(p=0.6), 70 | nn.Linear(128*16*8, 128), 71 | nn.BatchNorm1d(128), 72 | nn.ELU(inplace=True) 73 | ) 74 | # 256 1 1 75 | self.reid = reid 76 | self.batch_norm = nn.BatchNorm1d(128) 77 | self.classifier = nn.Sequential( 78 | nn.Linear(128, num_classes), 79 | ) 80 | 81 | def forward(self, x): 82 | x = self.conv(x) 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | 87 | x = x.view(x.size(0),-1) 88 | if self.reid: 89 | x = self.dense[0](x) 90 | x = self.dense[1](x) 91 | x = x.div(x.norm(p=2,dim=1,keepdim=True)) 92 | return x 93 | x = self.dense(x) 94 | # B x 128 95 | # classifier 96 | x = self.classifier(x) 97 | return x 98 | 99 | 100 | if __name__ == '__main__': 101 | net = Net(reid=True) 102 | x = torch.randn(4,3,128,64) 103 | y = net(x) 104 | import ipdb; ipdb.set_trace() 105 | 106 | 107 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torchvision 4 | 5 | import argparse 6 | import os 7 | 8 | from model import Net 9 | 10 | parser = argparse.ArgumentParser(description="Train on market1501") 11 | parser.add_argument("--data-dir",default='data',type=str) 12 | parser.add_argument("--no-cuda",action="store_true") 13 | parser.add_argument("--gpu-id",default=0,type=int) 14 | args = parser.parse_args() 15 | 16 | # device 17 | device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" 18 | if torch.cuda.is_available() and not args.no_cuda: 19 | cudnn.benchmark = True 20 | 21 | # data loader 22 | root = args.data_dir 23 | query_dir = os.path.join(root,"query") 24 | gallery_dir = os.path.join(root,"gallery") 25 | transform = torchvision.transforms.Compose([ 26 | torchvision.transforms.Resize((128,64)), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 29 | ]) 30 | queryloader = torch.utils.data.DataLoader( 31 | torchvision.datasets.ImageFolder(query_dir, transform=transform), 32 | batch_size=64, shuffle=False 33 | ) 34 | galleryloader = torch.utils.data.DataLoader( 35 | torchvision.datasets.ImageFolder(gallery_dir, transform=transform), 36 | batch_size=64, shuffle=False 37 | ) 38 | 39 | # net definition 40 | net = Net(reid=True) 41 | assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" 42 | print('Loading from checkpoint/ckpt.t7') 43 | checkpoint = torch.load("./checkpoint/ckpt.t7") 44 | net_dict = checkpoint['net_dict'] 45 | net.load_state_dict(net_dict, strict=False) 46 | net.eval() 47 | net.to(device) 48 | 49 | # compute features 50 | query_features = torch.tensor([]).float() 51 | query_labels = torch.tensor([]).long() 52 | gallery_features = torch.tensor([]).float() 53 | gallery_labels = torch.tensor([]).long() 54 | 55 | with torch.no_grad(): 56 | for idx,(inputs,labels) in enumerate(queryloader): 57 | inputs = inputs.to(device) 58 | features = net(inputs).cpu() 59 | query_features = torch.cat((query_features, features), dim=0) 60 | query_labels = torch.cat((query_labels, labels)) 61 | 62 | for idx,(inputs,labels) in enumerate(galleryloader): 63 | inputs = inputs.to(device) 64 | features = net(inputs).cpu() 65 | gallery_features = torch.cat((gallery_features, features), dim=0) 66 | gallery_labels = torch.cat((gallery_labels, labels)) 67 | 68 | gallery_labels -= 2 69 | 70 | # save features 71 | features = { 72 | "qf": query_features, 73 | "ql": query_labels, 74 | "gf": gallery_features, 75 | "gl": gallery_labels 76 | } 77 | torch.save(features,"features.pth") -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torchvision 10 | 11 | from model import Net 12 | 13 | parser = argparse.ArgumentParser(description="Train on market1501") 14 | parser.add_argument("--data-dir",default='data',type=str) 15 | parser.add_argument("--no-cuda",action="store_true") 16 | parser.add_argument("--gpu-id",default=0,type=int) 17 | parser.add_argument("--lr",default=0.1, type=float) 18 | parser.add_argument("--interval",'-i',default=20,type=int) 19 | parser.add_argument('--resume', '-r',action='store_true') 20 | args = parser.parse_args() 21 | 22 | # device 23 | device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" 24 | if torch.cuda.is_available() and not args.no_cuda: 25 | cudnn.benchmark = True 26 | 27 | # data loading 28 | root = args.data_dir 29 | train_dir = os.path.join(root,"train") 30 | test_dir = os.path.join(root,"test") 31 | transform_train = torchvision.transforms.Compose([ 32 | torchvision.transforms.RandomCrop((128,64),padding=4), 33 | torchvision.transforms.RandomHorizontalFlip(), 34 | torchvision.transforms.ToTensor(), 35 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 36 | ]) 37 | transform_test = torchvision.transforms.Compose([ 38 | torchvision.transforms.Resize((128,64)), 39 | torchvision.transforms.ToTensor(), 40 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 41 | ]) 42 | trainloader = torch.utils.data.DataLoader( 43 | torchvision.datasets.ImageFolder(train_dir, transform=transform_train), 44 | batch_size=64,shuffle=True 45 | ) 46 | testloader = torch.utils.data.DataLoader( 47 | torchvision.datasets.ImageFolder(test_dir, transform=transform_test), 48 | batch_size=64,shuffle=True 49 | ) 50 | num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes)) 51 | 52 | # net definition 53 | start_epoch = 0 54 | net = Net(num_classes=num_classes) 55 | if args.resume: 56 | assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" 57 | print('Loading from checkpoint/ckpt.t7') 58 | checkpoint = torch.load("./checkpoint/ckpt.t7") 59 | # import ipdb; ipdb.set_trace() 60 | net_dict = checkpoint['net_dict'] 61 | net.load_state_dict(net_dict) 62 | best_acc = checkpoint['acc'] 63 | start_epoch = checkpoint['epoch'] 64 | net.to(device) 65 | 66 | # loss and optimizer 67 | criterion = torch.nn.CrossEntropyLoss() 68 | optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4) 69 | best_acc = 0. 70 | 71 | # train function for each epoch 72 | def train(epoch): 73 | print("\nEpoch : %d"%(epoch+1)) 74 | net.train() 75 | training_loss = 0. 76 | train_loss = 0. 77 | correct = 0 78 | total = 0 79 | interval = args.interval 80 | start = time.time() 81 | for idx, (inputs, labels) in enumerate(trainloader): 82 | # forward 83 | inputs,labels = inputs.to(device),labels.to(device) 84 | outputs = net(inputs) 85 | loss = criterion(outputs, labels) 86 | 87 | # backward 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | # accumurating 93 | training_loss += loss.item() 94 | train_loss += loss.item() 95 | correct += outputs.max(dim=1)[1].eq(labels).sum().item() 96 | total += labels.size(0) 97 | 98 | # print 99 | if (idx+1)%interval == 0: 100 | end = time.time() 101 | print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( 102 | 100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total 103 | )) 104 | training_loss = 0. 105 | start = time.time() 106 | 107 | return train_loss/len(trainloader), 1.- correct/total 108 | 109 | def test(epoch): 110 | global best_acc 111 | net.eval() 112 | test_loss = 0. 113 | correct = 0 114 | total = 0 115 | start = time.time() 116 | with torch.no_grad(): 117 | for idx, (inputs, labels) in enumerate(testloader): 118 | inputs, labels = inputs.to(device), labels.to(device) 119 | outputs = net(inputs) 120 | loss = criterion(outputs, labels) 121 | 122 | test_loss += loss.item() 123 | correct += outputs.max(dim=1)[1].eq(labels).sum().item() 124 | total += labels.size(0) 125 | 126 | print("Testing ...") 127 | end = time.time() 128 | print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( 129 | 100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total 130 | )) 131 | 132 | # saving checkpoint 133 | acc = 100.*correct/total 134 | if acc > best_acc: 135 | best_acc = acc 136 | print("Saving parameters to checkpoint/ckpt.t7") 137 | checkpoint = { 138 | 'net_dict':net.state_dict(), 139 | 'acc':acc, 140 | 'epoch':epoch, 141 | } 142 | if not os.path.isdir('checkpoint'): 143 | os.mkdir('checkpoint') 144 | torch.save(checkpoint, './checkpoint/ckpt.t7') 145 | 146 | return test_loss/len(testloader), 1.- correct/total 147 | 148 | # plot figure 149 | x_epoch = [] 150 | record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]} 151 | fig = plt.figure() 152 | ax0 = fig.add_subplot(121, title="loss") 153 | ax1 = fig.add_subplot(122, title="top1err") 154 | def draw_curve(epoch, train_loss, train_err, test_loss, test_err): 155 | global record 156 | record['train_loss'].append(train_loss) 157 | record['train_err'].append(train_err) 158 | record['test_loss'].append(test_loss) 159 | record['test_err'].append(test_err) 160 | 161 | x_epoch.append(epoch) 162 | ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train') 163 | ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val') 164 | ax1.plot(x_epoch, record['train_err'], 'bo-', label='train') 165 | ax1.plot(x_epoch, record['test_err'], 'ro-', label='val') 166 | if epoch == 0: 167 | ax0.legend() 168 | ax1.legend() 169 | fig.savefig("train.jpg") 170 | 171 | # lr decay 172 | def lr_decay(): 173 | global optimizer 174 | for params in optimizer.param_groups: 175 | params['lr'] *= 0.1 176 | lr = params['lr'] 177 | print("Learning rate adjusted to {}".format(lr)) 178 | 179 | def main(): 180 | for epoch in range(start_epoch, start_epoch+40): 181 | train_loss, train_err = train(epoch) 182 | test_loss, test_err = test(epoch) 183 | draw_curve(epoch, train_loss, train_err, test_loss, test_err) 184 | if (epoch+1)%20==0: 185 | lr_decay() 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/deep_sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .deep.feature_extractor import Extractor 5 | from .sort.nn_matching import NearestNeighborDistanceMetric 6 | from .sort.preprocessing import non_max_suppression 7 | from .sort.detection import Detection 8 | from .sort.tracker import Tracker 9 | 10 | 11 | __all__ = ['DeepSort'] 12 | 13 | 14 | class DeepSort(object): 15 | def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=3.0, max_iou_distance=0.7, max_age=70, n_init=2, nn_budget=100, use_cuda=True, use_appearence=True): 16 | self.min_confidence = min_confidence 17 | self.nms_max_overlap = nms_max_overlap 18 | self.use_appearence=use_appearence 19 | self.extractor = Extractor(model_path, use_cuda=use_cuda) 20 | 21 | max_cosine_distance = max_dist 22 | nn_budget = nn_budget 23 | metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget) 24 | self.tracker = Tracker(metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init) 25 | 26 | def update(self, bbox_xywh, confidences, labels, ori_img): 27 | self.height, self.width = ori_img.shape[:2] 28 | # generate detections 29 | 30 | if self.use_appearence: 31 | features = self._get_features(bbox_xywh, ori_img) 32 | else: 33 | features = np.array([np.array([0.5,0.5]) for _ in range(len(bbox_xywh))]) 34 | bbox_tlwh = self._xywh_to_tlwh(bbox_xywh) 35 | detections = [Detection(bbox_tlwh[i], conf, labels[i], features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence] 36 | 37 | # run on non-maximum supression 38 | # boxes = np.array([d.tlwh for d in detections]) 39 | # scores = np.array([d.confidence for d in detections]) 40 | # indices = non_max_suppression(boxes, self.nms_max_overlap, scores) 41 | # detections = [detections[i] for i in indices] 42 | 43 | # update tracker 44 | self.tracker.predict() 45 | self.tracker.update(detections) 46 | 47 | # output bbox identities 48 | outputs = [] 49 | for track in self.tracker.tracks: 50 | if not track.is_confirmed() or track.time_since_update > 1: 51 | continue 52 | box = track.to_tlwh() 53 | x1,y1,x2,y2 = self._tlwh_to_xyxy(box) 54 | track_id = track.track_id 55 | label=track.label 56 | Vx=10*track.mean[4] 57 | Vy=10*track.mean[5] 58 | outputs.append(np.array([x1,y1,x2,y2,label,track_id,Vx,Vy], dtype=np.int)) 59 | if len(outputs) > 0: 60 | outputs = np.stack(outputs,axis=0) 61 | return outputs 62 | 63 | 64 | """ 65 | TODO: 66 | Convert bbox from xc_yc_w_h to xtl_ytl_w_h 67 | Thanks JieChen91@github.com for reporting this bug! 68 | """ 69 | @staticmethod 70 | def _xywh_to_tlwh(bbox_xywh): 71 | if isinstance(bbox_xywh, np.ndarray): 72 | bbox_tlwh = bbox_xywh.copy() 73 | elif isinstance(bbox_xywh, torch.Tensor): 74 | bbox_tlwh = bbox_xywh.clone() 75 | bbox_tlwh[:,0] = bbox_xywh[:,0] - bbox_xywh[:,2]/2. 76 | bbox_tlwh[:,1] = bbox_xywh[:,1] - bbox_xywh[:,3]/2. 77 | return bbox_tlwh 78 | 79 | 80 | def _xywh_to_xyxy(self, bbox_xywh): 81 | x,y,w,h = bbox_xywh 82 | x1 = max(int(x-w/2),0) 83 | x2 = min(int(x+w/2),self.width-1) 84 | y1 = max(int(y-h/2),0) 85 | y2 = min(int(y+h/2),self.height-1) 86 | return x1,y1,x2,y2 87 | 88 | def _tlwh_to_xyxy(self, bbox_tlwh): 89 | """ 90 | TODO: 91 | Convert bbox from xtl_ytl_w_h to xc_yc_w_h 92 | Thanks JieChen91@github.com for reporting this bug! 93 | """ 94 | x,y,w,h = bbox_tlwh 95 | x1 = max(int(x),0) 96 | x2 = min(int(x+w),self.width-1) 97 | y1 = max(int(y),0) 98 | y2 = min(int(y+h),self.height-1) 99 | return x1,y1,x2,y2 100 | 101 | def _xyxy_to_tlwh(self, bbox_xyxy): 102 | x1,y1,x2,y2 = bbox_xyxy 103 | 104 | t = x1 105 | l = y1 106 | w = int(x2-x1) 107 | h = int(y2-y1) 108 | return t,l,w,h 109 | 110 | def _get_features(self, bbox_xywh, ori_img): 111 | im_crops = [] 112 | for box in bbox_xywh: 113 | x1,y1,x2,y2 = self._xywh_to_xyxy(box) 114 | im = ori_img[y1:y2,x1:x2] 115 | im_crops.append(im) 116 | if im_crops: 117 | features = self.extractor(im_crops) 118 | else: 119 | features = np.array([]) 120 | return features 121 | 122 | 123 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Whiffe/yolov5-slowfast-deepsort-PytorchVideo/61eb82f6ec079c1ab6cb6d2ae1078ee1f903f836/deep_sort/deep_sort/sort/__init__.py -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/detection.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | 4 | 5 | class Detection(object): 6 | """ 7 | This class represents a bounding box detection in a single image. 8 | 9 | Parameters 10 | ---------- 11 | tlwh : array_like 12 | Bounding box in format `(x, y, w, h)`. 13 | confidence : float 14 | Detector confidence score. 15 | feature : array_like 16 | A feature vector that describes the object contained in this image. 17 | 18 | Attributes 19 | ---------- 20 | tlwh : ndarray 21 | Bounding box in format `(top left x, top left y, width, height)`. 22 | confidence : ndarray 23 | Detector confidence score. 24 | feature : ndarray | NoneType 25 | A feature vector that describes the object contained in this image. 26 | 27 | """ 28 | 29 | def __init__(self, tlwh, confidence, label, feature): 30 | self.tlwh = np.asarray(tlwh, dtype=np.float) 31 | self.confidence = float(confidence) 32 | self.label=label 33 | self.feature = np.asarray(feature, dtype=np.float32) 34 | 35 | def to_tlbr(self): 36 | """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., 37 | `(top left, bottom right)`. 38 | """ 39 | ret = self.tlwh.copy() 40 | ret[2:] += ret[:2] 41 | return ret 42 | 43 | def to_xyah(self): 44 | """Convert bounding box to format `(center x, center y, aspect ratio, 45 | height)`, where the aspect ratio is `width / height`. 46 | """ 47 | ret = self.tlwh.copy() 48 | ret[:2] += ret[2:] / 2 49 | ret[2] /= ret[3] 50 | return ret 51 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/iou_matching.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from . import linear_assignment 5 | 6 | 7 | def diou(box,candis): 8 | """Computer distance intersection over union, called diou(imporved on iou) 9 | 10 | Parameters 11 | ---------- 12 | bbox : ndarray 13 | A bounding box in format `(top left x, top left y, width, height)`. 14 | candidates : ndarray 15 | A matrix of candidate bounding boxes (one per row) in the same format 16 | as `bbox`. 17 | 18 | Returns 19 | ------- 20 | ndarray 21 | The intersection over union in [0, 1] between the `bbox` and each 22 | candidate. A higher score means a larger fraction of the `bbox` is 23 | occluded by the candidate. 24 | """ 25 | centerA=np.array([box[0]+box[2]/2,box[1]+box[3]/2]) 26 | centerB=np.array([candis[:,0]+candis[:,2]/2,candis[:,1]+candis[:,3]/2]).T 27 | center_dis=np.sum((centerA-centerB)**2,axis=-1) 28 | max_x=np.maximum(box[0]+box[2], candis[:,0]+candis[:,2]) 29 | min_x=np.minimum(box[0], candis[:,0]) 30 | max_y=np.maximum(box[1]+box[3], candis[:,1]+candis[:,3]) 31 | min_y=np.minimum(box[1], candis[:,1]) 32 | max_dis=(max_x-min_x)**2+(max_y-min_y)**2 33 | rela_dis=center_dis/max_dis 34 | return (1+iou(box,candis)-rela_dis)/2 35 | 36 | # %% 37 | def iou(bbox, candidates): 38 | """Computer intersection over union. 39 | 40 | Parameters 41 | ---------- 42 | bbox : ndarray 43 | A bounding box in format `(top left x, top left y, width, height)`. 44 | candidates : ndarray 45 | A matrix of candidate bounding boxes (one per row) in the same format 46 | as `bbox`. 47 | 48 | Returns 49 | ------- 50 | ndarray 51 | The intersection over union in [0, 1] between the `bbox` and each 52 | candidate. A higher score means a larger fraction of the `bbox` is 53 | occluded by the candidate. 54 | 55 | """ 56 | bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] 57 | candidates_tl = candidates[:, :2] 58 | candidates_br = candidates[:, :2] + candidates[:, 2:] 59 | 60 | tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], 61 | np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] 62 | br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], 63 | np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] 64 | wh = np.maximum(0., br - tl) 65 | 66 | area_intersection = wh.prod(axis=1) 67 | area_bbox = bbox[2:].prod() 68 | area_candidates = candidates[:, 2:].prod(axis=1) 69 | return area_intersection / (area_bbox + area_candidates - area_intersection) 70 | # %% 71 | 72 | def iou_cost(tracks, detections, track_indices=None, 73 | detection_indices=None): 74 | """An intersection over union distance metric. 75 | 76 | Parameters 77 | ---------- 78 | tracks : List[deep_sort.track.Track] 79 | A list of tracks. 80 | detections : List[deep_sort.detection.Detection] 81 | A list of detections. 82 | track_indices : Optional[List[int]] 83 | A list of indices to tracks that should be matched. Defaults to 84 | all `tracks`. 85 | detection_indices : Optional[List[int]] 86 | A list of indices to detections that should be matched. Defaults 87 | to all `detections`. 88 | 89 | Returns 90 | ------- 91 | ndarray 92 | Returns a cost matrix of shape 93 | len(track_indices), len(detection_indices) where entry (i, j) is 94 | `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. 95 | 96 | """ 97 | if track_indices is None: 98 | track_indices = np.arange(len(tracks)) 99 | if detection_indices is None: 100 | detection_indices = np.arange(len(detections)) 101 | 102 | cost_matrix = np.zeros((len(track_indices), len(detection_indices))) 103 | for row, track_idx in enumerate(track_indices): 104 | if tracks[track_idx].time_since_update > 1: 105 | cost_matrix[row, :] = linear_assignment.INFTY_COST 106 | continue 107 | 108 | bbox = tracks[track_idx].to_tlwh() 109 | candidates = np.asarray([detections[i].tlwh for i in detection_indices]) 110 | cost_matrix[row, :] = 1. - iou(bbox, candidates) 111 | return cost_matrix 112 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/kalman_filter.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import scipy.linalg 4 | 5 | 6 | """ 7 | Table for the 0.95 quantile of the chi-square distribution with N degrees of 8 | freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv 9 | function and used as Mahalanobis gating threshold. 10 | """ 11 | chi2inv95 = { 12 | 1: 3.8415, 13 | 2: 5.9915, 14 | 3: 7.8147, 15 | 4: 9.4877, 16 | 5: 11.070, 17 | 6: 12.592, 18 | 7: 14.067, 19 | 8: 15.507, 20 | 9: 16.919} 21 | 22 | 23 | class KalmanFilter(object): 24 | """ 25 | A simple Kalman filter for tracking bounding boxes in image space. 26 | 27 | The 8-dimensional state space 28 | 29 | x, y, a, h, vx, vy, va, vh 30 | 31 | contains the bounding box center position (x, y), aspect ratio a, height h, 32 | and their respective velocities. 33 | 34 | Object motion follows a constant velocity model. The bounding box location 35 | (x, y, a, h) is taken as direct observation of the state space (linear 36 | observation model). 37 | 38 | """ 39 | 40 | def __init__(self): 41 | ndim, dt = 4, 1. 42 | 43 | # Create Kalman filter model matrices. 44 | self._motion_mat = np.eye(2 * ndim, 2 * ndim) 45 | for i in range(ndim): 46 | self._motion_mat[i, ndim + i] = dt 47 | self._update_mat = np.eye(ndim, 2 * ndim) 48 | 49 | # Motion and observation uncertainty are chosen relative to the current 50 | # state estimate. These weights control the amount of uncertainty in 51 | # the model. This is a bit hacky. 52 | self._std_weight_position = 1. / 20 53 | self._std_weight_velocity = 1. / 160 54 | 55 | def initiate(self, measurement): 56 | """Create track from unassociated measurement. 57 | 58 | Parameters 59 | ---------- 60 | measurement : ndarray 61 | Bounding box coordinates (x, y, a, h) with center position (x, y), 62 | aspect ratio a, and height h. 63 | 64 | Returns 65 | ------- 66 | (ndarray, ndarray) 67 | Returns the mean vector (8 dimensional) and covariance matrix (8x8 68 | dimensional) of the new track. Unobserved velocities are initialized 69 | to 0 mean. 70 | 71 | """ 72 | mean_pos = measurement 73 | mean_vel = np.zeros_like(mean_pos) 74 | mean = np.r_[mean_pos, mean_vel] 75 | 76 | std = [ 77 | 2 * self._std_weight_position * measurement[3], 78 | 2 * self._std_weight_position * measurement[3], 79 | 1e-2, 80 | 2 * self._std_weight_position * measurement[3], 81 | 10 * self._std_weight_velocity * measurement[3], 82 | 10 * self._std_weight_velocity * measurement[3], 83 | 1e-5, 84 | 10 * self._std_weight_velocity * measurement[3]] 85 | covariance = np.diag(np.square(std)) 86 | return mean, covariance 87 | 88 | def predict(self, mean, covariance): 89 | """Run Kalman filter prediction step. 90 | 91 | Parameters 92 | ---------- 93 | mean : ndarray 94 | The 8 dimensional mean vector of the object state at the previous 95 | time step. 96 | covariance : ndarray 97 | The 8x8 dimensional covariance matrix of the object state at the 98 | previous time step. 99 | 100 | Returns 101 | ------- 102 | (ndarray, ndarray) 103 | Returns the mean vector and covariance matrix of the predicted 104 | state. Unobserved velocities are initialized to 0 mean. 105 | 106 | """ 107 | std_pos = [ 108 | self._std_weight_position * mean[3], 109 | self._std_weight_position * mean[3], 110 | 1e-2, 111 | self._std_weight_position * mean[3]] 112 | std_vel = [ 113 | self._std_weight_velocity * mean[3], 114 | self._std_weight_velocity * mean[3], 115 | 1e-5, 116 | self._std_weight_velocity * mean[3]] 117 | motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) 118 | 119 | mean = np.dot(self._motion_mat, mean) 120 | covariance = np.linalg.multi_dot(( 121 | self._motion_mat, covariance, self._motion_mat.T)) + motion_cov 122 | 123 | return mean, covariance 124 | 125 | def project(self, mean, covariance): 126 | """Project state distribution to measurement space. 127 | 128 | Parameters 129 | ---------- 130 | mean : ndarray 131 | The state's mean vector (8 dimensional array). 132 | covariance : ndarray 133 | The state's covariance matrix (8x8 dimensional). 134 | 135 | Returns 136 | ------- 137 | (ndarray, ndarray) 138 | Returns the projected mean and covariance matrix of the given state 139 | estimate. 140 | 141 | """ 142 | std = [ 143 | self._std_weight_position * mean[3], 144 | self._std_weight_position * mean[3], 145 | 1e-1, 146 | self._std_weight_position * mean[3]] 147 | innovation_cov = np.diag(np.square(std)) 148 | 149 | mean = np.dot(self._update_mat, mean) 150 | covariance = np.linalg.multi_dot(( 151 | self._update_mat, covariance, self._update_mat.T)) 152 | return mean, covariance + innovation_cov 153 | 154 | def update(self, mean, covariance, measurement): 155 | """Run Kalman filter correction step. 156 | 157 | Parameters 158 | ---------- 159 | mean : ndarray 160 | The predicted state's mean vector (8 dimensional). 161 | covariance : ndarray 162 | The state's covariance matrix (8x8 dimensional). 163 | measurement : ndarray 164 | The 4 dimensional measurement vector (x, y, a, h), where (x, y) 165 | is the center position, a the aspect ratio, and h the height of the 166 | bounding box. 167 | 168 | Returns 169 | ------- 170 | (ndarray, ndarray) 171 | Returns the measurement-corrected state distribution. 172 | 173 | """ 174 | projected_mean, projected_cov = self.project(mean, covariance) 175 | 176 | chol_factor, lower = scipy.linalg.cho_factor( 177 | projected_cov, lower=True, check_finite=False) 178 | kalman_gain = scipy.linalg.cho_solve( 179 | (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, 180 | check_finite=False).T 181 | innovation = measurement - projected_mean 182 | 183 | new_mean = mean + np.dot(innovation, kalman_gain.T) 184 | new_covariance = covariance - np.linalg.multi_dot(( 185 | kalman_gain, projected_cov, kalman_gain.T)) 186 | return new_mean, new_covariance 187 | 188 | def gating_distance(self, mean, covariance, measurements, 189 | only_position=False): 190 | """Compute gating distance between state distribution and measurements. 191 | 192 | A suitable distance threshold can be obtained from `chi2inv95`. If 193 | `only_position` is False, the chi-square distribution has 4 degrees of 194 | freedom, otherwise 2. 195 | 196 | Parameters 197 | ---------- 198 | mean : ndarray 199 | Mean vector over the state distribution (8 dimensional). 200 | covariance : ndarray 201 | Covariance of the state distribution (8x8 dimensional). 202 | measurements : ndarray 203 | An Nx4 dimensional matrix of N measurements, each in 204 | format (x, y, a, h) where (x, y) is the bounding box center 205 | position, a the aspect ratio, and h the height. 206 | only_position : Optional[bool] 207 | If True, distance computation is done with respect to the bounding 208 | box center position only. 209 | 210 | Returns 211 | ------- 212 | ndarray 213 | Returns an array of length N, where the i-th element contains the 214 | squared Mahalanobis distance between (mean, covariance) and 215 | `measurements[i]`. 216 | 217 | """ 218 | mean, covariance = self.project(mean, covariance) 219 | if only_position: 220 | mean, covariance = mean[:2], covariance[:2, :2] 221 | measurements = measurements[:, :2] 222 | 223 | cholesky_factor = np.linalg.cholesky(covariance) 224 | d = measurements - mean 225 | z = scipy.linalg.solve_triangular( 226 | cholesky_factor, d.T, lower=True, check_finite=False, 227 | overwrite_b=True) 228 | squared_maha = np.sum(z * z, axis=0) 229 | return squared_maha 230 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/linear_assignment.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | # from sklearn.utils.linear_assignment_ import linear_assignment 5 | from scipy.optimize import linear_sum_assignment as linear_assignment 6 | from . import kalman_filter 7 | 8 | 9 | INFTY_COST = 1e+5 10 | 11 | 12 | def min_cost_matching( 13 | distance_metric, max_distance, tracks, detections, track_indices=None, 14 | detection_indices=None): 15 | """Solve linear assignment problem. 16 | 17 | Parameters 18 | ---------- 19 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 20 | The distance metric is given a list of tracks and detections as well as 21 | a list of N track indices and M detection indices. The metric should 22 | return the NxM dimensional cost matrix, where element (i, j) is the 23 | association cost between the i-th track in the given track indices and 24 | the j-th detection in the given detection_indices. 25 | max_distance : float 26 | Gating threshold. Associations with cost larger than this value are 27 | disregarded. 28 | tracks : List[track.Track] 29 | A list of predicted tracks at the current time step. 30 | detections : List[detection.Detection] 31 | A list of detections at the current time step. 32 | track_indices : List[int] 33 | List of track indices that maps rows in `cost_matrix` to tracks in 34 | `tracks` (see description above). 35 | detection_indices : List[int] 36 | List of detection indices that maps columns in `cost_matrix` to 37 | detections in `detections` (see description above). 38 | 39 | Returns 40 | ------- 41 | (List[(int, int)], List[int], List[int]) 42 | Returns a tuple with the following three entries: 43 | * A list of matched track and detection indices. 44 | * A list of unmatched track indices. 45 | * A list of unmatched detection indices. 46 | 47 | """ 48 | if track_indices is None: 49 | track_indices = np.arange(len(tracks)) 50 | if detection_indices is None: 51 | detection_indices = np.arange(len(detections)) 52 | 53 | if len(detection_indices) == 0 or len(track_indices) == 0: 54 | return [], track_indices, detection_indices # Nothing to match. 55 | 56 | cost_matrix = distance_metric( 57 | tracks, detections, track_indices, detection_indices) 58 | cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 59 | 60 | row_indices, col_indices = linear_assignment(cost_matrix) 61 | 62 | matches, unmatched_tracks, unmatched_detections = [], [], [] 63 | for col, detection_idx in enumerate(detection_indices): 64 | if col not in col_indices: 65 | unmatched_detections.append(detection_idx) 66 | for row, track_idx in enumerate(track_indices): 67 | if row not in row_indices: 68 | unmatched_tracks.append(track_idx) 69 | for row, col in zip(row_indices, col_indices): 70 | track_idx = track_indices[row] 71 | detection_idx = detection_indices[col] 72 | if cost_matrix[row, col] > max_distance: 73 | unmatched_tracks.append(track_idx) 74 | unmatched_detections.append(detection_idx) 75 | else: 76 | matches.append((track_idx, detection_idx)) 77 | return matches, unmatched_tracks, unmatched_detections 78 | 79 | 80 | def matching_cascade( 81 | distance_metric, max_distance, cascade_depth, tracks, detections, 82 | track_indices=None, detection_indices=None): 83 | """Run matching cascade. 84 | 85 | Parameters 86 | ---------- 87 | distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray 88 | The distance metric is given a list of tracks and detections as well as 89 | a list of N track indices and M detection indices. The metric should 90 | return the NxM dimensional cost matrix, where element (i, j) is the 91 | association cost between the i-th track in the given track indices and 92 | the j-th detection in the given detection indices. 93 | max_distance : float 94 | Gating threshold. Associations with cost larger than this value are 95 | disregarded. 96 | cascade_depth: int 97 | The cascade depth, should be se to the maximum track age. 98 | tracks : List[track.Track] 99 | A list of predicted tracks at the current time step. 100 | detections : List[detection.Detection] 101 | A list of detections at the current time step. 102 | track_indices : Optional[List[int]] 103 | List of track indices that maps rows in `cost_matrix` to tracks in 104 | `tracks` (see description above). Defaults to all tracks. 105 | detection_indices : Optional[List[int]] 106 | List of detection indices that maps columns in `cost_matrix` to 107 | detections in `detections` (see description above). Defaults to all 108 | detections. 109 | 110 | Returns 111 | ------- 112 | (List[(int, int)], List[int], List[int]) 113 | Returns a tuple with the following three entries: 114 | * A list of matched track and detection indices. 115 | * A list of unmatched track indices. 116 | * A list of unmatched detection indices. 117 | 118 | """ 119 | if track_indices is None: 120 | track_indices = list(range(len(tracks))) 121 | if detection_indices is None: 122 | detection_indices = list(range(len(detections))) 123 | 124 | unmatched_detections = detection_indices 125 | matches = [] 126 | for level in range(cascade_depth): 127 | if len(unmatched_detections) == 0: # No detections left 128 | break 129 | 130 | track_indices_l = [ 131 | k for k in track_indices 132 | if tracks[k].time_since_update == 1 + level 133 | ] 134 | if len(track_indices_l) == 0: # Nothing to match at this level 135 | continue 136 | 137 | matches_l, _, unmatched_detections = \ 138 | min_cost_matching( 139 | distance_metric, max_distance, tracks, detections, 140 | track_indices_l, unmatched_detections) 141 | matches += matches_l 142 | unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) 143 | return matches, unmatched_tracks, unmatched_detections 144 | 145 | 146 | def gate_cost_matrix( 147 | kf, cost_matrix, tracks, detections, track_indices, detection_indices, 148 | gated_cost=INFTY_COST, only_position=False): 149 | """Invalidate infeasible entries in cost matrix based on the state 150 | distributions obtained by Kalman filtering. 151 | 152 | Parameters 153 | ---------- 154 | kf : The Kalman filter. 155 | cost_matrix : ndarray 156 | The NxM dimensional cost matrix, where N is the number of track indices 157 | and M is the number of detection indices, such that entry (i, j) is the 158 | association cost between `tracks[track_indices[i]]` and 159 | `detections[detection_indices[j]]`. 160 | tracks : List[track.Track] 161 | A list of predicted tracks at the current time step. 162 | detections : List[detection.Detection] 163 | A list of detections at the current time step. 164 | track_indices : List[int] 165 | List of track indices that maps rows in `cost_matrix` to tracks in 166 | `tracks` (see description above). 167 | detection_indices : List[int] 168 | List of detection indices that maps columns in `cost_matrix` to 169 | detections in `detections` (see description above). 170 | gated_cost : Optional[float] 171 | Entries in the cost matrix corresponding to infeasible associations are 172 | set this value. Defaults to a very large value. 173 | only_position : Optional[bool] 174 | If True, only the x, y position of the state distribution is considered 175 | during gating. Defaults to False. 176 | 177 | Returns 178 | ------- 179 | ndarray 180 | Returns the modified cost matrix. 181 | 182 | """ 183 | gating_dim = 2 if only_position else 4 184 | gating_threshold = kalman_filter.chi2inv95[gating_dim] 185 | measurements = np.asarray( 186 | [detections[i].to_xyah() for i in detection_indices]) 187 | for row, track_idx in enumerate(track_indices): 188 | track = tracks[track_idx] 189 | gating_distance = kf.gating_distance( 190 | track.mean, track.covariance, measurements, only_position) 191 | cost_matrix[row, gating_distance > gating_threshold] = gated_cost 192 | return cost_matrix 193 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/nn_matching.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | 4 | 5 | def _pdist(a, b): 6 | """Compute pair-wise squared distance between points in `a` and `b`. 7 | 8 | Parameters 9 | ---------- 10 | a : array_like 11 | An NxM matrix of N samples of dimensionality M. 12 | b : array_like 13 | An LxM matrix of L samples of dimensionality M. 14 | 15 | Returns 16 | ------- 17 | ndarray 18 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 19 | contains the squared distance between `a[i]` and `b[j]`. 20 | 21 | """ 22 | a, b = np.asarray(a), np.asarray(b) 23 | if len(a) == 0 or len(b) == 0: 24 | return np.zeros((len(a), len(b))) 25 | a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) 26 | r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] 27 | r2 = np.clip(r2, 0., float(np.inf)) 28 | return r2 29 | 30 | 31 | def _cosine_distance(a, b, data_is_normalized=False): 32 | """Compute pair-wise cosine distance between points in `a` and `b`. 33 | 34 | Parameters 35 | ---------- 36 | a : array_like 37 | An NxM matrix of N samples of dimensionality M. 38 | b : array_like 39 | An LxM matrix of L samples of dimensionality M. 40 | data_is_normalized : Optional[bool] 41 | If True, assumes rows in a and b are unit length vectors. 42 | Otherwise, a and b are explicitly normalized to lenght 1. 43 | 44 | Returns 45 | ------- 46 | ndarray 47 | Returns a matrix of size len(a), len(b) such that eleement (i, j) 48 | contains the squared distance between `a[i]` and `b[j]`. 49 | 50 | """ 51 | if not data_is_normalized: 52 | a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) 53 | b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) 54 | return 1. - np.dot(a, b.T) 55 | 56 | 57 | def _nn_euclidean_distance(x, y): 58 | """ Helper function for nearest neighbor distance metric (Euclidean). 59 | 60 | Parameters 61 | ---------- 62 | x : ndarray 63 | A matrix of N row-vectors (sample points). 64 | y : ndarray 65 | A matrix of M row-vectors (query points). 66 | 67 | Returns 68 | ------- 69 | ndarray 70 | A vector of length M that contains for each entry in `y` the 71 | smallest Euclidean distance to a sample in `x`. 72 | 73 | """ 74 | distances = _pdist(x, y) 75 | return np.maximum(0.0, distances.min(axis=0)) 76 | 77 | 78 | def _nn_cosine_distance(x, y): 79 | """ Helper function for nearest neighbor distance metric (cosine). 80 | 81 | Parameters 82 | ---------- 83 | x : ndarray 84 | A matrix of N row-vectors (sample points). 85 | y : ndarray 86 | A matrix of M row-vectors (query points). 87 | 88 | Returns 89 | ------- 90 | ndarray 91 | A vector of length M that contains for each entry in `y` the 92 | smallest cosine distance to a sample in `x`. 93 | 94 | """ 95 | distances = _cosine_distance(x, y) 96 | return distances.min(axis=0) 97 | 98 | 99 | class NearestNeighborDistanceMetric(object): 100 | """ 101 | A nearest neighbor distance metric that, for each target, returns 102 | the closest distance to any sample that has been observed so far. 103 | 104 | Parameters 105 | ---------- 106 | metric : str 107 | Either "euclidean" or "cosine". 108 | matching_threshold: float 109 | The matching threshold. Samples with larger distance are considered an 110 | invalid match. 111 | budget : Optional[int] 112 | If not None, fix samples per class to at most this number. Removes 113 | the oldest samples when the budget is reached. 114 | 115 | Attributes 116 | ---------- 117 | samples : Dict[int -> List[ndarray]] 118 | A dictionary that maps from target identities to the list of samples 119 | that have been observed so far. 120 | 121 | """ 122 | 123 | def __init__(self, metric, matching_threshold, budget=None): 124 | 125 | 126 | if metric == "euclidean": 127 | self._metric = _nn_euclidean_distance 128 | elif metric == "cosine": 129 | self._metric = _nn_cosine_distance 130 | else: 131 | raise ValueError( 132 | "Invalid metric; must be either 'euclidean' or 'cosine'") 133 | self.matching_threshold = matching_threshold 134 | self.budget = budget 135 | self.samples = {} 136 | 137 | def partial_fit(self, features, targets, active_targets): 138 | """Update the distance metric with new data. 139 | 140 | Parameters 141 | ---------- 142 | features : ndarray 143 | An NxM matrix of N features of dimensionality M. 144 | targets : ndarray 145 | An integer array of associated target identities. 146 | active_targets : List[int] 147 | A list of targets that are currently present in the scene. 148 | 149 | """ 150 | for feature, target in zip(features, targets): 151 | self.samples.setdefault(target, []).append(feature) 152 | if self.budget is not None: 153 | self.samples[target] = self.samples[target][-self.budget:] 154 | self.samples = {k: self.samples[k] for k in active_targets} 155 | 156 | def distance(self, features, targets): 157 | """Compute distance between features and targets. 158 | 159 | Parameters 160 | ---------- 161 | features : ndarray 162 | An NxM matrix of N features of dimensionality M. 163 | targets : List[int] 164 | A list of targets to match the given `features` against. 165 | 166 | Returns 167 | ------- 168 | ndarray 169 | Returns a cost matrix of shape len(targets), len(features), where 170 | element (i, j) contains the closest squared distance between 171 | `targets[i]` and `features[j]`. 172 | 173 | """ 174 | cost_matrix = np.zeros((len(targets), len(features))) 175 | for i, target in enumerate(targets): 176 | cost_matrix[i, :] = self._metric(self.samples[target], features) 177 | return cost_matrix 178 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/preprocessing.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def non_max_suppression(boxes, max_bbox_overlap, scores=None): 7 | """Suppress overlapping detections. 8 | 9 | Original code from [1]_ has been adapted to include confidence score. 10 | 11 | .. [1] http://www.pyimagesearch.com/2015/02/16/ 12 | faster-non-maximum-suppression-python/ 13 | 14 | Examples 15 | -------- 16 | 17 | >>> boxes = [d.roi for d in detections] 18 | >>> scores = [d.confidence for d in detections] 19 | >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores) 20 | >>> detections = [detections[i] for i in indices] 21 | 22 | Parameters 23 | ---------- 24 | boxes : ndarray 25 | Array of ROIs (x, y, width, height). 26 | max_bbox_overlap : float 27 | ROIs that overlap more than this values are suppressed. 28 | scores : Optional[array_like] 29 | Detector confidence score. 30 | 31 | Returns 32 | ------- 33 | List[int] 34 | Returns indices of detections that have survived non-maxima suppression. 35 | 36 | """ 37 | if len(boxes) == 0: 38 | return [] 39 | 40 | boxes = boxes.astype(np.float) 41 | pick = [] 42 | 43 | x1 = boxes[:, 0] 44 | y1 = boxes[:, 1] 45 | x2 = boxes[:, 2] + boxes[:, 0] 46 | y2 = boxes[:, 3] + boxes[:, 1] 47 | 48 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 49 | if scores is not None: 50 | idxs = np.argsort(scores) 51 | else: 52 | idxs = np.argsort(y2) 53 | 54 | while len(idxs) > 0: 55 | last = len(idxs) - 1 56 | i = idxs[last] 57 | pick.append(i) 58 | 59 | xx1 = np.maximum(x1[i], x1[idxs[:last]]) 60 | yy1 = np.maximum(y1[i], y1[idxs[:last]]) 61 | xx2 = np.minimum(x2[i], x2[idxs[:last]]) 62 | yy2 = np.minimum(y2[i], y2[idxs[:last]]) 63 | 64 | w = np.maximum(0, xx2 - xx1 + 1) 65 | h = np.maximum(0, yy2 - yy1 + 1) 66 | 67 | overlap = (w * h) / area[idxs[:last]] 68 | 69 | idxs = np.delete( 70 | idxs, np.concatenate( 71 | ([last], np.where(overlap > max_bbox_overlap)[0]))) 72 | 73 | return pick 74 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/track.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | 3 | 4 | class TrackState: 5 | """ 6 | Enumeration type for the single target track state. Newly created tracks are 7 | classified as `tentative` until enough evidence has been collected. Then, 8 | the track state is changed to `confirmed`. Tracks that are no longer alive 9 | are classified as `deleted` to mark them for removal from the set of active 10 | tracks. 11 | 12 | """ 13 | 14 | Tentative = 1 15 | Confirmed = 2 16 | Deleted = 3 17 | 18 | 19 | class Track: 20 | """ 21 | A single target track with state space `(x, y, a, h)` and associated 22 | velocities, where `(x, y)` is the center of the bounding box, `a` is the 23 | aspect ratio and `h` is the height. 24 | 25 | Parameters 26 | ---------- 27 | mean : ndarray 28 | Mean vector of the initial state distribution. 29 | covariance : ndarray 30 | Covariance matrix of the initial state distribution. 31 | track_id : int 32 | A unique track identifier. 33 | n_init : int 34 | Number of consecutive detections before the track is confirmed. The 35 | track state is set to `Deleted` if a miss occurs within the first 36 | `n_init` frames. 37 | max_age : int 38 | The maximum number of consecutive misses before the track state is 39 | set to `Deleted`. 40 | feature : Optional[ndarray] 41 | Feature vector of the detection this track originates from. If not None, 42 | this feature is added to the `features` cache. 43 | 44 | Attributes 45 | ---------- 46 | mean : ndarray 47 | Mean vector of the initial state distribution. 48 | covariance : ndarray 49 | Covariance matrix of the initial state distribution. 50 | track_id : int 51 | A unique track identifier. 52 | hits : int 53 | Total number of measurement updates. 54 | age : int 55 | Total number of frames since first occurance. 56 | time_since_update : int 57 | Total number of frames since last measurement update. 58 | state : TrackState 59 | The current track state. 60 | features : List[ndarray] 61 | A cache of features. On each measurement update, the associated feature 62 | vector is added to this list. 63 | 64 | """ 65 | 66 | def __init__(self, mean, covariance, track_id, n_init, max_age, 67 | feature=None,label=None): 68 | self.mean = mean 69 | self.covariance = covariance 70 | self.track_id = track_id 71 | self.hits = 1 72 | self.age = 1 73 | self.time_since_update = 0 74 | 75 | self.state = TrackState.Tentative 76 | self.features = [] 77 | self.label=label if label is not None else -1 78 | if feature is not None: 79 | self.features.append(feature) 80 | 81 | self._n_init = n_init 82 | self._max_age = max_age 83 | 84 | def to_tlwh(self): 85 | """Get current position in bounding box format `(top left x, top left y, 86 | width, height)`. 87 | 88 | Returns 89 | ------- 90 | ndarray 91 | The bounding box. 92 | 93 | """ 94 | ret = self.mean[:4].copy() 95 | ret[2] *= ret[3] 96 | ret[:2] -= ret[2:] / 2 97 | return ret 98 | 99 | def to_tlbr(self): 100 | """Get current position in bounding box format `(min x, miny, max x, 101 | max y)`. 102 | 103 | Returns 104 | ------- 105 | ndarray 106 | The bounding box. 107 | 108 | """ 109 | ret = self.to_tlwh() 110 | ret[2:] = ret[:2] + ret[2:] 111 | return ret 112 | 113 | def predict(self, kf): 114 | """Propagate the state distribution to the current time step using a 115 | Kalman filter prediction step. 116 | 117 | Parameters 118 | ---------- 119 | kf : kalman_filter.KalmanFilter 120 | The Kalman filter. 121 | 122 | """ 123 | self.mean, self.covariance = kf.predict(self.mean, self.covariance) 124 | self.age += 1 125 | self.time_since_update += 1 126 | 127 | def update(self, kf, detection): 128 | """Perform Kalman filter measurement update step and update the feature 129 | cache. 130 | 131 | Parameters 132 | ---------- 133 | kf : kalman_filter.KalmanFilter 134 | The Kalman filter. 135 | detection : Detection 136 | The associated detection. 137 | 138 | """ 139 | self.mean, self.covariance = kf.update( 140 | self.mean, self.covariance, detection.to_xyah()) 141 | self.features.append(detection.feature) 142 | self.label=detection.label 143 | self.hits += 1 144 | self.time_since_update = 0 145 | if self.state == TrackState.Tentative and self.hits >= self._n_init: 146 | self.state = TrackState.Confirmed 147 | 148 | def mark_missed(self): 149 | """Mark this track as missed (no association at the current time step). 150 | """ 151 | if self.state == TrackState.Tentative: 152 | self.state = TrackState.Deleted 153 | elif self.time_since_update > self._max_age: 154 | self.state = TrackState.Deleted 155 | 156 | def is_tentative(self): 157 | """Returns True if this track is tentative (unconfirmed). 158 | """ 159 | return self.state == TrackState.Tentative 160 | 161 | def is_confirmed(self): 162 | """Returns True if this track is confirmed.""" 163 | return self.state == TrackState.Confirmed 164 | 165 | def is_deleted(self): 166 | """Returns True if this track is dead and should be deleted.""" 167 | return self.state == TrackState.Deleted 168 | -------------------------------------------------------------------------------- /deep_sort/deep_sort/sort/tracker.py: -------------------------------------------------------------------------------- 1 | # vim: expandtab:ts=4:sw=4 2 | from __future__ import absolute_import 3 | import numpy as np 4 | from . import kalman_filter 5 | from . import linear_assignment 6 | from . import iou_matching 7 | from .track import Track 8 | 9 | 10 | class Tracker: 11 | """ 12 | This is the multi-target tracker. 13 | 14 | Parameters 15 | ---------- 16 | metric : nn_matching.NearestNeighborDistanceMetric 17 | A distance metric for measurement-to-track association. 18 | max_age : int 19 | Maximum number of missed misses before a track is deleted. 20 | n_init : int 21 | Number of consecutive detections before the track is confirmed. The 22 | track state is set to `Deleted` if a miss occurs within the first 23 | `n_init` frames. 24 | 25 | Attributes 26 | ---------- 27 | metric : nn_matching.NearestNeighborDistanceMetric 28 | The distance metric used for measurement to track association. 29 | max_age : int 30 | Maximum number of missed misses before a track is deleted. 31 | n_init : int 32 | Number of frames that a track remains in initialization phase. 33 | kf : kalman_filter.KalmanFilter 34 | A Kalman filter to filter target trajectories in image space. 35 | tracks : List[Track] 36 | The list of active tracks at the current time step. 37 | 38 | """ 39 | 40 | def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3): 41 | self.metric = metric 42 | self.max_iou_distance = max_iou_distance 43 | self.max_age = max_age 44 | self.n_init = n_init 45 | 46 | self.kf = kalman_filter.KalmanFilter() 47 | self.tracks = [] 48 | self._next_id = 1 49 | 50 | def predict(self): 51 | """Propagate track state distributions one time step forward. 52 | 53 | This function should be called once every time step, before `update`. 54 | """ 55 | for track in self.tracks: 56 | track.predict(self.kf) 57 | 58 | def update(self, detections): 59 | """Perform measurement update and track management. 60 | 61 | Parameters 62 | ---------- 63 | detections : List[deep_sort.detection.Detection] 64 | A list of detections at the current time step. 65 | 66 | """ 67 | # Run matching cascade. 68 | matches, unmatched_tracks, unmatched_detections = \ 69 | self._match(detections) 70 | 71 | # Update track set. 72 | for track_idx, detection_idx in matches: 73 | self.tracks[track_idx].update( 74 | self.kf, detections[detection_idx]) 75 | for track_idx in unmatched_tracks: 76 | self.tracks[track_idx].mark_missed() 77 | for detection_idx in unmatched_detections: 78 | self._initiate_track(detections[detection_idx]) 79 | self.tracks = [t for t in self.tracks if not t.is_deleted()] 80 | 81 | # Update distance metric. 82 | active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] 83 | features, targets = [], [] 84 | for track in self.tracks: 85 | if not track.is_confirmed(): 86 | continue 87 | features += track.features 88 | targets += [track.track_id for _ in track.features] 89 | track.features = [] 90 | self.metric.partial_fit( 91 | np.asarray(features), np.asarray(targets), active_targets) 92 | 93 | def _match(self, detections): 94 | 95 | def gated_metric(tracks, dets, track_indices, detection_indices): 96 | features = np.array([dets[i].feature for i in detection_indices]) 97 | targets = np.array([tracks[i].track_id for i in track_indices]) 98 | cost_matrix = self.metric.distance(features, targets) 99 | cost_matrix = linear_assignment.gate_cost_matrix( 100 | self.kf, cost_matrix, tracks, dets, track_indices, 101 | detection_indices) 102 | 103 | return cost_matrix 104 | 105 | # Split track set into confirmed and unconfirmed tracks. 106 | confirmed_tracks = [ 107 | i for i, t in enumerate(self.tracks) if t.is_confirmed()] 108 | unconfirmed_tracks = [ 109 | i for i, t in enumerate(self.tracks) if not t.is_confirmed()] 110 | 111 | # Associate confirmed tracks using appearance features. 112 | matches_a, unmatched_tracks_a, unmatched_detections = \ 113 | linear_assignment.matching_cascade( 114 | gated_metric, self.metric.matching_threshold, self.max_age, 115 | self.tracks, detections, confirmed_tracks) 116 | 117 | # Associate remaining tracks together with unconfirmed tracks using IOU. 118 | iou_track_candidates = unconfirmed_tracks + [ 119 | k for k in unmatched_tracks_a if 120 | self.tracks[k].time_since_update == 1] 121 | unmatched_tracks_a = [ 122 | k for k in unmatched_tracks_a if 123 | self.tracks[k].time_since_update != 1] 124 | matches_b, unmatched_tracks_b, unmatched_detections = \ 125 | linear_assignment.min_cost_matching( 126 | iou_matching.iou_cost, self.max_iou_distance, self.tracks, 127 | detections, iou_track_candidates, unmatched_detections) 128 | 129 | matches = matches_a + matches_b 130 | unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) 131 | return matches, unmatched_tracks, unmatched_detections 132 | 133 | def _initiate_track(self, detection): 134 | mean, covariance = self.kf.initiate(detection.to_xyah()) 135 | self.tracks.append(Track( 136 | mean, covariance, self._next_id, self.n_init, self.max_age, 137 | detection.feature,detection.label)) 138 | self._next_id += 1 139 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Yolov5+SlowFast+deepsort: Action Detection(PytorchVideo) 2 | 3 | 4 | ### A realtime action detection frame work based on PytorchVideo. 5 | 6 | #### Here are some details about our modification: 7 | 8 | - we choose yolov5 as an object detector instead of Faster R-CNN, it is faster and more convenient 9 | - we use a tracker(deepsort) to allocate action labels to all objects(with same ids) in different frames 10 | - our processing speed reached 24.2 FPS at 30 inference batch size (on a single RTX 2080Ti GPU) 11 | 12 | > 参考: 13 | > 14 | > [FAIR/PytorchVideo](https://github.com/facebookresearch/pytorchvideo); 15 | > 16 | > [Ultralytics/Yolov5](https://github.com/ultralytics/yolov5) 17 | > 18 | > [yolo_slowfast](https://github.com/wufan-tb/yolo_slowfast) 19 | 20 | 21 | #### Demo comparison between original and ours. 22 | 23 | 24 | 25 | ![image](https://img-blog.csdnimg.cn/92e00516f2984dfcb3ba4888fddde9dd.gif) 26 | ![image](https://img-blog.csdnimg.cn/c01edc763a744b9d8114b3973a4d0385.gif) 27 | 28 | 29 | ## Installation 30 | 31 | 使用AI平台:[https://cloud.videojj.com/auth/register?inviter=18452&activityChannel=student_invite](https://cloud.videojj.com/auth/register?inviter=18452&activityChannel=student_invite) 32 | 33 | 0. environment 环境 34 | ``` 35 | Pytorch 1.10.1 36 | Python 3.8 37 | Cuda 11.1 38 | ``` 39 | 1. 安装PytorchVideo: 40 | ``` 41 | cd /home 42 | git clone https://gitee.com/YFwinston/pytorchvideo.git 43 | cd pytorchvideo 44 | pip install -e . 45 | ``` 46 | 47 | ``` 48 | apt update 49 | apt install libgl1-mesa-glx 50 | 51 | ``` 52 | 53 | 3. clone this repo: 54 | 55 | 使用github 56 | ``` 57 | cd /home 58 | git clone https://github.com/Whiffe/yolov5-slowfast-deepsort-PytorchVideo.git 59 | ``` 60 | 61 | 或者使用gitee 62 | 63 | ``` 64 | cd /home 65 | git clone https://gitee.com/YFwinston/yolov5-slowfast-deepsort-PytorchVideo.git 66 | ``` 67 | 68 | 69 | 2. create a new python environment (optional 可选): 70 | 71 | ``` 72 | conda create -n {your_env_name} python=3.8.12 73 | conda activate {your_env_name} 74 | ``` 75 | 76 | 3. install requiments: 77 | 78 | ``` 79 | cd /home/yolov5-slowfast-deepsort-PytorchVideo 80 | pip install -r requirements2.txt 81 | ``` 82 | 83 | 4. download weights file(ckpt.t7) from [[yolov5_file]](https://share.weiyun.com/xCgma1LG) to this folder: 84 | 85 | ``` 86 | ./deep_sort/deep_sort/deep/checkpoint/ 87 | ``` 88 | 89 | 我是将ckpt.t7放在了:/user-data/yolov5_file/ 90 | 91 | 所以执行: 92 | 93 | ``` 94 | mkdir -p /home/yolov5-slowfast-deepsort-PytorchVideo/deep_sort/deep_sort/deep/checkpoint/ 95 | cp /user-data/yolov5_file/ckpt.t7 /home/yolov5-slowfast-deepsort-PytorchVideo/deep_sort/deep_sort/deep/checkpoint/ckpt.t7 96 | ``` 97 | 5. download file(SLOWFAST_8x8_R50_DETECTION.pyth) from [[slowfast_file]](https://share.weiyun.com/EUi4NvnM) to this folder: 98 | 99 | 我是将SLOWFAST_8x8_R50_DETECTION.pyth放在了:/user-data/slowfast_file/ 100 | 101 | 所以执行: 102 | ``` 103 | mkdir -p /root/.cache/torch/hub/checkpoints/ 104 | cp /user-data/slowfast_file/SLOWFAST_8x8_R50_DETECTION.pyth /root/.cache/torch/hub/checkpoints/SLOWFAST_8x8_R50_DETECTION.pyth 105 | ``` 106 | 107 | 6. download file(yolov5l6.pt) from [[yolov5_file]](https://share.weiyun.com/xCgma1LG) to this folder: 108 | 109 | 我是将yolov5l6.pt放在了:/user-data/yolov5_file/ 110 | 111 | 所以执行: 112 | ``` 113 | cp /user-data/yolov5_file/yolov5l6.pt /home/yolov5-slowfast-deepsort-PytorchVideo/yolov5l6.pt 114 | ``` 115 | 7. download file(master.zip) from [[yolov5_file]](https://share.weiyun.com/xCgma1LG) to this folder: 116 | 117 | 我是将yolov5-master.zip放在了:/user-data/yolov5_file/ 118 | 119 | 所以执行: 120 | ``` 121 | cp /user-data/yolov5_file/yolov5-master.zip /root/.cache/torch/hub/master.zip 122 | ``` 123 | 124 | 125 | 8. test on your video: 126 | 127 | 128 | ``` 129 | python yolo_slowfast.py --input {path to your video} 130 | ``` 131 | 132 | 我将1.mp4存放在了/home/yolov5-slowfast-deepsort-PytorchVideo/demo/中 133 | 134 | 所以执行: 135 | 136 | 137 | ``` 138 | cd /home/yolov5-slowfast-deepsort-PytorchVideo 139 | mkdir demo 140 | ``` 141 | 142 | ``` 143 | cd /home/yolov5-slowfast-deepsort-PytorchVideo 144 | python yolo_slowfast.py --input ./demo/1.mp4 145 | ``` 146 | 147 | The first time execute this command may take some times to download the yolov5 code and it's weights file from torch.hub, keep your network connection. 148 | 149 | ## References 150 | 151 | Thanks for these great works: 152 | 153 | [1] [Ultralytics/Yolov5](https://github.com/ultralytics/yolov5) 154 | 155 | [2] [ZQPei/deepsort](https://github.com/ZQPei/deep_sort_pytorch) 156 | 157 | [3] [FAIR/PytorchVideo](https://github.com/facebookresearch/pytorchvideo) 158 | 159 | [4] AVA: A Video Dataset of Spatio-temporally Localized Atomic Visual Actions. [paper](https://arxiv.org/pdf/1705.08421.pdf) 160 | 161 | [5] SlowFast Networks for Video Recognition. [paper](https://arxiv.org/pdf/1812.03982.pdf) 162 | 163 | ## Citation 164 | 165 | If you find our work useful, please cite as follow: 166 | 167 | ``` 168 | { yolo_slowfast, 169 | author = {Wu Fan}, 170 | title = { A realtime action detection frame work based on PytorchVideo}, 171 | year = {2021}, 172 | url = {\url{https://github.com/wufan-tb/yolo_slowfast}} 173 | } 174 | ``` 175 | 176 | ### Stargazers over time 177 | 178 | 179 | ## Stargazers over time 180 | 181 | [![Stargazers over time](https://starchart.cc/Whiffe/yolov5-slowfast-deepsort-PytorchVideo.svg)](https://starchart.cc/Whiffe/yolov5-slowfast-deepsort-PytorchVideo) 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.5.3.56 2 | torch==1.8.0 3 | torchvision==0.9.0 4 | natsort 5 | -------------------------------------------------------------------------------- /requirements2.txt: -------------------------------------------------------------------------------- 1 | natsort 2 | scipy 3 | pandas 4 | matplotlib 5 | seaborn 6 | -------------------------------------------------------------------------------- /selfutils/ava_action_list.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "bend/bow (at the waist)" 3 | id: 1 4 | } 5 | item { 6 | name: "crawl" 7 | id: 2 8 | } 9 | item { 10 | name: "crouch/kneel" 11 | id: 3 12 | } 13 | item { 14 | name: "dance" 15 | id: 4 16 | } 17 | item { 18 | name: "fall down" 19 | id: 5 20 | } 21 | item { 22 | name: "get up" 23 | id: 6 24 | } 25 | item { 26 | name: "jump/leap" 27 | id: 7 28 | } 29 | item { 30 | name: "lie/sleep" 31 | id: 8 32 | } 33 | item { 34 | name: "martial art" 35 | id: 9 36 | } 37 | item { 38 | name: "run/jog" 39 | id: 10 40 | } 41 | item { 42 | name: "sit" 43 | id: 11 44 | } 45 | item { 46 | name: "stand" 47 | id: 12 48 | } 49 | item { 50 | name: "swim" 51 | id: 13 52 | } 53 | item { 54 | name: "walk" 55 | id: 14 56 | } 57 | item { 58 | name: "answer phone" 59 | id: 15 60 | } 61 | item { 62 | name: "brush teeth" 63 | id: 16 64 | } 65 | item { 66 | name: "carry/hold (an object)" 67 | id: 17 68 | } 69 | item { 70 | name: "catch (an object)" 71 | id: 18 72 | } 73 | item { 74 | name: "chop" 75 | id: 19 76 | } 77 | item { 78 | name: "climb (e.g., a mountain)" 79 | id: 20 80 | } 81 | item { 82 | name: "clink glass" 83 | id: 21 84 | } 85 | item { 86 | name: "close (e.g., a door, a box)" 87 | id: 22 88 | } 89 | item { 90 | name: "cook" 91 | id: 23 92 | } 93 | item { 94 | name: "cut" 95 | id: 24 96 | } 97 | item { 98 | name: "dig" 99 | id: 25 100 | } 101 | item { 102 | name: "dress/put on clothing" 103 | id: 26 104 | } 105 | item { 106 | name: "drink" 107 | id: 27 108 | } 109 | item { 110 | name: "drive (e.g., a car, a truck)" 111 | id: 28 112 | } 113 | item { 114 | name: "eat" 115 | id: 29 116 | } 117 | item { 118 | name: "enter" 119 | id: 30 120 | } 121 | item { 122 | name: "exit" 123 | id: 31 124 | } 125 | item { 126 | name: "extract" 127 | id: 32 128 | } 129 | item { 130 | name: "fishing" 131 | id: 33 132 | } 133 | item { 134 | name: "hit (an object)" 135 | id: 34 136 | } 137 | item { 138 | name: "kick (an object)" 139 | id: 35 140 | } 141 | item { 142 | name: "lift/pick up" 143 | id: 36 144 | } 145 | item { 146 | name: "listen (e.g., to music)" 147 | id: 37 148 | } 149 | item { 150 | name: "open (e.g., a window, a car door)" 151 | id: 38 152 | } 153 | item { 154 | name: "paint" 155 | id: 39 156 | } 157 | item { 158 | name: "play board game" 159 | id: 40 160 | } 161 | item { 162 | name: "play musical instrument" 163 | id: 41 164 | } 165 | item { 166 | name: "play with pets" 167 | id: 42 168 | } 169 | item { 170 | name: "point to (an object)" 171 | id: 43 172 | } 173 | item { 174 | name: "press" 175 | id: 44 176 | } 177 | item { 178 | name: "pull (an object)" 179 | id: 45 180 | } 181 | item { 182 | name: "push (an object)" 183 | id: 46 184 | } 185 | item { 186 | name: "put down" 187 | id: 47 188 | } 189 | item { 190 | name: "read" 191 | id: 48 192 | } 193 | item { 194 | name: "ride (e.g., a bike, a car, a horse)" 195 | id: 49 196 | } 197 | item { 198 | name: "row boat" 199 | id: 50 200 | } 201 | item { 202 | name: "sail boat" 203 | id: 51 204 | } 205 | item { 206 | name: "shoot" 207 | id: 52 208 | } 209 | item { 210 | name: "shovel" 211 | id: 53 212 | } 213 | item { 214 | name: "smoke" 215 | id: 54 216 | } 217 | item { 218 | name: "stir" 219 | id: 55 220 | } 221 | item { 222 | name: "take a photo" 223 | id: 56 224 | } 225 | item { 226 | name: "text on/look at a cellphone" 227 | id: 57 228 | } 229 | item { 230 | name: "throw" 231 | id: 58 232 | } 233 | item { 234 | name: "touch (an object)" 235 | id: 59 236 | } 237 | item { 238 | name: "turn (e.g., a screwdriver)" 239 | id: 60 240 | } 241 | item { 242 | name: "watch (e.g., TV)" 243 | id: 61 244 | } 245 | item { 246 | name: "work on a computer" 247 | id: 62 248 | } 249 | item { 250 | name: "write" 251 | id: 63 252 | } 253 | item { 254 | name: "fight/hit (a person)" 255 | id: 64 256 | } 257 | item { 258 | name: "give/serve (an object) to (a person)" 259 | id: 65 260 | } 261 | item { 262 | name: "grab (a person)" 263 | id: 66 264 | } 265 | item { 266 | name: "hand clap" 267 | id: 67 268 | } 269 | item { 270 | name: "hand shake" 271 | id: 68 272 | } 273 | item { 274 | name: "hand wave" 275 | id: 69 276 | } 277 | item { 278 | name: "hug (a person)" 279 | id: 70 280 | } 281 | item { 282 | name: "kick (a person)" 283 | id: 71 284 | } 285 | item { 286 | name: "kiss (a person)" 287 | id: 72 288 | } 289 | item { 290 | name: "lift (a person)" 291 | id: 73 292 | } 293 | item { 294 | name: "listen to (a person)" 295 | id: 74 296 | } 297 | item { 298 | name: "play with kids" 299 | id: 75 300 | } 301 | item { 302 | name: "push (another person)" 303 | id: 76 304 | } 305 | item { 306 | name: "sing to (e.g., self, a person, a group)" 307 | id: 77 308 | } 309 | item { 310 | name: "take (an object) from (a person)" 311 | id: 78 312 | } 313 | item { 314 | name: "talk to (e.g., self, a person, a group)" 315 | id: 79 316 | } 317 | item { 318 | name: "watch (a person)" 319 | id: 80 320 | } 321 | -------------------------------------------------------------------------------- /selfutils/coco_names.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /selfutils/slowfast_detection.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore",category=UserWarning) 3 | 4 | import os,sys,torch,cv2,pytorchvideo,time 5 | 6 | from functools import partial 7 | import numpy as np 8 | 9 | from detectron2.config import get_cfg 10 | from detectron2 import model_zoo 11 | from detectron2.engine import DefaultPredictor 12 | 13 | import pytorchvideo 14 | from pytorchvideo.transforms.functional import ( 15 | uniform_temporal_subsample, 16 | short_side_scale_with_boxes, 17 | clip_boxes_to_image, 18 | ) 19 | from torchvision.transforms._functional_video import normalize 20 | from pytorchvideo.data.ava import AvaLabeledVideoFramePaths 21 | from pytorchvideo.models.hub import slowfast_r50_detection # Another option is slowfast_r50_detection, slow_r50_detection 22 | 23 | from visualization import VideoVisualizer 24 | 25 | # This method takes in an image and generates the bounding boxes for people in the image. 26 | def get_person_bboxes(inp_img, predictor): 27 | with torch.no_grad(): 28 | predictions = predictor(inp_img.cpu().detach().numpy())['instances'].to('cpu') 29 | boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None 30 | scores = predictions.scores if predictions.has("scores") else None 31 | classes = np.array(predictions.pred_classes.tolist() if predictions.has("pred_classes") else None) 32 | # print(predictions._fields.keys()) 33 | # print("ROI info:",boxes.tensor.shape,scores.shape,predictions.pred_classes.shape) 34 | # for i in range(predictions.pred_classes.shape[0]): 35 | # conf,pred=predictions.onehot_labels[i].max(-1) 36 | # print(predictions.pred_classes[i],conf,pred) 37 | predicted_boxes = boxes[np.logical_and(classes!=-1, scores>0.5 )].tensor.cpu() # only person 38 | return predicted_boxes 39 | 40 | # ## Define the transformations for the input required by the model 41 | def ava_inference_transform( 42 | clip, 43 | boxes, 44 | num_frames = 32, #if using slowfast_r50_detection, change this to 32, 4 for slow 45 | crop_size = 640, 46 | data_mean = [0.45, 0.45, 0.45], 47 | data_std = [0.225, 0.225, 0.225], 48 | slow_fast_alpha = 4, #if using slowfast_r50_detection, change this to 4, None for slow 49 | ): 50 | 51 | boxes = np.array(boxes) 52 | roi_boxes = boxes.copy() 53 | 54 | # Image [0, 255] -> [0, 1]. 55 | clip = uniform_temporal_subsample(clip, num_frames) 56 | clip = clip.float() 57 | clip = clip / 255.0 58 | 59 | height, width = clip.shape[2], clip.shape[3] 60 | # The format of boxes is [x1, y1, x2, y2]. The input boxes are in the 61 | # range of [0, width] for x and [0,height] for y 62 | boxes = clip_boxes_to_image(boxes, height, width) 63 | 64 | # Resize short side to crop_size. Non-local and STRG uses 256. 65 | clip, boxes = short_side_scale_with_boxes( 66 | clip, 67 | size=crop_size, 68 | boxes=boxes, 69 | ) 70 | 71 | # Normalize images by mean and std. 72 | clip = normalize( 73 | clip, 74 | np.array(data_mean, dtype=np.float32), 75 | np.array(data_std, dtype=np.float32), 76 | ) 77 | 78 | boxes = clip_boxes_to_image( 79 | boxes, clip.shape[2], clip.shape[3] 80 | ) 81 | 82 | # Incase of slowfast, generate both pathways 83 | if slow_fast_alpha is not None: 84 | fast_pathway = clip 85 | # Perform temporal sampling from the fast pathway. 86 | slow_pathway = torch.index_select( 87 | clip, 88 | 1, 89 | torch.linspace( 90 | 0, clip.shape[1] - 1, clip.shape[1] // slow_fast_alpha 91 | ).long(), 92 | ) 93 | clip = [slow_pathway, fast_pathway] 94 | 95 | return clip, torch.from_numpy(boxes), roi_boxes 96 | 97 | def main(args): 98 | # ## load slow faster model 99 | device = args.device # or 'cpu' 100 | video_model = slowfast_r50_detection(True) # Another option is slowfast_r50_detection 101 | video_model = video_model.eval().to(device) 102 | 103 | # ## Load an off-the-shelf Detectron2 object detector 104 | cfg = get_cfg() 105 | cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) 106 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.threshold # set threshold for this model 107 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") 108 | predictor = DefaultPredictor(cfg) 109 | 110 | # Create an id to label name mapping 111 | label_map, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map('selfutils/ava_action_list.pbtxt') 112 | # Create a video visualizer that can plot bounding boxes and visualize actions on bboxes. 113 | video_visualizer = VideoVisualizer(81, label_map, top_k=3, mode="thres",thres=0.7) 114 | 115 | # Load the video 116 | encoded_vid = pytorchvideo.data.encoded_video.EncodedVideo.from_path(args.input) 117 | print('Completed loading encoded video.') 118 | 119 | # Video predictions are generated at an internal of 1 sec from 90 seconds to 100 seconds in the video. 120 | time_stamp_range = range(0,int(encoded_vid.duration//1),1) # time stamps in video for which clip is sampled. 121 | clip_duration = 1 # Duration of clip used for each inference step. 122 | gif_imgs = [] 123 | a=time.time() 124 | for time_stamp in time_stamp_range: 125 | print("processing for {}th sec".format(time_stamp)) 126 | 127 | # Generate clip around the designated time stamps 128 | inp_imgs = encoded_vid.get_clip( 129 | time_stamp , # start second 130 | time_stamp + clip_duration # end second 131 | ) 132 | inp_imgs = inp_imgs['video'] 133 | # print("clips shape for slowfaster:",inp_imgs.shape) 134 | # Generate people bbox predictions using Detectron2's off the self pre-trained predictor 135 | # We use the the middle image in each clip to generate the bounding boxes. 136 | inp_img = inp_imgs[:,inp_imgs.shape[1]//2,:,:] 137 | inp_img = inp_img.permute(1,2,0) 138 | # print("img shape for faster rcnn:",inp_img.shape) 139 | 140 | # Predicted boxes are of the form List[(x_1, y_1, x_2, y_2)] 141 | predicted_boxes = get_person_bboxes(inp_img, predictor) 142 | # print("ROI boxes (only person):",predicted_boxes) 143 | if len(predicted_boxes) == 0: 144 | print("no detected at time stamp: ", time_stamp) 145 | continue 146 | 147 | # Preprocess clip and bounding boxes for video action recognition. 148 | # print(predicted_boxes) 149 | inputs, inp_boxes, _ = ava_inference_transform(inp_imgs, predicted_boxes.numpy(), crop_size=args.imsize) 150 | # Prepend data sample id for each bounding box. 151 | # For more details refere to the RoIAlign in Detectron2 152 | inp_boxes = torch.cat([torch.zeros(inp_boxes.shape[0],1), inp_boxes], dim=1) 153 | 154 | # Generate actions predictions for the bounding boxes in the clip. 155 | # The model here takes in the pre-processed video clip and the detected bounding boxes. 156 | if isinstance(inputs, list): 157 | inputs = [inp.unsqueeze(0).to(device) for inp in inputs] 158 | else: 159 | inputs = inputs.unsqueeze(0).to(device) 160 | # print("slowfaster's inputs shape:",len(inputs),inputs[0].shape,inputs[1].shape) 161 | with torch.no_grad(): 162 | preds = video_model(inputs, inp_boxes.to(device)) 163 | 164 | preds = preds.to('cpu') 165 | # The model is trained on AVA and AVA labels are 1 indexed so, prepend 0 to convert to 0 index. 166 | preds = torch.cat([torch.zeros(preds.shape[0],1), preds], dim=1) 167 | 168 | # Plot predictions on the video and save for later visualization. 169 | inp_imgs = inp_imgs.permute(1,2,3,0) 170 | 171 | inp_imgs = inp_imgs/255.0 172 | # print("pred shapes:",preds.shape,predicted_boxes.shape) 173 | out_img_pred = video_visualizer.draw_clip_range(inp_imgs, preds, predicted_boxes,repeat_frame=1) 174 | gif_imgs += out_img_pred 175 | 176 | print("Finished generating predictions.") 177 | print("total cost: {:.3f}s, video clips length: {}s".format(time.time()-a,len(time_stamp_range))) 178 | 179 | # ## Save predictions as video 180 | height, width = gif_imgs[0].shape[0], gif_imgs[0].shape[1] 181 | 182 | vide_save_path = os.path.join(args.output,'output.mp4') 183 | video = cv2.VideoWriter(vide_save_path,cv2.VideoWriter_fourcc(*'mp4v'), 25, (width,height)) 184 | 185 | for image in gif_imgs: 186 | img = (255*image).astype(np.uint8) 187 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 188 | video.write(img) 189 | video.release() 190 | 191 | print('video saved to:', vide_save_path) 192 | 193 | 194 | if __name__=="__main__": 195 | import argparse 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--input', type=str, default='/data/VAD/SHTech/testing/videos/01_0015.mp4') 198 | parser.add_argument('--output', type=str, default='videos') 199 | parser.add_argument('--device', default='cuda', help='cuda or cpu') 200 | parser.add_argument('--threshold', type=float, default=0.7) 201 | parser.add_argument('--imsize', type=int, default=384) 202 | args = parser.parse_args() 203 | 204 | main(args) 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /selfutils/temp.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "bend/bow (at the waist)" 3 | id: 1 4 | } 5 | item { 6 | name: "crawl" 7 | id: 2 8 | } 9 | item { 10 | name: "crouch/kneel" 11 | id: 3 12 | } 13 | item { 14 | name: "dance" 15 | id: 4 16 | } 17 | item { 18 | name: "fall down" 19 | id: 5 20 | } 21 | item { 22 | name: "get up" 23 | id: 6 24 | } 25 | item { 26 | name: "jump/leap" 27 | id: 7 28 | } 29 | item { 30 | name: "lie/sleep" 31 | id: 8 32 | } 33 | item { 34 | name: "martial art" 35 | id: 9 36 | } 37 | item { 38 | name: "run" 39 | id: 10 40 | } 41 | item { 42 | name: "sit" 43 | id: 11 44 | } 45 | item { 46 | name: "stand" 47 | id: 12 48 | } 49 | item { 50 | name: "swim" 51 | id: 13 52 | } 53 | item { 54 | name: "walk" 55 | id: 14 56 | } 57 | item { 58 | name: "answer phone" 59 | id: 15 60 | } 61 | item { 62 | name: "brush teeth" 63 | id: 16 64 | } 65 | item { 66 | name: "carry/hold (an object)" 67 | id: 17 68 | } 69 | item { 70 | name: "catch (an object)" 71 | id: 18 72 | } 73 | item { 74 | name: "chop" 75 | id: 19 76 | } 77 | item { 78 | name: "climb (e.g., a mountain)" 79 | id: 20 80 | } 81 | item { 82 | name: "clink glass" 83 | id: 21 84 | } 85 | item { 86 | name: "close (e.g., a door, a box)" 87 | id: 22 88 | } 89 | item { 90 | name: "cook" 91 | id: 23 92 | } 93 | item { 94 | name: "cut" 95 | id: 24 96 | } 97 | item { 98 | name: "dig" 99 | id: 25 100 | } 101 | item { 102 | name: "dress/put on clothing" 103 | id: 26 104 | } 105 | item { 106 | name: "drink" 107 | id: 27 108 | } 109 | item { 110 | name: "drive (e.g., a car, a truck)" 111 | id: 28 112 | } 113 | item { 114 | name: "eat" 115 | id: 29 116 | } 117 | item { 118 | name: "enter" 119 | id: 30 120 | } 121 | item { 122 | name: "exit" 123 | id: 31 124 | } 125 | item { 126 | name: "extract" 127 | id: 32 128 | } 129 | item { 130 | name: "fishing" 131 | id: 33 132 | } 133 | item { 134 | name: "hit (an object)" 135 | id: 34 136 | } 137 | item { 138 | name: "kick (an object)" 139 | id: 35 140 | } 141 | item { 142 | name: "lift/pick up" 143 | id: 36 144 | } 145 | item { 146 | name: "listen (e.g., to music)" 147 | id: 37 148 | } 149 | item { 150 | name: "open (e.g., a window, a car door)" 151 | id: 38 152 | } 153 | item { 154 | name: "paint" 155 | id: 39 156 | } 157 | item { 158 | name: "play board game" 159 | id: 40 160 | } 161 | item { 162 | name: "play musical instrument" 163 | id: 41 164 | } 165 | item { 166 | name: "play with pets" 167 | id: 42 168 | } 169 | item { 170 | name: "point to (an object)" 171 | id: 43 172 | } 173 | item { 174 | name: "press" 175 | id: 44 176 | } 177 | item { 178 | name: "pull (an object)" 179 | id: 45 180 | } 181 | item { 182 | name: "push (an object)" 183 | id: 46 184 | } 185 | item { 186 | name: "put down" 187 | id: 47 188 | } 189 | item { 190 | name: "read" 191 | id: 48 192 | } 193 | item { 194 | name: "sit" 195 | id: 49 196 | } 197 | item { 198 | name: "row boat" 199 | id: 50 200 | } 201 | item { 202 | name: "sail boat" 203 | id: 51 204 | } 205 | item { 206 | name: "shoot" 207 | id: 52 208 | } 209 | item { 210 | name: "shovel" 211 | id: 53 212 | } 213 | item { 214 | name: "smoke" 215 | id: 54 216 | } 217 | item { 218 | name: "stir" 219 | id: 55 220 | } 221 | item { 222 | name: "take a photo" 223 | id: 56 224 | } 225 | item { 226 | name: "text on/look at a cellphone" 227 | id: 57 228 | } 229 | item { 230 | name: "throw" 231 | id: 58 232 | } 233 | item { 234 | name: "touch (an object)" 235 | id: 59 236 | } 237 | item { 238 | name: "turn (e.g., a screwdriver)" 239 | id: 60 240 | } 241 | item { 242 | name: "stand" 243 | id: 61 244 | } 245 | item { 246 | name: "work on a computer" 247 | id: 62 248 | } 249 | item { 250 | name: "write" 251 | id: 63 252 | } 253 | item { 254 | name: "fight/hit (a person)" 255 | id: 64 256 | } 257 | item { 258 | name: "give/serve (an object) to (a person)" 259 | id: 65 260 | } 261 | item { 262 | name: "grab (a person)" 263 | id: 66 264 | } 265 | item { 266 | name: "hand clap" 267 | id: 67 268 | } 269 | item { 270 | name: "hand shake" 271 | id: 68 272 | } 273 | item { 274 | name: "hand wave" 275 | id: 69 276 | } 277 | item { 278 | name: "hug (a person)" 279 | id: 70 280 | } 281 | item { 282 | name: "kick (a person)" 283 | id: 71 284 | } 285 | item { 286 | name: "kiss (a person)" 287 | id: 72 288 | } 289 | item { 290 | name: "lift (a person)" 291 | id: 73 292 | } 293 | item { 294 | name: "listen to (a person)" 295 | id: 74 296 | } 297 | item { 298 | name: "play with kids" 299 | id: 75 300 | } 301 | item { 302 | name: "push (another person)" 303 | id: 76 304 | } 305 | item { 306 | name: "sing to (e.g., self, a person, a group)" 307 | id: 77 308 | } 309 | item { 310 | name: "take (an object) from (a person)" 311 | id: 78 312 | } 313 | item { 314 | name: "talk to (e.g., self, a person, a group)" 315 | id: 79 316 | } 317 | item { 318 | name: "stand" 319 | id: 80 320 | } 321 | -------------------------------------------------------------------------------- /selfutils/visualization.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import annotations 3 | 4 | import itertools 5 | import logging 6 | from types import SimpleNamespace 7 | from typing import Dict, List, Optional, Tuple, Union 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from detectron2.utils.visualizer import Visualizer 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def _create_text_labels( 19 | classes: List[int], 20 | scores: List[float], 21 | class_names: List[str], 22 | ground_truth: bool = False, 23 | ) -> List[str]: 24 | """ 25 | Create text labels. 26 | Args: 27 | classes (list[int]): a list of class ids for each example. 28 | scores (list[float] or None): list of scores for each example. 29 | class_names (list[str]): a list of class names, ordered by their ids. 30 | ground_truth (bool): whether the labels are ground truth. 31 | Returns: 32 | labels (list[str]): formatted text labels. 33 | """ 34 | try: 35 | labels = [class_names.get(c, "n/a") for c in classes] 36 | except IndexError: 37 | logger.error("Class indices get out of range: {}".format(classes)) 38 | return None 39 | 40 | if ground_truth: 41 | labels = ["[{}] {}".format("GT", label) for label in labels] 42 | elif scores is not None: 43 | assert len(classes) == len(scores) 44 | labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)] 45 | return labels 46 | 47 | 48 | class ImgVisualizer(Visualizer): 49 | def __init__( 50 | self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs 51 | ) -> None: 52 | """ 53 | See https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/visualizer.py 54 | for more details. 55 | Args: 56 | img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to 57 | the height and width of the image respectively. C is the number of 58 | color channels. The image is required to be in RGB format since that 59 | is a requirement of the Matplotlib library. The image is also expected 60 | to be in the range [0, 255]. 61 | meta (MetadataCatalog): image metadata. 62 | See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90 63 | """ 64 | super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs) 65 | 66 | def draw_text( 67 | self, 68 | text: str, 69 | position: List[int], 70 | *, 71 | font_size: Optional[int] = None, 72 | color: str = "w", 73 | horizontal_alignment: str = "center", 74 | vertical_alignment: str = "bottom", 75 | box_facecolor: str = "black", 76 | alpha: float = 0.5, 77 | ) -> None: 78 | """ 79 | Draw text at the specified position. 80 | Args: 81 | text (str): the text to draw on image. 82 | position (list of 2 ints): the x,y coordinate to place the text. 83 | font_size (Optional[int]): font of the text. If not provided, a font size 84 | proportional to the image width is calculated and used. 85 | color (str): color of the text. Refer to `matplotlib.colors` for full list 86 | of formats that are accepted. 87 | horizontal_alignment (str): see `matplotlib.text.Text`. 88 | vertical_alignment (str): see `matplotlib.text.Text`. 89 | box_facecolor (str): color of the box wrapped around the text. Refer to 90 | `matplotlib.colors` for full list of formats that are accepted. 91 | alpha (float): transparency level of the box. 92 | """ 93 | if not font_size: 94 | font_size = self._default_font_size 95 | x, y = position 96 | self.output.ax.text( 97 | x, 98 | y, 99 | text, 100 | size=font_size * self.output.scale, 101 | family="monospace", 102 | bbox={ 103 | "facecolor": box_facecolor, 104 | "alpha": alpha, 105 | "pad": 0.7, 106 | "edgecolor": "none", 107 | }, 108 | verticalalignment=vertical_alignment, 109 | horizontalalignment=horizontal_alignment, 110 | color=color, 111 | zorder=10, 112 | ) 113 | 114 | def draw_multiple_text( 115 | self, 116 | text_ls: List[str], 117 | box_coordinate: torch.Tensor, 118 | *, 119 | top_corner: bool = True, 120 | font_size: Optional[int] = None, 121 | color: str = "w", 122 | box_facecolors: str = "black", 123 | alpha: float = 0.5, 124 | ) -> None: 125 | """ 126 | Draw a list of text labels for some bounding box on the image. 127 | Args: 128 | text_ls (list of strings): a list of text labels. 129 | box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) 130 | coordinates of the box. 131 | top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box. 132 | Else, draw labels at (x_left, y_bottom). 133 | font_size (Optional[int]): font of the text. If not provided, a font size 134 | proportional to the image width is calculated and used. 135 | color (str): color of the text. Refer to `matplotlib.colors` for full list 136 | of formats that are accepted. 137 | box_facecolors (str): colors of the box wrapped around the text. Refer to 138 | `matplotlib.colors` for full list of formats that are accepted. 139 | alpha (float): transparency level of the box. 140 | """ 141 | if not isinstance(box_facecolors, list): 142 | box_facecolors = [box_facecolors] * len(text_ls) 143 | assert len(box_facecolors) == len( 144 | text_ls 145 | ), "Number of colors provided is not equal to the number of text labels." 146 | if not font_size: 147 | font_size = self._default_font_size 148 | text_box_width = font_size + font_size // 2 149 | # If the texts does not fit in the assigned location, 150 | # we split the text and draw it in another place. 151 | if top_corner: 152 | num_text_split = self._align_y_top( 153 | box_coordinate, len(text_ls), text_box_width 154 | ) 155 | y_corner = 1 156 | else: 157 | num_text_split = len(text_ls) - self._align_y_bottom( 158 | box_coordinate, len(text_ls), text_box_width 159 | ) 160 | y_corner = 3 161 | 162 | text_color_sorted = sorted( 163 | zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True 164 | ) 165 | if len(text_color_sorted) != 0: 166 | text_ls, box_facecolors = zip(*text_color_sorted) 167 | else: 168 | text_ls, box_facecolors = [], [] 169 | text_ls, box_facecolors = list(text_ls), list(box_facecolors) 170 | self.draw_multiple_text_upward( 171 | text_ls[:num_text_split][::-1], 172 | box_coordinate, 173 | y_corner=y_corner, 174 | font_size=font_size, 175 | color=color, 176 | box_facecolors=box_facecolors[:num_text_split][::-1], 177 | alpha=alpha, 178 | ) 179 | self.draw_multiple_text_downward( 180 | text_ls[num_text_split:], 181 | box_coordinate, 182 | y_corner=y_corner, 183 | font_size=font_size, 184 | color=color, 185 | box_facecolors=box_facecolors[num_text_split:], 186 | alpha=alpha, 187 | ) 188 | 189 | def draw_multiple_text_upward( 190 | self, 191 | text_ls: List[str], 192 | box_coordinate: torch.Tensor, 193 | *, 194 | y_corner: int = 1, 195 | font_size: Optional[int] = None, 196 | color: str = "w", 197 | box_facecolors: str = "black", 198 | alpha: float = 0.5, 199 | ) -> None: 200 | """ 201 | Draw a list of text labels for some bounding box on the image in upward direction. 202 | The next text label will be on top of the previous one. 203 | Args: 204 | text_ls (list of strings): a list of text labels. 205 | box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) 206 | coordinates of the box. 207 | y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of 208 | the box to draw labels around. 209 | font_size (Optional[int]): font of the text. If not provided, a font size 210 | proportional to the image width is calculated and used. 211 | color (str): color of the text. Refer to `matplotlib.colors` for full list 212 | of formats that are accepted. 213 | box_facecolors (str or list of strs): colors of the box wrapped around the 214 | text. Refer to `matplotlib.colors` for full list of formats that 215 | are accepted. 216 | alpha (float): transparency level of the box. 217 | """ 218 | if not isinstance(box_facecolors, list): 219 | box_facecolors = [box_facecolors] * len(text_ls) 220 | assert len(box_facecolors) == len( 221 | text_ls 222 | ), "Number of colors provided is not equal to the number of text labels." 223 | 224 | assert y_corner in [1, 3], "Y_corner must be either 1 or 3" 225 | if not font_size: 226 | font_size = self._default_font_size 227 | 228 | x, horizontal_alignment = self._align_x_coordinate(box_coordinate) 229 | y = box_coordinate[y_corner].item() 230 | for i, text in enumerate(text_ls): 231 | self.draw_text( 232 | text, 233 | (x, y), 234 | font_size=font_size, 235 | color=color, 236 | horizontal_alignment=horizontal_alignment, 237 | vertical_alignment="bottom", 238 | box_facecolor=box_facecolors[i], 239 | alpha=alpha, 240 | ) 241 | y -= font_size + font_size // 2 242 | 243 | def draw_multiple_text_downward( 244 | self, 245 | text_ls: List[str], 246 | box_coordinate: torch.Tensor, 247 | *, 248 | y_corner: int = 1, 249 | font_size: Optional[int] = None, 250 | color: str = "w", 251 | box_facecolors: str = "black", 252 | alpha: float = 0.5, 253 | ) -> None: 254 | """ 255 | Draw a list of text labels for some bounding box on the image in downward direction. 256 | The next text label will be below the previous one. 257 | Args: 258 | text_ls (list of strings): a list of text labels. 259 | box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom) 260 | coordinates of the box. 261 | y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of 262 | the box to draw labels around. 263 | font_size (Optional[int]): font of the text. If not provided, a font size 264 | proportional to the image width is calculated and used. 265 | color (str): color of the text. Refer to `matplotlib.colors` for full list 266 | of formats that are accepted. 267 | box_facecolors (str): colors of the box wrapped around the text. Refer to 268 | `matplotlib.colors` for full list of formats that are accepted. 269 | alpha (float): transparency level of the box. 270 | """ 271 | if not isinstance(box_facecolors, list): 272 | box_facecolors = [box_facecolors] * len(text_ls) 273 | assert len(box_facecolors) == len( 274 | text_ls 275 | ), "Number of colors provided is not equal to the number of text labels." 276 | 277 | assert y_corner in [1, 3], "Y_corner must be either 1 or 3" 278 | if not font_size: 279 | font_size = self._default_font_size 280 | 281 | x, horizontal_alignment = self._align_x_coordinate(box_coordinate) 282 | y = box_coordinate[y_corner].item() 283 | for i, text in enumerate(text_ls): 284 | self.draw_text( 285 | text, 286 | (x, y), 287 | font_size=font_size, 288 | color=color, 289 | horizontal_alignment=horizontal_alignment, 290 | vertical_alignment="top", 291 | box_facecolor=box_facecolors[i], 292 | alpha=alpha, 293 | ) 294 | y += font_size + font_size // 2 295 | 296 | def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]: 297 | """ 298 | Choose an x-coordinate from the box to make sure the text label 299 | does not go out of frames. By default, the left x-coordinate is 300 | chosen and text is aligned left. If the box is too close to the 301 | right side of the image, then the right x-coordinate is chosen 302 | instead and the text is aligned right. 303 | Args: 304 | box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) 305 | coordinates of the box. 306 | Returns: 307 | x_coordinate (float): the chosen x-coordinate. 308 | alignment (str): whether to align left or right. 309 | """ 310 | # If the x-coordinate is greater than 5/6 of the image width, 311 | # then we align test to the right of the box. This is 312 | # chosen by heuristics. 313 | if box_coordinate[0] > (self.output.width * 5) // 6: 314 | return box_coordinate[2], "right" 315 | 316 | return box_coordinate[0], "left" 317 | 318 | def _align_y_top( 319 | self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float 320 | ) -> int: 321 | """ 322 | Calculate the number of text labels to plot on top of the box 323 | without going out of frames. 324 | Args: 325 | box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) 326 | coordinates of the box. 327 | num_text (int): the number of text labels to plot. 328 | textbox_width (float): the width of the box wrapped around text label. 329 | """ 330 | dist_to_top = box_coordinate[1] 331 | num_text_top = dist_to_top // textbox_width 332 | 333 | if isinstance(num_text_top, torch.Tensor): 334 | num_text_top = int(num_text_top.item()) 335 | 336 | return min(num_text, num_text_top) 337 | 338 | def _align_y_bottom( 339 | self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float 340 | ) -> int: 341 | """ 342 | Calculate the number of text labels to plot at the bottom of the box 343 | without going out of frames. 344 | Args: 345 | box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom) 346 | coordinates of the box. 347 | num_text (int): the number of text labels to plot. 348 | textbox_width (float): the width of the box wrapped around text label. 349 | """ 350 | dist_to_bottom = self.output.height - box_coordinate[3] 351 | num_text_bottom = dist_to_bottom // textbox_width 352 | 353 | if isinstance(num_text_bottom, torch.Tensor): 354 | num_text_bottom = int(num_text_bottom.item()) 355 | 356 | return min(num_text, num_text_bottom) 357 | 358 | 359 | class VideoVisualizer: 360 | def __init__( 361 | self, 362 | num_classes: int, 363 | class_names: Dict, 364 | top_k: int = 1, 365 | colormap: str = "rainbow", 366 | thres: float = 0.7, 367 | lower_thres: float = 0.3, 368 | common_class_names: Optional[List[str]] = None, 369 | mode: str = "top-k", 370 | ) -> None: 371 | """ 372 | Args: 373 | num_classes (int): total number of classes. 374 | class_names (dict): Dict mapping classID to name. 375 | top_k (int): number of top predicted classes to plot. 376 | colormap (str): the colormap to choose color for class labels from. 377 | See https://matplotlib.org/tutorials/colors/colormaps.html 378 | thres (float): threshold for picking predicted classes to visualize. 379 | lower_thres (Optional[float]): If `common_class_names` if given, 380 | this `lower_thres` will be applied to uncommon classes and 381 | `thres` will be applied to classes in `common_class_names`. 382 | common_class_names (Optional[list of str]): list of common class names 383 | to apply `thres`. Class names not included in `common_class_names` will 384 | have `lower_thres` as a threshold. If None, all classes will have 385 | `thres` as a threshold. This is helpful for model trained on 386 | highly imbalanced dataset. 387 | mode (str): Supported modes are {"top-k", "thres"}. 388 | This is used for choosing predictions for visualization. 389 | 390 | """ 391 | assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode) 392 | self.mode = mode 393 | self.num_classes = num_classes 394 | self.class_names = class_names 395 | self.top_k = top_k 396 | self.thres = thres 397 | self.lower_thres = lower_thres 398 | 399 | if mode == "thres": 400 | self._get_thres_array(common_class_names=common_class_names) 401 | 402 | self.color_map = plt.get_cmap(colormap) 403 | 404 | def _get_color(self, class_id: int) -> List[float]: 405 | """ 406 | Get color for a class id. 407 | Args: 408 | class_id (int): class id. 409 | """ 410 | return self.color_map(class_id / self.num_classes)[:3] 411 | 412 | def draw_one_frame( 413 | self, 414 | frame: Union[torch.Tensor, np.ndarray], 415 | preds: Union[torch.Tensor, List[float]], 416 | bboxes: Optional[torch.Tensor] = None, 417 | alpha: float = 0.5, 418 | text_alpha: float = 0.7, 419 | ground_truth: bool = False, 420 | ) -> np.ndarray: 421 | """ 422 | Draw labels and bouding boxes for one image. By default, predicted 423 | labels are drawn in the top left corner of the image or corresponding 424 | bounding boxes. For ground truth labels (setting True for ground_truth flag), 425 | labels will be drawn in the bottom left corner. 426 | Args: 427 | frame (array-like): a tensor or numpy array of shape (H, W, C), 428 | where H and W correspond to 429 | the height and width of the image respectively. C is the number of 430 | color channels. The image is required to be in RGB format since that 431 | is a requirement of the Matplotlib library. The image is also expected 432 | to be in the range [0, 255]. 433 | preds (tensor or list): If ground_truth is False, provide a float tensor of 434 | shape (num_boxes, num_classes) that contains all of the confidence 435 | scores of the model. For recognition task, input shape can be (num_classes,). 436 | To plot true label (ground_truth is True), preds is a list contains int32 437 | of the shape (num_boxes, true_class_ids) or (true_class_ids,). 438 | bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates 439 | of the bounding boxes. 440 | alpha (Optional[float]): transparency level of the bounding boxes. 441 | text_alpha (Optional[float]): transparency level of the box wrapped around 442 | text labels. 443 | ground_truth (bool): whether the prodived bounding boxes are ground-truth. 444 | Returns: 445 | An image with bounding box annotations and corresponding bbox 446 | labels plotted on it. 447 | """ 448 | if isinstance(preds, torch.Tensor): 449 | if preds.ndim == 1: 450 | preds = preds.unsqueeze(0) 451 | n_instances = preds.shape[0] 452 | elif isinstance(preds, list): 453 | n_instances = len(preds) 454 | else: 455 | logger.error("Unsupported type of prediction input.") 456 | return 457 | 458 | if ground_truth: 459 | top_scores, top_classes = [None] * n_instances, preds 460 | 461 | elif self.mode == "top-k": 462 | top_scores, top_classes = torch.topk(preds, k=self.top_k) 463 | top_scores, top_classes = top_scores.tolist(), top_classes.tolist() 464 | elif self.mode == "thres": 465 | top_scores, top_classes = [], [] 466 | for pred in preds: 467 | mask = pred >= self.thres 468 | top_scores.append(pred[mask].tolist()) 469 | top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist() 470 | top_classes.append(top_class) 471 | 472 | # Create labels top k predicted classes with their scores. 473 | text_labels = [] 474 | for i in range(n_instances): 475 | text_labels.append( 476 | _create_text_labels( 477 | top_classes[i], 478 | top_scores[i], 479 | self.class_names, 480 | ground_truth=ground_truth, 481 | ) 482 | ) 483 | frame_visualizer = ImgVisualizer(frame, meta=None) 484 | font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9) 485 | top_corner = not ground_truth 486 | if bboxes is not None: 487 | assert len(preds) == len( 488 | bboxes 489 | ), "Encounter {} predictions and {} bounding boxes".format( 490 | len(preds), len(bboxes) 491 | ) 492 | for i, box in enumerate(bboxes): 493 | text = text_labels[i] 494 | pred_class = top_classes[i] 495 | colors = [self._get_color(pred) for pred in pred_class] 496 | 497 | box_color = "r" if ground_truth else "g" 498 | line_style = "--" if ground_truth else "-." 499 | frame_visualizer.draw_box( 500 | box, 501 | alpha=alpha, 502 | edge_color=box_color, 503 | line_style=line_style, 504 | ) 505 | frame_visualizer.draw_multiple_text( 506 | text, 507 | box, 508 | top_corner=top_corner, 509 | font_size=font_size, 510 | box_facecolors=colors, 511 | alpha=text_alpha, 512 | ) 513 | else: 514 | text = text_labels[0] 515 | pred_class = top_classes[0] 516 | colors = [self._get_color(pred) for pred in pred_class] 517 | frame_visualizer.draw_multiple_text( 518 | text, 519 | torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]), 520 | top_corner=top_corner, 521 | font_size=font_size, 522 | box_facecolors=colors, 523 | alpha=text_alpha, 524 | ) 525 | 526 | return frame_visualizer.output.get_image() 527 | 528 | def draw_clip_range( 529 | self, 530 | frames: Union[torch.Tensor, np.ndarray], 531 | preds: Union[torch.Tensor, List[float]], 532 | bboxes: Optional[torch.Tensor] = None, 533 | text_alpha: float = 0.5, 534 | ground_truth: bool = False, 535 | keyframe_idx: Optional[int] = None, 536 | draw_range: Optional[List[int]] = None, 537 | repeat_frame: int = 1, 538 | ) -> List[np.ndarray]: 539 | """ 540 | Draw predicted labels or ground truth classes to clip. 541 | Draw bouding boxes to clip if bboxes is provided. Boxes will gradually 542 | fade in and out the clip, centered around the clip's central frame, 543 | within the provided `draw_range`. 544 | Args: 545 | frames (array-like): video data in the shape (T, H, W, C). 546 | preds (tensor): a tensor of shape (num_boxes, num_classes) that 547 | contains all of the confidence scores of the model. For recognition 548 | task or for ground_truth labels, input shape can be (num_classes,). 549 | bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates 550 | of the bounding boxes. 551 | text_alpha (float): transparency label of the box wrapped around text labels. 552 | ground_truth (bool): whether the prodived bounding boxes are ground-truth. 553 | keyframe_idx (int): the index of keyframe in the clip. 554 | draw_range (Optional[list[ints]): only draw frames in range 555 | [start_idx, end_idx] inclusively in the clip. If None, draw on 556 | the entire clip. 557 | repeat_frame (int): repeat each frame in draw_range for `repeat_frame` 558 | time for slow-motion effect. 559 | Returns: 560 | A list of frames with bounding box annotations and corresponding 561 | bbox labels ploted on them. 562 | """ 563 | if draw_range is None: 564 | draw_range = [0, len(frames) - 1] 565 | if draw_range is not None: 566 | draw_range[0] = max(0, draw_range[0]) 567 | left_frames = frames[: draw_range[0]] 568 | right_frames = frames[draw_range[1] + 1 :] 569 | 570 | draw_frames = frames[draw_range[0] : draw_range[1] + 1] 571 | if keyframe_idx is None: 572 | keyframe_idx = len(frames) // 2 573 | 574 | img_ls = ( 575 | list(left_frames) 576 | + self.draw_clip( 577 | draw_frames, 578 | preds, 579 | bboxes=bboxes, 580 | text_alpha=text_alpha, 581 | ground_truth=ground_truth, 582 | keyframe_idx=keyframe_idx - draw_range[0], 583 | repeat_frame=repeat_frame, 584 | ) 585 | + list(right_frames) 586 | ) 587 | 588 | return img_ls 589 | 590 | def draw_clip( 591 | self, 592 | frames: Union[torch.Tensor, np.ndarray], 593 | preds: Union[torch.Tensor, List[float]], 594 | bboxes: Optional[torch.Tensor] = None, 595 | text_alpha: float = 0.5, 596 | ground_truth: bool = False, 597 | keyframe_idx: Optional[int] = None, 598 | repeat_frame: int = 1, 599 | ) -> List[np.ndarray]: 600 | """ 601 | Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip 602 | if bboxes is provided. Boxes will gradually fade in and out the clip, centered 603 | around the clip's central frame. 604 | Args: 605 | frames (array-like): video data in the shape (T, H, W, C). 606 | preds (tensor): a tensor of shape (num_boxes, num_classes) that contains 607 | all of the confidence scores of the model. For recognition task or for 608 | ground_truth labels, input shape can be (num_classes,). 609 | bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates 610 | of the bounding boxes. 611 | text_alpha (float): transparency label of the box wrapped around text labels. 612 | ground_truth (bool): whether the prodived bounding boxes are ground-truth. 613 | keyframe_idx (int): the index of keyframe in the clip. 614 | repeat_frame (int): repeat each frame in draw_range for `repeat_frame` 615 | time for slow-motion effect. 616 | Returns: 617 | A list of frames with bounding box annotations and corresponding 618 | bbox labels plotted on them. 619 | """ 620 | assert repeat_frame >= 1, "`repeat_frame` must be a positive integer." 621 | 622 | repeated_seq = range(0, len(frames)) 623 | repeated_seq = list( 624 | itertools.chain.from_iterable( 625 | itertools.repeat(x, repeat_frame) for x in repeated_seq 626 | ) 627 | ) 628 | 629 | frames, adjusted = self._adjust_frames_type(frames) 630 | if keyframe_idx is None: 631 | half_left = len(repeated_seq) // 2 632 | half_right = (len(repeated_seq) + 1) // 2 633 | else: 634 | mid = int((keyframe_idx / len(frames)) * len(repeated_seq)) 635 | half_left = mid 636 | half_right = len(repeated_seq) - mid 637 | 638 | alpha_ls = np.concatenate( 639 | [ 640 | np.linspace(0, 1, num=half_left), 641 | np.linspace(1, 0, num=half_right), 642 | ] 643 | ) 644 | text_alpha = text_alpha 645 | frames = frames[repeated_seq] 646 | img_ls = [] 647 | for alpha, frame in zip(alpha_ls, frames): 648 | draw_img = self.draw_one_frame( 649 | frame, 650 | preds, 651 | bboxes, 652 | alpha=alpha, 653 | text_alpha=text_alpha, 654 | ground_truth=ground_truth, 655 | ) 656 | if adjusted: 657 | draw_img = draw_img.astype("float32") / 255 658 | 659 | img_ls.append(draw_img) 660 | 661 | return img_ls 662 | 663 | def _adjust_frames_type( 664 | self, frames: torch.Tensor 665 | ) -> Tuple[List[np.ndarray], bool]: 666 | """ 667 | Modify video data to have dtype of uint8 and values range in [0, 255]. 668 | Args: 669 | frames (array-like): 4D array of shape (T, H, W, C). 670 | Returns: 671 | frames (list of frames): list of frames in range [0, 1]. 672 | adjusted (bool): whether the original frames need adjusted. 673 | """ 674 | assert ( 675 | frames is not None and len(frames) != 0 676 | ), "Frames does not contain any values" 677 | frames = np.array(frames) 678 | assert np.array(frames).ndim == 4, "Frames must have 4 dimensions" 679 | adjusted = False 680 | if frames.dtype in [np.float32, np.float64]: 681 | frames *= 255 682 | frames = frames.astype(np.uint8) 683 | adjusted = True 684 | 685 | return frames, adjusted 686 | 687 | def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None: 688 | """ 689 | Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`. 690 | Args: 691 | common_class_names (Optional[list of str]): a list of common class names. 692 | """ 693 | common_class_ids = [] 694 | if common_class_names is not None: 695 | common_classes = set(common_class_names) 696 | 697 | for key, name in self.class_names.items(): 698 | if name in common_classes: 699 | common_class_ids.append(key) 700 | else: 701 | common_class_ids = list(range(self.num_classes)) 702 | 703 | thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres) 704 | thres_array[common_class_ids] = self.thres 705 | self.thres = torch.from_numpy(thres_array) 706 | -------------------------------------------------------------------------------- /yolo_slowfast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,cv2,time,torch,random,pytorchvideo,warnings,argparse,math 3 | warnings.filterwarnings("ignore",category=UserWarning) 4 | 5 | from pytorchvideo.transforms.functional import ( 6 | uniform_temporal_subsample, 7 | short_side_scale_with_boxes, 8 | clip_boxes_to_image,) 9 | from torchvision.transforms._functional_video import normalize 10 | from pytorchvideo.data.ava import AvaLabeledVideoFramePaths 11 | from pytorchvideo.models.hub import slowfast_r50_detection 12 | from deep_sort.deep_sort import DeepSort 13 | 14 | 15 | def tensor_to_numpy(tensor): 16 | img = tensor.cpu().numpy().transpose((1, 2, 0)) 17 | return img 18 | 19 | def ava_inference_transform(clip, boxes, 20 | num_frames = 32, #if using slowfast_r50_detection, change this to 32, 4 for slow 21 | crop_size = 640, 22 | data_mean = [0.45, 0.45, 0.45], 23 | data_std = [0.225, 0.225, 0.225], 24 | slow_fast_alpha = 4, #if using slowfast_r50_detection, change this to 4, None for slow 25 | ): 26 | boxes = np.array(boxes) 27 | roi_boxes = boxes.copy() 28 | clip = uniform_temporal_subsample(clip, num_frames) 29 | clip = clip.float() 30 | clip = clip / 255.0 31 | height, width = clip.shape[2], clip.shape[3] 32 | boxes = clip_boxes_to_image(boxes, height, width) 33 | clip, boxes = short_side_scale_with_boxes(clip,size=crop_size,boxes=boxes,) 34 | clip = normalize(clip, 35 | np.array(data_mean, dtype=np.float32), 36 | np.array(data_std, dtype=np.float32),) 37 | boxes = clip_boxes_to_image(boxes, clip.shape[2], clip.shape[3]) 38 | if slow_fast_alpha is not None: 39 | fast_pathway = clip 40 | slow_pathway = torch.index_select(clip,1, 41 | torch.linspace(0, clip.shape[1] - 1, clip.shape[1] // slow_fast_alpha).long()) 42 | clip = [slow_pathway, fast_pathway] 43 | 44 | return clip, torch.from_numpy(boxes), roi_boxes 45 | 46 | def plot_one_box(x, img, color=[100,100,100], text_info="None", 47 | velocity=None,thickness=1,fontsize=0.5,fontthickness=1): 48 | # Plots one bounding box on image img 49 | c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) 50 | cv2.rectangle(img, c1, c2, color, thickness, lineType=cv2.LINE_AA) 51 | t_size = cv2.getTextSize(text_info, cv2.FONT_HERSHEY_TRIPLEX, fontsize , fontthickness+2)[0] 52 | cv2.rectangle(img, c1, (c1[0] + int(t_size[0]), c1[1] + int(t_size[1]*1.45)), color, -1) 53 | cv2.putText(img, text_info, (c1[0], c1[1]+t_size[1]+2), 54 | cv2.FONT_HERSHEY_TRIPLEX, fontsize, [255,255,255], fontthickness) 55 | return img 56 | 57 | def deepsort_update(Tracker,pred,xywh,np_img): 58 | outputs = Tracker.update(xywh, pred[:,4:5],pred[:,5].tolist(),cv2.cvtColor(np_img,cv2.COLOR_BGR2RGB)) 59 | return outputs 60 | 61 | def save_yolopreds_tovideo(yolo_preds,id_to_ava_labels,color_map,output_video): 62 | for i, (im, pred) in enumerate(zip(yolo_preds.ims, yolo_preds.pred)): 63 | im=cv2.cvtColor(im,cv2.COLOR_BGR2RGB) 64 | if pred.shape[0]: 65 | for j, (*box, cls, trackid, vx, vy) in enumerate(pred): 66 | if int(cls) != 0: 67 | ava_label = '' 68 | elif trackid in id_to_ava_labels.keys(): 69 | ava_label = id_to_ava_labels[trackid].split(' ')[0] 70 | else: 71 | ava_label = 'Unknow' 72 | text = '{} {} {}'.format(int(trackid),yolo_preds.names[int(cls)],ava_label) 73 | color = color_map[int(cls)] 74 | im = plot_one_box(box,im,color,text) 75 | output_video.write(im.astype(np.uint8)) 76 | 77 | def main(config): 78 | model = torch.hub.load('ultralytics/yolov5', 'yolov5l6') 79 | model.conf = config.conf 80 | model.iou = config.iou 81 | model.max_det = 200 82 | if config.classes: 83 | model.classes = config.classes 84 | device = config.device 85 | imsize = config.imsize 86 | video_model = slowfast_r50_detection(True).eval().to(device) 87 | deepsort_tracker = DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7") 88 | ava_labelnames,_ = AvaLabeledVideoFramePaths.read_label_map("selfutils/temp.pbtxt") 89 | coco_color_map = [[random.randint(0, 255) for _ in range(3)] for _ in range(80)] 90 | 91 | vide_save_path = config.output 92 | video=cv2.VideoCapture(config.input) 93 | width,height = int(video.get(3)),int(video.get(4)) 94 | video.release() 95 | outputvideo = cv2.VideoWriter(vide_save_path,cv2.VideoWriter_fourcc(*'mp4v'), 25, (width,height)) 96 | print("processing...") 97 | 98 | video = pytorchvideo.data.encoded_video.EncodedVideo.from_path(config.input) 99 | a=time.time() 100 | for i in range(0,math.ceil(video.duration),1): 101 | video_clips=video.get_clip(i, i+1-0.04) 102 | video_clips=video_clips['video'] 103 | if video_clips is None: 104 | continue 105 | img_num=video_clips.shape[1] 106 | imgs=[] 107 | for j in range(img_num): 108 | imgs.append(tensor_to_numpy(video_clips[:,j,:,:])) 109 | yolo_preds=model(imgs, size=imsize) 110 | yolo_preds.files=[f"img_{i*25+k}.jpg" for k in range(img_num)] 111 | 112 | print(i,video_clips.shape,img_num) 113 | deepsort_outputs=[] 114 | for j in range(len(yolo_preds.pred)): 115 | temp=deepsort_update(deepsort_tracker,yolo_preds.pred[j].cpu(),yolo_preds.xywh[j][:,0:4].cpu(),yolo_preds.ims[j]) 116 | if len(temp)==0: 117 | temp=np.ones((0,8)) 118 | deepsort_outputs.append(temp.astype(np.float32)) 119 | yolo_preds.pred=deepsort_outputs 120 | id_to_ava_labels={} 121 | if yolo_preds.pred[img_num//2].shape[0]: 122 | inputs,inp_boxes,_=ava_inference_transform(video_clips,yolo_preds.pred[img_num//2][:,0:4],crop_size=imsize) 123 | inp_boxes = torch.cat([torch.zeros(inp_boxes.shape[0],1), inp_boxes], dim=1) 124 | if isinstance(inputs, list): 125 | inputs = [inp.unsqueeze(0).to(device) for inp in inputs] 126 | else: 127 | inputs = inputs.unsqueeze(0).to(device) 128 | with torch.no_grad(): 129 | slowfaster_preds = video_model(inputs, inp_boxes.to(device)) 130 | slowfaster_preds = slowfaster_preds.cpu() 131 | for tid,avalabel in zip(yolo_preds.pred[img_num//2][:,5].tolist(),np.argmax(slowfaster_preds,axis=1).tolist()): 132 | id_to_ava_labels[tid]=ava_labelnames[avalabel+1] 133 | save_yolopreds_tovideo(yolo_preds,id_to_ava_labels,coco_color_map,outputvideo) 134 | print("total cost: {:.3f}s, video clips length: {}s".format(time.time()-a,video.duration)) 135 | 136 | outputvideo.release() 137 | print('saved video to:', vide_save_path) 138 | 139 | 140 | if __name__=="__main__": 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--input', type=str, default="/home/wufan/images/video/vad.mp4", help='test imgs folder or video or camera') 143 | parser.add_argument('--output', type=str, default="output.mp4", help='folder to save result imgs, can not use input folder') 144 | # object detect config 145 | parser.add_argument('--imsize', type=int, default=640, help='inference size (pixels)') 146 | parser.add_argument('--conf', type=float, default=0.4, help='object confidence threshold') 147 | parser.add_argument('--iou', type=float, default=0.4, help='IOU threshold for NMS') 148 | parser.add_argument('--device', default='cuda', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 149 | parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') 150 | config = parser.parse_args() 151 | 152 | print(config) 153 | main(config) 154 | --------------------------------------------------------------------------------