├── img └── teaser.png ├── Pipfile ├── LICENSE ├── README.md ├── config └── config_TV.cfg ├── src ├── plot_latent_space.py ├── main.py ├── capspose_flags.py ├── datamodule.py ├── utils.py ├── models.py └── layers.py └── .gitignore /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmlab-cv/DECA/HEAD/img/teaser.png -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | numpy = "*" 10 | scipy = "*" 11 | tqdm = "*" 12 | opencv-python = "*" 13 | torchvision = "*" 14 | rope = "*" 15 | adabound = "*" 16 | pytorch-lightning-bolts = "*" 17 | matplotlib = "*" 18 | pytorch-lightning = "==0.10.0" 19 | open3d = "*" 20 | pillow = "*" 21 | 22 | [requires] 23 | python_version = "3.8" 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 mmlab-cv 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deca-deep-viewpoint-equivariant-human-pose/pose-estimation-on-itop-top-view)](https://paperswithcode.com/sota/pose-estimation-on-itop-top-view?p=deca-deep-viewpoint-equivariant-human-pose) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deca-deep-viewpoint-equivariant-human-pose/pose-estimation-on-itop-front-view)](https://paperswithcode.com/sota/pose-estimation-on-itop-front-view?p=deca-deep-viewpoint-equivariant-human-pose) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2108.08557-00ff00.svg)](https://arxiv.org/abs/2108.08557) 4 | 5 | # DECA 6 | Official code for the ICCV 2021 paper "DECA: Deep viewpoint-Equivariant human pose estimation using Capsule Autoencoders". 7 | All the code is written using Pytorch Lightning. Please use [Pipenv](https://pipenv.pypa.io/en/latest/) to configure the virtual environment required to run the code. 8 | 9 | ![Teaser Image](/img/teaser.png) 10 | 11 | ## How to run 12 | Use the following command to configure the virtual environment: 13 | ``` 14 | pipenv install 15 | ``` 16 | To configure all the network parameters, including the dataset paths and hyperparameters, please edit the file: 17 | ``` 18 | config/config_TV.cfg 19 | ``` 20 | or add each parameter as a runtime flag while executing the main.py file as follows: 21 | ``` 22 | python main.py --flagfile config/config_TV.cfg 23 | ``` 24 | As an example, to run the network in training mode with a dataset stored in , you can run the following command: 25 | ``` 26 | python main.py --flagfile config/config_TV.cfg --mode train --dataset_dir 27 | ``` 28 | -------------------------------------------------------------------------------- /config/config_TV.cfg: -------------------------------------------------------------------------------- 1 | # DIRECTORIES 2 | 3 | --dataset_dir=/media/disi/New Volume/Datasets/PANOPTIC_CAPS 4 | --dataset=front 5 | 6 | # NETWORK PARAMETERS 7 | 8 | # input-output 9 | --n_channels=3 10 | --n_classes=19 11 | --input_width=256 12 | --input_height=256 13 | --input_channels=3 14 | --class_names=0 15 | --class_names=1 16 | --class_names=2 17 | --class_names=3 18 | --class_names=4 19 | --class_names=5 20 | --class_names=6 21 | --class_names=7 22 | --class_names=8 23 | --class_names=9 24 | --class_names=10 25 | --class_names=11 26 | --class_names=12 27 | --class_names=13 28 | --class_names=14 29 | --class_names=15 30 | --class_names=16 31 | --class_names=17 32 | --class_names=18 33 | --experiments_dir=/home/nicolagarau/Experiments 34 | --summaries_dir='NA' 35 | --checkpoint_dir='NA' 36 | --step=0 37 | 38 | # CapsNet 39 | --arch=64 40 | --arch=16 41 | --arch=16 42 | --arch=16 43 | --arch=19 44 | --F=5 45 | --K=3 46 | --P=4 47 | 48 | # convolution 49 | --conv_stride=2 50 | --conv_padding=0 51 | --conv_kernel=3 52 | 53 | # caps_conv_1 54 | --caps_conv_1_stride=2 55 | --caps_conv_1_padding=0 56 | 57 | # caps_conv_2 58 | --caps_conv_2_stride=1 59 | --caps_conv_2_padding=0 60 | 61 | # class_caps 62 | --class_caps_stride=1 63 | --class_caps_padding=0 64 | 65 | # TRAINING PARAMETERS 66 | 67 | --batch_size=128 68 | --seed=42 69 | --n_epochs=1237 70 | --learning_rate=1e-5 71 | --weight_decay=1e-2 72 | --routing_iter=3 73 | --pose_dim=4 74 | --padding=4 75 | --brightness=0.5 76 | --contrast=0.5 77 | --hue=0.5 78 | --patience=1e-3 79 | --crop_dim=32 80 | --load_checkpoint_dir=lightning_logs/TOP_NEW/checkpoints/TOP_NEW.ckpt 81 | --mode=train 82 | --test_affNIST=False 83 | --resume_training=False 84 | --dataset_iterations=10 85 | --stddev=0.01 86 | --num_workers=8 87 | --accumulation=1 -------------------------------------------------------------------------------- /src/plot_latent_space.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.manifold import TSNE 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d import Axes3D 5 | import matplotlib.cm as cm 6 | 7 | colors = ['b', 'g', 'r', 'c', 'm', 'y', 'black', 'tab:orange', 'lime', 'tab:brown', 'fuchsia', 'tab:gray', 'yellow', 'aqua', 'tab:blue', 'indigo', 'navy', 'lightcoral', 'darkolivegreen'] 8 | # colors = cm.rainbow(np.linspace(0, 1, 15)) 9 | labels = ['0','1','2','3','4','5','6','7','8','9','10','11','12','13','14','15','16','17','18'] 10 | 11 | X = np.load("output/features.npy") # (7680, 304) --> (7680*19, 16) 12 | 13 | allX = None 14 | 15 | for i in range(19): 16 | Xnew = X[:,(16*i):16*(i+1)] 17 | # print(len(Xnew[0,:])) 18 | # allX.append(Xnew, axis=0) 19 | if i == 0: 20 | allX = Xnew 21 | else: 22 | allX = np.vstack([allX, Xnew]) 23 | 24 | # X = np.array([[0, 0, 0, 0], [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [6, 6, 6, 6]]) 25 | y = np.array([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18]) 26 | #y = ['a', 'b', 'c', 'd'] 27 | X_embedded = TSNE(n_components=3).fit_transform(allX) 28 | 29 | print(X_embedded[:,0]) 30 | print(X_embedded.shape) 31 | 32 | fig = plt.figure() 33 | ax = fig.add_subplot(111, projection='3d') 34 | for i, color in enumerate(colors): 35 | left = i*7680 36 | right = (i+1)*7680 37 | ax.scatter(X_embedded[left:right,0], X_embedded[left:right,1], X_embedded[left:right,2], zdir='z', c= color, label = labels[i]) 38 | 39 | 40 | plt.show() 41 | 42 | X_embedded = TSNE(n_components=2).fit_transform(allX) 43 | 44 | print(X_embedded[:,0]) 45 | print(X_embedded) 46 | 47 | fig = plt.figure() 48 | ax = fig.add_subplot(111) 49 | for i, color in enumerate(colors): 50 | left = i*7680 51 | right = (i+1)*7680 52 | ax.scatter(X_embedded[left:right,0], X_embedded[left:right,1], c= color, label = labels[i]) 53 | 54 | 55 | plt.show() -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | lightning_logs/ 132 | output/ 133 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | from utils import eval_image 2 | import sys 3 | from datamodule import CapsulePoseDataModule 4 | from models import CapsulePose 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | import warnings 11 | import capspose_flags 12 | import numpy as np 13 | import os 14 | from absl import app 15 | from absl import flags 16 | FLAGS = flags.FLAGS 17 | 18 | 19 | def init_all(): 20 | warnings.filterwarnings("ignore") 21 | 22 | # enable cudnn and its inbuilt auto-tuner to find the best algorithm to use for your hardware 23 | torch.backends.cudnn.enabled = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | # useful for run-time 27 | #torch.backends.cudnn.deterministic = True 28 | 29 | pl.seed_everything(FLAGS.seed) 30 | torch.cuda.empty_cache() 31 | 32 | 33 | def main(argv): 34 | init_all() 35 | print("Capsules architecture: ", FLAGS.arch) 36 | 37 | if FLAGS.mode == "train": 38 | dm = CapsulePoseDataModule(FLAGS) 39 | model = CapsulePose(FLAGS) 40 | if(FLAGS.resume_training): 41 | trainer = pl.Trainer(gpus=1, distributed_backend=None, resume_from_checkpoint=os.path.join( 42 | os.getcwd(), FLAGS.load_checkpoint_dir), max_epochs=10000) 43 | else: 44 | trainer = pl.Trainer(gpus=1, distributed_backend=None, max_epochs=10000) 45 | trainer.fit(model, dm) 46 | elif FLAGS.mode == "test": 47 | # Create modules 48 | dm = CapsulePoseDataModule(FLAGS) 49 | model = CapsulePose(FLAGS) 50 | model = model.load_from_checkpoint(os.path.join( 51 | os.getcwd(), FLAGS.load_checkpoint_dir), FLAGS=FLAGS) 52 | model.configure_optimizers() 53 | 54 | # Manually run prep methods on DataModule 55 | dm.prepare_data() 56 | dm.setup() 57 | 58 | # Run test on validation dataset 59 | trainer = pl.Trainer(gpus=1, distributed_backend=None, resume_from_checkpoint=os.path.join( 60 | os.getcwd(), FLAGS.load_checkpoint_dir), max_epochs=10000) 61 | trainer.test(model, test_dataloaders=dm.val_dataloader()) 62 | print(np.array(model.features).shape) 63 | np.save('output/features', np.array(model.features)) 64 | 65 | elif FLAGS.mode == "demo": 66 | model = CapsulePose(FLAGS) 67 | model = model.load_from_checkpoint(os.path.join( 68 | os.getcwd(), FLAGS.load_checkpoint_dir), FLAGS=FLAGS) 69 | model.configure_optimizers() 70 | model = model.cuda() 71 | 72 | eval_image(model) 73 | 74 | 75 | if __name__ == '__main__': 76 | app.run(main) 77 | -------------------------------------------------------------------------------- /src/capspose_flags.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | 3 | # DIRECTORIES 4 | 5 | flags.DEFINE_string('dataset_dir', "D:/Datasets", 'Dataset directory.') 6 | flags.DEFINE_string('dataset', "panoptic_tv", 'Dataset name.') 7 | 8 | # NETWORK PARAMETERS 9 | 10 | # input-output 11 | flags.DEFINE_integer('n_channels', 3, 'Number of image channels.') 12 | flags.DEFINE_integer('n_classes', 19, 'Number of classes.') 13 | flags.DEFINE_integer('input_width', 256, 'Images width.') 14 | flags.DEFINE_integer('input_height', 256, 'Images height.') 15 | flags.DEFINE_integer('input_channels', 3, 'Images channels.') 16 | flags.DEFINE_multi_string('class_names', [''], 'Class names.') 17 | flags.DEFINE_string('experiments_dir', "D:/Experiments", 18 | 'Experiments directory.') 19 | flags.DEFINE_string('summaries_dir', "NA", 'Summaries directory.') 20 | flags.DEFINE_string('checkpoint_dir', "NA", 'Checkpoint directory.') 21 | flags.DEFINE_integer('step', 0, 'Training step (changed at runtime).') 22 | 23 | # CapsNet 24 | flags.DEFINE_multi_integer( 25 | 'arch', [64, 8, 16, 16, 19], 'CapsNet parameters A, B, C, D, F.') 26 | flags.DEFINE_integer('F', 5, 'CapsNet parameter F.') 27 | flags.DEFINE_integer('K', 3, 'CapsNet parameter K.') 28 | flags.DEFINE_integer('P', 4, 'CapsNet parameter P.') 29 | 30 | # convolution 31 | flags.DEFINE_integer('conv_stride', 2, 'CNN stride.') 32 | flags.DEFINE_integer('conv_padding', 0, 'CNN padding.') 33 | flags.DEFINE_integer('conv_kernel', 3, 'CNN kernel size.') 34 | 35 | # caps_conv_1 36 | flags.DEFINE_integer('caps_conv_1_stride', 2, 'ConvCaps1 stride.') 37 | flags.DEFINE_integer('caps_conv_1_padding', 0, 'ConvCaps1 padding.') 38 | 39 | # caps_conv_1 40 | flags.DEFINE_integer('caps_conv_2_stride', 1, 'ConvCaps2 stride.') 41 | flags.DEFINE_integer('caps_conv_2_padding', 0, 'ConvCaps2 padding.') 42 | 43 | # class_caps 44 | flags.DEFINE_integer('class_caps_stride', 1, 'ClassCaps stride.') 45 | flags.DEFINE_integer('class_caps_padding', 0, 'ClassCaps padding.') 46 | 47 | # TRAINING PARAMETERS 48 | 49 | flags.DEFINE_integer('batch_size', 20, 'Batch size.') 50 | flags.DEFINE_integer('seed', 7, 'Seed.') 51 | flags.DEFINE_integer('n_epochs', 300, 'Number of training epochs.') 52 | flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') 53 | flags.DEFINE_float('weight_decay', 0, 'Weight decay.') 54 | flags.DEFINE_integer('routing_iter', 3, 'Number of routing iterations.') 55 | flags.DEFINE_integer('pose_dim', 4, 'Pose matrix size.') 56 | flags.DEFINE_integer('padding', 4, 'Paffing.') 57 | flags.DEFINE_float('brightness', 0, 'Brightness.') 58 | flags.DEFINE_float('contrast', 0.5, 'Contrast.') 59 | flags.DEFINE_float('hue', 0, 'Hue.') 60 | flags.DEFINE_float( 61 | 'patience', 1e-3, 'Number of epochs with no improvement after which learning rate will be reduced.') 62 | flags.DEFINE_integer('crop_dim', 32, 'Default crop size.') 63 | flags.DEFINE_string('load_checkpoint_dir', 'NA', 64 | 'Load previous existing checkpoint.') 65 | flags.DEFINE_string('mode', 'demo', 'train/test/demo.') 66 | flags.DEFINE_boolean('test_affNIST', False, 'Test affnist.') 67 | flags.DEFINE_boolean('resume_training', False, 68 | 'Resume training using a checkpoint.') 69 | flags.DEFINE_integer('dataset_iterations', 10, 'Dataset iterations.') 70 | flags.DEFINE_float('stddev', 0.01, 'Standard deviation.') 71 | flags.DEFINE_integer('num_workers', 10, 'Number of workers.') 72 | flags.DEFINE_integer('accumulation', 1, 'Gradient accumulation iterations.') 73 | -------------------------------------------------------------------------------- /src/datamodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import transforms 7 | import os 8 | import re 9 | import cv2 10 | from random import shuffle 11 | 12 | from utils import compute_distances 13 | from utils import * 14 | from layers import * 15 | 16 | import pytorch_lightning as pl 17 | 18 | 19 | class CapsulePoseDataModule(pl.LightningDataModule): 20 | 21 | def __init__(self, FLAGS): 22 | super().__init__() 23 | self.FLAGS = FLAGS 24 | 25 | def prepare_data(self): 26 | dataset_path = os.path.join(self.FLAGS.dataset_dir, self.FLAGS.dataset) 27 | if os.path.isfile(dataset_path+'train_files.npy'): 28 | train_files = np.load(dataset_path+'train_files.npy') 29 | valid_files = np.load(dataset_path+'valid_files.npy') 30 | test_files = np.load(dataset_path+'test_files.npy') 31 | print('Files have been loaded') 32 | else: 33 | train_files = [f for f in os.listdir( 34 | os.path.join(dataset_path, 'train')) if not re.search(r'.npy', f)] # --> npy = (32,3) 35 | np.save('train_files', train_files) 36 | valid_files = [f for f in os.listdir( 37 | os.path.join(dataset_path, 'validation')) if not re.search(r'.npy', f)] 38 | np.save('valid_files', valid_files) 39 | test_files = [f for f in os.listdir( 40 | os.path.join(dataset_path, 'validation')) if not re.search(r'.npy', f)] 41 | np.save('test_files', test_files) 42 | train_files = np.load('train_files.npy') 43 | valid_files = np.load('valid_files.npy') 44 | test_files = np.load('test_files.npy') 45 | print('Files have been created') 46 | 47 | num_images = len(train_files) 48 | print('Loaded Training samples: ' + str(num_images)) 49 | 50 | self.num_train_examples = len(train_files) 51 | self.train_indices = list(range(self.num_train_examples)) 52 | self.num_valid_examples = len(valid_files) 53 | self.valid_indices = list(range(self.num_valid_examples)) 54 | self.num_test_examples = len(test_files) 55 | self.test_indices = list(range(self.num_test_examples)) 56 | 57 | shuffle(self.train_indices) 58 | shuffle(self.valid_indices) 59 | shuffle(self.test_indices) 60 | 61 | self.train_files = train_files[self.train_indices] 62 | self.validation_files = valid_files[self.valid_indices] 63 | 64 | def train_dataloader(self): 65 | working_dir = os.path.join(self.FLAGS.dataset_dir, self.FLAGS.dataset) 66 | dataset_path = os.path.join(working_dir, 'train') 67 | 68 | transform = transforms.Compose([ 69 | transforms.ToPILImage(), 70 | transforms.ColorJitter( 71 | brightness=self.FLAGS.brightness, contrast=self.FLAGS.contrast, hue=self.FLAGS.hue), 72 | transforms.ToTensor(), 73 | Standardize()]) 74 | 75 | dataset = poseDATA(self.FLAGS, dataset_path, self.train_files, 76 | self.train_indices, transform) 77 | 78 | capsulepose_train = FastDataLoader(dataset, shuffle=True, pin_memory=True, 79 | num_workers=self.FLAGS.num_workers, batch_size=self.FLAGS.batch_size, drop_last=True) 80 | 81 | return capsulepose_train 82 | 83 | def val_dataloader(self): 84 | working_dir = os.path.join(self.FLAGS.dataset_dir, self.FLAGS.dataset) 85 | dataset_path = os.path.join(working_dir, 'validation') 86 | 87 | transform = transforms.Compose([ 88 | transforms.ToPILImage(), 89 | transforms.ToTensor(), 90 | Standardize()]) 91 | 92 | dataset = poseDATA(self.FLAGS, dataset_path, self.validation_files, 93 | self.valid_indices, transform) 94 | 95 | capsulepose_val = FastDataLoader(dataset, shuffle=False, pin_memory=True, 96 | num_workers=self.FLAGS.num_workers, batch_size=self.FLAGS.batch_size, drop_last=True) 97 | 98 | return capsulepose_val 99 | 100 | def test_dataloader(self): 101 | working_dir = os.path.join(self.FLAGS.dataset_dir, self.FLAGS.dataset) 102 | dataset_path = os.path.join(working_dir, 'validation') 103 | 104 | transform = transforms.Compose([ 105 | transforms.ToPILImage(), 106 | transforms.ToTensor(), 107 | Standardize()]) 108 | 109 | dataset = poseDATA(self.FLAGS, dataset_path, self.validation_files, 110 | self.valid_indices, transform) 111 | 112 | capsulepose_test = FastDataLoader(dataset, shuffle=False, pin_memory=True, 113 | num_workers=self.FLAGS.num_workers, batch_size=self.FLAGS.batch_size, drop_last=True) 114 | 115 | return capsulepose_test 116 | 117 | 118 | class _RepeatSampler(object): 119 | """ Sampler that repeats forever. 120 | 121 | Args: 122 | sampler (Sampler) 123 | """ 124 | 125 | def __init__(self, sampler): 126 | self.sampler = sampler 127 | 128 | def __iter__(self): 129 | while True: 130 | yield from iter(self.sampler) 131 | 132 | 133 | class FastDataLoader(torch.utils.data.dataloader.DataLoader): 134 | 135 | def __init__(self, *args, **kwargs): 136 | super().__init__(*args, **kwargs) 137 | object.__setattr__(self, 'batch_sampler', 138 | _RepeatSampler(self.batch_sampler)) 139 | self.iterator = super().__iter__() 140 | 141 | def __len__(self): 142 | return len(self.batch_sampler.sampler) 143 | 144 | def __iter__(self): 145 | for i in range(len(self)): 146 | yield next(self.iterator) 147 | 148 | 149 | class poseDATA(Dataset): 150 | ''' In: 151 | data_path (string): path to the dataset split folder, i.e. train/valid/test 152 | transform (callable, optional): transform to be applied on a sample. 153 | Out: 154 | sample (dict): sample data and respective label''' 155 | 156 | def __init__(self, FLAGS, data_path, input_list, indices, transform): 157 | 158 | self.data_path = data_path 159 | self.data, self.labels = [], [] 160 | self.num_images = len(input_list) 161 | self.list_images = input_list 162 | self.list_masks = [] 163 | self.indices = indices 164 | self.dataset_iterations = FLAGS.dataset_iterations 165 | self.art_select_ = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 166 | 12, 13, 14, 15, 16, 17, 18] 167 | self.split_size = self.num_images // FLAGS.num_workers 168 | self.starting_point = -1 169 | self.idx = -1 170 | self.FLAGS = FLAGS 171 | self.transform = transform 172 | 173 | for f in self.list_images: 174 | name = f[:-4] + '.npy' 175 | self.list_masks.append(name) 176 | 177 | self.c = list(zip(self.list_images, self.list_masks)) 178 | print("Data path: ", self.data_path) 179 | if('test' in self.data_path): 180 | print("TODO: implement test loading!") 181 | 182 | if('train' in self.data_path): 183 | shuffle(self.c) 184 | self.list_images, self.list_masks = zip(*(self.c)) 185 | print('Training list has been shuffled (' + 186 | str(self.num_images) + ' images)') 187 | 188 | else: 189 | self.list_images, self.list_masks = zip(*(self.c)) 190 | print('Validation list has not been shuffled (' + 191 | str(self.num_images) + ' images)') 192 | 193 | def __len__(self): 194 | return self.num_images * self.dataset_iterations * self.FLAGS.batch_size // self.FLAGS.n_epochs 195 | 196 | def __getitem__(self, idx): 197 | 198 | if(self.starting_point == -1): 199 | self.starting_point = ( 200 | torch.utils.data.get_worker_info().id) * self.split_size 201 | 202 | if(self.idx == -1): 203 | next_id = self.starting_point 204 | else: 205 | next_id = (self.idx + 1) 206 | 207 | if(self.idx >= self.starting_point + self.split_size - 1): 208 | self.starting_point = ( 209 | torch.utils.data.get_worker_info().id) * self.split_size 210 | self.idx = self.idx - self.split_size 211 | else: 212 | self.idx = next_id 213 | 214 | im = cv2.imread(self.data_path + '/' + 215 | self.list_images[self.indices[self.idx]]) 216 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 217 | 218 | depth = cv2.imread(self.data_path + 'depth/' + 219 | self.list_images[self.indices[self.idx]].replace("render", "depth")) 220 | depth = cv2.cvtColor(depth, cv2.COLOR_BGR2RGB) 221 | 222 | msk = np.float32( 223 | np.load(self.data_path + '/' + self.list_masks[self.indices[self.idx]])) 224 | bias = np.repeat(np.reshape(msk[0, :], [1, 3]), 19, axis=0) 225 | msk = msk - bias 226 | msk = rotate(msk) 227 | msk = msk[self.art_select_, :] 228 | 229 | msk2d = np.float32( 230 | np.load(self.data_path + '2d' + '/' + self.list_masks[self.indices[self.idx]])) # --> (32,2) 231 | msk2d = msk2d[self.art_select_, :] 232 | 233 | 234 | if(self.transform): 235 | image = self.transform(im) 236 | 237 | msk = np.reshape( 238 | msk, [self.FLAGS.n_classes, 3]) / 100. 239 | 240 | msk2d = np.reshape( 241 | msk2d, [self.FLAGS.n_classes, 2]) / 256. 242 | 243 | depth = 1. - (depth / 255.) 244 | 245 | label = {'msk': msk, 246 | # 'rm': rm, 247 | 'msk2d': msk2d, 248 | 'depth': depth} 249 | 250 | 251 | return image, label, self.indices[self.idx] # (X, Y) 252 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | import os, glob 6 | import re 7 | import cv2 8 | import math 9 | from random import shuffle 10 | import torch.nn.functional as F 11 | from torchvision import transforms 12 | from tqdm import tqdm 13 | 14 | from PIL import Image 15 | import scipy.io as io 16 | import matplotlib.pyplot as plt 17 | import matplotlib.animation as manimation 18 | from mpl_toolkits.mplot3d import Axes3D 19 | 20 | import time 21 | import open3d as o3d 22 | from queue import Queue 23 | 24 | class Standardize(object): 25 | """ Standardizes a 'PIL Image' such that each channel 26 | gets zero mean and unit variance. """ 27 | def __call__(self, img): 28 | return (img - img.mean(dim=(1,2), keepdim=True)) \ 29 | / torch.clamp(img.std(dim=(1,2), keepdim=True), min=1e-8) 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ + '()' 33 | 34 | 35 | def rotate(xyz): 36 | def dotproduct(v1, v2): 37 | return sum((a * b) for a, b in zip(v1, v2)) 38 | 39 | def length(v): 40 | return math.sqrt(dotproduct(v, v)) 41 | 42 | def angle(v1, v2): 43 | num = dotproduct(v1, v2) 44 | den = (length(v1) * length(v2)) 45 | if den == 0: 46 | print('den = 0') 47 | print(length(v1)) 48 | print(length(v2)) 49 | print(num) 50 | ratio = num/den 51 | ratio = np.minimum(1, ratio) 52 | ratio = np.maximum(-1, ratio) 53 | 54 | return math.acos(ratio) 55 | 56 | p1 = np.float32(xyz[1, :]) 57 | p2 = np.float32(xyz[6, :]) 58 | v1 = np.subtract(p2, p1) 59 | mod_v1 = np.sqrt(np.sum(v1 ** 2)) 60 | x = np.float32([1., 0., 0.]) 61 | y = np.float32([0., 1., 0.]) 62 | z = np.float32([0., 0., 1.]) 63 | theta = math.acos(np.sum(v1 * z) / (mod_v1 * 1)) * 360 / (2 * math.pi) 64 | # M = cv2.getAffineTransform() 65 | p = np.cross(v1, z) 66 | # if sum(p)==0: 67 | # p = np.cross(v1,y) 68 | p[2] = 0. 69 | # ang = -np.minimum(np.abs(angle(p, x)), 2 * math.pi - np.abs(angle(p, x))) 70 | ang = angle(x, p) 71 | 72 | if p[1] < 0: 73 | ang = -ang 74 | 75 | M = [[np.cos(ang), np.sin(ang), 0.], 76 | [-np.sin(ang), np.cos(ang), 0.], [0., 0., 1.]] 77 | M = np.reshape(M, [3, 3]) 78 | xyz = np.transpose(xyz) 79 | xyz_ = np.matmul(M, xyz) 80 | xyz_ = np.transpose(xyz_) 81 | 82 | return xyz_ 83 | 84 | 85 | def flip_3d(msk): 86 | msk[:, 1] = -msk[:, 1] 87 | return msk 88 | 89 | 90 | def compute_distances(FLAGS, labels3D, predictions3D, labels2D, predictions2D, labelsD, predictionsD): 91 | ED_list_3d = torch.sum(torch.square(predictions3D - labels3D), dim=2) 92 | ED_3d = torch.mean(ED_list_3d) 93 | EDs_3d = torch.mean(torch.sqrt(ED_list_3d)) 94 | 95 | ED_list_2d = torch.sum(torch.square(predictions2D - labels2D), dim=2) 96 | ED_2d = torch.mean(ED_list_2d) 97 | EDs_2d = torch.mean(torch.sqrt(ED_list_2d)) 98 | 99 | # print("P3D: ", predictions3D.shape) 100 | # print("L3D: ", labels3D.shape) 101 | # print("P2D: ", predictions2D.shape) 102 | # print("L2D: ", labels2D.shape) 103 | 104 | # print(torch.max(labelsD)) 105 | # print(torch.min(labelsD)) 106 | # print(torch.max(predictionsD)) 107 | # print(torch.min(predictionsD)) 108 | valid_mask = (labelsD > 0).detach() 109 | diff = (labelsD - predictionsD).abs() 110 | diff_masked = diff[valid_mask] 111 | ED_D = (diff_masked.mean() + diff.mean()) / 2. 112 | 113 | # cv2.imshow("Predicted", predictionsD.clone()[0].permute(1,2,0).cpu().detach().numpy()) 114 | # cv2.imshow("Real", labelsD.clone()[0].permute(1,2,0).cpu().detach().numpy()) 115 | # cv2.imshow("Diff", diff.clone()[0].permute(1,2,0).cpu().detach().numpy()) 116 | # cv2.waitKey(1) 117 | 118 | return ED_3d, ED_2d, EDs_3d, EDs_2d, ED_D 119 | 120 | 121 | def procrustes(X, Y, scaling=True, reflection='best'): 122 | """ 123 | A port of MATLAB's `procrustes` function to Numpy. 124 | 125 | Procrustes analysis determines a linear transformation (translation, 126 | reflection, orthogonal rotation and scaling) of the points in Y to best 127 | conform them to the points in matrix X, using the sum of squared errors 128 | as the goodness of fit criterion. 129 | 130 | d, Z, [tform] = procrustes(X, Y) 131 | 132 | Inputs: 133 | ------------ 134 | X, Y 135 | matrices of target and input coordinates. they must have equal 136 | numbers of points (rows), but Y may have fewer dimensions 137 | (columns) than X. 138 | 139 | scaling 140 | if False, the scaling component of the transformation is forced 141 | to 1 142 | 143 | reflection 144 | if 'best' (default), the transformation solution may or may not 145 | include a reflection component, depending on which fits the data 146 | best. setting reflection to True or False forces a solution with 147 | reflection or no reflection respectively. 148 | 149 | Outputs 150 | ------------ 151 | d 152 | the residual sum of squared errors, normalized according to a 153 | measure of the scale of X, ((X - X.mean(0))**2).sum() 154 | 155 | Z 156 | the matrix of transformed Y-values 157 | 158 | tform 159 | a dict specifying the rotation, translation and scaling that 160 | maps X --> Y 161 | 162 | """ 163 | 164 | n, m = X.shape 165 | ny, my = Y.shape 166 | 167 | muX = X.mean(0) 168 | muY = Y.mean(0) 169 | 170 | X0 = X - muX 171 | Y0 = Y - muY 172 | 173 | ssX = (X0 ** 2.).sum() 174 | ssY = (Y0 ** 2.).sum() 175 | 176 | # centred Frobenius norm 177 | normX = np.sqrt(ssX) 178 | normY = np.sqrt(ssY) 179 | 180 | # scale to equal (unit) norm 181 | X0 /= normX 182 | Y0 /= normY 183 | 184 | if my < m: 185 | Y0 = np.concatenate((Y0, np.zeros(n, m - my)), 0) 186 | 187 | # optimum rotation matrix of Y 188 | A = np.dot(X0.T, Y0) 189 | U, s, Vt = np.linalg.svd(A, full_matrices=False) 190 | V = Vt.T 191 | T = np.dot(V, U.T) 192 | 193 | if reflection is not 'best': 194 | 195 | # does the current solution use a reflection? 196 | have_reflection = np.linalg.det(T) < 0 197 | 198 | # if that's not what was specified, force another reflection 199 | if reflection != have_reflection: 200 | V[:, -1] *= -1 201 | s[-1] *= -1 202 | T = np.dot(V, U.T) 203 | 204 | traceTA = s.sum() 205 | 206 | if scaling: 207 | 208 | # optimum scaling of Y 209 | b = traceTA * normX / normY 210 | 211 | # standarised distance between X and b*Y*T + c 212 | d = 1 - traceTA ** 2 213 | 214 | # transformed coords 215 | Z = normX * traceTA * np.dot(Y0, T) + muX 216 | 217 | else: 218 | b = 1 219 | d = 1 + ssY / ssX - 2 * traceTA * normY / normX 220 | Z = normY * np.dot(Y0, T) + muX 221 | 222 | # transformation matrix 223 | if my < m: 224 | T = T[:my, :] 225 | c = muX - b * np.dot(muY, T) 226 | 227 | # transformation values 228 | tform = {'rotation': T, 'scale': b, 'translation': c} 229 | 230 | return d, Z, tform 231 | 232 | def plot_skeletons(FLAGS, fig, images_orig, links, preds_2D, gts_2D, preds_3D, gts_3D, preds_D, gts_D, writer, angle): 233 | plt.rcParams.update({'axes.labelsize': 'small'}) 234 | for index in range(0, FLAGS.batch_size): 235 | 236 | plt.clf() 237 | angle = (angle + 1) % 360 238 | 239 | ax_bb = fig.add_subplot(331) 240 | ax_bb.set_title('Input image') 241 | 242 | ax_hat_3D = fig.add_subplot(338, projection='3d') 243 | ax_hat_3D.set_title('3D prediction') 244 | ax_hat_3D.set_xlabel('X') 245 | ax_hat_3D.set_ylabel('Y') 246 | ax_hat_3D.set_zlabel('Z') 247 | ax_hat_3D.set_xlim([-100, 100]) 248 | ax_hat_3D.set_ylim([-100, 100]) 249 | ax_hat_3D.set_zlim([-100, 100]) 250 | ax_hat_3D.view_init(15, angle) 251 | ax_hat_3D.labelsize = 10 252 | 253 | ax_gt_3D = fig.add_subplot(339, projection='3d') 254 | ax_gt_3D.set_title('3D ground truth') 255 | ax_gt_3D.set_xlabel('X') 256 | ax_gt_3D.set_ylabel('Y') 257 | ax_gt_3D.set_zlabel('Z') 258 | ax_gt_3D.set_xlim([-100, 100]) 259 | ax_gt_3D.set_ylim([-100, 100]) 260 | ax_gt_3D.set_zlim([-100, 100]) 261 | ax_gt_3D.view_init(15, angle) 262 | 263 | ax_hat_2D = fig.add_subplot(335) 264 | ax_hat_2D.set_title('2D prediction') 265 | ax_hat_2D.set_xlabel('X') 266 | ax_hat_2D.set_ylabel('Y') 267 | ax_hat_2D.set_xlim([0, 1]) 268 | ax_hat_2D.set_ylim([0, 1]) 269 | 270 | ax_gt_2D = fig.add_subplot(336) 271 | ax_gt_2D.set_title('2D ground truth') 272 | ax_gt_2D.set_xlabel('X') 273 | ax_gt_2D.set_ylabel('Y') 274 | ax_gt_2D.set_xlim([0, 1]) 275 | ax_gt_2D.set_ylim([0, 1]) 276 | 277 | ax_hat_D = fig.add_subplot(332) 278 | ax_hat_D.set_title('Depth prediction') 279 | 280 | ax_gt_D = fig.add_subplot(333) 281 | ax_gt_D.set_title('Depth ground truth') 282 | 283 | ax_bb.imshow(np.reshape( 284 | images_orig[index], (FLAGS.input_height, FLAGS.input_width, FLAGS.n_channels))) 285 | colormaps = [ 286 | 'Greys_r', 'Purples_r', 'Blues_r', 'Greens_r', 'Oranges_r', 'Reds_r', 287 | 'YlOrBr_r', 'YlOrRd_r', 'OrRd_r', 'PuRd_r', 'RdPu_r', 'BuPu_r', 288 | 'GnBu_r', 'PuBu_r', 'YlGnBu_r', 'PuBuGn_r', 'BuGn_r', 'YlGn_r'] 289 | 290 | 291 | for i in range(len(links)): 292 | 293 | link = links[i] 294 | 295 | for j in range(len(link)): 296 | P2_hat_3D = preds_3D[index][i, :] 297 | P1_hat_3D = preds_3D[index][link[j], :] 298 | link_hat_3D = [list(x) 299 | for x in list(zip(P1_hat_3D, P2_hat_3D))] 300 | ax_hat_3D.plot( 301 | link_hat_3D[0], link_hat_3D[2], zs=[ -x for x in link_hat_3D[1]]) 302 | P2_gt_3D = gts_3D[index][i, :] 303 | P1_gt_3D = gts_3D[index][link[j], :] 304 | link_gt_3D = [list(x) for x in list(zip(P1_gt_3D, P2_gt_3D))] 305 | ax_gt_3D.plot(link_gt_3D[0], link_gt_3D[2], zs=[ -x for x in link_gt_3D[1]]) 306 | 307 | P2_hat_2D = preds_2D[index][i, :] 308 | P1_hat_2D = preds_2D[index][link[j], :] 309 | link_hat_2D = [list(x) 310 | for x in list(zip(P1_hat_2D, P2_hat_2D))] 311 | ax_hat_2D.plot( 312 | link_hat_2D[0], link_hat_2D[1]) 313 | P2_gt_2D = gts_2D[index][i, :] 314 | P1_gt_2D = gts_2D[index][link[j], :] 315 | link_gt_2D = [list(x) for x in list(zip(P1_gt_2D, P2_gt_2D))] 316 | ax_gt_2D.plot(link_gt_2D[0], link_gt_2D[1]) 317 | 318 | ax_gt_D.imshow(gts_D[index]) 319 | # ax_hat_D.imshow(preds_D[index].cpu()) 320 | ax_hat_D.imshow(preds_D[index]) 321 | 322 | plt.draw() 323 | fig.canvas.flush_events() 324 | plt.show(block=False) 325 | 326 | writer.grab_frame() 327 | 328 | return angle 329 | 330 | 331 | def eval_image(model): 332 | viewpoint = "top" 333 | sample = "05_00000000_rear" 334 | image = cv2.imread("/media/disi/New Volume/Datasets/PANOPTIC_CAPS/"+viewpoint+"/train/"+ sample +".png") 335 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 336 | transform = transforms.Compose([ 337 | transforms.ToPILImage(), 338 | transforms.ToTensor(), 339 | Standardize()]) 340 | image = transform(image) 341 | image_tensor = image.unsqueeze(0) 342 | # image_tensor = image_tensor.permute(0,3,1,2) 343 | input = torch.autograd.Variable(image_tensor) 344 | input = input.cuda() 345 | input = torch.cat(128*[input]) 346 | print("INPUT SHAPE: ", input.shape) 347 | yhat2D, yhat3D, yhatD, W_reg, _ = model(input) 348 | 349 | itop_labels = ['Head','Neck','LShould','RShould',"LElbow","RElbow","LHand","RHand","Torso","LHip","RHip","LKnee","RKnee","LFoot","RFoot"] 350 | 351 | import gzip 352 | msk3D = np.load("/media/disi/New Volume/Datasets/PANOPTIC_CAPS/"+viewpoint+"/train/"+sample+".npy") 353 | msk3D = torch.from_numpy(msk3D).float().unsqueeze(0).unsqueeze(-1) 354 | msk3D = torch.cat(128*[msk3D]) / 100. 355 | msk3D = center_skeleton(msk3D) 356 | msk3D = discretize(msk3D, 0, 1) 357 | print(msk3D.shape) 358 | 359 | pred = yhat3D.cpu().detach().numpy().squeeze(-1) 360 | gt = msk3D.cpu().detach().numpy().squeeze(-1) 361 | 362 | assert(pred.shape == gt.shape) 363 | assert(len(pred.shape) == 3) 364 | 365 | msk3D = msk3D.squeeze(3) 366 | yhat3D = yhat3D.squeeze(3) 367 | for i, p in enumerate(pred): 368 | d, Z, tform = procrustes( 369 | gt[i], pred[i]) 370 | pred[i] = Z 371 | 372 | print(yhat3D.shape) 373 | print(pred.shape) 374 | 375 | yhat3D = torch.from_numpy(pred).float() 376 | 377 | # if(viewpoint=="top"): 378 | msk3D = msk3D[:,:,[2,0,1]] 379 | yhat3D = yhat3D[:,:,[2,0,1]] 380 | 381 | print("GT: ", msk3D.shape) 382 | print("PRED: ", yhat3D.shape) 383 | 384 | print("ERROR: ", np.mean(np.sqrt(np.sum((yhat3D.cpu().detach().numpy() - msk3D.cpu().detach().numpy())**2, axis=2)))) 385 | 386 | save_3d_plot(msk3D, "gt_depth", display_labels=True, viewpoint=viewpoint) 387 | save_3d_plot(yhat3D.cpu().detach().numpy(), "pred_depth", viewpoint=viewpoint) 388 | 389 | index = 10 390 | image_2d = input[index].permute(1,2,0).cpu().detach().numpy() 391 | # # img_kps = np.zeros((256,256,3), np.uint8) 392 | # img_kps = cv2. cvtColor(image_2d, cv2.COLOR_GRAY2BGR)#.astype(np.uint8) 393 | # for i, kps in enumerate(yhat2D[index]): # (15,2,1) 394 | # if(i == 8): 395 | # color = (255,0,0) 396 | # else: 397 | # color = (0,255,0) 398 | # cv2.circle(img_kps, (int(256*kps[0].cpu()), int(256*kps[1].cpu())), 2, color, 8, 0) 399 | # # cv2.putText(img_kps, itop_labels[i], (int(256*kps[0].cpu()) + 10, int(256*kps[1].cpu())), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,0)) 400 | 401 | # cv2.imshow("Kps", img_kps) 402 | cv2.imshow("Input", image_2d) 403 | cv2.waitKey(0) 404 | 405 | def save_3d_plot(itop, name, azim=None, elev=None, gt=None, display_labels=False, viewpoint="top"): 406 | # itop_labels = ['Head','Neck','RShould','LShould',"RElbow","LElbow","RHand","LHand","Torso","RHip","LHip","RKnee","LKnee","RFoot","LFoot"] 407 | itop_labels = ['Head','Neck','LShould','RShould',"LElbow","RElbow","LHand","RHand","Torso","LHip","RHip","LKnee","RKnee","LFoot","RFoot"] 408 | itop_labels = ['0','1','2','3',"4","5","6","7","8","9","10","11","12","13","14"] 409 | 410 | itop_connections = [[0,1],[1,2],[1,3],[2,3],[2,4],[3,5],[4,6],[5,7],[1,8],[8,9],[8,10],[9,10],[9,11],[10,12],[11,13],[12,14]] 411 | fig = plt.figure() 412 | ax = plt.axes(projection='3d') 413 | index = 10 414 | 415 | itop_newjoints = change_format_from_19_joints_to_15_joints(itop[0]) 416 | itop_newjoints = np.expand_dims(itop_newjoints, 0) 417 | itop = np.repeat(itop_newjoints, 128, axis=0) 418 | # print(itop.shape) 419 | 420 | 421 | xdata = itop[index,:,0].flatten() 422 | ydata = itop[index,:,1].flatten() 423 | zdata = itop[index,:,2].flatten() 424 | 425 | 426 | for i in itop_connections: 427 | x1,x2,y1,y2,z1,z2 = connect(xdata,ydata,zdata,i[0],i[1]) 428 | ax.plot([x1,x2],[y1,y2],[z1,z2],'k-') 429 | 430 | ax.scatter3D(xdata, ydata, zdata, c=zdata) 431 | 432 | if(gt is not None): 433 | pred = undiscretize(itop, 0, 1)[index] 434 | gt = undiscretize(gt, 0, 1)[index] 435 | 436 | pred = pred.squeeze() 437 | gt = gt.squeeze() 438 | 439 | assert(pred.shape == gt.shape) 440 | assert(len(pred.shape) == 2) 441 | 442 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=1)) # (N, K) 443 | 444 | errors = (err_dist < 0.1) 445 | 446 | for i, (x, y, z, label) in enumerate(zip(xdata,ydata,zdata, itop_labels)): 447 | error_color='black' 448 | if(gt is not None and not errors[i]): 449 | error_color='red' 450 | if(display_labels): 451 | ax.text(x, y, z, label, color=error_color) 452 | 453 | # ax.text2D(0.05, 0.95, "ITOP", transform=ax.transAxes) 454 | 455 | if(azim): 456 | ax.view_init(elev=elev, azim=azim) 457 | 458 | # ax.set_xlabel('x', rotation=0, fontsize=20, labelpad=20) 459 | # ax.set_ylabel('y', rotation=0, fontsize=20, labelpad=20) 460 | # ax.set_zlabel('z', rotation=0, fontsize=20, labelpad=20) 461 | 462 | # ax.set_xlim3d(-1,1) 463 | # ax.set_ylim3d(-2,2) 464 | # ax.set_zlim3d(0,2) 465 | 466 | ax.set_xlim3d(0.2,1) 467 | ax.set_ylim3d(0,0.6) 468 | ax.set_zlim3d(0.8,0.2) 469 | 470 | # plt.show(block=False) 471 | 472 | # redraw the canvas 473 | fig.canvas.draw() 474 | 475 | # convert canvas to image 476 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, 477 | sep='') 478 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 479 | 480 | # img is rgb, convert to opencv's default bgr 481 | img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR) 482 | 483 | 484 | # display image with opencv or any operation you like 485 | cv2.imshow(name, img) 486 | cv2.imwrite(name+".png", img) 487 | if(name=="True side"): 488 | cv2.waitKey(1) 489 | else: 490 | cv2.waitKey(1) 491 | 492 | def connect(x,y,z,p1,p2): 493 | x1, x2 = x[p1], x[p2] 494 | y1, y2 = y[p1], y[p2] 495 | z1,z2 = z[p1],z[p2] 496 | return x1,x2,y1,y2,z1,z2 497 | 498 | 499 | def center_skeleton(skeletons): 500 | for b, batch in enumerate(skeletons): 501 | skeletons[b,:,:] = skeletons[b,:,:] - skeletons[b,2,:] 502 | return skeletons 503 | 504 | def change_format_from_19_joints_to_15_joints(joints): 505 | xdata = joints[:,0] 506 | ydata = joints[:,1] 507 | zdata = joints[:,2] 508 | 509 | panoptic_head = [(xdata[16]+xdata[18])/2,(ydata[16]+ydata[18])/2,(zdata[16]+zdata[18])/2] 510 | panoptic_torso = [(xdata[0]+xdata[2])/2,(ydata[0]+ydata[2])/2,(zdata[0]+zdata[2])/2] 511 | 512 | 513 | # head neck r shoulder l shoulder r elbow l elbow r hand l hand torso r hip l hip r knee l knee r foot l foot 514 | #xdata_new = np.array([panoptic_head[0], xdata[0], xdata[9], xdata[3], xdata[10], xdata[4], xdata[11], xdata[5], panoptic_torso[0], xdata[12], xdata[6], xdata[13], xdata[7], xdata[14], xdata[8]]) 515 | #ydata_new = np.array([panoptic_head[1], ydata[0], ydata[9], ydata[3], ydata[10], ydata[4], ydata[11], ydata[5], panoptic_torso[1], ydata[12], ydata[6], ydata[13], ydata[7], ydata[14], ydata[8]]) 516 | #zdata_new = np.array([panoptic_head[2], zdata[0], zdata[9], zdata[3], zdata[10], zdata[4], zdata[11], zdata[5], panoptic_torso[2], zdata[12], zdata[6], zdata[13], zdata[7], zdata[14], zdata[8]]) 517 | 518 | xdata_new = np.array([panoptic_head[0], xdata[0], xdata[3], xdata[9], xdata[4], xdata[10], xdata[5], xdata[11], panoptic_torso[0], xdata[6], xdata[12], xdata[7], xdata[13], xdata[8], xdata[14]]) 519 | ydata_new = np.array([panoptic_head[1], ydata[0], ydata[3], ydata[9], ydata[4], ydata[10], ydata[5], ydata[11], panoptic_torso[1], ydata[6], ydata[12], ydata[7], ydata[13], ydata[8], ydata[14]]) 520 | zdata_new = np.array([panoptic_head[2], zdata[0], zdata[3], zdata[9], zdata[4], zdata[10], zdata[5], zdata[11], panoptic_torso[2], zdata[6], zdata[12], zdata[7], zdata[13], zdata[8], zdata[14]]) 521 | 522 | panoptic_converted = np.empty(shape=(15, 3), dtype=float) 523 | for index in range(len(panoptic_converted)): 524 | panoptic_converted[index,0] = xdata_new[index] 525 | panoptic_converted[index,1] = ydata_new[index] 526 | panoptic_converted[index,2] = zdata_new[index] 527 | 528 | return panoptic_converted 529 | 530 | def discretize(coord, a, b): 531 | 532 | normalizers_3D = [[-0.927149999999999, 1.4176299999999982], [-1.1949180000000008, 0.991252999999999], [-0.8993889999999993, 0.8777908000000015]] 533 | 534 | for i in range(3): 535 | coord[:,:,i] = (b - a) * (coord[:,:,i] - normalizers_3D[i][0]) / (normalizers_3D[i][1] - normalizers_3D[i][0]) + a 536 | 537 | return coord 538 | 539 | def undiscretize(coord, a, b): 540 | 541 | normalizers_3D = [[-0.927149999999999, 1.4176299999999982], [-1.1949180000000008, 0.991252999999999], [-0.8993889999999993, 0.8777908000000015]] 542 | 543 | for i in range(3): 544 | coord[:,:,i] = ( (coord[:,:,i] - a) * (normalizers_3D[i][1] - normalizers_3D[i][0]) / (b - a) ) + normalizers_3D[i][0] 545 | 546 | return coord -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import transforms 7 | import os 8 | import re 9 | import cv2 10 | from random import shuffle 11 | 12 | from utils import compute_distances 13 | from utils import * 14 | from layers import * 15 | 16 | import pytorch_lightning as pl 17 | 18 | import matplotlib.pyplot as plt 19 | from datetime import datetime 20 | 21 | 22 | class CapsulePose(pl.LightningModule): 23 | 24 | def __init__(self, FLAGS): 25 | super(CapsulePose, self).__init__() 26 | self.FLAGS = FLAGS 27 | 28 | self.P = self.FLAGS.pose_dim 29 | self.PP = int(np.max([2, self.P*self.P])) 30 | self.A, self.B, self.C, self.D = self.FLAGS.arch[:-1] 31 | self.n_classes = self.FLAGS.n_classes = self.FLAGS.arch[-1] 32 | self.in_channels = self.FLAGS.n_channels 33 | 34 | self.s1 = torch.nn.parameter.Parameter(torch.tensor(1., device=self.device), requires_grad=True) 35 | self.s2 = torch.nn.parameter.Parameter(torch.tensor(1., device=self.device), requires_grad=True) 36 | self.s3 = torch.nn.parameter.Parameter(torch.tensor(1., device=self.device), requires_grad=True) 37 | self.s4 = torch.nn.parameter.Parameter(torch.tensor(1., device=self.device), requires_grad=True) 38 | 39 | self.drop_rate = 0.5 40 | self.features = [] 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | self.Conv_1 = nn.Conv2d(in_channels=self.in_channels, out_channels=64, 45 | kernel_size=9, stride=2, padding=4, bias=False) 46 | # self.Residual_1 = nn.Conv2d(in_channels=self.in_channels, out_channels=32, 47 | # kernel_size=1, stride=2, padding=0, bias=False) 48 | nn.init.xavier_uniform(self.Conv_1.weight) 49 | self.IN_1 = nn.InstanceNorm2d(64) 50 | self.Drop_1 = nn.Dropout(p=float(self.drop_rate)) 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | self.Conv_2 = nn.Conv2d(in_channels=64, out_channels=128, 55 | kernel_size=9, stride=2, padding=4, bias=False) 56 | # self.Residual_2 = nn.Conv2d(in_channels=32, out_channels=64, 57 | # kernel_size=1, stride=2, padding=0, bias=False) 58 | nn.init.xavier_uniform(self.Conv_2.weight) 59 | self.IN_2 = nn.InstanceNorm2d(128) 60 | self.Drop_2 = nn.Dropout(p=float(self.drop_rate)) 61 | 62 | #---------------------------------------------------------------------------- 63 | 64 | self.Conv_3 = nn.Conv2d(in_channels=128, out_channels=256, 65 | kernel_size=9, stride=2, padding=4, bias=False) 66 | # self.Residual_3 = nn.Conv2d(in_channels=64, out_channels=128, 67 | # kernel_size=1, stride=2, padding=0, bias=False) 68 | nn.init.xavier_uniform(self.Conv_3.weight) 69 | self.IN_3 = nn.InstanceNorm2d(256) 70 | self.Drop_3 = nn.Dropout(p=float(self.drop_rate)) 71 | 72 | #---------------------------------------------------------------------------- 73 | 74 | self.Conv_4 = nn.Conv2d(in_channels=256, out_channels=self.A, 75 | kernel_size=9, stride=3, padding=7, bias=False) 76 | # self.Residual_4 = nn.Conv2d(in_channels=128, out_channels=self.A, 77 | # kernel_size=1, stride=3, padding=3, bias=False) 78 | nn.init.xavier_uniform(self.Conv_3.weight) 79 | self.IN_4 = nn.InstanceNorm2d(self.A) 80 | self.Drop_4 = nn.Dropout(p=float(self.drop_rate)) 81 | 82 | #---------------------------------------------------------------------------- 83 | 84 | self.PrimaryCaps = PrimaryCapsules2d(in_channels=self.A, out_caps=self.B, 85 | kernel_size=1, stride=1, pose_dim=self.P) 86 | 87 | #---------------------------------------------------------------------------- 88 | 89 | self.ConvCaps_1 = ConvCapsules2d(in_caps=self.B, out_caps=self.C, 90 | kernel_size=3, stride=2, pose_dim=self.P) 91 | 92 | self.ConvRouting_1 = VariationalBayesRouting2d(in_caps=self.B, out_caps=self.C, 93 | kernel_size=3, stride=2, pose_dim=self.P, 94 | cov='diag', iter=self.FLAGS.routing_iter, 95 | alpha0=1., m0=torch.zeros(self.PP), kappa0=1., 96 | Psi0=torch.eye(self.PP), nu0=self.PP+1) 97 | 98 | #---------------------------------------------------------------------------- 99 | 100 | self.ConvCaps_2 = ConvCapsules2d(in_caps=self.C, out_caps=self.D, 101 | kernel_size=3, stride=1, pose_dim=self.P) 102 | 103 | self.ConvRouting_2 = VariationalBayesRouting2d(in_caps=self.C, out_caps=self.D, 104 | kernel_size=3, stride=1, pose_dim=self.P, 105 | cov='diag', iter=self.FLAGS.routing_iter, 106 | alpha0=1., m0=torch.zeros(self.PP), kappa0=1., 107 | Psi0=torch.eye(self.PP), nu0=self.PP+1) 108 | 109 | #---------------------------------------------------------------------------- 110 | 111 | self.ClassCaps = ConvCapsules2d(in_caps=self.D, out_caps=self.n_classes, 112 | kernel_size=1, stride=1, pose_dim=self.P, share_W_ij=True, coor_add=True) 113 | 114 | self.ClassRouting = VariationalBayesRouting2d(in_caps=self.D, out_caps=self.n_classes, 115 | # adjust final kernel_size K depending on input H/W, for H=W=32, K=4. 116 | kernel_size=4, stride=1, pose_dim=self.P, 117 | cov='diag', iter=self.FLAGS.routing_iter, 118 | alpha0=1., m0=torch.zeros(self.PP), kappa0=1., 119 | Psi0=torch.eye(self.PP), nu0=self.PP+1, class_caps=True) 120 | 121 | #---------------------------------------------------------------------------- 122 | 123 | self.Entities = nn.Flatten() 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | 128 | self.FC_2D = FullyConnected2d( 129 | self.FLAGS.batch_size, self.FLAGS.n_classes*self.FLAGS.P*self.FLAGS.P, rate=self.drop_rate) 130 | 131 | self.FC_3D = FullyConnected3d( 132 | self.FLAGS.batch_size, self.FLAGS.n_classes*self.FLAGS.P*self.FLAGS.P, rate=self.drop_rate) 133 | 134 | self.Depth_Recons = DepthReconstruction( 135 | self.FLAGS.batch_size, self.FLAGS.input_width, self.FLAGS.input_height, self.FLAGS.n_classes*self.FLAGS.P*self.FLAGS.P, rate=self.drop_rate) 136 | 137 | def forward(self, v): 138 | 139 | # Out ← [?, A, F, F] 140 | v = self.Conv_1(v) #+ self.Residual_1(v) 141 | v = self.IN_1(v) 142 | # v = self.Drop_1(v) 143 | v = F.gelu(v) 144 | 145 | # print(v.shape) 146 | 147 | v = self.Conv_2(v) #+ self.Residual_2(v) 148 | v = self.IN_2(v) 149 | # v = self.Drop_2(v) 150 | v = F.gelu(v) 151 | 152 | # print(v.shape) 153 | 154 | v = self.Conv_3(v) #+ self.Residual_3(v) 155 | v = self.IN_3(v) 156 | # v = self.Drop_3(v) 157 | v = F.gelu(v) 158 | 159 | # print(v.shape) 160 | 161 | v = self.Conv_4(v) #+ self.Residual_4(v) 162 | v = self.IN_4(v) 163 | # v = self.Drop_4(v) 164 | v = F.gelu(v) 165 | 166 | # print(v.shape) 167 | 168 | # Out ← a [?, B, F, F], v [?, B, P, P, F, F] 169 | a, v = self.PrimaryCaps(v) 170 | # print(v.shape) 171 | 172 | # Out ← a [?, B, 1, 1, 1, F, F, K, K], v [?, B, C, P*P, 1, F, F, K, K] 173 | a, v, _ = self.ConvCaps_1(a, v) 174 | # print(v.shape) 175 | 176 | # Out ← a [?, C, F, F], v [?, C, P, P, F, F] 177 | a, v = self.ConvRouting_1(a, v) 178 | # print(v.shape) 179 | 180 | # Out ← a [?, C, 1, 1, 1, F, F, K, K], v [?, C, D, P*P, 1, F, F, K, K] 181 | a, v, _ = self.ConvCaps_2(a, v) 182 | # print(v.shape) 183 | 184 | # Out ← a [?, D, F, F], v [?, D, P, P, F, F] 185 | a, v = self.ConvRouting_2(a, v) 186 | # print(v.shape) 187 | 188 | # Out ← a [?, D, 1, 1, 1, F, F, K, K], v [?, D, n_classes, P*P, 1, F, F, K, K] 189 | a, v, W_reg = self.ClassCaps(a, v) 190 | # print(v.shape) 191 | 192 | # Out ← yhat [?, n_classes], v [?, n_classes, P, P] 193 | #yhat, v = self.ClassRouting(a, v) 194 | _, v = self.ClassRouting(a, v) 195 | # print(v.shape) 196 | 197 | v = self.Entities(v) # > (10, 272) 198 | # print(v.shape) 199 | 200 | # fc = self.FC_Base(entities) 201 | # print(v.shape) 202 | 203 | yhat2D = self.FC_2D(v) 204 | # print(yhat2D.shape) 205 | 206 | yhat3D = self.FC_3D(v) 207 | # print(yhat3D.shape) 208 | 209 | yhatD = self.Depth_Recons(v) 210 | # print(yhatD.shape) 211 | 212 | return yhat2D, yhat3D, yhatD, W_reg, v 213 | 214 | def training_step(self, train_batch, batch_idx): 215 | inputs, labels, _ = train_batch 216 | 217 | labels['msk'] = center_skeleton(labels['msk']) 218 | labels['msk'] = discretize(labels['msk'], 0, 1) 219 | 220 | yhat2D, yhat3D, yhatD, W_reg, _ = self.forward(inputs) 221 | loss, loss2D, loss3D, lossD, W_reg, ED_3D = self.pose_loss( 222 | yhat2D, yhat3D, yhatD, W_reg, labels) 223 | 224 | # viewpoint = "top" 225 | # save_3d_plot(labels['msk'], "gt_depth", display_labels=True, viewpoint=viewpoint) 226 | # save_3d_plot(yhat3D.cpu().detach().numpy(), "pred_depth", viewpoint=viewpoint) 227 | 228 | self.log('Training/loss', loss) 229 | self.log('Training/loss2D', loss2D) 230 | self.log('Training/loss3D', loss3D) 231 | self.log('Training/lossD', lossD) 232 | self.log('Training/W_reg', W_reg) 233 | self.log('Training/ED_3D', ED_3D) 234 | self.log('Training/mAP', calc_mAP(yhat3D, labels['msk'].type( 235 | torch.FloatTensor).unsqueeze(-1).cuda(non_blocking=True))) 236 | 237 | return loss 238 | 239 | def validation_step(self, val_batch, batch_idx): 240 | inputs, labels, _ = val_batch 241 | 242 | labels['msk'] = center_skeleton(labels['msk']) 243 | labels['msk'] = discretize(labels['msk'], 0, 1) 244 | 245 | yhat2D, yhat3D, yhatD, W_reg, _ = self.forward(inputs) 246 | loss, loss2D, loss3D, lossD, W_reg, ED_3D = self.pose_loss( 247 | yhat2D, yhat3D, yhatD, W_reg, labels) 248 | 249 | self.log('Validation/loss', loss) 250 | self.log('Validation/loss2D', loss2D) 251 | self.log('Validation/loss3D', loss3D) 252 | self.log('Validation/lossD', lossD) 253 | self.log('Validation/W_reg', W_reg) 254 | self.log('Validation/ED_3D', ED_3D) 255 | self.log('Validation/mAP', calc_mAP(yhat3D, labels['msk'].type( 256 | torch.FloatTensor).unsqueeze(-1).cuda(non_blocking=True))) 257 | 258 | return loss 259 | 260 | def test_step(self, val_batch, batch_idx): 261 | inputs, labels, _ = val_batch 262 | 263 | labels['msk'] = center_skeleton(labels['msk']) 264 | labels['msk'] = discretize(labels['msk'], 0, 1) 265 | 266 | yhat2D, yhat3D, yhatD, W_reg, feat_vec = self.forward(inputs) 267 | loss, loss2D, loss3D, lossD, W_reg, ED_3D = self.pose_loss( 268 | yhat2D, yhat3D, yhatD, W_reg, labels) 269 | 270 | self.features.extend(list(feat_vec.data.cpu().numpy())) 271 | 272 | # viewpoint = "top" 273 | # save_3d_plot(labels['msk'], "gt_depth", display_labels=True, viewpoint=viewpoint) 274 | # save_3d_plot(yhat3D.cpu().detach().numpy(), "pred_depth", viewpoint=viewpoint) 275 | 276 | self.log('Test/loss', loss) 277 | self.log('Test/loss2D', loss2D) 278 | self.log('Test/loss3D', loss3D) 279 | self.log('Test/lossD', lossD) 280 | self.log('Test/W_reg', W_reg) 281 | self.log('Test/ED_3D', ED_3D) 282 | MPJPE = calc_MPJPE(yhat3D, labels['msk'].type( 283 | torch.FloatTensor).unsqueeze(-1)) 284 | self.log('Test/MPJPE_00_Neck', MPJPE[1]) 285 | self.log('Test/MPJPE_01_Nose', MPJPE[2]) 286 | self.log('Test/MPJPE_02_Body_center', MPJPE[3]) 287 | self.log('Test/MPJPE_03_Shoulders', MPJPE[4]) 288 | self.log('Test/MPJPE_04_Elbows', MPJPE[5]) 289 | self.log('Test/MPJPE_05_Hands', MPJPE[6]) 290 | self.log('Test/MPJPE_06_Hips', MPJPE[7]) 291 | self.log('Test/MPJPE_07_Knees', MPJPE[8]) 292 | self.log('Test/MPJPE_08_Feet', MPJPE[9]) 293 | self.log('Test/MPJPE_09_Eyes', MPJPE[10]) 294 | self.log('Test/MPJPE_10_Ears', MPJPE[11]) 295 | self.log('Test/MPJPE_11_Upper_Body', MPJPE[12]) 296 | self.log('Test/MPJPE_12_Lower_Body', MPJPE[13]) 297 | self.log('Test/MPJPE_13_Mean', MPJPE[0]) 298 | 299 | return loss 300 | 301 | def configure_optimizers(self): 302 | optimizer = optim.AdamW(self.parameters(), 303 | lr=self.FLAGS.learning_rate, weight_decay=self.FLAGS.weight_decay) 304 | # optimizer = adabound.AdaBound( 305 | # self.parameters(), lr=self.FLAGS.learning_rate, final_lr=0.1) 306 | return optimizer 307 | 308 | def pose_loss(self, yhat2D, yhat3D, yhatD, W_reg, labels): 309 | msk2d = labels['msk2d'].type( 310 | torch.FloatTensor).unsqueeze(-1).cuda(non_blocking=True) # > (10,19,2,1) (B,C,2,1) 311 | msk3d = labels['msk'].type( 312 | torch.FloatTensor).unsqueeze(-1).cuda(non_blocking=True) # > (10,19,3,1) (B,C,3,1) 313 | 314 | # if(not int(datetime.now().strftime('%S')) % 10): 315 | # plot_skeleton(msk3d.cpu().detach().numpy() * 20, "True") 316 | # plot_skeleton(yhat3D.cpu().detach().numpy() * 20, "Predicted") 317 | 318 | # print("###########") 319 | # print(labels['depth'][0].cpu().numpy().shape) 320 | # cv2.imshow("PRE", labels['depth'][0].cpu().numpy()) 321 | # cv2.waitKey(0) 322 | depth = labels['depth'].permute(0, 3, 1, 2).cuda( 323 | non_blocking=True) #/ 255. # > (10,3,256,256) (B,C,H,W) 324 | # print(depth[0].permute(1, 2, 0).cpu().numpy().shape) 325 | # cv2.imshow("POST", depth[0].permute(1, 2, 0).cpu().numpy()) 326 | # cv2.waitKey(0) 327 | loss3D, loss2D, EDs, EDs_2D, lossD = compute_distances( 328 | self.FLAGS, labels3D=msk3d, predictions3D=yhat3D, labels2D=msk2d, predictions2D=yhat2D, labelsD=depth, predictionsD=yhatD) 329 | 330 | loss = (0.5 * torch.exp(-self.s1)) * loss3D + self.s1 + \ 331 | (0.5 * torch.exp(-self.s2)) * loss2D + self.s2 + \ 332 | (0.5 * torch.exp(-self.s3)) * lossD + self.s3 + \ 333 | (0.5 * torch.exp(-self.s4)) * W_reg/3000 + self.s4 334 | 335 | 336 | # loss = loss3D + loss2D + lossR + W_reg 337 | return loss, loss2D, loss3D, lossD, W_reg, 1000*EDs 338 | 339 | def plot_skeleton(itop, name): 340 | itop_labels = ['Head','Neck','RShould','LShould',"RElbow","LElbow","RHand","LHand","Torso","RHip","LHip","RKnee","LKnee","RFoot","LFoot"] 341 | itop_connections = [[0,1],[1,2],[1,3],[2,3],[2,4],[3,5],[4,6],[5,7],[1,8],[8,9],[8,10],[9,10],[9,11],[10,12],[11,13],[12,14]] 342 | itop = itop[0] 343 | fig = plt.figure() 344 | ax = plt.axes(projection='3d') 345 | xdata = itop[:,0].flatten() 346 | ydata = itop[:,1].flatten() 347 | zdata = itop[:,2].flatten() 348 | 349 | for i in itop_connections: 350 | x1,x2,y1,y2,z1,z2 = connectpoints(xdata,ydata,zdata,i[0],i[1]) 351 | ax.plot([x1,x2],[y1,y2],[z1,z2],'k-') 352 | 353 | ax.scatter3D(xdata, ydata, zdata, c=zdata) 354 | 355 | for x, y, z, label in zip(xdata,ydata,zdata, itop_labels): 356 | ax.text(x, y, z, label) 357 | 358 | ax.text2D(0.05, 0.95, "ITOP", transform=ax.transAxes) 359 | 360 | ax.set_xlabel('x', rotation=0, fontsize=20, labelpad=20) 361 | ax.set_ylabel('y', rotation=0, fontsize=20, labelpad=20) 362 | ax.set_zlabel('z', rotation=0, fontsize=20, labelpad=20) 363 | ax.set_xlim3d(-1,1) 364 | ax.set_ylim3d(-2,2) 365 | ax.set_zlim3d(0,2) 366 | # plt.show(block=False) 367 | 368 | # redraw the canvas 369 | fig.canvas.draw() 370 | 371 | # convert canvas to image 372 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, 373 | sep='') 374 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 375 | 376 | # img is rgb, convert to opencv's default bgr 377 | img = cv2.cvtColor(img,cv2.COLOR_RGB2BGR) 378 | 379 | 380 | # display image with opencv or any operation you like 381 | cv2.imshow(name,img) 382 | cv2.waitKey(1) 383 | 384 | def connectpoints(x,y,z,p1,p2): 385 | x1, x2 = x[p1], x[p2] 386 | y1, y2 = y[p1], y[p2] 387 | z1,z2 = z[p1],z[p2] 388 | return x1,x2,y1,y2,z1,z2 389 | 390 | def calc_mAP(pred_or, gt_or, dist=0.1): 391 | ''' 392 | pred: (N, K, 3) 393 | gt: (N, K, 3) 394 | ''' 395 | 396 | pred = pred_or.cpu().detach().numpy() 397 | gt = gt_or.cpu().detach().numpy() 398 | 399 | pred = undiscretize(pred, 0, 1) 400 | gt = undiscretize(gt, 0, 1) 401 | 402 | pred = pred.squeeze() 403 | gt = gt.squeeze() 404 | 405 | assert(pred.shape == gt.shape) 406 | assert(len(pred.shape) == 3) 407 | 408 | N, K = pred.shape[0], pred.shape[1] # [BS, 15, 3] 409 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K) 410 | 411 | acc_d = (err_dist < dist).sum(axis=0) / N 412 | 413 | return np.mean(acc_d) 414 | 415 | def calc_mAP_procrustes(pred_or, gt_or, dist=0.1): 416 | ''' 417 | pred: (N, K, 3) 418 | gt: (N, K, 3) 419 | ''' 420 | 421 | 422 | pred = pred_or.cpu().detach().numpy() 423 | gt = gt_or.cpu().detach().numpy() 424 | 425 | pred = undiscretize(pred, 0, 1) 426 | gt = undiscretize(gt, 0, 1) 427 | 428 | pred = pred.squeeze(3) 429 | gt = gt.squeeze(3) 430 | 431 | assert(pred.shape == gt.shape) 432 | assert(len(pred.shape) == 3) 433 | 434 | N, K = pred.shape[0], pred.shape[1] # [BS, 15, 3] 435 | # for i, p in enumerate(pred): 436 | # d, Z, tform = procrustes( 437 | # gt[i], pred[i]) 438 | # pred[i] = Z 439 | 440 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K) 441 | acc_d = (err_dist < dist).sum(axis=0) / N 442 | 443 | err_dist_neck = np.sqrt(np.sum((pred[:,[0],:] - gt[:,[0],:])**2, axis=2)) # (N, K) 444 | acc_d_neck = (err_dist_neck < dist).sum(axis=0) / N 445 | 446 | err_dist_nose = np.sqrt(np.sum((pred[:,[1],:] - gt[:,[1],:])**2, axis=2)) # (N, K) 447 | acc_d_nose = (err_dist_nose < dist).sum(axis=0) / N 448 | 449 | err_dist_bodycenter = np.sqrt(np.sum((pred[:,[2],:] - gt[:,[2],:])**2, axis=2)) # (N, K) 450 | acc_d_bodycenter = (err_dist_bodycenter < dist).sum(axis=0) / N 451 | 452 | err_dist_shoulders = np.sqrt(np.sum((pred[:,[3,9],:] - gt[:,[3,9],:])**2, axis=2)) # (N, K) 453 | acc_d_shoulders = (err_dist_shoulders < dist).sum(axis=0) / N 454 | 455 | err_dist_elbows = np.sqrt(np.sum((pred[:,[4,10],:] - gt[:,[4,10],:])**2, axis=2)) # (N, K) 456 | acc_d_elbows = (err_dist_elbows < dist).sum(axis=0) / N 457 | 458 | err_dist_wrists = np.sqrt(np.sum((pred[:,[5,11],:] - gt[:,[5,11],:])**2, axis=2)) # (N, K) 459 | acc_d_wrists = (err_dist_wrists < dist).sum(axis=0) / N 460 | 461 | err_dist_hips = np.sqrt(np.sum((pred[:,[6,12],:] - gt[:,[6,12],:])**2, axis=2)) # (N, K) 462 | acc_d_hips = (err_dist_hips < dist).sum(axis=0) / N 463 | 464 | err_dist_knees = np.sqrt(np.sum((pred[:,[7,13],:] - gt[:,[7,13],:])**2, axis=2)) # (N, K) 465 | acc_d_knees = (err_dist_knees < dist).sum(axis=0) / N 466 | 467 | err_dist_ankles = np.sqrt(np.sum((pred[:,[8,14],:] - gt[:,[8,14],:])**2, axis=2)) # (N, K) 468 | acc_d_ankles = (err_dist_ankles < dist).sum(axis=0) / N 469 | 470 | err_dist_eyes = np.sqrt(np.sum((pred[:,[15,17],:] - gt[:,[15,17],:])**2, axis=2)) # (N, K) 471 | acc_d_eyes = (err_dist_eyes < dist).sum(axis=0) / N 472 | 473 | err_dist_ears = np.sqrt(np.sum((pred[:,[16,18],:] - gt[:,[16,18],:])**2, axis=2)) # (N, K) 474 | acc_d_ears = (err_dist_ears < dist).sum(axis=0) / N 475 | 476 | err_dist_upper_body = np.sqrt(np.sum((pred[:,[0,1,2,3,4,5,9,10,11,15,16,17,18],:] - gt[:,[0,1,2,3,4,5,9,10,11,15,16,17,18],:])**2, axis=2)) # (N, K) 477 | acc_d_upper_body = (err_dist_upper_body < dist).sum(axis=0) / N 478 | 479 | err_dist_lower_body = np.sqrt(np.sum((pred[:,[6,7,8,12,13,14],:] - gt[:,[6,7,8,12,13,14],:])**2, axis=2)) # (N, K) 480 | acc_d_lower_body = (err_dist_lower_body < dist).sum(axis=0) / N 481 | 482 | return np.mean(acc_d), np.mean(acc_d_neck), np.mean(acc_d_nose), np.mean(acc_d_bodycenter), \ 483 | np.mean(acc_d_shoulders), np.mean(acc_d_elbows), np.mean(acc_d_wrists), np.mean(acc_d_hips), \ 484 | np.mean(acc_d_knees), np.mean(acc_d_ankles), np.mean(acc_d_eyes), np.mean(acc_d_ears), \ 485 | np.mean(acc_d_upper_body), np.mean(acc_d_lower_body) 486 | 487 | def calc_MPJPE(pred_or, gt_or, procrustes_transform=True): 488 | ''' 489 | pred: (N, K, 3) 490 | gt: (N, K, 3) 491 | ''' 492 | 493 | 494 | pred = pred_or.cpu().detach().numpy() 495 | gt = gt_or.cpu().detach().numpy() 496 | 497 | pred = undiscretize(pred, 0, 1) 498 | gt = undiscretize(gt, 0, 1) 499 | 500 | pred = pred.squeeze(3) 501 | gt = gt.squeeze(3) 502 | 503 | assert(pred.shape == gt.shape) 504 | assert(len(pred.shape) == 3) 505 | 506 | N, K = pred.shape[0], pred.shape[1] # [BS, 15, 3] 507 | 508 | if(procrustes_transform): 509 | for i, p in enumerate(pred): 510 | d, Z, tform = procrustes(gt[i], pred[i]) 511 | pred[i] = Z 512 | 513 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) * 100 514 | 515 | err_dist_neck = np.sqrt(np.sum((pred[:,[0],:] - gt[:,[0],:])**2, axis=2)) * 100 516 | 517 | err_dist_nose = np.sqrt(np.sum((pred[:,[1],:] - gt[:,[1],:])**2, axis=2)) * 100 518 | 519 | err_dist_bodycenter = np.sqrt(np.sum((pred[:,[2],:] - gt[:,[2],:])**2, axis=2)) * 100 520 | 521 | err_dist_shoulders = np.sqrt(np.sum((pred[:,[3,9],:] - gt[:,[3,9],:])**2, axis=2)) * 100 522 | 523 | err_dist_elbows = np.sqrt(np.sum((pred[:,[4,10],:] - gt[:,[4,10],:])**2, axis=2)) * 100 524 | 525 | err_dist_wrists = np.sqrt(np.sum((pred[:,[5,11],:] - gt[:,[5,11],:])**2, axis=2)) * 100 526 | 527 | err_dist_hips = np.sqrt(np.sum((pred[:,[6,12],:] - gt[:,[6,12],:])**2, axis=2)) * 100 528 | 529 | err_dist_knees = np.sqrt(np.sum((pred[:,[7,13],:] - gt[:,[7,13],:])**2, axis=2)) * 100 530 | 531 | err_dist_ankles = np.sqrt(np.sum((pred[:,[8,14],:] - gt[:,[8,14],:])**2, axis=2))* 100 532 | 533 | err_dist_eyes = np.sqrt(np.sum((pred[:,[15,17],:] - gt[:,[15,17],:])**2, axis=2)) * 100 534 | 535 | err_dist_ears = np.sqrt(np.sum((pred[:,[16,18],:] - gt[:,[16,18],:])**2, axis=2)) * 100 536 | 537 | err_dist_upper_body = np.sqrt(np.sum((pred[:,[0,1,2,3,4,5,9,10,11,15,16,17,18],:] - gt[:,[0,1,2,3,4,5,9,10,11,15,16,17,18],:])**2, axis=2)) * 100 538 | 539 | err_dist_lower_body = np.sqrt(np.sum((pred[:,[6,7,8,12,13,14],:] - gt[:,[6,7,8,12,13,14],:])**2, axis=2)) * 100 540 | 541 | return err_dist, err_dist_neck, err_dist_nose, err_dist_bodycenter, \ 542 | err_dist_shoulders, err_dist_elbows, err_dist_wrists, err_dist_hips, \ 543 | err_dist_knees, err_dist_ankles, err_dist_eyes, err_dist_ears, \ 544 | err_dist_upper_body, err_dist_lower_body 545 | 546 | # 0: Neck 547 | # 1: Nose 548 | # 2: BodyCenter (center of hips) 549 | # 3: lShoulder 550 | # 4: lElbow 551 | # 5: lWrist, 552 | # 6: lHip 553 | # 7: lKnee 554 | # 8: lAnkle 555 | # 9: rShoulder 556 | # 10: rElbow 557 | # 11: rWrist 558 | # 12: rHip 559 | # 13: rKnee 560 | # 14: rAnkle 561 | # 15: lEye 562 | # 16: lEar 563 | # 17: rEye 564 | # 18: rEar -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import pytorch_lightning as pl 7 | 8 | 9 | class PrimaryCapsules2d(pl.LightningModule): 10 | '''Primary Capsule Layer''' 11 | 12 | def __init__(self, in_channels, out_caps, kernel_size, stride, 13 | padding=0, pose_dim=4, weight_init='xavier_uniform'): 14 | super().__init__() 15 | 16 | self.A = in_channels 17 | self.B = out_caps 18 | self.P = pose_dim 19 | self.K = kernel_size 20 | self.S = stride 21 | self.padding = padding 22 | 23 | w_kernel = torch.empty(self.B*self.P*self.P, self.A, self.K, self.K) 24 | a_kernel = torch.empty(self.B, self.A, self.K, self.K) 25 | 26 | if weight_init == 'kaiming_normal': 27 | nn.init.kaiming_normal_(w_kernel) 28 | nn.init.kaiming_normal_(a_kernel) 29 | elif weight_init == 'kaiming_uniform': 30 | nn.init.kaiming_uniform_(w_kernel) 31 | nn.init.kaiming_uniform_(a_kernel) 32 | elif weight_init == 'xavier_normal': 33 | nn.init.xavier_normal_(w_kernel) 34 | nn.init.xavier_normal_(a_kernel) 35 | elif weight_init == 'xavier_uniform': 36 | nn.init.xavier_uniform_(w_kernel) 37 | nn.init.xavier_uniform_(a_kernel) 38 | else: 39 | NotImplementedError('{} not implemented.'.format(weight_init)) 40 | 41 | # Out ← [B*(P*P+1), A, K, K] 42 | self.weight = nn.Parameter(torch.cat([w_kernel, a_kernel], dim=0)) 43 | 44 | self.BN_a = nn.BatchNorm2d(self.B, affine=True) 45 | self.BN_p = nn.BatchNorm3d(self.B, affine=True) 46 | 47 | def forward(self, x): # [?, A, F, F] ← In 48 | 49 | # Out ← [?, B*(P*P+1), F, F] 50 | x = F.conv2d(x, weight=self.weight, 51 | stride=self.S, padding=self.padding) 52 | 53 | # Out ← ([?, B*P*P, F, F], [?, B, F, F]) ← [?, B*(P*P+1), F, F] 54 | poses, activations = torch.split( 55 | x, [self.B*self.P*self.P, self.B], dim=1) 56 | 57 | # Out ← [?, B, P*P, F, F] 58 | poses = self.BN_p( 59 | poses.reshape(-1, self.B, self.P*self.P, *x.shape[2:])) 60 | 61 | # Out ← [?, B, P, P, F, F] ← [?, B, P*P, F, F] ← In 62 | poses = poses.reshape(-1, self.B, self.P, self.P, *x.shape[2:]) 63 | 64 | # Out ← [?, B, F, F]) 65 | activations = torch.sigmoid(self.BN_a(activations)) 66 | 67 | return (activations, poses) 68 | 69 | 70 | class ConvCapsules2d(pl.LightningModule): 71 | '''Convolutional Capsule Layer''' 72 | 73 | def __init__(self, in_caps, out_caps, pose_dim, kernel_size, stride, padding=0, 74 | weight_init='xavier_uniform', share_W_ij=False, coor_add=False): 75 | super().__init__() 76 | 77 | self.B = in_caps 78 | self.C = out_caps 79 | self.P = pose_dim 80 | self.PP = np.max([2, self.P*self.P]) 81 | self.K = kernel_size 82 | self.S = stride 83 | self.padding = padding 84 | 85 | # share the transformation matrices across (F*F) 86 | self.share_W_ij = share_W_ij 87 | self.coor_add = coor_add # embed coordinates 88 | 89 | # Out ← [1, B, C, 1, P, P, 1, 1, K, K] 90 | self.W_ij = torch.empty(1, self.B, self.C, 1, 91 | self.P, self.P, 1, 1, self.K, self.K) # .normal_(std=0.01) 92 | 93 | if weight_init.split('_')[0] == 'xavier': 94 | # in_caps types * receptive field size 95 | fan_in = self.B * self.K*self.K * self.PP 96 | # out_caps types * receptive field size 97 | fan_out = self.C * self.K*self.K * self.PP 98 | std = np.sqrt(2. / (fan_in + fan_out)) 99 | bound = np.sqrt(3.) * std 100 | 101 | if weight_init.split('_')[1] == 'normal': 102 | self.W_ij = nn.Parameter(self.W_ij.normal_(0, std)) 103 | elif weight_init.split('_')[1] == 'uniform': 104 | self.W_ij = nn.Parameter(self.W_ij.uniform_(-bound, bound)) 105 | else: 106 | raise NotImplementedError( 107 | '{} not implemented.'.format(weight_init)) 108 | 109 | elif weight_init.split('_')[0] == 'kaiming': 110 | # fan_in preserves magnitude of the variance of the weights in the forward pass. 111 | # in_caps types * receptive field size 112 | fan_in = self.B * self.K*self.K * self.PP 113 | # fan_out has same affect as fan_in for backward pass. 114 | # fan_out = self.C * self.K*self.K * self.PP # out_caps types * receptive field size 115 | std = np.sqrt(2.) / np.sqrt(fan_in) 116 | bound = np.sqrt(3.) * std 117 | 118 | if weight_init.split('_')[1] == 'normal': 119 | self.W_ij = nn.Parameter(self.W_ij.normal_(0, std)) 120 | elif weight_init.split('_')[1] == 'uniform': 121 | self.W_ij = nn.Parameter(self.W_ij.uniform_(-bound, bound)) 122 | else: 123 | raise NotImplementedError( 124 | '{} not implemented.'.format(weight_init)) 125 | 126 | elif weight_init == 'noisy_identity' and self.PP > 2: 127 | b = 0.01 # U(0,b) 128 | # Out ← [1, B, C, 1, P, P, 1, 1, K, K] 129 | self.W_ij = nn.Parameter(torch.clamp(.1*torch.eye(self.P, self.P).repeat( 130 | 1, self.B, self.C, 1, 1, 1, self.K, self.K, 1, 1) + 131 | torch.empty(1, self.B, self.C, 1, 1, 1, self.K, 132 | self.K, self.P, self.P).uniform_(0, b), 133 | max=1).permute(0, 1, 2, 3, -2, -1, 4, 5, 6, 7)) 134 | else: 135 | raise NotImplementedError( 136 | '{} not implemented.'.format(weight_init)) 137 | 138 | if self.padding != 0: 139 | if isinstance(self.padding, int): 140 | self.padding = [self.padding]*4 141 | 142 | # ([?, B, F, F], [?, B, P, P, F, F]) ← In 143 | def forward(self, activations, poses): 144 | 145 | self.W_reg = None 146 | if(self.share_W_ij): 147 | self.W_reg = torch.matmul(self.W_ij.squeeze(), self.W_ij.squeeze().permute( 148 | 0, 1, 3, 2)) - torch.eye(4, device=self.device) 149 | self.W_reg = torch.norm(self.W_reg) 150 | 151 | if self.padding != 0: 152 | activations = F.pad(activations, self.padding) # [1,1,1,1] 153 | poses = F.pad(poses, self.padding + [0]*4) # [0,0,1,1,1,1] 154 | 155 | # share the matrices over (F*F), if class caps layer 156 | if self.share_W_ij: 157 | self.K = poses.shape[-1] # out_caps (C) feature map size 158 | 159 | self.F = (poses.shape[-1] - self.K) // self.S + 1 # featuremap size 160 | 161 | # Out ← [?, B, P, P, F', F', K, K] ← [?, B, P, P, F, F] 162 | poses = poses.unfold(4, size=self.K, step=self.S).unfold( 163 | 5, size=self.K, step=self.S) 164 | 165 | # Out ← [?, B, 1, P, P, 1, F', F', K, K] ← [?, B, P, P, F', F', K, K] 166 | poses = poses.unsqueeze(2).unsqueeze(5) 167 | 168 | # Out ← [?, B, F', F', K, K] ← [?, B, F, F] 169 | activations = activations.unfold( 170 | 2, size=self.K, step=self.S).unfold(3, size=self.K, step=self.S) 171 | 172 | # Out ← [?, B, 1, 1, 1, F', F', K, K] ← [?, B, F', F', K, K] 173 | activations = activations.reshape(-1, self.B, 1, 174 | 1, 1, *activations.shape[2:4], self.K, self.K) 175 | 176 | # Out ← [?, B, C, P, P, F', F', K, K] ← ([?, B, 1, P, P, 1, F', F', K, K] * [1, B, C, 1, P, P, 1, 1, K, K]) 177 | V_ji = (poses * self.W_ij).sum(dim=4) # matmul equiv. 178 | 179 | # Out ← [?, B, C, P*P, 1, F', F', K, K] ← [?, B, C, P, P, F', F', K, K] 180 | V_ji = V_ji.reshape(-1, self.B, self.C, self.P*self.P, 181 | 1, *V_ji.shape[-4:-2], self.K, self.K) 182 | 183 | if self.coor_add: 184 | # if class caps layer (featuremap size = 1) 185 | if V_ji.shape[-1] == 1: 186 | self.F = self.K # 1->4 187 | 188 | # coordinates = torch.arange(self.F, dtype=torch.float32) / self.F 189 | coordinates = torch.arange( 190 | self.F, dtype=torch.float32).add(1.) / (self.F*10) 191 | i_vals = torch.zeros(self.P*self.P, self.F, 1, device=self.device) 192 | j_vals = torch.zeros(self.P*self.P, 1, self.F, device=self.device) 193 | i_vals[self.P-1, :, 0] = coordinates 194 | j_vals[2*self.P-1, 0, :] = coordinates 195 | 196 | if V_ji.shape[-1] == 1: # if class caps layer 197 | # Out ← [?, B, C, P*P, 1, 1, 1, K=F, K=F] (class caps) 198 | V_ji = V_ji + (i_vals + j_vals).reshape(1, 1, 1, 199 | self.P*self.P, 1, 1, 1, self.F, self.F) 200 | return activations, V_ji 201 | 202 | # Out ← [?, B, C, P*P, 1, F, F, K, K] 203 | V_ji = V_ji + (i_vals + j_vals).reshape(1, 1, 1, 204 | self.P*self.P, 1, self.F, self.F, 1, 1) 205 | 206 | return activations, V_ji, self.W_reg 207 | 208 | 209 | class FullyConnected2d(pl.LightningModule): 210 | '''Fully Connected 2D Layer''' 211 | 212 | def __init__(self, batch_size, in_features, in_classes=19, dense_1_features=1024, dense_2_features=2048, rate='0.3'): 213 | super().__init__() 214 | self.batch_size = batch_size 215 | self.in_classes = in_classes 216 | 217 | #---------------------------------------------------------------------------- 218 | 219 | self.Drop_1 = nn.Dropout(p=float(rate)) 220 | self.DenseReLU_1 = nn.Linear(in_features, dense_1_features) 221 | nn.init.xavier_uniform(self.DenseReLU_1.weight) 222 | # self.IN_1 = nn.LayerNorm(dense_1_features) 223 | 224 | #---------------------------------------------------------------------------- 225 | 226 | self.Drop_2 = nn.Dropout(p=float(rate)) 227 | self.DenseReLU_2 = nn.Linear(dense_1_features, dense_2_features) 228 | nn.init.xavier_uniform(self.DenseReLU_2.weight) 229 | # self.IN_2 = nn.LayerNorm(dense_2_features) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | # self.Drop_3 = nn.Dropout(p=float(rate)) 234 | self.DenseSigm = nn.Linear( 235 | dense_2_features, 2*self.in_classes) 236 | nn.init.xavier_uniform(self.DenseSigm.weight) 237 | # self.IN_3 = nn.LayerNorm(2*self.in_classes) 238 | 239 | def forward(self, x): 240 | 241 | x = self.Drop_1(x) 242 | x = self.DenseReLU_1(x) 243 | # x = self.IN_1(x) 244 | x = F.gelu(x) 245 | 246 | #---------------------------------------------------------------------------- 247 | 248 | x = self.Drop_2(x) 249 | x = self.DenseReLU_2(x) 250 | # x = self.IN_2(x) 251 | x = F.gelu(x) 252 | 253 | #---------------------------------------------------------------------------- 254 | 255 | # x = self.Drop_3(x) 256 | x = self.DenseSigm(x) 257 | # x = self.IN_3(x) 258 | # x = F.sigmoid(x) 259 | x = F.gelu(x) 260 | 261 | #---------------------------------------------------------------------------- 262 | 263 | x = torch.reshape( 264 | x, (self.batch_size, self.in_classes, 2, 1)) 265 | 266 | #---------------------------------------------------------------------------- 267 | 268 | return x 269 | 270 | 271 | class FullyConnected3d(pl.LightningModule): 272 | '''Fully Connected 3D Layer''' 273 | 274 | def __init__(self, batch_size, in_features, in_classes=19, dense_1_features=1024, dense_2_features=2048, rate='0.3'): 275 | super().__init__() 276 | self.batch_size = batch_size 277 | self.in_classes = in_classes 278 | 279 | #---------------------------------------------------------------------------- 280 | 281 | self.Drop_1 = nn.Dropout(p=float(rate)) 282 | self.DenseReLU_1 = nn.Linear(in_features, dense_1_features) 283 | nn.init.xavier_uniform(self.DenseReLU_1.weight) 284 | # self.IN_1 = nn.LayerNorm(dense_1_features) 285 | 286 | #---------------------------------------------------------------------------- 287 | 288 | self.Drop_2 = nn.Dropout(p=float(rate)) 289 | self.DenseReLU_2 = nn.Linear(dense_1_features, dense_2_features) 290 | nn.init.xavier_uniform(self.DenseReLU_2.weight) 291 | # self.IN_2 = nn.LayerNorm(dense_2_features) 292 | 293 | #---------------------------------------------------------------------------- 294 | 295 | # self.Drop_3 = nn.Dropout(p=float(rate)) 296 | self.DenseTanh = nn.Linear( 297 | dense_2_features, 3*self.in_classes) 298 | nn.init.xavier_uniform(self.DenseTanh.weight) 299 | # self.IN_3 = nn.InstanceNorm1d(3*self.in_classes) 300 | 301 | def forward(self, x): 302 | 303 | x = self.Drop_1(x) 304 | x = self.DenseReLU_1(x) 305 | # x = self.IN_1(x) 306 | x = F.gelu(x) 307 | 308 | #---------------------------------------------------------------------------- 309 | 310 | x = self.Drop_2(x) 311 | x =self.DenseReLU_2(x) 312 | # x = self.IN_2(x) 313 | x = F.gelu(x) 314 | 315 | #---------------------------------------------------------------------------- 316 | 317 | # x = self.Drop_3(x) 318 | x = self.DenseTanh(x) 319 | # x = self.IN_3(x) 320 | # x = F.tanh(x) 321 | x = F.gelu(x) 322 | 323 | #---------------------------------------------------------------------------- 324 | 325 | x = torch.reshape( 326 | x, (self.batch_size, self.in_classes, 3, 1)) 327 | 328 | #---------------------------------------------------------------------------- 329 | 330 | return x 331 | 332 | 333 | class DepthReconstruction(pl.LightningModule): 334 | '''Depth Reconstruction Layer''' 335 | 336 | def __init__(self, batch_size, input_width, input_height, in_features, in_classes=19, dense_1_features=1024, dense_2_features=2048, rate='0.3'): 337 | super().__init__() 338 | self.batch_size = batch_size 339 | self.in_classes = in_classes 340 | self.input_width = input_width 341 | self.input_height = input_height 342 | 343 | #---------------------------------------------------------------------------- 344 | 345 | self.Drop_1 = nn.Dropout(p=float(rate)) 346 | self.DenseReLU_1 = nn.Linear(in_features, dense_1_features) 347 | nn.init.xavier_uniform(self.DenseReLU_1.weight) 348 | # self.IN_1 = nn.LayerNorm(dense_1_features) 349 | 350 | #---------------------------------------------------------------------------- 351 | 352 | self.Drop_2 = nn.Dropout(p=float(rate)) 353 | self.DenseReLU_2 = nn.Linear(dense_1_features, dense_2_features) 354 | nn.init.xavier_uniform(self.DenseReLU_2.weight) 355 | # self.IN_2 = nn.LayerNorm(dense_2_features) 356 | 357 | #---------------------------------------------------------------------------- 358 | 359 | # self.Drop_3 = nn.Dropout(p=float(rate)) 360 | self.DenseReLU_3 = nn.Linear( 361 | dense_2_features, 64*64*3) 362 | nn.init.xavier_uniform(self.DenseReLU_3.weight) 363 | # self.IN_3 = nn.InstanceNorm1d(3*self.in_classes) 364 | 365 | self.ConvTrans = nn.ConvTranspose2d( 366 | 3, 3, 4, stride=4, padding=0) 367 | 368 | #self.BN_3 = nn.BatchNorm1d(64*64*self.in_classes) 369 | 370 | def forward(self, x): 371 | 372 | x = self.Drop_1(x) 373 | x = self.DenseReLU_1(x) 374 | # x = self.IN_1(x) 375 | x = F.gelu(x) 376 | 377 | #---------------------------------------------------------------------------- 378 | 379 | x = self.Drop_2(x) 380 | x =self.DenseReLU_2(x) 381 | # x = self.IN_2(x) 382 | x = F.gelu(x) 383 | 384 | #---------------------------------------------------------------------------- 385 | 386 | # x = self.Drop_3(x) 387 | x = self.DenseReLU_3(x) 388 | # x = self.IN_3(x) 389 | # x = F.tanh(x) 390 | x = F.gelu(x) 391 | 392 | #---------------------------------------------------------------------------- 393 | 394 | x = torch.reshape( 395 | x, (self.batch_size, 3, 64, 64)) # > (10, 3, 64, 64) 396 | 397 | # x = F.interpolate(x, size=( 398 | # self.input_width, self.input_height), mode='bilinear') # > (10, 19, 256, 256) 399 | 400 | x = self.ConvTrans(x) 401 | 402 | # x = x.permute(0, 2, 3, 1) # > (10, 256, 256, 19) 403 | 404 | return x 405 | 406 | 407 | class VariationalBayesRouting2d(pl.LightningModule): 408 | '''Variational Bayes Capsule Routing Layer''' 409 | 410 | def __init__(self, in_caps, out_caps, pose_dim, 411 | kernel_size, stride, 412 | alpha0, # Dirichlet 413 | m0, kappa0, # Gaussian 414 | Psi0, nu0, # Wishart 415 | cov='diag', iter=3, class_caps=False): 416 | super().__init__() 417 | 418 | self.B = in_caps 419 | self.C = out_caps 420 | self.P = pose_dim 421 | self.D = np.max([2, self.P*self.P]) 422 | self.K = kernel_size 423 | self.S = stride 424 | 425 | self.cov = cov # diag/full 426 | self.iter = iter # routing iters 427 | self.class_caps = class_caps 428 | self.n_classes = out_caps if class_caps else None 429 | 430 | # dirichlet prior parameter 431 | self.alpha0 = torch.tensor(alpha0).type(torch.FloatTensor) 432 | # self.alpha0 = nn.Parameter(torch.zeros(1,1,self.C,1,1,1,1,1,1).fill_(alpha0)) learn it by backprop 433 | 434 | # Out ← [?, 1, C, P*P, 1, 1, 1, 1, 1] 435 | self.register_buffer('m0', m0.unsqueeze(0).repeat( 436 | self.C, 1).reshape(1, 1, self.C, self.D, 1, 1, 1, 1, 1)) # gaussian prior mean parameter 437 | 438 | # precision scaling parameter of gaussian prior over capsule component means 439 | self.kappa0 = kappa0 440 | 441 | # scale matrix of wishart prior over capsule precisions 442 | if self.cov == 'diag': 443 | # Out ← [?, 1, C, P*P, 1, 1, 1, 1, 1] 444 | self.register_buffer('Psi0', torch.diag(Psi0).unsqueeze(0).repeat( 445 | self.C, 1).reshape(1, 1, self.C, self.D, 1, 1, 1, 1, 1)) 446 | 447 | elif self.cov == 'full': 448 | # Out ← [?, 1, C, P*P, P*P, 1, 1, 1, 1] 449 | self.register_buffer('Psi0', Psi0.unsqueeze(0).repeat( 450 | self.C, 1, 1).reshape(1, 1, self.C, self.D, self.D, 1, 1, 1, 1)) 451 | 452 | # degree of freedom parameter of wishart prior capsule precisions 453 | self.nu0 = nu0 454 | 455 | # log determinant = 0, if Psi0 is identity 456 | self.register_buffer('lndet_Psi0', 2*torch.diagonal(torch.cholesky( 457 | Psi0)).log().sum()) 458 | 459 | # pre compute the argument of the digamma function in E[ln|lambda_j|] 460 | self.register_buffer('diga_arg', torch.arange(self.D).reshape( 461 | 1, 1, 1, self.D, 1, 1, 1, 1, 1).type(torch.FloatTensor)) 462 | 463 | # pre define some constants 464 | self.register_buffer('Dlog2', 465 | self.D*torch.log(torch.tensor(2.)).type(torch.FloatTensor)) 466 | self.register_buffer('Dlog2pi', 467 | self.D*torch.log(torch.tensor(2.*np.pi)).type(torch.FloatTensor)) 468 | 469 | # Out ← [K*K, 1, K, K] vote collecting filter 470 | self.register_buffer('filter', 471 | torch.eye(self.K*self.K).reshape(self.K*self.K, 1, self.K, self.K)) 472 | 473 | # Out ← [1, 1, C, 1, 1, 1, 1, 1, 1] optional params 474 | self.beta_u = nn.Parameter(torch.zeros(1, 1, self.C, 1, 1, 1, 1, 1, 1)) 475 | self.beta_a = nn.Parameter(torch.zeros(1, 1, self.C, 1, 1, 1, 1, 1, 1)) 476 | 477 | self.BN_v = nn.BatchNorm3d(self.C, affine=False) 478 | self.BN_a = nn.BatchNorm2d(self.C, affine=False) 479 | 480 | # Out ← [?, B, 1, 1, 1, F, F, K, K], [?, B, C, P*P, 1, F, F, K, K] ← In 481 | def forward(self, a_i, V_ji): 482 | 483 | # input capsule (B) votes feature map size (K) 484 | self.F_i = a_i.shape[-2:] 485 | self.F_o = a_i.shape[-4:-2] # output capsule (C) feature map size (F) 486 | # total num of lower level capsules 487 | self.N = self.B*self.F_i[0]*self.F_i[1] 488 | 489 | # Out ← [1, B, C, 1, 1, 1, 1, 1, 1] 490 | R_ij = (1./self.C) * torch.ones(1, self.B, self.C, 491 | 1, 1, 1, 1, 1, 1, requires_grad=False, device=self.device) 492 | 493 | for i in range(self.iter): # routing iters 494 | 495 | # update capsule parameter distributions 496 | self.update_qparam(a_i, V_ji, R_ij) 497 | 498 | if i != self.iter-1: # skip last iter 499 | # update latent variable distributions (child to parent capsule assignments) 500 | R_ij = self.update_qlatent(a_i, V_ji) 501 | 502 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] 503 | self.Elnlambda_j = self.reduce_poses( 504 | torch.digamma(.5*(self.nu_j - self.diga_arg))) \ 505 | + self.Dlog2 + self.lndet_Psi_j 506 | 507 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] 508 | self.Elnpi_j = torch.digamma(self.alpha_j) \ 509 | - torch.digamma(self.alpha_j.sum(dim=2, keepdim=True)) 510 | 511 | # subtract "- .5*ln|lmbda|" due to precision matrix, instead of adding "+ .5*ln|sigma|" for covariance matrix 512 | # posterior entropy H[q*(mu_j, sigma_j)] 513 | H_q_j = .5*self.D * \ 514 | torch.log(torch.tensor(2*np.pi*np.e)) - .5*self.Elnlambda_j 515 | 516 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] weighted negative entropy with optional beta params and R_j weight 517 | a_j = self.beta_a - (torch.exp(self.Elnpi_j) * 518 | H_q_j + self.beta_u) # * self.R_j 519 | 520 | # Out ← [?, C, F, F] 521 | a_j = a_j.squeeze() 522 | 523 | # Out ← [?, C, P*P, F, F] ← [?, 1, C, P*P, 1, F, F, 1, 1] 524 | self.m_j = self.m_j.squeeze() 525 | 526 | # so BN works in the classcaps layer 527 | if self.class_caps: 528 | # Out ← [?, C, 1, 1] ← [?, C] 529 | a_j = a_j[..., None, None] 530 | 531 | # Out ← [?, C, P*P, 1, 1] ← [?, C, P*P] 532 | self.m_j = self.m_j[..., None, None] 533 | # else: 534 | # self.m_j = self.BN_v(self.m_j) 535 | 536 | # Out ← [?, C, P*P, F, F] 537 | # use 'else' above to deactivate BN_v for class_caps 538 | self.m_j = self.BN_v(self.m_j) 539 | 540 | # Out ← [?, C, P, P, F, F] ← [?, C, P*P, F, F] 541 | self.m_j = self.m_j.reshape(-1, self.C, self.P, self.P, *self.F_o) 542 | 543 | # Out ← [?, C, F, F] 544 | a_j = torch.sigmoid(self.BN_a(a_j)) 545 | 546 | # propagate posterior means to next layer 547 | return a_j.squeeze(), self.m_j.squeeze() 548 | 549 | def update_qparam(self, a_i, V_ji, R_ij): 550 | 551 | # Out ← [?, B, C, 1, 1, F, F, K, K] 552 | # broadcast a_i 1->C, and R_ij (1,1,1,1)->(F,F,K,K), 1->batch 553 | R_ij = R_ij * a_i 554 | 555 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] 556 | self.R_j = self.reduce_icaps(R_ij) 557 | 558 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] 559 | self.alpha_j = self.alpha0 + self.R_j 560 | # self.alpha_j = torch.exp(self.alpha0) + self.R_j # when alpha's a param 561 | self.kappa_j = self.kappa0 + self.R_j 562 | self.nu_j = self.nu0 + self.R_j 563 | 564 | # Out ← [?, 1, C, P*P, 1, F, F, 1, 1] 565 | mu_j = (1./self.R_j) * self.reduce_icaps(R_ij * V_ji) 566 | 567 | # Out ← [?, 1, C, P*P, 1, F, F, 1, 1] 568 | # self.m_j = (1./self.kappa_j) * (self.R_j * mu_j + self.kappa0 * self.m0) # use this if self.m0 != 0 569 | # priors removed for faster computation 570 | self.m_j = (1./self.kappa_j) * (self.R_j * mu_j) 571 | 572 | if self.cov == 'diag': 573 | # Out ← [?, 1, C, P*P, 1, F, F, 1, 1] (1./R_j) not needed because Psi_j calc 574 | sigma_j = self.reduce_icaps(R_ij * (V_ji - mu_j).pow(2)) 575 | 576 | # Out ← [?, 1, C, P*P, 1, F, F, 1, 1] 577 | # self.invPsi_j = self.Psi0 + sigma_j + (self.kappa0*self.R_j / self.kappa_j) \ 578 | # * (mu_j - self.m0).pow(2) # use this if m0 != 0 or kappa0 != 1 579 | # priors removed for faster computation 580 | self.invPsi_j = self.Psi0 + sigma_j + \ 581 | (self.R_j / self.kappa_j) * (mu_j).pow(2) 582 | 583 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] (-) sign as inv. Psi_j 584 | # log det of diag precision matrix 585 | self.lndet_Psi_j = -self.reduce_poses(torch.log(self.invPsi_j)) 586 | 587 | elif self.cov == 'full': 588 | # [?, B, C, P*P, P*P, F, F, K, K] 589 | sigma_j = self.reduce_icaps( 590 | R_ij * (V_ji - mu_j) * (V_ji - mu_j).transpose(3, 4)) 591 | 592 | # Out ← [?, 1, C, P*P, P*P, F, F, 1, 1] full cov, torch.inverse(self.Psi0) 593 | self.invPsi_j = self.Psi0 + sigma_j + (self.kappa0*self.R_j / self.kappa_j) \ 594 | * (mu_j - self.m0) * (mu_j - self.m0).transpose(3, 4) 595 | 596 | # Out ← [?, 1, C, F, F, 1, 1 , P*P, P*P] 597 | # needed for pytorch (*,n,n) dim requirements in .cholesky and .inverse 598 | self.invPsi_j = self.invPsi_j.permute(0, 1, 2, 5, 6, 7, 8, 3, 4) 599 | 600 | # Out ← [?, 1, 1, 1, C, F, F, 1, 1] (-) sign as inv. Psi_j 601 | self.lndet_Psi_j = -2*torch.diagonal(torch.cholesky( 602 | self.invPsi_j), dim1=-2, dim2=-1).log().sum(-1, keepdim=True)[..., None] 603 | 604 | def update_qlatent(self, a_i, V_ji): 605 | 606 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] 607 | self.Elnpi_j = torch.digamma(self.alpha_j) \ 608 | - torch.digamma(self.alpha_j.sum(dim=2, keepdim=True)) 609 | 610 | # Out ← [?, 1, C, 1, 1, F, F, 1, 1] broadcasting diga_arg 611 | self.Elnlambda_j = self.reduce_poses( 612 | torch.digamma(.5*(self.nu_j - self.diga_arg))) \ 613 | + self.Dlog2 + self.lndet_Psi_j 614 | 615 | if self.cov == 'diag': 616 | # Out ← [?, B, C, 1, 1, F, F, K, K] 617 | ElnQ = (self.D/self.kappa_j) + self.nu_j \ 618 | * self.reduce_poses((1./self.invPsi_j) * (V_ji - self.m_j).pow(2)) 619 | 620 | elif self.cov == 'full': 621 | # Out ← [?, B, C, 1, 1, F, F, K, K] 622 | Vm_j = V_ji - self.m_j 623 | ElnQ = (self.D/self.kappa_j) + self.nu_j * self.reduce_poses( 624 | Vm_j.transpose(3, 4) * torch.inverse( 625 | self.invPsi_j).permute(0, 1, 2, 7, 8, 3, 4, 5, 6) * Vm_j) 626 | 627 | # Out ← [?, B, C, 1, 1, F, F, K, K] 628 | lnp_j = .5*self.Elnlambda_j - .5*self.Dlog2pi - .5*ElnQ 629 | 630 | # Out ← [?, B, C, 1, 1, F, F, K, K] 631 | p_j = torch.exp(self.Elnpi_j + lnp_j) 632 | 633 | # Out ← [?*B, 1, F', F'] ← [?*B, K*K, F, F] ← [?, B, 1, 1, 1, F, F, K, K] 634 | sum_p_j = F.conv_transpose2d( 635 | input=p_j.sum(dim=2, keepdim=True).reshape( 636 | -1, *self.F_o, self.K*self.K).permute(0, -1, 1, 2).contiguous(), 637 | weight=self.filter, 638 | stride=[self.S, self.S]) 639 | 640 | # Out ← [?*B, 1, F, F, K, K] 641 | sum_p_j = sum_p_j.unfold(2, size=self.K, step=self.S).unfold( 642 | 3, size=self.K, step=self.S) 643 | 644 | # Out ← [?, B, 1, 1, 1, F, F, K, K] 645 | sum_p_j = sum_p_j.reshape( 646 | [-1, self.B, 1, 1, 1, *self.F_o, self.K, self.K]) 647 | 648 | # Out ← [?, B, C, 1, 1, F, F, K, K] # normalise over out_caps j 649 | return 1. / torch.clamp(sum_p_j, min=1e-11) * p_j 650 | 651 | def reduce_icaps(self, x): 652 | return x.sum(dim=(1, -2, -1), keepdim=True) 653 | 654 | def reduce_poses(self, x): 655 | return x.sum(dim=(3, 4), keepdim=True) 656 | --------------------------------------------------------------------------------