├── __init__.py ├── models ├── __init__.py ├── modules.py └── nets.py ├── utils ├── __init__.py ├── options.py ├── analysis.py ├── videos.py ├── criterions.py └── transforms.py ├── test.py ├── configs ├── split.json ├── BehaveNet_model_config.json ├── DisAE_model_config.json └── DBE_model_config.json ├── README.md ├── .gitignore ├── evaluate.py ├── train_behavenet.py ├── train_disae.py └── train_dbe.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from utils.options import * -------------------------------------------------------------------------------- /configs/split.json: -------------------------------------------------------------------------------- 1 | { 2 | "train":[ 3 | "/data2/Dropbox/Shahar/learning project/experiments/Videos/PT3/2018-03-13/EPC_PT3_2018-03-13_001/movie_comb.avi" 4 | ], 5 | "test":[ 6 | "/data2/Dropbox/Shahar/learning project/experiments/Videos/PT3/2018-03-13/EPC_PT3_2018-03-13_002/movie_comb.avi" 7 | ] 8 | } -------------------------------------------------------------------------------- /configs/BehaveNet_model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "BehaveNet", 3 | "latent": 16, 4 | "frames_per_clip": 300, 5 | 6 | "encoder": 7 | { 8 | "blocks": 4, 9 | "input_size": [128, 128, 2], 10 | "channels": [32, 64, 128, 256], 11 | "kernel_size": [3, 3, 3, 3], 12 | "stride": [1, 1, 1, 1], 13 | "padding": [1, 1, 1, 1], 14 | "last_conv_size": [8, 8, 256], 15 | "kernel_size1d": [3, 3, 3, 3], 16 | "mean_pool": true, 17 | "down": "maxpool_unpool", 18 | "motion_only": false 19 | }, 20 | 21 | "decoder": 22 | { 23 | "blocks": 4, 24 | "first_deconv_size": [256, 8, 8], 25 | "channels": [128, 64, 32, 2], 26 | "kernel_size": [3, 3, 3, 3], 27 | "stride": [2, 2, 2, 2], 28 | "padding": [1, 1, 1, 1], 29 | "up": "unpool" 30 | } 31 | } -------------------------------------------------------------------------------- /configs/DisAE_model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "DisAE", 3 | "pose": 8, 4 | "content": 512, 5 | "frames_per_clip": 300, 6 | "first_frame": true, 7 | 8 | "encoder_ps": 9 | { 10 | "blocks": 4, 11 | "input_size": [128, 128, 1], 12 | "channels": [32, 64, 128, 256], 13 | "kernel_size": [5, 5, 5, 5], 14 | "stride": [1, 1, 1, 1], 15 | "padding": [2, 2, 2, 2], 16 | "last_conv_size": [8, 8, 256], 17 | "mean_pool": true, 18 | "down": "maxpool", 19 | "motion_only": false 20 | }, 21 | 22 | "encoder_ct": 23 | { 24 | "blocks": 4, 25 | "input_size": [128, 128, 1], 26 | "channels": [32, 64, 128, 256], 27 | "kernel_size": [5, 5, 5, 5], 28 | "stride": [1, 1, 1, 1], 29 | "padding": [2, 2, 2, 2], 30 | "last_conv_size": [8, 8, 256], 31 | "mean_pool": false, 32 | "down": "conv", 33 | "motion_only": false 34 | }, 35 | 36 | "decoder": 37 | { 38 | "blocks": 4, 39 | "first_deconv_size": [256, 8, 8], 40 | "channels": [128, 64, 32, 1], 41 | "kernel_size": [5, 5, 5, 5], 42 | "stride": [2, 2, 2, 2], 43 | "padding": [2, 2, 2, 2], 44 | "up": "upsample" 45 | } 46 | } -------------------------------------------------------------------------------- /configs/DBE_model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "DBE", 3 | "pose": 8, 4 | "content": 1024, 5 | "frames_per_clip": 20, 6 | 7 | "state_dim": 8, 8 | "gaussian_dim": 4, 9 | "n_clusters": 30, 10 | "straight_through": true, 11 | "ct_cond": false, 12 | "independent_cluster": false, 13 | "different_vars": false, 14 | "context_substraction": false, 15 | "tau": 1, 16 | "bidirectional": true, 17 | "rnn_layers": 1, 18 | "beta": [0.00005, 0.00005, 0.00001], 19 | "n_past": 10, 20 | 21 | "encoder_ps": 22 | { 23 | "blocks": 4, 24 | "input_size": [128, 128, 1], 25 | "channels": [32, 64, 128, 256], 26 | "kernel_size": [5, 5, 5, 5], 27 | "stride": [1, 1, 1, 1], 28 | "padding": [2, 2, 2, 2], 29 | "last_conv_size": [8, 8, 256], 30 | "mean_pool": true, 31 | "down": "maxpool", 32 | "motion_only": false 33 | }, 34 | 35 | "encoder_ct": 36 | { 37 | "blocks": 4, 38 | "input_size": [128, 128, 1], 39 | "channels": [32, 64, 128, 256], 40 | "kernel_size": [5, 5, 5, 5], 41 | "stride": [1, 1, 1, 1], 42 | "padding": [2, 2, 2, 2], 43 | "last_conv_size": [8, 8, 256], 44 | "mean_pool": false, 45 | "down": "conv", 46 | "motion_only": false 47 | }, 48 | 49 | "decoder": 50 | { 51 | "blocks": 4, 52 | "first_deconv_size": [256, 8, 8], 53 | "channels": [128, 64, 32, 1], 54 | "kernel_size": [5, 5, 5, 5], 55 | "stride": [2, 2, 2, 2], 56 | "padding": [2, 2, 2, 2], 57 | "up": "upsample" 58 | } 59 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DBE: Disentangled-Behavior-Embedding 2 | 3 | Official implementation of [Learning Disentangled Behavior Embeddings (NeurIPS 2021)](https://openreview.net/forum?id=ThbM9_6DNU). 4 | 5 | ## Environment requirement 6 | The whole experiment process is based on PyTorch: 7 | - PyTorch 1.3.0 8 | - Torchvision 0.4.1 9 | 10 | Note this is neither the minimum updated requirement nor the latest version. 11 | Other versions may just work as well. 12 | 13 | ## Create video dataset 14 | 1. Add train and test videos directories to "./configs/split.json". 15 | 2. Note that the videos are assumed to be multiview with 2 views concatenated horizontally. 16 | 17 | ## Train your models 18 | There are multiple training scripts in this repo as we are trying different models. 19 | 1. The model parameters are stored in "./configs/model_name_model_configs.json". To change the architecture, edit the json file. The model config file will saved for each running together with training hyperparameters. 20 | 2. Video frames are by defaulted resized to 128 by 128. Use argparse to specify other data related setting, e.g. frame rate, frame per clip, crop range etc. 21 | 3. Run bash command to train a DBE model. An example: 22 | 23 | ```bash 24 | python3 train_dbe.py -n name_of_exp -u id_of_gpu -l recon -bs num_of_batch_size -ep num_of_epochs --lr num_of_lr -fpc frame_per_clip 25 | ``` 26 | The choice of these parameters depends on your computing power. The batch size is recommanded to be larger than 8. The suitable number of epochs depends on the size of the dataset. 27 | 4. Results will be saved in "outputs/name_of_exp". 28 | 29 | ## Evaluate your models 30 | 1. Evaluate the trained model by specifying the name of experiment. Note that the time of the experiment being created is added as the prefix of the experiment name before training. 31 | 32 | ```bash 33 | python3 evaluate.py -u id_of_gpu -bs num_of_batch_size -n time-name_of_exp -fpc frame_per_clip -md test 34 | ``` 35 | 2. The latent embeddings will be saved in the experiment directory. For DBE model, behavioral state estimation will also be saved. 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | def parse_args(): 5 | 6 | parser = argparse.ArgumentParser(description='PyTorch Video Compression Training') 7 | 8 | parser.add_argument('-cf', '--config-dir', default='../DBE-Disentangled-Behavior-Embedding/configs', help='config dir') 9 | parser.add_argument('-sd', '--save-dir', default='../DBE-Disentangled-Behavior-Embedding/outputs', help='path to save') 10 | parser.add_argument('-n', '--name', default='', help='folder in save path') 11 | parser.add_argument('-r', '--resume', dest="resume", help='resume from checkpoint', action="store_true") 12 | parser.add_argument('-md', '--mode', default='test', help='test mode') 13 | 14 | parser.add_argument('--rnn_layer', default=1, type=int, metavar='N', help='hidden layer of rnn') 15 | parser.add_argument('--rnn_hidden', default=8, type=int, metavar='N', help='hidden dim of rnn') 16 | parser.add_argument('-l', '--criterion', default='recon', help='training criterion') 17 | 18 | parser.add_argument('-fpc', '--frames_per_clip', default=50, type=int, metavar='N', help='# of frames per clip') 19 | parser.add_argument('-fps', '--frame_rate', default=1, type=int, metavar='N', help='# of frames per second') 20 | parser.add_argument('-fs', '--frame_size', default=128, type=int, metavar='N', help='# of frames dimensions') 21 | parser.add_argument('-np', '--n_past', default=10, type=int, metavar='N', help='# of past frames to infer z0 (DBE only)') 22 | parser.add_argument('--start', default=750, help='selected interval of videos') 23 | parser.add_argument('--end', default=1150, help='selected interval of videos') 24 | 25 | parser.add_argument('-d', '--device', default='cuda', help='device') 26 | parser.add_argument('-u', '--gpus', default='', help='index of specified gpus') 27 | parser.add_argument('-bs', '--batch-size', default=8, type=int) 28 | parser.add_argument('-ep', '--epochs', default=50, type=int, metavar='N', help='# of total epochs') 29 | parser.add_argument('-wp', '--warmups', default=50, type=int, metavar='N', help='# of total warm up epochs') 30 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='# of data loaders (default: 16)') 31 | parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate (default: 0.001)') 32 | parser.add_argument('--lr-step', default=10000, type=int, help='decrease lr every these iterations') 33 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 34 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M') 35 | parser.add_argument('--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') 36 | parser.add_argument('--anneal', default=0, type=float, help='turning point for linear annealing') 37 | parser.add_argument('--alpha', default=1.0, type=float, help='ratio of rnn loss') 38 | parser.add_argument('--validate-freq', default=1, type=int, help='validation frequency') 39 | 40 | args = parser.parse_args() 41 | 42 | return args -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | from glob import glob 4 | from tqdm import tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from utils.options import * 12 | from utils.videos import * 13 | from utils.criterions import * 14 | from utils.analysis import * 15 | 16 | 17 | def evaluate(args): 18 | 19 | save_dir = os.path.join(args.save_dir, args.name) 20 | assert os.path.exists(save_dir) 21 | sys.path.append(save_dir) 22 | 23 | # model building 24 | saved_model_config = glob('{}/*_model_config.json'.format(save_dir))[0] 25 | with open(saved_model_config, 'r') as f: 26 | model_config = json.load(f) 27 | model_config["frames_per_clip"] = args.frames_per_clip 28 | use_first_frame = model_config["first_frame"] 29 | 30 | # data loading 31 | transform = ResizeVideo(args.frame_size) 32 | mod = args.mode 33 | eval_set = BehavioralVideo(args.frames_per_clip, frame_rate=args.frame_rate, transform=transform, mod=mod, rand_crop=False, jitter=False, return_first_frame=use_first_frame) 34 | eval_loader = DataLoader(eval_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 35 | 36 | # criterion & device 37 | device = args.device 38 | 39 | print('Model: {}'.format(model_config["model"])) 40 | from models.nets import BehaveNet, DisAE, DBE 41 | latest_model = eval(model_config["model"])(model_config) 42 | 43 | latest_dict = torch.load('{}/latest_model.pth'.format(save_dir), map_location='cpu') 44 | latest_model.load_state_dict(latest_dict) 45 | 46 | # read criterion 47 | criterion = globals()[model_config["criterion"]] 48 | 49 | if args.device == 'cuda': 50 | if args.gpus: 51 | device_ids = [int(idx) for idx in list(args.gpus)] 52 | device = '{}:{}'.format(args.device, device_ids[0]) 53 | latest_model = nn.DataParallel(latest_model, device_ids=device_ids).to(device) 54 | else: 55 | device = torch.device('cuda:0') 56 | latest_model = nn.DataParallel(latest_model).to(device) 57 | elif args.device == 'cpu': 58 | device = torch.device('cpu') 59 | 60 | with open('{}/{}_trials.pkl'.format(save_dir, mod), 'wb') as f: 61 | pickle.dump(eval_set.trials, f) 62 | 63 | latest_model.eval() 64 | 65 | print('start evaluating...') 66 | latents, contents, states = [], [], [] 67 | loss_track = [] 68 | with torch.no_grad(): 69 | for front, side in eval_loader: 70 | 71 | front, side = front.to(device), side.to(device) 72 | if model_config["model"] == 'DBE': 73 | try: 74 | latest_model.set_steps(args.n_past, args.frames_per_clip-args.n_past) 75 | except: 76 | latest_model.module.set_steps(args.n_past, args.frames_per_clip-args.n_past) 77 | (output1, output2), _, _, probs = latest_model(front, side) 78 | states.append(probs[0].detach()) 79 | else: 80 | output1, output2 = latest_model(front, side) 81 | 82 | if use_first_frame: 83 | loss = criterion(output1, front[:, 1:]) + criterion(output2, side[:, 1:]) 84 | else: 85 | loss = criterion(output1, front) + criterion(output2, side) 86 | loss_track.append(loss.item()) 87 | 88 | try: 89 | latents.append(latest_model.latent.detach()) 90 | except: 91 | latents.append(latest_model.module.latent.detach()) 92 | 93 | print('Saving latent embedding...') 94 | save_file(latents, file_dir='{}/{}_latents.pkl'.format(save_dir, mod)) 95 | if len(states) > 0: 96 | save_file(states, file_dir='{}/{}_states.pkl'.format(save_dir, mod)) 97 | 98 | print('loss: ', torch.mean(torch.FloatTensor(loss_track))) 99 | 100 | 101 | if __name__ == "__main__": 102 | 103 | args = parse_args() 104 | evaluate(args) -------------------------------------------------------------------------------- /utils/analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import cv2 5 | import pickle 6 | import datetime 7 | from tqdm import tqdm 8 | 9 | 10 | def save_file(buffer, file_dir): 11 | buffer = torch.cat(buffer, dim=0).cpu().numpy() 12 | with open(file_dir, 'wb') as f: 13 | pickle.dump(buffer, f) 14 | 15 | def convert_date(date_str): 16 | return datetime.date.strftime(datetime.datetime.strptime(date_str, '%Y-%m-%d'), "%b%d") 17 | 18 | def generate_labels(gt, classes, full_length=2400): 19 | labels = np.zeros((len(gt), full_length), dtype=int) 20 | for i, (n, t) in enumerate(gt.items()): 21 | for j, k in enumerate(classes): 22 | if k in t.keys(): 23 | labels[i, t[k]] = j+1 24 | return labels 25 | 26 | def get_states(all_states, target_trials, gt): 27 | # get the state estimation of the labeled video frames 28 | all_states = all_states.argmax(2).copy() 29 | s = [] 30 | trial_nums = [k.rsplit('/', 1)[1] for k in target_trials] 31 | for j, (k, v) in enumerate(gt.items()): 32 | idx = k.rsplit('_', 1)[1] 33 | if idx in trial_nums: 34 | s.append(all_states[trial_nums.index(idx)]) 35 | return np.stack(s) 36 | 37 | def generate_mapping(states, labels, n_states=None, n_labels=None): 38 | # generate a mapping from the state estimation to ground truth labeling 39 | assert states.shape==labels.shape 40 | mapping = [] 41 | if n_states is None: 42 | n_states = states.max()+1 43 | if n_labels is None: 44 | n_labels = labels.max()+1 45 | for i in range(n_states): 46 | tar, max_count = 0, 0 47 | for j in range(1, n_labels): 48 | count = labels[(labels==j)&(states==i)].sum() 49 | if count > max_count: 50 | max_count, tar = count, j 51 | mapping.append(tar) 52 | states_map = states.copy() 53 | for i, c in enumerate(mapping): 54 | states_map[states==i] = c 55 | return states_map, mapping 56 | 57 | def gather_motifs(trials, states, tar, start=750, resolution=400, win=0, video_dir='/data2/changhao/Dataset/Behavioral-Videos/Videos/'): 58 | images = [] 59 | info = [] 60 | for trial, indices in tqdm(zip(trials, states.argmax(-1))): 61 | img = np.zeros((resolution, resolution*2, 3)) 62 | for k in range(1, indices.shape[0]-1): 63 | if indices[k-1] != tar and indices[k] == tar: 64 | k0 = k 65 | i0 = cv2.imread('{}/img_{:04d}.jpg'.format(os.path.join(video_dir, trial), start+(k+1)+win//2), 0) 66 | img[:, :, 1] = i0 67 | img[:, :, 2] = i0 68 | if indices[k-1] == tar and indices[k] != tar: 69 | if k - k0 <=3: 70 | img = np.zeros((resolution, resolution*2, 3)) 71 | continue 72 | i1 = cv2.imread('{}/img_{:04d}.jpg'.format(os.path.join(video_dir, trial), start+(k+1)+win//2), 0) 73 | img[:, :, 0] = i1 74 | images.append(img) 75 | info.append([trial, k0, k]) 76 | if img[:, :, 0].sum() > 0: 77 | img = np.zeros((resolution, resolution*2, 3)) 78 | return images, info 79 | 80 | def dlc_regression(latents, trials, dlcs, dlc_trials, markers, session): 81 | assert len(latents)==len(trials) 82 | dlc_trial_nums = [n.rsplit('_', 1)[1] for n in dlc_trials] 83 | dlc_indices = [dlc_trial_nums.index(t.rsplit('/', 1)[1]) for t in trials if session in t] 84 | dlc_session = dlcs[dlc_indices] 85 | X = np.stack([l for l, t in zip(list(latents), trials) if session in t], axis=0) 86 | Y = dlc_session[:, 750:1150] 87 | score, dlc_reg = [], [] 88 | from sklearn.linear_model import LinearRegression 89 | for i in range(len(markers)): 90 | mask = (Y[:, :, 3*i+2].flatten()>0.95) 91 | x = X.reshape((-1, X.shape[-1]))[mask] 92 | y = Y[:, :, 3*i:3*i+2].reshape((-1, 2))[mask] 93 | reg = LinearRegression().fit(x, y) 94 | y_hat = reg.predict(X.reshape((-1, X.shape[-1]))).reshape((*X.shape[:2], 2)) 95 | score.append(reg.score(x, y)) 96 | dlc_reg.append(y_hat) 97 | return dlc_reg, score -------------------------------------------------------------------------------- /utils/videos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision import io 4 | from torchvision.datasets.video_utils import VideoClips 5 | from torchvision.transforms.functional import to_tensor 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | # from utils import * 9 | # from utils.transforms import * 10 | 11 | import os 12 | import time 13 | import json 14 | import pickle 15 | import h5py 16 | import random 17 | import cv2 18 | 19 | 20 | class BehavioralVideo(Dataset): 21 | """ 22 | Assuming video format = cat[view1, view2] 23 | """ 24 | def __init__(self, frames_per_clip, frame_rate=4, resolution=400, anchor=600, interval=[750,1150], grayscale=True, 25 | mod='train', transform=None, rand_crop=False, jitter=True, return_first_frame=False, first_frame_jitter_range=15): 26 | self.mod = mod 27 | self.step = frame_rate 28 | self.fpc = frames_per_clip 29 | self.size = resolution 30 | self.anchor = anchor 31 | self.interval = interval 32 | self.rand_crop = rand_crop 33 | self.jitter = jitter 34 | self.return_first_frame = return_first_frame 35 | self.first_frame_jitter_range = first_frame_jitter_range 36 | self.grayscale = grayscale 37 | 38 | with open("./configs/split.json", "rb") as f: 39 | self.trials = json.load(f)[mod] 40 | self.transforms = torchvision.transforms.Compose([transform]) 41 | print('Total {} trials: '.format(mod), len(self.trials)) 42 | 43 | def get_clip(self, trial_dir, rand_crop=False, jitter=False): 44 | 45 | # read from raw videos 46 | vidcap = cv2.VideoCapture(trial_dir) 47 | raw_video_flow, success = [], True 48 | while success: 49 | success, frame = vidcap.read() 50 | raw_video_flow.append(frame) 51 | raw_video_flow = raw_video_flow[:-1] 52 | if self.grayscale: 53 | raw_video_flow = [frame[:, :, 0] for frame in raw_video_flow] 54 | 55 | # random crop 56 | rand = random.randint(-self.step, self.step) if jitter else 0 57 | if rand_crop: 58 | start = self.interval[0] + random.randint(0, self.interval[1]-self.interval[0]-self.fpc*self.step) 59 | else: 60 | start = self.interval[0] 61 | start, end = rand+start, rand+start+self.fpc*self.step 62 | clip_indices = range(start, end, self.step) 63 | clip = [to_tensor(raw_video_flow[idx]) for idx in clip_indices] 64 | clip = torch.stack(clip, dim=0) 65 | 66 | # read context frames 67 | if self.return_first_frame: 68 | first_frame_jitter = random.randint(-self.first_frame_jitter_range, self.first_frame_jitter_range) if jitter else 0 69 | anchor = max(0, self.anchor+first_frame_jitter) 70 | first_frame = to_tensor(raw_video_flow[anchor]) 71 | clip = torch.cat([first_frame.unsqueeze(0), clip], dim=0) 72 | 73 | return clip 74 | 75 | def __len__(self): 76 | return len(self.trials) 77 | 78 | def __getitem__(self, idx): 79 | 80 | clip = self.get_clip(self.trials[idx], rand_crop=self.rand_crop, jitter=self.jitter) 81 | front, side = clip[:, :, :, self.size:], clip[:, :, :, :self.size] 82 | 83 | if self.transforms is not None: 84 | front, side = self.transforms(front), self.transforms(side) 85 | 86 | return front, side 87 | 88 | 89 | class ResizeVideo(object): 90 | """ 91 | Resize the tensors 92 | Args: 93 | target_size: int or tuple 94 | """ 95 | def __init__(self, target_size, interpolation_mode='bilinear'): 96 | assert isinstance(target_size, int) or (isinstance(target_size, Iterable) and len(target_size) == 2) 97 | self.target_size = target_size 98 | self.interpolation_mode = interpolation_mode 99 | 100 | def __call__(self, clip): 101 | """ 102 | Args: 103 | clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) / (T, C, H, W) 104 | """ 105 | return resize(clip, self.target_size, self.interpolation_mode) 106 | 107 | def __repr__(self): 108 | return self.__class__.__name__ 109 | 110 | def resize(clip, target_size, interpolation_mode): 111 | # assert len(target_size) == 2, "target size should be tuple (height, width)" 112 | return torch.nn.functional.interpolate( 113 | clip, size=target_size, mode=interpolation_mode, align_corners=True 114 | ) 115 | -------------------------------------------------------------------------------- /utils/criterions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def recon(output, view, normalize=False): 7 | return F.mse_loss(output, view) 8 | 9 | # def pred(output, view, normalize=False): 10 | # if normalize: 11 | # beta = view.std((0,1)) 12 | # return F.mse_loss(output[:, :-1]/beta, view[:, 1:]/beta) 13 | # if output.shape[1] == view.shape[1]: 14 | # return F.mse_loss(output[:, :-1], view[:, 1:]) 15 | # elif output.shape[1] == view.shape[1]-1: 16 | # return F.mse_loss(output, view[:, 1:]) 17 | 18 | def ct_sim(content): 19 | return F.mse_loss(content.mean(dim=1, keepdim=True), content) 20 | 21 | def semi_cls(output, states, print_acc=False): 22 | # states: list of dicts 23 | labeled_frames, tar_states = get_labeled_frames(output, states) 24 | if len(labeled_frames) == 0: 25 | return torch.tensor(0.) 26 | labeled_frames = torch.cat(labeled_frames, dim=0) 27 | tar_states = torch.LongTensor(tar_states).to(output.device) 28 | if print_acc: 29 | acc = (labeled_frames.argmax(dim=1) == tar_states).sum().float() / len(tar_states) 30 | print('state acc: ', acc.item()) 31 | return F.cross_entropy(labeled_frames, tar_states) 32 | 33 | def get_labeled_frames(output, states): 34 | labeled_frames, tar_states = [], [] 35 | for i, vid in enumerate(states): 36 | for tar_state, state_indices in vid.items(): 37 | labeled_frames.append(output[i, state_indices]) 38 | tar_states += [tar_state] * len(state_indices) 39 | return labeled_frames, tar_states 40 | 41 | def kl(mu, logvar): 42 | # kld of fixed prior 43 | mu = mu.flatten(end_dim=1) if mu.dim()==3 else mu 44 | logvar = logvar.flatten(end_dim=1) if logvar.dim()==3 else logvar 45 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 46 | KLD /= mu.shape[0] 47 | return KLD 48 | 49 | def kl_general(mu1, logvar1, mu2, logvar2): 50 | # kld of learned prior 51 | mu1 = mu1.flatten(end_dim=1) if mu1.dim()==3 else mu1 52 | mu2 = mu2.flatten(end_dim=1) if mu2.dim()==3 else mu2 53 | sigma1 = logvar1.mul(0.5).exp() 54 | sigma2 = logvar2.mul(0.5).exp() 55 | kld = torch.log(sigma2/sigma1) + (torch.exp(logvar1) + (mu1 - mu2)**2)/(2*torch.exp(logvar2)) - 1/2 56 | KLD = kld.sum() / mu1.shape[0] 57 | return KLD 58 | 59 | def kl_cat(prob): 60 | # kld of uniform prior 61 | prob = prob.flatten(end_dim=1) if prob.dim()==3 else prob 62 | c = prob.shape[1] 63 | KLD = torch.sum(prob * (prob * c).log()) 64 | KLD /= prob.shape[0] 65 | return KLD 66 | 67 | def kl_cat_log(prob): 68 | # kld of uniform prior 69 | prob = prob.flatten(end_dim=1) if prob.dim()==3 else prob 70 | c = prob.shape[1] 71 | KLD = torch.sum(prob.exp() * prob) 72 | KLD /= prob.shape[0] 73 | return KLD 74 | 75 | def kl_cat_general_log(prob1, prob2): 76 | # kld of learned prior 77 | prob1 = prob1.flatten(end_dim=1) if prob1.dim()==3 else prob1 78 | prob2 = prob2.flatten(end_dim=1) if prob2.dim()==3 else prob2 79 | KLD = torch.sum(prob1 * (prob1 / prob2).log()) 80 | KLD /= prob1.shape[0] 81 | return KLD 82 | 83 | def kl_cat_general_log(prob1, prob2): 84 | # kld of learned prior probs are log_softmax 85 | prob1 = prob1.flatten(end_dim=1) if prob1.dim()==3 else prob1 86 | prob2 = prob2.flatten(end_dim=1) if prob2.dim()==3 else prob2 87 | KLD = torch.sum(prob1.exp() * (prob1 - prob2)) 88 | KLD /= prob1.shape[0] 89 | return KLD 90 | 91 | def vq(mu, quantized, commitment_cost=0.25): 92 | e_latent_loss = F.mse_loss(quantized.detach(), mu) 93 | q_latent_loss = F.mse_loss(quantized, mu.detach()) 94 | return q_latent_loss + commitment_cost * e_latent_loss 95 | 96 | def kmeans(centroids): 97 | # encourage separation between centroids (vector quantization) 98 | z_dim = centroids.shape[1] 99 | return -torch.cdist(centroids, centroids, p=1).mean() * z_dim / 2 100 | # return -(torch.cdist(centroids, centroids, p=2)**2).mean() * z_dim / 2 101 | 102 | def svd(z, k=30, lmbda=1): 103 | batch_size = z.shape[0] 104 | gram_matrix = (z @ z.T) / batch_size 105 | _ ,sv_2, _ = torch.svd(gram_matrix) 106 | sv = torch.sqrt(sv_2[:k]) 107 | return lmbda * torch.sum(sv) 108 | 109 | def beta_annealing(epoch, step=1000): 110 | return (epoch+1)/step if epoch 0: 137 | anneal = beta_annealing(epoch, args.anneal) 138 | loss += (klg_loss * beta[0] + kls_loss * beta[1] + klc_loss * beta[2]) * anneal 139 | else: 140 | loss += klg_loss * beta[0] + kls_loss * beta[1] + klc_loss * beta[2] 141 | 142 | optimizer.zero_grad() 143 | loss.backward() 144 | optimizer.step() 145 | 146 | print('train loss: {}, kl loss: {}, kls loss: {}, klc loss: {}'.format(sum(meter)/len(meter), sum(klg_meter)/len(klg_meter), sum(kls_meter)/len(kls_meter), sum(klc_meter)/len(klc_meter))) 147 | 148 | scheduler.step() 149 | # model validation 150 | if (epoch+1) % args.validate_freq == 0: 151 | 152 | print('save model') 153 | 154 | # save checkpoint 155 | try: 156 | torch.save({ 157 | 'epoch': epoch, 158 | 'model_state_dict': model.state_dict(), 159 | 'optimizer_state_dict': optimizer.state_dict(), 160 | }, '{}/checkpoint.pth'.format(save_dir)) 161 | except: 162 | torch.save({ 163 | 'epoch': epoch, 164 | 'model_state_dict': model.module.state_dict(), 165 | 'optimizer_state_dict': optimizer.state_dict(), 166 | }, '{}/checkpoint.pth'.format(save_dir)) 167 | 168 | try: 169 | torch.save(model.state_dict(), '{}/latest_model.pth'.format(save_dir)) 170 | except: 171 | torch.save(model.module.state_dict(), '{}/latest_model.pth'.format(save_dir)) 172 | file.close() 173 | 174 | 175 | if __name__ == "__main__": 176 | 177 | args = parse_args() 178 | print('exp: ', args.name) 179 | train(args) -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | 7 | def conv_block(in_channels, out_channels, kernel_size=5, stride=1, padding=2, down='maxpool'): 8 | ''' 9 | returns a conv block conv-bn-relu-pool 10 | set return indices to true for unpooling later 11 | ''' 12 | if down == 'conv': 13 | stride = 2 14 | return nn.Sequential( 15 | nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), 16 | Down2dWrapper(down), 17 | nn.BatchNorm2d(out_channels), 18 | ) 19 | 20 | 21 | def deconv_block(in_channels, out_channels, kernel_size=5, stride=1, padding=2, up='unpool', last_layer=False): 22 | ''' 23 | returns a deconv block conv-bn-relu-unpool 24 | ''' 25 | return nn.Sequential( 26 | Up2dWrapper(up), 27 | nn.ConvTranspose2d(in_channels, out_channels*(2**2) if up=='subpixel' else out_channels, kernel_size, stride=stride, padding=padding), # asymmetric padding? 28 | nn.BatchNorm2d(out_channels), 29 | ) 30 | 31 | 32 | class Down2dWrapper(nn.Module): 33 | ''' 34 | workaround for sequential not taking multiple inputs 35 | ''' 36 | def __init__(self, down='maxpool'): 37 | 38 | super(Down2dWrapper, self).__init__() 39 | self.down = down 40 | if down == 'maxpool': 41 | self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False) 42 | elif down == 'maxpool_unpool': 43 | self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=True) 44 | elif down == 'conv': 45 | pass 46 | else: 47 | raise ValueError("Downsample method not supported!") 48 | 49 | def forward(self, x): 50 | if self.down == 'maxpool_unpool': 51 | x, pool_idx = self.maxpool(x) 52 | self.pool_idx = pool_idx 53 | return x 54 | elif self.down == 'maxpool': 55 | return self.maxpool(x) 56 | elif self.down == 'conv': 57 | return x 58 | 59 | 60 | class Up2dWrapper(nn.Module): 61 | ''' 62 | workaround for sequential not taking multiple inputs 63 | ''' 64 | def __init__(self, up='unpool'): 65 | 66 | super(Up2dWrapper, self).__init__() 67 | self.up = up 68 | if up == 'unpool': 69 | self.unpool = nn.MaxUnpool2d(2, stride=2) 70 | elif up == 'upsample': 71 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 72 | elif up == 'subpixel': 73 | self.subpixel = nn.PixelShuffle(upscale_factor=2) 74 | else: 75 | raise ValueError("Upsample method not supported!") 76 | 77 | def forward(self, x): 78 | if self.up == 'unpool': 79 | x, pool_idx, hidden_size = x 80 | return self.unpool(x, pool_idx, hidden_size) 81 | elif self.up == 'upsample': 82 | return self.upsample(x) 83 | elif self.up == 'subpixel': 84 | return self.subpixel(x) 85 | 86 | 87 | class MLP(nn.Module): 88 | 89 | def __init__(self, in_dim, out_dim, hidden_dim=128, n_layer=2): 90 | super(MLP, self).__init__() 91 | self.linear1 = nn.Linear(in_dim, hidden_dim) 92 | self.linear2 = nn.Linear(hidden_dim, out_dim) 93 | 94 | def forward(self, x): 95 | x = F.relu(self.linear1(x)) 96 | x = self.linear2(x) 97 | return x 98 | 99 | class Gaussian_MLP(nn.Module): 100 | 101 | def __init__(self, in_dim, out_dim, hidden_dim=128, n_layer=2): 102 | super(Gaussian_MLP, self).__init__() 103 | self.linear = nn.Linear(in_dim, hidden_dim) 104 | self.mu_net = nn.Linear(hidden_dim, out_dim) 105 | self.logvar_net = nn.Linear(hidden_dim, out_dim) 106 | 107 | def reparameterize(self, mu, logvar): 108 | logvar = logvar.mul(0.5).exp_() 109 | eps = Variable(logvar.data.new(logvar.size()).normal_()) 110 | return eps.mul(logvar).add_(mu) 111 | 112 | def forward(self, x): 113 | h = F.relu(self.linear(x)) 114 | mu = self.mu_net(h) 115 | logvar = self.logvar_net(h) 116 | if self.training: 117 | z = self.reparameterize(mu, logvar) 118 | return z, mu, logvar 119 | return mu, mu, logvar 120 | 121 | class Recurrent(nn.Module): 122 | 123 | def __init__(self, in_dim, out_dim, hidden_dim=128, n_layer=1, mod='GRU'): 124 | super(Recurrent, self).__init__() 125 | self.rnn = getattr(nn, mod)(in_dim, hidden_dim, n_layer, batch_first=True) 126 | self.linear = nn.Linear(hidden_dim, out_dim) 127 | 128 | def forward(self, x): 129 | h, _ = self.rnn(x) 130 | x = self.linear(h) 131 | return x 132 | 133 | class Recurrent_Cell(nn.Module): 134 | 135 | def __init__(self, in_dim, out_dim, hidden_dim=128, n_layer=1, mod='GRU'): 136 | super(Recurrent_Cell, self).__init__() 137 | self.mod = mod 138 | self.n_layer = n_layer 139 | self.rnn_cell = nn.ModuleList([getattr(nn, '{}Cell'.format(mod))(in_dim if i==0 else hidden_dim, hidden_dim, n_layer, batch_first=True) for i in range(n_layer)]) 140 | self.linear = nn.Linear(hidden_dim, out_dim) 141 | 142 | def init_hidden(self, device): 143 | hidden = [] 144 | for i in range(self.n_layer): 145 | hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).to(device)), 146 | # Variable(torch.zeros(self.batch_size, self.hidden_size).to(device)) 147 | )) 148 | return hidden 149 | 150 | def forward(self, x): 151 | h = x 152 | for i in range(self.n_layer): 153 | self.hidden[i] = self.rnn_cell[i](h, self.hidden[i]) 154 | h = self.hidden[i][0] 155 | x = self.linear(h) 156 | return x 157 | 158 | 159 | class Gaussian_Recurrent(nn.Module): 160 | 161 | def __init__(self, in_dim, out_dim, hidden_dim=128, n_layer=1, mod='GRU'): 162 | super(Gaussian_Recurrent, self).__init__() 163 | self.rnn = getattr(nn, mod)(in_dim, hidden_dim, n_layer, batch_first=True) 164 | self.mu_net = nn.Linear(hidden_dim, out_dim) 165 | self.logvar_net = nn.Linear(hidden_dim, out_dim) 166 | 167 | def reparameterize(self, mu, logvar): 168 | logvar = logvar.mul(0.5).exp_() 169 | eps = Variable(logvar.data.new(logvar.size()).normal_()) 170 | return eps.mul(logvar).add_(mu) 171 | 172 | def forward(self, x): 173 | h, _ = self.rnn(x) 174 | mu = self.mu_net(h) 175 | logvar = self.logvar_net(h) 176 | if self.training: 177 | z = self.reparameterize(mu, logvar) 178 | return z, mu, logvar 179 | return mu, mu, logvar 180 | 181 | class Gaussian_Recurrent_Cell(nn.Module): 182 | 183 | def __init__(self, in_dim, out_dim, hidden_dim=128, n_layer=1, mod='GRU'): 184 | super(Gaussian_Recurrent_Cell, self).__init__() 185 | self.mod = mod 186 | self.n_layer = n_layer 187 | self.rnn_cell = nn.ModuleList([getattr(nn, '{}Cell'.format(mod))(in_dim if i==0 else hidden_dim, hidden_dim, n_layer, batch_first=True) for i in range(n_layer)]) 188 | self.mu_net = nn.Linear(hidden_dim, out_dim) 189 | self.logvar_net = nn.Linear(hidden_dim, out_dim) 190 | 191 | def init_hidden(self, device): 192 | hidden = [] 193 | for i in range(self.n_layer): 194 | hidden.append((Variable(torch.zeros(self.batch_size, self.hidden_size).to(device)), 195 | # Variable(torch.zeros(self.batch_size, self.hidden_size).to(device)) 196 | )) 197 | return hidden 198 | 199 | def reparameterize(self, mu, logvar): 200 | logvar = logvar.mul(0.5).exp_() 201 | eps = Variable(logvar.data.new(logvar.size()).normal_()) 202 | return eps.mul(logvar).add_(mu) 203 | 204 | def forward(self, x): 205 | h = x 206 | for i in range(self.n_layer): 207 | self.hidden[i] = self.rnn_cell[i](h, self.hidden[i]) 208 | h = self.hidden[i][0] 209 | mu = self.mu_net(h) 210 | logvar = self.logvar_net(h) 211 | if self.training: 212 | z = self.reparameterize(mu, logvar) 213 | return z, mu, logvar 214 | return mu, mu, logvar 215 | 216 | 217 | class GM_Recurrent(nn.Module): 218 | 219 | def __init__(self, in_dim, out_dim, n_clusters, tau=1, hidden_dim=128, n_layer=1, mod='GRU', bid=False): 220 | super(GM_Recurrent, self).__init__() 221 | self.mod = mod 222 | self.tau = tau 223 | self.n_layer = n_layer 224 | self.n_clusters = n_clusters 225 | self.rnn = getattr(nn, mod)(in_dim, hidden_dim, n_layer, batch_first=True, bidirectional=bid) 226 | 227 | if bid: 228 | self.mu_net = nn.Linear(hidden_dim*2, out_dim) 229 | self.logvar_net = nn.Linear(hidden_dim*2, out_dim) 230 | 231 | self.c_net = nn.Linear(hidden_dim*2, n_clusters) 232 | else: 233 | self.mu_net = nn.Linear(hidden_dim, out_dim) 234 | self.logvar_net = nn.Linear(hidden_dim, out_dim) 235 | 236 | self.c_net = nn.Linear(hidden_dim, n_clusters) 237 | 238 | def reparameterize_gumbel(self, p): 239 | # Gumbel-Softmax for discrete sampling 240 | sampled_one_hot = F.gumbel_softmax(p, tau=self.tau, hard=True, dim=-1) 241 | return sampled_one_hot 242 | 243 | def reparameterize_gaussian(self, mu, logvar): 244 | # Gaussian sampling 245 | logvar = logvar.mul(0.5).exp_() 246 | eps = Variable(logvar.data.new(logvar.size()).normal_()) 247 | return eps.mul(logvar).add_(mu) 248 | 249 | def forward(self, x): 250 | h, _ = self.rnn(x) 251 | # gaussian inference 252 | mu = self.mu_net(h) 253 | logvar = self.logvar_net(h) 254 | # gumbel inference 255 | p = self.c_net(h) 256 | prob = F.log_softmax(p, dim=-1) 257 | if self.training: 258 | z = self.reparameterize_gaussian(mu, logvar) 259 | sampled_one_hot = self.reparameterize_gumbel(p) 260 | return z, mu, logvar, sampled_one_hot, prob 261 | size = prob.shape 262 | prob = prob.flatten(end_dim=1) 263 | one_hot = torch.zeros_like(prob) 264 | one_hot[torch.arange(prob.shape[0]), prob.argmax(dim=-1)] = 1 265 | prob, one_hot = prob.view(*size[:2], self.n_clusters), one_hot.view(*size[:2], self.n_clusters) 266 | return mu, mu, logvar, one_hot, prob 267 | -------------------------------------------------------------------------------- /models/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from functools import reduce 6 | 7 | from models.modules import * 8 | 9 | 10 | class Encoder(nn.Module): 11 | """Convolutional Encoder.""" 12 | 13 | def __init__(self, latent_dim, encoder_config): 14 | 15 | super(Encoder, self).__init__() 16 | self.config = encoder_config 17 | 18 | for k in range(self.config["blocks"]): 19 | if k == 0: 20 | in_channel, out_channel = self.config['input_size'][-1], self.config['channels'][0] 21 | else: 22 | in_channel, out_channel = self.config['channels'][k-1], self.config['channels'][k] 23 | setattr(self, 'conv{}'.format(k+1), conv_block(in_channel, out_channel, down=self.config["down"])) 24 | 25 | if self.config["mean_pool"]: 26 | self.fc_in_dim = self.config['last_conv_size'][-1] 27 | else: 28 | self.fc_in_dim = reduce(lambda a, b: a*b, self.config['last_conv_size']) 29 | self.fc = nn.Linear(self.fc_in_dim, latent_dim) 30 | 31 | self.relu = nn.LeakyReLU(0.2, inplace=True) 32 | 33 | def forward(self, x): 34 | 35 | x = x.reshape(-1, *x.shape[2:]) 36 | if self.config["down"] == 'maxpool_unpool': 37 | pool_idx, hidden_size = [], [x.shape] 38 | 39 | for k in range(self.config["blocks"]): 40 | x = getattr(self, 'conv{}'.format(k+1))(x) 41 | x = self.relu(x) 42 | if self.config["down"] == 'maxpool_unpool': 43 | pool_idx.append(getattr(self, 'conv{}'.format(k+1))[1].pool_idx) 44 | hidden_size.append(x.shape) 45 | 46 | if self.config["down"] == 'maxpool_unpool': 47 | return x, pool_idx, hidden_size 48 | return x 49 | 50 | 51 | class Decoder(nn.Module): 52 | """Convolutional Decoder.""" 53 | 54 | def __init__(self, latent_dim, decoder_config): 55 | 56 | super(Decoder, self).__init__() 57 | self.config = decoder_config 58 | 59 | fc_output = reduce(lambda a, b: a*b, self.config['first_deconv_size']) 60 | self.fc = nn.Linear(latent_dim, fc_output) 61 | self.relu = nn.LeakyReLU(0.5, inplace=True) 62 | self.unpool = (self.config["up"]=='unpool') 63 | 64 | for k in range(self.config["blocks"]): 65 | if k == 0: 66 | in_channel, out_channel = self.config['first_deconv_size'][0], self.config['channels'][0] 67 | else: 68 | in_channel, out_channel = self.config['channels'][k-1], self.config['channels'][k] 69 | 70 | if k != self.config["blocks"]-1: 71 | setattr(self, 'deconv{}'.format(k+1), deconv_block(in_channel, out_channel, up=self.config["up"])) 72 | else: 73 | setattr(self, 'deconv{}'.format(k+1), deconv_block(in_channel, out_channel, up=self.config["up"], last_layer=True)) 74 | 75 | def forward(self, x): 76 | 77 | if self.unpool: 78 | x, pool_idx, hidden_size = x 79 | x = x.view(x.shape[0], *self.config['first_deconv_size']) 80 | for k in range(self.config["blocks"]): 81 | if self.unpool: 82 | x = getattr(self, 'deconv{}'.format(k+1))((x, pool_idx[-k-1], hidden_size[-k-2])) 83 | else: 84 | x = getattr(self, 'deconv{}'.format(k+1))(x) 85 | if k < self.config["blocks"]-1: 86 | x = self.relu(x) 87 | else: 88 | x = torch.sigmoid(x) 89 | return x 90 | 91 | 92 | class BehaveNet(nn.Module): 93 | 94 | def __init__(self, config): 95 | super(BehaveNet, self).__init__() 96 | assert config["encoder"]["down"] == "maxpool_unpool" 97 | assert config["decoder"]["up"] == "unpool" 98 | assert config["encoder"]["input_size"][-1] == 2 99 | assert config["decoder"]["channels"][-1] == 2 100 | self.unpool = True 101 | self.config = config 102 | self.latent_dim = config["latent"] 103 | self.fpc = config["frames_per_clip"] 104 | 105 | self.encoder = Encoder(config["latent"], config["encoder"]) 106 | self.decoder = Decoder(config["latent"], config["decoder"]) 107 | self.dense1 = nn.Linear(reduce(lambda a, b: a*b, config["encoder"]['last_conv_size']), config["latent"]) 108 | self.dense2 = nn.Linear(config["latent"], reduce(lambda a, b: a*b, config["decoder"]['first_deconv_size'])) 109 | 110 | self.relu = nn.LeakyReLU(0.5, inplace=True) 111 | 112 | def forward(self, x1, x2): 113 | x = torch.cat([x1, x2], dim=2) 114 | x, pool_idx, hidden_size = self.encoder(x) 115 | x = self.relu(self.dense1(x.flatten(start_dim=1))) 116 | self.latent = x 117 | x = self.relu(self.dense2(x)) 118 | x = self.decoder((x, pool_idx, hidden_size)) 119 | x = x.view(-1, self.config["frames_per_clip"], *x.shape[1:]) 120 | return x[:, :, :1], x[:, :, 1:] 121 | 122 | 123 | class DisAE(nn.Module): 124 | """Multi-View Disentangled Convolutional Encoder-Decoder.""" 125 | 126 | def __init__(self, config): 127 | 128 | super(DisAE, self).__init__() 129 | if config["decoder"]["up"] == 'unpool': 130 | assert config["encoder"]["pool"] 131 | self.unpool = (config["decoder"]["up"] == 'unpool') 132 | self.config = config 133 | self.pose_dim = config["pose"] 134 | self.ct_dim = config["content"] 135 | self.fpc = config["frames_per_clip"] 136 | self.first_frame = config["first_frame"] 137 | 138 | # rollout / dynamic / pose encoder 139 | self.encoder1 = Encoder(config["pose"], config["encoder_ps"]) 140 | self.decoder1 = Decoder(config["pose"], config["decoder"]) 141 | 142 | self.encoder2 = Encoder(config["pose"], config["encoder_ps"]) 143 | self.decoder2 = Decoder(config["pose"], config["decoder"]) 144 | 145 | self.view_fuse = nn.Conv2d(256*2, config["pose"], 8, stride=1) 146 | 147 | # context / style / content encoder 148 | self.encoder1_ct = Encoder(config["content"], config["encoder_ct"]) 149 | self.encoder2_ct = Encoder(config["content"], config["encoder_ct"]) 150 | self.ct_fuse = nn.Conv2d(256*2, config["content"], 1, stride=1) 151 | 152 | # combine pose & content 153 | self.combine = nn.Conv2d(config["pose"]+config["content"], 256*2, 1, stride=1) 154 | 155 | self.relu = nn.LeakyReLU(0.2, inplace=True) 156 | 157 | 158 | def forward(self, x1, x2): 159 | 160 | # context / style / content encoding 161 | if self.first_frame: 162 | c1, c2 = self.encoder1_ct(x1[:, :1]), self.encoder2_ct(x2[:, :1]) 163 | c = self.relu(self.ct_fuse(torch.cat([c1, c2], dim=1))) 164 | self.content = c 165 | c = c.unsqueeze(dim=1).repeat_interleave(self.fpc, dim=1) 166 | x1, x2 = x1[:, 1:], x2[:, 1:] 167 | else: 168 | c1, c2 = self.encoder1_ct(x1[:, :n_past]), self.encoder2_ct(x2[:, :n_past]) 169 | c = self.relu(self.ct_fuse(torch.cat([c1, c2], dim=1))) 170 | c = c.view(-1, n_past, *c.shape[1:]) 171 | self.content = c 172 | c = torch.cat([c[:, :-1], c[:, -1:].repeat_interleave(n_future, dim=1)], dim=1) 173 | 174 | # rollout / dynamic / pose encoding 175 | p1, p2 = self.encoder1(x1), self.encoder2(x2) 176 | p = self.relu(self.view_fuse(torch.cat([p1, p2], dim=1))) 177 | self.latent = p.view(-1, self.fpc, self.pose_dim) 178 | 179 | # rollout / dynamic / pose decoding 180 | p = self.relu(self.combine(torch.cat([p.repeat(1, 1, 8, 8), c.flatten(end_dim=1)], dim=1))) 181 | p1, p2 = p[:, :256], p[:, 256:] 182 | 183 | if self.unpool: 184 | p1 = self.decoder1((p1, pool_idx1, hidden_size1)) 185 | p2 = self.decoder2((p2, pool_idx2, hidden_size2)) 186 | else: 187 | p1, p2 = self.decoder1(p1), self.decoder2(p2) 188 | p1, p2 = p1.view(-1, self.fpc, *p1.shape[1:]), p2.view(-1, self.fpc, *p2.shape[1:]) 189 | 190 | return p1, p2 191 | 192 | 193 | class DBE(nn.Module): 194 | # assume uniform prior and same covariance 195 | def __init__(self, config): 196 | super(DBE, self).__init__() 197 | if config["decoder"]["up"] == 'unpool': 198 | assert config["encoder"]["pool"] 199 | self.unpool = (config["decoder"]["up"] == 'unpool') 200 | self.config = config 201 | self.pose_dim = config["pose"] 202 | self.ct_dim = config["content"] 203 | self.state_dim = config["state_dim"] 204 | self.fpc = config["frames_per_clip"] 205 | self.indep = config["independent_cluster"] 206 | self.diffvar = config["different_vars"] 207 | self.first_frame = config["first_frame"] 208 | self.cond = config["ct_cond"] 209 | self.straight_through = config["straight_through"] 210 | self.context_sub = config["context_substraction"] 211 | 212 | # rollout / dynamic / pose encoder 213 | self.encoder1 = Encoder(config["pose"], config["encoder_ps"]) 214 | self.decoder1 = Decoder(config["state_dim"], config["decoder"]) 215 | 216 | self.encoder2 = Encoder(config["pose"], config["encoder_ps"]) 217 | self.decoder2 = Decoder(config["state_dim"], config["decoder"]) 218 | 219 | self.view_fuse = nn.Conv2d(256*2, config["pose"], 8, stride=1) 220 | if config["ct_cond"]: 221 | self.ct_cond = nn.Conv2d(config["content"], config["cond_dim"], 8, stride=1) 222 | self.posterior = GM_Recurrent(in_dim=config["pose"]+config["cond_dim"], out_dim=config["gaussian_dim"], n_clusters=config["n_clusters"], n_layer=config["rnn_layers"], tau=config["tau"], bid=config["bidirectional"]) 223 | else: 224 | self.posterior = GM_Recurrent(in_dim=config["pose"], out_dim=config["gaussian_dim"], n_clusters=config["n_clusters"], n_layer=config["rnn_layers"], tau=config["tau"], bid=config["bidirectional"]) 225 | self.init_state_posterior = Gaussian_MLP(in_dim=config["pose"]*config["n_past"], out_dim=config["state_dim"]) 226 | self.state_model = nn.Linear(config["state_dim"]+config["gaussian_dim"], config["state_dim"]) 227 | nn.init.normal_(self.state_model.weight.data, 0.0, 0.01) 228 | nn.init.constant_(self.state_model.bias.data, 0.0) 229 | self.centroids = nn.Embedding(config["n_clusters"], config["gaussian_dim"]) 230 | if config["different_vars"]: 231 | self.logvars = nn.Embedding(config["n_clusters"], config["gaussian_dim"]) 232 | 233 | if not config["independent_cluster"]: 234 | self.c_net = nn.Linear(config["state_dim"], config["n_clusters"]) 235 | 236 | # context / style / content encoder 237 | self.encoder1_ct = Encoder(config["content"], config["encoder_ct"]) 238 | self.encoder2_ct = Encoder(config["content"], config["encoder_ct"]) 239 | if config["straight_through"]: 240 | self.ct_fuse = nn.Conv2d(256*2, config["content"], 1, stride=1) 241 | else: 242 | self.ct_fuse = nn.Conv2d(256*2, config["content"], 8, stride=1) 243 | 244 | # combine pose & content 245 | if config["straight_through"]: 246 | self.combine = nn.ConvTranspose2d(config["state_dim"]+config["content"], 256*2, 1, stride=1) 247 | else: 248 | self.combine = nn.ConvTranspose2d(config["state_dim"]+config["content"], 256*2, 8, stride=1) 249 | 250 | self.relu = nn.LeakyReLU(0.2, inplace=True) 251 | 252 | def set_steps(self, n_past, n_future): 253 | self.n_past, self.n_future = n_past, n_future 254 | 255 | def forward(self, x1, x2): 256 | 257 | n_past, n_future = self.n_past, self.n_future 258 | assert self.fpc == n_past + n_future 259 | if self.first_frame: 260 | c1, c2 = self.encoder1_ct(x1[:, :1]), self.encoder2_ct(x2[:, :1]) 261 | c = self.relu(self.ct_fuse(torch.cat([c1, c2], dim=1))) 262 | self.content = c 263 | c = c.unsqueeze(dim=1).repeat_interleave(self.fpc, dim=1) 264 | if self.context_sub: 265 | x1, x2 = x1[:, 1:]-x1[:, :1], x2[:, 1:]-x2[:, :1] 266 | else: 267 | x1, x2 = x1[:, 1:], x2[:, 1:] 268 | else: 269 | c1, c2 = self.encoder1_ct(x1[:, :n_past]), self.encoder2_ct(x2[:, :n_past]) 270 | c = self.relu(self.ct_fuse(torch.cat([c1, c2], dim=1))) 271 | c = c.view(-1, n_past, *c.shape[1:]) 272 | self.content = c 273 | c = torch.cat([c[:, :-1], c[:, -1:].repeat_interleave(n_future, dim=1)], dim=1) 274 | 275 | # rollout / dynamic / pose encoding 276 | p1, p2 = self.encoder1(x1), self.encoder2(x2) 277 | p = self.relu(self.view_fuse(torch.cat([p1, p2], dim=1))) 278 | p = p.view(-1, self.fpc, self.pose_dim) 279 | 280 | if self.cond: 281 | cc = self.relu(self.ct_cond(self.content))[:, :, 0, 0] 282 | g, mu, logvar, one_hot, prob1 = self.posterior(torch.cat([p, cc.unsqueeze(dim=1).repeat_interleave(self.fpc, dim=1)], dim=2)) 283 | else: 284 | g, mu, logvar, one_hot, prob1 = self.posterior(p) 285 | 286 | if self.diffvar: 287 | g = g * torch.matmul(one_hot, self.logvars.weight.mul(0.5).exp()) + torch.matmul(one_hot, self.centroids.weight) 288 | else: 289 | g = g + torch.matmul(one_hot, self.centroids.weight) 290 | self.cluster = one_hot.detach() 291 | 292 | s = [] 293 | s0, mu_s0, logvar_s0 = self.init_state_posterior(p[:, :n_past].flatten(start_dim=1)) 294 | si = s0 295 | s.append(si) 296 | for i in range(1, n_past+n_future): 297 | si = self.state_model(torch.cat([si, g[:, i]], dim=1)) 298 | s.append(si) 299 | s = torch.stack(s, dim=1) 300 | if not self.indep: 301 | pre_softmax = self.c_net(s[:, :-1]) 302 | if torch.any(torch.isnan(pre_softmax)): 303 | print('pre-softmax tensors contain nan!') 304 | if torch.any(torch.isinf(pre_softmax)): 305 | print('pre-softmax tensors contain inf!') 306 | prob2 = F.log_softmax(pre_softmax, dim=-1) 307 | p = s 308 | self.latent = p.detach() 309 | 310 | # rollout / dynamic / pose decoding 311 | if self.straight_through: 312 | p = self.relu(self.combine(torch.cat([p.view(-1, self.state_dim, 1, 1).repeat(1, 1, 8, 8), c.flatten(end_dim=1)], dim=1))) 313 | else: 314 | p = self.relu(self.combine(torch.cat([p.view(-1, self.state_dim, 1, 1), c.flatten(end_dim=1)], dim=1))) 315 | p1, p2 = p[:, :256], p[:, 256:] 316 | 317 | if self.unpool: 318 | p1 = self.decoder1((p1, pool_idx1, hidden_size1)) 319 | p2 = self.decoder2((p2, pool_idx2, hidden_size2)) 320 | else: 321 | p1, p2 = self.decoder1(p1), self.decoder2(p2) 322 | p1, p2 = p1.view(-1, self.fpc, *p1.shape[1:]), p2.view(-1, self.fpc, *p2.shape[1:]) 323 | 324 | if not self.indep: 325 | return (p1, p2), (mu, logvar), (mu_s0, logvar_s0), (prob1[:, 1:], prob2) 326 | return (p1, p2), (mu, logvar), (mu_s0, logvar_s0), (prob1) 327 | 328 | def generate(self, x1, x2, n_past=None, n_future=None, c=None): 329 | assert self.fpc == n_past + n_future 330 | if self.first_frame: 331 | c1, c2 = self.encoder1_ct(x1[:, :1]), self.encoder2_ct(x2[:, :1]) 332 | ct = self.relu(self.ct_fuse(torch.cat([c1, c2], dim=1))) 333 | self.content = ct 334 | ct = ct.unsqueeze(dim=1).repeat_interleave(self.fpc, dim=1) 335 | x1, x2 = x1[:, 1:], x2[:, 1:] 336 | else: 337 | c1, c2 = self.encoder1_ct(x1[:, :n_past]), self.encoder2_ct(x2[:, :n_past]) 338 | ct = self.relu(self.ct_fuse(torch.cat([c1, c2], dim=1))) 339 | ct = ct.view(-1, n_past, *c.shape[1:]) 340 | self.content = ct 341 | ct = torch.cat([ct[:, :-1], ct[:, -1:].repeat_interleave(n_future, dim=1)], dim=1) 342 | 343 | # rollout / dynamic / pose encoding 344 | p1, p2 = self.encoder1(x1), self.encoder2(x2) 345 | p = self.relu(self.view_fuse(torch.cat([p1, p2], dim=1))) 346 | p = p.view(-1, n_past, self.pose_dim) 347 | 348 | g, mu, logvar, one_hot, prob1 = self.posterior(p) 349 | if self.diffvar: 350 | g = g * torch.matmul(one_hot, self.logvars.weight.mul(0.5).exp()) + torch.matmul(one_hot, self.centroids.weight) 351 | else: 352 | g = g + torch.matmul(one_hot, self.centroids.weight) 353 | self.cluster = one_hot.detach() 354 | 355 | s = [] 356 | s0, mu_s0, logvar_s0 = self.init_state_posterior(p[:, :n_past].flatten(start_dim=1)) 357 | si = s0 358 | s.append(si) 359 | for i in range(1, n_past): 360 | si = self.state_model(torch.cat([si, g[:, i]], dim=1)) 361 | s.append(si) 362 | 363 | for i in range(0, n_future): 364 | if c is None: 365 | prob_dist = torch.distributions.Categorical(F.softmax(self.c_net(si), dim=-1)) # probs should be of size batch x classes 366 | print(F.softmax(self.c_net(si), dim=-1)[0].argmax()) 367 | ci = prob_dist.sample() 368 | else: 369 | ci = c[:, i] 370 | gi = torch.normal(0, 1, size=(ci.shape[0], self.centroids.weight.shape[1])).to(si.device) + self.centroids.weight[ci] 371 | si = self.state_model(torch.cat([si, gi], dim=1)) 372 | s.append(si) 373 | s = torch.stack(s, dim=1) 374 | p = s 375 | self.latent = p.detach() 376 | 377 | # rollout / dynamic / pose decoding 378 | p = self.relu(self.combine(torch.cat([p.view(-1, self.state_dim, 1, 1).repeat(1, 1, 8, 8), ct.flatten(end_dim=1)], dim=1))) 379 | p1, p2 = p[:, :256], p[:, 256:] 380 | 381 | if self.unpool: 382 | p1 = self.decoder1((p1, pool_idx1, hidden_size1)) 383 | p2 = self.decoder2((p2, pool_idx2, hidden_size2)) 384 | else: 385 | p1, p2 = self.decoder1(p1), self.decoder2(p2) 386 | p1, p2 = p1.view(-1, self.fpc, *p1.shape[1:]), p2.view(-1, self.fpc, *p2.shape[1:]) 387 | 388 | return p1, p2 389 | --------------------------------------------------------------------------------