├── README.md ├── centroid ├── centroids_h.npy ├── centroids_w.npy ├── centroids_x.npy └── centroids_y.npy ├── checkpoint ├── ae │ └── ae_8.pth └── ar │ └── ar_110.pth ├── data ├── demo_test_subset.npy ├── kernel.npy ├── postp_combined_path_mot_train.npy ├── postp_mot_val.npy └── val_half_mot17.npy ├── demo_inpainting.py ├── demo_scoring.py ├── models ├── __pycache__ │ ├── ae.cpython-39.pyc │ ├── ar.cpython-39.pyc │ └── residual_block.cpython-39.pyc ├── ae.py ├── ar.py ├── model_utils.py └── residual_block.py ├── temp └── ar │ ├── inpainting │ ├── inpainting_00.jpg │ ├── inpainting_01.jpg │ ├── inpainting_02.jpg │ ├── inpainting_03.jpg │ ├── inpainting_04.jpg │ ├── inpainting_05.jpg │ ├── inpainting_06.jpg │ ├── inpainting_07.jpg │ ├── inpainting_08.jpg │ ├── inpainting_09.jpg │ ├── inpainting_10.jpg │ ├── inpainting_11.jpg │ ├── inpainting_12.jpg │ ├── inpainting_13.jpg │ ├── inpainting_14.jpg │ ├── inpainting_15.jpg │ ├── inpainting_16.jpg │ ├── inpainting_17.jpg │ ├── inpainting_18.jpg │ ├── inpainting_19.jpg │ ├── inpainting_20.jpg │ ├── inpainting_21.jpg │ ├── inpainting_22.jpg │ ├── inpainting_23.jpg │ ├── inpainting_24.jpg │ ├── inpainting_25.jpg │ ├── inpainting_26.jpg │ ├── inpainting_27.jpg │ ├── inpainting_28.jpg │ ├── inpainting_29.jpg │ ├── inpainting_30.jpg │ ├── inpainting_31.jpg │ ├── inpainting_32.jpg │ ├── inpainting_33.jpg │ ├── inpainting_34.jpg │ ├── inpainting_35.jpg │ ├── inpainting_36.jpg │ ├── inpainting_37.jpg │ ├── inpainting_38.jpg │ └── inpainting_39.jpg │ └── log_p │ └── likelihood.jpg ├── train_ae.py ├── train_ar.py └── utils ├── __pycache__ ├── clustering.cpython-39.pyc ├── dataloader_ar.cpython-39.pyc └── utils_ar.cpython-39.pyc ├── clustering.py ├── create_demo_test_subset.py ├── dataloader_ae.py ├── dataloader_ar.py ├── post_process.py ├── prepare_data.py └── utils_ar.py /README.md: -------------------------------------------------------------------------------- 1 | # [Probabilistic Tracklet Scoring and Inpainting for Multiple Object Tracking](https://openaccess.thecvf.com/content/CVPR2021/papers/Saleh_Probabilistic_Tracklet_Scoring_and_Inpainting_for_Multiple_Object_Tracking_CVPR_2021_paper.pdf) (CVPR 2021) 2 | Pytorch implementation of the ArTIST motion model. In this repo, there are 3 | 4 | - Training [script](https://github.com/fatemeh-slh/ArTIST/blob/main/train_ae.py) for the Moving Agent network 5 | - Training [script](https://github.com/fatemeh-slh/ArTIST/blob/main/train_ar.py) for the ArTIST motion model 6 | - Demo [script](https://github.com/fatemeh-slh/ArTIST/blob/main/demo_scoring.py) for Inferring the likelihood of current observations (detections) 7 | - Demo [script](https://github.com/fatemeh-slh/ArTIST/blob/main/demo_inpainting.py) for Inpainting the missing observation/detections 8 | 9 | 10 | ## Demo 1: Likelihood estimation of observation 11 | Run: 12 | ``` 13 | python3 demo_scoring.py 14 | ``` 15 | This will generate the output in the `temp/ar/log_p` directory, look like this: 16 | ![scoring demo](https://github.com/fatemeh-slh/ArTIST/blob/main/temp/ar/log_p/likelihood.jpg) 17 | 18 | This demo gets as input a pretrained model of the Moving Agent Network (MA-Net), a pretrained model of ArTIST, the centroids (obtain centroids via the [script](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/clustering.py) in the utils), a demo test sample index and the number of clusters. 19 | 20 | The model then evaluates the log-likelihood (lower the better) of all detections as the continuation of the observed sequence. 21 | 22 | ## Demo 2: Sequence inpainting 23 | Run: 24 | ``` 25 | python3 demo_inpainting.py 26 | ``` 27 | This will generate the multiple plauusible continuations of an observed motion, stored in the `temp/ar/inpainting` directory. One example looks like this: 28 | ![inpainting demo](https://github.com/fatemeh-slh/ArTIST/blob/main/temp/ar/inpainting/inpainting_25.jpg) 29 | 30 | This demo gets as input a pretrained model of the Moving Agent Network (MA-Net), a pretrained model of ArTIST, the centroids (obtain centroids via the [script](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/clustering.py) in the utils), a demo test sample index and the number of samples we wish to generate. 31 | 32 | For each generated future sequence, it computes the IoU between the last generated bounding box and the last groundtruth bounding box, as well as the mean IoU for the entire generated sequence and the groundtruth sequence. 33 | 34 | 35 | ### Utilities 36 | In this repo, there are a number of scripts to generate the required data to train/evaluate ArTIST. 37 | - [`prepare_data`](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/prepare_data.py): Given the annotations of a dataset (e.g., MOT17), it extracts the motion sequences as well as the IDs of the social tracklets living the life span of the corresponding sequence, and stores it as a dictionary. If there are multiple tracking datasets that you wish to combine, you can use the `merge_datasets()` function inside this script. 38 | - [`clustering`](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/clustering.py): Given the output dictionary of `prepare_data` script, this script performs the K-Means clustering and stores the centroids which are then used in the ArTIST model. 39 | - [`dataloader_ae`](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/dataloader_ae.py) and [`dataloader_ar`](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/dataloader_ar.py): Given the post-processes version of the dataset dictionary (which can be done by running the [`post_process`](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/post_process.py) script), these two scripts define the dataloaders for training the MA-Net and ArTIST. Note that the dataloader of ArTIST uses the MA-Net to compute the social information. This can also be done jointly in an end-to-end fashion, which we observed almost no difference. 40 | - [`create_demo_test_subset`](https://github.com/fatemeh-slh/ArTIST/blob/main/utils/create_demo_test_subset.py): In order to run the demo scripts, you need to run this script. However, the demo test subset has been produced and stored in [`data/demo_test_subset.npy`](https://github.com/fatemeh-slh/ArTIST/blob/main/data/demo_test_subset.npy). 41 | 42 | ### Data 43 | You can download the required data from the [Release](https://github.com/fatemeh-slh/ArTIST/releases/tag/data-release) and put it in `data/` directory. 44 | 45 | ## Citation 46 | If you find this work useful in your own research, please consider citing: 47 | ``` 48 | @inproceedings{saleh2021probabilistic, 49 | author={Saleh, Fatemeh and Aliakbarian, Sadegh and Rezatofighi, Hamid and Salzmann, Mathieu and Gould, Stephen}, 50 | title = {Probabilistic Tracklet Scoring and Inpainting for Multiple Object Tracking}, 51 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 52 | year = {2021} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /centroid/centroids_h.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/centroid/centroids_h.npy -------------------------------------------------------------------------------- /centroid/centroids_w.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/centroid/centroids_w.npy -------------------------------------------------------------------------------- /centroid/centroids_x.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/centroid/centroids_x.npy -------------------------------------------------------------------------------- /centroid/centroids_y.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/centroid/centroids_y.npy -------------------------------------------------------------------------------- /checkpoint/ae/ae_8.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/checkpoint/ae/ae_8.pth -------------------------------------------------------------------------------- /checkpoint/ar/ar_110.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/checkpoint/ar/ar_110.pth -------------------------------------------------------------------------------- /data/demo_test_subset.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/data/demo_test_subset.npy -------------------------------------------------------------------------------- /data/kernel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/data/kernel.npy -------------------------------------------------------------------------------- /data/postp_combined_path_mot_train.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/data/postp_combined_path_mot_train.npy -------------------------------------------------------------------------------- /data/postp_mot_val.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/data/postp_mot_val.npy -------------------------------------------------------------------------------- /data/val_half_mot17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/data/val_half_mot17.npy -------------------------------------------------------------------------------- /demo_inpainting.py: -------------------------------------------------------------------------------- 1 | """ 2 | ArTIST evaluation: Inpainting the missing observation/detections 3 | author: Fatemeh Saleh 4 | """ 5 | import cv2 6 | from utils.utils_ar import iou 7 | import numpy as np 8 | import torch 9 | from utils.clustering import load_clusters 10 | from models.ar import motion_ar 11 | from models.ae import motion_ae 12 | 13 | torch.backends.cudnn.enabled = False 14 | 15 | def test(model, centroid_x, centroid_y, centroid_w, centroid_h, cnt, n_sampling): 16 | """ This function generates multiple plausible continuations of an observed sequence 17 | 18 | Args: 19 | model (nn.Module): ArTIST model 20 | centroid_x (list): The centroids of x coordinate of the bounding boxes in the training set 21 | centroid_y (list): The centroids of y coordinate of the bounding boxes in the training set 22 | centroid_w (list): The centroids of width of the bounding boxes in the training set 23 | centroid_h ((list): The centroids of height of the bounding boxes in the training set 24 | cnt (int): The test sequence index 25 | n_sampling (int): The preferred number of inpaintings 26 | """ 27 | 28 | test_set = np.load('data/demo_test_subset.npy', allow_pickle=True) 29 | rand_len = int(test_set[cnt]['seq_len']) 30 | 31 | # we consider 75% of the sequence as observation, and aim to generate the rest 25% 32 | mask_index = int(rand_len * 0.75) 33 | data = test_set[cnt]['data'][:, :, :] 34 | # the observed sequence 35 | masked_data = test_set[cnt]['data'][:, :mask_index, :] 36 | image_wh = test_set[cnt]['wh'][0] 37 | width = image_wh[0] 38 | height = image_wh[1] 39 | # social information (see section 3.3 of https://arxiv.org/pdf/2012.02337v1.pdf) 40 | social = test_set[cnt]['social'] 41 | gap = rand_len - mask_index 42 | 43 | self_data = masked_data.cuda() 44 | self_data = self_data.cuda().float() 45 | 46 | # computing the motion velocity 47 | self_delta_tmp = self_data[:, 1:, :] - self_data[:, 0:-1, :] 48 | self_delta = torch.zeros(self_delta_tmp.shape[0], self_delta_tmp.shape[1] + 1, 49 | self_delta_tmp.shape[2]).cuda() 50 | self_delta[:, 1:, :] = self_delta_tmp 51 | 52 | inpainted_boxes = [] 53 | 54 | # generating multiple plausible continuations via the batch_inference function 55 | # To generate n_sampling sequences, this function considers the data as a batch of size n_sampling 56 | with torch.no_grad(): 57 | dist_x, dist_y, dist_w, dist_h, sampled_boxes, sampled_deltas, sampled_detection = model.batch_inference( 58 | self_data.repeat(n_sampling, 1, 1), social.repeat(n_sampling, 1, 1), gap, centroid_x, centroid_y, centroid_w, centroid_h) 59 | 60 | # visualization 61 | for n in range(n_sampling): 62 | I1 = np.ones((int(image_wh[1].item()), int(image_wh[0].item()), 3)) 63 | I1 = I1 * 255 64 | 65 | # visualizing the observed groundtruth sequence (Color: Blue) 66 | for i in range(0, mask_index): 67 | I1 = cv2.rectangle(I1, (int(data[0, i, 0].item() * image_wh[0]), int(data[0, i, 1].item() * image_wh[1])), 68 | (int(data[0, i, 0].item() * image_wh[0] + data[0, i, 2].item() * image_wh[0]), 69 | int(data[0, i, 1].item() * image_wh[1] + data[0, i, 3].item() * image_wh[1])), 70 | (255, 0, 0), 1) 71 | 72 | sequence_iou = [] 73 | # visualizing the future sequence (groundtruth and inpainted) 74 | for i in range(gap - 1): 75 | 76 | # groundtruth future sequence (Color: Red) 77 | gt_box = [ 78 | int(data[0, mask_index + i, 0].item() * image_wh[0]), 79 | int(data[0, mask_index + i, 1].item() * image_wh[1]), 80 | int(data[0, mask_index + i, 0].item() * image_wh[0] + data[0, mask_index + i, 2].item() * image_wh[0]), 81 | int(data[0, mask_index + i, 1].item() * image_wh[1] + data[0, mask_index + i, 3].item() * image_wh[1])] 82 | 83 | # inpainted future sequence (Color: Green) 84 | inpainted_box = [ 85 | int(sampled_boxes[n, i, 0].item() * image_wh[0]), 86 | int(sampled_boxes[n, i, 1].item() * image_wh[1]), 87 | int(sampled_boxes[n, i, 0].item() * image_wh[0] + sampled_boxes[n, i, 2].item() * image_wh[0]), 88 | int(sampled_boxes[n, i, 1].item() * image_wh[1] + sampled_boxes[n, i, 3].item() * image_wh[1])] 89 | 90 | sequence_iou.append(iou(gt_box, inpainted_box)) 91 | I1 = cv2.rectangle(I1, (gt_box[0], gt_box[1]), (gt_box[2], gt_box[3]), (0, 0, 255), 1) 92 | I1 = cv2.rectangle(I1, (inpainted_box[0], inpainted_box[1]), (inpainted_box[2], inpainted_box[3]), (0, 255, 0), 1) 93 | 94 | # additional information: last time IoU and future sequence mean IoU 95 | cv2.putText(I1, 'Last time IoU: ' + str(round(sequence_iou[-1], 2)), (10, 50), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 0), 1 , lineType=cv2.LINE_AA) 96 | cv2.putText(I1, 'Future sequence mIoU: ' + str(round(np.mean(sequence_iou), 2)), (10, 100), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 0), 1 , lineType=cv2.LINE_AA) 97 | cv2.imwrite('temp/ar/inpainting/inpainting_' + str(n).zfill(2) + '.jpg', I1) 98 | 99 | 100 | if __name__ == '__main__': 101 | model_ae = motion_ae(256).cuda() 102 | model_ae.load_state_dict(torch.load('checkpoint/ae/ae_8.pth')) 103 | model_ae.eval() 104 | model_ar = motion_ar(512, 1024).cuda() 105 | model_ar.load_state_dict(torch.load('checkpoint/ar/ar_110.pth')) 106 | model_ar.eval() 107 | centroid_x, centroid_y, centroid_w, centroid_h = load_clusters() 108 | test(model_ar, centroid_x, centroid_y, centroid_w, centroid_h, 100, 40) -------------------------------------------------------------------------------- /demo_scoring.py: -------------------------------------------------------------------------------- 1 | """ 2 | ArTIST evaluation: Inferring the likelihood of current observations (detections) 3 | author: Fatemeh Saleh 4 | """ 5 | 6 | import cv2 7 | from utils.utils_ar import infer_log_likelihood, iou 8 | import numpy as np 9 | import torch 10 | from utils.clustering import clustering, load_clusters 11 | from models.ar import motion_ar 12 | from models.ae import motion_ae 13 | 14 | torch.backends.cudnn.enabled = False 15 | 16 | 17 | def test(model, centroid_x, centroid_y, centroid_w, centroid_h, cnt, num_cluster): 18 | """[summary] 19 | 20 | Args: 21 | model (nn.Module): ArTIST model 22 | centroid_x (list): The centroids of x coordinate of the bounding boxes in the training set 23 | centroid_y (list): The centroids of y coordinate of the bounding boxes in the training set 24 | centroid_w (list): The centroids of width of the bounding boxes in the training set 25 | centroid_h ((list): The centroids of height of the bounding boxes in the training set 26 | cnt (int): The test sequence index 27 | num_cluster (int): number of clusters 28 | """ 29 | 30 | test_set = np.load('data/demo_test_subset.npy', allow_pickle=True) 31 | rand_len = int(test_set[cnt]['seq_len']) 32 | 33 | # we consider the sequence to be observed up until mask_index 34 | mask_index = int(rand_len * 0.75) 35 | data = test_set[cnt]['data'][:, :mask_index, :] 36 | BBOX = [] 37 | image_wh = test_set[cnt]['wh'][0] 38 | width = image_wh[0] 39 | height = image_wh[1] 40 | # social information (see section 3.3 of https://arxiv.org/pdf/2012.02337v1.pdf) 41 | social = test_set[cnt]['social'] 42 | valid_box = test_set[cnt]['data'][:, mask_index, :] 43 | 44 | # creating some imaginary detections at current time step 45 | instance_options = [[820/width, 830/height, 80/width, 170/height], [400/width, 310/height, 110/width, 330/height], [1000/width, 300/height, 170/width, 300/height]] 46 | instance_options.append([valid_box[0, 0], valid_box[0, 1], valid_box[0, 2], valid_box[0, 3]]) 47 | 48 | scores, ious = np.zeros(len(instance_options)), np.zeros(len(instance_options)) 49 | gap = 1 50 | gaussian_kernel = np.load("data/kernel.npy") 51 | gaussian_kernel = torch.autograd.Variable(torch.from_numpy(gaussian_kernel).float()).cuda() 52 | gaussian_kernel = gaussian_kernel.unsqueeze(0).unsqueeze(0) 53 | 54 | self_data = data.cuda() 55 | self_data = self_data.cuda().float() 56 | 57 | # computing the motion velocity 58 | self_delta_tmp = self_data[:, 1:, :] - self_data[:, 0:-1, :] 59 | self_delta = torch.zeros(self_delta_tmp.shape[0], self_delta_tmp.shape[1] + 1, 60 | self_delta_tmp.shape[2]).cuda() 61 | self_delta[:, 1:, :] = self_delta_tmp 62 | 63 | # computing the distribution over the next plausible bounding box 64 | dist_x, dist_y, dist_w, dist_h, sampled_boxes, sampled_deltas, sampled_detection = model.inference( 65 | self_data, 66 | social[:, :mask_index, :], 67 | gap, 68 | centroid_x, 69 | centroid_y, 70 | centroid_w, 71 | centroid_h) 72 | 73 | # making it a probability distribution 74 | dist_x = torch.nn.Softmax(dim=-1)(dist_x) 75 | dist_y = torch.nn.Softmax(dim=-1)(dist_y) 76 | dist_w = torch.nn.Softmax(dim=-1)(dist_w) 77 | dist_h = torch.nn.Softmax(dim=-1)(dist_h) 78 | 79 | # smoothing the distributions using a gaussian kernel 80 | y_g1d_x = torch.nn.functional.conv1d(dist_x, gaussian_kernel.repeat(dist_x.shape[1], dist_x.shape[1], 1), 81 | padding=24) 82 | y_g1d_y = torch.nn.functional.conv1d(dist_y, gaussian_kernel.repeat(dist_y.shape[1], dist_y.shape[1], 1), 83 | padding=24) 84 | y_g1d_w = torch.nn.functional.conv1d(dist_w, gaussian_kernel.repeat(dist_w.shape[1], dist_w.shape[1], 1), 85 | padding=24) 86 | y_g1d_h = torch.nn.functional.conv1d(dist_h, gaussian_kernel.repeat(dist_h.shape[1], dist_h.shape[1], 1), 87 | padding=24) 88 | 89 | extended_track = torch.zeros(1, self_delta.shape[1] + len(sampled_boxes) + 1, 4).cuda() 90 | extended_track[0, :self_delta.shape[1], :] = self_delta[0, :, :] 91 | extended_track[0, self_delta.shape[1]:-1, :] = sampled_deltas[0, :-1, :] 92 | 93 | observation_last = [self_data[0, -1, 0].item() * width, self_data[0, -1, 1].item() * height, 94 | (self_data[0, -1, 0].item() * width) + (self_data[0, -1, 2].item() * width), 95 | (self_data[0, -1, 1].item() * height) + (self_data[0, -1, 3].item() * height)] 96 | 97 | # loop over the detections... 98 | for opt_idx, option in enumerate(instance_options): 99 | opt = [(option[0]), (option[1]), (option[2]), (option[3])] 100 | 101 | option_unnorm = [ 102 | option[0] * width, 103 | option[1] * height, 104 | (option[2] + option[0]) * width, 105 | (option[3] + option[1]) * height 106 | ] 107 | iou_validate = iou(option_unnorm, observation_last) 108 | ious[opt_idx] = iou_validate 109 | 110 | last_delta = torch.from_numpy(np.array(opt) - self_data[0, -1].cpu().detach().numpy()).cuda() 111 | 112 | extended_track[0, -1, :] = last_delta 113 | 114 | # inferring the likelihood of each detection (considered as the last observation of the sequence) 115 | likelihoods_smooth = infer_log_likelihood(y_g1d_x, y_g1d_y, y_g1d_w, y_g1d_h, 116 | extended_track[:, 1:, 0:1], 117 | extended_track[:, 1:, 1:2], 118 | extended_track[:, 1:, 2:3], 119 | extended_track[:, 1:, 3:4], 120 | centroid_x, centroid_y, centroid_w, centroid_h, num_cluster) 121 | all_scores = np.array(likelihoods_smooth[-1]) 122 | likelihoods_smooth = np.sum(all_scores) 123 | score = likelihoods_smooth 124 | scores[opt_idx] = score 125 | 126 | # visualization 127 | I1 = np.ones((int(image_wh[1].item()), int(image_wh[0].item()), 3)) 128 | I1 = I1 * 255 129 | 130 | for i in range(0, mask_index): 131 | I1 = cv2.rectangle(I1, ( 132 | int(data[0, i, 0].item() * image_wh[0]), int(data[0, i, 1].item() * image_wh[1])), 133 | (int(data[0, i, 0].item() * image_wh[0] + data[0, i, 2].item() * image_wh[0]), 134 | int(data[0, i, 1].item() * image_wh[1] + data[0, i, 3].item() * image_wh[1])), 135 | (255, 0, 0), 1) #### GT 136 | 137 | for i in range(len(instance_options)): 138 | I1 = cv2.rectangle(I1, ( 139 | int(instance_options[i][0]*width), int(instance_options[i][1]*height), 140 | int(instance_options[i][2]*width), int(instance_options[i][3]*height)), (0, 0, 0), 2) 141 | cv2.putText(I1, 'log p: ' + str(round(scores[i], 2)), (int(instance_options[i][0]*width), int(instance_options[i][1]*height)-40), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 0), 142 | 1, lineType=cv2.LINE_AA) 143 | cv2.putText(I1, 'IoU w/ last bbox: ' + str(round(ious[i], 2)), (int(instance_options[i][0]*width), int(instance_options[i][1]*height)-15), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 0), 144 | 1 , lineType=cv2.LINE_AA) 145 | cv2.imwrite('temp/ar/log_p/likelihood.jpg', I1) 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | model_ae = motion_ae(256).cuda() 151 | model_ae.load_state_dict(torch.load('checkpoint/ae/ae_8.pth')) 152 | model_ae.eval() 153 | model_ar = motion_ar(512, 1024).cuda() 154 | model_ar.load_state_dict(torch.load('checkpoint/ar/ar_110.pth')) 155 | model_ar.eval() 156 | centroid_x, centroid_y, centroid_w, centroid_h = load_clusters() 157 | test(model_ar, centroid_x, centroid_y, centroid_w, centroid_h, 100, 1024) -------------------------------------------------------------------------------- /models/__pycache__/ae.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/models/__pycache__/ae.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/ar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/models/__pycache__/ar.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/residual_block.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/models/__pycache__/residual_block.cpython-39.pyc -------------------------------------------------------------------------------- /models/ae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class motion_ae(nn.Module): 7 | 8 | def __init__(self, hidden_state): 9 | 10 | super(motion_ae, self).__init__() 11 | self.hidden_size = hidden_state 12 | self.encoder_fc = nn.Linear(4, self.hidden_size // 2) 13 | self.encoder = nn.GRU(self.hidden_size // 2, self.hidden_size, num_layers=1, batch_first=True) 14 | self.decoder = nn.GRU(4, self.hidden_size, num_layers=1, batch_first=True) 15 | self.decoder_fc = nn.Sequential( 16 | nn.Linear(self.hidden_size, self.hidden_size // 4), 17 | nn.ReLU(), 18 | nn.Linear(self.hidden_size // 4, 4) 19 | ) 20 | 21 | def forward(self, observation, tf): 22 | observation_enc = nn.ReLU()(self.encoder_fc(observation)) 23 | _, encoder_h = self.encoder(observation_enc) 24 | 25 | T = observation.shape[1] 26 | mask = np.random.uniform(size=T - 1) < tf 27 | 28 | reconstructed = [] 29 | init_motion = observation[:, 0:1, :] 30 | 31 | x, h = self.decoder(init_motion, encoder_h) 32 | x = self.decoder_fc(x) 33 | x = x + init_motion 34 | 35 | reconstructed.append(x) 36 | 37 | for t in range(1, T): 38 | if mask[t - 1]: 39 | x_t, h = self.decoder(observation[:, t:t + 1, :], h) 40 | else: 41 | x_t, h = self.decoder(x, h) 42 | 43 | x_t = self.decoder_fc(x_t) 44 | x = x_t + x 45 | 46 | reconstructed.append(x) 47 | 48 | return torch.cat(reconstructed, dim=1) 49 | 50 | def inference(self, observation): 51 | observation = self.encoder_fc(observation) 52 | hiddens = [] 53 | h = None 54 | for t in range(observation.shape[1]): 55 | _, h = self.encoder(observation[:, t:t+1, :], h) 56 | hiddens.append(h) 57 | hiddens = torch.cat(hiddens, dim=0) 58 | hiddens = hiddens.permute(1, 0, 2) 59 | return hiddens 60 | -------------------------------------------------------------------------------- /models/ar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch import nn 4 | 5 | from .residual_block import ResidualBlock 6 | 7 | 8 | class motion_ar(nn.Module): 9 | 10 | def __init__(self, hidden_state=512, num_clusters=1024): 11 | 12 | super(motion_ar, self).__init__() 13 | self.hidden_size = hidden_state 14 | self.num_clusters = num_clusters 15 | self.fc_embedding = ResidualBlock(4, self.hidden_size) 16 | self.embedding = nn.Sequential(nn.Linear(self.hidden_size + 256, self.hidden_size), nn.ReLU()) 17 | self.rnn = nn.LSTM(self.hidden_size, self.hidden_size, num_layers=1, batch_first=True) 18 | 19 | self.fc_x = nn.Linear(self.hidden_size, num_clusters) 20 | self.fc_y = nn.Linear(self.hidden_size, num_clusters) 21 | self.fc_w = nn.Linear(self.hidden_size, num_clusters) 22 | self.fc_h = nn.Linear(self.hidden_size, num_clusters) 23 | 24 | def init_hidden(self, bsz): 25 | return (nn.Parameter(torch.zeros(self.rnn.num_layers, bsz, self.hidden_size).normal_(0, 0.01), requires_grad=True).cuda(), 26 | nn.Parameter(torch.zeros(self.rnn.num_layers, bsz, self.hidden_size).normal_(0, 0.01), 27 | requires_grad=True).cuda()) 28 | 29 | def sample_cluster(self, dist, centroid): 30 | indices = (torch.topk(nn.Softmax(dim=-1)(dist), k=1)[1]).squeeze(1) 31 | return centroid[0, list(indices.squeeze(1))].unsqueeze(1) 32 | 33 | def sample_cluster_multinomial(self, dist, centroid): 34 | m = torch.distributions.multinomial.Multinomial(total_count=self.num_clusters, probs=nn.Softmax(dim=-1)(dist)) 35 | indices = torch.multinomial(nn.Softmax(dim=-1)(dist)[:, 0], 1) 36 | 37 | return centroid[0, list(indices)].unsqueeze(1) 38 | 39 | def single_forward(self, x, s, h): 40 | 41 | x = self.fc_embedding(x) 42 | x_s = self.embedding(torch.cat([x, s], dim=2)) 43 | x_orig = x.clone() 44 | 45 | x, h = self.rnn(x_s, h) 46 | x += x_orig 47 | 48 | featureX = self.fc_x(x) 49 | featureY = self.fc_y(x) 50 | featureW = self.fc_w(x) 51 | featureH = self.fc_h(x) 52 | 53 | return h, featureX, featureY, featureW, featureH 54 | 55 | def forward(self, observation, social, mask_index, centroid_x, centroid_y, centroid_w, centroid_h): 56 | 57 | # observation and last box are bboxes not delta! 58 | B, T = observation.shape[0], observation.shape[1] 59 | sampled_delta = torch.zeros(observation.shape).cuda() 60 | 61 | # h = init_h 62 | h = self.init_hidden(B) 63 | 64 | reconstructed_x = [] 65 | reconstructed_y = [] 66 | reconstructed_w = [] 67 | reconstructed_h = [] 68 | 69 | init_motion = observation[:, 0:1, :] 70 | init_social = social[:, 0:1, :] 71 | sampled_delta[:, 0:1, :] = observation[:, 0:1, :].clone() 72 | pos = torch.zeros(B, 1, 1).cuda() 73 | 74 | h, featureX, featureY, featureW, featureH = self.single_forward(init_motion, init_social, h) 75 | 76 | featureX = featureX.view(B, 1, -1) 77 | featureY = featureY.view(B, 1, -1) 78 | featureW = featureW.view(B, 1, -1) 79 | featureH = featureH.view(B, 1, -1) 80 | 81 | reconstructed_x.append(featureX) 82 | reconstructed_y.append(featureY) 83 | reconstructed_w.append(featureW) 84 | reconstructed_h.append(featureH) 85 | 86 | for t in range(1, T - 1): 87 | if t < mask_index: 88 | sampled_delta[:, t:t + 1, :] = observation[:, t:t + 1, :].clone() 89 | h, featureX, featureY, featureW, featureH = self.single_forward(observation[:, t:t+1, :], social[:, t:t+1, :], h) 90 | else: 91 | sampled_x = torch.zeros(B, 1, 4).cuda() 92 | 93 | sampled_x[:, :, 0] = self.sample_cluster(featureX, centroid_x) 94 | sampled_x[:, :, 1] = self.sample_cluster(featureY, centroid_y) 95 | sampled_x[:, :, 2] = self.sample_cluster(featureW, centroid_w) 96 | sampled_x[:, :, 3] = self.sample_cluster(featureH, centroid_h) 97 | 98 | sampled_delta[:, t:t+1, :] = sampled_x.clone() 99 | h, featureX, featureY, featureW, featureH = self.single_forward(sampled_x, social[:, t:t+1, :], h) 100 | 101 | featureX = featureX.view(B, 1, -1) 102 | featureY = featureY.view(B, 1, -1) 103 | featureW = featureW.view(B, 1, -1) 104 | featureH = featureH.view(B, 1, -1) 105 | 106 | reconstructed_x.append(featureX) 107 | reconstructed_y.append(featureY) 108 | reconstructed_w.append(featureW) 109 | reconstructed_h.append(featureH) 110 | 111 | sampled_delta[:, -1:, :] = observation[:, -1:, :].clone() 112 | 113 | return sampled_delta, torch.cat(reconstructed_x, dim=1), torch.cat(reconstructed_y, dim=1),\ 114 | torch.cat(reconstructed_w, dim=1), torch.cat(reconstructed_h, dim=1) 115 | 116 | def inference(self, observation, social, gap, centroid_x, centroid_y, centroid_w, centroid_h): 117 | # observation and last box are bboxes not delta! 118 | B, mask_index = observation.shape[0], observation.shape[1] 119 | 120 | T = mask_index + gap 121 | 122 | # h = init_h 123 | h = self.init_hidden(B) 124 | 125 | reconstructed_x = [] 126 | reconstructed_y = [] 127 | reconstructed_w = [] 128 | reconstructed_h = [] 129 | 130 | generated_bbox = [] 131 | 132 | last_bbox = [observation[0, -1, 0].item(), 133 | observation[0, -1, 1].item(), 134 | observation[0, -1, 2].item(), 135 | observation[0, -1, 3].item()] 136 | 137 | generated_delta = torch.zeros(1, gap, 4) 138 | delta_cnt = 0 139 | 140 | # new alternative 141 | if observation.shape[1] == 1: 142 | current_obs = observation[:, 0:1, :] * 0.0 143 | h, featureX_obs, featureY_obs, featureW_obs, featureH_obs = self.single_forward(current_obs, social[:, :observation.shape[1], :], h) 144 | featureX, featureY, featureW, featureH = featureX_obs[:, -1, :].unsqueeze(1), featureY_obs[:, -1, :].unsqueeze(1), featureW_obs[:, -1, :].unsqueeze(1), featureH_obs[:, -1, :].unsqueeze(1) 145 | else: 146 | current_obs = observation[:, 1:, :] - observation[:, :-1, :] 147 | # current_obs = torch.cat([torch.zeros(current_obs.shape[0], 1, current_obs.shape[2]).cuda(), current_obs], dim=1) 148 | h, featureX_obs, featureY_obs, featureW_obs, featureH_obs = self.single_forward(current_obs, social[:, :current_obs.shape[1], :], h) 149 | featureX, featureY, featureW, featureH = featureX_obs[:, -1, :].unsqueeze(1), featureY_obs[:, -1, :].unsqueeze(1), featureW_obs[:, -1, :].unsqueeze(1), featureH_obs[:, -1, :].unsqueeze(1) 150 | 151 | 152 | # new alternative 153 | for t in range(observation.shape[1], observation.shape[1] + gap - 1): 154 | sampled_x = torch.zeros(B, 1, 4).cuda() 155 | sampled_x[:, :, 0] = self.sample_cluster_multinomial(featureX, centroid_x) 156 | sampled_x[:, :, 1] = self.sample_cluster_multinomial(featureY, centroid_y) 157 | sampled_x[:, :, 2] = self.sample_cluster_multinomial(featureW, centroid_w) 158 | sampled_x[:, :, 3] = self.sample_cluster_multinomial(featureH, centroid_h) 159 | 160 | # sampled_x[:, :, 0] = self.sample_cluster(featureX, centroid_x) 161 | # sampled_x[:, :, 1] = self.sample_cluster(featureY, centroid_y) 162 | # sampled_x[:, :, 2] = self.sample_cluster(featureW, centroid_w) 163 | # sampled_x[:, :, 3] = self.sample_cluster(featureH, centroid_h) 164 | 165 | generated_delta[0, delta_cnt] = sampled_x[0, 0] 166 | delta_cnt += 1 167 | 168 | last_bbox = [last_bbox[0] + sampled_x[0, 0, 0].item(), 169 | last_bbox[1] + sampled_x[0, 0, 1].item(), 170 | last_bbox[2] + sampled_x[0, 0, 2].item(), 171 | last_bbox[3] + sampled_x[0, 0, 3].item()] 172 | 173 | generated_bbox.append(last_bbox) 174 | current_social = social[:, t:t + 1, :] 175 | h, featureX, featureY, featureW, featureH = self.single_forward(sampled_x, current_social, h) 176 | 177 | featureX = featureX.view(B, 1, -1) 178 | featureY = featureY.view(B, 1, -1) 179 | featureW = featureW.view(B, 1, -1) 180 | featureH = featureH.view(B, 1, -1) 181 | 182 | reconstructed_x.append(featureX) 183 | reconstructed_y.append(featureY) 184 | reconstructed_w.append(featureW) 185 | reconstructed_h.append(featureH) 186 | 187 | sampled_detection = torch.zeros(B, 1, 4).cuda() 188 | sampled_detection[:, :, 0] = self.sample_cluster_multinomial(featureX, centroid_x) 189 | sampled_detection[:, :, 1] = self.sample_cluster_multinomial(featureY, centroid_y) 190 | sampled_detection[:, :, 2] = self.sample_cluster_multinomial(featureW, centroid_w) 191 | sampled_detection[:, :, 3] = self.sample_cluster_multinomial(featureH, centroid_h) 192 | # sampled_detection[:, :, 0] = self.sample_cluster(featureX, centroid_x) 193 | # sampled_detection[:, :, 1] = self.sample_cluster(featureY, centroid_y) 194 | # sampled_detection[:, :, 2] = self.sample_cluster(featureW, centroid_w) 195 | # sampled_detection[:, :, 3] = self.sample_cluster(featureH, centroid_h) 196 | 197 | last_detection = [last_bbox[0] + sampled_detection[0, 0, 0].item(), 198 | last_bbox[1] + sampled_detection[0, 0, 1].item(), 199 | last_bbox[2] + sampled_detection[0, 0, 2].item(), 200 | last_bbox[3] + sampled_detection[0, 0, 3].item()] 201 | 202 | if len(reconstructed_x) > 0: 203 | reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h = torch.cat(reconstructed_x, dim=1),\ 204 | torch.cat(reconstructed_y, dim=1), torch.cat(reconstructed_w, dim=1), torch.cat(reconstructed_h, dim=1) 205 | 206 | reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h = torch.cat([featureX_obs, reconstructed_x], dim=1), \ 207 | torch.cat([featureY_obs, reconstructed_y], dim=1), \ 208 | torch.cat([featureW_obs, reconstructed_w], dim=1), \ 209 | torch.cat([featureH_obs, reconstructed_h], dim=1) 210 | else: 211 | reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h = featureX_obs, featureY_obs, featureW_obs, featureH_obs 212 | return reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h, \ 213 | generated_bbox, generated_delta, last_detection 214 | 215 | def batch_inference(self, observation, social, gap, centroid_x, centroid_y, centroid_w, centroid_h): 216 | # observation and last box are bboxes not delta! 217 | B, mask_index = observation.shape[0], observation.shape[1] 218 | 219 | T = mask_index + gap 220 | 221 | # h = init_h 222 | h = self.init_hidden(B) 223 | 224 | reconstructed_x = [] 225 | reconstructed_y = [] 226 | reconstructed_w = [] 227 | reconstructed_h = [] 228 | 229 | generated_bbox = torch.zeros(B, gap - 1, 4).cuda() 230 | 231 | last_bbox = torch.zeros(B, 1, 4).cuda() 232 | last_bbox[:, 0, 0] = observation[:, -1, 0] 233 | last_bbox[:, 0, 1] = observation[:, -1, 1] 234 | last_bbox[:, 0, 2] = observation[:, -1, 2] 235 | last_bbox[:, 0, 3] = observation[:, -1, 3] 236 | 237 | generated_delta = torch.zeros(B, gap - 1, 4) 238 | delta_cnt = 0 239 | 240 | # new alternative 241 | if observation.shape[1] == 1: 242 | current_obs = observation[:, 0:1, :] * 0.0 243 | h, featureX_obs, featureY_obs, featureW_obs, featureH_obs = self.single_forward(current_obs, social[:, :observation.shape[1], :], h) 244 | featureX, featureY, featureW, featureH = featureX_obs[:, -1, :].unsqueeze(1), featureY_obs[:, -1, :].unsqueeze(1), featureW_obs[:, -1, :].unsqueeze(1), featureH_obs[:, -1, :].unsqueeze(1) 245 | else: 246 | current_obs = observation[:, 1:, :] - observation[:, :-1, :] 247 | # current_obs = torch.cat([torch.zeros(current_obs.shape[0], 1, current_obs.shape[2]).cuda(), current_obs], dim=1) 248 | h, featureX_obs, featureY_obs, featureW_obs, featureH_obs = self.single_forward(current_obs, social[:, :current_obs.shape[1], :], h) 249 | featureX, featureY, featureW, featureH = featureX_obs[:, -1, :].unsqueeze(1), featureY_obs[:, -1, :].unsqueeze(1), featureW_obs[:, -1, :].unsqueeze(1), featureH_obs[:, -1, :].unsqueeze(1) 250 | 251 | 252 | # new alternative 253 | for t in range(observation.shape[1], observation.shape[1] + gap - 1): 254 | sampled_x = torch.zeros(B, 1, 4).cuda() 255 | sampled_x[:, :, 0] = self.sample_cluster_multinomial(featureX, centroid_x) 256 | sampled_x[:, :, 1] = self.sample_cluster_multinomial(featureY, centroid_y) 257 | sampled_x[:, :, 2] = self.sample_cluster_multinomial(featureW, centroid_w) 258 | sampled_x[:, :, 3] = self.sample_cluster_multinomial(featureH, centroid_h) 259 | 260 | generated_delta[:, delta_cnt] = sampled_x[:, 0] 261 | 262 | last_bbox[:, 0, 0] += sampled_x[:, 0, 0] 263 | last_bbox[:, 0, 1] += sampled_x[:, 0, 1] 264 | last_bbox[:, 0, 2] += sampled_x[:, 0, 2] 265 | last_bbox[:, 0, 3] += sampled_x[:, 0, 3] 266 | 267 | generated_bbox[:, delta_cnt] = last_bbox[:, 0] 268 | delta_cnt += 1 269 | 270 | current_social = social[:, t:t + 1, :] 271 | h, featureX, featureY, featureW, featureH = self.single_forward(sampled_x, current_social, h) 272 | 273 | featureX = featureX.view(B, 1, -1) 274 | featureY = featureY.view(B, 1, -1) 275 | featureW = featureW.view(B, 1, -1) 276 | featureH = featureH.view(B, 1, -1) 277 | 278 | reconstructed_x.append(featureX) 279 | reconstructed_y.append(featureY) 280 | reconstructed_w.append(featureW) 281 | reconstructed_h.append(featureH) 282 | 283 | sampled_detection = torch.zeros(B, 1, 4).cuda() 284 | sampled_detection[:, :, 0] = self.sample_cluster_multinomial(featureX, centroid_x) 285 | sampled_detection[:, :, 1] = self.sample_cluster_multinomial(featureY, centroid_y) 286 | sampled_detection[:, :, 2] = self.sample_cluster_multinomial(featureW, centroid_w) 287 | sampled_detection[:, :, 3] = self.sample_cluster_multinomial(featureH, centroid_h) 288 | # sampled_detection[:, :, 0] = self.sample_cluster(featureX, centroid_x) 289 | # sampled_detection[:, :, 1] = self.sample_cluster(featureY, centroid_y) 290 | # sampled_detection[:, :, 2] = self.sample_cluster(featureW, centroid_w) 291 | # sampled_detection[:, :, 3] = self.sample_cluster(featureH, centroid_h) 292 | 293 | last_detection = torch.zeros(B, 1, 4).cuda() 294 | last_detection[:, 0, 0] = last_bbox[:, 0, 0] + sampled_detection[:, 0, 0] 295 | last_detection[:, 0, 1] = last_bbox[:, 0, 1] + sampled_detection[:, 0, 1] 296 | last_detection[:, 0, 2] = last_bbox[:, 0, 2] + sampled_detection[:, 0, 2] 297 | last_detection[:, 0, 3] = last_bbox[:, 0, 3] + sampled_detection[:, 0, 3] 298 | 299 | if len(reconstructed_x) > 0: 300 | reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h = torch.cat(reconstructed_x, dim=1),\ 301 | torch.cat(reconstructed_y, dim=1), torch.cat(reconstructed_w, dim=1), torch.cat(reconstructed_h, dim=1) 302 | 303 | reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h = torch.cat([featureX_obs, reconstructed_x], dim=1), \ 304 | torch.cat([featureY_obs, reconstructed_y], dim=1), \ 305 | torch.cat([featureW_obs, reconstructed_w], dim=1), \ 306 | torch.cat([featureH_obs, reconstructed_h], dim=1) 307 | else: 308 | reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h = featureX_obs, featureY_obs, featureW_obs, featureH_obs 309 | return reconstructed_x, reconstructed_y, reconstructed_w, reconstructed_h, \ 310 | generated_bbox, generated_delta, last_detection -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | 6 | 7 | def to_var(x, volatile=False): 8 | if torch.cuda.is_available(): 9 | x = x.cuda() 10 | return torch.autograd.Variable(x, volatile=volatile) 11 | -------------------------------------------------------------------------------- /models/residual_block.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | """ Residual Network that is then used for the VAE encoder and the VAE decoder. """ 8 | 9 | def __init__(self, input_size, embedding_size): 10 | super().__init__() 11 | 12 | self.shortcut = nn.Linear(input_size, embedding_size) 13 | self.deep1 = nn.Linear(input_size, embedding_size // 2) 14 | self.deep2 = nn.Linear(embedding_size // 2, embedding_size // 2) 15 | self.deep3 = nn.Linear(embedding_size // 2, embedding_size) 16 | 17 | def forward(self, input_tensor, activation=nn.ReLU()): 18 | 19 | if activation is not None: 20 | 21 | shortcut = self.shortcut(input_tensor) 22 | 23 | deep1 = activation(self.deep1(input_tensor)) 24 | deep2 = activation(self.deep2(deep1)) 25 | deep3 = self.deep3(deep2) 26 | 27 | else: 28 | 29 | shortcut = self.shortcut(input_tensor) 30 | 31 | deep1 = self.deep1(input_tensor) 32 | deep2 = self.deep2(deep1) 33 | deep3 = self.deep3(deep2) 34 | 35 | output = shortcut + deep3 36 | 37 | return output 38 | -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_00.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_01.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_02.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_03.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_04.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_05.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_06.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_07.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_08.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_09.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_10.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_11.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_12.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_13.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_14.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_15.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_16.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_17.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_18.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_19.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_20.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_21.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_22.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_23.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_24.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_25.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_26.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_27.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_28.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_29.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_29.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_30.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_31.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_32.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_33.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_34.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_34.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_35.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_35.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_36.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_37.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_37.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_38.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_38.jpg -------------------------------------------------------------------------------- /temp/ar/inpainting/inpainting_39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/inpainting/inpainting_39.jpg -------------------------------------------------------------------------------- /temp/ar/log_p/likelihood.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/temp/ar/log_p/likelihood.jpg -------------------------------------------------------------------------------- /train_ae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data 4 | from torch import nn, optim 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from .dataloader_ae import dataloader 8 | from models.ae import motion_ae 9 | 10 | 11 | torch.backends.cudnn.benchmark = True 12 | 13 | 14 | def train(model): 15 | 16 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 17 | tf = 1 18 | gradient_clip = 0.1 19 | max_iter = 80001 20 | model.train() 21 | 22 | for batch_idx in range(max_iter): 23 | seq_len = np.random.randint(3, 100) 24 | data, wh = next(iter(train_loader.generate(seq_len))) 25 | 26 | data = data.cuda() 27 | data_vel = data[:, 1:, :] - data[:, :-1, :] 28 | optimizer.zero_grad() 29 | 30 | reconstructed = model(data_vel, tf) 31 | loss = F.mse_loss(reconstructed, data_vel, reduction="sum") 32 | loss.backward() 33 | 34 | nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) 35 | optimizer.step() 36 | 37 | if batch_idx % 50 == 0: 38 | tf *= 0.99 39 | model.eval() 40 | 41 | with torch.no_grad(): 42 | test_sample, wh = next(iter(val_loader.generate(100))) 43 | test_sample = test_sample.cuda() 44 | test_sample_vel = test_sample[:, 1:, :] - test_sample[:, :-1, :] 45 | reconstructed = model(test_sample_vel, 0) 46 | 47 | np.save("temp/ae/train/reconstructed.npy", reconstructed.cpu().detach().numpy()[0]) 48 | np.save("temp/ae/train/original.npy", test_sample.cpu().detach().numpy()[0]) 49 | np.save("temp/ae/train/wh.npy", wh.numpy()[0]) 50 | 51 | print("[", batch_idx, "]\tLoss: ", round(loss.item() / seq_len, 4), "\ttf: ", round(tf, 4)) 52 | model.train() 53 | if batch_idx % 10000 == 0: 54 | torch.save(model.state_dict(), "checkpoint/ae/ae_" + str(batch_idx//10000) + ".pth") 55 | 56 | 57 | def evaluate(model): 58 | with torch.no_grad(): 59 | test_sample, wh = next(iter(val_loader.generate(100))) 60 | test_sample = test_sample.cuda() 61 | test_sample_vel = test_sample[:, 1:, :] - test_sample[:, :-1, :] 62 | reconstructed = model(test_sample_vel, 0) 63 | 64 | np.save("temp/ae/eval_reconstructed.npy", reconstructed.cpu().detach().numpy()[0]) 65 | np.save("temp/ae/eval_original.npy", test_sample.cpu().detach().numpy()[0]) 66 | np.save("temp/ae/eval_wh.npy", wh.numpy()[0]) 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | train_loader = dataloader(split="train", batch_size=512) 72 | val_loader = dataloader(split="val", batch_size=1) 73 | 74 | model = motion_ae(256).cuda() 75 | train(model) 76 | 77 | #model.load_state_dict(torch.load("checkpoints/ae_weights/ae_combined_path_mot_8.pth")) 78 | #model.eval() 79 | #evaluate(model) 80 | 81 | -------------------------------------------------------------------------------- /train_ar.py: -------------------------------------------------------------------------------- 1 | from torch import nn, optim 2 | from utils.utils_ar import * 3 | from utils.dataloader_ar import dataloader as data_loader 4 | from utils.clustering import load_clusters 5 | from models.ar import motion_ar 6 | from models.ae import motion_ae 7 | 8 | torch.backends.cudnn.enabled = False 9 | 10 | 11 | def jitter_seq(seq, jitter): 12 | new_seq = seq.clone() 13 | for t in range(new_seq.shape[1]): 14 | new_seq[:, t, 0] += np.random.randint(-1 * jitter, jitter) 15 | new_seq[:, t, 1] += np.random.randint(-1 * jitter, jitter) 16 | new_seq[:, t, 2] += np.random.randint(-1 * jitter, jitter) 17 | new_seq[:, t, 3] += np.random.randint(-1 * jitter, jitter) 18 | 19 | return new_seq 20 | 21 | 22 | def train(model, model_ae, centroid_x, centroid_y, centroid_w, centroid_h, num_cluster): 23 | 24 | train_loader = data_loader(model_ae, split="train", batch_size=256) 25 | model.train() 26 | model_optimizer = optim.Adam(model.parameters(), lr=1e-3) 27 | 28 | video_len = 100 29 | gradient_clip = 0.1 30 | MAX_ITER = 200000 31 | criterion = nn.NLLLoss().cuda() 32 | 33 | for batch_idx in range(MAX_ITER): 34 | 35 | data, image_wh, social = next(train_loader.generate(video_len)) 36 | self_data = data.cuda() 37 | image_wh = image_wh.cuda() 38 | social = social.cuda() 39 | 40 | if_jitter = np.random.randint(1, 100) > 50 41 | 42 | if if_jitter: 43 | self_data = jitter_seq(self_data, 10) 44 | 45 | # computing delta 46 | self_delta_tmp = self_data[:, 1:, :] - self_data[:, 0:-1, :] 47 | self_delta = torch.zeros(self_delta_tmp.shape[0], self_delta_tmp.shape[1] + 1, self_delta_tmp.shape[2]).cuda() 48 | self_delta[:, 1:, :] = self_delta_tmp 49 | 50 | rand_len = np.random.randint(5, video_len - 1) 51 | 52 | if_mask = np.random.randint(1, 10) < 8 53 | mask_index = rand_len + 1 54 | if if_mask: 55 | mask_index = int(rand_len * 0.7) 56 | 57 | 58 | # 2. Train model 59 | model_optimizer.zero_grad() 60 | 61 | sampled_deltas, dist_x, dist_y, dist_w, dist_h = model(self_delta[:, :rand_len, :], social[:, :rand_len, :], mask_index, centroid_x, 62 | centroid_y, centroid_w, centroid_h) 63 | 64 | nll_loss = 0 65 | dist_x = nn.LogSoftmax(dim=-1)(dist_x) 66 | dist_y = nn.LogSoftmax(dim=-1)(dist_y) 67 | dist_w = nn.LogSoftmax(dim=-1)(dist_w) 68 | dist_h = nn.LogSoftmax(dim=-1)(dist_h) 69 | for i in range(1, rand_len): 70 | GT_delta = self_data[:, i: i + 1, :] - self_data[:, i - 1: i, :] 71 | GT_delta_norm = torch.zeros(GT_delta.shape).cuda() 72 | GT_delta_norm[:, :, 0] = GT_delta[:, :, 0] 73 | GT_delta_norm[:, :, 1] = GT_delta[:, :, 1] 74 | GT_delta_norm[:, :, 2] = GT_delta[:, :, 2] 75 | GT_delta_norm[:, :, 3] = GT_delta[:, :, 3] 76 | delta_x = GT_delta_norm[:, :, 0:1] 77 | delta_y = GT_delta_norm[:, :, 1:2] 78 | delta_w = GT_delta_norm[:, :, 2:3] 79 | delta_h = GT_delta_norm[:, :, 3:4] 80 | min_idx_x, min_idx_y, min_idx_w, min_idx_h = quantize_cluster(delta_x, delta_y, delta_w, delta_h, 81 | centroid_x, centroid_y, centroid_w, centroid_h, num_cluster) 82 | 83 | nll_loss += criterion(dist_x[:, i - 1:i, :].view(dist_x.shape[0], -1), min_idx_x.view(-1).long()) 84 | nll_loss += criterion(dist_y[:, i - 1:i, :].view(dist_y.shape[0], -1), min_idx_y.view(-1).long()) 85 | nll_loss += criterion(dist_w[:, i - 1:i, :].view(dist_w.shape[0], -1), min_idx_w.view(-1).long()) 86 | nll_loss += criterion(dist_h[:, i - 1:i, :].view(dist_h.shape[0], -1), min_idx_h.view(-1).long()) 87 | 88 | loss = nll_loss #+ euc_dist_loss 89 | loss.backward() 90 | nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) 91 | model_optimizer.step() 92 | 93 | if batch_idx % 250 == 0: 94 | for param_group in model_optimizer.param_groups: 95 | param_group['lr'] *= 0.999 96 | 97 | print(batch_idx, 98 | "\t nll:", round(nll_loss.item()/rand_len, 2)) 99 | 100 | if batch_idx % 1000 == 0: 101 | torch.save(model.state_dict(), "checkpoint/ar/ar_" + str(batch_idx // 1000) + ".pth") 102 | 103 | 104 | if __name__ == '__main__': 105 | centroid_x, centroid_y, centroid_w, centroid_h = load_clusters() 106 | num_cluster = 1024 107 | model_ae = motion_ae(256).cuda() 108 | model_ae.load_state_dict(torch.load("checkpoint/ae/ae_8.pth")) 109 | model_ae.eval() 110 | model_social = motion_ar(num_clusters=num_cluster).cuda() 111 | train(model_social, motion_ae, centroid_x, centroid_y, centroid_w, centroid_h, num_cluster) 112 | -------------------------------------------------------------------------------- /utils/__pycache__/clustering.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/utils/__pycache__/clustering.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader_ar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/utils/__pycache__/dataloader_ar.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_ar.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fatemeh-slh/ArTIST/3dece1b9593003a7293715a2bc089c39689306cd/utils/__pycache__/utils_ar.cpython-39.pyc -------------------------------------------------------------------------------- /utils/clustering.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from sklearn.cluster import KMeans 4 | from random import shuffle 5 | import torch 6 | import argparse 7 | 8 | 9 | def prepare_clustering_data(data_path="train_path_mot.npy"): 10 | dataset = np.load(data_path, allow_pickle=True).item() 11 | samples_x = [] 12 | samples_y = [] 13 | samples_w = [] 14 | samples_h = [] 15 | for dir in dataset.keys(): 16 | for tracklet_key, tracklet_value in dataset[dir]['tracklets'].items(): 17 | for t in range(len(tracklet_value['sequence']) - 1): 18 | items1 = np.array(tracklet_value['sequence'][t]) 19 | items2 = np.array(tracklet_value['sequence'][t + 1]) 20 | 21 | x1 = float(items1[2]) / int(dataset[dir]['imWidth']) 22 | y1 = float(items1[3]) / int(dataset[dir]['imHeight']) 23 | w1 = float(items1[4]) / int(dataset[dir]['imWidth']) 24 | h1 = float(items1[5]) / int(dataset[dir]['imHeight']) 25 | x2 = float(items2[2]) / int(dataset[dir]['imWidth']) 26 | y2 = float(items2[3]) / int(dataset[dir]['imHeight']) 27 | w2 = float(items2[4]) / int(dataset[dir]['imWidth']) 28 | h2 = float(items2[5]) / int(dataset[dir]['imHeight']) 29 | samples_x.append(x2 - x1) 30 | samples_y.append(y2 - y1) 31 | samples_w.append(w2 - w1) 32 | samples_h.append(h2 - h1) 33 | print(dir, tracklet_key) 34 | 35 | shuffle(samples_x) 36 | shuffle(samples_y) 37 | shuffle(samples_w) 38 | shuffle(samples_h) 39 | 40 | np.save('centroid/samples_x_motpath.npy', np.array(samples_x, dtype='float32')) 41 | np.save('centroid/samples_y_motpath.npy', np.array(samples_y, dtype='float32')) 42 | np.save('centroid/samples_w_motpath.npy', np.array(samples_w, dtype='float32')) 43 | np.save('centroid/samples_h_motpath.npy', np.array(samples_h, dtype='float32')) 44 | 45 | 46 | def clustering_all(num_cluster=1024, stride=4): 47 | 48 | samples_x = np.load('centroid/samples_x_motpath.npy') 49 | samples_y = np.load('centroid/samples_y_motpath.npy') 50 | samples_w = np.load('centroid/samples_w_motpath.npy') 51 | samples_h = np.load('centroid/samples_h_motpath.npy') 52 | 53 | samples_x = samples_x[0::stride] 54 | kmeans_x = KMeans(n_clusters=num_cluster, verbose=1) 55 | kmeans_x.fit(np.expand_dims(samples_x, 1)) 56 | np.save('cluster/centroids_x.npy', kmeans_x.cluster_centers_) 57 | 58 | samples_y = samples_y[0::stride] 59 | kmeans_y = KMeans(n_clusters=num_cluster, verbose=1) 60 | kmeans_y.fit(np.expand_dims(samples_y, 1)) 61 | np.save('cluster/centroids_y.npy', kmeans_y.cluster_centers_) 62 | 63 | samples_w = samples_w[0::stride] 64 | kmeans_w = KMeans(n_clusters=num_cluster, verbose=1) 65 | kmeans_w.fit(np.expand_dims(samples_w, 1)) 66 | np.save('cluster/centroids_w.npy', kmeans_w.cluster_centers_) 67 | # 68 | samples_h = samples_h[0::stride] 69 | kmeans_h = KMeans(n_clusters=num_cluster, verbose=1) 70 | kmeans_h.fit(np.expand_dims(samples_h, 1)) 71 | np.save('cluster/centroids_h.npy', kmeans_h.cluster_centers_) 72 | 73 | 74 | def clustering(num_cluster=1024, stride=4, data_component='x'): 75 | if data_component == 'x': 76 | samples = np.load('centroid/samples_x_motpath.npy') 77 | if data_component == 'y': 78 | samples = np.load('centroid/samples_y_motpath.npy') 79 | if data_component == 'w': 80 | samples = np.load('centroid/samples_w_motpath.npy') 81 | if data_component == 'h': 82 | samples = np.load('centroid/samples_h_motpath.npy') 83 | 84 | samples = samples[0::stride] 85 | kmeans = KMeans(n_clusters=num_cluster, verbose=1) 86 | kmeans.fit(np.expand_dims(samples, 1)) 87 | np.save('cluster/centroids_' + data_component + '.npy', kmeans.cluster_centers_) 88 | 89 | 90 | def load_clusters(): 91 | centroid_x = np.load("centroid/centroids_x.npy") 92 | centroid_x = np.sort(centroid_x, 0) 93 | centroid_x = torch.from_numpy(centroid_x).transpose(1, 0).cuda() 94 | 95 | centroid_y = np.load("centroid/centroids_y.npy") 96 | centroid_y = np.sort(centroid_y, 0) 97 | centroid_y = torch.from_numpy(centroid_y).transpose(1, 0).cuda() 98 | 99 | centroid_w = np.load("centroid/centroids_w.npy") 100 | centroid_w = np.sort(centroid_w, 0) 101 | centroid_w = torch.from_numpy(centroid_w).transpose(1, 0).cuda() 102 | 103 | centroid_h = np.load("centroid/centroids_h.npy") 104 | centroid_h = np.sort(centroid_h, 0) 105 | centroid_h = torch.from_numpy(centroid_h).transpose(1, 0).cuda() 106 | 107 | return centroid_x, centroid_y, centroid_w, centroid_h 108 | 109 | 110 | if __name__ == '__main__': 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("-t", "--task", type=str, default='do_all', 113 | choices=['prepare', 'cluster_x', 'cluster_y', 'cluster_w', 'cluster_h', 'cluster_all', 'do_all']) 114 | parser.add_argument("-d", "--data", type=str, required=True) 115 | parser.add_argument("-s", "--stride", type=int, default=4) 116 | parser.add_argument("-c", "--cluster", type=int, default=1024) 117 | args = parser.parse_args() 118 | 119 | if args.task == "prepare": 120 | prepare_clustering_data(args.data) 121 | if args.task == "cluster_x": 122 | clustering(num_cluster=args.cluster, stride=args.stride, data_component='x') 123 | if args.task == "cluster_y": 124 | clustering(num_cluster=args.cluster, stride=args.stride, data_component='y') 125 | if args.task == "cluster_w": 126 | clustering(num_cluster=args.cluster, stride=args.stride, data_component='w') 127 | if args.task == "cluster_h": 128 | clustering(num_cluster=args.cluster, stride=args.stride, data_component='h') 129 | if args.task == "cluster_all": 130 | clustering_all(num_cluster=args.cluster, stride=args.stride) 131 | if args.task == "do_all": 132 | prepare_clustering_data(args.data) 133 | clustering_all(num_cluster=args.cluster, stride=args.stride) -------------------------------------------------------------------------------- /utils/create_demo_test_subset.py: -------------------------------------------------------------------------------- 1 | from utils.dataloader_ar import dataloader 2 | from models import motion_ae 3 | import torch 4 | import numpy as np 5 | 6 | if __name__ == "__main__": 7 | 8 | model_ae = motion_ae(256).cuda() 9 | model_ae.load_state_dict(torch.load("checkpoint/ae/ae_8.pth")) 10 | model_ae.eval() 11 | val_loader = dataloader(model_ae, split="val", batch_size=1) 12 | test_set = [] 13 | for i in range(1000): 14 | seq_len = np.random.randint(5, 100) 15 | obs_len = np.random.randint(1, seq_len - 1) 16 | data, wh, social = next(iter(val_loader.generate(seq_len))) 17 | test_set.append({'data': data, 'wh': wh, 'social': social, 'seq_len': seq_len, 'obs_len': obs_len}) 18 | 19 | np.save('demo_test_subset.npy', test_set) 20 | -------------------------------------------------------------------------------- /utils/dataloader_ae.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import random 5 | 6 | 7 | class dataloader: 8 | def __init__(self, split="train", batch_size=64): 9 | self.batch_size = batch_size 10 | if split == "train": 11 | self.dataset = np.load('data/postp_combined_path_mot_train.npy', allow_pickle=True).item() 12 | else: 13 | self.dataset = np.load('data/postp_mot_val.npy', allow_pickle=True).item() 14 | 15 | def generate(self, seq_len=180): 16 | while True: 17 | current_batch_motion = np.zeros((self.batch_size, seq_len, 4)) 18 | current_image_wh = torch.zeros((self.batch_size, 2)) 19 | 20 | for batch in range(self.batch_size): 21 | while True: 22 | dir = random.choice(list(self.dataset.keys())) 23 | selected_seq = random.choice(list(self.dataset[dir]['tracklets'].keys())) 24 | len_selected_seq = len(self.dataset[dir]['tracklets'][selected_seq]['sequence']) 25 | if len_selected_seq > seq_len + 1: 26 | break 27 | 28 | # select a random time 29 | offset = np.random.randint(0, len_selected_seq - seq_len) 30 | 31 | for t in range(offset, offset + seq_len): 32 | items = np.array(self.dataset[dir]['tracklets'][selected_seq]['sequence'][t]) 33 | current_batch_motion[batch, t - offset] = items[2:6] 34 | current_image_wh[batch, 0] = int(self.dataset[dir]['imWidth']) 35 | current_image_wh[batch, 1] = int(self.dataset[dir]['imHeight']) 36 | 37 | yield torch.from_numpy(current_batch_motion).float(), current_image_wh 38 | 39 | 40 | if __name__ == '__main__': 41 | loader = dataloader() 42 | data, wh = next(iter(loader.generate(20))) 43 | -------------------------------------------------------------------------------- /utils/dataloader_ar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | import random 5 | 6 | 7 | def find_frame_idx(frame_num, sequences): 8 | for i in range(len(sequences['sequence'])): 9 | if int(sequences['sequence'][i][0]) == frame_num: 10 | return i 11 | return -1 12 | 13 | 14 | class dataloader: 15 | def __init__(self, model_ae, split="train", batch_size=64): 16 | self.batch_size = batch_size 17 | if split == "train": 18 | self.dataset = np.load('data/postp_combined_path_mot_train.npy', allow_pickle=True).item() 19 | else: 20 | self.dataset = np.load('data/postp_mot_val.npy', allow_pickle=True).item() 21 | self.model_ae = model_ae 22 | 23 | def generate(self, seq_len): 24 | while True: 25 | current_batch_motion = np.zeros((self.batch_size, seq_len, 4)) 26 | current_image_wh = torch.zeros((self.batch_size, 2)) 27 | social_list_h = [] 28 | for batch in range(self.batch_size): 29 | 30 | while True: 31 | dir = random.choice(list(self.dataset.keys())) 32 | selected_seq = random.choice(list(self.dataset[dir]['tracklets'].keys())) 33 | len_selected_seq = len(self.dataset[dir]['tracklets'][selected_seq]['sequence']) 34 | if len_selected_seq > seq_len + 1: 35 | break 36 | 37 | # select a random time 38 | offset = np.random.randint(0, len_selected_seq - seq_len) 39 | sequence = self.dataset[dir]['tracklets'][selected_seq]['sequence'] 40 | current_batch_motion[batch, :] = sequence[offset: offset + seq_len] 41 | 42 | current_image_wh[batch, 0] = int(self.dataset[dir]['imWidth']) 43 | current_image_wh[batch, 1] = int(self.dataset[dir]['imHeight']) 44 | 45 | social_ids = self.dataset[dir]['tracklets'][selected_seq]['social_ids'] 46 | social_ids = np.unique(social_ids) 47 | start_frame = int(self.dataset[dir]['tracklets'][selected_seq]['start']) + offset 48 | end_frame = int(self.dataset[dir]['tracklets'][selected_seq]['start']) + offset + seq_len 49 | social_dict = {} 50 | for sid in social_ids: 51 | start_s = self.dataset[dir]['tracklets'][sid]['start'] 52 | end_s = self.dataset[dir]['tracklets'][sid]['end'] 53 | SEQ = np.zeros((seq_len, 4)) 54 | START, END = -1, -1 55 | 56 | if start_s == start_frame and end_frame == end_s: 57 | SEQ = self.dataset[dir]['tracklets'][sid]['sequence'][0:-1] 58 | START = start_s 59 | END = end_s 60 | elif start_frame < end_s <= end_frame: 61 | if start_s >= start_frame: 62 | SEQ = self.dataset[dir]['tracklets'][sid]['sequence'][0:-1] 63 | START = start_s 64 | END = end_s 65 | else: 66 | idx = start_frame - start_s 67 | SEQ = self.dataset[dir]['tracklets'][sid]['sequence'][idx:-1] 68 | START = start_frame 69 | END = end_s 70 | elif end_s > start_frame and start_s < end_frame: 71 | if start_s >= start_frame: 72 | idx_end = end_frame - start_s 73 | SEQ = self.dataset[dir]['tracklets'][sid]['sequence'][:idx_end] 74 | START = start_s 75 | END = end_frame 76 | elif start_s < start_frame: 77 | idx_end = end_frame - start_s 78 | idx_start = start_frame - start_s 79 | SEQ = self.dataset[dir]['tracklets'][sid]['sequence'][idx_start:idx_end] 80 | START = start_frame 81 | END = end_frame 82 | if START != -1 and END != -1 and START != END: 83 | social_dict[sid] = {'seq': SEQ, 'start': START, 'end': END} 84 | 85 | current_batch_social = torch.zeros(len(social_dict.keys()), seq_len, 4) 86 | idx = 0 87 | for key, value in social_dict.items(): 88 | start_social = int(value['start']) 89 | end_social = int(value['end']) 90 | if start_social > end_social: 91 | pass 92 | else: 93 | 94 | if start_social == start_frame and end_social == end_frame: 95 | current_batch_social[idx] = torch.from_numpy(value['seq']) 96 | 97 | if start_social > start_frame and end_social < end_frame: 98 | current_batch_social[idx, :(start_social - start_frame)] = torch.from_numpy(value['seq'][0, :]) 99 | current_batch_social[idx, (start_social - start_frame):-(end_frame - end_social)] = torch.from_numpy(value['seq'][:, :]) 100 | current_batch_social[idx, -(end_frame - end_social):] = torch.from_numpy(value['seq'][-1, :]) 101 | elif end_social < end_frame and start_social == start_frame: 102 | current_batch_social[idx, -(end_frame - end_social):] = torch.from_numpy(value['seq'][-1, :]) 103 | current_batch_social[idx, :-(end_frame - end_social)] = torch.from_numpy(value['seq'][:, :]) 104 | elif end_social == end_frame and start_social > start_frame: 105 | current_batch_social[idx, :(start_social - start_frame)] = torch.from_numpy(value['seq'][0, :]) 106 | current_batch_social[idx, (start_social - start_frame):] = torch.from_numpy(value['seq'][:, :]) 107 | else: 108 | current_batch_social[idx] = torch.from_numpy(value['seq']) 109 | 110 | idx += 1 111 | social_vel = torch.zeros(current_batch_social.shape) 112 | social_vel[:, 1:, :] = current_batch_social[:, 1:, :] - current_batch_social[:, :-1, :] 113 | with torch.no_grad(): 114 | social_vel = social_vel.float().cuda() 115 | try: 116 | h = self.model_ae.inference(social_vel) 117 | h = torch.max(h, dim=0)[0].unsqueeze(0) 118 | except: 119 | h = torch.zeros(1, seq_len, 256).float().cuda() 120 | social_list_h.append(h) 121 | social_rep = torch.cat(social_list_h, dim=0) 122 | yield torch.from_numpy(current_batch_motion).float(), current_image_wh, social_rep.float() 123 | -------------------------------------------------------------------------------- /utils/post_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def postprocess(data='data/train_mot17.npy', output_file="data/postp_train_mot17.npy"): 5 | dataset = np.load(data, allow_pickle=True).item() 6 | 7 | for key, value in dataset.items(): 8 | for keyt, valuet in value['tracklets'].items(): 9 | seq = valuet['sequence'] 10 | new_seq = np.zeros((len(seq), 4)) 11 | for idx, item in enumerate(seq): 12 | new_seq[idx, 0] = float(seq[idx][2]) / int(value['imWidth']) 13 | new_seq[idx, 1] = float(seq[idx][3]) / int(value['imHeight']) 14 | new_seq[idx, 2] = float(seq[idx][4]) / int(value['imWidth']) 15 | new_seq[idx, 3] = float(seq[idx][5]) / int(value['imHeight']) 16 | valuet['sequence'] = new_seq 17 | 18 | np.save(output_file, dataset) -------------------------------------------------------------------------------- /utils/prepare_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import configparser 4 | import argparse 5 | 6 | 7 | def prepare_data(data_path='data/MOT17Labels', split='train'): 8 | 9 | data_dict = {} 10 | dirs = glob.glob(data_path + split + '/*') 11 | dirs.sort() 12 | 13 | for dir in dirs: 14 | max_id = 500000 15 | tracklets = {} 16 | if 'FRCNN' in dir: 17 | config = configparser.ConfigParser() 18 | config.read(dir + '/seqinfo.ini') 19 | seq = config['Sequence']['name'] 20 | if seq not in data_dict: 21 | data_dict[seq] = {} 22 | data_dict[seq]['frameRate'] = config['Sequence']['frameRate'] 23 | data_dict[seq]['imWidth'] = config['Sequence']['imWidth'] 24 | data_dict[seq]['imHeight'] = config['Sequence']['imHeight'] 25 | data_dict[seq]['tracklets'] = {} 26 | # change to gt_val_half if split = val 27 | # change to gt_train_half if split = train and you would like to keep half of the train for validation 28 | gt_file = dir + '/gt/gt.txt' 29 | fid = open(gt_file) 30 | lines = str.split(fid.read(), '\n')[:-1] 31 | for line in lines: 32 | items = str.split(line, ',') 33 | current_id = int(items[1]) 34 | if int(items[6]) != 1 or float(items[-1]) < 0.25: 35 | continue 36 | else: 37 | if int(items[1]) not in tracklets: 38 | tracklets[int(items[1])] = {} 39 | tracklets[int(items[1])]['sequence'] = [] 40 | tracklets[int(items[1])]['social_ids'] = [] 41 | tracklets[int(items[1])]['start'] = int(items[0]) 42 | tracklets[int(items[1])]['end'] = int(items[0]) 43 | current_id = int(items[1]) 44 | if int(items[0]) - tracklets[current_id]['end'] > 1: 45 | current_id = max_id + 1 46 | max_id += 1 47 | if current_id not in tracklets: 48 | tracklets[current_id] = {} 49 | tracklets[current_id]['sequence'] = [] 50 | tracklets[current_id]['social_ids'] = [] 51 | tracklets[current_id]['start'] = int(items[0]) 52 | tracklets[current_id]['end'] = int(items[0]) 53 | else: 54 | tracklets[current_id]['end'] = int(items[0]) 55 | tracklets[current_id]['sequence'].append(items) 56 | 57 | data_dict[seq]['tracklets'] = tracklets 58 | 59 | 60 | for key, value in data_dict.items(): 61 | for key_t, value_t in value['tracklets'].items(): 62 | id = key_t 63 | start = value_t['start'] 64 | end = value_t['end'] 65 | for key_s, value_s in value['tracklets'].items(): 66 | if key_s != key_t: 67 | if start<=start<=value_s['end'] <= end: 68 | value_t['social_ids'].append(key_s) 69 | if start<=value_s['start'] <= end: 70 | value_t['social_ids'].append(key_s) 71 | if value_s['start'] < start and value_s['end'] > end: 72 | value_t['social_ids'].append(key_s) 73 | 74 | if args.dataset == "mot": 75 | np.save('data/train_mot17.npy', data_dict) 76 | if args.dataset == "pathtrack": 77 | if split == "train": 78 | np.save('data/path_track_train.npy', data_dict) 79 | else: 80 | np.save('data/path_track_test.npy', data_dict) 81 | 82 | 83 | def merge_datasets(): 84 | 85 | dicts = [] 86 | mot_train = np.load('data/train_mot17.npy', allow_pickle=True).item() 87 | path_train = np.load('data/path_track_train.npy', allow_pickle=True).item() 88 | path_test = np.load('data/path_track_test.npy', allow_pickle=True).item() 89 | 90 | dicts.append(mot_train) 91 | dicts.append(path_train) 92 | dicts.append(path_test) 93 | 94 | super_dict = {} 95 | for d in dicts: 96 | for k, v in d.items(): 97 | super_dict[k] = v 98 | 99 | np.save('data/train_path_mot.npy', super_dict) 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("-t", "--task", type=str, default='prepare', choices=['prepare', 'merge']) 105 | parser.add_argument("-d", "--dataset", type=str, choices=['mot', 'pathtrack'], default='mot') 106 | parser.add_argument("-s", "--split", type=str, default='train') 107 | args = parser.parse_args() 108 | 109 | if args.task == "prepare": 110 | if args.dataset == 'mot': 111 | data_path = 'data/MOT17Labels' 112 | prepare_data(data_path=data_path, split=args.split) 113 | elif args.dataset == 'pathtrack': 114 | data_path = 'data/pathtrack_release_v1.0/pathtrack_release/' 115 | prepare_data(data_path=data_path, split=args.split) 116 | 117 | if args.task == "merge": 118 | merge_datasets() -------------------------------------------------------------------------------- /utils/utils_ar.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def iou(bb_test, bb_gt): 7 | """ 8 | Computes IUO between two bboxes in the form [x1,y1,x2,y2] 9 | """ 10 | xx1 = np.maximum(bb_test[0], bb_gt[0]) 11 | yy1 = np.maximum(bb_test[1], bb_gt[1]) 12 | xx2 = np.minimum(bb_test[2], bb_gt[2]) 13 | yy2 = np.minimum(bb_test[3], bb_gt[3]) 14 | w = np.maximum(0., xx2 - xx1) 15 | h = np.maximum(0., yy2 - yy1) 16 | wh = w * h 17 | o = wh / ((bb_test[2] - bb_test[0]) * (bb_test[3] - bb_test[1]) 18 | + (bb_gt[2] - bb_gt[0]) * (bb_gt[3] - bb_gt[1]) - wh) 19 | return (o) 20 | 21 | 22 | def quantize_cluster(delta_x, delta_y, delta_w, delta_h, centroid_x, centroid_y, centroid_w, centroid_h, num_cluster): 23 | 24 | B, T = delta_x.shape[0], delta_x.shape[1] 25 | # compute distances to the cluster centroids 26 | batch_centroid_x = centroid_x.unsqueeze(0).repeat(B, T, 1) 27 | batch_gt_x = delta_x.repeat(1, 1, num_cluster) 28 | dist_gt_x = torch.abs(batch_centroid_x.float() - batch_gt_x.float()) 29 | min_val_x, min_idx_x = torch.min(dist_gt_x, dim=-1) 30 | 31 | batch_centroid_y = centroid_y.unsqueeze(0).repeat(B, T, 1) 32 | batch_gt_y = delta_y.repeat(1, 1, num_cluster) 33 | dist_gt_y = torch.abs(batch_centroid_y.float() - batch_gt_y.float()) 34 | min_val_y, min_idx_y = torch.min(dist_gt_y, dim=-1) 35 | 36 | batch_centroid_w = centroid_w.unsqueeze(0).repeat(B, T, 1) 37 | batch_gt_w = delta_w.repeat(1, 1, num_cluster) 38 | dist_gt_w = torch.abs(batch_centroid_w.float() - batch_gt_w.float()) 39 | min_val_w, min_idx_w = torch.min(dist_gt_w, dim=-1) 40 | 41 | batch_centroid_h = centroid_h.unsqueeze(0).repeat(B, T, 1) 42 | batch_gt_h = delta_h.repeat(1, 1, num_cluster) 43 | dist_gt_h = torch.abs(batch_centroid_h.float() - batch_gt_h.float()) 44 | min_val_h, min_idx_h = torch.min(dist_gt_h, dim=-1) 45 | 46 | return min_idx_x, min_idx_y, min_idx_w, min_idx_h 47 | 48 | 49 | def quantize_bin(delta, bin): 50 | 51 | B, T = delta.shape[0], delta.shape[1] 52 | # compute distances to the cluster centroids 53 | batch_centroid = bin.unsqueeze(0).repeat(B, T, 1) 54 | batch_gt = delta.repeat(1, 1, bin.shape[1]) 55 | dist_gt = torch.abs(batch_centroid.float() - batch_gt.float()) 56 | min_val, min_idx = torch.min(dist_gt, dim=-1) 57 | 58 | return min_idx 59 | 60 | 61 | def sample_cluster(dist, centroid): 62 | indices = (torch.topk(dist, k=1)[1]).squeeze(1) 63 | return centroid[0, list(indices.squeeze())].unsqueeze(1) 64 | 65 | 66 | def infer_log_likelihood(dist_x, dist_y, dist_w, dist_h, x, y, w, h, centroid_x, centroid_y, centroid_w, centroid_h, num_clusters): 67 | probs = [] 68 | idx_x, idx_y, idx_w, idx_h = quantize_cluster(x[:, -1:, :], y[:, -1:, :], w[:, -1:, :], h[:, -1:, :], centroid_x, centroid_y, centroid_w, centroid_h, num_clusters) 69 | prob_t = [dist_x[:, -1:, idx_x.long().item()].item(), dist_y[:, -1:, idx_y.long().item()].item(), 70 | dist_w[:, -1:, idx_w.long().item()].item(), dist_h[:, -1:, idx_h.long().item()].item()] 71 | probs.append(torch.sum(torch.log(torch.tensor(prob_t)))) 72 | return probs 73 | --------------------------------------------------------------------------------