├── logs └── .gitkeep ├── configs ├── .gitkeep └── tan.yaml ├── plb ├── __init__.py ├── models │ ├── __init__.py │ ├── self_supervised │ │ ├── __init__.py │ │ └── tan │ │ │ ├── __init__.py │ │ │ ├── transforms.py │ │ │ └── tan_module.py │ └── encoder.py └── datamodules │ ├── __init__.py │ ├── dataset.py │ ├── data_transform.py │ └── seq_datamodule.py ├── src ├── __init__.py ├── data │ ├── distance │ │ ├── __init__.py │ │ └── base.py │ └── dataset │ │ ├── __init__.py │ │ ├── split.py │ │ ├── cluster_misc.py │ │ ├── base.py │ │ ├── visualizer.py │ │ ├── splits │ │ ├── split_wseed_1234.json │ │ └── split_wseed_4321.json │ │ ├── loader.py │ │ └── utils.py ├── algo │ ├── __init__.py │ └── kmeans_skl.py └── metrics.py ├── show.gif ├── alata.ttf ├── README.md ├── .gitignore ├── pretrain.py └── cluster.py /logs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /plb/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /show.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/acton/HEAD/show.gif -------------------------------------------------------------------------------- /alata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/likenneth/acton/HEAD/alata.ttf -------------------------------------------------------------------------------- /src/data/distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SnippetDistance as base -------------------------------------------------------------------------------- /src/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import AISTDataset 2 | from .base import SnippetDataset as base -------------------------------------------------------------------------------- /plb/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from plb.datamodules.seq_datamodule import SeqDataModule 2 | 3 | __all__ = [ 4 | 'SeqDataModule', 5 | ] 6 | -------------------------------------------------------------------------------- /plb/models/self_supervised/__init__.py: -------------------------------------------------------------------------------- 1 | from plb.models.self_supervised.tan.tan_module import TAN 2 | 3 | __all__ = [ 4 | "TAN", 5 | ] 6 | -------------------------------------------------------------------------------- /src/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .kmeans_skl import get_best_clusterer as kmeans_skl 2 | from .kmeans_skl import Clusterer as kmeans_skl_clusterer 3 | -------------------------------------------------------------------------------- /plb/models/self_supervised/tan/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import TrainDataTransform as TANTrainDataTransform 2 | from .transforms import EvalDataTransform as TANEvalDataTransform 3 | from .transforms import FinetuneTransform as TANFinetuneDataTransform 4 | -------------------------------------------------------------------------------- /src/data/dataset/split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | from pathlib import Path 5 | from loader import AISTDataset 6 | 7 | genre_list = ["gBR", "gPO", "gLO", "gMH", "gLH", "gHO", "gWA", "gKR", "gJS", "gJB"] 8 | data_dir = Path("../../../../aistplusplus") 9 | assert data_dir.exists() 10 | validation_number = 42 11 | 12 | def main(): 13 | seed = 4321 14 | np.random.seed(seed) 15 | 16 | offical_loader = AISTDataset(data_dir / "annotations") 17 | seq_container = {k: [] for k in genre_list} 18 | for seq_name in offical_loader.mapping_seq2env.keys(): 19 | genre = seq_name.split("_")[0] 20 | seq_container[genre].append(seq_name) 21 | 22 | val = {k: [] for k in genre_list} 23 | total = 0 24 | for genre in genre_list: 25 | total += len(seq_container[genre]) 26 | lucky = np.random.permutation(len(seq_container[genre]))[:validation_number].tolist() 27 | val[genre] = [seq_container[genre][_] for _ in lucky] 28 | print(f"In sum has {total} sequences") 29 | 30 | dump_path = os.path.join("./", "splits", f"split_wseed_{seed}.json") 31 | with open(dump_path, "w") as f: 32 | json.dump(val, f) 33 | print(f"Dumped validation set generated from seed {seed} to {dump_path}") 34 | 35 | 36 | if __name__ == "__main__": 37 | main() -------------------------------------------------------------------------------- /configs/tan.yaml: -------------------------------------------------------------------------------- 1 | NAME: tan # will create a folder under ./logs with this name 2 | PRETRAIN: 3 | GPUS: -1 # default on 4 GPUs 4 | ALGO: TAN 5 | EPOCH: 500 6 | WARMUP: 50 7 | ARCH: 8 | ARCH: Transformer 9 | LAYER: 3 10 | DIM: 512 11 | DROPOUT: 0.0 # no need to change 12 | PROTECTION: 2 # 0 for no protect, 1 for half protection, 2 for full protection 13 | DATA: # look up this: https://aistdancedb.ongaaccel.jp/data_formats/#dance-genres 14 | DATA_DIR: ../aistplusplus 15 | GENRE: 10 # number of genres to use, maximum 10, no need to change 16 | SPLIT: 1234 # 1234 or 4321 17 | BS: 32 18 | MIN_LENGTH: 64 19 | MAX_LENGTH: 64 20 | AUG_SHIFT_PROB: 1 21 | AUG_ROT_PROB: 1 22 | AUG_TIME_PROB: 1 23 | NUM_WORKERS: 4 24 | AUG_SHIFT_RANGE: 0.4 # in meter, length of an interval centered at 0 25 | AUG_ROT_RANGE: 0.2 # take value from 0 to 2 26 | AUG_TIME_RATE: 1.99 # take value from 1 to 1.99 27 | TRAINER: 28 | VAL_STEP: 5 29 | ACCELERATOR: ddp 30 | LR: 2.5e-5 31 | OPTIM: adam # sgd or adam 32 | LARS: False # use only when OPTIM is sgd 33 | VALIDATION: 34 | CLUSTER: 35 | TYPE: kmeans_skl 36 | GENRE: 10 # the index of genre to use (0~9), 10 means all genre 37 | K_MIN: 150 38 | K_MAX: 160 # with a step of 10 39 | TIMES: 10 # how many kmeans is run 40 | CKPT: -1 # NAME want to evaluate, if -1 use NAME in this config 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Video Tokneization 2 | 3 | Codebase for video tokenization, based on our paper [Towards Tokenized Human Dynamics Representation](https://arxiv.org/pdf/2111.11433.pdf). 4 | 5 | ![](show.gif) 6 | 7 | ### Prerequisites (tested under Python 3.8 and CUDA 11.1) 8 | 9 | ```console 10 | apt-get install ffmpeg 11 | pip install torch==1.8 12 | pip install torchvision 13 | pip install pytorch-lightning 14 | pip install pytorch-lightning-bolts 15 | pip install aniposelib wandb gym test-tube ffmpeg-python matplotlib easydict scikit-learn 16 | ``` 17 | 18 | ### Data Preparation 19 | 20 | 1. Make a directory besides this repo and name it `aistplusplus` 21 | 2. Download from [AIST++ website](https://google.github.io/aistplusplus_dataset/download.html) until it looks like 22 | ```angular2html 23 | ├── annotations 24 | │   ├── cameras 25 | │   ├── ignore_list.txt 26 | │   ├── keypoints2d 27 | │   ├── keypoints3d 28 | │   ├── motions 29 | │   └── splits 30 | └── video_list.txt 31 | ``` 32 | 33 | ### How to run 34 | 35 | 1. Write one configuration file, e.g., `configs/tan.yaml`. 36 | 37 | 2. Run `python pretrain.py --cfg configs/tan.yaml` with GPU, which will create a folder under `logs` for this run. Folder name specified by the `NAME` in configuration file. Then run `python cluster.py --cfg configs/tan.yaml` (CPU-only) and check results in `demo.ipynb`. 38 | 39 | 3. Or you can download and unzip my training result into `logs` folder from [here](https://drive.google.com/file/d/1a40_wDAY_LsUZq9VYlx2qI7nRN--eDm6/view?usp=sharing). 40 | -------------------------------------------------------------------------------- /src/algo/kmeans_skl.py: -------------------------------------------------------------------------------- 1 | import numpy as numpy 2 | from sklearn.cluster import KMeans 3 | import numpy as np 4 | 5 | class Clusterer: 6 | def __init__(self, TIMES, K, TOL): 7 | # cannot use custom distance 8 | self.K = K 9 | self.num_init = TIMES 10 | self.kmeans = KMeans(n_clusters=self.K, n_init=self.num_init, verbose=False, tol=TOL) 11 | # tol: Relative tolerance with regards to Frobenius norm of the difference 12 | # in the cluster centers of two consecutive iterations to declare convergence 13 | 14 | def fit(self, x): # x: (num_sample, num_feat) 15 | new_kmeans = self.kmeans.fit(x) 16 | self.kmeans = new_kmeans 17 | # after sort, we sort the best centroids in front 18 | score_container = [] 19 | for i in range(self.K): 20 | score = self.kmeans.score(x[self.kmeans.labels_ == i]) # the bigger the better 21 | score_container.append(score) 22 | indices = np.argsort(np.array(score_container))[::-1] 23 | self.kmeans.cluster_centers_ = self.kmeans.cluster_centers_[indices] 24 | 25 | def get_assignment(self, x): 26 | # x: (num_sample, num_feat) 27 | # returns the centroids with same shape 28 | idx = self.kmeans.predict(x) 29 | # centroids = self.kmeans.cluster_centers_[idx] 30 | return idx 31 | 32 | def get_centroids(self, ): 33 | for idx in range(self.K): 34 | yield self.kmeans.cluster_centers_[idx] 35 | 36 | def get_best_clusterer(nodes, times, argument_dict): 37 | c = Clusterer(TIMES=times, K=argument_dict["K"], TOL=argument_dict["TOL"]) 38 | c.fit(nodes) 39 | return c -------------------------------------------------------------------------------- /src/data/distance/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class SnippetDistance(): 4 | def __init__(self, TYPE, TRANSLATION, ROTATION): 5 | self.type = TYPE 6 | self.translation = TRANSLATION # TRANSLATION True means action, to normalize it by translation invariance 7 | self.rotation = ROTATION 8 | def __call__(self, a, b): 9 | ttl = a.shape[0] 10 | pcd = [] 11 | for x in [a, b]: 12 | if self.translation: 13 | # 33: left hip x coordinates, before it there are 11 joints 14 | # then is x, y, z and right hip 15 | body_centre_x = (x[33] + x[36]) / 2 16 | body_centre_y = (x[34] + x[37]) / 2 17 | body_centre_z = (x[35] + x[38]) / 2 18 | shift = np.tile(np.array([body_centre_x, body_centre_y, body_centre_z]), ttl // 3) 19 | x = x - shift 20 | if self.rotation: 21 | # using Euler–Rodrigues formula, partially adopted from https://stackoverflow.com/questions/6802577/rotation-of-3d-vector 22 | lh = x[33:36] 23 | axis = np.array([0, -lh[2], lh[1]]) 24 | theta = -np.arccos(lh[0] / np.sqrt(np.dot(lh, lh))) 25 | axis = axis / np.sqrt(np.dot(axis, axis)) 26 | a = np.cos(theta / 2.0) 27 | b, c, d = -axis * np.sin(theta / 2.0) 28 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 29 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 30 | ttt = np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 31 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 32 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 33 | x = (ttt @ x.reshape(-1, 3).T).T.flatten() 34 | pcd.append(x) 35 | return np.linalg.norm(pcd[0] - pcd[1]) 36 | -------------------------------------------------------------------------------- /plb/datamodules/dataset.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from src.data.dataset.loader import AISTDataset 3 | 4 | genre_list = ["gBR", "gPO", "gLO", "gMH", "gLH", "gHO", "gWA", "gKR", "gJS", "gJB"] 5 | 6 | # 17 joints of COCO: 7 | # 0 - nose, 1 - left_eye, 2 - right_eye, 3 - left_ear, 4 - right_ear 8 | # 5 - left_shoulder, 6 - right_shoulder, 7 - left_elbow, 8 - right_elbow, 9 - left_wrist, 10 - right_wrist 9 | # 11 - left_hip, 12 - right_hip, 13 - left_knee, 14 - right_knee. 15 - left_ankle, 16 - right_ankle 10 | 11 | 12 | class SkeletonDataset(): 13 | name = 'stl10' 14 | 15 | def __del__(self, ): 16 | if hasattr(self, "official_loader"): 17 | del self.official_loader 18 | 19 | def __init__(self, DATA_DIR, GENRE, SPLIT): 20 | self.data_dir = DATA_DIR 21 | self.genre = genre_list[:GENRE] 22 | self.split = SPLIT 23 | assert os.path.isdir(self.data_dir) 24 | self.official_loader = AISTDataset(os.path.join(self.data_dir, "annotations")) 25 | all_seq = [] 26 | for seq_name_np in self.official_loader.mapping_seq2env.keys(): 27 | seq_name = str(seq_name_np) 28 | if seq_name.split("_")[0] in self.genre: 29 | all_seq.append(seq_name) 30 | 31 | with open(os.path.join("src/data/dataset", "splits", f"split_wseed_{self.split}.json"), "r") as f: 32 | ldd = json.load(f) 33 | self.validation_split = [] 34 | for genre in self.genre: 35 | self.validation_split += ldd[genre] 36 | 37 | self.train_split = [_ for _ in all_seq if _ not in self.validation_split] 38 | # remove files officially deemed as broken 39 | bad_vids = self.official_loader.filter_file + ["gHO_sFM_cAll_d20_mHO5_ch13", ] 40 | self.validation_split = [_ for _ in self.validation_split if _ not in bad_vids] 41 | self.train_split = [_ for _ in self.train_split if _ not in bad_vids] 42 | 43 | print( 44 | f"{self.genre}-style dances loaded with {len(self.train_split)} training videos and {len(self.validation_split)} validation videos") 45 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle, time 3 | import json, math 4 | import numpy as np 5 | import pandas as pd 6 | from itertools import combinations 7 | from sklearn.metrics import normalized_mutual_info_score 8 | from scipy.stats import entropy 9 | from src.data.dataset.cluster_misc import lexicon, get_names, genre_list, vidn_parse 10 | from src.data.dataset.loader import AISTDataset 11 | from src.data.dataset.utils import save_paired_keypoints3d_as_video, rigid_align, rigid_align_sequence 12 | from src.data.distance.nndtw import DTW 13 | 14 | data_dir = "../aistplusplus" 15 | 16 | def preprocess(df): 17 | # input: a df with cols: idx, word, length, y, name 18 | # split advanced dance into multiple rows or remove them 19 | # give each snippet a tag of corresponding base dance 20 | res = pd.DataFrame(columns=["idx", "word", "length", "y", "label", "name"]) 21 | for index, row in df.iterrows(): 22 | if "sBM" in row["name"]: 23 | parsed = vidn_parse(row["name"]) 24 | tba = dict(row) 25 | tba["label"] = int(parsed["choreo"][2:4]) 26 | res = res.append(tba, ignore_index=True) 27 | else: 28 | raise NotImplementedError 29 | return res 30 | 31 | def metric_nmi(df): 32 | # input: a df with cols: idx, word, length, y, name, label 33 | df = preprocess(df) 34 | gt, pd = [], [] 35 | for index, row in df.iterrows(): 36 | gt += [row["label"], ] * row["length"] 37 | pd += [lexicon.index(row["word"]), ] * row["length"] 38 | return normalized_mutual_info_score(gt, pd) 39 | 40 | def ngram_ent(df, n=4, lb=1): # this is not n-gram entropy, this is a pre-processing function 41 | # input: a df with cols: idx, word, length, y, name 42 | # input is not expected to have gone through preprocessing 43 | # lb: filter out instance <= lb frames 44 | bins = {} 45 | dfs = {_: list(x[x["length"] > lb]["word"]) for _, x in df.groupby('y') if len(x) > 1} 46 | for k, v in dfs.items(): 47 | if len(v) >= n: 48 | for i in range(len(v) - n + 1): 49 | pattern = "".join(v[i:i + n]) 50 | if pattern in bins: 51 | bins[pattern] += 1 52 | else: 53 | bins[pattern] = 1 54 | return bins 55 | 56 | def nge(df, K, n=2, lb=5): # this is n-gram entropy 57 | bins = ngram_ent(df, n, lb) 58 | if not len(bins): 59 | return 0. 60 | assert K ** n >= len(bins) 61 | dist = [v for k, v in bins.items()] + [0] * (K ** n - len(bins)) # compensate for n-gram that did not appear 62 | ent = entropy(np.array(dist) / sum(dist), base=2) 63 | return ent 64 | 65 | def metric_f2(df, K): # this calculates the F_2 in paper, nge returns the K_n in paper 66 | return nge(df, K, n=2) - nge(df, K, n=1) 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | share/python-wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | *.py,cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | cover/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | # For a library or package, you might want to ignore these files since the code is 86 | # intended to run in multiple environments; otherwise, check them in: 87 | # .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | .idea 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzers 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # custom 141 | logs/* 142 | !logs/.gitkeep 143 | *.sh 144 | -------------------------------------------------------------------------------- /src/data/dataset/cluster_misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import string 4 | import itertools 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | all26 = list(string.ascii_lowercase) 9 | lexicon = [] 10 | for (c, v, s) in itertools.permutations(all26, 3): 11 | lexicon.append(c + v + s) 12 | 13 | 14 | keys = ["genre", "situ", "dancer", "tempo", "choreo", "name"] 15 | genre_list = ["gBR", "gPO", "gLO", "gMH", "gLH", "gHO", "gWA", "gKR", "gJS", "gJB"] 16 | 17 | def vidn_parse(s): 18 | res = {} 19 | if s.endswith("pkl"): 20 | s = s[:-4] 21 | for seg in s.split("_"): 22 | if seg.startswith("g"): 23 | res["genre"] = seg 24 | elif seg.startswith("s"): 25 | res["situ"] = seg 26 | elif seg.startswith("d"): 27 | res["dancer"] = seg 28 | elif seg.startswith("m"): 29 | res["tempo"] = 10 * (int(seg[3:]) + 8) 30 | elif seg.startswith("ch"): 31 | res["choreo"] = res["genre"][1:] + seg[2:] + res["situ"][1:] 32 | else: 33 | pass 34 | res["name"] = s # currently does not support camera variation 35 | return res 36 | 37 | def get_names(genre, trval="train", seed=1234): 38 | validation_split = [] 39 | with open(os.path.join("src/data/dataset", "splits", f"split_wseed_{seed}.json"), "r") as f: 40 | ldd = json.load(f) 41 | validation_split += ldd[genre] 42 | 43 | annot_3d = list(os.listdir("../aistplusplus/annotations/keypoints3d")) 44 | filter_file = os.path.join("../aistplusplus/annotations/", 'ignore_list.txt') 45 | with open(filter_file, "r") as f: 46 | filter_file = [_[:-1] for _ in f.readlines()] 47 | annot_3d = [_ for _ in annot_3d if _.startswith(genre) and _[:-4] not in filter_file] 48 | 49 | res = pd.DataFrame(columns=keys) 50 | if trval == "train" or trval == "tr": 51 | for s in annot_3d: 52 | if not s[:-4] in validation_split: # only keep samples in the training set 53 | res = res.append(vidn_parse(s), ignore_index=True) 54 | elif trval == "val": 55 | for s in annot_3d: 56 | if s[:-4] in validation_split: # only keep samples in the validation set 57 | res = res.append(vidn_parse(s), ignore_index=True) 58 | else: 59 | raise NotImplementedError 60 | 61 | return res.sort_values(["situ", "choreo"]) 62 | 63 | def get_num_in_table(df, names, signs): 64 | # a df filtered to have only one exp 65 | dfs = {_: x for _, x in df.groupby("genre") if len(x) > 1} 66 | containers = [[] for name in names] 67 | 68 | for genre, little_df in dfs.items(): 69 | for i, name in enumerate(names): 70 | nmi_list = list(little_df[little_df["type"]==name]["value"]) 71 | if len(nmi_list): 72 | containers[i].append(sum(nmi_list) / len(nmi_list)) 73 | else: 74 | containers[i].append(-9999) 75 | 76 | res = [] 77 | for container, sign in zip(containers, signs): 78 | nmi = sum(container) / len(container) * sign 79 | res.append(nmi) 80 | return res 81 | 82 | if __name__ == "__main__": 83 | print(len(lexicon)) 84 | print(lexicon) -------------------------------------------------------------------------------- /src/data/dataset/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | from src.data.dataset.loader import AISTDataset 5 | 6 | genre_list = ["gBR", "gPO", "gLO", "gMH", "gLH", "gHO", "gWA", "gKR", "gJS", "gJB"] 7 | 8 | class SnippetDataset(): 9 | def __del__(self, ): 10 | del self.official_loader 11 | 12 | def __init__(self, TYPE, DATA_DIR, GENRE, USE_OPTIM, LENGTH, OVERLAP, SPLIT, BS, MIN_LENGTH, MAX_LENGTH): 13 | self.type = TYPE 14 | self.data_dir = DATA_DIR 15 | self.genre = genre_list[:GENRE] 16 | self.use_optim = USE_OPTIM 17 | self.length = LENGTH 18 | self.overlap = OVERLAP 19 | self.split = SPLIT 20 | assert os.path.isdir(self.data_dir) 21 | self.official_loader = AISTDataset(os.path.join(self.data_dir, "annotations")) 22 | all_seq = [] 23 | for seq_name_np in self.official_loader.mapping_seq2env.keys(): 24 | seq_name = str(seq_name_np) 25 | if seq_name.split("_")[0] in self.genre: 26 | all_seq.append(seq_name) 27 | 28 | with open(os.path.join("src/data/dataset", "splits", f"split_wseed_{self.split}.json"), "r") as f: 29 | ldd = json.load(f) 30 | self.validation_split = [] 31 | for genre in self.genre: 32 | self.validation_split += ldd[genre] 33 | 34 | self.train_split = [_ for _ in all_seq if _ not in self.validation_split] 35 | 36 | # remove files officially deemed as broken 37 | bad_vids = self.official_loader.filter_file #+ ["gHO_sFM_cAll_d20_mHO5_ch13", ] 38 | self.validation_split = [_ for _ in self.validation_split if _ not in bad_vids] 39 | self.train_split = [_ for _ in self.train_split if _ not in bad_vids] 40 | 41 | print(f"{self.genre}-style dances loaded with {len(self.train_split)} training videos and {len(self.validation_split)} validation videos") 42 | # TODO: change to use official splits, 868 train, 70 validation, 470 test 43 | 44 | def get_train(self, ): 45 | for seq_name in self.train_split: 46 | full_data = self.official_loader.load_keypoint3d(seq_name, use_optim=self.use_optim) 47 | if full_data.max() > 500: 48 | print(f"in {seq_name} for train, max number being {full_data.max()}") 49 | duration = full_data.shape[0] 50 | hop = 1 - self.overlap 51 | total_pieces = math.floor((duration / self.length - 1) / hop) # + 1 52 | for i in range(total_pieces): 53 | start = int(i * hop * self.length) 54 | tbr = full_data[start: start + self.length] 55 | yield tbr 56 | 57 | def get_validation(self, ): 58 | for seq_name in self.validation_split: 59 | per_seq = [] 60 | full_data = self.official_loader.load_keypoint3d(seq_name, use_optim=self.use_optim) 61 | if full_data.max() > 500: 62 | print(f"in {seq_name} for val, max number being {full_data.max()}") 63 | duration = full_data.shape[0] 64 | for i in range(duration // self.length): 65 | per_seq.append(full_data[i * self.length: (i + 1) * self.length]) 66 | yield {seq_name: per_seq} 67 | -------------------------------------------------------------------------------- /plb/models/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from torch import nn 7 | 8 | inter_dim_dict = {2048: 512, 1024:256, 512: 128, 256: 128, 128: 64, 64: 64} 9 | 10 | class PositionalEncoding(nn.Module): 11 | def __init__(self, d_model, dropout=0.1, max_len=5000): 12 | super(PositionalEncoding, self).__init__() 13 | self.dropout = nn.Dropout(p=dropout) 14 | 15 | pe = torch.zeros(max_len, d_model) 16 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | pe = pe.unsqueeze(0).transpose(0, 1) 21 | self.register_buffer('pe', pe) 22 | 23 | def forward(self, x): 24 | x = x + self.pe[:x.size(0)] 25 | return self.dropout(x) 26 | 27 | 28 | class Transformer(nn.Module): 29 | def __init__(self, tr_layer=6, tr_dim=512, j=51): 30 | super(Transformer, self).__init__() 31 | self.layers = tr_layer 32 | self.d_model = tr_dim 33 | self.pos_encoder = PositionalEncoding(self.d_model, 0) 34 | self.embedder = nn.Sequential(nn.Linear(j, inter_dim_dict[self.d_model]), nn.Linear(inter_dim_dict[self.d_model], self.d_model)) 35 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=8) 36 | self.encoder = nn.TransformerEncoder(encoder_layer, tr_layer) # [L, B, F] --> [L, B, F] 37 | 38 | def forward(self, point_bank, len): 39 | # point_bank, [N, T, 51] 40 | max_len = point_bank.size(1) # collate fuction collate across GPUs, so the padded length might not larger than the max in this GPU 41 | points = self.embedder(point_bank) * math.sqrt(self.d_model) 42 | points = self.pos_encoder(points.transpose(0, 1)) # [T, N, 512], transpose should be before pos_encoder 43 | mask = torch.stack([(torch.arange(max_len, device=len.device) >= _) for _ in len]).to(points.device) 44 | points = self.encoder(points, src_key_padding_mask=mask) # [T, N, 512] 45 | # points = points[:, 0] # use the first token as a summary of whole 46 | return points 47 | 48 | class Transformer_wote(nn.Module): 49 | def __init__(self, tr_layer=6, tr_dim=512, j=51): 50 | super(Transformer_wote, self).__init__() 51 | self.layers = tr_layer 52 | self.d_model = tr_dim 53 | # self.pos_encoder = PositionalEncoding(self.d_model, 0) 54 | self.embedder = nn.Sequential(nn.Linear(j, inter_dim_dict[self.d_model]), nn.Linear(inter_dim_dict[self.d_model], self.d_model)) 55 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=8) 56 | self.encoder = nn.TransformerEncoder(encoder_layer, tr_layer) # [L, B, F] --> [L, B, F] 57 | 58 | def forward(self, point_bank, len): 59 | # point_bank, [N, T, 51] 60 | max_len = point_bank.size(1) # collate fuction collate across GPUs, so the padded length might not larger than the max in this GPU 61 | points = self.embedder(point_bank) * math.sqrt(self.d_model) 62 | points = points.transpose(0, 1) # [T, N, 512], transpose should be before pos_encoder 63 | mask = torch.stack([(torch.arange(max_len, device=len.device) >= _) for _ in len]).to(points.device) 64 | points = self.encoder(points, src_key_padding_mask=mask) # [T, N, 512] 65 | # points = points[:, 0] # use the first token as a summary of whole 66 | return points 67 | -------------------------------------------------------------------------------- /src/data/dataset/visualizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Perception Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Visualize the AIST++ Dataset.""" 16 | 17 | from . import utils 18 | import cv2 19 | import numpy as np 20 | 21 | _COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], 22 | [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], 23 | [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], 24 | [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], 25 | [255, 0, 170], [255, 0, 85]] 26 | 27 | _COLORS_BONE = [ 28 | ([0, 1], [255, 0, 0]), ([0, 2], [255, 0, 0]), 29 | ([1, 3], [255, 170, 0]), ([2, 4], [255, 170, 0]), 30 | ([5, 6], [0, 255, 255]), ([11, 12], [0, 255, 255]), # the two bridge 31 | ([6, 8], [0, 255, 255]), ([8, 10], [0, 255, 255]), 32 | ([6, 12], [0, 255, 255]), ([12, 14], [0, 255, 255]), ([14, 16], [0, 255, 255]), 33 | ([5, 7], [255, 0, 85]), ([7, 9], [255, 0, 85]), 34 | ([5, 11], [255, 0, 85]), ([11, 13], [255, 0, 85]), ([13, 15], [255, 0, 85]), 35 | ] 36 | 37 | _COLORS_BONE_PLAIN = [ 38 | ([0, 1], [0, 0, 0]), ([0, 2], [0, 0, 0]), 39 | ([1, 3], [0, 0, 0]), ([2, 4], [0, 0, 0]), 40 | ([5, 6], [0, 0, 0]), ([11, 12], [0, 0, 0]), # the two bridge 41 | ([6, 8], [0, 0, 0]), ([8, 10], [0, 0, 0]), 42 | ([6, 12], [0, 0, 0]), ([12, 14], [0, 0, 0]), ([14, 16], [0, 0, 0]), 43 | ([5, 7], [0, 0, 0]), ([7, 9], [0, 0, 0]), 44 | ([5, 11], [0, 0, 0]), ([11, 13], [0, 0, 0]), ([13, 15], [0, 0, 0]), 45 | ] 46 | 47 | 48 | def plot_kpt(keypoint, canvas, bones=True): 49 | if bones: 50 | for j, c in _COLORS_BONE: 51 | cv2.line(canvas, 52 | tuple(keypoint[:, 0:2][j[0]].astype(int).tolist()), 53 | tuple(keypoint[:, 0:2][j[1]].astype(int).tolist()), 54 | tuple(c), thickness=2) 55 | else: 56 | for i, (x, y) in enumerate(keypoint[:, 0:2]): 57 | if np.isnan(x) or np.isnan(y): 58 | continue 59 | cv2.circle(canvas, (int(x), int(y)), 7, _COLORS[i % len(_COLORS)], thickness=-1) 60 | return canvas 61 | 62 | def plot_kpt_plain(keypoint, canvas, c, bones=True): 63 | if bones: 64 | for j, _ in _COLORS_BONE_PLAIN: 65 | cv2.line(canvas, 66 | tuple(keypoint[:, 0:2][j[0]].astype(int).tolist()), 67 | tuple(keypoint[:, 0:2][j[1]].astype(int).tolist()), 68 | tuple(c), thickness=2) 69 | else: 70 | for i, (x, y) in enumerate(keypoint[:, 0:2]): 71 | if np.isnan(x) or np.isnan(y): 72 | continue 73 | cv2.circle(canvas, (int(x), int(y)), 7, _COLORS[i % len(_COLORS)], thickness=-1) 74 | return canvas 75 | 76 | 77 | def plot_on_video(keypoints2d, video_path, save_path, fps=60): 78 | assert len(keypoints2d.shape) == 3, ( 79 | f'Input shape is not valid! Got {keypoints2d.shape}') 80 | video = utils.ffmpeg_video_read(video_path, fps=fps) 81 | for iframe, keypoint in enumerate(keypoints2d): 82 | if iframe >= video.shape[0]: 83 | break 84 | video[iframe] = plot_kpt(keypoint, video[iframe]) 85 | utils.ffmpeg_video_write(video, save_path, fps=fps) 86 | -------------------------------------------------------------------------------- /plb/models/self_supervised/tan/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE 4 | from pl_bolts.utils.warnings import warn_missing_pkg 5 | 6 | if _TORCHVISION_AVAILABLE: 7 | from torchvision import transforms as transforms 8 | else: # pragma: no cover 9 | warn_missing_pkg('torchvision') 10 | 11 | if _OPENCV_AVAILABLE: 12 | import cv2 13 | else: # pragma: no cover 14 | warn_missing_pkg('cv2', pypi_name='opencv-python') 15 | 16 | from plb.datamodules.seq_datamodule import SkeletonTransform 17 | 18 | 19 | class TrainDataTransform(object): 20 | def __init__(self, aug_shift_prob, aug_shift_range, aug_rot_prob, aug_rot_range, min_length, max_length, aug_time_prob, aug_time_rate) -> None: 21 | self.train_transform = SkeletonTransform(aug_shift_prob, aug_shift_range, aug_rot_prob, aug_rot_range, min_length, max_length, aug_time_prob, aug_time_rate) 22 | self.min_length = min_length 23 | self.max_length = max_length 24 | self.aug_time_prob = aug_time_prob 25 | self.aug_time_rate = aug_time_rate 26 | 27 | def __call__(self, sample): 28 | transform = self.train_transform 29 | ttl = sample.size(0) 30 | 31 | # let's random crop 32 | len = random.randint(self.min_length, min(self.max_length, ttl)) 33 | start = random.randint(0, ttl - len) 34 | sample = sample[start:start + len] 35 | 36 | xi, veloi = transform(sample)#, shut=True) # not do transform on one branch 37 | xj, veloj = transform(sample) 38 | 39 | return xi, xj, veloi, veloj # self.online_transform(sample) 40 | 41 | 42 | class EvalDataTransform(TrainDataTransform): 43 | def __init__(self, *args, **kwargs): 44 | super().__init__(*args, **kwargs) 45 | 46 | 47 | class FinetuneTransform(object): 48 | def __init__( 49 | self, 50 | input_height: int = 224, 51 | jitter_strength: float = 1., 52 | normalize=None, 53 | eval_transform: bool = False 54 | ) -> None: 55 | 56 | self.jitter_strength = jitter_strength 57 | self.input_height = input_height 58 | self.normalize = normalize 59 | 60 | self.color_jitter = transforms.ColorJitter( 61 | 0.8 * self.jitter_strength, 62 | 0.8 * self.jitter_strength, 63 | 0.8 * self.jitter_strength, 64 | 0.2 * self.jitter_strength, 65 | ) 66 | 67 | if not eval_transform: 68 | data_transforms = [ 69 | transforms.RandomResizedCrop(size=self.input_height), 70 | transforms.RandomHorizontalFlip(p=0.5), 71 | transforms.RandomApply([self.color_jitter], p=0.8), 72 | transforms.RandomGrayscale(p=0.2) 73 | ] 74 | else: 75 | data_transforms = [ 76 | transforms.Resize(int(self.input_height + 0.1 * self.input_height)), 77 | transforms.CenterCrop(self.input_height) 78 | ] 79 | 80 | if normalize is None: 81 | final_transform = transforms.ToTensor() 82 | else: 83 | final_transform = transforms.Compose([transforms.ToTensor(), normalize]) 84 | 85 | data_transforms.append(final_transform) 86 | self.transform = transforms.Compose(data_transforms) 87 | 88 | def __call__(self, sample): 89 | return self.transform(sample) 90 | 91 | 92 | class GaussianBlur(object): 93 | # Implements Gaussian blur as described in the SimCLR paper 94 | def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): 95 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 96 | raise ModuleNotFoundError('You want to use `GaussianBlur` from `cv2` which is not installed yet.') 97 | 98 | self.min = min 99 | self.max = max 100 | 101 | # kernel size is set to be 10% of the image height/width 102 | self.kernel_size = kernel_size 103 | self.p = p 104 | 105 | def __call__(self, sample): 106 | sample = np.array(sample) 107 | 108 | # blur the image with a 50% chance 109 | prob = np.random.random_sample() 110 | 111 | if prob < self.p: 112 | sigma = (self.max - self.min) * np.random.random_sample() + self.min 113 | sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) 114 | 115 | return sample 116 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pprint 4 | import sys 5 | 6 | import json 7 | import yaml 8 | import shutil 9 | import time 10 | import logging, json 11 | from pathlib import Path 12 | 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import ModelCheckpoint 15 | from pytorch_lightning.loggers import TestTubeLogger 16 | from plb.models.self_supervised import TAN 17 | from plb.models.self_supervised.tan import TANEvalDataTransform, TANTrainDataTransform 18 | from plb.datamodules import SeqDataModule 19 | from pytorch_lightning.plugins import DDPPlugin 20 | 21 | KEYPOINT_NAME = ["nose", "left_eye", "right_eye", "left_ear", "right_ear", 22 | "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", 23 | "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", "right_ankle"] 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Train classification network') 27 | 28 | parser.add_argument('--cfg', 29 | help='experiment configure file name', 30 | required=True, 31 | type=str) 32 | 33 | parser.add_argument('--data_dir', 34 | help='path to aistplusplus data directory from repo root', 35 | type=str) 36 | 37 | parser.add_argument('--seed', 38 | help='seed for this run', 39 | default=1, 40 | type=int) 41 | 42 | args, _ = parser.parse_known_args() 43 | pl.utilities.seed.seed_everything(args.seed) 44 | with open(args.cfg, 'r') as stream: 45 | ldd = yaml.safe_load(stream) 46 | 47 | if args.data_dir: 48 | ldd["PRETRAIN"]["DATA"]["DATA_DIR"] = args.data_dir 49 | pprint.pprint(ldd) 50 | return ldd 51 | 52 | 53 | def main(): 54 | args = parse_args() 55 | debug = args["NAME"] == "debug" 56 | log_dir = os.path.join("./logs", args["NAME"]) 57 | 58 | dirpath = Path(log_dir) 59 | dirpath.mkdir(parents=True, exist_ok=True) 60 | 61 | timed = time.strftime("%Y%m%d_%H%M%S") 62 | with open(os.path.join(log_dir, f"config_used_{timed}.yaml"), "w") as stream: 63 | yaml.dump(args, stream, default_flow_style=False) 64 | video_dir = os.path.join(log_dir, "saved_videos") 65 | Path(video_dir).mkdir(parents=True, exist_ok=True) 66 | 67 | # log 68 | tt_logger = TestTubeLogger( 69 | save_dir=log_dir, 70 | name="default", 71 | debug=False, 72 | create_git_tag=False 73 | ) 74 | 75 | # trainer 76 | trainer = pl.Trainer( 77 | gpus=args["PRETRAIN"]["GPUS"], 78 | check_val_every_n_epoch=args["PRETRAIN"]["TRAINER"]["VAL_STEP"], 79 | logger=tt_logger, 80 | accelerator=args["PRETRAIN"]["TRAINER"]["ACCELERATOR"], 81 | max_epochs=args["PRETRAIN"]["EPOCH"], 82 | gradient_clip_val=0.5, 83 | num_sanity_val_steps=0, 84 | plugins=DDPPlugin(find_unused_parameters=False), 85 | ) 86 | 87 | j = 17 88 | dm = SeqDataModule(**args["PRETRAIN"]["DATA"]) 89 | transform_args = {"min_length": args["PRETRAIN"]["DATA"]["MIN_LENGTH"], 90 | "max_length": args["PRETRAIN"]["DATA"]["MAX_LENGTH"], 91 | "aug_shift_prob": args["PRETRAIN"]["DATA"]["AUG_SHIFT_PROB"], 92 | "aug_shift_range": args["PRETRAIN"]["DATA"]["AUG_SHIFT_RANGE"], 93 | "aug_rot_prob": args["PRETRAIN"]["DATA"]["AUG_ROT_PROB"], 94 | "aug_rot_range": args["PRETRAIN"]["DATA"]["AUG_ROT_RANGE"], 95 | "aug_time_prob": args["PRETRAIN"]["DATA"]["AUG_TIME_PROB"], 96 | "aug_time_rate": args["PRETRAIN"]["DATA"]["AUG_TIME_RATE"], } 97 | dm.train_transforms = eval(args["PRETRAIN"]["ALGO"] + "TrainDataTransform")(**transform_args) 98 | dm.val_transforms = eval(args["PRETRAIN"]["ALGO"] + "EvalDataTransform")(**transform_args) 99 | model = eval(args["PRETRAIN"]["ALGO"])( 100 | gpus=args["PRETRAIN"]["GPUS"], 101 | num_samples=dm.num_samples, 102 | batch_size=dm.batch_size, 103 | length=dm.min_length, 104 | dataset=dm.name, 105 | max_epochs=args["PRETRAIN"]["EPOCH"], 106 | warmup_epochs=args["PRETRAIN"]["WARMUP"], 107 | arch=args["PRETRAIN"]["ARCH"]["ARCH"], 108 | val_configs=args["PRETRAIN"]["VALIDATION"], 109 | learning_rate=float(args["PRETRAIN"]["TRAINER"]["LR"]), 110 | log_dir=log_dir, 111 | protection=args["PRETRAIN"]["PROTECTION"], 112 | optim=args["PRETRAIN"]["TRAINER"]["OPTIM"], 113 | lars_wrapper=args["PRETRAIN"]["TRAINER"]["LARS"], 114 | tr_layer=args["PRETRAIN"]["ARCH"]["LAYER"], 115 | tr_dim=args["PRETRAIN"]["ARCH"]["DIM"], 116 | neg_dp=args["PRETRAIN"]["ARCH"]["DROPOUT"], 117 | j=j*3, 118 | ) 119 | 120 | trainer.fit(model, datamodule=dm) 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /src/data/dataset/splits/split_wseed_1234.json: -------------------------------------------------------------------------------- 1 | {"gBR": ["gBR_sBM_cAll_d04_mBR3_ch01", "gBR_sBM_cAll_d04_mBR2_ch07", "gBR_sFM_cAll_d04_mBR3_ch04", "gBR_sBM_cAll_d04_mBR2_ch09", "gBR_sBM_cAll_d06_mBR3_ch02", "gBR_sFM_cAll_d05_mBR1_ch08", "gBR_sBM_cAll_d04_mBR3_ch06", "gBR_sBM_cAll_d05_mBR5_ch09", "gBR_sBM_cAll_d04_mBR2_ch05", "gBR_sBM_cAll_d05_mBR4_ch03", "gBR_sFM_cAll_d05_mBR2_ch09", "gBR_sBM_cAll_d05_mBR4_ch04", "gBR_sBM_cAll_d05_mBR4_ch02", "gBR_sBM_cAll_d05_mBR0_ch10", "gBR_sBM_cAll_d05_mBR1_ch10", "gBR_sBM_cAll_d06_mBR3_ch09", "gBR_sBM_cAll_d05_mBR0_ch02", "gBR_sBM_cAll_d06_mBR4_ch04", "gBR_sBM_cAll_d04_mBR3_ch07", "gBR_sBM_cAll_d04_mBR1_ch07"], "gPO": ["gPO_sBM_cAll_d10_mPO3_ch10", "gPO_sBM_cAll_d12_mPO4_ch03", "gPO_sBM_cAll_d11_mPO0_ch06", "gPO_sFM_cAll_d12_mPO1_ch16", "gPO_sBM_cAll_d10_mPO2_ch09", "gPO_sBM_cAll_d12_mPO2_ch01", "gPO_sBM_cAll_d12_mPO2_ch02", "gPO_sBM_cAll_d12_mPO4_ch04", "gPO_sBM_cAll_d11_mPO5_ch02", "gPO_sFM_cAll_d12_mPO5_ch21", "gPO_sFM_cAll_d12_mPO0_ch15", "gPO_sBM_cAll_d11_mPO1_ch07", "gPO_sBM_cAll_d12_mPO5_ch02", "gPO_sBM_cAll_d10_mPO0_ch04", "gPO_sBM_cAll_d12_mPO2_ch05", "gPO_sBM_cAll_d11_mPO5_ch10", "gPO_sBM_cAll_d11_mPO1_ch10", "gPO_sBM_cAll_d11_mPO1_ch09", "gPO_sBM_cAll_d10_mPO2_ch03", "gPO_sBM_cAll_d11_mPO0_ch05"], "gLO": ["gLO_sBM_cAll_d15_mLO2_ch02", "gLO_sBM_cAll_d15_mLO3_ch10", "gLO_sBM_cAll_d15_mLO2_ch09", "gLO_sBM_cAll_d15_mLO3_ch06", "gLO_sBM_cAll_d13_mLO0_ch01", "gLO_sBM_cAll_d13_mLO2_ch02", "gLO_sBM_cAll_d15_mLO5_ch07", "gLO_sBM_cAll_d13_mLO0_ch06", "gLO_sFM_cAll_d15_mLO4_ch19", "gLO_sBM_cAll_d13_mLO2_ch10", "gLO_sBM_cAll_d14_mLO0_ch04", "gLO_sFM_cAll_d14_mLO4_ch12", "gLO_sBM_cAll_d13_mLO2_ch01", "gLO_sBM_cAll_d13_mLO2_ch04", "gLO_sBM_cAll_d15_mLO4_ch08", "gLO_sBM_cAll_d13_mLO3_ch08", "gLO_sBM_cAll_d13_mLO0_ch03", "gLO_sBM_cAll_d14_mLO1_ch05", "gLO_sBM_cAll_d15_mLO3_ch02", "gLO_sBM_cAll_d15_mLO3_ch08"], "gMH": ["gMH_sBM_cAll_d24_mMH5_ch01", "gMH_sBM_cAll_d24_mMH4_ch03", "gMH_sBM_cAll_d24_mMH3_ch05", "gMH_sBM_cAll_d23_mMH4_ch07", "gMH_sBM_cAll_d22_mMH1_ch10", "gMH_sBM_cAll_d22_mMH3_ch09", "gMH_sFM_cAll_d22_mMH1_ch02", "gMH_sBM_cAll_d24_mMH3_ch09", "gMH_sBM_cAll_d22_mMH1_ch02", "gMH_sFM_cAll_d24_mMH1_ch16", "gMH_sBM_cAll_d23_mMH4_ch06", "gMH_sBM_cAll_d22_mMH3_ch01", "gMH_sBM_cAll_d22_mMH2_ch04", "gMH_sBM_cAll_d24_mMH5_ch08", "gMH_sBM_cAll_d22_mMH3_ch10", "gMH_sBM_cAll_d22_mMH1_ch06", "gMH_sFM_cAll_d23_mMH3_ch11", "gMH_sFM_cAll_d23_mMH0_ch14", "gMH_sBM_cAll_d22_mMH3_ch07", "gMH_sBM_cAll_d23_mMH0_ch06"], "gLH": ["gLH_sBM_cAll_d16_mLH3_ch08", "gLH_sFM_cAll_d17_mLH0_ch14", "gLH_sBM_cAll_d17_mLH5_ch05", "gLH_sBM_cAll_d17_mLH5_ch04", "gLH_sBM_cAll_d17_mLH0_ch08", "gLH_sBM_cAll_d17_mLH1_ch05", "gLH_sBM_cAll_d17_mLH1_ch10", "gLH_sBM_cAll_d16_mLH2_ch07", "gLH_sBM_cAll_d18_mLH5_ch02", "gLH_sFM_cAll_d17_mLH4_ch12", "gLH_sBM_cAll_d18_mLH2_ch05", "gLH_sBM_cAll_d17_mLH0_ch10", "gLH_sBM_cAll_d18_mLH4_ch09", "gLH_sBM_cAll_d16_mLH1_ch04", "gLH_sFM_cAll_d17_mLH1_ch09", "gLH_sBM_cAll_d16_mLH2_ch04", "gLH_sBM_cAll_d17_mLH4_ch01", "gLH_sBM_cAll_d16_mLH2_ch09", "gLH_sBM_cAll_d17_mLH1_ch06", "gLH_sBM_cAll_d17_mLH5_ch06"], "gHO": ["gHO_sBM_cAll_d19_mHO0_ch02", "gHO_sFM_cAll_d21_mHO2_ch17", "gHO_sBM_cAll_d21_mHO4_ch07", "gHO_sBM_cAll_d20_mHO4_ch01", "gHO_sBM_cAll_d19_mHO2_ch06", "gHO_sBM_cAll_d21_mHO4_ch05", "gHO_sBM_cAll_d19_mHO2_ch09", "gHO_sBM_cAll_d21_mHO5_ch06", "gHO_sBM_cAll_d20_mHO4_ch09", "gHO_sBM_cAll_d20_mHO1_ch01", "gHO_sFM_cAll_d19_mHO1_ch02", "gHO_sBM_cAll_d19_mHO0_ch10", "gHO_sBM_cAll_d21_mHO2_ch07", "gHO_sBM_cAll_d21_mHO4_ch04", "gHO_sBM_cAll_d21_mHO3_ch03", "gHO_sBM_cAll_d21_mHO5_ch08", "gHO_sBM_cAll_d21_mHO5_ch07", "gHO_sBM_cAll_d19_mHO0_ch01", "gHO_sBM_cAll_d21_mHO2_ch04", "gHO_sBM_cAll_d21_mHO5_ch10"], "gWA": ["gWA_sBM_cAll_d27_mWA4_ch07", "gWA_sFM_cAll_d27_mWA0_ch15", "gWA_sBM_cAll_d26_mWA1_ch04", "gWA_sBM_cAll_d26_mWA0_ch10", "gWA_sBM_cAll_d26_mWA1_ch05", "gWA_sBM_cAll_d26_mWA0_ch01", "gWA_sBM_cAll_d27_mWA4_ch10", "gWA_sBM_cAll_d27_mWA5_ch03", "gWA_sBM_cAll_d27_mWA4_ch01", "gWA_sBM_cAll_d27_mWA5_ch01", "gWA_sBM_cAll_d27_mWA3_ch01", "gWA_sBM_cAll_d26_mWA5_ch07", "gWA_sBM_cAll_d27_mWA3_ch07", "gWA_sFM_cAll_d25_mWA3_ch04", "gWA_sBM_cAll_d27_mWA4_ch09", "gWA_sBM_cAll_d25_mWA1_ch02", "gWA_sBM_cAll_d27_mWA2_ch02", "gWA_sBM_cAll_d27_mWA3_ch10", "gWA_sBM_cAll_d27_mWA2_ch09", "gWA_sFM_cAll_d26_mWA3_ch11"], "gKR": ["gKR_sBM_cAll_d28_mKR2_ch04", "gKR_sBM_cAll_d30_mKR4_ch01", "gKR_sBM_cAll_d29_mKR1_ch10", "gKR_sBM_cAll_d30_mKR5_ch06", "gKR_sBM_cAll_d30_mKR4_ch03", "gKR_sBM_cAll_d29_mKR4_ch04", "gKR_sBM_cAll_d29_mKR0_ch01", "gKR_sBM_cAll_d30_mKR4_ch10", "gKR_sFM_cAll_d30_mKR4_ch19", "gKR_sBM_cAll_d28_mKR3_ch08", "gKR_sBM_cAll_d29_mKR0_ch03", "gKR_sBM_cAll_d28_mKR2_ch08", "gKR_sBM_cAll_d29_mKR5_ch03", "gKR_sBM_cAll_d29_mKR1_ch06", "gKR_sFM_cAll_d29_mKR5_ch13", "gKR_sBM_cAll_d30_mKR2_ch07", "gKR_sFM_cAll_d30_mKR3_ch21", "gKR_sBM_cAll_d28_mKR0_ch04", "gKR_sBM_cAll_d30_mKR4_ch05", "gKR_sBM_cAll_d29_mKR1_ch05"], "gJS": ["gJS_sBM_cAll_d03_mJS3_ch05", "gJS_sBM_cAll_d01_mJS2_ch04", "gJS_sFM_cAll_d01_mJS1_ch02", "gJS_sBM_cAll_d01_mJS3_ch10", "gJS_sBM_cAll_d03_mJS4_ch04", "gJS_sBM_cAll_d01_mJS0_ch09", "gJS_sBM_cAll_d03_mJS4_ch06", "gJS_sBM_cAll_d02_mJS4_ch02", "gJS_sBM_cAll_d02_mJS1_ch05", "gJS_sBM_cAll_d03_mJS2_ch05", "gJS_sBM_cAll_d03_mJS3_ch08", "gJS_sBM_cAll_d02_mJS4_ch01", "gJS_sBM_cAll_d02_mJS4_ch10", "gJS_sBM_cAll_d02_mJS1_ch03", "gJS_sBM_cAll_d01_mJS1_ch06", "gJS_sBM_cAll_d02_mJS5_ch01", "gJS_sBM_cAll_d02_mJS5_ch07", "gJS_sBM_cAll_d02_mJS0_ch05", "gJS_sBM_cAll_d03_mJS4_ch03", "gJS_sFM_cAll_d03_mJS0_ch01"], "gJB": ["gJB_sBM_cAll_d07_mJB0_ch03", "gJB_sBM_cAll_d07_mJB1_ch09", "gJB_sBM_cAll_d08_mJB1_ch03", "gJB_sBM_cAll_d07_mJB3_ch02", "gJB_sBM_cAll_d09_mJB4_ch05", "gJB_sBM_cAll_d07_mJB2_ch05", "gJB_sBM_cAll_d09_mJB5_ch10", "gJB_sFM_cAll_d07_mJB3_ch07", "gJB_sBM_cAll_d09_mJB5_ch05", "gJB_sBM_cAll_d09_mJB3_ch07", "gJB_sBM_cAll_d07_mJB0_ch02", "gJB_sFM_cAll_d08_mJB1_ch09", "gJB_sBM_cAll_d07_mJB0_ch04", "gJB_sFM_cAll_d07_mJB4_ch05", "gJB_sBM_cAll_d09_mJB3_ch08", "gJB_sBM_cAll_d08_mJB5_ch09", "gJB_sBM_cAll_d08_mJB1_ch09", "gJB_sBM_cAll_d08_mJB4_ch03", "gJB_sBM_cAll_d08_mJB4_ch08", "gJB_sBM_cAll_d09_mJB2_ch04"]} -------------------------------------------------------------------------------- /src/data/dataset/loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Perception Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """AIST++ Dataset Loader.""" 16 | import json 17 | import os 18 | import pickle 19 | 20 | import aniposelib 21 | import numpy as np 22 | 23 | # 17 joints of COCO: 24 | # 0 - nose, 1 - left_eye, 2 - right_eye, 3 - left_ear, 4 - right_ear 25 | # 5 - left_shoulder, 6 - right_shoulder, 7 - left_elbow, 8 - right_elbow, 9 - left_wrist, 10 - right_wrist 26 | # 11 - left_hip, 12 - right_hip, 13 - left_knee, 14 - right_knee. 15 - left_ankle, 16 - right_ankle 27 | 28 | class AISTDataset: 29 | """A dataset class for loading, processing and plotting AIST++.""" 30 | # use this link to check naming method: https://aistdancedb.ongaaccel.jp/data_formats/ 31 | 32 | VIEWS = ['c01', 'c02', 'c03', 'c04', 'c05', 'c06', 'c07', 'c08', 'c09'] 33 | 34 | def __init__(self, anno_dir): 35 | assert os.path.exists(anno_dir), f'Data does not exist at {anno_dir}!' 36 | 37 | # Init paths 38 | self.camera_dir = os.path.join(anno_dir, 'cameras/') 39 | self.motion_dir = os.path.join(anno_dir, 'motions/') 40 | self.keypoint3d_dir = os.path.join(anno_dir, 'keypoints3d/') 41 | self.keypoint2d_dir = os.path.join(anno_dir, 'keypoints2d/') 42 | self.splits_dir = os.path.join(anno_dir, 'splits/') 43 | filter_file = os.path.join(anno_dir, 'ignore_list.txt') 44 | with open(filter_file, "r") as f: 45 | self.filter_file = [_[:-1] for _ in f.readlines()] 46 | 47 | # Load environment setting mapping 48 | self.mapping_seq2env = {} # sequence name -> env name 49 | self.mapping_env2seq = {} # env name -> a list of sequence names 50 | env_mapping_file = os.path.join(self.camera_dir, 'mapping.txt') 51 | env_mapping = np.loadtxt(env_mapping_file, dtype=str) 52 | for seq_name, env_name in env_mapping: 53 | self.mapping_seq2env[seq_name] = env_name 54 | if env_name not in self.mapping_env2seq: 55 | self.mapping_env2seq[env_name] = [] 56 | self.mapping_env2seq[env_name].append(seq_name) 57 | 58 | @classmethod 59 | def get_video_name(cls, seq_name, view): 60 | """Get AIST video name from AIST++ sequence name.""" 61 | return seq_name.replace('cAll', view) 62 | 63 | @classmethod 64 | def get_seq_name(cls, video_name): 65 | """Get AIST++ sequence name from AIST video name.""" 66 | tags = video_name.split('_') 67 | if len(tags) == 3: 68 | view = tags[1] 69 | tags[1] = 'cAll' 70 | else: 71 | view = tags[2] 72 | tags[2] = 'cAll' 73 | return '_'.join(tags), view 74 | 75 | @classmethod 76 | def load_camera_group(cls, camera_dir, env_name): 77 | """Load a set of cameras in the environment.""" 78 | file_path = os.path.join(camera_dir, f'{env_name}.json') 79 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 80 | with open(file_path, 'r') as f: 81 | params = json.load(f) 82 | cameras = [] 83 | for param_dict in params: 84 | camera = aniposelib.cameras.Camera(name=param_dict['name'], 85 | size=param_dict['size'], 86 | matrix=param_dict['matrix'], 87 | rvec=param_dict['rotation'], 88 | tvec=param_dict['translation'], 89 | dist=param_dict['distortions']) 90 | cameras.append(camera) 91 | camera_group = aniposelib.cameras.CameraGroup(cameras) 92 | return camera_group 93 | 94 | @classmethod 95 | def load_motion(cls, motion_dir, seq_name): 96 | """Load a motion sequence represented using SMPL format.""" 97 | file_path = os.path.join(motion_dir, f'{seq_name}.pkl') 98 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 99 | with open(file_path, 'rb') as f: 100 | data = pickle.load(f) 101 | smpl_poses = data['smpl_poses'] # (N, 24, 3) 102 | smpl_scaling = data['smpl_scaling'] # (1,) 103 | smpl_trans = data['smpl_trans'] # (N, 3) 104 | return smpl_poses, smpl_scaling, smpl_trans 105 | 106 | # @classmethod 107 | def load_keypoint3d(self, seq_name, use_optim=True): 108 | """Load a 3D keypoint sequence represented using COCO format.""" 109 | file_path = os.path.join(self.keypoint3d_dir, f'{seq_name}.pkl') 110 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 111 | with open(file_path, 'rb') as f: 112 | data = pickle.load(f) 113 | if use_optim: 114 | return data['keypoints3d_optim'] # (N, 17, 3) 115 | else: 116 | return data['keypoints3d'] # (N, 17, 3) 117 | 118 | @classmethod 119 | def load_keypoint2d(cls, keypoint_dir, seq_name): 120 | """Load a 2D keypoint sequence represented using COCO format.""" 121 | file_path = os.path.join(keypoint_dir, f'{seq_name}.pkl') 122 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 123 | with open(file_path, 'rb') as f: 124 | data = pickle.load(f) 125 | keypoints2d = data['keypoints2d'] # (nviews, N, 17, 3) 126 | det_scores = data['det_scores'] # (nviews, N) 127 | timestamps = data['timestamps'] # (N,) 128 | return keypoints2d, det_scores, timestamps 129 | -------------------------------------------------------------------------------- /plb/datamodules/data_transform.py: -------------------------------------------------------------------------------- 1 | import math, random 2 | import numpy as np 3 | import torch 4 | 5 | def body_center(joint, dim=3, lhip_id=11, rhip_id=12): 6 | # TODO import lhip_id and rhip_id from dataset 7 | lhip = joint[lhip_id * dim:lhip_id * dim + dim] 8 | rhip = joint[rhip_id * dim:rhip_id * dim + dim] 9 | body_center = (lhip + rhip) * 0.5 10 | return body_center 11 | 12 | 13 | def body_unit(joint, dim=3, lhip_id=11, rhip_id=12): 14 | # TODO import lhip_id and rhip_id from dataset 15 | lhip = joint[lhip_id * dim:lhip_id * dim + dim] 16 | rhip = joint[rhip_id * dim:rhip_id * dim + dim] 17 | unit = torch.linalg.norm(lhip - rhip, ord=2) # positive hip width 18 | return unit 19 | 20 | 21 | def euler_rodrigues_rotation(theta, axis): 22 | # https://en.wikipedia.org/wiki/Euler%E2%80%93Rodrigues_formula 23 | a = np.cos(theta / 2.0) 24 | b, c, d = - axis * np.sin(theta / 2.0) 25 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 26 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 27 | rot_matrix = torch.tensor([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 28 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 29 | [2 * (bd + ac), 2 * (cd - ab), 30 | aa + dd - bb - cc]]).float() # originally cosine returns double 31 | return rot_matrix 32 | 33 | 34 | def time_distortion_func(a, b): 35 | pass 36 | 37 | 38 | def time_scaling_func(s): 39 | assert s > 0 40 | 41 | 42 | class SkeletonTransform(object): 43 | def __init__(self, aug_shift_prob, aug_shift_range, aug_rot_prob, aug_rot_range, min_length, max_length, aug_time_prob, aug_time_rate): 44 | # TODO: move to config if necessary 45 | self.axis = np.array([0, 0, 1]) 46 | # axis = axis / math.sqrt(np.dot(axis, axis)) 47 | self.norm_frame = 0 48 | self.dim = 3 49 | self.aug_shift_prob = aug_shift_prob 50 | self.aug_shift_range = aug_shift_range 51 | self.aug_rot_prob = aug_rot_prob 52 | self.aug_rot_range = aug_rot_range 53 | self.aug_time_prob = aug_time_prob 54 | self.aug_time_faster_prob = 0.5 55 | assert 1 <= aug_time_rate < 2 56 | self.aug_time_rate = aug_time_rate 57 | self.min_length = min_length 58 | self.max_length = max_length 59 | 60 | def __call__(self, x, shut=False, seed=None): 61 | if seed: 62 | random.seed(seed) 63 | # input and output, torch tensor of shape [T, 51] 64 | # returns transformed tensor, and a list of int depicting its correspondence to original index 65 | ttl = x.shape[0] 66 | joint_num = int(x.shape[-1] / self.dim) 67 | 68 | norm_frame = random.randint(0, ttl-1) 69 | # print("Norm Frame is", norm_frame) 70 | # norm_frame = self.norm_frame 71 | # spatial translation normalization by first frame 72 | if joint_num == 17: 73 | ct = body_center(x[norm_frame]) 74 | else: 75 | ct = body_center(x[norm_frame], lhip_id=12, rhip_id=16) 76 | x -= ct.repeat(joint_num).unsqueeze(0) 77 | 78 | assert not x.isnan().any(), "After Translation Normalization" 79 | 80 | if joint_num == 17: 81 | # spatial rotation normalization by first frame 82 | if joint_num == 17: 83 | lh = x[norm_frame, 33:35] # left hip x, left hip y 84 | else: 85 | lh = x[norm_frame, 36:38] # left hip x, left hip y 86 | theta = float(-np.arccos(lh[0] / np.sqrt(np.dot(lh, lh)))) 87 | ttt = euler_rodrigues_rotation(theta, self.axis) 88 | x = (x.reshape(-1, self.dim) @ ttt.transpose(1, 0)).reshape(ttl, joint_num * self.dim) 89 | 90 | assert not x.isnan().any(), "After Rotation Normalization" 91 | 92 | # spatial rotation augmentation 93 | if random.random() < self.aug_rot_prob and not shut: # let's rotate 94 | theta = 2 * math.pi * (random.random() - .5) * self.aug_rot_range 95 | # Euler-Rodrigues formula 96 | ttt = euler_rodrigues_rotation(theta, self.axis) 97 | x = (x.reshape(-1, self.dim) @ ttt.transpose(1, 0)).reshape(ttl, joint_num * self.dim) 98 | 99 | assert not x.isnan().any(), "After Rotation Augmentation" 100 | 101 | # TODO: spatial scaling augmentation 102 | 103 | # spatial translation augmentation 104 | if random.random() < self.aug_shift_prob and not shut: # let's translate 105 | if joint_num == 17: 106 | unit = body_unit(x[self.norm_frame]) 107 | else: 108 | unit = body_unit(x[self.norm_frame], lhip_id=12, rhip_id=16) 109 | move_x = (random.random() - .5) * self.aug_shift_range * unit 110 | move_y = (random.random() - .5) * self.aug_shift_range * unit 111 | move_z = 0 #(random.random() - .5) * self.aug_shift_range * unit # TODO: this is very suspicious! 112 | shift = torch.Tensor([move_x, move_y, move_z]) 113 | x = (x.reshape(-1, self.dim) - shift.unsqueeze(0)).reshape(ttl, joint_num * self.dim) 114 | 115 | assert not x.isnan().any(), "After Translation Augmentation" 116 | 117 | # temporal distortion augmentation 118 | # The index of velo is the index of distorted video while the value of velo is the index of original video. 119 | # It's an index to index pair 120 | velo = np.arange(ttl) 121 | # requirement 1: len(velo) == x.size(0) 122 | # requirement 2: velo[0] == 0 and velo[-1] == ttl - 1 123 | if random.random() < self.aug_time_prob and not shut: # let's do velocity augmentation 124 | # uniform (1, self.aug_time_rate) 125 | t_scale_ = (1 - self.aug_time_rate) * random.random() + self.aug_time_rate 126 | if random.random() < 0.5: # slower 127 | t_scale = t_scale_ 128 | else: # faster 129 | t_scale = 1.0 / t_scale_ 130 | 131 | # slower or faster 132 | # t_scale = 2 * (self.aug_time_rate - 1) * random.random() + 2 - self.aug_time_rate 133 | # only slower 134 | # t_scale = (self.aug_time_rate - 1) * random.random() + 1 135 | # only faster 136 | # t_scale = (self.aug_time_rate - 1) * random.random() + 2 - self.aug_time_rate 137 | 138 | new_ttl = int(ttl * t_scale + 0.5) 139 | new_velo = np.arange(new_ttl) 140 | # TODO: this is not exact nearest neighbor. +0.5 may cause new_velo[-1] == ttl. deal with it later 141 | new_velo = new_velo / t_scale 142 | # assert new_velo[-1] <= ttl - 1 143 | new_x = x[np.floor(new_velo).astype(int)] 144 | 145 | # if self.aug_time_rate == 1: 146 | # assert ttl == new_ttl 147 | # assert (velo - new_velo).sum().item() < 1e-12 148 | # assert (x - new_x).sum().item() < 1e-12 149 | 150 | ttl = new_ttl 151 | velo = new_velo 152 | x = new_x 153 | assert not x.isnan().any(), "After Temporal Augmentation" 154 | 155 | return x, velo 156 | -------------------------------------------------------------------------------- /cluster.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pprint 4 | import shutil 5 | import time 6 | import sys 7 | import yaml 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | import pandas as pd 12 | from pathlib import Path 13 | 14 | from src.data.dataset.loader import AISTDataset 15 | from src import algo 16 | from src.data.dataset.cluster_misc import lexicon, get_names, genre_list 17 | 18 | from plb.models.self_supervised import TAN 19 | from plb.models.self_supervised.tan import TANEvalDataTransform, TANTrainDataTransform 20 | from plb.datamodules import SeqDataModule 21 | from plb.datamodules.data_transform import body_center, euler_rodrigues_rotation 22 | 23 | KEYPOINT_NAME = ["nose", "left_eye", "right_eye", "left_ear", "right_ear", 24 | "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", 25 | "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", "right_ankle"] 26 | 27 | import pytorch_lightning as pl 28 | pl.utilities.seed.seed_everything(0) 29 | 30 | def plain_distance(a, b): 31 | return np.linalg.norm(a - b, ord=2) 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser(description='Train classification network') 35 | 36 | parser.add_argument('--cfg', 37 | help='experiment configure file name', 38 | required=True, 39 | type=str) 40 | 41 | parser.add_argument('--data_dir', 42 | help='path to aistplusplus data directory from repo root', 43 | type=str) 44 | 45 | parser.add_argument('--seed', 46 | help='seed for this run', 47 | default=1, 48 | type=int) 49 | 50 | args, _ = parser.parse_known_args() 51 | pl.utilities.seed.seed_everything(args.seed) 52 | np.random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | with open(args.cfg, 'r') as stream: 55 | ldd = yaml.safe_load(stream) 56 | 57 | if args.data_dir: 58 | ldd["PRETRAIN"]["DATA"]["DATA_DIR"] = args.data_dir 59 | pprint.pprint(ldd) 60 | return ldd 61 | 62 | def main(): 63 | args = parse_args() 64 | debug = args["NAME"] == "debug" 65 | log_dir = os.path.join("./logs", args["NAME"]) 66 | 67 | dirpath = Path(log_dir) 68 | dirpath.mkdir(parents=True, exist_ok=True) 69 | 70 | timed = time.strftime("%Y%m%d_%H%M%S") 71 | with open(os.path.join(log_dir, f"config_used_{timed}.yaml"), "w") as stream: 72 | yaml.dump(args, stream, default_flow_style=False) 73 | video_dir = os.path.join(log_dir, "saved_videos") 74 | Path(video_dir).mkdir(parents=True, exist_ok=True) 75 | official_loader = AISTDataset(os.path.join(args["PRETRAIN"]["DATA"]["DATA_DIR"], "annotations")) 76 | 77 | if 1: # change this to 0 for skeleton experiment 78 | # get model 79 | load_name = args["CLUSTER"]["CKPT"] if args["CLUSTER"]["CKPT"] != -1 else args["NAME"] 80 | with open(os.path.join(log_dir, f"val_cluster_zrsc_scores.txt"), "a") as f: 81 | f.write(f"EXP: {load_name}\n") 82 | cfg = None 83 | for fn in os.listdir(os.path.join("./logs", load_name)): 84 | if fn.endswith(".yaml"): 85 | cfg = fn 86 | with open(os.path.join("./logs", load_name, cfg), 'r') as stream: 87 | old_args = yaml.safe_load(stream) 88 | cpt_name = os.listdir(os.path.join("./logs", load_name, "default/version_0/checkpoints"))[0] 89 | print(f"We are using checkpoint: {cpt_name}") 90 | model = eval(old_args["PRETRAIN"]["ALGO"]).load_from_checkpoint(os.path.join("./logs", load_name, "default/version_0/checkpoints", cpt_name)) 91 | model.eval() 92 | def ske2feat(ldd): 93 | ldd1 = torch.Tensor(ldd).flatten(1, -1) / 100 # [T, 51] 94 | ttl = ldd1.shape[0] 95 | ct = body_center(ldd1[0]) 96 | ldd1 -= ct.repeat(17).unsqueeze(0) 97 | res1 = model(ldd1.unsqueeze(0), torch.tensor([ttl])) 98 | forward_feat = res1[:, 0] # [T1, f] 99 | forward_feat /= torch.linalg.norm(forward_feat, dim=-1, keepdim=True, ord=2) 100 | return forward_feat 101 | else: 102 | # to get results for using raw skeleton, swap with 103 | def ske2feat(ldd): 104 | ldd1 = torch.Tensor(ldd).flatten(1, -1) / 100 # [T, 51] 105 | ttl = ldd1.shape[0] 106 | ct = body_center(ldd1[0]) 107 | ldd1 -= ct.repeat(17).unsqueeze(0) 108 | return ldd1 109 | 110 | # get data 111 | tr_kpt_container = [] 112 | tr_len_container = [] 113 | tr_feat_container = [] 114 | tr_name_container = [] 115 | val_kpt_container = [] 116 | val_len_container = [] 117 | val_feat_container = [] 118 | val_name_container = [] 119 | for genre in genre_list: # mix every genre together 120 | # train data, we only have training set in this setting 121 | tr_df = get_names(genre, trval="train", seed=4321) 122 | tr_df = tr_df[tr_df["situ"] == "sFM"] 123 | val_df = get_names(genre, trval="val", seed=4321) 124 | val_df = val_df[val_df["situ"] == "sFM"] 125 | for reference_name in tqdm(list(tr_df["name"]), desc='Loading training set features'): 126 | ldd = official_loader.load_keypoint3d(reference_name) 127 | tr_kpt_container.append(ldd) 128 | tr_len_container.append(ldd.shape[0]) 129 | tr_feat_container.append(ske2feat(ldd).detach().cpu().numpy()) 130 | tr_name_container.append(reference_name) 131 | for reference_name in tqdm(list(val_df["name"]), desc='Loading validation set features'): 132 | ldd = official_loader.load_keypoint3d(reference_name) 133 | val_kpt_container.append(ldd) 134 | val_len_container.append(ldd.shape[0]) 135 | val_feat_container.append(ske2feat(ldd).detach().cpu().numpy()) 136 | val_name_container.append(reference_name) 137 | 138 | tr_where_to_cut = [0, ] + list(np.cumsum(np.array(tr_len_container))) 139 | tr_stacked = np.vstack(tr_feat_container) 140 | val_where_to_cut = [0, ] + list(np.cumsum(np.array(val_len_container))) 141 | val_stacked = np.vstack(val_feat_container) 142 | 143 | for K in range(args["CLUSTER"]["K_MIN"], args["CLUSTER"]["K_MAX"], 10): 144 | argument_dict = {"distance": plain_distance, "TYPE": "vanilla", "K": K, "TOL": 1e-4} 145 | if not os.path.exists(os.path.join(log_dir, f"advanced_centers_{K}.npy")): 146 | c = getattr(algo, args["CLUSTER"]["TYPE"])(tr_stacked, times=args["CLUSTER"]["TIMES"], argument_dict=argument_dict) 147 | np.save(os.path.join(log_dir, f"advanced_centers_{K}.npy"), c.kmeans.cluster_centers_) 148 | else: 149 | ctrs = np.load(os.path.join(log_dir, f"advanced_centers_{K}.npy")) 150 | c = getattr(algo, args["CLUSTER"]["TYPE"] + "_clusterer")(TIMES=args["CLUSTER"]["TIMES"], K=K, TOL=1e-4) 151 | c.fit(tr_stacked[:K]) 152 | c.kmeans.cluster_centers_ = ctrs 153 | # infer on training set and save 154 | y = np.concatenate([np.ones((l,)) * i for i, l in enumerate(tr_len_container)], axis=0) 155 | s = np.concatenate([np.arange(l) for i, l in enumerate(tr_len_container)], axis=0) 156 | tr_res_df = pd.DataFrame(y, columns=["y"]) # from which sequence 157 | cluster_l = c.get_assignment(tr_stacked) # assigned to which cluster 158 | tr_res_df['cluster'] = cluster_l 159 | tr_res_df['frame_index'] = s # the frame index in home sequence 160 | tr_word_df = pd.DataFrame(columns=["idx", "word", "length", "y", "name"]) # word index in home sequence 161 | for sequence_idx in range(len(tr_len_container)): 162 | name = tr_name_container[sequence_idx] 163 | cluster_seq = list(cluster_l[tr_where_to_cut[sequence_idx]: tr_where_to_cut[sequence_idx + 1]]) + [-1, ] 164 | running_idx = 0 165 | prev = -1 166 | current_len = 0 167 | for cc in cluster_seq: 168 | if cc == prev: 169 | current_len += 1 170 | else: 171 | tr_word_df = tr_word_df.append( 172 | {"idx": int(running_idx), "word": lexicon[prev], "length": current_len, "y": sequence_idx, 173 | "name": name}, ignore_index=True) 174 | running_idx += 1 175 | current_len = 1 176 | prev = cc 177 | tr_word_df = tr_word_df[tr_word_df["idx"] > 0] 178 | tr_word_df.to_pickle(dirpath / f"advanced_tr_{K}.pkl") 179 | print(f"advanced_tr_{K}.pkl dumped to {log_dir}") # saved tokenization of training set 180 | 181 | # infer on validation set and save 182 | y = np.concatenate([np.ones((l,)) * i for i, l in enumerate(val_len_container)], axis=0) 183 | s = np.concatenate([np.arange(l) for i, l in enumerate(val_len_container)], axis=0) 184 | val_res_df = pd.DataFrame(y, columns=["y"]) # from which sequence 185 | cluster_l = c.get_assignment(val_stacked) # assigned to which cluster 186 | val_res_df['cluster'] = cluster_l 187 | val_res_df['frame_index'] = s # the frame index in home sequence 188 | val_word_df = pd.DataFrame(columns=["idx", "word", "length", "y", "name"]) # word index in home sequence 189 | for sequence_idx in range(len(val_len_container)): 190 | name = val_name_container[sequence_idx] 191 | cluster_seq = list(cluster_l[val_where_to_cut[sequence_idx]: val_where_to_cut[sequence_idx + 1]]) + [-1, ] 192 | running_idx = 0 193 | prev = -1 194 | current_len = 0 195 | for cc in cluster_seq: 196 | if cc == prev: 197 | current_len += 1 198 | else: 199 | val_word_df = val_word_df.append( 200 | {"idx": int(running_idx), "word": lexicon[prev], "length": current_len, "y": sequence_idx, 201 | "name": name}, ignore_index=True) 202 | running_idx += 1 203 | current_len = 1 204 | prev = cc 205 | val_word_df = val_word_df[val_word_df["idx"] > 0] 206 | val_word_df.to_pickle(dirpath / f"advanced_val_{K}.pkl") 207 | print(f"advanced_val_{K}.pkl dumped to {log_dir}") # saved tokenization of validation set 208 | 209 | if __name__ == '__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /plb/datamodules/seq_datamodule.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import random 3 | import os 4 | import numpy as np 5 | from typing import Any, Callable, Optional 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import DataLoader, Dataset, IterableDataset 11 | 12 | from .dataset import SkeletonDataset 13 | from .data_transform import SkeletonTransform 14 | 15 | num_cpu = multiprocessing.cpu_count() 16 | preciser = True 17 | 18 | # mp worker 19 | def get_from_name(loader, seq_name): 20 | return seq_name, loader.load_keypoint3d(seq_name) 21 | 22 | 23 | def select_valid_3D_skeletons(skeletons, thre): 24 | # fresh out of official loader is [T, 17, 3] in centinetre 25 | max_per_frame = np.max(np.max(skeletons, axis=-1), axis=-1) # [T] 26 | sel = max_per_frame < thre 27 | kicks = np.sum(~(sel)).item() 28 | return sel, kicks 29 | 30 | miu = 3 31 | scale = 1 32 | def gaussian(m): 33 | # input a torch tensor with the un-abs-ed distance 34 | # returns a Gaussian value of the same shape 35 | m = torch.abs(m) 36 | sh = - torch.square(m / miu) / 2 37 | ex = torch.exp(sh) * scale 38 | return ex 39 | 40 | def two_trans_collate(lop): 41 | first_len = torch.tensor([len(_[2]) for _ in lop]) 42 | second_len = torch.tensor([len(_[3]) for _ in lop]) 43 | first_max = torch.max(first_len).item() 44 | second_max = torch.max(second_len).item() 45 | 46 | first_container = [] 47 | second_container = [] 48 | # first_container_velo = [] 49 | # second_container_velo = [] 50 | rect_container = [] 51 | view1_container = [] 52 | view2_container = [] 53 | indices1 = [] 54 | indices2 = [] 55 | chopped_bs = [] 56 | for b, (i1, i2, velo1, velo2) in enumerate(lop): 57 | first_container.append(torch.cat([i1, torch.zeros(size=(first_max - len(velo1), i1.shape[-1]))], dim=0)) 58 | second_container.append(torch.cat([i2, torch.zeros(size=(second_max - len(velo2), i2.shape[-1]))], dim=0)) 59 | # first_container_velo.append(torch.cat([torch.tensor(velo1), torch.zeros(size=(first_max - len(velo1),))], dim=0)) 60 | # second_container_velo.append(torch.cat([torch.tensor(velo2), torch.zeros(size=(second_max - len(velo2),))], dim=0)) 61 | # creating a rectangle for loss calculation, pending a lot of heuristic design 62 | dist = torch.tensor(velo1).unsqueeze(-1) - torch.tensor(velo2).unsqueeze(0) # [t1, t2] 63 | rect = gaussian(dist).float() 64 | rect_container.append(rect) 65 | 66 | dist = torch.tensor(velo1).unsqueeze(-1) - torch.tensor(velo1).unsqueeze(0) # [t1, t1] 67 | rect = gaussian(dist).float() 68 | view1_container.append(rect) 69 | 70 | dist = torch.tensor(velo2).unsqueeze(-1) - torch.tensor(velo2).unsqueeze(0) # [t2, t2] 71 | rect = gaussian(dist).float() 72 | view2_container.append(rect) 73 | 74 | floor_velo_dist = torch.tensor(np.floor(velo1).astype(int)).unsqueeze(-1) - torch.tensor(np.floor(velo2).astype(int)).unsqueeze(0) 75 | loc1, loc2 = torch.where(floor_velo_dist == 1) 76 | if preciser: 77 | # here we remove even more positives 78 | keep = [] 79 | prev = -1 80 | luck = [] 81 | for i in range(loc1.shape[0]): 82 | if loc1[i] == prev: 83 | luck.append(i) 84 | else: 85 | if len(luck): 86 | chosen = random.choice(luck) 87 | keep.append(chosen) 88 | luck = [i] 89 | prev = loc1[i] 90 | chosen = random.choice(luck) 91 | keep.append(chosen) 92 | loc1, loc2 = loc1[keep], loc2[keep] 93 | keep = [] 94 | prev = -1 95 | luck = [] 96 | for i in range(loc2.shape[0]): 97 | if loc2[i] == prev: 98 | luck.append(i) 99 | else: 100 | if len(luck): 101 | chosen = random.choice(luck) 102 | keep.append(chosen) 103 | luck = [i] 104 | prev = loc1[i] 105 | chosen = random.choice(luck) 106 | keep.append(chosen) 107 | loc1, loc2 = loc1[keep], loc2[keep] 108 | 109 | indices1.append(loc1 + b * first_max) 110 | indices2.append(loc2 + b * second_max) 111 | chopped_bs.append(loc1.size(0)) 112 | indices1 = torch.cat(indices1, dim=0) 113 | indices2 = torch.cat(indices2, dim=0) 114 | assert len(indices1) == len(indices2) 115 | 116 | first_view = torch.stack(first_container) 117 | second_view = torch.stack(second_container) 118 | # first_velo = torch.stack(first_container_velo) 119 | # second_velo = torch.stack(second_container_velo) 120 | m = torch.block_diag(*rect_container) 121 | v1 = torch.block_diag(*view1_container) 122 | v2 = torch.block_diag(*view2_container) 123 | # m = torch.cat([torch.cat([v1, m], dim=1), torch.cat([m.t(), v2], dim=1)], dim=0) 124 | m = torch.cat([torch.cat([v1.new_zeros(v1.shape), m], dim=1), torch.cat([m.t(), v2.new_zeros(v2.shape)], dim=1)], dim=0) 125 | chopped_bs = torch.tensor(chopped_bs) 126 | 127 | return first_view, second_view, first_len, second_len, m, indices1, indices2, chopped_bs 128 | 129 | 130 | class SeqDataset(Dataset): 131 | def __init__(self, data, transform): 132 | # data: a list of torch tensor, each of shape [T, 51] 133 | self.data = data 134 | self.transform = transform 135 | 136 | def __len__(self): 137 | return len(self.data) 138 | 139 | def __getitem__(self, item): 140 | # innately has some randomness, will crop a continuous chunk from a video of batch size length 141 | tbc = self.data[item] 142 | return self.transform(tbc) 143 | 144 | 145 | class SeqDataModule(LightningDataModule): 146 | name = 'seq' 147 | 148 | def __init__( 149 | self, 150 | DATA_DIR, 151 | GENRE, 152 | SPLIT, 153 | BS, 154 | AUG_SHIFT_PROB, 155 | AUG_SHIFT_RANGE, 156 | AUG_ROT_PROB, 157 | AUG_ROT_RANGE, 158 | MIN_LENGTH, 159 | MAX_LENGTH, 160 | NUM_WORKERS, 161 | AUG_TIME_PROB, 162 | AUG_TIME_RATE, 163 | *args: Any, 164 | **kwargs: Any, 165 | ) -> None: 166 | super().__init__(*args, **kwargs) 167 | self.dataset = SkeletonDataset(DATA_DIR, GENRE, SPLIT) 168 | self.batch_size = BS # the batch size to show to outer modules 169 | self.aug_shift_prob = AUG_SHIFT_PROB 170 | self.aug_shift_range = AUG_SHIFT_RANGE 171 | self.aug_rot_prob = AUG_ROT_PROB 172 | self.aug_rot_range = AUG_ROT_RANGE 173 | self.min_length = MIN_LENGTH 174 | self.max_length = MAX_LENGTH 175 | self.train_data = [] 176 | self.val_data = [] 177 | self.num_samples = 0 # not the real num samples, but the total frame number in training set 178 | self.num_samples_valid = 0 179 | self.num_workers = NUM_WORKERS 180 | self.aug_time_prob = AUG_TIME_PROB 181 | self.aug_time_rate = AUG_TIME_RATE 182 | self.prepare_data() 183 | 184 | # TODO: move data washing into self.dataset? 185 | # TODO: deal with constant 186 | self.num_proc = self.num_workers if self.num_workers > 0 else 1 187 | self.threshold = 500 188 | name_list = self.dataset.train_split + self.dataset.validation_split 189 | with multiprocessing.Pool(self.num_proc) as p: 190 | for name, res in p.starmap(get_from_name, 191 | tqdm(zip([self.dataset.official_loader] * len(name_list), name_list), 192 | total=len(name_list), desc='Loading training data...', leave=True)): 193 | sel, kicks = select_valid_3D_skeletons(res, self.threshold) 194 | if kicks > 0: 195 | print(f"kicking out {kicks} frames out of {name} for train, threshold is {self.threshold}") 196 | res = res[sel] 197 | # let's normalize for numerical stability 198 | res = res / 100.0 # originally in cm, now in m 199 | 200 | if SPLIT == 4321: 201 | if name in self.dataset.train_split and "sFM" in name: 202 | self.train_data.append(torch.tensor(res).float().flatten(1)) 203 | self.num_samples += 1 # related to simCLR or scheduling, important 204 | elif name in self.dataset.validation_split and "sFM" in name: 205 | self.val_data.append(torch.tensor(res).float().flatten(1)) 206 | self.num_samples_valid += 1 207 | else: 208 | pass 209 | elif SPLIT == 1234: 210 | if "sFM" in name: 211 | self.train_data.append(torch.tensor(res).float().flatten(1)) 212 | self.num_samples += 1 # related to simCLR or scheduling, important 213 | elif name in self.dataset.validation_split: 214 | self.val_data.append(torch.tensor(res).float().flatten(1)) 215 | self.num_samples_valid += 1 216 | else: 217 | pass 218 | else: 219 | assert 0, 'unknown split, should be in [1234, 4321]' 220 | 221 | print( 222 | f"SPLIT {SPLIT} dances loaded with {self.num_samples} training videos and {self.num_samples_valid} validation videos") 223 | 224 | def train_dataloader(self) -> DataLoader: 225 | train_dataset = SeqDataset(self.train_data, transform=self.train_transforms) 226 | train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, collate_fn=two_trans_collate, 227 | num_workers=self.num_workers) 228 | return train_dataloader 229 | 230 | def val_dataloader(self) -> DataLoader: 231 | val_dataset = SeqDataset(self.val_data, transform=self.val_transforms) 232 | val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=two_trans_collate, 233 | num_workers=self.num_workers) 234 | return val_dataloader 235 | 236 | def _default_transforms(self) -> Callable: 237 | data_transforms = SkeletonTransform(self.aug_shift_prob, self.aug_shift_range, self.aug_rot_prob, self.aug_rot_range, self.min_length, self.max_length, self.aug_time_prob, self.aug_time_rate) 238 | return data_transforms 239 | -------------------------------------------------------------------------------- /src/data/dataset/splits/split_wseed_4321.json: -------------------------------------------------------------------------------- 1 | {"gBR": ["gBR_sBM_cAll_d06_mBR4_ch02", "gBR_sBM_cAll_d04_mBR0_ch01", "gBR_sFM_cAll_d04_mBR2_ch03", "gBR_sBM_cAll_d05_mBR4_ch03", "gBR_sFM_cAll_d04_mBR4_ch05", "gBR_sBM_cAll_d04_mBR0_ch05", "gBR_sFM_cAll_d04_mBR5_ch06", "gBR_sBM_cAll_d04_mBR2_ch09", "gBR_sBM_cAll_d04_mBR3_ch08", "gBR_sBM_cAll_d05_mBR0_ch01", "gBR_sBM_cAll_d04_mBR0_ch02", "gBR_sFM_cAll_d05_mBR4_ch13", "gBR_sBM_cAll_d04_mBR1_ch02", "gBR_sBM_cAll_d05_mBR4_ch08", "gBR_sBM_cAll_d05_mBR0_ch02", "gBR_sBM_cAll_d05_mBR5_ch05", "gBR_sBM_cAll_d06_mBR3_ch04", "gBR_sBM_cAll_d06_mBR4_ch04", "gBR_sBM_cAll_d04_mBR3_ch07", "gBR_sBM_cAll_d06_mBR3_ch05", "gBR_sFM_cAll_d05_mBR5_ch12", "gBR_sBM_cAll_d04_mBR2_ch04", "gBR_sBM_cAll_d06_mBR3_ch10", "gBR_sFM_cAll_d05_mBR4_ch11", "gBR_sFM_cAll_d06_mBR1_ch15", "gBR_sBM_cAll_d05_mBR1_ch05", "gBR_sBM_cAll_d04_mBR3_ch03", "gBR_sBM_cAll_d04_mBR2_ch06", "gBR_sBM_cAll_d04_mBR0_ch06", "gBR_sBM_cAll_d04_mBR0_ch08", "gBR_sBM_cAll_d06_mBR4_ch10", "gBR_sBM_cAll_d04_mBR0_ch04", "gBR_sBM_cAll_d06_mBR4_ch03", "gBR_sBM_cAll_d06_mBR4_ch01", "gBR_sBM_cAll_d06_mBR2_ch01", "gBR_sBM_cAll_d05_mBR5_ch06", "gBR_sBM_cAll_d05_mBR4_ch02", "gBR_sBM_cAll_d05_mBR5_ch09", "gBR_sBM_cAll_d04_mBR1_ch04", "gBR_sBM_cAll_d05_mBR5_ch10", "gBR_sBM_cAll_d05_mBR1_ch08", "gBR_sBM_cAll_d06_mBR2_ch10"], "gPO": ["gPO_sBM_cAll_d12_mPO3_ch04", "gPO_sBM_cAll_d10_mPO2_ch08", "gPO_sBM_cAll_d10_mPO0_ch08", "gPO_sBM_cAll_d10_mPO2_ch03", "gPO_sBM_cAll_d10_mPO3_ch06", "gPO_sFM_cAll_d11_mPO3_ch11", "gPO_sBM_cAll_d11_mPO1_ch08", "gPO_sBM_cAll_d10_mPO1_ch05", "gPO_sBM_cAll_d10_mPO1_ch06", "gPO_sBM_cAll_d10_mPO2_ch06", "gPO_sBM_cAll_d11_mPO5_ch04", "gPO_sFM_cAll_d12_mPO1_ch16", "gPO_sBM_cAll_d10_mPO3_ch02", "gPO_sBM_cAll_d11_mPO1_ch10", "gPO_sBM_cAll_d11_mPO5_ch06", "gPO_sBM_cAll_d10_mPO2_ch07", "gPO_sBM_cAll_d10_mPO3_ch03", "gPO_sBM_cAll_d12_mPO5_ch06", "gPO_sBM_cAll_d11_mPO5_ch05", "gPO_sBM_cAll_d11_mPO0_ch06", "gPO_sBM_cAll_d12_mPO3_ch02", "gPO_sBM_cAll_d10_mPO1_ch08", "gPO_sBM_cAll_d10_mPO1_ch09", "gPO_sBM_cAll_d10_mPO2_ch10", "gPO_sBM_cAll_d11_mPO0_ch07", "gPO_sBM_cAll_d12_mPO2_ch07", "gPO_sBM_cAll_d11_mPO4_ch02", "gPO_sBM_cAll_d12_mPO4_ch07", "gPO_sBM_cAll_d12_mPO4_ch04", "gPO_sBM_cAll_d10_mPO1_ch04", "gPO_sBM_cAll_d12_mPO3_ch03", "gPO_sBM_cAll_d11_mPO0_ch03", "gPO_sFM_cAll_d12_mPO3_ch18", "gPO_sFM_cAll_d10_mPO2_ch03", "gPO_sBM_cAll_d10_mPO0_ch09", "gPO_sBM_cAll_d10_mPO0_ch04", "gPO_sBM_cAll_d12_mPO5_ch10", "gPO_sBM_cAll_d12_mPO2_ch06", "gPO_sBM_cAll_d11_mPO4_ch05", "gPO_sBM_cAll_d10_mPO1_ch03", "gPO_sBM_cAll_d12_mPO5_ch02", "gPO_sFM_cAll_d10_mPO4_ch05"], "gLO": ["gLO_sBM_cAll_d14_mLO5_ch02", "gLO_sBM_cAll_d13_mLO0_ch08", "gLO_sFM_cAll_d13_mLO5_ch06", "gLO_sBM_cAll_d13_mLO2_ch06", "gLO_sBM_cAll_d13_mLO3_ch06", "gLO_sBM_cAll_d13_mLO1_ch08", "gLO_sBM_cAll_d15_mLO2_ch05", "gLO_sBM_cAll_d15_mLO4_ch03", "gLO_sBM_cAll_d14_mLO0_ch04", "gLO_sBM_cAll_d15_mLO2_ch03", "gLO_sBM_cAll_d15_mLO4_ch02", "gLO_sBM_cAll_d13_mLO0_ch03", "gLO_sBM_cAll_d14_mLO4_ch01", "gLO_sBM_cAll_d14_mLO1_ch08", "gLO_sFM_cAll_d15_mLO0_ch15", "gLO_sBM_cAll_d13_mLO1_ch01", "gLO_sBM_cAll_d13_mLO0_ch06", "gLO_sBM_cAll_d15_mLO2_ch07", "gLO_sFM_cAll_d15_mLO4_ch21", "gLO_sBM_cAll_d15_mLO3_ch03", "gLO_sBM_cAll_d15_mLO4_ch06", "gLO_sBM_cAll_d15_mLO2_ch09", "gLO_sBM_cAll_d13_mLO1_ch09", "gLO_sBM_cAll_d15_mLO4_ch01", "gLO_sBM_cAll_d15_mLO2_ch10", "gLO_sBM_cAll_d15_mLO4_ch08", "gLO_sBM_cAll_d15_mLO4_ch09", "gLO_sBM_cAll_d15_mLO5_ch03", "gLO_sBM_cAll_d14_mLO4_ch08", "gLO_sBM_cAll_d14_mLO5_ch05", "gLO_sBM_cAll_d14_mLO0_ch09", "gLO_sBM_cAll_d15_mLO3_ch07", "gLO_sBM_cAll_d13_mLO2_ch01", "gLO_sBM_cAll_d14_mLO4_ch09", "gLO_sBM_cAll_d15_mLO5_ch08", "gLO_sBM_cAll_d13_mLO1_ch03", "gLO_sBM_cAll_d14_mLO4_ch06", "gLO_sBM_cAll_d15_mLO5_ch01", "gLO_sBM_cAll_d13_mLO0_ch09", "gLO_sBM_cAll_d14_mLO5_ch08", "gLO_sBM_cAll_d14_mLO4_ch10", "gLO_sBM_cAll_d13_mLO2_ch08"], "gMH": ["gMH_sBM_cAll_d23_mMH5_ch09", "gMH_sBM_cAll_d24_mMH5_ch07", "gMH_sBM_cAll_d23_mMH5_ch06", "gMH_sBM_cAll_d23_mMH1_ch05", "gMH_sBM_cAll_d24_mMH3_ch06", "gMH_sBM_cAll_d24_mMH5_ch10", "gMH_sBM_cAll_d24_mMH5_ch04", "gMH_sBM_cAll_d22_mMH0_ch03", "gMH_sBM_cAll_d24_mMH2_ch07", "gMH_sBM_cAll_d24_mMH5_ch01", "gMH_sBM_cAll_d23_mMH0_ch01", "gMH_sBM_cAll_d22_mMH2_ch02", "gMH_sBM_cAll_d24_mMH3_ch04", "gMH_sBM_cAll_d24_mMH5_ch08", "gMH_sBM_cAll_d22_mMH1_ch04", "gMH_sFM_cAll_d22_mMH1_ch02", "gMH_sBM_cAll_d24_mMH2_ch10", "gMH_sBM_cAll_d22_mMH3_ch01", "gMH_sBM_cAll_d23_mMH4_ch03", "gMH_sFM_cAll_d23_mMH3_ch11", "gMH_sBM_cAll_d24_mMH4_ch05", "gMH_sBM_cAll_d23_mMH0_ch03", "gMH_sBM_cAll_d23_mMH4_ch05", "gMH_sBM_cAll_d24_mMH2_ch09", "gMH_sBM_cAll_d24_mMH4_ch10", "gMH_sFM_cAll_d24_mMH5_ch20", "gMH_sBM_cAll_d23_mMH5_ch10", "gMH_sBM_cAll_d22_mMH1_ch07", "gMH_sBM_cAll_d23_mMH5_ch01", "gMH_sBM_cAll_d22_mMH2_ch03", "gMH_sBM_cAll_d23_mMH1_ch04", "gMH_sFM_cAll_d22_mMH0_ch01", "gMH_sBM_cAll_d24_mMH5_ch02", "gMH_sBM_cAll_d22_mMH3_ch10", "gMH_sFM_cAll_d22_mMH3_ch04", "gMH_sBM_cAll_d23_mMH5_ch05", "gMH_sBM_cAll_d24_mMH4_ch07", "gMH_sBM_cAll_d23_mMH5_ch07", "gMH_sBM_cAll_d23_mMH1_ch10", "gMH_sBM_cAll_d22_mMH1_ch01", "gMH_sBM_cAll_d23_mMH0_ch02", "gMH_sBM_cAll_d23_mMH0_ch10"], "gLH": ["gLH_sBM_cAll_d16_mLH3_ch01", "gLH_sBM_cAll_d17_mLH1_ch03", "gLH_sBM_cAll_d18_mLH3_ch10", "gLH_sBM_cAll_d16_mLH3_ch06", "gLH_sBM_cAll_d16_mLH0_ch10", "gLH_sBM_cAll_d16_mLH1_ch01", "gLH_sBM_cAll_d17_mLH0_ch01", "gLH_sFM_cAll_d17_mLH3_ch11", "gLH_sBM_cAll_d18_mLH5_ch04", "gLH_sFM_cAll_d16_mLH3_ch04", "gLH_sBM_cAll_d17_mLH5_ch02", "gLH_sBM_cAll_d17_mLH0_ch09", "gLH_sBM_cAll_d16_mLH1_ch06", "gLH_sBM_cAll_d16_mLH1_ch02", "gLH_sBM_cAll_d16_mLH0_ch08", "gLH_sBM_cAll_d18_mLH2_ch03", "gLH_sBM_cAll_d18_mLH2_ch01", "gLH_sBM_cAll_d16_mLH2_ch10", "gLH_sBM_cAll_d17_mLH1_ch09", "gLH_sBM_cAll_d17_mLH0_ch07", "gLH_sFM_cAll_d17_mLH4_ch12", "gLH_sBM_cAll_d18_mLH2_ch09", "gLH_sBM_cAll_d18_mLH5_ch05", "gLH_sFM_cAll_d16_mLH4_ch05", "gLH_sBM_cAll_d18_mLH5_ch09", "gLH_sFM_cAll_d16_mLH3_ch07", "gLH_sBM_cAll_d17_mLH4_ch09", "gLH_sFM_cAll_d17_mLH1_ch09", "gLH_sBM_cAll_d18_mLH2_ch10", "gLH_sFM_cAll_d18_mLH0_ch15", "gLH_sBM_cAll_d16_mLH3_ch02", "gLH_sFM_cAll_d16_mLH2_ch03", "gLH_sBM_cAll_d18_mLH3_ch05", "gLH_sBM_cAll_d16_mLH0_ch06", "gLH_sFM_cAll_d16_mLH5_ch06", "gLH_sFM_cAll_d17_mLH0_ch14", "gLH_sBM_cAll_d16_mLH2_ch04", "gLH_sBM_cAll_d17_mLH0_ch04", "gLH_sBM_cAll_d17_mLH5_ch03", "gLH_sBM_cAll_d16_mLH1_ch09", "gLH_sFM_cAll_d17_mLH2_ch10", "gLH_sBM_cAll_d17_mLH1_ch10"], "gHO": ["gHO_sBM_cAll_d21_mHO3_ch02", "gHO_sBM_cAll_d20_mHO4_ch05", "gHO_sBM_cAll_d21_mHO2_ch09", "gHO_sBM_cAll_d20_mHO5_ch05", "gHO_sBM_cAll_d19_mHO0_ch05", "gHO_sBM_cAll_d21_mHO4_ch01", "gHO_sBM_cAll_d21_mHO5_ch09", "gHO_sBM_cAll_d21_mHO3_ch07", "gHO_sBM_cAll_d19_mHO0_ch01", "gHO_sBM_cAll_d20_mHO5_ch08", "gHO_sBM_cAll_d20_mHO4_ch07", "gHO_sBM_cAll_d20_mHO0_ch06", "gHO_sBM_cAll_d20_mHO1_ch05", "gHO_sBM_cAll_d19_mHO3_ch05", "gHO_sBM_cAll_d21_mHO3_ch10", "gHO_sBM_cAll_d21_mHO4_ch07", "gHO_sBM_cAll_d20_mHO1_ch04", "gHO_sBM_cAll_d20_mHO0_ch02", "gHO_sBM_cAll_d19_mHO2_ch01", "gHO_sFM_cAll_d20_mHO3_ch11", "gHO_sBM_cAll_d20_mHO0_ch05", "gHO_sBM_cAll_d19_mHO2_ch09", "gHO_sBM_cAll_d20_mHO5_ch04", "gHO_sBM_cAll_d20_mHO1_ch08", "gHO_sBM_cAll_d19_mHO0_ch04", "gHO_sBM_cAll_d19_mHO1_ch03", "gHO_sBM_cAll_d21_mHO3_ch08", "gHO_sBM_cAll_d20_mHO0_ch01", "gHO_sFM_cAll_d19_mHO2_ch03", "gHO_sBM_cAll_d21_mHO2_ch06", "gHO_sBM_cAll_d20_mHO1_ch03", "gHO_sBM_cAll_d21_mHO2_ch01", "gHO_sFM_cAll_d20_mHO2_ch10", "gHO_sBM_cAll_d19_mHO0_ch09", "gHO_sBM_cAll_d19_mHO1_ch10", "gHO_sBM_cAll_d19_mHO1_ch08", "gHO_sBM_cAll_d20_mHO5_ch09", "gHO_sFM_cAll_d20_mHO0_ch08", "gHO_sBM_cAll_d20_mHO4_ch02", "gHO_sBM_cAll_d21_mHO2_ch10", "gHO_sBM_cAll_d21_mHO5_ch03", "gHO_sBM_cAll_d21_mHO4_ch04"], "gWA": ["gWA_sBM_cAll_d27_mWA4_ch01", "gWA_sBM_cAll_d26_mWA4_ch08", "gWA_sBM_cAll_d26_mWA1_ch04", "gWA_sFM_cAll_d27_mWA1_ch16", "gWA_sBM_cAll_d26_mWA4_ch06", "gWA_sBM_cAll_d27_mWA2_ch10", "gWA_sBM_cAll_d25_mWA1_ch10", "gWA_sFM_cAll_d26_mWA1_ch09", "gWA_sBM_cAll_d27_mWA2_ch09", "gWA_sBM_cAll_d27_mWA5_ch03", "gWA_sBM_cAll_d26_mWA4_ch01", "gWA_sBM_cAll_d26_mWA0_ch02", "gWA_sBM_cAll_d25_mWA3_ch10", "gWA_sFM_cAll_d25_mWA4_ch05", "gWA_sBM_cAll_d26_mWA4_ch07", "gWA_sBM_cAll_d26_mWA4_ch09", "gWA_sBM_cAll_d25_mWA2_ch02", "gWA_sBM_cAll_d27_mWA5_ch10", "gWA_sFM_cAll_d25_mWA0_ch01", "gWA_sBM_cAll_d27_mWA3_ch02", "gWA_sBM_cAll_d27_mWA3_ch06", "gWA_sBM_cAll_d27_mWA5_ch04", "gWA_sBM_cAll_d25_mWA0_ch09", "gWA_sBM_cAll_d25_mWA2_ch04", "gWA_sBM_cAll_d25_mWA0_ch02", "gWA_sBM_cAll_d27_mWA3_ch07", "gWA_sFM_cAll_d27_mWA2_ch17", "gWA_sBM_cAll_d25_mWA3_ch08", "gWA_sBM_cAll_d26_mWA1_ch08", "gWA_sBM_cAll_d25_mWA0_ch04", "gWA_sBM_cAll_d27_mWA5_ch02", "gWA_sFM_cAll_d26_mWA0_ch08", "gWA_sBM_cAll_d25_mWA1_ch02", "gWA_sFM_cAll_d25_mWA1_ch02", "gWA_sBM_cAll_d27_mWA5_ch07", "gWA_sBM_cAll_d25_mWA3_ch02", "gWA_sBM_cAll_d26_mWA5_ch01", "gWA_sBM_cAll_d26_mWA5_ch09", "gWA_sBM_cAll_d26_mWA4_ch03", "gWA_sBM_cAll_d26_mWA0_ch09", "gWA_sBM_cAll_d27_mWA5_ch06", "gWA_sBM_cAll_d26_mWA5_ch10"], "gKR": ["gKR_sBM_cAll_d30_mKR4_ch04", "gKR_sFM_cAll_d30_mKR5_ch20", "gKR_sBM_cAll_d28_mKR3_ch05", "gKR_sBM_cAll_d30_mKR3_ch08", "gKR_sBM_cAll_d28_mKR1_ch01", "gKR_sBM_cAll_d30_mKR4_ch03", "gKR_sBM_cAll_d28_mKR2_ch10", "gKR_sFM_cAll_d29_mKR3_ch11", "gKR_sBM_cAll_d30_mKR5_ch10", "gKR_sBM_cAll_d28_mKR2_ch03", "gKR_sBM_cAll_d28_mKR0_ch03", "gKR_sBM_cAll_d29_mKR4_ch05", "gKR_sBM_cAll_d29_mKR1_ch07", "gKR_sBM_cAll_d30_mKR2_ch07", "gKR_sBM_cAll_d28_mKR0_ch07", "gKR_sBM_cAll_d28_mKR2_ch07", "gKR_sBM_cAll_d30_mKR2_ch02", "gKR_sBM_cAll_d28_mKR2_ch05", "gKR_sBM_cAll_d28_mKR2_ch02", "gKR_sBM_cAll_d30_mKR2_ch01", "gKR_sBM_cAll_d30_mKR3_ch07", "gKR_sBM_cAll_d28_mKR1_ch10", "gKR_sBM_cAll_d28_mKR1_ch09", "gKR_sBM_cAll_d28_mKR3_ch06", "gKR_sBM_cAll_d29_mKR5_ch10", "gKR_sBM_cAll_d28_mKR0_ch10", "gKR_sBM_cAll_d28_mKR3_ch08", "gKR_sBM_cAll_d29_mKR4_ch08", "gKR_sBM_cAll_d30_mKR4_ch06", "gKR_sBM_cAll_d29_mKR5_ch01", "gKR_sBM_cAll_d28_mKR3_ch03", "gKR_sBM_cAll_d29_mKR4_ch03", "gKR_sBM_cAll_d28_mKR2_ch01", "gKR_sBM_cAll_d29_mKR4_ch06", "gKR_sBM_cAll_d28_mKR2_ch04", "gKR_sBM_cAll_d30_mKR4_ch09", "gKR_sBM_cAll_d30_mKR2_ch08", "gKR_sBM_cAll_d28_mKR0_ch02", "gKR_sBM_cAll_d28_mKR1_ch05", "gKR_sFM_cAll_d28_mKR3_ch04", "gKR_sBM_cAll_d28_mKR3_ch02", "gKR_sFM_cAll_d28_mKR1_ch02"], "gJS": ["gJS_sBM_cAll_d03_mJS2_ch04", "gJS_sFM_cAll_d01_mJS1_ch02", "gJS_sFM_cAll_d01_mJS1_ch07", "gJS_sBM_cAll_d02_mJS0_ch08", "gJS_sBM_cAll_d02_mJS0_ch03", "gJS_sBM_cAll_d01_mJS2_ch07", "gJS_sBM_cAll_d03_mJS3_ch03", "gJS_sBM_cAll_d01_mJS2_ch04", "gJS_sFM_cAll_d02_mJS2_ch03", "gJS_sBM_cAll_d01_mJS2_ch06", "gJS_sBM_cAll_d03_mJS3_ch01", "gJS_sBM_cAll_d02_mJS1_ch10", "gJS_sBM_cAll_d02_mJS0_ch07", "gJS_sBM_cAll_d02_mJS1_ch02", "gJS_sBM_cAll_d02_mJS1_ch03", "gJS_sBM_cAll_d03_mJS4_ch02", "gJS_sBM_cAll_d02_mJS0_ch05", "gJS_sFM_cAll_d01_mJS3_ch04", "gJS_sFM_cAll_d03_mJS2_ch03", "gJS_sFM_cAll_d01_mJS2_ch03", "gJS_sBM_cAll_d03_mJS2_ch06", "gJS_sBM_cAll_d03_mJS4_ch08", "gJS_sBM_cAll_d03_mJS4_ch09", "gJS_sBM_cAll_d03_mJS4_ch01", "gJS_sBM_cAll_d02_mJS5_ch02", "gJS_sBM_cAll_d02_mJS0_ch01", "gJS_sBM_cAll_d01_mJS3_ch02", "gJS_sBM_cAll_d02_mJS4_ch08", "gJS_sFM_cAll_d01_mJS5_ch06", "gJS_sBM_cAll_d01_mJS2_ch09", "gJS_sBM_cAll_d02_mJS0_ch06", "gJS_sFM_cAll_d02_mJS5_ch10", "gJS_sBM_cAll_d02_mJS5_ch09", "gJS_sBM_cAll_d02_mJS5_ch04", "gJS_sBM_cAll_d02_mJS1_ch07", "gJS_sBM_cAll_d03_mJS5_ch02", "gJS_sBM_cAll_d01_mJS1_ch06", "gJS_sBM_cAll_d03_mJS4_ch07", "gJS_sFM_cAll_d03_mJS1_ch02", "gJS_sBM_cAll_d01_mJS1_ch03", "gJS_sBM_cAll_d03_mJS5_ch07", "gJS_sBM_cAll_d03_mJS2_ch10"], "gJB": ["gJB_sBM_cAll_d07_mJB0_ch01", "gJB_sBM_cAll_d07_mJB0_ch10", "gJB_sFM_cAll_d09_mJB4_ch19", "gJB_sBM_cAll_d07_mJB3_ch09", "gJB_sBM_cAll_d09_mJB3_ch04", "gJB_sFM_cAll_d08_mJB5_ch13", "gJB_sBM_cAll_d09_mJB4_ch10", "gJB_sBM_cAll_d09_mJB2_ch02", "gJB_sBM_cAll_d07_mJB1_ch06", "gJB_sBM_cAll_d08_mJB0_ch06", "gJB_sBM_cAll_d08_mJB0_ch03", "gJB_sFM_cAll_d08_mJB2_ch10", "gJB_sBM_cAll_d08_mJB4_ch01", "gJB_sFM_cAll_d07_mJB0_ch01", "gJB_sBM_cAll_d08_mJB4_ch03", "gJB_sBM_cAll_d08_mJB4_ch06", "gJB_sBM_cAll_d09_mJB3_ch05", "gJB_sFM_cAll_d08_mJB5_ch14", "gJB_sBM_cAll_d07_mJB2_ch03", "gJB_sBM_cAll_d08_mJB4_ch08", "gJB_sBM_cAll_d07_mJB0_ch07", "gJB_sBM_cAll_d09_mJB3_ch06", "gJB_sBM_cAll_d07_mJB0_ch09", "gJB_sFM_cAll_d07_mJB1_ch02", "gJB_sFM_cAll_d08_mJB0_ch08", "gJB_sBM_cAll_d09_mJB5_ch08", "gJB_sBM_cAll_d07_mJB2_ch07", "gJB_sBM_cAll_d07_mJB2_ch06", "gJB_sBM_cAll_d07_mJB1_ch03", "gJB_sBM_cAll_d08_mJB0_ch07", "gJB_sFM_cAll_d07_mJB4_ch05", "gJB_sBM_cAll_d08_mJB4_ch09", "gJB_sBM_cAll_d08_mJB4_ch10", "gJB_sBM_cAll_d08_mJB1_ch04", "gJB_sBM_cAll_d09_mJB3_ch03", "gJB_sFM_cAll_d09_mJB3_ch18", "gJB_sBM_cAll_d09_mJB4_ch01", "gJB_sBM_cAll_d07_mJB1_ch10", "gJB_sBM_cAll_d08_mJB4_ch05", "gJB_sBM_cAll_d08_mJB5_ch10", "gJB_sBM_cAll_d08_mJB0_ch04", "gJB_sBM_cAll_d08_mJB0_ch05"]} -------------------------------------------------------------------------------- /src/data/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google AI Perception Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utils for AIST++ Dataset.""" 16 | import os 17 | import json 18 | import ffmpeg 19 | import numpy as np 20 | import contextlib 21 | from PIL import Image, ImageDraw, ImageFont 22 | 23 | import aniposelib 24 | 25 | from src.data.dataset.visualizer import plot_kpt, plot_kpt_plain 26 | from src.data.dataset.cluster_misc import lexicon, get_names, genre_list, vidn_parse 27 | 28 | font = ImageFont.truetype("alata.ttf", 12) 29 | font_large = ImageFont.truetype("alata.ttf", 24) 30 | CAP_COL = (16, 16, 16) 31 | 32 | parse_keys = ["genre", "situ", "dancer", "tempo", "choreo", "name"] 33 | def vidn_parse(s): 34 | res = {} 35 | if s.endswith("pkl"): 36 | s = s[:-4] 37 | for seg in s.split("_"): 38 | if seg.startswith("g"): 39 | res["genre"] = seg 40 | elif seg.startswith("s"): 41 | res["situ"] = seg 42 | elif seg.startswith("d"): 43 | res["dancer"] = seg 44 | elif seg.startswith("m"): 45 | res["tempo"] = 10 * (int(seg[3:]) + 8) 46 | elif seg.startswith("ch"): 47 | res["choreo"] = res["genre"][1:] + seg[2:] + res["situ"][1:] 48 | else: 49 | pass 50 | res["name"] = s # currently does not support camera variation 51 | return res 52 | 53 | def ffmpeg_video_read(video_path, fps=None): 54 | """Video reader based on FFMPEG. 55 | 56 | This function supports setting fps for video reading. It is critical 57 | as AIST++ Dataset are constructed under exact 60 fps, while some of 58 | the AIST dance videos are not percisely 60 fps. 59 | 60 | Args: 61 | video_path: A video file. 62 | fps: Use specific fps for video reading. (optional) 63 | Returns: 64 | A `np.array` with the shape of [seq_len, height, width, 3] 65 | """ 66 | assert os.path.exists(video_path), f'{video_path} does not exist!' 67 | try: 68 | probe = ffmpeg.probe(video_path) 69 | except ffmpeg.Error as e: 70 | print('stdout:', e.stdout.decode('utf8')) 71 | print('stderr:', e.stderr.decode('utf8')) 72 | raise e 73 | video_info = next(stream for stream in probe['streams'] 74 | if stream['codec_type'] == 'video') 75 | width = int(video_info['width']) 76 | height = int(video_info['height']) 77 | stream = ffmpeg.input(video_path) 78 | if fps: 79 | stream = ffmpeg.filter(stream, 'fps', fps=fps, round='up') 80 | stream = ffmpeg.output(stream, 'pipe:', format='rawvideo', pix_fmt='rgb24') 81 | out, _ = ffmpeg.run(stream, capture_stdout=True) 82 | out = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) 83 | return out.copy() 84 | 85 | 86 | def ffmpeg_video_write(data, video_path, fps=25): 87 | """Video writer based on FFMPEG. 88 | 89 | Args: 90 | data: A `np.array` with the shape of [seq_len, height, width, 3] 91 | video_path: A video file. 92 | fps: Use specific fps for video writing. (optional) 93 | """ 94 | assert len(data.shape) == 4, f'input shape is not valid! Got {data.shape}!' 95 | _, height, width, _ = data.shape 96 | # import pdb; pdb.set_trace() 97 | os.makedirs(os.path.dirname(video_path), exist_ok=True) 98 | writer = ( 99 | ffmpeg 100 | .input('pipe:', framerate=fps, format='rawvideo', 101 | pix_fmt='rgb24', s='{}x{}'.format(width, height)) 102 | .output(video_path, pix_fmt='yuv420p') 103 | .overwrite_output() 104 | .run_async(pipe_stdin=True) 105 | ) 106 | for frame in data: 107 | writer.stdin.write(frame.astype(np.uint8).tobytes()) 108 | writer.stdin.close() 109 | 110 | 111 | def save_keypoints3d_as_video(keypoints3d, captions, data_root, video_path): 112 | assert len(captions) == keypoints3d.shape[0] 113 | file_path = os.path.join(data_root, "annotations/cameras/setting1.json") 114 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 115 | with open(file_path, 'r') as f: 116 | params = json.load(f) 117 | cameras = [] 118 | for param_dict in params: 119 | camera = aniposelib.cameras.Camera(name=param_dict['name'], 120 | size=param_dict['size'], 121 | matrix=param_dict['matrix'], 122 | rvec=param_dict['rotation'], 123 | tvec=param_dict['translation'], 124 | dist=param_dict['distortions']) 125 | cameras.append(camera) 126 | cgroup = aniposelib.cameras.CameraGroup(cameras) 127 | length = keypoints3d.shape[0] 128 | keypoints2d = (cgroup.project(keypoints3d) // 2).reshape(9, length, 17, 2)[0] 129 | blank_video = np.ones((length, 1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 255 130 | for iframe, (keypoint, cap) in enumerate(zip(keypoints2d, captions)): 131 | if iframe >= blank_video.shape[0]: 132 | break 133 | tmp = plot_kpt(keypoint, blank_video[iframe]) 134 | tmp = Image.fromarray(tmp, 'RGB') 135 | ImageDraw.Draw(tmp).text((25, 25), cap, CAP_COL, font=font) 136 | blank_video[iframe] = np.array(tmp) 137 | ffmpeg_video_write(blank_video, video_path, fps=30) # play it slowly 138 | 139 | def plot_cool(kpt, cap, data_root, c=None): 140 | # kpt: T, 17, 3 141 | file_path = os.path.join(data_root, "annotations/cameras/setting1.json") 142 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 143 | with open(file_path, 'r') as f: 144 | params = json.load(f) 145 | cameras = [] 146 | for param_dict in params: 147 | camera = aniposelib.cameras.Camera(name=param_dict['name'], 148 | size=param_dict['size'], 149 | matrix=param_dict['matrix'], 150 | rvec=param_dict['rotation'], 151 | tvec=param_dict['translation'], 152 | dist=param_dict['distortions']) 153 | cameras.append(camera) 154 | cgroup = aniposelib.cameras.CameraGroup(cameras) 155 | length = kpt.shape[0] 156 | keypoints2d = (cgroup.project(kpt) // 2).reshape(9, length, 17, 2)[0] 157 | blank_video = np.ones((1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 255 158 | for iframe, keypoint in enumerate(keypoints2d): 159 | if c is not None: 160 | blank_video = plot_kpt_plain(keypoint, blank_video, c) 161 | else: 162 | blank_video = plot_kpt_plain(keypoint, blank_video, [0, 0, 0]) 163 | blank_video = (255 - (255 - blank_video) * 0.8).astype(np.uint8) 164 | tmp = Image.fromarray(blank_video, 'RGB') 165 | ImageDraw.Draw(tmp).text((25, 25), cap, CAP_COL, font=font) 166 | return np.array(tmp) 167 | 168 | def save_paired_keypoints3d_as_video(keypoints3d_raw, keypoints3d_gen, cap1, cap2, data_root, video_path, align=False): 169 | # align option will align the generated video to 170 | assert len(cap1) == keypoints3d_raw.shape[0] == len(cap2) == keypoints3d_gen.shape[0] 171 | 172 | file_path = os.path.join(data_root, "annotations/cameras/setting1.json") 173 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 174 | with open(file_path, 'r') as f: 175 | params = json.load(f) 176 | cameras = [] 177 | for param_dict in params: 178 | camera = aniposelib.cameras.Camera(name=param_dict['name'], 179 | size=param_dict['size'], 180 | matrix=param_dict['matrix'], 181 | rvec=param_dict['rotation'], 182 | tvec=param_dict['translation'], 183 | dist=param_dict['distortions']) 184 | cameras.append(camera) 185 | cgroup = aniposelib.cameras.CameraGroup(cameras) 186 | length = keypoints3d_raw.shape[0] 187 | keypoints2d_raw = cgroup.project(keypoints3d_raw).reshape(9, length, 17, 2)[0] // 2 # there are nine cameras there 188 | keypoints2d_gen = cgroup.project(keypoints3d_gen).reshape(9, length, 17, 2)[0] // 2 189 | 190 | blank_video_raw = np.ones((length, 1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 255 191 | for iframe, (keypoint, cap) in enumerate(zip(keypoints2d_raw, cap1)): 192 | if iframe >= blank_video_raw.shape[0]: 193 | break 194 | tmp = plot_kpt(keypoint, blank_video_raw[iframe]) 195 | # import pdb; pdb.set_trace() 196 | tmp = Image.fromarray(tmp, 'RGB') 197 | ImageDraw.Draw(tmp).text((25, 25), "Raw video: "+cap, CAP_COL, font=font) 198 | blank_video_raw[iframe] = np.array(tmp) 199 | 200 | blank_video_gen = np.ones((length, 1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 255 201 | for iframe, (keypoint, cap) in enumerate(zip(keypoints2d_gen, cap2)): 202 | if iframe >= blank_video_gen.shape[0]: 203 | break 204 | if cap == "no matched": 205 | tmp = np.ones((1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 0 206 | tmp = Image.fromarray(tmp, 'RGB') 207 | ImageDraw.Draw(tmp).text((25, 25), "Gnd video: "+cap, CAP_COL, font=font) 208 | blank_video_gen[iframe] = np.array(tmp) 209 | else: 210 | tmp = plot_kpt(keypoint, blank_video_gen[iframe]) 211 | tmp = Image.fromarray(tmp, 'RGB') 212 | ImageDraw.Draw(tmp).text((25, 25), "Gnd video: "+cap, CAP_COL, font=font) 213 | blank_video_gen[iframe] = np.array(tmp) 214 | 215 | blank_video = np.concatenate([blank_video_raw, blank_video_gen], axis=1) 216 | ffmpeg_video_write(blank_video, video_path, fps=30) 217 | 218 | 219 | def save_centroids_as_video(skes, caps, data_root, video_path): 220 | # skes: dict of word: list of skeleton motions [T, 17 * 3] 221 | # caps: dict of word: list of strings 222 | file_path = os.path.join(data_root, "annotations/cameras/setting1.json") 223 | assert os.path.exists(file_path), f'File {file_path} does not exist!' 224 | with open(file_path, 'r') as f: 225 | params = json.load(f) 226 | cameras = [] 227 | for param_dict in params: 228 | camera = aniposelib.cameras.Camera(name=param_dict['name'], 229 | size=param_dict['size'], 230 | matrix=param_dict['matrix'], 231 | rvec=param_dict['rotation'], 232 | tvec=param_dict['translation'], 233 | dist=param_dict['distortions']) 234 | cameras.append(camera) 235 | cgroup = aniposelib.cameras.CameraGroup(cameras) 236 | 237 | blank_video_container = [] 238 | for word in lexicon: 239 | if word in skes: 240 | for chunk_idx in range(len(skes[word])//9): # TODO: here we sometime discard some 241 | near_points = [] 242 | len_contatiner = [] 243 | for _ in skes[word][chunk_idx*9:(chunk_idx+1)*9]: 244 | length = _.shape[0] 245 | len_contatiner.append(length) 246 | near_points.append(cgroup.project(_).reshape(9, length, 17, 2)[0] // 2) 247 | length = max(len_contatiner) 248 | blank_video_raw = np.ones((length, 1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 255 249 | plotted = [] 250 | for idx, point in enumerate(near_points): 251 | blank_video_gen = np.ones((length, 1080 // 2, 1920 // 2, 3), dtype=np.uint8) * 255 252 | for iframe, keypoint in enumerate(point): 253 | if iframe >= blank_video_gen.shape[0]: 254 | break 255 | tmp = plot_kpt(keypoint, blank_video_gen[iframe]) 256 | tmp = Image.fromarray(tmp, 'RGB') 257 | ImageDraw.Draw(tmp).text((25, 25), f"Near-by ground-truth point {idx}", CAP_COL, font=font_large) 258 | blank_video_gen[iframe] = np.array(tmp) 259 | plotted.append(blank_video_gen[:, ::3, ::3, :]) 260 | 261 | blank_video_gen = np.concatenate([ 262 | np.concatenate([plotted[0], plotted[1], plotted[2]], axis=2), 263 | np.concatenate([plotted[3], plotted[4], plotted[5]], axis=2), 264 | np.concatenate([plotted[6], plotted[7], plotted[8]], axis=2), 265 | ], axis=1) 266 | blank_video_container.append(blank_video_gen) 267 | blank_video = np.concatenate(blank_video_container, axis=0) 268 | ffmpeg_video_write(blank_video, video_path, fps=15) 269 | 270 | def spatial_align(input_video, reconstructed_video): 271 | # [num_snippets, LENGTH*17*3] 272 | num_snippets = input_video.shape[0] 273 | assert reconstructed_video.shape[0] == num_snippets 274 | ttl = input_video.shape[1] 275 | aligned = [] 276 | for i in range(num_snippets): 277 | x = input_video[i] 278 | body_centre_x_to = (x[33] + x[36]) / 2 279 | body_centre_y_to = (x[34] + x[37]) / 2 280 | body_centre_z_to = (x[35] + x[38]) / 2 281 | y = reconstructed_video[i] 282 | body_centre_x_fm = (y[33] + y[36]) / 2 283 | body_centre_y_fm = (y[34] + y[37]) / 2 284 | body_centre_z_fm = (y[35] + y[38]) / 2 285 | shift = np.tile(np.array([body_centre_x_to - body_centre_x_fm, 286 | body_centre_y_to - body_centre_y_fm, 287 | body_centre_z_to - body_centre_z_fm]), ttl // 3) 288 | aligned.append(y + shift) 289 | return np.vstack(aligned) 290 | 291 | def rigid_transform_3D(A, B): 292 | centroid_A = np.mean(A, axis = 0) 293 | centroid_B = np.mean(B, axis = 0) 294 | H = np.dot(np.transpose(A - centroid_A), B - centroid_B) 295 | U, s, V = np.linalg.svd(H) 296 | R = np.dot(np.transpose(V), np.transpose(U)) 297 | if np.linalg.det(R) < 0: 298 | V[2] = -V[2] 299 | R = np.dot(np.transpose(V), np.transpose(U)) 300 | t = -np.dot(R, np.transpose(centroid_A)) + np.transpose(centroid_B) 301 | return R, t 302 | 303 | def rigid_align(A, B): 304 | # both numpy array of [J, 3], align A to B 305 | R, t = rigid_transform_3D(A, B) 306 | A2 = np.transpose(np.dot(R, np.transpose(A))) + t 307 | return A2 308 | 309 | def rigid_align_sequence(A, B): 310 | # align A's first frame to B 311 | # A, [T1, J, 3] 312 | # B, [T2, J, 3] 313 | R, t = rigid_transform_3D(A[0], B[-1]) 314 | a_container = [] 315 | for a in A: 316 | a2 = np.transpose(np.dot(R, np.transpose(a))) + t 317 | a_container.append(a2) 318 | A2 = np.stack(a_container, axis=0) 319 | return A2 320 | -------------------------------------------------------------------------------- /plb/models/self_supervised/tan/tan_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math, bisect 3 | from argparse import ArgumentParser 4 | from typing import Callable, Optional 5 | 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | from pytorch_lightning.core.optimizer import LightningOptimizer 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch.optim.optimizer import Optimizer 13 | 14 | from pl_bolts.optimizers.lars_scheduling import LARSWrapper 15 | from pl_bolts.transforms.dataset_normalizations import ( 16 | cifar10_normalization, 17 | imagenet_normalization, 18 | stl10_normalization, 19 | ) 20 | 21 | from plb.models.encoder import Transformer, Transformer_wote 22 | 23 | dump_time = False 24 | big_number = 2 ** 13 # a number >> T 25 | 26 | class LargeMarginInSoftmaxLoss(nn.CrossEntropyLoss): 27 | # from https://github.com/tk1980/LargeMarginInSoftmax/blob/master/models/modules/myloss.py 28 | def __init__(self, reg_lambda=0.3, deg_logit=None, 29 | weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): 30 | super(LargeMarginInSoftmaxLoss, self).__init__(weight=weight, size_average=size_average, 31 | ignore_index=ignore_index, reduce=reduce, reduction=reduction) 32 | self.reg_lambda = reg_lambda 33 | self.deg_logit = deg_logit 34 | 35 | def forward(self, input, target): 36 | N = input.size(0) # number of samples 37 | C = input.size(1) # number of classes 38 | Mask = torch.zeros_like(input, requires_grad=False) 39 | Mask[range(N), target] = 1 40 | 41 | if self.deg_logit is not None: 42 | input = input - self.deg_logit * Mask 43 | 44 | loss = F.cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) 45 | 46 | X = input - 1.e6 * Mask # [N x C], excluding the target class 47 | reg = 0.5 * ((F.softmax(X, dim=1) - 1.0 / (C - 1)) * F.log_softmax(X, dim=1) * (1.0 - Mask)).sum(dim=1) 48 | if self.reduction == 'sum': 49 | reg = reg.sum() 50 | elif self.reduction == 'mean': 51 | reg = reg.mean() 52 | elif self.reduction == 'none': 53 | reg = reg 54 | 55 | return loss + self.reg_lambda * reg 56 | 57 | class SyncFunction(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, tensor): 60 | # gather sizes on different GPU's 61 | size = torch.tensor(tensor.size(0), device=tensor.device) 62 | gathered_size = [torch.zeros_like(size) for _ in range(torch.distributed.get_world_size())] 63 | torch.distributed.all_gather(gathered_size, size) 64 | ctx.sizes = [_.item() for _ in gathered_size] 65 | max_bs = max(ctx.sizes) 66 | 67 | gathered_tensor = [tensor.new_zeros((max_bs, ) + tensor.shape[1:]) for _ in range(torch.distributed.get_world_size())] 68 | tbg = torch.cat([tensor, tensor.new_zeros((max_bs-tensor.size(0), ) + tensor.shape[1:])], dim=0) 69 | torch.distributed.all_gather(gathered_tensor, tbg) 70 | gathered_tensor = torch.cat([_[:s] for (_, s) in zip(gathered_tensor, ctx.sizes)], 0) 71 | return gathered_tensor 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | grad_input = grad_output.clone() 76 | torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) 77 | my_rank = torch.distributed.get_rank() 78 | idx_from = sum(ctx.sizes[:my_rank]) 79 | idx_to = idx_from + ctx.sizes[my_rank] 80 | return grad_input[idx_from:idx_to] 81 | 82 | 83 | class Projection(nn.Module): 84 | 85 | def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): 86 | super().__init__() 87 | self.output_dim = output_dim 88 | self.input_dim = input_dim 89 | self.hidden_dim = hidden_dim 90 | 91 | self.model = nn.Sequential( 92 | nn.Linear(self.input_dim, self.hidden_dim), nn.BatchNorm1d(self.hidden_dim), nn.ReLU(), 93 | nn.Linear(self.hidden_dim, self.output_dim, bias=False) 94 | ) 95 | 96 | def forward(self, x): 97 | x = self.model(x) 98 | return F.normalize(x, dim=1) 99 | 100 | 101 | class TAN(pl.LightningModule): 102 | def __init__( 103 | self, 104 | gpus: int, 105 | num_samples: int, 106 | batch_size: int, 107 | length: int, 108 | dataset: str, 109 | num_nodes: int = 1, 110 | arch: str = 'resnet50', 111 | hidden_mlp: int = 512, # 2048, this is revised 112 | feat_dim: int = 128, 113 | warmup_epochs: int = 10, 114 | max_epochs: int = 100, 115 | temperature: float = 0.1, 116 | first_conv: bool = True, 117 | maxpool1: bool = True, 118 | optimizer: str = 'adam', 119 | lars_wrapper: bool = True, 120 | exclude_bn_bias: bool = False, 121 | start_lr: float = 0., 122 | learning_rate: float = 1e-3, 123 | final_lr: float = 0., 124 | weight_decay: float = 1e-6, 125 | val_configs=None, 126 | log_dir=None, 127 | protection=0, 128 | tr_layer=6, 129 | tr_dim=512, 130 | neg_dp=0.0, 131 | j=51, 132 | **kwargs 133 | ): 134 | """ 135 | Args: 136 | batch_size: the batch size 137 | num_samples: num samples in the dataset 138 | warmup_epochs: epochs to warmup the lr for 139 | lr: the optimizer learning rate 140 | opt_weight_decay: the optimizer weight decay 141 | loss_temperature: the loss temperature 142 | """ 143 | super().__init__() 144 | self.save_hyperparameters() 145 | 146 | self.gpus = gpus 147 | self.num_nodes = num_nodes 148 | self.arch = arch 149 | self.dataset = dataset 150 | self.num_samples = num_samples 151 | self.batch_size = batch_size # batch size from the view of scheduler 152 | self.real_batch_size = batch_size * length # batch size from the view of optimizer 153 | 154 | self.hidden_mlp = hidden_mlp 155 | self.feat_dim = feat_dim 156 | self.first_conv = first_conv 157 | self.maxpool1 = maxpool1 158 | 159 | self.optim = optimizer 160 | self.lars_wrapper = lars_wrapper 161 | self.exclude_bn_bias = exclude_bn_bias 162 | self.weight_decay = weight_decay 163 | self.temperature = temperature 164 | 165 | self.start_lr = start_lr / 256 * self.real_batch_size 166 | self.final_lr = final_lr / 256 * self.real_batch_size 167 | self.learning_rate = learning_rate / 256 * self.real_batch_size 168 | self.warmup_epochs = warmup_epochs 169 | self.max_epochs = max_epochs 170 | self.log_dir = log_dir 171 | # select which level of protection is used 172 | self.loss_calculator = [self.nt_xent_loss, self.nt_xent_loss_halfprotect, self.nt_xent_loss_protection, self.nt_xent_loss_rectangle, self.large_margin_loss_protection][protection] 173 | self.is_rectangle = protection == 3 174 | if protection == 4: 175 | self.lml = LargeMarginInSoftmaxLoss(reg_lambda=0.3) 176 | 177 | self.encoder = self.init_model(tr_layer=tr_layer, tr_dim=tr_dim, j=j) 178 | 179 | self.projection = Projection(input_dim=tr_dim, hidden_dim=tr_dim, output_dim=self.feat_dim) 180 | # originally using hidden_mlp for input_dim and hidden_dim 181 | self.dropout = nn.Dropout(p=neg_dp) 182 | 183 | # compute iters per epoch 184 | global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size * torch.cuda.device_count() 185 | if global_batch_size != 0: 186 | self.train_iters_per_epoch = math.ceil(self.num_samples / global_batch_size) 187 | else: 188 | self.train_iters_per_epoch = 0 189 | 190 | # define LR schedule 191 | warmup_lr_schedule = np.linspace(self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs) 192 | iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) 193 | cosine_lr_schedule = np.array([ 194 | self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * 195 | (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)))) 196 | for t in iters 197 | ]) 198 | 199 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 200 | 201 | # construct validator 202 | self.validators = [] 203 | if val_configs is not None and torch.cuda.device_count() != 0: 204 | for k, val in val_configs.items(): 205 | val["log_dir"] = self.log_dir 206 | val["rank"] = self.global_rank 207 | val["world_size"] = torch.cuda.device_count() 208 | self.validators.append(construct_validator(k, val)) 209 | 210 | def init_model(self, tr_layer, tr_dim, j): 211 | if self.arch == "Transformer": 212 | return Transformer(tr_layer, tr_dim, j) 213 | elif self.arch == "Tconv": 214 | # TODO: move to config 215 | config = get_default_tconv_net_config() 216 | config.tempconv_dim_in = 51 217 | config.tempconv_dim_out = 512 218 | config.tempconv_filter_widths = [5, ] * 5 219 | config.tempconv_channels = 1024 220 | return get_tconv_net(config) 221 | elif self.arch == "Transformer_wote": 222 | return Transformer_wote(tr_layer, tr_dim, j) 223 | else: 224 | assert 0, "Unknown model!" 225 | 226 | def forward(self, *args): 227 | x = self.encoder(*args) # [N, T, f] 228 | if self.arch == "Tconv": 229 | x = x.permute(2, 0, 1).contiguous() 230 | return x 231 | 232 | def shared_step(self, batch): 233 | # img1, img2: [B, maxT1, 51], [B, maxT2, 51], maxT1 >= l1b, maxT2 >= l2b, any b in B 234 | # len1, len2: [B] of ints, real lengths, l1B, l2B 235 | # velo1, velo2: [B, maxT1], [B, maxT2], corresponding indices to video before temporal augmentation 236 | # m: [t1, t2]: real number between 0 and 1, t1 = sum_B l1b, t1 = sum_B l2b, composed of diagonal rectangles 237 | # chopped_bs: the batch size after reducing length difference to squares 238 | img1, img2, len1, len2, m, indices1, indices2, chopped_bs = batch 239 | # len1 and len2 actually the same 240 | h1_ = self(img1, len1) # [maxT1, B, f=512] 241 | h2_ = self(img2, len2) 242 | h1_ = h1_.permute(1, 0, 2).contiguous() # [B, maxT1, f=512] 243 | h2_ = h2_.permute(1, 0, 2).contiguous() 244 | 245 | if self.is_rectangle: 246 | dev = len1.device 247 | bs, maxT1, f = h1_.shape 248 | _, maxT2, _ = h2_.shape 249 | # big_boy = torch.arange(bs).to(dev).unsqueeze(-1) * big_number 250 | 251 | indices1 = torch.cat([torch.arange(l1b).to(dev) + b * maxT1 for b, l1b in enumerate(len1)], dim=0) 252 | h1 = torch.gather(h1_.flatten(0, 1), 0, indices1.unsqueeze(-1).repeat(1, f)) # [t1, f] 253 | # v1 = torch.gather((velo1 + big_boy).flatten(), 0, indices1) # [t1] 254 | indices2 = torch.cat([torch.arange(l2b).to(dev) + b * maxT2 for b, l2b in enumerate(len2)], dim=0) 255 | h2 = torch.gather(h2_.flatten(0, 1), 0, indices2.unsqueeze(-1).repeat(1, f)) # [t2, f] 256 | # v2 = torch.gather((velo2 + big_boy).flatten(), 0, indices2) # [t2] 257 | z1 = self.projection(h1) 258 | z2 = self.projection(h2) 259 | 260 | loss = self.loss_calculator(z1, z2, m, self.temperature) 261 | else: 262 | h1 = h1_.flatten(0, 1)[indices1] 263 | h2 = h2_.flatten(0, 1)[indices2] 264 | z1 = self.projection(h1) 265 | z2 = self.projection(h2) 266 | loss = self.loss_calculator(z1, z2, chopped_bs, self.temperature) 267 | return loss 268 | 269 | def training_step(self, batch, batch_idx): 270 | loss = self.shared_step(batch) 271 | 272 | # log LR (LearningRateLogger callback doesn't work with LARSWrapper) 273 | self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) 274 | 275 | self.log('train_loss', loss, on_step=True, on_epoch=False) 276 | return loss 277 | 278 | def validation_step(self, batch, batch_idx): 279 | loss = self.shared_step(batch) 280 | self.log('val_loss', loss, on_step=False, on_epoch=True, sync_dist=True) 281 | return loss 282 | 283 | def on_validation_epoch_end(self, ): 284 | device = self.device 285 | self.eval() 286 | # self.cpu() 287 | for validator in self.validators: 288 | if self.global_rank != 0: 289 | save = -1 290 | else: 291 | if self.current_epoch % 500 == 474: # note it has to be a subset of 5Z - 1, debug time use 0 292 | save = self.current_epoch 293 | else: 294 | save = -1 295 | 296 | metric_dict = validator(self, save=save) 297 | for name, metric in metric_dict.items(): 298 | metric = torch.tensor([metric], device=device) 299 | self.log(name, metric, on_step=False, on_epoch=True, sync_dist=True) 300 | # self.to(device) 301 | if dump_time: 302 | if self.current_epoch % 1 == 0: 303 | torch.save(self.encoder, os.path.join(self.log_dir, f"dumped_at epoch{self.current_epoch}.ckpt")) 304 | 305 | def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']): 306 | params = [] 307 | excluded_params = [] 308 | 309 | for name, param in named_params: 310 | if not param.requires_grad: 311 | continue 312 | elif any(layer_name in name for layer_name in skip_list): 313 | excluded_params.append(param) 314 | else: 315 | params.append(param) 316 | 317 | return [{ 318 | 'params': params, 319 | 'weight_decay': weight_decay 320 | }, { 321 | 'params': excluded_params, 322 | 'weight_decay': 0., 323 | }] 324 | 325 | def configure_optimizers(self): 326 | if self.exclude_bn_bias: 327 | params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) 328 | else: 329 | params = self.parameters() 330 | 331 | if self.optim == 'sgd': 332 | optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay) 333 | elif self.optim == 'adam': 334 | # optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) 335 | optimizer = torch.optim.AdamW(params, lr=self.learning_rate, weight_decay=self.weight_decay) 336 | 337 | if self.lars_wrapper: 338 | optimizer = LARSWrapper( 339 | optimizer, 340 | eta=0.001, # trust coefficient 341 | clip=False 342 | ) 343 | 344 | return optimizer 345 | 346 | def optimizer_step( 347 | self, 348 | epoch: int = None, 349 | batch_idx: int = None, 350 | optimizer: Optimizer = None, 351 | optimizer_idx: int = None, 352 | optimizer_closure: Optional[Callable] = None, 353 | on_tpu: bool = None, 354 | using_native_amp: bool = None, 355 | using_lbfgs: bool = None, 356 | ) -> None: 357 | # warm-up + decay schedule placed here since LARSWrapper is not optimizer class 358 | # adjust LR of optim contained within LARSWrapper 359 | for param_group in optimizer.param_groups: 360 | param_group["lr"] = self.lr_schedule[self.trainer.global_step] # // torch.cuda.device_count()] 361 | 362 | # rank = torch.distributed.get_rank() 363 | # print(f"I am with rank {rank} and I am at global step {self.trainer.global_step}") 364 | 365 | # from lightning 366 | if not isinstance(optimizer, LightningOptimizer): 367 | # wraps into LightingOptimizer only for running step 368 | optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) 369 | optimizer.step(closure=optimizer_closure) 370 | 371 | def nt_xent_loss(self, out_1, out_2, len, temperature, eps=1e-6): 372 | """ 373 | assume out_1 and out_2 are normalized 374 | out_1: [batch_size, dim] 375 | out_2: [batch_size, dim] 376 | """ 377 | # gather representations in case of distributed training 378 | # out_1_dist: [batch_size * world_size, dim] 379 | # out_2_dist: [batch_size * world_size, dim] 380 | del len 381 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 382 | out_1_dist = SyncFunction.apply(out_1) 383 | out_2_dist = SyncFunction.apply(out_2) 384 | else: 385 | out_1_dist = out_1 386 | out_2_dist = out_2 387 | 388 | # out: [2 * batch_size, dim] 389 | # out_dist: [2 * batch_size * world_size, dim] 390 | out = torch.cat([out_1, out_2], dim=0) 391 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 392 | 393 | # Positive similarity, pos becomes [2 * batch_size] 394 | inner_sim = torch.exp(out_1 @ out_2.t().contiguous() / temperature) 395 | pos = torch.diagonal(inner_sim) 396 | pos = torch.cat([pos, pos], dim=0) 397 | 398 | # cov: [2 * batch_size, 2 * batch_size * world_size] 399 | # neg: [2 * batch_size] 400 | cov = torch.exp(out @ out_dist.t().contiguous() / temperature) 401 | neg = cov.sum(dim=-1) # length: \sum_i (l_1i + l_2i) 402 | 403 | # from each row, subtract e^(1/t) so that denominator has only t1 + t2 - 1 classes 404 | row_sub = torch.ones_like(neg) * math.exp(1 / temperature) 405 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 406 | 407 | loss = -torch.log(pos / (neg + eps)).mean() 408 | 409 | return loss 410 | 411 | def nt_xent_loss_protection(self, out_1, out_2, len, temperature, eps=1e-6): 412 | """ 413 | assume out_1 and out_2 are normalized 414 | out_1: [batch_size, dim] 415 | out_2: [batch_size, dim] 416 | """ 417 | # gather representations in case of distributed training 418 | # out_1_dist: [batch_size * world_size, dim] 419 | # out_2_dist: [batch_size * world_size, dim] 420 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 421 | out_1_dist = SyncFunction.apply(out_1) 422 | out_2_dist = SyncFunction.apply(out_2) 423 | else: 424 | out_1_dist = out_1 425 | out_2_dist = out_2 426 | 427 | # Bg, total frame number on a certain GPU, = \sum bi, i over number of videos per GPU 428 | out = torch.cat([out_1, out_2], dim=0) # [2Bg] 429 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) # [2 \sum Bg] 430 | 431 | # Positive similarity, pos becomes [2 * batch_size] 432 | inner_sim = torch.exp(out_1 @ out_2.t().contiguous() / temperature) # [Bg, Bg] 433 | pos = torch.diagonal(inner_sim) 434 | pos = torch.cat([pos, pos], dim=0) # [2Bg] 435 | 436 | cov = torch.exp(out @ out_dist.t().contiguous() / temperature) # [2Bg, 2 \sum Bg] 437 | neg = cov.sum(dim=-1) # [2Bg] 438 | 439 | # from each row, subtract similarity to frame from the same video 440 | mask = torch.block_diag(*[pos.new_ones(_, _) for _ in len]) 441 | mask = torch.cat([mask, mask], dim=0) 442 | mask = torch.cat([mask, mask], dim=1) 443 | outer_sim = torch.exp(out @ out.t().contiguous() / temperature) # [2Bg, 2Bg] 444 | masked = outer_sim * mask 445 | row_sub = masked.sum(dim=0) 446 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 447 | 448 | loss = -torch.log(pos / (neg + pos + eps)).mean() 449 | 450 | return loss 451 | 452 | def large_margin_loss_protection(self, out_1, out_2, len, temperature, eps=1e-6): 453 | """ 454 | assume out_1 and out_2 are normalized 455 | out_1: [batch_size, dim] 456 | out_2: [batch_size, dim] 457 | """ 458 | # gather representations in case of distributed training 459 | # out_1_dist: [batch_size * world_size, dim] 460 | # out_2_dist: [batch_size * world_size, dim] 461 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 462 | out_1_dist = SyncFunction.apply(out_1) 463 | out_2_dist = SyncFunction.apply(out_2) 464 | else: 465 | out_1_dist = out_1 466 | out_2_dist = out_2 467 | 468 | # Bg, total frame number on a certain GPU, = \sum bi, i over number of videos per GPU 469 | out = torch.cat([out_1, out_2], dim=0) # [2Bg] 470 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) # [2 \sum Bg] 471 | 472 | # Positive similarity, pos becomes [2 * batch_size] 473 | inner_sim = torch.exp(out_1 @ out_2.t().contiguous() / temperature) # [Bg, Bg] 474 | pos = torch.diagonal(inner_sim) 475 | pos = torch.cat([pos, pos], dim=0) # [2Bg] 476 | 477 | cov = torch.exp(out @ out_dist.t().contiguous() / temperature) # [2Bg, 2 \sum Bg] 478 | neg = cov.sum(dim=-1) # [2Bg] 479 | 480 | # from each row, subtract similarity to frame from the same video 481 | mask = torch.block_diag(*[pos.new_ones(_, _) for _ in len]) 482 | mask = torch.cat([mask, mask], dim=0) 483 | mask = torch.cat([mask, mask], dim=1) 484 | outer_sim = torch.exp(out @ out.t().contiguous() / temperature) # [2Bg, 2Bg] 485 | masked = outer_sim * mask 486 | row_sub = masked.sum(dim=0) 487 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 488 | 489 | loss = -torch.log(pos / (neg + pos + eps)).mean() 490 | 491 | return loss 492 | 493 | def nt_xent_loss_halfprotect(self, out_1, out_2, len, temperature, eps=1e-6): 494 | """ 495 | assume out_1 and out_2 are normalized 496 | out_1: [batch_size, dim] 497 | out_2: [batch_size, dim] 498 | """ 499 | # gather representations in case of distributed training 500 | # out_1_dist: [batch_size * world_size, dim] 501 | # out_2_dist: [batch_size * world_size, dim] 502 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 503 | out_1_dist = SyncFunction.apply(out_1) 504 | out_2_dist = SyncFunction.apply(out_2) 505 | else: 506 | out_1_dist = out_1 507 | out_2_dist = out_2 508 | 509 | # Bg, total frame number on a certain GPU, = \sum bi, i over number of videos per GPU 510 | out = torch.cat([out_1, out_2], dim=0) # [2Bg] 511 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) # [2 \sum Bg] 512 | 513 | # Positive similarity, pos becomes [2 * batch_size] 514 | inner_sim = torch.exp(out_1 @ out_2.t().contiguous() / temperature) # [Bg, Bg] 515 | pos = torch.diagonal(inner_sim) 516 | pos = torch.cat([pos, pos], dim=0) # [2Bg] 517 | 518 | cov = torch.exp(out @ out_dist.t().contiguous() / temperature) # [2Bg, 2 \sum Bg] 519 | neg = cov.sum(dim=-1) # [2Bg] 520 | 521 | # from each row, subtract similarity to frame from the same video 522 | mask = torch.block_diag(*[pos.new_ones(_, _) for _ in len]) 523 | mask_l = torch.cat([mask, mask.new_zeros(mask.shape)], dim=0) 524 | mask_r = torch.cat([mask.new_zeros(mask.shape), mask], dim=0) 525 | mask = torch.cat([mask_l, mask_r], dim=1) 526 | outer_sim = torch.exp(out @ out.t().contiguous() / temperature) # [2Bg, 2Bg] 527 | masked = outer_sim * mask 528 | row_sub = masked.sum(dim=0) 529 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 530 | 531 | loss = -torch.log(pos / (neg + eps)).mean() 532 | 533 | return loss 534 | 535 | def nt_xent_loss_rectangle(self, out_1, out_2, m, temperature, eps=1e-6): 536 | """ 537 | assume out_1 and out_2 are normalized 538 | out_1: [t1, f] 539 | out_2: [t2, f] 540 | m: [t1, t2] 541 | """ 542 | # gather representations in case of distributed training 543 | # out_1_dist: [batch_size * world_size, dim] 544 | # out_2_dist: [batch_size * world_size, dim] 545 | t1, t2 = m.shape 546 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 547 | out_1_dist = SyncFunction.apply(out_1).detach() 548 | out_2_dist = SyncFunction.apply(out_2).detach() 549 | else: 550 | out_1_dist = out_1 551 | out_2_dist = out_2 552 | 553 | # Bg, total frame number on a certain GPU, = \sum bi, i over number of videos per GPU 554 | out = torch.cat([out_1, out_2], dim=0) # [2Bg] 555 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) # [2 \sum Bg] 556 | 557 | cov = torch.exp(torch.mm(out, out_dist.t().contiguous())/ temperature) # [2Bg, 2 \sum Bg] 558 | cov = self.dropout(cov) 559 | neg = cov.sum(dim=-1) # [2Bg] 560 | 561 | # # from each row, subtract e^(1/t) so that denominator has only t1 + t2 - 1 classes 562 | # row_sub = torch.ones_like(neg) * math.exp(1 / temperature) 563 | # neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 564 | 565 | # calculate positive 566 | inner_sim = torch.exp(torch.mm(out, out.t().contiguous())/ temperature) # [2Bg, 2Bg] 567 | pos = (inner_sim * m).sum(dim=-1) 568 | 569 | loss = -torch.log(pos / (neg + eps)).mean() 570 | return loss 571 | --------------------------------------------------------------------------------