├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── check_dataset_stat.py ├── dataset ├── __init__.py ├── base_dataset.py ├── data_model.py ├── import_data.py ├── importers.py ├── real_dataset.py └── synth_dataset.py ├── model ├── __init__.py ├── base_model.py ├── hpe_model.py ├── pipeline │ ├── __init__.py │ ├── base_pipeline.py │ └── deep_fisheye_pipeline.py └── pix2depth_model.py ├── network ├── __init__.py ├── base_net.py ├── basic_block.py ├── hand_depth_net.py ├── hand_module_net.py ├── hand_pose_net.py ├── helper.py ├── norm.py ├── projection.py ├── resnet.py └── unprojection.py ├── option ├── __init__.py ├── base.py ├── general.py ├── hpe.py ├── options.py └── pix2depth.py ├── plot_result_loss.py ├── preset ├── __init__.py ├── base_preset.py └── presets.py ├── requirements.txt ├── run ├── __init__.py ├── base │ ├── __init__.py │ ├── base_run.py │ ├── hpe_base_util.py │ ├── hpe_test_base_run.py │ ├── hpe_train_base_run.py │ ├── test_base_run.py │ └── train_base_run.py ├── empty_run.py ├── pipeline │ ├── helper.py │ ├── pipeline_test_run.py │ └── pipeline_train_run.py ├── pix2depth │ ├── helper.py │ └── pix2depth_train_run.py └── pix2joint │ ├── __init__.py │ ├── pix2joint_test_run.py │ └── pix2joint_train_run.py ├── run_board.sh ├── scripts ├── check_synth_dataset_stat.sh ├── pipeline │ ├── test_pipeline_synth.sh │ ├── train_pipeline_real.sh │ ├── train_pipeline_real_local.sh │ ├── train_pipeline_synth.sh │ └── train_pipeline_synth_local.sh ├── pix2depth │ ├── train_pix2depth.sh │ └── train_pix2depth_local.sh └── pix2joint │ ├── test_pix2joint.sh │ ├── test_real_pix2joint.sh │ ├── train_pix2joint.sh │ ├── train_pix2joint_local.sh │ ├── train_real_pix2joint.sh │ └── train_real_pix2joint_local.sh ├── technical_concept.png ├── test.py ├── train.py └── util ├── __init__.py ├── debug.py ├── filter.py ├── fisheye.py ├── hooks.py ├── image.py ├── io.py ├── joint.py ├── math.py ├── package.py ├── projector.py ├── unwarp.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 KAIST HCI Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepFisheye Network 2 | This is a codebase for training and testing the DeepFisheye network. It is based on Python3 and Pytorch. 3 | Our work, **'DeepFisheye: Near-Surface Multi-Finger Tracking Technology Using Fisheye Camera'**, was presented in UIST 2020. 4 | 5 | ![technical concept](technical_concept.png) 6 | 7 | Near-surface multi-finger tracking (NMFT) technology expands the input space of touchscreens by enabling novel interactions such as mid-air and finger-aware interactions. We present DeepFisheye, a practical NMFT solution for mobile devices, that utilizes a fisheye camera attached at the bottom of a touchscreen. DeepFisheye acquires the image of an interacting hand positioned above the touchscreen using the camera and employs deep learning to estimate the 3D position of each fingertip. Additionally, we created simple rule-based classifiers that estimate the contact finger and hand posture from DeepFisheye’s output. 8 | 9 | ## Related Links 10 | - [Paper](https://dl.acm.org/doi/abs/10.1145/3379337.3415818) 11 | - [Project page](http://kwpark.io/deepfisheye) 12 | - [Dataset project page](https://github.com/KeunwooPark/DeepFisheyeDataset) 13 | - [Pretrained model weight](https://drive.google.com/file/d/1C_kbaw1Ull4D_JHgDkhrLwdjCITzj-8E/view?usp=sharing) 14 | 15 | ## Demo 16 | If you want to try out pretrained model, please see [DeepFisheye_Realtime](https://github.com/KeunwooPark/DeepFisheye_Realtime). 17 | 18 | ## Short Explanations of the Folders 19 | 20 | |Folder Name|What is it?| 21 | |-------------|-------------| 22 | |dataset | Dataloaders and file pathes of datasets.| 23 | |model | Network models. A model contains a network, an optimizer. | 24 | | +- pipeline | A pipeline connects multiple models into one.| 25 | |network | Actual deep learning networks.| 26 | |options | Options for models and pipelines. | 27 | |preset| A preset is a set of predefined options.| 28 | |run| A run contains actual training or testing logic.| 29 | |scripts| Scripts for running train and test.| 30 | |results (will be created dynamically)| All the results goes into here.| 31 | |util| Useful functions. | 32 | |```train.py``` and ```test.py``` | Main function of the training and testing process. You do not have touch these files.| 33 | 34 | 35 | ## Quick Start: How to Train and Test 36 | ### 1. Install the requirements. 37 | 38 | ```shell 39 | $ pip install -r requirements.txt 40 | ``` 41 | 42 | ### 2. Download and Unzip a Dataset. 43 | 44 | There are two datasets. One is *synthetic* datset and the other one is *real* dataset. It is possible to run train and test with only one dataset. Please visit the [dataset project page](https://github.com/KeunwooPark/DeepFisheyeDataset) to download the datasets. Then, unzip the dataset file wherever you want to. 45 | 46 | 47 | ### 3. Import a Dataset. 48 | 49 | Importing a dataset means creating a text file that holds file paths of the dataset. For example, you can import sythetic dataset by running the script below. You have to put the root directory of the dataset. 50 | 51 | ```shell 52 | $ python dataset/import_data.py --data_type synth --data_dir root/path/of/dataset --test_ratio 0.2 53 | ``` 54 | 55 | Then two text files are created under a directory with the name of the dataset type (e.g. synth, real). 56 | 57 | ### 4. Run Training 58 | 59 | There are three steps of training in our paper. You have to run a corresponding script for each steps. If you want to use [tensorboard](https://www.tensorflow.org/tensorboard), then run ```run_board.sh``` before running a script. 60 | 61 | A script with ```_local``` postfix is for testing. It runs with very small size dataset. If you have a local computer with small GPU and want to run main training on a large GPU server, then use a local script to test in your local computer before running a training on the server. 62 | 63 | 1. Create a directory for saving trained models (e.g. trained_models). 64 | 2. Run ```pix2depth/train_pix2depth.sh```. **Change parameters for your hardware environment (e.g. batch_size, gpu_ids).** 65 | 3. Choose the weight file (```*.pth.tar```) of Pix2Depth network that you want to use for the next step, and move the file to ```trained_models```. 66 | 4. In the ```pipeline.train_pipeline_synth.sh```, set ```--p2d_pretrained``` parameter to the location of the weight file. 67 | 5. Choose and move the weight file of the pipeline as Step 3. 68 | 6. In the ```pipeline.train_pipeline_real.sh```, set ```--pipeline_pretrained``` parameter to the location of the trained pipeline's weight file. 69 | 70 | ### (5. Run Testing) 71 | 72 | You can run test for each trained models with ```test_*.sh``` scripts. All you have to do is set ```--*_pretrained``` parameter correctly. 73 | 74 | ## Using the pretrained model 75 | You can use the pretrained model that we used in our paper. 76 | 1. Download the model weight file from [here](https://drive.google.com/file/d/1C_kbaw1Ull4D_JHgDkhrLwdjCITzj-8E/view?usp=sharing). 77 | 2. Set ```pipeline_pretrained``` parameter in the option to the location of the downloaded weight file. 78 | 3. Create the pipeline. 79 | 80 | ## FAQ 81 | 82 | ### Is the real dataset same as the dataset used for user testing in the paper? 83 | No. Data used in the user tests are different from the real dataset. We used real dataset only for the training. We cannot share to actual test data because of privacy issues. 84 | 85 | ### Can I create a virtual fisheye camera? 86 | Please check out [fisheye camera mesh generator](https://github.com/KeunwooPark/fisheye_mesh_generator). 87 | 88 | # Contact 89 | - [Keunwoo Park](http://kwpark.io) 90 | 91 | # Citation 92 | Please cite our paper in your publication if it helped your research. 93 | 94 | ``` 95 | @inproceedings{Park2020DeepFisheye, 96 | author = {Park, Keunwoo and Kim, Sunbum and Yoon, Youngwoo and Kim, Tae-Kyun and Lee, Geehyuk}, 97 | title = {DeepFisheye: Near-Surface Multi-Finger Tracking Technology Using Fisheye Camera}, 98 | year = {2020}, 99 | isbn = {9781450375146}, 100 | publisher = {Association for Computing Machinery}, 101 | address = {New York, NY, USA}, 102 | url = {https://doi.org/10.1145/3379337.3415818}, 103 | doi = {10.1145/3379337.3415818}, 104 | booktitle = {Proceedings of the 33rd Annual ACM Symposium on User Interface Software and Technology}, 105 | pages = {1132–1146}, 106 | numpages = {15}, 107 | keywords = {near-surface, deep learning, touchscreen, computer vision, finger tracking}, 108 | location = {Virtual Event, USA}, 109 | series = {UIST '20} 110 | } 111 | ``` 112 | 113 | # Acknowledgments 114 | Our code is inspired by [cyclegan](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 115 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/__init__.py -------------------------------------------------------------------------------- /check_dataset_stat.py: -------------------------------------------------------------------------------- 1 | from option.options import Options 2 | from preset import modify_options 3 | from dataset import * 4 | from run import find_run_using_name 5 | from run.empty_run import EmptyRun 6 | import torch 7 | 8 | def get_number_of_sum_and_dim(data, data_type): 9 | if data_type == 'joint': 10 | b, j, c = data.shape 11 | nb_data = b * j 12 | dim_to_reduce = [0, 1] 13 | else: 14 | b, c, h, w = data.shape 15 | nb_data = b * h * w 16 | dim_to_reduce = [0,2,3] 17 | 18 | return nb_data, dim_to_reduce 19 | 20 | def online_stat(loader, data_type): 21 | """Compute the mean and sd in an online fashion 22 | 23 | Var[x] = E[X^2] - E^2[X] 24 | """ 25 | cnt = 0 26 | fst_moment = None # mean 27 | snd_moment = None 28 | 29 | _min = float("inf") 30 | _max = -float("inf") 31 | 32 | for data_packet in loader: 33 | data = data_packet[data_type] 34 | nb_data, dim_to_reduce = get_number_of_sum_and_dim(data, data_type) 35 | sum_ = torch.sum(data, dim=dim_to_reduce) 36 | sum_of_square = torch.sum(data ** 2, dim=dim_to_reduce) 37 | 38 | if fst_moment is None: 39 | fst_moment = torch.zeros(sum_.shape) 40 | snd_moment = torch.zeros(sum_.shape) 41 | 42 | fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_data) 43 | snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_data) 44 | 45 | cnt += nb_data 46 | 47 | data_max = data.max() 48 | data_min = data.min() 49 | 50 | if data_max > _max: 51 | _max = data_max 52 | 53 | if data_min < _min: 54 | _min = data_min 55 | 56 | return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2), _max, _min 57 | 58 | def main(): 59 | 60 | dataset_type = 'joint' 61 | 62 | options = Options() 63 | options.initialize() 64 | modify_options(options) 65 | options.parse() 66 | options.general.dataset = 'synth' 67 | print(options.pretty_str()) 68 | run = EmptyRun(options) 69 | 70 | train_loader = run.get_train_loader() 71 | train_mean, train_std, max, min = online_stat(train_loader, dataset_type) 72 | print("train mean:", train_mean) 73 | print("train std:", train_std) 74 | print("train max:", max) 75 | print("train min:", min) 76 | 77 | test_loader = run.get_test_loader(shuffle = True) 78 | test_mean, test_std, max, min = online_stat(test_loader, dataset_type) 79 | print("test mean:", test_mean) 80 | print("test std:", test_std) 81 | print("test max:", max) 82 | print("test min:", min) 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | from dataset.base_dataset import BaseDataset 4 | from torch.utils.data import DataLoader 5 | 6 | from util.package import find_class_using_name 7 | 8 | def find_dataset_by_name(dataset_name): 9 | dataset_cls = find_class_using_name('dataset', dataset_name, 'dataset') 10 | if inspect.isclass(dataset_cls) and issubclass(dataset_cls, BaseDataset): 11 | return dataset_cls 12 | 13 | raise Exception("{} is not correctely implemented as BaseRun class".format(dataset_name)) 14 | 15 | def create_train_dataset(opt): 16 | dataset_cls = find_dataset_by_name(opt.dataset) 17 | return dataset_cls(opt, True) 18 | 19 | def create_test_dataset(opt): 20 | dataset_cls = find_dataset_by_name(opt.dataset) 21 | return dataset_cls(opt, False) 22 | 23 | def create_dataloader(dataset, batch_size, num_workers, shuffle): 24 | kwargs = {'num_workers':num_workers, 'pin_memory':True} 25 | dataloader = DataLoader(\ 26 | dataset,\ 27 | batch_size = batch_size, shuffle = shuffle, **kwargs) 28 | return dataloader 29 | -------------------------------------------------------------------------------- /dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import pathlib 3 | import random 4 | import os 5 | 6 | from torch.utils.data.dataset import Dataset 7 | from torchvision import transforms 8 | 9 | class BaseDataset(ABC, Dataset): 10 | def __init__(self, opt, is_train): 11 | self.is_train = is_train 12 | self.img_size = opt.img_size 13 | self.no_flip = opt.no_flip 14 | 15 | transform_list = [] 16 | transform_list += [transforms.Resize((self.img_size, self.img_size))] 17 | transform_list += [transforms.ToTensor()] 18 | 19 | self.transform = transforms.Compose(transform_list) 20 | self.color_transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.1, hue=0.15) 21 | self.hand_list = None 22 | self.set_hand_list() 23 | if not opt.max_data == float("inf"): 24 | self.hand_list = self.hand_list[:opt.max_data] 25 | 26 | @abstractmethod 27 | def set_hand_list(self): 28 | """ This method should initialize self.hand_list . """ 29 | pass 30 | 31 | def _load_filenames(self, root_name): 32 | this_file = pathlib.Path(os.path.abspath(__file__)) 33 | this_dir = this_file.parents[0] 34 | if self.is_train: 35 | filename_txt = this_dir.joinpath(root_name,'train.txt') 36 | else: 37 | filename_txt = this_dir.joinpath(root_name, 'test.txt') 38 | 39 | filenames = [] 40 | with open(str(filename_txt), 'r') as f: 41 | lines = f.readlines() 42 | for line in lines: 43 | names = line.strip().split(',') 44 | filenames.append(names) 45 | return filenames 46 | 47 | def __len__(self): 48 | return len(self.hand_list) 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | pass 53 | 54 | def toss_coin(self): 55 | return random.random() > 0.5 56 | -------------------------------------------------------------------------------- /dataset/data_model.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from torchvision import transforms 5 | 6 | class HandDataModel: 7 | 8 | def __init__(self, fish_fn, fish_depth_img_fn=None, fn_3d_joints=None): 9 | 10 | self.fish_fn = fish_fn 11 | self.fish_depth_img_fn = fish_depth_img_fn 12 | self.fn_3d_joints = fn_3d_joints 13 | 14 | def load_data(self, img_size, is_flip): 15 | self.load_imgs() 16 | 17 | self.joints_3d = self._parse_joints(self.fn_3d_joints) 18 | 19 | if is_flip: 20 | self._flip() 21 | 22 | def unload_data(self): 23 | del self.fish_img 24 | del self.fish_depth_img 25 | del self.joints_3d 26 | 27 | def load_imgs(self): 28 | self.fish_img = self._load_img(self.fish_fn) 29 | self.fish_depth_img = self._load_img(self.fish_depth_img_fn) 30 | 31 | def _load_img(self, fn): 32 | if fn is None: 33 | return None 34 | return Image.open(fn) 35 | 36 | def _parse_joints(self, filename): 37 | if filename is None: 38 | return None 39 | with open(filename, 'r') as f: 40 | line = f.readline() 41 | data = [float(d) for d in line.split(',')] 42 | joints_3d = np.array(data).reshape((-1,3)) 43 | joints_3d = torch.FloatTensor(joints_3d) 44 | return joints_3d 45 | 46 | def _flip(self): 47 | self.fish_img = self._flip_img(self.fish_img) 48 | self.fish_depth_img = self._flip_img(self.fish_depth_img) 49 | self.joints_3d[:,0] *= -1 50 | 51 | def _flip_img(self, img): 52 | if img is None: 53 | return None 54 | return transforms.functional.hflip(img) 55 | -------------------------------------------------------------------------------- /dataset/import_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import random 4 | import os 5 | from importers import * 6 | 7 | description = """ Import filenames of dataset. 8 | This script creates two textfiles, 'train.txt' and 'test.txt'. 9 | The two files contain filenames of data. 10 | Dataloaders read the files to acces files and loads them. 11 | """ 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description = description) 15 | parser.add_argument("--data_type", dest="data_type", required=True, type=str, choices=['synth', 'real']) 16 | parser.add_argument("--data_dir", dest="data_dir", required=True, type=str, help="root directory of datset") 17 | parser.add_argument("--test_ratio", dest="test_ratio", default=0.2, type=float, help="ratio of test dataset size (only for synth data)") 18 | 19 | return parser.parse_args() 20 | 21 | def write_filenames(filename, filenames_to_write): 22 | with open(filename, 'w') as f: 23 | for fn in filenames_to_write: 24 | line = ','.join(fn) + "\n" 25 | f.write(line) 26 | 27 | def main(args): 28 | if args.data_type == 'synth': 29 | importer = SynthImporter() 30 | elif args.data_type == 'real': 31 | importer = RealImporter() 32 | else: 33 | raise ValueError('wrong data type') 34 | 35 | data_filenames_for_train, data_filenames_for_test = importer.get_file_names(args) 36 | 37 | this_file = pathlib.Path(os.path.abspath(__file__)) 38 | this_dir = this_file.parents[0] 39 | import_root = this_dir.joinpath(importer.get_import_root()) 40 | import_root.mkdir(exist_ok=True) 41 | 42 | train_filename = str(import_root.joinpath('train.txt')) 43 | write_filenames(train_filename, data_filenames_for_train) 44 | test_filename = str(import_root.joinpath('test.txt')) 45 | write_filenames(test_filename, data_filenames_for_test) 46 | 47 | print("number of train data: {}".format(len(data_filenames_for_train))) 48 | print("number of test data: {}".format(len(data_filenames_for_test))) 49 | 50 | if __name__ == "__main__": 51 | args = parse_args() 52 | main(args) 53 | -------------------------------------------------------------------------------- /dataset/importers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | class BaseImporter(ABC): 3 | """ The main purpose of an Importer is to 4 | 5 | - split test and train dataset 6 | - unify dataset structure for different datasets. 7 | """ 8 | def __init__(self): 9 | pass 10 | 11 | @abstractmethod 12 | def get_file_names(self, args): 13 | """ This method should generate text files that contains filenames of data in a dataset. 14 | Also, it should split train and test data by different directories. 15 | 16 | This should return two lists 17 | - filenames_for_train 18 | - filenames_for_test 19 | 20 | All the filenames in the lists should be absolute paths. 21 | """ 22 | pass 23 | 24 | @abstractmethod 25 | def get_import_root(self): 26 | """ This method should return the root directory of where the data would be imported to. """ 27 | pass 28 | 29 | def get_num_train_for_even_batch(self, num_total, test_ratio, batch_size): 30 | assert type(batch_size) == int, "batch_size should be int" 31 | 32 | train_ratio = 1 - test_ratio 33 | rough_num_train = num_total * train_ratio 34 | 35 | num_iter = int(rough_num_train / batch_size) 36 | 37 | num_train = int(num_iter * batch_size) 38 | return num_train 39 | 40 | import pathlib 41 | import random 42 | 43 | class SynthImporter(BaseImporter): 44 | 45 | def get_file_names(self, args): 46 | data_root = pathlib.Path(args.data_dir) 47 | fish_root = data_root.joinpath('fish') 48 | heatmap_root = fish_root 49 | joints_root = data_root.joinpath('joints') 50 | fish_depth_root = data_root.joinpath('fish_depth') 51 | 52 | img_paths = fish_root.glob("**/*.png") 53 | 54 | total_filenames = [] 55 | for img_path in img_paths: 56 | fish_depth_img_path = self.get_depth_img_path(fish_depth_root, img_path) 57 | joints_path = self.get_joints_path(joints_root, img_path) 58 | 59 | fns = (str(img_path), str(fish_depth_img_path), str(joints_path)) 60 | total_filenames.append(fns) 61 | 62 | random.shuffle(total_filenames) 63 | 64 | chunk_size = 256 65 | num_train = self.get_num_train_for_even_batch(len(total_filenames), args.test_ratio, chunk_size) 66 | 67 | filenames_for_train = total_filenames[:num_train] 68 | filenames_for_test = total_filenames[num_train:] 69 | 70 | return filenames_for_train, filenames_for_test 71 | 72 | def get_import_root(self): 73 | return 'synth' 74 | 75 | def get_matching_path_in_subroot(self, subroot, img_path, fn_pattern): 76 | id = img_path.stem 77 | subdir = img_path.parts[-2] 78 | 79 | return subroot.joinpath(subdir, fn_pattern.format(id)) 80 | def get_joints_path(self, joints_root, img_path): 81 | return self.get_matching_path_in_subroot(joints_root, img_path, '{}_joint_pos.txt') 82 | 83 | def get_depth_img_path(self, depth_root, img_path): 84 | return self.get_matching_path_in_subroot(depth_root, img_path, '{}.png') 85 | 86 | class RealImporter(BaseImporter): 87 | def get_file_names(self, args): 88 | data_root = pathlib.Path(args.data_dir) 89 | 90 | img_paths = data_root.glob("**/*.png") 91 | total_filenames = [] 92 | for img_path in img_paths: 93 | joints_path = self.get_joints_path(data_root, img_path) 94 | fns = (str(img_path), str(joints_path)) 95 | total_filenames.append(fns) 96 | 97 | if args.test_ratio > 0: 98 | num_train = self.get_num_train_for_even_batch(len(total_filenames), args.test_ratio, 256) 99 | filenames_for_train = total_filenames[:num_train] 100 | filenames_for_test = total_filenames[num_train:] 101 | 102 | return filenames_for_train, filenames_for_test 103 | 104 | return total_filenames, [] 105 | 106 | def get_import_root(self): 107 | return 'real' 108 | 109 | def get_joints_path(self, data_root, img_path): 110 | id = img_path.stem 111 | user = img_path.parts[-3] 112 | 113 | return data_root.joinpath(user, 'joints', '{}_leap.txt'.format(id)) 114 | -------------------------------------------------------------------------------- /dataset/real_dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import os 3 | import random 4 | 5 | from dataset.base_dataset import BaseDataset 6 | from dataset.data_model import HandDataModel 7 | 8 | class RealDataset(BaseDataset): 9 | 10 | def set_hand_list(self): 11 | self.hand_list = [] 12 | filenames = self._load_filenames('real') 13 | for fn in filenames: 14 | img_fn, fn_3d = fn 15 | hand = HandDataModel(img_fn, None, fn_3d_joints = fn_3d) 16 | self.hand_list.append(hand) 17 | 18 | def __getitem__(self, index): 19 | hand = self.hand_list[index] 20 | 21 | if self.is_train: 22 | is_flip = (not self.no_flip) and self.toss_coin() 23 | else: 24 | is_flip = False 25 | 26 | hand.load_data(self.img_size, is_flip = is_flip) 27 | img = hand.fish_img 28 | 29 | if self.is_train: 30 | img = self.color_transform(img) 31 | pix = self.transform(img) 32 | 33 | joint = hand.joints_3d 34 | 35 | item = {'fish': pix, \ 36 | 'joint': joint 37 | } 38 | 39 | # this prevents memory explosion. 40 | hand.unload_data() 41 | return item 42 | -------------------------------------------------------------------------------- /dataset/synth_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import scipy.ndimage 6 | 7 | from dataset.data_model import HandDataModel 8 | from dataset.base_dataset import BaseDataset 9 | import util.filter as filter 10 | from util.image import merge_channel 11 | from util.image import get_center_circle_mask 12 | 13 | class SynthDataset(BaseDataset): 14 | 15 | def __init__(self, opt, is_train): 16 | super().__init__(opt, is_train) 17 | img_shape = (self.img_size, self.img_size) 18 | self.fish_mask = get_center_circle_mask(img_shape, dataformats='CHW') 19 | 20 | self.min_depth_thrs = opt.min_depth_thrs 21 | 22 | self.blur = filter.GaussianFilter(channels = 3, kernel_size = 5, sigma = 3, peak_to_one = False) 23 | self.threshold_depth = nn.Threshold(self.min_depth_thrs, 0) 24 | 25 | def set_hand_list(self): 26 | self.hand_list = [] 27 | filenames = self._load_filenames('synth') 28 | for fn in filenames: 29 | fish_fn, fish_depth_img_fn, fn_joint = fn 30 | hand = HandDataModel(fish_fn, fish_depth_img_fn, fn_joint) 31 | self.hand_list.append(hand) 32 | 33 | def __getitem__(self, index): 34 | hand = self.hand_list[index] 35 | 36 | is_flip = (not self.no_flip) and self.toss_coin() 37 | 38 | hand.load_data(self.img_size, is_flip) 39 | fish_img = hand.fish_img 40 | fish_depth_img = hand.fish_depth_img 41 | 42 | if self.is_train: 43 | fish_img = self.color_transform(fish_img) 44 | 45 | fish_img = self.transform(fish_img) 46 | fish_img = self._blur_img(fish_img) 47 | 48 | fish_depth = self.transform(fish_depth_img) 49 | fish_depth = self._mask_fish_area(fish_depth) 50 | fish_depth = merge_channel(fish_depth, dataformats='CHW') 51 | fish_depth = self.threshold_depth(fish_depth) 52 | 53 | joint = hand.joints_3d 54 | 55 | fish_segment = self._to_binary(fish_depth, self.min_depth_thrs) 56 | 57 | item = {'fish': fish_img, \ 58 | 'fish_depth': fish_depth, \ 59 | 'joint': joint 60 | } 61 | 62 | # this prevents memory explosion. 63 | hand.unload_data() 64 | return item 65 | 66 | def _blur_img(self, img): 67 | img = img.unsqueeze(0) 68 | img = self.blur(img) 69 | return img.squeeze(0) 70 | 71 | def _mask_fish_area(self, img): 72 | return self.fish_mask * img 73 | 74 | def _to_binary(self, x, threshold): 75 | zero = torch.zeros(x.shape) 76 | one = torch.ones(x.shape) 77 | x = torch.where(x > threshold, one, zero) 78 | return x 79 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/model/__init__.py -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | 6 | class BaseModel(ABC): 7 | """ Model has a network, optimizers and losses. 8 | """ 9 | def __init__(self, opt, gpu_ids): 10 | 11 | self.set_opt_and_mode(opt) 12 | 13 | self.gpu_ids = gpu_ids 14 | self.loss_names = [] 15 | self.optimizers = [] 16 | self.loss = None 17 | self.networks = None 18 | self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if gpu_ids else torch.device('cpu') 19 | self.is_setup = False 20 | self.metric = 0 # used for learning rate policy 'plateau' 21 | 22 | self.schedulers = [] 23 | self.loss_names = [] 24 | self.visual_names = [] 25 | 26 | def set_opt_and_mode(self, opt): 27 | if opt: 28 | self.opt = opt 29 | if hasattr(opt, 'mode'): 30 | self.mode = opt.mode 31 | 32 | def check_and_load_pretrained(self): 33 | opt = self.opt 34 | 35 | if opt.pretrained: 36 | print("Load pretrained weights from {}".format(opt.pretrained)) 37 | checkpoint = torch.load(opt.pretrained) 38 | self.load_from_checkpoint(checkpoint, model_only = True) 39 | 40 | @abstractmethod 41 | def setup(self): 42 | """ Basic setup steps for a model. All the initialization should be done in here. 43 | Here are essential tasks for the setup. 44 | 1. make networks 45 | 2. send the networks to devices 46 | 3. make optimizer 47 | 4. make schedulers 48 | """ 49 | 50 | @abstractmethod 51 | def set_input(self, data): 52 | """ Hold data in the model as an input. 53 | """ 54 | pass 55 | 56 | @abstractmethod 57 | def forward(self): 58 | """ Run forward process with input. 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def optimize_parameters(self): 64 | """ Where actual parameters are optimized. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def pack_as_checkpoint(self): 70 | """ Pack the model parameters as a checkpoint. 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def load_from_checkpoint(self, checkpoint, model_only): 76 | """ Load the model from a checkpoint. 77 | """ 78 | pass 79 | 80 | @abstractmethod 81 | def get_current_results(self): 82 | """ Return current results. 83 | """ 84 | pass 85 | 86 | def get_detached_current_results(self): 87 | results = self.get_current_results() 88 | detached_results = {} 89 | for key, value in results.items(): 90 | detached_results[key] = self._detach_value(value) 91 | return detached_results 92 | 93 | def _detach_value(self, value): 94 | if value is None: 95 | return value 96 | if isinstance(value, list): 97 | new_value = [] 98 | for v in value: 99 | _v = v.detach().cpu() 100 | new_value.append(_v) 101 | return new_value 102 | 103 | return value.detach().cpu() 104 | 105 | def optimize(self): 106 | self.check_setup() 107 | assert self.mode.is_train(), "BaseModel: The model should be in training mode to be optimized." 108 | self.forward() 109 | self.optimize_parameters() 110 | 111 | def check_setup(self): 112 | assert self.is_setup, "BaseModel: call 'setup()' before use the model" 113 | 114 | def set_requires_grad(self, networks, requires_grad): 115 | if not isinstance(networks, list): 116 | networks = [networks] 117 | 118 | for network in networks: 119 | if network is not None: 120 | for p in network.parameters(): 121 | p.requires_grad = requires_grad 122 | 123 | def get_current_losses(self): 124 | losses_ret = OrderedDict() 125 | for name in self.loss_names: 126 | if isinstance(name, str): 127 | loss_name = 'loss_' + name 128 | if hasattr(self, loss_name): 129 | losses_ret[name] = float(getattr(self, loss_name)) # float(...) works for both scalar tensor and float number 130 | return losses_ret 131 | 132 | def extract_weights(self, network): 133 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 134 | weights = network.module.cpu().state_dict() 135 | network.cuda(self.gpu_ids[0]) 136 | return weights 137 | else: 138 | return network.cpu().state_dict() 139 | 140 | def apply_weights(self, network, state_dict): 141 | 142 | # resolve data parallel crush. 143 | if isinstance(network, torch.nn.DataParallel): 144 | network = network.module 145 | 146 | if hasattr(state_dict, '_metadata'): 147 | del state_dict._metadata 148 | 149 | network.load_state_dict(state_dict) 150 | 151 | def send_tensor_to_device(self, tensor): 152 | if (not tensor.is_cuda) and self.device.type.startswith('cuda'): 153 | tensor = tensor.to(self.device) 154 | return tensor 155 | 156 | @abstractmethod 157 | def get_grads(self): 158 | """ Return gradients for visualization 159 | """ 160 | pass 161 | -------------------------------------------------------------------------------- /model/hpe_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.base_model import BaseModel 5 | from network.hand_pose_net import create_hpe_net 6 | from network.projection import create_projection_net 7 | import util.image as image 8 | from util.joint import JointConverter 9 | 10 | 11 | class HPEModel(BaseModel): 12 | def setup(self, make_optimizer = True): 13 | opt = self.opt 14 | 15 | self.make_optimizer = make_optimizer 16 | self.joint_converter = JointConverter(opt.num_joints) 17 | 18 | self.loss_names = [ 'heatmap', \ 19 | 'joint'] 20 | 21 | self.net = create_hpe_net(opt, self.gpu_ids) 22 | 23 | self.criterionL2 = nn.MSELoss() 24 | 25 | self.run_interm = self.mode.is_train() or self.mode.is_test() 26 | if self.run_interm: 27 | self.set_requires_grad(self.net, True) 28 | 29 | self.projection_net = create_projection_net(opt, self.gpu_ids) 30 | 31 | self.heatmap_loss_weight = opt.heatmap_loss_weight 32 | self.heatmap_interm_loss_weight = opt.heatmap_interm_loss_weight 33 | self.joint_loss_weight = opt.joint_loss_weight 34 | self.joint_interm_loss_weight = opt.joint_interm_loss_weight 35 | 36 | if make_optimizer: 37 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=opt.lr) 38 | self.optimizers.append(self.optimizer) 39 | 40 | else: 41 | self.joint_converter.joint_scale = opt.joint_scale 42 | 43 | self.check_and_load_pretrained() 44 | 45 | self.is_setup = True 46 | 47 | def set_input(self, data): 48 | self.input = self.send_tensor_to_device(data['img']) 49 | self.input = image.normalize_img(self.input) 50 | 51 | self.joint_true = None 52 | self.heatmap_seed = None 53 | self.heatmap_true = None 54 | if self.run_interm: 55 | joint_true = self.send_tensor_to_device(data['joint']) 56 | normalized_joint = self.joint_converter.normalize(joint_true) 57 | self.heatmap_true = self.projection_net(normalized_joint) 58 | self.heatmap_true.requires_grad = False 59 | self.joint_true = self.joint_converter.convert_for_training(joint_true) 60 | 61 | def forward(self): 62 | assert self.is_setup 63 | 64 | result = self.net(self.input) 65 | self.joint_out = result['joint'] 66 | self.heatmap_out = None 67 | self.heatmap_interms = [] 68 | self.heatmap_out = result['heatmap'] 69 | self.reprojected_heatmap = None 70 | 71 | if self.run_interm: 72 | self.heatmap_interms = result['heatmap_interms'] 73 | self.joint_interms = result['joint_interms'] 74 | 75 | unflat_joint_out = self.joint_converter.convert_for_output(self.joint_out, no_unnormalize = True) 76 | self.reprojected_heatmap = self.projection_net(unflat_joint_out) 77 | 78 | def optimize_parameters(self): 79 | self.update_losses() 80 | total_loss = self.get_total_loss() 81 | self.optimizer.zero_grad() 82 | total_loss.backward() 83 | self.optimizer.step() 84 | 85 | def update_losses(self): 86 | self._update_heatmap_losses() 87 | self._update_joint_losses() 88 | 89 | def _update_heatmap_losses(self): 90 | self.loss_heatmap = 0 91 | for i, interm in enumerate(self.heatmap_interms): 92 | interm_loss = self.criterionL2(self.heatmap_true, interm) * self.heatmap_interm_loss_weight 93 | loss_name = 'heatmap_interm{}'.format(i+1) 94 | setattr(self, 'loss_' + loss_name, interm_loss) 95 | self._add_loss_name_if_not_exists(loss_name) 96 | 97 | self.loss_heatmap = self.criterionL2(self.heatmap_true, self.heatmap_out) * self.heatmap_loss_weight 98 | 99 | def _update_joint_losses(self): 100 | self.loss_joint = 0 101 | for i, interm in enumerate(self.joint_interms): 102 | interm_loss = self.criterionL2(self.joint_true, interm) * self.joint_interm_loss_weight 103 | loss_name = 'joint_interm{}'.format(i+1) 104 | setattr(self, "loss_"+loss_name, interm_loss) 105 | self._add_loss_name_if_not_exists(loss_name) 106 | 107 | self.loss_joint = self.criterionL2(self.joint_true, self.joint_out) * self.joint_loss_weight 108 | 109 | def _add_loss_name_if_not_exists(self, loss_name): 110 | if loss_name not in self.loss_names: 111 | self.loss_names.append(loss_name) 112 | 113 | def get_total_loss(self): 114 | heatmap_loss = self._sum_heatmap_losses() 115 | joint_loss = self._sum_joint_losses() 116 | 117 | total_loss = heatmap_loss + joint_loss 118 | return total_loss 119 | 120 | def _sum_heatmap_losses(self): 121 | return self._sum_losses_by_part_of_name('heatmap') 122 | 123 | def _sum_joint_losses(self): 124 | return self._sum_losses_by_part_of_name('joint') 125 | 126 | def _sum_losses_by_part_of_name(self, part_of_name): 127 | losses = [] 128 | 129 | for loss_name in self.loss_names: 130 | if loss_name.startswith(part_of_name): 131 | losses.append(getattr(self, "loss_" + loss_name)) 132 | 133 | return sum(losses) 134 | 135 | def pack_as_checkpoint(self): 136 | checkpoint = {} 137 | checkpoint['net'] = self.extract_weights(self.net) 138 | if self.make_optimizer: 139 | checkpoint['optim'] = self.optimizer.state_dict() 140 | 141 | return checkpoint 142 | 143 | def load_from_checkpoint(self, checkpoint, model_only): 144 | self.apply_weights(self.net, checkpoint['net']) 145 | if not model_only: 146 | self.optimizer.load_state_dict(checkpoint['optim']) 147 | 148 | def get_current_results(self): 149 | img = image.unnormalize_as_img(self.input) 150 | joint_out = self.joint_converter.convert_for_output(self.joint_out) 151 | heatmap_interms = self.heatmap_interms 152 | heatmap_out = self.heatmap_out 153 | heatmap_true = self.heatmap_true 154 | reprojected_heatmap = self.reprojected_heatmap 155 | 156 | if self.run_interm: 157 | return {'img': img, \ 158 | 'joint': joint_out, \ 159 | 'heatmap': heatmap_out, \ 160 | 'heatmap_true': heatmap_true, \ 161 | 'heatmap_reprojected': reprojected_heatmap, \ 162 | 'heatmap_interms': heatmap_interms} 163 | else: 164 | return {'img': img, \ 165 | 'joint': joint_out, \ 166 | 'heatmap': heatmap_out} 167 | 168 | def get_net_parameters(self): 169 | return self.net.parameters() 170 | 171 | def get_name_parameters(self): 172 | return self.net.named_parameters() 173 | 174 | def get_grads(self): 175 | grads = {} 176 | for tag, param in self.net.named_parameters(): 177 | grads[tag] = param.grad.data.detach().cpu() 178 | 179 | return grads 180 | -------------------------------------------------------------------------------- /model/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from util.package import find_class_using_name 3 | from .base_pipeline import BasePipeline 4 | 5 | def find_pipeline_using_name(pipeline_name): 6 | if not pipeline_name: 7 | raise Exception("'pipeline_name' is empty") 8 | pipeline_cls = find_class_using_name('model.pipeline', pipeline_name, 'pipeline') 9 | if inspect.isclass(pipeline_cls) and issubclass(pipeline_cls, BasePipeline): 10 | return pipeline_cls 11 | 12 | raise Exception("{} is not correctely implemented as PipelineModel class".format(pipeline_name)) 13 | -------------------------------------------------------------------------------- /model/pipeline/base_pipeline.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections import OrderedDict 3 | import torch 4 | 5 | from model.base_model import BaseModel 6 | 7 | class BasePipeline(BaseModel): 8 | """ The Pipeline is a special form of the Model. 9 | It is a collection and a wrapper of models. 10 | The roles of a pipeline are 11 | - connect models 12 | - init models with a correct option 13 | - (additional) optimize models collectively 14 | - (additional) manipulate models 15 | 16 | SOME IMPLEMENTATION TIPS : 17 | - A pipeline should not have its own option. Then it becomes same with a model. 18 | - A pipeline can have own optimizers and losses. 19 | """ 20 | 21 | def __init__(self, options): 22 | """ 23 | Override constructor. 24 | """ 25 | self.options = options 26 | gpu_ids = options.general.gpu_ids 27 | self.model_names = [] 28 | 29 | super().__init__(None, gpu_ids) 30 | 31 | def check_and_load_pretrained(self): 32 | """ Override its super class. 33 | A pretrained weights for a pipeline consists multiple weights for multiple modules. 34 | """ 35 | 36 | opt = self.options.general 37 | 38 | if opt.pipeline_pretrained: 39 | print("Load pretrained weights from {}".format(opt.pipeline_pretrained)) 40 | checkpoint = torch.load(opt.pipeline_pretrained) 41 | self.load_from_checkpoint(checkpoint, model_only = True) 42 | 43 | def _get_model_by_name(self, name): 44 | model = getattr(self, "{}_model".format(name)) 45 | assert isinstance(model, BaseModel), "No module with '{}'".format(name) 46 | return model 47 | 48 | def _has_model_with_name(self, name): 49 | return name in self.model_names 50 | 51 | def pack_as_checkpoint(self): 52 | """ Implement abstract method. 53 | Basically, a pipeline just packs all the modules as checkpoints and collects them. 54 | However, it is not a strcit rule and this method can be overridden. 55 | """ 56 | collected = {} 57 | for model_name in self.model_names: 58 | model = self._get_model_by_name(model_name) 59 | checkpoint = model.pack_as_checkpoint() 60 | collected[model_name] = checkpoint 61 | 62 | return collected 63 | 64 | def load_from_checkpoint(self, checkpoint, model_only): 65 | """ Implement abstract method. 66 | Reverse process of 'pack_as_checkpoint'. 67 | """ 68 | for model_name, cp in checkpoint.items(): 69 | if not self._has_model_with_name(model_name): 70 | continue 71 | model = self._get_model_by_name(model_name) 72 | model.load_from_checkpoint(cp, model_only = model_only) 73 | 74 | def get_current_losses(self): 75 | """ Override its super class. 76 | Usually, a pipeline doesn't have its own losses, so just return losses of each models with a prefix. 77 | However, it is not a strcit rule and this method can be override. 78 | """ 79 | losses_ret = OrderedDict() 80 | for model_name in self.model_names: 81 | model = self._get_model_by_name(model_name) 82 | losses = model.get_current_losses() 83 | for k, v in losses.items(): 84 | new_k = "{}_{}".format(model_name, k) 85 | losses_ret[new_k] = v 86 | return losses_ret 87 | 88 | def get_current_results_wo_detach(self): 89 | raise NotImplementedError("Don't call this for pipeline.") 90 | 91 | def get_grads(self): 92 | combined_grads = {} 93 | for model_name in self.model_names: 94 | model = getattr(self, "{}_model".format(model_name)) 95 | grads = model.get_grads() 96 | for tag, val in grads.items(): 97 | tag = "{}_{}".format(model_name, tag) 98 | combined_grads[tag] = val 99 | 100 | return combined_grads 101 | -------------------------------------------------------------------------------- /model/pipeline/deep_fisheye_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.pipeline.base_pipeline import BasePipeline 5 | from model.pix2depth_model import Pix2DepthModel 6 | from model.hpe_model import HPEModel 7 | from util.projector import FisheyeProjector 8 | 9 | class DeepFisheyePipeline(BasePipeline): 10 | def setup(self): 11 | 12 | self.model_names = ['pix2depth', 'hpe'] 13 | self.pix2depth_model = Pix2DepthModel(self.options.pix2depth, self.gpu_ids) 14 | self.hpe_model = HPEModel(self.options.hpe, self.gpu_ids) 15 | 16 | self.img_size = self.options.general.img_size 17 | projector = FisheyeProjector(self.img_size) 18 | 19 | self.hpe_model.setup(make_optimizer = False) 20 | self.pix2depth_model.setup(make_optimizer = False) 21 | 22 | # removes very small depths to create a correct hand segmentation mask. 23 | self.depth_threshold = nn.Threshold(threshold = self.options.general.min_depth_thrs, value = 0, inplace=True) 24 | 25 | self.is_train = self._is_all_model_train_mode() 26 | if self.is_train: 27 | pix2depth_opt = self.options.pix2depth 28 | hpe_opt = self.options.hpe 29 | 30 | self.pix2depth_optimizer = torch.optim.Adam(self.pix2depth_model.get_net_parameters(), lr=pix2depth_opt.lr, betas = (pix2depth_opt.beta1, 0.999)) 31 | self.hpe_optimizer = torch.optim.Adam(self._get_hpe_parameters(), lr=hpe_opt.lr) 32 | 33 | self.mode = self.pix2depth_model.mode 34 | self.is_setup = True 35 | 36 | self.check_and_load_pretrained() 37 | 38 | def _is_all_model_train_mode(self): 39 | p2d_train = self.pix2depth_model.mode.is_train() 40 | hpe_train = self.hpe_model.mode.is_train() 41 | return p2d_train 42 | #return p2d_train and hpe_train 43 | 44 | def _get_hpe_parameters(self): 45 | return self.hpe_model.get_net_parameters() 46 | 47 | def set_input(self, data): 48 | assert self.is_setup 49 | 50 | self.pix = self.send_tensor_to_device(data['fish']) 51 | 52 | self.segment = None 53 | self.fish_depth = None 54 | self.joint = None 55 | 56 | if 'joint' in data: 57 | self.joint = self.send_tensor_to_device(data['joint']) 58 | 59 | if 'fish_depth' in data: 60 | self.fish_depth = self.send_tensor_to_device(data['fish_depth']) 61 | 62 | def forward(self): 63 | assert self.is_setup 64 | 65 | pix2depth_results = self._forward_pix2depth(self.pix, self.fish_depth) 66 | self.fake_fish_depth = pix2depth_results['fake_depth'] 67 | self.fake_fish_depth = self.depth_threshold(self.fake_fish_depth) 68 | self.hand_pix = self._segment_hand(self.pix, self.fake_fish_depth) 69 | 70 | self._forward_hpe(self.hand_pix, self.fake_fish_depth, self.joint) 71 | 72 | def _forward_pix2depth(self, pix, depth): 73 | pix2depth_input = {"pix" : pix, 'depth': depth, 'joint': self.joint} 74 | self.pix2depth_model.set_input(pix2depth_input) 75 | self.pix2depth_model.forward() 76 | return self.pix2depth_model.get_current_results() 77 | 78 | def _segment_hand(self, pix, depth): 79 | segment = self._depth_to_segment(depth) 80 | pix = pix * segment 81 | return pix 82 | 83 | def _depth_to_segment(self, depth): 84 | """ Make segment from depth with maintaining autograds """ 85 | segment = depth.clone() 86 | segment[segment > 0] = 1 87 | return segment 88 | 89 | def _forward_hpe(self, pix, depth, joint): 90 | img = self._combine_pix_and_depth(pix, depth) 91 | if joint is None: 92 | hpe_input = {'img': img} 93 | else: 94 | hpe_input = {'joint' : joint, 'img' : img} 95 | self.hpe_model.set_input(hpe_input) 96 | self.hpe_model.forward() 97 | 98 | def _combine_pix_and_depth(self, pix, depth): 99 | return torch.cat((pix, depth), 1) 100 | 101 | def optimize_parameters(self): 102 | self.pix2depth_optimizer.zero_grad() 103 | self.hpe_optimizer.zero_grad() 104 | 105 | self.pix2depth_model.update_loss() 106 | p2d_loss = self.pix2depth_model.get_total_loss() 107 | 108 | self.hpe_model.update_losses() 109 | hpe_loss = self.hpe_model.get_total_loss() 110 | 111 | loss = hpe_loss + p2d_loss 112 | 113 | loss.backward() 114 | 115 | self.pix2depth_optimizer.step() 116 | self.hpe_optimizer.step() 117 | 118 | def _should_optimize_gan_discreminators(self): 119 | return (self.fish_depth is not None) 120 | 121 | 122 | def get_current_results(self): 123 | results = {} 124 | 125 | results['input'] = self.pix 126 | results['hand_pix'] = self.hand_pix 127 | results['fake_fish_depth'] = self.fake_fish_depth 128 | 129 | hpe_results = self.hpe_model.get_current_results() 130 | results['heatmap'] = hpe_results['heatmap'] 131 | results['joint'] = hpe_results['joint'] 132 | if 'heatmap_true' in hpe_results: 133 | results['heatmap_true'] = hpe_results['heatmap_true'] 134 | 135 | if 'heatmap_reprojected' in hpe_results: 136 | results['heatmap_reprojected'] = hpe_results['heatmap_reprojected'] 137 | 138 | return results 139 | -------------------------------------------------------------------------------- /model/pix2depth_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.base_model import BaseModel 5 | import network.hand_depth_net as depth_net 6 | from network.projection import create_projection_net 7 | import util.image as image 8 | from util.joint import JointConverter 9 | 10 | class Pix2DepthModel(BaseModel): 11 | def setup(self, make_optimizer = True): 12 | opt = self.opt 13 | self.make_optimizer = make_optimizer 14 | 15 | self.loss_names = [] 16 | 17 | self.network = depth_net.create_hdg_net(opt, self.gpu_ids) 18 | self.joint_converter = JointConverter(opt.num_joints) 19 | 20 | self.projection_net = create_projection_net(opt, self.gpu_ids) 21 | 22 | if self.mode.is_train(): 23 | 24 | self.criterionL2 = nn.MSELoss() 25 | self.heatmap_loss_weight = opt.heatmap_loss_weight 26 | self.heatmap_interm_loss_weight = opt.heatmap_interm_loss_weight 27 | self.joint_loss_weight = opt.joint_loss_weight 28 | self.joint_interm_loss_weight = opt.joint_interm_loss_weight 29 | self.depth_loss_weight = opt.depth_loss_weight 30 | 31 | if make_optimizer: 32 | # initialize optimizers 33 | self.optimizer = torch.optim.Adam(self.network.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 34 | self.optimizers.append(self.optimizer) 35 | 36 | else: 37 | self.network.eval() 38 | print("Pix2Depth network set to eval mode.") 39 | 40 | self.check_and_load_pretrained() 41 | 42 | self.run_interm = not self.mode.is_eval() 43 | 44 | self.is_setup = True 45 | 46 | def set_input(self, data): 47 | 48 | pix = data['pix'] 49 | self.real_pix = pix.to(self.device) 50 | self.real_pix = image.normalize_img(self.real_pix) 51 | 52 | if 'depth' in data and (data['depth'] is not None): 53 | self.real_depth = data['depth'].to(self.device) 54 | self.real_depth = image.normalize_img(self.real_depth) 55 | else: 56 | self.real_depth = None 57 | 58 | self.joint_true = None 59 | self.heatmap_true = None 60 | 61 | if self.run_interm: 62 | joint_true = self.send_tensor_to_device(data['joint']) 63 | 64 | normalized_joint = self.joint_converter.normalize(joint_true) 65 | self.heatmap_true = self.projection_net(normalized_joint) 66 | self.heatmap_true.requires_grad = False 67 | 68 | self.joint_true = self.joint_converter.convert_for_training(joint_true) 69 | 70 | def forward(self): 71 | assert self.is_setup 72 | result = self.network(self.real_pix) 73 | self.fake_depth = result['fake'] 74 | if self.run_interm: 75 | self.fake_interms = result['interms'] 76 | self.joint_interms = result['joint_interms'] 77 | self.heatmap_interms = result['heatmap_interms'] 78 | 79 | def optimize_parameters(self): 80 | self.optimizer.zero_grad() 81 | 82 | self.update_loss() 83 | self.loss_total.backward() 84 | 85 | self.optimizer.step() 86 | 87 | def update_loss(self): 88 | self.update_depth_loss() 89 | self.update_depth_interm_loss() 90 | self.update_joint_losses() 91 | self.update_heatmap_losses() 92 | self.loss_total = self.get_total_loss() 93 | 94 | def get_total_loss(self): 95 | total_G_loss = self.loss_depth + self.loss_depth_interm + self.loss_joint + self.loss_heatmap 96 | return total_G_loss 97 | 98 | def update_depth_loss(self): 99 | self.loss_depth = 0 100 | if self.real_depth is not None: 101 | self.loss_depth = self.criterionL2(self.fake_depth, self.real_depth) * self.depth_loss_weight 102 | self.add_loss_name('depth') 103 | 104 | def update_depth_interm_loss(self): 105 | self.loss_depth_interm = 0 106 | 107 | if self.real_depth is None: 108 | return 109 | 110 | for i, interm in enumerate(self.fake_interms): 111 | loss = self.criterionL2(interm, self.real_depth) * self.depth_loss_weight * 0.5 112 | loss_name = "depth_interm_{}".format(i) 113 | setattr(self, "loss_"+loss_name, loss) 114 | self.add_loss_name(loss_name) 115 | 116 | self.loss_depth_interm += loss 117 | 118 | def update_joint_losses(self): 119 | self.loss_joint = 0 120 | 121 | for i, interm in enumerate(self.joint_interms): 122 | interm_loss = self.criterionL2(interm, self.joint_true) * self.joint_loss_weight 123 | loss_name = "joint_interm_{}".format(i) 124 | setattr(self, 'loss_'+loss_name, interm_loss) 125 | self.add_loss_name(loss_name) 126 | 127 | self.loss_joint += interm_loss 128 | 129 | def update_heatmap_losses(self): 130 | self.loss_heatmap = 0 131 | 132 | for i, interm in enumerate(self.heatmap_interms): 133 | interm_loss = self.criterionL2(interm, self.heatmap_true) * self.heatmap_loss_weight 134 | loss_name = "heatmap_interm{}".format(i) 135 | setattr(self, 'loss_'+loss_name, interm_loss) 136 | self.add_loss_name(loss_name) 137 | 138 | self.loss_heatmap += interm_loss 139 | 140 | def add_loss_name(self, loss_name): 141 | if loss_name not in self.loss_names: 142 | self.loss_names.append(loss_name) 143 | 144 | def pack_as_checkpoint(self): 145 | checkpoint = {} 146 | checkpoint['network'] = self.extract_weights(self.network) 147 | 148 | if self.make_optimizer: 149 | checkpoint['optim'] = self.optimizer.state_dict() 150 | 151 | return checkpoint 152 | 153 | def add_loss_name(self, loss_name): 154 | if loss_name not in self.loss_names: 155 | self.loss_names.append(loss_name) 156 | 157 | def load_from_checkpoint(self, checkpoint, model_only): 158 | self.apply_weights(self.network, checkpoint['network']) 159 | 160 | if not model_only: 161 | self.optimizer.load_state_dict(checkpoint['optim']) 162 | 163 | def get_fake(self): 164 | return self.fake_depth 165 | 166 | def get_current_results(self): 167 | pix = image.unnormalize_as_img(self.real_pix) 168 | fake_depth = image.unnormalize_as_img(self.fake_depth) 169 | results = {'pix': pix, 'fake_depth': fake_depth} 170 | 171 | if not self.mode.is_eval(): 172 | depth = image.unnormalize_as_img(self.real_depth) 173 | results['depth'] = depth 174 | results['heatmap_interms'] = self.heatmap_interms 175 | 176 | return results 177 | 178 | def get_grads(self): 179 | grads = {} 180 | for tag, param in self.network.named_parameters(): 181 | tag = "{}".format(tag) 182 | grads[tag] = param.grad.data.detach().cpu() 183 | return grads 184 | 185 | def get_net_parameters(self): 186 | return self.network.parameters() 187 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/network/__init__.py -------------------------------------------------------------------------------- /network/base_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import network.norm as norm 3 | 4 | class BaseNet(nn.Module): 5 | def __init__(self, name, input_nc, output_nc, mode, norm_type): 6 | 7 | super().__init__() 8 | 9 | self.name = name 10 | self.input_nc = input_nc 11 | self.output_nc = output_nc 12 | self.mode = mode 13 | 14 | self.norm_layer = norm.get_norm_layer(norm_type) 15 | 16 | def set_mode(self): 17 | if self.mode.is_train(): 18 | print("{}: in train mode".format(self.name)) 19 | self.train() 20 | else: 21 | print("{}: in test mode".format(self.name)) 22 | self.eval() 23 | -------------------------------------------------------------------------------- /network/basic_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def down_conv1x1(inplanes, planes, norm_layer, stride = 1, acti_layer = nn.ReLU): 6 | return ConvBlock(nn.Conv2d, inplanes, planes, 1, stride, 0, norm_layer = norm_layer, acti_layer = acti_layer) 7 | 8 | def down_conv3x3(inplanes, planes, norm_layer, stride = 1, acti_layer = nn.ReLU): 9 | return ConvBlock(nn.Conv2d, inplanes, planes, 3, stride, 1, norm_layer = norm_layer, acti_layer = acti_layer) 10 | 11 | def down_conv7x7(inplanes, planes, norm_layer, stride = 1, acti_layer = nn.ReLU): 12 | return ConvBlock(nn.Conv2d, inplanes, planes, 7, stride, 3, norm_layer = norm_layer, acti_layer = acti_layer) 13 | 14 | def up_conv4x4(inplanes, planes, norm_layer, stride = 1, acti_layer = nn.ReLU): 15 | return ConvBlock(nn.ConvTranspose2d, inplanes, planes, 4, stride, 1, norm_layer = norm_layer, acti_layer = acti_layer) 16 | 17 | def up_conv6x6(inplanes, planes, norm_layer, stride = 1, acti_layer = nn.ReLU): 18 | return ConvBlock(nn.ConvTranspose2d, inplanes, planes, 6, stride, 2, norm_layer = norm_layer, acti_layer = acti_layer) 19 | 20 | class ConvBlock(nn.Module): 21 | def __init__(self, conv_class, input_nc, output_nc, kernel_size, stride, padding, bias = False, norm_layer = None, acti_layer = None): 22 | super().__init__() 23 | 24 | self.input_nc = input_nc 25 | self.output_nc = output_nc 26 | 27 | layers = [] 28 | conv = conv_class(input_nc, output_nc, \ 29 | kernel_size = kernel_size, stride = stride, padding = padding, bias = bias) 30 | layers.append(conv) 31 | 32 | if norm_layer is not None: 33 | layers.append(norm_layer(output_nc)) 34 | 35 | if acti_layer is not None: 36 | if acti_layer == nn.ReLU or acti_layer == nn.LeakyReLU: 37 | layers.append(acti_layer(inplace=True)) 38 | else: 39 | layers.append(acti_layer()) 40 | 41 | self.model = nn.Sequential(*layers) 42 | 43 | def forward(self, x): 44 | return self.model(x) 45 | 46 | class O2OBlock(nn.Module): 47 | # rather than using fully connected, use one-to-one convolution 48 | def __init__(self, in_channel, out_channel, global_pool, acti_layer = None): 49 | super().__init__() 50 | 51 | self.conv = ConvBlock(nn.Conv2d, in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=True, acti_layer = acti_layer) 52 | self.global_pool = global_pool 53 | 54 | def forward(self, x): 55 | # Assuming x as (N, C, H, W) 56 | 57 | x = self.conv(x) 58 | # avg pooling 59 | if self.global_pool: 60 | x = x.mean(dim=(2,3)) 61 | return x 62 | 63 | class FCBlock(nn.Module): 64 | def __init__(self, in_channel, out_channel, acti_layer = None, bias = True): 65 | super().__init__() 66 | self.fc = nn.Linear(in_channel, out_channel, bias = bias) 67 | 68 | if acti_layer: 69 | if acti_layer == nn.ReLU or acti_layer == nn.LeakyReLU: 70 | self.activation_layer = acti_layer(inplace = True) 71 | else: 72 | self.activation_layer = acti_layer() 73 | 74 | def forward(self, x): 75 | x = self.fc(x) 76 | 77 | if self.activation_layer: 78 | x = self.activation_layer(x) 79 | 80 | return x 81 | 82 | class Sigmoid6(nn.Module): 83 | value_range = 6 84 | def __init__(self): 85 | super().__init__() 86 | self.acti_layer = nn.Sigmoid() 87 | 88 | def forward(self, x): 89 | return self.acti_layer(x) * self.value_range 90 | -------------------------------------------------------------------------------- /network/hand_depth_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import network.resnet as resnet 4 | import network.helper as helper 5 | import network.basic_block as bb 6 | import network.norm as norm 7 | import network.unprojection as unproj 8 | 9 | from network.hand_pose_net import HandDataEncoder 10 | from network.base_net import BaseNet 11 | 12 | def create_hdg_net(opt, gpu_ids): 13 | unproject_net_gen = unproj.create_unprojection_net_generator(opt) 14 | net = HandDepthGenerateNet(opt, unproject_net_gen) 15 | return helper.init_net(net, opt.init_type, opt.init_gain, gpu_ids) 16 | 17 | class HandDepthGenerateNet(BaseNet): 18 | def __init__(self, opt, unproject_net_gen): 19 | 20 | input_nc = opt.input_nc 21 | output_nc = 0 # multiple output 22 | mode = opt.mode 23 | norm_type = opt.norm 24 | 25 | super().__init__("HandDepthGenerateNet", input_nc, output_nc, mode, norm_type) 26 | 27 | self.input_nc = opt.input_nc 28 | self.output_nc = opt.output_nc 29 | self.img_size = opt.img_size 30 | self.img_shape = (opt.img_size, opt.img_size) 31 | self.n_joints = opt.num_joints 32 | self.norm_type = opt.norm 33 | self.base_nc = opt.base_nc 34 | self.net_type = opt.net_type 35 | self.unproject_net_gen = unproject_net_gen 36 | 37 | self.norm_layer = norm.get_norm_layer(norm_type) 38 | self._make_network() 39 | self.interm_upscale = nn.Upsample(size = self.img_shape, mode = "bilinear") 40 | 41 | self.set_mode() 42 | 43 | self.run_interm = not self.mode.is_eval() 44 | 45 | if opt.train_only_encoder: 46 | self.backbone.requires_grad = False 47 | self.decoder.requires_grad = False 48 | 49 | def _make_network(self): 50 | encoder_base_nc = 64 51 | n_blocks_list = [3, 3, 3] 52 | n_stride_list = [2, 2, 2] 53 | self.encoder = HandDataEncoder(self.input_nc, self.n_joints, encoder_base_nc, n_blocks_list, n_stride_list, self.img_size, self.mode, self.norm_type, self.unproject_net_gen) 54 | self.backbone = self._make_resnet_backbone(self.encoder) 55 | 56 | decoder_base_nc = 512 57 | decoder_n_blocks = [3, 3, 3] 58 | self.decoder = DepthDecoder(self.backbone.output_nc, self.output_nc, decoder_base_nc, decoder_n_blocks, self.img_shape, self.norm_layer) 59 | 60 | def _make_resnet_backbone(self, encoder): 61 | if self.net_type == "resnet_6blocks": 62 | planes = 256 63 | n_blocks = 6 64 | elif self.net_type == "resnet_3blocks": 65 | planes = 256 66 | n_blocks = 3 67 | else: 68 | raise NotImplementedError("{} is not implemented".format(self.netG)) 69 | 70 | input_nc = encoder.output_nc 71 | return resnet.SimpleResnetLayer(input_nc, n_blocks, self.norm_layer) 72 | 73 | def forward(self, x): 74 | result = self.encoder(x) 75 | x = result['output'] 76 | x = self.backbone(x) 77 | 78 | decoder_result = self.decoder(x) 79 | 80 | result['fake'] = decoder_result['output'] 81 | result['interms'] = decoder_result['depth_interms'] 82 | 83 | return result 84 | 85 | class DepthDecoder(nn.Module): 86 | def __init__(self, input_nc, output_nc, base_nc, n_blocks, img_shape, norm_layer): 87 | super().__init__() 88 | self.input_nc = input_nc 89 | self.output_nc = output_nc 90 | self.base_nc = base_nc 91 | self.img_shape = img_shape 92 | self.num_layers = len(n_blocks) 93 | self.n_blocks = n_blocks 94 | self.norm_layer = norm_layer 95 | self._make_layers() 96 | self._make_interm_convs() 97 | self.end_conv = self._make_end_conv() 98 | 99 | def _make_layers(self): 100 | input_nc = self.input_nc 101 | num_features = self.base_nc 102 | for i, nb in enumerate(self.n_blocks): 103 | layer = resnet.UpsampleRensetLayer(input_nc, num_features, nb, stride = 2, norm_layer = self.norm_layer) 104 | self._set_layer(i, layer) 105 | 106 | num_features = int(num_features/2) 107 | input_nc = layer.output_nc 108 | 109 | def _make_interm_convs(self): 110 | for i in range(1, self.num_layers): 111 | layer = self._get_layer(i) 112 | input_nc = layer.output_nc 113 | interm_conv = DepthIntermConv(input_nc, self.output_nc, self.img_shape, self.norm_layer) 114 | self._set_interm_conv(i, interm_conv) 115 | 116 | def _set_layer(self, id, layer): 117 | setattr(self, "layer_{}".format(id), layer) 118 | 119 | def _get_layer(self, id): 120 | return getattr(self, "layer_{}".format(id)) 121 | 122 | def _get_last_layer(self): 123 | last_id = self.num_layers - 1 124 | return self._get_layer(last_id) 125 | 126 | def _set_interm_conv(self, id, conv): 127 | setattr(self, "interm_conv_{}".format(id), conv) 128 | 129 | def _get_interm_conv(self, id): 130 | return getattr(self, "interm_conv_{}".format(id)) 131 | 132 | def _make_end_conv(self): 133 | last_layer = self._get_last_layer() 134 | conv = bb.ConvBlock(nn.Conv2d, last_layer.output_nc, self.output_nc, kernel_size = 1, norm_layer = None, stride = 1, padding = 0, acti_layer = nn.Tanh, bias = True) 135 | upsample = nn.Upsample(size = self.img_shape) 136 | return nn.Sequential(conv, upsample) 137 | 138 | def forward(self, x): 139 | interms = [] 140 | 141 | first_layer = self._get_layer(0) 142 | x, upsampeld = first_layer(x) 143 | 144 | for i in range(1, self.num_layers): 145 | layer = self._get_layer(i) 146 | x, upsampled = layer(x) 147 | interm_conv = self._get_interm_conv(i) 148 | interm = interm_conv(upsampled) 149 | interms.append(interm) 150 | 151 | result = {} 152 | result['output'] = self.end_conv(x) 153 | result['depth_interms'] = interms 154 | return result 155 | 156 | def _forward_layers(self, i, x): 157 | layer = getattr(selt, "layer{}".format(i)) 158 | return layer(x) 159 | 160 | class DepthIntermConv(nn.Module): 161 | def __init__(self, input_nc, output_nc, img_shape, norm_layer): 162 | super().__init__() 163 | self.upconv = bb.up_conv4x4(input_nc, output_nc, norm_layer) 164 | self.scale_conv = bb.ConvBlock(nn.Conv2d, output_nc, output_nc, kernel_size = 1, norm_layer = None, stride = 1, padding = 0, acti_layer = nn.Tanh, bias = True) 165 | self.upscale = nn.Upsample(size = img_shape, mode = "bilinear") 166 | 167 | def forward(self, x): 168 | x = self.upconv(x) 169 | x = self.scale_conv(x) 170 | x = self.upscale(x) 171 | return x 172 | -------------------------------------------------------------------------------- /network/hand_module_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from functools import partial 3 | 4 | import network.basic_block as bb 5 | import network.resnet as resnet 6 | """ Submodules for hand pose estimation networks. """ 7 | class HeatmapConv(nn.Module): 8 | """ Decode heatmaps from encoded features. """ 9 | def __init__(self, input_nc, num_joints, n_deconv, deconv_nc, img_shape, norm_layer): 10 | super().__init__() 11 | self.input_nc = input_nc 12 | self.output_nc = num_joints 13 | self.norm_layer = norm_layer 14 | self.deconv_nc = deconv_nc 15 | self.n_deconv = n_deconv 16 | self.img_shape = img_shape 17 | 18 | self.net = self._make_network() 19 | self.upsample = nn.Upsample(size = img_shape, mode = 'bilinear') 20 | 21 | def _make_network(self): 22 | norm_layer = self.norm_layer 23 | 24 | conv1 = bb.down_conv3x3(self.input_nc, int(self.input_nc / 4), norm_layer) 25 | 26 | deconv1 = bb.up_conv4x4(int(self.input_nc / 4), self.deconv_nc, norm_layer, stride = 2) 27 | blocks = [conv1, deconv1] 28 | deconv_input_nc = deconv1.output_nc 29 | for _ in range(self.n_deconv - 1): 30 | deconv = bb.up_conv4x4(self.deconv_nc, self.deconv_nc, norm_layer, stride = 2) 31 | blocks.append(deconv) 32 | 33 | final_conv = bb.ConvBlock(nn.Conv2d, self.deconv_nc, self.output_nc, 34 | kernel_size = 1, stride = 1, padding = 0, bias = True, 35 | norm_layer = None, acti_layer = nn.Sigmoid) 36 | 37 | blocks.append(final_conv) 38 | return nn.Sequential(*blocks) 39 | 40 | def forward(self, x): 41 | x = self.net(x) 42 | return self.upsample(x) 43 | 44 | class IntermHeatmapConv(nn.Module): 45 | def __init__(self, input_nc, num_joints, img_shape, norm_layer): 46 | super().__init__() 47 | 48 | self.input_nc = input_nc 49 | self.output_nc = num_joints 50 | 51 | deconv = bb.up_conv4x4(input_nc, num_joints, norm_layer, stride = 2) 52 | conv1 = bb.ConvBlock(nn.Conv2d, num_joints, num_joints, kernel_size = 1, 53 | stride = 1, padding = 0, bias = True, 54 | norm_layer = None, acti_layer = nn.Sigmoid) 55 | 56 | upsample = nn.Upsample(size = img_shape, mode = 'bilinear') 57 | blocks = [deconv, conv1, upsample] 58 | self.net = nn.Sequential(*blocks) 59 | 60 | def forward(self, x): 61 | return self.net(x) 62 | 63 | class DistConv(nn.Module): 64 | """ Decode distance vectors from encoded features. """ 65 | def __init__(self, input_nc, num_joints, norm_layer): 66 | super().__init__() 67 | self.input_nc = input_nc 68 | self.output_nc = num_joints 69 | self.norm_layer = norm_layer 70 | 71 | self.net = self._make_network() 72 | 73 | def _make_network(self): 74 | norm_layer = self.norm_layer 75 | 76 | conv1 = bb.down_conv3x3(self.input_nc, 128, norm_layer, stride = 2, acti_layer = nn.Sigmoid) 77 | conv2 = bb.down_conv3x3(128, 256, norm_layer, stride = 2, acti_layer = nn.Sigmoid) 78 | inner_product = bb.O2OBlock(256, self.output_nc, global_pool = True, acti_layer = nn.Sigmoid) 79 | # acti layer is sigmoid because of relu's dead neuron problem 80 | fc1 = bb.FCBlock(self.output_nc, self.output_nc, acti_layer = bb.Sigmoid6, bias = True) 81 | layers = [conv1, conv2, inner_product, fc1] 82 | 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | return self.net(x) 87 | 88 | class IntermDistConv(nn.Module): 89 | def __init__(self, input_nc, num_joints, norm_layer): 90 | super().__init__() 91 | self.input_nc = input_nc 92 | self.output_nc = num_joints 93 | self.norm_layer = norm_layer 94 | 95 | self.net = self._make_network() 96 | 97 | def _make_network(self): 98 | norm_layer = self.norm_layer 99 | conv1 = bb.down_conv3x3(self.input_nc, 256, norm_layer, stride = 2, acti_layer = nn.Sigmoid) 100 | one_by_one = bb.O2OBlock(256, self.output_nc, global_pool = True, acti_layer = nn.Sigmoid) 101 | 102 | # acti layer is sigmoid because of relu's dead neuron problem 103 | small_fc = bb.FCBlock(self.output_nc, self.output_nc, acti_layer = bb.Sigmoid6, bias = True) 104 | layers = [conv1, one_by_one, small_fc] 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | return self.net(x) 110 | 111 | def leaky_relu(): 112 | return nn.LeakyReLU(negative_slope = 0.1, inplace = True) 113 | 114 | class HandConv(nn.Module): 115 | """ Conv layers that decodes joint data from encoded features. 116 | HandConv = HeatmapConv + DistConv + Unprojection Network 117 | """ 118 | def __init__(self, img_size, heatmap_conv, dist_conv, unproject_net): 119 | super().__init__() 120 | self.heatmap_conv = heatmap_conv 121 | self.dist_conv = dist_conv 122 | self.unproject_net = unproject_net 123 | 124 | def forward(self, x): 125 | heatmap = self.heatmap_conv(x) 126 | dist = self.dist_conv(x) 127 | joint = self.unproject_net(heatmap, dist) 128 | return joint, heatmap 129 | -------------------------------------------------------------------------------- /network/hand_pose_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import network.helper as helper 5 | import network.norm as norm 6 | import network.resnet as resnet 7 | import network.basic_block as bb 8 | import network.hand_module_net as handnet 9 | import network.unprojection as unproj 10 | 11 | from network.base_net import BaseNet 12 | 13 | def create_hpe_net(opt, gpu_ids): 14 | unproject_net_gen = unproj.create_unprojection_net_generator(opt) 15 | if opt.network == 'basic': # What we used in the paper. 16 | net = HandPoseNetBasic(opt, unproject_net_gen) 17 | elif opt.network == 'advance': # Same resnet structure, but deeper. 18 | net = HandPoseNetAdvance(opt, unproject_net_gen) 19 | else: 20 | raise Exception("Hpe network {} is not impelemented.".format(opt.network)) 21 | 22 | return helper.init_net(net, opt.init_type, opt.init_gain, gpu_ids) 23 | 24 | class HandPoseNetBasic(BaseNet): 25 | def __init__(self, opt, unproject_net_gen): 26 | input_nc = opt.input_channel 27 | output_nc = 0 # multiple output 28 | mode = opt.mode 29 | norm_type = opt.norm_type 30 | 31 | super().__init__("HandPoseNetBasic", input_nc, output_nc, mode, norm_type) 32 | 33 | self.num_joints = opt.num_joints 34 | self.img_size = opt.img_size 35 | self.img_shape = (opt.img_size, opt.img_size) 36 | self.norm_type = opt.norm_type 37 | 38 | self.deconv_nc = opt.deconv_channel 39 | 40 | self.unproject_net_gen = unproject_net_gen 41 | 42 | self.run_interm = not self.mode.is_eval() 43 | self._make_network() 44 | 45 | self.set_mode() 46 | 47 | def _make_network(self): 48 | base_nc = 64 49 | n_blocks_list = [3, 3, 4] 50 | n_stride_list = [2, 2, 2] 51 | n_heatmap_deconv = 2 52 | conv4_nc = 512 53 | 54 | self.encoder = HandDataEncoder(self.input_nc, self.num_joints, base_nc, n_blocks_list, n_stride_list, self.img_size, self.mode, self.norm_type, self.unproject_net_gen) 55 | 56 | input_nc = self.encoder.output_nc 57 | self.conv4e = bb.down_conv3x3(input_nc, conv4_nc, self.norm_layer) 58 | 59 | # create intermediate hand conv part. 60 | heatmap_interm_conv = handnet.IntermHeatmapConv(conv4_nc, self.num_joints, self.img_shape, self.norm_layer) 61 | dist_interm_conv = handnet.IntermDistConv(conv4_nc, self.num_joints, self.norm_layer) 62 | unproject_net = self.unproject_net_gen.create() 63 | self.hand_interm_conv = handnet.HandConv(self.img_size, heatmap_interm_conv, dist_interm_conv, unproject_net) 64 | 65 | self.conv4f = bb.down_conv3x3(conv4_nc, int(conv4_nc/2), self.norm_layer) 66 | 67 | # create final hand conv part. 68 | heatmap_conv = handnet.HeatmapConv(int(conv4_nc/2), self.num_joints, n_heatmap_deconv, self.deconv_nc, self.img_shape, self.norm_layer) 69 | dist_conv = handnet.DistConv(int(conv4_nc/2), self.num_joints, self.norm_layer) 70 | unproject_net = self.unproject_net_gen.create() 71 | self.hand_conv = handnet.HandConv(self.img_size, heatmap_conv, dist_conv, unproject_net) 72 | 73 | def forward(self, x): 74 | result = self.encoder(x) 75 | x = result['output'] 76 | 77 | x = self.conv4e(x) 78 | if self.run_interm: 79 | joint_interm, heatmap_interm = self.hand_interm_conv(x) 80 | 81 | x = self.conv4f(x) 82 | joint, heatmap = self.hand_conv(x) 83 | result['joint'] = joint 84 | result['heatmap'] = heatmap 85 | 86 | if self.run_interm: 87 | result['joint_interms'].append(joint_interm) 88 | result['heatmap_interms'].append(heatmap_interm) 89 | 90 | return result 91 | 92 | class HandPoseNetAdvance(BaseNet): 93 | def __init__(self, opt, unproject_net_gen): 94 | 95 | input_nc = opt.input_channel 96 | output_nc = 0 # multiple output 97 | mode = opt.mode 98 | norm_type = opt.norm_type 99 | 100 | super().__init__("HandPoseNetAdvance", input_nc, output_nc, mode, norm_type) 101 | 102 | self.num_joints = opt.num_joints 103 | self.img_size = opt.img_size 104 | self.img_shape = (opt.img_size, opt.img_size) 105 | self.norm_type = opt.norm_type 106 | self.unproject_net_gen = unproject_net_gen 107 | 108 | self.deconv_nc = opt.deconv_channel 109 | 110 | self.run_interm = not self.mode.is_eval() 111 | 112 | self._make_network() 113 | 114 | self.set_mode() 115 | 116 | def _make_network(self): 117 | base_nc = 64 118 | n_blocks_list = [3, 4, 4, 3] 119 | n_stride_list = [2, 2, 2, 2] 120 | n_heatmap_deconv = 3 121 | conv6_nc = 1024 122 | 123 | self.encoder = HandDataEncoder(self.input_nc, self.num_joints, base_nc, n_blocks_list, n_stride_list, self.img_size, self.mode, self.norm_type, self.unproject_net_gen) 124 | input_nc = self.encoder.output_nc 125 | 126 | self.conv6a = bb.down_conv3x3(input_nc, conv6_nc, self.norm_layer) 127 | heatmap_interm_conv = handnet.IntermHeatmapConv(conv6_nc, self.num_joints, self.img_shape, self.norm_layer) 128 | dist_interm_conv = handnet.IntermDistConv(conv6_nc, self.num_joints, self.norm_layer) 129 | unproject_net = self.unproject_net_gen.create() 130 | self.hand_interm_conv = handnet.HandConv(self.img_size, heatmap_interm_conv, dist_interm_conv, unproject_net) 131 | 132 | self.conv6b = bb.down_conv3x3(conv6_nc, int(conv6_nc/2), self.norm_layer) 133 | heatmap_conv = handnet.HeatmapConv(int(conv6_nc/2), self.num_joints, n_heatmap_deconv, self.deconv_nc, self.img_shape, self.norm_layer) 134 | dist_conv = handnet.DistConv(int(conv6_nc/2), self.num_joints, self.norm_layer) 135 | unproject_net = self.unproject_net_gen.create() 136 | self.hand_conv = handnet.HandConv(self.img_size, heatmap_conv, dist_conv, unproject_net) 137 | 138 | def forward(self, x): 139 | result = self.encoder(x) 140 | x = result['output'] 141 | 142 | x = self.conv6a(x) 143 | joint_interm, heatmap_interm = self.hand_interm_conv(x) 144 | x = self.conv6b(x) 145 | joint, heatmap = self.hand_conv(x) 146 | result['joint'] = joint 147 | result['heatmap'] = heatmap 148 | 149 | if self.run_interm: 150 | result['joint_interms'].append(joint_interm) 151 | result['heatmap_interms'].append(heatmap_interm) 152 | 153 | return result 154 | 155 | class HandDataEncoder(BaseNet): 156 | """ Encodes hand joint information. """ 157 | feature_inc_ratio = 2 158 | def __init__(self, input_nc, num_joints, base_nc, n_blocks_list, n_stride_list, img_size, mode, norm, unproject_net_gen): 159 | self.input_nc = input_nc 160 | self.num_joints = num_joints 161 | self.base_nc = base_nc 162 | self.n_blocks_list = n_blocks_list 163 | self.n_stride_list = n_stride_list 164 | self.n_layers = len(self.n_blocks_list) 165 | self.img_size = img_size 166 | self.img_shape = (img_size, img_size) 167 | self.unproject_net_gen = unproject_net_gen 168 | super().__init__("HandDataEncoderPolar", input_nc, 0, mode, norm) 169 | 170 | self._make_network() 171 | 172 | last_layer = self._get_last_layer() 173 | self.output_nc = last_layer.output_nc 174 | self.run_interm = not self.mode.is_eval() 175 | 176 | self.set_mode() 177 | 178 | def _make_network(self): 179 | self.front_conv = bb.ConvBlock(nn.Conv2d, self.input_nc, self.base_nc, kernel_size = 7, norm_layer = self.norm_layer, stride = 1, acti_layer = nn.ReLU, padding = 3) 180 | self.pooling = nn.MaxPool2d(3, 2, padding=1) 181 | self._make_resnet_layers() 182 | self._set_hand_interm_convs() 183 | 184 | def _make_resnet_layers(self): 185 | input_nc = self.base_nc 186 | num_channels = self.base_nc 187 | next_layer_feature_ratio = self.feature_inc_ratio 188 | 189 | for i in range(self.n_layers): 190 | n_blocks = self.n_blocks_list[i] 191 | stride = self.n_stride_list[i] 192 | layer = resnet.DownsampleResnetLayer(input_nc, num_channels, n_blocks, stride, self.norm_layer) 193 | self._set_layer(i, layer) 194 | num_channels = int(num_channels * next_layer_feature_ratio) 195 | input_nc = layer.output_nc 196 | 197 | def _set_layer(self, id, layer): 198 | setattr(self, "layer_{}".format(id), layer) 199 | 200 | def _get_layer(self, id): 201 | return getattr(self, "layer_{}".format(id)) 202 | 203 | def _get_last_layer(self): 204 | return self._get_layer(len(self.n_blocks_list)-1) 205 | 206 | def _set_hand_interm_convs(self): 207 | for i in range(1, self.n_layers): 208 | layer = self._get_layer(i) 209 | input_nc = layer.output_nc 210 | dist_conv = handnet.IntermDistConv(input_nc, self.num_joints, self.norm_layer) 211 | heatmap_conv = handnet.IntermHeatmapConv(input_nc, self.num_joints, self.img_shape, self.norm_layer) 212 | unproject_net = self.unproject_net_gen.create() 213 | 214 | hand_conv = handnet.HandConv(self.img_size, heatmap_conv, dist_conv, unproject_net) 215 | 216 | setattr(self, "hand_conv_{}".format(i), hand_conv) 217 | 218 | def _get_hand_interm_convs(self, i): 219 | return getattr(self, "hand_conv_{}".format(i)) 220 | 221 | def forward(self, x): 222 | x = self.front_conv(x) 223 | x = self.pooling(x) 224 | 225 | result = {} 226 | if self.run_interm: 227 | result['heatmap_interms'] = [] 228 | result['joint_interms'] = [] 229 | 230 | first_layer = self._get_layer(0) 231 | x, downsampeled = first_layer(x) 232 | 233 | for i in range(1, self.n_layers): 234 | x, joint_interm, heatmap_interm = self.forward_layer(x, i) 235 | if self.run_interm: 236 | result['joint_interms'].append(joint_interm) 237 | result['heatmap_interms'].append(heatmap_interm) 238 | 239 | result['output'] = x 240 | 241 | return result 242 | 243 | def forward_layer(self, x, i): 244 | layer = self._get_layer(i) 245 | hand_conv = self._get_hand_interm_convs(i) 246 | 247 | x, downsampeled = layer(x) 248 | joint_interm = None 249 | heatmap_interm = None 250 | if self.run_interm: 251 | joint_interm, heatmap_interm = hand_conv(downsampeled) 252 | 253 | return x, joint_interm, heatmap_interm 254 | -------------------------------------------------------------------------------- /network/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | import network.norm as norm 8 | import network.basic_block as bb 9 | import util.filter as filter 10 | 11 | def is_conv_or_linear(m): 12 | classname = m.__class__.__name__ 13 | return classname.startswith('Conv') or classname.startswith('Linear') 14 | 15 | def is_batch_or_group_norm(m): 16 | classname = m.__class__.__name__ 17 | return classname.startswith('BatchNorm') or classname.startswith('GroupNorm') 18 | 19 | def init_weights(net, init_type='normal', init_gain = 0.1): 20 | 21 | initializer = None 22 | if init_type == 'normal': 23 | initializer = functools.partial(init.normal_, mean=0.0, std=init_gain) 24 | elif init_type == 'xavier': 25 | initializer = functools.partial(init.xavier_normal_, gain = init_gain) 26 | elif init_type == 'kaiming': 27 | initializer = functools.partial(init.kaiming_normal_, gain = init_gain) 28 | else: 29 | raise ValueError("init_type with {} is not valid".format(init_type)) 30 | 31 | def weights_init(m): 32 | if hasattr(m, 'weight') and is_conv_or_linear(m): 33 | initializer(m.weight.data) 34 | elif is_batch_or_group_norm(m): 35 | # Instance norm is implemenet with GroupNorm in our case. 36 | init.normal_(m.weight.data, mean=1.0, std=init_gain) 37 | init.constant_(m.bias.data, 0.0) 38 | 39 | if isinstance(net, nn.DataParallel): 40 | name = net.module.name 41 | else: 42 | name = net.name 43 | 44 | print("{} is initialized with {}".format(name, init_type)) 45 | net.apply(weights_init) 46 | 47 | def send_net_to_device(net, gpu_ids = []): 48 | if len(gpu_ids) > 0: 49 | assert(torch.cuda.is_available()) 50 | net.to(gpu_ids[0]) 51 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 52 | return net 53 | 54 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], initialize_weights = True): 55 | net = send_net_to_device(net, gpu_ids) 56 | if initialize_weights: 57 | init_weights(net, init_type, init_gain=init_gain) 58 | return net 59 | -------------------------------------------------------------------------------- /network/norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import functools 3 | 4 | class InstanceNorm(nn.GroupNorm): 5 | def __init__(self, channels): 6 | super().__init__(channels, channels) 7 | 8 | class LayerNorm(nn.GroupNorm): 9 | def __init__(self, channels): 10 | super().__init__(1, channels) 11 | 12 | def get_norm_layer(norm_type, group = 4): 13 | if norm_type == "batch": 14 | return nn.BatchNorm2d 15 | elif norm_type == "group": 16 | return create_group_norm(group) 17 | elif norm_type == "instance": 18 | return InstanceNorm 19 | elif norm_type == "layer": 20 | return LayerNorm 21 | else: 22 | raise NotImplementedError("norm_type '{}' is not implemented".format(norm_type)) 23 | 24 | def create_group_norm(self, num_group): 25 | class GroupNorm(nn.GroupNorm): 26 | def __init__(self, channels): 27 | super().__init__(num_group, channels) 28 | -------------------------------------------------------------------------------- /network/projection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import network.basic_block as bb 5 | import util.fisheye as fisheye 6 | import network.helper as helper 7 | from util.debug import viz_grad, viz_grad_mean 8 | import util.image as image 9 | import util.io as io 10 | import torchvision 11 | import util.math as umath 12 | import util.filter as filter 13 | 14 | from abc import ABC, abstractmethod 15 | 16 | def create_projection_net(opt, gpu_ids): 17 | net = FisheyeProjectionNet(opt.num_joints, opt.img_size, opt.gauss_kernel_size, opt.gauss_kernel_sigma) 18 | return helper.init_net(net, gpu_ids = gpu_ids, initialize_weights = False) 19 | 20 | class ProjectionNet(nn.Module, ABC): 21 | def __init__(self, num_joints, img_size, g_size, g_sigma): 22 | super().__init__() 23 | 24 | self.num_joints = num_joints 25 | self.img_size = img_size 26 | self.img_shape = (img_size, img_size) 27 | self.gen_heatmap_seed = GenHeatmapSeed.apply 28 | self.gaussian_filter = filter.GaussianFilter(channels = num_joints, kernel_size = g_size, sigma = g_sigma, peak_to_one = True) 29 | 30 | def forward(self, joint): 31 | assert len(joint.shape) == 3, "joint shape should be (batch_size, num_joints, 3)" 32 | 33 | mapped_joint = self._convert_to_uv(joint) 34 | 35 | heatmap_seed = self.gen_heatmap_seed(mapped_joint, self.img_size) 36 | heatmap = self.gaussian_filter(heatmap_seed) 37 | 38 | self.saved_for_backward = heatmap 39 | return heatmap 40 | 41 | @abstractmethod 42 | def _convert_to_uv(self, xyz): 43 | pass 44 | 45 | class GenHeatmapSeed(torch.autograd.Function): 46 | """ A differentiable function that generates a heatmap seed matrix from uv coordinates """ 47 | @staticmethod 48 | def forward(ctx, uv, img_size): 49 | # uv: N, n_joints, 2 50 | n_sample, n_joints = uv.shape[:2] 51 | 52 | int_uv_mat = torch.round(uv).long() 53 | int_uv_mat = torch.clamp(int_uv_mat, min = 0, max = img_size - 1) 54 | 55 | batch_joint_index = np.arange(0, n_sample * n_joints) 56 | 57 | flatten_idx = int_uv_mat.view(-1, 2) 58 | x_idx = flatten_idx[:,0] 59 | y_idx = flatten_idx[:,1] 60 | 61 | heatmap_seeds = torch.zeros(n_sample * n_joints, img_size, img_size) 62 | heatmap_seeds[batch_joint_index, y_idx, x_idx] = 1 63 | heatmap_seeds = heatmap_seeds.view(n_sample, n_joints, img_size, img_size) 64 | 65 | heatmap_seeds = heatmap_seeds.to(uv.device) 66 | 67 | ctx.save_for_backward(uv, heatmap_seeds) 68 | 69 | return heatmap_seeds 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | pred_uv, heatmap_seeds = ctx.saved_tensors 74 | 75 | # Just pass 0. Backward is never called in the project. 76 | # However, you might want to have a proper backward function. 77 | # Then the backward function can be implemented in here. 78 | grad_input = torch.zeros(pred_uv.shape) 79 | 80 | return grad_input, None # None is for img_size input 81 | 82 | class DifferentiableRound(torch.autograd.Function): 83 | @staticmethod 84 | def forward(ctx, x): 85 | return x.round() 86 | 87 | @staticmethod 88 | def backward(ctx, grad_output): 89 | return grad_output 90 | 91 | class FisheyeProjectionNet(ProjectionNet): 92 | def __init__(self, num_joints, img_size, g_size, g_sigma, fisheye_type = 'equidistant'): 93 | super().__init__(num_joints, img_size, g_size, g_sigma) 94 | self.name = "Fisheye Projection Network" 95 | self.radius = int(self.img_size / 2) 96 | self.center = (self.radius, self.radius) 97 | self.fisheye_type = fisheye_type 98 | self.differentiable_round = DifferentiableRound.apply 99 | 100 | def _convert_to_uv(self, xyz): 101 | #xyz shape: N, num_joint, 3 102 | n_sample, n_joints = xyz.shape[:2] 103 | 104 | xyz = xyz.view(-1,3) 105 | x = xyz[:,0] 106 | y = xyz[:,1] 107 | z = xyz[:,2] 108 | 109 | theta = torch.atan2(torch.sqrt(x*x + y*y), z) 110 | 111 | phi = torch.atan2(y, x) 112 | 113 | r = fisheye.r_function(self.radius, theta, self.fisheye_type) 114 | 115 | _x = r * torch.cos(phi) 116 | _y = r * torch.sin(phi) 117 | 118 | fish_x = self.differentiable_round(self.center[0] + _x).unsqueeze(1) 119 | fish_y = self.differentiable_round(self.center[1] + _y).unsqueeze(1) 120 | out = torch.cat((fish_x, fish_y), dim = 1) 121 | out = out.view(n_sample, n_joints, -1) 122 | 123 | return out 124 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import network.basic_block as bb 3 | from network.basic_block import ConvBlock 4 | 5 | class DownBottleneck(nn.Module): 6 | expansion = 4 7 | 8 | def __init__(self, inplanes, planes, norm_layer, stride=1, downsample=False): 9 | super().__init__() 10 | self.stride = stride 11 | self.input_nc = inplanes 12 | self.output_nc = planes * self.expansion 13 | self.conv1 = bb.down_conv1x1(inplanes, planes, norm_layer) 14 | self.conv2 = bb.down_conv3x3(planes, planes, norm_layer, stride) 15 | self.conv3 = bb.down_conv1x1(planes, planes * self.expansion, norm_layer, acti_layer = None) 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | self.downsample = downsample or (stride > 1) 19 | if self.downsample: 20 | self.downsample_block = bb.down_conv1x1(inplanes, planes * self.expansion, norm_layer, stride, acti_layer = None) 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.conv2(out) 27 | out = self.conv3(out) 28 | 29 | if self.downsample: 30 | residual = self.downsample_block(x) 31 | 32 | out += residual 33 | out = self.relu(out) 34 | 35 | return out 36 | 37 | class UniBottleneck(nn.Module): 38 | def __init__(self, num_features, norm_layer): 39 | super().__init__() 40 | self.input_nc = num_features 41 | self.output_nc = num_features 42 | self.conv1 = bb.down_conv1x1(num_features, num_features, norm_layer) 43 | self.conv2 = bb.down_conv3x3(num_features, num_features, norm_layer) 44 | self.conv3 = bb.down_conv1x1(num_features, num_features, norm_layer, acti_layer = None) 45 | self.relu = nn.ReLU(inplace=True) 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.conv2(out) 52 | out = self.conv3(out) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | class UpBottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, norm_layer=nn.BatchNorm2d, stride=1, upsample=False): 63 | super().__init__() 64 | 65 | self.stride = stride 66 | 67 | self.input_nc = inplanes 68 | self.output_nc = int(planes / self.expansion) 69 | 70 | if self.stride > 1: 71 | self.conv1 = bb.up_conv4x4(inplanes, planes, norm_layer, stride) 72 | else: 73 | self.conv1 = bb.down_conv1x1(inplanes, planes, norm_layer) 74 | 75 | self.conv2 = bb.down_conv3x3(planes, planes, norm_layer) 76 | 77 | self.conv3 = bb.down_conv1x1(planes, self.output_nc, norm_layer, acti_layer = None) 78 | self.relu = nn.ReLU(inplace=True) 79 | 80 | self.upsample = upsample or (stride > 1) 81 | if self.upsample: 82 | self.upsample_block = bb.up_conv4x4(inplanes, self.output_nc, norm_layer, stride, acti_layer = None) 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.conv2(out) 89 | out = self.conv3(out) 90 | 91 | if self.upsample: 92 | residual = self.upsample_block(x) 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | class DownsampleResnetLayer(nn.Module): 99 | def __init__(self, input_nc, planes, n_blocks, stride, norm_layer): 100 | super().__init__() 101 | 102 | self.downsample_block = DownBottleneck(input_nc, planes, norm_layer=norm_layer, stride = stride, downsample = True) 103 | blocks = [] 104 | 105 | in_plane = self.downsample_block.output_nc 106 | for _ in range(n_blocks - 1): 107 | b = DownBottleneck(in_plane, planes, norm_layer = norm_layer, stride = 1, downsample = False) 108 | in_plane = b.output_nc 109 | blocks.append(b) 110 | 111 | self.identity_blocks = nn.Sequential(*blocks) 112 | self.input_nc = input_nc 113 | self.planes = planes 114 | self.output_nc = blocks[-1].output_nc 115 | 116 | def forward(self, x): 117 | downsampled = self.downsample_block(x) 118 | x = self.identity_blocks(downsampled) 119 | return x, downsampled 120 | 121 | class UpsampleRensetLayer(nn.Module): 122 | def __init__(self, input_nc, planes, n_blocks, stride, norm_layer): 123 | super().__init__() 124 | 125 | self.upsample_block = UpBottleneck(input_nc, planes, norm_layer=norm_layer, stride = stride, upsample = True) 126 | blocks = [] 127 | 128 | in_plane = self.upsample_block.output_nc 129 | for _ in range(n_blocks - 1): 130 | b = UniBottleneck(in_plane, norm_layer) 131 | in_plane = b.output_nc 132 | blocks.append(b) 133 | 134 | self.input_nc = input_nc 135 | self.planes = planes 136 | self.n_blocks = n_blocks 137 | 138 | if n_blocks > 1: 139 | self.identity_blocks = nn.Sequential(*blocks) 140 | self.output_nc = blocks[-1].output_nc 141 | else: 142 | self.identity_blocks = None 143 | self.output_nc = self.upsample_block.output_nc 144 | 145 | def forward(self, x): 146 | upsampled = self.upsample_block(x) 147 | if self.n_blocks > 1: 148 | x = self.identity_blocks(upsampled) 149 | else: 150 | x = upsampled 151 | return x, upsampled 152 | 153 | class SimpleResnetLayer(nn.Module): 154 | def __init__(self, input_nc, n_blocks, norm_layer): 155 | super().__init__() 156 | 157 | blocks = [] 158 | for _ in range(n_blocks): 159 | b = UniBottleneck(input_nc, norm_layer) 160 | input_nc = b.output_nc 161 | blocks.append(b) 162 | 163 | self.net = nn.Sequential(*blocks) 164 | self.input_nc = input_nc 165 | self.output_nc = blocks[-1].output_nc 166 | 167 | def forward(self, x): 168 | return self.net(x) 169 | -------------------------------------------------------------------------------- /network/unprojection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from abc import ABC, abstractmethod 4 | import util.fisheye as fisheye 5 | from util.debug import viz_grad, viz_grad_mean 6 | import network.basic_block as bb 7 | 8 | """ Unprojection networks 9 | 1) get keypoints from heatmaps 10 | 2) convert the keypoints to direction vectors 11 | 3) convert the vectors and distance values to 3D points in the cartesian coordinate. 12 | """ 13 | 14 | def unravel_index(idx, H, W): 15 | row = (idx / W).long() 16 | col = (idx % W).long() 17 | return row, col 18 | 19 | class ArgMax2d(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, tensor): 22 | N, C, H, W = tensor.shape 23 | tensor = tensor.reshape(N, C, H*W) 24 | max_idx = tensor.argmax(dim = -1) 25 | row, col = unravel_index(max_idx, H, W) 26 | idx = torch.stack([row, col], dim = 2) 27 | idx = idx.float() 28 | ctx.input_shape = (N, C, H, W) 29 | ctx.save_for_backward(idx) 30 | return idx 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | 35 | idx, = ctx.saved_tensors 36 | N, C, H, W = ctx.input_shape 37 | 38 | 39 | grad_input = torch.zeros(N, C, H, W) 40 | grad_input = grad_input.to(grad_output.device) 41 | 42 | return grad_input 43 | 44 | def create_unprojection_net_generator(opt): 45 | return UnprojectionNetGenerator(opt.img_size) 46 | 47 | class UnprojectionNetGenerator: 48 | def __init__(self, img_size): 49 | self.img_size = img_size 50 | 51 | def create(self): 52 | return FisheyeUnprojectionNet(self.img_size) 53 | 54 | class FisheyeUnprojectionNet(nn.Module): 55 | def __init__(self, img_size, fisheye_type = "equidistant"): 56 | super().__init__() 57 | self.img_size = img_size 58 | self.img_radius = img_size / 2 59 | self.argmax_2d = ArgMax2d.apply 60 | self.fisheye_type = fisheye_type 61 | 62 | def forward(self, heatmaps, dists): 63 | max_idx = self.argmax_2d(heatmaps) 64 | theta_phi = self._convert_to_theta_phi(max_idx) 65 | r = dists 66 | theta = theta_phi[:,:,0] 67 | phi = theta_phi[:,:,1] 68 | xyz = self._convert_to_cartesian(r, theta, phi) 69 | return xyz 70 | 71 | def _convert_to_theta_phi(self, max_idx): 72 | max_idx = max_idx.float() 73 | img_coord_idx = max_idx - self.img_radius 74 | u = img_coord_idx[:,:,1] 75 | v = img_coord_idx[:,:,0] 76 | 77 | phi = torch.atan2(v, u) 78 | 79 | r = (u**2 + v**2).sqrt() 80 | 81 | theta = fisheye.inverse_r_function(self.img_radius, r, self.fisheye_type) 82 | 83 | theta_phi = torch.stack([theta, phi], dim = 2) 84 | return theta_phi 85 | 86 | def _convert_to_cartesian(self, r, theta, phi): 87 | z = r * torch.cos(theta) 88 | xy = r * torch.sin(theta) 89 | x = xy * torch.cos(phi) 90 | y = xy * torch.sin(phi) 91 | 92 | xyz = torch.stack([x, y, z], dim = 2) 93 | return xyz 94 | -------------------------------------------------------------------------------- /option/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/option/__init__.py -------------------------------------------------------------------------------- /option/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class BaseOption(ABC): 4 | """ This is basic abstract class for options. """ 5 | def __init__(self, name, prefix): 6 | self.__name__ = name 7 | self.__prefix__ = prefix 8 | self.arguments = [] 9 | self.set_arguments() 10 | self.initialized = False 11 | 12 | @abstractmethod 13 | def set_arguments(self): 14 | """ Set (custom) arguments that this option needs. """ 15 | pass 16 | 17 | def add_to_parser(self, parser): 18 | for argument in self.arguments: 19 | argument.add_to_parser(parser, self.__prefix__) 20 | 21 | def initialize(self, args): 22 | if self.initialized: 23 | return 24 | 25 | for argument in self.arguments: 26 | arg_name = argument.update_arg_name(self.__prefix__) 27 | v = getattr(args, arg_name) 28 | setattr(self, argument.name, v) 29 | 30 | del self.arguments 31 | self.initialized = True 32 | 33 | def initialize_with_defaults(self): 34 | if self.initialized: 35 | return 36 | 37 | for argument in self.arguments: 38 | setattr(self, argument.name, argument.default) 39 | 40 | del self.arguments 41 | self.initialized = True 42 | 43 | def __str__(self): 44 | message = "" 45 | for k, v in sorted(vars(self).items()): 46 | message += "{}: {}\n".format(str(k), str(v)) 47 | 48 | return message 49 | 50 | def convert_mode_str_to_mode(self): 51 | if hasattr(self, 'mode'): 52 | self.mode = Mode(self.mode) 53 | 54 | 55 | class Argument: 56 | """ This is a wrapper class for python argument. """ 57 | def __init__(self, name, type=None, default=None, help=None, action=None): 58 | self.name = name 59 | self.type = type 60 | self.default = default 61 | self.help = help 62 | self.action = action 63 | 64 | def add_to_parser(self, parser, prefix): 65 | parser_tag = self.get_parser_tag(prefix) 66 | helper_msg = self.update_helper_msg(prefix) 67 | if not self.action is None: 68 | self._add_action_to_parser(parser, parser_tag, prefix) 69 | return 70 | 71 | parser.add_argument(parser_tag, \ 72 | type = self.type, \ 73 | default = self.default, \ 74 | help = helper_msg, \ 75 | action = self.action \ 76 | ) 77 | 78 | def _add_action_to_parser(self, parser, parser_tag, prefix): 79 | parser.add_argument(parser_tag, \ 80 | default = self.default, \ 81 | help = self.help, \ 82 | action = self.action \ 83 | ) 84 | 85 | def get_parser_tag(self, prefix): 86 | return "--" + self.update_arg_name(prefix) 87 | 88 | def update_arg_name(self, prefix): 89 | if prefix: 90 | return '{}_{}'.format(prefix, self.name) 91 | else: 92 | return '{}'.format(self.name) 93 | 94 | def update_helper_msg(self, prefix): 95 | if prefix: 96 | return '{}: {}'.format(prefix, self.help) 97 | else: 98 | return '{}'.format(self.help) 99 | 100 | class Mode: 101 | """ Mode class defines the mode of models and networks. 102 | A model or a network might have to work differently for different modes. 103 | For example, a network might have to calculate intermediate losses for 'train' and 'test', but not for 'eval'. 104 | At the sametime, the network should be swtiched to eval mode for 'test' and 'eval', but not for 'train'. 105 | 106 | There are three modes. 107 | - train 108 | - test 109 | - eval (evaluate) 110 | """ 111 | modes = ['train', 'test', 'eval'] 112 | def __init__(self, mode_str): 113 | self.mode_str = self.clean_up(mode_str) 114 | assert self.mode_str in self.modes, "Mode: mode should be one of {}".format(','.join(modes)) 115 | 116 | def clean_up(self, mode_str): 117 | mode_str = mode_str.strip() 118 | mode_str = mode_str.lower() 119 | 120 | return mode_str 121 | 122 | def is_train(self): 123 | return self.mode_str == 'train' 124 | 125 | def is_test(self): 126 | return self.mode_str == 'test' 127 | 128 | def is_eval(self): 129 | return self.mode_str == 'eval' 130 | 131 | def __str__(self): 132 | return "{} (mode class)".format(self.mode_str) 133 | 134 | def to_train(self): 135 | self.mode_str = 'train' 136 | 137 | def to_test(self): 138 | self.mode_str = 'test' 139 | 140 | def to_eval(self): 141 | self.mode_str = 'eval' 142 | -------------------------------------------------------------------------------- /option/general.py: -------------------------------------------------------------------------------- 1 | from .base import BaseOption, Argument 2 | 3 | class GeneralOption(BaseOption): 4 | 5 | def set_arguments(self): 6 | 7 | self._set_basic_params() 8 | self._set_dataset_params() 9 | self._set_pipeline_params() 10 | self._set_train_params() 11 | self._set_train_visual_params() 12 | self._set_test_params() 13 | 14 | def _set_basic_params(self): 15 | self.arguments += [Argument('name', type=str, default='sample', help='name of the run. model and samples will be stored.')] 16 | self.arguments += [Argument('gpu_ids', type=str, default='0', help="gpu ids to use. -1 for cpu. (e.g. 0 0,1,2).")] 17 | self.arguments += [Argument('num_workers', type=int, default=0, help="number of data loader workers.")] 18 | self.arguments += [Argument('max_data', type=int, default = float("inf"), help="number of data to use.")] 19 | self.arguments += [Argument('preset', type=str, default = '', help="name of preset. you should write Preset(*this_part*) of preset classes in presets.py")] 20 | self.arguments += [Argument('run', type=str, default = '', help="name of the train or test code to run. please see 'run' module")] 21 | 22 | def _set_dataset_params(self): 23 | self.arguments += [Argument('dataset', type=str, default='', help="dataset to use.")] 24 | self.arguments += [Argument('img_size', type=int, default=256, help="the size of images (defaul: 256). all images should be square images")] 25 | self.arguments += [Argument('no_flip', action='store_true', help='does not flip during a run')] 26 | self.arguments += [Argument('min_depth_thrs', type=float, default=0.01171875, help=" too small depth to consider (default: 3 / 256)")] 27 | 28 | def _set_pipeline_params(self): 29 | self.arguments += [Argument('pipeline_pretrained', type=str, default='', help="pretrained_weight for the pipeline")] 30 | 31 | def _set_train_params(self): 32 | self.arguments += [Argument('batch_size', type=int, default=1, help='batch size.')] 33 | self.arguments += [Argument('epoch', type=int, default=100, help='number of epochs to train.')] 34 | self.arguments += [Argument('speed_diagnose', action='store_true', default=False, help='measure time for important steps.')] 35 | 36 | # save 37 | self.arguments += [Argument('save_epoch', type=int, default=5, help='epoch period to save checkpoints')] 38 | 39 | def _set_train_visual_params(self): 40 | self.arguments += [Argument('print_iter', type=int, default=100, help='iteration period to print to tensorboard and terminal')] 41 | self.arguments += [Argument('tensorboard_dir', type=str, default='tensorboard_logs', help='where the tensorboard logs stays')] 42 | self.arguments += [Argument('show_grad', action='store_true', default = False, help='show gradient histogram while training')] 43 | 44 | def _set_test_params(self): 45 | self.arguments += [Argument('no_save_image', action='store_true', default=False, help='do not save result images for test')] 46 | -------------------------------------------------------------------------------- /option/hpe.py: -------------------------------------------------------------------------------- 1 | from .base import BaseOption, Argument 2 | 3 | class HPEOption(BaseOption): 4 | 5 | def set_arguments(self): 6 | 7 | # network params 8 | self.arguments += [Argument('mode', type=str, default='train', help="mode of the model [train | test | eval]")] 9 | self.arguments += [Argument('network', type=str, default='basic', help="hand pose estimator network [basic]")] 10 | self.arguments += [Argument('num_joints', type=int, default=21, help="the number of hand joints.")] 11 | self.arguments += [Argument('gauss_kernel_size', type=int, default=31, help="the size of gaussian kernel for heatmaps")] 12 | self.arguments += [Argument('gauss_kernel_sigma', type=int, default=5, help="the sigma of gaussian kernel for heatmaps")] 13 | self.arguments += [Argument('deconv_channel', type=int, default=64, help="channel for deconv layer of the heatmap network")] 14 | self.arguments += [Argument('init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')] 15 | self.arguments += [Argument('init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')] 16 | 17 | self.arguments += [Argument('img_size', type=int, default=256, help='width and height of input and output image')] 18 | self.arguments += [Argument('input_channel', type=int, default=3, help='the number of channels for the input images. 1 for depth, 3 for rgb image')] 19 | self.arguments += [Argument('norm_type', type=str, default='group', help='the type of normaliztion layer for the network. [batch | group | instance]')] 20 | 21 | # training 22 | self.arguments += [Argument('lr', type=float, default=0.1, help='initial learning rate for adam')] 23 | self.arguments += [Argument('heatmap_loss_weight', type=float, default=1.0, help='weight of heatmap losses')] 24 | self.arguments += [Argument('heatmap_interm_loss_weight', type=float, default=1.0, help='weight of intermediate heatmap losses')] 25 | self.arguments += [Argument('joint_loss_weight', type=float, default=1.0, help='weight of joint losses')] 26 | self.arguments += [Argument('joint_interm_loss_weight', type=float, default=1.0, help='weight of intermediate joint losses')] 27 | 28 | # pretrained 29 | self.arguments += [Argument('pretrained', type=str, default = "", help="pretrained model weights to load. it should be a checkpoint.")] 30 | -------------------------------------------------------------------------------- /option/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from .general import GeneralOption 3 | from .hpe import HPEOption 4 | from .pix2depth import Pix2DepthOption 5 | 6 | class Options(): 7 | """ This class holds all the options. 8 | You have to modify '_set_options' to add/remove options. 9 | """ 10 | def __init__(self): 11 | self.initialized = False 12 | self._set_options() 13 | 14 | def _set_options(self): 15 | # Add or remove options in here. 16 | self.general = GeneralOption(name = 'general', prefix = None) 17 | self.hpe = HPEOption(name = 'hand_pose_estimator', prefix = 'hpe') 18 | self.pix2depth = Pix2DepthOption(name = 'pix2depth', prefix = "p2d") 19 | 20 | self.options = [] 21 | self.options.append(self.general) 22 | self.options.append(self.hpe) 23 | self.options.append(self.pix2depth) 24 | 25 | def initialize(self): 26 | if not self.initialized: 27 | parser = argparse.ArgumentParser() 28 | parser = self._init_parser(parser) 29 | self.initialized = True 30 | 31 | args = parser.parse_args() 32 | 33 | for opt in self.options: 34 | opt.initialize(args) 35 | 36 | assert self.general.img_size == self.hpe.img_size and self.general.img_size == self.pix2depth.img_size, "all image size should be same." 37 | 38 | def initialize_with_defaults(self): 39 | for option in self.options: 40 | option.initialize_with_defaults() 41 | 42 | def parse(self): 43 | self.general.gpu_ids = [int(i) for i in self.general.gpu_ids.split(',')] 44 | 45 | for opt in self.options: 46 | opt.convert_mode_str_to_mode() 47 | 48 | def _init_parser(self, parser): 49 | for option in self.options: 50 | option.add_to_parser(parser) 51 | 52 | return parser 53 | 54 | def pretty_str(self): 55 | message = "" 56 | message += "-------------- Options --------------\n" 57 | for opt in self.options: 58 | message += (str(opt) + "\n") 59 | message += "----------------------------------\n" 60 | 61 | return message 62 | 63 | def get_gpu_ids(self): 64 | return self.general.gpu_ids 65 | -------------------------------------------------------------------------------- /option/pix2depth.py: -------------------------------------------------------------------------------- 1 | from .base import BaseOption, Argument 2 | 3 | class Pix2DepthOption(BaseOption): 4 | 5 | def set_arguments(self): 6 | 7 | # model parameters 8 | self.arguments += [Argument('mode', type=str, default='train', help="mode of the model [train | test | eval]")] 9 | self.arguments += [Argument('input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')] 10 | self.arguments += [Argument('output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')] 11 | self.arguments += [Argument('num_joints', type=int, default=21, help="the number of hand joints.")] 12 | self.arguments += [Argument('base_nc', type=int, default=64, help='# of filters for the first conv layer')] 13 | self.arguments += [Argument('net_type', type=str, default='resnet_3blocks', help='type of the generator backbone')] 14 | self.arguments += [Argument('norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')] 15 | self.arguments += [Argument('init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')] 16 | self.arguments += [Argument('init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')] 17 | self.arguments += [Argument('img_size', type=int, default=256, help='width and height of input and output image')] 18 | self.arguments += [Argument('gauss_kernel_size', type=int, default=31, help="the size of gaussian kernel for heatmaps")] 19 | self.arguments += [Argument('gauss_kernel_sigma', type=int, default=5, help="the sigma of gaussian kernel for heatmaps")] 20 | 21 | # train and test details 22 | self.arguments += [Argument('niter', type=int, default=1000, help='# of iter(epoch) at starting learning rate')] 23 | self.arguments += [Argument('niter_decay', type=int, default=1000, help='# of iter(epoch) to linearly decay learning rate to zero')] 24 | self.arguments += [Argument('beta1', type=float, default=0.5, help='momentum term of adam')] 25 | self.arguments += [Argument('lr', type=float, default=0.0002, help='initial learning rate for adam')] 26 | self.arguments += [Argument('train_only_encoder', action='store_true', default = False, help="train only encoder part of pix2depth.")] 27 | 28 | self.arguments += [Argument('depth_loss_weight', type=float, default=1.0, help='weight of depth error')] 29 | self.arguments += [Argument('heatmap_loss_weight', type=float, default=1.0, help='weight of heatmap losses')] 30 | self.arguments += [Argument('heatmap_interm_loss_weight', type=float, default=1.0, help='weight of intermediate heatmap losses')] 31 | self.arguments += [Argument('joint_loss_weight', type=float, default=1.0, help='weight of joint losses')] 32 | self.arguments += [Argument('joint_interm_loss_weight', type=float, default=1.0, help='weight of intermediate joint losses')] 33 | 34 | 35 | # pretrained 36 | self.arguments += [Argument('pretrained', type=str, default = "", help="pretrained model weights to load. it should be a checkpoint.")] 37 | -------------------------------------------------------------------------------- /plot_result_loss.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from collections import defaultdict 3 | import pathlib 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--result_path", required=True, type=str, help="path of the result folder.") 9 | parser.add_argument("--nrows", type=int, default = 2, help="number of rows for the plot.") 10 | parser.add_argument("--log_scale", action='store_true', default = False, help="Draw y axis in log scale.") 11 | 12 | return parser.parse_args() 13 | 14 | class LossParser: 15 | 16 | data_to_ignore = ('timestamp', 'tag') 17 | 18 | def __init__(self, args): 19 | self.result_dir = args.result_path 20 | self.loss_path = pathlib.Path(self.result_dir, 'losses.txt') 21 | assert self.loss_path.is_file(), "'losses.txt' does not exist in {}".format(self.result_dir) 22 | 23 | def parse(self): 24 | with open(str(self.loss_path), 'r') as f: 25 | lines = f.readlines() 26 | 27 | parsed_data = defaultdict(list) 28 | 29 | for line in lines: 30 | line_data = self._parse_line(line) 31 | for k, v in line_data.items(): 32 | parsed_data[k].append(v) 33 | 34 | return parsed_data 35 | 36 | def _parse_line(self, line): 37 | line = line.strip() 38 | data_str = line.split(',') 39 | data_str = self._remove_data_to_ignore(data_str) 40 | data = {} 41 | for d in data_str: 42 | k, v = d.split(':') 43 | data[k] = float(v) 44 | 45 | return data 46 | 47 | def _remove_data_to_ignore(self, data_str): 48 | filtered = [] 49 | for d in data_str: 50 | if not d.startswith(self.data_to_ignore): 51 | filtered.append(d) 52 | return filtered 53 | 54 | class Plotter: 55 | data_type_to_not_plot = ('epoch', 'iter') 56 | def __init__(self, data, args): 57 | self.n_row = args.nrows 58 | self.log_scale = args.log_scale 59 | self.data = data 60 | 61 | def plot(self): 62 | self.data = self._change_iter_to_epoch(self.data) 63 | self._plot_in_subplots(self.data) 64 | 65 | plt.show() 66 | 67 | def _change_iter_to_epoch(self, data): 68 | iter_per_epoch = self._count_iter_per_epoch(data) 69 | new_iter = [] 70 | for i, iter in enumerate(data['iter']): 71 | new_iter.append((i+1) * iter_per_epoch) 72 | 73 | data['iter'] = new_iter 74 | return data 75 | 76 | def _count_iter_per_epoch(self, data): 77 | epoch_set = set() 78 | for epoch in data['epoch']: 79 | epoch_set.add(epoch) 80 | 81 | num_epoch = len(epoch_set) 82 | num_iter = len(data['iter']) 83 | 84 | return num_epoch / num_iter 85 | 86 | def _plot_in_subplots(self, data): 87 | data_type_to_plot = self._get_data_type_to_plot(data) 88 | fig, axises = self._create_subplots(data_type_to_plot) 89 | 90 | x = data['iter'] 91 | for i, data_type in enumerate(data_type_to_plot): 92 | y = data[data_type] 93 | self._plot_data_in_ax(axises[i], x, y, data_type) 94 | 95 | def _create_subplots(self, data_types): 96 | n_col = self._get_n_col(data_types) 97 | fig, ax = plt.subplots(nrows = self.n_row, ncols = n_col) 98 | 99 | flatten_ax = [] 100 | for row in ax: 101 | for col in row: 102 | flatten_ax.append(col) 103 | 104 | return fig, flatten_ax 105 | 106 | def _get_data_type_to_plot(self, data): 107 | filtered = [] 108 | for key in data.keys(): 109 | if not key in self.data_type_to_not_plot: 110 | filtered.append(key) 111 | 112 | return filtered 113 | 114 | def _get_n_col(self, data_types): 115 | n_cols = int(len(data_types) / self.n_row) 116 | if len(data_types) % self.n_row > 0: 117 | n_cols += 1 118 | 119 | return n_cols 120 | 121 | def _plot_data_in_ax(self, ax, x, y, title): 122 | ax.set_title(title) 123 | ax.set_xlabel('epoch') 124 | ax.set_ylabel('loss') 125 | if self.log_scale: 126 | ax.set_yscale('log') 127 | ax.plot(x,y) 128 | 129 | def main(args): 130 | 131 | loss_parser = LossParser(args) 132 | parsed_data = loss_parser.parse() 133 | 134 | plotter = Plotter(parsed_data, args) 135 | plotter.plot() 136 | 137 | if __name__ == "__main__": 138 | args = parse_args() 139 | main(args) 140 | -------------------------------------------------------------------------------- /preset/__init__.py: -------------------------------------------------------------------------------- 1 | from preset.presets import modify_options 2 | -------------------------------------------------------------------------------- /preset/base_preset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class BasePreset(ABC): 4 | """ Preset is kind of a recipe. 5 | It has a set of options that should be predefined. 6 | Preset overrides default options and arguments. 7 | """ 8 | def __init__(self): 9 | pass 10 | 11 | @classmethod 12 | @abstractmethod 13 | def modify_options(cls, opt): 14 | pass 15 | -------------------------------------------------------------------------------- /preset/presets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from preset.base_preset import BasePreset 4 | 5 | def modify_options(options): 6 | if options.general.preset: 7 | preset = find_preset_by(options.general.preset) 8 | preset.modify_options(options) 9 | 10 | def find_preset_by(preset_name): 11 | class_name = "Preset{}".format(preset_name) 12 | try: 13 | preset_class = globals()[class_name] 14 | except KeyError as e: 15 | if preset_name: 16 | print("No preset that as {} name".format(preset_name)) 17 | raise 18 | return preset_class 19 | 20 | ###################### Some Common Settings ################# 21 | joint_loss_weight = 1.0 22 | joint_interm_loss_weight = 0.5 23 | heatmap_loss_weight = 250.0 24 | heatmap_interm_loss_weight = 125 25 | reprojection_loss_weight = 125 26 | 27 | ###################### Pix2Joint ################### 28 | class PresetPix2Joint(BasePreset): 29 | @classmethod 30 | def modify_options(cls, opt): 31 | opt.general.dataset = "synth" 32 | 33 | opt.hpe.norm_type = 'instance' 34 | opt.hpe.img_size = opt.general.img_size 35 | opt.hpe.input_channel = 3 36 | opt.hpe.init_gain = 0.2 37 | opt.hpe.joint_loss_weight = joint_loss_weight 38 | opt.hpe.joint_interm_loss_weight = joint_interm_loss_weight 39 | opt.hpe.heatmap_loss_weight = heatmap_loss_weight 40 | opt.hpe.heatmap_interm_loss_weight = heatmap_interm_loss_weight 41 | opt.hpe.reprojection_loss_weight = reprojection_loss_weight 42 | opt.hpe.lr_policy = "step" 43 | 44 | class PresetPix2JointTrain(PresetPix2Joint): 45 | @classmethod 46 | def modify_options(cls, opt): 47 | PresetPix2Joint.modify_options(opt) 48 | opt.hpe.mode = 'train' 49 | 50 | class PresetPix2JointTest(PresetPix2Joint): 51 | @classmethod 52 | def modify_options(cls, opt): 53 | PresetPix2Joint.modify_options(opt) 54 | opt.hpe.mode = 'test' 55 | 56 | class PresetRealPix2JointTrain(PresetPix2Joint): 57 | @classmethod 58 | def modify_options(cls, opt): 59 | PresetPix2Joint.modify_options(opt) 60 | opt.general.dataset = "real" 61 | opt.hpe.mode = 'train' 62 | 63 | class PresetRealPix2JointTest(PresetPix2Joint): 64 | @classmethod 65 | def modify_options(cls, opt): 66 | PresetPix2Joint.modify_options(opt) 67 | opt.general.dataset = "real" 68 | opt.hpe.mode = 'test' 69 | 70 | ###### Pix2depth encoder ##### 71 | class PresetPix2Depth(BasePreset): 72 | @classmethod 73 | def modify_options(cls, opt): 74 | 75 | opt.general.dataset = 'synth' 76 | opt.pix2depth.norm = 'instance' 77 | opt.pix2depth.init_gain = 0.1 78 | opt.pix2depth.net_type = 'resnet_3blocks' 79 | opt.pix2depth.depth_loss_weight = heatmap_loss_weight 80 | opt.pix2depth.output_nc = 1 81 | 82 | opt.pix2depth.joint_loss_weight = joint_loss_weight 83 | opt.pix2depth.joint_interm_loss_weight = joint_interm_loss_weight 84 | opt.pix2depth.heatmap_loss_weight = heatmap_loss_weight 85 | opt.pix2depth.heatmap_interm_loss_weight = heatmap_interm_loss_weight 86 | opt.pix2depth.reprojection_loss_weight = reprojection_loss_weight 87 | 88 | 89 | class PresetPix2DepthTrain(PresetPix2Depth): 90 | @classmethod 91 | def modify_options(cls, opt): 92 | PresetPix2Depth.modify_options(opt) 93 | opt.general.run = 'pix2depth_train' 94 | opt.pix2depth.mode = 'train' 95 | return opt 96 | 97 | class PresetPix2DepthTest(PresetPix2Depth): 98 | @classmethod 99 | def modify_options(cls, opt): 100 | PresetPix2Depth.modify_options(opt) 101 | opt.general.run = 'pix2depth_test' 102 | opt.pix2depth.mode = 'test' 103 | return opt 104 | 105 | class PresetRealPix2DepthTrain(PresetPix2Depth): 106 | @classmethod 107 | def modify_options(cls, opt): 108 | PresetPix2Depth.modify_options(opt) 109 | opt.general.run = 'pix2depth_train' 110 | opt.general.dataset = 'real' 111 | opt.pix2depth.mode = 'train' 112 | opt.pix2depth.train_only_encoder = True 113 | opt.pix2depth.depth_loss_weight = 0 114 | return opt 115 | 116 | ###################### Pipeline ################### 117 | class PresetPipeline(BasePreset): 118 | @classmethod 119 | def modify_options(cls, opt): 120 | opt.general.dataset = "synth" 121 | 122 | opt.hpe.mode = 'train' 123 | opt.hpe.network = 'basic' 124 | opt.hpe.img_size = opt.general.img_size 125 | opt.hpe.input_channel = 4 126 | opt.hpe.init_gain = 1.0 127 | opt.hpe.joint_loss_weight = joint_loss_weight 128 | opt.hpe.joint_interm_loss_weight = joint_interm_loss_weight 129 | opt.hpe.heatmap_loss_weight = heatmap_loss_weight 130 | opt.hpe.heatmap_interm_loss_weight = heatmap_interm_loss_weight 131 | opt.hpe.lr_policy = "step" 132 | opt.hpe.norm_type = 'instance' 133 | 134 | opt.pix2depth.model = 'resnet_encoder' 135 | opt.pix2depth.norm = 'instance' 136 | opt.pix2depth.n_layers_G = 3 137 | opt.pix2depth.netG = 'resnet_3blocks' 138 | opt.pix2depth.netD = 'basic' 139 | opt.pix2depth.init_gain = 0.1 140 | opt.pix2depth.depth_loss_weight = heatmap_loss_weight * 10 141 | 142 | opt.pix2depth.joint_loss_weight = joint_loss_weight 143 | opt.pix2depth.joint_interm_loss_weight = joint_interm_loss_weight 144 | opt.pix2depth.heatmap_loss_weight = heatmap_loss_weight 145 | opt.pix2depth.heatmap_interm_loss_weight = heatmap_interm_loss_weight 146 | 147 | opt.pix2depth.output_nc = 1 148 | 149 | ###################### Pix2Depth2JointLite (pipeline)################### 150 | class PresetPipelineSynthTrain(PresetPipeline): 151 | @classmethod 152 | def modify_options(cls, opt): 153 | PresetPipeline.modify_options(opt) 154 | opt.general.run = "pipeline_train" 155 | 156 | class PresetPipelineSynthTest(PresetPipeline): 157 | @classmethod 158 | def modify_options(cls, opt): 159 | PresetPipeline.modify_options(opt) 160 | opt.general.run = "pipeline_test" 161 | 162 | class PresetPipelineRealTrain(PresetPipelineSynthTrain): 163 | @classmethod 164 | def modify_options(cls, opt): 165 | PresetPipelineSynthTrain.modify_options(opt) 166 | opt.general.dataset = "real" 167 | 168 | class PresetPipelineRealTest(PresetPipelineSynthTest): 169 | @classmethod 170 | def modify_options(cls, opt): 171 | PresetPipelineSynthTest.modify_options(opt) 172 | opt.general.dataset = "real" 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | opencv-python 3 | pip-chill 4 | scipy 5 | tb-nightly 6 | tensorboardx 7 | torchvision 8 | -------------------------------------------------------------------------------- /run/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from util.package import find_class_using_name 3 | from .base.base_run import BaseRun 4 | 5 | def find_run_using_name(run_name): 6 | subpackage_name = find_subpackage(run_name) 7 | package_name = 'run.{}'.format(subpackage_name) 8 | run_cls = find_class_using_name(package_name, run_name, 'run') 9 | if inspect.isclass(run_cls) and issubclass(run_cls, BaseRun): 10 | return run_cls 11 | 12 | raise Exception("{} is not correctely implemented as BaseRun class".format(run_name)) 13 | 14 | def find_subpackage(run_name): 15 | subpackage_name = run_name.split('_')[:-1] 16 | subpackage_name = '_'.join(subpackage_name) 17 | 18 | return subpackage_name 19 | -------------------------------------------------------------------------------- /run/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/run/base/__init__.py -------------------------------------------------------------------------------- /run/base/base_run.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from util.io import * 3 | from util.visualizer import Visualizer 4 | from dataset import * 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class BaseRun(ABC): 10 | """ This is basic abastract Run class. 11 | The run class contains core algorthm of train or test running. 12 | """ 13 | 14 | def __init__(self, options): 15 | self.options = options 16 | self.logger = self._get_logger(options) 17 | self.visualizer = self._get_visualizer(self.logger) 18 | 19 | self.gpu_ids = self.options.general.gpu_ids 20 | 21 | def _get_logger(self, options): 22 | logger = Logger(options.general) 23 | logger.save_options(options) 24 | return logger 25 | 26 | def _get_visualizer(self, logger): 27 | tensorboard_path = logger.get_tensorboard_path() 28 | return Visualizer(str(tensorboard_path)) 29 | 30 | def get_train_loader(self): 31 | opt = self.options.general 32 | train_dataset = create_train_dataset(opt) 33 | print("BaseRun: total number of training data: {}".format(len(train_dataset))) 34 | train_loader = create_dataloader(train_dataset, batch_size = opt.batch_size, num_workers = opt.num_workers, shuffle=True) 35 | return train_loader 36 | 37 | def get_test_loader(self, shuffle = False): 38 | opt = self.options.general 39 | test_dataset = create_test_dataset(opt) 40 | 41 | print("BaseRun: total number of test data: {}".format(len(test_dataset))) 42 | test_loader = create_dataloader(test_dataset, batch_size = 1, num_workers = opt.num_workers, shuffle=shuffle) 43 | return test_loader 44 | 45 | @abstractmethod 46 | def setup(self): 47 | pass 48 | 49 | def toss_coin(self): 50 | return random.random() > 0.5 51 | -------------------------------------------------------------------------------- /run/base/hpe_base_util.py: -------------------------------------------------------------------------------- 1 | from dataset.data_model import HandDataModel 2 | 3 | def unpack_data(results, is_eval = False): 4 | joint_out = results['joint'] 5 | heatmap = None 6 | heatmap_true = None 7 | if not is_eval: 8 | heatmap = results['heatmap'] 9 | heatmap_true = results['heatmap_true'] 10 | heatmap_reprojected = results['heatmap_reprojected'] 11 | 12 | return joint_out, heatmap, heatmap_true, heatmap_reprojected 13 | -------------------------------------------------------------------------------- /run/base/hpe_test_base_run.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import abstractmethod 3 | import torch 4 | 5 | from run.base.test_base_run import TestBaseRun 6 | import run.base.hpe_base_util as hpe_util 7 | 8 | from util.image import convert_to_colormap 9 | from util import StatDict, cal_L1_diff, cal_RMS_diff 10 | 11 | class HPETestBaseRun(TestBaseRun): 12 | 13 | def setup(self): 14 | 15 | # move some paramters from the options 16 | self.img_size = self.options.hpe.img_size 17 | 18 | self.model = self.make_model() 19 | 20 | self.heatmap_max = 1 21 | 22 | self.stat_dict = StatDict() 23 | self.no_save_image = self.options.general.no_save_image 24 | 25 | self.euclidean_errors = [] 26 | self.pck_thresholds = np.linspace(0, 0.5, 11) 27 | 28 | @abstractmethod 29 | def make_model(self): 30 | pass 31 | 32 | def end_test(self): 33 | result = {} 34 | result['avg'] = self.stat_dict.get_avg() 35 | result['std'] = self.stat_dict.get_std() 36 | 37 | print(result) 38 | 39 | self.logger.write_loss(result) 40 | pck_results = self._cal_pck(self.euclidean_errors, self.pck_thresholds) 41 | self.logger.write_pck(pck_results) 42 | 43 | def test(self, data, current_iter): 44 | data = self.arrange_data(data) 45 | 46 | self.model.set_input(data) 47 | self.model.forward() 48 | results = self.model.get_detached_current_results() 49 | 50 | joint = data['joint'] 51 | joint_out, heatmap_out, heatmap_true, heatmap_reprojected = hpe_util.unpack_data(results, self.model.mode.is_eval()) 52 | 53 | if (not self.no_save_image) and (heatmap_out is not None): 54 | img = results['img'] 55 | if img.size(1) == 1: 56 | img = convert_to_colormap(img, 1) 57 | out_heatmap_img = convert_to_colormap(heatmap_out, self.heatmap_max) 58 | true_heatmap_img = convert_to_colormap(heatmap_true, self.heatmap_max) 59 | reprojected_heatmap_img = convert_to_colormap(heatmap_reprojected, self.heatmap_max) 60 | stacked_img = torch.cat((img, out_heatmap_img, reprojected_heatmap_img, true_heatmap_img), 3) # horizontal_stack 61 | stacked_img = stacked_img.squeeze() 62 | self.save_img(stacked_img) 63 | 64 | losses = {} 65 | losses['Joint L1'] = cal_L1_diff(joint_out, joint, reduction = 'mean') 66 | losses['Joint RMS'] = cal_RMS_diff(joint_out, joint, reduction = 'mean') 67 | 68 | if not self.model.mode.is_eval(): 69 | losses['Heatmap L1'] = cal_L1_diff(heatmap_out, heatmap_true, reduction = 'mean') 70 | losses['Heatmap RMS'] = cal_RMS_diff(heatmap_out, heatmap_true, reduction = 'mean') 71 | 72 | self.stat_dict.add(losses) 73 | 74 | euclidean_error = self._cal_eucliean_error(joint_out, joint) 75 | self.euclidean_errors.append(euclidean_error) 76 | 77 | @abstractmethod 78 | def arrange_data(self, data): 79 | pass 80 | 81 | def _cal_eucliean_error(self, joint1, joint2): 82 | diff = joint1 - joint2 83 | 84 | diff = diff.numpy() 85 | diff = diff.reshape((-1, 3)) 86 | error = np.linalg.norm(diff, axis = 1) 87 | 88 | return error 89 | 90 | def _cal_pck(self, error, thresholds): 91 | results = [] 92 | for thrs in thresholds: 93 | acc = np.mean(error < thrs) 94 | results.append((thrs, acc)) 95 | 96 | return results 97 | -------------------------------------------------------------------------------- /run/base/hpe_train_base_run.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from .train_base_run import TrainBaseRun 3 | 4 | from util.io import LossLog 5 | from util import Timer 6 | 7 | from util.image import * 8 | import torch 9 | 10 | import run.base.hpe_base_util as hpe_util 11 | from dataset.data_model import HandDataModel 12 | 13 | class HPETrainBaseRun(TrainBaseRun): 14 | 15 | def setup(self): 16 | super().setup() 17 | 18 | self.img_size = self.options.hpe.img_size 19 | self.speed_diagnose = self.options.general.speed_diagnose 20 | 21 | self.model = self.make_model() 22 | self.heatmap_max = 1 23 | 24 | self.last_results = None 25 | self.timer = Timer() 26 | 27 | @abstractmethod 28 | def make_model(self): 29 | pass 30 | 31 | def iterate(self, data): 32 | if self.speed_diagnose: 33 | self.timer.start('preprocess') 34 | 35 | data = self.arrange_data(data) 36 | 37 | if self.speed_diagnose: 38 | self.timer.stop('preprocess') 39 | self.timer.start('setting input') 40 | 41 | self.model.set_input(data) 42 | 43 | if self.speed_diagnose: 44 | self.timer.stop('setting input') 45 | self.timer.start('optimize') 46 | 47 | self.model.optimize() 48 | if self.speed_diagnose: 49 | self.timer.stop('optimize') 50 | self.timer.print_elapsed_times() 51 | 52 | self.avg_dict.add(self.model.get_current_losses()) 53 | 54 | # save the result for visualization 55 | self.last_results = self.model.get_detached_current_results() 56 | self.last_data = data 57 | 58 | def save_checkpoint(self, epoch): 59 | checkpoint = self.model.pack_as_checkpoint() 60 | self.logger.save_checkpoint(checkpoint, epoch) 61 | 62 | def end_epoch(self): 63 | pass 64 | 65 | @abstractmethod 66 | def arrange_data(self, data): 67 | """ reshape the data for the model. """ 68 | 69 | def _visualize_results_as_image(self, results, cur_iter): 70 | 71 | if results is None: 72 | return 73 | 74 | results = self._select_first_in_batch(results) 75 | img = results['img'] 76 | joint_out, heatmap_out, heatmap_true, heatmap_reprojected = hpe_util.unpack_data(results) 77 | 78 | out_heatmap_img = convert_to_colormap(heatmap_out, 1.0) 79 | true_heatmap_img = convert_to_colormap(heatmap_true, 1.0) 80 | reprojected_heatmap_img = convert_to_colormap(heatmap_reprojected, 1.0) 81 | img = expand_channel(img) 82 | 83 | stacked_img = torch.cat((img, out_heatmap_img, reprojected_heatmap_img, true_heatmap_img), 3) # horizontal_stack 84 | self.visualizer.add_image('train sample', stacked_img, cur_iter) 85 | 86 | def _visualize_network_grad(self, epoch, current_iter): 87 | grads = self.model.get_grads() 88 | for tag, val in grads.items(): 89 | self.visualizer.add_histogram(tag, val, epoch) 90 | -------------------------------------------------------------------------------- /run/base/test_base_run.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from .base_run import BaseRun 3 | 4 | class TestBaseRun(BaseRun): 5 | 6 | @abstractmethod 7 | def test(self, data, current_iter): 8 | """ Runs at every iteration. """ 9 | 10 | def save_img(self, img): 11 | if self.options.general.no_save_image: 12 | return 13 | self.logger.save_image_tensor(img) 14 | 15 | @abstractmethod 16 | def end_test(self): 17 | """ Runs once at the end of the whole test. """ 18 | -------------------------------------------------------------------------------- /run/base/train_base_run.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from .base_run import BaseRun 3 | 4 | from util import AverageDict 5 | from util.io import LossLog 6 | 7 | class TrainBaseRun(BaseRun): 8 | 9 | def setup(self): 10 | super().setup() 11 | 12 | self.avg_dict = AverageDict() 13 | # it is for intermediate visualizations 14 | self.last_results = None 15 | self.last_data = None 16 | self.show_grad = self.options.general.show_grad 17 | 18 | @abstractmethod 19 | def iterate(self, data): 20 | """ Runs at every iteration. """ 21 | pass 22 | 23 | @abstractmethod 24 | def end_epoch(self): 25 | """ Runs at every end of epoch. """ 26 | pass 27 | 28 | @abstractmethod 29 | def save_checkpoint(self, epoch): 30 | pass 31 | 32 | def log_and_visualize_iteration(self, epoch, current_iter): 33 | self._log_and_vis_scalar(epoch, current_iter) 34 | self._visualize_results_as_image(self.last_results, current_iter) 35 | if self.show_grad: 36 | self._visualize_network_grad(epoch, current_iter) 37 | 38 | def _log_and_vis_scalar(self, epoch, current_iter): 39 | losses = self.avg_dict.to_dict() 40 | self.avg_dict.reset() 41 | loss_log = LossLog(losses, epoch, current_iter, 'train') 42 | self.logger.write_loss(loss_log) 43 | self.visualizer.add_losses('train', losses, current_iter) 44 | 45 | @abstractmethod 46 | def _visualize_results_as_image(self, results, current_iter): 47 | """ This can be different by a subclasse's purpose. """ 48 | pass 49 | 50 | @abstractmethod 51 | def _visualize_network_grad(self, epoch, current_iter): 52 | pass 53 | 54 | def _select_first_in_batch(self, results): 55 | first_results = {} 56 | 57 | for k, v in results.items(): 58 | first_results[k] = v[0].unsqueeze(0) 59 | if 'interm' in k: 60 | # interm results are in list. select the last one. 61 | first_results[k] = v[-1][0].unsqueeze(0) 62 | 63 | return first_results 64 | -------------------------------------------------------------------------------- /run/empty_run.py: -------------------------------------------------------------------------------- 1 | from run.base.base_run import BaseRun 2 | 3 | class EmptyRun(BaseRun): 4 | def setup(self): 5 | pass 6 | -------------------------------------------------------------------------------- /run/pipeline/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def unpack_data(results): 4 | pix = results['input'] 5 | hand_pix = results['hand_pix'] 6 | fake_fish_depth = results['fake_fish_depth'] 7 | heatmap = results['heatmap'] 8 | heatmap_true = results['heatmap_true'] 9 | heatmap_reprojected = results['heatmap_reprojected'] 10 | joint = results['joint'] 11 | 12 | return pix, hand_pix, fake_fish_depth, heatmap, heatmap_true, heatmap_reprojected, joint 13 | -------------------------------------------------------------------------------- /run/pipeline/pipeline_test_run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from run.base.test_base_run import TestBaseRun 5 | import run.pipeline.helper as helper 6 | from model.pipeline.deep_fisheye_pipeline import DeepFisheyePipeline 7 | from util import StatDict, cal_L1_diff, cal_RMS_diff 8 | from util.image import convert_to_colormap 9 | 10 | class PipelineTestRun(TestBaseRun): 11 | def setup(self): 12 | super().setup() 13 | 14 | self.model = DeepFisheyePipeline(self.options) 15 | self.model.setup() 16 | 17 | self.stat_dict = StatDict() 18 | 19 | self.last_results = None 20 | 21 | self.euclidean_errors = [] 22 | self.pck_thresholds = np.linspace(0, 0.5, 11) 23 | 24 | def test(self, data, current_iter): 25 | 26 | self.model.set_input(data) 27 | self.model.forward() 28 | 29 | results = self.model.get_detached_current_results() 30 | 31 | joint_true = data['joint'].detach().cpu() 32 | pix, hand_pix, fake_fish_depth, heatmap, heatmap_true, heatmap_reprojected, joint = helper.unpack_data(results) 33 | fake_fish_depth_img = convert_to_colormap(fake_fish_depth) 34 | out_heatmap_img = convert_to_colormap(heatmap) 35 | true_heatmap_img = convert_to_colormap(heatmap_true) 36 | 37 | stacked_img = torch.cat((pix, fake_fish_depth_img, out_heatmap_img, true_heatmap_img), 3) 38 | stacked_img = stacked_img.squeeze() 39 | self.save_img(stacked_img) 40 | 41 | losses = {} 42 | losses['heatmap L1'] = cal_L1_diff(heatmap, heatmap_true, reduction = 'mean') 43 | losses['heatmap RMS'] = cal_RMS_diff(heatmap, heatmap_true, reduction = 'mean') 44 | losses['joint L1'] = cal_L1_diff(joint, joint_true, reduction = 'mean') 45 | losses['joint RMS'] = cal_RMS_diff(joint, joint_true, reduction = 'mean') 46 | 47 | self.stat_dict.add(losses) 48 | 49 | euclidean_error = self._cal_eucliean_error(joint, joint_true) 50 | self.euclidean_errors.append(euclidean_error) 51 | 52 | def end_test(self): 53 | result = {} 54 | result['avg'] = self.stat_dict.get_avg() 55 | result['std'] = self.stat_dict.get_std() 56 | 57 | print(result) 58 | 59 | self.logger.write_loss(result) 60 | self.euclidean_errors = np.array(self.euclidean_errors) 61 | pck_results = self._cal_pck(self.euclidean_errors, self.pck_thresholds) 62 | self.logger.write_pck(pck_results) 63 | 64 | def _cal_eucliean_error(self, joint1, joint2): 65 | diff = joint1 - joint2 66 | 67 | diff = diff.numpy() 68 | diff = diff.reshape((-1, 3)) 69 | error = np.linalg.norm(diff, axis = 1) 70 | 71 | return error 72 | 73 | def _cal_pck(self, error, thresholds): 74 | results = [] 75 | for thrs in thresholds: 76 | acc = np.mean(error < thrs) 77 | results.append((thrs, acc)) 78 | 79 | return results 80 | -------------------------------------------------------------------------------- /run/pipeline/pipeline_train_run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from run.base.train_base_run import TrainBaseRun 4 | import run.pipeline.helper as helper 5 | from model.pipeline.deep_fisheye_pipeline import DeepFisheyePipeline 6 | from util import AverageDict 7 | from util.io import LossLog 8 | from util.image import convert_to_colormap 9 | 10 | class PipelineTrainRun(TrainBaseRun): 11 | def setup(self): 12 | super().setup() 13 | 14 | self.pipeline = DeepFisheyePipeline(self.options) 15 | self.pipeline.setup() 16 | 17 | self.avg_dict = AverageDict() 18 | 19 | self.last_results = None 20 | 21 | def iterate(self, data): 22 | self.pipeline.set_input(data) 23 | self.pipeline.optimize() 24 | self.avg_dict.add(self.pipeline.get_current_losses()) 25 | 26 | # save the result for visualization 27 | self.last_results = self.pipeline.get_detached_current_results() 28 | 29 | def _visualize_results_as_image(self, results, cur_iter): 30 | if results is None: 31 | return 32 | 33 | results = self._select_first_in_batch(results) 34 | pix, hand_pix, fake_fish_depth, heatmap, heatmap_true, heatmap_reprojected, joint = helper.unpack_data(results) 35 | 36 | fake_fish_depth_img = convert_to_colormap(fake_fish_depth) 37 | out_heatmap_img = convert_to_colormap(heatmap) 38 | true_heatmap_img = convert_to_colormap(heatmap_true) 39 | reprojected_img = convert_to_colormap(heatmap_reprojected) 40 | 41 | stacked_img = torch.cat((pix, fake_fish_depth_img, hand_pix, out_heatmap_img, reprojected_img, true_heatmap_img), 3) 42 | self.visualizer.add_image('train sample', stacked_img, cur_iter) 43 | 44 | def _visualize_network_grad(self, epoch, current_iter): 45 | grads = self.pipeline.get_grads() 46 | for tag, val in grads.items(): 47 | self.visualizer.add_histogram(tag, val, epoch) 48 | 49 | def save_checkpoint(self, epoch): 50 | checkpoint = self.pipeline.pack_as_checkpoint() 51 | self.logger.save_checkpoint(checkpoint, epoch) 52 | 53 | def end_epoch(self): 54 | pass 55 | -------------------------------------------------------------------------------- /run/pix2depth/helper.py: -------------------------------------------------------------------------------- 1 | def arrange_data(data): 2 | _data = {'pix': data['fish']} 3 | if 'fish_depth' in data: 4 | _data['depth'] = data['fish_depth'] 5 | 6 | if 'joint' in data: 7 | _data['joint'] = data['joint'] 8 | 9 | return _data 10 | -------------------------------------------------------------------------------- /run/pix2depth/pix2depth_train_run.py: -------------------------------------------------------------------------------- 1 | from run.base.train_base_run import TrainBaseRun 2 | import run.pix2depth.helper as helper 3 | from util.io import LossLog 4 | from util import AverageDict 5 | from util.image import convert_to_colormap 6 | from util.projector import FisheyeProjector 7 | from model.pix2depth_model import Pix2DepthModel 8 | 9 | import torch 10 | 11 | class Pix2DepthTrainRun(TrainBaseRun): 12 | 13 | def setup(self): 14 | super().setup() 15 | 16 | opt = self.options.pix2depth 17 | self.img_size = opt.img_size 18 | self.model = Pix2DepthModel(opt, self.gpu_ids) 19 | 20 | projector = self.get_projector() 21 | self.model.setup(projector) 22 | 23 | self.last_results = None 24 | 25 | def get_projector(self): 26 | return FisheyeProjector(self.img_size) 27 | 28 | def iterate(self, data): 29 | data = helper.arrange_data(data) 30 | self.model.set_input(data) 31 | self.model.optimize() 32 | self.avg_dict.add(self.model.get_current_losses()) 33 | 34 | # save the result for visualization 35 | self.last_results = self.model.get_detached_current_results() 36 | 37 | def log_and_visualize_iteration(self, epoch, current_iter): 38 | self._log_and_vis_scalar(epoch, current_iter) 39 | self._visualize_results_as_image(self.last_results, current_iter) 40 | 41 | def _visualize_results_as_image(self, results, cur_iter): 42 | if results is None: 43 | return 44 | 45 | results = self._select_first_in_batch(results) 46 | pix = results['pix'] 47 | fake_depth = convert_to_colormap(results['fake_depth'], 1.0) 48 | if 'depth' in results: 49 | depth = convert_to_colormap(results['depth'], 1.0) 50 | else: 51 | depth = torch.zeros(fake_depth.shape) 52 | 53 | if 'heatmap_interms' in results: 54 | last_heatmap = convert_to_colormap(results['heatmap_interms'], 1.0) 55 | else: 56 | last_heatmap = torch.zeros(fake_depth.shape) 57 | 58 | img = torch.cat((pix, last_heatmap, fake_depth, depth), 3) # stack horizontally 59 | 60 | self.visualizer.add_image('train sample', img, cur_iter) 61 | 62 | def save_checkpoint(self, epoch): 63 | checkpoint = self.model.pack_as_checkpoint() 64 | self.logger.save_checkpoint(checkpoint, epoch) 65 | 66 | def end_epoch(self): 67 | pass 68 | 69 | def _visualize_network_grad(self, epoch, current_iter): 70 | grads = self.model.get_grads() 71 | for tag, val in grads.items(): 72 | self.visualizer.add_histogram(tag, val, epoch) 73 | -------------------------------------------------------------------------------- /run/pix2joint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/run/pix2joint/__init__.py -------------------------------------------------------------------------------- /run/pix2joint/pix2joint_test_run.py: -------------------------------------------------------------------------------- 1 | from run.base.hpe_test_base_run import HPETestBaseRun 2 | from model.hpe_model import HPEModel 3 | from util.projector import FisheyeProjector 4 | 5 | class Pix2JointTestRun(HPETestBaseRun): 6 | 7 | def make_model(self): 8 | opt = self.options.hpe 9 | model = HPEModel(opt, self.gpu_ids) 10 | model.setup() 11 | return model 12 | 13 | def arrange_data(self, data): 14 | _data = {'img': data['fish'], \ 15 | 'joint': data['joint']} 16 | return _data 17 | -------------------------------------------------------------------------------- /run/pix2joint/pix2joint_train_run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from model.pipeline import find_pipeline_using_name 4 | from util.io import LossLog 5 | from model.hpe_model import HPEModel 6 | from run.base.hpe_train_base_run import HPETrainBaseRun 7 | 8 | class Pix2JointTrainRun(HPETrainBaseRun): 9 | 10 | def make_model(self): 11 | opt = self.options.hpe 12 | model = HPEModel(opt, self.gpu_ids) 13 | model.setup() 14 | return model 15 | 16 | def arrange_data(self, data): 17 | fish_data = data['fish'] 18 | _data = {'img': fish_data, \ 19 | 'joint': data['joint']} 20 | return _data 21 | -------------------------------------------------------------------------------- /run_board.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | rm -r tensorboard_logs 3 | mkdir tensorboard_logs 4 | tensorboard --logdir tensorboard_logs --port 8008 --bind_all 5 | -------------------------------------------------------------------------------- /scripts/check_synth_dataset_stat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python check_dataset_stat.py --dataset synth 3 | -------------------------------------------------------------------------------- /scripts/pipeline/test_pipeline_synth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python test.py \ 3 | --preset PipelineSynthTest \ 4 | --name test_pipeline_synth \ 5 | --batch_size 1 \ 6 | --gpu_ids 0 \ 7 | --num_workers 0 \ 8 | --pipeline_pretrained path/to/weight.pth.tar \ 9 | --no_save_image \ 10 | -------------------------------------------------------------------------------- /scripts/pipeline/train_pipeline_real.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pipeline_real_train \ 4 | --run pipeline_train \ 5 | --preset PipelineRealTrain \ 6 | --batch_size 16 \ 7 | --gpu_ids 0,1,2 \ 8 | --num_workers 2 \ 9 | --save_epoch 4 \ 10 | --p2d_lr 0.00001 \ 11 | --hpe_lr 0.0001 \ 12 | --epoch 20 \ 13 | --print_iter 10 \ 14 | --pipeline_pretrained path/to/weight.pth.tar \ 15 | -------------------------------------------------------------------------------- /scripts/pipeline/train_pipeline_real_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pipeline_real_local_train \ 4 | --run pipeline_train \ 5 | --preset PipelineRealTrain \ 6 | --batch_size 1 \ 7 | --gpu_ids 0 \ 8 | --max_data 20 \ 9 | --num_workers 0 \ 10 | --save_epoch 1 \ 11 | --p2d_lr 0.00002 \ 12 | --hpe_lr 0.001 \ 13 | --epoch 30 \ 14 | --print_iter 10 \ 15 | --pipeline_pretrained path/to/weight.pth.tar \ 16 | -------------------------------------------------------------------------------- /scripts/pipeline/train_pipeline_synth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pipeline_synth_train \ 4 | --run pipeline_train \ 5 | --preset PipelineSynthTrain \ 6 | --batch_size 16 \ 7 | --gpu_ids 0,1,2 \ 8 | --num_workers 2 \ 9 | --save_epoch 2 \ 10 | --p2d_lr 0.0001 \ 11 | --hpe_lr 0.01 \ 12 | --epoch 6 \ 13 | --print_iter 10 \ 14 | --p2d_pretrained path/to/weight.pth.tar \ 15 | -------------------------------------------------------------------------------- /scripts/pipeline/train_pipeline_synth_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pipeline_synth_local_train \ 4 | --run pipeline_train \ 5 | --preset PipelineSynthTrain \ 6 | --batch_size 1 \ 7 | --gpu_ids 0 \ 8 | --max_data 20 \ 9 | --num_workers 0 \ 10 | --save_epoch 100 \ 11 | --p2d_lr 0.0005 \ 12 | --hpe_lr 0.005 \ 13 | --epoch 30 \ 14 | --p2d_pretrained path/to/weight.pth.tar \ 15 | -------------------------------------------------------------------------------- /scripts/pix2depth/train_pix2depth.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --preset Pix2DepthTrain \ 4 | --name pix2depth_train \ 5 | --epoch 12 \ 6 | --batch_size 32 \ 7 | --gpu_ids 0,1,2 \ 8 | --print_iter 10 \ 9 | --num_workers 2 \ 10 | --save_epoch 2 \ 11 | --p2d_lr 0.001 \ 12 | -------------------------------------------------------------------------------- /scripts/pix2depth/train_pix2depth_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --preset Pix2DepthTrain \ 4 | --name pix2depth_train_local \ 5 | --max_data 20\ 6 | --epoch 6 \ 7 | --batch_size 2 \ 8 | --gpu_ids 0 \ 9 | --print_iter 10 \ 10 | --num_workers 2 \ 11 | --save_epoch 1 \ 12 | --p2d_lr 0.0002 \ 13 | -------------------------------------------------------------------------------- /scripts/pix2joint/test_pix2joint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python test.py \ 3 | --name pix2joint_test \ 4 | --run pix2joint_test \ 5 | --preset Pix2JointTest \ 6 | --batch_size 1 \ 7 | --gpu_ids 0 \ 8 | --num_workers 0 \ 9 | --hpe_network basic \ 10 | --hpe_pretrained path/to/weight.pth.tar \ 11 | --no_save_image \ 12 | --max_data 5000 13 | -------------------------------------------------------------------------------- /scripts/pix2joint/test_real_pix2joint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python test.py \ 3 | --name real_pix2joint_test \ 4 | --run pix2joint_test \ 5 | --preset RealPix2JointTest \ 6 | --batch_size 1 \ 7 | --gpu_ids 0 \ 8 | --num_workers 0 \ 9 | --hpe_network basic \ 10 | --hpe_pretrained path/to/weight.pth.tar \ 11 | --no_save_image 12 | -------------------------------------------------------------------------------- /scripts/pix2joint/train_pix2joint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pix2joint_train \ 4 | --run pix2joint_train \ 5 | --preset Pix2JointTrain \ 6 | --batch_size 16 \ 7 | --print_iter 10 \ 8 | --gpu_ids 0,1,2 \ 9 | --num_workers 16 \ 10 | --save_epoch 2 \ 11 | --epoch 6 \ 12 | --hpe_network basic \ 13 | --hpe_lr 0.01 \ 14 | #--show_grad 15 | -------------------------------------------------------------------------------- /scripts/pix2joint/train_pix2joint_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pix2joint_local_train \ 4 | --run pix2joint_train \ 5 | --preset Pix2JointTrain \ 6 | --batch_size 3 \ 7 | --gpu_ids 0 \ 8 | --max_data 500 \ 9 | --num_workers 0 \ 10 | --save_epoch 1 \ 11 | --hpe_network basic \ 12 | --hpe_lr 0.1 \ 13 | -------------------------------------------------------------------------------- /scripts/pix2joint/train_real_pix2joint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name fish_pix2joint_train \ 4 | --run pix2joint_train \ 5 | --preset RealPix2JointTrain \ 6 | --batch_size 16 \ 7 | --print_iter 10 \ 8 | --gpu_ids 0,1,2 \ 9 | --num_workers 4 \ 10 | --save_epoch 4 \ 11 | --hpe_lr 0.0001 \ 12 | --epoch 12 \ 13 | --hpe_network basic \ 14 | --hpe_pretrained path/to/weight.pth.tar \ 15 | -------------------------------------------------------------------------------- /scripts/pix2joint/train_real_pix2joint_local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --name pix2joint_local_train \ 4 | --run pix2joint_train \ 5 | --preset RealPix2JointTrain \ 6 | --batch_size 3 \ 7 | --print_iter 10 \ 8 | --gpu_ids 0 \ 9 | --max_data 500 \ 10 | --num_workers 0 \ 11 | --save_epoch 1 \ 12 | --hpe_lr 0.0005 \ 13 | --epoch 12 \ 14 | --hpe_network basic \ 15 | --hpe_pretrained path/to/weight.pth.tar \ 16 | -------------------------------------------------------------------------------- /technical_concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-HCIL/DeepFisheyeNet/ecdf9265a0f7c5048ace0636c0ce26260a39f352/technical_concept.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from option.options import Options 2 | from preset import modify_options 3 | from dataset import * 4 | from run import find_run_using_name 5 | 6 | def main(): 7 | options = Options() 8 | options.initialize() 9 | modify_options(options) 10 | options.parse() 11 | 12 | print(options.pretty_str()) 13 | 14 | run_cls = find_run_using_name(options.general.run) 15 | run = run_cls(options) 16 | 17 | test_loader = run.get_test_loader(shuffle = True) 18 | 19 | run.setup() 20 | 21 | for i, data in enumerate(test_loader): 22 | print("testing {}th data".format(i)) 23 | run.test(data, i) 24 | 25 | run.end_test() 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from option.options import Options 2 | from preset import modify_options 3 | from dataset import * 4 | from run import find_run_using_name 5 | 6 | def iter_to_epoch(cur_iter, batch_size, num_data): 7 | print(cur_iter, batch_size, num_data) 8 | return float(cur_iter * batch_size) / num_data 9 | 10 | def epoch_to_iter(epoch, total_epoch, total_iter): 11 | return (epoch / total_epoch) * total_iter 12 | 13 | def main(): 14 | options = Options() 15 | options.initialize() 16 | modify_options(options) 17 | options.parse() 18 | 19 | print(options.pretty_str()) 20 | run_cls = find_run_using_name(options.general.run) 21 | run = run_cls(options) 22 | 23 | train_loader = run.get_train_loader() 24 | num_iter = len(train_loader) 25 | 26 | general_opt = options.general 27 | cur_iter = 0 28 | 29 | run.setup() 30 | 31 | for epoch in range(1, general_opt.epoch+1): 32 | for i, data in enumerate(train_loader): 33 | cur_iter += 1 34 | run.iterate(data) 35 | 36 | if cur_iter % general_opt.print_iter == 0: 37 | float_epoch = cur_iter / num_iter 38 | run.log_and_visualize_iteration(epoch, cur_iter) 39 | print("training progress: {}/{}".format(float_epoch, general_opt.epoch)) 40 | 41 | if epoch % general_opt.save_epoch == 0: 42 | run.save_checkpoint(epoch) 43 | print("checkpoint saved at {}th epoch".format(epoch)) 44 | 45 | run.end_epoch() 46 | 47 | if __name__ == '__main__': 48 | main() 49 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, OrderedDict 2 | import torch 3 | import numpy as np 4 | 5 | def cal_L1_diff(estimated, real, reduction = 'mean', dim = []): 6 | diff = (estimated - real).abs() 7 | if reduction == 'mean': 8 | return diff.mean(dim = dim) 9 | elif reduction == 'sum': 10 | return diff.mean(dim = dim) 11 | 12 | raise NotImplementedError("reduction should be either 'mean' or 'sum'") 13 | 14 | def cal_RMS_diff(estimated, real, reduction = 'mean', dim = []): 15 | diff = estimated - real 16 | if reduction == 'mean': 17 | return (diff ** 2).mean(dim = dim).sqrt() 18 | elif reduction == 'sum': 19 | return (diff ** 2).sum(dim = dim).sqrt() 20 | 21 | raise NotImplementedError("reduction should be either 'mean' or 'sum'") 22 | 23 | class StatDict: 24 | def __init__(self): 25 | self.data_group = defaultdict(list) 26 | 27 | def add(self, losses): 28 | for k, v in losses.items(): 29 | v = self._if_single_then_multi_dim(v) 30 | self.data_group[k].append(v) 31 | 32 | def _if_single_then_multi_dim(self, tensor): 33 | if len(tensor.shape) == 1: 34 | return tensor.unsqueeze(0) 35 | 36 | return tensor 37 | 38 | def get_avg(self): 39 | avg = OrderedDict() 40 | for k, v in self.data_group.items(): 41 | combined = torch.stack(v, 0) 42 | avg[k] = torch.mean(combined, dim = [0]) 43 | return avg 44 | 45 | def get_std(self): 46 | deviation = OrderedDict() 47 | for k, v in self.data_group.items(): 48 | combined = torch.stack(v, 0) 49 | deviation[k] = torch.std(combined, dim = [0]) 50 | 51 | return deviation 52 | 53 | 54 | class AverageDict: 55 | def __init__(self): 56 | self.meters = {} 57 | 58 | def add(self, losses): 59 | for k, v in losses.items(): 60 | if k not in self.meters.items(): 61 | self.meters[k] = AverageMeter() 62 | self.meters[k].update(v) 63 | 64 | def to_dict(self): 65 | result = {} 66 | for k, m in self.meters.items(): 67 | result[k] = m.avg 68 | return result 69 | 70 | def reset(self): 71 | for k, m in self.meters.items(): 72 | m.reset() 73 | 74 | class AverageMeter: 75 | def __init__(self): 76 | self.reset() 77 | 78 | def reset(self): 79 | self.val = 0 80 | self.avg = 0 81 | self.sum = 0 82 | self.count = 0 83 | 84 | def update(self, val, n=1): 85 | self.val = val 86 | self.sum += (val * n) 87 | self.count += n 88 | self.avg = self.sum / self.count 89 | 90 | class Timer: 91 | def __init__(self): 92 | self.stopwatches = {} 93 | 94 | def start(self, key): 95 | 96 | if not key in self.stopwatches: 97 | self.stopwatches[key] = CudaStopwatch() 98 | 99 | self.stopwatches[key].start() 100 | 101 | def stop(self, key): 102 | self.stopwatches[key].stop() 103 | 104 | def print_elapsed_times(self): 105 | for key, sw in self.stopwatches.items(): 106 | print("{}: {} sec".format(key, sw.get_elapsed_time())) 107 | 108 | class CudaStopwatch: 109 | def __init__(self): 110 | self.start_event = torch.cuda.Event(enable_timing = True) 111 | self.end_event = torch.cuda.Event(enable_timing = True) 112 | 113 | def start(self): 114 | self.start_event.record() 115 | 116 | def stop(self): 117 | self.end_event.record() 118 | torch.cuda.synchronize() 119 | 120 | def get_elapsed_time(self): 121 | return self.start_event.elapsed_time(self.end_event) 122 | -------------------------------------------------------------------------------- /util/debug.py: -------------------------------------------------------------------------------- 1 | def viz_grad_sum(name, abs = True): 2 | def hook(grad): 3 | 4 | if abs: 5 | print(name, "(abs sum):", grad.abs().sum()) 6 | else: 7 | print(name, "(sum):", grad.sum()) 8 | 9 | return hook 10 | 11 | def viz_grad_mean(name, abs = True): 12 | def hook(grad): 13 | if abs: 14 | print(name, "(abs mean):", grad.abs().mean(), grad.shape) 15 | else: 16 | print(name, "(mean):", grad.mean()) 17 | return hook 18 | 19 | def viz_grad(name): 20 | def hook(grad): 21 | print(name, grad) 22 | return hook 23 | -------------------------------------------------------------------------------- /util/filter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import scipy.ndimage 3 | import torch 4 | import numpy as np 5 | 6 | class GaussianFilter(nn.Module): 7 | def __init__(self, channels, kernel_size, sigma, peak_to_one = False): 8 | super().__init__() 9 | padding = int(kernel_size/2) 10 | self.pad = nn.ZeroPad2d(padding) 11 | 12 | kernel = self._make_gaussian_kernel(kernel_size, sigma, peak_to_one) 13 | self.kernel_max = kernel.max() 14 | self.conv = self._define_conv(channels, kernel, kernel_size) 15 | self.conv.weight.requires_grad = False 16 | 17 | def forward(self, x): 18 | x = self.pad(x) 19 | x = self.conv(x) 20 | return x 21 | 22 | def _define_conv(self, channels, kernel, kernel_size): 23 | conv = nn.Conv2d(channels, channels, groups = channels, kernel_size = kernel_size, padding = 0, stride = 1, bias = False) 24 | conv.weight.data.copy_(kernel) 25 | return conv 26 | 27 | def _make_gaussian_kernel(self, kernel_size, sigma, peak_to_one): 28 | g_kernel = np.zeros((kernel_size, kernel_size)).astype(np.float64) 29 | center = int(kernel_size / 2) 30 | g_kernel[center, center] = 1 31 | g_kernel = scipy.ndimage.gaussian_filter(g_kernel, sigma) 32 | if peak_to_one: 33 | g_kernel = g_kernel / g_kernel.max() 34 | return torch.from_numpy(g_kernel) 35 | -------------------------------------------------------------------------------- /util/fisheye.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | def get_focal_length(max_radius, fisheye_type): 6 | """ Calculcate focal length of a 180 degree fov fisheye camera, which means 7 | theta is pi/2 at the max radius from the center of the image. 8 | """ 9 | if fisheye_type == 'orthographic': 10 | return max_radius 11 | elif fisheye_type == 'equidistant': 12 | return (max_radius * 2 / math.pi) 13 | else: 14 | raise ValueError("fisheye type {} is not implemented".format(fish_type)) 15 | 16 | def r_function(max_radius, theta, fisheye_type, use_np = False): 17 | f = get_focal_length(max_radius, fisheye_type) 18 | 19 | math_module = torch 20 | if use_np: 21 | math_module = np 22 | 23 | if fisheye_type == 'orthographic': 24 | return f * math_module.sin(theta) 25 | elif fisheye_type == 'equidistant': 26 | return f * theta 27 | else: 28 | raise ValueError("fisheye type {} is not implemented".format(fish_type)) 29 | 30 | def inverse_r_function(max_radius, r, fish_type, use_np = False): 31 | f = get_focal_length(max_radius, fish_type) 32 | 33 | math_module = torch 34 | if use_np: 35 | math_module = np 36 | 37 | if fish_type == 'orthographic': 38 | return math_module.arcsin(r/f) 39 | elif fish_type == 'equidistant': 40 | return r/f 41 | else: 42 | raise ValueError("fisheye type {} is not implemented".format(fish_type)) 43 | 44 | def make_theta_phi_meshgrid(img_shape, fisheye_type): 45 | height, width = img_shape 46 | v = np.linspace(0, height, height) 47 | u = np.linspace(0, width, width) 48 | 49 | vv, uu = np.meshgrid(v, u) 50 | vv = vv-height/2 51 | uu = uu-width/2 52 | 53 | r = np.sqrt(np.power(vv, 2) + np.power(uu, 2)) 54 | max_radius = width / 2 55 | theta = inverse_r_function(max_radius, r, fisheye_type, use_np = True) 56 | theta = np.nan_to_num(theta, nan=0) 57 | 58 | phi = np.arctan2(uu, vv) 59 | return torch.from_numpy(theta), torch.from_numpy(phi) 60 | -------------------------------------------------------------------------------- /util/hooks.py: -------------------------------------------------------------------------------- 1 | def nan_to_zero_hook(grad): 2 | grad[grad != grad] = 0 # NaN != NaN 3 | return grad 4 | -------------------------------------------------------------------------------- /util/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torchvision 4 | import torch 5 | from PIL import Image 6 | 7 | ############### Colormap ####################### 8 | 9 | def convert_to_colormap(heatmap, max_value = 1.0): 10 | heatmap = scale_to_make_visible(heatmap, max_value) 11 | heatmap = merge_channels(heatmap) 12 | heatmap = torch.clamp(heatmap, max = 1.0, min = 0.0) 13 | heatmap_img = conver_each_channel_to_colormap(heatmap) 14 | 15 | return heatmap_img 16 | 17 | def merge_channels(multi_channel_img): 18 | single_channel_img = multi_channel_img.sum(dim = 1).unsqueeze(1) 19 | #heatmap_img = heatmap[:,0,:,:].unsqueeze(1) 20 | return single_channel_img 21 | 22 | def scale_to_make_visible(heatmap, max_value): 23 | return heatmap / max_value 24 | 25 | def blend_to_image(img, heatmap, ratio = 0.45): 26 | img = img.cpu().detach() 27 | heatmap = heatmap.cpu().detach() 28 | blended = blend(img, heatmap) 29 | blended_img = torchvision.transforms.ToPILImage()(blended) 30 | return blended_img 31 | 32 | def blend(images, heatmaps, ratio = 0.45): 33 | 34 | merged_heatmaps = heatmaps.sum(dim = 1).unsqueeze(1) 35 | 36 | colormaps = convert_to_colormap(merged_heatmaps) 37 | num_sample = images.size(0) 38 | 39 | blended = [] 40 | for i in range(num_sample): 41 | img = images[i] 42 | img = torchvision.transforms.ToPILImage()(img) 43 | 44 | cm = colormaps[i].squeeze() 45 | cm = torchvision.transforms.ToPILImage()(cm) 46 | 47 | b = Image.blend(img, cm, ratio) 48 | 49 | b = torchvision.transforms.ToTensor()(b).unsqueeze(0) 50 | 51 | blended.append(b) 52 | blended = torch.cat(blended) 53 | 54 | return blended.squeeze() 55 | 56 | def conver_each_channel_to_colormap(multi_channel_img): 57 | colormaps = [] 58 | 59 | for img in multi_channel_img: 60 | img = img.squeeze() 61 | img *= 255 62 | img = img.numpy().astype(np.uint8) 63 | cm = cv2.applyColorMap(img, cv2.COLORMAP_JET) 64 | cm = cv2.cvtColor(cm, cv2.COLOR_BGR2RGB) 65 | cm = torchvision.transforms.ToTensor()(cm).unsqueeze(0) 66 | colormaps.append(cm) 67 | 68 | return torch.cat(colormaps) 69 | 70 | def gray_to_rgb(img_tensor): 71 | shape = img_tensor.shape 72 | assert shape[1] == 1, "input should have a single channel" 73 | 74 | return torch.cat((img_tensor, img_tensor, img_tensor), 1) 75 | 76 | ############### Normalizations ####################### 77 | 78 | def normalize_img(img, mean=0.5, std=0.5): 79 | if img is None: 80 | return None 81 | """ 82 | ref : https://github.com/pytorch/vision/issues/528 83 | normalize: image = (image - mean)/std 84 | unnormalize: image = (image * std) + mean = (image - (-mean/std)) * (1/std) 85 | mean = 0.5, std = 0.5 86 | Please see 'base_dataset' 87 | """ 88 | return (img - mean) / std # (img + 1) / 2.0 89 | 90 | def unnormalize_as_img(img, mean=0.5, std=0.5): 91 | if img is None: 92 | return None 93 | """ 94 | ref : https://github.com/pytorch/vision/issues/528 95 | normalize: image = (image - mean)/std 96 | unnormalize: image = (image * std) + mean = (image - (-mean/std)) * (1/std) 97 | mean = 0.5, std = 0.5 98 | Please see 'base_dataset' 99 | """ 100 | unnormalized = (img * std) + mean # (img + 1) / 2.0 101 | return unnormalized 102 | 103 | ############### Image Processing ######################## 104 | def get_center_circle_mask(img_size, dataformats = 'NCHW'): 105 | radius = int(np.min(img_size) / 2) 106 | 107 | center = (radius, radius) 108 | x = np.linspace(-img_size[0]/2, img_size[0]/2, img_size[0]) 109 | y = np.linspace(-img_size[1]/2, img_size[1]/2, img_size[1]) 110 | 111 | xx, yy = np.meshgrid(x, y) 112 | rr = np.sqrt(xx**2 + yy**2) 113 | 114 | mask = np.zeros(img_size, dtype = int) 115 | mask[rr <= radius] = 1.0 116 | if dataformats == 'NCHW': 117 | return torch.from_numpy(mask).float().unsqueeze(0).unsqueeze(0) 118 | elif dataformats == 'CHW': 119 | return torch.from_numpy(mask).float().unsqueeze(0) 120 | else: 121 | raise Exception("dataformats should be either 'NCHW' or 'CHW'") 122 | 123 | def expand_channel(img, dataformats = 'NCHW'): 124 | """ Expand channel to 3 125 | """ 126 | if dataformats == 'NCHW': 127 | target_dim = 1 128 | elif dataformats == 'CHW': 129 | target_dim = 0 130 | else: 131 | raise Exception("dataformats should be either 'NCHW' or 'CHW'") 132 | 133 | if img.size(target_dim) == 3: 134 | return img 135 | 136 | return torch.cat((img, img, img), target_dim) 137 | 138 | def merge_channel(img, dataformats = 'NCHW'): 139 | """ Merge channel to 1 140 | """ 141 | if dataformats == 'NCHW': 142 | target_dim = 1 143 | elif dataformats == 'CHW': 144 | target_dim = 0 145 | else: 146 | raise Exception("dataformats should be either 'NCHW' or 'CHW'") 147 | 148 | if img.size(target_dim) == 1: 149 | return img 150 | 151 | return img.mean(dim = target_dim).unsqueeze(target_dim) 152 | -------------------------------------------------------------------------------- /util/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import torch 4 | import torchvision 5 | 6 | from datetime import datetime 7 | 8 | from .image import blend_to_image 9 | 10 | def save_image_hot(filename, tensor): 11 | img = torchvision.transforms.ToPILImage()(tensor) 12 | img.save(str(filename)) 13 | 14 | def load_checkpoint(filename): 15 | checkpoint = torch.load(filename) 16 | return checkpoint 17 | 18 | class Logger: 19 | """ 20 | Create filenames for saving the results. 21 | """ 22 | def __init__(self, opt): 23 | self.opt = opt 24 | self.create_paths(opt) 25 | self.loss_file = self.get_loss_file() 26 | self.pck_file = self.get_pck_file() 27 | self.cnt = 0 28 | 29 | def create_paths(self, opt): 30 | this_file = pathlib.Path(os.path.abspath(__file__)) 31 | proj_root = this_file.parents[1] 32 | results_root = proj_root.joinpath('results') 33 | results_root.mkdir(exist_ok = True) 34 | 35 | result_path = results_root.joinpath(opt.name) 36 | 37 | dup_cnt = 0 38 | while result_path.exists(): 39 | dup_cnt += 1 40 | new_name = "{}{}".format(opt.name, dup_cnt) 41 | result_path = results_root.joinpath(new_name) 42 | 43 | result_path.mkdir(exist_ok = True) 44 | self.result_path = result_path 45 | 46 | def get_heatmap_path(self): 47 | 48 | heatmap_path = self.result_path.joinpath('heatmaps') 49 | heatmap_path.mkdir(exist_ok = True) 50 | return heatmap_path 51 | 52 | def get_image_path(self): 53 | image_path = self.result_path.joinpath('images') 54 | image_path.mkdir(exist_ok = True) 55 | return image_path 56 | 57 | def get_tensorboard_path(self): 58 | dir_name = self.result_path.parts[-1] 59 | return self.result_path.parents[1].joinpath(self.opt.tensorboard_dir, dir_name) 60 | 61 | def get_checkpoint_path(self, epoch): 62 | checkpoint_path = self.result_path.joinpath('checkpoints') 63 | checkpoint_path.mkdir(exist_ok = True) 64 | checkpoint_path = checkpoint_path.joinpath("{}.pth.tar".format(epoch)) 65 | 66 | return checkpoint_path 67 | 68 | def get_loss_file(self): 69 | loss_path = self.result_path.joinpath('losses.txt') 70 | return open(str(loss_path), 'w') 71 | 72 | def get_pck_file(self): 73 | pck_path = self.result_path.joinpath('pck.txt') 74 | return open(str(pck_path), 'w') 75 | 76 | def save_image_tensor(self, tensor): 77 | img = torchvision.transforms.ToPILImage()(tensor) 78 | img_path = self.get_image_path() 79 | img_file_path = img_path.joinpath('{}.png'.format(self.cnt)) 80 | self.cnt += 1 81 | img.save(str(img_file_path)) 82 | pass 83 | 84 | def save_heatmap(self, img, heatmap): 85 | blended_img = blend_to_image(img, heatmap) 86 | heatmap_path = self.get_heatmap_path() 87 | img_path = heatmap_path.joinpath('{}.png'.format(self.cnt)) 88 | self.cnt += 1 89 | 90 | blended_img.save(str(img_path)) 91 | 92 | def save_checkpoint(self, checkpoint, epoch): 93 | checkpoint['epoch'] = epoch 94 | cp_path = self.get_checkpoint_path(epoch) 95 | torch.save(checkpoint, str(cp_path)) 96 | 97 | def save_options(self, options): 98 | options_path = self.result_path.joinpath('options.txt') 99 | options_path.write_text(options.pretty_str()) 100 | 101 | def write_loss(self, loss): 102 | loss_str = "{}\n".format(str(loss)) 103 | self.loss_file.write(loss_str) 104 | 105 | def write_pck(self, pck_results): 106 | for thrs, acc in pck_results: 107 | line = "{}, {}\n".format(thrs, acc) 108 | self.pck_file.write(line) 109 | 110 | def close(self): 111 | self.loss_file.close() 112 | self.pck_file.close() 113 | 114 | class LossLog: 115 | def __init__(self, loss_dict, epoch, total_iter, tag): 116 | self.loss_dict = self._copy(loss_dict) 117 | 118 | self.loss_dict['timestamp'] = str(datetime.now()) 119 | self.loss_dict['epoch'] = str(epoch) 120 | self.loss_dict['iter'] = str(total_iter) 121 | self.loss_dict['tag'] = tag 122 | 123 | def _copy(self, loss_dict): 124 | new_dict = {} 125 | for k, v in loss_dict.items(): 126 | new_dict[k] = v 127 | 128 | return new_dict 129 | 130 | def __str__(self): 131 | log_str_list = [] 132 | for k, v in self.loss_dict.items(): 133 | log_str_list.append("{}:{}".format(k, v)) 134 | 135 | log_str = ','.join(log_str_list) 136 | return log_str 137 | -------------------------------------------------------------------------------- /util/joint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class JointConverter: 4 | def __init__(self, num_joints): 5 | self.num_joints = num_joints 6 | self.joint_scale = 0 7 | 8 | def convert_for_training(self, joint): 9 | joint = self.normalize(joint) 10 | #joint = self._to_spherical_coord(joint) 11 | #joint = self._flatten(joint) 12 | 13 | return joint 14 | 15 | def convert_for_output(self, joint, no_unnormalize = False): 16 | #joint = self._unflatten(joint) 17 | #joint = self._to_cartesian_coord(joint) 18 | if no_unnormalize: 19 | return joint 20 | 21 | joint = self.unnormalize(joint) 22 | return joint 23 | 24 | def normalize(self, joint): 25 | """ make length between wrist and middle mcp as 1.0 26 | joint: (N, n_joint, 3) 27 | """ 28 | 29 | wrist = joint[:, 0, :].unsqueeze(1) 30 | middle_mcp = joint[:, 9, :].unsqueeze(1) 31 | diff = wrist - middle_mcp 32 | scale = diff.norm(dim=2, keepdim = True) 33 | joint = joint / scale 34 | 35 | self.joint_scale = scale 36 | 37 | return joint 38 | 39 | def unnormalize(self, joint): 40 | """ change joint back to its original scale 41 | joint: (N, n_joint *3) 42 | """ 43 | return joint * self.joint_scale 44 | 45 | def _to_spherical_coord(self, joint): 46 | X = joint[:,:,0] 47 | Y = joint[:,:,1] 48 | Z = joint[:,:,2] 49 | 50 | R = torch.sqrt(X**2 + Y**2 + Z**2) 51 | T = torch.acos(Z/R) 52 | P = torch.atan2(Y, X) 53 | 54 | joint = torch.stack([R, T, P], dim = 2) 55 | return joint 56 | 57 | def _to_cartesian_coord(self, joint): 58 | R = joint[:,:,0] 59 | T = joint[:,:,1] 60 | P = joint[:,:,2] 61 | 62 | Z = R * torch.cos(T) 63 | XY = R * torch.sin(T) 64 | X = XY * torch.cos(P) 65 | Y = XY * torch.sin(P) 66 | 67 | joint = torch.stack([X,Y,Z], dim = 2) 68 | return joint 69 | 70 | def _flatten(self, joint): 71 | if joint is None: 72 | return joint 73 | joint = joint.view(-1, 3*self.num_joints) 74 | return joint 75 | 76 | def _unflatten(self, joint): 77 | return joint.view(-1, self.num_joints, 3) 78 | -------------------------------------------------------------------------------- /util/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def argmax_2d(tensor): 4 | assert len(tensor.shape) == 4 5 | N, C, H, W = tensor.shape 6 | tensor = tensor.reshape(N, C, H*W) 7 | _, idx = tensor.max(dim = -1) 8 | 9 | row, col = unravel_index(idx, H, W) 10 | return torch.stack([row, col], dim = 2) 11 | 12 | def unravel_index(idx, H, W): 13 | row = (idx / W).long() 14 | col = (idx % W).long() 15 | 16 | return row, col 17 | 18 | def argmin_2d(tensor): 19 | return argmax_2d(-tensor) 20 | -------------------------------------------------------------------------------- /util/package.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | def find_class_using_name(module_name, name, postfix): 4 | """ 5 | Imports class with name. 6 | The name should have certain pattern with under the module. 7 | """ 8 | filename = '{}.{}_{}'.format(module_name, name, postfix) 9 | module_lib = importlib.import_module(filename) 10 | target_name = (name + postfix).replace('_', '') 11 | 12 | target_cls = None 13 | for name, cls in module_lib.__dict__.items(): 14 | if name.lower() == target_name.lower(): 15 | target_cls = cls 16 | 17 | if target_cls is None: 18 | NotImplementedError('Class {} is not implemented in module {}'.format(name, module_name)) 19 | 20 | return target_cls 21 | -------------------------------------------------------------------------------- /util/projector.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | import numpy as np 4 | import cv2 5 | import math 6 | import util.fisheye as fisheye 7 | 8 | class BaseProjector(ABC): 9 | 10 | def __init__(self, img_size): 11 | self.img_size = img_size 12 | self.img_shape = (img_size, img_size) 13 | 14 | def make_heatmap_seed(self, joint, data_format = 'NJC'): 15 | mapped_joint = self.convert_to_uv(joint, data_format = data_format) 16 | heatmap_seed = self.convert_to_heatmap_seed(mapped_joint, self.img_shape) 17 | return heatmap_seed 18 | 19 | @abstractmethod 20 | def convert_to_uv(self, xyz): 21 | pass 22 | 23 | def convert_to_heatmap_seed(self, uv_mat, img_shape): 24 | assert len(uv_mat.shape) == 3 # N, n_joints, 2 25 | 26 | n_sample, n_joints = uv_mat.shape[:2] 27 | 28 | int_uv_mat = torch.round(uv_mat).long() 29 | int_uv_mat = torch.clamp(int_uv_mat, min = 0, max = img_shape[0] - 1) 30 | 31 | batch_joint_index = np.arange(0, n_sample * n_joints) 32 | 33 | flatten_idx = int_uv_mat.view(-1, 2) 34 | x_idx = flatten_idx[:,0] 35 | y_idx = flatten_idx[:,1] 36 | 37 | heatmaps = torch.zeros(n_sample * n_joints, img_shape[0], img_shape[1]) 38 | heatmaps[batch_joint_index, y_idx, x_idx] = 1 39 | heatmaps = heatmaps.view(n_sample, n_joints, img_shape[0], img_shape[1]) 40 | 41 | return heatmaps 42 | 43 | def check_data_format(self, xyz, data_format): 44 | if data_format == 'NJC': 45 | assert len(xyz.shape) == 3, "for 'NC' data format, the shape should be (N,n_joints,3)" 46 | elif data_format == 'NC': 47 | assert len(xyz.shape) == 2, "for 'NC' data format, the shape should be (N,3)" 48 | else: 49 | raise Exception("Wrong Data format. The data format should be 'NJC' or 'NC'") 50 | 51 | class FisheyeProjector(BaseProjector): 52 | def __init__(self, img_size, fisheye_type = 'equidistant'): 53 | super().__init__(img_size) 54 | 55 | self.radius = int(self.img_size / 2) 56 | self.center = (self.radius, self.radius) 57 | self.fisheye_type = fisheye_type 58 | 59 | def convert_to_uv(self, xyz, data_format='NJC'): 60 | self.check_data_format(xyz, data_format) 61 | 62 | if data_format == 'NJC': 63 | n_sample, n_joints = xyz.shape[:2] 64 | 65 | xyz = xyz.view(-1,3) 66 | x = xyz[:,0] 67 | y = xyz[:,1] 68 | z = xyz[:,2] 69 | 70 | theta = torch.atan2(torch.sqrt(x*x + y*y), z) 71 | phi = torch.atan2(y, x) 72 | 73 | r = fisheye.r_function(self.radius, theta, self.fisheye_type) 74 | 75 | _x = r * torch.cos(phi) 76 | _y = r * torch.sin(phi) 77 | 78 | fish_x = torch.round(self.center[0] + _x).unsqueeze(1) 79 | fish_y = torch.round(self.center[1] + _y).unsqueeze(1) 80 | 81 | out = torch.cat((fish_x, fish_y), dim = 1) 82 | if data_format == 'NJC': 83 | out = out.view(n_sample, n_joints, -1) 84 | return out 85 | 86 | class FlatProjector(BaseProjector): 87 | def __init__(self, img_size, space_to_img_ratio): 88 | super().__init__(img_size) 89 | self.space_to_img_ratio = space_to_img_ratio 90 | self.range_3d = img_size / space_to_img_ratio 91 | 92 | def convert_to_uv(self, xyz, data_format = 'NC'): 93 | self.check_data_format(xyz, data_format) 94 | 95 | if data_format == 'NJC': 96 | n_sample, n_joints = xyz.shape[:2] 97 | 98 | xyz = xyz.view(-1, 3) 99 | 100 | xy = xyz[:,0:2] 101 | 102 | uv = self.proj_to_img_plane(xy) 103 | if data_format == 'NJC': 104 | uv = uv.view(n_sample, n_joints, -1) 105 | return uv 106 | 107 | def proj_to_img_plane(self, val): 108 | return torch.round((val / self.range_3d) * self.img_size + (self.img_size/2)) 109 | -------------------------------------------------------------------------------- /util/unwarp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.projector import FlatProjector 3 | import util.fisheye as fisheye 4 | import util.image as image 5 | 6 | class Unwarper: 7 | def __init__(self, opt): 8 | 9 | self.fisheye_type = opt.fisheye_type 10 | self.out_img_size = opt.out_size 11 | 12 | # 3d joint space predefined parameters 13 | self.near_distance = opt.near_distance 14 | self.far_distance = opt.far_distance 15 | 16 | #projection camera model parameters (camera faces along the z axis) 17 | self.max_depth = opt.max_depth 18 | self.projector = FlatProjector(opt.out_size, opt.space_to_img_ratio) 19 | 20 | # Thresholds 21 | self.min_depth_thrs = opt.min_depth_thrs 22 | self.initialized = False 23 | 24 | def initialize(self, input_img_size, is_cuda): 25 | 26 | theta_map, phi_map = fisheye.make_theta_phi_meshgrid(input_img_size, self.fisheye_type) 27 | theta_map.requires_grad = False 28 | phi_map.requires_grad = False 29 | fisheye_mask = image.get_center_circle_mask(input_img_size) 30 | fisheye_mask.requires_grad = False 31 | 32 | self.fisheye_mask = fisheye_mask 33 | self.theta_map = theta_map 34 | self.phi_map = phi_map 35 | 36 | if is_cuda: 37 | self.fisheye_mask = self.fisheye_mask.to(device='cuda') 38 | self.theta_map = self.theta_map.to(device='cuda') 39 | self.phi_map = self.phi_map.to(device='cuda') 40 | 41 | self.is_cuda = is_cuda 42 | 43 | self.initialized = True 44 | 45 | def unwarp(self, fish_depth_img): 46 | assert len(fish_depth_img.shape) == 4 47 | 48 | self._initialize_if_not(fish_depth_img) 49 | 50 | fish_depth_img = self._preprocess(fish_depth_img) 51 | 52 | n_sample = fish_depth_img.size(0) 53 | 54 | unwarped_imgs = [] 55 | for i in range(n_sample): 56 | unwarped = self._unwarp_single_img(fish_depth_img[i]) 57 | unwarped_imgs.append(unwarped) 58 | 59 | result = torch.stack(unwarped_imgs, 0) 60 | del unwarped_imgs # save memory 61 | 62 | return result 63 | 64 | def _initialize_if_not(self, depth_img): 65 | n_epoch, n_channel, height, width = depth_img.shape 66 | if not self.initialized: 67 | input_img_size = (height, width) 68 | self.initialize(input_img_size, depth_img.is_cuda) 69 | 70 | def _preprocess(self, depth_img): 71 | depth_img = depth_img * self.fisheye_mask 72 | return depth_img 73 | 74 | def _unwarp_single_img(self, depth_img): 75 | depth = depth_img.view(-1) # H, W 76 | theta = self.theta_map.view(-1) 77 | phi = self.phi_map.view(-1) 78 | 79 | try: # DEBUG: Error occurs here intermittenly 80 | valid_index = self._filter_valid_index(depth) 81 | 82 | if self._check_if_img_empty(depth): 83 | return self._get_empty_img() 84 | depth = depth[valid_index] 85 | theta = theta[valid_index] 86 | phi = phi[valid_index] 87 | 88 | valid_index = self._sort_index_depth_desc_order(depth) 89 | if self._check_if_img_empty(depth): 90 | return self._get_empty_img() 91 | depth = depth[valid_index] 92 | theta = theta[valid_index] 93 | phi = phi[valid_index] 94 | 95 | except Exception as e: 96 | print(valid_index) 97 | print(depth.shape) 98 | print(e) 99 | raise Exception() 100 | 101 | distance = self._scale_reversed_depth_to_distance(depth) 102 | 103 | x, y, z = self._convert_to_cartesian(distance, theta, phi) 104 | xyz = torch.stack((x,y,z), dim=1) 105 | 106 | uv = self.projector.convert_to_uv(xyz, data_format='NC') 107 | uv = uv.long() 108 | """loss should not backprop to uv. because uv is used as index""" 109 | uv = uv.detach() 110 | 111 | valid_index = self._filter_valid_uv_index(uv) 112 | uv = uv[valid_index] 113 | xyz = xyz[valid_index] 114 | 115 | converted_img = self._map_to_img(uv, xyz) 116 | return converted_img 117 | 118 | def _check_if_img_empty(self, depth_img): 119 | return len(depth_img.size()) == 0 or depth_img.nelement() == 0 120 | 121 | def _get_empty_img(self): 122 | out_img_shape = (1, self.out_img_size, self.out_img_size) 123 | out_img = torch.zeros(out_img_shape) 124 | if self.is_cuda: 125 | return out_img.to(device='cuda') 126 | 127 | return out_img 128 | 129 | def _filter_valid_index(self, depth): 130 | valid_index = (depth > self.min_depth_thrs).nonzero() 131 | return valid_index.squeeze() 132 | 133 | def _sort_index_depth_desc_order(self, depth): 134 | """ 135 | Sort x, y, z in descending order by its distance from the camera. 136 | The result is used for implementing occlusion of point clouds. 137 | 'projection' will only consider a point that is near (which has a larger index) the camera if multiple points are projected to a same point. 138 | """ 139 | reversed_depth = depth 140 | sorted_idx = torch.argsort(reversed_depth, descending = True) 141 | return sorted_idx 142 | 143 | def _convert_to_cartesian(self, distance, theta, phi): 144 | distance = distance.double() 145 | theta = theta.double() 146 | phi = phi.double() 147 | 148 | z = distance * torch.cos(theta) 149 | x = distance * torch.sin(theta) * torch.cos(phi) 150 | y = distance * torch.sin(theta) * torch.sin(phi) 151 | 152 | return x, y, z 153 | 154 | def _filter_valid_uv_index(self, uv): 155 | # select indices that u and v is both in image (0 < u,v < img_size) 156 | mask_0 = uv >= 0 157 | mask_1 = uv <= (self.out_img_size - 1) 158 | 159 | mask = mask_0 * mask_1 # and operation 160 | mask = mask[:, 0] * mask[:, 1] 161 | 162 | valid_index = mask.nonzero() 163 | return valid_index.squeeze() 164 | 165 | def _map_to_img(self, uv, xyz): 166 | empty_img = self._get_empty_img() 167 | 168 | u = uv[:,0] 169 | v = uv[:,1] 170 | d = xyz[:,2] # z == d, orthogonal projection (Flat Projector) 171 | #d = torch.norm(xyz[:,:2], dim=1) # z == d, orthogonal projection (Flat Projector) 172 | 173 | d = self._scale_3d_to_grayscale(d) 174 | empty_img[0, v, u] = d.float() 175 | 176 | return empty_img 177 | 178 | def _scale_reversed_depth_to_distance(self, reversed_depth): 179 | # depth image follows the inverse depth scale. A nearer point has brighter color. 180 | depth_01 = (1.0 - reversed_depth) 181 | return depth_01 * (self.far_distance - self.near_distance) + self.near_distance 182 | 183 | def _scale_3d_to_grayscale(self, val): 184 | depth_01 = (val - self.near_distance) / (self.far_distance - self.near_distance) 185 | return (1.0 - depth_01) # convert to the inverse depth scale 186 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import time 5 | from torch.utils.tensorboard import SummaryWriter 6 | from .image import blend 7 | 8 | from util.image import unnormalize_as_img 9 | 10 | class Visualizer: 11 | def __init__(self, log_dir): 12 | self.log_dir = log_dir 13 | self.writer = SummaryWriter(log_dir) 14 | 15 | def add_losses(self, tag, losses, epoch): 16 | for k, v in losses.items(): 17 | self.writer.add_scalar("{}/{}".format(tag, k), v, epoch, walltime = time.time()) 18 | 19 | def add_images(self, tag, imgs, epoch, nrow = 5, dataformats = 'NCHW'): 20 | self.writer.add_images(tag, imgs, epoch, walltime = time.time(), dataformats = dataformats) 21 | 22 | def add_image(self, tag, img, epoch, dataformats = 'CHW'): 23 | self.writer.add_image(tag, img.squeeze(0), epoch, walltime = time.time(), dataformats = dataformats) 24 | 25 | def add_histogram(self, tag, val, epoch): 26 | self.writer.add_histogram(tag, val, epoch, walltime = time.time()) 27 | 28 | def blend_heatmap(self, img, heatmap): 29 | img = img.cpu().detach() 30 | heatmap = heatmap.cpu().detach() 31 | blended = blend(img, heatmap) 32 | blended_img = torchvision.transforms.ToPILImage()(blended) 33 | return blended_img 34 | --------------------------------------------------------------------------------