├── .gitignore ├── README.md ├── data ├── modelnet10.py └── put_modelnet10_dataset_here.txt ├── train.py ├── utils ├── README.md ├── __pycache__ │ └── binvox_rw.cpython-36.pyc ├── binvox ├── binvox_rw.py ├── off2binvox.py └── viewvox └── voxnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Jupyter Notebook 59 | .ipynb_checkpoints 60 | 61 | # IPython 62 | profile_default/ 63 | ipython_config.py 64 | 65 | # pyenv 66 | # For a library or package, you might want to ignore these files since the code is 67 | # intended to run in multiple environments; otherwise, check them in: 68 | # .python-version 69 | 70 | # pipenv 71 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 72 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 73 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 74 | # install all needed dependencies. 75 | #Pipfile.lock 76 | 77 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 78 | __pypackages__/ 79 | 80 | # Celery stuff 81 | celerybeat-schedule 82 | celerybeat.pid 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | .dmypy.json 109 | dmypy.json 110 | 111 | # Pyre type checker 112 | .pyre/ 113 | 114 | # pytype static type analyzer 115 | .pytype/ 116 | 117 | # ignore softlink `data/ModelNet10` 118 | data/ModelNet10 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoxNet.pytorch 2 | 3 | A PyTorch implement of "VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition". 4 | 5 | ## Prepare Data 6 | 7 | 1. Download `ModelNet10` dataset, unzip in `data/`. 8 | 9 | like `data/ModelNet10/bathtub` ... 10 | 11 | 2. Convert `*.off` file to `*.binvox` file. 12 | ```shell 13 | cd utils 14 | chmod +x binvox 15 | python off2binvox.py 16 | ``` 17 | 18 | ## Train 19 | Train VoxNet and the model weights will output in `cls/` 20 | ```shell 21 | python train.py 22 | ``` 23 | 24 | ## Reference 25 | ``` 26 | @inproceedings{maturana_iros_2015, 27 | author = "Maturana, D. and Scherer, S.", 28 | title = "{VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition}", 29 | booktitle = "{IROS}", 30 | year = "2015", 31 | pdf = "/extra/voxnet_maturana_scherer_iros15.pdf", 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /data/modelnet10.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import sys 4 | import numpy as np 5 | import glob 6 | import re 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | sys.path.insert(0, '../utils/') 10 | sys.path.insert(0, './utils/') 11 | import binvox_rw 12 | 13 | 14 | class ModelNet10(Dataset): 15 | def __init__(self, data_root, n_classes, idx2cls, split='train'): 16 | """ 17 | Args: 18 | split (str, optional): 'train' or 'test'. Defaults to 'train'. 19 | """ 20 | self.data_root = data_root 21 | self.n_classes = n_classes 22 | self.samples_str = [] 23 | self.cls2idx = {} 24 | for k, v in idx2cls.items(): 25 | self.cls2idx.update({v: k}) 26 | for sample_str in glob.glob(os.path.join(data_root, v, split, '*.binvox')): 27 | if re.match(r"[a-zA-Z]+_\d+.binvox", os.path.basename(sample_str)): 28 | self.samples_str.append(sample_str) 29 | print(self.cls2idx) 30 | 31 | def __getitem__(self, idx): 32 | sample_name = self.samples_str[idx] 33 | cls_name = re.split(r"_\d+\.binvox", os.path.basename(sample_name))[0] 34 | cls_idx = self.cls2idx[cls_name] 35 | with open(sample_name, 'rb') as file: 36 | data = np.int32(binvox_rw.read_as_3d_array(file).data) 37 | data = data[np.newaxis, :] 38 | 39 | sample = {'voxel': data, 'cls_idx': cls_idx} 40 | 41 | return sample 42 | 43 | def __len__(self): 44 | return len(self.samples_str) 45 | 46 | 47 | if __name__ == "__main__": 48 | idx2cls = {0: 'bathtub', 1: 'chair', 2: 'dresser', 3: 'night_stand', 49 | 4: 'sofa', 5: 'toilet', 6: 'bed', 7: 'desk', 8: 'monitor', 9: 'table'} 50 | 51 | data_root = './ModelNet10' 52 | 53 | dataset = ModelNet10(data_root=data_root, n_classes=10, idx2cls=idx2cls, split='train') 54 | cnt = len(dataset) 55 | 56 | data, cls_idx = dataset[0]['voxel'], dataset[1]['cls_idx'] 57 | print(f"length: {cnt}\nsample data: {data}\nsample cls: {cls_idx}") 58 | -------------------------------------------------------------------------------- /data/put_modelnet10_dataset_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MonteYang/VoxNet.pytorch/98e9a6e7d5fd584fccaf2fd3c4616346fa201ed2/data/put_modelnet10_dataset_here.txt -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | ''' 4 | File: train.py 5 | Created: 2020-01-21 21:32:40 6 | Author : Yangmaonan 7 | Email : 59786677@qq.com 8 | Description: 9 | ''' 10 | 11 | from __future__ import print_function 12 | import argparse 13 | import sys 14 | import os 15 | import random 16 | import torch 17 | import torch.nn.parallel 18 | import torch.optim as optim 19 | from torch.utils.data import DataLoader 20 | import torch.nn.functional as F 21 | from tqdm import tqdm 22 | from voxnet import VoxNet 23 | sys.path.insert(0, './data/') 24 | from modelnet10 import ModelNet10 25 | 26 | CLASSES = { 27 | 0: 'bathtub', 28 | 1: 'chair', 29 | 2: 'dresser', 30 | 3: 'night_stand', 31 | 4: 'sofa', 32 | 5: 'toilet', 33 | 6: 'bed', 34 | 7: 'desk', 35 | 8: 'monitor', 36 | 9: 'table' 37 | } 38 | N_CLASSES = len(CLASSES) 39 | 40 | def blue(x): return '\033[94m' + x + '\033[0m' 41 | 42 | # 参数解析 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--data-root', type=str, default='/Data1/DL-project/VoxNet.pytorch/data/ModelNet10', help="dataset path") 45 | parser.add_argument('--batchSize', type=int, default=256, help='input batch size') 46 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 47 | parser.add_argument('--n-epoch', type=int, default=30, help='number of epochs to train for') 48 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 49 | parser.add_argument('--model', type=str, default='', help='model path') 50 | opt = parser.parse_args() 51 | # print(opt) 52 | 53 | # 创建目录 54 | try: 55 | os.makedirs(opt.outf) 56 | except OSError: 57 | pass 58 | 59 | # 固定随机种子 60 | opt.manualSeed = random.randint(1, 10000) 61 | print("Random Seed: ", opt.manualSeed) 62 | random.seed(opt.manualSeed) 63 | torch.manual_seed(opt.manualSeed) 64 | 65 | # 数据加载 66 | train_dataset = ModelNet10(data_root=opt.data_root, n_classes=N_CLASSES, idx2cls=CLASSES, split='train') 67 | test_dataset = ModelNet10(data_root=opt.data_root, n_classes=N_CLASSES, idx2cls=CLASSES, split='test') 68 | 69 | train_dataloader = DataLoader(train_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) 70 | test_dataloader = DataLoader(test_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) 71 | 72 | # VoxNet 73 | voxnet = VoxNet(n_classes=N_CLASSES) 74 | 75 | print(voxnet) 76 | 77 | # 加载权重 78 | if opt.model != '': 79 | voxnet.load_state_dict(torch.load(opt.model)) 80 | 81 | # 优化器 82 | optimizer = optim.Adam(voxnet.parameters(), lr=1e-4) 83 | # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 84 | voxnet.cuda() 85 | 86 | num_batch = len(train_dataset) / opt.batchSize 87 | print(num_batch) 88 | 89 | for epoch in range(opt.n_epoch): 90 | # scheduler.step() 91 | for i, sample in enumerate(train_dataloader, 0): 92 | # 读数据 93 | voxel, cls_idx = sample['voxel'], sample['cls_idx'] 94 | voxel, cls_idx = voxel.cuda(), cls_idx.cuda() 95 | voxel = voxel.float() # Voxel原来是int类型(0,1),需转float, torch.Size([256, 1, 32, 32, 32]) 96 | 97 | # 梯度清零 98 | optimizer.zero_grad() 99 | 100 | # 网络切换训练模型 101 | voxnet = voxnet.train() 102 | pred = voxnet(voxel) # torch.Size([256, 10]) 103 | 104 | # 计算损失函数 105 | 106 | loss = F.cross_entropy(pred, cls_idx) 107 | 108 | # 反向传播, 更新权重 109 | loss.backward() 110 | optimizer.step() 111 | 112 | # 计算该batch的预测准确率 113 | pred_choice = pred.data.max(1)[1] 114 | correct = pred_choice.eq(cls_idx.data).cpu().sum() 115 | print('[%d: %d/%d] train loss: %f accuracy: %f' % 116 | (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize))) 117 | 118 | # 每5个batch进行一次test 119 | if i % 5 == 0: 120 | j, sample = next(enumerate(test_dataloader, 0)) 121 | voxel, cls_idx = sample['voxel'], sample['cls_idx'] 122 | voxel, cls_idx = voxel.cuda(), cls_idx.cuda() 123 | voxel = voxel.float() # 转float, torch.Size([256, 1, 32, 32, 32]) 124 | voxnet = voxnet.eval() 125 | pred = voxnet(voxel) 126 | loss = F.nll_loss(pred, cls_idx) 127 | pred_choice = pred.data.max(1)[1] 128 | correct = pred_choice.eq(cls_idx.data).cpu().sum() 129 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, 130 | blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 131 | 132 | # 保存权重 133 | torch.save(voxnet.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 134 | 135 | 136 | # 训练后, 在测试集上评估 137 | total_correct = 0 138 | total_testset = 0 139 | 140 | for i, data in tqdm(enumerate(test_dataloader, 0)): 141 | voxel, cls_idx = data['voxel'], data['cls_idx'] 142 | voxel, cls_idx = voxel.cuda(), cls_idx.cuda() 143 | voxel = voxel.float() # 转float, torch.Size([256, 1, 32, 32, 32]) 144 | 145 | voxnet = voxnet.eval() 146 | pred = voxnet(voxel) 147 | pred_choice = pred.data.max(1)[1] 148 | correct = pred_choice.eq(cls_idx.data).cpu().sum() 149 | total_correct += correct.item() 150 | total_testset += voxel.size()[0] 151 | 152 | print("final accuracy {}".format(total_correct / float(total_testset))) 153 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # binvox 2 | ```shell 3 | Usage: binvox [-d ] [-t ] [-c] [-v] 4 | -license: show software license 5 | -d: specify voxel grid size (default 256, max 1024)(no max when using -e) 6 | -t: specify voxel file type (default binvox, also supported: hips, mira, vtk, raw, schematic, msh) 7 | -c: z-buffer based carving method only 8 | -dc: dilated carving, stop carving 1 voxel before intersection 9 | -v: z-buffer based parity voting method only (default is both -c and -v) 10 | -e: exact voxelization (any voxel intersecting a convex polygon gets set)(does not use graphics card) 11 | Additional parameters: 12 | -bb : force a different input model bounding box 13 | -ri: remove internal voxels 14 | -cb: center model inside unit cube 15 | -rotx: rotate object 90 degrees ccw around x-axis before voxelizing 16 | -rotz: rotate object 90 degrees cw around z-axis before voxelizing 17 | both -rotx and -rotz can be used multiple times 18 | -aw: _also_ render the model in wireframe (helps with thin parts) 19 | -fit: only write the voxels in the voxel bounding box 20 | -bi : when converting to schematic, use block ID 21 | -mb: when converting using -e from .obj to schematic, parse block ID from material spec 'usemtl blockid_' (ids 1-255 only) 22 | -pb: use offscreen pbuffer instead of onscreen window 23 | -down: downsample voxels by a factor of 2 in each dimension (can be used multiple times) 24 | -dmin : when downsampling, destination voxel is on if >= source voxels are (default 4) 25 | Supported 3D model file formats: 26 | VRML V2.0: almost fully supported 27 | UG, OBJ, OFF, DXF, XGL, POV, BREP, PLY, JOT: only polygons supported 28 | Example: 29 | binvox -c -d 200 -t mira plane.wrl 30 | ``` 31 | 32 | 33 | # viewvox 34 | 35 | ```shell 36 | Usage 37 | viewvox [-ki] 38 | 39 | -ki: keep internal voxels (removed by default) 40 | 41 | Mouse left button = rotate 42 | middle = pan 43 | right = zoom 44 | Key r = reset view 45 | arrow keys = move 1 voxel step along x (left, right) or y (up, down) 46 | =,- = move 1 voxel step along z 47 | 48 | q = quit 49 | 50 | a = toggle alternating colours 51 | p = toggle between orthographic and perspective projection 52 | x, y, z = set camera looking down X, Y, or Z axis 53 | X, Y, Z = set camera looking up X, Y, or Z axis 54 | 1 = toggle show x, y, and z coordinates 55 | 56 | s = show single slice 57 | n = show both/above/below slice neighbour(s) 58 | t = toggle neighbour transparency 59 | j = move slice down 60 | k = move slice up 61 | g = toggle show grid at slice level 62 | 63 | A lot of the key commands were added to make viewvox more useful when building voxel models in minecraft.http://www.patrickmin.com/minecraft 64 | 65 | ``` 66 | 67 | -------------------------------------------------------------------------------- /utils/__pycache__/binvox_rw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MonteYang/VoxNet.pytorch/98e9a6e7d5fd584fccaf2fd3c4616346fa201ed2/utils/__pycache__/binvox_rw.cpython-36.pyc -------------------------------------------------------------------------------- /utils/binvox: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MonteYang/VoxNet.pytorch/98e9a6e7d5fd584fccaf2fd3c4616346fa201ed2/utils/binvox -------------------------------------------------------------------------------- /utils/binvox_rw.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | """ 19 | Binvox to Numpy and back. 20 | 21 | 22 | >>> import numpy as np 23 | >>> import binvox_rw 24 | >>> with open('chair.binvox', 'rb') as f: 25 | ... m1 = binvox_rw.read_as_3d_array(f) 26 | ... 27 | >>> m1.dims 28 | [32, 32, 32] 29 | >>> m1.scale 30 | 41.133000000000003 31 | >>> m1.translate 32 | [0.0, 0.0, 0.0] 33 | >>> with open('chair_out.binvox', 'wb') as f: 34 | ... m1.write(f) 35 | ... 36 | >>> with open('chair_out.binvox', 'rb') as f: 37 | ... m2 = binvox_rw.read_as_3d_array(f) 38 | ... 39 | >>> m1.dims==m2.dims 40 | True 41 | >>> m1.scale==m2.scale 42 | True 43 | >>> m1.translate==m2.translate 44 | True 45 | >>> np.all(m1.data==m2.data) 46 | True 47 | 48 | >>> with open('chair.binvox', 'rb') as f: 49 | ... md = binvox_rw.read_as_3d_array(f) 50 | ... 51 | >>> with open('chair.binvox', 'rb') as f: 52 | ... ms = binvox_rw.read_as_coord_array(f) 53 | ... 54 | >>> data_ds = binvox_rw.dense_to_sparse(md.data) 55 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) 56 | >>> np.all(data_sd==md.data) 57 | True 58 | >>> # the ordering of elements returned by numpy.nonzero changes with axis 59 | >>> # ordering, so to compare for equality we first lexically sort the voxels. 60 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) 61 | True 62 | """ 63 | 64 | import numpy as np 65 | 66 | class Voxels(object): 67 | """ Holds a binvox model. 68 | data is either a three-dimensional numpy boolean array (dense representation) 69 | or a two-dimensional numpy float array (coordinate representation). 70 | 71 | dims, translate and scale are the model metadata. 72 | 73 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 74 | 75 | scale and translate relate the voxels to the original model coordinates. 76 | 77 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 78 | 79 | x_n = (i+.5)/dims[0] 80 | y_n = (j+.5)/dims[1] 81 | z_n = (k+.5)/dims[2] 82 | x = scale*x_n + translate[0] 83 | y = scale*y_n + translate[1] 84 | z = scale*z_n + translate[2] 85 | 86 | """ 87 | 88 | def __init__(self, data, dims, translate, scale, axis_order): 89 | self.data = data 90 | self.dims = dims 91 | self.translate = translate 92 | self.scale = scale 93 | assert (axis_order in ('xzy', 'xyz')) 94 | self.axis_order = axis_order 95 | 96 | def clone(self): 97 | data = self.data.copy() 98 | dims = self.dims[:] 99 | translate = self.translate[:] 100 | return Voxels(data, dims, translate, self.scale, self.axis_order) 101 | 102 | def write(self, fp): 103 | write(self, fp) 104 | 105 | def read_header(fp): 106 | """ Read binvox header. Mostly meant for internal use. 107 | """ 108 | line = fp.readline().strip() 109 | if not line.startswith(b'#binvox'): 110 | raise IOError('Not a binvox file') 111 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 112 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 113 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 114 | line = fp.readline() 115 | return dims, translate, scale 116 | 117 | def read_as_3d_array(fp, fix_coords=True): 118 | """ Read binary binvox format as array. 119 | 120 | Returns the model with accompanying metadata. 121 | 122 | Voxels are stored in a three-dimensional numpy array, which is simple and 123 | direct, but may use a lot of memory for large models. (Storage requirements 124 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 125 | boolean arrays use a byte per element). 126 | 127 | Doesn't do any checks on input except for the '#binvox' line. 128 | """ 129 | dims, translate, scale = read_header(fp) 130 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 131 | # if just using reshape() on the raw data: 132 | # indexing the array as array[i,j,k], the indices map into the 133 | # coords as: 134 | # i -> x 135 | # j -> z 136 | # k -> y 137 | # if fix_coords is true, then data is rearranged so that 138 | # mapping is 139 | # i -> x 140 | # j -> y 141 | # k -> z 142 | values, counts = raw_data[::2], raw_data[1::2] 143 | data = np.repeat(values, counts).astype(np.bool) 144 | data = data.reshape(dims) 145 | if fix_coords: 146 | # xzy to xyz TODO the right thing 147 | data = np.transpose(data, (0, 2, 1)) 148 | axis_order = 'xyz' 149 | else: 150 | axis_order = 'xzy' 151 | return Voxels(data, dims, translate, scale, axis_order) 152 | 153 | def read_as_coord_array(fp, fix_coords=True): 154 | """ Read binary binvox format as coordinates. 155 | 156 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 157 | 3 x N array where N is the number of nonzero voxels. Each column 158 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 159 | of the voxel. (The odd ordering is due to the way binvox format lays out 160 | data). Note that coordinates refer to the binvox voxels, without any 161 | scaling or translation. 162 | 163 | Use this to save memory if your model is very sparse (mostly empty). 164 | 165 | Doesn't do any checks on input except for the '#binvox' line. 166 | """ 167 | dims, translate, scale = read_header(fp) 168 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 169 | 170 | values, counts = raw_data[::2], raw_data[1::2] 171 | 172 | sz = np.prod(dims) 173 | index, end_index = 0, 0 174 | end_indices = np.cumsum(counts) 175 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 176 | 177 | values = values.astype(np.bool) 178 | indices = indices[values] 179 | end_indices = end_indices[values] 180 | 181 | nz_voxels = [] 182 | for index, end_index in zip(indices, end_indices): 183 | nz_voxels.extend(range(index, end_index)) 184 | nz_voxels = np.array(nz_voxels) 185 | # TODO are these dims correct? 186 | # according to docs, 187 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 188 | 189 | x = nz_voxels / (dims[0]*dims[1]) 190 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 191 | z = zwpy / dims[0] 192 | y = zwpy % dims[0] 193 | if fix_coords: 194 | data = np.vstack((x, y, z)) 195 | axis_order = 'xyz' 196 | else: 197 | data = np.vstack((x, z, y)) 198 | axis_order = 'xzy' 199 | 200 | #return Voxels(data, dims, translate, scale, axis_order) 201 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 202 | 203 | def dense_to_sparse(voxel_data, dtype=np.int): 204 | """ From dense representation to sparse (coordinate) representation. 205 | No coordinate reordering. 206 | """ 207 | if voxel_data.ndim!=3: 208 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 209 | return np.asarray(np.nonzero(voxel_data), dtype) 210 | 211 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 212 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 213 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 214 | if np.isscalar(dims): 215 | dims = [dims]*3 216 | dims = np.atleast_2d(dims).T 217 | # truncate to integers 218 | xyz = voxel_data.astype(np.int) 219 | # discard voxels that fall outside dims 220 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 221 | xyz = xyz[:,valid_ix] 222 | out = np.zeros(dims.flatten(), dtype=dtype) 223 | out[tuple(xyz)] = True 224 | return out 225 | 226 | #def get_linear_index(x, y, z, dims): 227 | #""" Assuming xzy order. (y increasing fastest. 228 | #TODO ensure this is right when dims are not all same 229 | #""" 230 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 231 | 232 | def write(voxel_model, fp): 233 | """ Write binary binvox format. 234 | 235 | Note that when saving a model in sparse (coordinate) format, it is first 236 | converted to dense format. 237 | 238 | Doesn't check if the model is 'sane'. 239 | 240 | """ 241 | if voxel_model.data.ndim==2: 242 | # TODO avoid conversion to dense 243 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 244 | else: 245 | dense_voxel_data = voxel_model.data 246 | 247 | fp.write('#binvox 1\n') 248 | fp.write('dim '+' '.join(map(str, voxel_model.dims))+'\n') 249 | fp.write('translate '+' '.join(map(str, voxel_model.translate))+'\n') 250 | fp.write('scale '+str(voxel_model.scale)+'\n') 251 | fp.write('data\n') 252 | if not voxel_model.axis_order in ('xzy', 'xyz'): 253 | raise ValueError('Unsupported voxel model axis order') 254 | 255 | if voxel_model.axis_order=='xzy': 256 | voxels_flat = dense_voxel_data.flatten() 257 | elif voxel_model.axis_order=='xyz': 258 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 259 | 260 | # keep a sort of state machine for writing run length encoding 261 | state = voxels_flat[0] 262 | ctr = 0 263 | for c in voxels_flat: 264 | if c==state: 265 | ctr += 1 266 | # if ctr hits max, dump 267 | if ctr==255: 268 | fp.write(chr(state)) 269 | fp.write(chr(ctr)) 270 | ctr = 0 271 | else: 272 | # if switch state, dump 273 | fp.write(chr(state)) 274 | fp.write(chr(ctr)) 275 | state = c 276 | ctr = 1 277 | # flush out remainders 278 | if ctr > 0: 279 | fp.write(chr(state)) 280 | fp.write(chr(ctr)) 281 | 282 | if __name__ == '__main__': 283 | import doctest 284 | doctest.testmod() 285 | -------------------------------------------------------------------------------- /utils/off2binvox.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | ''' 4 | File: off2binvox.py 5 | Created: 2020-01-21 21:32:40 6 | Author : Yangmaonan 7 | Email : 59786677@qq.com 8 | Description: 将 ModelNet10 数据集中.off文件转为binvox文件 9 | ''' 10 | # TODO: 使用多进程加速文件转换速度 11 | import os 12 | import glob 13 | 14 | DATA_ROOT = '../data/ModelNet10' 15 | 16 | CLASSES = {'bathtub', 'chair', 'dresser', 'night_stand', 'sofa', 'toilet', 'bed', 'desk', 'monitor', 'table'} 17 | 18 | for c in CLASSES: 19 | for split in ['test', 'train']: 20 | for off in glob.glob(os.path.join(DATA_ROOT, c, split, '*.off')): 21 | # 判断是否存在 22 | binname = os.path.join(DATA_ROOT, c, split, os.path.basename(off).split('.')[0] + '.binvox') 23 | if os.path.exists(binname): 24 | print(binname, "exits, continue...") 25 | continue 26 | os.system(f'./binvox -d 32 -cb -pb {off}') 27 | -------------------------------------------------------------------------------- /utils/viewvox: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MonteYang/VoxNet.pytorch/98e9a6e7d5fd584fccaf2fd3c4616346fa201ed2/utils/viewvox -------------------------------------------------------------------------------- /voxnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | ''' 4 | File: voxnet.py 5 | Created: 2020-01-21 21:32:40 6 | Author : Yangmaonan 7 | Email : 59786677@qq.com 8 | Description: VoxNet 网络结构 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | from collections import OrderedDict 14 | 15 | 16 | class VoxNet(nn.Module): 17 | def __init__(self, n_classes=10, input_shape=(32, 32, 32)): 18 | super(VoxNet, self).__init__() 19 | self.n_classes = n_classes 20 | self.input_shape = input_shape 21 | self.feat = torch.nn.Sequential(OrderedDict([ 22 | ('conv3d_1', torch.nn.Conv3d(in_channels=1, 23 | out_channels=32, kernel_size=5, stride=2)), 24 | ('relu1', torch.nn.ReLU()), 25 | ('drop1', torch.nn.Dropout(p=0.2)), 26 | ('conv3d_2', torch.nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3)), 27 | ('relu2', torch.nn.ReLU()), 28 | ('pool2', torch.nn.MaxPool3d(2)), 29 | ('drop2', torch.nn.Dropout(p=0.3)) 30 | ])) 31 | x = self.feat(torch.autograd.Variable(torch.rand((1, 1) + input_shape))) 32 | dim_feat = 1 33 | for n in x.size()[1:]: 34 | dim_feat *= n 35 | 36 | self.mlp = torch.nn.Sequential(OrderedDict([ 37 | ('fc1', torch.nn.Linear(dim_feat, 128)), 38 | ('relu1', torch.nn.ReLU()), 39 | ('drop3', torch.nn.Dropout(p=0.4)), 40 | ('fc2', torch.nn.Linear(128, self.n_classes)) 41 | ])) 42 | 43 | def forward(self, x): 44 | x = self.feat(x) 45 | x = x.view(x.size(0), -1) 46 | x = self.mlp(x) 47 | return x 48 | 49 | 50 | if __name__ == "__main__": 51 | voxnet = VoxNet() 52 | data = torch.rand([256, 1, 32, 32, 32]) 53 | voxnet(data) 54 | --------------------------------------------------------------------------------