├── utils
├── bremm.png
├── saveload.py
├── basic.py
├── data.py
├── samp.py
├── loss.py
├── misc.py
└── py.py
├── .gitignore
├── demo_video
└── download_video.sh
├── download_reference_model.sh
├── requirements.txt
├── LICENSE
├── datasets
├── davisdataset.py
├── rgbstackingdataset.py
├── kineticsdataset.py
├── robotapdataset.py
├── egopointsdataset.py
├── crohddataset.py
├── badjadataset.py
├── horsedataset.py
├── drivetrackdataset.py
├── kubric_movif_dataset.py
├── exportdataset.py
├── pointdataset.py
└── dynrep_dataset.py
├── README.md
├── demo.py
├── test_dense_on_sparse.py
└── train_stage1.py
/utils/bremm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aharley/alltracker/HEAD/utils/bremm.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | *#*
3 | *pyc
4 | #.*
5 | .DS_Store
6 | *.out
7 | *.gif
8 | stock_videos/
9 | logs*
10 | *.mp4
11 | temp_*
12 |
--------------------------------------------------------------------------------
/demo_video/download_video.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | FILE="monkey.mp4"
3 | echo "downloading ${FILE} from dropbox"
4 | wget --max-redirect=20 -O ${FILE} https://www.dropbox.com/scl/fi/fm2m3ylhzmqae05bzwm8q/monkey.mp4?rlkey=ibf81gaqpxkh334rccu7zrioe&st=mli9bqb6&dl=1
5 |
--------------------------------------------------------------------------------
/download_reference_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | FILE="alltracker_reference.tar.gz"
3 | echo "downloading ${FILE} from dropbox"
4 | wget --max-redirect=20 -O ${FILE} https://www.dropbox.com/scl/fi/ng66ceortfy07bgie3r54/alltracker_reference.tar.gz?rlkey=o781im2v0sl7035hy8fcuv1d5&st=u5mcttcx&dl=1
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchaudio
3 | torchvision
4 | pytorch_lightning==2.4.0
5 | opencv-python==4.10.0.84
6 | einops==0.8.0
7 | moviepy==1.0.3
8 | h5py==3.12.1
9 | matplotlib==3.9.2
10 | scikit-learn==1.5.2
11 | scikit-image==0.24.0
12 | tensorboardX==2.6.2.2
13 | prettytable==3.12.0
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Adam W. Harley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/datasets/davisdataset.py:
--------------------------------------------------------------------------------
1 | from numpy import random
2 | import torch
3 | import numpy as np
4 | import pickle
5 | from datasets.pointdataset import PointDataset
6 | import utils.data
7 | import cv2
8 |
9 | class DavisDataset(PointDataset):
10 | def __init__(
11 | self,
12 | data_root='../datasets/tapvid_davis',
13 | crop_size=(384,512),
14 | seq_len=None,
15 | only_first=False,
16 | ):
17 | super(DavisDataset, self).__init__(
18 | data_root=data_root,
19 | crop_size=crop_size,
20 | seq_len=seq_len,
21 | )
22 |
23 | print('loading TAPVID-DAVIS dataset...')
24 |
25 | self.dname = 'davis'
26 | self.only_first = only_first
27 |
28 | input_path = '%s/tapvid_davis.pkl' % data_root
29 | with open(input_path, 'rb') as f:
30 | data = pickle.load(f)
31 | if isinstance(data, dict):
32 | data = list(data.values())
33 | self.data = data
34 | print('found %d videos in %s' % (len(self.data), data_root))
35 |
36 | def __getitem__(self, index):
37 | dat = self.data[index]
38 | rgbs = dat['video'] # list of H,W,C uint8 images
39 | trajs = dat['points'] # N,S,2 array
40 | visibs = 1-dat['occluded'] # N,S array
41 | # note the annotations are only valid when not occluded
42 |
43 | trajs = trajs.transpose(1,0,2) # S,N,2
44 | visibs = visibs.transpose(1,0) # S,N
45 | valids = visibs.copy()
46 |
47 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
48 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
49 |
50 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
51 | # in this data, 1.0,1.0 should lie at the bottom-right corner pixel
52 | H, W = rgbs[0].shape[:2]
53 | trajs[:,:,0] *= W
54 | trajs[:,:,1] *= H
55 |
56 | rgbs = torch.from_numpy(np.stack(rgbs,0)).permute(0,3,1,2).contiguous().float() # S,C,H,W
57 | trajs = torch.from_numpy(trajs).float() # S,N,2
58 | valids = torch.from_numpy(valids).float() # S,N
59 | visibs = torch.from_numpy(visibs).float() # S,N
60 |
61 | sample = utils.data.VideoData(
62 | video=rgbs,
63 | trajs=trajs,
64 | visibs=visibs,
65 | valids=valids,
66 | dname=self.dname,
67 | )
68 | return sample, True
69 |
70 | def __len__(self):
71 | return len(self.data)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/datasets/rgbstackingdataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 | from datasets.pointdataset import PointDataset
5 | import utils.data
6 | import cv2
7 |
8 | class RGBStackingDataset(PointDataset):
9 | def __init__(
10 | self,
11 | data_root='../datasets/tapvid_rgbstacking',
12 | crop_size=(384,512),
13 | seq_len=None,
14 | only_first=False,
15 | ):
16 | super(RGBStackingDataset, self).__init__(
17 | data_root=data_root,
18 | crop_size=crop_size,
19 | seq_len=seq_len,
20 | )
21 |
22 | print('loading TAPVID-RGB-Stacking dataset...')
23 |
24 | self.dname = 'rgbstacking'
25 | self.only_first = only_first
26 |
27 | input_path = '%s/tapvid_rgb_stacking.pkl' % data_root
28 | with open(input_path, 'rb') as f:
29 | data = pickle.load(f)
30 | if isinstance(data, dict):
31 | data = list(data.values())
32 | self.data = data
33 | print('found %d videos in %s' % (len(self.data), data_root))
34 |
35 | def __getitem__(self, index):
36 | dat = self.data[index]
37 | rgbs = dat['video'] # list of H,W,C uint8 images
38 | trajs = dat['points'] # N,S,2 array
39 | visibs = 1-dat['occluded'] # N,S array
40 | # note the annotations are only valid when visib
41 | valids = visibs.copy()
42 |
43 | trajs = trajs.transpose(1,0,2) # S,N,2
44 | visibs = visibs.transpose(1,0) # S,N
45 | valids = valids.transpose(1,0) # S,N
46 |
47 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
48 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
49 |
50 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
51 | # 1.0,1.0 should lie at the bottom-right corner pixel
52 | H, W = rgbs[0].shape[:2]
53 | trajs[:,:,0] *= W-1
54 | trajs[:,:,1] *= H-1
55 |
56 | rgbs = torch.from_numpy(np.stack(rgbs,0)).permute(0,3,1,2).contiguous().float() # S,C,H,W
57 | trajs = torch.from_numpy(trajs).float() # S,N,2
58 | visibs = torch.from_numpy(visibs).float() # S,N
59 | valids = torch.from_numpy(valids).float() # S,N
60 |
61 | sample = utils.data.VideoData(
62 | video=rgbs,
63 | trajs=trajs,
64 | visibs=visibs,
65 | valids=valids,
66 | dname=self.dname,
67 | )
68 | return sample, True
69 |
70 | def __len__(self):
71 | return len(self.data)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/datasets/kineticsdataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 | from datasets.pointdataset import PointDataset
5 | import utils.data
6 | import cv2
7 | from pathlib import Path
8 | import io
9 | from PIL import Image
10 |
11 | def decode(frame):
12 | byteio = io.BytesIO(frame)
13 | img = Image.open(byteio)
14 | return np.array(img)
15 |
16 | class KineticsDataset(PointDataset):
17 | def __init__(
18 | self,
19 | data_root='../datasets/tapvid_kinetics',
20 | crop_size=(384,512),
21 | seq_len=None,
22 | only_first=False,
23 | ):
24 | super(KineticsDataset, self).__init__(
25 | data_root=data_root,
26 | crop_size=crop_size,
27 | seq_len=seq_len,
28 | )
29 |
30 | print('loading TAPVID-Kinetics dataset...')
31 |
32 | self.dname = 'kinetics'
33 | self.only_first = only_first
34 |
35 | self.data = []
36 | for vid_pkl in sorted(list(Path(data_root).glob('*.pkl')))[:]:
37 | vid_pkl = vid_pkl.name
38 | print(vid_pkl)
39 | input_path = "%s/%s" % (data_root, vid_pkl)
40 | with open(input_path, "rb") as f:
41 | data = pickle.load(f)
42 | self.data += data
43 | print("found %d videos in %s" % (len(self.data), data_root))
44 |
45 | def __getitem__(self, index):
46 | dat = self.data[index]
47 | rgbs = dat['video'] # list of H,W,C uint8 images
48 | if isinstance(rgbs[0], bytes): # decode if needed
49 | rgbs = [decode(frame) for frame in rgbs]
50 | trajs = dat['points'] # N,S,2 array
51 | visibs = 1-dat['occluded'] # N,S array
52 | # note the annotations are only valid when visib
53 |
54 | trajs = trajs.transpose(1,0,2) # S,N,2
55 | visibs = visibs.transpose(1,0) # S,N
56 | valids = visibs.copy()
57 |
58 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
59 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
60 |
61 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
62 | H, W = rgbs[0].shape[:2]
63 | trajs[:,:,0] *= W-1
64 | trajs[:,:,1] *= H-1
65 |
66 | rgbs = torch.from_numpy(np.stack(rgbs,0)).permute(0,3,1,2).contiguous().float() # S,C,H,W
67 | trajs = torch.from_numpy(trajs).float() # S,N,2
68 | visibs = torch.from_numpy(visibs).float() # S,N
69 | valids = torch.from_numpy(valids).float() # S,N
70 |
71 | sample = utils.data.VideoData(
72 | video=rgbs,
73 | trajs=trajs,
74 | visibs=visibs,
75 | valids=valids,
76 | dname=self.dname,
77 | )
78 | return sample, True
79 |
80 | def __len__(self):
81 | return len(self.data)
82 |
83 |
84 |
--------------------------------------------------------------------------------
/datasets/robotapdataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 | from datasets.pointdataset import PointDataset
5 | import utils.data
6 | import cv2
7 |
8 | class RobotapDataset(PointDataset):
9 | def __init__(
10 | self,
11 | data_root='../datasets/robotap',
12 | crop_size=(384,512),
13 | seq_len=None,
14 | only_first=False,
15 | ):
16 | super(RobotapDataset, self).__init__(
17 | data_root=data_root,
18 | crop_size=crop_size,
19 | seq_len=seq_len,
20 | )
21 |
22 | self.dname = 'robo'
23 | self.only_first = only_first
24 |
25 | # self.train_pkls = ['robotap_split0.pkl', 'robotap_split1.pkl', 'robotap_split2.pkl']
26 | self.val_pkls = ['robotap_split3.pkl', 'robotap_split4.pkl']
27 |
28 | print("loading robotap dataset...")
29 | # self.vid_pkls = self.train_pkls if is_training else self.val_pkls
30 | self.data = []
31 | for vid_pkl in self.val_pkls:
32 | print(vid_pkl)
33 | input_path = "%s/%s" % (data_root, vid_pkl)
34 | with open(input_path, "rb") as f:
35 | data = pickle.load(f)
36 | keys = list(data.keys())
37 | self.data += [data[key] for key in keys]
38 | print("found %d videos in %s" % (len(self.data), data_root))
39 |
40 | def __len__(self):
41 | return len(self.data)
42 |
43 | def getitem_helper(self, index):
44 | dat = self.data[index]
45 | rgbs = dat["video"] # list of H,W,C uint8 images
46 | trajs = dat["points"] # N,S,2 array
47 | visibs = 1 - dat["occluded"] # N,S array
48 |
49 | # note the annotations are only valid when not occluded
50 | trajs = trajs.transpose(1,0,2) # S,N,2
51 | visibs = visibs.transpose(1,0) # S,N
52 | valids = visibs.copy()
53 |
54 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
55 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
56 |
57 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
58 | # 1.0,1.0 should lie at the bottom-right corner pixel
59 | H, W = rgbs[0].shape[:2]
60 | trajs[:,:,0] *= W-1
61 | trajs[:,:,1] *= H-1
62 |
63 | rgbs = torch.from_numpy(np.stack(rgbs,0)).permute(0,3,1,2).contiguous().float() # S,C,H,W
64 | trajs = torch.from_numpy(trajs).float() # S,N,2
65 | visibs = torch.from_numpy(visibs).float() # S,N
66 | valids = torch.from_numpy(valids).float() # S,N
67 |
68 | if self.seq_len is not None:
69 | rgbs = rgbs[:self.seq_len]
70 | trajs = trajs[:self.seq_len]
71 | valids = valids[:self.seq_len]
72 | visibs = visibs[:self.seq_len]
73 |
74 | sample = utils.data.VideoData(
75 | video=rgbs,
76 | trajs=trajs,
77 | visibs=visibs,
78 | valids=valids,
79 | dname=self.dname,
80 | )
81 | return sample, True
82 |
--------------------------------------------------------------------------------
/utils/saveload.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import os
3 | import torch
4 |
5 | def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'):
6 | pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True)
7 | prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name))
8 | prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
9 | if len(prev_ckpts) > keep_latest-1:
10 | for f in prev_ckpts[keep_latest-1:]:
11 | f.unlink()
12 | save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step)
13 | save_dict = {
14 | "model": module.state_dict(),
15 | "optimizer": optimizer.state_dict(),
16 | "global_step": global_step,
17 | }
18 | if scheduler is not None:
19 | save_dict['scheduler'] = scheduler.state_dict()
20 | print(f"saving {save_path}")
21 | torch.save(save_dict, save_path)
22 | return False
23 |
24 | def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False):
25 | if verbose:
26 | print('reading ckpt from %s' % ckpt_path)
27 | if not os.path.exists(ckpt_path):
28 | print('...there is no full checkpoint in %s' % ckpt_path)
29 | print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --')
30 | assert(False)
31 | else:
32 | if os.path.isfile(ckpt_path):
33 | path = ckpt_path
34 | print('...found checkpoint %s' % (path))
35 | else:
36 | prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name))
37 | prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True)
38 | if len(prev_ckpts):
39 | path = prev_ckpts[0]
40 | # e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth'
41 | # OR ./whatever.pth
42 | step = int(str(path).split('-')[-1].split('.')[0])
43 | if verbose:
44 | print('...found checkpoint %s; (parsed step %d from path)' % (path, step))
45 | else:
46 | print('...there is no full checkpoint here!')
47 | return 0
48 | if fabric is not None:
49 | checkpoint = fabric.load(path)
50 | else:
51 | checkpoint = torch.load(path, weights_only=weights_only)
52 | if optimizer is not None:
53 | optimizer.load_state_dict(checkpoint['optimizer'])
54 | if scheduler is not None:
55 | scheduler.load_state_dict(checkpoint['scheduler'])
56 | assert ignore_load is None # not ready yet
57 | if 'model' in checkpoint:
58 | state_dict = checkpoint['model']
59 | else:
60 | state_dict = checkpoint
61 | model.load_state_dict(state_dict, strict=strict)
62 | return step
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/datasets/egopointsdataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from datasets.pointdataset import PointDataset
4 | import utils.data
5 | import cv2
6 | from pathlib import Path
7 |
8 | class EgoPointsDataset(PointDataset):
9 | def __init__(
10 | self,
11 | data_root='../datasets/ego_points',
12 | crop_size=(384,512),
13 | seq_len=None,
14 | only_first=False,
15 | ):
16 | super(EgoPointsDataset, self).__init__(
17 | data_root=data_root,
18 | crop_size=crop_size,
19 | seq_len=seq_len,
20 | )
21 |
22 | print('loading egopoints dataset...')
23 |
24 | self.dname = 'egopoints'
25 | self.only_first = only_first
26 |
27 | self.data = []
28 | for subfolder in Path(data_root).iterdir():
29 | if subfolder.is_dir():
30 | annot_fn = subfolder / 'annot.npz'
31 | if not annot_fn.exists():
32 | continue
33 | data = np.load(annot_fn)
34 | trajs_2d, valids, visibs, vis_valids = data['trajs_2d'], data['valids'], data['visibs'], data['vis_valids']
35 |
36 | self.data.append({
37 | 'rgb_paths': sorted(subfolder.glob('rgbs/*.jpg')),
38 | 'trajs_2d': trajs_2d,
39 | 'valids': valids,
40 | 'visibs': visibs,
41 | 'vis_valids': vis_valids,
42 | })
43 |
44 | print('found %d videos in %s' % (len(self.data), data_root))
45 |
46 | def __getitem__(self, index):
47 | dat = self.data[index]
48 | rgb_paths = dat['rgb_paths']
49 | trajs = dat['trajs_2d'] # S,N,2
50 | valids = dat['valids'] # S,N
51 | visibs = valids.copy() # we don't use this
52 |
53 | rgbs = [cv2.imread(str(rgb_path))[..., ::-1] for rgb_path in rgb_paths]
54 |
55 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
56 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
57 |
58 | # resize
59 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
60 | rgb0_raw = cv2.imread(str(rgb_paths[0]))[..., ::-1]
61 | trajs = trajs / (np.array([rgb0_raw.shape[1], rgb0_raw.shape[0]]) - 1)
62 | trajs = np.maximum(np.minimum(trajs, 1.), 0.)
63 | # 1.0,1.0 should map to the bottom-right corner pixel
64 | H, W = rgbs[0].shape[:2]
65 | trajs[:,:,0] *= W-1
66 | trajs[:,:,1] *= H-1
67 |
68 | rgbs = torch.from_numpy(np.stack(rgbs,0)).permute(0,3,1,2).contiguous().float() # S,C,H,W
69 | trajs = torch.from_numpy(trajs).float() # S,N,2
70 | valids = torch.from_numpy(valids).float() # S,N
71 | visibs = torch.from_numpy(visibs).float()
72 |
73 | if self.seq_len is not None:
74 | rgbs = rgbs[:self.seq_len]
75 | trajs = trajs[:self.seq_len]
76 | valids = valids[:self.seq_len]
77 | visibs = visibs[:self.seq_len]
78 |
79 | # req at least one timestep valid (after cutting)
80 | val_ok = torch.sum(valids, axis=0) > 0
81 | trajs = trajs[:,val_ok]
82 | valids = valids[:,val_ok]
83 | visibs = visibs[:,val_ok]
84 |
85 | sample = utils.data.VideoData(
86 | video=rgbs,
87 | trajs=trajs,
88 | valids=valids,
89 | visibs=visibs,
90 | dname=self.dname,
91 | )
92 | return sample, True
93 |
94 | def __len__(self):
95 | return len(self.data)
96 |
97 |
98 |
--------------------------------------------------------------------------------
/datasets/crohddataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from PIL import Image
5 | import cv2
6 | from datasets.pointdataset import PointDataset
7 | import utils.data
8 |
9 | class CrohdDataset(PointDataset):
10 | def __init__(
11 | self,
12 | data_root='../datasets/crohd',
13 | crop_size=(384, 512),
14 | seq_len=None,
15 | only_first=False,
16 | ):
17 | super(CrohdDataset, self).__init__(
18 | data_root=data_root,
19 | crop_size=crop_size,
20 | seq_len=seq_len,
21 | )
22 |
23 | self.dname = 'crohd'
24 | self.seq_len = seq_len
25 | self.only_first = only_first
26 |
27 | dataset_dir = "%s/HT21/train" % self.data_root
28 | label_location = "%s/HT21Labels/train" % self.data_root
29 | subfolders = ["HT21-01", "HT21-02", "HT21-03", "HT21-04"]
30 |
31 | print("loading data from {0}".format(dataset_dir))
32 | self.dataset_dir = dataset_dir
33 | self.subfolders = subfolders
34 | print("found %d samples" % len(self.subfolders))
35 |
36 | def getitem_helper(self, index):
37 | subfolder = self.subfolders[index]
38 |
39 | label_path = os.path.join(self.dataset_dir, subfolder, "gt/gt.txt")
40 | labels = np.loadtxt(label_path, delimiter=",")
41 |
42 | n_frames = int(labels[-1, 0])
43 | n_heads = int(labels[:, 1].max())
44 |
45 | bboxes = np.zeros((n_frames, n_heads, 4))
46 | visibs = np.zeros((n_frames, n_heads))
47 |
48 | for i in range(labels.shape[0]):
49 | (
50 | frame_id,
51 | head_id,
52 | bb_left,
53 | bb_top,
54 | bb_width,
55 | bb_height,
56 | conf,
57 | cid,
58 | vis,
59 | ) = labels[i]
60 | frame_id = int(frame_id) - 1 # convert 1-indexing to 0-indexing
61 | head_id = int(head_id) - 1 # convert 1-indexing to 0-indexing
62 |
63 | visibs[frame_id, head_id] = vis
64 | box_cur = np.array(
65 | [bb_left, bb_top, bb_left + bb_width, bb_top + bb_height]
66 | ) # convert xywh to x1, y1, x2, y2
67 | bboxes[frame_id, head_id] = box_cur
68 |
69 | prescale = 0.75 # to save memory
70 |
71 | # take the center of each head box as a coordinate
72 | trajs = np.stack([bboxes[:, :, [0, 2]].mean(2), bboxes[:, :, [1, 3]].mean(2)], axis=2) # S,N,2
73 | trajs = trajs * prescale
74 | valids = visibs.copy()
75 |
76 | S, N = valids.shape
77 |
78 | rgbs = []
79 | for ii in range(S):
80 | rgb_path = os.path.join(self.dataset_dir, subfolder, "img1", str(ii + 1).zfill(6) + ".jpg")
81 | rgb = Image.open(rgb_path) # 1920x1080
82 | rgb = rgb.resize((int(rgb.size[0] * prescale), int(rgb.size[1] * prescale)), Image.BILINEAR) # save memory by downsampling here
83 | rgbs.append(rgb)
84 | rgbs = np.stack(rgbs) # S,H,W,3
85 |
86 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
87 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
88 |
89 | H, W = rgbs[0].shape[:2]
90 | S, N = trajs.shape[:2]
91 | sx = W / self.crop_size[1]
92 | sy = H / self.crop_size[0]
93 | trajs[:,:,0] /= sx
94 | trajs[:,:,1] /= sy
95 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
96 | rgbs = np.stack(rgbs)
97 | H,W = rgbs[0].shape[:2]
98 |
99 | rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()
100 | trajs = torch.from_numpy(trajs).float()
101 | visibs = torch.from_numpy(visibs).float()
102 | valids = torch.from_numpy(valids).float()
103 |
104 | sample = utils.data.VideoData(
105 | video=rgbs,
106 | trajs=trajs,
107 | visibs=visibs,
108 | valids=valids,
109 | dname=self.dname,
110 | )
111 | return sample, True
112 |
113 | def __len__(self):
114 | return len(self.subfolders)
115 |
116 |
117 |
--------------------------------------------------------------------------------
/datasets/badjadataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from datasets.pointdataset import PointDataset
4 | import utils.data
5 | import cv2
6 | import glob
7 | import imageio
8 | import pandas as pd
9 |
10 | # this is badja with annotations hand-cleaned by adam in 2022
11 |
12 | class BadjaDataset(PointDataset):
13 | def __init__(
14 | self,
15 | data_root='../datasets/badja',
16 | crop_size=(384,512),
17 | seq_len=None,
18 | only_first=False,
19 | ):
20 | super(BadjaDataset, self).__init__(
21 | data_root=data_root,
22 | crop_size=crop_size,
23 | seq_len=seq_len,
24 | )
25 |
26 | self.dname = 'badja'
27 | self.seq_len = seq_len
28 | self.only_first = only_first
29 |
30 | npzs = glob.glob('%s/complete_aa/*.npz' % self.data_root)
31 | npzs = sorted(npzs)
32 | df = pd.read_csv('%s/picks_and_coords.txt' % self.data_root, sep=' ', header=None)
33 | track_names = df[0].tolist()
34 | pick_frames = np.array(df[1])
35 |
36 | self.animal_names = []
37 | self.animal_trajs = []
38 | self.animal_visibs = []
39 | self.animal_valids = []
40 | self.animal_picks = []
41 |
42 | for ind in range(len(npzs)):
43 | o = np.load(npzs[ind])
44 |
45 | animal_name = o['animal_name']
46 | trajs = o['trajs_g']
47 | valids = o['valids_g']
48 |
49 | S, N, D = trajs.shape
50 |
51 | assert(D==2)
52 |
53 | N = trajs.shape[1]
54 |
55 | # hand-picked frame where it's fair to start tracking this kp
56 | pick_g = np.zeros((N), dtype=np.int32)
57 |
58 | for n in range(N):
59 | short_name = '%s_%02d' % (animal_name, n)
60 | txt_id = track_names.index(short_name)
61 | pick_id = pick_frames[txt_id]
62 | pick_g[n] = pick_id
63 |
64 | # discard annotations before the pick
65 | valids[:pick_id,n] = 0
66 | valids[pick_id,n] = 2
67 |
68 | self.animal_names.append(animal_name)
69 | self.animal_trajs.append(trajs)
70 | self.animal_valids.append(valids)
71 |
72 | def __getitem__(self, index):
73 | animal_name = self.animal_names[index]
74 | trajs = self.animal_trajs[index].copy()
75 | valids = self.animal_valids[index]
76 | valids = (valids==2) * 1.0
77 | visibs = valids.copy()
78 |
79 | S,N,D = trajs.shape
80 |
81 | filenames = glob.glob('%s/videos/%s/*.png' % (self.data_root, animal_name)) + glob.glob('%s/videos/%s/*.jpg' % (self.data_root, animal_name))
82 | filenames = sorted(filenames)
83 | S = len(filenames)
84 | filenames_short = [fn.split('/')[-1] for fn in filenames]
85 |
86 | rgbs = []
87 | for s in range(S):
88 | filename_actual = filenames[s]
89 | rgb = imageio.imread(filename_actual)
90 | rgbs.append(rgb)
91 |
92 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
93 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
94 |
95 | S = len(rgbs)
96 | H, W, C = rgbs[0].shape
97 | N = trajs.shape[1]
98 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
99 | sx = W / self.crop_size[1]
100 | sy = H / self.crop_size[0]
101 | trajs[:,:,0] /= sx
102 | trajs[:,:,1] /= sy
103 | rgbs = np.stack(rgbs, 0)
104 | H, W, C = rgbs[0].shape
105 |
106 | rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0,3,1,2).float()
107 | trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
108 | visibs = torch.from_numpy(visibs).reshape(S, N).float()
109 | valids = torch.from_numpy(valids).reshape(S, N).float()
110 |
111 | sample = utils.data.VideoData(
112 | video=rgbs,
113 | trajs=trajs,
114 | visibs=visibs,
115 | valids=valids,
116 | dname=self.dname,
117 | )
118 | return sample, True
119 |
120 | def __len__(self):
121 | return len(self.animal_names)
122 |
123 |
--------------------------------------------------------------------------------
/utils/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | EPS = 1e-6
5 |
6 | def sub2ind(height, width, y, x):
7 | return y*width + x
8 |
9 | def ind2sub(height, width, ind):
10 | y = ind // width
11 | x = ind % width
12 | return y, x
13 |
14 | def get_lr_str(lr):
15 | lrn = "%.1e" % lr # e.g., 5.0e-04
16 | lrn = lrn[0] + lrn[3:5] + lrn[-1] # e.g., 5e-4
17 | return lrn
18 |
19 | def strnum(x):
20 | s = '%g' % x
21 | if '.' in s:
22 | if x < 1.0:
23 | s = s[s.index('.'):]
24 | s = s[:min(len(s),4)]
25 | return s
26 |
27 | def assert_same_shape(t1, t2):
28 | for (x, y) in zip(list(t1.shape), list(t2.shape)):
29 | assert(x==y)
30 |
31 | def mkdir(path):
32 | if not os.path.exists(path):
33 | os.makedirs(path)
34 |
35 | def print_stats(name, tensor):
36 | shape = tensor.shape
37 | tensor = tensor.detach().cpu().numpy()
38 | print('%s (%s) min = %.2f, mean = %.2f, max = %.2f' % (name, tensor.dtype, np.min(tensor), np.mean(tensor), np.max(tensor)), shape)
39 |
40 | def normalize_single(d):
41 | # d is a whatever shape torch tensor
42 | dmin = torch.min(d)
43 | dmax = torch.max(d)
44 | d = (d-dmin)/(EPS+(dmax-dmin))
45 | return d
46 |
47 | def normalize(d):
48 | # d is B x whatever. normalize within each element of the batch
49 | out = torch.zeros(d.size(), dtype=d.dtype, device=d.device)
50 | B = list(d.size())[0]
51 | for b in list(range(B)):
52 | out[b] = normalize_single(d[b])
53 | return out
54 |
55 | def meshgrid2d(B, Y, X, stack=False, norm=False, device='cuda', on_chans=False):
56 | # returns a meshgrid sized B x Y x X
57 |
58 | grid_y = torch.linspace(0.0, Y-1, Y, device=torch.device(device))
59 | grid_y = torch.reshape(grid_y, [1, Y, 1])
60 | grid_y = grid_y.repeat(B, 1, X)
61 |
62 | grid_x = torch.linspace(0.0, X-1, X, device=torch.device(device))
63 | grid_x = torch.reshape(grid_x, [1, 1, X])
64 | grid_x = grid_x.repeat(B, Y, 1)
65 |
66 | if norm:
67 | grid_y, grid_x = normalize_grid2d(
68 | grid_y, grid_x, Y, X)
69 |
70 | if stack:
71 | # note we stack in xy order
72 | # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
73 | if on_chans:
74 | grid = torch.stack([grid_x, grid_y], dim=1)
75 | else:
76 | grid = torch.stack([grid_x, grid_y], dim=-1)
77 | return grid
78 | else:
79 | return grid_y, grid_x
80 |
81 | def gridcloud2d(B, Y, X, norm=False, device='cuda'):
82 | # we want to sample for each location in the grid
83 | grid_y, grid_x = meshgrid2d(B, Y, X, norm=norm, device=device)
84 | x = torch.reshape(grid_x, [B, -1])
85 | y = torch.reshape(grid_y, [B, -1])
86 | # these are B x N
87 | xy = torch.stack([x, y], dim=2)
88 | # this is B x N x 2
89 | return xy
90 |
91 | def reduce_masked_mean(x, mask, dim=None, keepdim=False, broadcast=False):
92 | # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
93 | # returns shape-1
94 | # axis can be a list of axes
95 | if not broadcast:
96 | for (a,b) in zip(x.size(), mask.size()):
97 | if not a==b:
98 | print('some shape mismatch:', x.shape, mask.shape)
99 | assert(a==b) # some shape mismatch!
100 | # assert(x.size() == mask.size())
101 | prod = x*mask
102 | if dim is None:
103 | numer = torch.sum(prod)
104 | denom = EPS+torch.sum(mask)
105 | else:
106 | numer = torch.sum(prod, dim=dim, keepdim=keepdim)
107 | denom = EPS+torch.sum(mask, dim=dim, keepdim=keepdim)
108 | mean = numer/denom
109 | return mean
110 |
111 | def reduce_masked_median(x, mask, keep_batch=False):
112 | # x and mask are the same shape
113 | assert(x.size() == mask.size())
114 | device = x.device
115 |
116 | B = list(x.shape)[0]
117 | x = x.detach().cpu().numpy()
118 | mask = mask.detach().cpu().numpy()
119 |
120 | if keep_batch:
121 | x = np.reshape(x, [B, -1])
122 | mask = np.reshape(mask, [B, -1])
123 | meds = np.zeros([B], np.float32)
124 | for b in list(range(B)):
125 | xb = x[b]
126 | mb = mask[b]
127 | if np.sum(mb) > 0:
128 | xb = xb[mb > 0]
129 | meds[b] = np.median(xb)
130 | else:
131 | meds[b] = np.nan
132 | meds = torch.from_numpy(meds).to(device)
133 | return meds.float()
134 | else:
135 | x = np.reshape(x, [-1])
136 | mask = np.reshape(mask, [-1])
137 | if np.sum(mask) > 0:
138 | x = x[mask > 0]
139 | med = np.median(x)
140 | else:
141 | med = np.nan
142 | med = np.array([med], np.float32)
143 | med = torch.from_numpy(med).to(device)
144 | return med.float()
145 |
--------------------------------------------------------------------------------
/datasets/horsedataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from PIL import Image
5 | import cv2
6 | from datasets.pointdataset import PointDataset
7 | import pickle
8 | import utils.data
9 |
10 | class HorseDataset(PointDataset):
11 | def __init__(
12 | self,
13 | data_root='../datasets/horse10',
14 | crop_size=(384, 512),
15 | seq_len=None,
16 | only_first=False,
17 | ):
18 | super(HorseDataset, self).__init__(
19 | data_root=data_root,
20 | crop_size=crop_size,
21 | seq_len=seq_len,
22 | )
23 | print("loading horse dataset...")
24 |
25 | self.seq_len = seq_len
26 | self.only_first = only_first
27 | self.dname = 'hor'
28 |
29 | self.dataset_location = data_root
30 | self.anno_path = os.path.join(self.dataset_location, "seq_annotation.pkl")
31 | with open(self.anno_path, "rb") as f:
32 | self.annotation = pickle.load(f)
33 |
34 | self.video_names = []
35 |
36 | for video_name in list(self.annotation.keys()):
37 | video = self.annotation[video_name]
38 |
39 | rgbs = []
40 | trajs = []
41 | visibs = []
42 | for sample in video:
43 | img_path = sample["img_path"]
44 | img_path = self.dataset_location + '/' + img_path
45 | rgb = Image.open(img_path)
46 | rgbs.append(rgb)
47 | trajs.append(np.squeeze(sample["keypoints"], 0))
48 | visibs.append(np.squeeze(sample["keypoints_visible"], 0))
49 |
50 | rgbs = np.stack(rgbs, axis=0)
51 | trajs = np.stack(trajs, axis=0)
52 | visibs = np.stack(visibs, axis=0)
53 | valids = visibs.copy()
54 |
55 | S, H, W, C = rgbs.shape
56 | _, N, D = trajs.shape
57 |
58 | for si in range(S):
59 | # avoid 2px edge, since these are not really visible (according to adam)
60 | oob_inds = np.logical_or(
61 | np.logical_or(trajs[si, :, 0] < 2, trajs[si, :, 0] >= W-2),
62 | np.logical_or(trajs[si, :, 1] < 2, trajs[si, :, 1] >= H-2),
63 | )
64 | visibs[si, oob_inds] = 0
65 |
66 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
67 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
68 |
69 | N = trajs.shape[1]
70 | if N > 0:
71 | self.video_names.append(video_name)
72 |
73 | print(f"found {len(self.annotation)} unique videos in {self.dataset_location}")
74 |
75 | def getitem_helper(self, index):
76 | video_name = self.video_names[index]
77 | video = self.annotation[video_name]
78 |
79 | rgbs = []
80 | trajs = []
81 | visibs = []
82 | for sample in video:
83 | img_path = sample["img_path"]
84 | img_path = self.dataset_location + '/' + img_path
85 | rgb = Image.open(img_path)
86 | rgbs.append(rgb)
87 | trajs.append(np.squeeze(sample["keypoints"], 0))
88 | visibs.append(np.squeeze(sample["keypoints_visible"], 0))
89 |
90 | rgbs = np.stack(rgbs, axis=0)
91 | trajs = np.stack(trajs, axis=0)
92 | visibs = np.stack(visibs, axis=0)
93 | valids = visibs.copy()
94 |
95 | S, H, W, C = rgbs.shape
96 | _, N, D = trajs.shape
97 |
98 | for si in range(S):
99 | # avoid 2px edge, since these are not really visible (according to adam)
100 | oob_inds = np.logical_or(
101 | np.logical_or(trajs[si, :, 0] < 2, trajs[si, :, 0] >= W-2),
102 | np.logical_or(trajs[si, :, 1] < 2, trajs[si, :, 1] >= H-2),
103 | )
104 | visibs[si, oob_inds] = 0
105 |
106 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
107 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
108 |
109 | H, W = rgbs[0].shape[:2]
110 | trajs[:,:,0] /= W-1
111 | trajs[:,:,1] /= H-1
112 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
113 | rgbs = np.stack(rgbs)
114 | H,W = rgbs[0].shape[:2]
115 | trajs[:,:,0] *= W-1
116 | trajs[:,:,1] *= H-1
117 |
118 | rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()
119 | trajs = torch.from_numpy(trajs)
120 | visibs = torch.from_numpy(visibs)
121 | valids = torch.from_numpy(valids)
122 |
123 | sample = utils.data.VideoData(
124 | video=rgbs,
125 | trajs=trajs,
126 | visibs=visibs,
127 | valids=valids,
128 | dname=self.dname,
129 | )
130 | return sample, True
131 |
132 | def __len__(self):
133 | return len(self.video_names)
134 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dataclasses
3 | import torch.nn.functional as F
4 | from dataclasses import dataclass
5 | from typing import Any, Optional, Dict
6 | import utils.misc
7 | import numpy as np
8 |
9 | def replace_invalid_xys_with_nearest(xys, valids):
10 | # replace invalid xys with nearby ones
11 | invalid_idx = np.where(valids==0)[0]
12 | valid_idx = np.where(valids==1)[0]
13 | for idx in invalid_idx:
14 | nearest = valid_idx[np.argmin(np.abs(valid_idx - idx))]
15 | xys[idx] = xys[nearest]
16 | return xys
17 |
18 | def standardize_test_data(rgbs, trajs, visibs, valids, S_cap=600, only_first=False, seq_len=None):
19 | trajs = trajs.astype(np.float32) # S,N,2
20 | visibs = visibs.astype(np.float32) # S,N
21 | valids = valids.astype(np.float32) # S,N
22 |
23 | # only take tracks that make sense
24 | visval_ok = np.sum(valids*visibs, axis=0) > 1
25 | trajs = trajs[:,visval_ok]
26 | visibs = visibs[:,visval_ok]
27 | valids = valids[:,visval_ok]
28 |
29 | # fill in missing data (for visualization)
30 | N = trajs.shape[1]
31 | for ni in range(N):
32 | trajs[:,ni] = replace_invalid_xys_with_nearest(trajs[:,ni], valids[:,ni])
33 |
34 | # use S_cap or seq_len (legacy)
35 | if seq_len is not None:
36 | S = min(len(rgbs), seq_len)
37 | else:
38 | S = len(rgbs)
39 | S = min(S, S_cap)
40 |
41 | if only_first:
42 | # we'll find the best frame to start on
43 | best_count = 0
44 | best_ind = 0
45 |
46 | for si in range(0,len(rgbs)-64):
47 | # try this slice
48 | visibs_ = visibs[si:min(si+S,len(rgbs)+1)] # S,N
49 | valids_ = valids[si:min(si+S,len(rgbs)+1)] # S,N
50 | visval_ok0 = (visibs_[0]*valids_[0]) > 0 # N
51 | visval_okA = np.sum(visibs_*valids_, axis=0) > 1 # N
52 | all_ok = visval_ok0 & visval_okA
53 | # print('- slicing %d to %d; sum(ok) %d' % (si, min(si+S,len(rgbs)+1), np.sum(all_ok)))
54 | if np.sum(all_ok) > best_count:
55 | best_count = np.sum(all_ok)
56 | best_ind = si
57 | si = best_ind
58 | rgbs = rgbs[si:si+S]
59 | trajs = trajs[si:si+S]
60 | visibs = visibs[si:si+S]
61 | valids = valids[si:si+S]
62 | vis_ok0 = visibs[0] > 0 # N
63 | trajs = trajs[:,vis_ok0]
64 | visibs = visibs[:,vis_ok0]
65 | valids = valids[:,vis_ok0]
66 | # print('- best_count', best_count, 'best_ind', best_ind)
67 |
68 | if seq_len is not None:
69 | rgbs = rgbs[:seq_len]
70 | trajs = trajs[:seq_len]
71 | valids = valids[:seq_len]
72 |
73 | # req two timesteps valid (after seqlen trim)
74 | visval_ok = np.sum(visibs*valids, axis=0) > 1
75 | trajs = trajs[:,visval_ok]
76 | valids = valids[:,visval_ok]
77 | visibs = visibs[:,visval_ok]
78 |
79 | return rgbs, trajs, visibs, valids
80 |
81 |
82 | @dataclass(eq=False)
83 | class VideoData:
84 | """
85 | Dataclass for storing video tracks data.
86 | """
87 |
88 | video: torch.Tensor # B,S,C,H,W
89 | trajs: torch.Tensor # B,S,N,2
90 | visibs: torch.Tensor # B,S,N
91 | valids: Optional[torch.Tensor] = None # B,S,N
92 | dname: Optional[str] = None
93 |
94 |
95 | def collate_fn(batch):
96 | """
97 | Collate function for video tracks data.
98 | """
99 | video = torch.stack([b.video for b in batch], dim=0)
100 | trajs = torch.stack([b.trajs for b in batch], dim=0)
101 | visibs = torch.stack([b.visibs for b in batch], dim=0)
102 | dname = [b.dname for b in batch]
103 |
104 | return VideoData(
105 | video=video,
106 | trajs=trajs,
107 | visibs=visibs,
108 | dname=dname,
109 | )
110 |
111 |
112 | def collate_fn_train(batch):
113 | """
114 | Collate function for video tracks data during training.
115 | """
116 | gotit = [gotit for _, gotit in batch]
117 | video = torch.stack([b.video for b, _ in batch], dim=0)
118 | trajs = torch.stack([b.trajs for b, _ in batch], dim=0)
119 | visibs = torch.stack([b.visibs for b, _ in batch], dim=0)
120 | valids = torch.stack([b.valids for b, _ in batch], dim=0)
121 | dname = [b.dname for b, _ in batch]
122 |
123 | return (
124 | VideoData(
125 | video=video,
126 | trajs=trajs,
127 | visibs=visibs,
128 | valids=valids,
129 | dname=dname,
130 | ),
131 | gotit,
132 | )
133 |
134 |
135 | def try_to_cuda(t: Any) -> Any:
136 | """
137 | Try to move the input variable `t` to a cuda device.
138 |
139 | Args:
140 | t: Input.
141 |
142 | Returns:
143 | t_cuda: `t` moved to a cuda device, if supported.
144 | """
145 | try:
146 | t = t.float().cuda()
147 | except AttributeError:
148 | pass
149 | return t
150 |
151 |
152 | def dataclass_to_cuda_(obj):
153 | """
154 | Move all contents of a dataclass to cuda inplace if supported.
155 |
156 | Args:
157 | batch: Input dataclass.
158 |
159 | Returns:
160 | batch_cuda: `batch` moved to a cuda device, if supported.
161 | """
162 | for f in dataclasses.fields(obj):
163 | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
164 | return obj
165 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AllTracker: Efficient Dense Point Tracking at High Resolution
2 |
3 | **[[Paper](https://arxiv.org/abs/2506.07310)] [[Project Page](https://alltracker.github.io/)] [[Gradio Demo](https://huggingface.co/spaces/aharley/alltracker)]**
4 |
5 |
6 |
7 | **AllTracker is a point tracking model which is faster and more accurate than other similar models, while also producing dense output at high resolution.**
8 |
9 | AllTracker estimates long-range point tracks by way of estimating the flow field between a query frame and every other frame of a video. Unlike existing point tracking methods, our approach delivers high-resolution and dense (all-pixel) correspondence fields, which can be visualized as flow maps. Unlike existing optical flow methods, our approach corresponds one frame to hundreds of subsequent frames, rather than just the next frame.
10 |
11 | We are actively adding to this repo, but please ping or open an issue if you notice something missing or broken. The demo (at least) should work for everyone!
12 |
13 |
14 | ## Env setup
15 |
16 | Install miniconda:
17 | ```
18 | mkdir -p ~/miniconda3
19 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
20 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
21 | rm ~/miniconda3/miniconda.sh
22 | source ~/miniconda3/bin/activate
23 | conda init
24 | ```
25 |
26 | Set up a fresh conda environment for AllTracker:
27 |
28 | ```
29 | conda create -n alltracker python=3.12.8
30 | conda activate alltracker
31 | pip install -r requirements.txt
32 | ```
33 |
34 | ## Running the demo
35 |
36 | Download the sample video:
37 | ```
38 | cd demo_video
39 | sh download_video.sh
40 | cd ..
41 | ```
42 |
43 | Run the demo:
44 | ```
45 | python demo.py --mp4_path ./demo_video/monkey.mp4
46 | ```
47 | The demo script will automatically download the model weights from [huggingface](https://huggingface.co/aharley/alltracker/tree/main) if needed.
48 |
49 | For a fancier visualization, giving a side-by-side view of the input and output, try this:
50 | ```
51 | python demo.py --mp4_path ./demo_video/monkey.mp4 --query_frame 32 --conf_thr 0.01 --bkg_opacity 0.0 --rate 2 --hstack --query_frame 16
52 | ```
53 |
54 |
55 |
56 |
57 | ## Training
58 |
59 | AllTracker is trained in two stages: Stage 1 is kubric alone; Stage 2 is a mix of datasets. This 2-stage regime enables fair comparisons with models that train only on Kubric.
60 |
61 | ### Data prep
62 |
63 | Start by downloding Kubric.
64 |
65 | - 24-frame data: [kubric_au.tar.gz](https://huggingface.co/datasets/aharley/alltracker_data/resolve/main/kubric_au.tar.gz?download=true)
66 |
67 | - 64-frame data: [part1](https://huggingface.co/datasets/aharley/alltracker_data/resolve/main/ce64_kub_aa?download=true), [part2](https://huggingface.co/datasets/aharley/alltracker_data/resolve/main/ce64_kub_ab?download=true), [part3](https://huggingface.co/datasets/aharley/alltracker_data/resolve/main/ce64_kub_ac?download=true)
68 |
69 | Merge the parts by concatenating:
70 | ```
71 | cat ce64_kub_aa ce64_kub_ab ce64_kub_ac > ce64_kub.tar.gz
72 | ```
73 |
74 | The 24-frame Kubric data is a torch export of the official `kubric-public/tfds/movi_f/512x512` data.
75 |
76 | With Kubric, you can skip the other datasets and start training Stage 1.
77 |
78 | Download the rest of the point tracking datasets from [here](https://huggingface.co/datasets/aharley/alltracker_data/tree/main). There you will find 24-frame datasets, `ce24*.tar.gz`, and 64-frame datasets, `ce64*.tar.gz`. Some of the datasets are large, and they are split into parts, so you need to create the full files by concatenating.
79 |
80 | On disk, the point tracking datasets should look like this:
81 | ```
82 | data/
83 | ├── ce24/
84 | │ ├── drivingpt/
85 | │ ├── fltpt/
86 | │ ├── monkapt/
87 | │ ├── springpt/
88 | ├── ce64/
89 | │ ├── drivingpt/
90 | │ ├── kublong/
91 | │ ├── monkapt/
92 | │ ├── podlong/
93 | │ ├── springpt/
94 | ├── dynamicreplica/
95 | ├── kubric_au/
96 | ```
97 |
98 | Download the optical flow datasets from the official websites: [FlyingChairs, FlyingThings3D, Monkaa, Driving](https://lmb.informatik.uni-freiburg.de/resources/datasets) [AutoFlow](https://autoflow-google.github.io/), [SPRING](https://spring-benchmark.org/), [VIPER](https://playing-for-benchmarks.org/download/), [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/), [KITTI](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow), [TARTANAIR](https://theairlab.org/tartanair-dataset/).
99 |
100 |
101 | ### Stage 1
102 |
103 | Stage 1 is to train the model for 200k steps on Kubric.
104 |
105 | ```
106 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7; python train_stage1.py --mixed_precision --lr 5e-4 --max_steps 200000 --data_dir /data --exp "stage1abc"
107 | ```
108 |
109 | This should produce a tensorboard log in `./logs_train/`, and checkpoints in `./checkpoints/`, in folder names similar to "64Ai4i3_5e-4m_stage1abc_1318". (The 4-digit string at the end is a timecode indicating when the run began, to help make the filepaths unique.)
110 |
111 | ### Stage 2
112 |
113 | Stage 2 is to train the model for 400k steps on a mix of point tracking datasets and optical flow datasets. This stage initializes from the output of Stage 1.
114 |
115 | ```
116 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7; python train_stage2.py --mixed_precision --init_dir '64Ai4i3_5e-4m_stage1abc_1318' --lr 1e-5 --max_steps 400000 --exp 'stage2abc'
117 | ```
118 |
119 | ## Evaluation
120 |
121 | Test the model on point tracking datasets with a command like:
122 |
123 | ```
124 | python test_dense_on_sparse.py --dname 'dav'
125 | ```
126 |
127 | At the end, you should see:
128 | ```
129 | da: 76.3,
130 | aj: 63.3,
131 | oa: 90.0,
132 | ```
133 | which represent `d_avg` (accuracy), Average Jaccard, and Occlusion Accuracy. We find that small numerical issues (even across GPUs) may cause +- 0.1 fluctuation on these metrics.
134 |
135 | Test it at higher resolution with the `image_size` arg, like: `--image_size 448 768`, which should produce `da: 78.8, aj: 65.9, oa: 90.2` or `--image_size 768 1024`, which should produce `da: 80.6, aj: 67.2, oa: 89.7`.
136 |
137 | Note that AJ and OA are not reliable metrics in all datasets, because not all datasets follow the same rules about visibility annotation.
138 |
139 | Dataloaders for all test datasets are in the repo, and can be run in sequence with a command like:
140 | ```
141 | python test_dense_on_sparse.py --dname 'bad,cro,dav,dri,ego,hor,kin,rgb,rob'
142 | ```
143 | but if you have multiple GPUs, we recomend running the tests in parallel.
144 |
145 |
146 | ## Citation
147 |
148 | If you use this code for your research, please cite:
149 |
150 | ```
151 | Adam W. Harley, Yang You, Xinglong Sun, Yang Zheng, Nikhil Raghuraman, Yunqi Gu, Sheldon Liang, Wen-Hsuan Chu, Achal Dave, Pavel Tokmakov, Suya You, Rares Ambrus, Katerina Fragkiadaki, Leonidas J. Guibas. AllTracker: Efficient Dense Point Tracking at High Resolution. ICCV 2025.
152 | ```
153 |
154 | Bibtex:
155 | ```
156 | @inproceedings{harley2025alltracker,
157 | author = {Adam W. Harley and Yang You and Xinglong Sun and Yang Zheng and Nikhil Raghuraman and Yunqi Gu and Sheldon Liang and Wen-Hsuan Chu and Achal Dave and Pavel Tokmakov and Suya You and Rares Ambrus and Katerina Fragkiadaki and Leonidas J. Guibas},
158 | title = {All{T}racker: {E}fficient Dense Point Tracking at High Resolution}
159 | booktitle = {ICCV},
160 | year = {2025}
161 | }
162 | ```
163 |
--------------------------------------------------------------------------------
/utils/samp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import utils.basic
3 | import torch.nn.functional as F
4 |
5 | def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
6 | r"""Sample a tensor using bilinear interpolation
7 |
8 | `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
9 | coordinates :attr:`coords` using bilinear interpolation. It is the same
10 | as `torch.nn.functional.grid_sample()` but with a different coordinate
11 | convention.
12 |
13 | The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
14 | :math:`B` is the batch size, :math:`C` is the number of channels,
15 | :math:`H` is the height of the image, and :math:`W` is the width of the
16 | image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
17 | interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
18 |
19 | Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
20 | in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
21 | that in this case the order of the components is slightly different
22 | from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
23 |
24 | If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
25 | in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
26 | left-most image pixel :math:`W-1` to the center of the right-most
27 | pixel.
28 |
29 | If `align_corners` is `False`, the coordinate :math:`x` is assumed to
30 | be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
31 | the left-most pixel :math:`W` to the right edge of the right-most
32 | pixel.
33 |
34 | Similar conventions apply to the :math:`y` for the range
35 | :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
36 | :math:`[0,T-1]` and :math:`[0,T]`.
37 |
38 | Args:
39 | input (Tensor): batch of input images.
40 | coords (Tensor): batch of coordinates.
41 | align_corners (bool, optional): Coordinate convention. Defaults to `True`.
42 | padding_mode (str, optional): Padding mode. Defaults to `"border"`.
43 |
44 | Returns:
45 | Tensor: sampled points.
46 | """
47 |
48 | sizes = input.shape[2:]
49 |
50 | assert len(sizes) in [2, 3]
51 |
52 | if len(sizes) == 3:
53 | # t x y -> x y t to match dimensions T H W in grid_sample
54 | coords = coords[..., [1, 2, 0]]
55 |
56 | if align_corners:
57 | coords = coords * torch.tensor(
58 | [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
59 | )
60 | else:
61 | coords = coords * torch.tensor(
62 | [2 / size for size in reversed(sizes)], device=coords.device
63 | )
64 |
65 | coords -= 1
66 |
67 | return F.grid_sample(
68 | input, coords, align_corners=align_corners, padding_mode=padding_mode
69 | )
70 |
71 |
72 | def sample_features4d(input, coords):
73 | r"""Sample spatial features
74 |
75 | `sample_features4d(input, coords)` samples the spatial features
76 | :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
77 |
78 | The field is sampled at coordinates :attr:`coords` using bilinear
79 | interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
80 | 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
81 | same convention as :func:`bilinear_sampler` with `align_corners=True`.
82 |
83 | The output tensor has one feature per point, and has shape :math:`(B,
84 | R, C)`.
85 |
86 | Args:
87 | input (Tensor): spatial features.
88 | coords (Tensor): points.
89 |
90 | Returns:
91 | Tensor: sampled features.
92 | """
93 |
94 | B, _, _, _ = input.shape
95 |
96 | # B R 2 -> B R 1 2
97 | coords = coords.unsqueeze(2)
98 |
99 | # B C R 1
100 | feats = bilinear_sampler(input, coords)
101 |
102 | return feats.permute(0, 2, 1, 3).view(
103 | B, -1, feats.shape[1] * feats.shape[3]
104 | ) # B C R 1 -> B R C
105 |
106 |
107 | def sample_features5d(input, coords):
108 | r"""Sample spatio-temporal features
109 |
110 | `sample_features5d(input, coords)` works in the same way as
111 | :func:`sample_features4d` but for spatio-temporal features and points:
112 | :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
113 | a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
114 | x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
115 |
116 | Args:
117 | input (Tensor): spatio-temporal features.
118 | coords (Tensor): spatio-temporal points.
119 |
120 | Returns:
121 | Tensor: sampled features.
122 | """
123 |
124 | B, T, _, _, _ = input.shape
125 |
126 | # B T C H W -> B C T H W
127 | input = input.permute(0, 2, 1, 3, 4)
128 |
129 | # B R1 R2 3 -> B R1 R2 1 3
130 | coords = coords.unsqueeze(3)
131 |
132 | # B C R1 R2 1
133 | feats = bilinear_sampler(input, coords)
134 |
135 | return feats.permute(0, 2, 3, 1, 4).view(
136 | B, feats.shape[2], feats.shape[3], feats.shape[1]
137 | ) # B C R1 R2 1 -> B R1 R2 C
138 |
139 |
140 | def bilinear_sample2d(im, x, y, return_inbounds=False):
141 | # x and y are each B, N
142 | # output is B, C, N
143 | B, C, H, W = list(im.shape)
144 | N = list(x.shape)[1]
145 |
146 | x = x.float()
147 | y = y.float()
148 | H_f = torch.tensor(H, dtype=torch.float32)
149 | W_f = torch.tensor(W, dtype=torch.float32)
150 |
151 | # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte()
208 | y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
209 | inbounds = (x_valid & y_valid).float()
210 | inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
211 | return output, inbounds
212 |
213 | return output # B, C, N
214 |
--------------------------------------------------------------------------------
/datasets/drivetrackdataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import glob
5 | import cv2
6 | from datasets.pointdataset import PointDataset
7 | import pickle
8 | import utils.data
9 | from pathlib import Path
10 |
11 | class DrivetrackDataset(PointDataset):
12 | def __init__(
13 | self,
14 | data_root='../datasets/drivetrack',
15 | crop_size=(384, 512),
16 | seq_len=None,
17 | traj_per_sample=512,
18 | only_first=False,
19 | ):
20 | super(DrivetrackDataset, self).__init__(
21 | data_root=data_root,
22 | crop_size=crop_size,
23 | seq_len=seq_len,
24 | traj_per_sample=traj_per_sample,
25 | )
26 | print("loading drivetrack dataset...")
27 |
28 | self.dname = 'drivetrack'
29 | self.only_first = only_first
30 | S = seq_len
31 |
32 | self.dataset_location = Path(data_root)
33 | self.S = S
34 | video_fns = sorted(list(self.dataset_location.glob('*.npz')))
35 |
36 | self.video_fns = []
37 | for video_fn in video_fns[:100]: # drivetrack is huge and self-similar, so we trim to 100
38 | ds = np.load(video_fn, allow_pickle=True)
39 | rgbs, trajs, visibs = ds['video'], ds['tracks'], ds['visibles']
40 | # rgbs is T,1280,1920,3
41 | # trajs is N,T,2
42 | # visibs is N,T
43 |
44 | trajs = np.transpose(trajs, (1,0,2)).astype(np.float32) # S,N,2
45 | visibs = np.transpose(visibs, (1,0)).astype(np.float32) # S,N
46 | valids = visibs.copy()
47 | # print('N0', trajs.shape[1])
48 |
49 | # discard tracks with any inf/nan
50 | idx = np.nonzero(np.isfinite(trajs.sum(0).sum(1)))[0] # N
51 | trajs = trajs[:,idx]
52 | visibs = visibs[:,idx]
53 | valids = valids[:,idx]
54 | # print('N1', trajs.shape[1])
55 |
56 | if trajs.shape[1] < self.traj_per_sample:
57 | continue
58 |
59 | # shuffle and trim
60 | inds = np.random.permutation(trajs.shape[1])
61 | inds = inds[:10000]
62 | trajs = trajs[:,inds]
63 | visibs = visibs[:,inds]
64 | valids = valids[:,inds]
65 | # print('N2', trajs.shape[1])
66 |
67 | S,H,W,C = rgbs.shape
68 |
69 | # set OOB to invisible
70 | visibs[trajs[:, :, 0] > W-1] = False
71 | visibs[trajs[:, :, 0] < 0] = False
72 | visibs[trajs[:, :, 1] > H-1] = False
73 | visibs[trajs[:, :, 1] < 0] = False
74 |
75 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
76 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
77 | # print('N3', trajs.shape[1])
78 |
79 | if trajs.shape[1] < self.traj_per_sample:
80 | continue
81 |
82 | trajs = torch.from_numpy(trajs)
83 | visibs = torch.from_numpy(visibs)
84 | valids = torch.from_numpy(valids)
85 | # discard tracks that go far OOB
86 | crop_tensor = torch.tensor(self.crop_size).flip(0)[None, None] / 2.0
87 | close_pts_inds = torch.all(
88 | torch.linalg.vector_norm(trajs[..., :2] - crop_tensor, dim=-1) < max(H,W)*2,
89 | dim=0,
90 | )
91 | trajs = trajs[:, close_pts_inds]
92 | visibs = visibs[:, close_pts_inds]
93 | valids = valids[:, close_pts_inds]
94 | # print('N4', trajs.shape[1])
95 |
96 | if trajs.shape[1] < self.traj_per_sample:
97 | continue
98 |
99 | visible_inds = (valids[0]*visibs[0]).nonzero(as_tuple=False)[:, 0]
100 | trajs = trajs[:, visible_inds].float()
101 | visibs = visibs[:, visible_inds].float()
102 | valids = valids[:, visible_inds].float()
103 | # print('N5', trajs.shape[1])
104 |
105 | if trajs.shape[1] >= self.traj_per_sample:
106 | self.video_fns.append(video_fn)
107 |
108 | print(f"found {len(self.video_fns)} unique videos in {self.dataset_location}")
109 |
110 | def getitem_helper(self, index):
111 | video_fn = self.video_fns[index]
112 | ds = np.load(video_fn, allow_pickle=True)
113 | rgbs, trajs, visibs = ds['video'], ds['tracks'], ds['visibles']
114 | # rgbs is T,1280,1920,3
115 | # trajs is N,T,2
116 | # visibs is N,T
117 |
118 | trajs = np.transpose(trajs, (1,0,2)).astype(np.float32) # S,N,2
119 | visibs = np.transpose(visibs, (1,0)).astype(np.float32) # S,N
120 | valids = visibs.copy()
121 |
122 | # discard inf/nan
123 | idx = np.nonzero(np.isfinite(trajs.sum(0).sum(1)))[0] # N
124 | trajs = trajs[:,idx]
125 | visibs = visibs[:,idx]
126 | valids = valids[:,idx]
127 |
128 | # shuffle and trim
129 | inds = np.random.permutation(trajs.shape[1])
130 | inds = inds[:10000]
131 | trajs = trajs[:,inds]
132 | visibs = visibs[:,inds]
133 | valids = valids[:,inds]
134 | N = trajs.shape[1]
135 | # print('N2', trajs.shape[1])
136 |
137 | S,H,W,C = rgbs.shape
138 | # set OOB to invisible
139 | visibs[trajs[:, :, 0] > W-1] = False
140 | visibs[trajs[:, :, 0] < 0] = False
141 | visibs[trajs[:, :, 1] > H-1] = False
142 | visibs[trajs[:, :, 1] < 0] = False
143 |
144 | rgbs, trajs, visibs, valids = utils.data.standardize_test_data(
145 | rgbs, trajs, visibs, valids, only_first=self.only_first, seq_len=self.seq_len)
146 |
147 | H, W = rgbs[0].shape[:2]
148 | trajs[:,:,0] /= W-1
149 | trajs[:,:,1] /= H-1
150 | rgbs = [cv2.resize(rgb, (self.crop_size[1], self.crop_size[0]), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
151 | rgbs = np.stack(rgbs)
152 | H,W = rgbs[0].shape[:2]
153 | trajs[:,:,0] *= W-1
154 | trajs[:,:,1] *= H-1
155 |
156 | trajs = torch.from_numpy(trajs)
157 | visibs = torch.from_numpy(visibs)
158 | valids = torch.from_numpy(valids)
159 |
160 | # discard tracks that go far OOB
161 | crop_tensor = torch.tensor(self.crop_size).flip(0)[None, None] / 2.0
162 | close_pts_inds = torch.all(
163 | torch.linalg.vector_norm(trajs[..., :2] - crop_tensor, dim=-1) < max(H,W)*2,
164 | dim=0,
165 | )
166 | trajs = trajs[:, close_pts_inds]
167 | visibs = visibs[:, close_pts_inds]
168 | valids = valids[:, close_pts_inds]
169 | # print('N3', trajs.shape[1])
170 |
171 | visible_pts_inds = (valids[0]*visibs[0]).nonzero(as_tuple=False)[:, 0]
172 | point_inds = torch.randperm(len(visible_pts_inds))[: self.traj_per_sample]
173 | if len(point_inds) < self.traj_per_sample:
174 | return None, False
175 | visible_inds_sampled = visible_pts_inds[point_inds]
176 | trajs = trajs[:, visible_inds_sampled].float()
177 | visibs = visibs[:, visible_inds_sampled].float()
178 | valids = valids[:, visible_inds_sampled].float()
179 | # print('N4', trajs.shape[1])
180 |
181 | trajs = trajs[:, :self.traj_per_sample]
182 | visibs = visibs[:, :self.traj_per_sample]
183 | valids = valids[:, :self.traj_per_sample]
184 | # print('N5', trajs.shape[1])
185 |
186 | rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()
187 |
188 | sample = utils.data.VideoData(
189 | video=rgbs,
190 | trajs=trajs,
191 | visibs=visibs,
192 | valids=valids,
193 | dname=self.dname,
194 | )
195 | return sample, True
196 |
197 | def __len__(self):
198 | return len(self.video_fns)
199 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from typing import List
5 | import utils.basic
6 |
7 |
8 | def sequence_loss(
9 | flow_preds,
10 | flow_gt,
11 | valids,
12 | vis=None,
13 | gamma=0.8,
14 | use_huber_loss=False,
15 | loss_only_for_visible=False,
16 | ):
17 | """Loss function defined over sequence of flow predictions"""
18 | total_flow_loss = 0.0
19 | for j in range(len(flow_gt)):
20 | B, S, N, D = flow_gt[j].shape
21 | B, S2, N = valids[j].shape
22 | assert S == S2
23 | n_predictions = len(flow_preds[j])
24 | flow_loss = 0.0
25 | for i in range(n_predictions):
26 | i_weight = gamma ** (n_predictions - i - 1)
27 | flow_pred = flow_preds[j][i]
28 | if use_huber_loss:
29 | i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
30 | else:
31 | i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
32 | i_loss = torch.mean(i_loss, dim=3) # B, S, N
33 | valid_ = valids[j].clone()
34 | if loss_only_for_visible:
35 | valid_ = valid_ * vis[j]
36 | flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_)
37 | flow_loss = flow_loss / n_predictions
38 | total_flow_loss += flow_loss
39 | return total_flow_loss / len(flow_gt)
40 |
41 | def sequence_loss_dense(
42 | flow_preds,
43 | flow_gt,
44 | valids,
45 | vis=None,
46 | gamma=0.8,
47 | use_huber_loss=False,
48 | loss_only_for_visible=False,
49 | ):
50 | """Loss function defined over sequence of flow predictions"""
51 | total_flow_loss = 0.0
52 | for j in range(len(flow_gt)):
53 | # print('flow_gt[j]', flow_gt[j].shape)
54 | B, S, D, H, W = flow_gt[j].shape
55 | B, S2, _, H, W = valids[j].shape
56 | assert S == S2
57 | n_predictions = len(flow_preds[j])
58 | flow_loss = 0.0
59 | # import ipdb; ipdb.set_trace()
60 | for i in range(n_predictions):
61 | # print('flow_e[j][i]', flow_preds[j][i].shape)
62 | i_weight = gamma ** (n_predictions - i - 1)
63 | flow_pred = flow_preds[j][i] # B,S,2,H,W
64 | if use_huber_loss:
65 | i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W
66 | else:
67 | i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W
68 | i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W
69 | valid_ = valids[j].reshape(B,S,H,W)
70 | # print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape)
71 | # print(' (%d,%d) valid_' % (i,j), valid_.shape)
72 | if loss_only_for_visible:
73 | valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W
74 | flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True)
75 | # import ipdb; ipdb.set_trace()
76 | flow_loss = flow_loss / n_predictions
77 | total_flow_loss += flow_loss
78 | return total_flow_loss / len(flow_gt)
79 |
80 |
81 | def huber_loss(x, y, delta=1.0):
82 | """Calculate element-wise Huber loss between x and y"""
83 | diff = x - y
84 | abs_diff = diff.abs()
85 | flag = (abs_diff <= delta).float()
86 | return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
87 |
88 |
89 | def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False):
90 | total_bce_loss = 0.0
91 | # all_vis_preds = [torch.stack(vp) for vp in vis_preds]
92 | # all_vis_preds = torch.stack(all_vis_preds)
93 | # utils.basic.print_stats('all_vis_preds', all_vis_preds)
94 | for j in range(len(vis_preds)):
95 | n_predictions = len(vis_preds[j])
96 | bce_loss = 0.0
97 | for i in range(n_predictions):
98 | # utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i])
99 | # utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i])
100 | if use_logits:
101 | loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none')
102 | else:
103 | loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none')
104 | if valids is None:
105 | bce_loss += loss.mean()
106 | else:
107 | bce_loss += (loss * valids[j]).mean()
108 | bce_loss = bce_loss / n_predictions
109 | total_bce_loss += bce_loss
110 | return total_bce_loss / len(vis_preds)
111 |
112 |
113 | # def sequence_BCE_loss_dense(vis_preds, vis_gts):
114 | # total_bce_loss = 0.0
115 | # for j in range(len(vis_preds)):
116 | # n_predictions = len(vis_preds[j])
117 | # bce_loss = 0.0
118 | # for i in range(n_predictions):
119 | # vis_e = vis_preds[j][i]
120 | # vis_g = vis_gts[j]
121 | # print('vis_e', vis_e.shape, 'vis_g', vis_g.shape)
122 | # vis_loss = F.binary_cross_entropy(vis_e, vis_g)
123 | # bce_loss += vis_loss
124 | # bce_loss = bce_loss / n_predictions
125 | # total_bce_loss += bce_loss
126 | # return total_bce_loss / len(vis_preds)
127 |
128 |
129 | def sequence_prob_loss(
130 | tracks: torch.Tensor,
131 | confidence: torch.Tensor,
132 | target_points: torch.Tensor,
133 | visibility: torch.Tensor,
134 | expected_dist_thresh: float = 12.0,
135 | use_logits=False,
136 | ):
137 | """Loss for classifying if a point is within pixel threshold of its target."""
138 | # Points with an error larger than 12 pixels are likely to be useless; marking
139 | # them as occluded will actually improve Jaccard metrics and give
140 | # qualitatively better results.
141 | total_logprob_loss = 0.0
142 | for j in range(len(tracks)):
143 | n_predictions = len(tracks[j])
144 | logprob_loss = 0.0
145 | for i in range(n_predictions):
146 | err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
147 | valid = (err <= expected_dist_thresh**2).float()
148 | if use_logits:
149 | loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none")
150 | else:
151 | loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none")
152 | loss *= visibility[j]
153 | loss = torch.mean(loss, dim=[1, 2])
154 | logprob_loss += loss
155 | logprob_loss = logprob_loss / n_predictions
156 | total_logprob_loss += logprob_loss
157 | return total_logprob_loss / len(tracks)
158 |
159 | def sequence_prob_loss_dense(
160 | tracks: torch.Tensor,
161 | confidence: torch.Tensor,
162 | target_points: torch.Tensor,
163 | visibility: torch.Tensor,
164 | expected_dist_thresh: float = 12.0,
165 | use_logits=False,
166 | ):
167 | """Loss for classifying if a point is within pixel threshold of its target."""
168 | # Points with an error larger than 12 pixels are likely to be useless; marking
169 | # them as occluded will actually improve Jaccard metrics and give
170 | # qualitatively better results.
171 |
172 | # all_confidence = [torch.stack(vp) for vp in confidence]
173 | # all_confidence = torch.stack(all_confidence)
174 | # utils.basic.print_stats('all_confidence', all_confidence)
175 |
176 | total_logprob_loss = 0.0
177 | for j in range(len(tracks)):
178 | n_predictions = len(tracks[j])
179 | logprob_loss = 0.0
180 | for i in range(n_predictions):
181 | # print('trajs_e', tracks[j][i].shape)
182 | # print('trajs_g', target_points[j].shape)
183 | err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2)
184 | positive = (err <= expected_dist_thresh**2).float()
185 | # print('conf', confidence[j][i].shape, 'positive', positive.shape)
186 | if use_logits:
187 | loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none")
188 | else:
189 | loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none")
190 | loss *= visibility[j].squeeze(2) # B,S,H,W
191 | loss = torch.mean(loss, dim=[1,2,3])
192 | logprob_loss += loss
193 | logprob_loss = logprob_loss / n_predictions
194 | total_logprob_loss += logprob_loss
195 | return total_logprob_loss / len(tracks)
196 |
197 |
198 | def masked_mean(data, mask, dim):
199 | if mask is None:
200 | return data.mean(dim=dim, keepdim=True)
201 | mask = mask.float()
202 | mask_sum = torch.sum(mask, dim=dim, keepdim=True)
203 | mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
204 | mask_sum, min=1.0
205 | )
206 | return mask_mean
207 |
208 |
209 | def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
210 | if mask is None:
211 | return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
212 | mask = mask.float()
213 | mask_sum = torch.sum(mask, dim=dim, keepdim=True)
214 | mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
215 | mask_sum, min=1.0
216 | )
217 | mask_var = torch.sum(
218 | mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
219 | ) / torch.clamp(mask_sum, min=1.0)
220 | return mask_mean.squeeze(dim), mask_var.squeeze(dim)
221 |
--------------------------------------------------------------------------------
/datasets/kubric_movif_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2
4 | import imageio
5 | import numpy as np
6 | import glob
7 | from pathlib import Path
8 | import utils.data
9 | from datasets.pointdataset import PointDataset
10 | import random
11 |
12 | class KubricMovifDataset(PointDataset):
13 | def __init__(
14 | self,
15 | data_root,
16 | crop_size=(384, 512),
17 | seq_len=24,
18 | traj_per_sample=768,
19 | traj_max_factor=24, # multiplier on traj_per_sample
20 | use_augs=False,
21 | random_seq_len=False,
22 | random_first_frame=False,
23 | random_frame_rate=False,
24 | random_number_traj=False,
25 | shuffle_frames=False,
26 | shuffle=True,
27 | only_first=False,
28 | ):
29 | super(KubricMovifDataset, self).__init__(
30 | data_root=data_root,
31 | crop_size=crop_size,
32 | seq_len=seq_len,
33 | traj_per_sample=traj_per_sample,
34 | use_augs=use_augs,
35 | )
36 | print('loading kubric S%d dataset...' % seq_len)
37 |
38 | self.dname = 'kubric%d' % seq_len
39 |
40 | self.only_first = only_first
41 | self.traj_max_factor = traj_max_factor
42 |
43 | self.random_seq_len = random_seq_len
44 | self.random_first_frame = random_first_frame
45 | self.random_frame_rate = random_frame_rate
46 | self.random_number_traj = random_number_traj
47 | self.shuffle_frames = shuffle_frames
48 | self.pad_bounds = [10, 100]
49 | self.resize_lim = [0.25, 2.0] # sample resizes from here
50 | self.resize_delta = 0.2
51 | self.max_crop_offset = 50
52 |
53 | folder_names = Path(data_root).glob('*/*/')
54 | folder_names = [str(fn) for fn in folder_names]
55 | folder_names = sorted(folder_names)
56 | # print('folder_names', folder_names)
57 | if shuffle:
58 | random.shuffle(folder_names)
59 |
60 | self.seq_names = []
61 | for fi, fol in enumerate(folder_names):
62 | npy_path = os.path.join(fol, "annot.npy")
63 | rgb_path = os.path.join(fol, "frames")
64 | if os.path.isdir(rgb_path) and os.path.isfile(npy_path):
65 | img_paths = sorted(os.listdir(rgb_path))
66 | if len(img_paths)>=seq_len:
67 | self.seq_names.append(fol)
68 | else:
69 | pass
70 | else:
71 | pass
72 |
73 | print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
74 |
75 | def getitem_helper(self, index):
76 | gotit = True
77 | fol = self.seq_names[index]
78 | npy_path = os.path.join(fol, "annot.npy")
79 | rgb_path = os.path.join(fol, "frames")
80 |
81 | seq_name = fol.split('/')[-1]
82 |
83 | img_paths = sorted(os.listdir(rgb_path))
84 | rgbs = []
85 | for i, img_path in enumerate(img_paths):
86 | rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
87 |
88 | rgbs = np.stack(rgbs)
89 | annot_dict = np.load(npy_path, allow_pickle=True).item()
90 | trajs = annot_dict["point_xys"]
91 | visibs = annot_dict["point_visibs"]
92 | valids = annot_dict["point_xys_valid"]
93 |
94 | S = len(rgbs)
95 | # ensure all valid, and then discard valids tensor
96 | all_valid = np.nonzero(np.sum(valids, axis=0)==S)[0]
97 | trajs = trajs[:,all_valid]
98 | visibs = visibs[:,all_valid]
99 |
100 | if self.use_augs and np.random.rand() < 0.5: # time flip
101 | # time flip
102 | rgbs = np.flip(rgbs, axis=[0]).copy()
103 | trajs = np.flip(trajs, axis=[0]).copy()
104 | visibs = np.flip(visibs, axis=[0]).copy()
105 |
106 | if self.shuffle_frames and np.random.rand() < 0.01:
107 | # shuffle the frames
108 | perm = np.random.permutation(rgbs.shape[0])
109 | rgbs = rgbs[perm]
110 | trajs = trajs[perm]
111 | visibs = visibs[perm]
112 |
113 | frame_rate = 1
114 | final_num_traj = self.traj_per_sample
115 | crop_size = self.crop_size
116 |
117 | # randomize time slice
118 | min_num_traj = 1
119 | assert self.traj_per_sample >= min_num_traj
120 | if self.random_seq_len and self.random_number_traj:
121 | final_num_traj = np.random.randint(min_num_traj, self.traj_per_sample)
122 | alpha = final_num_traj / float(self.traj_per_sample)
123 | seq_len = int(alpha * 10 + (1 - alpha) * self.seq_len)
124 | seq_len = np.random.randint(seq_len - 2, seq_len + 2)
125 | if self.random_frame_rate:
126 | frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
127 | elif self.random_number_traj:
128 | final_num_traj = np.random.randint(min_num_traj, self.traj_per_sample)
129 | alpha = final_num_traj / float(self.traj_per_sample)
130 | seq_len = 8 * int(alpha * 2 + (1 - alpha) * self.seq_len // 8)
131 | # seq_len = np.random.randint(seq_len , seq_len + 2)
132 | if self.random_frame_rate:
133 | frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
134 | elif self.random_seq_len:
135 | seq_len = np.random.randint(int(self.seq_len / 2), self.seq_len)
136 | if self.random_frame_rate:
137 | frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
138 | else:
139 | seq_len = self.seq_len
140 | if self.random_frame_rate:
141 | frame_rate = np.random.randint(1, int((120 / seq_len)) + 1)
142 | if seq_len < len(rgbs):
143 | if self.random_first_frame:
144 | ind0 = np.random.choice(len(rgbs))
145 | rgb0 = rgbs[ind0]
146 | traj0 = trajs[ind0]
147 | visib0 = visibs[ind0]
148 |
149 | if seq_len * frame_rate < len(rgbs):
150 | start_ind = np.random.choice(len(rgbs) - (seq_len * frame_rate), 1)[0]
151 | else:
152 | start_ind = 0
153 | # print('slice %d:%d:%d' % (start_ind, start_ind+seq_len*frame_rate, frame_rate))
154 | rgbs = rgbs[start_ind : start_ind + seq_len * frame_rate : frame_rate]
155 | trajs = trajs[start_ind : start_ind + seq_len * frame_rate : frame_rate]
156 | visibs = visibs[start_ind : start_ind + seq_len * frame_rate : frame_rate]
157 |
158 | if self.random_first_frame:
159 | rgbs[0] = rgb0
160 | trajs[0] = traj0
161 | visibs[0] = visib0
162 |
163 | assert seq_len == len(rgbs)
164 |
165 | # ensure no crazy values
166 | all_valid = np.nonzero(np.sum(np.sum(np.abs(trajs).astype(np.float64), axis=-1)<100000, axis=0)==seq_len)[0]
167 | trajs = trajs[:,all_valid]
168 | visibs = visibs[:,all_valid]
169 |
170 | if self.use_augs and np.random.rand() < 0.98:
171 | rgbs, trajs, visibs = self.add_photometric_augs(rgbs, trajs, visibs, replace=False)
172 | rgbs, trajs = self.add_spatial_augs(rgbs, trajs, visibs, crop_size)
173 | else:
174 | rgbs, trajs = self.crop(rgbs, trajs, crop_size)
175 |
176 | visibs[trajs[:, :, 0] > crop_size[1] - 1] = False
177 | visibs[trajs[:, :, 0] < 0] = False
178 | visibs[trajs[:, :, 1] > crop_size[0] - 1] = False
179 | visibs[trajs[:, :, 1] < 0] = False
180 |
181 | # ensure no crazy values
182 | all_valid = np.nonzero(np.sum(np.sum(np.abs(trajs), axis=-1)<100000, axis=0)==seq_len)[0]
183 | trajs = trajs[:,all_valid]
184 | visibs = visibs[:,all_valid]
185 |
186 | if self.shuffle_frames and np.random.rand() < 0.01:
187 | # shuffle the frames (again)
188 | perm = np.random.permutation(rgbs.shape[0])
189 | rgbs = rgbs[perm]
190 | trajs = trajs[perm]
191 | visibs = visibs[perm]
192 |
193 | if self.only_first:
194 | vis_ok = np.nonzero(visibs[0]==1)[0]
195 | trajs = trajs[:,vis_ok]
196 | visibs = visibs[:,vis_ok]
197 |
198 | visibs = torch.from_numpy(visibs)
199 | trajs = torch.from_numpy(trajs)
200 |
201 | crop_tensor = torch.tensor(crop_size).flip(0)[None, None] / 2.0
202 | close_pts_inds = torch.all(
203 | torch.linalg.vector_norm(trajs[..., :2] - crop_tensor, dim=-1) < 1000.0,
204 | dim=0,
205 | )
206 | trajs = trajs[:, close_pts_inds]
207 | visibs = visibs[:, close_pts_inds]
208 | N = trajs.shape[1]
209 |
210 | assert self.only_first
211 |
212 | N = trajs.shape[1]
213 | point_inds = torch.randperm(N)[:self.traj_per_sample*self.traj_max_factor]
214 |
215 | if len(point_inds) < self.traj_per_sample:
216 | gotit = False
217 |
218 | trajs = trajs[:, point_inds]
219 | visibs = visibs[:, point_inds]
220 | valids = torch.ones_like(visibs)
221 |
222 | rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()
223 |
224 | trajs = trajs[:, :self.traj_per_sample*self.traj_max_factor]
225 | visibs = visibs[:, :self.traj_per_sample*self.traj_max_factor]
226 | valids = valids[:, :self.traj_per_sample*self.traj_max_factor]
227 |
228 | sample = utils.data.VideoData(
229 | video=rgbs,
230 | trajs=trajs,
231 | visibs=visibs,
232 | valids=valids,
233 | seq_name=seq_name,
234 | dname=self.dname,
235 | )
236 | return sample, gotit
237 |
238 | def __len__(self):
239 | return len(self.seq_names)
240 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def get_1d_sincos_pos_embed_from_grid(embed_dim, positions):
5 | assert embed_dim % 2 == 0
6 | omega = torch.arange(embed_dim // 2, dtype=torch.double)
7 | omega /= embed_dim / 2.0
8 | omega = 1.0 / 10000**omega # (D/2,)
9 |
10 | positions = positions.reshape(-1) # (M,)
11 | out = torch.einsum("m,d->md", positions, omega) # (M, D/2), outer product
12 |
13 | emb_sin = torch.sin(out) # (M, D/2)
14 | emb_cos = torch.cos(out) # (M, D/2)
15 |
16 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
17 | return emb[None].float()
18 |
19 |
20 | class SimplePool():
21 | def __init__(self, pool_size, version='pt', min_size=1):
22 | self.pool_size = pool_size
23 | self.version = version
24 | self.items = []
25 | self.min_size = min_size
26 |
27 | if not (version=='pt' or version=='np'):
28 | print('version = %s; please choose pt or np')
29 | assert(False) # please choose pt or np
30 |
31 | def __len__(self):
32 | return len(self.items)
33 |
34 | def mean(self, min_size=None):
35 | if min_size is None:
36 | pool_size_thresh = self.min_size
37 | elif min_size=='half':
38 | pool_size_thresh = self.pool_size/2
39 | else:
40 | pool_size_thresh = min_size
41 |
42 | if self.version=='np':
43 | if len(self.items) >= pool_size_thresh:
44 | return np.sum(self.items)/float(len(self.items))
45 | else:
46 | return np.nan
47 | if self.version=='pt':
48 | if len(self.items) >= pool_size_thresh:
49 | return torch.sum(self.items)/float(len(self.items))
50 | else:
51 | return torch.from_numpy(np.nan)
52 |
53 | def sample(self, with_replacement=True):
54 | idx = np.random.randint(len(self.items))
55 | if with_replacement:
56 | return self.items[idx]
57 | else:
58 | return self.items.pop(idx)
59 |
60 | def fetch(self, num=None):
61 | if self.version=='pt':
62 | item_array = torch.stack(self.items)
63 | elif self.version=='np':
64 | item_array = np.stack(self.items)
65 | if num is not None:
66 | # there better be some items
67 | assert(len(self.items) >= num)
68 |
69 | # if there are not that many elements just return however many there are
70 | if len(self.items) < num:
71 | return item_array
72 | else:
73 | idxs = np.random.randint(len(self.items), size=num)
74 | return item_array[idxs]
75 | else:
76 | return item_array
77 |
78 | def is_full(self):
79 | full = len(self.items)==self.pool_size
80 | return full
81 |
82 | def empty(self):
83 | self.items = []
84 |
85 | def have_min_size(self):
86 | return len(self.items) >= self.min_size
87 |
88 |
89 | def update(self, items):
90 | for item in items:
91 | if len(self.items) < self.pool_size:
92 | # the pool is not full, so let's add this in
93 | self.items.append(item)
94 | else:
95 | # the pool is full
96 | # pop from the front
97 | self.items.pop(0)
98 | # add to the back
99 | self.items.append(item)
100 | return self.items
101 |
102 | def compute_tapvid_metrics(
103 | query_points: np.ndarray,
104 | gt_occluded: np.ndarray,
105 | gt_tracks: np.ndarray,
106 | pred_occluded: np.ndarray,
107 | pred_tracks: np.ndarray,
108 | query_mode: str,
109 | crop_size: tuple = (256, 256),
110 | ):
111 | """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
112 | See the TAP-Vid paper for details on the metric computation. All inputs are
113 | given in raster coordinates. The first three arguments should be the direct
114 | outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
115 | The paper metrics assume these are scaled relative to 256x256 images.
116 | pred_occluded and pred_tracks are your algorithm's predictions.
117 | This function takes a batch of inputs, and computes metrics separately for
118 | each video. The metrics for the full benchmark are a simple mean of the
119 | metrics across the full set of videos. These numbers are between 0 and 1,
120 | but the paper multiplies them by 100 to ease reading.
121 | Args:
122 | query_points: The query points, an in the format [t, y, x]. Its size is
123 | [b, n, 3], where b is the batch size and n is the number of queries
124 | gt_occluded: A boolean array of shape [b, n, t], where t is the number
125 | of frames. True indicates that the point is occluded.
126 | gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
127 | in the format [x, y]
128 | pred_occluded: A boolean array of predicted occlusions, in the same
129 | format as gt_occluded.
130 | pred_tracks: An array of track predictions from your algorithm, in the
131 | same format as gt_tracks.
132 | query_mode: Either 'first' or 'strided', depending on how queries are
133 | sampled. If 'first', we assume the prior knowledge that all points
134 | before the query point are occluded, and these are removed from the
135 | evaluation.
136 | Returns:
137 | A dict with the following keys:
138 | occlusion_accuracy: Accuracy at predicting occlusion.
139 | pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
140 | predicted to be within the given pixel threshold, ignoring occlusion
141 | prediction.
142 | jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
143 | threshold
144 | average_pts_within_thresh: average across pts_within_{x}
145 | average_jaccard: average across jaccard_{x}
146 | """
147 |
148 | metrics = {}
149 | # Fixed bug is described in:
150 | # https://github.com/facebookresearch/co-tracker/issues/20
151 | eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
152 |
153 | if query_mode == "first":
154 | # evaluate frames after the query frame
155 | query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
156 | elif query_mode == "strided":
157 | # evaluate all frames except the query frame
158 | query_frame_to_eval_frames = 1 - eye
159 | else:
160 | raise ValueError("Unknown query mode " + query_mode)
161 |
162 | query_frame = query_points[..., 0]
163 | query_frame = np.round(query_frame).astype(np.int32)
164 | evaluation_points = query_frame_to_eval_frames[query_frame] > 0
165 |
166 | # Occlusion accuracy is simply how often the predicted occlusion equals the
167 | # ground truth.
168 | occ_acc = np.sum(
169 | np.equal(pred_occluded, gt_occluded) & evaluation_points,
170 | axis=(1, 2),
171 | ) / np.sum(evaluation_points)
172 | metrics["occlusion_accuracy"] = occ_acc
173 |
174 | # Next, convert the predictions and ground truth positions into pixel
175 | # coordinates.
176 | visible = np.logical_not(gt_occluded)
177 | pred_visible = np.logical_not(pred_occluded)
178 | all_frac_within = []
179 | all_jaccard = []
180 | sx_ = (crop_size[1] - 1) / 255.0
181 | sy_ = (crop_size[0] - 1) / 255.0
182 | sc_pt = np.array([sx_, sy_]).reshape([1, 1, 1, 2])
183 |
184 | for thresh in [1, 2, 4, 8, 16]:
185 | # True positives are points that are within the threshold and where both
186 | # the prediction and the ground truth are listed as visible.
187 | within_dist = np.sum(
188 | np.square(pred_tracks / sc_pt - gt_tracks / sc_pt),
189 | axis=-1,
190 | ) < np.square(thresh)
191 | is_correct = np.logical_and(within_dist, visible)
192 |
193 | # Compute the frac_within_threshold, which is the fraction of points
194 | # within the threshold among points that are visible in the ground truth,
195 | # ignoring whether they're predicted to be visible.
196 | count_correct = np.sum(
197 | is_correct & evaluation_points,
198 | axis=(1, 2),
199 | )
200 | count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
201 | frac_correct = count_correct / count_visible_points
202 | metrics["pts_within_" + str(thresh)] = frac_correct
203 | all_frac_within.append(frac_correct)
204 |
205 | true_positives = np.sum(
206 | is_correct & pred_visible & evaluation_points, axis=(1, 2)
207 | )
208 |
209 | # The denominator of the jaccard metric is the true positives plus
210 | # false positives plus false negatives. However, note that true positives
211 | # plus false negatives is simply the number of points in the ground truth
212 | # which is easier to compute than trying to compute all three quantities.
213 | # Thus we just add the number of points in the ground truth to the number
214 | # of false positives.
215 | #
216 | # False positives are simply points that are predicted to be visible,
217 | # but the ground truth is not visible or too far from the prediction.
218 | gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
219 | false_positives = (~visible) & pred_visible
220 | false_positives = false_positives | ((~within_dist) & pred_visible)
221 | false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
222 | jaccard = true_positives / (gt_positives + false_positives)
223 | metrics["jaccard_" + str(thresh)] = jaccard
224 | all_jaccard.append(jaccard)
225 | metrics["average_jaccard"] = np.mean(
226 | np.stack(all_jaccard, axis=1),
227 | axis=1,
228 | )
229 | metrics["average_pts_within_thresh"] = np.mean(
230 | np.stack(all_frac_within, axis=1),
231 | axis=1,
232 | )
233 | return metrics
234 |
235 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import argparse
4 | import utils.saveload
5 | import utils.basic
6 | import utils.improc
7 | import PIL.Image
8 | import numpy as np
9 | import os
10 | from prettytable import PrettyTable
11 | import time
12 |
13 | def read_mp4(name_path):
14 | vidcap = cv2.VideoCapture(name_path)
15 | framerate = int(round(vidcap.get(cv2.CAP_PROP_FPS)))
16 | print('framerate', framerate)
17 | frames = []
18 | while vidcap.isOpened():
19 | ret, frame = vidcap.read()
20 | if ret == False:
21 | break
22 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
23 | frames.append(frame)
24 | vidcap.release()
25 | return frames, framerate
26 |
27 | def draw_pts_gpu(rgbs, trajs, visibs, colormap, rate=1, bkg_opacity=0.5):
28 | device = rgbs.device
29 | T, C, H, W = rgbs.shape
30 | trajs = trajs.permute(1,0,2) # N,T,2
31 | visibs = visibs.permute(1,0) # N,T
32 | N = trajs.shape[0]
33 | colors = torch.tensor(colormap, dtype=torch.float32, device=device) # [N,3]
34 |
35 | rgbs = rgbs * bkg_opacity # darken, to see the point tracks better
36 |
37 | opacity = 1.0
38 | if rate==1:
39 | radius = 1
40 | opacity = 0.9
41 | elif rate==2:
42 | radius = 1
43 | elif rate== 4:
44 | radius = 2
45 | elif rate== 8:
46 | radius = 4
47 | else:
48 | radius = 6
49 | sharpness = 0.15 + 0.05 * np.log2(rate)
50 |
51 | D = radius * 2 + 1
52 | y = torch.arange(D, device=device).float()[:, None] - radius
53 | x = torch.arange(D, device=device).float()[None, :] - radius
54 | dist2 = x**2 + y**2
55 | icon = torch.clamp(1 - (dist2 - (radius**2) / 2.0) / (radius * 2 * sharpness), 0, 1) # [D,D]
56 | icon = icon.view(1, D, D)
57 | dx = torch.arange(-radius, radius + 1, device=device)
58 | dy = torch.arange(-radius, radius + 1, device=device)
59 | disp_y, disp_x = torch.meshgrid(dy, dx, indexing="ij") # [D,D]
60 | for t in range(T):
61 | mask = visibs[:, t] # [N]
62 | if mask.sum() == 0:
63 | continue
64 | xy = trajs[mask, t] + 0.5 # [N,2]
65 | xy[:, 0] = xy[:, 0].clamp(0, W - 1)
66 | xy[:, 1] = xy[:, 1].clamp(0, H - 1)
67 | colors_now = colors[mask] # [N,3]
68 | N = xy.shape[0]
69 | cx = xy[:, 0].long() # [N]
70 | cy = xy[:, 1].long()
71 | x_grid = cx[:, None, None] + disp_x # [N,D,D]
72 | y_grid = cy[:, None, None] + disp_y # [N,D,D]
73 | valid = (x_grid >= 0) & (x_grid < W) & (y_grid >= 0) & (y_grid < H)
74 | x_valid = x_grid[valid] # [K]
75 | y_valid = y_grid[valid]
76 | icon_weights = icon.expand(N, D, D)[valid] # [K]
77 | colors_valid = colors_now[:, :, None, None].expand(N, 3, D, D).permute(1, 0, 2, 3)[
78 | :, valid
79 | ] # [3, K]
80 | idx_flat = (y_valid * W + x_valid).long() # [K]
81 |
82 | accum = torch.zeros_like(rgbs[t]) # [3, H, W]
83 | weight = torch.zeros(1, H * W, device=device) # [1, H*W]
84 | img_flat = accum.view(C, -1) # [3, H*W]
85 | weighted_colors = colors_valid * icon_weights # [3, K]
86 | img_flat.scatter_add_(1, idx_flat.unsqueeze(0).expand(C, -1), weighted_colors)
87 | weight.scatter_add_(1, idx_flat.unsqueeze(0), icon_weights.unsqueeze(0))
88 | weight = weight.view(1, H, W)
89 |
90 | alpha = weight.clamp(0, 1) * opacity
91 | accum = accum / (weight + 1e-6) # [3, H, W]
92 | rgbs[t] = rgbs[t] * (1 - alpha) + accum * alpha
93 | rgbs = rgbs.clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy() # T,H,W,3
94 | if bkg_opacity==0.0:
95 | for t in range(T):
96 | hsv_frame = cv2.cvtColor(rgbs[t], cv2.COLOR_RGB2HSV)
97 | saturation_factor = 1.5
98 | hsv_frame[..., 1] = np.clip(hsv_frame[..., 1] * saturation_factor, 0, 255)
99 | rgbs[t] = cv2.cvtColor(hsv_frame, cv2.COLOR_HSV2RGB)
100 | return rgbs
101 |
102 | def count_parameters(model):
103 | table = PrettyTable(["Modules", "Parameters"])
104 | total_params = 0
105 | for name, parameter in model.named_parameters():
106 | if not parameter.requires_grad:
107 | continue
108 | param = parameter.numel()
109 | if param > 100000:
110 | table.add_row([name, param])
111 | total_params+=param
112 | print(table)
113 | print('total params: %.2f M' % (total_params/1000000.0))
114 | return total_params
115 |
116 | def forward_video(rgbs, framerate, model, args):
117 |
118 | B,T,C,H,W = rgbs.shape
119 | assert C == 3
120 | device = rgbs.device
121 | assert(B==1)
122 |
123 | grid_xy = utils.basic.gridcloud2d(1, H, W, norm=False, device='cuda:0').float() # 1,H*W,2
124 | grid_xy = grid_xy.permute(0,2,1).reshape(1,1,2,H,W) # 1,1,2,H,W
125 |
126 | torch.cuda.empty_cache()
127 | print('starting forward...')
128 | f_start_time = time.time()
129 |
130 | flows_e, visconf_maps_e, _, _ = \
131 | model.forward_sliding(rgbs[:, args.query_frame:], iters=args.inference_iters, sw=None, is_training=False)
132 | traj_maps_e = flows_e.cuda() + grid_xy # B,Tf,2,H,W
133 | if args.query_frame > 0:
134 | backward_flows_e, backward_visconf_maps_e, _, _ = \
135 | model.forward_sliding(rgbs[:, :args.query_frame+1].flip([1]), iters=args.inference_iters, sw=None, is_training=False)
136 | backward_traj_maps_e = backward_flows_e.cuda() + grid_xy # B,Tb,2,H,W, reversed
137 | backward_traj_maps_e = backward_traj_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
138 | backward_visconf_maps_e = backward_visconf_maps_e.flip([1])[:, :-1] # flip time and drop the overlapped frame
139 | traj_maps_e = torch.cat([backward_traj_maps_e, traj_maps_e], dim=1) # B,T,2,H,W
140 | visconf_maps_e = torch.cat([backward_visconf_maps_e, visconf_maps_e], dim=1) # B,T,2,H,W
141 | ftime = time.time()-f_start_time
142 | print('finished forward; %.2f seconds / %d frames; %d fps' % (ftime, T, round(T/ftime)))
143 | utils.basic.print_stats('traj_maps_e', traj_maps_e)
144 | utils.basic.print_stats('visconf_maps_e', visconf_maps_e)
145 |
146 | # subsample to make the vis more readable
147 | rate = args.rate
148 | trajs_e = traj_maps_e[:,:,:,::rate,::rate].reshape(B,T,2,-1).permute(0,1,3,2) # B,T,N,2
149 | visconfs_e = visconf_maps_e[:,:,:,::rate,::rate].reshape(B,T,2,-1).permute(0,1,3,2) # B,T,N,2
150 |
151 | xy0 = trajs_e[0,0].cpu().numpy()
152 | colors = utils.improc.get_2d_colors(xy0, H, W)
153 |
154 | fn = args.mp4_path.split('/')[-1].split('.')[0]
155 | rgb_out_f = './pt_vis_%s_rate%d_q%d.mp4' % (fn, rate, args.query_frame)
156 | print('rgb_out_f', rgb_out_f)
157 | temp_dir = 'temp_pt_vis_%s_rate%d_q%d' % (fn, rate, args.query_frame)
158 | utils.basic.mkdir(temp_dir)
159 | vis = []
160 |
161 | frames = draw_pts_gpu(rgbs[0].to('cuda:0'), trajs_e[0], visconfs_e[0,:,:,1] > args.conf_thr,
162 | colors, rate=rate, bkg_opacity=args.bkg_opacity)
163 | print('frames', frames.shape)
164 |
165 | if args.vstack:
166 | frames_top = rgbs[0].clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy() # T,H,W,3
167 | frames = np.concatenate([frames_top, frames], axis=1)
168 | elif args.hstack:
169 | frames_left = rgbs[0].clamp(0, 255).byte().permute(0, 2, 3, 1).cpu().numpy() # T,H,W,3
170 | frames = np.concatenate([frames_left, frames], axis=2)
171 |
172 | print('writing frames to disk')
173 | f_start_time = time.time()
174 | for ti in range(T):
175 | temp_out_f = '%s/%03d.jpg' % (temp_dir, ti)
176 | im = PIL.Image.fromarray(frames[ti])
177 | im.save(temp_out_f)#, "PNG", subsampling=0, quality=80)
178 | ftime = time.time()-f_start_time
179 | print('finished writing; %.2f seconds / %d frames; %d fps' % (ftime, T, round(T/ftime)))
180 |
181 | print('writing mp4')
182 | os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate %d -pattern_type glob -i "./%s/*.jpg" -c:v libx264 -crf 20 -pix_fmt yuv420p %s' % (framerate, temp_dir, rgb_out_f))
183 |
184 | # # flow vis
185 | # rgb_out_f = './flow_vis.mp4'
186 | # temp_dir = 'temp_flow_vis'
187 | # utils.basic.mkdir(temp_dir)
188 | # vis = []
189 | # for ti in range(T):
190 | # flow_vis = utils.improc.flow2color(flows_e[0:1,ti])
191 | # vis.append(flow_vis)
192 | # for ti in range(T):
193 | # temp_out_f = '%s/%03d.png' % (temp_dir, ti)
194 | # im = PIL.Image.fromarray(vis[ti][0].permute(1,2,0).cpu().numpy())
195 | # im.save(temp_out_f, "PNG", subsampling=0, quality=100)
196 | # os.system('/usr/bin/ffmpeg -y -hide_banner -loglevel error -f image2 -framerate 24 -pattern_type glob -i "./%s/*.png" -c:v libx264 -crf 1 -pix_fmt yuv420p %s' % (temp_dir, rgb_out_f))
197 |
198 | return None
199 |
200 | def run(model, args):
201 | log_dir = './logs_demo'
202 |
203 | global_step = 0
204 |
205 | if args.ckpt_init:
206 | _ = utils.saveload.load(
207 | None,
208 | args.ckpt_init,
209 | model,
210 | optimizer=None,
211 | scheduler=None,
212 | ignore_load=None,
213 | strict=True,
214 | verbose=False,
215 | weights_only=False,
216 | )
217 | print('loaded weights from', args.ckpt_init)
218 | else:
219 | if args.tiny:
220 | url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker_tiny.pth"
221 | else:
222 | url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth"
223 | state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
224 | model.load_state_dict(state_dict['model'], strict=True)
225 | print('loaded weights from', url)
226 |
227 | model.cuda()
228 | for n, p in model.named_parameters():
229 | p.requires_grad = False
230 | model.eval()
231 |
232 | rgbs, framerate = read_mp4(args.mp4_path)
233 | print('rgbs[0]', rgbs[0].shape)
234 | H,W = rgbs[0].shape[:2]
235 |
236 | # shorten & shrink the video, in case the gpu is small
237 | if args.max_frames:
238 | rgbs = rgbs[:args.max_frames]
239 | scale = min(int(args.image_size)/H, int(args.image_size)/W)
240 | H, W = int(H*scale), int(W*scale)
241 | H, W = H//8 * 8, W//8 * 8 # make it divisible by 8
242 | rgbs = [cv2.resize(rgb, dsize=(W, H), interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
243 | print('rgbs[0]', rgbs[0].shape)
244 |
245 | # move to gpu
246 | rgbs = [torch.from_numpy(rgb).permute(2,0,1) for rgb in rgbs]
247 | rgbs = torch.stack(rgbs, dim=0).unsqueeze(0).float() # 1,T,C,H,W
248 | print('rgbs', rgbs.shape)
249 |
250 | with torch.no_grad():
251 | metrics = forward_video(rgbs, framerate, model, args)
252 |
253 | return None
254 |
255 | if __name__ == "__main__":
256 | torch.set_grad_enabled(False)
257 |
258 | parser = argparse.ArgumentParser()
259 | parser.add_argument("--ckpt_init", type=str, default='') # the ckpt we want (else default)
260 | parser.add_argument("--mp4_path", type=str, default='./demo_video/monkey.mp4') # input video
261 | parser.add_argument("--query_frame", type=int, default=0) # which frame to track from
262 | parser.add_argument("--image_size", type=int, default=1024) # max dimension of a video frame (upsample to this)
263 | parser.add_argument("--max_frames", type=int, default=400) # trim the video to this length
264 | parser.add_argument("--inference_iters", type=int, default=4) # number of inference steps per forward
265 | parser.add_argument("--window_len", type=int, default=16) # model hyperparam
266 | parser.add_argument("--rate", type=int, default=2) # vis hyp
267 | parser.add_argument("--conf_thr", type=float, default=0.1) # vis hyp
268 | parser.add_argument("--bkg_opacity", type=float, default=0.5) # vis hyp
269 | parser.add_argument("--vstack", action='store_true', default=False) # whether to stack the input and output in the mp4
270 | parser.add_argument("--hstack", action='store_true', default=False) # whether to stack the input and output in the mp4
271 | parser.add_argument("--tiny", action='store_true', default=False) # whether to use the tiny model
272 | args = parser.parse_args()
273 |
274 | from nets.alltracker import Net;
275 | if args.tiny:
276 | model = Net(args.window_len, use_basicencoder=True, no_split=True)
277 | else:
278 | model = Net(args.window_len)
279 | count_parameters(model)
280 |
281 | run(model, args)
282 |
283 |
--------------------------------------------------------------------------------
/datasets/exportdataset.py:
--------------------------------------------------------------------------------
1 | from numpy import random
2 | import torch
3 | import numpy as np
4 | import os
5 | import random
6 | import imageio
7 | from pathlib import Path
8 | import matplotlib.pyplot as plt
9 | from utils.basic import print_stats
10 | from PIL import Image
11 | import cv2
12 | import utils.py
13 | import torch.nn.functional as F
14 | import torchvision.transforms.functional as tvF
15 | from torchvision.transforms import ColorJitter, GaussianBlur
16 | from datasets.pointdataset import PointDataset
17 |
18 | class ExportDataset(PointDataset):
19 | def __init__(self,
20 | data_root='../datasets/alltrack_export',
21 | version='bs',
22 | dsets=None,
23 | dsets_exclude=None,
24 | seq_len=64,
25 | crop_size=(384,512),
26 | shuffle_frames=False,
27 | shuffle=True,
28 | use_augs=False,
29 | is_training=True,
30 | backwards=False,
31 | traj_per_sample=256, # min number of trajs
32 | traj_max_factor=24, # multiplier on traj_per_sample
33 | random_seq_len=False,
34 | random_frame_rate=False,
35 | random_number_traj=False,
36 | only_first=False,
37 | ):
38 | super(ExportDataset, self).__init__(
39 | data_root=data_root,
40 | crop_size=crop_size,
41 | seq_len=seq_len,
42 | traj_per_sample=traj_per_sample,
43 | use_augs=use_augs,
44 | )
45 | print('loading export...')
46 |
47 | self.shuffle_frames = shuffle_frames
48 | self.pad_bounds = [10, 100]
49 | self.resize_lim = [0.25, 2.0] # sample resizes from here
50 | self.resize_delta = 0.2
51 | self.max_crop_offset = 50
52 | self.only_first = only_first
53 | self.traj_max_factor = traj_max_factor
54 |
55 | self.S = seq_len
56 |
57 | self.use_augs = use_augs
58 | self.is_training = is_training
59 |
60 | self.dataset_location = Path(data_root) / version
61 |
62 | dataset_names = self.dataset_location.glob('*/')
63 | self.dataset_names = [str(fn.stem) for fn in dataset_names]
64 | self.dataset_names = ['%s%d' % (dname, self.S) for dname in self.dataset_names]
65 | print('dataset_names', self.dataset_names)
66 |
67 | folder_names = self.dataset_location.glob('*/*/*/')
68 | folder_names = [str(fn) for fn in folder_names]
69 | # print('folder_names', folder_names)
70 |
71 | print('found {:d} {} folders in {}'.format(len(folder_names), version, self.dataset_location))
72 |
73 | if dsets is not None:
74 | print('dsets', dsets)
75 | new_folder_names = []
76 | for fn in folder_names:
77 | for dset in dsets:
78 | if dset in fn:
79 | new_folder_names.append(fn)
80 | break
81 | folder_names = new_folder_names
82 | print('filtered to %d folders' % len(folder_names))
83 |
84 | if backwards:
85 | new_folder_names = []
86 | for fn in folder_names:
87 | chunk = fn.split('/')[-2]
88 | if 'b' in chunk:
89 | new_folder_names.append(fn)
90 | folder_names = new_folder_names
91 | print('filtered to %d folders with backward motion' % len(folder_names))
92 |
93 | if dsets_exclude is not None:
94 | print('dsets_exclude', dsets_exclude)
95 | new_folder_names = []
96 | for fn in folder_names:
97 | keep = True
98 | for dset in dsets_exclude:
99 | if dset in fn:
100 | keep = False
101 | break
102 | if keep:
103 | new_folder_names.append(fn)
104 | folder_names = new_folder_names
105 | print('filtered to %d folders' % len(folder_names))
106 |
107 | # if quick:
108 | # folder_names = sorted(folder_names)
109 | # folder_names = folder_names[:201]
110 | # print('folder_names', folder_names)
111 |
112 | if shuffle:
113 | random.shuffle(folder_names)
114 | else:
115 | folder_names = sorted(list(folder_names))
116 |
117 | self.all_folders = folder_names
118 | # # step through once and make sure all of the npys are there
119 | # print('stepping through...')
120 | # self.all_folders = []
121 | # for fi, fol in enumerate(folder_names):
122 | # npy_path = os.path.join(fol, "annot.npy")
123 | # rgb_path = os.path.join(fol, "frames")
124 | # if os.path.isdir(rgb_path) and os.path.isfile(npy_path):
125 | # img_paths = sorted(os.listdir(rgb_path))
126 | # if len(img_paths)>=self.S:
127 | # self.all_folders.append(fol)
128 | # else:
129 | # pass
130 | # else:
131 | # pass
132 | # print('ok done stepping; got %d' % len(self.all_folders))
133 |
134 |
135 | def getitem_helper(self, index):
136 | # cH, cW = self.cH, self.cW
137 |
138 | fol = self.all_folders[index]
139 | npy_path = os.path.join(fol, "annot.npy")
140 | rgb_path = os.path.join(fol, "frames")
141 |
142 | mid = str(fol)[len(str(self.dataset_location))+1:]
143 | dname = mid.split('/')[0]
144 | # print('dname', dname)
145 |
146 | img_paths = sorted(os.listdir(rgb_path))
147 | rgbs = []
148 | try:
149 | for i, img_path in enumerate(img_paths):
150 | rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
151 | except:
152 | print('some exception when reading rgbs')
153 |
154 | if len(rgbs) self.S:
189 | surplus = trajs.shape[0] - self.S
190 | ind = np.random.randint(surplus)+1
191 | rgbs = rgbs[ind:ind+self.S]
192 | trajs = trajs[ind:ind+self.S]
193 | visibs = visibs[ind:ind+self.S]
194 | assert(trajs.shape[0] == self.S)
195 |
196 | if self.only_first:
197 | vis_ok = np.nonzero(visibs[0]==1)[0]
198 | trajs = trajs[:,vis_ok]
199 | visibs = visibs[:,vis_ok]
200 |
201 | N = trajs.shape[1]
202 | if N < self.traj_per_sample:
203 | print('exp: %s; N after vis0: %d' % (dname, N))
204 | return None, False
205 |
206 | # print('rgbs', rgbs.shape)
207 | if H > self.crop_size[0]*2 and W > self.crop_size[1]*2 and np.random.rand() < 0.5:
208 | scale = 0.5
209 | rgbs = [cv2.resize(rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
210 | H, W = rgbs[0].shape[:2]
211 | rgbs = np.stack(rgbs, axis=0) # S,H,W,3
212 | trajs = trajs * scale
213 | # print('resized rgbs', rgbs.shape)
214 |
215 | if H > self.crop_size[0]*2 and W > self.crop_size[1]*2 and np.random.rand() < 0.5:
216 | scale = 0.5
217 | rgbs = [cv2.resize(rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
218 | H, W = rgbs[0].shape[:2]
219 | rgbs = np.stack(rgbs, axis=0) # S,H,W,3
220 | trajs = trajs * scale
221 | # print('resized rgbs', rgbs.shape)
222 |
223 | if self.use_augs and np.random.rand() < 0.98:
224 | rgbs, trajs, visibs = self.add_photometric_augs(
225 | rgbs, trajs, visibs, replace=False,
226 | )
227 | if np.random.rand() < 0.2:
228 | rgbs, trajs = self.add_spatial_augs(rgbs, trajs, visibs, self.crop_size)
229 | else:
230 | rgbs, trajs = self.follow_crop(rgbs, trajs, visibs, self.crop_size)
231 | if np.random.rand() < self.rot_prob:
232 | # note this is OK since B==1
233 | # otw we would do it before this func
234 | rgbs = [np.transpose(rgb, (1,0,2)).copy() for rgb in rgbs]
235 | rgbs = np.stack(rgbs)
236 | trajs = np.flip(trajs, axis=2).copy()
237 | H, W = rgbs[0].shape[:2]
238 | if np.random.rand() < self.h_flip_prob:
239 | rgbs = [rgb[:, ::-1].copy() for rgb in rgbs]
240 | trajs[:, :, 0] = W - trajs[:, :, 0]
241 | rgbs = np.stack(rgbs)
242 | if np.random.rand() < self.v_flip_prob:
243 | rgbs = [rgb[::-1].copy() for rgb in rgbs]
244 | trajs[:, :, 1] = H - trajs[:, :, 1]
245 | rgbs = np.stack(rgbs)
246 | else:
247 | rgbs, trajs = self.crop(rgbs, trajs, self.crop_size)
248 |
249 | if self.shuffle_frames and np.random.rand() < 0.01:
250 | # shuffle the frames (again)
251 | perm = np.random.permutation(rgbs.shape[0])
252 | rgbs = rgbs[perm]
253 | trajs = trajs[perm]
254 | visibs = visibs[perm]
255 |
256 | H,W = rgbs[0].shape[:2]
257 |
258 | visibs[trajs[:, :, 0] > W-1] = False
259 | visibs[trajs[:, :, 0] < 0] = False
260 | visibs[trajs[:, :, 1] > H-1] = False
261 | visibs[trajs[:, :, 1] < 0] = False
262 |
263 | N = trajs.shape[1]
264 | # print('N8', N)
265 |
266 | # ensure no crazy values
267 | all_valid = np.nonzero(np.sum(np.sum(np.abs(trajs), axis=-1)<100000, axis=0)==self.S)[0]
268 | trajs = trajs[:,all_valid]
269 | visibs = visibs[:,all_valid]
270 |
271 | if self.only_first:
272 | vis_ok = np.nonzero(visibs[0]==1)[0]
273 | trajs = trajs[:,vis_ok]
274 | visibs = visibs[:,vis_ok]
275 | N = trajs.shape[1]
276 |
277 | if N < self.traj_per_sample:
278 | print('exp: %s; N after aug: %d' % (dname, N))
279 | return None, False
280 |
281 | N = trajs.shape[1]
282 |
283 | seq_len = S
284 | visibs = torch.from_numpy(visibs)
285 | trajs = torch.from_numpy(trajs)
286 |
287 | # discard tracks that go far OOB
288 | crop_tensor = torch.tensor(self.crop_size).flip(0)[None, None] / 2.0
289 | close_pts_inds = torch.all(
290 | torch.linalg.vector_norm(trajs[..., :2] - crop_tensor, dim=-1) < max(H,W)*2,
291 | dim=0,
292 | )
293 | trajs = trajs[:, close_pts_inds]
294 | visibs = visibs[:, close_pts_inds]
295 |
296 | visible_pts_inds = (visibs[0]).nonzero(as_tuple=False)[:, 0]
297 | point_inds = torch.randperm(len(visible_pts_inds))[:self.traj_per_sample*self.traj_max_factor]
298 | if len(point_inds) < self.traj_per_sample:
299 | # print('not enough trajs')
300 | # gotit = False
301 | return None, False
302 |
303 | visible_inds_sampled = visible_pts_inds[point_inds]
304 |
305 | trajs = trajs[:, visible_inds_sampled].float()
306 | visibs = visibs[:, visible_inds_sampled].float()
307 | valids = torch.ones_like(visibs).float()
308 |
309 | trajs = trajs[:, :self.traj_per_sample*self.traj_max_factor]
310 | visibs = visibs[:, :self.traj_per_sample*self.traj_max_factor]
311 | valids = valids[:, :self.traj_per_sample*self.traj_max_factor]
312 |
313 | rgbs = torch.from_numpy(rgbs).permute(0, 3, 1, 2).float()
314 |
315 | dname += '%d' % (self.S)
316 |
317 | sample = utils.data.VideoData(
318 | video=rgbs,
319 | trajs=trajs,
320 | visibs=visibs,
321 | valids=valids,
322 | seq_name=None,
323 | dname=dname,
324 | )
325 | return sample, True
326 |
327 | def __len__(self):
328 | return len(self.all_folders)
329 |
--------------------------------------------------------------------------------
/datasets/pointdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2
4 | import imageio
5 | import numpy as np
6 | from torchvision.transforms import ColorJitter, GaussianBlur
7 | from PIL import Image
8 | import utils.data
9 |
10 | class PointDataset(torch.utils.data.Dataset):
11 | def __init__(
12 | self,
13 | data_root,
14 | crop_size=(384, 512),
15 | seq_len=24,
16 | traj_per_sample=768,
17 | use_augs=False,
18 | ):
19 | super(PointDataset, self).__init__()
20 | self.data_root = data_root
21 | self.seq_len = seq_len
22 | self.traj_per_sample = traj_per_sample
23 | self.use_augs = use_augs
24 | # photometric augmentation
25 | self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
26 | self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
27 | self.blur_aug_prob = 0.25
28 | self.color_aug_prob = 0.25
29 |
30 | # occlusion augmentation
31 | self.eraser_aug_prob = 0.5
32 | self.eraser_bounds = [2, 100]
33 | self.eraser_max = 10
34 |
35 | # occlusion augmentation
36 | self.replace_aug_prob = 0.5
37 | self.replace_bounds = [2, 100]
38 | self.replace_max = 10
39 |
40 | # spatial augmentations
41 | self.pad_bounds = [10, 100]
42 | self.crop_size = crop_size
43 | self.resize_lim = [0.25, 2.0] # sample resizes from here
44 | self.resize_delta = 0.2
45 | self.max_crop_offset = 50
46 |
47 | self.h_flip_prob = 0.5
48 | self.v_flip_prob = 0.5
49 | self.rot_prob = 0.5
50 |
51 | def getitem_helper(self, index):
52 | return NotImplementedError
53 |
54 | def __getitem__(self, index):
55 | gotit = False
56 | fails = 0
57 | while not gotit and fails < 4:
58 | sample, gotit = self.getitem_helper(index)
59 | if gotit:
60 | return sample, gotit
61 | else:
62 | fails += 1
63 | index = np.random.randint(len(self))
64 | del sample
65 | if fails > 1:
66 | print('note: sampling failed %d times' % fails)
67 |
68 | if self.seq_len is not None:
69 | S = self.seq_len
70 | else:
71 | S = 11
72 | # fake sample, so we can still collate
73 | sample = utils.data.VideoData(
74 | video=torch.zeros((S, 3, self.crop_size[0], self.crop_size[1])),
75 | trajs=torch.zeros((S, self.traj_per_sample, 2)),
76 | visibs=torch.zeros((S, self.traj_per_sample)),
77 | valids=torch.zeros((S, self.traj_per_sample)),
78 | )
79 | return sample, gotit
80 |
81 | def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True, augscale=1.0):
82 | T, N, _ = trajs.shape
83 |
84 | S = len(rgbs)
85 | H, W = rgbs[0].shape[:2]
86 | assert S == T
87 |
88 | if eraser:
89 | ############ eraser transform (per image after the first) ############
90 | eraser_bounds = [eb*augscale for eb in self.eraser_bounds]
91 | rgbs = [rgb.astype(np.float32) for rgb in rgbs]
92 | for i in range(1, S):
93 | if np.random.rand() < self.eraser_aug_prob:
94 | for _ in range(np.random.randint(1, self.eraser_max + 1)):
95 | # number of times to occlude
96 | xc = np.random.randint(0, W)
97 | yc = np.random.randint(0, H)
98 | dx = np.random.randint(eraser_bounds[0], eraser_bounds[1])
99 | dy = np.random.randint(eraser_bounds[0], eraser_bounds[1])
100 | x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
101 | x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
102 | y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
103 | y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
104 |
105 | mean_color = np.mean(
106 | rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
107 | )
108 | rgbs[i][y0:y1, x0:x1, :] = mean_color
109 |
110 | occ_inds = np.logical_and(
111 | np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
112 | np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
113 | )
114 | visibles[i, occ_inds] = 0
115 | rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
116 |
117 | if replace:
118 | rgbs_alt = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
119 | rgbs_alt = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt]
120 |
121 | ############ replace transform (per image after the first) ############
122 | rgbs = [rgb.astype(np.float32) for rgb in rgbs]
123 | rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
124 | replace_bounds = [rb*augscale for rb in self.replace_bounds]
125 | for i in range(1, S):
126 | if np.random.rand() < self.replace_aug_prob:
127 | for _ in range(
128 | np.random.randint(1, self.replace_max + 1)
129 | ): # number of times to occlude
130 | xc = np.random.randint(0, W)
131 | yc = np.random.randint(0, H)
132 | dx = np.random.randint(replace_bounds[0], replace_bounds[1])
133 | dy = np.random.randint(replace_bounds[0], replace_bounds[1])
134 | x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
135 | x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
136 | y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
137 | y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
138 |
139 | wid = x1 - x0
140 | hei = y1 - y0
141 | y00 = np.random.randint(0, H - hei)
142 | x00 = np.random.randint(0, W - wid)
143 | fr = np.random.randint(0, S)
144 | rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
145 | rgbs[i][y0:y1, x0:x1, :] = rep
146 |
147 | occ_inds = np.logical_and(
148 | np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
149 | np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
150 | )
151 | visibles[i, occ_inds] = 0
152 | rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
153 |
154 | ############ photometric augmentation ############
155 | if np.random.rand() < self.color_aug_prob:
156 | # random per-frame amount of aug
157 | rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
158 |
159 | if np.random.rand() < self.blur_aug_prob:
160 | # random per-frame amount of blur
161 | rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
162 |
163 | return rgbs, trajs, visibles
164 |
165 |
166 | def add_spatial_augs(self, rgbs, trajs, visibles, crop_size, augscale=1.0):
167 | T, N, __ = trajs.shape
168 |
169 | S = len(rgbs)
170 | H, W = rgbs[0].shape[:2]
171 | assert S == T
172 |
173 | rgbs = [rgb.astype(np.float32) for rgb in rgbs]
174 |
175 | trajs = trajs.astype(np.float64)
176 |
177 | target_H, target_W = crop_size
178 | if target_H > H or target_W > W:
179 | scale = max(target_H / H, target_W / W)
180 | new_H, new_W = int(np.ceil(H * scale)), int(np.ceil(W * scale))
181 | rgbs = [cv2.resize(rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
182 | trajs = trajs * scale
183 |
184 | ############ spatial transform ############
185 |
186 | # padding
187 | pad_bounds = [int(pb*augscale) for pb in self.pad_bounds]
188 | pad_x0 = np.random.randint(pad_bounds[0], pad_bounds[1])
189 | pad_x1 = np.random.randint(pad_bounds[0], pad_bounds[1])
190 | pad_y0 = np.random.randint(pad_bounds[0], pad_bounds[1])
191 | pad_y1 = np.random.randint(pad_bounds[0], pad_bounds[1])
192 |
193 | rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
194 | trajs[:, :, 0] += pad_x0
195 | trajs[:, :, 1] += pad_y0
196 | H, W = rgbs[0].shape[:2]
197 |
198 | # scaling + stretching
199 | scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
200 | scale_x = scale
201 | scale_y = scale
202 | H_new = H
203 | W_new = W
204 |
205 | scale_delta_x = 0.0
206 | scale_delta_y = 0.0
207 |
208 | rgbs_scaled = []
209 | resize_delta = self.resize_delta * augscale
210 | for s in range(S):
211 | if s == 1:
212 | scale_delta_x = np.random.uniform(-resize_delta, resize_delta)
213 | scale_delta_y = np.random.uniform(-resize_delta, resize_delta)
214 | elif s > 1:
215 | scale_delta_x = (
216 | scale_delta_x * 0.8
217 | + np.random.uniform(-resize_delta, resize_delta) * 0.2
218 | )
219 | scale_delta_y = (
220 | scale_delta_y * 0.8
221 | + np.random.uniform(-resize_delta, resize_delta) * 0.2
222 | )
223 | scale_x = scale_x + scale_delta_x
224 | scale_y = scale_y + scale_delta_y
225 |
226 | # bring h/w closer
227 | scale_xy = (scale_x + scale_y) * 0.5
228 | scale_x = scale_x * 0.5 + scale_xy * 0.5
229 | scale_y = scale_y * 0.5 + scale_xy * 0.5
230 |
231 | # don't get too crazy
232 | scale_x = np.clip(scale_x, 0.2, 2.0)
233 | scale_y = np.clip(scale_y, 0.2, 2.0)
234 |
235 | H_new = int(H * scale_y)
236 | W_new = int(W * scale_x)
237 |
238 | # make it at least slightly bigger than the crop area,
239 | # so that the random cropping can add diversity
240 | H_new = np.clip(H_new, crop_size[0] + 10, None)
241 | W_new = np.clip(W_new, crop_size[1] + 10, None)
242 | # recompute scale in case we clipped
243 | scale_x = (W_new - 1) / float(W - 1)
244 | scale_y = (H_new - 1) / float(H - 1)
245 | rgbs_scaled.append(
246 | cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
247 | )
248 | trajs[s, :, 0] *= scale_x
249 | trajs[s, :, 1] *= scale_y
250 | rgbs = rgbs_scaled
251 | ok_inds = visibles[0, :] > 0
252 | vis_trajs = trajs[:, ok_inds] # S,?,2
253 |
254 | if vis_trajs.shape[1] > 0:
255 | mid_x = np.mean(vis_trajs[0, :, 0])
256 | mid_y = np.mean(vis_trajs[0, :, 1])
257 | else:
258 | mid_y = crop_size[0]
259 | mid_x = crop_size[1]
260 |
261 | x0 = int(mid_x - crop_size[1] // 2)
262 | y0 = int(mid_y - crop_size[0] // 2)
263 |
264 | offset_x = 0
265 | offset_y = 0
266 | max_crop_offset = int(self.max_crop_offset*augscale)
267 |
268 | for s in range(S):
269 | # on each frame, shift a bit more
270 | if s == 1:
271 | offset_x = np.random.randint(-max_crop_offset, max_crop_offset)
272 | offset_y = np.random.randint(-max_crop_offset, max_crop_offset)
273 | elif s > 1:
274 | offset_x = int(
275 | offset_x * 0.8
276 | + np.random.randint(-max_crop_offset, max_crop_offset + 1)
277 | * 0.2
278 | )
279 | offset_y = int(
280 | offset_y * 0.8
281 | + np.random.randint(-max_crop_offset, max_crop_offset + 1)
282 | * 0.2
283 | )
284 | x0 = x0 + offset_x
285 | y0 = y0 + offset_y
286 |
287 | H_new, W_new = rgbs[s].shape[:2]
288 | if H_new == crop_size[0]:
289 | y0 = 0
290 | else:
291 | y0 = min(max(0, y0), H_new - crop_size[0] - 1)
292 |
293 | if W_new == crop_size[1]:
294 | x0 = 0
295 | else:
296 | x0 = min(max(0, x0), W_new - crop_size[1] - 1)
297 |
298 | rgbs[s] = rgbs[s][y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
299 | trajs[s, :, 0] -= x0
300 | trajs[s, :, 1] -= y0
301 |
302 | H = crop_size[0]
303 | W = crop_size[1]
304 |
305 | if np.random.rand() < self.h_flip_prob:
306 | rgbs = [rgb[:, ::-1].copy() for rgb in rgbs]
307 | trajs[:, :, 0] = W-1 - trajs[:, :, 0]
308 | if np.random.rand() < self.v_flip_prob:
309 | rgbs = [rgb[::-1].copy() for rgb in rgbs]
310 | trajs[:, :, 1] = H-1 - trajs[:, :, 1]
311 | return np.stack(rgbs), trajs.astype(np.float32)
312 |
313 | def crop(self, rgbs, trajs, crop_size):
314 | T, N, _ = trajs.shape
315 |
316 | S = len(rgbs)
317 | H, W = rgbs[0].shape[:2]
318 | assert S == T
319 |
320 | target_H, target_W = crop_size
321 | if target_H > H or target_W > W:
322 | scale = max(target_H / H, target_W / W)
323 | new_H, new_W = int(np.ceil(H * scale)), int(np.ceil(W * scale))
324 | rgbs = [cv2.resize(rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
325 | trajs = trajs * scale
326 | H, W = rgbs[0].shape[:2]
327 |
328 | # simple random crop
329 | y0 = 0 if crop_size[0] >= H else (H - crop_size[0]) // 2
330 | x0 = 0 if crop_size[1] >= W else np.random.randint(0, W - crop_size[1])
331 | rgbs = [rgb[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]] for rgb in rgbs]
332 |
333 | trajs[:, :, 0] -= x0
334 | trajs[:, :, 1] -= y0
335 |
336 | return np.stack(rgbs), trajs
337 |
338 | def follow_crop(self, rgbs, trajs, visibs, crop_size):
339 | T, N, _ = trajs.shape
340 |
341 | rgbs = [rgb for rgb in rgbs] # unstack so we can change them one by one
342 |
343 | S = len(rgbs)
344 | H, W = rgbs[0].shape[:2]
345 | assert S == T
346 |
347 | vels = trajs[1:]-trajs[:-1]
348 | accels = vels[1:]-vels[:-1]
349 | vis_ = visibs[1:]*visibs[:-1]
350 | vis__ = vis_[1:]*vis_[:-1]
351 | travel = np.sum(np.sum(np.abs(accels)*vis__[:,:,None], axis=2), axis=0)
352 | num_interesting = np.sum(travel > 0).round()
353 | inds = np.argsort(-travel)[:max(num_interesting//32,32)]
354 |
355 | trajs_interesting = trajs[:,inds] # S,?,2
356 |
357 | # pick a random one to focus on, for variety
358 | smooth_xys = trajs_interesting[:,np.random.randint(len(inds))]
359 |
360 | crop_H, crop_W = crop_size
361 |
362 | smooth_xys = np.clip(smooth_xys, [crop_W // 2, crop_H // 2], [W - crop_W // 2, H - crop_H // 2])
363 |
364 | def smooth_path(xys, num_passes):
365 | kernel = np.array([0.25, 0.5, 0.25])
366 | for _ in range(num_passes):
367 | padded = np.pad(xys, ((1, 1), (0, 0)), mode='edge')
368 | xys = (
369 | kernel[0] * padded[:-2] +
370 | kernel[1] * padded[1:-1] +
371 | kernel[2] * padded[2:]
372 | )
373 | return xys
374 | num_passes = np.random.randint(4, S) # 1 is perfect follow; S is near linear
375 | smooth_xys = smooth_path(smooth_xys, num_passes)
376 |
377 | for si in range(S):
378 | x0, y0 = smooth_xys[si].round().astype(np.int32)
379 | x0 -= crop_W//2
380 | y0 -= crop_H//2
381 | rgbs[si] = rgbs[si][y0:y0+crop_H, x0:x0+crop_W]
382 | trajs[si,:,0] -= x0
383 | trajs[si,:,1] -= y0
384 |
385 | return np.stack(rgbs), trajs
386 |
387 |
--------------------------------------------------------------------------------
/datasets/dynrep_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gzip
3 | import torch
4 | import numpy as np
5 | import torch.utils.data as data
6 | from collections import defaultdict
7 | from typing import List, Optional, Any, Dict, Tuple, IO, TypeVar, Type, get_args, get_origin, Union, Any
8 | from datasets.pointdataset import PointDataset
9 | import json
10 | import dataclasses
11 | from dataclasses import dataclass, Field, MISSING
12 | import utils.data
13 | import cv2
14 | import random
15 |
16 | _X = TypeVar("_X")
17 |
18 |
19 | def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
20 | """
21 | Loads to a @dataclass or collection hierarchy including dataclasses
22 | from a json recursively.
23 | Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
24 | raises KeyError if json has keys not mapping to the dataclass fields.
25 |
26 | Args:
27 | f: Either a path to a file, or a file opened for writing.
28 | cls: The class of the loaded dataclass.
29 | binary: Set to True if `f` is a file handle, else False.
30 | """
31 | if binary:
32 | asdict = json.loads(f.read().decode("utf8"))
33 | else:
34 | asdict = json.load(f)
35 |
36 | # in the list case, run a faster "vectorized" version
37 | cls = get_args(cls)[0]
38 | res = list(_dataclass_list_from_dict_list(asdict, cls))
39 |
40 | return res
41 |
42 |
43 | def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
44 | """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
45 | if get_origin(type_) is Union:
46 | args = get_args(type_)
47 | if len(args) == 2 and args[1] == type(None): # noqa E721
48 | return True, args[0]
49 | if type_ is Any:
50 | return True, Any
51 |
52 | return False, type_
53 |
54 |
55 | def _unwrap_type(tp):
56 | # strips Optional wrapper, if any
57 | if get_origin(tp) is Union:
58 | args = get_args(tp)
59 | if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
60 | # this is typing.Optional
61 | return args[0] if args[1] is type(None) else args[1] # noqa: E721
62 | return tp
63 |
64 |
65 | def _get_dataclass_field_default(field: Field) -> Any:
66 | if field.default_factory is not MISSING:
67 | # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
68 | # dataclasses._DefaultFactory[typing.Any]]` is not a function.
69 | return field.default_factory()
70 | elif field.default is not MISSING:
71 | return field.default
72 | else:
73 | return None
74 |
75 |
76 | def _dataclass_list_from_dict_list(dlist, typeannot):
77 | """
78 | Vectorised version of `_dataclass_from_dict`.
79 | The output should be equivalent to
80 | `[_dataclass_from_dict(d, typeannot) for d in dlist]`.
81 |
82 | Args:
83 | dlist: list of objects to convert.
84 | typeannot: type of each of those objects.
85 | Returns:
86 | iterator or list over converted objects of the same length as `dlist`.
87 |
88 | Raises:
89 | ValueError: it assumes the objects have None's in consistent places across
90 | objects, otherwise it would ignore some values. This generally holds for
91 | auto-generated annotations, but otherwise use `_dataclass_from_dict`.
92 | """
93 |
94 | cls = get_origin(typeannot) or typeannot
95 |
96 | if typeannot is Any:
97 | return dlist
98 | if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
99 | return dlist
100 | if any(obj is None for obj in dlist):
101 | # filter out Nones and recurse on the resulting list
102 | idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
103 | idx, notnone = zip(*idx_notnone)
104 | converted = _dataclass_list_from_dict_list(notnone, typeannot)
105 | res = [None] * len(dlist)
106 | for i, obj in zip(idx, converted):
107 | res[i] = obj
108 | return res
109 |
110 | is_optional, contained_type = _resolve_optional(typeannot)
111 | if is_optional:
112 | return _dataclass_list_from_dict_list(dlist, contained_type)
113 |
114 | # otherwise, we dispatch by the type of the provided annotation to convert to
115 | if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
116 | # For namedtuple, call the function recursively on the lists of corresponding keys
117 | types = cls.__annotations__.values()
118 | dlist_T = zip(*dlist)
119 | res_T = [
120 | _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
121 | ]
122 | return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
123 | elif issubclass(cls, (list, tuple)):
124 | # For list/tuple, call the function recursively on the lists of corresponding positions
125 | types = get_args(typeannot)
126 | if len(types) == 1: # probably List; replicate for all items
127 | types = types * len(dlist[0])
128 | dlist_T = zip(*dlist)
129 | res_T = (
130 | _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
131 | )
132 | if issubclass(cls, tuple):
133 | return list(zip(*res_T))
134 | else:
135 | return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
136 | elif issubclass(cls, dict):
137 | # For the dictionary, call the function recursively on concatenated keys and vertices
138 | key_t, val_t = get_args(typeannot)
139 | all_keys_res = _dataclass_list_from_dict_list(
140 | [k for obj in dlist for k in obj.keys()], key_t
141 | )
142 | all_vals_res = _dataclass_list_from_dict_list(
143 | [k for obj in dlist for k in obj.values()], val_t
144 | )
145 | indices = np.cumsum([len(obj) for obj in dlist])
146 | assert indices[-1] == len(all_keys_res)
147 |
148 | keys = np.split(list(all_keys_res), indices[:-1])
149 | all_vals_res_iter = iter(all_vals_res)
150 | return [cls(zip(k, all_vals_res_iter)) for k in keys]
151 | elif not dataclasses.is_dataclass(typeannot):
152 | return dlist
153 |
154 | # dataclass node: 2nd recursion base; call the function recursively on the lists
155 | # of the corresponding fields
156 | assert dataclasses.is_dataclass(cls)
157 | fieldtypes = {
158 | f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
159 | for f in dataclasses.fields(typeannot)
160 | }
161 |
162 | # NOTE the default object is shared here
163 | key_lists = (
164 | _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
165 | for k, (type_, default) in fieldtypes.items()
166 | )
167 | transposed = zip(*key_lists)
168 | return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
169 |
170 |
171 | @dataclass
172 | class ImageAnnotation:
173 | # path to jpg file, relative w.r.t. dataset_root
174 | path: str
175 | # H x W
176 | size: Tuple[int, int]
177 |
178 | @dataclass
179 | class DynamicReplicaFrameAnnotation:
180 | """A dataclass used to load annotations from json."""
181 |
182 | # can be used to join with `SequenceAnnotation`
183 | sequence_name: str
184 | # 0-based, continuous frame number within sequence
185 | frame_number: int
186 | # timestamp in seconds from the video start
187 | frame_timestamp: float
188 |
189 | image: ImageAnnotation
190 | meta: Optional[Dict[str, Any]] = None
191 |
192 | camera_name: Optional[str] = None
193 | trajectories: Optional[str] = None
194 |
195 |
196 | class DynamicReplicaDataset(PointDataset):
197 | def __init__(
198 | self,
199 | data_root,
200 | split="train",
201 | traj_per_sample=256,
202 | traj_max_factor=24, # multiplier on traj_per_sample
203 | crop_size=None,
204 | use_augs=False,
205 | seq_len=64,
206 | strides=[2,3],
207 | shuffle_frames=False,
208 | shuffle=False,
209 | only_first=False,
210 | ):
211 | super(DynamicReplicaDataset, self).__init__(
212 | data_root=data_root,
213 | crop_size=crop_size,
214 | seq_len=seq_len,
215 | traj_per_sample=traj_per_sample,
216 | use_augs=use_augs,
217 | )
218 | print('loading dynamicreplica dataset...')
219 | self.data_root = data_root
220 | self.only_first = only_first
221 | self.traj_max_factor = traj_max_factor
222 | self.seq_len = seq_len
223 | self.split = split
224 | self.traj_per_sample = traj_per_sample
225 | self.crop_size = crop_size
226 | self.shuffle_frames = shuffle_frames
227 | frame_annotations_file = f"frame_annotations_{split}.jgz"
228 | self.sample_list = []
229 | with gzip.open(
230 | os.path.join(data_root, split, frame_annotations_file), "rt", encoding="utf8"
231 | ) as zipfile:
232 | frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
233 | seq_annot = defaultdict(list)
234 | for frame_annot in frame_annots_list:
235 | if frame_annot.camera_name == "left":
236 | seq_annot[frame_annot.sequence_name].append(frame_annot)
237 | # if os.path.isfile(traj_2d_file) and os.path.isfile(visib_file) and os.path.isfile(valid_file):
238 | # self.sequences.append(seq)
239 |
240 | clip_step = 64
241 |
242 | for seq_name in seq_annot.keys():
243 | S_local = len(seq_annot[seq_name])
244 | # print(seq_name, 'S_local', S_local)
245 |
246 | traj_path = os.path.join(self.data_root, self.split, seq_annot[seq_name][0].trajectories['path'])
247 | if os.path.isfile(traj_path):
248 | for stride in strides:
249 | for ref_idx in range(0, S_local-seq_len*stride, clip_step):
250 | full_idx = ref_idx + np.arange(seq_len)*stride
251 | full_idx = [ij for ij in full_idx if ij < S_local]
252 | full_idx = np.array(full_idx).astype(np.int32)
253 | if len(full_idx)==seq_len:
254 | sample = [seq_annot[seq_name][fi] for fi in full_idx]
255 | self.sample_list.append(sample)
256 | print('found %d unique videos in %s (split=%s)' % (len(self.sample_list), data_root, split))
257 | self.dname = 'dynrep%d' % seq_len
258 |
259 | if shuffle:
260 | random.shuffle(self.sample_list)
261 |
262 | def __len__(self):
263 | return len(self.sample_list)
264 |
265 | def getitem_helper(self, index):
266 | sample = self.sample_list[index]
267 | T = len(sample)
268 | rgbs, visibs, trajs = [], [], []
269 |
270 | H, W = sample[0].image.size
271 | image_size = (H, W)
272 |
273 | for i in range(T):
274 | traj_path = os.path.join(self.data_root, self.split, sample[i].trajectories["path"])
275 | traj = torch.load(traj_path, weights_only=False)
276 |
277 | visibs.append(traj["verts_inds_vis"].numpy())
278 |
279 | rgbs.append(traj["img"].numpy())
280 | trajs.append(traj["traj_2d"].numpy()[..., :2])
281 |
282 | rgbs = np.stack(rgbs, axis=0) # S,H,W,3
283 | trajs = np.stack(trajs)
284 | visibs = np.stack(visibs)
285 | T, N, D = trajs.shape
286 |
287 | H,W = rgbs[0].shape[:2]
288 | visibs[trajs[:, :, 0] > W-1] = False
289 | visibs[trajs[:, :, 0] < 0] = False
290 | visibs[trajs[:, :, 1] > H-1] = False
291 | visibs[trajs[:, :, 1] < 0] = False
292 |
293 |
294 | if self.use_augs and np.random.rand() < 0.5:
295 | # time flip
296 | rgbs = np.flip(rgbs, axis=[0]).copy()
297 | trajs = np.flip(trajs, axis=[0]).copy()
298 | visibs = np.flip(visibs, axis=[0]).copy()
299 |
300 | if self.shuffle_frames and np.random.rand() < 0.01:
301 | # shuffle the frames
302 | perm = np.random.permutation(rgbs.shape[0])
303 | rgbs = rgbs[perm]
304 | trajs = trajs[perm]
305 | visibs = visibs[perm]
306 |
307 | assert(trajs.shape[0] == self.seq_len)
308 |
309 | if self.only_first:
310 | vis_ok = np.nonzero(visibs[0]==1)[0]
311 | trajs = trajs[:,vis_ok]
312 | visibs = visibs[:,vis_ok]
313 |
314 | N = trajs.shape[1]
315 | if N < self.traj_per_sample:
316 | print('dyn: N after vis0', N)
317 | return None, False
318 |
319 | # the data is quite big: 720x1280
320 | if H > self.crop_size[0]*2 and W > self.crop_size[1]*2 and np.random.rand() < 0.5:
321 | scale = 0.5
322 | rgbs = [cv2.resize(rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
323 | H, W = rgbs[0].shape[:2]
324 | rgbs = np.stack(rgbs, axis=0) # S,H,W,3
325 | trajs = trajs * scale
326 | # print('resized rgbs', rgbs.shape)
327 |
328 | if H > self.crop_size[0]*2 and W > self.crop_size[1]*2 and np.random.rand() < 0.5:
329 | scale = 0.5
330 | rgbs = [cv2.resize(rgb, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) for rgb in rgbs]
331 | H, W = rgbs[0].shape[:2]
332 | rgbs = np.stack(rgbs, axis=0) # S,H,W,3
333 | trajs = trajs * scale
334 | # print('resized rgbs', rgbs.shape)
335 |
336 | if self.use_augs and np.random.rand() < 0.98:
337 | H, W = rgbs[0].shape[:2]
338 |
339 | rgbs, trajs, visibs = self.add_photometric_augs(
340 | rgbs, trajs, visibs, replace=False,
341 | )
342 | if np.random.rand() < 0.2:
343 | rgbs, trajs = self.add_spatial_augs(rgbs, trajs, visibs, self.crop_size)
344 | else:
345 | rgbs, trajs = self.follow_crop(rgbs, trajs, visibs, self.crop_size)
346 | if np.random.rand() < self.rot_prob:
347 | # note this is OK since B==1
348 | # otw we would do it before this func
349 | rgbs = [np.transpose(rgb, (1,0,2)).copy() for rgb in rgbs]
350 | rgbs = np.stack(rgbs)
351 | trajs = np.flip(trajs, axis=2).copy()
352 | H, W = rgbs[0].shape[:2]
353 | if np.random.rand() < self.h_flip_prob:
354 | rgbs = [rgb[:, ::-1].copy() for rgb in rgbs]
355 | trajs[:, :, 0] = W - trajs[:, :, 0]
356 | rgbs = np.stack(rgbs)
357 | if np.random.rand() < self.v_flip_prob:
358 | rgbs = [rgb[::-1].copy() for rgb in rgbs]
359 | trajs[:, :, 1] = H - trajs[:, :, 1]
360 | rgbs = np.stack(rgbs)
361 | else:
362 | rgbs, trajs = self.crop(rgbs, trajs, self.crop_size)
363 |
364 | if self.shuffle_frames and np.random.rand() < 0.01:
365 | # shuffle the frames (again)
366 | perm = np.random.permutation(rgbs.shape[0])
367 | rgbs = rgbs[perm]
368 | trajs = trajs[perm]
369 | visibs = visibs[perm]
370 |
371 | H,W = rgbs[0].shape[:2]
372 |
373 | visibs[trajs[:, :, 0] > W-1] = False
374 | visibs[trajs[:, :, 0] < 0] = False
375 | visibs[trajs[:, :, 1] > H-1] = False
376 | visibs[trajs[:, :, 1] < 0] = False
377 |
378 | # ensure no crazy values
379 | all_valid = np.nonzero(np.sum(np.sum(np.abs(trajs), axis=-1)<100000, axis=0)==self.seq_len)[0]
380 | trajs = trajs[:,all_valid]
381 | visibs = visibs[:,all_valid]
382 |
383 | if self.only_first:
384 | vis_ok = np.nonzero(visibs[0]==1)[0]
385 | trajs = trajs[:,vis_ok]
386 | visibs = visibs[:,vis_ok]
387 |
388 | N = trajs.shape[1]
389 | if N < self.traj_per_sample:
390 | print('dyn: N after aug', N)
391 | return None, False
392 |
393 | trajs = torch.from_numpy(trajs)
394 | visibs = torch.from_numpy(visibs)
395 |
396 | rgbs = np.stack(rgbs, 0)
397 | rgbs = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
398 |
399 | # discard tracks that go far OOB
400 | crop_tensor = torch.tensor(self.crop_size).flip(0)[None, None] / 2.0
401 | close_pts_inds = torch.all(
402 | torch.linalg.vector_norm(trajs[..., :2] - crop_tensor, dim=-1) < max(H,W)*2,
403 | dim=0,
404 | )
405 | trajs = trajs[:, close_pts_inds]
406 | visibs = visibs[:, close_pts_inds]
407 |
408 | visible_pts_inds = (visibs[0]).nonzero(as_tuple=False)[:, 0]
409 | point_inds = torch.randperm(len(visible_pts_inds))[:self.traj_per_sample*self.traj_max_factor]
410 | if len(point_inds) < self.traj_per_sample:
411 | # print('not enough trajs')
412 | return None, False
413 |
414 | visible_inds_sampled = visible_pts_inds[point_inds]
415 | trajs = trajs[:, visible_inds_sampled].float()
416 | visibs = visibs[:, visible_inds_sampled]
417 | valids = torch.ones_like(visibs)
418 |
419 | trajs = trajs[:, :self.traj_per_sample*self.traj_max_factor]
420 | visibs = visibs[:, :self.traj_per_sample*self.traj_max_factor]
421 | valids = valids[:, :self.traj_per_sample*self.traj_max_factor]
422 |
423 | sample = utils.data.VideoData(
424 | video=rgbs,
425 | trajs=trajs,
426 | visibs=visibs,
427 | valids=valids,
428 | dname=self.dname,
429 | )
430 | return sample, True
431 |
--------------------------------------------------------------------------------
/test_dense_on_sparse.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import argparse
7 | import utils.loss
8 | import utils.data
9 | import utils.improc
10 | import utils.misc
11 | import utils.saveload
12 | from tensorboardX import SummaryWriter
13 | import datetime
14 | import time
15 |
16 | torch.set_float32_matmul_precision('medium')
17 |
18 | from prettytable import PrettyTable
19 | def count_parameters(model):
20 | table = PrettyTable(["Modules", "Parameters"])
21 | total_params = 0
22 | for name, parameter in model.named_parameters():
23 | if not parameter.requires_grad:
24 | continue
25 | param = parameter.numel()
26 | if param > 100000:
27 | table.add_row([name, param])
28 | total_params += param
29 | # print(table)
30 | print('total params: %.2f M' % (total_params/1000000.0))
31 | return total_params
32 |
33 | def get_parameter_names(model, forbidden_layer_types):
34 | result = []
35 | for name, child in model.named_children():
36 | result += [
37 | f"{name}.{n}"
38 | for n in get_parameter_names(child, forbidden_layer_types)
39 | if not isinstance(child, tuple(forbidden_layer_types))
40 | ]
41 | result += list(model._parameters.keys())
42 | return result
43 |
44 | def get_dataset(dname, args):
45 | if dname=='bad':
46 | dataset_names = ['bad']
47 | from datasets import badjadataset
48 | dataset = badjadataset.BadjaDataset(
49 | data_root=os.path.join(args.dataset_root, 'badja2'),
50 | crop_size=args.image_size,
51 | only_first=False,
52 | )
53 | elif dname=='cro':
54 | dataset_names = ['cro']
55 | from datasets import crohddataset
56 | dataset = crohddataset.CrohdDataset(
57 | data_root=os.path.join(args.dataset_root, 'crohd'),
58 | crop_size=args.image_size,
59 | seq_len=None,
60 | only_first=True,
61 | )
62 | elif dname=='dav':
63 | dataset_names = ['dav']
64 | from datasets import davisdataset
65 | dataset = davisdataset.DavisDataset(
66 | data_root=os.path.join(args.dataset_root, 'tapvid_davis'),
67 | crop_size=args.image_size,
68 | only_first=False,
69 | )
70 | elif dname=='dri':
71 | dataset_names = ['dri']
72 | from datasets import drivetrackdataset
73 | dataset = drivetrackdataset.DrivetrackDataset(
74 | data_root=os.path.join(args.dataset_root, 'drivetrack'),
75 | crop_size=args.image_size,
76 | seq_len=None,
77 | traj_per_sample=768,
78 | only_first=True,
79 | )
80 | elif dname=='ego':
81 | dataset_names = ['ego']
82 | from datasets import egopointsdataset
83 | dataset = egopointsdataset.EgoPointsDataset(
84 | data_root=os.path.join(args.dataset_root, 'ego_points'),
85 | crop_size=args.image_size,
86 | only_first=True,
87 | )
88 | elif dname=='hor':
89 | dataset_names = ['hor']
90 | from datasets import horsedataset
91 | dataset = horsedataset.HorseDataset(
92 | data_root=os.path.join(args.dataset_root, 'horse10'),
93 | crop_size=args.image_size,
94 | seq_len=None,
95 | only_first=True,
96 | )
97 | elif dname=='kin':
98 | dataset_names = ['kin']
99 | from datasets import kineticsdataset
100 | dataset = kineticsdataset.KineticsDataset(
101 | data_root=os.path.join(args.dataset_root, 'tapvid_kinetics'),
102 | crop_size=args.image_size,
103 | only_first=True,
104 | )
105 | elif dname=='rgb':
106 | dataset_names = ['rgb']
107 | from datasets import rgbstackingdataset
108 | dataset = rgbstackingdataset.RGBStackingDataset(
109 | data_root=os.path.join(args.dataset_root, 'tapvid_rgb_stacking'),
110 | crop_size=args.image_size,
111 | only_first=False,
112 | )
113 | elif dname=='rob':
114 | dataset_names = ['rob']
115 | from datasets import robotapdataset
116 | dataset = robotapdataset.RobotapDataset(
117 | data_root=os.path.join(args.dataset_root, 'robotap'),
118 | crop_size=args.image_size,
119 | only_first=True,
120 | )
121 | return dataset, dataset_names
122 |
123 | def create_pools(args, n_pool=10000, min_size=1):
124 | pools = {}
125 | n_pool = max(n_pool, 10)
126 | thrs = [1,2,4,8,16]
127 | for thr in thrs:
128 | pools['d_%d' % thr] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
129 | pools['jac_%d' % thr] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
130 | pools['d_avg'] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
131 | pools['aj'] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
132 | pools['oa'] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
133 | return pools
134 |
135 | def forward_batch(batch, model, args, sw):
136 | rgbs = batch.video
137 | trajs_g = batch.trajs
138 | vis_g = batch.visibs
139 | valids = batch.valids
140 | dname = batch.dname
141 | # print('rgbs', rgbs.shape, rgbs.dtype, rgbs.device)
142 | # print('trajs_g', trajs_g.shape, trajs_g.device)
143 | # print('vis_g', vis_g.shape, vis_g.device)
144 |
145 | B, T, C, H, W = rgbs.shape
146 | assert C == 3
147 | B, T, N, D = trajs_g.shape
148 | device = rgbs.device
149 | assert(B==1)
150 |
151 | trajs_g = trajs_g.cuda()
152 | vis_g = vis_g.cuda()
153 | valids = valids.cuda()
154 | __, first_positive_inds = torch.max(vis_g, dim=1)
155 |
156 | grid_xy = utils.basic.gridcloud2d(1, H, W, norm=False, device='cuda:0').float() # 1,H*W,2
157 | grid_xy = grid_xy.permute(0,2,1).reshape(1,1,2,H,W) # 1,1,2,H,W
158 |
159 | trajs_e = torch.zeros([B, T, N, 2], device='cuda:0')
160 | visconfs_e = torch.zeros([B, T, N, 2], device='cuda:0')
161 | query_points_all = []
162 | with torch.no_grad():
163 | for first_positive_ind in torch.unique(first_positive_inds):
164 | chunk_pt_idxs = torch.nonzero(first_positive_inds[0]==first_positive_ind, as_tuple=False)[:, 0] # K
165 | chunk_pts = trajs_g[:, first_positive_ind[None].repeat(chunk_pt_idxs.shape[0]), chunk_pt_idxs] # B, K, 2
166 | query_points_all.append(torch.cat([first_positive_inds[:, chunk_pt_idxs, None], chunk_pts], dim=2))
167 |
168 | traj_maps_e = grid_xy.repeat(1,T,1,1,1) # B,T,2,H,W
169 | visconf_maps_e = torch.zeros_like(traj_maps_e)
170 | if first_positive_ind < T-1:
171 | if T > 128: # forward_sliding is a little safer memory-wise
172 | forward_flow_e, forward_visconf_e, forward_flow_preds, forward_visconf_preds = \
173 | model.forward_sliding(rgbs[:, first_positive_ind:], iters=args.inference_iters, sw=sw, is_training=False)
174 | else:
175 | forward_flow_e, forward_visconf_e, forward_flow_preds, forward_visconf_preds = \
176 | model(rgbs[:, first_positive_ind:], iters=args.inference_iters, sw=sw, is_training=False)
177 |
178 | del forward_flow_preds
179 | del forward_visconf_preds
180 | forward_traj_maps_e = forward_flow_e.cuda() + grid_xy # B,Tf,2,H,W, when T = 2, flow has no T dim, but we broadcast
181 | traj_maps_e[:,first_positive_ind:] = forward_traj_maps_e
182 | visconf_maps_e[:,first_positive_ind:] = forward_visconf_e
183 | if not sw.save_this:
184 | del forward_flow_e
185 | del forward_visconf_e
186 | del forward_traj_maps_e
187 |
188 | xyt = trajs_g[:,first_positive_ind].round().long()[0, chunk_pt_idxs] # K,2
189 | trajs_e_chunk = traj_maps_e[:, :, :, xyt[:,1], xyt[:,0]] # B,T,2,K
190 | trajs_e_chunk = trajs_e_chunk.permute(0,1,3,2) # B,T,K,2
191 | trajs_e.scatter_add_(2, chunk_pt_idxs[None, None, :, None].repeat(1, trajs_e_chunk.shape[1], 1, 2), trajs_e_chunk)
192 |
193 | visconfs_e_chunk = visconf_maps_e[:, :, :, xyt[:,1], xyt[:,0]] # B,T,2,K
194 | visconfs_e_chunk = visconfs_e_chunk.permute(0,1,3,2) # B,T,K,2
195 | visconfs_e.scatter_add_(2, chunk_pt_idxs[None, None, :, None].repeat(1, visconfs_e_chunk.shape[1], 1, 2), visconfs_e_chunk)
196 |
197 | visconfs_e[..., 0] *= visconfs_e[..., 1]
198 | assert (torch.all(visconfs_e >= 0) and torch.all(visconfs_e <= 1))
199 | vis_thr = 0.6
200 | query_points_all = torch.cat(query_points_all, dim=1)[..., [0, 2, 1]]
201 | gt_occluded = (vis_g < .5).bool().transpose(1, 2)
202 | gt_tracks = trajs_g.transpose(1, 2)
203 | pred_occluded = (visconfs_e[..., 0] < vis_thr).bool().transpose(1, 2)
204 | pred_tracks = trajs_e.transpose(1, 2)
205 |
206 | metrics = utils.misc.compute_tapvid_metrics(
207 | query_points=query_points_all.cpu().numpy(),
208 | gt_occluded=gt_occluded.cpu().numpy(),
209 | gt_tracks=gt_tracks.cpu().numpy(),
210 | pred_occluded=pred_occluded.cpu().numpy(),
211 | pred_tracks=pred_tracks.cpu().numpy(),
212 | query_mode='first',
213 | crop_size=args.image_size
214 | )
215 | for thr in [1, 2, 4, 8, 16]:
216 | metrics['d_%d' % thr] = metrics['pts_within_' + str(thr)]
217 | metrics['jac_%d' % thr] = metrics['jaccard_' + str(thr)]
218 | metrics['d_avg'] = metrics['average_pts_within_thresh']
219 | metrics['aj'] = metrics['average_jaccard']
220 | metrics['oa'] = metrics['occlusion_accuracy']
221 |
222 | return metrics
223 |
224 |
225 | def run(dname, model, args):
226 | def seed_everything(seed: int):
227 | random.seed(seed)
228 | os.environ["PYTHONHASHSEED"] = str(seed)
229 | np.random.seed(seed)
230 | torch.manual_seed(seed)
231 | torch.cuda.manual_seed(seed)
232 | torch.backends.cudnn.deterministic = True
233 | torch.backends.cudnn.benchmark = False
234 | def seed_worker(worker_id):
235 | worker_seed = torch.initial_seed() % 2**32
236 | np.random.seed(worker_seed + worker_id)
237 | random.seed(worker_seed + worker_id)
238 | seed = 42
239 | seed_everything(seed)
240 | g = torch.Generator()
241 | g.manual_seed(seed)
242 |
243 | B_ = args.batch_size * torch.cuda.device_count()
244 | assert(B_==1)
245 | model_name = "%dx%d" % (int(args.image_size[0]), int(args.image_size[1]))
246 | model_name += "i%d" % (args.inference_iters)
247 | model_name += "_%s" % args.init_dir
248 | model_name += "_%s" % dname
249 | if args.only_first:
250 | model_name += "_first"
251 | model_name += "_%s" % args.exp
252 | model_date = datetime.datetime.now().strftime('%M%S')
253 | model_name = model_name + '_' + model_date
254 |
255 | save_dir = '%s/%s' % (args.ckpt_dir, model_name)
256 |
257 | dataset, dataset_names = get_dataset(dname, args)
258 | dataloader = torch.utils.data.DataLoader(
259 | dataset,
260 | batch_size=1,
261 | shuffle=False,
262 | num_workers=args.num_workers,
263 | worker_init_fn=seed_worker,
264 | generator=g,
265 | pin_memory=True,
266 | drop_last=True,
267 | collate_fn=utils.data.collate_fn_train,
268 | )
269 | iterloader = iter(dataloader)
270 | print('len(dataloader)', len(dataloader))
271 |
272 | log_dir = './logs_test_dense_on_sparse'
273 | overpools_t = create_pools(args)
274 | writer_t = SummaryWriter(log_dir + '/' + args.model_type + '-' + model_name + '/t', max_queue=10, flush_secs=60)
275 |
276 | global_step = 0
277 | if args.init_dir:
278 | load_dir = '%s/%s' % (args.ckpt_dir, args.init_dir)
279 | _ = utils.saveload.load(
280 | None,
281 | load_dir,
282 | model,
283 | optimizer=None,
284 | scheduler=None,
285 | ignore_load=None,
286 | strict=True,
287 | verbose=False,
288 | weights_only=False,
289 | )
290 | model.cuda()
291 | for n, p in model.named_parameters():
292 | p.requires_grad = False
293 | model.eval()
294 |
295 | max_steps = min(args.max_steps, len(dataloader))
296 |
297 | while global_step < max_steps:
298 | torch.cuda.empty_cache()
299 | iter_start_time = time.time()
300 | try:
301 | batch = next(iterloader)
302 | except StopIteration:
303 | iterloader = iter(dataloader)
304 | batch = next(iterloader)
305 |
306 | batch, gotit = batch
307 | if not all(gotit):
308 | continue
309 |
310 | sw_t = utils.improc.Summ_writer(
311 | writer=writer_t,
312 | global_step=global_step,
313 | log_freq=args.log_freq,
314 | fps=8,
315 | scalar_freq=1,
316 | just_gif=True)
317 | if args.log_freq == 9999:
318 | sw_t.save_this = False
319 |
320 | rtime = time.time()-iter_start_time
321 |
322 | if batch.trajs.shape[2] == 0:
323 | global_step += 1
324 | continue
325 |
326 | metrics = forward_batch(batch, model, args, sw_t)
327 |
328 | # update stats
329 | for key in list(overpools_t.keys()):
330 | if key in metrics:
331 | overpools_t[key].update([metrics[key]])
332 | # plot stats
333 | for key in list(overpools_t.keys()):
334 | sw_t.summ_scalar('_/%s' % (key), overpools_t[key].mean())
335 |
336 | global_step += 1
337 |
338 | itime = time.time()-iter_start_time
339 |
340 | info_str = '%s; step %06d/%d; rtime %.2f; itime %.2f' % (
341 | model_name, global_step, max_steps, rtime, itime)
342 | info_str += '; dname %s; d_avg %.1f aj %.1f oa %.1f' % (
343 | dname, overpools_t['d_avg'].mean()*100.0, overpools_t['aj'].mean()*100.0, overpools_t['oa'].mean()*100.0
344 | )
345 | if sw_t.save_this:
346 | print('model_name', model_name)
347 |
348 | if not args.print_less:
349 | print(info_str, flush=True)
350 |
351 |
352 | if args.print_less:
353 | print(info_str, flush=True)
354 |
355 | writer_t.close()
356 |
357 | del iterloader
358 | del dataloader
359 | del dataset
360 |
361 | return overpools_t['d_avg'].mean()*100.0, overpools_t['aj'].mean()*100.0, overpools_t['oa'].mean()*100.0
362 |
363 | if __name__ == "__main__":
364 | torch.set_grad_enabled(False)
365 | init_dir = ''
366 |
367 | exp = ''
368 |
369 | parser = argparse.ArgumentParser()
370 | parser.add_argument("--exp", default=exp)
371 | parser.add_argument("--dname", type=str, nargs='+', default=None, help="Dataset names, written as a single string or list of strings")
372 | parser.add_argument("--init_dir", type=str, default=init_dir)
373 | parser.add_argument("--ckpt_dir", type=str, default='')
374 | parser.add_argument("--batch_size", type=int, default=1)
375 | parser.add_argument("--num_workers", type=int, default=1)
376 | parser.add_argument("--max_steps", type=int, default=1500)
377 | parser.add_argument("--log_freq", type=int, default=9999)
378 | parser.add_argument("--dataset_root", type=str, default='/orion/group')
379 | parser.add_argument("--inference_iters", type=int, default=4)
380 | parser.add_argument("--window_len", type=int, default=16)
381 | parser.add_argument("--stride", type=int, default=8)
382 | parser.add_argument("--image_size", nargs="+", default=[384, 512]) # resizing arg
383 | parser.add_argument("--backwards", default=False)
384 | parser.add_argument("--mixed_precision", action='store_true', default=False)
385 | parser.add_argument("--only_first", action='store_true', default=False)
386 | parser.add_argument("--no_split", action='store_true', default=False)
387 | parser.add_argument("--print_less", action='store_true', default=False)
388 | parser.add_argument("--use_basicencoder", action='store_true', default=False)
389 | parser.add_argument("--conf", action='store_true', default=False)
390 | parser.add_argument("--model_type", choices=['ours', 'raft', 'searaft', 'accflow', 'delta'], default='ours')
391 | args = parser.parse_args()
392 | # allow dname to be a comma-separated string (e.g., "rgb,bad,dav")
393 | if args.dname is not None and len(args.dname) == 1 and ',' in args.dname[0]:
394 | args.dname = args.dname[0].split(',')
395 | if args.dname is None:
396 | args.dname = ['bad', 'cro', 'dav', 'dri', 'hor', 'kin', 'rgb', 'rob']
397 | dataset_names = args.dname
398 | args.image_size = [int(args.image_size[0]), int(args.image_size[1])]
399 | full_start_time = time.time()
400 |
401 | from nets.alltracker import Net; model = Net(16)
402 | url = "https://huggingface.co/aharley/alltracker/resolve/main/alltracker.pth"
403 | state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
404 | model.load_state_dict(state_dict['model'], strict=True)
405 | print('loaded weights from', url)
406 |
407 | das, ajs, oas = [], [], []
408 | for dname in dataset_names:
409 | if dname==dataset_names[0]:
410 | count_parameters(model)
411 | da, aj, oa = run(dname, model, args)
412 | das.append(da)
413 | ajs.append(aj)
414 | oas.append(oa)
415 | for (data, name) in zip([dataset_names, das, ajs, oas], ['dn', 'da', 'aj', 'oa']):
416 | st = name + ': '
417 | for dat in data:
418 | if isinstance(dat, str):
419 | st += '%s,' % dat
420 | else:
421 | st += '%.1f,' % dat
422 | print(st)
423 | full_time = time.time()-full_start_time
424 | print('full_time %.1f' % full_time)
425 |
--------------------------------------------------------------------------------
/utils/py.py:
--------------------------------------------------------------------------------
1 | import glob, math
2 | import numpy as np
3 | # from scipy import misc
4 | # from scipy import linalg
5 | from PIL import Image
6 | import io
7 | import matplotlib.pyplot as plt
8 | EPS = 1e-6
9 |
10 |
11 | XMIN = -64.0 # right (neg is left)
12 | XMAX = 64.0 # right
13 | YMIN = -64.0 # down (neg is up)
14 | YMAX = 64.0 # down
15 | ZMIN = -64.0 # forward
16 | ZMAX = 64.0 # forward
17 |
18 | def print_stats(name, tensor):
19 | tensor = tensor.astype(np.float32)
20 | print('%s min = %.2f, mean = %.2f, max = %.2f' % (name, np.min(tensor), np.mean(tensor), np.max(tensor)), tensor.shape)
21 |
22 | def reduce_masked_mean(x, mask, axis=None, keepdims=False):
23 | # x and mask are the same shape
24 | # returns shape-1
25 | # axis can be a list of axes
26 | prod = x*mask
27 | numer = np.sum(prod, axis=axis, keepdims=keepdims)
28 | denom = EPS+np.sum(mask, axis=axis, keepdims=keepdims)
29 | mean = numer/denom
30 | return mean
31 |
32 | def reduce_masked_sum(x, mask, axis=None, keepdims=False):
33 | # x and mask are the same shape
34 | # returns shape-1
35 | # axis can be a list of axes
36 | prod = x*mask
37 | numer = np.sum(prod, axis=axis, keepdims=keepdims)
38 | return numer
39 |
40 | def reduce_masked_median(x, mask, keep_batch=False):
41 | # x and mask are the same shape
42 | # returns shape-1
43 | # axis can be a list of axes
44 |
45 | if not (x.shape == mask.shape):
46 | print('reduce_masked_median: these shapes should match:', x.shape, mask.shape)
47 | assert(False)
48 | # assert(x.shape == mask.shape)
49 |
50 | B = list(x.shape)[0]
51 |
52 | if keep_batch:
53 | x = np.reshape(x, [B, -1])
54 | mask = np.reshape(mask, [B, -1])
55 | meds = np.zeros([B], np.float32)
56 | for b in list(range(B)):
57 | xb = x[b]
58 | mb = mask[b]
59 | if np.sum(mb) > 0:
60 | xb = xb[mb > 0]
61 | meds[b] = np.median(xb)
62 | else:
63 | meds[b] = np.nan
64 | return meds
65 | else:
66 | x = np.reshape(x, [-1])
67 | mask = np.reshape(mask, [-1])
68 | if np.sum(mask) > 0:
69 | x = x[mask > 0]
70 | med = np.median(x)
71 | else:
72 | med = np.nan
73 | med = np.array([med], np.float32)
74 | return med
75 |
76 | def get_nFiles(path):
77 | return len(glob.glob(path))
78 |
79 | def get_file_list(path):
80 | return glob.glob(path)
81 |
82 | def rotm2eul(R):
83 | # R is 3x3
84 | sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0])
85 | if sy > 1e-6: # singular
86 | x = math.atan2(R[2,1] , R[2,2])
87 | y = math.atan2(-R[2,0], sy)
88 | z = math.atan2(R[1,0], R[0,0])
89 | else:
90 | x = math.atan2(-R[1,2], R[1,1])
91 | y = math.atan2(-R[2,0], sy)
92 | z = 0
93 | return x, y, z
94 |
95 | def rad2deg(rad):
96 | return rad*180.0/np.pi
97 |
98 | def deg2rad(deg):
99 | return deg/180.0*np.pi
100 |
101 | def eul2rotm(rx, ry, rz):
102 | # copy of matlab, but order of inputs is different
103 | # R = [ cy*cz sy*sx*cz-sz*cx sy*cx*cz+sz*sx
104 | # cy*sz sy*sx*sz+cz*cx sy*cx*sz-cz*sx
105 | # -sy cy*sx cy*cx]
106 | sinz = np.sin(rz)
107 | siny = np.sin(ry)
108 | sinx = np.sin(rx)
109 | cosz = np.cos(rz)
110 | cosy = np.cos(ry)
111 | cosx = np.cos(rx)
112 | r11 = cosy*cosz
113 | r12 = sinx*siny*cosz - cosx*sinz
114 | r13 = cosx*siny*cosz + sinx*sinz
115 | r21 = cosy*sinz
116 | r22 = sinx*siny*sinz + cosx*cosz
117 | r23 = cosx*siny*sinz - sinx*cosz
118 | r31 = -siny
119 | r32 = sinx*cosy
120 | r33 = cosx*cosy
121 | r1 = np.stack([r11,r12,r13],axis=-1)
122 | r2 = np.stack([r21,r22,r23],axis=-1)
123 | r3 = np.stack([r31,r32,r33],axis=-1)
124 | r = np.stack([r1,r2,r3],axis=0)
125 | return r
126 |
127 | def wrap2pi(rad_angle):
128 | # puts the angle into the range [-pi, pi]
129 | return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
130 |
131 | def rot2view(rx,ry,rz,x,y,z):
132 | # takes rot angles and 3d position as input
133 | # returns viewpoint angles as output
134 | # (all in radians)
135 | # it will perform strangely if z <= 0
136 | az = wrap2pi(ry - (-np.arctan2(z, x) - 1.5*np.pi))
137 | el = -wrap2pi(rx - (-np.arctan2(z, y) - 1.5*np.pi))
138 | th = -rz
139 | return az, el, th
140 |
141 | def invAxB(a,b):
142 | """
143 | Compute the relative 3d transformation between a and b.
144 |
145 | Input:
146 | a -- first pose (homogeneous 4x4 matrix)
147 | b -- second pose (homogeneous 4x4 matrix)
148 |
149 | Output:
150 | Relative 3d transformation from a to b.
151 | """
152 | return np.dot(np.linalg.inv(a),b)
153 |
154 | def merge_rt(r, t):
155 | # r is 3 x 3
156 | # t is 3 or maybe 3 x 1
157 | t = np.reshape(t, [3, 1])
158 | rt = np.concatenate((r,t), axis=1)
159 | # rt is 3 x 4
160 | br = np.reshape(np.array([0,0,0,1], np.float32), [1, 4])
161 | # br is 1 x 4
162 | rt = np.concatenate((rt, br), axis=0)
163 | # rt is 4 x 4
164 | return rt
165 |
166 | def split_rt(rt):
167 | r = rt[:3,:3]
168 | t = rt[:3,3]
169 | r = np.reshape(r, [3, 3])
170 | t = np.reshape(t, [3, 1])
171 | return r, t
172 |
173 | def split_intrinsics(K):
174 | # K is 3 x 4 or 4 x 4
175 | fx = K[0,0]
176 | fy = K[1,1]
177 | x0 = K[0,2]
178 | y0 = K[1,2]
179 | return fx, fy, x0, y0
180 |
181 | def merge_intrinsics(fx, fy, x0, y0):
182 | # inputs are shaped []
183 | K = np.eye(4)
184 | K[0,0] = fx
185 | K[1,1] = fy
186 | K[0,2] = x0
187 | K[1,2] = y0
188 | # K is shaped 4 x 4
189 | return K
190 |
191 | def scale_intrinsics(K, sx, sy):
192 | fx, fy, x0, y0 = split_intrinsics(K)
193 | fx *= sx
194 | fy *= sy
195 | x0 *= sx
196 | y0 *= sy
197 | return merge_intrinsics(fx, fy, x0, y0)
198 |
199 | # def meshgrid(H, W):
200 | # x = np.linspace(0, W-1, W)
201 | # y = np.linspace(0, H-1, H)
202 | # xv, yv = np.meshgrid(x, y)
203 | # return xv, yv
204 |
205 | def compute_distance(transform):
206 | """
207 | Compute the distance of the translational component of a 4x4 homogeneous matrix.
208 | """
209 | return numpy.linalg.norm(transform[0:3,3])
210 |
211 | def radian_l1_dist(e, g):
212 | # if our angles are in [0, 360] we can follow this stack overflow answer:
213 | # https://gamedev.stackexchange.com/questions/4467/comparing-angles-and-working-out-the-difference
214 | # wrap2pi brings the angles to [-180, 180]; adding pi puts them in [0, 360]
215 | e = wrap2pi(e)+np.pi
216 | g = wrap2pi(g)+np.pi
217 | l = np.abs(np.pi - np.abs(np.abs(e-g) - np.pi))
218 | return l
219 |
220 | def apply_pix_T_cam(pix_T_cam, xyz):
221 | fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
222 | # xyz is shaped B x H*W x 3
223 | # returns xy, shaped B x H*W x 2
224 | N, C = xyz.shape
225 | x, y, z = np.split(xyz, 3, axis=-1)
226 | EPS = 1e-4
227 | z = np.clip(z, EPS, None)
228 | x = (x*fx)/(z)+x0
229 | y = (y*fy)/(z)+y0
230 | xy = np.concatenate([x, y], axis=-1)
231 | return xy
232 |
233 | def apply_4x4(RT, XYZ):
234 | # RT is 4 x 4
235 | # XYZ is N x 3
236 |
237 | # put into homogeneous coords
238 | X, Y, Z = np.split(XYZ, 3, axis=1)
239 | ones = np.ones_like(X)
240 | XYZ1 = np.concatenate([X, Y, Z, ones], axis=1)
241 | # XYZ1 is N x 4
242 |
243 | XYZ1_t = np.transpose(XYZ1)
244 | # this is 4 x N
245 |
246 | XYZ2_t = np.dot(RT, XYZ1_t)
247 | # this is 4 x N
248 |
249 | XYZ2 = np.transpose(XYZ2_t)
250 | # this is N x 4
251 |
252 | XYZ2 = XYZ2[:,:3]
253 | # this is N x 3
254 |
255 | return XYZ2
256 |
257 | def Ref2Mem(xyz, Z, Y, X):
258 | # xyz is N x 3, in ref coordinates
259 | # transforms ref coordinates into mem coordinates
260 | N, C = xyz.shape
261 | assert(C==3)
262 | mem_T_ref = get_mem_T_ref(Z, Y, X)
263 | xyz = apply_4x4(mem_T_ref, xyz)
264 | return xyz
265 |
266 | # def Mem2Ref(xyz_mem, MH, MW, MD):
267 | # # xyz is B x N x 3, in mem coordinates
268 | # # transforms mem coordinates into ref coordinates
269 | # B, N, C = xyz_mem.get_shape().as_list()
270 | # ref_T_mem = get_ref_T_mem(B, MH, MW, MD)
271 | # xyz_ref = utils_geom.apply_4x4(ref_T_mem, xyz_mem)
272 | # return xyz_ref
273 |
274 | def get_mem_T_ref(Z, Y, X):
275 | # sometimes we want the mat itself
276 | # note this is not a rigid transform
277 |
278 | # for interpretability, let's construct this in two steps...
279 |
280 | # translation
281 | center_T_ref = np.eye(4, dtype=np.float32)
282 | center_T_ref[0,3] = -XMIN
283 | center_T_ref[1,3] = -YMIN
284 | center_T_ref[2,3] = -ZMIN
285 |
286 | VOX_SIZE_X = (XMAX-XMIN)/float(X)
287 | VOX_SIZE_Y = (YMAX-YMIN)/float(Y)
288 | VOX_SIZE_Z = (ZMAX-ZMIN)/float(Z)
289 |
290 | # scaling
291 | mem_T_center = np.eye(4, dtype=np.float32)
292 | mem_T_center[0,0] = 1./VOX_SIZE_X
293 | mem_T_center[1,1] = 1./VOX_SIZE_Y
294 | mem_T_center[2,2] = 1./VOX_SIZE_Z
295 |
296 | mem_T_ref = np.dot(mem_T_center, center_T_ref)
297 | return mem_T_ref
298 |
299 | def safe_inverse(a):
300 | r, t = split_rt(a)
301 | t = np.reshape(t, [3, 1])
302 | r_transpose = r.T
303 | inv = np.concatenate([r_transpose, -np.matmul(r_transpose, t)], 1)
304 | bottom_row = a[3:4, :] # this is [0, 0, 0, 1]
305 | inv = np.concatenate([inv, bottom_row], 0)
306 | return inv
307 |
308 | def get_ref_T_mem(Z, Y, X):
309 | mem_T_ref = get_mem_T_ref(X, Y, X)
310 | # note safe_inverse is inapplicable here,
311 | # since the transform is nonrigid
312 | ref_T_mem = np.linalg.inv(mem_T_ref)
313 | return ref_T_mem
314 |
315 | def voxelize_xyz(xyz_ref, Z, Y, X):
316 | # xyz_ref is N x 3
317 | xyz_mem = Ref2Mem(xyz_ref, Z, Y, X)
318 | # this is N x 3
319 | voxels = get_occupancy(xyz_mem, Z, Y, X)
320 | voxels = np.reshape(voxels, [Z, Y, X, 1])
321 | return voxels
322 |
323 | def get_inbounds(xyz, Z, Y, X, already_mem=False):
324 | # xyz is H*W x 3
325 |
326 | if not already_mem:
327 | xyz = Ref2Mem(xyz, Z, Y, X)
328 |
329 | x_valid = np.logical_and(
330 | np.greater_equal(xyz[:,0], -0.5),
331 | np.less(xyz[:,0], float(X)-0.5))
332 | y_valid = np.logical_and(
333 | np.greater_equal(xyz[:,1], -0.5),
334 | np.less(xyz[:,1], float(Y)-0.5))
335 | z_valid = np.logical_and(
336 | np.greater_equal(xyz[:,2], -0.5),
337 | np.less(xyz[:,2], float(Z)-0.5))
338 | inbounds = np.logical_and(np.logical_and(x_valid, y_valid), z_valid)
339 | return inbounds
340 |
341 | def sub2ind3d_zyx(depth, height, width, d, h, w):
342 | # same as sub2ind3d, but inputs in zyx order
343 | # when gathering/scattering with these inds, the tensor should be Z x Y x X
344 | return d*height*width + h*width + w
345 |
346 | def sub2ind3d_yxz(height, width, depth, h, w, d):
347 | return h*width*depth + w*depth + d
348 |
349 | # def ind2sub(height, width, ind):
350 | # # int input
351 | # y = int(ind / height)
352 | # x = ind % height
353 | # return y, x
354 |
355 | def get_occupancy(xyz_mem, Z, Y, X):
356 | # xyz_mem is N x 3
357 | # we want to fill a voxel tensor with 1's at these inds
358 |
359 | inbounds = get_inbounds(xyz_mem, Z, Y, X, already_mem=True)
360 | inds = np.where(inbounds)
361 |
362 | xyz_mem = np.reshape(xyz_mem[inds], [-1, 3])
363 | # xyz_mem is N x 3
364 |
365 | # this is more accurate than a cast/floor, but runs into issues when Y==0
366 | xyz_mem = np.round(xyz_mem).astype(np.int32)
367 | x = xyz_mem[:,0]
368 | y = xyz_mem[:,1]
369 | z = xyz_mem[:,2]
370 |
371 | voxels = np.zeros([Z, Y, X], np.float32)
372 | voxels[z, y, x] = 1.0
373 |
374 | return voxels
375 |
376 | def pixels2camera(x,y,z,fx,fy,x0,y0):
377 | # x and y are locations in pixel coordinates, z is a depth image in meters
378 | # their shapes are H x W
379 | # fx, fy, x0, y0 are scalar camera intrinsics
380 | # returns xyz, sized [B,H*W,3]
381 |
382 | H, W = z.shape
383 |
384 | fx = np.reshape(fx, [1,1])
385 | fy = np.reshape(fy, [1,1])
386 | x0 = np.reshape(x0, [1,1])
387 | y0 = np.reshape(y0, [1,1])
388 |
389 | # unproject
390 | x = ((z+EPS)/fx)*(x-x0)
391 | y = ((z+EPS)/fy)*(y-y0)
392 |
393 | x = np.reshape(x, [-1])
394 | y = np.reshape(y, [-1])
395 | z = np.reshape(z, [-1])
396 | xyz = np.stack([x,y,z], axis=1)
397 | return xyz
398 |
399 | def depth2pointcloud(z, pix_T_cam):
400 | H = z.shape[0]
401 | W = z.shape[1]
402 | y, x = meshgrid2d(H, W)
403 | z = np.reshape(z, [H, W])
404 |
405 | fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
406 | xyz = pixels2camera(x, y, z, fx, fy, x0, y0)
407 | return xyz
408 |
409 | def meshgrid2d(Y, X):
410 | grid_y = np.linspace(0.0, Y-1, Y)
411 | grid_y = np.reshape(grid_y, [Y, 1])
412 | grid_y = np.tile(grid_y, [1, X])
413 |
414 | grid_x = np.linspace(0.0, X-1, X)
415 | grid_x = np.reshape(grid_x, [1, X])
416 | grid_x = np.tile(grid_x, [Y, 1])
417 |
418 | # outputs are Y x X
419 | return grid_y, grid_x
420 |
421 | def gridcloud3d(Y, X, Z):
422 | x_ = np.linspace(0, X-1, X)
423 | y_ = np.linspace(0, Y-1, Y)
424 | z_ = np.linspace(0, Z-1, Z)
425 | y, x, z = np.meshgrid(y_, x_, z_, indexing='ij')
426 | x = np.reshape(x, [-1])
427 | y = np.reshape(y, [-1])
428 | z = np.reshape(z, [-1])
429 | xyz = np.stack([x,y,z], axis=1).astype(np.float32)
430 | return xyz
431 |
432 | def gridcloud2d(Y, X):
433 | x_ = np.linspace(0, X-1, X)
434 | y_ = np.linspace(0, Y-1, Y)
435 | y, x = np.meshgrid(y_, x_, indexing='ij')
436 | x = np.reshape(x, [-1])
437 | y = np.reshape(y, [-1])
438 | xy = np.stack([x,y], axis=1).astype(np.float32)
439 | return xy
440 |
441 | def normalize(im):
442 | im = im - np.min(im)
443 | im = im / np.max(im)
444 | return im
445 |
446 | def wrap2pi(rad_angle):
447 | # rad_angle can be any shape
448 | # puts the angle into the range [-pi, pi]
449 | return np.arctan2(np.sin(rad_angle), np.cos(rad_angle))
450 |
451 | def convert_occ_to_height(occ):
452 | Z, Y, X, C = occ.shape
453 | assert(C==1)
454 |
455 | height = np.linspace(float(Y), 1.0, Y)
456 | height = np.reshape(height, [1, Y, 1, 1])
457 | height = np.max(occ*height, axis=1)/float(Y)
458 | height = np.reshape(height, [Z, X, C])
459 | return height
460 |
461 | def create_depth_image(xy, Z, H, W):
462 |
463 | # turn the xy coordinates into image inds
464 | xy = np.round(xy)
465 |
466 | # lidar reports a sphere of measurements
467 | # only use the inds that are within the image bounds
468 | # also, only use forward-pointing depths (Z > 0)
469 | valid = (xy[:,0] < W-1) & (xy[:,1] < H-1) & (xy[:,0] >= 0) & (xy[:,1] >= 0) & (Z[:] > 0)
470 |
471 | # gather these up
472 | xy = xy[valid]
473 | Z = Z[valid]
474 |
475 | inds = sub2ind(H,W,xy[:,1],xy[:,0])
476 | depth = np.zeros((H*W), np.float32)
477 |
478 | for (index, replacement) in zip(inds, Z):
479 | depth[index] = replacement
480 | depth[np.where(depth == 0.0)] = 70.0
481 | depth = np.reshape(depth, [H, W])
482 |
483 | return depth
484 |
485 | def vis_depth(depth, maxdepth=80.0, log_vis=True):
486 | depth[depth<=0.0] = maxdepth
487 | if log_vis:
488 | depth = np.log(depth)
489 | depth = np.clip(depth, 0, np.log(maxdepth))
490 | else:
491 | depth = np.clip(depth, 0, maxdepth)
492 | depth = (depth*255.0).astype(np.uint8)
493 | return depth
494 |
495 | def preprocess_color(x):
496 | return x.astype(np.float32) * 1./255 - 0.5
497 |
498 | def convert_box_to_ref_T_obj(boxes):
499 | shape = boxes.shape
500 | boxes = boxes.reshape(-1,9)
501 | rots = [eul2rotm(rx,ry,rz)
502 | for rx,ry,rz in boxes[:,6:]]
503 | rots = np.stack(rots,axis=0)
504 | trans = boxes[:,:3]
505 | ref_T_objs = [merge_rt(rot,tran)
506 | for rot,tran in zip(rots,trans)]
507 | ref_T_objs = np.stack(ref_T_objs,axis=0)
508 | ref_T_objs = ref_T_objs.reshape(shape[:-1]+(4,4))
509 | ref_T_objs = ref_T_objs.astype(np.float32)
510 | return ref_T_objs
511 |
512 | def get_rot_from_delta(delta, yaw_only=False):
513 | dx = delta[:,0]
514 | dy = delta[:,1]
515 | dz = delta[:,2]
516 |
517 | bot_hyp = np.sqrt(dz**2 + dx**2)
518 | # top_hyp = np.sqrt(bot_hyp**2 + dy**2)
519 |
520 | pitch = -np.arctan2(dy, bot_hyp)
521 | yaw = np.arctan2(dz, dx)
522 |
523 | if yaw_only:
524 | rot = [eul2rotm(0,y,0) for y in yaw]
525 | else:
526 | rot = [eul2rotm(0,y,p) for (p,y) in zip(pitch,yaw)]
527 |
528 | rot = np.stack(rot)
529 | # rot is B x 3 x 3
530 | return rot
531 |
532 | def im2col(im, psize):
533 | n_channels = 1 if len(im.shape) == 2 else im.shape[0]
534 | (n_channels, rows, cols) = (1,) * (3 - len(im.shape)) + im.shape
535 |
536 | im_pad = np.zeros((n_channels,
537 | int(math.ceil(1.0 * rows / psize) * psize),
538 | int(math.ceil(1.0 * cols / psize) * psize)))
539 | im_pad[:, 0:rows, 0:cols] = im
540 |
541 | final = np.zeros((im_pad.shape[1], im_pad.shape[2], n_channels,
542 | psize, psize))
543 | for c in np.arange(n_channels):
544 | for x in np.arange(psize):
545 | for y in np.arange(psize):
546 | im_shift = np.vstack(
547 | (im_pad[c, x:], im_pad[c, :x]))
548 | im_shift = np.column_stack(
549 | (im_shift[:, y:], im_shift[:, :y]))
550 | final[x::psize, y::psize, c] = np.swapaxes(
551 | im_shift.reshape(int(im_pad.shape[1] / psize), psize,
552 | int(im_pad.shape[2] / psize), psize), 1, 2)
553 |
554 | return np.squeeze(final[0:rows - psize + 1, 0:cols - psize + 1])
555 |
556 | def filter_discontinuities(depth, filter_size=9, thresh=10):
557 | H, W = list(depth.shape)
558 |
559 | # Ensure that filter sizes are okay
560 | assert filter_size % 2 == 1, "Can only use odd filter sizes."
561 |
562 | # Compute discontinuities
563 | offset = int((filter_size - 1) / 2)
564 | patches = 1.0 * im2col(depth, filter_size)
565 | mids = patches[:, :, offset, offset]
566 | mins = np.min(patches, axis=(2, 3))
567 | maxes = np.max(patches, axis=(2, 3))
568 |
569 | discont = np.maximum(np.abs(mins - mids),
570 | np.abs(maxes - mids))
571 | mark = discont > thresh
572 |
573 | # Account for offsets
574 | final_mark = np.zeros((H, W), dtype=np.uint16)
575 | final_mark[offset:offset + mark.shape[0],
576 | offset:offset + mark.shape[1]] = mark
577 |
578 | return depth * (1 - final_mark)
579 |
580 | def argmax2d(tensor):
581 | Y, X = list(tensor.shape)
582 | # flatten the Tensor along the height and width axes
583 | flat_tensor = tensor.reshape(-1)
584 | # argmax of the flat tensor
585 | argmax = np.argmax(flat_tensor)
586 |
587 | # convert the indices into 2d coordinates
588 | argmax_y = argmax // X # row
589 | argmax_x = argmax % X # col
590 |
591 | return argmax_y, argmax_x
592 |
593 | def plot_traj_3d(traj):
594 | # traj is S x 3
595 |
596 | # print('traj', traj.shape)
597 | S, C = list(traj.shape)
598 | assert(C==3)
599 |
600 | fig = plt.figure()
601 | ax = fig.add_subplot(111, projection='3d')
602 |
603 | colors = [plt.cm.RdYlBu(i) for i in np.linspace(0,1,S)]
604 | # print('colors', colors)
605 |
606 | xs = traj[:,0]
607 | ys = -traj[:,1]
608 | zs = traj[:,2]
609 |
610 | ax.scatter(xs, zs, ys, s=30, c=colors, marker='o', alpha=1.0, edgecolors=(0,0,0))#, color=color_map[n])
611 |
612 | ax.set_xlabel('X')
613 | ax.set_ylabel('Z')
614 | ax.set_zlabel('Y')
615 |
616 | ax.set_xlim(0,1)
617 | ax.set_ylim(0,1) # this is really Z
618 | ax.set_zlim(-1,0) # this is really Y
619 |
620 | buf = io.BytesIO()
621 | plt.savefig(buf, format='png')
622 | buf.seek(0)
623 | image = np.array(Image.open(buf)) # H x W x 4
624 | image = image[:,:,:3]
625 |
626 | plt.close()
627 | return image
628 |
629 | def camera2pixels(xyz, pix_T_cam):
630 | # xyz is shaped N x 3
631 | # returns xy, shaped N x 2
632 |
633 | fx, fy, x0, y0 = split_intrinsics(pix_T_cam)
634 | x, y, z = xyz[:,0], xyz[:,1], xyz[:,2]
635 |
636 | EPS = 1e-4
637 | z = np.clip(z, EPS, None)
638 | x = (x*fx)/z + x0
639 | y = (y*fy)/z + y0
640 | xy = np.stack([x, y], axis=-1)
641 | return xy
642 |
643 | def make_colorwheel():
644 | """
645 | Generates a color wheel for optical flow visualization as presented in:
646 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
647 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
648 |
649 | Code follows the original C++ source code of Daniel Scharstein.
650 | Code follows the the Matlab source code of Deqing Sun.
651 |
652 | Returns:
653 | np.ndarray: Color wheel
654 | """
655 |
656 | RY = 15
657 | YG = 6
658 | GC = 4
659 | CB = 11
660 | BM = 13
661 | MR = 6
662 |
663 | ncols = RY + YG + GC + CB + BM + MR
664 | colorwheel = np.zeros((ncols, 3))
665 | col = 0
666 |
667 | # RY
668 | colorwheel[0:RY, 0] = 255
669 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
670 | col = col+RY
671 | # YG
672 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
673 | colorwheel[col:col+YG, 1] = 255
674 | col = col+YG
675 | # GC
676 | colorwheel[col:col+GC, 1] = 255
677 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
678 | col = col+GC
679 | # CB
680 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
681 | colorwheel[col:col+CB, 2] = 255
682 | col = col+CB
683 | # BM
684 | colorwheel[col:col+BM, 2] = 255
685 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
686 | col = col+BM
687 | # MR
688 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
689 | colorwheel[col:col+MR, 0] = 255
690 | return colorwheel
691 |
692 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
693 | """
694 | Applies the flow color wheel to (possibly clipped) flow components u and v.
695 |
696 | According to the C++ source code of Daniel Scharstein
697 | According to the Matlab source code of Deqing Sun
698 |
699 | Args:
700 | u (np.ndarray): Input horizontal flow of shape [H,W]
701 | v (np.ndarray): Input vertical flow of shape [H,W]
702 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
703 |
704 | Returns:
705 | np.ndarray: Flow visualization image of shape [H,W,3]
706 | """
707 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
708 | colorwheel = make_colorwheel() # shape [55x3]
709 | ncols = colorwheel.shape[0]
710 | rad = np.sqrt(np.square(u) + np.square(v))
711 | a = np.arctan2(-v, -u)/np.pi
712 | fk = (a+1) / 2*(ncols-1)
713 | k0 = np.floor(fk).astype(np.int32)
714 | k1 = k0 + 1
715 | k1[k1 == ncols] = 0
716 | f = fk - k0
717 | for i in range(colorwheel.shape[1]):
718 | tmp = colorwheel[:,i]
719 | col0 = tmp[k0] / 255.0
720 | col1 = tmp[k1] / 255.0
721 | col = (1-f)*col0 + f*col1
722 | idx = (rad <= 1)
723 | col[idx] = 1 - rad[idx] * (1-col[idx])
724 | col[~idx] = col[~idx] * 0.75 # out of range
725 | # Note the 2-i => BGR instead of RGB
726 | ch_idx = 2-i if convert_to_bgr else i
727 | flow_image[:,:,ch_idx] = np.floor(255 * col)
728 | return flow_image
729 |
730 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
731 | """
732 | Expects a two dimensional flow image of shape.
733 |
734 | Args:
735 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
736 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
737 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
738 |
739 | Returns:
740 | np.ndarray: Flow visualization image of shape [H,W,3]
741 | """
742 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
743 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
744 | if clip_flow is not None:
745 | flow_uv = np.clip(flow_uv, -clip_flow, clip_flow) / clip_flow
746 | # flow_uv = np.clamp(flow, -clip, clip)/clip
747 |
748 | u = flow_uv[:,:,0]
749 | v = flow_uv[:,:,1]
750 | rad = np.sqrt(np.square(u) + np.square(v))
751 | rad_max = np.max(rad)
752 | epsilon = 1e-5
753 | u = u / (rad_max + epsilon)
754 | v = v / (rad_max + epsilon)
755 | return flow_uv_to_colors(u, v, convert_to_bgr)
756 |
--------------------------------------------------------------------------------
/train_stage1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import argparse
7 | from lightning_fabric import Fabric
8 | import utils.loss
9 | import utils.samp
10 | import utils.data
11 | import utils.improc
12 | import utils.misc
13 | import utils.saveload
14 | from tensorboardX import SummaryWriter
15 | import datetime
16 | import time
17 | from nets.blocks import bilinear_sampler
18 |
19 | torch.set_float32_matmul_precision('medium')
20 |
21 | def get_parameter_names(model, forbidden_layer_types):
22 | """
23 | Returns the names of the model parameters that are not inside a forbidden layer.
24 | """
25 | result = []
26 | for name, child in model.named_children():
27 | result += [
28 | f"{name}.{n}"
29 | for n in get_parameter_names(child, forbidden_layer_types)
30 | if not isinstance(child, tuple(forbidden_layer_types))
31 | ]
32 | result += list(model._parameters.keys())
33 | return result
34 |
35 | def get_sparse_dataset(args, crop_size, N, T, random_first=False, version='kubric_au'):
36 | from datasets import kubric_movif_dataset
37 | dataset = kubric_movif_dataset.KubricMovifDataset(
38 | data_root=os.path.join(args.data_dir, version),
39 | crop_size=crop_size,
40 | seq_len=T,
41 | traj_per_sample=N,
42 | use_augs=args.use_augs,
43 | random_seq_len=args.random_seq_len,
44 | random_first_frame=random_first,
45 | random_frame_rate=args.random_frame_rate,
46 | random_number_traj=args.random_number_traj,
47 | shuffle_frames=args.shuffle_frames,
48 | shuffle=True,
49 | only_first=True,
50 | )
51 | dataset_names = [dataset.dname]
52 | return dataset, dataset_names
53 |
54 | def create_pools(n_pool=50, min_size=10):
55 | pools = {}
56 |
57 | n_pool = max(n_pool, 10)
58 |
59 | thrs = [1,2,4,8,16]
60 | for thr in thrs:
61 | pools['d_%d' % thr] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
62 | pools['d_avg'] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
63 |
64 | pool_names = [
65 | 'seq_loss_visible',
66 | 'seq_loss_invisible',
67 | 'vis_loss',
68 | 'conf_loss',
69 | 'total_loss',
70 | ]
71 | for pool_name in pool_names:
72 | pools[pool_name] = utils.misc.SimplePool(n_pool, version='np', min_size=min_size)
73 |
74 | return pools
75 |
76 | def fetch_optimizer(args, model):
77 | """Create the optimizer and learning rate scheduler"""
78 |
79 | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
80 | print(f"Total number of parameters: {total_params}")
81 |
82 | decay_parameters = get_parameter_names(model, [torch.nn.LayerNorm])
83 | decay_parameters = [name for name in decay_parameters if not ("bias" in name)]
84 | nondecay_parameters = [n for n, p in model.named_parameters() if n not in decay_parameters]
85 | optimizer_grouped_parameters = [
86 | {
87 | "params": [p for n, p in model.named_parameters() if n in decay_parameters],
88 | "lr": args.lr,
89 | "weight_decay": args.wdecay,
90 | },
91 | {
92 | "params": [p for n, p in model.named_parameters() if n in nondecay_parameters],
93 | "lr": args.lr,
94 | "weight_decay": 0.0,
95 | },
96 | ]
97 | optimizer = torch.optim.AdamW(params=optimizer_grouped_parameters, lr=args.lr, weight_decay=args.wdecay)
98 |
99 | if args.use_scheduler:
100 | scheduler = torch.optim.lr_scheduler.OneCycleLR(
101 | optimizer,
102 | args.lr,
103 | args.max_steps+100,
104 | pct_start=0.05,
105 | cycle_momentum=False,
106 | anneal_strategy="cos",
107 | )
108 | else:
109 | scheduler = None
110 |
111 | return optimizer, scheduler
112 |
113 | def forward_batch_sparse(batch, model, args, sw, inference_iters):
114 | rgbs = batch.video
115 | trajs_g = batch.trajs
116 | vis_g = batch.visibs # B,S,N
117 | valids = batch.valids
118 | dname = batch.dname
119 |
120 | B, T, C, H, W = rgbs.shape
121 | assert C == 3
122 | B, T, N, D = trajs_g.shape
123 | device = rgbs.device
124 | S = 16
125 |
126 | if N==0 or torch.any(vis_g[:,0].reshape(-1)==0):
127 | print('vis_g', vis_g.shape)
128 | return None, None
129 |
130 | all_flow_e, all_visconf_e, all_flow_preds, all_visconf_preds = model(rgbs, iters=inference_iters, sw=sw, is_training=True)
131 |
132 | grid_xy = utils.basic.gridcloud2d(1, H, W, norm=False, device=device).float() # 1,H*W,2
133 | grid_xy = grid_xy.permute(0,2,1).reshape(1,1,2,H,W) # 1,1,2,H,W
134 | traj_maps_e = all_flow_e + grid_xy # B,T,2,H,W
135 | xy0 = trajs_g[:,0] # B,N,2
136 | xy0[:,:,0] = xy0[:,:,0].clamp(0,W-1)
137 | xy0[:,:,1] = xy0[:,:,1].clamp(0,H-1)
138 |
139 | traj_maps_e_ = traj_maps_e.reshape(B*T,2,H,W)
140 | xy0_ = xy0.reshape(B,1,N,2).repeat(1,T,1,1).reshape(B*T,N,1,2)
141 | trajs_e_ = utils.samp.bilinear_sampler(traj_maps_e_, xy0_) # B*T,2,N,1
142 | trajs_e = trajs_e_.reshape(B,T,2,N).permute(0,1,3,2) # B,T,N,2
143 |
144 | xy0_ = xy0.reshape(B,1,N,2).repeat(1,S,1,1).reshape(B*S,N,1,2)
145 |
146 | coord_predictions = []
147 | assert(B==1)
148 | for fpl in all_flow_preds:
149 | cps = []
150 | for fp in fpl:
151 | traj_map = fp + grid_xy # B,S,2,H,W
152 | traj_map_ = traj_map.reshape(B*S,2,H,W)
153 | traj_e_ = utils.samp.bilinear_sampler(traj_map_, xy0_) # B*T,2,N,1
154 | traj_e = traj_e_.reshape(B,S,2,N).permute(0,1,3,2) # B,S,N,2
155 | cps.append(traj_e)
156 | coord_predictions.append(cps)
157 |
158 | # visconf is upset when i bilinearly sample. says the data is outside [0,1]
159 | # so here we do NN sampling
160 | assert(B==1)
161 | x0, y0 = xy0[0,:,0], xy0[0,:,1] # N
162 | x0 = torch.clamp(x0, 0, W-1).round().long()
163 | y0 = torch.clamp(y0, 0, H-1).round().long()
164 | vis_predictions, confidence_predictions = [], []
165 | for vcl in all_visconf_preds:
166 | vps = []
167 | cps = []
168 | for vc in vcl:
169 | vc = vc[:,:,:,y0,x0] # B,S,2,N
170 | vps.append(vc[:,:,0]) # B,S,N
171 | cps.append(vc[:,:,1]) # B,S,N
172 | vis_predictions.append(vps)
173 | confidence_predictions.append(cps)
174 |
175 | vis_gts = []
176 | invis_gts = []
177 | traj_gts = []
178 | valids_gts = []
179 |
180 | for ind in range(0, T - S // 2, S // 2):
181 | vis_gts.append(vis_g[:, ind : ind + S])
182 | invis_gts.append(1 - vis_g[:, ind : ind + S])
183 | traj_gts.append(trajs_g[:, ind : ind + S, :, :2])
184 | valids_gts.append(valids[:, ind : ind + S])
185 |
186 | total_loss = torch.tensor(0.0, requires_grad=True, device=device)
187 | metrics = {}
188 |
189 | seq_loss_visible = utils.loss.sequence_loss(
190 | coord_predictions,
191 | traj_gts,
192 | valids_gts,
193 | vis=vis_gts,
194 | gamma=0.8,
195 | use_huber_loss=args.use_huber_loss,
196 | loss_only_for_visible=True,
197 | )
198 | confidence_loss = utils.loss.sequence_prob_loss(
199 | coord_predictions, confidence_predictions, traj_gts, vis_gts
200 | )
201 | vis_loss = utils.loss.sequence_BCE_loss(vis_predictions, vis_gts)
202 |
203 | seq_loss_invisible = utils.loss.sequence_loss(
204 | coord_predictions,
205 | traj_gts,
206 | valids_gts,
207 | vis=invis_gts,
208 | gamma=0.8,
209 | use_huber_loss=args.use_huber_loss,
210 | loss_only_for_visible=True,
211 | )
212 |
213 | total_loss = seq_loss_visible.mean()*0.05 + seq_loss_invisible.mean()*0.01 + vis_loss.mean() + confidence_loss.mean()
214 |
215 | if sw is not None and sw.save_scalar:
216 | dn = dname[0]
217 | metrics['dname'] = dn
218 | metrics['seq_loss_visible'] = seq_loss_visible.mean().item()
219 | metrics['seq_loss_invisible'] = seq_loss_invisible.mean().item()
220 | metrics['vis_loss'] = vis_loss.item()
221 | metrics['conf_loss'] = confidence_loss.mean().item()
222 | thrs = [1,2,4,8,16]
223 | sx_ = (W-1) / 255.0
224 | sy_ = (H-1) / 255.0
225 | sc_py = np.array([sx_, sy_]).reshape([1,1,1,2])
226 | sc_pt = torch.from_numpy(sc_py).float().to(device)
227 | d_sum = 0.0
228 | for thr in thrs:
229 | d_ = (torch.norm(trajs_e/sc_pt - trajs_g/sc_pt, dim=-1) < thr).float().mean().item()
230 | d_sum += d_
231 | metrics['d_%d' % thr] = d_
232 | d_avg = d_sum / len(thrs)
233 | metrics['d_avg'] = d_avg
234 | metrics['total_loss'] = total_loss.item()
235 |
236 | if sw is not None and sw.save_this:
237 | utils.basic.print_stats('rgbs', rgbs)
238 | prep_rgbs = utils.basic.normalize(rgbs[0:1])-0.5
239 | prep_grays = prep_rgbs.mean(dim=2, keepdim=True).repeat(1,1,3,1,1)
240 | sw.summ_rgb('0_inputs/rgb0', prep_rgbs[:,0], frame_str=dname[0], frame_id=torch.sum(vis_g[0,0]).item())
241 | sw.summ_traj2ds_on_rgb('0_inputs/trajs_g_on_rgb0', trajs_g[0:1], prep_rgbs[:,0], cmap='winter', linewidth=1)
242 | trajs_clamp = trajs_g.clone()
243 | trajs_clamp[:,:,:,0] = trajs_clamp[:,:,:,0].clip(0,W-1)
244 | trajs_clamp[:,:,:,1] = trajs_clamp[:,:,:,1].clip(0,H-1)
245 | inds = np.random.choice(trajs_g.shape[2], 1024)
246 | outs = sw.summ_pts_on_rgbs(
247 | '',
248 | trajs_clamp[0:1,:,inds],
249 | prep_grays[0:1],
250 | valids=valids[0:1,:,inds],
251 | cmap='winter', linewidth=3, only_return=True)
252 | sw.summ_pts_on_rgbs(
253 | '0_inputs/kps_gv_on_rgbs',
254 | trajs_clamp[0:1,:,inds],
255 | utils.improc.preprocess_color(outs),
256 | valids=valids[0:1,:,inds]*vis_g[0:1,:,inds],
257 | cmap='spring', linewidth=2,
258 | frame_ids=list(range(T)))
259 |
260 | out = utils.improc.preprocess_color(sw.summ_traj2ds_on_rgb('', trajs_g[0:1,:,inds], prep_rgbs[:,0], cmap='winter', linewidth=1, only_return=True))
261 | sw.summ_traj2ds_on_rgb('2_outputs/trajs_e_on_g', trajs_e[0:1,:,inds], out, cmap='spring', linewidth=1)
262 |
263 | trajs_e_clamp = trajs_e.clone()
264 | trajs_e_clamp[:,:,:,0] = trajs_e_clamp[:,:,:,0].clip(0,W-1)
265 | trajs_e_clamp[:,:,:,1] = trajs_e_clamp[:,:,:,1].clip(0,H-1)
266 | inds_e = np.random.choice(trajs_e.shape[2], 1024)
267 | outs = sw.summ_pts_on_rgbs(
268 | '',
269 | trajs_clamp[0:1,:,inds],
270 | prep_grays[0:1],
271 | valids=valids[0:1,:,inds]*vis_g[0:1,:,inds],
272 | cmap='winter', linewidth=2,
273 | only_return=True)
274 | sw.summ_pts_on_rgbs(
275 | '2_outputs/kps_ge_on_rgbs',
276 | trajs_e_clamp[0:1,:,inds],
277 | utils.improc.preprocess_color(outs),
278 | valids=valids[0:1,:,inds]*vis_g[0:1,:,inds],
279 | cmap='spring', linewidth=2,
280 | frame_ids=list(range(T)))
281 |
282 | return total_loss, metrics
283 |
284 |
285 | def run(model, args):
286 | fabric = Fabric(
287 | devices="auto",
288 | num_nodes=1,
289 | strategy="ddp",
290 | accelerator="cuda",
291 | precision="bf16-mixed" if args.mixed_precision else "32-true",
292 | )
293 | fabric.launch() # enable multi-gpu
294 |
295 | def seed_everything(seed: int):
296 | random.seed(seed)
297 | os.environ["PYTHONHASHSEED"] = str(seed)
298 | np.random.seed(seed)
299 | torch.manual_seed(seed)
300 | torch.cuda.manual_seed(seed)
301 | torch.backends.cudnn.deterministic = True
302 | torch.backends.cudnn.benchmark = False
303 | def seed_worker(worker_id):
304 | worker_seed = torch.initial_seed() % 2**32
305 | np.random.seed(worker_seed + worker_id)
306 | random.seed(worker_seed + worker_id)
307 | random_data = os.urandom(4)
308 | seed = int.from_bytes(random_data, byteorder="big")
309 | seed_everything(seed)
310 | g = torch.Generator()
311 | g.manual_seed(seed)
312 |
313 | B_ = args.batch_size * torch.cuda.device_count()
314 | model_name = "%d" % (B_)
315 | if args.use_augs:
316 | model_name += "A"
317 | if args.only_short:
318 | model_name += "i%d" % (args.inference_iters_24)
319 | elif args.no_short:
320 | model_name += "i%d" % (args.inference_iters_56)
321 | else:
322 | model_name += "i%d" % (args.inference_iters_24)
323 | model_name += "i%d" % (args.inference_iters_56)
324 | lrn = utils.basic.get_lr_str(args.lr) # e.g., 5e-4
325 | model_name += "_%s" % lrn
326 | if args.mixed_precision:
327 | model_name += "m"
328 | if args.use_huber_loss:
329 | model_name += "h"
330 | model_name += "_%s" % args.exp
331 | model_date = datetime.datetime.now().strftime('%M%S')
332 | model_name = model_name + '_' + model_date
333 | print('model_name', model_name)
334 |
335 | save_dir = '%s/%s' % (args.ckpt_dir, model_name)
336 |
337 | model.cuda()
338 |
339 | dataset_names = []
340 |
341 | if fabric.world_size==2:
342 | if args.only_short:
343 | ranks_56 = []
344 | ranks_24 = [0,1]
345 | log_ranks = [0]
346 | elif args.no_short:
347 | ranks_56 = [0,1]
348 | ranks_24 = []
349 | log_ranks = [0]
350 | else:
351 | ranks_56 = [0]
352 | ranks_24 = [1]
353 | log_ranks = [0,1]
354 | elif fabric.world_size==8:
355 | ranks_56 = [0,1,2,3]
356 | ranks_24 = [4,5,6,7]
357 | log_ranks = [0,4]
358 | else:
359 | ranks_24 = [0]
360 | ranks_56 = []
361 | log_ranks = [0]
362 | print('assuming we are debugging with 1 gpu...')
363 |
364 | if fabric.global_rank in ranks_56:
365 | sparse_dataset56, sparse_dataset_names56 = get_sparse_dataset(
366 | args, crop_size=args.crop_size_56, T=56, N=args.traj_per_sample_56, random_first=False, version='ce64/kublong')
367 | sparse_loader56 = torch.utils.data.DataLoader(
368 | sparse_dataset56,
369 | batch_size=args.batch_size,
370 | shuffle=True,
371 | num_workers=args.num_workers_56,
372 | worker_init_fn=seed_worker,
373 | generator=g,
374 | pin_memory=False,
375 | collate_fn=utils.data.collate_fn_train,
376 | drop_last=True,
377 | )
378 | sparse_loader56 = fabric.setup_dataloaders(sparse_loader56, move_to_device=False)
379 | print('len(sparse_loader56)', len(sparse_loader56))
380 | sparse_iterloader56 = iter(sparse_loader56)
381 | dataset_names += sparse_dataset_names56
382 | else:
383 | sparse_dataset24, sparse_dataset_names24 = get_sparse_dataset(
384 | args, crop_size=args.crop_size_24, T=24, N=args.traj_per_sample_24, random_first=args.random_first_frame, version='kubric_au')
385 | sparse_loader24 = torch.utils.data.DataLoader(
386 | sparse_dataset24,
387 | batch_size=args.batch_size,
388 | shuffle=True,
389 | num_workers=args.num_workers_24,
390 | worker_init_fn=seed_worker,
391 | generator=g,
392 | pin_memory=False,
393 | collate_fn=utils.data.collate_fn_train,
394 | drop_last=True,
395 | )
396 | sparse_loader24 = fabric.setup_dataloaders(sparse_loader24, move_to_device=False)
397 | print('len(sparse_loader24)', len(sparse_loader24))
398 | sparse_iterloader24 = iter(sparse_loader24)
399 | dataset_names += sparse_dataset_names24
400 |
401 | optimizer, scheduler = fetch_optimizer(args, model)
402 |
403 | if fabric.global_rank in log_ranks:
404 | log_dir = './logs_train'
405 | pools_t = {}
406 | for dname in dataset_names:
407 | if not (dname in pools_t):
408 | print('creating pools for', dname)
409 | pools_t[dname] = create_pools()
410 | overpools_t = create_pools()
411 | writer_t = SummaryWriter(log_dir + '/' + model_name + '/t', max_queue=10, flush_secs=60)
412 |
413 | global_step = 0
414 | if args.init_dir:
415 | load_dir = '%s/%s' % (args.ckpt_dir, args.init_dir)
416 | loaded_global_step = utils.saveload.load(
417 | fabric,
418 | load_dir,
419 | model,
420 | optimizer=optimizer if args.load_optimizer else None,
421 | scheduler=scheduler if args.load_scheduler else None,
422 | ignore_load=None,
423 | strict=True)
424 | if args.load_optimizer and not args.use_scheduler:
425 | assert(optimizer.param_groups[-1]["lr"] == args.lr)
426 | if args.load_step:
427 | global_step = loaded_global_step
428 | if args.use_scheduler and args.load_step and (not args.load_optimizer):
429 | # advance the scheduler to catch up with global_step
430 | for ii in range(global_step):
431 | scheduler.step()
432 |
433 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False)
434 | model.train()
435 |
436 | while global_step < args.max_steps+10:
437 | global_step += 1
438 |
439 | f_start_time = time.time()
440 |
441 | optimizer.zero_grad(set_to_none=True)
442 | assert model.training
443 |
444 | if fabric.global_rank in log_ranks:
445 | sw_t = utils.improc.Summ_writer(
446 | writer=writer_t,
447 | global_step=global_step,
448 | log_freq=args.log_freq,
449 | fps=8,
450 | scalar_freq=min(53,args.log_freq),
451 | just_gif=True)
452 | else:
453 | sw_t = None
454 |
455 | metrics = None
456 |
457 | gotit = [False, False]
458 | while not all(gotit):
459 | if fabric.global_rank in ranks_56:
460 | try:
461 | batch = next(sparse_iterloader56)
462 | except StopIteration:
463 | sparse_iterloader56 = iter(sparse_loader56)
464 | batch = next(sparse_iterloader56)
465 | else:
466 | try:
467 | batch = next(sparse_iterloader24)
468 | except StopIteration:
469 | sparse_iterloader24 = iter(sparse_loader24)
470 | batch = next(sparse_iterloader24)
471 | batch, gotit = batch
472 |
473 | rtime = time.time()-f_start_time
474 |
475 | if fabric.global_rank in ranks_56:
476 | inference_iters = args.inference_iters_56
477 | else:
478 | inference_iters = args.inference_iters_24
479 |
480 | utils.data.dataclass_to_cuda_(batch)
481 | total_loss, metrics = forward_batch_sparse(batch, model, args, sw_t, inference_iters)
482 | ftime = time.time()-f_start_time
483 | fabric.barrier() # wait for all gpus to finish their fwd
484 |
485 | b_start_time = time.time()
486 | if metrics is not None:
487 | if fabric.global_rank in log_ranks and sw_t.save_scalar:
488 | sw_t.summ_scalar('_/current_lr', optimizer.param_groups[-1]["lr"])
489 | sw_t.summ_scalar('total_loss', metrics['total_loss'])
490 |
491 | # update stats
492 | dname = metrics['dname']
493 | if dname in dataset_names:
494 | for key in list(pools_t[dname].keys()):
495 | if key in metrics:
496 | pools_t[dname][key].update([metrics[key]])
497 | overpools_t[key].update([metrics[key]])
498 | # plot stats
499 | for key in list(overpools_t.keys()):
500 | for dname in dataset_names:
501 | sw_t.summ_scalar('%s/%s' % (dname, key), pools_t[dname][key].mean())
502 | sw_t.summ_scalar('_/%s' % (key), overpools_t[key].mean())
503 |
504 | if args.mixed_precision:
505 | fabric.backward(total_loss)
506 | fabric.clip_gradients(model, optimizer, max_norm=1.0, norm_type=2, error_if_nonfinite=False)
507 | optimizer.step()
508 | if args.use_scheduler:
509 | scheduler.step()
510 | else:
511 | fabric.backward(total_loss)
512 | fabric.clip_gradients(model, optimizer, max_norm=1.0, norm_type=2, error_if_nonfinite=False)
513 | optimizer.step()
514 | if args.use_scheduler:
515 | scheduler.step()
516 | btime = time.time()-b_start_time
517 | fabric.barrier() # wait for all gpus to finish their bwd
518 |
519 | itime = ftime + btime
520 |
521 | if global_step % args.save_freq == 0:
522 | if fabric.global_rank == 0:
523 | utils.saveload.save(save_dir, model.module, optimizer, scheduler, global_step, keep_latest=2)
524 |
525 | info_str = '%s; step %06d/%d; rtime %.2f; itime %.2f' % (
526 | model_name, global_step, args.max_steps, rtime, itime)
527 | if fabric.global_rank in log_ranks:
528 | if overpools_t['total_loss'].have_min_size():
529 | info_str += '; loss_t %.2f; d_avg %.1f' % (
530 | overpools_t['total_loss'].mean(),
531 | overpools_t['d_avg'].mean()*100.0,
532 | )
533 | info_str += '; (rank %d)' % fabric.global_rank
534 | print(info_str)
535 | print('done!')
536 |
537 | if fabric.global_rank in log_ranks:
538 | writer_t.close()
539 |
540 | if __name__ == "__main__":
541 | init_dir = ''
542 |
543 | # this file is for training alltracker in "stage 1",
544 | # which involves kubric-only training.
545 | # this is also the file to execute all ablations.
546 |
547 | from nets.alltracker import Net; exp = 'stage1' # clean up for release
548 |
549 | parser = argparse.ArgumentParser()
550 | parser.add_argument("--exp", default=exp)
551 | parser.add_argument("--init_dir", type=str, default=init_dir)
552 | parser.add_argument("--load_optimizer", default=False, action='store_true')
553 | parser.add_argument("--load_scheduler", default=False, action='store_true')
554 | parser.add_argument("--load_step", default=False, action='store_true')
555 | parser.add_argument("--ckpt_dir", type=str, default='./checkpoints')
556 | parser.add_argument("--data_dir", type=str, default='/data')
557 | parser.add_argument("--batch_size", type=int, default=1)
558 | parser.add_argument("--num_nodes", type=int, default=1)
559 | parser.add_argument("--num_workers_24", type=int, default=2)
560 | parser.add_argument("--num_workers_56", type=int, default=6)
561 | parser.add_argument("--mixed_precision", default=False, action='store_true')
562 | parser.add_argument("--lr", type=float, default=5e-4)
563 | parser.add_argument("--wdecay", type=float, default=0.0005)
564 | parser.add_argument("--max_steps", type=int, default=100000)
565 | parser.add_argument("--use_scheduler", default=True)
566 | parser.add_argument("--save_freq", type=int, default=2000)
567 | parser.add_argument("--log_freq", type=int, default=997) # prime number
568 | parser.add_argument("--traj_per_sample_24", type=int, default=256) # note we allow 1-24x this amount
569 | parser.add_argument("--traj_per_sample_56", type=int, default=256) # note we allow 1-24x this amount
570 | parser.add_argument("--inference_iters_24", type=int, default=4)
571 | parser.add_argument("--inference_iters_56", type=int, default=3)
572 | parser.add_argument("--random_frame_rate", default=False, action='store_true')
573 | parser.add_argument("--random_first_frame", default=False, action='store_true')
574 | parser.add_argument("--shuffle_frames", default=False, action='store_true')
575 | parser.add_argument("--use_augs", default=True)
576 | parser.add_argument("--seqlen", type=int, default=16)
577 | parser.add_argument("--crop_size_24", nargs="+", default=[384,512])
578 | parser.add_argument("--crop_size_56", nargs="+", default=[256,384])
579 | parser.add_argument("--random_number_traj", default=False, action='store_true')
580 | parser.add_argument("--use_huber_loss", default=False, action='store_true')
581 | parser.add_argument("--debug", default=False, action='store_true')
582 | parser.add_argument("--random_seq_len", default=False, action='store_true')
583 | parser.add_argument("--no_attn", default=False, action='store_true')
584 | parser.add_argument("--use_mixer", default=False, action='store_true')
585 | parser.add_argument("--use_conv", default=False, action='store_true')
586 | parser.add_argument("--use_convb", default=False, action='store_true')
587 | parser.add_argument("--use_basicencoder", default=False, action='store_true')
588 | parser.add_argument("--only_short", default=False, action='store_true')
589 | parser.add_argument("--no_short", default=False, action='store_true')
590 | parser.add_argument("--no_space", default=False, action='store_true')
591 | parser.add_argument("--no_time", default=False, action='store_true')
592 | parser.add_argument("--no_split", default=False, action='store_true')
593 | parser.add_argument("--no_ctx", default=False, action='store_true')
594 | parser.add_argument("--full_split", default=False, action='store_true')
595 | parser.add_argument("--use_sinmotion", default=False, action='store_true')
596 | parser.add_argument("--use_relmotion", default=False, action='store_true')
597 | parser.add_argument("--use_sinrelmotion", default=False, action='store_true')
598 | parser.add_argument("--use_feats8", default=False, action='store_true')
599 | parser.add_argument("--no_init", default=False, action='store_true')
600 | parser.add_argument("--num_blocks", type=int, default=3)
601 | parser.add_argument("--corr_radius", type=int, default=4)
602 | parser.add_argument("--corr_levels", type=int, default=5)
603 | parser.add_argument("--dim", type=int, default=128)
604 | parser.add_argument("--hdim", type=int, default=128)
605 |
606 | args = parser.parse_args()
607 |
608 | # fix up integers
609 | args.crop_size_24 = (int(args.crop_size_24[0]), int(args.crop_size_24[1]))
610 | args.crop_size_56 = (int(args.crop_size_56[0]), int(args.crop_size_56[1]))
611 |
612 | model = Net(
613 | 16,
614 | dim=args.dim,
615 | hdim=args.hdim,
616 | use_attn=(not args.no_attn),
617 | use_mixer=args.use_mixer,
618 | use_conv=args.use_conv,
619 | use_convb=args.use_convb,
620 | use_basicencoder=args.use_basicencoder,
621 | no_space=args.no_space,
622 | no_time=args.no_time,
623 | use_sinmotion=args.use_sinmotion,
624 | use_relmotion=args.use_relmotion,
625 | use_sinrelmotion=args.use_sinrelmotion,
626 | use_feats8=args.use_feats8,
627 | no_split=args.no_split,
628 | no_ctx=args.no_ctx,
629 | full_split=args.full_split,
630 | num_blocks=args.num_blocks,
631 | corr_radius=args.corr_radius,
632 | corr_levels=args.corr_levels,
633 | init_weights=(not args.no_init),
634 | )
635 |
636 | run(model, args)
637 |
638 |
--------------------------------------------------------------------------------