├── .gitignore ├── README.md ├── __init__.py ├── config.py ├── core ├── .gitignore ├── README.md ├── __init__.py ├── annotations │ ├── .gitignore │ ├── __init__.py │ ├── benchmark.py │ ├── datasets.py │ ├── expert_verified.py │ ├── io_util.py │ ├── path.py │ ├── points.py │ ├── points_label.py │ └── segment.py ├── blender_renderings │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── path.py │ └── scripts │ │ ├── blender_render.py │ │ ├── check_archive.py │ │ ├── create_archive.py │ │ ├── render_cat.py │ │ └── vis.py ├── fixed_objs.py ├── frustrum_voxels │ ├── __init__.py │ └── scripts │ │ ├── check_frustrum_voxels.py │ │ └── create_frustrum_voxels.py ├── meshes │ ├── __init__.py │ └── scripts │ │ ├── check_mesh_data.py │ │ ├── generate_mesh_data.py │ │ └── remove_empty_meshes.py ├── objs.py ├── path.py ├── point_clouds │ ├── __init__.py │ └── scripts │ │ └── generate_point_clouds.py ├── renderings │ ├── __init__.py │ ├── archive.py │ ├── archive_manager.py │ ├── path.py │ ├── renderings_manager.py │ └── scripts │ │ ├── archive.py │ │ ├── blender_render.py │ │ ├── create_base_renderings.py │ │ └── report.py ├── scripts │ └── create_ids.py ├── views │ ├── __init__.py │ ├── archive.py │ ├── base.py │ ├── h5.py │ ├── lazy.py │ ├── manager.py │ ├── scripts │ │ ├── check_consistent.py │ │ └── create_base_data.py │ └── txt.py └── voxels │ ├── .gitignore │ ├── __init__.py │ ├── concat_ds.py │ ├── config.py │ ├── datasets.py │ ├── filled.py │ ├── path.py │ └── scripts │ ├── cat.py │ ├── create_all.py │ ├── create_dataset.py │ └── example.py ├── default_config.yaml ├── example ├── core │ ├── benchmark_frustrum.py │ ├── blender_renderings.py │ ├── clip_space_voxels.py │ ├── filled_voxels.py │ ├── frust_saved.py │ ├── frustrum.py │ ├── meshes.py │ ├── objs.py │ ├── point_clouds.py │ ├── train_test_split.py │ ├── vox_ds │ │ ├── compare_compression.py │ │ └── compare_formats.py │ └── voxels.py ├── iccv17 │ └── voxels.py └── r2n2 │ ├── angle_dist.py │ ├── binvox.py │ ├── hdf5.py │ └── renderings.py ├── iccv17 ├── README.md ├── __init__.py ├── path.py └── voxels │ └── __init__.py ├── image.py ├── path.py ├── r2n2 ├── .gitignore ├── README.md ├── __init__.py ├── hdf5.py ├── path.py ├── scripts │ ├── download.py │ ├── extract_renderings.py │ └── hdf5.py ├── split.py └── tgz.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | config.yaml 4 | data 5 | .vscode 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository provides python loading, manipulation and caching functions for interacting with the [ShapeNet](https://www.shapenet.org/) dataset. 2 | 3 | Dependencies: 4 | * [Dictionary Interface to Datasets (DIDS)](https://github.com/jackd/dids) 5 | * [util3d](https://github.com/jackd/util3d) 6 | 7 | # Setup 8 | 1. Install pip dependencies 9 | ```bash 10 | pip install numpy h5py progress wget 11 | ``` 12 | 2. Clone relevant repositories 13 | ```bash 14 | cd /path/to/parent_dir 15 | git clone https://github.com/jackd/dids.git 16 | git clone https://github.com/jackd/util3d.git 17 | git clone https://github.com/jackd/shapenet.git 18 | ``` 19 | 3. Add parent directory to `PYTHONPATH` 20 | ```bash 21 | export PYTHONPATH=$PYTHONPATH:/path/to/parent_dir 22 | ``` 23 | Consider adding this to your `~/.bashrc` file if you do not want to call it for each new terminal. 24 | 4. Copy `default_config.yaml` to `config.yaml` and make changes depending on where you have data saved etc. See comments in `default_config.yaml` for more. 25 | ```bash 26 | cd shapenet 27 | cp default_config.yaml config.yaml 28 | gedit config.yaml # make changes 29 | ``` 30 | 31 | ## Core Dataset 32 | We cannot provide data for the core dataset - see the [dataset website](https://www.shapenet.org/) for access. 33 | 34 | Assign the location of your the ShapeNet dataset to the `SHAPENET_CORE_PATH` environment variable, 35 | ``` 36 | export SHAPENET_CORE_PATH=/path/to/ShapeNetCore.v1 37 | ``` 38 | This directory should contain the zip files for each category, e.g. `02691156.zip` should contain all `obj` files for planes. 39 | 40 | ## ICCV 2017 Competition Dataset 41 | To use functions associated with the ICCV2017 competition dataset, after downloading the data, 42 | ``` 43 | export SHAPENET17_PATH=/path/to/shapenet2017/dataset 44 | ``` 45 | Currently this only supports uncompressed data. This directory should contain `train_imgs`, `test_images`, `train_voxels`, `train_imgs`, `val_imgs` and `val_voxels` directories as provided by the dataset. This may change in the future to support compressed versions. 46 | 47 | # Data 48 | Most data is provided via a [DIDS](https://github.com/jackd/dids) Dataset. A number of these datasets are saved to disk as required to reduce repeated calculations. These should be calculated and saved as required, but you can manually force the data preprocessing. For example, the following will generate meshes, point clouds and voxels for the plane category in the core dataset with default arguments. 49 | 50 | ``` 51 | cd core 52 | cd meshes/scripts 53 | python generate_mesh_data.py plane # parses objs to hdf5 vertices/faces 54 | cd ../../point_clouds/scripts 55 | python create_point_clouds.py plane # samples faces of meshes 56 | cd ../../voxels/scripts 57 | python create_voxels.py plane # creates voxels in binvox format 58 | python create_archive.py plane # zip the binvox files created above 59 | ``` 60 | 61 | Resulting datasets can be used much like dictionaries, though are required to be explicitly opened or used in a `with` block. 62 | 63 | ``` 64 | import numpy as np 65 | from shapenet.core.meshes import get_mesh_dataset 66 | from shapenet.core import cat_desc_to_id 67 | from util3d.mayavi_vis import vis_mesh 68 | from mayavi import mlab 69 | 70 | desc = 'plane' 71 | 72 | cat_id = cat_desc_to_id(desc) 73 | with get_mesh_dataset(cat_id) as mesh_dataset: 74 | for example_id in mesh_dataset: 75 | example_group = mesh_dataset[example_id] 76 | vertices, faces = ( 77 | np.array(example_group[k]) for k in ('vertices', 'faces')) 78 | vis_mesh(vertices, faces, color=(0, 0, 1), axis_order='xzy') 79 | mlab.show() 80 | ``` 81 | 82 | See [examples](https://github.com/jackd/shapenet/tree/master/example) for more. 83 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackd/shapenet/4ae662743b0e5d0bd4d96f224c96be811149b9eb/__init__.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import yaml 7 | 8 | root_dir = os.path.realpath(os.path.dirname(__file__)) 9 | _config_path = os.path.join(root_dir, 'config.yaml') 10 | 11 | 12 | def _create_config(): 13 | import shutil 14 | default_config_path = os.path.join(root_dir, 'default_config.yaml') 15 | if not os.path.isfile(default_config_path): 16 | raise IOError( 17 | 'No file at default_config_path "%s"' % default_config_path) 18 | shutil.copyfile(default_config_path, _config_path) 19 | 20 | 21 | if not os.path.isfile(_config_path): 22 | _create_config() 23 | 24 | 25 | def load_config(): 26 | with open(_config_path, 'r') as fp: 27 | config = yaml.load(fp) 28 | return config 29 | 30 | 31 | config = load_config() 32 | -------------------------------------------------------------------------------- /core/.gitignore: -------------------------------------------------------------------------------- 1 | split.csv 2 | -------------------------------------------------------------------------------- /core/README.md: -------------------------------------------------------------------------------- 1 | # General notes 2 | * Every plane example has `top = [0, 0, 1]` 3 | * Front values vary - mostly aligned with major axis, though not all 4 | * Despite varying front values, all look like they're facing the same way -- no need to rotate 5 | * Scale looks consistent -- no need to rescale 6 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import six 7 | from . import path 8 | 9 | _cat_descs = { 10 | '02691156': 'plane', 11 | '02773838': 'bag', 12 | '02801938': 'basket', 13 | '02808440': 'bathtub', 14 | '02818832': 'bed', 15 | '02828884': 'bench', 16 | '02834778': 'bicycle', 17 | '02843684': 'birdhouse', 18 | '02871439': 'bookshelf', 19 | '02876657': 'bottle', 20 | '02880940': 'bowl', 21 | '02924116': 'bus', 22 | '02933112': 'cabinet', 23 | '02747177': 'can', 24 | '02942699': 'camera', 25 | '02954340': 'cap', 26 | '02958343': 'car', 27 | '03001627': 'chair', 28 | '03046257': 'clock', 29 | '03207941': 'dishwasher', 30 | '03211117': 'monitor', 31 | '04379243': 'table', 32 | '04401088': 'telephone', 33 | '02946921': 'tin_can', 34 | '04460130': 'tower', 35 | '04468005': 'train', 36 | '03085013': 'keyboard', 37 | '03261776': 'earphone', 38 | '03325088': 'faucet', 39 | '03337140': 'file', 40 | '03467517': 'guitar', 41 | '03513137': 'helmet', 42 | '03593526': 'jar', 43 | '03624134': 'knife', 44 | '03636649': 'lamp', 45 | '03642806': 'laptop', 46 | '03691459': 'speaker', 47 | '03710193': 'mailbox', 48 | '03759954': 'microphone', 49 | '03761084': 'microwave', 50 | '03790512': 'motorcycle', 51 | '03797390': 'mug', 52 | '03928116': 'piano', 53 | '03938244': 'pillow', 54 | '03948459': 'pistol', 55 | '03991062': 'pot', 56 | '04004475': 'printer', 57 | '04074963': 'remote_control', 58 | '04090263': 'rifle', 59 | '04099429': 'rocket', 60 | '04225987': 'skateboard', 61 | '04256520': 'sofa', 62 | '04330267': 'stove', 63 | '04530566': 'watercraft', 64 | '04554684': 'washer', 65 | '02858304': 'boat', 66 | '02992529': 'cellphone' 67 | } 68 | 69 | 70 | def get_cat_ids(): 71 | cat_ids = list(_cat_descs.keys()) 72 | cat_ids.sort() 73 | return tuple(cat_ids) 74 | 75 | 76 | _cat_ids = {v: k for k, v in _cat_descs.items()} 77 | 78 | 79 | def get_test_train_split(): 80 | import os 81 | split_path = path.get_test_train_split_path() 82 | if not os.path.isfile(split_path): 83 | import wget 84 | url = ('http://shapenet.cs.stanford.edu/shapenet/obj-zip/SHREC16/' 85 | 'all.csv') 86 | wget.download(url, split_path) 87 | if not os.path.isfile(split_path): 88 | raise IOError('Failed to download test/train split from %s' % url) 89 | 90 | split = {k: {'train': [], 'test': [], 'val': []} for k in get_cat_ids()} 91 | with open(split_path, 'r') as fp: 92 | fp.readline() 93 | for line in fp.readlines(): 94 | line = line.rstrip() 95 | if len(line) > 0: 96 | _, cat_id, __, example_id, ds = line.split(',') 97 | split[cat_id][ds].append(example_id) 98 | return split 99 | 100 | 101 | def cat_id_to_desc(cat_id): 102 | if isinstance(cat_id, (list, tuple)): 103 | return tuple(_cat_descs[c] for c in cat_id) 104 | else: 105 | return _cat_descs[cat_id] 106 | 107 | 108 | def cat_desc_to_id(cat_desc): 109 | if isinstance(cat_desc, (list, tuple)): 110 | return tuple(_cat_ids[c] for c in cat_desc) 111 | else: 112 | return _cat_ids[cat_desc] 113 | 114 | 115 | def to_cat_desc(cat): 116 | if cat in _cat_descs: 117 | return _cat_descs[cat] 118 | elif cat in _cat_ids: 119 | return cat 120 | else: 121 | raise ValueError('cat %s is not a valid id or descriptor' % cat) 122 | 123 | 124 | def to_cat_id(cat): 125 | if cat in _cat_ids: 126 | return _cat_ids[cat] 127 | elif cat in _cat_descs: 128 | return cat 129 | else: 130 | raise ValueError('cat %s is not a valid id or descriptor' % cat) 131 | 132 | 133 | def get_example_ids(cat_id, include_bad=False): 134 | from .path import get_ids_path 135 | ids_path = get_ids_path(cat_id) 136 | if not os.path.isfile(ids_path): 137 | create_ids(cat_id) 138 | with open(ids_path, 'r') as fp: 139 | ids = fp.readlines() 140 | ids = [id.rstrip() for id in ids] 141 | ids.sort() 142 | if not include_bad: 143 | from .objs import get_bad_objs 144 | bad_objs = get_bad_objs(cat_id) 145 | ids = (i for i in ids if i not in bad_objs) 146 | return tuple(ids) 147 | 148 | 149 | def get_old_example_ids(cat_id): 150 | from .path import get_ids_path 151 | with open(get_ids_path(cat_id), 'r') as fp: 152 | ids = fp.readlines() 153 | ids = [id.rstrip() for id in ids] 154 | return ids 155 | 156 | 157 | def create_ids(cat_ids): 158 | from progress.bar import IncrementalBar 159 | if isinstance(cat_ids, six.string_types): 160 | cat_ids = [cat_ids] 161 | bar = IncrementalBar(max=len(cat_ids)) 162 | for cat_id in cat_ids: 163 | ids_path = path.get_ids_path(cat_id) 164 | d = os.path.dirname(ids_path) 165 | if not os.path.isdir(d): 166 | os.makedirs(d) 167 | example_ids = path.get_example_ids_from_zip(cat_id) 168 | with open(ids_path, 'w') as fp: 169 | fp.writelines( 170 | ('%s\n' % example_id for example_id in example_ids)) 171 | bar.next() 172 | bar.finish() 173 | 174 | 175 | # __all__ = [ 176 | # path, 177 | # get_cat_ids, 178 | # get_example_ids, 179 | # cat_id_to_desc, 180 | # cat_desc_to_id, 181 | # ] 182 | -------------------------------------------------------------------------------- /core/annotations/.gitignore: -------------------------------------------------------------------------------- 1 | ids.json 2 | seg_names.json 3 | -------------------------------------------------------------------------------- /core/annotations/__init__.py: -------------------------------------------------------------------------------- 1 | from expert_verified import get_seg_image, get_expert_point_labels 2 | from expert_verified import has_expert_labels 3 | from points import get_points 4 | from points_label import get_point_labels 5 | from path import get_zip_file 6 | import benchmark 7 | from segment import segment 8 | 9 | 10 | def get_best_point_labels(zipfile, cat_id, example_id): 11 | if has_expert_labels(zipfile, cat_id, example_id): 12 | return get_expert_point_labels(zipfile, cat_id, example_id) 13 | else: 14 | return get_point_labels(zipfile, cat_id, example_id) 15 | 16 | 17 | __all__ = [ 18 | get_seg_image, 19 | get_expert_point_labels, 20 | get_point_labels, 21 | get_points, 22 | get_zip_file, 23 | segment, 24 | has_expert_labels, 25 | benchmark, 26 | ] 27 | -------------------------------------------------------------------------------- /core/annotations/benchmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | import zipfile 5 | from path import annot_dir 6 | from shapenet.image import load_image_from_zip 7 | from io_util import parse_seg, parse_pts 8 | 9 | 10 | def get_zip_path(): 11 | return os.path.join( 12 | annot_dir, 'shapenetcore_partanno_segmentation_benchmark_v0.zip') 13 | 14 | 15 | def get_zip_file(): 16 | return zipfile.ZipFile(get_zip_path(), 'r') 17 | 18 | 19 | def _subpath(topic, cat_id, example_id, ext): 20 | fn = '%s.%s' % (example_id, ext) 21 | return os.path.join( 22 | 'shapenetcore_partanno_segmentation_benchmark_v0', cat_id, topic, fn) 23 | 24 | 25 | def get_points(zip_file, cat_id, example_id): 26 | with zip_file.open(_subpath('points', cat_id, example_id, 'pts')) as fp: 27 | pts = parse_pts(fp) 28 | return np.array(pts, dtype=np.float32) 29 | 30 | 31 | def get_point_labels(zip_file, cat_id, example_id): 32 | with zip_file.open( 33 | _subpath('points_label', cat_id, example_id, 'seg')) as fp: 34 | seg = parse_seg(fp) 35 | return np.array(seg, dtype=np.uint8) 36 | 37 | 38 | def get_seg_image(zip_file, cat_id, example_id): 39 | with zip_file.open( 40 | _subpath('points_label', cat_id, example_id, 'png')) as fp: 41 | image = load_image_from_zip(fp) 42 | return image 43 | 44 | 45 | def _read_ids(zip_file): 46 | ids = {} 47 | for name in zip_file.namelist(): 48 | split = name.split('/') 49 | if len(split) < 4: 50 | continue 51 | _, cat_id, topic, fn = split 52 | if fn == '' or topic != 'points': 53 | continue 54 | example_id = fn.split('.')[0] 55 | if cat_id not in ids: 56 | ids[cat_id] = [] 57 | ids[cat_id].append(example_id) 58 | ids = {k: sorted(v) for k, v in ids.items()} 59 | return ids 60 | 61 | 62 | _ids_path = os.path.join(os.path.dirname(__file__), 'ids.json') 63 | 64 | 65 | def get_ids(): 66 | if not os.path.isfile(_ids_path): 67 | with get_zip_file() as zf: 68 | ids = _read_ids(zf) 69 | with open(_ids_path, 'w') as fp: 70 | json.dump(ids, fp) 71 | else: 72 | with open(_ids_path, 'r') as fp: 73 | ids = json.load(fp) 74 | return ids 75 | 76 | 77 | if __name__ == '__main__': 78 | def vis(points, labels, image): 79 | from segment import segment 80 | from mayavi import mlab 81 | import matplotlib.pyplot as plt 82 | colors = ( 83 | (1, 1, 1), 84 | (0, 1, 0), 85 | (0, 0, 1), 86 | (1, 0, 0), 87 | (0, 1, 1), 88 | (1, 1, 0), 89 | (1, 0, 1), 90 | (0, 0, 0) 91 | ) 92 | points = segment(points, labels) 93 | for i, ps in enumerate(points): 94 | color = colors[i % len(colors)] 95 | x, z, y = ps.T 96 | mlab.points3d(x, y, z, color=color, scale_factor=0.005) 97 | if image is not None: 98 | plt.imshow(image) 99 | plt.show(block=False) 100 | mlab.show() 101 | 102 | ids = get_ids() 103 | cat_id = '02691156' 104 | example_ids = ids[cat_id] 105 | with get_zip_file() as zf: 106 | for example_id in example_ids: 107 | points = get_points(zf, cat_id, example_id) 108 | labels = get_point_labels(zf, cat_id, example_id) 109 | try: 110 | image = get_seg_image(zf, cat_id, example_id) 111 | except KeyError: 112 | image = None 113 | vis(points, labels, image) 114 | 115 | # from ffd.templates import get_templated_cat_ids 116 | # from ffd.templates import get_templated_example_ids 117 | # from shapenet.core import cat_id_to_desc 118 | # ids = get_ids() 119 | # counts = {} 120 | # for cat_id in get_templated_cat_ids(): 121 | # desc = cat_id_to_desc(cat_id) 122 | # counts[cat_id] = 0 123 | # if cat_id not in ids: 124 | # print('No template data for CAT %s' % desc) 125 | # else: 126 | # cid = set(ids[cat_id]) 127 | # for example_id in get_templated_example_ids(cat_id): 128 | # if example_id not in cid: 129 | # print('No template data for %s/%s' %(cat_id, example_id)) 130 | # else: 131 | # counts[cat_id] += 1 132 | # 133 | # for k, v in counts.items(): 134 | # print('%s: %d / %d' % (cat_id_to_desc(k), v, 30)) 135 | -------------------------------------------------------------------------------- /core/annotations/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import dids.file_io.zip_file_dataset as zfd 4 | from shapenet.image import load_image_from_zip 5 | from path import get_zip_path 6 | from points import get_points 7 | from io_util import parse_seg 8 | from expert_verified import _label_path, _get_subpath 9 | from expert_verified import has_expert_labels 10 | 11 | 12 | class _AnnotationsDataset(zfd.ZipFileDataset): 13 | def __init__(self, cat_id): 14 | super(_AnnotationsDataset, self).__init__(get_zip_path()) 15 | self._cat_id = cat_id 16 | 17 | def __contains__(self, key): 18 | return has_expert_labels(self._file, self._cat_id, key) 19 | 20 | 21 | class PointCloudDataset(_AnnotationsDataset): 22 | def __getitem__(self, key): 23 | return get_points(self._file, self._cat_id, key) 24 | 25 | 26 | class SegmentationDataset(_AnnotationsDataset): 27 | def __getitem__(self, key): 28 | with self._file.open(_label_path(self._cat_id, key)) as fp: 29 | data = parse_seg(fp) 30 | return np.array(data, np.int32) 31 | 32 | 33 | class SegmentedImageDataset(_AnnotationsDataset): 34 | def __getitem__(self, key): 35 | topic = os.path.join('expert_verified', 'seg_img') 36 | subpath = _get_subpath(self._cat_id, key, topic, 'png') 37 | return load_image_from_zip(self._file, subpath) 38 | 39 | 40 | def _main(cat_id, example_id): 41 | from dids import Dataset 42 | import matplotlib.pyplot as plt 43 | from util3d.mayavi_vis import vis_point_cloud 44 | from mayavi import mlab 45 | 46 | colors = ( 47 | (1, 0, 0), 48 | (0, 1, 0), 49 | (0, 0, 1), 50 | (1, 1, 1), 51 | ) 52 | 53 | image_ds = SegmentedImageDataset(cat_id) 54 | pc_ds = PointCloudDataset(cat_id) 55 | s_ds = SegmentationDataset(cat_id) 56 | ds = Dataset.zip(image_ds, pc_ds, s_ds) 57 | with ds: 58 | image, pc, s = ds[example_id] 59 | print(np.min(s)) 60 | print(np.max(s)) 61 | ns = np.max(s) + 1 62 | plt.imshow(image) 63 | for i in range(ns-1): 64 | cloud = pc[s == i+1] 65 | color = colors[i % len(colors)] 66 | vis_point_cloud(cloud, color=color, scale_factor=0.02) 67 | plt.show(block=False) 68 | mlab.show() 69 | 70 | 71 | if __name__ == '__main__': 72 | cat_id = '02691156' 73 | example_id = '1a04e3eab45ca15dd86060f189eb133' 74 | _main(cat_id, example_id) 75 | -------------------------------------------------------------------------------- /core/annotations/expert_verified.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from path import _get_subpath, get_zip_path 4 | from shapenet.image import load_image_from_zip 5 | from io_util import parse_seg 6 | import dids.file_io.zip_file_dataset as zfd 7 | 8 | 9 | class SegmentedDataset(zfd.ZipFileDataset): 10 | def __init__(self, cat_id): 11 | super(SegmentedDataset, self).__init__(get_zip_path()) 12 | self._cat_id = cat_id 13 | 14 | def keys(self): 15 | if not self.open: 16 | raise RuntimeError('Cannot check keys of closed dataset.') 17 | return example_ids_with_expert_labels(self._file, self._cat_id) 18 | 19 | 20 | class SegmentedPointCloudDataset(SegmentedDataset): 21 | def __getitem__(self, key): 22 | with self._file.open(_label_path(self._cat_id, key)) as fp: 23 | data = parse_seg(fp) 24 | return np.array(data) 25 | 26 | 27 | class SegmentedImageDataset(SegmentedDataset): 28 | def __getitem__(self, key): 29 | topic = os.path.join('expert_verified', 'seg_img') 30 | subpath = _get_subpath(self._cat_id, key, topic, 'png') 31 | return load_image_from_zip(self._file, subpath) 32 | 33 | 34 | def _label_path(cat_id, example_id): 35 | topic = os.path.join('expert_verified', 'points_label') 36 | return _get_subpath(cat_id, example_id, topic, 'seg') 37 | 38 | 39 | def has_expert_labels(zipfile, cat_id, example_id): 40 | try: 41 | with zipfile.open(_label_path(cat_id, example_id)): 42 | pass 43 | return True 44 | except KeyError: 45 | return False 46 | 47 | 48 | def example_ids_with_expert_labels(zipfile, cat_id): 49 | base_path = os.path.join( 50 | 'PartAnnotation', cat_id, 'expert_verified', 'points_label') 51 | n = len(base_path) 52 | return [name[n+1:-4] for name in zipfile.namelist() 53 | if name.startswith(base_path) and len(name) > n+1] 54 | 55 | 56 | def get_expert_point_labels(zipfile, cat_id, example_id): 57 | with zipfile.open(_label_path(cat_id, example_id)) as fp: 58 | point_labels = parse_seg(fp) 59 | 60 | data = np.array(point_labels, dtype=np.uint8) 61 | return data 62 | 63 | 64 | def get_seg_image(zipfile, cat_id, example_id): 65 | topic = os.path.join('expert_verified', 'seg_img') 66 | subpath = _get_subpath(cat_id, example_id, topic, 'png') 67 | return load_image_from_zip(zipfile, subpath) 68 | 69 | 70 | def _main(): 71 | # from path import get_zip_file 72 | from dids import Dataset 73 | import matplotlib.pyplot as plt 74 | cat_id = '02691156' 75 | example_id = '1a04e3eab45ca15dd86060f189eb133' 76 | ds = Dataset.zip( 77 | SegmentedImageDataset(cat_id), SegmentedPointCloudDataset(cat_id)) 78 | with ds: 79 | image, cloud = ds[example_id] 80 | print(cloud.shape) 81 | print(cloud.dtype) 82 | plt.imshow(image) 83 | plt.show() 84 | # with get_zip_file() as f: 85 | # print(np.min(get_expert_point_labels(f, cat_id, example_id))) 86 | 87 | # image = get_seg_image(f, cat_id, example_id) 88 | # plt.imshow(image) 89 | # plt.show() 90 | 91 | 92 | if __name__ == '__main__': 93 | _main() 94 | -------------------------------------------------------------------------------- /core/annotations/io_util.py: -------------------------------------------------------------------------------- 1 | def parse_seg(fp): 2 | return [int(r) for r in fp.readlines()] 3 | 4 | 5 | def parse_pts(fp): 6 | return [[float(a) for a in r.split(' ')] for r in fp.readlines()] 7 | -------------------------------------------------------------------------------- /core/annotations/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import zipfile 7 | 8 | 9 | def get_annot_dir(): 10 | from ...config import config 11 | return config['core_annotations_dir'] 12 | 13 | 14 | annot_dir = get_annot_dir() 15 | 16 | 17 | def get_zip_path(): 18 | return os.path.join(annot_dir, 'shapenetcore_partanno_v0.zip') 19 | 20 | 21 | def get_zip_file(): 22 | return zipfile.ZipFile(get_zip_path(), 'r') 23 | 24 | 25 | def _get_subpath(cat_id, example_id, topic, ext): 26 | fn = '%s.%s' % (example_id, ext) 27 | return os.path.join('PartAnnotation', cat_id, topic, fn) 28 | -------------------------------------------------------------------------------- /core/annotations/points.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from path import _get_subpath 3 | from io_util import parse_pts 4 | 5 | 6 | def get_points(zipfile, cat_id, example_id, dtype=np.float32): 7 | subpath = _get_subpath(cat_id, example_id, 'points', 'pts') 8 | with zipfile.open(subpath) as fp: 9 | pts = parse_pts(fp) 10 | return np.array(pts, dtype=dtype) 11 | 12 | 13 | def _main(): 14 | from path import get_zip_file 15 | from expert_verified import get_expert_point_labels, get_seg_image 16 | from expert_verified import example_ids_with_expert_labels 17 | from points_label import get_point_labels 18 | from shapenet.core.meshes import get_mesh_dataset 19 | import matplotlib.pyplot as plt 20 | from mayavi import mlab 21 | # example_id = '1a04e3eab45ca15dd86060f189eb133' 22 | 23 | def vis_all(zipfile, cat_id, example_id): 24 | points = get_points(zipfile, cat_id, example_id) 25 | expert_labels = get_expert_point_labels(zipfile, cat_id, example_id) 26 | labels = get_point_labels(zipfile, cat_id, example_id) 27 | image = get_seg_image(zipfile, cat_id, example_id) 28 | print(len(points)) 29 | 30 | with get_mesh_dataset(cat_id) as mesh_ds: 31 | example = mesh_ds[example_id] 32 | vertices, faces = (np.array(example[k]) for k in 33 | ('vertices', 'faces')) 34 | 35 | def vis(vertices, faces, points, labels): 36 | x, z, y = vertices.T 37 | mlab.triangular_mesh(x, y, z, faces, color=(1, 1, 1)) 38 | mlab.triangular_mesh( 39 | x, y, z, faces, color=(0, 0, 0), representation='wireframe') 40 | 41 | n = np.max(labels) + 1 42 | colors = ( 43 | (1, 1, 1), 44 | (0, 1, 0), 45 | (0, 0, 1), 46 | (1, 0, 0), 47 | (0, 1, 1), 48 | (1, 1, 0), 49 | (1, 0, 1), 50 | (0, 0, 0) 51 | ) 52 | colors = colors[:n] 53 | for i, c in enumerate(colors): 54 | x, z, y = points[labels == i].T 55 | mlab.points3d(x, y, z, color=c, opacity=0.8, scale_factor=0.02) 56 | 57 | print(np.min(points, axis=0), np.max(points, axis=0)) 58 | print(np.min(vertices, axis=0), np.max(vertices, axis=0)) 59 | mlab.figure() 60 | vis(vertices, faces, points, expert_labels) 61 | mlab.figure() 62 | vis(vertices, faces, points, labels) 63 | plt.imshow(image) 64 | plt.show(block=False) 65 | mlab.show() 66 | plt.close() 67 | 68 | cat_id = '02691156' 69 | print('cat_id: %s' % cat_id) 70 | with get_zip_file() as f: 71 | example_ids = example_ids_with_expert_labels(f, cat_id) 72 | for example_id in example_ids: 73 | print('example_id: %s' % example_id) 74 | vis_all(f, cat_id, example_id) 75 | 76 | 77 | if __name__ == '__main__': 78 | _main() 79 | -------------------------------------------------------------------------------- /core/annotations/points_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from io_util import parse_seg 5 | from path import _get_subpath 6 | 7 | _seg_names_path = os.path.join(os.path.dirname(__file__), 'seg_names.json') 8 | 9 | 10 | def _get_cat_dir(cat_id): 11 | return os.path.join('PartAnnotation', cat_id, 'points_label') 12 | 13 | 14 | def get_segmentation_names(cat_id=None): 15 | if not os.path.isfile(_seg_names_path): 16 | names = _create_segmentation_names() 17 | with open(_seg_names_path, 'w') as fp: 18 | json.dump(names, fp) 19 | else: 20 | with open(_seg_names_path, 'r') as fp: 21 | names = json.load(fp) 22 | 23 | if cat_id is None: 24 | return names 25 | else: 26 | return names[cat_id] 27 | 28 | 29 | def get_binary_point_labels(zipfile, cat_id, example_id, seg_name): 30 | path = _get_subpath( 31 | cat_id, example_id, os.path.join('points_label', seg_name), 'seg') 32 | with zipfile.open(path) as fp: 33 | labels = parse_seg(fp) 34 | return labels 35 | 36 | 37 | def get_point_labels(zipfile, cat_id, example_id): 38 | names = get_segmentation_names(cat_id) 39 | result = get_binary_point_labels(zipfile, cat_id, example_id, names[0]) 40 | result = np.array(result, dtype=np.uint8) 41 | for i, name in enumerate(names[1:]): 42 | bin_seg = get_binary_point_labels(zipfile, cat_id, example_id, name) 43 | result[bin_seg] = i + 2 44 | return result 45 | 46 | 47 | def _create_segmentation_names(): 48 | from path import get_zip_file 49 | names = {} 50 | with get_zip_file() as zf: 51 | for n in zf.namelist(): 52 | split = n.split('/') 53 | if len(split) >= 4 and split[2] == 'points_label': 54 | cat_id = split[1] 55 | class_name = split[3] 56 | if cat_id not in names: 57 | names[cat_id] = set() 58 | names[cat_id].add(class_name) 59 | 60 | def map_fn(s): 61 | li = list(s) 62 | li.sort() 63 | li = [v for v in li if v != ''] 64 | return li 65 | 66 | return {k: map_fn(v) for k, v in names.items()} 67 | 68 | 69 | def _main(): 70 | from path import get_zip_file 71 | from points import get_points 72 | # from expert_verified import get_seg_image 73 | from shapenet.core.meshes import get_mesh_dataset 74 | # import matplotlib.pyplot as plt 75 | from mayavi import mlab 76 | cat_id = '02691156' 77 | # example_id = '1a04e3eab45ca15dd86060f189eb133' 78 | # example_id = '1a6ad7a24bb89733f412783097373bdc' 79 | example_id = '1a9b552befd6306cc8f2d5fe7449af61' 80 | with get_zip_file() as f: 81 | points = get_points(f, cat_id, example_id) 82 | labels = get_point_labels(f, cat_id, example_id) 83 | names = get_segmentation_names(cat_id) 84 | # image = get_seg_image(f, cat_id, example_id) 85 | print(len(points)) 86 | names = ['unlabelled'] + names 87 | 88 | clouds = [] 89 | for i, name in enumerate(names): 90 | cloud = points[labels == i] 91 | clouds.append(cloud) 92 | print('%s: %d' % (name, len(cloud))) 93 | 94 | with get_mesh_dataset(cat_id) as mesh_ds: 95 | example = mesh_ds[example_id] 96 | vertices, faces = (np.array(example[k]) for k in ('vertices', 'faces')) 97 | 98 | x, z, y = vertices.T 99 | mlab.triangular_mesh(x, y, z, faces, color=(1, 1, 1)) 100 | mlab.triangular_mesh( 101 | x, y, z, faces, color=(0, 0, 0), representation='wireframe') 102 | 103 | colors = np.random.uniform(size=(np.max(labels)+1, 3)) 104 | for color, cloud in zip(colors, clouds): 105 | x, z, y = cloud.T 106 | mlab.points3d( 107 | x, y, z, color=tuple(color), opacity=0.8, scale_factor=0.02) 108 | print(np.min(points, axis=0), np.max(points, axis=0)) 109 | print(np.min(vertices, axis=0), np.max(vertices, axis=0)) 110 | # plt.imshow(image) 111 | # plt.show(block=False) 112 | mlab.show() 113 | # plt.close() 114 | 115 | 116 | if __name__ == '__main__': 117 | _main() 118 | -------------------------------------------------------------------------------- /core/annotations/segment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def segment(points, segmentation): 5 | n = np.max(segmentation) + 1 6 | seg_points = [points[segmentation == i] for i in range(n)] 7 | return seg_points 8 | -------------------------------------------------------------------------------- /core/blender_renderings/.gitignore: -------------------------------------------------------------------------------- 1 | _images/* 2 | _fixed_meshes/* 3 | -------------------------------------------------------------------------------- /core/blender_renderings/README.md: -------------------------------------------------------------------------------- 1 | Provides data on renderings done via blender. 2 | 3 | See `scripts/blender_render.py` for blender details. 4 | 5 | # Creating Renderings 6 | * `scripts/render_cat.py` creates renderings for a given render configuration (shape, number of images at even rotations about `z` axis) 7 | * `scripts/create_archive.py` zips renderings associated with the provided category, or adds specific `example_id`s to an existing archive 8 | * `scripts/check_archive.py` checks which examples have renderings associated with them in the given archive. 9 | 10 | Some `obj` files in the original dataset cannot be opened by blender. These models are extracted, imported into meshlab and re-exported as `obj`s, and the resulting files compressed in `_fixed_meshes/CAT_ID.zip`. These renderings can be created using `scripts/render_cat.py CAT_ID -f` 11 | 12 | # Accessing Image Data 13 | Data (either from compressed archives or raw files) can be accessed via `RenderConfig` methods. A `RenderConfig` specifies an image shape and the number of images rendered at linearly spaced angles. 14 | -------------------------------------------------------------------------------- /core/blender_renderings/__init__.py: -------------------------------------------------------------------------------- 1 | from config import RenderConfig 2 | 3 | __all__ = [ 4 | RenderConfig, 5 | ] 6 | -------------------------------------------------------------------------------- /core/blender_renderings/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import string 7 | import zipfile 8 | import numpy as np 9 | from . import path 10 | 11 | 12 | def get_config_id(shape, n_images, scale=None): 13 | shape_str = string.join((str(s) for s in shape), '-') 14 | id = 'r%s_%d' % (shape_str, n_images) 15 | if scale is not None and scale != 1: 16 | id = '%s_s%03d' % (id, int(100*scale)) 17 | return id 18 | 19 | 20 | def parse_config_id(config_id): 21 | substrs = config_id[1:].split('_') 22 | if len(substrs) == 2: 23 | shape_str, n_images = substrs 24 | scale = None 25 | else: 26 | shape_str, n_images = substrs[:2] 27 | scale = int(substrs[2][1:]) / 100 28 | n_images = int(n_images) 29 | shape = tuple(int(s) for s in shape_str.split('-')) 30 | return dict(shape=shape, n_images=n_images, scale=scale) 31 | 32 | 33 | class RenderConfig(object): 34 | def __init__( 35 | self, shape=(192, 256), n_images=8, scale=None, config_id=None): 36 | self._shape = tuple(shape) 37 | self._n_images = n_images 38 | self._scale = scale 39 | self._id = get_config_id(shape, n_images, scale) if config_id is None \ 40 | else config_id 41 | 42 | @staticmethod 43 | def from_id(config_id): 44 | args = parse_config_id(config_id) 45 | return RenderConfig(args['shape'], args['n_images'], config_id) 46 | 47 | @property 48 | def config_id(self): 49 | return self._id 50 | 51 | @property 52 | def scale(self): 53 | return self._scale 54 | 55 | @property 56 | def shape(self): 57 | return self._shape 58 | 59 | @property 60 | def n_images(self): 61 | return self._n_images 62 | 63 | def view_angle(self, view_index): 64 | return view_index * 360 // self._n_images 65 | 66 | def view_elevation(self, view_index): 67 | return np.rad2deg(np.atan(0.6)) 68 | 69 | @property 70 | def root_dir(self): 71 | return path.get_renderings_dir(self.config_id) 72 | 73 | def get_zip_path(self, cat_id): 74 | return os.path.join(self.root_dir, '%s.zip' % cat_id) 75 | 76 | def get_zip_file(self, cat_id, mode='r'): 77 | return zipfile.ZipFile(self.get_zip_path(cat_id), mode=mode) 78 | 79 | def has_zip_file(self, cat_id): 80 | return os.path.isfile(self.get_zip_path(cat_id)) 81 | 82 | def get_cat_dir(self, cat_id): 83 | return path.get_cat_dir(self.config_id, cat_id) 84 | 85 | def get_example_dir(self, cat_id, example_id): 86 | return path.get_example_dir(self.config_id, cat_id, example_id) 87 | 88 | def _path(self, subpath): 89 | return os.path.join(self.root_dir, subpath) 90 | 91 | def get_example_image_subpath(self, cat_id, example_id, view_angle): 92 | return path.get_example_image_subpath( 93 | cat_id, example_id, view_angle) 94 | 95 | def get_example_normals_subpath(self, cat_id, example_id, view_angle): 96 | return path.get_example_normals_subpath(cat_id, example_id, view_angle) 97 | 98 | def get_example_albedo_path(self, cat_id, example_id, view_angle): 99 | return path.get_example_albedo_subpath(cat_id, example_id, view_angle) 100 | 101 | def get_example_depth_subpath(self, cat_id, example_id, view_angle): 102 | return path.get_example_depth_subpath(cat_id, example_id, view_angle) 103 | 104 | def get_multi_view_dataset( 105 | self, cat_id, view_indices=None, example_ids=None, mode='r'): 106 | from shapenet.image import load_image_from_file 107 | from dids.file_io.zip_file_dataset import ZipFileDataset 108 | if view_indices is None: 109 | view_indices = range(self.n_images) 110 | view_angles = [self.view_angle(i) for i in view_indices] 111 | else: 112 | view_angles = {i: self.view_angle(i) for i in view_indices} 113 | 114 | def key_fn(key): 115 | example_id, view_index = key 116 | view_angle = view_angles[view_index] 117 | return self.get_example_image_subpath( 118 | cat_id, example_id, view_angle) 119 | 120 | dataset = ZipFileDataset(self.get_zip_path(cat_id), mode) 121 | dataset = dataset.map(load_image_from_file) 122 | dataset = dataset.map_keys(key_fn) 123 | if example_ids is not None: 124 | keys = [] 125 | for example_id in example_ids: 126 | keys.extend((example_id, k) for k in view_indices) 127 | dataset = dataset.subset(keys) 128 | return dataset 129 | 130 | def get_dataset(self, cat_id, view_index, example_ids=None, mode='r'): 131 | from shapenet.image import load_image_from_file 132 | from dids.file_io.zip_file_dataset import ZipFileDataset 133 | view_angle = self.view_angle(view_index) 134 | 135 | def key_fn(example_id): 136 | return self.get_example_image_subpath( 137 | cat_id, example_id, view_angle) 138 | 139 | def inverse_key_fn(path): 140 | subpaths = path.split('/') 141 | if len(subpaths) == 3: 142 | i = subpaths[1] 143 | p0 = self.get_example_image_subpath(cat_id, i, view_angle) 144 | if path == p0: 145 | return i 146 | return None 147 | 148 | dataset = ZipFileDataset(self.get_zip_path(cat_id), mode) 149 | dataset = dataset.map(load_image_from_file) 150 | dataset = dataset.map_keys(key_fn, inverse_key_fn) 151 | if example_ids is not None: 152 | dataset = dataset.subset(example_ids) 153 | return dataset 154 | 155 | 156 | if __name__ == '__main__': 157 | # import numpy as np 158 | import matplotlib.pyplot as plt 159 | from shapenet.core import cat_desc_to_id 160 | # from zipfile import ZipFile 161 | # from shapenet.image import load_image_from_zip 162 | from shapenet.image import with_background 163 | cat_desc = 'plane' 164 | cat_id = cat_desc_to_id(cat_desc) 165 | view_index = 5 166 | config = RenderConfig() 167 | with config.get_dataset(cat_id, view_index) as ds: 168 | for k, v in ds.items(): 169 | plt.imshow(with_background(v, 255)) 170 | plt.title(k) 171 | plt.show() 172 | # with ZipFile(config.get_zip_path(cat_id)) as zf: 173 | # for i in range(config.n_images): 174 | # subpath = config.get_example_image_subpath( 175 | # cat_id, example_id, config.view_angle(i)) 176 | # image = load_image_from_zip(zf, subpath) 177 | # 178 | # plt.figure() 179 | # plt.imshow(np.array(image)[..., :3]) 180 | # image = with_background(image, 255) 181 | # plt.figure() 182 | # plt.imshow(image) 183 | # plt.show() 184 | -------------------------------------------------------------------------------- /core/blender_renderings/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | blender_renderings_dir = os.path.realpath(os.path.dirname(__file__)) 4 | images_dir = os.path.join(blender_renderings_dir, '_images') 5 | 6 | 7 | def get_renderings_dir(config_id): 8 | return os.path.join(images_dir, config_id) 9 | 10 | 11 | def get_cat_dir(config_id, cat_id): 12 | return os.path.join(get_renderings_dir(config_id), cat_id) 13 | 14 | 15 | def get_example_dir(config_id, cat_id, example_id): 16 | return os.path.join(get_cat_dir(config_id, cat_id), example_id) 17 | 18 | 19 | def _get_example_subpath(cat_id, example_id, view_angle, extra): 20 | return os.path.join( 21 | cat_id, example_id, 22 | '%s_r_%03d%s.png' % (example_id, view_angle, extra)) 23 | 24 | 25 | def _get_example_path(config_id, cat_id, example_id, view_angle, extra): 26 | return os.path.join( 27 | get_renderings_dir(config_id), 28 | _get_example_subpath(cat_id, example_id, view_angle, extra)) 29 | 30 | 31 | def get_example_image_subpath(cat_id, example_id, view_angle): 32 | return _get_example_subpath(cat_id, example_id, view_angle, '') 33 | 34 | 35 | def get_example_normals_subpath(cat_id, example_id, view_angle): 36 | return _get_example_subpath( 37 | cat_id, example_id, view_angle, '_normals.png0001') 38 | 39 | 40 | def get_example_albedo_subpath(cat_id, example_id, view_angle): 41 | return _get_example_subpath( 42 | cat_id, example_id, view_angle, '_albedo.png0001') 43 | 44 | 45 | def get_example_depth_subpath(cat_id, example_id, view_angle): 46 | return _get_example_subpath( 47 | cat_id, example_id, view_angle, '_depth.png0001') 48 | 49 | 50 | def get_fixed_meshes_dir(): 51 | return os.path.join(blender_renderings_dir, '_fixed_meshes') 52 | 53 | 54 | def get_fixed_meshes_zip_path(cat_id): 55 | return os.path.join(get_fixed_meshes_dir(), '%s.zip' % cat_id) 56 | 57 | 58 | if __name__ == '__main__': 59 | from shapenet.core import cat_desc_to_id, get_example_ids 60 | from shapenet.image import load_image_from_zip, with_background 61 | import random 62 | import zipfile 63 | import matplotlib.pyplot as plt 64 | from config import RenderConfig 65 | cat_desc = 'plane' 66 | cat_id = cat_desc_to_id(cat_desc) 67 | example_ids = get_example_ids(cat_id) 68 | random.shuffle(example_ids) 69 | 70 | config = RenderConfig() 71 | with zipfile.ZipFile(config.get_zip_path(cat_id), 'r') as zf: 72 | for example_id in example_ids: 73 | subpath = get_example_image_subpath( 74 | cat_id, example_id, config.view_angle(5)) 75 | image = load_image_from_zip(zf, subpath) 76 | image = with_background(image, 255) 77 | plt.imshow(image) 78 | plt.show() 79 | -------------------------------------------------------------------------------- /core/blender_renderings/scripts/blender_render.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple script that uses blender to render views of a single object by 3 | rotation the camera around it. 4 | 5 | Also produces depth map at the same time. 6 | 7 | Example: 8 | blender --background --python blender_render.py -- --views 10 /path/to/my.obj 9 | 10 | Original source: 11 | https://github.com/panmari/stanford-shapenet-renderer 12 | """ 13 | 14 | 15 | def main( 16 | depth_scale, obj, scale, remove_doubles, edge_split, output_folder, 17 | views, shape): 18 | import os 19 | from math import radians 20 | import bpy 21 | 22 | # Set up rendering of depth map: 23 | bpy.context.scene.use_nodes = True 24 | tree = bpy.context.scene.node_tree 25 | links = tree.links 26 | 27 | # Add passes for additionally dumping albed and normals. 28 | bpy.context.scene.render.layers["RenderLayer"].use_pass_normal = True 29 | bpy.context.scene.render.layers["RenderLayer"].use_pass_color = True 30 | 31 | # clear default nodes 32 | for n in tree.nodes: 33 | tree.nodes.remove(n) 34 | 35 | # create input render layer node 36 | rl = tree.nodes.new('CompositorNodeRLayers') 37 | 38 | map = tree.nodes.new(type="CompositorNodeMapValue") 39 | # Size is chosen kind of arbitrarily, try out until you're satisfied with 40 | # resulting depth map. 41 | map.offset = [-0.7] 42 | map.size = [depth_scale] 43 | map.use_min = True 44 | map.min = [0] 45 | map.use_max = True 46 | map.max = [255] 47 | try: 48 | links.new(rl.outputs['Z'], map.inputs[0]) 49 | except KeyError: 50 | # some versions of blender don't like this? 51 | pass 52 | 53 | invert = tree.nodes.new(type="CompositorNodeInvert") 54 | links.new(map.outputs[0], invert.inputs[1]) 55 | 56 | # create a file output node and set the path 57 | depthFileOutput = tree.nodes.new(type="CompositorNodeOutputFile") 58 | depthFileOutput.label = 'Depth Output' 59 | links.new(invert.outputs[0], depthFileOutput.inputs[0]) 60 | 61 | scale_normal = tree.nodes.new(type="CompositorNodeMixRGB") 62 | scale_normal.blend_type = 'MULTIPLY' 63 | # scale_normal.use_alpha = True 64 | scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1) 65 | links.new(rl.outputs['Normal'], scale_normal.inputs[1]) 66 | 67 | bias_normal = tree.nodes.new(type="CompositorNodeMixRGB") 68 | bias_normal.blend_type = 'ADD' 69 | # bias_normal.use_alpha = True 70 | bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0) 71 | links.new(scale_normal.outputs[0], bias_normal.inputs[1]) 72 | 73 | normalFileOutput = tree.nodes.new(type="CompositorNodeOutputFile") 74 | normalFileOutput.label = 'Normal Output' 75 | links.new(bias_normal.outputs[0], normalFileOutput.inputs[0]) 76 | 77 | albedoFileOutput = tree.nodes.new(type="CompositorNodeOutputFile") 78 | albedoFileOutput.label = 'Albedo Output' 79 | # For some reason, 80 | links.new(rl.outputs['Color'], albedoFileOutput.inputs[0]) 81 | 82 | # Delete default cube 83 | bpy.data.objects['Cube'].select = True 84 | bpy.ops.object.delete() 85 | 86 | bpy.ops.import_scene.obj(filepath=obj) 87 | 88 | if scale != 1: 89 | bpy.ops.transform.resize(value=(scale, scale, scale)) 90 | 91 | for object in bpy.context.scene.objects: 92 | if object.name in ['Camera', 'Lamp']: 93 | continue 94 | bpy.context.scene.objects.active = object 95 | if scale != 1: 96 | bpy.ops.object.transform_apply(scale=True) 97 | if remove_doubles: 98 | bpy.ops.object.mode_set(mode='EDIT') 99 | bpy.ops.mesh.remove_doubles() 100 | bpy.ops.object.mode_set(mode='OBJECT') 101 | if edge_split: 102 | bpy.ops.object.modifier_add(type='EDGE_SPLIT') 103 | bpy.context.object.modifiers["EdgeSplit"].split_angle = 1.32645 104 | bpy.ops.object.modifier_apply( 105 | apply_as='DATA', modifier="EdgeSplit") 106 | 107 | # Make light just directional, disable shadows. 108 | lamp = bpy.data.lamps['Lamp'] 109 | lamp.type = 'SUN' 110 | lamp.shadow_method = 'NOSHADOW' 111 | # Possibly disable specular shading: 112 | lamp.use_specular = False 113 | 114 | # Add another light source so stuff facing away from light is not 115 | # completely dark 116 | bpy.ops.object.lamp_add(type='SUN') 117 | lamp2 = bpy.data.lamps['Sun'] 118 | lamp2.shadow_method = 'NOSHADOW' 119 | lamp2.use_specular = False 120 | lamp2.energy = 0.015 121 | sun = bpy.data.objects['Sun'] 122 | sun.rotation_euler = bpy.data.objects['Lamp'].rotation_euler 123 | sun.rotation_euler[0] += 180 124 | 125 | def parent_obj_to_camera(b_camera): 126 | origin = (0, 0, 0) 127 | b_empty = bpy.data.objects.new("Empty", None) 128 | b_empty.location = origin 129 | b_camera.parent = b_empty # setup parenting 130 | 131 | scn = bpy.context.scene 132 | scn.objects.link(b_empty) 133 | scn.objects.active = b_empty 134 | return b_empty 135 | 136 | scene = bpy.context.scene 137 | scene.render.resolution_x = shape[1] 138 | scene.render.resolution_y = shape[0] 139 | scene.render.resolution_percentage = 100 140 | scene.render.alpha_mode = 'TRANSPARENT' 141 | cam = scene.objects['Camera'] 142 | cam.location = (0, 1, 0.6) 143 | cam_constraint = cam.constraints.new(type='TRACK_TO') 144 | cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' 145 | cam_constraint.up_axis = 'UP_Y' 146 | b_empty = parent_obj_to_camera(cam) 147 | cam_constraint.target = b_empty 148 | 149 | model_identifier = os.path.split(os.path.split(obj)[0])[1] 150 | fp = os.path.join(output_folder, model_identifier, model_identifier) 151 | scene.render.image_settings.file_format = 'PNG' # set output format to png 152 | 153 | stepsize = 360.0 / views 154 | # rotation_mode = 'XYZ' 155 | 156 | for output_node in [depthFileOutput, normalFileOutput, albedoFileOutput]: 157 | output_node.base_path = '' 158 | 159 | for i in range(0, views): 160 | print("Rotation {}, {}".format((stepsize * i), radians(stepsize * i))) 161 | 162 | scene.render.filepath = fp + '_r_{0:03d}'.format(int(i * stepsize)) 163 | depthFileOutput.file_slots[0].path = \ 164 | scene.render.filepath + "_depth.png" 165 | normalFileOutput.file_slots[0].path = \ 166 | scene.render.filepath + "_normal.png" 167 | albedoFileOutput.file_slots[0].path = \ 168 | scene.render.filepath + "_albedo.png" 169 | 170 | bpy.ops.render.render(write_still=True) # render still 171 | 172 | b_empty.rotation_euler[2] += radians(stepsize) 173 | 174 | 175 | def get_args(): 176 | import argparse 177 | import sys 178 | parser = argparse.ArgumentParser( 179 | description='Renders given obj file by rotation a camera around it.') 180 | parser.add_argument('--views', type=int, default=30, 181 | help='number of views to be rendered') 182 | parser.add_argument('obj', type=str, 183 | help='Path to the obj file to be rendered.') 184 | parser.add_argument('--output_folder', type=str, default='/tmp', 185 | help='The path the output will be dumped to.') 186 | parser.add_argument('--scale', type=float, default=1, 187 | help='Scaling factor applied to model. ' 188 | 'Depends on size of mesh.') 189 | parser.add_argument('--remove_doubles', action='store_true', 190 | help='Remove double vertices to improve mesh quality.') 191 | parser.add_argument('--edge_split', action='store_true', 192 | help='Adds edge split filter.') 193 | parser.add_argument('--depth_scale', type=float, default=1.4, 194 | help='Scaling that is applied to depth. ' 195 | 'Depends on size of mesh. ' 196 | 'Try out various values until you get a good ' 197 | 'result.') 198 | parser.add_argument('--shape', type=int, default=[192, 256], nargs=2, 199 | help='2D shape of rendered images.') 200 | 201 | argv = sys.argv[sys.argv.index("--") + 1:] 202 | args = parser.parse_args(argv) 203 | return args 204 | 205 | 206 | args = get_args() 207 | main( 208 | args.depth_scale, args.obj, args.scale, args.remove_doubles, 209 | args.edge_split, args.output_folder, args.views, args.shape) 210 | -------------------------------------------------------------------------------- /core/blender_renderings/scripts/check_archive.py: -------------------------------------------------------------------------------- 1 | """Checks rendered archive for renderings of all example ids.""" 2 | 3 | 4 | def check_zip(cat_desc, shape, n_images): 5 | import zipfile 6 | from shapenet.core.blender_renderings.config import RenderConfig 7 | from shapenet.core import cat_desc_to_id, get_example_ids 8 | cat_id = cat_desc_to_id(cat_desc) 9 | 10 | config = RenderConfig(shape=shape, n_images=n_images) 11 | rendered_ids = set() 12 | with zipfile.ZipFile(config.get_zip_path(cat_id)) as zf: 13 | for name in zf.namelist(): 14 | rendered_ids.add(name.split('/')[1]) 15 | 16 | not_rendered_count = 0 17 | example_ids = get_example_ids(cat_id) 18 | for example_id in example_ids: 19 | if example_id not in rendered_ids: 20 | print(example_id) 21 | not_rendered_count += 1 22 | 23 | if not_rendered_count > 0: 24 | print('%d / %d not rendered' % (not_rendered_count, len(example_ids))) 25 | else: 26 | print('All %d %ss rendered!' % (len(example_ids), cat_desc)) 27 | 28 | 29 | if __name__ == '__main__': 30 | import argparse 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('cat', type=str) 33 | parser.add_argument('--shape', type=int, nargs=2, default=[192, 256]) 34 | parser.add_argument('-n', '--n_images', type=int, default=8) 35 | args = parser.parse_args() 36 | check_zip(args.cat, args.shape, args.n_images) 37 | -------------------------------------------------------------------------------- /core/blender_renderings/scripts/create_archive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """Creates or adds to an archive of renderings.""" 3 | 4 | 5 | def create_archive( 6 | cat_desc, shape=(192, 256), n_images=8, scale=None, example_ids=None, 7 | delete_src=False): 8 | import os 9 | from shapenet.core import cat_desc_to_id 10 | from shapenet.core import get_example_ids 11 | from shapenet.core.blender_renderings.config import RenderConfig 12 | from progress.bar import IncrementalBar 13 | import zipfile 14 | cat_id = cat_desc_to_id(cat_desc) 15 | if example_ids is None or len(example_ids) == 0: 16 | example_ids = get_example_ids(cat_id) 17 | config = RenderConfig(shape=shape, n_images=n_images, scale=scale) 18 | zip_path = config.get_zip_path(cat_id) 19 | with zipfile.ZipFile(zip_path, mode='a', allowZip64=True) as zf: 20 | bar = IncrementalBar(max=len(example_ids)) 21 | for example_id in example_ids: 22 | example_dir = config.get_example_dir(cat_id, example_id) 23 | if not os.path.isdir(example_dir): 24 | print('No directory at %s' % example_dir) 25 | else: 26 | for fn in os.listdir(example_dir): 27 | src = os.path.join(example_dir, fn) 28 | dst = os.path.join(cat_id, example_id, fn) 29 | zf.write(src, dst) 30 | bar.next() 31 | bar.finish() 32 | if delete_src: 33 | import shutil 34 | print('Removing src...') 35 | shutil.rmtree(config.get_cat_dir(cat_id)) 36 | 37 | 38 | if __name__ == '__main__': 39 | import argparse 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('cat', type=str) 42 | parser.add_argument('--shape', type=int, nargs=2, default=[192, 256]) 43 | parser.add_argument('--scale', type=float, default=None) 44 | parser.add_argument('-n', '--n_images', type=int, default=8) 45 | parser.add_argument('-d', '--debug', action='store_true') 46 | parser.add_argument('-o', '--overwrite', action='store_true') 47 | parser.add_argument('-i', '--example_ids', nargs='*') 48 | parser.add_argument('--delete_src', action='store_true') 49 | args = parser.parse_args() 50 | create_archive( 51 | args.cat, args.shape, args.n_images, args.scale, args.example_ids, 52 | delete_src=args.delete_src) 53 | -------------------------------------------------------------------------------- /core/blender_renderings/scripts/render_cat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """Create uncompressed renderings.""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import shutil 10 | import subprocess 11 | import tempfile 12 | from datetime import datetime 13 | from shapenet.core import get_example_ids, cat_desc_to_id 14 | from shapenet.core.path import get_zip_path, get_example_subdir, \ 15 | get_obj_subpath 16 | from shapenet.core.blender_renderings.path import get_fixed_meshes_zip_path 17 | from shapenet.core.blender_renderings.config import RenderConfig 18 | 19 | _FNULL = open(os.devnull, 'w') 20 | 21 | 22 | def render_obj( 23 | config, obj_path, output_dir, call_kwargs, blender_path='blender'): 24 | script_path = os.path.join( 25 | os.path.realpath(os.path.dirname(__file__)), 'blender_render.py') 26 | scale_str = '1' if config.scale is None else str(config.scale) 27 | subprocess.call([ 28 | blender_path, 29 | '--background', 30 | '--python', script_path, '--', 31 | '--views', str(config.n_images), 32 | '--shape', str(config.shape[0]), str(config.shape[1]), 33 | '--scale', scale_str, 34 | '--output_folder', output_dir, 35 | '--remove_doubles', 36 | '--edge_split', 37 | obj_path, 38 | ], **call_kwargs) 39 | 40 | 41 | def get_file_index(zf): 42 | ret = {} 43 | for f in zf.namelist(): 44 | key = f.split('/') 45 | if len(key) > 2: 46 | ret.setdefault('/'.join(key[:2]), []).append(f) 47 | return ret 48 | 49 | 50 | def render_example( 51 | config, cat_id, example_id, zip_file, overwrite, call_kwargs, 52 | blender_path='blender', verbose=False, file_index=None): 53 | if file_index is None: 54 | print('Warning: computing file_index in render_example. ' 55 | 'Highly inefficient if rendering multiple examples') 56 | file_index = get_file_index(zip_file) 57 | subdir = get_example_subdir(cat_id, example_id) 58 | cat_dir = config.get_cat_dir(cat_id) 59 | example_dir = config.get_example_dir(cat_id, example_id) 60 | if not overwrite: 61 | # some versions of blender create 4 files per image? 62 | # standard is a .png for each or image, normal, albedo 63 | if (os.path.isdir(example_dir) and 64 | len(os.listdir(example_dir)) in ( 65 | 3*config.n_images, 4*config.n_images)): 66 | return False 67 | else: 68 | if os.path.isdir(example_dir): 69 | shutil.rmtree(example_dir) 70 | if not os.path.isdir(cat_dir): 71 | os.makedirs(cat_dir) 72 | 73 | paths = file_index.get(subdir) 74 | if paths is None: 75 | return False 76 | 77 | tmp = tempfile.mkdtemp(prefix='shapenet') 78 | 79 | for f in paths: 80 | zip_file.extract(f, tmp) 81 | # for f in zip_file.namelist(): 82 | # if f.startswith(subdir): 83 | # zip_file.extract(f, tmp) 84 | subpath = get_obj_subpath(cat_id, example_id) 85 | obj_path = os.path.join(tmp, subpath) 86 | if verbose: 87 | print('') 88 | print(datetime.now()) 89 | print('Rendering %s' % example_id) 90 | render_obj( 91 | config, obj_path, cat_dir, call_kwargs, blender_path=blender_path) 92 | shutil.rmtree(tmp) 93 | return True 94 | 95 | 96 | def render_cat( 97 | config, cat_id, overwrite, reverse=False, debug=False, 98 | example_ids=None, use_fixed_meshes=False, blender_path='blender', 99 | verbose=False): 100 | import zipfile 101 | from progress.bar import IncrementalBar 102 | call_kwargs = {} if debug else dict( 103 | stdout=_FNULL, stderr=subprocess.STDOUT) 104 | if example_ids is None or len(example_ids) == 0: 105 | example_ids = get_example_ids(cat_id) 106 | if reverse: 107 | example_ids = example_ids[-1::-1] 108 | print('Rendering %d images for cat %s' % (len(example_ids), cat_id)) 109 | bar = IncrementalBar(max=len(example_ids)) 110 | if use_fixed_meshes: 111 | zip_path = get_fixed_meshes_zip_path(cat_id) 112 | else: 113 | zip_path = get_zip_path(cat_id) 114 | with zipfile.ZipFile(zip_path) as zip_file: 115 | file_index = get_file_index(zip_file) 116 | for example_id in example_ids: 117 | bar.next() 118 | render_example( 119 | config, cat_id, example_id, zip_file, 120 | overwrite, call_kwargs, blender_path=blender_path, 121 | verbose=verbose, file_index=file_index) 122 | bar.finish() 123 | 124 | 125 | if __name__ == '__main__': 126 | import argparse 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('cat', type=str) 129 | parser.add_argument('--shape', type=int, nargs=2, default=[192, 256]) 130 | parser.add_argument('--scale', type=float, default=None) 131 | parser.add_argument('--blender_path', type=str, default='blender') 132 | parser.add_argument('-n', '--n_images', type=int, default=8) 133 | parser.add_argument('-d', '--debug', action='store_true') 134 | parser.add_argument('-r', '--reverse', action='store_true') 135 | parser.add_argument('-o', '--overwrite', action='store_true') 136 | parser.add_argument('-i', '--example_ids', nargs='*') 137 | parser.add_argument('-f', '--fixed_meshes', action='store_true') 138 | parser.add_argument('-v', '--verbose', action='store_true') 139 | args = parser.parse_args() 140 | config = RenderConfig(args.shape, args.n_images, args.scale) 141 | cat_id = cat_id = cat_desc_to_id(args.cat) 142 | render_cat(config, cat_id, args.overwrite, args.reverse, args.debug, 143 | args.example_ids, args.fixed_meshes, args.blender_path, 144 | args.verbose) 145 | -------------------------------------------------------------------------------- /core/blender_renderings/scripts/vis.py: -------------------------------------------------------------------------------- 1 | def vis(cat, n_images, view_index=5, example_ids=None): 2 | import matplotlib.pyplot as plt 3 | from shapenet.core import cat_desc_to_id, get_example_ids 4 | from shapenet.core.blender_renderings.config import RenderConfig 5 | cat_id = cat_desc_to_id(cat) 6 | config = RenderConfig(n_images=n_images) 7 | dataset = config.get_dataset(cat_id, view_index) 8 | if example_ids is not None and len(example_ids) > 0: 9 | dataset = dataset.subset(example_ids) 10 | else: 11 | example_ids = get_example_ids(cat_id) 12 | with dataset: 13 | for example_id in example_ids: 14 | plt.imshow(dataset[example_id]) 15 | plt.title(example_id) 16 | plt.show() 17 | 18 | 19 | if __name__ == '__main__': 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('cat', type=str) 23 | parser.add_argument('--shape', type=int, nargs=2, default=[192, 256]) 24 | parser.add_argument('--blender_path', type=str, default='blender') 25 | parser.add_argument('-n', '--n_images', type=int, default=8) 26 | parser.add_argument('-i', '--example_ids', nargs='*') 27 | parser.add_argument('-v', '--view_index', default=5, type=int) 28 | args = parser.parse_args() 29 | 30 | vis(args.cat, args.n_images, args.view_index, args.example_ids) 31 | -------------------------------------------------------------------------------- /core/fixed_objs.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import os 5 | from .path import get_data_dir 6 | 7 | fixed_obj_data_dir = get_data_dir('fixed_meshes') 8 | 9 | _bad_ids = {'02958343': ('f9c1d7748c15499c6f2bd1c4e9adb41')} 10 | 11 | 12 | def get_fixed_obj_dir(cat_id, example_id=None): 13 | if example_id is None: 14 | return os.path.join(fixed_obj_data_dir, cat_id) 15 | else: 16 | return os.path.join(fixed_obj_data_dir, cat_id, example_id) 17 | 18 | 19 | def get_fixed_obj_path(cat_id, example_id): 20 | return os.path.join(get_fixed_obj_dir(cat_id, example_id), 'model.obj') 21 | 22 | 23 | def get_fixed_example_ids(cat_id=None): 24 | if cat_id is None: 25 | return _bad_ids.copy() 26 | else: 27 | return _bad_ids.get(cat_id, ()) 28 | 29 | 30 | def is_fixed_obj(cat_id, example_id): 31 | return example_id in get_fixed_example_ids(cat_id) 32 | -------------------------------------------------------------------------------- /core/frustrum_voxels/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import numpy as np 7 | import h5py 8 | from ..voxels.datasets import get_manager as get_voxel_manager, GROUP_KEY 9 | 10 | from util3d.transform.frustrum import voxel_values_to_frustrum 11 | from util3d.transform.nonhom import get_eye_to_world_transform 12 | from util3d.voxel.binvox import DenseVoxels, RleVoxels 13 | 14 | 15 | def _make_dir(filename): 16 | folder = os.path.dirname(filename) 17 | if not os.path.isdir(folder): 18 | os.makedirs(folder) 19 | 20 | 21 | def convert(vox, eye, ray_shape, f): 22 | dense_data = vox.dense_data() 23 | dense_data = dense_data[:, -1::-1] 24 | n = np.linalg.norm(eye) 25 | R, t = get_eye_to_world_transform(eye) 26 | z_near = n - 0.5 27 | z_far = z_near + 1 28 | 29 | frust, inside = voxel_values_to_frustrum( 30 | dense_data, R, t, f, z_near, z_far, ray_shape, 31 | include_corners=False) 32 | frust[np.logical_not(inside)] = 0 33 | frust = frust[:, -1::-1] 34 | return DenseVoxels(frust) 35 | 36 | 37 | def _get_frustrum_voxels_path( 38 | voxel_config, view_id, out_dim, cat_id, code=None): 39 | from ..path import get_data_dir 40 | fn = ('%s.hdf5' % cat_id) if code is None else ( 41 | '%s_%s.hdf5' % (code, cat_id)) 42 | return os.path.join( 43 | get_data_dir('frustrum_voxels'), voxel_config.voxel_id, view_id, 44 | 'v%03d' % out_dim, fn) 45 | 46 | 47 | def get_frustrum_voxels_path( 48 | voxel_config, view_id, out_dim, cat_id): 49 | return _get_frustrum_voxels_path( 50 | voxel_config, view_id, out_dim, cat_id) 51 | 52 | 53 | def get_frustrum_voxels_data(voxel_config, view_id, out_dim, cat_id): 54 | return h5py.File(get_frustrum_voxels_path( 55 | voxel_config, view_id, out_dim, cat_id), 'r') 56 | 57 | 58 | def create_temp_frustrum_voxels( 59 | view_manager, voxel_config, out_dim, cat_id, compression='lzf'): 60 | from progress.bar import IncrementalBar 61 | view_params = view_manager.get_view_params() 62 | n_views = view_params['n_views'] 63 | f = view_params['f'] 64 | in_dims = (voxel_config.voxel_dim,) * 3 65 | ray_shape = (out_dim,) * 3 66 | example_ids = tuple(view_manager.get_example_ids(cat_id)) 67 | n0 = len(example_ids) 68 | 69 | temp_path = _get_frustrum_voxels_path( 70 | voxel_config, view_manager.view_id, out_dim, 71 | cat_id, code='temp') 72 | _make_dir(temp_path) 73 | with h5py.File(temp_path, 'a') as vox_dst: 74 | attrs = vox_dst.attrs 75 | prog = attrs.get('prog', 0) 76 | if prog == n0: 77 | return temp_path 78 | 79 | attrs.setdefault('n_views', n_views) 80 | max_len = attrs.setdefault('max_len', 0) 81 | 82 | vox_manager = get_voxel_manager( 83 | voxel_config, cat_id, key='rle', 84 | compression=compression, shape_key='pad') 85 | vox_manager.get_dataset() # ensure data exists 86 | assert(vox_manager.has_dataset()) 87 | with h5py.File(vox_manager.path, 'r') as vox_src: 88 | rle_src = vox_src[GROUP_KEY] 89 | 90 | n, m = rle_src.shape 91 | max_max_len = m * 3 92 | assert(n == n0) 93 | 94 | print( 95 | 'Creating temp rle frustrum voxel data at %s' % temp_path) 96 | rle_dst = vox_dst.require_dataset( 97 | GROUP_KEY, shape=(n, n_views, max_max_len), 98 | dtype=np.uint8, compression=compression) 99 | bar = IncrementalBar(max=n-prog) 100 | for i in range(prog, n): 101 | bar.next() 102 | voxels = RleVoxels(np.array(rle_src[i]), in_dims) 103 | eye = view_manager.get_camera_positions(cat_id, example_ids[i]) 104 | for j in range(n_views): 105 | out = convert(voxels, eye[j], ray_shape, f) 106 | data = out.rle_data() 107 | dlen = len(data) 108 | if dlen > max_len: 109 | attrs['max_len'] = dlen 110 | max_len = dlen 111 | if dlen > max_max_len: 112 | raise ValueError( 113 | 'max_max_len exceeded. %d > %d' 114 | % (dlen, max_max_len)) 115 | rle_dst[i, j, :dlen] = data 116 | attrs['prog'] = i+1 117 | bar.finish() 118 | return temp_path 119 | 120 | 121 | def _shrink_data(temp_path, dst_path, chunk_size=100, compression='lzf'): 122 | from progress.bar import IncrementalBar 123 | print('Shrinking data to fit.') 124 | with h5py.File(temp_path, 'r') as src: 125 | max_len = int(src.attrs['max_len']) 126 | src_group = src[GROUP_KEY] 127 | _make_dir(dst_path) 128 | with h5py.File(dst_path, 'w') as dst: 129 | n_examples, n_renderings = src_group.shape[:2] 130 | dst_dataset = dst.create_dataset( 131 | GROUP_KEY, shape=(n_examples, n_renderings, max_len), 132 | dtype=np.uint8, compression=compression) 133 | bar = IncrementalBar(max=n_examples // chunk_size) 134 | for i in range(0, n_examples, chunk_size): 135 | stop = min(i + chunk_size, n_examples) 136 | dst_dataset[i:stop] = src_group[i:stop, :, :max_len] 137 | bar.next() 138 | bar.finish() 139 | 140 | 141 | # def _concat_data(temp_path, dst_path): 142 | # from progress.bar import IncrementalBar 143 | # from util3d.voxel import rle 144 | # print('Concatenating data') 145 | # with h5py.File(temp_path, 'r') as src: 146 | # src_group = src[GROUP_KEY] 147 | # _make_dir(dst_path) 148 | # with h5py.File(dst_path, 'w') as dst: 149 | # n_examples, n_renderings = src_group.shape[:2] 150 | # n_total = n_examples * n_renderings 151 | # starts = np.empty(dtype=np.int64, shape=(n_total+1,)) 152 | # print('Computing starts...') 153 | # k = 1 154 | # start = 0 155 | # starts[0] = start 156 | # bar = IncrementalBar(max=n_examples) 157 | # for i in range(n_examples): 158 | # bar.next() 159 | # example_data = np.array(src_group[i]) 160 | # for j in range(n_renderings): 161 | # data = rle.remove_length_padding(example_data[j]) 162 | # start += len(data) 163 | # starts[k] = start 164 | # k += 1 165 | # bar.finish() 166 | # assert(k == n_total+1) 167 | # dst.create_dataset('starts', data=starts) 168 | # values = dst.create_dataset( 169 | # 'values', dtype=np.uint8, shape=(starts[-1],)) 170 | 171 | # k = 0 172 | # print('Transfering data...') 173 | # bar = IncrementalBar(max=n_examples) 174 | # for i in range(n_examples): 175 | # example_data = np.array(src_group[i]) 176 | # bar.next() 177 | # for j in range(n_renderings): 178 | # values[starts[k]: starts[k+1]] = \ 179 | # rle.remove_length_padding(example_data[j]) 180 | # k += 1 181 | # bar.finish() 182 | # assert(k == n_total) 183 | 184 | 185 | def create_frustrum_voxels( 186 | view_manager, voxel_config, out_dim, cat_id, compression='lzf'): 187 | kwargs = dict( 188 | voxel_config=voxel_config, 189 | out_dim=out_dim, cat_id=cat_id) 190 | dst_path = _get_frustrum_voxels_path( 191 | view_id=view_manager.view_id, code=None, **kwargs) 192 | if os.path.isfile(dst_path): 193 | print('Already present.') 194 | return 195 | temp_path = create_temp_frustrum_voxels( 196 | view_manager=view_manager, compression=compression, **kwargs) 197 | _shrink_data(temp_path, dst_path, compression=compression) 198 | # _concat_data(temp_path, dst_path) 199 | 200 | 201 | # def fix(view_manager, voxel_config, out_dim, cat_id, example_index, 202 | # view_index, example_ids, compression='lzf'): 203 | # # from util3d.voxel.rle import length 204 | # dst_path = _get_frustrum_voxels_path( 205 | # manager_dir=view_manager.root_dir, code=None, 206 | # voxel_config=voxel_config, out_dim=out_dim, cat_id=cat_id) 207 | # # load src data 208 | # vox_manager = get_voxel_manager( 209 | # voxel_config, cat_id, key='rle', 210 | # compression=compression, shape_key='pad') 211 | # if not vox_manager.has_dataset(): 212 | # raise RuntimeError('No dataset') 213 | # with h5py.File(vox_manager.path, 'r') as vox_src: 214 | # rle_src = vox_src[GROUP_KEY] 215 | # voxels = RleVoxels( 216 | # np.array(rle_src[example_index]), 217 | # (voxel_config.voxel_dim,) * 3) 218 | # eye = view_manager.get_camera_positions( 219 | # cat_id, example_ids[example_index]) 220 | # f = view_manager.get_view_params()['f'] 221 | # # convert 222 | # rle_data = convert(voxels, eye, (out_dim,)*3, f).rle_data() 223 | # # save 224 | # with h5py.File(dst_path, 'a') as dst: 225 | # dst_group = dst[GROUP_KEY] 226 | # padded_rle_data = np.zeros((dst_group.shape[2],), dtype=np.uint8) 227 | # padded_rle_data[:len(rle_data)] = rle_data 228 | # dst_group[example_index, view_index] = padded_rle_data 229 | 230 | # # with h5py.File(dst_path, 'r') as dst: 231 | # # dst_group = dst[GROUP_KEY] 232 | # # old_data = np.array(dst_group[example_index, view_index]) 233 | # # print(length(old_data)) 234 | # # print(length(rle_data)) 235 | 236 | # # start = 500 237 | # # end = 520 238 | # # print(rle_data[start:end]) 239 | # # print(old_data[start:end]) 240 | # # print(out_dim) 241 | # # print(out_dim**3) 242 | # # print(length(rle_data)) 243 | # # print(length(old_data)) 244 | -------------------------------------------------------------------------------- /core/frustrum_voxels/scripts/check_frustrum_voxels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_integer('dim', default=256, help='dimension of square renderings') 10 | flags.DEFINE_bool('turntable', default=False, help='if True, renderings') 11 | flags.DEFINE_integer('n_views', default=24, help='number of views') 12 | flags.DEFINE_integer('voxel_dim', default=32, help='output voxel dimension') 13 | flags.DEFINE_integer( 14 | 'src_voxel_dim', default=256, help='input voxel dimension') 15 | flags.DEFINE_list( 16 | 'cat', default=None, help='category ids/descriptors, comma separated') 17 | 18 | 19 | def main(_): 20 | from progress.bar import IncrementalBar 21 | import numpy as np 22 | from shapenet.core import to_cat_id 23 | from shapenet.core.renderings.renderings_manager import get_base_manager 24 | from shapenet.core.frustrum_voxels import get_frustrum_voxels_data 25 | from shapenet.core.frustrum_voxels import GROUP_KEY 26 | from shapenet.core.voxels.config import get_config 27 | from util3d.voxel import rle 28 | voxel_config = get_config( 29 | FLAGS.src_voxel_dim, alt=False).filled('orthographic') 30 | manager = get_base_manager( 31 | dim=FLAGS.dim, turntable=FLAGS.turntable, n_views=FLAGS.n_views) 32 | 33 | cats = FLAGS.cat 34 | if cats is None: 35 | from shapenet.r2n2 import get_cat_ids 36 | cats = get_cat_ids() 37 | 38 | expected_length = FLAGS.voxel_dim**3 39 | 40 | for ci, cat in enumerate(cats): 41 | # if ci >= 3: 42 | # continue 43 | cat_id = to_cat_id(cat) 44 | print('Checking cat %s: %d / %d' % (cat_id, ci+1, len(cats))) 45 | with get_frustrum_voxels_data( 46 | manager.root_dir, voxel_config, FLAGS.voxel_dim, 47 | cat_id) as root: 48 | data = root[GROUP_KEY] 49 | ne, nr = data.shape[:2] 50 | bar = IncrementalBar(max=ne) 51 | for i in range(ne): 52 | bar.next() 53 | di = np.array(data[i]) 54 | for j in range(nr): 55 | actual_length = rle.length(di[j]) 56 | # actual_length = len(rle.rle_to_dense(di[j])) 57 | if actual_length != expected_length: 58 | raise ValueError( 59 | 'Incorrect length at %s, %d, %d\n' 60 | 'Expected %d, got %d' 61 | % (cat_id, i, j, expected_length, actual_length)) 62 | bar.finish() 63 | 64 | 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /core/frustrum_voxels/scripts/create_frustrum_voxels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_integer('dim', default=256, help='dimension of square renderings') 10 | flags.DEFINE_bool( 11 | 'turntable', default=False, 12 | help='if True, views are based on equal rotation about z axis') 13 | flags.DEFINE_integer('n_views', default=24, help='number of views') 14 | flags.DEFINE_integer('voxel_dim', default=32, help='output voxel dimension') 15 | flags.DEFINE_integer( 16 | 'src_voxel_dim', default=None, help='input voxel dimension') 17 | flags.DEFINE_list( 18 | 'cat', default=None, help='category ids/descriptors, comma separated') 19 | flags.DEFINE_bool('temp_only', default=False, help='If True, does not squeeze') 20 | flags.DEFINE_string( 21 | 'compression', default='lzf', help='compression for encoded data') 22 | flags.DEFINE_string( 23 | 'fill_alg', default='orthographic', 24 | help='fill algorithm for base voxel grid.') 25 | 26 | 27 | def main(_): 28 | from shapenet.core import to_cat_id 29 | from shapenet.core.renderings.renderings_manager import get_base_manager 30 | from shapenet.core.frustrum_voxels import create_frustrum_voxels 31 | from shapenet.core.frustrum_voxels import create_temp_frustrum_voxels 32 | from shapenet.core.voxels.config import get_config 33 | if FLAGS.src_voxel_dim is None: 34 | FLAGS.src_voxel_dim = FLAGS.voxel_dim 35 | 36 | voxel_config = get_config( 37 | FLAGS.src_voxel_dim, alt=False).filled(FLAGS.fill_alg) 38 | manager = get_base_manager( 39 | dim=FLAGS.dim, turntable=FLAGS.turntable, 40 | n_views=FLAGS.n_views) 41 | 42 | cats = FLAGS.cat 43 | if cats is None: 44 | from shapenet.r2n2 import get_cat_ids 45 | cats = get_cat_ids() 46 | 47 | for cat in cats: 48 | cat_id = to_cat_id(cat) 49 | args = manager, voxel_config, FLAGS.voxel_dim, cat_id 50 | if FLAGS.temp_only: 51 | create_temp_frustrum_voxels(*args, compression=FLAGS.compression) 52 | else: 53 | create_frustrum_voxels(*args, compression=FLAGS.compression) 54 | 55 | 56 | app.run(main) 57 | -------------------------------------------------------------------------------- /core/meshes/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import dids.file_io.hdf5 as dh 3 | from ..path import get_data_dir 4 | 5 | 6 | _meshes_dir = get_data_dir('meshes') 7 | 8 | 9 | class _MeshAutoSavingManager(dh.Hdf5AutoSavingManager): 10 | def __init__(self, cat_id): 11 | self.cat_id = cat_id 12 | 13 | @property 14 | def path(self): 15 | return os.path.join(_meshes_dir, '%s.hdf5' % self.cat_id) 16 | 17 | @property 18 | def saving_message(self): 19 | return 'Parsing mesh data for cat_id %s' % self.cat_id 20 | 21 | def get_lazy_dataset(self): 22 | from shapenet.core.objs import get_obj_file_dataset 23 | from util3d.mesh.obj_io import parse_obj_file 24 | 25 | def map_fn(f): 26 | vertices, faces = parse_obj_file(f)[:2] 27 | return dict(vertices=vertices, faces=faces) 28 | 29 | return get_obj_file_dataset(self.cat_id).map(map_fn) 30 | 31 | 32 | def remove_empty_meshes(cat_id): 33 | from shapenet.core.meshes import get_mesh_dataset 34 | to_remove = [] 35 | with get_mesh_dataset(cat_id, mode='a') as ds: 36 | for example_id, mesh in ds.items(): 37 | if len(mesh['faces']) == 0: 38 | to_remove.append(example_id) 39 | for t in to_remove: 40 | del ds[t] 41 | return to_remove 42 | 43 | 44 | def generate_mesh_data(cat_id, overwrite=False): 45 | _MeshAutoSavingManager(cat_id).save_all(overwrite=overwrite) 46 | remove_empty_meshes(cat_id) 47 | 48 | 49 | def get_mesh_dataset(cat_id, mode='r'): 50 | manager = _MeshAutoSavingManager(cat_id) 51 | if not os.path.isfile(manager.path): 52 | manager.save_all() 53 | return manager.get_saving_dataset(mode) 54 | 55 | 56 | # __all__ = [ 57 | # get_mesh_dataset, 58 | # generate_mesh_data, 59 | # ] 60 | -------------------------------------------------------------------------------- /core/meshes/scripts/check_mesh_data.py: -------------------------------------------------------------------------------- 1 | def check_mesh_data(cat_desc): 2 | from shapenet.core import cat_desc_to_id, get_example_ids 3 | from shapenet.core.meshes import get_mesh_dataset 4 | cat_id = cat_desc_to_id(cat_desc) 5 | example_ids = get_example_ids(cat_id) 6 | n_absent = 0 7 | with get_mesh_dataset(cat_id) as ds: 8 | for example_id in example_ids: 9 | if example_id not in ds: 10 | n_absent += 1 11 | 12 | n = len(example_ids) 13 | if n_absent == 0: 14 | print('All %d %s meshes present!' % (n, cat_desc)) 15 | else: 16 | print('%d / %d %s meshes absent' % (n_absent, n, cat_desc)) 17 | 18 | 19 | if __name__ == '__main__': 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('cat', type=str) 23 | 24 | args = parser.parse_args() 25 | check_mesh_data(args.cat) 26 | -------------------------------------------------------------------------------- /core/meshes/scripts/generate_mesh_data.py: -------------------------------------------------------------------------------- 1 | def generate_mesh_data(cat_desc, overwrite=False): 2 | from shapenet.core import cat_desc_to_id 3 | from shapenet.core.meshes import generate_mesh_data 4 | 5 | cat_id = cat_desc_to_id(cat_desc) 6 | generate_mesh_data(cat_id, overwrite=overwrite) 7 | 8 | 9 | if __name__ == '__main__': 10 | import argparse 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('cat', type=str, nargs='+') 13 | parser.add_argument('-o', '--overwrite', action='store_true') 14 | 15 | args = parser.parse_args() 16 | for cat in args.cat: 17 | generate_mesh_data(cat, args.overwrite) 18 | -------------------------------------------------------------------------------- /core/meshes/scripts/remove_empty_meshes.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | from shapenet.core import get_cat_ids 3 | from shapenet.core.meshes import remove_empty_meshes 4 | for cat_id in get_cat_ids(): 5 | remove_empty_meshes(cat_id) 6 | -------------------------------------------------------------------------------- /core/objs.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from .path import get_zip_path, get_zip_file, get_extracted_core_dir 7 | 8 | 9 | def get_obj_file_dataset(cat_id): 10 | """Get a DIDS dataset mapping example_ids to obj file objects.""" 11 | import dids.file_io.zip_file_dataset as zfd 12 | dataset = zfd.ZipFileDataset(get_zip_path(cat_id)) 13 | 14 | def key_fn(example_id): 15 | return os.path.join(cat_id, example_id, 'model.obj') 16 | 17 | def inverse_key_fn(path): 18 | subpaths = path.split('/') 19 | if len(subpaths) == 3 and subpaths[2][-4:] == '.obj': 20 | return subpaths[1] 21 | else: 22 | return None 23 | 24 | dataset = dataset.map_keys(key_fn, inverse_key_fn) 25 | return dataset 26 | 27 | 28 | def get_extracted_obj_path_dataset( 29 | cat_id, include_bad=False, use_fixed_meshes=True): 30 | from dids.core import FunctionDataset 31 | from .fixed_objs import get_fixed_obj_path, get_fixed_example_ids 32 | from . import get_example_ids 33 | from .path import get_obj_subpath 34 | extracted_dir = get_extracted_core_dir() 35 | example_ids = get_example_ids(cat_id, include_bad=include_bad) 36 | try_extract_models(cat_id) 37 | 38 | if use_fixed_meshes: 39 | fixed_paths = { 40 | k: get_fixed_obj_path(cat_id, k) 41 | for k in get_fixed_example_ids(cat_id)} 42 | 43 | def get_path(example_id): 44 | if example_id in fixed_paths: 45 | return fixed_paths[example_id] 46 | else: 47 | return os.path.join( 48 | extracted_dir, get_obj_subpath(cat_id, example_id)) 49 | else: 50 | def get_path(example_id): 51 | return os.path.join( 52 | extracted_dir, get_obj_subpath(cat_id, example_id)) 53 | 54 | return FunctionDataset(get_path, example_ids) 55 | 56 | 57 | def get_extract_obj_file_dataset( 58 | cat_id, include_bad=False, use_fixed_meshes=True): 59 | return get_extracted_obj_path_dataset( 60 | cat_id, include_bad=include_bad, 61 | use_fixed_meshes=use_fixed_meshes).map(lambda path: open(path, 'r')) 62 | 63 | 64 | def extract_models(cat_id): 65 | extraction_dir = get_extracted_core_dir(cat_id) 66 | if os.path.isdir(extraction_dir): 67 | raise IOError('Directory %s already exists' % extraction_dir) 68 | _extract_models(cat_id) 69 | 70 | 71 | def _extract_models(cat_id): 72 | folder = get_extracted_core_dir() 73 | print('Extracting obj models for cat %s to %s' % (cat_id, folder)) 74 | if not os.path.isdir(folder): 75 | os.makedirs(folder) 76 | with get_zip_file(cat_id) as zf: 77 | zf.extractall(folder) 78 | 79 | 80 | def try_extract_models(cat_id): 81 | from . import get_example_ids 82 | folder = get_extracted_core_dir(cat_id) 83 | if os.path.isdir(folder): 84 | example_ids = os.listdir(folder) 85 | if len(example_ids) != len(get_example_ids(cat_id, include_bad=True)): 86 | _extract_models(cat_id) 87 | else: 88 | _extract_models(cat_id) 89 | 90 | 91 | _bad_objs = { 92 | '04090263': ('4a32519f44dc84aabafe26e2eb69ebf4',) 93 | } 94 | 95 | 96 | def get_bad_objs(cat_id): 97 | """Get a tuple of ids that are "bad" - no vertices.""" 98 | return _bad_objs.get(cat_id, ()) 99 | 100 | 101 | def is_bad_obj(cat_id, example_id): 102 | """Flag indicating whether the obj file is "bad" (has no vertices).""" 103 | return cat_id in _bad_objs and example_id in _bad_objs[cat_id] 104 | 105 | 106 | def remove_extracted_models(cat_id, confirm=True): 107 | import shutil 108 | extraction_dir = get_extracted_core_dir(cat_id) 109 | if os.path.isdir(extraction_dir): 110 | if confirm: 111 | try: 112 | get_input = raw_input 113 | except NameError: 114 | get_input = input 115 | confirmed = get_input( 116 | 'Really delete directory %s? (y/N) ' % extraction_dir).lower() 117 | if confirmed == 'y': 118 | shutil.rmtree(extraction_dir) 119 | print('Removed %s' % extraction_dir) 120 | else: 121 | print('NOT removing %s' % extraction_dir) 122 | else: 123 | print('No directory at %s' % extraction_dir) 124 | -------------------------------------------------------------------------------- /core/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import zipfile 7 | from .. import config 8 | 9 | 10 | def get_core_dir(): 11 | return config.config['core_dir'] 12 | 13 | 14 | def get_data_dir(*args): 15 | from .. import path 16 | return path.get_data_dir('core', *args) 17 | 18 | 19 | _ids_dir = get_data_dir('ids') 20 | 21 | 22 | def get_extracted_core_dir(cat_id=None): 23 | args = 'extracted', 24 | if cat_id is not None: 25 | args = args + (cat_id,) 26 | return get_data_dir(*args) 27 | 28 | 29 | def get_csv_path(cat_id): 30 | return os.path.join(get_core_dir(), '%s.csv' % cat_id) 31 | 32 | 33 | def get_zip_path(cat_id): 34 | return os.path.join(get_core_dir(), '%s.zip' % cat_id) 35 | 36 | 37 | def get_test_train_split_path(): 38 | return os.path.join( 39 | os.path.realpath(os.path.dirname(__file__)), 'split.csv') 40 | 41 | 42 | def get_zip_file(cat_id): 43 | return zipfile.ZipFile(get_zip_path(cat_id)) 44 | 45 | 46 | def get_example_subdir(cat_id, example_id): 47 | return os.path.join(cat_id, example_id) 48 | 49 | 50 | def get_obj_subpath(cat_id, example_id): 51 | return os.path.join(cat_id, example_id, 'model.obj') 52 | 53 | 54 | def get_mtl_subpath(cat_id, example_id): 55 | return os.path.join(cat_id, example_id, 'model.mtl') 56 | 57 | 58 | def _get_example_ids_from_zip(cat_id, category_zipfile): 59 | start = len(cat_id) + 1 60 | end = -len('model.obj')-1 61 | names = [ 62 | n[start:end] for n in category_zipfile.namelist() if n[-4:] == '.obj'] 63 | return names 64 | 65 | 66 | def get_example_ids_from_zip(cat_id, category_zipfile=None): 67 | if category_zipfile is None: 68 | with get_zip_file(cat_id) as f: 69 | return _get_example_ids_from_zip(cat_id, f) 70 | else: 71 | return _get_example_ids_from_zip(cat_id, category_zipfile) 72 | 73 | 74 | def get_ids_path(cat_id): 75 | return os.path.join(_ids_dir, '%s.txt' % cat_id) 76 | -------------------------------------------------------------------------------- /core/point_clouds/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import os 5 | import numpy as np 6 | from dids.file_io.hdf5 import Hdf5AutoSavingManager 7 | from dids.core import BiKeyDataset 8 | from ..path import get_data_dir 9 | 10 | 11 | _point_cloud_dir = get_data_dir('point_clouds') 12 | 13 | 14 | class PointCloudAutoSavingManager(Hdf5AutoSavingManager): 15 | def __init__(self, cat_id, n_samples, example_ids=None): 16 | if not isinstance(n_samples, int): 17 | raise ValueError('n_samples must be an int') 18 | self._cat_id = cat_id 19 | self._n_samples = n_samples 20 | self._example_ids = example_ids 21 | 22 | @property 23 | def path(self): 24 | return os.path.join( 25 | _point_cloud_dir, 'point_clouds', str(self._n_samples), 26 | '%s.hdf5' % self._cat_id) 27 | 28 | @property 29 | def saving_message(self): 30 | return ('Creating shapenet point cloud\n' 31 | 'cat_id: %s\n' 32 | 'n_samples: %d' % (self._cat_id, self._n_samples)) 33 | 34 | def get_lazy_dataset(self): 35 | from shapenet.core.meshes import get_mesh_dataset 36 | from util3d.mesh.sample import sample_faces 37 | 38 | def map_fn(mesh): 39 | v, f = (np.array(mesh[k]) for k in ('vertices', 'faces')) 40 | return sample_faces(v, f, self._n_samples) 41 | 42 | mesh_dataset = get_mesh_dataset(self._cat_id) 43 | if self._example_ids is not None: 44 | mesh_dataset = mesh_dataset.subset(self._example_ids) 45 | 46 | with mesh_dataset: 47 | keys = [k for k, v in mesh_dataset.items() if len(v['faces']) > 0] 48 | mesh_dataset = mesh_dataset.subset(keys) 49 | 50 | return mesh_dataset.map(map_fn) 51 | 52 | 53 | def _get_point_cloud_dataset(cat_id, n_samples, example_ids=None, mode='r'): 54 | """ 55 | Get a dataset with point cloud data. 56 | 57 | Args: 58 | `cat_id`: category id 59 | `n_samples`: number of points in each cloud 60 | `example_ids`: If specified, only create data for these examples, 61 | and return a dataset with only these ids exposed as keys. 62 | Defaults to all examples in the category. 63 | `mode`: mode to open the file in. If `a`, data can be deleted or saved. 64 | """ 65 | manager = PointCloudAutoSavingManager(cat_id, n_samples, example_ids) 66 | if not os.path.isfile(manager.path): 67 | dataset = manager.get_saved_dataset() 68 | else: 69 | if mode not in ('a', 'r'): 70 | raise NotImplementedError('mode must be in ("a", "r")') 71 | dataset = manager.get_saving_dataset(mode=mode) 72 | if example_ids is not None: 73 | dataset = dataset.subset(example_ids) 74 | return dataset 75 | 76 | 77 | def get_point_cloud_dataset(cat_id, n_samples, example_ids=None, mode='r'): 78 | def f(c, e): 79 | return _get_point_cloud_dataset(c, n_samples, e, mode) 80 | 81 | if isinstance(cat_id, (tuple, list)): 82 | if example_ids is None: 83 | example_ids = tuple(None for _ in cat_id) 84 | datasets = { 85 | c: f(c, e) 86 | for c, e in zip(cat_id, example_ids)} 87 | return BiKeyDataset(datasets) 88 | else: 89 | return f(cat_id, example_ids) 90 | 91 | 92 | class CloudNormalAutoSavingManager(Hdf5AutoSavingManager): 93 | def __init__(self, cat_id, n_samples, example_ids=None): 94 | if not isinstance(n_samples, int): 95 | raise ValueError('n_samples must be an int') 96 | self._cat_id = cat_id 97 | self._n_samples = n_samples 98 | self._example_ids = example_ids 99 | 100 | @property 101 | def path(self): 102 | return os.path.join( 103 | _point_cloud_dir, 'cloud_normals', str(self._n_samples), 104 | '%s.hdf5' % self._cat_id) 105 | 106 | @property 107 | def saving_message(self): 108 | return ('Creating shapenet cloud normal data\n' 109 | 'cat_id: %s\n' 110 | 'n_samples: %d' % (self._cat_id, self._n_samples)) 111 | 112 | def get_lazy_dataset(self): 113 | from shapenet.core.meshes import get_mesh_dataset 114 | from util3d.mesh.sample import sample_faces_with_normals 115 | 116 | def map_fn(mesh): 117 | v, f = (np.array(mesh[k]) for k in ('vertices', 'faces')) 118 | p, n = sample_faces_with_normals(v, f, self._n_samples) 119 | return dict(points=p, normals=n) 120 | 121 | mesh_dataset = get_mesh_dataset(self._cat_id) 122 | if self._example_ids is not None: 123 | mesh_dataset = mesh_dataset.subset(self._example_ids) 124 | 125 | with mesh_dataset: 126 | keys = [k for k, v in mesh_dataset.items() if len(v['faces']) > 0] 127 | mesh_dataset = mesh_dataset.subset(keys) 128 | 129 | return mesh_dataset.map(map_fn) 130 | 131 | 132 | def _get_cloud_normal_dataset(cat_id, n_samples, example_ids=None, mode='r'): 133 | manager = CloudNormalAutoSavingManager(cat_id, n_samples, example_ids) 134 | if not os.path.isfile(manager.path): 135 | dataset = manager.get_saved_dataset(mode=mode) 136 | else: 137 | dataset = manager.get_saving_dataset(mode=mode) 138 | if example_ids is not None: 139 | dataset = dataset.subset(example_ids) 140 | return dataset 141 | 142 | 143 | def get_cloud_normal_dataset(cat_id, n_samples, example_ids=None, mode='r'): 144 | if not isinstance(cat_id, (tuple, list)): 145 | cat_id = [cat_id] 146 | example_ids = [example_ids] 147 | else: 148 | if example_ids is None: 149 | example_ids = [None for _ in cat_id] 150 | datasets = { 151 | c: _get_cloud_normal_dataset(c, n_samples, e, mode) 152 | for c, e in zip(cat_id, example_ids)} 153 | return BiKeyDataset(datasets) 154 | 155 | 156 | def generate_point_cloud_data( 157 | cat_id, n_samples, normals=False, example_ids=None, overwrite=False): 158 | if normals: 159 | CloudNormalAutoSavingManager( 160 | cat_id, n_samples, example_ids).save_all(overwrite=overwrite) 161 | get_cloud_normal_dataset(cat_id, n_samples) 162 | else: 163 | PointCloudAutoSavingManager( 164 | cat_id, n_samples, example_ids).save_all(overwrite=overwrite) 165 | -------------------------------------------------------------------------------- /core/point_clouds/scripts/generate_point_clouds.py: -------------------------------------------------------------------------------- 1 | """Example usage: 2 | ```bash 3 | python generate_point_clouds.py rifle -s=1024 4 | ``` 5 | """ 6 | 7 | def generate_point_cloud_data(cat_desc, samples, normals, overwrite=False): 8 | from shapenet.core import cat_desc_to_id 9 | from shapenet.core.point_clouds import generate_point_cloud_data 10 | 11 | cat_id = cat_desc_to_id(cat_desc) 12 | generate_point_cloud_data(cat_id, samples, normals, overwrite=overwrite) 13 | 14 | 15 | if __name__ == '__main__': 16 | import argparse 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('cat', type=str, nargs='+') 19 | parser.add_argument('-s', '--samples', type=int, help='number of samples') 20 | parser.add_argument('-n', '--normals', action='store_true') 21 | parser.add_argument('-o', '--overwrite', action='store_true') 22 | 23 | args = parser.parse_args() 24 | for cat in args.cat: 25 | generate_point_cloud_data( 26 | cat, args.samples, args.normals, args.overwrite) 27 | -------------------------------------------------------------------------------- /core/renderings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jackd/shapenet/4ae662743b0e5d0bd4d96f224c96be811149b9eb/core/renderings/__init__.py -------------------------------------------------------------------------------- /core/renderings/archive.py: -------------------------------------------------------------------------------- 1 | """Uniform interface for zip/tar archives.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import zipfile 8 | import tarfile 9 | 10 | 11 | class Archive(object): 12 | def add(self, src, dst): 13 | raise NotImplementedError('Abstract method') 14 | 15 | def get_names(self): 16 | raise NotImplementedError('Abstract method') 17 | 18 | def __enter__(self): 19 | self.open() 20 | return self 21 | 22 | def __exit__(self, *args, **kwargs): 23 | self.close() 24 | 25 | def open(self): 26 | raise NotImplementedError('Abstract method') 27 | 28 | def close(self): 29 | raise NotImplementedError('Abstract method') 30 | 31 | def extractall(self, path=None): 32 | raise NotImplementedError('Abstract method') 33 | 34 | 35 | class ArchiveBase(Archive): 36 | def __init__(self, path, mode): 37 | self._path = path 38 | self._mode = mode 39 | self._file = None 40 | 41 | @property 42 | def path(self): 43 | return self._path 44 | 45 | @property 46 | def mode(self): 47 | return self._mode 48 | 49 | def get_open_file(self): 50 | raise NotImplementedError('Abstract method') 51 | 52 | def __str__(self): 53 | return '' % self._path 54 | 55 | def __repr__(self): 56 | return self.__str__() 57 | 58 | def close(self): 59 | self._file = None 60 | 61 | def open(self): 62 | self._file = self.get_open_file() 63 | 64 | 65 | class ZipArchive(ArchiveBase): 66 | def get_open_file(self): 67 | return zipfile.ZipFile( 68 | self.path, self.mode, zipfile.ZIP_DEFLATED, allowZip64=True) 69 | 70 | def add(self, src, dst): 71 | self._file.write(src, dst) 72 | 73 | def get_names(self): 74 | return self._file.namelist() 75 | 76 | def extractall(self, path='.'): 77 | return self._file.extractall(path=path) 78 | 79 | 80 | class TarArchive(ArchiveBase): 81 | def get_open_file(self): 82 | return tarfile.open(self.path, self.mode) 83 | 84 | def add(self, src, dst): 85 | self._file.add(src, dst) 86 | 87 | def get_names(self): 88 | return self._file.getnames() 89 | 90 | def extractall(self, path='.'): 91 | return self._file.extractall(path) 92 | 93 | 94 | def get_archive(path, mode='r'): 95 | ext = os.path.splitext(path)[1] 96 | if ext == '.zip': 97 | return ZipArchive(path, mode) 98 | elif ext == '.tar': 99 | return TarArchive(path, mode) 100 | else: 101 | raise KeyError('Unrecognized archive extension "%s"' % ext) 102 | -------------------------------------------------------------------------------- /core/renderings/archive_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from .. import to_cat_id, get_example_ids 7 | from . import archive 8 | 9 | 10 | class DummyBar(object): 11 | def __init__(self, *args, **kwargs): 12 | pass 13 | 14 | def next(self): 15 | pass 16 | 17 | def finish(self): 18 | pass 19 | 20 | 21 | def get_bar(show_progress, *args, **kwargs): 22 | from progress.bar import IncrementalBar 23 | return (IncrementalBar if show_progress else DummyBar)(*args, **kwargs) 24 | 25 | 26 | class ArchiveManager(object): 27 | def __init__(self, renderings_manager, cat, archive, base_only): 28 | self._renderings_manager = renderings_manager 29 | self._cat_id = to_cat_id(cat) 30 | self._archive = archive 31 | self._base_only = base_only 32 | 33 | @property 34 | def archive(self): 35 | return self._archive 36 | 37 | @property 38 | def cat_id(self): 39 | return self._cat_id 40 | 41 | @property 42 | def renderings_manager(self): 43 | return self._renderings_manager 44 | 45 | def get_src_paths(self): 46 | cat_id = self.cat_id 47 | manager = self.renderings_manager 48 | 49 | if self._base_only: 50 | example_ids = get_example_ids(cat_id) 51 | n_views = manager.get_view_params()['n_views'] 52 | 53 | for example_id in example_ids: 54 | for i in range(n_views): 55 | yield manager.get_rendering_path(cat_id, example_id, i) 56 | else: 57 | cat_dir = manager.get_cat_dir(cat_id) 58 | for r, _, fns in os.walk(cat_dir): 59 | for fn in fns: 60 | yield os.path.join(r, fn) 61 | 62 | def n_src_paths(self): 63 | manager = self.renderings_manager 64 | cat_id = self.cat_id 65 | if self._base_only: 66 | return len(get_example_ids(self.cat_id)) \ 67 | * manager.get_view_params()['n_views'] 68 | else: 69 | cat_dir = manager.get_cat_dir(cat_id) 70 | c = 0 71 | for _, __, fns in os.walk(cat_dir): 72 | c += len(fns) 73 | return c 74 | 75 | def _do_all(self, fn, show_progress): 76 | manager = self._renderings_manager 77 | cat_id = self._cat_id 78 | archive = self.archive 79 | print('Getting names...') 80 | names = set(self.archive.get_names()) 81 | print('%d files found' % len(names)) 82 | cat_dir = manager.get_cat_dir(cat_id) 83 | base_folder, rest = os.path.split(cat_dir) 84 | assert(rest == cat_id) 85 | nrd = len(base_folder) + 1 86 | 87 | n = self.n_src_paths() 88 | bar = get_bar(show_progress, max=n) 89 | for path in self.get_src_paths(): 90 | bar.next() 91 | name = path[nrd:] 92 | if name not in names: 93 | fn(archive, path, name) 94 | bar.finish() 95 | 96 | def unpack(self): 97 | manager = self._renderings_manager 98 | cat_id = self._cat_id 99 | archive = self.archive 100 | 101 | cat_dir = manager.get_cat_dir(cat_id) 102 | base_folder, rest = os.path.split(cat_dir) 103 | assert(rest == cat_id) 104 | print('Extracting files at %s' % archive.path) 105 | archive.extractall(base_folder) 106 | 107 | def check(self): 108 | def fn(archive, src_path, name): 109 | raise RuntimeError('%s not present' % name) 110 | self._do_all(fn, False) 111 | print('Check complete: %s' % str(self.archive)) 112 | 113 | def add(self): 114 | def fn(archive, src_path, name): 115 | archive.add(src_path, name) 116 | self._do_all(fn, True) 117 | print('Add complete: %s' % str(self.archive)) 118 | 119 | 120 | def get_archive_path(renderings_manager, cat, base_only=True, format='zip'): 121 | cat_id = to_cat_id(cat) 122 | fn = ('%s-base.%s' if base_only else '%s.%s') % (cat_id, format) 123 | return os.path.join(renderings_manager.root_dir, 'renderings', fn) 124 | 125 | 126 | def get_archive( 127 | renderings_manager, cat, base_only=True, format='zip', mode='r'): 128 | return archive.get_archive( 129 | get_archive_path(renderings_manager, cat, base_only, format), mode) 130 | 131 | 132 | def get_archive_manager( 133 | renderings_manager, cat, base_only=True, format='zip', mode='r'): 134 | archive = get_archive(renderings_manager, cat, base_only, format, mode) 135 | return ArchiveManager( 136 | renderings_manager, cat, archive, base_only) 137 | -------------------------------------------------------------------------------- /core/renderings/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | 8 | renderings_format = 'r%03d%s.png' 9 | 10 | 11 | def get_rendering_filename(view_index, suffix=''): 12 | return renderings_format % (view_index, suffix) 13 | 14 | 15 | def get_rendering_subpath(cat_id, example_id, view_index): 16 | return os.path.join( 17 | get_renderings_subdir(cat_id, example_id), 18 | get_rendering_filename(view_index)) 19 | 20 | 21 | def get_renderings_subdir(cat_id, example_id=None): 22 | if example_id is None: 23 | return cat_id 24 | else: 25 | return os.path.join(cat_id, example_id) 26 | -------------------------------------------------------------------------------- /core/renderings/renderings_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import yaml 6 | import json 7 | import os 8 | import numpy as np 9 | from ..fixed_objs import is_fixed_obj, get_fixed_obj_path 10 | from .path import get_renderings_subdir 11 | 12 | 13 | def _get_core_renderings_dir(): 14 | from ..path import get_data_dir 15 | return get_data_dir('renderings') 16 | 17 | 18 | renderings_dir = _get_core_renderings_dir() 19 | 20 | 21 | def get_default_manager_dir(manager_id): 22 | return os.path.join(renderings_dir, manager_id) 23 | 24 | 25 | def has_renderings(folder, n_renderings, files_per_rendering=4): 26 | return os.path.isdir(folder) and \ 27 | len(os.listdir(folder)) == files_per_rendering*n_renderings 28 | 29 | 30 | class RenderingsManager(object): 31 | @property 32 | def view_manager(self): 33 | raise NotImplementedError('Abstract property') 34 | 35 | def get_image(self, cat_id, example_id, view_index): 36 | raise NotImplementedError('Abstract method') 37 | 38 | 39 | class RenderableManager(RenderingsManager): 40 | def get_obj_path(self, cat_id, example_id): 41 | raise NotImplementedError('Abstract method') 42 | 43 | def get_rendering_path(self, cat_id, example_id, view_index): 44 | raise NotImplementedError('Abstract method') 45 | 46 | def get_view_params(self): 47 | raise NotImplementedError('Abstract method') 48 | 49 | def get_image_params(self): 50 | raise NotImplementedError('Abstract method') 51 | 52 | def needs_rendering_keys(self): 53 | raise NotImplementedError('Abstract method') 54 | 55 | def get_image(self, cat_id, example_id, view_index): 56 | from PIL import Image 57 | return Image.open( 58 | self.get_rendering_path(cat_id, example_id, view_index)) 59 | 60 | 61 | class RenderableManagerBase(RenderableManager): 62 | def __init__(self, root_dir, view_manager, image_shape): 63 | self._root_dir = root_dir 64 | self._view_manager = view_manager 65 | self._view_params = view_manager.get_view_params() 66 | self._image_shape = image_shape 67 | self._renderings_dir = os.path.join(root_dir, 'renderings') 68 | 69 | @property 70 | def view_manager(self): 71 | return self._view_manager 72 | 73 | def get_view_params(self): 74 | return self._view_params.copy() 75 | 76 | def needs_rendering_keys(self, cat_ids=None): 77 | n_views = self.get_view_params()['n_views'] 78 | files_per_rendering = 4 79 | return ( 80 | k for k in self.view_manager.keys(cat_ids) 81 | if not has_renderings( 82 | self.get_renderings_dir(*k), n_views, files_per_rendering)) 83 | 84 | @property 85 | def root_dir(self): 86 | return self._root_dir 87 | 88 | def _path(self, *subpaths): 89 | return os.path.join(self.root_dir, *subpaths) 90 | 91 | @property 92 | def _image_params_path(self): 93 | return self._path('image_params.yaml') 94 | 95 | def get_image_params(self): 96 | path = self._image_params_path 97 | if not os.path.isfile(path): 98 | raise IOError('No image_params saved at %s' % path) 99 | with open(path, 'r') as fp: 100 | return yaml.load(fp) 101 | 102 | def set_image_params(self, **image_params): 103 | path = self._image_params_path 104 | folder = os.path.dirname(path) 105 | if not os.path.isdir(folder): 106 | os.makedirs(folder) 107 | 108 | if os.path.isfile(path): 109 | self.check_image_params(**image_params) 110 | else: 111 | with open(path, 'w') as fp: 112 | yaml.dump(image_params, fp, default_flow_style=False) 113 | 114 | def check_image_params(self, **image_params): 115 | saved = self.get_image_params() 116 | if saved != image_params: 117 | raise ValueError( 118 | 'image_params not consistent with saved image_params\n' 119 | 'saved_image_params:\n' 120 | '%s\n' 121 | 'passed image_params:\n' 122 | '%s\n' % (saved, image_params)) 123 | 124 | def get_obj_path(self, cat_id, example_id): 125 | from shapenet.core.path import get_extracted_core_dir 126 | if is_fixed_obj(cat_id, example_id): 127 | return get_fixed_obj_path(cat_id, example_id) 128 | else: 129 | return os.path.join( 130 | get_extracted_core_dir(), cat_id, example_id, 'model.obj') 131 | 132 | def get_renderings_dir(self, cat_id, example_id): 133 | return os.path.join( 134 | self._renderings_dir, get_renderings_subdir(cat_id, example_id)) 135 | 136 | def get_rendering_path(self, cat_id, example_id, view_index): 137 | from . import path 138 | subpath = path.get_rendering_subpath(cat_id, example_id, view_index) 139 | return os.path.join(self._renderings_dir, subpath) 140 | 141 | def get_cat_dir(self, cat_id): 142 | return os.path.join( 143 | self._renderings_dir, get_renderings_subdir(cat_id)) 144 | 145 | def render_all( 146 | self, cat_ids=None, verbose=True, blender_path='blender'): 147 | import subprocess 148 | from progress.bar import IncrementalBar 149 | import tempfile 150 | from .path import renderings_format 151 | from ..objs import try_extract_models 152 | for cat_id in cat_ids: 153 | try_extract_models(cat_id) 154 | _FNULL = open(os.devnull, 'w') 155 | call_kwargs = dict() 156 | if not verbose: 157 | call_kwargs['stdout'] = _FNULL 158 | call_kwargs['stderr'] = subprocess.STDOUT 159 | 160 | root_dir = os.path.realpath(os.path.dirname(__file__)) 161 | script_path = os.path.join(root_dir, 'scripts', 'blender_render.py') 162 | 163 | render_params_path = None 164 | camera_positions_path = None 165 | 166 | def clean_up(): 167 | for path in (render_params_path, camera_positions_path): 168 | if path is not None and os.path.isfile(path): 169 | os.remove(path) 170 | 171 | render_params_fp, render_params_path = tempfile.mkstemp(suffix='.json') 172 | try: 173 | view_params = self.get_view_params() 174 | view_params.update(**self.get_image_params()) 175 | os.write(render_params_fp, json.dumps(view_params)) 176 | os.close(render_params_fp) 177 | 178 | args = [ 179 | blender_path, '--background', 180 | '--python', script_path, '--', 181 | '--render_params', render_params_path] 182 | 183 | keys = tuple(self.needs_rendering_keys(cat_ids)) 184 | n = len(keys) 185 | if n == 0: 186 | print('No keys to render.') 187 | return 188 | print('Rendering %d examples' % n) 189 | bar = IncrementalBar(max=n) 190 | for cat_id, example_id in keys: 191 | bar.next() 192 | 193 | camera_positions_fp, camera_positions_path = tempfile.mkstemp( 194 | suffix='.npy') 195 | os.close(camera_positions_fp) 196 | np.save( 197 | camera_positions_path, 198 | self.view_manager.get_camera_positions(cat_id, example_id)) 199 | 200 | out_dir = self.get_renderings_dir(cat_id, example_id) 201 | proc = subprocess.Popen( 202 | args + [ 203 | '--obj', self.get_obj_path(cat_id, example_id), 204 | '--out_dir', out_dir, 205 | '--filename_format', renderings_format, 206 | '--camera_positions', camera_positions_path, 207 | ], 208 | **call_kwargs) 209 | try: 210 | proc.wait() 211 | except KeyboardInterrupt: 212 | proc.kill() 213 | raise 214 | if os.path.isfile(camera_positions_path): 215 | os.remove(camera_positions_path) 216 | bar.finish() 217 | except (Exception, KeyboardInterrupt): 218 | clean_up() 219 | raise 220 | clean_up() 221 | 222 | 223 | def get_base_manager(turntable=False, n_views=24, format='h5', dim=128): 224 | from ..views.base import get_base_manager, get_base_id 225 | kwargs = dict(turntable=turntable, n_views=n_views) 226 | manager_id = '%s-%03d' % (get_base_id(**kwargs), dim) 227 | manager = RenderableManagerBase( 228 | get_default_manager_dir(manager_id), get_base_manager(**kwargs), 229 | (dim,)*2) 230 | manager.set_image_params(shape=(dim,)*2) 231 | return manager 232 | -------------------------------------------------------------------------------- /core/renderings/scripts/archive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_integer('dim', default=128, help='dimension of square renderings') 10 | flags.DEFINE_bool('turntable', default=False, help='if True, renderings') 11 | flags.DEFINE_integer('n_renderings', default=24, help='number of renderings') 12 | flags.DEFINE_string('path', default=None, help='(optional) path to save') 13 | flags.DEFINE_list( 14 | 'cat', default=None, help='cat(s), either ids of descriptors') 15 | flags.DEFINE_boolean( 16 | 'full', default=False, help='compress all images (True), or just base') 17 | flags.DEFINE_string( 18 | 'format', default='zip', help='compression format, one of ["zip", "tar"]') 19 | flags.DEFINE_boolean('check', default=False, help='If True, just runs a check') 20 | 21 | 22 | def main(_): 23 | from shapenet.core.renderings.archive_manager import get_archive_manager 24 | from shapenet.core.renderings.renderings_manager import get_base_manager 25 | 26 | rend_manager = get_base_manager( 27 | dim=FLAGS.dim, turntable=FLAGS.turntable, 28 | n_views=FLAGS.n_renderings) 29 | if FLAGS.cat is None: 30 | from shapenet.r2n2 import get_cat_ids 31 | cats = get_cat_ids() 32 | else: 33 | cats = FLAGS.cat 34 | format = FLAGS.format 35 | mode = 'r' if FLAGS.check else 'a' 36 | for cat in cats: 37 | archive_manager = get_archive_manager( 38 | rend_manager, cat, base_only=not FLAGS.full, 39 | format=format, mode=mode) 40 | with archive_manager.archive: 41 | if FLAGS.check: 42 | archive_manager.check() 43 | else: 44 | archive_manager.add() 45 | 46 | 47 | app.run(main) 48 | -------------------------------------------------------------------------------- /core/renderings/scripts/blender_render.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple script that uses blender to render views of a single object by 3 | rotation the camera around it. 4 | 5 | Also produces depth map at the same time. 6 | 7 | Example: 8 | blender --background --python blender_render.py -- \ 9 | --manager_dir=/path/to/manager 10 | 11 | Original source: 12 | https://github.com/panmari/stanford-shapenet-renderer 13 | """ 14 | import os 15 | import bpy 16 | 17 | 18 | def setup(depth_scale): 19 | # Set up rendering of depth map: 20 | bpy.context.scene.use_nodes = True 21 | tree = bpy.context.scene.node_tree 22 | links = tree.links 23 | 24 | # Add passes for additionally dumping albed and normals. 25 | bpy.context.scene.render.layers["RenderLayer"].use_pass_normal = True 26 | bpy.context.scene.render.layers["RenderLayer"].use_pass_color = True 27 | 28 | # clear default nodes 29 | for n in tree.nodes: 30 | tree.nodes.remove(n) 31 | 32 | # create input render layer node 33 | rl = tree.nodes.new('CompositorNodeRLayers') 34 | 35 | map = tree.nodes.new(type="CompositorNodeMapValue") 36 | # Size is chosen kind of arbitrarily, try out until you're satisfied with 37 | # resulting depth map. 38 | map.offset = [-0.7] 39 | map.size = [depth_scale] 40 | map.use_min = True 41 | map.min = [0] 42 | map.use_max = True 43 | map.max = [255] 44 | try: 45 | links.new(rl.outputs['Z'], map.inputs[0]) 46 | except KeyError: 47 | # some versions of blender don't like this? 48 | pass 49 | 50 | invert = tree.nodes.new(type="CompositorNodeInvert") 51 | links.new(map.outputs[0], invert.inputs[1]) 52 | 53 | # create a file output node and set the path 54 | depthFileOutput = tree.nodes.new(type="CompositorNodeOutputFile") 55 | depthFileOutput.label = 'Depth Output' 56 | links.new(invert.outputs[0], depthFileOutput.inputs[0]) 57 | 58 | scale_normal = tree.nodes.new(type="CompositorNodeMixRGB") 59 | scale_normal.blend_type = 'MULTIPLY' 60 | # scale_normal.use_alpha = True 61 | scale_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 1) 62 | links.new(rl.outputs['Normal'], scale_normal.inputs[1]) 63 | 64 | bias_normal = tree.nodes.new(type="CompositorNodeMixRGB") 65 | bias_normal.blend_type = 'ADD' 66 | # bias_normal.use_alpha = True 67 | bias_normal.inputs[2].default_value = (0.5, 0.5, 0.5, 0) 68 | links.new(scale_normal.outputs[0], bias_normal.inputs[1]) 69 | 70 | normalFileOutput = tree.nodes.new(type="CompositorNodeOutputFile") 71 | normalFileOutput.label = 'Normal Output' 72 | links.new(bias_normal.outputs[0], normalFileOutput.inputs[0]) 73 | 74 | albedoFileOutput = tree.nodes.new(type="CompositorNodeOutputFile") 75 | albedoFileOutput.label = 'Albedo Output' 76 | # For some reason, 77 | links.new(rl.outputs['Color'], albedoFileOutput.inputs[0]) 78 | 79 | # Delete default cube 80 | bpy.data.objects['Cube'].select = True 81 | bpy.ops.object.delete() 82 | 83 | # Make light just directional, disable shadows. 84 | lamp = bpy.data.lamps['Lamp'] 85 | lamp.type = 'SUN' 86 | lamp.shadow_method = 'NOSHADOW' 87 | # Possibly disable specular shading: 88 | lamp.use_specular = False 89 | 90 | # Add another light source so stuff facing away from light is not 91 | # completely dark 92 | bpy.ops.object.lamp_add(type='SUN') 93 | lamp2 = bpy.data.lamps['Sun'] 94 | lamp2.shadow_method = 'NOSHADOW' 95 | lamp2.use_specular = False 96 | lamp2.energy = 0.015 97 | sun = bpy.data.objects['Sun'] 98 | sun.rotation_euler = bpy.data.objects['Lamp'].rotation_euler 99 | sun.rotation_euler[0] += 180 100 | 101 | invariants = set(bpy.context.scene.objects) 102 | return invariants, depthFileOutput, normalFileOutput, albedoFileOutput 103 | 104 | 105 | def load_obj(path, scale, remove_doubles, edge_split, invariants): 106 | bpy.ops.import_scene.obj(filepath=path) 107 | if scale != 1: 108 | bpy.ops.transform.resize(value=(scale, scale, scale)) 109 | 110 | objs = [obj for obj in bpy.context.scene.objects if obj not in invariants] 111 | 112 | for object in objs: 113 | bpy.context.scene.objects.active = object 114 | if scale != 1: 115 | bpy.ops.object.transform_apply(scale=True) 116 | if remove_doubles: 117 | bpy.ops.object.mode_set(mode='EDIT') 118 | bpy.ops.mesh.remove_doubles() 119 | bpy.ops.object.mode_set(mode='OBJECT') 120 | if edge_split: 121 | bpy.ops.object.modifier_add(type='EDGE_SPLIT') 122 | bpy.context.object.modifiers["EdgeSplit"].split_angle = 1.32645 123 | bpy.ops.object.modifier_apply( 124 | apply_as='DATA', modifier="EdgeSplit") 125 | 126 | return objs 127 | 128 | 129 | def remove_obj(objs): 130 | for object in objs: 131 | object.select = True 132 | bpy.ops.object.delete() 133 | 134 | 135 | # def parent_obj_to_camera(b_camera): 136 | # origin = (0, 0, 0) 137 | # b_empty = bpy.data.objects.new("Empty", None) 138 | # b_empty.location = origin 139 | # b_camera.parent = b_empty # setup parenting 140 | # 141 | # scn = bpy.context.scene 142 | # scn.objects.link(b_empty) 143 | # scn.objects.active = b_empty 144 | # return b_empty 145 | 146 | 147 | def load_camera_positions(path): 148 | import numpy as np 149 | if path.endswith('.txt'): 150 | return np.loadtxt(path) 151 | elif path.endswith('.npy'): 152 | return np.load(path) 153 | else: 154 | raise IOError( 155 | 'Unrecognized extension for camera_positions %s' % path) 156 | 157 | 158 | def load_render_params(path): 159 | import json 160 | if path is None: 161 | return {} 162 | else: 163 | assert(isinstance(path, str)) 164 | assert(path.endswith('.json')) 165 | with open(path, 'r') as fp: 166 | return json.load(fp) 167 | 168 | 169 | def main(render_params, out_dir, filename_format, obj_path, camera_positions): 170 | render_params = load_render_params(render_params) 171 | if filename_format.endswith('.png'): 172 | filename_format = filename_format[:-4] 173 | 174 | assert(isinstance(obj_path, str)) 175 | assert(obj_path.endswith('.obj')) 176 | assert(os.path.isfile(obj_path)) 177 | 178 | camera_positions = load_camera_positions(camera_positions) 179 | 180 | shape = render_params.get('shape', [128, 128]) 181 | scale = render_params.get('scale', 1) 182 | remove_doubles = render_params.get('remove_doubles', False) 183 | edge_split = render_params.get('edge_split', False) 184 | f = render_params.get('f', 32 / 35) 185 | if f != 32 / 35: 186 | raise NotImplementedError( 187 | 'Only default focal length of 32 / 35 implemented') 188 | 189 | invariants, depthFileOutput, normalFileOutput, albedoFileOutput = setup( 190 | render_params.get('depth_scale', 1.4)) 191 | empty = bpy.data.objects.new("Empty", None) 192 | empty.location = (0, 0, 0) 193 | invariants.add(empty) 194 | 195 | scene = bpy.context.scene 196 | scene.render.resolution_x = shape[1] 197 | scene.render.resolution_y = shape[0] 198 | scene.render.resolution_percentage = 100 199 | scene.render.alpha_mode = 'TRANSPARENT' 200 | cam = scene.objects['Camera'] 201 | # b_empty = parent_obj_to_camera(cam) 202 | 203 | # invariants.add(b_empty) 204 | outputs = { 205 | 'depth': depthFileOutput, 206 | 'normal': normalFileOutput, 207 | 'albedo': albedoFileOutput 208 | } 209 | 210 | def set_camera(eye): 211 | cam.location = eye 212 | cam_constraint = cam.constraints.new(type='TRACK_TO') 213 | cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' 214 | cam_constraint.up_axis = 'UP_Y' 215 | cam_constraint.target = empty 216 | 217 | 218 | eyes = camera_positions 219 | load_obj(obj_path, scale, remove_doubles, edge_split, invariants) 220 | # set output format to png 221 | scene.render.image_settings.file_format = 'PNG' 222 | 223 | for output_node in [ 224 | depthFileOutput, normalFileOutput, albedoFileOutput]: 225 | output_node.base_path = '' 226 | 227 | if not os.path.isdir(out_dir): 228 | os.makedirs(out_dir) 229 | 230 | for i, eye in enumerate(eyes): 231 | set_camera(eye) 232 | base_path = os.path.join(out_dir, filename_format % (i, '')) 233 | scene.render.filepath = base_path 234 | for k, v in outputs.items(): 235 | filename = filename_format % (i, k) 236 | v.file_slots[0].path = os.path.join(out_dir, filename) 237 | bpy.ops.render.render(write_still=True) # render still 238 | 239 | 240 | def get_args(): 241 | import argparse 242 | import sys 243 | parser = argparse.ArgumentParser( 244 | description='Renders given obj file by rotation a camera around it.') 245 | parser.add_argument( 246 | '--render_params', type=str, help='path to json render_params file') 247 | parser.add_argument('--out_dir', type=str, help='output directory') 248 | parser.add_argument( 249 | '--filename_format', type=str, default='r%03%s', 250 | help='output directory') 251 | parser.add_argument('--obj', type=str, help='path to obj file') 252 | parser.add_argument( 253 | '--camera_positions', type=str, 254 | help='path to camera_positions npy file') 255 | 256 | argv = sys.argv[sys.argv.index("--") + 1:] 257 | args = parser.parse_args(argv) 258 | 259 | return args 260 | 261 | 262 | args = get_args() 263 | main( 264 | args.render_params, args.out_dir, args.filename_format, args.obj, 265 | args.camera_positions) 266 | -------------------------------------------------------------------------------- /core/renderings/scripts/create_base_renderings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_integer('dim', default=128, help='dimension of square renderings') 10 | flags.DEFINE_bool( 11 | 'turntable', default=False, help='render regular angles (default random)') 12 | flags.DEFINE_integer('n_views', default=24, help='number of views to render') 13 | flags.DEFINE_bool( 14 | 'verbose', default=True, help='suppress blender output if False') 15 | flags.DEFINE_list( 16 | 'cat', default=None, 17 | help='category descriptions to render, ' 18 | 'comma separated, e.g. chair,sofa,plane') 19 | 20 | 21 | def main(_): 22 | from shapenet.core.renderings.renderings_manager import get_base_manager 23 | from shapenet.core import to_cat_id 24 | from shapenet.core.objs import try_extract_models 25 | cat_ids = [to_cat_id(c) for c in FLAGS.cat] 26 | for cat_id in cat_ids: 27 | try_extract_models(cat_id) 28 | 29 | manager = get_base_manager( 30 | dim=FLAGS.dim, 31 | turntable=FLAGS.turntable, 32 | n_views=FLAGS.n_views, 33 | ) 34 | manager.render_all(cat_ids=cat_ids, verbose=FLAGS.verbose) 35 | 36 | 37 | app.run(main) 38 | -------------------------------------------------------------------------------- /core/renderings/scripts/report.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_integer('dim', default=128, help='dimension of square renderings') 10 | flags.DEFINE_bool( 11 | 'turntable', default=False, help='render regular angles (default random)') 12 | flags.DEFINE_integer('n_views', default=24, help='number of views') 13 | flags.DEFINE_list( 14 | 'cat', default=None, 15 | help='category descriptions to render, ' 16 | 'comma separated, e.g. chair,sofa,plane') 17 | 18 | 19 | def main(_): 20 | from shapenet.core.renderings.renderings_manager import get_base_manager 21 | from shapenet.core import cat_desc_to_id 22 | from shapenet.core import cat_id_to_desc 23 | cat = FLAGS.cat 24 | if cat is None or len(cat) == 0: 25 | from shapenet.r2n2 import get_cat_ids 26 | cat_ids = get_cat_ids() 27 | else: 28 | cat_ids = [cat_desc_to_id(c) for c in FLAGS.cat] 29 | 30 | print('Required renderings:') 31 | for cat_id in cat_ids: 32 | manager = get_base_manager( 33 | dim=FLAGS.dim, 34 | turntable=FLAGS.turntable, 35 | n_views=FLAGS.n_views, 36 | ) 37 | n = len(tuple(manager.needs_rendering_keys(cat_ids=[cat_id]))) 38 | n_total = len(tuple(manager.view_manager.keys(cat_ids=[cat_id]))) 39 | cat = cat_id_to_desc(cat_id) 40 | print('%s: %s\n %d / %d' % (cat_id, cat, n, n_total)) 41 | 42 | 43 | app.run(main) 44 | -------------------------------------------------------------------------------- /core/scripts/create_ids.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from progress.bar import IncrementalBar 7 | from shapenet.core.path import get_cat_ids 8 | from shapenet.core.path import get_example_ids_from_zip 9 | from shapenet.core.path import get_example_ids 10 | from shapenet.core import create_ids 11 | 12 | 13 | def check(cat_ids): 14 | print('Checking...') 15 | bar = IncrementalBar(max=len(cat_ids)) 16 | for cat_id in cat_ids: 17 | actual = get_example_ids(cat_id) 18 | original = get_example_ids_from_zip(cat_id) 19 | if tuple(actual) != tuple(original): 20 | raise IOError('ids for cat_id "%s" not consistent' % cat_id) 21 | bar.next() 22 | bar.finish() 23 | print('All ids consistent!') 24 | 25 | 26 | def main(cat_ids): 27 | print('Creating example_ids') 28 | create_ids(cat_ids) 29 | check(cat_ids) 30 | 31 | 32 | main(get_cat_ids()) 33 | -------------------------------------------------------------------------------- /core/views/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for creating/loading random or turntable views for experiments. 3 | 4 | On their own, views have little meaning. However, they can be used consistently 5 | between renderings and frustrum_voxels. They contain rendering parameters 6 | shared across the dataset/view manager and camera positions (potentially unique 7 | to each model/view). 8 | """ 9 | -------------------------------------------------------------------------------- /core/views/archive.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import zipfile 7 | import numpy as np 8 | import yaml 9 | from .manager import ViewManager 10 | 11 | 12 | class ArchiveViewManager(ViewManager): 13 | def __init__(self, archive): 14 | self.archive = archive 15 | 16 | def view_params_subpath(self): 17 | return 'view_params.yaml' 18 | 19 | def camera_positions_subpath(self, cat_id, example_id): 20 | return os.path.join('camera_positions', cat_id, '%s.txt' % example_id) 21 | 22 | def open(self, subpath): 23 | raise NotImplementedError('Abstract method') 24 | 25 | def list_contents(self): 26 | raise NotImplementedError('Abstract method') 27 | 28 | def get_view_params(self): 29 | return yaml.load(self.open(self.view_params_subpath())) 30 | 31 | def get_camera_positions(self, cat_id, example_id): 32 | return np.loadtxt( 33 | self.open(self.camera_positions_subpath(cat_id, example_id))) 34 | 35 | def get_example_ids(self, cat_id): 36 | return [ 37 | k.split('/')[2] for k in self.list_contents() 38 | if k.startswith('camera_positions/%s/' % cat_id)] 39 | 40 | def get_cat_ids(self): 41 | split_contents = (k.split('/') for k in self.list_contents()) 42 | split_contents = (k for k in split_contents if len(k) == 3) 43 | return sorted(set(k[1] for k in split_contents)) 44 | 45 | 46 | class ZipViewManager(ArchiveViewManager): 47 | def open(self, subpath): 48 | return self.archive.open(subpath) 49 | 50 | def list_contents(self): 51 | return self.archive.namelist() 52 | 53 | 54 | def get_base_zip_path(turntable=False, n_views=24): 55 | import os 56 | from ..path import get_data_dir 57 | from .base import get_base_id 58 | return os.path.join( 59 | get_data_dir('views'), '%s.zip' % get_base_id(turntable, n_views)) 60 | 61 | 62 | def get_base_zip_manager(turntable=False, n_views=24): 63 | path = get_base_zip_path(turntable=turntable, n_views=n_views) 64 | if not os.path.isfile(path): 65 | from .txt import get_base_txt_manager 66 | import shutil 67 | txt = get_base_txt_manager(turntable=turntable, n_views=n_views) 68 | shutil.make_archive(path[:-4], 'zip', txt.root_folder) 69 | return ZipViewManager(zipfile.ZipFile(path, 'r')) 70 | -------------------------------------------------------------------------------- /core/views/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | def get_base_id(turntable=False, n_views=24): 7 | return '%s-%03d' % ( 8 | 'turntable' if turntable else 'rand', n_views) 9 | 10 | 11 | def get_base_manager(turntable=False, n_views=24, format='h5'): 12 | if format == 'zip': 13 | from .archive import get_base_zip_manager 14 | fn = get_base_zip_manager 15 | elif format == 'txt': 16 | from .txt import get_base_txt_manager 17 | fn = get_base_txt_manager 18 | elif format == 'h5': 19 | from .h5 import get_base_h5_manager 20 | fn = get_base_h5_manager 21 | elif format == 'lazy': 22 | from .lazy import get_base_lazy_manager 23 | fn = get_base_lazy_manager 24 | else: 25 | raise ValueError('Unrecognized format "%s"' % format) 26 | return fn(turntable=turntable, n_views=n_views) 27 | -------------------------------------------------------------------------------- /core/views/h5.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from .manager import WritableViewManager 7 | from .. import get_example_ids 8 | 9 | 10 | class H5ViewManager(WritableViewManager): 11 | def __init__(self, base_group): 12 | self._base_group = base_group 13 | self._indices = {} 14 | 15 | def get_example_indices(self, cat_id): 16 | if cat_id not in self._indices: 17 | example_ids = self.get_example_ids(cat_id) 18 | indices = {k: i for i, k in enumerate(example_ids)} 19 | self._indices[cat_id] = indices 20 | return self._indices[cat_id] 21 | 22 | def _get_view_params(self): 23 | return self._base_group.require_group('view_params').attrs 24 | 25 | def get_view_params(self): 26 | return dict(**self._get_view_params()) 27 | 28 | def get_camera_positions(self, cat_id, example_id): 29 | return self.get_camera_positions_from_index( 30 | cat_id, self.get_example_indices(cat_id)[example_id]) 31 | 32 | def get_camera_positions_from_index(self, cat_id, example_index): 33 | return np.array(self.camera_positions_group[cat_id][example_index]) 34 | 35 | @property 36 | def camera_positions_group(self): 37 | return self._base_group.require_group('camera_positions') 38 | 39 | def set_view_params(self, **params): 40 | attrs = self._get_view_params() 41 | for k, v in params.items(): 42 | attrs[k] = v 43 | 44 | def set_camera_positions(self, cat_id, example_id, value): 45 | example_indices = self.get_example_indices(cat_id) 46 | n = len(example_indices) 47 | example_index = example_indices[example_id] 48 | dataset = self.camera_positions_group.require_dataset( 49 | cat_id, shape=(n, self.get_view_params()['n_views'], 3), 50 | dtype=np.float32, exact=True) 51 | dataset[example_index] = value 52 | 53 | def get_example_ids(self, cat_id): 54 | return get_example_ids(cat_id) 55 | 56 | def get_cat_ids(self): 57 | return self._base_group.require_group('camera_positions').keys() 58 | 59 | 60 | def get_base_h5_manager_path(turntable=False, n_views=24): 61 | import os 62 | from ..path import get_data_dir 63 | from .base import get_base_id 64 | return os.path.join( 65 | get_data_dir('views'), '%s.h5' % get_base_id(turntable, n_views)) 66 | 67 | 68 | def get_base_h5_manager(turntable=False, n_views=24): 69 | import os 70 | import h5py 71 | from .lazy import get_base_lazy_manager 72 | data_path = get_base_h5_manager_path( 73 | turntable, n_views) 74 | if os.path.isfile(data_path): 75 | fp = h5py.File(data_path, mode='r') 76 | return H5ViewManager(fp) 77 | else: 78 | try: 79 | with h5py.File(data_path, 'w') as fp: 80 | manager = H5ViewManager(fp) 81 | manager.copy_from(get_base_lazy_manager(turntable, n_views)) 82 | return H5ViewManager(h5py.File(data_path, mode='r')) 83 | except (Exception, KeyboardInterrupt): 84 | os.remove(data_path) 85 | raise 86 | -------------------------------------------------------------------------------- /core/views/lazy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .manager import ViewManager 6 | 7 | 8 | class LazyViewManager(ViewManager): 9 | def __init__( 10 | self, view_params, cat_ids, example_ids_fn, camera_pos_fn): 11 | self._view_params = view_params 12 | self._cat_ids = tuple(cat_ids) 13 | self._example_ids_fn = example_ids_fn 14 | self._camera_pos_fn = camera_pos_fn 15 | 16 | def get_view_params(self): 17 | return self._view_params.copy() 18 | 19 | def get_camera_positions(self, cat_id, example_id): 20 | return self._camera_pos_fn(cat_id, example_id) 21 | 22 | def get_example_ids(self, cat_id): 23 | return self._example_ids_fn(cat_id) 24 | 25 | def get_cat_ids(self): 26 | return self._cat_ids 27 | 28 | 29 | def get_base_lazy_manager(turntable=False, n_views=24, seed=0): 30 | import numpy as np 31 | from ...r2n2 import get_cat_ids 32 | from .. import get_example_ids 33 | from .base import get_base_id 34 | cat_ids = get_cat_ids() 35 | dist = 1.166 # sqrt(1 + 0.6**2) - looked good in experiments 36 | 37 | def polar_to_cartesian(dist, theta, phi): 38 | z = np.cos(phi) 39 | s = np.sin(phi) 40 | x = s * np.cos(theta) 41 | y = s * np.sin(theta) 42 | return np.stack((x, y, z), axis=-1) * dist 43 | 44 | if turntable: 45 | def get_camera_pos(): 46 | theta = np.deg2rad(np.linspace(0, 360, n_views+1)[:-1]) 47 | phi = np.deg2rad(60.0) * np.ones_like(theta) 48 | return polar_to_cartesian(dist, theta, phi).astype(np.float32) 49 | else: 50 | np.random.seed(seed) 51 | 52 | def get_camera_pos(): 53 | size = (n_views,) 54 | theta = np.deg2rad(np.random.uniform(0, 360, size=size)) 55 | phi = np.deg2rad(90 - np.random.uniform(25, 30, size=size)) 56 | return polar_to_cartesian(dist, theta, phi).astype(np.float32) 57 | 58 | view_params = dict( 59 | depth_scale=1.4, 60 | scale=1, 61 | n_views=24, 62 | f=32 / 35, 63 | view_id=get_base_id(turntable=turntable, n_views=n_views) 64 | ) 65 | 66 | return LazyViewManager( 67 | view_params, cat_ids, get_example_ids, 68 | lambda cat_id, example_id: get_camera_pos()) 69 | -------------------------------------------------------------------------------- /core/views/manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | class ViewManager(object): 7 | @property 8 | def view_id(self): 9 | return self.get_view_params()['view_id'] 10 | 11 | def get_view_params(self): 12 | raise NotImplementedError('Abstract method') 13 | 14 | def get_camera_positions(self, cat_id, example_id): 15 | raise NotImplementedError('Abstract method') 16 | 17 | def get_example_ids(self, cat_id): 18 | raise NotImplementedError('Abstract method') 19 | 20 | def get_cat_ids(self): 21 | raise NotImplementedError('Abstract method') 22 | 23 | def keys(self, cat_ids=None): 24 | if cat_ids is None: 25 | cat_ids = self.get_cat_ids() 26 | for cat_id in cat_ids: 27 | for example_id in self.get_example_ids(cat_id): 28 | yield (cat_id, example_id) 29 | 30 | def to_dict(self): 31 | raise NotImplementedError('Abstract method') 32 | 33 | 34 | class WritableViewManager(ViewManager): 35 | def set_view_params(self, **params): 36 | raise NotImplementedError('Abstract method') 37 | 38 | def set_camera_positions(self, cat_id, example_id, positions): 39 | raise NotImplementedError('Abstract method') 40 | 41 | def copy_from(self, src_manager): 42 | self.set_view_params(**src_manager.get_view_params()) 43 | 44 | for cat_id in src_manager.get_cat_ids(): 45 | for example_id in src_manager.get_example_ids(cat_id): 46 | positions = src_manager.get_camera_positions( 47 | cat_id, example_id) 48 | self.set_camera_positions(cat_id, example_id, positions) 49 | -------------------------------------------------------------------------------- /core/views/scripts/check_consistent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from shapenet.core.views.txt import get_base_txt_manager 8 | from shapenet.core.views.h5 import get_base_h5_manager 9 | from shapenet.core.views.archive import get_base_zip_manager 10 | 11 | txt = get_base_txt_manager() 12 | h5 = get_base_h5_manager() 13 | zi = get_base_zip_manager() 14 | 15 | cat_id = h5.get_cat_ids()[2] 16 | example_id = h5.get_example_ids(cat_id)[10] 17 | 18 | hc = h5.get_camera_positions(cat_id, example_id) 19 | tc = txt.get_camera_positions(cat_id, example_id) 20 | zc = zi.get_camera_positions(cat_id, example_id) 21 | 22 | print(hc) 23 | print(tc) 24 | print(zc) 25 | print(hc - tc) 26 | print(hc - zc) 27 | 28 | print('Checking params...') 29 | hp = h5.get_view_params() 30 | tp = txt.get_view_params() 31 | zp = zi.get_view_params() 32 | 33 | for k, v in hp.items(): 34 | assert(tp[k] == v) 35 | assert(zp[k] == v) 36 | 37 | assert(len(hp) == len(tp)) 38 | assert(len(hp) == len(zp)) 39 | print('Consistent!') 40 | -------------------------------------------------------------------------------- /core/views/scripts/create_base_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | 7 | def create_h5_data(**kwargs): 8 | from shapenet.core.views.h5 import get_base_h5_manager 9 | get_base_h5_manager(**kwargs) 10 | 11 | 12 | def create_txt_data(**kwargs): 13 | from shapenet.core.views.txt import get_base_txt_manager 14 | get_base_txt_manager(**kwargs) 15 | 16 | 17 | def create_zip_data(**kwargs): 18 | from shapenet.core.views.archive import get_base_zip_manager 19 | get_base_zip_manager(**kwargs) 20 | 21 | 22 | # create_txt_data() 23 | create_h5_data() 24 | # create_zip_data() 25 | -------------------------------------------------------------------------------- /core/views/txt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import yaml 7 | import numpy as np 8 | from .manager import WritableViewManager 9 | 10 | 11 | class TxtViewManager(WritableViewManager): 12 | def __init__(self, root_folder): 13 | self._root_folder = root_folder 14 | 15 | @property 16 | def root_folder(self): 17 | return self._root_folder 18 | 19 | def view_params_path(self): 20 | return os.path.join(self.root_folder, 'view_params.yaml') 21 | 22 | def has_view_params(self): 23 | return os.path.isfile(self.view_params_path()) 24 | 25 | def camera_positions_path(self, cat_id, example_id): 26 | return os.path.join( 27 | self.root_folder, 'camera_positions', cat_id, 28 | '%s.txt' % example_id) 29 | 30 | def get_view_params(self): 31 | path = self.view_params_path() 32 | with open(path, 'r') as fp: 33 | return yaml.load(fp) 34 | 35 | def get_camera_positions(self, cat_id, example_id): 36 | return np.loadtxt(self.camera_positions_path(cat_id, example_id)) 37 | 38 | def set_view_params(self, **params): 39 | with open(self.view_params_path(), 'w') as fp: 40 | yaml.dump(params, fp, default_flow_style=False) 41 | 42 | def set_camera_positions(self, cat_id, example_id, camera_pos): 43 | path = self.camera_positions_path(cat_id, example_id) 44 | dirname = os.path.dirname(path) 45 | if not os.path.isdir(dirname): 46 | os.makedirs(dirname) 47 | return np.savetxt(path, camera_pos) 48 | 49 | def get_cat_ids(self): 50 | return os.listdir(os.path.join(self.root_folder, 'camera_positions')) 51 | 52 | def get_example_ids(self, cat_id): 53 | return ( 54 | k[:-4] for k in os.listdir( 55 | os.path.join(self.root_folder, 'camera_positions', cat_id))) 56 | 57 | 58 | def get_base_txt_manager_dir(turntable=False, n_views=24): 59 | from .. import path 60 | from .base import get_base_id 61 | manager_id = get_base_id(turntable, n_views) 62 | return path.get_data_dir('views', manager_id) 63 | 64 | 65 | def get_base_txt_manager(turntable=False, n_views=24): 66 | from .txt import TxtViewManager 67 | root_folder = get_base_txt_manager_dir(turntable, n_views) 68 | manager = TxtViewManager(root_folder) 69 | if not manager.has_view_params(): 70 | from .lazy import get_base_lazy_manager 71 | manager.copy_from(get_base_lazy_manager(turntable, n_views)) 72 | return manager 73 | -------------------------------------------------------------------------------- /core/voxels/.gitignore: -------------------------------------------------------------------------------- 1 | _data/* 2 | *.binvox 3 | -------------------------------------------------------------------------------- /core/voxels/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .config import VoxelConfig, get_config 6 | 7 | __all__ = [ 8 | VoxelConfig, 9 | get_config 10 | ] 11 | -------------------------------------------------------------------------------- /core/voxels/concat_ds.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import dids.core as c 6 | from dids.file_io.hdf5 import Hdf5Resource 7 | 8 | 9 | class ConcatenatedDataset(c.DependentResource, c.Dataset): 10 | def __init__(self, path, starts_key='starts', values_key='values'): 11 | c.DependentResource.__init__(Hdf5Resource(path)) 12 | self._starts_key = starts_key 13 | self._values_key = values_key 14 | 15 | def _open_self(self): 16 | c.DependentResource._open_self(self) 17 | self._starts = self._parent._base[self._starts_key] 18 | self._values = self._parent._base[self._values_key] 19 | 20 | def __contains__(self, key): 21 | return isinstance(key, int) and 0 <= key < self._len 22 | 23 | def __len__(self): 24 | return self._len 25 | 26 | def __keys__(self): 27 | return range(self._len) 28 | 29 | def __getitem__(self, key): 30 | start, end = self._starts[key: key + 2] 31 | values = self._values[start:end] 32 | return values 33 | -------------------------------------------------------------------------------- /core/voxels/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import six 7 | from . import path 8 | from .. import get_example_ids 9 | 10 | 11 | class VoxelConfig(object): 12 | def __init__( 13 | self, voxel_dim=32, exact=True, dc=True, aw=True, c=False, 14 | v=False): 15 | self._voxel_dim = voxel_dim 16 | self._exact = exact 17 | self._dc = dc 18 | self._aw = aw 19 | self._c = c 20 | self._v = v 21 | self._voxel_id = get_voxel_id( 22 | voxel_dim, exact=exact, dc=dc, aw=aw, c=c, v=v) 23 | 24 | def filled(self, fill_alg=None): 25 | if fill_alg is None: 26 | return self 27 | else: 28 | from .filled import FilledVoxelConfig 29 | return FilledVoxelConfig(self, fill_alg) 30 | 31 | @staticmethod 32 | def from_id(voxel_id): 33 | kwargs = parse_voxel_id(voxel_id) 34 | fill_alg = kwargs.pop('fill_alg', None) 35 | return VoxelConfig(**kwargs).filled(fill_alg) 36 | 37 | @property 38 | def voxel_dim(self): 39 | return self._voxel_dim 40 | 41 | @property 42 | def exact(self): 43 | return self._exact 44 | 45 | @property 46 | def dc(self): 47 | return self._dc 48 | 49 | @property 50 | def aw(self): 51 | return self._aw 52 | 53 | @property 54 | def c(self): 55 | return self._c 56 | 57 | @property 58 | def v(self): 59 | return self._v 60 | 61 | @property 62 | def voxel_id(self): 63 | return self._voxel_id 64 | 65 | def get_binvox_subpath(self, cat_id, example_id): 66 | return path.get_binvox_subpath(cat_id, example_id) 67 | 68 | def get_binvox_path(self, cat_id, example_id): 69 | return os.path.join( 70 | self.root_dir, self.get_binvox_subpath(cat_id, example_id)) 71 | 72 | @property 73 | def root_dir(self): 74 | return path.get_binvox_dir(self.voxel_id) 75 | 76 | def create_voxel_data(self, cat_id, example_ids=None, overwrite=False): 77 | from progress.bar import IncrementalBar 78 | from util3d.voxel.convert import obj_to_binvox 79 | from .. import objs 80 | if example_ids is None: 81 | example_ids = get_example_ids(cat_id) 82 | 83 | kwargs = dict( 84 | voxel_dim=self.voxel_dim, 85 | exact=self.exact, 86 | dc=self.dc, 87 | aw=self.aw, 88 | c=self.c, 89 | v=self.v, 90 | overwrite_original=True) 91 | 92 | path_ds = objs.get_extracted_obj_path_dataset(cat_id) 93 | bvd = self.get_binvox_path(cat_id, None) 94 | if not os.path.isdir(bvd): 95 | os.makedirs(bvd) 96 | 97 | with path_ds: 98 | print('Creating binvox voxel data') 99 | bar = IncrementalBar(max=len(example_ids)) 100 | for example_id in example_ids: 101 | bar.next() 102 | binvox_path = self.get_binvox_path(cat_id, example_id) 103 | if overwrite or not os.path.isfile(binvox_path): 104 | obj_path = path_ds[example_id] 105 | obj_to_binvox(obj_path, binvox_path, **kwargs) 106 | bar.finish() 107 | 108 | def get_dataset(self, cat_id): 109 | from .datasets import get_dataset 110 | print( 111 | 'Warning: voxel_config.get_dataset deprecated. ' 112 | 'Use voxels.datasets.get_dataset instead') 113 | return get_dataset(self, cat_id, key='zip') 114 | 115 | 116 | def get_base_config(voxel_dim): 117 | return VoxelConfig(voxel_dim) 118 | 119 | 120 | def get_alt_config(voxel_dim): 121 | return VoxelConfig( 122 | voxel_dim, exact=False, dc=False, aw=False, c=True, v=True) 123 | 124 | 125 | def get_config(voxel_dim, alt=False): 126 | if alt: 127 | return get_alt_config(voxel_dim) 128 | else: 129 | return get_base_config(voxel_dim) 130 | 131 | 132 | def get_voxel_id( 133 | voxel_dim=32, exact=True, dc=True, aw=True, c=False, v=False, 134 | fill=None): 135 | def bstr(b): 136 | return 't' if b else 'f' 137 | voxel_id = 'd%03d%s%s%s' % (voxel_dim, bstr(exact), bstr(dc), bstr(aw)) 138 | if c: 139 | voxel_id = '%sc' % voxel_id 140 | if v: 141 | voxel_id = '%sv' % voxel_id 142 | if fill is not None: 143 | voxel_id = '%s_%s' % (voxel_id, fill) 144 | return voxel_id 145 | 146 | 147 | default_voxel_id = get_voxel_id() 148 | 149 | 150 | def split_id(voxel_id): 151 | parts = voxel_id.split('_') 152 | if len(parts) == 2: 153 | voxel_id, fill_alg = parts 154 | else: 155 | assert(len(parts) == 1) 156 | fill_alg = None 157 | return voxel_id, fill_alg 158 | 159 | 160 | def parse_voxel_id(voxel_id): 161 | """Inverse function to `get_voxel_id`.""" 162 | voxel_id, fill_alg = split_id(voxel_id) 163 | if not is_valid_voxel_id(voxel_id): 164 | raise ValueError('voxel_id %s not valid.' % voxel_id) 165 | kwargs = dict( 166 | voxel_dim=int(voxel_id[1:4]), 167 | exact=voxel_id[4] == 't', 168 | dc=voxel_id[5] == 't', 169 | aw=voxel_id[6] == 't', 170 | ) 171 | rest = voxel_id[7:] 172 | if rest.startswith('c'): 173 | rest = rest[1:] 174 | kwargs['c'] = True 175 | if rest.startswith('v'): 176 | kwargs['v'] = True 177 | rest = rest[1:] 178 | assert(rest == '') 179 | if fill_alg is not None: 180 | kwargs['fill_alg'] = fill_alg 181 | return kwargs 182 | 183 | 184 | def is_valid_voxel_id(voxel_id): 185 | voxel_id, fill_alg = split_id(voxel_id) 186 | if fill_alg is not None: 187 | from . import filled 188 | try: 189 | filled.check_valid_fill_alg(fill_alg) 190 | except ValueError: 191 | return False 192 | 193 | nc = len(voxel_id) 194 | if not isinstance(voxel_id, six.string_types) or 7 <= nc <= 9: 195 | return False 196 | if nc == 9 and nc[-2:] != 'cv': 197 | return False 198 | if nc == 8 and nc[-1] not in ('c', 'v'): 199 | return False 200 | try: 201 | int3 = int(voxel_id[1:4]) 202 | except ValueError: 203 | return False 204 | if int3 <= 0: 205 | return False 206 | return voxel_id[0] == 'd' and all(s in ('t', 'f') for s in voxel_id[4:]) 207 | -------------------------------------------------------------------------------- /core/voxels/filled.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from .config import VoxelConfig 7 | from util3d.voxel.binvox import DenseVoxels 8 | from util3d.voxel.manip import OrthographicFiller 9 | 10 | 11 | def filled_voxels(voxels_dense): 12 | from scipy.ndimage.morphology import binary_fill_holes 13 | _structure = None 14 | return binary_fill_holes(voxels_dense, _structure) 15 | 16 | 17 | class FillAlg(object): 18 | def __init__(self): 19 | raise RuntimeError('Not meant to be instantiated') 20 | 21 | BASE = 'filled' 22 | ORTHOGRAPHIC = 'orthographic' 23 | 24 | 25 | _fill_fns = { 26 | FillAlg.BASE: lambda dims: filled_voxels, 27 | FillAlg.ORTHOGRAPHIC: OrthographicFiller, 28 | } 29 | 30 | 31 | _algs = (FillAlg.BASE, FillAlg.ORTHOGRAPHIC) 32 | 33 | 34 | def check_valid_fill_alg(fill_alg): 35 | if fill_alg not in _algs: 36 | raise ValueError('Invalid fill_alg "%s": must be one of %s' % _algs) 37 | 38 | 39 | class FilledVoxelConfig(VoxelConfig): 40 | def __init__(self, base_config, fill_alg=FillAlg.BASE): 41 | check_valid_fill_alg(fill_alg) 42 | self._fill_alg = fill_alg 43 | self._fill_fn = _fill_fns[fill_alg] 44 | self._base_config = base_config 45 | self._voxel_id = '%s_%s' % (base_config.voxel_id, fill_alg) 46 | 47 | @property 48 | def voxel_dim(self): 49 | return self._base_config.voxel_dim 50 | 51 | @property 52 | def root_dir(self): 53 | from . import path 54 | dir = os.path.join( 55 | path.data_dir, self._fill_alg, self._base_config.voxel_id) 56 | if not os.path.isdir(dir): 57 | os.makedirs(dir) 58 | return dir 59 | 60 | def get_fill_dense_fn(self, shape): 61 | return self._fill_fn(shape) 62 | 63 | def get_fill_voxels_fn(self, shape): 64 | dense_fn = self.get_fill_dense_fn(shape) 65 | 66 | def f(vox): 67 | return DenseVoxels( 68 | dense_fn(vox.dense_data()), scale=vox.scale, 69 | translate=vox.translate) 70 | return f 71 | 72 | def create_voxel_data(self, cat_id, example_ids=None, overwrite=False): 73 | from .datasets import get_manager 74 | src = None 75 | for shape_key in ('pad', 'jag', 'ind'): 76 | for compression in ('lzf', 'gzip', None): 77 | for key in ('brle', 'rle'): 78 | src = get_manager( 79 | self._base_config, cat_id, key=key, 80 | compression=compression, shape_key=shape_key) 81 | if src.has_dataset(): 82 | break 83 | else: 84 | src = get_manager(self._base_config, cat_id, key='zip') 85 | if not src.has_dataset(): 86 | src = get_manager(self._base_config, cat_id, key='file') 87 | 88 | dst = get_manager(self, cat_id, 'file')._get_dataset(mode='a') 89 | fill_fn = self.get_fill_voxels_fn((self.voxel_dim,)*3) 90 | src_ds = src.get_dataset().map(fill_fn) 91 | if example_ids is not None: 92 | src_ds = src_ds.subset(example_ids) 93 | with src_ds, dst: 94 | print('Writing filled voxels to file') 95 | dst.save_dataset(src_ds) 96 | -------------------------------------------------------------------------------- /core/voxels/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from ..path import get_data_dir 7 | 8 | data_dir = get_data_dir('voxels') 9 | if not os.path.isdir(data_dir): 10 | os.makedirs(data_dir) 11 | 12 | 13 | def get_binvox_subpath(cat_id, example_id=None): 14 | if example_id is None: 15 | return cat_id 16 | return os.path.join(cat_id, '%s.binvox' % example_id) 17 | 18 | 19 | def get_binvox_dir(voxel_id): 20 | return os.path.join(data_dir, voxel_id) 21 | -------------------------------------------------------------------------------- /core/voxels/scripts/cat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | # VoxelConfig args 10 | flags.DEFINE_integer('voxel_dim', default=32, help='voxel dimension') 11 | flags.DEFINE_boolean( 12 | 'alt', default=False, help='use alternative base VoxelConfig') 13 | flags.DEFINE_string('fill', default=None, help='optional fill algorithm') 14 | flags.DEFINE_list( 15 | 'cat', default=None, help='catergory(s) to create (ID or descriptor)') 16 | 17 | 18 | def main(_): 19 | from shapenet.core.voxels.config import get_config 20 | from shapenet.core import to_cat_id 21 | from shapenet.r2n2 import get_cat_ids 22 | config = get_config(FLAGS.voxel_dim, alt=FLAGS.alt) 23 | fill = FLAGS.fill 24 | if fill is not None: 25 | config = config.filled(fill) 26 | if FLAGS.cat is None: 27 | cat_ids = get_cat_ids() 28 | else: 29 | cat_ids = [to_cat_id(c) for c in FLAGS.cat] 30 | if FLAGS.fill is not None: 31 | config = config.filled(FLAGS.fill) 32 | for cat_id in cat_ids: 33 | config.create_voxel_data(cat_id) 34 | 35 | 36 | app.run(main) 37 | -------------------------------------------------------------------------------- /core/voxels/scripts/create_all.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from shapenet.r2n2 import get_cat_ids 7 | from shapenet.core.voxels.datasets import get_manager, convert 8 | from shapenet.core.voxels import get_config 9 | 10 | cat_ids = get_cat_ids() 11 | voxel_dims = (32, 64, 128, 256) 12 | n_configs = 2 13 | n = len(cat_ids) * len(voxel_dims) * n_configs 14 | i = 1 15 | for voxel_dim in voxel_dims: 16 | for config in ( 17 | get_config(voxel_dim).filled('orthographic'), 18 | get_config(voxel_dim, alt=True), 19 | ): 20 | for cat_id in cat_ids: 21 | dst = get_manager( 22 | config, cat_id, key='brle', compression='lzf', pad=True) 23 | src_kwargs = dict(key='file') 24 | print('Converting %d / %d' % (i, n)) 25 | convert(dst, overwrite=False, delete_src=False, **src_kwargs) 26 | i += 1 27 | -------------------------------------------------------------------------------- /core/voxels/scripts/create_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | Script for creating/converting voxel data formats. 5 | 6 | Example usage: 7 | ./create_dataset.py --cat=car,plane --voxel_dim=128 \ 8 | --format=rle --overwrite=True --compression=gzip --shape=pad 9 | --src_format=file 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | from absl import flags, app 17 | FLAGS = flags.FLAGS 18 | 19 | # VoxelConfig args 20 | flags.DEFINE_integer('voxel_dim', default=32, help='voxel dimension') 21 | flags.DEFINE_boolean( 22 | 'alt', default=False, help='use alternative base VoxelConfig') 23 | flags.DEFINE_string('fill', default=None, help='optional fill algorithm') 24 | 25 | flags.DEFINE_list( 26 | 'cat', default=None, help='cat desc(s) to create (comma separated)') 27 | 28 | flags.DEFINE_boolean( 29 | 'overwrite', default=False, help='overwrite existing data') 30 | flags.DEFINE_string( 31 | 'format', default='file', 32 | help='input format, one of ' 33 | '["file", "zip", "rle", "brle"]') 34 | flags.DEFINE_string( 35 | 'shape', default=None, help='one of ["pad", "jag", "ind"]') 36 | flags.DEFINE_string( 37 | 'compression', default=None, help='one of [None, "gzip", "lzf"]') 38 | 39 | flags.DEFINE_string( 40 | 'src_shape', default=None, help='one of ["pad", "jag", "ind", "cat"]') 41 | flags.DEFINE_string( 42 | 'src_compression', default=None, help='one of [None, "gzip", "lzf"]') 43 | flags.DEFINE_string( 44 | 'src_format', default=None, 45 | help='output format, one of ' 46 | '[None, "file", "zip", "rle", "brle"]') 47 | 48 | flags.DEFINE_boolean( 49 | 'delete_src', default=False, 50 | help='if true, source data is delete (unless in is None)') 51 | 52 | 53 | def safe_update(out, **added): 54 | for k, v in added.items(): 55 | if v is not None: 56 | if k in out: 57 | raise ValueError('%s already in out' % k) 58 | out[k] = v 59 | 60 | 61 | def main(_): 62 | from shapenet.core.voxels.config import get_config 63 | from shapenet.core.voxels.datasets import get_manager, convert 64 | from shapenet.core import to_cat_id 65 | config = get_config(FLAGS.voxel_dim, alt=FLAGS.alt) 66 | if FLAGS.cat is None: 67 | # from shapenet.r2n2 import get_cat_ids 68 | raise ValueError('Must provide at least one cat to convert.') 69 | if FLAGS.fill is not None: 70 | config = config.filled(FLAGS.fill) 71 | 72 | kwargs = dict(config=config, key=FLAGS.format) 73 | safe_update(kwargs, compression=FLAGS.compression, shape_key=FLAGS.shape) 74 | src_kwargs = dict() 75 | safe_update( 76 | src_kwargs, key=FLAGS.src_format, compression=FLAGS.src_compression, 77 | shape_key=FLAGS.src_shape) 78 | 79 | for cat in FLAGS.cat: 80 | dst = get_manager(cat_id=to_cat_id(cat), **kwargs) 81 | convert( 82 | dst, overwrite=FLAGS.overwrite, delete_src=FLAGS.delete_src, 83 | **src_kwargs) 84 | 85 | 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /core/voxels/scripts/example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from absl import flags, app 7 | FLAGS = flags.FLAGS 8 | 9 | # VoxelConfig args 10 | flags.DEFINE_integer('voxel_dim', default=32, help='voxel dimension') 11 | flags.DEFINE_boolean( 12 | 'alt', default=False, help='use alternative base VoxelConfig') 13 | flags.DEFINE_string('fill', default=None, help='optional fill algorithm') 14 | 15 | flags.DEFINE_string( 16 | 'cat', default=None, help='catergory to create (ID or descriptor)') 17 | flags.DEFINE_list( 18 | 'example_id', default=None, 19 | help='example id(s) to create base for. Defaults to all') 20 | 21 | 22 | def main(_): 23 | from shapenet.core.voxels.config import get_config 24 | from shapenet.core import to_cat_id 25 | from shapenet.core import get_example_ids 26 | config = get_config(FLAGS.voxel_dim, alt=FLAGS.alt) 27 | fill = FLAGS.fill 28 | if fill is not None: 29 | config = config.filled(fill) 30 | if FLAGS.cat is None: 31 | raise ValueError('Must provide at least one cat to convert.') 32 | if FLAGS.fill is not None: 33 | config = config.filled(FLAGS.fill) 34 | cat_id = to_cat_id(FLAGS.cat) 35 | example_ids = FLAGS.example_id 36 | if example_ids is None: 37 | example_ids = get_example_ids(cat_id) 38 | config.create_voxel_data(cat_id, example_ids) 39 | 40 | 41 | app.run(main) 42 | -------------------------------------------------------------------------------- /default_config.yaml: -------------------------------------------------------------------------------- 1 | core_dir: /data/ShapeNetCore.v1 # directory containing CAT_ID.zip, CAT_ID.csv, taxonomy.json 2 | iccv17_dir: /data/shapenet2017 # directory containing (train)|(test)|(val)_(imgs)|(voxels).zip 3 | core_annotations_dir: /data/shapenet_annotations # directory containing shapenetcore_partanno_segmentation_benchmark_v0.zip and shapenetcore_partanno_v0.zip 4 | created_data_dir: ./data # directory into data created by this repo saves data 5 | voxel_compression: lzf 6 | -------------------------------------------------------------------------------- /example/core/benchmark_frustrum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from util3d.transform.frustrum import voxel_values_to_frustrum 8 | from util3d.transform.nonhom import get_eye_to_world_transform 9 | from util3d.voxel.binvox import DenseVoxels 10 | from shapenet.core import get_example_ids, to_cat_id 11 | from shapenet.core.renderings.renderings_manager import get_base_manager 12 | from shapenet.core.voxels.config import get_config 13 | from shapenet.core.voxels.datasets import get_dataset as get_voxel_dataset 14 | 15 | import time 16 | 17 | 18 | cat = 'plane' 19 | voxel_dim = 64 20 | ray_shape = (32,)*3 21 | view_index = 0 22 | cat_id = to_cat_id(cat) 23 | config = get_config(voxel_dim, alt=False).filled('orthographic') 24 | voxel_dataset = get_voxel_dataset( 25 | config, cat_id, id_keys=True, key='rle', compression='lzf') 26 | image_manager = get_base_manager(dim=256) 27 | n_renderings = image_manager.get_render_params()['n_renderings'] 28 | f = 32 / 35 29 | 30 | 31 | example_ids = get_example_ids(cat_id) 32 | with voxel_dataset: 33 | for example_id in example_ids: 34 | start = time.time() 35 | dense_data = voxel_dataset[example_id].dense_data() 36 | dense_data = dense_data[:, -1::-1] 37 | 38 | key = (cat_id, example_id) 39 | eyes = image_manager.get_camera_positions(key) 40 | for vi in range(n_renderings): 41 | eye = eyes[vi] 42 | n = np.linalg.norm(eye) 43 | R, t = get_eye_to_world_transform(eye) 44 | z_near = n - 0.5 45 | z_far = z_near + 1 46 | 47 | frust, inside = voxel_values_to_frustrum( 48 | dense_data, R, t, f, z_near, z_far, ray_shape, 49 | include_corners=False) 50 | frust[np.logical_not(inside)] = 0 51 | frust = frust[:, -1::-1] 52 | vox = DenseVoxels(frust) 53 | rle = vox.rle_data() 54 | print(time.time() - start) 55 | -------------------------------------------------------------------------------- /example/core/blender_renderings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import matplotlib.pyplot as plt 8 | from shapenet.image import with_background 9 | from shapenet.core.blender_renderings.config import RenderConfig 10 | from shapenet.core import cat_desc_to_id, get_example_ids 11 | 12 | 13 | cat_desc = 'plane' 14 | view_index = 5 15 | config = RenderConfig() 16 | view_angle = config.view_angle(view_index) 17 | cat_id = cat_desc_to_id(cat_desc) 18 | example_ids = get_example_ids(cat_id) 19 | 20 | path = config.get_zip_path(cat_id) 21 | if not os.path.isfile(path): 22 | raise IOError('No renderings at %s' % path) 23 | 24 | with config.get_dataset(cat_id, view_index) as ds: 25 | ds = ds.map(lambda image: with_background(image, 255)) 26 | for example_id in ds: 27 | image = ds[example_id] 28 | plt.imshow(image) 29 | plt.show() 30 | -------------------------------------------------------------------------------- /example/core/clip_space_voxels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import random 7 | import numpy as np 8 | 9 | from shapenet.core.voxels.rotated import FrustrumVoxelConfig 10 | from shapenet.core.blender_renderings import RenderConfig 11 | from shapenet.core.voxels import VoxelConfig 12 | from shapenet.core import cat_desc_to_id 13 | 14 | cat_desc = 'chair' 15 | base_config = VoxelConfig(voxel_dim=128) 16 | render_config = RenderConfig(shape=(224, 224)) 17 | view_index = 0 18 | out_shape = (28,)*3 19 | 20 | cat_id = cat_desc_to_id(cat_desc) 21 | 22 | frustrum_config = FrustrumVoxelConfig( 23 | base_config, render_config, view_index, out_shape) 24 | 25 | 26 | def vis(base, frust, image): 27 | from util3d.mayavi_vis import vis_voxels 28 | from mayavi import mlab 29 | from PIL import Image 30 | frust_data = frust.dense_data() 31 | frust_flat = np.max(frust_data, axis=-1) 32 | Image.fromarray( 33 | frust_flat.T.astype(np.uint8)*255).resize((224, 224)).show() 34 | 35 | image.show() 36 | mlab.figure() 37 | vis_voxels(base.dense_data(), color=(0, 0, 1)) 38 | mlab.figure() 39 | vis_voxels(frust_data, color=(0, 1, 0)) 40 | mlab.show() 41 | 42 | 43 | with frustrum_config.get_dataset(cat_id) as fds: 44 | with base_config.get_dataset(cat_id) as bds: 45 | with render_config.get_dataset(cat_id, view_index) as vds: 46 | example_ids = list(fds.keys()) 47 | random.shuffle(example_ids) 48 | for example_id in example_ids: 49 | vis(bds[example_id], fds[example_id], vds[example_id]) 50 | -------------------------------------------------------------------------------- /example/core/filled_voxels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from shapenet.core.voxels.config import VoxelConfig 7 | from shapenet.core import cat_desc_to_id 8 | from util3d.mayavi_vis import vis_sliced 9 | from mayavi import mlab 10 | import numpy as np 11 | 12 | cat_desc = 'watercraft' 13 | cat_id = cat_desc_to_id(cat_desc) 14 | 15 | 16 | base = VoxelConfig(voxel_dim=128) 17 | filled = base.filled('orthographic') 18 | with base.get_dataset(cat_id) as bds, filled.get_dataset(cat_id) as fds: 19 | for example_id in bds: 20 | base_data = bds[example_id].dense_data() 21 | filled_data = fds[example_id].dense_data() 22 | for dense_data in (base_data, filled_data): 23 | mlab.figure() 24 | vis_sliced(dense_data.astype(np.float32), axis_order='xyz') 25 | mlab.show() 26 | break 27 | -------------------------------------------------------------------------------- /example/core/frust_saved.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from shapenet.core import to_cat_id, get_example_ids 7 | from shapenet.core.renderings.archive_manager import get_archive 8 | from shapenet.core.renderings.renderings_manager import get_base_manager 9 | from shapenet.core.renderings.frustrum_voxels import get_frustrum_voxels_data 10 | from shapenet.core.renderings.frustrum_voxels import GROUP_KEY 11 | from shapenet.core.voxels import get_config 12 | 13 | 14 | def vis(image_fp, rle_data, out_dim): 15 | import numpy as np 16 | from PIL import Image 17 | from util3d.voxel.binvox import RleVoxels 18 | from util3d.mayavi_vis import vis_voxels 19 | from mayavi import mlab 20 | im = Image.open(image_fp) 21 | im.show() 22 | dense_data = RleVoxels(np.array(rle_data), (out_dim,)*3).dense_data() 23 | sil = np.any(dense_data, axis=-1).T.astype(np.uint8)*255 24 | sil = Image.fromarray(sil).resize(im.size) 25 | sil.show() 26 | comb = np.array(im) // 2 + np.array(sil)[:, :, np.newaxis] // 2 27 | Image.fromarray(comb).show() 28 | vis_voxels(dense_data, color=(0, 0, 1)) 29 | mlab.show() 30 | 31 | 32 | image_dim = 256 33 | 34 | src_voxel_dim = 256 35 | out_dim = 32 36 | n_renderings = 24 37 | cat = 'car' 38 | 39 | cat_id = to_cat_id(cat) 40 | example_ids = get_example_ids(cat_id) 41 | 42 | renderings_manager = get_base_manager(image_dim, n_renderings=n_renderings) 43 | src_config = get_config(src_voxel_dim, alt=False).filled('orthographic') 44 | view_index = 10 45 | with get_archive(renderings_manager, cat_id).get_open_file() as zf: 46 | with get_frustrum_voxels_data( 47 | renderings_manager.root_dir, src_config, out_dim, cat_id) as vd: 48 | rle_data = vd[GROUP_KEY] 49 | for i, example_id in enumerate(example_ids): 50 | subpath = renderings_manager.get_rendering_subpath( 51 | (cat_id, example_id), view_index) 52 | vis(zf.open(subpath), rle_data[i, view_index], out_dim) 53 | -------------------------------------------------------------------------------- /example/core/frustrum.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from util3d.transform.frustrum import voxel_values_to_frustrum 8 | from util3d.transform.nonhom import get_eye_to_world_transform 9 | from shapenet.core import get_example_ids, to_cat_id 10 | from shapenet.core.renderings.renderings_manager import \ 11 | get_base_manager as get_image_manager 12 | from shapenet.core.views.base import get_base_manager as get_view_manager 13 | from shapenet.core.voxels.config import get_config 14 | from shapenet.core.voxels.datasets import get_dataset as get_voxel_dataset 15 | 16 | 17 | cat = 'telephone' 18 | voxel_dim = 32 19 | ray_shape = (32,)*3 20 | view_index = 0 21 | 22 | cat_id = to_cat_id(cat) 23 | config = get_config(voxel_dim, alt=False).filled('orthographic') 24 | voxel_dataset = get_voxel_dataset( 25 | config, cat_id, id_keys=True, key='rle', compression='lzf') 26 | image_manager = get_image_manager(dim=128) 27 | view_manager = get_view_manager() 28 | f = view_manager.get_view_params()['f'] 29 | 30 | 31 | def vis(dense_data, image_path, frust): 32 | from PIL import Image 33 | from mayavi import mlab 34 | from util3d.mayavi_vis import vis_contours 35 | image = Image.open(image_path) 36 | image.show() 37 | frust = Image.fromarray(frust).resize(image.size) 38 | frust.show() 39 | combined = np.array(image) // 2 + np.array(frust)[:, :, np.newaxis] // 2 40 | Image.fromarray(combined).show() 41 | mlab.figure() 42 | vis_contours(dense_data, contours=[0.5]) 43 | mlab.show() 44 | 45 | 46 | example_ids = get_example_ids(cat_id) 47 | with voxel_dataset: 48 | for example_id in example_ids: 49 | dense_data = voxel_dataset[example_id].dense_data() 50 | dense_data = dense_data[:, -1::-1] 51 | 52 | key = (cat_id, example_id) 53 | eyes = image_manager.get_camera_positions(key) 54 | image_path = image_manager.get_rendering_path(key, view_index) 55 | 56 | eye = eyes[view_index] 57 | n = np.linalg.norm(eye) 58 | R, t = get_eye_to_world_transform(eye) 59 | z_near = n - 0.5 60 | z_far = z_near + 1 61 | 62 | frust, inside = voxel_values_to_frustrum( 63 | dense_data, R, t, f, z_near, z_far, ray_shape, 64 | include_corners=False) 65 | frust[np.logical_not(inside)] = 0 66 | frust_image = np.any(frust, axis=-1).T 67 | frust_image = frust_image[-1::-1].astype(np.uint8)*255 68 | vis(dense_data, image_path, frust_image) 69 | -------------------------------------------------------------------------------- /example/core/meshes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from shapenet.core.meshes import get_mesh_dataset 8 | from shapenet.core import cat_desc_to_id 9 | from util3d.mayavi_vis import vis_mesh 10 | from mayavi import mlab 11 | 12 | desc = 'plane' 13 | 14 | cat_id = cat_desc_to_id(desc) 15 | with get_mesh_dataset(cat_id) as mesh_dataset: 16 | for example_id in mesh_dataset: 17 | example_group = mesh_dataset[example_id] 18 | vertices, faces = ( 19 | np.array(example_group[k]) for k in ('vertices', 'faces')) 20 | vis_mesh(vertices, faces, color=(0, 0, 1), axis_order='xzy') 21 | mlab.show() 22 | -------------------------------------------------------------------------------- /example/core/objs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from shapenet.core import cat_desc_to_id 7 | from shapenet.core.objs import get_obj_file_dataset 8 | from mayavi import mlab 9 | from util3d.mayavi_vis import vis_mesh 10 | from util3d.mesh.obj_io import parse_obj_file 11 | 12 | cat_desc = 'plane' 13 | cat_id = cat_desc_to_id(cat_desc) 14 | 15 | 16 | def map_fn(f): 17 | return parse_obj_file(f)[:2] 18 | 19 | 20 | dataset = get_obj_file_dataset(cat_id).map(map_fn) 21 | 22 | with dataset: 23 | for k, (vertices, faces) in dataset.items(): 24 | print(k) 25 | vis_mesh(vertices, faces, axis_order='xzy') 26 | mlab.show() 27 | -------------------------------------------------------------------------------- /example/core/point_clouds.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import random 7 | import numpy as np 8 | from mayavi import mlab 9 | from util3d.mayavi_vis import vis_point_cloud, vis_normals 10 | from shapenet.core.point_clouds import get_point_cloud_dataset 11 | from shapenet.core.point_clouds import get_cloud_normal_dataset 12 | from shapenet.core import cat_desc_to_id 13 | 14 | cat_desc = 'plane' 15 | n_points = 16384 16 | cat_id = cat_desc_to_id(cat_desc) 17 | 18 | show_normals = True 19 | if vis_normals: 20 | dataset = get_cloud_normal_dataset(cat_id, n_points) 21 | else: 22 | dataset = get_point_cloud_dataset(cat_id, n_points) 23 | 24 | with dataset: 25 | for example_id in dataset: 26 | data = dataset[example_id] 27 | s = random.sample(range(n_points), 1024) 28 | if show_normals: 29 | cloud = np.array(data['points']) 30 | normals = np.array(data['normals']) 31 | cloud = cloud[s] 32 | normals = normals[s] 33 | else: 34 | cloud = np.array(data) 35 | normals = None 36 | cloud = cloud[s] 37 | vis_point_cloud( 38 | cloud, axis_order='xzy', color=(0, 0, 1), scale_factor=0.002) 39 | if normals is not None: 40 | vis_normals(cloud, normals, scale_factor=0.01, axis_order='xzy') 41 | mlab.show() 42 | -------------------------------------------------------------------------------- /example/core/train_test_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import dids 7 | from shapenet.core.point_clouds import get_point_cloud_dataset 8 | from shapenet.core.meshes import get_mesh_dataset 9 | from shapenet.core import cat_desc_to_id, get_test_train_split 10 | 11 | cat_desc = 'plane' 12 | cat_id = cat_desc_to_id(cat_desc) 13 | 14 | cloud_ds = get_point_cloud_dataset(cat_id, 1024) 15 | mesh_ds = get_mesh_dataset(cat_id) 16 | 17 | zipped_ds = dids.Dataset.zip(mesh_ds, cloud_ds) 18 | 19 | 20 | def vis(mesh, cloud): 21 | import numpy as np 22 | from util3d.mayavi_vis import vis_mesh, vis_point_cloud 23 | from mayavi import mlab 24 | vertices, faces = (np.array(mesh[k]) for k in ('vertices', 'faces')) 25 | vis_mesh(vertices, faces, color=(0, 0, 1), axis_order='xzy') 26 | vis_point_cloud( 27 | np.array(cloud), color=(0, 1, 0), scale_factor=0.01, axis_order='xzy') 28 | mlab.show() 29 | 30 | 31 | test_train_split = get_test_train_split()[cat_id] 32 | train_keys = test_train_split['train'] 33 | test_keys = test_train_split['test'] 34 | val_keys = test_train_split['val'] 35 | 36 | with zipped_ds: 37 | print(len(zipped_ds)) 38 | train_ds = zipped_ds.subset(train_keys) 39 | test_ds = zipped_ds.subset(test_keys) 40 | val_ds = zipped_ds.subset(val_keys) 41 | print('n train: %d, n valid train: %d' % (len(train_keys), len(train_ds))) 42 | print('n test: %d, n valid test: %d' % (len(test_keys), len(test_ds))) 43 | print('n val: %d, n valid val: %d' % (len(val_keys), len(val_ds))) 44 | 45 | for example_id in train_ds.keys(): 46 | print('train dataset, example_id: %s' % example_id) 47 | mesh, cloud = train_ds[example_id] 48 | break 49 | vis(mesh, cloud) 50 | 51 | for example_id in test_ds.keys(): 52 | print('test dataset, example_id: %s' % example_id) 53 | mesh, cloud = test_ds[example_id] 54 | break 55 | vis(mesh, cloud) 56 | 57 | for example_id in val_ds.keys(): 58 | print('val dataset, example_id: %s' % example_id) 59 | mesh, cloud = val_ds[example_id] 60 | break 61 | vis(mesh, cloud) 62 | -------------------------------------------------------------------------------- /example/core/vox_ds/compare_compression.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | from shapenet.core.voxels import VoxelConfig 9 | from shapenet.core.voxels.datasets import get_dataset 10 | from shapenet.core import cat_desc_to_id, get_example_ids 11 | 12 | # from util3d.voxel import rle as rle 13 | # from util3d.voxel import brle as brle 14 | 15 | config = VoxelConfig(voxel_dim=128).filled('orthographic') 16 | cat_desc = 'cellphone' 17 | cat_id = cat_desc_to_id(cat_desc) 18 | example_ids = get_example_ids(cat_id) 19 | 20 | example_index = 10 21 | example_id = example_ids[example_index] 22 | 23 | 24 | def vis(*dense_data): 25 | from mayavi import mlab 26 | from util3d.mayavi_vis import vis_voxels 27 | for d in dense_data: 28 | mlab.figure() 29 | vis_voxels(d, color=(0, 0, 1), axis_order='xyz') 30 | mlab.show() 31 | 32 | 33 | out = [] 34 | 35 | for key in ('rle', 'brle'): 36 | for compression in ('lzf', 'gzip', None): 37 | for pad in (True, False): 38 | with get_dataset( 39 | config, cat_id, id_keys=True, key=key, pad=pad, 40 | compression=compression) as ds: 41 | out.append(ds[example_id].dense_data()) 42 | 43 | assert(all(np.all(j == out[0]) for j in out[1:])) 44 | vis(out[-1]) 45 | -------------------------------------------------------------------------------- /example/core/vox_ds/compare_formats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | from shapenet.core.voxels import VoxelConfig 9 | from shapenet.core.voxels.datasets import get_dataset 10 | from shapenet.core import cat_desc_to_id, get_example_ids 11 | 12 | # from util3d.voxel import rle as rle 13 | # from util3d.voxel import brle as brle 14 | 15 | config = VoxelConfig(voxel_dim=32) 16 | cat_desc = 'cellphone' 17 | cat_id = cat_desc_to_id(cat_desc) 18 | example_ids = get_example_ids(cat_id) 19 | 20 | example_index = 10 21 | example_id = example_ids[example_index] 22 | 23 | 24 | def vis(*dense_data): 25 | from mayavi import mlab 26 | from util3d.mayavi_vis import vis_voxels 27 | for d in dense_data: 28 | mlab.figure() 29 | vis_voxels(d, color=(0, 0, 1), axis_order='xyz') 30 | mlab.show() 31 | 32 | 33 | out = [] 34 | 35 | for key in ('file', 'zip', 'rle', 'brle'): 36 | if key in ('rle', 'brle'): 37 | kwargs = dict(compression='lzf', pad=True) 38 | else: 39 | kwargs = {} 40 | with get_dataset( 41 | config, cat_id, id_keys=True, key=key, **kwargs) as ds: 42 | out.append(ds[example_id].dense_data()) 43 | 44 | 45 | assert(all(np.all(j == out[0]) for j in out[1:])) 46 | vis(out[-1]) 47 | -------------------------------------------------------------------------------- /example/core/voxels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from mayavi import mlab 7 | from shapenet.core.voxels.config import get_config 8 | from shapenet.core.voxels.datasets import get_dataset 9 | from shapenet.core.blender_renderings import RenderConfig 10 | from shapenet.core import to_cat_id, get_example_ids 11 | from util3d.mayavi_vis import vis_sliced 12 | from util3d.mayavi_vis import vis_contours 13 | # from util3d.mayavi_vis import vix_voxels 14 | 15 | cat = 'plane' 16 | voxel_dim = 128 17 | alt = False 18 | fill = 'orthographic' 19 | ds_kwargs = dict(key='rle', compression='lzf', shape_key='pad') 20 | 21 | # cat = 'pistol' 22 | # voxel_dim = 32 23 | # alt = True 24 | # fill = None 25 | # ds_kwargs = dict(key='zip') 26 | 27 | cat_id = to_cat_id(cat) 28 | example_ids = get_example_ids(cat_id) 29 | config = get_config(voxel_dim, alt=alt) 30 | if fill is not None: 31 | config = config.filled(fill) 32 | with get_dataset(config, cat_id, **ds_kwargs) as dataset: 33 | with RenderConfig(shape=(256, 256), n_images=8).get_dataset( 34 | cat_id, view_index=5) as render_ds: 35 | for example_id in example_ids: 36 | dense = dataset[example_id].dense_data() 37 | render_ds[example_id].show() 38 | mlab.figure() 39 | vis_sliced(dense, axis_order='xyz') 40 | mlab.figure() 41 | vis_contours(dense, contours=[0.5]) 42 | mlab.show() 43 | -------------------------------------------------------------------------------- /example/iccv17/voxels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from util3d.mayavi_vis import vis_voxels, vis_point_cloud 7 | from mayavi import mlab 8 | from util3d.voxel.manip import fast_resize, get_surface_voxels 9 | from util3d.voxel.convert import voxels_to_point_cloud 10 | 11 | from shapenet.iccv17.voxels import get_mat_data 12 | 13 | voxels = get_mat_data('eval', 9) 14 | 15 | 16 | def test_fast_resize(): 17 | # zoomed = resize(voxels, 64) 18 | vis_voxels(fast_resize(voxels, 32), color=(0, 0, 1), axis_order='xyz') 19 | mlab.show() 20 | 21 | 22 | def test_surface_voxels(): 23 | point_cloud = voxels_to_point_cloud(get_surface_voxels(voxels)) 24 | vis_point_cloud(point_cloud[:len(point_cloud) // 2], color=(0, 0, 1)) 25 | mlab.show() 26 | 27 | 28 | def test_mpl_vis(): 29 | print(voxels.shape) 30 | v = fast_resize(voxels, 32) 31 | print(v.shape) 32 | vis_voxels(v, color=(0, 0, 1), axis_order='xyz') 33 | mlab.show() 34 | 35 | 36 | test_mpl_vis() 37 | test_fast_resize() 38 | test_surface_voxels() 39 | -------------------------------------------------------------------------------- /example/r2n2/angle_dist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from shapenet.r2n2 import get_cat_ids 8 | from shapenet.r2n2.hdf5 import Hdf5Manager, meta_index 9 | 10 | az_index = meta_index('azimuth') 11 | el_index = meta_index('elevation') 12 | az = [] 13 | el = [] 14 | 15 | for cat_id in get_cat_ids(): 16 | with Hdf5Manager(cat_id) as manager: 17 | g = manager.meta_group 18 | az.append(np.array(g[..., az_index])) 19 | el.append(np.array(g[..., el_index])) 20 | 21 | az = np.concatenate(az, axis=0).flatten() 22 | el = np.concatenate(el, axis=0).flatten() 23 | 24 | 25 | def vis(az, el): 26 | import matplotlib.pyplot as plt 27 | _, (ax0, ax1) = plt.subplots(1, 2) 28 | ax0.hist(az, bins=36) 29 | ax1.hist(el, bins=36) 30 | plt.show() 31 | 32 | 33 | vis(az, el) 34 | -------------------------------------------------------------------------------- /example/r2n2/binvox.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from shapenet.r2n2.tgz import BinvoxManager 8 | 9 | 10 | def vis(vox): 11 | import numpy as np 12 | from mayavi import mlab 13 | voxels = np.pad( 14 | vox.dense_data(), [[1, 1], [1, 1], [1, 1]], mode='constant') 15 | mlab.figure() 16 | mlab.contour3d( 17 | voxels.astype(np.float32), color=(0, 0, 1), contours=[0.5], 18 | opacity=0.5) 19 | mlab.show() 20 | 21 | 22 | cat_id = '02958343' # car 23 | 24 | with BinvoxManager() as bvm: 25 | ids = bvm.get_example_ids() 26 | example_ids = ids[cat_id] 27 | print(cat_id) 28 | for example_id in example_ids: 29 | vox = bvm[cat_id, example_id] 30 | vis(vox) 31 | -------------------------------------------------------------------------------- /example/r2n2/hdf5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import random 8 | from shapenet.r2n2.hdf5 import Hdf5Manager, get_cat_ids 9 | from shapenet.r2n2.split import split_indices 10 | 11 | cat_id = random.sample(get_cat_ids(), 1)[0] 12 | # cat_id = '03691459' 13 | view_index = 0 14 | 15 | 16 | def vis(vox, image, meta): 17 | from mayavi import mlab 18 | from util3d.mayavi_vis import vis_voxels 19 | print(meta) 20 | image.resize((137*4,)*2).show() 21 | mlab.figure() 22 | vis_voxels(vox, color=(0, 0, 1), scale_factor=0.5, axis_order='xyz') 23 | mlab.show() 24 | 25 | 26 | print('Opening manager...') 27 | with Hdf5Manager(cat_id) as m: 28 | print('Getting example_ids...') 29 | example_ids = m.get_example_ids() 30 | print('Getting split indices') 31 | indices = split_indices(example_ids, 'train') 32 | print(indices[:10]) 33 | for index in indices: 34 | print('Loading data...') 35 | vox = m.get_voxels(index).dense_data() 36 | image = m.get_rendering(index, view_index) 37 | meta = m.get_meta(index, view_index) 38 | print('Visualizing...') 39 | vis(vox, image, meta) 40 | -------------------------------------------------------------------------------- /example/r2n2/renderings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from shapenet.r2n2.renderings import RenderingsManager 8 | from shapenet.r2n2.binvox import BinvoxManager 9 | 10 | 11 | print('Opening BM') 12 | with BinvoxManager() as bm: 13 | print('Getting exmaple_ids...') 14 | example_ids = bm.get_example_ids() 15 | 16 | print('Opening RM') 17 | with RenderingsManager() as rm: 18 | print('Getting metas...') 19 | for cat_id, example_ids in example_ids.items(): 20 | for example_id in example_ids: 21 | metas = rm.get_metas(cat_id, example_id) 22 | print(cat_id, example_id, len(metas)) 23 | -------------------------------------------------------------------------------- /iccv17/README.md: -------------------------------------------------------------------------------- 1 | [Website](https://shapenet.cs.stanford.edu/iccv17/) 2 | -------------------------------------------------------------------------------- /iccv17/__init__.py: -------------------------------------------------------------------------------- 1 | import path 2 | from path import get_example_ids 3 | 4 | __all__ = [ 5 | path, 6 | get_example_ids, 7 | ] 8 | -------------------------------------------------------------------------------- /iccv17/path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def get_dataset_dir(): 5 | if 'SHAPENET17_PATH' in os.environ: 6 | dataset_dir = os.environ['SHAPENET17_PATH'] 7 | if not os.path.isdir(dataset_dir): 8 | raise Exception('SHAPENET17_PATH directory does not exist') 9 | return dataset_dir 10 | else: 11 | raise Exception('SHAPENET17_PATH environment variable not set.') 12 | 13 | 14 | def get_image_shape(): 15 | return (192, 256, 3) 16 | 17 | 18 | def _mode(mode): 19 | return 'val' if mode == 'eval' else mode 20 | 21 | 22 | def _example_id(example_id): 23 | return '%06d' % example_id if isinstance(example_id, int) else example_id 24 | 25 | 26 | def _image_id(image_id): 27 | return ('%03d' % image_id) if isinstance(image_id, int) else image_id 28 | 29 | 30 | def get_all_images_dir(mode): 31 | return os.path.join(get_dataset_dir(), '%s_imgs' % _mode(mode)) 32 | 33 | 34 | def get_example_images_dir(mode, example_id): 35 | return os.path.join(get_all_images_dir(mode), _example_id(example_id)) 36 | 37 | 38 | def get_example_ids(mode): 39 | return os.listdir(get_all_images_dir(_mode(mode))) 40 | 41 | 42 | def n_images(mode, example_id): 43 | return len(os.listdir(get_example_images_dir(mode, example_id))) 44 | 45 | 46 | def get_image_path(mode, example_id, image_id): 47 | filename = '%s.png' % _image_id(image_id) 48 | return os.path.join(get_example_images_dir(mode, example_id), filename) 49 | 50 | 51 | def get_example_indices(mode): 52 | return (int(k) for k in get_example_ids(mode)) 53 | 54 | 55 | def get_all_voxels_dir(mode): 56 | return os.path.join(get_dataset_dir(), '%s_voxels' % _mode(mode)) 57 | 58 | 59 | def get_voxel_path(mode, example_id): 60 | return os.path.join( 61 | get_all_voxels_dir(mode), _example_id(example_id), 'model.mat') 62 | 63 | 64 | if __name__ == '__main__': 65 | mode = 'eval' 66 | example_id = 9 67 | 68 | def vis_image(): 69 | import matplotlib.pyplot as plt 70 | from scipy.misc import imread 71 | image_idx = 0 72 | image = imread(get_image_path(mode, example_id, image_idx)) 73 | image = image[..., :3] 74 | plt.imshow(image) 75 | plt.figure() 76 | plt.imshow(image[..., -1]) 77 | plt.show() 78 | 79 | vis_image() 80 | -------------------------------------------------------------------------------- /iccv17/voxels/__init__.py: -------------------------------------------------------------------------------- 1 | from shapenet.iccv17.path import get_voxel_path 2 | from scipy.io import loadmat 3 | 4 | 5 | def get_shape(): 6 | return (256, 256, 256) 7 | 8 | 9 | def load_mat_data(voxel_path): 10 | return loadmat(voxel_path)['input'][0] 11 | 12 | 13 | def get_mat_data(mode, example_id): 14 | return load_mat_data(get_voxel_path(mode, example_id)) 15 | 16 | 17 | if __name__ == '__main__': 18 | from path import get_image_path 19 | mode = 'eval' 20 | example_id = 9 21 | 22 | def vis_image(): 23 | import matplotlib.pyplot as plt 24 | from scipy.misc import imread 25 | image_idx = 0 26 | image = imread(get_image_path(mode, example_id, image_idx)) 27 | image = image[..., :3] 28 | plt.imshow(image) 29 | plt.figure() 30 | plt.imshow(image[..., -1]) 31 | plt.show() 32 | 33 | def vis_voxels(): 34 | from mayavi import mlab 35 | from util3d.mayavi_vis import vis_voxels 36 | path = get_voxel_path('eval', 9) 37 | voxels = load_mat_data(path) 38 | 39 | vis_voxels(voxels, color=(0, 0, 1)) 40 | mlab.show() 41 | 42 | vis_image() 43 | vis_voxels() 44 | -------------------------------------------------------------------------------- /image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | try: 9 | from StringIO import StringIO 10 | except ImportError: 11 | from io import StringIO 12 | 13 | 14 | def load_image_from_file(f): 15 | return Image.open(StringIO(f.read())) 16 | 17 | 18 | def load_image_from_zip(zip_file, path): 19 | with zip_file.open(path) as fp: 20 | return load_image_from_file(fp) 21 | 22 | 23 | def with_background(image4d, background_color): 24 | """Sets background of 4d image to the specified color.""" 25 | image = np.asarray(image4d) 26 | assert(image.shape[-1] == 4) 27 | background = image[..., 3] == 0 28 | image = image[..., :3].copy() 29 | image[background] = background_color 30 | return image 31 | -------------------------------------------------------------------------------- /path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from .config import config 7 | 8 | root_dir = os.path.realpath(os.path.dirname(__file__)) 9 | 10 | 11 | def get_data_dir(*args): 12 | data_path = config['created_data_dir'] 13 | if data_path.startswith('/') or data_path.startswith('~'): 14 | folder = os.path.join(data_path, *args) 15 | else: 16 | if data_path.startswith('./'): 17 | data_path = data_path[2:] 18 | folder = os.path.join(root_dir, data_path, *args) 19 | if not os.path.isdir(folder): 20 | os.makedirs(folder) 21 | return folder 22 | -------------------------------------------------------------------------------- /r2n2/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | _data 4 | -------------------------------------------------------------------------------- /r2n2/README.md: -------------------------------------------------------------------------------- 1 | Download and loading functions for data from [R2N2](https://github.com/chrischoy/3D-R2N2). 2 | 3 | ## Data Download 4 | `scripts/download.py` should download renderings and binvox archives to `_data`. 5 | 6 | ### Manual Download 7 | * [Renderings](ftp://cs.stanford.edu/cs/cvgl/ShapeNetRendering.tgz) 8 | * [Voxels](ftp://cs.stanford.edu/cs/cvgl/ShapeNetVox32.tgz) 9 | 10 | These should be downloaded/linked to `DATA/r2n2/ShapeNetRendering.tgz` and `DATA/r2n2/ShapeNetVox32.tgz` respectively, where `DATA` is `saved_data_dir` defined in `shapenet/config.yaml`. 11 | -------------------------------------------------------------------------------- /r2n2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | cat_descs = ( 7 | 'plane', 8 | 'bench', 9 | 'cabinet', 10 | 'car', 11 | 'chair', 12 | 'monitor', 13 | 'lamp', 14 | 'speaker', 15 | 'rifle', 16 | 'sofa', 17 | 'table', 18 | 'telephone', 19 | 'watercraft', 20 | ) 21 | 22 | 23 | def get_cat_descs(): 24 | return cat_descs 25 | 26 | 27 | def get_cat_ids(): 28 | from ..core import cat_desc_to_id 29 | return tuple(cat_desc_to_id(desc) for desc in cat_descs) 30 | 31 | 32 | # __all__ = [ 33 | # get_cat_ids, 34 | # get_cat_descs, 35 | # cat_descs, 36 | # ] 37 | -------------------------------------------------------------------------------- /r2n2/hdf5.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import numpy as np 8 | import h5py 9 | from progress.bar import IncrementalBar 10 | from PIL import Image 11 | from . import path 12 | from . import tgz 13 | 14 | hdf5_dir = os.path.join(path.data_dir, 'hdf5') 15 | # hdf5_dir = os.path.join(path.data_dir, '..', 'scripts') 16 | _logger = logging.getLogger(__name__) 17 | n_renderings = 24 18 | voxel_dim = 32 19 | meta_keys = ( 20 | 'azimuth', 'elevation', 'in_plane_rotation', 'distance', 'field_of_view') 21 | _meta_indices = {k: i for i, k in enumerate(meta_keys)} 22 | 23 | 24 | def meta_index(key): 25 | return _meta_indices[key] 26 | 27 | 28 | def get_hdf5_path(cat_id): 29 | return os.path.join(hdf5_dir, '%s.hdf5' % cat_id) 30 | 31 | 32 | def get_hdf5_data(cat_id, mode='r'): 33 | return h5py.File(get_hdf5_path(cat_id), mode) 34 | 35 | 36 | def numpy_to_buffer(data): 37 | import io 38 | import numpy as np 39 | if not isinstance(data, np.ndarray): 40 | raise ValueError('data must be a numpy array') 41 | return io.BytesIO(data) 42 | 43 | 44 | def numpy_to_image(data): 45 | return Image.open(numpy_to_buffer(data)) 46 | 47 | 48 | def buffer_to_numpy(fp, dtype=np.uint8): 49 | import numpy as np 50 | return np.fromstring(fp.read(), dtype=dtype) 51 | 52 | 53 | def _ensure_extracted(): 54 | if not os.path.isdir(path.get_renderings_path()): 55 | with tgz.RenderingsManager() as rm: 56 | rm.extract() 57 | 58 | 59 | def get_cat_ids(): 60 | return tuple(fn[:-5] for fn in os.listdir(hdf5_dir)) 61 | 62 | 63 | class Converter(object): 64 | def __init__(self, cat_id, example_ids=None, mode='r'): 65 | if example_ids is None: 66 | with tgz.BinvoxManager() as bm: 67 | example_ids = bm.get_example_ids()[cat_id] 68 | self._cat_id = cat_id 69 | self._example_ids = sorted(example_ids) 70 | self._mode = mode 71 | self._file = None 72 | 73 | def __enter__(self): 74 | self.open() 75 | return self 76 | 77 | def __exit__(self, *args, **kwargs): 78 | self.close() 79 | 80 | def open(self): 81 | p = get_hdf5_path(self._cat_id) 82 | d = os.path.dirname(p) 83 | if not os.path.isdir(d): 84 | os.makedirs(d) 85 | self._file = h5py.File(p, self._mode) 86 | 87 | def close(self): 88 | self._file.close() 89 | self._file = None 90 | 91 | def setup(self, overwrite=False): 92 | _logger.info('Setting up hdf5.%s' % self._cat_id) 93 | vlen_dtype = h5py.special_dtype(vlen=np.dtype(np.uint8)) 94 | n = len(self._example_ids) 95 | if overwrite: 96 | for k in ('example_ids', 'rle_data', 'renderings', 'meta'): 97 | del self._file[k] 98 | 99 | if 'example_ids' not in self._file: 100 | id_group = self._file.create_dataset( 101 | 'example_ids', shape=(n,), dtype='S32') 102 | for i, e in enumerate(self._example_ids): 103 | id_group[i] = e 104 | else: 105 | assert(len(self._file['example_ids']) == n) 106 | self._file.require_dataset( 107 | 'rle_data', shape=(n,), dtype=vlen_dtype) 108 | self._file.require_dataset( 109 | 'renderings', shape=(n, n_renderings), dtype=vlen_dtype) 110 | self._file.require_dataset( 111 | 'meta', shape=(n, n_renderings, 5), dtype=np.float32) 112 | 113 | def convert_voxels(self, overwrite=False): 114 | _logger.info('Converting voxel data for hdf5.%s' % self._cat_id) 115 | group = self._file['rle_data'] 116 | with tgz.BinvoxManager() as bm: 117 | bar = IncrementalBar(max=len(self._example_ids)) 118 | for i, example_id in enumerate(self._example_ids): 119 | if overwrite or len(group[i]) == 0: 120 | group[i] = bm.load(self._cat_id, example_id).rle_data() 121 | bar.next() 122 | bar.finish() 123 | 124 | def convert_renderings(self, overwrite=False): 125 | _logger.info('Converting voxel data for hdf5.%s' % self._cat_id) 126 | group = self._file['renderings'] 127 | cat_id = self._cat_id 128 | _ensure_extracted() 129 | bar = IncrementalBar(max=len(self._example_ids)) 130 | for i, example_id in enumerate(self._example_ids): 131 | for j in range(n_renderings): 132 | if overwrite or len(group[i, j]) == 0: 133 | p = path.get_renderings_path(cat_id, example_id, j) 134 | with open(p, 'r') as fp: 135 | group[i, j] = buffer_to_numpy(fp) 136 | bar.next() 137 | bar.finish() 138 | 139 | def convert_meta(self, overwrite=False): 140 | group = self._file['meta'] 141 | _ensure_extracted() 142 | cat_id = self._cat_id 143 | bar = IncrementalBar(max=len(self._example_ids)) 144 | for i, example_id in enumerate(self._example_ids): 145 | if overwrite or np.all(group[i] == 0): 146 | p = path.get_renderings_path( 147 | cat_id, example_id, 'rendering_metadata.txt') 148 | with open(p, 'r') as fp: 149 | lines = [line for line in fp.readlines() if len(line) > 1] 150 | meta = [ 151 | [float(n) for n in line.rstrip().split(' ')] 152 | for line in lines] 153 | group[i] = np.array(meta, dtype=np.float32) 154 | bar.next() 155 | bar.finish() 156 | 157 | def convert( 158 | self, setup=True, voxels=True, meta=True, renderings=True, 159 | overwrite=False): 160 | if setup: 161 | self.setup(overwrite=overwrite) 162 | if voxels: 163 | self.convert_voxels(overwrite=overwrite) 164 | if meta: 165 | self.convert_meta(overwrite=overwrite) 166 | if renderings: 167 | self.convert_renderings(overwrite=overwrite) 168 | 169 | 170 | class Hdf5Manager(object): 171 | def __init__(self, cat_id): 172 | self._file = None 173 | self._cat_id = cat_id 174 | 175 | @property 176 | def cat_id(self): 177 | return self._cat_id 178 | 179 | def __enter__(self): 180 | self.open() 181 | return self 182 | 183 | def __exit__(self, *args, **kwargs): 184 | self.close() 185 | 186 | def close(self): 187 | self._file.close() 188 | self._file = None 189 | 190 | @property 191 | def meta_group(self): 192 | return self._file['meta'] 193 | 194 | @property 195 | def renderings_group(self): 196 | return self._file['renderings'] 197 | 198 | @property 199 | def rle_group(self): 200 | return self._file['rle_data'] 201 | 202 | def open(self): 203 | path = get_hdf5_path(self._cat_id) 204 | if not os.path.isfile(path): 205 | _logger.info( 206 | 'No hdf5 data found for category %s. Converting...' 207 | % self._cat_id) 208 | with self.get_converter(mode='a') as converter: 209 | converter.convert() 210 | self._file = h5py.File(path, mode='r') 211 | 212 | def get_converter(self, mode='a'): 213 | return Converter(self._cat_id, 'a') 214 | 215 | def get_voxels(self, example_index): 216 | from util3d.voxel.binvox import RleVoxels 217 | return RleVoxels(self.get_rle_data(example_index), (32,)*3) 218 | 219 | def get_rle_data(self, example_index): 220 | return np.array(self._file['rle_data'][example_index]) 221 | 222 | def get_rendering(self, example_index, view_index): 223 | return numpy_to_image( 224 | np.array(self._file['renderings'][example_index, view_index])) 225 | 226 | def get_meta(self, example_index, view_index): 227 | return np.array(self._file['meta'][example_index, view_index]) 228 | 229 | def get_example_ids(self): 230 | return np.array(self._file['example_ids'], dtype='S32') 231 | -------------------------------------------------------------------------------- /r2n2/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import six 7 | from .. import path 8 | 9 | data_dir = path.get_data_dir('r2n2') 10 | 11 | 12 | def get_binvox_subpath(cat_id=None, example_id=None): 13 | args = ['ShapeNetVox32'] 14 | if cat_id is not None: 15 | args.append(cat_id) 16 | if example_id is not None: 17 | args.extend((example_id, 'model.binvox')) 18 | return os.path.join(*args) 19 | 20 | 21 | def get_renderings_subpath(cat_id=None, example_id=None, data_id=None): 22 | args = ['ShapeNetRendering'] 23 | if cat_id is not None: 24 | args.append(cat_id) 25 | if example_id is not None: 26 | args.append(example_id) 27 | if data_id is not None: 28 | args.append('rendering') 29 | if isinstance(data_id, int): 30 | args.append('%02d.png' % data_id) 31 | elif isinstance(data_id, six.string_types): 32 | args.append(data_id) 33 | else: 34 | raise ValueError('Unrecognized data_id `%s`' % data_id) 35 | 36 | return os.path.join(*args) 37 | 38 | 39 | def get_binvox_path(cat_id=None, example_id=None): 40 | return os.path.join(data_dir, get_binvox_subpath(cat_id, example_id)) 41 | 42 | 43 | def get_renderings_path(cat_id=None, example_id=None, data_id=None): 44 | return os.path.join( 45 | data_dir, get_renderings_subpath(cat_id, example_id, data_id)) 46 | -------------------------------------------------------------------------------- /r2n2/scripts/download.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import shapenet.r2n2.tgz as tgz 7 | 8 | 9 | tgz.BinvoxManager().download() 10 | tgz.RenderingsManager().download() 11 | -------------------------------------------------------------------------------- /r2n2/scripts/extract_renderings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import shapenet.r2n2.tgz as tgz # NOQA 6 | 7 | 8 | with tgz.RenderingsManager() as rm: 9 | rm.extract() 10 | -------------------------------------------------------------------------------- /r2n2/scripts/hdf5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import shapenet.r2n2.tgz as tgz 7 | import shapenet.r2n2.hdf5 as hdf5 8 | 9 | overwrite = False 10 | # overwrite = True 11 | 12 | with tgz.BinvoxManager() as bm: 13 | example_ids = bm.get_example_ids() 14 | 15 | 16 | n = len(example_ids) 17 | for i, (cat_id, ex_ids) in enumerate(example_ids.items()): 18 | print('Converting %s, %d / %d' % (cat_id, i+1, n)) 19 | with hdf5.Converter(cat_id, ex_ids, mode='a') as converter: 20 | converter.convert(overwrite=overwrite) 21 | -------------------------------------------------------------------------------- /r2n2/split.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | def get_start_end_frac(mode): 7 | if mode == 'train': 8 | start_frac = 0 9 | end_frac = 0.8 10 | elif mode == 'test': 11 | start_frac = 0.8 12 | end_frac = 1 13 | else: 14 | raise ValueError( 15 | '`mode` must be one of ("train", "test"), got "%s" % mode') 16 | return start_frac, end_frac 17 | 18 | 19 | def get_start_end_index(n_examples, mode): 20 | start, end = get_start_end_frac(mode) 21 | return int(n_examples*start), int(n_examples*end) 22 | 23 | 24 | def _split(sorted_keys, mode): 25 | start, end = get_start_end_index(len(sorted_keys), mode) 26 | return sorted_keys[start:end] 27 | 28 | 29 | def split(example_ids, mode): 30 | example_ids = sorted(example_ids) 31 | return _split(example_ids, mode) 32 | 33 | 34 | def split_indices(example_ids, mode): 35 | enumerated = sorted(enumerate(example_ids), key=lambda x: x[1]) 36 | indices, keys = zip(*_split(enumerated, mode)) 37 | return indices 38 | -------------------------------------------------------------------------------- /r2n2/tgz.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import tarfile 5 | import os 6 | import collections 7 | import logging 8 | from .path import data_dir, get_renderings_subpath, get_binvox_subpath 9 | 10 | 11 | _logger = logging.getLogger(__name__) 12 | 13 | renderings_path = os.path.join(data_dir, 'ShapeNetRendering.tgz') 14 | 15 | 16 | binvox_path = os.path.join(data_dir, 'ShapeNetVox32.tgz') 17 | 18 | 19 | # binvox_url = 'ftp://cs.stanford.edu/cs/cvgl/ShapeNetVox32.tgz' 20 | # renderings_url = 'ftp://cs.stanford.edu/cs/cvgl/ShapeNetRendering.tgz' 21 | binvox_url = 'http://cvgl.stanford.edu/data2/ShapeNetVox32.tgz' 22 | renderings_url = 'http://cvgl.stanford.edu/data2/ShapeNetRendering.tgz' 23 | 24 | 25 | class ArchiveManager(object): 26 | def __init__(self, url, path, mode='r:gz'): 27 | self._path = path 28 | self._mode = mode 29 | self._url = url 30 | self._file = None 31 | self._members = None 32 | 33 | @property 34 | def path(self): 35 | return self._path 36 | 37 | @property 38 | def url(self): 39 | return self._url 40 | 41 | @property 42 | def mode(self): 43 | return self._mode 44 | 45 | def __enter__(self): 46 | self.open() 47 | return self 48 | 49 | def __exit__(self, *args, **kwargs): 50 | self.close() 51 | 52 | def open(self): 53 | if not os.path.isfile(self._path): 54 | self.download() 55 | assert(os.path.isfile(self._path)) 56 | self._file = tarfile.open(self._path, self._mode) 57 | self._members = {m.name: m for m in self._file.getmembers()} 58 | 59 | def close(self): 60 | self._file = file 61 | self._members = None 62 | 63 | def extract(self): 64 | _logger.info('Extracting contents of %s' % self.path) 65 | self._file.extractall(path=data_dir) 66 | 67 | @property 68 | def is_open(self): 69 | return not self.is_closed 70 | 71 | @property 72 | def is_closed(self): 73 | return self._file is None 74 | 75 | def load_subpath(self, subpath): 76 | return self._file.extractfile(self._members[subpath]) 77 | 78 | def download(self): 79 | import wget 80 | path = self.path 81 | if os.path.isfile(path): 82 | _logger.info( 83 | 'Data already present at %s. Skipping download' % path) 84 | else: 85 | url = self.url 86 | _logger.info('Downloading from %s' % url) 87 | try: 88 | wget.download(url, out=path) 89 | except IOError: 90 | print('Problem downloading from %s' % url) 91 | raise 92 | 93 | 94 | class BinvoxManager(ArchiveManager): 95 | def __init__(self): 96 | super(BinvoxManager, self).__init__( 97 | binvox_url, binvox_path, 'r:gz') 98 | self._cat_ids = None 99 | self._example_ids = {} 100 | 101 | def __getitem__(self, args): 102 | cat_id, example_id = args 103 | return self.load(cat_id, example_id) 104 | 105 | def load(self, cat_id, example_id): 106 | from util3d.voxel.binvox import Voxels 107 | fp = self.load_subpath(get_binvox_subpath(cat_id, example_id)) 108 | vox = Voxels.from_file(fp) 109 | return vox 110 | 111 | def get_example_ids(self): 112 | example_ids = {} 113 | for name in self._members: 114 | args = name.split('/') 115 | if args[-1].endswith('.binvox'): 116 | cat_id, example_id = args[1:3] 117 | example_ids.setdefault(cat_id, []).append(example_id) 118 | return example_ids 119 | 120 | 121 | # https://github.com/chrischoy/3D-R2N2/issues/12 122 | meta = collections.namedtuple('RenderingMeta', [ 123 | 'azimuth', 124 | 'elevation', 125 | 'in_plane_rotation', 126 | 'distance', 127 | 'field_of_view', 128 | ]) 129 | 130 | 131 | def parse_meta_line(line): 132 | return tuple(float(x) for x in line.rstrip().split(' ')) 133 | 134 | 135 | class RenderingsManager(ArchiveManager): 136 | def __init__(self): 137 | super(RenderingsManager, self).__init__( 138 | renderings_url, renderings_path, 'r:gz') 139 | self._cat_ids = None 140 | self._example_ids = {} 141 | 142 | def __getitem__(self, args): 143 | cat_id, example_id, view_index = args 144 | return self.load(cat_id, example_id, view_index) 145 | 146 | def load(self, cat_id, example_id, view_index): 147 | from PIL import Image 148 | fp = self.load_subpath(get_renderings_subpath( 149 | cat_id, example_id, view_index)) 150 | return Image.open(fp) 151 | 152 | def get_meta_lines(self, cat_id, example_id): 153 | meta_path = get_renderings_subpath( 154 | cat_id, example_id, 'rendering_metadata.txt') 155 | fp = self.load_subpath(meta_path) 156 | lines = fp.readlines() 157 | return tuple(line for line in lines if len(lines) > 0) 158 | 159 | def get_metas(self, cat_id, example_id): 160 | lines = self.get_meta_lines(cat_id, example_id) 161 | return tuple(parse_meta_line(line) for line in lines) 162 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | class LengthedGenerator(object): 7 | """Generator with an efficient, fixed length.""" 8 | def __init__(self, gen, gen_len): 9 | self._gen = gen 10 | self._len = gen_len 11 | 12 | def __iter__(self): 13 | return iter(self._gen) 14 | 15 | def __len__(self): 16 | return self._len 17 | --------------------------------------------------------------------------------