├── .DS_Store ├── README.md ├── bundle_adjustment.py ├── cotracker ├── .DS_Store ├── checkpoints │ ├── .DS_Store │ └── checkpoint_here ├── cotracker │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── predictor.cpython-310.pyc │ ├── datasets │ │ ├── __init__.py │ │ ├── dataclass_utils.py │ │ ├── dr_dataset.py │ │ ├── kubric_movif_dataset.py │ │ ├── tap_vid_datasets.py │ │ └── utils.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── eval_dynamic_replica.yaml │ │ │ ├── eval_tapvid_davis_first.yaml │ │ │ ├── eval_tapvid_davis_strided.yaml │ │ │ └── eval_tapvid_kinetics_first.yaml │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── eval_utils.py │ │ │ └── evaluator.py │ │ └── evaluate.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── build_cotracker.cpython-310.pyc │ │ ├── build_cotracker.py │ │ ├── core │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── embeddings.cpython-310.pyc │ │ │ │ └── model_utils.cpython-310.pyc │ │ │ ├── cotracker │ │ │ │ ├── __init__.py │ │ │ │ ├── __pycache__ │ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ │ ├── blocks.cpython-310.pyc │ │ │ │ │ └── cotracker.cpython-310.pyc │ │ │ │ ├── blocks.py │ │ │ │ ├── cotracker.py │ │ │ │ └── losses.py │ │ │ ├── embeddings.py │ │ │ └── model_utils.py │ │ └── evaluation_predictor.py │ ├── predictor.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── visualizer.cpython-310.pyc │ │ └── visualizer.py │ └── version.py └── track_and_filter_keypoints.py ├── extract_audio_visual.py ├── face_parsing ├── 79999_iter.pth ├── __pycache__ │ ├── model.cpython-310.pyc │ └── resnet.cpython-310.pyc ├── logger.py ├── model.py ├── resnet.py └── test.py ├── face_tracking ├── .DS_Store ├── 3DMM │ ├── .DS_Store │ ├── exp_info.npy │ ├── keys_info.npy │ ├── sub_mesh.obj │ └── topology_info.npy ├── __init__.py ├── __pycache__ │ ├── data_loader.cpython-310.pyc │ ├── facemodel.cpython-310.pyc │ ├── render_3dmm.cpython-310.pyc │ └── util.cpython-310.pyc ├── convert_BFM.py ├── data_loader.py ├── face_tracker.py ├── facemodel.py ├── geo_transform.py ├── render_3dmm.py ├── render_land.py └── util.py ├── process.py ├── wav2mel.py └── wav2mel_hparams.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/.DS_Store -------------------------------------------------------------------------------- /bundle_adjustment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | from util import forward_transform 8 | 9 | 10 | # Placeholder function for initialization stage 11 | def initialize_keypoints( 12 | keypoints: torch.Tensor, 13 | euler_angle: torch.Tensor, 14 | trans: torch.Tensor, 15 | focal_length: torch.Tensor, 16 | cxy: torch.Tensor, 17 | ): 18 | num_keypoints = keypoints.shape[1] 19 | euler_angle = euler_angle.cuda() 20 | trans = trans.cuda() 21 | focal_length = focal_length.cuda() 22 | spatial_coords = torch.randn( 23 | keypoints.shape[0], num_keypoints, 3, requires_grad=True, device="cuda" 24 | ) 25 | 26 | optimizer = torch.optim.Adam([spatial_coords], lr=0.005) 27 | 28 | losses = [] 29 | for i in tqdm(range(20000)): 30 | proj_spatial_coords = forward_transform( 31 | spatial_coords, euler_angle, trans, focal_length, cxy 32 | ) 33 | print("proj_spatial_coords[:, :, :2]: ", proj_spatial_coords[:, :, :2].shape, keypoints.shape) 34 | loss_init = F.mse_loss(proj_spatial_coords[:, :, :2], keypoints) 35 | losses.append(loss_init) 36 | 37 | optimizer.zero_grad() 38 | loss_init.backward() 39 | optimizer.step() 40 | 41 | if i % 100 == 0: 42 | print(sum(losses) / len(losses)) 43 | losses = [] 44 | 45 | return spatial_coords, euler_angle, focal_length, trans 46 | 47 | 48 | # Placeholder function for comprehensive optimization stage 49 | def optimize_keypoints_and_pose( 50 | keypoints: torch.Tensor, 51 | id_para: torch.Tensor, 52 | exp_para: torch.Tensor, 53 | euler_angle: torch.Tensor, 54 | trans: torch.Tensor, 55 | focal_length: torch.Tensor, 56 | cxy: torch.Tensor, 57 | spatial_coords: torch.Tensor, 58 | ): 59 | id_para = id_para.new_tensor(id_para.data, device="cuda", requires_grad=True) 60 | exp_para = exp_para.new_tensor(exp_para.data, device="cuda", requires_grad=True) 61 | euler_angle = euler_angle.new_tensor( 62 | euler_angle.data, device="cuda", requires_grad=True 63 | ) 64 | trans = trans.new_tensor(trans.data, device="cuda", requires_grad=True) 65 | optimizer = torch.optim.Adam([euler_angle, trans, spatial_coords], lr=0.0001) 66 | 67 | losses = [] 68 | for i in tqdm(range(10000)): 69 | proj_spatial_coords = forward_transform( 70 | spatial_coords, euler_angle, trans, focal_length, cxy 71 | ) 72 | 73 | loss_sec = F.mse_loss(proj_spatial_coords[:, :, :2], keypoints) 74 | losses.append(loss_sec) 75 | 76 | optimizer.zero_grad() 77 | loss_sec.backward() 78 | optimizer.step() 79 | 80 | if i % 100 == 0: 81 | print(sum(losses) / len(losses)) 82 | losses = [] 83 | 84 | return id_para, exp_para, euler_angle, trans, spatial_coords 85 | 86 | 87 | def parse_args() -> argparse.Namespace: 88 | parser = argparse.ArgumentParser( 89 | description="Bundle Adjustment of refined Rotation and Translation parameters." 90 | ) 91 | parser.add_argument("--keypoints-path", type=str, required=True) 92 | parser.add_argument("--track-params-path", type=str, required=True) 93 | return parser.parse_args() 94 | 95 | 96 | def process(keypoints_path: str, track_params_path: str) -> None: 97 | keypoints = torch.load(keypoints_path).squeeze() 98 | 99 | w, h = 512, 512 100 | cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda() 101 | 102 | track_params = torch.load(track_params_path) 103 | id_para = track_params["id"].cuda() 104 | exp_para = track_params["exp"].cuda() 105 | euler_angle = track_params["euler"].cuda() 106 | trans = track_params["trans"].cuda() 107 | focal_length = track_params["focal"] 108 | 109 | spatial_coords, euler_angle, focal_length, trans = initialize_keypoints( 110 | keypoints, euler_angle, trans, focal_length, cxy 111 | ) 112 | 113 | print(euler_angle) 114 | 115 | id_para, exp_para, euler_angle, trans, spatial_coords = optimize_keypoints_and_pose( 116 | keypoints, 117 | id_para, 118 | exp_para, 119 | euler_angle, 120 | trans, 121 | focal_length, 122 | cxy, 123 | spatial_coords, 124 | ) 125 | 126 | print(euler_angle) 127 | 128 | # new track_params 129 | track_params = { 130 | "id": id_para.detach().cpu(), 131 | "exp": exp_para.detach().cpu(), 132 | "euler": euler_angle.detach().cpu(), 133 | "trans": trans.detach().cpu(), 134 | "focal": focal_length.detach().cpu(), 135 | } 136 | 137 | torch.save( 138 | track_params, 139 | os.path.join(os.path.dirname(track_params_path), "track_params_init.pt"), 140 | ) 141 | 142 | torch.save( 143 | { 144 | "id": id_para.detach().cpu(), 145 | "exp": exp_para.detach().cpu(), 146 | "euler": euler_angle.detach().cpu(), 147 | "trans": trans.detach().cpu(), 148 | "focal": focal_length.detach().cpu(), 149 | }, 150 | track_params_path, 151 | ) 152 | 153 | 154 | def main() -> None: 155 | args = parse_args() 156 | 157 | keypoints_path: str = args.keypoints_path 158 | track_params_path: str = args.track_params_path 159 | 160 | assert os.path.exists(keypoints_path) 161 | assert os.path.exists(track_params_path) 162 | 163 | process(keypoints_path, track_params_path) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /cotracker/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/.DS_Store -------------------------------------------------------------------------------- /cotracker/checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/checkpoints/.DS_Store -------------------------------------------------------------------------------- /cotracker/checkpoints/checkpoint_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/checkpoints/checkpoint_here -------------------------------------------------------------------------------- /cotracker/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/__pycache__/predictor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/__pycache__/predictor.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/datasets/dataclass_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import dataclasses 10 | import numpy as np 11 | from dataclasses import Field, MISSING 12 | from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple 13 | 14 | _X = TypeVar("_X") 15 | 16 | 17 | def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: 18 | """ 19 | Loads to a @dataclass or collection hierarchy including dataclasses 20 | from a json recursively. 21 | Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). 22 | raises KeyError if json has keys not mapping to the dataclass fields. 23 | 24 | Args: 25 | f: Either a path to a file, or a file opened for writing. 26 | cls: The class of the loaded dataclass. 27 | binary: Set to True if `f` is a file handle, else False. 28 | """ 29 | if binary: 30 | asdict = json.loads(f.read().decode("utf8")) 31 | else: 32 | asdict = json.load(f) 33 | 34 | # in the list case, run a faster "vectorized" version 35 | cls = get_args(cls)[0] 36 | res = list(_dataclass_list_from_dict_list(asdict, cls)) 37 | 38 | return res 39 | 40 | 41 | def _resolve_optional(type_: Any) -> Tuple[bool, Any]: 42 | """Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" 43 | if get_origin(type_) is Union: 44 | args = get_args(type_) 45 | if len(args) == 2 and args[1] == type(None): # noqa E721 46 | return True, args[0] 47 | if type_ is Any: 48 | return True, Any 49 | 50 | return False, type_ 51 | 52 | 53 | def _unwrap_type(tp): 54 | # strips Optional wrapper, if any 55 | if get_origin(tp) is Union: 56 | args = get_args(tp) 57 | if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721 58 | # this is typing.Optional 59 | return args[0] if args[1] is type(None) else args[1] # noqa: E721 60 | return tp 61 | 62 | 63 | def _get_dataclass_field_default(field: Field) -> Any: 64 | if field.default_factory is not MISSING: 65 | # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, 66 | # dataclasses._DefaultFactory[typing.Any]]` is not a function. 67 | return field.default_factory() 68 | elif field.default is not MISSING: 69 | return field.default 70 | else: 71 | return None 72 | 73 | 74 | def _dataclass_list_from_dict_list(dlist, typeannot): 75 | """ 76 | Vectorised version of `_dataclass_from_dict`. 77 | The output should be equivalent to 78 | `[_dataclass_from_dict(d, typeannot) for d in dlist]`. 79 | 80 | Args: 81 | dlist: list of objects to convert. 82 | typeannot: type of each of those objects. 83 | Returns: 84 | iterator or list over converted objects of the same length as `dlist`. 85 | 86 | Raises: 87 | ValueError: it assumes the objects have None's in consistent places across 88 | objects, otherwise it would ignore some values. This generally holds for 89 | auto-generated annotations, but otherwise use `_dataclass_from_dict`. 90 | """ 91 | 92 | cls = get_origin(typeannot) or typeannot 93 | 94 | if typeannot is Any: 95 | return dlist 96 | if all(obj is None for obj in dlist): # 1st recursion base: all None nodes 97 | return dlist 98 | if any(obj is None for obj in dlist): 99 | # filter out Nones and recurse on the resulting list 100 | idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] 101 | idx, notnone = zip(*idx_notnone) 102 | converted = _dataclass_list_from_dict_list(notnone, typeannot) 103 | res = [None] * len(dlist) 104 | for i, obj in zip(idx, converted): 105 | res[i] = obj 106 | return res 107 | 108 | is_optional, contained_type = _resolve_optional(typeannot) 109 | if is_optional: 110 | return _dataclass_list_from_dict_list(dlist, contained_type) 111 | 112 | # otherwise, we dispatch by the type of the provided annotation to convert to 113 | if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple 114 | # For namedtuple, call the function recursively on the lists of corresponding keys 115 | types = cls.__annotations__.values() 116 | dlist_T = zip(*dlist) 117 | res_T = [ 118 | _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) 119 | ] 120 | return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] 121 | elif issubclass(cls, (list, tuple)): 122 | # For list/tuple, call the function recursively on the lists of corresponding positions 123 | types = get_args(typeannot) 124 | if len(types) == 1: # probably List; replicate for all items 125 | types = types * len(dlist[0]) 126 | dlist_T = zip(*dlist) 127 | res_T = ( 128 | _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) 129 | ) 130 | if issubclass(cls, tuple): 131 | return list(zip(*res_T)) 132 | else: 133 | return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] 134 | elif issubclass(cls, dict): 135 | # For the dictionary, call the function recursively on concatenated keys and vertices 136 | key_t, val_t = get_args(typeannot) 137 | all_keys_res = _dataclass_list_from_dict_list( 138 | [k for obj in dlist for k in obj.keys()], key_t 139 | ) 140 | all_vals_res = _dataclass_list_from_dict_list( 141 | [k for obj in dlist for k in obj.values()], val_t 142 | ) 143 | indices = np.cumsum([len(obj) for obj in dlist]) 144 | assert indices[-1] == len(all_keys_res) 145 | 146 | keys = np.split(list(all_keys_res), indices[:-1]) 147 | all_vals_res_iter = iter(all_vals_res) 148 | return [cls(zip(k, all_vals_res_iter)) for k in keys] 149 | elif not dataclasses.is_dataclass(typeannot): 150 | return dlist 151 | 152 | # dataclass node: 2nd recursion base; call the function recursively on the lists 153 | # of the corresponding fields 154 | assert dataclasses.is_dataclass(cls) 155 | fieldtypes = { 156 | f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) 157 | for f in dataclasses.fields(typeannot) 158 | } 159 | 160 | # NOTE the default object is shared here 161 | key_lists = ( 162 | _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) 163 | for k, (type_, default) in fieldtypes.items() 164 | ) 165 | transposed = zip(*key_lists) 166 | return [cls(*vals_as_tuple) for vals_as_tuple in transposed] 167 | -------------------------------------------------------------------------------- /cotracker/cotracker/datasets/dr_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import gzip 10 | import torch 11 | import numpy as np 12 | import torch.utils.data as data 13 | from collections import defaultdict 14 | from dataclasses import dataclass 15 | from typing import List, Optional, Any, Dict, Tuple 16 | 17 | from cotracker.datasets.utils import CoTrackerData 18 | from cotracker.datasets.dataclass_utils import load_dataclass 19 | 20 | 21 | @dataclass 22 | class ImageAnnotation: 23 | # path to jpg file, relative w.r.t. dataset_root 24 | path: str 25 | # H x W 26 | size: Tuple[int, int] 27 | 28 | 29 | @dataclass 30 | class DynamicReplicaFrameAnnotation: 31 | """A dataclass used to load annotations from json.""" 32 | 33 | # can be used to join with `SequenceAnnotation` 34 | sequence_name: str 35 | # 0-based, continuous frame number within sequence 36 | frame_number: int 37 | # timestamp in seconds from the video start 38 | frame_timestamp: float 39 | 40 | image: ImageAnnotation 41 | meta: Optional[Dict[str, Any]] = None 42 | 43 | camera_name: Optional[str] = None 44 | trajectories: Optional[str] = None 45 | 46 | 47 | class DynamicReplicaDataset(data.Dataset): 48 | def __init__( 49 | self, 50 | root, 51 | split="valid", 52 | traj_per_sample=256, 53 | crop_size=None, 54 | sample_len=-1, 55 | only_first_n_samples=-1, 56 | rgbd_input=False, 57 | ): 58 | super(DynamicReplicaDataset, self).__init__() 59 | self.root = root 60 | self.sample_len = sample_len 61 | self.split = split 62 | self.traj_per_sample = traj_per_sample 63 | self.rgbd_input = rgbd_input 64 | self.crop_size = crop_size 65 | frame_annotations_file = f"frame_annotations_{split}.jgz" 66 | self.sample_list = [] 67 | with gzip.open( 68 | os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" 69 | ) as zipfile: 70 | frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) 71 | seq_annot = defaultdict(list) 72 | for frame_annot in frame_annots_list: 73 | if frame_annot.camera_name == "left": 74 | seq_annot[frame_annot.sequence_name].append(frame_annot) 75 | 76 | for seq_name in seq_annot.keys(): 77 | seq_len = len(seq_annot[seq_name]) 78 | 79 | step = self.sample_len if self.sample_len > 0 else seq_len 80 | counter = 0 81 | 82 | for ref_idx in range(0, seq_len, step): 83 | sample = seq_annot[seq_name][ref_idx : ref_idx + step] 84 | self.sample_list.append(sample) 85 | counter += 1 86 | if only_first_n_samples > 0 and counter >= only_first_n_samples: 87 | break 88 | 89 | def __len__(self): 90 | return len(self.sample_list) 91 | 92 | def crop(self, rgbs, trajs): 93 | T, N, _ = trajs.shape 94 | 95 | S = len(rgbs) 96 | H, W = rgbs[0].shape[:2] 97 | assert S == T 98 | 99 | H_new = H 100 | W_new = W 101 | 102 | # simple random crop 103 | y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 104 | x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 105 | rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] 106 | 107 | trajs[:, :, 0] -= x0 108 | trajs[:, :, 1] -= y0 109 | 110 | return rgbs, trajs 111 | 112 | def __getitem__(self, index): 113 | sample = self.sample_list[index] 114 | T = len(sample) 115 | rgbs, visibilities, traj_2d = [], [], [] 116 | 117 | H, W = sample[0].image.size 118 | image_size = (H, W) 119 | 120 | for i in range(T): 121 | traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) 122 | traj = torch.load(traj_path) 123 | 124 | visibilities.append(traj["verts_inds_vis"].numpy()) 125 | 126 | rgbs.append(traj["img"].numpy()) 127 | traj_2d.append(traj["traj_2d"].numpy()[..., :2]) 128 | 129 | traj_2d = np.stack(traj_2d) 130 | visibility = np.stack(visibilities) 131 | T, N, D = traj_2d.shape 132 | # subsample trajectories for augmentations 133 | visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] 134 | 135 | traj_2d = traj_2d[:, visible_inds_sampled] 136 | visibility = visibility[:, visible_inds_sampled] 137 | 138 | if self.crop_size is not None: 139 | rgbs, traj_2d = self.crop(rgbs, traj_2d) 140 | H, W, _ = rgbs[0].shape 141 | image_size = self.crop_size 142 | 143 | visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False 144 | visibility[traj_2d[:, :, 0] < 0] = False 145 | visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False 146 | visibility[traj_2d[:, :, 1] < 0] = False 147 | 148 | # filter out points that're visible for less than 10 frames 149 | visible_inds_resampled = visibility.sum(0) > 10 150 | traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) 151 | visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) 152 | 153 | rgbs = np.stack(rgbs, 0) 154 | video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() 155 | return CoTrackerData( 156 | video=video, 157 | trajectory=traj_2d, 158 | visibility=visibility, 159 | valid=torch.ones(T, N), 160 | seq_name=sample[0].sequence_name, 161 | ) 162 | -------------------------------------------------------------------------------- /cotracker/cotracker/datasets/tap_vid_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import io 9 | import glob 10 | import torch 11 | import pickle 12 | import numpy as np 13 | import mediapy as media 14 | 15 | from PIL import Image 16 | from typing import Mapping, Tuple, Union 17 | 18 | from cotracker.datasets.utils import CoTrackerData 19 | 20 | DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] 21 | 22 | 23 | def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: 24 | """Resize a video to output_size.""" 25 | # If you have a GPU, consider replacing this with a GPU-enabled resize op, 26 | # such as a jitted jax.image.resize. It will make things faster. 27 | return media.resize_video(video, output_size) 28 | 29 | 30 | def sample_queries_first( 31 | target_occluded: np.ndarray, 32 | target_points: np.ndarray, 33 | frames: np.ndarray, 34 | ) -> Mapping[str, np.ndarray]: 35 | """Package a set of frames and tracks for use in TAPNet evaluations. 36 | Given a set of frames and tracks with no query points, use the first 37 | visible point in each track as the query. 38 | Args: 39 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 40 | where True indicates occluded. 41 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 42 | is [x,y] scaled between 0 and 1. 43 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 44 | -1 and 1. 45 | Returns: 46 | A dict with the keys: 47 | video: Video tensor of shape [1, n_frames, height, width, 3] 48 | query_points: Query points of shape [1, n_queries, 3] where 49 | each point is [t, y, x] scaled to the range [-1, 1] 50 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 51 | each point is [x, y] scaled to the range [-1, 1] 52 | """ 53 | valid = np.sum(~target_occluded, axis=1) > 0 54 | target_points = target_points[valid, :] 55 | target_occluded = target_occluded[valid, :] 56 | 57 | query_points = [] 58 | for i in range(target_points.shape[0]): 59 | index = np.where(target_occluded[i] == 0)[0][0] 60 | x, y = target_points[i, index, 0], target_points[i, index, 1] 61 | query_points.append(np.array([index, y, x])) # [t, y, x] 62 | query_points = np.stack(query_points, axis=0) 63 | 64 | return { 65 | "video": frames[np.newaxis, ...], 66 | "query_points": query_points[np.newaxis, ...], 67 | "target_points": target_points[np.newaxis, ...], 68 | "occluded": target_occluded[np.newaxis, ...], 69 | } 70 | 71 | 72 | def sample_queries_strided( 73 | target_occluded: np.ndarray, 74 | target_points: np.ndarray, 75 | frames: np.ndarray, 76 | query_stride: int = 5, 77 | ) -> Mapping[str, np.ndarray]: 78 | """Package a set of frames and tracks for use in TAPNet evaluations. 79 | 80 | Given a set of frames and tracks with no query points, sample queries 81 | strided every query_stride frames, ignoring points that are not visible 82 | at the selected frames. 83 | 84 | Args: 85 | target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], 86 | where True indicates occluded. 87 | target_points: Position, of shape [n_tracks, n_frames, 2], where each point 88 | is [x,y] scaled between 0 and 1. 89 | frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between 90 | -1 and 1. 91 | query_stride: When sampling query points, search for un-occluded points 92 | every query_stride frames and convert each one into a query. 93 | 94 | Returns: 95 | A dict with the keys: 96 | video: Video tensor of shape [1, n_frames, height, width, 3]. The video 97 | has floats scaled to the range [-1, 1]. 98 | query_points: Query points of shape [1, n_queries, 3] where 99 | each point is [t, y, x] scaled to the range [-1, 1]. 100 | target_points: Target points of shape [1, n_queries, n_frames, 2] where 101 | each point is [x, y] scaled to the range [-1, 1]. 102 | trackgroup: Index of the original track that each query point was 103 | sampled from. This is useful for visualization. 104 | """ 105 | tracks = [] 106 | occs = [] 107 | queries = [] 108 | trackgroups = [] 109 | total = 0 110 | trackgroup = np.arange(target_occluded.shape[0]) 111 | for i in range(0, target_occluded.shape[1], query_stride): 112 | mask = target_occluded[:, i] == 0 113 | query = np.stack( 114 | [ 115 | i * np.ones(target_occluded.shape[0:1]), 116 | target_points[:, i, 1], 117 | target_points[:, i, 0], 118 | ], 119 | axis=-1, 120 | ) 121 | queries.append(query[mask]) 122 | tracks.append(target_points[mask]) 123 | occs.append(target_occluded[mask]) 124 | trackgroups.append(trackgroup[mask]) 125 | total += np.array(np.sum(target_occluded[:, i] == 0)) 126 | 127 | return { 128 | "video": frames[np.newaxis, ...], 129 | "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], 130 | "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], 131 | "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], 132 | "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], 133 | } 134 | 135 | 136 | class TapVidDataset(torch.utils.data.Dataset): 137 | def __init__( 138 | self, 139 | data_root, 140 | dataset_type="davis", 141 | resize_to_256=True, 142 | queried_first=True, 143 | ): 144 | self.dataset_type = dataset_type 145 | self.resize_to_256 = resize_to_256 146 | self.queried_first = queried_first 147 | if self.dataset_type == "kinetics": 148 | all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) 149 | points_dataset = [] 150 | for pickle_path in all_paths: 151 | with open(pickle_path, "rb") as f: 152 | data = pickle.load(f) 153 | points_dataset = points_dataset + data 154 | self.points_dataset = points_dataset 155 | else: 156 | with open(data_root, "rb") as f: 157 | self.points_dataset = pickle.load(f) 158 | if self.dataset_type == "davis": 159 | self.video_names = list(self.points_dataset.keys()) 160 | print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) 161 | 162 | def __getitem__(self, index): 163 | if self.dataset_type == "davis": 164 | video_name = self.video_names[index] 165 | else: 166 | video_name = index 167 | video = self.points_dataset[video_name] 168 | frames = video["video"] 169 | 170 | if isinstance(frames[0], bytes): 171 | # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. 172 | def decode(frame): 173 | byteio = io.BytesIO(frame) 174 | img = Image.open(byteio) 175 | return np.array(img) 176 | 177 | frames = np.array([decode(frame) for frame in frames]) 178 | 179 | target_points = self.points_dataset[video_name]["points"] 180 | if self.resize_to_256: 181 | frames = resize_video(frames, [256, 256]) 182 | target_points *= np.array([255, 255]) # 1 should be mapped to 256-1 183 | else: 184 | target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1]) 185 | 186 | target_occ = self.points_dataset[video_name]["occluded"] 187 | if self.queried_first: 188 | converted = sample_queries_first(target_occ, target_points, frames) 189 | else: 190 | converted = sample_queries_strided(target_occ, target_points, frames) 191 | assert converted["target_points"].shape[1] == converted["query_points"].shape[1] 192 | 193 | trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D 194 | 195 | rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() 196 | visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute( 197 | 1, 0 198 | ) # T, N 199 | query_points = torch.from_numpy(converted["query_points"])[0] # T, N 200 | return CoTrackerData( 201 | rgbs, 202 | trajs, 203 | visibles, 204 | seq_name=str(video_name), 205 | query_points=query_points, 206 | ) 207 | 208 | def __len__(self): 209 | return len(self.points_dataset) 210 | -------------------------------------------------------------------------------- /cotracker/cotracker/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | import dataclasses 10 | import torch.nn.functional as F 11 | from dataclasses import dataclass 12 | from typing import Any, Optional 13 | 14 | 15 | @dataclass(eq=False) 16 | class CoTrackerData: 17 | """ 18 | Dataclass for storing video tracks data. 19 | """ 20 | 21 | video: torch.Tensor # B, S, C, H, W 22 | trajectory: torch.Tensor # B, S, N, 2 23 | visibility: torch.Tensor # B, S, N 24 | # optional data 25 | valid: Optional[torch.Tensor] = None # B, S, N 26 | segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W 27 | seq_name: Optional[str] = None 28 | query_points: Optional[torch.Tensor] = None # TapVID evaluation format 29 | 30 | 31 | def collate_fn(batch): 32 | """ 33 | Collate function for video tracks data. 34 | """ 35 | video = torch.stack([b.video for b in batch], dim=0) 36 | trajectory = torch.stack([b.trajectory for b in batch], dim=0) 37 | visibility = torch.stack([b.visibility for b in batch], dim=0) 38 | query_points = segmentation = None 39 | if batch[0].query_points is not None: 40 | query_points = torch.stack([b.query_points for b in batch], dim=0) 41 | if batch[0].segmentation is not None: 42 | segmentation = torch.stack([b.segmentation for b in batch], dim=0) 43 | seq_name = [b.seq_name for b in batch] 44 | 45 | return CoTrackerData( 46 | video=video, 47 | trajectory=trajectory, 48 | visibility=visibility, 49 | segmentation=segmentation, 50 | seq_name=seq_name, 51 | query_points=query_points, 52 | ) 53 | 54 | 55 | def collate_fn_train(batch): 56 | """ 57 | Collate function for video tracks data during training. 58 | """ 59 | gotit = [gotit for _, gotit in batch] 60 | video = torch.stack([b.video for b, _ in batch], dim=0) 61 | trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) 62 | visibility = torch.stack([b.visibility for b, _ in batch], dim=0) 63 | valid = torch.stack([b.valid for b, _ in batch], dim=0) 64 | seq_name = [b.seq_name for b, _ in batch] 65 | return ( 66 | CoTrackerData( 67 | video=video, 68 | trajectory=trajectory, 69 | visibility=visibility, 70 | valid=valid, 71 | seq_name=seq_name, 72 | ), 73 | gotit, 74 | ) 75 | 76 | 77 | def try_to_cuda(t: Any) -> Any: 78 | """ 79 | Try to move the input variable `t` to a cuda device. 80 | 81 | Args: 82 | t: Input. 83 | 84 | Returns: 85 | t_cuda: `t` moved to a cuda device, if supported. 86 | """ 87 | try: 88 | t = t.float().cuda() 89 | except AttributeError: 90 | pass 91 | return t 92 | 93 | 94 | def dataclass_to_cuda_(obj): 95 | """ 96 | Move all contents of a dataclass to cuda inplace if supported. 97 | 98 | Args: 99 | batch: Input dataclass. 100 | 101 | Returns: 102 | batch_cuda: `batch` moved to a cuda device, if supported. 103 | """ 104 | for f in dataclasses.fields(obj): 105 | setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) 106 | return obj 107 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/configs/eval_dynamic_replica.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: dynamic_replica 5 | 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/configs/eval_tapvid_davis_first.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: tapvid_davis_first 5 | 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: tapvid_davis_strided 5 | 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_config_eval 3 | exp_dir: ./outputs/cotracker 4 | dataset_name: tapvid_kinetics_first 5 | 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/core/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | from typing import Iterable, Mapping, Tuple, Union 10 | 11 | 12 | def compute_tapvid_metrics( 13 | query_points: np.ndarray, 14 | gt_occluded: np.ndarray, 15 | gt_tracks: np.ndarray, 16 | pred_occluded: np.ndarray, 17 | pred_tracks: np.ndarray, 18 | query_mode: str, 19 | ) -> Mapping[str, np.ndarray]: 20 | """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) 21 | See the TAP-Vid paper for details on the metric computation. All inputs are 22 | given in raster coordinates. The first three arguments should be the direct 23 | outputs of the reader: the 'query_points', 'occluded', and 'target_points'. 24 | The paper metrics assume these are scaled relative to 256x256 images. 25 | pred_occluded and pred_tracks are your algorithm's predictions. 26 | This function takes a batch of inputs, and computes metrics separately for 27 | each video. The metrics for the full benchmark are a simple mean of the 28 | metrics across the full set of videos. These numbers are between 0 and 1, 29 | but the paper multiplies them by 100 to ease reading. 30 | Args: 31 | query_points: The query points, an in the format [t, y, x]. Its size is 32 | [b, n, 3], where b is the batch size and n is the number of queries 33 | gt_occluded: A boolean array of shape [b, n, t], where t is the number 34 | of frames. True indicates that the point is occluded. 35 | gt_tracks: The target points, of shape [b, n, t, 2]. Each point is 36 | in the format [x, y] 37 | pred_occluded: A boolean array of predicted occlusions, in the same 38 | format as gt_occluded. 39 | pred_tracks: An array of track predictions from your algorithm, in the 40 | same format as gt_tracks. 41 | query_mode: Either 'first' or 'strided', depending on how queries are 42 | sampled. If 'first', we assume the prior knowledge that all points 43 | before the query point are occluded, and these are removed from the 44 | evaluation. 45 | Returns: 46 | A dict with the following keys: 47 | occlusion_accuracy: Accuracy at predicting occlusion. 48 | pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points 49 | predicted to be within the given pixel threshold, ignoring occlusion 50 | prediction. 51 | jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given 52 | threshold 53 | average_pts_within_thresh: average across pts_within_{x} 54 | average_jaccard: average across jaccard_{x} 55 | """ 56 | 57 | metrics = {} 58 | # Fixed bug is described in: 59 | # https://github.com/facebookresearch/co-tracker/issues/20 60 | eye = np.eye(gt_tracks.shape[2], dtype=np.int32) 61 | 62 | if query_mode == "first": 63 | # evaluate frames after the query frame 64 | query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye 65 | elif query_mode == "strided": 66 | # evaluate all frames except the query frame 67 | query_frame_to_eval_frames = 1 - eye 68 | else: 69 | raise ValueError("Unknown query mode " + query_mode) 70 | 71 | query_frame = query_points[..., 0] 72 | query_frame = np.round(query_frame).astype(np.int32) 73 | evaluation_points = query_frame_to_eval_frames[query_frame] > 0 74 | 75 | # Occlusion accuracy is simply how often the predicted occlusion equals the 76 | # ground truth. 77 | occ_acc = np.sum( 78 | np.equal(pred_occluded, gt_occluded) & evaluation_points, 79 | axis=(1, 2), 80 | ) / np.sum(evaluation_points) 81 | metrics["occlusion_accuracy"] = occ_acc 82 | 83 | # Next, convert the predictions and ground truth positions into pixel 84 | # coordinates. 85 | visible = np.logical_not(gt_occluded) 86 | pred_visible = np.logical_not(pred_occluded) 87 | all_frac_within = [] 88 | all_jaccard = [] 89 | for thresh in [1, 2, 4, 8, 16]: 90 | # True positives are points that are within the threshold and where both 91 | # the prediction and the ground truth are listed as visible. 92 | within_dist = np.sum( 93 | np.square(pred_tracks - gt_tracks), 94 | axis=-1, 95 | ) < np.square(thresh) 96 | is_correct = np.logical_and(within_dist, visible) 97 | 98 | # Compute the frac_within_threshold, which is the fraction of points 99 | # within the threshold among points that are visible in the ground truth, 100 | # ignoring whether they're predicted to be visible. 101 | count_correct = np.sum( 102 | is_correct & evaluation_points, 103 | axis=(1, 2), 104 | ) 105 | count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) 106 | frac_correct = count_correct / count_visible_points 107 | metrics["pts_within_" + str(thresh)] = frac_correct 108 | all_frac_within.append(frac_correct) 109 | 110 | true_positives = np.sum( 111 | is_correct & pred_visible & evaluation_points, axis=(1, 2) 112 | ) 113 | 114 | # The denominator of the jaccard metric is the true positives plus 115 | # false positives plus false negatives. However, note that true positives 116 | # plus false negatives is simply the number of points in the ground truth 117 | # which is easier to compute than trying to compute all three quantities. 118 | # Thus we just add the number of points in the ground truth to the number 119 | # of false positives. 120 | # 121 | # False positives are simply points that are predicted to be visible, 122 | # but the ground truth is not visible or too far from the prediction. 123 | gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) 124 | false_positives = (~visible) & pred_visible 125 | false_positives = false_positives | ((~within_dist) & pred_visible) 126 | false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) 127 | jaccard = true_positives / (gt_positives + false_positives) 128 | metrics["jaccard_" + str(thresh)] = jaccard 129 | all_jaccard.append(jaccard) 130 | metrics["average_jaccard"] = np.mean( 131 | np.stack(all_jaccard, axis=1), 132 | axis=1, 133 | ) 134 | metrics["average_pts_within_thresh"] = np.mean( 135 | np.stack(all_frac_within, axis=1), 136 | axis=1, 137 | ) 138 | return metrics 139 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/core/evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import os 9 | from typing import Optional 10 | import torch 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from torch.utils.tensorboard import SummaryWriter 15 | from cotracker.datasets.utils import dataclass_to_cuda_ 16 | from cotracker.utils.visualizer import Visualizer 17 | from cotracker.models.core.model_utils import reduce_masked_mean 18 | from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics 19 | 20 | import logging 21 | 22 | 23 | class Evaluator: 24 | """ 25 | A class defining the CoTracker evaluator. 26 | """ 27 | 28 | def __init__(self, exp_dir) -> None: 29 | # Visualization 30 | self.exp_dir = exp_dir 31 | os.makedirs(exp_dir, exist_ok=True) 32 | self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) 33 | self.visualize_dir = os.path.join(exp_dir, "visualisations") 34 | 35 | def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): 36 | if isinstance(pred_trajectory, tuple): 37 | pred_trajectory, pred_visibility = pred_trajectory 38 | else: 39 | pred_visibility = None 40 | if "tapvid" in dataset_name: 41 | B, T, N, D = sample.trajectory.shape 42 | traj = sample.trajectory.clone() 43 | thr = 0.9 44 | 45 | if pred_visibility is None: 46 | logging.warning("visibility is NONE") 47 | pred_visibility = torch.zeros_like(sample.visibility) 48 | 49 | if not pred_visibility.dtype == torch.bool: 50 | pred_visibility = pred_visibility > thr 51 | 52 | query_points = sample.query_points.clone().cpu().numpy() 53 | 54 | pred_visibility = pred_visibility[:, :, :N] 55 | pred_trajectory = pred_trajectory[:, :, :N] 56 | 57 | gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() 58 | gt_occluded = ( 59 | torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy() 60 | ) 61 | 62 | pred_occluded = ( 63 | torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy() 64 | ) 65 | pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() 66 | 67 | out_metrics = compute_tapvid_metrics( 68 | query_points, 69 | gt_occluded, 70 | gt_tracks, 71 | pred_occluded, 72 | pred_tracks, 73 | query_mode="strided" if "strided" in dataset_name else "first", 74 | ) 75 | 76 | metrics[sample.seq_name[0]] = out_metrics 77 | for metric_name in out_metrics.keys(): 78 | if "avg" not in metrics: 79 | metrics["avg"] = {} 80 | metrics["avg"][metric_name] = np.mean( 81 | [v[metric_name] for k, v in metrics.items() if k != "avg"] 82 | ) 83 | 84 | logging.info(f"Metrics: {out_metrics}") 85 | logging.info(f"avg: {metrics['avg']}") 86 | print("metrics", out_metrics) 87 | print("avg", metrics["avg"]) 88 | elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey": 89 | *_, N, _ = sample.trajectory.shape 90 | B, T, N = sample.visibility.shape 91 | H, W = sample.video.shape[-2:] 92 | device = sample.video.device 93 | 94 | out_metrics = {} 95 | 96 | d_vis_sum = d_occ_sum = d_sum_all = 0.0 97 | thrs = [1, 2, 4, 8, 16] 98 | sx_ = (W - 1) / 255.0 99 | sy_ = (H - 1) / 255.0 100 | sc_py = np.array([sx_, sy_]).reshape([1, 1, 2]) 101 | sc_pt = torch.from_numpy(sc_py).float().to(device) 102 | __, first_visible_inds = torch.max(sample.visibility, dim=1) 103 | 104 | frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N) 105 | start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1)) 106 | 107 | for thr in thrs: 108 | d_ = ( 109 | torch.norm( 110 | pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, 111 | dim=-1, 112 | ) 113 | < thr 114 | ).float() # B,S-1,N 115 | d_occ = ( 116 | reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item() 117 | * 100.0 118 | ) 119 | d_occ_sum += d_occ 120 | out_metrics[f"accuracy_occ_{thr}"] = d_occ 121 | 122 | d_vis = ( 123 | reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0 124 | ) 125 | d_vis_sum += d_vis 126 | out_metrics[f"accuracy_vis_{thr}"] = d_vis 127 | 128 | d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0 129 | d_sum_all += d_all 130 | out_metrics[f"accuracy_{thr}"] = d_all 131 | 132 | d_occ_avg = d_occ_sum / len(thrs) 133 | d_vis_avg = d_vis_sum / len(thrs) 134 | d_all_avg = d_sum_all / len(thrs) 135 | 136 | sur_thr = 50 137 | dists = torch.norm( 138 | pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, 139 | dim=-1, 140 | ) # B,S,N 141 | dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N 142 | survival = torch.cumprod(dist_ok, dim=1) # B,S,N 143 | out_metrics["survival"] = torch.mean(survival).item() * 100.0 144 | 145 | out_metrics["accuracy_occ"] = d_occ_avg 146 | out_metrics["accuracy_vis"] = d_vis_avg 147 | out_metrics["accuracy"] = d_all_avg 148 | 149 | metrics[sample.seq_name[0]] = out_metrics 150 | for metric_name in out_metrics.keys(): 151 | if "avg" not in metrics: 152 | metrics["avg"] = {} 153 | metrics["avg"][metric_name] = float( 154 | np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"]) 155 | ) 156 | 157 | logging.info(f"Metrics: {out_metrics}") 158 | logging.info(f"avg: {metrics['avg']}") 159 | print("metrics", out_metrics) 160 | print("avg", metrics["avg"]) 161 | 162 | @torch.no_grad() 163 | def evaluate_sequence( 164 | self, 165 | model, 166 | test_dataloader: torch.utils.data.DataLoader, 167 | dataset_name: str, 168 | train_mode=False, 169 | visualize_every: int = 1, 170 | writer: Optional[SummaryWriter] = None, 171 | step: Optional[int] = 0, 172 | ): 173 | metrics = {} 174 | 175 | vis = Visualizer( 176 | save_dir=self.exp_dir, 177 | fps=7, 178 | ) 179 | 180 | for ind, sample in enumerate(tqdm(test_dataloader)): 181 | if isinstance(sample, tuple): 182 | sample, gotit = sample 183 | if not all(gotit): 184 | print("batch is None") 185 | continue 186 | if torch.cuda.is_available(): 187 | dataclass_to_cuda_(sample) 188 | device = torch.device("cuda") 189 | else: 190 | device = torch.device("cpu") 191 | 192 | if ( 193 | not train_mode 194 | and hasattr(model, "sequence_len") 195 | and (sample.visibility[:, : model.sequence_len].sum() == 0) 196 | ): 197 | print(f"skipping batch {ind}") 198 | continue 199 | 200 | if "tapvid" in dataset_name: 201 | queries = sample.query_points.clone().float() 202 | 203 | queries = torch.stack( 204 | [ 205 | queries[:, :, 0], 206 | queries[:, :, 2], 207 | queries[:, :, 1], 208 | ], 209 | dim=2, 210 | ).to(device) 211 | else: 212 | queries = torch.cat( 213 | [ 214 | torch.zeros_like(sample.trajectory[:, 0, :, :1]), 215 | sample.trajectory[:, 0], 216 | ], 217 | dim=2, 218 | ).to(device) 219 | 220 | pred_tracks = model(sample.video, queries) 221 | if "strided" in dataset_name: 222 | inv_video = sample.video.flip(1).clone() 223 | inv_queries = queries.clone() 224 | inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 225 | 226 | pred_trj, pred_vsb = pred_tracks 227 | inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) 228 | 229 | inv_pred_trj = inv_pred_trj.flip(1) 230 | inv_pred_vsb = inv_pred_vsb.flip(1) 231 | 232 | mask = pred_trj == 0 233 | 234 | pred_trj[mask] = inv_pred_trj[mask] 235 | pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] 236 | 237 | pred_tracks = pred_trj, pred_vsb 238 | 239 | if dataset_name == "badja" or dataset_name == "fastcapture": 240 | seq_name = sample.seq_name[0] 241 | else: 242 | seq_name = str(ind) 243 | if ind % visualize_every == 0: 244 | vis.visualize( 245 | sample.video, 246 | pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, 247 | filename=dataset_name + "_" + seq_name, 248 | writer=writer, 249 | step=step, 250 | ) 251 | 252 | self.compute_metrics(metrics, sample, pred_tracks, dataset_name) 253 | return metrics 254 | -------------------------------------------------------------------------------- /cotracker/cotracker/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | from dataclasses import dataclass, field 10 | 11 | import hydra 12 | import numpy as np 13 | 14 | import torch 15 | from omegaconf import OmegaConf 16 | 17 | from cotracker.datasets.tap_vid_datasets import TapVidDataset 18 | from cotracker.datasets.dr_dataset import DynamicReplicaDataset 19 | from cotracker.datasets.utils import collate_fn 20 | 21 | from cotracker.models.evaluation_predictor import EvaluationPredictor 22 | 23 | from cotracker.evaluation.core.evaluator import Evaluator 24 | from cotracker.models.build_cotracker import ( 25 | build_cotracker, 26 | ) 27 | 28 | 29 | @dataclass(eq=False) 30 | class DefaultConfig: 31 | # Directory where all outputs of the experiment will be saved. 32 | exp_dir: str = "./outputs" 33 | 34 | # Name of the dataset to be used for the evaluation. 35 | dataset_name: str = "tapvid_davis_first" 36 | # The root directory of the dataset. 37 | dataset_root: str = "./" 38 | 39 | # Path to the pre-trained model checkpoint to be used for the evaluation. 40 | # The default value is the path to a specific CoTracker model checkpoint. 41 | checkpoint: str = "./checkpoints/cotracker2.pth" 42 | 43 | # EvaluationPredictor parameters 44 | # The size (N) of the support grid used in the predictor. 45 | # The total number of points is (N*N). 46 | grid_size: int = 5 47 | # The size (N) of the local support grid. 48 | local_grid_size: int = 8 49 | # A flag indicating whether to evaluate one ground truth point at a time. 50 | single_point: bool = True 51 | # The number of iterative updates for each sliding window. 52 | n_iters: int = 6 53 | 54 | seed: int = 0 55 | gpu_idx: int = 0 56 | 57 | # Override hydra's working directory to current working dir, 58 | # also disable storing the .hydra logs: 59 | hydra: dict = field( 60 | default_factory=lambda: { 61 | "run": {"dir": "."}, 62 | "output_subdir": None, 63 | } 64 | ) 65 | 66 | 67 | def run_eval(cfg: DefaultConfig): 68 | """ 69 | The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. 70 | 71 | Args: 72 | cfg (DefaultConfig): An instance of DefaultConfig class which includes: 73 | - exp_dir (str): The directory path for the experiment. 74 | - dataset_name (str): The name of the dataset to be used. 75 | - dataset_root (str): The root directory of the dataset. 76 | - checkpoint (str): The path to the CoTracker model's checkpoint. 77 | - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. 78 | - n_iters (int): The number of iterative updates for each sliding window. 79 | - seed (int): The seed for setting the random state for reproducibility. 80 | - gpu_idx (int): The index of the GPU to be used. 81 | """ 82 | # Creating the experiment directory if it doesn't exist 83 | os.makedirs(cfg.exp_dir, exist_ok=True) 84 | 85 | # Saving the experiment configuration to a .yaml file in the experiment directory 86 | cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") 87 | with open(cfg_file, "w") as f: 88 | OmegaConf.save(config=cfg, f=f) 89 | 90 | evaluator = Evaluator(cfg.exp_dir) 91 | cotracker_model = build_cotracker(cfg.checkpoint) 92 | 93 | # Creating the EvaluationPredictor object 94 | predictor = EvaluationPredictor( 95 | cotracker_model, 96 | grid_size=cfg.grid_size, 97 | local_grid_size=cfg.local_grid_size, 98 | single_point=cfg.single_point, 99 | n_iters=cfg.n_iters, 100 | ) 101 | if torch.cuda.is_available(): 102 | predictor.model = predictor.model.cuda() 103 | 104 | # Setting the random seeds 105 | torch.manual_seed(cfg.seed) 106 | np.random.seed(cfg.seed) 107 | 108 | # Constructing the specified dataset 109 | curr_collate_fn = collate_fn 110 | if "tapvid" in cfg.dataset_name: 111 | dataset_type = cfg.dataset_name.split("_")[1] 112 | if dataset_type == "davis": 113 | data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl") 114 | elif dataset_type == "kinetics": 115 | data_root = os.path.join( 116 | cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" 117 | ) 118 | test_dataset = TapVidDataset( 119 | dataset_type=dataset_type, 120 | data_root=data_root, 121 | queried_first=not "strided" in cfg.dataset_name, 122 | ) 123 | elif cfg.dataset_name == "dynamic_replica": 124 | test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1) 125 | 126 | # Creating the DataLoader object 127 | test_dataloader = torch.utils.data.DataLoader( 128 | test_dataset, 129 | batch_size=1, 130 | shuffle=False, 131 | num_workers=14, 132 | collate_fn=curr_collate_fn, 133 | ) 134 | 135 | # Timing and conducting the evaluation 136 | import time 137 | 138 | start = time.time() 139 | evaluate_result = evaluator.evaluate_sequence( 140 | predictor, 141 | test_dataloader, 142 | dataset_name=cfg.dataset_name, 143 | ) 144 | end = time.time() 145 | print(end - start) 146 | 147 | # Saving the evaluation results to a .json file 148 | evaluate_result = evaluate_result["avg"] 149 | print("evaluate_result", evaluate_result) 150 | result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") 151 | evaluate_result["time"] = end - start 152 | print(f"Dumping eval results to {result_file}.") 153 | with open(result_file, "w") as f: 154 | json.dump(evaluate_result, f) 155 | 156 | 157 | cs = hydra.core.config_store.ConfigStore.instance() 158 | cs.store(name="default_config_eval", node=DefaultConfig) 159 | 160 | 161 | @hydra.main(config_path="./configs/", config_name="default_config_eval") 162 | def evaluate(cfg: DefaultConfig) -> None: 163 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 164 | os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) 165 | run_eval(cfg) 166 | 167 | 168 | if __name__ == "__main__": 169 | evaluate() 170 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/__pycache__/build_cotracker.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/__pycache__/build_cotracker.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/build_cotracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from cotracker.models.core.cotracker.cotracker import CoTracker2 10 | 11 | 12 | def build_cotracker( 13 | checkpoint: str, 14 | ): 15 | if checkpoint is None: 16 | return build_cotracker() 17 | model_name = checkpoint.split("/")[-1].split(".")[0] 18 | if model_name == "cotracker": 19 | return build_cotracker(checkpoint=checkpoint) 20 | else: 21 | raise ValueError(f"Unknown model name {model_name}") 22 | 23 | 24 | def build_cotracker(checkpoint=None): 25 | cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True) 26 | 27 | if checkpoint is not None: 28 | with open(checkpoint, "rb") as f: 29 | state_dict = torch.load(f, map_location="cpu") 30 | if "model" in state_dict: 31 | state_dict = state_dict["model"] 32 | cotracker.load_state_dict(state_dict) 33 | return cotracker 34 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/core/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/__pycache__/embeddings.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/core/__pycache__/embeddings.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/__pycache__/model_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/core/__pycache__/model_utils.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/cotracker/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/cotracker/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/core/cotracker/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/cotracker/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/core/cotracker/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/cotracker/blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from functools import partial 11 | from typing import Callable 12 | import collections 13 | from torch import Tensor 14 | from itertools import repeat 15 | 16 | from cotracker.models.core.model_utils import bilinear_sampler 17 | 18 | 19 | # From PyTorch internals 20 | def _ntuple(n): 21 | def parse(x): 22 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 23 | return tuple(x) 24 | return tuple(repeat(x, n)) 25 | 26 | return parse 27 | 28 | 29 | def exists(val): 30 | return val is not None 31 | 32 | 33 | def default(val, d): 34 | return val if exists(val) else d 35 | 36 | 37 | to_2tuple = _ntuple(2) 38 | 39 | 40 | class Mlp(nn.Module): 41 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 42 | 43 | def __init__( 44 | self, 45 | in_features, 46 | hidden_features=None, 47 | out_features=None, 48 | act_layer=nn.GELU, 49 | norm_layer=None, 50 | bias=True, 51 | drop=0.0, 52 | use_conv=False, 53 | ): 54 | super().__init__() 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | bias = to_2tuple(bias) 58 | drop_probs = to_2tuple(drop) 59 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 60 | 61 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 62 | self.act = act_layer() 63 | self.drop1 = nn.Dropout(drop_probs[0]) 64 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() 65 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 66 | self.drop2 = nn.Dropout(drop_probs[1]) 67 | 68 | def forward(self, x): 69 | x = self.fc1(x) 70 | x = self.act(x) 71 | x = self.drop1(x) 72 | x = self.fc2(x) 73 | x = self.drop2(x) 74 | return x 75 | 76 | 77 | class ResidualBlock(nn.Module): 78 | def __init__(self, in_planes, planes, norm_fn="group", stride=1): 79 | super(ResidualBlock, self).__init__() 80 | 81 | self.conv1 = nn.Conv2d( 82 | in_planes, 83 | planes, 84 | kernel_size=3, 85 | padding=1, 86 | stride=stride, 87 | padding_mode="zeros", 88 | ) 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros") 90 | self.relu = nn.ReLU(inplace=True) 91 | 92 | num_groups = planes // 8 93 | 94 | if norm_fn == "group": 95 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 96 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 97 | if not stride == 1: 98 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 99 | 100 | elif norm_fn == "batch": 101 | self.norm1 = nn.BatchNorm2d(planes) 102 | self.norm2 = nn.BatchNorm2d(planes) 103 | if not stride == 1: 104 | self.norm3 = nn.BatchNorm2d(planes) 105 | 106 | elif norm_fn == "instance": 107 | self.norm1 = nn.InstanceNorm2d(planes) 108 | self.norm2 = nn.InstanceNorm2d(planes) 109 | if not stride == 1: 110 | self.norm3 = nn.InstanceNorm2d(planes) 111 | 112 | elif norm_fn == "none": 113 | self.norm1 = nn.Sequential() 114 | self.norm2 = nn.Sequential() 115 | if not stride == 1: 116 | self.norm3 = nn.Sequential() 117 | 118 | if stride == 1: 119 | self.downsample = None 120 | 121 | else: 122 | self.downsample = nn.Sequential( 123 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 124 | ) 125 | 126 | def forward(self, x): 127 | y = x 128 | y = self.relu(self.norm1(self.conv1(y))) 129 | y = self.relu(self.norm2(self.conv2(y))) 130 | 131 | if self.downsample is not None: 132 | x = self.downsample(x) 133 | 134 | return self.relu(x + y) 135 | 136 | 137 | class BasicEncoder(nn.Module): 138 | def __init__(self, input_dim=3, output_dim=128, stride=4): 139 | super(BasicEncoder, self).__init__() 140 | self.stride = stride 141 | self.norm_fn = "instance" 142 | self.in_planes = output_dim // 2 143 | 144 | self.norm1 = nn.InstanceNorm2d(self.in_planes) 145 | self.norm2 = nn.InstanceNorm2d(output_dim * 2) 146 | 147 | self.conv1 = nn.Conv2d( 148 | input_dim, 149 | self.in_planes, 150 | kernel_size=7, 151 | stride=2, 152 | padding=3, 153 | padding_mode="zeros", 154 | ) 155 | self.relu1 = nn.ReLU(inplace=True) 156 | self.layer1 = self._make_layer(output_dim // 2, stride=1) 157 | self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) 158 | self.layer3 = self._make_layer(output_dim, stride=2) 159 | self.layer4 = self._make_layer(output_dim, stride=2) 160 | 161 | self.conv2 = nn.Conv2d( 162 | output_dim * 3 + output_dim // 4, 163 | output_dim * 2, 164 | kernel_size=3, 165 | padding=1, 166 | padding_mode="zeros", 167 | ) 168 | self.relu2 = nn.ReLU(inplace=True) 169 | self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 173 | elif isinstance(m, (nn.InstanceNorm2d)): 174 | if m.weight is not None: 175 | nn.init.constant_(m.weight, 1) 176 | if m.bias is not None: 177 | nn.init.constant_(m.bias, 0) 178 | 179 | def _make_layer(self, dim, stride=1): 180 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 181 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 182 | layers = (layer1, layer2) 183 | 184 | self.in_planes = dim 185 | return nn.Sequential(*layers) 186 | 187 | def forward(self, x): 188 | _, _, H, W = x.shape 189 | 190 | x = self.conv1(x) 191 | x = self.norm1(x) 192 | x = self.relu1(x) 193 | 194 | a = self.layer1(x) 195 | b = self.layer2(a) 196 | c = self.layer3(b) 197 | d = self.layer4(c) 198 | 199 | def _bilinear_intepolate(x): 200 | return F.interpolate( 201 | x, 202 | (H // self.stride, W // self.stride), 203 | mode="bilinear", 204 | align_corners=True, 205 | ) 206 | 207 | a = _bilinear_intepolate(a) 208 | b = _bilinear_intepolate(b) 209 | c = _bilinear_intepolate(c) 210 | d = _bilinear_intepolate(d) 211 | 212 | x = self.conv2(torch.cat([a, b, c, d], dim=1)) 213 | x = self.norm2(x) 214 | x = self.relu2(x) 215 | x = self.conv3(x) 216 | return x 217 | 218 | 219 | class CorrBlock: 220 | def __init__( 221 | self, 222 | fmaps, 223 | num_levels=4, 224 | radius=4, 225 | multiple_track_feats=False, 226 | padding_mode="zeros", 227 | ): 228 | B, S, C, H, W = fmaps.shape 229 | self.S, self.C, self.H, self.W = S, C, H, W 230 | self.padding_mode = padding_mode 231 | self.num_levels = num_levels 232 | self.radius = radius 233 | self.fmaps_pyramid = [] 234 | self.multiple_track_feats = multiple_track_feats 235 | 236 | self.fmaps_pyramid.append(fmaps) 237 | for i in range(self.num_levels - 1): 238 | fmaps_ = fmaps.reshape(B * S, C, H, W) 239 | fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) 240 | _, _, H, W = fmaps_.shape 241 | fmaps = fmaps_.reshape(B, S, C, H, W) 242 | self.fmaps_pyramid.append(fmaps) 243 | 244 | def sample(self, coords): 245 | r = self.radius 246 | B, S, N, D = coords.shape 247 | assert D == 2 248 | 249 | H, W = self.H, self.W 250 | out_pyramid = [] 251 | for i in range(self.num_levels): 252 | corrs = self.corrs_pyramid[i] # B, S, N, H, W 253 | *_, H, W = corrs.shape 254 | 255 | dx = torch.linspace(-r, r, 2 * r + 1) 256 | dy = torch.linspace(-r, r, 2 * r + 1) 257 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) 258 | 259 | centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i 260 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 261 | coords_lvl = centroid_lvl + delta_lvl 262 | 263 | corrs = bilinear_sampler( 264 | corrs.reshape(B * S * N, 1, H, W), 265 | coords_lvl, 266 | padding_mode=self.padding_mode, 267 | ) 268 | corrs = corrs.view(B, S, N, -1) 269 | out_pyramid.append(corrs) 270 | 271 | out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2 272 | out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float() 273 | return out 274 | 275 | def corr(self, targets): 276 | B, S, N, C = targets.shape 277 | if self.multiple_track_feats: 278 | targets_split = targets.split(C // self.num_levels, dim=-1) 279 | B, S, N, C = targets_split[0].shape 280 | 281 | assert C == self.C 282 | assert S == self.S 283 | 284 | fmap1 = targets 285 | 286 | self.corrs_pyramid = [] 287 | for i, fmaps in enumerate(self.fmaps_pyramid): 288 | *_, H, W = fmaps.shape 289 | fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) 290 | if self.multiple_track_feats: 291 | fmap1 = targets_split[i] 292 | corrs = torch.matmul(fmap1, fmap2s) 293 | corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W 294 | corrs = corrs / torch.sqrt(torch.tensor(C).float()) 295 | self.corrs_pyramid.append(corrs) 296 | 297 | 298 | class Attention(nn.Module): 299 | def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False): 300 | super().__init__() 301 | inner_dim = dim_head * num_heads 302 | context_dim = default(context_dim, query_dim) 303 | self.scale = dim_head**-0.5 304 | self.heads = num_heads 305 | 306 | self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias) 307 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) 308 | self.to_out = nn.Linear(inner_dim, query_dim) 309 | 310 | def forward(self, x, context=None, attn_bias=None): 311 | B, N1, C = x.shape 312 | h = self.heads 313 | 314 | q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3) 315 | context = default(context, x) 316 | k, v = self.to_kv(context).chunk(2, dim=-1) 317 | 318 | N2 = context.shape[1] 319 | k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) 320 | v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) 321 | 322 | sim = (q @ k.transpose(-2, -1)) * self.scale 323 | 324 | if attn_bias is not None: 325 | sim = sim + attn_bias 326 | attn = sim.softmax(dim=-1) 327 | 328 | x = (attn @ v).transpose(1, 2).reshape(B, N1, C) 329 | return self.to_out(x) 330 | 331 | 332 | class AttnBlock(nn.Module): 333 | def __init__( 334 | self, 335 | hidden_size, 336 | num_heads, 337 | attn_class: Callable[..., nn.Module] = Attention, 338 | mlp_ratio=4.0, 339 | **block_kwargs 340 | ): 341 | super().__init__() 342 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 343 | self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 344 | 345 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 346 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 347 | approx_gelu = lambda: nn.GELU(approximate="tanh") 348 | self.mlp = Mlp( 349 | in_features=hidden_size, 350 | hidden_features=mlp_hidden_dim, 351 | act_layer=approx_gelu, 352 | drop=0, 353 | ) 354 | 355 | def forward(self, x, mask=None): 356 | attn_bias = mask 357 | if mask is not None: 358 | mask = ( 359 | (mask[:, None] * mask[:, :, None]) 360 | .unsqueeze(1) 361 | .expand(-1, self.attn.num_heads, -1, -1) 362 | ) 363 | max_neg_value = -torch.finfo(x.dtype).max 364 | attn_bias = (~mask) * max_neg_value 365 | x = x + self.attn(self.norm1(x), attn_bias=attn_bias) 366 | x = x + self.mlp(self.norm2(x)) 367 | return x 368 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/cotracker/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from cotracker.models.core.model_utils import reduce_masked_mean 10 | 11 | EPS = 1e-6 12 | 13 | 14 | def balanced_ce_loss(pred, gt, valid=None): 15 | total_balanced_loss = 0.0 16 | for j in range(len(gt)): 17 | B, S, N = gt[j].shape 18 | # pred and gt are the same shape 19 | for (a, b) in zip(pred[j].size(), gt[j].size()): 20 | assert a == b # some shape mismatch! 21 | # if valid is not None: 22 | for (a, b) in zip(pred[j].size(), valid[j].size()): 23 | assert a == b # some shape mismatch! 24 | 25 | pos = (gt[j] > 0.95).float() 26 | neg = (gt[j] < 0.05).float() 27 | 28 | label = pos * 2.0 - 1.0 29 | a = -label * pred[j] 30 | b = F.relu(a) 31 | loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) 32 | 33 | pos_loss = reduce_masked_mean(loss, pos * valid[j]) 34 | neg_loss = reduce_masked_mean(loss, neg * valid[j]) 35 | 36 | balanced_loss = pos_loss + neg_loss 37 | total_balanced_loss += balanced_loss / float(N) 38 | return total_balanced_loss 39 | 40 | 41 | def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): 42 | """Loss function defined over sequence of flow predictions""" 43 | total_flow_loss = 0.0 44 | for j in range(len(flow_gt)): 45 | B, S, N, D = flow_gt[j].shape 46 | assert D == 2 47 | B, S1, N = vis[j].shape 48 | B, S2, N = valids[j].shape 49 | assert S == S1 50 | assert S == S2 51 | n_predictions = len(flow_preds[j]) 52 | flow_loss = 0.0 53 | for i in range(n_predictions): 54 | i_weight = gamma ** (n_predictions - i - 1) 55 | flow_pred = flow_preds[j][i] 56 | i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2 57 | i_loss = torch.mean(i_loss, dim=3) # B, S, N 58 | flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) 59 | flow_loss = flow_loss / n_predictions 60 | total_flow_loss += flow_loss / float(N) 61 | return total_flow_loss 62 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple, Union 8 | import torch 9 | 10 | 11 | def get_2d_sincos_pos_embed( 12 | embed_dim: int, grid_size: Union[int, Tuple[int, int]] 13 | ) -> torch.Tensor: 14 | """ 15 | This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. 16 | It is a wrapper of get_2d_sincos_pos_embed_from_grid. 17 | Args: 18 | - embed_dim: The embedding dimension. 19 | - grid_size: The grid size. 20 | Returns: 21 | - pos_embed: The generated 2D positional embedding. 22 | """ 23 | if isinstance(grid_size, tuple): 24 | grid_size_h, grid_size_w = grid_size 25 | else: 26 | grid_size_h = grid_size_w = grid_size 27 | grid_h = torch.arange(grid_size_h, dtype=torch.float) 28 | grid_w = torch.arange(grid_size_w, dtype=torch.float) 29 | grid = torch.meshgrid(grid_w, grid_h, indexing="xy") 30 | grid = torch.stack(grid, dim=0) 31 | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) 34 | 35 | 36 | def get_2d_sincos_pos_embed_from_grid( 37 | embed_dim: int, grid: torch.Tensor 38 | ) -> torch.Tensor: 39 | """ 40 | This function generates a 2D positional embedding from a given grid using sine and cosine functions. 41 | 42 | Args: 43 | - embed_dim: The embedding dimension. 44 | - grid: The grid to generate the embedding from. 45 | 46 | Returns: 47 | - emb: The generated 2D positional embedding. 48 | """ 49 | assert embed_dim % 2 == 0 50 | 51 | # use half of dimensions to encode grid_h 52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 54 | 55 | emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) 56 | return emb 57 | 58 | 59 | def get_1d_sincos_pos_embed_from_grid( 60 | embed_dim: int, pos: torch.Tensor 61 | ) -> torch.Tensor: 62 | """ 63 | This function generates a 1D positional embedding from a given grid using sine and cosine functions. 64 | 65 | Args: 66 | - embed_dim: The embedding dimension. 67 | - pos: The position to generate the embedding from. 68 | 69 | Returns: 70 | - emb: The generated 1D positional embedding. 71 | """ 72 | assert embed_dim % 2 == 0 73 | omega = torch.arange(embed_dim // 2, dtype=torch.double) 74 | omega /= embed_dim / 2.0 75 | omega = 1.0 / 10000**omega # (D/2,) 76 | 77 | pos = pos.reshape(-1) # (M,) 78 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 79 | 80 | emb_sin = torch.sin(out) # (M, D/2) 81 | emb_cos = torch.cos(out) # (M, D/2) 82 | 83 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 84 | return emb[None].float() 85 | 86 | 87 | def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: 88 | """ 89 | This function generates a 2D positional embedding from given coordinates using sine and cosine functions. 90 | 91 | Args: 92 | - xy: The coordinates to generate the embedding from. 93 | - C: The size of the embedding. 94 | - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. 95 | 96 | Returns: 97 | - pe: The generated 2D positional embedding. 98 | """ 99 | B, N, D = xy.shape 100 | assert D == 2 101 | 102 | x = xy[:, :, 0:1] 103 | y = xy[:, :, 1:2] 104 | div_term = ( 105 | torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) 106 | ).reshape(1, 1, int(C / 2)) 107 | 108 | pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 109 | pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) 110 | 111 | pe_x[:, :, 0::2] = torch.sin(x * div_term) 112 | pe_x[:, :, 1::2] = torch.cos(x * div_term) 113 | 114 | pe_y[:, :, 0::2] = torch.sin(y * div_term) 115 | pe_y[:, :, 1::2] = torch.cos(y * div_term) 116 | 117 | pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) 118 | if cat_coords: 119 | pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) 120 | return pe 121 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/core/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from typing import Optional, Tuple 10 | 11 | EPS = 1e-6 12 | 13 | 14 | def smart_cat(tensor1, tensor2, dim): 15 | if tensor1 is None: 16 | return tensor2 17 | return torch.cat([tensor1, tensor2], dim=dim) 18 | 19 | 20 | def get_points_on_a_grid( 21 | size: int, 22 | extent: Tuple[float, ...], 23 | center: Optional[Tuple[float, ...]] = None, 24 | device: Optional[torch.device] = torch.device("cpu"), 25 | ): 26 | r"""Get a grid of points covering a rectangular region 27 | 28 | `get_points_on_a_grid(size, extent)` generates a :attr:`size` by 29 | :attr:`size` grid fo points distributed to cover a rectangular area 30 | specified by `extent`. 31 | 32 | The `extent` is a pair of integer :math:`(H,W)` specifying the height 33 | and width of the rectangle. 34 | 35 | Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` 36 | specifying the vertical and horizontal center coordinates. The center 37 | defaults to the middle of the extent. 38 | 39 | Points are distributed uniformly within the rectangle leaving a margin 40 | :math:`m=W/64` from the border. 41 | 42 | It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of 43 | points :math:`P_{ij}=(x_i, y_i)` where 44 | 45 | .. math:: 46 | P_{ij} = \left( 47 | c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ 48 | c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i 49 | \right) 50 | 51 | Points are returned in row-major order. 52 | 53 | Args: 54 | size (int): grid size. 55 | extent (tuple): height and with of the grid extent. 56 | center (tuple, optional): grid center. 57 | device (str, optional): Defaults to `"cpu"`. 58 | 59 | Returns: 60 | Tensor: grid. 61 | """ 62 | if size == 1: 63 | return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] 64 | 65 | if center is None: 66 | center = [extent[0] / 2, extent[1] / 2] 67 | 68 | margin = extent[1] / 64 69 | range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) 70 | range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) 71 | grid_y, grid_x = torch.meshgrid( 72 | torch.linspace(*range_y, size, device=device), 73 | torch.linspace(*range_x, size, device=device), 74 | indexing="ij", 75 | ) 76 | return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) 77 | 78 | 79 | def reduce_masked_mean(input, mask, dim=None, keepdim=False): 80 | r"""Masked mean 81 | 82 | `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input` 83 | over a mask :attr:`mask`, returning 84 | 85 | .. math:: 86 | \text{output} = 87 | \frac 88 | {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i} 89 | {\epsilon + \sum_{i=1}^N \text{mask}_i} 90 | 91 | where :math:`N` is the number of elements in :attr:`input` and 92 | :attr:`mask`, and :math:`\epsilon` is a small constant to avoid 93 | division by zero. 94 | 95 | `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor 96 | :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`. 97 | Optionally, the dimension can be kept in the output by setting 98 | :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to 99 | the same dimension as :attr:`input`. 100 | 101 | The interface is similar to `torch.mean()`. 102 | 103 | Args: 104 | inout (Tensor): input tensor. 105 | mask (Tensor): mask. 106 | dim (int, optional): Dimension to sum over. Defaults to None. 107 | keepdim (bool, optional): Keep the summed dimension. Defaults to False. 108 | 109 | Returns: 110 | Tensor: mean tensor. 111 | """ 112 | 113 | mask = mask.expand_as(input) 114 | 115 | prod = input * mask 116 | 117 | if dim is None: 118 | numer = torch.sum(prod) 119 | denom = torch.sum(mask) 120 | else: 121 | numer = torch.sum(prod, dim=dim, keepdim=keepdim) 122 | denom = torch.sum(mask, dim=dim, keepdim=keepdim) 123 | 124 | mean = numer / (EPS + denom) 125 | return mean 126 | 127 | 128 | def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): 129 | r"""Sample a tensor using bilinear interpolation 130 | 131 | `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at 132 | coordinates :attr:`coords` using bilinear interpolation. It is the same 133 | as `torch.nn.functional.grid_sample()` but with a different coordinate 134 | convention. 135 | 136 | The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where 137 | :math:`B` is the batch size, :math:`C` is the number of channels, 138 | :math:`H` is the height of the image, and :math:`W` is the width of the 139 | image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is 140 | interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. 141 | 142 | Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, 143 | in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note 144 | that in this case the order of the components is slightly different 145 | from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. 146 | 147 | If `align_corners` is `True`, the coordinate :math:`x` is assumed to be 148 | in the range :math:`[0,W-1]`, with 0 corresponding to the center of the 149 | left-most image pixel :math:`W-1` to the center of the right-most 150 | pixel. 151 | 152 | If `align_corners` is `False`, the coordinate :math:`x` is assumed to 153 | be in the range :math:`[0,W]`, with 0 corresponding to the left edge of 154 | the left-most pixel :math:`W` to the right edge of the right-most 155 | pixel. 156 | 157 | Similar conventions apply to the :math:`y` for the range 158 | :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range 159 | :math:`[0,T-1]` and :math:`[0,T]`. 160 | 161 | Args: 162 | input (Tensor): batch of input images. 163 | coords (Tensor): batch of coordinates. 164 | align_corners (bool, optional): Coordinate convention. Defaults to `True`. 165 | padding_mode (str, optional): Padding mode. Defaults to `"border"`. 166 | 167 | Returns: 168 | Tensor: sampled points. 169 | """ 170 | 171 | sizes = input.shape[2:] 172 | 173 | assert len(sizes) in [2, 3] 174 | 175 | if len(sizes) == 3: 176 | # t x y -> x y t to match dimensions T H W in grid_sample 177 | coords = coords[..., [1, 2, 0]] 178 | 179 | if align_corners: 180 | coords = coords * torch.tensor( 181 | [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device 182 | ) 183 | else: 184 | coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) 185 | 186 | coords -= 1 187 | 188 | return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) 189 | 190 | 191 | def sample_features4d(input, coords): 192 | r"""Sample spatial features 193 | 194 | `sample_features4d(input, coords)` samples the spatial features 195 | :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. 196 | 197 | The field is sampled at coordinates :attr:`coords` using bilinear 198 | interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, 199 | 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the 200 | same convention as :func:`bilinear_sampler` with `align_corners=True`. 201 | 202 | The output tensor has one feature per point, and has shape :math:`(B, 203 | R, C)`. 204 | 205 | Args: 206 | input (Tensor): spatial features. 207 | coords (Tensor): points. 208 | 209 | Returns: 210 | Tensor: sampled features. 211 | """ 212 | 213 | B, _, _, _ = input.shape 214 | 215 | # B R 2 -> B R 1 2 216 | coords = coords.unsqueeze(2) 217 | 218 | # B C R 1 219 | feats = bilinear_sampler(input, coords) 220 | 221 | return feats.permute(0, 2, 1, 3).view( 222 | B, -1, feats.shape[1] * feats.shape[3] 223 | ) # B C R 1 -> B R C 224 | 225 | 226 | def sample_features5d(input, coords): 227 | r"""Sample spatio-temporal features 228 | 229 | `sample_features5d(input, coords)` works in the same way as 230 | :func:`sample_features4d` but for spatio-temporal features and points: 231 | :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is 232 | a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, 233 | x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. 234 | 235 | Args: 236 | input (Tensor): spatio-temporal features. 237 | coords (Tensor): spatio-temporal points. 238 | 239 | Returns: 240 | Tensor: sampled features. 241 | """ 242 | 243 | B, T, _, _, _ = input.shape 244 | 245 | # B T C H W -> B C T H W 246 | input = input.permute(0, 2, 1, 3, 4) 247 | 248 | # B R1 R2 3 -> B R1 R2 1 3 249 | coords = coords.unsqueeze(3) 250 | 251 | # B C R1 R2 1 252 | feats = bilinear_sampler(input, coords) 253 | 254 | return feats.permute(0, 2, 3, 1, 4).view( 255 | B, feats.shape[2], feats.shape[3], feats.shape[1] 256 | ) # B C R1 R2 1 -> B R1 R2 C 257 | -------------------------------------------------------------------------------- /cotracker/cotracker/models/evaluation_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from typing import Tuple 10 | 11 | from cotracker.models.core.cotracker.cotracker import CoTracker2 12 | from cotracker.models.core.model_utils import get_points_on_a_grid 13 | 14 | 15 | class EvaluationPredictor(torch.nn.Module): 16 | def __init__( 17 | self, 18 | cotracker_model: CoTracker2, 19 | interp_shape: Tuple[int, int] = (384, 512), 20 | grid_size: int = 5, 21 | local_grid_size: int = 8, 22 | single_point: bool = True, 23 | n_iters: int = 6, 24 | ) -> None: 25 | super(EvaluationPredictor, self).__init__() 26 | self.grid_size = grid_size 27 | self.local_grid_size = local_grid_size 28 | self.single_point = single_point 29 | self.interp_shape = interp_shape 30 | self.n_iters = n_iters 31 | 32 | self.model = cotracker_model 33 | self.model.eval() 34 | 35 | def forward(self, video, queries): 36 | queries = queries.clone() 37 | B, T, C, H, W = video.shape 38 | B, N, D = queries.shape 39 | 40 | assert D == 3 41 | 42 | video = video.reshape(B * T, C, H, W) 43 | video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) 44 | video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) 45 | 46 | device = video.device 47 | 48 | queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1) 49 | queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1) 50 | 51 | if self.single_point: 52 | traj_e = torch.zeros((B, T, N, 2), device=device) 53 | vis_e = torch.zeros((B, T, N), device=device) 54 | for pind in range((N)): 55 | query = queries[:, pind : pind + 1] 56 | 57 | t = query[0, 0, 0].long() 58 | 59 | traj_e_pind, vis_e_pind = self._process_one_point(video, query) 60 | traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] 61 | vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] 62 | else: 63 | if self.grid_size > 0: 64 | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) 65 | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # 66 | queries = torch.cat([queries, xy], dim=1) # 67 | 68 | traj_e, vis_e, __ = self.model( 69 | video=video, 70 | queries=queries, 71 | iters=self.n_iters, 72 | ) 73 | 74 | traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1) 75 | traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1) 76 | return traj_e, vis_e 77 | 78 | def _process_one_point(self, video, query): 79 | t = query[0, 0, 0].long() 80 | 81 | device = query.device 82 | if self.local_grid_size > 0: 83 | xy_target = get_points_on_a_grid( 84 | self.local_grid_size, 85 | (50, 50), 86 | [query[0, 0, 2].item(), query[0, 0, 1].item()], 87 | ) 88 | 89 | xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to( 90 | device 91 | ) # 92 | query = torch.cat([query, xy_target], dim=1) # 93 | 94 | if self.grid_size > 0: 95 | xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) 96 | xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # 97 | query = torch.cat([query, xy], dim=1) # 98 | # crop the video to start from the queried frame 99 | query[0, 0, 0] = 0 100 | traj_e_pind, vis_e_pind, __ = self.model( 101 | video=video[:, t:], queries=query, iters=self.n_iters 102 | ) 103 | 104 | return traj_e_pind, vis_e_pind 105 | -------------------------------------------------------------------------------- /cotracker/cotracker/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid 11 | from cotracker.models.build_cotracker import build_cotracker 12 | 13 | 14 | class CoTrackerPredictor(torch.nn.Module): 15 | def __init__(self, checkpoint="./checkpoints/cotracker2.pth"): 16 | super().__init__() 17 | self.support_grid_size = 6 18 | model = build_cotracker(checkpoint) 19 | self.interp_shape = model.model_resolution 20 | self.model = model 21 | self.model.eval() 22 | 23 | @torch.no_grad() 24 | def forward( 25 | self, 26 | video, # (B, T, 3, H, W) 27 | # input prompt types: 28 | # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. 29 | # *backward_tracking=True* will compute tracks in both directions. 30 | # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates. 31 | # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. 32 | # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. 33 | queries: torch.Tensor = None, 34 | segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W) 35 | grid_size: int = 0, 36 | grid_query_frame: int = 0, # only for dense and regular grid tracks 37 | backward_tracking: bool = False, 38 | ): 39 | if queries is None and grid_size == 0: 40 | tracks, visibilities = self._compute_dense_tracks( 41 | video, 42 | grid_query_frame=grid_query_frame, 43 | backward_tracking=backward_tracking, 44 | ) 45 | else: 46 | tracks, visibilities = self._compute_sparse_tracks( 47 | video, 48 | queries, 49 | segm_mask, 50 | grid_size, 51 | add_support_grid=(grid_size == 0 or segm_mask is not None), 52 | grid_query_frame=grid_query_frame, 53 | backward_tracking=backward_tracking, 54 | ) 55 | 56 | return tracks, visibilities 57 | 58 | def _compute_dense_tracks( 59 | self, video, grid_query_frame, grid_size=80, backward_tracking=False 60 | ): 61 | *_, H, W = video.shape 62 | grid_step = W // grid_size 63 | grid_width = W // grid_step 64 | grid_height = H // grid_step 65 | tracks = visibilities = None 66 | grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) 67 | grid_pts[0, :, 0] = grid_query_frame 68 | for offset in range(grid_step * grid_step): 69 | print(f"step {offset} / {grid_step * grid_step}") 70 | ox = offset % grid_step 71 | oy = offset // grid_step 72 | grid_pts[0, :, 1] = ( 73 | torch.arange(grid_width).repeat(grid_height) * grid_step + ox 74 | ) 75 | grid_pts[0, :, 2] = ( 76 | torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy 77 | ) 78 | tracks_step, visibilities_step = self._compute_sparse_tracks( 79 | video=video, 80 | queries=grid_pts, 81 | backward_tracking=backward_tracking, 82 | ) 83 | tracks = smart_cat(tracks, tracks_step, dim=2) 84 | visibilities = smart_cat(visibilities, visibilities_step, dim=2) 85 | 86 | return tracks, visibilities 87 | 88 | def _compute_sparse_tracks( 89 | self, 90 | video, 91 | queries, 92 | segm_mask=None, 93 | grid_size=0, 94 | add_support_grid=False, 95 | grid_query_frame=0, 96 | backward_tracking=False, 97 | ): 98 | B, T, C, H, W = video.shape 99 | 100 | video = video.reshape(B * T, C, H, W) 101 | video = F.interpolate( 102 | video, tuple(self.interp_shape), mode="bilinear", align_corners=True 103 | ) 104 | video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) 105 | 106 | if queries is not None: 107 | B, N, D = queries.shape 108 | assert D == 3 109 | queries = queries.clone() 110 | queries[:, :, 1:] *= queries.new_tensor( 111 | [ 112 | (self.interp_shape[1] - 1) / (W - 1), 113 | (self.interp_shape[0] - 1) / (H - 1), 114 | ] 115 | ) 116 | elif grid_size > 0: 117 | grid_pts = get_points_on_a_grid( 118 | grid_size, self.interp_shape, device=video.device 119 | ) 120 | if segm_mask is not None: 121 | segm_mask = F.interpolate( 122 | segm_mask, tuple(self.interp_shape), mode="nearest" 123 | ) 124 | point_mask = segm_mask[0, 0][ 125 | (grid_pts[0, :, 1]).round().long().cpu(), 126 | (grid_pts[0, :, 0]).round().long().cpu(), 127 | ].bool() 128 | grid_pts = grid_pts[:, point_mask] 129 | 130 | queries = torch.cat( 131 | [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], 132 | dim=2, 133 | ).repeat(B, 1, 1) 134 | 135 | if add_support_grid: 136 | grid_pts = get_points_on_a_grid( 137 | self.support_grid_size, self.interp_shape, device=video.device 138 | ) 139 | grid_pts = torch.cat( 140 | [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 141 | ) 142 | grid_pts = grid_pts.repeat(B, 1, 1) 143 | queries = torch.cat([queries, grid_pts], dim=1) 144 | 145 | tracks, visibilities, __ = self.model.forward( 146 | video=video, queries=queries, iters=6 147 | ) 148 | 149 | if backward_tracking: 150 | tracks, visibilities = self._compute_backward_tracks( 151 | video, queries, tracks, visibilities 152 | ) 153 | if add_support_grid: 154 | queries[:, -(self.support_grid_size**2) :, 0] = T - 1 155 | if add_support_grid: 156 | tracks = tracks[:, :, : -(self.support_grid_size**2)] 157 | visibilities = visibilities[:, :, : -(self.support_grid_size**2)] 158 | thr = 0.9 159 | visibilities = visibilities > thr 160 | 161 | # correct query-point predictions 162 | # see https://github.com/facebookresearch/co-tracker/issues/28 163 | 164 | # TODO: batchify 165 | for i in range(len(queries)): 166 | queries_t = queries[i, : tracks.size(2), 0].to(torch.int64) 167 | arange = torch.arange(0, len(queries_t)) 168 | 169 | # overwrite the predictions with the query points 170 | tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:] 171 | 172 | # correct visibilities, the query points should be visible 173 | visibilities[i, queries_t, arange] = True 174 | 175 | tracks *= tracks.new_tensor( 176 | [(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)] 177 | ) 178 | return tracks, visibilities 179 | 180 | def _compute_backward_tracks(self, video, queries, tracks, visibilities): 181 | inv_video = video.flip(1).clone() 182 | inv_queries = queries.clone() 183 | inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 184 | 185 | inv_tracks, inv_visibilities, __ = self.model( 186 | video=inv_video, queries=inv_queries, iters=6 187 | ) 188 | 189 | inv_tracks = inv_tracks.flip(1) 190 | inv_visibilities = inv_visibilities.flip(1) 191 | arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] 192 | 193 | mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) 194 | 195 | tracks[mask] = inv_tracks[mask] 196 | visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] 197 | return tracks, visibilities 198 | 199 | 200 | class CoTrackerOnlinePredictor(torch.nn.Module): 201 | def __init__(self, checkpoint="./checkpoints/cotracker2.pth"): 202 | super().__init__() 203 | self.support_grid_size = 6 204 | model = build_cotracker(checkpoint) 205 | self.interp_shape = model.model_resolution 206 | self.step = model.window_len // 2 207 | self.model = model 208 | self.model.eval() 209 | 210 | @torch.no_grad() 211 | def forward( 212 | self, 213 | video_chunk, 214 | is_first_step: bool = False, 215 | queries: torch.Tensor = None, 216 | grid_size: int = 10, 217 | grid_query_frame: int = 0, 218 | add_support_grid=False, 219 | segm_mask: torch.Tensor = None, 220 | ): 221 | B, T, C, H, W = video_chunk.shape 222 | # Initialize online video processing and save queried points 223 | # This needs to be done before processing *each new video* 224 | if is_first_step: 225 | self.model.init_video_online_processing() 226 | if queries is not None: 227 | B, N, D = queries.shape 228 | assert D == 3 229 | queries = queries.clone() 230 | queries[:, :, 1:] *= queries.new_tensor( 231 | [ 232 | (self.interp_shape[1] - 1) / (W - 1), 233 | (self.interp_shape[0] - 1) / (H - 1), 234 | ] 235 | ) 236 | elif grid_size > 0: 237 | grid_pts = get_points_on_a_grid( 238 | grid_size, self.interp_shape, device=video_chunk.device 239 | ) 240 | if segm_mask is not None: 241 | segm_mask = F.interpolate( 242 | segm_mask, tuple(self.interp_shape), mode="nearest" 243 | ) 244 | point_mask = segm_mask[0, 0][ 245 | (grid_pts[0, :, 1]).round().long().cpu(), 246 | (grid_pts[0, :, 0]).round().long().cpu(), 247 | ].bool() 248 | grid_pts = grid_pts[:, point_mask] 249 | 250 | queries = torch.cat( 251 | [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], 252 | dim=2, 253 | ) 254 | if add_support_grid: 255 | grid_pts = get_points_on_a_grid( 256 | self.support_grid_size, self.interp_shape, device=video_chunk.device 257 | ) 258 | grid_pts = torch.cat( 259 | [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 260 | ) 261 | queries = torch.cat([queries, grid_pts], dim=1) 262 | self.queries = queries 263 | return (None, None) 264 | 265 | video_chunk = video_chunk.reshape(B * T, C, H, W) 266 | video_chunk = F.interpolate( 267 | video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True 268 | ) 269 | video_chunk = video_chunk.reshape( 270 | B, T, 3, self.interp_shape[0], self.interp_shape[1] 271 | ) 272 | 273 | tracks, visibilities, __ = self.model( 274 | video=video_chunk, 275 | queries=self.queries, 276 | iters=6, 277 | is_online=True, 278 | ) 279 | thr = 0.9 280 | return ( 281 | tracks 282 | * tracks.new_tensor( 283 | [ 284 | (W - 1) / (self.interp_shape[1] - 1), 285 | (H - 1) / (self.interp_shape[0] - 1), 286 | ] 287 | ), 288 | visibilities > thr, 289 | ) 290 | -------------------------------------------------------------------------------- /cotracker/cotracker/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /cotracker/cotracker/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/utils/__pycache__/visualizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/cotracker/cotracker/utils/__pycache__/visualizer.cpython-310.pyc -------------------------------------------------------------------------------- /cotracker/cotracker/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import numpy as np 8 | import imageio 9 | import torch 10 | 11 | from matplotlib import cm 12 | import torch.nn.functional as F 13 | import torchvision.transforms as transforms 14 | import matplotlib.pyplot as plt 15 | from PIL import Image, ImageDraw 16 | 17 | 18 | def read_video_from_path(path): 19 | try: 20 | reader = imageio.get_reader(path) 21 | except Exception as e: 22 | print("Error opening video file: ", e) 23 | return None 24 | frames = [] 25 | for i, im in enumerate(reader): 26 | frames.append(np.array(im)) 27 | return np.stack(frames) 28 | 29 | 30 | def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): 31 | # Create a draw object 32 | draw = ImageDraw.Draw(rgb) 33 | # Calculate the bounding box of the circle 34 | left_up_point = (coord[0] - radius, coord[1] - radius) 35 | right_down_point = (coord[0] + radius, coord[1] + radius) 36 | # Draw the circle 37 | draw.ellipse( 38 | [left_up_point, right_down_point], 39 | fill=tuple(color) if visible else None, 40 | outline=tuple(color), 41 | ) 42 | return rgb 43 | 44 | 45 | def draw_line(rgb, coord_y, coord_x, color, linewidth): 46 | draw = ImageDraw.Draw(rgb) 47 | draw.line( 48 | (coord_y[0], coord_y[1], coord_x[0], coord_x[1]), 49 | fill=tuple(color), 50 | width=linewidth, 51 | ) 52 | return rgb 53 | 54 | 55 | def add_weighted(rgb, alpha, original, beta, gamma): 56 | return (rgb * alpha + original * beta + gamma).astype("uint8") 57 | 58 | 59 | class Visualizer: 60 | def __init__( 61 | self, 62 | save_dir: str = "./results", 63 | grayscale: bool = False, 64 | pad_value: int = 0, 65 | fps: int = 10, 66 | mode: str = "rainbow", # 'cool', 'optical_flow' 67 | linewidth: int = 2, 68 | show_first_frame: int = 10, 69 | tracks_leave_trace: int = 0, # -1 for infinite 70 | ): 71 | self.mode = mode 72 | self.save_dir = save_dir 73 | if mode == "rainbow": 74 | self.color_map = cm.get_cmap("gist_rainbow") 75 | elif mode == "cool": 76 | self.color_map = cm.get_cmap(mode) 77 | self.show_first_frame = show_first_frame 78 | self.grayscale = grayscale 79 | self.tracks_leave_trace = tracks_leave_trace 80 | self.pad_value = pad_value 81 | self.linewidth = linewidth 82 | self.fps = fps 83 | 84 | def visualize( 85 | self, 86 | video: torch.Tensor, # (B,T,C,H,W) 87 | tracks: torch.Tensor, # (B,T,N,2) 88 | visibility: torch.Tensor = None, # (B, T, N, 1) bool 89 | gt_tracks: torch.Tensor = None, # (B,T,N,2) 90 | segm_mask: torch.Tensor = None, # (B,1,H,W) 91 | filename: str = "video", 92 | writer=None, # tensorboard Summary Writer, used for visualization during training 93 | step: int = 0, 94 | query_frame: int = 0, 95 | save_video: bool = True, 96 | compensate_for_camera_motion: bool = False, 97 | ): 98 | if compensate_for_camera_motion: 99 | assert segm_mask is not None 100 | if segm_mask is not None: 101 | coords = tracks[0, query_frame].round().long() 102 | segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() 103 | 104 | video = F.pad( 105 | video, 106 | (self.pad_value, self.pad_value, self.pad_value, self.pad_value), 107 | "constant", 108 | 255, 109 | ) 110 | tracks = tracks + self.pad_value 111 | 112 | if self.grayscale: 113 | transform = transforms.Grayscale() 114 | video = transform(video) 115 | video = video.repeat(1, 1, 3, 1, 1) 116 | 117 | res_video = self.draw_tracks_on_video( 118 | video=video, 119 | tracks=tracks, 120 | visibility=visibility, 121 | segm_mask=segm_mask, 122 | gt_tracks=gt_tracks, 123 | query_frame=query_frame, 124 | compensate_for_camera_motion=compensate_for_camera_motion, 125 | ) 126 | if save_video: 127 | self.save_video(res_video, filename=filename, writer=writer, step=step) 128 | return res_video 129 | 130 | def save_video(self, video, filename, writer=None, step=0): 131 | if writer is not None: 132 | writer.add_video( 133 | filename, 134 | video.to(torch.uint8), 135 | global_step=step, 136 | fps=self.fps, 137 | ) 138 | else: 139 | os.makedirs(self.save_dir, exist_ok=True) 140 | wide_list = list(video.unbind(1)) 141 | wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] 142 | 143 | # Prepare the video file path 144 | save_path = os.path.join(self.save_dir, f"{filename}.mp4") 145 | 146 | # Create a writer object 147 | video_writer = imageio.get_writer(save_path, fps=self.fps) 148 | 149 | # Write frames to the video file 150 | for frame in wide_list[2:-1]: 151 | video_writer.append_data(frame) 152 | 153 | video_writer.close() 154 | 155 | print(f"Video saved to {save_path}") 156 | 157 | def draw_tracks_on_video( 158 | self, 159 | video: torch.Tensor, 160 | tracks: torch.Tensor, 161 | visibility: torch.Tensor = None, 162 | segm_mask: torch.Tensor = None, 163 | gt_tracks=None, 164 | query_frame: int = 0, 165 | compensate_for_camera_motion=False, 166 | ): 167 | B, T, C, H, W = video.shape 168 | _, _, N, D = tracks.shape 169 | 170 | assert D == 2 171 | assert C == 3 172 | video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C 173 | tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2 174 | if gt_tracks is not None: 175 | gt_tracks = gt_tracks[0].detach().cpu().numpy() 176 | 177 | res_video = [] 178 | 179 | # process input video 180 | for rgb in video: 181 | res_video.append(rgb.copy()) 182 | vector_colors = np.zeros((T, N, 3)) 183 | 184 | if self.mode == "optical_flow": 185 | import flow_vis 186 | 187 | vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) 188 | elif segm_mask is None: 189 | if self.mode == "rainbow": 190 | y_min, y_max = ( 191 | tracks[query_frame, :, 1].min(), 192 | tracks[query_frame, :, 1].max(), 193 | ) 194 | norm = plt.Normalize(y_min, y_max) 195 | for n in range(N): 196 | color = self.color_map(norm(tracks[query_frame, n, 1])) 197 | color = np.array(color[:3])[None] * 255 198 | vector_colors[:, n] = np.repeat(color, T, axis=0) 199 | else: 200 | # color changes with time 201 | for t in range(T): 202 | color = np.array(self.color_map(t / T)[:3])[None] * 255 203 | vector_colors[t] = np.repeat(color, N, axis=0) 204 | else: 205 | if self.mode == "rainbow": 206 | vector_colors[:, segm_mask <= 0, :] = 255 207 | 208 | y_min, y_max = ( 209 | tracks[0, segm_mask > 0, 1].min(), 210 | tracks[0, segm_mask > 0, 1].max(), 211 | ) 212 | norm = plt.Normalize(y_min, y_max) 213 | for n in range(N): 214 | if segm_mask[n] > 0: 215 | color = self.color_map(norm(tracks[0, n, 1])) 216 | color = np.array(color[:3])[None] * 255 217 | vector_colors[:, n] = np.repeat(color, T, axis=0) 218 | 219 | else: 220 | # color changes with segm class 221 | segm_mask = segm_mask.cpu() 222 | color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) 223 | color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 224 | color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 225 | vector_colors = np.repeat(color[None], T, axis=0) 226 | 227 | # draw tracks 228 | if self.tracks_leave_trace != 0: 229 | for t in range(query_frame + 1, T): 230 | first_ind = ( 231 | max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 232 | ) 233 | curr_tracks = tracks[first_ind : t + 1] 234 | curr_colors = vector_colors[first_ind : t + 1] 235 | if compensate_for_camera_motion: 236 | diff = ( 237 | tracks[first_ind : t + 1, segm_mask <= 0] 238 | - tracks[t : t + 1, segm_mask <= 0] 239 | ).mean(1)[:, None] 240 | 241 | curr_tracks = curr_tracks - diff 242 | curr_tracks = curr_tracks[:, segm_mask > 0] 243 | curr_colors = curr_colors[:, segm_mask > 0] 244 | 245 | res_video[t] = self._draw_pred_tracks( 246 | res_video[t], 247 | curr_tracks, 248 | curr_colors, 249 | ) 250 | if gt_tracks is not None: 251 | res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) 252 | 253 | # draw points 254 | for t in range(query_frame, T): 255 | img = Image.fromarray(np.uint8(res_video[t])) 256 | for i in range(N): 257 | coord = (tracks[t, i, 0], tracks[t, i, 1]) 258 | visibile = True 259 | if visibility is not None: 260 | visibile = visibility[0, t, i] 261 | if coord[0] != 0 and coord[1] != 0: 262 | if not compensate_for_camera_motion or ( 263 | compensate_for_camera_motion and segm_mask[i] > 0 264 | ): 265 | img = draw_circle( 266 | img, 267 | coord=coord, 268 | radius=int(self.linewidth * 2), 269 | color=vector_colors[t, i].astype(int), 270 | visible=visibile, 271 | ) 272 | res_video[t] = np.array(img) 273 | 274 | # construct the final rgb sequence 275 | if self.show_first_frame > 0: 276 | res_video = [res_video[0]] * self.show_first_frame + res_video[1:] 277 | return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() 278 | 279 | def _draw_pred_tracks( 280 | self, 281 | rgb: np.ndarray, # H x W x 3 282 | tracks: np.ndarray, # T x 2 283 | vector_colors: np.ndarray, 284 | alpha: float = 0.5, 285 | ): 286 | T, N, _ = tracks.shape 287 | rgb = Image.fromarray(np.uint8(rgb)) 288 | for s in range(T - 1): 289 | vector_color = vector_colors[s] 290 | original = rgb.copy() 291 | alpha = (s / T) ** 2 292 | for i in range(N): 293 | coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) 294 | coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) 295 | if coord_y[0] != 0 and coord_y[1] != 0: 296 | rgb = draw_line( 297 | rgb, 298 | coord_y, 299 | coord_x, 300 | vector_color[i].astype(int), 301 | self.linewidth, 302 | ) 303 | if self.tracks_leave_trace > 0: 304 | rgb = Image.fromarray( 305 | np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0)) 306 | ) 307 | rgb = np.array(rgb) 308 | return rgb 309 | 310 | def _draw_gt_tracks( 311 | self, 312 | rgb: np.ndarray, # H x W x 3, 313 | gt_tracks: np.ndarray, # T x 2 314 | ): 315 | T, N, _ = gt_tracks.shape 316 | color = np.array((211, 0, 0)) 317 | rgb = Image.fromarray(np.uint8(rgb)) 318 | for t in range(T): 319 | for i in range(N): 320 | gt_tracks = gt_tracks[t][i] 321 | # draw a red cross 322 | if gt_tracks[0] > 0 and gt_tracks[1] > 0: 323 | length = self.linewidth * 3 324 | coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) 325 | coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) 326 | rgb = draw_line( 327 | rgb, 328 | coord_y, 329 | coord_x, 330 | color, 331 | self.linewidth, 332 | ) 333 | coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) 334 | coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) 335 | rgb = draw_line( 336 | rgb, 337 | coord_y, 338 | coord_x, 339 | color, 340 | self.linewidth, 341 | ) 342 | rgb = np.array(rgb) 343 | return rgb 344 | -------------------------------------------------------------------------------- /cotracker/cotracker/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | __version__ = "2.0.0" 9 | -------------------------------------------------------------------------------- /cotracker/track_and_filter_keypoints.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import torch 9 | import argparse 10 | import numpy as np 11 | from tqdm import tqdm 12 | from PIL import Image 13 | import torch.nn.functional as F 14 | import glob 15 | import torchvision 16 | from pathlib import Path 17 | 18 | from cotracker.utils.visualizer import Visualizer 19 | from cotracker.predictor import CoTrackerOnlinePredictor 20 | 21 | # Unfortunately MPS acceleration does not support all the features we require, 22 | # but we may be able to enable it in the future 23 | 24 | DEFAULT_DEVICE = ( 25 | # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" 26 | "cuda" if torch.cuda.is_available() else "cpu" 27 | ) 28 | 29 | 30 | def select_significant_keypoints(tracked_keypoints, threshold=0.5): 31 | """ 32 | Select keypoints with significant flow changes using Laplacian filtering. 33 | 34 | Args: 35 | tracked_keypoints (torch.Tensor): Tensor of tracked keypoints of shape (num_frames, num_keypoints, 2). 36 | threshold (float): Threshold value for selecting significant flow changes. 37 | 38 | Returns: 39 | torch.Tensor: Indices of keypoints with significant flow changes. 40 | """ 41 | # Convert tracked keypoints to PyTorch tensor 42 | tracked_keypoints_tensor = torch.tensor(tracked_keypoints) 43 | 44 | # Compute displacement vectors between consecutive frames 45 | displacement_vectors = tracked_keypoints_tensor[1:] - tracked_keypoints_tensor[:-1] 46 | 47 | # Apply Laplacian filter to displacement vectors 48 | laplacian_kernel = torch.tensor( 49 | [[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float 50 | ).view(1, 1, 3, 3) 51 | laplacian_filtered_vectors = F.conv2d( 52 | displacement_vectors.permute(2, 0, 1).unsqueeze(1), 53 | laplacian_kernel, 54 | padding=1, 55 | ) 56 | 57 | # Compute magnitude of Laplacian-filtered vectors 58 | laplacian_magnitudes = torch.norm(laplacian_filtered_vectors, dim=0) 59 | 60 | # Threshold Laplacian magnitudes to select keypoints with significant flow changes 61 | significant_flow_mask = laplacian_magnitudes > threshold 62 | 63 | result = significant_flow_mask[0, 0] 64 | for i in range(significant_flow_mask.shape[1]): 65 | result = torch.logical_or(result, significant_flow_mask[0, i]) 66 | 67 | # significant_change_mask = ( 68 | # torch.norm(tracked_keypoints_tensor[-1] - tracked_keypoints_tensor[0], dim=-1) 69 | # ) > 14 70 | 71 | # result = torch.logical_and(result, ~significant_change_mask) 72 | 73 | return result 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument( 79 | "--img_path", 80 | default="./assets/apple.mp4", 81 | help="path to a imgs", 82 | ) 83 | parser.add_argument( 84 | "--checkpoint", 85 | default=None, 86 | help="CoTracker model parameters", 87 | ) 88 | parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") 89 | parser.add_argument( 90 | "--grid_query_frame", 91 | type=int, 92 | default=0, 93 | help="Compute dense and grid tracks starting from this frame", 94 | ) 95 | parser.add_argument("--mask-path", type=str, default=None) 96 | 97 | args = parser.parse_args() 98 | 99 | if args.checkpoint is not None: 100 | model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint) 101 | else: 102 | model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") 103 | model = model.to(DEFAULT_DEVICE) 104 | 105 | segm_mask = ( 106 | np.array(Image.open(os.path.join(args.mask_path))).sum(-1).astype(np.float32) 107 | ) 108 | segm_mask = (torch.from_numpy(segm_mask)[None, None] > 0).to(torch.float32) 109 | 110 | # print(f"{segm_mask.shape=}") 111 | # assert False 112 | 113 | window_frames = [] 114 | 115 | def _process_step(window_frames, is_first_step, grid_size, grid_query_frame): 116 | video_chunk = torch.tensor( 117 | np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE 118 | ).float()[None] # (1, T, 3, H, W) 119 | return model( 120 | video_chunk, 121 | is_first_step=is_first_step, 122 | grid_size=grid_size, 123 | grid_query_frame=grid_query_frame, 124 | segm_mask=segm_mask, 125 | ) 126 | 127 | # Iterating over video frames, processing one window at a time: 128 | img_dir: str = args.img_path 129 | assert os.path.exists(img_dir) 130 | 131 | img_paths = list( 132 | sorted(glob.glob(f"{img_dir}/*.jpg"), key=lambda x: int(Path(x).stem)) 133 | ) 134 | is_first_step = True 135 | for i, path in tqdm(enumerate(img_paths), total=len(img_paths)): 136 | frame = torchvision.io.read_image(path) 137 | if i % model.step == 0 and i != 0: 138 | pred_tracks, pred_visibility = _process_step( 139 | window_frames, 140 | is_first_step, 141 | grid_size=args.grid_size, 142 | grid_query_frame=args.grid_query_frame, 143 | ) 144 | is_first_step = False 145 | window_frames.append(frame) 146 | 147 | # Processing the final video frames in case video length is not a multiple of model.step 148 | pred_tracks, pred_visibility = _process_step( 149 | window_frames[-(i % model.step) - model.step - 1 :], 150 | is_first_step, 151 | grid_size=args.grid_size, 152 | grid_query_frame=args.grid_query_frame, 153 | ) 154 | 155 | print(" [INFO] Tracks are computed") 156 | 157 | # save a video with predicted tracks 158 | video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE)[None] 159 | 160 | vis = Visualizer( 161 | save_dir="./saved_videos", 162 | pad_value=120, 163 | linewidth=3, 164 | fps=25, 165 | mode="optical_flow", 166 | ) 167 | # vis.visualize( 168 | # video, 169 | # pred_tracks, 170 | # pred_visibility, 171 | # query_frame=args.grid_query_frame, 172 | # filename="initial", 173 | # ) 174 | 175 | laplacian_filtered_mask = select_significant_keypoints( 176 | pred_tracks[0].cpu(), threshold=5 177 | ) 178 | 179 | visibility_mask = pred_visibility.squeeze().sum(0) > (len(window_frames) * 3 // 4) 180 | 181 | laplacian_filtered_mask = torch.logical_and( 182 | laplacian_filtered_mask, visibility_mask.cpu() 183 | ) 184 | 185 | print(" [INFO] Laplacian filter applied.") 186 | 187 | pred_tracks = pred_tracks[:, :, laplacian_filtered_mask] 188 | pred_visibility = pred_visibility[:, :, laplacian_filtered_mask] 189 | 190 | vis.visualize( 191 | video, 192 | pred_tracks, 193 | pred_visibility, 194 | query_frame=args.grid_query_frame, 195 | filename="filtered", 196 | ) 197 | 198 | torch.save(pred_tracks, os.path.join(os.path.dirname(img_dir), "keypoints.pt")) 199 | print(" [INFO] Done!") 200 | -------------------------------------------------------------------------------- /extract_audio_visual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | import librosa 6 | import numpy as np 7 | import librosa.filters 8 | from scipy import signal 9 | from os.path import basename 10 | 11 | 12 | class Conv2d(nn.Module): 13 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, leakyReLU=False, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.conv_block = nn.Sequential( 16 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 17 | nn.BatchNorm2d(cout) 18 | ) 19 | if leakyReLU: 20 | self.act = nn.LeakyReLU(0.02) 21 | else: 22 | self.act = nn.ReLU() 23 | self.residual = residual 24 | 25 | def forward(self, x): 26 | out = self.conv_block(x) 27 | if self.residual: 28 | out += x 29 | return self.act(out) 30 | 31 | 32 | class AudioEncoder(nn.Module): 33 | def __init__(self): 34 | super(AudioEncoder, self).__init__() 35 | 36 | self.audio_encoder = nn.Sequential( 37 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 38 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 39 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 40 | 41 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 42 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 43 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 44 | 45 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 46 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 47 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 48 | 49 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 50 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 51 | 52 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 53 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ) 54 | 55 | def forward(self, x): 56 | out = self.audio_encoder(x) 57 | out = out.squeeze(2).squeeze(2) 58 | 59 | return out 60 | 61 | 62 | 63 | def load_wav(path, sr): 64 | return librosa.core.load(path, sr=sr)[0] 65 | 66 | 67 | def preemphasis(wav, k): 68 | return signal.lfilter([1, -k], [1], wav) 69 | 70 | 71 | def melspectrogram(wav): 72 | D = _stft(preemphasis(wav, 0.97)) 73 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - 20 74 | 75 | return _normalize(S) 76 | 77 | 78 | def _stft(y): 79 | return librosa.stft(y=y, n_fft=800, hop_length=200, win_length=800) 80 | 81 | 82 | def _linear_to_mel(spectogram): 83 | global _mel_basis 84 | _mel_basis = _build_mel_basis() 85 | return np.dot(_mel_basis, spectogram) 86 | 87 | 88 | def _build_mel_basis(): 89 | return librosa.filters.mel(sr=16000, n_fft=800, n_mels=80, fmin=55, fmax=7600) 90 | 91 | 92 | def _amp_to_db(x): 93 | min_level = np.exp(-5 * np.log(10)) 94 | return 20 * np.log10(np.maximum(min_level, x)) 95 | 96 | 97 | def _normalize(S): 98 | return np.clip((2 * 4.) * ((S - -100) / (--100)) - 4., -4., 4.) 99 | 100 | 101 | 102 | class AudDataset(object): 103 | def __init__(self, wavpath): 104 | wav = load_wav(wavpath, 16000) 105 | 106 | self.orig_mel = melspectrogram(wav).T 107 | self.data_len = int((self.orig_mel.shape[0] - 16) / 80. * float(25)) 108 | 109 | def get_frame_id(self, frame): 110 | return int(basename(frame).split('.')[0]) 111 | 112 | def crop_audio_window(self, spec, start_frame): 113 | if type(start_frame) == int: 114 | start_frame_num = start_frame 115 | else: 116 | start_frame_num = self.get_frame_id(start_frame) 117 | start_idx = int(80. * (start_frame_num / float(25))) 118 | 119 | end_idx = start_idx + 16 120 | 121 | return spec[start_idx: end_idx, :] 122 | 123 | def __len__(self): 124 | return self.data_len 125 | 126 | def __getitem__(self, idx): 127 | 128 | mel = self.crop_audio_window(self.orig_mel.copy(), idx) 129 | if (mel.shape[0] != 16): 130 | raise Exception('mel.shape[0] != 16') 131 | mel = torch.FloatTensor(mel.T).unsqueeze(0) 132 | 133 | return mel 134 | 135 | 136 | import argparse 137 | import os 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--aud', type=str, required=True) 140 | 141 | args = parser.parse_args() 142 | 143 | aud_path = args.aud 144 | 145 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 146 | model = AudioEncoder().to(device).eval() 147 | ckpt = torch.load('/home/ssm-user/codes/sam/nerf/SyncTalk/nerf_triplane/checkpoints/audio_visual_encoder.pth') 148 | model.load_state_dict({f'audio_encoder.{k}': v for k, v in ckpt.items()}) 149 | dataset = AudDataset(aud_path) 150 | data_loader = DataLoader(dataset, batch_size=64, shuffle=False) 151 | outputs = [] 152 | for mel in data_loader: 153 | mel = mel.to(device) 154 | with torch.no_grad(): 155 | out = model(mel) 156 | outputs.append(out) 157 | outputs = torch.cat(outputs, dim=0).cpu() 158 | first_frame, last_frame = outputs[:1], outputs[-1:] 159 | aud_features = torch.cat([first_frame.repeat(2, 1), outputs, last_frame.repeat(2, 1)], dim=0).numpy() 160 | output_aud_path = aud_path.replace('.wav', '_ave.npy') 161 | print("aud_features: ", aud_features.shape) 162 | np.save(output_aud_path, aud_features) -------------------------------------------------------------------------------- /face_parsing/79999_iter.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_parsing/79999_iter.pth -------------------------------------------------------------------------------- /face_parsing/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_parsing/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /face_parsing/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_parsing/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /face_parsing/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os.path as osp 6 | import time 7 | import sys 8 | import logging 9 | 10 | import torch.distributed as dist 11 | 12 | 13 | def setup_logger(logpth): 14 | logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 15 | logfile = osp.join(logpth, logfile) 16 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 17 | log_level = logging.INFO 18 | if dist.is_initialized() and not dist.get_rank()==0: 19 | log_level = logging.ERROR 20 | logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) 21 | logging.root.addHandler(logging.StreamHandler()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /face_parsing/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | 255 | # return feat_out, feat_out16, feat_out32 256 | return feat_out 257 | 258 | def init_weight(self): 259 | for ly in self.children(): 260 | if isinstance(ly, nn.Conv2d): 261 | nn.init.kaiming_normal_(ly.weight, a=1) 262 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 263 | 264 | def get_params(self): 265 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 266 | for name, child in self.named_children(): 267 | child_wd_params, child_nowd_params = child.get_params() 268 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 269 | lr_mul_wd_params += child_wd_params 270 | lr_mul_nowd_params += child_nowd_params 271 | else: 272 | wd_params += child_wd_params 273 | nowd_params += child_nowd_params 274 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 275 | 276 | 277 | if __name__ == "__main__": 278 | net = BiSeNet(19) 279 | net.cuda() 280 | net.eval() 281 | in_ten = torch.randn(16, 3, 640, 480).cuda() 282 | out, out16, out32 = net(in_ten) 283 | print(out.shape) 284 | 285 | net.get_params() 286 | -------------------------------------------------------------------------------- /face_parsing/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /face_parsing/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import numpy as np 4 | from model import BiSeNet 5 | 6 | import torch 7 | 8 | import os 9 | import os.path as osp 10 | 11 | from PIL import Image 12 | import torchvision.transforms as transforms 13 | import cv2 14 | from pathlib import Path 15 | import configargparse 16 | import tqdm 17 | 18 | # import ttach as tta 19 | 20 | 21 | import numpy as np 22 | from scipy import ndimage 23 | 24 | # Example mask and label indices for demonstration 25 | 26 | def get_top_20_percent_of_neck(mask): 27 | neck_labels = [14] 28 | 29 | # Isolate the neck 30 | neck_mask = np.isin(mask, neck_labels) 31 | 32 | # Identify unique x-coordinates (columns) where the neck is present 33 | unique_x_coords = np.unique(np.where(neck_mask)[1]) 34 | 35 | # Initialize an empty mask for the top 10% of the neck across all x-coordinates 36 | top_10_percent_neck_mask = np.zeros_like(mask, dtype=bool) 37 | 38 | # Iterate over each unique x-coordinate 39 | for x in unique_x_coords: 40 | # Find y-coordinates (rows) of neck pixels at this x-coordinate 41 | y_coords = np.where(neck_mask[:, x])[0] 42 | 43 | # Calculate the number of pixels that make up the top 10% 44 | top_10_percent_count = int(np.ceil(0.4 * len(y_coords))) 45 | 46 | # If there are neck pixels at this x-coordinate 47 | if top_10_percent_count > 0: 48 | # Sort y-coordinates in ascending order (top of the image has lower y-values) 49 | sorted_y_coords = np.sort(y_coords) 50 | 51 | # Select the top 10% based on y-coordinates 52 | top_y_coords = sorted_y_coords[:top_10_percent_count] 53 | 54 | # Mark these pixels in the top 10% mask 55 | top_10_percent_neck_mask[top_y_coords, x] = True 56 | # Label connected components 57 | label_im, nb_labels = ndimage.label(top_10_percent_neck_mask) 58 | 59 | # Find the size of each component 60 | sizes = ndimage.sum(top_10_percent_neck_mask, label_im, range(nb_labels + 1)) 61 | 62 | # Exclude the background label by setting its size to 0 63 | sizes[0] = 0 64 | 65 | # Find the label of the largest component 66 | largest_component_label = np.argmax(sizes) 67 | 68 | # Create a mask of the largest component 69 | largest_region_mask = (label_im == largest_component_label) 70 | return largest_region_mask 71 | 72 | 73 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', img_size=(512, 512)): 74 | im = np.array(im) 75 | vis_im = im.copy().astype(np.uint8) 76 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 77 | vis_parsing_anno = cv2.resize( 78 | vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 79 | vis_parsing_anno_color = np.zeros( 80 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 81 | vis_parsing_anno_color_onlyface = vis_parsing_anno_color.copy() 82 | torso_vis_parsing_anno_color_onlyface = np.zeros( 83 | (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) 84 | num_of_class = np.max(vis_parsing_anno) 85 | # ['bg', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 86 | # 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] 87 | # print(num_of_class) 88 | 89 | for pi in range(1, 14): 90 | index = np.where(vis_parsing_anno == pi) 91 | vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) 92 | torso_vis_parsing_anno_color_onlyface[index[0], index[1], :] = np.array([255, 255, 255]) 93 | # only face 94 | # vis_parsing_anno_color_onlyface = vis_parsing_anno_color.copy() 95 | 96 | for pi in range(1, 7): 97 | index = np.where(vis_parsing_anno == pi) 98 | vis_parsing_anno_color_onlyface[index[0], index[1], :] = np.array([255, 0, 0]) 99 | for pi in range(10, 14): 100 | index = np.where(vis_parsing_anno == pi) 101 | vis_parsing_anno_color_onlyface[index[0], index[1], :] = np.array([255, 0, 0]) 102 | 103 | top_20_percent_of_neck = get_top_20_percent_of_neck(vis_parsing_anno) 104 | index = np.where(top_20_percent_of_neck == True) 105 | vis_parsing_anno_color_onlyface[index[0], index[1], :] = np.array([255, 0, 0]) 106 | 107 | for pi in range(14, 16): 108 | index = np.where(vis_parsing_anno == pi) 109 | vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) 110 | for pi in range(16, 17): 111 | index = np.where(vis_parsing_anno == pi) 112 | vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) 113 | for pi in range(17, num_of_class+1): 114 | index = np.where(vis_parsing_anno == pi) 115 | vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) 116 | torso_vis_parsing_anno_color_onlyface[index[0], index[1], :] = np.array([255, 255, 255]) 117 | 118 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 119 | vis_parsing_anno_color_onlyface = vis_parsing_anno_color_onlyface.astype(np.uint8) 120 | torso_vis_parsing_anno_color_onlyface = torso_vis_parsing_anno_color_onlyface.astype(np.uint8) 121 | index = np.where(vis_parsing_anno == num_of_class-1) 122 | vis_im = cv2.resize(vis_parsing_anno_color, img_size) 123 | torso_vis_parsing_anno_color_onlyface = cv2.resize(torso_vis_parsing_anno_color_onlyface, img_size) 124 | vis_im_onlyface = cv2.resize(vis_parsing_anno_color_onlyface, img_size) 125 | if save_im: 126 | save_face_mask_torso = save_path.replace('parsing', 'face_mask') 127 | print("save_face_mask_torso: ", save_face_mask_torso) 128 | cv2.imwrite(save_face_mask_torso, torso_vis_parsing_anno_color_onlyface) 129 | save_path_face = save_path.replace('.png', '_face.png') 130 | cv2.imwrite(save_path, vis_im) 131 | blurred_face = cv2.GaussianBlur(vis_im_onlyface, (99, 99), 2) 132 | cv2.imwrite(save_path_face, blurred_face) 133 | 134 | 135 | def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): 136 | 137 | Path(respth).mkdir(parents=True, exist_ok=True) 138 | print(f'[INFO] output path: {respth} from {dspth}') 139 | print(f'[INFO] loading model...') 140 | n_classes = 19 141 | net = BiSeNet(n_classes=n_classes) 142 | net.cuda() 143 | net.load_state_dict(torch.load(cp)) 144 | net.eval() 145 | 146 | to_tensor = transforms.Compose([ 147 | transforms.ToTensor(), 148 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 149 | ]) 150 | 151 | image_paths = os.listdir(dspth) 152 | print("image_paths: ", image_paths) 153 | with torch.no_grad(): 154 | for image_path in tqdm.tqdm(image_paths): 155 | if image_path.endswith('.jpg') or image_path.endswith('.png'): 156 | img = Image.open(osp.join(dspth, image_path)) 157 | ori_size = img.size 158 | image = img.resize((512, 512), Image.BILINEAR) 159 | image = image.convert("RGB") 160 | img = to_tensor(image) 161 | 162 | # test-time augmentation. 163 | inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512] 164 | outputs = net(inputs.cuda()) 165 | parsing = outputs.mean(0).cpu().numpy().argmax(0) 166 | 167 | image_path = int(image_path[:-4]) 168 | image_path = str(image_path) + '.png' 169 | 170 | vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) 171 | 172 | 173 | if __name__ == "__main__": 174 | parser = configargparse.ArgumentParser() 175 | parser.add_argument('--respath', type=str, default='./result/', help='result path for label') 176 | parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') 177 | parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') 178 | parser.add_argument('--resolution', type=int, default=512, help='resolution of input image') 179 | args = parser.parse_args() 180 | evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) 181 | -------------------------------------------------------------------------------- /face_tracking/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/.DS_Store -------------------------------------------------------------------------------- /face_tracking/3DMM/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/3DMM/.DS_Store -------------------------------------------------------------------------------- /face_tracking/3DMM/exp_info.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/3DMM/exp_info.npy -------------------------------------------------------------------------------- /face_tracking/3DMM/keys_info.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/3DMM/keys_info.npy -------------------------------------------------------------------------------- /face_tracking/3DMM/topology_info.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/3DMM/topology_info.npy -------------------------------------------------------------------------------- /face_tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/__init__.py -------------------------------------------------------------------------------- /face_tracking/__pycache__/data_loader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/__pycache__/data_loader.cpython-310.pyc -------------------------------------------------------------------------------- /face_tracking/__pycache__/facemodel.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/__pycache__/facemodel.cpython-310.pyc -------------------------------------------------------------------------------- /face_tracking/__pycache__/render_3dmm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/__pycache__/render_3dmm.cpython-310.pyc -------------------------------------------------------------------------------- /face_tracking/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christopherohit/nerf_data_preprocessing/b30b51dd915df31752ba5006fa807fde4c47f30d/face_tracking/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /face_tracking/convert_BFM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io import loadmat 3 | 4 | original_BFM = loadmat("3DMM/01_MorphableModel.mat") 5 | sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"] 6 | 7 | shapePC = original_BFM["shapePC"] 8 | shapeEV = original_BFM["shapeEV"] 9 | shapeMU = original_BFM["shapeMU"] 10 | texPC = original_BFM["texPC"] 11 | texEV = original_BFM["texEV"] 12 | texMU = original_BFM["texMU"] 13 | 14 | b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) 15 | mu_shape = shapeMU.reshape(-1, 3) 16 | 17 | b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) 18 | mu_tex = texMU.reshape(-1, 3) 19 | 20 | b_shape = b_shape[:, sub_inds, :].reshape(199, -1) 21 | mu_shape = mu_shape[sub_inds, :].reshape(-1) 22 | b_tex = b_tex[:, sub_inds, :].reshape(199, -1) 23 | mu_tex = mu_tex[sub_inds, :].reshape(-1) 24 | 25 | exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item() 26 | np.save( 27 | "3DMM/3DMM_info.npy", 28 | { 29 | "mu_shape": mu_shape, 30 | "b_shape": b_shape, 31 | "sig_shape": shapeEV.reshape(-1), 32 | "mu_exp": exp_info["mu_exp"], 33 | "b_exp": exp_info["base_exp"], 34 | "sig_exp": exp_info["sig_exp"], 35 | "mu_tex": mu_tex, 36 | "b_tex": b_tex, 37 | "sig_tex": texEV.reshape(-1), 38 | }, 39 | ) 40 | -------------------------------------------------------------------------------- /face_tracking/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def load_dir(path, start, end): 7 | lmss = [] 8 | imgs_paths = [] 9 | for i in range(start, end): 10 | if os.path.isfile(os.path.join(path, str(i) + ".lms")): 11 | lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32) 12 | lmss.append(lms) 13 | imgs_paths.append(os.path.join(path, str(i) + ".jpg")) 14 | lmss = np.stack(lmss) 15 | lmss = torch.as_tensor(lmss).cuda() 16 | return lmss, imgs_paths 17 | -------------------------------------------------------------------------------- /face_tracking/face_tracker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import argparse 5 | from pathlib import Path 6 | import torch 7 | import numpy as np 8 | from data_loader import load_dir 9 | from facemodel import Face_3DMM 10 | from util import * 11 | from render_3dmm import Render_3DMM 12 | 13 | 14 | # torch.autograd.set_detect_anomaly(True) 15 | 16 | dir_path = os.path.dirname(os.path.realpath(__file__)) 17 | 18 | 19 | def set_requires_grad(tensor_list): 20 | for tensor in tensor_list: 21 | tensor.requires_grad = True 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--path", type=str, default="obama/ori_imgs", help="idname of target person" 27 | ) 28 | parser.add_argument("--img_h", type=int, default=512, help="image height") 29 | parser.add_argument("--img_w", type=int, default=512, help="image width") 30 | parser.add_argument("--frame_num", type=int, default=11000, help="image number") 31 | args = parser.parse_args() 32 | 33 | start_id = 0 34 | end_id = args.frame_num 35 | 36 | lms, img_paths = load_dir(args.path, start_id, end_id) 37 | print("lms: ", lms.shape) 38 | num_frames = lms.shape[0] 39 | h, w = args.img_h, args.img_w 40 | cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda() 41 | id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650 42 | model_3dmm = Face_3DMM( 43 | os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num 44 | ) 45 | 46 | # only use one image per 40 to do fit the focal length 47 | sel_ids = np.arange(0, num_frames, 40) 48 | sel_num = sel_ids.shape[0] 49 | arg_focal = 1600 50 | arg_landis = 1e5 51 | 52 | print(f'[INFO] fitting focal length...') 53 | 54 | # fit the focal length 55 | for focal in range(600, 1500, 100): 56 | id_para = lms.new_zeros((1, id_dim), requires_grad=True) 57 | exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True) 58 | euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True) 59 | trans = lms.new_zeros((sel_num, 3), requires_grad=True) 60 | trans.data[:, 2] -= 7 61 | focal_length = lms.new_zeros(1, requires_grad=False) 62 | focal_length.data += focal 63 | set_requires_grad([id_para, exp_para, euler_angle, trans]) 64 | 65 | optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) 66 | optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1) 67 | 68 | for iter in range(2000): 69 | id_para_batch = id_para.expand(sel_num, -1) 70 | geometry = model_3dmm.get_3dlandmarks( 71 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 72 | ) 73 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 74 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) 75 | loss = loss_lan 76 | optimizer_frame.zero_grad() 77 | loss.backward() 78 | optimizer_frame.step() 79 | # if iter % 100 == 0: 80 | # print(focal, 'pose', iter, loss.item()) 81 | 82 | for iter in range(2500): 83 | id_para_batch = id_para.expand(sel_num, -1) 84 | geometry = model_3dmm.get_3dlandmarks( 85 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 86 | ) 87 | print("3dm landmark: ", geometry.shape) 88 | 89 | 90 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 91 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) 92 | loss_regid = torch.mean(id_para * id_para) 93 | loss_regexp = torch.mean(exp_para * exp_para) 94 | loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 95 | optimizer_idexp.zero_grad() 96 | optimizer_frame.zero_grad() 97 | loss.backward() 98 | optimizer_idexp.step() 99 | optimizer_frame.step() 100 | # if iter % 100 == 0: 101 | # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) 102 | 103 | if iter % 1500 == 0 and iter >= 1500: 104 | for param_group in optimizer_idexp.param_groups: 105 | param_group["lr"] *= 0.2 106 | for param_group in optimizer_frame.param_groups: 107 | param_group["lr"] *= 0.2 108 | 109 | print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item()) 110 | 111 | if loss_lan.item() < arg_landis: 112 | arg_landis = loss_lan.item() 113 | arg_focal = focal 114 | 115 | print("[INFO] find best focal:", arg_focal) 116 | 117 | print(f'[INFO] coarse fitting...') 118 | 119 | # for all frames, do a coarse fitting ??? 120 | id_para = lms.new_zeros((1, id_dim), requires_grad=True) 121 | exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) 122 | tex_para = lms.new_zeros( 123 | (1, tex_dim), requires_grad=True 124 | ) # not optimized in this block ??? 125 | euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) 126 | trans = lms.new_zeros((num_frames, 3), requires_grad=True) 127 | light_para = lms.new_zeros((num_frames, 27), requires_grad=True) 128 | trans.data[:, 2] -= 7 # ??? 129 | focal_length = lms.new_zeros(1, requires_grad=True) 130 | focal_length.data += arg_focal 131 | 132 | set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para]) 133 | 134 | optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) 135 | optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1) 136 | 137 | for iter in range(1500): 138 | id_para_batch = id_para.expand(num_frames, -1) 139 | geometry = model_3dmm.get_3dlandmarks( 140 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 141 | ) 142 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 143 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) 144 | loss = loss_lan 145 | optimizer_frame.zero_grad() 146 | loss.backward() 147 | optimizer_frame.step() 148 | if iter == 1000: 149 | for param_group in optimizer_frame.param_groups: 150 | param_group["lr"] = 0.1 151 | # if iter % 100 == 0: 152 | # print('pose', iter, loss.item()) 153 | 154 | for param_group in optimizer_frame.param_groups: 155 | param_group["lr"] = 0.1 156 | 157 | for iter in range(2000): 158 | id_para_batch = id_para.expand(num_frames, -1) 159 | geometry = model_3dmm.get_3dlandmarks( 160 | id_para_batch, exp_para, euler_angle, trans, focal_length, cxy 161 | ) 162 | proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) 163 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) 164 | loss_regid = torch.mean(id_para * id_para) 165 | loss_regexp = torch.mean(exp_para * exp_para) 166 | loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 167 | optimizer_idexp.zero_grad() 168 | optimizer_frame.zero_grad() 169 | loss.backward() 170 | optimizer_idexp.step() 171 | optimizer_frame.step() 172 | # if iter % 100 == 0: 173 | # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) 174 | if iter % 1000 == 0 and iter >= 1000: 175 | for param_group in optimizer_idexp.param_groups: 176 | param_group["lr"] *= 0.2 177 | for param_group in optimizer_frame.param_groups: 178 | param_group["lr"] *= 0.2 179 | 180 | print(loss_lan.item(), torch.mean(trans[:, 2]).item()) 181 | 182 | print(f'[INFO] fitting light...') 183 | 184 | batch_size = 32 185 | 186 | device_default = torch.device("cuda:0") 187 | device_render = torch.device("cuda:0") 188 | renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) 189 | 190 | sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] 191 | imgs = [] 192 | for sel_id in sel_ids: 193 | imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) 194 | imgs = np.stack(imgs) 195 | sel_imgs = torch.as_tensor(imgs).cuda() 196 | sel_lms = lms[sel_ids] 197 | sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) 198 | set_requires_grad([sel_light]) 199 | 200 | optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1) 201 | optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01) 202 | 203 | for iter in range(71): 204 | sel_exp_para, sel_euler, sel_trans = ( 205 | exp_para[sel_ids], 206 | euler_angle[sel_ids], 207 | trans[sel_ids], 208 | ) 209 | sel_id_para = id_para.expand(batch_size, -1) 210 | geometry = model_3dmm.get_3dlandmarks( 211 | sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy 212 | ) 213 | proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) 214 | 215 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) 216 | loss_regid = torch.mean(id_para * id_para) 217 | loss_regexp = torch.mean(sel_exp_para * sel_exp_para) 218 | 219 | sel_tex_para = tex_para.expand(batch_size, -1) 220 | sel_texture = model_3dmm.forward_tex(sel_tex_para) 221 | geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) 222 | rott_geo = forward_rott(geometry, sel_euler, sel_trans) 223 | render_imgs = renderer( 224 | rott_geo.to(device_render), 225 | sel_texture.to(device_render), 226 | sel_light.to(device_render), 227 | ) 228 | render_imgs = render_imgs.to(device_default) 229 | 230 | mask = (render_imgs[:, :, :, 3]).detach() > 0.0 231 | render_proj = sel_imgs.clone() 232 | render_proj[mask] = render_imgs[mask][..., :3].byte() 233 | loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) 234 | 235 | if iter > 50: 236 | loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8 237 | else: 238 | loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0 239 | 240 | optimizer_tl.zero_grad() 241 | optimizer_id_frame.zero_grad() 242 | loss.backward() 243 | 244 | optimizer_tl.step() 245 | optimizer_id_frame.step() 246 | 247 | if iter % 50 == 0 and iter > 0: 248 | for param_group in optimizer_id_frame.param_groups: 249 | param_group["lr"] *= 0.2 250 | for param_group in optimizer_tl.param_groups: 251 | param_group["lr"] *= 0.2 252 | # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item()) 253 | 254 | 255 | light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1) 256 | light_para.data = light_mean 257 | 258 | exp_para = exp_para.detach() 259 | euler_angle = euler_angle.detach() 260 | trans = trans.detach() 261 | light_para = light_para.detach() 262 | 263 | print(f'[INFO] fine frame-wise fitting...') 264 | 265 | for i in range(int((num_frames - 1) / batch_size + 1)): 266 | 267 | if (i + 1) * batch_size > num_frames: 268 | start_n = num_frames - batch_size 269 | sel_ids = np.arange(num_frames - batch_size, num_frames) 270 | else: 271 | start_n = i * batch_size 272 | sel_ids = np.arange(i * batch_size, i * batch_size + batch_size) 273 | 274 | imgs = [] 275 | for sel_id in sel_ids: 276 | imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) 277 | imgs = np.stack(imgs) 278 | sel_imgs = torch.as_tensor(imgs).cuda() 279 | sel_lms = lms[sel_ids] 280 | 281 | sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True) 282 | sel_exp_para.data = exp_para[sel_ids].clone() 283 | sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True) 284 | sel_euler.data = euler_angle[sel_ids].clone() 285 | sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) 286 | sel_trans.data = trans[sel_ids].clone() 287 | sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) 288 | sel_light.data = light_para[sel_ids].clone() 289 | 290 | set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light]) 291 | 292 | optimizer_cur_batch = torch.optim.Adam( 293 | [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005 294 | ) 295 | 296 | sel_id_para = id_para.expand(batch_size, -1).detach() 297 | sel_tex_para = tex_para.expand(batch_size, -1).detach() 298 | 299 | pre_num = 5 300 | 301 | if i > 0: 302 | pre_ids = np.arange(start_n - pre_num, start_n) 303 | 304 | for iter in range(50): 305 | 306 | geometry = model_3dmm.get_3dlandmarks( 307 | sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy 308 | ) 309 | proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) 310 | loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) 311 | loss_regexp = torch.mean(sel_exp_para * sel_exp_para) 312 | 313 | sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) 314 | sel_texture = model_3dmm.forward_tex(sel_tex_para) 315 | geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) 316 | rott_geo = forward_rott(geometry, sel_euler, sel_trans) 317 | render_imgs = renderer( 318 | rott_geo.to(device_render), 319 | sel_texture.to(device_render), 320 | sel_light.to(device_render), 321 | ) 322 | render_imgs = render_imgs.to(device_default) 323 | 324 | mask = (render_imgs[:, :, :, 3]).detach() > 0.0 325 | 326 | loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) 327 | 328 | if i > 0: 329 | geometry_lap = model_3dmm.forward_geo_sub( 330 | id_para.expand(batch_size + pre_num, -1).detach(), 331 | torch.cat((exp_para[pre_ids].detach(), sel_exp_para)), 332 | model_3dmm.rigid_ids, 333 | ) 334 | rott_geo_lap = forward_rott( 335 | geometry_lap, 336 | torch.cat((euler_angle[pre_ids].detach(), sel_euler)), 337 | torch.cat((trans[pre_ids].detach(), sel_trans)), 338 | ) 339 | loss_lap = cal_lap_loss( 340 | [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] 341 | ) 342 | else: 343 | geometry_lap = model_3dmm.forward_geo_sub( 344 | id_para.expand(batch_size, -1).detach(), 345 | sel_exp_para, 346 | model_3dmm.rigid_ids, 347 | ) 348 | rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans) 349 | loss_lap = cal_lap_loss( 350 | [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] 351 | ) 352 | 353 | 354 | if iter > 30: 355 | loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0 356 | else: 357 | loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0 358 | 359 | optimizer_cur_batch.zero_grad() 360 | loss.backward() 361 | optimizer_cur_batch.step() 362 | 363 | # if iter % 10 == 0: 364 | # print( 365 | # i, 366 | # iter, 367 | # loss_col.item(), 368 | # loss_lan.item(), 369 | # loss_lap.item(), 370 | # loss_regexp.item(), 371 | # ) 372 | 373 | print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done") 374 | 375 | render_proj = sel_imgs.clone() 376 | render_proj[mask] = render_imgs[mask][..., :3].byte() 377 | 378 | exp_para[sel_ids] = sel_exp_para.clone() 379 | euler_angle[sel_ids] = sel_euler.clone() 380 | trans[sel_ids] = sel_trans.clone() 381 | light_para[sel_ids] = sel_light.clone() 382 | 383 | torch.save( 384 | { 385 | "id": id_para.detach().cpu(), 386 | "exp": exp_para.detach().cpu(), 387 | "euler": euler_angle.detach().cpu(), 388 | "trans": trans.detach().cpu(), 389 | "focal": focal_length.detach().cpu(), 390 | }, 391 | os.path.join(os.path.dirname(args.path), "track_params.pt"), 392 | ) 393 | 394 | print("params saved") 395 | -------------------------------------------------------------------------------- /face_tracking/facemodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from util import * 6 | 7 | 8 | class Face_3DMM(nn.Module): 9 | def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num): 10 | super(Face_3DMM, self).__init__() 11 | # id_dim = 100 12 | # exp_dim = 79 13 | # tex_dim = 100 14 | self.point_num = point_num 15 | DMM_info = np.load( 16 | os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True 17 | ).item() 18 | base_id = DMM_info["b_shape"][:id_dim, :] 19 | mu_id = DMM_info["mu_shape"] 20 | base_exp = DMM_info["b_exp"][:exp_dim, :] 21 | mu_exp = DMM_info["mu_exp"] 22 | mu = mu_id + mu_exp 23 | mu = mu.reshape(-1, 3) 24 | for i in range(3): 25 | mu[:, i] -= np.mean(mu[:, i]) 26 | mu = mu.reshape(-1) 27 | self.base_id = torch.as_tensor(base_id).cuda() / 100000.0 28 | self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0 29 | self.mu = torch.as_tensor(mu).cuda() / 100000.0 30 | base_tex = DMM_info["b_tex"][:tex_dim, :] 31 | mu_tex = DMM_info["mu_tex"] 32 | self.base_tex = torch.as_tensor(base_tex).cuda() 33 | self.mu_tex = torch.as_tensor(mu_tex).cuda() 34 | sig_id = DMM_info["sig_shape"][:id_dim] 35 | sig_tex = DMM_info["sig_tex"][:tex_dim] 36 | sig_exp = DMM_info["sig_exp"][:exp_dim] 37 | self.sig_id = torch.as_tensor(sig_id).cuda() 38 | self.sig_tex = torch.as_tensor(sig_tex).cuda() 39 | self.sig_exp = torch.as_tensor(sig_exp).cuda() 40 | 41 | keys_info = np.load( 42 | os.path.join(modelpath, "keys_info.npy"), allow_pickle=True 43 | ).item() 44 | self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda() 45 | self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda() 46 | self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda() 47 | self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda() 48 | 49 | def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy): 50 | id_para = id_para * self.sig_id 51 | exp_para = exp_para * self.sig_exp 52 | batch_size = id_para.shape[0] 53 | num_per_contour = self.left_contours.shape[1] 54 | left_contours_flat = self.left_contours.reshape(-1) 55 | right_contours_flat = self.right_contours.reshape(-1) 56 | sel_index = torch.cat( 57 | ( 58 | 3 * left_contours_flat.unsqueeze(1), 59 | 3 * left_contours_flat.unsqueeze(1) + 1, 60 | 3 * left_contours_flat.unsqueeze(1) + 2, 61 | ), 62 | dim=1, 63 | ).reshape(-1) 64 | left_geometry = ( 65 | torch.mm(id_para, self.base_id[:, sel_index]) 66 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 67 | + self.mu[sel_index] 68 | ) 69 | left_geometry = left_geometry.view(batch_size, -1, 3) 70 | proj_x = forward_transform( 71 | left_geometry, euler_angle, trans, focal_length, cxy 72 | )[:, :, 0] 73 | proj_x = proj_x.reshape(batch_size, 8, num_per_contour) 74 | arg_min = proj_x.argmin(dim=2) 75 | left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3) 76 | left_3dlands = left_geometry[ 77 | torch.arange(batch_size * 8), arg_min.view(-1), : 78 | ].view(batch_size, 8, 3) 79 | 80 | sel_index = torch.cat( 81 | ( 82 | 3 * right_contours_flat.unsqueeze(1), 83 | 3 * right_contours_flat.unsqueeze(1) + 1, 84 | 3 * right_contours_flat.unsqueeze(1) + 2, 85 | ), 86 | dim=1, 87 | ).reshape(-1) 88 | right_geometry = ( 89 | torch.mm(id_para, self.base_id[:, sel_index]) 90 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 91 | + self.mu[sel_index] 92 | ) 93 | right_geometry = right_geometry.view(batch_size, -1, 3) 94 | proj_x = forward_transform( 95 | right_geometry, euler_angle, trans, focal_length, cxy 96 | )[:, :, 0] 97 | proj_x = proj_x.reshape(batch_size, 8, num_per_contour) 98 | arg_max = proj_x.argmax(dim=2) 99 | right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3) 100 | right_3dlands = right_geometry[ 101 | torch.arange(batch_size * 8), arg_max.view(-1), : 102 | ].view(batch_size, 8, 3) 103 | 104 | sel_index = torch.cat( 105 | ( 106 | 3 * self.keyinds.unsqueeze(1), 107 | 3 * self.keyinds.unsqueeze(1) + 1, 108 | 3 * self.keyinds.unsqueeze(1) + 2, 109 | ), 110 | dim=1, 111 | ).reshape(-1) 112 | geometry = ( 113 | torch.mm(id_para, self.base_id[:, sel_index]) 114 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 115 | + self.mu[sel_index] 116 | ) 117 | lands_3d = geometry.view(-1, self.keyinds.shape[0], 3) 118 | lands_3d[:, :8, :] = left_3dlands 119 | lands_3d[:, 9:17, :] = right_3dlands 120 | return lands_3d 121 | 122 | def forward_geo_sub(self, id_para, exp_para, sub_index): 123 | id_para = id_para * self.sig_id 124 | exp_para = exp_para * self.sig_exp 125 | sel_index = torch.cat( 126 | ( 127 | 3 * sub_index.unsqueeze(1), 128 | 3 * sub_index.unsqueeze(1) + 1, 129 | 3 * sub_index.unsqueeze(1) + 2, 130 | ), 131 | dim=1, 132 | ).reshape(-1) 133 | geometry = ( 134 | torch.mm(id_para, self.base_id[:, sel_index]) 135 | + torch.mm(exp_para, self.base_exp[:, sel_index]) 136 | + self.mu[sel_index] 137 | ) 138 | return geometry.reshape(-1, sub_index.shape[0], 3) 139 | 140 | def forward_geo(self, id_para, exp_para): 141 | id_para = id_para * self.sig_id 142 | exp_para = exp_para * self.sig_exp 143 | geometry = ( 144 | torch.mm(id_para, self.base_id) 145 | + torch.mm(exp_para, self.base_exp) 146 | + self.mu 147 | ) 148 | return geometry.reshape(-1, self.point_num, 3) 149 | 150 | def forward_tex(self, tex_para): 151 | tex_para = tex_para * self.sig_tex 152 | texture = torch.mm(tex_para, self.base_tex) + self.mu_tex 153 | return texture.reshape(-1, self.point_num, 3) 154 | -------------------------------------------------------------------------------- /face_tracking/geo_transform.py: -------------------------------------------------------------------------------- 1 | """This module contains functions for geometry transform and camera projection""" 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def euler2rot(euler_angle): 8 | batch_size = euler_angle.shape[0] 9 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 10 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 11 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 12 | one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) 13 | zero = torch.zeros( 14 | (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device 15 | ) 16 | rot_x = torch.cat( 17 | ( 18 | torch.cat((one, zero, zero), 1), 19 | torch.cat((zero, theta.cos(), theta.sin()), 1), 20 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 21 | ), 22 | 2, 23 | ) 24 | rot_y = torch.cat( 25 | ( 26 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 27 | torch.cat((zero, one, zero), 1), 28 | torch.cat((phi.sin(), zero, phi.cos()), 1), 29 | ), 30 | 2, 31 | ) 32 | rot_z = torch.cat( 33 | ( 34 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 35 | torch.cat((psi.sin(), psi.cos(), zero), 1), 36 | torch.cat((zero, zero, one), 1), 37 | ), 38 | 2, 39 | ) 40 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 41 | 42 | 43 | def rot_trans_geo(geometry, rot, trans): 44 | rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) 45 | return rott_geo.permute(0, 2, 1) 46 | 47 | 48 | def euler_trans_geo(geometry, euler, trans): 49 | rot = euler2rot(euler) 50 | return rot_trans_geo(geometry, rot, trans) 51 | 52 | 53 | def proj_geo(rott_geo, camera_para): 54 | fx = camera_para[:, 0] 55 | fy = camera_para[:, 0] 56 | cx = camera_para[:, 1] 57 | cy = camera_para[:, 2] 58 | 59 | X = rott_geo[:, :, 0] 60 | Y = rott_geo[:, :, 1] 61 | Z = rott_geo[:, :, 2] 62 | 63 | fxX = fx[:, None] * X 64 | fyY = fy[:, None] * Y 65 | 66 | proj_x = -fxX / Z + cx[:, None] 67 | proj_y = fyY / Z + cy[:, None] 68 | 69 | return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) 70 | -------------------------------------------------------------------------------- /face_tracking/render_3dmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | from pytorch3d.structures import Meshes 6 | from pytorch3d.renderer import ( 7 | look_at_view_transform, 8 | PerspectiveCameras, 9 | FoVPerspectiveCameras, 10 | PointLights, 11 | DirectionalLights, 12 | Materials, 13 | RasterizationSettings, 14 | MeshRenderer, 15 | MeshRasterizer, 16 | SoftPhongShader, 17 | TexturesUV, 18 | TexturesVertex, 19 | blending, 20 | ) 21 | 22 | from pytorch3d.ops import interpolate_face_attributes 23 | 24 | from pytorch3d.renderer.blending import ( 25 | BlendParams, 26 | hard_rgb_blend, 27 | sigmoid_alpha_blend, 28 | softmax_rgb_blend, 29 | ) 30 | 31 | 32 | class SoftSimpleShader(nn.Module): 33 | """ 34 | Per pixel lighting - the lighting model is applied using the interpolated 35 | coordinates and normals for each pixel. The blending function returns the 36 | soft aggregated color using all the faces per pixel. 37 | 38 | To use the default values, simply initialize the shader with the desired 39 | device e.g. 40 | 41 | """ 42 | 43 | def __init__( 44 | self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None 45 | ): 46 | super().__init__() 47 | self.lights = lights if lights is not None else PointLights(device=device) 48 | self.materials = ( 49 | materials if materials is not None else Materials(device=device) 50 | ) 51 | self.cameras = cameras 52 | self.blend_params = blend_params if blend_params is not None else BlendParams() 53 | 54 | def to(self, device): 55 | # Manually move to device modules which are not subclasses of nn.Module 56 | self.cameras = self.cameras.to(device) 57 | self.materials = self.materials.to(device) 58 | self.lights = self.lights.to(device) 59 | return self 60 | 61 | def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: 62 | 63 | texels = meshes.sample_textures(fragments) 64 | blend_params = kwargs.get("blend_params", self.blend_params) 65 | 66 | cameras = kwargs.get("cameras", self.cameras) 67 | if cameras is None: 68 | msg = "Cameras must be specified either at initialization \ 69 | or in the forward pass of SoftPhongShader" 70 | raise ValueError(msg) 71 | znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) 72 | zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) 73 | images = softmax_rgb_blend( 74 | texels, fragments, blend_params, znear=znear, zfar=zfar 75 | ) 76 | return images 77 | 78 | 79 | class Render_3DMM(nn.Module): 80 | def __init__( 81 | self, 82 | focal=1015, 83 | img_h=500, 84 | img_w=500, 85 | batch_size=1, 86 | device=torch.device("cuda:0"), 87 | ): 88 | super(Render_3DMM, self).__init__() 89 | 90 | self.focal = focal 91 | self.img_h = img_h 92 | self.img_w = img_w 93 | self.device = device 94 | self.renderer = self.get_render(batch_size) 95 | 96 | dir_path = os.path.dirname(os.path.realpath(__file__)) 97 | topo_info = np.load( 98 | os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True 99 | ).item() 100 | self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) 101 | self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) 102 | 103 | def compute_normal(self, geometry): 104 | vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) 105 | vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) 106 | vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) 107 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) 108 | tri_normal = nn.functional.normalize(nnorm, dim=2) 109 | v_norm = tri_normal[:, self.vert_tris, :].sum(2) 110 | vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) 111 | return vert_normal 112 | 113 | def get_render(self, batch_size=1): 114 | half_s = self.img_w * 0.5 115 | R, T = look_at_view_transform(10, 0, 0) 116 | R = R.repeat(batch_size, 1, 1) 117 | T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) 118 | 119 | cameras = FoVPerspectiveCameras( 120 | device=self.device, 121 | R=R, 122 | T=T, 123 | znear=0.01, 124 | zfar=20, 125 | fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, 126 | ) 127 | lights = PointLights( 128 | device=self.device, 129 | location=[[0.0, 0.0, 1e5]], 130 | ambient_color=[[1, 1, 1]], 131 | specular_color=[[0.0, 0.0, 0.0]], 132 | diffuse_color=[[0.0, 0.0, 0.0]], 133 | ) 134 | sigma = 1e-4 135 | raster_settings = RasterizationSettings( 136 | image_size=(self.img_h, self.img_w), 137 | blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, 138 | faces_per_pixel=2, 139 | perspective_correct=False, 140 | ) 141 | blend_params = blending.BlendParams(background_color=[0, 0, 0]) 142 | renderer = MeshRenderer( 143 | rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), 144 | shader=SoftSimpleShader( 145 | lights=lights, blend_params=blend_params, cameras=cameras 146 | ), 147 | ) 148 | return renderer.to(self.device) 149 | 150 | @staticmethod 151 | def Illumination_layer(face_texture, norm, gamma): 152 | 153 | n_b, num_vertex, _ = face_texture.size() 154 | n_v_full = n_b * num_vertex 155 | gamma = gamma.view(-1, 3, 9).clone() 156 | gamma[:, :, 0] += 0.8 157 | 158 | gamma = gamma.permute(0, 2, 1) 159 | 160 | a0 = np.pi 161 | a1 = 2 * np.pi / np.sqrt(3.0) 162 | a2 = 2 * np.pi / np.sqrt(8.0) 163 | c0 = 1 / np.sqrt(4 * np.pi) 164 | c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) 165 | c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) 166 | d0 = 0.5 / np.sqrt(3.0) 167 | 168 | Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 169 | norm = norm.view(-1, 3) 170 | nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] 171 | arrH = [] 172 | 173 | arrH.append(Y0) 174 | arrH.append(-a1 * c1 * ny) 175 | arrH.append(a1 * c1 * nz) 176 | arrH.append(-a1 * c1 * nx) 177 | arrH.append(a2 * c2 * nx * ny) 178 | arrH.append(-a2 * c2 * ny * nz) 179 | arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) 180 | arrH.append(-a2 * c2 * nx * nz) 181 | arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) 182 | 183 | H = torch.stack(arrH, 1) 184 | Y = H.view(n_b, num_vertex, 9) 185 | lighting = Y.bmm(gamma) 186 | 187 | face_color = face_texture * lighting 188 | return face_color 189 | 190 | def forward(self, rott_geometry, texture, diffuse_sh): 191 | face_normal = self.compute_normal(rott_geometry) 192 | face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) 193 | face_color = TexturesVertex(face_color) 194 | mesh = Meshes( 195 | rott_geometry, 196 | self.tris.float().repeat(rott_geometry.shape[0], 1, 1), 197 | face_color, 198 | ) 199 | rendered_img = self.renderer(mesh) 200 | rendered_img = torch.clamp(rendered_img, 0, 255) 201 | 202 | return rendered_img 203 | -------------------------------------------------------------------------------- /face_tracking/render_land.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import render_util 4 | import geo_transform 5 | import numpy as np 6 | 7 | 8 | def compute_tri_normal(geometry, tris): 9 | geometry = geometry.permute(0, 2, 1) 10 | tri_1 = tris[:, 0] 11 | tri_2 = tris[:, 1] 12 | tri_3 = tris[:, 2] 13 | 14 | vert_1 = torch.index_select(geometry, 2, tri_1) 15 | vert_2 = torch.index_select(geometry, 2, tri_2) 16 | vert_3 = torch.index_select(geometry, 2, tri_3) 17 | 18 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1) 19 | normal = nn.functional.normalize(nnorm).permute(0, 2, 1) 20 | return normal 21 | 22 | 23 | class Compute_normal_base(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, normal): 26 | (normal_b,) = render_util.normal_base_forward(normal) 27 | ctx.save_for_backward(normal) 28 | return normal_b 29 | 30 | @staticmethod 31 | def backward(ctx, grad_normal_b): 32 | (normal,) = ctx.saved_tensors 33 | (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal) 34 | return grad_normal 35 | 36 | 37 | class Normal_Base(torch.nn.Module): 38 | def __init__(self): 39 | super(Normal_Base, self).__init__() 40 | 41 | def forward(self, normal): 42 | return Compute_normal_base.apply(normal) 43 | 44 | 45 | def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): 46 | point_num = geometry.shape[1] 47 | rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) 48 | proj_geo = geo_transform.proj_geo(rott_geo, cam) 49 | rot_tri_normal = compute_tri_normal(rott_geo, tris) 50 | rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) 51 | is_visible = -torch.bmm( 52 | rot_vert_normal.reshape(-1, 1, 3), 53 | nn.functional.normalize(rott_geo.reshape(-1, 3, 1)), 54 | ).reshape(-1, point_num) 55 | is_visible[is_visible < 0.01] = -1 56 | pixel_valid = torch.zeros( 57 | (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]), 58 | dtype=torch.float32, 59 | device=ori_img.device, 60 | ) 61 | return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid 62 | 63 | 64 | class Render_Face(torch.autograd.Function): 65 | @staticmethod 66 | def forward( 67 | ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid 68 | ): 69 | batch_size, h, w, _ = ori_img.shape 70 | ori_img = ori_img.view(batch_size, -1, 3) 71 | ori_size = torch.cat( 72 | ( 73 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 74 | * h, 75 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 76 | * w, 77 | ), 78 | dim=1, 79 | ).view(-1) 80 | tri_index, tri_coord, render, real = render_util.render_face_forward( 81 | proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid 82 | ) 83 | ctx.save_for_backward( 84 | ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord 85 | ) 86 | return render, real 87 | 88 | @staticmethod 89 | def backward(ctx, grad_render, grad_real): 90 | ( 91 | ori_img, 92 | ori_size, 93 | proj_geo, 94 | texture, 95 | nbl, 96 | tri_inds, 97 | tri_index, 98 | tri_coord, 99 | ) = ctx.saved_tensors 100 | grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( 101 | grad_render, 102 | grad_real, 103 | ori_img, 104 | ori_size, 105 | proj_geo, 106 | texture, 107 | nbl, 108 | tri_inds, 109 | tri_index, 110 | tri_coord, 111 | ) 112 | return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None 113 | 114 | 115 | class Render_RGB(nn.Module): 116 | def __init__(self): 117 | super(Render_RGB, self).__init__() 118 | 119 | def forward( 120 | self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid 121 | ): 122 | return Render_Face.apply( 123 | proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid 124 | ) 125 | 126 | 127 | def cal_land(proj_geo, is_visible, lands_info, land_num): 128 | (land_index,) = render_util.update_contour(lands_info, is_visible, land_num) 129 | proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[ 130 | :, :2 131 | ].reshape(-1, land_num, 2) 132 | return proj_land 133 | 134 | 135 | class Render_Land(nn.Module): 136 | def __init__(self): 137 | super(Render_Land, self).__init__() 138 | lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32) 139 | self.lands_info = torch.as_tensor(lands_info).cuda() 140 | tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64) 141 | self.tris = torch.as_tensor(tris).cuda() - 1 142 | vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64) 143 | self.vert_tris = torch.as_tensor(vert_tris).cuda() 144 | self.normal_baser = Normal_Base().cuda() 145 | self.renderer = Render_RGB().cuda() 146 | 147 | def render_mesh(self, geometry, euler, trans, cam, ori_img, light): 148 | batch_size, h, w, _ = ori_img.shape 149 | ori_img = ori_img.view(batch_size, -1, 3) 150 | ori_size = torch.cat( 151 | ( 152 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 153 | * h, 154 | torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) 155 | * w, 156 | ), 157 | dim=1, 158 | ).view(-1) 159 | rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( 160 | geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img 161 | ) 162 | tri_nb = self.normal_baser(rot_tri_normal.contiguous()) 163 | nbl = torch.bmm( 164 | tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3) 165 | ) 166 | texture = torch.ones_like(geometry) * 200 167 | (render,) = render_util.render_mesh( 168 | proj_geo, ori_img, ori_size, texture, nbl, self.tris 169 | ) 170 | return render.view(batch_size, h, w, 3).byte() 171 | 172 | def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): 173 | rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render( 174 | geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img 175 | ) 176 | tri_nb = self.normal_baser(rot_tri_normal.contiguous()) 177 | nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) 178 | render, real = self.renderer( 179 | proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid 180 | ) 181 | proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1]) 182 | col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape( 183 | ori_img.shape[0], -1 184 | ) 185 | col_dis = torch.mean(col_minus * pixel_valid) / ( 186 | torch.mean(pixel_valid) + 0.00001 187 | ) 188 | land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape( 189 | ori_img.shape[0], -1 190 | ) 191 | lan_dis = torch.mean(land_dists) 192 | return col_dis, lan_dis 193 | -------------------------------------------------------------------------------- /face_tracking/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def compute_tri_normal(geometry, tris): 7 | tri_1 = tris[:, 0] 8 | tri_2 = tris[:, 1] 9 | tri_3 = tris[:, 2] 10 | vert_1 = torch.index_select(geometry, 1, tri_1) 11 | vert_2 = torch.index_select(geometry, 1, tri_2) 12 | vert_3 = torch.index_select(geometry, 1, tri_3) 13 | nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) 14 | normal = nn.functional.normalize(nnorm) 15 | return normal 16 | 17 | 18 | def euler2rot(euler_angle): 19 | batch_size = euler_angle.shape[0] 20 | theta = euler_angle[:, 0].reshape(-1, 1, 1) 21 | phi = euler_angle[:, 1].reshape(-1, 1, 1) 22 | psi = euler_angle[:, 2].reshape(-1, 1, 1) 23 | one = torch.ones(batch_size, 1, 1).to(euler_angle.device) 24 | zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) 25 | rot_x = torch.cat( 26 | ( 27 | torch.cat((one, zero, zero), 1), 28 | torch.cat((zero, theta.cos(), theta.sin()), 1), 29 | torch.cat((zero, -theta.sin(), theta.cos()), 1), 30 | ), 31 | 2, 32 | ) 33 | rot_y = torch.cat( 34 | ( 35 | torch.cat((phi.cos(), zero, -phi.sin()), 1), 36 | torch.cat((zero, one, zero), 1), 37 | torch.cat((phi.sin(), zero, phi.cos()), 1), 38 | ), 39 | 2, 40 | ) 41 | rot_z = torch.cat( 42 | ( 43 | torch.cat((psi.cos(), -psi.sin(), zero), 1), 44 | torch.cat((psi.sin(), psi.cos(), zero), 1), 45 | torch.cat((zero, zero, one), 1), 46 | ), 47 | 2, 48 | ) 49 | return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) 50 | 51 | 52 | def rot_trans_pts(geometry, rot, trans): 53 | rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] 54 | return rott_geo.permute(0, 2, 1) 55 | 56 | 57 | def cal_lap_loss(tensor_list, weight_list): 58 | lap_kernel = ( 59 | torch.Tensor((-0.5, 1.0, -0.5)) 60 | .unsqueeze(0) 61 | .unsqueeze(0) 62 | .float() 63 | .to(tensor_list[0].device) 64 | ) 65 | loss_lap = 0 66 | for i in range(len(tensor_list)): 67 | in_tensor = tensor_list[i] 68 | in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) 69 | out_tensor = F.conv1d(in_tensor, lap_kernel) 70 | loss_lap += torch.mean(out_tensor ** 2) * weight_list[i] 71 | return loss_lap 72 | 73 | 74 | def proj_pts(rott_geo, focal_length, cxy): 75 | cx, cy = cxy[0], cxy[1] 76 | X = rott_geo[:, :, 0] 77 | Y = rott_geo[:, :, 1] 78 | Z = rott_geo[:, :, 2] 79 | fxX = focal_length * X 80 | fyY = focal_length * Y 81 | proj_x = -fxX / Z + cx 82 | proj_y = fyY / Z + cy 83 | return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) 84 | 85 | 86 | def forward_rott(geometry, euler_angle, trans): 87 | rot = euler2rot(euler_angle) 88 | rott_geo = rot_trans_pts(geometry, rot, trans) 89 | return rott_geo 90 | 91 | 92 | def forward_transform(geometry, euler_angle, trans, focal_length, cxy): 93 | rot = euler2rot(euler_angle) 94 | rott_geo = rot_trans_pts(geometry, rot, trans) 95 | proj_geo = proj_pts(rott_geo, focal_length, cxy) 96 | return proj_geo 97 | 98 | 99 | def cal_lan_loss(proj_lan, gt_lan): 100 | return torch.mean((proj_lan - gt_lan) ** 2) 101 | 102 | 103 | def cal_col_loss(pred_img, gt_img, img_mask): 104 | pred_img = pred_img.float() 105 | # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 106 | loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255 107 | loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) 108 | loss = torch.mean(loss) 109 | return loss 110 | -------------------------------------------------------------------------------- /wav2mel.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | from scipy import signal 5 | from wav2mel_hparams import hparams as hp 6 | from librosa.core.audio import resample 7 | import soundfile as sf 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr) 11 | 12 | def preemphasis(wav, k, preemphasize=True): 13 | if preemphasize: 14 | return signal.lfilter([1, -k], [1], wav) 15 | return wav 16 | 17 | def inv_preemphasis(wav, k, inv_preemphasize=True): 18 | if inv_preemphasize: 19 | return signal.lfilter([1], [1, -k], wav) 20 | return wav 21 | 22 | def get_hop_size(): 23 | hop_size = hp.hop_size 24 | if hop_size is None: 25 | assert hp.frame_shift_ms is not None 26 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 27 | return hop_size 28 | 29 | def linearspectrogram(wav): 30 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 31 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 32 | 33 | if hp.signal_normalization: 34 | return _normalize(S) 35 | return S 36 | 37 | def melspectrogram(wav): 38 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 39 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 40 | 41 | if hp.signal_normalization: 42 | return _normalize(S) 43 | return S 44 | 45 | def _stft(y): 46 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) 47 | 48 | ########################################################## 49 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 50 | def num_frames(length, fsize, fshift): 51 | """Compute number of time frames of spectrogram 52 | """ 53 | pad = (fsize - fshift) 54 | if length % fshift == 0: 55 | M = (length + pad * 2 - fsize) // fshift + 1 56 | else: 57 | M = (length + pad * 2 - fsize) // fshift + 2 58 | return M 59 | 60 | 61 | def pad_lr(x, fsize, fshift): 62 | """Compute left and right padding 63 | """ 64 | M = num_frames(len(x), fsize, fshift) 65 | pad = (fsize - fshift) 66 | T = len(x) + 2 * pad 67 | r = (M - 1) * fshift + fsize - T 68 | return pad, pad + r 69 | ########################################################## 70 | #Librosa correct padding 71 | def librosa_pad_lr(x, fsize, fshift): 72 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 73 | 74 | # Conversions 75 | _mel_basis = None 76 | 77 | def _linear_to_mel(spectogram): 78 | global _mel_basis 79 | if _mel_basis is None: 80 | _mel_basis = _build_mel_basis() 81 | return np.dot(_mel_basis, spectogram) 82 | 83 | def _build_mel_basis(): 84 | assert hp.fmax <= hp.sample_rate // 2 85 | return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, 86 | fmin=hp.fmin, fmax=hp.fmax) 87 | 88 | def _amp_to_db(x): 89 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 90 | return 20 * np.log10(np.maximum(min_level, x)) 91 | 92 | def _db_to_amp(x): 93 | return np.power(10.0, (x) * 0.05) 94 | 95 | def _normalize(S): 96 | if hp.allow_clipping_in_normalization: 97 | if hp.symmetric_mels: 98 | return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 99 | -hp.max_abs_value, hp.max_abs_value) 100 | else: 101 | return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) 102 | 103 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 104 | if hp.symmetric_mels: 105 | return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value 106 | else: 107 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 108 | 109 | def _denormalize(D): 110 | if hp.allow_clipping_in_normalization: 111 | if hp.symmetric_mels: 112 | return (((np.clip(D, -hp.max_abs_value, 113 | hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) 114 | + hp.min_level_db) 115 | else: 116 | return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 117 | 118 | if hp.symmetric_mels: 119 | return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) 120 | else: 121 | return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) 122 | 123 | 124 | 125 | def wav2mel(wav, sr): 126 | wav16k = resample(wav, orig_sr=sr, target_sr=16000) 127 | # print('wav16k', wav16k.shape, wav16k.dtype) 128 | mel = melspectrogram(wav16k) 129 | # print('mel', mel.shape, mel.dtype) 130 | if np.isnan(mel.reshape(-1)).sum() > 0: 131 | raise ValueError( 132 | 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 133 | # mel.dtype = np.float32 134 | mel_chunks = [] 135 | mel_idx_multiplier = 80. / 25 136 | mel_step_size = 8 137 | i = start_idx = 0 138 | while start_idx < len(mel[0]): 139 | start_idx = int(i * mel_idx_multiplier) 140 | if start_idx + mel_step_size // 2 > len(mel[0]): 141 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) 142 | elif start_idx - mel_step_size // 2 < 0: 143 | mel_chunks.append(mel[:, :mel_step_size]) 144 | else: 145 | mel_chunks.append(mel[:, start_idx - mel_step_size // 2 : start_idx + mel_step_size // 2]) 146 | i += 1 147 | return mel_chunks 148 | 149 | 150 | 151 | if __name__ == '__main__': 152 | import argparse 153 | 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--wav', type=str, default='') 156 | parser.add_argument('--save_feats', action='store_true') 157 | 158 | opt = parser.parse_args() 159 | 160 | wav, sr = librosa.core.load(opt.wav) 161 | mel_chunks = np.array(wav2mel(wav.T, sr)) 162 | print(mel_chunks.shape, mel_chunks.transpose(0,2,1).shape) 163 | 164 | if opt.save_feats: 165 | save_path = opt.wav.replace('.wav', '_mel.npy') 166 | np.save(save_path, mel_chunks.transpose(0,2,1)) 167 | print(f"[INFO] saved logits to {save_path}") -------------------------------------------------------------------------------- /wav2mel_hparams.py: -------------------------------------------------------------------------------- 1 | class HParams: 2 | def __init__(self, **kwargs): 3 | self.data = {} 4 | 5 | for key, value in kwargs.items(): 6 | self.data[key] = value 7 | 8 | def __getattr__(self, key): 9 | if key not in self.data: 10 | raise AttributeError("'HParams' object has no attribute %s" % key) 11 | return self.data[key] 12 | 13 | def set_hparam(self, key, value): 14 | self.data[key] = value 15 | 16 | # Default hyperparameters 17 | hparams = HParams( 18 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 19 | # network 20 | rescale=True, # Whether to rescale audio prior to preprocessing 21 | rescaling_max=0.9, # Rescaling value 22 | 23 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 24 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 25 | # Does not work if n_ffit is not multiple of hop_size!! 26 | use_lws=False, 27 | 28 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 29 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 30 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 31 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 32 | 33 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 34 | 35 | # Mel and Linear spectrograms normalization/scaling and clipping 36 | signal_normalization=True, 37 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 38 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 39 | symmetric_mels=True, 40 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 41 | # faster and cleaner convergence) 42 | max_abs_value=4., 43 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 44 | # be too big to avoid gradient explosion, 45 | # not too small for fast convergence) 46 | # Contribution by @begeekmyfriend 47 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 48 | # levels. Also allows for better G&L phase reconstruction) 49 | preemphasize=True, # whether to apply filter 50 | preemphasis=0.97, # filter coefficient. 51 | 52 | # Limits 53 | min_level_db=-100, 54 | ref_level_db=20, 55 | fmin=65, 56 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 57 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 58 | fmax=6000, # To be increased/reduced depending on data. 59 | 60 | ###################### Our training parameters ################################# 61 | img_size=96, 62 | fps=25, 63 | 64 | batch_size=16, 65 | initial_learning_rate=1e-4, 66 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 67 | num_workers=16, 68 | checkpoint_interval=3000, 69 | eval_interval=3000, 70 | save_optimizer_state=True, 71 | 72 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 73 | syncnet_batch_size=64, 74 | syncnet_lr=1e-4, 75 | syncnet_eval_interval=10000, 76 | syncnet_checkpoint_interval=10000, 77 | 78 | disc_wt=0.07, 79 | disc_initial_learning_rate=1e-4, 80 | ) 81 | --------------------------------------------------------------------------------