├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── inference_dataset.py ├── shapenet_dataset.py └── shapenet_sketch_dataset.py ├── docs ├── .editorconfig ├── index.html └── static │ ├── css │ ├── customize.css │ ├── main.css │ └── normalize.css │ ├── js │ ├── main.js │ ├── plugins.js │ └── vendor │ │ └── modernizr-3.11.2.min.js │ └── media │ ├── 2d-iou.png │ ├── 3d-iou.png │ ├── comparison.png │ ├── gen.png │ ├── setting.png │ └── view-aware.png ├── infer.py ├── load └── inference │ ├── airplane.png │ ├── car.png │ └── table.png ├── models ├── __init__.py ├── base_model.py ├── criterions.py ├── networks.py └── view_disentangle_model.py ├── options ├── __init__.py ├── base_options.py ├── infer_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── templates └── sphere_642.obj ├── test.py ├── train.py └── utils ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | # End of https://www.toptal.com/developers/gitignore/api/python 146 | 147 | load/shapenet-* 148 | checkpoints/ 149 | runs/ 150 | results/ 151 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "SoftRas"] 2 | path = SoftRas 3 | url = https://github.com/ShichenLiu/SoftRas.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuanchen Guo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sketch2Model: View-Aware 3D Modeling from Single Free-Hand Sketches 2 | ### [Project Page](https://bennyguo.github.io/sketch2model/) | [Video](https://www.youtube.com/watch?v=wqGwcUKBG7E) | [Paper](https://arxiv.org/abs/2105.06663) | [Data](https://drive.google.com/drive/folders/1_DKZV6KtqpLKRoBd0JgOgf60wi1LYm6s?usp=sharing) 3 | Official PyTorch implementation of paper `Sketch2Model: View-Aware 3D Modeling from Single Free-Hand Sketches`, presented at CVPR 2021. The code framework is adapted from [this CycleGAN repository](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 4 | 5 | ![](docs/static/media/gen.png) 6 | 7 | ## Environments 8 | - `git clone --recursive https://github.com/bennyguo/sketch2model.git` 9 | - Python>=3.6 10 | - PyTorch>=1.5 11 | - install dependencies: `pip install -r requirements.txt` 12 | - build and install Soft Rasterizer: `cd SoftRas; python setup.py install` 13 | 14 | ## Training 15 | - Download `shapenet-synthetic.zip` [here](https://drive.google.com/drive/folders/1_DKZV6KtqpLKRoBd0JgOgf60wi1LYm6s?usp=sharing), and extract to `load/`: 16 | ``` 17 | load/ 18 | └───shapenet-synthetic / 19 | │ 02691156/ 20 | │ ... ... 21 | ... 22 | ``` 23 | - Train on airplane: 24 | ``` 25 | python train.py --name exp-airplane --class_id 02691156 26 | ``` 27 | 28 | You may specify arguments listed in `options/base_options.py` and `options/train_options.py`. Saved meshes are named with the corresponding (ground truth or predicted) viewpoints in format `e[elevation]a[azimuth]`. `pred` in the filename indicates predicted viewpoint, otherwise the viewpoint is ground truth value (or user-specified when inference, see the Inference section below). 29 | 30 | Supported classes: 31 | ``` 32 | 02691156 Airplane 33 | 02828884 Bench 34 | 02933112 Cabinet 35 | 02958343 Car 36 | 03001627 Chair 37 | 03211117 Display 38 | 03636649 Lamp 39 | 03691459 Loudspeaker 40 | 04090263 Rifle 41 | 04256520 Sofa 42 | 04379243 Table 43 | 04401088 Telephone 44 | 04530566 Watercraft 45 | ``` 46 | 47 | ## Evaluation 48 | - Test on ShapeNet-Synthetic testset: 49 | ``` 50 | python test.py --name [experiment name] --class_id [class id] --test_split test 51 | ``` 52 | - To test on our ShapeNet-Sketch dataset, you need to first download `shapenet-sketch.zip` [here](https://drive.google.com/drive/folders/1_DKZV6KtqpLKRoBd0JgOgf60wi1LYm6s?usp=sharing) and extract to `load/`, then 53 | ``` 54 | python test.py --name [experiment name] --class_id [class id] --dataset_mode shapenet_sketch --dataset_root load/shapenet-sketch 55 | ``` 56 | About file structures of our ShapeNet-Sketch dataset, please see the dataset definition in `data/shapenet_sketch_dataset.py`. 57 | 58 | ## Inference 59 | You can generate a 3D mesh from a given sketch image (black background, white strokes) with predicted viewpoint: 60 | ``` 61 | python infer.py --name [experiment name] --class_id [class id] --image_path [path/to/sketch] 62 | ``` 63 | or with specified viewpoint: 64 | ``` 65 | python infer.py --name [experiment name] --class_id [class id] --image_path [path/to/sketch] --view [elevation] [azimuth] 66 | ``` 67 | Note that elevation is in range [-20, 40], and azimuth is in range [-180, 180]. We provide some example sketches used in the paper in `load/inference`. 68 | 69 | ## Pretrained Weights 70 | We provide pretrained weights for all the 13 classes, and weights trained with domain adaptation for 6 of them (airplane, bench, car, chair, rifle, sofa, table). You can download `checkpoints.zip` [here](https://drive.google.com/drive/folders/1_DKZV6KtqpLKRoBd0JgOgf60wi1LYm6s?usp=sharing) and extract to `checkpoints/`: 71 | ``` 72 | checkpoints/ 73 | └───airplane_pretrained/ 74 | └───airplane_pretrained+da/ 75 | └───... ... 76 | ``` 77 | Note that the code and data for domain adaptation are not contained in this repository. You may implement yourself according to the description in the original paper. 78 | 79 | ## Citation 80 | ``` 81 | @inproceedings{zhang2021sketch2model, 82 | title={Sketch2Model: View-Aware 3D Modeling from Single Free-Hand Sketches}, 83 | author={Zhang, Song-Hai and Guo, Yuan-Chen and Gu, Qing-Wen}, 84 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 85 | pages={6012--6021}, 86 | year={2021} 87 | } 88 | ``` 89 | 90 | ## Contact 91 | If you have any questions about the implementation or the paper, please feel free to open an issue or contact me at . -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_dataset import BaseDataset 4 | 5 | 6 | def find_dataset_using_name(dataset_name): 7 | """Import the module "data/[dataset_name]_dataset.py". 8 | 9 | In the file, the class called DatasetNameDataset() will 10 | be instantiated. It has to be a subclass of BaseDataset, 11 | and it is case-insensitive. 12 | """ 13 | dataset_filename = "data." + dataset_name + "_dataset" 14 | datasetlib = importlib.import_module(dataset_filename) 15 | 16 | dataset = None 17 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 18 | for name, cls in datasetlib.__dict__.items(): 19 | if name.lower() == target_dataset_name.lower() \ 20 | and issubclass(cls, BaseDataset): 21 | dataset = cls 22 | 23 | if dataset is None: 24 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 25 | 26 | return dataset 27 | 28 | 29 | def get_option_setter(dataset_name): 30 | """Return the static method of the dataset class.""" 31 | dataset_class = find_dataset_using_name(dataset_name) 32 | return dataset_class.modify_commandline_options 33 | 34 | 35 | def create_dataset(opt, mode, shuffle): 36 | """Create a dataset given the option. 37 | 38 | This function wraps the class CustomDatasetDataLoader. 39 | 40 | Example: 41 | >>> from data import create_dataset 42 | >>> dataset = create_dataset(opt) 43 | """ 44 | data_loader = CustomDatasetDataLoader(opt, mode, shuffle) 45 | dataset = data_loader.load_data() 46 | return dataset 47 | 48 | 49 | class CustomDatasetDataLoader(): 50 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 51 | 52 | def __init__(self, opt, mode, shuffle): 53 | """Initialize this class 54 | 55 | Step 1: create a dataset instance given the name [dataset_mode] 56 | Step 2: create a multi-threaded data loader. 57 | """ 58 | self.opt = opt 59 | self.mode = mode 60 | dataset_class = find_dataset_using_name(opt.dataset_mode) 61 | self.dataset = dataset_class(opt, mode) 62 | print("dataset [%s] was created" % type(self.dataset).__name__) 63 | self.dataloader = torch.utils.data.DataLoader( 64 | self.dataset, 65 | batch_size=opt.batch_size, 66 | shuffle=shuffle, 67 | num_workers=int(opt.num_threads), 68 | pin_memory=True 69 | ) 70 | 71 | def load_data(self): 72 | return self 73 | 74 | def __len__(self): 75 | """Return the number of data in the dataset""" 76 | return min(len(self.dataset), self.opt.max_dataset_size) 77 | 78 | def __iter__(self): 79 | """Return a batch of data""" 80 | for i, data in enumerate(self.dataloader): 81 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 82 | break 83 | yield data 84 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | """ 3 | import torch.utils.data as data 4 | from abc import ABC, abstractmethod 5 | from options import Configurable 6 | 7 | 8 | class BaseDataset(data.Dataset, ABC, Configurable): 9 | """This class is an abstract base class (ABC) for datasets. 10 | 11 | To create a subclass, you need to implement the following four functions: 12 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 13 | -- <__len__>: return the size of dataset. 14 | -- <__getitem__>: get a data point. 15 | -- : (optionally) add dataset-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt, mode): 19 | """Initialize the class; save the options in the class 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | """ 24 | self.opt = opt 25 | self.mode = mode 26 | 27 | @staticmethod 28 | def modify_commandline_options(parser): 29 | """Add new dataset-specific options, and rewrite default values for existing options. 30 | 31 | Parameters: 32 | parser -- original option parser 33 | 34 | Returns: 35 | the modified parser. 36 | """ 37 | return parser 38 | 39 | @abstractmethod 40 | def __len__(self): 41 | """Return the total number of images in the dataset.""" 42 | return 0 43 | 44 | @abstractmethod 45 | def __getitem__(self, index): 46 | """Return a data point and its metadata information. 47 | 48 | Parameters: 49 | index - - a random integer for data indexing 50 | 51 | Returns: 52 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 53 | """ 54 | pass 55 | -------------------------------------------------------------------------------- /data/inference_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as TF 3 | from PIL import Image 4 | from data.base_dataset import BaseDataset 5 | 6 | 7 | class InferenceDataset(BaseDataset): 8 | """ 9 | Dataset for inferencing, containing only a single image with (optional) given view. 10 | """ 11 | def modify_commandline_options(parser): 12 | return parser 13 | 14 | def __init__(self, opt, mode): 15 | super().__init__(opt, mode) 16 | self.opt = opt 17 | self.image = self.get_image_tensor(opt.image_path) 18 | if opt.view is None: 19 | self.elevation, self.azimuth = 0, 0 20 | else: 21 | self.elevation, self.azimuth = torch.tensor(opt.view[0], dtype=torch.float32), torch.tensor(opt.view[1], dtype=torch.float32) 22 | assert -20 <= self.elevation <= 40 and -180 <= self.azimuth <= 180 23 | 24 | def __len__(self): 25 | return 1 26 | 27 | def __getitem__(self, index): 28 | return { 29 | 'image': self.image, 30 | 'elevation': self.elevation, 31 | 'azimuth': self.azimuth, 32 | } 33 | 34 | def get_image_tensor(self, path): 35 | image = Image.open(path).convert('RGBA') 36 | image = TF.resize(image, (self.opt.image_size, self.opt.image_size)) 37 | image = TF.to_tensor(image) 38 | return image 39 | -------------------------------------------------------------------------------- /data/shapenet_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torchvision.transforms.functional as TF 5 | import numpy as np 6 | from scipy.io import loadmat 7 | from PIL import Image 8 | from data.base_dataset import BaseDataset 9 | import soft_renderer as sr 10 | import soft_renderer.functional as srf 11 | 12 | 13 | class ShapeNetDataset(BaseDataset): 14 | """ 15 | Dataset for loading ShapeNet-Synthetic data. 16 | ShapeNet-Sythetic is used for training and evaluation. 17 | """ 18 | def modify_commandline_options(parser): 19 | parser.add_argument('--n_views_per_obj', type=int, default=20) 20 | return parser 21 | 22 | def __init__(self, opt, mode): 23 | super().__init__(opt, mode) 24 | self.opt = opt 25 | self.root = os.path.join(opt.dataset_root, opt.class_id) 26 | self.class_id = opt.class_id 27 | self.split = mode 28 | assert self.split in ['train', 'val', 'test'] 29 | with open(os.path.join(self.root, self.split + '.lst')) as f: 30 | self.obj_ids = list(filter(None, f.read().split('\n'))) 31 | self.dat = [] 32 | for obj_id in self.obj_ids: 33 | obj_path = os.path.join(self.root, obj_id) 34 | obj_dat = [] 35 | with open(os.path.join(obj_path, 'view.txt')) as f: 36 | obj_cameras = [list(map(float, c.split(' '))) for c in list(filter(None, f.read().split('\n')))] 37 | for i in range(opt.n_views_per_obj): 38 | obj_camera = obj_cameras[i] 39 | elevation, azimuth = self.get_view_tensor(obj_camera[1], obj_camera[0]) 40 | obj_dat.append({ 41 | 'class_id': self.class_id, 42 | 'obj_id': obj_id, 43 | 'image': os.path.join(obj_path, 'sketches', f"render_{i}.png"), 44 | 'camera': self.get_camera_tensor(obj_camera[3], obj_camera[1], obj_camera[0]), # distance, elevation, azimuth 45 | 'elevation': elevation, 46 | 'azimuth': azimuth, 47 | 'voxel': os.path.join(obj_path, 'voxel.mat'), 48 | }) 49 | self.dat.append(obj_dat) 50 | 51 | def __len__(self): 52 | if self.split == 'train': 53 | return len(self.dat) 54 | else: 55 | return len(self.dat) * self.opt.n_views_per_obj 56 | 57 | def __getitem__(self, index): 58 | if self.split == 'train': 59 | obj_dat = self.dat[index] 60 | view_dat = random.sample(obj_dat, k=2) 61 | camera = (view_dat[0]['camera'], view_dat[1]['camera']) 62 | image = (self.get_image_tensor(view_dat[0]['image']), self.get_image_tensor(view_dat[1]['image'])) 63 | elevation = (view_dat[0]['elevation'], view_dat[1]['elevation']) 64 | azimuth = (view_dat[0]['azimuth'], view_dat[1]['azimuth']) 65 | return { 66 | 'image': image, 67 | 'camera': camera, 68 | 'elevation': elevation, 69 | 'azimuth': azimuth 70 | } 71 | else: 72 | obj_dat = self.dat[index // self.opt.n_views_per_obj] 73 | view_dat = obj_dat[index % self.opt.n_views_per_obj] 74 | camera = view_dat['camera'] 75 | image = self.get_image_tensor(view_dat['image']) 76 | elevation, azimuth = view_dat['elevation'], view_dat['azimuth'] 77 | voxel = self.get_voxel_tensor(view_dat['voxel']) 78 | return { 79 | 'image': image, 80 | 'camera': camera, 81 | 'elevation': elevation, 82 | 'azimuth': azimuth, 83 | 'voxel': voxel 84 | } 85 | 86 | def get_image_tensor(self, path): 87 | image = Image.open(path).convert('RGBA') 88 | image = TF.resize(image, (self.opt.image_size, self.opt.image_size)) 89 | image = TF.to_tensor(image) 90 | torch.where(image[3] > 0.5, torch.tensor(1.), torch.tensor(0.)) 91 | return image 92 | 93 | def get_camera_tensor(self, distance, elevation, azimuth): 94 | camera = srf.get_points_from_angles(distance, elevation, azimuth) 95 | return torch.Tensor(camera).float() 96 | 97 | def get_voxel_tensor(self, path): 98 | # ground truth voxel head to x, up to y 99 | # transform to be able to compare with the voxel converted by SoftRas 100 | voxel = loadmat(path)['Volume'].astype(np.float32) 101 | voxel = np.rot90(np.rot90(voxel, axes=(1, 0)), axes=(2, 1)) 102 | voxel = torch.from_numpy(voxel) 103 | return voxel 104 | 105 | def get_view_tensor(self, elevation, azimuth): 106 | return torch.tensor(elevation, dtype=torch.float32), torch.tensor(azimuth, dtype=torch.float32) 107 | -------------------------------------------------------------------------------- /data/shapenet_sketch_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torchvision.transforms.functional as TF 5 | import numpy as np 6 | from scipy.io import loadmat 7 | from PIL import Image 8 | from data.base_dataset import BaseDataset 9 | import soft_renderer as sr 10 | import soft_renderer.functional as srf 11 | 12 | 13 | class ShapeNetSketchDataset(BaseDataset): 14 | """ 15 | Dataset for loading ShapeNet-Sketch data. 16 | ShapeNet-Sketch is for evaluation only. 17 | """ 18 | def modify_commandline_options(parser): 19 | return parser 20 | 21 | def __init__(self, opt, mode): 22 | super().__init__(opt, mode) 23 | self.opt = opt 24 | self.root = os.path.join(opt.dataset_root, opt.class_id) 25 | self.class_id = opt.class_id 26 | 27 | self.dat = [] 28 | for obj_id in os.listdir(self.root): 29 | obj_path = os.path.join(self.root, obj_id) 30 | with open(os.path.join(obj_path, 'view.txt')) as f: 31 | obj_cameras = [list(map(float, c.split(' '))) for c in list(filter(None, f.read().split('\n')))] 32 | obj_camera = obj_cameras[0] 33 | elevation, azimuth = self.get_view_tensor(obj_camera[1], obj_camera[0]) 34 | self.dat.append({ 35 | 'class_id': self.class_id, 36 | 'obj_id': obj_id, 37 | 'image': os.path.join(obj_path, 'sketch.png'), 38 | 'camera': self.get_camera_tensor(obj_camera[3], obj_camera[1], obj_camera[0]), # distance, elevation, azimuth 39 | 'elevation': elevation, 40 | 'azimuth': azimuth, 41 | 'voxel': os.path.join(obj_path, 'voxel.mat'), 42 | }) 43 | 44 | def __len__(self): 45 | return len(self.dat) 46 | 47 | def __getitem__(self, index): 48 | view_dat = self.dat[index] 49 | camera = view_dat['camera'] 50 | image = self.get_image_tensor(view_dat['image']) 51 | elevation, azimuth = view_dat['elevation'], view_dat['azimuth'] 52 | voxel = self.get_voxel_tensor(view_dat['voxel']) 53 | return { 54 | 'image': image, 55 | 'camera': camera, 56 | 'elevation': elevation, 57 | 'azimuth': azimuth, 58 | 'voxel': voxel 59 | } 60 | 61 | def get_image_tensor(self, path): 62 | image = Image.open(path).convert('RGBA') 63 | image = TF.resize(image, (self.opt.image_size, self.opt.image_size)) 64 | image = TF.to_tensor(image) 65 | torch.where(image[3] > 0.5, torch.tensor(1.), torch.tensor(0.)) 66 | return image 67 | 68 | def get_camera_tensor(self, distance, elevation, azimuth): 69 | camera = srf.get_points_from_angles(distance, elevation, azimuth) 70 | return torch.Tensor(camera).float() 71 | 72 | def get_voxel_tensor(self, path): 73 | # ground truth voxel head to x, up to y 74 | # transform to be able to compare with the voxel converted by SoftRas 75 | voxel = loadmat(path)['Volume'].astype(np.float32) 76 | voxel = np.rot90(np.rot90(voxel, axes=(1, 0)), axes=(2, 1)) 77 | voxel = torch.from_numpy(voxel) 78 | return voxel 79 | 80 | def get_view_tensor(self, elevation, azimuth): 81 | return torch.tensor(elevation, dtype=torch.float32), torch.tensor(azimuth, dtype=torch.float32) 82 | -------------------------------------------------------------------------------- /docs/.editorconfig: -------------------------------------------------------------------------------- 1 | # editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | indent_size = 2 8 | indent_style = space 9 | insert_final_newline = true 10 | trim_trailing_whitespace = true 11 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 |

Sketch2Model:
View-Aware 3D Modeling from Single Free-Hand Sketches

29 |

Song-Hai Zhang, Yuan-Chen Guo, Qing-Wen Gu
Tsinghua University

30 |

Computer Vision and Pattern Recogonition (CVPR) 2021

31 |
32 | arXiv 33 | Code 34 | Data 35 |
36 |
37 |
38 |
39 |
40 |

Abstract

41 |

We investigate the problem of generating 3D meshes from single free-hand sketches, aiming at fast 3D modeling for novice users. It can be regarded as a single-view reconstruction problem, but with unique challenges, brought by the variation and conciseness of sketches. Ambiguities in poorly-drawn sketches could make it hard to determine how the sketched object is posed. In this paper, we address the importance of viewpoint specification for overcoming such ambiguities, and propose a novel view-aware generation approach. By explicitly conditioning the generation process on a given viewpoint, our method can generate plausible shapes automatically with predicted viewpoints, or with specified viewpoints to help users better express their intentions. Extensive evaluations on various datasets demonstrate the effectiveness of our view-aware design in solving sketch ambiguities and improving reconstruction quality.

42 |
43 |
44 |

Problem Setting

45 |

The figure bellow illustrates our proposed view-aware setting, where the user inputs a free-hand sketch, and optionally a viewpoint. The viewpoint describes which angle the object is sketched from, and can also come from network prediction. The output is a 3D mesh. It would match the input sketch at the viewpoint.

46 | 47 |
48 |
49 |

View-Aware Property

50 |

Our method outputs different shapes according to different input viewpoints, and ensures consistent silhouettes at these viewpoints. Take the table for example. If seen from the side of the table, with a small elevation angle, the synthesized result will have a thick top and short legs. As the elevation angle gets larger, the tabletop gets thinner and the legs become longer, to satisfy the silhouette constraint.

51 | 52 |
53 |
54 |

3D Modeling Results

55 |

Here shows some 3D modeling results on our collected ShapeNet-Sketch dataset (with predicted viewpoints).

56 | 57 |
58 |
59 |

ShapeNet-Sketch Dataset

60 |

To further evaluate our method, we collect a ShapeNet-Sketch dataset containing 1,300 free-hand sketches drawn by novice users, and their corresponding 3D models. The dataset is available on Google Drive to inspire further research. For more details, we refer to the README document contained in the dataset.

61 |

Here we show the mean voxel IoU scores on the ShapeNet-Sketch testset, comparing to baseline methods.

62 | 63 | We also show comparisons on 2D IoU scores, to demonstrate how well the output shape matches the input sketch. 64 | 65 |
66 |
67 |

Video

68 |
69 | 70 |
71 |
72 |
73 |

Citation

74 |

75 |   @inproceedings{zhang2021sketch2model,
76 |     title={Sketch2Model: View-Aware 3D Modeling from Single Free-Hand Sketches},
77 |     author={Zhang, Song-Hai and Guo, Yuan-Chen and Gu, Qing-Wen},
78 |     booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
79 |     pages={6012--6021},
80 |     year={2021}
81 |   }
82 |           
83 |
84 |
85 |
86 |
87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /docs/static/css/customize.css: -------------------------------------------------------------------------------- 1 | .content-wrapper { 2 | max-width: 800px; 3 | margin-left: auto; 4 | margin-right: auto; 5 | } 6 | 7 | .video-wrapper { 8 | position: relative; 9 | padding-bottom: 56.25%; 10 | /* 16:9 */ 11 | padding-top: 25px; 12 | height: 0; 13 | } 14 | 15 | .video-wrapper iframe { 16 | position: absolute; 17 | top: 0; 18 | left: 0; 19 | width: 100%; 20 | height: 100%; 21 | } -------------------------------------------------------------------------------- /docs/static/css/main.css: -------------------------------------------------------------------------------- 1 | /*! HTML5 Boilerplate v8.0.0 | MIT License | https://html5boilerplate.com/ */ 2 | 3 | /* main.css 2.1.0 | MIT License | https://github.com/h5bp/main.css#readme */ 4 | /* 5 | * What follows is the result of much research on cross-browser styling. 6 | * Credit left inline and big thanks to Nicolas Gallagher, Jonathan Neal, 7 | * Kroc Camen, and the H5BP dev community and team. 8 | */ 9 | 10 | /* ========================================================================== 11 | Base styles: opinionated defaults 12 | ========================================================================== */ 13 | 14 | html { 15 | color: #222; 16 | font-size: 1em; 17 | line-height: 1.4; 18 | } 19 | 20 | /* 21 | * Remove text-shadow in selection highlight: 22 | * https://twitter.com/miketaylr/status/12228805301 23 | * 24 | * Vendor-prefixed and regular ::selection selectors cannot be combined: 25 | * https://stackoverflow.com/a/16982510/7133471 26 | * 27 | * Customize the background color to match your design. 28 | */ 29 | 30 | ::-moz-selection { 31 | background: #b3d4fc; 32 | text-shadow: none; 33 | } 34 | 35 | ::selection { 36 | background: #b3d4fc; 37 | text-shadow: none; 38 | } 39 | 40 | /* 41 | * A better looking default horizontal rule 42 | */ 43 | 44 | hr { 45 | display: block; 46 | height: 1px; 47 | border: 0; 48 | border-top: 1px solid #ccc; 49 | margin: 1em 0; 50 | padding: 0; 51 | } 52 | 53 | /* 54 | * Remove the gap between audio, canvas, iframes, 55 | * images, videos and the bottom of their containers: 56 | * https://github.com/h5bp/html5-boilerplate/issues/440 57 | */ 58 | 59 | audio, 60 | canvas, 61 | iframe, 62 | img, 63 | svg, 64 | video { 65 | vertical-align: middle; 66 | } 67 | 68 | /* 69 | * Remove default fieldset styles. 70 | */ 71 | 72 | fieldset { 73 | border: 0; 74 | margin: 0; 75 | padding: 0; 76 | } 77 | 78 | /* 79 | * Allow only vertical resizing of textareas. 80 | */ 81 | 82 | textarea { 83 | resize: vertical; 84 | } 85 | 86 | /* ========================================================================== 87 | Author's custom styles 88 | ========================================================================== */ 89 | 90 | /* ========================================================================== 91 | Helper classes 92 | ========================================================================== */ 93 | 94 | /* 95 | * Hide visually and from screen readers 96 | */ 97 | 98 | .hidden, 99 | [hidden] { 100 | display: none !important; 101 | } 102 | 103 | /* 104 | * Hide only visually, but have it available for screen readers: 105 | * https://snook.ca/archives/html_and_css/hiding-content-for-accessibility 106 | * 107 | * 1. For long content, line feeds are not interpreted as spaces and small width 108 | * causes content to wrap 1 word per line: 109 | * https://medium.com/@jessebeach/beware-smushed-off-screen-accessible-text-5952a4c2cbfe 110 | */ 111 | 112 | .sr-only { 113 | border: 0; 114 | clip: rect(0, 0, 0, 0); 115 | height: 1px; 116 | margin: -1px; 117 | overflow: hidden; 118 | padding: 0; 119 | position: absolute; 120 | white-space: nowrap; 121 | width: 1px; 122 | /* 1 */ 123 | } 124 | 125 | /* 126 | * Extends the .sr-only class to allow the element 127 | * to be focusable when navigated to via the keyboard: 128 | * https://www.drupal.org/node/897638 129 | */ 130 | 131 | .sr-only.focusable:active, 132 | .sr-only.focusable:focus { 133 | clip: auto; 134 | height: auto; 135 | margin: 0; 136 | overflow: visible; 137 | position: static; 138 | white-space: inherit; 139 | width: auto; 140 | } 141 | 142 | /* 143 | * Hide visually and from screen readers, but maintain layout 144 | */ 145 | 146 | .invisible { 147 | visibility: hidden; 148 | } 149 | 150 | /* 151 | * Clearfix: contain floats 152 | * 153 | * For modern browsers 154 | * 1. The space content is one way to avoid an Opera bug when the 155 | * `contenteditable` attribute is included anywhere else in the document. 156 | * Otherwise it causes space to appear at the top and bottom of elements 157 | * that receive the `clearfix` class. 158 | * 2. The use of `table` rather than `block` is only necessary if using 159 | * `:before` to contain the top-margins of child elements. 160 | */ 161 | 162 | .clearfix::before, 163 | .clearfix::after { 164 | content: " "; 165 | display: table; 166 | } 167 | 168 | .clearfix::after { 169 | clear: both; 170 | } 171 | 172 | /* ========================================================================== 173 | EXAMPLE Media Queries for Responsive Design. 174 | These examples override the primary ('mobile first') styles. 175 | Modify as content requires. 176 | ========================================================================== */ 177 | 178 | @media only screen and (min-width: 35em) { 179 | /* Style adjustments for viewports that meet the condition */ 180 | } 181 | 182 | @media print, 183 | (-webkit-min-device-pixel-ratio: 1.25), 184 | (min-resolution: 1.25dppx), 185 | (min-resolution: 120dpi) { 186 | /* Style adjustments for high resolution devices */ 187 | } 188 | 189 | /* ========================================================================== 190 | Print styles. 191 | Inlined to avoid the additional HTTP request: 192 | https://www.phpied.com/delay-loading-your-print-css/ 193 | ========================================================================== */ 194 | 195 | @media print { 196 | *, 197 | *::before, 198 | *::after { 199 | background: #fff !important; 200 | color: #000 !important; 201 | /* Black prints faster */ 202 | box-shadow: none !important; 203 | text-shadow: none !important; 204 | } 205 | 206 | a, 207 | a:visited { 208 | text-decoration: underline; 209 | } 210 | 211 | a[href]::after { 212 | content: " (" attr(href) ")"; 213 | } 214 | 215 | abbr[title]::after { 216 | content: " (" attr(title) ")"; 217 | } 218 | 219 | /* 220 | * Don't show links that are fragment identifiers, 221 | * or use the `javascript:` pseudo protocol 222 | */ 223 | a[href^="#"]::after, 224 | a[href^="javascript:"]::after { 225 | content: ""; 226 | } 227 | 228 | pre { 229 | white-space: pre-wrap !important; 230 | } 231 | 232 | pre, 233 | blockquote { 234 | border: 1px solid #999; 235 | page-break-inside: avoid; 236 | } 237 | 238 | /* 239 | * Printing Tables: 240 | * https://web.archive.org/web/20180815150934/http://css-discuss.incutio.com/wiki/Printing_Tables 241 | */ 242 | thead { 243 | display: table-header-group; 244 | } 245 | 246 | tr, 247 | img { 248 | page-break-inside: avoid; 249 | } 250 | 251 | p, 252 | h2, 253 | h3 { 254 | orphans: 3; 255 | widows: 3; 256 | } 257 | 258 | h2, 259 | h3 { 260 | page-break-after: avoid; 261 | } 262 | } 263 | 264 | -------------------------------------------------------------------------------- /docs/static/css/normalize.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ 2 | 3 | /* Document 4 | ========================================================================== */ 5 | 6 | /** 7 | * 1. Correct the line height in all browsers. 8 | * 2. Prevent adjustments of font size after orientation changes in iOS. 9 | */ 10 | 11 | html { 12 | line-height: 1.15; /* 1 */ 13 | -webkit-text-size-adjust: 100%; /* 2 */ 14 | } 15 | 16 | /* Sections 17 | ========================================================================== */ 18 | 19 | /** 20 | * Remove the margin in all browsers. 21 | */ 22 | 23 | body { 24 | margin: 0; 25 | } 26 | 27 | /** 28 | * Render the `main` element consistently in IE. 29 | */ 30 | 31 | main { 32 | display: block; 33 | } 34 | 35 | /** 36 | * Correct the font size and margin on `h1` elements within `section` and 37 | * `article` contexts in Chrome, Firefox, and Safari. 38 | */ 39 | 40 | h1 { 41 | font-size: 2em; 42 | margin: 0.67em 0; 43 | } 44 | 45 | /* Grouping content 46 | ========================================================================== */ 47 | 48 | /** 49 | * 1. Add the correct box sizing in Firefox. 50 | * 2. Show the overflow in Edge and IE. 51 | */ 52 | 53 | hr { 54 | box-sizing: content-box; /* 1 */ 55 | height: 0; /* 1 */ 56 | overflow: visible; /* 2 */ 57 | } 58 | 59 | /** 60 | * 1. Correct the inheritance and scaling of font size in all browsers. 61 | * 2. Correct the odd `em` font sizing in all browsers. 62 | */ 63 | 64 | pre { 65 | font-family: monospace, monospace; /* 1 */ 66 | font-size: 1em; /* 2 */ 67 | } 68 | 69 | /* Text-level semantics 70 | ========================================================================== */ 71 | 72 | /** 73 | * Remove the gray background on active links in IE 10. 74 | */ 75 | 76 | a { 77 | background-color: transparent; 78 | } 79 | 80 | /** 81 | * 1. Remove the bottom border in Chrome 57- 82 | * 2. Add the correct text decoration in Chrome, Edge, IE, Opera, and Safari. 83 | */ 84 | 85 | abbr[title] { 86 | border-bottom: none; /* 1 */ 87 | text-decoration: underline; /* 2 */ 88 | text-decoration: underline dotted; /* 2 */ 89 | } 90 | 91 | /** 92 | * Add the correct font weight in Chrome, Edge, and Safari. 93 | */ 94 | 95 | b, 96 | strong { 97 | font-weight: bolder; 98 | } 99 | 100 | /** 101 | * 1. Correct the inheritance and scaling of font size in all browsers. 102 | * 2. Correct the odd `em` font sizing in all browsers. 103 | */ 104 | 105 | code, 106 | kbd, 107 | samp { 108 | font-family: monospace, monospace; /* 1 */ 109 | font-size: 1em; /* 2 */ 110 | } 111 | 112 | /** 113 | * Add the correct font size in all browsers. 114 | */ 115 | 116 | small { 117 | font-size: 80%; 118 | } 119 | 120 | /** 121 | * Prevent `sub` and `sup` elements from affecting the line height in 122 | * all browsers. 123 | */ 124 | 125 | sub, 126 | sup { 127 | font-size: 75%; 128 | line-height: 0; 129 | position: relative; 130 | vertical-align: baseline; 131 | } 132 | 133 | sub { 134 | bottom: -0.25em; 135 | } 136 | 137 | sup { 138 | top: -0.5em; 139 | } 140 | 141 | /* Embedded content 142 | ========================================================================== */ 143 | 144 | /** 145 | * Remove the border on images inside links in IE 10. 146 | */ 147 | 148 | img { 149 | border-style: none; 150 | } 151 | 152 | /* Forms 153 | ========================================================================== */ 154 | 155 | /** 156 | * 1. Change the font styles in all browsers. 157 | * 2. Remove the margin in Firefox and Safari. 158 | */ 159 | 160 | button, 161 | input, 162 | optgroup, 163 | select, 164 | textarea { 165 | font-family: inherit; /* 1 */ 166 | font-size: 100%; /* 1 */ 167 | line-height: 1.15; /* 1 */ 168 | margin: 0; /* 2 */ 169 | } 170 | 171 | /** 172 | * Show the overflow in IE. 173 | * 1. Show the overflow in Edge. 174 | */ 175 | 176 | button, 177 | input { /* 1 */ 178 | overflow: visible; 179 | } 180 | 181 | /** 182 | * Remove the inheritance of text transform in Edge, Firefox, and IE. 183 | * 1. Remove the inheritance of text transform in Firefox. 184 | */ 185 | 186 | button, 187 | select { /* 1 */ 188 | text-transform: none; 189 | } 190 | 191 | /** 192 | * Correct the inability to style clickable types in iOS and Safari. 193 | */ 194 | 195 | button, 196 | [type="button"], 197 | [type="reset"], 198 | [type="submit"] { 199 | -webkit-appearance: button; 200 | } 201 | 202 | /** 203 | * Remove the inner border and padding in Firefox. 204 | */ 205 | 206 | button::-moz-focus-inner, 207 | [type="button"]::-moz-focus-inner, 208 | [type="reset"]::-moz-focus-inner, 209 | [type="submit"]::-moz-focus-inner { 210 | border-style: none; 211 | padding: 0; 212 | } 213 | 214 | /** 215 | * Restore the focus styles unset by the previous rule. 216 | */ 217 | 218 | button:-moz-focusring, 219 | [type="button"]:-moz-focusring, 220 | [type="reset"]:-moz-focusring, 221 | [type="submit"]:-moz-focusring { 222 | outline: 1px dotted ButtonText; 223 | } 224 | 225 | /** 226 | * Correct the padding in Firefox. 227 | */ 228 | 229 | fieldset { 230 | padding: 0.35em 0.75em 0.625em; 231 | } 232 | 233 | /** 234 | * 1. Correct the text wrapping in Edge and IE. 235 | * 2. Correct the color inheritance from `fieldset` elements in IE. 236 | * 3. Remove the padding so developers are not caught out when they zero out 237 | * `fieldset` elements in all browsers. 238 | */ 239 | 240 | legend { 241 | box-sizing: border-box; /* 1 */ 242 | color: inherit; /* 2 */ 243 | display: table; /* 1 */ 244 | max-width: 100%; /* 1 */ 245 | padding: 0; /* 3 */ 246 | white-space: normal; /* 1 */ 247 | } 248 | 249 | /** 250 | * Add the correct vertical alignment in Chrome, Firefox, and Opera. 251 | */ 252 | 253 | progress { 254 | vertical-align: baseline; 255 | } 256 | 257 | /** 258 | * Remove the default vertical scrollbar in IE 10+. 259 | */ 260 | 261 | textarea { 262 | overflow: auto; 263 | } 264 | 265 | /** 266 | * 1. Add the correct box sizing in IE 10. 267 | * 2. Remove the padding in IE 10. 268 | */ 269 | 270 | [type="checkbox"], 271 | [type="radio"] { 272 | box-sizing: border-box; /* 1 */ 273 | padding: 0; /* 2 */ 274 | } 275 | 276 | /** 277 | * Correct the cursor style of increment and decrement buttons in Chrome. 278 | */ 279 | 280 | [type="number"]::-webkit-inner-spin-button, 281 | [type="number"]::-webkit-outer-spin-button { 282 | height: auto; 283 | } 284 | 285 | /** 286 | * 1. Correct the odd appearance in Chrome and Safari. 287 | * 2. Correct the outline style in Safari. 288 | */ 289 | 290 | [type="search"] { 291 | -webkit-appearance: textfield; /* 1 */ 292 | outline-offset: -2px; /* 2 */ 293 | } 294 | 295 | /** 296 | * Remove the inner padding in Chrome and Safari on macOS. 297 | */ 298 | 299 | [type="search"]::-webkit-search-decoration { 300 | -webkit-appearance: none; 301 | } 302 | 303 | /** 304 | * 1. Correct the inability to style clickable types in iOS and Safari. 305 | * 2. Change font properties to `inherit` in Safari. 306 | */ 307 | 308 | ::-webkit-file-upload-button { 309 | -webkit-appearance: button; /* 1 */ 310 | font: inherit; /* 2 */ 311 | } 312 | 313 | /* Interactive 314 | ========================================================================== */ 315 | 316 | /* 317 | * Add the correct display in Edge, IE 10+, and Firefox. 318 | */ 319 | 320 | details { 321 | display: block; 322 | } 323 | 324 | /* 325 | * Add the correct display in all browsers. 326 | */ 327 | 328 | summary { 329 | display: list-item; 330 | } 331 | 332 | /* Misc 333 | ========================================================================== */ 334 | 335 | /** 336 | * Add the correct display in IE 10+. 337 | */ 338 | 339 | template { 340 | display: none; 341 | } 342 | 343 | /** 344 | * Add the correct display in IE 10. 345 | */ 346 | 347 | [hidden] { 348 | display: none; 349 | } 350 | -------------------------------------------------------------------------------- /docs/static/js/main.js: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bennyguo/sketch2model/c1431c22e33889595bfac9e693b889fc000af6ec/docs/static/js/main.js -------------------------------------------------------------------------------- /docs/static/js/plugins.js: -------------------------------------------------------------------------------- 1 | // Avoid `console` errors in browsers that lack a console. 2 | (function() { 3 | var method; 4 | var noop = function () {}; 5 | var methods = [ 6 | 'assert', 'clear', 'count', 'debug', 'dir', 'dirxml', 'error', 7 | 'exception', 'group', 'groupCollapsed', 'groupEnd', 'info', 'log', 8 | 'markTimeline', 'profile', 'profileEnd', 'table', 'time', 'timeEnd', 9 | 'timeline', 'timelineEnd', 'timeStamp', 'trace', 'warn' 10 | ]; 11 | var length = methods.length; 12 | var console = (window.console = window.console || {}); 13 | 14 | while (length--) { 15 | method = methods[length]; 16 | 17 | // Only stub undefined methods. 18 | if (!console[method]) { 19 | console[method] = noop; 20 | } 21 | } 22 | }()); 23 | 24 | // Place any jQuery/helper plugins in here. 25 | -------------------------------------------------------------------------------- /docs/static/js/vendor/modernizr-3.11.2.min.js: -------------------------------------------------------------------------------- 1 | /*! modernizr 3.11.2 (Custom Build) | MIT * 2 | * https://modernizr.com/download/?-cssanimations-csscolumns-customelements-flexbox-history-picture-pointerevents-postmessage-sizes-srcset-webgl-websockets-webworkers-addtest-domprefixes-hasevent-mq-prefixedcssvalue-prefixes-setclasses-testallprops-testprop-teststyles !*/ 3 | !function(e,t,n,r){function o(e,t){return typeof e===t}function i(e){var t=_.className,n=Modernizr._config.classPrefix||"";if(S&&(t=t.baseVal),Modernizr._config.enableJSClass){var r=new RegExp("(^|\\s)"+n+"no-js(\\s|$)");t=t.replace(r,"$1"+n+"js$2")}Modernizr._config.enableClasses&&(e.length>0&&(t+=" "+n+e.join(" "+n)),S?_.className.baseVal=t:_.className=t)}function s(e,t){if("object"==typeof e)for(var n in e)k(e,n)&&s(n,e[n]);else{e=e.toLowerCase();var r=e.split("."),o=Modernizr[r[0]];if(2===r.length&&(o=o[r[1]]),void 0!==o)return Modernizr;t="function"==typeof t?t():t,1===r.length?Modernizr[r[0]]=t:(!Modernizr[r[0]]||Modernizr[r[0]]instanceof Boolean||(Modernizr[r[0]]=new Boolean(Modernizr[r[0]])),Modernizr[r[0]][r[1]]=t),i([(t&&!1!==t?"":"no-")+r.join("-")]),Modernizr._trigger(e,t)}return Modernizr}function a(){return"function"!=typeof n.createElement?n.createElement(arguments[0]):S?n.createElementNS.call(n,"http://www.w3.org/2000/svg",arguments[0]):n.createElement.apply(n,arguments)}function l(){var e=n.body;return e||(e=a(S?"svg":"body"),e.fake=!0),e}function u(e,t,r,o){var i,s,u,f,c="modernizr",d=a("div"),p=l();if(parseInt(r,10))for(;r--;)u=a("div"),u.id=o?o[r]:c+(r+1),d.appendChild(u);return i=a("style"),i.type="text/css",i.id="s"+c,(p.fake?p:d).appendChild(i),p.appendChild(d),i.styleSheet?i.styleSheet.cssText=e:i.appendChild(n.createTextNode(e)),d.id=c,p.fake&&(p.style.background="",p.style.overflow="hidden",f=_.style.overflow,_.style.overflow="hidden",_.appendChild(p)),s=t(d,e),p.fake?(p.parentNode.removeChild(p),_.style.overflow=f,_.offsetHeight):d.parentNode.removeChild(d),!!s}function f(e,n,r){var o;if("getComputedStyle"in t){o=getComputedStyle.call(t,e,n);var i=t.console;if(null!==o)r&&(o=o.getPropertyValue(r));else if(i){var s=i.error?"error":"log";i[s].call(i,"getComputedStyle returning null, its possible modernizr test results are inaccurate")}}else o=!n&&e.currentStyle&&e.currentStyle[r];return o}function c(e,t){return!!~(""+e).indexOf(t)}function d(e){return e.replace(/([A-Z])/g,function(e,t){return"-"+t.toLowerCase()}).replace(/^ms-/,"-ms-")}function p(e,n){var o=e.length;if("CSS"in t&&"supports"in t.CSS){for(;o--;)if(t.CSS.supports(d(e[o]),n))return!0;return!1}if("CSSSupportsRule"in t){for(var i=[];o--;)i.push("("+d(e[o])+":"+n+")");return i=i.join(" or "),u("@supports ("+i+") { #modernizr { position: absolute; } }",function(e){return"absolute"===f(e,null,"position")})}return r}function m(e){return e.replace(/([a-z])-([a-z])/g,function(e,t,n){return t+n.toUpperCase()}).replace(/^-/,"")}function h(e,t,n,i){function s(){u&&(delete N.style,delete N.modElem)}if(i=!o(i,"undefined")&&i,!o(n,"undefined")){var l=p(e,n);if(!o(l,"undefined"))return l}for(var u,f,d,h,A,v=["modernizr","tspan","samp"];!N.style&&v.length;)u=!0,N.modElem=a(v.shift()),N.style=N.modElem.style;for(d=e.length,f=0;f of the model class.""" 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | """Create a model given the option. 36 | 37 | This function warps the class CustomDatasetDataLoader. 38 | 39 | Example: 40 | >>> from models import create_model 41 | >>> model = create_model(opt) 42 | """ 43 | model = find_model_using_name(opt.model) 44 | instance = model(opt) 45 | print("model [%s] was created" % type(instance).__name__) 46 | return instance 47 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import torch 5 | from collections import OrderedDict 6 | from abc import ABC, abstractmethod 7 | from options import Configurable 8 | from . import networks 9 | 10 | 11 | class BaseModel(ABC, Configurable): 12 | """This class is an abstract base class (ABC) for models. 13 | """ 14 | 15 | def __init__(self, opt): 16 | self.opt = opt 17 | self.isTrain, self.isTest, self.isInfer = opt.isTrain, opt.isTest, opt.isInfer 18 | self.device = opt.device 19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | # losses 22 | self.train_loss_names = [] 23 | self.val_loss_names = [] 24 | self.test_loss_names = [] 25 | self.infer_loss_names = [] 26 | 27 | # models 28 | self.model_names = [] 29 | 30 | self.optimizers = [] 31 | 32 | @abstractmethod 33 | def set_input(self, input): 34 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 35 | 36 | Parameters: 37 | input (dict): includes the data itself and its metadata information. 38 | """ 39 | pass 40 | 41 | @abstractmethod 42 | def forward(self): 43 | """Run forward pass.""" 44 | pass 45 | 46 | @abstractmethod 47 | def optimize_parameters(self): 48 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 49 | pass 50 | 51 | def setup(self, opt): 52 | """Load and print networks; create schedulers 53 | 54 | Parameters: 55 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 56 | """ 57 | 58 | current_epoch = 0 59 | 60 | if self.isTrain and opt.init_weights: 61 | init_weights_name, init_weights_epoch = ':'.join(opt.init_weights.split(':')[:-1]), opt.init_weights.split(':')[-1] 62 | self.load_networks(init_weights_name, init_weights_epoch, opt.init_weights_keys) 63 | 64 | if not self.isTrain or opt.continue_train: 65 | if opt.load_epoch == 'latest': 66 | current_epoch = max([int(os.path.basename(x).split('_')[0]) for x in glob.glob(os.path.join(self.save_dir, '*.pth')) if 'latest' not in x]) 67 | opt.load_epoch = current_epoch 68 | else: 69 | current_epoch = int(opt.load_epoch) 70 | self.load_networks(opt.name, opt.load_epoch) 71 | 72 | if self.isTrain and opt.fix_layers: 73 | for name in self.model_names: 74 | net = getattr(self, 'net' + name) 75 | if isinstance(net, torch.nn.DataParallel): 76 | net = net.module 77 | for param_name, params in net.named_parameters(): 78 | if re.match(opt.fix_layers, param_name): 79 | params.requires_grad = False 80 | 81 | if self.isTrain: 82 | self.schedulers = [networks.get_scheduler(optimizer, opt, last_epoch=current_epoch - 1) for optimizer in self.optimizers] 83 | 84 | self.print_networks(opt.verbose) 85 | return current_epoch 86 | 87 | def train(self): 88 | """Make models train mode during training time""" 89 | self.optimization = True 90 | for name in self.model_names: 91 | if isinstance(name, str): 92 | net = getattr(self, 'net' + name) 93 | net.train() 94 | 95 | def eval(self): 96 | """Make models eval mode during test time""" 97 | self.optimization = False 98 | for name in self.model_names: 99 | if isinstance(name, str): 100 | net = getattr(self, 'net' + name) 101 | net.eval() 102 | 103 | @abstractmethod 104 | def validate(self): 105 | """Function for validation procedure.""" 106 | pass 107 | 108 | @abstractmethod 109 | def test(self): 110 | """Function for test procedure.""" 111 | pass 112 | 113 | @abstractmethod 114 | def inference(self): 115 | """Function for inference procedure.""" 116 | pass 117 | 118 | @abstractmethod 119 | def update_hyperparameters(self, epoch): 120 | """ 121 | Define how hyperparameters are updated. 122 | Called before each epoch. 123 | """ 124 | pass 125 | 126 | @abstractmethod 127 | def update_hyperparameters_step(self, step): 128 | """ 129 | Define how hyperparameters are updated. 130 | Called before each step. 131 | """ 132 | pass 133 | 134 | def update_learning_rate(self): 135 | """Update learning rates for all the networks; called at the end of every epoch""" 136 | old_lr = self.optimizers[0].param_groups[0]['lr'] 137 | for scheduler in self.schedulers: 138 | scheduler.step() 139 | lr = self.optimizers[0].param_groups[0]['lr'] 140 | 141 | def get_learning_rate(self): 142 | lr = self.optimizers[0].param_groups[0]['lr'] 143 | return lr 144 | 145 | def get_current_visuals(self, mode): 146 | """Return visualizations.""" 147 | visual_ret = OrderedDict() 148 | for name in getattr(self, f"{mode}_visual_names"): 149 | if isinstance(name, str) and hasattr(self, name): 150 | visual_ret[name] = getattr(self, name) 151 | return visual_ret 152 | 153 | def get_current_losses(self, mode): 154 | """Return losses / errors, used for logging.""" 155 | errors_ret = OrderedDict() 156 | for name in getattr(self, f"{mode}_loss_names"): 157 | if isinstance(name, str) and hasattr(self, 'loss_' + name): 158 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 159 | return errors_ret 160 | 161 | def save_networks(self, epoch): 162 | """Save all the networks to the disk. 163 | 164 | Parameters: 165 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 166 | """ 167 | for name in self.model_names: 168 | if isinstance(name, str): 169 | save_filename = '%s_net_%s.pth' % (epoch, name) 170 | save_path = os.path.join(self.save_dir, save_filename) 171 | net = getattr(self, 'net' + name) 172 | 173 | if isinstance(net, torch.nn.DataParallel): 174 | net = net.module 175 | torch.save(net.cpu().state_dict(), save_path) 176 | net.to(self.device) 177 | 178 | def load_networks(self, exp_name, epoch, keys=None): 179 | """Load all the networks from the disk. 180 | 181 | Parameters: 182 | exp_name (str) -- experiment name 183 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 184 | keys (re) -- names (regular expressions) of the parameters to be loaded 185 | """ 186 | for name in self.model_names: 187 | if isinstance(name, str): 188 | load_filename = '%s_net_%s.pth' % (epoch, name) 189 | load_path = os.path.join(self.opt.checkpoints_dir, exp_name, load_filename) 190 | net = getattr(self, 'net' + name) 191 | if isinstance(net, torch.nn.DataParallel): 192 | net = net.module 193 | print('loading the model from %s' % load_path, 'with keys', keys) 194 | if keys is None: 195 | state_dict = torch.load(load_path, map_location=self.device) 196 | net.load_state_dict(state_dict, strict=False) 197 | else: 198 | state_dict = {k: v for k, v in torch.load(load_path, map_location=self.device).items() if re.match(keys, k)} 199 | net.load_state_dict(state_dict, strict=False) 200 | 201 | def print_networks(self, verbose): 202 | """Print the total number of parameters in the network and (if verbose) network architecture 203 | 204 | Parameters: 205 | verbose (bool) -- if verbose: print the network architecture 206 | """ 207 | print('---------- Networks initialized -------------') 208 | for name in self.model_names: 209 | if isinstance(name, str): 210 | net = getattr(self, 'net' + name) 211 | num_params = 0 212 | for param in net.parameters(): 213 | num_params += param.numel() 214 | if verbose: 215 | print(net) 216 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 217 | print('-----------------------------------------------') 218 | 219 | def set_requires_grad(self, nets, requires_grad=False): 220 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 221 | Parameters: 222 | nets (network list) -- a list of networks 223 | requires_grad (bool) -- whether the networks require gradients or not 224 | """ 225 | if not isinstance(nets, list): 226 | nets = [nets] 227 | for net in nets: 228 | if net is not None: 229 | for param in net.parameters(): 230 | param.requires_grad = requires_grad 231 | -------------------------------------------------------------------------------- /models/criterions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import soft_renderer as sr 7 | 8 | 9 | class LaplacianLoss(nn.Module): 10 | def __init__(self, opt): 11 | super().__init__() 12 | self.opt = opt 13 | self.template_mesh = sr.Mesh.from_obj(opt.template_path) 14 | self.loss = sr.LaplacianLoss( 15 | self.template_mesh.vertices[0].cpu(), 16 | self.template_mesh.faces[0].cpu() 17 | ).to(opt.device) 18 | 19 | def forward(self, v): 20 | return self.loss(v).mean() 21 | 22 | 23 | class FlattenLoss(nn.Module): 24 | def __init__(self, opt): 25 | super().__init__() 26 | self.opt = opt 27 | self.template_mesh = sr.Mesh.from_obj(opt.template_path) 28 | self.loss = sr.FlattenLoss( 29 | self.template_mesh.faces[0].cpu() 30 | ).to(opt.device) 31 | 32 | def forward(self, v): 33 | return self.loss(v).mean() 34 | 35 | 36 | def iou(pred, target, eps=1e-6): 37 | dims = tuple(range(pred.ndimension())[1:]) 38 | intersect = (pred * target).sum(dims) 39 | union = (pred + target - pred * target).sum(dims) + eps 40 | return (intersect / union).sum() / intersect.nelement() 41 | 42 | 43 | class IoULoss(nn.Module): 44 | def __init__(self, opt): 45 | super().__init__() 46 | self.opt = opt 47 | 48 | def forward(self, pred, target): 49 | return 1 - iou(pred[:, 3], target[:, 3]) 50 | 51 | 52 | class MultiViewIoULoss(nn.Module): 53 | def __init__(self, opt): 54 | super().__init__() 55 | self.opt = opt 56 | 57 | def forward(self, pred, target_a, target_b): 58 | return ( 59 | 1 - iou(pred[0][:, 3], target_a[:, 3]) + 60 | 1 - iou(pred[1][:, 3], target_a[:, 3]) + 61 | 1 - iou(pred[2][:, 3], target_b[:, 3]) + 62 | 1 - iou(pred[3][:, 3], target_b[:, 3]) 63 | ) / 4. 64 | 65 | 66 | class MSELoss(nn.Module): 67 | def __init__(self, opt): 68 | super().__init__() 69 | self.opt = opt 70 | 71 | def forward(self, pred, target): 72 | return F.mse_loss(pred, target, reduction='mean') 73 | 74 | 75 | def voxel_iou(pred, target): 76 | return ((pred * target).sum((1, 2, 3)) / (0 < (pred + target)).sum((1, 2, 3))).mean() 77 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim import lr_scheduler 5 | 6 | 7 | def init_net(net, opt): 8 | if opt.n_gpus > 0: 9 | assert(torch.cuda.is_available()) 10 | net.to(opt.device) 11 | net = torch.nn.DataParallel(net) # multi-GPUs 12 | return net 13 | 14 | 15 | def get_scheduler(optimizer, opt, last_epoch=-1): 16 | """Return a learning rate scheduler 17 | 18 | Parameters: 19 | optimizer -- the optimizer of the network 20 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  21 | opt.lr_policy is the name of learning rate policy: linear | exp | step 22 | """ 23 | if opt.lr_policy == 'linear': 24 | """Linear decay in the last opt.n_epochs_decay epochs.""" 25 | def lambda_rule(epoch): 26 | t = max(0, epoch + 1 - opt.n_epochs + opt.n_epochs_decay) / float(opt.n_epochs_decay + 1) 27 | lr = opt.lr * (1 - t) + opt.lr_final * t 28 | return lr / opt.lr 29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=last_epoch) 30 | elif opt.lr_policy == 'exp': 31 | """Exponential decay in the last opt.n_epochs_decay epochs.""" 32 | def lambda_rule(epoch): 33 | t = max(0, epoch + 1 - opt.n_epochs + opt.n_epochs_decay) / float(opt.n_epochs_decay + 1) 34 | lr = math.exp(math.log(opt.lr) * (1 - t) + math.log(opt.lr_final) * t) 35 | return lr / opt.lr 36 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=last_epoch) 37 | elif opt.lr_policy == 'step': 38 | """Decay every opt.lr_decay_epochs epochs.""" 39 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=opt.lr_decay_gamma, last_epoch=last_epoch) 40 | else: 41 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 42 | return scheduler 43 | -------------------------------------------------------------------------------- /models/view_disentangle_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import torch.nn.functional as F 9 | from torch.autograd import Function 10 | from torch.nn.modules.loss import MSELoss 11 | import torchvision 12 | from tqdm import tqdm 13 | 14 | import soft_renderer as sr 15 | import soft_renderer.functional as srf 16 | 17 | 18 | from .base_model import BaseModel 19 | from .networks import init_net 20 | from .criterions import * 21 | from utils.utils import tensor2im 22 | 23 | class gradient_reversal(Function): 24 | """ 25 | Gradient Reversal Layer from: 26 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 27 | Forward pass is the identity function. In the backward pass, 28 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 29 | """ 30 | @staticmethod 31 | def forward(ctx, x, lambda_): 32 | ctx.lambda_ = lambda_ 33 | return x.clone() 34 | 35 | @staticmethod 36 | def backward(ctx, grads): 37 | lambda_ = ctx.lambda_ 38 | lambda_ = grads.new_tensor(lambda_) 39 | dx = -lambda_ * grads 40 | return dx, None 41 | 42 | 43 | class GradientReversalLayer(nn.Module): 44 | def __init__(self, lambda_=1): 45 | super(GradientReversalLayer, self).__init__() 46 | self.lambda_ = lambda_ 47 | 48 | def forward(self, x): 49 | return gradient_reversal.apply(x, self.lambda_) 50 | 51 | 52 | class ResNet18Encoder(nn.Module): 53 | def __init__(self, dim_in, pretrained): 54 | super(ResNet18Encoder, self).__init__() 55 | assert(dim_in == 3) 56 | print('ResNet18 pretrained:', pretrained) 57 | self.backbone = torchvision.models.resnet18(pretrained=pretrained) 58 | self.backbone.avgpool = nn.Identity() 59 | self.backbone.fc = nn.Identity() 60 | self.x = {} 61 | 62 | def forward(self, x): 63 | batch_size = x.shape[0] 64 | x0 = self.backbone.conv1(x) 65 | x0 = self.backbone.bn1(x0) 66 | x0 = self.backbone.relu(x0) 67 | x0 = self.backbone.maxpool(x0) 68 | 69 | x1 = self.backbone.layer1(x0) 70 | x2 = self.backbone.layer2(x1) 71 | x3 = self.backbone.layer3(x2) 72 | x4 = self.backbone.layer4(x3) 73 | 74 | self.x[0], self.x[1], self.x[2], self.x[3], self.x[4] = x0, x1, x2, x3, x4 75 | 76 | x = x4 77 | x = x.view(batch_size, -1) 78 | return x 79 | 80 | 81 | class Encoder(nn.Module): 82 | def __init__(self, dim_in, dim_hidden, dim_s, dim_v, normalize=True): 83 | super(Encoder, self).__init__() 84 | self.fc1 = nn.Linear(dim_in, dim_hidden) 85 | self.fc2 = nn.Linear(dim_hidden, dim_hidden) 86 | self.fc_s = nn.Linear(dim_hidden, dim_s) 87 | self.fc_v = nn.Linear(dim_hidden, dim_v) 88 | self.normalize = normalize 89 | 90 | def forward(self, x): 91 | x = F.relu(self.fc1(x), inplace=True) 92 | x = F.relu(self.fc2(x), inplace=True) 93 | zs = F.relu(self.fc_s(x), inplace=True) 94 | zv = F.relu(self.fc_v(x), inplace=True) 95 | if self.normalize: 96 | zs = F.normalize(zs, dim=1) 97 | zv = F.normalize(zv, dim=1) 98 | return zs, zv 99 | 100 | 101 | class ViewEncoder(nn.Module): 102 | def __init__(self, dim_hidden, dim_out, normalize=True): 103 | super(ViewEncoder, self).__init__() 104 | dim_in = 2 105 | self.fc1 = nn.Linear(dim_in, dim_hidden) 106 | self.fc2 = nn.Linear(dim_hidden, dim_out) 107 | self.normalize = normalize 108 | 109 | def forward(self, x): 110 | x = F.relu(self.fc1(x), inplace=True) 111 | x = F.relu(self.fc2(x), inplace=True) 112 | if self.normalize: 113 | x = F.normalize(x, dim=1) 114 | return x 115 | 116 | 117 | class ViewDecoder(nn.Module): 118 | def __init__(self, dim_in, dim_hidden): 119 | super(ViewDecoder, self).__init__() 120 | dim_out = 2 121 | self.fc1 = nn.Linear(dim_in, dim_hidden) 122 | self.fc2 = nn.Linear(dim_hidden, dim_out) 123 | self.sigmoid = nn.Sigmoid() 124 | 125 | def forward(self, x): 126 | x = F.relu(self.fc1(x), inplace=True) 127 | x = self.sigmoid(self.fc2(x)) 128 | return x 129 | 130 | 131 | class Decoder(nn.Module): 132 | def __init__(self, dim_in, dim_shape, dim_view, dim_hidden, normalize=True): 133 | super(Decoder, self).__init__() 134 | self.dim_in = dim_in 135 | self.dim_shape = dim_shape 136 | self.dim_view = dim_view 137 | self.fc1 = nn.Linear(dim_in + dim_view, dim_hidden) 138 | self.fc2 = nn.Linear(dim_hidden, dim_shape) 139 | self.normalize = normalize 140 | 141 | def forward(self, x, view): 142 | x = torch.cat([x, view], dim=1) 143 | x = F.relu(self.fc1(x), inplace=True) 144 | x = F.relu(self.fc2(x), inplace=True) 145 | if self.normalize: 146 | x = F.normalize(x, dim=1) 147 | return x 148 | 149 | 150 | 151 | class MeshDecoder(nn.Module): 152 | """This MeshDecoder follows N3MR and SoftRas""" 153 | def __init__(self, filename_obj, dim_in, centroid_scale=0.1, bias_scale=1.0): 154 | super(MeshDecoder, self).__init__() 155 | self.template_mesh = sr.Mesh.from_obj(filename_obj) 156 | self.register_buffer('vertices_base', self.template_mesh.vertices.cpu()[0]) 157 | self.register_buffer('faces', self.template_mesh.faces.cpu()[0]) 158 | 159 | self.nv = self.vertices_base.size(0) 160 | self.nf = self.faces.size(0) 161 | self.centroid_scale = centroid_scale 162 | self.bias_scale = bias_scale 163 | self.obj_scale = 0.5 164 | 165 | dim = 1024 166 | dim_hidden = [dim, dim*2] 167 | self.fc1 = nn.Linear(dim_in, dim_hidden[0]) 168 | self.fc2 = nn.Linear(dim_hidden[0], dim_hidden[1]) 169 | self.fc_centroid = nn.Linear(dim_hidden[1], 3) 170 | self.fc_bias = nn.Linear(dim_hidden[1], self.nv*3) 171 | 172 | def forward(self, x): 173 | batch_size = x.shape[0] 174 | x = F.relu(self.fc1(x), inplace=True) 175 | x = F.relu(self.fc2(x), inplace=True) 176 | 177 | centroid = self.fc_centroid(x) * self.centroid_scale 178 | 179 | bias = self.fc_bias(x) * self.bias_scale 180 | bias = bias.view(-1, self.nv, 3) 181 | 182 | base = self.vertices_base * self.obj_scale 183 | 184 | sign = torch.sign(base) 185 | base = torch.abs(base) 186 | base = torch.log(base / (1 - base)) 187 | 188 | centroid = torch.tanh(centroid[:, None, :]) 189 | scale_pos = 1 - centroid 190 | scale_neg = centroid + 1 191 | 192 | vertices = torch.sigmoid(base + bias) * sign 193 | vertices = F.relu(vertices) * scale_pos - F.relu(-vertices) * scale_neg 194 | vertices = vertices + centroid 195 | vertices = vertices * 0.5 196 | faces = self.faces[None, :, :].repeat(batch_size, 1, 1) 197 | 198 | return vertices, faces 199 | 200 | 201 | class Normalize(nn.Module): 202 | def __init__(self, dim): 203 | super(Normalize, self).__init__() 204 | self.dim = dim 205 | def forward(self, x): 206 | return F.normalize(x, dim=self.dim) 207 | 208 | 209 | class ViewDisentangleNetwork(nn.Module): 210 | def __init__(self, opt): 211 | super().__init__() 212 | self.opt = opt 213 | self.feature_extractor = ResNet18Encoder(dim_in=opt.dim_in, pretrained=True) 214 | self.encoder = Encoder(dim_in=512 * 7 * 7, dim_hidden=2048, dim_s=1024, dim_v=opt.view_dim, normalize=True) 215 | self.view_encoder = ViewEncoder(dim_hidden=opt.view_dim, dim_out=opt.view_dim, normalize=True) 216 | self.view_decoder = ViewDecoder(dim_in=opt.view_dim, dim_hidden=opt.view_dim//2) 217 | self.decoder = Decoder(dim_in=1024, dim_shape=512, dim_view=opt.view_dim, dim_hidden=1024, normalize=True) 218 | 219 | self.shape_discriminator = nn.Sequential( 220 | GradientReversalLayer(lambda_=opt.grl_lambda), 221 | nn.Linear(opt.n_vertices * 3, 256), 222 | nn.ReLU(inplace=True), 223 | nn.Linear(256, 1) 224 | ) 225 | 226 | self.domain_discriminator = nn.Sequential( 227 | GradientReversalLayer(lambda_=opt.grl_lambda), 228 | nn.Linear(25088, 256), 229 | nn.ReLU(inplace=True), 230 | nn.Linear(256, 1) 231 | ) 232 | 233 | self.mesh_decoder = MeshDecoder(opt.template_path, dim_in=512) 234 | 235 | def forward(self, image, view=None, view_rand=None): 236 | N = image.shape[0] 237 | if view is None and view_rand is None: 238 | """Inference""" 239 | ft = self.feature_extractor(image[:,:self.opt.dim_in]) 240 | zs, zv_pred = self.encoder(ft) 241 | view_pred = self.view_decoder(zv_pred) 242 | zv_recon = self.view_encoder(view_pred) 243 | z = self.decoder(zs, zv_recon) 244 | vertices, faces = self.mesh_decoder(z) 245 | return { 246 | 'vertices': vertices, 247 | 'faces': faces, 248 | 'view_pred': view_pred 249 | } 250 | else: 251 | """Training / Validation / Testing""" 252 | ft = self.feature_extractor(image[:,:self.opt.dim_in]) 253 | zs, zv_pred = self.encoder(ft) 254 | view_pred = self.view_decoder(zv_pred) 255 | zv_recon = self.view_encoder(view_pred) 256 | zv = self.view_encoder(view) 257 | view_recon = self.view_decoder(zv) 258 | z = self.decoder(zs, zv) # teacher forcing 259 | z_pred = self.decoder(zs, zv_recon) 260 | vertices, faces = self.mesh_decoder(z) 261 | vertices_pred, faces_pred = self.mesh_decoder(z_pred) 262 | sd_score = self.shape_discriminator(vertices.view(N, -1)) 263 | rv = { 264 | 'vertices': vertices, 265 | 'faces': faces, 266 | 'vertices_pred': vertices_pred, 267 | 'faces_pred': faces_pred, 268 | 'view_pred': view_pred, 269 | 'view_recon': view_recon, 270 | 'zv_pred': zv_pred, 271 | 'zv': zv, 272 | 'zv_recon': zv_recon, 273 | 'sd_score': sd_score 274 | } 275 | 276 | if view_rand is not None: 277 | """Training""" 278 | zv_rand = self.view_encoder(view_rand) 279 | z_rand = self.decoder(zs, zv_rand) 280 | vertices_rand, faces_rand = self.mesh_decoder(z_rand) 281 | sd_score_rand = self.shape_discriminator(vertices_rand.view(N, -1)) 282 | rv.update({ 283 | 'vertices_rand': vertices_rand, 284 | 'faces_rand': faces_rand, 285 | 'zv_rand': zv_rand, 286 | 'sd_score_rand': sd_score_rand 287 | }) 288 | 289 | return rv 290 | 291 | 292 | class ViewDisentangleModel(BaseModel): 293 | @staticmethod 294 | def modify_commandline_options(parser): 295 | return parser 296 | 297 | def __init__(self, opt): 298 | BaseModel.__init__(self, opt) 299 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the tensorboard. 300 | self.train_loss_names = ['tot', 'iou_tot', 'iou_rand_tot', 'laplacian', 'flatten', 'view_pred', 'view_recon', 'zv_recon', 'sd'] 301 | self.val_loss_names = ['voxel_iou', 'voxel_iou_pred'] 302 | self.test_loss_names = ['voxel_iou', 'voxel_iou_pred'] 303 | 304 | self.model_names = ['Full'] 305 | self.netFull = init_net(ViewDisentangleNetwork(opt), opt) 306 | if self.isTrain: # only defined during training time 307 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. 308 | self.criterions = { 309 | 'laplacian': LaplacianLoss(opt), 310 | 'flatten': FlattenLoss(opt), 311 | 'multiview-iou': MultiViewIoULoss(opt), 312 | 'iou': IoULoss(opt), 313 | 'mse': MSELoss(opt), 314 | 'gan': F.binary_cross_entropy_with_logits 315 | } 316 | # define and initialize optimizers. 317 | self.optimizer = torch.optim.Adam(self.netFull.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 318 | self.optimizers = [self.optimizer] 319 | 320 | # larger scales correspond to smaller rendered images 321 | # render_size = image_size // render_scale 322 | self.render_scales = [1, 2, 4] 323 | self.sigmas = [1e-5, 3e-5, 1e-4] 324 | self.adaptive_weighting_func = [ 325 | lambda e: 1 if e > 1600 else 0, 326 | lambda e: 1 if 800 < e <= 1600 else 0, 327 | lambda e: 1 if e <= 800 else 0 328 | ] 329 | self.renderers = [ 330 | sr.SoftRasterizer( 331 | image_size=opt.image_size // scale, sigma_val=sigma, aggr_func_rgb='hard', aggr_func_alpha='prod', dist_eps=1e-10 332 | ) for (scale, sigma) in zip(self.render_scales, self.sigmas) 333 | ] 334 | 335 | def set_input(self, input): 336 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 337 | 338 | Parameters: 339 | input: a dictionary that contains the data itself and its metadata information. 340 | """ 341 | for k, v in input.items(): 342 | if isinstance(v, list): 343 | """Training""" 344 | setattr(self, f"data_{k}_a", v[0].to(self.device)) 345 | setattr(self, f"data_{k}_b", v[1].to(self.device)) 346 | else: 347 | """Validation / Testing / Inference""" 348 | setattr(self, f"data_{k}", v.to(self.device)) 349 | 350 | def update_hyperparameters(self, epoch): 351 | super().update_hyperparameters(epoch) 352 | self.adaptive_weighting = [f(epoch) for f in self.adaptive_weighting_func] 353 | 354 | def update_hyperparameters_step(self, step): 355 | super().update_hyperparameters_step(step) 356 | 357 | def get_random_view(self, N): 358 | """ 359 | Get random elevation angle from [-20, 40] 360 | Get random azimuth angle from [-180, 180] 361 | """ 362 | elevation_rand = (torch.rand(N, dtype=torch.float32) * 60 - 20) 363 | azimuth_rand = (torch.rand(N, dtype=torch.float32) * 360 - 180) 364 | elevation_rand, azimuth_rand = elevation_rand.to(self.device), azimuth_rand.to(self.device) 365 | return elevation_rand, azimuth_rand 366 | 367 | def encode_view(self, view): 368 | """ 369 | Project elevation angle from [-20, 40] to [0, 1] 370 | Project azimuth angle from [-180, 180] to [0, 1] 371 | """ 372 | view = view.clone() 373 | view[:,0] = (view[:,0] + 20) / 60. 374 | view[:,1] = (view[:,1] + 180) / 360. 375 | return view 376 | 377 | def decode_view(self, view): 378 | """ 379 | Un-project elevation angle from [0, 1] to [-20, 40] 380 | Un-project azimuth angle from [0, 1] to [-180, 180] 381 | """ 382 | view = view.clone() 383 | view[:,0] = (view[:,0] * 60) - 20. 384 | view[:,1] = (view[:,1] * 360) - 180. 385 | return view 386 | 387 | def view2camera(self, view): 388 | """ 389 | Caculate camera position from given elevation and azimuth angle. 390 | The camera looks at the center of the object, with a distance of 2. 391 | """ 392 | N = view.shape[0] 393 | distance = torch.ones(N, dtype=torch.float32) * 2. 394 | distance = distance.to(self.device) 395 | camera = srf.get_points_from_angles(distance, view[:,0], view[:,1]) 396 | return camera 397 | 398 | def render_silhouette(self, v, f, camera, multiview=False): 399 | transform = sr.LookAt(viewing_angle=15, eye=camera) 400 | # only render when w > 0 to save time 401 | sil = [r(transform(sr.Mesh(v, f))) if w > 0 else None for r, w in zip(self.renderers, self.adaptive_weighting)] 402 | return [s.chunk(4, dim=0) if s is not None else None for s in sil] if multiview else sil 403 | 404 | def forward(self): 405 | """Run forward pass.""" 406 | N = self.data_image_a.shape[0] 407 | 408 | image_ab = torch.cat([self.data_image_a, self.data_image_b], dim=0) 409 | self.data_image_ab = image_ab 410 | camera_aabb = torch.cat([self.data_camera_a, self.data_camera_a, self.data_camera_b, self.data_camera_b], dim=0) 411 | 412 | view_a = torch.cat([self.data_elevation_a[:, None], self.data_azimuth_a[:, None]], dim=1) 413 | view_b = torch.cat([self.data_elevation_b[:, None], self.data_azimuth_b[:, None]], dim=1) 414 | 415 | elevation_a_rand, azimuth_a_rand = self.get_random_view(N) 416 | elevation_b_rand, azimuth_b_rand = self.get_random_view(N) 417 | 418 | view_a_rand = torch.cat([elevation_a_rand[:, None], azimuth_a_rand[:, None]], dim=1) 419 | view_b_rand = torch.cat([elevation_b_rand[:, None], azimuth_b_rand[:, None]], dim=1) 420 | camera_a_rand, camera_b_rand = self.view2camera(view_a_rand), self.view2camera(view_b_rand) 421 | camera_ab_rand = torch.cat([camera_a_rand, camera_b_rand], dim=0) 422 | 423 | view_a, view_b = self.encode_view(view_a), self.encode_view(view_b) 424 | view_a_rand, view_b_rand = self.encode_view(view_a_rand), self.encode_view(view_b_rand) 425 | 426 | view_ab = torch.cat([view_a, view_b], dim=0) 427 | view_ab_rand = torch.cat([view_a_rand, view_b_rand], dim=0) 428 | self.data_view_ab, self.data_view_ab_rand = view_ab, view_ab_rand 429 | 430 | out = self.netFull(image_ab, view=view_ab, view_rand=view_ab_rand) 431 | for k, v in out.items(): 432 | setattr(self, f"out_{k}", v) 433 | 434 | self.out_silhouette = self.render_silhouette( 435 | torch.cat([self.out_vertices, self.out_vertices], dim=0), 436 | torch.cat([self.out_faces, self.out_faces], dim=0), 437 | camera_aabb, 438 | multiview=True 439 | ) 440 | self.out_silhouette_rand = self.render_silhouette( 441 | self.out_vertices_rand, 442 | self.out_faces_rand, 443 | camera_ab_rand, 444 | multiview=False 445 | ) 446 | 447 | def forward_test(self): 448 | """Run forward pass for validation / testing.""" 449 | N = self.data_image.shape[0] 450 | self.data_view = torch.cat([self.data_elevation[:, None], self.data_azimuth[:, None]], dim=1) 451 | self.data_view = self.encode_view(self.data_view) 452 | out = self.netFull(self.data_image, view=self.data_view) 453 | for k, v in out.items(): 454 | setattr(self, f"out_{k}", v) 455 | 456 | def forward_inference(self): 457 | """Run forward pass for inference.""" 458 | if self.opt.view is None: 459 | self.data_view = None 460 | else: 461 | self.data_view = torch.cat([self.data_elevation[:, None], self.data_azimuth[:, None]], dim=1) 462 | self.data_view = self.encode_view(self.data_view) 463 | out = self.netFull(self.data_image, view=self.data_view) 464 | for k, v in out.items(): 465 | setattr(self, f"out_{k}", v) 466 | 467 | def backward(self): 468 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 469 | self.loss_laplacian, self.loss_laplacian_rand = self.criterions['laplacian'](self.out_vertices), self.criterions['laplacian'](self.out_vertices_rand) 470 | self.loss_flatten, self.loss_flatten_rand = self.criterions['flatten'](self.out_vertices), self.criterions['flatten'](self.out_vertices_rand) 471 | self.loss_view_pred = self.criterions['mse'](self.data_view_ab, self.out_view_pred) 472 | self.loss_view_recon = self.criterions['mse'](self.data_view_ab, self.out_view_recon) 473 | self.loss_zv_recon = self.criterions['mse'](self.out_zv_pred, self.out_zv_recon) 474 | self.loss_iou = [ 475 | self.criterions['multiview-iou']( 476 | sil, 477 | F.interpolate(self.data_image_a, (self.opt.image_size//scale, self.opt.image_size//scale), mode='nearest'), 478 | F.interpolate(self.data_image_b, (self.opt.image_size//scale, self.opt.image_size//scale), mode='nearest') 479 | ) * w if w > 0 else 0 for sil, w, scale in zip(self.out_silhouette, self.adaptive_weighting, self.render_scales) 480 | ] 481 | self.loss_iou_tot = sum(self.loss_iou) 482 | self.loss_iou_rand = [ 483 | self.criterions['iou']( 484 | sil, 485 | torch.cat([ 486 | F.interpolate(self.data_image_a, (self.opt.image_size//scale, self.opt.image_size//scale), mode='nearest'), 487 | F.interpolate(self.data_image_b, (self.opt.image_size//scale, self.opt.image_size//scale), mode='nearest') 488 | ], dim=0) 489 | ) * w if w > 0 else 0 for sil, w, scale in zip(self.out_silhouette_rand, self.adaptive_weighting, self.render_scales) 490 | ] 491 | self.loss_iou_rand_tot = sum(self.loss_iou_rand) 492 | 493 | self.loss_sd_real = self.criterions['gan'](self.out_sd_score, torch.ones_like(self.out_sd_score)) 494 | self.loss_sd_fake = self.criterions['gan'](self.out_sd_score_rand, torch.zeros_like(self.out_sd_score_rand)) 495 | self.loss_sd = self.loss_sd_real + self.loss_sd_fake 496 | 497 | self.loss_tot = self.loss_iou_tot + self.opt.lambda_iou_rand * self.loss_iou_rand_tot + \ 498 | self.opt.lambda_laplacian * (self.loss_laplacian + self.loss_laplacian_rand) + \ 499 | self.opt.lambda_flatten * (self.loss_flatten + self.loss_flatten_rand) + \ 500 | self.opt.lambda_view_pred * self.loss_view_pred + \ 501 | self.opt.lambda_view_recon * self.loss_view_recon + \ 502 | self.opt.lambda_zv_recon * self.loss_zv_recon + \ 503 | self.opt.lambda_sd * self.loss_sd 504 | 505 | self.loss_tot.backward() 506 | 507 | def optimize_parameters(self): 508 | """Update network weights; it will be called in every training iteration.""" 509 | self.forward() # first call forward to calculate intermediate results 510 | self.optimizer.zero_grad() # clear network G's existing gradients 511 | self.backward() # calculate gradients for network G 512 | self.optimizer.step() # update gradients for network G 513 | 514 | def visualize_train(self, it): 515 | """Save generated meshes every opt.vis_freq training steps.""" 516 | vis_save_dir = os.path.join(self.save_dir, 'train') 517 | os.makedirs(vis_save_dir, exist_ok=True) 518 | view = self.decode_view(self.data_view_ab) 519 | view_pred = self.decode_view(self.out_view_pred) 520 | view_rand = self.decode_view(self.data_view_ab_rand) 521 | image = self.data_image_ab[0] 522 | vt, f, vt_pred, f_pred, vt_rand, f_rand = self.out_vertices[0], self.out_faces[0], self.out_vertices_pred[0], self.out_faces_pred[0], self.out_vertices_rand[0], self.out_faces_rand[0] 523 | v, v_pred, v_rand = view[0], view_pred[0], view_rand[0] 524 | cv2.imwrite(os.path.join(vis_save_dir, f'{it:05d}_input.png'), tensor2im(image)[...,:3]) 525 | srf.save_obj(os.path.join(vis_save_dir, f'{it:05d}_e{int(v[0])}a{int(v[1])}.obj'), vt, f) 526 | srf.save_obj(os.path.join(vis_save_dir, f'{it:05d}_pred_e{int(v_pred[0])}a{int(v_pred[1])}.obj'), vt_pred, f_pred) 527 | srf.save_obj(os.path.join(vis_save_dir, f'{it:05d}_rand_e{int(v_rand[0])}a{int(v_rand[1])}.obj'), vt_rand, f_rand) 528 | 529 | def validate(self, epoch, dataset, phase='val', save_dir=None): 530 | """Validation procedure. Called every opt.val_epoch_freq epochs.""" 531 | count = 0 532 | iou_tot, iou_pred_tot = 0., 0. 533 | for i, data in enumerate(tqdm(dataset, desc=phase, total=len(dataset.dataloader))): 534 | self.set_input(data) 535 | self.forward_test() 536 | voxel_gt = self.data_voxel.cpu().numpy() 537 | faces = srf.face_vertices(self.out_vertices, self.out_faces) * 31. / 32. + 0.5 538 | voxel = srf.voxelization(faces, 32, False).cpu().numpy().transpose(0, 2, 1, 3)[...,::-1] 539 | faces_pred = srf.face_vertices(self.out_vertices_pred, self.out_faces_pred) * 31. / 32. + 0.5 540 | voxel_pred = srf.voxelization(faces_pred, 32, False).cpu().numpy().transpose(0, 2, 1, 3)[...,::-1] 541 | iou, iou_pred = voxel_iou(voxel, voxel_gt), voxel_iou(voxel_pred, voxel_gt) 542 | iou_tot += iou * self.data_image.shape[0] 543 | iou_pred_tot += iou_pred * self.data_image.shape[0] 544 | count += self.data_image.shape[0] 545 | 546 | if i < getattr(self.opt, f"{phase}_epoch_vis_n"): 547 | vis_save_dir = save_dir or os.path.join(self.save_dir, f"vis_{epoch}_{phase}") 548 | os.makedirs(vis_save_dir, exist_ok=True) 549 | view = self.decode_view(self.data_view) 550 | view_pred = self.decode_view(self.out_view_pred) 551 | image = self.data_image[0] 552 | vt, f, vt_pred, f_pred = self.out_vertices[0], self.out_faces[0], self.out_vertices_pred[0], self.out_faces_pred[0] 553 | v, v_pred = view[0], view_pred[0] 554 | cv2.imwrite(os.path.join(vis_save_dir, f'{i:02d}_input.png'), tensor2im(image)[...,:3]) 555 | srf.save_obj(os.path.join(vis_save_dir, f'{i:02d}_e{int(v[0])}a{int(v[1])}.obj'), vt, f) 556 | srf.save_obj(os.path.join(vis_save_dir, f'{i:02d}_pred_e{int(v_pred[0])}a{int(v_pred[1])}.obj'), vt_pred, f_pred) 557 | 558 | self.loss_voxel_iou, self.loss_voxel_iou_pred = iou_tot / count, iou_pred_tot / count 559 | 560 | def test(self, epoch, dataset, save_dir=None): 561 | """Validation procedure. Called every opt.test_epoch_freq epochs.""" 562 | self.validate(epoch, dataset, phase='test', save_dir=save_dir) 563 | 564 | def inference(self, epoch, dataset, save_dir): 565 | """Validation procedure. Generate 3D model from an input sketch and (optional) a given view.""" 566 | data = next(iter(dataset)) 567 | self.set_input(data) 568 | self.forward_inference() 569 | image = self.data_image[0] 570 | cv2.imwrite(os.path.join(save_dir, f'input.png'), tensor2im(image)[...,:3]) 571 | if self.opt.view is None: 572 | v_pred = self.decode_view(self.out_view_pred)[0] 573 | vt, f = self.out_vertices[0], self.out_faces[0] 574 | srf.save_obj(os.path.join(save_dir, f'pred-view_e{int(v_pred[0])}a{int(v_pred[1])}.obj'), vt, f) 575 | else: 576 | v = self.decode_view(self.data_view)[0] 577 | v_pred = self.decode_view(self.out_view_pred)[0] 578 | vt, f, vt_pred, f_pred = self.out_vertices[0], self.out_faces[0], self.out_vertices_pred[0], self.out_faces_pred[0] 579 | srf.save_obj(os.path.join(save_dir, f'given-view_e{int(v[0])}a{int(v[1])}.obj'), vt, f) 580 | srf.save_obj(os.path.join(save_dir, f'pred-view_e{int(v_pred[0])}a{int(v_pred[1])}.obj'), vt_pred, f_pred) 581 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Configurable(object): 5 | @staticmethod 6 | def modify_commandline_options(parser): 7 | """Add new options, and rewrite default values for existing options. 8 | 9 | Parameters: 10 | parser -- original option parser 11 | 12 | Returns: 13 | the modified parser. 14 | """ 15 | return parser 16 | 17 | 18 | def get_option_setter(obj): 19 | if issubclass(obj, Configurable): 20 | return obj.modify_commandline_options 21 | print('Class', obj, 'is not a subclass of Configurable, hence unable to modify commandline arguments.') 22 | 23 | 24 | def str2bool(v): 25 | if isinstance(v, bool): 26 | return v 27 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 28 | return True 29 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 30 | return False 31 | else: 32 | raise argparse.ArgumentTypeError('Boolean value expected.') 33 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models.base_model import BaseModel 3 | from data.base_dataset import BaseDataset 4 | import os 5 | from utils import utils 6 | import torch 7 | import options 8 | import json 9 | from utils.utils import find_class_using_name 10 | 11 | 12 | class BaseOptions(): 13 | """This class defines options used during both training and test time. 14 | 15 | It also implements several helper functions such as parsing, printing, and saving the options. 16 | It also gathers additional options defined in functions in both dataset class and model class. 17 | """ 18 | 19 | def __init__(self): 20 | """Reset the class; indicates the class hasn't been initailized""" 21 | self.initialized = False 22 | 23 | def initialize(self, parser): 24 | """Define the common options that are used in both training and test.""" 25 | # basic parameters 26 | parser.add_argument('--name', type=str, required=True, help='name of the experiment. It decides where to store samples and models') 27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 28 | parser.add_argument('--summary_dir', type=str, default='./runs', help='tensorboard logs are saved here') 29 | parser.add_argument('--seed', type=int, default=0) 30 | parser.add_argument('--class_id', type=str, required=True) 31 | 32 | # model parameters 33 | parser.add_argument('--model', type=str, default='view_disentangle', choices=['view_disentangle'], help='chooses which model to use.') 34 | parser.add_argument('--dim_in', type=int, default=3, help='number of input channels for image feature extractor') 35 | parser.add_argument('--grl_lambda', type=float, default=1, help='lambda in gradient reversal layer') 36 | parser.add_argument('--n_vertices', type=int, default=642, help='number of vertices of the base mesh') 37 | parser.add_argument('--image_size', type=int, default=224, help='input image size') 38 | parser.add_argument('--view_dim', type=int, default=512, help='dimension of the view latent code') 39 | parser.add_argument('--template_path', type=str, default='templates/sphere_642.obj', help='path to the base mesh') 40 | 41 | # dataset parameters 42 | parser.add_argument('--dataset_mode', type=str, default='shapenet', choices=['shapenet', 'shapenet_sketch', 'inference'], help='chooses how datasets are loaded.') 43 | parser.add_argument('--dataset_root', type=str, default='load/shapenet-synthetic') 44 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 45 | parser.add_argument('--batch_size', type=int, default=64, help='input batch size') 46 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 47 | 48 | # additional parameters 49 | parser.add_argument('--phase', type=str, choices=['train', 'test', 'infer']) 50 | parser.add_argument('--load_epoch', type=str, default='latest', help='epoch to load') 51 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 52 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{batch_size}') 53 | self.initialized = True 54 | return parser 55 | 56 | def gather_options(self): 57 | """Initialize our parser with basic options(only once). 58 | Add additional model-specific and dataset-specific options. 59 | These options are defined in the function 60 | in model and dataset classes. 61 | """ 62 | if not self.initialized: # check if it has been initialized 63 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 64 | parser = self.initialize(parser) 65 | 66 | # get the basic options 67 | opt, _ = parser.parse_known_args() 68 | 69 | # modify model-related parser options 70 | model_name = opt.model 71 | model_option_setter = options.get_option_setter(find_class_using_name(f"models.{model_name}_model", model_name, 'model', BaseModel)) 72 | parser = model_option_setter(parser) 73 | 74 | # modify dataset-related parser options 75 | dataset_name = opt.dataset_mode 76 | dataset_option_setter = options.get_option_setter(find_class_using_name(f"data.{dataset_name}_dataset", dataset_name, 'dataset', BaseDataset)) 77 | parser = dataset_option_setter(parser) 78 | 79 | # save and return the parser 80 | self.parser = parser 81 | return parser.parse_args() 82 | 83 | def print_options(self, opt): 84 | """Print and save options 85 | 86 | It will print both current options and default values(if different). 87 | It will save options into a text file -- [checkpoints_dir] / opt.txt 88 | It will save options into a json file -- [checkpoints_dir] / opt.json 89 | """ 90 | message = '' 91 | message += '----------------- Options ---------------\n' 92 | opt_dict = {} 93 | for k, v in sorted(vars(opt).items()): 94 | comment = '' 95 | default = self.parser.get_default(k) 96 | if v != default: 97 | comment = '\t[default: %s]' % str(default) 98 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 99 | opt_dict[k] = v 100 | message += '----------------- End -------------------' 101 | print(message) 102 | 103 | # save to the disk 104 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 105 | utils.mkdirs(expr_dir) 106 | with open(os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)), 'wt') as opt_file: 107 | opt_file.write(message) 108 | opt_file.write('\n') 109 | with open(os.path.join(expr_dir, '{}_opt.json'.format(opt.phase)), 'w') as opt_json_file: 110 | opt_json_file.write(json.dumps(opt_dict)) 111 | 112 | def parse(self): 113 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 114 | opt = self.gather_options() 115 | opt.isTrain = self.isTrain # train or test 116 | opt.isTest = self.isTest 117 | opt.isInfer = self.isInfer 118 | 119 | # process opt.suffix 120 | if opt.suffix: 121 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 122 | opt.name = opt.name + suffix 123 | 124 | opt.n_gpus = torch.cuda.device_count() 125 | 126 | opt.device = 'cuda:0' if opt.n_gpus > 0 else 'cpu' 127 | 128 | if opt.n_gpus > 0: 129 | torch.cuda.set_device(opt.device) 130 | 131 | self.print_options(opt) 132 | 133 | self.opt = opt 134 | return self.opt 135 | -------------------------------------------------------------------------------- /options/infer_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | from datetime import datetime 3 | 4 | 5 | class InferOptions(BaseOptions): 6 | """This class includes inference options. 7 | 8 | It also includes shared options defined in BaseOptions. 9 | """ 10 | 11 | def initialize(self, parser): 12 | parser = BaseOptions.initialize(self, parser) # define shared options 13 | parser.set_defaults(phase='infer', dataset_mode='inference') 14 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 15 | 16 | parser.add_argument('--data_name', type=str, default=datetime.now().strftime("%Y%m%d%H%M%S"), help='identifier to distinguish different runs') 17 | parser.add_argument('--image_path', type=str, required=True, help='path to input image') 18 | parser.add_argument('--view', type=float, nargs=2, required=False, help='specified view, in the format of [elevation azimuth]') 19 | 20 | self.isTrain, self.isTest, self.isInfer = False, False, True 21 | return parser 22 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.set_defaults(phase='test', batch_size=1) 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--test_split', type=str, default='test', help='which split to evaluate on') 15 | parser.add_argument('--test_epoch_vis_n', type=int, default=20, help='number of data to visualize') 16 | self.isTrain, self.isTest, self.isInfer = False, True, False 17 | return parser 18 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | parser.set_defaults(phase='train') 13 | 14 | parser.add_argument('--lambda_iou_rand', type=float, default=0.1) 15 | parser.add_argument('--lambda_laplacian', type=float, default=5e-3) 16 | parser.add_argument('--lambda_flatten', type=float, default=5e-4) 17 | parser.add_argument('--lambda_view_pred', type=float, default=10) 18 | parser.add_argument('--lambda_view_recon', type=float, default=10) 19 | parser.add_argument('--lambda_zv_recon', type=float, default=100) 20 | parser.add_argument('--lambda_sd', type=float, default=0.1) 21 | 22 | # visdom and HTML visualization parameters 23 | parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console') 24 | parser.add_argument('--vis_freq', type=int, default=100, help='training visualization frequency, in steps') 25 | parser.add_argument('--val_epoch_freq', type=int, default=50, help='validation frequency, in epochs') 26 | parser.add_argument('--val_epoch_vis_n', type=int, default=20, help='number of data to visualize in validation') 27 | parser.add_argument('--test_epoch_freq', type=int, default=2000, help='testing frequency, in epochs') 28 | parser.add_argument('--test_epoch_vis_n', type=int, default=20, help='number of data to visualize in testing') 29 | 30 | # network saving and loading parameters 31 | parser.add_argument('--save_epoch_freq', type=int, default=500, help='frequency of saving checkpoints at the end of epochs') 32 | parser.add_argument('--continue_train', action='store_true', help='continue training: load model at epoch [load_epoch]') 33 | parser.add_argument('--init_weights', type=str, default=None, help='initialize weights from an existing model, in format [name]:[epoch]') 34 | parser.add_argument('--init_weights_keys', type=str, default='.+', help='regex for weights keys to be loaded') 35 | parser.add_argument('--fix_layers', type=str, default=None, help='regex for fix layers') 36 | 37 | # training parameters 38 | parser.add_argument('--n_epochs', type=int, default=2000, help='number of epochs in total') 39 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 40 | parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') 41 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | exp | step]') 42 | 43 | # linear | exp policy 44 | parser.add_argument('--lr_final', type=float, default=1e-5, help='final learning rate for adam, used in linear and exp') 45 | parser.add_argument('--n_epochs_decay', type=int, default=1000, help='number of epochs to decay learning rate to lr_final') 46 | 47 | # step policy 48 | parser.add_argument('--lr_decay_epochs', type=int, default=800, help='multiply by a gamma every lr_decay_epochs epochs, used in step') 49 | parser.add_argument('--lr_decay_gamma', type=float, default=0.3, help='multiply by a gamma every lr_decay_epochs epochs, used in step') 50 | 51 | self.isTrain, self.isTest, self.isInfer = True, False, False 52 | return parser 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | numpy 3 | scipy 4 | tqdm 5 | opencv-python 6 | tensorboard 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import create_dataset 4 | from models import create_model 5 | import torch 6 | 7 | 8 | if __name__ == '__main__': 9 | opt = TestOptions().parse() 10 | 11 | dataset_test = create_dataset(opt, mode=opt.test_split, shuffle=False) 12 | 13 | print(f'The number of test data = {len(dataset_test)}') 14 | 15 | model = create_model(opt) 16 | current_epoch = model.setup(opt) 17 | 18 | out_dir = os.path.join(opt.results_dir, opt.name, '{}-{}_{}'.format(opt.dataset_mode, opt.test_split, current_epoch)) 19 | print('creating out directory', out_dir) 20 | os.makedirs(out_dir, exist_ok=True) 21 | 22 | model.eval() 23 | with torch.no_grad(): 24 | model.test(current_epoch, dataset_test, save_dir=out_dir) 25 | test_losses = model.get_current_losses('test') 26 | print("Test losses |", ' '.join([f"{k}: {v:.3e}" for k, v in test_losses.items()])) 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from options.train_options import TrainOptions 4 | from data import create_dataset 5 | from models import create_model 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | if __name__ == '__main__': 11 | opt = TrainOptions().parse() 12 | 13 | torch.manual_seed(opt.seed) 14 | torch.cuda.manual_seed_all(opt.seed) 15 | 16 | dataset_train = create_dataset(opt, mode='train', shuffle=True) 17 | dataset_val = create_dataset(opt, mode='val', shuffle=False) 18 | dataset_test = create_dataset(opt, mode='test', shuffle=False) 19 | 20 | print(f'The number of training data = {len(dataset_train)}') 21 | print(f'The number of validation data = {len(dataset_val)}') 22 | print(f'The number of test data = {len(dataset_test)}') 23 | 24 | model = create_model(opt) 25 | writer = SummaryWriter(os.path.join(opt.summary_dir, opt.name)) 26 | current_epoch = model.setup(opt) 27 | total_iters = current_epoch * len(dataset_train.dataloader) 28 | 29 | for epoch in range(current_epoch + 1, opt.n_epochs + 1): 30 | epoch_start_time = time.time() 31 | iter_data_time = time.time() 32 | epoch_iter = 0 33 | print('Learning rate:', f"{model.get_learning_rate():.3e}") 34 | model.update_hyperparameters(epoch) 35 | for i, data in enumerate(dataset_train): 36 | iter_start_time = time.time() 37 | total_iters += 1 38 | epoch_iter += 1 39 | model.update_hyperparameters_step(total_iters) 40 | if total_iters % opt.print_freq == 0: 41 | t_data = iter_start_time - iter_data_time 42 | 43 | model.train() 44 | model.set_input(data) 45 | model.optimize_parameters() 46 | 47 | if total_iters % opt.vis_freq == 0: 48 | model.visualize_train(total_iters) 49 | 50 | if total_iters % opt.print_freq == 0: 51 | losses = model.get_current_losses('train') 52 | t_comp = time.time() - iter_start_time 53 | for loss_name, loss_val in losses.items(): 54 | writer.add_scalar(f"train_{loss_name}", loss_val, global_step=total_iters) 55 | print(f"Epoch {epoch} - Iteration {epoch_iter}/{len(dataset_train.dataloader)} (comp time {t_comp:.3f}, data time {t_data:.3f})") 56 | print("Training losses |", ' '.join([f"{k}: {v:.3e}" for k, v in losses.items()])) 57 | 58 | iter_data_time = time.time() 59 | 60 | if epoch % opt.val_epoch_freq == 0: 61 | model.eval() 62 | with torch.no_grad(): 63 | model.validate(epoch, dataset_val, phase='val') 64 | val_losses = model.get_current_losses('val') 65 | for loss_name, loss_val in val_losses.items(): 66 | writer.add_scalar(f"val_{loss_name}", loss_val, global_step=epoch) 67 | print("Validation losses |", ' '.join([f"{k}: {v:.3e}" for k, v in val_losses.items()])) 68 | 69 | if epoch % opt.test_epoch_freq == 0: 70 | model.eval() 71 | with torch.no_grad(): 72 | model.test(epoch, dataset_test) 73 | test_losses = model.get_current_losses('test') 74 | for loss_name, loss_test in test_losses.items(): 75 | writer.add_scalar(f"test_{loss_name}", loss_test, global_step=epoch) 76 | print("Test losses |", ' '.join([f"{k}: {v:.3e}" for k, v in test_losses.items()])) 77 | 78 | if epoch % opt.save_epoch_freq == 0: 79 | print('Saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 80 | model.save_networks(epoch) 81 | model.save_networks('latest') 82 | 83 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs, time.time() - epoch_start_time)) 84 | 85 | model.update_learning_rate() 86 | 87 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bennyguo/sketch2model/c1431c22e33889595bfac9e693b889fc000af6ec/utils/__init__.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import pickle 8 | import importlib 9 | from collections import defaultdict 10 | 11 | 12 | def tensor2im(tensor): 13 | """"Converts a Tensor array into a numpy image array. 14 | 15 | Parameters: 16 | input_image (tensor) -- the input image tensor array 17 | imtype (type) -- the desired type of the converted numpy array 18 | """ 19 | tensor = tensor.cpu().numpy() 20 | im = (tensor.transpose(1, 2, 0).clip(0, 1) * 255).astype(np.uint8) 21 | return im 22 | 23 | 24 | def diagnose_network(net, name='network'): 25 | """Calculate and print the mean of average absolute(gradients) 26 | 27 | Parameters: 28 | net (torch network) -- Torch network 29 | name (str) -- the name of the network 30 | """ 31 | mean = 0.0 32 | count = 0 33 | for param in net.parameters(): 34 | if param.grad is not None: 35 | mean += torch.mean(torch.abs(param.grad.data)) 36 | count += 1 37 | if count > 0: 38 | mean = mean / count 39 | print(name) 40 | print(mean) 41 | 42 | 43 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 44 | """Save a numpy image to the disk 45 | 46 | Parameters: 47 | image_numpy (numpy array) -- input numpy array 48 | image_path (str) -- the path of the image 49 | """ 50 | 51 | image_pil = Image.fromarray(image_numpy) 52 | h, w, _ = image_numpy.shape 53 | 54 | if aspect_ratio > 1.0: 55 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 56 | if aspect_ratio < 1.0: 57 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 58 | image_pil.save(image_path) 59 | 60 | 61 | def print_numpy(x, val=True, shp=False): 62 | """Print the mean, min, max, median, std, and size of a numpy array 63 | 64 | Parameters: 65 | val (bool) -- if print the values of the numpy array 66 | shp (bool) -- if print the shape of the numpy array 67 | """ 68 | x = x.astype(np.float64) 69 | if shp: 70 | print('shape,', x.shape) 71 | if val: 72 | x = x.flatten() 73 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 74 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 75 | 76 | 77 | def mkdirs(paths): 78 | """create empty directories if they don't exist 79 | 80 | Parameters: 81 | paths (str list) -- a list of directory paths 82 | """ 83 | if isinstance(paths, list) and not isinstance(paths, str): 84 | for path in paths: 85 | mkdir(path) 86 | else: 87 | mkdir(paths) 88 | 89 | 90 | def mkdir(path): 91 | """create a single empty directory if it didn't exist 92 | 93 | Parameters: 94 | path (str) -- a single directory path 95 | """ 96 | if not os.path.exists(path): 97 | os.makedirs(path) 98 | 99 | 100 | def one_hot(n, idx): 101 | x = torch.zeros(n, dtype=torch.float32) 102 | if n > 0: 103 | x[idx] = 1. 104 | return x 105 | 106 | 107 | def batch_one_hot(n, idxs): 108 | return torch.stack([ 109 | one_hot(n, idx) for idx in idxs 110 | ], dim=0) 111 | 112 | 113 | def load_pickle(f): 114 | return pickle.load(open(f, 'rb'), encoding='latin1') 115 | 116 | 117 | def save_pickle(obj, f): 118 | pickle.dump(obj, open(f, 'wb')) 119 | 120 | 121 | def chunk_batch(func, chunk_size, *args, **kwargs): 122 | B = None 123 | for arg in args: 124 | if isinstance(arg, torch.Tensor): 125 | B = arg.shape[0] 126 | break 127 | out = defaultdict(list) 128 | out_dict = False 129 | for i in range(0, B, chunk_size): 130 | out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) 131 | if isinstance(out_chunk, torch.Tensor): 132 | out_chunk = {0: out_chunk} 133 | out_dict = False 134 | elif isinstance(out_chunk, dict): 135 | out_dict = True 136 | else: 137 | print(f'Return value of func must be in type [torch.Tensor, dict], get {type(out_chunk)}.') 138 | exit(1) 139 | for k, v in out_chunk.items(): 140 | out[k].append(v) 141 | 142 | out = {k: torch.cat(v, dim=0) for k, v in out.items()} 143 | return out if out_dict else out[0] 144 | 145 | 146 | def find_class_using_name(module_name, class_name, suffix='', type=object): 147 | """Import the module 148 | 149 | In the file, the class called ClassNameFuffix() will 150 | be instantiated. It has to be a subclass of type, 151 | and it is case-insensitive. 152 | """ 153 | filename = module_name 154 | lib = importlib.import_module(filename) 155 | target = None 156 | target_class_name = class_name.replace('_', '') + suffix 157 | for name, cls in lib.__dict__.items(): 158 | if name.lower() == target_class_name.lower() \ 159 | and issubclass(cls, type): 160 | target = cls 161 | 162 | if target is None: 163 | print(f"In {filename}.py, there should be a subclass of {type} with class name that matches {target_class_name} in lowercase.") 164 | exit(0) 165 | 166 | return target 167 | --------------------------------------------------------------------------------