├── configs
├── data
│ ├── __init__.py
│ ├── debug
│ │ └── .gitignore
│ ├── pennaction_train.py
│ ├── pennaction_trainval.py
│ └── base.py
└── sa
│ ├── sa_ds_penn.py
│ ├── sa_ds_ikea.py
│ └── sa_ds_h2o.py
├── baseball_swing.gif
├── src
├── casa
│ ├── casa_module
│ │ ├── __init__.py
│ │ ├── linear_attention.py
│ │ └── transformer.py
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── fcl.py
│ │ ├── projection_head.py
│ │ └── resnet_fpn.py
│ ├── utils
│ │ ├── position_encoding.py
│ │ ├── supervision.py
│ │ └── matching.py
│ └── casa.py
├── utils
│ ├── dataset.py
│ ├── plotting.py
│ ├── profiler.py
│ ├── augment.py
│ ├── dataloader.py
│ └── misc.py
├── optimizers
│ └── __init__.py
├── evaluation
│ ├── task_utils.py
│ ├── kendalls_tau.py
│ ├── classification.py
│ └── event_completion.py
├── config
│ └── default.py
├── losses
│ └── casa_loss.py
├── lightning
│ ├── data.py
│ └── lightning_casa.py
└── datasets
│ └── pennaction.py
├── requirements.txt
├── download_extra_data.sh
├── env.yml
├── scripts
├── train
│ ├── h2o_train.sh
│ ├── ikea_train.sh
│ └── pennaction_train.sh
└── eval
│ └── pennaction_eval.sh
├── dataset_splits.py
├── dataset_preparation
├── penn_action
│ ├── mat2json.py
│ └── read_pose.py
├── preprocess_norm.py
└── preprocess_norm_mat.py
├── joint_ids.py
├── README.md
├── eval.py
├── train.py
└── License
/configs/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/data/debug/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | */
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/baseball_swing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taeinkwon/CASA/HEAD/baseball_swing.gif
--------------------------------------------------------------------------------
/src/casa/casa_module/__init__.py:
--------------------------------------------------------------------------------
1 | from .transformer import LocalFeatureTransformer
2 |
3 |
--------------------------------------------------------------------------------
/src/casa/__init__.py:
--------------------------------------------------------------------------------
1 | from .casa import CASA
2 | #from .classification import Classification
3 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv_python==4.4.0.46
2 | tqdm
3 | pytorch-lightning==1.4.8
4 | loguru==0.5.3
5 | yacs==0.1.8
6 |
7 | joblib
8 | albumentations
9 |
--------------------------------------------------------------------------------
/src/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import io
2 | from loguru import logger
3 |
4 | import cv2
5 | import numpy as np
6 | import h5py
7 | import torch
8 | from numpy.linalg import inv
9 |
10 | # --- DATA IO ---
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/download_extra_data.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 |
3 | mkdir -p extra_data/body_module
4 | cd extra_data/body_module
5 |
6 | echo "J_regressor_extra_smplx"
7 | wget https://dl.fbaipublicfiles.com/eft/fairmocap_data/body_module/J_regressor_extra_smplx.npy
8 |
9 | echo "Done"
--------------------------------------------------------------------------------
/configs/data/pennaction_train.py:
--------------------------------------------------------------------------------
1 | from configs.data.base import cfg
2 |
3 |
4 | TRAIN_BASE_PATH = "npyrecords/"
5 | cfg.DATASET.TRAINVAL_DATA_SOURCE = "PennAction"
6 | cfg.DATASET.TRAIN_DATA_ROOT = f"{TRAIN_BASE_PATH}/baseball_pitch_val.npy"
7 | cfg.DATASET.TRAIN_NPZ_ROOT = ""
8 |
9 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: CASA
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.8
8 | - pip
9 | - numpy
10 | - pytorch==1.7.1
11 | - torchvision==0.8.2
12 | - pip:
13 | - opencv_python==4.4.0.46
14 | - tqdm
15 | - pytorch-lightning==1.4.8
16 | - loguru==0.5.3
17 | - yacs==0.1.8
18 | - joblib
19 | - albumentations
20 | - torchgeometry
21 | - smplx
22 | - dotmap
23 | - einops
24 | - seaborn
25 | - easydict
26 | - chumpy
27 | - torchsummary
28 | - torchmetrics==0.6.0
29 |
--------------------------------------------------------------------------------
/src/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import matplotlib
4 | import seaborn as sn
5 | import pandas as pd
6 | import os
7 |
8 |
9 | def vis_conf_matrix(conf_matrix, save_path):
10 |
11 | conf_matrix = np.around(conf_matrix, decimals=5)
12 | sn.set(font_scale=0.05)
13 | df_cm = pd.DataFrame(conf_matrix, index=[i for i in range(np.shape(conf_matrix)[0])],
14 | columns=[i for i in range(np.shape(conf_matrix)[1])])
15 | df_cm = df_cm[::-1]
16 | svm = sn.heatmap(df_cm, annot=False, cmap="OrRd")
17 | figure = svm.get_figure()
18 | figure.savefig(save_path, dpi=400)
19 | figure.clf()
20 |
--------------------------------------------------------------------------------
/src/casa/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
2 | from .fcl import FCL
3 | from .projection_head import ProjectionHead
4 |
5 | def build_backbone(config):
6 | if config['backbone_type'] == 'ResNetFPN':
7 | if config['resolution'] == (8, 2):
8 | return ResNetFPN_8_2(config['resnetfpn'])
9 | elif config['resolution'] == (16, 4):
10 | return ResNetFPN_16_4(config['resnetfpn'])
11 | elif config['backbone_type'] == 'FCL':
12 | return FCL(config['fcl'])
13 |
14 | elif config['backbone_type'] == 'PH':
15 | return ProjectionHead(config)
16 |
17 | else:
18 | raise ValueError(
19 | f"CASA.BACKBONE_TYPE {config['backbone_type']} not supported.")
20 |
--------------------------------------------------------------------------------
/configs/data/pennaction_trainval.py:
--------------------------------------------------------------------------------
1 | from configs.data.base import cfg
2 | #from config import ENVCONFIG
3 |
4 | TRAIN_BASE_PATH = "npyrecords/"
5 | cfg.DATASET.TRAINVAL_DATA_SOURCE = "PennAction"
6 | #DATASET_NAME = ENVCONFIG.DATASETNAME
7 | #cfg.DATASET.TRAIN_DATA_ROOT = f"{TRAIN_BASE_PATH}/{DATASET_NAME}_train.npy"
8 | cfg.DATASET.TRAIN_DATA_ROOT = f"{TRAIN_BASE_PATH}"
9 | cfg.DATASET.TRAIN_NPZ_ROOT = ""
10 |
11 |
12 | TEST_BASE_PATH = "npyrecords"
13 | cfg.DATASET.TEST_DATA_SOURCE = "PennAction"
14 | #cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = [f"{TEST_BASE_PATH}/{DATASET_NAME}_train.npy",f"{TEST_BASE_PATH}/{DATASET_NAME}_val.npy"]
15 | cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = [f"{TEST_BASE_PATH}",f"{TEST_BASE_PATH}"]
16 |
17 | cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = ""
18 |
--------------------------------------------------------------------------------
/configs/data/base.py:
--------------------------------------------------------------------------------
1 | """
2 | The data config will be the last one merged into the main config.
3 | Setups in data configs will override all existed setups!
4 | """
5 |
6 | from yacs.config import CfgNode as CN
7 | _CN = CN()
8 | _CN.DATASET = CN()
9 | _CN.TRAINER = CN()
10 |
11 | # training data config
12 | _CN.DATASET.TRAIN_DATA_ROOT = None
13 | _CN.DATASET.TRAIN_POSE_ROOT = None
14 | _CN.DATASET.TRAIN_NPZ_ROOT = None
15 | _CN.DATASET.TRAIN_LIST_PATH = None
16 | # validation set config
17 | _CN.DATASET.VAL_DATA_ROOT = None
18 | _CN.DATASET.VAL_POSE_ROOT = None
19 | _CN.DATASET.VAL_NPZ_ROOT = None
20 | _CN.DATASET.VAL_LIST_PATH = None
21 |
22 | # testing data config
23 | _CN.DATASET.TEST_DATA_ROOT = None
24 | _CN.DATASET.TEST_POSE_ROOT = None
25 | _CN.DATASET.TEST_NPZ_ROOT = None
26 | _CN.DATASET.TEST_LIST_PATH = None
27 |
28 |
29 | cfg = _CN
30 |
--------------------------------------------------------------------------------
/src/casa/backbone/fcl.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class FCL(nn.Module):
6 | """
7 | Fully Connected Network, we set 2 layers with the dimension number
8 | which is same as the dimension of the input.
9 | """
10 |
11 | def __init__(self, config):
12 | super().__init__()
13 |
14 | # Config
15 | #print("config['initial_dim']",config)
16 | initial_dim = config['initial_dim']
17 | #block_dims = config['block_dims']
18 |
19 | # Networks
20 | self.fc1 = nn.Linear(initial_dim, initial_dim)
21 | self.fc2 = nn.Linear(initial_dim, initial_dim)
22 | #self.bn1 = nn.BatchNorm2d(initial_dim)
23 | self.relu = nn.ReLU(inplace=True)
24 |
25 | def forward(self, x):
26 | # FCL Backbone
27 | x = self.fc1(x)
28 | x = F.relu(x)
29 | x = self.fc2(x)
30 | output = F.relu(x)
31 |
32 | return output
33 |
--------------------------------------------------------------------------------
/src/casa/backbone/projection_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class ProjectionHead(nn.Module):
6 | """
7 | Fully Connected Network, we set 2 layers with the dimension number
8 | which is same as the dimension of the input.
9 | """
10 |
11 | def __init__(self, config):
12 | super().__init__()
13 |
14 | # Config
15 | # print("config['initial_dim']",config)
16 | input_dim = config['input_dim']
17 | hidden_dim = config['hidden_dim']
18 | output_dim = config['output_dim']
19 | #block_dims = config['block_dims']
20 |
21 | # Networks
22 | self.fc1 = nn.Linear(input_dim, hidden_dim)
23 | self.fc2 = nn.Linear(hidden_dim, output_dim)
24 | #self.bn1 = nn.BatchNorm2d(initial_dim)
25 | self.relu = nn.ReLU(inplace=True)
26 |
27 | def forward(self, x):
28 | # FCL Backbone
29 | x = self.fc1(x)
30 | x = F.relu(x)
31 | output = self.fc2(x)
32 | #output = F.relu(x)
33 |
34 | return output
35 |
--------------------------------------------------------------------------------
/scripts/train/h2o_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 |
3 | SCRIPTPATH=$(dirname $(readlink -f "$0"))
4 | PROJECT_DIR="${SCRIPTPATH}/../../"
5 |
6 | # conda activate skeletal_alignment
7 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
8 | cd $PROJECT_DIR
9 |
10 | data_cfg_path="configs/data/pennaction_trainval.py"
11 | main_cfg_path="configs/sa/sa_ds_h2o.py"
12 |
13 | n_nodes=1
14 | n_gpus_per_node=1
15 | torch_num_workers=0
16 | batch_size=32
17 | pin_memory=true
18 | val_steps=1
19 | exp_name="CASA=$(($n_gpus_per_node * $n_nodes * $batch_size))"
20 | #--benchmark=True \
21 |
22 | python -u ./train.py \
23 | --data_cfg_path=${data_cfg_path} \
24 | --main_cfg_path=${main_cfg_path} \
25 | --exp_name=${exp_name} \
26 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
27 | --batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
28 | --check_val_every_n_epoch=${val_steps} \
29 | --log_every_n_steps=100 \
30 | --flush_logs_every_n_steps=${val_steps} \
31 | --limit_val_batches=1. \
32 | --num_sanity_val_steps=0 \
33 | --max_epochs=100 \
34 | --data_folder="./" \
35 | --dataset_name="pouring_milk"
--------------------------------------------------------------------------------
/scripts/train/ikea_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 |
3 | SCRIPTPATH=$(dirname $(readlink -f "$0"))
4 | PROJECT_DIR="${SCRIPTPATH}/../../"
5 |
6 | # conda activate skeletal_alignment
7 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
8 | cd $PROJECT_DIR
9 |
10 | data_cfg_path="configs/data/pennaction_trainval.py"
11 | main_cfg_path="configs/sa/sa_ds_ikea.py"
12 |
13 | n_nodes=1
14 | n_gpus_per_node=1
15 | torch_num_workers=0
16 | batch_size=4
17 | pin_memory=true
18 | val_steps=1
19 | exp_name="CASA=$(($n_gpus_per_node * $n_nodes * $batch_size))"
20 | #--benchmark=True \
21 |
22 | python -u ./train.py \
23 | --data_cfg_path=${data_cfg_path} \
24 | --main_cfg_path=${main_cfg_path} \
25 | --exp_name=${exp_name} \
26 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
27 | --batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
28 | --check_val_every_n_epoch=${val_steps} \
29 | --log_every_n_steps=100 \
30 | --flush_logs_every_n_steps=${val_steps} \
31 | --limit_val_batches=1. \
32 | --num_sanity_val_steps=0 \
33 | --max_epochs=100 \
34 | --data_folder="./" \
35 | --dataset_name="kallax_shelf_drawer"
--------------------------------------------------------------------------------
/scripts/train/pennaction_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 | echo ${1}
3 |
4 | SCRIPTPATH=$(dirname $(readlink -f "$0"))
5 | PROJECT_DIR="${SCRIPTPATH}/../../"
6 |
7 | # conda activate skeletal_alignment
8 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
9 | cd $PROJECT_DIR
10 |
11 | data_cfg_path="configs/data/pennaction_trainval.py"
12 | main_cfg_path="configs/sa/sa_ds_penn.py"
13 |
14 | n_nodes=1
15 | n_gpus_per_node=1
16 | torch_num_workers=0
17 | batch_size=64
18 | pin_memory=true
19 | val_steps=1
20 | exp_name="CASA=$(($n_gpus_per_node * $n_nodes * $batch_size))"
21 | #--benchmark=True \
22 |
23 | python -u ./train.py \
24 | --data_cfg_path=${data_cfg_path} \
25 | --main_cfg_path=${main_cfg_path} \
26 | --exp_name=${exp_name} \
27 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
28 | --batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
29 | --check_val_every_n_epoch=${val_steps} \
30 | --log_every_n_steps=100 \
31 | --flush_logs_every_n_steps=${val_steps} \
32 | --limit_val_batches=1. \
33 | --num_sanity_val_steps=0 \
34 | --max_epochs=200 \
35 | --data_folder="./" \
36 | --dataset_name=${1}
--------------------------------------------------------------------------------
/scripts/eval/pennaction_eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -l
2 | echo ${1}
3 | echo ${2}
4 |
5 | SCRIPTPATH=$(dirname $(readlink -f "$0"))
6 | PROJECT_DIR="${SCRIPTPATH}/../../"
7 |
8 | # conda activate skeletal_alignment
9 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
10 | cd $PROJECT_DIR
11 |
12 | data_cfg_path="configs/data/pennaction_trainval.py"
13 | main_cfg_path="configs/sa/sa_ds_penn.py"
14 | ckpt_path="./"${2}
15 |
16 | n_nodes=1
17 | n_gpus_per_node=1
18 | torch_num_workers=0
19 | batch_size=1
20 | pin_memory=true
21 | val_steps=1
22 | exp_name="CASA=$(($n_gpus_per_node * $n_nodes * $batch_size))"
23 | #--benchmark=True \
24 |
25 | python -u ./eval.py \
26 | --data_cfg_path=${data_cfg_path} \
27 | --main_cfg_path=${main_cfg_path} \
28 | --exp_name=${exp_name} \
29 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \
30 | --batch_size=${batch_size} --num_workers=${torch_num_workers} --pin_memory=${pin_memory} \
31 | --check_val_every_n_epoch=${val_steps} \
32 | --log_every_n_steps=100 \
33 | --flush_logs_every_n_steps=${val_steps} \
34 | --limit_val_batches=1. \
35 | --num_sanity_val_steps=0 \
36 | --max_epochs=200 \
37 | --data_folder="./" \
38 | --ckpt_path=${ckpt_path} \
39 | --videos_dir=${videos_dir} \
40 | --dataset_name=${1}
--------------------------------------------------------------------------------
/src/utils/profiler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
3 | from contextlib import contextmanager
4 | from pytorch_lightning.utilities import rank_zero_only
5 |
6 |
7 | class InferenceProfiler(SimpleProfiler):
8 | """
9 | This profiler records duration of actions with cuda.synchronize()
10 | Use this in test time.
11 | """
12 |
13 | def __init__(self):
14 | super().__init__()
15 | self.start = rank_zero_only(self.start)
16 | self.stop = rank_zero_only(self.stop)
17 | self.summary = rank_zero_only(self.summary)
18 |
19 | @contextmanager
20 | def profile(self, action_name: str) -> None:
21 | try:
22 | torch.cuda.synchronize()
23 | self.start(action_name)
24 | yield action_name
25 | finally:
26 | torch.cuda.synchronize()
27 | self.stop(action_name)
28 |
29 |
30 | def build_profiler(name):
31 | if name == 'inference':
32 | return InferenceProfiler()
33 | elif name == 'pytorch':
34 | from pytorch_lightning.profiler import PyTorchProfiler
35 | return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
36 | elif name is None:
37 | return PassThroughProfiler()
38 | else:
39 | raise ValueError(f'Invalid profiler: {name}')
40 |
--------------------------------------------------------------------------------
/dataset_splits.py:
--------------------------------------------------------------------------------
1 | """List of subsets."""
2 |
3 | DATASETS = {
4 | 'pouring': {'train': 70, 'val': 14, 'test': 32},
5 | 'baseball_pitch': {'train': 103, 'val': 63, 'max_len':243},
6 | 'baseball_swing': {'train': 113, 'val': 57,'max_len':95},
7 | 'bench_press': {'train': 113, 'val': 57, 'max_len':218},
8 | 'bowling': {'train': 134, 'val': 84,'max_len':564},
9 | 'clean_and_jerk': {'train': 40, 'val': 42,'max_len':663},
10 | 'golf_swing': {'train': 87, 'val': 77,'max_len':95},
11 | 'jumping_jacks': {'train': 56, 'val': 56,'max_len':42},
12 | 'pushups': {'train': 102, 'val': 105,'max_len':189},
13 | 'pullups': {'train': 98, 'val': 100,'max_len':301},
14 | 'situp': {'train': 50, 'val': 50,'max_len':242},
15 | 'squats': {'train': 114, 'val': 115,'max_len':178},
16 | 'tennis_forehand': {'train': 79, 'val': 74,'max_len':95},
17 | 'tennis_serve': {'train': 115, 'val': 68,'max_len':100},
18 | 'kallax_shelf_drawer': {'train': 61, 'val': 29,'max_len':4078},
19 | 'pouring_milk': {'train': 27, 'val': 11,'max_len':865},
20 | }
21 |
22 |
23 | DATASET_TO_NUM_CLASSES = {
24 | 'pouring': 5,
25 | 'baseball_pitch': 4,
26 | 'baseball_swing': 3,
27 | 'bench_press': 2,
28 | 'bowling': 3,
29 | 'clean_and_jerk': 6,
30 | 'golf_swing': 3,
31 | 'jumping_jacks': 4,
32 | 'pushups': 2,
33 | 'pullups': 2,
34 | 'situp': 2,
35 | 'squats': 4,
36 | 'tennis_forehand': 3,
37 | 'tennis_serve': 4,
38 | 'Kallax_Shelf_Drawer': 17,
39 | 'pouring_milk': 10,
40 | }
41 |
--------------------------------------------------------------------------------
/src/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
3 |
4 |
5 | def build_optimizer(model, config):
6 | name = config.TRAINER.OPTIMIZER
7 | lr = config.TRAINER.TRUE_LR
8 |
9 | if name == "adam":
10 | return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
11 | elif name == "adamw":
12 | return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
13 | else:
14 | raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
15 |
16 |
17 | def build_scheduler(config, optimizer):
18 | """
19 | Returns:
20 | scheduler (dict):{
21 | 'scheduler': lr_scheduler,
22 | 'interval': 'step', # or 'epoch'
23 | 'monitor': 'val_f1', (optional)
24 | 'frequency': x, (optional)
25 | }
26 | """
27 | scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
28 | name = config.TRAINER.SCHEDULER
29 |
30 | if name == 'MultiStepLR':
31 | scheduler.update(
32 | {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
33 | elif name == 'CosineAnnealing':
34 | scheduler.update(
35 | {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
36 | elif name == 'ExponentialLR':
37 | scheduler.update(
38 | {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
39 | else:
40 | raise NotImplementedError()
41 |
42 | return scheduler
43 |
--------------------------------------------------------------------------------
/src/evaluation/task_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 |
8 | def regression_labels_for_class(labels, class_idx,LAST=False):
9 | # Assumes labels are ordered. Find the last occurrence of particular class.
10 | #print("labels",labels)
11 | #print("class_idx",class_idx)
12 | #print("np.argwhere(labels == class_idx)",np.argwhere(labels == class_idx))
13 | if LAST:
14 | transition_frame = len(labels)
15 | else:
16 | transition_frame = np.argwhere(labels == class_idx)[-1, 0]
17 | return (np.arange(float(len(labels))) - transition_frame) / len(labels)
18 |
19 |
20 |
21 | def get_regression_labels(class_labels, num_classes):
22 | regression_labels = []
23 | for i in range(num_classes - 1):
24 | if i in class_labels:
25 | regression_labels.append(regression_labels_for_class(class_labels, i))
26 | else:
27 | if i == num_classes - 2:
28 | regression_labels.append(regression_labels_for_class(class_labels, i,LAST=True))
29 | print("last",regression_labels_for_class(class_labels, i,LAST=True))
30 | else:
31 | regression_labels.append(regression_labels[i-1])
32 | return np.stack(regression_labels, axis=1)
33 |
34 |
35 | def get_targets_from_labels(all_class_labels, num_classes):
36 | all_regression_labels = []
37 | for class_labels in all_class_labels:
38 | all_regression_labels.append(get_regression_labels(class_labels,
39 | num_classes))
40 | return all_regression_labels
41 |
42 |
43 | def unnormalize(preds):
44 | seq_len = len(preds)
45 | return np.mean([i - pred * seq_len for i, pred in enumerate(preds)])
46 |
--------------------------------------------------------------------------------
/src/utils/augment.py:
--------------------------------------------------------------------------------
1 | import albumentations as A
2 |
3 |
4 | class DarkAug(object):
5 | """
6 | Extreme dark augmentation aiming at Aachen Day-Night
7 | """
8 |
9 | def __init__(self) -> None:
10 | self.augmentor = A.Compose([
11 | A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
12 | A.Blur(p=0.1, blur_limit=(3, 9)),
13 | A.MotionBlur(p=0.2, blur_limit=(3, 25)),
14 | A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
15 | A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
16 | ], p=0.75)
17 |
18 | def __call__(self, x):
19 | return self.augmentor(image=x)['image']
20 |
21 |
22 | class MobileAug(object):
23 | """
24 | Random augmentations aiming at images of mobile/handhold devices.
25 | """
26 |
27 | def __init__(self):
28 | self.augmentor = A.Compose([
29 | A.MotionBlur(p=0.25),
30 | A.ColorJitter(p=0.5),
31 | A.RandomRain(p=0.1), # random occlusion
32 | A.RandomSunFlare(p=0.1),
33 | A.JpegCompression(p=0.25),
34 | A.ISONoise(p=0.25)
35 | ], p=1.0)
36 |
37 | def __call__(self, x):
38 | return self.augmentor(image=x)['image']
39 |
40 |
41 | def build_augmentor(method=None, **kwargs):
42 | if method is not None:
43 | raise NotImplementedError('Using of augmentation functions are not supported yet!')
44 | if method == 'dark':
45 | return DarkAug()
46 | elif method == 'mobile':
47 | return MobileAug()
48 | elif method is None:
49 | return None
50 | else:
51 | raise ValueError(f'Invalid augmentation method: {method}')
52 |
53 |
54 | if __name__ == '__main__':
55 | augmentor = build_augmentor('FDA')
56 |
--------------------------------------------------------------------------------
/dataset_preparation/penn_action/mat2json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import scipy.io
3 | import os
4 | import tqdm
5 | if __name__ == "__main__":
6 | CHECK_EMPTY = True
7 | dataset_path = ''
8 | label_path = os.path.join(dataset_path, 'labels')
9 | bbox_path = os.path.join(dataset_path, 'bbox')
10 | mocap_path = os.path.join(dataset_path, 'mocap')
11 |
12 | for action in tqdm.tqdm(range(1, 2327)): # 2135)):#2327
13 | #action = 27
14 | mat = scipy.io.loadmat(os.path.join(
15 | label_path, '{0:04d}.mat'.format(action)))
16 | # print("mat",mat)
17 | dict_frank = {}
18 | action_path = os.path.join(dataset_path, 'bbox/{0:04d}'.format(action))
19 |
20 | if CHECK_EMPTY:
21 | bbox_len = len(os.listdir(os.path.join(
22 | mocap_path, "{0:04d}".format(action), "bbox")))
23 | # print("matframes",mat['nframes'][0][0])
24 | # print("bbox_len",bbox_len)
25 | if mat['nframes'][0][0] != bbox_len:
26 | print("action", action)
27 | else:
28 | if not os.path.exists(action_path):
29 | os.mkdir(action_path)
30 | for frame, bbox in enumerate(mat['bbox']):
31 | x = float(bbox[0])
32 | y = float(bbox[1])
33 | w = float(bbox[2]-bbox[0])
34 | h = float(bbox[3]-bbox[1])
35 | dict_frank = {"image_path": "{0}/frames/{1:04d}/{2:06d}.jpg".format(
36 | dataset_path, action, frame+1), "body_bbox_list": [[x, y, w, h]]}
37 | # print("dict_frank",dict_frank)
38 | with open(os.path.join(dataset_path, 'bbox', '{0:04d}/{1:06d}.json'.format(action, frame+1)), 'w') as outfile:
39 | json.dump(dict_frank, outfile)
40 | #{"image_path": "xxx.jpg", "hand_bbox_list":[{"left_hand":[x,y,w,h], "right_hand":[x,y,w,h]}], "body_bbox_list":[[x,y,w,h]]}
41 |
--------------------------------------------------------------------------------
/src/casa/utils/position_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn, Tensor
4 |
5 | class PositionalEncoding(nn.Module):
6 |
7 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
8 | super().__init__()
9 |
10 | self.d_model = d_model #maximum length of the sequence
11 | self.max_len = max_len #maximum length of the sequence
12 | self.dropout = nn.Dropout(p=dropout)
13 |
14 | position = torch.arange(max_len).unsqueeze(1)
15 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
16 | pe = torch.zeros(max_len,1, d_model)
17 | pe[:,0, 0::2] = torch.sin(position * div_term)
18 | pe[:,0, 1::2] = torch.cos(position * div_term)
19 |
20 | self.register_buffer('pe', pe)
21 |
22 | def forward(self, x: Tensor, steps, len) -> Tensor:
23 | """
24 | Args:
25 | x: Tensor, shape [batch_size, seq_len, embedding_dim]
26 | steps: Normalized steps in the sequence. # batch(512),seq length (20)
27 | len: length of the data in the data.
28 | self.pe = 5000,75
29 | """
30 | #x = torch.transpose(x,0,1) # now, seq_len is the first location
31 |
32 | NORM = False
33 | batch, seq_len,dim = x.shape
34 | x = x * math.sqrt(self.d_model)
35 | if NORM:
36 | len = torch.unsqueeze(len,1)
37 | steps = steps/(len+2)
38 | recv_steps = steps * self.max_len
39 | recv_steps = torch.round(recv_steps)
40 | recv_steps = torch.minimum(recv_steps, torch.tensor(self.max_len-1).cuda())
41 | emb_steps = self.pe[recv_steps.type(torch.LongTensor),:,:x.size(2)]
42 | else:
43 | emb_steps = self.pe[steps.type(torch.LongTensor),:,:x.size(2)]
44 | emb_steps = emb_steps[:,:,0,:]
45 | #print("emb_steps",emb_steps.shape)
46 | x = x + emb_steps
47 |
48 | #before
49 | #x = x + self.pe[:x.size(0),:,:x.size(2)]
50 |
51 | #x = torch.transpose(x,0,1)
52 | #return self.dropout(x)
53 | return x
--------------------------------------------------------------------------------
/configs/sa/sa_ds_penn.py:
--------------------------------------------------------------------------------
1 | from pickle import TRUE
2 | from src.config.default import _CN as cfg
3 |
4 | # 'dual_softmax' # 'dual_bicross'
5 | cfg.CASA.MATCH.MATCH_TYPE = 'dual_softmax'
6 | cfg.CASA.MATCH.MATCH_ALGO = None
7 |
8 | #cfg.DATASET.TRAIN_PAIR = False
9 | #cfg.DATASET.VAL_PAIR = False
10 | cfg.DATASET.MAX_LENGTH = 250
11 | cfg.DATASET.VAL_BATCH_SIZE = 1
12 |
13 | cfg.EVAL.KENDALLS_TAU = True
14 | cfg.EVAL.KENDALLS_TAU_STRIDE = 2 # 5 for Pouring, 2 for PennAction
15 | cfg.EVAL.KENDALLS_TAU_DISTANCE = 'sqeuclidean' # cosine, sqeuclidean
16 | cfg.EVAL.EVENT_COMPLETION = True
17 |
18 | #cfg.DATASET.NUM_STEPS = 1
19 |
20 | # cfg.TRAINER.SCALING = None # this will be calculated automatically
21 | cfg.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
22 | cfg.CASA.MATCH.NHEAD = 15
23 | cfg.CASA.MATCH.D_MODEL = 75 # The maximum length of the sequences
24 | cfg.TRAINER.CANONICAL_LR = 3e-4 # 3e-3 # 1e-3 0.00005#
25 | cfg.CONSTRASTIVE.TRAIN = True
26 | cfg.CASA.MATCH.USE_PRIOR = True
27 | cfg.DATASET.ATT_STYLE = True
28 |
29 | cfg.TRAINER.OPTIMIZER = "adam" # adamw
30 |
31 | cfg.DATASET.NUM_FRAMES = 20 # 20
32 | cfg.CLASSIFICATION.ACC_LIST = [0.1, 0.5, 1.0]
33 |
34 | # Parameters from TCC,
35 | # number of frames that will be embedded jointly, #2 for conv_embedder 1 for casa
36 | cfg.DATASET.NUM_STEPS = 1
37 | cfg.DATASET.FRAME_STRIDE = 15 # stride between context frames
38 |
39 | cfg.DATASET.NAME = ""
40 | cfg.DATASET.SMPL = True
41 | cfg.DATASET.USE_NORM = True
42 | cfg.CASA.MATCH.VIS_CONF_TRAIN = False
43 | cfg.CASA.MATCH.VIS_CONF_VALIDATION = False
44 | cfg.TRAINER.WARMUP_STEP = 0
45 | cfg.TRAINER.SCHEDULER_INTERVAL = 'epoch'
46 | # [20, 40, 60, 80, 100, 120, 140] # MSLR: MultiStepLR[70, 100, 120, 150] #
47 | cfg.TRAINER.MSLR_MILESTONES = [30, 40, 50, 60, 70, 80, 90, 100]
48 | cfg.CONSTRASTIVE.AUGMENTATION_STRATEGY = [
49 | 'fast', 'noise_angle', 'flip', 'noise_translation', 'noise_vposer']
50 | #'noise_vposer' 'translation', 'scale', 'shuffle', 'crop', 'rotation','flip' 'center' 'noise_translation' 'noise_angle' ,'fast'
51 |
52 |
53 | cfg.CASA.PH.TRUE = True # projection head
54 | cfg.CASA.MATCH.PE = True # positional encoding
55 | cfg.CASA.MATCH.LAYER_NAMES = ['self', 'cross'] * 4
56 |
57 | #cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
58 | cfg.CASA.EMBEDDER_TYPE = 'casa' # casa conv_embedder
59 | cfg.CASA.NUM_FRAMES = cfg.DATASET.NUM_FRAMES
60 | # classification regression regression_var
61 | cfg.CASA.LOSS.LOSS_TYPE = 'regression'
62 |
--------------------------------------------------------------------------------
/src/casa/utils/supervision.py:
--------------------------------------------------------------------------------
1 | from math import log
2 | from loguru import logger
3 |
4 | import torch
5 | from einops import repeat
6 | #from kornia.utils import create_meshgrid
7 | import numpy as np
8 | import scipy.stats as stats
9 | import math
10 |
11 | @torch.no_grad()
12 | def mask_pts_at_padded_regions(grid_pt, mask):
13 | mask = repeat(mask, 'n h w -> n (h w) c', c=2)
14 | grid_pt[~mask.bool()] = 0
15 | return grid_pt
16 |
17 |
18 | @torch.no_grad()
19 | def spvs_coarse(data, config):
20 | """
21 | Update:
22 | data (dict): {
23 | "conf_matrix_gt": [N, hw0, hw1],
24 | 'spv_b_ids': [M]
25 | 'spv_i_ids': [M]
26 | 'spv_j_ids': [M]
27 | 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
28 | 'spv_pt1_i': [N, hw1, 2], in original image resolution
29 | }
30 | """
31 | # 1. misc
32 | device = data['keypoints0'].device
33 | N0, T0, K0, D0 = data['keypoints0'].shape
34 | N1, T1, K1, D1 = data['keypoints1'].shape
35 | max_len_t0 = data['len0']
36 | max_len_t1 = data['len1']
37 |
38 | #print("N0, T0, K0, D0", N0, T0, K0, D0)
39 | # Gaussian prior matrix
40 | mu = 0
41 | variance = 1
42 | sigma = math.sqrt(variance)
43 | #x = np.linspace(mu - 2*sigma, mu + 2*sigma, T0*2)
44 | #y = stats.norm.pdf(x, mu, sigma)
45 | if config.CONSTRASTIVE.TRAIN:
46 | conf_matrix_prior = torch.zeros(N0, T0, T1, device=device)
47 | for nn in range(N0):
48 | for ii in range(max_len_t0[nn]):
49 | if data['matching'][nn][ii] == -1:
50 | continue
51 | jj_index = data['matching'][nn][ii]
52 | conf_matrix_prior[nn][ii][jj_index] = 1
53 |
54 | else:
55 | conf_matrix_prior = torch.ones(T0, T1, device=device)
56 | for ii in range(T0):
57 | for jj in range(T1):
58 | # Put Gaussian dtribution
59 | #conf_matrix_prior[ii][jj] = y[T0+abs(ii-jj)]
60 | # Or, just diagonal matrix
61 | if (ii != jj): # and abs(ii-jj) != 1:
62 | conf_matrix_prior[ii][jj] = 0
63 | conf_matrix_prior = conf_matrix_prior.repeat(N0, 1, 1)
64 |
65 | data.update({'conf_matrix_prior': conf_matrix_prior})
66 |
67 |
68 | def compute_supervision_coarse(data, config):
69 | assert len(set(data['dataset_name'])
70 | ) == 1, "Do not support mixed datasets training!"
71 | data_source = data['dataset_name'][0]
72 | if data_source.lower() in ['pennaction', 'h2o', 'ikea_asm']:
73 | spvs_coarse(data, config)
74 | else:
75 | raise ValueError(f'Unknown data source: {data_source}')
76 |
77 |
--------------------------------------------------------------------------------
/dataset_preparation/penn_action/read_pose.py:
--------------------------------------------------------------------------------
1 | import json
2 | import scipy.io
3 | import os
4 | import tqdm
5 | import pickle
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from mpl_toolkits.mplot3d import Axes3D
9 |
10 |
11 | def get_penn_connectivity():
12 | return [
13 | [0, 1],
14 | [1, 4],
15 | [4, 7],
16 | [7, 10],
17 | [0, 2],
18 | [2, 5],
19 | [5, 8],
20 | [8, 11],
21 | [0, 3],
22 | [3, 6],
23 | [6, 9],
24 | [9, 13],
25 | [13, 16],
26 | [16, 18],
27 | [18, 20],
28 | [20, 22],
29 | [9, 12],
30 | [12, 15],
31 | [9, 14],
32 | [14, 17],
33 | [17, 19],
34 | [19, 21],
35 | [21, 23]
36 | ]
37 |
38 |
39 | def get_openpose_connectivity():
40 | return [
41 | [0, 1],
42 | [1, 2],
43 | [2, 3],
44 | [3, 4],
45 | [1, 5],
46 | [5, 6],
47 | [6, 7],
48 | [1, 8],
49 | [8, 9],
50 | [9, 10],
51 | [10, 11],
52 | [11, 24],
53 | [11, 22],
54 | [22, 23],
55 | [8, 12],
56 | [12, 13],
57 | [13, 14],
58 | [14, 21],
59 | [14, 19],
60 | [19, 20],
61 | [0, 15],
62 | [15, 17],
63 | [0, 16],
64 | [16, 18]
65 | ]
66 |
67 |
68 | if __name__ == "__main__":
69 | VIS = True
70 | dataset_path = ''
71 | label_path = os.path.join(dataset_path, 'labels')
72 | bbox_path = os.path.join(dataset_path, 'bbox')
73 | mocap_path = os.path.join(dataset_path, 'mocap')
74 |
75 | # frame = 1
76 | pose_path = os.path.join(mocap_path, "0001/mocap")
77 | for frame in range(1, 100):
78 | with open(os.path.join(pose_path, '{0:06d}_prediction_result.pkl'.format(frame)), 'rb') as f:
79 | data = pickle.load(f)
80 |
81 | # body_pose = np.reshape(
82 | # (data['pred_output_list'][0]['pred_body_pose'][0]), (24, 3)) # (24, 3)
83 | # 25,3
84 | body_pose = data['pred_output_list'][0]['pred_body_joints_img'][:25]
85 | # print("body_pose", body_pose)
86 |
87 | if VIS:
88 | s = body_pose
89 | fig = plt.figure()
90 | ax = plt.axes(projection='3d')
91 | ax.set_xlim3d(-200, 200)
92 | ax.set_ylim3d(-200, 200)
93 | ax.set_zlim3d(-200, 200)
94 | connectivity = get_openpose_connectivity()
95 | for limb in connectivity:
96 | ax.plot3D(s[limb, 0], s[limb, 1], s[limb, 2])
97 | ax.scatter3D(s[:, 0], s[:, 1],
98 | s[:, 2], cmap='Greens')
99 | plt.show(block=False)
100 | plt.pause(0.5)
101 | plt.close()
102 | # visualize
103 |
--------------------------------------------------------------------------------
/configs/sa/sa_ds_ikea.py:
--------------------------------------------------------------------------------
1 | from pickle import TRUE
2 | from src.config.default import _CN as cfg
3 |
4 | # 'dual_softmax' # 'dual_bicross'
5 | cfg.CASA.MATCH.MATCH_TYPE = 'dual_softmax'
6 | cfg.CASA.MATCH.MATCH_ALGO = None
7 |
8 | cfg.DATASET.MAX_LENGTH = 250
9 | cfg.DATASET.VAL_BATCH_SIZE = 256
10 |
11 | cfg.EVAL.KENDALLS_TAU = True
12 | cfg.EVAL.KENDALLS_TAU_STRIDE = 2 # 5 for Pouring, 2 for PennAction
13 | cfg.EVAL.KENDALLS_TAU_DISTANCE = 'sqeuclidean' # cosine, sqeuclidean
14 | cfg.EVAL.EVENT_COMPLETION = True
15 |
16 | #cfg.DATASET.NUM_STEPS = 1
17 |
18 | # cfg.TRAINER.SCALING = None # this will be calculated automatically
19 | cfg.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
20 | cfg.CASA.MATCH.NHEAD = 17 # ikea 17 or 3
21 | cfg.CASA.MATCH.D_MODEL = 51 # dimension of the data, ikea_asm 51 pennaction 75
22 | cfg.TRAINER.CANONICAL_LR = 3e-2 # 3e-3 # 1e-3 0.00005#
23 | cfg.CONSTRASTIVE.TRAIN = True
24 | cfg.CASA.MATCH.USE_PRIOR = True
25 | cfg.DATASET.ATT_STYLE = True
26 |
27 |
28 | cfg.CASA.FCL.INITIAL_DIM = 51 # For pose of IKEA ASM => 51, PENN ACTION =>75
29 | # For pose of IKEA ASM => 51, PENN ACTION =>75
30 | cfg.CASA.PH.OUTPUT_DIM = cfg.CASA.PH.INPUT_DIM = 51
31 | cfg.CASA.PH.HIDDEN_DIM = 51
32 |
33 | cfg.TRAINER.OPTIMIZER = "adam" # adamw
34 |
35 | cfg.DATASET.NUM_FRAMES = 20 # 20
36 | cfg.CLASSIFICATION.ACC_LIST = [0.1, 0.5, 1.0]
37 |
38 | cfg.EVAL.KENDALLS_TAU = False
39 | cfg.EVAL.EVENT_COMPLETION = False
40 |
41 | # Parameters from TCC,
42 | # number of frames that will be embedded jointly, #2 for conv_embedder 1 for casa
43 | cfg.DATASET.NUM_STEPS = 1
44 | cfg.DATASET.FRAME_STRIDE = 15 # stride between context frames
45 |
46 | cfg.DATASET.NAME = "kallax_shelf_drawer" # 20
47 | cfg.DATASET.SMPL = False
48 | cfg.DATASET.USE_NORM = True
49 | cfg.CASA.MATCH.VIS_CONF_TRAIN = False
50 | cfg.CASA.MATCH.VIS_CONF_VALIDATION = False
51 | cfg.TRAINER.WARMUP_STEP = 0
52 | cfg.TRAINER.SCHEDULER_INTERVAL = 'epoch'
53 | # [20, 40, 60, 80, 100, 120, 140] # MSLR: MultiStepLR[70, 100, 120, 150] #
54 | cfg.TRAINER.MSLR_MILESTONES = [30, 40, 50, 60, 70, 80, 90, 100]
55 | # cfg.CASA.MATCH.VIS_CONF_TRAIN
56 | # ['fast','noise_vposer', 'noise_translation' 'noise_angle','flip']
57 | cfg.CONSTRASTIVE.AUGMENTATION_STRATEGY = ['fast']
58 | #'noise_vposer' 'translation', 'scale', 'shuffle', 'crop', 'rotation','flip' 'center' 'noise_translation' 'noise_angle' ,'fast'
59 |
60 |
61 | cfg.CASA.PH.TRUE = True # projection head
62 | cfg.CASA.MATCH.PE = True # positional encoding
63 | cfg.CASA.MATCH.LAYER_NAMES = ['self', 'cross'] * 4
64 | cfg.CASA.MATCH.SIMILARITY = True
65 |
66 | #cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
67 |
68 |
69 | # TCC embbed network parameters
70 | # List of conv layers defined as (channels, kernel_size, activate).
71 | cfg.CASA.EMBEDDER_TYPE = 'casa' # casa conv_embedder
72 |
73 | cfg.CASA.NUM_FRAMES = cfg.DATASET.NUM_FRAMES
74 |
75 | # classification regression regression_var
76 | cfg.CASA.LOSS.LOSS_TYPE = 'regression'
77 |
--------------------------------------------------------------------------------
/configs/sa/sa_ds_h2o.py:
--------------------------------------------------------------------------------
1 | from pickle import TRUE
2 | from src.config.default import _CN as cfg
3 |
4 | # 'dual_softmax' # 'dual_bicross'
5 | cfg.CASA.MATCH.MATCH_TYPE = 'dual_softmax'
6 | cfg.CASA.MATCH.MATCH_ALGO = None
7 |
8 | #cfg.DATASET.TRAIN_PAIR = False
9 | #cfg.DATASET.VAL_PAIR = False
10 | cfg.DATASET.MAX_LENGTH = 250
11 | cfg.DATASET.VAL_BATCH_SIZE = 1#256
12 |
13 | cfg.EVAL.KENDALLS_TAU = True
14 | cfg.EVAL.KENDALLS_TAU_STRIDE = 2 # 5 for Pouring, 2 for PennAction
15 | cfg.EVAL.KENDALLS_TAU_DISTANCE = 'sqeuclidean' # cosine, sqeuclidean
16 | cfg.EVAL.EVENT_COMPLETION = True
17 |
18 | #cfg.DATASET.NUM_STEPS = 1
19 |
20 | # cfg.TRAINER.SCALING = None # this will be calculated automatically
21 | cfg.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
22 | cfg.CASA.MATCH.NHEAD = 21
23 | cfg.CASA.MATCH.D_MODEL = 126 # The maximum length of the sequences
24 | cfg.TRAINER.CANONICAL_LR = 3e-4 # 3e-3 # 1e-3 0.00005#
25 | cfg.CONSTRASTIVE.TRAIN = True
26 | cfg.CASA.MATCH.USE_PRIOR = True
27 | cfg.DATASET.ATT_STYLE = True
28 |
29 |
30 | cfg.CASA.FCL.INITIAL_DIM = 126 # For pose of IKEA ASM => 51, PENN ACTION =>75, H2O => 126
31 | # For pose of IKEA ASM => 51, PENN ACTION =>75
32 | cfg.CASA.PH.OUTPUT_DIM = cfg.CASA.PH.INPUT_DIM = 126
33 | cfg.CASA.PH.HIDDEN_DIM = 126
34 |
35 | cfg.TRAINER.OPTIMIZER = "adam" # adamw
36 |
37 | cfg.DATASET.NUM_FRAMES = 20 # 20
38 | cfg.CLASSIFICATION.ACC_LIST = [0.1, 0.5, 1.0]
39 |
40 | # Parameters from TCC,
41 | # number of frames that will be embedded jointly, #2 for conv_embedder 1 for casa
42 | cfg.DATASET.NUM_STEPS = 1
43 | cfg.DATASET.FRAME_STRIDE = 15 # stride between context frames
44 |
45 | cfg.DATASET.NAME = "pouring_milk" # 20
46 | cfg.DATASET.SMPL = False
47 | cfg.DATASET.MANO = True
48 | cfg.DATASET.USE_NORM = True
49 | cfg.CASA.MATCH.VIS_CONF_TRAIN = False
50 | cfg.CASA.MATCH.VIS_CONF_VALIDATION = False
51 | cfg.TRAINER.WARMUP_STEP = 0
52 | cfg.TRAINER.SCHEDULER_INTERVAL = 'epoch'
53 | # [20, 40, 60, 80, 100, 120, 140] # MSLR: MultiStepLR[70, 100, 120, 150] #
54 | cfg.TRAINER.MSLR_MILESTONES = [30, 40, 50, 60, 70, 80, 90, 100]
55 | # cfg.CASA.MATCH.VIS_CONF_TRAIN
56 | # ['fast','noise_vposer', 'noise_translation' 'noise_angle','flip']
57 | cfg.CONSTRASTIVE.AUGMENTATION_STRATEGY = ['fast','flip','noise_translation', 'noise_angle']#,'noise_translation','noise_vposer']
58 | #'noise_vposer' 'translation', 'scale', 'shuffle', 'crop', 'rotation','flip' 'center' 'noise_translation' 'noise_angle' ,'fast'
59 |
60 |
61 |
62 | cfg.CASA.PH.TRUE = True #projection head
63 | cfg.CASA.MATCH.PE = True #positional encoding
64 | cfg.CASA.MATCH.LAYER_NAMES = ['self', 'cross'] * 4
65 |
66 | #cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29]
67 |
68 |
69 | # TCC embbed network parameters
70 | # List of conv layers defined as (channels, kernel_size, activate).
71 | cfg.CASA.EMBEDDER_TYPE = 'casa' # casa conv_embedder
72 |
73 | cfg.CASA.NUM_FRAMES = cfg.DATASET.NUM_FRAMES
74 |
75 | # classification regression regression_var
76 | cfg.CASA.LOSS.LOSS_TYPE = 'regression'
77 | #fg.CASA.L2_REG_WEIGHT = 0.00001
--------------------------------------------------------------------------------
/src/casa/casa_module/linear_attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3 | Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4 | """
5 |
6 | import torch
7 | from torch.nn import Module, Dropout
8 |
9 |
10 | def elu_feature_map(x):
11 | return torch.nn.functional.elu(x) + 1
12 |
13 |
14 | class LinearAttention(Module):
15 | def __init__(self, eps=1e-6):
16 | super().__init__()
17 | self.feature_map = elu_feature_map
18 | self.eps = eps
19 |
20 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
21 | """ Multi-Head linear attention proposed in "Transformers are RNNs"
22 | Args:
23 | queries: [N, L, H, D]
24 | keys: [N, S, H, D]
25 | values: [N, S, H, D]
26 | q_mask: [N, L]
27 | kv_mask: [N, S]
28 | Returns:
29 | queried_values: (N, L, H, D)
30 | """
31 | Q = self.feature_map(queries)
32 | K = self.feature_map(keys)
33 |
34 | # set padded position to zero
35 | if q_mask is not None:
36 | Q = Q * q_mask[:, :, None, None]
37 | if kv_mask is not None:
38 | K = K * kv_mask[:, :, None, None]
39 | values = values * kv_mask[:, :, None, None]
40 |
41 | v_length = values.size(1)
42 | values = values / v_length # prevent fp16 overflow
43 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
44 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
45 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
46 |
47 | return queried_values.contiguous()
48 |
49 |
50 | class FullAttention(Module):
51 | def __init__(self, use_dropout=False, attention_dropout=0.1):
52 | super().__init__()
53 | self.use_dropout = use_dropout
54 | self.dropout = Dropout(attention_dropout)
55 |
56 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
57 | """ Multi-head scaled dot-product attention, a.k.a full attention.
58 | Args:
59 | queries: [N, L, H, D]
60 | keys: [N, S, H, D]
61 | values: [N, S, H, D]
62 | q_mask: [N, L]
63 | kv_mask: [N, S]
64 | Returns:
65 | queried_values: (N, L, H, D)
66 | """
67 |
68 | # Compute the unnormalized attention and apply the masks
69 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
70 | if kv_mask is not None:
71 | QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
72 |
73 | # Compute the attention and the weighted average
74 | softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
75 | A = torch.softmax(softmax_temp * QK, dim=2)
76 | if self.use_dropout:
77 | A = self.dropout(A)
78 |
79 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
80 |
81 | return queried_values.contiguous()
82 |
--------------------------------------------------------------------------------
/src/evaluation/kendalls_tau.py:
--------------------------------------------------------------------------------
1 | r"""Evaluation train and val loss using the algo.
2 | """
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from absl import flags
9 | from absl import logging
10 |
11 | import numpy as np
12 | from scipy.spatial.distance import cdist
13 | from scipy.stats import kendalltau
14 |
15 | from loguru import logger as loguru_logger
16 | import copy
17 | #FLAGS = flags.FLAGS
18 |
19 |
20 | def _get_kendalls_tau(embs_list, stride, tau_dist):
21 | """Get nearest neighbours in embedding space and calculate Kendall's Tau."""
22 | num_seqs = len(embs_list)
23 |
24 | taus = np.zeros((num_seqs * (num_seqs - 1)))
25 | idx = 0
26 | for i in range(num_seqs):
27 | query_feats = embs_list[i][::stride]
28 | for j in range(num_seqs):
29 | if i == j:
30 | continue
31 | candidate_feats = embs_list[j][::stride]
32 | dists = cdist(query_feats, candidate_feats,
33 | tau_dist)
34 | nns = np.argmin(dists, axis=1)
35 | taus[idx] = kendalltau(np.arange(len(nns)), nns).correlation
36 | idx += 1
37 |
38 | # Remove NaNs.
39 | taus = taus[~np.isnan(taus)]
40 | tau = np.mean(taus)
41 |
42 | # logging.info('Iter[{}/{}] {} set alignment tau: {:.4f}'.format(
43 | # global_step.numpy(), CONFIG.TRAIN.MAX_ITERS, split, tau))
44 |
45 | #tf.summary.scalar('kendalls_tau/%s_align_tau' % split, tau, step=global_step)
46 | return tau
47 |
48 |
49 | class KendallsTau():
50 | """Calculate Kendall's Tau."""
51 |
52 | def __init__(self, conf):
53 | super(KendallsTau, self).__init__()
54 | self.conf = conf
55 |
56 | def evaluate_embeddings(self, datasets_ori):
57 | """Labeled evaluation."""
58 |
59 | datasets = copy.deepcopy(datasets_ori)
60 |
61 |
62 | train_emb = []
63 | train_label = []
64 | val_emb = []
65 | val_label = []
66 |
67 | for key, emb in datasets['train_dataset']['embs'].items():
68 | train_emb.append(np.average(np.array(emb), axis=0))
69 | train_label.append(
70 | datasets['train_dataset']['labels'][key][0])
71 |
72 | for key, emb in datasets['val_dataset']['embs'].items():
73 | val_emb.append(np.average(np.array(emb), axis=0))
74 | val_label.append(datasets['val_dataset']['labels'][key][0])
75 |
76 | datasets['train_dataset']['embs'] = train_emb
77 | datasets['train_dataset']['labels'] = train_label
78 | datasets['val_dataset']['embs'] = val_emb
79 | datasets['val_dataset']['labels'] = val_label
80 |
81 | train_embs = datasets['train_dataset']['embs']
82 |
83 | train_tau = _get_kendalls_tau(
84 | train_embs,
85 | self.conf.EVAL.KENDALLS_TAU_STRIDE, self.conf.EVAL.KENDALLS_TAU_DISTANCE)
86 |
87 | val_embs = datasets['val_dataset']['embs']
88 |
89 | val_tau = _get_kendalls_tau(
90 | val_embs, self.conf.EVAL.KENDALLS_TAU_STRIDE, self.conf.EVAL.KENDALLS_TAU_DISTANCE)
91 |
92 | loguru_logger.info('train set alignment tau: {:.5f}'.format(train_tau))
93 | loguru_logger.info('val set alignment tau: {:.5f}'.format(val_tau))
94 | return train_tau, val_tau
95 |
--------------------------------------------------------------------------------
/src/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import random
4 |
5 |
6 | def collate_stack(batch):
7 | rt_dataset = {}
8 | for key_name in batch[0].keys():
9 | elem_list = []
10 | for item in batch:
11 | if type(item[key_name]) == str or type(item[key_name]) == tuple:
12 | elem_list.append(item[key_name])
13 | else:
14 | elem_list.append(torch.FloatTensor(item[key_name]))
15 | # np.shape(elem_list)
16 | rt_dataset[key_name] = elem_list
17 | return rt_dataset
18 |
19 |
20 | def collate_fixed_len(batch, num_frames, sampling_strategy):
21 | if sampling_strategy == 'offset_uniform':
22 | # To access the dataset directly
23 | def _sample_random(item_len):
24 | steps = random.sample(
25 | range(1, item_len), num_frames)
26 | return sorted(steps)
27 |
28 | def _sample_all():
29 | return list(range(0, num_frames))
30 |
31 | def sampled_num(nparray, steps):
32 | return nparray[steps]
33 |
34 | rt_dataset = {}
35 |
36 | for key_name in batch[0].keys():
37 | rt_dataset[key_name] = []
38 |
39 |
40 | for item in batch:
41 |
42 | len0 = len(item['label0'])
43 | check0 = (num_frames <= len0)
44 | if check0:
45 | steps0 = _sample_random(len0)
46 | else:
47 | steps0 = _sample_all()
48 |
49 | check1 = (num_frames <= len(item['label1']))
50 | len1 = len(item['label1'])
51 | if check1:
52 | steps1 = _sample_random(len1)
53 | else:
54 | steps1 = _sample_all()
55 |
56 | elem = sampled_num(np.array(item["keypoints0"]), steps0)
57 | rt_dataset["keypoints0"].append(elem)
58 | elem = sampled_num(np.array(item["keypoints1"]), steps1)
59 | rt_dataset["keypoints1"].append(elem)
60 | elem = sampled_num(np.array(item["label0"]), steps0)
61 | rt_dataset["label0"].append(elem)
62 | elem = sampled_num(np.array(item["label1"]), steps1)
63 | rt_dataset["label1"].append(elem)
64 | rt_dataset["dataset_name"].append(item["dataset_name"])
65 | rt_dataset["pair_id"].append(item["pair_id"])
66 | rt_dataset["pair_names"].append(item["pair_names"])
67 |
68 |
69 | rt_dataset["keypoints0"] = torch.FloatTensor(np.array(rt_dataset["keypoints0"], dtype=float))
70 | rt_dataset["keypoints1"] = torch.FloatTensor(np.array(rt_dataset["keypoints1"], dtype=float))
71 | rt_dataset["label0"] = np.array(rt_dataset["label0"], dtype=int)
72 | rt_dataset["label1"] = np.array(rt_dataset["label1"], dtype=int)
73 | rt_dataset["pair_id"] = np.array(rt_dataset["pair_id"], dtype=int)
74 |
75 | else:
76 | assert()
77 | return rt_dataset
78 |
79 |
80 | def get_local_split(items: list, world_size: int, rank: int, seed: int):
81 | """ The local rank only loads a split of the dataset. """
82 | n_items = len(items)
83 | items_permute = np.random.RandomState(seed).permutation(items)
84 | if n_items % world_size == 0:
85 | padded_items = items_permute
86 | else:
87 | padding = np.random.RandomState(seed).choice(
88 | items,
89 | world_size - (n_items % world_size),
90 | replace=True)
91 | padded_items = np.concatenate([items_permute, padding])
92 | assert len(padded_items) % world_size == 0, \
93 | f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
94 | n_per_rank = len(padded_items) // world_size
95 | local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
96 |
97 | return local_items
98 |
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import contextlib
3 | import joblib
4 | from typing import Union
5 | from loguru import _Logger, logger
6 | from itertools import chain
7 |
8 | import torch
9 | from yacs.config import CfgNode as CN
10 | from pytorch_lightning.utilities import rank_zero_only
11 |
12 |
13 | def lower_config(yacs_cfg):
14 | if not isinstance(yacs_cfg, CN):
15 | return yacs_cfg
16 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
17 |
18 |
19 | def upper_config(dict_cfg):
20 | if not isinstance(dict_cfg, dict):
21 | return dict_cfg
22 | return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
23 |
24 |
25 | def log_on(condition, message, level):
26 | if condition:
27 | assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
28 | logger.log(level, message)
29 |
30 |
31 | def get_rank_zero_only_logger(logger: _Logger):
32 | if rank_zero_only.rank == 0:
33 | return logger
34 | else:
35 | for _level in logger._core.levels.keys():
36 | level = _level.lower()
37 | setattr(logger, level,
38 | lambda x: None)
39 | logger._log = lambda x: None
40 | return logger
41 |
42 |
43 | def setup_gpus(gpus: Union[str, int]) -> int:
44 | """ A temporary fix for pytorch-lighting 1.3.x """
45 | gpus = str(gpus)
46 | gpu_ids = []
47 |
48 | if ',' not in gpus:
49 | n_gpus = int(gpus)
50 | return n_gpus if n_gpus != -1 else torch.cuda.device_count()
51 | else:
52 | gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
53 |
54 | # setup environment variables
55 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
56 | if visible_devices is None:
57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
58 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
59 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
60 | logger.warning(
61 | f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
62 | else:
63 | logger.warning(
64 | '[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
65 | return len(gpu_ids)
66 |
67 |
68 | def flattenList(x):
69 | return list(chain(*x))
70 |
71 |
72 | @contextlib.contextmanager
73 | def tqdm_joblib(tqdm_object):
74 | """Context manager to patch joblib to report into tqdm progress bar given as argument
75 |
76 | Usage:
77 | with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
78 | Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
79 |
80 | When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
81 | ret_vals = Parallel(n_jobs=args.world_size)(
82 | delayed(lambda x: _compute_cov_score(pid, *x))(param)
83 | for param in tqdm(combinations(image_ids, 2),
84 | desc=f'Computing cov_score of [{pid}]',
85 | total=len(image_ids)*(len(image_ids)-1)/2))
86 | Src: https://stackoverflow.com/a/58936697
87 | """
88 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
89 | def __init__(self, *args, **kwargs):
90 | super().__init__(*args, **kwargs)
91 |
92 | def __call__(self, *args, **kwargs):
93 | tqdm_object.update(n=self.batch_size)
94 | return super().__call__(*args, **kwargs)
95 |
96 | old_batch_callback = joblib.parallel.BatchCompletionCallBack
97 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
98 | try:
99 | yield tqdm_object
100 | finally:
101 | joblib.parallel.BatchCompletionCallBack = old_batch_callback
102 | tqdm_object.close()
103 |
--------------------------------------------------------------------------------
/src/casa/casa_module/transformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | from .linear_attention import LinearAttention, FullAttention
5 |
6 |
7 | class CASAEncoderLayer(nn.Module):
8 | def __init__(self,
9 | d_model,
10 | nhead,
11 | attention='linear'):
12 | super(CASAEncoderLayer, self).__init__()
13 |
14 | self.dim = d_model // nhead
15 | self.nhead = nhead
16 |
17 | # multi-head attention
18 | self.q_proj = nn.Linear(d_model, d_model, bias=False)
19 | self.k_proj = nn.Linear(d_model, d_model, bias=False)
20 | self.v_proj = nn.Linear(d_model, d_model, bias=False)
21 | self.attention = LinearAttention() if attention == 'linear' else FullAttention()
22 | self.merge = nn.Linear(d_model, d_model, bias=False)
23 |
24 | # feed-forward network
25 | self.mlp = nn.Sequential(
26 | nn.Linear(d_model*2, d_model*2, bias=False),
27 | nn.ReLU(True),
28 | nn.Linear(d_model*2, d_model, bias=False),
29 | )
30 |
31 | # norm and dropout
32 | self.norm1 = nn.LayerNorm(d_model)
33 | self.norm2 = nn.LayerNorm(d_model)
34 |
35 | def forward(self, x, source, x_mask=None, source_mask=None):
36 | """
37 | Args:
38 | x (torch.Tensor): [N, L, C]
39 | source (torch.Tensor): [N, S, C]
40 | x_mask (torch.Tensor): [N, L] (optional)
41 | source_mask (torch.Tensor): [N, S] (optional)
42 | """
43 | bs = x.size(0)
44 |
45 | #print("x.shape",x.shape)
46 | query, key, value = x, source, source
47 | #print("query.shape",query.shape)
48 | # multi-head attention
49 | #print("self.q_proj(query).shape",self.q_proj(query).shape)
50 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
51 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
52 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
53 | message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
54 | message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
55 | message = self.norm1(message)
56 |
57 | # feed-forward network
58 | message = self.mlp(torch.cat([x, message], dim=2))
59 | message = self.norm2(message)
60 |
61 | return x + message
62 |
63 |
64 | class LocalFeatureTransformer(nn.Module):
65 | """A Local Feature Transformer (CASA) module."""
66 |
67 | def __init__(self, config):
68 | super(LocalFeatureTransformer, self).__init__()
69 |
70 | self.config = config
71 | self.d_model = config['d_model']
72 | self.nhead = config['nhead']
73 | self.layer_names = config['layer_names']
74 | encoder_layer = CASAEncoderLayer(config['d_model'], config['nhead'], config['attention'])
75 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
76 | self._reset_parameters()
77 |
78 | def _reset_parameters(self):
79 | for p in self.parameters():
80 | if p.dim() > 1:
81 | nn.init.xavier_uniform_(p)
82 |
83 | def forward(self, feat0, feat1, mask0=None, mask1=None):
84 | """
85 | Args:
86 | feat0 (torch.Tensor): [N, L, C]
87 | feat1 (torch.Tensor): [N, S, C]
88 | mask0 (torch.Tensor): [N, L] (optional)
89 | mask1 (torch.Tensor): [N, S] (optional)
90 | """
91 | #print("feat0.size(2)",feat0.size(2))
92 |
93 | assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
94 |
95 | for layer, name in zip(self.layers, self.layer_names):
96 | if name == 'self':
97 | feat0 = layer(feat0, feat0, mask0, mask0)
98 | feat1 = layer(feat1, feat1, mask1, mask1)
99 | elif name == 'cross':
100 | feat0 = layer(feat0, feat1, mask0, mask1)
101 | feat1 = layer(feat1, feat0, mask1, mask0)
102 | else:
103 | raise KeyError
104 |
105 | return feat0, feat1
106 |
--------------------------------------------------------------------------------
/joint_ids.py:
--------------------------------------------------------------------------------
1 | # Joint Ids and it's connectivity
2 |
3 | import numpy as np
4 |
5 | def get_joint_names_dict(joint_names):
6 | return {name: i for i, name in enumerate(joint_names)}
7 |
8 | def get_ikea_joint_names():
9 | return [
10 | "nose", # 0
11 | "left eye", # 1
12 | "right eye", # 2
13 | "left ear", # 3
14 | "right ear", # 4
15 | "left shoulder", # 5
16 | "right shoulder", # 6
17 | "left elbow", # 7
18 | "right elbow", # 8
19 | "left wrist", # 9
20 | "right wrist", # 10
21 | "left hip", # 11
22 | "right hip", # 12
23 | "left knee", # 13
24 | "right knee", # 14
25 | "left ankle", # 15
26 | "right ankle", # 16
27 | ]
28 |
29 | def get_ikea_connectivity():
30 | return [
31 | [0, 1],
32 | [0, 2],
33 | [1, 3],
34 | [2, 4],
35 | [0, 5],
36 | [0, 6],
37 | [5, 6],
38 | [5, 7],
39 | [6, 8],
40 | [7, 9],
41 | [8, 10],
42 | [5, 11],
43 | [6, 12],
44 | [11, 12],
45 | [11, 13],
46 | [12, 14],
47 | [13, 15],
48 | [14, 16]
49 | ]
50 |
51 | def get_body25_joint_names():
52 | return [
53 | "nose", # 0
54 | "neck", # 1
55 | "right shoulder", # 2
56 | "right elbow", # 3
57 | "right wrist", # 4
58 | "left shoulder", # 5
59 | "left elbow", # 6
60 | "left wrist", # 7
61 | "mid hip", # 8
62 | "right hip", # 9
63 | "right knee", # 10
64 | "right ankle", # 11
65 | "left hip", # 12
66 | "left knee", # 13
67 | "left ankle", # 14
68 | "right eye", # 15
69 | "left eye", # 16
70 | "right ear", # 17
71 | "left ear", # 18
72 | "left big toe", # 19
73 | "left small toe", # 20
74 | "left heel", # 21
75 | "right big toe", # 22
76 | "right small toe", # 23
77 | "right heel", # 24
78 | "background", # 25
79 | ]
80 |
81 | def get_body25_connectivity():
82 | return [
83 | [0, 1],
84 | [1, 2],
85 | [2, 3],
86 | [3, 4],
87 | [1, 5],
88 | [5, 6],
89 | [6, 7],
90 | [1, 8],
91 | [8, 9],
92 | [9, 10],
93 | [10, 11],
94 | [8, 12],
95 | [12, 13],
96 | [13, 14],
97 | [0, 15],
98 | [0, 16],
99 | [15, 17],
100 | [16, 18],
101 | [2, 9],
102 | [5, 12],
103 | [11, 22],
104 | [11, 23],
105 | [11, 24],
106 | [14, 19],
107 | [14, 20],
108 | [14, 21],
109 | ]
110 |
111 |
112 | def get_body21_joint_names():
113 | return [
114 | "nose", # 0
115 | "neck", # 1
116 | "right shoulder", # 2
117 | "right elbow", # 3
118 | "right wrist", # 4
119 | "left shoulder", # 5
120 | "left elbow", # 6
121 | "left wrist", # 7
122 | "mid hip", # 8
123 | "right hip", # 9
124 | "right knee", # 10
125 | "right ankle", # 11
126 | "left hip", # 12
127 | "left knee", # 13
128 | "left ankle", # 14
129 | "right eye", # 15
130 | "left eye", # 16
131 | "right ear", # 17
132 | "left ear", # 18
133 | "neck (lsp)", # 19
134 | "top of head (lsp)", # 20
135 | ]
136 |
137 | def get_hmmr_joint_names():
138 | return [
139 | "right ankle", # 0
140 | "right knee", # 1
141 | "right hip", # 2
142 | "left hip", # 3
143 | "left knee", # 4
144 | "left ankle", # 5
145 | "right wrist", # 6
146 | "right elbow", # 7
147 | "right shoulder", # 8
148 | "left shoulder", # 9
149 | "left elbow", # 10
150 | "left wrist", # 11
151 | "neck", # 12
152 | "top of head", # 13
153 | "nose", # 14
154 | "left eye", # 15
155 | "right eye", # 16
156 | "left ear", # 17
157 | "right ear", # 18
158 | "left big toe", # 19
159 | "right big toe", # 20
160 | "left small toe", # 21
161 | "right small toe", # 22
162 | "left heel", # 23
163 | "right heel", # 24
164 | ]
165 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Context-Aware Sequence Alignment using 4D Skeletal Augmentation
2 | Taein Kwon, Bugra Tekin, Sigyu Tang, and Marc Pollefeys
3 |
4 |
5 | This code is for Context-Aware Sequence Alignment using 4D Skeletal Augmentation, CVPR 2022 (oral). You can see more details about more paper in our [project page](https://taeinkwon.com/projects/casa/). Note that we referred [LoFTR](https://github.com/zju3dv/LoFTR) to implement our framework.
6 |
7 | ## Environment Setup
8 | To setup the env,
9 | ```
10 | git clone https://github.com/taeinkwon/casa_clean.git
11 | cd CASA
12 | conda env create -f env.yml
13 | conda activate CASA
14 | ```
15 |
16 | ## External Dependencies
17 |
18 | ### Folder Structures
19 |
20 | . 21 | ├── bodymocap 22 | ├── extra_data 23 | │ └── body_module 24 | │ └── J_regressor_extra_smplx.npy 25 | ├── human_body_prior 26 | │ └── ... 27 | ├── manopth 28 | │ ├── __init__.py 29 | │ ├── arguitls.py 30 | │ └── ... 31 | ├── smpl 32 | │ └── models 33 | │ └── SMPLX_NEUTRAL.pkl 34 | ├── mano 35 | │ ├── models 36 | │ │ ├── MANO_LEFT.pkl 37 | │ │ └── MANO_RIGHT.pkl 38 | │ ├── websuers 39 | │ ├── __init__.py 40 | │ └── LICENSE.txt 41 | ├── npyrecords 42 | ├── sripts 43 | ├── src 44 | └── ... 45 |46 | 47 | ### Install MANO and Manopth 48 | In this repository, we use the [MANO](https://mano.is.tue.mpg.de/) model from MPI and some part of [Yana](https://hassony2.github.io/)'s code for hand pose alignment. 49 | - Clone [manopth](https://github.com/hassony2/manopth) ```git clone https://github.com/hassony2/manopth.git``` and copy ```manopth``` and ```mano``` folder (inside) into the CASA folder. 50 | - Go to the [mano website](https://mano.is.tue.mpg.de/) and download models and code and put them in ```CASA/mano/models```. 51 | - In ```smpl_handpca_wrapper_HAND_only.py```, please change following lines to run in python3. L23:import CPickle as pickle -> import pickle, L30: dd = pickle.load(open(fname_or_dick)) -> dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1'), L144: print 'FINTO' -> print('FINTO'). 52 | 53 | ### Install Vposer 54 | We used Vposer to augment body pose. 55 | - In the [Vposer](https://github.com/nghorbani/human_body_prior) repository, clone it ```git clone https://github.com/nghorbani/human_body_prior.git``` 56 | - Copy the human_body_prior folder into the CASA folder. 57 | - Go into the human_body_prior folder and run the setup.py 58 | ``` 59 | cd human_body_prior 60 | python setup.py develop 61 | ``` 62 | 63 | ### Install SMPL 64 | We use the SMPL model for body pose alignment. 65 | - Download [SMPL-X](https://smpl-x.is.tue.mpg.de/) ver 1.1. and VPoser V2.0. 66 | - Put the ```SMPLX_NUETRAL.pkl``` into the ```CASA/smpl/models``` folder. 67 | - Copy VPoser files in ```CASA/human_body_prior/support_data/downlaods/vposer_v2_05/``` 68 | - In order to use the joints from FrankMocap, you need to additionally clone [FrankMocap](https://github.com/facebookresearch/frankmocap) ```git clone https://github.com/facebookresearch/frankmocap.git``` and put the ```bodymocap``` fodler into ```CASA``` fodler. 69 | - After then, run ```sh download_extra_data.sh``` to get the J_regressor_extra_smplx.npy file. 70 | 71 | ## Datasets 72 | You can download npy files [here](https://drive.google.com/file/d/16Kgy8iESC-0YwqELxfE9mWu24Jxzsu1C/view?usp=sharing). In the npy files, normalized joints and labels are included. In order to get the original data, you should go to each dataset websites and download the datasets there. 73 | 74 | ### H2O 75 | We selected ```pouring milk``` sequences and manually divided into train and test set with the new labels we set. Please go to the [H2O project page](https://taeinkwon.com/projects/h2o/) and download the dataset there. 76 | 77 | ### PennAction 78 | We estimated 3D joints using [FrankMocap](https://github.com/facebookresearch/frankmocap) for the [Penn Action dataset](https://dreamdragon.github.io/PennAction/). Penn Action has 13 different actions: baseball_pitch, baseball_swing, bench_press, bowling, clean_and_jerk, golf_swing, jumping_jacks, pushups, pullups, situp, squats, tennis_forehand, tennis_serve. 79 | 80 | ### IkeaASM 81 | We downloaded and used the 3D joints from triangulation of 2D poses in the [IkeaASM dataset](https://ikeaasm.github.io/). 82 | 83 | ## Train 84 | To train the Penn Action dataset, 85 | ``` 86 | sh scripts/train/pennaction_train.sh ${dataset_name} 87 | ``` 88 | For example, 89 | ``` 90 | sh scripts/train/pennaction_train.sh tennis_serve 91 | ``` 92 | 93 | ## Eval 94 | We also provide pre-trained models. To evalate the pre-trained model 95 | 96 | ``` 97 | sh scripts/eval/pennaction_eval.sh ${dataset_name} ${eval_model_path} 98 | ``` 99 | For example, 100 | ``` 101 | sh scripts/eval/pennaction_eval.sh tennis_serve logs/tennis_serve/CASA=64/version_0/checkpoints/last.ckpt 102 | ``` 103 | 104 | ## License 105 | Note that our code follows the Apache License 2.0. However, external libraries follows their own licenses. -------------------------------------------------------------------------------- /src/casa/casa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.einops import rearrange 4 | 5 | from .backbone import build_backbone 6 | from .utils.position_encoding import PositionalEncoding 7 | from .casa_module import LocalFeatureTransformer 8 | from .utils.matching import Matching 9 | import tqdm 10 | 11 | 12 | class CASA(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | # Misc 16 | self.config = config 17 | 18 | # Modules 19 | self.backbone = build_backbone(config) 20 | self.ph = config['ph']['true'] 21 | if self.ph: 22 | self.backbone_ph = build_backbone(config['ph']) 23 | if config['match']['d_model'] % 2 == 0: 24 | self.pos_encoding = PositionalEncoding( 25 | config['match']['d_model']) # To make it even number 26 | else: 27 | self.pos_encoding = PositionalEncoding( 28 | config['match']['d_model']+1) # To make it even number 29 | self.casa_coarse = LocalFeatureTransformer(config['match']) 30 | self.matching = Matching(config['match']) 31 | 32 | 33 | def forward(self, data, train=True): 34 | """ 35 | Update: 36 | data (dict): { 37 | 'keypoints0': (torch.Tensor): (N, T, K, D) 1, 85, 25 3 38 | 'keypoints1': (torch.Tensor): (N, T, K, D) 39 | } 40 | """ 41 | if train: 42 | # For training set, we have two inputs, the original sequence and the 4D augmented sequence. 43 | 44 | # 1. Local Feature FCL 45 | data.update({ 46 | 'bs': data['keypoints0'].size(0), 47 | 'hw0_i': data['keypoints0'].shape[1], 'hw1_i': data['keypoints1'].shape[1] 48 | }) 49 | # else: # handle different input shapes 50 | kp0 = torch.reshape( 51 | data['keypoints0'], (data['keypoints0'].shape[0], data['keypoints0'].shape[1], -1)).float() 52 | kp1 = torch.reshape( 53 | data['keypoints1'], (data['keypoints1'].shape[0], data['keypoints1'].shape[1], -1)).float() 54 | feat_f0, feat_f1 = self.backbone(kp0), self.backbone(kp1) 55 | 56 | data.update({ 57 | # N T (K *D) 58 | 'len_t0': feat_f0.shape[1], 'len_t1': feat_f1.shape[1], 'len_d': feat_f0.shape[2] 59 | }) 60 | 61 | # 2. Matching 62 | # add featmap with positional encoding, then flatten it to sequence [N, HW, C] 63 | # positional encoding for CASA 64 | if self.config['match']['pe']: 65 | feat_f0 = self.pos_encoding( 66 | feat_f0, data['steps0'], data['len0']) 67 | feat_f1 = self.pos_encoding( 68 | feat_f1, data['steps1'], data['len1']) 69 | # CASA Matching 70 | feat_f0, feat_f1 = self.casa_coarse( 71 | feat_f0, feat_f1) 72 | 73 | data.update({ 74 | # N T (K *D) 75 | 'emb0': feat_f0, 'emb1': feat_f1 76 | }) 77 | 78 | # 3. Projection Head 79 | # emb N T K*D 80 | if self.ph: 81 | z0 = self.backbone_ph(data['emb0']) 82 | z1 = self.backbone_ph(data['emb1']) 83 | data.update({ 84 | # N T (K *D) 85 | 'z0': z0, 86 | 'z1': z1 87 | }) 88 | self.matching(z0, z1, data) 89 | else: 90 | self.matching(feat_f0, feat_f1, data) 91 | else: 92 | # For val set, we only need one input, the original sequence. 93 | data.update({ 94 | 'bs': data['keypoints0'].size(0), 95 | 'hw0_i': data['keypoints0'].shape[1] 96 | }) 97 | 98 | kp0 = torch.reshape( 99 | data['keypoints0'], (data['keypoints0'].shape[0], data['keypoints0'].shape[1], -1)).float() 100 | feat_f0 = self.backbone(kp0) 101 | data.update({ 102 | # N T (K *D) 103 | 'len_t0': feat_f0.shape[1], 'len_d': feat_f0.shape[2] 104 | }) 105 | # add featmap with positional encoding, then flatten it to sequence [N, HW, C] 106 | if self.config['match']['pe']: 107 | feat_f0 = self.pos_encoding( 108 | feat_f0, data['steps0'], data['len0']) 109 | # CASA Matching, we use the same sequence to get the features. 110 | feat_f0, feat_f1 = self.casa_coarse( 111 | feat_f0, feat_f0) 112 | # For val, we only use the latent space before projection head. 113 | data.update({ 114 | # N T (K *D) 115 | 'emb0': feat_f0 116 | }) 117 | 118 | 119 | 120 | def load_state_dict(self, state_dict, *args, **kwargs): 121 | for k in list(state_dict.keys()): 122 | if k.startswith('matcher.'): 123 | state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) 124 | return super().load_state_dict(state_dict, *args, **kwargs) 125 | -------------------------------------------------------------------------------- /src/config/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | ############## ↓ CASA Pipeline ↓ ############## 6 | _CN.CASA = CN() 7 | _CN.CASA.BACKBONE_TYPE = 'FCL' 8 | _CN.CASA.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] 9 | _CN.CASA.NUM_FRAMES = 20 10 | # 1. CASA-backbone (local feature CNN) config 11 | _CN.CASA.RESNETFPN = CN() 12 | _CN.CASA.RESNETFPN.INITIAL_DIM = 128 13 | _CN.CASA.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 14 | # FCL config 15 | _CN.CASA.FCL = CN() 16 | _CN.CASA.FCL.INITIAL_DIM = 75 # For pose of IKEA ASM => 51, PENN ACTION =>75 17 | # Projection Head config 18 | _CN.CASA.PH = CN() 19 | _CN.CASA.PH.TRUE = False 20 | _CN.CASA.PH.BACKBONE_TYPE = 'PH' 21 | # For pose of IKEA ASM => 51, PENN ACTION =>75 22 | _CN.CASA.PH.OUTPUT_DIM = _CN.CASA.PH.INPUT_DIM = 75 23 | _CN.CASA.PH.HIDDEN_DIM = 75 # For pose of IKEA ASM => 51, PENN ACTION =>75 24 | 25 | # 2. CASA module config 26 | _CN.CASA.MATCH = CN() 27 | _CN.CASA.MATCH.PE = True # should be enven number #256 28 | _CN.CASA.MATCH.D_MODEL = 75 # should be enven number #256 29 | _CN.CASA.MATCH.D_FFN = 256 30 | _CN.CASA.MATCH.NHEAD = 5 # 8 31 | _CN.CASA.MATCH.LAYER_NAMES = ['self', 'cross'] * 4 32 | _CN.CASA.MATCH.ATTENTION = 'linear' # options: ['linear', 'full'] 33 | _CN.CASA.MATCH.TEMP_BUG_FIX = True 34 | _CN.CASA.MATCH.VIS_CONF_TRAIN = False 35 | _CN.CASA.MATCH.VIS_CONF_VALIDATION = True 36 | _CN.CASA.MATCH.USE_PRIOR = False 37 | _CN.CASA.MATCH.THR = 0.1 38 | _CN.CASA.MATCH.BORDER_RM = 2 39 | # options: ['dual_softmax, 'bicross'] 40 | _CN.CASA.MATCH.MATCH_TYPE = 'dual_softmax' 41 | _CN.CASA.MATCH.MATCH_ALGO = None 42 | _CN.CASA.MATCH.DSMAX_TEMPERATURE = 0.1 43 | _CN.CASA.MATCH.SIMILARITY = False 44 | 45 | # 3. CASA Losses 46 | # -- # coarse-level 47 | _CN.CASA.LOSS = CN() 48 | _CN.CASA.LOSS.TYPE = 'cross_entropy' # ['focal', 'cross_entropy'] 49 | # ['classification', 'regression','regression_var] 50 | _CN.CASA.LOSS.LOSS_TYPE = 'regression' 51 | _CN.CASA.LOSS.WEIGHT = 1.0 52 | # -- - -- # focal loss (coarse) 53 | _CN.CASA.LOSS.FOCAL_ALPHA = 0.25 54 | _CN.CASA.LOSS.FOCAL_GAMMA = 2.0 55 | _CN.CASA.LOSS.POS_WEIGHT = 1.0 56 | _CN.CASA.LOSS.NEG_WEIGHT = 1.0 57 | 58 | _CN.CASA.EMBEDDER_TYPE = 'casa' # casa, conv_embedder 59 | 60 | _CN.CONSTRASTIVE = CN() 61 | _CN.CONSTRASTIVE.TRAIN = False 62 | _CN.CONSTRASTIVE.AUGMENTATION_STRATEGY = ['shuffle'] 63 | 64 | _CN.CLASSIFICATION = CN() 65 | _CN.CLASSIFICATION.ACC_LIST = [0.1, 0.5, 1.0] 66 | 67 | ############## Dataset ############## 68 | _CN.DATASET = CN() 69 | # 1. data config 70 | _CN.DATASET.NUM_FRAMES = 20 71 | _CN.DATASET.NAME = "name" 72 | _CN.DATASET.SAMPLING_STRATEGY = 'offset_uniform' 73 | _CN.DATASET.FOLDER = "./" 74 | _CN.DATASET.LOGDIR = './logs' 75 | 76 | # Parameters from TCC, 77 | _CN.DATASET.NUM_STEPS = 1 # number of frames that will be embedded jointly, 78 | _CN.DATASET.FRAME_STRIDE = 15 # stride between context frames 79 | 80 | _CN.DATASET.MAX_LENGTH = 250 81 | _CN.DATASET.ATT_STYLE = False 82 | _CN.DATASET.USE_NORM = True 83 | _CN.DATASET.MANO = False 84 | _CN.DATASET.SMPL = False 85 | _CN.DATASET.TRAINVAL_DATA_SOURCE = None 86 | _CN.DATASET.TRAIN_DATA_ROOT = None 87 | _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) 88 | _CN.DATASET.TRAIN_NPZ_ROOT = None 89 | _CN.DATASET.TRAIN_LIST_PATH = None 90 | _CN.DATASET.VAL_DATA_ROOT = None 91 | _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) 92 | _CN.DATASET.VAL_NPZ_ROOT = None 93 | # None if val data from all scenes are bundled into a single npz file 94 | _CN.DATASET.VAL_LIST_PATH = None 95 | _CN.DATASET.VAL_BATCH_SIZE = 1 96 | # testing 97 | _CN.DATASET.TEST_DATA_SOURCE = None 98 | _CN.DATASET.TEST_DATA_ROOT = None 99 | _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) 100 | _CN.DATASET.TEST_NPZ_ROOT = None 101 | # None if test data from all scenes are bundled into a single npz file 102 | _CN.DATASET.TEST_LIST_PATH = None 103 | 104 | # 2. dataset config 105 | # general options 106 | _CN.DATASET.AUGMENTATION_TYPE = None 107 | 108 | _CN.EVAL = CN() 109 | _CN.EVAL.EVENT_COMPLETION = False 110 | _CN.EVAL.KENDALLS_TAU = False 111 | _CN.EVAL.KENDALLS_TAU_STRIDE = 2 # 5 for Pouring, 2 for PennAction 112 | _CN.EVAL.KENDALLS_TAU_DISTANCE = 'sqeuclidean' # cosine, sqeuclidean 113 | ############## Trainer ############## 114 | _CN.TRAINER = CN() 115 | _CN.TRAINER.WORLD_SIZE = 1 116 | _CN.TRAINER.CANONICAL_BS = 64 117 | _CN.TRAINER.CANONICAL_LR = 1e-3 # 6e-3 118 | _CN.TRAINER.SCALING = None # this will be calculated automatically 119 | _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning 120 | 121 | # optimizer 122 | _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] 123 | _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime 124 | _CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam 125 | _CN.TRAINER.ADAMW_DECAY = 0.1 126 | 127 | # step-based warm-up 128 | _CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] 129 | _CN.TRAINER.WARMUP_RATIO = 0. 130 | _CN.TRAINER.WARMUP_STEP = 100 131 | 132 | # learning rate scheduler 133 | # [MultiStepLR, CosineAnnealing, ExponentialLR] 134 | _CN.TRAINER.SCHEDULER = 'MultiStepLR' 135 | _CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] 136 | _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR 137 | _CN.TRAINER.MSLR_GAMMA = 0.5 138 | _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing 139 | # ELR: ExponentialLR, this value for 'step' interval 140 | _CN.TRAINER.ELR_GAMMA = 0.999992 141 | 142 | # gradient clipping 143 | _CN.TRAINER.GRADIENT_CLIPPING = 0.5 144 | _CN.TRAINER.SEED = 50 145 | 146 | 147 | def get_cfg_defaults(): 148 | """Get a yacs CfgNode object with default values for my_project.""" 149 | # Return a clone so that the defaults will not be altered 150 | # This is for the "local variable" use pattern 151 | return _CN.clone() 152 | -------------------------------------------------------------------------------- /src/evaluation/classification.py: -------------------------------------------------------------------------------- 1 | r"""Evaluation on per-frame labels for classification. 2 | """ 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from absl import flags 9 | from loguru import logger as loguru_logger 10 | import pytorch_lightning as pl 11 | 12 | import concurrent.futures as cf 13 | 14 | import numpy as np 15 | from sklearn.linear_model import LogisticRegression 16 | from sklearn.svm import SVC, LinearSVC 17 | import copy 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | 22 | def fit_linear_model(train_embs, train_labels, 23 | val_embs, val_labels): 24 | """Fit a linear classifier.""" 25 | lin_model = LogisticRegression(max_iter=100000, solver='lbfgs', 26 | multi_class='multinomial', verbose=0) 27 | lin_model.fit(train_embs, train_labels) 28 | train_acc = lin_model.score(train_embs, train_labels) 29 | val_acc = lin_model.score(val_embs, val_labels) 30 | return lin_model, train_acc, val_acc 31 | 32 | 33 | def fit_svm_model(train_embs, train_labels, 34 | val_embs, val_labels): 35 | """Fit a SVM classifier.""" 36 | # svm_model = LinearSVC(verbose=0) 37 | svm_model = SVC(decision_function_shape='ovo', verbose=0) 38 | svm_model.fit(train_embs, train_labels) 39 | train_acc = svm_model.score(train_embs, train_labels) 40 | val_acc = svm_model.score(val_embs, val_labels) 41 | return svm_model, train_acc, val_acc 42 | 43 | 44 | def fit_linear_models(train_embs, train_labels, val_embs, val_labels, 45 | model_type='linear'): 46 | """Fit Log Regression and SVM Models.""" 47 | if model_type == 'linear': 48 | _, train_acc, val_acc = fit_linear_model(train_embs, train_labels, 49 | val_embs, val_labels) 50 | elif model_type == 'svm': 51 | _, train_acc, val_acc = fit_svm_model(train_embs, train_labels, 52 | val_embs, val_labels) 53 | else: 54 | raise ValueError('%s model type not supported' % model_type) 55 | return train_acc, val_acc 56 | 57 | 58 | class Classification(): 59 | """Classification using small linear models.""" 60 | 61 | def __init__(self, config): 62 | self.config = config 63 | 64 | def evaluate_embeddings(self, datasets_ori, emb_mean=False, DICT=False, acc_list=[0.1, 0.5, 1.0]): 65 | """Labeled evaluation.""" 66 | fractions = acc_list # CONFIG.EVAL.CLASSIFICATION_FRACTIONS # [0.1, 0.5, 1.0] 67 | datasets = copy.deepcopy(datasets_ori) 68 | 69 | if self.config.DATASET.NAME == 'kallax_shelf_drawer': 70 | BACKGROUND_LABEL = True 71 | else: 72 | BACKGROUND_LABEL = False 73 | 74 | if datasets['train_dataset']['embs'] == [] or datasets['val_dataset']['embs'] == []: 75 | loguru_logger.info( 76 | 'Empty embeddings') 77 | return (0.0, 0.0) 78 | 79 | if DICT: 80 | if emb_mean: 81 | train_emb = [] 82 | train_label = [] 83 | val_emb = [] 84 | val_label = [] 85 | for key, emb in datasets['train_dataset']['embs'].items(): 86 | train_emb.append(np.average(np.array(emb), axis=0)) 87 | train_label.append( 88 | datasets['train_dataset']['labels'][key][0]) 89 | 90 | for key, emb in datasets['val_dataset']['embs'].items(): 91 | val_emb.append(np.average(np.array(emb), axis=0)) 92 | val_label.append(datasets['val_dataset']['labels'][key][0]) 93 | datasets['train_dataset']['embs'] = train_emb 94 | datasets['train_dataset']['labels'] = train_label 95 | datasets['val_dataset']['embs'] = val_emb 96 | datasets['val_dataset']['labels'] = val_label 97 | 98 | else: 99 | train_emb = [] 100 | train_label = [] 101 | val_emb = [] 102 | val_label = [] 103 | for key, emb in datasets['train_dataset']['embs'].items(): 104 | train_emb.append(datasets['train_dataset']['embs'][key][0]) 105 | train_label.append( 106 | datasets['train_dataset']['labels'][key][0]) 107 | 108 | for key, emb in datasets['val_dataset']['embs'].items(): 109 | val_emb.append(datasets['val_dataset']['embs'][key][0]) 110 | val_label.append(datasets['val_dataset']['labels'][key][0]) 111 | datasets['train_dataset']['embs'] = train_emb 112 | datasets['train_dataset']['labels'] = train_label 113 | datasets['val_dataset']['embs'] = val_emb 114 | datasets['val_dataset']['labels'] = val_label 115 | 116 | val_embs = np.concatenate(datasets['val_dataset']['embs']) 117 | val_labels = np.concatenate(datasets['val_dataset']['labels']) 118 | 119 | if BACKGROUND_LABEL: 120 | val_embs = val_embs[val_labels.astype(bool)] 121 | val_labels = val_labels[val_labels.astype(bool)]-1 122 | 123 | val_accs = [] 124 | train_accs = [] 125 | train_dataset = datasets['train_dataset'] 126 | num_samples = len(train_dataset['embs']) 127 | 128 | def worker(fraction_used): 129 | num_samples_used = max(1, int(fraction_used * num_samples)) 130 | train_embs = np.concatenate( 131 | train_dataset['embs'][:num_samples_used]) 132 | train_labels = np.concatenate( 133 | train_dataset['labels'][:num_samples_used]) 134 | 135 | if BACKGROUND_LABEL: 136 | train_embs = train_embs[train_labels.astype(bool)] 137 | train_labels = train_labels[train_labels.astype(bool)]-1 138 | return fit_linear_models(train_embs, train_labels, val_embs, val_labels) 139 | 140 | with cf.ThreadPoolExecutor(max_workers=len(fractions)) as executor: 141 | results = executor.map(worker, fractions) 142 | for (fraction, (train_acc, val_acc)) in zip(fractions, results): 143 | loguru_logger.info('[Global step: Classification {} Fraction' 144 | 'Train Accuracy: {:.5f},'.format(fraction, train_acc)) 145 | loguru_logger.info('[Global step: ] Classification {} Fraction' 146 | 'Val Accuracy: {:.5f},'.format(fraction, 147 | val_acc)) 148 | train_accs.append(train_acc) 149 | val_accs.append(val_acc) 150 | 151 | return (train_accs, val_accs) 152 | -------------------------------------------------------------------------------- /src/losses/casa_loss.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.twodim_base import mask_indices 2 | from loguru import logger 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from scipy.optimize import linear_sum_assignment 9 | 10 | 11 | class CASALoss(nn.Module): 12 | def __init__(self, config): 13 | super().__init__() 14 | self.config = config # config under the global namespace 15 | self.loss_config = config['casa']['loss'] 16 | self.match_type = self.config['casa']['match']['match_type'] 17 | self.match_algo = self.config['casa']['match']['match_algo'] 18 | self.mse_loss = nn.MSELoss() 19 | # coarse-level 20 | self.c_pos_w = self.loss_config['pos_weight'] 21 | self.c_neg_w = self.loss_config['neg_weight'] 22 | self.use_prior = self.config['casa']['match']['use_prior'] 23 | self.DEBUG = False 24 | 25 | def compute_coarse_loss(self, data, weight=None): 26 | steps0 = data['steps0'] 27 | steps1 = data['steps1'] 28 | 29 | conf = data['conf_matrix'] 30 | conf_prior = data['conf_matrix_prior'] 31 | i_mask = data['i_mask'] 32 | j_mask = data['j_mask'] 33 | gt_mask_float = (i_mask.float()+j_mask.float())/2.0 34 | gt_mask = i_mask + j_mask 35 | 36 | if self.match_type == 'dual_softmax': 37 | 38 | conf0 = data['conf_matrix0'] 39 | conf1 = data['conf_matrix1'] 40 | len0 = data['len0'] 41 | len1 = data['len1'] 42 | sim_matrix = data['sim_matrix'] 43 | 44 | 45 | neg_mask = gt_mask == False 46 | mask = i_mask * j_mask 47 | 48 | c_pos_w, c_neg_w = self.c_pos_w, self.c_neg_w 49 | 50 | # mask = (conf_prior ==1) 51 | 52 | if self.loss_config['type'] == 'cross_entropy': 53 | 54 | if self.match_type == 'dual_softmax': 55 | conf = torch.clamp(conf, 1e-6, 1-1e-6) 56 | 57 | if self.use_prior: 58 | conf_prior = conf_prior.bool() 59 | gt_mask = conf_prior 60 | 61 | loss_pos = [] 62 | if self.loss_config['loss_type'] == 'regression': 63 | # cal row and col seperately 64 | 65 | steps_i = torch.unsqueeze(steps0, 2) 66 | 67 | true_time_i = torch.sum(steps_i*gt_mask.float(), 1) 68 | pred_time_i = torch.sum(steps_i*conf0, 1) 69 | gt_mask_sum_i = torch.sum(gt_mask.float(), dim=1) == 1 70 | loss_pos = self.mse_loss( 71 | true_time_i[gt_mask_sum_i], pred_time_i[gt_mask_sum_i]) 72 | 73 | elif self.loss_config['loss_type'] == 'regression_var': 74 | 75 | coeff_lambda = 0.00001 76 | num_list_i = torch.unsqueeze(steps0, 2) 77 | conf_num_i = conf0 * num_list_i 78 | mean_conf_i = torch.sum(conf_num_i, 1) 79 | diff_i = num_list_i - torch.unsqueeze(mean_conf_i, 1) 80 | sigma_square_i = torch.sum(conf0 * diff_i**2, 1) 81 | mask_loc_i = torch.argmax(gt_mask.float(), dim=1) 82 | gt_mask_sum_i = torch.sum(gt_mask.float(), dim=1) == 1 83 | diff_regression_i = torch.gather( 84 | steps0, 1, mask_loc_i) - mean_conf_i 85 | 86 | num_list_j = torch.unsqueeze(steps1, 1) 87 | conf_num_j = conf1 * num_list_j 88 | mean_conf_j = torch.sum(conf_num_j, 2) 89 | diff_j = num_list_j - torch.unsqueeze(mean_conf_j, 2) 90 | sigma_square_j = torch.sum(conf1 * diff_j**2, 2) 91 | mask_loc_j = torch.argmax(gt_mask.float(), dim=2) 92 | gt_mask_sum_j = torch.sum(gt_mask.float(), dim=2) == 1 93 | diff_regression_j = torch.gather( 94 | steps1, 1, mask_loc_j) - mean_conf_j 95 | 96 | if self.DEBUG: # debug purpose 97 | print("gt_mask_sum", gt_mask_sum_j.shape) 98 | print("gt_mask[0]", gt_mask[0]) 99 | print("mask_loc_i[0]", mask_loc_j[0]) 100 | print("num_list_i", num_list_j.shape) 101 | print("conf0", conf1.shape) 102 | print("conf_num_i", conf_num_j.shape) 103 | print("num_list_i", num_list_j.shape) 104 | print("torch.unsqueeze(mean_conf_i,1)", 105 | torch.unsqueeze(mean_conf_j, 2).shape) 106 | print("diff_i", diff_j.shape) 107 | print("num_list_i", num_list_j.shape) 108 | print("conf0", conf1.shape) 109 | print("diff_i", diff_j.shape) 110 | print("mask_loc_i", mask_loc_j.shape) 111 | print("mean_conf_i", mean_conf_j.shape) 112 | print("sigma_square_i", sigma_square_j.shape) 113 | print("diff_regression_i", diff_regression_j.shape) 114 | print("(diff_regression_i[gt_mask_sum]**2)/sigma_square_i[gt_mask_sum]", (( 115 | diff_regression_j[gt_mask_sum_j]**2)/sigma_square_j[gt_mask_sum_j]).shape) 116 | # loss_pos += conf0[gt_mask_sum] 117 | loss_pos = (diff_regression_i[gt_mask_sum_i]**2)/sigma_square_i[gt_mask_sum_i] + \ 118 | coeff_lambda * torch.log(sigma_square_i[gt_mask_sum_i]) 119 | loss_pos += (diff_regression_j[gt_mask_sum_j]**2)/sigma_square_j[gt_mask_sum_j] + \ 120 | coeff_lambda * torch.log(sigma_square_j[gt_mask_sum_j]) 121 | else: 122 | raise ValueError('Supported loss types: regression and ' 123 | 'regression_var.') 124 | 125 | return loss_pos.mean() # + loss_count.mean() 126 | 127 | elif self.loss_config['type'] == 'mse': 128 | loss = self.mse_loss(conf, gt_mask_float) 129 | return loss # .mean() 130 | else: 131 | raise ValueError('Unknown coarse loss: {type}'.format( 132 | type=self.loss_config['type'])) 133 | 134 | @ torch.no_grad() 135 | def compute_c_weight(self, data): 136 | if 'mask0' in data: 137 | c_weight = (data['mask0'].flatten(-2)[..., None] 138 | * data['mask1'].flatten(-2)[:, None]).float() 139 | else: 140 | c_weight = None 141 | return c_weight 142 | 143 | def forward(self, data): 144 | loss_scalars = {} 145 | # compute element-wise loss weight 146 | c_weight = self.compute_c_weight(data) 147 | 148 | # computer loss loss 149 | loss_c = self.compute_coarse_loss(data, 150 | weight=c_weight) 151 | 152 | loss = loss_c * self.loss_config['weight'] 153 | loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()}) 154 | 155 | loss_scalars.update({'loss': loss.clone().detach().cpu()}) 156 | data.update({"loss": loss, "loss_scalars": loss_scalars}) 157 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import pprint 4 | from distutils.util import strtobool 5 | from pathlib import Path 6 | from loguru import logger as loguru_logger 7 | import os 8 | import tqdm 9 | import cv2 10 | import glob 11 | import numpy as np 12 | 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.utilities import rank_zero_only 15 | from pytorch_lightning.loggers import TensorBoardLogger 16 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 17 | from pytorch_lightning.plugins import DDPPlugin 18 | 19 | from src.config.default import get_cfg_defaults 20 | from src.utils.misc import get_rank_zero_only_logger, setup_gpus 21 | from src.utils.profiler import build_profiler 22 | from src.lightning.data import MultiSceneDataModule 23 | from src.lightning.lightning_casa import PL_CASA 24 | from torchsummary import summary 25 | 26 | loguru_logger = get_rank_zero_only_logger(loguru_logger) 27 | 28 | 29 | def parse_args(): 30 | # init a costum parser which will be added into pl.Trainer parser 31 | # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags 32 | parser = argparse.ArgumentParser( 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | parser.add_argument( 35 | '--data_cfg_path', type=str, help='data config path') 36 | parser.add_argument( 37 | '--main_cfg_path', type=str, help='main config path') 38 | parser.add_argument( 39 | '--exp_name', type=str, default='default_exp_name') 40 | parser.add_argument( 41 | '--batch_size', type=int, default=4, help='batch_size per gpu') 42 | parser.add_argument( 43 | '--num_workers', type=int, default=4) 44 | parser.add_argument( 45 | '--pin_memory', type=lambda x: bool(strtobool(x)), 46 | nargs='?', default=True, help='whether loading data to pinned memory or not') 47 | parser.add_argument( 48 | '--ckpt_path', type=str, default=None, 49 | help='pretrained checkpoint path, helpful for using a pre-trained CASA') 50 | parser.add_argument( 51 | '--disable_ckpt', action='store_true', 52 | help='disable checkpoint saving (useful for debugging).') 53 | parser.add_argument( 54 | '--profiler_name', type=str, default=None, 55 | help='options: [inference, pytorch], or leave it unset') 56 | parser.add_argument( 57 | '--parallel_load_data', action='store_true', 58 | help='load datasets in with multiple processes.') 59 | parser.add_argument( 60 | '--data_folder', type=str, default=None, 61 | help='data folder path') 62 | parser.add_argument( 63 | '--videos_dir', type=str, default=None, 64 | help='directory of videos') 65 | parser.add_argument( 66 | '--dataset_name', type=str, default=None, 67 | help='name of the dataset') 68 | 69 | parser = pl.Trainer.add_argparse_args(parser) 70 | return parser.parse_args() 71 | 72 | 73 | def preprocess(im, rotate, resize, width, height): 74 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 75 | if resize: 76 | im = cv2.resize(im, (width, height)) 77 | if rotate: 78 | im = cv2.transpose(im) 79 | im = cv2.flip(im, 1) 80 | return im 81 | 82 | 83 | def get_frames_in_folder(path, rotate, resize, width, height): 84 | """Returns all frames from a video in a given folder. 85 | 86 | Args: 87 | path: string, directory containing frames of a video. 88 | rotate: Boolean, if True rotates an image by 90 degrees. 89 | resize: Boolean, if True resizes images to given size. 90 | width: Integer, Width of image. 91 | height: Integer, Height of image. 92 | Returns: 93 | frames: list, list of frames in a video. 94 | Raises: 95 | ValueError: When provided directory doesn't exist. 96 | """ 97 | if not os.path.isdir(path): 98 | raise ValueError('Provided path %s is not a directory' % path) 99 | else: 100 | im_list = sorted( 101 | glob.glob(os.path.join(path, '*.%s' % 'jpg'))) 102 | 103 | frames = [preprocess(cv2.imread(im), rotate, resize, width, height) 104 | for im in im_list] 105 | return frames 106 | 107 | 108 | def main(): 109 | 110 | # parse arguments 111 | 112 | args = parse_args() 113 | 114 | # Load images 115 | videos_dir = args.videos_dir 116 | rank_zero_only(pprint.pprint)(vars(args)) 117 | 118 | # init default-cfg and merge it with the main- and data-cfg 119 | config = get_cfg_defaults() 120 | 121 | config.merge_from_file(args.main_cfg_path) 122 | config.merge_from_file(args.data_cfg_path) 123 | 124 | config.DATASET.LOGDIR = args.data_folder 125 | config.DATASET.NAME = args.dataset_name 126 | config.DATASET.TRAIN_DATA_ROOT = config.DATASET.TRAIN_DATA_ROOT + "/" + config.DATASET.NAME +"_train.npy" 127 | config.DATASET.VAL_DATA_ROOT = config.DATASET.TEST_DATA_ROOT = [config.DATASET.VAL_DATA_ROOT[0] + "/" + config.DATASET.NAME +"_train.npy", 128 | config.DATASET.VAL_DATA_ROOT[1] + "/" + config.DATASET.NAME +"_val.npy" ] 129 | # To reproduce the results, 130 | pl.seed_everything(config.TRAINER.SEED) # reproducibility 131 | 132 | # scale lr and warmup-step automatically 133 | args.gpus = _n_gpus = setup_gpus(args.gpus) 134 | config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes 135 | config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size 136 | _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS 137 | config.TRAINER.SCALING = _scaling 138 | config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling 139 | config.TRAINER.WARMUP_STEP = math.floor( 140 | config.TRAINER.WARMUP_STEP / _scaling) 141 | 142 | config.DATASET.PATH = args.data_folder 143 | 144 | print("config.TRAINER.TRUE_BATCH_SIZE", config.TRAINER.TRUE_BATCH_SIZE) 145 | print("_scaling", _scaling) 146 | print("config.TRAINER.WARMUP_STEP", config.TRAINER.WARMUP_STEP) 147 | profiler = build_profiler(args.profiler_name) 148 | model = PL_CASA(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) 149 | 150 | # lightning data 151 | data_module = MultiSceneDataModule(args, config) 152 | 153 | # TensorBoard Logger 154 | 155 | save_dir = os.path.join(config.DATASET.LOGDIR,'logs') 156 | 157 | logger = TensorBoardLogger( 158 | save_dir=save_dir, name=os.path.join(config.DATASET.NAME, args.exp_name), default_hp_metric=False) 159 | 160 | experiment_path = os.path.join(save_dir,config.DATASET.NAME ,args.exp_name) 161 | print("experiment_path", experiment_path) 162 | 163 | # Lightning Trainer #fast_dev_run=True should be here 164 | trainer = pl.Trainer.from_argparse_args( 165 | args, 166 | plugins=DDPPlugin(find_unused_parameters=True, 167 | num_nodes=args.num_nodes, 168 | sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), 169 | gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, 170 | logger=logger, 171 | sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, 172 | replace_sampler_ddp=False, # use custom sampler 173 | reload_dataloaders_every_epoch=False, # avoid repeated samples! 174 | weights_summary='full', 175 | profiler=profiler, 176 | fast_dev_run=False) 177 | loguru_logger.info("Trainer initialized!") 178 | loguru_logger.info("Start predict!") 179 | trainer.test(model, datamodule=data_module) 180 | 181 | 182 | if __name__ == '__main__': 183 | main() 184 | -------------------------------------------------------------------------------- /src/casa/backbone/resnet_fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv1x1(in_planes, out_planes, stride=1): 6 | """1x1 convolution without padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, in_planes, planes, stride=1): 17 | super().__init__() 18 | self.conv1 = conv3x3(in_planes, planes, stride) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | if stride == 1: 25 | self.downsample = None 26 | else: 27 | self.downsample = nn.Sequential( 28 | conv1x1(in_planes, planes, stride=stride), 29 | nn.BatchNorm2d(planes) 30 | ) 31 | 32 | def forward(self, x): 33 | y = x 34 | y = self.relu(self.bn1(self.conv1(y))) 35 | y = self.bn2(self.conv2(y)) 36 | 37 | if self.downsample is not None: 38 | x = self.downsample(x) 39 | 40 | return self.relu(x+y) 41 | 42 | 43 | class ResNetFPN_8_2(nn.Module): 44 | """ 45 | ResNet+FPN, output resolution are 1/8 and 1/2. 46 | Each block has 2 layers. 47 | """ 48 | 49 | def __init__(self, config): 50 | super().__init__() 51 | # Config 52 | block = BasicBlock 53 | initial_dim = config['initial_dim'] 54 | block_dims = config['block_dims'] 55 | 56 | # Class Variable 57 | self.in_planes = initial_dim 58 | 59 | # Networks 60 | self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) 61 | self.bn1 = nn.BatchNorm2d(initial_dim) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 65 | self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 66 | self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 67 | 68 | # 3. FPN upsample 69 | self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) 70 | self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) 71 | self.layer2_outconv2 = nn.Sequential( 72 | conv3x3(block_dims[2], block_dims[2]), 73 | nn.BatchNorm2d(block_dims[2]), 74 | nn.LeakyReLU(), 75 | conv3x3(block_dims[2], block_dims[1]), 76 | ) 77 | self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) 78 | self.layer1_outconv2 = nn.Sequential( 79 | conv3x3(block_dims[1], block_dims[1]), 80 | nn.BatchNorm2d(block_dims[1]), 81 | nn.LeakyReLU(), 82 | conv3x3(block_dims[1], block_dims[0]), 83 | ) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 88 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, block, dim, stride=1): 93 | layer1 = block(self.in_planes, dim, stride=stride) 94 | layer2 = block(dim, dim, stride=1) 95 | layers = (layer1, layer2) 96 | 97 | self.in_planes = dim 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | # ResNet Backbone 102 | x0 = self.relu(self.bn1(self.conv1(x))) 103 | x1 = self.layer1(x0) # 1/2 104 | x2 = self.layer2(x1) # 1/4 105 | x3 = self.layer3(x2) # 1/8 106 | 107 | # FPN 108 | x3_out = self.layer3_outconv(x3) 109 | 110 | x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) 111 | x2_out = self.layer2_outconv(x2) 112 | x2_out = self.layer2_outconv2(x2_out+x3_out_2x) 113 | 114 | x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) 115 | x1_out = self.layer1_outconv(x1) 116 | x1_out = self.layer1_outconv2(x1_out+x2_out_2x) 117 | 118 | return [x3_out, x1_out] 119 | 120 | 121 | class ResNetFPN_16_4(nn.Module): 122 | """ 123 | ResNet+FPN, output resolution are 1/16 and 1/4. 124 | Each block has 2 layers. 125 | """ 126 | 127 | def __init__(self, config): 128 | super().__init__() 129 | # Config 130 | block = BasicBlock 131 | initial_dim = config['initial_dim'] 132 | block_dims = config['block_dims'] 133 | 134 | # Class Variable 135 | self.in_planes = initial_dim 136 | 137 | # Networks 138 | self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(initial_dim) 140 | self.relu = nn.ReLU(inplace=True) 141 | 142 | self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 143 | self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 144 | self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 145 | self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 146 | 147 | # 3. FPN upsample 148 | self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) 149 | self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) 150 | self.layer3_outconv2 = nn.Sequential( 151 | conv3x3(block_dims[3], block_dims[3]), 152 | nn.BatchNorm2d(block_dims[3]), 153 | nn.LeakyReLU(), 154 | conv3x3(block_dims[3], block_dims[2]), 155 | ) 156 | 157 | self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) 158 | self.layer2_outconv2 = nn.Sequential( 159 | conv3x3(block_dims[2], block_dims[2]), 160 | nn.BatchNorm2d(block_dims[2]), 161 | nn.LeakyReLU(), 162 | conv3x3(block_dims[2], block_dims[1]), 163 | ) 164 | 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 168 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 169 | nn.init.constant_(m.weight, 1) 170 | nn.init.constant_(m.bias, 0) 171 | 172 | def _make_layer(self, block, dim, stride=1): 173 | layer1 = block(self.in_planes, dim, stride=stride) 174 | layer2 = block(dim, dim, stride=1) 175 | layers = (layer1, layer2) 176 | 177 | self.in_planes = dim 178 | return nn.Sequential(*layers) 179 | 180 | def forward(self, x): 181 | # ResNet Backbone 182 | x0 = self.relu(self.bn1(self.conv1(x))) 183 | x1 = self.layer1(x0) # 1/2 184 | x2 = self.layer2(x1) # 1/4 185 | x3 = self.layer3(x2) # 1/8 186 | x4 = self.layer4(x3) # 1/16 187 | 188 | # FPN 189 | x4_out = self.layer4_outconv(x4) 190 | 191 | x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) 192 | x3_out = self.layer3_outconv(x3) 193 | x3_out = self.layer3_outconv2(x3_out+x4_out_2x) 194 | 195 | x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) 196 | x2_out = self.layer2_outconv(x2) 197 | x2_out = self.layer2_outconv2(x2_out+x3_out_2x) 198 | 199 | return [x4_out, x2_out] 200 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import argparse 3 | import pprint 4 | from distutils.util import strtobool 5 | from pathlib import Path 6 | from loguru import logger as loguru_logger 7 | import os 8 | 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.utilities import rank_zero_only 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 13 | from pytorch_lightning.plugins import DDPPlugin 14 | 15 | from src.config.default import get_cfg_defaults 16 | from src.utils.misc import get_rank_zero_only_logger, setup_gpus 17 | from src.utils.profiler import build_profiler 18 | from src.lightning.data import MultiSceneDataModule 19 | from src.lightning.lightning_casa import PL_CASA 20 | 21 | loguru_logger = get_rank_zero_only_logger(loguru_logger) 22 | 23 | 24 | def parse_args(): 25 | # init a costum parser which will be added into pl.Trainer parser 26 | # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags 27 | parser = argparse.ArgumentParser( 28 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 29 | parser.add_argument( 30 | '--data_cfg_path', type=str, help='data config path') 31 | parser.add_argument( 32 | '--main_cfg_path', type=str, help='main config path') 33 | parser.add_argument( 34 | '--exp_name', type=str, default='default_exp_name') 35 | parser.add_argument( 36 | '--batch_size', type=int, default=4, help='batch_size per gpu') 37 | parser.add_argument( 38 | '--num_workers', type=int, default=4) 39 | parser.add_argument( 40 | '--pin_memory', type=lambda x: bool(strtobool(x)), 41 | nargs='?', default=True, help='whether loading data to pinned memory or not') 42 | parser.add_argument( 43 | '--ckpt_path', type=str, default=None, 44 | help='pretrained checkpoint path, helpful for using a pre-trained CASA') 45 | parser.add_argument( 46 | '--disable_ckpt', action='store_true', 47 | help='disable checkpoint saving (useful for debugging).') 48 | parser.add_argument( 49 | '--profiler_name', type=str, default=None, 50 | help='options: [inference, pytorch], or leave it unset') 51 | parser.add_argument( 52 | '--parallel_load_data', action='store_true', 53 | help='load datasets in with multiple processes.') 54 | parser.add_argument( 55 | '--dataset_name', type=str, default=None, 56 | help='name of the dataset') 57 | parser.add_argument( 58 | '--data_folder', type=str, default=None, 59 | help='data folder path') 60 | 61 | parser = pl.Trainer.add_argparse_args(parser) 62 | return parser.parse_args() 63 | 64 | 65 | def main(): 66 | # parse arguments 67 | 68 | 69 | args = parse_args() 70 | rank_zero_only(pprint.pprint)(vars(args)) 71 | 72 | # init default-cfg and merge it with the main- and data-cfg 73 | config = get_cfg_defaults() 74 | 75 | # Load main and data cfgs. 76 | config.merge_from_file(args.main_cfg_path) 77 | config.merge_from_file(args.data_cfg_path) 78 | 79 | # Change the dataset folder and the name of the dataset 80 | #config.DATASET.NAME = args.DATASET.NAME 81 | config.DATASET.LOGDIR = args.data_folder 82 | config.DATASET.NAME = args.dataset_name 83 | print("config.DATASET.TRAIN_DATA_ROOT",config.DATASET.TRAIN_DATA_ROOT) 84 | print("config.DATASET.NAME",config.DATASET.NAME) 85 | 86 | config.DATASET.TRAIN_DATA_ROOT = config.DATASET.TRAIN_DATA_ROOT + "/" + config.DATASET.NAME +"_train.npy" 87 | config.DATASET.VAL_DATA_ROOT = config.DATASET.TEST_DATA_ROOT = [config.DATASET.VAL_DATA_ROOT[0] + "/" + config.DATASET.NAME +"_train.npy", 88 | config.DATASET.VAL_DATA_ROOT[1] + "/" + config.DATASET.NAME +"_val.npy" ] # To reproduce the results, 89 | pl.seed_everything(config.TRAINER.SEED) # reproducibility 90 | 91 | # scale lr and warmup-step automatically 92 | args.gpus = _n_gpus = setup_gpus(args.gpus) 93 | config.TRAINER.WORLD_SIZE = _n_gpus * args.num_nodes 94 | config.TRAINER.TRUE_BATCH_SIZE = config.TRAINER.WORLD_SIZE * args.batch_size 95 | _scaling = config.TRAINER.TRUE_BATCH_SIZE / config.TRAINER.CANONICAL_BS 96 | config.TRAINER.SCALING = _scaling 97 | config.TRAINER.TRUE_LR = config.TRAINER.CANONICAL_LR * _scaling 98 | config.TRAINER.WARMUP_STEP = math.floor( 99 | config.TRAINER.WARMUP_STEP / _scaling) 100 | 101 | config.DATASET.PATH = args.data_folder 102 | 103 | 104 | print("config.TRAINER.TRUE_BATCH_SIZE",config.TRAINER.TRUE_BATCH_SIZE) 105 | print("_scaling",_scaling) 106 | print("config.TRAINER.WARMUP_STEP",config.TRAINER.WARMUP_STEP) 107 | #config_par = configparser.RawConfigParser(allow_no_value=True) 108 | #config_par.readfp(config) 109 | 110 | # lightning module 111 | profiler = build_profiler(args.profiler_name) 112 | model = PL_CASA(config, pretrained_ckpt=args.ckpt_path, profiler=profiler) 113 | loguru_logger.info("CASA LightningModule initialized!") 114 | 115 | # lightning data 116 | data_module = MultiSceneDataModule(args, config) 117 | loguru_logger.info("CASA DataModule initialized!") 118 | 119 | # TensorBoard Logger 120 | save_dir = os.path.join(config.DATASET.LOGDIR,'logs') 121 | 122 | logger = TensorBoardLogger( 123 | save_dir=save_dir, name=os.path.join(config.DATASET.NAME, args.exp_name), default_hp_metric=False) 124 | ckpt_dir = Path(logger.log_dir) / 'checkpoints' 125 | 126 | 127 | if not os.path.exists(save_dir): 128 | os.mkdir(save_dir) 129 | dataset_log_path = os.path.join(save_dir,config.DATASET.NAME) 130 | if not os.path.exists(dataset_log_path): 131 | os.mkdir(dataset_log_path) 132 | experiment_path = os.path.join(save_dir,config.DATASET.NAME ,args.exp_name) 133 | if not os.path.exists(experiment_path): 134 | os.mkdir(experiment_path) 135 | 136 | print("experiment_path",experiment_path) 137 | 138 | 139 | if not os.path.exists(logger.log_dir): 140 | os.mkdir(logger.log_dir) 141 | 142 | with open("{}/config.yaml".format(logger.log_dir), "w") as f: 143 | f.write(config.dump()) 144 | # Callbacks 145 | # TODO: update ModelCheckpoint to monitor multiple metrics 146 | ckpt_callback = ModelCheckpoint(verbose=True, mode='max', 147 | save_top_k =-1, 148 | save_last=True, 149 | dirpath=str(ckpt_dir), 150 | filename='{epoch}-{auc@5:.3f}-{auc@10:.3f}-{auc@20:.3f}', 151 | every_n_epochs=1) 152 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 153 | callbacks = [lr_monitor] 154 | if not args.disable_ckpt: 155 | callbacks.append(ckpt_callback) 156 | 157 | # Lightning Trainer #fast_dev_run=True should be here 158 | trainer = pl.Trainer.from_argparse_args( 159 | args, 160 | plugins=DDPPlugin(find_unused_parameters=True, 161 | num_nodes=args.num_nodes, 162 | sync_batchnorm=config.TRAINER.WORLD_SIZE > 0), 163 | gradient_clip_val=config.TRAINER.GRADIENT_CLIPPING, 164 | callbacks=callbacks, 165 | logger=logger, 166 | sync_batchnorm=config.TRAINER.WORLD_SIZE > 0, 167 | replace_sampler_ddp=False, # use custom sampler 168 | reload_dataloaders_every_epoch=False, # avoid repeated samples! 169 | weights_summary='full', 170 | profiler=profiler, 171 | fast_dev_run=False) 172 | loguru_logger.info("Trainer initialized!") 173 | loguru_logger.info("Start training!") 174 | trainer.fit(model, datamodule=data_module) 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /dataset_preparation/preprocess_norm.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tqdm import tqdm 3 | import sys 4 | import numpy as np 5 | import math 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | 9 | sys.path.extend(['../']) 10 | '''IKEA ASM 11 | "nose", # 0 12 | "left eye", # 1 13 | "right eye", # 2 14 | "left ear", # 3 15 | "right ear", # 4 16 | "left shoulder", # 5 - center 17 | "right shoulder", # 6 18 | "left elbow", # 7 19 | "right elbow", # 8 20 | "left wrist", # 9 21 | "right wrist", # 10 22 | "left hip", # 11 23 | "right hip", # 12 24 | "left knee", # 13 25 | "right knee", # 14 26 | "left ankle", # 15 27 | "right ankle", # 16 28 | ''' 29 | 30 | 31 | def get_openpose_connectivity(): 32 | return [ 33 | [0, 1], 34 | [1, 2], 35 | [2, 3], 36 | [3, 4], 37 | [1, 5], 38 | [5, 6], 39 | [6, 7], 40 | [1, 8], 41 | [8, 9], 42 | [9, 10], 43 | [10, 11], 44 | [11, 24], 45 | [11, 22], 46 | [22, 23], 47 | [8, 12], 48 | [12, 13], 49 | [13, 14], 50 | [14, 21], 51 | [14, 19], 52 | [19, 20], 53 | [0, 15], 54 | [15, 17], 55 | [0, 16], 56 | [16, 18] 57 | ] 58 | 59 | 60 | def get_ikea_connectivity(): 61 | return [ 62 | [0, 1], 63 | [0, 2], 64 | [1, 3], 65 | [2, 4], 66 | [0, 5], 67 | [0, 6], 68 | [5, 6], 69 | [5, 7], 70 | [6, 8], 71 | [7, 9], 72 | [8, 10], 73 | [5, 11], 74 | [6, 12], 75 | [11, 12], 76 | [11, 13], 77 | [12, 14], 78 | [13, 15], 79 | [14, 16] 80 | ] 81 | 82 | 83 | def rotation_matrix(axis, theta): 84 | """ 85 | Return the rotation matrix associated with counterclockwise rotation about 86 | the given axis by theta radians. 87 | """ 88 | if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: 89 | return np.eye(3) 90 | axis = np.asarray(axis) 91 | axis = axis / math.sqrt(np.dot(axis, axis)) 92 | a = math.cos(theta / 2.0) 93 | b, c, d = -axis * math.sin(theta / 2.0) 94 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 95 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 96 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 97 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 98 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 99 | 100 | 101 | def unit_vector(vector): 102 | """ Returns the unit vector of the vector. """ 103 | return vector / np.linalg.norm(vector) 104 | 105 | 106 | def angle_between(v1, v2): 107 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 108 | >>> angle_between((1, 0, 0), (0, 1, 0)) 109 | 1.5707963267948966 110 | >>> angle_between((1, 0, 0), (1, 0, 0)) 111 | 0.0 112 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 113 | 3.141592653589793 114 | """ 115 | if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: 116 | return 0 117 | v1_u = unit_vector(v1) 118 | v2_u = unit_vector(v2) 119 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 120 | 121 | 122 | def x_rotation(vector, theta): 123 | """Rotates 3-D vector around x-axis""" 124 | R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], 125 | [0, np.sin(theta), np.cos(theta)]]) 126 | return np.dot(R, vector) 127 | 128 | 129 | def y_rotation(vector, theta): 130 | """Rotates 3-D vector around y-axis""" 131 | R = np.array([[np.cos(theta), 0, np.sin(theta)], [ 132 | 0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) 133 | return np.dot(R, vector) 134 | 135 | 136 | def z_rotation(vector, theta): 137 | """Rotates 3-D vector around z-axis""" 138 | R = np.array([[np.cos(theta), -np.sin(theta), 0], 139 | [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) 140 | return np.dot(R, vector) 141 | 142 | 143 | def pre_normalization(data, zaxis=[5, 11], xaxis=[5, 6], NORM_BONE=True): 144 | """1. Normalize 1st frame. 145 | 2. Normalize every frame. 146 | 3. Normalize every frame but leave the first location to give more information. 147 | """ 148 | VIS = False 149 | 150 | # print("data.shape", np.shape(data)) # (993, 17, 3) 151 | data = np.array(data)[:, :, :3] 152 | N, K, C = data.shape # (993, 17, 3) 153 | s = data 154 | # print('sub the center joint #5 (left shoulder)') 155 | for i_s, skeleton in enumerate(s): 156 | 157 | if np.sum(skeleton[zaxis[0]]) == 0: 158 | continue 159 | # Instaed of setting the center as #5, we can also cal mean vals. 160 | main_body_center = skeleton[zaxis[0]:zaxis[0]+1, :].copy() 161 | s[i_s] = s[i_s] - main_body_center 162 | 163 | # initialization 164 | joint_bottom_prev = np.array([0, 0, 0]) 165 | joint_top_prev = np.array([0, 0, 0]) 166 | #print('parallel the edge between left shoulder(5) and right shoulder(6) to the x axis') 167 | 168 | for i_s, skeleton in enumerate(s): 169 | if np.sum(skeleton) == 0: 170 | continue 171 | if np.sum(skeleton[xaxis[1]]) == 0: 172 | #print("x axis is not exist", i_s) 173 | # if joint in xaxis[1] is zero, then we will use from the previous frame's 174 | # one because it will be not much different. 175 | joint_bottom = joint_bottom_prev 176 | joint_top = joint_top_prev 177 | else: 178 | joint_bottom = skeleton[xaxis[0]] 179 | joint_top = skeleton[xaxis[1]] 180 | joint_bottom_prev = joint_bottom 181 | joint_top_prev = joint_top 182 | joint_rshoulder = skeleton[xaxis[0]] 183 | joint_lshoulder = skeleton[xaxis[1]] 184 | axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) 185 | angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0]) 186 | matrix_x = rotation_matrix(axis, angle) 187 | 188 | for i_j, joint in enumerate(skeleton): 189 | s[i_s, i_j] = np.dot(matrix_x, joint) 190 | 191 | #print('parallel the edge between left shoulder(5) and left heap(11) to the z axis') 192 | 193 | for i_s, skeleton in enumerate(s): 194 | if np.sum(skeleton) == 0: 195 | continue 196 | if np.sum(skeleton[zaxis[1]]) == 0: 197 | #print("z axis is not exist", i_s) 198 | # if joint in zaxis[1] is zero, then we will use from the previous frame's 199 | # one because it will be not much different. 200 | joint_bottom = joint_bottom_prev 201 | joint_top = joint_top_prev 202 | else: 203 | joint_bottom = skeleton[zaxis[0]] 204 | joint_top = skeleton[zaxis[1]] 205 | joint_bottom_prev = joint_bottom 206 | joint_top_prev = joint_top 207 | axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) 208 | # print("joint_top", joint_top) 209 | # print("joint_bottom", joint_bottom) 210 | angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) 211 | matrix_z = rotation_matrix(axis, angle) 212 | for i_j, joint in enumerate(skeleton): 213 | s[i_s, i_j] = np.dot(matrix_z, joint) 214 | 215 | # Normalize bones 216 | if NORM_BONE: 217 | for i_s, skeleton in enumerate(s): 218 | bone_length = np.linalg.norm( 219 | skeleton[zaxis[0]] - skeleton[zaxis[1]]) 220 | if (np.sum(skeleton) == 0) or (bone_length == 0): 221 | continue 222 | 223 | #print("bone_length", bone_length) 224 | #print("s[i_s] before", s[i_s]) 225 | s[i_s] = s[i_s]/bone_length 226 | #print("s[i_s] after", s[i_s]) 227 | if VIS: 228 | fig = plt.figure() 229 | ax = plt.axes(projection='3d') 230 | ax.set_xlim3d(-1, 1) 231 | ax.set_ylim3d(-1, 1) 232 | ax.set_zlim3d(-1, 1) 233 | connectivity = get_openpose_connectivity() 234 | for limb in connectivity: 235 | ax.plot3D(s[i_s, limb, 0], 236 | s[i_s, limb, 1], s[i_s, limb, 2]) 237 | ax.scatter3D(s[i_s, :, 0], s[i_s, :, 1], 238 | s[i_s, :, 2], cmap='Greens') 239 | plt.show(block=False) 240 | plt.pause(0.5) 241 | plt.close() 242 | 243 | # print("s", np.shape(s)) 244 | return s 245 | 246 | 247 | if __name__ == '__main__': 248 | data = np.load('../data/ntu/xview/val_data.npy') 249 | pre_normalization(data) 250 | np.save('../data/ntu/xview/data_val_pre.npy', data) 251 | -------------------------------------------------------------------------------- /src/evaluation/event_completion.py: -------------------------------------------------------------------------------- 1 | r"""Evaluation on detecting key events using a RNN. 2 | """ 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from absl import flags 9 | from absl import logging 10 | 11 | import concurrent.futures as cf 12 | from loguru import logger as loguru_logger 13 | import numpy as np 14 | import sklearn 15 | import copy 16 | from dataset_splits import DATASET_TO_NUM_CLASSES 17 | 18 | from src.evaluation.task_utils import get_targets_from_labels, unnormalize 19 | #from config import ENVCONFIG 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | class VectorRegression(sklearn.base.BaseEstimator): 24 | """Class to perform regression on multiple outputs.""" 25 | 26 | def __init__(self, estimator): 27 | self.estimator = estimator 28 | 29 | def fit(self, x, y): 30 | _, m = y.shape 31 | # Fit a separate regressor for each column of y 32 | self.estimators_ = [sklearn.base.clone(self.estimator).fit(x, y[:, i]) 33 | for i in range(m)] 34 | return self 35 | 36 | def predict(self, x): 37 | # Join regressors' predictions 38 | res = [est.predict(x)[:, np.newaxis] for est in self.estimators_] 39 | return np.hstack(res) 40 | 41 | def score(self, x, y): 42 | # Join regressors' scores 43 | res = [est.score(x, y[:, i]) for i, est in enumerate(self.estimators_)] 44 | return np.mean(res) 45 | 46 | 47 | def get_error(predictions, labels, seq_lens, num_classes, prefix): 48 | """Get error based on predictions.""" 49 | errs = [] 50 | for i in range(num_classes - 1): 51 | abs_err = 0 52 | for j in range(len(predictions)): 53 | # Choose last seq_len steps as our preprocessing pads sequences in 54 | # front with zeros. 55 | unnorm_preds = unnormalize(predictions[j][:, i]) 56 | unnorm_labels = unnormalize(labels[j][:, i]) 57 | 58 | abs_err += abs(unnorm_labels - unnorm_preds) / seq_lens[j] 59 | 60 | err = abs_err / len(predictions) 61 | logging.info('{} {} Fraction Error: ' 62 | '{:.3f},'.format(prefix, i, err)) 63 | # tf.summary.scalar('event_completion/%s_%d_error' % (prefix, i), 64 | # err, step=global_step) 65 | errs.append(err) 66 | 67 | avg_err = np.mean(errs) 68 | 69 | logging.info(' {} Fraction Error: ' 70 | '{:.3f},'.format(prefix, avg_err)) 71 | # tf.summary.scalar('event_completion/avg_error_%s' % prefix, 72 | # avg_err, step=global_step) 73 | 74 | return avg_err 75 | 76 | 77 | def fit_model(train_embs, train_labels, val_embs, val_labels, 78 | num_classes, prefix, report_error=False): 79 | """Linear Regression to regress to fraction completed.""" 80 | 81 | train_seq_lens = [len(x) for x in train_labels] 82 | val_seq_lens = [len(x) for x in val_labels] 83 | 84 | train_embs = np.concatenate(train_embs, axis=0) 85 | train_labels = np.concatenate(train_labels, axis=0) 86 | val_embs = np.concatenate(val_embs, axis=0) 87 | val_labels = np.concatenate(val_labels, axis=0) 88 | 89 | lin_model = VectorRegression(sklearn.linear_model.LinearRegression()) 90 | lin_model.fit(train_embs, train_labels) 91 | 92 | train_score = lin_model.score(train_embs, train_labels) 93 | val_score = lin_model.score(val_embs, val_labels) 94 | 95 | # To debug linear regression 96 | val_predictions = lin_model.predict(val_embs) 97 | train_predictions = lin_model.predict(train_embs) 98 | 99 | # print("train_predictions",train_predictions) 100 | # print("train_labels",train_labels) 101 | # print("val_predictions",val_predictions) 102 | # print("val_labels",val_labels) 103 | 104 | # Not used for evaluation right now. 105 | if report_error: 106 | val_predictions = lin_model.predict(val_embs) 107 | train_predictions = lin_model.predict(train_embs) 108 | 109 | train_labels = np.array_split(train_labels, 110 | np.cumsum(train_seq_lens))[:-1] 111 | train_predictions = np.array_split(train_predictions, 112 | np.cumsum(train_seq_lens))[:-1] 113 | val_labels = np.array_split(val_labels, 114 | np.cumsum(val_seq_lens))[:-1] 115 | val_predictions = np.array_split(val_predictions, 116 | np.cumsum(val_seq_lens))[:-1] 117 | 118 | get_error(train_predictions, train_labels, train_seq_lens, 119 | num_classes, 'train_' + prefix) 120 | get_error(val_predictions, val_labels, val_seq_lens, 121 | num_classes, 'val_' + prefix) 122 | 123 | return train_score, val_score 124 | 125 | 126 | class EventCompletion(): 127 | """Predict event completion using linear regression.""" 128 | 129 | def __init__(self, config): 130 | super(EventCompletion, self).__init__() 131 | self.config = config 132 | 133 | def evaluate_embeddings(self, datasets_ori, DICT=True, emb_mean=True): 134 | """Labeled evaluation.""" 135 | 136 | datasets = copy.deepcopy(datasets_ori) 137 | 138 | num_classes = DATASET_TO_NUM_CLASSES[self.config.DATASET.NAME] # 4 139 | # print("num_class",num_classes) 140 | 141 | #DICT = True 142 | #emb_mean = True 143 | 144 | if DICT: 145 | if emb_mean: 146 | train_emb = [] 147 | train_label = [] 148 | val_emb = [] 149 | val_label = [] 150 | 151 | for key, emb in datasets['train_dataset']['embs'].items(): 152 | train_emb.append(np.average(np.array(emb), axis=0)) 153 | train_label.append( 154 | datasets['train_dataset']['labels'][key][0]) 155 | 156 | for key, emb in datasets['val_dataset']['embs'].items(): 157 | val_emb.append(np.average(np.array(emb), axis=0)) 158 | val_label.append(datasets['val_dataset']['labels'][key][0]) 159 | datasets['train_dataset']['embs'] = train_emb 160 | datasets['train_dataset']['labels'] = train_label 161 | datasets['val_dataset']['embs'] = val_emb 162 | datasets['val_dataset']['labels'] = val_label 163 | 164 | else: 165 | train_emb = [] 166 | train_label = [] 167 | val_emb = [] 168 | val_label = [] 169 | for key, emb in datasets['train_dataset']['embs'].items(): 170 | # for ii in range(len(emb)): 171 | train_emb.append(datasets['train_dataset']['embs'][key][0]) 172 | train_label.append( 173 | datasets['train_dataset']['labels'][key][0]) 174 | 175 | for key, emb in datasets['val_dataset']['embs'].items(): 176 | # for ii in range(len(emb)): 177 | val_emb.append(datasets['val_dataset']['embs'][key][0]) 178 | val_label.append(datasets['val_dataset']['labels'][key][0]) 179 | datasets['train_dataset']['embs'] = train_emb 180 | datasets['train_dataset']['labels'] = train_label 181 | datasets['val_dataset']['embs'] = val_emb 182 | datasets['val_dataset']['labels'] = val_label 183 | 184 | train_embs = datasets['train_dataset']['embs'] 185 | val_embs = datasets['val_dataset']['embs'] 186 | 187 | # print("train_embs",np.size(train_embs)) 188 | 189 | if not train_embs or not val_embs: 190 | logging.warn( 191 | 'All embeddings are NAN. Something is wrong with model.') 192 | return 1.0 193 | 194 | val_labels = get_targets_from_labels( 195 | datasets['val_dataset']['labels'], num_classes) 196 | train_labels = get_targets_from_labels( 197 | datasets['train_dataset']['labels'], num_classes) 198 | 199 | #print("train_labels", np.shape(train_labels)) 200 | #print("train_embs", np.shape(train_embs)) 201 | #print("val_labels", val_labels) 202 | 203 | results = fit_model(train_embs, train_labels, val_embs, val_labels, 204 | num_classes, '%s_%s' % ('Penn', str(1))) 205 | train_score, val_score = results 206 | 207 | prefix = '%s_%s' % ('Penn', str(1)) 208 | loguru_logger.info('Event Completion {} Fraction Train ' 209 | 'Score: {:.5f},'.format(prefix, 210 | train_score)) 211 | loguru_logger.info('Event Completion {} Fraction Val ' 212 | 'Score: {:.5f},'.format(prefix, 213 | val_score)) 214 | 215 | return train_score, val_score 216 | -------------------------------------------------------------------------------- /src/casa/utils/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | 6 | INF = 1e9 7 | 8 | 9 | def mask_border(m, b: int, v): 10 | """ Mask borders with value 11 | Args: 12 | m (torch.Tensor): [N, H0, W0, H1, W1] 13 | m (torch.Tensor): [N, L0, L1] 14 | b (int) 15 | v (m.dtype) 16 | """ 17 | if b <= 0: 18 | return 19 | 20 | m[:, :b] = v 21 | m[:, :, :b] = v 22 | # m[:, :, :, :b] = v 23 | # m[:, :, :, :, :b] = v 24 | m[:, -b:] = v 25 | m[:, :, -b:] = v 26 | # m[:, :, :, -b:] = v 27 | # m[:, :, :, :, -b:] = v 28 | 29 | 30 | def mask_border_with_padding(m, bd, v, p_m0, p_m1): 31 | if bd <= 0: 32 | return 33 | 34 | m[:, :bd] = v 35 | m[:, :, :bd] = v 36 | m[:, :, :, :bd] = v 37 | m[:, :, :, :, :bd] = v 38 | 39 | h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() 40 | h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() 41 | for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): 42 | m[b_idx, h0 - bd:] = v 43 | m[b_idx, :, w0 - bd:] = v 44 | m[b_idx, :, :, h1 - bd:] = v 45 | m[b_idx, :, :, :, w1 - bd:] = v 46 | 47 | 48 | def compute_max_candidates(p_m0, p_m1): 49 | """Compute the max candidates of all pairs within a batch 50 | 51 | Args: 52 | p_m0, p_m1 (torch.Tensor): padded maskszzz 53 | """ 54 | h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] 55 | h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] 56 | max_cand = torch.sum( 57 | torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) 58 | return max_cand 59 | 60 | 61 | def pairwise_l2_distance(embs1, embs2): 62 | """Computes pairwise distances between all rows of embs1 and embs2.""" 63 | n, l, c = embs1.shape 64 | norm1 = torch.sum(torch.square(embs1), 2) 65 | norm1 = torch.reshape(norm1, [n, -1, 1]) 66 | norm2 = torch.sum(torch.square(embs2), 2) 67 | norm2 = torch.reshape(norm2, [n, 1, -1]) 68 | 69 | # Max to ensure matmul doesn't produce anything negative due to floating 70 | # point approximations. 71 | dist = torch.maximum( 72 | norm1 + norm2 - 2.0 * torch.einsum("nlc,nsc->nls", embs1, embs2), torch.zeros_like(norm1)) 73 | 74 | return dist 75 | 76 | 77 | def get_scaled_similarity(embs1, embs2, temperature): 78 | 79 | B, M, C = embs1.shape 80 | # Go for embs1 to embs2. 81 | similarity = -pairwise_l2_distance(embs1, embs2) 82 | similarity /= C 83 | similarity /= temperature 84 | 85 | return similarity 86 | 87 | 88 | def dual_softmax(feat_c0, feat_c1, temperature, SIM=False): 89 | # print("feat_c0", feat_c0) 90 | N, L, S, C = feat_c0.size(0), feat_c0.size( 91 | 1), feat_c1.size(1), feat_c0.size(2) 92 | 93 | # normalize 94 | feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, 95 | [feat_c0, feat_c1]) 96 | 97 | if SIM: 98 | # feat_c0 = F.normalize( 99 | # feat_c0, p=2.0, dim=1, eps=1e-12, out=None) 100 | # feat_c1 = F.normalize( 101 | # feat_c1, p=2.0, dim=1, eps=1e-12, out=None) 102 | #print("feat_c0", feat_c0.shape) 103 | norm_c0 = torch.unsqueeze(torch.norm(feat_c0, dim=2), 2) 104 | norm_c1 = torch.unsqueeze(torch.norm(feat_c1, dim=2), 2) 105 | #print("norm_c0", norm_c0.shape) 106 | feat_c0 = feat_c0 / norm_c0 107 | feat_c1 = feat_c1/norm_c1 108 | sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, 109 | feat_c1) / temperature 110 | else: # distnace matrix 111 | # Temperature helps with how soft the alignment should be. 112 | sim_matrix = -torch.cdist(feat_c0, feat_c1) 113 | sim_matrix /= temperature 114 | # Scale the distance by number of channels. This normalization helps with optimization. 115 | sim_matrix /= C 116 | # print("sim_matrix",sim_matrix.shape) 117 | # = torch.einsum("nlc,nsc->nls", feat_c0, 118 | # feat_c1) / temperature 119 | conf_matrix0 = F.softmax(sim_matrix, 1) 120 | conf_matrix1 = F.softmax(sim_matrix, 2) 121 | 122 | conf_matrix = conf_matrix0 * conf_matrix1 123 | return conf_matrix, conf_matrix0, conf_matrix1, sim_matrix 124 | 125 | 126 | 127 | def dual_bicross(feat_c0, feat_c1): 128 | # print("feat_c0", feat_c0) 129 | N, L, S, C = feat_c0.size(0), feat_c0.size( 130 | 1), feat_c1.size(1), feat_c0.size(2) 131 | 132 | # normalize 133 | feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, 134 | [feat_c0, feat_c1]) 135 | sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, 136 | feat_c1) 137 | sigmoid_f = nn.Sigmoid() 138 | conf_matrix = sigmoid_f(sim_matrix) 139 | return conf_matrix 140 | 141 | 142 | class Matching(nn.Module): 143 | def __init__(self, config): 144 | super().__init__() 145 | self.config = config 146 | # general config 147 | self.thr = config['thr'] 148 | self.border_rm = config['border_rm'] 149 | 150 | # we provide 2 options for differentiable matching 151 | self.match_type = config['match_type'] 152 | self.temperature = config['dsmax_temperature'] 153 | self.sim = config['similarity'] 154 | 155 | 156 | def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): 157 | """ 158 | Args: 159 | feat0 (torch.Tensor): [N, L, C] 160 | feat1 (torch.Tensor): [N, S, C] 161 | data (dict) 162 | mask_c0 (torch.Tensor): [N, L] (optional) 163 | mask_c1 (torch.Tensor): [N, S] (optional) 164 | Update: 165 | data (dict): { 166 | 'b_ids' (torch.Tensor): [M'], 167 | 'i_ids' (torch.Tensor): [M'], 168 | 'j_ids' (torch.Tensor): [M'], 169 | 'gt_mask' (torch.Tensor): [M'], 170 | 'mkpts0_c' (torch.Tensor): [M, 2], 171 | 'mkpts1_c' (torch.Tensor): [M, 2], 172 | 'mconf' (torch.Tensor): [M]} 173 | NOTE: M' != M during training. 174 | """ 175 | if self.match_type == 'dual_softmax': 176 | conf_matrix, conf_matrix0, conf_matrix1, sim_matrix = dual_softmax( 177 | feat_c0, feat_c1, self.temperature, SIM=self.sim) 178 | data.update({'conf_matrix': conf_matrix, 'conf_matrix0': conf_matrix0, 179 | 'conf_matrix1': conf_matrix1, 'sim_matrix': sim_matrix}) 180 | # predict coarse matches from conf_matrix 181 | data.update(**self.get_coarse_match(conf_matrix, data)) 182 | elif self.match_type == 'dual_bicross': 183 | conf_matrix = dual_bicross(feat_c0, feat_c1) 184 | data.update({'conf_matrix': conf_matrix}) 185 | # predict coarse matches from conf_matrix 186 | data.update(**self.get_coarse_match(conf_matrix, data)) 187 | 188 | 189 | @ torch.no_grad() 190 | def get_coarse_match(self, conf_matrix, data): 191 | """ 192 | Args: 193 | conf_matrix (torch.Tensor): [N, L, S] 194 | data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] 195 | Returns: 196 | coarse_matches (dict): { 197 | 'b_ids' (torch.Tensor): [M'], 198 | 'i_ids' (torch.Tensor): [M'], 199 | 'j_ids' (torch.Tensor): [M'], 200 | 'gt_mask' (torch.Tensor): [M'], 201 | 'm_bids' (torch.Tensor): [M], 202 | 'mkpts0_c' (torch.Tensor): [M, 2], 203 | 'mkpts1_c' (torch.Tensor): [M, 2], 204 | 'mconf' (torch.Tensor): [M]} 205 | """ 206 | axes_lengths = { 207 | 'len0': data['len_t0'], 208 | # 'w0c': data['hw0_c'][1], 209 | 'len1': data['len_t1'], 210 | # 'w1c': data['hw1_c'][1] 211 | } 212 | _device = conf_matrix.device 213 | # 1. confidence thresholding 214 | mask = conf_matrix > self.thr 215 | 216 | # 2. mutual nearest 217 | i_mask = conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0] 218 | j_mask = conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0] 219 | mask = mask * i_mask * j_mask 220 | 221 | # 3. find all valid coarse matches 222 | # this only works when at most one `True` in each row 223 | mask_v, all_j_ids = mask.max(dim=2) 224 | b_ids, i_ids = torch.where(mask_v) 225 | j_ids = all_j_ids[b_ids, i_ids] 226 | mconf = conf_matrix[b_ids, i_ids, j_ids] 227 | 228 | # These matches select patches that feed into fine-level network 229 | coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} 230 | 231 | # These matches is the current prediction (for visualization) 232 | coarse_matches.update({ 233 | 'gt_mask': mconf == 0, 234 | 'mask': mask, 235 | 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches 236 | 'mconf': mconf[mconf != 0], 237 | 'i_mask': i_mask, 238 | 'j_mask': j_mask, 239 | }) 240 | 241 | return coarse_matches 242 | -------------------------------------------------------------------------------- /dataset_preparation/preprocess_norm_mat.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tqdm import tqdm 3 | import sys 4 | import numpy as np 5 | import math 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | import torch 9 | from scipy.spatial.transform import Rotation as R 10 | 11 | sys.path.extend(['../']) 12 | '''IKEA ASM 13 | "nose", # 0 14 | "left eye", # 1 15 | "right eye", # 2 16 | "left ear", # 3 17 | "right ear", # 4 18 | "left shoulder", # 5 - center 19 | "right shoulder", # 6 20 | "left elbow", # 7 21 | "right elbow", # 8 22 | "left wrist", # 9 23 | "right wrist", # 10 24 | "left hip", # 11 25 | "right hip", # 12 26 | "left knee", # 13 27 | "right knee", # 14 28 | "left ankle", # 15 29 | "right ankle", # 16 30 | ''' 31 | 32 | 33 | def get_openpose_connectivity(): 34 | return [ 35 | [0, 1], 36 | [1, 2], 37 | [2, 3], 38 | [3, 4], 39 | [1, 5], 40 | [5, 6], 41 | [6, 7], 42 | [1, 8], 43 | [8, 9], 44 | [9, 10], 45 | [10, 11], 46 | [11, 24], 47 | [11, 22], 48 | [22, 23], 49 | [8, 12], 50 | [12, 13], 51 | [13, 14], 52 | [14, 21], 53 | [14, 19], 54 | [19, 20], 55 | [0, 15], 56 | [15, 17], 57 | [0, 16], 58 | [16, 18] 59 | ] 60 | 61 | 62 | def get_ikea_connectivity(): 63 | return [ 64 | [0, 1], 65 | [0, 2], 66 | [1, 3], 67 | [2, 4], 68 | [0, 5], 69 | [0, 6], 70 | [5, 6], 71 | [5, 7], 72 | [6, 8], 73 | [7, 9], 74 | [8, 10], 75 | [5, 11], 76 | [6, 12], 77 | [11, 12], 78 | [11, 13], 79 | [12, 14], 80 | [13, 15], 81 | [14, 16] 82 | ] 83 | 84 | 85 | def get_h2o_connectivity(): 86 | offset = 21 87 | return [ 88 | [1,2], 89 | [2,3], 90 | [3,4], 91 | [5,6], 92 | [6,7], 93 | [7,8], 94 | [9,10], 95 | [10,11], 96 | [11,12], 97 | 98 | [13,14], 99 | [14,15], 100 | [15,16], 101 | [17,18], 102 | [18,19], 103 | [19,20], 104 | 105 | [0,1], 106 | [0,5], 107 | [0,9], 108 | [0,13], 109 | [0,17], 110 | 111 | [1+offset,2+offset], 112 | [2+offset,3+offset], 113 | [3+offset,4+offset], 114 | [5+offset,6+offset], 115 | [6+offset,7+offset], 116 | [7+offset,8+offset], 117 | [9+offset,10+offset], 118 | [10+offset,11+offset], 119 | [11+offset,12+offset], 120 | 121 | [13+offset,14+offset], 122 | [14+offset,15+offset], 123 | [15+offset,16+offset], 124 | [17+offset,18+offset], 125 | [18+offset,19+offset], 126 | [19+offset,20+offset], 127 | 128 | [0+offset,1+offset], 129 | [0+offset,5+offset], 130 | [0+offset,9+offset], 131 | [0+offset,13+offset], 132 | [0+offset,17+offset], 133 | 134 | 135 | ] 136 | 137 | 138 | def rotation_matrix(axis, theta): 139 | """ 140 | Return the rotation matrix associated with counterclockwise rotation about 141 | the given axis by theta radians. 142 | """ 143 | if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: 144 | return np.eye(3) 145 | axis = np.asarray(axis) 146 | axis = axis / math.sqrt(np.dot(axis, axis)) 147 | a = math.cos(theta / 2.0) 148 | b, c, d = -axis * math.sin(theta / 2.0) 149 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 150 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 151 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 152 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 153 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 154 | 155 | 156 | def unit_vector(vector): 157 | """ Returns the unit vector of the vector. """ 158 | # print("np.linalg.norm(vector)", np.linalg.norm(vector)) 159 | return vector / np.linalg.norm(vector) 160 | 161 | 162 | def angle_between(v1, v2): 163 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 164 | >>> angle_between((1, 0, 0), (0, 1, 0)) 165 | 1.5707963267948966 166 | >>> angle_between((1, 0, 0), (1, 0, 0)) 167 | 0.0 168 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 169 | 3.141592653589793 170 | """ 171 | if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: 172 | return 0 173 | # print("v1", v1) 174 | v1_u = unit_vector(v1) 175 | v2_u = unit_vector(v2) 176 | # print("v1_u", v1_u) 177 | # print("v2_u", v2_u) 178 | # print("np.dot(v1_u, v2_u)", np.dot(v1_u, v2_u)) 179 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 180 | 181 | 182 | def unit_vector_mat(vector): 183 | """ Returns the unit vector of the vector. """ 184 | # print("matamt") 185 | # print("np.linalg.norm(vector)", np.linalg.norm(vector, axis=1)) 186 | # print("np.linalg.norm(vector, axis=1)[0]", 187 | # np.linalg.norm(vector, axis=1)) 188 | return (vector.T/np.linalg.norm(vector, axis=1).T).T 189 | 190 | 191 | def angle_between_mat(v1, v2): 192 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 193 | >>> angle_between((1, 0, 0), (0, 1, 0)) 194 | 1.5707963267948966 195 | >>> angle_between((1, 0, 0), (1, 0, 0)) 196 | 0.0 197 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 198 | 3.141592653589793 199 | """ 200 | if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: 201 | return 0 202 | # print("v1", v1) 203 | v1_u = unit_vector_mat(v1) 204 | v2_u = unit_vector_mat(v2) 205 | 206 | # print("v1_u", v1_u) 207 | # print("v2_u", v2_u) 208 | # print("np.multiply(v1_u, v2_u)", np.sum(np.multiply(v1_u, v2_u), axis=1)) 209 | return np.arccos(np.clip(np.sum(np.multiply(v1_u, v2_u), axis=1), -1.0, 1.0)) 210 | 211 | 212 | def x_rotation(vector, theta): 213 | """Rotates 3-D vector around x-axis""" 214 | R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], 215 | [0, np.sin(theta), np.cos(theta)]]) 216 | return np.dot(R, vector) 217 | 218 | 219 | def y_rotation(vector, theta): 220 | """Rotates 3-D vector around y-axis""" 221 | R = np.array([[np.cos(theta), 0, np.sin(theta)], [ 222 | 0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) 223 | return np.dot(R, vector) 224 | 225 | 226 | def z_rotation(vector, theta): 227 | """Rotates 3-D vector around z-axis""" 228 | R = np.array([[np.cos(theta), -np.sin(theta), 0], 229 | [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) 230 | return np.dot(R, vector) 231 | 232 | # ikea 5,11 / 5,6 openpose 1,8/1,5 233 | def pre_normalization_mat(data, zaxis=[5, 11], xaxis=[5, 6], NORM_BONE=True,ERR_BORN=False): 234 | """1. Normalize 1st frame. 235 | 2. Normalize every frame. 236 | 3. Normalize every frame but leave the first location to give more information. 237 | """ 238 | VIS = False 239 | 240 | # print("data.shape", np.shape(data)) # (993, 17, 3) 241 | data = np.array(data)[:, :, :3] 242 | N, K, C = data.shape # (993, 17, 3) 243 | 244 | # print('sub the center joint #5 (left shoulder)') 245 | # for i_s, skeleton in enumerate(s): 246 | 247 | # if np.sum(skeleton[zaxis[0]]) == 0: 248 | # continue 249 | # Instaed of setting the center as #5, we can also cal mean vals. 250 | 251 | main_body_center = data[:, zaxis[0]:zaxis[0]+1, :] 252 | data = data - main_body_center 253 | 254 | # print('parallel the edge between left shoulder(5) and right shoulder(6) to the x axis') 255 | joint_bottom = data[:, xaxis[0]] 256 | joint_top = data[:, xaxis[1]] 257 | joint_rshoulder = data[:, xaxis[0]] 258 | joint_lshoulder = data[:, xaxis[1]] 259 | axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) 260 | input_axis = np.tile([1, 0, 0], (N, 1)) 261 | angle = angle_between_mat(joint_rshoulder - joint_lshoulder, input_axis) 262 | #print("angle", angle) 263 | # r = R.from_rotvec((input_axis.T*angle.T).T) 264 | 265 | for ii in range(N): 266 | matrix_x = rotation_matrix(axis[ii], angle[ii]) 267 | data[ii] = np.dot(matrix_x, data[ii].T).T 268 | 269 | # print('parallel the edge between left shoulder(5) and left heap(11) to the z axis') 270 | joint_bottom = data[:, zaxis[0]] 271 | joint_top = data[:, zaxis[1]] 272 | # oint_rshoulder = data[:, xaxis[0]] 273 | # joint_lshoulder = data[:, xaxis[1]] 274 | axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) 275 | input_axis = np.tile([0, 0, 1], (N, 1)) 276 | angle = angle_between_mat(joint_top - joint_bottom, input_axis) 277 | # print("angle", angle) 278 | # r = R.from_rotvec((input_axis.T*angle.T).T) 279 | 280 | for ii in range(N): 281 | matrix_z = rotation_matrix(axis[ii], angle[ii]) 282 | data[ii] = np.dot(matrix_z, data[ii].T).T 283 | 284 | # Normalize bones 285 | if NORM_BONE: 286 | if ERR_BORN: 287 | s = copy.deepcopy(data) 288 | for i_s, skeleton in enumerate(s): 289 | bone_length = np.linalg.norm( 290 | skeleton[zaxis[0]] - skeleton[zaxis[1]]) 291 | if (np.sum(skeleton) == 0) or (bone_length == 0): 292 | continue 293 | 294 | data = s 295 | else: 296 | 297 | bone_length = np.linalg.norm( 298 | data[:, zaxis[0]] - data[:, zaxis[1]], axis=1) 299 | data = data/np.tile(np.expand_dims(bone_length, 300 | axis=(1, 2)), (1, K, C)) 301 | # print("bone_length", bone_length) 302 | 303 | # print("bone_length", bone_length) 304 | # print("s[i_s] before", s[i_s]) 305 | # s[i_s] = s[i_s]/bone_length 306 | # print("s[i_s] after", s[i_s]) 307 | 308 | 309 | # print("np.expand_dims(bone_lengthaxis(1, 2)", np.expand_dims(bone_length, 310 | # axis=(1, 2))) 311 | 312 | #print("s", s) 313 | #print("data", data) 314 | 315 | # if VIS: 316 | # for i_s, skeleton in enumerate(s): 317 | # fig = plt.figure() 318 | # ax = plt.axes(projection='3d') 319 | # ax.set_xlim3d(-1, 1) 320 | # ax.set_ylim3d(-1, 1) 321 | # ax.set_zlim3d(-1, 1) 322 | # connectivity = get_openpose_connectivity() 323 | # for limb in connectivity: 324 | # ax.plot3D(s[i_s, limb, 0], 325 | # s[i_s, limb, 1], s[i_s, limb, 2]) 326 | # ax.scatter3D(s[i_s, :, 0], s[i_s, :, 1], 327 | # s[i_s, :, 2], cmap='Greens') 328 | # plt.show(block=False) 329 | # plt.pause(0.5) 330 | # plt.close() 331 | 332 | # print("s", np.shape(s)) 333 | return data 334 | 335 | 336 | if __name__ == '__main__': 337 | data = np.load('../data/ntu/xview/val_data.npy') 338 | pre_normalization(data) 339 | np.save('../data/ntu/xview/data_val_pre.npy', data) 340 | -------------------------------------------------------------------------------- /src/lightning/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from collections import abc 4 | from loguru import logger 5 | from torch.utils.data.dataset import Dataset 6 | from tqdm import tqdm 7 | from os import path as osp 8 | from pathlib import Path 9 | from joblib import Parallel, delayed 10 | from functools import partial 11 | 12 | import pytorch_lightning as pl 13 | from torch import distributed as dist 14 | from torch.utils.data import ( 15 | Dataset, 16 | DataLoader, 17 | ConcatDataset, 18 | DistributedSampler, 19 | ChainDataset, 20 | RandomSampler, 21 | dataloader 22 | ) 23 | 24 | from src.utils.augment import build_augmentor 25 | from src.utils.dataloader import get_local_split, collate_fixed_len 26 | from src.utils.misc import tqdm_joblib 27 | from src.datasets.pennaction import PennActionDataset 28 | 29 | 30 | class MultiSceneDataModule(pl.LightningDataModule): 31 | """ 32 | For distributed training, each training process is assgined 33 | only a part of the training scenes to reduce memory overhead. 34 | """ 35 | 36 | def __init__(self, args, config): 37 | super().__init__() 38 | # 1. data config 39 | # Train and Val should from the same data source 40 | self.config = config 41 | self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE 42 | self.test_data_source = config.DATASET.TEST_DATA_SOURCE 43 | # training and validating 44 | 45 | self.train_data_root = config.DATASET.TRAIN_DATA_ROOT 46 | self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) 47 | self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT 48 | self.train_list_path = config.DATASET.TRAIN_LIST_PATH 49 | self.val_data_root = config.DATASET.VAL_DATA_ROOT 50 | self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) 51 | self.val_npz_root = config.DATASET.VAL_NPZ_ROOT 52 | self.val_list_path = config.DATASET.VAL_LIST_PATH 53 | self.val_batch_size = config.DATASET.VAL_BATCH_SIZE 54 | # testing 55 | self.test_data_root = config.DATASET.TEST_DATA_ROOT 56 | self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) 57 | self.test_npz_root = config.DATASET.TEST_NPZ_ROOT 58 | self.test_list_path = config.DATASET.TEST_LIST_PATH 59 | 60 | self.sampling_strategy = config.DATASET.SAMPLING_STRATEGY 61 | self.num_frames = config.DATASET.NUM_FRAMES 62 | # 2. dataset config 63 | self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) 64 | self.use_norm = config.DATASET.USE_NORM 65 | 66 | # 0.125. for training casa. 67 | self.coarse_scale = 1 / config.CASA.RESOLUTION[0] 68 | 69 | self.contrastive = config.CONSTRASTIVE.TRAIN 70 | self.augmentation_strategy = config.CONSTRASTIVE.AUGMENTATION_STRATEGY 71 | 72 | # 3.loader parameters 73 | self.train_loader_params = { 74 | 'batch_size': args.batch_size, 75 | 'num_workers': args.num_workers, 76 | 'pin_memory': getattr(args, 'pin_memory', True) 77 | } 78 | self.val_loader_params = { 79 | 'batch_size': self.val_batch_size, # 1 80 | 'shuffle': False, 81 | 'num_workers': args.num_workers, 82 | 'pin_memory': getattr(args, 'pin_memory', True) 83 | } 84 | self.test_loader_params = { 85 | 'batch_size': self.val_batch_size, 86 | 'shuffle': False, 87 | 'num_workers': args.num_workers, 88 | 'pin_memory': True 89 | } 90 | 91 | 92 | 93 | # (optional) RandomSampler for debugging 94 | 95 | # misc configurations 96 | self.parallel_load_data = getattr(args, 'parallel_load_data', False) 97 | self.seed = config.TRAINER.SEED # 66 98 | 99 | def setup(self, stage=None): 100 | """ 101 | Setup train / val / test dataset. This method will be called by PL automatically. 102 | Args: 103 | stage (str): 'fit' in training phase, and 'test' in testing phase. 104 | """ 105 | 106 | assert stage in ['fit', 'test', 107 | 'predict'], "stage must be either fit or test" 108 | 109 | try: 110 | self.world_size = dist.get_world_size() 111 | self.rank = dist.get_rank() 112 | logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") 113 | except AssertionError as ae: 114 | self.world_size = 1 115 | self.rank = 0 116 | logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") 117 | 118 | if stage == 'fit' or stage == 'predict': 119 | self.train_dataset = self._setup_dataset( 120 | self.train_data_root, 121 | self.train_npz_root, 122 | self.train_list_path, 123 | mode='train', 124 | pose_dir=self.train_pose_root) 125 | 126 | # setup multiple (optional) validation subsets 127 | if isinstance(self.val_list_path, (list, tuple)): 128 | self.val_dataset = [] 129 | if not isinstance(self.val_npz_root, (list, tuple)): 130 | self.val_npz_root = [ 131 | self.val_npz_root for _ in range(len(self.val_list_path))] 132 | for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): 133 | self.val_dataset.append(self._setup_dataset( 134 | self.val_data_root, 135 | npz_root, 136 | npz_list, 137 | mode='val', 138 | pose_dir=self.val_pose_root)) 139 | else: 140 | self.val_dataset = self._setup_dataset( 141 | self.val_data_root, 142 | self.val_npz_root, 143 | self.val_list_path, 144 | mode='val', 145 | pose_dir=self.val_pose_root) 146 | logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') 147 | 148 | else: # stage == 'test 149 | self.test_dataset = self._setup_dataset( 150 | self.test_data_root, 151 | self.test_npz_root, 152 | self.test_list_path, 153 | mode='test', 154 | pose_dir=self.test_pose_root) 155 | logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') 156 | 157 | def _setup_dataset(self, 158 | data_root, 159 | split_npz_root, 160 | scene_list_path, 161 | mode='train', 162 | pose_dir=None): 163 | """ Setup train / val / test set""" 164 | 165 | dataset_builder = self._build_concat_dataset 166 | local_npz_names = data_root 167 | return dataset_builder(data_root, local_npz_names, split_npz_root, 168 | mode=mode, pose_dir=pose_dir) 169 | 170 | def _build_concat_dataset( 171 | self, 172 | data_root, 173 | npz_names, 174 | npz_dir, 175 | mode, 176 | pose_dir=None 177 | ): 178 | 179 | augment_fn = self.augment_fn if mode == 'train' else None 180 | data_source = self.trainval_data_source if mode in [ 181 | 'train'] else self.test_data_source 182 | 183 | npz_path = data_root # osp.join(npz_dir, npz_name) 184 | # print("npz_path",npz_path) 185 | if mode in ['train']: 186 | datasets = [] 187 | if data_source == 'PennAction': 188 | datasets.append( 189 | PennActionDataset(npz_path, 190 | num_frames=self.num_frames, 191 | sampling_strategy=self.sampling_strategy, 192 | mode=mode, 193 | augment_fn=augment_fn, 194 | coarse_scale=self.coarse_scale, 195 | contrastive=self.contrastive, 196 | augmentation_strategy=self.augmentation_strategy, 197 | use_norm=self.use_norm, 198 | config=self.config)) 199 | else: 200 | raise NotImplementedError() 201 | else: 202 | datasets = [] 203 | 204 | if data_source == 'PennAction': 205 | print("npz_path", npz_path) 206 | 207 | 208 | 209 | for npz_path_elem in npz_path: 210 | # Append train and validation sets together. 211 | datasets.append(PennActionDataset(npz_path_elem, 212 | num_frames=self.num_frames, 213 | sampling_strategy=self.sampling_strategy, 214 | mode=mode, 215 | augment_fn=augment_fn, 216 | coarse_scale=self.coarse_scale, 217 | contrastive=self.contrastive, 218 | val=True, 219 | augmentation_strategy=self.augmentation_strategy, 220 | use_norm=self.use_norm, 221 | config=self.config)) 222 | 223 | else: 224 | raise NotImplementedError() 225 | return ConcatDataset(datasets) 226 | 227 | def train_dataloader(self): 228 | logger.info( 229 | f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') 230 | 231 | dataloader = DataLoader( 232 | self.train_dataset, shuffle=True, **self.train_loader_params) 233 | print("self.train_dataset", len(self.train_dataset)) 234 | print("dataloader.__len__", len(dataloader.dataset)) 235 | return dataloader 236 | 237 | def val_dataloader(self): 238 | logger.info( 239 | f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') 240 | dataloader = DataLoader( 241 | self.val_dataset, **self.val_loader_params) 242 | 243 | print("self.val_taset", len(self.val_dataset)) 244 | print("dataloader.__len__", len(dataloader.dataset)) 245 | return dataloader 246 | 247 | def predict_dataloader(self): 248 | """ Build validation dataloader for H2O/Penn Action/IKEA ASM. """ 249 | logger.info( 250 | f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') 251 | dataloader = DataLoader( 252 | self.val_dataset, **self.val_loader_params) 253 | 254 | print("predict datset length", len(self.val_dataset)) 255 | print("dataloader.__len__", len(dataloader.dataset)) 256 | return dataloader 257 | 258 | def test_dataloader(self, *args, **kwargs): 259 | logger.info( 260 | f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') 261 | sampler = DistributedSampler(self.test_dataset, shuffle=False) 262 | return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) 263 | 264 | 265 | def _build_dataset(dataset: Dataset, *args, **kwargs): 266 | return dataset(*args, **kwargs) 267 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/lightning/lightning_casa.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import defaultdict 3 | import pprint 4 | from loguru import logger 5 | from pathlib import Path 6 | import copy 7 | import torch 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | from matplotlib import pyplot as plt 11 | from src.evaluation.event_completion import EventCompletion 12 | from src.evaluation.classification import Classification 13 | from src.evaluation.kendalls_tau import KendallsTau 14 | # from src.utils.classification import Classification 15 | import os 16 | 17 | from src.casa.utils.matching import dual_softmax, dual_bicross 18 | 19 | from src.casa import CASA 20 | from src.casa.utils.supervision import compute_supervision_coarse 21 | from src.losses.casa_loss import CASALoss 22 | from src.optimizers import build_optimizer, build_scheduler 23 | from src.utils.plotting import vis_conf_matrix 24 | from src.utils.misc import lower_config 25 | from src.utils.profiler import PassThroughProfiler 26 | 27 | 28 | class PL_CASA(pl.LightningModule): 29 | def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): 30 | """ 31 | TODO: 32 | - use the new version of PL logging API. 33 | """ 34 | super().__init__() 35 | 36 | # Misc 37 | self.config = config # full config 38 | _config = lower_config(self.config) 39 | self.casa_cfg = lower_config(_config['casa']) 40 | # CLASSIFICATION.ACC_LIST 41 | self.acc_list = _config['classification']['acc_list'] 42 | self.profiler = profiler or PassThroughProfiler() 43 | # print("_config",_config) 44 | self.vis_conf_train = _config['casa']['match']['vis_conf_train'] 45 | self.vis_conf_val = _config['casa']['match']['vis_conf_validation'] 46 | # Matcher: CASA 47 | self.matcher = CASA(config=_config['casa']) 48 | self.loss = CASALoss(_config) 49 | self.temperature = _config['casa']['match']['dsmax_temperature'] 50 | self.thr = _config['casa']['match']['thr'] 51 | self.match_type = _config['casa']['match']['match_type'] 52 | # Pretrained weights 53 | if pretrained_ckpt: 54 | state_dict = torch.load(pretrained_ckpt, map_location='cpu')[ 55 | 'state_dict'] 56 | self.matcher.load_state_dict(state_dict, strict=True) 57 | logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") 58 | 59 | # Testing 60 | self.dump_dir = dump_dir 61 | self.classification = Classification(self.config) 62 | self.eventcompletion = EventCompletion(self.config) 63 | self.kendallstau = KendallsTau(self.config) 64 | 65 | def configure_optimizers(self): 66 | # FIXME: The scheduler did not work properly when `--resume_from_checkpoint` 67 | optimizer = build_optimizer(self, self.config) 68 | scheduler = build_scheduler(self.config, optimizer) 69 | return [optimizer], [scheduler] 70 | 71 | def optimizer_step( 72 | self, epoch, batch_idx, optimizer, optimizer_idx, 73 | optimizer_closure, on_tpu, using_native_amp, using_lbfgs): 74 | # learning rate warm up 75 | warmup_step = self.config.TRAINER.WARMUP_STEP 76 | if self.trainer.global_step < warmup_step: 77 | if self.config.TRAINER.WARMUP_TYPE == 'linear': 78 | base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR 79 | lr = base_lr + \ 80 | (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \ 81 | abs(self.config.TRAINER.TRUE_LR - base_lr) 82 | for pg in optimizer.param_groups: 83 | pg['lr'] = lr 84 | elif self.config.TRAINER.WARMUP_TYPE == 'constant': 85 | pass 86 | else: 87 | raise ValueError( 88 | f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}') 89 | 90 | # update params 91 | optimizer.step(closure=optimizer_closure) 92 | optimizer.zero_grad() 93 | 94 | def _trainval_inference(self, batch): 95 | # print("batch.shape",batch) 96 | with self.profiler.profile("Compute coarse supervision"): 97 | compute_supervision_coarse(batch, self.config) 98 | 99 | with self.profiler.profile("CASA"): 100 | self.matcher(batch) 101 | 102 | 103 | with self.profiler.profile("Compute losses"): 104 | self.loss(batch) 105 | 106 | def _valtest_inference(self, batch): 107 | with self.profiler.profile("CASA"): 108 | self.matcher(batch, False) 109 | 110 | def _test_inference(self, batch): 111 | 112 | with self.profiler.profile("CASA"): 113 | self.matcher(batch, True) 114 | 115 | 116 | def _compute_metrics(self, batch): 117 | with self.profiler.profile("Copmute metrics"): 118 | # compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match 119 | # compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair 120 | 121 | rel_pair_names = list(zip(*batch['pair_names'])) 122 | bs = batch['keypoints0'].size(0) 123 | metrics = { 124 | # to filter duplicate pairs caused by DistributedSampler 125 | 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)]} 126 | # 'inliers': batch['inliers']} 127 | ret_dict = {'metrics': metrics} 128 | return ret_dict, rel_pair_names 129 | 130 | def on_train_epoch_start(self): 131 | pass 132 | 133 | 134 | def training_step(self, batch, batch_idx): 135 | self._trainval_inference(batch) 136 | return {'loss': batch['loss']} 137 | 138 | def training_epoch_end(self, outputs): 139 | 140 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean() 141 | if self.trainer.global_rank == 0: 142 | self.logger.experiment.add_scalar( 143 | 'train/avg_loss_on_epoch', avg_loss, 144 | global_step=self.current_epoch) 145 | if self.vis_conf_train: 146 | target_value = 0 147 | for ii in range(len(outputs)): 148 | index_id = (outputs[ii]['pair_id'] == 149 | target_value).nonzero(as_tuple=True) 150 | # print("index_id[0]",index_id[0]) 151 | if len(index_id[0]): 152 | batch_num = ii 153 | index_in_batch = index_id 154 | conf_matrix = outputs[batch_num]['conf_matrix'][index_in_batch].cpu().detach().numpy()[ 155 | 0] 156 | # print("conf_matrix",np.shape(conf_matrix)) 157 | if not os.path.exists("{}/vis".format(self.logger.log_dir)): 158 | os.mkdir("{}/vis".format(self.logger.log_dir)) 159 | if not os.path.exists("{}/vis/conf_matrix_train".format(self.logger.log_dir)): 160 | os.mkdir("{}/vis/conf_matrix_train".format(self.logger.log_dir)) 161 | save_path = "{0}/vis/conf_matrix_train/{1:06d}.png".format( 162 | self.logger.log_dir, self.current_epoch) 163 | vis_conf_matrix(conf_matrix, save_path) 164 | 165 | def validation_step(self, batch, batch_idx): 166 | self._valtest_inference(batch) 167 | 168 | 169 | return {'emb0': batch['emb0'], 'emb1': batch['emb0'], 'len0': batch['len0'], 'len1': batch['len0'], 'label0': batch['label0'], 170 | 'label1': batch['label0'], 'mode': batch['mode'], 'pair_names': batch['pair_names']} 171 | 172 | 173 | def predict_step(self, batch, batch_idx): 174 | self._valtest_inference(batch) 175 | return {'emb0': batch['emb0'], 'label0': batch['label0'], 'steps0': batch['steps0'], 'len0': batch['len0'], 176 | 'mode': batch['mode'], 'keypoints0': batch['keypoints0'], 'pair_names': batch['pair_names']} 177 | 178 | def validation_epoch_end(self, outputs): 179 | # handle multiple validation sets 180 | 181 | MEAN_EMB = True 182 | if MEAN_EMB: 183 | train_embs = {} 184 | val_embs = {} 185 | train_labels = {} 186 | val_labels = {} 187 | else: 188 | train_embs = [] 189 | val_embs = [] 190 | train_labels = [] 191 | val_labels = [] 192 | # print("outputs",outputs) 193 | 194 | for output in outputs: 195 | emb0 = output['emb0'].cpu().detach().numpy() 196 | emb1 = output['emb1'].cpu().detach().numpy() 197 | label0 = output['label0'].cpu().detach().numpy() 198 | label1 = output['label1'].cpu().detach().numpy() 199 | len0 = output['len0'].cpu().detach().numpy() 200 | len1 = output['len1'].cpu().detach().numpy() 201 | 202 | len_data = len(output['pair_names']) 203 | # print("len_data",len_data) 204 | 205 | for ii in range(len_data): 206 | 207 | if self.config.DATASET.NAME == "kallax_shelf_drawer": 208 | key1 = output['pair_names'][ii] 209 | key2 = -1 210 | else: 211 | # key1 = int(output['pair_names'][ii]) 212 | key1 = output['pair_names'][ii] 213 | key2 = -1 214 | if MEAN_EMB: 215 | 216 | if output['mode'][ii] == 'train': 217 | if key1 in train_embs.keys(): 218 | train_embs[key1].append(emb0[ii][:len0[ii]]) 219 | train_labels[key1].append( 220 | label0[ii][:len0[ii]]) 221 | else: 222 | train_embs[key1] = [emb0[ii][:len0[ii]]] 223 | train_labels[key1] = [label0[ii][:len0[ii]]] 224 | 225 | if key2 in train_embs.keys(): 226 | train_embs[key2].append(emb1[ii][:len1[ii]]) 227 | train_labels[key2].append( 228 | label1[ii][:len1[ii]]) 229 | elif key2 is not -1: 230 | train_embs[key2] = [emb1[ii][:len1[ii]]] 231 | train_labels[key2] = [label1[ii][:len1[ii]]] 232 | 233 | elif output['mode'][ii] == 'val': 234 | if key1 in val_embs.keys(): 235 | val_embs[key1].append(emb0[ii][:len0[ii]]) 236 | val_labels[key1].append(label0[ii][:len0[ii]]) 237 | else: 238 | val_embs[key1] = [emb0[ii][:len0[ii]]] 239 | val_labels[key1] = [label0[ii][:len0[ii]]] 240 | 241 | if key2 in val_embs.keys(): 242 | val_embs[key2].append(emb1[ii][:len1[ii]]) 243 | val_labels[key2].append(label1[ii][:len1[ii]]) 244 | elif key2 is not -1: 245 | val_embs[key2] = [emb1[ii][:len1[ii]]] 246 | val_labels[key2] = [label1[ii][:len1[ii]]] 247 | 248 | else: 249 | if output['mode'][ii] == 'train': 250 | # print("train_hi") 251 | train_embs.append(emb0[ii]) 252 | train_labels.append(label0[ii]) 253 | train_embs.append(emb1[ii]) 254 | train_labels.append(label1[ii]) 255 | elif output['mode'][ii] == 'val': 256 | val_embs.append(emb0[ii]) 257 | val_labels.append(label0[ii]) 258 | val_embs.append(emb1[ii]) 259 | val_labels.append(label1[ii]) 260 | # print() 261 | datasets = {'train_dataset': {'embs': train_embs, 'labels': train_labels}, 262 | 'val_dataset': {'embs': val_embs, 'labels': val_labels}} 263 | print(self.profiler.summary()) 264 | # datasets_event = copy.deepcopy(datasets) 265 | (train_accs, val_accs) = self.classification.evaluate_embeddings( 266 | datasets, emb_mean=False, DICT=True, acc_list=self.acc_list) 267 | if self.config.EVAL.EVENT_COMPLETION: 268 | train_completion_score, val_completion_score = self.eventcompletion.evaluate_embeddings( 269 | datasets, emb_mean=False, DICT=True) 270 | 271 | if self.config.EVAL.KENDALLS_TAU: 272 | datasets_pair = {} 273 | train_dataset = [] 274 | val_dataset = [] 275 | datasets_pair['train_dataset'] = {} 276 | datasets_pair['val_dataset'] = {} 277 | for output in outputs: 278 | # print("output['mode']", output['mode'][0]) 279 | # print("output['pair_names'][",output['pair_names']) 280 | emb0 = output['emb0'].cpu().detach().numpy() 281 | emb1 = output['emb1'].cpu().detach().numpy() 282 | label0 = output['label0'].cpu().detach().numpy() 283 | label1 = output['label1'].cpu().detach().numpy() 284 | len0 = output['len0'].cpu().detach().numpy() 285 | len1 = output['len1'].cpu().detach().numpy() 286 | len_data = len(output['pair_names']) 287 | for ii in range(len_data): 288 | if output['mode'][ii] == 'train': 289 | train_dataset.append( 290 | [emb0[ii][:len0[ii]], emb1[ii][:len1[ii]]]) 291 | else: 292 | val_dataset.append( 293 | [emb0[ii][:len0[ii]], emb1[ii][:len1[ii]]]) 294 | datasets_pair['train_dataset']['embs'] = train_dataset 295 | datasets_pair['val_dataset']['embs'] = val_dataset 296 | 297 | 298 | (train_tau, val_tau) = self.kendallstau.evaluate_embeddings( 299 | datasets) 300 | # print("(train_tau, val_tau)", (train_tau, val_tau)) 301 | 302 | if val_accs != 0: 303 | acc_list = self.acc_list 304 | for ii, train_acc in enumerate(train_accs): 305 | self.logger.experiment.add_scalar('classification/train_{}_accuracy'.format(acc_list[ii]), 306 | train_acc, global_step=self.current_epoch) 307 | for ii, val_acc in enumerate(val_accs): 308 | self.logger.experiment.add_scalar('classification/val_{}_accuracy'.format(acc_list[ii]), 309 | val_acc, global_step=self.current_epoch) 310 | if self.config.EVAL.KENDALLS_TAU: 311 | self.logger.experiment.add_scalar( 312 | 'kendalls_tau/train', train_tau, global_step=self.current_epoch) 313 | self.logger.experiment.add_scalar( 314 | 'kendalls_tau/val', val_tau, global_step=self.current_epoch) 315 | if self.config.EVAL.EVENT_COMPLETION: 316 | self.logger.experiment.add_scalar( 317 | 'event_progress/train', train_completion_score, global_step=self.current_epoch) 318 | self.logger.experiment.add_scalar( 319 | 'event_progress/val', val_completion_score, global_step=self.current_epoch) 320 | 321 | # Visualize matrix 322 | if self.vis_conf_val: 323 | # Calculate conf_matrix for visualization 324 | if self.match_type == "dual_softmax": 325 | conf_matrix, _, _ = dual_softmax(torch.tensor(np.expand_dims(datasets['train_dataset']['embs'][0], axis=0)), 326 | torch.tensor(np.expand_dims( 327 | datasets['train_dataset']['embs'][1], axis=0)), self.temperature) 328 | elif self.match_type == "dual_bicross": 329 | conf_matrix = dual_bicross(torch.tensor(np.expand_dims(datasets['train_dataset']['embs'][0], axis=0)), 330 | torch.tensor(np.expand_dims( 331 | datasets['train_dataset']['embs'][1], axis=0))) 332 | # make folders 333 | if not os.path.exists("{}/vis".format(self.logger.log_dir)): 334 | os.mkdir("{}/vis".format(self.logger.log_dir)) 335 | if not os.path.exists("{}/vis/conf_matrix_val".format(self.logger.log_dir)): 336 | os.mkdir("{}/vis/conf_matrix_val".format(self.logger.log_dir)) 337 | 338 | save_path = "{0}/vis/conf_matrix_val/{1:06d}.png".format( 339 | self.logger.log_dir, self.current_epoch) 340 | 341 | vis_conf_matrix(conf_matrix.cpu().detach().numpy()[0], save_path) 342 | 343 | mask = conf_matrix > 0.0 344 | 345 | if self.match_type == "dual_softmax": 346 | i_mask = conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0] 347 | j_mask = conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0] 348 | mask = mask * i_mask * j_mask 349 | elif self.match_type == "dual_bicross": 350 | thres = 0.5 351 | mask = conf_matrix > thres 352 | 353 | if not os.path.exists("{}/vis/mask_train".format(self.logger.log_dir)): 354 | os.mkdir("{}/vis/mask_train".format(self.logger.log_dir)) 355 | save_path_mask = "{0}/vis/mask_train/{1:06d}.png".format(self.logger.log_dir, 356 | self.current_epoch) 357 | 358 | vis_conf_matrix(mask.cpu().detach().numpy()[ 359 | 0].astype(float), save_path_mask) 360 | 361 | # Todo: create matching from both videos and visualize it. 362 | 363 | def test_step(self, batch, batch_idx): 364 | self._test_inference(batch) 365 | return {'emb0': batch['emb0'], 'emb1': batch['emb1'], 'len0': batch['len0'], 'len1': batch['len1'], 'label0': batch['label0'], 366 | 'label1': batch['label1'], 'mode': batch['mode'], 'pair_names': batch['pair_names']} 367 | 368 | def test_epoch_end(self, outputs): 369 | # metrics: dict of list, numpy 370 | MEAN_EMB = True 371 | if MEAN_EMB: 372 | train_embs = {} 373 | val_embs = {} 374 | train_labels = {} 375 | val_labels = {} 376 | else: 377 | train_embs = [] 378 | val_embs = [] 379 | train_labels = [] 380 | val_labels = [] 381 | 382 | for output in outputs: 383 | emb0 = output['emb0'].cpu().detach().numpy() 384 | emb1 = output['emb1'].cpu().detach().numpy() 385 | label0 = output['label0'].cpu().detach().numpy() 386 | label1 = output['label1'].cpu().detach().numpy() 387 | len0 = output['len0'].cpu().detach().numpy() 388 | len1 = output['len1'].cpu().detach().numpy() 389 | 390 | len_data = len(output['pair_names']) 391 | # print("len_data",len_data) 392 | 393 | for ii in range(len_data): 394 | 395 | # key1 = int(output['pair_names'][ii]) 396 | key1 = output['pair_names'][ii] 397 | key2 = -1 398 | if MEAN_EMB: 399 | if output['mode'][ii] == 'train': 400 | if key1 in train_embs.keys(): 401 | train_embs[key1].append(emb0[ii][:len0[ii]]) 402 | train_labels[key1].append( 403 | label0[ii][:len0[ii]]) 404 | else: 405 | train_embs[key1] = [emb0[ii][:len0[ii]]] 406 | train_labels[key1] = [label0[ii][:len0[ii]]] 407 | 408 | if key2 in train_embs.keys(): 409 | train_embs[key2].append(emb1[ii][:len1[ii]]) 410 | train_labels[key2].append( 411 | label1[ii][:len1[ii]]) 412 | elif key2 is not -1: 413 | train_embs[key2] = [emb1[ii][:len1[ii]]] 414 | train_labels[key2] = [label1[ii][:len1[ii]]] 415 | 416 | elif output['mode'][ii] == 'val': 417 | if key1 in val_embs.keys(): 418 | val_embs[key1].append(emb0[ii][:len0[ii]]) 419 | val_labels[key1].append(label0[ii][:len0[ii]]) 420 | else: 421 | val_embs[key1] = [emb0[ii][:len0[ii]]] 422 | val_labels[key1] = [label0[ii][:len0[ii]]] 423 | 424 | if key2 in val_embs.keys(): 425 | val_embs[key2].append(emb1[ii][:len1[ii]]) 426 | val_labels[key2].append(label1[ii][:len1[ii]]) 427 | elif key2 is not -1: 428 | val_embs[key2] = [emb1[ii][:len1[ii]]] 429 | val_labels[key2] = [label1[ii][:len1[ii]]] 430 | datasets = {'train_dataset': {'embs': train_embs, 'labels': train_labels}, 431 | 'val_dataset': {'embs': val_embs, 'labels': val_labels}} 432 | datasets_pair = {} 433 | train_dataset = [] 434 | val_dataset = [] 435 | datasets_pair['train_dataset'] = {} 436 | datasets_pair['val_dataset'] = {} 437 | 438 | (train_accs, val_accs) = self.classification.evaluate_embeddings( 439 | datasets, acc_list=self.acc_list, DICT=True) 440 | train_completion_score, val_completion_score = self.eventcompletion.evaluate_embeddings( 441 | datasets, emb_mean=False, DICT=True) 442 | (train_tau, val_tau) = self.kendallstau.evaluate_embeddings( 443 | datasets) 444 | -------------------------------------------------------------------------------- /src/datasets/pennaction.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from typing import Dict 3 | from unicodedata import name 4 | import itertools 5 | import copy 6 | from numpy.lib.function_base import insert 7 | from scipy.spatial.transform import Rotation as R 8 | import numpy as np 9 | import torch 10 | import torch.utils as utils 11 | from numpy.linalg import inv 12 | import random 13 | from dataset_preparation.preprocess_norm_mat import pre_normalization_mat, get_openpose_connectivity 14 | from bodymocap.models import SMPLX 15 | import time 16 | 17 | # Loading VPoser Body Pose Prior 18 | from human_body_prior.tools.model_loader import load_model 19 | from human_body_prior.models.vposer_model import VPoser 20 | 21 | # Mano 22 | from manopth.manolayer import ManoLayer 23 | 24 | # Drawing tools. 25 | import matplotlib.pyplot as plt 26 | from mpl_toolkits.mplot3d import Axes3D 27 | import sys 28 | import os 29 | import subprocess 30 | 31 | from dataset_splits import DATASETS 32 | 33 | 34 | def moving_average(x, w): 35 | return np.convolve(x, np.ones(w), 'same') / w 36 | 37 | 38 | class PennActionDataset(utils.data.Dataset): 39 | def __init__(self, 40 | npz_path, 41 | num_frames=None, 42 | sampling_strategy=None, 43 | augmentation_strategy=None, 44 | mode='train', 45 | augment_fn=None, 46 | pose_dir=None, 47 | contrastive=False, 48 | val=False, 49 | use_norm=True, 50 | config=None, 51 | **kwargs): 52 | 53 | super().__init__() 54 | # self.pose_dir = pose_dir if pose_dir is not None else root_dir 55 | self.config = config 56 | self.mode = mode 57 | self.contrastive = contrastive 58 | self.num_frames = num_frames 59 | self.sampling_strategy = sampling_strategy 60 | self.use_norm = use_norm 61 | self.smpl = self.config.DATASET.SMPL 62 | self.mano = self.config.DATASET.MANO 63 | self.augmentation_strategy = augmentation_strategy 64 | self.val = val 65 | self.max_len = DATASETS[self.config.DATASET.NAME]['max_len'] 66 | self.dataset_path = self.config.DATASET.PATH 67 | 68 | # prepare data_names, intrinsics and extrinsics(T) 69 | print("npz_path", npz_path) 70 | self.mode_str = npz_path.split('.')[-2].split('_')[-1] 71 | npz_path = os.path.join(self.dataset_path, npz_path) 72 | if not self.use_norm: 73 | npz_path = npz_path + 'nn' 74 | elif self.smpl: 75 | npz_path = npz_path # + 'nn2' 76 | smpl_dir = os.path.join(self.dataset_path, 'smpl/models') 77 | print("self.dataset_path", self.dataset_path) 78 | print("smpl_dir", smpl_dir) 79 | self.smpl_model = SMPLX(smpl_dir, 80 | batch_size=self.max_len, 81 | num_betas=10, 82 | use_pca=False, 83 | create_transl=False) 84 | 85 | self.device = torch.device('cuda') # cpu cuda 86 | self.smpl_model.to(device=self.device) 87 | self.smpl_keypoints = {} 88 | 89 | support_dir = os.path.join( 90 | self.dataset_path, "human_body_prior/support_data/downloads/") 91 | expr_dir = osp.join(support_dir, 'vposer_v2_05') 92 | vp, ps = load_model(expr_dir, model_code=VPoser, 93 | remove_words_in_model_weights='vp_model.', 94 | disable_grad=True) 95 | self.vp = vp.to(device=self.device) 96 | elif self.mano: 97 | ncomps = 45 98 | mano_dir = os.path.join(self.dataset_path, 'mano/models') 99 | self.mano_layer_l = ManoLayer( 100 | mano_root=mano_dir, use_pca=False, ncomps=ncomps, flat_hand_mean=True, side='left') 101 | self.mano_layer_r = ManoLayer( 102 | mano_root=mano_dir, use_pca=False, ncomps=ncomps, flat_hand_mean=True, side='right') 103 | 104 | self.device = torch.device('cuda') # cpu cuda 105 | self.mano_layer_l.to(device=self.device) 106 | self.mano_layer_r.to(device=self.device) 107 | 108 | self.mano_keypoints = {} 109 | 110 | data = np.load(npz_path, allow_pickle=True) 111 | self.data = data.item() 112 | 113 | 114 | 115 | if self.val: 116 | self.data_names = list(self.data.keys()) 117 | else: 118 | self.data_names = list(self.data.keys()) * 50 119 | 120 | self.augment_fn = augment_fn if mode == 'train' else None 121 | 122 | def __len__(self): 123 | return len(self.data_names) 124 | 125 | def get_keypoints_from_smpl(self, pose, beta): 126 | 127 | M, D1 = beta.shape 128 | M, D2 = pose.shape 129 | betas = torch.zeros(self.max_len, D1) 130 | body_pose = torch.zeros(self.max_len, D2) 131 | betas[:M, :] = torch.from_numpy(beta.astype(np.float32)) 132 | body_pose[:M, :] = torch.from_numpy(pose.astype(np.float32)) 133 | 134 | body_pose = body_pose.to(device=self.device) 135 | betas = betas.to(device=self.device) 136 | smpl_output = self.smpl_model( 137 | global_orient=body_pose[:, :3], 138 | body_pose=body_pose[:, 3:], 139 | betas=betas) 140 | #middle_t = time.time() 141 | 142 | keypoints = pre_normalization_mat(np.reshape(smpl_output.joints.detach().cpu().numpy()[ 143 | :M, :25], (-1, 25, 3)), zaxis=[1, 8], xaxis=[1, 5]) 144 | return keypoints 145 | 146 | def get_keypoints_from_mano(self, pose, beta): 147 | 148 | M, D1 = beta.shape 149 | M, D2 = pose.shape 150 | betas = torch.zeros(self.max_len, D1) 151 | mano_pose = torch.zeros(self.max_len, D2) 152 | betas[:M, :] = torch.from_numpy(beta.astype(np.float32)) 153 | mano_pose[:M, :] = torch.from_numpy(pose.astype(np.float32)) 154 | 155 | mano_pose = mano_pose.to(device=self.device) 156 | betas = betas.to(device=self.device) 157 | 158 | trans_l = torch.unsqueeze(mano_pose[:, :3], 1) 159 | trans_r = torch.unsqueeze(mano_pose[:, 51:54], 1) 160 | 161 | _, mano_output_l = self.mano_layer_l(mano_pose[:, 3:51], betas[:, :10]) 162 | _, mano_output_r = self.mano_layer_l(mano_pose[:, 54:], betas[:, 10:]) 163 | 164 | 165 | mano_output = torch.cat( 166 | (mano_output_l+trans_l, mano_output_r+trans_r), axis=1) 167 | keypoints = pre_normalization_mat(np.reshape(mano_output.detach().cpu().numpy()[ 168 | :M, :42], (-1, 42, 3)), zaxis=[0, 1], xaxis=[0, 9]) 169 | return keypoints 170 | 171 | 172 | def sample_steps(self, data_name): 173 | def _sample_uniform(item_len): 174 | interval = item_len // self.num_frames 175 | steps = range(0, item_len, interval)[:self.num_frames] 176 | return sorted(steps) 177 | 178 | def _sample_random(item_len): 179 | steps = random.sample( 180 | range(1, item_len), self.num_frames) 181 | return sorted(steps) 182 | 183 | def _sample_all(): 184 | return list(range(0, self.num_frames)) 185 | 186 | len0 = len(self.data[data_name[0]]['labels']) 187 | check0 = (self.num_frames <= len0) 188 | if check0: 189 | steps0 = _sample_random(len0) 190 | else: 191 | steps0 = _sample_all() 192 | len1 = len(self.data[data_name[1]]['labels']) 193 | check1 = (self.num_frames <= len1) 194 | if check1: 195 | steps1 = _sample_random(len1) 196 | else: 197 | steps1 = _sample_all() 198 | 199 | return torch.IntTensor(steps0), torch.IntTensor(steps1) 200 | 201 | def sample_steps_one(self, len0): 202 | # num_frames=20 203 | def _sample_uniform(item_len): 204 | interval = item_len // self.num_frames 205 | steps = range(0, item_len, interval)[:self.num_frames] 206 | return sorted(steps) 207 | 208 | def _sample_random(item_len): 209 | steps = random.sample( 210 | range(1, item_len), self.num_frames) 211 | return sorted(steps) 212 | 213 | def _sample_all(): 214 | return list(range(0, self.num_frames)) 215 | 216 | check0 = (self.num_frames <= len0) 217 | if check0: 218 | steps0 = _sample_random(len0) 219 | else: 220 | steps0 = _sample_all() 221 | 222 | return torch.IntTensor(steps0) 223 | 224 | def get_steps(self, step): 225 | """Sample multiple context steps for a given step.""" 226 | 227 | num_steps = self.config.DATASET.NUM_STEPS 228 | stride = self.config.DATASET.FRAME_STRIDE 229 | 230 | if num_steps < 1: 231 | return step 232 | if stride < 1: 233 | raise ValueError('stride should be >= 1.') 234 | steps = torch.arange(step - (num_steps - 1) * 235 | stride, step + stride, stride) 236 | 237 | return steps 238 | 239 | def __getitem__(self, idx): 240 | data_name = self.data_names[idx] 241 | 242 | 243 | if self.contrastive: 244 | 245 | len_st = len(self.augmentation_strategy) 246 | strategy = self.augmentation_strategy 247 | 248 | len0 = len(self.data[data_name]['labels']) 249 | len1 = len0 250 | vposer_prob0 = 1 251 | # Todo: Get the value directly from the setting. 252 | dim, channel = self.config.CASA.MATCH.D_MODEL//3, 3 # 25, 3 253 | 254 | if self.config.DATASET.ATT_STYLE: 255 | 256 | # print("hihi") 257 | NO_TIME_AUG = ('fast' not in strategy) and ( 258 | 'slow' not in strategy) 259 | steps0 = np.array(list(range(self.max_len))) 260 | steps1 = np.array(list(range(self.max_len))) 261 | 262 | keypoints0 = np.zeros([self.max_len, dim, channel]) 263 | keypoints1 = np.zeros([self.max_len, dim, channel]) 264 | 265 | label0 = np.ones([self.max_len]) * (self.max_len-1) 266 | label1 = np.ones([self.max_len]) * (self.max_len-1) 267 | 268 | # if NO_TIME_AUG: 269 | #steps0 = self.sample_steps_one(len0) 270 | matched_list = list(range(self.max_len)) 271 | if self.smpl: 272 | T, D = self.data[data_name]['pose'].shape 273 | if data_name in self.smpl_keypoints.keys(): 274 | keypoints0[:len0, :, 275 | :] = self.smpl_keypoints[data_name] 276 | else: 277 | keypoints0[:len0, :, :] = self.get_keypoints_from_smpl( 278 | self.data[data_name]['pose'], self.data[data_name]['beta']) 279 | self.smpl_keypoints[data_name] = keypoints0[:len0, :, :] 280 | elif self.mano: 281 | T, D = self.data[data_name]['pose'].shape 282 | if data_name in self.mano_keypoints.keys(): 283 | keypoints0[:len0, :, 284 | :] = self.mano_keypoints[data_name] 285 | else: 286 | keypoints0[:len0, :, :] = self.get_keypoints_from_mano( 287 | self.data[data_name]['pose'], self.data[data_name]['beta']) 288 | self.mano_keypoints[data_name] = keypoints0[:len0, :, :] 289 | else: 290 | keypoints0[:len0, :, :] = self.data[data_name]['pose'] 291 | label0[:len0] = self.data[data_name]['labels'] 292 | 293 | keypoints1 = copy.deepcopy(keypoints0) 294 | label1 = copy.deepcopy(label0) 295 | steps1 = copy.deepcopy(steps0) 296 | 297 | steps1 = np.array(steps1) 298 | steps0 = np.array(steps0) 299 | 300 | if self.val: 301 | # If it is val, need to put back the same seq as the original 302 | # No 4D augmentation. 303 | data = { 304 | 'len0': len0, 305 | 'len1': len1, 306 | 'steps0': steps0, 307 | 'steps1': steps1, 308 | 'mode': self.mode_str, 309 | 'keypoints0': keypoints0, # (1, h, w) 310 | 'label0': label0, # (h, w) 311 | 'keypoints1': keypoints1, 312 | 'label1': label1, 313 | 'matching': torch.IntTensor(matched_list), 314 | 'dataset_name': 'PennAction', 315 | 'pair_id': idx, 316 | 'pair_names': data_name 317 | } 318 | return data 319 | 320 | # time augmentation before translation noise and flipping. 321 | slow_fast_prob = np.random.uniform(low=-0, high=1) 322 | if 'fast' in strategy: 323 | if ('slow' not in strategy) or slow_fast_prob > 0.5: 324 | fast_coeff = np.random.uniform(low=0, high=0.5) 325 | steps0_eff = np.array(list(range(len0))) 326 | steps1_eff = np.array(list(range(len1))) 327 | sampled_steps = random.sample( 328 | range(len0), int(fast_coeff*len0)) 329 | steps1_eff = np.delete(steps1_eff, sampled_steps) 330 | len1 = len(steps1_eff) 331 | 332 | # Initialize keypoints1 and lable 1 333 | keypoints1 = np.zeros([self.max_len, dim, channel]) 334 | label1 = np.ones([self.max_len]) * (self.max_len-1) 335 | 336 | # Assign numbers to keypoints1 and labels based on steps1 we calculated above. 337 | keypoints1[:len1, :, :] = keypoints0[steps1_eff] 338 | label1[:len1] = label0[steps1_eff] 339 | 340 | #len_new_label0 = len(label0) 341 | matched_list = [-1] * self.max_len 342 | 343 | nn0 = [np.argmin(abs(steps1_eff - b)) 344 | for b in steps0_eff] 345 | nn1 = [np.argmin(abs(steps0_eff - b)) 346 | for b in steps1_eff] 347 | 348 | if self.smpl or self.mano: 349 | # for the geometric augmentation 350 | pose_values = self.data[data_name]['pose'][steps1_eff] 351 | beta_values = self.data[data_name]['beta'][steps1_eff] 352 | 353 | for ii in range(len0): 354 | if ii == nn1[nn0[ii]]: 355 | matched_list[ii] = nn0[ii] 356 | 357 | if 'slow' in strategy: 358 | if ('fast' not in strategy) or slow_fast_prob < 0.5: 359 | def cal_interpolation(pose1, sampled_steps, len1): 360 | pose_vals = [] 361 | for step in sampled_steps: 362 | if step == len1-1: 363 | pose_val = pose1[step] 364 | else: 365 | # print("pose_val",pose_val.shape) 366 | euler_angle0 = R.from_rotvec(np.reshape( 367 | pose1[step], (-1, 3))).as_euler('zyx', degrees=True) 368 | #M, K = euler_angle.shape 369 | euler_angle1 = R.from_rotvec(np.reshape( 370 | pose1[step+1], (-1, 3))).as_euler('zyx', degrees=True) 371 | 372 | euler_angle = ( 373 | euler_angle0 + euler_angle1)/2 374 | pose_val = np.reshape(R.from_euler( 375 | 'zyx', euler_angle, degrees=True).as_rotvec(), (-1)) 376 | pose_vals.append(pose_val) 377 | 378 | return np.array(pose_vals) 379 | 380 | slow_coeff = np.random.uniform(low=0, high=0.5) 381 | steps0_eff = np.array(list(range(len0))) 382 | steps1_eff = np.array(list(range(len1))) 383 | 384 | # Select duplicated frames. 385 | select_num = self.max_len - len0 386 | sampled_steps = np.array(random.sample( 387 | range(len0), min(int(slow_coeff*len0), select_num))) 388 | 389 | if len(sampled_steps) > 0: 390 | # Initialize keypoints1, pose1 and lable 1 391 | keypoints1 = np.zeros([self.max_len, dim, channel]) 392 | pose_values = np.zeros([self.max_len, D]) 393 | label1 = np.ones([self.max_len]) * (self.max_len-1) 394 | 395 | # Define step1 396 | steps1_eff = np.concatenate( 397 | (steps1_eff, sampled_steps), axis=0) 398 | steps1_eff = np.sort(steps1_eff) 399 | # Assign numbers to keypoints1 and labels based on steps1 we calculated above. 400 | interpolated_pose = cal_interpolation( 401 | self.data[data_name]['pose'], sampled_steps, len1) 402 | 403 | len1 = len(steps1_eff) 404 | 405 | # Assign numbers to keypoints1 and labels based on steps1 we calculated above. 406 | pose_values = np.insert( 407 | self.data[data_name]['pose'], sampled_steps+1, interpolated_pose, axis=0)[:len1] 408 | 409 | beta_values = self.data[data_name]['beta'][steps1_eff] 410 | label1[:len1] = label0[steps1_eff] 411 | 412 | keypoints1[:len1, :, :] = self.get_keypoints_from_smpl( 413 | pose_values, beta_values) 414 | 415 | # added 1 to sampled steps so that it is located right after the number. 416 | matched_list = [-1] * self.max_len 417 | nn0 = [np.argmin(abs(steps1_eff - b)) 418 | for b in steps0_eff] 419 | nn1 = [np.argmin(abs(steps0_eff - b)) 420 | for b in steps1_eff] 421 | 422 | for ii in range(len0): 423 | if ii == nn1[nn0[ii]]: 424 | matched_list[ii] = nn0[ii] 425 | else: 426 | pose_values = self.data[data_name]['pose'] 427 | beta_values = self.data[data_name]['beta'] 428 | 429 | if NO_TIME_AUG: 430 | if self.smpl or self.mano: 431 | pose_values = self.data[data_name]['pose'] 432 | beta_values = self.data[data_name]['beta'] 433 | 434 | # geometric augmentation. 435 | ma_window = 10 436 | IID = False 437 | if 'noise_angle' in strategy: 438 | # 25 joints, x,y,z, 75 gaussiain 439 | vposer_prob0 = np.random.uniform(low=-0, high=1) 440 | if vposer_prob0 < 0.3: 441 | mu, sigma = 0, 10.0 442 | 443 | #print("pose_values", pose_values.shape) 444 | T, D = pose_values.shape 445 | # Augment angles in the Euler space. 446 | 447 | euler_angle = R.from_rotvec(np.reshape( 448 | pose_values, (-1, 3))).as_euler('zyx', degrees=True) 449 | M, K = euler_angle.shape 450 | euler_angle = np.reshape(euler_angle, (T, -1)) 451 | 452 | T, P = euler_angle.shape 453 | 454 | cov = ( 455 | T - np.abs(np.arange(T)[:, np.newaxis] - np.arange(T)[np.newaxis, :])/2) / T 456 | if IID: 457 | noise_angle = np.random.normal(mu, sigma, (P, T)) 458 | else: 459 | noise_angle = np.random.multivariate_normal( 460 | [mu]*T, cov*sigma, P) 461 | for ii in range(P): 462 | noise_angle[ii] = moving_average( 463 | noise_angle[ii], ma_window) 464 | euler_angle = np.reshape(euler_angle, (M, K)) # .T 465 | 466 | aug_rotvec = np.reshape(R.from_euler( 467 | 'zyx', euler_angle, degrees=True).as_rotvec(), (len1, -1)) 468 | 469 | pose_values = aug_rotvec 470 | if self.smpl: 471 | keypoints1[:len1, :, :] = self.get_keypoints_from_smpl( 472 | aug_rotvec, beta_values) 473 | elif self.mano: 474 | keypoints1[:len1, :, :] = self.get_keypoints_from_mano( 475 | aug_rotvec, beta_values) 476 | 477 | if "noise_vposer" in strategy: 478 | 479 | vposer_prob0 = np.random.uniform(low=-0, high=1) 480 | if vposer_prob0 < 0.1: 481 | 482 | mu, sigma = 0, 0.1 # 0.1 483 | DIM_VPOSER = 32 484 | 485 | body_pose = torch.from_numpy(pose_values[:, 3:66]).type( 486 | torch.float).to(device=self.device) 487 | 488 | cov = (len1 - np.abs(np.arange(len1) 489 | [:, np.newaxis] - np.arange(len1)[np.newaxis, :])/2) / len1 490 | if IID: 491 | #noise_angle = np.random.normal(mu, sigma, (P, T)) 492 | noise_vposer = np.random.normal( 493 | mu, sigma, DIM_VPOSER) 494 | else: 495 | noise_vposer = np.random.multivariate_normal( 496 | [mu]*len1, cov*sigma, DIM_VPOSER) 497 | for ii in range(DIM_VPOSER): 498 | noise_vposer[ii] = moving_average( 499 | noise_vposer[ii], ma_window) 500 | 501 | noise_vposer = noise_vposer.T 502 | 503 | body_poZ = self.vp.encode(body_pose).mean 504 | body_poZ = body_poZ.T 505 | vposer_window = 3 506 | vposer_edge_interval = vposer_window//2 + 1 507 | temp_start = copy.deepcopy( 508 | body_poZ[:, :vposer_edge_interval]) 509 | temp_end = copy.deepcopy( 510 | body_poZ[:, -vposer_edge_interval:]) 511 | # print("temp_end",temp_end.shape) 512 | body_poZ = body_poZ.detach().cpu().numpy() 513 | if not IID: 514 | for ii in range(DIM_VPOSER): 515 | body_poZ[ii] = moving_average( 516 | body_poZ[ii], vposer_window) 517 | body_poZ = torch.Tensor(body_poZ).type( 518 | torch.float).to(device=self.device) 519 | body_poZ[:, :vposer_edge_interval] = temp_start 520 | body_poZ[:, -vposer_edge_interval:] = temp_end 521 | body_poZ = body_poZ.T 522 | 523 | # add noise to enconded body poses. 524 | body_poZ = body_poZ + \ 525 | torch.Tensor(noise_vposer).type( 526 | torch.float).to(device=self.device) 527 | body_pose_rec = self.vp.decode( 528 | body_poZ)['pose_body'].contiguous().view(-1, 63) 529 | body_pose_rec = body_pose_rec.detach().cpu().numpy() 530 | 531 | body_pose_rec = np.concatenate( 532 | (pose_values[:, :3], body_pose_rec, pose_values[:, 66:]), axis=1) 533 | pose_values = body_pose_rec 534 | # print("body_pose_rec",body_pose_rec.shape) 535 | keypoints1[:len1, :, :] = self.get_keypoints_from_smpl( 536 | body_pose_rec, beta_values) 537 | 538 | 539 | if ('flip' in strategy): 540 | 541 | flip_prob1 = np.random.uniform(low=-0, high=1) 542 | if flip_prob1 < 0.3: 543 | keypoints1 = keypoints1 * [-1, 1, 1] 544 | keypoints1 = keypoints1.astype(np.float32) 545 | if self.smpl: 546 | keypoints1[:len1] = pre_normalization_mat(keypoints1[:len1], 547 | zaxis=[1, 8], xaxis=[1, 5]) 548 | elif self.mano: 549 | keypoints1[:len1] = pre_normalization_mat(keypoints1[:len1], 550 | zaxis=[0, 1], xaxis=[0, 9]) 551 | else: 552 | keypoints1[:len1] = pre_normalization_mat(keypoints1[:len1], 553 | zaxis=[5, 11], xaxis=[5, 6], ERR_BORN=True) 554 | 555 | # and (vposer_prob0 > 0.3): 556 | if ('noise_translation' in strategy): 557 | # make noise translation and vposer not working together. 558 | vposer_prob0 = np.random.uniform(low=-0, high=1) 559 | if vposer_prob0 < 0.3: 560 | # 25 joints, x,y,z, 75 gaussiain 561 | T, K, D = keypoints1.shape 562 | mu, sigma = 0, 0.1 # 0.1 563 | noise_trans = np.random.normal(mu, sigma, (T, K, D)) 564 | keypoints1 = keypoints1 + noise_trans 565 | 566 | VIS = False 567 | if VIS: 568 | 569 | vis_folder = 'vis/motion/key0' 570 | s = copy.deepcopy(keypoints0) 571 | s = s * [1, -1, -1] 572 | s1 = copy.deepcopy(keypoints1) 573 | s1 = s1 * [1, -1, -1] 574 | fig = plt.figure() 575 | connectivity = get_openpose_connectivity() 576 | ax = plt.axes(projection='3d') 577 | 578 | for i_s in range(len0): 579 | ax.set_xlim3d(-2, 2) 580 | ax.set_ylim3d(-2, 2) 581 | ax.set_zlim3d(-2, 2) 582 | 583 | for limb in connectivity: 584 | ax.plot3D(s[i_s, limb, 0], 585 | s[i_s, limb, 1], s[i_s, limb, 2], 'red') 586 | ax.scatter3D(s[i_s, :, 0], s[i_s, :, 1], 587 | s[i_s, :, 2], cmap='Greens', s=2) 588 | 589 | for limb in connectivity: 590 | ax.plot3D(s1[i_s, limb, 0], 591 | s1[i_s, limb, 1], s1[i_s, limb, 2], 'blue') 592 | ax.scatter3D(s1[i_s, :, 0], s1[i_s, :, 1], 593 | s1[i_s, :, 2], cmap='Greens', s=2) 594 | 595 | plt.show(block=False) 596 | plt.savefig(vis_folder + "/file%04d.png" % i_s) 597 | plt.cla() 598 | 599 | os.chdir("vis/motion/key0") 600 | subprocess.call([ 601 | 'ffmpeg', '-framerate', '8', '-i', 'file%04d.png', '-r', '30', '-pix_fmt', 'yuv420p', 602 | 'video_name.mp4' 603 | ]) 604 | sys.exit(0) 605 | 606 | data = { 607 | # 'valid_mask' : mask, 608 | 'len0': len0, 609 | 'len1': len1, 610 | 'steps0': steps0, 611 | 'steps1': steps1, 612 | 'mode': self.mode_str, 613 | 'keypoints0': keypoints0, # (1, h, w) 614 | 'label0': label0, # (h, w) 615 | 'keypoints1': keypoints1, 616 | 'label1': label1, 617 | 'matching': torch.IntTensor(matched_list), 618 | 'dataset_name': 'PennAction', 619 | 'pair_id': idx, 620 | 'pair_names': data_name 621 | } 622 | return data 623 | 624 | 625 | else: 626 | 627 | label0 = self.data[data_name]['labels'] 628 | len0 = len(label0) 629 | if self.sampling_strategy == 'offset_uniform': 630 | steps0 = list(range(len0)) 631 | steps0 = torch.reshape(torch.stack( 632 | list(map(self.get_steps, steps0))), [-1]) 633 | steps0 = torch.maximum(steps0, torch.tensor(0)) 634 | steps0 = torch.minimum(steps0, torch.tensor(len0-1)) 635 | keypoints0 = self.data[data_name]['pose'][steps0] 636 | label0 = self.data[data_name]['labels'][steps0] 637 | data = { 638 | 'len0': len0, 639 | 'steps0': steps0, 640 | 'mode': self.mode_str, 641 | 'keypoints0': keypoints0, # (1, h, w) 642 | 'label0': label0, # (h, w) 643 | 'dataset_name': 'PennAction', 644 | 'pair_id': idx, 645 | 'pair_names': data_name 646 | } 647 | 648 | return data 649 | --------------------------------------------------------------------------------