├── DATASET.md ├── images ├── teaser.png ├── object_interaction_dataframe.png └── locomotion_prediction_dataframe.png ├── decoder └── utils.py ├── INSTALL.md ├── masking_generator.py ├── dpvo └── plot_utils.py ├── functional.py ├── VPP.md ├── README.md ├── volume_transforms.py ├── VIP.md ├── LP.md ├── random_erasing.py ├── engine_for_pretraining.py ├── modeling_video_teacher.py ├── optim_factory.py ├── engine_locomotion_prediction_for_finetuning.py ├── object_interaction ├── object_interaction_utils.py └── object_interaction_dataloader.py ├── locomotion_prediction ├── locomotion_prediction_utils.py └── locomotion_prediction_dataloader.py ├── modeling_teacher.py ├── datasets.py ├── transforms.py ├── cms_dataset.py ├── mixup.py ├── ssv2.py └── modeling_student.py /DATASET.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DannyTran123/egopet/HEAD/images/teaser.png -------------------------------------------------------------------------------- /images/object_interaction_dataframe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DannyTran123/egopet/HEAD/images/object_interaction_dataframe.png -------------------------------------------------------------------------------- /images/locomotion_prediction_dataframe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DannyTran123/egopet/HEAD/images/locomotion_prediction_dataframe.png -------------------------------------------------------------------------------- /decoder/utils.py: -------------------------------------------------------------------------------- 1 | import ffmpeg 2 | import numpy as np 3 | import torch 4 | 5 | def decode_ffmpeg(video_path, start_seek, num_sec=2, num_frames=16, fps=5): 6 | 7 | probe = ffmpeg.probe(video_path, v='error', select_streams='v:0', show_entries='stream=width,height,duration,r_frame_rate') 8 | video_info = next((s for s in probe['streams'] if 'width' in s and 'height' in s), None) 9 | 10 | if video_info is None: 11 | raise ValueError("No video stream information found in the input video.") 12 | 13 | width = int(video_info['width']) 14 | height = int(video_info['height']) 15 | r_frame_rate = video_info['r_frame_rate'].split('/') 16 | 17 | if fps is None: 18 | fps = int(r_frame_rate[0]) / int(r_frame_rate[1]) 19 | 20 | cmd = ( 21 | ffmpeg 22 | .input(video_path, ss=start_seek, t=num_sec + 0.1) 23 | .filter('fps', fps=fps) 24 | ) 25 | out, _ = ( 26 | cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24') 27 | .run(capture_stdout=True, quiet=True) 28 | ) 29 | 30 | video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) 31 | video_copy = video.copy() 32 | video = torch.from_numpy(video_copy) 33 | return video[:num_frames].type(torch.float32) 34 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Requirements 4 | - Python >= 3.8 5 | - [PyTorch==1.10](https://pytorch.org/) 6 | - [torchvision](https://github.com/pytorch/vision) that matches the PyTorch installation. 7 | - [timm==0.4.12](https://github.com/rwightman/pytorch-image-models) 8 | - [deepspeed==0.5.8](https://github.com/microsoft/DeepSpeed) 9 | - [TensorboardX](https://github.com/lanpa/tensorboardX) 10 | - [decord](https://github.com/dmlc/decord) 11 | - [einops](https://github.com/arogozhnikov/einops) 12 | - [scipy](https://github.com/scipy/scipy) 13 | - [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) 14 | - [opencv-python](https://pypi.org/project/opencv-python/) 15 | - [iopath](https://github.com/facebookresearch/iopath) 16 | - [numpy==1.23.0](https://numpy.org/install/) 17 | - [pandas](https://pandas.pydata.org/docs/getting_started/install.html) 18 | - [matplotlib](https://matplotlib.org/stable/users/installing/index.html) 19 | - [sklearn](https://scikit-learn.org/stable/install.html) 20 | - [torchmetrics](https://pypi.org/project/torchmetrics/) 21 | 22 | ## Extras 23 | For ffmpeg-python you additionally need to download the openh264 binary from [here](https://github.com/cisco/openh264/releases/tag/v2.1.1) and rename the file as 'libopenh264.so.5' and place it in your environment's lib directory. 24 | -------------------------------------------------------------------------------- /masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class TubeMaskingGenerator: 5 | def __init__(self, input_size, mask_ratio): 6 | self.frames, self.height, self.width = input_size 7 | self.num_patches_per_frame = self.height * self.width 8 | self.total_patches = self.frames * self.num_patches_per_frame 9 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 10 | self.total_masks = self.frames * self.num_masks_per_frame 11 | 12 | def __repr__(self): 13 | repr_str = "Maks: total patches {}, mask patches {}".format( 14 | self.total_patches, self.total_masks 15 | ) 16 | return repr_str 17 | 18 | def __call__(self): 19 | mask_per_frame = np.hstack([ 20 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 21 | np.ones(self.num_masks_per_frame), 22 | ]) 23 | np.random.shuffle(mask_per_frame) 24 | mask = np.tile(mask_per_frame, (self.frames, 1)) 25 | mask = mask.flatten() 26 | return mask 27 | 28 | 29 | class RandomMaskingGenerator: 30 | def __init__(self, input_size, mask_ratio): 31 | self.frames, self.height, self.width = input_size 32 | self.total_patches = self.frames * self.height * self.width 33 | self.num_masks = int(mask_ratio * self.total_patches) 34 | self.total_masks = self.num_masks 35 | 36 | def __repr__(self): 37 | repr_str = "Maks: total patches {}, mask patches {}".format( 38 | self.total_patches, self.total_masks 39 | ) 40 | return repr_str 41 | 42 | def __call__(self): 43 | mask = np.hstack([ 44 | np.zeros(self.total_patches - self.num_masks), 45 | np.ones(self.num_masks), 46 | ]) 47 | np.random.shuffle(mask) 48 | return mask 49 | -------------------------------------------------------------------------------- /dpvo/plot_utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from evo.core import sync 6 | from evo.core.trajectory import PoseTrajectory3D 7 | from evo.tools import plot 8 | from pathlib import Path 9 | 10 | 11 | def make_traj(args) -> PoseTrajectory3D: 12 | if isinstance(args, tuple): 13 | traj, tstamps = args 14 | return PoseTrajectory3D(positions_xyz=traj[:,:3], orientations_quat_wxyz=traj[:,3:], timestamps=tstamps) 15 | assert isinstance(args, PoseTrajectory3D), type(args) 16 | return deepcopy(args) 17 | 18 | def best_plotmode(traj): 19 | # _, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0)) 20 | # plot_axes = "xyz"[i2] + "xyz"[i1] 21 | plot_axes = "xz" 22 | return getattr(plot.PlotMode, plot_axes) 23 | 24 | def plot_trajectory(pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True): 25 | pred_traj = make_traj(pred_traj) 26 | 27 | if gt_traj is not None: 28 | gt_traj = make_traj(gt_traj) 29 | gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj) 30 | 31 | if align: 32 | pred_traj.align(gt_traj, correct_scale=correct_scale) 33 | 34 | plot_collection = plot.PlotCollection("PlotCol") 35 | fig = plt.figure(figsize=(8, 8)) 36 | plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj) 37 | ax = plot.prepare_axis(fig, plot_mode) 38 | ax.set_title(title) 39 | if gt_traj is not None: 40 | plot.traj(ax, plot_mode, gt_traj, '--', 'gray', "Ground Truth") 41 | plot.traj(ax, plot_mode, pred_traj, '-', 'blue', "Predicted") 42 | plot_collection.add_figure("traj (error)", fig) 43 | plot_collection.export(filename, confirm_overwrite=False) 44 | plt.close(fig=fig) 45 | # print(f"Saved {filename}") 46 | 47 | def save_trajectory_tum_format(traj, filename): 48 | traj = make_traj(traj) 49 | tostr = lambda a: ' '.join(map(str, a)) 50 | with Path(filename).open('w') as f: 51 | for i in range(traj.num_poses): 52 | f.write(f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[1,2,3,0]])}\n") 53 | print(f"Saved {filename}") 54 | -------------------------------------------------------------------------------- /functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /VPP.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning for Vision to Proprioception prediction (VPP) Classification 2 | 3 | ## Description 4 | We evaluate the usefulness of EgoPet on a robotic task, focusing on the problem of vision-based locomotion. Specifically, the task consists of predicting the parameters of the terrain a quadrupedal robot is walking on. The accurate prediction of these parameters is correlated with higher performance in the downstream walking task. 5 | 6 | The parameters we predict are the local terrain geometry, the terrain's friction, and the parameters related to the robot's walking behavior on the terrain, including the robot's speed, motor efficiency, and high-level command. We aim to predict a latent representation $z_t$ of the terrain parameters. This latent representation consists of the hidden layer of a neural network trained in simulation to encode end-to-end together with the walking policy. 7 | 8 | To collect the dataset, we let the robot walk in many environments. The training dataset contains 120 thousand frames from 3 environments (approximately 2.3 hours of total walking time): an office, a park, and a beach. Each of them contains different terrain geometries, including flat, steps, and slopes. Each sample contains an image collected from a forward-looking camera mounted on the robot and the (latent) parameters of the terrain below the center of mass of the robot $z_t$ estimated with a history of proprioception. 9 | 10 | The task consists of predicting $z_t$ from a (history) of images. We generate several sub-tasks by predicting the future terrain parameters $z_{t+0.8}, z_{t+1.5}$ and the past ones $z_{t-0.8}, z_{t-1.5}$. These time intervals were selected to differentiate between forecasting and estimation. We divide the datasets into a training and testing set per episode, i.e., two different policy runs. We construct three test datasets, one in distribution (same location and lighting conditions as training), out of distribution due to lighting conditions (same location but very different time (night)), and out of distribution data such as sand which was never experienced during training. 11 | 12 | For more information about the VPP task refer to our paper! 13 | 14 | ## Dataset Setup 15 | Download and extract CMS_data.tar.gz from from [here](https://drive.google.com/file/d/1ZKSWwCoZP1mHjpksIEAwh3sTeeSJNF3B/view?usp=sharing) 16 | 17 | ## Linear Probing 18 | 19 | Run the following command to train a linear probing layer to predict the proprioception (past, present, future) given the video input. 20 | 21 | ``` 22 | DATA_PATH='./datasets/CMS/up_down_policy_data' 23 | LOOKHEAD=0.8 # in -1.5 -0.8 0 0.8 1.5 24 | MASTER_PORT=29500 25 | DATA_ROOT=${DATA_PATH} 26 | OUTPUT_DIR="./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/finetune_on_cms_lookahead_${LOOKHEAD}_8frames_update_freq_4" 27 | MODEL_PATH="./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/checkpoint-2669.pth" 28 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \ 29 | run_fs_domain_adaptation.py \ 30 | --model vit_base_patch16_224 \ 31 | --nb_classes 10 \ 32 | --latent_dim 10 \ 33 | --data_path ${DATA_PATH} \ 34 | --data_root ${DATA_ROOT} \ 35 | --finetune ${MODEL_PATH} \ 36 | --log_dir ${OUTPUT_DIR} \ 37 | --output_dir ${OUTPUT_DIR} \ 38 | --input_size 224 --short_side_size 224 \ 39 | --opt adamw --opt_betas 0.9 0.999 --weight_decay 0.05 \ 40 | --batch_size 256 --update_freq 4 --num_sample 4 \ 41 | --num_frames 8 --sampling_rate 4 \ 42 | --lr 5e-4 --epochs 50 \ 43 | --input_use_imgs 8 \ 44 | --lookhead ${LOOKHEAD} \ 45 | --num_workers 5 \ 46 | --save_ckpt_freq 20 47 | ``` 48 | 49 | ### Pretrained Models 50 | | Model | Lookahead | Link | 51 | |-------------------|-----------|------| 52 | | MVD (ViT-B) | -1.5 | [link](https://drive.google.com/file/d/12tO7LwjZ66voCTxp6lNcSaeLaU-9YlYE/view?usp=sharing) | 53 | | MVD (ViT-B) | -0.8 | [link](https://drive.google.com/file/d/135ndczYtWKF04ZNl-5zs_O6yTdLh5T7O/view?usp=sharing) | 54 | | MVD (ViT-B) | 0 | [link](https://drive.google.com/file/d/1r_8ZbJtuI_6ImFQFzjgIb1ERVBjq-eY-/view?usp=sharing) | 55 | | MVD (ViT-B) | 0.8 | [link](https://drive.google.com/file/d/1f78c-ascoWKa3_29rq-4NhCruidMsFoI/view?usp=sharing) | 56 | | MVD (ViT-B) | 1.5 | [link](https://drive.google.com/file/d/1BO5gzVs5TiNilRpF5OzuTtyVSjfMO33Z/view?usp=sharing) | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EgoPet: Egomotion and Interaction Data from an Animal's Perspective (ECCV 2022) 2 | ### [Amir Bar](https://amirbar.net), [Arya Bakhtiar](), [Antonio Loquercio](https://antonilo.github.io/), [Jathushan Rajasegaran](https://people.eecs.berkeley.edu/~jathushan/), [Danny Tran](), [Yann LeCun](https://yann.lecun.com/), [Amir Globerson](http://www.cs.tau.ac.il/~gamir/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) 3 |

4 | 5 |

6 | 7 | This repository is the implementation of the pretraining and linear probing experiments in this paper. 8 | 9 | ## Abstract 10 | Animals perceive the world to plan their actions and interact with other agents to accomplish complex tasks, demonstrating capabilities that are still unmatched by AI systems. To advance our understanding and reduce the gap between the capabilities of animals and AI systems, we introduce a dataset of pet egomotion imagery with diverse examples of simultaneous egomotion and multi-agent interaction. Current video datasets separately contain egomotion and interaction examples, but rarely both at the same time. In addition, EgoPet offers a radically distinct perspective from existing egocentric datasets of humans or vehicles. We define two in-domain benchmark tasks that capture animal behavior, and a third benchmark to assess the utility of EgoPet as a pretraining resource to robotic quadruped locomotion, showing that models trained from EgoPet outperform those trained from prior datasets. This work provides evidence that today's pets could be a valuable resource for training future AI systems and robotic assistants. 11 | 12 | ## EgoPet Dataset 13 | Install the data [here](https://huggingface.co/datasets/amirbar1/egopet). 14 | 15 | ## Pre-training 16 | ### Installation 17 | Please follow the instructions in [INSTALL.md](INSTALL.md). 18 | 19 | ### Pretrain an MVD model on EgoPet dataset: 20 | * Download VideoMAE ViT-B checkpoint from [here](https://drive.google.com/file/d/1tEhLyskjb755TJ65ptsrafUG2llSwQE1/view) 21 | * Download MAE ViT-B checkpoint from [here](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth) 22 | 23 | Run the following command: 24 | ``` 25 | OUTPUT_DIR='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet' 26 | IMAGE_TEACHER="path/to/mae/checkpoint" 27 | VIDEO_TEACHER="path/to/kinetics/checkpoint" 28 | DATA_PATH='egopet_pretrain.csv' 29 | GPUS=8 30 | NODE_COUNT=4 31 | RANK=0 32 | MASTER_PORT=29500 33 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=${GPUS} --use_env \ 34 | --master_port ${MASTER_PORT} --nnodes=${NODE_COUNT} \ 35 | --node_rank=${RANK} --master_addr=${MASTER_ADDR} \ 36 | run_mvd_pretraining.py \ 37 | --data_path ${DATA_PATH} \ 38 | --data_root ${DATA_ROOT} \ 39 | --model pretrain_masked_video_student_base_patch16_224 \ 40 | --opt adamw --opt_betas 0.9 0.95 \ 41 | --log_dir ${OUTPUT_DIR} \ 42 | --output_dir ${OUTPUT_DIR} \ 43 | --image_teacher_model mae_teacher_vit_base_patch16 \ 44 | --distillation_target_dim 768 \ 45 | --distill_loss_func SmoothL1 \ 46 | --image_teacher_model_ckpt_path ${IMAGE_TEACHER} \ 47 | --video_teacher_model pretrain_videomae_teacher_base_patch16_224 \ 48 | --video_distillation_target_dim 768 \ 49 | --video_distill_loss_func SmoothL1 \ 50 | --video_teacher_model_ckpt_path ${VIDEO_TEACHER} \ 51 | --mask_type tube --mask_ratio 0.9 --decoder_depth 2 \ 52 | --batch_size 16 --update_freq 2 --save_ckpt_freq 10 \ 53 | --num_frames 16 --sampling_rate 4 \ 54 | --lr 1.5e-4 --min_lr 1e-4 --drop_path 0.1 --warmup_epochs 268 --epochs 2680 \ 55 | --auto_resume 56 | ``` 57 | We set `RANK` (`--node_rank`) as `0` on the first node. On other nodes, run the same command with `RANK=1`, ..., `RANK=3` respectively. `--master_addr` is set as the ip of the node 0. 58 | 59 | ### Pretrained Models 60 | | Model | Pretraining | Epochs | Link | 61 | |-------------------|-------------|--------|------| 62 | | MVD (ViT-B) | EgoPet | 2670 | [link](https://drive.google.com/file/d/1_Ky73DpjvSh5k4g-xhWE6QTndRlfUTHp/view?usp=sharing) | 63 | 64 | ## Visual Interaction Prediction (VIP) 65 | 66 | The fine-tuning instructions for the VIP task is in [VIP.md](VIP.md). 67 | 68 | ## Locomotion Prediction (LP) 69 | 70 | The fine-tuning instructions for the LP task is in [LP.md](LP.md). 71 | 72 | ## Vision to Proprioception Prediction (VPP) 73 | 74 | The fine-tuning instructions for the VPP task is in [VPP.md](VPP.md). 75 | 76 | ## Acknowledgements 77 | 78 | This project is built upon [MVD](https://github.com/ruiwang2021/mvd/tree/main), [MAE_ST](https://github.com/facebookresearch/mae_st), and [DPVO](https://github.com/princeton-vl/DPVO). Thank you to the contributors of these codebases! 79 | -------------------------------------------------------------------------------- /volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /VIP.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning for Visual Interaction Prediction (VIP) Classification 2 | 3 | ## Description 4 | Studying animal interactions from an egocentric perspective provides insights into how they navigate, manipulate their surroundings, and communicate which is applicable in designing effective systems for real-world settings. 5 | 6 | The input for this task is a video clip from the egocentric perspective of an animal. The output is twofold: a binary classification indicating whether an interaction is taking place or not, and identification of the subject of the interaction. In the EgoPet dataset, an “interaction” is defined as a discernible event where the agent demonstrates clear attention (physical contact, proximity, orientation, or vocalization) to an object or another agent with visual evidence. Aimless movements are not labeled as interactions. 7 | 8 | The beginning of an interaction is marked at the first time-step where the agent begins to give attention to a target, and the endpoint is marked at the last time-step before the attention ceases. There are 644 positive interaction segments with 17 distinct interaction subjects and 556 segments where no interaction occurs (“negative segments”). The segments were then split into train and test. For both training and testing, we sampled additional negative segments from the videos, ensuring these did not overlap with the positive segments. We sampled 125 additional negative training segments and 124 additional negative test segments, leaving us with 754 training segments and 695 test segments, for a total of 1,449 annotated segments. 9 | 10 | For more information about the VIP task refer to our paper! 11 | 12 | ## Dataset Information 13 | `object_interaciton_train.csv` and `object_interaciton_validation.csv` are csv files in which every row represents a single training or validation sample. 14 | ![alt text](images/object_interaction_dataframe.png "") 15 | 16 | ### Columns documentation: 17 | ``` 18 | animal - source animal of ego footage 19 | ds_type - train/validation 20 | video_id - hashing of video 21 | segment_id - id for segment within video 22 | start_time - start time of interaction, if NONE there is no interaction in the entire clip 23 | end_time - end time of interaction, if NONE there is no interaction in the entire clip 24 | total_time - total time of segment video 25 | interacting_object - object being interacted with 26 | video_path - path to segment video 27 | ``` 28 | 29 | ### Interacting Object Categories 30 | Chosen from common interaction objects. 31 | ``` 32 | person 33 | ball 34 | bench 35 | bird 36 | dog 37 | cat 38 | other animal 39 | toy 40 | door 41 | floor 42 | food 43 | plant 44 | filament 45 | plastic 46 | water 47 | vehicle 48 | other 49 | ``` 50 | 51 | ## Evaluation 52 | 53 | As a sanity check, run evaluation using our MVD **fine-tuned** models: 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 |
MVD (EgoPet)
fine-tuned checkpointdownload
reference Interaction accuracy68.75
reference Interaction AUROC74.50
reference Subject Prediction Top-1 accuracy35.38
reference Subject Prediction Top-3 accuracy66.43
80 | 81 | Evaluate VideoMAE/MVD on a single GPU (`{EGOPET_DIR}` is a directory containing `{training_set, validation_set}` sets of EgoPet, `{CSV_PATH}` is a directory containing `{object_interaction_train.csv, object_interaction_validation.csv}` which is `csv/`): 82 | ``` 83 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set' 84 | CSV_PATH='csv/' 85 | FINETUNE_PATH='path/to/model' 86 | python run_object_interaction_finetuning.py --num_workers 5 --model vit_base_patch16_224 --latent_dim 18 --data_path ${EGOPET_DIR} --csv_path ${CSV_PATH} --finetune ${FINETUNE_PATH} --input_size 224 --batch_size 64 --num_frames 8 --num_sec 2 --fps 4 --alpha 1 --eval 87 | ``` 88 | Evaluating on MVD (EgoPet) should give: 89 | ``` 90 | * loss 1.736 91 | * Acc@1_interaction 68.750 auroc_interaction 0.745 92 | * Acc@1_object 35.379 Acc@3_object 66.426 auroc_object 0.690 93 | Loss of the network on the 695 test images: 1.7% 94 | Min loss: 1.000% 95 | ``` 96 | 97 | ## Linear Probing 98 | 99 | To fine-tune a pre-trained ViT-Base VideoMAE/MVD with **single-node training**, run the following on 1 node with 8 GPUs: 100 | ``` 101 | OUTPUT_DIR='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/finetune_on_object_interaction' 102 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set' 103 | CSV_PATH='csv/' 104 | FINETUNE_PATH='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/checkpoint-2669.pth' 105 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \ 106 | run_object_interaction_finetuning.py \ 107 | --output_dir ${OUTPUT_DIR} \ 108 | --num_workers 5 \ 109 | --model vit_base_patch16_224 \ 110 | --latent_dim 18 \ 111 | --data_path ${EGOPET_DIR} \ 112 | --csv_path ${CSV_PATH} \ 113 | --finetune ${FINETUNE_PATH} \ 114 | --input_size 224 \ 115 | --opt adamw --opt_betas 0.9 0.999 --weight_decay 0.05 \ 116 | --batch_size 64 --update_freq 2 --num_sample 2 \ 117 | --save_ckpt_freq 10 --auto_resume \ 118 | --num_frames 8 --num_sec 2 --fps 4 --object_interaction_ratio 0.5 \ 119 | --alpha 1 \ 120 | --lr 5e-4 --epochs 15 121 | ``` 122 | To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. -------------------------------------------------------------------------------- /LP.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning for Locomotion Prediction (LP) 2 | 3 | ## Description 4 | Planning where to move involves a complex interplay of both perception and foresight. It requires the ability to anticipate potential obstacles, consider various courses of action, and select the most efficient and effective strategy to achieve a desired goal. EgoPet contains examples where animals plan a future trajectory to achieve a certain goal (e.g., a dog following its owner). 5 | 6 | Given a sequence of past $m$ video frames $\{x_i\}^{t}_{i=t-m}$, the goal is to predict the unit normalized future trajectory of the agent $\{v_j\}^{t+k}_{j=t+1}$, where $v_j\in \mathbb{R}^{3}$ represents the relative location of the agent at timestep $j$. We predict the unit normalized relative location due to the scale ambiguity of the extracted trajectories. In practice, we condition models on $m=16$ frames and predict $k=40$ future locations which correspond to $4$ seconds into the future. 7 | 8 | Given an input sequence of frames, DPVO returns the location and orientation of the camera for each frame. For our EgoPet data we found that a stride (the rate at which we kept frames) of $5$ worked best but in some cases a stride of $10$ and $15$ worked better. While DPVO worked well, there were some inaccurate trajectories, so two human annotators were trained to evaluate a trajectory from an eagle's eye view of the trajectory (XZ view). Annotators were trained to choose which stride of DPVO produced the poses that best matched the trajectory if any at all. 9 | 10 | For more information about the LP task refer to our paper! 11 | 12 | ## Dataset Information 13 | `locomotion_prediction_train.csv` and `locomotion_prediction_train_25.csv` are csv files in which every row represents a single training sample and the 25 refers to 25% of the data which we use in the experiments in the paper. `locomotion_prediction_val.csv` is the csv file for the validation set. 14 | ![alt text](images/locomotion_prediction_dataframe.png "") 15 | 16 | The DPVO generated trajectories can be downloaded from [here](https://drive.google.com/file/d/1kgwhAnjrSoOSPx9s4F013NgmNEi5K0KV/view?usp=sharing). 17 | 18 | ### Columns documentation: 19 | ``` 20 | animal - source animal of ego footage 21 | ds_type - train/validation 22 | segment_id - id for segment within video 23 | stride - the DPVO stride for this trajectory 24 | start_time - start time of this trajectory 25 | end_time - end time of this trajectory 26 | video_path - path to segment video 27 | ``` 28 | 29 | ## Evaluation 30 | 31 | As a sanity check, run evaluation using our MVD **fine-tuned** models: 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
MVD (EgoPet)
fine-tuned checkpointdownload
reference ATE0.474
reference RPE0.171
50 | 51 | Evaluate VideoMAE/MVD on a 1 node with 8 GPUs. (`{EGOPET_DIR}` is a directory containing `{training_set, validation_set}` sets of EgoPet, `{TRAIN_CSV}` is the path to the train csv to use. `{RESUME_PATH}` is the path to the model you want to evaluate.): 52 | ``` 53 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set' 54 | TRAIN_CSV='locomotion_prediction_train_25.csv' 55 | RESUME_PATH='path/to/model/egopet_lp_linearprobing_vitb_checkpoint-00014.pth' 56 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --master_port=29101 57 | --nproc_per_node=8 --use_env \ 58 | run_locomotion_prediction_finetuning.py \ 59 | --num_workers 7 \ 60 | --model vit_base_patch16_224 \ 61 | --input_size 224 \ 62 | --validation_batch_size 8 \ 63 | --path_to_data_dir ${EGOPET_DIR} \ 64 | --train_csv ${TRAIN_CSV} \ 65 | --path_to_trajectories_dir ${PATH_TO_TRAJECTORIES_DIR} \ 66 | --num_pred 3 \ 67 | --resume ${RESUME_PATH} \ 68 | --animals cat,dog --num_condition_frames 16 --num_pose_prediction 120 --pps 10 --fps 30 \ 69 | --scale_invariance dir \ 70 | --dist_eval --eval 71 | ``` 72 | Evaluating on MVD (EgoPet) should give: 73 | ``` 74 | * loss 0.36006 75 | * ate 0.47448 76 | * rpe_trans 0.17098 77 | * rpe_rot 0.00000 78 | Loss of the network on the 167 test images: 0.4% 79 | Min loss: 0.360% 80 | ``` 81 | 82 | ## Linear Probing 83 | 84 | To fine-tune a pre-trained ViT-Base VideoMAE/MVD with **single-node training**, run the following on 1 node with 8 GPUs: 85 | ``` 86 | OUTPUT_DIR='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/finetune_on_object_interaction' 87 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set' 88 | CSV_PATH='csv/locomotion_prediction_train_25.csv' 89 | FINETUNE_PATH='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/checkpoint-2669.pth' 90 | PATH_TO_TRAJECTORIES_DIR='your_path/interp_trajectories' 91 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \ 92 | run_locomotion_prediction_finetuning.py \ 93 | --output_dir ${OUTPUT_DIR} \ 94 | --num_workers 5 \ 95 | --model vit_base_patch16_224 \ 96 | --path_to_data_dir ${EGOPET_DIR} \ 97 | --path_to_trajectories_dir ${PATH_TO_TRAJECTORIES_DIR} \ 98 | --train_csv ${CSV_PATH} \ 99 | --finetune ${FINETUNE_PATH} \ 100 | --input_size 224 \ 101 | --opt adamw --opt_betas 0.9 0.999 --weight_decay 0.05 \ 102 | --batch_size 16 --validation_batch_size 8 --update_freq 8 --num_sample 2 \ 103 | --save_ckpt_freq 1 --auto_resume \ 104 | --num_pred 3 --num_condition_frames 16 --num_pose_prediction 120 \ 105 | --animals cat,dog --pps 10 --fps 30 --scale_invariance dir \ 106 | --dist_eval \ 107 | --lr 5e-4 --epochs 15 108 | ``` 109 | To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /engine_for_pretraining.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import utils 8 | from einops import rearrange 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 11 | 12 | Loss_func_choice = {'L1': torch.nn.L1Loss, 'L2': torch.nn.MSELoss, 'SmoothL1': torch.nn.SmoothL1Loss} 13 | 14 | 15 | def train_one_epoch(args, model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, 16 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 17 | log_writer=None, lr_scheduler=None, start_steps=None, lr_schedule_values=None, 18 | wd_schedule_values=None, update_freq=None, time_stride_loss=True, lr_scale=1.0, 19 | image_teacher_model=None, video_teacher_model=None, norm_feature=False): 20 | 21 | model.train() 22 | metric_logger = utils.MetricLogger(delimiter=" ") 23 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 24 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 25 | header = 'Epoch: [{}]'.format(epoch) 26 | print_freq = 10 27 | LN_img = nn.LayerNorm(args.distillation_target_dim, eps=1e-6, elementwise_affine=False).cuda() 28 | LN_vid = nn.LayerNorm(args.video_distillation_target_dim, eps=1e-6, elementwise_affine=False).cuda() 29 | 30 | loss_func_img_feat = Loss_func_choice[args.distill_loss_func]() 31 | loss_func_vid_feat = Loss_func_choice[args.video_distill_loss_func]() 32 | image_loss_weight = args.image_teacher_loss_weight 33 | video_loss_weight = args.video_teacher_loss_weight 34 | 35 | tubelet_size = args.tubelet_size 36 | 37 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 38 | # assign learning rate & weight decay for each step 39 | update_step = step // update_freq 40 | it = start_steps + update_step # global training iteration 41 | if lr_schedule_values is not None or wd_schedule_values is not None and step % update_freq == 0: 42 | for i, param_group in enumerate(optimizer.param_groups): 43 | if lr_schedule_values is not None: 44 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] * lr_scale 45 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 46 | param_group["weight_decay"] = wd_schedule_values[it] 47 | 48 | videos, videos_for_teacher, bool_masked_pos = batch 49 | videos = videos.to(device, non_blocking=True) 50 | videos_for_teacher = videos_for_teacher.to(device, non_blocking=True) 51 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool) 52 | _, _, T, _, _ = videos.shape 53 | 54 | with torch.cuda.amp.autocast(): 55 | output_features, output_video_features = model(videos, bool_masked_pos) 56 | with torch.no_grad(): 57 | image_teacher_model.eval() 58 | if time_stride_loss: 59 | teacher_features = image_teacher_model( 60 | rearrange(videos_for_teacher[:, :, ::tubelet_size, :, :], 'b c t h w -> (b t) c h w'), 61 | ) 62 | teacher_features = rearrange(teacher_features, '(b t) l c -> b (t l) c', t=T//tubelet_size) 63 | else: 64 | teacher_features = image_teacher_model( 65 | rearrange(videos_for_teacher, 'b c t h w -> (b t) c h w'), 66 | ) 67 | teacher_features = rearrange(teacher_features, '(b t d) l c -> b (t l) (d c)', t=T//tubelet_size, d=tubelet_size) 68 | if norm_feature: 69 | teacher_features = LN_img(teacher_features) 70 | 71 | video_teacher_model.eval() 72 | videos_for_video_teacher = videos if args.video_teacher_input_size == args.input_size \ 73 | else videos_for_teacher 74 | 75 | video_teacher_features = video_teacher_model(videos_for_video_teacher) 76 | if norm_feature: 77 | video_teacher_features = LN_vid(video_teacher_features) 78 | 79 | B, _, D = output_features.shape 80 | loss_img_feat = loss_func_img_feat( 81 | input=output_features, 82 | target=teacher_features[bool_masked_pos].reshape(B, -1, D) 83 | ) 84 | loss_value_img_feat = loss_img_feat.item() 85 | 86 | B, _, D = output_video_features.shape 87 | loss_vid_feat = loss_func_vid_feat( 88 | input=output_video_features, 89 | target=video_teacher_features[bool_masked_pos].reshape(B, -1, D) 90 | ) 91 | loss_value_vid_feat = loss_vid_feat.item() 92 | 93 | loss = image_loss_weight * loss_img_feat + video_loss_weight * loss_vid_feat 94 | 95 | loss_value = loss.item() 96 | 97 | if not math.isfinite(loss_value): 98 | print("Loss is {}, stopping training".format(loss_value)) 99 | sys.exit(1) 100 | 101 | # this attribute is added by timm on one optimizer (adahessian) 102 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 103 | loss /= update_freq 104 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 105 | parameters=model.parameters(), create_graph=is_second_order, 106 | update_grad=(step + 1) % update_freq == 0) 107 | if (step + 1) % update_freq == 0: 108 | optimizer.zero_grad() 109 | loss_scale_value = loss_scaler.state_dict()["scale"] 110 | 111 | torch.cuda.synchronize() 112 | 113 | metric_logger.update(loss=loss_value) 114 | metric_logger.update(loss_img_feat=loss_value_img_feat) 115 | metric_logger.update(loss_vid_feat=loss_value_vid_feat) 116 | metric_logger.update(loss_scale=loss_scale_value) 117 | min_lr = 10. 118 | max_lr = 0. 119 | for group in optimizer.param_groups: 120 | min_lr = min(min_lr, group["lr"]) 121 | max_lr = max(max_lr, group["lr"]) 122 | 123 | metric_logger.update(lr=max_lr) 124 | metric_logger.update(min_lr=min_lr) 125 | weight_decay_value = None 126 | for group in optimizer.param_groups: 127 | if group["weight_decay"] > 0: 128 | weight_decay_value = group["weight_decay"] 129 | metric_logger.update(weight_decay=weight_decay_value) 130 | metric_logger.update(grad_norm=grad_norm) 131 | 132 | if log_writer is not None: 133 | log_writer.update(loss=loss_value, head="loss") 134 | log_writer.update(loss_img_feat=loss_value_img_feat, head="loss_img_feat") 135 | log_writer.update(loss_vid_feat=loss_value_vid_feat, head="loss_vid_feat") 136 | log_writer.update(loss_scale=loss_scale_value, head="opt") 137 | log_writer.update(lr=max_lr, head="opt") 138 | log_writer.update(min_lr=min_lr, head="opt") 139 | log_writer.update(weight_decay=weight_decay_value, head="opt") 140 | log_writer.update(grad_norm=grad_norm, head="opt") 141 | log_writer.set_step() 142 | 143 | if lr_scheduler is not None: 144 | lr_scheduler.step_update(start_steps + step) 145 | # gather the stats from all processes 146 | metric_logger.synchronize_between_processes() 147 | print("Averaged stats:", metric_logger) 148 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 149 | -------------------------------------------------------------------------------- /modeling_video_teacher.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial 7 | 8 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table, get_3d_sincos_pos_embed 9 | from timm.models.registry import register_model 10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 11 | 12 | 13 | def trunc_normal_(tensor, mean=0., std=1.): 14 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 15 | 16 | 17 | __all__ = [ 18 | 'pretrain_videomae_teacher_base_patch16_224', 19 | 'pretrain_videomae_teacher_large_patch16_224', 20 | 'pretrain_videomae_teacher_huge_patch16_224', 21 | ] 22 | 23 | 24 | # -------------------------------------------------------- 25 | # VideoMAE encoder 26 | # References: 27 | # VideoMAE: https://github.com/MCG-NJU/VideoMAE/blob/main/modeling_pretrain.py 28 | # -------------------------------------------------------- 29 | class PretrainVisionTransformerEncoder(nn.Module): 30 | """ Vision Transformer with support for patch or hybrid CNN input stage 31 | """ 32 | 33 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 34 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 35 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2): 36 | super().__init__() 37 | self.num_classes = num_classes 38 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 39 | self.patch_embed = PatchEmbed( 40 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tubelet_size=tubelet_size) 41 | num_patches = self.patch_embed.num_patches 42 | 43 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) 44 | 45 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 46 | self.blocks = nn.ModuleList([ 47 | Block( 48 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 49 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 50 | init_values=init_values) 51 | for i in range(depth)]) 52 | self.norm = norm_layer(embed_dim) 53 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 54 | 55 | self.apply(self._init_weights) 56 | 57 | def _init_weights(self, m): 58 | if isinstance(m, nn.Linear): 59 | nn.init.xavier_uniform_(m.weight) 60 | if isinstance(m, nn.Linear) and m.bias is not None: 61 | nn.init.constant_(m.bias, 0) 62 | elif isinstance(m, nn.LayerNorm): 63 | nn.init.constant_(m.bias, 0) 64 | nn.init.constant_(m.weight, 1.0) 65 | 66 | def get_num_layers(self): 67 | return len(self.blocks) 68 | 69 | @torch.jit.ignore 70 | def no_weight_decay(self): 71 | return {'pos_embed', 'cls_token'} 72 | 73 | def get_classifier(self): 74 | return self.head 75 | 76 | def reset_classifier(self, num_classes, global_pool=''): 77 | self.num_classes = num_classes 78 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 79 | 80 | def forward_features(self, x): 81 | x = self.patch_embed(x) 82 | 83 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() 84 | 85 | for i, blk in enumerate(self.blocks): 86 | x = blk(x) 87 | 88 | return x 89 | 90 | def forward(self, x): 91 | x = self.forward_features(x) 92 | return x 93 | 94 | 95 | class PretrainVideoTransformerTeacher(nn.Module): 96 | """ Vision Transformer with support for patch or hybrid CNN input stage 97 | """ 98 | 99 | def __init__(self, 100 | img_size=224, 101 | patch_size=16, 102 | encoder_in_chans=3, 103 | encoder_num_classes=0, 104 | encoder_embed_dim=768, 105 | encoder_depth=12, 106 | encoder_num_heads=12, 107 | mlp_ratio=4., 108 | qkv_bias=False, 109 | qk_scale=None, 110 | drop_rate=0., 111 | attn_drop_rate=0., 112 | drop_path_rate=0., 113 | norm_layer=nn.LayerNorm, 114 | init_values=0., 115 | tubelet_size=2, 116 | ): 117 | super().__init__() 118 | self.encoder = PretrainVisionTransformerEncoder( 119 | img_size=img_size, 120 | patch_size=patch_size, 121 | in_chans=encoder_in_chans, 122 | num_classes=encoder_num_classes, 123 | embed_dim=encoder_embed_dim, 124 | depth=encoder_depth, 125 | num_heads=encoder_num_heads, 126 | mlp_ratio=mlp_ratio, 127 | qkv_bias=qkv_bias, 128 | qk_scale=qk_scale, 129 | drop_rate=drop_rate, 130 | attn_drop_rate=attn_drop_rate, 131 | drop_path_rate=drop_path_rate, 132 | norm_layer=norm_layer, 133 | init_values=init_values, 134 | tubelet_size=tubelet_size, 135 | ) 136 | 137 | def _init_weights(self, m): 138 | if isinstance(m, nn.Linear): 139 | nn.init.xavier_uniform_(m.weight) 140 | if isinstance(m, nn.Linear) and m.bias is not None: 141 | nn.init.constant_(m.bias, 0) 142 | elif isinstance(m, nn.LayerNorm): 143 | nn.init.constant_(m.bias, 0) 144 | nn.init.constant_(m.weight, 1.0) 145 | 146 | def get_num_layers(self): 147 | return len(self.blocks) 148 | 149 | @torch.jit.ignore 150 | def no_weight_decay(self): 151 | return {'pos_embed', 'cls_token'} 152 | 153 | def forward(self, x): 154 | x = self.encoder(x) 155 | return x 156 | 157 | 158 | @register_model 159 | def pretrain_videomae_teacher_base_patch16_224(pretrained=False, **kwargs): 160 | model = PretrainVideoTransformerTeacher( 161 | patch_size=16, 162 | encoder_embed_dim=768, 163 | encoder_depth=12, 164 | encoder_num_heads=12, 165 | encoder_num_classes=0, 166 | mlp_ratio=4, 167 | qkv_bias=True, 168 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 169 | **kwargs) 170 | model.default_cfg = _cfg() 171 | if pretrained: 172 | checkpoint = torch.load( 173 | kwargs["init_ckpt"], map_location="cpu" 174 | ) 175 | model.load_state_dict(checkpoint["model"]) 176 | return model 177 | 178 | 179 | @register_model 180 | def pretrain_videomae_teacher_large_patch16_224(pretrained=False, **kwargs): 181 | model = PretrainVideoTransformerTeacher( 182 | patch_size=16, 183 | encoder_embed_dim=1024, 184 | encoder_depth=24, 185 | encoder_num_heads=16, 186 | encoder_num_classes=0, 187 | mlp_ratio=4, 188 | qkv_bias=True, 189 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 190 | **kwargs) 191 | model.default_cfg = _cfg() 192 | if pretrained: 193 | checkpoint = torch.load( 194 | kwargs["init_ckpt"], map_location="cpu" 195 | ) 196 | model.load_state_dict(checkpoint["model"]) 197 | return model 198 | 199 | 200 | @register_model 201 | def pretrain_videomae_teacher_huge_patch16_224(pretrained=False, **kwargs): 202 | model = PretrainVideoTransformerTeacher( 203 | patch_size=16, 204 | encoder_embed_dim=1280, 205 | encoder_depth=32, 206 | encoder_num_heads=16, 207 | encoder_num_classes=0, 208 | mlp_ratio=4, 209 | qkv_bias=True, 210 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 211 | **kwargs) 212 | model.default_cfg = _cfg() 213 | if pretrained: 214 | checkpoint = torch.load( 215 | kwargs["init_ckpt"], map_location="cpu" 216 | ) 217 | model.load_state_dict(checkpoint["model"]) 218 | return model 219 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | class LARS(torch.optim.Optimizer): 25 | """ 26 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 27 | """ 28 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 29 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 30 | super().__init__(params, defaults) 31 | 32 | @torch.no_grad() 33 | def step(self): 34 | for g in self.param_groups: 35 | for p in g['params']: 36 | dp = p.grad 37 | 38 | if dp is None: 39 | continue 40 | 41 | if p.ndim > 1: # if not normalization gamma/beta or bias 42 | dp = dp.add(p, alpha=g['weight_decay']) 43 | param_norm = torch.norm(p) 44 | update_norm = torch.norm(dp) 45 | one = torch.ones_like(param_norm) 46 | q = torch.where( 47 | param_norm > 0., 48 | torch.where(update_norm > 0, (g['trust_coefficient'] * param_norm / update_norm), one), 49 | one 50 | ) 51 | dp = dp.mul(q) 52 | 53 | param_state = self.state[p] 54 | if 'mu' not in param_state: 55 | param_state['mu'] = torch.zeros_like(p) 56 | mu = param_state['mu'] 57 | mu.mul_(g['momentum']).add_(dp) 58 | p.add_(mu, alpha=-g['lr']) 59 | 60 | 61 | def get_num_layer_for_vit(var_name, num_max_layer): 62 | if var_name in ("cls_token", "mask_token", "pos_embed"): 63 | return 0 64 | elif var_name.startswith("patch_embed"): 65 | return 0 66 | elif var_name.startswith("rel_pos_bias"): 67 | return num_max_layer - 1 68 | elif var_name.startswith("blocks"): 69 | layer_id = int(var_name.split('.')[1]) 70 | return layer_id + 1 71 | else: 72 | return num_max_layer - 1 73 | 74 | 75 | class LayerDecayValueAssigner(object): 76 | def __init__(self, values): 77 | self.values = values 78 | 79 | def get_scale(self, layer_id): 80 | return self.values[layer_id] 81 | 82 | def get_layer_id(self, var_name): 83 | return get_num_layer_for_vit(var_name, len(self.values)) 84 | 85 | 86 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 87 | parameter_group_names = {} 88 | parameter_group_vars = {} 89 | 90 | for name, param in model.named_parameters(): 91 | if not param.requires_grad: 92 | continue # frozen weights 93 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 94 | group_name = "no_decay" 95 | this_weight_decay = 0. 96 | else: 97 | group_name = "decay" 98 | this_weight_decay = weight_decay 99 | if get_num_layer is not None: 100 | layer_id = get_num_layer(name) 101 | group_name = "layer_%d_%s" % (layer_id, group_name) 102 | else: 103 | layer_id = None 104 | 105 | if group_name not in parameter_group_names: 106 | if get_layer_scale is not None: 107 | scale = get_layer_scale(layer_id) 108 | else: 109 | scale = 1. 110 | 111 | parameter_group_names[group_name] = { 112 | "weight_decay": this_weight_decay, 113 | "params": [], 114 | "lr_scale": scale 115 | } 116 | parameter_group_vars[group_name] = { 117 | "weight_decay": this_weight_decay, 118 | "params": [], 119 | "lr_scale": scale 120 | } 121 | 122 | parameter_group_vars[group_name]["params"].append(param) 123 | parameter_group_names[group_name]["params"].append(name) 124 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 125 | return list(parameter_group_vars.values()) 126 | 127 | 128 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 129 | opt_lower = args.opt.lower() 130 | weight_decay = args.weight_decay 131 | if filter_bias_and_bn: 132 | skip = {} 133 | if skip_list is not None: 134 | skip = skip_list 135 | elif hasattr(model, 'no_weight_decay'): 136 | skip = model.no_weight_decay() 137 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 138 | weight_decay = 0. 139 | else: 140 | parameters = model.parameters() 141 | 142 | if 'fused' in opt_lower: 143 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 144 | 145 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 146 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 147 | opt_args['eps'] = args.opt_eps 148 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 149 | opt_args['betas'] = args.opt_betas 150 | 151 | print("optimizer settings:", opt_args) 152 | 153 | opt_split = opt_lower.split('_') 154 | opt_lower = opt_split[-1] 155 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 156 | opt_args.pop('eps', None) 157 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 158 | elif opt_lower == 'momentum': 159 | opt_args.pop('eps', None) 160 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 161 | elif opt_lower == 'adam': 162 | optimizer = optim.Adam(parameters, **opt_args) 163 | elif opt_lower == 'adamw': 164 | optimizer = optim.AdamW(parameters, **opt_args) 165 | elif opt_lower == 'nadam': 166 | optimizer = Nadam(parameters, **opt_args) 167 | elif opt_lower == 'radam': 168 | optimizer = RAdam(parameters, **opt_args) 169 | elif opt_lower == 'adamp': 170 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 171 | elif opt_lower == 'sgdp': 172 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 173 | elif opt_lower == 'adadelta': 174 | optimizer = optim.Adadelta(parameters, **opt_args) 175 | elif opt_lower == 'adafactor': 176 | if not args.lr: 177 | opt_args['lr'] = None 178 | optimizer = Adafactor(parameters, **opt_args) 179 | elif opt_lower == 'adahessian': 180 | optimizer = Adahessian(parameters, **opt_args) 181 | elif opt_lower == 'rmsprop': 182 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 183 | elif opt_lower == 'rmsproptf': 184 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 185 | elif opt_lower == 'novograd': 186 | optimizer = NovoGrad(parameters, **opt_args) 187 | elif opt_lower == 'nvnovograd': 188 | optimizer = NvNovoGrad(parameters, **opt_args) 189 | elif opt_lower == 'fusedsgd': 190 | opt_args.pop('eps', None) 191 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 192 | elif opt_lower == 'fusedmomentum': 193 | opt_args.pop('eps', None) 194 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 195 | elif opt_lower == 'fusedadam': 196 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 197 | elif opt_lower == 'fusedadamw': 198 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 199 | elif opt_lower == 'fusedlamb': 200 | optimizer = FusedLAMB(parameters, **opt_args) 201 | elif opt_lower == 'fusednovograd': 202 | opt_args.setdefault('betas', (0.95, 0.98)) 203 | optimizer = FusedNovoGrad(parameters, **opt_args) 204 | elif opt_lower == 'lars': 205 | opt_args.pop('eps', None) 206 | optimizer = LARS(parameters, **opt_args) 207 | else: 208 | assert False and "Invalid optimizer" 209 | raise ValueError 210 | 211 | if len(opt_split) > 1: 212 | if opt_split[0] == 'lookahead': 213 | optimizer = Lookahead(optimizer) 214 | 215 | return optimizer 216 | -------------------------------------------------------------------------------- /engine_locomotion_prediction_for_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import sys 5 | from typing import Iterable, Optional 6 | import torch 7 | from timm.utils import accuracy, ModelEma 8 | import utils 9 | from scipy.special import softmax 10 | from einops import rearrange 11 | from torch.utils.data._utils.collate import default_collate 12 | import torch.nn.functional as F 13 | import pandas as pd 14 | from matplotlib import pyplot as plt 15 | import os 16 | 17 | parent = os.path.dirname(os.path.abspath(__file__)) 18 | parent_parent = os.path.join(parent, '../') 19 | sys.path.append(os.path.dirname(parent_parent)) 20 | 21 | from locomotion_prediction.locomotion_prediction_utils import * 22 | 23 | 24 | def train_class_batch(model, samples, dP_tensor, criterion, args): 25 | outputs = model(samples) 26 | dP_pred = outputs.unflatten(1, (args.act_pose_prediction, args.num_pred)) 27 | 28 | loss = criterion(dP_pred, dP_tensor) 29 | 30 | return loss, dP_pred 31 | 32 | 33 | def get_loss_scale_for_deepspeed(model): 34 | optimizer = model.optimizer 35 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 36 | 37 | 38 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 39 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 40 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 41 | model_ema: Optional[ModelEma] = None, mixup_fn=None, log_writer=None, 42 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 43 | num_training_steps_per_epoch=None, update_freq=None, args=None): 44 | model.train(True) 45 | metric_logger = utils.MetricLogger(delimiter=" ") 46 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 47 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 48 | header = 'Epoch: [{}]'.format(epoch) 49 | print_freq = 10 50 | 51 | if loss_scaler is None: 52 | model.zero_grad() 53 | model.micro_steps = 0 54 | else: 55 | optimizer.zero_grad() 56 | 57 | for data_iter_step, (samples, dP_tensor, start_idx, end_idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 58 | step = data_iter_step // update_freq 59 | if step >= num_training_steps_per_epoch: 60 | continue 61 | it = start_steps + step # global training iteration 62 | # Update LR & WD for the first acc 63 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 64 | for i, param_group in enumerate(optimizer.param_groups): 65 | if lr_schedule_values is not None and 'lr_scale' in param_group: 66 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 67 | if wd_schedule_values is not None and 'weight_decay' in param_group and param_group["weight_decay"] > 0: 68 | param_group["weight_decay"] = wd_schedule_values[it] 69 | 70 | samples = samples.to(device, non_blocking=True) 71 | samples = samples.permute(0, 2, 1, 3, 4) 72 | dP_tensor = dP_tensor.to(device, non_blocking=True) 73 | 74 | if mixup_fn is not None: 75 | samples, targets = mixup_fn(samples, targets) 76 | 77 | if loss_scaler is None: 78 | samples = samples.half() 79 | loss, dP_pred = train_class_batch( 80 | model, samples, dP_tensor, criterion, args) 81 | else: 82 | with torch.cuda.amp.autocast(): 83 | loss, dP_pred = train_class_batch( 84 | model, samples, dP_tensor, criterion, args) 85 | 86 | loss_value = loss.item() 87 | 88 | if not math.isfinite(loss_value): 89 | print("Loss is {}, stopping training".format(loss_value)) 90 | sys.exit(1) 91 | 92 | if loss_scaler is None: 93 | loss /= update_freq 94 | model.backward(loss) 95 | model.step() 96 | 97 | if (data_iter_step + 1) % update_freq == 0: 98 | # model.zero_grad() 99 | # Deepspeed will call step() & model.zero_grad() automatic 100 | if model_ema is not None: 101 | model_ema.update(model) 102 | grad_norm = None 103 | loss_scale_value = get_loss_scale_for_deepspeed(model) 104 | else: 105 | # this attribute is added by timm on one optimizer (adahessian) 106 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 107 | loss /= update_freq 108 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 109 | parameters=model.parameters(), create_graph=is_second_order, 110 | update_grad=(data_iter_step + 1) % update_freq == 0) 111 | if (data_iter_step + 1) % update_freq == 0: 112 | optimizer.zero_grad() 113 | if model_ema is not None: 114 | model_ema.update(model) 115 | loss_scale_value = loss_scaler.state_dict()["scale"] 116 | 117 | torch.cuda.synchronize() 118 | 119 | metric_logger.update(loss=loss_value) 120 | metric_logger.update(loss_scale=loss_scale_value) 121 | min_lr = 10. 122 | max_lr = 0. 123 | for group in optimizer.param_groups: 124 | min_lr = min(min_lr, group["lr"]) 125 | max_lr = max(max_lr, group["lr"]) 126 | 127 | metric_logger.update(lr=max_lr) 128 | metric_logger.update(min_lr=min_lr) 129 | weight_decay_value = None 130 | for group in optimizer.param_groups: 131 | if group["weight_decay"] > 0: 132 | weight_decay_value = group["weight_decay"] 133 | metric_logger.update(weight_decay=weight_decay_value) 134 | metric_logger.update(grad_norm=grad_norm) 135 | 136 | if log_writer is not None: 137 | log_writer.update(loss=loss_value, head="loss") 138 | log_writer.update(loss_scale=loss_scale_value, head="opt") 139 | log_writer.update(lr=max_lr, head="opt") 140 | log_writer.update(min_lr=min_lr, head="opt") 141 | log_writer.update(weight_decay=weight_decay_value, head="opt") 142 | log_writer.update(grad_norm=grad_norm, head="opt") 143 | 144 | log_writer.set_step() 145 | 146 | # gather the stats from all processes 147 | metric_logger.synchronize_between_processes() 148 | print("Averaged stats:", metric_logger) 149 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 150 | 151 | 152 | @torch.no_grad() 153 | def validation_one_epoch(dataset_val, model, criterion, device, args): 154 | num_tasks = utils.get_world_size() 155 | global_rank = utils.get_rank() 156 | 157 | if args.dist_eval: 158 | if len(dataset_val) % num_tasks != 0: 159 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 160 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 161 | 'equal num of samples per-process.') 162 | sampler_val = torch.utils.data.DistributedSampler( 163 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 164 | else: 165 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 166 | 167 | data_loader = torch.utils.data.DataLoader( 168 | dataset_val, sampler=sampler_val, 169 | batch_size=1, 170 | num_workers=args.num_workers, 171 | pin_memory=args.pin_mem, 172 | drop_last=True 173 | ) 174 | 175 | metric_logger = utils.MetricLogger(delimiter=" ") 176 | header = 'Test:' 177 | 178 | # switch to evaluation mode 179 | model.eval() 180 | 181 | for (video_path, trajectory_path, start_idx, end_idx) in metric_logger.log_every(data_loader, 10, header): 182 | try: 183 | video_path, trajectory_path, start_idx, end_idx = video_path[0], trajectory_path[0], start_idx[0].item(), end_idx[0].item() 184 | with torch.cuda.amp.autocast(): 185 | loss, ate, rpe_trans, rpe_rot = evaluate_segment(model, dataset_val, device, video_path, trajectory_path, criterion, start_idx, end_idx, args, save_plot=True) 186 | metric_logger.meters['loss'].update(loss.item(), n=1) 187 | metric_logger.meters['ate'].update(ate, n=1) 188 | metric_logger.meters['rpe_trans'].update(rpe_trans, n=1) 189 | metric_logger.meters['rpe_rot'].update(rpe_rot, n=1) 190 | except Exception as e: 191 | print('Evaluation Error: ', e) 192 | 193 | # gather the stats from all processes 194 | metric_logger.synchronize_between_processes() 195 | print('* loss {losses.global_avg:.5f}' 196 | .format(losses=metric_logger.loss)) 197 | print('* ate {ate.global_avg:.5f}' 198 | .format(ate=metric_logger.ate)) 199 | print('* rpe_trans {rpe_trans.global_avg:.5f}' 200 | .format(rpe_trans=metric_logger.rpe_trans)) 201 | print('* rpe_rot {rpe_rot.global_avg:.5f}' 202 | .format(rpe_rot=metric_logger.rpe_rot)) 203 | 204 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 205 | -------------------------------------------------------------------------------- /object_interaction/object_interaction_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | objects_dict = { 5 | 'person': 0, 6 | 'ball': 1, 7 | 'bench': 2, 8 | 'bird': 3, 9 | 'dog': 4, 10 | 'cat': 5, 11 | 'other animal': 6, 12 | 'toy': 7, 13 | 'door': 8, 14 | 'floor': 9, 15 | 'food': 10, 16 | 'plant': 11, 17 | 'filament': 12, 18 | 'plastic': 13, 19 | 'water': 14, 20 | 'vehicle': 15, 21 | 'other': 16, 22 | } 23 | 24 | # Decoding Helper Functions 25 | def spatial_sampling( 26 | frames, 27 | spatial_idx=-1, 28 | min_scale=256, 29 | max_scale=320, 30 | crop_size=224, 31 | random_horizontal_flip=True, 32 | inverse_uniform_sampling=False, 33 | aspect_ratio=None, 34 | scale=None, 35 | motion_shift=False, 36 | ): 37 | """ 38 | Perform spatial sampling on the given video frames. If spatial_idx is 39 | -1, perform random scale, random crop, and random flip on the given 40 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 41 | with the given spatial_idx. 42 | Args: 43 | frames (tensor): frames of images sampled from the video. The 44 | dimension is `num frames` x `height` x `width` x `channel`. 45 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 46 | or 2, perform left, center, right crop if width is larger than 47 | height, and perform top, center, buttom crop if height is larger 48 | than width. 49 | min_scale (int): the minimal size of scaling. 50 | max_scale (int): the maximal size of scaling. 51 | crop_size (int): the size of height and width used to crop the 52 | frames. 53 | inverse_uniform_sampling (bool): if True, sample uniformly in 54 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 55 | scale. If False, take a uniform sample from [min_scale, 56 | max_scale]. 57 | aspect_ratio (list): Aspect ratio range for resizing. 58 | scale (list): Scale range for resizing. 59 | motion_shift (bool): Whether to apply motion shift for resizing. 60 | Returns: 61 | frames (tensor): spatially sampled frames. 62 | """ 63 | 64 | assert spatial_idx in [-1, 0, 1, 2] 65 | if spatial_idx == -1: 66 | if aspect_ratio is None and scale is None: 67 | frames = transform.random_short_side_scale_jitter( 68 | images=frames, 69 | min_size=min_scale, 70 | max_size=max_scale, 71 | inverse_uniform_sampling=inverse_uniform_sampling, 72 | ) 73 | frames = transform.random_crop(frames, crop_size) 74 | else: 75 | transform_func = ( 76 | transform.random_resized_crop_with_shift 77 | if motion_shift 78 | else transform.random_resized_crop 79 | ) 80 | frames = transform_func( 81 | images=frames, 82 | target_height=crop_size, 83 | target_width=crop_size, 84 | scale=scale, 85 | ratio=aspect_ratio, 86 | ) 87 | if random_horizontal_flip: 88 | frames = transform.horizontal_flip(0.5, frames) 89 | else: 90 | # The testing is deterministic and no jitter should be performed. 91 | # min_scale, max_scale, and crop_size are expect to be the same. 92 | assert len({min_scale, max_scale}) == 1 93 | frames = transform.random_short_side_scale_jitter(frames, min_scale, max_scale) 94 | frames = transform.uniform_crop(frames, crop_size, spatial_idx) 95 | return frames 96 | 97 | def temporal_sampling(frames, start_idx, end_idx, num_samples): 98 | """ 99 | Given the start and end frame index, sample num_samples frames between 100 | the start and end with equal interval. 101 | Args: 102 | frames (tensor): a tensor of video frames, dimension is 103 | `num video frames` x `channel` x `height` x `width`. 104 | start_idx (int): the index of the start frame. 105 | end_idx (int): the index of the end frame. 106 | num_samples (int): number of frames to sample. 107 | Returns: 108 | frames (tersor): a tensor of temporal sampled video frames, dimension is 109 | `num clip frames` x `channel` x `height` x `width`. 110 | """ 111 | index = torch.linspace(start_idx, end_idx, num_samples) 112 | index = torch.clamp(index, 0, frames.shape[0] - 1).long() 113 | new_frames = torch.index_select(frames, 0, index) 114 | return new_frames 115 | 116 | def get_start_end_idx(video_size, clip_size, clip_idx, num_clips, use_offset=False): 117 | """ 118 | Sample a clip of size clip_size from a video of size video_size and 119 | return the indices of the first and last frame of the clip. If clip_idx is 120 | -1, the clip is randomly sampled, otherwise uniformly split the video to 121 | num_clips clips, and select the start and end index of clip_idx-th video 122 | clip. 123 | Args: 124 | video_size (int): number of overall frames. 125 | clip_size (int): size of the clip to sample from the frames. 126 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If 127 | clip_idx is larger than -1, uniformly split the video to num_clips 128 | clips, and select the start and end index of the clip_idx-th video 129 | clip. 130 | num_clips (int): overall number of clips to uniformly sample from the 131 | given video for testing. 132 | Returns: 133 | start_idx (int): the start frame index. 134 | end_idx (int): the end frame index. 135 | """ 136 | delta = max(video_size - clip_size, 0) 137 | if clip_idx == -1: 138 | # Random temporal sampling. 139 | start_idx = random.uniform(0, delta) 140 | else: 141 | if use_offset: 142 | if num_clips == 1: 143 | # Take the center clip if num_clips is 1. 144 | start_idx = math.floor(delta / 2) 145 | else: 146 | # Uniformly sample the clip with the given index. 147 | start_idx = clip_idx * math.floor(delta / (num_clips - 1)) 148 | else: 149 | # Uniformly sample the clip with the given index. 150 | start_idx = delta * clip_idx / num_clips 151 | end_idx = start_idx + clip_size - 1 152 | return start_idx, end_idx 153 | 154 | 155 | # Helper Functions 156 | def get_video_path(path_to_data_dir, animal, video_id, segment_id): 157 | video_name = 'edited_{video_id}_segment_{segment_id}.mp4'.format(video_id=video_id, segment_id=segment_id.zfill(6)) 158 | video_path = os.path.join(path_to_data_dir, animal, video_name) 159 | return video_path 160 | 161 | def process_time(time): 162 | # Assumes in format ##:##:## or ##:## or NONE to get time in seconds 163 | if time == 'NONE': 164 | return 0 165 | 166 | time_split = time.split(':') 167 | if len(time_split) == 1: 168 | time_sec = int(time_split[0]) 169 | elif len(time_split) == 2: 170 | time_sec = int(time_split[0]) * 60 + int(time_split[1]) 171 | elif len(time_split) == 3: 172 | time_sec = int(time_split[0]) * 3600 + int(time_split[1]) * 60 + int(time_split[2]) 173 | else: 174 | raise ValueError('Invalid Time was {time} but expected in format ##:##:##, ##:##, or #'.format(time=time)) 175 | return time_sec 176 | 177 | def process_end_time(time, total_time): 178 | # Get end time for processing NONE times 179 | if time == 'NONE': 180 | return process_time(total_time) 181 | else: 182 | return process_time(time) 183 | 184 | def get_none_segments(start_time, end_time, total_time, interacting_object): 185 | min_beg_end_length = 4 + 2 # the actual clip must be at least 4 seconds, 2 seconds from the next clip 186 | min_mid_length = 2 + 4 + 2 # 2 seconds after the last clip, the actual clip must be at least 4 seconds, 2 clips before next clip 187 | 188 | new_start_time = [] 189 | new_end_time = [] 190 | new_interacting_object = [] 191 | 192 | # Checking middle portions 193 | for i in range(len(start_time) - 1): 194 | idx = i+1 195 | if start_time[idx] - end_time[i] >= min_mid_length: 196 | new_start_time.append(end_time[i]+2) 197 | new_end_time.append(start_time[idx]-2) 198 | new_interacting_object.append('NONE') 199 | 200 | # Checking beginning 201 | if start_time[0] >= min_beg_end_length: 202 | new_start_time.append(0) 203 | new_end_time.append(start_time[0]-2) 204 | new_interacting_object.append('NONE') 205 | 206 | # Checking ending 207 | if total_time - end_time[-1] >= min_beg_end_length: 208 | new_start_time.append(end_time[-1]+2) 209 | new_end_time.append(total_time) 210 | new_interacting_object.append('NONE') 211 | 212 | start_time.extend(new_start_time) 213 | end_time.extend(new_end_time) 214 | interacting_object.extend(new_interacting_object) 215 | return start_time, end_time, interacting_object 216 | 217 | def get_label(interacting_object): 218 | if interacting_object == 'NONE': 219 | return -1 220 | elif interacting_object in objects_dict: 221 | object_label_idx = objects_dict[interacting_object] 222 | else: 223 | raise ValueError("Invalid Object Type: {curr_object}".format(curr_object=interacting_object)) 224 | 225 | return object_label_idx 226 | -------------------------------------------------------------------------------- /locomotion_prediction/locomotion_prediction_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import os 5 | import random 6 | import math 7 | 8 | import torch 9 | import numpy as np 10 | import torch.utils.data 11 | from iopath.common.file_io import g_pathmgr as pathmgr 12 | 13 | import sys 14 | from decoder.utils import decode_ffmpeg 15 | from torchvision import transforms 16 | from torch.nn.functional import normalize 17 | 18 | ### evo evaluation library ### 19 | import evo 20 | from evo.core.trajectory import PoseTrajectory3D 21 | from evo.tools import file_interface 22 | from evo.core import sync, metrics 23 | import evo.main_ape as main_ape 24 | import evo.main_rpe as main_rpe 25 | from evo.core.metrics import PoseRelation 26 | from evo.tools import plot 27 | from dpvo.plot_utils import plot_trajectory, save_trajectory_tum_format 28 | 29 | import sys 30 | parent = os.path.dirname(os.path.abspath(__file__)) 31 | parent_parent = os.path.join(parent, '../') 32 | sys.path.append(os.path.dirname(parent_parent)) 33 | 34 | from dpvo.plot_utils import * 35 | from pathlib import Path 36 | 37 | 38 | def get_poses(trajectory_path, start_idx, end_idx): 39 | """ 40 | Gets the camera poses for the frames start_idx to end_idx. 41 | Args: 42 | trajectory_path (string): the path to the file containing the trajectory 43 | in TUM format. 44 | start_idx (int): the start index to get poses 45 | end_idx (int): the end index to get poses 46 | Returns: 47 | traj_poses (tensor): The poses (end_idx - start_idx) x 7 48 | """ 49 | traj = file_interface.read_tum_trajectory_file(trajectory_path) 50 | traj_poses = np.concatenate((traj.positions_xyz, traj.orientations_quat_wxyz[:, [1, 2, 3, 0]]), axis=1) 51 | traj_poses = torch.from_numpy(traj_poses) 52 | 53 | return traj_poses[start_idx:end_idx] 54 | 55 | 56 | def relative_poses(trajectory_path, start_idx, end_idx, scale_invariance, pose_skip): 57 | """ 58 | Gets dP, the relative pose to the prior frame, between camera poses for the 59 | frames start_idx to end_idx. 60 | Args: 61 | trajectory_path (string): the path to the file containing the trajectory 62 | in TUM format. 63 | start_idx (int): the start index to get relative poses 64 | end_idx (int): the end index to get relative poses 65 | scale_invariance (str): scale invariance type ('dir' or None) 66 | pose_skip (int): stride of pose prediction 67 | Returns: 68 | dP_tensor (tensor): the dPs . The dimension 69 | is (end_idx - start_idx) x 7. 70 | scale (tensor): if scale_invariance=='dir' return the scale of the 71 | poses 72 | """ 73 | traj = file_interface.read_tum_trajectory_file(trajectory_path) 74 | translation = traj.positions_xyz[1:] - traj.positions_xyz[:-1] 75 | 76 | dP_tensor = torch.from_numpy(translation) 77 | zero_tensor = torch.tensor([[0, 0, 0]]) # Hacky adding 0'th relative pose for indexing 78 | dP_tensor = torch.cat([zero_tensor, dP_tensor], dim=0) 79 | 80 | dP_tensor = dP_tensor[start_idx:end_idx] 81 | dP_tensor = dP_tensor.unflatten(0, (-1, pose_skip)) 82 | dP_tensor = torch.sum(dP_tensor, dim=1) 83 | 84 | if scale_invariance == 'dir': 85 | scale = torch.norm(dP_tensor, p=2, dim=1, keepdim=True) 86 | dP_tensor = dP_tensor / (scale + 1e-10) 87 | return dP_tensor, scale 88 | 89 | return dP_tensor 90 | 91 | 92 | def make_traj_from_tensor(traj_tensor, start_idx, end_idx, num_condition_frames, num_pose_prediction, act_pose_prediction, pose_skip): 93 | tstamps = np.arange(0, end_idx, 1, dtype=np.float64) 94 | tstamps = tstamps.reshape((-1, num_condition_frames + num_pose_prediction))[:, num_condition_frames+(pose_skip-1)::pose_skip] 95 | tstamps = tstamps.flatten() 96 | 97 | tstamps_idx = np.arange(0, len(traj_tensor), 1, dtype=np.float64) 98 | tstamps_idx = tstamps_idx.reshape((-1, num_condition_frames + act_pose_prediction))[:, num_condition_frames:] 99 | tstamps_idx = tstamps_idx.flatten() 100 | 101 | traj = PoseTrajectory3D(positions_xyz=traj_tensor[tstamps_idx.astype(int),:3], orientations_quat_wxyz=traj_tensor[tstamps_idx.astype(int),3:][:, [3, 0, 1, 2]], timestamps=tstamps) 102 | 103 | return traj 104 | 105 | 106 | def eval_metrics(traj_ref, traj_pred): 107 | traj_ref, traj_pred = sync.associate_trajectories(traj_ref, traj_pred) 108 | 109 | result = main_ape.ape(traj_ref, traj_pred, est_name='traj', 110 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True) 111 | ate = result.stats['rmse'] 112 | 113 | result = main_rpe.rpe(traj_ref, traj_pred, est_name='traj', 114 | pose_relation=PoseRelation.rotation_angle_deg, align=True, correct_scale=True, 115 | delta=1.0, delta_unit=metrics.Unit.frames, rel_delta_tol=0.1) 116 | rpe_rot = result.stats['rmse'] 117 | 118 | result = main_rpe.rpe(traj_ref, traj_pred, est_name='traj', 119 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True, 120 | delta=1.0, delta_unit=metrics.Unit.frames, rel_delta_tol=0.1) 121 | rpe_trans = result.stats['rmse'] 122 | 123 | return ate, rpe_trans, rpe_rot 124 | 125 | 126 | def evaluate_segment(model, dataset_val, device, video_path, trajectory_path, criterion, start_idx, end_idx, args, save_plot=False, image_model=False): 127 | # Assume batch of 1 128 | loss = 0 129 | total_idx = end_idx - start_idx 130 | 131 | frames_per_sub_clip = args.num_condition_frames + args.num_pose_prediction 132 | act_frames_per_sub_clip = args.num_condition_frames + args.act_pose_prediction 133 | num_sub_clips = total_idx // frames_per_sub_clip 134 | num_sub_batch = math.ceil(num_sub_clips / args.validation_batch_size) 135 | 136 | all_cond_poses = get_poses(trajectory_path, start_idx, end_idx).cpu() 137 | all_cond_poses = all_cond_poses.unflatten(0, (-1, frames_per_sub_clip)) 138 | all_cond_poses = torch.cat((all_cond_poses[:, :args.num_condition_frames ,:], all_cond_poses[:, args.num_condition_frames+(args.pose_skip-1)::args.pose_skip]), 1) 139 | all_cond_poses = all_cond_poses.flatten(0, 1) 140 | traj_ref = make_traj_from_tensor(all_cond_poses, start_idx, end_idx, args.num_condition_frames, args.num_pose_prediction, args.act_pose_prediction, args.pose_skip) 141 | 142 | traj_pred_poses = all_cond_poses.clone() 143 | traj_pred_poses = traj_pred_poses.unflatten(0, (-1, act_frames_per_sub_clip)) 144 | traj_pred_poses[:, args.num_condition_frames:] = 0 145 | traj_pred_poses = traj_pred_poses.flatten(0, 1) 146 | traj_pred_poses[:, 3:] = all_cond_poses[:, 3:] 147 | traj_pred_poses = traj_pred_poses.to(device=device, dtype=torch.float32) 148 | 149 | for i in range(num_sub_batch): 150 | sub_start_idx = i * args.validation_batch_size * frames_per_sub_clip 151 | sub_end_idx = min((i + 1) * args.validation_batch_size * frames_per_sub_clip, total_idx) 152 | 153 | frames, dP_tensor = dataset_val._get_sub_clip(video_path, trajectory_path, sub_start_idx, sub_end_idx) 154 | 155 | frames = frames.to(device, non_blocking=True) 156 | dP_tensor = dP_tensor.to(device= device, dtype=torch.float32, non_blocking=True) 157 | 158 | frames = frames.unflatten(0, (-1, frames_per_sub_clip)) 159 | frames = frames[:, :args.num_condition_frames, :, :, :] 160 | frames = frames.permute(0, 2, 1, 3, 4) 161 | 162 | dP_tensor = dP_tensor.unflatten(0, (-1, frames_per_sub_clip)) 163 | dP_tensor = dP_tensor[:, args.num_condition_frames:, :].unflatten(1, (-1, args.pose_skip)) 164 | dP_tensor = torch.sum(dP_tensor, dim=2) 165 | 166 | if args.scale_invariance == 'dir': 167 | scale = torch.norm(dP_tensor, p=2, dim=2, keepdim=True) 168 | dP_tensor = dP_tensor / (scale + 1e-10) 169 | 170 | if image_model: 171 | frames = frames[:, :, -1, :, :] 172 | 173 | outputs = model(frames) 174 | dP_pred = outputs.unflatten(1, (args.act_pose_prediction, args.num_pred)) 175 | 176 | loss += criterion(dP_pred, dP_tensor) 177 | 178 | if args.scale_invariance == 'dir': 179 | dP_pred_unnorm = dP_pred * scale 180 | 181 | i_start_idx = i*args.validation_batch_size*act_frames_per_sub_clip 182 | for j in range(dP_pred_unnorm.shape[0]): 183 | j_start_idx = i_start_idx + j * act_frames_per_sub_clip + args.num_condition_frames 184 | 185 | start_xyz = traj_pred_poses[j_start_idx - 1, :3].unsqueeze(0) 186 | 187 | pred_xyz = torch.cat([start_xyz, dP_pred_unnorm[j]], dim=0) 188 | pred_xyz = torch.cumsum(pred_xyz, dim=0)[1:] 189 | traj_pred_poses[j_start_idx:j_start_idx + args.act_pose_prediction, :3] = pred_xyz 190 | 191 | traj_pred = make_traj_from_tensor(traj_pred_poses.detach().cpu(), start_idx, end_idx, args.num_condition_frames, args.num_pose_prediction, args.act_pose_prediction, args.pose_skip) 192 | ate, rpe_trans, rpe_rot = eval_metrics(traj_ref, traj_pred) 193 | 194 | segment_id = video_path.split('/')[-1][:-4] 195 | if save_plot and args.output_dir: 196 | segment_id = video_path.split('/')[-1][:-4] 197 | title = 'Reconstruction: {}'.format(segment_id) 198 | filename = os.path.join(args.output_dir, 'traj_viz', '{}_reconstruction.png'.format(segment_id)) 199 | plot_trajectory(traj_pred, gt_traj=traj_ref, title=title, filename=filename, align=False, correct_scale=True) 200 | 201 | return loss, ate, rpe_trans, rpe_rot 202 | -------------------------------------------------------------------------------- /modeling_teacher.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from functools import partial, reduce 8 | from operator import mul 9 | 10 | from modeling_finetune import _cfg 11 | from timm.models.registry import register_model 12 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 13 | from timm.models.vision_transformer import PatchEmbed, Block 14 | from timm.models.layers import drop_path, to_2tuple 15 | 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | __all__ = [ 22 | 'mae_vit_base_patch16_dec512d8b', 23 | 'mae_vit_large_patch16_dec512d8b', 24 | 'mae_vit_huge_patch14_dec512d8b', 25 | ] 26 | 27 | 28 | # -------------------------------------------------------- 29 | # 2D sine-cosine position embedding 30 | # References: 31 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 32 | # MoCo v3: https://github.com/facebookresearch/moco-v3 33 | # -------------------------------------------------------- 34 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 35 | """ 36 | grid_size: int of the grid height and width 37 | return: 38 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 39 | """ 40 | grid_h = np.arange(grid_size, dtype=np.float32) 41 | grid_w = np.arange(grid_size, dtype=np.float32) 42 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 43 | grid = np.stack(grid, axis=0) 44 | 45 | grid = grid.reshape([2, 1, grid_size, grid_size]) 46 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 47 | if cls_token: 48 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 49 | return pos_embed 50 | 51 | 52 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 53 | assert embed_dim % 2 == 0 54 | 55 | # use half of dimensions to encode grid_h 56 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 57 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 58 | 59 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 60 | return emb 61 | 62 | 63 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 64 | """ 65 | embed_dim: output dimension for each position 66 | pos: a list of positions to be encoded: size (M,) 67 | out: (M, D) 68 | """ 69 | assert embed_dim % 2 == 0 70 | omega = np.arange(embed_dim // 2, dtype=np.float) 71 | omega /= embed_dim / 2. 72 | omega = 1. / 10000**omega # (D/2,) 73 | 74 | pos = pos.reshape(-1) # (M,) 75 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 76 | 77 | emb_sin = np.sin(out) # (M, D/2) 78 | emb_cos = np.cos(out) # (M, D/2) 79 | 80 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 81 | return emb 82 | 83 | 84 | # -------------------------------------------------------- 85 | # Interpolate position embeddings for high-resolution 86 | # References: 87 | # DeiT: https://github.com/facebookresearch/deit 88 | # -------------------------------------------------------- 89 | def interpolate_pos_embed(model, checkpoint_model): 90 | if 'pos_embed' in checkpoint_model: 91 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 92 | embedding_size = pos_embed_checkpoint.shape[-1] 93 | num_patches = model.patch_embed.num_patches 94 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 95 | # height (== width) for the checkpoint position embedding 96 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 97 | # height (== width) for the new position embedding 98 | new_size = int(num_patches ** 0.5) 99 | # class_token and dist_token are kept unchanged 100 | if orig_size != new_size: 101 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 102 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 103 | # only the position tokens are interpolated 104 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 105 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 106 | pos_tokens = torch.nn.functional.interpolate( 107 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 108 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 109 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 110 | checkpoint_model['pos_embed'] = new_pos_embed 111 | 112 | 113 | # -------------------------------------------------------- 114 | # MAE encoder 115 | # References: 116 | # MAE: https://github.com/facebookresearch/mae 117 | # -------------------------------------------------------- 118 | class MaskedAutoencoderViT(nn.Module): 119 | """ Masked Autoencoder with VisionTransformer backbone 120 | """ 121 | 122 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 123 | embed_dim=1024, depth=24, num_heads=16, 124 | mlp_ratio=4., norm_layer=nn.LayerNorm, 125 | ): 126 | super().__init__() 127 | 128 | # -------------------------------------------------------------------------- 129 | # MAE encoder specifics 130 | self.img_size = img_size 131 | self.patch_size = patch_size 132 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 133 | num_patches = self.patch_embed.num_patches 134 | 135 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 136 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), 137 | requires_grad=False) # fixed sin-cos embedding 138 | 139 | self.blocks = nn.ModuleList([ 140 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 141 | for i in range(depth)]) 142 | self.norm = norm_layer(embed_dim) 143 | # -------------------------------------------------------------------------- 144 | 145 | self.initialize_weights() 146 | 147 | def initialize_weights(self): 148 | # initialization 149 | # initialize (and freeze) pos_embed by sin-cos embedding 150 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5), 151 | cls_token=True) 152 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 153 | 154 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 155 | w = self.patch_embed.proj.weight.data 156 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 157 | 158 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 159 | torch.nn.init.normal_(self.cls_token, std=.02) 160 | 161 | # initialize nn.Linear and nn.LayerNorm 162 | self.apply(self._init_weights) 163 | 164 | def _init_weights(self, m): 165 | if isinstance(m, nn.Linear): 166 | # we use xavier_uniform following official JAX ViT: 167 | torch.nn.init.xavier_uniform_(m.weight) 168 | if isinstance(m, nn.Linear) and m.bias is not None: 169 | nn.init.constant_(m.bias, 0) 170 | elif isinstance(m, nn.LayerNorm): 171 | nn.init.constant_(m.bias, 0) 172 | nn.init.constant_(m.weight, 1.0) 173 | 174 | def patchify(self, imgs): 175 | """ 176 | imgs: (N, 3, H, W) 177 | x: (N, L, patch_size**2 *3) 178 | """ 179 | p = self.patch_embed.patch_size[0] 180 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 181 | 182 | h = w = imgs.shape[2] // p 183 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 184 | x = torch.einsum('nchpwq->nhwpqc', x) 185 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) 186 | return x 187 | 188 | def forward_encoder(self, x): 189 | # embed patches 190 | x = self.patch_embed(x) 191 | 192 | # add pos embed w/o cls token 193 | x = x + self.pos_embed[:, 1:, :] 194 | 195 | B, _, C = x.shape 196 | 197 | # append cls token 198 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 199 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 200 | x = torch.cat((cls_tokens, x), dim=1) 201 | 202 | # apply Transformer blocks 203 | for i, blk in enumerate(self.blocks): 204 | x = blk(x) 205 | 206 | return x[:, 1:] 207 | 208 | def forward(self, imgs): 209 | latent = self.forward_encoder(imgs) 210 | return latent 211 | 212 | 213 | @register_model 214 | def mae_teacher_vit_base_patch16(pretrained=False, **kwargs): 215 | model = MaskedAutoencoderViT( 216 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 217 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 218 | return model 219 | 220 | 221 | @register_model 222 | def mae_teacher_vit_large_patch16(pretrained=False, **kwargs): 223 | model = MaskedAutoencoderViT( 224 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 225 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 226 | return model 227 | 228 | 229 | @register_model 230 | def mae_teacher_vit_huge_patch14(pretrained=False, **kwargs): 231 | model = MaskedAutoencoderViT( 232 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 233 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 234 | return model 235 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import transforms 3 | from transforms import * 4 | import video_transforms 5 | from masking_generator import TubeMaskingGenerator, RandomMaskingGenerator 6 | from kinetics import VideoClsDataset, VideoDistillation 7 | from ssv2 import SSVideoClsDataset 8 | 9 | 10 | class DataAugmentationForVideoDistillation(object): 11 | def __init__(self, args, num_frames=None): 12 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN 13 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD 14 | normalize = GroupNormalize(self.input_mean, self.input_std) 15 | self.train_augmentation = GroupMultiScaleTwoResizedCrop( 16 | args.input_size, args.teacher_input_size, [1, .875, .75, .66] 17 | ) 18 | self.transform = transforms.Compose([ 19 | Stack(roll=False), 20 | ToTorchFormatTensor(div=True), 21 | normalize, 22 | ]) 23 | window_size = args.window_size if num_frames is None else (num_frames // args.tubelet_size, args.window_size[1], args.window_size[2]) 24 | if args.mask_type == 'tube': 25 | self.masked_position_generator = TubeMaskingGenerator( 26 | window_size, args.mask_ratio 27 | ) 28 | elif args.mask_type == 'random': 29 | self.masked_position_generator = RandomMaskingGenerator( 30 | window_size, args.mask_ratio 31 | ) 32 | 33 | def __call__(self, images): 34 | process_data_0, process_data_1, labels = self.train_augmentation(images) 35 | process_data_0, _ = self.transform((process_data_0, labels)) 36 | process_data_1, _ = self.transform((process_data_1, labels)) 37 | return process_data_0, process_data_1, self.masked_position_generator() 38 | 39 | def __repr__(self): 40 | repr = "(DataAugmentationForVideoDistillation,\n" 41 | repr += " transform = %s,\n" % str(self.transform) 42 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) 43 | repr += ")" 44 | return repr 45 | 46 | 47 | def build_distillation_dataset(args, num_frames=None): 48 | if num_frames is None: 49 | num_frames = args.num_frames 50 | transform = DataAugmentationForVideoDistillation(args, num_frames=num_frames) 51 | dataset = VideoDistillation( 52 | root=args.data_root, 53 | setting=args.data_path, 54 | video_ext='mp4', 55 | is_color=True, 56 | modality='rgb', 57 | new_length=num_frames, 58 | new_step=args.sampling_rate, 59 | transform=transform, 60 | temporal_jitter=False, 61 | video_loader=True, 62 | use_decord=True, 63 | lazy_init=False, 64 | num_sample=args.num_sample, 65 | num_segments=args.num_sample, 66 | # ds_size=5893 67 | ) 68 | print("Data Aug = %s" % str(transform)) 69 | return dataset 70 | 71 | 72 | def build_dataset(is_train, test_mode, args): 73 | if args.data_set == 'Kinetics-400': 74 | mode = None 75 | anno_path = None 76 | if is_train is True: 77 | mode = 'train' 78 | anno_path = os.path.join(args.data_path, 'train.csv') 79 | elif test_mode is True: 80 | mode = 'test' 81 | anno_path = os.path.join(args.data_path, 'val.csv') 82 | else: 83 | mode = 'validation' 84 | anno_path = os.path.join(args.data_path, 'val.csv') 85 | 86 | dataset = VideoClsDataset( 87 | anno_path=anno_path, 88 | data_path=args.data_root, 89 | mode=mode, 90 | clip_len=args.num_frames, 91 | frame_sample_rate=args.sampling_rate, 92 | num_segment=1, 93 | test_num_segment=args.test_num_segment, 94 | test_num_crop=args.test_num_crop, 95 | num_crop=1 if not test_mode else 3, 96 | keep_aspect_ratio=True, 97 | crop_size=args.input_size, 98 | short_side_size=args.short_side_size, 99 | new_height=256, 100 | new_width=320, 101 | args=args, 102 | ) 103 | nb_classes = 400 104 | 105 | elif args.data_set == 'SSV2': 106 | mode = None 107 | anno_path = None 108 | if is_train is True: 109 | mode = 'train' 110 | anno_path = os.path.join(args.data_path, 'train.csv') 111 | elif test_mode is True: 112 | mode = 'test' 113 | anno_path = os.path.join(args.data_path, 'val.csv') 114 | else: 115 | mode = 'validation' 116 | anno_path = os.path.join(args.data_path, 'val.csv') 117 | 118 | dataset = SSVideoClsDataset( 119 | anno_path=anno_path, 120 | data_path=args.data_root, 121 | mode=mode, 122 | clip_len=1, 123 | num_segment=args.num_frames, 124 | test_num_segment=args.test_num_segment, 125 | test_num_crop=args.test_num_crop, 126 | num_crop=1 if not test_mode else 3, 127 | keep_aspect_ratio=True, 128 | crop_size=args.input_size, 129 | short_side_size=args.short_side_size, 130 | new_height=256, 131 | new_width=320, 132 | args=args, 133 | ) 134 | nb_classes = 174 135 | 136 | elif args.data_set == 'UCF101': 137 | mode = None 138 | anno_path = None 139 | if is_train is True: 140 | mode = 'train' 141 | anno_path = os.path.join(args.data_path, 'train.csv') 142 | elif test_mode is True: 143 | mode = 'test' 144 | anno_path = os.path.join(args.data_path, 'val.csv') 145 | else: 146 | mode = 'validation' 147 | anno_path = os.path.join(args.data_path, 'test.csv') 148 | 149 | dataset = VideoClsDataset( 150 | anno_path=anno_path, 151 | data_path=args.data_root, 152 | mode=mode, 153 | clip_len=args.num_frames, 154 | frame_sample_rate=args.sampling_rate, 155 | num_segment=1, 156 | test_num_segment=args.test_num_segment, 157 | test_num_crop=args.test_num_crop, 158 | num_crop=1 if not test_mode else 3, 159 | keep_aspect_ratio=True, 160 | crop_size=args.input_size, 161 | short_side_size=args.short_side_size, 162 | new_height=256, 163 | new_width=320, 164 | args=args) 165 | nb_classes = 101 166 | 167 | elif args.data_set == 'HMDB51': 168 | mode = None 169 | anno_path = None 170 | if is_train is True: 171 | mode = 'train' 172 | anno_path = os.path.join(args.data_path, 'train.csv') 173 | elif test_mode is True: 174 | mode = 'test' 175 | anno_path = os.path.join(args.data_path, 'val.csv') 176 | else: 177 | mode = 'validation' 178 | anno_path = os.path.join(args.data_path, 'test.csv') 179 | 180 | dataset = VideoClsDataset( 181 | anno_path=anno_path, 182 | data_path=args.data_root, 183 | mode=mode, 184 | clip_len=args.num_frames, 185 | frame_sample_rate=args.sampling_rate, 186 | num_segment=1, 187 | test_num_segment=args.test_num_segment, 188 | test_num_crop=args.test_num_crop, 189 | num_crop=1 if not test_mode else 3, 190 | keep_aspect_ratio=True, 191 | crop_size=args.input_size, 192 | short_side_size=args.short_side_size, 193 | new_height=256, 194 | new_width=320, 195 | args=args) 196 | nb_classes = 51 197 | elif args.data_set == 'egopet': 198 | mode = None 199 | anno_path = None 200 | if is_train is True: 201 | mode = 'train' 202 | anno_path = '/private/home/amirbar/datasets/egopet/egopet_df_v2_mvd_train.csv' 203 | elif test_mode is True: 204 | mode = 'test' 205 | anno_path = '/private/home/amirbar/datasets/egopet/egopet_df_v2_mvd_val.csv' 206 | else: 207 | mode = 'validation' 208 | anno_path = '/private/home/amirbar/datasets/egopet/egopet_df_v2_mvd_val.csv' 209 | 210 | dataset = VideoClsDataset( 211 | anno_path=anno_path, 212 | data_path=args.data_root, 213 | mode=mode, 214 | clip_len=args.num_frames, 215 | frame_sample_rate=args.sampling_rate, 216 | num_segment=1, 217 | test_num_segment=args.test_num_segment, 218 | test_num_crop=args.test_num_crop, 219 | num_crop=1 if not test_mode else 3, 220 | keep_aspect_ratio=True, 221 | crop_size=args.input_size, 222 | short_side_size=args.short_side_size, 223 | new_height=256, 224 | new_width=320, 225 | args=args, 226 | ds_size=5893 227 | ) 228 | nb_classes = 400 229 | 230 | elif args.data_set == 'ego4d': 231 | mode = None 232 | anno_path = None 233 | if is_train is True: 234 | mode = 'train' 235 | anno_path = '/private/home/amirbar/datasets/egopet/ego4d_df_v2_mvd_train.csv' 236 | elif test_mode is True: 237 | mode = 'test' 238 | anno_path = None 239 | else: 240 | mode = 'validation' 241 | anno_path = None 242 | 243 | dataset = VideoClsDataset( 244 | anno_path=anno_path, 245 | data_path=args.data_root, 246 | mode=mode, 247 | clip_len=args.num_frames, 248 | frame_sample_rate=args.sampling_rate, 249 | num_segment=1, 250 | test_num_segment=args.test_num_segment, 251 | test_num_crop=args.test_num_crop, 252 | num_crop=1 if not test_mode else 3, 253 | keep_aspect_ratio=True, 254 | crop_size=args.input_size, 255 | short_side_size=args.short_side_size, 256 | new_height=256, 257 | new_width=320, 258 | args=args, 259 | ds_size=5893 260 | ) 261 | nb_classes = 400 262 | 263 | elif args.data_set == 'egomix': 264 | pass 265 | 266 | else: 267 | raise NotImplementedError() 268 | assert nb_classes == args.nb_classes 269 | print("Number of the class = %d" % args.nb_classes) 270 | 271 | return dataset, nb_classes 272 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | import warnings 4 | import random 5 | import numpy as np 6 | import torchvision 7 | from PIL import Image, ImageOps 8 | import numbers 9 | 10 | 11 | class GroupRandomCrop(object): 12 | def __init__(self, size): 13 | if isinstance(size, numbers.Number): 14 | self.size = (int(size), int(size)) 15 | else: 16 | self.size = size 17 | 18 | def __call__(self, img_tuple): 19 | img_group, label = img_tuple 20 | 21 | w, h = img_group[0].size 22 | th, tw = self.size 23 | 24 | out_images = list() 25 | 26 | x1 = random.randint(0, w - tw) 27 | y1 = random.randint(0, h - th) 28 | 29 | for img in img_group: 30 | assert(img.size[0] == w and img.size[1] == h) 31 | if w == tw and h == th: 32 | out_images.append(img) 33 | else: 34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 35 | 36 | return (out_images, label) 37 | 38 | 39 | class GroupCenterCrop(object): 40 | def __init__(self, size): 41 | self.worker = torchvision.transforms.CenterCrop(size) 42 | 43 | def __call__(self, img_tuple): 44 | img_group, label = img_tuple 45 | return ([self.worker(img) for img in img_group], label) 46 | 47 | 48 | class GroupNormalize(object): 49 | def __init__(self, mean, std): 50 | self.mean = mean 51 | self.std = std 52 | 53 | def __call__(self, tensor_tuple): 54 | tensor, label = tensor_tuple 55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 56 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 57 | 58 | # TODO: make efficient 59 | for t, m, s in zip(tensor, rep_mean, rep_std): 60 | t.sub_(m).div_(s) 61 | 62 | return (tensor,label) 63 | 64 | 65 | class GroupGrayScale(object): 66 | def __init__(self, size): 67 | self.worker = torchvision.transforms.Grayscale(size) 68 | 69 | def __call__(self, img_tuple): 70 | img_group, label = img_tuple 71 | return ([self.worker(img) for img in img_group], label) 72 | 73 | 74 | class GroupScale(object): 75 | """ Rescales the input PIL.Image to the given 'size'. 76 | 'size' will be the size of the smaller edge. 77 | For example, if height > width, then image will be 78 | rescaled to (size * height / width, size) 79 | size: size of the smaller edge 80 | interpolation: Default: PIL.Image.BILINEAR 81 | """ 82 | 83 | def __init__(self, size, interpolation=Image.BILINEAR): 84 | self.worker = torchvision.transforms.Resize(size, interpolation) 85 | 86 | def __call__(self, img_tuple): 87 | img_group, label = img_tuple 88 | return ([self.worker(img) for img in img_group], label) 89 | 90 | 91 | class GroupMultiScaleCrop(object): 92 | 93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 94 | self.scales = scales if scales is not None else [1, .875, .75, .66] 95 | self.max_distort = max_distort 96 | self.fix_crop = fix_crop 97 | self.more_fix_crop = more_fix_crop 98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 99 | self.interpolation = Image.BILINEAR 100 | 101 | def __call__(self, img_tuple): 102 | img_group, label = img_tuple 103 | 104 | im_size = img_group[0].size 105 | 106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group] 109 | return (ret_img_group, label) 110 | 111 | def _sample_crop_size(self, im_size): 112 | image_w, image_h = im_size[0], im_size[1] 113 | 114 | # find a crop size 115 | base_size = min(image_w, image_h) 116 | crop_sizes = [int(base_size * x) for x in self.scales] 117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 119 | 120 | pairs = [] 121 | for i, h in enumerate(crop_h): 122 | for j, w in enumerate(crop_w): 123 | if abs(i - j) <= self.max_distort: 124 | pairs.append((w, h)) 125 | 126 | crop_pair = random.choice(pairs) 127 | if not self.fix_crop: 128 | w_offset = random.randint(0, image_w - crop_pair[0]) 129 | h_offset = random.randint(0, image_h - crop_pair[1]) 130 | else: 131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 132 | 133 | return crop_pair[0], crop_pair[1], w_offset, h_offset 134 | 135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 137 | return random.choice(offsets) 138 | 139 | @staticmethod 140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 141 | w_step = (image_w - crop_w) // 4 142 | h_step = (image_h - crop_h) // 4 143 | 144 | ret = list() 145 | ret.append((0, 0)) # upper left 146 | ret.append((4 * w_step, 0)) # upper right 147 | ret.append((0, 4 * h_step)) # lower left 148 | ret.append((4 * w_step, 4 * h_step)) # lower right 149 | ret.append((2 * w_step, 2 * h_step)) # center 150 | 151 | if more_fix_crop: 152 | ret.append((0, 2 * h_step)) # center left 153 | ret.append((4 * w_step, 2 * h_step)) # center right 154 | ret.append((2 * w_step, 4 * h_step)) # lower center 155 | ret.append((2 * w_step, 0 * h_step)) # upper center 156 | 157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 161 | return ret 162 | 163 | 164 | class GroupMultiScaleTwoResizedCrop(object): 165 | 166 | def __init__(self, input_size, input_size_1, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 167 | self.scales = scales if scales is not None else [1, .875, .75, .66] 168 | self.max_distort = max_distort 169 | self.fix_crop = fix_crop 170 | self.more_fix_crop = more_fix_crop 171 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 172 | self.input_size_1 = input_size_1 if not isinstance(input_size_1, int) else [input_size_1, input_size_1] 173 | self.interpolation = Image.BILINEAR 174 | 175 | def __call__(self, img_tuple): 176 | img_group, label = img_tuple 177 | 178 | im_size = img_group[0].size 179 | 180 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 181 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 182 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in 183 | crop_img_group] 184 | ret_img_group_1 = [img.resize((self.input_size_1[0], self.input_size_1[1]), self.interpolation) for img in 185 | crop_img_group] 186 | return (ret_img_group, ret_img_group_1, label) 187 | 188 | def _sample_crop_size(self, im_size): 189 | image_w, image_h = im_size[0], im_size[1] 190 | 191 | # find a crop size 192 | base_size = min(image_w, image_h) 193 | crop_sizes = [int(base_size * x) for x in self.scales] 194 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 195 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 196 | 197 | pairs = [] 198 | for i, h in enumerate(crop_h): 199 | for j, w in enumerate(crop_w): 200 | if abs(i - j) <= self.max_distort: 201 | pairs.append((w, h)) 202 | 203 | crop_pair = random.choice(pairs) 204 | if not self.fix_crop: 205 | w_offset = random.randint(0, image_w - crop_pair[0]) 206 | h_offset = random.randint(0, image_h - crop_pair[1]) 207 | else: 208 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 209 | 210 | return crop_pair[0], crop_pair[1], w_offset, h_offset 211 | 212 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 213 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 214 | return random.choice(offsets) 215 | 216 | @staticmethod 217 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 218 | w_step = (image_w - crop_w) // 4 219 | h_step = (image_h - crop_h) // 4 220 | 221 | ret = list() 222 | ret.append((0, 0)) # upper left 223 | ret.append((4 * w_step, 0)) # upper right 224 | ret.append((0, 4 * h_step)) # lower left 225 | ret.append((4 * w_step, 4 * h_step)) # lower right 226 | ret.append((2 * w_step, 2 * h_step)) # center 227 | 228 | if more_fix_crop: 229 | ret.append((0, 2 * h_step)) # center left 230 | ret.append((4 * w_step, 2 * h_step)) # center right 231 | ret.append((2 * w_step, 4 * h_step)) # lower center 232 | ret.append((2 * w_step, 0 * h_step)) # upper center 233 | 234 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 235 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 236 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 237 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 238 | return ret 239 | 240 | 241 | class Stack(object): 242 | 243 | def __init__(self, roll=False): 244 | self.roll = roll 245 | 246 | def __call__(self, img_tuple): 247 | img_group, label = img_tuple 248 | 249 | if img_group[0].mode == 'L': 250 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label) 251 | elif img_group[0].mode == 'RGB': 252 | if self.roll: 253 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label) 254 | else: 255 | return (np.concatenate(img_group, axis=2), label) 256 | 257 | 258 | class ToTorchFormatTensor(object): 259 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 260 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 261 | def __init__(self, div=True): 262 | self.div = div 263 | 264 | def __call__(self, pic_tuple): 265 | pic, label = pic_tuple 266 | 267 | if isinstance(pic, np.ndarray): 268 | # handle numpy array 269 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 270 | else: 271 | # handle PIL Image 272 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 273 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 274 | # put it from HWC to CHW format 275 | # yikes, this transpose takes 80% of the loading time/CPU 276 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 277 | return (img.float().div(255.) if self.div else img.float(), label) 278 | 279 | 280 | class IdentityTransform(object): 281 | 282 | def __call__(self, data): 283 | return data 284 | -------------------------------------------------------------------------------- /object_interaction/object_interaction_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import os 5 | import random 6 | 7 | import torch 8 | import torch.utils.data 9 | from iopath.common.file_io import g_pathmgr as pathmgr 10 | 11 | import sys 12 | parent = os.path.dirname(os.path.abspath(__file__)) 13 | parent_parent = os.path.join(parent, '../') 14 | sys.path.append(os.path.dirname(parent_parent)) 15 | 16 | from decoder.utils import decode_ffmpeg 17 | from object_interaction.object_interaction_utils import * 18 | from torchvision import transforms 19 | 20 | from pathlib import Path 21 | 22 | 23 | class Object_Interaction(torch.utils.data.Dataset): 24 | """ 25 | Object Interaction video loader. Construct the Object Interaction video loader, 26 | then sample clips from the videos. For training and validation, a single clip 27 | is randomly sampled from every video with normalization. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | mode, 33 | path_to_data_dir, 34 | path_to_csv, 35 | # transformation 36 | transform, 37 | # decoding settings 38 | num_frames=8, 39 | target_fps=30, 40 | num_sec=2, 41 | fps=4, 42 | # frame aug settings 43 | crop_size=224, 44 | # other parameters 45 | enable_multi_thread_decode=False, 46 | use_offset_sampling=True, 47 | inverse_uniform_sampling=False, 48 | num_retries=10, 49 | # object or no object 50 | object_interaction_ratio=0.5, 51 | ): 52 | """ 53 | Construct the Object Interaction video loader with a given csv file. The format of 54 | the csv file is: 55 | ``` 56 | animal_1, ds_type_1, video_id_1, segment_id_1, start_time_1, end_time_1, total_time_1, interacting_object_1, video_path_1 57 | animal_2, ds_type_2, video_id_2, segment_id_2, start_time_2, end_time_2, total_time_2, interacting_object_2, video_path_2 58 | ... 59 | animal_N, ds_type_N, video_id_N, segment_id_N, start_time_N, end_time_N, total_time_N, interacting_object_N, video_path_N 60 | ``` 61 | Args: 62 | mode (string): Options includes `train` or `test`. 63 | For the train and val mode, the data loader will take data 64 | from the train or val set, and sample one clip per video. 65 | path_to_data_dir (string): Path to EgoPet Dataset 66 | path_to_csv (string): Path to Object Interaction data 67 | num_frames (int): number of frames used for model 68 | object_interaction_ratio (float): ratio of clips with interactions during training 69 | """ 70 | # Only support train, val, and test mode. 71 | assert mode in [ 72 | "train", 73 | "test", 74 | ], "mode has to be 'train' or 'test'" 75 | self.mode = mode 76 | 77 | self._num_retries = num_retries 78 | self._path_to_data_dir = path_to_data_dir 79 | self._path_to_csv = path_to_csv 80 | self.object_interaction_ratio = object_interaction_ratio 81 | 82 | self._crop_size = crop_size 83 | 84 | self._num_frames = num_frames 85 | self._num_sec = num_sec 86 | self._target_fps = target_fps 87 | self._fps=fps 88 | 89 | self.transform = transform 90 | 91 | self._enable_multi_thread_decode = enable_multi_thread_decode 92 | self._inverse_uniform_sampling = inverse_uniform_sampling 93 | self._use_offset_sampling = use_offset_sampling 94 | 95 | print(self) 96 | print(locals()) 97 | self._construct_loader() 98 | 99 | def _construct_loader(self): 100 | """ 101 | Construct the video loader. 102 | """ 103 | self._object_path_to_videos = [] 104 | self._no_object_path_to_videos = [] 105 | self._object_labels_times = [] 106 | self._no_object_labels_times = [] 107 | 108 | with pathmgr.open(self._path_to_csv, "r") as f: 109 | for curr_clip in f.read().splitlines(): 110 | curr_values = curr_clip.split(',') 111 | if curr_values[0] != 'animal': 112 | animal, ds_type, video_id, segment_id, start_time, end_time, total_time, interacting_object, video_path = curr_values 113 | video_path = os.path.join(self._path_to_data_dir, video_path) 114 | start_time, end_time, interacting_object = start_time.split(';'), end_time.split(';'), interacting_object.split(';') 115 | start_time, end_time, total_time = [process_time(time) for time in start_time], [process_end_time(time, total_time) for time in end_time], process_time(total_time) 116 | 117 | assert len(start_time) == len(end_time) == len(interacting_object), "Error with csv on {curr_row}".format(curr_row=curr_values) 118 | 119 | # Getting None Segments inbetween videos 120 | start_time, end_time, interacting_object = get_none_segments(start_time, end_time, total_time, interacting_object) 121 | 122 | object_labels = [get_label(curr_interacting_object) for curr_interacting_object in interacting_object] 123 | assert len(start_time) == len(end_time) == len(interacting_object), "Error with processing" 124 | 125 | for i in range(len(start_time)): 126 | interaction = 0. if object_labels[i] == -1 else 1. 127 | if object_labels[i] == -1: 128 | object_labels[i] = 0 129 | 130 | labels_times = (interaction, object_labels[i], start_time[i], end_time[i], total_time) 131 | if interaction: 132 | self._object_path_to_videos.append(video_path) 133 | self._object_labels_times.append(labels_times) 134 | else: 135 | self._no_object_path_to_videos.append(video_path) 136 | self._no_object_labels_times.append(labels_times) 137 | 138 | self._path_to_videos = self._object_path_to_videos + self._no_object_path_to_videos 139 | self._labels_times = self._object_labels_times + self._no_object_labels_times 140 | 141 | assert ( 142 | len(self._path_to_videos) > 0 143 | ), "Failed to load Object Interaction from {}".format( 144 | self._path_to_csv 145 | ) 146 | print( 147 | "Constructing Object Interaction dataloader (size: {}) from {}".format( 148 | len(self._path_to_videos), self._path_to_csv 149 | ) 150 | ) 151 | 152 | def __getitem__(self, index): 153 | """ 154 | With probability self.object_interaction_ratio randomly choose a clip 155 | with an object interaction, otherwise randomly choose a clip without 156 | an object interaction. If the video cannot be fetched and decoded 157 | successfully, find a random video that can be decoded as a replacement. 158 | Args: 159 | index (int): the video index provided by the pytorch sampler. (not used) 160 | Returns: 161 | frames (tensor): the frames of sampled from the video. The dimension 162 | is `num frames` x `channel` x `height` x `width`. 163 | interaction (int): whether there is an interaction in the vidoe. 164 | 0 for no interaction, 1 for an interaction. 165 | object_label (int): the label of the current video. 166 | """ 167 | # Try to decode and sample a clip from a video. If the video can not be 168 | # decoded, repeatly find a random video replacement that can be decoded. 169 | for i_try in range(self._num_retries): 170 | if self.mode == 'train': 171 | # If training sample random video with object interaction 172 | # with prob self.object_interaction_ratio 173 | if random.random() < self.object_interaction_ratio: 174 | # Get sample with object interaction 175 | index = random.randint(0, len(self._object_path_to_videos) - 1) 176 | # Get a clip with an object interaction 177 | video_path = self._object_path_to_videos[index] 178 | label_time = self._object_labels_times[index] 179 | interaction, object_label, start_time, end_time, total_time = label_time 180 | else: 181 | # Get sample without object interaction 182 | index = random.randint(0, len(self._no_object_path_to_videos) - 1) 183 | # Get a clip without object interaction 184 | video_path = self._no_object_path_to_videos[index] 185 | label_time = self._no_object_labels_times[index] 186 | interaction, object_label, start_time, end_time, total_time = label_time 187 | elif self.mode == 'test': 188 | # If test sample provided index 189 | video_path = self._path_to_videos[index] 190 | label_time = self._labels_times[index] 191 | interaction, object_label, start_time, end_time, total_time = label_time 192 | 193 | # Decode Video 194 | try: 195 | if self.mode == 'train': 196 | # For training randomnly select start of clip 197 | start_seek = min(max(int((start_time + end_time - self._num_sec) / 2), start_time), max(end_time - self._num_sec, 0)) 198 | elif self.mode == 'test': 199 | # Gets as to close to the center as possible of current clip 200 | start_seek = min(max(int((start_time + end_time - self._num_sec) / 2), start_time), max(end_time - self._num_sec, 0)) 201 | 202 | num_sec = min(self._num_sec, end_time-start_time) 203 | frames = decode_ffmpeg(video_path, start_seek=start_seek, num_sec=num_sec, num_frames=self._num_frames, fps=self._fps) 204 | if frames.shape[0] == 0: 205 | raise ValueError('Decoder Error, 0 frames decoded at video path {video_path}, start_seek: {start_seek}, num_sec: {num_sec}, self._num_frames: {num_frames}, fps: {fps}, start_time: {start_time}, end_time: {end_time}, total_time: {total_time}'.format(video_path=video_path, start_seek=start_seek, num_sec=num_sec, num_frames=self._num_frames, fps=self._fps, start_time=start_time, end_time=end_time, total_time=total_time)) 206 | except Exception as e: 207 | print( 208 | "Failed to decode video idx {} from {} with error {}".format( 209 | index, video_path, e 210 | ) 211 | ) 212 | # Random selection logic in getitem so random video will be decoded 213 | return self.__getitem__(0) 214 | 215 | start_idx, end_idx = get_start_end_idx( 216 | frames.shape[0], self._num_frames, 0, 1 217 | ) 218 | frames = temporal_sampling( 219 | frames, start_idx, end_idx, self._num_frames 220 | ) 221 | 222 | frames = frames.permute(0, 3, 1, 2) / 255. 223 | frames = torch.stack([self.transform(f) for f in frames]) 224 | return frames, torch.tensor([interaction]), object_label 225 | else: 226 | raise RuntimeError( 227 | "Failed to fetch video after {} retries.".format(self._num_retries) 228 | ) 229 | 230 | def __len__(self): 231 | """ 232 | Returns: 233 | (int): the number of videos in the dataset. 234 | """ 235 | if self.mode == 'train': 236 | return 20000 237 | elif self.mode == 'test': 238 | return self.num_videos 239 | 240 | @property 241 | def num_videos(self): 242 | """ 243 | Returns: 244 | (int): the number of videos in the dataset. 245 | """ 246 | return len(self._path_to_videos) 247 | -------------------------------------------------------------------------------- /cms_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision.utils import save_image, make_grid 9 | from torchvision import transforms, utils 10 | from torchvision.transforms.functional import rgb_to_grayscale 11 | 12 | import os 13 | from PIL import Image, ImageOps 14 | import cv2 15 | import random 16 | import copy 17 | import pandas as pd 18 | from scipy.signal import savgol_filter 19 | import numpy as np 20 | from tqdm import tqdm 21 | import warnings 22 | import matplotlib.pyplot as plt 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | class VisionPropDataset(Dataset): 27 | 28 | def __init__(self, root_path, config, transforms, mode='train', debug=False): 29 | 30 | self.config = config 31 | self.mode = mode 32 | self.root_path = root_path 33 | self.normalization_coeffs = None 34 | self.debug = debug 35 | self.experiments = [] 36 | self.last_frame_idx = 0 37 | self.input_prop_fts = [] 38 | self.images_path = [] 39 | self.label_features = [] 40 | self.inputs = [] 41 | self.targets = [] 42 | self._add_features() 43 | self.rollout_boundaries = [0] # boundaries between different runs 44 | self.rollout_ts = [] 45 | self.prop_latent = [] 46 | self.setup_data_pipeline(transforms) 47 | if mode == 'deploy': 48 | return 49 | self.num_samples = 0 50 | file_rootname = 'rollout_' 51 | print("------------------------------------------") 52 | print('Building %s Dataset' % ("Training" if mode == 'train' else "Validation")) 53 | print("------------------------------------------") 54 | if os.path.isdir(root_path): 55 | for root, dirs, files in os.walk(root_path, topdown=True, followlinks=True): 56 | for name in dirs: 57 | if name.startswith(file_rootname): 58 | self.experiments.append(os.path.join(root, name)) 59 | else: 60 | assert False, "Provided dataset root is neither a file nor a directory!" 61 | 62 | self.num_experiments = len(self.experiments) 63 | assert self.num_experiments > 0, 'No valid data found!' 64 | print('Dataset contains %d experiments.' % self.num_experiments) 65 | 66 | # assert self.config.lookhead[0] > 0.0, "Lookhead should be larger than 0.0 for training" 67 | 68 | # numpy arrays to store the raw data and compute mean/std for normalization 69 | self.raw_prop_inputs = None 70 | 71 | print("Decoding data...") 72 | self.experiments = sorted(self.experiments) 73 | for exp in tqdm(self.experiments): 74 | try: 75 | self._decode_experiment(exp) 76 | except Exception as e: 77 | print(e) 78 | 79 | if len(self.inputs) == 0: 80 | raise IOError("Did not find any file in the dataset folder") 81 | self.inputs = torch.from_numpy(np.vstack(self.inputs).astype(np.float32)) 82 | self.targets = torch.from_numpy(np.vstack(self.targets).astype(np.float32)) 83 | self.rollout_ts = np.vstack(self.rollout_ts).astype(np.float32) 84 | self.prop_latent = np.vstack(self.prop_latent).astype(np.float32) 85 | 86 | # this computes the normalization 87 | if self.mode == 'train': 88 | self._preprocess_dataset() 89 | 90 | print('Found {} samples belonging to {} experiments:'.format( 91 | self.num_samples, self.num_experiments)) 92 | 93 | def __len__(self): 94 | return self.num_samples 95 | 96 | def _add_features(self): 97 | self.input_prop_fts += ["rpy_0", 98 | "rpy_1"] 99 | 100 | joint_dim = 12 101 | action_dim = 12 102 | latent_dim = self.config.latent_dim 103 | self.latent_dim = latent_dim 104 | 105 | for i in range(joint_dim): 106 | self.input_prop_fts.append("joint_angles_{}".format(i)) 107 | for i in range(joint_dim): 108 | self.input_prop_fts.append("joint_vel_{}".format(i)) 109 | for i in range(action_dim): 110 | self.input_prop_fts.append("last_action_{}".format(i)) 111 | self.input_prop_fts.extend(["command_0", "command_1"]) 112 | 113 | if self.config.input_use_depth: 114 | self.input_frame_fts = ["depth_frame_counter"] 115 | else: 116 | self.input_frame_fts = ["frame_counter"] 117 | self.input_fts = self.input_prop_fts + self.input_frame_fts 118 | self.target_fts = [] 119 | for i in range(self.config.latent_dim): 120 | self.target_fts.append(f"prop_latent_{i}") 121 | 122 | def process_raw_latent(self, data, ts): 123 | 124 | data_freq = 90 125 | lookheads = self.config.lookhead 126 | self.num_lookheads = len(lookheads) 127 | 128 | predictive_latent = [] 129 | # Latent smoothing 130 | for i in range(data.shape[1]): 131 | data[:, i] = savgol_filter(data[:, i], 41, 3) 132 | 133 | # Latent smoothing 134 | for i in range(data.shape[1]): 135 | time_shifted_data = np.zeros_like(data[:,i]) 136 | for k in range(len(lookheads)): 137 | # Advance only for future geometry, without any smoothing 138 | current_lookhead = int(lookheads[k]*data_freq) 139 | time_shifted_data = np.roll(data[:,i], 140 | int(-current_lookhead)) 141 | if current_lookhead >= 0: 142 | time_shifted_data[-current_lookhead:] = data[-current_lookhead:, i] 143 | else: 144 | time_shifted_data[:-current_lookhead] = data[:-current_lookhead, i] 145 | predictive_latent.append(np.expand_dims(time_shifted_data,1)) 146 | 147 | predictive_latent = np.hstack(predictive_latent) 148 | 149 | return predictive_latent 150 | 151 | def _decode_experiment(self, dir_subpath): 152 | propr_file = os.path.join(dir_subpath, "proprioception.csv") 153 | assert os.path.isfile(propr_file), "Not Found proprioception file" 154 | df_prop = pd.read_csv(propr_file, delimiter=',') 155 | 156 | current_img_paths = [] 157 | if self.config.input_use_imgs: 158 | if self.config.input_use_depth: 159 | r_ext = '.tiff' 160 | else: 161 | r_ext = '.jpg' 162 | img_dir = os.path.join(dir_subpath, "img") 163 | for f in os.listdir(img_dir): 164 | ext = os.path.splitext(f)[1] 165 | if ext.lower() not in [r_ext]: 166 | continue 167 | current_img_paths.append(os.path.join(img_dir,f)) 168 | if len(current_img_paths) == 0: 169 | raise IOError("Not found images") 170 | self.images_path.extend(sorted(current_img_paths)) 171 | 172 | if self.debug: 173 | print("Average sampling frequency of proprioception is %.6f" % ( 174 | 1.0 / np.mean(np.diff(np.unique(df_prop["time_from_start"].values))))) 175 | inputs, targets = df_prop[self.input_fts].values, df_prop[self.target_fts].values 176 | 177 | inputs[:,-1] += self.last_frame_idx 178 | 179 | 180 | ts = df_prop["time_from_start"] 181 | if self.config.input_use_depth: 182 | fc = df_prop["depth_frame_counter"] 183 | else: 184 | fc = df_prop["frame_counter"] 185 | 186 | prop_l = targets 187 | targets = self.process_raw_latent(prop_l, ts) 188 | 189 | self.inputs.append(inputs) 190 | self.targets.append(targets) 191 | 192 | final_idx = len(self.input_prop_fts) 193 | input_prop_features_v = inputs[:, :final_idx] 194 | 195 | if self.raw_prop_inputs is None: 196 | self.raw_prop_inputs = input_prop_features_v 197 | self.raw_latent_target = targets 198 | else: 199 | self.raw_prop_inputs = np.concatenate([self.raw_prop_inputs, input_prop_features_v], axis=0) 200 | self.raw_latent_target = np.concatenate([self.raw_latent_target, targets], axis=0) 201 | self.last_frame_idx += len(current_img_paths) 202 | self.num_samples += inputs.shape[0] 203 | #self.num_samples += len(idxs) 204 | self.rollout_boundaries.append(self.num_samples) 205 | self.rollout_ts.append(np.expand_dims(fc[:inputs.shape[0]], axis=1)) 206 | self.prop_latent.append(prop_l) 207 | 208 | def _preprocess_dataset(self): 209 | if self.normalization_coeffs is None: 210 | self.input_prop_mean = np.mean(self.raw_prop_inputs, axis=0).astype(np.float32) 211 | self.input_prop_std = np.std(self.raw_prop_inputs, axis=0).astype(np.float32) 212 | self.target_latent_mean = np.mean(self.raw_latent_target, axis=0).astype(np.float32) 213 | self.target_latent_std = np.std(self.raw_latent_target, axis=0).astype(np.float32) 214 | else: 215 | self.input_prop_mean = self.normalization_coeffs[0] 216 | self.input_prop_std = self.normalization_coeffs[1] 217 | self.target_latent_mean = self.normalization_coeffs[2] 218 | self.target_latent_std = self.normalization_coeffs[3] 219 | 220 | def get_normalization(self): 221 | return self.input_prop_mean, self.input_prop_std, self.target_latent_mean, self.target_latent_std 222 | 223 | def setup_data_pipeline(self, transforms): 224 | if self.config.input_use_depth: 225 | self.preprocess_pipeline = transforms.Compose([ 226 | transforms.ToTensor(), 227 | transforms.CenterCrop((240,320)), 228 | ]) 229 | else: 230 | # self.preprocess_pipeline = transforms.Compose([ 231 | # transforms.ToTensor(), 232 | # transforms.CenterCrop((240,320)), 233 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 234 | # ]) 235 | 236 | self.preprocess_pipeline = transforms 237 | 238 | def image_processing(self, fname): 239 | if self.config.input_use_depth: 240 | input_image = cv2.imread(fname, cv2.IMREAD_ANYDEPTH) 241 | input_array = np.asarray(input_image, dtype=np.float32) 242 | input_array = np.minimum(input_array, 4000) 243 | input_array = (input_array - 2000) / 2000 244 | else: 245 | input_array = Image.open(fname) 246 | return input_array 247 | 248 | def preprocess_data(self, prop_numpy, img_numpy): 249 | prop = torch.from_numpy(prop_numpy.astype(np.float32)) 250 | try: 251 | img_numpy = np.stack(img_numpy, axis=-1).astype(np.float32) 252 | except: 253 | print(img_numpy[0].shape) 254 | print(img_numpy[1].shape) 255 | print(img_numpy[2].shape) 256 | if self.config.input_use_depth: 257 | img_numpy = np.minimum(img_numpy, 4000) 258 | img_numpy = (img_numpy - 2000) / 2000 259 | img = self.preprocess_pipeline(img_numpy) 260 | return prop, img 261 | 262 | def __getitem__(self, idx): 263 | prop_data = np.zeros((self.config.history_len,len(self.input_prop_fts)), dtype=np.float32) 264 | start_idx = np.maximum(0, idx - self.config.history_len) 265 | actual_history_length = idx - start_idx 266 | if actual_history_length > 0: 267 | prop_data[-actual_history_length:] = self.inputs[start_idx:idx, :len(self.input_prop_fts)] 268 | else: 269 | prop_data[-1] = self.inputs[idx, :len(self.input_prop_fts)] 270 | 271 | target = self.targets[idx] 272 | frame_skip = self.config.frame_skip # take one every n 273 | if self.config.input_use_imgs > 0: 274 | frame_idx_start = int(self.inputs[idx][-1].numpy()) 275 | imgs = [] 276 | for i in reversed(range(self.config.input_use_imgs)): 277 | frame_idx = np.maximum(0, frame_idx_start - i*frame_skip) 278 | frame_path = self.images_path[frame_idx] 279 | img = self.image_processing(frame_path) 280 | img = self.preprocess_pipeline(img) 281 | imgs.append(img) 282 | 283 | 284 | if self.config.grayscale: 285 | imgs = rgb_to_grayscale(imgs, num_output_channels=1).squeeze(1) 286 | 287 | imgs = torch.stack(imgs, dim=1) 288 | 289 | return prop_data, imgs, target 290 | else: 291 | return prop_data, 0.0, target 292 | -------------------------------------------------------------------------------- /locomotion_prediction/locomotion_prediction_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | import os 5 | import random 6 | import math 7 | 8 | import torch 9 | import numpy as np 10 | import torch.utils.data 11 | from iopath.common.file_io import g_pathmgr as pathmgr 12 | 13 | import sys 14 | parent = os.path.dirname(os.path.abspath(__file__)) 15 | parent_parent = os.path.join(parent, '../') 16 | sys.path.append(os.path.dirname(parent_parent)) 17 | 18 | from locomotion_prediction.locomotion_prediction_utils import * 19 | from decoder.utils import decode_ffmpeg 20 | from torchvision import transforms 21 | from evo.tools import file_interface 22 | 23 | from pathlib import Path 24 | 25 | 26 | class Locomotion_Prediction_dataloader(torch.utils.data.Dataset): 27 | """ 28 | Locomotion Prediction video loader. Construct the Locomotion Prediction video loader, 29 | then sample clips from the videos. For training, a single clip is randomly sampled 30 | from every video with normalization. For validation, the video path and trajectories are 31 | returned. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | mode, 37 | path_to_data_dir, 38 | path_to_trajectories_dir, 39 | path_to_csv, 40 | # transformation 41 | transform, 42 | # decoding settings 43 | fps=30, 44 | # frame aug settings 45 | crop_size=224, 46 | # pose estimation settings 47 | animals=['cat', 'dog'], 48 | num_condition_frames=16, 49 | num_pose_prediction=16, 50 | scale_invariance='None', 51 | pps=30, 52 | # other parameters 53 | enable_multi_thread_decode=False, 54 | use_offset_sampling=True, 55 | inverse_uniform_sampling=False, 56 | num_retries=10, 57 | ): 58 | """ 59 | Construct the Object Interaction video loader with a given csv file. The format of 60 | the csv file is: 61 | ``` 62 | animal_1, ds_type_1, segment_id_1, stride_1, start_time_1, end_time_1, video_path_1 63 | animal_2, ds_type_2, segment_id_2, stride_2, start_time_2, end_time_2, video_path_2 64 | ... 65 | animal_N, ds_type_N, segment_id_N, stride_N, start_time_N, end_time_N, video_path_N 66 | ``` 67 | Args: 68 | mode (string): Options includes `train` or `test`. 69 | For the train and val mode, the data loader will take data 70 | from the train or val set, and sample one clip per video. 71 | path_to_data_dir (string): Path to EgoPet Dataset 72 | path_to_csv (string): Path to Object Interaction data 73 | num_frames (int): number of frames used for model 74 | num_condition_frames (int): number of frames to condition model on 75 | num_pose_prediction (int): number of future poses to predict 76 | """ 77 | # Only support train, val, and test mode. 78 | assert mode in [ 79 | "train", 80 | "test", 81 | ], "mode has to be 'train' or 'test'" 82 | self.mode = mode 83 | 84 | self._num_retries = num_retries 85 | self._path_to_data_dir = path_to_data_dir 86 | self._path_to_trajectories_dir = path_to_trajectories_dir 87 | self._path_to_csv = path_to_csv 88 | 89 | self._crop_size = crop_size 90 | 91 | self._num_frames = num_condition_frames 92 | self._num_sec = math.ceil((num_condition_frames + num_pose_prediction) / fps) 93 | self._fps=fps 94 | self._pps = pps 95 | self._pose_skip = fps // pps 96 | 97 | self.transform = transform 98 | 99 | # Pose Estimation Settings 100 | self._animals = animals 101 | self._num_condition_frames = num_condition_frames 102 | self._num_pose_prediction = num_pose_prediction 103 | self._scale_invariance = scale_invariance 104 | 105 | self._enable_multi_thread_decode = enable_multi_thread_decode 106 | self._inverse_uniform_sampling = inverse_uniform_sampling 107 | self._use_offset_sampling = use_offset_sampling 108 | self._num_retries = num_retries 109 | 110 | print(self) 111 | print(locals()) 112 | self._construct_loader() 113 | 114 | def _construct_loader(self): 115 | """ 116 | Construct the video loader. 117 | """ 118 | self._path_to_videos = [] 119 | self._path_to_trajectories = [] 120 | self._start_times = [] 121 | self._end_times = [] 122 | self._stride = [] 123 | 124 | with pathmgr.open(self._path_to_csv, "r") as f: 125 | for curr_clip in f.read().splitlines(): 126 | curr_values = curr_clip.split(',') 127 | 128 | animal, ds_type, segment_id, stride, start_time, end_time, video_path = curr_values # NEED TO GENERATE THIS CSV 129 | video_path = os.path.join(self._path_to_data_dir, video_path) 130 | if ((self.mode == 'train' and ds_type == 'training_set') or (self.mode == 'test' and ds_type == 'validation_set')) and stride != -1 and animal in self._animals: 131 | start_time, end_time = int(start_time), int(end_time) 132 | 133 | if self.mode == 'test': 134 | end_time = min(end_time, 50) # Clip Evaluation Videos to 50 Seconds 135 | 136 | trajectory = "{}_calib_eth_stride_{}_interp.txt".format(segment_id, stride) 137 | trajectory_path = os.path.join(self._path_to_trajectories_dir, trajectory) 138 | 139 | if os.path.isfile(trajectory_path) and os.path.isfile(video_path) and (end_time - start_time) >= self._num_sec: 140 | video_path = os.path.join(self._path_to_data_dir, video_path) 141 | self._path_to_videos.append(video_path) 142 | self._path_to_trajectories.append(trajectory_path) 143 | self._start_times.append(start_time) 144 | self._end_times.append(end_time) 145 | self._stride.append(stride) 146 | 147 | assert ( 148 | len(self._path_to_videos) > 0 149 | ), "Failed to load Pose Trajectories from {}".format( 150 | self._path_to_csv 151 | ) 152 | print( 153 | "Constructing Object Interaction dataloader (size: {}) from {}".format( 154 | len(self._path_to_videos), self._path_to_csv 155 | ) 156 | ) 157 | 158 | 159 | def _get_sub_clip(self, video_path, trajectory_path, sub_start_idx, sub_end_idx): 160 | """ 161 | Get sub clip of video of from sub_start_idx to sub_end_idx. 162 | Args: 163 | video_path (string): path to video 164 | trajectory_path (string): path to associated trajectory 165 | sub_start_idx (int): start idx for this sub clip 166 | sub_end_idx (int): end idx for this sub clip 167 | Returns: 168 | frames (tensor): the frames of sampled from the video. The dimension 169 | is `num_condition_frames` x `channel` x `height` x `width`. 170 | dP_tensor (tensor): the relative poses for the entire clip. 171 | """ 172 | # For validation should get the entire segment 173 | start_seek = sub_start_idx // self._fps 174 | num_sec = (sub_end_idx - sub_start_idx) / self._fps + 1.0 175 | dP_tensor = relative_poses(trajectory_path, sub_start_idx, sub_end_idx, scale_invariance=None, pose_skip=1) 176 | 177 | real_num_frames = (sub_end_idx - sub_start_idx) 178 | frame_relative_start_idx = sub_start_idx - start_seek * self._fps 179 | num_frames = int(num_sec * self._fps + 512) # Set the num_frames to be more frames than decoded in order to get all the decoded frames 180 | frames = decode_ffmpeg(video_path, start_seek=start_seek, num_sec=num_sec, num_frames=num_frames, fps=self._fps) 181 | frames = frames[frame_relative_start_idx:frame_relative_start_idx + real_num_frames] 182 | 183 | frames = frames.permute(0, 3, 1, 2) / 255. 184 | frames = torch.stack([self.transform(f) for f in frames]) 185 | 186 | return frames, dP_tensor 187 | 188 | 189 | def __getitem__(self, index): 190 | """ 191 | If self.mode is train, randomnly choose a self.num_condition_frames + 192 | self.num_pose_prediction frame clip. If self.mode is test, choose the 193 | entire video segment. If the video cannot be fetched and decoded 194 | successfully, find a random video that can be decoded as a replacement. 195 | Args: 196 | index (int): the video index provided by the pytorch sampler. (not used) 197 | Returns: 198 | frames (tensor): the frames of sampled from the video. The dimension 199 | is `num_condition_frames` x `channel` x `height` x `width`. 200 | cond_poses (tensor): the associated poses to the conditioning frames. 201 | The dimension is `num_condition_frames` x `channel` x `height` x `width`. 202 | dP_tensor (tensor): the relative poses for the entire clip. 203 | start_idx (int): the start index of this clip 204 | end_idx (int): the end index of this clip 205 | """ 206 | for i_try in range(self._num_retries): 207 | if self.mode == 'train': # If 'train', sample a random video but if 'test' use index 208 | index = random.randint(0, len(self._path_to_videos) - 1) 209 | 210 | video_path = self._path_to_videos[index] 211 | trajectory_path = self._path_to_trajectories[index] 212 | start_time = self._start_times[index] 213 | end_time = self._end_times[index] 214 | stride = self._stride[index] 215 | 216 | # Decode Video 217 | try: 218 | if self.mode == 'train': 219 | # For training randomnly select start of clip 220 | start_seek = np.random.randint(start_time, max(end_time - self._num_sec, 1)) 221 | start_idx = (start_seek * self._fps) 222 | end_idx = start_idx + self._num_condition_frames + self._num_pose_prediction 223 | 224 | if self._scale_invariance == 'dir': 225 | dP_tensor, _ = relative_poses(trajectory_path, start_idx + self._num_condition_frames, end_idx, self._scale_invariance, self._pose_skip) 226 | 227 | frames = decode_ffmpeg(video_path, start_seek=start_seek, num_sec=self._num_sec, num_frames=self._num_frames, fps=self._fps) 228 | elif self.mode == 'test': 229 | # For validation should get the entire segment 230 | start_idx = start_time * self._fps 231 | num_sub_clips = ((end_time - start_time) * self._fps) // (self._num_condition_frames + self._num_pose_prediction) 232 | end_idx = start_idx + num_sub_clips * (self._num_condition_frames + self._num_pose_prediction) 233 | return video_path, trajectory_path, start_idx, end_idx 234 | 235 | if frames.shape[0] == 0: 236 | raise ValueError('Decoder Error, 0 frames decoded at video path {video_path}, start_seek: {start_seek}, num_sec: {num_sec}, self._num_frames: {num_frames}, fps: {fps}, start_time: {start_time}, end_time: {end_time}, total_time: {total_time}'.format(video_path=video_path, start_seek=start_seek, num_sec=num_sec, num_frames=self._num_frames, fps=self._fps, start_time=start_time, end_time=end_time, total_time=total_time)) 237 | except Exception as e: 238 | print( 239 | "Failed to decode video idx {} from {} with error {}".format( 240 | index, video_path, e 241 | ) 242 | ) 243 | # Random selection logic in getitem so random video will be decoded 244 | return self.__getitem__(0) 245 | 246 | # Keep just the first num_condition_frames frames 247 | if self.mode == 'train': 248 | frames = frames[:self._num_condition_frames] 249 | elif self.mode == 'test': 250 | frames = frames[start_idx:end_idx] 251 | frames = frames.permute(0, 3, 1, 2) / 255. 252 | frames = torch.stack([self.transform(f) for f in frames]) 253 | return frames, dP_tensor.to(torch.float32), start_idx, end_idx 254 | else: 255 | raise RuntimeError( 256 | "Failed to fetch video after {} retries.".format(self._num_retries) 257 | ) 258 | 259 | 260 | def __len__(self): 261 | """ 262 | Returns: 263 | (int): the number of videos in the dataset. 264 | """ 265 | if self.mode == 'train': 266 | return 20000 267 | elif self.mode == 'test': 268 | return self.num_videos 269 | 270 | 271 | @property 272 | def num_videos(self): 273 | """ 274 | Returns: 275 | (int): the number of videos in the dataset. 276 | """ 277 | return len(self._path_to_videos) -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | """ Mixup and Cutmix 2 | 3 | Papers: 4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 5 | 6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) 7 | 8 | Code Reference: 9 | CutMix: https://github.com/clovaai/CutMix-PyTorch 10 | 11 | Hacked together by / Copyright 2019, Ross Wightman 12 | """ 13 | import numpy as np 14 | import torch 15 | 16 | 17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): 18 | x = x.long().view(-1, 1) 19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) 20 | 21 | 22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): 23 | off_value = smoothing / num_classes 24 | on_value = 1. - smoothing + off_value 25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 27 | return y1 * lam + y2 * (1. - lam) 28 | 29 | 30 | def rand_bbox(img_shape, lam, margin=0., count=None): 31 | """ Standard CutMix bounding-box 32 | Generates a random square bbox based on lambda value. This impl includes 33 | support for enforcing a border margin as percent of bbox dimensions. 34 | 35 | Args: 36 | img_shape (tuple): Image shape as tuple 37 | lam (float): Cutmix lambda value 38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 39 | count (int): Number of bbox to generate 40 | """ 41 | ratio = np.sqrt(1 - lam) 42 | img_h, img_w = img_shape[-2:] 43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 47 | yl = np.clip(cy - cut_h // 2, 0, img_h) 48 | yh = np.clip(cy + cut_h // 2, 0, img_h) 49 | xl = np.clip(cx - cut_w // 2, 0, img_w) 50 | xh = np.clip(cx + cut_w // 2, 0, img_w) 51 | return yl, yh, xl, xh 52 | 53 | 54 | def rand_bbox_minmax(img_shape, minmax, count=None): 55 | """ Min-Max CutMix bounding-box 56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox 57 | based on min/max percent values applied to each dimension of the input image. 58 | 59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. 60 | 61 | Args: 62 | img_shape (tuple): Image shape as tuple 63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size) 64 | count (int): Number of bbox to generate 65 | """ 66 | assert len(minmax) == 2 67 | img_h, img_w = img_shape[-2:] 68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) 69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) 70 | yl = np.random.randint(0, img_h - cut_h, size=count) 71 | xl = np.random.randint(0, img_w - cut_w, size=count) 72 | yu = yl + cut_h 73 | xu = xl + cut_w 74 | return yl, yu, xl, xu 75 | 76 | 77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): 78 | """ Generate bbox and apply lambda correction. 79 | """ 80 | if ratio_minmax is not None: 81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) 82 | else: 83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 84 | if correct_lam or ratio_minmax is not None: 85 | bbox_area = (yu - yl) * (xu - xl) 86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) 87 | return (yl, yu, xl, xu), lam 88 | 89 | 90 | class Mixup: 91 | """ Mixup/Cutmix that applies different params to each element or whole batch 92 | 93 | Args: 94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 97 | prob (float): probability of applying mixup or cutmix per batch or element 98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 101 | label_smoothing (float): apply label smoothing to the mixed target tensor 102 | num_classes (int): number of classes for target 103 | """ 104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 106 | self.mixup_alpha = mixup_alpha 107 | self.cutmix_alpha = cutmix_alpha 108 | self.cutmix_minmax = cutmix_minmax 109 | if self.cutmix_minmax is not None: 110 | assert len(self.cutmix_minmax) == 2 111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 112 | self.cutmix_alpha = 1.0 113 | self.mix_prob = prob 114 | self.switch_prob = switch_prob 115 | self.label_smoothing = label_smoothing 116 | self.num_classes = num_classes 117 | self.mode = mode 118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 120 | 121 | def _params_per_elem(self, batch_size): 122 | lam = np.ones(batch_size, dtype=np.float32) 123 | use_cutmix = np.zeros(batch_size, dtype=np.bool) 124 | if self.mixup_enabled: 125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob 127 | lam_mix = np.where( 128 | use_cutmix, 129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), 130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) 131 | elif self.mixup_alpha > 0.: 132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) 133 | elif self.cutmix_alpha > 0.: 134 | use_cutmix = np.ones(batch_size, dtype=np.bool) 135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) 136 | else: 137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) 139 | return lam, use_cutmix 140 | 141 | def _params_per_batch(self): 142 | lam = 1. 143 | use_cutmix = False 144 | if self.mixup_enabled and np.random.rand() < self.mix_prob: 145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 146 | use_cutmix = np.random.rand() < self.switch_prob 147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ 148 | np.random.beta(self.mixup_alpha, self.mixup_alpha) 149 | elif self.mixup_alpha > 0.: 150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 151 | elif self.cutmix_alpha > 0.: 152 | use_cutmix = True 153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 154 | else: 155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." 156 | lam = float(lam_mix) 157 | return lam, use_cutmix 158 | 159 | def _mix_elem(self, x): 160 | batch_size = len(x) 161 | lam_batch, use_cutmix = self._params_per_elem(batch_size) 162 | x_orig = x.clone() # need to keep an unmodified original for mixing source 163 | for i in range(batch_size): 164 | j = batch_size - i - 1 165 | lam = lam_batch[i] 166 | if lam != 1.: 167 | if use_cutmix[i]: 168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh] 171 | lam_batch[i] = lam 172 | else: 173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 175 | 176 | def _mix_pair(self, x): 177 | batch_size = len(x) 178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 179 | x_orig = x.clone() # need to keep an unmodified original for mixing source 180 | for i in range(batch_size // 2): 181 | j = batch_size - i - 1 182 | lam = lam_batch[i] 183 | if lam != 1.: 184 | if use_cutmix[i]: 185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] 188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] 189 | lam_batch[i] = lam 190 | else: 191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam) 192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam) 193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) 195 | 196 | def _mix_batch(self, x): 197 | lam, use_cutmix = self._params_per_batch() 198 | if lam == 1.: 199 | return 1. 200 | if use_cutmix: 201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 204 | else: 205 | x_flipped = x.flip(0).mul_(1. - lam) 206 | x.mul_(lam).add_(x_flipped) 207 | return lam 208 | 209 | def __call__(self, x, target): 210 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 211 | if self.mode == 'elem': 212 | lam = self._mix_elem(x) 213 | elif self.mode == 'pair': 214 | lam = self._mix_pair(x) 215 | else: 216 | lam = self._mix_batch(x) 217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) 218 | return x, target 219 | 220 | 221 | class FastCollateMixup(Mixup): 222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch 223 | 224 | A Mixup impl that's performed while collating the batches. 225 | """ 226 | 227 | def _mix_elem_collate(self, output, batch, half=False): 228 | batch_size = len(batch) 229 | num_elem = batch_size // 2 if half else batch_size 230 | assert len(output) == num_elem 231 | lam_batch, use_cutmix = self._params_per_elem(num_elem) 232 | for i in range(num_elem): 233 | j = batch_size - i - 1 234 | lam = lam_batch[i] 235 | mixed = batch[i][0] 236 | if lam != 1.: 237 | if use_cutmix[i]: 238 | if not half: 239 | mixed = mixed.copy() 240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] 243 | lam_batch[i] = lam 244 | else: 245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 246 | np.rint(mixed, out=mixed) 247 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 248 | if half: 249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) 250 | return torch.tensor(lam_batch).unsqueeze(1) 251 | 252 | def _mix_pair_collate(self, output, batch): 253 | batch_size = len(batch) 254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) 255 | for i in range(batch_size // 2): 256 | j = batch_size - i - 1 257 | lam = lam_batch[i] 258 | mixed_i = batch[i][0] 259 | mixed_j = batch[j][0] 260 | assert 0 <= lam <= 1.0 261 | if lam < 1.: 262 | if use_cutmix[i]: 263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy() 266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] 267 | mixed_j[:, yl:yh, xl:xh] = patch_i 268 | lam_batch[i] = lam 269 | else: 270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) 271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) 272 | mixed_i = mixed_temp 273 | np.rint(mixed_j, out=mixed_j) 274 | np.rint(mixed_i, out=mixed_i) 275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) 276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) 277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) 278 | return torch.tensor(lam_batch).unsqueeze(1) 279 | 280 | def _mix_batch_collate(self, output, batch): 281 | batch_size = len(batch) 282 | lam, use_cutmix = self._params_per_batch() 283 | if use_cutmix: 284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 286 | for i in range(batch_size): 287 | j = batch_size - i - 1 288 | mixed = batch[i][0] 289 | if lam != 1.: 290 | if use_cutmix: 291 | mixed = mixed.copy() # don't want to modify the original while iterating 292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh] 293 | else: 294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) 295 | np.rint(mixed, out=mixed) 296 | output[i] += torch.from_numpy(mixed.astype(np.uint8)) 297 | return lam 298 | 299 | def __call__(self, batch, _=None): 300 | batch_size = len(batch) 301 | assert batch_size % 2 == 0, 'Batch size should be even when using this' 302 | half = 'half' in self.mode 303 | if half: 304 | batch_size //= 2 305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) 306 | if self.mode == 'elem' or self.mode == 'half': 307 | lam = self._mix_elem_collate(output, batch, half=half) 308 | elif self.mode == 'pair': 309 | lam = self._mix_pair_collate(output, batch) 310 | else: 311 | lam = self._mix_batch_collate(output, batch) 312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64) 313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') 314 | target = target[:batch_size] 315 | return output, target 316 | -------------------------------------------------------------------------------- /ssv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms 5 | from random_erasing import RandomErasing 6 | import warnings 7 | from decord import VideoReader, cpu 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | import video_transforms as video_transforms 11 | import volume_transforms as volume_transforms 12 | 13 | 14 | class SSVideoClsDataset(Dataset): 15 | """Load your own video classification dataset.""" 16 | 17 | def __init__(self, anno_path, data_path, mode='train', clip_len=8, 18 | crop_size=224, short_side_size=256, new_height=256, 19 | new_width=340, keep_aspect_ratio=True, num_segment=1, 20 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None): 21 | self.anno_path = anno_path 22 | self.data_path = data_path 23 | self.mode = mode 24 | self.clip_len = clip_len 25 | self.crop_size = crop_size 26 | self.short_side_size = short_side_size 27 | self.new_height = new_height 28 | self.new_width = new_width 29 | self.keep_aspect_ratio = keep_aspect_ratio 30 | self.num_segment = num_segment 31 | self.test_num_segment = test_num_segment 32 | self.num_crop = num_crop 33 | self.test_num_crop = test_num_crop 34 | self.args = args 35 | self.aug = False 36 | self.rand_erase = False 37 | 38 | if self.mode in ['train']: 39 | self.aug = True 40 | if self.args.reprob > 0: 41 | self.rand_erase = True 42 | if VideoReader is None: 43 | raise ImportError("Unable to import `decord` which is required to read videos.") 44 | 45 | import pandas as pd 46 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ') 47 | self.dataset_samples = list(cleaned.values[:, 0]) 48 | self.label_array = list(cleaned.values[:, 1]) 49 | if self.data_path is not None: 50 | self.dataset_samples = [os.path.join(self.data_path, p) for p in self.dataset_samples] 51 | 52 | 53 | if (mode == 'train'): 54 | pass 55 | 56 | elif (mode == 'validation'): 57 | self.data_transform = video_transforms.Compose([ 58 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'), 59 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)), 60 | volume_transforms.ClipToTensor(), 61 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | ]) 64 | elif mode == 'test': 65 | self.data_resize = video_transforms.Compose([ 66 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear') 67 | ]) 68 | self.data_transform = video_transforms.Compose([ 69 | volume_transforms.ClipToTensor(), 70 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]) 72 | ]) 73 | self.test_seg = [] 74 | self.test_dataset = [] 75 | self.test_label_array = [] 76 | for ck in range(self.test_num_segment): 77 | for cp in range(self.test_num_crop): 78 | for idx in range(len(self.label_array)): 79 | sample_label = self.label_array[idx] 80 | self.test_label_array.append(sample_label) 81 | self.test_dataset.append(self.dataset_samples[idx]) 82 | self.test_seg.append((ck, cp)) 83 | 84 | def __getitem__(self, index): 85 | if self.mode == 'train': 86 | args = self.args 87 | scale_t = 1 88 | 89 | sample = self.dataset_samples[index] 90 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C 91 | if len(buffer) == 0: 92 | while len(buffer) == 0: 93 | warnings.warn("video {} not correctly loaded during training".format(sample)) 94 | index = np.random.randint(self.__len__()) 95 | sample = self.dataset_samples[index] 96 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) 97 | 98 | if args.num_sample > 1: 99 | frame_list = [] 100 | label_list = [] 101 | index_list = [] 102 | for _ in range(args.num_sample): 103 | new_frames = self._aug_frame(buffer, args) 104 | label = self.label_array[index] 105 | frame_list.append(new_frames) 106 | label_list.append(label) 107 | index_list.append(index) 108 | return frame_list, label_list, index_list, {} 109 | else: 110 | buffer = self._aug_frame(buffer, args) 111 | return buffer, self.label_array[index], index, {} 112 | 113 | elif self.mode == 'validation': 114 | sample = self.dataset_samples[index] 115 | buffer = self.loadvideo_decord(sample) 116 | if len(buffer) == 0: 117 | while len(buffer) == 0: 118 | warnings.warn("video {} not correctly loaded during validation".format(sample)) 119 | index = np.random.randint(self.__len__()) 120 | sample = self.dataset_samples[index] 121 | buffer = self.loadvideo_decord(sample) 122 | buffer = self.data_transform(buffer) 123 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0] 124 | 125 | elif self.mode == 'test': 126 | sample = self.test_dataset[index] 127 | chunk_nb, split_nb = self.test_seg[index] 128 | buffer = self.loadvideo_decord(sample) 129 | 130 | while len(buffer) == 0: 131 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\ 132 | str(self.test_dataset[index]), chunk_nb, split_nb)) 133 | index = np.random.randint(self.__len__()) 134 | sample = self.test_dataset[index] 135 | chunk_nb, split_nb = self.test_seg[index] 136 | buffer = self.loadvideo_decord(sample) 137 | 138 | buffer = self.data_resize(buffer) 139 | if isinstance(buffer, list): 140 | buffer = np.stack(buffer, 0) 141 | 142 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \ 143 | / (self.test_num_crop - 1) 144 | temporal_start = chunk_nb # 0/1 145 | spatial_start = int(split_nb * spatial_step) 146 | if buffer.shape[1] >= buffer.shape[2]: 147 | buffer = buffer[temporal_start::2, \ 148 | spatial_start:spatial_start + self.short_side_size, :, :] 149 | else: 150 | buffer = buffer[temporal_start::2, \ 151 | :, spatial_start:spatial_start + self.short_side_size, :] 152 | 153 | buffer = self.data_transform(buffer) 154 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \ 155 | chunk_nb, split_nb 156 | else: 157 | raise NameError('mode {} unkown'.format(self.mode)) 158 | 159 | def _aug_frame( 160 | self, 161 | buffer, 162 | args, 163 | ): 164 | aug_transform = video_transforms.create_random_augment( 165 | input_size=(self.crop_size, self.crop_size), 166 | auto_augment=args.aa, 167 | interpolation=args.train_interpolation, 168 | ) 169 | 170 | buffer = [ 171 | transforms.ToPILImage()(frame) for frame in buffer 172 | ] 173 | 174 | buffer = aug_transform(buffer) 175 | 176 | buffer = [transforms.ToTensor()(img) for img in buffer] 177 | buffer = torch.stack(buffer) # T C H W 178 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 179 | 180 | # T H W C 181 | buffer = tensor_normalize( 182 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 183 | ) 184 | # T H W C -> C T H W. 185 | buffer = buffer.permute(3, 0, 1, 2) 186 | # Perform data augmentation. 187 | scl, asp = ( 188 | [0.08, 1.0], 189 | [0.75, 1.3333], 190 | ) 191 | 192 | buffer = spatial_sampling( 193 | buffer, 194 | spatial_idx=-1, 195 | min_scale=256, 196 | max_scale=320, 197 | crop_size=self.crop_size, 198 | random_horizontal_flip=False if args.data_set == 'SSV2' else True, 199 | inverse_uniform_sampling=False, 200 | aspect_ratio=asp, 201 | scale=scl, 202 | motion_shift=False 203 | ) 204 | 205 | if self.rand_erase: 206 | erase_transform = RandomErasing( 207 | args.reprob, 208 | mode=args.remode, 209 | max_count=args.recount, 210 | num_splits=args.recount, 211 | device="cpu", 212 | ) 213 | buffer = buffer.permute(1, 0, 2, 3) 214 | buffer = erase_transform(buffer) 215 | buffer = buffer.permute(1, 0, 2, 3) 216 | 217 | return buffer 218 | 219 | 220 | def loadvideo_decord(self, sample, sample_rate_scale=1): 221 | """Load video content using Decord""" 222 | fname = sample 223 | 224 | if not (os.path.exists(fname)): 225 | return [] 226 | 227 | # avoid hanging issue 228 | if os.path.getsize(fname) < 1 * 1024: 229 | print('SKIP: ', fname, " - ", os.path.getsize(fname)) 230 | return [] 231 | try: 232 | if self.keep_aspect_ratio: 233 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0)) 234 | else: 235 | vr = VideoReader(fname, width=self.new_width, height=self.new_height, 236 | num_threads=1, ctx=cpu(0)) 237 | except: 238 | print("video cannot be loaded by decord: ", fname) 239 | return [] 240 | 241 | if self.mode == 'test': 242 | all_index = [] 243 | tick = float(len(vr) - 1) / float(self.num_segment) 244 | interval_cross_view = tick / self.test_num_segment 245 | for i in range(self.num_segment): 246 | for ci in range(self.test_num_segment): 247 | start = int(np.round(tick * i)) 248 | end = int(np.round(tick * (i + 1))) 249 | if self.test_num_segment > 1: 250 | start = int(np.round(tick * i + interval_cross_view * ci)) 251 | end = int(np.round(tick * i + interval_cross_view * (ci + 1))) 252 | all_index.append((start + end) // 2) 253 | else: 254 | all_index.append((start + end) // 2) 255 | all_index.append((start + end) // 2) 256 | 257 | vr.seek(0) 258 | buffer = vr.get_batch(all_index).asnumpy() 259 | return buffer 260 | 261 | # handle temporal segments 262 | average_duration = float(len(vr) - 1) / self.num_segment 263 | all_index = [] 264 | for i in range(self.num_segment): 265 | start = int(np.round(average_duration * i)) 266 | end = int(np.round(average_duration * (i + 1))) 267 | all_index.append(int(np.random.randint(start, end + 1))) 268 | 269 | all_index = list(np.array(all_index)) 270 | vr.seek(0) 271 | buffer = vr.get_batch(all_index).asnumpy() 272 | return buffer 273 | 274 | def __len__(self): 275 | if self.mode != 'test': 276 | return len(self.dataset_samples) 277 | else: 278 | return len(self.test_dataset) 279 | 280 | 281 | def spatial_sampling( 282 | frames, 283 | spatial_idx=-1, 284 | min_scale=256, 285 | max_scale=320, 286 | crop_size=224, 287 | random_horizontal_flip=True, 288 | inverse_uniform_sampling=False, 289 | aspect_ratio=None, 290 | scale=None, 291 | motion_shift=False, 292 | ): 293 | """ 294 | Perform spatial sampling on the given video frames. If spatial_idx is 295 | -1, perform random scale, random crop, and random flip on the given 296 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 297 | with the given spatial_idx. 298 | Args: 299 | frames (tensor): frames of images sampled from the video. The 300 | dimension is `num frames` x `height` x `width` x `channel`. 301 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 302 | or 2, perform left, center, right crop if width is larger than 303 | height, and perform top, center, buttom crop if height is larger 304 | than width. 305 | min_scale (int): the minimal size of scaling. 306 | max_scale (int): the maximal size of scaling. 307 | crop_size (int): the size of height and width used to crop the 308 | frames. 309 | inverse_uniform_sampling (bool): if True, sample uniformly in 310 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 311 | scale. If False, take a uniform sample from [min_scale, 312 | max_scale]. 313 | aspect_ratio (list): Aspect ratio range for resizing. 314 | scale (list): Scale range for resizing. 315 | motion_shift (bool): Whether to apply motion shift for resizing. 316 | Returns: 317 | frames (tensor): spatially sampled frames. 318 | """ 319 | assert spatial_idx in [-1, 0, 1, 2] 320 | if spatial_idx == -1: 321 | if aspect_ratio is None and scale is None: 322 | frames, _ = video_transforms.random_short_side_scale_jitter( 323 | images=frames, 324 | min_size=min_scale, 325 | max_size=max_scale, 326 | inverse_uniform_sampling=inverse_uniform_sampling, 327 | ) 328 | frames, _ = video_transforms.random_crop(frames, crop_size) 329 | else: 330 | transform_func = ( 331 | video_transforms.random_resized_crop_with_shift 332 | if motion_shift 333 | else video_transforms.random_resized_crop 334 | ) 335 | frames = transform_func( 336 | images=frames, 337 | target_height=crop_size, 338 | target_width=crop_size, 339 | scale=scale, 340 | ratio=aspect_ratio, 341 | ) 342 | if random_horizontal_flip: 343 | frames, _ = video_transforms.horizontal_flip(0.5, frames) 344 | else: 345 | # The testing is deterministic and no jitter should be performed. 346 | # min_scale, max_scale, and crop_size are expect to be the same. 347 | assert len({min_scale, max_scale, crop_size}) == 1 348 | frames, _ = video_transforms.random_short_side_scale_jitter( 349 | frames, min_scale, max_scale 350 | ) 351 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx) 352 | return frames 353 | 354 | 355 | def tensor_normalize(tensor, mean, std): 356 | """ 357 | Normalize a given tensor by subtracting the mean and dividing the std. 358 | Args: 359 | tensor (tensor): tensor to normalize. 360 | mean (tensor or list): mean value to subtract. 361 | std (tensor or list): std to divide. 362 | """ 363 | if tensor.dtype == torch.uint8: 364 | tensor = tensor.float() 365 | tensor = tensor / 255.0 366 | if type(mean) == list: 367 | mean = torch.tensor(mean) 368 | if type(std) == list: 369 | std = torch.tensor(std) 370 | tensor = tensor - mean 371 | tensor = tensor / std 372 | return tensor 373 | -------------------------------------------------------------------------------- /modeling_student.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from functools import partial, reduce 7 | from operator import mul 8 | from einops import rearrange 9 | import torch.utils.checkpoint as checkpoint 10 | 11 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table, get_3d_sincos_pos_embed 12 | from timm.models.registry import register_model 13 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ 14 | from timm.models.layers import drop_path, to_2tuple 15 | 16 | 17 | def trunc_normal_(tensor, mean=0., std=1.): 18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 19 | 20 | 21 | __all__ = [ 22 | 'pretrain_masked_video_student_small_patch16_224', 23 | 'pretrain_masked_video_student_base_patch16_224', 24 | 'pretrain_masked_video_student_large_patch16_224', 25 | 'pretrain_masked_video_student_huge_patch16_224', 26 | ] 27 | 28 | 29 | class DropPath(nn.Module): 30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 31 | """ 32 | 33 | def __init__(self, drop_prob=None): 34 | super(DropPath, self).__init__() 35 | self.drop_prob = drop_prob 36 | 37 | def forward(self, x): 38 | return drop_path(x, self.drop_prob, self.training) 39 | 40 | def extra_repr(self) -> str: 41 | return 'p={}'.format(self.drop_prob) 42 | 43 | 44 | class Mlp(nn.Module): 45 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 46 | super().__init__() 47 | out_features = out_features or in_features 48 | hidden_features = hidden_features or in_features 49 | self.fc1 = nn.Linear(in_features, hidden_features) 50 | self.act = act_layer() 51 | self.fc2 = nn.Linear(hidden_features, out_features) 52 | self.drop = nn.Dropout(drop) 53 | 54 | def forward(self, x): 55 | x = self.fc1(x) 56 | x = self.act(x) 57 | x = self.fc2(x) 58 | x = self.drop(x) 59 | return x 60 | 61 | 62 | class PretrainVisionTransformerDecoder(nn.Module): 63 | """ Vision Transformer with support for patch or hybrid CNN input stage 64 | """ 65 | 66 | def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12, 67 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 68 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, use_checkpoint=False 69 | ): 70 | super().__init__() 71 | self.num_classes = num_classes 72 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 73 | self.patch_size = patch_size 74 | self.use_checkpoint = use_checkpoint 75 | 76 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 77 | self.blocks = nn.ModuleList([ 78 | Block( 79 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 80 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 81 | init_values=init_values) 82 | for i in range(depth)]) 83 | self.norm = norm_layer(embed_dim) 84 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 85 | 86 | self.apply(self._init_weights) 87 | 88 | def _init_weights(self, m): 89 | if isinstance(m, nn.Linear): 90 | nn.init.xavier_uniform_(m.weight) 91 | if isinstance(m, nn.Linear) and m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | elif isinstance(m, nn.LayerNorm): 94 | nn.init.constant_(m.bias, 0) 95 | nn.init.constant_(m.weight, 1.0) 96 | 97 | def get_num_layers(self): 98 | return len(self.blocks) 99 | 100 | @torch.jit.ignore 101 | def no_weight_decay(self): 102 | return {'pos_embed', 'cls_token'} 103 | 104 | def get_classifier(self): 105 | return self.head 106 | 107 | def reset_classifier(self, num_classes, global_pool=''): 108 | self.num_classes = num_classes 109 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 110 | 111 | def forward(self, x, return_token_num): 112 | if self.use_checkpoint: 113 | for blk in self.blocks: 114 | x = checkpoint.checkpoint(blk, x) 115 | else: 116 | for blk in self.blocks: 117 | x = blk(x) 118 | 119 | if return_token_num > 0: 120 | x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels 121 | else: 122 | x = self.head(self.norm(x)) 123 | 124 | return x 125 | 126 | 127 | class PretrainMaskedVideoStudent(nn.Module): 128 | 129 | def __init__(self, 130 | img_size=224, 131 | patch_size=16, 132 | encoder_in_chans=3, 133 | encoder_embed_dim=768, 134 | encoder_depth=12, 135 | encoder_num_heads=12, 136 | decoder_depth=4, 137 | feat_decoder_embed_dim=None, 138 | feat_decoder_num_heads=None, 139 | mlp_ratio=4., 140 | qkv_bias=False, 141 | qk_scale=None, 142 | drop_rate=0., 143 | attn_drop_rate=0., 144 | drop_path_rate=0., 145 | norm_layer=nn.LayerNorm, 146 | init_values=0., 147 | tubelet_size=2, 148 | num_frames=16, 149 | use_cls_token=False, 150 | target_feature_dim=768, 151 | target_video_feature_dim=768, 152 | use_checkpoint=False, 153 | ): 154 | super().__init__() 155 | self.use_cls_token = use_cls_token 156 | self.patch_embed = PatchEmbed( 157 | img_size=img_size, patch_size=patch_size, in_chans=encoder_in_chans, 158 | embed_dim=encoder_embed_dim, tubelet_size=tubelet_size, num_frames=num_frames) 159 | self.patch_size = self.patch_embed.patch_size 160 | num_patches = self.patch_embed.num_patches 161 | self.encoder_embed_dim = encoder_embed_dim 162 | self.tubelet_size = tubelet_size 163 | self.num_frames = num_frames 164 | self.use_checkpoint = use_checkpoint 165 | 166 | if use_cls_token: 167 | self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim)) 168 | else: 169 | self.cls_token = None 170 | 171 | # sine-cosine positional embeddings 172 | self.pos_embed = get_3d_sincos_pos_embed(embed_dim=encoder_embed_dim, 173 | grid_size=self.patch_embed.num_patches_h, 174 | t_size=self.patch_embed.num_patches_t) 175 | self.pos_embed = nn.Parameter(self.pos_embed, requires_grad=False) 176 | self.pos_embed.requires_grad = False 177 | 178 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)] # stochastic depth decay rule 179 | self.blocks = nn.ModuleList([ 180 | Block( 181 | dim=encoder_embed_dim, num_heads=encoder_num_heads, 182 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 183 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 184 | init_values=init_values) 185 | for i in range(encoder_depth)]) 186 | self.norm = norm_layer(encoder_embed_dim) 187 | 188 | if feat_decoder_embed_dim is None: 189 | feat_decoder_embed_dim = encoder_embed_dim 190 | if feat_decoder_num_heads is None: 191 | feat_decoder_num_heads = encoder_num_heads 192 | self.mask_token_img = nn.Parameter(torch.zeros(1, 1, feat_decoder_embed_dim)) 193 | self.down_img = nn.Linear(encoder_embed_dim, feat_decoder_embed_dim) 194 | self.decoder_img = PretrainVisionTransformerDecoder( 195 | patch_size=patch_size, 196 | num_classes=target_feature_dim, 197 | embed_dim=feat_decoder_embed_dim, 198 | depth=decoder_depth, 199 | num_heads=feat_decoder_num_heads, 200 | mlp_ratio=mlp_ratio, 201 | qkv_bias=qkv_bias, 202 | qk_scale=qk_scale, 203 | drop_rate=drop_rate, 204 | attn_drop_rate=attn_drop_rate, 205 | drop_path_rate=drop_path_rate, 206 | norm_layer=norm_layer, 207 | init_values=init_values, 208 | use_checkpoint=use_checkpoint, 209 | ) 210 | self.pos_embed_img = get_3d_sincos_pos_embed( 211 | embed_dim=feat_decoder_embed_dim, 212 | grid_size=self.patch_embed.num_patches_h, 213 | t_size=self.patch_embed.num_patches_t 214 | ) 215 | self.pos_embed_img = nn.Parameter(self.pos_embed_img, requires_grad=False) 216 | self.pos_embed_img.requires_grad = False 217 | trunc_normal_(self.mask_token_img, std=.02) 218 | 219 | self.mask_token_vid = nn.Parameter(torch.zeros(1, 1, feat_decoder_embed_dim)) 220 | self.down_vid = nn.Linear(encoder_embed_dim, feat_decoder_embed_dim) 221 | self.decoder_vid = PretrainVisionTransformerDecoder( 222 | patch_size=patch_size, 223 | num_classes=target_video_feature_dim, 224 | embed_dim=feat_decoder_embed_dim, 225 | depth=decoder_depth, 226 | num_heads=feat_decoder_num_heads, 227 | mlp_ratio=mlp_ratio, 228 | qkv_bias=qkv_bias, 229 | qk_scale=qk_scale, 230 | drop_rate=drop_rate, 231 | attn_drop_rate=attn_drop_rate, 232 | drop_path_rate=drop_path_rate, 233 | norm_layer=norm_layer, 234 | init_values=init_values, 235 | use_checkpoint=use_checkpoint, 236 | ) 237 | self.pos_embed_vid = get_3d_sincos_pos_embed( 238 | embed_dim=feat_decoder_embed_dim, 239 | grid_size=self.patch_embed.num_patches_h, 240 | t_size=self.patch_embed.num_patches_t 241 | ) 242 | self.pos_embed_vid = nn.Parameter(self.pos_embed_vid, requires_grad=False) 243 | self.pos_embed_vid.requires_grad = False 244 | trunc_normal_(self.mask_token_vid, std=.02) 245 | 246 | self.apply(self._init_weights) 247 | 248 | if self.use_cls_token: 249 | nn.init.normal_(self.cls_token, std=1e-6) 250 | 251 | def _init_weights(self, m): 252 | if isinstance(m, nn.Linear): 253 | nn.init.xavier_uniform_(m.weight) 254 | if isinstance(m, nn.Linear) and m.bias is not None: 255 | nn.init.constant_(m.bias, 0) 256 | elif isinstance(m, nn.LayerNorm): 257 | nn.init.constant_(m.bias, 0) 258 | nn.init.constant_(m.weight, 1.0) 259 | 260 | def get_num_layers(self): 261 | return len(self.blocks) 262 | 263 | @torch.jit.ignore 264 | def no_weight_decay(self): 265 | return {'pos_embed', 'cls_token', 'mask_token'} 266 | 267 | def forward_encoder(self, x, mask): 268 | # embed patches 269 | # x: B, C, T, H, W 270 | x = self.patch_embed(x) 271 | 272 | # add pos embed w/o cls token 273 | x = x + self.pos_embed.type_as(x).detach() 274 | # x: B, L, C 275 | 276 | # masking: length -> length * mask_ratio 277 | B, _, C = x.shape 278 | x = x[~mask].reshape(B, -1, C) # ~mask means visible 279 | 280 | # append cls token 281 | if self.use_cls_token: 282 | cls_tokens = self.cls_token.expand(B, -1, -1) 283 | x = torch.cat((cls_tokens, x), dim=1) 284 | 285 | # apply Transformer blocks 286 | if self.use_checkpoint: 287 | for blk in self.blocks: 288 | x = checkpoint.checkpoint(blk, x) 289 | else: 290 | for blk in self.blocks: 291 | x = blk(x) 292 | 293 | x = self.norm(x) 294 | 295 | return x 296 | 297 | def forward(self, x, mask): 298 | x = self.forward_encoder(x, mask) 299 | s = 1 if self.use_cls_token else 0 300 | 301 | x_vis_img = self.down_img(x) 302 | B, N, C = x_vis_img.shape 303 | 304 | expand_pos_embed_img = self.pos_embed_img.type_as(x_vis_img).detach().expand(B, -1, -1) 305 | pos_emd_vis_img = expand_pos_embed_img[~mask].reshape(B, -1, C) 306 | pos_emd_mask_img = expand_pos_embed_img[mask].reshape(B, -1, C) 307 | x_img = torch.cat( 308 | [x_vis_img[:, s:, :] + pos_emd_vis_img, self.mask_token_img + pos_emd_mask_img], 309 | dim=1) # [B, N, C_d] 310 | x_img = torch.cat([x_vis_img[:, :s, :], x_img], dim=1) 311 | 312 | x_img = self.decoder_img(x_img, pos_emd_mask_img.shape[1]) 313 | 314 | x_vis_vid = self.down_vid(x) 315 | B, N, C = x_vis_vid.shape 316 | 317 | expand_pos_embed_vid = self.pos_embed_vid.type_as(x_vis_vid).detach().expand(B, -1, -1) 318 | pos_emd_vis_vid = expand_pos_embed_vid[~mask].reshape(B, -1, C) 319 | pos_emd_mask_vid = expand_pos_embed_vid[mask].reshape(B, -1, C) 320 | x_vid = torch.cat( 321 | [x_vis_vid[:, s:, :] + pos_emd_vis_vid, self.mask_token_vid + pos_emd_mask_vid], 322 | dim=1) # [B, N, C_d] 323 | x_vid = torch.cat([x_vis_vid[:, :s, :], x_vid], dim=1) 324 | 325 | x_vid = self.decoder_vid(x_vid, pos_emd_mask_vid.shape[1]) 326 | 327 | return x_img, x_vid 328 | 329 | 330 | @register_model 331 | def pretrain_masked_video_student_small_patch16_224(pretrained=False, **kwargs): 332 | model = PretrainMaskedVideoStudent( 333 | img_size=224, 334 | patch_size=16, 335 | encoder_embed_dim=384, 336 | encoder_depth=12, 337 | encoder_num_heads=6, 338 | mlp_ratio=4, 339 | qkv_bias=True, 340 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 341 | **kwargs) 342 | model.default_cfg = _cfg() 343 | if pretrained: 344 | checkpoint = torch.load( 345 | kwargs["init_ckpt"], map_location="cpu" 346 | ) 347 | model.load_state_dict(checkpoint["model"]) 348 | return model 349 | 350 | 351 | @register_model 352 | def pretrain_masked_video_student_base_patch16_224(pretrained=False, **kwargs): 353 | model = PretrainMaskedVideoStudent( 354 | img_size=224, 355 | patch_size=16, 356 | encoder_embed_dim=768, 357 | encoder_depth=12, 358 | encoder_num_heads=12, 359 | mlp_ratio=4, 360 | qkv_bias=True, 361 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 362 | **kwargs) 363 | model.default_cfg = _cfg() 364 | if pretrained: 365 | checkpoint = torch.load( 366 | kwargs["init_ckpt"], map_location="cpu" 367 | ) 368 | model.load_state_dict(checkpoint["model"]) 369 | return model 370 | 371 | 372 | @register_model 373 | def pretrain_masked_video_student_large_patch16_224(pretrained=False, **kwargs): 374 | model = PretrainMaskedVideoStudent( 375 | img_size=224, 376 | patch_size=16, 377 | encoder_embed_dim=1024, 378 | encoder_depth=24, 379 | encoder_num_heads=16, 380 | mlp_ratio=4, 381 | qkv_bias=True, 382 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 383 | **kwargs) 384 | model.default_cfg = _cfg() 385 | if pretrained: 386 | checkpoint = torch.load( 387 | kwargs["init_ckpt"], map_location="cpu" 388 | ) 389 | model.load_state_dict(checkpoint["model"]) 390 | return model 391 | 392 | 393 | @register_model 394 | def pretrain_masked_video_student_huge_patch16_224(pretrained=False, **kwargs): 395 | model = PretrainMaskedVideoStudent( 396 | img_size=224, 397 | patch_size=16, 398 | encoder_embed_dim=1280, 399 | encoder_depth=32, 400 | encoder_num_heads=16, 401 | mlp_ratio=4, 402 | qkv_bias=True, 403 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 404 | **kwargs) 405 | model.default_cfg = _cfg() 406 | if pretrained: 407 | checkpoint = torch.load( 408 | kwargs["init_ckpt"], map_location="cpu" 409 | ) 410 | model.load_state_dict(checkpoint["model"]) 411 | return model 412 | --------------------------------------------------------------------------------