├── ShapeNet.md ├── ShapeNet.json ├── shapenet_1000.json ├── utils.py ├── binvox2mat.py ├── README.md ├── generate3DR2N2split.py ├── crop_images.py ├── ConversionTestCase.py ├── evaluate.py ├── vis_objects.py ├── DatasetLoader.py ├── ResNet.py ├── DatasetCollector.py ├── grad_approx_fun ├── finite_diff_torch.py └── finite_diff_torch_faster.py ├── test.py └── train.py /ShapeNet.md: -------------------------------------------------------------------------------- 1 | # Erroneous Files 2 | 3 | 03624134/67ada28ebc79cc75a056f196c127ed77/model.obj 4 | 04090263/4a32519f44dc84aabafe26e2eb69ebf4/model.obj 5 | 04074963/b65b590a565fa2547e1c85c5c15da7fb/model.obj 6 | 7 | -------------------------------------------------------------------------------- /ShapeNet.json: -------------------------------------------------------------------------------- 1 | { 2 | "04256520": { 3 | "id": "04256520", 4 | "name": "sofa,couch,lounge" 5 | }, 6 | "03001627": { 7 | "id": "03001627", 8 | "name": "chair" 9 | }, 10 | "04401088": { 11 | "id": "04401088", 12 | "name": "telephone,phone,telephone set" 13 | }, 14 | "03691459": { 15 | "id": "03691459", 16 | "name": "loudspeaker,speaker,speaker unit,loudspeaker system,speaker system" 17 | }, 18 | "02958343": { 19 | "id": "02958343", 20 | "name": "car,auto,automobile,machine,motorcar" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /shapenet_1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "02876657": { 3 | "id": "02876657", 4 | "name": "bottle" 5 | }, 6 | "04256520": { 7 | "id": "04256520", 8 | "name": "sofa,couch,lounge" 9 | }, 10 | "03001627": { 11 | "id": "03001627", 12 | "name": "chair" 13 | }, 14 | "04401088": { 15 | "id": "04401088", 16 | "name": "telephone,phone,telephone set" 17 | }, 18 | "03691459": { 19 | "id": "03691459", 20 | "name": "loudspeaker,speaker,speaker unit,loudspeaker system,speaker system" 21 | }, 22 | "02958343": { 23 | "id": "02958343", 24 | "name": "car,auto,automobile,machine,motorcar" 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | import os 12 | 13 | def convert_files(dir, in_ext, action, recursive=True, exclude_ext=None): 14 | """ Traverse directory recursively to convert files. 15 | If recursive==False, only files in the directory dir are converted. 16 | """ 17 | files = sorted(os.listdir(dir)) 18 | for file in files: 19 | path = os.path.join(dir, file) 20 | if os.path.isdir(path): 21 | if recursive: 22 | convert_files(path, in_ext, action, recursive) 23 | pass 24 | pass 25 | elif path.endswith(in_ext): 26 | if exclude_ext is None or not path.endswith(exclude_ext): 27 | action(path) 28 | pass 29 | pass 30 | pass 31 | pass 32 | -------------------------------------------------------------------------------- /binvox2mat.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | from __future__ import print_function 12 | from utils import convert_files 13 | import scipy.io as sio 14 | import os 15 | import numpy as np 16 | import argparse 17 | 18 | def read_binvoxfile_as_3d_array(*args): 19 | # This file requires some external library function that reads binvox files 20 | # and returns the loaded shapes as a numpy 3D array. 21 | raise NotImplementedError 22 | 23 | 24 | def convert_binvox(filename): 25 | """ Converts a single binvox file to Matlab. 26 | The resulting mat file stores one variable 'voxel'. 27 | """ 28 | print('Converting %s ... ' % filename, end='') 29 | try: 30 | with open(filename, 'rb') as f: 31 | md = read_binvoxfile_as_3d_array(f) 32 | v = np.array(md.data, dtype='uint8') 33 | sio.savemat(filename[:-7]+'.vox.mat', {'voxel':v[::-1,::-1,::-1].copy()}, do_compression=True) 34 | pass 35 | pass 36 | except: 37 | print('failed.') 38 | return 39 | print('done.') 40 | pass 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | parser = argparse.ArgumentParser('Converts .binvox files to .mat files.') 46 | parser.add_argument('directory', type=str, help='Directory with binvox files.', default='.') 47 | parser.add_argument('-r', '--recursive', action='store_true', help='Recursively traverses the directory and converts all binvox files.') 48 | 49 | args = parser.parse_args() 50 | 51 | convert_files(args.directory, '.binvox', convert_binvox, args.recursive) 52 | pass 53 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Implemetation of the paper: 2 | **Implicit Surface Representations as Layers in Neural Networks** 3 | 4 | Michalkiewicz M, Pontes K, Jack D, Baktashmotlagh M, Eriksson A. In ICCV 2019. 5 | 6 | 7 | [link] (http://openaccess.thecvf.com/content_ICCV_2019/papers/Michalkiewicz_Implicit_Surface_Representations_As_Layers_in_Neural_Networks_ICCV_2019_paper.pdf) 8 | 9 | ----------------------- 10 | Dependencies: 11 | ----------------------- 12 | 13 | To run the code, install the following packages in conda environment: 14 | 15 | ``` 16 | conda create -n dls python=3.7 17 | source activate dls 18 | conda install scipy pillow Pillow trimesh numpy 19 | conda install -c conda-forge scikit-fmm 20 | conda install pytorch torchvision -c pytorch 21 | ``` 22 | 23 | 24 | ---------------------------------------------------------------------- 25 | General notes 26 | ---------------------------------------------------------------------- 27 | **The code is largely based on Matryoshka [1] repository [2] and was modified accordingly.** 28 | 29 | The 2D encoder used is based on Matryoshka paper [1], however using any other encoder 30 | should give similar results. 31 | 32 | The very simple 3D decoder used is based on TL paper [3], however using any other 33 | 3D decoder should give similar (most likely better) results. 34 | 35 | ------------ 36 | Datasets 37 | ------------ 38 | We have used 3D models from ShapeNetCore.v1 39 | 40 | 2D input images are expected to be have a shape of 128x128. 41 | 42 | To process standard 3D-R2N2 [4] views, use `crop_images.py`. 43 | 44 | 3D ground truth should be signed distance functions of watertight manifolds of shape 45 | 32x32x32. Watertight manifolds can be obtained with the Manifold code [5] 46 | 47 | Datasets are loaded using DatasetCollector.py and DatasetLoader.py. 48 | 49 | ------------------------- 50 | References 51 | --------------------------- 52 | 53 | [1] https://arxiv.org/abs/1804.10975 54 | 55 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 56 | 57 | [3] https://arxiv.org/abs/1603.08637 58 | 59 | [4] https://arxiv.org/abs/1604.00449 60 | -------------------------------------------------------------------------------- /generate3DR2N2split.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | import os 12 | import json 13 | from collections import OrderedDict 14 | 15 | from ipdb import set_trace 16 | 17 | 18 | 19 | 20 | def id_to_name(id, category_list): 21 | for k, v in category_list.items(): 22 | if v[0] <= id and v[1] > id: 23 | return (k, id - v[0]) 24 | 25 | 26 | 27 | 28 | 29 | def category_model_id_pair(dataset_portion=[]): 30 | ''' 31 | Load category, model names from a shapenet dataset. 32 | ''' 33 | 34 | def model_names(model_path): 35 | """ Return model names""" 36 | model_names = [name for name in os.listdir(model_path) 37 | if os.path.isdir(os.path.join(model_path, name))] 38 | return sorted(model_names) 39 | 40 | category_name_pair = [] # full path of the objs files 41 | 42 | cats = json.load(open('ShapeNet.json')) 43 | 44 | cats = OrderedDict(sorted(cats.items(), key=lambda x: x[0])) 45 | 46 | for k, cat in cats.items(): # load by categories 47 | model_path = os.path.join('/home/matryoshka/matryoshka/data/ShapeNetVox32', cat['id']) 48 | 49 | models = model_names(model_path) 50 | num_models = len(models) 51 | 52 | portioned_models = models[int(num_models * dataset_portion[0]):int(num_models * 53 | dataset_portion[1])] 54 | 55 | category_name_pair.extend([(cat['id'], model_id) for model_id in portioned_models]) 56 | 57 | return category_name_pair 58 | 59 | 60 | 61 | with open('3dr2n2-train.txt', 'w') as f: 62 | for synset, model in category_model_id_pair([0,0.8]): 63 | f.write('%s/%s\n' % (synset, model)) 64 | 65 | 66 | 67 | with open('3dr2n2-test.txt', 'w') as f: 68 | for synset, model in category_model_id_pair([0.8,1]): 69 | f.write('%s/%s\n' % (synset, model)) 70 | -------------------------------------------------------------------------------- /crop_images.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | from utils import convert_files 12 | from math import ceil, floor 13 | import PIL.Image 14 | import numpy as np 15 | import argparse 16 | import os 17 | from ipdb import set_trace 18 | 19 | def load_image(temp): 20 | # Only process if image has transparency (http://stackoverflow.com/a/1963146) 21 | if temp.mode == 'RGBA': 22 | alpha = temp.split()[-1] 23 | bg = PIL.Image.new("RGBA", temp.size, (255,255,255) + (255,)) 24 | bg.paste(temp, mask=alpha) 25 | im = bg.convert('RGB').copy() 26 | bg.close() 27 | temp.close() 28 | else: 29 | im = temp.copy() 30 | temp.close() 31 | return im 32 | 33 | def make_crop_func(size): 34 | def croptox(path): 35 | pad = 100 36 | img = load_image(PIL.Image.open(path)) 37 | a = np.asarray(img) 38 | bg_mask = np.min(img, axis=2)==255 39 | cols = np.flatnonzero(np.logical_not(np.min(bg_mask, axis=0))) 40 | rows = np.flatnonzero(np.logical_not(np.min(bg_mask, axis=1))) 41 | p = ceil(np.max([cols[-1]-cols[0], rows[-1]-rows[0]]) / 2) 42 | imp = np.pad(img, ((pad, pad), (pad, pad), (0,0)), 'constant', constant_values=255) 43 | row_offset = pad+floor((rows[0]+rows[-1])/2) 44 | col_offset = pad+floor((cols[0]+cols[-1])/2) 45 | i2 = imp[row_offset-p:row_offset+p+1, \ 46 | col_offset-p:col_offset+p+1,:] 47 | 48 | i2 = PIL.Image.fromarray(i2).resize((size, size), PIL.Image.LANCZOS) 49 | i2.save(path[:-4] + '.%d.png' % size, 'PNG') 50 | pass 51 | return croptox 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | parser = argparse.ArgumentParser('Crop images to certain size.') 57 | parser.add_argument('directory', type=str, help='Directory with PNG files.', default='.') 58 | parser.add_argument('-s', '--size', type=int, default=128, help='Target size of images (width in pixels).') 59 | parser.add_argument('-r', '--recursive', action='store_true', help='Recursively traverses the directory.') 60 | 61 | args = parser.parse_args() 62 | 63 | set_trace() 64 | 65 | convert_files(args.directory, '.png', make_crop_func(args.size), args.recursive, '.%s.png' % args.size) 66 | pass 67 | -------------------------------------------------------------------------------- /ConversionTestCase.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | import torch 15 | 16 | # for loading voxel grid from MATLAB file 17 | import scipy.io as sio 18 | 19 | import voxel2layer as v2lnp 20 | import voxel2layer_torch as v2lt 21 | 22 | class ConversionTestCase(unittest.TestCase): 23 | 24 | def setUp(self): 25 | # load sample shape as voxel representation 26 | d = sio.loadmat('./data/chair.mat') 27 | self.voxelnp = d['voxel'] 28 | self.voxelt = torch.from_numpy(self.voxelnp) 29 | self.shlnp = d['shl'] 30 | self.shlt = torch.from_numpy(self.shlnp.astype(np.int32)).to(torch.int16) 31 | pass 32 | 33 | def tearDown(self): 34 | pass 35 | 36 | def test_encode_numpy(self): 37 | """ Tests converting from voxel to shape layer representation using numpy. """ 38 | shl = v2lnp.encode_shape(self.voxelnp) 39 | self.assertTrue(np.all(shl == self.shlnp)) 40 | pass 41 | 42 | def test_decode_numpy(self): 43 | """ Tests converting from shape layer to voxel representation using numpy. """ 44 | voxel = v2lnp.decode_shape(self.shlnp) 45 | # sio.savemat('decode_numpy.mat', {'voxel':voxel, 'gt':self.voxelnp}) 46 | self.assertTrue(np.all(voxel == self.voxelnp)) 47 | pass 48 | 49 | def test_encode_torch(self): 50 | """ Tests converting from voxel to shape layer representation using torch. """ 51 | shl = v2lt.encode_shape(self.voxelt) 52 | # sio.savemat('encode_torch.mat', {'shl':shl.numpy(), 'gt':self.shlt.numpy()}) 53 | self.assertTrue((shl == self.shlt).all()) 54 | pass 55 | 56 | def test_decode_torch(self): 57 | """ Tests converting from shape layer to voxel representation using torch. """ 58 | voxel = v2lt.decode_shape(self.shlt) 59 | # sio.savemat('decode_torch.mat', {'voxel':voxel.numpy(), 'gt':self.voxelt.numpy()}) 60 | self.assertTrue((voxel == self.voxelt).all()) 61 | pass 62 | 63 | def test_roundtrip_numpy(self): 64 | """ Tests converting from voxel to shape layer and back using numpy.""" 65 | voxel = v2lnp.decode_shape(v2lnp.encode_shape(self.voxelnp, 2)) 66 | self.assertTrue(np.all(voxel == self.voxelnp)) 67 | pass 68 | 69 | def test_roundtrip_torch(self): 70 | """ Tests converting from voxel to shape layer and back using torch.""" 71 | voxel = v2lt.decode_shape(v2lt.encode_shape(self.voxelt, 2)) 72 | self.assertTrue((voxel == self.voxelt).all()) 73 | pass 74 | 75 | def test_shapelayer_conversion(self): 76 | """ Tests modifying the shape layer representation for better alignment. 77 | """ 78 | shlx = v2lt.shl2shlx(self.shlt.clone().permute(2,0,1).reshape(1,6,128,128)) 79 | shl = v2lt.shlx2shl(shlx).reshape(6,128,128).permute(1,2,0) 80 | self.assertTrue((shl == self.shlt).all()) 81 | pass 82 | 83 | if __name__ == '__main__': 84 | unittest.main() 85 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | import numpy as np 12 | import scipy.io as sio 13 | import argparse 14 | import os 15 | 16 | from voxel2layer import decode_shape, generate_indices 17 | from DatasetCollector import * 18 | 19 | def evaluate_sample(file, shape_layer, id1, id2, id3): 20 | 21 | d = sio.loadmat(file) 22 | voxel = d['voxel'] 23 | 24 | decoded = decode_shape(shape_layer, id1, id2, id3) 25 | 26 | return np.sum(np.logical_and(voxel, decoded)) / np.sum(np.logical_or(voxel, decoded)) 27 | 28 | 29 | if __name__ == '__main__': 30 | 31 | name2dataset = {'ShapeNetPTN':ShapeNetPTNCollector, \ 32 | 'ShapeNetCars':ShapeNetCarsOGNCollector, \ 33 | 'ShapeNet':ShapeNet3DR2N2Collector} 34 | dataset_default = 'ShapeNet' 35 | 36 | parser = argparse.ArgumentParser('Evaluate results') 37 | parser.add_argument('result_dir', type=str, default='.', help='Directory with result batches.') 38 | parser.add_argument('--dataset', type=str, default=dataset_default, help=('Dataset [%s]' % ','.join(name2dataset.keys()))) 39 | parser.add_argument('--basedir', type=str, default='./data/', help='Base directory for dataset.') 40 | parser.add_argument('--set', type=str, choices=['val', 'test'], help='Subset to evaluate. (default: val)', default='val') 41 | args = parser.parse_args() 42 | 43 | collector = name2dataset[args.dataset](basedir=args.basedir) 44 | 45 | if args.set == 'val': 46 | samples = collector.val() 47 | elif args.set == 'test': 48 | samples = collector.test() 49 | 50 | batchfiles = [os.path.join(args.result_dir, f) \ 51 | for f in sorted(os.listdir(args.result_dir)) \ 52 | if f.startswith('b_') and f.endswith('.mat')] 53 | 54 | agg_iou = 0 55 | count = 0 56 | num_samples = len(samples) 57 | 58 | with open(os.path.join(args.result_dir, 'results.txt'), 'w') as f: 59 | for batchfile in batchfiles: 60 | batch = sio.loadmat(batchfile)['results'] 61 | for i in range(batch.shape[0]): 62 | if count == 0: 63 | side = batch.shape[2] 64 | id1, id2, id3 = generate_indices(side) 65 | pass 66 | 67 | shape_layer = np.transpose(batch[i,:,:,:], axes=[1,2,0]) 68 | iou = evaluate_sample(samples[count][1][:-8]+'.vox.mat', shape_layer, id1, id2, id3) 69 | agg_iou += iou 70 | 71 | f.write('%d,%s,%.1f\n' % (count, samples[count][1][:-8], 100*iou)) 72 | 73 | count += 1 74 | 75 | if count % 100 == 0: 76 | print('Mean %.1f (%d/%d)' % (100*agg_iou / count, count, num_samples)) 77 | pass 78 | pass 79 | pass 80 | f.write('Mean,%.1f\n' % (100*agg_iou / count)) 81 | pass 82 | pass 83 | -------------------------------------------------------------------------------- /vis_objects.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | from mayavi import mlab 3 | from mayavi.mlab import show 4 | from matplotlib import pyplot as plt 5 | from ipdb import set_trace 6 | from binvox_rw import read_as_3d_array 7 | import numpy as np 8 | from trimesh import load_mesh 9 | import trimesh 10 | from skimage.measure import marching_cubes_lewiner as mc 11 | 12 | from os import listdir, system 13 | from os.path import join 14 | 15 | color = tuple(np.asarray((38, 139, 210))/255.0) 16 | 17 | def merge_mesh(mesh_list): 18 | vertices = mesh_list[0].vertices 19 | faces = mesh_list[0].faces 20 | for i in range(len(mesh_list))[1:]: 21 | nr_of_verts = vertices.shape[0] 22 | new_verts = mesh_list[i].vertices 23 | new_faces = mesh_list[i].faces 24 | vertices = np.append(vertices, new_verts).reshape(-1,3) 25 | new_faces += nr_of_verts 26 | faces = np.append(faces, new_faces).reshape(-1,3) 27 | new_mesh = trimesh.Trimesh(vertices, faces) 28 | return new_mesh 29 | 30 | def vis_pc(pc, **kwargs): 31 | ''' vis point cloud of shape X,3 32 | use scale_factor=0.1 for example as arg 33 | ''' 34 | assert len(pc.shape) == 2 35 | assert pc.shape[1] == 3 36 | mlab.points3d(pc[:, 0], pc[:, 1], pc[:, 2], **kwargs) 37 | 38 | def vis_mesh( 39 | vertices, faces, axis_order='zyx', include_wireframe=True, 40 | color=color, distance=2.,**kwargs): 41 | if len(faces) == 0: 42 | print('Warning: no faces') 43 | return 44 | fig_data = mlab.figure(size=(800, 600), bgcolor=(1,1,1), fgcolor=None, engine=None) 45 | x, y, z = permute_xyz(*vertices.T, order=axis_order) 46 | mlab.triangular_mesh(x, y, z, faces, color=color, **kwargs) 47 | mlab.view(distance=distance, focalpoint='auto', roll=2) 48 | 49 | def vis_sdf(sdf, distance=56., threshold=0.0, **kwargs): 50 | ''' extracts levelset and uses vis_mesh''' 51 | levelset = mc(sdf, threshold) 52 | vis_mesh(vertices=levelset[0], faces = levelset[1], distance=distance, 53 | **kwargs) 54 | 55 | def vis_voxels(voxels, axis_order='xzy',color=color, **kwargs): 56 | fig_data = mlab.figure(size=(800, 600), bgcolor=(1,1,1), fgcolor=None, engine=None) 57 | data = permute_xyz(*np.where(voxels), order=axis_order) 58 | if len(data[0]) == 0: 59 | # raise ValueError('No voxels to display') 60 | Warning('No voxels to display') 61 | else: 62 | kwargs.setdefault('mode', 'cube') 63 | mlab.points3d(*data,color=color, **kwargs) 64 | 65 | def vis_view(view_np): 66 | ''' views view. assumes a np array''' 67 | if view_np.shape[0] == 3: 68 | new_view\ 69 | = np.zeros((view_np.shape[1],view_np.shape[2], view_np.shape[0])) 70 | new_view[:,:,0] = view_np[0,:] 71 | new_view[:,:,1] = view_np[1,:] 72 | new_view[:,:,2] = view_np[2,:] 73 | view_np = new_view 74 | plt.imshow(view_np) 75 | plt.show() 76 | 77 | def show(): 78 | mlab.show() 79 | 80 | def permute_xyz(x, y, z, order='xyz'): 81 | _dim = {'x': 0, 'y': 1, 'z': 2} 82 | data = (x, y, z) 83 | return tuple(data[_dim[k]] for k in order) 84 | 85 | def print_voxels_to_file(vxls,path_to_save_data): 86 | vis_voxels(vxls) 87 | mlab.savefig(path_to_save_data) 88 | sleep(2) 89 | mlab.close() 90 | 91 | def print_sdf_to_file(sdf, path_to_save_data, threshold=0.0, distance=56.): 92 | vertices, faces, _, _ = mc(sdf, threshold) 93 | print_mesh_to_file(vertices, faces, path_to_save_data, distance=distance) 94 | 95 | def print_mesh_to_file(verts, faces, path_to_save_data, distance=56.): 96 | vis_mesh(verts, faces, distance=distance) 97 | mlab.savefig(path_to_save_data) 98 | sleep(2) 99 | mlab.close() 100 | 101 | -------------------------------------------------------------------------------- /DatasetLoader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | from __future__ import print_function 12 | import torch.utils.data as data 13 | from PIL import Image, ImageOps 14 | import os 15 | import torch 16 | import numpy as np 17 | from torchvision import transforms 18 | import scipy.io as sio 19 | 20 | from binvox_rw import read_as_3d_array 21 | from ipdb import set_trace 22 | import skfmm 23 | 24 | class RandomColorFlip(object): 25 | def __call__(self, img): 26 | c = np.random.choice(3,3,np.random.random() < 0.5) 27 | return img[c,:,:] 28 | 29 | class DatasetLoader(data.Dataset): 30 | def __init__(self, samples, side, num_comp=1, input_transform=None, no_images=False, no_shapes=False): 31 | 32 | self.input_transform = input_transform 33 | self.num_comp = num_comp 34 | self.samples = samples 35 | self.side = side 36 | self.no_images = no_images 37 | self.no_shapes = no_shapes 38 | pass 39 | 40 | def __getitem__(self, index): 41 | 42 | imagepath = self.samples[index][0] 43 | shapepath = self.samples[index][1] 44 | example_id = self.samples[index][2] 45 | 46 | flipped = False#random.random() > 0.5 if self.flip else False 47 | 48 | if self.no_images: 49 | imgs = None 50 | else: 51 | imgs = self.input_transform(self._load_image(Image.open(imagepath),flipped)) 52 | 53 | if self.no_shapes: 54 | shape = None 55 | else: 56 | if shapepath.endswith('.shl.mat'): # shape layer 57 | shape = self._load_shl(shapepath) 58 | elif shapepath.endswith('.vox.mat'): # voxels in mat format 59 | shape = self._load_vox2(shapepath) 60 | elif shapepath.endswith('.binvox'): # voxels in binvox format 61 | shape = self._load_binvox(shapepath) 62 | elif ('sdf_' in shapepath) and (shapepath.endswith('.npy')): 63 | shape = self._load_sdf(shapepath) 64 | elif ('vox_' in shapepath) and (shapepath.endswith('.npy')): 65 | shape = self._load_vox(shapepath) 66 | elif 'dist' in shapepath: # chamfer 67 | shape = self._load_dist(shapepath) 68 | else: 69 | assert False, ('Could not determine shape representation from file name (%s has neither ".shl.mat" nor ".vox.mat").' % shapepath) 70 | 71 | if self.no_images: 72 | if self.no_shapes: 73 | return 74 | else: 75 | return shape 76 | else: 77 | if self.no_shapes: 78 | return imgs 79 | else: 80 | return imgs, shape, example_id 81 | 82 | 83 | def __len__(self): 84 | return len(self.samples) 85 | 86 | def _load_dist(self, path): 87 | dist = np.load(path) 88 | assert dist.shape == (self.side**3,) 89 | return torch.from_numpy(dist.astype('float32')) 90 | 91 | def _load_vox(self, path): 92 | vox = np.load(path) 93 | if len(vox.shape) != 3: # sdf is flattened 94 | vox = vox.squeeze() 95 | assert len(vox.shape) == 1 96 | dim = np.cbrt(vox.shape[0]) 97 | assert dim == int(dim) 98 | vox = vox.reshape((int(dim),)*3).astype('int') 99 | return torch.from_numpy(vox.astype('int')) 100 | 101 | def _load_sdf(self, path): 102 | sdf = np.load(path) 103 | if len(sdf.shape) != 3: # sdf is flattened 104 | sdf = sdf.squeeze() 105 | assert len(sdf.shape) == 1 106 | dim = np.cbrt(sdf.shape[0]) 107 | assert dim == int(dim) 108 | sdf = sdf.reshape((int(dim),)*3).astype('float32') 109 | return torch.from_numpy(sdf.astype('float32')) 110 | 111 | 112 | def _load_vox22(self, path): 113 | d = sio.loadmat(path) 114 | return torch.from_numpy(d['voxel']) 115 | 116 | def _load_binvox(self, path): 117 | ''' loads voxels saved in binvox format. see also _load_vox ''' 118 | if not os.path.exists(path): 119 | raise Exception('path does not exist: '+ path) 120 | with open(path, 'rb') as fin: 121 | voxels = read_as_3d_array(fin) 122 | return torch.from_numpy(voxels.data.astype('uint8')) 123 | 124 | def _load_shl(self, path): 125 | d = sio.loadmat(path) 126 | return torch.from_numpy(np.array(d['shapelayer'], dtype=np.int32)[:,:,:6*self.num_comp]).permute(2,0,1).contiguous().float() 127 | 128 | def _load_image(self, temp, flipped=False): 129 | if temp.mode == 'RGBA': 130 | alpha = temp.split()[-1] 131 | bg = Image.new("RGBA", temp.size, (128,128,128) + (255,)) 132 | bg.paste(temp, mask=alpha) 133 | im = bg.convert('RGB').copy() 134 | bg.close() 135 | temp.close() 136 | else: 137 | im = temp.copy() 138 | temp.close() 139 | return (im.transpose(Image.FLIP_LEFT_RIGHT) if flipped else im) 140 | -------------------------------------------------------------------------------- /ResNet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | from math import log, floor 16 | 17 | from ipdb import set_trace 18 | 19 | 20 | class ResNet(nn.Module): 21 | 22 | def __init__(self, num_input_channels, num_output_channels, num_penultimate_channels, \ 23 | input_resolution, output_resolution, num_initial_channels=16, num_inner_channels=64, \ 24 | num_downsampling=3, num_blocks=6, bottleneck_dim=1024): 25 | 26 | assert num_blocks >= 0 27 | 28 | super(ResNet, self).__init__() 29 | 30 | relu = nn.ReLU(True) 31 | 32 | model = [nn.BatchNorm2d(num_input_channels, True)] # TODO 33 | 34 | # additional down and upsampling blocks to account for difference in input/output resolution 35 | num_additional_down = int(log(input_resolution / output_resolution,2)) if output_resolution <= input_resolution else 0 36 | num_additional_up = int(log(output_resolution / input_resolution,2)) if output_resolution > input_resolution else 0 37 | 38 | # number of channels to add during downsampling 39 | num_channels_down = int(floor(float(num_inner_channels - num_initial_channels)/(num_downsampling+num_additional_down))) 40 | 41 | # adjust number of initial channels 42 | num_initial_channels += (num_inner_channels-num_initial_channels) % num_channels_down 43 | 44 | # initial feature block 45 | model += [nn.ReflectionPad2d(1), 46 | nn.Conv2d(num_input_channels, num_initial_channels, kernel_size=3, padding=0), 47 | nn.BatchNorm2d(num_initial_channels), 48 | relu] 49 | model += [nn.ReflectionPad2d(1), 50 | nn.Conv2d(num_initial_channels, num_initial_channels, kernel_size=3, padding=0)] 51 | 52 | # downsampling 53 | for i in range(num_downsampling+num_additional_down): 54 | model += [ResDownBlock(num_initial_channels, num_channels_down)] 55 | model += [ResSameBlock(num_initial_channels+num_channels_down)] 56 | num_initial_channels += num_channels_down 57 | pass 58 | 59 | # inner blocks at constant resolution 60 | for i in range(num_blocks): 61 | model += [ResSameBlock(num_initial_channels)] 62 | pass 63 | 64 | 65 | model_encoder = model[:20].copy() # model_encoder is bsx512x1x1 66 | self.model_enc = nn.Sequential(*model_encoder) 67 | 68 | # BOTTLENECK 69 | tmp_init_depth = 8 # nr of channels 70 | tmp_n_final = 6 # we will first resize to 6x6x6 with 8 channels 71 | 72 | enc_last_dim = num_initial_channels 73 | model_bottleneck = [nn.Linear(enc_last_dim, bottleneck_dim)] 74 | model_bottleneck += [nn.Linear(bottleneck_dim, tmp_init_depth*tmp_n_final**3)] 75 | self.model_bottleneck = nn.Sequential(*model_bottleneck) 76 | # DECODER 77 | model_decoder = [] 78 | 79 | kernel_size_list = [3,3, 3, 5, 5, 7, 7] 80 | channels_list = [8,8, 16, 32, 16, 8, 4, 4] 81 | 82 | for deconv_iter in range(len(kernel_size_list)): 83 | model_decoder += [nn.ConvTranspose3d(channels_list[deconv_iter], 84 | channels_list[deconv_iter+1], kernel_size_list[deconv_iter], 85 | stride=1, padding=0, output_padding=0, groups=1, bias=True, 86 | dilation=1), 87 | nn.BatchNorm3d(channels_list[deconv_iter+1], eps=1e-05, 88 | momentum=0.1, affine=True, track_running_stats=True), 89 | nn.ReLU(True)] 90 | 91 | model_decoder += [nn.ConvTranspose3d(channels_list[-1], 92 | 1, 1, 93 | stride=1, padding=0, output_padding=0, groups=1, bias=True, 94 | dilation=1)] 95 | self.model_decoder = nn.Sequential(*model_decoder) 96 | return 97 | 98 | 99 | 100 | def forward(self, input): 101 | x = self.model_enc(input) # x should be 32,512,1,1 102 | x = x.view(-1, x.shape[1]) 103 | # bottleneck FC 1024, FC 6*6*6*8 104 | x = self.model_bottleneck(x) 105 | tmp_init_depth = 8 # nr of channels 106 | tmp_n_final = 6 # we will first resize to 6x6x6 with 8 channels 107 | x = x.view((-1,)+(tmp_init_depth,) + (tmp_n_final,)*3 ) 108 | # decoder 109 | x = self.model_decoder(x) 110 | x = x[:,0,:,:,:] # remove the '1' in bsxcxWxDxH, c=1 111 | return x 112 | pass 113 | 114 | 115 | class ResSameBlock(nn.Module): 116 | """ ResNet block for constant resolution. 117 | """ 118 | 119 | def __init__(self, dim): 120 | super(ResSameBlock, self).__init__() 121 | 122 | self.model = nn.Sequential(*[nn.BatchNorm2d(dim, True), \ 123 | nn.ReLU(True), 124 | nn.Conv2d(dim, dim, kernel_size=3, padding=1), 125 | nn.BatchNorm2d(dim), 126 | nn.ReLU(True), 127 | nn.Conv2d(dim, dim, kernel_size=3, padding=1)]) 128 | 129 | def forward(self, x): 130 | return x + self.model(x) 131 | pass 132 | 133 | 134 | class ResUpBlock(nn.Module): 135 | """ ResNet block for upsampling. 136 | """ 137 | 138 | def __init__(self, dim, num_up): 139 | super(ResUpBlock, self).__init__() 140 | 141 | self.model = nn.Sequential(*[nn.BatchNorm2d(dim, True),\ 142 | nn.ReLU(False), 143 | nn.ConvTranspose2d(dim, -num_up+dim, kernel_size=4, padding=1, stride=2), 144 | nn.BatchNorm2d(-num_up+dim, True), 145 | nn.ReLU(True), 146 | nn.Conv2d(-num_up+dim, -num_up+dim, kernel_size=3, padding=1)]) 147 | 148 | self.project = nn.Conv2d(dim,dim-num_up,kernel_size=1) 149 | pass 150 | 151 | def forward(self, x): 152 | # xu = F.upsample(x,scale_factor=2,mode='nearest') 153 | xu = F.interpolate(x, scale_factor=2, mode='nearest') 154 | bs,_,h,w = xu.size() 155 | return self.project(xu) + self.model(x) 156 | pass 157 | 158 | 159 | class ResDownBlock(nn.Module): 160 | """ ResNet block for downsampling. 161 | """ 162 | 163 | def __init__(self, dim, num_down): 164 | super(ResDownBlock, self).__init__() 165 | self.num_down = num_down 166 | 167 | self.model = nn.Sequential(*[nn.BatchNorm2d(dim, True), \ 168 | nn.ReLU(False), 169 | nn.Conv2d(dim, num_down+dim, kernel_size=3, padding=1, stride=2), 170 | nn.BatchNorm2d(num_down+dim, True), 171 | nn.ReLU(True), 172 | nn.Conv2d(num_down+dim, num_down+dim, kernel_size=3, padding=1)]) 173 | pass 174 | 175 | def forward(self, x): 176 | xu = x[:,:,::2,::2] 177 | bs,_,h,w = xu.size() 178 | return torch.cat([xu, x.new_zeros(bs, self.num_down, h, w, requires_grad=False)],1) + self.model(x) 179 | pass 180 | -------------------------------------------------------------------------------- /DatasetCollector.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | import os 12 | import logging 13 | from ipdb import set_trace 14 | 15 | 16 | class DatasetCollector: 17 | 18 | def __init__(self): 19 | pass 20 | 21 | def classes(self): 22 | return [] 23 | 24 | def train(self, cls=None): 25 | return [] 26 | 27 | def val(self, cls=None): 28 | return [] 29 | 30 | def test(self, cls=None): 31 | return [] 32 | 33 | 34 | class SanityCollector(DatasetCollector): 35 | 36 | def __init__(self, *args, **kwargs): 37 | self.cls = ['chair'] 38 | 39 | def classes(self): 40 | return self.cls 41 | 42 | def _gather(self): 43 | return [('./data/model.128.png', './data/model.shl.mat')] 44 | 45 | def train(self, cls=None): 46 | return self._gather() 47 | 48 | def val(self, cls=None): 49 | return self._gather() 50 | 51 | def test(self, cls=None): 52 | return self._gather() 53 | 54 | 55 | class ShapeNetPTNCollector(DatasetCollector): 56 | """ Collects samples from ShapeNet using the version of Yan et al. 57 | """ 58 | 59 | def __init__(self, base_dir, crop=True): 60 | assert os.path.exists(base_dir), ('Base directory for PTN dataset does not exist [%s].' % base_dir) 61 | self.base_dir = base_dir 62 | self.id_dir = os.path.join(self.base_dir, 'shapenetcore_ids') 63 | self.view_dir = os.path.join(self.base_dir, 'shapenetcore_viewdata') 64 | self.shape_dir = os.path.join(self.base_dir, 'shapenetcore_voxdata') 65 | self.crop = crop 66 | self.cls = [] 67 | 68 | for c in sorted([d[:-12] for d in os.listdir(self.id_dir) if d.endswith('_testids.txt')]): 69 | if os.path.exists(os.path.join(self.id_dir, c+'_trainids.txt')) and \ 70 | os.path.exists(os.path.join(self.id_dir, c+'_valids.txt')) and \ 71 | os.path.exists(os.path.join(self.view_dir, c)) and \ 72 | os.path.exists(os.path.join(self.shape_dir, c)): 73 | self.cls.append(c) 74 | pass 75 | pass 76 | pass 77 | 78 | def _gather(self, subset, cls=None): 79 | if cls is None: 80 | cls = self.classes() 81 | pass 82 | 83 | samples = [] 84 | 85 | shape_suffix = 'model.shl.mat' if self.representation == 'shl' else 'model.vox.mat' 86 | for c in cls: 87 | logging.info('Collecting %s/%s...' % (subset, c)) 88 | with open(os.path.join(self.id_dir, '%s_%sids.txt' % (c, 'all'))) as f: 89 | for line in f: 90 | # format is class/id 91 | id = line.strip().split('/')[1] 92 | shapepath = os.path.join(self.shape_dir, c, id, shape_suffix) 93 | # check images 94 | viewdir = os.path.join(self.view_dir, c, id) 95 | for file in sorted(os.listdir(viewdir)): 96 | if self.crop and file.endswith('.128.png'): 97 | samples.append((os.path.join(viewdir, file), shapepath)) 98 | pass 99 | if not self.crop and file.endswith('.png') and not file.endswith('.128.png'): 100 | samples.append((os.path.join(viewdir, file, id), shapepath)) 101 | pass 102 | pass 103 | pass 104 | pass 105 | pass 106 | 107 | return samples 108 | 109 | def classes(self): 110 | return self.cls 111 | 112 | def train(self, cls=None): 113 | return self._gather('train', cls) 114 | 115 | def val(self, cls=None): 116 | return self._gather('val', cls) 117 | 118 | def test(self, cls=None): 119 | return self._gather('test', cls) 120 | pass 121 | 122 | class BlendswapOGNCollector(DatasetCollector): 123 | 124 | def __init__(self, base_dir, resolution=512): 125 | res2dir = {64:'64_l4', 128:'128_l4', 256:'256_l5', 512:'512_l5'} 126 | self.base_dir = os.path.join(base_dir, res2dir[resolution]) 127 | assert os.path.exists(self.base_dir), ('Base directory for OGN Blendswap dataset does not exist [%s].' % self.base_dir) 128 | pass 129 | 130 | def _gather(self): 131 | samples = [] 132 | shape_suffix = '.shl.mat' 133 | 134 | for file in sorted(os.listdir(self.base_dir)): 135 | if file.endswith(shape_suffix): 136 | samples.append(os.path.join(self.base_dir, file)) 137 | pass 138 | pass 139 | 140 | return samples 141 | 142 | def classes(self): 143 | return None 144 | 145 | def train(self): 146 | return self._gather('all') 147 | 148 | def val(self): 149 | return self._gather('all') 150 | 151 | def test(self): 152 | return self._gather('all') 153 | pass 154 | 155 | 156 | class ShapeNetCarsOGNCollector(DatasetCollector): 157 | """Assuming that text files with sample paths are in root dir.""" 158 | def __init__(self, base_dir, shapenet_base_dir, resolution=128, crop=True): 159 | res2dir = {64:'64_l4', 128:'128_l4', 256:'256_l4'} 160 | self.base_dir = os.path.join(base_dir, res2dir[resolution]) 161 | assert os.path.exists(self.base_dir), ('Base directory for OGN ShapeNet Cars dataset does not exist [%s].' % self.base_dir) 162 | self.shapenet_base_dir = shapenet_base_dir 163 | assert os.path.exists(self.shapenet_base_dir), ('ShapeNet rendering directory for OGN ShapeNet Cars dataset does not exist [%s].' % self.shapenet_base_dir) 164 | 165 | self.crop = crop 166 | 167 | for s in ['train', 'validation', 'test']: 168 | id_path = os.path.join(self.base_dir, 'shapenet_cars_rendered_new_%s.txt' % s) 169 | assert os.path.exists(id_path), ('Could not find id list for %s set [%s].' % (s, id_path)) 170 | pass 171 | 172 | assert os.path.exists(self.base_dir), ('Base directory for OGN ShapeNet Cars dataset does not exist [%s].' % self.base_dir) 173 | pass 174 | 175 | def classes(self): 176 | return ['car'] 177 | 178 | def _gather(self, subset): 179 | samples = [] 180 | with open(os.path.join(self.base_dir, 'shapenet_cars_rendered_new_%s.txt' % subset)) as f: 181 | for line in f: 182 | img_path, id = line.strip().split(' ') 183 | img_id = img_path.split('/')[-1] 184 | shapenet_id = img_path.split('/')[-3] 185 | img_path = os.path.join(self.shapenet_base_dir, '02958343', shapenet_id, \ 186 | 'rendering', img_id + ('.128.png' if self.crop else '.png')) 187 | shape_path = os.path.join(self.base_dir, id + shape_suffix) 188 | samples.append((img_path, shape_path)) 189 | pass 190 | pass 191 | return samples 192 | 193 | def train(self, cls=None): 194 | return self._gather('train') 195 | 196 | def val(self, cls=None): 197 | return self._gather('validation') 198 | 199 | def test(self, cls=None): 200 | return self._gather('test') 201 | pass 202 | 203 | 204 | class ShapeNet3DR2N2Collector(DatasetCollector): 205 | def __init__(self, base_dir, cat_id, representation='vox', side=32, p_norm=2): 206 | assert representation in ['vox', 'sdf', 'chamfer'] 207 | if representation == 'vox': 208 | self.shape_dir = os.path.join(base_dir, 'vox') 209 | elif representation == 'sdf': 210 | self.shape_dir = os.path.join(base_dir, 'sdfs') 211 | elif representation == 'chamfer': 212 | self.shape_dir = os.path.join(base_dir, 'ShapeNetDist'+str(side)) 213 | else: 214 | print('unknown representation') 215 | exit() 216 | self.view_dir = os.path.join(base_dir, 'ShapeNetRendering') 217 | self.list_dir = os.path.join(base_dir, 'ids') 218 | self.representation = representation 219 | self.side = side 220 | self.p_norm = p_norm 221 | 222 | self.cls = [] 223 | self.cls = [cat_id] 224 | 225 | def classes(self): 226 | return self.cls 227 | 228 | def _gather(self, subset, cls=None): 229 | if cls is None: 230 | cls = self.classes() 231 | pass 232 | 233 | samples = [] 234 | if self.representation == 'vox': 235 | shape_suffix = 'model.binvox' 236 | elif self.representation == 'sdf': 237 | shape_suffix = 'sdf_manifold_50000_grid_size_32.npy' 238 | elif self.representation == 'chamfer': 239 | shape_suffix\ 240 | = 'actual_dist_grid_size_'+str(self.side)+'_p_'+str(self.p_norm)+'.npy' 241 | else: 242 | print('unknown shape suffix') 243 | exit() 244 | for c in cls: 245 | logging.info('Collecting %s/%s...' % (subset, c)) 246 | # open example_ids.txt 247 | with open(os.path.join(self.list_dir,c, '%s_%s.txt' % (c, subset))) as f: 248 | for line in f: 249 | # format is class/id 250 | id = line.strip() 251 | if self.representation == 'vox': 252 | shapepath = os.path.join(self.shape_dir, c, str(self.side), id, 253 | id+'_vox_'+str(self.side)+'.npy') 254 | else: 255 | shapepath = os.path.join(self.shape_dir, c, str(self.side), id, 256 | id+'_sdf_'+str(self.side)+'_pad_1.npy') 257 | # check images 258 | viewdir = os.path.join(self.view_dir, c, id, 'rendering') 259 | for file in sorted(os.listdir(viewdir)): 260 | if file.endswith('.128.png') and (not file.endswith('128.128.png')): # FIX 261 | samples.append((os.path.join(viewdir, file), shapepath, id)) 262 | pass 263 | if not file.endswith('.128.png'): # FIX 264 | pass 265 | pass 266 | pass 267 | pass 268 | pass 269 | 270 | return samples 271 | 272 | def train(self, cls=None): 273 | return self._gather('train', cls) 274 | 275 | def val(self, cls=None): 276 | return [] 277 | 278 | def test(self, cls=None): 279 | return self._gather('test', cls) 280 | pass 281 | -------------------------------------------------------------------------------- /grad_approx_fun/finite_diff_torch.py: -------------------------------------------------------------------------------- 1 | from ipdb import set_trace 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def get_norm3d_batch(list_of_grad): 7 | assert type(list_of_grad) is list 8 | batch_size = len(list_of_grad) 9 | grid_size = list_of_grad[0][0].shape[0] 10 | norm_batch = np.zeros((batch_size, grid_size, grid_size, grid_size)) 11 | for _, c_grad in enumerate(list_of_grad): 12 | c_norm = get_norm3d(c_grad) 13 | norm_batch[_] = c_norm 14 | return norm_batch 15 | 16 | def get_norm3d_batch_torch(list_of_grad_torch): 17 | assert type(list_of_grad_torch) is list 18 | batch_size = len(list_of_grad_torch) 19 | grid_size = list_of_grad_torch[0][0].shape[0] 20 | norm_batch_torch = torch.zeros((batch_size, grid_size, grid_size, grid_size)) 21 | for _, c_grad_t in enumerate(list_of_grad_torch): 22 | c_norm_t = get_norm3d_torch(c_grad_t) 23 | norm_batch_torch[_] = c_norm_t 24 | return norm_batch_torch 25 | 26 | 27 | def get_norm3d(gradients): 28 | assert type(gradients) is tuple 29 | assert len(gradients) == 3 30 | norm = gradients[0]**2 + gradients[1]**2 + gradients[0]**2 31 | norm = np.sqrt(norm) 32 | return norm 33 | 34 | def get_norm3d_torch(gradients_torch): 35 | assert type(gradients_torch) is tuple 36 | assert len(gradients_torch) == 3 37 | norm_torch\ 38 | = gradients_torch[0]**2 + gradients_torch[1]**2 + gradients_torch[0]**2 39 | norm_torch = torch.sqrt(norm_torch) 40 | return norm_torch 41 | 42 | 43 | def apply_diff3d_torch(phi_t, i, j, k, direction, type, step=1): 44 | assert len(phi_t.shape) == 3 # phi_t is 3d 45 | assert direction in ['X', 'Y', 'Z'] 46 | assert type in ['forward', 'backward', 'central'] 47 | if type == 'forward': 48 | if direction == 'X': 49 | c_value = (phi_t[i+step,j, k]-phi_t[i,j, k])/step 50 | elif direction == 'Y': 51 | c_value = (phi_t[i,j+step, k]-phi_t[i,j, k])/step 52 | elif direction == 'Z': 53 | c_value = (phi_t[i,j, k+step]-phi_t[i,j, k])/step 54 | 55 | elif type == 'backward': 56 | if direction == 'X': 57 | c_value = (phi_t[i,j, k]-phi_t[i-step,j, k])/step 58 | elif direction == 'Y': 59 | c_value = (phi_t[i,j, k]-phi_t[i,j-step, k])/step 60 | elif direction == 'Z': 61 | c_value = (phi_t[i,j, k]-phi_t[i,j, k-step])/step 62 | 63 | elif type == 'central': 64 | if direction == 'X': 65 | c_value = (phi_t[i+step,j, k]-phi_t[i-step,j, k])/(2*step) 66 | elif direction == 'Y': 67 | c_value = (phi_t[i,j+step, k]-phi_t[i,j-step, k])/(2*step) 68 | elif direction == 'Z': 69 | c_value = (phi_t[i,j, k+step]-phi_t[i,j, k-step])/(2*step) 70 | 71 | return c_value 72 | 73 | def apply_diff3d(phi, i, j, k, direction, type, step=1): 74 | assert len(phi.shape) == 3 # phi is 3d 75 | assert direction in ['X', 'Y', 'Z'] 76 | assert type in ['forward', 'backward', 'central'] 77 | if type == 'forward': 78 | if direction == 'X': 79 | c_value = (phi[i+step,j, k]-phi[i,j, k])/step 80 | elif direction == 'Y': 81 | c_value = (phi[i,j+step, k]-phi[i,j, k])/step 82 | elif direction == 'Z': 83 | c_value = (phi[i,j, k+step]-phi[i,j, k])/step 84 | 85 | elif type == 'backward': 86 | if direction == 'X': 87 | c_value = (phi[i,j, k]-phi[i-step,j, k])/step 88 | elif direction == 'Y': 89 | c_value = (phi[i,j, k]-phi[i,j-step, k])/step 90 | elif direction == 'Z': 91 | c_value = (phi[i,j, k]-phi[i,j, k-step])/step 92 | 93 | elif type == 'central': 94 | if direction == 'X': 95 | c_value = (phi[i+step,j, k]-phi[i-step,j, k])/(2*step) 96 | elif direction == 'Y': 97 | c_value = (phi[i,j+step, k]-phi[i,j-step, k])/(2*step) 98 | elif direction == 'Z': 99 | c_value = (phi[i,j, k+step]-phi[i,j, k-step])/(2*step) 100 | 101 | return c_value 102 | 103 | def get_grad3d_batch(phi): 104 | assert len(phi.shape) == 4 105 | batch_size = phi.shape[0] 106 | grad_list = [] 107 | for i in range(batch_size): 108 | c_grad = get_grad3d(phi[i]) 109 | grad_list.append(c_grad) 110 | return grad_list 111 | 112 | def get_grad3d_batch_torch(phi_t): 113 | assert len(phi_t.shape) == 4 114 | batch_size = phi_t.shape[0] 115 | grad_list = [] 116 | for i in range(batch_size): 117 | c_grad = get_grad3d_torch(phi_t[i]) 118 | grad_list.append(c_grad) 119 | return grad_list 120 | 121 | 122 | def get_grad3d_torch(phi_t): 123 | assert len(phi_t.shape) == 3 # phi_t is 3d 124 | grad_x = torch.zeros(phi_t.shape) 125 | grad_y = torch.zeros(phi_t.shape) 126 | grad_z = torch.zeros(phi_t.shape) 127 | 128 | X = phi_t.shape[0] 129 | Y = phi_t.shape[1] 130 | Z = phi_t.shape[1] 131 | 132 | assert X == Y and Y == Z 133 | 134 | grid_size = X 135 | 136 | set_trace() 137 | print('try to paralelize here') 138 | 139 | for i in range(X): 140 | for j in range(Y): 141 | for k in range(Z): 142 | 143 | # do gradient with respect to X ; i axis 144 | if i == 0: # forward difference 145 | grad_x[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'X', 'forward') 146 | elif i == grid_size-1: # backward diff 147 | grad_x[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'X', 'backward') 148 | else: # central diff 149 | grad_x[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'X', 'central') 150 | 151 | # do gradient with respect to Y ; j axis 152 | if j == 0: # forward difference 153 | grad_y[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Y', 'forward') 154 | elif j == grid_size-1: # backward diff 155 | grad_y[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Y', 'backward') 156 | else: # central diff 157 | grad_y[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Y', 'central') 158 | 159 | # do gradient with respect to Z ; k axis 160 | if k == 0: # forward difference 161 | grad_z[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Z', 'forward') 162 | elif k == grid_size-1: # backward diff 163 | grad_z[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Z', 'backward') 164 | else: # central diff 165 | grad_z[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Z', 'central') 166 | 167 | return (grad_x, grad_y, grad_z) 168 | 169 | 170 | def get_grad3d(phi): 171 | assert len(phi.shape) == 3 # phi is 3d 172 | grad_x = np.zeros(phi.shape) 173 | grad_y = np.zeros(phi.shape) 174 | grad_z = np.zeros(phi.shape) 175 | 176 | X = phi.shape[0] 177 | Y = phi.shape[1] 178 | Z = phi.shape[1] 179 | 180 | assert X == Y and Y == Z 181 | 182 | grid_size = X 183 | 184 | for i in range(X): 185 | for j in range(Y): 186 | for k in range(Z): 187 | 188 | # do gradient with respect to X ; i axis 189 | if i == 0: # forward difference 190 | grad_x[i,j, k] = apply_diff3d(phi, i, j, k, 'X', 'forward') 191 | elif i == grid_size-1: # backward diff 192 | grad_x[i,j, k] = apply_diff3d(phi, i, j, k, 'X', 'backward') 193 | else: # central diff 194 | grad_x[i,j, k] = apply_diff3d(phi, i, j, k, 'X', 'central') 195 | 196 | # do gradient with respect to Y ; j axis 197 | if j == 0: # forward difference 198 | grad_y[i,j, k] = apply_diff3d(phi, i, j, k, 'Y', 'forward') 199 | elif j == grid_size-1: # backward diff 200 | grad_y[i,j, k] = apply_diff3d(phi, i, j, k, 'Y', 'backward') 201 | else: # central diff 202 | grad_y[i,j, k] = apply_diff3d(phi, i, j, k, 'Y', 'central') 203 | 204 | # do gradient with respect to Z ; k axis 205 | if k == 0: # forward difference 206 | grad_z[i,j, k] = apply_diff3d(phi, i, j, k, 'Z', 'forward') 207 | elif k == grid_size-1: # backward diff 208 | grad_z[i,j, k] = apply_diff3d(phi, i, j, k, 'Z', 'backward') 209 | else: # central diff 210 | grad_z[i,j, k] = apply_diff3d(phi, i, j, k, 'Z', 'central') 211 | 212 | return (grad_x, grad_y, grad_z) 213 | 214 | 215 | def apply_diff2d(phi, i, j, direction, type, step=1): 216 | assert len(phi.shape) == 2 # phi is 2d 217 | assert direction in ['X', 'Y'] 218 | assert type in ['forward', 'backward', 'central'] 219 | if type == 'forward': 220 | if direction == 'X': 221 | c_value = (phi[i+step,j]-phi[i,j])/step 222 | elif direction == 'Y': 223 | c_value = (phi[i,j+step]-phi[i,j])/step 224 | elif type == 'backward': 225 | if direction == 'X': 226 | c_value = (phi[i,j]-phi[i-step,j])/step 227 | elif direction == 'Y': 228 | c_value = (phi[i,j]-phi[i,j-step])/step 229 | elif type == 'central': 230 | if direction == 'X': 231 | c_value = (phi[i+step,j]-phi[i-step,j])/(2*step) 232 | elif direction == 'Y': 233 | c_value = (phi[i,j+step]-phi[i,j-step])/(2*step) 234 | return c_value 235 | 236 | def get_grad2d_fast(phi): 237 | set_trace() 238 | grad_x = np.zeros(phi.shape) 239 | grad_y = np.zeros(phi.shape) 240 | X = phi.shape[0] 241 | Y = phi.shape[1] 242 | grid_size = X 243 | 244 | def get_grad2d(phi): 245 | grad_x = np.zeros(phi.shape) 246 | grad_y = np.zeros(phi.shape) 247 | X = phi.shape[0] 248 | Y = phi.shape[1] 249 | grid_size = X 250 | for i in range(X): 251 | for j in range(Y): 252 | # do gradient with respect to X ; i axis 253 | if i == 0: # forward difference 254 | grad_x[i,j] = apply_diff2d(phi, i, j, 'X', 'forward') 255 | elif i == grid_size-1: # backward diff 256 | grad_x[i,j] = apply_diff2d(phi, i, j, 'X', 'backward') 257 | else: # central diff 258 | grad_x[i,j] = apply_diff2d(phi, i, j, 'X', 'central') 259 | 260 | # do gradient with respect to Y ; j axis 261 | if j == 0: # forward difference 262 | grad_y[i,j] = apply_diff2d(phi, i, j, 'Y', 'forward') 263 | elif j == grid_size-1: # backward diff 264 | grad_y[i,j] = apply_diff2d(phi, i, j, 'Y', 'backward') 265 | else: # central diff 266 | grad_y[i,j] = apply_diff2d(phi, i, j, 'Y', 'central') 267 | return (grad_x, grad_y) 268 | 269 | def test_grad_2d(grid_size=32): 270 | set_trace() 271 | phi = np.random.rand(grid_size, grid_size) 272 | my_grad = get_grad2d(phi) 273 | np_grad = np.gradient(phi) 274 | assert (my_grad[0] - np_grad[0]).sum() == 0 275 | assert (my_grad[1] - np_grad[1]).sum() == 0 276 | print('2D test successfull.') 277 | 278 | def test_grad_3d(grid_size=32): 279 | set_trace() 280 | phi = np.random.rand(grid_size, grid_size, grid_size) 281 | my_grad = get_grad3d(phi) 282 | np_grad = np.gradient(phi) 283 | assert (my_grad[0] - np_grad[0]).sum() == 0 284 | assert (my_grad[1] - np_grad[1]).sum() == 0 285 | assert (my_grad[2] - np_grad[2]).sum() == 0 286 | print('3D test successfull.') 287 | 288 | def test_get_grad3d_torch(grid_size=32): 289 | phi = np.random.rand(grid_size, grid_size, grid_size) 290 | phi_t = torch.from_numpy(phi) 291 | np_grad = get_grad3d(phi) 292 | torch_grad = get_grad3d_torch(phi_t) 293 | set_trace() 294 | print('compare the two') 295 | 296 | def test_get_norm3d_batch_torch(batch_size=16, grid_size=32): 297 | phi = np.random.rand(batch_size, grid_size, grid_size, grid_size) 298 | phi_t = torch.from_numpy(phi) 299 | grad_list_np = get_grad3d_batch(phi) 300 | grad_list_torch = get_grad3d_batch_torch(phi_t) 301 | 302 | 303 | 304 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | from __future__ import print_function 12 | import argparse 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | from torchvision import datasets, transforms 18 | from torch.autograd import Variable 19 | import scipy.io as sio 20 | import functools 21 | import PIL 22 | import logging 23 | import time 24 | from utils import * 25 | import math 26 | import sys, os 27 | 28 | 29 | from matr_iccv.train import iou_voxel, iou_shapelayer, my_iou_voxel, my_iou_sdf 30 | from matr_iccv.ResNet import * 31 | from matr_iccv.DatasetLoader import * 32 | from matr_iccv.DatasetCollector import * 33 | from ipdb import set_trace 34 | from scipy.stats import logistic 35 | from vis_objects import vis_view, vis_mesh, vis_sdf, vis_voxels, show, merge_mesh 36 | import trimesh 37 | 38 | from binvox_rw import read_as_3d_array 39 | 40 | from progressbar import AnimatedMarker, Bar, BouncingBar, Counter, ETA, \ 41 | AdaptiveETA, FileTransferSpeed, FormatLabel, Percentage, \ 42 | ProgressBar, ReverseBar, RotatingMarker, \ 43 | SimpleProgress, Timer 44 | 45 | 46 | widgets = [Percentage(), 47 | ' ', Bar(), 48 | ' ', ETA(), 49 | ' ', AdaptiveETA()] 50 | 51 | 52 | def vis_slice(sdf_slice): 53 | from mayavi import mlab 54 | mlab.surf(sdf_slice, warp_scale='auto') 55 | 56 | if __name__ == '__main__': 57 | 58 | 59 | logging.basicConfig(level=logging.INFO) 60 | logging.info(sys.argv) # nice to have in log files 61 | 62 | # register networks, datasets, etc. 63 | name2net = {'resnet': ResNet} 64 | net_default = 'resnet' 65 | 66 | name2dataset = {'ShapeNet':ShapeNet3DR2N2Collector} 67 | dataset_default = 'ShapeNet' 68 | 69 | parser = argparse.ArgumentParser(description='Train a Matryoshka Network') 70 | 71 | # general options 72 | parser.add_argument('--title', type=str, default='matryoshka', help='Title in logs, filename (default: matryoshka).') 73 | parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') 74 | parser.add_argument('--gpu', type=int, default=0, help='GPU ID if cuda is available and enabled') 75 | parser.add_argument('--batchsize', type=int, default=16, help='input batch size for training (default: 128)') 76 | parser.add_argument('--nthreads', type=int, default=1, help='number of threads for loader') 77 | parser.add_argument('--save_inter', type=int, default=10, help='Saving interval in epochs (default: 10)') 78 | 79 | # options for dataset 80 | parser.add_argument('--dataset', type=str, default=dataset_default, help=('Dataset [%s]' % ','.join(name2dataset.keys()))) 81 | parser.add_argument('--set', type=str, default='test', help='Validation or test set. (default: val) or train. input samples', choices=['train', 'val', 'test']) 82 | parser.add_argument('--basedir', type=str, default='/media/data/', help='Base directory for dataset.') 83 | 84 | # options for network 85 | parser.add_argument('--file', type=str, default=None, help='Savegame') 86 | parser.add_argument('--net', type=str, default=net_default, help=('Network architecture [%s]' % ','.join(name2net.keys()))) 87 | parser.add_argument('--ncomp', type=int, default=1, help='Number of nested shape layers (default: 1)') 88 | 89 | # other options 90 | parser.add_argument('--label_type', type=str, default='vox', help='Type of representation: vox(voxels), sdf or mesh') 91 | 92 | parser.add_argument('--vis_pred', action='store_true', default=False, help='if True, will only print predictions') 93 | parser.add_argument('--gen_report', type=str, default='no', help='if not then generate iou, imgs or all the above') 94 | parser.add_argument('--test_ids', type=str, default='0 5 10', help='if not then generate iou, loss, imgs or all the above') 95 | parser.add_argument('--cat_id', type=str, default='02958343', help='cat_id, default is cars 02958343') 96 | parser.add_argument('--path_to_res', type=str, default='/media/results/', help='path to output results') 97 | parser.add_argument('--path_to_prep_shapenet', type=str, default='/media/prep_shapenet', help='path to prep shapenet') 98 | parser.add_argument('--path_to_data', type=str, default='/media/data', help='path to output results') 99 | 100 | parser.add_argument('--side', type=int, default=32, help='Output resolution [if dataset has multiple resolutions.] (default: 128)') 101 | parser.add_argument('--p_norm', type=int, default=1, help='p_norm for paper loss, default =2') 102 | 103 | args = parser.parse_args() 104 | args.cuda = not args.no_cuda and torch.cuda.is_available() 105 | 106 | assert args.gen_report in ['iou', 'loss', 'imgs', 'all', 'no'] 107 | 108 | device = torch.device("cuda:{}".format(args.gpu) if args.cuda else "cpu") 109 | 110 | torch.manual_seed(1) 111 | 112 | if args.file == None: 113 | args.file = '/home/results/matryoshka_ShapeNet_10.pth.tar' 114 | else: 115 | pass 116 | assert os.path.exists(args.file) 117 | 118 | args.no_cuda = True 119 | 120 | if args.vis_pred: 121 | args.batchsize=1 122 | 123 | savegame = torch.load(args.file) 124 | args.side = savegame['side'] 125 | 126 | # load dataset 127 | try: 128 | logging.info('Initializing dataset "%s"' % args.dataset) 129 | Collector = ShapeNet3DR2N2Collector(base_dir=args.basedir,cat_id=args.cat_id, 130 | representation=args.label_type, side=args.side, p_norm=args.p_norm) 131 | except KeyError: 132 | logging.error('A dataset named "%s" is not available.' % args.dataset) 133 | exit(1) 134 | 135 | logging.info('Initializing dataset loader') 136 | set_trace() 137 | 138 | if args.set == 'val': 139 | samples = Collector.val() 140 | elif args.set == 'test': 141 | samples = Collector.test() 142 | elif args.set == 'train': 143 | samples = Collector.train() 144 | 145 | num_samples = len(samples) 146 | logging.info('Found %d test samples.' % num_samples) 147 | test_loader = torch.utils.data.DataLoader(DatasetLoader(samples, args.side, args.ncomp, \ 148 | input_transform=transforms.Compose([transforms.ToTensor()])), \ 149 | batch_size=args.batchsize, shuffle=False, num_workers=args.nthreads, \ 150 | pin_memory=True) 151 | 152 | if args.gen_report != 'no': 153 | test_ids = list(map(int, args.test_ids.split(' '))) 154 | samples_small = [] 155 | for test_id in test_ids: 156 | samples_small += [samples[test_id*24]] 157 | 158 | test_loader_small = torch.utils.data.DataLoader(DatasetLoader(samples_small, args.ncomp, \ 159 | input_transform=transforms.Compose([transforms.ToTensor()])), \ 160 | batch_size=len(samples_small), shuffle=False, num_workers=1, \ 161 | pin_memory=True) 162 | 163 | # gather np guys from test_ids 164 | test_ids = list(map(int, args.test_ids.split(' '))) 165 | test_guys_np = np.zeros((len(test_ids), args.side, args.side, args.side)) 166 | 167 | if args.label_type == 'vox': 168 | for _, test_id in enumerate(test_ids): 169 | path_to_vox_tmp = samples[test_id*args.batchsize][1] 170 | with open(path_to_vox_tmp, 'rb') as fin: 171 | vox_tmp = read_as_3d_array(fin).data 172 | test_guys_np[_] = vox_tmp 173 | elif args.label_type == 'sdf': 174 | for _, test_id in enumerate(test_ids): 175 | path_to_sdf_tmp = samples[test_id*args.batchsize][1] 176 | sdf_tmp = np.load(path_to_sdf_tmp)\ 177 | .reshape(args.side, args.side, args.side) 178 | test_guys_np[_] = sdf_tmp 179 | else: 180 | print('strange label type') 181 | exit() 182 | 183 | samples = [] 184 | 185 | net = name2net[args.net](\ 186 | num_input_channels=3, 187 | num_initial_channels=savegame['ninf'], 188 | num_inner_channels=savegame['ngf'], 189 | num_penultimate_channels=savegame['noutf'], 190 | num_output_channels=6*args.ncomp, 191 | input_resolution=128, 192 | output_resolution=savegame['side'], 193 | num_downsampling=savegame['down'], 194 | bottleneck_dim = 128, 195 | num_blocks=savegame['block'], 196 | ).to(device) 197 | logging.info(net) 198 | net.load_state_dict(savegame['state_dict']) 199 | 200 | net.eval() 201 | 202 | agg_iou = 0. 203 | count = 0 204 | results = torch.zeros(args.batchsize*100, 6, 128,128).to(device) 205 | 206 | ctr_tej=0 207 | real_ctr = 0 208 | vox_threshold = 0.5 209 | 210 | iou_list = [] 211 | 212 | 213 | if args.gen_report == 'imgs' or args.gen_report == 'all': 214 | with torch.no_grad(): 215 | for batch_idx, (inputs, targets, ex_ids) in enumerate(test_loader_small): 216 | if args.label_type == 'vox': 217 | inputs = inputs.to(device, non_blocking=True) 218 | targets = targets.to(device, non_blocking=True) 219 | pred = net(inputs) 220 | pred = torch.sigmoid(pred) > vox_threshold 221 | targets = targets.float() 222 | pred = pred.float() 223 | targets = targets.cpu().numpy() 224 | pred = pred.cpu().numpy() 225 | inputs = inputs.cpu().numpy() 226 | elif args.label_type == 'sdf': 227 | inputs = inputs.to(device, non_blocking=True) 228 | targets = targets.to(device, non_blocking=True) 229 | pred = net(inputs) # this is compressed 230 | targets = targets.float() 231 | pred = pred.float() 232 | targets = targets.cpu().numpy() 233 | pred = pred.cpu().numpy() 234 | inputs = inputs.cpu().numpy() 235 | # assert batch_idx == 0 236 | # GT 237 | tmp_str = os.path.basename(args.file) 238 | path_to_out_gt = os.path.join(os.path.dirname(args.file), 'imgs', 239 | tmp_str[:-4]+'_gt_imgs') 240 | cmd = 'mkdir -p '+os.path.dirname(path_to_out_gt) 241 | os.system(cmd) 242 | np.save(path_to_out_gt, targets) 243 | # PRED 244 | path_to_out_gt = os.path.join(os.path.dirname(args.file), 'imgs', 245 | tmp_str[:-4]+'_pred_imgs') 246 | np.save(path_to_out_gt, pred) 247 | # VIEW 248 | path_to_out_view = os.path.join(os.path.dirname(args.file), 'imgs', 249 | tmp_str[:-4]+'_view_imgs') 250 | np.save(path_to_out_view, inputs) 251 | # EXAMPLE IDS 252 | example_ids = [] 253 | for i in range(len(samples_small)): 254 | example_id = samples_small[i][1] 255 | tmp_right_idx = example_id.rfind('/') 256 | tmp_left_idx = example_id.rfind('/', 0, tmp_right_idx) 257 | example_id = example_id[tmp_left_idx+1:tmp_right_idx] 258 | example_ids.append(example_id) 259 | path_to_out_ids = os.path.join(os.path.dirname(args.file), 'imgs', 260 | tmp_str[:-4]+'_ids_imgs.txt') 261 | with open(path_to_out_ids, 'w') as fout: 262 | fout.write('\n'.join(example_ids)) 263 | 264 | with torch.no_grad(): 265 | for batch_idx, (inputs, targets, example_ids) in enumerate(test_loader): 266 | print(batch_idx) 267 | # VIS INPUTS BLOCK 268 | if args.vis_pred: 269 | if batch_idx % 24 != 10: # there are 24 views, take only 1 270 | continue 271 | set_trace() 272 | inputs = inputs.to(device, non_blocking=True) 273 | pred = net(inputs) 274 | inputs = inputs.cpu() 275 | pred = pred.cpu() 276 | 277 | example_nr=0 278 | c_view = inputs[example_nr].numpy() 279 | c_label = targets[example_nr].numpy() 280 | vis_view(c_view) 281 | if args.label_type == 'vox': 282 | c_pred = logistic.cdf(pred[example_nr].numpy()) > vox_threshold 283 | vis_voxels(c_label, color=(0,1,0)) 284 | vis_voxels(c_pred) 285 | elif args.label_type == 'sdf': 286 | # c_pred = pred[example_nr].numpy()+0.1 287 | c_pred = pred[example_nr].numpy() 288 | vis_sdf(c_label, color=(0,1,0)) 289 | vis_sdf(c_pred) 290 | elif args.label_type == 'chamfer': 291 | # vis gt manifold 292 | path_to_man = os.path.join(args.path_to_prep_shapenet, 293 | args.cat_id, 'meshes', 294 | example_ids[example_nr], 295 | 'sim_manifold_50000.obj') 296 | c_label_man = trimesh.load_mesh(path_to_man) 297 | if type(c_label_man) is list: 298 | c_label_man = merge_mesh(c_label_man) 299 | vis_mesh(c_label_man.vertices, c_label_man.faces) 300 | # vis gt SDF 301 | path_to_sdf = os.path.join(args.path_to_data, 302 | 'ShapeNetSDF'+str(args.side), args.cat_id, 303 | example_ids[example_nr], 304 | 'sdf_manifold_50000_grid_size_'+str(args.side)+'.npy') 305 | c_label_sdf = np.load(path_to_sdf)\ 306 | .reshape((args.side,)*3) 307 | vis_sdf(c_label_sdf, color=(1,0,0)) 308 | 309 | c_pred = pred[example_nr].numpy() 310 | vis_sdf(c_pred) 311 | 312 | else: 313 | print('unknown label type') 314 | exit() 315 | show() 316 | print('Should I stop? ') 317 | user_response = input() 318 | if user_response == 'y': 319 | exit() 320 | else: 321 | continue 322 | else: 323 | if args.label_type == 'vox': 324 | inputs = inputs.to(device, non_blocking=True) 325 | targets = targets.to(device, non_blocking=True) 326 | pred = net(inputs) 327 | pred = torch.sigmoid(pred) > vox_threshold 328 | targets = targets.float() 329 | pred = pred.float() 330 | 331 | targets = targets.detach().cpu().numpy() 332 | pred = pred.detach().cpu().numpy() 333 | inputs = inputs.detach().cpu().numpy() 334 | 335 | tmp_nr = 7 336 | example_id = example_ids[tmp_nr] 337 | path_to_out_gt = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'metrics', example_id, 338 | example_id+'_vox_gt_'+str(args.side)+'_pad_1.npy') 339 | cmd = 'mkdir -p '+os.path.dirname(path_to_out_gt) 340 | os.system(cmd) 341 | np.save(path_to_out_gt, targets[tmp_nr]) 342 | 343 | path_to_out_pred = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'metrics', example_id, 344 | example_id+'_vox_pred_'+str(args.side)+'_pad_1.npy') 345 | np.save(path_to_out_pred, pred[tmp_nr]) 346 | 347 | path_to_out_view = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'metrics', example_id, 348 | example_id+'_view.npy') 349 | np.save(path_to_out_view, inputs[tmp_nr]) 350 | 351 | elif args.label_type == 'sdf': 352 | inputs = inputs.to(device, non_blocking=True) 353 | targets = targets.to(device, non_blocking=True) 354 | pred = net(inputs) 355 | targets = targets.float() 356 | pred = pred.float() 357 | # Just save to file 358 | targets = targets.detach().cpu().numpy() 359 | pred = pred.detach().cpu().numpy() 360 | inputs = inputs.detach().cpu().numpy() 361 | tmp_nr = 7 362 | example_id = example_ids[tmp_nr] 363 | path_to_out_gt = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'metrics', example_id, 364 | example_id+'_sdf_gt_'+str(args.side)+'_pad_1.npy') 365 | cmd = 'mkdir -p '+os.path.dirname(path_to_out_gt) 366 | os.system(cmd) 367 | np.save(path_to_out_gt, targets[tmp_nr]) 368 | 369 | path_to_out_pred = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'metrics', example_id, 370 | example_id+'_sdf_pred_'+str(args.side)+'_pad_1.npy') 371 | np.save(path_to_out_pred, pred[tmp_nr]) 372 | 373 | path_to_out_view = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'metrics', example_id, 374 | example_id+'_view.npy') 375 | np.save(path_to_out_view, inputs[tmp_nr]) 376 | 377 | 378 | else: 379 | print('unknown label type') 380 | exit() 381 | if args.gen_report == 'iou' or args.gen_report == 'all': 382 | if args.label_type == 'vox': 383 | iou, bs = my_iou_voxel(pred, targets) 384 | iou_list += iou.cpu().tolist() 385 | elif args.label_type == 'sdf': 386 | iou, bs = my_iou_sdf(pred, targets) 387 | iou_list += iou.cpu().tolist() 388 | else: 389 | print('unkown label type') 390 | exit() 391 | 392 | 393 | 394 | 395 | pass 396 | 397 | 398 | -------------------------------------------------------------------------------- /grad_approx_fun/finite_diff_torch_faster.py: -------------------------------------------------------------------------------- 1 | from ipdb import set_trace 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def get_norm3d_batch(list_of_grad): 7 | assert type(list_of_grad) is list 8 | batch_size = len(list_of_grad) 9 | grid_size = list_of_grad[0][0].shape[0] 10 | norm_batch = np.zeros((batch_size, grid_size, grid_size, grid_size)) 11 | for _, c_grad in enumerate(list_of_grad): 12 | c_norm = get_norm3d(c_grad) 13 | norm_batch[_] = c_norm 14 | return norm_batch 15 | 16 | def get_norm3d_batch_torch(list_of_grad_torch, device): 17 | # set_trace() 18 | assert type(list_of_grad_torch) is list 19 | batch_size = len(list_of_grad_torch) 20 | grid_size = list_of_grad_torch[0][0].shape[0] 21 | norm_batch_torch\ 22 | = torch.zeros((batch_size, grid_size, grid_size, grid_size),\ 23 | dtype=list_of_grad_torch[0][0].dtype, device=device) 24 | for _, c_grad_t in enumerate(list_of_grad_torch): 25 | c_norm_t = get_norm3d_torch(c_grad_t) 26 | norm_batch_torch[_] = c_norm_t 27 | return norm_batch_torch 28 | 29 | 30 | def get_norm3d(gradients): 31 | assert type(gradients) is tuple 32 | assert len(gradients) == 3 33 | norm = gradients[0]**2 + gradients[1]**2 + gradients[2]**2 34 | norm = np.sqrt(norm) 35 | return norm 36 | 37 | def get_norm3d_torch(gradients_torch, small_nr = 1e-8): 38 | assert type(gradients_torch) is tuple 39 | assert len(gradients_torch) == 3 40 | norm_torch\ 41 | = gradients_torch[0]**2 + gradients_torch[1]**2 + gradients_torch[2]**2 42 | norm_torch = torch.sqrt(norm_torch+small_nr) 43 | return norm_torch 44 | 45 | 46 | def apply_diff3d_torch(phi_t, i, j, k, direction, type, step=1): 47 | assert len(phi_t.shape) == 3 # phi_t is 3d 48 | assert direction in ['X', 'Y', 'Z'] 49 | assert type in ['forward', 'backward', 'central'] 50 | if type == 'forward': 51 | if direction == 'X': 52 | c_value = (phi_t[i+step,j, k]-phi_t[i,j, k])/step 53 | elif direction == 'Y': 54 | c_value = (phi_t[i,j+step, k]-phi_t[i,j, k])/step 55 | elif direction == 'Z': 56 | c_value = (phi_t[i,j, k+step]-phi_t[i,j, k])/step 57 | 58 | elif type == 'backward': 59 | if direction == 'X': 60 | c_value = (phi_t[i,j, k]-phi_t[i-step,j, k])/step 61 | elif direction == 'Y': 62 | c_value = (phi_t[i,j, k]-phi_t[i,j-step, k])/step 63 | elif direction == 'Z': 64 | c_value = (phi_t[i,j, k]-phi_t[i,j, k-step])/step 65 | 66 | elif type == 'central': 67 | if direction == 'X': 68 | c_value = (phi_t[i+step,j, k]-phi_t[i-step,j, k])/(2*step) 69 | elif direction == 'Y': 70 | c_value = (phi_t[i,j+step, k]-phi_t[i,j-step, k])/(2*step) 71 | elif direction == 'Z': 72 | c_value = (phi_t[i,j, k+step]-phi_t[i,j, k-step])/(2*step) 73 | 74 | return c_value 75 | 76 | def apply_diff3d(phi, i, j, k, direction, type, step=1): 77 | assert len(phi.shape) == 3 # phi is 3d 78 | assert direction in ['X', 'Y', 'Z'] 79 | assert type in ['forward', 'backward', 'central'] 80 | if type == 'forward': 81 | if direction == 'X': 82 | c_value = (phi[i+step,j, k]-phi[i,j, k])/step 83 | elif direction == 'Y': 84 | c_value = (phi[i,j+step, k]-phi[i,j, k])/step 85 | elif direction == 'Z': 86 | c_value = (phi[i,j, k+step]-phi[i,j, k])/step 87 | 88 | elif type == 'backward': 89 | if direction == 'X': 90 | c_value = (phi[i,j, k]-phi[i-step,j, k])/step 91 | elif direction == 'Y': 92 | c_value = (phi[i,j, k]-phi[i,j-step, k])/step 93 | elif direction == 'Z': 94 | c_value = (phi[i,j, k]-phi[i,j, k-step])/step 95 | 96 | elif type == 'central': 97 | if direction == 'X': 98 | c_value = (phi[i+step,j, k]-phi[i-step,j, k])/(2*step) 99 | elif direction == 'Y': 100 | c_value = (phi[i,j+step, k]-phi[i,j-step, k])/(2*step) 101 | elif direction == 'Z': 102 | c_value = (phi[i,j, k+step]-phi[i,j, k-step])/(2*step) 103 | 104 | return c_value 105 | 106 | def get_grad3d_batch_faster(phi): 107 | assert len(phi.shape) == 4 108 | batch_size = phi.shape[0] 109 | grad_list = [] 110 | for i in range(batch_size): 111 | c_grad = get_grad3d_faster(phi[i]) 112 | grad_list.append(c_grad) 113 | return grad_list 114 | 115 | 116 | def get_grad3d_batch(phi): 117 | assert len(phi.shape) == 4 118 | batch_size = phi.shape[0] 119 | grad_list = [] 120 | for i in range(batch_size): 121 | c_grad = get_grad3d(phi[i]) 122 | grad_list.append(c_grad) 123 | return grad_list 124 | 125 | def get_grad3d_batch_faster_torch(phi_t, device): 126 | assert len(phi_t.shape) == 4 127 | batch_size = phi_t.shape[0] 128 | grad_list = [] 129 | for i in range(batch_size): 130 | c_grad = get_grad3d_faster_torch(phi_t[i], device) 131 | grad_list.append(c_grad) 132 | return grad_list 133 | 134 | def get_grad3d_batch_torch(phi_t): 135 | assert len(phi_t.shape) == 4 136 | batch_size = phi_t.shape[0] 137 | grad_list = [] 138 | for i in range(batch_size): 139 | c_grad = get_grad3d_torch(phi_t[i]) 140 | grad_list.append(c_grad) 141 | return grad_list 142 | 143 | def get_grad3d_torch(phi_t): 144 | assert len(phi_t.shape) == 3 145 | grad_x = torch.zeros(phi_t.shape) 146 | grad_y = torch.zeros(phi_t.shape) 147 | grad_z = torch.zeros(phi_t.shape) 148 | X = phi_t.shape[0] 149 | Y = phi_t.shape[1] 150 | Z = phi_t.shape[1] 151 | assert X == Y and Y == Z 152 | grid_size = X 153 | 154 | for i in range(X): 155 | for j in range(Y): 156 | for k in range(Z): 157 | 158 | # do gradient with respect to X ; i axis 159 | if i == 0: # forward difference 160 | grad_x[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'X', 'forward') 161 | elif i == grid_size-1: # backward diff 162 | grad_x[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'X', 'backward') 163 | else: # central diff 164 | grad_x[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'X', 'central') 165 | 166 | # do gradient with respect to Y ; j axis 167 | if j == 0: # forward difference 168 | grad_y[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Y', 'forward') 169 | elif j == grid_size-1: # backward diff 170 | grad_y[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Y', 'backward') 171 | else: # central diff 172 | grad_y[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Y', 'central') 173 | 174 | # do gradient with respect to Z ; k axis 175 | if k == 0: # forward difference 176 | grad_z[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Z', 'forward') 177 | elif k == grid_size-1: # backward diff 178 | grad_z[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Z', 'backward') 179 | else: # central diff 180 | grad_z[i,j, k] = apply_diff3d_torch(phi_t, i, j, k, 'Z', 'central') 181 | 182 | return (grad_x, grad_y, grad_z) 183 | 184 | def get_grad3d(phi): 185 | assert len(phi.shape) == 3 # phi is 3d 186 | grad_x = np.zeros(phi.shape) 187 | grad_y = np.zeros(phi.shape) 188 | grad_z = np.zeros(phi.shape) 189 | 190 | X = phi.shape[0] 191 | Y = phi.shape[1] 192 | Z = phi.shape[1] 193 | 194 | assert X == Y and Y == Z 195 | 196 | grid_size = X 197 | 198 | for i in range(X): 199 | for j in range(Y): 200 | for k in range(Z): 201 | 202 | # do gradient with respect to X ; i axis 203 | if i == 0: # forward difference 204 | grad_x[i,j, k] = apply_diff3d(phi, i, j, k, 'X', 'forward') 205 | elif i == grid_size-1: # backward diff 206 | grad_x[i,j, k] = apply_diff3d(phi, i, j, k, 'X', 'backward') 207 | else: # central diff 208 | grad_x[i,j, k] = apply_diff3d(phi, i, j, k, 'X', 'central') 209 | 210 | # do gradient with respect to Y ; j axis 211 | if j == 0: # forward difference 212 | grad_y[i,j, k] = apply_diff3d(phi, i, j, k, 'Y', 'forward') 213 | elif j == grid_size-1: # backward diff 214 | grad_y[i,j, k] = apply_diff3d(phi, i, j, k, 'Y', 'backward') 215 | else: # central diff 216 | grad_y[i,j, k] = apply_diff3d(phi, i, j, k, 'Y', 'central') 217 | 218 | # do gradient with respect to Z ; k axis 219 | if k == 0: # forward difference 220 | grad_z[i,j, k] = apply_diff3d(phi, i, j, k, 'Z', 'forward') 221 | elif k == grid_size-1: # backward diff 222 | grad_z[i,j, k] = apply_diff3d(phi, i, j, k, 'Z', 'backward') 223 | else: # central diff 224 | grad_z[i,j, k] = apply_diff3d(phi, i, j, k, 'Z', 'central') 225 | 226 | return (grad_x, grad_y, grad_z) 227 | 228 | def apply_diff2d(phi, i, j, direction, type, step=1): 229 | assert len(phi.shape) == 2 # phi is 2d 230 | assert direction in ['X', 'Y'] 231 | assert type in ['forward', 'backward', 'central'] 232 | if type == 'forward': 233 | if direction == 'X': 234 | c_value = (phi[i+step,j]-phi[i,j])/step 235 | elif direction == 'Y': 236 | c_value = (phi[i,j+step]-phi[i,j])/step 237 | elif type == 'backward': 238 | if direction == 'X': 239 | c_value = (phi[i,j]-phi[i-step,j])/step 240 | elif direction == 'Y': 241 | c_value = (phi[i,j]-phi[i,j-step])/step 242 | elif type == 'central': 243 | if direction == 'X': 244 | c_value = (phi[i+step,j]-phi[i-step,j])/(2*step) 245 | elif direction == 'Y': 246 | c_value = (phi[i,j+step]-phi[i,j-step])/(2*step) 247 | return c_value 248 | 249 | def get_grad3d_faster_torch(phi_t, device): 250 | assert len(phi_t.shape) == 3 251 | 252 | grad_t_x = torch.zeros(phi_t.shape, dtype=phi_t.dtype, device=device) 253 | grad_t_y = torch.zeros(phi_t.shape, dtype=phi_t.dtype, device=device) 254 | grad_t_z = torch.zeros(phi_t.shape, dtype=phi_t.dtype, device=device) 255 | 256 | X = phi_t.shape[0] 257 | Y = phi_t.shape[1] 258 | Z = phi_t.shape[2] 259 | 260 | grid_size = X 261 | 262 | # CENTRAL 263 | # get the slice: 264 | phi_t_slice_central = phi_t[1:-1, 1:-1, 1:-1] 265 | # X direction 266 | phi_t_slice_central_x_pos = phi_t[1+1:, :, :] # should be : -0 267 | phi_t_slice_central_x_neg = phi_t[1-1:-1-1, :, :] 268 | grad_t_x[1:-1, :, :] = (phi_t_slice_central_x_pos - phi_t_slice_central_x_neg)/2.0 269 | # Y direction 270 | phi_t_slice_central_y_pos = phi_t[:, 1+1:, :] 271 | phi_t_slice_central_y_neg = phi_t[:, 1-1:-1-1, :] 272 | grad_t_y[:, 1:-1, :] = (phi_t_slice_central_y_pos - phi_t_slice_central_y_neg)/2.0 273 | # Z direction 274 | phi_t_slice_central_z_pos = phi_t[:,:, 1+1:] 275 | phi_t_slice_central_z_neg = phi_t[:,:, 1-1:-1-1] 276 | grad_t_z[:,:, 1:-1] = (phi_t_slice_central_z_pos - phi_t_slice_central_z_neg)/2.0 277 | # FORWARD 278 | # X direction 279 | phi_t_slice_forward_x = phi_t[0,:,:] 280 | phi_t_slice_forward_x_pos = phi_t[0+1,:,:] 281 | grad_t_x[0,:,:] = phi_t_slice_forward_x_pos - phi_t_slice_forward_x 282 | # Y direction 283 | phi_t_slice_forward_y = phi_t[:,0,:] 284 | phi_t_slice_forward_y_pos = phi_t[:,0+1,:] 285 | grad_t_y[:,0,:] = phi_t_slice_forward_y_pos - phi_t_slice_forward_y 286 | # Z direction 287 | phi_t_slice_forward_z = phi_t[:,:,0] 288 | phi_t_slice_forward_z_pos = phi_t[:,:,0+1] 289 | grad_t_z[:,:,0] = phi_t_slice_forward_z_pos - phi_t_slice_forward_z 290 | # BACKWARD 291 | # X direction 292 | phi_t_slice_backward_x = phi_t[-1,:,:] 293 | phi_t_slice_backward_x_neg = phi_t[-1-1,:,:] 294 | grad_t_x[-1,:,:] = phi_t_slice_backward_x - phi_t_slice_backward_x_neg 295 | # Y direction 296 | phi_t_slice_backward_y = phi_t[:,-1,:] 297 | phi_t_slice_backward_y_neg = phi_t[:,-1-1,:] 298 | grad_t_y[:,-1,:] = phi_t_slice_backward_y - phi_t_slice_backward_y_neg 299 | # Z direction 300 | phi_t_slice_backward_z = phi_t[:,:,-1] 301 | phi_t_slice_backward_z_neg = phi_t[:,:,-1-1] 302 | grad_t_z[:,:,-1] = phi_t_slice_backward_z - phi_t_slice_backward_z_neg 303 | 304 | return (grad_t_x, grad_t_y, grad_t_z) 305 | 306 | 307 | def get_grad3d_faster(phi): 308 | assert len(phi.shape) == 3 309 | 310 | grad_x = np.zeros(phi.shape) 311 | grad_y = np.zeros(phi.shape) 312 | grad_z = np.zeros(phi.shape) 313 | 314 | X = phi.shape[0] 315 | Y = phi.shape[1] 316 | Z = phi.shape[2] 317 | 318 | grid_size = X 319 | 320 | # CENTRAL 321 | # get the slice: 322 | phi_slice_central = phi[1:-1, 1:-1, 1:-1] 323 | # X direction 324 | phi_slice_central_x_pos = phi[1+1:, :, :] # should be : -0 325 | phi_slice_central_x_neg = phi[1-1:-1-1, :, :] 326 | grad_x[1:-1, :, :] = (phi_slice_central_x_pos - phi_slice_central_x_neg)/2.0 327 | # Y direction 328 | phi_slice_central_y_pos = phi[:, 1+1:, :] 329 | phi_slice_central_y_neg = phi[:, 1-1:-1-1, :] 330 | grad_y[:, 1:-1, :] = (phi_slice_central_y_pos - phi_slice_central_y_neg)/2.0 331 | # Z direction 332 | phi_slice_central_z_pos = phi[:,:, 1+1:] 333 | phi_slice_central_z_neg = phi[:,:, 1-1:-1-1] 334 | grad_z[:,:, 1:-1] = (phi_slice_central_z_pos - phi_slice_central_z_neg)/2.0 335 | # FORWARD 336 | # X direction 337 | phi_slice_forward_x = phi[0,:,:] 338 | phi_slice_forward_x_pos = phi[0+1,:,:] 339 | grad_x[0,:,:] = phi_slice_forward_x_pos - phi_slice_forward_x 340 | # Y direction 341 | phi_slice_forward_y = phi[:,0,:] 342 | phi_slice_forward_y_pos = phi[:,0+1,:] 343 | grad_y[:,0,:] = phi_slice_forward_y_pos - phi_slice_forward_y 344 | # Z direction 345 | phi_slice_forward_z = phi[:,:,0] 346 | phi_slice_forward_z_pos = phi[:,:,0+1] 347 | grad_z[:,:,0] = phi_slice_forward_z_pos - phi_slice_forward_z 348 | # BACKWARD 349 | # X direction 350 | phi_slice_backward_x = phi[-1,:,:] 351 | phi_slice_backward_x_neg = phi[-1-1,:,:] 352 | grad_x[-1,:,:] = phi_slice_backward_x - phi_slice_backward_x_neg 353 | # Y direction 354 | phi_slice_backward_y = phi[:,-1,:] 355 | phi_slice_backward_y_neg = phi[:,-1-1,:] 356 | grad_y[:,-1,:] = phi_slice_backward_y - phi_slice_backward_y_neg 357 | # Z direction 358 | phi_slice_backward_z = phi[:,:,-1] 359 | phi_slice_backward_z_neg = phi[:,:,-1-1] 360 | grad_z[:,:,-1] = phi_slice_backward_z - phi_slice_backward_z_neg 361 | 362 | return (grad_x, grad_y, grad_z) 363 | 364 | 365 | def get_grad2d_fast(phi): 366 | grad_x = np.zeros(phi.shape) 367 | grad_y = np.zeros(phi.shape) 368 | X = phi.shape[0] 369 | Y = phi.shape[1] 370 | grid_size = X 371 | # CENTRAL 372 | # get the slice: 373 | phi_slice_central = phi[1:-1, 1:-1] 374 | # X direction 375 | phi_slice_central_x_pos = phi[1+1:, :] # should be : -0 376 | phi_slice_central_x_neg = phi[1-1:-1-1, :] 377 | grad_x[1:-1, :] = (phi_slice_central_x_pos - phi_slice_central_x_neg)/2.0 378 | # Y direction 379 | phi_slice_central_y_pos = phi[:, 1+1:] 380 | phi_slice_central_y_neg = phi[:, 1-1:-1-1] 381 | grad_y[:, 1:-1] = (phi_slice_central_y_pos - phi_slice_central_y_neg)/2.0 382 | # FORWARD 383 | # X direction 384 | phi_slice_forward_x = phi[0,:] 385 | phi_slice_forward_x_pos = phi[0+1,:] 386 | grad_x[0,:] = phi_slice_forward_x_pos - phi_slice_forward_x 387 | # Y direction 388 | phi_slice_forward_y = phi[:,0] 389 | phi_slice_forward_y_pos = phi[:,0+1] 390 | grad_y[:,0] = phi_slice_forward_y_pos - phi_slice_forward_y 391 | # BACKWARD 392 | # X direction 393 | phi_slice_backward_x = phi[-1,:] 394 | phi_slice_backward_x_neg = phi[-1-1,:] 395 | grad_x[-1,:] = phi_slice_backward_x - phi_slice_backward_x_neg 396 | # Y direction 397 | phi_slice_backward_y = phi[:,-1] 398 | phi_slice_backward_y_neg = phi[:,-1-1] 399 | grad_y[:,-1] = phi_slice_backward_y - phi_slice_backward_y_neg 400 | 401 | return (grad_x, grad_y) 402 | 403 | def get_grad2d(phi): 404 | grad_x = np.zeros(phi.shape) 405 | grad_y = np.zeros(phi.shape) 406 | X = phi.shape[0] 407 | Y = phi.shape[1] 408 | grid_size = X 409 | for i in range(X): 410 | for j in range(Y): 411 | # do gradient with respect to X ; i axis 412 | if i == 0: # forward difference 413 | grad_x[i,j] = apply_diff2d(phi, i, j, 'X', 'forward') 414 | elif i == grid_size-1: # backward diff 415 | grad_x[i,j] = apply_diff2d(phi, i, j, 'X', 'backward') 416 | else: # central diff 417 | grad_x[i,j] = apply_diff2d(phi, i, j, 'X', 'central') 418 | 419 | # do gradient with respect to Y ; j axis 420 | if j == 0: # forward difference 421 | grad_y[i,j] = apply_diff2d(phi, i, j, 'Y', 'forward') 422 | elif j == grid_size-1: # backward diff 423 | grad_y[i,j] = apply_diff2d(phi, i, j, 'Y', 'backward') 424 | else: # central diff 425 | grad_y[i,j] = apply_diff2d(phi, i, j, 'Y', 'central') 426 | return (grad_x, grad_y) 427 | 428 | def test_grad_2d(grid_size=32): 429 | set_trace() 430 | 431 | phi = np.random.rand(grid_size, grid_size) 432 | 433 | my_grad = get_grad2d(phi) 434 | np_grad = np.gradient(phi) 435 | 436 | assert (my_grad[0] - np_grad[0]).sum() == 0 437 | assert (my_grad[1] - np_grad[1]).sum() == 0 438 | print('2D test successfull.') 439 | 440 | def test_grad_3d(grid_size=32): 441 | set_trace() 442 | 443 | phi = np.random.rand(grid_size, grid_size, grid_size) 444 | 445 | my_grad = get_grad3d(phi) 446 | np_grad = np.gradient(phi) 447 | 448 | assert (my_grad[0] - np_grad[0]).sum() == 0 449 | assert (my_grad[1] - np_grad[1]).sum() == 0 450 | assert (my_grad[2] - np_grad[2]).sum() == 0 451 | 452 | print('3D test successfull.') 453 | 454 | def test_get_grad3d_faster_torch(grid_size=32): 455 | 456 | phi = np.random.rand(grid_size, grid_size, grid_size) 457 | phi_t = torch.from_numpy(phi) 458 | 459 | np_grad = get_grad3d_faster(phi) 460 | torch_grad = get_grad3d_faster_torch(phi_t) 461 | 462 | assert (np_grad[0]-torch_grad[0].cpu().detach().numpy()).sum() == 0 463 | assert (np_grad[1]-torch_grad[1].cpu().detach().numpy()).sum() == 0 464 | assert (np_grad[2]-torch_grad[2].cpu().detach().numpy()).sum() == 0 465 | 466 | print('get_grad3d faster in torch match those in numpy') 467 | 468 | def test_get_grad3d_torch(grid_size=32): 469 | 470 | phi = np.random.rand(grid_size, grid_size, grid_size) 471 | phi_t = torch.from_numpy(phi) 472 | 473 | np_grad = get_grad3d(phi) 474 | torch_grad = get_grad3d_torch(phi_t) 475 | 476 | set_trace() 477 | print('compare the two') 478 | 479 | def test_get_norm3d_batch_faster_torch(batch_size=16, grid_size=32): 480 | phi = np.random.rand(batch_size, grid_size, grid_size, grid_size) 481 | phi_t = torch.from_numpy(phi) 482 | 483 | grad_list_np = get_grad3d_batch_faster(phi) 484 | grad_list_torch = get_grad3d_batch_faster_torch(phi_t) 485 | 486 | 487 | def test_get_norm3d_batch_torch(batch_size=16, grid_size=32): 488 | phi = np.random.rand(batch_size, grid_size, grid_size, grid_size) 489 | phi_t = torch.from_numpy(phi) 490 | 491 | grad_list_np = get_grad3d_batch(phi) 492 | grad_list_torch = get_grad3d_batch_torch(phi_t) 493 | 494 | def test_get_grad2d(grid_size=32): 495 | phi = np.random.rand(grid_size, grid_size) 496 | 497 | grad_np = np.gradient(phi) 498 | 499 | my_grad = get_grad2d_fast(phi) 500 | 501 | assert (my_grad[0]-grad_np[0]).sum() == 0 502 | assert (my_grad[1]-grad_np[1]).sum() == 0 503 | 504 | print('get_grad 2d, fast: all good') 505 | 506 | def test_get_grad3d(grid_size=32): 507 | phi = np.random.rand(grid_size, grid_size, grid_size) 508 | 509 | grad_np = np.gradient(phi) 510 | 511 | my_grad = get_grad3d_fast(phi) 512 | 513 | assert (my_grad[0]-grad_np[0]).sum() == 0 514 | assert (my_grad[1]-grad_np[1]).sum() == 0 515 | assert (my_grad[2]-grad_np[2]).sum() == 0 516 | 517 | 518 | print('get_grad 3d, fast: all good') 519 | 520 | 521 | 522 | 523 | 524 | 525 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is based on the Matryoshka [1] repository [2] and was modified accordingly: 3 | 4 | [1] https://arxiv.org/abs/1804.10975 5 | 6 | [2] https://bitbucket.org/visinf/projects-2018-matryoshka/src/master/ 7 | 8 | Copyright (c) 2018, Visual Inference Lab @TU Darmstadt 9 | ''' 10 | 11 | from __future__ import print_function 12 | import argparse 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | from torchvision import datasets, transforms 18 | from torch.autograd import Variable 19 | import scipy.io as sio 20 | import functools 21 | import PIL 22 | import logging 23 | import time 24 | #from utils import * 25 | import math 26 | import sys, os 27 | 28 | from matr_iccv.ResNet import * 29 | from matr_iccv.DatasetLoader import * 30 | from matr_iccv.DatasetCollector import * 31 | 32 | from matr_iccv.grad_approx_fun.finite_diff_torch_faster import get_grad3d_batch_faster_torch,\ 33 | get_norm3d_batch_torch 34 | 35 | 36 | from ipdb import set_trace 37 | 38 | from matr_iccv.vis_objects import vis_view, vis_mesh, vis_sdf, vis_voxels, show, merge_mesh 39 | import trimesh 40 | 41 | 42 | def vis_slice(voxels, slice_nr=64, axis='y'): 43 | from matplotlib import pyplot as plt 44 | if axis == 'x': 45 | tmp_slice = voxels[slice_nr,:,:] 46 | elif axis == 'y': 47 | tmp_slice = voxels[:,slice_nr,:] 48 | elif axis == 'z': 49 | tmp_slice = voxels[:,:,slice_nr] 50 | plt.imshow(tmp_slice);plt.show() 51 | 52 | 53 | 54 | def dirac_delta_torch(phi_pred, eps=0.1): 55 | small_eps = 10**(-6) # just in case 56 | mask = (torch.abs(phi_pred) <= eps) 57 | dirac_delta = (1/(2*eps)) * (1+torch.cos(np.pi * phi_pred/eps)) 58 | if phi_pred.dtype == torch.float32: 59 | dirac_delta = dirac_delta * mask.float() 60 | elif phi_pred.dtype == torch.float64: 61 | dirac_delta = dirac_delta * mask.double() 62 | else: 63 | print('strange dtype in delta') 64 | exit() 65 | return dirac_delta 66 | 67 | def my_loss(pred, target, loss_type, p_norm=2, eps=0.1, alpha_sdf=0.1): 68 | if loss_type == 'cross': 69 | loss = F.binary_cross_entropy(torch.sigmoid(pred), target.float()) 70 | elif loss_type == 'l2': 71 | pred.type = torch.float64 72 | loss = F.mse_loss(pred, target) 73 | elif loss_type == 'chamfer': 74 | delta = dirac_delta_torch(pred.view(target.shape), eps=eps) 75 | chamf_loss = (delta*target) # BS X side**3 76 | chamf_loss = chamf_loss.sum(dim=1) # BS 77 | chamf_loss = chamf_loss**(1/p_norm) # BS, do abs? 78 | chamf_loss = chamf_loss.sum() 79 | phi_grad_t = get_grad3d_batch_faster_torch(pred, device) 80 | phi_grad_norm_t = get_norm3d_batch_torch(phi_grad_t, device) 81 | sdf_loss = (phi_grad_norm_t-1)**2 82 | sdf_loss_flatten = sdf_loss.view(sdf_loss.shape[0],-1) 83 | sdf_loss_final = sdf_loss_flatten.mean(dim=1) # BS 84 | sdf_loss_final = sdf_loss_final.sum() 85 | 86 | return (chamf_loss, sdf_loss_final) 87 | else: 88 | print('unkown loss type') 89 | exit() 90 | return loss 91 | 92 | def pos_loss(pred, target, num_components=6): 93 | """ Modified L1-loss, which penalizes background pixels 94 | only if predictions are closer than 1 to being considered foreground. 95 | """ 96 | fg_loss = pred.new_zeros(1) 97 | bg_loss = pred.new_zeros(1) 98 | fg_count = 0 # counter for normalization 99 | bg_count = 0 # counter for normalization 100 | 101 | for i in range(num_components): 102 | mask = target[:,i,:,:].gt(0).float().detach() 103 | target_i = target[:,i,:,:] 104 | pred_i = pred[:,i,:,:] 105 | # L1 between prediction and target only for foreground 106 | fg_loss += torch.mean((torch.abs(pred_i-target_i)).mul(mask)) 107 | fg_count += torch.mean(mask) 108 | # flip mask => background 109 | mask = 1-mask 110 | # L1 for background pixels > -1 111 | bg_loss += torch.mean(((pred_i + 1)).clamp(min=0).mul(mask)) 112 | bg_count += torch.mean(mask) 113 | pass 114 | 115 | return fg_loss / max(1, fg_count) + \ 116 | bg_loss / max(1, bg_count) 117 | 118 | def my_iou_voxel(pred, voxel): 119 | """ Computes intersection over union between two shapes. 120 | Returns iou with the length of a batch 121 | """ 122 | pred = pred.detach() 123 | voxel = voxel.detach() 124 | 125 | bs,_,h,w = pred.size() 126 | 127 | inter = pred.mul(voxel).detach() 128 | union = pred.add(voxel).detach() 129 | union = union.sub_(inter) # probably to reduce 2 to 1. 130 | inter = inter.sum(3).sum(2).sum(1) 131 | union = union.sum(3).sum(2).sum(1) 132 | return inter.div(union), bs 133 | 134 | def my_iou_sdf(pred, target): 135 | ''' 136 | calculates iou between two sdfs. 137 | For simplicity, isosurface is extracted through > 0 138 | ''' 139 | pred_bin = (pred >= 0).type(torch.float64) 140 | target_bin = (target >= 0).type(torch.float64) 141 | 142 | pred_bin = pred_bin.detach() 143 | target_bin = target_bin.detach() 144 | 145 | bs,_,h,w = pred_bin.size() 146 | 147 | inter = pred_bin.mul(target_bin).detach() 148 | union = pred_bin.add(target_bin).detach() 149 | union = union.sub_(inter) # probably to reduce 2 to 1. 150 | inter = inter.sum(3).sum(2).sum(1) 151 | union = union.sum(3).sum(2).sum(1) 152 | return inter.div(union), bs 153 | 154 | 155 | 156 | def iou_voxel(pred, voxel): 157 | """ Computes intersection over union between two shapes. 158 | Returns iou summed over batch 159 | """ 160 | bs,_,h,w = pred.size() 161 | 162 | inter = pred.mul(voxel).detach() 163 | union = pred.add(voxel).detach() 164 | union = union.sub_(inter) 165 | inter = inter.sum(3).sum(2).sum(1) 166 | union = union.sum(3).sum(2).sum(1) 167 | return inter.div(union).sum(), bs 168 | 169 | 170 | def iou_shapelayer(pred, voxel, id1, id2, id3): 171 | """ Compares prediction and ground truth shape layers using IoU. 172 | Returns iou summed over batch and number of samples in batch. 173 | """ 174 | pred = pred.detach() 175 | voxel = voxel.detach() 176 | 177 | bs, _, side, _ = pred.shape 178 | vp = pred.new_zeros(bs,side,side,side, requires_grad=False) 179 | vt = pred.new_zeros(bs,side,side,side, requires_grad=False) 180 | 181 | for i in range(bs): 182 | vp[i,:,:,:] = decode_shape(pred[i,:,:,:].short().permute(1,2,0), id1, id2, id3) 183 | vt[i,:,:,:] = decode_shape(voxel[i,:,:,:].short().permute(1,2,0), id1, id2, id3) 184 | 185 | return iou_voxel(vp,vt) 186 | 187 | 188 | k_save = 0 189 | def save(c, d, name=None): 190 | global k_save 191 | if c: 192 | k_save += 1 193 | if name is None: 194 | name = 'dbg_%d.mat' % k_save 195 | sio.savemat(name, {k:d[k].detach().cpu().numpy() for k in d.keys()}) 196 | 197 | if __name__ == '__main__': 198 | 199 | logging.basicConfig(level=logging.INFO) 200 | logging.info(sys.argv) # nice to have in log files 201 | 202 | # register networks, datasets, etc. 203 | name2net = {'resnet': ResNet} 204 | net_default = 'resnet' 205 | 206 | name2dataset = {\ 207 | 'ShapeNet':ShapeNet3DR2N2Collector} 208 | dataset_default = 'ShapeNet' 209 | 210 | name2optim = {'adam': optim.Adam} 211 | optim_default = 'adam' 212 | 213 | parser = argparse.ArgumentParser(description='Train a Matryoshka Network') 214 | 215 | # general options 216 | parser.add_argument('--title', type=str, default='matryoshka', help='Title in logs, filename (default: matryoshka).') 217 | parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') 218 | parser.add_argument('--gpu', type=int, default=0, help='GPU ID if cuda is available and enabled') 219 | parser.add_argument('--no_save', action='store_true', default=False, help='Disables saving of final model') 220 | parser.add_argument('--no_val', action='store_true', default=False, help='Disable validation for faster training') 221 | parser.add_argument('--batchsize', type=int, default=16, help='input batch size for training (default: 32)') 222 | parser.add_argument('--epochs', type=int, default=1, help='number of epochs to train') 223 | parser.add_argument('--nthreads', type=int, default=4, help='number of threads for loader') 224 | parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') 225 | parser.add_argument('--val_inter', type=int, default=1, help='Validation interval in epochs (default: 1)') 226 | parser.add_argument('--log_inter', type=int, default=100, help='Logging interval in batches (default: 100)') 227 | parser.add_argument('--save_inter', type=int, default=10, help='Saving interval in epochs (default: 10)') 228 | 229 | # options for optimizer 230 | parser.add_argument('--optim', type=str, default=optim_default, help=('Optimizer [%s]' % ','.join(name2optim.keys()))) 231 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate (default: 1e-3)') 232 | parser.add_argument('--decay', type=float, default=0, help='Weight decay for optimizer (default: 0)') 233 | parser.add_argument('--drop', type=int, default=30) 234 | 235 | # options for dataset 236 | parser.add_argument('--dataset', type=str, default=dataset_default, help=('Dataset [%s]' % ','.join(name2dataset.keys()))) 237 | parser.add_argument('--basedir', type=str, default='/media/data/', help='Base directory for dataset.') 238 | parser.add_argument('--no_shuffle_train', action='store_true', default=False, help='Disable shuffling of training samples') 239 | parser.add_argument('--no_shuffle_val', action='store_true', default=False, help='Disable shuffling of validation samples') 240 | 241 | # options for network 242 | parser.add_argument('--file', type=str, default=None, help='Savegame') 243 | parser.add_argument('--net', type=str, default=net_default, help=('Network architecture [%s]' % ','.join(name2net.keys()))) 244 | parser.add_argument('--side', type=int, default=32, help='Output resolution [if dataset has multiple resolutions.] (default: 128)') 245 | parser.add_argument('--ncomp', type=int, default=1, help='Number of nested shape layers (default: 1)') 246 | parser.add_argument('--ninf', type=int, default=8, help='Number of initial feature channels (default: 8)') 247 | parser.add_argument('--ngf', type=int, default=512, help='Number of inner channels to train (default: 512)') 248 | parser.add_argument('--noutf', type=int, default=128, help='Number of penultimate feature channels (default: 128)') 249 | parser.add_argument('--down', type=int, default=5, help='Number of downsampling blocks. (default: 5)') 250 | parser.add_argument('--block', type=int, default=1, help='Number of inner blocks at same resolution. (default: 1)') 251 | 252 | # options for visualisation 253 | parser.add_argument('--vis_inputs', action='store_true', default=False, help='if True, will only print inputs') 254 | 255 | # other options 256 | parser.add_argument('--label_type', type=str, default='vox', help='Type of representation: vox(voxels), sdf or chamfer') 257 | parser.add_argument('--path_to_res', type=str, default='/media/results', help='path to output results') 258 | parser.add_argument('--path_to_data', type=str, default='/media/data', help='path to output results') 259 | parser.add_argument('--path_to_prep_shapenet', type=str, default='/media/data/prep_shapenet', help='path to prep shapenet') 260 | 261 | parser.add_argument('--subset', type=str, default='train', help='data subset, can be train or test') 262 | 263 | parser.add_argument('--cat_id', type=str, default='02958343', help='cat_id, default is cars 02958343') 264 | parser.add_argument('--p_norm', type=int, default=1, help='p_norm for paper loss') 265 | parser.add_argument('--eps_delta', type=float, default=0.1, help='epsilon for dirac delta') 266 | 267 | 268 | args = parser.parse_args() 269 | 270 | 271 | args.cuda = not args.no_cuda and torch.cuda.is_available() 272 | args.shuffle_train = not args.no_shuffle_train 273 | args.shuffle_val = not args.no_shuffle_val 274 | 275 | device = torch.device("cuda:{}".format(args.gpu) if args.cuda else "cpu") 276 | 277 | torch.manual_seed(1) 278 | 279 | if args.vis_inputs == True: 280 | args.shuffle_train = False 281 | 282 | # load paths of voxels and views 283 | try: 284 | logging.info('Initializing dataset "%s"' % args.dataset) 285 | Collector = ShapeNet3DR2N2Collector(base_dir=args.basedir,cat_id = args.cat_id, 286 | representation=args.label_type, side=args.side, p_norm=args.p_norm) 287 | except KeyError: 288 | logging.error('A dataset named "%s" is not available.' % args.net) 289 | exit(1) 290 | 291 | set_trace() 292 | # print('check out samples') 293 | 294 | logging.info('Initializing dataset loader') 295 | if args.subset == 'train': 296 | samples = Collector.train() 297 | elif args.subset == 'test': 298 | samples = Collector.test() 299 | else: 300 | print('unknown subset') 301 | exit() 302 | 303 | logging.info('Found %d training samples.' % len(samples)) 304 | 305 | acoto_dataset= DatasetLoader(samples, args.side, 306 | input_transform=transforms.Compose([transforms.ToTensor(), RandomColorFlip()])) 307 | 308 | train_loader = torch.utils.data.DataLoader(acoto_dataset, \ 309 | batch_size=args.batchsize, shuffle=args.shuffle_train, num_workers=args.nthreads, \ 310 | pin_memory=True) 311 | 312 | if args.no_val: 313 | samples = Collector.val() 314 | logging.info('Found %d validation samples.' % len(samples)) 315 | val_loader = torch.utils.data.DataLoader(DatasetLoader(samples, args.ncomp, \ 316 | input_transform=transforms.Compose([transforms.ToTensor()])), \ 317 | batch_size=args.batchsize, shuffle=args.shuffle_val, num_workers=args.nthreads, \ 318 | pin_memory=True) 319 | pass 320 | 321 | samples = [] 322 | 323 | # load network 324 | try: 325 | logging.info('Initializing "%s" network' % args.net) 326 | net = name2net[args.net](\ 327 | num_input_channels=3, 328 | num_initial_channels=args.ninf, 329 | num_inner_channels=args.ngf, 330 | num_penultimate_channels=args.noutf, 331 | num_output_channels=6*args.ncomp, 332 | input_resolution=128, 333 | output_resolution=32, 334 | bottleneck_dim = 128, 335 | num_downsampling=args.down, 336 | num_blocks=args.block, 337 | ).to(device) 338 | logging.info(net) 339 | except KeyError: 340 | logging.error('A network named "%s" is not available.' % args.net) 341 | exit(2) 342 | if args.file: 343 | savegame = torch.load(args.file) 344 | net.load_state_dict(savegame['state_dict']) 345 | 346 | # init optimizer 347 | try: 348 | logging.info('Initializing "%s" optimizer with learning rate = %f and weight decay = %f' % (args.optim, args.lr, args.decay)) 349 | optimizer = name2optim[args.optim](net.parameters(), lr=args.lr, weight_decay=args.decay) 350 | except KeyError: 351 | logging.error('An optimizer named "%s" is not available.' % args.optim) 352 | exit(3) 353 | 354 | try: 355 | net.train() 356 | 357 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.drop, gamma=0.5) 358 | if args.label_type == 'vox': 359 | loss_type = 'cross' 360 | elif args.label_type == 'sdf': 361 | loss_type = 'l2' 362 | elif args.label_type == 'chamfer': 363 | loss_type = 'chamfer' 364 | else: 365 | print('unknown label_type') 366 | exit() 367 | 368 | 369 | agg_loss = 0. 370 | count = 0 371 | 372 | for epoch in range(1, args.epochs + 1): 373 | scheduler.step() 374 | # VIS INPUTS BLOCK 375 | if args.vis_inputs: 376 | coto_dataiter = iter(train_loader) 377 | while True: 378 | set_trace() 379 | print('check out images and labels') 380 | images, labels, example_ids = coto_dataiter.next() 381 | for example_nr in range(args.batchsize): 382 | c_view = images[example_nr].numpy() 383 | c_label = labels[example_nr].numpy() 384 | vis_view(c_view) 385 | if args.label_type == 'vox': 386 | vis_voxels(c_label) 387 | elif args.label_type == 'sdf': 388 | vis_sdf(c_label) 389 | elif args.label_type == 'chamfer': 390 | # vis gt manifold 391 | path_to_man = os.path.join(args.path_to_prep_shapenet, 392 | args.cat_id, 'meshes', 393 | example_ids[example_nr], 394 | 'sim_manifold_50000.obj') 395 | c_label_man = trimesh.load_mesh(path_to_man) 396 | if type(c_label_man) is list: 397 | c_label_man = merge_mesh(c_label_man) 398 | vis_mesh(c_label_man.vertices, c_label_man.faces) 399 | # vis gt SDF 400 | path_to_sdf = os.path.join(args.path_to_data, 401 | 'ShapeNetSDF'+str(args.side), args.cat_id, 402 | example_ids[example_nr], 403 | 'sdf_manifold_50000_grid_size_'+str(args.side)+'.npy') 404 | c_label_sdf = np.load(path_to_sdf)\ 405 | .reshape((args.side,)*3) 406 | vis_sdf(c_label_sdf) 407 | 408 | else: 409 | print('unknown label_type') 410 | exit() 411 | show() 412 | print('Should I stop? ') 413 | user_response = input() 414 | if user_response == 'y': 415 | exit() 416 | else: 417 | break 418 | tmp_loss_list = [] 419 | 420 | for batch_idx, (inputs, targets, _) in enumerate(train_loader): 421 | optimizer.zero_grad() 422 | 423 | inputs = inputs.to(device, non_blocking=True) 424 | targets = targets.to(device, non_blocking=True) 425 | pred = net(inputs) # this is compressed 426 | 427 | if loss_type == 'chamfer': 428 | loss_chamf, loss_sdf = my_loss(pred, targets, loss_type, eps=args.eps_delta, 429 | p_norm=args.p_norm) 430 | alpha_sdf = 0.1 431 | loss = loss_chamf + alpha_sdf * loss_sdf 432 | else: 433 | loss = my_loss(pred, targets, loss_type, eps=args.eps_delta, 434 | p_norm=args.p_norm) 435 | 436 | 437 | loss.backward() 438 | optimizer.step() 439 | 440 | agg_loss += loss.detach() 441 | count += inputs.shape[0] 442 | tmp_loss_list.append(loss.item()) 443 | pass 444 | 445 | if batch_idx % args.log_inter == 0: 446 | logging.info('%d/%d: Train loss: %.10f [%s]' % (epoch, batch_idx, agg_loss.item()/count, args.title)) 447 | 448 | agg_loss = 0. 449 | count = 0 450 | 451 | filename = '%s_%s_%s_%d.pth.tar'\ 452 | % (args.title, args.dataset, args.label_type, epoch) 453 | path_to_out = os.path.join(args.path_to_res, args.cat_id, args.label_type, 'loss', 454 | filename+'_loss.txt') 455 | cmd = 'mkdir -p '+os.path.dirname(path_to_out) 456 | os.system(cmd) 457 | tmp_loss_list = list(map(str, tmp_loss_list)) 458 | 459 | with open(path_to_out, 'a') as fout: 460 | fout.write('\n'.join(tmp_loss_list)) 461 | 462 | if (epoch % 10 == 0): 463 | filename = '%s_%s_%s_%d.pth.tar'\ 464 | % (args.title, args.dataset, args.label_type, epoch) 465 | logging.info('Saving model to %s.' % filename) 466 | torch.save({'state_dict': net.state_dict(), 467 | 'optimizer' : optimizer.state_dict(), 468 | 'ninf':args.ninf, 469 | 'ngf':args.ngf, 470 | 'noutf':args.noutf, 471 | 'block':args.block, 472 | 'side': args.side, 473 | 'down':args.down, 474 | 'epoch': epoch, 475 | 'optim': args.optim, 476 | 'lr': args.lr, 477 | }, os.path.join(args.path_to_res, args.cat_id, args.label_type, filename)) 478 | 479 | if args.no_val and epoch % args.val_inter == 0: # FIX 480 | net.eval() 481 | 482 | agg_iou = 0. 483 | count = 0 484 | with torch.no_grad(): 485 | for batch_idx, (inputs, targets) in enumerate(val_loader): 486 | 487 | inputs = inputs.to(device, non_blocking=True) 488 | targets = targets.to(device, non_blocking=True) 489 | 490 | pred = net(inputs) 491 | iou, bs = iou_shapelayer(shlx2shl(pred), targets, id1, id2, id3) 492 | agg_iou += float(iou) 493 | count += bs 494 | 495 | pass 496 | pass 497 | 498 | net.train() 499 | 500 | total_iou = (100 * agg_iou / count) if count > 0 else 0 501 | 502 | logging.info('%d: Val set accuracy, iou: %.1f [%s]' % (epoch, total_iou, args.title)) 503 | pass 504 | pass 505 | 506 | except KeyboardInterrupt: 507 | pass 508 | 509 | if not args.no_save or False: 510 | filename = '%s_%s_%d.pth.tar' % (args.title, args.dataset, epoch) 511 | logging.info('Saving model to %s.' % filename) 512 | torch.save({'state_dict': net.state_dict(), 513 | 'optimizer' : optimizer.state_dict(), 514 | 'ninf':args.ninf, 515 | 'ngf':args.ngf, 516 | 'noutf':args.noutf, 517 | 'block':args.block, 518 | 'side': args.side, 519 | 'down':args.down, 520 | 'epoch': epoch, 521 | 'optim': args.optim, 522 | 'lr': args.lr, 523 | }, os.path.join(args.path_to_res, args.cat_id, args.label_type, filename)) 524 | --------------------------------------------------------------------------------