├── data └── put data folder here.txt ├── checkpoint └── put checkpoint folder here.txt ├── test_ae.sh ├── test_svr.sh ├── train_svr.sh ├── train_ae.sh ├── LICENSE ├── main.py ├── utils.py ├── README.md ├── modelAE.py └── modelSVR.py /data/put data folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoint/put checkpoint folder here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_ae.sh: -------------------------------------------------------------------------------- 1 | python main.py --ae --sample_dir samples/im_ae_out --start 0 --end 9000 -------------------------------------------------------------------------------- /test_svr.sh: -------------------------------------------------------------------------------- 1 | python main.py --svr --sample_dir samples/im_svr_out --start 0 --end 9000 -------------------------------------------------------------------------------- /train_svr.sh: -------------------------------------------------------------------------------- 1 | python main.py --svr --train --epoch 1000 --sample_dir samples/all_vox256_img1 2 | 3 | -------------------------------------------------------------------------------- /train_ae.sh: -------------------------------------------------------------------------------- 1 | python main.py --ae --train --epoch 200 --sample_dir samples/all_vox256_img0_16 --sample_vox_size 16 2 | python main.py --ae --train --epoch 200 --sample_dir samples/all_vox256_img0_32 --sample_vox_size 32 3 | python main.py --ae --train --epoch 400 --sample_dir samples/all_vox256_img0_64 --sample_vox_size 64 4 | python main.py --ae --getz 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Learning Implicit Fields for Generative Shape Modeling 2 | 3 | Copyright (c) 2018, GrUVi lab of Simon Fraser University 4 | 5 | MIT License 6 | 7 | Copyright (c) 2018 Zhiqin Chen 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 4 | import numpy as np 5 | 6 | from modelAE import IM_AE 7 | from modelSVR import IM_SVR 8 | 9 | import argparse 10 | import h5py 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--epoch", action="store", dest="epoch", default=0, type=int, help="Epoch to train [0]") 14 | parser.add_argument("--iteration", action="store", dest="iteration", default=0, type=int, help="Iteration to train. Either epoch or iteration need to be zero [0]") 15 | parser.add_argument("--learning_rate", action="store", dest="learning_rate", default=0.00005, type=float, help="Learning rate for adam [0.00005]") 16 | parser.add_argument("--beta1", action="store", dest="beta1", default=0.5, type=float, help="Momentum term of adam [0.5]") 17 | parser.add_argument("--dataset", action="store", dest="dataset", default="all_vox256_img", help="The name of dataset") 18 | parser.add_argument("--checkpoint_dir", action="store", dest="checkpoint_dir", default="checkpoint", help="Directory name to save the checkpoints [checkpoint]") 19 | parser.add_argument("--data_dir", action="store", dest="data_dir", default="./data/all_vox256_img/", help="Root directory of dataset [data]") 20 | parser.add_argument("--sample_dir", action="store", dest="sample_dir", default="./samples/", help="Directory name to save the image samples [samples]") 21 | parser.add_argument("--sample_vox_size", action="store", dest="sample_vox_size", default=64, type=int, help="Voxel resolution for coarse-to-fine training [64]") 22 | parser.add_argument("--train", action="store_true", dest="train", default=False, help="True for training, False for testing [False]") 23 | parser.add_argument("--start", action="store", dest="start", default=0, type=int, help="In testing, output shapes [start:end]") 24 | parser.add_argument("--end", action="store", dest="end", default=16, type=int, help="In testing, output shapes [start:end]") 25 | parser.add_argument("--ae", action="store_true", dest="ae", default=False, help="True for ae [False]") 26 | parser.add_argument("--svr", action="store_true", dest="svr", default=False, help="True for svr [False]") 27 | parser.add_argument("--getz", action="store_true", dest="getz", default=False, help="True for getting latent codes [False]") 28 | FLAGS = parser.parse_args() 29 | 30 | 31 | 32 | if not os.path.exists(FLAGS.sample_dir): 33 | os.makedirs(FLAGS.sample_dir) 34 | 35 | if FLAGS.ae: 36 | im_ae = IM_AE(FLAGS) 37 | 38 | if FLAGS.train: 39 | im_ae.train(FLAGS) 40 | elif FLAGS.getz: 41 | im_ae.get_z(FLAGS) 42 | else: 43 | #im_ae.test_mesh(FLAGS) 44 | im_ae.test_mesh_point(FLAGS) 45 | elif FLAGS.svr: 46 | im_svr = IM_SVR(FLAGS) 47 | 48 | if FLAGS.train: 49 | im_svr.train(FLAGS) 50 | else: 51 | #im_svr.test_mesh(FLAGS) 52 | im_svr.test_mesh_point(FLAGS) 53 | else: 54 | print("Please specify an operation: ae or svr?") 55 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | def write_ply_point(name, vertices): 6 | fout = open(name, 'w') 7 | fout.write("ply\n") 8 | fout.write("format ascii 1.0\n") 9 | fout.write("element vertex "+str(len(vertices))+"\n") 10 | fout.write("property float x\n") 11 | fout.write("property float y\n") 12 | fout.write("property float z\n") 13 | fout.write("end_header\n") 14 | for ii in range(len(vertices)): 15 | fout.write(str(vertices[ii,0])+" "+str(vertices[ii,1])+" "+str(vertices[ii,2])+"\n") 16 | fout.close() 17 | 18 | 19 | def write_ply_point_normal(name, vertices, normals=None): 20 | fout = open(name, 'w') 21 | fout.write("ply\n") 22 | fout.write("format ascii 1.0\n") 23 | fout.write("element vertex "+str(len(vertices))+"\n") 24 | fout.write("property float x\n") 25 | fout.write("property float y\n") 26 | fout.write("property float z\n") 27 | fout.write("property float nx\n") 28 | fout.write("property float ny\n") 29 | fout.write("property float nz\n") 30 | fout.write("end_header\n") 31 | if normals is None: 32 | for ii in range(len(vertices)): 33 | fout.write(str(vertices[ii,0])+" "+str(vertices[ii,1])+" "+str(vertices[ii,2])+" "+str(vertices[ii,3])+" "+str(vertices[ii,4])+" "+str(vertices[ii,5])+"\n") 34 | else: 35 | for ii in range(len(vertices)): 36 | fout.write(str(vertices[ii,0])+" "+str(vertices[ii,1])+" "+str(vertices[ii,2])+" "+str(normals[ii,0])+" "+str(normals[ii,1])+" "+str(normals[ii,2])+"\n") 37 | fout.close() 38 | 39 | 40 | def write_ply_triangle(name, vertices, triangles): 41 | fout = open(name, 'w') 42 | fout.write("ply\n") 43 | fout.write("format ascii 1.0\n") 44 | fout.write("element vertex "+str(len(vertices))+"\n") 45 | fout.write("property float x\n") 46 | fout.write("property float y\n") 47 | fout.write("property float z\n") 48 | fout.write("element face "+str(len(triangles))+"\n") 49 | fout.write("property list uchar int vertex_index\n") 50 | fout.write("end_header\n") 51 | for ii in range(len(vertices)): 52 | fout.write(str(vertices[ii,0])+" "+str(vertices[ii,1])+" "+str(vertices[ii,2])+"\n") 53 | for ii in range(len(triangles)): 54 | fout.write("3 "+str(triangles[ii,0])+" "+str(triangles[ii,1])+" "+str(triangles[ii,2])+"\n") 55 | fout.close() 56 | 57 | 58 | def sample_points_triangle(vertices, triangles, num_of_points): 59 | epsilon = 1e-6 60 | triangle_area_list = np.zeros([len(triangles)],np.float32) 61 | triangle_normal_list = np.zeros([len(triangles),3],np.float32) 62 | for i in range(len(triangles)): 63 | #area = |u x v|/2 = |u||v|sin(uv)/2 64 | a,b,c = vertices[triangles[i,1]]-vertices[triangles[i,0]] 65 | x,y,z = vertices[triangles[i,2]]-vertices[triangles[i,0]] 66 | ti = b*z-c*y 67 | tj = c*x-a*z 68 | tk = a*y-b*x 69 | area2 = math.sqrt(ti*ti+tj*tj+tk*tk) 70 | if area2100: 94 | print("infinite loop here!") 95 | return point_normal_list 96 | for i in range(len(triangle_index_list)): 97 | if count>=num_of_points: break 98 | dxb = triangle_index_list[i] 99 | prob = sample_prob_list[dxb] 100 | prob_i = int(prob) 101 | prob_f = prob-prob_i 102 | if np.random.random()=1: 113 | u_x = 1-u_x 114 | v_y = 1-v_y 115 | ppp = u*u_x+v*v_y+base 116 | 117 | point_normal_list[count,:3] = ppp 118 | point_normal_list[count,3:] = normal_direction 119 | count += 1 120 | if count>=num_of_points: break 121 | 122 | return point_normal_list -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IM-NET-pytorch 2 | PyTorch 1.2 implementation for paper "Learning Implicit Fields for Generative Shape Modeling", [Zhiqin Chen](https://czq142857.github.io/), [Hao (Richard) Zhang](https://www.cs.sfu.ca/~haoz/). 3 | 4 | ### [Supplementary material](https://github.com/czq142857/implicit-decoder/tree/master/supplementary_material) | [Paper](https://arxiv.org/abs/1812.02822) 5 | 6 | ### [Original implementation](https://github.com/czq142857/implicit-decoder) 7 | 8 | ### [Improved TensorFlow1 implementation](https://github.com/czq142857/IM-NET) 9 | 10 | 11 | ## Improvements 12 | 13 | In short, this repo is an implementation of [IM-NET](https://github.com/czq142857/IM-NET) with the framework provided by [BSP-NET-pytorch](https://github.com/czq142857/BSP-NET-pytorch). 14 | 15 | The improvements over the [original implementation](https://github.com/czq142857/implicit-decoder) is the same as [IM-NET (improved TensorFlow1 implementation)](https://github.com/czq142857/IM-NET): 16 | 17 | Encoder: 18 | 19 | - In IM-AE (autoencoder), changed batch normalization to instance normalization. 20 | 21 | Decoder (=generator): 22 | 23 | - Changed the first layer from 2048-1024 to 1024-1024-1024. 24 | - Changed latent code size from 128 to 256. 25 | - Removed all skip connections. 26 | - Changed the last activation function from sigmoid to clip ( max(min(h, 1), 0) ). 27 | 28 | Training: 29 | 30 | - Trained one model on the 13 ShapeNet categories as most Single-View Reconstruction networks do. 31 | - For each category, sort the object names and use the first 80% as training set, the rest as testing set, same as [AtlasNet](https://github.com/ThibaultGROUEIX/AtlasNet). 32 | - Reduced the number of sampled points by half in the training set. Points were sampled on 2563 voxels. 33 | - Removed data augmentation (image crops), same as [Occupancy Networks](https://github.com/autonomousvision/occupancy_networks). 34 | - Added coarse-to-fine sampling for inference to speed up testing. 35 | - Added post-processing to make the output mesh smoother. To enable, find and uncomment all *"self.optimize_mesh(vertices,model_z)"*. 36 | 37 | 38 | ## Citation 39 | If you find our work useful in your research, please consider citing: 40 | 41 | @article{chen2018implicit_decoder, 42 | title={Learning Implicit Fields for Generative Shape Modeling}, 43 | author={Chen, Zhiqin and Zhang, Hao}, 44 | journal={Proceedings of IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 45 | year={2019} 46 | } 47 | 48 | ## Dependencies 49 | Requirements: 50 | - Python 3.5 with numpy, scipy and h5py 51 | - [PyTorch 1.2](https://pytorch.org/get-started/locally/) 52 | - [PyMCubes](https://github.com/pmneila/PyMCubes) (for marching cubes) 53 | 54 | Our code has been tested on Ubuntu 16.04 and Windows 10. 55 | 56 | 57 | ## Datasets and pre-trained weights 58 | The original voxel models are from [HSP](https://github.com/chaene/hsp). 59 | 60 | The rendered views are from [3D-R2N2](https://github.com/chrischoy/3D-R2N2). 61 | 62 | Since our network takes point-value pairs, the voxel models require further sampling. 63 | 64 | For data preparation, please see directory [point_sampling](https://github.com/czq142857/IM-NET/tree/master/point_sampling). 65 | 66 | We provide the ready-to-use datasets in hdf5 format, together with our pre-trained network weights. 67 | 68 | - [IM-NET-pytorch](https://drive.google.com/open?id=1ykE6MB2iW1Dk5t4wRx85MgpggeoyAqu3) 69 | 70 | Backup links: 71 | 72 | - [IM-NET-pytorch](https://pan.baidu.com/s/10695F20-xTWCrltYGhBPcQ) (pwd: bqex) 73 | 74 | 75 | ## Usage 76 | 77 | Please use the provided scripts *train_ae.sh*, *train_svr.sh*, *test_ae.sh*, *test_svr.sh* to train the network on the training set and get output meshes for the testing set. 78 | 79 | To train an autoencoder, use the following commands for progressive training. 80 | ``` 81 | python main.py --ae --train --epoch 200 --sample_dir samples/all_vox256_img0_16 --sample_vox_size 16 82 | python main.py --ae --train --epoch 200 --sample_dir samples/all_vox256_img0_32 --sample_vox_size 32 83 | python main.py --ae --train --epoch 200 --sample_dir samples/all_vox256_img0_64 --sample_vox_size 64 84 | ``` 85 | The above commands will train the AE model 200 epochs in 163 resolution, then 200 epochs in 323 resolution, and finally 200 epochs in 643 resolution. 86 | Training on the 13 ShapeNet categories takes about 3 days on one GeForce RTX 2080 Ti GPU. 87 | 88 | After training, you may visualize some results from the testing set. 89 | ``` 90 | python main.py --ae --sample_dir samples/im_ae_out --start 0 --end 16 91 | ``` 92 | You can specify the start and end indices of the shapes by *--start* and *--end*. 93 | 94 | 95 | To train the network for single-view reconstruction, after training the autoencoder, use the following command to extract the latent codes: 96 | ``` 97 | python main.py --ae --getz 98 | ``` 99 | Then use the following commands to train the SVR model: 100 | ``` 101 | python main.py --svr --train --epoch 1000 --sample_dir samples/all_vox256_img1 102 | ``` 103 | After training, you may visualize some results from the testing set. 104 | ``` 105 | python main.py --svr --sample_dir samples/im_svr_out --start 0 --end 16 106 | ``` 107 | 108 | 109 | ## License 110 | This project is licensed under the terms of the MIT license (see LICENSE for details). 111 | 112 | 113 | -------------------------------------------------------------------------------- /modelAE.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | import numpy as np 6 | import h5py 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch import optim 13 | from torch.autograd import Variable 14 | 15 | import mcubes 16 | 17 | from utils import * 18 | 19 | #pytorch 1.2.0 implementation 20 | 21 | 22 | class generator(nn.Module): 23 | def __init__(self, z_dim, point_dim, gf_dim): 24 | super(generator, self).__init__() 25 | self.z_dim = z_dim 26 | self.point_dim = point_dim 27 | self.gf_dim = gf_dim 28 | self.linear_1 = nn.Linear(self.z_dim+self.point_dim, self.gf_dim*8, bias=True) 29 | self.linear_2 = nn.Linear(self.gf_dim*8, self.gf_dim*8, bias=True) 30 | self.linear_3 = nn.Linear(self.gf_dim*8, self.gf_dim*8, bias=True) 31 | self.linear_4 = nn.Linear(self.gf_dim*8, self.gf_dim*4, bias=True) 32 | self.linear_5 = nn.Linear(self.gf_dim*4, self.gf_dim*2, bias=True) 33 | self.linear_6 = nn.Linear(self.gf_dim*2, self.gf_dim*1, bias=True) 34 | self.linear_7 = nn.Linear(self.gf_dim*1, 1, bias=True) 35 | nn.init.normal_(self.linear_1.weight, mean=0.0, std=0.02) 36 | nn.init.constant_(self.linear_1.bias,0) 37 | nn.init.normal_(self.linear_2.weight, mean=0.0, std=0.02) 38 | nn.init.constant_(self.linear_2.bias,0) 39 | nn.init.normal_(self.linear_3.weight, mean=0.0, std=0.02) 40 | nn.init.constant_(self.linear_3.bias,0) 41 | nn.init.normal_(self.linear_4.weight, mean=0.0, std=0.02) 42 | nn.init.constant_(self.linear_4.bias,0) 43 | nn.init.normal_(self.linear_5.weight, mean=0.0, std=0.02) 44 | nn.init.constant_(self.linear_5.bias,0) 45 | nn.init.normal_(self.linear_6.weight, mean=0.0, std=0.02) 46 | nn.init.constant_(self.linear_6.bias,0) 47 | nn.init.normal_(self.linear_7.weight, mean=1e-5, std=0.02) 48 | nn.init.constant_(self.linear_7.bias,0) 49 | 50 | def forward(self, points, z, is_training=False): 51 | zs = z.view(-1,1,self.z_dim).repeat(1,points.size()[1],1) 52 | pointz = torch.cat([points,zs],2) 53 | 54 | l1 = self.linear_1(pointz) 55 | l1 = F.leaky_relu(l1, negative_slope=0.02, inplace=True) 56 | 57 | l2 = self.linear_2(l1) 58 | l2 = F.leaky_relu(l2, negative_slope=0.02, inplace=True) 59 | 60 | l3 = self.linear_3(l2) 61 | l3 = F.leaky_relu(l3, negative_slope=0.02, inplace=True) 62 | 63 | l4 = self.linear_4(l3) 64 | l4 = F.leaky_relu(l4, negative_slope=0.02, inplace=True) 65 | 66 | l5 = self.linear_5(l4) 67 | l5 = F.leaky_relu(l5, negative_slope=0.02, inplace=True) 68 | 69 | l6 = self.linear_6(l5) 70 | l6 = F.leaky_relu(l6, negative_slope=0.02, inplace=True) 71 | 72 | l7 = self.linear_7(l6) 73 | 74 | #l7 = torch.clamp(l7, min=0, max=1) 75 | l7 = torch.max(torch.min(l7, l7*0.01+0.99), l7*0.01) 76 | 77 | return l7 78 | 79 | class encoder(nn.Module): 80 | def __init__(self, ef_dim, z_dim): 81 | super(encoder, self).__init__() 82 | self.ef_dim = ef_dim 83 | self.z_dim = z_dim 84 | self.conv_1 = nn.Conv3d(1, self.ef_dim, 4, stride=2, padding=1, bias=False) 85 | self.in_1 = nn.InstanceNorm3d(self.ef_dim) 86 | self.conv_2 = nn.Conv3d(self.ef_dim, self.ef_dim*2, 4, stride=2, padding=1, bias=False) 87 | self.in_2 = nn.InstanceNorm3d(self.ef_dim*2) 88 | self.conv_3 = nn.Conv3d(self.ef_dim*2, self.ef_dim*4, 4, stride=2, padding=1, bias=False) 89 | self.in_3 = nn.InstanceNorm3d(self.ef_dim*4) 90 | self.conv_4 = nn.Conv3d(self.ef_dim*4, self.ef_dim*8, 4, stride=2, padding=1, bias=False) 91 | self.in_4 = nn.InstanceNorm3d(self.ef_dim*8) 92 | self.conv_5 = nn.Conv3d(self.ef_dim*8, self.z_dim, 4, stride=1, padding=0, bias=True) 93 | nn.init.xavier_uniform_(self.conv_1.weight) 94 | nn.init.xavier_uniform_(self.conv_2.weight) 95 | nn.init.xavier_uniform_(self.conv_3.weight) 96 | nn.init.xavier_uniform_(self.conv_4.weight) 97 | nn.init.xavier_uniform_(self.conv_5.weight) 98 | nn.init.constant_(self.conv_5.bias,0) 99 | 100 | def forward(self, inputs, is_training=False): 101 | d_1 = self.in_1(self.conv_1(inputs)) 102 | d_1 = F.leaky_relu(d_1, negative_slope=0.02, inplace=True) 103 | 104 | d_2 = self.in_2(self.conv_2(d_1)) 105 | d_2 = F.leaky_relu(d_2, negative_slope=0.02, inplace=True) 106 | 107 | d_3 = self.in_3(self.conv_3(d_2)) 108 | d_3 = F.leaky_relu(d_3, negative_slope=0.02, inplace=True) 109 | 110 | d_4 = self.in_4(self.conv_4(d_3)) 111 | d_4 = F.leaky_relu(d_4, negative_slope=0.02, inplace=True) 112 | 113 | d_5 = self.conv_5(d_4) 114 | d_5 = d_5.view(-1, self.z_dim) 115 | d_5 = torch.sigmoid(d_5) 116 | 117 | return d_5 118 | 119 | 120 | class im_network(nn.Module): 121 | def __init__(self, ef_dim, gf_dim, z_dim, point_dim): 122 | super(im_network, self).__init__() 123 | self.ef_dim = ef_dim 124 | self.gf_dim = gf_dim 125 | self.z_dim = z_dim 126 | self.point_dim = point_dim 127 | self.encoder = encoder(self.ef_dim, self.z_dim) 128 | self.generator = generator(self.z_dim, self.point_dim, self.gf_dim) 129 | 130 | def forward(self, inputs, z_vector, point_coord, is_training=False): 131 | if is_training: 132 | z_vector = self.encoder(inputs, is_training=is_training) 133 | net_out = self.generator(point_coord, z_vector, is_training=is_training) 134 | else: 135 | if inputs is not None: 136 | z_vector = self.encoder(inputs, is_training=is_training) 137 | if z_vector is not None and point_coord is not None: 138 | net_out = self.generator(point_coord, z_vector, is_training=is_training) 139 | else: 140 | net_out = None 141 | 142 | return z_vector, net_out 143 | 144 | 145 | class IM_AE(object): 146 | def __init__(self, config): 147 | #progressive training 148 | #1-- (16, 16*16*16) 149 | #2-- (32, 16*16*16) 150 | #3-- (64, 16*16*16*4) 151 | self.sample_vox_size = config.sample_vox_size 152 | self.point_batch_size = 16*16*16 153 | self.shape_batch_size = 32 154 | self.input_size = 64 #input voxel grid size 155 | 156 | self.ef_dim = 32 157 | self.gf_dim = 128 158 | self.z_dim = 256 159 | self.point_dim = 3 160 | 161 | self.dataset_name = config.dataset 162 | self.dataset_load = self.dataset_name + '_train' 163 | if not (config.train or config.getz): 164 | self.dataset_load = self.dataset_name + '_test' 165 | self.checkpoint_dir = config.checkpoint_dir 166 | self.data_dir = config.data_dir 167 | 168 | data_hdf5_name = self.data_dir+'/'+self.dataset_load+'.hdf5' 169 | if os.path.exists(data_hdf5_name): 170 | data_dict = h5py.File(data_hdf5_name, 'r') 171 | self.data_points = (data_dict['points_'+str(self.sample_vox_size)][:].astype(np.float32)+0.5)/256-0.5 172 | self.data_values = data_dict['values_'+str(self.sample_vox_size)][:].astype(np.float32) 173 | self.data_voxels = data_dict['voxels'][:] 174 | self.load_point_batch_size = self.data_points.shape[1] 175 | #reshape to NCHW 176 | self.data_voxels = np.reshape(self.data_voxels, [-1,1,self.input_size,self.input_size,self.input_size]) 177 | else: 178 | print("error: cannot load "+data_hdf5_name) 179 | exit(0) 180 | 181 | 182 | if torch.cuda.is_available(): 183 | self.device = torch.device('cuda') 184 | torch.backends.cudnn.benchmark = True 185 | else: 186 | self.device = torch.device('cpu') 187 | 188 | #build model 189 | self.im_network = im_network(self.ef_dim, self.gf_dim, self.z_dim, self.point_dim) 190 | self.im_network.to(self.device) 191 | #print params 192 | #for param_tensor in self.im_network.state_dict(): 193 | # print(param_tensor, "\t", self.im_network.state_dict()[param_tensor].size()) 194 | self.optimizer = torch.optim.Adam(self.im_network.parameters(), lr=config.learning_rate, betas=(config.beta1, 0.999)) 195 | #pytorch does not have a checkpoint manager 196 | #have to define it myself to manage max num of checkpoints to keep 197 | self.max_to_keep = 2 198 | self.checkpoint_path = os.path.join(self.checkpoint_dir, self.model_dir) 199 | self.checkpoint_name='IM_AE.model' 200 | self.checkpoint_manager_list = [None] * self.max_to_keep 201 | self.checkpoint_manager_pointer = 0 202 | #loss 203 | def network_loss(G,point_value): 204 | return torch.mean((G-point_value)**2) 205 | self.loss = network_loss 206 | 207 | 208 | #keep everything a power of 2 209 | self.cell_grid_size = 4 210 | self.frame_grid_size = 64 211 | self.real_size = self.cell_grid_size*self.frame_grid_size #=256, output point-value voxel grid size in testing 212 | self.test_size = 32 #related to testing batch_size, adjust according to gpu memory size 213 | self.test_point_batch_size = self.test_size*self.test_size*self.test_size #do not change 214 | 215 | #get coords for training 216 | dima = self.test_size 217 | dim = self.frame_grid_size 218 | self.aux_x = np.zeros([dima,dima,dima],np.uint8) 219 | self.aux_y = np.zeros([dima,dima,dima],np.uint8) 220 | self.aux_z = np.zeros([dima,dima,dima],np.uint8) 221 | multiplier = int(dim/dima) 222 | multiplier2 = multiplier*multiplier 223 | multiplier3 = multiplier*multiplier*multiplier 224 | for i in range(dima): 225 | for j in range(dima): 226 | for k in range(dima): 227 | self.aux_x[i,j,k] = i*multiplier 228 | self.aux_y[i,j,k] = j*multiplier 229 | self.aux_z[i,j,k] = k*multiplier 230 | self.coords = np.zeros([multiplier3,dima,dima,dima,3],np.float32) 231 | for i in range(multiplier): 232 | for j in range(multiplier): 233 | for k in range(multiplier): 234 | self.coords[i*multiplier2+j*multiplier+k,:,:,:,0] = self.aux_x+i 235 | self.coords[i*multiplier2+j*multiplier+k,:,:,:,1] = self.aux_y+j 236 | self.coords[i*multiplier2+j*multiplier+k,:,:,:,2] = self.aux_z+k 237 | self.coords = (self.coords.astype(np.float32)+0.5)/dim-0.5 238 | self.coords = np.reshape(self.coords,[multiplier3,self.test_point_batch_size,3]) 239 | self.coords = torch.from_numpy(self.coords) 240 | self.coords = self.coords.to(self.device) 241 | 242 | 243 | #get coords for testing 244 | dimc = self.cell_grid_size 245 | dimf = self.frame_grid_size 246 | self.cell_x = np.zeros([dimc,dimc,dimc],np.int32) 247 | self.cell_y = np.zeros([dimc,dimc,dimc],np.int32) 248 | self.cell_z = np.zeros([dimc,dimc,dimc],np.int32) 249 | self.cell_coords = np.zeros([dimf,dimf,dimf,dimc,dimc,dimc,3],np.float32) 250 | self.frame_coords = np.zeros([dimf,dimf,dimf,3],np.float32) 251 | self.frame_x = np.zeros([dimf,dimf,dimf],np.int32) 252 | self.frame_y = np.zeros([dimf,dimf,dimf],np.int32) 253 | self.frame_z = np.zeros([dimf,dimf,dimf],np.int32) 254 | for i in range(dimc): 255 | for j in range(dimc): 256 | for k in range(dimc): 257 | self.cell_x[i,j,k] = i 258 | self.cell_y[i,j,k] = j 259 | self.cell_z[i,j,k] = k 260 | for i in range(dimf): 261 | for j in range(dimf): 262 | for k in range(dimf): 263 | self.cell_coords[i,j,k,:,:,:,0] = self.cell_x+i*dimc 264 | self.cell_coords[i,j,k,:,:,:,1] = self.cell_y+j*dimc 265 | self.cell_coords[i,j,k,:,:,:,2] = self.cell_z+k*dimc 266 | self.frame_coords[i,j,k,0] = i 267 | self.frame_coords[i,j,k,1] = j 268 | self.frame_coords[i,j,k,2] = k 269 | self.frame_x[i,j,k] = i 270 | self.frame_y[i,j,k] = j 271 | self.frame_z[i,j,k] = k 272 | self.cell_coords = (self.cell_coords.astype(np.float32)+0.5)/self.real_size-0.5 273 | self.cell_coords = np.reshape(self.cell_coords,[dimf,dimf,dimf,dimc*dimc*dimc,3]) 274 | self.cell_x = np.reshape(self.cell_x,[dimc*dimc*dimc]) 275 | self.cell_y = np.reshape(self.cell_y,[dimc*dimc*dimc]) 276 | self.cell_z = np.reshape(self.cell_z,[dimc*dimc*dimc]) 277 | self.frame_x = np.reshape(self.frame_x,[dimf*dimf*dimf]) 278 | self.frame_y = np.reshape(self.frame_y,[dimf*dimf*dimf]) 279 | self.frame_z = np.reshape(self.frame_z,[dimf*dimf*dimf]) 280 | self.frame_coords = (self.frame_coords.astype(np.float32)+0.5)/dimf-0.5 281 | self.frame_coords = np.reshape(self.frame_coords,[dimf*dimf*dimf,3]) 282 | 283 | self.sampling_threshold = 0.5 #final marching cubes threshold 284 | 285 | @property 286 | def model_dir(self): 287 | return "{}_ae_{}".format(self.dataset_name, self.input_size) 288 | 289 | def train(self, config): 290 | #load previous checkpoint 291 | checkpoint_txt = os.path.join(self.checkpoint_path, "checkpoint") 292 | if os.path.exists(checkpoint_txt): 293 | fin = open(checkpoint_txt) 294 | model_dir = fin.readline().strip() 295 | fin.close() 296 | self.im_network.load_state_dict(torch.load(model_dir)) 297 | print(" [*] Load SUCCESS") 298 | else: 299 | print(" [!] Load failed...") 300 | 301 | shape_num = len(self.data_voxels) 302 | batch_index_list = np.arange(shape_num) 303 | 304 | print("\n\n----------net summary----------") 305 | print("training samples ", shape_num) 306 | print("-------------------------------\n\n") 307 | 308 | start_time = time.time() 309 | assert config.epoch==0 or config.iteration==0 310 | training_epoch = config.epoch + int(config.iteration/shape_num) 311 | batch_num = int(shape_num/self.shape_batch_size) 312 | point_batch_num = int(self.load_point_batch_size/self.point_batch_size) 313 | 314 | for epoch in range(0, training_epoch): 315 | self.im_network.train() 316 | np.random.shuffle(batch_index_list) 317 | avg_loss_sp = 0 318 | avg_num = 0 319 | for idx in range(batch_num): 320 | dxb = batch_index_list[idx*self.shape_batch_size:(idx+1)*self.shape_batch_size] 321 | batch_voxels = self.data_voxels[dxb].astype(np.float32) 322 | if point_batch_num==1: 323 | point_coord = self.data_points[dxb] 324 | point_value = self.data_values[dxb] 325 | else: 326 | which_batch = np.random.randint(point_batch_num) 327 | point_coord = self.data_points[dxb,which_batch*self.point_batch_size:(which_batch+1)*self.point_batch_size] 328 | point_value = self.data_values[dxb,which_batch*self.point_batch_size:(which_batch+1)*self.point_batch_size] 329 | 330 | batch_voxels = torch.from_numpy(batch_voxels) 331 | point_coord = torch.from_numpy(point_coord) 332 | point_value = torch.from_numpy(point_value) 333 | 334 | batch_voxels = batch_voxels.to(self.device) 335 | point_coord = point_coord.to(self.device) 336 | point_value = point_value.to(self.device) 337 | 338 | self.im_network.zero_grad() 339 | _, net_out = self.im_network(batch_voxels, None, point_coord, is_training=True) 340 | errSP = self.loss(net_out, point_value) 341 | 342 | errSP.backward() 343 | self.optimizer.step() 344 | 345 | avg_loss_sp += errSP.item() 346 | avg_num += 1 347 | print(str(self.sample_vox_size)+" Epoch: [%2d/%2d] time: %4.4f, loss_sp: %.6f" % (epoch, training_epoch, time.time() - start_time, avg_loss_sp/avg_num)) 348 | if epoch%10==9: 349 | self.test_1(config,"train_"+str(self.sample_vox_size)+"_"+str(epoch)) 350 | if epoch%20==19: 351 | if not os.path.exists(self.checkpoint_path): 352 | os.makedirs(self.checkpoint_path) 353 | save_dir = os.path.join(self.checkpoint_path,self.checkpoint_name+str(self.sample_vox_size)+"-"+str(epoch)+".pth") 354 | self.checkpoint_manager_pointer = (self.checkpoint_manager_pointer+1)%self.max_to_keep 355 | #delete checkpoint 356 | if self.checkpoint_manager_list[self.checkpoint_manager_pointer] is not None: 357 | if os.path.exists(self.checkpoint_manager_list[self.checkpoint_manager_pointer]): 358 | os.remove(self.checkpoint_manager_list[self.checkpoint_manager_pointer]) 359 | #save checkpoint 360 | torch.save(self.im_network.state_dict(), save_dir) 361 | #update checkpoint manager 362 | self.checkpoint_manager_list[self.checkpoint_manager_pointer] = save_dir 363 | #write file 364 | checkpoint_txt = os.path.join(self.checkpoint_path, "checkpoint") 365 | fout = open(checkpoint_txt, 'w') 366 | for i in range(self.max_to_keep): 367 | pointer = (self.checkpoint_manager_pointer+self.max_to_keep-i)%self.max_to_keep 368 | if self.checkpoint_manager_list[pointer] is not None: 369 | fout.write(self.checkpoint_manager_list[pointer]+"\n") 370 | fout.close() 371 | 372 | if not os.path.exists(self.checkpoint_path): 373 | os.makedirs(self.checkpoint_path) 374 | save_dir = os.path.join(self.checkpoint_path,self.checkpoint_name+str(self.sample_vox_size)+"-"+str(epoch)+".pth") 375 | self.checkpoint_manager_pointer = (self.checkpoint_manager_pointer+1)%self.max_to_keep 376 | #delete checkpoint 377 | if self.checkpoint_manager_list[self.checkpoint_manager_pointer] is not None: 378 | if os.path.exists(self.checkpoint_manager_list[self.checkpoint_manager_pointer]): 379 | os.remove(self.checkpoint_manager_list[self.checkpoint_manager_pointer]) 380 | #save checkpoint 381 | torch.save(self.im_network.state_dict(), save_dir) 382 | #update checkpoint manager 383 | self.checkpoint_manager_list[self.checkpoint_manager_pointer] = save_dir 384 | #write file 385 | checkpoint_txt = os.path.join(self.checkpoint_path, "checkpoint") 386 | fout = open(checkpoint_txt, 'w') 387 | for i in range(self.max_to_keep): 388 | pointer = (self.checkpoint_manager_pointer+self.max_to_keep-i)%self.max_to_keep 389 | if self.checkpoint_manager_list[pointer] is not None: 390 | fout.write(self.checkpoint_manager_list[pointer]+"\n") 391 | fout.close() 392 | 393 | def test_1(self, config, name): 394 | multiplier = int(self.frame_grid_size/self.test_size) 395 | multiplier2 = multiplier*multiplier 396 | self.im_network.eval() 397 | t = np.random.randint(len(self.data_voxels)) 398 | model_float = np.zeros([self.frame_grid_size+2,self.frame_grid_size+2,self.frame_grid_size+2],np.float32) 399 | batch_voxels = self.data_voxels[t:t+1].astype(np.float32) 400 | batch_voxels = torch.from_numpy(batch_voxels) 401 | batch_voxels = batch_voxels.to(self.device) 402 | z_vector, _ = self.im_network(batch_voxels, None, None, is_training=False) 403 | for i in range(multiplier): 404 | for j in range(multiplier): 405 | for k in range(multiplier): 406 | minib = i*multiplier2+j*multiplier+k 407 | point_coord = self.coords[minib:minib+1] 408 | _, net_out = self.im_network(None, z_vector, point_coord, is_training=False) 409 | #net_out = torch.clamp(net_out, min=0, max=1) 410 | model_float[self.aux_x+i+1,self.aux_y+j+1,self.aux_z+k+1] = np.reshape(net_out.detach().cpu().numpy(), [self.test_size,self.test_size,self.test_size]) 411 | 412 | vertices, triangles = mcubes.marching_cubes(model_float, self.sampling_threshold) 413 | vertices = (vertices.astype(np.float32)-0.5)/self.frame_grid_size-0.5 414 | #output ply sum 415 | write_ply_triangle(config.sample_dir+"/"+name+".ply", vertices, triangles) 416 | print("[sample]") 417 | 418 | 419 | 420 | def z2voxel(self, z): 421 | model_float = np.zeros([self.real_size+2,self.real_size+2,self.real_size+2],np.float32) 422 | dimc = self.cell_grid_size 423 | dimf = self.frame_grid_size 424 | 425 | frame_flag = np.zeros([dimf+2,dimf+2,dimf+2],np.uint8) 426 | queue = [] 427 | 428 | frame_batch_num = int(dimf**3/self.test_point_batch_size) 429 | assert frame_batch_num>0 430 | 431 | #get frame grid values 432 | for i in range(frame_batch_num): 433 | point_coord = self.frame_coords[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 434 | point_coord = np.expand_dims(point_coord, axis=0) 435 | point_coord = torch.from_numpy(point_coord) 436 | point_coord = point_coord.to(self.device) 437 | _, model_out_ = self.im_network(None, z, point_coord, is_training=False) 438 | model_out = model_out_.detach().cpu().numpy()[0] 439 | x_coords = self.frame_x[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 440 | y_coords = self.frame_y[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 441 | z_coords = self.frame_z[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 442 | frame_flag[x_coords+1,y_coords+1,z_coords+1] = np.reshape((model_out>self.sampling_threshold).astype(np.uint8), [self.test_point_batch_size]) 443 | 444 | #get queue and fill up ones 445 | for i in range(1,dimf+1): 446 | for j in range(1,dimf+1): 447 | for k in range(1,dimf+1): 448 | maxv = np.max(frame_flag[i-1:i+2,j-1:j+2,k-1:k+2]) 449 | minv = np.min(frame_flag[i-1:i+2,j-1:j+2,k-1:k+2]) 450 | if maxv!=minv: 451 | queue.append((i,j,k)) 452 | elif maxv==1: 453 | x_coords = self.cell_x+(i-1)*dimc 454 | y_coords = self.cell_y+(j-1)*dimc 455 | z_coords = self.cell_z+(k-1)*dimc 456 | model_float[x_coords+1,y_coords+1,z_coords+1] = 1.0 457 | 458 | print("running queue:",len(queue)) 459 | cell_batch_size = dimc**3 460 | cell_batch_num = int(self.test_point_batch_size/cell_batch_size) 461 | assert cell_batch_num>0 462 | #run queue 463 | while len(queue)>0: 464 | batch_num = min(len(queue),cell_batch_num) 465 | point_list = [] 466 | cell_coords = [] 467 | for i in range(batch_num): 468 | point = queue.pop(0) 469 | point_list.append(point) 470 | cell_coords.append(self.cell_coords[point[0]-1,point[1]-1,point[2]-1]) 471 | cell_coords = np.concatenate(cell_coords, axis=0) 472 | cell_coords = np.expand_dims(cell_coords, axis=0) 473 | cell_coords = torch.from_numpy(cell_coords) 474 | cell_coords = cell_coords.to(self.device) 475 | _, model_out_batch_ = self.im_network(None, z, cell_coords, is_training=False) 476 | model_out_batch = model_out_batch_.detach().cpu().numpy()[0] 477 | for i in range(batch_num): 478 | point = point_list[i] 479 | model_out = model_out_batch[i*cell_batch_size:(i+1)*cell_batch_size,0] 480 | x_coords = self.cell_x+(point[0]-1)*dimc 481 | y_coords = self.cell_y+(point[1]-1)*dimc 482 | z_coords = self.cell_z+(point[2]-1)*dimc 483 | model_float[x_coords+1,y_coords+1,z_coords+1] = model_out 484 | 485 | if np.max(model_out)>self.sampling_threshold: 486 | for i in range(-1,2): 487 | pi = point[0]+i 488 | if pi<=0 or pi>dimf: continue 489 | for j in range(-1,2): 490 | pj = point[1]+j 491 | if pj<=0 or pj>dimf: continue 492 | for k in range(-1,2): 493 | pk = point[2]+k 494 | if pk<=0 or pk>dimf: continue 495 | if (frame_flag[pi,pj,pk] == 0): 496 | frame_flag[pi,pj,pk] = 1 497 | queue.append((pi,pj,pk)) 498 | return model_float 499 | 500 | #may introduce foldovers 501 | def optimize_mesh(self, vertices, z, iteration = 3): 502 | new_vertices = np.copy(vertices) 503 | 504 | new_vertices_ = np.expand_dims(new_vertices, axis=0) 505 | new_vertices_ = torch.from_numpy(new_vertices_) 506 | new_vertices_ = new_vertices_.to(self.device) 507 | _, new_v_out_ = self.im_network(None, z, new_vertices_, is_training=False) 508 | new_v_out = new_v_out_.detach().cpu().numpy()[0] 509 | 510 | for iter in range(iteration): 511 | for i in [-1,0,1]: 512 | for j in [-1,0,1]: 513 | for k in [-1,0,1]: 514 | if i==0 and j==0 and k==0: continue 515 | offset = np.array([[i,j,k]],np.float32)/(self.real_size*6*2**iter) 516 | current_vertices = vertices+offset 517 | current_vertices_ = np.expand_dims(current_vertices, axis=0) 518 | current_vertices_ = torch.from_numpy(current_vertices_) 519 | current_vertices_ = current_vertices_.to(self.device) 520 | _, current_v_out_ = self.im_network(None, z, current_vertices_, is_training=False) 521 | current_v_out = current_v_out_.detach().cpu().numpy()[0] 522 | keep_flag = abs(current_v_out-self.sampling_threshold)0 490 | 491 | #get frame grid values 492 | for i in range(frame_batch_num): 493 | point_coord = self.frame_coords[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 494 | point_coord = np.expand_dims(point_coord, axis=0) 495 | point_coord = torch.from_numpy(point_coord) 496 | point_coord = point_coord.to(self.device) 497 | _, model_out_ = self.im_network(None, z, point_coord, is_training=False) 498 | model_out = model_out_.detach().cpu().numpy()[0] 499 | x_coords = self.frame_x[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 500 | y_coords = self.frame_y[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 501 | z_coords = self.frame_z[i*self.test_point_batch_size:(i+1)*self.test_point_batch_size] 502 | frame_flag[x_coords+1,y_coords+1,z_coords+1] = np.reshape((model_out>self.sampling_threshold).astype(np.uint8), [self.test_point_batch_size]) 503 | 504 | #get queue and fill up ones 505 | for i in range(1,dimf+1): 506 | for j in range(1,dimf+1): 507 | for k in range(1,dimf+1): 508 | maxv = np.max(frame_flag[i-1:i+2,j-1:j+2,k-1:k+2]) 509 | minv = np.min(frame_flag[i-1:i+2,j-1:j+2,k-1:k+2]) 510 | if maxv!=minv: 511 | queue.append((i,j,k)) 512 | elif maxv==1: 513 | x_coords = self.cell_x+(i-1)*dimc 514 | y_coords = self.cell_y+(j-1)*dimc 515 | z_coords = self.cell_z+(k-1)*dimc 516 | model_float[x_coords+1,y_coords+1,z_coords+1] = 1.0 517 | 518 | print("running queue:",len(queue)) 519 | cell_batch_size = dimc**3 520 | cell_batch_num = int(self.test_point_batch_size/cell_batch_size) 521 | assert cell_batch_num>0 522 | #run queue 523 | while len(queue)>0: 524 | batch_num = min(len(queue),cell_batch_num) 525 | point_list = [] 526 | cell_coords = [] 527 | for i in range(batch_num): 528 | point = queue.pop(0) 529 | point_list.append(point) 530 | cell_coords.append(self.cell_coords[point[0]-1,point[1]-1,point[2]-1]) 531 | cell_coords = np.concatenate(cell_coords, axis=0) 532 | cell_coords = np.expand_dims(cell_coords, axis=0) 533 | cell_coords = torch.from_numpy(cell_coords) 534 | cell_coords = cell_coords.to(self.device) 535 | _, model_out_batch_ = self.im_network(None, z, cell_coords, is_training=False) 536 | model_out_batch = model_out_batch_.detach().cpu().numpy()[0] 537 | for i in range(batch_num): 538 | point = point_list[i] 539 | model_out = model_out_batch[i*cell_batch_size:(i+1)*cell_batch_size,0] 540 | x_coords = self.cell_x+(point[0]-1)*dimc 541 | y_coords = self.cell_y+(point[1]-1)*dimc 542 | z_coords = self.cell_z+(point[2]-1)*dimc 543 | model_float[x_coords+1,y_coords+1,z_coords+1] = model_out 544 | 545 | if np.max(model_out)>self.sampling_threshold: 546 | for i in range(-1,2): 547 | pi = point[0]+i 548 | if pi<=0 or pi>dimf: continue 549 | for j in range(-1,2): 550 | pj = point[1]+j 551 | if pj<=0 or pj>dimf: continue 552 | for k in range(-1,2): 553 | pk = point[2]+k 554 | if pk<=0 or pk>dimf: continue 555 | if (frame_flag[pi,pj,pk] == 0): 556 | frame_flag[pi,pj,pk] = 1 557 | queue.append((pi,pj,pk)) 558 | return model_float 559 | 560 | #may introduce foldovers 561 | def optimize_mesh(self, vertices, z, iteration = 3): 562 | new_vertices = np.copy(vertices) 563 | 564 | new_vertices_ = np.expand_dims(new_vertices, axis=0) 565 | new_vertices_ = torch.from_numpy(new_vertices_) 566 | new_vertices_ = new_vertices_.to(self.device) 567 | _, new_v_out_ = self.im_network(None, z, new_vertices_, is_training=False) 568 | new_v_out = new_v_out_.detach().cpu().numpy()[0] 569 | 570 | for iter in range(iteration): 571 | for i in [-1,0,1]: 572 | for j in [-1,0,1]: 573 | for k in [-1,0,1]: 574 | if i==0 and j==0 and k==0: continue 575 | offset = np.array([[i,j,k]],np.float32)/(self.real_size*6*2**iter) 576 | current_vertices = vertices+offset 577 | current_vertices_ = np.expand_dims(current_vertices, axis=0) 578 | current_vertices_ = torch.from_numpy(current_vertices_) 579 | current_vertices_ = current_vertices_.to(self.device) 580 | _, current_v_out_ = self.im_network(None, z, current_vertices_, is_training=False) 581 | current_v_out = current_v_out_.detach().cpu().numpy()[0] 582 | keep_flag = abs(current_v_out-self.sampling_threshold)