├── .gitmodules ├── README.md ├── config_cars.json ├── config_chairs.json ├── config_planes.json ├── dataset └── dataset.py ├── dataset_generation_scripts ├── cloud_export.py ├── generate.py ├── obj2gltf.py ├── render.py ├── render180.py └── requirements.txt ├── eval.py ├── examples ├── car.gif ├── chair.gif ├── interpolation.gif └── plane.gif ├── models ├── encoder.py ├── nerf.py └── resnet.py ├── nerf_helpers.py ├── pts2nerf.py ├── render_samples.py ├── requirements.txt └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ChamferDistancePytorch"] 2 | path = ChamferDistancePytorch 3 | url = https://github.com/ThibaultGROUEIX/ChamferDistancePytorch 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Points2NeRF: Generating Neural Radiance Fields from 3D point cloud 2 | 3 | 4 | ![Car](examples/car.gif) ![Plane](examples/plane.gif) ![Chair](examples/chair.gif) 5 | ![Interpolation](examples/interpolation.gif) 6 | 7 | | arXiv | 8 | | :---- | 9 | | [Points2NeRF: Generating Neural Radiance Fields from 3D point cloud](https://arxiv.org/pdf/2206.01290.pdf)| 10 | 11 | 12 | Model based on VAE/Hypernetwork architecture, takes as input colored cloud of points, encodes them into latent space and generates NeRF (xyz$\alpha\beta$ -> RGB) functions that can be used in volumetric rendering reconstruction. 13 | 14 | ### Abstract 15 | *Contemporary registration devices for 3D visual information, such as LIDARs and various depth cameras, capture data as 3D point clouds. In turn, such clouds are challenging to be processed due to their size and complexity. Existing methods address this problem by fitting a mesh to the point cloud and rendering it instead. This approach, however, leads to the reduced fidelity of the resulting visualization and misses color information of the objects crucial in computer graphics applications. In this work, we propose to mitigate this challenge by representing 3D objects as Neural Radiance Fields (NeRFs). We leverage a hypernetwork paradigm and train the model to take a 3D point cloud with the associated color values and return a NeRF network's weights that reconstruct 3D objects from input 2D images. Our method provides efficient 3D object representation and offers several advantages over the existing approaches, including the ability to condition NeRFs and improved generalization beyond objects seen in training. The latter we also confirmed in the results of our empirical evaluation.* 16 | 17 | ## Requirements 18 | - Dependencies stored in `requirements.txt`. 19 | - submodule ![ChamferDistancePytorch](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch) 20 | - Python 3.9.12 21 | - CUDA 22 | 23 | ## Usage 24 | 25 | ### Installation 26 | Create new conda environment 27 | `pip install -r requirements.txt` 28 | 29 | ### Training 30 | Edit config file, by setting up dataset and results paths: 31 | 32 | `python pts2nerf.py config_cars.json` 33 | 34 | Results will be saved in the directory defined in config file. 35 | 36 | ### Evaluation and Sampling 37 | 38 | You can use pre-trained model: 39 | 40 | [Download model here.](https://drive.google.com/drive/folders/1dcSxbXfSYpwjcazrVsHm3WNWcB8uPps-?usp=sharing) 41 | 42 | For calculating metrics: 43 | `python eval.py config_cars.json` 44 | 45 | For sampling images, interpolations: 46 | 47 | `python render_samples.py config_cars.json -o_anim_count=10 -g_anim_count=25 -i_anim_count=5 -train_ds=0 -epoch=1000` 48 | 49 | Above line, will render: 10 objects reconstruction image sets, 25 generated objects, 5 interpolations. 50 | For each object, and some interpolation steps, script will have 3D object produced using marching cubes algorithm. 51 | 52 | ### Prepared Dataset 53 | 54 | [Download dataset here.](https://ujchmura-my.sharepoint.com/:u:/g/personal/przemyslaw_spurek_uj_edu_pl/ETy5BPpf4ZFLorYEpXxhRRcBY1ASvCqDCgEX_h75Um6MlA?e=MTJdaj) 55 | 56 | Folders should be placed in a local folder specified in experiment's config file. 57 | For metrics calculation you need to download a ShapenetCore.V2 by yourself and specify its' location in a config file. 58 | 59 | ### Data Preparation 60 | 61 | Use scripts found in `dataset_generation_scripts` folder. Create new `conda` environment if you need. 62 | Or download provided data: ShapeNet cars, planes, chairs with 50 renders, each 200x200 on 2048 colored points. 63 | 64 | 1. Download ShapeNet, or other dataset of your choice which has 3D models 65 | 2. If using ShapeNet convert to `.gltf` extension (this removes duplicate faces etc.) 66 | 3. Use `cloud_exporter.py` to sample colored point cloud from objects. 67 | 4. Render them using `generate.py` which requires `Blender` 68 | 5. Put them in separate folder with structure: `dataset_path_from_config_file//sampled/_.npz` 69 | 6. Use that `ds` folder as input for training. 70 | 71 | Note: 72 | - is required if you want evaluation to find original `.obj` files in order to sample and compare points position with model reconstruction. 73 | - training requires considerable amount of GPU memory, it can be reduced by changing parameters in the `config` files (`N_samples`, `N_rand`, `chunk` sizes for target NeRF networks and the Hypernetwork) 74 | -------------------------------------------------------------------------------- /config_cars.json: -------------------------------------------------------------------------------- 1 | { 2 | "results_dir": "./results/cars", 3 | "clean_results_dir": true, 4 | "clean_weights_dir": true, 5 | "cuda": true, 6 | "gpu": 0, 7 | "data_dir": "./pts2nerf_data", 8 | "classes": [ 9 | "cars" 10 | ], 11 | "n_points": 2048, 12 | "max_epochs": 2000, 13 | "poses": 1, 14 | "batch_size": 1, 15 | "shuffle": true, 16 | "z_size": 4096, 17 | "seed": 111, 18 | "i_log": 1, 19 | "i_sample": 1, 20 | "i_save": 100, 21 | "resnet": false, 22 | "lr_decay": 0.999, 23 | "model": { 24 | "D": { 25 | "dropout": 0.5, 26 | "use_bias": true, 27 | "relu_slope": 0.2 28 | }, 29 | "HN": { 30 | "use_bias": true, 31 | "relu_slope": 0.2, 32 | "arch": [ 33 | 4096, 34 | 8192 35 | ], 36 | "chunk_size": 16384 37 | }, 38 | "E": { 39 | "use_bias": true, 40 | "relu_slope": 0.2 41 | }, 42 | "TN": { 43 | "use_bias": true, 44 | "D": 8, 45 | "W": 256, 46 | "skips": [ 47 | 4 48 | ], 49 | "peturb": 1, 50 | "N_importance": 0, 51 | "N_samples": 256, 52 | "N_rand": 1024, 53 | "white_bkgd": true, 54 | "use_viewdirs": false, 55 | "raw_noise_std": 0, 56 | "multires": 10, 57 | "multires_views": 4, 58 | "i_embed": 0, 59 | "netchunk": 8192, 60 | "chunk": 16384, 61 | "relu_slope": 0.2, 62 | "freeze_layers_learning": false, 63 | "input_ch_embed": 63, 64 | "input_ch_views_embed": 27 65 | } 66 | }, 67 | "optimizer": { 68 | "D": { 69 | "hyperparams": { 70 | "lr": 5e-05, 71 | "betas": [ 72 | 0.9, 73 | 0.999 74 | ] 75 | } 76 | }, 77 | "E_HN": { 78 | "hyperparams": { 79 | "lr": 5e-05 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /config_chairs.json: -------------------------------------------------------------------------------- 1 | { 2 | "results_dir": "./results/chairs", 3 | "clean_results_dir": true, 4 | "clean_weights_dir": true, 5 | "cuda": true, 6 | "gpu": 0, 7 | "data_dir": "./pts2nerf_data", 8 | "classes": [ 9 | "chairs" 10 | ], 11 | "n_points": 2048, 12 | "max_epochs": 2000, 13 | "poses": 1, 14 | "batch_size": 1, 15 | "shuffle": true, 16 | "z_size": 4096, 17 | "seed": 111, 18 | "i_log": 1, 19 | "i_sample": 1, 20 | "i_save": 100, 21 | "resnet": false, 22 | "lr_decay": 0.999, 23 | "model": { 24 | "D": { 25 | "dropout": 0.5, 26 | "use_bias": true, 27 | "relu_slope": 0.2 28 | }, 29 | "HN": { 30 | "use_bias": true, 31 | "relu_slope": 0.2, 32 | "arch": [ 33 | 4096, 34 | 8192 35 | ], 36 | "chunk_size": 16384 37 | }, 38 | "E": { 39 | "use_bias": true, 40 | "relu_slope": 0.2 41 | }, 42 | "TN": { 43 | "use_bias": true, 44 | "D": 8, 45 | "W": 256, 46 | "skips": [ 47 | 4 48 | ], 49 | "peturb": 1, 50 | "N_importance": 0, 51 | "N_samples": 256, 52 | "N_rand": 1024, 53 | "white_bkgd": true, 54 | "use_viewdirs": false, 55 | "raw_noise_std": 0, 56 | "multires": 10, 57 | "multires_views": 4, 58 | "i_embed": 0, 59 | "netchunk": 8192, 60 | "chunk": 16384, 61 | "relu_slope": 0.2, 62 | "freeze_layers_learning": false, 63 | "input_ch_embed": 63, 64 | "input_ch_views_embed": 27 65 | } 66 | }, 67 | "optimizer": { 68 | "D": { 69 | "hyperparams": { 70 | "lr": 5e-05, 71 | "betas": [ 72 | 0.9, 73 | 0.999 74 | ] 75 | } 76 | }, 77 | "E_HN": { 78 | "hyperparams": { 79 | "lr": 5e-05 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /config_planes.json: -------------------------------------------------------------------------------- 1 | { 2 | "results_dir": "./results/planes", 3 | "clean_results_dir": true, 4 | "clean_weights_dir": true, 5 | "cuda": true, 6 | "gpu": 0, 7 | "data_dir": "./pts2nerf_data", 8 | "classes": [ 9 | "planes" 10 | ], 11 | "n_points": 2048, 12 | "max_epochs": 2000, 13 | "poses": 1, 14 | "batch_size": 1, 15 | "shuffle": true, 16 | "z_size": 4096, 17 | "seed": 111, 18 | "i_log": 1, 19 | "i_sample": 1, 20 | "i_save": 100, 21 | "resnet": false, 22 | "lr_decay": 0.999, 23 | "model": { 24 | "D": { 25 | "dropout": 0.5, 26 | "use_bias": true, 27 | "relu_slope": 0.2 28 | }, 29 | "HN": { 30 | "use_bias": true, 31 | "relu_slope": 0.2, 32 | "arch": [ 33 | 4096, 34 | 8192 35 | ], 36 | "chunk_size": 16384 37 | }, 38 | "E": { 39 | "use_bias": true, 40 | "relu_slope": 0.2 41 | }, 42 | "TN": { 43 | "use_bias": true, 44 | "D": 8, 45 | "W": 256, 46 | "skips": [ 47 | 4 48 | ], 49 | "peturb": 1, 50 | "N_importance": 0, 51 | "N_samples": 256, 52 | "N_rand": 1024, 53 | "white_bkgd": true, 54 | "use_viewdirs": false, 55 | "raw_noise_std": 0, 56 | "multires": 10, 57 | "multires_views": 4, 58 | "i_embed": 0, 59 | "netchunk": 8192, 60 | "chunk": 16384, 61 | "relu_slope": 0.2, 62 | "freeze_layers_learning": false, 63 | "input_ch_embed": 63, 64 | "input_ch_views_embed": 27 65 | } 66 | }, 67 | "optimizer": { 68 | "D": { 69 | "hyperparams": { 70 | "lr": 5e-05, 71 | "betas": [ 72 | 0.9, 73 | 0.999 74 | ] 75 | } 76 | }, 77 | "E_HN": { 78 | "hyperparams": { 79 | "lr": 5e-05 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import pandas as pd 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | from os.path import join 7 | 8 | synth_id_to_category = { 9 | '02691156': 'planes', '02773838': 'bag', '02801938': 'basket', #airplane = planes temporary 10 | '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', 11 | '02834778': 'bicycle', '02843684': 'birdhouse', '02871439': 'bookshelf', 12 | '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', 13 | '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', 14 | '02954340': 'cap', '02958343': 'cars', '03001627': 'chairs', #car=cars temporary chair=chairs 15 | '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', 16 | '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', 17 | '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', 18 | '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', 19 | '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', 20 | '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', 21 | '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', 22 | '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', 23 | '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', 24 | '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', 25 | '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', 26 | '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', 27 | '04554684': 'washer', '02858304': 'boat', '02992529': 'cellphone' 28 | } 29 | 30 | category_to_synth_id = {v: k for k, v in synth_id_to_category.items()} 31 | synth_id_to_number = {k: i for i, k in enumerate(synth_id_to_category.keys())} 32 | 33 | 34 | class NeRFShapeNetDataset(Dataset): 35 | def __init__(self, root_dir='/home/datasets/nerfdataset', shapenet_root_dir='/shared/sets/datasets/3D_points/ShapeNetCore.v2', classes=[], 36 | transform=None, train=True): 37 | """ 38 | Args: 39 | root_dir (string): Directory of structure: 40 | > 41 | >classname1 42 | >sampled 43 | >count_{name}.npz 44 | >classname2 45 | ... 46 | 47 | where sampled has all the .NPZ of format: images : (n, W, H, channels), cam_poses (n, 4, 4), data :(N, 6) 48 | and shapenet is a shapenet directory for this class (contains .obj files). 49 | 50 | classes: list of class names 51 | 52 | transform (callable, optional): Optional transform to be applied on a sample. 53 | """ 54 | self.root_dir = root_dir 55 | self.shapenet_root_dir = shapenet_root_dir 56 | self.transform = transform 57 | 58 | self.classes = classes 59 | self.train = train 60 | 61 | self.data = [] 62 | 63 | self._load() 64 | 65 | def __len__(self): 66 | if self.train: 67 | return len(self.train_data) 68 | else: 69 | return len(self.test_data) 70 | 71 | 72 | def __getitem__(self, idx): 73 | if self.train: 74 | data_files = self.train_data 75 | else: 76 | data_files = self.test_data 77 | 78 | sample = np.load(data_files['sample_filename'][idx]) 79 | class_name = data_files['class'][idx] 80 | if self.transform: 81 | sample = self.transform(sample) 82 | 83 | #return self.data[idx] 84 | return sample, class_name, data_files['obj_filename'][idx] 85 | 86 | def _load(self): 87 | print("Loading dataset:") 88 | self.train_data = pd.DataFrame(columns=['class', 'name', 'sample_filename', 'obj_filename']) 89 | self.test_data = pd.DataFrame(columns=['class', 'name', 'sample_filename', 'obj_filename']) 90 | 91 | for data_class in self.classes: 92 | df = pd.DataFrame(columns=['class', 'name', 'sample_filename', 'obj_filename']) 93 | print(data_class) 94 | 95 | npz_glob = glob.glob(join(self.root_dir,data_class,'sampled','*.npz')) 96 | print(len(npz_glob)) 97 | for file in npz_glob: 98 | sample_name = file.split('_')[-1].split('.')[0] 99 | 100 | df = df.append({'class': data_class, 101 | 'name': sample_name, 102 | 'sample_filename':file, 103 | 'obj_filename':join(self.shapenet_root_dir, category_to_synth_id[data_class], sample_name, 'models','model_normalized.obj')}, 104 | ignore_index=True) 105 | 106 | #with np.load(file) as data: 107 | #self.data.append({'data': np.array(data['data']), 'images':np.array(data['images']), 'cam_poses':np.array(data['cam_poses'])}) 108 | 109 | #Sort and split, same like Atlasnet 110 | df = df.sort_values(by=['name']) 111 | df_train = df.head(max(1,int(len(df)*(0.8)))) 112 | df_test = df.tail(max(1,int(len(df)*(0.2)))) 113 | 114 | self.train_data = pd.concat([self.train_data, df_train]) 115 | self.test_data = pd.concat([self.test_data, df_test]) 116 | 117 | self.train_data = self.train_data.reset_index(drop=True) 118 | self.test_data = self.test_data.reset_index(drop=True) 119 | 120 | print("Loaded train data:", len(self.train_data), "samples") 121 | print("Loaded test data:", len(self.test_data), "samples") -------------------------------------------------------------------------------- /dataset_generation_scripts/cloud_export.py: -------------------------------------------------------------------------------- 1 | import os 2 | import bpy 3 | import numpy as np 4 | from random import randrange 5 | import glob 6 | import traceback 7 | import argparse 8 | 9 | def get_materials_info(material_slots): 10 | images = [] 11 | for material_slot in material_slots: 12 | image = None 13 | color = None 14 | local_pixels = None 15 | for node in material_slot.material.node_tree.nodes: 16 | if node.type == 'TEX_IMAGE': 17 | image = bpy.data.images[node.image.name] 18 | local_pixels = list(image.pixels[:]) 19 | elif node.type == 'BSDF_PRINCIPLED': 20 | color = node.color 21 | if image != None: 22 | images.append(('img', image, local_pixels)) 23 | else: 24 | images.append(('col', color)) 25 | return images 26 | 27 | def clamp_uv(val): 28 | return max(0, min(val, 1)) 29 | 30 | def should_skip(selected_verts, vert_idx, inserted, ob): 31 | try: 32 | index = selected_verts.index(ob.data.vertices[vert_idx]) 33 | if inserted[index] == True: 34 | return True 35 | inserted[index] = True 36 | return False 37 | except ValueError: 38 | return True 39 | 40 | def append_vert_and_color(verts_coordinates, verts_colors, ob, vert_idx, loop_idx, face, images, width, height, local_pixels): 41 | if images[face.material_index][0] == 'col': 42 | verts_coordinates.append((ob.data.vertices[vert_idx].co[0], ob.data.vertices[vert_idx].co[1], ob.data.vertices[vert_idx].co[2])) 43 | verts_colors.append((images[face.material_index][1][0], images[face.material_index][1][1], images[face.material_index][1][2], 1)) 44 | else: 45 | uv_coords = ob.data.uv_layers.active.data[loop_idx].uv 46 | 47 | target = [round(clamp_uv(uv_coords.x) * (width - 1)), round(clamp_uv(uv_coords.y) * (height - 1))] 48 | index = ( target[1] * width + target[0] ) * 4 49 | 50 | verts_coordinates.append((ob.data.vertices[vert_idx].co[0], ob.data.vertices[vert_idx].co[1], ob.data.vertices[vert_idx].co[2])) 51 | verts_colors.append((local_pixels[index], local_pixels[index + 1], local_pixels[index + 2], 1)) 52 | 53 | def process_faces_for_cloud(ob, images, inserted): 54 | verts_coordinates, verts_colors = [], [] 55 | image, local_pixels, width, height = None, None, None, None 56 | for face in ob.data.polygons: 57 | if images[face.material_index][0] == 'img': 58 | image = images[face.material_index][1] 59 | local_pixels = images[face.material_index][2] 60 | width = image.size[0] 61 | height = image.size[1] 62 | else: 63 | image = images[face.material_index][1][2] 64 | 65 | for vert_idx, loop_idx in zip(face.vertices, face.loop_indices): 66 | if should_skip(selected_verts, vert_idx, inserted, ob): 67 | continue 68 | append_vert_and_color(verts_coordinates, verts_colors, ob, vert_idx, loop_idx, face, images, width, height, local_pixels) 69 | return verts_coordinates, verts_colors 70 | 71 | def export_to_cloud(ob, filepath : str, selected_verts): 72 | images = get_materials_info(ob.material_slots) 73 | inserted = np.zeros((len(selected_verts)), dtype=bool) 74 | 75 | verts_coordinates, verts_colors = process_faces_for_cloud(ob, images, inserted) 76 | 77 | np.savetxt(filepath + ob.name + '_mesh_data.txt', np.asarray(verts_coordinates), delimiter=' ', fmt='%f') 78 | np.savetxt(filepath + ob.name + '_color_data.txt', np.asarray(verts_colors), delimiter=' ', fmt='%f') 79 | 80 | def bake(filepath : str, selected_verts): 81 | bpy.context.view_layer.objects.active = bpy.data.objects[0] 82 | bpy.ops.object.mode_set(mode='VERTEX_PAINT') 83 | 84 | for obj in bpy.context.scene.objects: 85 | if hasattr(obj.data, 'vertices') == False: 86 | continue 87 | 88 | export_to_cloud(obj, filepath, selected_verts) 89 | bpy.ops.object.mode_set(mode='OBJECT') 90 | 91 | 92 | def get_axis_index(pos, bound_lower, bound_upper, interval_len, axis_len): 93 | x = pos 94 | x_start = bound_lower + interval_len 95 | x_index = 0 96 | while(x_start < x): 97 | x_start += interval_len 98 | x_index += 1 99 | if x_index >= axis_len: 100 | x_index = axis_len - 1 101 | return x_index 102 | 103 | 104 | def get_box_index(vert, bounds, interval_len, axis_len): 105 | return (get_axis_index(vert.co[0], bounds[0][0], bounds[0][1], interval_len[0], axis_len), get_axis_index(vert.co[1], bounds[1][0], bounds[1][1], interval_len[1], axis_len), get_axis_index(vert.co[2], bounds[2][0], bounds[2][1], interval_len[2], axis_len)) 106 | 107 | 108 | def get_random_vertices(boxes, axis_len): 109 | selected_verts = [] 110 | 111 | while(len(selected_verts) < axis_len**3): 112 | box_index = randrange(len(boxes)) 113 | box_len = len(boxes[box_index]) 114 | vert = boxes[box_index].pop(randrange(box_len)) 115 | if vert in selected_verts: 116 | continue 117 | selected_verts.append(vert) 118 | if len(boxes[box_index]) == 0: 119 | boxes.pop(box_index) 120 | return selected_verts 121 | 122 | def put_verts_to_boxes(boxes, vertices, bounds, x_interval_len, y_interval_len, z_interval_len, axis_len): 123 | for vert in vertices: 124 | box_index = get_box_index(vert, bounds, (x_interval_len, y_interval_len, z_interval_len), axis_len) 125 | if boxes[box_index[0]][box_index[1]][box_index[2]] == None: 126 | boxes[box_index[0]][box_index[1]][box_index[2]] = [] 127 | boxes[box_index[0]][box_index[1]][box_index[2]].append(vert) 128 | 129 | def select_verts_subspace(obj, axis_len): 130 | x_sort = sorted(obj.data.vertices, key=lambda v: v.co[0]) 131 | y_sort = sorted(obj.data.vertices, key=lambda v: v.co[1]) 132 | z_sort = sorted(obj.data.vertices, key=lambda v: v.co[2]) 133 | 134 | bounds = ((x_sort[0].co[0], x_sort[-1].co[0]), (y_sort[0].co[1], y_sort[-1].co[1]), (z_sort[0].co[2], z_sort[-1].co[2])) 135 | 136 | x_interval_len = (bounds[0][1] - bounds[0][0]) / axis_len 137 | y_interval_len = (bounds[1][1] - bounds[1][0]) / axis_len 138 | z_interval_len = (bounds[2][1] - bounds[2][0]) / axis_len 139 | 140 | boxes = np.empty((axis_len, axis_len, axis_len), dtype=type(list)) 141 | 142 | put_verts_to_boxes(boxes, x_sort, bounds, x_interval_len, y_interval_len, z_interval_len, axis_len) 143 | 144 | boxes = boxes.flatten() 145 | boxes = [box for box in boxes if box != None] 146 | 147 | return get_random_vertices(boxes, axis_len) 148 | 149 | 150 | def get_subdiv_amount(vert_len, axis_len): 151 | subdiv_threshold = 4**2 * axis_len**3 152 | subdiv_amount = 0 153 | vert_len *= 4 154 | while vert_len < subdiv_threshold: 155 | subdiv_amount += 1 156 | vert_len *= 4 157 | return subdiv_amount 158 | 159 | def merge_doubles(merge_threshold = 0.0000001): 160 | for obj in bpy.context.scene.objects: 161 | if hasattr(obj.data, 'vertices') == False: 162 | continue 163 | print('before doubles merge: ', len(obj.data.vertices)) 164 | bpy.context.view_layer.objects.active = obj 165 | bpy.ops.object.mode_set(mode='EDIT') 166 | bpy.ops.mesh.select_all(action='SELECT') 167 | bpy.ops.mesh.remove_doubles(threshold = merge_threshold) 168 | bpy.ops.mesh.select_all(action='DESELECT') 169 | bpy.ops.mesh.select_mode(type = 'FACE') 170 | bpy.ops.mesh.select_interior_faces() 171 | bpy.ops.mesh.delete(type='FACE') 172 | bpy.ops.object.mode_set(mode='OBJECT') 173 | print('after doubles merge: ', len(obj.data.vertices)) 174 | 175 | 176 | def apply_subsurf(axis_len): 177 | for obj in bpy.context.scene.objects: 178 | if hasattr(obj.data, 'vertices') == False: 179 | continue 180 | subdiv_amount = 1 181 | if subdiv_amount == 0: 182 | return 183 | 184 | bpy.context.view_layer.objects.active = obj 185 | while(len(obj.data.vertices) < axis_len**3): 186 | bpy.ops.object.modifier_add(type='SUBSURF') 187 | bpy.context.object.modifiers[0].subdivision_type = 'SIMPLE' 188 | bpy.context.object.modifiers[0].levels = subdiv_amount 189 | bpy.ops.object.modifier_apply(modifier='Subdivision') 190 | 191 | def apply_simple_subdivide(min_vert_count): 192 | for obj in bpy.context.scene.objects: 193 | if hasattr(obj.data, 'vertices') == False: 194 | continue 195 | bpy.context.view_layer.objects.active = obj 196 | while(len(obj.data.vertices) < min_vert_count): 197 | # print('subdiv', len(obj.data.vertices)) 198 | bpy.ops.object.mode_set(mode="EDIT") 199 | bpy.ops.mesh.select_all() 200 | bpy.ops.mesh.subdivide(number_cuts=1) 201 | bpy.ops.object.mode_set(mode="OBJECT") 202 | 203 | 204 | def cleanup_mesh(): 205 | for obj in bpy.context.scene.objects: 206 | if hasattr(obj.data, 'vertices') == False: 207 | continue 208 | bpy.context.view_layer.objects.active = obj 209 | bpy.ops.object.mode_set(mode="EDIT") 210 | bpy.ops.mesh.remove_doubles() 211 | bpy.ops.mesh.delete_loose() 212 | bpy.ops.object.mode_set(mode="OBJECT") 213 | 214 | def apply_decimate(): 215 | for obj in bpy.context.scene.objects: 216 | if hasattr(obj.data, 'vertices') == False: 217 | continue 218 | bpy.context.view_layer.objects.active = obj 219 | bpy.ops.object.modifier_add(type='DECIMATE') 220 | bpy.context.object.modifiers[0].ratio = 0.5 221 | bpy.ops.object.modifier_apply(modifier='Decimate') 222 | 223 | 224 | def get_models_directories(path): 225 | obj_list = [] 226 | dir_list = sorted(os.listdir(path)) 227 | dir_list = [os.path.join(path, dir) for dir in dir_list if os.path.isdir(os.path.join(path, dir))] 228 | for dir in dir_list: 229 | for r, d, files in os.walk(dir): 230 | if 'images' not in d: 231 | continue 232 | for r1, d1, f1 in os.walk(os.path.join(dir, 'models')): 233 | for file in f1: 234 | if file.endswith('.obj'): 235 | obj_list.append(os.path.join(dir, 'models', file)) 236 | 237 | print('models with textures: ', len(obj_list)) 238 | return obj_list 239 | 240 | 241 | if __name__ == "__main__": 242 | parser = argparse.ArgumentParser(description='Export cloud of points') 243 | parser.add_argument('shapenet_path', type=str, 244 | help='Relative shapenet path (./shapenet/02958343)') 245 | parser.add_argument('export_path', type=str, 246 | help='Relative output path (./dataset_maker/export/02958343)') 247 | parser.add_argument('check_path', type=str, 248 | help='Relative previous run output path (./dataset_maker/data/02958343/*.txt) to continue sampling in case of error') 249 | 250 | args = parser.parse_args() 251 | 252 | path_to_obj_dir = args.shapenet_path 253 | export_path = args.export_path 254 | check_path = args.check_path 255 | 256 | axis_len = 3 257 | min_vert_subdiv_count = 10000 258 | 259 | bpy.ops.object.delete({"selected_objects": bpy.context.scene.objects}) 260 | 261 | already_generated = glob.glob(check_path) 262 | already_generated_names = [x.split('\\')[-1].split('_')[0] for x in already_generated] 263 | print(already_generated_names) 264 | 265 | obj_list = get_models_directories(path_to_obj_dir) 266 | i = 0 267 | for file in obj_list: 268 | try: 269 | print('file ', file) 270 | name = file.split('\\')[-3] 271 | if name in already_generated_names: 272 | print(name, ' is already generated.') 273 | continue 274 | bpy.ops.import_scene.obj(filepath = file, use_split_objects=False) 275 | for obj in bpy.context.scene.objects: 276 | obj.name = name 277 | print('set name ', obj.name) 278 | cleanup_mesh() 279 | apply_decimate() 280 | apply_simple_subdivide(min_vert_subdiv_count) 281 | selected_verts = select_verts_subspace(bpy.data.objects[0], axis_len) 282 | print('selected verts len ', len(selected_verts)) 283 | bake(export_path, selected_verts) 284 | bpy.context.view_layer.objects.active = bpy.data.objects[0] 285 | bpy.ops.object.delete({"selected_objects": bpy.context.scene.objects}) 286 | i += 1 287 | print('export progress: ', (i * 100) / len(obj_list), ' %') 288 | except Exception as e: 289 | bpy.ops.object.mode_set(mode="OBJECT") 290 | bpy.ops.object.delete({"selected_objects": bpy.context.scene.objects}) 291 | f=open("cloud_error.txt", "a") 292 | f.write(file + '\n') 293 | f.write(str(e) + '\n') 294 | f.write(traceback.format_exc()) 295 | f.close() 296 | print(e) 297 | print("Shapenet size: ", len(obj_list)) 298 | print('Generated :', i) 299 | print('Already there: ', len(already_generated_names)) -------------------------------------------------------------------------------- /dataset_generation_scripts/generate.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy 3 | import os 4 | import numpy as np 5 | from pathlib import Path 6 | import subprocess 7 | 8 | from PIL import Image 9 | import json 10 | from numpy import asarray 11 | import shutil 12 | 13 | 14 | DEBUG = False 15 | MAX = 99999 16 | VIEW_COUNT = 50 17 | RESOLUTION = 200 18 | 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(description='Render gltfs') 23 | parser.add_argument('path', type=str, 24 | help='Glob string for model files (./shapenet/02958343/**/*.gltf)') 25 | parser.add_argument('cp_path', type=str, 26 | help='Cloud of points files from cloud export script (./data/02958343/)') 27 | parser.add_argument('export_name', type=str, 28 | help='Result folder name (./shapenet/02958343)') 29 | 30 | args = parser.parse_args() 31 | files = glob.glob(args.path, recursive=True) 32 | count = len(files) 33 | 34 | RESULT_DIR = f"{args.export_name}_{VIEW_COUNT}_{RESOLUTION}x{RESOLUTION}" 35 | 36 | np.random.seed(1234) 37 | generated = 0 38 | for i, f in enumerate(files): 39 | os.makedirs(f"./{RESULT_DIR}", exist_ok=True) 40 | 41 | print(f"{i}/{count}") 42 | name = f.split('/')[-3] 43 | print(f"{name} opened at {f}") 44 | 45 | 46 | print("Rendering...") 47 | x=subprocess.run(f"blender --background --python render.py -- --output_folder ./tmp {f} --name {name} --views {VIEW_COUNT} --resolution {RESOLUTION}", capture_output=False) #render180.py for circle 48 | print("Images rendered") 49 | 50 | images = [] 51 | cam_poses = [] 52 | 53 | for render in glob.glob(f'./tmp/{name}/*.png'): 54 | images.append(asarray(Image.open(render))) 55 | 56 | f = (render.split('/')[-1]).split('.')[0] 57 | with open(f"./tmp/{name}/{f}.json", "r") as file: 58 | cp = json.load(fp=file) 59 | cam_poses.append(np.array(cp)) 60 | 61 | images = np.array(images, dtype="float16")/255 62 | 63 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) #enforce white background 64 | 65 | cam_poses = np.array(cam_poses) 66 | 67 | if DEBUG: 68 | print(images.shape) 69 | print(cam_poses.shape) 70 | 71 | print("Sampling points...") 72 | 73 | vertices = [] 74 | try: 75 | with open(f'{args.cp_path}/{name}_mesh_data.txt') as file: 76 | for line in file: 77 | vertices.append([float(x) for x in line.split()]) 78 | vertices = np.array(vertices) 79 | except Exception as e: 80 | continue 81 | #print(e) 82 | 83 | colors = [] 84 | with open(f'{args.cp_path}/{name}_color_data.txt') as file: 85 | for line in file: 86 | colors.append([float(x) for x in line.split()]) 87 | colors = np.array(colors) 88 | 89 | data = np.concatenate((vertices, colors[:,0:3]), axis=1) 90 | 91 | #uncomment if you want to take random number of points 92 | data = data[np.random.choice(data.shape[0], 2048, replace=False), :] 93 | 94 | if DEBUG: 95 | print(data.shape) 96 | print(data) 97 | 98 | print(f"Data for {name} was generated!") 99 | print(f"Saving {name}...") 100 | np.savez_compressed(f"./{RESULT_DIR}/{VIEW_COUNT}_{name}.npz", images=images, cam_poses=cam_poses, data=data) 101 | print(f"{name} was saved!") 102 | 103 | shutil.rmtree(f'./tmp/{name}') 104 | 105 | if DEBUG: 106 | print("DEBUG is TRUE: pausing rendering") 107 | break 108 | 109 | if generated >= MAX: 110 | break 111 | 112 | generated += 1 113 | 114 | print("Data was generated!") 115 | -------------------------------------------------------------------------------- /dataset_generation_scripts/obj2gltf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from tqdm import tqdm 4 | from pathlib import Path 5 | 6 | Path(__file__).parent 7 | 8 | files = glob.glob('.\shapenet\02958343\**\*.obj', recursive=True) 9 | for file in tqdm(files): 10 | os.system(f"obj2gltf -i {file} -o {Path(file).parent}\\model_normalized.gltf") -------------------------------------------------------------------------------- /dataset_generation_scripts/render.py: -------------------------------------------------------------------------------- 1 | # Modified: https://github.com/panmari/stanford-shapenet-renderer 2 | # A simple script that uses blender to render views of a single object by rotation the camera around it. 3 | # Also produces depth map at the same time. 4 | # 5 | # Tested with Blender 2.9 6 | # 7 | # Example: 8 | # blender --background --python mytest.py -- --views 10 /path/to/my.obj 9 | # 10 | 11 | import argparse, sys, os, math, re 12 | import bpy 13 | from glob import glob 14 | import numpy as np 15 | 16 | import json 17 | 18 | import random 19 | 20 | def listify_matrix(matrix): 21 | matrix_list = [] 22 | for row in matrix: 23 | matrix_list.append(list(row)) 24 | return matrix_list 25 | 26 | parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.') 27 | parser.add_argument('--views', type=int, default=100, 28 | help='number of views to be rendered') 29 | parser.add_argument('obj', type=str, 30 | help='Path to the obj file to be rendered.') 31 | parser.add_argument('--output_folder', type=str, default='./tmp', 32 | help='The path the output will be dumped to.') 33 | parser.add_argument('--scale', type=float, default=2.2, 34 | help='Scaling factor applied to model. Depends on size of mesh.') 35 | parser.add_argument('--remove_doubles', type=bool, default=True, 36 | help='Remove double vertices to improve mesh quality.') 37 | parser.add_argument('--edge_split', type=bool, default=True, 38 | help='Adds edge split filter.') 39 | parser.add_argument('--depth_scale', type=float, default=2.2, 40 | help='Scaling that is applied to depth. Depends on size of mesh. Try out various values until you get a good result. Ignored if format is OPEN_EXR.') 41 | parser.add_argument('--color_depth', type=str, default='8', 42 | help='Number of bit per channel used for output. Either 8 or 16.') 43 | parser.add_argument('--format', type=str, default='PNG', 44 | help='Format of files generated. Either PNG or OPEN_EXR') 45 | parser.add_argument('--resolution', type=int, default=400, #800! 46 | help='Resolution of the images.') 47 | parser.add_argument('--engine', type=str, default='BLENDER_EEVEE', 48 | help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...') 49 | parser.add_argument('--name', type=str, default='model', 50 | help='Model id') 51 | 52 | argv = sys.argv[sys.argv.index("--") + 1:] 53 | args = parser.parse_args(argv) 54 | 55 | # Set up rendering 56 | context = bpy.context 57 | scene = bpy.context.scene 58 | render = bpy.context.scene.render 59 | 60 | render.engine = args.engine 61 | render.image_settings.color_mode = 'RGBA' # ('RGB', 'RGBA', ...) 62 | render.image_settings.color_depth = args.color_depth # ('8', '16') 63 | render.image_settings.file_format = args.format # ('PNG', 'OPEN_EXR', 'JPEG, ...) 64 | render.resolution_x = args.resolution 65 | render.resolution_y = args.resolution 66 | render.resolution_percentage = 100 67 | render.film_transparent = True 68 | 69 | #set up nodes to change background color 70 | 71 | bpy.context.scene.use_nodes = True 72 | tree = bpy.context.scene.node_tree 73 | composite = tree.nodes[0] 74 | render_layers = tree.nodes[1] 75 | alpha_over = tree.nodes.new(type='CompositorNodeAlphaOver') 76 | links = tree.links 77 | link_1 = links.new(render_layers.outputs[0], alpha_over.inputs[2]) 78 | link_2 = links.new(alpha_over.outputs[0], composite.inputs[0]) 79 | alpha_over.inputs[1].default_value = (1, 1, 1, 0) 80 | 81 | 82 | scene.use_nodes = True 83 | 84 | scene.view_layers["ViewLayer"].use_pass_normal = True 85 | scene.view_layers["ViewLayer"].use_pass_diffuse_color = True 86 | scene.view_layers["ViewLayer"].use_pass_object_index = True 87 | 88 | nodes = bpy.context.scene.node_tree.nodes 89 | links = bpy.context.scene.node_tree.links 90 | 91 | bpy.context.scene.render.use_persistent_data = True 92 | 93 | # Clear default nodes 94 | for n in nodes: 95 | nodes.remove(n) 96 | 97 | # Create input render layer node 98 | render_layers = nodes.new('CompositorNodeRLayers') 99 | 100 | # Create depth output nodes 101 | depth_file_output = nodes.new(type="CompositorNodeOutputFile") 102 | depth_file_output.label = 'Depth Output' 103 | depth_file_output.base_path = '' 104 | depth_file_output.file_slots[0].use_node_format = True 105 | depth_file_output.format.file_format = args.format 106 | depth_file_output.format.color_depth = args.color_depth 107 | if args.format == 'OPEN_EXR': 108 | links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0]) 109 | else: 110 | depth_file_output.format.color_mode = "BW" 111 | 112 | # Remap as other types can not represent the full range of depth. 113 | map = nodes.new(type="CompositorNodeMapValue") 114 | # Size is chosen kind of arbitrarily, try out until you're satisfied with resulting depth map. 115 | map.offset = [-0.7] 116 | map.size = [args.depth_scale] 117 | map.use_min = True 118 | map.min = [0] 119 | 120 | links.new(render_layers.outputs['Depth'], map.inputs[0]) 121 | links.new(map.outputs[0], depth_file_output.inputs[0]) 122 | 123 | # Create normal output nodes 124 | scale_node = nodes.new(type="CompositorNodeMixRGB") 125 | scale_node.blend_type = 'MULTIPLY' 126 | # scale_node.use_alpha = True 127 | scale_node.inputs[2].default_value = (0.5, 0.5, 0.5, 1) 128 | links.new(render_layers.outputs['Normal'], scale_node.inputs[1]) 129 | 130 | bias_node = nodes.new(type="CompositorNodeMixRGB") 131 | bias_node.blend_type = 'ADD' 132 | # bias_node.use_alpha = True 133 | bias_node.inputs[2].default_value = (0.5, 0.5, 0.5, 0) 134 | links.new(scale_node.outputs[0], bias_node.inputs[1]) 135 | 136 | normal_file_output = nodes.new(type="CompositorNodeOutputFile") 137 | normal_file_output.label = 'Normal Output' 138 | normal_file_output.base_path = '' 139 | normal_file_output.file_slots[0].use_node_format = True 140 | normal_file_output.format.file_format = args.format 141 | links.new(bias_node.outputs[0], normal_file_output.inputs[0]) 142 | 143 | # Create albedo output nodes 144 | alpha_albedo = nodes.new(type="CompositorNodeSetAlpha") 145 | links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image']) 146 | links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha']) 147 | 148 | albedo_file_output = nodes.new(type="CompositorNodeOutputFile") 149 | albedo_file_output.label = 'Albedo Output' 150 | albedo_file_output.base_path = '' 151 | albedo_file_output.file_slots[0].use_node_format = True 152 | albedo_file_output.format.file_format = args.format 153 | albedo_file_output.format.color_mode = 'RGBA' 154 | albedo_file_output.format.color_depth = args.color_depth 155 | links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0]) 156 | 157 | # Create id map output nodes 158 | id_file_output = nodes.new(type="CompositorNodeOutputFile") 159 | id_file_output.label = 'ID Output' 160 | id_file_output.base_path = '' 161 | id_file_output.file_slots[0].use_node_format = True 162 | id_file_output.format.file_format = args.format 163 | id_file_output.format.color_depth = args.color_depth 164 | 165 | if args.format == 'OPEN_EXR': 166 | links.new(render_layers.outputs['IndexOB'], id_file_output.inputs[0]) 167 | else: 168 | id_file_output.format.color_mode = 'BW' 169 | 170 | divide_node = nodes.new(type='CompositorNodeMath') 171 | divide_node.operation = 'DIVIDE' 172 | divide_node.use_clamp = False 173 | divide_node.inputs[1].default_value = 2**int(args.color_depth) 174 | 175 | links.new(render_layers.outputs['IndexOB'], divide_node.inputs[0]) 176 | links.new(divide_node.outputs[0], id_file_output.inputs[0]) 177 | 178 | # Delete default cube 179 | context.active_object.select_set(True) 180 | bpy.ops.object.delete() 181 | 182 | # Import textured mesh 183 | bpy.ops.object.select_all(action='DESELECT') 184 | 185 | bpy.ops.import_scene.gltf(filepath=args.obj) 186 | 187 | 188 | obj = bpy.context.selected_objects[0] 189 | context.view_layer.objects.active = obj 190 | 191 | # create material 192 | #mat = bpy.data.materials.new(name='Material') 193 | 194 | #obj.data.materials.append(mat) 195 | #mat.use_nodes=True 196 | 197 | # let's create a variable to store our list of nodes 198 | #mat_nodes = mat.node_tree.nodes 199 | 200 | # let's set the metallic to 1.0 201 | #mat_nodes['Principled BSDF'].inputs['Metallic'].default_value=1.0 202 | #mat_nodes['Principled BSDF'].inputs['Roughness'].default_value=0.0 203 | 204 | # Possibly disable specular shading 205 | for slot in obj.material_slots: 206 | node = slot.material.node_tree.nodes['Principled BSDF'] 207 | node.inputs['Specular'].default_value = 0.3 208 | node.inputs['Metallic'].default_value=0.5 209 | node.inputs['Roughness'].default_value=0.25 210 | 211 | if args.scale != 1: 212 | bpy.ops.transform.resize(value=(args.scale,args.scale,args.scale)) 213 | bpy.ops.object.transform_apply(scale=True) 214 | #if args.remove_doubles: 215 | # bpy.ops.object.mode_set(mode='EDIT') 216 | # bpy.ops.mesh.remove_doubles() 217 | # bpy.ops.object.mode_set(mode='OBJECT') 218 | #if args.edge_split: 219 | # bpy.ops.object.modifier_add(type='EDGE_SPLIT') 220 | # context.object.modifiers["EdgeSplit"].split_angle = 1.32645 221 | # bpy.ops.object.modifier_apply(modifier="EdgeSplit") 222 | 223 | # Set objekt IDs 224 | obj.pass_index = 1 225 | 226 | #Make light just directional, disable shadows. 227 | light = bpy.data.lights['Light'] 228 | light.type = 'SUN' 229 | light.use_shadow = True 230 | #Possibly disable specular shading: 231 | light.specular_factor = 1.0 232 | light.energy = 0.0 233 | 234 | # create light datablock, set attributes 235 | light_data = bpy.data.lights.new(name="light_2.80", type='POINT') 236 | light_data.energy = 200 237 | light_data.specular_factor = 0.4 238 | light_data.use_shadow = True 239 | #light_data.color = (1.0,0,0) 240 | 241 | # create new object with our light datablock 242 | light_object = bpy.data.objects.new(name="light_2.80", object_data=light_data) 243 | 244 | # link light object 245 | bpy.context.collection.objects.link(light_object) 246 | 247 | # make it active 248 | bpy.context.view_layer.objects.active = light_object 249 | 250 | #change location 251 | light_object.location = (4, 1, 1) 252 | 253 | # Add another light source so stuff facing away from light is not completely dark 254 | #bpy.ops.object.light_add(type='SUN') 255 | #light2 = bpy.data.lights['Sun'] 256 | #light2.use_shadow = True 257 | #light2.specular_factor = 1.0 258 | #light2.energy = 0.045 259 | #bpy.data.objects['Sun'].rotation_euler = bpy.data.objects['Light'].rotation_euler 260 | #bpy.data.objects['Sun'].rotation_euler[0] += 180 261 | 262 | # Place camera 263 | cam = scene.objects['Camera'] 264 | cam.location = (0, 3.2, 0) 265 | cam.data.angle_x = 0.6911112070083618 266 | #cam.data.sensor_width = 32 267 | 268 | cam_constraint = cam.constraints.new(type='TRACK_TO') 269 | cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' 270 | cam_constraint.up_axis = 'UP_Y' 271 | 272 | cam_empty = bpy.data.objects.new("Empty", None) 273 | cam_empty.location = (0, 0, 0) 274 | cam.parent = cam_empty 275 | 276 | scene.collection.objects.link(cam_empty) 277 | context.view_layer.objects.active = cam_empty 278 | cam_constraint.target = cam_empty 279 | 280 | stepsize = 360.0 / args.views 281 | rotation_mode = 'XYZ' 282 | 283 | model_identifier = args.name 284 | fp = os.path.join(os.path.abspath(args.output_folder), model_identifier) 285 | 286 | 287 | for i in range(0, args.views): 288 | 289 | rot = np.random.uniform(0.001, 1, size=3) * (1,0,2*np.pi) # ( gora-dol, , z prawo-lewo) 290 | rot[0] = np.abs(np.arccos(1 - 2 * rot[0]) - np.pi/2) 291 | cam_empty.rotation_euler = rot 292 | 293 | render_file_path = fp + f'/image_{i}' 294 | 295 | scene.render.filepath = render_file_path 296 | 297 | bpy.ops.render.render(write_still=True) # render still 298 | 299 | with open(fp+f"/image_{i}.json", "w") as file: 300 | json.dump(listify_matrix(cam.matrix_world), file, indent=4) 301 | 302 | -------------------------------------------------------------------------------- /dataset_generation_scripts/render180.py: -------------------------------------------------------------------------------- 1 | # A simple script that uses blender to render views of a single object by rotation the camera around it. 2 | # Also produces depth map at the same time. 3 | # 4 | # Tested with Blender 2.9 5 | # 6 | # Example: 7 | # blender --background --python mytest.py -- --views 10 /path/to/my.obj 8 | # 9 | 10 | import argparse, sys, os, math, re 11 | import bpy 12 | from glob import glob 13 | import numpy as np 14 | 15 | import json 16 | 17 | import random 18 | 19 | def listify_matrix(matrix): 20 | matrix_list = [] 21 | for row in matrix: 22 | matrix_list.append(list(row)) 23 | return matrix_list 24 | 25 | parser = argparse.ArgumentParser(description='Renders given obj file by rotation a camera around it.') 26 | parser.add_argument('--views', type=int, default=100, 27 | help='number of views to be rendered') 28 | parser.add_argument('obj', type=str, 29 | help='Path to the obj file to be rendered.') 30 | parser.add_argument('--output_folder', type=str, default='./tmp', 31 | help='The path the output will be dumped to.') 32 | parser.add_argument('--scale', type=float, default=2.8, 33 | help='Scaling factor applied to model. Depends on size of mesh.') 34 | parser.add_argument('--remove_doubles', type=bool, default=True, 35 | help='Remove double vertices to improve mesh quality.') 36 | parser.add_argument('--edge_split', type=bool, default=True, 37 | help='Adds edge split filter.') 38 | parser.add_argument('--depth_scale', type=float, default=1.4, 39 | help='Scaling that is applied to depth. Depends on size of mesh. Try out various values until you get a good result. Ignored if format is OPEN_EXR.') 40 | parser.add_argument('--color_depth', type=str, default='8', 41 | help='Number of bit per channel used for output. Either 8 or 16.') 42 | parser.add_argument('--format', type=str, default='PNG', 43 | help='Format of files generated. Either PNG or OPEN_EXR') 44 | parser.add_argument('--resolution', type=int, default=400, #800! 45 | help='Resolution of the images.') 46 | parser.add_argument('--engine', type=str, default='BLENDER_EEVEE', 47 | help='Blender internal engine for rendering. E.g. CYCLES, BLENDER_EEVEE, ...') 48 | parser.add_argument('--name', type=str, default='model', 49 | help='Model id') 50 | 51 | argv = sys.argv[sys.argv.index("--") + 1:] 52 | args = parser.parse_args(argv) 53 | 54 | # Set up rendering 55 | context = bpy.context 56 | scene = bpy.context.scene 57 | render = bpy.context.scene.render 58 | 59 | render.engine = args.engine 60 | render.image_settings.color_mode = 'RGBA' # ('RGB', 'RGBA', ...) 61 | render.image_settings.color_depth = args.color_depth # ('8', '16') 62 | render.image_settings.file_format = args.format # ('PNG', 'OPEN_EXR', 'JPEG, ...) 63 | render.resolution_x = args.resolution 64 | render.resolution_y = args.resolution 65 | render.resolution_percentage = 100 66 | render.film_transparent = True 67 | 68 | #set up nodes to change background color 69 | 70 | bpy.context.scene.use_nodes = True 71 | tree = bpy.context.scene.node_tree 72 | composite = tree.nodes[0] 73 | render_layers = tree.nodes[1] 74 | alpha_over = tree.nodes.new(type='CompositorNodeAlphaOver') 75 | links = tree.links 76 | link_1 = links.new(render_layers.outputs[0], alpha_over.inputs[2]) 77 | link_2 = links.new(alpha_over.outputs[0], composite.inputs[0]) 78 | alpha_over.inputs[1].default_value = (1, 1, 1, 0) 79 | 80 | 81 | scene.use_nodes = True 82 | scene.view_layers["View Layer"].use_pass_normal = True 83 | scene.view_layers["View Layer"].use_pass_diffuse_color = True 84 | scene.view_layers["View Layer"].use_pass_object_index = True 85 | 86 | nodes = bpy.context.scene.node_tree.nodes 87 | links = bpy.context.scene.node_tree.links 88 | 89 | bpy.context.scene.render.use_persistent_data = True 90 | 91 | # Clear default nodes 92 | for n in nodes: 93 | nodes.remove(n) 94 | 95 | # Create input render layer node 96 | render_layers = nodes.new('CompositorNodeRLayers') 97 | 98 | # Create depth output nodes 99 | depth_file_output = nodes.new(type="CompositorNodeOutputFile") 100 | depth_file_output.label = 'Depth Output' 101 | depth_file_output.base_path = '' 102 | depth_file_output.file_slots[0].use_node_format = True 103 | depth_file_output.format.file_format = args.format 104 | depth_file_output.format.color_depth = args.color_depth 105 | if args.format == 'OPEN_EXR': 106 | links.new(render_layers.outputs['Depth'], depth_file_output.inputs[0]) 107 | else: 108 | depth_file_output.format.color_mode = "BW" 109 | 110 | # Remap as other types can not represent the full range of depth. 111 | map = nodes.new(type="CompositorNodeMapValue") 112 | # Size is chosen kind of arbitrarily, try out until you're satisfied with resulting depth map. 113 | map.offset = [-0.7] 114 | map.size = [args.depth_scale] 115 | map.use_min = True 116 | map.min = [0] 117 | 118 | links.new(render_layers.outputs['Depth'], map.inputs[0]) 119 | links.new(map.outputs[0], depth_file_output.inputs[0]) 120 | 121 | # Create normal output nodes 122 | scale_node = nodes.new(type="CompositorNodeMixRGB") 123 | scale_node.blend_type = 'MULTIPLY' 124 | # scale_node.use_alpha = True 125 | scale_node.inputs[2].default_value = (0.5, 0.5, 0.5, 1) 126 | links.new(render_layers.outputs['Normal'], scale_node.inputs[1]) 127 | 128 | bias_node = nodes.new(type="CompositorNodeMixRGB") 129 | bias_node.blend_type = 'ADD' 130 | # bias_node.use_alpha = True 131 | bias_node.inputs[2].default_value = (0.5, 0.5, 0.5, 0) 132 | links.new(scale_node.outputs[0], bias_node.inputs[1]) 133 | 134 | normal_file_output = nodes.new(type="CompositorNodeOutputFile") 135 | normal_file_output.label = 'Normal Output' 136 | normal_file_output.base_path = '' 137 | normal_file_output.file_slots[0].use_node_format = True 138 | normal_file_output.format.file_format = args.format 139 | links.new(bias_node.outputs[0], normal_file_output.inputs[0]) 140 | 141 | # Create albedo output nodes 142 | alpha_albedo = nodes.new(type="CompositorNodeSetAlpha") 143 | links.new(render_layers.outputs['DiffCol'], alpha_albedo.inputs['Image']) 144 | links.new(render_layers.outputs['Alpha'], alpha_albedo.inputs['Alpha']) 145 | 146 | albedo_file_output = nodes.new(type="CompositorNodeOutputFile") 147 | albedo_file_output.label = 'Albedo Output' 148 | albedo_file_output.base_path = '' 149 | albedo_file_output.file_slots[0].use_node_format = True 150 | albedo_file_output.format.file_format = args.format 151 | albedo_file_output.format.color_mode = 'RGBA' 152 | albedo_file_output.format.color_depth = args.color_depth 153 | links.new(alpha_albedo.outputs['Image'], albedo_file_output.inputs[0]) 154 | 155 | # Create id map output nodes 156 | id_file_output = nodes.new(type="CompositorNodeOutputFile") 157 | id_file_output.label = 'ID Output' 158 | id_file_output.base_path = '' 159 | id_file_output.file_slots[0].use_node_format = True 160 | id_file_output.format.file_format = args.format 161 | id_file_output.format.color_depth = args.color_depth 162 | 163 | if args.format == 'OPEN_EXR': 164 | links.new(render_layers.outputs['IndexOB'], id_file_output.inputs[0]) 165 | else: 166 | id_file_output.format.color_mode = 'BW' 167 | 168 | divide_node = nodes.new(type='CompositorNodeMath') 169 | divide_node.operation = 'DIVIDE' 170 | divide_node.use_clamp = False 171 | divide_node.inputs[1].default_value = 2**int(args.color_depth) 172 | 173 | links.new(render_layers.outputs['IndexOB'], divide_node.inputs[0]) 174 | links.new(divide_node.outputs[0], id_file_output.inputs[0]) 175 | 176 | # Delete default cube 177 | context.active_object.select_set(True) 178 | bpy.ops.object.delete() 179 | 180 | # Import textured mesh 181 | bpy.ops.object.select_all(action='DESELECT') 182 | 183 | bpy.ops.import_scene.obj(filepath=args.obj) 184 | 185 | 186 | 187 | obj = bpy.context.selected_objects[0] 188 | context.view_layer.objects.active = obj 189 | 190 | # create material 191 | #mat = bpy.data.materials.new(name='Material') 192 | 193 | #obj.data.materials.append(mat) 194 | #mat.use_nodes=True 195 | 196 | # let's create a variable to store our list of nodes 197 | #mat_nodes = mat.node_tree.nodes 198 | 199 | # let's set the metallic to 1.0 200 | #mat_nodes['Principled BSDF'].inputs['Metallic'].default_value=1.0 201 | #mat_nodes['Principled BSDF'].inputs['Roughness'].default_value=0.0 202 | 203 | # Possibly disable specular shading 204 | for slot in obj.material_slots: 205 | node = slot.material.node_tree.nodes['Principled BSDF'] 206 | node.inputs['Specular'].default_value = 0.3 207 | node.inputs['Metallic'].default_value=0.5 208 | node.inputs['Roughness'].default_value=0.25 209 | 210 | if args.scale != 1: 211 | bpy.ops.transform.resize(value=(args.scale,args.scale,args.scale)) 212 | bpy.ops.object.transform_apply(scale=True) 213 | if args.remove_doubles: 214 | bpy.ops.object.mode_set(mode='EDIT') 215 | bpy.ops.mesh.remove_doubles() 216 | bpy.ops.object.mode_set(mode='OBJECT') 217 | if args.edge_split: 218 | bpy.ops.object.modifier_add(type='EDGE_SPLIT') 219 | context.object.modifiers["EdgeSplit"].split_angle = 1.32645 220 | bpy.ops.object.modifier_apply(modifier="EdgeSplit") 221 | 222 | # Set objekt IDs 223 | obj.pass_index = 1 224 | 225 | #Make light just directional, disable shadows. 226 | light = bpy.data.lights['Light'] 227 | light.type = 'SUN' 228 | light.use_shadow = True 229 | #Possibly disable specular shading: 230 | light.specular_factor = 1.0 231 | light.energy = 0.0 232 | 233 | # create light datablock, set attributes 234 | light_data = bpy.data.lights.new(name="light_2.80", type='POINT') 235 | light_data.energy = 60 236 | light_data.specular_factor = 0.4 237 | light_data.use_shadow = True 238 | #light_data.color = (1.0,0,0) 239 | 240 | # create new object with our light datablock 241 | light_object = bpy.data.objects.new(name="light_2.80", object_data=light_data) 242 | 243 | # link light object 244 | bpy.context.collection.objects.link(light_object) 245 | 246 | # make it active 247 | bpy.context.view_layer.objects.active = light_object 248 | 249 | #change location 250 | light_object.location = (1.3, 1, 1) 251 | 252 | # Add another light source so stuff facing away from light is not completely dark 253 | #bpy.ops.object.light_add(type='SUN') 254 | #light2 = bpy.data.lights['Sun'] 255 | #light2.use_shadow = True 256 | #light2.specular_factor = 1.0 257 | #light2.energy = 0.045 258 | #bpy.data.objects['Sun'].rotation_euler = bpy.data.objects['Light'].rotation_euler 259 | #bpy.data.objects['Sun'].rotation_euler[0] += 180 260 | 261 | # Place camera 262 | cam = scene.objects['Camera'] 263 | cam.location = (0, 3.2, 0) 264 | cam.data.angle_x = 0.6911112070083618 265 | #cam.data.sensor_width = 32 266 | 267 | cam_constraint = cam.constraints.new(type='TRACK_TO') 268 | cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' 269 | cam_constraint.up_axis = 'UP_Y' 270 | 271 | cam_empty = bpy.data.objects.new("Empty", None) 272 | cam_empty.location = (0, 0, 0) 273 | cam.parent = cam_empty 274 | 275 | scene.collection.objects.link(cam_empty) 276 | context.view_layer.objects.active = cam_empty 277 | cam_constraint.target = cam_empty 278 | 279 | stepsize = 360.0 / args.views 280 | rotation_mode = 'XYZ' 281 | 282 | model_identifier = args.name 283 | fp = os.path.join(os.path.abspath(args.output_folder), model_identifier) 284 | 285 | from math import radians 286 | 287 | for i in range(0, args.views): 288 | 289 | cam_empty.rotation_euler[0] = radians(45.0) 290 | cam_empty.rotation_euler[2] += radians(stepsize) 291 | render_file_path = fp + f'/image_{i}' 292 | 293 | scene.render.filepath = render_file_path 294 | 295 | bpy.ops.render.render(write_still=True) # render still 296 | 297 | with open(fp+f"/image_{i}.json", "w") as file: 298 | json.dump(listify_matrix(cam.matrix_world), file, indent=4) 299 | 300 | -------------------------------------------------------------------------------- /dataset_generation_scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | anyio==2.2.0 3 | argon2-cffi==20.1.0 4 | arrow==1.2.2 5 | async-generator==1.10 6 | attrs==20.3.0 7 | Babel==2.9.0 8 | backcall==0.2.0 9 | binaryornot==0.4.4 10 | bleach==3.3.0 11 | certifi==2020.12.5 12 | cffi==1.14.5 13 | chardet==4.0.0 14 | click==8.0.3 15 | colorama==0.4.4 16 | cookiecutter==1.7.3 17 | decorator==4.4.2 18 | defusedxml==0.7.1 19 | deprecation==2.1.0 20 | docutils==0.16 21 | docx2txt==0.8 22 | entrypoints==0.3 23 | idna==2.10 24 | imagesize==1.2.0 25 | ipykernel==5.5.0 26 | ipython==7.21.0 27 | ipython-genutils==0.2.0 28 | ipywidgets==7.6.5 29 | jedi==0.18.0 30 | Jinja2==2.11.3 31 | jinja2-time==0.2.0 32 | joblib==1.1.0 33 | json5==0.9.5 34 | jsonschema==3.2.0 35 | jupyter-client==6.1.12 36 | jupyter-core==4.7.1 37 | jupyter-packaging==0.11.1 38 | jupyter-server==1.5.1 39 | jupyterlab==3.0.12 40 | jupyterlab-pygments==0.1.2 41 | jupyterlab-server==2.3.0 42 | jupyterlab-widgets==1.0.2 43 | MarkupSafe==1.1.1 44 | mistune==0.8.4 45 | nbclassic==0.2.6 46 | nbclient==0.5.3 47 | nbconvert==6.0.7 48 | nbformat==5.1.2 49 | nest-asyncio==1.5.1 50 | nltk==3.6.7 51 | notebook==6.3.0 52 | npzviewer==0.2.0 53 | numpy==1.20.1 54 | open3d==0.14.1 55 | packaging==20.9 56 | pandas==1.2.3 57 | pandocfilters==1.4.3 58 | parso==0.8.1 59 | pickleshare==0.7.5 60 | Pillow==9.0.1 61 | pip-chill==1.0.1 62 | poyo==0.5.0 63 | prometheus-client==0.9.0 64 | prompt-toolkit==3.0.18 65 | pycparser==2.20 66 | Pygments==2.7.4 67 | pymeshlab==2021.10 68 | pyntcloud==0.1.6 69 | pyparsing==2.4.7 70 | PyQt5==5.15.6 71 | PyQt5-Qt5==5.15.2 72 | PyQt5-sip==12.9.0 73 | pyrsistent==0.17.3 74 | python-dateutil==2.8.1 75 | python-slugify==6.1.1 76 | pytz==2021.1 77 | pywin32==300 78 | pywinpty==0.5.7 79 | pyzmq==22.0.3 80 | regex==2021.11.10 81 | requests==2.25.1 82 | scikit-learn==1.1.0 83 | scipy==1.8.0 84 | Send2Trash==1.5.0 85 | six==1.15.0 86 | sklearn==0.0 87 | sniffio==1.2.0 88 | snowballstemmer==2.1.0 89 | Sphinx==3.4.3 90 | sphinxcontrib-applehelp==1.0.2 91 | sphinxcontrib-devhelp==1.0.2 92 | sphinxcontrib-htmlhelp==1.0.3 93 | sphinxcontrib-jsmath==1.0.1 94 | sphinxcontrib-qthelp==1.0.3 95 | sphinxcontrib-serializinghtml==1.1.4 96 | terminado==0.9.3 97 | testpath==0.4.4 98 | text-unidecode==1.3 99 | threadpoolctl==3.1.0 100 | tomlkit==0.9.2 101 | tornado==6.1 102 | tqdm==4.62.3 103 | traitlets==5.0.5 104 | urllib3==1.26.3 105 | wcwidth==0.2.5 106 | webencodings==0.5.1 107 | widgetsnbextension==3.5.2 108 | zstandard==0.15.2 109 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join, exists 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from datetime import datetime 9 | from nerf_helpers import * 10 | from itertools import chain 11 | from tqdm import tqdm 12 | 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | 17 | from utils import * 18 | 19 | import trimesh, mcubes 20 | 21 | from dataset.dataset import NeRFShapeNetDataset 22 | 23 | from models.encoder import Encoder 24 | from models.nerf import NeRF 25 | from models.resnet import resnet18 26 | from hypnettorch.hnets.chunked_mlp_hnet import ChunkedHMLP 27 | 28 | from ChamferDistancePytorch.fscore import fscore 29 | 30 | import open3d as o3d 31 | 32 | import argparse 33 | 34 | import ChamferDistancePytorch.chamfer_python as chfp 35 | 36 | #Needed for workers for dataloader 37 | from torch.multiprocessing import Pool, Process, set_start_method 38 | set_start_method('spawn', force=True) 39 | 40 | def rot_x(angle): 41 | rx = torch.Tensor([ [1,0,0], 42 | [0, math.cos(angle), -math.sin(angle)], 43 | [0, math.sin(angle), math.cos(angle)]]) 44 | return rx 45 | 46 | 47 | def as_mesh(scene_or_mesh): 48 | if isinstance(scene_or_mesh, trimesh.Scene): 49 | mesh = trimesh.util.concatenate([ 50 | trimesh.Trimesh(vertices=m.vertices, faces=m.faces) 51 | for m in scene_or_mesh.geometry.values()]) 52 | else: 53 | mesh = scene_or_mesh 54 | return mesh 55 | 56 | def calculate_best_mesh_metrics(obj_path, render_kwargs, save_pc=True, name="1", thresholds=[1,2,3]): 57 | 58 | fscores = [calculate_mesh_metrics(obj_path, render_kwargs, save_pc, name+f'_{t}', t) for t in thresholds] 59 | return max(fscores, key=lambda x: x[0].item()) 60 | 61 | def calculate_mesh_metrics(obj_path, render_kwargs, save_pc=True, name="1", threshold = 3): 62 | 63 | with torch.no_grad(): 64 | N = 128 65 | t = torch.linspace(-1.1, 1.1, N+1) 66 | 67 | query_pts = torch.stack(torch.meshgrid(t, t, t), -1) 68 | sh = query_pts.shape 69 | flat = query_pts.reshape([-1,3]) 70 | 71 | def batchify(fn, chunk): 72 | if chunk is None: 73 | return fn 74 | def ret(inputs): 75 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 76 | return ret 77 | 78 | fn = lambda i0, i1 : render_kwargs['network_query_fn'](flat[i0:i1,None,:], viewdirs=None, network_fn=render_kwargs['network_fn']) 79 | chunk = 1024*16 80 | raw = torch.cat([fn(i, i+chunk) for i in range(0, flat.shape[0], chunk)], 0) 81 | raw = torch.reshape(raw, list(sh[:-1]) + [-1]) 82 | sigma = torch.maximum(raw[...,-1], torch.Tensor([0.])) 83 | 84 | 85 | vertices, triangles = mcubes.marching_cubes(sigma.cpu().numpy(), threshold) 86 | mesh = trimesh.Trimesh(vertices / N - .5, triangles) 87 | 88 | try: 89 | entry_mesh = trimesh.load_mesh(obj_path, force='mesh') 90 | entry_mesh = as_mesh(entry_mesh) 91 | 92 | entry_points = trimesh.sample.sample_surface(entry_mesh, 3000) 93 | entry_points = torch.from_numpy(entry_points[0]).to(device, dtype=torch.float) 94 | 95 | entry_points = rot_x(math.pi/2)@entry_points.T 96 | entry_points = entry_points.T 97 | entry_points = entry_points[None] 98 | 99 | sampled_points = trimesh.sample.sample_surface(mesh, 3000) 100 | sampled_points = torch.from_numpy(sampled_points[0])[None].to(device, dtype=torch.float) 101 | 102 | dist1, dist2, idx1, idx2 = chfp.distChamfer(entry_points, sampled_points) 103 | cd = (torch.mean(dist1)) + (torch.mean(dist2)) 104 | f_score, precision, recall = fscore(dist1, dist2, 0.01) 105 | except Exception as e: 106 | print(e) 107 | f_score = torch.Tensor([0.0]) 108 | cd = torch.Tensor([0.0]) 109 | 110 | iou=0 111 | if save_pc: 112 | try: 113 | pcd = o3d.geometry.PointCloud() 114 | pcd.points = o3d.utility.Vector3dVector(sampled_points.detach().cpu().numpy()[0]) 115 | o3d.io.write_point_cloud(f"./results/pcs/{name}_sampled_points.ply", pcd) 116 | 117 | pcd = o3d.geometry.PointCloud() 118 | #print(entry_points.detach().cpu().numpy()) 119 | #print(entry_points.detach().cpu().numpy().shape) 120 | pcd.points = o3d.utility.Vector3dVector(entry_points.detach().cpu().numpy()[0]) 121 | o3d.io.write_point_cloud(f"./results/pcs/{name}_entry_points.ply", pcd) 122 | except Exception as e: 123 | print(e) 124 | print("something went wrong with saving point cloud!") 125 | raise e 126 | return f_score, cd #remember the threshold! 127 | 128 | 129 | def calculate_image_metrics(entry, render_kwargs, metric_fn, count=5): 130 | x = [] 131 | y = [] 132 | with torch.no_grad(): 133 | for c in range(count): 134 | img_i = np.random.choice(len(entry['images'][j]), 1) 135 | target = entry['images'][j][img_i][0].to(device)#entry['images'][j][img_i][0].to(device) 136 | target = torch.Tensor(target.float()) 137 | pose = entry['cam_poses'][j][img_i, :3,:4][0].to(device) 138 | 139 | H = entry["images"][j].shape[1] 140 | W = entry["images"][j].shape[2] 141 | focal = .5 * W / np.tan(.5 * 0.6911112070083618) 142 | 143 | K = np.array([ 144 | [focal, 0, 0.5*W], 145 | [0, focal, 0.5*H], 146 | [0, 0, 1] 147 | ]) 148 | 149 | img, _, _, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=pose, 150 | verbose=True, retraw=True, 151 | **render_kwargs) 152 | 153 | x.append(img) 154 | y.append(target) 155 | 156 | x = torch.stack(x) 157 | y = torch.stack(y) 158 | 159 | metric_val = metric_fn(y, x) 160 | 161 | return metric_val 162 | 163 | if __name__ == '__main__': 164 | pd.set_option('display.max_columns', None) 165 | pd.set_option('display.max_rows', None) 166 | 167 | dirname = os.path.dirname(__file__) 168 | 169 | parser = argparse.ArgumentParser(description='Start training HyperRF') 170 | parser.add_argument('config_path', type=str, 171 | help='Relative config path') 172 | 173 | args = parser.parse_args() 174 | 175 | config = None 176 | with open(args.config_path) as f: 177 | config = json.load(f) 178 | assert config is not None 179 | 180 | print(config) 181 | 182 | set_seed(config['seed']) 183 | 184 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 185 | 186 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 187 | 188 | #config['classes'] = ['cars'] 189 | 190 | dataset = NeRFShapeNetDataset(root_dir=config['data_dir'], classes=config['classes'], train=False) 191 | 192 | config['batch_size'] = 1 193 | 194 | dataloader = DataLoader(dataset, batch_size=config['batch_size'], 195 | shuffle=config['shuffle'], 196 | num_workers=2, drop_last=True, 197 | pin_memory=True, generator=torch.Generator(device='cuda')) 198 | 199 | embed_fn, config['model']['TN']['input_ch_embed'] = get_embedder(config['model']['TN']['multires'], config['model']['TN']['i_embed']) 200 | 201 | embeddirs_fn = None 202 | config['model']['TN']['input_ch_views_embed'] = 0 203 | if config['model']['TN']['use_viewdirs']: 204 | embeddirs_fn, config['model']['TN']['input_ch_views_embed']= get_embedder(config['model']['TN']['multires_views'], config['model']['TN']['i_embed']) 205 | 206 | # Create a NeRF network 207 | nerf = NeRF(config['model']['TN']['D'],config['model']['TN']['W'], 208 | config['model']['TN']['input_ch_embed'], 209 | config['model']['TN']['input_ch_views_embed'], 210 | config['model']['TN']['use_viewdirs']).to(device) 211 | 212 | #Hypernetwork 213 | hnet = ChunkedHMLP(nerf.param_shapes, uncond_in_size=config['z_size'], cond_in_size=0, 214 | layers=config['model']['HN']['arch'], chunk_size=config['model']['HN']['chunk_size'], cond_chunk_embs=False, use_bias=config['model']['HN']['use_bias']).to(device) 215 | 216 | #Create encoder: either Resnet or classic 217 | if config['resnet']==True: 218 | encoder = resnet18(num_classes=config['z_size']).to(device) 219 | else: 220 | encoder = Encoder(config).to(device) 221 | 222 | results_dir = config['results_dir'] 223 | os.makedirs(join(dirname,results_dir), exist_ok=True) 224 | 225 | with open(join(results_dir, "config_eval.json"), "w") as file: 226 | json.dump(config, file, indent=4) 227 | 228 | try: 229 | losses_r = np.load(join(results_dir, f'losses_r.npy')).tolist() 230 | print("Loaded reconstruction losses") 231 | losses_kld = np.load(join(results_dir, f'losses_kld.npy')).tolist() 232 | print("Loaded KLD losses") 233 | losses_total = np.load(join(results_dir, f'losses_total.npy')).tolist() 234 | print("Loaded total losses") 235 | except: 236 | print("Haven't found previous losses. Is this a new experiment?") 237 | losses_r = [] 238 | losses_kld = [] 239 | losses_total = [] 240 | 241 | if losses_total == []: 242 | print("Loading \'latest\' model without loaded losses") 243 | try: 244 | hnet.load_state_dict(torch.load(join(results_dir, f"model_hn_latest.pt"))) 245 | print("Loaded HNet") 246 | encoder.load_state_dict(torch.load(join(results_dir, f"model_e_latest.pt"))) 247 | print("Loaded Encoder") 248 | scheduler.load_state_dict(torch.load(join(results_dir, f"lr_latest.pt"))) 249 | print("Loaded Scheduler") 250 | except: 251 | print("Haven't loaded all previous models.") 252 | else: 253 | starting_epoch = len(losses_total) 254 | 255 | print("starting epoch:", starting_epoch) 256 | 257 | if(starting_epoch>0): 258 | print("Loading weights since previous losses were found") 259 | try: 260 | hnet.load_state_dict(torch.load(join(results_dir, f"model_hn_{starting_epoch-1}.pt"))) 261 | print("Loaded HNet") 262 | encoder.load_state_dict(torch.load(join(results_dir, f"model_e_{starting_epoch-1}.pt"))) 263 | print("Loaded Encoder") 264 | scheduler.load_state_dict(torch.load(join(results_dir, f"lr_{starting_epoch-1}.pt"))) 265 | print("Loaded Scheduler") 266 | except: 267 | print("Haven't loaded all previous models.") 268 | 269 | results_dir = join(results_dir, 'eval') 270 | os.makedirs(results_dir, exist_ok=True) 271 | 272 | encoder.eval() 273 | hnet.eval() 274 | 275 | mse = torch.nn.MSELoss() 276 | psnr_metric = lambda x,y: torch.mean(mse2psnr(mse(y,x))) 277 | 278 | eval_results = pd.DataFrame(columns=['class', 'fscore', 'cd', 'psnr']) 279 | 280 | for i, (entry, cat, obj_path) in enumerate(dataloader): 281 | start_time = datetime.now() 282 | 283 | if config['resnet']: 284 | nerf_Ws = get_nerf_resnet(entry, encoder, hnet) 285 | else: 286 | nerf_Ws, mu, logvar = get_nerf(entry, encoder, hnet) 287 | 288 | #For batch size == 1 hnet doesn't return batch dimension... 289 | if config['batch_size'] == 1: 290 | nerf_Ws = [nerf_Ws] 291 | 292 | for j, target_w in enumerate(nerf_Ws): 293 | render_kwargs = get_render_kwargs(config, nerf, target_w, embed_fn, embeddirs_fn) 294 | render_kwargs['perturb'] = False 295 | render_kwargs['raw_noise_std'] = 0. 296 | 297 | points = entry["data"][j] 298 | points = points.to(device, dtype=torch.float) 299 | 300 | #render_kwargs = get_render_kwargs(config, nerf, target_w, embed_fn, embeddirs_fn) 301 | 302 | #f_score,cd = calculate_metrics(obj_path[0], render_kwargs,save_pc=i<=2, name=str(i)) 303 | f_score, cd = calculate_best_mesh_metrics(obj_path[j], render_kwargs, save_pc=False, name=str(i), thresholds=[3]) 304 | #f_score = torch.Tensor([0.]) 305 | #cd = torch.Tensor([0.]) 306 | psnr = calculate_image_metrics(entry, render_kwargs, psnr_metric, count=5) 307 | 308 | #print(i,' fscore-> ', f_score) 309 | #print(i, 'cd-> ', cd) 310 | #print(i,' psnr-> ', psnr) 311 | eval_results = eval_results.append({'class': cat[j], 'fscore': f_score.item(),'cd': cd.item(), 'psnr': psnr.item()}, ignore_index=True) 312 | #print("Time:", round((datetime.now() - start_time).total_seconds(), 2)) 313 | print(eval_results.groupby("class").describe()) 314 | print("---------------") 315 | print(eval_results[['fscore', 'cd', 'psnr']].describe()) 316 | -------------------------------------------------------------------------------- /examples/car.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/points2nerf/798c6c4b4e309a93c5cf1051e942c8e2629e4076/examples/car.gif -------------------------------------------------------------------------------- /examples/chair.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/points2nerf/798c6c4b4e309a93c5cf1051e942c8e2629e4076/examples/chair.gif -------------------------------------------------------------------------------- /examples/interpolation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/points2nerf/798c6c4b4e309a93c5cf1051e942c8e2629e4076/examples/interpolation.gif -------------------------------------------------------------------------------- /examples/plane.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gmum/points2nerf/798c6c4b4e309a93c5cf1051e942c8e2629e4076/examples/plane.gif -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Encoder(nn.Module): 5 | def __init__(self, config): 6 | super().__init__() 7 | 8 | self.z_size = config['z_size'] 9 | self.use_bias = config['model']['E']['use_bias'] 10 | self.relu_slope = config['model']['E']['relu_slope'] 11 | 12 | self.conv = nn.Sequential( 13 | nn.Conv1d(in_channels=6, out_channels=64, kernel_size=1, bias=self.use_bias), 14 | nn.ReLU(inplace=True), 15 | 16 | nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, bias=self.use_bias), 17 | nn.ReLU(inplace=True), 18 | 19 | nn.Conv1d(in_channels=128, out_channels=256, kernel_size=1, bias=self.use_bias), 20 | nn.ReLU(inplace=True), 21 | 22 | nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1, bias=self.use_bias), 23 | nn.ReLU(inplace=True), 24 | 25 | nn.Conv1d(in_channels=512, out_channels=512, kernel_size=1, bias=self.use_bias), 26 | ) 27 | 28 | self.fc = nn.Sequential( 29 | nn.Linear(512, 512, bias=True), 30 | nn.ReLU(inplace=True) 31 | ) 32 | 33 | self.mu_layer = nn.Linear(512, self.z_size, bias=True) 34 | self.std_layer = nn.Linear(512, self.z_size, bias=True) 35 | 36 | def reparameterize(self, mu, logvar): 37 | std = torch.exp(logvar) 38 | eps = torch.randn_like(std) 39 | return eps.mul(std).add_(mu) 40 | 41 | def forward(self, x): 42 | output = self.conv(x) 43 | output2 = output.max(dim=2)[0] 44 | logit = self.fc(output2) 45 | mu = self.mu_layer(logit) 46 | logvar = self.std_layer(logit) 47 | z = self.reparameterize(mu, logvar) 48 | return z, mu, torch.exp(logvar) 49 | 50 | def freeze(self): 51 | for p in self.parameters(): 52 | p.requires_grad = False -------------------------------------------------------------------------------- /models/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import chain 3 | from hypnettorch.mnets import MLP 4 | from hypnettorch.hnets import ChunkedHMLP, HMLP 5 | 6 | class NeRF(torch.nn.Module): 7 | 8 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, use_viewdirs=True): 9 | super(NeRF, self).__init__() 10 | self.D = D 11 | self.W = W 12 | self.input_ch = input_ch 13 | self.input_ch_views = input_ch_views 14 | self.use_viewdirs = use_viewdirs 15 | 16 | self.before_skip = MLP(n_in=self.input_ch, 17 | n_out=self.W, hidden_layers=[self.W]*(self.D//2), 18 | no_weights=False, out_fn=torch.nn.ReLU()) 19 | 20 | self.after_skip = MLP(n_in=self.input_ch + self.W, 21 | n_out=self.W, hidden_layers=[self.W]*(self.D//2), 22 | no_weights=False, out_fn=torch.nn.ReLU()) 23 | 24 | #assume we use viewdirs 25 | if use_viewdirs: 26 | self.out_sigma = MLP(n_in=self.W, 27 | n_out=1, hidden_layers=[], 28 | no_weights=False) 29 | self.out_feature = MLP(n_in=self.W, 30 | n_out=self.W, hidden_layers=[], 31 | no_weights=False, activation_fn=None) 32 | self.out_rgb = MLP(n_in=self.W + self.input_ch_views, 33 | n_out=3, hidden_layers=[self.W//2], 34 | no_weights=False) 35 | 36 | self.internal_params = chain(self.before_skip.internal_params, self.after_skip.internal_params, self.out_sigma.internal_params, self.out_feature.internal_params, self.out_rgb.internal_params) 37 | else: 38 | self.output_linear = MLP(n_in=self.W, 39 | n_out=4, hidden_layers=[], 40 | no_weights=False, activation_fn=None) 41 | 42 | self.internal_params = chain(self.before_skip.internal_params, self.after_skip.internal_params, self.output_linear.internal_params) 43 | 44 | 45 | 46 | def forward(self, x, weights=None): 47 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 48 | h = input_pts 49 | 50 | if self.use_viewdirs: 51 | bs_l, as_l, os_l, of_l, orgb_l = self.unpack_weights(weights) 52 | else: 53 | bs_l, as_l, ol_l = self.unpack_weights(weights) 54 | 55 | h, _=self.before_skip(h, weights=bs_l) 56 | h = torch.cat([h, input_pts], -1) 57 | h, _=self.after_skip(h,weights=as_l) 58 | 59 | if self.use_viewdirs: 60 | sigma = self.out_sigma(h,weights=os_l) 61 | feature = self.out_feature(h,weights=of_l) 62 | h = torch.cat([feature, input_views], -1) 63 | rgb = self.out_rgb(h, weights=orgb_l) 64 | return torch.cat([rgb, sigma], -1) 65 | else: 66 | return self.output_linear(h, weights=ol_l) 67 | 68 | 69 | @property 70 | def param_shapes(self) -> list: 71 | if self.use_viewdirs: 72 | return list(chain(self.before_skip.param_shapes, self.after_skip.param_shapes, self.out_sigma.param_shapes, 73 | self.out_feature.param_shapes, self.out_rgb.param_shapes)) 74 | else: 75 | return list(chain(self.before_skip.param_shapes, self.after_skip.param_shapes, self.output_linear.param_shapes)) 76 | 77 | def unpack_weights(self, weights) -> list: 78 | if(weights is None): 79 | print("No weights!") 80 | return None 81 | weights = weights.copy() 82 | bs_weights = [] 83 | for param in self.before_skip.param_shapes: 84 | bs_weights.append(weights.pop(0)) 85 | 86 | as_weights = [] 87 | for param in self.after_skip.param_shapes: 88 | as_weights.append(weights.pop(0)) 89 | 90 | if self.use_viewdirs: 91 | os_weights = [] 92 | for param in self.out_sigma.param_shapes: 93 | os_weights.append(weights.pop(0)) 94 | 95 | of_weights = [] 96 | for param in self.out_feature.param_shapes: 97 | of_weights.append(weights.pop(0)) 98 | 99 | orgb_weights = [] 100 | for param in self.out_rgb.param_shapes: 101 | orgb_weights.append(weights.pop(0)) 102 | 103 | assert len(weights)==0 104 | 105 | return bs_weights, as_weights, os_weights, of_weights, orgb_weights 106 | else: 107 | ol_weights = [] 108 | for param in self.output_linear.param_shapes: 109 | ol_weights.append(weights.pop(0)) 110 | 111 | assert len(weights)==0 112 | 113 | return bs_weights, as_weights, ol_weights -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Any, Callable, Union, List, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | #from .._internally_replaced_utils import load_state_dict_from_url 8 | #from ..utils import _log_api_usage_once 9 | 10 | 11 | __all__ = [ 12 | "ResNet", 13 | "resnet18", 14 | "resnet34", 15 | "resnet50", 16 | "resnet101", 17 | "resnet152", 18 | "resnext50_32x4d", 19 | "resnext101_32x8d", 20 | "wide_resnet50_2", 21 | "wide_resnet101_2", 22 | ] 23 | 24 | 25 | model_urls = { 26 | "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", 27 | "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", 28 | "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", 29 | "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", 30 | "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", 31 | "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", 32 | "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", 33 | "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", 34 | "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", 35 | } 36 | 37 | 38 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 39 | """3x3 convolution with padding""" 40 | return nn.Conv2d( 41 | in_planes, 42 | out_planes, 43 | kernel_size=3, 44 | stride=stride, 45 | padding=dilation, 46 | groups=groups, 47 | bias=False, 48 | dilation=dilation, 49 | ) 50 | 51 | 52 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 53 | """1x1 convolution""" 54 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 55 | 56 | 57 | class BasicBlock(nn.Module): 58 | expansion: int = 1 59 | 60 | def __init__( 61 | self, 62 | inplanes: int, 63 | planes: int, 64 | stride: int = 1, 65 | downsample: Optional[nn.Module] = None, 66 | groups: int = 1, 67 | base_width: int = 64, 68 | dilation: int = 1, 69 | norm_layer: Optional[Callable[..., nn.Module]] = None, 70 | ) -> None: 71 | super().__init__() 72 | if norm_layer is None: 73 | norm_layer = nn.BatchNorm2d 74 | if groups != 1 or base_width != 64: 75 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 76 | if dilation > 1: 77 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 78 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 79 | self.conv1 = conv3x3(inplanes, planes, stride) 80 | self.bn1 = norm_layer(planes) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.conv2 = conv3x3(planes, planes) 83 | self.bn2 = norm_layer(planes) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x: Tensor) -> Tensor: 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | 97 | if self.downsample is not None: 98 | identity = self.downsample(x) 99 | 100 | out += identity 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class Bottleneck(nn.Module): 107 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 108 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 109 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 110 | # This variant is also known as ResNet V1.5 and improves accuracy according to 111 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 112 | 113 | expansion: int = 4 114 | 115 | def __init__( 116 | self, 117 | inplanes: int, 118 | planes: int, 119 | stride: int = 1, 120 | downsample: Optional[nn.Module] = None, 121 | groups: int = 1, 122 | base_width: int = 64, 123 | dilation: int = 1, 124 | norm_layer: Optional[Callable[..., nn.Module]] = None, 125 | ) -> None: 126 | super().__init__() 127 | if norm_layer is None: 128 | norm_layer = nn.BatchNorm2d 129 | width = int(planes * (base_width / 64.0)) * groups 130 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 131 | self.conv1 = conv1x1(inplanes, width) 132 | self.bn1 = norm_layer(width) 133 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 134 | self.bn2 = norm_layer(width) 135 | self.conv3 = conv1x1(width, planes * self.expansion) 136 | self.bn3 = norm_layer(planes * self.expansion) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.downsample = downsample 139 | self.stride = stride 140 | 141 | def forward(self, x: Tensor) -> Tensor: 142 | identity = x 143 | 144 | out = self.conv1(x) 145 | out = self.bn1(out) 146 | out = self.relu(out) 147 | 148 | out = self.conv2(out) 149 | out = self.bn2(out) 150 | out = self.relu(out) 151 | 152 | out = self.conv3(out) 153 | out = self.bn3(out) 154 | 155 | if self.downsample is not None: 156 | identity = self.downsample(x) 157 | 158 | out += identity 159 | out = self.relu(out) 160 | 161 | return out 162 | 163 | 164 | class ResNet(nn.Module): 165 | def __init__( 166 | self, 167 | block: Type[Union[BasicBlock, Bottleneck]], 168 | layers: List[int], 169 | num_classes: int = 1000, 170 | zero_init_residual: bool = False, 171 | groups: int = 1, 172 | width_per_group: int = 64, 173 | replace_stride_with_dilation: Optional[List[bool]] = None, 174 | norm_layer: Optional[Callable[..., nn.Module]] = None, 175 | ) -> None: 176 | super().__init__() 177 | #_log_api_usage_once(self) 178 | if norm_layer is None: 179 | norm_layer = nn.BatchNorm2d 180 | self._norm_layer = norm_layer 181 | 182 | self.inplanes = 64 183 | self.dilation = 1 184 | if replace_stride_with_dilation is None: 185 | # each element in the tuple indicates if we should replace 186 | # the 2x2 stride with a dilated convolution instead 187 | replace_stride_with_dilation = [False, False, False] 188 | if len(replace_stride_with_dilation) != 3: 189 | raise ValueError( 190 | "replace_stride_with_dilation should be None " 191 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 192 | ) 193 | self.groups = groups 194 | self.base_width = width_per_group 195 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 196 | self.bn1 = norm_layer(self.inplanes) 197 | self.relu = nn.ReLU(inplace=True) 198 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 199 | self.layer1 = self._make_layer(block, 64, layers[0]) 200 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 201 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 202 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 203 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 204 | self.fc = nn.Linear(512 * block.expansion, num_classes) 205 | 206 | #VAE 207 | self.mu_layer = nn.Linear(512*block.expansion, num_classes, bias=True) 208 | self.std_layer = nn.Linear(512*block.expansion, num_classes, bias=True) 209 | 210 | for m in self.modules(): 211 | if isinstance(m, nn.Conv2d): 212 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 213 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 214 | nn.init.constant_(m.weight, 1) 215 | nn.init.constant_(m.bias, 0) 216 | 217 | # Zero-initialize the last BN in each residual branch, 218 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 219 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 220 | if zero_init_residual: 221 | for m in self.modules(): 222 | if isinstance(m, Bottleneck): 223 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 224 | elif isinstance(m, BasicBlock): 225 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 226 | 227 | def _make_layer( 228 | self, 229 | block: Type[Union[BasicBlock, Bottleneck]], 230 | planes: int, 231 | blocks: int, 232 | stride: int = 1, 233 | dilate: bool = False, 234 | ) -> nn.Sequential: 235 | norm_layer = self._norm_layer 236 | downsample = None 237 | previous_dilation = self.dilation 238 | if dilate: 239 | self.dilation *= stride 240 | stride = 1 241 | if stride != 1 or self.inplanes != planes * block.expansion: 242 | downsample = nn.Sequential( 243 | conv1x1(self.inplanes, planes * block.expansion, stride), 244 | norm_layer(planes * block.expansion), 245 | ) 246 | 247 | layers = [] 248 | layers.append( 249 | block( 250 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 251 | ) 252 | ) 253 | self.inplanes = planes * block.expansion 254 | for _ in range(1, blocks): 255 | layers.append( 256 | block( 257 | self.inplanes, 258 | planes, 259 | groups=self.groups, 260 | base_width=self.base_width, 261 | dilation=self.dilation, 262 | norm_layer=norm_layer, 263 | ) 264 | ) 265 | 266 | return nn.Sequential(*layers) 267 | 268 | def reparameterize(self, mu, logvar): 269 | std = torch.exp(logvar) 270 | eps = torch.randn_like(std) 271 | return eps.mul(std).add_(mu) 272 | 273 | def _forward_impl(self, x: Tensor) -> Tensor: 274 | # See note [TorchScript super()] 275 | x = self.conv1(x) 276 | x = self.bn1(x) 277 | x = self.relu(x) 278 | x = self.maxpool(x) 279 | 280 | x = self.layer1(x) 281 | x = self.layer2(x) 282 | x = self.layer3(x) 283 | x = self.layer4(x) 284 | 285 | x = self.avgpool(x) 286 | x = torch.flatten(x, 1) 287 | 288 | mu = self.mu_layer(x) 289 | logvar = self.std_layer(x) 290 | z = self.reparameterize(mu, logvar) 291 | return z, mu, torch.exp(logvar) 292 | 293 | #x = self.fc(x) 294 | 295 | #return x 296 | 297 | def forward(self, x: Tensor) -> Tensor: 298 | return self._forward_impl(x) 299 | 300 | 301 | def _resnet( 302 | arch: str, 303 | block: Type[Union[BasicBlock, Bottleneck]], 304 | layers: List[int], 305 | pretrained: bool, 306 | progress: bool, 307 | **kwargs: Any, 308 | ) -> ResNet: 309 | model = ResNet(block, layers, **kwargs) 310 | if pretrained: 311 | #state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 312 | #model.load_state_dict(state_dict) 313 | pass 314 | return model 315 | 316 | 317 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 318 | r"""ResNet-18 model from 319 | `"Deep Residual Learning for Image Recognition" `_. 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 326 | 327 | 328 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 329 | r"""ResNet-34 model from 330 | `"Deep Residual Learning for Image Recognition" `_. 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 337 | 338 | 339 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 340 | r"""ResNet-50 model from 341 | `"Deep Residual Learning for Image Recognition" `_. 342 | 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 348 | 349 | 350 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 351 | r"""ResNet-101 model from 352 | `"Deep Residual Learning for Image Recognition" `_. 353 | 354 | Args: 355 | pretrained (bool): If True, returns a model pre-trained on ImageNet 356 | progress (bool): If True, displays a progress bar of the download to stderr 357 | """ 358 | return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 359 | 360 | 361 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 362 | r"""ResNet-152 model from 363 | `"Deep Residual Learning for Image Recognition" `_. 364 | 365 | Args: 366 | pretrained (bool): If True, returns a model pre-trained on ImageNet 367 | progress (bool): If True, displays a progress bar of the download to stderr 368 | """ 369 | return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) 370 | 371 | 372 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 373 | r"""ResNeXt-50 32x4d model from 374 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 375 | 376 | Args: 377 | pretrained (bool): If True, returns a model pre-trained on ImageNet 378 | progress (bool): If True, displays a progress bar of the download to stderr 379 | """ 380 | kwargs["groups"] = 32 381 | kwargs["width_per_group"] = 4 382 | return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 383 | 384 | 385 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 386 | r"""ResNeXt-101 32x8d model from 387 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 388 | 389 | Args: 390 | pretrained (bool): If True, returns a model pre-trained on ImageNet 391 | progress (bool): If True, displays a progress bar of the download to stderr 392 | """ 393 | kwargs["groups"] = 32 394 | kwargs["width_per_group"] = 8 395 | return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 396 | 397 | 398 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 399 | r"""Wide ResNet-50-2 model from 400 | `"Wide Residual Networks" `_. 401 | 402 | The model is the same as ResNet except for the bottleneck number of channels 403 | which is twice larger in every block. The number of channels in outer 1x1 404 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 405 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 406 | 407 | Args: 408 | pretrained (bool): If True, returns a model pre-trained on ImageNet 409 | progress (bool): If True, displays a progress bar of the download to stderr 410 | """ 411 | kwargs["width_per_group"] = 64 * 2 412 | return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 413 | 414 | 415 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 416 | r"""Wide ResNet-101-2 model from 417 | `"Wide Residual Networks" `_. 418 | 419 | The model is the same as ResNet except for the bottleneck number of channels 420 | which is twice larger in every block. The number of channels in outer 1x1 421 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 422 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 423 | 424 | Args: 425 | pretrained (bool): If True, returns a model pre-trained on ImageNet 426 | progress (bool): If True, displays a progress bar of the download to stderr 427 | """ 428 | kwargs["width_per_group"] = 64 * 2 429 | return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 430 | -------------------------------------------------------------------------------- /nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm, trange 11 | 12 | import math 13 | 14 | DEBUG = False 15 | 16 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 17 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | trans_t = lambda t : torch.Tensor([ 22 | [1,0,0,0], 23 | [0,1,0,0], 24 | [0,0,1,t], 25 | [0,0,0,1]]).float() 26 | 27 | rot_phi = lambda phi : torch.Tensor([ 28 | [1,0,0,0], 29 | [0,np.cos(phi),-np.sin(phi),0], 30 | [0,np.sin(phi), np.cos(phi),0], 31 | [0,0,0,1]]).float() 32 | 33 | rot_theta = lambda th : torch.Tensor([ 34 | [np.cos(th),0,-np.sin(th),0], 35 | [0,1,0,0], 36 | [np.sin(th),0, np.cos(th),0], 37 | [0,0,0,1]]).float() 38 | 39 | 40 | def pose_spherical(theta, phi, radius): 41 | c2w = trans_t(radius) 42 | c2w = rot_phi(phi/180.*np.pi) @ c2w 43 | c2w = rot_theta(theta/180.*np.pi) @ c2w 44 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 45 | return c2w 46 | 47 | class Embedder: 48 | def __init__(self, **kwargs): 49 | self.kwargs = kwargs 50 | self.create_embedding_fn() 51 | 52 | def create_embedding_fn(self): 53 | embed_fns = [] 54 | d = self.kwargs['input_dims'] 55 | out_dim = 0 56 | if self.kwargs['include_input']: 57 | embed_fns.append(lambda x : x) 58 | out_dim += d 59 | 60 | max_freq = self.kwargs['max_freq_log2'] 61 | N_freqs = self.kwargs['num_freqs'] 62 | 63 | if self.kwargs['log_sampling']: 64 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 65 | else: 66 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 67 | 68 | for freq in freq_bands: 69 | for p_fn in self.kwargs['periodic_fns']: 70 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 71 | out_dim += d 72 | 73 | self.embed_fns = embed_fns 74 | self.out_dim = out_dim 75 | 76 | def embed(self, inputs): 77 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 78 | 79 | 80 | def get_embedder(multires, i=0): 81 | if i == -1: 82 | return nn.Identity(), 3 83 | 84 | embed_kwargs = { 85 | 'include_input' : True, 86 | 'input_dims' : 3, 87 | 'max_freq_log2' : multires-1, 88 | 'num_freqs' : multires, 89 | 'log_sampling' : True, 90 | 'periodic_fns' : [torch.sin, torch.cos], 91 | } 92 | 93 | embedder_obj = Embedder(**embed_kwargs) 94 | embed = lambda x, eo=embedder_obj : eo.embed(x) 95 | return embed, embedder_obj.out_dim 96 | 97 | 98 | # Ray helpers 99 | def get_rays(H, W, K, c2w): 100 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 101 | i = i.t() 102 | j = j.t() 103 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) 104 | # Rotate ray directions from camera frame to the world frame 105 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 106 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 107 | rays_o = c2w[:3,-1].expand(rays_d.shape) 108 | return rays_o, rays_d 109 | 110 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 111 | # Shift ray origins to near plane 112 | t = -(near + rays_o[...,2]) / rays_d[...,2] 113 | rays_o = rays_o + t[...,None] * rays_d 114 | 115 | # Projection 116 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 117 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 118 | o2 = 1. + 2. * near / rays_o[...,2] 119 | 120 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 121 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 122 | d2 = -2. * near / rays_o[...,2] 123 | 124 | rays_o = torch.stack([o0,o1,o2], -1) 125 | rays_d = torch.stack([d0,d1,d2], -1) 126 | 127 | return rays_o, rays_d 128 | 129 | 130 | # Hierarchical sampling (section 5.2) 131 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 132 | # Get pdf 133 | weights = weights + 1e-5 # prevent nans 134 | pdf = weights / torch.sum(weights, -1, keepdim=True) 135 | cdf = torch.cumsum(pdf, -1) 136 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 137 | 138 | # Take uniform samples 139 | if det: 140 | u = torch.linspace(0., 1., steps=N_samples) 141 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 142 | else: 143 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 144 | 145 | # Pytest, overwrite u with numpy's fixed random numbers 146 | if pytest: 147 | np.random.seed(0) 148 | new_shape = list(cdf.shape[:-1]) + [N_samples] 149 | if det: 150 | u = np.linspace(0., 1., N_samples) 151 | u = np.broadcast_to(u, new_shape) 152 | else: 153 | u = np.random.rand(*new_shape) 154 | u = torch.Tensor(u) 155 | 156 | # Invert CDF 157 | u = u.contiguous() 158 | inds = torch.searchsorted(cdf, u, right=True) 159 | below = torch.max(torch.zeros_like(inds-1), inds-1) 160 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 161 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 162 | 163 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 164 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 165 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 166 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 167 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 168 | 169 | denom = (cdf_g[...,1]-cdf_g[...,0]) 170 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 171 | t = (u-cdf_g[...,0])/denom 172 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 173 | 174 | return samples 175 | 176 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 177 | near=0., far=1., 178 | use_viewdirs=False, c2w_staticcam=None, 179 | **kwargs): 180 | """Render rays 181 | Args: 182 | H: int. Height of image in pixels. 183 | W: int. Width of image in pixels. 184 | focal: float. Focal length of pinhole camera. 185 | chunk: int. Maximum number of rays to process simultaneously. Used to 186 | control maximum memory usage. Does not affect final results. 187 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 188 | each example in batch. 189 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 190 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 191 | near: float or array of shape [batch_size]. Nearest distance for a ray. 192 | far: float or array of shape [batch_size]. Farthest distance for a ray. 193 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 194 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 195 | camera while using other c2w argument for viewing directions. 196 | Returns: 197 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 198 | disp_map: [batch_size]. Disparity map. Inverse of depth. 199 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 200 | extras: dict with everything returned by render_rays(). 201 | """ 202 | if c2w is not None: 203 | # special case to render full image 204 | rays_o, rays_d = get_rays(H, W, K, c2w) 205 | else: 206 | # use provided ray batch 207 | rays_o, rays_d = rays 208 | 209 | if use_viewdirs: 210 | # provide ray directions as input 211 | viewdirs = rays_d 212 | if c2w_staticcam is not None: 213 | # special case to visualize effect of viewdirs 214 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 215 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 216 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 217 | 218 | sh = rays_d.shape # [..., 3] 219 | if ndc: 220 | # for forward facing scenes 221 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 222 | 223 | # Create ray batch 224 | rays_o = torch.reshape(rays_o, [-1,3]).float() 225 | rays_d = torch.reshape(rays_d, [-1,3]).float() 226 | 227 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 228 | rays = torch.cat([rays_o, rays_d, near, far], -1) 229 | if use_viewdirs: 230 | rays = torch.cat([rays, viewdirs], -1) 231 | 232 | # Render and reshape 233 | all_ret = batchify_rays(rays, chunk, **kwargs) 234 | for k in all_ret: 235 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 236 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 237 | 238 | k_extract = ['rgb_map', 'disp_map', 'acc_map'] 239 | ret_list = [all_ret[k] for k in k_extract] 240 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 241 | return ret_list + [ret_dict] 242 | 243 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 244 | """Render rays in smaller minibatches to avoid OOM. 245 | """ 246 | all_ret = {} 247 | for i in range(0, rays_flat.shape[0], chunk): 248 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 249 | for k in ret: 250 | if k not in all_ret: 251 | all_ret[k] = [] 252 | all_ret[k].append(ret[k]) 253 | 254 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 255 | return all_ret 256 | 257 | 258 | def render_rays(ray_batch, 259 | network_fn, 260 | network_query_fn, 261 | N_samples, 262 | retraw=False, 263 | lindisp=False, 264 | perturb=0., 265 | N_importance=0, 266 | network_fine=None, 267 | white_bkgd=False, 268 | raw_noise_std=0., 269 | verbose=False, 270 | pytest=False): 271 | """Volumetric rendering. 272 | Args: 273 | ray_batch: array of shape [batch_size, ...]. All information necessary 274 | for sampling along a ray, including: ray origin, ray direction, min 275 | dist, max dist, and unit-magnitude viewing direction. 276 | network_fn: function. Model for predicting RGB and density at each point 277 | in space. 278 | network_query_fn: function used for passing queries to network_fn. 279 | N_samples: int. Number of different times to sample along each ray. 280 | retraw: bool. If True, include model's raw, unprocessed predictions. 281 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 282 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 283 | random points in time. 284 | N_importance: int. Number of additional times to sample along each ray. 285 | These samples are only passed to network_fine. 286 | network_fine: "fine" network with same spec as network_fn. 287 | white_bkgd: bool. If True, assume a white background. 288 | raw_noise_std: ... 289 | verbose: bool. If True, print more debugging info. 290 | Returns: 291 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 292 | disp_map: [num_rays]. Disparity map. 1 / depth. 293 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 294 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 295 | rgb0: See rgb_map. Output for coarse model. 296 | disp0: See disp_map. Output for coarse model. 297 | acc0: See acc_map. Output for coarse model. 298 | z_std: [num_rays]. Standard deviation of distances along ray for each 299 | sample. 300 | """ 301 | N_rays = ray_batch.shape[0] 302 | rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each 303 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None 304 | bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2]) 305 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 306 | 307 | t_vals = torch.linspace(0., 1., steps=N_samples) 308 | if not lindisp: 309 | z_vals = near * (1.-t_vals) + far * (t_vals) 310 | else: 311 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 312 | 313 | z_vals = z_vals.expand([N_rays, N_samples]) 314 | 315 | if perturb > 0.: 316 | # get intervals between samples 317 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 318 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 319 | lower = torch.cat([z_vals[...,:1], mids], -1) 320 | # stratified samples in those intervals 321 | t_rand = torch.rand(z_vals.shape) 322 | 323 | # Pytest, overwrite u with numpy's fixed random numbers 324 | if pytest: 325 | np.random.seed(0) 326 | t_rand = np.random.rand(*list(z_vals.shape)) 327 | t_rand = torch.Tensor(t_rand) 328 | 329 | z_vals = lower + (upper - lower) * t_rand 330 | 331 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] 332 | 333 | 334 | #raw = run_network(pts, **kwargs) 335 | raw = network_query_fn(pts, viewdirs, network_fn) 336 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 337 | """ 338 | if N_importance > 0: 339 | 340 | rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map 341 | 342 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 343 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 344 | z_samples = z_samples.detach() 345 | 346 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 347 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] 348 | 349 | run_fn = network_fn if network_fine is None else network_fine #jak nie ma network fine to uzyj zwyklej 350 | #raw = run_network(pts, fn=run_fn) 351 | raw = network_query_fn(pts, viewdirs, run_fn) 352 | 353 | rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 354 | 355 | """ 356 | 357 | ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map} 358 | if retraw: 359 | ret['raw'] = raw 360 | """ 361 | if N_importance > 0: 362 | ret['rgb0'] = rgb_map_0 363 | ret['disp0'] = disp_map_0 364 | ret['acc0'] = acc_map_0 365 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 366 | """ 367 | for k in ret: 368 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 369 | print(f"! [Numerical Error] {k} contains nan or inf.") 370 | 371 | return ret 372 | 373 | 374 | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64): 375 | """Prepares inputs and applies network 'fn'. 376 | """ 377 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 378 | embedded = embed_fn(inputs_flat) 379 | 380 | if viewdirs is not None: 381 | input_dirs = viewdirs[:,None].expand(inputs.shape) 382 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 383 | embedded_dirs = embeddirs_fn(input_dirs_flat) 384 | embedded = torch.cat([embedded, embedded_dirs], -1) 385 | 386 | outputs_flat = batchify(fn, netchunk)(embedded) 387 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 388 | return outputs 389 | 390 | def batchify(fn, chunk): 391 | """Constructs a version of 'fn' that applies to smaller batches. 392 | """ 393 | if chunk is None: 394 | return fn 395 | def ret(inputs): 396 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 397 | return ret 398 | 399 | 400 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 401 | """Transforms model's predictions to semantically meaningful values. 402 | Args: 403 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 404 | z_vals: [num_rays, num_samples along ray]. Integration time. 405 | rays_d: [num_rays, 3]. Direction of each ray. 406 | Returns: 407 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 408 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 409 | acc_map: [num_rays]. Sum of weights along each ray. 410 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 411 | depth_map: [num_rays]. Estimated distance to object. 412 | """ 413 | raw2alpha = lambda raw, dists, act_fn=F.softplus: 1.-torch.exp(-act_fn(raw)*dists) 414 | 415 | dists = z_vals[...,1:] - z_vals[...,:-1] 416 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 417 | 418 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1) 419 | 420 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 421 | noise = 0. 422 | if raw_noise_std > 0.: 423 | noise = torch.randn(raw[...,3].shape) * raw_noise_std 424 | 425 | # Overwrite randomly sampled data if pytest 426 | if pytest: 427 | np.random.seed(0) 428 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 429 | noise = torch.Tensor(noise) 430 | 431 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 432 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 433 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 434 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] 435 | 436 | depth_map = torch.sum(weights * z_vals, -1) 437 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 438 | acc_map = torch.sum(weights, -1) 439 | 440 | if white_bkgd: 441 | rgb_map = rgb_map + (1.-acc_map[...,None]) 442 | 443 | return rgb_map, disp_map, acc_map, weights, depth_map 444 | 445 | 446 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 447 | 448 | H, W, focal = hwf 449 | 450 | if render_factor!=0: 451 | # Render downsampled for speed 452 | H = H//render_factor 453 | W = W//render_factor 454 | focal = focal/render_factor 455 | 456 | rgbs = [] 457 | disps = [] 458 | 459 | t = time.time() 460 | for i, c2w in enumerate(tqdm(render_poses)): 461 | print(i, time.time() - t) 462 | t = time.time() 463 | rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 464 | rgbs.append(rgb.detach().cpu().numpy()) 465 | disps.append(disp.detach().cpu().numpy()) 466 | if i==0: 467 | print(rgb.shape, disp.shape) 468 | 469 | """ 470 | if gt_imgs is not None and render_factor==0: 471 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 472 | print(p) 473 | """ 474 | 475 | if savedir is not None: 476 | rgb8 = to8b(rgbs[-1]) 477 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 478 | imageio.imwrite(filename, rgb8) 479 | 480 | 481 | rgbs = np.stack(rgbs, 0) 482 | disps = np.stack(disps, 0) 483 | 484 | return rgbs, disps 485 | 486 | -------------------------------------------------------------------------------- /pts2nerf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join, exists 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from datetime import datetime 8 | from nerf_helpers import * 9 | from itertools import chain 10 | from tqdm import tqdm 11 | 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | import matplotlib.pyplot as plt 15 | 16 | from utils import * 17 | 18 | from dataset.dataset import NeRFShapeNetDataset 19 | 20 | from models.encoder import Encoder 21 | from models.nerf import NeRF 22 | from models.resnet import resnet18 23 | from hypnettorch.hnets.chunked_mlp_hnet import ChunkedHMLP 24 | 25 | #Needed for workers for dataloader 26 | from torch.multiprocessing import Pool, Process, set_start_method 27 | set_start_method('spawn', force=True) 28 | 29 | import argparse 30 | 31 | if __name__ == '__main__': 32 | dirname = os.path.dirname(__file__) 33 | 34 | parser = argparse.ArgumentParser(description='Start training') 35 | parser.add_argument('config_path', type=str, 36 | help='Relative config path') 37 | 38 | args = parser.parse_args() 39 | 40 | config = None 41 | with open(args.config_path) as f: 42 | config = json.load(f) 43 | assert config is not None 44 | 45 | print(config) 46 | 47 | set_seed(config['seed']) 48 | 49 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | print('Device: ', device) 51 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 52 | 53 | dataset = NeRFShapeNetDataset(root_dir=config['data_dir'], classes=config['classes']) 54 | 55 | dataloader = DataLoader(dataset, batch_size=config['batch_size'], 56 | shuffle=config['shuffle'], 57 | num_workers=8, drop_last=True, 58 | pin_memory=True, generator=torch.Generator(device='cuda')) 59 | 60 | embed_fn, config['model']['TN']['input_ch_embed'] = get_embedder(config['model']['TN']['multires'], config['model']['TN']['i_embed']) 61 | 62 | embeddirs_fn = None 63 | config['model']['TN']['input_ch_views_embed'] = 0 64 | if config['model']['TN']['use_viewdirs']: 65 | embeddirs_fn, config['model']['TN']['input_ch_views_embed']= get_embedder(config['model']['TN']['multires_views'], config['model']['TN']['i_embed']) 66 | 67 | 68 | # Create a NeRF network 69 | nerf = NeRF(config['model']['TN']['D'],config['model']['TN']['W'], 70 | config['model']['TN']['input_ch_embed'], 71 | config['model']['TN']['input_ch_views_embed'], 72 | config['model']['TN']['use_viewdirs']).to(device) 73 | 74 | #Hypernetwork 75 | hnet = ChunkedHMLP(nerf.param_shapes, uncond_in_size=config['z_size'], cond_in_size=0, 76 | layers=config['model']['HN']['arch'], chunk_size=config['model']['HN']['chunk_size'], cond_chunk_embs=False, use_bias=config['model']['HN']['use_bias']).to(device) 77 | 78 | print(hnet.param_shapes) 79 | 80 | #Create encoder: either Resnet or classic 81 | if config['resnet']==True: 82 | encoder = resnet18(num_classes=config['z_size']).to(device) 83 | else: 84 | encoder = Encoder(config).to(device) 85 | 86 | #RAdam because it might help with not collapsing to white background 87 | optimizer = torch.optim.RAdam(chain(encoder.parameters(), hnet.internal_params), **config['optimizer']['E_HN']['hyperparams']) 88 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['lr_decay']) 89 | loss_fn = torch.nn.MSELoss() 90 | 91 | results_dir = config['results_dir'] 92 | os.makedirs(join(dirname,results_dir), exist_ok=True) 93 | 94 | with open(join(results_dir, "config.json"), "w") as file: 95 | json.dump(config, file, indent=4) 96 | 97 | try: 98 | losses_r = np.load(join(results_dir, f'losses_r.npy')).tolist() 99 | print("Loaded reconstruction losses") 100 | losses_kld = np.load(join(results_dir, f'losses_kld.npy')).tolist() 101 | print("Loaded KLD losses") 102 | losses_total = np.load(join(results_dir, f'losses_total.npy')).tolist() 103 | print("Loaded total losses") 104 | except: 105 | print("Haven't found previous loss data. We are assuming that this is a new experiment.") 106 | losses_r = [] 107 | losses_kld = [] 108 | losses_total = [] 109 | 110 | starting_epoch = len(losses_total) 111 | 112 | print("starting epoch:", starting_epoch) 113 | 114 | if(starting_epoch>0): 115 | print("Loading weights since previous losses were found") 116 | try: 117 | hnet.load_state_dict(torch.load(join(results_dir, f"model_hn_{starting_epoch-1}.pt"))) 118 | print("Loaded HNet") 119 | encoder.load_state_dict(torch.load(join(results_dir, f"model_e_{starting_epoch-1}.pt"))) 120 | print("Loaded Encoder") 121 | scheduler.load_state_dict(torch.load(join(results_dir, f"lr_{starting_epoch-1}.pt"))) 122 | print("Loaded Scheduler") 123 | except: 124 | print("Haven't found all previous models.") 125 | 126 | 127 | hnet.train() 128 | encoder.train() 129 | 130 | os.makedirs(join(results_dir, 'samples'), exist_ok=True) 131 | 132 | for epoch in range(starting_epoch, starting_epoch+config['max_epochs'] + 1): 133 | start_epoch_time = datetime.now() 134 | 135 | total_loss = 0.0 136 | total_loss_r = 0.0 137 | total_loss_kld = 0.0 138 | 139 | for i, (entry, cat, obj_path) in enumerate(dataloader): 140 | x = [] 141 | y = [] 142 | 143 | if config['resnet']: 144 | nerf_Ws, mu, logvar = get_nerf_resnet(entry, encoder, hnet) 145 | else: 146 | nerf_Ws, mu, logvar = get_nerf(entry, encoder, hnet) 147 | 148 | #For batch size == 1 hnet doesn't return batch dimension... 149 | if config['batch_size'] == 1: 150 | nerf_Ws = [nerf_Ws] 151 | 152 | for j, target_w in enumerate(nerf_Ws): 153 | render_kwargs_train = get_render_kwargs(config, nerf, target_w, embed_fn, embeddirs_fn) 154 | 155 | for p in range(config["poses"]): 156 | img_i = np.random.choice(len(entry['images'][j]), 1) 157 | target = entry['images'][j][img_i][0].to(device) 158 | target = torch.Tensor(target.float()) 159 | pose = entry['cam_poses'][j][img_i, :3,:4][0].to(device) 160 | 161 | H = entry["images"][j].shape[1] 162 | W = entry["images"][j].shape[2] 163 | focal = .5 * W / np.tan(.5 * 0.6911112070083618) 164 | 165 | K = np.array([ 166 | [focal, 0, 0.5*W], 167 | [0, focal, 0.5*H], 168 | [0, 0, 1] 169 | ]) 170 | 171 | #Calculate rays from camera origin 172 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose.float())) 173 | 174 | #Create coordinates array (for ray selection) 175 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) 176 | 177 | #To 1D 178 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2) 179 | 180 | #Select rays based on random coord selection 181 | select_inds = np.random.choice(coords.shape[0], size=[config['model']['TN']['N_rand'],], replace=False) # (N_rand,) 182 | select_coords = coords[select_inds].long() # (N_rand, 2) 183 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 184 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 185 | batch_rays = torch.stack([rays_o, rays_d], 0) 186 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 187 | 188 | 189 | img_r, _, _, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], rays=batch_rays.to(device), 190 | verbose=True, retraw=True, 191 | **render_kwargs_train) 192 | 193 | x.append(target_s) 194 | y.append(img_r) 195 | 196 | optimizer.zero_grad() 197 | x = torch.stack(x) 198 | y = torch.stack(y) 199 | 200 | loss_r = loss_fn(y, x) 201 | 202 | loss_kld = 0.5 * (torch.exp(logvar) + torch.pow(mu, 2) - 1 - logvar).sum() 203 | 204 | loss = loss_r + loss_kld 205 | 206 | loss.backward() 207 | optimizer.step() 208 | 209 | total_loss_r += loss_r.item() 210 | total_loss += loss.item() 211 | total_loss_kld += loss_kld.item() 212 | 213 | losses_r.append(total_loss_r) 214 | losses_kld.append(total_loss_kld) 215 | losses_total.append(total_loss) 216 | 217 | scheduler.step() 218 | 219 | #Log information, save models etc. 220 | if epoch % config['i_log'] == 0: 221 | print(f"Epoch {epoch}: took {round((datetime.now() - start_epoch_time).total_seconds(), 3)} seconds") 222 | print(f"Total loss: {total_loss} Loss R: {total_loss_r} Loss KLD: {total_loss_kld}") 223 | 224 | #Compare current reconstruction 225 | if epoch % config['i_sample'] == 0 or epoch == 0: 226 | with torch.no_grad(): 227 | render_kwargs_test = { 228 | k: render_kwargs_train[k] for k in render_kwargs_train} 229 | render_kwargs_test['perturb'] = False 230 | render_kwargs_test['raw_noise_std'] = 0. 231 | img, _, _, _ = render(H,W,K, chunk=config['model']['TN']['netchunk'], c2w=pose, 232 | verbose=True, retraw=True, 233 | **render_kwargs_test) 234 | f, axarr = plt.subplots(1,2) 235 | axarr[0].imshow(img.detach().cpu()) 236 | axarr[1].imshow(target.detach().cpu()) 237 | f.savefig(join(results_dir, 'samples', f"epoch_{epoch}.png")) 238 | plt.close(f) 239 | 240 | 241 | if epoch % config['i_save']==0: 242 | torch.save(hnet.state_dict(), join(results_dir, f"model_hn_{epoch}.pt")) 243 | torch.save(encoder.state_dict(), join(results_dir, f"model_e_{epoch}.pt")) 244 | torch.save(scheduler.state_dict(), join(results_dir, f"lr_{epoch}.pt")) 245 | #torch.save(optimizer.state_dict(), join(results_dir, f"opt_{epoch}.pt")) 246 | 247 | np.save(join(results_dir, 'losses_r.npy'), np.array(losses_r)) 248 | np.save(join(results_dir, 'losses_kld.npy'), np.array(losses_kld)) 249 | np.save(join(results_dir, 'losses_total.npy'), np.array(losses_total)) 250 | 251 | plt.plot(losses_r) 252 | plt.savefig(os.path.join(results_dir, f'loss_r_plot.png')) 253 | plt.close() 254 | 255 | plt.loglog(losses_r) 256 | plt.savefig(os.path.join(results_dir, f'loss_r_plot_log.png')) 257 | plt.close() 258 | 259 | plt.plot(losses_kld) 260 | plt.savefig(os.path.join(results_dir, f'loss_kld_plot.png')) 261 | plt.close() 262 | 263 | plt.plot(losses_total) 264 | plt.savefig(os.path.join(results_dir, f'loss_total_plot.png')) 265 | plt.close() 266 | 267 | -------------------------------------------------------------------------------- /render_samples.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join, exists 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from datetime import datetime 9 | from nerf_helpers import * 10 | from itertools import chain 11 | from tqdm import tqdm 12 | 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | 17 | from utils import * 18 | 19 | import trimesh, mcubes 20 | 21 | from dataset.dataset import NeRFShapeNetDataset 22 | 23 | from models.encoder import Encoder 24 | from models.nerf import NeRF 25 | from models.resnet import resnet18 26 | from hypnettorch.hnets.chunked_mlp_hnet import ChunkedHMLP 27 | 28 | import open3d as o3d 29 | 30 | import argparse 31 | 32 | #Needed for workers for dataloader 33 | from torch.multiprocessing import Pool, Process, set_start_method 34 | set_start_method('spawn', force=True) 35 | 36 | import math 37 | def cart2sph(x,y,z): 38 | XsqPlusYsq = x**2 + y**2 39 | r = math.sqrt(XsqPlusYsq + z**2) # r 40 | elev = math.atan2(z,math.sqrt(XsqPlusYsq)) # theta 41 | az = math.atan2(y,x) # phi 42 | return r, elev, az 43 | 44 | def export_model(render_kwargs, focal, path, path_colored, N=256): 45 | width = 1.1 46 | with torch.no_grad(): 47 | #Sample NeRF 48 | t = torch.linspace(-width, width, N+1) 49 | query_pts = torch.stack(torch.meshgrid(t, t, t), -1) 50 | print(query_pts.shape) 51 | sh = query_pts.shape 52 | flat = query_pts.reshape([-1,3]) 53 | print(flat.shape) 54 | 55 | fn = lambda i0, i1 : render_kwargs['network_query_fn'](flat[i0:i1,None,:], viewdirs=None, network_fn=render_kwargs['network_fn']) 56 | chunk = 1024*16 57 | raw = torch.cat([fn(i, i+chunk) for i in range(0, flat.shape[0], chunk)], 0) 58 | raw = torch.reshape(raw, list(sh[:-1]) + [-1]) 59 | sigma = torch.maximum(raw[...,-1], torch.Tensor([0.])) 60 | 61 | #Marching cubes 62 | threshold = 5 63 | vertices, triangles = mcubes.marching_cubes(sigma.cpu().numpy(), threshold) 64 | print('done', vertices.shape, triangles.shape) 65 | 66 | #Two meshes because colors tend to be misplaced on mesh_export 67 | mesh = trimesh.Trimesh((vertices / N) - 0.5, triangles) 68 | 69 | obj = trimesh.exchange.ply.export_ply(mesh) 70 | 71 | with open(path, "wb+") as f: 72 | f.write(obj) 73 | 74 | print("Saved uncolored model to", path) 75 | 76 | rgbs = [] 77 | final = [] 78 | vertex_colors = [] 79 | radius = 0.05 # distance from camera to a vertex, theoretically it could be lower to properly capture its color 80 | 81 | H = 1 82 | W = 1 83 | K = np.array([ 84 | [focal, 0, 0.5*W], 85 | [0, focal, 0.5*H], 86 | [0, 0, 1] 87 | ]) 88 | 89 | for i, vert in enumerate(mesh.vertices): 90 | coords = np.array(vert) 91 | 92 | coords = coords / np.linalg.norm(coords) 93 | r, phi, theta = cart2sph(*coords) 94 | theta += math.pi/2 95 | phi -= math.pi 96 | c2w = pose_spherical(theta * 180 / math.pi, phi * 180 / math.pi, r+radius) 97 | result = render(H, W, K, chunk=2048, c2w=c2w, **render_kwargs) 98 | rgb = np.clip(result[0].detach().cpu().numpy(),0,1).squeeze() 99 | rgbs.append(rgb) 100 | final.append([*vert, *rgb]) 101 | mesh.visual.vertex_colors[i] = np.concatenate((rgb, [1]))*255 102 | 103 | obj = trimesh.exchange.ply.export_ply(mesh) 104 | 105 | with open(path_colored, "wb+") as f: 106 | f.write(obj) 107 | 108 | print("Saved colored model to", path_colored) 109 | 110 | 111 | if __name__ == '__main__': 112 | pd.set_option('display.max_columns', None) 113 | pd.set_option('display.max_rows', None) 114 | 115 | dirname = os.path.dirname(__file__) 116 | 117 | parser = argparse.ArgumentParser(description='Start training HyperRF') 118 | parser.add_argument('config_path', type=str, 119 | help='Relative config path') 120 | parser.add_argument('-o_anim_count', type=int, help='How many object animations') 121 | parser.add_argument('-g_anim_count', type=int, help='How many generated object animations') 122 | parser.add_argument('-i_anim_count', type=int, help='How many interpolation object animations') 123 | parser.add_argument('-train_ds', type=int, help="Use train dataset?", default=0) 124 | parser.add_argument('-epoch', type=int, help="Default epoch to use. Set 0 to use latest.", default=0) 125 | #TODO: dodac argumenty tutaj 126 | 127 | args = parser.parse_args() 128 | 129 | config = None 130 | with open(args.config_path) as f: 131 | config = json.load(f) 132 | assert config is not None 133 | 134 | print(config) 135 | 136 | set_seed(config['seed']) 137 | 138 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 139 | 140 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 141 | 142 | dataset = NeRFShapeNetDataset(root_dir=config['data_dir'], classes=config['classes'], train=args.train_ds != 0) 143 | 144 | config['batch_size'] = 1 145 | 146 | dataloader = DataLoader(dataset, batch_size=config['batch_size'], 147 | shuffle=config['shuffle'], 148 | num_workers=2, drop_last=True, 149 | pin_memory=True, generator=torch.Generator(device='cuda')) 150 | 151 | embed_fn, config['model']['TN']['input_ch_embed'] = get_embedder(config['model']['TN']['multires'], config['model']['TN']['i_embed']) 152 | 153 | embeddirs_fn = None 154 | config['model']['TN']['input_ch_views_embed'] = 0 155 | if config['model']['TN']['use_viewdirs']: 156 | embeddirs_fn, config['model']['TN']['input_ch_views_embed']= get_embedder(config['model']['TN']['multires_views'], config['model']['TN']['i_embed']) 157 | 158 | # Create a NeRF network 159 | nerf = NeRF(config['model']['TN']['D'],config['model']['TN']['W'], 160 | config['model']['TN']['input_ch_embed'], 161 | config['model']['TN']['input_ch_views_embed'], 162 | config['model']['TN']['use_viewdirs']).to(device) 163 | 164 | #Hypernetwork 165 | hnet = ChunkedHMLP(nerf.param_shapes, uncond_in_size=config['z_size'], cond_in_size=0, 166 | layers=config['model']['HN']['arch'], chunk_size=config['model']['HN']['chunk_size'], cond_chunk_embs=False, use_bias=config['model']['HN']['use_bias']).to(device) 167 | 168 | #Create encoder: either Resnet or classic 169 | if config['resnet']==True: 170 | encoder = resnet18(num_classes=config['z_size']).to(device) 171 | else: 172 | encoder = Encoder(config).to(device) 173 | 174 | results_dir = config['results_dir'] 175 | os.makedirs(join(dirname,results_dir), exist_ok=True) 176 | 177 | with open(join(results_dir, "config_eval.json"), "w") as file: 178 | json.dump(config, file, indent=4) 179 | 180 | 181 | print(args.epoch, "set as starting epoch") 182 | if args.epoch == 0: 183 | print("Loading \'latest\' models") 184 | try: 185 | hnet.load_state_dict(torch.load(join(results_dir, f"model_hn_latest.pt"))) 186 | print("Loaded HNet") 187 | encoder.load_state_dict(torch.load(join(results_dir, f"model_e_latest.pt"))) 188 | print("Loaded Encoder") 189 | except: 190 | print("Haven't loaded all previous models.") 191 | else: 192 | starting_epoch = args.epoch 193 | print("Starting epoch:", starting_epoch) 194 | 195 | if(starting_epoch>0): 196 | print("Loading weights") 197 | try: 198 | hnet.load_state_dict(torch.load(join(results_dir, f"model_hn_{starting_epoch}.pt"))) 199 | print("Loaded HNet") 200 | encoder.load_state_dict(torch.load(join(results_dir, f"model_e_{starting_epoch}.pt"))) 201 | print("Loaded Encoder") 202 | except: 203 | print("Haven't found all previous models.") 204 | 205 | results_dir = join(dirname, 'rendered_samples', config['classes'][0]) 206 | os.makedirs(results_dir, exist_ok=True) 207 | results_dir_main = results_dir 208 | 209 | encoder.eval() 210 | hnet.eval() 211 | 212 | default_N = 256 213 | render_iterations = 60 + 1 214 | render_fps = 30 215 | 216 | for i, (entry, cat, obj_path) in enumerate(dataloader): 217 | if i > args.o_anim_count: 218 | break 219 | 220 | start_time = datetime.now() 221 | 222 | if config['resnet']: 223 | nerf_Ws = get_nerf_resnet(entry, encoder, hnet) 224 | else: 225 | nerf_Ws, mu, logvar = get_nerf(entry, encoder, hnet) 226 | 227 | #For batch size == 1 hnet doesn't return batch dimension... 228 | if config['batch_size'] == 1: 229 | nerf_Ws = [nerf_Ws] 230 | 231 | for j, target_w in enumerate(nerf_Ws): 232 | render_kwargs = get_render_kwargs(config, nerf, target_w, embed_fn, embeddirs_fn) 233 | render_kwargs['perturb'] = False 234 | render_kwargs['raw_noise_std'] = 0. 235 | 236 | print("Animation", i, obj_path) 237 | H = entry["images"][j].shape[1] 238 | W = entry["images"][j].shape[2] 239 | focal = .5 * W / np.tan(.5 * 0.6911112070083618) 240 | 241 | K = np.array([ 242 | [focal, 0, 0.5*W], 243 | [0, focal, 0.5*H], 244 | [0, 0, 1] 245 | ]) 246 | 247 | results_dir = join(results_dir_main, f'o{i}') 248 | os.makedirs(results_dir, exist_ok=True) 249 | torch.set_printoptions(threshold=100) 250 | 251 | #Render cloud of points 252 | """ 253 | for el in [0,45,90,135, 180, 225, 270, 315]: 254 | for az in [0,45,90,135, 180, 225, 270, 315]: 255 | fig = plt.figure(figsize=(8,8)) 256 | ax = fig.add_subplot(111, projection = '3d') 257 | ax.view_init(elev=el, azim=az) 258 | ax.scatter(entry['data'][j][:,0], entry['data'][j][:,1], entry['data'][j][:,2], c = entry['data'][j][:,3:]) 259 | ax.set_xlim3d(-1, 1) 260 | ax.set_ylim3d(-1, 1) 261 | ax.set_zlim3d(-1, 1) 262 | plt.axis('off') 263 | plt.grid(b=None) 264 | plt.tight_layout() 265 | plt.savefig(join(results_dir, f'pc_{el}_{az}.png')) 266 | plt.close() 267 | """ 268 | 269 | for gt in range(10): 270 | imageio.imsave(join(results_dir, f'ground_t_{gt}.png'), to8b(entry['images'][j][gt].detach().cpu().numpy())) 271 | 272 | with torch.no_grad(): 273 | img_i = np.random.choice(len(entry['images'][j]), 1) 274 | target = entry['images'][j][img_i][0].to(device) 275 | target = torch.Tensor(target.float()) 276 | pose = entry['cam_poses'][j][img_i, :3,:4][0].to(device) 277 | 278 | img_r, _, _, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w = pose, 279 | verbose=True, retraw=True, 280 | **render_kwargs) 281 | 282 | frame = torch.cat([img_r,target], dim=1) 283 | 284 | imageio.imsave(join(results_dir, f'compare_{i}.png'), to8b(frame.detach().cpu().numpy())) 285 | 286 | with torch.no_grad(): 287 | render_poses = torch.stack([pose_spherical(angle, -45, 3.2) for angle in np.linspace(-180,180,render_iterations)[:-1]], 0) 288 | frames = [] 289 | for k, pose in enumerate(render_poses): 290 | 291 | img, disp, acc, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=pose, 292 | verbose=True, retraw=True, 293 | **render_kwargs) 294 | frames.append(to8b(img.detach().cpu().numpy())) 295 | 296 | if k%4==0: 297 | imageio.imsave(join(results_dir, f'o_{i}_{k}.png'), to8b(img.detach().cpu().numpy())) 298 | 299 | writer = imageio.get_writer(join(results_dir, f'an_{i}.gif'), fps=30) 300 | for frame in frames: 301 | writer.append_data(frame) 302 | writer.close() 303 | 304 | with torch.no_grad(): 305 | render_poses = torch.stack([pose_spherical(angle, -45, 3.2) for angle in np.linspace(-180,180,9)[:-1]]+\ 306 | [pose_spherical(angle, -30, 3.2) for angle in np.linspace(-180,180,9)[:-1]]+\ 307 | [pose_spherical(angle, -15, 3.2) for angle in np.linspace(-180,180,9)[:-1]], 308 | 0) 309 | for k, pose in enumerate(render_poses): 310 | 311 | img, disp, acc, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=pose, 312 | verbose=True, retraw=True, 313 | **render_kwargs) 314 | 315 | 316 | imageio.imsave(join(results_dir, f'o_other_{i}_{k}.png'), to8b(img.detach().cpu().numpy())) 317 | 318 | render_kwargs['near'] = 0. 319 | 320 | export_model(render_kwargs, focal, join(results_dir, f'o_model_{i}.ply'), join(results_dir, f'o_model_col_{i}.ply'), N=default_N) 321 | 322 | print("Time:", round((datetime.now() - start_time).total_seconds(), 2)) 323 | 324 | for i in range(args.g_anim_count): 325 | start_time = datetime.now() 326 | sample = torch.normal(mean=torch.zeros(config["z_size"]), std=torch.full((config["z_size"],), fill_value=0.006)) 327 | render_kwargs = get_render_kwargs(config, nerf, get_nerf_from_code(hnet, sample[None]), embed_fn, embeddirs_fn) 328 | render_kwargs['perturb'] = False 329 | render_kwargs['raw_noise_std'] = 0. 330 | 331 | results_dir = join(results_dir_main, f'g{i}') 332 | os.makedirs(results_dir, exist_ok=True) 333 | 334 | print("Generated Object Animation", i) 335 | with torch.no_grad(): 336 | render_poses = torch.stack([pose_spherical(angle, -45, 3.2) for angle in np.linspace(-180,180,render_iterations)[:-1]], 0) 337 | frames = [] 338 | for k, pose in enumerate(render_poses): 339 | 340 | img, disp, acc, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=pose, 341 | verbose=True, retraw=True, 342 | **render_kwargs) 343 | frames.append(to8b(img.detach().cpu().numpy())) 344 | 345 | if k%4==0: 346 | imageio.imsave(join(results_dir, f'g_{i}_{k}.png'), to8b(img.detach().cpu().numpy())) 347 | 348 | writer = imageio.get_writer(join(results_dir, f'g_an_{i}.gif'), fps=render_fps) 349 | for frame in frames: 350 | writer.append_data(frame) 351 | writer.close() 352 | 353 | render_kwargs['near'] = 0. 354 | 355 | export_model(render_kwargs, focal, join(results_dir, f'g_model_{i}.ply'), join(results_dir, f'g_model_col_{i}.ply'), N=default_N) 356 | print("Time:", round((datetime.now() - start_time).total_seconds(), 2)) 357 | 358 | 359 | dl_iter = iter(dataloader) 360 | 361 | for i in range(args.i_anim_count): 362 | with torch.no_grad(): 363 | 364 | results_dir = join(results_dir_main, f'i{i}') 365 | os.makedirs(results_dir, exist_ok=True) 366 | 367 | full_interpolations = None 368 | start_time = datetime.now() 369 | 370 | entry_1, cat_1, obj_path_1 = next(dl_iter) 371 | entry_2, cat_2, obj_path_2 = next(dl_iter) 372 | 373 | nerf_1_code = get_code(entry_1, encoder) 374 | nerf_2_code = get_code(entry_2, encoder) 375 | print("Generated Object Animation", i) 376 | print(obj_path_1) 377 | print(obj_path_2) 378 | 379 | kwargs_1 = get_render_kwargs(config, nerf, get_nerf_from_code(hnet, nerf_1_code), embed_fn, embeddirs_fn) 380 | kwargs_2 = get_render_kwargs(config, nerf, get_nerf_from_code(hnet, nerf_2_code), embed_fn, embeddirs_fn) 381 | 382 | kwargs_1['perturb'] = False 383 | kwargs_1['raw_noise_std'] = 0. 384 | 385 | kwargs_2['perturb'] = False 386 | kwargs_2['raw_noise_std'] = 0. 387 | 388 | steps = render_iterations + 1 389 | 390 | export_model(kwargs_1, focal, join(results_dir, f'i_1_model_{i}.ply'), join(results_dir, f'i_1_model_col_{i}.ply'), N=default_N) 391 | export_model(kwargs_2, focal, join(results_dir, f'i_2_model_{i}.ply'), join(results_dir, f'i_2_model_col_{i}.ply'), N=default_N) 392 | 393 | writer = imageio.get_writer(join(results_dir, f'i_an_{i}.gif'), fps=render_fps) 394 | render_poses = torch.stack([pose_spherical(angle, -45, 3.2) for angle in np.linspace(-180,180,steps)[:-1]], 0) 395 | for k, pose in enumerate(render_poses): 396 | 397 | #c2w=pose for rotation 398 | img1, disp, acc, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=render_poses[-36], 399 | verbose=True, retraw=True,**kwargs_1) 400 | img2, disp, acc, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=render_poses[-36], 401 | verbose=True, retraw=True,**kwargs_2) 402 | 403 | nerf_3_code=torch.lerp(nerf_1_code, nerf_2_code, k/steps) 404 | 405 | kwargs_3 = get_render_kwargs(config, nerf, get_nerf_from_code(hnet, nerf_3_code), embed_fn, embeddirs_fn) 406 | kwargs_3['perturb'] = False 407 | kwargs_3['raw_noise_std'] = 0. 408 | 409 | 410 | img3, disp, acc, _ = render(H, W, K, chunk=config['model']['TN']['netchunk'], c2w=render_poses[-36], 411 | verbose=True, retraw=True,**kwargs_3) 412 | 413 | frame = torch.cat([img1,img3,img2], dim=1) 414 | 415 | if k % 5==0: 416 | kwargs_3['near'] = 0. 417 | export_model(kwargs_3, focal, join(results_dir, f'interpolated_model_{i}_{k}.ply'), join(results_dir, f'interpolated_model_{i}_{k}.ply'), N=default_N) 418 | imageio.imsave(join(results_dir, f'ii_{i}_{k}.png'), to8b(img3.detach().cpu().numpy())) 419 | 420 | writer.append_data(to8b(frame.detach().cpu().numpy())) 421 | writer.close() 422 | 423 | 424 | print("Time:", round((datetime.now() - start_time).total_seconds(), 2)) 425 | 426 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | anyio==3.5.0 3 | argon2-cffi==20.1.0 4 | async-generator==1.10 5 | attrs==21.2.0 6 | Babel==2.9.1 7 | backcall==0.2.0 8 | bleach==4.0.0 9 | certifi==2021.10.8 10 | cffi==1.14.6 11 | charset-normalizer==2.0.12 12 | cycler==0.11.0 13 | debugpy==1.5.1 14 | decorator==5.1.0 15 | defusedxml==0.7.1 16 | deprecation==2.1.0 17 | entrypoints==0.3 18 | fonttools==4.25.0 19 | freetype-py==2.2.0 20 | hypnettorch==0.0.3 21 | idna==3.3 22 | imageio==2.9.0 23 | imageio-ffmpeg==0.4.5 24 | importlib-metadata==4.8.2 25 | ipykernel==6.4.1 26 | ipython==7.29.0 27 | ipython-genutils==0.2.0 28 | ipywidgets==7.6.5 29 | jedi==0.18.0 30 | Jinja2==3.0.2 31 | joblib==1.1.0 32 | json5==0.9.6 33 | jsonschema==3.2.0 34 | jupyter==1.0.0 35 | jupyter-client==7.0.6 36 | jupyter-console==6.4.0 37 | jupyter-core==4.9.1 38 | jupyter-packaging==0.11.1 39 | jupyter-server==1.13.5 40 | jupyterlab==3.2.9 41 | jupyterlab-pygments==0.1.2 42 | jupyterlab-server==2.10.3 43 | jupyterlab-widgets==1.0.0 44 | kiwisolver==1.3.1 45 | MarkupSafe==2.0.1 46 | matplotlib==3.5.0 47 | matplotlib-inline==0.1.2 48 | mistune==0.8.4 49 | mkl-fft 50 | mkl-random 51 | mkl-service 52 | munkres==1.1.4 53 | nbclassic==0.3.5 54 | nbclient==0.5.3 55 | nbconvert==6.1.0 56 | nbformat==5.1.3 57 | nest-asyncio==1.5.1 58 | networkx==2.6.3 59 | ninja==1.10.2.3 60 | notebook==6.4.6 61 | numpy 62 | olefile==0.46 63 | open3d==0.14.1 64 | packaging==21.3 65 | pandas==1.3.5 66 | pandocfilters==1.4.3 67 | parso==0.8.2 68 | pexpect==4.8.0 69 | pickleshare==0.7.5 70 | Pillow==8.4.0 71 | pip==21.2.4 72 | prometheus-client==0.12.0 73 | prompt-toolkit==3.0.20 74 | ptyprocess==0.7.0 75 | pycparser==2.21 76 | pyglet==1.5.21 77 | Pygments==2.10.0 78 | PyMCubes==0.1.2 79 | PyOpenGL==3.1.0 80 | pyparsing==3.0.4 81 | pyrender==0.1.45 82 | pyrsistent==0.18.0 83 | python-dateutil==2.8.2 84 | pytz==2021.3 85 | PyYAML==6.0 86 | pyzmq==22.3.0 87 | qtconsole==5.1.1 88 | QtPy==1.10.0 89 | requests==2.27.1 90 | scikit-learn==1.0.2 91 | scipy==1.7.3 92 | Send2Trash==1.8.0 93 | setuptools==58.0.4 94 | sip 95 | six==1.16.0 96 | sniffio==1.2.0 97 | terminado==0.9.4 98 | testpath==0.5.0 99 | threadpoolctl==3.1.0 100 | tomlkit==0.10.0 101 | torch==1.10.1 102 | torchvision==0.11.2 103 | tornado==6.1 104 | tqdm==4.62.3 105 | traitlets==5.1.1 106 | trimesh==3.9.40 107 | typing-extensions==3.10.0.2 108 | urllib3==1.26.8 109 | wcwidth==0.2.5 110 | webencodings==0.5.1 111 | websocket-client==1.2.3 112 | wheel==0.37.0 113 | widgetsnbextension==3.5.1 114 | zipp==3.6.0 115 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import shutil 3 | import torch 4 | import torch.nn as nn 5 | from os import listdir, makedirs, remove 6 | from os.path import exists, join 7 | import glob 8 | import numpy as np 9 | import pandas as pd 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | from nerf_helpers import * 16 | 17 | def set_seed(seed: int = 0): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | def get_nerf(entry, encoder, hnet): 24 | points = entry["data"] 25 | points = points.to(device, dtype=torch.float) 26 | 27 | if points.size(-1) == 6: 28 | points.transpose_(points.dim() - 2, points.dim() - 1) 29 | 30 | code, mu, logvar = encoder(points) 31 | nerf_W = hnet(uncond_input=code) 32 | 33 | return nerf_W, mu, logvar 34 | 35 | def get_nerf_resnet(entry, encoder, hnet): 36 | img_i = np.random.choice(24, len(entry['images'])) #get 0..max_imgs random ids for each batch 37 | images = [imgs[i] for imgs, i in zip(entry["images"], img_i)] #get those images 38 | 39 | images = torch.stack(images) 40 | 41 | images = images.to(device, dtype=torch.float) 42 | images.transpose_(1, -1) 43 | code, mu, logvar = encoder(images) 44 | 45 | nerf_W = hnet(uncond_input=code) 46 | 47 | return nerf_W, mu, logvar 48 | 49 | def get_code(entry, encoder): 50 | points = entry["data"] 51 | points = points.to(device, dtype=torch.float) 52 | 53 | if points.size(-1) == 6: 54 | points.transpose_(points.dim() - 2, points.dim() - 1) 55 | 56 | code, mu, logvar = encoder(points) 57 | 58 | return code 59 | 60 | def get_nerf_from_code(hnet, code): 61 | 62 | nerf_W = hnet(uncond_input=code) 63 | return nerf_W 64 | 65 | def get_render_kwargs(config, nerf, nerf_w, embed_fn, embeddirs_fn): 66 | 67 | render_kwargs = { 68 | 'network_query_fn' : lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn, 69 | embed_fn=embed_fn, 70 | embeddirs_fn=embeddirs_fn, 71 | netchunk=config['model']['TN']['netchunk']), 72 | 'perturb' : config['model']['TN']['peturb'], 73 | 'N_importance' : config['model']['TN']['N_importance'], 74 | 'network_fine' : None, 75 | 'N_samples' : config['model']['TN']['N_samples'], 76 | 'network_fn' : lambda x: nerf(x,weights=nerf_w), 77 | 'use_viewdirs' : config['model']['TN']['use_viewdirs'], 78 | 'white_bkgd' : config['model']['TN']['white_bkgd'], 79 | 'raw_noise_std' : config['model']['TN']['raw_noise_std'], 80 | 'near': 2., 81 | 'far': 6., 82 | 'ndc': False 83 | } 84 | 85 | return render_kwargs 86 | --------------------------------------------------------------------------------