├── preprocessing ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── functional_video.cpython-38.pyc │ └── transforms_video.cpython-38.pyc ├── functional_video.py └── transforms_video.py ├── __pycache__ ├── losses.cpython-36.pyc └── utils.cpython-36.pyc ├── evals ├── __pycache__ │ ├── datasets.cpython-38.pyc │ ├── kendalls_tau.cpython-38.pyc │ ├── phase_progression.cpython-38.pyc │ └── phase_classification.cpython-38.pyc ├── datasets.py ├── .ipynb_checkpoints │ ├── datasets-checkpoint.py │ ├── phase_classification-checkpoint.py │ └── phase_progression-checkpoint.py ├── kendalls_tau.py ├── phase_classification.py └── phase_progression.py ├── video_to_frames.py ├── decode.json ├── encode.json ├── data_prep_script.py ├── lav_env.yml ├── README.md ├── soft_dtw.py ├── align_dataset_test.py ├── wget-log ├── config.py ├── align_dataset.py ├── losses.py ├── models.py ├── visualize_alignment.py ├── train.py ├── evaluations.py └── utils.py /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /evals/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/evals/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /evals/__pycache__/kendalls_tau.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/evals/__pycache__/kendalls_tau.cpython-38.pyc -------------------------------------------------------------------------------- /evals/__pycache__/phase_progression.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/evals/__pycache__/phase_progression.cpython-38.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/preprocessing/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /evals/__pycache__/phase_classification.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/evals/__pycache__/phase_classification.cpython-38.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/functional_video.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/preprocessing/__pycache__/functional_video.cpython-38.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/transforms_video.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trquhuytin/LAV-CVPR21/HEAD/preprocessing/__pycache__/transforms_video.cpython-38.pyc -------------------------------------------------------------------------------- /evals/datasets.py: -------------------------------------------------------------------------------- 1 | DATASET_TO_NUM_CLASSES = { 2 | 'pouring': 5, 3 | 'baseball_pitch': 4, 4 | 'baseball_swing': 3, 5 | 'bench_press': 2, 6 | 'bowling': 3, 7 | 'clean_and_jerk': 6, 8 | 'golf_swing': 3, 9 | 'jumping_jacks': 4, 10 | 'pushups': 2, 11 | 'pullups': 2, 12 | 'situp': 2, 13 | 'squats': 4, 14 | 'tennis_forehand': 3, 15 | 'tennis_serve': 4, 16 | 'pouring_milk':5, 17 | } 18 | 19 | # DATASET_TO_NUM_CLASSES = { 20 | # 'videos': 6 21 | # } 22 | 23 | IDX_TO_CLASS = { 24 | 0: 0, 25 | 1: 5, 26 | 2: 13, 27 | 3: 18, 28 | 4: 21, 29 | 5: 23 30 | } -------------------------------------------------------------------------------- /evals/.ipynb_checkpoints/datasets-checkpoint.py: -------------------------------------------------------------------------------- 1 | DATASET_TO_NUM_CLASSES = { 2 | 'pouring': 5, 3 | 'baseball_pitch': 4, 4 | 'baseball_swing': 3, 5 | 'bench_press': 2, 6 | 'bowling': 3, 7 | 'clean_and_jerk': 6, 8 | 'golf_swing': 3, 9 | 'jumping_jacks': 4, 10 | 'pushups': 2, 11 | 'pullups': 2, 12 | 'situp': 2, 13 | 'squats': 4, 14 | 'tennis_forehand': 3, 15 | 'tennis_serve': 4, 16 | 'pouring_milk':5, 17 | } 18 | 19 | # DATASET_TO_NUM_CLASSES = { 20 | # 'videos': 6 21 | # } 22 | 23 | IDX_TO_CLASS = { 24 | 0: 0, 25 | 1: 5, 26 | 2: 13, 27 | 3: 18, 28 | 4: 21, 29 | 5: 23 30 | } -------------------------------------------------------------------------------- /video_to_frames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import sys 3 | import tqdm 4 | import os 5 | import glob 6 | 7 | if not os.path.exists('Data'): 8 | os.mkdir("Data") 9 | 10 | path = sys.argv[1] 11 | 12 | for video in tqdm.tqdm(glob.glob(path + '*')): 13 | vid_name = video.split('/')[-1].split('.')[0] 14 | vidcap = cv2.VideoCapture(video) 15 | success,image = vidcap.read() 16 | count = 0 17 | if not os.path.exists(f'Data/{vid_name}'): 18 | os.mkdir(f'Data/{vid_name}') 19 | else: 20 | continue 21 | while success: 22 | cv2.imwrite(f"Data/{vid_name}/frame%d.jpg" % count, image) 23 | success,image = vidcap.read() 24 | count += 1 25 | -------------------------------------------------------------------------------- /decode.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "vid1": "subject1/h1/1", 4 | "vid2": "subject1/h2/1", 5 | "vid3": "subject1/k1/1", 6 | "vid4": "subject1/k2/1", 7 | "vid5": "subject1/o1/1", 8 | "vid6": "subject1/o2/1", 9 | "vid7": "subject10/s1/5", 10 | "vid8": "subject10/s2/7", 11 | "vid9": "subject10/s4/4", 12 | "vid10": "subject2/h1/1", 13 | "vid11": "subject2/h2/1", 14 | "vid12": "subject2/k1/1", 15 | "vid13": "subject2/k2/1", 16 | "vid14": "subject5/s1/5", 17 | "vid15": "subject5/s2/4", 18 | "vid16": "subject5/s3/4", 19 | "vid17": "subject5/s4/6", 20 | "vid18": "subject7/s1/6", 21 | "vid19": "subject7/s2/2", 22 | "vid20": "subject7/s3/2", 23 | "vid21": "subject7/s4/1", 24 | "vid22": "subject8/s1/2", 25 | "vid23": "subject8/s2/7", 26 | "vid24": "subject9/s1/5", 27 | "vid25": "subject9/s2/1", 28 | "vid26": "subject9/s3/4", 29 | "vid27": "subject9/s4/6" 30 | }, 31 | "test": { 32 | "vid28": "subject3/h1/1", 33 | "vid29": "subject3/h2/1", 34 | "vid30": "subject3/k1/1", 35 | "vid31": "subject3/k2/1", 36 | "vid32": "subject3/o1/1", 37 | "vid33": "subject3/o2/1", 38 | "vid34": "subject4/k1/1", 39 | "vid35": "subject6/s1/7", 40 | "vid36": "subject6/s2/3", 41 | "vid37": "subject6/s3/2", 42 | "vid38": "subject6/s4/2" 43 | } 44 | } -------------------------------------------------------------------------------- /encode.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "subject1/h1/1": "vid1", 4 | "subject1/h2/1": "vid2", 5 | "subject1/k1/1": "vid3", 6 | "subject1/k2/1": "vid4", 7 | "subject1/o1/1": "vid5", 8 | "subject1/o2/1": "vid6", 9 | "subject10/s1/5": "vid7", 10 | "subject10/s2/7": "vid8", 11 | "subject10/s4/4": "vid9", 12 | "subject2/h1/1": "vid10", 13 | "subject2/h2/1": "vid11", 14 | "subject2/k1/1": "vid12", 15 | "subject2/k2/1": "vid13", 16 | "subject5/s1/5": "vid14", 17 | "subject5/s2/4": "vid15", 18 | "subject5/s3/4": "vid16", 19 | "subject5/s4/6": "vid17", 20 | "subject7/s1/6": "vid18", 21 | "subject7/s2/2": "vid19", 22 | "subject7/s3/2": "vid20", 23 | "subject7/s4/1": "vid21", 24 | "subject8/s1/2": "vid22", 25 | "subject8/s2/7": "vid23", 26 | "subject9/s1/5": "vid24", 27 | "subject9/s2/1": "vid25", 28 | "subject9/s3/4": "vid26", 29 | "subject9/s4/6": "vid27" 30 | }, 31 | "test": { 32 | "subject3/h1/1": "vid28", 33 | "subject3/h2/1": "vid29", 34 | "subject3/k1/1": "vid30", 35 | "subject3/k2/1": "vid31", 36 | "subject3/o1/1": "vid32", 37 | "subject3/o2/1": "vid33", 38 | "subject4/k1/1": "vid34", 39 | "subject6/s1/7": "vid35", 40 | "subject6/s2/3": "vid36", 41 | "subject6/s3/2": "vid37", 42 | "subject6/s4/2": "vid38" 43 | } 44 | } -------------------------------------------------------------------------------- /data_prep_script.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import shutil 4 | import numpy as np 5 | 6 | # fileObject = open("D:/Drhuy/LAV/encode.json", "r") 7 | # jsonContent = fileObject.read() 8 | # encode = json.loads(jsonContent) 9 | # fileObject.close() 10 | 11 | # for f in encode['test'].keys(): 12 | # os.makedirs(os.path.join("D:/Drhuy/LAV/data/H20/val", encode['test'][f]), exist_ok=True) 13 | # pth = os.path.join("D:/Drhuy/h2o_CASA",f,"cam4/rgb256") 14 | # for files in os.listdir(pth): 15 | # shutil.copy( 16 | # src=os.path.join(pth, files), 17 | # dst=os.path.join("D:/Drhuy/LAV/data/H20/val", encode['test'][f]) 18 | # ) 19 | 20 | 21 | # for f in encode['test'].keys(): 22 | # pth = os.path.join("D:/Drhuy/h2o_CASA",f,"cam4/action_label") 23 | # np_labs=[] 24 | # for files in os.listdir(pth): 25 | # with open(os.path.join(pth,files), 'r') as k: 26 | # np_labs.append(int(k.read())) 27 | 28 | # labels = np.array(np_labs) 29 | # np.save(os.path.join("D:/Drhuy/LAV/data/H20/labels/val/videos", encode['test'][f]), labels) 30 | 31 | unique_labels = list() 32 | vids = 0 33 | for label in os.listdir(r'D:\Drhuy\LAV\data\H20\labels\train\videos'): 34 | l = np.load(os.path.join(r'D:\Drhuy\LAV\data\H20\labels\train\videos', label)) 35 | unique_labels.extend(np.unique(l)) 36 | vids += 1 37 | 38 | for label in os.listdir(r'D:\Drhuy\LAV\data\H20\labels\val\videos'): 39 | l = np.load(os.path.join(r'D:\Drhuy\LAV\data\H20\labels\val\videos', label)) 40 | unique_labels.extend(np.unique(l)) 41 | vids += 1 42 | 43 | print(np.unique(unique_labels), vids) 44 | # print(set((unique_labels))) -------------------------------------------------------------------------------- /evals/kendalls_tau.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from scipy.spatial.distance import cdist 4 | from scipy.stats import kendalltau 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | def softmax(w, t=1.0): 9 | e = np.exp(np.array(w) / t) 10 | dist = e / np.sum(e) 11 | return dist 12 | 13 | def _get_kendalls_tau(embs_list, stride, split, kt_dist, visualize=False): 14 | """Get nearest neighbours in embedding space and calculate Kendall's Tau.""" 15 | num_seqs = len(embs_list) 16 | taus = np.zeros((num_seqs * (num_seqs - 1))) 17 | idx = 0 18 | for i in range(num_seqs): 19 | query_feats = embs_list[i][::stride] 20 | for j in range(num_seqs): 21 | if i == j: 22 | continue 23 | candidate_feats = embs_list[j][::stride] 24 | dists = cdist(query_feats, candidate_feats, 25 | kt_dist) 26 | if visualize: 27 | if (i == 0 and j == 1) or split == 'val': 28 | sim_matrix = [] 29 | for k in range(len(query_feats)): 30 | sim_matrix.append(softmax(-dists[k])) 31 | sim_matrix = np.array(sim_matrix, dtype=np.float32) 32 | # visualize matplotlib 33 | plt.imshow(sim_matrix) 34 | plt.show() 35 | nns = np.argmin(dists, axis=1) 36 | taus[idx] = kendalltau(np.arange(len(nns)), nns).correlation 37 | 38 | idx += 1 39 | # Remove NaNs. 40 | taus = taus[~np.isnan(taus)] 41 | tau = np.mean(taus) 42 | 43 | return tau 44 | 45 | def evaluate_kendalls_tau(train_embs, val_embs, stride, kt_dist, visualize=False): 46 | 47 | train_tau = _get_kendalls_tau(train_embs, stride=stride, split='train', kt_dist=kt_dist, visualize=visualize) 48 | val_tau = _get_kendalls_tau(val_embs, stride=stride, split='val', kt_dist=kt_dist, visualize=visualize) 49 | 50 | return train_tau, val_tau -------------------------------------------------------------------------------- /lav_env.yml: -------------------------------------------------------------------------------- 1 | name: lav 2 | channels: 3 | - defaults 4 | dependencies: 5 | - certifi=2021.5.30 6 | - pip=21.2.2 7 | - python=3.6.13 8 | - sqlite=3.39.2 9 | - wheel=0.37.1 10 | - pip: 11 | - absl-py==1.2.0 12 | - aiohttp==3.8.3 13 | - aiosignal==1.2.0 14 | - async-timeout==4.0.2 15 | - asynctest==0.13.0 16 | - attrs==22.1.0 17 | - cachetools==4.2.4 18 | - charset-normalizer==2.0.12 19 | - colorama==0.4.5 20 | - cycler==0.11.0 21 | - dataclasses==0.8 22 | - dtw==1.4.0 23 | - easydict==1.9 24 | - frozenlist==1.2.0 25 | - fsspec==2022.1.0 26 | - future==0.18.2 27 | - google-auth==2.12.0 28 | - google-auth-oauthlib==0.4.6 29 | - gputil==1.4.0 30 | - grpcio==1.48.2 31 | - idna==3.4 32 | - idna-ssl==1.1.0 33 | - imageio==2.6.1 34 | - importlib-metadata==4.8.3 35 | - importlib-resources==5.4.0 36 | - joblib==1.1.0 37 | - kiwisolver==1.3.1 38 | - llvmlite==0.36.0 39 | - markdown==3.3.7 40 | - matplotlib==3.2.0 41 | - multidict==5.2.0 42 | - natsort==7.0.1 43 | - numba==0.53.1 44 | - numpy==1.18.5 45 | - oauthlib==3.2.1 46 | - packaging==21.3 47 | - pandas==1.1.5 48 | - pillow==7.2.0 49 | - protobuf==3.19.5 50 | - pyasn1==0.4.8 51 | - pyasn1-modules==0.2.8 52 | - pydeprecate==0.3.1 53 | - pyparsing==3.0.9 54 | - python-dateutil==2.8.2 55 | - pytorch-lightning==1.5.10 56 | - pytz==2022.4 57 | - pyyaml==6.0 58 | - requests==2.27.1 59 | - requests-oauthlib==1.3.1 60 | - rsa==4.9 61 | - scikit-learn==0.23.2 62 | - scipy==1.5.0 63 | - seaborn==0.11.2 64 | - setuptools==59.5.0 65 | - six==1.16.0 66 | - tensorboard==2.10.1 67 | - tensorboard-data-server==0.6.1 68 | - tensorboard-plugin-wit==1.8.1 69 | - threadpoolctl==3.1.0 70 | - torch-tb-profiler==0.4.0 71 | - torchmetrics==0.8.2 72 | - tqdm==4.64.1 73 | - typing-extensions==4.1.1 74 | - urllib3==1.26.12 75 | - werkzeug==2.0.3 76 | - yarl==1.7.2 77 | - zipp==3.6.0 78 | prefix: C:\Users\hamza\anaconda3\envs\lav 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning by Aligning Videos in Time (CVPR 2021) 2 | 3 | ## Overview 4 | This repository contains the official implementation of our CVPR 2021 paper (https://openaccess.thecvf.com/content/CVPR2021/papers/Haresh_Learning_by_Aligning_Videos_in_Time_CVPR_2021_paper.pdf). 5 | 6 | If you use the code, please cite our paper: 7 | ``` 8 | @inproceedings{haresh2021learning, 9 | title={Learning by aligning videos in time}, 10 | author={Haresh, Sanjay and Kumar, Sateesh and Coskun, Huseyin and Syed, Shahram N and Konin, Andrey and Zia, Zeeshan and Tran, Quoc-Huy}, 11 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 12 | pages={5548--5558}, 13 | year={2021} 14 | } 15 | ``` 16 | 17 | For our recent works, please check out our research page (https://retrocausal.ai/research/). 18 | 19 | 20 | ## Installation 21 | Create an environment and install required packages 22 | ``` 23 | conda env create --name LAV --file=lav_env.yml 24 | conda activate LAV 25 | ``` 26 | 27 | If you face any pytorch related issues during training, uninstall the pytorch first 28 | ``` 29 | pip3 uninstall torch torchvision torchaudio 30 | ``` 31 | 32 | Go to https://pytorch.org/get-started/locally/ and install the suitable pytorch as per you machine requirements. 33 | ``` 34 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 35 | ``` 36 | 37 | 38 | ## Video-to-Frame Conversion 39 | ``` 40 | python video_to_frames.py videos/ 41 | ``` 42 | 43 | 44 | ## Training/Testing Splits 45 | Split your data into train and test and your directory should look like this 46 | ``` 47 | $YOUR_PATH_TO_DATASET 48 | ├─train 49 | ├──vid1/ 50 | | ├──000001.jpg 51 | | ├──000002.jpg 52 | | ├──... 53 | ├──val 54 | ├──vid2/ 55 | | ├──000001.jpg 56 | | ├──000002.jpg 57 | | ├──... 58 | ├──... 59 | ``` 60 | 61 | 62 | ## Training 63 | ``` 64 | python train.py --description "LAV" --data_path Data 65 | ``` 66 | 67 | 68 | ## Testing 69 | ``` 70 | python evaluations.py --model_path path/to/model --dest path/to/log/dest --device 0 71 | ``` 72 | 73 | The expected structure of evaluation is like this: 74 | ``` 75 | ├── 76 | ├──test 77 | ├──vid2 78 | | ├──000001.jpg 79 | | ├──000002.jpg 80 | | ├──... 81 | ``` 82 | -------------------------------------------------------------------------------- /evals/phase_classification.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.svm import SVC 4 | from sklearn.neighbors import NearestNeighbors 5 | from sklearn.metrics import accuracy_score, confusion_matrix 6 | 7 | 8 | def fit_svm(train_embs, train_labels): 9 | train_embs = np.concatenate(train_embs) 10 | train_labels = np.concatenate(train_labels) 11 | 12 | svm_model = SVC(decision_function_shape='ovo') 13 | svm_model.fit(train_embs, train_labels) 14 | train_acc = svm_model.score(train_embs, train_labels) 15 | 16 | return svm_model, train_acc 17 | 18 | 19 | def evaluate_svm(svm, val_embs, val_labels): 20 | 21 | val_preds = [] 22 | for vid_embs in val_embs: 23 | vid_preds = svm.predict(vid_embs) 24 | val_preds.append(vid_preds) 25 | 26 | # concatenate labels and preds in one array 27 | val_preds = np.concatenate(val_preds) 28 | val_labels = np.concatenate(val_labels) 29 | 30 | # calculate accuracy and confusion matrix 31 | val_acc = accuracy_score(val_labels, val_preds) 32 | conf_mat = confusion_matrix(val_labels, val_preds) 33 | 34 | return val_acc, conf_mat 35 | 36 | 37 | def evaluate_phase_classification(ckpt_step, train_embs, train_labels, val_embs, val_labels, act_name, CONFIG, writer=None, verbose=False): 38 | 39 | for frac in CONFIG.EVAL.CLASSIFICATION_FRACTIONS: 40 | N_Vids = max(1, int(len(train_embs) * frac)) 41 | embs = train_embs[:N_Vids] 42 | labs = train_labels[:N_Vids] 43 | 44 | if verbose: 45 | print(f'Fraction = {frac}, Total = {len(train_embs)}, Used = {len(embs)}') 46 | 47 | svm_model, train_acc = fit_svm(embs, labs) 48 | val_acc, conf_mat = evaluate_svm(svm_model, val_embs, val_labels) 49 | 50 | print('\n-----------------------------') 51 | print('Fraction: ', frac) 52 | print('Train-Acc: ', train_acc) 53 | print('Val-Acc: ', val_acc) 54 | print('Conf-Mat: ', conf_mat) 55 | 56 | 57 | writer.add_scalar(f'classification/train_{act_name}_{frac}', train_acc, global_step=ckpt_step) 58 | writer.add_scalar(f'classification/val_{act_name}_{frac}', val_acc, global_step=ckpt_step) 59 | 60 | print(f'classification/train_{act_name}_{frac}', train_acc, f"global_step={ckpt_step}") 61 | print(f'classification/val_{act_name}_{frac}', val_acc, f"global_step={ckpt_step}") 62 | 63 | return train_acc, val_acc 64 | 65 | 66 | def _compute_ap(val_embs, val_labels): 67 | results = [] 68 | for k in [5, 10, 15]: 69 | nbrs = NearestNeighbors(n_neighbors=k).fit(val_embs) 70 | distances, indices = nbrs.kneighbors(val_embs) 71 | vals = [] 72 | for i in range(val_embs.shape[0]): 73 | a = np.array([val_labels[i]] * k) 74 | b = indices[i] 75 | b = np.array([val_labels[k] for k in b]) 76 | val = (a==b).sum()/k 77 | vals.append(val) 78 | 79 | results.append(np.mean(vals)) 80 | return results 81 | 82 | def compute_ap(videos, labels): 83 | ap5, ap10, ap15 = 0, 0, 0 84 | for v,l in zip(videos, labels): 85 | a5, a10, a15 = _compute_ap(v,l) 86 | ap5 += a5 87 | ap10 += a10 88 | ap15 += a15 89 | ap5 /= len(videos) 90 | ap10 /= len(videos) 91 | ap15 /= len(videos) 92 | return [ap5, ap10, ap15] -------------------------------------------------------------------------------- /evals/.ipynb_checkpoints/phase_classification-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.svm import SVC 4 | from sklearn.neighbors import NearestNeighbors 5 | from sklearn.metrics import accuracy_score, confusion_matrix 6 | 7 | 8 | def fit_svm(train_embs, train_labels): 9 | train_embs = np.concatenate(train_embs) 10 | train_labels = np.concatenate(train_labels) 11 | 12 | svm_model = SVC(decision_function_shape='ovo') 13 | svm_model.fit(train_embs, train_labels) 14 | train_acc = svm_model.score(train_embs, train_labels) 15 | 16 | return svm_model, train_acc 17 | 18 | 19 | def evaluate_svm(svm, val_embs, val_labels): 20 | 21 | val_preds = [] 22 | for vid_embs in val_embs: 23 | vid_preds = svm.predict(vid_embs) 24 | val_preds.append(vid_preds) 25 | 26 | # concatenate labels and preds in one array 27 | val_preds = np.concatenate(val_preds) 28 | val_labels = np.concatenate(val_labels) 29 | 30 | # calculate accuracy and confusion matrix 31 | val_acc = accuracy_score(val_labels, val_preds) 32 | conf_mat = confusion_matrix(val_labels, val_preds) 33 | 34 | return val_acc, conf_mat 35 | 36 | 37 | def evaluate_phase_classification(ckpt_step, train_embs, train_labels, val_embs, val_labels, act_name, CONFIG, writer=None, verbose=False): 38 | 39 | for frac in CONFIG.EVAL.CLASSIFICATION_FRACTIONS: 40 | N_Vids = max(1, int(len(train_embs) * frac)) 41 | embs = train_embs[:N_Vids] 42 | labs = train_labels[:N_Vids] 43 | 44 | if verbose: 45 | print(f'Fraction = {frac}, Total = {len(train_embs)}, Used = {len(embs)}') 46 | 47 | svm_model, train_acc = fit_svm(embs, labs) 48 | val_acc, conf_mat = evaluate_svm(svm_model, val_embs, val_labels) 49 | 50 | print('\n-----------------------------') 51 | print('Fraction: ', frac) 52 | print('Train-Acc: ', train_acc) 53 | print('Val-Acc: ', val_acc) 54 | print('Conf-Mat: ', conf_mat) 55 | 56 | 57 | writer.add_scalar(f'classification/train_{act_name}_{frac}', train_acc, global_step=ckpt_step) 58 | writer.add_scalar(f'classification/val_{act_name}_{frac}', val_acc, global_step=ckpt_step) 59 | 60 | print(f'classification/train_{act_name}_{frac}', train_acc, f"global_step={ckpt_step}") 61 | print(f'classification/val_{act_name}_{frac}', val_acc, f"global_step={ckpt_step}") 62 | 63 | return train_acc, val_acc 64 | 65 | 66 | def _compute_ap(val_embs, val_labels): 67 | results = [] 68 | for k in [5, 10, 15]: 69 | nbrs = NearestNeighbors(n_neighbors=k).fit(val_embs) 70 | distances, indices = nbrs.kneighbors(val_embs) 71 | vals = [] 72 | for i in range(val_embs.shape[0]): 73 | a = np.array([val_labels[i]] * k) 74 | b = indices[i] 75 | b = np.array([val_labels[k] for k in b]) 76 | val = (a==b).sum()/k 77 | vals.append(val) 78 | 79 | results.append(np.mean(vals)) 80 | return results 81 | 82 | def compute_ap(videos, labels): 83 | ap5, ap10, ap15 = 0, 0, 0 84 | for v,l in zip(videos, labels): 85 | a5, a10, a15 = _compute_ap(v,l) 86 | ap5 += a5 87 | ap10 += a10 88 | ap15 += a15 89 | ap5 /= len(videos) 90 | ap10 /= len(videos) 91 | ap15 /= len(videos) 92 | return [ap5, ap10, ap15] -------------------------------------------------------------------------------- /preprocessing/functional_video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _is_tensor_video_clip(clip): 4 | if not torch.is_tensor(clip): 5 | raise TypeError("clip should be Tesnor. Got %s" % type(clip)) 6 | 7 | if not clip.ndimension() == 4: 8 | raise ValueError("clip should be 4D. Got %dD" % clip.dim()) 9 | 10 | return True 11 | 12 | 13 | def crop(clip, i, j, h, w): 14 | """ 15 | Args: 16 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 17 | """ 18 | assert len(clip.size()) == 4, "clip should be a 4D tensor" 19 | return clip[..., i:i + h, j:j + w] 20 | 21 | 22 | def resize(clip, target_size, interpolation_mode): 23 | assert len(target_size) == 2, "target size should be tuple (height, width)" 24 | return torch.nn.functional.interpolate( 25 | clip, size=target_size, mode=interpolation_mode, align_corners=False 26 | ) 27 | 28 | 29 | def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): 30 | """ 31 | Do spatial cropping and resizing to the video clip 32 | Args: 33 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 34 | i (int): i in (i,j) i.e coordinates of the upper left corner. 35 | j (int): j in (i,j) i.e coordinates of the upper left corner. 36 | h (int): Height of the cropped region. 37 | w (int): Width of the cropped region. 38 | size (tuple(int, int)): height and width of resized clip 39 | Returns: 40 | clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) 41 | """ 42 | assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" 43 | clip = crop(clip, i, j, h, w) 44 | clip = resize(clip, size, interpolation_mode) 45 | return clip 46 | 47 | 48 | def center_crop(clip, crop_size): 49 | assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" 50 | h, w = clip.size(-2), clip.size(-1) 51 | th, tw = crop_size 52 | assert h >= th and w >= tw, "height and width must be no smaller than crop_size" 53 | 54 | i = int(round((h - th) / 2.0)) 55 | j = int(round((w - tw) / 2.0)) 56 | return crop(clip, i, j, th, tw) 57 | 58 | 59 | def to_tensor(clip): 60 | """ 61 | Convert tensor data type from uint8 to float, divide value by 255.0 and 62 | permute the dimenions of clip tensor 63 | Args: 64 | clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) 65 | Return: 66 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) 67 | """ 68 | _is_tensor_video_clip(clip) 69 | if not clip.dtype == torch.uint8: 70 | raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) 71 | return clip.float().permute(3, 0, 1, 2) / 255.0 72 | 73 | 74 | def normalize(clip, mean, std, inplace=False): 75 | """ 76 | Args: 77 | clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) 78 | mean (tuple): pixel RGB mean. Size is (3) 79 | std (tuple): pixel standard deviation. Size is (3) 80 | Returns: 81 | normalized clip (torch.tensor): Size is (T, C, H, W) 82 | """ 83 | assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" 84 | if not inplace: 85 | clip = clip.clone() 86 | mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) 87 | std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) 88 | clip.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 89 | return clip 90 | 91 | 92 | def hflip(clip): 93 | """ 94 | Args: 95 | clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) 96 | Returns: 97 | flipped clip (torch.tensor): Size is (C, T, H, W) 98 | """ 99 | assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" 100 | return clip.flip((-1)) 101 | -------------------------------------------------------------------------------- /soft_dtw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from numba import jit 5 | from torch.autograd import Function 6 | 7 | @jit(nopython = True) 8 | def compute_softdtw(D, gamma): 9 | B = D.shape[0] 10 | N = D.shape[1] 11 | M = D.shape[2] 12 | R = np.ones((B, N + 2, M + 2)) * np.inf 13 | R[:, 0, 0] = 0 14 | for k in range(B): 15 | for j in range(1, M + 1): 16 | for i in range(1, N + 1): 17 | r0 = -R[k, i - 1, j - 1] / gamma 18 | r1 = -R[k, i - 1, j] / gamma 19 | r2 = -R[k, i, j - 1] / gamma 20 | rmax = max(max(r0, r1), r2) 21 | rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax) 22 | softmin = - gamma * (np.log(rsum) + rmax) 23 | R[k, i, j] = D[k, i - 1, j - 1] + softmin 24 | return R 25 | 26 | @jit(nopython = True) 27 | def compute_softdtw_backward(D_, R, gamma): 28 | B = D_.shape[0] 29 | N = D_.shape[1] 30 | M = D_.shape[2] 31 | D = np.zeros((B, N + 2, M + 2)) 32 | E = np.zeros((B, N + 2, M + 2)) 33 | D[:, 1:N + 1, 1:M + 1] = D_ 34 | E[:, -1, -1] = 1 35 | R[:, : , -1] = -np.inf 36 | R[:, -1, :] = -np.inf 37 | R[:, -1, -1] = R[:, -2, -2] 38 | for k in range(B): 39 | for j in range(M, 0, -1): 40 | for i in range(N, 0, -1): 41 | a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma 42 | b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma 43 | c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma 44 | a = np.exp(a0) 45 | b = np.exp(b0) 46 | c = np.exp(c0) 47 | E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c 48 | return E[:, 1:N + 1, 1:M + 1] 49 | 50 | class _SoftDTW(Function): 51 | @staticmethod 52 | def forward(ctx, D, gamma): 53 | dev = D.device 54 | dtype = D.dtype 55 | gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed 56 | D_ = D.detach().cpu().numpy() 57 | g_ = gamma.item() 58 | R = torch.Tensor(compute_softdtw(D_, g_)).to(dev).type(dtype) 59 | ctx.save_for_backward(D, R, gamma) 60 | return R[:, -2, -2] 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | dev = grad_output.device 65 | dtype = grad_output.dtype 66 | D, R, gamma = ctx.saved_tensors 67 | D_ = D.detach().cpu().numpy() 68 | R_ = R.detach().cpu().numpy() 69 | g_ = gamma.item() 70 | E = torch.Tensor(compute_softdtw_backward(D_, R_, g_)).to(dev).type(dtype) 71 | return grad_output.view(-1, 1, 1).expand_as(E) * E, None 72 | 73 | class SoftDTW(torch.nn.Module): 74 | def __init__(self, gamma=1.0, normalize=False): 75 | super(SoftDTW, self).__init__() 76 | self.normalize = normalize 77 | self.gamma=gamma 78 | self.func_dtw = _SoftDTW.apply 79 | 80 | def calc_distance_matrix(self, x, y): 81 | n = x.size(1) 82 | m = y.size(1) 83 | d = x.size(2) 84 | x = x.unsqueeze(2).expand(-1, n, m, d) 85 | y = y.unsqueeze(1).expand(-1, n, m, d) 86 | dist = torch.pow(x - y, 2).sum(3) 87 | return dist 88 | 89 | def forward(self, x, y): 90 | assert len(x.shape) == len(y.shape) 91 | squeeze = False 92 | if len(x.shape) < 3: 93 | x = x.unsqueeze(0) 94 | y = y.unsqueeze(0) 95 | squeeze = True 96 | if self.normalize: 97 | D_xy = self.calc_distance_matrix(x, y) 98 | 99 | out_xy = self.func_dtw(D_xy, self.gamma) 100 | D_xx = self.calc_distance_matrix(x, x) 101 | out_xx = self.func_dtw(D_xx, self.gamma) 102 | D_yy = self.calc_distance_matrix(y, y) 103 | out_yy = self.func_dtw(D_yy, self.gamma) 104 | result = out_xy - 1/2 * (out_xx + out_yy) # distance 105 | else: 106 | D_xy = self.calc_distance_matrix(x, y) 107 | 108 | out_xy = self.func_dtw(D_xy, self.gamma) 109 | result = out_xy # discrepancy 110 | return result.squeeze(0) if squeeze else result 111 | -------------------------------------------------------------------------------- /align_dataset_test.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import numpy as np 3 | import random 4 | from natsort import natsorted 5 | import utils 6 | 7 | from torch.utils.data import IterableDataset 8 | 9 | def get_steps_with_context(steps, num_context, context_stride): 10 | _context = np.arange(num_context-1, -1, -1) 11 | context_steps = np.maximum(0, steps[:, None] - _context * context_stride) 12 | return context_steps.reshape(-1) 13 | 14 | def sample_frames(frames, num_context, context_stride=15): 15 | seq_len = len(frames) 16 | 17 | chosen_steps = np.arange(seq_len) 18 | steps = get_steps_with_context(chosen_steps, num_context, context_stride) 19 | 20 | frames = np.array(frames)[steps] 21 | return frames, chosen_steps 22 | 23 | 24 | class AlignData(IterableDataset): 25 | 26 | def __init__(self, path, batch_size, data_config, transform=False, flatten=False): 27 | 28 | self.act_sequences = natsorted(glob.glob(os.path.join(path, '*'))) 29 | self.n_classes = len(self.act_sequences) 30 | 31 | self.batch_size = batch_size 32 | self.config = data_config 33 | 34 | self.current_act = -1 35 | 36 | if transform: 37 | self.transform = transform 38 | else: 39 | self.transform = utils.get_totensor_transform(is_video=True) 40 | 41 | self.flatten = flatten 42 | 43 | self.action = None 44 | self.num_seqs = None 45 | self._one_vid = False 46 | 47 | def __len__(self): 48 | return self.n_classes 49 | 50 | def get_action_name(self, i_action): 51 | return os.path.basename(self.act_sequences[i_action]) 52 | 53 | def set_action_seq(self, action, num_seqs=None): 54 | self.action = action 55 | self.num_seqs = num_seqs 56 | 57 | def set_spec_video(self, path): 58 | if not os.path.exists(path): 59 | raise Exception("Video doesn't exist") 60 | self.spec_vid = path 61 | self._one_vid = True 62 | 63 | def __iter__(self): 64 | 65 | if self.action is not None: 66 | if self._one_vid: 67 | act_sequences = [os.path.dirname(self.spec_vid)] 68 | else: 69 | act_sequences = [self.act_sequences[self.action]] 70 | else: 71 | act_sequences = self.act_sequences 72 | 73 | for _action in act_sequences: 74 | if self._one_vid: 75 | sequences = [self.spec_vid] 76 | else: 77 | sequences = natsorted(glob.glob(os.path.join(_action, '*'))) 78 | if self.num_seqs is not None: 79 | sequences = random.sample(sequences, min(self.num_seqs, len(sequences))) 80 | 81 | get_frame_paths = lambda x : sorted(glob.glob(os.path.join(x, '*'))) 82 | 83 | def seq_iter(): 84 | for seq in sequences: 85 | 86 | def frame_iter(): 87 | frames = get_frame_paths(seq) 88 | num_context = self.config.NUM_CONTEXT 89 | batch_step = num_context * self.batch_size 90 | 91 | frames, steps = sample_frames(frames, num_context, self.config.CONTEXT_STRIDE) 92 | for i in range(0, len(frames), batch_step): 93 | a_frames = frames[i:i+batch_step] 94 | a_x = utils.get_pil_images(a_frames) 95 | a_x = self.transform(a_x) 96 | 97 | if self.flatten: 98 | a_x = a_x.view((a_x.shape[0], -1)) 99 | 100 | a_name = os.path.join(os.path.basename(_action), os.path.basename(seq)) 101 | 102 | yield a_x, a_name, a_frames[num_context-1::num_context] 103 | 104 | yield frame_iter() 105 | yield seq_iter() -------------------------------------------------------------------------------- /evals/phase_progression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sklearn 4 | from sklearn.linear_model import LinearRegression 5 | from evals.datasets import DATASET_TO_NUM_CLASSES, IDX_TO_CLASS 6 | 7 | # def regression_labels_for_class(labels, class_idx): 8 | # # Assumes labels are ordered. Find the last occurrence of particular class. 9 | # class_idx = IDX_TO_CLASS[class_idx] 10 | # if len(np.argwhere(labels == class_idx)) > 0: 11 | # transition_frame = np.argwhere(labels == class_idx)[-1, 0] 12 | # return (np.arange(float(len(labels))) - transition_frame) / len(labels) 13 | # else: 14 | # return [] 15 | 16 | def regression_labels_for_class(labels, class_idx): 17 | # Assumes labels are ordered. Find the last occurrence of particular class. 18 | class_idx = IDX_TO_CLASS[class_idx] 19 | transition_frame = None 20 | if len(np.argwhere(labels == class_idx)) > 0: 21 | transition_frame = np.argwhere(labels == class_idx)[-1, 0] 22 | else: 23 | transition_frame = -1 24 | return (np.arange(float(len(labels))) - transition_frame) / len(labels) 25 | 26 | 27 | 28 | def get_regression_labels(class_labels, num_classes): 29 | regression_labels = [] 30 | for i in range(num_classes - 1): 31 | value=regression_labels_for_class(class_labels, i) 32 | if len(value) > 0: 33 | regression_labels.append(value) 34 | return np.stack(regression_labels, axis=1) 35 | 36 | def get_targets_from_labels(all_class_labels, num_classes): 37 | all_regression_labels = [] 38 | 39 | for class_labels in all_class_labels: 40 | all_regression_labels.append(get_regression_labels(class_labels, 41 | num_classes)) 42 | return all_regression_labels 43 | 44 | 45 | def unnormalize(preds): 46 | seq_len = len(preds) 47 | return np.mean([i - pred * seq_len for i, pred in enumerate(preds)]) 48 | 49 | 50 | class VectorRegression(sklearn.base.BaseEstimator): 51 | """Class to perform regression on multiple outputs.""" 52 | 53 | def __init__(self, estimator): 54 | self.estimator = estimator 55 | 56 | def fit(self, x, y): 57 | _, m = y.shape 58 | # Fit a separate regressor for each column of y 59 | self.estimators_ = [sklearn.base.clone(self.estimator).fit(x, y[:, i]) 60 | for i in range(m)] 61 | return self 62 | 63 | def predict(self, x): 64 | # Join regressors' predictions 65 | res = [est.predict(x)[:, np.newaxis] for est in self.estimators_] 66 | return np.hstack(res) 67 | 68 | def score(self, x, y): 69 | # Join regressors' scores 70 | res = [est.score(x, y[:, i]) for i, est in enumerate(self.estimators_)] 71 | return np.mean(res) 72 | 73 | 74 | def fit_model(train_embs, train_labels, val_embs, val_labels): 75 | """Linear Regression to regress to fraction completed.""" 76 | 77 | train_embs = np.concatenate(train_embs, axis=0) 78 | train_labels = np.concatenate(train_labels, axis=0) 79 | val_embs = np.concatenate(val_embs, axis=0) 80 | val_labels = np.concatenate(val_labels, axis=0) 81 | 82 | lin_model = VectorRegression(LinearRegression()) 83 | lin_model.fit(train_embs, train_labels) 84 | 85 | train_score = lin_model.score(train_embs, train_labels) 86 | val_score = lin_model.score(val_embs, val_labels) 87 | 88 | return train_score, val_score 89 | 90 | 91 | def evaluate_phase_progression(train_data, val_data, action, ckpt_step, CONFIG, writer=None, verbose=False): 92 | 93 | train_embs = train_data['embs'] 94 | val_embs = val_data['embs'] 95 | num_classes = DATASET_TO_NUM_CLASSES[action] 96 | 97 | if not train_embs or not val_embs: 98 | raise Exception("All embeddings are NAN. Something is wrong with model.") 99 | 100 | val_labels = get_targets_from_labels(val_data['labels'], 101 | num_classes) 102 | 103 | 104 | num_samples = len(train_data['embs']) 105 | 106 | train_scores = [] 107 | val_scores = [] 108 | for fraction_used in CONFIG.EVAL.CLASSIFICATION_FRACTIONS: 109 | num_samples_used = max(1, int(fraction_used * num_samples)) 110 | train_embs = train_data['embs'][:num_samples_used] 111 | train_labels = get_targets_from_labels( 112 | train_data['labels'][:num_samples_used], num_classes) 113 | 114 | train_score, val_score = fit_model(train_embs, train_labels, val_embs, val_labels) 115 | 116 | if verbose: 117 | print('\n-----------------------------') 118 | print('Fraction: ', fraction_used) 119 | print('Train-Score: ', train_score) 120 | print('Val-Score: ', val_score) 121 | 122 | if writer: 123 | writer.add_scalar(f'phase_progression/train_{action}_{fraction_used}', train_score, global_step=ckpt_step) 124 | writer.add_scalar(f'phase_progression/val_{action}_{fraction_used}', val_score, global_step=ckpt_step) 125 | 126 | 127 | print(f'phase_progression/train_{action}_{fraction_used}', train_score, f"global_step={ckpt_step}") 128 | print(f'phase_progression/val_{action}_{fraction_used}', val_score, f"global_step={ckpt_step}") 129 | train_scores.append(train_score) 130 | val_scores.append(val_score) 131 | 132 | return train_scores, val_scores -------------------------------------------------------------------------------- /evals/.ipynb_checkpoints/phase_progression-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sklearn 4 | from sklearn.linear_model import LinearRegression 5 | from evals.datasets import DATASET_TO_NUM_CLASSES, IDX_TO_CLASS 6 | 7 | # def regression_labels_for_class(labels, class_idx): 8 | # # Assumes labels are ordered. Find the last occurrence of particular class. 9 | # class_idx = IDX_TO_CLASS[class_idx] 10 | # if len(np.argwhere(labels == class_idx)) > 0: 11 | # transition_frame = np.argwhere(labels == class_idx)[-1, 0] 12 | # return (np.arange(float(len(labels))) - transition_frame) / len(labels) 13 | # else: 14 | # return [] 15 | 16 | def regression_labels_for_class(labels, class_idx): 17 | # Assumes labels are ordered. Find the last occurrence of particular class. 18 | class_idx = IDX_TO_CLASS[class_idx] 19 | transition_frame = None 20 | if len(np.argwhere(labels == class_idx)) > 0: 21 | transition_frame = np.argwhere(labels == class_idx)[-1, 0] 22 | else: 23 | transition_frame = -1 24 | return (np.arange(float(len(labels))) - transition_frame) / len(labels) 25 | 26 | 27 | 28 | def get_regression_labels(class_labels, num_classes): 29 | regression_labels = [] 30 | for i in range(num_classes - 1): 31 | value=regression_labels_for_class(class_labels, i) 32 | if len(value) > 0: 33 | regression_labels.append(value) 34 | return np.stack(regression_labels, axis=1) 35 | 36 | def get_targets_from_labels(all_class_labels, num_classes): 37 | all_regression_labels = [] 38 | 39 | for class_labels in all_class_labels: 40 | all_regression_labels.append(get_regression_labels(class_labels, 41 | num_classes)) 42 | return all_regression_labels 43 | 44 | 45 | def unnormalize(preds): 46 | seq_len = len(preds) 47 | return np.mean([i - pred * seq_len for i, pred in enumerate(preds)]) 48 | 49 | 50 | class VectorRegression(sklearn.base.BaseEstimator): 51 | """Class to perform regression on multiple outputs.""" 52 | 53 | def __init__(self, estimator): 54 | self.estimator = estimator 55 | 56 | def fit(self, x, y): 57 | _, m = y.shape 58 | # Fit a separate regressor for each column of y 59 | self.estimators_ = [sklearn.base.clone(self.estimator).fit(x, y[:, i]) 60 | for i in range(m)] 61 | return self 62 | 63 | def predict(self, x): 64 | # Join regressors' predictions 65 | res = [est.predict(x)[:, np.newaxis] for est in self.estimators_] 66 | return np.hstack(res) 67 | 68 | def score(self, x, y): 69 | # Join regressors' scores 70 | res = [est.score(x, y[:, i]) for i, est in enumerate(self.estimators_)] 71 | return np.mean(res) 72 | 73 | 74 | def fit_model(train_embs, train_labels, val_embs, val_labels): 75 | """Linear Regression to regress to fraction completed.""" 76 | 77 | train_embs = np.concatenate(train_embs, axis=0) 78 | train_labels = np.concatenate(train_labels, axis=0) 79 | val_embs = np.concatenate(val_embs, axis=0) 80 | val_labels = np.concatenate(val_labels, axis=0) 81 | 82 | lin_model = VectorRegression(LinearRegression()) 83 | lin_model.fit(train_embs, train_labels) 84 | 85 | train_score = lin_model.score(train_embs, train_labels) 86 | val_score = lin_model.score(val_embs, val_labels) 87 | 88 | return train_score, val_score 89 | 90 | 91 | def evaluate_phase_progression(train_data, val_data, action, ckpt_step, CONFIG, writer=None, verbose=False): 92 | 93 | train_embs = train_data['embs'] 94 | val_embs = val_data['embs'] 95 | num_classes = DATASET_TO_NUM_CLASSES[action] 96 | 97 | if not train_embs or not val_embs: 98 | raise Exception("All embeddings are NAN. Something is wrong with model.") 99 | 100 | val_labels = get_targets_from_labels(val_data['labels'], 101 | num_classes) 102 | 103 | 104 | num_samples = len(train_data['embs']) 105 | 106 | train_scores = [] 107 | val_scores = [] 108 | for fraction_used in CONFIG.EVAL.CLASSIFICATION_FRACTIONS: 109 | num_samples_used = max(1, int(fraction_used * num_samples)) 110 | train_embs = train_data['embs'][:num_samples_used] 111 | train_labels = get_targets_from_labels( 112 | train_data['labels'][:num_samples_used], num_classes) 113 | 114 | train_score, val_score = fit_model(train_embs, train_labels, val_embs, val_labels) 115 | 116 | if verbose: 117 | print('\n-----------------------------') 118 | print('Fraction: ', fraction_used) 119 | print('Train-Score: ', train_score) 120 | print('Val-Score: ', val_score) 121 | 122 | if writer: 123 | writer.add_scalar(f'phase_progression/train_{action}_{fraction_used}', train_score, global_step=ckpt_step) 124 | writer.add_scalar(f'phase_progression/val_{action}_{fraction_used}', val_score, global_step=ckpt_step) 125 | 126 | 127 | print(f'phase_progression/train_{action}_{fraction_used}', train_score, f"global_step={ckpt_step}") 128 | print(f'phase_progression/val_{action}_{fraction_used}', val_score, f"global_step={ckpt_step}") 129 | train_scores.append(train_score) 130 | val_scores.append(val_score) 131 | 132 | return train_scores, val_scores -------------------------------------------------------------------------------- /wget-log: -------------------------------------------------------------------------------- 1 | --2023-02-11 20:41:59-- https://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e=download 2 | Resolving doc-0k-6o-docs.googleusercontent.com (doc-0k-6o-docs.googleusercontent.com)... 142.251.163.132, 2607:f8b0:4004:c1b::84 3 | Connecting to doc-0k-6o-docs.googleusercontent.com (doc-0k-6o-docs.googleusercontent.com)|142.251.163.132|:443... connected. 4 | HTTP request sent, awaiting response... 302 Found 5 | Location: https://docs.google.com/nonceSigner?nonce=jnmn4dnmhu1vm&continue=https://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e%3Ddownload&hash=n2jrt6q4cqhcct2jf0b8qd63vh2q36ak [following] 6 | --2023-02-11 20:41:59-- https://docs.google.com/nonceSigner?nonce=jnmn4dnmhu1vm&continue=https://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e%3Ddownload&hash=n2jrt6q4cqhcct2jf0b8qd63vh2q36ak 7 | Resolving docs.google.com (docs.google.com)... 142.251.163.101, 142.251.163.113, 142.251.163.100, ... 8 | Connecting to docs.google.com (docs.google.com)|142.251.163.101|:443... connected. 9 | HTTP request sent, awaiting response... 302 Found 10 | Location: https://accounts.google.com/ServiceLogin?service=wise&passive=1209600&continue=https://docs.google.com/nonceSigner?nonce%3Djnmn4dnmhu1vm%26continue%3Dhttps://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak&followup=https://docs.google.com/nonceSigner?nonce%3Djnmn4dnmhu1vm%26continue%3Dhttps://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak [following] 11 | --2023-02-11 20:41:59-- https://accounts.google.com/ServiceLogin?service=wise&passive=1209600&continue=https://docs.google.com/nonceSigner?nonce%3Djnmn4dnmhu1vm%26continue%3Dhttps://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak&followup=https://docs.google.com/nonceSigner?nonce%3Djnmn4dnmhu1vm%26continue%3Dhttps://doc-0k-6o-docs.googleusercontent.com/docs/securesc/a25jtu3utvilrqlj9eu9v7jl8vgupv35/icvkmhjaf18838idusj5ossoeehuumim/1676148075000/15729137890663927406/04461041260087607793/1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak 12 | Resolving accounts.google.com (accounts.google.com)... 142.251.163.84, 2607:f8b0:4004:c06::54 13 | Connecting to accounts.google.com (accounts.google.com)|142.251.163.84|:443... connected. 14 | HTTP request sent, awaiting response... 302 Moved Temporarily 15 | Location: https://accounts.google.com/v3/signin/identifier?dsh=S1108005935%3A1676148119416376&continue=https%3A%2F%2Fdocs.google.com%2FnonceSigner%3Fnonce%3Djnmn4dnmhu1vm%26continue%3Dhttps%3A%2F%2Fdoc-0k-6o-docs.googleusercontent.com%2Fdocs%2Fsecuresc%2Fa25jtu3utvilrqlj9eu9v7jl8vgupv35%2Ficvkmhjaf18838idusj5ossoeehuumim%2F1676148075000%2F15729137890663927406%2F04461041260087607793%2F1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6%3Fe%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak&followup=https%3A%2F%2Fdocs.google.com%2FnonceSigner%3Fnonce%3Djnmn4dnmhu1vm%26continue%3Dhttps%3A%2F%2Fdoc-0k-6o-docs.googleusercontent.com%2Fdocs%2Fsecuresc%2Fa25jtu3utvilrqlj9eu9v7jl8vgupv35%2Ficvkmhjaf18838idusj5ossoeehuumim%2F1676148075000%2F15729137890663927406%2F04461041260087607793%2F1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6%3Fe%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak&passive=1209600&service=wise&flowName=WebLiteSignIn&flowEntry=ServiceLogin&ifkv=AWnogHcw4HkRvG-fbk3EPpHhWBAIl1ZUUc1VUv2nT3y6o5Jy5nyM-yr13s-W9anozC1S3E598mKiBw [following] 16 | --2023-02-11 20:41:59-- https://accounts.google.com/v3/signin/identifier?dsh=S1108005935%3A1676148119416376&continue=https%3A%2F%2Fdocs.google.com%2FnonceSigner%3Fnonce%3Djnmn4dnmhu1vm%26continue%3Dhttps%3A%2F%2Fdoc-0k-6o-docs.googleusercontent.com%2Fdocs%2Fsecuresc%2Fa25jtu3utvilrqlj9eu9v7jl8vgupv35%2Ficvkmhjaf18838idusj5ossoeehuumim%2F1676148075000%2F15729137890663927406%2F04461041260087607793%2F1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6%3Fe%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak&followup=https%3A%2F%2Fdocs.google.com%2FnonceSigner%3Fnonce%3Djnmn4dnmhu1vm%26continue%3Dhttps%3A%2F%2Fdoc-0k-6o-docs.googleusercontent.com%2Fdocs%2Fsecuresc%2Fa25jtu3utvilrqlj9eu9v7jl8vgupv35%2Ficvkmhjaf18838idusj5ossoeehuumim%2F1676148075000%2F15729137890663927406%2F04461041260087607793%2F1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6%3Fe%253Ddownload%26hash%3Dn2jrt6q4cqhcct2jf0b8qd63vh2q36ak&passive=1209600&service=wise&flowName=WebLiteSignIn&flowEntry=ServiceLogin&ifkv=AWnogHcw4HkRvG-fbk3EPpHhWBAIl1ZUUc1VUv2nT3y6o5Jy5nyM-yr13s-W9anozC1S3E598mKiBw 17 | Reusing existing connection to accounts.google.com:443. 18 | HTTP request sent, awaiting response... 200 OK 19 | Length: unspecified [text/html] 20 | Saving to: ‘1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e=download’ 21 | 22 | 1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e=download [<=> ] 0 --.-KB/s 1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e=download [ <=> ] 142.04K --.-KB/s in 0.008s 23 | 24 | 2023-02-11 20:41:59 (16.7 MB/s) - ‘1eCbrIuw--16xCmI3RtBhRJ-r9K_FVkL6?e=download’ saved [145452] 25 | 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | CONFIG = edict() 4 | 5 | CONFIG.DESCRIPTION = 'Exp 1' 6 | 7 | CONFIG.ROOT = '/home/ubuntu/Data' 8 | CONFIG.CKPT_PATH = '/home/ubuntu/Data/LOGS/CKPTS/' 9 | CONFIG.DATA_PATH = '/home/ubuntu/Data' 10 | CONFIG.GPUS = 1 11 | CONFIG.SEED = 0 12 | 13 | CONFIG.TRAIN = edict() 14 | 15 | CONFIG.TRAIN.NUM_FRAMES = 20 16 | CONFIG.TRAIN.EPOCHS = 100 17 | CONFIG.TRAIN.SAVE_INTERVAL_ITERS = 1000 18 | CONFIG.TRAIN.LR = 1e-4 19 | CONFIG.TRAIN.WEIGHT_DECAY = 1e-5 20 | CONFIG.TRAIN.BATCH_SIZE = 1 21 | 22 | CONFIG.TRAIN.FREEZE_BASE = False 23 | CONFIG.TRAIN.FREEZE_BN_ONLY = False 24 | 25 | # CONFIG.TRAIN.VAL_PERCENT = 0.1 26 | CONFIG.TRAIN.VAL_PERCENT = 1 27 | 28 | CONFIG.EVAL = edict() 29 | CONFIG.EVAL.NUM_FRAMES = 20 30 | 31 | CONFIG.EVAL.CLASSIFICATION_FRACTIONS = [0.1, 0.5, 1.0] 32 | CONFIG.EVAL.KENDALLS_TAU_DISTANCE = 'sqeuclidean' # cosine or sqeuclidean 33 | 34 | # DTW Alignment Parameters 35 | CONFIG.DTWALIGNMENT = edict() 36 | CONFIG.DTWALIGNMENT.EMBEDDING_SIZE = 128 37 | CONFIG.DTWALIGNMENT.SDTW_GAMMA = 0.1 38 | CONFIG.DTWALIGNMENT.SDTW_NORMALIZE = False 39 | 40 | CONFIG.LOSSES = edict() 41 | CONFIG.LOSSES.IDM_IDX_MARGIN = 2. 42 | CONFIG.LOSSES.ALPHA = 0.5 43 | CONFIG.LOSSES.SIGMA = 10 # window size 44 | CONFIG.LOSSES.L2_NORMALIZE = True 45 | 46 | # TCC Parameters 47 | CONFIG.TCC = edict() 48 | CONFIG.TCC.EMBEDDING_SIZE = 128 49 | CONFIG.TCC.CYCLE_LENGTH = 2 50 | CONFIG.TCC.LABEL_SMOOTHING = 0.1 51 | CONFIG.TCC.SOFTMAX_TEMPERATURE = 0.1 52 | CONFIG.TCC.LOSS_TYPE = 'regression_mse_var' 53 | CONFIG.TCC.NORMALIZE_INDICES = True 54 | CONFIG.TCC.VARIANCE_LAMBDA = 0.001 55 | CONFIG.TCC.FRACTION = 1.0 56 | CONFIG.TCC.HUBER_DELTA = 0.1 57 | CONFIG.TCC.SIMILARITY_TYPE = 'l2' # l2, cosine 58 | 59 | CONFIG.TCC.TCC_LAMBDA = 1.0 60 | 61 | CONFIG.DATA = edict() 62 | 63 | CONFIG.DATA.IMAGE_SIZE = 224 # For ResNet50 64 | 65 | CONFIG.DATA.SHUFFLE_QUEUE_SIZE = 0 66 | CONFIG.DATA.NUM_PREFETCH_BATCHES = 1 67 | CONFIG.DATA.RANDOM_OFFSET = 1 68 | CONFIG.DATA.FRAME_STRIDE = 16 69 | CONFIG.DATA.SAMPLING_STRATEGY = 'segment_uniform' # offset_uniform, stride, all, segment_uniform 70 | CONFIG.DATA.NUM_CONTEXT = 2 # number of frames that will be embedded jointly, 71 | CONFIG.DATA.CONTEXT_STRIDE = 15 # stride between context frames 72 | 73 | CONFIG.DATA.FRAME_LABELS = True 74 | CONFIG.DATA.PER_DATASET_FRACTION = 1.0 # Use 0 to use only one sample. 75 | CONFIG.DATA.PER_CLASS = False 76 | 77 | # stride of frames while embedding a video during evaluation. 78 | CONFIG.DATA.SAMPLE_ALL_STRIDE = 1 79 | 80 | # CONFIG.DATA.TCN = CONFIG.TCN 81 | CONFIG.DATA.WORKERS = 30 82 | 83 | # ****************************************************************************** 84 | # Augmentation params 85 | # ****************************************************************************** 86 | CONFIG.AUGMENTATION = edict() 87 | CONFIG.AUGMENTATION.RANDOM_FLIP = True 88 | CONFIG.AUGMENTATION.RANDOM_CROP = False 89 | CONFIG.AUGMENTATION.BRIGHTNESS_DELTA = 32.0 / 255 # 0 to turn off 90 | CONFIG.AUGMENTATION.CONTRAST_DELTA = 0.5 # 0 to turn off 91 | CONFIG.AUGMENTATION.HUE_DELTA = 0. # 0 to turn off 92 | CONFIG.AUGMENTATION.SATURATION_DELTA = 0. # 0 to turn off 93 | from easydict import EasyDict as edict 94 | 95 | CONFIG = edict() 96 | 97 | CONFIG.DESCRIPTION = 'Exp 1' 98 | 99 | CONFIG.ROOT = '/home/ubuntu/Data' 100 | CONFIG.CKPT_PATH = '/home/ubuntu/Data/LOGS/CKPTS/' 101 | CONFIG.DATA_PATH = '/home/ubuntu/Data' 102 | CONFIG.GPUS = 1 103 | CONFIG.SEED = 0 104 | 105 | CONFIG.TRAIN = edict() 106 | 107 | CONFIG.TRAIN.NUM_FRAMES = 20 108 | CONFIG.TRAIN.EPOCHS = 100 109 | CONFIG.TRAIN.SAVE_INTERVAL_ITERS = 1000 110 | CONFIG.TRAIN.LR = 1e-4 111 | CONFIG.TRAIN.WEIGHT_DECAY = 1e-5 112 | CONFIG.TRAIN.BATCH_SIZE = 1 113 | 114 | CONFIG.TRAIN.FREEZE_BASE = False 115 | CONFIG.TRAIN.FREEZE_BN_ONLY = False 116 | 117 | # CONFIG.TRAIN.VAL_PERCENT = 0.1 118 | CONFIG.TRAIN.VAL_PERCENT = 1 119 | 120 | CONFIG.EVAL = edict() 121 | CONFIG.EVAL.NUM_FRAMES = 20 122 | 123 | CONFIG.EVAL.CLASSIFICATION_FRACTIONS = [0.1, 0.5, 1.0] 124 | CONFIG.EVAL.KENDALLS_TAU_DISTANCE = 'sqeuclidean' # cosine or sqeuclidean 125 | 126 | # DTW Alignment Parameters 127 | CONFIG.DTWALIGNMENT = edict() 128 | CONFIG.DTWALIGNMENT.EMBEDDING_SIZE = 128 129 | CONFIG.DTWALIGNMENT.SDTW_GAMMA = 0.1 130 | CONFIG.DTWALIGNMENT.SDTW_NORMALIZE = False 131 | 132 | CONFIG.LOSSES = edict() 133 | CONFIG.LOSSES.IDM_IDX_MARGIN = 2. 134 | CONFIG.LOSSES.ALPHA = 0.5 135 | CONFIG.LOSSES.SIGMA = 10 # window size 136 | CONFIG.LOSSES.L2_NORMALIZE = True 137 | 138 | # TCC Parameters 139 | CONFIG.TCC = edict() 140 | CONFIG.TCC.EMBEDDING_SIZE = 128 141 | CONFIG.TCC.CYCLE_LENGTH = 2 142 | CONFIG.TCC.LABEL_SMOOTHING = 0.1 143 | CONFIG.TCC.SOFTMAX_TEMPERATURE = 0.1 144 | CONFIG.TCC.LOSS_TYPE = 'regression_mse_var' 145 | CONFIG.TCC.NORMALIZE_INDICES = True 146 | CONFIG.TCC.VARIANCE_LAMBDA = 0.001 147 | CONFIG.TCC.FRACTION = 1.0 148 | CONFIG.TCC.HUBER_DELTA = 0.1 149 | CONFIG.TCC.SIMILARITY_TYPE = 'l2' # l2, cosine 150 | 151 | CONFIG.TCC.TCC_LAMBDA = 1.0 152 | 153 | CONFIG.DATA = edict() 154 | 155 | CONFIG.DATA.IMAGE_SIZE = 224 # For ResNet50 156 | 157 | CONFIG.DATA.SHUFFLE_QUEUE_SIZE = 0 158 | CONFIG.DATA.NUM_PREFETCH_BATCHES = 1 159 | CONFIG.DATA.RANDOM_OFFSET = 1 160 | CONFIG.DATA.FRAME_STRIDE = 16 161 | CONFIG.DATA.SAMPLING_STRATEGY = 'segment_uniform' # offset_uniform, stride, all, segment_uniform 162 | CONFIG.DATA.NUM_CONTEXT = 2 # number of frames that will be embedded jointly, 163 | CONFIG.DATA.CONTEXT_STRIDE = 15 # stride between context frames 164 | 165 | CONFIG.DATA.FRAME_LABELS = True 166 | CONFIG.DATA.PER_DATASET_FRACTION = 1.0 # Use 0 to use only one sample. 167 | CONFIG.DATA.PER_CLASS = False 168 | 169 | # stride of frames while embedding a video during evaluation. 170 | CONFIG.DATA.SAMPLE_ALL_STRIDE = 1 171 | 172 | # CONFIG.DATA.TCN = CONFIG.TCN 173 | CONFIG.DATA.WORKERS = 30 174 | 175 | # ****************************************************************************** 176 | # Augmentation params 177 | # ****************************************************************************** 178 | CONFIG.AUGMENTATION = edict() 179 | CONFIG.AUGMENTATION.RANDOM_FLIP = True 180 | CONFIG.AUGMENTATION.RANDOM_CROP = False 181 | CONFIG.AUGMENTATION.BRIGHTNESS_DELTA = 32.0 / 255 # 0 to turn off 182 | CONFIG.AUGMENTATION.CONTRAST_DELTA = 0.5 # 0 to turn off 183 | CONFIG.AUGMENTATION.HUE_DELTA = 0. # 0 to turn off 184 | CONFIG.AUGMENTATION.SATURATION_DELTA = 0. # 0 to turn off 185 | -------------------------------------------------------------------------------- /align_dataset.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import numpy as np 3 | import random 4 | import utils 5 | 6 | from torch.utils.data import Dataset 7 | 8 | def get_steps_with_context(steps, num_context, context_stride): 9 | _context = np.arange(num_context-1, -1, -1) 10 | context_steps = np.maximum(0, steps[:, None] - _context * context_stride) 11 | return context_steps.reshape(-1) 12 | 13 | def sample_frames(frames, num_frames, num_context, frame_stride, 14 | sampling='offset_uniform', random_offset=1, context_stride=15, is_tcn=False, tcn_window=5): 15 | 16 | seq_len = len(frames) 17 | 18 | if sampling == 'stride': 19 | 20 | offset = random.randint(0, max(1, seq_len - frame_stride * num_frames)-1) 21 | steps = np.arange(offset, offset + frame_stride * num_frames + 1, frame_stride) 22 | # cap at max length 23 | steps = np.minimum(steps, seq_len-1) 24 | steps = steps[:num_frames] 25 | 26 | elif sampling == 'offset_uniform': 27 | 28 | def _sample_random(offset): 29 | assert offset <= seq_len, "Offset is greater than the Sequence length" 30 | steps = np.arange(offset, seq_len) 31 | random.shuffle(steps) 32 | steps = steps[:num_frames] 33 | steps = np.sort(steps) 34 | return steps 35 | 36 | def _sample_all(): 37 | return np.arange(0, num_frames) 38 | 39 | if num_frames < seq_len - random_offset: 40 | steps = _sample_random(random_offset) 41 | else: 42 | steps = _sample_all() 43 | elif sampling == 'segment_uniform': 44 | 45 | if num_frames > seq_len: 46 | steps = np.arange(num_frames) 47 | else: 48 | r = num_frames - seq_len % num_frames 49 | 50 | if r < num_frames: 51 | steps = np.concatenate([np.arange(seq_len), np.arange(r)]) 52 | else: 53 | steps = np.arange(seq_len) 54 | f = len(steps) / num_frames 55 | 56 | sampled_idxes = np.arange(num_frames) * f + np.array(random.choices(range(np.int(f)), k=num_frames)) 57 | sampled_idxes = sampled_idxes.astype(np.int32) 58 | 59 | steps = np.sort(steps[sampled_idxes]) 60 | 61 | elif sampling == 'all': 62 | steps = np.arange(0, seq_len) 63 | else: 64 | raise Exception("{} not implemented.".format(sampling)) 65 | 66 | if is_tcn: 67 | pos_steps = steps - np.array(random.choices(range(1, tcn_window + 1), k=len(num_frames))) 68 | steps = np.stack([pos_steps, steps]) 69 | steps = steps.T.reshape((-1, )) 70 | 71 | steps = np.minimum(steps, seq_len-1) 72 | chosen_steps = steps.astype(np.float32) / seq_len 73 | steps = get_steps_with_context(steps, num_context, context_stride) 74 | 75 | frames = np.array(frames)[steps] 76 | 77 | return frames, chosen_steps, float(seq_len) 78 | 79 | 80 | class AlignData(Dataset): 81 | 82 | def __init__(self, path, num_frames, data_config, neg_example=False, transform=False, flatten=False): 83 | 84 | self.act_sequences = sorted(glob.glob(os.path.join(path, '*'))) 85 | self.n_sequences = len(self.act_sequences) 86 | 87 | self.n_classes = len(self.act_sequences) 88 | self.num_frames = num_frames 89 | self.config = data_config 90 | 91 | self.neg_example = neg_example 92 | 93 | if transform: 94 | self.transform = transform 95 | else: 96 | self.transform = utils.get_totensor_transform(is_video=True) 97 | 98 | self.flatten = flatten 99 | 100 | def __len__(self): 101 | return self.n_sequences 102 | 103 | def __getitem__(self, idx): 104 | 105 | a = self.act_sequences[idx] 106 | 107 | b = a 108 | while a == b: 109 | b = random.choice(self.act_sequences) 110 | 111 | assert a != b, "Same sequences sampled!" 112 | 113 | config = self.config 114 | get_frame_paths = lambda x : sorted(glob.glob(os.path.join(x, '*'))) 115 | 116 | a_frames = get_frame_paths(a) 117 | # a_frames, a_chosen_steps, a_seq_len = sample_frames(a_frames, num_frames=self.num_frames, num_context=config.NUM_CONTEXT, 118 | # frame_stride=config.FRAME_STRIDE, sampling=config.SAMPLING_STRATEGY, 119 | # random_offset=config.RANDOM_OFFSET, context_stride=config.CONTEXT_STRIDE, 120 | # is_tcn=config.TCN.IS_TCN, tcn_window=config.TCN.POS_WINDOW) 121 | 122 | a_frames, a_chosen_steps, a_seq_len = sample_frames(a_frames, num_frames=self.num_frames, num_context=config.NUM_CONTEXT, 123 | frame_stride=config.FRAME_STRIDE, sampling=config.SAMPLING_STRATEGY, 124 | random_offset=config.RANDOM_OFFSET, context_stride=config.CONTEXT_STRIDE) 125 | 126 | b_frames = get_frame_paths(b) 127 | # b_frames, b_chosen_steps, b_seq_len = sample_frames(b_frames, num_frames=self.num_frames, num_context=config.NUM_CONTEXT, 128 | # frame_stride=config.FRAME_STRIDE, sampling=config.SAMPLING_STRATEGY, 129 | # random_offset=config.RANDOM_OFFSET, context_stride=config.CONTEXT_STRIDE, 130 | # is_tcn=config.TCN.IS_TCN, tcn_window=config.TCN.POS_WINDOW) 131 | 132 | b_frames, b_chosen_steps, b_seq_len = sample_frames(b_frames, num_frames=self.num_frames, num_context=config.NUM_CONTEXT, 133 | frame_stride=config.FRAME_STRIDE, sampling=config.SAMPLING_STRATEGY, 134 | random_offset=config.RANDOM_OFFSET, context_stride=config.CONTEXT_STRIDE) 135 | 136 | a_x = utils.get_pil_images(a_frames) 137 | b_x = utils.get_pil_images(b_frames) 138 | 139 | a_x = self.transform(a_x) 140 | b_x = self.transform(b_x) 141 | 142 | a_name = 'Vid_{}'.format(os.path.basename(a)) 143 | b_name = 'Vid_{}'.format(os.path.basename(b)) 144 | 145 | result = [[a_x, a_name, a_chosen_steps, a_seq_len], [b_x, b_name, b_chosen_steps, b_seq_len]] 146 | 147 | if self.flatten: 148 | for item in result: 149 | item[0] = item[0].view((item[0].shape[0], -1)) 150 | 151 | return result -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import utils 6 | from soft_dtw import SoftDTW 7 | 8 | def calc_distance_matrix(x, y): 9 | n = x.size(1) 10 | m = y.size(1) 11 | d = x.size(2) 12 | x = x.unsqueeze(2).expand(-1, n, m, d) 13 | y = y.unsqueeze(1).expand(-1, n, m, d) 14 | dist = torch.pow(x - y, 2).sum(3) 15 | return dist 16 | 17 | 18 | class Contrastive_IDM(nn.Module): 19 | 20 | def __init__(self, sigma, margin, debug=False): 21 | super(Contrastive_IDM, self).__init__() 22 | 23 | self.sigma = sigma 24 | self.margin = margin 25 | self.debug = debug 26 | 27 | def forward(self, dist, idx, seq_len, logger=None): 28 | 29 | grid_x, grid_y = torch.meshgrid(idx, idx) 30 | 31 | prob = F.relu(self.margin - dist) 32 | 33 | weights_orig = 1 + torch.pow(grid_x - grid_y, 2) 34 | 35 | diff = torch.abs(grid_x - grid_y) - (self.sigma / seq_len) 36 | 37 | _ones = torch.ones_like(diff) 38 | _zeros = torch.zeros_like(diff) 39 | weights_neg = torch.where(diff > 0, weights_orig, _zeros) 40 | 41 | weights_pos = torch.where(diff > 0, _zeros, _ones) 42 | 43 | if not self.training and self.debug and logger: 44 | logger.experiment.add_image('idm_diff', utils.plot_to_image(diff), 0, dataformats='CHW') 45 | logger.experiment.add_image('idm_weights_pos', utils.plot_to_image(weights_pos), 0, dataformats='CHW') 46 | logger.experiment.add_image('idm_weights_neg', utils.plot_to_image(weights_neg), 0, dataformats='CHW') 47 | logger.experiment.add_image('idm_prob', utils.plot_to_image(prob), 0, dataformats='CHW') 48 | 49 | idm = weights_neg * prob + weights_pos * dist 50 | 51 | return torch.sum(idm), idm 52 | 53 | class LAV(nn.Module): 54 | 55 | def __init__(self, alpha, sigma, margin, num_frames, dtw_gamma, dtw_normalize, debug=False): 56 | super(LAV, self).__init__() 57 | 58 | self.alpha = alpha 59 | self.debug = debug 60 | self.N = num_frames 61 | 62 | self.dtw_loss = SoftDTW(gamma=dtw_gamma, normalize=dtw_normalize) 63 | 64 | self.inverse_idm = Contrastive_IDM(sigma=sigma, margin=margin, debug=debug) 65 | 66 | def forward(self, a_emb, b_emb, a_idx, b_idx, a_len, b_len, logger=None): 67 | 68 | pos_loss = self.dtw_loss(a_emb, b_emb) 69 | 70 | # frame level loss 71 | dist_a = calc_distance_matrix(a_emb, a_emb).squeeze(0) 72 | dist_b = calc_distance_matrix(b_emb, b_emb).squeeze(0) 73 | 74 | idm_a, _ = self.inverse_idm(dist_a, a_idx, a_len, logger=logger) 75 | idm_b, _ = self.inverse_idm(dist_b, b_idx, b_len, logger=logger) 76 | 77 | total_loss = pos_loss + self.alpha * (idm_a + idm_b) 78 | total_loss = total_loss / self.N 79 | 80 | if not self.training and self.debug and logger: 81 | logger.experiment.add_image('dist_a', utils.plot_to_image(dist_a), 0, dataformats='CHW') 82 | logger.experiment.add_image('dist_b', utils.plot_to_image(dist_b), 0, dataformats='CHW') 83 | 84 | return total_loss 85 | 86 | class TCC(nn.Module): 87 | 88 | def __init__(self, channels, temperature, var_lambda, debug=False): 89 | super(TCC, self).__init__() 90 | 91 | self.debug = debug 92 | self.channels = channels 93 | self.temperature = temperature 94 | self.var_lambda = var_lambda 95 | 96 | def _pairwise_distance(self, x, y): 97 | x = x.unsqueeze(1) 98 | y = y.unsqueeze(0) 99 | dist = torch.pow(x - y, 2).sum(2) 100 | return dist 101 | 102 | def _get_scaled_similarity(self, emb_a, emb_b): 103 | 104 | sim = -1. * self._pairwise_distance(emb_a, emb_b) 105 | 106 | scaled_similarity = sim / self.channels 107 | scaled_similarity = scaled_similarity / self.temperature 108 | 109 | return scaled_similarity 110 | 111 | def _tcc_loss(self, emb_a, emb_b, idxes): 112 | 113 | sim_ab = self._get_scaled_similarity(emb_a, emb_b) 114 | softmaxed_sim_ab = F.softmax(sim_ab, dim=-1) 115 | 116 | soft_nn = torch.matmul(softmaxed_sim_ab, emb_b) 117 | 118 | sim_ba = self._get_scaled_similarity(soft_nn, emb_a) 119 | 120 | labels = idxes 121 | 122 | beta = F.softmax(sim_ba, dim=-1) 123 | preds = torch.sum(beta * idxes, dim=-1, keepdim=True) 124 | 125 | pred_var = torch.sum(torch.pow(idxes - preds, 2.) * beta, axis=1) 126 | pred_var_log = torch.log(pred_var) 127 | 128 | squared_error = torch.pow(labels.squeeze() - preds.squeeze(), 2.) 129 | return torch.sum(torch.exp(-pred_var_log) * squared_error + self.var_lambda * pred_var_log) 130 | 131 | 132 | def forward(self, emb_a, emb_b, idx_a, idx_b, logger=None): 133 | 134 | emb_a = emb_a.squeeze(0) 135 | emb_b = emb_b.squeeze(0) 136 | 137 | loss_ab = self._tcc_loss(emb_a, emb_b, idx_a) 138 | loss_ba = self._tcc_loss(emb_b, emb_a, idx_b) 139 | 140 | return (loss_ab + loss_ba) / (emb_a.size(0) + emb_b.size(0)) 141 | 142 | 143 | class TCN(nn.Module): 144 | 145 | def __init__(self, reg_lambda=0.002): 146 | super(TCN, self).__init__() 147 | 148 | self.reg_lambda = reg_lambda 149 | 150 | def _npairs_loss(self, labels, embeddings_anchor, embeddings_positive): 151 | """Returns n-pairs metric loss.""" 152 | square = lambda x : torch.pow(x, 2) 153 | reg_anchor = torch.mean(torch.sum(square(embeddings_anchor), 1)) 154 | reg_positive = torch.mean(torch.sum( 155 | square(embeddings_positive), 1)) 156 | l2loss = 0.25 * self.reg_lambda * (reg_anchor + reg_positive) 157 | 158 | # Get per pair similarities. 159 | similarity_matrix = torch.matmul( 160 | embeddings_anchor, embeddings_positive.t()) 161 | 162 | # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. 163 | lshape = labels.shape 164 | 165 | # Add the softmax loss. 166 | xent_loss = F.cross_entropy( 167 | input=similarity_matrix, target=labels) 168 | #xent_loss = tf.reduce_mean(xent_loss) 169 | 170 | return l2loss + xent_loss 171 | 172 | 173 | def single_sequence_loss(self, embs, num_steps): 174 | """Returns n-pairs loss for a single sequence.""" 175 | 176 | labels = torch.arange(num_steps) 177 | embeddings_anchor = embs[0::2] 178 | embeddings_positive = embs[1::2] 179 | loss = self._npairs_loss(labels, embeddings_anchor, embeddings_positive) 180 | return loss 181 | 182 | def forward(self, embs, num_steps): 183 | return self.single_sequence_loss(embs.squeeze(0), num_steps) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision.models as models 6 | 7 | class ConvEmbedder(nn.Module): 8 | 9 | def __init__(self, emb_size=128, l2_normalize=False): 10 | super(ConvEmbedder, self).__init__() 11 | 12 | self.emb_size = emb_size 13 | self.l2_normalize = l2_normalize 14 | 15 | self.conv1 = nn.Conv3d(in_channels=1024, out_channels=512, kernel_size=3, padding=1) 16 | self.bn1 = nn.BatchNorm1d(512) 17 | 18 | self.conv2 = nn.Conv3d(in_channels=512, out_channels=512, kernel_size=3, padding=1) 19 | self.bn2 = nn.BatchNorm1d(512) 20 | 21 | self.fc1 = nn.Linear(512, 512) 22 | self.dropout1 = nn.Dropout(0.1) 23 | 24 | self.fc2 = nn.Linear(512, 512) 25 | self.dropout2 = nn.Dropout(0.1) 26 | 27 | self.embedding_layer = nn.Linear(512, emb_size) 28 | 29 | def apply_bn(self, bn, x): 30 | N, C, T, H, W = x.shape 31 | x = x.permute(0, 2, 3, 4, 1) 32 | x = torch.reshape(x, (-1, x.shape[-1])) 33 | x = bn(x) 34 | x = torch.reshape(x, (N, T, H, W, C)) 35 | x = x.permute(0, 4, 1, 2, 3) 36 | return x 37 | 38 | def forward(self, x, num_frames): 39 | 40 | batch_size, total_num_steps, c, h, w = x.shape 41 | num_context = total_num_steps // num_frames 42 | x = torch.reshape(x, (batch_size * num_frames, num_context, c, h, w)) 43 | 44 | # TxCxHxW -> CxTxHxW 45 | x = x.transpose(1, 2) 46 | 47 | x = self.conv1(x) 48 | 49 | x = self.apply_bn(self.bn1, x) 50 | x = F.relu(x) 51 | 52 | x = self.conv2(x) 53 | x = self.apply_bn(self.bn2, x) 54 | x = F.relu(x) 55 | 56 | x = torch.max(x.view(x.size(0), x.size(1), -1), dim=-1)[0] 57 | x = self.dropout1(x) 58 | x = self.fc1(x) 59 | x = F.relu(x) 60 | 61 | x = self.dropout2(x) 62 | x = self.fc2(x) 63 | x = F.relu(x) 64 | 65 | x = self.embedding_layer(x) 66 | 67 | if self.l2_normalize: 68 | x = F.normalize(x, p=2, dim=-1) 69 | 70 | x = torch.reshape(x, (batch_size, num_frames, self.emb_size)) 71 | return x 72 | 73 | class Identity(nn.Module): 74 | def __init__(self): 75 | super(Identity, self).__init__() 76 | def forward(self, x): 77 | return x 78 | 79 | class BaseModel(nn.Module): 80 | 81 | def __init__(self, pretrained=True): 82 | super(BaseModel, self).__init__() 83 | 84 | resnet = models.resnet50(pretrained=pretrained) 85 | layers = list(resnet.children())[:-3] 86 | layers[-1] = nn.Sequential(*list(layers[-1].children())[:-3]) 87 | self.base_model = nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | 91 | batch_size, num_steps, c, h, w = x.shape 92 | x = torch.reshape(x, [batch_size * num_steps, c, h, w]) 93 | 94 | x = self.base_model(x) 95 | 96 | _, c, h, w = x.shape 97 | x = torch.reshape(x, [batch_size, num_steps, c, h, w]) 98 | 99 | return x 100 | 101 | class BaseVGGM(nn.Module): 102 | 103 | def __init__(self): 104 | super(BaseVGGM, self).__init__() 105 | 106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | 109 | self.bn1 = nn.BatchNorm2d(64) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 112 | 113 | self.conv_block1 = nn.Sequential( 114 | nn.Conv2d(64, 128, 3, padding=1), 115 | nn.BatchNorm2d(128), 116 | nn.ReLU(), 117 | nn.Conv2d(128, 128, 3, padding=1), 118 | nn.BatchNorm2d(128), 119 | nn.ReLU(), 120 | nn.MaxPool2d(2, 2) 121 | ) 122 | 123 | self.conv_block2 = nn.Sequential( 124 | nn.Conv2d(128, 256, 3, padding=1), 125 | nn.BatchNorm2d(256), 126 | nn.ReLU(), 127 | nn.Conv2d(256, 256, 3, padding=1), 128 | nn.BatchNorm2d(256), 129 | nn.ReLU(), 130 | nn.MaxPool2d(2, 2) 131 | ) 132 | 133 | self.conv_block3 = nn.Sequential( 134 | nn.Conv2d(256, 512, 3, padding=1), 135 | nn.BatchNorm2d(512), 136 | nn.ReLU(), 137 | nn.Conv2d(512, 512, 3, padding=1), 138 | nn.BatchNorm2d(512), 139 | nn.ReLU(), 140 | ) 141 | 142 | def forward(self, x): 143 | 144 | batch_size, num_steps, c, h, w = x.shape 145 | x = torch.reshape(x, [batch_size * num_steps, c, h, w]) 146 | 147 | x = self.conv1(x) 148 | x = self.bn1(x) 149 | x = self.relu(x) 150 | x = self.maxpool(x) 151 | 152 | x = self.conv_block1(x) 153 | x = self.conv_block2(x) 154 | x = self.conv_block3(x) 155 | 156 | _, c, h, w = x.shape 157 | x = torch.reshape(x, [batch_size, num_steps, c, h, w]) 158 | 159 | return x 160 | 161 | class Classifier(nn.Module): 162 | 163 | def __init__(self, input_size, num_classes): 164 | super(Classifier, self).__init__() 165 | 166 | self.input_size = input_size 167 | self.fc = nn.Linear(input_size, num_classes) 168 | 169 | def forward(self, x): 170 | x = torch.reshape(x, (-1, self.input_size)) 171 | x = self.fc(x) 172 | return x 173 | 174 | class CharBaseNet(nn.Module): 175 | def __init__(self, emb_size=64): 176 | super(CharBaseNet, self).__init__() 177 | 178 | self.fc1 = nn.Linear(64*64, 512) 179 | self.fc2 = nn.Linear(512, 256) 180 | self.fc3 = nn.Linear(256, emb_size) 181 | 182 | def forward(self, x): 183 | 184 | batch_size, num_steps, in_size = x.shape 185 | x = torch.reshape(x, [batch_size * num_steps, in_size]) 186 | 187 | x = self.fc1(x) 188 | x = F.relu(x) 189 | 190 | x = self.fc2(x) 191 | x = F.relu(x) 192 | 193 | x = self.fc3(x) 194 | 195 | x = torch.reshape(x, [batch_size, num_steps, -1]) 196 | return x 197 | 198 | class CharEmbedder(nn.Module): 199 | def __init__(self, in_size=64, emb_size=64, num_context=2, l2_normalize=True): 200 | super(CharEmbedder, self).__init__() 201 | 202 | self.fc1 = nn.Linear(in_size*num_context, in_size) 203 | self.fc2 = nn.Linear(in_size, emb_size) 204 | 205 | self.l2_normalize = l2_normalize 206 | 207 | def forward(self, x, num_frames): 208 | 209 | batch_size, total_num_steps, emb_size = x.shape 210 | num_context = total_num_steps // num_frames 211 | x = torch.reshape(x, (batch_size * num_frames, num_context * emb_size)) 212 | 213 | x = self.fc1(x) 214 | x = F.relu(x) 215 | 216 | x = self.fc2(x) 217 | 218 | if self.l2_normalize: 219 | x = F.normalize(x, p=2, dim=-1) 220 | 221 | x = torch.reshape(x, (batch_size, num_frames, emb_size)) 222 | return x 223 | 224 | -------------------------------------------------------------------------------- /visualize_alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from imageio import imread 4 | import cv2 5 | 6 | import utils 7 | import align_dataset_test as align_dataset 8 | from config import CONFIG 9 | 10 | from train import AlignNet 11 | 12 | import matplotlib 13 | matplotlib.use("Agg") 14 | 15 | import logging 16 | logger = logging.getLogger('matplotlib') 17 | logger.setLevel(logging.INFO) 18 | 19 | import matplotlib.pyplot as plt 20 | import matplotlib.animation as animation 21 | plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg' 22 | 23 | import random 24 | import itertools 25 | import argparse 26 | 27 | from numpy import array, zeros, full, argmin, inf, ndim 28 | from scipy.spatial.distance import cdist 29 | from math import isinf 30 | 31 | from dtw import dtw 32 | 33 | def dist_fn(x, y): 34 | dist = np.sum((x-y)**2) 35 | return dist 36 | 37 | 38 | def get_nn(embs, query_emb): 39 | dist = np.linalg.norm(embs - query_emb, axis=1) 40 | assert len(dist) == len(embs) 41 | return np.argmin(dist), np.min(dist) 42 | 43 | def show_mat_align(D, nns): 44 | plt.imshow(D.T) 45 | plt.plot(nns, 'r', linewidth=1) 46 | plt.colorbar() 47 | plt.show() 48 | 49 | def save_mat_align(D, nns, path): 50 | plt.imshow(D.T) 51 | plt.plot(nns, 'r', linewidth=1) 52 | plt.colorbar() 53 | plt.savefig(path) 54 | plt.close() 55 | 56 | def align(query_feats, candidate_feats, use_dtw): 57 | """Align videos based on nearest neighbor or dynamic time warping.""" 58 | if use_dtw: 59 | _, D, _, path = dtw(query_feats, candidate_feats, dist=dist_fn) 60 | _, uix = np.unique(path[0], return_index=True) 61 | nns = path[1][uix] 62 | 63 | else: 64 | nns = [] 65 | _, D, _, _ = dtw(query_feats, candidate_feats, dist=dist_fn) 66 | for i in range(len(query_feats)): 67 | nn_frame_id, _ = get_nn(candidate_feats, query_feats[i]) 68 | nns.append(nn_frame_id) 69 | return nns, D 70 | 71 | def align_and_video(args, a_emb, b_emb, a_name, b_name, a_frames, b_frames): 72 | nns_a, dist_mat_a = align(a_emb, a_emb, use_dtw=args.use_dtw) 73 | save_mat_align(dist_mat_a, nns_a, args.dest+'Self-{}-align-{}-stride-{}-dtw-{}-bs-{}.png'.format(a_name, args.mode, 74 | args.stride, args.use_dtw, args.batch_size).replace('/', '_')) 75 | 76 | nns_b, dist_mat_b = align(b_emb, b_emb, use_dtw=args.use_dtw) 77 | save_mat_align(dist_mat_b, nns_b, args.dest+'Self-{}-align-{}-stride-{}-dtw-{}-bs-{}.png'.format(b_name, args.mode, 78 | args.stride, args.use_dtw, args.batch_size).replace('/', '_')) 79 | 80 | nns, dist_mat = align(a_emb[::args.stride], b_emb[::args.stride], use_dtw=args.use_dtw) 81 | 82 | print(dist_mat.shape) 83 | 84 | save_mat_align(dist_mat, nns, args.dest+'{}-{}-align-{}-stride-{}-dtw-{}-bs-{}.png'.format(a_name, b_name, args.mode, 85 | args.stride, args.use_dtw, args.batch_size).replace('/', '_')) 86 | 87 | aligned_imgs = [] 88 | a_frames = a_frames[::args.stride] 89 | b_frames = b_frames[::args.stride] 90 | 91 | max_len = max(len(a_frames), len(b_frames)) 92 | 93 | for i in range(max_len): 94 | 95 | aimg = imread(a_frames[min(i, len(a_frames)-1)]) 96 | aimg = cv2.resize(aimg, (224, 224)) 97 | bimg_nn = imread(b_frames[nns[min(i, len(nns)-1)]]) 98 | bimg_nn = cv2.resize(bimg_nn, (224, 224)) 99 | 100 | bimg_i = imread(b_frames[min(i, len(b_frames)-1)]) 101 | bimg_i = cv2.resize(bimg_i, (224, 224)) 102 | 103 | print('Aligned {} - {}'.format(min(i, len(a_frames)-1), nns[min(i, len(a_frames)-1)])) 104 | 105 | ab_img_nn = np.concatenate((aimg, bimg_nn), axis=1) 106 | ab_img_i = np.concatenate((aimg, bimg_i), axis=1) 107 | 108 | ab_img = np.concatenate((ab_img_nn, ab_img_i), axis=0) 109 | aligned_imgs.append(ab_img) 110 | 111 | def make_video(img): 112 | 113 | frames = [] # for storing the generated images 114 | fig = plt.figure() 115 | 116 | print('LEN: ', len(img)) 117 | 118 | for i in range(len(img)): 119 | frames.append([plt.imshow(img[i],animated=True)]) 120 | 121 | ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, 122 | repeat_delay=1000) 123 | ani.save(args.dest+'{}-{}-align-{}-stride-{}-dtw-{}-bs-{}.mp4'.format(a_name, b_name, args.mode, 124 | args.stride, args.use_dtw, args.batch_size).replace('/', '_')) 125 | plt.close(fig) 126 | 127 | make_video(aligned_imgs) 128 | 129 | def main(args): 130 | 131 | model = AlignNet.load_from_checkpoint(args.model_path, map_location=args.device) 132 | model.to(args.device) 133 | 134 | if args.mode == 'train': 135 | model.train() 136 | else: 137 | model.eval() 138 | 139 | eval_transforms = utils.get_transforms(augment=False) 140 | 141 | random.seed(args.seed) 142 | data = align_dataset.AlignData(args.data_path, args.batch_size, CONFIG.DATA, transform=eval_transforms, flatten=False) 143 | 144 | for i in range(data.n_classes): 145 | # get 2 videos of 0th action 146 | data.set_action_seq(action=i, num_seqs=args.num_seqs) 147 | 148 | embeddings = [] 149 | frame_paths = [] 150 | names = [] 151 | 152 | for act_iter in iter(data): 153 | for seq_iter in act_iter: 154 | 155 | seq_embs = [] 156 | seq_fpaths = [] 157 | for _, batch in enumerate(seq_iter): 158 | 159 | a_X, a_name, a_frames = batch 160 | 161 | print(a_X.shape) 162 | print(a_name) 163 | 164 | a_emb = model(a_X.to(args.device).unsqueeze(0)) 165 | print(a_emb.shape) 166 | 167 | seq_embs.append(a_emb.squeeze(0).detach().cpu().numpy()) 168 | seq_fpaths.extend(a_frames) 169 | 170 | seq_embs = np.concatenate(seq_embs, axis=0) 171 | embeddings.append(seq_embs) 172 | frame_paths.append(seq_fpaths) 173 | names.append(a_name) 174 | 175 | print(len(embeddings)) 176 | print(len(frame_paths)) 177 | 178 | print(embeddings[0].shape) 179 | print(embeddings[1].shape) 180 | print(frame_paths[0][-1]) 181 | print(frame_paths[1][-1]) 182 | print(names) 183 | 184 | for i, j in itertools.combinations(range(len(embeddings)), 2): 185 | align_and_video(args, embeddings[i], embeddings[j], names[i], names[j], frame_paths[i], frame_paths[j]) 186 | 187 | if __name__ == "__main__": 188 | 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument('--data_path', type=str, required=True) 191 | parser.add_argument('--model_path', type=str, required=True) 192 | parser.add_argument('--batch_size', type=int, default=40) 193 | parser.add_argument('--mode', type=str, default='eval') 194 | parser.add_argument('--dest', type=str, default='./') 195 | parser.add_argument('--stride', type=int, default=1) 196 | parser.add_argument('--use_dtw', dest='use_dtw', action='store_true') 197 | 198 | parser.add_argument('--num_seqs', type=int, default=2) 199 | 200 | parser.add_argument('--device', type=str, default='cuda') 201 | parser.add_argument('--seed', type=int, default=0) 202 | 203 | args = parser.parse_args() 204 | 205 | main(args) 206 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | import losses 7 | 8 | import os 9 | import numpy as np 10 | 11 | import utils 12 | import align_dataset 13 | from models import BaseModel, ConvEmbedder 14 | from config import CONFIG 15 | # from ViViT import ViViT 16 | from pytorch_lightning import Trainer, seed_everything 17 | from pytorch_lightning.core.lightning import LightningModule 18 | 19 | import torch 20 | import torch.nn as nn 21 | import numpy as np 22 | from einops import rearrange, reduce, repeat 23 | from IPython.display import display 24 | import argparse 25 | 26 | class AlignNet(LightningModule): 27 | def __init__(self, config): 28 | super(AlignNet, self).__init__() 29 | 30 | self.base_cnn = BaseModel(pretrained=True) 31 | 32 | if config.TRAIN.FREEZE_BASE: 33 | if config.TRAIN.FREEZE_BN_ONLY: 34 | utils.freeze_bn_only(module=self.base_cnn) 35 | else: 36 | utils.freeze(module=self.base_cnn, train_bn=False) 37 | 38 | self.emb = ConvEmbedder(emb_size=config.DTWALIGNMENT.EMBEDDING_SIZE, l2_normalize=config.LOSSES.L2_NORMALIZE) 39 | 40 | self.lav_loss = losses.LAV(alpha=config.LOSSES.ALPHA, sigma=config.LOSSES.SIGMA, margin=config.LOSSES.IDM_IDX_MARGIN, 41 | num_frames=config.TRAIN.NUM_FRAMES, dtw_gamma=config.DTWALIGNMENT.SDTW_GAMMA, 42 | dtw_normalize=config.DTWALIGNMENT.SDTW_NORMALIZE, debug=False) 43 | 44 | self.description = config.DESCRIPTION 45 | 46 | # params 47 | self.l2_normalize = config.LOSSES.L2_NORMALIZE 48 | self.alpha = config.LOSSES.ALPHA 49 | self.sigma = config.LOSSES.SIGMA 50 | 51 | self.lr = config.TRAIN.LR 52 | self.weight_decay = config.TRAIN.WEIGHT_DECAY 53 | self.batch_size = config.TRAIN.BATCH_SIZE 54 | self.freeze_base = config.TRAIN.FREEZE_BASE 55 | self.freeze_bn_only = config.TRAIN.FREEZE_BN_ONLY 56 | 57 | self.data_path = os.path.abspath(config.DATA_PATH) 58 | 59 | self.hparams.config = config 60 | 61 | self.save_hyperparameters() 62 | 63 | def train(self, mode=True): 64 | super(AlignNet, self).train(mode=mode) 65 | 66 | if self.freeze_base: 67 | if self.freeze_bn_only: 68 | utils.freeze_bn_only(module=self.base_cnn) 69 | else: 70 | utils.freeze(module=self.base_cnn, train_bn=False) 71 | 72 | def forward(self, x): 73 | num_ctxt = self.hparams.config.DATA.NUM_CONTEXT 74 | 75 | num_frames = x.size(1) // num_ctxt 76 | x = self.base_cnn(x) 77 | x = self.emb(x, num_frames) 78 | return x 79 | 80 | def training_step(self, batch, batch_idx): 81 | (a_X, _, a_steps, a_seq_len), (b_X, _, b_steps, b_seq_len) = batch 82 | 83 | X = torch.cat([a_X, b_X]) 84 | embs = self.forward(X) 85 | a_embs, b_embs = torch.split(embs, a_X.size(0), dim=0) 86 | 87 | loss = 0. 88 | 89 | for a_emb, a_idx, a_len, b_emb, b_idx, b_len in zip(a_embs.unsqueeze(1), a_steps, a_seq_len, b_embs.unsqueeze(1), b_steps, b_seq_len): 90 | 91 | loss += self.lav_loss(a_emb, b_emb, a_idx, b_idx, a_len, b_len) 92 | 93 | loss = loss / self.batch_size 94 | 95 | tensorboard_logs = {'train_loss': loss} 96 | 97 | return {'loss': loss, 'log': tensorboard_logs} 98 | 99 | def validation_step(self, batch, batch_idx): 100 | 101 | (a_X, _, a_steps, a_seq_len), (b_X, _, b_steps, b_seq_len) = batch 102 | 103 | X = torch.cat([a_X, b_X]) 104 | embs = self.forward(X) 105 | a_embs, b_embs = torch.split(embs, a_X.size(0), dim=0) 106 | 107 | loss = 0. 108 | 109 | for a_emb, a_idx, a_len, b_emb, b_idx, b_len in zip(a_embs.unsqueeze(1), a_steps, a_seq_len, b_embs.unsqueeze(1), b_steps, b_seq_len): 110 | 111 | loss += self.lav_loss(a_emb, b_emb, a_idx, b_idx, a_len, b_len, logger=self.logger) 112 | 113 | loss = loss / self.batch_size 114 | 115 | tensorboard_logs = {'val_loss': loss} 116 | 117 | return {'val_loss': loss, 'log': tensorboard_logs} 118 | 119 | def validation_epoch_end(self, outputs): 120 | 121 | avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() 122 | tensorboard_logs = {} 123 | 124 | for x in outputs: 125 | for k in x['log']: 126 | if k not in tensorboard_logs: 127 | tensorboard_logs[k] = [] 128 | 129 | tensorboard_logs[k].append(x['log'][k]) 130 | 131 | for k, losses in tensorboard_logs.items(): 132 | tensorboard_logs[k] = torch.stack(losses).mean() 133 | 134 | return {'val_loss': avg_loss, 'log': tensorboard_logs} 135 | 136 | def configure_optimizers(self): 137 | 138 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.lr, weight_decay=self.weight_decay) 139 | return optimizer 140 | 141 | def train_dataloader(self): 142 | config = self.hparams.config 143 | train_path = os.path.join(self.data_path, 'train') 144 | 145 | train_transforms = utils.get_transforms(augment=True) 146 | data = align_dataset.AlignData(train_path, config.TRAIN.NUM_FRAMES, config.DATA, transform=train_transforms, flatten=False) 147 | data_loader = DataLoader(data, batch_size=self.batch_size, shuffle=True, pin_memory=True, 148 | num_workers=config.DATA.WORKERS) 149 | 150 | return data_loader 151 | 152 | def val_dataloader(self): 153 | config = self.hparams.config 154 | val_path = os.path.join(self.data_path, 'val') 155 | 156 | val_transforms = utils.get_transforms(augment=False) 157 | data = align_dataset.AlignData(val_path, config.EVAL.NUM_FRAMES, config.DATA, transform=val_transforms, flatten=False) 158 | data_loader = DataLoader(data, batch_size=self.batch_size, shuffle=True, pin_memory=True, 159 | num_workers=config.DATA.WORKERS) 160 | 161 | return data_loader 162 | 163 | def test_step(self, batch, batch_idx): 164 | return self.validation_step(batch, batch_idx) 165 | 166 | def test_epoch_end(self, outputs): 167 | return self.validation_epoch_end(outputs) 168 | 169 | 170 | def main(hparams): 171 | 172 | seed_everything(hparams.SEED) 173 | 174 | model = AlignNet(hparams) 175 | 176 | dd_backend = None 177 | # if hparams.GPUS < 0 or hparams.GPUS > 1: 178 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 179 | # dd_backend = 'ddp' 180 | 181 | try: 182 | 183 | checkpoint_callback = utils.CheckpointEveryNSteps(hparams.TRAIN.SAVE_INTERVAL_ITERS, filepath=os.path.join(hparams.CKPT_PATH, 'STEPS')) 184 | 185 | trainer = Trainer(gpus=hparams.GPUS, max_epochs=hparams.TRAIN.EPOCHS, default_root_dir=hparams.ROOT, 186 | deterministic=True, 187 | callbacks=[checkpoint_callback], check_val_every_n_epoch=1) 188 | 189 | trainer.fit(model) 190 | # distributed_backend=dd_backend, row_log_interval=10 limit_val_batches=hparams.TRAIN.VAL_PERCENT 191 | except KeyboardInterrupt: 192 | pass 193 | finally: 194 | trainer.save_checkpoint(os.path.join(os.path.join(hparams.CKPT_PATH, 'STEPS'), 'final_model_l2norm-{}' 195 | '_sigma-{}_alpha-{}' 196 | '_lr-{}_bs-{}.pth'.format(hparams.LOSSES.L2_NORMALIZE, 197 | hparams.LOSSES.SIGMA, 198 | hparams.LOSSES.ALPHA, 199 | hparams.TRAIN.LR, 200 | hparams.TRAIN.BATCH_SIZE))) 201 | trainer.save_checkpoint(os.path.join(hparams.ROOT, 'final_model_l2norm-{}' 202 | '_sigma-{}_alpha-{}' 203 | '_lr-{}_bs-{}.pth'.format(hparams.LOSSES.L2_NORMALIZE, 204 | hparams.LOSSES.SIGMA, 205 | hparams.LOSSES.ALPHA, 206 | hparams.TRAIN.LR, 207 | hparams.TRAIN.BATCH_SIZE))) 208 | 209 | if __name__ == '__main__': 210 | 211 | parser = argparse.ArgumentParser() 212 | parser.add_argument('--description', type=str, required=True, help='Description of the experiment run!') 213 | parser.add_argument('--gpus', type=int, default=1) 214 | parser.add_argument('--root_dir', type=str, default=None) 215 | parser.add_argument('--ckpt_path', type=str, default=None, help='Path to save checkpoints') 216 | parser.add_argument('--data_path', type=str, default=None, help='Path to dataset') 217 | parser.add_argument('--num_frames', type=int, default=None, help='Path to dataset') 218 | parser.add_argument('--workers', type=int, default=30, help='Path to dataset') 219 | 220 | args = parser.parse_args() 221 | 222 | CONFIG.DESCRIPTION = args.description 223 | CONFIG.GPUS = args.gpus 224 | 225 | if args.root_dir: 226 | CONFIG.ROOT = args.root_dir 227 | if args.ckpt_path: 228 | CONFIG.CKPT_PATH = args.ckpt_path 229 | if args.data_path: 230 | CONFIG.DATA_PATH = args.data_path 231 | if args.num_frames: 232 | CONFIG.TRAIN.NUM_FRAMES = args.num_frames 233 | CONFIG.EVAL.NUM_FRAMES = args.num_frames 234 | if args.workers: 235 | CONFIG.DATA.WORKERS = args.workers 236 | 237 | main(CONFIG) 238 | -------------------------------------------------------------------------------- /preprocessing/transforms_video.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | from PIL import Image 4 | import collections 5 | 6 | import torch 7 | from torchvision.transforms import ( 8 | RandomCrop, 9 | RandomResizedCrop, 10 | Resize, 11 | ToTensor, 12 | ColorJitter, 13 | ) 14 | 15 | from . import functional_video as F 16 | 17 | import sys 18 | if sys.version_info < (3, 3): 19 | Sequence = collections.Sequence 20 | Iterable = collections.Iterable 21 | else: 22 | Sequence = collections.abc.Sequence 23 | Iterable = collections.abc.Iterable 24 | 25 | __all__ = [ 26 | "RandomCropVideo", 27 | "RandomResizedCropVideo", 28 | "CenterCropVideo", 29 | "ResizeVideo", 30 | "NormalizeVideo", 31 | "ToTensorVideo", 32 | "RandomHorizontalFlipVideo", 33 | "ColorJitterVideo", 34 | ] 35 | 36 | 37 | class RandomCropVideo(RandomCrop): 38 | def __init__(self, size): 39 | if isinstance(size, numbers.Number): 40 | self.size = (int(size), int(size)) 41 | else: 42 | self.size = size 43 | 44 | def __call__(self, clip): 45 | """ 46 | Args: 47 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 48 | Returns: 49 | torch.tensor: randomly cropped/resized video clip. 50 | size is (C, T, OH, OW) 51 | """ 52 | i, j, h, w = self.get_params(clip, self.size) 53 | return F.crop(clip, i, j, h, w) 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + '(size={0})'.format(self.size) 57 | 58 | 59 | class RandomResizedCropVideo(object): 60 | def __init__( 61 | self, 62 | size, 63 | scale=(0.8, 1.0), 64 | interpolation_mode="bilinear", 65 | ): 66 | if isinstance(size, tuple): 67 | assert len(size) == 2, "size should be tuple (height, width)" 68 | self.size = size 69 | else: 70 | self.size = (size, size) 71 | 72 | self.interpolation_mode = interpolation_mode 73 | self.scale = scale 74 | 75 | def get_params(self, clip, scale): 76 | 77 | H, W = clip.shape[-2:] 78 | min_dim = min(H, W) 79 | 80 | sampled_size = int(min_dim * random.uniform(scale[0], scale[1])) 81 | height_offset = random.randint(0, H - sampled_size) 82 | width_offset = random.randint(0, W - sampled_size) 83 | return height_offset, width_offset, sampled_size, sampled_size 84 | 85 | def __call__(self, clip): 86 | """ 87 | Args: 88 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 89 | Returns: 90 | torch.tensor: randomly cropped/resized video clip. 91 | size is (C, T, H, W) 92 | """ 93 | i, j, h, w = self.get_params(clip, self.scale) 94 | return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) 95 | 96 | def __repr__(self): 97 | return self.__class__.__name__ + \ 98 | '(size={0}, interpolation_mode={1}, scale={2})'.format( 99 | self.size, self.interpolation_mode, self.scale 100 | ) 101 | 102 | class CenterResizedCropVideo(RandomResizedCropVideo): 103 | def __init__( 104 | self, 105 | size, 106 | interpolation_mode="bilinear", 107 | ): 108 | """ 109 | Returns the maximum square, central crop resized to the given image size. 110 | """ 111 | super(CenterResizedCropVideo, self).__init__(size, scale=1., interpolation_mode=interpolation_mode) 112 | 113 | def get_params(self, clip, scale): 114 | 115 | H, W = clip.shape[-2:] 116 | min_dim = min(H, W) 117 | 118 | height_offset = int((H - min_dim) // 2) 119 | width_offset = int((W - min_dim) // 2) 120 | 121 | return height_offset, width_offset, min_dim, min_dim 122 | 123 | class CenterCropVideo(object): 124 | def __init__(self, crop_size): 125 | if isinstance(crop_size, numbers.Number): 126 | self.crop_size = (int(crop_size), int(crop_size)) 127 | else: 128 | self.crop_size = crop_size 129 | 130 | def __call__(self, clip): 131 | """ 132 | Args: 133 | clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) 134 | Returns: 135 | torch.tensor: central cropping of video clip. Size is 136 | (C, T, crop_size, crop_size) 137 | """ 138 | return F.center_crop(clip, self.crop_size) 139 | 140 | def __repr__(self): 141 | return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size) 142 | 143 | class ResizeVideo(object): 144 | """Resize the input PIL Images to the given size. 145 | 146 | Args: 147 | size (sequence or int): Desired output size. If size is a sequence like 148 | (h, w), output size will be matched to this. If size is an int, 149 | smaller edge of the image will be matched to this number. 150 | i.e, if height > width, then image will be rescaled to 151 | (size * height / width, size) 152 | interpolation (int, optional): Desired interpolation. Default is 153 | ``PIL.Image.BILINEAR`` 154 | """ 155 | 156 | def __init__(self, size, interpolation=Image.BILINEAR): 157 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 158 | self.size = size 159 | self.interpolation = interpolation 160 | self.resize_img = Resize(self.size, self.interpolation) 161 | 162 | def __call__(self, imgs): 163 | """ 164 | Args: 165 | imgs (List of PIL Image): Images to be scaled. 166 | 167 | Returns: 168 | List of PIL Images: Rescaled images. 169 | """ 170 | return [self.resize_img(x) for x in imgs] 171 | 172 | def __repr__(self): 173 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, self.interpolation) 174 | 175 | class NormalizeVideo(object): 176 | """ 177 | Normalize the video clip by mean subtraction and division by standard deviation 178 | Args: 179 | mean (3-tuple): pixel RGB mean 180 | std (3-tuple): pixel RGB standard deviation 181 | inplace (boolean): whether do in-place normalization 182 | """ 183 | 184 | def __init__(self, mean, std, inplace=False): 185 | self.mean = mean 186 | self.std = std 187 | self.inplace = inplace 188 | 189 | def __call__(self, clip): 190 | """ 191 | Args: 192 | clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) 193 | """ 194 | return F.normalize(clip, self.mean, self.std, self.inplace) 195 | 196 | def __repr__(self): 197 | return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format( 198 | self.mean, self.std, self.inplace) 199 | 200 | 201 | class ToTensorVideo(object): 202 | """ 203 | Convert a List of PIL Images or numpy ndarrays of each size HxWxC to torch.tensor of size TxCxHxW 204 | """ 205 | 206 | def __init__(self): 207 | self.totensor = ToTensor() 208 | 209 | def __call__(self, clip): 210 | """ 211 | Args: 212 | clip (List of PIL Image or numpy ndarray): Size is Tx (HxWxC) 213 | Return: 214 | clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) 215 | """ 216 | clip = [self.totensor(x) for x in clip] 217 | return torch.stack(clip) 218 | 219 | def __repr__(self): 220 | return self.__class__.__name__ 221 | 222 | 223 | class RandomHorizontalFlipVideo(object): 224 | """ 225 | Flip the video clip along the horizonal direction with a given probability 226 | Args: 227 | p (float): probability of the clip being flipped. Default value is 0.5 228 | """ 229 | 230 | def __init__(self, p=0.5): 231 | self.p = p 232 | 233 | def __call__(self, clip): 234 | """ 235 | Args: 236 | clip (torch.tensor): Size is (C, T, H, W) 237 | Return: 238 | clip (torch.tensor): Size is (C, T, H, W) 239 | """ 240 | if random.random() < self.p: 241 | clip = F.hflip(clip) 242 | return clip 243 | 244 | def __repr__(self): 245 | return self.__class__.__name__ + "(p={0})".format(self.p) 246 | 247 | class ColorJitterVideo(ColorJitter): 248 | """Randomly change the brightness, contrast and saturation of a list of images. 249 | 250 | Args: 251 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 252 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 253 | or the given [min, max]. Should be non negative numbers. 254 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 255 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 256 | or the given [min, max]. Should be non negative numbers. 257 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 258 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 259 | or the given [min, max]. Should be non negative numbers. 260 | hue (float or tuple of float (min, max)): How much to jitter hue. 261 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 262 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 263 | """ 264 | def __init__(self, brightness, contrast, saturation, hue): 265 | 266 | self.brightness = self._check_input(brightness, 'brightness') 267 | self.contrast = self._check_input(contrast, 'contrast') 268 | self.saturation = self._check_input(saturation, 'saturation') 269 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 270 | clip_first_on_zero=False) 271 | 272 | def __call__(self, clip): 273 | """ 274 | Args: 275 | clip (List): List of 'N' PIL Images. 276 | Return: 277 | List of 'N' PIL Images: Color jittered images. 278 | """ 279 | 280 | # transform = self.get_params(self.brightness, self.contrast, self.saturation, self.hue) 281 | # clip = [transform(x) for x in clip] 282 | 283 | clip = [self.forward(x) for x in clip] 284 | return clip -------------------------------------------------------------------------------- /evaluations.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import utils 5 | import align_dataset_test as align_dataset 6 | from config import CONFIG 7 | 8 | from evals.phase_classification import evaluate_phase_classification, compute_ap 9 | from evals.kendalls_tau import evaluate_kendalls_tau 10 | from evals.phase_progression import evaluate_phase_progression 11 | 12 | from train import AlignNet 13 | import torch 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | import random 17 | import argparse 18 | import glob 19 | from natsort import natsorted 20 | 21 | def get_embeddings(model, data, labels_npy, args): 22 | 23 | embeddings = [] 24 | labels = [] 25 | frame_paths = [] 26 | names = [] 27 | 28 | device = f"cuda:{args.device}" 29 | 30 | for act_iter in iter(data): 31 | for i, seq_iter in enumerate(act_iter): 32 | seq_embs = [] 33 | seq_fpaths = [] 34 | original = 0 35 | for _, batch in enumerate(seq_iter): 36 | a_X, a_name, a_frames = batch 37 | a_X = a_X.to(device).unsqueeze(0) 38 | original = a_X.shape[1]//2 39 | 40 | # a_X = a_X[:,:a_X.shape[1]//2,:,:,:] 41 | # original = a_X.shape[1]//2 42 | 43 | # if ((args.num_frames*2)-a_X.shape[1]) > 0: 44 | # b = a_X[:, -1].clone() 45 | # b = torch.stack([b]*((args.num_frames*2)-a_X.shape[1]),axis=1).to(device) 46 | # a_X = torch.concat([a_X,b], axis=1) 47 | 48 | b = a_X[:, -1].clone() 49 | try: 50 | b = torch.stack([b]*((args.num_frames*2)-a_X.shape[1]),axis=1).to(device) 51 | except: 52 | b = torch.from_numpy(np.array([])).float().to(device) 53 | a_X = torch.concat([a_X,b], axis=1) 54 | a_emb = model(a_X)[:, :original,:] 55 | 56 | if args.verbose: 57 | print(f'Seq: {i}, ', a_emb.shape) 58 | 59 | seq_embs.append(a_emb.squeeze(0).detach().cpu().numpy()) 60 | seq_fpaths.extend(a_frames) 61 | 62 | seq_embs = np.concatenate(seq_embs, axis=0) 63 | 64 | name = str(a_name).split('/')[-1] 65 | # name = name[:8] + '/' + name[8:10] + '/' + name[10:] 66 | lab = labels_npy[name]['labels'] 67 | end = min(seq_embs.shape[0], len(lab)) 68 | lab = lab[:end]#.T 69 | seq_embs = seq_embs[:end] 70 | print(seq_embs.shape, len(lab)) 71 | embeddings.append(seq_embs[:end]) 72 | frame_paths.append(seq_fpaths) 73 | names.append(a_name) 74 | labels.append(lab) 75 | 76 | return embeddings, names, labels 77 | 78 | 79 | def main(ckpts, args): 80 | 81 | summary_dest = os.path.join(args.dest, 'eval_logs') 82 | os.makedirs(summary_dest, exist_ok=True) 83 | 84 | for ckpt in ckpts: 85 | writer = SummaryWriter(summary_dest, filename_suffix='eval_logs') 86 | 87 | # get ckpt-step from the ckpt name 88 | _, ckpt_step = ckpt.split('.')[0].split('_')[-2:] 89 | ckpt_step = int(ckpt_step.split('=')[1]) 90 | DEST = os.path.join(args.dest, 'eval_step_{}'.format(ckpt_step)) 91 | 92 | device = f"cuda:{args.device}" 93 | model = AlignNet.load_from_checkpoint(ckpt, map_location=device) 94 | model.to(device) 95 | model.eval() 96 | 97 | # grad off 98 | torch.set_grad_enabled(False) 99 | 100 | if args.num_frames: 101 | CONFIG.TRAIN.NUM_FRAMES = args.num_frames 102 | CONFIG.EVAL.NUM_FRAMES = args.num_frames 103 | 104 | CONFIG.update(model.hparams.config) 105 | 106 | print(model.hparams) 107 | if args.data_path: 108 | data_path = args.data_path 109 | else: 110 | data_path = CONFIG.DATA_PATH 111 | data_path = '/home/ubuntu/Data_Test/' 112 | 113 | train_path = os.path.join(data_path, 'Test') 114 | val_path = os.path.join(data_path, 'Test') 115 | # lab_train_path = os.path.join(data_path, 'labels', 'train') 116 | # lab_val_path = os.path.join(data_path, 'labels', 'val') 117 | lab_name = "_".join(args.model_path.split('/')[4].split('_')[:-1]) + '_val' 118 | print(lab_name) 119 | labels = np.load(f"/home/ubuntu/npyrecords/{lab_name}.npy", allow_pickle=True).item() 120 | 121 | # create dataset 122 | _transforms = utils.get_transforms(augment=False) 123 | 124 | random.seed(0) 125 | train_data = align_dataset.AlignData(train_path, args.batch_size, CONFIG.DATA, transform=_transforms, flatten=False) 126 | val_data = align_dataset.AlignData(val_path, args.batch_size, CONFIG.DATA, transform=_transforms, flatten=False) 127 | 128 | 129 | all_classifications = [] 130 | all_kendalls_taus = [] 131 | all_phase_progressions = [] 132 | ap5, ap10, ap15 = 0, 0, 0 133 | for i_action in range(train_data.n_classes): 134 | 135 | train_data.set_action_seq(i_action) 136 | val_data.set_action_seq(i_action) 137 | 138 | train_act_name = train_data.get_action_name(i_action) 139 | val_act_name = val_data.get_action_name(i_action) 140 | 141 | assert train_act_name == val_act_name 142 | 143 | # if args.verbose: 144 | # print(f'Getting embeddings for {train_act_name}...') 145 | # train_embs, train_names, train_labels = get_embeddings(model, train_data, lab_train_path, args) 146 | val_embs, val_names, val_labels = get_embeddings(model, val_data, labels, args) 147 | train_embs, train_names, train_labels = val_embs, val_names, val_labels 148 | 149 | # # save embeddings 150 | os.makedirs(DEST, exist_ok=True) 151 | DEST_TRAIN = os.path.join(DEST, f'train_{train_act_name}_embs.npy') 152 | DEST_VAL = os.path.join(DEST, f'val_{val_act_name}_embs.npy') 153 | 154 | np.save(DEST_TRAIN, {'embs' : train_embs, 'names':train_names, 'labels': train_labels}) 155 | np.save(DEST_VAL, {'embs' : val_embs, 'names':val_names, 'labels': val_labels}) 156 | 157 | train_embeddings = np.load(DEST_TRAIN, allow_pickle=True).tolist() 158 | val_embeddings = np.load(DEST_VAL, allow_pickle=True).tolist() 159 | 160 | train_embs, train_labels, train_names = train_embeddings['embs'], train_embeddings['labels'], train_embeddings['names'] 161 | val_embs, val_labels, val_names = val_embeddings['embs'], val_embeddings['labels'], val_embeddings['names'] 162 | 163 | # Evaluating Classification 164 | train_acc, val_acc = evaluate_phase_classification(ckpt_step, train_embs, train_labels, val_embs, val_labels, 165 | act_name=train_act_name, CONFIG=CONFIG, writer=writer, verbose=args.verbose) 166 | ap5, ap10, ap15 = compute_ap(val_embs, val_labels) 167 | 168 | all_classifications.append([train_acc, val_acc]) 169 | 170 | # Evaluating Kendall's Tau 171 | train_tau, val_tau = evaluate_kendalls_tau(train_embs, val_embs, stride=args.stride, 172 | kt_dist=CONFIG.EVAL.KENDALLS_TAU_DISTANCE, visualize=False) 173 | all_kendalls_taus.append([train_tau, val_tau]) 174 | 175 | print(f"Kendal's Tau: Stride = {args.stride} \n") 176 | print(f"Train = {train_tau}\n") 177 | print(f"Val = {val_tau}\n") 178 | 179 | writer.add_scalar(f'kendalls_tau/train_{train_act_name}', train_tau, global_step=ckpt_step) 180 | writer.add_scalar(f'kendalls_tau/val_{val_act_name}', val_tau, global_step=ckpt_step) 181 | 182 | # Evaluating Phase Progression 183 | # _train_dict = {'embs': train_embs, 'labels': train_labels} 184 | # _val_dict = {'embs': val_embs, 'labels': val_labels} 185 | # train_phase_scores, val_phase_scores = evaluate_phase_progression(_train_dict, _val_dict, "_".join(lab_name.split('_')[:-1]), 186 | # ckpt_step, CONFIG, writer, verbose=args.verbose) 187 | 188 | # all_phase_progressions.append([train_phase_scores[-1], val_phase_scores[-1]]) 189 | 190 | train_classification, val_classification = np.mean(all_classifications, axis=0) 191 | train_kendalls_tau, val_kendalls_tau = np.mean(all_kendalls_taus, axis=0) 192 | # train_phase_prog, val_phase_prog = np.mean(all_phase_progressions, axis=0) 193 | 194 | writer.add_scalar('metrics/AP@5_val', ap5, global_step=ckpt_step) 195 | writer.add_scalar('metrics/AP@10_val', ap10, global_step=ckpt_step) 196 | writer.add_scalar('metrics/AP@15_val', ap15, global_step=ckpt_step) 197 | 198 | writer.add_scalar('metrics/all_classification_train', train_classification, global_step=ckpt_step) 199 | writer.add_scalar('metrics/all_classification_val', val_classification, global_step=ckpt_step) 200 | 201 | writer.add_scalar('metrics/all_kendalls_tau_train', train_kendalls_tau, global_step=ckpt_step) 202 | writer.add_scalar('metrics/all_kendalls_tau_val', val_kendalls_tau, global_step=ckpt_step) 203 | 204 | # writer.add_scalar('metrics/all_phase_progression_train', train_phase_prog, global_step=ckpt_step) 205 | # writer.add_scalar('metrics/all_phase_progression_val', val_phase_prog, global_step=ckpt_step) 206 | 207 | print('metrics/AP@5_val', ap5, f"global_step={ckpt_step}") 208 | print('metrics/AP@10_val', ap10, f"global_step={ckpt_step}") 209 | print('metrics/AP@15_val', ap15, f"global_step={ckpt_step}") 210 | 211 | print('metrics/all_classification_train', train_classification, f"global_step={ckpt_step}") 212 | print('metrics/all_classification_val', val_classification, f"global_step={ckpt_step}") 213 | 214 | print('metrics/all_kendalls_tau_train', train_kendalls_tau, f"global_step={ckpt_step}") 215 | print('metrics/all_kendalls_tau_val', val_kendalls_tau, f"global_step={ckpt_step}") 216 | 217 | # print('metrics/all_phase_progression_train', train_phase_prog, f"global_step={ckpt_step}") 218 | # print('metrics/all_phase_progression_val', val_phase_prog, f"global_step={ckpt_step}") 219 | 220 | writer.flush() 221 | 222 | writer.close() 223 | 224 | if __name__ == '__main__': 225 | parser = argparse.ArgumentParser() 226 | parser.add_argument('--model_path', type=str, required=True) 227 | parser.add_argument('--data_path', type=str, default=None) 228 | parser.add_argument('--batch_size', type=int, default=20) 229 | parser.add_argument('--dest', type=str, default='./') 230 | 231 | parser.add_argument('--stride', type=int, default=5) 232 | parser.add_argument('--visualize', dest='visualize', action='store_true') 233 | parser.add_argument('--device', type=int, default=0, help='Cuda device to be used') 234 | parser.add_argument('--verbose', action='store_true') 235 | parser.add_argument('--num_frames', type=int, default=None, help='Path to dataset') 236 | 237 | args = parser.parse_args() 238 | 239 | if os.path.isdir(args.model_path): 240 | ckpts = natsorted(glob.glob(os.path.join(args.model_path, '*'))) 241 | else: 242 | ckpts = [args.model_path] 243 | 244 | 245 | ckpt_mul = args.device 246 | main(ckpts, args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import io 3 | import numpy as np 4 | from imageio import imread 5 | from PIL import Image 6 | 7 | import torch 8 | from torchvision import transforms 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | import preprocessing.transforms_video as tv 13 | from config import CONFIG 14 | 15 | import pytorch_lightning as pl 16 | 17 | BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) 18 | 19 | def normalize(img): 20 | return img / 255. 21 | 22 | def get_totensor_transform(is_video): 23 | if is_video: 24 | return transforms.Compose([tv.ToTensorVideo()]) 25 | else: 26 | return transforms.Compose([transforms.ToTensor()]) 27 | 28 | def get_images(paths): 29 | 30 | imgs = [] 31 | for p in paths: 32 | imgs.append(normalize(imread(p))) 33 | 34 | return np.array(imgs, dtype=np.float32) 35 | 36 | def get_transformed_images(paths, transform): 37 | 38 | imgs = [] 39 | for p in paths: 40 | imgs.append(transform(Image.open(p))) 41 | 42 | imgs = torch.stack(imgs) 43 | 44 | return imgs 45 | 46 | def get_pil_images(paths): 47 | 48 | imgs = [] 49 | for p in paths: 50 | imgs.append(Image.open(p)) 51 | 52 | return imgs 53 | 54 | def get_transforms(augment): 55 | 56 | seq_transforms = [] 57 | 58 | if augment: 59 | seq_transforms.append(tv.ColorJitterVideo( 60 | CONFIG.AUGMENTATION.BRIGHTNESS_DELTA, 61 | CONFIG.AUGMENTATION.CONTRAST_DELTA, 62 | CONFIG.AUGMENTATION.HUE_DELTA, 63 | CONFIG.AUGMENTATION.SATURATION_DELTA 64 | )) 65 | 66 | if not CONFIG.AUGMENTATION.RANDOM_CROP: 67 | seq_transforms.append(tv.ResizeVideo(size=(CONFIG.DATA.IMAGE_SIZE, CONFIG.DATA.IMAGE_SIZE))) 68 | 69 | seq_transforms.append(tv.ToTensorVideo()) 70 | 71 | if CONFIG.AUGMENTATION.RANDOM_FLIP: 72 | seq_transforms.append(tv.RandomHorizontalFlipVideo(p=0.5)) 73 | 74 | if CONFIG.AUGMENTATION.RANDOM_CROP: 75 | seq_transforms.append(tv.RandomResizedCropVideo(size=CONFIG.DATA.IMAGE_SIZE)) 76 | 77 | else: 78 | if CONFIG.AUGMENTATION.RANDOM_CROP: 79 | seq_transforms.append(tv.ToTensorVideo()) 80 | seq_transforms.append(tv.CenterResizedCropVideo(size=CONFIG.DATA.IMAGE_SIZE)) 81 | else: 82 | seq_transforms.append(tv.ResizeVideo(size=(CONFIG.DATA.IMAGE_SIZE, CONFIG.DATA.IMAGE_SIZE))) 83 | seq_transforms.append(tv.ToTensorVideo()) 84 | 85 | seq_transforms.append(tv.NormalizeVideo(mean=[0.485, 0.456, 0.406], 86 | std=[0.229, 0.224, 0.225])) 87 | 88 | return transforms.Compose(seq_transforms) 89 | 90 | def arg_to_numpy(f): 91 | 92 | def wrapper(x): 93 | if isinstance(x, torch.Tensor): 94 | x = x.detach().cpu().numpy() 95 | return f(x) 96 | return wrapper 97 | 98 | @arg_to_numpy 99 | def plot_to_image(arr): 100 | 101 | arr = arr.squeeze() 102 | figure = plt.figure() 103 | plt.imshow(arr) 104 | plt.colorbar() 105 | buf = io.BytesIO() 106 | plt.savefig(buf, format='jpg') 107 | plt.close(figure) 108 | buf.seek(0) 109 | img = imread(buf) 110 | 111 | return img.transpose((2, 0, 1)) 112 | 113 | def _make_trainable(module): 114 | """Unfreeze a given module. 115 | Operates in-place. 116 | Parameters 117 | ---------- 118 | module : instance of `torch.nn.Module` 119 | """ 120 | for param in module.parameters(): 121 | param.requires_grad = True 122 | module.train() 123 | 124 | 125 | def _recursive_freeze(module, train_bn=True): 126 | """Freeze the layers of a given module. 127 | Operates in-place. 128 | Parameters 129 | ---------- 130 | module : instance of `torch.nn.Module` 131 | train_bn : bool (default: True) 132 | If True, the BatchNorm layers will remain in training mode. 133 | Otherwise, they will be set to eval mode along with the other modules. 134 | """ 135 | children = list(module.children()) 136 | if not children: 137 | if not (isinstance(module, BN_TYPES) and train_bn): 138 | for param in module.parameters(): 139 | param.requires_grad = False 140 | module.eval() 141 | else: 142 | # Make the BN layers trainable 143 | _make_trainable(module) 144 | else: 145 | for child in children: 146 | _recursive_freeze(module=child, train_bn=train_bn) 147 | 148 | def freeze(module, n=-1, train_bn=True): 149 | """Freeze the layers up to index n. 150 | Operates in-place. 151 | Parameters 152 | ---------- 153 | module : instance of `torch.nn.Module` 154 | n : int 155 | By default, all the layers will be frozen. Otherwise, an integer 156 | between 0 and `len(module.children())` must be given. 157 | train_bn : bool (default: True) 158 | If True, the BatchNorm layers will remain in training mode. 159 | """ 160 | idx = 0 161 | children = list(module.children()) 162 | n_max = len(children) if n == -1 else int(n) 163 | for child in children: 164 | if idx < n_max: 165 | #print('-----------------',child,'-------------') 166 | _recursive_freeze(module=child, train_bn=train_bn) 167 | else: 168 | _make_trainable(module=child) 169 | 170 | def _recursive_freeze_bn_only(module): 171 | """Freeze the BN-layers of a given module. 172 | Operates in-place. 173 | Parameters 174 | ---------- 175 | module : instance of `torch.nn.Module` 176 | """ 177 | children = list(module.children()) 178 | if not children: 179 | if isinstance(module, BN_TYPES): 180 | print('Froze ',module) 181 | for param in module.parameters(): 182 | param.requires_grad = False 183 | module.eval() 184 | else: 185 | # Make the other layers trainable 186 | _make_trainable(module) 187 | else: 188 | for child in children: 189 | _recursive_freeze_bn_only(module=child) 190 | 191 | def freeze_bn_only(module, n=-1): 192 | """Freeze the BN-layers up to index n. 193 | Operates in-place. 194 | Parameters 195 | ---------- 196 | module : instance of `torch.nn.Module` 197 | n : int 198 | By default, all the BN-layers will be frozen. Otherwise, an integer 199 | between 0 and `len(module.children())` must be given. 200 | """ 201 | idx = 0 202 | children = list(module.children()) 203 | n_max = len(children) if n == -1 else int(n) 204 | for child in children: 205 | if idx < n_max: 206 | #print('-----------------',child,'-------------') 207 | _recursive_freeze_bn_only(module=child) 208 | else: 209 | _make_trainable(module=child) 210 | 211 | class CheckpointEveryNSteps(pl.Callback): 212 | """ 213 | Save a checkpoint every N steps, instead of Lightning's default that checkpoints 214 | based on validation loss. 215 | """ 216 | 217 | def __init__( 218 | self, 219 | save_step_frequency, 220 | filepath, 221 | prefix="model", 222 | use_modelcheckpoint_filename=False, 223 | ): 224 | """ 225 | Args: 226 | save_step_frequency: how often to save in steps 227 | prefix: add a prefix to the name, only used if 228 | use_modelcheckpoint_filename=False 229 | use_modelcheckpoint_filename: just use the ModelCheckpoint callback's 230 | default filename, don't use ours. 231 | """ 232 | self.save_step_frequency = save_step_frequency 233 | self.prefix = prefix 234 | self.filepath = filepath 235 | self.use_modelcheckpoint_filename = use_modelcheckpoint_filename 236 | 237 | os.makedirs(self.filepath, exist_ok=True) 238 | 239 | def on_batch_end(self, trainer: pl.Trainer, _): 240 | """ Check if we should save a checkpoint after every train batch """ 241 | epoch = trainer.current_epoch 242 | global_step = trainer.global_step 243 | if global_step % self.save_step_frequency == 0: 244 | if self.use_modelcheckpoint_filename: 245 | filename = trainer.checkpoint_callback.filename 246 | else: 247 | filename = f"{self.prefix}lAV_epoch={epoch}_step={global_step}.ckpt" 248 | ckpt_path = os.path.join(self.filepath, filename) 249 | trainer.save_checkpoint(ckpt_path) 250 | import os, glob 251 | import io 252 | import numpy as np 253 | from imageio import imread 254 | from PIL import Image 255 | 256 | import torch 257 | from torchvision import transforms 258 | 259 | import matplotlib.pyplot as plt 260 | 261 | import preprocessing.transforms_video as tv 262 | from config import CONFIG 263 | 264 | import pytorch_lightning as pl 265 | 266 | BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d) 267 | 268 | def normalize(img): 269 | return img / 255. 270 | 271 | def get_totensor_transform(is_video): 272 | if is_video: 273 | return transforms.Compose([tv.ToTensorVideo()]) 274 | else: 275 | return transforms.Compose([transforms.ToTensor()]) 276 | 277 | def get_images(paths): 278 | 279 | imgs = [] 280 | for p in paths: 281 | imgs.append(normalize(imread(p))) 282 | 283 | return np.array(imgs, dtype=np.float32) 284 | 285 | def get_transformed_images(paths, transform): 286 | 287 | imgs = [] 288 | for p in paths: 289 | imgs.append(transform(Image.open(p))) 290 | 291 | imgs = torch.stack(imgs) 292 | 293 | return imgs 294 | 295 | def get_pil_images(paths): 296 | 297 | imgs = [] 298 | for p in paths: 299 | imgs.append(Image.open(p)) 300 | 301 | return imgs 302 | 303 | def get_transforms(augment): 304 | 305 | seq_transforms = [] 306 | 307 | if augment: 308 | seq_transforms.append(tv.ColorJitterVideo( 309 | CONFIG.AUGMENTATION.BRIGHTNESS_DELTA, 310 | CONFIG.AUGMENTATION.CONTRAST_DELTA, 311 | CONFIG.AUGMENTATION.HUE_DELTA, 312 | CONFIG.AUGMENTATION.SATURATION_DELTA 313 | )) 314 | 315 | if not CONFIG.AUGMENTATION.RANDOM_CROP: 316 | seq_transforms.append(tv.ResizeVideo(size=(CONFIG.DATA.IMAGE_SIZE, CONFIG.DATA.IMAGE_SIZE))) 317 | 318 | seq_transforms.append(tv.ToTensorVideo()) 319 | 320 | if CONFIG.AUGMENTATION.RANDOM_FLIP: 321 | seq_transforms.append(tv.RandomHorizontalFlipVideo(p=0.5)) 322 | 323 | if CONFIG.AUGMENTATION.RANDOM_CROP: 324 | seq_transforms.append(tv.RandomResizedCropVideo(size=CONFIG.DATA.IMAGE_SIZE)) 325 | 326 | else: 327 | if CONFIG.AUGMENTATION.RANDOM_CROP: 328 | seq_transforms.append(tv.ToTensorVideo()) 329 | seq_transforms.append(tv.CenterResizedCropVideo(size=CONFIG.DATA.IMAGE_SIZE)) 330 | else: 331 | seq_transforms.append(tv.ResizeVideo(size=(CONFIG.DATA.IMAGE_SIZE, CONFIG.DATA.IMAGE_SIZE))) 332 | seq_transforms.append(tv.ToTensorVideo()) 333 | 334 | seq_transforms.append(tv.NormalizeVideo(mean=[0.485, 0.456, 0.406], 335 | std=[0.229, 0.224, 0.225])) 336 | 337 | return transforms.Compose(seq_transforms) 338 | 339 | def arg_to_numpy(f): 340 | 341 | def wrapper(x): 342 | if isinstance(x, torch.Tensor): 343 | x = x.detach().cpu().numpy() 344 | return f(x) 345 | return wrapper 346 | 347 | @arg_to_numpy 348 | def plot_to_image(arr): 349 | 350 | arr = arr.squeeze() 351 | figure = plt.figure() 352 | plt.imshow(arr) 353 | plt.colorbar() 354 | buf = io.BytesIO() 355 | plt.savefig(buf, format='jpg') 356 | plt.close(figure) 357 | buf.seek(0) 358 | img = imread(buf) 359 | 360 | return img.transpose((2, 0, 1)) 361 | 362 | def _make_trainable(module): 363 | """Unfreeze a given module. 364 | Operates in-place. 365 | Parameters 366 | ---------- 367 | module : instance of `torch.nn.Module` 368 | """ 369 | for param in module.parameters(): 370 | param.requires_grad = True 371 | module.train() 372 | 373 | 374 | def _recursive_freeze(module, train_bn=True): 375 | """Freeze the layers of a given module. 376 | Operates in-place. 377 | Parameters 378 | ---------- 379 | module : instance of `torch.nn.Module` 380 | train_bn : bool (default: True) 381 | If True, the BatchNorm layers will remain in training mode. 382 | Otherwise, they will be set to eval mode along with the other modules. 383 | """ 384 | children = list(module.children()) 385 | if not children: 386 | if not (isinstance(module, BN_TYPES) and train_bn): 387 | for param in module.parameters(): 388 | param.requires_grad = False 389 | module.eval() 390 | else: 391 | # Make the BN layers trainable 392 | _make_trainable(module) 393 | else: 394 | for child in children: 395 | _recursive_freeze(module=child, train_bn=train_bn) 396 | 397 | def freeze(module, n=-1, train_bn=True): 398 | """Freeze the layers up to index n. 399 | Operates in-place. 400 | Parameters 401 | ---------- 402 | module : instance of `torch.nn.Module` 403 | n : int 404 | By default, all the layers will be frozen. Otherwise, an integer 405 | between 0 and `len(module.children())` must be given. 406 | train_bn : bool (default: True) 407 | If True, the BatchNorm layers will remain in training mode. 408 | """ 409 | idx = 0 410 | children = list(module.children()) 411 | n_max = len(children) if n == -1 else int(n) 412 | for child in children: 413 | if idx < n_max: 414 | #print('-----------------',child,'-------------') 415 | _recursive_freeze(module=child, train_bn=train_bn) 416 | else: 417 | _make_trainable(module=child) 418 | 419 | def _recursive_freeze_bn_only(module): 420 | """Freeze the BN-layers of a given module. 421 | Operates in-place. 422 | Parameters 423 | ---------- 424 | module : instance of `torch.nn.Module` 425 | """ 426 | children = list(module.children()) 427 | if not children: 428 | if isinstance(module, BN_TYPES): 429 | print('Froze ',module) 430 | for param in module.parameters(): 431 | param.requires_grad = False 432 | module.eval() 433 | else: 434 | # Make the other layers trainable 435 | _make_trainable(module) 436 | else: 437 | for child in children: 438 | _recursive_freeze_bn_only(module=child) 439 | 440 | def freeze_bn_only(module, n=-1): 441 | """Freeze the BN-layers up to index n. 442 | Operates in-place. 443 | Parameters 444 | ---------- 445 | module : instance of `torch.nn.Module` 446 | n : int 447 | By default, all the BN-layers will be frozen. Otherwise, an integer 448 | between 0 and `len(module.children())` must be given. 449 | """ 450 | idx = 0 451 | children = list(module.children()) 452 | n_max = len(children) if n == -1 else int(n) 453 | for child in children: 454 | if idx < n_max: 455 | #print('-----------------',child,'-------------') 456 | _recursive_freeze_bn_only(module=child) 457 | else: 458 | _make_trainable(module=child) 459 | 460 | class CheckpointEveryNSteps(pl.Callback): 461 | """ 462 | Save a checkpoint every N steps, instead of Lightning's default that checkpoints 463 | based on validation loss. 464 | """ 465 | 466 | def __init__( 467 | self, 468 | save_step_frequency, 469 | filepath, 470 | prefix="model", 471 | use_modelcheckpoint_filename=False, 472 | ): 473 | """ 474 | Args: 475 | save_step_frequency: how often to save in steps 476 | prefix: add a prefix to the name, only used if 477 | use_modelcheckpoint_filename=False 478 | use_modelcheckpoint_filename: just use the ModelCheckpoint callback's 479 | default filename, don't use ours. 480 | """ 481 | self.save_step_frequency = save_step_frequency 482 | self.prefix = prefix 483 | self.filepath = filepath 484 | self.use_modelcheckpoint_filename = use_modelcheckpoint_filename 485 | 486 | os.makedirs(self.filepath, exist_ok=True) 487 | 488 | def on_batch_end(self, trainer: pl.Trainer, _): 489 | """ Check if we should save a checkpoint after every train batch """ 490 | epoch = trainer.current_epoch 491 | global_step = trainer.global_step 492 | if global_step % self.save_step_frequency == 0: 493 | if self.use_modelcheckpoint_filename: 494 | filename = trainer.checkpoint_callback.filename 495 | else: 496 | filename = f"{self.prefix}lAV_epoch={epoch}_step={global_step}.ckpt" 497 | ckpt_path = os.path.join(self.filepath, filename) 498 | trainer.save_checkpoint(ckpt_path) 499 | --------------------------------------------------------------------------------