├── DATASET.md
├── images
├── teaser.png
├── object_interaction_dataframe.png
└── locomotion_prediction_dataframe.png
├── decoder
└── utils.py
├── INSTALL.md
├── masking_generator.py
├── dpvo
└── plot_utils.py
├── functional.py
├── VPP.md
├── README.md
├── volume_transforms.py
├── VIP.md
├── LP.md
├── random_erasing.py
├── engine_for_pretraining.py
├── modeling_video_teacher.py
├── optim_factory.py
├── engine_locomotion_prediction_for_finetuning.py
├── object_interaction
├── object_interaction_utils.py
└── object_interaction_dataloader.py
├── locomotion_prediction
├── locomotion_prediction_utils.py
└── locomotion_prediction_dataloader.py
├── modeling_teacher.py
├── datasets.py
├── transforms.py
├── cms_dataset.py
├── mixup.py
├── ssv2.py
└── modeling_student.py
/DATASET.md:
--------------------------------------------------------------------------------
1 | # Data Preparation
2 |
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DannyTran123/egopet/HEAD/images/teaser.png
--------------------------------------------------------------------------------
/images/object_interaction_dataframe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DannyTran123/egopet/HEAD/images/object_interaction_dataframe.png
--------------------------------------------------------------------------------
/images/locomotion_prediction_dataframe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DannyTran123/egopet/HEAD/images/locomotion_prediction_dataframe.png
--------------------------------------------------------------------------------
/decoder/utils.py:
--------------------------------------------------------------------------------
1 | import ffmpeg
2 | import numpy as np
3 | import torch
4 |
5 | def decode_ffmpeg(video_path, start_seek, num_sec=2, num_frames=16, fps=5):
6 |
7 | probe = ffmpeg.probe(video_path, v='error', select_streams='v:0', show_entries='stream=width,height,duration,r_frame_rate')
8 | video_info = next((s for s in probe['streams'] if 'width' in s and 'height' in s), None)
9 |
10 | if video_info is None:
11 | raise ValueError("No video stream information found in the input video.")
12 |
13 | width = int(video_info['width'])
14 | height = int(video_info['height'])
15 | r_frame_rate = video_info['r_frame_rate'].split('/')
16 |
17 | if fps is None:
18 | fps = int(r_frame_rate[0]) / int(r_frame_rate[1])
19 |
20 | cmd = (
21 | ffmpeg
22 | .input(video_path, ss=start_seek, t=num_sec + 0.1)
23 | .filter('fps', fps=fps)
24 | )
25 | out, _ = (
26 | cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
27 | .run(capture_stdout=True, quiet=True)
28 | )
29 |
30 | video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
31 | video_copy = video.copy()
32 | video = torch.from_numpy(video_copy)
33 | return video[:num_frames].type(torch.float32)
34 |
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ## Requirements
4 | - Python >= 3.8
5 | - [PyTorch==1.10](https://pytorch.org/)
6 | - [torchvision](https://github.com/pytorch/vision) that matches the PyTorch installation.
7 | - [timm==0.4.12](https://github.com/rwightman/pytorch-image-models)
8 | - [deepspeed==0.5.8](https://github.com/microsoft/DeepSpeed)
9 | - [TensorboardX](https://github.com/lanpa/tensorboardX)
10 | - [decord](https://github.com/dmlc/decord)
11 | - [einops](https://github.com/arogozhnikov/einops)
12 | - [scipy](https://github.com/scipy/scipy)
13 | - [ffmpeg-python](https://github.com/kkroening/ffmpeg-python)
14 | - [opencv-python](https://pypi.org/project/opencv-python/)
15 | - [iopath](https://github.com/facebookresearch/iopath)
16 | - [numpy==1.23.0](https://numpy.org/install/)
17 | - [pandas](https://pandas.pydata.org/docs/getting_started/install.html)
18 | - [matplotlib](https://matplotlib.org/stable/users/installing/index.html)
19 | - [sklearn](https://scikit-learn.org/stable/install.html)
20 | - [torchmetrics](https://pypi.org/project/torchmetrics/)
21 |
22 | ## Extras
23 | For ffmpeg-python you additionally need to download the openh264 binary from [here](https://github.com/cisco/openh264/releases/tag/v2.1.1) and rename the file as 'libopenh264.so.5' and place it in your environment's lib directory.
24 |
--------------------------------------------------------------------------------
/masking_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class TubeMaskingGenerator:
5 | def __init__(self, input_size, mask_ratio):
6 | self.frames, self.height, self.width = input_size
7 | self.num_patches_per_frame = self.height * self.width
8 | self.total_patches = self.frames * self.num_patches_per_frame
9 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame)
10 | self.total_masks = self.frames * self.num_masks_per_frame
11 |
12 | def __repr__(self):
13 | repr_str = "Maks: total patches {}, mask patches {}".format(
14 | self.total_patches, self.total_masks
15 | )
16 | return repr_str
17 |
18 | def __call__(self):
19 | mask_per_frame = np.hstack([
20 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame),
21 | np.ones(self.num_masks_per_frame),
22 | ])
23 | np.random.shuffle(mask_per_frame)
24 | mask = np.tile(mask_per_frame, (self.frames, 1))
25 | mask = mask.flatten()
26 | return mask
27 |
28 |
29 | class RandomMaskingGenerator:
30 | def __init__(self, input_size, mask_ratio):
31 | self.frames, self.height, self.width = input_size
32 | self.total_patches = self.frames * self.height * self.width
33 | self.num_masks = int(mask_ratio * self.total_patches)
34 | self.total_masks = self.num_masks
35 |
36 | def __repr__(self):
37 | repr_str = "Maks: total patches {}, mask patches {}".format(
38 | self.total_patches, self.total_masks
39 | )
40 | return repr_str
41 |
42 | def __call__(self):
43 | mask = np.hstack([
44 | np.zeros(self.total_patches - self.num_masks),
45 | np.ones(self.num_masks),
46 | ])
47 | np.random.shuffle(mask)
48 | return mask
49 |
--------------------------------------------------------------------------------
/dpvo/plot_utils.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from evo.core import sync
6 | from evo.core.trajectory import PoseTrajectory3D
7 | from evo.tools import plot
8 | from pathlib import Path
9 |
10 |
11 | def make_traj(args) -> PoseTrajectory3D:
12 | if isinstance(args, tuple):
13 | traj, tstamps = args
14 | return PoseTrajectory3D(positions_xyz=traj[:,:3], orientations_quat_wxyz=traj[:,3:], timestamps=tstamps)
15 | assert isinstance(args, PoseTrajectory3D), type(args)
16 | return deepcopy(args)
17 |
18 | def best_plotmode(traj):
19 | # _, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0))
20 | # plot_axes = "xyz"[i2] + "xyz"[i1]
21 | plot_axes = "xz"
22 | return getattr(plot.PlotMode, plot_axes)
23 |
24 | def plot_trajectory(pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True):
25 | pred_traj = make_traj(pred_traj)
26 |
27 | if gt_traj is not None:
28 | gt_traj = make_traj(gt_traj)
29 | gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)
30 |
31 | if align:
32 | pred_traj.align(gt_traj, correct_scale=correct_scale)
33 |
34 | plot_collection = plot.PlotCollection("PlotCol")
35 | fig = plt.figure(figsize=(8, 8))
36 | plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj)
37 | ax = plot.prepare_axis(fig, plot_mode)
38 | ax.set_title(title)
39 | if gt_traj is not None:
40 | plot.traj(ax, plot_mode, gt_traj, '--', 'gray', "Ground Truth")
41 | plot.traj(ax, plot_mode, pred_traj, '-', 'blue', "Predicted")
42 | plot_collection.add_figure("traj (error)", fig)
43 | plot_collection.export(filename, confirm_overwrite=False)
44 | plt.close(fig=fig)
45 | # print(f"Saved {filename}")
46 |
47 | def save_trajectory_tum_format(traj, filename):
48 | traj = make_traj(traj)
49 | tostr = lambda a: ' '.join(map(str, a))
50 | with Path(filename).open('w') as f:
51 | for i in range(traj.num_poses):
52 | f.write(f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[1,2,3,0]])}\n")
53 | print(f"Saved {filename}")
54 |
--------------------------------------------------------------------------------
/functional.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import cv2
3 | import numpy as np
4 | import PIL
5 | import torch
6 |
7 |
8 | def _is_tensor_clip(clip):
9 | return torch.is_tensor(clip) and clip.ndimension() == 4
10 |
11 |
12 | def crop_clip(clip, min_h, min_w, h, w):
13 | if isinstance(clip[0], np.ndarray):
14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
15 |
16 | elif isinstance(clip[0], PIL.Image.Image):
17 | cropped = [
18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
19 | ]
20 | else:
21 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
22 | 'but got list of {0}'.format(type(clip[0])))
23 | return cropped
24 |
25 |
26 | def resize_clip(clip, size, interpolation='bilinear'):
27 | if isinstance(clip[0], np.ndarray):
28 | if isinstance(size, numbers.Number):
29 | im_h, im_w, im_c = clip[0].shape
30 | # Min spatial dim already matches minimal size
31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
32 | and im_h == size):
33 | return clip
34 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
35 | size = (new_w, new_h)
36 | else:
37 | size = size[0], size[1]
38 | if interpolation == 'bilinear':
39 | np_inter = cv2.INTER_LINEAR
40 | else:
41 | np_inter = cv2.INTER_NEAREST
42 | scaled = [
43 | cv2.resize(img, size, interpolation=np_inter) for img in clip
44 | ]
45 | elif isinstance(clip[0], PIL.Image.Image):
46 | if isinstance(size, numbers.Number):
47 | im_w, im_h = clip[0].size
48 | # Min spatial dim already matches minimal size
49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w
50 | and im_h == size):
51 | return clip
52 | new_h, new_w = get_resize_sizes(im_h, im_w, size)
53 | size = (new_w, new_h)
54 | else:
55 | size = size[1], size[0]
56 | if interpolation == 'bilinear':
57 | pil_inter = PIL.Image.BILINEAR
58 | else:
59 | pil_inter = PIL.Image.NEAREST
60 | scaled = [img.resize(size, pil_inter) for img in clip]
61 | else:
62 | raise TypeError('Expected numpy.ndarray or PIL.Image' +
63 | 'but got list of {0}'.format(type(clip[0])))
64 | return scaled
65 |
66 |
67 | def get_resize_sizes(im_h, im_w, size):
68 | if im_w < im_h:
69 | ow = size
70 | oh = int(size * im_h / im_w)
71 | else:
72 | oh = size
73 | ow = int(size * im_w / im_h)
74 | return oh, ow
75 |
76 |
77 | def normalize(clip, mean, std, inplace=False):
78 | if not _is_tensor_clip(clip):
79 | raise TypeError('tensor is not a torch clip.')
80 |
81 | if not inplace:
82 | clip = clip.clone()
83 |
84 | dtype = clip.dtype
85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device)
87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
88 |
89 | return clip
90 |
--------------------------------------------------------------------------------
/VPP.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning for Vision to Proprioception prediction (VPP) Classification
2 |
3 | ## Description
4 | We evaluate the usefulness of EgoPet on a robotic task, focusing on the problem of vision-based locomotion. Specifically, the task consists of predicting the parameters of the terrain a quadrupedal robot is walking on. The accurate prediction of these parameters is correlated with higher performance in the downstream walking task.
5 |
6 | The parameters we predict are the local terrain geometry, the terrain's friction, and the parameters related to the robot's walking behavior on the terrain, including the robot's speed, motor efficiency, and high-level command. We aim to predict a latent representation $z_t$ of the terrain parameters. This latent representation consists of the hidden layer of a neural network trained in simulation to encode end-to-end together with the walking policy.
7 |
8 | To collect the dataset, we let the robot walk in many environments. The training dataset contains 120 thousand frames from 3 environments (approximately 2.3 hours of total walking time): an office, a park, and a beach. Each of them contains different terrain geometries, including flat, steps, and slopes. Each sample contains an image collected from a forward-looking camera mounted on the robot and the (latent) parameters of the terrain below the center of mass of the robot $z_t$ estimated with a history of proprioception.
9 |
10 | The task consists of predicting $z_t$ from a (history) of images. We generate several sub-tasks by predicting the future terrain parameters $z_{t+0.8}, z_{t+1.5}$ and the past ones $z_{t-0.8}, z_{t-1.5}$. These time intervals were selected to differentiate between forecasting and estimation. We divide the datasets into a training and testing set per episode, i.e., two different policy runs. We construct three test datasets, one in distribution (same location and lighting conditions as training), out of distribution due to lighting conditions (same location but very different time (night)), and out of distribution data such as sand which was never experienced during training.
11 |
12 | For more information about the VPP task refer to our paper!
13 |
14 | ## Dataset Setup
15 | Download and extract CMS_data.tar.gz from from [here](https://drive.google.com/file/d/1ZKSWwCoZP1mHjpksIEAwh3sTeeSJNF3B/view?usp=sharing)
16 |
17 | ## Linear Probing
18 |
19 | Run the following command to train a linear probing layer to predict the proprioception (past, present, future) given the video input.
20 |
21 | ```
22 | DATA_PATH='./datasets/CMS/up_down_policy_data'
23 | LOOKHEAD=0.8 # in -1.5 -0.8 0 0.8 1.5
24 | MASTER_PORT=29500
25 | DATA_ROOT=${DATA_PATH}
26 | OUTPUT_DIR="./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/finetune_on_cms_lookahead_${LOOKHEAD}_8frames_update_freq_4"
27 | MODEL_PATH="./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/checkpoint-2669.pth"
28 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
29 | run_fs_domain_adaptation.py \
30 | --model vit_base_patch16_224 \
31 | --nb_classes 10 \
32 | --latent_dim 10 \
33 | --data_path ${DATA_PATH} \
34 | --data_root ${DATA_ROOT} \
35 | --finetune ${MODEL_PATH} \
36 | --log_dir ${OUTPUT_DIR} \
37 | --output_dir ${OUTPUT_DIR} \
38 | --input_size 224 --short_side_size 224 \
39 | --opt adamw --opt_betas 0.9 0.999 --weight_decay 0.05 \
40 | --batch_size 256 --update_freq 4 --num_sample 4 \
41 | --num_frames 8 --sampling_rate 4 \
42 | --lr 5e-4 --epochs 50 \
43 | --input_use_imgs 8 \
44 | --lookhead ${LOOKHEAD} \
45 | --num_workers 5 \
46 | --save_ckpt_freq 20
47 | ```
48 |
49 | ### Pretrained Models
50 | | Model | Lookahead | Link |
51 | |-------------------|-----------|------|
52 | | MVD (ViT-B) | -1.5 | [link](https://drive.google.com/file/d/12tO7LwjZ66voCTxp6lNcSaeLaU-9YlYE/view?usp=sharing) |
53 | | MVD (ViT-B) | -0.8 | [link](https://drive.google.com/file/d/135ndczYtWKF04ZNl-5zs_O6yTdLh5T7O/view?usp=sharing) |
54 | | MVD (ViT-B) | 0 | [link](https://drive.google.com/file/d/1r_8ZbJtuI_6ImFQFzjgIb1ERVBjq-eY-/view?usp=sharing) |
55 | | MVD (ViT-B) | 0.8 | [link](https://drive.google.com/file/d/1f78c-ascoWKa3_29rq-4NhCruidMsFoI/view?usp=sharing) |
56 | | MVD (ViT-B) | 1.5 | [link](https://drive.google.com/file/d/1BO5gzVs5TiNilRpF5OzuTtyVSjfMO33Z/view?usp=sharing) |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EgoPet: Egomotion and Interaction Data from an Animal's Perspective (ECCV 2022)
2 | ### [Amir Bar](https://amirbar.net), [Arya Bakhtiar](), [Antonio Loquercio](https://antonilo.github.io/), [Jathushan Rajasegaran](https://people.eecs.berkeley.edu/~jathushan/), [Danny Tran](), [Yann LeCun](https://yann.lecun.com/), [Amir Globerson](http://www.cs.tau.ac.il/~gamir/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/)
3 |
4 |
5 |
6 |
7 | This repository is the implementation of the pretraining and linear probing experiments in this paper.
8 |
9 | ## Abstract
10 | Animals perceive the world to plan their actions and interact with other agents to accomplish complex tasks, demonstrating capabilities that are still unmatched by AI systems. To advance our understanding and reduce the gap between the capabilities of animals and AI systems, we introduce a dataset of pet egomotion imagery with diverse examples of simultaneous egomotion and multi-agent interaction. Current video datasets separately contain egomotion and interaction examples, but rarely both at the same time. In addition, EgoPet offers a radically distinct perspective from existing egocentric datasets of humans or vehicles. We define two in-domain benchmark tasks that capture animal behavior, and a third benchmark to assess the utility of EgoPet as a pretraining resource to robotic quadruped locomotion, showing that models trained from EgoPet outperform those trained from prior datasets. This work provides evidence that today's pets could be a valuable resource for training future AI systems and robotic assistants.
11 |
12 | ## EgoPet Dataset
13 | Install the data [here](https://huggingface.co/datasets/amirbar1/egopet).
14 |
15 | ## Pre-training
16 | ### Installation
17 | Please follow the instructions in [INSTALL.md](INSTALL.md).
18 |
19 | ### Pretrain an MVD model on EgoPet dataset:
20 | * Download VideoMAE ViT-B checkpoint from [here](https://drive.google.com/file/d/1tEhLyskjb755TJ65ptsrafUG2llSwQE1/view)
21 | * Download MAE ViT-B checkpoint from [here](https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth)
22 |
23 | Run the following command:
24 | ```
25 | OUTPUT_DIR='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet'
26 | IMAGE_TEACHER="path/to/mae/checkpoint"
27 | VIDEO_TEACHER="path/to/kinetics/checkpoint"
28 | DATA_PATH='egopet_pretrain.csv'
29 | GPUS=8
30 | NODE_COUNT=4
31 | RANK=0
32 | MASTER_PORT=29500
33 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=${GPUS} --use_env \
34 | --master_port ${MASTER_PORT} --nnodes=${NODE_COUNT} \
35 | --node_rank=${RANK} --master_addr=${MASTER_ADDR} \
36 | run_mvd_pretraining.py \
37 | --data_path ${DATA_PATH} \
38 | --data_root ${DATA_ROOT} \
39 | --model pretrain_masked_video_student_base_patch16_224 \
40 | --opt adamw --opt_betas 0.9 0.95 \
41 | --log_dir ${OUTPUT_DIR} \
42 | --output_dir ${OUTPUT_DIR} \
43 | --image_teacher_model mae_teacher_vit_base_patch16 \
44 | --distillation_target_dim 768 \
45 | --distill_loss_func SmoothL1 \
46 | --image_teacher_model_ckpt_path ${IMAGE_TEACHER} \
47 | --video_teacher_model pretrain_videomae_teacher_base_patch16_224 \
48 | --video_distillation_target_dim 768 \
49 | --video_distill_loss_func SmoothL1 \
50 | --video_teacher_model_ckpt_path ${VIDEO_TEACHER} \
51 | --mask_type tube --mask_ratio 0.9 --decoder_depth 2 \
52 | --batch_size 16 --update_freq 2 --save_ckpt_freq 10 \
53 | --num_frames 16 --sampling_rate 4 \
54 | --lr 1.5e-4 --min_lr 1e-4 --drop_path 0.1 --warmup_epochs 268 --epochs 2680 \
55 | --auto_resume
56 | ```
57 | We set `RANK` (`--node_rank`) as `0` on the first node. On other nodes, run the same command with `RANK=1`, ..., `RANK=3` respectively. `--master_addr` is set as the ip of the node 0.
58 |
59 | ### Pretrained Models
60 | | Model | Pretraining | Epochs | Link |
61 | |-------------------|-------------|--------|------|
62 | | MVD (ViT-B) | EgoPet | 2670 | [link](https://drive.google.com/file/d/1_Ky73DpjvSh5k4g-xhWE6QTndRlfUTHp/view?usp=sharing) |
63 |
64 | ## Visual Interaction Prediction (VIP)
65 |
66 | The fine-tuning instructions for the VIP task is in [VIP.md](VIP.md).
67 |
68 | ## Locomotion Prediction (LP)
69 |
70 | The fine-tuning instructions for the LP task is in [LP.md](LP.md).
71 |
72 | ## Vision to Proprioception Prediction (VPP)
73 |
74 | The fine-tuning instructions for the VPP task is in [VPP.md](VPP.md).
75 |
76 | ## Acknowledgements
77 |
78 | This project is built upon [MVD](https://github.com/ruiwang2021/mvd/tree/main), [MAE_ST](https://github.com/facebookresearch/mae_st), and [DPVO](https://github.com/princeton-vl/DPVO). Thank you to the contributors of these codebases!
79 |
--------------------------------------------------------------------------------
/volume_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 |
6 | def convert_img(img):
7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format
8 | """
9 | if len(img.shape) == 3:
10 | img = img.transpose(2, 0, 1)
11 | if len(img.shape) == 2:
12 | img = np.expand_dims(img, 0)
13 | return img
14 |
15 |
16 | class ClipToTensor(object):
17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
19 | """
20 |
21 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
22 | self.channel_nb = channel_nb
23 | self.div_255 = div_255
24 | self.numpy = numpy
25 |
26 | def __call__(self, clip):
27 | """
28 | Args: clip (list of numpy.ndarray): clip (list of images)
29 | to be converted to tensor.
30 | """
31 | # Retrieve shape
32 | if isinstance(clip[0], np.ndarray):
33 | h, w, ch = clip[0].shape
34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
35 | ch)
36 | elif isinstance(clip[0], Image.Image):
37 | w, h = clip[0].size
38 | else:
39 | raise TypeError('Expected numpy.ndarray or PIL.Image\
40 | but got list of {0}'.format(type(clip[0])))
41 |
42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
43 |
44 | # Convert
45 | for img_idx, img in enumerate(clip):
46 | if isinstance(img, np.ndarray):
47 | pass
48 | elif isinstance(img, Image.Image):
49 | img = np.array(img, copy=False)
50 | else:
51 | raise TypeError('Expected numpy.ndarray or PIL.Image\
52 | but got list of {0}'.format(type(clip[0])))
53 | img = convert_img(img)
54 | np_clip[:, img_idx, :, :] = img
55 | if self.numpy:
56 | if self.div_255:
57 | np_clip = np_clip / 255.0
58 | return np_clip
59 |
60 | else:
61 | tensor_clip = torch.from_numpy(np_clip)
62 |
63 | if not isinstance(tensor_clip, torch.FloatTensor):
64 | tensor_clip = tensor_clip.float()
65 | if self.div_255:
66 | tensor_clip = torch.div(tensor_clip, 255)
67 | return tensor_clip
68 |
69 |
70 | # Note this norms data to -1/1
71 | class ClipToTensor_K(object):
72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
74 | """
75 |
76 | def __init__(self, channel_nb=3, div_255=True, numpy=False):
77 | self.channel_nb = channel_nb
78 | self.div_255 = div_255
79 | self.numpy = numpy
80 |
81 | def __call__(self, clip):
82 | """
83 | Args: clip (list of numpy.ndarray): clip (list of images)
84 | to be converted to tensor.
85 | """
86 | # Retrieve shape
87 | if isinstance(clip[0], np.ndarray):
88 | h, w, ch = clip[0].shape
89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format(
90 | ch)
91 | elif isinstance(clip[0], Image.Image):
92 | w, h = clip[0].size
93 | else:
94 | raise TypeError('Expected numpy.ndarray or PIL.Image\
95 | but got list of {0}'.format(type(clip[0])))
96 |
97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
98 |
99 | # Convert
100 | for img_idx, img in enumerate(clip):
101 | if isinstance(img, np.ndarray):
102 | pass
103 | elif isinstance(img, Image.Image):
104 | img = np.array(img, copy=False)
105 | else:
106 | raise TypeError('Expected numpy.ndarray or PIL.Image\
107 | but got list of {0}'.format(type(clip[0])))
108 | img = convert_img(img)
109 | np_clip[:, img_idx, :, :] = img
110 | if self.numpy:
111 | if self.div_255:
112 | np_clip = (np_clip - 127.5) / 127.5
113 | return np_clip
114 |
115 | else:
116 | tensor_clip = torch.from_numpy(np_clip)
117 |
118 | if not isinstance(tensor_clip, torch.FloatTensor):
119 | tensor_clip = tensor_clip.float()
120 | if self.div_255:
121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
122 | return tensor_clip
123 |
124 |
125 | class ToTensor(object):
126 | """Converts numpy array to tensor
127 | """
128 |
129 | def __call__(self, array):
130 | tensor = torch.from_numpy(array)
131 | return tensor
132 |
--------------------------------------------------------------------------------
/VIP.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning for Visual Interaction Prediction (VIP) Classification
2 |
3 | ## Description
4 | Studying animal interactions from an egocentric perspective provides insights into how they navigate, manipulate their surroundings, and communicate which is applicable in designing effective systems for real-world settings.
5 |
6 | The input for this task is a video clip from the egocentric perspective of an animal. The output is twofold: a binary classification indicating whether an interaction is taking place or not, and identification of the subject of the interaction. In the EgoPet dataset, an “interaction” is defined as a discernible event where the agent demonstrates clear attention (physical contact, proximity, orientation, or vocalization) to an object or another agent with visual evidence. Aimless movements are not labeled as interactions.
7 |
8 | The beginning of an interaction is marked at the first time-step where the agent begins to give attention to a target, and the endpoint is marked at the last time-step before the attention ceases. There are 644 positive interaction segments with 17 distinct interaction subjects and 556 segments where no interaction occurs (“negative segments”). The segments were then split into train and test. For both training and testing, we sampled additional negative segments from the videos, ensuring these did not overlap with the positive segments. We sampled 125 additional negative training segments and 124 additional negative test segments, leaving us with 754 training segments and 695 test segments, for a total of 1,449 annotated segments.
9 |
10 | For more information about the VIP task refer to our paper!
11 |
12 | ## Dataset Information
13 | `object_interaciton_train.csv` and `object_interaciton_validation.csv` are csv files in which every row represents a single training or validation sample.
14 | 
15 |
16 | ### Columns documentation:
17 | ```
18 | animal - source animal of ego footage
19 | ds_type - train/validation
20 | video_id - hashing of video
21 | segment_id - id for segment within video
22 | start_time - start time of interaction, if NONE there is no interaction in the entire clip
23 | end_time - end time of interaction, if NONE there is no interaction in the entire clip
24 | total_time - total time of segment video
25 | interacting_object - object being interacted with
26 | video_path - path to segment video
27 | ```
28 |
29 | ### Interacting Object Categories
30 | Chosen from common interaction objects.
31 | ```
32 | person
33 | ball
34 | bench
35 | bird
36 | dog
37 | cat
38 | other animal
39 | toy
40 | door
41 | floor
42 | food
43 | plant
44 | filament
45 | plastic
46 | water
47 | vehicle
48 | other
49 | ```
50 |
51 | ## Evaluation
52 |
53 | As a sanity check, run evaluation using our MVD **fine-tuned** models:
54 |
55 |
56 |
57 |
58 | |
59 | MVD (EgoPet) |
60 |
61 | | fine-tuned checkpoint |
62 | download |
63 |
64 | | reference Interaction accuracy |
65 | 68.75 |
66 |
67 |
68 | | reference Interaction AUROC |
69 | 74.50 |
70 |
71 |
72 | | reference Subject Prediction Top-1 accuracy |
73 | 35.38 |
74 |
75 |
76 | | reference Subject Prediction Top-3 accuracy |
77 | 66.43 |
78 |
79 |
80 |
81 | Evaluate VideoMAE/MVD on a single GPU (`{EGOPET_DIR}` is a directory containing `{training_set, validation_set}` sets of EgoPet, `{CSV_PATH}` is a directory containing `{object_interaction_train.csv, object_interaction_validation.csv}` which is `csv/`):
82 | ```
83 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set'
84 | CSV_PATH='csv/'
85 | FINETUNE_PATH='path/to/model'
86 | python run_object_interaction_finetuning.py --num_workers 5 --model vit_base_patch16_224 --latent_dim 18 --data_path ${EGOPET_DIR} --csv_path ${CSV_PATH} --finetune ${FINETUNE_PATH} --input_size 224 --batch_size 64 --num_frames 8 --num_sec 2 --fps 4 --alpha 1 --eval
87 | ```
88 | Evaluating on MVD (EgoPet) should give:
89 | ```
90 | * loss 1.736
91 | * Acc@1_interaction 68.750 auroc_interaction 0.745
92 | * Acc@1_object 35.379 Acc@3_object 66.426 auroc_object 0.690
93 | Loss of the network on the 695 test images: 1.7%
94 | Min loss: 1.000%
95 | ```
96 |
97 | ## Linear Probing
98 |
99 | To fine-tune a pre-trained ViT-Base VideoMAE/MVD with **single-node training**, run the following on 1 node with 8 GPUs:
100 | ```
101 | OUTPUT_DIR='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/finetune_on_object_interaction'
102 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set'
103 | CSV_PATH='csv/'
104 | FINETUNE_PATH='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/checkpoint-2669.pth'
105 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
106 | run_object_interaction_finetuning.py \
107 | --output_dir ${OUTPUT_DIR} \
108 | --num_workers 5 \
109 | --model vit_base_patch16_224 \
110 | --latent_dim 18 \
111 | --data_path ${EGOPET_DIR} \
112 | --csv_path ${CSV_PATH} \
113 | --finetune ${FINETUNE_PATH} \
114 | --input_size 224 \
115 | --opt adamw --opt_betas 0.9 0.999 --weight_decay 0.05 \
116 | --batch_size 64 --update_freq 2 --num_sample 2 \
117 | --save_ckpt_freq 10 --auto_resume \
118 | --num_frames 8 --num_sec 2 --fps 4 --object_interaction_ratio 0.5 \
119 | --alpha 1 \
120 | --lr 5e-4 --epochs 15
121 | ```
122 | To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`.
--------------------------------------------------------------------------------
/LP.md:
--------------------------------------------------------------------------------
1 | # Fine-tuning for Locomotion Prediction (LP)
2 |
3 | ## Description
4 | Planning where to move involves a complex interplay of both perception and foresight. It requires the ability to anticipate potential obstacles, consider various courses of action, and select the most efficient and effective strategy to achieve a desired goal. EgoPet contains examples where animals plan a future trajectory to achieve a certain goal (e.g., a dog following its owner).
5 |
6 | Given a sequence of past $m$ video frames $\{x_i\}^{t}_{i=t-m}$, the goal is to predict the unit normalized future trajectory of the agent $\{v_j\}^{t+k}_{j=t+1}$, where $v_j\in \mathbb{R}^{3}$ represents the relative location of the agent at timestep $j$. We predict the unit normalized relative location due to the scale ambiguity of the extracted trajectories. In practice, we condition models on $m=16$ frames and predict $k=40$ future locations which correspond to $4$ seconds into the future.
7 |
8 | Given an input sequence of frames, DPVO returns the location and orientation of the camera for each frame. For our EgoPet data we found that a stride (the rate at which we kept frames) of $5$ worked best but in some cases a stride of $10$ and $15$ worked better. While DPVO worked well, there were some inaccurate trajectories, so two human annotators were trained to evaluate a trajectory from an eagle's eye view of the trajectory (XZ view). Annotators were trained to choose which stride of DPVO produced the poses that best matched the trajectory if any at all.
9 |
10 | For more information about the LP task refer to our paper!
11 |
12 | ## Dataset Information
13 | `locomotion_prediction_train.csv` and `locomotion_prediction_train_25.csv` are csv files in which every row represents a single training sample and the 25 refers to 25% of the data which we use in the experiments in the paper. `locomotion_prediction_val.csv` is the csv file for the validation set.
14 | 
15 |
16 | The DPVO generated trajectories can be downloaded from [here](https://drive.google.com/file/d/1kgwhAnjrSoOSPx9s4F013NgmNEi5K0KV/view?usp=sharing).
17 |
18 | ### Columns documentation:
19 | ```
20 | animal - source animal of ego footage
21 | ds_type - train/validation
22 | segment_id - id for segment within video
23 | stride - the DPVO stride for this trajectory
24 | start_time - start time of this trajectory
25 | end_time - end time of this trajectory
26 | video_path - path to segment video
27 | ```
28 |
29 | ## Evaluation
30 |
31 | As a sanity check, run evaluation using our MVD **fine-tuned** models:
32 |
33 |
34 |
35 |
36 | |
37 | MVD (EgoPet) |
38 |
39 | | fine-tuned checkpoint |
40 | download |
41 |
42 | | reference ATE |
43 | 0.474 |
44 |
45 |
46 | | reference RPE |
47 | 0.171 |
48 |
49 |
50 |
51 | Evaluate VideoMAE/MVD on a 1 node with 8 GPUs. (`{EGOPET_DIR}` is a directory containing `{training_set, validation_set}` sets of EgoPet, `{TRAIN_CSV}` is the path to the train csv to use. `{RESUME_PATH}` is the path to the model you want to evaluate.):
52 | ```
53 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set'
54 | TRAIN_CSV='locomotion_prediction_train_25.csv'
55 | RESUME_PATH='path/to/model/egopet_lp_linearprobing_vitb_checkpoint-00014.pth'
56 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --master_port=29101
57 | --nproc_per_node=8 --use_env \
58 | run_locomotion_prediction_finetuning.py \
59 | --num_workers 7 \
60 | --model vit_base_patch16_224 \
61 | --input_size 224 \
62 | --validation_batch_size 8 \
63 | --path_to_data_dir ${EGOPET_DIR} \
64 | --train_csv ${TRAIN_CSV} \
65 | --path_to_trajectories_dir ${PATH_TO_TRAJECTORIES_DIR} \
66 | --num_pred 3 \
67 | --resume ${RESUME_PATH} \
68 | --animals cat,dog --num_condition_frames 16 --num_pose_prediction 120 --pps 10 --fps 30 \
69 | --scale_invariance dir \
70 | --dist_eval --eval
71 | ```
72 | Evaluating on MVD (EgoPet) should give:
73 | ```
74 | * loss 0.36006
75 | * ate 0.47448
76 | * rpe_trans 0.17098
77 | * rpe_rot 0.00000
78 | Loss of the network on the 167 test images: 0.4%
79 | Min loss: 0.360%
80 | ```
81 |
82 | ## Linear Probing
83 |
84 | To fine-tune a pre-trained ViT-Base VideoMAE/MVD with **single-node training**, run the following on 1 node with 8 GPUs:
85 | ```
86 | OUTPUT_DIR='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/finetune_on_object_interaction'
87 | EGOPET_DIR='your_path/egopet/training_and_validation_test_set'
88 | CSV_PATH='csv/locomotion_prediction_train_25.csv'
89 | FINETUNE_PATH='./logs_dir/mvd_vit_base_with_vit_base_teacher_egopet/checkpoint-2669.pth'
90 | PATH_TO_TRAJECTORIES_DIR='your_path/interp_trajectories'
91 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 --use_env \
92 | run_locomotion_prediction_finetuning.py \
93 | --output_dir ${OUTPUT_DIR} \
94 | --num_workers 5 \
95 | --model vit_base_patch16_224 \
96 | --path_to_data_dir ${EGOPET_DIR} \
97 | --path_to_trajectories_dir ${PATH_TO_TRAJECTORIES_DIR} \
98 | --train_csv ${CSV_PATH} \
99 | --finetune ${FINETUNE_PATH} \
100 | --input_size 224 \
101 | --opt adamw --opt_betas 0.9 0.999 --weight_decay 0.05 \
102 | --batch_size 16 --validation_batch_size 8 --update_freq 8 --num_sample 2 \
103 | --save_ckpt_freq 1 --auto_resume \
104 | --num_pred 3 --num_condition_frames 16 --num_pose_prediction 120 \
105 | --animals cat,dog --pps 10 --fps 30 --scale_invariance dir \
106 | --dist_eval \
107 | --lr 5e-4 --epochs 15
108 | ```
109 | To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`.
--------------------------------------------------------------------------------
/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/engine_for_pretraining.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 | import torch
5 | import torch.nn as nn
6 | import torchvision
7 | import utils
8 | from einops import rearrange
9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
11 |
12 | Loss_func_choice = {'L1': torch.nn.L1Loss, 'L2': torch.nn.MSELoss, 'SmoothL1': torch.nn.SmoothL1Loss}
13 |
14 |
15 | def train_one_epoch(args, model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer,
16 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
17 | log_writer=None, lr_scheduler=None, start_steps=None, lr_schedule_values=None,
18 | wd_schedule_values=None, update_freq=None, time_stride_loss=True, lr_scale=1.0,
19 | image_teacher_model=None, video_teacher_model=None, norm_feature=False):
20 |
21 | model.train()
22 | metric_logger = utils.MetricLogger(delimiter=" ")
23 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
24 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
25 | header = 'Epoch: [{}]'.format(epoch)
26 | print_freq = 10
27 | LN_img = nn.LayerNorm(args.distillation_target_dim, eps=1e-6, elementwise_affine=False).cuda()
28 | LN_vid = nn.LayerNorm(args.video_distillation_target_dim, eps=1e-6, elementwise_affine=False).cuda()
29 |
30 | loss_func_img_feat = Loss_func_choice[args.distill_loss_func]()
31 | loss_func_vid_feat = Loss_func_choice[args.video_distill_loss_func]()
32 | image_loss_weight = args.image_teacher_loss_weight
33 | video_loss_weight = args.video_teacher_loss_weight
34 |
35 | tubelet_size = args.tubelet_size
36 |
37 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
38 | # assign learning rate & weight decay for each step
39 | update_step = step // update_freq
40 | it = start_steps + update_step # global training iteration
41 | if lr_schedule_values is not None or wd_schedule_values is not None and step % update_freq == 0:
42 | for i, param_group in enumerate(optimizer.param_groups):
43 | if lr_schedule_values is not None:
44 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] * lr_scale
45 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
46 | param_group["weight_decay"] = wd_schedule_values[it]
47 |
48 | videos, videos_for_teacher, bool_masked_pos = batch
49 | videos = videos.to(device, non_blocking=True)
50 | videos_for_teacher = videos_for_teacher.to(device, non_blocking=True)
51 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
52 | _, _, T, _, _ = videos.shape
53 |
54 | with torch.cuda.amp.autocast():
55 | output_features, output_video_features = model(videos, bool_masked_pos)
56 | with torch.no_grad():
57 | image_teacher_model.eval()
58 | if time_stride_loss:
59 | teacher_features = image_teacher_model(
60 | rearrange(videos_for_teacher[:, :, ::tubelet_size, :, :], 'b c t h w -> (b t) c h w'),
61 | )
62 | teacher_features = rearrange(teacher_features, '(b t) l c -> b (t l) c', t=T//tubelet_size)
63 | else:
64 | teacher_features = image_teacher_model(
65 | rearrange(videos_for_teacher, 'b c t h w -> (b t) c h w'),
66 | )
67 | teacher_features = rearrange(teacher_features, '(b t d) l c -> b (t l) (d c)', t=T//tubelet_size, d=tubelet_size)
68 | if norm_feature:
69 | teacher_features = LN_img(teacher_features)
70 |
71 | video_teacher_model.eval()
72 | videos_for_video_teacher = videos if args.video_teacher_input_size == args.input_size \
73 | else videos_for_teacher
74 |
75 | video_teacher_features = video_teacher_model(videos_for_video_teacher)
76 | if norm_feature:
77 | video_teacher_features = LN_vid(video_teacher_features)
78 |
79 | B, _, D = output_features.shape
80 | loss_img_feat = loss_func_img_feat(
81 | input=output_features,
82 | target=teacher_features[bool_masked_pos].reshape(B, -1, D)
83 | )
84 | loss_value_img_feat = loss_img_feat.item()
85 |
86 | B, _, D = output_video_features.shape
87 | loss_vid_feat = loss_func_vid_feat(
88 | input=output_video_features,
89 | target=video_teacher_features[bool_masked_pos].reshape(B, -1, D)
90 | )
91 | loss_value_vid_feat = loss_vid_feat.item()
92 |
93 | loss = image_loss_weight * loss_img_feat + video_loss_weight * loss_vid_feat
94 |
95 | loss_value = loss.item()
96 |
97 | if not math.isfinite(loss_value):
98 | print("Loss is {}, stopping training".format(loss_value))
99 | sys.exit(1)
100 |
101 | # this attribute is added by timm on one optimizer (adahessian)
102 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
103 | loss /= update_freq
104 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
105 | parameters=model.parameters(), create_graph=is_second_order,
106 | update_grad=(step + 1) % update_freq == 0)
107 | if (step + 1) % update_freq == 0:
108 | optimizer.zero_grad()
109 | loss_scale_value = loss_scaler.state_dict()["scale"]
110 |
111 | torch.cuda.synchronize()
112 |
113 | metric_logger.update(loss=loss_value)
114 | metric_logger.update(loss_img_feat=loss_value_img_feat)
115 | metric_logger.update(loss_vid_feat=loss_value_vid_feat)
116 | metric_logger.update(loss_scale=loss_scale_value)
117 | min_lr = 10.
118 | max_lr = 0.
119 | for group in optimizer.param_groups:
120 | min_lr = min(min_lr, group["lr"])
121 | max_lr = max(max_lr, group["lr"])
122 |
123 | metric_logger.update(lr=max_lr)
124 | metric_logger.update(min_lr=min_lr)
125 | weight_decay_value = None
126 | for group in optimizer.param_groups:
127 | if group["weight_decay"] > 0:
128 | weight_decay_value = group["weight_decay"]
129 | metric_logger.update(weight_decay=weight_decay_value)
130 | metric_logger.update(grad_norm=grad_norm)
131 |
132 | if log_writer is not None:
133 | log_writer.update(loss=loss_value, head="loss")
134 | log_writer.update(loss_img_feat=loss_value_img_feat, head="loss_img_feat")
135 | log_writer.update(loss_vid_feat=loss_value_vid_feat, head="loss_vid_feat")
136 | log_writer.update(loss_scale=loss_scale_value, head="opt")
137 | log_writer.update(lr=max_lr, head="opt")
138 | log_writer.update(min_lr=min_lr, head="opt")
139 | log_writer.update(weight_decay=weight_decay_value, head="opt")
140 | log_writer.update(grad_norm=grad_norm, head="opt")
141 | log_writer.set_step()
142 |
143 | if lr_scheduler is not None:
144 | lr_scheduler.step_update(start_steps + step)
145 | # gather the stats from all processes
146 | metric_logger.synchronize_between_processes()
147 | print("Averaged stats:", metric_logger)
148 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
149 |
--------------------------------------------------------------------------------
/modeling_video_teacher.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from functools import partial
7 |
8 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table, get_3d_sincos_pos_embed
9 | from timm.models.registry import register_model
10 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
11 |
12 |
13 | def trunc_normal_(tensor, mean=0., std=1.):
14 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
15 |
16 |
17 | __all__ = [
18 | 'pretrain_videomae_teacher_base_patch16_224',
19 | 'pretrain_videomae_teacher_large_patch16_224',
20 | 'pretrain_videomae_teacher_huge_patch16_224',
21 | ]
22 |
23 |
24 | # --------------------------------------------------------
25 | # VideoMAE encoder
26 | # References:
27 | # VideoMAE: https://github.com/MCG-NJU/VideoMAE/blob/main/modeling_pretrain.py
28 | # --------------------------------------------------------
29 | class PretrainVisionTransformerEncoder(nn.Module):
30 | """ Vision Transformer with support for patch or hybrid CNN input stage
31 | """
32 |
33 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
34 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
35 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2):
36 | super().__init__()
37 | self.num_classes = num_classes
38 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
39 | self.patch_embed = PatchEmbed(
40 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tubelet_size=tubelet_size)
41 | num_patches = self.patch_embed.num_patches
42 |
43 | self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
44 |
45 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
46 | self.blocks = nn.ModuleList([
47 | Block(
48 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
49 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
50 | init_values=init_values)
51 | for i in range(depth)])
52 | self.norm = norm_layer(embed_dim)
53 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
54 |
55 | self.apply(self._init_weights)
56 |
57 | def _init_weights(self, m):
58 | if isinstance(m, nn.Linear):
59 | nn.init.xavier_uniform_(m.weight)
60 | if isinstance(m, nn.Linear) and m.bias is not None:
61 | nn.init.constant_(m.bias, 0)
62 | elif isinstance(m, nn.LayerNorm):
63 | nn.init.constant_(m.bias, 0)
64 | nn.init.constant_(m.weight, 1.0)
65 |
66 | def get_num_layers(self):
67 | return len(self.blocks)
68 |
69 | @torch.jit.ignore
70 | def no_weight_decay(self):
71 | return {'pos_embed', 'cls_token'}
72 |
73 | def get_classifier(self):
74 | return self.head
75 |
76 | def reset_classifier(self, num_classes, global_pool=''):
77 | self.num_classes = num_classes
78 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
79 |
80 | def forward_features(self, x):
81 | x = self.patch_embed(x)
82 |
83 | x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
84 |
85 | for i, blk in enumerate(self.blocks):
86 | x = blk(x)
87 |
88 | return x
89 |
90 | def forward(self, x):
91 | x = self.forward_features(x)
92 | return x
93 |
94 |
95 | class PretrainVideoTransformerTeacher(nn.Module):
96 | """ Vision Transformer with support for patch or hybrid CNN input stage
97 | """
98 |
99 | def __init__(self,
100 | img_size=224,
101 | patch_size=16,
102 | encoder_in_chans=3,
103 | encoder_num_classes=0,
104 | encoder_embed_dim=768,
105 | encoder_depth=12,
106 | encoder_num_heads=12,
107 | mlp_ratio=4.,
108 | qkv_bias=False,
109 | qk_scale=None,
110 | drop_rate=0.,
111 | attn_drop_rate=0.,
112 | drop_path_rate=0.,
113 | norm_layer=nn.LayerNorm,
114 | init_values=0.,
115 | tubelet_size=2,
116 | ):
117 | super().__init__()
118 | self.encoder = PretrainVisionTransformerEncoder(
119 | img_size=img_size,
120 | patch_size=patch_size,
121 | in_chans=encoder_in_chans,
122 | num_classes=encoder_num_classes,
123 | embed_dim=encoder_embed_dim,
124 | depth=encoder_depth,
125 | num_heads=encoder_num_heads,
126 | mlp_ratio=mlp_ratio,
127 | qkv_bias=qkv_bias,
128 | qk_scale=qk_scale,
129 | drop_rate=drop_rate,
130 | attn_drop_rate=attn_drop_rate,
131 | drop_path_rate=drop_path_rate,
132 | norm_layer=norm_layer,
133 | init_values=init_values,
134 | tubelet_size=tubelet_size,
135 | )
136 |
137 | def _init_weights(self, m):
138 | if isinstance(m, nn.Linear):
139 | nn.init.xavier_uniform_(m.weight)
140 | if isinstance(m, nn.Linear) and m.bias is not None:
141 | nn.init.constant_(m.bias, 0)
142 | elif isinstance(m, nn.LayerNorm):
143 | nn.init.constant_(m.bias, 0)
144 | nn.init.constant_(m.weight, 1.0)
145 |
146 | def get_num_layers(self):
147 | return len(self.blocks)
148 |
149 | @torch.jit.ignore
150 | def no_weight_decay(self):
151 | return {'pos_embed', 'cls_token'}
152 |
153 | def forward(self, x):
154 | x = self.encoder(x)
155 | return x
156 |
157 |
158 | @register_model
159 | def pretrain_videomae_teacher_base_patch16_224(pretrained=False, **kwargs):
160 | model = PretrainVideoTransformerTeacher(
161 | patch_size=16,
162 | encoder_embed_dim=768,
163 | encoder_depth=12,
164 | encoder_num_heads=12,
165 | encoder_num_classes=0,
166 | mlp_ratio=4,
167 | qkv_bias=True,
168 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
169 | **kwargs)
170 | model.default_cfg = _cfg()
171 | if pretrained:
172 | checkpoint = torch.load(
173 | kwargs["init_ckpt"], map_location="cpu"
174 | )
175 | model.load_state_dict(checkpoint["model"])
176 | return model
177 |
178 |
179 | @register_model
180 | def pretrain_videomae_teacher_large_patch16_224(pretrained=False, **kwargs):
181 | model = PretrainVideoTransformerTeacher(
182 | patch_size=16,
183 | encoder_embed_dim=1024,
184 | encoder_depth=24,
185 | encoder_num_heads=16,
186 | encoder_num_classes=0,
187 | mlp_ratio=4,
188 | qkv_bias=True,
189 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
190 | **kwargs)
191 | model.default_cfg = _cfg()
192 | if pretrained:
193 | checkpoint = torch.load(
194 | kwargs["init_ckpt"], map_location="cpu"
195 | )
196 | model.load_state_dict(checkpoint["model"])
197 | return model
198 |
199 |
200 | @register_model
201 | def pretrain_videomae_teacher_huge_patch16_224(pretrained=False, **kwargs):
202 | model = PretrainVideoTransformerTeacher(
203 | patch_size=16,
204 | encoder_embed_dim=1280,
205 | encoder_depth=32,
206 | encoder_num_heads=16,
207 | encoder_num_classes=0,
208 | mlp_ratio=4,
209 | qkv_bias=True,
210 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
211 | **kwargs)
212 | model.default_cfg = _cfg()
213 | if pretrained:
214 | checkpoint = torch.load(
215 | kwargs["init_ckpt"], map_location="cpu"
216 | )
217 | model.load_state_dict(checkpoint["model"])
218 | return model
219 |
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | class LARS(torch.optim.Optimizer):
25 | """
26 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
27 | """
28 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
29 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
30 | super().__init__(params, defaults)
31 |
32 | @torch.no_grad()
33 | def step(self):
34 | for g in self.param_groups:
35 | for p in g['params']:
36 | dp = p.grad
37 |
38 | if dp is None:
39 | continue
40 |
41 | if p.ndim > 1: # if not normalization gamma/beta or bias
42 | dp = dp.add(p, alpha=g['weight_decay'])
43 | param_norm = torch.norm(p)
44 | update_norm = torch.norm(dp)
45 | one = torch.ones_like(param_norm)
46 | q = torch.where(
47 | param_norm > 0.,
48 | torch.where(update_norm > 0, (g['trust_coefficient'] * param_norm / update_norm), one),
49 | one
50 | )
51 | dp = dp.mul(q)
52 |
53 | param_state = self.state[p]
54 | if 'mu' not in param_state:
55 | param_state['mu'] = torch.zeros_like(p)
56 | mu = param_state['mu']
57 | mu.mul_(g['momentum']).add_(dp)
58 | p.add_(mu, alpha=-g['lr'])
59 |
60 |
61 | def get_num_layer_for_vit(var_name, num_max_layer):
62 | if var_name in ("cls_token", "mask_token", "pos_embed"):
63 | return 0
64 | elif var_name.startswith("patch_embed"):
65 | return 0
66 | elif var_name.startswith("rel_pos_bias"):
67 | return num_max_layer - 1
68 | elif var_name.startswith("blocks"):
69 | layer_id = int(var_name.split('.')[1])
70 | return layer_id + 1
71 | else:
72 | return num_max_layer - 1
73 |
74 |
75 | class LayerDecayValueAssigner(object):
76 | def __init__(self, values):
77 | self.values = values
78 |
79 | def get_scale(self, layer_id):
80 | return self.values[layer_id]
81 |
82 | def get_layer_id(self, var_name):
83 | return get_num_layer_for_vit(var_name, len(self.values))
84 |
85 |
86 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
87 | parameter_group_names = {}
88 | parameter_group_vars = {}
89 |
90 | for name, param in model.named_parameters():
91 | if not param.requires_grad:
92 | continue # frozen weights
93 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
94 | group_name = "no_decay"
95 | this_weight_decay = 0.
96 | else:
97 | group_name = "decay"
98 | this_weight_decay = weight_decay
99 | if get_num_layer is not None:
100 | layer_id = get_num_layer(name)
101 | group_name = "layer_%d_%s" % (layer_id, group_name)
102 | else:
103 | layer_id = None
104 |
105 | if group_name not in parameter_group_names:
106 | if get_layer_scale is not None:
107 | scale = get_layer_scale(layer_id)
108 | else:
109 | scale = 1.
110 |
111 | parameter_group_names[group_name] = {
112 | "weight_decay": this_weight_decay,
113 | "params": [],
114 | "lr_scale": scale
115 | }
116 | parameter_group_vars[group_name] = {
117 | "weight_decay": this_weight_decay,
118 | "params": [],
119 | "lr_scale": scale
120 | }
121 |
122 | parameter_group_vars[group_name]["params"].append(param)
123 | parameter_group_names[group_name]["params"].append(name)
124 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
125 | return list(parameter_group_vars.values())
126 |
127 |
128 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
129 | opt_lower = args.opt.lower()
130 | weight_decay = args.weight_decay
131 | if filter_bias_and_bn:
132 | skip = {}
133 | if skip_list is not None:
134 | skip = skip_list
135 | elif hasattr(model, 'no_weight_decay'):
136 | skip = model.no_weight_decay()
137 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
138 | weight_decay = 0.
139 | else:
140 | parameters = model.parameters()
141 |
142 | if 'fused' in opt_lower:
143 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
144 |
145 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
146 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
147 | opt_args['eps'] = args.opt_eps
148 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
149 | opt_args['betas'] = args.opt_betas
150 |
151 | print("optimizer settings:", opt_args)
152 |
153 | opt_split = opt_lower.split('_')
154 | opt_lower = opt_split[-1]
155 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
156 | opt_args.pop('eps', None)
157 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
158 | elif opt_lower == 'momentum':
159 | opt_args.pop('eps', None)
160 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
161 | elif opt_lower == 'adam':
162 | optimizer = optim.Adam(parameters, **opt_args)
163 | elif opt_lower == 'adamw':
164 | optimizer = optim.AdamW(parameters, **opt_args)
165 | elif opt_lower == 'nadam':
166 | optimizer = Nadam(parameters, **opt_args)
167 | elif opt_lower == 'radam':
168 | optimizer = RAdam(parameters, **opt_args)
169 | elif opt_lower == 'adamp':
170 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
171 | elif opt_lower == 'sgdp':
172 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
173 | elif opt_lower == 'adadelta':
174 | optimizer = optim.Adadelta(parameters, **opt_args)
175 | elif opt_lower == 'adafactor':
176 | if not args.lr:
177 | opt_args['lr'] = None
178 | optimizer = Adafactor(parameters, **opt_args)
179 | elif opt_lower == 'adahessian':
180 | optimizer = Adahessian(parameters, **opt_args)
181 | elif opt_lower == 'rmsprop':
182 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
183 | elif opt_lower == 'rmsproptf':
184 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
185 | elif opt_lower == 'novograd':
186 | optimizer = NovoGrad(parameters, **opt_args)
187 | elif opt_lower == 'nvnovograd':
188 | optimizer = NvNovoGrad(parameters, **opt_args)
189 | elif opt_lower == 'fusedsgd':
190 | opt_args.pop('eps', None)
191 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
192 | elif opt_lower == 'fusedmomentum':
193 | opt_args.pop('eps', None)
194 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
195 | elif opt_lower == 'fusedadam':
196 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
197 | elif opt_lower == 'fusedadamw':
198 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
199 | elif opt_lower == 'fusedlamb':
200 | optimizer = FusedLAMB(parameters, **opt_args)
201 | elif opt_lower == 'fusednovograd':
202 | opt_args.setdefault('betas', (0.95, 0.98))
203 | optimizer = FusedNovoGrad(parameters, **opt_args)
204 | elif opt_lower == 'lars':
205 | opt_args.pop('eps', None)
206 | optimizer = LARS(parameters, **opt_args)
207 | else:
208 | assert False and "Invalid optimizer"
209 | raise ValueError
210 |
211 | if len(opt_split) > 1:
212 | if opt_split[0] == 'lookahead':
213 | optimizer = Lookahead(optimizer)
214 |
215 | return optimizer
216 |
--------------------------------------------------------------------------------
/engine_locomotion_prediction_for_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import sys
5 | from typing import Iterable, Optional
6 | import torch
7 | from timm.utils import accuracy, ModelEma
8 | import utils
9 | from scipy.special import softmax
10 | from einops import rearrange
11 | from torch.utils.data._utils.collate import default_collate
12 | import torch.nn.functional as F
13 | import pandas as pd
14 | from matplotlib import pyplot as plt
15 | import os
16 |
17 | parent = os.path.dirname(os.path.abspath(__file__))
18 | parent_parent = os.path.join(parent, '../')
19 | sys.path.append(os.path.dirname(parent_parent))
20 |
21 | from locomotion_prediction.locomotion_prediction_utils import *
22 |
23 |
24 | def train_class_batch(model, samples, dP_tensor, criterion, args):
25 | outputs = model(samples)
26 | dP_pred = outputs.unflatten(1, (args.act_pose_prediction, args.num_pred))
27 |
28 | loss = criterion(dP_pred, dP_tensor)
29 |
30 | return loss, dP_pred
31 |
32 |
33 | def get_loss_scale_for_deepspeed(model):
34 | optimizer = model.optimizer
35 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
36 |
37 |
38 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
39 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
40 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
41 | model_ema: Optional[ModelEma] = None, mixup_fn=None, log_writer=None,
42 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
43 | num_training_steps_per_epoch=None, update_freq=None, args=None):
44 | model.train(True)
45 | metric_logger = utils.MetricLogger(delimiter=" ")
46 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
47 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
48 | header = 'Epoch: [{}]'.format(epoch)
49 | print_freq = 10
50 |
51 | if loss_scaler is None:
52 | model.zero_grad()
53 | model.micro_steps = 0
54 | else:
55 | optimizer.zero_grad()
56 |
57 | for data_iter_step, (samples, dP_tensor, start_idx, end_idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
58 | step = data_iter_step // update_freq
59 | if step >= num_training_steps_per_epoch:
60 | continue
61 | it = start_steps + step # global training iteration
62 | # Update LR & WD for the first acc
63 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
64 | for i, param_group in enumerate(optimizer.param_groups):
65 | if lr_schedule_values is not None and 'lr_scale' in param_group:
66 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
67 | if wd_schedule_values is not None and 'weight_decay' in param_group and param_group["weight_decay"] > 0:
68 | param_group["weight_decay"] = wd_schedule_values[it]
69 |
70 | samples = samples.to(device, non_blocking=True)
71 | samples = samples.permute(0, 2, 1, 3, 4)
72 | dP_tensor = dP_tensor.to(device, non_blocking=True)
73 |
74 | if mixup_fn is not None:
75 | samples, targets = mixup_fn(samples, targets)
76 |
77 | if loss_scaler is None:
78 | samples = samples.half()
79 | loss, dP_pred = train_class_batch(
80 | model, samples, dP_tensor, criterion, args)
81 | else:
82 | with torch.cuda.amp.autocast():
83 | loss, dP_pred = train_class_batch(
84 | model, samples, dP_tensor, criterion, args)
85 |
86 | loss_value = loss.item()
87 |
88 | if not math.isfinite(loss_value):
89 | print("Loss is {}, stopping training".format(loss_value))
90 | sys.exit(1)
91 |
92 | if loss_scaler is None:
93 | loss /= update_freq
94 | model.backward(loss)
95 | model.step()
96 |
97 | if (data_iter_step + 1) % update_freq == 0:
98 | # model.zero_grad()
99 | # Deepspeed will call step() & model.zero_grad() automatic
100 | if model_ema is not None:
101 | model_ema.update(model)
102 | grad_norm = None
103 | loss_scale_value = get_loss_scale_for_deepspeed(model)
104 | else:
105 | # this attribute is added by timm on one optimizer (adahessian)
106 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
107 | loss /= update_freq
108 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
109 | parameters=model.parameters(), create_graph=is_second_order,
110 | update_grad=(data_iter_step + 1) % update_freq == 0)
111 | if (data_iter_step + 1) % update_freq == 0:
112 | optimizer.zero_grad()
113 | if model_ema is not None:
114 | model_ema.update(model)
115 | loss_scale_value = loss_scaler.state_dict()["scale"]
116 |
117 | torch.cuda.synchronize()
118 |
119 | metric_logger.update(loss=loss_value)
120 | metric_logger.update(loss_scale=loss_scale_value)
121 | min_lr = 10.
122 | max_lr = 0.
123 | for group in optimizer.param_groups:
124 | min_lr = min(min_lr, group["lr"])
125 | max_lr = max(max_lr, group["lr"])
126 |
127 | metric_logger.update(lr=max_lr)
128 | metric_logger.update(min_lr=min_lr)
129 | weight_decay_value = None
130 | for group in optimizer.param_groups:
131 | if group["weight_decay"] > 0:
132 | weight_decay_value = group["weight_decay"]
133 | metric_logger.update(weight_decay=weight_decay_value)
134 | metric_logger.update(grad_norm=grad_norm)
135 |
136 | if log_writer is not None:
137 | log_writer.update(loss=loss_value, head="loss")
138 | log_writer.update(loss_scale=loss_scale_value, head="opt")
139 | log_writer.update(lr=max_lr, head="opt")
140 | log_writer.update(min_lr=min_lr, head="opt")
141 | log_writer.update(weight_decay=weight_decay_value, head="opt")
142 | log_writer.update(grad_norm=grad_norm, head="opt")
143 |
144 | log_writer.set_step()
145 |
146 | # gather the stats from all processes
147 | metric_logger.synchronize_between_processes()
148 | print("Averaged stats:", metric_logger)
149 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
150 |
151 |
152 | @torch.no_grad()
153 | def validation_one_epoch(dataset_val, model, criterion, device, args):
154 | num_tasks = utils.get_world_size()
155 | global_rank = utils.get_rank()
156 |
157 | if args.dist_eval:
158 | if len(dataset_val) % num_tasks != 0:
159 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
160 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
161 | 'equal num of samples per-process.')
162 | sampler_val = torch.utils.data.DistributedSampler(
163 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
164 | else:
165 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
166 |
167 | data_loader = torch.utils.data.DataLoader(
168 | dataset_val, sampler=sampler_val,
169 | batch_size=1,
170 | num_workers=args.num_workers,
171 | pin_memory=args.pin_mem,
172 | drop_last=True
173 | )
174 |
175 | metric_logger = utils.MetricLogger(delimiter=" ")
176 | header = 'Test:'
177 |
178 | # switch to evaluation mode
179 | model.eval()
180 |
181 | for (video_path, trajectory_path, start_idx, end_idx) in metric_logger.log_every(data_loader, 10, header):
182 | try:
183 | video_path, trajectory_path, start_idx, end_idx = video_path[0], trajectory_path[0], start_idx[0].item(), end_idx[0].item()
184 | with torch.cuda.amp.autocast():
185 | loss, ate, rpe_trans, rpe_rot = evaluate_segment(model, dataset_val, device, video_path, trajectory_path, criterion, start_idx, end_idx, args, save_plot=True)
186 | metric_logger.meters['loss'].update(loss.item(), n=1)
187 | metric_logger.meters['ate'].update(ate, n=1)
188 | metric_logger.meters['rpe_trans'].update(rpe_trans, n=1)
189 | metric_logger.meters['rpe_rot'].update(rpe_rot, n=1)
190 | except Exception as e:
191 | print('Evaluation Error: ', e)
192 |
193 | # gather the stats from all processes
194 | metric_logger.synchronize_between_processes()
195 | print('* loss {losses.global_avg:.5f}'
196 | .format(losses=metric_logger.loss))
197 | print('* ate {ate.global_avg:.5f}'
198 | .format(ate=metric_logger.ate))
199 | print('* rpe_trans {rpe_trans.global_avg:.5f}'
200 | .format(rpe_trans=metric_logger.rpe_trans))
201 | print('* rpe_rot {rpe_rot.global_avg:.5f}'
202 | .format(rpe_rot=metric_logger.rpe_rot))
203 |
204 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
205 |
--------------------------------------------------------------------------------
/object_interaction/object_interaction_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | objects_dict = {
5 | 'person': 0,
6 | 'ball': 1,
7 | 'bench': 2,
8 | 'bird': 3,
9 | 'dog': 4,
10 | 'cat': 5,
11 | 'other animal': 6,
12 | 'toy': 7,
13 | 'door': 8,
14 | 'floor': 9,
15 | 'food': 10,
16 | 'plant': 11,
17 | 'filament': 12,
18 | 'plastic': 13,
19 | 'water': 14,
20 | 'vehicle': 15,
21 | 'other': 16,
22 | }
23 |
24 | # Decoding Helper Functions
25 | def spatial_sampling(
26 | frames,
27 | spatial_idx=-1,
28 | min_scale=256,
29 | max_scale=320,
30 | crop_size=224,
31 | random_horizontal_flip=True,
32 | inverse_uniform_sampling=False,
33 | aspect_ratio=None,
34 | scale=None,
35 | motion_shift=False,
36 | ):
37 | """
38 | Perform spatial sampling on the given video frames. If spatial_idx is
39 | -1, perform random scale, random crop, and random flip on the given
40 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
41 | with the given spatial_idx.
42 | Args:
43 | frames (tensor): frames of images sampled from the video. The
44 | dimension is `num frames` x `height` x `width` x `channel`.
45 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
46 | or 2, perform left, center, right crop if width is larger than
47 | height, and perform top, center, buttom crop if height is larger
48 | than width.
49 | min_scale (int): the minimal size of scaling.
50 | max_scale (int): the maximal size of scaling.
51 | crop_size (int): the size of height and width used to crop the
52 | frames.
53 | inverse_uniform_sampling (bool): if True, sample uniformly in
54 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
55 | scale. If False, take a uniform sample from [min_scale,
56 | max_scale].
57 | aspect_ratio (list): Aspect ratio range for resizing.
58 | scale (list): Scale range for resizing.
59 | motion_shift (bool): Whether to apply motion shift for resizing.
60 | Returns:
61 | frames (tensor): spatially sampled frames.
62 | """
63 |
64 | assert spatial_idx in [-1, 0, 1, 2]
65 | if spatial_idx == -1:
66 | if aspect_ratio is None and scale is None:
67 | frames = transform.random_short_side_scale_jitter(
68 | images=frames,
69 | min_size=min_scale,
70 | max_size=max_scale,
71 | inverse_uniform_sampling=inverse_uniform_sampling,
72 | )
73 | frames = transform.random_crop(frames, crop_size)
74 | else:
75 | transform_func = (
76 | transform.random_resized_crop_with_shift
77 | if motion_shift
78 | else transform.random_resized_crop
79 | )
80 | frames = transform_func(
81 | images=frames,
82 | target_height=crop_size,
83 | target_width=crop_size,
84 | scale=scale,
85 | ratio=aspect_ratio,
86 | )
87 | if random_horizontal_flip:
88 | frames = transform.horizontal_flip(0.5, frames)
89 | else:
90 | # The testing is deterministic and no jitter should be performed.
91 | # min_scale, max_scale, and crop_size are expect to be the same.
92 | assert len({min_scale, max_scale}) == 1
93 | frames = transform.random_short_side_scale_jitter(frames, min_scale, max_scale)
94 | frames = transform.uniform_crop(frames, crop_size, spatial_idx)
95 | return frames
96 |
97 | def temporal_sampling(frames, start_idx, end_idx, num_samples):
98 | """
99 | Given the start and end frame index, sample num_samples frames between
100 | the start and end with equal interval.
101 | Args:
102 | frames (tensor): a tensor of video frames, dimension is
103 | `num video frames` x `channel` x `height` x `width`.
104 | start_idx (int): the index of the start frame.
105 | end_idx (int): the index of the end frame.
106 | num_samples (int): number of frames to sample.
107 | Returns:
108 | frames (tersor): a tensor of temporal sampled video frames, dimension is
109 | `num clip frames` x `channel` x `height` x `width`.
110 | """
111 | index = torch.linspace(start_idx, end_idx, num_samples)
112 | index = torch.clamp(index, 0, frames.shape[0] - 1).long()
113 | new_frames = torch.index_select(frames, 0, index)
114 | return new_frames
115 |
116 | def get_start_end_idx(video_size, clip_size, clip_idx, num_clips, use_offset=False):
117 | """
118 | Sample a clip of size clip_size from a video of size video_size and
119 | return the indices of the first and last frame of the clip. If clip_idx is
120 | -1, the clip is randomly sampled, otherwise uniformly split the video to
121 | num_clips clips, and select the start and end index of clip_idx-th video
122 | clip.
123 | Args:
124 | video_size (int): number of overall frames.
125 | clip_size (int): size of the clip to sample from the frames.
126 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
127 | clip_idx is larger than -1, uniformly split the video to num_clips
128 | clips, and select the start and end index of the clip_idx-th video
129 | clip.
130 | num_clips (int): overall number of clips to uniformly sample from the
131 | given video for testing.
132 | Returns:
133 | start_idx (int): the start frame index.
134 | end_idx (int): the end frame index.
135 | """
136 | delta = max(video_size - clip_size, 0)
137 | if clip_idx == -1:
138 | # Random temporal sampling.
139 | start_idx = random.uniform(0, delta)
140 | else:
141 | if use_offset:
142 | if num_clips == 1:
143 | # Take the center clip if num_clips is 1.
144 | start_idx = math.floor(delta / 2)
145 | else:
146 | # Uniformly sample the clip with the given index.
147 | start_idx = clip_idx * math.floor(delta / (num_clips - 1))
148 | else:
149 | # Uniformly sample the clip with the given index.
150 | start_idx = delta * clip_idx / num_clips
151 | end_idx = start_idx + clip_size - 1
152 | return start_idx, end_idx
153 |
154 |
155 | # Helper Functions
156 | def get_video_path(path_to_data_dir, animal, video_id, segment_id):
157 | video_name = 'edited_{video_id}_segment_{segment_id}.mp4'.format(video_id=video_id, segment_id=segment_id.zfill(6))
158 | video_path = os.path.join(path_to_data_dir, animal, video_name)
159 | return video_path
160 |
161 | def process_time(time):
162 | # Assumes in format ##:##:## or ##:## or NONE to get time in seconds
163 | if time == 'NONE':
164 | return 0
165 |
166 | time_split = time.split(':')
167 | if len(time_split) == 1:
168 | time_sec = int(time_split[0])
169 | elif len(time_split) == 2:
170 | time_sec = int(time_split[0]) * 60 + int(time_split[1])
171 | elif len(time_split) == 3:
172 | time_sec = int(time_split[0]) * 3600 + int(time_split[1]) * 60 + int(time_split[2])
173 | else:
174 | raise ValueError('Invalid Time was {time} but expected in format ##:##:##, ##:##, or #'.format(time=time))
175 | return time_sec
176 |
177 | def process_end_time(time, total_time):
178 | # Get end time for processing NONE times
179 | if time == 'NONE':
180 | return process_time(total_time)
181 | else:
182 | return process_time(time)
183 |
184 | def get_none_segments(start_time, end_time, total_time, interacting_object):
185 | min_beg_end_length = 4 + 2 # the actual clip must be at least 4 seconds, 2 seconds from the next clip
186 | min_mid_length = 2 + 4 + 2 # 2 seconds after the last clip, the actual clip must be at least 4 seconds, 2 clips before next clip
187 |
188 | new_start_time = []
189 | new_end_time = []
190 | new_interacting_object = []
191 |
192 | # Checking middle portions
193 | for i in range(len(start_time) - 1):
194 | idx = i+1
195 | if start_time[idx] - end_time[i] >= min_mid_length:
196 | new_start_time.append(end_time[i]+2)
197 | new_end_time.append(start_time[idx]-2)
198 | new_interacting_object.append('NONE')
199 |
200 | # Checking beginning
201 | if start_time[0] >= min_beg_end_length:
202 | new_start_time.append(0)
203 | new_end_time.append(start_time[0]-2)
204 | new_interacting_object.append('NONE')
205 |
206 | # Checking ending
207 | if total_time - end_time[-1] >= min_beg_end_length:
208 | new_start_time.append(end_time[-1]+2)
209 | new_end_time.append(total_time)
210 | new_interacting_object.append('NONE')
211 |
212 | start_time.extend(new_start_time)
213 | end_time.extend(new_end_time)
214 | interacting_object.extend(new_interacting_object)
215 | return start_time, end_time, interacting_object
216 |
217 | def get_label(interacting_object):
218 | if interacting_object == 'NONE':
219 | return -1
220 | elif interacting_object in objects_dict:
221 | object_label_idx = objects_dict[interacting_object]
222 | else:
223 | raise ValueError("Invalid Object Type: {curr_object}".format(curr_object=interacting_object))
224 |
225 | return object_label_idx
226 |
--------------------------------------------------------------------------------
/locomotion_prediction/locomotion_prediction_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | import os
5 | import random
6 | import math
7 |
8 | import torch
9 | import numpy as np
10 | import torch.utils.data
11 | from iopath.common.file_io import g_pathmgr as pathmgr
12 |
13 | import sys
14 | from decoder.utils import decode_ffmpeg
15 | from torchvision import transforms
16 | from torch.nn.functional import normalize
17 |
18 | ### evo evaluation library ###
19 | import evo
20 | from evo.core.trajectory import PoseTrajectory3D
21 | from evo.tools import file_interface
22 | from evo.core import sync, metrics
23 | import evo.main_ape as main_ape
24 | import evo.main_rpe as main_rpe
25 | from evo.core.metrics import PoseRelation
26 | from evo.tools import plot
27 | from dpvo.plot_utils import plot_trajectory, save_trajectory_tum_format
28 |
29 | import sys
30 | parent = os.path.dirname(os.path.abspath(__file__))
31 | parent_parent = os.path.join(parent, '../')
32 | sys.path.append(os.path.dirname(parent_parent))
33 |
34 | from dpvo.plot_utils import *
35 | from pathlib import Path
36 |
37 |
38 | def get_poses(trajectory_path, start_idx, end_idx):
39 | """
40 | Gets the camera poses for the frames start_idx to end_idx.
41 | Args:
42 | trajectory_path (string): the path to the file containing the trajectory
43 | in TUM format.
44 | start_idx (int): the start index to get poses
45 | end_idx (int): the end index to get poses
46 | Returns:
47 | traj_poses (tensor): The poses (end_idx - start_idx) x 7
48 | """
49 | traj = file_interface.read_tum_trajectory_file(trajectory_path)
50 | traj_poses = np.concatenate((traj.positions_xyz, traj.orientations_quat_wxyz[:, [1, 2, 3, 0]]), axis=1)
51 | traj_poses = torch.from_numpy(traj_poses)
52 |
53 | return traj_poses[start_idx:end_idx]
54 |
55 |
56 | def relative_poses(trajectory_path, start_idx, end_idx, scale_invariance, pose_skip):
57 | """
58 | Gets dP, the relative pose to the prior frame, between camera poses for the
59 | frames start_idx to end_idx.
60 | Args:
61 | trajectory_path (string): the path to the file containing the trajectory
62 | in TUM format.
63 | start_idx (int): the start index to get relative poses
64 | end_idx (int): the end index to get relative poses
65 | scale_invariance (str): scale invariance type ('dir' or None)
66 | pose_skip (int): stride of pose prediction
67 | Returns:
68 | dP_tensor (tensor): the dPs . The dimension
69 | is (end_idx - start_idx) x 7.
70 | scale (tensor): if scale_invariance=='dir' return the scale of the
71 | poses
72 | """
73 | traj = file_interface.read_tum_trajectory_file(trajectory_path)
74 | translation = traj.positions_xyz[1:] - traj.positions_xyz[:-1]
75 |
76 | dP_tensor = torch.from_numpy(translation)
77 | zero_tensor = torch.tensor([[0, 0, 0]]) # Hacky adding 0'th relative pose for indexing
78 | dP_tensor = torch.cat([zero_tensor, dP_tensor], dim=0)
79 |
80 | dP_tensor = dP_tensor[start_idx:end_idx]
81 | dP_tensor = dP_tensor.unflatten(0, (-1, pose_skip))
82 | dP_tensor = torch.sum(dP_tensor, dim=1)
83 |
84 | if scale_invariance == 'dir':
85 | scale = torch.norm(dP_tensor, p=2, dim=1, keepdim=True)
86 | dP_tensor = dP_tensor / (scale + 1e-10)
87 | return dP_tensor, scale
88 |
89 | return dP_tensor
90 |
91 |
92 | def make_traj_from_tensor(traj_tensor, start_idx, end_idx, num_condition_frames, num_pose_prediction, act_pose_prediction, pose_skip):
93 | tstamps = np.arange(0, end_idx, 1, dtype=np.float64)
94 | tstamps = tstamps.reshape((-1, num_condition_frames + num_pose_prediction))[:, num_condition_frames+(pose_skip-1)::pose_skip]
95 | tstamps = tstamps.flatten()
96 |
97 | tstamps_idx = np.arange(0, len(traj_tensor), 1, dtype=np.float64)
98 | tstamps_idx = tstamps_idx.reshape((-1, num_condition_frames + act_pose_prediction))[:, num_condition_frames:]
99 | tstamps_idx = tstamps_idx.flatten()
100 |
101 | traj = PoseTrajectory3D(positions_xyz=traj_tensor[tstamps_idx.astype(int),:3], orientations_quat_wxyz=traj_tensor[tstamps_idx.astype(int),3:][:, [3, 0, 1, 2]], timestamps=tstamps)
102 |
103 | return traj
104 |
105 |
106 | def eval_metrics(traj_ref, traj_pred):
107 | traj_ref, traj_pred = sync.associate_trajectories(traj_ref, traj_pred)
108 |
109 | result = main_ape.ape(traj_ref, traj_pred, est_name='traj',
110 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True)
111 | ate = result.stats['rmse']
112 |
113 | result = main_rpe.rpe(traj_ref, traj_pred, est_name='traj',
114 | pose_relation=PoseRelation.rotation_angle_deg, align=True, correct_scale=True,
115 | delta=1.0, delta_unit=metrics.Unit.frames, rel_delta_tol=0.1)
116 | rpe_rot = result.stats['rmse']
117 |
118 | result = main_rpe.rpe(traj_ref, traj_pred, est_name='traj',
119 | pose_relation=PoseRelation.translation_part, align=True, correct_scale=True,
120 | delta=1.0, delta_unit=metrics.Unit.frames, rel_delta_tol=0.1)
121 | rpe_trans = result.stats['rmse']
122 |
123 | return ate, rpe_trans, rpe_rot
124 |
125 |
126 | def evaluate_segment(model, dataset_val, device, video_path, trajectory_path, criterion, start_idx, end_idx, args, save_plot=False, image_model=False):
127 | # Assume batch of 1
128 | loss = 0
129 | total_idx = end_idx - start_idx
130 |
131 | frames_per_sub_clip = args.num_condition_frames + args.num_pose_prediction
132 | act_frames_per_sub_clip = args.num_condition_frames + args.act_pose_prediction
133 | num_sub_clips = total_idx // frames_per_sub_clip
134 | num_sub_batch = math.ceil(num_sub_clips / args.validation_batch_size)
135 |
136 | all_cond_poses = get_poses(trajectory_path, start_idx, end_idx).cpu()
137 | all_cond_poses = all_cond_poses.unflatten(0, (-1, frames_per_sub_clip))
138 | all_cond_poses = torch.cat((all_cond_poses[:, :args.num_condition_frames ,:], all_cond_poses[:, args.num_condition_frames+(args.pose_skip-1)::args.pose_skip]), 1)
139 | all_cond_poses = all_cond_poses.flatten(0, 1)
140 | traj_ref = make_traj_from_tensor(all_cond_poses, start_idx, end_idx, args.num_condition_frames, args.num_pose_prediction, args.act_pose_prediction, args.pose_skip)
141 |
142 | traj_pred_poses = all_cond_poses.clone()
143 | traj_pred_poses = traj_pred_poses.unflatten(0, (-1, act_frames_per_sub_clip))
144 | traj_pred_poses[:, args.num_condition_frames:] = 0
145 | traj_pred_poses = traj_pred_poses.flatten(0, 1)
146 | traj_pred_poses[:, 3:] = all_cond_poses[:, 3:]
147 | traj_pred_poses = traj_pred_poses.to(device=device, dtype=torch.float32)
148 |
149 | for i in range(num_sub_batch):
150 | sub_start_idx = i * args.validation_batch_size * frames_per_sub_clip
151 | sub_end_idx = min((i + 1) * args.validation_batch_size * frames_per_sub_clip, total_idx)
152 |
153 | frames, dP_tensor = dataset_val._get_sub_clip(video_path, trajectory_path, sub_start_idx, sub_end_idx)
154 |
155 | frames = frames.to(device, non_blocking=True)
156 | dP_tensor = dP_tensor.to(device= device, dtype=torch.float32, non_blocking=True)
157 |
158 | frames = frames.unflatten(0, (-1, frames_per_sub_clip))
159 | frames = frames[:, :args.num_condition_frames, :, :, :]
160 | frames = frames.permute(0, 2, 1, 3, 4)
161 |
162 | dP_tensor = dP_tensor.unflatten(0, (-1, frames_per_sub_clip))
163 | dP_tensor = dP_tensor[:, args.num_condition_frames:, :].unflatten(1, (-1, args.pose_skip))
164 | dP_tensor = torch.sum(dP_tensor, dim=2)
165 |
166 | if args.scale_invariance == 'dir':
167 | scale = torch.norm(dP_tensor, p=2, dim=2, keepdim=True)
168 | dP_tensor = dP_tensor / (scale + 1e-10)
169 |
170 | if image_model:
171 | frames = frames[:, :, -1, :, :]
172 |
173 | outputs = model(frames)
174 | dP_pred = outputs.unflatten(1, (args.act_pose_prediction, args.num_pred))
175 |
176 | loss += criterion(dP_pred, dP_tensor)
177 |
178 | if args.scale_invariance == 'dir':
179 | dP_pred_unnorm = dP_pred * scale
180 |
181 | i_start_idx = i*args.validation_batch_size*act_frames_per_sub_clip
182 | for j in range(dP_pred_unnorm.shape[0]):
183 | j_start_idx = i_start_idx + j * act_frames_per_sub_clip + args.num_condition_frames
184 |
185 | start_xyz = traj_pred_poses[j_start_idx - 1, :3].unsqueeze(0)
186 |
187 | pred_xyz = torch.cat([start_xyz, dP_pred_unnorm[j]], dim=0)
188 | pred_xyz = torch.cumsum(pred_xyz, dim=0)[1:]
189 | traj_pred_poses[j_start_idx:j_start_idx + args.act_pose_prediction, :3] = pred_xyz
190 |
191 | traj_pred = make_traj_from_tensor(traj_pred_poses.detach().cpu(), start_idx, end_idx, args.num_condition_frames, args.num_pose_prediction, args.act_pose_prediction, args.pose_skip)
192 | ate, rpe_trans, rpe_rot = eval_metrics(traj_ref, traj_pred)
193 |
194 | segment_id = video_path.split('/')[-1][:-4]
195 | if save_plot and args.output_dir:
196 | segment_id = video_path.split('/')[-1][:-4]
197 | title = 'Reconstruction: {}'.format(segment_id)
198 | filename = os.path.join(args.output_dir, 'traj_viz', '{}_reconstruction.png'.format(segment_id))
199 | plot_trajectory(traj_pred, gt_traj=traj_ref, title=title, filename=filename, align=False, correct_scale=True)
200 |
201 | return loss, ate, rpe_trans, rpe_rot
202 |
--------------------------------------------------------------------------------
/modeling_teacher.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from functools import partial, reduce
8 | from operator import mul
9 |
10 | from modeling_finetune import _cfg
11 | from timm.models.registry import register_model
12 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
13 | from timm.models.vision_transformer import PatchEmbed, Block
14 | from timm.models.layers import drop_path, to_2tuple
15 |
16 |
17 | def trunc_normal_(tensor, mean=0., std=1.):
18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
19 |
20 |
21 | __all__ = [
22 | 'mae_vit_base_patch16_dec512d8b',
23 | 'mae_vit_large_patch16_dec512d8b',
24 | 'mae_vit_huge_patch14_dec512d8b',
25 | ]
26 |
27 |
28 | # --------------------------------------------------------
29 | # 2D sine-cosine position embedding
30 | # References:
31 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
32 | # MoCo v3: https://github.com/facebookresearch/moco-v3
33 | # --------------------------------------------------------
34 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
35 | """
36 | grid_size: int of the grid height and width
37 | return:
38 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
39 | """
40 | grid_h = np.arange(grid_size, dtype=np.float32)
41 | grid_w = np.arange(grid_size, dtype=np.float32)
42 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
43 | grid = np.stack(grid, axis=0)
44 |
45 | grid = grid.reshape([2, 1, grid_size, grid_size])
46 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
47 | if cls_token:
48 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
49 | return pos_embed
50 |
51 |
52 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
53 | assert embed_dim % 2 == 0
54 |
55 | # use half of dimensions to encode grid_h
56 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
57 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
58 |
59 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
60 | return emb
61 |
62 |
63 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
64 | """
65 | embed_dim: output dimension for each position
66 | pos: a list of positions to be encoded: size (M,)
67 | out: (M, D)
68 | """
69 | assert embed_dim % 2 == 0
70 | omega = np.arange(embed_dim // 2, dtype=np.float)
71 | omega /= embed_dim / 2.
72 | omega = 1. / 10000**omega # (D/2,)
73 |
74 | pos = pos.reshape(-1) # (M,)
75 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
76 |
77 | emb_sin = np.sin(out) # (M, D/2)
78 | emb_cos = np.cos(out) # (M, D/2)
79 |
80 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
81 | return emb
82 |
83 |
84 | # --------------------------------------------------------
85 | # Interpolate position embeddings for high-resolution
86 | # References:
87 | # DeiT: https://github.com/facebookresearch/deit
88 | # --------------------------------------------------------
89 | def interpolate_pos_embed(model, checkpoint_model):
90 | if 'pos_embed' in checkpoint_model:
91 | pos_embed_checkpoint = checkpoint_model['pos_embed']
92 | embedding_size = pos_embed_checkpoint.shape[-1]
93 | num_patches = model.patch_embed.num_patches
94 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
95 | # height (== width) for the checkpoint position embedding
96 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
97 | # height (== width) for the new position embedding
98 | new_size = int(num_patches ** 0.5)
99 | # class_token and dist_token are kept unchanged
100 | if orig_size != new_size:
101 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
102 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
103 | # only the position tokens are interpolated
104 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
105 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
106 | pos_tokens = torch.nn.functional.interpolate(
107 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
108 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
109 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
110 | checkpoint_model['pos_embed'] = new_pos_embed
111 |
112 |
113 | # --------------------------------------------------------
114 | # MAE encoder
115 | # References:
116 | # MAE: https://github.com/facebookresearch/mae
117 | # --------------------------------------------------------
118 | class MaskedAutoencoderViT(nn.Module):
119 | """ Masked Autoencoder with VisionTransformer backbone
120 | """
121 |
122 | def __init__(self, img_size=224, patch_size=16, in_chans=3,
123 | embed_dim=1024, depth=24, num_heads=16,
124 | mlp_ratio=4., norm_layer=nn.LayerNorm,
125 | ):
126 | super().__init__()
127 |
128 | # --------------------------------------------------------------------------
129 | # MAE encoder specifics
130 | self.img_size = img_size
131 | self.patch_size = patch_size
132 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
133 | num_patches = self.patch_embed.num_patches
134 |
135 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
136 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim),
137 | requires_grad=False) # fixed sin-cos embedding
138 |
139 | self.blocks = nn.ModuleList([
140 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
141 | for i in range(depth)])
142 | self.norm = norm_layer(embed_dim)
143 | # --------------------------------------------------------------------------
144 |
145 | self.initialize_weights()
146 |
147 | def initialize_weights(self):
148 | # initialization
149 | # initialize (and freeze) pos_embed by sin-cos embedding
150 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** .5),
151 | cls_token=True)
152 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
153 |
154 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
155 | w = self.patch_embed.proj.weight.data
156 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
157 |
158 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
159 | torch.nn.init.normal_(self.cls_token, std=.02)
160 |
161 | # initialize nn.Linear and nn.LayerNorm
162 | self.apply(self._init_weights)
163 |
164 | def _init_weights(self, m):
165 | if isinstance(m, nn.Linear):
166 | # we use xavier_uniform following official JAX ViT:
167 | torch.nn.init.xavier_uniform_(m.weight)
168 | if isinstance(m, nn.Linear) and m.bias is not None:
169 | nn.init.constant_(m.bias, 0)
170 | elif isinstance(m, nn.LayerNorm):
171 | nn.init.constant_(m.bias, 0)
172 | nn.init.constant_(m.weight, 1.0)
173 |
174 | def patchify(self, imgs):
175 | """
176 | imgs: (N, 3, H, W)
177 | x: (N, L, patch_size**2 *3)
178 | """
179 | p = self.patch_embed.patch_size[0]
180 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
181 |
182 | h = w = imgs.shape[2] // p
183 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
184 | x = torch.einsum('nchpwq->nhwpqc', x)
185 | x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
186 | return x
187 |
188 | def forward_encoder(self, x):
189 | # embed patches
190 | x = self.patch_embed(x)
191 |
192 | # add pos embed w/o cls token
193 | x = x + self.pos_embed[:, 1:, :]
194 |
195 | B, _, C = x.shape
196 |
197 | # append cls token
198 | cls_token = self.cls_token + self.pos_embed[:, :1, :]
199 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
200 | x = torch.cat((cls_tokens, x), dim=1)
201 |
202 | # apply Transformer blocks
203 | for i, blk in enumerate(self.blocks):
204 | x = blk(x)
205 |
206 | return x[:, 1:]
207 |
208 | def forward(self, imgs):
209 | latent = self.forward_encoder(imgs)
210 | return latent
211 |
212 |
213 | @register_model
214 | def mae_teacher_vit_base_patch16(pretrained=False, **kwargs):
215 | model = MaskedAutoencoderViT(
216 | patch_size=16, embed_dim=768, depth=12, num_heads=12,
217 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
218 | return model
219 |
220 |
221 | @register_model
222 | def mae_teacher_vit_large_patch16(pretrained=False, **kwargs):
223 | model = MaskedAutoencoderViT(
224 | patch_size=16, embed_dim=1024, depth=24, num_heads=16,
225 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
226 | return model
227 |
228 |
229 | @register_model
230 | def mae_teacher_vit_huge_patch14(pretrained=False, **kwargs):
231 | model = MaskedAutoencoderViT(
232 | patch_size=14, embed_dim=1280, depth=32, num_heads=16,
233 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
234 | return model
235 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torchvision import transforms
3 | from transforms import *
4 | import video_transforms
5 | from masking_generator import TubeMaskingGenerator, RandomMaskingGenerator
6 | from kinetics import VideoClsDataset, VideoDistillation
7 | from ssv2 import SSVideoClsDataset
8 |
9 |
10 | class DataAugmentationForVideoDistillation(object):
11 | def __init__(self, args, num_frames=None):
12 | self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN
13 | self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD
14 | normalize = GroupNormalize(self.input_mean, self.input_std)
15 | self.train_augmentation = GroupMultiScaleTwoResizedCrop(
16 | args.input_size, args.teacher_input_size, [1, .875, .75, .66]
17 | )
18 | self.transform = transforms.Compose([
19 | Stack(roll=False),
20 | ToTorchFormatTensor(div=True),
21 | normalize,
22 | ])
23 | window_size = args.window_size if num_frames is None else (num_frames // args.tubelet_size, args.window_size[1], args.window_size[2])
24 | if args.mask_type == 'tube':
25 | self.masked_position_generator = TubeMaskingGenerator(
26 | window_size, args.mask_ratio
27 | )
28 | elif args.mask_type == 'random':
29 | self.masked_position_generator = RandomMaskingGenerator(
30 | window_size, args.mask_ratio
31 | )
32 |
33 | def __call__(self, images):
34 | process_data_0, process_data_1, labels = self.train_augmentation(images)
35 | process_data_0, _ = self.transform((process_data_0, labels))
36 | process_data_1, _ = self.transform((process_data_1, labels))
37 | return process_data_0, process_data_1, self.masked_position_generator()
38 |
39 | def __repr__(self):
40 | repr = "(DataAugmentationForVideoDistillation,\n"
41 | repr += " transform = %s,\n" % str(self.transform)
42 | repr += " Masked position generator = %s,\n" % str(self.masked_position_generator)
43 | repr += ")"
44 | return repr
45 |
46 |
47 | def build_distillation_dataset(args, num_frames=None):
48 | if num_frames is None:
49 | num_frames = args.num_frames
50 | transform = DataAugmentationForVideoDistillation(args, num_frames=num_frames)
51 | dataset = VideoDistillation(
52 | root=args.data_root,
53 | setting=args.data_path,
54 | video_ext='mp4',
55 | is_color=True,
56 | modality='rgb',
57 | new_length=num_frames,
58 | new_step=args.sampling_rate,
59 | transform=transform,
60 | temporal_jitter=False,
61 | video_loader=True,
62 | use_decord=True,
63 | lazy_init=False,
64 | num_sample=args.num_sample,
65 | num_segments=args.num_sample,
66 | # ds_size=5893
67 | )
68 | print("Data Aug = %s" % str(transform))
69 | return dataset
70 |
71 |
72 | def build_dataset(is_train, test_mode, args):
73 | if args.data_set == 'Kinetics-400':
74 | mode = None
75 | anno_path = None
76 | if is_train is True:
77 | mode = 'train'
78 | anno_path = os.path.join(args.data_path, 'train.csv')
79 | elif test_mode is True:
80 | mode = 'test'
81 | anno_path = os.path.join(args.data_path, 'val.csv')
82 | else:
83 | mode = 'validation'
84 | anno_path = os.path.join(args.data_path, 'val.csv')
85 |
86 | dataset = VideoClsDataset(
87 | anno_path=anno_path,
88 | data_path=args.data_root,
89 | mode=mode,
90 | clip_len=args.num_frames,
91 | frame_sample_rate=args.sampling_rate,
92 | num_segment=1,
93 | test_num_segment=args.test_num_segment,
94 | test_num_crop=args.test_num_crop,
95 | num_crop=1 if not test_mode else 3,
96 | keep_aspect_ratio=True,
97 | crop_size=args.input_size,
98 | short_side_size=args.short_side_size,
99 | new_height=256,
100 | new_width=320,
101 | args=args,
102 | )
103 | nb_classes = 400
104 |
105 | elif args.data_set == 'SSV2':
106 | mode = None
107 | anno_path = None
108 | if is_train is True:
109 | mode = 'train'
110 | anno_path = os.path.join(args.data_path, 'train.csv')
111 | elif test_mode is True:
112 | mode = 'test'
113 | anno_path = os.path.join(args.data_path, 'val.csv')
114 | else:
115 | mode = 'validation'
116 | anno_path = os.path.join(args.data_path, 'val.csv')
117 |
118 | dataset = SSVideoClsDataset(
119 | anno_path=anno_path,
120 | data_path=args.data_root,
121 | mode=mode,
122 | clip_len=1,
123 | num_segment=args.num_frames,
124 | test_num_segment=args.test_num_segment,
125 | test_num_crop=args.test_num_crop,
126 | num_crop=1 if not test_mode else 3,
127 | keep_aspect_ratio=True,
128 | crop_size=args.input_size,
129 | short_side_size=args.short_side_size,
130 | new_height=256,
131 | new_width=320,
132 | args=args,
133 | )
134 | nb_classes = 174
135 |
136 | elif args.data_set == 'UCF101':
137 | mode = None
138 | anno_path = None
139 | if is_train is True:
140 | mode = 'train'
141 | anno_path = os.path.join(args.data_path, 'train.csv')
142 | elif test_mode is True:
143 | mode = 'test'
144 | anno_path = os.path.join(args.data_path, 'val.csv')
145 | else:
146 | mode = 'validation'
147 | anno_path = os.path.join(args.data_path, 'test.csv')
148 |
149 | dataset = VideoClsDataset(
150 | anno_path=anno_path,
151 | data_path=args.data_root,
152 | mode=mode,
153 | clip_len=args.num_frames,
154 | frame_sample_rate=args.sampling_rate,
155 | num_segment=1,
156 | test_num_segment=args.test_num_segment,
157 | test_num_crop=args.test_num_crop,
158 | num_crop=1 if not test_mode else 3,
159 | keep_aspect_ratio=True,
160 | crop_size=args.input_size,
161 | short_side_size=args.short_side_size,
162 | new_height=256,
163 | new_width=320,
164 | args=args)
165 | nb_classes = 101
166 |
167 | elif args.data_set == 'HMDB51':
168 | mode = None
169 | anno_path = None
170 | if is_train is True:
171 | mode = 'train'
172 | anno_path = os.path.join(args.data_path, 'train.csv')
173 | elif test_mode is True:
174 | mode = 'test'
175 | anno_path = os.path.join(args.data_path, 'val.csv')
176 | else:
177 | mode = 'validation'
178 | anno_path = os.path.join(args.data_path, 'test.csv')
179 |
180 | dataset = VideoClsDataset(
181 | anno_path=anno_path,
182 | data_path=args.data_root,
183 | mode=mode,
184 | clip_len=args.num_frames,
185 | frame_sample_rate=args.sampling_rate,
186 | num_segment=1,
187 | test_num_segment=args.test_num_segment,
188 | test_num_crop=args.test_num_crop,
189 | num_crop=1 if not test_mode else 3,
190 | keep_aspect_ratio=True,
191 | crop_size=args.input_size,
192 | short_side_size=args.short_side_size,
193 | new_height=256,
194 | new_width=320,
195 | args=args)
196 | nb_classes = 51
197 | elif args.data_set == 'egopet':
198 | mode = None
199 | anno_path = None
200 | if is_train is True:
201 | mode = 'train'
202 | anno_path = '/private/home/amirbar/datasets/egopet/egopet_df_v2_mvd_train.csv'
203 | elif test_mode is True:
204 | mode = 'test'
205 | anno_path = '/private/home/amirbar/datasets/egopet/egopet_df_v2_mvd_val.csv'
206 | else:
207 | mode = 'validation'
208 | anno_path = '/private/home/amirbar/datasets/egopet/egopet_df_v2_mvd_val.csv'
209 |
210 | dataset = VideoClsDataset(
211 | anno_path=anno_path,
212 | data_path=args.data_root,
213 | mode=mode,
214 | clip_len=args.num_frames,
215 | frame_sample_rate=args.sampling_rate,
216 | num_segment=1,
217 | test_num_segment=args.test_num_segment,
218 | test_num_crop=args.test_num_crop,
219 | num_crop=1 if not test_mode else 3,
220 | keep_aspect_ratio=True,
221 | crop_size=args.input_size,
222 | short_side_size=args.short_side_size,
223 | new_height=256,
224 | new_width=320,
225 | args=args,
226 | ds_size=5893
227 | )
228 | nb_classes = 400
229 |
230 | elif args.data_set == 'ego4d':
231 | mode = None
232 | anno_path = None
233 | if is_train is True:
234 | mode = 'train'
235 | anno_path = '/private/home/amirbar/datasets/egopet/ego4d_df_v2_mvd_train.csv'
236 | elif test_mode is True:
237 | mode = 'test'
238 | anno_path = None
239 | else:
240 | mode = 'validation'
241 | anno_path = None
242 |
243 | dataset = VideoClsDataset(
244 | anno_path=anno_path,
245 | data_path=args.data_root,
246 | mode=mode,
247 | clip_len=args.num_frames,
248 | frame_sample_rate=args.sampling_rate,
249 | num_segment=1,
250 | test_num_segment=args.test_num_segment,
251 | test_num_crop=args.test_num_crop,
252 | num_crop=1 if not test_mode else 3,
253 | keep_aspect_ratio=True,
254 | crop_size=args.input_size,
255 | short_side_size=args.short_side_size,
256 | new_height=256,
257 | new_width=320,
258 | args=args,
259 | ds_size=5893
260 | )
261 | nb_classes = 400
262 |
263 | elif args.data_set == 'egomix':
264 | pass
265 |
266 | else:
267 | raise NotImplementedError()
268 | assert nb_classes == args.nb_classes
269 | print("Number of the class = %d" % args.nb_classes)
270 |
271 | return dataset, nb_classes
272 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | import warnings
4 | import random
5 | import numpy as np
6 | import torchvision
7 | from PIL import Image, ImageOps
8 | import numbers
9 |
10 |
11 | class GroupRandomCrop(object):
12 | def __init__(self, size):
13 | if isinstance(size, numbers.Number):
14 | self.size = (int(size), int(size))
15 | else:
16 | self.size = size
17 |
18 | def __call__(self, img_tuple):
19 | img_group, label = img_tuple
20 |
21 | w, h = img_group[0].size
22 | th, tw = self.size
23 |
24 | out_images = list()
25 |
26 | x1 = random.randint(0, w - tw)
27 | y1 = random.randint(0, h - th)
28 |
29 | for img in img_group:
30 | assert(img.size[0] == w and img.size[1] == h)
31 | if w == tw and h == th:
32 | out_images.append(img)
33 | else:
34 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
35 |
36 | return (out_images, label)
37 |
38 |
39 | class GroupCenterCrop(object):
40 | def __init__(self, size):
41 | self.worker = torchvision.transforms.CenterCrop(size)
42 |
43 | def __call__(self, img_tuple):
44 | img_group, label = img_tuple
45 | return ([self.worker(img) for img in img_group], label)
46 |
47 |
48 | class GroupNormalize(object):
49 | def __init__(self, mean, std):
50 | self.mean = mean
51 | self.std = std
52 |
53 | def __call__(self, tensor_tuple):
54 | tensor, label = tensor_tuple
55 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
56 | rep_std = self.std * (tensor.size()[0]//len(self.std))
57 |
58 | # TODO: make efficient
59 | for t, m, s in zip(tensor, rep_mean, rep_std):
60 | t.sub_(m).div_(s)
61 |
62 | return (tensor,label)
63 |
64 |
65 | class GroupGrayScale(object):
66 | def __init__(self, size):
67 | self.worker = torchvision.transforms.Grayscale(size)
68 |
69 | def __call__(self, img_tuple):
70 | img_group, label = img_tuple
71 | return ([self.worker(img) for img in img_group], label)
72 |
73 |
74 | class GroupScale(object):
75 | """ Rescales the input PIL.Image to the given 'size'.
76 | 'size' will be the size of the smaller edge.
77 | For example, if height > width, then image will be
78 | rescaled to (size * height / width, size)
79 | size: size of the smaller edge
80 | interpolation: Default: PIL.Image.BILINEAR
81 | """
82 |
83 | def __init__(self, size, interpolation=Image.BILINEAR):
84 | self.worker = torchvision.transforms.Resize(size, interpolation)
85 |
86 | def __call__(self, img_tuple):
87 | img_group, label = img_tuple
88 | return ([self.worker(img) for img in img_group], label)
89 |
90 |
91 | class GroupMultiScaleCrop(object):
92 |
93 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
94 | self.scales = scales if scales is not None else [1, .875, .75, .66]
95 | self.max_distort = max_distort
96 | self.fix_crop = fix_crop
97 | self.more_fix_crop = more_fix_crop
98 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
99 | self.interpolation = Image.BILINEAR
100 |
101 | def __call__(self, img_tuple):
102 | img_group, label = img_tuple
103 |
104 | im_size = img_group[0].size
105 |
106 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
107 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
108 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in crop_img_group]
109 | return (ret_img_group, label)
110 |
111 | def _sample_crop_size(self, im_size):
112 | image_w, image_h = im_size[0], im_size[1]
113 |
114 | # find a crop size
115 | base_size = min(image_w, image_h)
116 | crop_sizes = [int(base_size * x) for x in self.scales]
117 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
118 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
119 |
120 | pairs = []
121 | for i, h in enumerate(crop_h):
122 | for j, w in enumerate(crop_w):
123 | if abs(i - j) <= self.max_distort:
124 | pairs.append((w, h))
125 |
126 | crop_pair = random.choice(pairs)
127 | if not self.fix_crop:
128 | w_offset = random.randint(0, image_w - crop_pair[0])
129 | h_offset = random.randint(0, image_h - crop_pair[1])
130 | else:
131 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
132 |
133 | return crop_pair[0], crop_pair[1], w_offset, h_offset
134 |
135 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
136 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
137 | return random.choice(offsets)
138 |
139 | @staticmethod
140 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
141 | w_step = (image_w - crop_w) // 4
142 | h_step = (image_h - crop_h) // 4
143 |
144 | ret = list()
145 | ret.append((0, 0)) # upper left
146 | ret.append((4 * w_step, 0)) # upper right
147 | ret.append((0, 4 * h_step)) # lower left
148 | ret.append((4 * w_step, 4 * h_step)) # lower right
149 | ret.append((2 * w_step, 2 * h_step)) # center
150 |
151 | if more_fix_crop:
152 | ret.append((0, 2 * h_step)) # center left
153 | ret.append((4 * w_step, 2 * h_step)) # center right
154 | ret.append((2 * w_step, 4 * h_step)) # lower center
155 | ret.append((2 * w_step, 0 * h_step)) # upper center
156 |
157 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
158 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
159 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
160 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
161 | return ret
162 |
163 |
164 | class GroupMultiScaleTwoResizedCrop(object):
165 |
166 | def __init__(self, input_size, input_size_1, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
167 | self.scales = scales if scales is not None else [1, .875, .75, .66]
168 | self.max_distort = max_distort
169 | self.fix_crop = fix_crop
170 | self.more_fix_crop = more_fix_crop
171 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
172 | self.input_size_1 = input_size_1 if not isinstance(input_size_1, int) else [input_size_1, input_size_1]
173 | self.interpolation = Image.BILINEAR
174 |
175 | def __call__(self, img_tuple):
176 | img_group, label = img_tuple
177 |
178 | im_size = img_group[0].size
179 |
180 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
181 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
182 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) for img in
183 | crop_img_group]
184 | ret_img_group_1 = [img.resize((self.input_size_1[0], self.input_size_1[1]), self.interpolation) for img in
185 | crop_img_group]
186 | return (ret_img_group, ret_img_group_1, label)
187 |
188 | def _sample_crop_size(self, im_size):
189 | image_w, image_h = im_size[0], im_size[1]
190 |
191 | # find a crop size
192 | base_size = min(image_w, image_h)
193 | crop_sizes = [int(base_size * x) for x in self.scales]
194 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
195 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
196 |
197 | pairs = []
198 | for i, h in enumerate(crop_h):
199 | for j, w in enumerate(crop_w):
200 | if abs(i - j) <= self.max_distort:
201 | pairs.append((w, h))
202 |
203 | crop_pair = random.choice(pairs)
204 | if not self.fix_crop:
205 | w_offset = random.randint(0, image_w - crop_pair[0])
206 | h_offset = random.randint(0, image_h - crop_pair[1])
207 | else:
208 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
209 |
210 | return crop_pair[0], crop_pair[1], w_offset, h_offset
211 |
212 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
213 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
214 | return random.choice(offsets)
215 |
216 | @staticmethod
217 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
218 | w_step = (image_w - crop_w) // 4
219 | h_step = (image_h - crop_h) // 4
220 |
221 | ret = list()
222 | ret.append((0, 0)) # upper left
223 | ret.append((4 * w_step, 0)) # upper right
224 | ret.append((0, 4 * h_step)) # lower left
225 | ret.append((4 * w_step, 4 * h_step)) # lower right
226 | ret.append((2 * w_step, 2 * h_step)) # center
227 |
228 | if more_fix_crop:
229 | ret.append((0, 2 * h_step)) # center left
230 | ret.append((4 * w_step, 2 * h_step)) # center right
231 | ret.append((2 * w_step, 4 * h_step)) # lower center
232 | ret.append((2 * w_step, 0 * h_step)) # upper center
233 |
234 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter
235 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter
236 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter
237 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
238 | return ret
239 |
240 |
241 | class Stack(object):
242 |
243 | def __init__(self, roll=False):
244 | self.roll = roll
245 |
246 | def __call__(self, img_tuple):
247 | img_group, label = img_tuple
248 |
249 | if img_group[0].mode == 'L':
250 | return (np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2), label)
251 | elif img_group[0].mode == 'RGB':
252 | if self.roll:
253 | return (np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2), label)
254 | else:
255 | return (np.concatenate(img_group, axis=2), label)
256 |
257 |
258 | class ToTorchFormatTensor(object):
259 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
260 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
261 | def __init__(self, div=True):
262 | self.div = div
263 |
264 | def __call__(self, pic_tuple):
265 | pic, label = pic_tuple
266 |
267 | if isinstance(pic, np.ndarray):
268 | # handle numpy array
269 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
270 | else:
271 | # handle PIL Image
272 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
273 | img = img.view(pic.size[1], pic.size[0], len(pic.mode))
274 | # put it from HWC to CHW format
275 | # yikes, this transpose takes 80% of the loading time/CPU
276 | img = img.transpose(0, 1).transpose(0, 2).contiguous()
277 | return (img.float().div(255.) if self.div else img.float(), label)
278 |
279 |
280 | class IdentityTransform(object):
281 |
282 | def __call__(self, data):
283 | return data
284 |
--------------------------------------------------------------------------------
/object_interaction/object_interaction_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | import os
5 | import random
6 |
7 | import torch
8 | import torch.utils.data
9 | from iopath.common.file_io import g_pathmgr as pathmgr
10 |
11 | import sys
12 | parent = os.path.dirname(os.path.abspath(__file__))
13 | parent_parent = os.path.join(parent, '../')
14 | sys.path.append(os.path.dirname(parent_parent))
15 |
16 | from decoder.utils import decode_ffmpeg
17 | from object_interaction.object_interaction_utils import *
18 | from torchvision import transforms
19 |
20 | from pathlib import Path
21 |
22 |
23 | class Object_Interaction(torch.utils.data.Dataset):
24 | """
25 | Object Interaction video loader. Construct the Object Interaction video loader,
26 | then sample clips from the videos. For training and validation, a single clip
27 | is randomly sampled from every video with normalization.
28 | """
29 |
30 | def __init__(
31 | self,
32 | mode,
33 | path_to_data_dir,
34 | path_to_csv,
35 | # transformation
36 | transform,
37 | # decoding settings
38 | num_frames=8,
39 | target_fps=30,
40 | num_sec=2,
41 | fps=4,
42 | # frame aug settings
43 | crop_size=224,
44 | # other parameters
45 | enable_multi_thread_decode=False,
46 | use_offset_sampling=True,
47 | inverse_uniform_sampling=False,
48 | num_retries=10,
49 | # object or no object
50 | object_interaction_ratio=0.5,
51 | ):
52 | """
53 | Construct the Object Interaction video loader with a given csv file. The format of
54 | the csv file is:
55 | ```
56 | animal_1, ds_type_1, video_id_1, segment_id_1, start_time_1, end_time_1, total_time_1, interacting_object_1, video_path_1
57 | animal_2, ds_type_2, video_id_2, segment_id_2, start_time_2, end_time_2, total_time_2, interacting_object_2, video_path_2
58 | ...
59 | animal_N, ds_type_N, video_id_N, segment_id_N, start_time_N, end_time_N, total_time_N, interacting_object_N, video_path_N
60 | ```
61 | Args:
62 | mode (string): Options includes `train` or `test`.
63 | For the train and val mode, the data loader will take data
64 | from the train or val set, and sample one clip per video.
65 | path_to_data_dir (string): Path to EgoPet Dataset
66 | path_to_csv (string): Path to Object Interaction data
67 | num_frames (int): number of frames used for model
68 | object_interaction_ratio (float): ratio of clips with interactions during training
69 | """
70 | # Only support train, val, and test mode.
71 | assert mode in [
72 | "train",
73 | "test",
74 | ], "mode has to be 'train' or 'test'"
75 | self.mode = mode
76 |
77 | self._num_retries = num_retries
78 | self._path_to_data_dir = path_to_data_dir
79 | self._path_to_csv = path_to_csv
80 | self.object_interaction_ratio = object_interaction_ratio
81 |
82 | self._crop_size = crop_size
83 |
84 | self._num_frames = num_frames
85 | self._num_sec = num_sec
86 | self._target_fps = target_fps
87 | self._fps=fps
88 |
89 | self.transform = transform
90 |
91 | self._enable_multi_thread_decode = enable_multi_thread_decode
92 | self._inverse_uniform_sampling = inverse_uniform_sampling
93 | self._use_offset_sampling = use_offset_sampling
94 |
95 | print(self)
96 | print(locals())
97 | self._construct_loader()
98 |
99 | def _construct_loader(self):
100 | """
101 | Construct the video loader.
102 | """
103 | self._object_path_to_videos = []
104 | self._no_object_path_to_videos = []
105 | self._object_labels_times = []
106 | self._no_object_labels_times = []
107 |
108 | with pathmgr.open(self._path_to_csv, "r") as f:
109 | for curr_clip in f.read().splitlines():
110 | curr_values = curr_clip.split(',')
111 | if curr_values[0] != 'animal':
112 | animal, ds_type, video_id, segment_id, start_time, end_time, total_time, interacting_object, video_path = curr_values
113 | video_path = os.path.join(self._path_to_data_dir, video_path)
114 | start_time, end_time, interacting_object = start_time.split(';'), end_time.split(';'), interacting_object.split(';')
115 | start_time, end_time, total_time = [process_time(time) for time in start_time], [process_end_time(time, total_time) for time in end_time], process_time(total_time)
116 |
117 | assert len(start_time) == len(end_time) == len(interacting_object), "Error with csv on {curr_row}".format(curr_row=curr_values)
118 |
119 | # Getting None Segments inbetween videos
120 | start_time, end_time, interacting_object = get_none_segments(start_time, end_time, total_time, interacting_object)
121 |
122 | object_labels = [get_label(curr_interacting_object) for curr_interacting_object in interacting_object]
123 | assert len(start_time) == len(end_time) == len(interacting_object), "Error with processing"
124 |
125 | for i in range(len(start_time)):
126 | interaction = 0. if object_labels[i] == -1 else 1.
127 | if object_labels[i] == -1:
128 | object_labels[i] = 0
129 |
130 | labels_times = (interaction, object_labels[i], start_time[i], end_time[i], total_time)
131 | if interaction:
132 | self._object_path_to_videos.append(video_path)
133 | self._object_labels_times.append(labels_times)
134 | else:
135 | self._no_object_path_to_videos.append(video_path)
136 | self._no_object_labels_times.append(labels_times)
137 |
138 | self._path_to_videos = self._object_path_to_videos + self._no_object_path_to_videos
139 | self._labels_times = self._object_labels_times + self._no_object_labels_times
140 |
141 | assert (
142 | len(self._path_to_videos) > 0
143 | ), "Failed to load Object Interaction from {}".format(
144 | self._path_to_csv
145 | )
146 | print(
147 | "Constructing Object Interaction dataloader (size: {}) from {}".format(
148 | len(self._path_to_videos), self._path_to_csv
149 | )
150 | )
151 |
152 | def __getitem__(self, index):
153 | """
154 | With probability self.object_interaction_ratio randomly choose a clip
155 | with an object interaction, otherwise randomly choose a clip without
156 | an object interaction. If the video cannot be fetched and decoded
157 | successfully, find a random video that can be decoded as a replacement.
158 | Args:
159 | index (int): the video index provided by the pytorch sampler. (not used)
160 | Returns:
161 | frames (tensor): the frames of sampled from the video. The dimension
162 | is `num frames` x `channel` x `height` x `width`.
163 | interaction (int): whether there is an interaction in the vidoe.
164 | 0 for no interaction, 1 for an interaction.
165 | object_label (int): the label of the current video.
166 | """
167 | # Try to decode and sample a clip from a video. If the video can not be
168 | # decoded, repeatly find a random video replacement that can be decoded.
169 | for i_try in range(self._num_retries):
170 | if self.mode == 'train':
171 | # If training sample random video with object interaction
172 | # with prob self.object_interaction_ratio
173 | if random.random() < self.object_interaction_ratio:
174 | # Get sample with object interaction
175 | index = random.randint(0, len(self._object_path_to_videos) - 1)
176 | # Get a clip with an object interaction
177 | video_path = self._object_path_to_videos[index]
178 | label_time = self._object_labels_times[index]
179 | interaction, object_label, start_time, end_time, total_time = label_time
180 | else:
181 | # Get sample without object interaction
182 | index = random.randint(0, len(self._no_object_path_to_videos) - 1)
183 | # Get a clip without object interaction
184 | video_path = self._no_object_path_to_videos[index]
185 | label_time = self._no_object_labels_times[index]
186 | interaction, object_label, start_time, end_time, total_time = label_time
187 | elif self.mode == 'test':
188 | # If test sample provided index
189 | video_path = self._path_to_videos[index]
190 | label_time = self._labels_times[index]
191 | interaction, object_label, start_time, end_time, total_time = label_time
192 |
193 | # Decode Video
194 | try:
195 | if self.mode == 'train':
196 | # For training randomnly select start of clip
197 | start_seek = min(max(int((start_time + end_time - self._num_sec) / 2), start_time), max(end_time - self._num_sec, 0))
198 | elif self.mode == 'test':
199 | # Gets as to close to the center as possible of current clip
200 | start_seek = min(max(int((start_time + end_time - self._num_sec) / 2), start_time), max(end_time - self._num_sec, 0))
201 |
202 | num_sec = min(self._num_sec, end_time-start_time)
203 | frames = decode_ffmpeg(video_path, start_seek=start_seek, num_sec=num_sec, num_frames=self._num_frames, fps=self._fps)
204 | if frames.shape[0] == 0:
205 | raise ValueError('Decoder Error, 0 frames decoded at video path {video_path}, start_seek: {start_seek}, num_sec: {num_sec}, self._num_frames: {num_frames}, fps: {fps}, start_time: {start_time}, end_time: {end_time}, total_time: {total_time}'.format(video_path=video_path, start_seek=start_seek, num_sec=num_sec, num_frames=self._num_frames, fps=self._fps, start_time=start_time, end_time=end_time, total_time=total_time))
206 | except Exception as e:
207 | print(
208 | "Failed to decode video idx {} from {} with error {}".format(
209 | index, video_path, e
210 | )
211 | )
212 | # Random selection logic in getitem so random video will be decoded
213 | return self.__getitem__(0)
214 |
215 | start_idx, end_idx = get_start_end_idx(
216 | frames.shape[0], self._num_frames, 0, 1
217 | )
218 | frames = temporal_sampling(
219 | frames, start_idx, end_idx, self._num_frames
220 | )
221 |
222 | frames = frames.permute(0, 3, 1, 2) / 255.
223 | frames = torch.stack([self.transform(f) for f in frames])
224 | return frames, torch.tensor([interaction]), object_label
225 | else:
226 | raise RuntimeError(
227 | "Failed to fetch video after {} retries.".format(self._num_retries)
228 | )
229 |
230 | def __len__(self):
231 | """
232 | Returns:
233 | (int): the number of videos in the dataset.
234 | """
235 | if self.mode == 'train':
236 | return 20000
237 | elif self.mode == 'test':
238 | return self.num_videos
239 |
240 | @property
241 | def num_videos(self):
242 | """
243 | Returns:
244 | (int): the number of videos in the dataset.
245 | """
246 | return len(self._path_to_videos)
247 |
--------------------------------------------------------------------------------
/cms_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision.utils import save_image, make_grid
9 | from torchvision import transforms, utils
10 | from torchvision.transforms.functional import rgb_to_grayscale
11 |
12 | import os
13 | from PIL import Image, ImageOps
14 | import cv2
15 | import random
16 | import copy
17 | import pandas as pd
18 | from scipy.signal import savgol_filter
19 | import numpy as np
20 | from tqdm import tqdm
21 | import warnings
22 | import matplotlib.pyplot as plt
23 | warnings.filterwarnings("ignore")
24 |
25 |
26 | class VisionPropDataset(Dataset):
27 |
28 | def __init__(self, root_path, config, transforms, mode='train', debug=False):
29 |
30 | self.config = config
31 | self.mode = mode
32 | self.root_path = root_path
33 | self.normalization_coeffs = None
34 | self.debug = debug
35 | self.experiments = []
36 | self.last_frame_idx = 0
37 | self.input_prop_fts = []
38 | self.images_path = []
39 | self.label_features = []
40 | self.inputs = []
41 | self.targets = []
42 | self._add_features()
43 | self.rollout_boundaries = [0] # boundaries between different runs
44 | self.rollout_ts = []
45 | self.prop_latent = []
46 | self.setup_data_pipeline(transforms)
47 | if mode == 'deploy':
48 | return
49 | self.num_samples = 0
50 | file_rootname = 'rollout_'
51 | print("------------------------------------------")
52 | print('Building %s Dataset' % ("Training" if mode == 'train' else "Validation"))
53 | print("------------------------------------------")
54 | if os.path.isdir(root_path):
55 | for root, dirs, files in os.walk(root_path, topdown=True, followlinks=True):
56 | for name in dirs:
57 | if name.startswith(file_rootname):
58 | self.experiments.append(os.path.join(root, name))
59 | else:
60 | assert False, "Provided dataset root is neither a file nor a directory!"
61 |
62 | self.num_experiments = len(self.experiments)
63 | assert self.num_experiments > 0, 'No valid data found!'
64 | print('Dataset contains %d experiments.' % self.num_experiments)
65 |
66 | # assert self.config.lookhead[0] > 0.0, "Lookhead should be larger than 0.0 for training"
67 |
68 | # numpy arrays to store the raw data and compute mean/std for normalization
69 | self.raw_prop_inputs = None
70 |
71 | print("Decoding data...")
72 | self.experiments = sorted(self.experiments)
73 | for exp in tqdm(self.experiments):
74 | try:
75 | self._decode_experiment(exp)
76 | except Exception as e:
77 | print(e)
78 |
79 | if len(self.inputs) == 0:
80 | raise IOError("Did not find any file in the dataset folder")
81 | self.inputs = torch.from_numpy(np.vstack(self.inputs).astype(np.float32))
82 | self.targets = torch.from_numpy(np.vstack(self.targets).astype(np.float32))
83 | self.rollout_ts = np.vstack(self.rollout_ts).astype(np.float32)
84 | self.prop_latent = np.vstack(self.prop_latent).astype(np.float32)
85 |
86 | # this computes the normalization
87 | if self.mode == 'train':
88 | self._preprocess_dataset()
89 |
90 | print('Found {} samples belonging to {} experiments:'.format(
91 | self.num_samples, self.num_experiments))
92 |
93 | def __len__(self):
94 | return self.num_samples
95 |
96 | def _add_features(self):
97 | self.input_prop_fts += ["rpy_0",
98 | "rpy_1"]
99 |
100 | joint_dim = 12
101 | action_dim = 12
102 | latent_dim = self.config.latent_dim
103 | self.latent_dim = latent_dim
104 |
105 | for i in range(joint_dim):
106 | self.input_prop_fts.append("joint_angles_{}".format(i))
107 | for i in range(joint_dim):
108 | self.input_prop_fts.append("joint_vel_{}".format(i))
109 | for i in range(action_dim):
110 | self.input_prop_fts.append("last_action_{}".format(i))
111 | self.input_prop_fts.extend(["command_0", "command_1"])
112 |
113 | if self.config.input_use_depth:
114 | self.input_frame_fts = ["depth_frame_counter"]
115 | else:
116 | self.input_frame_fts = ["frame_counter"]
117 | self.input_fts = self.input_prop_fts + self.input_frame_fts
118 | self.target_fts = []
119 | for i in range(self.config.latent_dim):
120 | self.target_fts.append(f"prop_latent_{i}")
121 |
122 | def process_raw_latent(self, data, ts):
123 |
124 | data_freq = 90
125 | lookheads = self.config.lookhead
126 | self.num_lookheads = len(lookheads)
127 |
128 | predictive_latent = []
129 | # Latent smoothing
130 | for i in range(data.shape[1]):
131 | data[:, i] = savgol_filter(data[:, i], 41, 3)
132 |
133 | # Latent smoothing
134 | for i in range(data.shape[1]):
135 | time_shifted_data = np.zeros_like(data[:,i])
136 | for k in range(len(lookheads)):
137 | # Advance only for future geometry, without any smoothing
138 | current_lookhead = int(lookheads[k]*data_freq)
139 | time_shifted_data = np.roll(data[:,i],
140 | int(-current_lookhead))
141 | if current_lookhead >= 0:
142 | time_shifted_data[-current_lookhead:] = data[-current_lookhead:, i]
143 | else:
144 | time_shifted_data[:-current_lookhead] = data[:-current_lookhead, i]
145 | predictive_latent.append(np.expand_dims(time_shifted_data,1))
146 |
147 | predictive_latent = np.hstack(predictive_latent)
148 |
149 | return predictive_latent
150 |
151 | def _decode_experiment(self, dir_subpath):
152 | propr_file = os.path.join(dir_subpath, "proprioception.csv")
153 | assert os.path.isfile(propr_file), "Not Found proprioception file"
154 | df_prop = pd.read_csv(propr_file, delimiter=',')
155 |
156 | current_img_paths = []
157 | if self.config.input_use_imgs:
158 | if self.config.input_use_depth:
159 | r_ext = '.tiff'
160 | else:
161 | r_ext = '.jpg'
162 | img_dir = os.path.join(dir_subpath, "img")
163 | for f in os.listdir(img_dir):
164 | ext = os.path.splitext(f)[1]
165 | if ext.lower() not in [r_ext]:
166 | continue
167 | current_img_paths.append(os.path.join(img_dir,f))
168 | if len(current_img_paths) == 0:
169 | raise IOError("Not found images")
170 | self.images_path.extend(sorted(current_img_paths))
171 |
172 | if self.debug:
173 | print("Average sampling frequency of proprioception is %.6f" % (
174 | 1.0 / np.mean(np.diff(np.unique(df_prop["time_from_start"].values)))))
175 | inputs, targets = df_prop[self.input_fts].values, df_prop[self.target_fts].values
176 |
177 | inputs[:,-1] += self.last_frame_idx
178 |
179 |
180 | ts = df_prop["time_from_start"]
181 | if self.config.input_use_depth:
182 | fc = df_prop["depth_frame_counter"]
183 | else:
184 | fc = df_prop["frame_counter"]
185 |
186 | prop_l = targets
187 | targets = self.process_raw_latent(prop_l, ts)
188 |
189 | self.inputs.append(inputs)
190 | self.targets.append(targets)
191 |
192 | final_idx = len(self.input_prop_fts)
193 | input_prop_features_v = inputs[:, :final_idx]
194 |
195 | if self.raw_prop_inputs is None:
196 | self.raw_prop_inputs = input_prop_features_v
197 | self.raw_latent_target = targets
198 | else:
199 | self.raw_prop_inputs = np.concatenate([self.raw_prop_inputs, input_prop_features_v], axis=0)
200 | self.raw_latent_target = np.concatenate([self.raw_latent_target, targets], axis=0)
201 | self.last_frame_idx += len(current_img_paths)
202 | self.num_samples += inputs.shape[0]
203 | #self.num_samples += len(idxs)
204 | self.rollout_boundaries.append(self.num_samples)
205 | self.rollout_ts.append(np.expand_dims(fc[:inputs.shape[0]], axis=1))
206 | self.prop_latent.append(prop_l)
207 |
208 | def _preprocess_dataset(self):
209 | if self.normalization_coeffs is None:
210 | self.input_prop_mean = np.mean(self.raw_prop_inputs, axis=0).astype(np.float32)
211 | self.input_prop_std = np.std(self.raw_prop_inputs, axis=0).astype(np.float32)
212 | self.target_latent_mean = np.mean(self.raw_latent_target, axis=0).astype(np.float32)
213 | self.target_latent_std = np.std(self.raw_latent_target, axis=0).astype(np.float32)
214 | else:
215 | self.input_prop_mean = self.normalization_coeffs[0]
216 | self.input_prop_std = self.normalization_coeffs[1]
217 | self.target_latent_mean = self.normalization_coeffs[2]
218 | self.target_latent_std = self.normalization_coeffs[3]
219 |
220 | def get_normalization(self):
221 | return self.input_prop_mean, self.input_prop_std, self.target_latent_mean, self.target_latent_std
222 |
223 | def setup_data_pipeline(self, transforms):
224 | if self.config.input_use_depth:
225 | self.preprocess_pipeline = transforms.Compose([
226 | transforms.ToTensor(),
227 | transforms.CenterCrop((240,320)),
228 | ])
229 | else:
230 | # self.preprocess_pipeline = transforms.Compose([
231 | # transforms.ToTensor(),
232 | # transforms.CenterCrop((240,320)),
233 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
234 | # ])
235 |
236 | self.preprocess_pipeline = transforms
237 |
238 | def image_processing(self, fname):
239 | if self.config.input_use_depth:
240 | input_image = cv2.imread(fname, cv2.IMREAD_ANYDEPTH)
241 | input_array = np.asarray(input_image, dtype=np.float32)
242 | input_array = np.minimum(input_array, 4000)
243 | input_array = (input_array - 2000) / 2000
244 | else:
245 | input_array = Image.open(fname)
246 | return input_array
247 |
248 | def preprocess_data(self, prop_numpy, img_numpy):
249 | prop = torch.from_numpy(prop_numpy.astype(np.float32))
250 | try:
251 | img_numpy = np.stack(img_numpy, axis=-1).astype(np.float32)
252 | except:
253 | print(img_numpy[0].shape)
254 | print(img_numpy[1].shape)
255 | print(img_numpy[2].shape)
256 | if self.config.input_use_depth:
257 | img_numpy = np.minimum(img_numpy, 4000)
258 | img_numpy = (img_numpy - 2000) / 2000
259 | img = self.preprocess_pipeline(img_numpy)
260 | return prop, img
261 |
262 | def __getitem__(self, idx):
263 | prop_data = np.zeros((self.config.history_len,len(self.input_prop_fts)), dtype=np.float32)
264 | start_idx = np.maximum(0, idx - self.config.history_len)
265 | actual_history_length = idx - start_idx
266 | if actual_history_length > 0:
267 | prop_data[-actual_history_length:] = self.inputs[start_idx:idx, :len(self.input_prop_fts)]
268 | else:
269 | prop_data[-1] = self.inputs[idx, :len(self.input_prop_fts)]
270 |
271 | target = self.targets[idx]
272 | frame_skip = self.config.frame_skip # take one every n
273 | if self.config.input_use_imgs > 0:
274 | frame_idx_start = int(self.inputs[idx][-1].numpy())
275 | imgs = []
276 | for i in reversed(range(self.config.input_use_imgs)):
277 | frame_idx = np.maximum(0, frame_idx_start - i*frame_skip)
278 | frame_path = self.images_path[frame_idx]
279 | img = self.image_processing(frame_path)
280 | img = self.preprocess_pipeline(img)
281 | imgs.append(img)
282 |
283 |
284 | if self.config.grayscale:
285 | imgs = rgb_to_grayscale(imgs, num_output_channels=1).squeeze(1)
286 |
287 | imgs = torch.stack(imgs, dim=1)
288 |
289 | return prop_data, imgs, target
290 | else:
291 | return prop_data, 0.0, target
292 |
--------------------------------------------------------------------------------
/locomotion_prediction/locomotion_prediction_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | import os
5 | import random
6 | import math
7 |
8 | import torch
9 | import numpy as np
10 | import torch.utils.data
11 | from iopath.common.file_io import g_pathmgr as pathmgr
12 |
13 | import sys
14 | parent = os.path.dirname(os.path.abspath(__file__))
15 | parent_parent = os.path.join(parent, '../')
16 | sys.path.append(os.path.dirname(parent_parent))
17 |
18 | from locomotion_prediction.locomotion_prediction_utils import *
19 | from decoder.utils import decode_ffmpeg
20 | from torchvision import transforms
21 | from evo.tools import file_interface
22 |
23 | from pathlib import Path
24 |
25 |
26 | class Locomotion_Prediction_dataloader(torch.utils.data.Dataset):
27 | """
28 | Locomotion Prediction video loader. Construct the Locomotion Prediction video loader,
29 | then sample clips from the videos. For training, a single clip is randomly sampled
30 | from every video with normalization. For validation, the video path and trajectories are
31 | returned.
32 | """
33 |
34 | def __init__(
35 | self,
36 | mode,
37 | path_to_data_dir,
38 | path_to_trajectories_dir,
39 | path_to_csv,
40 | # transformation
41 | transform,
42 | # decoding settings
43 | fps=30,
44 | # frame aug settings
45 | crop_size=224,
46 | # pose estimation settings
47 | animals=['cat', 'dog'],
48 | num_condition_frames=16,
49 | num_pose_prediction=16,
50 | scale_invariance='None',
51 | pps=30,
52 | # other parameters
53 | enable_multi_thread_decode=False,
54 | use_offset_sampling=True,
55 | inverse_uniform_sampling=False,
56 | num_retries=10,
57 | ):
58 | """
59 | Construct the Object Interaction video loader with a given csv file. The format of
60 | the csv file is:
61 | ```
62 | animal_1, ds_type_1, segment_id_1, stride_1, start_time_1, end_time_1, video_path_1
63 | animal_2, ds_type_2, segment_id_2, stride_2, start_time_2, end_time_2, video_path_2
64 | ...
65 | animal_N, ds_type_N, segment_id_N, stride_N, start_time_N, end_time_N, video_path_N
66 | ```
67 | Args:
68 | mode (string): Options includes `train` or `test`.
69 | For the train and val mode, the data loader will take data
70 | from the train or val set, and sample one clip per video.
71 | path_to_data_dir (string): Path to EgoPet Dataset
72 | path_to_csv (string): Path to Object Interaction data
73 | num_frames (int): number of frames used for model
74 | num_condition_frames (int): number of frames to condition model on
75 | num_pose_prediction (int): number of future poses to predict
76 | """
77 | # Only support train, val, and test mode.
78 | assert mode in [
79 | "train",
80 | "test",
81 | ], "mode has to be 'train' or 'test'"
82 | self.mode = mode
83 |
84 | self._num_retries = num_retries
85 | self._path_to_data_dir = path_to_data_dir
86 | self._path_to_trajectories_dir = path_to_trajectories_dir
87 | self._path_to_csv = path_to_csv
88 |
89 | self._crop_size = crop_size
90 |
91 | self._num_frames = num_condition_frames
92 | self._num_sec = math.ceil((num_condition_frames + num_pose_prediction) / fps)
93 | self._fps=fps
94 | self._pps = pps
95 | self._pose_skip = fps // pps
96 |
97 | self.transform = transform
98 |
99 | # Pose Estimation Settings
100 | self._animals = animals
101 | self._num_condition_frames = num_condition_frames
102 | self._num_pose_prediction = num_pose_prediction
103 | self._scale_invariance = scale_invariance
104 |
105 | self._enable_multi_thread_decode = enable_multi_thread_decode
106 | self._inverse_uniform_sampling = inverse_uniform_sampling
107 | self._use_offset_sampling = use_offset_sampling
108 | self._num_retries = num_retries
109 |
110 | print(self)
111 | print(locals())
112 | self._construct_loader()
113 |
114 | def _construct_loader(self):
115 | """
116 | Construct the video loader.
117 | """
118 | self._path_to_videos = []
119 | self._path_to_trajectories = []
120 | self._start_times = []
121 | self._end_times = []
122 | self._stride = []
123 |
124 | with pathmgr.open(self._path_to_csv, "r") as f:
125 | for curr_clip in f.read().splitlines():
126 | curr_values = curr_clip.split(',')
127 |
128 | animal, ds_type, segment_id, stride, start_time, end_time, video_path = curr_values # NEED TO GENERATE THIS CSV
129 | video_path = os.path.join(self._path_to_data_dir, video_path)
130 | if ((self.mode == 'train' and ds_type == 'training_set') or (self.mode == 'test' and ds_type == 'validation_set')) and stride != -1 and animal in self._animals:
131 | start_time, end_time = int(start_time), int(end_time)
132 |
133 | if self.mode == 'test':
134 | end_time = min(end_time, 50) # Clip Evaluation Videos to 50 Seconds
135 |
136 | trajectory = "{}_calib_eth_stride_{}_interp.txt".format(segment_id, stride)
137 | trajectory_path = os.path.join(self._path_to_trajectories_dir, trajectory)
138 |
139 | if os.path.isfile(trajectory_path) and os.path.isfile(video_path) and (end_time - start_time) >= self._num_sec:
140 | video_path = os.path.join(self._path_to_data_dir, video_path)
141 | self._path_to_videos.append(video_path)
142 | self._path_to_trajectories.append(trajectory_path)
143 | self._start_times.append(start_time)
144 | self._end_times.append(end_time)
145 | self._stride.append(stride)
146 |
147 | assert (
148 | len(self._path_to_videos) > 0
149 | ), "Failed to load Pose Trajectories from {}".format(
150 | self._path_to_csv
151 | )
152 | print(
153 | "Constructing Object Interaction dataloader (size: {}) from {}".format(
154 | len(self._path_to_videos), self._path_to_csv
155 | )
156 | )
157 |
158 |
159 | def _get_sub_clip(self, video_path, trajectory_path, sub_start_idx, sub_end_idx):
160 | """
161 | Get sub clip of video of from sub_start_idx to sub_end_idx.
162 | Args:
163 | video_path (string): path to video
164 | trajectory_path (string): path to associated trajectory
165 | sub_start_idx (int): start idx for this sub clip
166 | sub_end_idx (int): end idx for this sub clip
167 | Returns:
168 | frames (tensor): the frames of sampled from the video. The dimension
169 | is `num_condition_frames` x `channel` x `height` x `width`.
170 | dP_tensor (tensor): the relative poses for the entire clip.
171 | """
172 | # For validation should get the entire segment
173 | start_seek = sub_start_idx // self._fps
174 | num_sec = (sub_end_idx - sub_start_idx) / self._fps + 1.0
175 | dP_tensor = relative_poses(trajectory_path, sub_start_idx, sub_end_idx, scale_invariance=None, pose_skip=1)
176 |
177 | real_num_frames = (sub_end_idx - sub_start_idx)
178 | frame_relative_start_idx = sub_start_idx - start_seek * self._fps
179 | num_frames = int(num_sec * self._fps + 512) # Set the num_frames to be more frames than decoded in order to get all the decoded frames
180 | frames = decode_ffmpeg(video_path, start_seek=start_seek, num_sec=num_sec, num_frames=num_frames, fps=self._fps)
181 | frames = frames[frame_relative_start_idx:frame_relative_start_idx + real_num_frames]
182 |
183 | frames = frames.permute(0, 3, 1, 2) / 255.
184 | frames = torch.stack([self.transform(f) for f in frames])
185 |
186 | return frames, dP_tensor
187 |
188 |
189 | def __getitem__(self, index):
190 | """
191 | If self.mode is train, randomnly choose a self.num_condition_frames +
192 | self.num_pose_prediction frame clip. If self.mode is test, choose the
193 | entire video segment. If the video cannot be fetched and decoded
194 | successfully, find a random video that can be decoded as a replacement.
195 | Args:
196 | index (int): the video index provided by the pytorch sampler. (not used)
197 | Returns:
198 | frames (tensor): the frames of sampled from the video. The dimension
199 | is `num_condition_frames` x `channel` x `height` x `width`.
200 | cond_poses (tensor): the associated poses to the conditioning frames.
201 | The dimension is `num_condition_frames` x `channel` x `height` x `width`.
202 | dP_tensor (tensor): the relative poses for the entire clip.
203 | start_idx (int): the start index of this clip
204 | end_idx (int): the end index of this clip
205 | """
206 | for i_try in range(self._num_retries):
207 | if self.mode == 'train': # If 'train', sample a random video but if 'test' use index
208 | index = random.randint(0, len(self._path_to_videos) - 1)
209 |
210 | video_path = self._path_to_videos[index]
211 | trajectory_path = self._path_to_trajectories[index]
212 | start_time = self._start_times[index]
213 | end_time = self._end_times[index]
214 | stride = self._stride[index]
215 |
216 | # Decode Video
217 | try:
218 | if self.mode == 'train':
219 | # For training randomnly select start of clip
220 | start_seek = np.random.randint(start_time, max(end_time - self._num_sec, 1))
221 | start_idx = (start_seek * self._fps)
222 | end_idx = start_idx + self._num_condition_frames + self._num_pose_prediction
223 |
224 | if self._scale_invariance == 'dir':
225 | dP_tensor, _ = relative_poses(trajectory_path, start_idx + self._num_condition_frames, end_idx, self._scale_invariance, self._pose_skip)
226 |
227 | frames = decode_ffmpeg(video_path, start_seek=start_seek, num_sec=self._num_sec, num_frames=self._num_frames, fps=self._fps)
228 | elif self.mode == 'test':
229 | # For validation should get the entire segment
230 | start_idx = start_time * self._fps
231 | num_sub_clips = ((end_time - start_time) * self._fps) // (self._num_condition_frames + self._num_pose_prediction)
232 | end_idx = start_idx + num_sub_clips * (self._num_condition_frames + self._num_pose_prediction)
233 | return video_path, trajectory_path, start_idx, end_idx
234 |
235 | if frames.shape[0] == 0:
236 | raise ValueError('Decoder Error, 0 frames decoded at video path {video_path}, start_seek: {start_seek}, num_sec: {num_sec}, self._num_frames: {num_frames}, fps: {fps}, start_time: {start_time}, end_time: {end_time}, total_time: {total_time}'.format(video_path=video_path, start_seek=start_seek, num_sec=num_sec, num_frames=self._num_frames, fps=self._fps, start_time=start_time, end_time=end_time, total_time=total_time))
237 | except Exception as e:
238 | print(
239 | "Failed to decode video idx {} from {} with error {}".format(
240 | index, video_path, e
241 | )
242 | )
243 | # Random selection logic in getitem so random video will be decoded
244 | return self.__getitem__(0)
245 |
246 | # Keep just the first num_condition_frames frames
247 | if self.mode == 'train':
248 | frames = frames[:self._num_condition_frames]
249 | elif self.mode == 'test':
250 | frames = frames[start_idx:end_idx]
251 | frames = frames.permute(0, 3, 1, 2) / 255.
252 | frames = torch.stack([self.transform(f) for f in frames])
253 | return frames, dP_tensor.to(torch.float32), start_idx, end_idx
254 | else:
255 | raise RuntimeError(
256 | "Failed to fetch video after {} retries.".format(self._num_retries)
257 | )
258 |
259 |
260 | def __len__(self):
261 | """
262 | Returns:
263 | (int): the number of videos in the dataset.
264 | """
265 | if self.mode == 'train':
266 | return 20000
267 | elif self.mode == 'test':
268 | return self.num_videos
269 |
270 |
271 | @property
272 | def num_videos(self):
273 | """
274 | Returns:
275 | (int): the number of videos in the dataset.
276 | """
277 | return len(self._path_to_videos)
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | """ Mixup and Cutmix
2 |
3 | Papers:
4 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5 |
6 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7 |
8 | Code Reference:
9 | CutMix: https://github.com/clovaai/CutMix-PyTorch
10 |
11 | Hacked together by / Copyright 2019, Ross Wightman
12 | """
13 | import numpy as np
14 | import torch
15 |
16 |
17 | def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18 | x = x.long().view(-1, 1)
19 | return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20 |
21 |
22 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23 | off_value = smoothing / num_classes
24 | on_value = 1. - smoothing + off_value
25 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27 | return y1 * lam + y2 * (1. - lam)
28 |
29 |
30 | def rand_bbox(img_shape, lam, margin=0., count=None):
31 | """ Standard CutMix bounding-box
32 | Generates a random square bbox based on lambda value. This impl includes
33 | support for enforcing a border margin as percent of bbox dimensions.
34 |
35 | Args:
36 | img_shape (tuple): Image shape as tuple
37 | lam (float): Cutmix lambda value
38 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39 | count (int): Number of bbox to generate
40 | """
41 | ratio = np.sqrt(1 - lam)
42 | img_h, img_w = img_shape[-2:]
43 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47 | yl = np.clip(cy - cut_h // 2, 0, img_h)
48 | yh = np.clip(cy + cut_h // 2, 0, img_h)
49 | xl = np.clip(cx - cut_w // 2, 0, img_w)
50 | xh = np.clip(cx + cut_w // 2, 0, img_w)
51 | return yl, yh, xl, xh
52 |
53 |
54 | def rand_bbox_minmax(img_shape, minmax, count=None):
55 | """ Min-Max CutMix bounding-box
56 | Inspired by Darknet cutmix impl, generates a random rectangular bbox
57 | based on min/max percent values applied to each dimension of the input image.
58 |
59 | Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60 |
61 | Args:
62 | img_shape (tuple): Image shape as tuple
63 | minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64 | count (int): Number of bbox to generate
65 | """
66 | assert len(minmax) == 2
67 | img_h, img_w = img_shape[-2:]
68 | cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69 | cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70 | yl = np.random.randint(0, img_h - cut_h, size=count)
71 | xl = np.random.randint(0, img_w - cut_w, size=count)
72 | yu = yl + cut_h
73 | xu = xl + cut_w
74 | return yl, yu, xl, xu
75 |
76 |
77 | def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78 | """ Generate bbox and apply lambda correction.
79 | """
80 | if ratio_minmax is not None:
81 | yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82 | else:
83 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84 | if correct_lam or ratio_minmax is not None:
85 | bbox_area = (yu - yl) * (xu - xl)
86 | lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87 | return (yl, yu, xl, xu), lam
88 |
89 |
90 | class Mixup:
91 | """ Mixup/Cutmix that applies different params to each element or whole batch
92 |
93 | Args:
94 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97 | prob (float): probability of applying mixup or cutmix per batch or element
98 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101 | label_smoothing (float): apply label smoothing to the mixed target tensor
102 | num_classes (int): number of classes for target
103 | """
104 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106 | self.mixup_alpha = mixup_alpha
107 | self.cutmix_alpha = cutmix_alpha
108 | self.cutmix_minmax = cutmix_minmax
109 | if self.cutmix_minmax is not None:
110 | assert len(self.cutmix_minmax) == 2
111 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112 | self.cutmix_alpha = 1.0
113 | self.mix_prob = prob
114 | self.switch_prob = switch_prob
115 | self.label_smoothing = label_smoothing
116 | self.num_classes = num_classes
117 | self.mode = mode
118 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120 |
121 | def _params_per_elem(self, batch_size):
122 | lam = np.ones(batch_size, dtype=np.float32)
123 | use_cutmix = np.zeros(batch_size, dtype=np.bool)
124 | if self.mixup_enabled:
125 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126 | use_cutmix = np.random.rand(batch_size) < self.switch_prob
127 | lam_mix = np.where(
128 | use_cutmix,
129 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130 | np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131 | elif self.mixup_alpha > 0.:
132 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133 | elif self.cutmix_alpha > 0.:
134 | use_cutmix = np.ones(batch_size, dtype=np.bool)
135 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136 | else:
137 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138 | lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139 | return lam, use_cutmix
140 |
141 | def _params_per_batch(self):
142 | lam = 1.
143 | use_cutmix = False
144 | if self.mixup_enabled and np.random.rand() < self.mix_prob:
145 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146 | use_cutmix = np.random.rand() < self.switch_prob
147 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148 | np.random.beta(self.mixup_alpha, self.mixup_alpha)
149 | elif self.mixup_alpha > 0.:
150 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151 | elif self.cutmix_alpha > 0.:
152 | use_cutmix = True
153 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154 | else:
155 | assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156 | lam = float(lam_mix)
157 | return lam, use_cutmix
158 |
159 | def _mix_elem(self, x):
160 | batch_size = len(x)
161 | lam_batch, use_cutmix = self._params_per_elem(batch_size)
162 | x_orig = x.clone() # need to keep an unmodified original for mixing source
163 | for i in range(batch_size):
164 | j = batch_size - i - 1
165 | lam = lam_batch[i]
166 | if lam != 1.:
167 | if use_cutmix[i]:
168 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170 | x[i][..., yl:yh, xl:xh] = x_orig[j][..., yl:yh, xl:xh]
171 | lam_batch[i] = lam
172 | else:
173 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175 |
176 | def _mix_pair(self, x):
177 | batch_size = len(x)
178 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179 | x_orig = x.clone() # need to keep an unmodified original for mixing source
180 | for i in range(batch_size // 2):
181 | j = batch_size - i - 1
182 | lam = lam_batch[i]
183 | if lam != 1.:
184 | if use_cutmix[i]:
185 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186 | x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187 | x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188 | x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189 | lam_batch[i] = lam
190 | else:
191 | x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192 | x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194 | return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195 |
196 | def _mix_batch(self, x):
197 | lam, use_cutmix = self._params_per_batch()
198 | if lam == 1.:
199 | return 1.
200 | if use_cutmix:
201 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
204 | else:
205 | x_flipped = x.flip(0).mul_(1. - lam)
206 | x.mul_(lam).add_(x_flipped)
207 | return lam
208 |
209 | def __call__(self, x, target):
210 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
211 | if self.mode == 'elem':
212 | lam = self._mix_elem(x)
213 | elif self.mode == 'pair':
214 | lam = self._mix_pair(x)
215 | else:
216 | lam = self._mix_batch(x)
217 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218 | return x, target
219 |
220 |
221 | class FastCollateMixup(Mixup):
222 | """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223 |
224 | A Mixup impl that's performed while collating the batches.
225 | """
226 |
227 | def _mix_elem_collate(self, output, batch, half=False):
228 | batch_size = len(batch)
229 | num_elem = batch_size // 2 if half else batch_size
230 | assert len(output) == num_elem
231 | lam_batch, use_cutmix = self._params_per_elem(num_elem)
232 | for i in range(num_elem):
233 | j = batch_size - i - 1
234 | lam = lam_batch[i]
235 | mixed = batch[i][0]
236 | if lam != 1.:
237 | if use_cutmix[i]:
238 | if not half:
239 | mixed = mixed.copy()
240 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242 | mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243 | lam_batch[i] = lam
244 | else:
245 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246 | np.rint(mixed, out=mixed)
247 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
248 | if half:
249 | lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250 | return torch.tensor(lam_batch).unsqueeze(1)
251 |
252 | def _mix_pair_collate(self, output, batch):
253 | batch_size = len(batch)
254 | lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255 | for i in range(batch_size // 2):
256 | j = batch_size - i - 1
257 | lam = lam_batch[i]
258 | mixed_i = batch[i][0]
259 | mixed_j = batch[j][0]
260 | assert 0 <= lam <= 1.0
261 | if lam < 1.:
262 | if use_cutmix[i]:
263 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265 | patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266 | mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267 | mixed_j[:, yl:yh, xl:xh] = patch_i
268 | lam_batch[i] = lam
269 | else:
270 | mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271 | mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272 | mixed_i = mixed_temp
273 | np.rint(mixed_j, out=mixed_j)
274 | np.rint(mixed_i, out=mixed_i)
275 | output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276 | output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277 | lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278 | return torch.tensor(lam_batch).unsqueeze(1)
279 |
280 | def _mix_batch_collate(self, output, batch):
281 | batch_size = len(batch)
282 | lam, use_cutmix = self._params_per_batch()
283 | if use_cutmix:
284 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285 | output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286 | for i in range(batch_size):
287 | j = batch_size - i - 1
288 | mixed = batch[i][0]
289 | if lam != 1.:
290 | if use_cutmix:
291 | mixed = mixed.copy() # don't want to modify the original while iterating
292 | mixed[..., yl:yh, xl:xh] = batch[j][0][..., yl:yh, xl:xh]
293 | else:
294 | mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295 | np.rint(mixed, out=mixed)
296 | output[i] += torch.from_numpy(mixed.astype(np.uint8))
297 | return lam
298 |
299 | def __call__(self, batch, _=None):
300 | batch_size = len(batch)
301 | assert batch_size % 2 == 0, 'Batch size should be even when using this'
302 | half = 'half' in self.mode
303 | if half:
304 | batch_size //= 2
305 | output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306 | if self.mode == 'elem' or self.mode == 'half':
307 | lam = self._mix_elem_collate(output, batch, half=half)
308 | elif self.mode == 'pair':
309 | lam = self._mix_pair_collate(output, batch)
310 | else:
311 | lam = self._mix_batch_collate(output, batch)
312 | target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314 | target = target[:batch_size]
315 | return output, target
316 |
--------------------------------------------------------------------------------
/ssv2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torchvision import transforms
5 | from random_erasing import RandomErasing
6 | import warnings
7 | from decord import VideoReader, cpu
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 | import video_transforms as video_transforms
11 | import volume_transforms as volume_transforms
12 |
13 |
14 | class SSVideoClsDataset(Dataset):
15 | """Load your own video classification dataset."""
16 |
17 | def __init__(self, anno_path, data_path, mode='train', clip_len=8,
18 | crop_size=224, short_side_size=256, new_height=256,
19 | new_width=340, keep_aspect_ratio=True, num_segment=1,
20 | num_crop=1, test_num_segment=10, test_num_crop=3, args=None):
21 | self.anno_path = anno_path
22 | self.data_path = data_path
23 | self.mode = mode
24 | self.clip_len = clip_len
25 | self.crop_size = crop_size
26 | self.short_side_size = short_side_size
27 | self.new_height = new_height
28 | self.new_width = new_width
29 | self.keep_aspect_ratio = keep_aspect_ratio
30 | self.num_segment = num_segment
31 | self.test_num_segment = test_num_segment
32 | self.num_crop = num_crop
33 | self.test_num_crop = test_num_crop
34 | self.args = args
35 | self.aug = False
36 | self.rand_erase = False
37 |
38 | if self.mode in ['train']:
39 | self.aug = True
40 | if self.args.reprob > 0:
41 | self.rand_erase = True
42 | if VideoReader is None:
43 | raise ImportError("Unable to import `decord` which is required to read videos.")
44 |
45 | import pandas as pd
46 | cleaned = pd.read_csv(self.anno_path, header=None, delimiter=' ')
47 | self.dataset_samples = list(cleaned.values[:, 0])
48 | self.label_array = list(cleaned.values[:, 1])
49 | if self.data_path is not None:
50 | self.dataset_samples = [os.path.join(self.data_path, p) for p in self.dataset_samples]
51 |
52 |
53 | if (mode == 'train'):
54 | pass
55 |
56 | elif (mode == 'validation'):
57 | self.data_transform = video_transforms.Compose([
58 | video_transforms.Resize(self.short_side_size, interpolation='bilinear'),
59 | video_transforms.CenterCrop(size=(self.crop_size, self.crop_size)),
60 | volume_transforms.ClipToTensor(),
61 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
62 | std=[0.229, 0.224, 0.225])
63 | ])
64 | elif mode == 'test':
65 | self.data_resize = video_transforms.Compose([
66 | video_transforms.Resize(size=(short_side_size), interpolation='bilinear')
67 | ])
68 | self.data_transform = video_transforms.Compose([
69 | volume_transforms.ClipToTensor(),
70 | video_transforms.Normalize(mean=[0.485, 0.456, 0.406],
71 | std=[0.229, 0.224, 0.225])
72 | ])
73 | self.test_seg = []
74 | self.test_dataset = []
75 | self.test_label_array = []
76 | for ck in range(self.test_num_segment):
77 | for cp in range(self.test_num_crop):
78 | for idx in range(len(self.label_array)):
79 | sample_label = self.label_array[idx]
80 | self.test_label_array.append(sample_label)
81 | self.test_dataset.append(self.dataset_samples[idx])
82 | self.test_seg.append((ck, cp))
83 |
84 | def __getitem__(self, index):
85 | if self.mode == 'train':
86 | args = self.args
87 | scale_t = 1
88 |
89 | sample = self.dataset_samples[index]
90 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t) # T H W C
91 | if len(buffer) == 0:
92 | while len(buffer) == 0:
93 | warnings.warn("video {} not correctly loaded during training".format(sample))
94 | index = np.random.randint(self.__len__())
95 | sample = self.dataset_samples[index]
96 | buffer = self.loadvideo_decord(sample, sample_rate_scale=scale_t)
97 |
98 | if args.num_sample > 1:
99 | frame_list = []
100 | label_list = []
101 | index_list = []
102 | for _ in range(args.num_sample):
103 | new_frames = self._aug_frame(buffer, args)
104 | label = self.label_array[index]
105 | frame_list.append(new_frames)
106 | label_list.append(label)
107 | index_list.append(index)
108 | return frame_list, label_list, index_list, {}
109 | else:
110 | buffer = self._aug_frame(buffer, args)
111 | return buffer, self.label_array[index], index, {}
112 |
113 | elif self.mode == 'validation':
114 | sample = self.dataset_samples[index]
115 | buffer = self.loadvideo_decord(sample)
116 | if len(buffer) == 0:
117 | while len(buffer) == 0:
118 | warnings.warn("video {} not correctly loaded during validation".format(sample))
119 | index = np.random.randint(self.__len__())
120 | sample = self.dataset_samples[index]
121 | buffer = self.loadvideo_decord(sample)
122 | buffer = self.data_transform(buffer)
123 | return buffer, self.label_array[index], sample.split("/")[-1].split(".")[0]
124 |
125 | elif self.mode == 'test':
126 | sample = self.test_dataset[index]
127 | chunk_nb, split_nb = self.test_seg[index]
128 | buffer = self.loadvideo_decord(sample)
129 |
130 | while len(buffer) == 0:
131 | warnings.warn("video {}, temporal {}, spatial {} not found during testing".format(\
132 | str(self.test_dataset[index]), chunk_nb, split_nb))
133 | index = np.random.randint(self.__len__())
134 | sample = self.test_dataset[index]
135 | chunk_nb, split_nb = self.test_seg[index]
136 | buffer = self.loadvideo_decord(sample)
137 |
138 | buffer = self.data_resize(buffer)
139 | if isinstance(buffer, list):
140 | buffer = np.stack(buffer, 0)
141 |
142 | spatial_step = 1.0 * (max(buffer.shape[1], buffer.shape[2]) - self.short_side_size) \
143 | / (self.test_num_crop - 1)
144 | temporal_start = chunk_nb # 0/1
145 | spatial_start = int(split_nb * spatial_step)
146 | if buffer.shape[1] >= buffer.shape[2]:
147 | buffer = buffer[temporal_start::2, \
148 | spatial_start:spatial_start + self.short_side_size, :, :]
149 | else:
150 | buffer = buffer[temporal_start::2, \
151 | :, spatial_start:spatial_start + self.short_side_size, :]
152 |
153 | buffer = self.data_transform(buffer)
154 | return buffer, self.test_label_array[index], sample.split("/")[-1].split(".")[0], \
155 | chunk_nb, split_nb
156 | else:
157 | raise NameError('mode {} unkown'.format(self.mode))
158 |
159 | def _aug_frame(
160 | self,
161 | buffer,
162 | args,
163 | ):
164 | aug_transform = video_transforms.create_random_augment(
165 | input_size=(self.crop_size, self.crop_size),
166 | auto_augment=args.aa,
167 | interpolation=args.train_interpolation,
168 | )
169 |
170 | buffer = [
171 | transforms.ToPILImage()(frame) for frame in buffer
172 | ]
173 |
174 | buffer = aug_transform(buffer)
175 |
176 | buffer = [transforms.ToTensor()(img) for img in buffer]
177 | buffer = torch.stack(buffer) # T C H W
178 | buffer = buffer.permute(0, 2, 3, 1) # T H W C
179 |
180 | # T H W C
181 | buffer = tensor_normalize(
182 | buffer, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
183 | )
184 | # T H W C -> C T H W.
185 | buffer = buffer.permute(3, 0, 1, 2)
186 | # Perform data augmentation.
187 | scl, asp = (
188 | [0.08, 1.0],
189 | [0.75, 1.3333],
190 | )
191 |
192 | buffer = spatial_sampling(
193 | buffer,
194 | spatial_idx=-1,
195 | min_scale=256,
196 | max_scale=320,
197 | crop_size=self.crop_size,
198 | random_horizontal_flip=False if args.data_set == 'SSV2' else True,
199 | inverse_uniform_sampling=False,
200 | aspect_ratio=asp,
201 | scale=scl,
202 | motion_shift=False
203 | )
204 |
205 | if self.rand_erase:
206 | erase_transform = RandomErasing(
207 | args.reprob,
208 | mode=args.remode,
209 | max_count=args.recount,
210 | num_splits=args.recount,
211 | device="cpu",
212 | )
213 | buffer = buffer.permute(1, 0, 2, 3)
214 | buffer = erase_transform(buffer)
215 | buffer = buffer.permute(1, 0, 2, 3)
216 |
217 | return buffer
218 |
219 |
220 | def loadvideo_decord(self, sample, sample_rate_scale=1):
221 | """Load video content using Decord"""
222 | fname = sample
223 |
224 | if not (os.path.exists(fname)):
225 | return []
226 |
227 | # avoid hanging issue
228 | if os.path.getsize(fname) < 1 * 1024:
229 | print('SKIP: ', fname, " - ", os.path.getsize(fname))
230 | return []
231 | try:
232 | if self.keep_aspect_ratio:
233 | vr = VideoReader(fname, num_threads=1, ctx=cpu(0))
234 | else:
235 | vr = VideoReader(fname, width=self.new_width, height=self.new_height,
236 | num_threads=1, ctx=cpu(0))
237 | except:
238 | print("video cannot be loaded by decord: ", fname)
239 | return []
240 |
241 | if self.mode == 'test':
242 | all_index = []
243 | tick = float(len(vr) - 1) / float(self.num_segment)
244 | interval_cross_view = tick / self.test_num_segment
245 | for i in range(self.num_segment):
246 | for ci in range(self.test_num_segment):
247 | start = int(np.round(tick * i))
248 | end = int(np.round(tick * (i + 1)))
249 | if self.test_num_segment > 1:
250 | start = int(np.round(tick * i + interval_cross_view * ci))
251 | end = int(np.round(tick * i + interval_cross_view * (ci + 1)))
252 | all_index.append((start + end) // 2)
253 | else:
254 | all_index.append((start + end) // 2)
255 | all_index.append((start + end) // 2)
256 |
257 | vr.seek(0)
258 | buffer = vr.get_batch(all_index).asnumpy()
259 | return buffer
260 |
261 | # handle temporal segments
262 | average_duration = float(len(vr) - 1) / self.num_segment
263 | all_index = []
264 | for i in range(self.num_segment):
265 | start = int(np.round(average_duration * i))
266 | end = int(np.round(average_duration * (i + 1)))
267 | all_index.append(int(np.random.randint(start, end + 1)))
268 |
269 | all_index = list(np.array(all_index))
270 | vr.seek(0)
271 | buffer = vr.get_batch(all_index).asnumpy()
272 | return buffer
273 |
274 | def __len__(self):
275 | if self.mode != 'test':
276 | return len(self.dataset_samples)
277 | else:
278 | return len(self.test_dataset)
279 |
280 |
281 | def spatial_sampling(
282 | frames,
283 | spatial_idx=-1,
284 | min_scale=256,
285 | max_scale=320,
286 | crop_size=224,
287 | random_horizontal_flip=True,
288 | inverse_uniform_sampling=False,
289 | aspect_ratio=None,
290 | scale=None,
291 | motion_shift=False,
292 | ):
293 | """
294 | Perform spatial sampling on the given video frames. If spatial_idx is
295 | -1, perform random scale, random crop, and random flip on the given
296 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
297 | with the given spatial_idx.
298 | Args:
299 | frames (tensor): frames of images sampled from the video. The
300 | dimension is `num frames` x `height` x `width` x `channel`.
301 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
302 | or 2, perform left, center, right crop if width is larger than
303 | height, and perform top, center, buttom crop if height is larger
304 | than width.
305 | min_scale (int): the minimal size of scaling.
306 | max_scale (int): the maximal size of scaling.
307 | crop_size (int): the size of height and width used to crop the
308 | frames.
309 | inverse_uniform_sampling (bool): if True, sample uniformly in
310 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
311 | scale. If False, take a uniform sample from [min_scale,
312 | max_scale].
313 | aspect_ratio (list): Aspect ratio range for resizing.
314 | scale (list): Scale range for resizing.
315 | motion_shift (bool): Whether to apply motion shift for resizing.
316 | Returns:
317 | frames (tensor): spatially sampled frames.
318 | """
319 | assert spatial_idx in [-1, 0, 1, 2]
320 | if spatial_idx == -1:
321 | if aspect_ratio is None and scale is None:
322 | frames, _ = video_transforms.random_short_side_scale_jitter(
323 | images=frames,
324 | min_size=min_scale,
325 | max_size=max_scale,
326 | inverse_uniform_sampling=inverse_uniform_sampling,
327 | )
328 | frames, _ = video_transforms.random_crop(frames, crop_size)
329 | else:
330 | transform_func = (
331 | video_transforms.random_resized_crop_with_shift
332 | if motion_shift
333 | else video_transforms.random_resized_crop
334 | )
335 | frames = transform_func(
336 | images=frames,
337 | target_height=crop_size,
338 | target_width=crop_size,
339 | scale=scale,
340 | ratio=aspect_ratio,
341 | )
342 | if random_horizontal_flip:
343 | frames, _ = video_transforms.horizontal_flip(0.5, frames)
344 | else:
345 | # The testing is deterministic and no jitter should be performed.
346 | # min_scale, max_scale, and crop_size are expect to be the same.
347 | assert len({min_scale, max_scale, crop_size}) == 1
348 | frames, _ = video_transforms.random_short_side_scale_jitter(
349 | frames, min_scale, max_scale
350 | )
351 | frames, _ = video_transforms.uniform_crop(frames, crop_size, spatial_idx)
352 | return frames
353 |
354 |
355 | def tensor_normalize(tensor, mean, std):
356 | """
357 | Normalize a given tensor by subtracting the mean and dividing the std.
358 | Args:
359 | tensor (tensor): tensor to normalize.
360 | mean (tensor or list): mean value to subtract.
361 | std (tensor or list): std to divide.
362 | """
363 | if tensor.dtype == torch.uint8:
364 | tensor = tensor.float()
365 | tensor = tensor / 255.0
366 | if type(mean) == list:
367 | mean = torch.tensor(mean)
368 | if type(std) == list:
369 | std = torch.tensor(std)
370 | tensor = tensor - mean
371 | tensor = tensor / std
372 | return tensor
373 |
--------------------------------------------------------------------------------
/modeling_student.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from functools import partial, reduce
7 | from operator import mul
8 | from einops import rearrange
9 | import torch.utils.checkpoint as checkpoint
10 |
11 | from modeling_finetune import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table, get_3d_sincos_pos_embed
12 | from timm.models.registry import register_model
13 | from timm.models.layers import trunc_normal_ as __call_trunc_normal_
14 | from timm.models.layers import drop_path, to_2tuple
15 |
16 |
17 | def trunc_normal_(tensor, mean=0., std=1.):
18 | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
19 |
20 |
21 | __all__ = [
22 | 'pretrain_masked_video_student_small_patch16_224',
23 | 'pretrain_masked_video_student_base_patch16_224',
24 | 'pretrain_masked_video_student_large_patch16_224',
25 | 'pretrain_masked_video_student_huge_patch16_224',
26 | ]
27 |
28 |
29 | class DropPath(nn.Module):
30 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31 | """
32 |
33 | def __init__(self, drop_prob=None):
34 | super(DropPath, self).__init__()
35 | self.drop_prob = drop_prob
36 |
37 | def forward(self, x):
38 | return drop_path(x, self.drop_prob, self.training)
39 |
40 | def extra_repr(self) -> str:
41 | return 'p={}'.format(self.drop_prob)
42 |
43 |
44 | class Mlp(nn.Module):
45 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46 | super().__init__()
47 | out_features = out_features or in_features
48 | hidden_features = hidden_features or in_features
49 | self.fc1 = nn.Linear(in_features, hidden_features)
50 | self.act = act_layer()
51 | self.fc2 = nn.Linear(hidden_features, out_features)
52 | self.drop = nn.Dropout(drop)
53 |
54 | def forward(self, x):
55 | x = self.fc1(x)
56 | x = self.act(x)
57 | x = self.fc2(x)
58 | x = self.drop(x)
59 | return x
60 |
61 |
62 | class PretrainVisionTransformerDecoder(nn.Module):
63 | """ Vision Transformer with support for patch or hybrid CNN input stage
64 | """
65 |
66 | def __init__(self, patch_size=16, num_classes=768, embed_dim=768, depth=12,
67 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
68 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, use_checkpoint=False
69 | ):
70 | super().__init__()
71 | self.num_classes = num_classes
72 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
73 | self.patch_size = patch_size
74 | self.use_checkpoint = use_checkpoint
75 |
76 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
77 | self.blocks = nn.ModuleList([
78 | Block(
79 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
80 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
81 | init_values=init_values)
82 | for i in range(depth)])
83 | self.norm = norm_layer(embed_dim)
84 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
85 |
86 | self.apply(self._init_weights)
87 |
88 | def _init_weights(self, m):
89 | if isinstance(m, nn.Linear):
90 | nn.init.xavier_uniform_(m.weight)
91 | if isinstance(m, nn.Linear) and m.bias is not None:
92 | nn.init.constant_(m.bias, 0)
93 | elif isinstance(m, nn.LayerNorm):
94 | nn.init.constant_(m.bias, 0)
95 | nn.init.constant_(m.weight, 1.0)
96 |
97 | def get_num_layers(self):
98 | return len(self.blocks)
99 |
100 | @torch.jit.ignore
101 | def no_weight_decay(self):
102 | return {'pos_embed', 'cls_token'}
103 |
104 | def get_classifier(self):
105 | return self.head
106 |
107 | def reset_classifier(self, num_classes, global_pool=''):
108 | self.num_classes = num_classes
109 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
110 |
111 | def forward(self, x, return_token_num):
112 | if self.use_checkpoint:
113 | for blk in self.blocks:
114 | x = checkpoint.checkpoint(blk, x)
115 | else:
116 | for blk in self.blocks:
117 | x = blk(x)
118 |
119 | if return_token_num > 0:
120 | x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
121 | else:
122 | x = self.head(self.norm(x))
123 |
124 | return x
125 |
126 |
127 | class PretrainMaskedVideoStudent(nn.Module):
128 |
129 | def __init__(self,
130 | img_size=224,
131 | patch_size=16,
132 | encoder_in_chans=3,
133 | encoder_embed_dim=768,
134 | encoder_depth=12,
135 | encoder_num_heads=12,
136 | decoder_depth=4,
137 | feat_decoder_embed_dim=None,
138 | feat_decoder_num_heads=None,
139 | mlp_ratio=4.,
140 | qkv_bias=False,
141 | qk_scale=None,
142 | drop_rate=0.,
143 | attn_drop_rate=0.,
144 | drop_path_rate=0.,
145 | norm_layer=nn.LayerNorm,
146 | init_values=0.,
147 | tubelet_size=2,
148 | num_frames=16,
149 | use_cls_token=False,
150 | target_feature_dim=768,
151 | target_video_feature_dim=768,
152 | use_checkpoint=False,
153 | ):
154 | super().__init__()
155 | self.use_cls_token = use_cls_token
156 | self.patch_embed = PatchEmbed(
157 | img_size=img_size, patch_size=patch_size, in_chans=encoder_in_chans,
158 | embed_dim=encoder_embed_dim, tubelet_size=tubelet_size, num_frames=num_frames)
159 | self.patch_size = self.patch_embed.patch_size
160 | num_patches = self.patch_embed.num_patches
161 | self.encoder_embed_dim = encoder_embed_dim
162 | self.tubelet_size = tubelet_size
163 | self.num_frames = num_frames
164 | self.use_checkpoint = use_checkpoint
165 |
166 | if use_cls_token:
167 | self.cls_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
168 | else:
169 | self.cls_token = None
170 |
171 | # sine-cosine positional embeddings
172 | self.pos_embed = get_3d_sincos_pos_embed(embed_dim=encoder_embed_dim,
173 | grid_size=self.patch_embed.num_patches_h,
174 | t_size=self.patch_embed.num_patches_t)
175 | self.pos_embed = nn.Parameter(self.pos_embed, requires_grad=False)
176 | self.pos_embed.requires_grad = False
177 |
178 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, encoder_depth)] # stochastic depth decay rule
179 | self.blocks = nn.ModuleList([
180 | Block(
181 | dim=encoder_embed_dim, num_heads=encoder_num_heads,
182 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
183 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
184 | init_values=init_values)
185 | for i in range(encoder_depth)])
186 | self.norm = norm_layer(encoder_embed_dim)
187 |
188 | if feat_decoder_embed_dim is None:
189 | feat_decoder_embed_dim = encoder_embed_dim
190 | if feat_decoder_num_heads is None:
191 | feat_decoder_num_heads = encoder_num_heads
192 | self.mask_token_img = nn.Parameter(torch.zeros(1, 1, feat_decoder_embed_dim))
193 | self.down_img = nn.Linear(encoder_embed_dim, feat_decoder_embed_dim)
194 | self.decoder_img = PretrainVisionTransformerDecoder(
195 | patch_size=patch_size,
196 | num_classes=target_feature_dim,
197 | embed_dim=feat_decoder_embed_dim,
198 | depth=decoder_depth,
199 | num_heads=feat_decoder_num_heads,
200 | mlp_ratio=mlp_ratio,
201 | qkv_bias=qkv_bias,
202 | qk_scale=qk_scale,
203 | drop_rate=drop_rate,
204 | attn_drop_rate=attn_drop_rate,
205 | drop_path_rate=drop_path_rate,
206 | norm_layer=norm_layer,
207 | init_values=init_values,
208 | use_checkpoint=use_checkpoint,
209 | )
210 | self.pos_embed_img = get_3d_sincos_pos_embed(
211 | embed_dim=feat_decoder_embed_dim,
212 | grid_size=self.patch_embed.num_patches_h,
213 | t_size=self.patch_embed.num_patches_t
214 | )
215 | self.pos_embed_img = nn.Parameter(self.pos_embed_img, requires_grad=False)
216 | self.pos_embed_img.requires_grad = False
217 | trunc_normal_(self.mask_token_img, std=.02)
218 |
219 | self.mask_token_vid = nn.Parameter(torch.zeros(1, 1, feat_decoder_embed_dim))
220 | self.down_vid = nn.Linear(encoder_embed_dim, feat_decoder_embed_dim)
221 | self.decoder_vid = PretrainVisionTransformerDecoder(
222 | patch_size=patch_size,
223 | num_classes=target_video_feature_dim,
224 | embed_dim=feat_decoder_embed_dim,
225 | depth=decoder_depth,
226 | num_heads=feat_decoder_num_heads,
227 | mlp_ratio=mlp_ratio,
228 | qkv_bias=qkv_bias,
229 | qk_scale=qk_scale,
230 | drop_rate=drop_rate,
231 | attn_drop_rate=attn_drop_rate,
232 | drop_path_rate=drop_path_rate,
233 | norm_layer=norm_layer,
234 | init_values=init_values,
235 | use_checkpoint=use_checkpoint,
236 | )
237 | self.pos_embed_vid = get_3d_sincos_pos_embed(
238 | embed_dim=feat_decoder_embed_dim,
239 | grid_size=self.patch_embed.num_patches_h,
240 | t_size=self.patch_embed.num_patches_t
241 | )
242 | self.pos_embed_vid = nn.Parameter(self.pos_embed_vid, requires_grad=False)
243 | self.pos_embed_vid.requires_grad = False
244 | trunc_normal_(self.mask_token_vid, std=.02)
245 |
246 | self.apply(self._init_weights)
247 |
248 | if self.use_cls_token:
249 | nn.init.normal_(self.cls_token, std=1e-6)
250 |
251 | def _init_weights(self, m):
252 | if isinstance(m, nn.Linear):
253 | nn.init.xavier_uniform_(m.weight)
254 | if isinstance(m, nn.Linear) and m.bias is not None:
255 | nn.init.constant_(m.bias, 0)
256 | elif isinstance(m, nn.LayerNorm):
257 | nn.init.constant_(m.bias, 0)
258 | nn.init.constant_(m.weight, 1.0)
259 |
260 | def get_num_layers(self):
261 | return len(self.blocks)
262 |
263 | @torch.jit.ignore
264 | def no_weight_decay(self):
265 | return {'pos_embed', 'cls_token', 'mask_token'}
266 |
267 | def forward_encoder(self, x, mask):
268 | # embed patches
269 | # x: B, C, T, H, W
270 | x = self.patch_embed(x)
271 |
272 | # add pos embed w/o cls token
273 | x = x + self.pos_embed.type_as(x).detach()
274 | # x: B, L, C
275 |
276 | # masking: length -> length * mask_ratio
277 | B, _, C = x.shape
278 | x = x[~mask].reshape(B, -1, C) # ~mask means visible
279 |
280 | # append cls token
281 | if self.use_cls_token:
282 | cls_tokens = self.cls_token.expand(B, -1, -1)
283 | x = torch.cat((cls_tokens, x), dim=1)
284 |
285 | # apply Transformer blocks
286 | if self.use_checkpoint:
287 | for blk in self.blocks:
288 | x = checkpoint.checkpoint(blk, x)
289 | else:
290 | for blk in self.blocks:
291 | x = blk(x)
292 |
293 | x = self.norm(x)
294 |
295 | return x
296 |
297 | def forward(self, x, mask):
298 | x = self.forward_encoder(x, mask)
299 | s = 1 if self.use_cls_token else 0
300 |
301 | x_vis_img = self.down_img(x)
302 | B, N, C = x_vis_img.shape
303 |
304 | expand_pos_embed_img = self.pos_embed_img.type_as(x_vis_img).detach().expand(B, -1, -1)
305 | pos_emd_vis_img = expand_pos_embed_img[~mask].reshape(B, -1, C)
306 | pos_emd_mask_img = expand_pos_embed_img[mask].reshape(B, -1, C)
307 | x_img = torch.cat(
308 | [x_vis_img[:, s:, :] + pos_emd_vis_img, self.mask_token_img + pos_emd_mask_img],
309 | dim=1) # [B, N, C_d]
310 | x_img = torch.cat([x_vis_img[:, :s, :], x_img], dim=1)
311 |
312 | x_img = self.decoder_img(x_img, pos_emd_mask_img.shape[1])
313 |
314 | x_vis_vid = self.down_vid(x)
315 | B, N, C = x_vis_vid.shape
316 |
317 | expand_pos_embed_vid = self.pos_embed_vid.type_as(x_vis_vid).detach().expand(B, -1, -1)
318 | pos_emd_vis_vid = expand_pos_embed_vid[~mask].reshape(B, -1, C)
319 | pos_emd_mask_vid = expand_pos_embed_vid[mask].reshape(B, -1, C)
320 | x_vid = torch.cat(
321 | [x_vis_vid[:, s:, :] + pos_emd_vis_vid, self.mask_token_vid + pos_emd_mask_vid],
322 | dim=1) # [B, N, C_d]
323 | x_vid = torch.cat([x_vis_vid[:, :s, :], x_vid], dim=1)
324 |
325 | x_vid = self.decoder_vid(x_vid, pos_emd_mask_vid.shape[1])
326 |
327 | return x_img, x_vid
328 |
329 |
330 | @register_model
331 | def pretrain_masked_video_student_small_patch16_224(pretrained=False, **kwargs):
332 | model = PretrainMaskedVideoStudent(
333 | img_size=224,
334 | patch_size=16,
335 | encoder_embed_dim=384,
336 | encoder_depth=12,
337 | encoder_num_heads=6,
338 | mlp_ratio=4,
339 | qkv_bias=True,
340 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
341 | **kwargs)
342 | model.default_cfg = _cfg()
343 | if pretrained:
344 | checkpoint = torch.load(
345 | kwargs["init_ckpt"], map_location="cpu"
346 | )
347 | model.load_state_dict(checkpoint["model"])
348 | return model
349 |
350 |
351 | @register_model
352 | def pretrain_masked_video_student_base_patch16_224(pretrained=False, **kwargs):
353 | model = PretrainMaskedVideoStudent(
354 | img_size=224,
355 | patch_size=16,
356 | encoder_embed_dim=768,
357 | encoder_depth=12,
358 | encoder_num_heads=12,
359 | mlp_ratio=4,
360 | qkv_bias=True,
361 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
362 | **kwargs)
363 | model.default_cfg = _cfg()
364 | if pretrained:
365 | checkpoint = torch.load(
366 | kwargs["init_ckpt"], map_location="cpu"
367 | )
368 | model.load_state_dict(checkpoint["model"])
369 | return model
370 |
371 |
372 | @register_model
373 | def pretrain_masked_video_student_large_patch16_224(pretrained=False, **kwargs):
374 | model = PretrainMaskedVideoStudent(
375 | img_size=224,
376 | patch_size=16,
377 | encoder_embed_dim=1024,
378 | encoder_depth=24,
379 | encoder_num_heads=16,
380 | mlp_ratio=4,
381 | qkv_bias=True,
382 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
383 | **kwargs)
384 | model.default_cfg = _cfg()
385 | if pretrained:
386 | checkpoint = torch.load(
387 | kwargs["init_ckpt"], map_location="cpu"
388 | )
389 | model.load_state_dict(checkpoint["model"])
390 | return model
391 |
392 |
393 | @register_model
394 | def pretrain_masked_video_student_huge_patch16_224(pretrained=False, **kwargs):
395 | model = PretrainMaskedVideoStudent(
396 | img_size=224,
397 | patch_size=16,
398 | encoder_embed_dim=1280,
399 | encoder_depth=32,
400 | encoder_num_heads=16,
401 | mlp_ratio=4,
402 | qkv_bias=True,
403 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
404 | **kwargs)
405 | model.default_cfg = _cfg()
406 | if pretrained:
407 | checkpoint = torch.load(
408 | kwargs["init_ckpt"], map_location="cpu"
409 | )
410 | model.load_state_dict(checkpoint["model"])
411 | return model
412 |
--------------------------------------------------------------------------------