├── .gitignore ├── 3dgan.py ├── LICENSE ├── README.md ├── dataset.py ├── nets.py ├── vis.py ├── visualization ├── __init__.py ├── util.py └── util_vtk.py └── visualize.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | train_log/ 107 | model/ 108 | log/ 109 | img/ 110 | *.mat 111 | *.jpg 112 | *.png 113 | 114 | 115 | -------------------------------------------------------------------------------- /3dgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Created Time: 2018/05/11 11:50:23 3 | # Author: Taihong Xiao 4 | 5 | from dataset import config, ShapeNet 6 | from nets import Generator, Discriminator 7 | 8 | import os, argparse 9 | import torch 10 | import numpy as np 11 | import scipy.io as sio 12 | from tensorboardX import SummaryWriter 13 | from itertools import chain 14 | 15 | 16 | class _3DGAN(object): 17 | def __init__(self, args, config=config): 18 | self.args = args 19 | self.attribute = args.attribute 20 | self.gpu = args.gpu 21 | self.mode = args.mode 22 | self.restore = args.restore 23 | 24 | # init dataset and networks 25 | self.config = config 26 | self.dataset = ShapeNet(self.attribute) 27 | self.G = Generator() 28 | self.D = Discriminator() 29 | 30 | self.adv_criterion = torch.nn.BCELoss() 31 | 32 | self.set_mode_and_gpu() 33 | self.restore_from_file() 34 | 35 | def set_mode_and_gpu(self): 36 | if self.mode == 'train': 37 | self.G.train() 38 | self.D.train() 39 | if self.gpu: 40 | with torch.cuda.device(self.gpu[0]): 41 | self.G.cuda() 42 | self.D.cuda() 43 | self.adv_criterion.cuda() 44 | 45 | if len(self.gpu) > 1: 46 | self.G = torch.nn.DataParallel(self.G, device_ids=self.gpu) 47 | self.D = torch.nn.DataParallel(self.D, device_ids=self.gpu) 48 | 49 | elif self.mode == 'test': 50 | self.G.eval() 51 | self.D.eval() 52 | if self.gpu: 53 | with torch.cuda.device(self.gpu[0]): 54 | self.G.cuda() 55 | self.D.cuda() 56 | 57 | if len(self.gpu) > 1: 58 | self.G = torch.nn.DataParallel(self.G, device_ids=self.gpu) 59 | self.D = torch.nn.DataParallel(self.D, device_ids=self.gpu) 60 | 61 | else: 62 | raise NotImplementationError() 63 | 64 | def restore_from_file(self): 65 | if self.restore is not None: 66 | ckpt_file_G = os.path.join(self.config.model_dir, 'G_iter_{:06d}.pth'.format(self.restore)) 67 | assert os.path.exists(ckpt_file_G) 68 | self.G.load_state_dict(torch.load(ckpt_file_G)) 69 | 70 | if self.mode == 'train': 71 | ckpt_file_D = os.path.join(self.config.model_dir, 'D_iter_{:06d}.pth'.format(self.restore)) 72 | assert os.path.exists(ckpt_file_D) 73 | self.D.load_state_dict(torch.load(ckpt_file_D)) 74 | 75 | self.start_step = self.restore + 1 76 | else: 77 | self.start_step = 1 78 | 79 | def save_log(self): 80 | scalar_info = { 81 | 'loss_D': self.loss_D, 82 | 'loss_G': self.loss_G, 83 | 'G_lr' : self.G_lr_scheduler.get_lr()[0], 84 | 'D_lr' : self.D_lr_scheduler.get_lr()[0], 85 | } 86 | for key, value in self.G_loss.items(): 87 | scalar_info['G_loss/' + key] = value 88 | 89 | for key, value in self.D_loss.items(): 90 | scalar_info['D_loss/' + key] = value 91 | 92 | for tag, value in scalar_info.items(): 93 | self.writer.add_scalar(tag, value, self.step) 94 | 95 | def save_img(self, save_num=5): 96 | for i in range(save_num): 97 | mdict = { 98 | 'instance': self.fake_X[i,0].data.cpu().numpy() 99 | } 100 | sio.savemat(os.path.join(self.config.img_dir, '{:06d}_{:02d}.mat'.format(self.step, i)), mdict) 101 | 102 | def save_model(self): 103 | torch.save({key: val.cpu() for key, val in self.G.state_dict().items()}, os.path.join(self.config.model_dir, 'G_iter_{:06d}.pth'.format(self.step))) 104 | torch.save({key: val.cpu() for key, val in self.D.state_dict().items()}, os.path.join(self.config.model_dir, 'D_iter_{:06d}.pth'.format(self.step))) 105 | 106 | def train(self): 107 | self.writer = SummaryWriter(self.config.log_dir) 108 | self.opt_G = torch.optim.Adam(self.G.parameters(), lr=self.config.G_lr, betas=(0.5, 0.999)) 109 | self.opt_D = torch.optim.Adam(self.D.parameters(), lr=self.config.D_lr, betas=(0.5, 0.999)) 110 | self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.opt_G, step_size=self.config.step_size, gamma=self.config.gamma) 111 | self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.opt_D, step_size=self.config.step_size, gamma=self.config.gamma) 112 | 113 | # start training 114 | for step in range(self.start_step, 1 + self.config.max_iter): 115 | self.step = step 116 | self.G_lr_scheduler.step() 117 | self.D_lr_scheduler.step() 118 | 119 | self.real_X = next(self.dataset.gen(True)) 120 | self.noise = torch.randn(self.config.nchw[0], 200) 121 | if len(self.gpu): 122 | with torch.cuda.device(self.gpu[0]): 123 | self.real_X = self.real_X.cuda() 124 | self.noise = self.noise.cuda() 125 | 126 | self.fake_X = self.G(self.noise) 127 | 128 | # update D 129 | self.D_real = self.D(self.real_X) 130 | self.D_fake = self.D(self.fake_X.detach()) 131 | self.D_loss = { 132 | 'adv_real': self.adv_criterion(self.D_real, torch.ones_like(self.D_real)), 133 | 'adv_fake': self.adv_criterion(self.D_fake, torch.zeros_like(self.D_fake)), 134 | } 135 | self.loss_D = sum(self.D_loss.values()) 136 | 137 | self.opt_D.zero_grad() 138 | self.loss_D.backward() 139 | self.opt_D.step() 140 | 141 | # update G 142 | self.D_fake = self.D(self.fake_X) 143 | self.G_loss = { 144 | 'adv_fake': self.adv_criterion(self.D_fake, torch.ones_like(self.D_fake)) 145 | } 146 | self.loss_G = sum(self.G_loss.values()) 147 | self.opt_G.zero_grad() 148 | self.loss_G.backward() 149 | self.opt_G.step() 150 | 151 | print('step: {:06d}, loss_D: {:.6f}, loss_G: {:.6f}'.format(self.step, self.loss_D.data.cpu().numpy(), self.loss_G.data.cpu().numpy())) 152 | 153 | if self.step % 100 == 0: 154 | self.save_log() 155 | 156 | if self.step % 1000 == 0: 157 | self.save_img() 158 | self.save_model() 159 | 160 | print('Finished training!') 161 | self.writer.close() 162 | 163 | 164 | if __name__ == "__main__": 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument('-a', '--attribute', type=str, help='Specify category for training.') 167 | parser.add_argument('-g', '--gpu', default=[], nargs='+', type=int, help='Specify GPU ids.') 168 | parser.add_argument('-r', '--restore', default=None, action='store', type=int, help='Specify checkpoint id to restore.') 169 | parser.add_argument('-m', '--mode', default='train', type=str, choices=['train', 'test']) 170 | args = parser.parse_args() 171 | print(args) 172 | 173 | model = _3DGAN(args) 174 | if args.mode == 'train': 175 | model.train() 176 | 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Taihong Xiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D-GAN-pytorch 2 | 3 | Pytorch Implementation of [3D-GAN](http://3dgan.csail.mit.edu/papers/3dgan_nips.pdf). 4 | 5 | ## Dataset 6 | 7 | ``` 8 | wget http://3dshapenets.cs.princeton.edu/3DShapeNetsCode.zip 9 | unzip 3DShapeNetsCode.zip 10 | mv 3DShapeNetsCode ModelNet 11 | ``` 12 | 13 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Created Time: 2018/05/10 17:22:38 3 | # Author: Taihong Xiao 4 | 5 | import os 6 | import scipy.ndimage as nd 7 | import scipy.io as sio 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits import mplot3d 11 | 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import transforms, utils 15 | 16 | 17 | class Config: 18 | @property 19 | def data_dir(self): 20 | # data_dir = '/home/xiaoth/datasets/ModelNet/volumetric_data' 21 | data_dir = '/gpfs/share/home/1501210096/datasets/ModelNet/volumetric_data' 22 | if not os.path.exists(data_dir): 23 | os.makedirs(data_dir) 24 | return data_dir 25 | 26 | @property 27 | def exp_dir(self): 28 | exp_dir = os.path.join('train_log') 29 | if not os.path.exists(exp_dir): 30 | os.makedirs(exp_dir) 31 | return exp_dir 32 | 33 | @property 34 | def model_dir(self): 35 | model_dir = os.path.join(self.exp_dir, 'model') 36 | if not os.path.exists(model_dir): 37 | os.makedirs(model_dir) 38 | return model_dir 39 | 40 | @property 41 | def log_dir(self): 42 | log_dir = os.path.join(self.exp_dir, 'log') 43 | if not os.path.exists(log_dir): 44 | os.makedirs(log_dir) 45 | return log_dir 46 | 47 | @property 48 | def img_dir(self): 49 | img_dir = os.path.join(self.exp_dir, 'img') 50 | if not os.path.exists(img_dir): 51 | os.makedirs(img_dir) 52 | return img_dir 53 | 54 | nchw = [32,64,64,64] 55 | 56 | G_lr = 2.5e-3 57 | 58 | D_lr = 1e-5 59 | 60 | step_size = 2000 61 | 62 | gamma = 0.95 63 | 64 | shuffle = True 65 | 66 | num_workers = 0 67 | 68 | max_iter = 20000 69 | 70 | config = Config() 71 | 72 | 73 | class Single(Dataset): 74 | def __init__(self, filenames, config): 75 | self.filenames = filenames 76 | self.config = config 77 | 78 | def __len__(self): 79 | return len(self.filenames) 80 | 81 | def __getitem__(self, idx): 82 | voxel = sio.loadmat(self.filenames[idx])['instance'] 83 | voxel = np.pad(voxel, (1,1), 'constant', constant_values=(0,0)) 84 | if self.config.nchw[-1] != 32: 85 | ratio = self.config.nchw[-1] / 32. 86 | voxel = nd.zoom(voxel, (ratio, ratio, ratio), mode='constant', order=0) 87 | return np.expand_dims(voxel.astype(np.float32), 0) 88 | 89 | def gen(self): 90 | dataloader = DataLoader(self, batch_size=self.config.nchw[0], shuffle=self.config.shuffle, num_workers=self.config.num_workers, drop_last=True) 91 | while True: 92 | for data in dataloader: 93 | yield data 94 | 95 | 96 | class ShapeNet(object): 97 | def __init__(self, category, config=config): 98 | self.category = category 99 | self.config = config 100 | 101 | self.dict = {True: None, False: None} 102 | for is_train in [True, False]: 103 | prefix = os.path.join(self.config.data_dir, category, '30') 104 | data_dir = prefix + '/train' if is_train else prefix + '/test' 105 | filenames = [os.path.join(data_dir, name) for name in os.listdir(data_dir) if name.endswith('.mat')] 106 | self.dict[is_train] = Single(filenames, self.config).gen() 107 | 108 | def gen(self, is_train): 109 | data_gen = self.dict[is_train] 110 | return data_gen 111 | 112 | def test(): 113 | dataset = ShapeNet('chair') 114 | import cProfile 115 | pr = cProfile.Profile() 116 | pr.enable() 117 | for i in range(10): 118 | if 1 % 2 == 0: 119 | voxel = next(dataset.gen(True)) 120 | else: 121 | voxel = next(dataset.gen(False)) 122 | print(i) 123 | print(voxel.shape) 124 | 125 | pr.disable() 126 | pr.print_stats() 127 | 128 | 129 | if __name__ == "__main__": 130 | test() 131 | -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # Created Time: 2018/05/11 10:21:32 3 | # Author: Taihong Xiao 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | class Generator(nn.Module): 10 | def __init__(self): 11 | super(Generator, self).__init__() 12 | self.main = nn.Sequential( 13 | nn.ConvTranspose3d(200, 512, 4, 2, 0), 14 | nn.BatchNorm3d(512), 15 | nn.ReLU(), 16 | 17 | nn.ConvTranspose3d(512, 256, 4, 2, 1), 18 | nn.BatchNorm3d(256), 19 | nn.ReLU(), 20 | 21 | nn.ConvTranspose3d(256, 128, 4, 2, 1), 22 | nn.BatchNorm3d(128), 23 | nn.ReLU(), 24 | 25 | nn.ConvTranspose3d(128, 64, 4, 2, 1), 26 | nn.BatchNorm3d(64), 27 | nn.ReLU(), 28 | 29 | nn.ConvTranspose3d(64, 1, 4, 2, 1), 30 | nn.Sigmoid(), 31 | ) 32 | 33 | def forward(self, x): 34 | # x's size: batch_size * hidden_size 35 | x = x.view(x.size(0), x.size(1), 1, 1, 1) 36 | return self.main(x) 37 | 38 | 39 | class Discriminator(nn.Module): 40 | def __init__(self): 41 | super(Discriminator, self).__init__() 42 | self.main = nn.Sequential( 43 | nn.Conv3d(1, 64, 4, 2, 1), 44 | nn.BatchNorm3d(64), 45 | nn.LeakyReLU(0.2), 46 | 47 | nn.Conv3d(64, 128, 4, 2, 1), 48 | nn.BatchNorm3d(128), 49 | nn.LeakyReLU(0.2), 50 | 51 | nn.Conv3d(128, 256, 4, 2, 1), 52 | nn.BatchNorm3d(256), 53 | nn.LeakyReLU(0.2), 54 | 55 | nn.Conv3d(256, 512, 4, 2, 1), 56 | nn.BatchNorm3d(512), 57 | nn.LeakyReLU(0.2), 58 | 59 | nn.Conv3d(512, 1, 4, 2, 0), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | # x's size: batch_size * 1 * 64 * 64 * 64 65 | x = self.main(x) 66 | return x.view(-1, x.size(1)) 67 | 68 | if __name__ == "__main__": 69 | G = Generator().cuda(0) 70 | D = Discriminator().cuda(0) 71 | G = torch.nn.DataParallel(G, device_ids=[0,1]) 72 | D = torch.nn.DataParallel(D, device_ids=[0,1]) 73 | 74 | # z = Variable(torch.rand(16,512,4,4,4)) 75 | # m = nn.ConvTranspose3d(512, 256, 4, 2, 1) 76 | z = Variable(torch.rand(16, 200, 1,1,1)).cuda(1) 77 | X = G(z) 78 | m = nn.Conv3d(1, 64, 4, 2, 1) 79 | D_X = D(X) 80 | print(X.shape, D_X.shape) 81 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import os.path 4 | from visualization.util import * 5 | from visualization.util_vtk import visualization 6 | 7 | if __name__ == '__main__': 8 | import argparse 9 | cmd_parser = argparse.ArgumentParser(description="""Visualizing .mat voxel file. """) 10 | cmd_parser.add_argument('-t', '--threshold', metavar='threshold', type=float, default=0.1, help='voxels with confidence lower than the threshold are not displayed') 11 | cmd_parser.add_argument('-i', '--index', metavar='index', type=int, default=1, help='the index of objects in the inputfile that should be rendered (one based)') 12 | cmd_parser.add_argument('filename', metavar='filename', type=str, help='name of .torch or .mat file to be visualized') 13 | cmd_parser.add_argument('-df', '--downsample-factor', metavar='factor', type=int, default=1, help="downsample objects via a max pooling of step STEPSIZE for efficiency. Currently supporting STEPSIZE 1, 2, and 4.") 14 | cmd_parser.add_argument('-dm', '--downsample-method', metavar='downsample_method', type=str, default='max', help='downsample method, where mean stands for average pooling and max for max pooling') 15 | cmd_parser.add_argument('-u', '--uniform-size', metavar='uniform_size', type=float, default=0.9, help='set the size of the voxels to BLOCK_SIZE') 16 | cmd_parser.add_argument('-cm', '--colormap', action="store_true", help='whether to use a colormap to represent voxel occupancy, or to use a uniform color') 17 | cmd_parser.add_argument('-mc', '--max-component', metavar='max_component', type=int, default=3, help='whether to keep only the maximal connected component, where voxels of distance no larger than `DISTANCE` are considered connected. Set to 0 to disable this function.') 18 | 19 | args = cmd_parser.parse_args() 20 | filename = args.filename 21 | matname = 'instance' 22 | threshold = args.threshold 23 | ind = args.index - 1 # matlab use 1 base index 24 | downsample_factor = args.downsample_factor 25 | downsample_method = args.downsample_method 26 | uniform_size = args.uniform_size 27 | use_colormap = args.colormap 28 | connect = args.max_component 29 | 30 | assert downsample_method in ('max', 'mean') 31 | 32 | # read file 33 | print("==> Reading input voxel file: "+filename) 34 | voxels_raw = read_tensor(filename, matname) 35 | print("Done") 36 | 37 | voxels = voxels_raw[ind] 38 | 39 | # keep only max connected component 40 | print("Looking for max connected component") 41 | if connect > 0: 42 | voxels_keep = (voxels >= threshold) 43 | voxels_keep = max_connected(voxels_keep, connect) 44 | voxels[np.logical_not(voxels_keep)] = 0 45 | 46 | # downsample if needed 47 | if downsample_factor > 1: 48 | print("==> Performing downsample: factor: "+str(downsample_factor)+" method: "+downsample_method) 49 | voxels = downsample(voxels, downsample_factor, method=downsample_method) 50 | print("Done") 51 | 52 | visualization(voxels, threshold, title=str(ind+1)+'/'+str(voxels_raw.shape[0]), uniform_size=uniform_size, use_colormap=use_colormap) 53 | -------------------------------------------------------------------------------- /visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prinsphield/3D-GAN-pytorch/06b4b34e499a5c5187b6a67a8c8dfc35b2e1ce62/visualization/__init__.py -------------------------------------------------------------------------------- /visualization/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import os 4 | from scipy import ndimage 5 | from scipy.io import loadmat 6 | 7 | def read_tensor(filename, varname='voxels'): 8 | """ return a 4D matrix, with dimensions point, x, y, z """ 9 | assert(filename[-4:] == '.mat') 10 | mats = loadmat(filename) 11 | if varname not in mats: 12 | print(".mat file only has these matrices:") 13 | for var in mats: 14 | print(var) 15 | # assert(False) 16 | 17 | voxels = mats[varname] 18 | dims = voxels.shape 19 | if len(dims) == 5: 20 | assert dims[1] == 1 21 | dims = (dims[0],) + tuple(dims[2:]) 22 | elif len(dims) == 3: 23 | dims = [1] + list(dims) 24 | else: 25 | assert len(dims) == 4 26 | result = np.reshape(voxels, dims) 27 | return result 28 | 29 | def sigmoid(z, offset=0, ratio=1): 30 | s = 1.0 / (1.0 + np.exp(-1.0 * (z-offset) * ratio)) 31 | return s 32 | 33 | ############################################################################ 34 | ### Voxel Utility functions 35 | ############################################################################ 36 | def blocktrans_cen2side(cen_size): 37 | """ Convert from center rep to side rep 38 | In center rep, the 6 numbers are center coordinates, then size in 3 dims 39 | In side rep, the 6 numbers are lower x, y, z, then higher x, y, z """ 40 | cx = float(cen_size[0]) 41 | cy = float(cen_size[1]) 42 | cz = float(cen_size[2]) 43 | sx = float(cen_size[3]) 44 | sy = float(cen_size[4]) 45 | sz = float(cen_size[5]) 46 | lx,ly,lz = cx-sx/2., cy-sy/2., cz-sz/2. 47 | hx,hy,hz = cx+sx/2., cy+sy/2., cz+sz/2. 48 | return [lx,ly,lz,hx,hy,hz] 49 | 50 | def blocktrans_side2cen6(side_size): 51 | """ Convert from side rep to center rep 52 | In center rep, the 6 numbers are center coordinates, then size in 3 dims 53 | In side rep, the 6 numbers are lower x, y, z, then higher x, y, z """ 54 | lx,ly,lz = float(side_size[0]), float(side_size[1]), float(side_size[2]) 55 | hx,hy,hz = float(side_size[3]), float(side_size[4]), float(side_size[5]) 56 | return [(lx+hx)*.5,(ly+hy)*.5,(lz+hz)*.5,abs(hx-lx),abs(hy-ly),abs(hz-lz)] 57 | 58 | 59 | def center_of_mass(voxels, threshold=0.1): 60 | """ Calculate the center of mass for the current object. 61 | Voxels with occupancy less than threshold are ignored 62 | """ 63 | assert voxels.ndim == 3 64 | center = [0]*3 65 | voxels_filtered = np.copy(voxels) 66 | voxels_filtered[voxels_filtered < threshold] = 0 67 | 68 | total = voxels_filtered.sum() 69 | if total == 0: 70 | print('threshold too high for current object.') 71 | return [length / 2 for length in voxels.shape] 72 | 73 | # calculate center of mass 74 | center[0] = np.multiply(voxels_filtered.sum(1).sum(1), np.arange(voxels.shape[0])).sum()/total 75 | center[1] = np.multiply(voxels_filtered.sum(0).sum(1), np.arange(voxels.shape[1])).sum()/total 76 | center[2] = np.multiply(voxels_filtered.sum(0).sum(0), np.arange(voxels.shape[2])).sum()/total 77 | 78 | return center 79 | 80 | def downsample(voxels, step, method='max'): 81 | """ 82 | downsample a voxels matrix by a factor of step. 83 | downsample method options: max/mean 84 | same as a pooling 85 | """ 86 | assert step > 0 87 | assert voxels.ndim == 3 or voxels.ndim == 4 88 | assert method in ('max', 'mean') 89 | if step == 1: 90 | return voxels 91 | 92 | if voxels.ndim == 3: 93 | sx, sy, sz = voxels.shape[-3:] 94 | X, Y, Z = np.ogrid[0:sx, 0:sy, 0:sz] 95 | regions = sz/step * sy/step * (X/step) + sz/step * (Y/step) + Z/step 96 | if method == 'max': 97 | res = ndimage.maximum(voxels, labels=regions, index=np.arange(regions.max() + 1)) 98 | elif method == 'mean': 99 | res = ndimage.mean(voxels, labels=regions, index=np.arange(regions.max() + 1)) 100 | res.shape = (sx/step, sy/step, sz/step) 101 | return res 102 | else: 103 | res0 = downsample(voxels[0], step, method) 104 | res = np.zeros((voxels.shape[0],) + res0.shape) 105 | res[0] = res0 106 | for ind in range(1, voxels.shape[0]): 107 | res[ind] = downsample(voxels[ind], step, method) 108 | return res 109 | 110 | def max_connected(voxels, distance): 111 | """ Keep the max connected component of the voxels (a boolean matrix). 112 | distance is the distance considered as neighbors, i.e. if distance = 2, 113 | then two blocks are considered connected even with a hole in between""" 114 | assert(distance > 0) 115 | max_component = np.zeros(voxels.shape, dtype=bool) 116 | voxels = np.copy(voxels) 117 | for startx in range(voxels.shape[0]): 118 | for starty in range(voxels.shape[1]): 119 | for startz in range(voxels.shape[2]): 120 | if not voxels[startx,starty,startz]: 121 | continue 122 | # start a new component 123 | component = np.zeros(voxels.shape, dtype=bool) 124 | stack = [[startx,starty,startz]] 125 | component[startx,starty,startz] = True 126 | voxels[startx,starty,startz] = False 127 | while len(stack) > 0: 128 | x,y,z = stack.pop() 129 | for i in range(x-distance, x+distance + 1): 130 | for j in range(y-distance, y+distance + 1): 131 | for k in range(z-distance, z+distance + 1): 132 | if (i-x)**2+(j-y)**2+(k-z)**2 > distance * distance: 133 | continue 134 | if voxel_exist(voxels, i,j,k): 135 | voxels[i,j,k] = False 136 | component[i,j,k] = True 137 | stack.append([i,j,k]) 138 | if component.sum() > max_component.sum(): 139 | max_component = component 140 | return max_component 141 | 142 | 143 | def voxel_exist(voxels, x,y,z): 144 | if x < 0 or y < 0 or z < 0 or x >= voxels.shape[0] or y >= voxels.shape[1] or z >= voxels.shape[2]: 145 | return False 146 | else : 147 | return voxels[x,y,z] 148 | -------------------------------------------------------------------------------- /visualization/util_vtk.py: -------------------------------------------------------------------------------- 1 | from visualization.util import * 2 | import matplotlib.cm 3 | import vtk 4 | import math 5 | 6 | ############################################################################ 7 | ### VTK functions 8 | ############################################################################ 9 | def block_generation(cen_size, color): 10 | """ generate a block up to actor stage 11 | User may choose to use VTK boxsource implementation, or the polydata implementation 12 | """ 13 | cubeMapper = vtk.vtkPolyDataMapper() 14 | cubeActor = vtk.vtkActor() 15 | 16 | lx,ly,lz,hx,hy,hz = blocktrans_cen2side(cen_size) 17 | vertices = [ [lx,ly,lz], [hx,ly,lz], [hx,hy,lz], [lx,hy,lz], 18 | [lx,ly,hz], [hx,ly,hz], [hx,hy,hz], [lx,hy,hz]] 19 | 20 | pts =[[0,1,2,3], [4,5,6,7], [0,1,5,4], 21 | [1,2,6,5], [2,3,7,6], [3,0,4,7]] 22 | 23 | cube = vtk.vtkPolyData() 24 | points = vtk.vtkPoints() 25 | polys = vtk.vtkCellArray() 26 | 27 | for i in range(0,8): 28 | points.InsertPoint(i,vertices[i]) 29 | 30 | for i in range(0,6): 31 | polys.InsertNextCell(4) 32 | for j in range(0,4): 33 | polys.InsertCellPoint(pts[i][j]) 34 | cube.SetPoints(points) 35 | cube.SetPolys(polys) 36 | # cubeMapper.SetInput(cube) 37 | cubeMapper.SetInputData(cube) 38 | cubeActor.SetMapper(cubeMapper) 39 | 40 | # set the colors 41 | cubeActor.GetProperty().SetColor(np.array(color[:3])) 42 | cubeActor.GetProperty().SetAmbient(0.5) 43 | cubeActor.GetProperty().SetDiffuse(.5) 44 | cubeActor.GetProperty().SetSpecular(0.1) 45 | cubeActor.GetProperty().SetSpecularColor(1,1,1) 46 | cubeActor.GetProperty().SetDiffuseColor(color[:3]) 47 | # cubeActor.GetProperty().SetAmbientColor(1,1,1) 48 | # cubeActor.GetProperty().ShadingOn() 49 | return cubeActor 50 | 51 | def generate_all_blocks(voxels, threshold=0.1, uniform_size=-1, use_colormap=False): 52 | """ 53 | Generate one block per voxel, with block size and color dependent on probability. 54 | Performance is desirable if number of blocks is below 20,000. 55 | """ 56 | assert voxels.ndim == 3 57 | actors = [] 58 | counter = 0 59 | dims = voxels.shape 60 | 61 | cmap = matplotlib.cm.get_cmap('jet') 62 | DEFAULT_COLOR = [0.9,0,0] 63 | 64 | for k in range(dims[2]): 65 | for j in range(dims[1]): 66 | for i in range(dims[0]): 67 | occupancy = voxels[i][j][k] 68 | if occupancy < threshold: 69 | continue 70 | 71 | if use_colormap: 72 | color = cmap(float(occupancy)) 73 | else: # use default color 74 | color = DEFAULT_COLOR 75 | 76 | if uniform_size > 0 and uniform_size <= 1: 77 | block_size = uniform_size 78 | else: 79 | block_size = occupancy 80 | actors.append(block_generation([i+0.5, j+0.5, k+0.5, block_size, block_size, block_size], color=(color))) 81 | counter = counter + 1 82 | 83 | print(counter, "blocks filled") 84 | return actors 85 | 86 | def display(actors, cam_pos, cam_vocal, cam_up, title=None): 87 | """ Display the scene from actors. 88 | cam_pos: list of positions of cameras. 89 | cam_vocal: vocal point of cameras 90 | cam_up: view up direction of cameras 91 | title: display window title 92 | """ 93 | 94 | renWin = vtk.vtkRenderWindow() 95 | window_size = 1024 96 | 97 | renderer = vtk.vtkRenderer() 98 | for actor in actors: 99 | renderer.AddActor(actor) 100 | renderer.SetBackground(1,1,1) 101 | renWin.AddRenderer(renderer) 102 | 103 | camera = vtk.vtkCamera() 104 | renderer.SetActiveCamera(camera) 105 | renderer.ResetCamera() 106 | 107 | # the object is located at 0 <= x,y,z <= dims[i] 108 | camera.SetFocalPoint(*cam_vocal) 109 | camera.SetViewUp(*cam_up) 110 | camera.SetPosition(*cam_pos) 111 | 112 | renWin.SetSize(window_size, window_size) 113 | 114 | iren = vtk.vtkRenderWindowInteractor() 115 | style = vtk.vtkInteractorStyleTrackballCamera() 116 | iren.SetInteractorStyle(style) 117 | iren.SetRenderWindow(renWin) 118 | if title != None: 119 | renWin.SetWindowName(title) 120 | 121 | renderer.ResetCameraClippingRange() 122 | renWin.Render() 123 | 124 | iren.Initialize() 125 | iren.Start() 126 | 127 | def visualization(voxels, threshold, title=None, uniform_size=-1, use_colormap=False): 128 | """ 129 | Given a voxel matrix, plot all occupied blocks (defined by voxels[x][y][z] > threshold) 130 | if size_change is set to true, block size will be proportional to voxels[x][y][z] 131 | otherwise voxel matrix is transfered to {0,1} matrix, where consecutive blocks are merged for performance. 132 | 133 | The function saves an image at address ofilename, with form jpg/png. If form is empty string, no image is saved. 134 | 135 | """ 136 | actors = generate_all_blocks(voxels, threshold, uniform_size=uniform_size, use_colormap=use_colormap) 137 | 138 | center = center_of_mass(voxels) 139 | dims = voxels.shape 140 | distance = voxels.shape[0] * 2.8 141 | height = voxels.shape[2] * 0.85 142 | rad = math.pi * 0.43 #+ math.pi 143 | cam_pos = [center[0] + distance * math.cos(rad), center[1] + distance * math.sin(rad), center[2] + height] 144 | 145 | display(actors, cam_pos, center, (0,0,1), title=title) 146 | --------------------------------------------------------------------------------