├── .gitignore ├── README.md ├── __init__.py ├── common ├── __init__.py ├── bpy_util.py ├── domain.py ├── file_util.py ├── input_param_map.py ├── intersection_util.py ├── param_descriptors.py ├── point_cloud_util.py ├── sampling_util.py └── save_obj.py ├── config └── neptune_config_example.yml ├── data ├── __init__.py ├── data_processing.py ├── dataset_pc.py ├── dataset_sketch.py └── dataset_util.py ├── dataset_generator ├── __init__.py ├── base_recipe_generator.py ├── dataset_generator.py ├── recipe_files │ ├── recipe_ceiling_lamp.yml │ ├── recipe_chair.yml │ ├── recipe_table.yml │ └── recipe_vase.yml ├── shape_validators │ ├── __init__.py │ ├── ceiling_lamp_validator.py │ ├── chair_validator.py │ ├── common_validations.py │ ├── shape_validator_factory.py │ ├── shape_validator_interface.py │ ├── table_validator.py │ └── vase_validator.py └── sketch_generator.py ├── environment.yml ├── geocode ├── __init__.py ├── barplot_util.py ├── calculator_accuracy.py ├── calculator_loss.py ├── calculator_util.py ├── geocode.py ├── geocode_model.py ├── geocode_model_alexnet.py ├── geocode_model_resnet.py ├── geocode_test.py ├── geocode_train.py └── geocode_util.py ├── models ├── __init__.py ├── decoder.py ├── dgcnn.py └── vgg.py ├── resources ├── chair_back_frame_mid_y_offset_pct_0_0000_0002.png ├── geo_nodes_button.png ├── geo_nodes_workspace.png ├── geocode_addon.png └── teaser.png ├── scripts ├── __init__.py ├── download_ds.py ├── download_ds_processing_scripts.py └── install_blender4.2.sh ├── setup.py ├── stability_metric ├── __init__.py ├── stability.py ├── stability_parallel.py └── stability_simulation.blend └── visualize_results └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | 27 | .DS_Store 28 | datasets/ 29 | *.zip. 30 | *.zip 31 | .idea/ 32 | Logs/ 33 | cls/ 34 | slurm-*.out 35 | node_modules/ 36 | additional/ 37 | neptune_config.yml 38 | neptune_session.json 39 | .neptune/ 40 | .vscode/ 41 | *.egg_inf 42 | lightning_logs/ 43 | stability_results.json 44 | dataset_processing/ 45 | docs/ 46 | *.slurm 47 | *.log 48 | *.out 49 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/__init__.py -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/common/__init__.py -------------------------------------------------------------------------------- /common/bpy_util.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import math 3 | from mathutils import Vector 4 | from typing import Union 5 | from pathlib import Path 6 | import subprocess 7 | 8 | 9 | def save_obj(target_obj_file_path: Union[Path, str], additional_objs_to_save=None, simplification_ratio=None): 10 | """ 11 | save the object and returns a mesh duplicate version of it 12 | """ 13 | obj = select_shape() 14 | refresh_obj_in_viewport(obj) 15 | dup_obj = copy(obj) 16 | # set active 17 | bpy.ops.object.select_all(action='DESELECT') 18 | dup_obj.select_set(True) 19 | bpy.context.view_layer.objects.active = dup_obj 20 | # apply the modifier to turn the geometry node to a mesh 21 | bpy.ops.object.modifier_apply(modifier="GeometryNodes") 22 | if simplification_ratio and simplification_ratio < 1.0: 23 | bpy.ops.object.modifier_add(type='DECIMATE') 24 | dup_obj.modifiers["Decimate"].decimate_type = 'COLLAPSE' 25 | dup_obj.modifiers["Decimate"].ratio = simplification_ratio 26 | bpy.ops.object.modifier_apply(modifier="Decimate") 27 | assert dup_obj.type == 'MESH' 28 | bpy.ops.object.transform_apply(location=True, rotation=True, scale=True) 29 | # set origin to center of bounding box 30 | bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='BOUNDS') 31 | dup_obj.location.x = dup_obj.location.y = dup_obj.location.z = 0 32 | normalize_scale(dup_obj) 33 | if additional_objs_to_save: 34 | for additional_obj in additional_objs_to_save: 35 | additional_obj.select_set(True) 36 | # save 37 | bpy.ops.wm.obj_export(filepath=str(target_obj_file_path), export_selected_objects=True, export_materials=False, export_triangulated_mesh=True) 38 | return dup_obj 39 | 40 | 41 | def get_geometric_nodes_modifier(obj): 42 | # loop through all modifiers of the given object 43 | gnodes_mod = None 44 | for modifier in obj.modifiers: 45 | # check if current modifier is the geometry nodes modifier 46 | if modifier.type == "NODES": 47 | gnodes_mod = modifier 48 | break 49 | return gnodes_mod 50 | 51 | 52 | def normalize_scale(obj): 53 | obj.select_set(True) 54 | bpy.context.view_layer.objects.active = obj 55 | bpy.ops.object.transform_apply(location=True, rotation=True, scale=True) 56 | # set origin to the center of the bounding box 57 | bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='BOUNDS') 58 | obj.location.x = 0 59 | obj.location.y = 0 60 | obj.location.z = 0 61 | bpy.ops.object.transform_apply(location=True, rotation=True, scale=True) 62 | max_vert_dist = math.sqrt(max([v.co.dot(v.co) for v in obj.data.vertices])) 63 | for v in obj.data.vertices: 64 | v.co /= max_vert_dist 65 | bpy.ops.object.transform_apply(location=True, rotation=True, scale=True) 66 | # verify that the shape is normalized 67 | # max_vert_dist = math.sqrt(max([v.co.dot(v.co) for v in obj.data.vertices])) 68 | # assert abs(max_vert_dist - 1.0) < 0.01 69 | 70 | 71 | def setup_lights(): 72 | """ 73 | setup lights for rendering 74 | used for visualization of 3D objects as images 75 | """ 76 | scene = bpy.context.scene 77 | # light 1 78 | light_data_1 = bpy.data.lights.new(name="light_data_1", type='POINT') 79 | light_data_1.energy = 300 80 | light_object_1 = bpy.data.objects.new(name="Light_1", object_data=light_data_1) 81 | light_object_1.location = Vector((10, -10, 10)) 82 | scene.collection.objects.link(light_object_1) 83 | # light 2 84 | light_data_2 = bpy.data.lights.new(name="light_data_2", type='POINT') 85 | light_data_2.energy = 300 86 | light_object_2 = bpy.data.objects.new(name="Light_2", object_data=light_data_2) 87 | light_object_2.location = Vector((-10, -10, 10)) 88 | scene.collection.objects.link(light_object_2) 89 | # light 3 90 | light_data_3 = bpy.data.lights.new(name="light_data_3", type='POINT') 91 | light_data_3.energy = 300 92 | light_object_3 = bpy.data.objects.new(name="Light_3", object_data=light_data_3) 93 | light_object_3.location = Vector((10, 0, 10)) 94 | scene.collection.objects.link(light_object_3) 95 | 96 | 97 | def look_at(obj_camera, point): 98 | """ 99 | orient the given camera with a fixed position to loot at a given point in space 100 | """ 101 | loc_camera = obj_camera.matrix_world.to_translation() 102 | direction = point - loc_camera 103 | # point the cameras '-Z' and use its 'Y' as up 104 | rot_quat = direction.to_track_quat('-Z', 'Y') 105 | obj_camera.rotation_euler = rot_quat.to_euler() 106 | 107 | 108 | def clean_scene(start_with_strings=["Camera", "procedural", "Light"]): 109 | """ 110 | delete all object of which the name's prefix is matching any of the given strings 111 | """ 112 | scene = bpy.context.scene 113 | bpy.ops.object.select_all(action='DESELECT') 114 | for obj in scene.objects: 115 | if any([obj.name.startswith(starts_with_string) for starts_with_string in start_with_strings]): 116 | # select the object 117 | if obj.visible_get(): 118 | obj.select_set(True) 119 | bpy.ops.object.delete() 120 | 121 | 122 | def del_obj(obj): 123 | bpy.ops.object.select_all(action='DESELECT') 124 | obj.select_set(True) 125 | bpy.ops.object.delete() 126 | 127 | 128 | def refresh_obj_in_viewport(obj): 129 | # the following two line cause the object to update according to the new geometric nodes input 130 | obj.show_bounds = not obj.show_bounds 131 | obj.show_bounds = not obj.show_bounds 132 | obj.data.update() 133 | 134 | def select_objs(*objs): 135 | bpy.ops.object.select_all(action='DESELECT') 136 | for i, obj in enumerate(objs): 137 | if i == 0: 138 | bpy.context.view_layer.objects.active = obj 139 | obj.select_set(True) 140 | 141 | 142 | def select_obj(obj): 143 | select_objs(obj) 144 | 145 | 146 | def select_shape(): 147 | """ 148 | select the procedural shape in the blend file 149 | note that in all our domains, the procedural shape is named "procedural shape" within the blend file 150 | """ 151 | obj = bpy.data.objects["procedural shape"] 152 | select_obj(obj) 153 | return obj 154 | 155 | 156 | def copy(obj): 157 | dup_obj = obj.copy() 158 | dup_obj.data = obj.data.copy() 159 | dup_obj.animation_data_clear() 160 | bpy.context.collection.objects.link(dup_obj) 161 | return dup_obj 162 | 163 | 164 | def use_gpu_if_available(): 165 | """ 166 | allow Blender to use all available GPUs 167 | """ 168 | try: 169 | subprocess.check_output('nvidia-smi') 170 | print('Nvidia GPU detected!') 171 | except Exception: 172 | print('No Nvidia GPU available!') 173 | return 174 | bpy.data.scenes['Scene'].render.engine = "CYCLES" 175 | # set the device_type 176 | bpy.context.preferences.addons["cycles"].preferences.compute_device_type = "CUDA" 177 | # set device to GPU 178 | bpy.context.scene.cycles.device = "GPU" 179 | # get_devices detects GPU devices 180 | bpy.context.preferences.addons["cycles"].preferences.get_devices() 181 | print(bpy.context.preferences.addons["cycles"].preferences.compute_device_type) 182 | for d in bpy.context.preferences.addons["cycles"].preferences.devices: 183 | d["use"] = 1 # using all devices, include GPU and CPU 184 | print(d["name"], d["use"]) 185 | -------------------------------------------------------------------------------- /common/domain.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class Domain(Enum): 4 | chair = 'chair' 5 | vase = 'vase' 6 | table = 'table' 7 | ceiling_lamp = 'ceiling_lamp' 8 | 9 | def __str__(self): 10 | return self.value 11 | -------------------------------------------------------------------------------- /common/file_util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import hashlib 3 | import numpy as np 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | 8 | def save_yml(yml_obj, target_yml_file_path): 9 | with open(target_yml_file_path, 'w') as target_yml_file: 10 | yaml.dump(yml_obj, target_yml_file, sort_keys=False, width=1000) 11 | 12 | 13 | def get_source_recipe_file_path(domain): 14 | """ 15 | get the path to the recipe file path that is found in the source code under the directory "recipe_files" 16 | """ 17 | return Path(__file__).parent.joinpath('..', 'dataset_generator', 'recipe_files', f'recipe_{domain}.yml').resolve() 18 | 19 | 20 | def hash_file_name(file_name): 21 | return int(hashlib.sha1(file_name.encode("utf-8")).hexdigest(), 16) % (10 ** 8) 22 | 23 | 24 | def get_recipe_yml_obj(recipe_file_path: Union[str, Path]): 25 | with open(recipe_file_path, 'r') as recipe_file: 26 | recipe_yml_obj = yaml.load(recipe_file, Loader=yaml.FullLoader) 27 | return recipe_yml_obj 28 | 29 | 30 | def load_obj(file: str): 31 | vs, faces = [], [] 32 | f = open(file) 33 | for line in f: 34 | line = line.strip() 35 | split_line = line.split() 36 | if not split_line: 37 | continue 38 | elif split_line[0] == 'v': 39 | vs.append([float(v) for v in split_line[1:4]]) 40 | elif split_line[0] == 'f': 41 | face_vertex_ids = [int(c.split('/')[0]) for c in split_line[1:]] 42 | assert len(face_vertex_ids) == 3 43 | face_vertex_ids = [(ind - 1) if (ind >= 0) else (len(vs) + ind) 44 | for ind in face_vertex_ids] 45 | faces.append(face_vertex_ids) 46 | f.close() 47 | vs = np.asarray(vs) 48 | faces = np.asarray(faces, dtype=np.int64) 49 | assert np.logical_and(faces >= 0, faces < len(vs)).all() 50 | return vs, faces 51 | -------------------------------------------------------------------------------- /common/input_param_map.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import random 3 | import traceback 4 | import numpy as np 5 | from typing import List, Optional 6 | from dataclasses import dataclass 7 | from bpy.types import NodeInputs, Modifier 8 | from common.bpy_util import select_shape, refresh_obj_in_viewport, get_geometric_nodes_modifier 9 | 10 | 11 | @dataclass 12 | class InputParam: 13 | gnodes_mod: Modifier 14 | input: NodeInputs 15 | axis: Optional[str] # None indicates that this is not a vector 16 | possible_values: List 17 | 18 | def assign_random_value(self): 19 | self.assign_value(random.choice(self.possible_values)) 20 | 21 | def assign_value(self, val): 22 | assert val in self.possible_values 23 | input_type = self.input.bl_label 24 | identifier = self.input.identifier 25 | if input_type == "Float": 26 | self.gnodes_mod[identifier] = val 27 | 28 | if input_type == "Integer": 29 | self.gnodes_mod[identifier] = int(val) 30 | 31 | if input_type == "Boolean": 32 | self.gnodes_mod[identifier] = int(val) 33 | 34 | if input_type == "Vector": 35 | axis_idx = ['x', 'y', 'z'].index(self.axis) 36 | self.gnodes_mod[identifier][axis_idx] = val 37 | 38 | def get_value(self): 39 | identifier = self.input.identifier 40 | if self.axis: 41 | axis_idx = ['x', 'y', 'z'].index(self.axis) 42 | return self.gnodes_mod[identifier][axis_idx] 43 | return self.gnodes_mod[identifier] 44 | 45 | def get_name_for_file(self): 46 | res = str(self.input.name) + ("" if not self.axis else "_" + self.axis) 47 | return res.replace(" ", "_") 48 | 49 | 50 | def get_input_values(gnodes_mod, input, yml_gen_rule): 51 | min_value = None 52 | max_value = None 53 | param_name = input.name 54 | if input.bl_label != 'Boolean': 55 | # TODO(ofekp): find something simpler that works for Blender4.2, input.min_value no longer works 56 | min_value = gnodes_mod.node_group.interface.items_tree[param_name].min_value 57 | max_value = gnodes_mod.node_group.interface.items_tree[param_name].max_value 58 | # override min and max with requested values from recipe yml file 59 | if 'min' in yml_gen_rule: 60 | requested_min_value = yml_gen_rule['min'] 61 | if min_value and requested_min_value < min_value: 62 | if abs(min_value - requested_min_value) > 1e-6: 63 | raise Exception( 64 | f'Requested a min value of [{requested_min_value}] for parameter [{input.name}], but min allowed is [{min_value}]') 65 | # otherwise min_value should remain input.min_value 66 | else: 67 | min_value = requested_min_value 68 | if 'max' in yml_gen_rule: 69 | requested_max_value = yml_gen_rule['max'] 70 | if max_value and requested_max_value > max_value: 71 | if abs(max_value - requested_max_value) > 1e-6: 72 | raise Exception( 73 | f'Requested a max value of [{requested_max_value}] for parameter [{input.name}], but max allowed is [{max_value}]') 74 | # otherwise max_value should remain input.max_value 75 | max_value = requested_max_value 76 | step = 1 if 'samples' not in yml_gen_rule else calculate_step(min_value, max_value, yml_gen_rule['samples']) 77 | res = np.arange(min_value, max_value + 1e-6, step) 78 | 79 | # convert to integers if needed 80 | if input.bl_label in ['Boolean', 'Integer']: 81 | res = list(res.astype(int)) 82 | else: 83 | res = [round(x, 4) for x in list(res)] 84 | 85 | return res 86 | 87 | 88 | def calculate_step(min_value, max_value, samples): 89 | return (max_value - min_value) / (samples - 1) 90 | 91 | 92 | def get_input_param_map(gnodes_mod, yml): 93 | input_params_map = {} 94 | # loops through all the inputs in the geometric node group 95 | group_input_nodes = [node for node in gnodes_mod.node_group.nodes if node.type == 'GROUP_INPUT'] 96 | assert len(group_input_nodes) > 0 97 | group_input_node = group_input_nodes[0] 98 | param_names = [input.name for input in group_input_node.outputs if len(input.name) > 0] 99 | for param_name in yml['dataset_generation']: 100 | if param_name not in param_names: 101 | raise Exception(f"Parameter named [{param_name}] was not found in geometry nodes input group.") 102 | for input in group_input_node.outputs: 103 | param_name = str(input.name) 104 | if len(param_name) == 0: 105 | continue 106 | # we only change inputs that are explicitly noted in the yaml object 107 | if param_name in yml['dataset_generation']: 108 | param_gen_rule = yml['dataset_generation'][param_name] 109 | if 'x' in param_gen_rule or 'y' in param_gen_rule or 'z' in param_gen_rule: 110 | # vector handling 111 | for idx, axis in enumerate(['x', 'y', 'z']): 112 | if not axis in param_gen_rule: 113 | continue 114 | curr_param_values = get_input_values(gnodes_mod, input, param_gen_rule[axis]) 115 | input_params_map[f"{param_name} {axis}"] = InputParam(gnodes_mod, input, axis, curr_param_values) 116 | else: 117 | curr_param_values = get_input_values(gnodes_mod, input, param_gen_rule) 118 | input_params_map[param_name] = InputParam(gnodes_mod, input, None, curr_param_values) 119 | return input_params_map 120 | 121 | 122 | def yml_to_shape(shape_yml_obj, input_params_map, ignore_sanity_check=False): 123 | try: 124 | # select the object in blender 125 | obj = select_shape() 126 | # get the geometric nodes modifier fo the object 127 | gnodes_mod = get_geometric_nodes_modifier(obj) 128 | 129 | # loops through all the inputs in the geometric node group 130 | group_input_nodes = [node for node in gnodes_mod.node_group.nodes if node.type == 'GROUP_INPUT'] 131 | assert len(group_input_nodes) > 0 132 | group_input_node = group_input_nodes[0] 133 | for input in group_input_node.outputs: 134 | param_name = str(input.name) 135 | if len(param_name) == 0: 136 | continue 137 | if param_name not in shape_yml_obj: 138 | continue 139 | param_val = shape_yml_obj[param_name] 140 | if hasattr(param_val, '__iter__'): 141 | # vector handling 142 | for axis_idx, axis in enumerate(['x', 'y', 'z']): 143 | val = param_val[axis] 144 | val = round(val, 4) 145 | param_name_with_axis = f'{param_name} {axis}' 146 | gnodes_mod[input.identifier][axis_idx] = val if abs(val + 1.0) > 0.1 else input_params_map[param_name_with_axis].possible_values[0].item() 147 | assert gnodes_mod[input.identifier][axis_idx] >= 0.0 148 | else: 149 | param_val = round(param_val, 4) 150 | if not ignore_sanity_check: 151 | err_msg = f'param_name [{param_name}] param_val [{param_val}] possible_values {input_params_map[param_name].possible_values}' 152 | assert param_val == -1 or (param_val in input_params_map[param_name].possible_values), err_msg 153 | gnodes_mod[input.identifier] = param_val if (abs(param_val + 1.0) > 0.1) else (input_params_map[param_name].possible_values[0].item()) 154 | # we assume that all input values are non-negative 155 | assert gnodes_mod[input.identifier] >= 0.0 156 | 157 | refresh_obj_in_viewport(obj) 158 | except Exception as e: 159 | print(repr(e)) 160 | print(traceback.format_exc()) 161 | 162 | 163 | def load_shape_from_yml(yml_file_path, input_params_map, ignore_sanity_check=False): 164 | with open(yml_file_path, 'r') as f: 165 | yml_obj = yaml.load(f, Loader=yaml.FullLoader) 166 | yml_to_shape(yml_obj, input_params_map, ignore_sanity_check=ignore_sanity_check) 167 | 168 | 169 | def load_base_shape_from_yml(recipe_file_path, input_params_map): 170 | print(f'Loading the base shape from the YML file [{recipe_file_path}]') 171 | 172 | with open(recipe_file_path, 'r') as f: 173 | yml_obj = yaml.load(f, Loader=yaml.FullLoader) 174 | 175 | yml_to_shape(yml_obj['base'], input_params_map) 176 | 177 | 178 | def randomize_all_params(input_params_map): 179 | param_values_map = {} 180 | for param_name, input_param in input_params_map.items(): 181 | param_values_map[param_name] = random.choice(input_param.possible_values) 182 | return param_values_map 183 | -------------------------------------------------------------------------------- /common/intersection_util.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import array 3 | import bmesh 4 | import mathutils 5 | from common.bpy_util import select_objs, select_shape, refresh_obj_in_viewport 6 | 7 | 8 | # refere to https://github.com/blender/blender-addons/blob/main/object_print3d_utils/mesh_helpers.py 9 | # in Blender 4.2 we cannot use `from object_print3d_utils import mesh_helpers` 10 | def bmesh_copy_from_object(obj, transform=True, triangulate=True, apply_modifiers=False): 11 | """Returns a transformed, triangulated copy of the mesh""" 12 | 13 | assert obj.type == 'MESH' 14 | 15 | if apply_modifiers and obj.modifiers: 16 | import bpy 17 | depsgraph = bpy.context.evaluated_depsgraph_get() 18 | obj_eval = obj.evaluated_get(depsgraph) 19 | me = obj_eval.to_mesh() 20 | bm = bmesh.new() 21 | bm.from_mesh(me) 22 | obj_eval.to_mesh_clear() 23 | else: 24 | me = obj.data 25 | if obj.mode == 'EDIT': 26 | bm_orig = bmesh.from_edit_mesh(me) 27 | bm = bm_orig.copy() 28 | else: 29 | bm = bmesh.new() 30 | bm.from_mesh(me) 31 | 32 | # TODO. remove all customdata layers. 33 | # would save ram 34 | 35 | if transform: 36 | matrix = obj.matrix_world.copy() 37 | if not matrix.is_identity: 38 | bm.transform(matrix) 39 | # Update normals if the matrix has no rotation. 40 | matrix.translation.zero() 41 | if not matrix.is_identity: 42 | bm.normal_update() 43 | 44 | if triangulate: 45 | bmesh.ops.triangulate(bm, faces=bm.faces) 46 | 47 | return bm 48 | 49 | 50 | def isolate_node_as_final_geometry(obj, node_label): 51 | gm = obj.modifiers.get("GeometryNodes") 52 | group_output_node = None 53 | node_to_isolate = None 54 | for n in gm.node_group.nodes: 55 | # print(n.name), print(n.type), print(dir(n)) 56 | if n.type == 'GROUP_OUTPUT': 57 | group_output_node = n 58 | elif n.label == node_label: 59 | node_to_isolate = n 60 | if not node_to_isolate: 61 | raise Exception(f"Did not find any node with the label [{node_label}]") 62 | 63 | realize_instances_node = group_output_node.inputs[0].links[0].from_node 64 | third_to_last_node = realize_instances_node.inputs[0].links[0].from_node 65 | third_to_last_node_socket = None 66 | # to later revert this operation, we need to find the socket which is currently connected 67 | # this happens since the SWITCH node has multiple options, and each option translates to 68 | # a different output socket in the node (so there isn't just one socket as you would think) 69 | for i, socket in enumerate(third_to_last_node.outputs): 70 | if socket.is_linked: 71 | third_to_last_node_socket = i 72 | break 73 | node_group = next(m for m in obj.modifiers if m.type == 'NODES').node_group 74 | # find the output socket that actually is connected to something, 75 | # we do this since some nodes have multiple output sockets 76 | out_socket_idx = 0 77 | for out_socket_idx, out_socket in enumerate(node_to_isolate.outputs): 78 | if out_socket.is_linked: 79 | break 80 | node_group.links.new(node_to_isolate.outputs[out_socket_idx], realize_instances_node.inputs[0]) 81 | def revert(): 82 | node_group.links.new(third_to_last_node.outputs[third_to_last_node_socket], realize_instances_node.inputs[0]) 83 | refresh_obj_in_viewport(obj) 84 | return revert 85 | 86 | 87 | def detect_self_intersection(obj): 88 | """ 89 | refer to: 90 | https://blenderartists.org/t/self-intersection-detection/671080 91 | documentation of the intersection detection method 92 | https://docs.blender.org/api/current/mathutils.bvhtree.html 93 | """ 94 | if not obj.data.polygons: 95 | return array.array('i', ()) 96 | 97 | bm = bmesh_copy_from_object(obj, transform=False, triangulate=False) # mesh_helpers 98 | tree = mathutils.bvhtree.BVHTree.FromBMesh(bm, epsilon=0.00001) 99 | 100 | overlap = tree.overlap(tree) 101 | faces_error = {i for i_pair in overlap for i in i_pair} 102 | return array.array('i', faces_error) 103 | 104 | 105 | def find_self_intersections(node_label): 106 | # intersection detection 107 | chair = select_shape() 108 | revert_isolation = isolate_node_as_final_geometry(chair, node_label) 109 | 110 | dup_obj = chair.copy() 111 | dup_obj.data = chair.data.copy() 112 | dup_obj.animation_data_clear() 113 | bpy.context.collection.objects.link(dup_obj) 114 | # move for clarity 115 | dup_obj.location.x += 2.0 116 | 117 | # set active 118 | bpy.ops.object.select_all(action='DESELECT') 119 | dup_obj.select_set(True) 120 | bpy.context.view_layer.objects.active = dup_obj 121 | # apply the modifier to turn the geometry node to a mesh 122 | bpy.ops.object.modifier_apply(modifier="GeometryNodes") 123 | assert dup_obj.type == 'MESH' 124 | 125 | intersections = detect_self_intersection(dup_obj) 126 | 127 | # delete the duplicate 128 | bpy.ops.object.delete() 129 | 130 | revert_isolation() 131 | 132 | # reselect the original object 133 | select_shape() 134 | 135 | return len(intersections) 136 | 137 | 138 | def detect_cross_intersection(obj1, obj2): 139 | if not obj1.data.polygons or not obj2.data.polygons: 140 | return array.array('i', ()) 141 | 142 | bm1 = bmesh_copy_from_object(obj1, transform=False, triangulate=False) # mesh_helpers 143 | tree1 = mathutils.bvhtree.BVHTree.FromBMesh(bm1, epsilon=0.00001) 144 | bm2 = bmesh_copy_from_object(obj2, transform=False, triangulate=False) # mesh_helpers 145 | tree2 = mathutils.bvhtree.BVHTree.FromBMesh(bm2, epsilon=0.00001) 146 | 147 | overlap = tree1.overlap(tree2) 148 | faces_error = {i for i_pair in overlap for i in i_pair} 149 | return array.array('i', faces_error) 150 | 151 | 152 | def find_cross_intersections(node_label1, node_label2): 153 | # intersection detection 154 | chair = select_shape() 155 | revert_isolation = isolate_node_as_final_geometry(chair, node_label1) 156 | 157 | dup_obj1 = chair.copy() 158 | dup_obj1.data = chair.data.copy() 159 | dup_obj1.animation_data_clear() 160 | bpy.context.collection.objects.link(dup_obj1) 161 | # move for clarity 162 | dup_obj1.location.x += 2.0 163 | # set active 164 | bpy.ops.object.select_all(action='DESELECT') 165 | dup_obj1.select_set(True) 166 | bpy.context.view_layer.objects.active = dup_obj1 167 | # apply the modifier to turn the geometry node to a mesh 168 | bpy.ops.object.modifier_apply(modifier="GeometryNodes") 169 | # export the object 170 | assert dup_obj1.type == 'MESH' 171 | 172 | revert_isolation() 173 | 174 | chair = select_shape() 175 | revert_isolation = isolate_node_as_final_geometry(chair, node_label2) 176 | 177 | dup_obj2 = chair.copy() 178 | dup_obj2.data = chair.data.copy() 179 | dup_obj2.animation_data_clear() 180 | bpy.context.collection.objects.link(dup_obj2) 181 | # move for clarity 182 | dup_obj2.location.x += 2.0 183 | # set active 184 | bpy.ops.object.select_all(action='DESELECT') 185 | dup_obj2.select_set(True) 186 | bpy.context.view_layer.objects.active = dup_obj2 187 | # apply the modifier to turn the geometry node to a mesh 188 | bpy.ops.object.modifier_apply(modifier="GeometryNodes") 189 | # export the object 190 | assert dup_obj2.type == 'MESH' 191 | 192 | revert_isolation() 193 | 194 | intersections = detect_cross_intersection(dup_obj1, dup_obj2) 195 | 196 | # delete the duplicate 197 | select_objs(dup_obj1, dup_obj2) 198 | bpy.ops.object.delete() 199 | 200 | # reselect the original object 201 | select_shape() 202 | 203 | return len(intersections) 204 | -------------------------------------------------------------------------------- /common/param_descriptors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from collections import OrderedDict 4 | from typing import List, Optional, Dict 5 | 6 | 7 | arithmetic_symbols = ['and', 'or', 'not', '(', ')', '<', '<=' , '>', '>=', '==', '-', '+', '/', '*'] 8 | 9 | def isfloat(num): 10 | try: 11 | float(num) 12 | return True 13 | except ValueError: 14 | return False 15 | 16 | 17 | @dataclass 18 | class ParamDescriptor: 19 | input_type: str # ['Integer', 'Boolean', 'Float', 'Vector'] 20 | num_classes: int # including visibility class (-1) if exists 21 | step: float 22 | classes: np.ndarray # e.g. [-1, 5, 6, 7, 8] 23 | normalized_classes: np.ndarray # e.g. [-1, 0, 1, 2, 3] 24 | min_val: float 25 | max_val: float 26 | visibility_condition: str 27 | is_regression: bool 28 | normalized_acc_threshold: float 29 | 30 | def is_visible(self, param_values_map): 31 | assert param_values_map 32 | if not self.visibility_condition: 33 | return True 34 | is_visible_cond = " ".join([(word if (word in arithmetic_symbols or isfloat(word) or word.isnumeric()) else (f"param_values_map[\"{word}\"] == 1" if 'is_' in word else f"param_values_map[\"{word}\"]")) for word in self.visibility_condition.split(" ")]) 35 | return eval(is_visible_cond) 36 | 37 | 38 | class ParamDescriptors: 39 | def __init__(self, recipe_yml_obj, inputs_to_eval, use_regression=False, train_with_visibility_label=True): 40 | self.epsilon = 1e-6 41 | self.recipe_yml_obj = recipe_yml_obj 42 | self.inputs_to_eval = inputs_to_eval 43 | self.use_regression = use_regression 44 | self.train_with_visibility_label = train_with_visibility_label 45 | self.__overall_num_of_classes_without_visibility_label = 0 46 | self.param_descriptors_map: Optional[Dict[str, ParamDescriptor]] = None 47 | self.__constraints: Optional[List[str]] = None 48 | 49 | def check_constraints(self, param_values_map): 50 | assert param_values_map 51 | for constraint in self.get_constraints(): 52 | is_fulfilled = " ".join([(word if (word in arithmetic_symbols or isfloat(word) or word.isnumeric()) else (f"param_values_map[\"{word}\"] == 1" if 'is_' in word else f"param_values_map[\"{word}\"]")) for word in constraint.split(" ")]) 53 | if not eval(is_fulfilled): 54 | return False 55 | return True 56 | 57 | def get_constraints(self): 58 | if self.__constraints: 59 | return self.__constraints 60 | self.__constraints = [] 61 | if 'constraints' in self.recipe_yml_obj: 62 | for constraint_name, constraint in self.recipe_yml_obj['constraints'].items(): 63 | self.__constraints.append(constraint) 64 | return self.__constraints 65 | 66 | def get_param_descriptors_map(self): 67 | if self.param_descriptors_map: 68 | return self.param_descriptors_map 69 | recipe_yml_obj = self.recipe_yml_obj # for readability 70 | param_descriptors_map = OrderedDict() 71 | visibility_conditions = {} 72 | if 'visibility_conditions' in recipe_yml_obj: 73 | visibility_conditions = recipe_yml_obj['visibility_conditions'] 74 | 75 | for i, param_name in enumerate(self.inputs_to_eval): 76 | is_regression = False 77 | normalized_acc_threshold = None 78 | if " x" in param_name or " y" in param_name or " z" in param_name: 79 | input_type = recipe_yml_obj['data_types'][param_name[:-2]]['type'] 80 | else: 81 | input_type = recipe_yml_obj['data_types'][param_name]['type'] 82 | if input_type == 'Integer' or input_type == 'Boolean': 83 | max_val = recipe_yml_obj['dataset_generation'][param_name]['max'] 84 | min_val = recipe_yml_obj['dataset_generation'][param_name]['min'] 85 | step = 1 86 | num_classes = max_val - min_val + step 87 | self.__overall_num_of_classes_without_visibility_label += num_classes 88 | classes = np.arange(min_val, max_val + self.epsilon, step) 89 | normalized_classes = classes - min_val 90 | # visibility label adjustment 91 | if self.train_with_visibility_label: 92 | for vis_cond_name, vis_cond in visibility_conditions.items(): 93 | if vis_cond_name in param_name: 94 | num_classes += 1 95 | classes = np.concatenate((np.array([-1.0]), classes)) 96 | normalized_classes = np.concatenate((np.array([-1.0]), normalized_classes)) 97 | break 98 | elif input_type == 'Float': 99 | max_val = recipe_yml_obj['dataset_generation'][param_name]['max'] 100 | min_val = recipe_yml_obj['dataset_generation'][param_name]['min'] 101 | samples = recipe_yml_obj['dataset_generation'][param_name]['samples'] 102 | step, num_classes, classes, normalized_classes, is_regression, normalized_acc_threshold \ 103 | = self._handle_float(param_name, samples, min_val, max_val, visibility_conditions) 104 | elif input_type == 'Vector': 105 | axis = param_name[-1] 106 | param_name_no_axis = param_name[:-2] 107 | max_val = recipe_yml_obj['dataset_generation'][param_name_no_axis][axis]['max'] 108 | min_val = recipe_yml_obj['dataset_generation'][param_name_no_axis][axis]['min'] 109 | samples = recipe_yml_obj['dataset_generation'][param_name_no_axis][axis]['samples'] 110 | step, num_classes, classes, normalized_classes, is_regression, normalized_acc_threshold \ 111 | = self._handle_float(param_name_no_axis, samples, min_val, max_val, visibility_conditions) 112 | else: 113 | raise Exception(f'Input type [{input_type}] is not supported yet') 114 | 115 | visibility_condition = None 116 | for vis_cond_name, vis_cond in visibility_conditions.items(): 117 | if vis_cond_name in param_name: 118 | visibility_condition = vis_cond 119 | break 120 | param_descriptors_map[param_name] = ParamDescriptor(input_type, num_classes, step, classes, 121 | normalized_classes, min_val, max_val, 122 | visibility_condition, is_regression, 123 | normalized_acc_threshold) 124 | self.param_descriptors_map = param_descriptors_map 125 | return self.param_descriptors_map 126 | 127 | def _handle_float(self, param_name, samples, min_val, max_val, visibility_conditions): 128 | """ 129 | :param param_name: the parameter name, if the parameter is a vector, the axis should be omitted 130 | :param samples: the number of samples requested in the recipe file 131 | :param min_val: the min value allowed in the recipe file 132 | :param max_val: the max value allowed in the recipe file 133 | :param visibility_conditions: visibility conditions from the recipe file 134 | :return: step, num_classes, classes, normalized_classes, is_regression, normalized_acc_threshold 135 | """ 136 | is_regression = False 137 | normalized_acc_threshold = None 138 | if not self.use_regression: 139 | step = (max_val - min_val) / (samples - 1) 140 | classes = np.arange(min_val, max_val + self.epsilon, step) 141 | classes = classes.astype(np.float64) 142 | normalized_classes = (classes - min_val) / (max_val - min_val) 143 | normalized_classes = normalized_classes.astype(np.float64) 144 | num_classes = classes.shape[0] 145 | self.__overall_num_of_classes_without_visibility_label += num_classes 146 | # visibility label adjustment 147 | if self.train_with_visibility_label: 148 | for vis_cond_name, vis_cond in visibility_conditions.items(): 149 | if vis_cond_name in param_name: 150 | num_classes += 1 151 | classes = np.concatenate((np.array([-1.0]), classes)) 152 | normalized_classes = np.concatenate((np.array([-1.0]), normalized_classes)) 153 | break 154 | else: 155 | step = 0 156 | num_classes = 2 # one for prediction and one for visibility label 157 | classes = None 158 | normalized_classes = None 159 | is_regression = True 160 | normalized_acc_threshold = 1 / (2 * (samples - 1)) 161 | return step, num_classes, classes, normalized_classes, is_regression, normalized_acc_threshold 162 | 163 | 164 | def convert_prediction_vector_to_map(self, pred_vector, use_regression=False): 165 | """ 166 | :param pred_vector: predicted vector from the network 167 | :param use_regression: whether we use regression for float values 168 | :return: map object representing the shape 169 | """ 170 | pred_vector = pred_vector.squeeze() 171 | assert len(pred_vector.shape) == 1 172 | shape_map = {} 173 | idx = 0 174 | param_descriptors_map = self.get_param_descriptors_map() 175 | for param_name in self.inputs_to_eval: 176 | param_descriptor = param_descriptors_map[param_name] 177 | input_type = param_descriptor.input_type 178 | classes = param_descriptor.classes 179 | num_classes = param_descriptor.num_classes 180 | if input_type == 'Float' or input_type == 'Vector': 181 | if not use_regression: 182 | normalized_pred_class = int(np.argmax(pred_vector[idx:idx + num_classes])) 183 | pred_val = float(classes[normalized_pred_class]) 184 | else: 185 | min_val = param_descriptor.min_val 186 | max_val = param_descriptor.max_val 187 | pred_val = -1.0 188 | if float(pred_vector[idx + 1]) < 0.5: # visibility class 189 | pred_val = (float(pred_vector[idx]) * (max_val - min_val)) + min_val 190 | else: 191 | # Integer or Boolean 192 | normalized_pred_class = int(np.argmax(pred_vector[idx:idx + num_classes])) 193 | pred_val = int(classes[normalized_pred_class]) 194 | if input_type == 'Vector': 195 | if param_name[:-2] not in shape_map: 196 | shape_map[param_name[:-2]] = {} 197 | shape_map[param_name[:-2]][param_name[-1]] = pred_val 198 | else: 199 | shape_map[param_name] = pred_val 200 | idx += num_classes 201 | return shape_map 202 | 203 | def convert_prediction_vector_to_map_continuous_only(self, inputs_to_eval, pred_vector): 204 | """ 205 | for the comparison to SRPM 206 | """ 207 | pred_vector = pred_vector.squeeze() 208 | assert len(pred_vector.shape) == 1 209 | shape_map = {} 210 | idx = 0 211 | param_descriptors_map = self.get_param_descriptors_map() 212 | for param_name in inputs_to_eval: 213 | param_descriptor = param_descriptors_map[param_name] 214 | input_type = param_descriptor.input_type 215 | classes = param_descriptor.classes 216 | num_classes = param_descriptor.num_classes 217 | assert input_type == 'Float' or input_type == 'Vector', f"param_name [{param_name}] input_type [{input_type}]" 218 | min_val = param_descriptor.min_val 219 | max_val = param_descriptor.max_val 220 | pred_val = (float(pred_vector[idx]) * (max_val - min_val)) + min_val 221 | if input_type == 'Vector': 222 | if param_name[:-2] not in shape_map: 223 | shape_map[param_name[:-2]] = {} 224 | shape_map[param_name[:-2]][param_name[-1]] = pred_val 225 | else: 226 | shape_map[param_name] = pred_val 227 | idx += 1 228 | return shape_map 229 | 230 | def convert_prediction_vector_to_map_discrete_only(self, inputs_to_eval, pred_vector): 231 | """ 232 | :param pred_vector: predicted vector from the network 233 | :param use_regression: whether we use regression for float values 234 | :return: map object representing the shape 235 | """ 236 | pred_vector = pred_vector.squeeze() 237 | assert len(pred_vector.shape) == 1 238 | shape_map = {} 239 | idx = 0 240 | param_descriptors_map = self.get_param_descriptors_map() 241 | for param_name in inputs_to_eval: 242 | param_descriptor = param_descriptors_map[param_name] 243 | input_type = param_descriptor.input_type 244 | classes = param_descriptor.classes 245 | num_classes = param_descriptor.num_classes 246 | assert input_type == 'Boolean' or input_type == 'Integer', f"param_name [{param_name}] input_type [{input_type}]" 247 | # Integer or Boolean 248 | normalized_pred_class = int(np.argmax(pred_vector[idx:idx + num_classes])) 249 | pred_val = int(classes[normalized_pred_class]) 250 | if input_type == 'Vector': 251 | if param_name[:-2] not in shape_map: 252 | shape_map[param_name[:-2]] = {} 253 | shape_map[param_name[:-2]][param_name[-1]] = pred_val 254 | else: 255 | shape_map[param_name] = pred_val 256 | idx += num_classes 257 | return shape_map 258 | 259 | def get_overall_num_of_classes_without_visibility_label(self): 260 | self.get_param_descriptors_map() 261 | return self.__overall_num_of_classes_without_visibility_label 262 | 263 | def expand_target_vector(self, targets): 264 | """ 265 | :param targets: 1-dim target vector which includes a single normalized value for each parameter 266 | :return: 1-dim vector where each parameter prediction is in one-hot representation 267 | """ 268 | targets = targets.squeeze() 269 | assert len(targets.shape) == 1 270 | res_vector = np.array([]) 271 | param_descriptors = self.get_param_descriptors_map() 272 | for i, param_name in enumerate(self.inputs_to_eval): 273 | param_descriptor = param_descriptors[param_name] 274 | num_classes = param_descriptor.num_classes 275 | if param_descriptor.is_regression: 276 | val = targets[i].reshape(1).item() 277 | if val == -1.0: 278 | res_vector = np.concatenate((res_vector, np.array([0.0, 1.0]))) 279 | else: 280 | res_vector = np.concatenate((res_vector, np.array([val, 0.0]))) 281 | else: 282 | normalized_classes = param_descriptor.normalized_classes 283 | normalized_gt_class_idx = int(np.where(abs(normalized_classes - targets[i].item()) < 1e-3)[0].item()) 284 | one_hot = np.eye(num_classes)[normalized_gt_class_idx] 285 | res_vector = np.concatenate((res_vector, one_hot)) 286 | return res_vector 287 | 288 | def expand_target_vector_contiuous_only(self, inputs_to_eval, targets): 289 | """ 290 | :param targets: 1-dim target vector which includes a single normalized value for each parameter 291 | :return: 1-dim vector where each parameter prediction is in one-hot representation 292 | """ 293 | targets = targets.squeeze() 294 | assert len(targets.shape) == 1 295 | res_vector = np.array([]) 296 | for i, param_name in enumerate(inputs_to_eval): 297 | val = targets[i].reshape(1).item() 298 | if val == -1.0: 299 | res_vector = np.concatenate((res_vector, np.array([0.0]))) 300 | else: 301 | res_vector = np.concatenate((res_vector, np.array([val]))) 302 | return res_vector 303 | 304 | def expand_target_vector_discrete_only(self, inputs_to_eval, targets): 305 | """ 306 | :param targets: 1-dim target vector which includes a single normalized value for each parameter 307 | :return: 1-dim vector where each parameter prediction is in one-hot representation 308 | """ 309 | targets = targets.squeeze() 310 | assert len(targets.shape) == 1 311 | res_vector = np.array([]) 312 | param_descriptors = self.get_param_descriptors_map() 313 | for i, param_name in enumerate(inputs_to_eval): 314 | param_descriptor = param_descriptors[param_name] 315 | num_classes = param_descriptor.num_classes 316 | normalized_classes = param_descriptor.normalized_classes 317 | if targets[i].item() == -1.0: 318 | normalized_gt_class_idx = 0 319 | else: 320 | normalized_gt_class_idx = int(np.where(abs(normalized_classes - targets[i].item()) < 1e-3)[0].item()) 321 | one_hot = np.eye(num_classes)[normalized_gt_class_idx] 322 | res_vector = np.concatenate((res_vector, one_hot)) 323 | return res_vector -------------------------------------------------------------------------------- /common/point_cloud_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def normalize_point_cloud(point_cloud, use_center_of_bounding_box=True): 5 | min_x, max_x = torch.min(point_cloud[:, 0]), torch.max(point_cloud[:, 0]) 6 | min_y, max_y = torch.min(point_cloud[:, 1]), torch.max(point_cloud[:, 1]) 7 | min_z, max_z = torch.min(point_cloud[:, 2]), torch.max(point_cloud[:, 2]) 8 | # center the point cloud 9 | if use_center_of_bounding_box: 10 | center = torch.tensor([(min_x + max_x) / 2, (min_y + max_y) / 2, (min_z + max_z) / 2]) 11 | else: 12 | center = torch.mean(point_cloud, dim=0) 13 | point_cloud = point_cloud - center 14 | dist = torch.max(torch.sqrt(torch.sum((point_cloud ** 2), dim=1))) 15 | point_cloud = point_cloud / dist # scale the point cloud 16 | return point_cloud 17 | -------------------------------------------------------------------------------- /common/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgl.geometry import farthest_point_sampler 3 | 4 | 5 | def farthest_point_sampling(faces, vertices, num_points=1000): 6 | random_sampling = sample_surface(faces, vertices, num_points=30000) 7 | point_cloud_indices = farthest_point_sampler(random_sampling.unsqueeze(0), npoints=num_points) 8 | point_cloud = random_sampling[point_cloud_indices[0]] 9 | return point_cloud 10 | 11 | 12 | def face_areas_normals(faces, vs): 13 | face_normals = torch.cross(vs[:, faces[:, 1], :] - vs[:, faces[:, 0], :], 14 | vs[:, faces[:, 2], :] - vs[:, faces[:, 1], :], dim=2) 15 | face_areas = torch.norm(face_normals, dim=2) 16 | face_normals = face_normals / face_areas[:, :, None] 17 | face_areas = 0.5 * face_areas 18 | return face_areas, face_normals 19 | 20 | 21 | def sample_surface(faces, vertices, num_points=1000): 22 | """ 23 | sample mesh surface 24 | sample method: 25 | http://mathworld.wolfram.com/TrianglePointPicking.html 26 | Args 27 | --------- 28 | vertices: vertices 29 | faces: triangle faces (torch.long) 30 | num_points: number of samples in the final point cloud 31 | Return 32 | --------- 33 | samples: (count, 3) points in space on the surface of mesh 34 | normals: (count, 3) corresponding face normals for points 35 | """ 36 | bsize, nvs, _ = vertices.shape 37 | weights, normal = face_areas_normals(faces, vertices) 38 | weights_sum = torch.sum(weights, dim=1) 39 | dist = torch.distributions.categorical.Categorical(probs=weights / weights_sum[:, None]) 40 | face_index = dist.sample((num_points,)) 41 | 42 | # pull triangles into the form of an origin + 2 vectors 43 | tri_origins = vertices[:, faces[:, 0], :] 44 | tri_vectors = vertices[:, faces[:, 1:], :].clone() 45 | tri_vectors -= tri_origins.repeat(1, 1, 2).reshape((bsize, len(faces), 2, 3)) 46 | 47 | # pull the vectors for the faces we are going to sample from 48 | face_index = face_index.transpose(0, 1) 49 | face_index = face_index[:, :, None].expand((bsize, num_points, 3)) 50 | tri_origins = torch.gather(tri_origins, dim=1, index=face_index) 51 | face_index2 = face_index[:, :, None, :].expand((bsize, num_points, 2, 3)) 52 | tri_vectors = torch.gather(tri_vectors, dim=1, index=face_index2) 53 | 54 | # randomly generate two 0-1 scalar components to multiply edge vectors by 55 | random_lengths = torch.rand(num_points, 2, 1, device=vertices.device, dtype=tri_vectors.dtype) 56 | 57 | # points will be distributed on a quadrilateral if we use 2x [0-1] samples 58 | # if the two scalar components sum less than 1.0 the point will be 59 | # inside the triangle, so we find vectors longer than 1.0 and 60 | # transform them to be inside the triangle 61 | random_test = random_lengths.sum(dim=1).reshape(-1) > 1.0 62 | random_lengths[random_test] -= 1.0 63 | random_lengths = torch.abs(random_lengths) 64 | 65 | # multiply triangle edge vectors by the random lengths and sum 66 | sample_vector = (tri_vectors * random_lengths[None, :]).sum(dim=2) 67 | 68 | # finally, offset by the origin to generate 69 | # (n,3) points in space on the triangle 70 | samples = sample_vector + tri_origins 71 | 72 | # normals = torch.gather(normal, dim=1, index=face_index) 73 | 74 | # return samples, normals 75 | return samples[0] 76 | -------------------------------------------------------------------------------- /common/save_obj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import bpy 5 | import argparse 6 | import traceback 7 | from pathlib import Path 8 | import importlib 9 | 10 | def import_parents(level=1): 11 | global __package__ 12 | file = Path(__file__).resolve() 13 | parent, top = file.parent, file.parents[level] 14 | sys.path.append(str(top)) 15 | try: 16 | sys.path.remove(str(parent)) 17 | except ValueError: 18 | pass 19 | 20 | __package__ = '.'.join(parent.parts[len(top.parts):]) 21 | importlib.import_module(__package__) 22 | 23 | if __name__ == '__main__' and __package__ is None: 24 | import_parents(level=1) 25 | 26 | from common.file_util import get_recipe_yml_obj 27 | from common.input_param_map import get_input_param_map, load_shape_from_yml 28 | from common.bpy_util import clean_scene, select_shape, select_objs, get_geometric_nodes_modifier, save_obj 29 | 30 | 31 | def save_obj_from_yml(args): 32 | if args.simplification_ratio: 33 | assert 0.0 <= args.simplification_ratio <= 1.0 34 | target_obj_file_path = Path(args.target_obj_file_path) 35 | assert target_obj_file_path.suffix == ".obj" 36 | clean_scene(start_with_strings=["Camera", "Light"]) 37 | assert 'Main' in bpy.context.view_layer.layer_collection.children, "The procedural shape must be inside a collection called 'Main'" 38 | bpy.context.view_layer.layer_collection.children['Main'].hide_viewport = False 39 | bpy.context.view_layer.layer_collection.children['Main'].exclude = False 40 | obj = select_shape() 41 | gnodes_mod = get_geometric_nodes_modifier(obj) 42 | recipe_yml = get_recipe_yml_obj(args.recipe_file_path) 43 | input_params_map = get_input_param_map(gnodes_mod, recipe_yml) 44 | load_shape_from_yml(args.yml_file_path, input_params_map, ignore_sanity_check=args.ignore_sanity_check) 45 | dup_obj = save_obj(target_obj_file_path, simplification_ratio=args.simplification_ratio) 46 | bpy.data.collections["Main"].hide_render = False 47 | chair_obj = select_shape() 48 | dup_obj.hide_render = False 49 | chair_obj.hide_render = True 50 | dup_obj.data.materials.clear() 51 | select_objs(dup_obj) 52 | bpy.ops.object.delete() 53 | 54 | 55 | def main(): 56 | if '--' in sys.argv: 57 | # refer to https://b3d.interplanety.org/en/how-to-pass-command-line-arguments-to-a-blender-python-script-or-add-on/ 58 | argv = sys.argv[sys.argv.index('--') + 1:] 59 | else: 60 | raise Exception("Expected \'--\' followed by arguments to the script") 61 | 62 | parser = argparse.ArgumentParser(prog='save_obj') 63 | parser.add_argument('--recipe-file-path', type=str, required=True, help='Path to recipe.yml file') 64 | parser.add_argument('--yml-file-path', type=str, required=True, help='Path to yaml file to convert to object') 65 | parser.add_argument('--target-obj-file-path', type=str, required=True, help='Path the obj file that will be created') 66 | parser.add_argument('--simplification-ratio', type=float, default=None, help='Simplification ratio to decimate the mesh') 67 | parser.add_argument('--ignore-sanity-check', action='store_true', default=False, help='Do not check the shape\'s parameters') 68 | 69 | try: 70 | args = parser.parse_known_args(argv)[0] 71 | save_obj_from_yml(args) 72 | except Exception as e: 73 | print(repr(e)) 74 | print(traceback.format_exc()) 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /config/neptune_config_example.yml: -------------------------------------------------------------------------------- 1 | neptune: 2 | api_token = "" 3 | project = "project_dir/project_name" 4 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/data/__init__.py -------------------------------------------------------------------------------- /data/data_processing.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import numpy as np 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | import traceback 6 | import torch 7 | from common.file_util import load_obj 8 | from common.point_cloud_util import normalize_point_cloud 9 | 10 | 11 | def generate_point_clouds(data_dir, phase, num_points, num_point_clouds_per_combination, 12 | processed_dataset_dir_name, sampling_method, 13 | gaussian=0.0, apply_point_cloud_normalization=False): 14 | """ 15 | samples point cloud from mesh 16 | """ 17 | dataset_dir = Path(data_dir, phase) 18 | obj_gt_dir = dataset_dir.joinpath('obj_gt') 19 | processed_dataset_dir = dataset_dir.joinpath(processed_dataset_dir_name) 20 | processed_dataset_dir.mkdir(exist_ok=True) 21 | files = sorted(obj_gt_dir.glob('*.obj')) 22 | for file in tqdm(files): 23 | faces = None 24 | vertices = None 25 | should_load_obj = True # this is done to only load the obj if it is actually required 26 | for point_cloud_idx in range(num_point_clouds_per_combination): 27 | new_file_name = Path(processed_dataset_dir, file.with_suffix('.npy').name.replace(".npy", f"_{point_cloud_idx}.npy")) 28 | if new_file_name.is_file(): 29 | continue 30 | # only load the obj (once per all the instances) if it is actually needed 31 | if should_load_obj: 32 | vertices, faces = load_obj(file) 33 | vertices = vertices.reshape(1, vertices.shape[0], vertices.shape[1]) 34 | vertices = torch.from_numpy(vertices) 35 | faces = torch.from_numpy(faces) 36 | should_load_obj = False 37 | 38 | try: 39 | point_cloud = sampling_method(faces, vertices, num_points=num_points) 40 | except Exception as e: 41 | print(traceback.format_exc()) 42 | print(repr(e)) 43 | print(file) 44 | continue 45 | if apply_point_cloud_normalization: 46 | # normalize the point cloud and use center of bounding box 47 | point_cloud = normalize_point_cloud(point_cloud) 48 | if gaussian and gaussian > 0.0: 49 | point_cloud += np.random.normal(0, gaussian, point_cloud.shape) 50 | np.save(str(new_file_name), point_cloud) 51 | 52 | 53 | def normalize_labels(data_dir, phase, processed_dataset_dir_name, params_descriptors, train_with_visibility_label): 54 | dataset_dir = Path(data_dir, phase) 55 | processed_dataset_dir = dataset_dir.joinpath(processed_dataset_dir_name) 56 | processed_dataset_dir.mkdir(exist_ok=True) 57 | 58 | yml_gt_dir = dataset_dir.joinpath('yml_gt') 59 | files = sorted(yml_gt_dir.glob('*.yml')) 60 | for file in files: 61 | if not file.is_file(): 62 | # it is only allowed to not have a gt yml file when we are in test phase 63 | assert phase == "test" 64 | continue 65 | save_path = Path(processed_dataset_dir, file.name) 66 | if save_path.is_file(): 67 | # this will skip normalization if the file exists, but if the recipe file changes, then normalization needs to be performed again 68 | # in that case, disable this if statement to regenerate the normalized labels 69 | continue 70 | with open(file, 'r') as f: 71 | yml_obj = yaml.load(f, Loader=yaml.FullLoader) 72 | normalized_yml_obj = yml_obj.copy() 73 | 74 | # only apply the normalization to the inputs that were changed in this dataset 75 | for param_name, param_descriptor in params_descriptors.items(): 76 | param_input_type = param_descriptor.input_type 77 | min_val = param_descriptor.min_val 78 | max_val = param_descriptor.max_val 79 | if param_input_type == 'Integer': 80 | normalized_yml_obj[param_name] -= min_val 81 | elif param_input_type == 'Float': 82 | value = yml_obj[param_name] 83 | normalized_yml_obj[param_name] = (value - min_val) / (max_val - min_val) 84 | elif param_input_type == 'Boolean': 85 | pass 86 | elif param_input_type == 'Vector': 87 | param_name_no_axis = param_name[:-2] 88 | for axis in ['x', 'y', 'z']: 89 | if param_name[-2:] != f" {axis}": 90 | continue 91 | value = yml_obj[param_name_no_axis][axis] 92 | normalized_yml_obj[param_name_no_axis][axis] = (value - min_val) / (max_val - min_val) 93 | if train_with_visibility_label and not params_descriptors[param_name].is_visible(yml_obj): 94 | normalized_yml_obj[param_name] = -1 95 | 96 | with open(save_path, 'w') as out_file: 97 | yaml.dump(normalized_yml_obj, out_file) 98 | -------------------------------------------------------------------------------- /data/dataset_pc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data as data 3 | import torch 4 | import yaml 5 | from pathlib import Path 6 | from .data_processing import generate_point_clouds, normalize_labels 7 | from common.sampling_util import sample_surface, farthest_point_sampling 8 | from .dataset_util import assemble_targets 9 | 10 | 11 | class DatasetPC(data.Dataset): 12 | def __init__(self, 13 | inputs_to_eval, 14 | dataset_processing_preferred_device, 15 | params_descriptors, 16 | data_dir, 17 | phase, 18 | num_points=1500, 19 | num_point_clouds_per_combination=1, 20 | random_pc=None, 21 | gaussian=0.0, 22 | apply_point_cloud_normalization=False, 23 | scanobjectnn=False, 24 | augment_with_random_points=True, 25 | train_with_visibility_label=True): 26 | self.inputs_to_eval = inputs_to_eval 27 | self.data_dir = data_dir 28 | self.phase = phase 29 | self.random_pc = random_pc 30 | self.gaussian = gaussian 31 | self.apply_point_cloud_normalization = apply_point_cloud_normalization 32 | self.dataset_processing_preferred_device = dataset_processing_preferred_device 33 | self.train_with_visibility_label = train_with_visibility_label 34 | self.yml_gt_normalized_dir_name = 'yml_gt_normalized' 35 | self.point_cloud_fps_dir_name = 'point_cloud_fps' 36 | self.point_cloud_random_dir_name = 'point_cloud_random' 37 | self.num_point_clouds_per_combination = num_point_clouds_per_combination 38 | self.augment_with_random_points = augment_with_random_points 39 | self.ds_path = Path(data_dir, phase) 40 | if not self.ds_path.is_dir(): 41 | raise Exception(f"Could not find a dataset in path [{self.ds_path}]") 42 | 43 | if scanobjectnn: 44 | random_pc_dir = self.ds_path.joinpath(self.point_cloud_random_dir_name) 45 | # [:-2] removes the _0 so that when it is added later it will match the file name 46 | self.file_names = [f.stem[:-2] for f in random_pc_dir.glob("*.npy")] 47 | self.num_files = len(self.file_names) 48 | self.size = self.num_files * self.num_point_clouds_per_combination 49 | return 50 | print(f"Processing dataset [{phase}] with farthest point sampling...") 51 | if not self.random_pc: 52 | generate_point_clouds(data_dir, phase, num_points, self.num_point_clouds_per_combination, 53 | self.point_cloud_fps_dir_name, sampling_method=farthest_point_sampling, gaussian=self.gaussian, 54 | apply_point_cloud_normalization=self.apply_point_cloud_normalization) 55 | else: 56 | num_points = self.random_pc 57 | print(f"Using uniform sampling only with [{num_points}] samples") 58 | normalize_labels(data_dir, phase, self.yml_gt_normalized_dir_name, params_descriptors, self.train_with_visibility_label) 59 | print(f"Processing dataset [{phase}] with uniform sampling (augmentation)...") 60 | generate_point_clouds(data_dir, phase, num_points, self.num_point_clouds_per_combination, 61 | self.point_cloud_random_dir_name, sampling_method=sample_surface, gaussian=self.gaussian, 62 | apply_point_cloud_normalization=self.apply_point_cloud_normalization) 63 | 64 | obj_gt_dir = self.ds_path.joinpath('obj_gt') 65 | self.file_names = [f.stem for f in obj_gt_dir.glob("*.obj")] 66 | 67 | self.num_files = len(self.file_names) 68 | self.size = self.num_files * self.num_point_clouds_per_combination 69 | 70 | 71 | def __getitem__(self, _index): 72 | file_idx = _index // self.num_point_clouds_per_combination 73 | sample_idx = _index % self.num_point_clouds_per_combination 74 | file_name = self.file_names[file_idx] 75 | 76 | pc = [] 77 | random_pc_path = self.ds_path.joinpath(self.point_cloud_random_dir_name, f"{file_name}_{sample_idx}.npy") 78 | fps_pc_path = self.ds_path.joinpath(self.point_cloud_fps_dir_name, f"{file_name}_{sample_idx}.npy") 79 | if self.random_pc: 80 | pc = np.load(str(random_pc_path)) 81 | pc = torch.from_numpy(pc).float() 82 | assert len(pc) == self.random_pc 83 | else: 84 | if fps_pc_path.is_file(): 85 | pc = np.load(str(fps_pc_path)) 86 | pc = torch.from_numpy(pc).float() 87 | 88 | # augment the farthest point sampled point cloud with points from a randomly sampled point cloud 89 | # note that in some tests we did not apply the augmentation 90 | if self.augment_with_random_points: 91 | pc_aug = np.load(str(random_pc_path)) 92 | pc_aug = torch.from_numpy(pc_aug).float() 93 | pc_aug = pc_aug[np.random.choice(pc_aug.shape[0], replace=False, size=800)] 94 | pc = torch.cat((pc, pc_aug), dim=0) 95 | else: 96 | assert self.phase == "real" 97 | 98 | # assert that the point cloud is normalized 99 | max_diff = 0.05 100 | if self.random_pc: 101 | max_diff = 0.3 102 | if not self.gaussian or self.gaussian == 0.0: 103 | max_dist_from_center = abs(1.0 - torch.max(torch.sqrt(torch.sum((pc ** 2), dim=1)))) 104 | assert max_dist_from_center < max_diff, f"Point cloud is not normalized [{max_dist_from_center} > {max_diff}] for sample [{file_name}]. If this is an external ds, please consider using prepare_coseg.py script first." 105 | 106 | # load target vectors, for test phase, some examples may not have a yml file attached to the 107 | yml_path = self.ds_path.joinpath(self.yml_gt_normalized_dir_name, f"{file_name}.yml") 108 | yml_obj = None 109 | if yml_path.is_file(): 110 | with open(yml_path, 'r') as f: 111 | yml_obj = yaml.load(f, Loader=yaml.FullLoader) 112 | else: 113 | # for training and validation we must have a yml file for each sample, for certain phases, yml file is not mandatory 114 | assert self.phase == "coseg" or self.phase == "real" 115 | 116 | # assemble the vectors in the requested order of parameters 117 | targets = assemble_targets(yml_obj, self.inputs_to_eval) 118 | 119 | # dataloaders are not allowed to return None, anything empty is converted to [] 120 | return file_name, pc, targets, yml_obj if yml_obj else [] 121 | 122 | def __len__(self): 123 | return self.size 124 | -------------------------------------------------------------------------------- /data/dataset_sketch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.utils.data as data 3 | import yaml 4 | from pathlib import Path 5 | from .data_processing import normalize_labels 6 | from torchvision import transforms 7 | from PIL import Image 8 | from skimage.morphology import erosion, dilation 9 | import random 10 | from .dataset_util import assemble_targets 11 | 12 | 13 | class DatasetSketch(data.Dataset): 14 | def __init__(self, 15 | inputs_to_eval, 16 | params_descriptors, 17 | camera_angles_to_process, 18 | pretrained_vgg, 19 | data_dir, 20 | phase, 21 | train_with_visibility_label=True): 22 | self.inputs_to_eval = inputs_to_eval 23 | self.data_dir = data_dir 24 | self.phase = phase 25 | self.pretrained_vgg = pretrained_vgg 26 | self.train_with_visibility_label = train_with_visibility_label 27 | self.camera_angles_to_process = camera_angles_to_process 28 | self.num_sketches_camera_angles = len(self.camera_angles_to_process) 29 | self.yml_gt_normalized_dir_name = 'yml_gt_normalized' 30 | self.ds_path = Path(data_dir, phase) 31 | if not self.ds_path.is_dir(): 32 | raise Exception(f"Could not find a dataset in path [{self.ds_path}]") 33 | self.sketches_path = self.ds_path.joinpath("sketches") 34 | if not self.sketches_path.is_dir(): 35 | raise Exception(f"Could not find a sketches in path [{self.sketches_path}]") 36 | self.sketch_transforms = transforms.Compose([ 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 40 | ]) 41 | normalize_labels(data_dir, phase, self.yml_gt_normalized_dir_name, params_descriptors, self.train_with_visibility_label) 42 | 43 | obj_gt_dir = self.ds_path.joinpath('obj_gt') 44 | self.file_names = [f.stem for f in obj_gt_dir.glob("*.obj")] 45 | if self.phase == "real" or self.phase == "clipasso" or self.phase == "traced": 46 | self.file_names = [f.stem for f in self.sketches_path.glob("*.png")] 47 | 48 | num_files = len(self.file_names) 49 | if self.phase == "real" or self.phase == "clipasso" or self.phase == "traced": 50 | self.size = num_files 51 | else: 52 | self.size = num_files * self.num_sketches_camera_angles 53 | 54 | def __getitem__(self, _index): 55 | if self.phase == "real" or self.phase == "clipasso" or self.phase == "traced": 56 | file_idx = _index 57 | sketch_idx = 0 58 | else: 59 | file_idx = _index // self.num_sketches_camera_angles 60 | sketch_idx = _index % self.num_sketches_camera_angles 61 | file_name = self.file_names[file_idx] 62 | 63 | # load target vectors, for test phase, some examples may not have a yml file attached to them 64 | yml_path = self.ds_path.joinpath(self.yml_gt_normalized_dir_name, f"{file_name}.yml") 65 | yml_obj = None 66 | if yml_path.is_file(): 67 | with open(yml_path, 'r') as f: 68 | yml_obj = yaml.load(f, Loader=yaml.FullLoader) 69 | else: 70 | # for training and validation we must have a yml file for each sample, for certain phases, yml file is not mandatory 71 | assert self.phase == "test" or self.phase == "coseg" or self.phase == "real" or self.phase == "clipasso" or self.phase == "traced" 72 | 73 | # assemble the vectors in the requested order of parameters 74 | targets = assemble_targets(yml_obj, self.inputs_to_eval) 75 | 76 | sketch_files = sorted(self.sketches_path.glob(f"{file_name}_*.png")) 77 | if self.phase == "real" or self.phase == "clipasso" or self.phase == "traced": 78 | sketch_files = sorted(self.sketches_path.glob(f"{file_name}.png")) 79 | # filter out sketches from camera angles that are excluded 80 | if self.phase != "real" and self.phase != "clipasso" and self.phase != "traced": 81 | sketch_files = [f for f in sketch_files if any( camera_angle in f.name for camera_angle in self.camera_angles_to_process )] 82 | if len(sketch_files) != len(self.camera_angles_to_process): 83 | raise Exception(f"Object [{file_name}] is missing sketch files") 84 | sketch_file = sketch_files[sketch_idx] 85 | sketch = Image.open(sketch_file).convert("RGB") 86 | if sketch.size[0] != sketch.size[1]: 87 | raise Exception(f"Images should be square, got [{sketch.size}] instead.") 88 | if sketch.size[0] != 224: 89 | sketch = sketch.resize((224, 224), Image.BILINEAR) 90 | # augmentation for the sketches 91 | if self.phase == "train": 92 | # three augmentation options: 1) original 2) erosion 3) erosion then dilation 93 | aug_idx = random.randint(0, 2) 94 | if aug_idx == 1: 95 | sketch = np.array(sketch) 96 | sketch = erosion(sketch) 97 | sketch = Image.fromarray(sketch) 98 | if aug_idx == 2: 99 | sketch = np.array(sketch) 100 | eroded = erosion(sketch) 101 | sketch = dilation(eroded) 102 | sketch = Image.fromarray(sketch) 103 | sketch = self.sketch_transforms(sketch) 104 | if not self.pretrained_vgg: 105 | sketch = sketch[0].unsqueeze(0) # sketch.shape = [1, 224, 224] 106 | 107 | curr_file_camera_angle = 'angle_na' 108 | for camera_angle in self.camera_angles_to_process: 109 | if camera_angle in str(sketch_file): 110 | curr_file_camera_angle = camera_angle 111 | break 112 | if self.phase != "real" and self.phase != "clipasso" and self.phase != "traced": 113 | assert curr_file_camera_angle != 'angle_na' 114 | 115 | # dataloaders are not allowed to return None, anything empty is converted to [] 116 | return file_name, curr_file_camera_angle, sketch, targets, yml_obj if yml_obj else [] 117 | 118 | def __len__(self): 119 | return self.size 120 | -------------------------------------------------------------------------------- /data/dataset_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Dict, List 4 | 5 | 6 | def assemble_targets(yml_obj: Dict, inputs_to_eval: List[str]): 7 | targets = [] 8 | if yml_obj: 9 | for param_name in inputs_to_eval: 10 | if param_name[-2:] == ' x': 11 | targets.append(yml_obj[param_name[:-2]]['x']) 12 | elif param_name[-2:] == ' y': 13 | targets.append(yml_obj[param_name[:-2]]['y']) 14 | elif param_name[-2:] == ' z': 15 | targets.append(yml_obj[param_name[:-2]]['z']) 16 | else: 17 | targets.append(yml_obj[param_name]) 18 | 19 | # convert from list to numpy array and then to torch tensor 20 | targets = torch.from_numpy(np.asarray(targets)) 21 | return targets 22 | -------------------------------------------------------------------------------- /dataset_generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/dataset_generator/__init__.py -------------------------------------------------------------------------------- /dataset_generator/base_recipe_generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations # to allow deferred annotations 2 | 3 | import bpy 4 | import sys 5 | import argparse 6 | from pathlib import Path 7 | from typing import List, Set 8 | from dataclasses import dataclass, field 9 | 10 | import pip 11 | pip.main(['install', 'sympy']) 12 | pip.main(['install', 'pyyaml']) 13 | 14 | from sympy import symbols, simplify_logic, And, Or, Eq, Ne, Not 15 | from sympy.parsing.sympy_parser import parse_expr 16 | import yaml 17 | 18 | import importlib 19 | 20 | def import_parents(level=1): 21 | global __package__ 22 | file = Path(__file__).resolve() 23 | parent, top = file.parent, file.parents[level] 24 | 25 | sys.path.append(str(top)) 26 | try: 27 | sys.path.remove(str(parent)) 28 | except ValueError: 29 | pass 30 | 31 | __package__ = '.'.join(parent.parts[len(top.parts):]) 32 | importlib.import_module(__package__) 33 | 34 | if __name__ == '__main__' and __package__ is None: 35 | import_parents(level=1) 36 | 37 | from common.bpy_util import select_shape, get_geometric_nodes_modifier 38 | from common.file_util import save_yml 39 | from dataset_generator.dataset_generator import update_base_shape_in_yml, update_recipe_yml_obj_with_metadata 40 | 41 | 42 | def OrAll(expressions): 43 | cond = False 44 | for expression in expressions: 45 | cond |= expression 46 | return cond 47 | 48 | 49 | @dataclass 50 | class TreeNode: 51 | name: str 52 | condition: False 53 | param_names_true: Set[str] = field(default_factory=set) 54 | param_names_false: Set[str] = field(default_factory=set) 55 | node_true: List[TreeNode] = field(default_factory=list) 56 | node_false: List[TreeNode] = field(default_factory=list) 57 | 58 | def tostring(self, is_true): 59 | print(f"is_true [{is_true}] name [{self.name}] condition [{self.condition}] params true {self.param_names_true} params false {self.param_names_false}") 60 | for node in self.node_true: 61 | node.tostring(True) 62 | for node in self.node_false: 63 | node.tostring(False) 64 | 65 | def get_all_param_names(self): 66 | def get_all_param_names_rec(node: TreeNode, param_names): 67 | param_names.update(node.param_names_true) 68 | param_names.update(node.param_names_false) 69 | for n in node.node_true: 70 | get_all_param_names_rec(n, param_names) 71 | for n in node.node_false: 72 | get_all_param_names_rec(n, param_names) 73 | param_names = set([]) 74 | get_all_param_names_rec(self, param_names) 75 | return param_names 76 | 77 | def induce_condition(self, param_name): 78 | def induce_cond_rec(node: TreeNode, param_name: str): 79 | if param_name in node.param_names_true and param_name in node.param_names_false: 80 | return True 81 | elif node.condition == 'True' and param_name in node.param_names_true: 82 | return True 83 | elif node.condition == 'False' and param_name in node.param_names_false: 84 | return True 85 | elif node.condition == 'True' and param_name not in node.param_names_true: 86 | expressions = [(induce_cond_rec(n, param_name)) for n in node.node_true] 87 | cond = OrAll(expressions) 88 | return (cond) 89 | elif node.condition == 'False' and param_name not in node.param_names_false: 90 | expressions = [(induce_cond_rec(n, param_name)) for n in node.node_false] 91 | cond = OrAll(expressions) 92 | return (cond) 93 | elif param_name in node.param_names_true: 94 | expressions = [(induce_cond_rec(n, param_name)) for n in node.node_false] 95 | if not expressions: 96 | return (node.condition) 97 | cond = OrAll(expressions) 98 | return ((node.condition) | (Not(node.condition) & (cond))) 99 | elif param_name in node.param_names_false: 100 | expressions = [(induce_cond_rec(n, param_name)) for n in node.node_true] 101 | if not expressions: 102 | return Not(node.condition) 103 | cond = OrAll(expressions) 104 | return (Not(node.condition) | ((node.condition) & (cond))) 105 | else: 106 | expressions1 = [(induce_cond_rec(n, param_name)) for n in node.node_false] 107 | expressions2 = [(induce_cond_rec(n, param_name)) for n in node.node_true] 108 | cond1 = OrAll(expressions1) 109 | cond2 = OrAll(expressions2) 110 | if not expressions1 and not expressions2: 111 | return False 112 | elif not expressions1: 113 | return ((node.condition) & (cond2)) 114 | else: 115 | try: 116 | return (Not(node.condition) & (cond1)) 117 | except: 118 | print(f"[{node.condition}] [{cond1}]") 119 | return ((Not(node.condition) & (cond1)) | ((node.condition) & (cond2))) 120 | return induce_cond_rec(self, param_name) 121 | 122 | 123 | def apply_operation(operation: str, term1, term2): 124 | if operation == 'LESS_THAN': 125 | return term1 < term2 126 | elif operation == 'LESS_THAN_OR_EQUAL': 127 | return term1 <= term2 128 | elif operation == 'GREATER_THAN': 129 | return term1 > term2 130 | elif operation == 'GREATER_THAN_OR_EQUAL': 131 | return term1 >= term2 132 | elif operation == 'EQUAL': 133 | return Eq(term1, term2) 134 | elif operation == 'NOT_EQUAL': 135 | return Ne(term1, term2) 136 | else: 137 | raise Exception(f"operation [{operation}] is not recognized") 138 | 139 | 140 | def assemble_condition(link, symbols_map): 141 | def assemble_condition_rec(link, symbols_map): 142 | from_node = link.from_node 143 | if from_node.type == 'GROUP_INPUT': 144 | socket_name = link.from_socket.name 145 | if socket_name not in symbols_map: 146 | symbols_map[socket_name] = symbols(socket_name) 147 | if link.to_node.type == 'SWITCH': 148 | # booleans are represented as 0 or 1 149 | return Eq(symbols_map[socket_name], 1) 150 | return symbols_map[socket_name] 151 | elif from_node.type == 'COMPARE': 152 | data_type = from_node.data_type 153 | input_idx_offset = None 154 | if data_type == 'FLOAT': 155 | input_idx_offset = 0 156 | elif data_type == 'INT': 157 | input_idx_offset = 2 158 | else: 159 | raise Exception(f"data_type [{data_type}] is not supported") 160 | input1 = from_node.inputs[0 + input_idx_offset] 161 | input2 = from_node.inputs[1 + input_idx_offset] 162 | if input1.is_linked: 163 | term1 = assemble_condition_rec(input1.links[0], symbols_map) 164 | else: 165 | term1 = input1.default_value 166 | if input2.is_linked: 167 | term2 = assemble_condition_rec(input2.links[0], symbols_map) 168 | else: 169 | term2 = input2.default_value 170 | cond = apply_operation(from_node.operation, term1, term2) 171 | return cond 172 | elif from_node.type == 'SWITCH': 173 | raise Exception("Aggregated switches are not supported yet") 174 | else: 175 | raise Exception(f"No support for node_type of [{from_node.type}]") 176 | condition = assemble_condition_rec(link, symbols_map) 177 | return condition 178 | 179 | 180 | def parse_switch(switch_node, symbols_map): 181 | switch_inputs = [input for input in switch_node.inputs if input.is_linked] 182 | switch_param_name = None 183 | link_true = None 184 | link_false = None 185 | condition = False # if not connected, the input condition defaults to 'False' 186 | for switch_input in switch_inputs: 187 | if switch_input.name == 'Switch': 188 | assert switch_node.inputs[1].is_linked 189 | condition = assemble_condition(switch_input.links[0], symbols_map) 190 | elif switch_input.name == 'True': 191 | assert switch_node.inputs[15].is_linked 192 | link_true = switch_input.links[0] 193 | elif switch_input.name == 'False': 194 | assert switch_node.inputs[14].is_linked 195 | link_false = switch_input.links[0] 196 | return condition, link_true, link_false 197 | 198 | 199 | def traverse_geo_node(link, tree_node: TreeNode, is_true, visited, symbols_map): 200 | from_node = link.from_node 201 | if from_node.type == 'SWITCH' and from_node.input_type == 'GEOMETRY': 202 | cond, link_true, link_false = parse_switch(from_node, symbols_map) 203 | print(f"{from_node.name} {link_true} {link_false}") 204 | tn = TreeNode(from_node.name, cond, set([]), set([]), [], []) 205 | if is_true: 206 | tree_node.node_true.append(tn) 207 | else: 208 | tree_node.node_false.append(tn) 209 | if link_true: 210 | traverse_geo_node(link_true, tn, True, visited, symbols_map) 211 | if link_false: 212 | traverse_geo_node(link_false, tn, False, visited, symbols_map) 213 | return 214 | if from_node.type == 'GROUP_INPUT': 215 | socket_name = link.from_socket.name 216 | if is_true: 217 | tree_node.param_names_true.add(socket_name) 218 | else: 219 | tree_node.param_names_false.add(socket_name) 220 | return 221 | for input in from_node.inputs: 222 | if not input.is_linked: 223 | continue 224 | for link in input.links: 225 | child_node = link.from_node 226 | vis_name = f"{child_node.name}_{link.from_socket.name}" 227 | #if vis_name not in visited: 228 | # visited.add(vis_name) 229 | traverse_geo_node(link, tree_node, is_true, visited, symbols_map) 230 | 231 | 232 | def assemble_rules(final_node_name): 233 | obj = bpy.data.objects['procedural shape'] 234 | mod = obj.modifiers['GeometryNodes'] 235 | ng = mod.node_group 236 | final_node = ng.nodes[final_node_name] 237 | final_link = final_node.inputs['Geometry'].links[0] 238 | 239 | # by default we assume the condition is False 240 | visited = set([]) 241 | symbols_map = {} 242 | root = TreeNode("root", False, set([]), set([]), [], []) 243 | traverse_geo_node(final_link, root, False, visited, symbols_map) 244 | root.tostring(False) 245 | param_names = root.get_all_param_names() 246 | # param_names = set([k for k in symbols_map]) 247 | #bool_param_names = [p for p in param_names if p.startswith("is_")] 248 | cond_map = {} 249 | 250 | # build the current parameter dict (used for a sanity check) 251 | param_value_map = {} 252 | group_input_nodes = [node for node in mod.node_group.nodes if node.type == 'GROUP_INPUT'] 253 | assert len(group_input_nodes) > 0 254 | group_input_node = group_input_nodes[0] 255 | for input in group_input_node.outputs: 256 | param_name = str(input.name) 257 | if len(param_name) == 0: 258 | continue 259 | # for input in ng.interface.items_tree: 260 | # for param_name in [input.name for input in ng.interface.items_tree]: 261 | #print(param_name) 262 | # param_name = input.name 263 | #print(ng.inputs[param_name].identifier) 264 | print(input.identifier) 265 | value = mod[input.identifier] 266 | if input.bl_label == 'BOOLEAN': 267 | param_value_map[param_name] = 1 if value else 0 268 | else: 269 | param_value_map[param_name] = value 270 | print(param_value_map) 271 | 272 | print(f"Number of parameters [{len(param_names)}]") 273 | for param_name in param_names: 274 | raw_expression = root.induce_condition(param_name) 275 | # parsed_expression = parse_expr(raw_expression) 276 | simplified_expression = simplify_logic(raw_expression) 277 | cond_map[param_name] = str(simplified_expression) 278 | current_eval = parse_expr(cond_map[param_name], local_dict=param_value_map) 279 | print(f"{param_name}: {cond_map[param_name]} with current parameters evaluates to {current_eval}") 280 | # remove "True" conditions 281 | return {k: v for k, v in cond_map.items() if v != "True"} 282 | 283 | 284 | def main(args): 285 | recipe_file_path = Path(args.recipe_file_path) 286 | obj = select_shape() 287 | # get the geometric nodes modifier for the object 288 | gnodes_mod = get_geometric_nodes_modifier(obj) 289 | recipe_yml_obj = update_base_shape_in_yml(gnodes_mod, recipe_file_path) 290 | update_recipe_yml_obj_with_metadata(recipe_yml_obj, gnodes_mod, write_dataset_generation=True) 291 | visibility_conditions = assemble_rules(args.final_node_name) 292 | if len(visibility_conditions) > 0: 293 | recipe_yml_obj['visibility_conditions'] = visibility_conditions 294 | else: 295 | if "visibility_conditions" in recipe_yml_obj: 296 | del recipe_yml_obj["visibility_conditions"] 297 | save_yml(recipe_yml_obj, recipe_file_path) 298 | 299 | 300 | 301 | if __name__ == "__main__": 302 | if '--' in sys.argv: 303 | argv = sys.argv[sys.argv.index('--') + 1:] 304 | else: 305 | raise Exception("Expected \'--\' followed by arguments to the script") 306 | 307 | parser = argparse.ArgumentParser(prog='dataset_generator') 308 | parser.add_argument('--recipe-file-path', type=str, required=True, help='Path to the output recipe file.') 309 | parser.add_argument('--final-node-name', type=str, help="The name of the output node of the procedural shape, e.g. \"Realize Instances\".") 310 | 311 | args = parser.parse_args(argv) 312 | main(args) 313 | 314 | """ 315 | "C:\Program Files\Blender Foundation\Blender 4.2\blender.exe" "D:\TAU MSc\Semester 4\Thesis\Shape Editing\Procedural Chair Revisit\simple_ceiling_lamp_with_inputs.blend" -b --python dataset_generator/base_recipe_generator.py -- --recipe-file-path "D:\TAU MSc\Semester 4\Thesis\Shape Editing\Procedural Chair Revisit\recipe_ceiling_lamp.yml" --final-node-name "Realize Instances" 316 | """ 317 | -------------------------------------------------------------------------------- /dataset_generator/recipe_files/recipe_ceiling_lamp.yml: -------------------------------------------------------------------------------- 1 | base: 2 | base_roundness: 0.5400000214576721 3 | base_thickness: 0.07000000029802322 4 | base_profile_type: 1 5 | base_profile_strength: 0.6800000071525574 6 | base_x: 0.5499999523162842 7 | base_y: 0.6000000238418579 8 | light_frame_count: 8 9 | light_frame_radius: 0.7900000214576721 10 | light_frame_thickness_x: 0.03999999910593033 11 | light_frame_thickness_y: 0.019999999552965164 12 | light_frame_roundness: 0.6399999856948853 13 | light_frame_style: 0.5 14 | lightbulb_ radius: 0.07999999821186066 15 | light_frame_mid_point_pos: 0.5160000324249268 16 | light_frame_mid_point_offset_y: 0.1459999680519104 17 | light_frame_mid_point_offset_z: 0.9339999556541443 18 | light_frame_turning_effect: 0.5 19 | light_frame_end_point_y: 0.1459999978542328 20 | light_frame_end_point_z: 0.23800000548362732 21 | dataset_generation: 22 | base_roundness: 23 | min: 0.0 24 | max: 1.0 25 | samples: 6 26 | base_thickness: 27 | min: 0.009999999776482582 28 | max: 0.17000000178813934 29 | samples: 5 30 | base_profile_type: 31 | min: 0 32 | max: 3 33 | base_profile_strength: 34 | min: 0.0 35 | max: 1.0 36 | samples: 6 37 | base_x: 38 | min: 0.029999999329447746 39 | max: 1.0 40 | samples: 8 41 | base_y: 42 | min: 0.029999999329447746 43 | max: 1.0 44 | samples: 8 45 | light_frame_count: 46 | min: 1 47 | max: 8 48 | light_frame_radius: 49 | min: 0.0 50 | max: 1.0 51 | samples: 8 52 | light_frame_thickness_x: 53 | min: 0.019999999552965164 54 | max: 0.11999999731779099 55 | samples: 5 56 | light_frame_thickness_y: 57 | min: 0.019999999552965164 58 | max: 0.11999999731779099 59 | samples: 5 60 | light_frame_roundness: 61 | min: 0.0 62 | max: 1.0 63 | samples: 6 64 | light_frame_style: 65 | min: 0.0 66 | max: 1.0 67 | samples: 6 68 | lightbulb_ radius: 69 | min: 0.019999999552965164 70 | max: 0.10000000149011612 71 | samples: 6 72 | light_frame_mid_point_pos: 73 | min: 0.0 74 | max: 1.0 75 | samples: 6 76 | light_frame_mid_point_offset_y: 77 | min: 0.0 78 | max: 1.0 79 | samples: 6 80 | light_frame_mid_point_offset_z: 81 | min: 0.0 82 | max: 1.0 83 | samples: 6 84 | light_frame_turning_effect: 85 | min: 0.0 86 | max: 1.0 87 | samples: 6 88 | light_frame_end_point_y: 89 | min: 0.0 90 | max: 0.5 91 | samples: 6 92 | light_frame_end_point_z: 93 | min: 0.11999999731779099 94 | max: 0.5 95 | samples: 5 96 | camera_angles_train: 97 | - - 30.0 98 | - 35.0 99 | - - 30.0 100 | - 55.0 101 | camera_angles_test: 102 | - - 30.0 103 | - 15.0 104 | data_types: 105 | base_roundness: 106 | type: Float 107 | min: 0.0 108 | max: 1.0 109 | base_thickness: 110 | type: Float 111 | min: 0.009999999776482582 112 | max: 0.17000000178813934 113 | base_profile_type: 114 | type: Integer 115 | min: 0 116 | max: 3 117 | base_profile_strength: 118 | type: Float 119 | min: 0.0 120 | max: 1.0 121 | base_x: 122 | type: Float 123 | min: 0.029999999329447746 124 | max: 1.0 125 | base_y: 126 | type: Float 127 | min: 0.029999999329447746 128 | max: 1.0 129 | light_frame_count: 130 | type: Integer 131 | min: 1 132 | max: 8 133 | light_frame_radius: 134 | type: Float 135 | min: 0.0 136 | max: 1.0 137 | light_frame_thickness_x: 138 | type: Float 139 | min: 0.019999999552965164 140 | max: 0.11999999731779099 141 | light_frame_thickness_y: 142 | type: Float 143 | min: 0.019999999552965164 144 | max: 0.11999999731779099 145 | light_frame_roundness: 146 | type: Float 147 | min: 0.0 148 | max: 1.0 149 | light_frame_style: 150 | type: Float 151 | min: 0.0 152 | max: 1.0 153 | lightbulb_ radius: 154 | type: Float 155 | min: 0.019999999552965164 156 | max: 0.10000000149011612 157 | light_frame_mid_point_pos: 158 | type: Float 159 | min: 0.0 160 | max: 1.0 161 | light_frame_mid_point_offset_y: 162 | type: Float 163 | min: 0.0 164 | max: 1.0 165 | light_frame_mid_point_offset_z: 166 | type: Float 167 | min: 0.0 168 | max: 1.0 169 | light_frame_turning_effect: 170 | type: Float 171 | min: 0.0 172 | max: 1.0 173 | light_frame_end_point_y: 174 | type: Float 175 | min: 0.0 176 | max: 0.5 177 | light_frame_end_point_z: 178 | type: Float 179 | min: 0.11999999731779099 180 | max: 0.5 181 | -------------------------------------------------------------------------------- /dataset_generator/recipe_files/recipe_chair.yml: -------------------------------------------------------------------------------- 1 | base: 2 | scale: 3 | x: 1.0 4 | y: 1.0 5 | z: 1.0 6 | bevel_rails: 0.0 7 | pillow_state: 1 8 | pillow_fill_edge: 1 9 | seat_shape: 0.423308789730072 10 | seat_pos: 0.5646687746047974 11 | cr_count: 5 12 | cr_scale_y: 0.7333 13 | cr_scale_z: 0.98 14 | cr_offset_bottom: 0.4071 15 | cr_offset_top: 0.7786 16 | cr_shape_1: 0.0 17 | curvature: 0.25 18 | is_top_rail: 1 19 | tr_fill_edge: 1 20 | tr_scale_y: 0.6800000071525574 21 | tr_scale_z: 1.3600000143051147 22 | tr_shape_1: 0.5525987148284912 23 | is_vertical_rail: 1 24 | vr_count: 4 25 | vr_scale_x: 0.7799999713897705 26 | vr_scale_y: 0.3499999940395355 27 | vr_shape_1: 0.0 28 | is_back_rest: 0 29 | legs_shape_1: 1.0 30 | legs_shape_2: 1.0 31 | legs_bevel: 0.5 32 | is_monoleg: 0 33 | is_monoleg_tent: 0 34 | monoleg_tent_pct: 0.4 35 | monoleg_tent_count: 4 36 | monoleg_bezier_start_x_offset: 0.0 37 | monoleg_bezier_start_handle_x_offset: 0.0 38 | monoleg_bezier_start_handle_z_pct: 0.4 39 | monoleg_bezier_end_x_offset: 0.8199999332427979 40 | monoleg_bezier_end_handle_x_offset: 0.4 41 | monoleg_bezier_end_handle_z_pct: 1.0 42 | back_frame_top_y_offset_pct: 0.0 43 | back_frame_mid_y_offset_pct: 0.0 44 | back_leg_bottom_y_offset_pct: 0.0 45 | back_leg_mid_y_offset_pct: 0.0 46 | handles_state: 0 47 | is_handles_support: 1 48 | is_handles_cusion: 1 49 | handles_base_pos_z_pct: 0.15 50 | handles_mid_pos_x_pct: 0.0 51 | handles_mid_pos_y_pct: 0.42329999804496765 52 | handles_mid_pos_z_pct: 0.5 53 | handles_edge_pos_x_pct: 0.0 54 | handles_bottom_pos_along_seat_pct: 0.4 55 | handles_profile_width: 0.9 56 | handles_profile_height: 0.9 57 | handles_support_mid_x: 0.5800000429153442 58 | handles_support_mid_y: 0.0 59 | handles_support_top_pos: 0.5 60 | handles_support_thickness: 0.8167 61 | handles_cusion_cover_pct: 1.0 62 | dataset_generation: 63 | scale: 64 | x: 65 | min: 0.5 66 | max: 2.0 67 | samples: 10 68 | y: 69 | min: 0.5 70 | max: 2.0 71 | samples: 10 72 | z: 73 | min: 0.5 74 | max: 2.0 75 | samples: 10 76 | bevel_rails: 77 | min: 0.0 78 | max: 1.0 79 | samples: 3 80 | pillow_state: 81 | min: 0 82 | max: 2 83 | pillow_fill_edge: 84 | min: 0 85 | max: 1 86 | seat_shape: 87 | min: 0.0 88 | max: 1.0 89 | samples: 5 90 | seat_pos: 91 | min: 0.2 92 | max: 1.0 93 | samples: 9 94 | cr_count: 95 | min: 3 96 | max: 8 97 | cr_scale_y: 98 | min: 0.5 99 | max: 1.2 100 | samples: 4 101 | cr_scale_z: 102 | min: 0.3 103 | max: 2.0 104 | samples: 6 105 | cr_offset_bottom: 106 | min: 0.15 107 | max: 0.75 108 | samples: 8 109 | cr_offset_top: 110 | min: 0.35 111 | max: 0.95 112 | samples: 8 113 | cr_shape_1: 114 | min: 0.0 115 | max: 2.0 116 | samples: 7 117 | curvature: 118 | min: 0.0 119 | max: 1.0 120 | samples: 5 121 | is_top_rail: 122 | min: 0 123 | max: 1 124 | tr_fill_edge: 125 | min: 0 126 | max: 1 127 | tr_scale_y: 128 | min: 0.5 129 | max: 1.2 130 | samples: 5 131 | tr_scale_z: 132 | min: 0.5 133 | max: 1.5 134 | samples: 5 135 | tr_shape_1: 136 | min: 0.0 137 | max: 1.0 138 | samples: 5 139 | is_vertical_rail: 140 | min: 0 141 | max: 1 142 | vr_count: 143 | min: 3 144 | max: 8 145 | vr_scale_x: 146 | min: 0.3 147 | max: 2.0 148 | samples: 6 149 | vr_scale_y: 150 | min: 0.2 151 | max: 1.0 152 | samples: 4 153 | vr_shape_1: 154 | min: 0.0 155 | max: 0.4 156 | samples: 5 157 | is_back_rest: 158 | min: 0 159 | max: 1 160 | legs_shape_1: 161 | min: 0.0 162 | max: 1.0 163 | samples: 3 164 | legs_shape_2: 165 | min: 0.0 166 | max: 1.0 167 | samples: 3 168 | legs_bevel: 169 | min: 0.0 170 | max: 1.0 171 | samples: 3 172 | is_monoleg: 173 | min: 0 174 | max: 1 175 | is_monoleg_tent: 176 | min: 0 177 | max: 1 178 | monoleg_tent_pct: 179 | min: 0.2 180 | max: 0.8 181 | samples: 7 182 | monoleg_tent_count: 183 | min: 3 184 | max: 8 185 | monoleg_bezier_start_x_offset: 186 | min: 0.0 187 | max: 1.0 188 | samples: 6 189 | monoleg_bezier_start_handle_x_offset: 190 | min: 0.0 191 | max: 1.0 192 | samples: 6 193 | monoleg_bezier_start_handle_z_pct: 194 | min: 0.0 195 | max: 1.0 196 | samples: 6 197 | monoleg_bezier_end_x_offset: 198 | min: 0.0 199 | max: 1.0 200 | samples: 6 201 | monoleg_bezier_end_handle_x_offset: 202 | min: 0.0 203 | max: 1.0 204 | samples: 6 205 | monoleg_bezier_end_handle_z_pct: 206 | min: 0.0 207 | max: 1.0 208 | samples: 6 209 | back_frame_top_y_offset_pct: 210 | min: 0.0 211 | max: 1.0 212 | samples: 6 213 | back_frame_mid_y_offset_pct: 214 | min: 0.0 215 | max: 1.0 216 | samples: 6 217 | back_leg_bottom_y_offset_pct: 218 | min: 0.0 219 | max: 1.0 220 | samples: 6 221 | back_leg_mid_y_offset_pct: 222 | min: 0.0 223 | max: 1.0 224 | samples: 6 225 | handles_state: 226 | min: 0 227 | max: 2 228 | is_handles_support: 229 | min: 0 230 | max: 1 231 | is_handles_cusion: 232 | min: 0 233 | max: 1 234 | handles_profile_width: 235 | min: 0.5 236 | max: 1.0 237 | samples: 6 238 | handles_profile_height: 239 | min: 0.5 240 | max: 1.0 241 | samples: 6 242 | handles_base_pos_z_pct: 243 | min: 0.15 244 | max: 0.8 245 | samples: 8 246 | handles_mid_pos_x_pct: 247 | min: 0.0 248 | max: 1.0 249 | samples: 7 250 | handles_mid_pos_y_pct: 251 | min: 0.1 252 | max: 0.9 253 | samples: 7 254 | handles_mid_pos_z_pct: 255 | min: 0.0 256 | max: 1.0 257 | samples: 7 258 | handles_edge_pos_x_pct: 259 | min: 0.0 260 | max: 1.0 261 | samples: 6 262 | handles_bottom_pos_along_seat_pct: 263 | min: 0.3 264 | max: 0.9 265 | samples: 7 266 | handles_support_mid_x: 267 | min: 0.0 268 | max: 1.0 269 | samples: 6 270 | handles_support_mid_y: 271 | min: 0.0 272 | max: 1.0 273 | samples: 6 274 | handles_support_top_pos: 275 | min: 0.3 276 | max: 0.9 277 | samples: 7 278 | handles_support_thickness: 279 | min: 0.4 280 | max: 0.9 281 | samples: 7 282 | handles_cusion_cover_pct: 283 | min: 0.0 284 | max: 1.0 285 | samples: 6 286 | constraints: 287 | rule1: cr_offset_top - cr_offset_bottom >= 0.1 288 | rule2: monoleg_bezier_end_x_offset >= monoleg_bezier_start_x_offset 289 | visibility_conditions: 290 | bevel_rails: ( not is_back_rest ) 291 | cr_: ( not is_back_rest ) and ( not is_top_rail or not is_vertical_rail ) 292 | vr_: ( not is_back_rest ) and is_top_rail and is_vertical_rail 293 | tr_: is_top_rail 294 | is_vertical_rail: ( not is_back_rest ) and ( is_top_rail ) 295 | legs_shape_: ( is_monoleg and is_monoleg_tent ) or ( not is_monoleg ) 296 | is_monoleg_tent: is_monoleg 297 | monoleg_tent_pct: is_monoleg and is_monoleg_tent 298 | monoleg_tent_count: is_monoleg and is_monoleg_tent 299 | monoleg_bezier: is_monoleg 300 | pillow_fill_edge: pillow_state > 0 301 | back_leg_: ( not is_monoleg ) 302 | is_handles_support: handles_state == 1 303 | handles_support_: ( handles_state == 1 ) and ( is_handles_support ) 304 | is_handles_cusion: handles_state > 0 305 | handles_cusion_cover_pct: ( handles_state > 0 ) and ( is_handles_cusion ) 306 | handles_profile_: handles_state > 0 307 | handles_base_: handles_state > 0 308 | handles_mid_: handles_state > 0 309 | handles_edge_: handles_state == 1 310 | handles_bottom_pos_along_seat_pct: ( handles_state == 2 ) or ( handles_state == 1 and is_handles_support == 1 ) 311 | camera_angles_train: 312 | - - -30.0 313 | - 35.0 314 | - - -30.0 315 | - 55.0 316 | camera_angles_test: 317 | - - -30.0 318 | - 15.0 319 | -------------------------------------------------------------------------------- /dataset_generator/recipe_files/recipe_table.yml: -------------------------------------------------------------------------------- 1 | base: 2 | table_top_scale_x: 2.5999999046325684 3 | table_top_scale_y: 2.5999999046325684 4 | table_top_height: 0.0 5 | table_top_shape: 1.0 6 | table_top_thickness: 0.0 7 | table_top_profile_state: 2 8 | table_top_profile_strength: 0.0 9 | legs_shape_1: 1.0 10 | legs_shape_2: 1.0 11 | legs_bevel: 0.0 12 | std_legs_bottom_offset_y: 1.0 13 | std_legs_mid_offset_y: 0.0 14 | std_legs_top_offset_x: 0.26999998092651367 15 | std_legs_top_offset_y: 0.0 16 | std_legs_rotation: 0.0 17 | is_std_legs_support_x: 1 18 | std_legs_support_x_height: 0.7200000286102295 19 | std_legs_support_x_curvature: 0.0 20 | std_legs_support_x_profile_width: 1.0 21 | std_legs_support_x_profile_height: 0.27000004053115845 22 | is_std_legs_support_y: 1 23 | std_legs_support_y_height: 0.36000001430511475 24 | std_legs_support_y_curvature: 0.0 25 | std_legs_support_y_profile_width: 1.0 26 | std_legs_support_y_profile_height: 1.0 27 | is_monoleg: 0 28 | is_monoleg_tent: 1 29 | monoleg_tent_pct: 0.6649518609046936 30 | monoleg_tent_base_radius: 0.0 31 | monoleg_tent_count: 5 32 | monoleg_bezier_start_x_offset: 0.5273312330245972 33 | monoleg_bezier_start_handle_x_offset: 1.0 34 | monoleg_bezier_start_handle_z_pct: 0.46000000834465027 35 | monoleg_bezier_end_x_offset: 0.2499999850988388 36 | monoleg_bezier_end_handle_x_offset: 0.10999999940395355 37 | monoleg_bezier_end_handle_z_pct: 0.20999999344348907 38 | dataset_generation: 39 | table_top_scale_x: 40 | min: 0.6 41 | max: 2.6 42 | samples: 12 43 | table_top_scale_y: 44 | min: 0.6 45 | max: 2.6 46 | samples: 12 47 | table_top_height: 48 | min: 0.0 49 | max: 1.0 50 | samples: 8 51 | table_top_shape: 52 | min: 0.0 53 | max: 1.0 54 | samples: 11 55 | table_top_thickness: 56 | min: 0.0 57 | max: 1.0 58 | samples: 6 59 | table_top_profile_state: 60 | min: 0 61 | max: 3 62 | table_top_profile_strength: 63 | min: 0.0 64 | max: 1.0 65 | samples: 6 66 | legs_shape_1: 67 | min: 0.0 68 | max: 1.0 69 | samples: 3 70 | legs_shape_2: 71 | min: 0.0 72 | max: 1.0 73 | samples: 3 74 | legs_bevel: 75 | min: 0.0 76 | max: 1.0 77 | samples: 3 78 | std_legs_bottom_offset_y: 79 | min: 0.0 80 | max: 1.0 81 | samples: 6 82 | std_legs_mid_offset_y: 83 | min: 0.0 84 | max: 1.0 85 | samples: 6 86 | std_legs_top_offset_x: 87 | min: 0.0 88 | max: 1.0 89 | samples: 6 90 | std_legs_top_offset_y: 91 | min: 0.0 92 | max: 1.0 93 | samples: 6 94 | std_legs_rotation: 95 | min: 0.0 96 | max: 1.0 97 | samples: 6 98 | is_std_legs_support_x: 99 | min: 0 100 | max: 1 101 | std_legs_support_x_height: 102 | min: 0.0 103 | max: 1.0 104 | samples: 6 105 | std_legs_support_x_curvature: 106 | min: 0.0 107 | max: 1.0 108 | samples: 6 109 | std_legs_support_x_profile_width: 110 | min: 0.0 111 | max: 1.0 112 | samples: 5 113 | std_legs_support_x_profile_height: 114 | min: 0.0 115 | max: 1.0 116 | samples: 5 117 | is_std_legs_support_y: 118 | min: 0 119 | max: 1 120 | std_legs_support_y_height: 121 | min: 0.0 122 | max: 1.0 123 | samples: 6 124 | std_legs_support_y_curvature: 125 | min: 0.0 126 | max: 1.0 127 | samples: 6 128 | std_legs_support_y_profile_width: 129 | min: 0.0 130 | max: 1.0 131 | samples: 5 132 | std_legs_support_y_profile_height: 133 | min: 0.0 134 | max: 1.0 135 | samples: 5 136 | is_monoleg: 137 | min: 0 138 | max: 1 139 | is_monoleg_tent: 140 | min: 0 141 | max: 1 142 | monoleg_tent_pct: 143 | min: 0.2 144 | max: 0.8 145 | samples: 7 146 | monoleg_tent_base_radius: 147 | min: 0.0 148 | max: 1.0 149 | samples: 11 150 | monoleg_tent_count: 151 | min: 3 152 | max: 8 153 | monoleg_bezier_start_x_offset: 154 | min: 0.0 155 | max: 1.0 156 | samples: 6 157 | monoleg_bezier_start_handle_x_offset: 158 | min: 0.0 159 | max: 1.0 160 | samples: 6 161 | monoleg_bezier_start_handle_z_pct: 162 | min: 0.0 163 | max: 1.0 164 | samples: 6 165 | monoleg_bezier_end_x_offset: 166 | min: 0.2 167 | max: 1.0 168 | samples: 5 169 | monoleg_bezier_end_handle_x_offset: 170 | min: 0.0 171 | max: 1.0 172 | samples: 6 173 | monoleg_bezier_end_handle_z_pct: 174 | min: 0.0 175 | max: 1.0 176 | samples: 6 177 | constraints: 178 | rule1: monoleg_bezier_end_x_offset >= monoleg_bezier_start_x_offset 179 | visibility_conditions: 180 | legs_shape_: ( is_monoleg and is_monoleg_tent ) or ( not is_monoleg ) 181 | legs_bevel: ( is_monoleg and is_monoleg_tent ) or ( not is_monoleg ) 182 | is_monoleg_tent: is_monoleg 183 | monoleg_tent_pct: is_monoleg and is_monoleg_tent 184 | monoleg_tent_count: is_monoleg and is_monoleg_tent 185 | monoleg_tent_base_radius: is_monoleg and is_monoleg_tent 186 | monoleg_bezier: is_monoleg and ( not is_monoleg_tent or monoleg_tent_pct < 0.4 ) 187 | std_legs_support_x_: ( not is_monoleg ) and is_std_legs_support_x 188 | std_legs_support_y_: ( not is_monoleg ) and is_std_legs_support_y 189 | std_legs: ( not is_monoleg ) 190 | table_top_profile_strength: table_top_profile_state > 0 191 | camera_angles_train: 192 | - - -30.0 193 | - 35.0 194 | - - -30.0 195 | - 55.0 196 | camera_angles_test: 197 | - - -30.0 198 | - 15.0 199 | -------------------------------------------------------------------------------- /dataset_generator/recipe_files/recipe_vase.yml: -------------------------------------------------------------------------------- 1 | base: 2 | body_height: 1.2143 3 | body_width: 0.2944 4 | body_bottom_curve_width: 0.0 5 | body_bottom_curve_height: 0.6 6 | body_mouth_width: 0.6 7 | body_top_curve_width: 0.2 8 | body_top_curve_height: 0.9 9 | body_profile_blend: 0.8 10 | has_body_thickness: 0 11 | body_thickness_val: 0.05999999865889549 12 | handle_count: 5 13 | hndl_type: 2 14 | hndl_profile_width: 1.0 15 | hndl_profile_height: 0.2 16 | hndl_profile_blend: 0.8 17 | hndl_base_z: 0.4699999988079071 18 | hndl_base_bezier_handle_angle: 0.4 19 | hndl_base_bezier_handle_length: 0.2 20 | hndl_radius_along_path: 0.0 21 | hndl1_top_z: 0.5 22 | hndl1_end_bezier_handle_angle: 0.2 23 | hndl1_end_bezier_handle_length: 0.6 24 | hndl2_end_x: 0.0 25 | hndl2_end_z: 0.4099999964237213 26 | hndl2_end_bezier_handle_x: 0.0 27 | hndl2_end_bezier_handle_z: 0.2 28 | has_neck: 1 29 | neck_end_x: 0.8 30 | neck_end_z: 0.8 31 | neck_end_bezier_handle_x: 0.8 32 | neck_end_bezier_handle_z: 0.0 33 | has_base: 1 34 | base_start_x: 0.7 35 | base_start_z: 0.6 36 | base_mid_x: 0.2 37 | base_mid_z: 0.8 38 | has_lid: 0 39 | has_lid_handle: 1 40 | lid_handle_radius: 0.02 41 | dataset_generation: 42 | body_height: 43 | min: 0.5 44 | max: 1.5 45 | samples: 15 46 | body_width: 47 | min: 0.1 48 | max: 0.45 49 | samples: 10 50 | body_bottom_curve_width: 51 | min: 0.0 52 | max: 1.0 53 | samples: 11 54 | body_bottom_curve_height: 55 | min: 0.1 56 | max: 0.9 57 | samples: 9 58 | body_mouth_width: 59 | min: 0.0 60 | max: 0.7 61 | samples: 8 62 | body_top_curve_width: 63 | min: 0.0 64 | max: 0.9 65 | samples: 10 66 | body_top_curve_height: 67 | min: 0.1 68 | max: 0.9 69 | samples: 9 70 | body_profile_blend: 71 | min: 0.0 72 | max: 1.0 73 | samples: 11 74 | has_body_thickness: 75 | min: 0 76 | max: 1 77 | body_thickness_val: 78 | min: 0.01 79 | max: 0.07 80 | samples: 7 81 | handle_count: 82 | min: 0 83 | max: 6 84 | hndl_type: 85 | min: 1 86 | max: 2 87 | hndl_profile_width: 88 | min: 0.0 89 | max: 1.0 90 | samples: 6 91 | hndl_profile_height: 92 | min: 0.0 93 | max: 1.0 94 | samples: 6 95 | hndl_profile_blend: 96 | min: 0.0 97 | max: 1.0 98 | samples: 6 99 | hndl_base_z: 100 | min: 0.1 101 | max: 0.6 102 | samples: 6 103 | hndl_base_bezier_handle_angle: 104 | min: 0.0 105 | max: 1.0 106 | samples: 11 107 | hndl_base_bezier_handle_length: 108 | min: 0.0 109 | max: 1.0 110 | samples: 6 111 | hndl_radius_along_path: 112 | min: 0.0 113 | max: 1.0 114 | samples: 11 115 | hndl1_top_z: 116 | min: 0.2 117 | max: 0.8 118 | samples: 7 119 | hndl1_end_bezier_handle_angle: 120 | min: 0.0 121 | max: 1.0 122 | samples: 11 123 | hndl1_end_bezier_handle_length: 124 | min: 0.0 125 | max: 1.0 126 | samples: 6 127 | hndl2_end_x: 128 | min: 0.0 129 | max: 1.0 130 | samples: 11 131 | hndl2_end_z: 132 | min: 0.0 133 | max: 1.0 134 | samples: 11 135 | hndl2_end_bezier_handle_x: 136 | min: 0.0 137 | max: 1.0 138 | samples: 11 139 | hndl2_end_bezier_handle_z: 140 | min: 0.1 141 | max: 1.0 142 | samples: 10 143 | has_neck: 144 | min: 0 145 | max: 1 146 | neck_end_x: 147 | min: 0.1 148 | max: 1.0 149 | samples: 10 150 | neck_end_z: 151 | min: 0.0 152 | max: 1.0 153 | samples: 11 154 | neck_end_bezier_handle_x: 155 | min: 0.0 156 | max: 1.0 157 | samples: 11 158 | neck_end_bezier_handle_z: 159 | min: 0.0 160 | max: 1.0 161 | samples: 11 162 | has_base: 163 | min: 0 164 | max: 1 165 | base_start_x: 166 | min: 0.0 167 | max: 1.0 168 | samples: 11 169 | base_start_z: 170 | min: 0.0 171 | max: 1.0 172 | samples: 6 173 | base_mid_x: 174 | min: 0.0 175 | max: 1.0 176 | samples: 11 177 | base_mid_z: 178 | min: 0.0 179 | max: 1.0 180 | samples: 6 181 | has_lid: 182 | min: 0 183 | max: 1 184 | has_lid_handle: 185 | min: 0 186 | max: 1 187 | lid_handle_radius: 188 | min: 0.02 189 | max: 0.07 190 | samples: 6 191 | visibility_conditions: 192 | body_thickness_val: has_body_thickness and not has_lid 193 | hndl_: handle_count > 0 194 | hndl1_: handle_count > 0 and hndl_type == 1 195 | hndl2_: handle_count > 0 and hndl_type == 2 196 | neck_: has_neck 197 | base_start_: has_base 198 | base_mid_: has_base 199 | has_lid_handle: has_lid 200 | lid_handle_radius: has_lid and has_lid_handle 201 | camera_angles_train: 202 | - - -30.0 203 | - 35.0 204 | - - -30.0 205 | - 55.0 206 | camera_angles_test: 207 | - - -30.0 208 | - 15.0 209 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/dataset_generator/shape_validators/__init__.py -------------------------------------------------------------------------------- /dataset_generator/shape_validators/ceiling_lamp_validator.py: -------------------------------------------------------------------------------- 1 | from dataset_generator.shape_validators.shape_validator_interface import ShapeValidatorInterface 2 | from common.intersection_util import find_self_intersections 3 | 4 | 5 | class CeilingLampValidator(ShapeValidatorInterface): 6 | def validate_shape(self, input_params_map) -> (bool, str): 7 | frames_self_intersections = find_self_intersections('frames') 8 | if frames_self_intersections > 0: 9 | return False, "Self intersection between the frames" 10 | light_bulbs_self_intersections = find_self_intersections('light_bulbs') 11 | if light_bulbs_self_intersections > 0: 12 | return False, "Self intersection between the light bulbs" 13 | return True, "Valid" 14 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/chair_validator.py: -------------------------------------------------------------------------------- 1 | from dataset_generator.shape_validators.shape_validator_interface import ShapeValidatorInterface 2 | from dataset_generator.shape_validators.common_validations import validate_monoleg 3 | from common.intersection_util import find_self_intersections, find_cross_intersections 4 | 5 | 6 | class ChairValidator(ShapeValidatorInterface): 7 | def validate_shape(self, input_params_map) -> (bool, str): 8 | if input_params_map['is_back_rest'].get_value() == 0 \ 9 | and input_params_map['is_top_rail'].get_value() == 1 \ 10 | and input_params_map['is_vertical_rail'].get_value() == 1: 11 | # reaching here means the vertical rails are visible 12 | # also note the assumption that the min vertical rails count is 3 13 | if find_self_intersections('vertical_rails_out') > 0: 14 | # try again, as we have vertical rails intersecting each other 15 | return False, "Vertical rails intersect themselves" 16 | if input_params_map['is_back_rest'].get_value() == 0 \ 17 | and (input_params_map['is_top_rail'].get_value() == 0 18 | or input_params_map['is_vertical_rail'].get_value() == 0): 19 | # reaching here means the cross rails are visible 20 | # also note the assumption that the min cross rails count is 3 21 | if find_self_intersections('cross_rails_and_top_rail_out') > 0: 22 | # try again, as we have cross rails intersecting each other or the top rail 23 | return False, "Cross rails intersect with top rail" 24 | if input_params_map['handles_state'].get_value() == 1 and input_params_map['is_handles_support'].get_value(): 25 | if find_self_intersections('handles_support_and_back_frame') > 0: 26 | return False, "Handles' supports intersect with the back frame" 27 | if input_params_map['handles_state'].get_value() > 0: 28 | if find_cross_intersections('handles_left_side', 'handles_right_side') > 0: 29 | # the handles in both sides of the chair should never intersect 30 | return False, "Left handles intersect with the right handles" 31 | if input_params_map['is_monoleg'].get_value() > 0 and input_params_map['is_monoleg_tent'].get_value() == 0: 32 | if not validate_monoleg('monoleg'): 33 | return False, "Monoleg Center of Mass" 34 | return True, "Valid" 35 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/common_validations.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import numpy as np 3 | from common.file_util import load_obj 4 | from common.bpy_util import select_shape 5 | from common.intersection_util import isolate_node_as_final_geometry 6 | 7 | 8 | def triangle_area(x): 9 | a = x[:, 0, :] - x[:, 1, :] 10 | b = x[:, 0, :] - x[:, 2, :] 11 | cross = np.cross(a, b) 12 | area = 0.5 * np.norm(cross, dim=1) 13 | return area 14 | 15 | 16 | def object_sanity_check(obj_file): 17 | try: 18 | vertices, faces = load_obj(obj_file) 19 | vertices = vertices.reshape(1, vertices.shape[0], vertices.shape[1]) 20 | faces = vertices.squeeze()[faces] 21 | triangle_area(faces) 22 | except Exception: 23 | print('Invalid sample') 24 | return False 25 | return True 26 | 27 | 28 | def validate_monoleg(node_label, factor=0.08): 29 | chair = select_shape() 30 | revert_isolation = isolate_node_as_final_geometry(chair, node_label) 31 | 32 | dup_obj = chair.copy() 33 | dup_obj.data = chair.data.copy() 34 | dup_obj.animation_data_clear() 35 | bpy.context.collection.objects.link(dup_obj) 36 | # move for clarity 37 | dup_obj.location.x += 2.0 38 | # set active 39 | bpy.ops.object.select_all(action='DESELECT') 40 | dup_obj.select_set(True) 41 | bpy.context.view_layer.objects.active = dup_obj 42 | # apply the modifier to turn the geometry node to a mesh 43 | bpy.ops.object.modifier_apply(modifier="GeometryNodes") 44 | # export the object 45 | assert dup_obj.type == 'MESH' 46 | 47 | revert_isolation() 48 | 49 | bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='BOUNDS') 50 | center_of_volume = dup_obj.location[2] 51 | # another option is to use (type='ORIGIN_CENTER_OF_MASS', center='MEDIAN') as the center of mass 52 | bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_VOLUME', center='MEDIAN') 53 | center_of_mass = dup_obj.location[2] 54 | height = dup_obj.dimensions[2] 55 | 56 | if center_of_volume - center_of_mass > factor * height: 57 | return True 58 | return False 59 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/shape_validator_factory.py: -------------------------------------------------------------------------------- 1 | from common.domain import Domain 2 | from dataset_generator.shape_validators.shape_validator_interface import ShapeValidatorInterface 3 | from dataset_generator.shape_validators.chair_validator import ChairValidator 4 | from dataset_generator.shape_validators.vase_validator import VaseValidator 5 | from dataset_generator.shape_validators.table_validator import TableValidator 6 | from dataset_generator.shape_validators.ceiling_lamp_validator import CeilingLampValidator 7 | 8 | 9 | class ShapeValidatorFactory: 10 | @staticmethod 11 | def create_validator(domain) -> ShapeValidatorInterface: 12 | if domain == Domain.chair: 13 | return ChairValidator() 14 | elif domain == Domain.vase: 15 | return VaseValidator() 16 | elif domain == Domain.table: 17 | return TableValidator() 18 | elif domain == Domain.ceiling_lamp: 19 | return CeilingLampValidator() 20 | else: 21 | raise Exception(f"Domain [{domain}] is not recognized.") 22 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/shape_validator_interface.py: -------------------------------------------------------------------------------- 1 | class ShapeValidatorInterface: 2 | def validate_shape(self, input_params_map) -> (bool, str): 3 | """validate the shape and return True if valid""" 4 | pass 5 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/table_validator.py: -------------------------------------------------------------------------------- 1 | from dataset_generator.shape_validators.shape_validator_interface import ShapeValidatorInterface 2 | from dataset_generator.shape_validators.common_validations import validate_monoleg 3 | from common.intersection_util import find_self_intersections 4 | 5 | 6 | class TableValidator(ShapeValidatorInterface): 7 | def validate_shape(self, input_params_map) -> (bool, str): 8 | table_top_and_legs_support_intersections = find_self_intersections('table_top_and_legs_support') 9 | if table_top_and_legs_support_intersections > 0: 10 | return False, "Table top intersects with the legs supports" 11 | floor_and_legs_support_intersections = find_self_intersections('floor_and_legs_support') 12 | if floor_and_legs_support_intersections > 0: 13 | return False, "Legs supports intersect with the floor" 14 | if input_params_map['is_monoleg'].get_value() > 0 and input_params_map['is_monoleg_tent'].get_value() == 0: 15 | if not validate_monoleg('monoleg', factor=0.16): 16 | # the factor is more restricting since the tables can be much wider than chairs 17 | return False, "Invalid monoleg" 18 | return True, "Valid" 19 | -------------------------------------------------------------------------------- /dataset_generator/shape_validators/vase_validator.py: -------------------------------------------------------------------------------- 1 | from dataset_generator.shape_validators.shape_validator_interface import ShapeValidatorInterface 2 | from common.intersection_util import find_self_intersections, find_cross_intersections 3 | 4 | 5 | class VaseValidator(ShapeValidatorInterface): 6 | def validate_shape(self, input_params_map) -> (bool, str): 7 | body_self_intersections = find_self_intersections('Body Self Intersections') 8 | if body_self_intersections > 0: 9 | return False, "Self intersection in the body" 10 | if input_params_map['handle_count'].get_value() > 0: 11 | handle_self_intersections = find_self_intersections('Handle Self Intersections') 12 | if handle_self_intersections > 0: 13 | return False, "self intersection in the handle" 14 | base_handle_intersections = find_self_intersections('Base and Handle Intersections') 15 | if base_handle_intersections > 0: 16 | return False, "Base intersects with handles" 17 | floor_handle_intersections = find_self_intersections('Floor and Handle Intersections') 18 | if floor_handle_intersections > 0: 19 | return False, "Floor intersects with handles" 20 | return True, "Valid" 21 | -------------------------------------------------------------------------------- /dataset_generator/sketch_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import traceback 5 | import bpy 6 | import time 7 | from mathutils import Vector 8 | import math 9 | import mathutils 10 | import random 11 | import argparse 12 | from tqdm import tqdm 13 | from pathlib import Path 14 | import importlib 15 | 16 | def import_parents(level=1): 17 | global __package__ 18 | file = Path(__file__).resolve() 19 | parent, top = file.parent, file.parents[level] 20 | 21 | sys.path.append(str(top)) 22 | try: 23 | sys.path.remove(str(parent)) 24 | except ValueError: 25 | pass 26 | 27 | __package__ = '.'.join(parent.parts[len(top.parts):]) 28 | importlib.import_module(__package__) 29 | 30 | if __name__ == '__main__' and __package__ is None: 31 | import_parents(level=1) 32 | 33 | from common.bpy_util import normalize_scale, look_at, del_obj, clean_scene, use_gpu_if_available 34 | from common.file_util import get_recipe_yml_obj, hash_file_name 35 | 36 | 37 | """ 38 | Shader references: 39 | pencil shader - https://www.youtube.com/watch?v=71KGlu_Yxtg 40 | white background (compositing) - https://www.youtube.com/watch?v=aegiN7XeLow 41 | creating transparent object - https://www.katsbits.com/codex/transparency-cycles/ 42 | """ 43 | 44 | 45 | def main(dataset_dir: Path, phase, parallel, mod): 46 | try: 47 | clean_scene() 48 | 49 | use_gpu_if_available() # also switches to Cycles 50 | 51 | # setup to avoid rendering surfaces and only render the freestyle curves 52 | bpy.context.view_layer.use_pass_z = False 53 | bpy.context.view_layer.use_pass_combined = False 54 | bpy.context.view_layer.use_sky = False 55 | bpy.context.view_layer.use_solid = False 56 | bpy.context.view_layer.use_volumes = False 57 | bpy.context.view_layer.use_strand = True # freestyle curves 58 | bpy.context.scene.render.use_freestyle = True 59 | bpy.context.scene.render.film_transparent = True 60 | bpy.context.scene.render.image_settings.color_mode = 'RGB' 61 | bpy.context.scene.view_settings.view_transform = 'Standard' 62 | 63 | if "Along Stroke" not in bpy.data.linestyles['LineStyle'].thickness_modifiers: 64 | bpy.ops.scene.freestyle_thickness_modifier_add(type='ALONG_STROKE') 65 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].mapping = 'CURVE' 66 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points[0].location = (0.0, 0.44375) 67 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.031879, 0.6875) 68 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.088926, 0.8625) 69 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.104027, 0.918751) 70 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.213087, 0.5875) 71 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.315436, 0.887501) 72 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.404362, 0.64375) 73 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.463088, 0.55625) 74 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.520134, 0.7125) 75 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.545302, 0.975001) 76 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.630872, 0.7) 77 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.778523, 0.76875) 78 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points.new(0.892618, 0.55) 79 | bpy.data.linestyles["LineStyle"].thickness_modifiers["Along Stroke"].curve.curves[0].points[-1].location = (1.0, 0.70625) 80 | 81 | # compositing 82 | if not bpy.context.scene.use_nodes: 83 | bpy.context.scene.use_nodes = True 84 | render_layers_node = bpy.context.scene.node_tree.nodes['Render Layers'] 85 | composite_node = bpy.context.scene.node_tree.nodes['Composite'] 86 | alpha_over_node = bpy.context.scene.node_tree.nodes.new('CompositorNodeAlphaOver') 87 | bpy.context.scene.node_tree.links.new(render_layers_node.outputs['Image'], alpha_over_node.inputs[2]) 88 | bpy.context.scene.node_tree.links.new(alpha_over_node.outputs['Image'], composite_node.inputs['Image']) 89 | 90 | recipe_file_path = dataset_dir.joinpath('recipe.yml') 91 | recipe_yml_obj = get_recipe_yml_obj(recipe_file_path) 92 | camera_angles = recipe_yml_obj['camera_angles_train'] + recipe_yml_obj['camera_angles_test'] 93 | # euler setting 94 | radius = 2 95 | eulers = [mathutils.Euler((math.radians(camera_angle[0]), 0.0, math.radians(camera_angle[1])), 'XYZ') for camera_angle in camera_angles] 96 | 97 | obj_gt_dir = dataset_dir.joinpath(phase, 'obj_gt') 98 | path_to_sketches = dataset_dir.joinpath(phase, 'sketches') # output folder 99 | if (parallel == 1 or mod == 0) and not path_to_sketches.is_dir(): 100 | path_to_sketches.mkdir() 101 | 102 | if parallel == 1 and mod != 0: 103 | while not path_to_sketches.is_dir(): 104 | time.sleep(2) 105 | 106 | obj_files = sorted(obj_gt_dir.glob('*.obj')) 107 | # filter out files that were already processed 108 | obj_files = [file for file in obj_files if 109 | not all( 110 | list(path_to_sketches.glob(f'{file.stem}_{camera_angle[0]}_{camera_angle[1]}.png')) 111 | for camera_angle in camera_angles)] 112 | # remove any file that is not handled in this job 113 | if parallel > 1: 114 | obj_files = [file for file in obj_files if hash_file_name(file.name) % parallel == mod] 115 | 116 | for obj_file in tqdm(obj_files): 117 | file_name = obj_file.name 118 | 119 | filepath = obj_gt_dir.joinpath(file_name) 120 | bpy.ops.wm.obj_import(filepath=str(filepath), forward_axis='NEGATIVE_Z', up_axis='Y', filter_glob="*.obj;*.mtl") 121 | obj = bpy.context.selected_objects[0] 122 | 123 | # normalize the object 124 | normalize_scale(obj) 125 | 126 | for i, eul in enumerate(eulers): 127 | filename_no_ext = obj_file.stem 128 | target_file_name = f"{filename_no_ext}_{camera_angles[i][0]:.1f}_{camera_angles[i][1]:.1f}.png" 129 | target_file = path_to_sketches.joinpath(target_file_name) 130 | if target_file.is_file(): 131 | continue 132 | 133 | # camera setting 134 | cam_pos = mathutils.Vector((0.0, -radius, 0.0)) 135 | cam_pos.rotate(eul) 136 | if i < 4: 137 | # camera position perturbation 138 | rand_x = random.uniform(-2.0, 2.0) 139 | rand_z = random.uniform(-3.0, 3.0) 140 | eul_perturb = mathutils.Euler((math.radians(rand_x), 0.0, math.radians(rand_z)), 'XYZ') 141 | cam_pos.rotate(eul_perturb) 142 | 143 | scene = bpy.context.scene 144 | bpy.ops.object.camera_add(enter_editmode=False, location=cam_pos) 145 | new_camera = bpy.context.active_object 146 | new_camera.name = "camera_tmp" 147 | new_camera.data.name = "camera_tmp" 148 | new_camera.data.lens_unit = 'FOV' 149 | new_camera.data.angle = math.radians(60) 150 | look_at(new_camera, Vector((0.0, 0.0, 0.0))) 151 | 152 | # render 153 | scene.camera = new_camera 154 | scene.render.filepath = str(target_file) 155 | scene.render.resolution_x = 224 156 | scene.render.resolution_y = 224 157 | bpy.context.scene.cycles.samples = 10 158 | # debug 159 | if False: 160 | debug_file_path = dataset_dir / "gen_sketch_debug.blend" 161 | print(f"saving scene to [{debug_file_path}]") 162 | bpy.ops.wm.save_as_mainfile(filepath=str(debug_file_path)) 163 | return 164 | bpy.ops.render.render(write_still=True) 165 | 166 | # prepare for the next camera 167 | del_obj(new_camera) 168 | 169 | # delete the obj to prepare for the next one 170 | del_obj(obj) 171 | 172 | # clean the scene 173 | clean_scene() 174 | except Exception as e: 175 | print(repr(e)) 176 | print(traceback.format_exc()) 177 | 178 | 179 | if __name__ == "__main__": 180 | argv = sys.argv 181 | if '--' in sys.argv: 182 | # refer to https://b3d.interplanety.org/en/how-to-pass-command-line-arguments-to-a-blender-python-script-or-add-on/ 183 | argv = sys.argv[sys.argv.index('--') + 1:] 184 | else: 185 | raise Exception("Expected \'--\' followed by arguments to the script") 186 | 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('--dataset-dir', type=str, required=True, help='Path to dataset directory') 189 | parser.add_argument('--parallel', type=int, default=1, help='Number of processes that will run the script') 190 | parser.add_argument('--mod', type=int, default=0, help='The modulo for this process to match files\' hash') 191 | parser.add_argument('--phases', type=str, required=True, nargs='+', help='List of phases to generate the sketches for') 192 | 193 | args = parser.parse_args(argv) 194 | 195 | # hide the main collections (if it is already hidden, there is no effect) 196 | bpy.context.view_layer.layer_collection.children['Main'].hide_viewport = True 197 | bpy.context.view_layer.layer_collection.children['Main'].exclude = True 198 | 199 | dataset_dir = Path(args.dataset_dir).expanduser() 200 | phases = args.phases 201 | for phase in phases: 202 | main(dataset_dir, phase, args.parallel, args.mod) 203 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: geocode 2 | channels: 3 | - pytorch 4 | - defaults 5 | - dglteam 6 | - bottler 7 | - conda-forge 8 | - fvcore 9 | - iopath 10 | - pytorch3d 11 | dependencies: 12 | - python=3.8 13 | - pytorch=1.10.2 14 | - torchvision 15 | - numpy=1.21.2 16 | - matplotlib=3.5.1 17 | - pytorch-lightning=1.5.10 18 | - neptune-client=0.16.4 19 | - nvidiacub=1.10.0 20 | - dgl=0.9.1 21 | - scikit-image=0.19.2 22 | - iopath=0.1.9 23 | - fvcore=0.1.5.post20210915 24 | - tqdm>=4.64.0 25 | - pytorch3d 26 | - pip 27 | - pip: 28 | - git+https://github.com/otaheri/chamfer_distance 29 | -------------------------------------------------------------------------------- /geocode/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/geocode/__init__.py -------------------------------------------------------------------------------- /geocode/barplot_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | 5 | def gen_and_save_barplot(barplot_json_path, title, barplot_target_image_path=None): 6 | from matplotlib import pyplot as plt 7 | with open(barplot_json_path, 'r') as barplot_json_file: 8 | data = json.load(barplot_json_file) 9 | 10 | inputs_to_eval = data['inputs_to_eval'] 11 | correct_arr_pc = data['correct_arr_pc'] 12 | correct_arr_sketch = data['correct_arr_sketch'] 13 | total_pc = data['total_pc'] 14 | total_sketch = data['total_sketch'] 15 | 16 | correct = [a + b for a, b in zip(correct_arr_pc, correct_arr_sketch)] 17 | accuracy_avg = [a / (total_pc + total_sketch) for a in correct] 18 | accuracy_pc = [a / total_pc for a in correct_arr_pc] 19 | accuracy_sketch = [a / total_sketch for a in correct_arr_sketch] 20 | 21 | overall_acc_avg = (sum(correct_arr_pc) + sum(correct_arr_sketch)) / ( 22 | len(inputs_to_eval) * (total_pc + total_sketch)) 23 | overall_acc_pc = sum(correct_arr_pc) / (len(inputs_to_eval) * total_pc) 24 | overall_acc_sketch = sum(correct_arr_sketch) / (len(inputs_to_eval) * total_sketch) 25 | 26 | is_only_sketches = False 27 | is_only_pcs = False 28 | if all([param_acc == 0 for param_acc in accuracy_pc]): 29 | # only sketches 30 | overall_acc_avg = overall_acc_sketch 31 | is_only_sketches = True 32 | if all([param_acc == 0 for param_acc in accuracy_sketch]): 33 | # only pcs 34 | overall_acc_avg = overall_acc_pc 35 | is_only_pcs = True 36 | 37 | # sort by average accuracy 38 | inputs_to_eval, accuracy_avg, accuracy_pc, accuracy_sketch = zip( 39 | *sorted(zip(inputs_to_eval, accuracy_avg, accuracy_pc, accuracy_sketch), key=lambda x: x[1])) 40 | 41 | inputs_to_eval += ("Overall",) 42 | accuracy_avg += (overall_acc_avg,) 43 | accuracy_pc += (overall_acc_pc,) 44 | accuracy_sketch += (overall_acc_sketch,) 45 | 46 | fig, ax = plt.subplots(figsize=(16, 14)) 47 | X_axis = np.arange(len(inputs_to_eval)) * 2.6 48 | if not is_only_pcs and not is_only_sketches: 49 | pps = ax.barh(X_axis + 0.7, accuracy_avg, 0.7, color='steelblue') 50 | ax.barh(X_axis - 0.0, accuracy_pc, 0.7, color='lightsteelblue') 51 | ax.barh(X_axis - 0.7, accuracy_sketch, 0.7, color='wheat') 52 | ax.legend(labels=['Average', 'Point Clouds', 'Sketches']) 53 | elif is_only_pcs: 54 | pps = ax.barh(X_axis - 0.0, accuracy_pc, 0.7, color='lightsteelblue') 55 | ax.legend(labels=['Point Clouds']) 56 | elif is_only_sketches: 57 | pps = ax.barh(X_axis - 0.7, accuracy_sketch, 0.7, color='wheat') 58 | ax.legend(labels=['Sketches']) 59 | else: 60 | raise Exception("Either point cloud or sketch input should be processed") 61 | 62 | ax.bar_label(pps, fmt='%.2f', label_type='center', fontsize=8) 63 | ax.set_yticks(X_axis, inputs_to_eval) 64 | ax.set_title(title) 65 | 66 | if barplot_target_image_path: 67 | plt.savefig(barplot_target_image_path) 68 | return fig 69 | -------------------------------------------------------------------------------- /geocode/calculator_accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from calculator_util import eval_metadata 3 | 4 | 5 | class AccuracyCalculator: 6 | def __init__(self, inputs_to_eval, param_descriptors): 7 | self.inputs_to_eval = inputs_to_eval 8 | self.normalized_classes_all, self.num_classes_all_shifted_cumulated, self.num_classes_all, self.regression_params_indices \ 9 | = eval_metadata(inputs_to_eval, param_descriptors) 10 | self.param_descriptors = param_descriptors 11 | 12 | def eval(self, pred, targets, top_k_acc): 13 | batch_size = pred.shape[0] 14 | device = targets.device 15 | normalized_classes_all = self.normalized_classes_all.to(device) 16 | num_classes_all_shifted_cumulated = self.num_classes_all_shifted_cumulated.to(device) 17 | num_classes_all = self.num_classes_all.to(device) 18 | correct = [[0] * len(self.inputs_to_eval) for _ in range(top_k_acc)] 19 | targets_interleaved = torch.repeat_interleave(targets, num_classes_all.view(-1), dim=1) 20 | normalized_classes_all_repeated = normalized_classes_all.repeat(batch_size, 1).to(device) 21 | target_class = torch.abs(normalized_classes_all_repeated - targets_interleaved) 22 | target_class = torch.where(target_class < 1e-3)[1].view(batch_size, -1) # take the indices along dim=1 since target is of size [1, param_count] 23 | if len(self.regression_params_indices) > 0: 24 | regression_params_indices_repeated = self.regression_params_indices.repeat(batch_size, 1).to(device) 25 | target_class = torch.cat((target_class, regression_params_indices_repeated), dim=1) 26 | target_class, _ = torch.sort(target_class, dim=1) 27 | assert target_class.shape[1] == len(self.inputs_to_eval) 28 | target_class = target_class - num_classes_all_shifted_cumulated 29 | pred_split = torch.split(pred, list(num_classes_all), dim=1) 30 | class_indices_diff = [(torch.argmax(p, dim=1) - t if p.shape[1] > 1 else None) for p, t in zip( pred_split, target_class.T )] 31 | 32 | l1_distance = [None] * targets.shape[1] 33 | if len(self.regression_params_indices) > 0: 34 | for param_idx, (p, t) in enumerate(zip(pred_split, targets.T)): 35 | if self.param_descriptors[self.inputs_to_eval[param_idx]].is_regression: 36 | adjusted_pred = p[:, 0].clone() 37 | adjusted_pred[p[:, 1] >= 0.5] = -1.0 38 | l1_distance[param_idx] = torch.abs(adjusted_pred.squeeze() - t) 39 | 40 | for i, param_name in enumerate(self.inputs_to_eval): 41 | if self.param_descriptors[param_name].is_regression: 42 | # regression parameter 43 | normalized_acc_threshold = self.param_descriptors[param_name].normalized_acc_threshold 44 | for j in range(top_k_acc): 45 | assert len(l1_distance[i]) == batch_size 46 | correct[j][i] += torch.sum((l1_distance[i] < normalized_acc_threshold * (j + 1)).int()).item() 47 | else: 48 | cid = class_indices_diff[i] 49 | assert len(cid) == batch_size 50 | for j in range(top_k_acc): 51 | correct[j][i] += len(cid[(cid <= j) & (cid >= -j)]) 52 | return correct 53 | 54 | def eval_continuous_only(self, pred, targets, top_k_acc): 55 | assert pred.dtype == torch.float 56 | assert targets.dtype == torch.float 57 | batch_size = pred.shape[0] 58 | l1_distance = torch.where(targets == -1.0, torch.tensor(0.0).to(pred.device), torch.abs(pred - targets)) 59 | # continuous_param_names = [p_name for p_name in self.inputs_to_eval if self.param_descriptors[p_name].input_type in ['Vector', 'Float']] 60 | correct = [[0] * len(self.inputs_to_eval) for _ in range(top_k_acc)] 61 | for i, param_name in enumerate(self.inputs_to_eval): 62 | normalized_acc_threshold = self.param_descriptors[param_name].normalized_acc_threshold 63 | for j in range(top_k_acc): 64 | correct[j][i] += torch.sum((l1_distance[:, i] < normalized_acc_threshold * (j + 1)).int()).item() 65 | return correct 66 | 67 | 68 | def eval_discrete_only(self, pred, targets, top_k_acc): 69 | assert pred.dtype == torch.float 70 | assert targets.dtype == torch.float 71 | batch_size = pred.shape[0] 72 | # l1_distance = torch.where(torch.logical_and(targets == -1.0, pred < 0.0), torch.tensor(0.0).to(pred.device), torch.abs(pred - targets)) 73 | # import pdb; pdb.set_trace() 74 | device = targets.device 75 | num_classes_all = self.num_classes_all.to(device) 76 | pred_split = torch.split(pred, list(num_classes_all), dim=1) 77 | correct = [[0] * len(self.inputs_to_eval) for _ in range(top_k_acc)] 78 | for i, (pr, tt) in enumerate(zip(pred_split, targets.long().T)): 79 | pred_classes = torch.argmax(pr, axis=1) 80 | class_indices_diff = torch.abs(pred_classes - tt) 81 | class_indices_diff = torch.where(tt == -1, 0, class_indices_diff) # we consider target = -1 as success, thus we set the diff to 0 82 | for j in range(top_k_acc): 83 | correct[j][i] += len(class_indices_diff[class_indices_diff <= j]) 84 | return correct 85 | -------------------------------------------------------------------------------- /geocode/calculator_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from calculator_util import eval_metadata 3 | 4 | 5 | MSE = torch.nn.MSELoss() 6 | CElossSum = torch.nn.CrossEntropyLoss(reduction='sum') 7 | CEloss = torch.nn.CrossEntropyLoss(reduction='none') 8 | 9 | 10 | class LossCalculator(): 11 | def __init__(self, inputs_to_eval, param_descriptors): 12 | self.inputs_to_eval = inputs_to_eval 13 | self.normalized_classes_all, self.num_classes_all_shifted_cumulated, self.num_classes_all, self.regression_params_indices \ 14 | = eval_metadata(inputs_to_eval, param_descriptors) 15 | self.param_descriptors = param_descriptors 16 | 17 | def loss(self, pred, targets): 18 | """ 19 | _pred: (B, TARGET_VEC_LEN) 20 | """ 21 | batch_size = pred.shape[0] 22 | device = targets.device 23 | normalized_classes_all = self.normalized_classes_all.to(device) 24 | num_classes_all_shifted_cumulated = self.num_classes_all_shifted_cumulated.to(device) 25 | num_classes_all = self.num_classes_all.to(device) 26 | targets_interleaved = torch.repeat_interleave(targets, num_classes_all.view(-1), dim=1) 27 | normalized_classes_all_repeated = normalized_classes_all.repeat(batch_size, 1).to(device) 28 | target_class = torch.abs(normalized_classes_all_repeated - targets_interleaved) 29 | target_class = torch.where(target_class < 1e-3)[1].view(batch_size, -1) # take the indices along dim=1 30 | if len(self.regression_params_indices) > 0: 31 | regression_params_indices_repeated = self.regression_params_indices.repeat(batch_size, 1).to(device) 32 | target_class = torch.cat((target_class, regression_params_indices_repeated), dim=1) 33 | target_class, _ = torch.sort(target_class, dim=1) 34 | assert target_class.shape[1] == len(self.inputs_to_eval) 35 | target_class = target_class - num_classes_all_shifted_cumulated 36 | # target_class = target_class.to(_pred.get_device()) 37 | pred_split = torch.split(pred, list(num_classes_all), dim=1) 38 | detailed_ce_loss = [(CElossSum(p, t) if p.shape[1] > 1 else None) for p, t in zip( pred_split, target_class.T )] 39 | 40 | detailed_mse_loss = [None] * targets.shape[1] 41 | if len(self.regression_params_indices) > 0: 42 | for param_idx, (p, t) in enumerate(zip(pred_split, targets.T)): 43 | if self.param_descriptors[self.inputs_to_eval[param_idx]].is_regression: 44 | t_visibility = torch.zeros(t.shape[0]) 45 | t_visibility[t >= 0.0] = 0.0 46 | t_visibility[t == -1.0] = 1.0 47 | t_visibility = t_visibility.to(device) 48 | t_clone = t.clone() 49 | t_clone = t_clone.float() 50 | t_clone[t_clone == -1] = p[t_clone == -1,0] 51 | t_adjusted = torch.concat((t_clone.unsqueeze(1), t_visibility.unsqueeze(1)), dim=1) 52 | detailed_mse_loss[param_idx] = MSE(p, t_adjusted) 53 | detailed_mse_loss_no_none = [e for e in detailed_mse_loss if e] 54 | detailed_ce_loss_no_none = [e for e in detailed_ce_loss if e] 55 | mse_loss_range = 1.0 if not detailed_mse_loss_no_none else (max(detailed_mse_loss_no_none).item() - min(detailed_mse_loss_no_none).item()) 56 | ce_loss_range = max(detailed_ce_loss_no_none).item() - min(detailed_ce_loss_no_none).item() 57 | detailed_loss = [(ce_loss / ce_loss_range) if not mse_loss else (mse_loss / mse_loss_range) for ce_loss, mse_loss in zip(detailed_ce_loss, detailed_mse_loss)] 58 | 59 | return sum(detailed_loss), detailed_loss 60 | 61 | def loss_discrete(self, pred, targets): 62 | # this is specifically made to compare to SRPM paper 63 | device = targets.device 64 | num_classes_all = self.num_classes_all.to(device) 65 | pred_split = torch.split(pred, list(num_classes_all), dim=1) 66 | # import pdb; pdb.set_trace() 67 | detailed_loss = [torch.sum(CEloss(p, torch.where(t == -1, 0, t)) * (t != -1).int()) for p, t in zip( pred_split, targets.long().T )] 68 | return sum(detailed_loss), detailed_loss 69 | -------------------------------------------------------------------------------- /geocode/calculator_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List 3 | from common.param_descriptors import ParamDescriptor 4 | 5 | 6 | def eval_metadata(inputs_to_eval: List[str], param_descriptors_map: Dict[str, ParamDescriptor]): 7 | num_classes_all = torch.empty(0) 8 | normalized_classes_all = torch.empty(0) 9 | for i, param_name in enumerate(inputs_to_eval): 10 | param_descriptor = param_descriptors_map[param_name] 11 | num_classes = param_descriptor.num_classes # Including the visibility label. If using regression then num_classes=2. 12 | num_classes_all = torch.cat((num_classes_all, torch. tensor([num_classes]))).long() 13 | if param_descriptor.normalized_classes is not None: 14 | normalized_classes = torch.from_numpy(param_descriptor.normalized_classes) 15 | else: 16 | # high values so that eval and loss methods will work when using regression 17 | normalized_classes = torch.tensor([100000.0, 100000.0]) 18 | normalized_classes_all = torch.cat((normalized_classes_all, normalized_classes.view(-1))) 19 | num_classes_all_shifted = torch.cat((torch.tensor([0]), num_classes_all))[0:-1] # shift right + drop right-most element 20 | num_classes_all_shifted_cumulated = torch.cumsum(num_classes_all_shifted, dim=0).view(1, -1) 21 | 22 | # get the indices of all the regression params, then shift them to match the expanded vector 23 | regression_params = torch.tensor([param_descriptors_map[param_name].is_regression for param_name in inputs_to_eval], dtype=torch.int) 24 | regression_params_indices = torch.where(regression_params)[0] 25 | regression_params_indices = torch.tensor([num_classes_all_shifted_cumulated[0, idx] for idx in regression_params_indices]) 26 | 27 | return normalized_classes_all, num_classes_all_shifted_cumulated, num_classes_all, regression_params_indices 28 | -------------------------------------------------------------------------------- /geocode/geocode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from geocode_util import InputType 5 | from geocode_train import train 6 | from geocode_test import test 7 | 8 | 9 | def str2bool(v): 10 | if isinstance(v, bool): 11 | return v 12 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 13 | return True 14 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError('Boolean value expected but got [{}].'.format(v)) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(prog='ShapeEditing') 22 | 23 | common_parser = argparse.ArgumentParser(add_help=False) 24 | common_parser.add_argument('--dataset-dir', type=str, required=True, help='Path to dataset directory') 25 | common_parser.add_argument('--models-dir', type=str, required=False, help='Directory where experiments will be saved') 26 | common_parser.add_argument('--exp-name', type=str, required=True, help='Experiment directory within the models directory, where checkpoints will be saved') 27 | common_parser.add_argument('--input-type', type=InputType, nargs='+', default='pc sketch', help='Either \"pc\", \"sketch\" or \"pc sketch\"') 28 | common_parser.add_argument('--increase-network-size', action='store_true', default=False, help='Use larger encoders networks sizes') 29 | common_parser.add_argument('--normalize-embeddings', action='store_true', default=False, help='Normalize embeddings before using the decoders') 30 | common_parser.add_argument('--pretrained-vgg', action='store_true', default=False, help='Use a pretrained VGG network') 31 | common_parser.add_argument('--use-regression', action='store_true', default=False, help='Use regression instead of classification for continuous parameters') 32 | common_parser.add_argument('--huang', choices=['continuous', 'discrete'], default=False, help='Comparison to Huang et al. which separates continuous and discrete parameters.') 33 | 34 | sp = parser.add_subparsers() 35 | sp_train = sp.add_parser('train', parents=[common_parser]) 36 | sp_test = sp.add_parser('test', parents=[common_parser]) 37 | 38 | sp_train.set_defaults(func=train) 39 | sp_test.set_defaults(func=test) 40 | 41 | sp_train.add_argument('--batch_size', type=int, required=True, help='Batch size') 42 | sp_train.add_argument('--nepoch', type=int, required=True, help='Number of epochs to train') 43 | 44 | sp_test.add_argument('--phase', type=str, default='test') 45 | sp_test.add_argument('--blender-exe', type=str, required=True, help='Path to blender executable') 46 | sp_test.add_argument('--blend-file', type=str, required=True, help='Path to blend file') 47 | sp_test.add_argument('--random-pc', type=int, default=None, help='Use only random point cloud sampling with specified number of points') 48 | sp_test.add_argument('--gaussian', type=float, default=0.0, help='Add Gaussian noise to the point cloud with the specified STD') 49 | sp_test.add_argument('--normalize-pc', action='store_true', default=False, help='Automatically normalize the input point clouds') 50 | sp_test.add_argument('--scanobjectnn', action='store_true', default=False, help='ScanObjectNN dataset which has only point clouds input') 51 | # we augment in phases "train", "val", "test" and experiments "coseg", "simplify_mesh", and "gaussian" 52 | # use `--augment-with-random-points false` to disable 53 | sp_test.add_argument('--augment-with-random-points', type=str2bool, default='True', help='Augment FPS point cloud with randomly sampled points') 54 | 55 | args = parser.parse_args() 56 | args.func(args) 57 | 58 | # either pc or sketch, or both must be trained 59 | assert InputType.pc in args.input_type or InputType.sketch in args.input_type 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /geocode/geocode_train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | from pathlib import Path 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import NeptuneLogger 9 | import neptune.new as neptune 10 | from data.dataset_pc import DatasetPC 11 | from data.dataset_sketch import DatasetSketch 12 | from common.param_descriptors import ParamDescriptors 13 | from common.file_util import get_recipe_yml_obj 14 | from pytorch_lightning.trainer.supporters import CombinedLoader 15 | from geocode_util import InputType, get_inputs_to_eval, calc_prediction_vector_size 16 | 17 | 18 | def train(opt): 19 | torch.set_printoptions(precision=4) 20 | torch.multiprocessing.set_sharing_strategy('file_system') # to prevent "received 0 items of data" errors 21 | recipe_file_path = Path(opt.dataset_dir, 'recipe.yml') 22 | if not recipe_file_path.is_file(): 23 | raise Exception(f'No \'recipe.yml\' file found in path [{recipe_file_path}]') 24 | recipe_yml_obj = get_recipe_yml_obj(str(recipe_file_path)) 25 | 26 | inputs_to_eval = get_inputs_to_eval(recipe_yml_obj) 27 | 28 | top_k_acc = 2 29 | camera_angles_to_process = [f'{a}_{b}' for a, b in recipe_yml_obj['camera_angles_train']] 30 | param_descriptors = ParamDescriptors(recipe_yml_obj, inputs_to_eval, use_regression=opt.use_regression, train_with_visibility_label=(not opt.huang)) 31 | param_descriptors_map = param_descriptors.get_param_descriptors_map() 32 | detailed_vec_size = calc_prediction_vector_size(param_descriptors_map) 33 | print(f"Prediction vector length is set to [{sum(detailed_vec_size)}]") 34 | 35 | # create datasets 36 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 37 | train_loaders_map = {} 38 | val_loaders_map = {} 39 | 40 | num_workers = 5 # use 1 when debugging, otherwise 5 41 | 42 | # pc 43 | if InputType.pc in opt.input_type: 44 | train_dataset_pc = DatasetPC(inputs_to_eval, device, param_descriptors_map, opt.dataset_dir, "train", augment_with_random_points=True) 45 | train_dataloader_pc = DataLoader(train_dataset_pc, batch_size=opt.batch_size, shuffle=True, num_workers=num_workers, prefetch_factor=5) 46 | val_dataset_pc = DatasetPC(inputs_to_eval, device, param_descriptors_map, opt.dataset_dir, "val", augment_with_random_points=True) 47 | val_dataloader_pc = DataLoader(val_dataset_pc, batch_size=opt.batch_size, shuffle=False, num_workers=num_workers, prefetch_factor=5) 48 | train_loaders_map['pc'] = train_dataloader_pc 49 | val_loaders_map['pc'] = val_dataloader_pc 50 | print(f"Point cloud train dataset size [{len(train_dataset_pc)}] val dataset size [{len(val_dataset_pc)}]") 51 | 52 | # sketch 53 | if InputType.sketch in opt.input_type: 54 | train_dataset_sketch = DatasetSketch(inputs_to_eval, param_descriptors_map, camera_angles_to_process, opt.pretrained_vgg, opt.dataset_dir, "train") 55 | train_dataloader_sketch = DataLoader(train_dataset_sketch, batch_size=opt.batch_size, shuffle=True, num_workers=num_workers, prefetch_factor=5) 56 | val_dataset_sketch = DatasetSketch(inputs_to_eval, param_descriptors_map, camera_angles_to_process, opt.pretrained_vgg, opt.dataset_dir, "val") 57 | val_dataloader_sketch = DataLoader(val_dataset_sketch, batch_size=opt.batch_size, shuffle=False, num_workers=num_workers, prefetch_factor=5) 58 | train_loaders_map['sketch'] = train_dataloader_sketch 59 | val_loaders_map['sketch'] = val_dataloader_sketch 60 | print(f"Sketch train dataset size [{len(train_dataset_sketch)}] val dataset size [{len(val_dataset_sketch)}]") 61 | 62 | combined_train_dataloader = CombinedLoader(train_loaders_map, mode="max_size_cycle") 63 | combined_val_dataloader = CombinedLoader(val_loaders_map, mode="max_size_cycle") 64 | 65 | if InputType.pc in opt.input_type and InputType.sketch in opt.input_type: 66 | assert ( len(camera_angles_to_process) * len(train_dataset_pc) ) == len(train_dataset_sketch) 67 | assert ( len(camera_angles_to_process) * len(val_dataset_pc) ) == len(val_dataset_sketch) 68 | 69 | print(f"Experiment name [{opt.exp_name}]") 70 | 71 | exp_dir = Path(opt.models_dir, opt.exp_name) 72 | exp_dir.mkdir(exist_ok=True, parents=True) 73 | 74 | neptune_short_id = None 75 | neptune_short_id_file_path = exp_dir.joinpath('neptune_session.json') 76 | if neptune_short_id_file_path.is_file(): 77 | with open(neptune_short_id_file_path, 'r') as neptune_short_id_file: 78 | try: 79 | neptune_session_json = json.load(neptune_short_id_file) 80 | if 'short_id' in neptune_session_json: 81 | neptune_short_id = neptune_session_json['short_id'] 82 | print(f'Continuing Neptune run [{neptune_short_id}]') 83 | except: 84 | print("Could not resume neptune session") 85 | 86 | # create/load NeptuneLogger 87 | neptune_logger = None 88 | neptune_config_file_path = Path(__file__).parent.joinpath('..', 'config', 'neptune_config.yml').resolve() 89 | if neptune_config_file_path.is_file(): 90 | print(f"Found neptune config file [{neptune_config_file_path}]") 91 | with open(neptune_config_file_path) as neptune_config_file: 92 | config = yaml.safe_load(neptune_config_file) 93 | api_token = config['neptune']['api_token'] 94 | project = config['neptune']['project'] 95 | tags = ["train"] 96 | if neptune_short_id: 97 | neptune_logger = NeptuneLogger( run=neptune.init(run=neptune_short_id, project=project, api_token=api_token, tags=tags), log_model_checkpoints=False ) 98 | else: 99 | # log_model_checkpoints=False avoids saving the models to Neptune 100 | neptune_logger = NeptuneLogger(api_key=api_token, project=project, tags=tags, log_model_checkpoints=False) 101 | if neptune_short_id is None: 102 | # new experiment 103 | neptune_short_id = neptune_logger.run.fetch()['sys']['id'] # e.g. IN-105 (-) 104 | with open(neptune_short_id_file_path, 'w') as neptune_short_id_file: 105 | json.dump({'short_id': neptune_short_id}, neptune_short_id_file) 106 | print(f'Started a new Neptune.ai run with id [{neptune_short_id}]') 107 | 108 | # log parameters to Neptune 109 | params = { 110 | "exp_name": opt.exp_name, 111 | "lr": 1e-2 if not opt.huang else 3e-4, 112 | "bs": opt.batch_size, 113 | "n_parameters": len(inputs_to_eval), 114 | "sched_step_size": 20, 115 | "sched_gamma": 0.85 if not opt.huang else 0.9, 116 | "normalize_embeddings": opt.normalize_embeddings, 117 | "increase_net_size": opt.increase_network_size, 118 | "pretrained_vgg": opt.pretrained_vgg, 119 | "use_regression": opt.use_regression, 120 | } 121 | if neptune_logger: 122 | neptune_logger.run['parameters'] = params 123 | 124 | checkpoint_callback = ModelCheckpoint( 125 | dirpath=exp_dir, 126 | filename='ise-epoch{epoch:03d}-val_loss{val/loss/total:.2f}-val_acc{val/acc_top1/avg:.2f}', 127 | auto_insert_metric_name=False, 128 | save_last=True, 129 | monitor="val/acc_top1/avg", 130 | mode="max", 131 | save_top_k=3) 132 | 133 | huang_continuous = False 134 | huang_discrete = False 135 | if opt.huang == 'continuous': 136 | huang_continuous = True 137 | elif opt.huang == 'discrete': 138 | huang_discrete = True 139 | 140 | # import the relevant Model class 141 | if opt.huang: 142 | # comparison to Huang et al. 143 | from geocode_model_alexnet import Model 144 | else: 145 | from geocode_model import Model 146 | 147 | trainer = pl.Trainer(gpus=[0], max_epochs=opt.nepoch, logger=neptune_logger, callbacks=[checkpoint_callback]) 148 | last_ckpt_file_name = f"{checkpoint_callback.CHECKPOINT_NAME_LAST}{checkpoint_callback.FILE_EXTENSION}" # "last.ckpt" by default 149 | last_checkpoint_file_path = exp_dir.joinpath(last_ckpt_file_name) 150 | if last_checkpoint_file_path.is_file(): 151 | print(f"Loading checkpoint file [{last_checkpoint_file_path}]...") 152 | pl_model = Model.load_from_checkpoint(str(last_checkpoint_file_path), 153 | param_descriptors=param_descriptors, 154 | trainer=trainer, 155 | models_dir=opt.models_dir, 156 | exp_name=opt.exp_name, 157 | use_regression=opt.use_regression, 158 | discrete=huang_discrete, 159 | continuous=huang_continuous) 160 | else: 161 | last_checkpoint_file_path = None 162 | if opt.huang: 163 | pl_model = Model(top_k_acc, opt.batch_size, detailed_vec_size, opt.increase_network_size, opt.normalize_embeddings, opt.pretrained_vgg, 164 | opt.input_type, inputs_to_eval, params['lr'], params['sched_step_size'], params['sched_gamma'], opt.exp_name, 165 | trainer=trainer, param_descriptors=param_descriptors, models_dir=opt.models_dir, use_regression=opt.use_regression, 166 | discrete=huang_discrete, continuous=huang_continuous) 167 | else: 168 | # no huang related arguments 169 | pl_model = Model(top_k_acc, opt.batch_size, detailed_vec_size, opt.increase_network_size, opt.normalize_embeddings, opt.pretrained_vgg, 170 | opt.input_type, inputs_to_eval, params['lr'], params['sched_step_size'], params['sched_gamma'], opt.exp_name, 171 | trainer=trainer, param_descriptors=param_descriptors, models_dir=opt.models_dir, use_regression=opt.use_regression) 172 | trainer.fit(pl_model, train_dataloaders=combined_train_dataloader, val_dataloaders=combined_val_dataloader, ckpt_path=last_checkpoint_file_path) 173 | -------------------------------------------------------------------------------- /geocode/geocode_util.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict 3 | from common.param_descriptors import ParamDescriptor 4 | 5 | 6 | class InputType(Enum): 7 | sketch = 'sketch' 8 | pc = 'pc' 9 | 10 | def __str__(self): 11 | return self.value 12 | 13 | def __eq__(self, other): 14 | return self.value == other.value 15 | 16 | 17 | def get_inputs_to_eval(recipe_yml_obj): 18 | inputs_to_eval = [] 19 | for param_name, param_dict in recipe_yml_obj['dataset_generation'].items(): 20 | is_vector = False 21 | for axis in ['x', 'y', 'z']: 22 | if axis in param_dict: 23 | inputs_to_eval.append(f'{param_name} {axis}') 24 | is_vector = True 25 | if not is_vector: 26 | inputs_to_eval.append(param_name) 27 | print("Inputs that will be evaluated:") 28 | print("\t" + "\n\t".join(inputs_to_eval)) 29 | return inputs_to_eval 30 | 31 | 32 | def calc_prediction_vector_size(param_descriptors_map: Dict[str, ParamDescriptor]): 33 | detailed_vec_size = [param_descriptor.num_classes for param_name, param_descriptor in param_descriptors_map.items()] 34 | print(f"Found [{len(detailed_vec_size)}] parameters with a combined number of classes of [{sum(detailed_vec_size)}]") 35 | return detailed_vec_size 36 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/models/__init__.py -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Decoder(nn.Module): 7 | def __init__(self, emb_dims, output_channels, increase_network_size, bn, use_regression): 8 | super(Decoder, self).__init__() 9 | self.emb_dims = emb_dims 10 | 11 | if increase_network_size: 12 | self.linear_out_features_1 = 512 13 | self.linear_out_features_2 = 256 14 | else: 15 | self.linear_out_features_1 = 128 16 | self.linear_out_features_2 = 64 17 | self.bn = bn 18 | 19 | if use_regression: 20 | output_channels = 2 21 | 22 | self.linear1 = nn.Linear(self.emb_dims * 2, self.linear_out_features_1, bias=False) # args.emb_dims 23 | if self.bn: 24 | self.bn6 = nn.BatchNorm1d(self.linear_out_features_1) 25 | else: 26 | self.bn6 = nn.InstanceNorm1d(self.linear_out_features_1) 27 | self.dp1 = nn.Dropout(p=0.5) # args.dropout 28 | self.linear2 = nn.Linear(self.linear_out_features_1, self.linear_out_features_2) 29 | if self.bn: 30 | self.bn7 = nn.BatchNorm1d(self.linear_out_features_2) 31 | else: 32 | self.bn7 = nn.InstanceNorm1d(self.linear_out_features_2) 33 | self.dp2 = nn.Dropout(p=0.5) # args.dropout 34 | 35 | self.linear3 = nn.Linear(self.linear_out_features_2, output_channels) 36 | 37 | def forward(self, x): 38 | if self.bn: 39 | x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) 40 | else: 41 | x = F.leaky_relu(self.linear1(x), negative_slope=0.2) 42 | x = self.dp1(x) 43 | if self.bn: 44 | x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) 45 | else: 46 | x = F.leaky_relu(self.linear2(x), negative_slope=0.2) 47 | x = self.dp2(x) 48 | x = self.linear3(x) 49 | return x 50 | 51 | 52 | class DecodersNet(nn.Module): 53 | def __init__(self, output_channels, increase_network_size, regression_params=None): 54 | """ 55 | output_channels - array containing the number of classes per parameter (including visibility label if exists) 56 | """ 57 | # there is no Parallel module in torch (and there is no reason for one to exist) 58 | # refer to https://github.com/pytorch/pytorch/issues/36459 59 | super(DecodersNet, self).__init__() 60 | if increase_network_size: 61 | self.emb_dims = 1024 62 | else: 63 | self.emb_dims = 64 64 | self.bn = True 65 | fan_out_list = [] 66 | for i, param_output_channels in enumerate(output_channels): 67 | use_regression = False 68 | if regression_params: 69 | use_regression = regression_params[i] 70 | fan_out_list.append(Decoder(self.emb_dims, param_output_channels, increase_network_size, self.bn, use_regression)) 71 | self.fan_out = nn.ModuleList(fan_out_list) 72 | self.initialize_weights(self) 73 | 74 | @staticmethod 75 | def initialize_weights(module): 76 | for m in module.modules(): 77 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 78 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 79 | if m.bias is not None: 80 | nn.init.constant_(m.bias, 0) 81 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 82 | nn.init.constant_(m.weight, 1.0) 83 | nn.init.constant_(m.bias, 0) 84 | elif isinstance(m, nn.Linear): 85 | nn.init.normal_(m.weight, 0, 0.01) 86 | if m.bias is not None: 87 | nn.init.constant_(m.bias, 0) 88 | 89 | def decode(self, embedding): 90 | param_outs = [] 91 | for net in self.fan_out: 92 | param_outs.append(net(embedding)) 93 | 94 | x = torch.cat(param_outs, dim=1) 95 | return x 96 | 97 | 98 | # used for comparison to "Shape Synthesis from Sketches via Procedural Models and Convolutional Networks" 99 | class DecodersNetAlex(nn.Module): 100 | def __init__(self, output_channels): 101 | """ 102 | output_channels - array containing the number of classes per parameter (including visibility label if exists) 103 | """ 104 | # there is no Parallel module in torch (and there is no reason for one to exist) 105 | # refer to https://github.com/pytorch/pytorch/issues/36459 106 | super(DecodersNetAlex, self).__init__() 107 | self.emb_dims = 4096 108 | fan_out_list = [] 109 | for param_output_channels in output_channels: 110 | fan_out_list.append(nn.Linear(4096, param_output_channels)) 111 | self.fan_out = nn.ModuleList(fan_out_list) 112 | self.initialize_weights(self) 113 | 114 | @staticmethod 115 | def initialize_weights(module): 116 | for m in module.modules(): 117 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 118 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 119 | if m.bias is not None: 120 | nn.init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 122 | nn.init.constant_(m.weight, 1.0) 123 | nn.init.constant_(m.bias, 0) 124 | elif isinstance(m, nn.Linear): 125 | nn.init.normal_(m.weight, 0, 0.01) 126 | if m.bias is not None: 127 | nn.init.constant_(m.bias, 0) 128 | 129 | def decode(self, embedding): 130 | param_outs = [] 131 | for net in self.fan_out: 132 | param_outs.append(net(embedding)) 133 | 134 | x = torch.cat(param_outs, dim=1) 135 | return x 136 | -------------------------------------------------------------------------------- /models/dgcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def knn(x, k): 7 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 8 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 9 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 10 | 11 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 12 | return idx 13 | 14 | 15 | def get_graph_feature(x, k=20, idx=None): 16 | batch_size = x.size(0) 17 | num_points = x.size(2) 18 | x = x.view(batch_size, -1, num_points) 19 | if idx is None: 20 | idx = knn(x, k=k) # (batch_size, num_points, k) 21 | device = x.device 22 | 23 | # separate indices for each batch 24 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 25 | idx = idx + idx_base 26 | idx = idx.view(-1) 27 | 28 | _, num_dims, _ = x.size() # at first num_dims = 3, then num_dims = 16 (or 64 if increased size), and so on... 29 | 30 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 31 | feature = x.view(batch_size * num_points, -1)[idx, :] 32 | feature = feature.view(batch_size, num_points, k, num_dims) 33 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 34 | 35 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() # (batch_size, num_dims, num_points, k) and num_dims is now 6 then 32, and so on... 36 | 37 | return feature 38 | 39 | 40 | class DGCNN(nn.Module): 41 | def __init__(self, increase_network_size=False, normalize_embeddings=False): 42 | super(DGCNN, self).__init__() 43 | #### hard coded params #### 44 | # args.k = 20 45 | ########################### 46 | # self.args = args 47 | self.k = 20 # args.k 48 | self.increase_network_size = increase_network_size 49 | self.normalize_embeddings = normalize_embeddings 50 | if self.increase_network_size: 51 | self.channels_layer_1 = 64 52 | self.channels_layer_2 = 64 53 | self.channels_layer_3 = 128 54 | self.channels_layer_4 = 256 55 | self.emb_dims = 1024 56 | else: 57 | self.channels_layer_1 = 16 58 | self.channels_layer_2 = 16 59 | self.channels_layer_3 = 32 60 | self.channels_layer_4 = 64 61 | self.emb_dims = 64 62 | 63 | self.bn = True 64 | if self.bn: 65 | self.bn1 = nn.BatchNorm2d(self.channels_layer_1) 66 | self.bn2 = nn.BatchNorm2d(self.channels_layer_2) 67 | self.bn3 = nn.BatchNorm2d(self.channels_layer_3) 68 | self.bn4 = nn.BatchNorm2d(self.channels_layer_4) 69 | else: 70 | self.bn1 = nn.InstanceNorm2d(self.channels_layer_1) 71 | self.bn2 = nn.InstanceNorm2d(self.channels_layer_2) 72 | self.bn3 = nn.InstanceNorm2d(self.channels_layer_3) 73 | self.bn4 = nn.InstanceNorm2d(self.channels_layer_4) 74 | 75 | self.conv1 = nn.Sequential(nn.Conv2d(6, self.channels_layer_1, kernel_size=1, bias=False), 76 | self.bn1, 77 | nn.LeakyReLU(negative_slope=0.2)) 78 | self.conv2 = nn.Sequential(nn.Conv2d(self.channels_layer_1 * 2, self.channels_layer_2, kernel_size=1, bias=False), 79 | self.bn2, 80 | nn.LeakyReLU(negative_slope=0.2)) 81 | self.conv3 = nn.Sequential(nn.Conv2d(self.channels_layer_2 * 2, self.channels_layer_3, kernel_size=1, bias=False), 82 | self.bn3, 83 | nn.LeakyReLU(negative_slope=0.2)) 84 | self.conv4 = nn.Sequential(nn.Conv2d(self.channels_layer_3 * 2, self.channels_layer_4, kernel_size=1, bias=False), 85 | self.bn4, 86 | nn.LeakyReLU(negative_slope=0.2)) 87 | 88 | if self.bn: 89 | self.bn5 = nn.BatchNorm1d(self.emb_dims) # args.emb_dims 90 | else: 91 | self.bn5 = nn.InstanceNorm1d(self.emb_dims) # args.emb_dims 92 | 93 | self.conv5 = nn.Sequential(nn.Conv1d(self.channels_layer_4 * 2, self.emb_dims, kernel_size=1, bias=False), # args.emb_dims 94 | self.bn5, 95 | nn.LeakyReLU(negative_slope=0.2)) 96 | 97 | # net.apply(init_weights) 98 | 99 | @staticmethod 100 | def initialize_weights(module): 101 | for m in module.modules(): 102 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 103 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 104 | if m.bias is not None: 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 107 | nn.init.constant_(m.weight, 1.0) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.normal_(m.weight, 0, 0.01) 111 | if m.bias is not None: 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def forward(self, x): 115 | x = get_graph_feature(x, k=self.k) 116 | x = self.conv1(x) 117 | x1 = x.max(dim=-1, keepdim=False)[0] # we are given back the values and indices, but we only need the values, hence the [0] 118 | 119 | x = get_graph_feature(x1, k=self.k) 120 | x = self.conv2(x) 121 | x2 = x.max(dim=-1, keepdim=False)[0] 122 | 123 | x = get_graph_feature(x2, k=self.k) 124 | x = self.conv3(x) 125 | x3 = x.max(dim=-1, keepdim=False)[0] 126 | 127 | x = get_graph_feature(x3, k=self.k) 128 | x = self.conv4(x) 129 | x4 = x.max(dim=-1, keepdim=False)[0] 130 | 131 | x = torch.cat((x1, x2, x3, x4), dim=1) 132 | 133 | batch_size = x.size(0) 134 | x = self.conv5(x) 135 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 136 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) 137 | enc = torch.cat((x1, x2), 1) 138 | 139 | if self.normalize_embeddings: 140 | enc = F.normalize(enc, dim=1) 141 | 142 | return enc 143 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.hub import load_state_dict_from_url 5 | 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 | } 23 | 24 | 25 | class VGG(nn.Module): 26 | 27 | def __init__(self, features, num_classes=1000, init_weights=True, encoder_only=False, increase_network_size=False, normalize_embeddings=False, pretrained=False): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 31 | self.encoder_only = encoder_only 32 | self.increase_network_size = increase_network_size 33 | self.normalize_embeddings = normalize_embeddings 34 | self.pretrained = pretrained 35 | if not encoder_only: 36 | assert False 37 | self.classifier = nn.Sequential( 38 | nn.Linear(512 * 7 * 7, 4096), 39 | nn.ReLU(True), 40 | nn.Dropout(), 41 | nn.Linear(4096, 4096), 42 | nn.ReLU(True), 43 | nn.Dropout(), 44 | nn.Linear(4096, num_classes), 45 | ) 46 | else: 47 | if self.increase_network_size: 48 | self.lin = nn.Linear(512 * 7 * 7, 2048) 49 | else: 50 | self.lin = nn.Linear(64 * 7 * 7, 128) 51 | with torch.no_grad(): 52 | if init_weights: 53 | self._initialize_weights() 54 | 55 | def forward(self, x): 56 | x = self.features(x) 57 | x = self.avgpool(x) 58 | x = torch.flatten(x, 1) 59 | # x = F.relu(x) 60 | if not self.encoder_only: 61 | assert False 62 | x = self.classifier(x) 63 | else: 64 | x = self.lin(x) 65 | if self.normalize_embeddings: 66 | x = F.normalize(x, dim=1) 67 | return x 68 | 69 | @staticmethod 70 | def scale_weights(m, scale=1.0): 71 | with torch.no_grad(): 72 | m.weight *= scale 73 | 74 | def _initialize_weights(self): 75 | for m in self.modules(): 76 | # import pdb; pdb.set_trace() 77 | if isinstance(m, nn.Conv2d): 78 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 79 | if m.bias is not None: 80 | nn.init.constant_(m.bias, 0) 81 | VGG.scale_weights(m) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | nn.init.constant_(m.weight, 1.0) 84 | nn.init.constant_(m.bias, 0) 85 | VGG.scale_weights(m) 86 | elif isinstance(m, nn.Linear): 87 | nn.init.normal_(m.weight, 0, 0.01) 88 | nn.init.constant_(m.bias, 0) 89 | VGG.scale_weights(m) 90 | 91 | 92 | def make_layers(cfg, batch_norm=False, pretrained=False): 93 | layers = [] 94 | in_channels = 3 if pretrained else 1 95 | for v in cfg: 96 | if v == 'M': 97 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 98 | else: 99 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 100 | if batch_norm: 101 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 102 | else: 103 | layers += [conv2d, nn.ReLU(inplace=True)] 104 | in_channels = v 105 | return nn.Sequential(*layers) 106 | 107 | 108 | cfgs = { 109 | 'A_reduced': [8, 'M', 16, 'M', 32, 32, 'M', 64, 64, 'M', 64, 64, 'M'], 110 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 111 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 112 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 113 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 114 | } 115 | 116 | 117 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 118 | assert not (pretrained and not kwargs['increase_network_size']) 119 | kwargs['pretrained'] = False 120 | if pretrained: 121 | kwargs['init_weights'] = False 122 | kwargs['pretrained'] = True 123 | if not kwargs['increase_network_size']: 124 | cfg += '_reduced' 125 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm, pretrained=pretrained), **kwargs) 126 | if pretrained: 127 | model_dict = model.state_dict() 128 | lin_state = {k: v for k, v in model_dict.items() if k == "lin.bias" or k == "lin.weight"} 129 | pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 130 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 131 | pretrained_dict.update(lin_state) 132 | model.load_state_dict(pretrained_dict) 133 | return model 134 | 135 | 136 | def vgg11(pretrained=False, progress=True, **kwargs): 137 | r"""VGG 11-layer model (configuration "A") from 138 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | progress (bool): If True, displays a progress bar of the download to stderr 142 | """ 143 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 144 | 145 | 146 | def vgg11_bn(pretrained=False, progress=True, **kwargs): 147 | r"""VGG 11-layer model (configuration "A") with batch normalization 148 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | progress (bool): If True, displays a progress bar of the download to stderr 152 | """ 153 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 154 | 155 | 156 | def vgg13(pretrained=False, progress=True, **kwargs): 157 | r"""VGG 13-layer model (configuration "B") 158 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | progress (bool): If True, displays a progress bar of the download to stderr 162 | """ 163 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 164 | 165 | 166 | def vgg13_bn(pretrained=False, progress=True, **kwargs): 167 | r"""VGG 13-layer model (configuration "B") with batch normalization 168 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | progress (bool): If True, displays a progress bar of the download to stderr 172 | """ 173 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 174 | 175 | 176 | def vgg16(pretrained=False, progress=True, **kwargs): 177 | r"""VGG 16-layer model (configuration "D") 178 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | progress (bool): If True, displays a progress bar of the download to stderr 182 | """ 183 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 184 | 185 | 186 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 187 | r"""VGG 16-layer model (configuration "D") with batch normalization 188 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | progress (bool): If True, displays a progress bar of the download to stderr 192 | """ 193 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 194 | 195 | 196 | def vgg19(pretrained=False, progress=True, **kwargs): 197 | r"""VGG 19-layer model (configuration "E") 198 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | progress (bool): If True, displays a progress bar of the download to stderr 202 | """ 203 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 204 | 205 | 206 | def vgg19_bn(pretrained=False, progress=True, **kwargs): 207 | r"""VGG 19-layer model (configuration 'E') with batch normalization 208 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | progress (bool): If True, displays a progress bar of the download to stderr 212 | """ 213 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /resources/chair_back_frame_mid_y_offset_pct_0_0000_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/resources/chair_back_frame_mid_y_offset_pct_0_0000_0002.png -------------------------------------------------------------------------------- /resources/geo_nodes_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/resources/geo_nodes_button.png -------------------------------------------------------------------------------- /resources/geo_nodes_workspace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/resources/geo_nodes_workspace.png -------------------------------------------------------------------------------- /resources/geocode_addon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/resources/geocode_addon.png -------------------------------------------------------------------------------- /resources/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/resources/teaser.png -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/download_ds.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import zipfile 4 | import hashlib 5 | import requests 6 | import traceback 7 | from pathlib import Path 8 | from common.domain import Domain 9 | from argparse import ArgumentParser 10 | 11 | 12 | def download_file(url, target_file_path): 13 | print(f"Downloading file from [{url}] as [{target_file_path}]") 14 | req = requests.get(url, allow_redirects=True) 15 | with open(target_file_path, 'wb') as target_file: 16 | target_file.write(req.content) 17 | 18 | 19 | def download_ds(args): 20 | datasets_dir = None 21 | if args.datasets_dir: 22 | datasets_dir = Path(args.datasets_dir) 23 | if not datasets_dir.is_dir(): 24 | raise Exception(f'Given datasets path [{datasets_dir}] is not an existing directory.') 25 | models_dir = None 26 | if args.models_dir: 27 | models_dir = Path(args.models_dir) 28 | if not models_dir.is_dir(): 29 | raise Exception(f'Given models path [{models_dir}] is not an existing directory.') 30 | blends_dir = None 31 | if args.blends_dir: 32 | blends_dir = Path(args.blends_dir) 33 | if not blends_dir.is_dir(): 34 | raise Exception(f'Given blends path [{blends_dir}] is not an existing directory.') 35 | 36 | if args.domain == Domain.chair: 37 | md5 = "27c283fa6893b23400a9bba6aca92854" 38 | ds_url = "https://figshare.com/ndownloader/files/39487282?private_link=d06bff0ae6b0c710bec8" 39 | ds_zip_file_name = "ChairDataset.zip" 40 | best_epoch = 585 41 | elif args.domain == Domain.vase: 42 | md5 = "1200bfb9552513ea6c9a3b9050af470e" 43 | ds_url = "https://figshare.com/ndownloader/files/39487153?private_link=1b30f4105c0518ce9071" 44 | ds_zip_file_name = "VaseDataset.zip" 45 | best_epoch = 573 46 | elif args.domain == Domain.table: 47 | md5 = "c7a0fc73c2b3f39dcd02f8cd3380d9dd" 48 | ds_url = "https://figshare.com/ndownloader/files/39487033?private_link=53f9de1359c3e3cc3218" 49 | ds_zip_file_name = "TableDataset.zip" 50 | best_epoch = 537 51 | elif args.domain == Domain.ceiling_lamp: 52 | md5 = "a6e2e29790f74219539f4c151f566ba8" 53 | ds_url = "https://figshare.com/ndownloader/files/53115095?private_link=e568d4700d54a8f48289" 54 | ds_zip_file_name = "CeilingLampDataset.zip" 55 | best_epoch = 259 56 | else: 57 | raise Exception(f'Domain [{args.domain}] is not recognized.') 58 | 59 | if args.datasets_dir: 60 | target_ds_zip_file_path = datasets_dir.joinpath(ds_zip_file_name) 61 | # download requested dataset zip file from Google Drive 62 | if not target_ds_zip_file_path.is_file(): 63 | download_file(ds_url, target_ds_zip_file_path) 64 | else: 65 | print(f"Skipping downloading dataset from Google Drive, file [{target_ds_zip_file_path}] already exists.") 66 | 67 | unzipped_dataset_dir = datasets_dir.joinpath(f"{str(args.domain).title()}Dataset") 68 | 69 | if not unzipped_dataset_dir.is_dir(): 70 | # verify md5 71 | print("Verifying MD5 hash...") 72 | assert hashlib.md5(open(target_ds_zip_file_path, 'rb').read()).hexdigest() == md5 73 | 74 | print("Unzipping dataset...") 75 | with zipfile.ZipFile(target_ds_zip_file_path, 'r') as target_ds_zip_file: 76 | target_ds_zip_file.extractall(datasets_dir) 77 | else: 78 | print(f"Skipping dataset unzip, directory [{unzipped_dataset_dir}] already exists.") 79 | 80 | release_url = "https://github.com/threedle/GeoCode/releases/latest/download" 81 | 82 | if args.models_dir: 83 | best_ckpt_file_name = f"procedural_{args.domain}_last_ckpt.zip" 84 | latest_ckpt_file_name = f"procedural_{args.domain}_epoch{best_epoch:03d}_ckpt.zip" 85 | exp_target_dir = models_dir.joinpath(f"exp_geocode_{args.domain}") 86 | exp_target_dir.mkdir(exist_ok=True) 87 | 88 | best_ckpt_url = f"{release_url}/{best_ckpt_file_name}" 89 | best_ckpt_file_path = exp_target_dir.joinpath(best_ckpt_file_name) 90 | download_file(best_ckpt_url, best_ckpt_file_path) 91 | 92 | print(f"Unzipping checkpoint file [{best_ckpt_file_path}]...") 93 | with zipfile.ZipFile(best_ckpt_file_path, 'r') as best_ckpt_file: 94 | best_ckpt_file.extractall(exp_target_dir) 95 | 96 | latest_ckpt_url = f"{release_url}/{latest_ckpt_file_name}" 97 | latest_ckpt_file_path = exp_target_dir.joinpath(latest_ckpt_file_name) 98 | download_file(latest_ckpt_url, latest_ckpt_file_path) 99 | 100 | print(f"Unzipping checkpoint file [{latest_ckpt_file_path}]...") 101 | with zipfile.ZipFile(latest_ckpt_file_path, 'r') as latest_ckpt_file: 102 | latest_ckpt_file.extractall(exp_target_dir) 103 | 104 | if args.blends_dir: 105 | blend_file_name = f"procedural_{args.domain}.blend" 106 | blend_file_path = blends_dir.joinpath(blend_file_name) 107 | blend_url = f"{release_url}/{blend_file_name}" 108 | download_file(blend_url, blend_file_path) 109 | 110 | print("Done") 111 | 112 | 113 | def main(): 114 | parser = ArgumentParser(prog='download_ds') 115 | parser.add_argument('--domain', type=Domain, choices=list(Domain), required=True, help='The domain name to download the dataset for.') 116 | parser.add_argument('--datasets-dir', type=str, default=None, help='The directory to download the dataset to.') 117 | parser.add_argument('--models-dir', type=str, default=None, help='The directory to download checkpoint file to.') 118 | parser.add_argument('--blends-dir', type=str, default=None, help='The directory to download blend file to.') 119 | 120 | try: 121 | args = parser.parse_args() 122 | download_ds(args) 123 | except Exception as e: 124 | print(repr(e)) 125 | print(traceback.format_exc()) 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /scripts/download_ds_processing_scripts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import zipfile 4 | import hashlib 5 | import requests 6 | import traceback 7 | from pathlib import Path 8 | from argparse import ArgumentParser 9 | 10 | 11 | def download_file(url, target_file_path): 12 | print(f"Downloading file from [{url}] as [{target_file_path}]") 13 | req = requests.get(url, allow_redirects=True) 14 | with open(target_file_path, 'wb') as target_file: 15 | target_file.write(req.content) 16 | 17 | 18 | def download_ds(args): 19 | md5 = "b641562224202ff5afa86f023661e9c2" 20 | ds_url = "https://figshare.com/ndownloader/files/47453975?private_link=50549dabd53a72065749" 21 | ds_zip_file_name = "dataset_processing.zip" 22 | 23 | geocode_dir = Path('.').resolve() 24 | dataset_processing_dir_path = geocode_dir / "dataset_processing" 25 | dataset_processing_dir_path.mkdir(exist_ok=True) 26 | target_ds_processing_scripts_zip_file_path = dataset_processing_dir_path / ds_zip_file_name 27 | # download requested dataset processing scripts zip file from Google Drive 28 | if not target_ds_processing_scripts_zip_file_path.is_file(): 29 | download_file(ds_url, target_ds_processing_scripts_zip_file_path) 30 | else: 31 | print(f"Skipping downloading dataset from Google Drive, file [{target_ds_processing_scripts_zip_file_path}] already exists.") 32 | 33 | unzipped_dataset_dir = target_ds_processing_scripts_zip_file_path.with_suffix('') 34 | 35 | if not unzipped_dataset_dir.is_dir(): 36 | # verify md5 37 | print("Verifying MD5 hash...") 38 | assert hashlib.md5(open(target_ds_processing_scripts_zip_file_path, 'rb').read()).hexdigest() == md5 39 | 40 | print("Unzipping dataset...") 41 | with zipfile.ZipFile(target_ds_processing_scripts_zip_file_path, 'r') as target_ds_zip_file: 42 | target_ds_zip_file.extractall(dataset_processing_dir_path) 43 | else: 44 | print(f"Skipping dataset processing scripts unzip, directory [{unzipped_dataset_dir}] already exists.") 45 | 46 | print("Done") 47 | 48 | 49 | def main(): 50 | parser = ArgumentParser(prog='download_ds_processing_scripts') 51 | 52 | try: 53 | args = parser.parse_args() 54 | download_ds(args) 55 | except Exception as e: 56 | print(repr(e)) 57 | print(traceback.format_exc()) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /scripts/install_blender4.2.sh: -------------------------------------------------------------------------------- 1 | mkdir ~/Blender 2 | cd ~/Blender 3 | wget https://mirror.freedif.org/blender/release/Blender4.2/blender-4.2.3-linux-x64.tar.xz 4 | tar -xvf blender-4.2.3-linux-x64.tar.xz 5 | BLENDER_PYTHON_BIN=~/Blender/blender-4.2.3-linux-x64/3.2/python/bin 6 | cd ${BLENDER_PYTHON_BIN} 7 | wget -P /tmp https://bootstrap.pypa.io/get-pip.py 8 | ${BLENDER_PYTHON_BIN}/python3.11 /tmp/get-pip.py 9 | ${BLENDER_PYTHON_BIN}/python3.11 -m pip install --upgrade pip 10 | ${BLENDER_PYTHON_BIN}/python3.11 -m pip install pyyaml 11 | ${BLENDER_PYTHON_BIN}/python3.11 -m pip install tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup(name='geocode', packages=find_packages()) 3 | -------------------------------------------------------------------------------- /stability_metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/stability_metric/__init__.py -------------------------------------------------------------------------------- /stability_metric/stability.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import bpy 4 | import sys 5 | import argparse 6 | import numpy as np 7 | import traceback 8 | from mathutils import Vector 9 | 10 | from pathlib import Path 11 | import importlib 12 | 13 | def import_parents(level=1): 14 | global __package__ 15 | file = Path(__file__).resolve() 16 | parent, top = file.parent, file.parents[level] 17 | 18 | sys.path.append(str(top)) 19 | try: 20 | sys.path.remove(str(parent)) 21 | except ValueError: 22 | pass 23 | 24 | __package__ = '.'.join(parent.parts[len(top.parts):]) 25 | importlib.import_module(__package__) 26 | 27 | if __name__ == '__main__' and __package__ is None: 28 | import_parents(level=1) 29 | 30 | from common.bpy_util import normalize_scale, select_obj 31 | from common.intersection_util import detect_cross_intersection 32 | 33 | 34 | class DropSimulator: 35 | def __init__(self, args): 36 | self.drop_height = 0.1 37 | self.duration_sec = 5 38 | self.skip_components_check = args.skip_components_check 39 | self.apply_normalization = args.apply_normalization 40 | self.obj_file_path = Path(args.obj_path).expanduser() 41 | 42 | def simulate(self): 43 | print(f"Importing object file [{self.obj_file_path}]") 44 | bpy.ops.wm.obj_import(filepath=str(self.obj_file_path), use_split_objects=False) 45 | obj = bpy.context.selected_objects[0] 46 | obj.data.materials.clear() 47 | select_obj(obj) 48 | if self.apply_normalization: 49 | print("Normalizing object...") 50 | normalize_scale(obj) 51 | # set origin to center of mass 52 | bpy.ops.object.origin_set(type='ORIGIN_CENTER_OF_MASS', center='MEDIAN') 53 | vertices = np.array([(obj.matrix_world @ v.co) for v in obj.data.vertices]) 54 | # verify that the object is normalized 55 | max_diff = 0.05 56 | max_dist_from_center = abs(1.0 - np.max(np.sqrt(np.sum((vertices ** 2), axis=1)))) 57 | assert max_dist_from_center < max_diff, f"Point cloud is not normalized [{max_dist_from_center} > {max_diff}] for sample [{self.obj_file_path.name}]. If this is an external dataset, please consider adding --apply-normalization flag." 58 | # position the object at drop height 59 | obj.location = Vector((0, 0, obj.location.z - min(vertices[:, 2]) + self.drop_height)) 60 | height_before_drop = max(vertices[:, 2]) - min(vertices[:, 2]) 61 | 62 | # apply rigid body simulation 63 | bpy.ops.rigidbody.object_add() 64 | frame_end = self.duration_sec * 25 65 | area = [a for a in bpy.context.screen.areas if a.type == "VIEW_3D"][0] 66 | with bpy.context.temp_override(area=area): 67 | select_obj(obj) 68 | bpy.ops.rigidbody.bake_to_keyframes(frame_start=1, frame_end=frame_end, step=1) 69 | bpy.context.scene.frame_current = frame_end 70 | obj.data.update() 71 | bpy.context.view_layer.update() 72 | print("Simulation completed") 73 | self.eval(obj, height_before_drop) 74 | 75 | def eval(self, obj, height_before_drop): 76 | vertices = np.array([(obj.matrix_world @ v.co) for v in obj.data.vertices]) 77 | height_after_drop = max(vertices[:, 2]) - min(vertices[:, 2]) 78 | score = min(height_after_drop, height_before_drop) / max(height_after_drop, height_before_drop) 79 | score = 1.0 if score > 1.0 else score 80 | print(f"Height before simulation [{height_before_drop:.5f}]") 81 | print(f"Height after simulation [{height_after_drop:.5f}]") 82 | print(f"Score [{score:.5f}]") 83 | print(f"is_stable (score > 0.98) [{score > 0.98}]") 84 | self.reset_simulation(obj) 85 | # structural evaluation 86 | if self.skip_components_check: 87 | print("is_structurally_valid (shape is connected) [True] (check skipped)") 88 | else: 89 | print("Checking structural validity...") 90 | obj_is_valid = self.is_structurally_connected() 91 | print(f"is_structurally_valid (shape is connected) [{obj_is_valid}]") 92 | 93 | def is_structurally_connected(self): 94 | """ 95 | return True if all the parts that make the shape are reachable from any other part 96 | two parts are connected if they are intersecting or there is a path from one part 97 | to the other that passes only through intersecting parts 98 | """ 99 | bpy.ops.wm.obj_import(filepath=str(self.obj_file_path), use_split_objects=False) 100 | obj = bpy.context.selected_objects[0] 101 | select_obj(obj) 102 | bpy.ops.object.mode_set(mode='EDIT') 103 | bpy.ops.mesh.separate(type='LOOSE') 104 | bpy.ops.object.mode_set(mode='OBJECT') 105 | parts = bpy.context.selected_objects 106 | # in the beginning, each part is put into a separate set 107 | components = [] 108 | for part in parts: 109 | components.append({part}) 110 | 111 | idx_a = 0 112 | while idx_a + 1 < len(components): 113 | component_a = components[idx_a] 114 | found_intersection = False 115 | for idx_b in range(idx_a + 1, len(components)): 116 | component_b = components[idx_b] 117 | for part_a in component_a: 118 | for part_b in component_b: 119 | assert part_a.name != part_b.name 120 | if len(detect_cross_intersection(part_a, part_b)) > 0: 121 | components.remove(component_a) 122 | components.remove(component_b) 123 | components.append(component_a.union(component_b)) 124 | found_intersection = True 125 | break 126 | if found_intersection: 127 | break 128 | if found_intersection: 129 | break 130 | if not found_intersection: 131 | idx_a += 1 132 | # note that we can 'break' here and return False if we are only looking to have a single connected component 133 | bpy.ops.object.delete() 134 | return len(components) <= 1 135 | 136 | @staticmethod 137 | def reset_simulation(obj): 138 | bpy.context.scene.frame_current = 0 139 | obj.data.update() 140 | bpy.context.view_layer.update() 141 | bpy.ops.object.delete() 142 | 143 | 144 | def main(): 145 | if '--' in sys.argv: 146 | # refer to https://b3d.interplanety.org/en/how-to-pass-command-line-arguments-to-a-blender-python-script-or-add-on/ 147 | argv = sys.argv[sys.argv.index('--') + 1:] 148 | else: 149 | raise Exception("Expected \'--\' followed by arguments to the script") 150 | 151 | parser = argparse.ArgumentParser(prog='stability') 152 | parser.add_argument('--obj-path', type=str, required=True, help='Path to the object to test') 153 | parser.add_argument('--apply-normalization', action='store_true', default=False, help='Apply normalization on the object upon importing') 154 | parser.add_argument('--skip-components-check', action='store_true', default=False, help='Do not check that the shape is structurally valid') 155 | 156 | try: 157 | args = parser.parse_known_args(argv)[0] 158 | drop_simulator = DropSimulator(args) 159 | drop_simulator.simulate() 160 | except Exception as e: 161 | print(repr(e)) 162 | print(traceback.format_exc()) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /stability_metric/stability_parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import random 5 | import argparse 6 | import traceback 7 | import multiprocessing 8 | import subprocess 9 | from subprocess import Popen 10 | from functools import partial 11 | from pathlib import Path 12 | 13 | 14 | def calculate_stability_proc(pred_obj_file_path, apply_normalization, skip_components_check, blender_exe: Path): 15 | print(f"Calculating stability for object [{pred_obj_file_path}]") 16 | simulation_blend_file_path = Path(__file__).parent.joinpath('stability_simulation.blend').resolve() 17 | stability_script_path = Path(__file__).parent.joinpath('stability.py').resolve() 18 | cmd = [str(blender_exe.expanduser()), 19 | str(simulation_blend_file_path), 20 | '-b', '--python', str(stability_script_path), '--', 21 | 'sim-obj', '--obj-path', str(pred_obj_file_path)] 22 | if apply_normalization: 23 | cmd.append('--apply-normalization') 24 | if skip_components_check: 25 | cmd.append('--skip-components-check') 26 | print(" ".join(cmd)) 27 | process = Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) 28 | out, err = process.communicate() 29 | result = out.splitlines() 30 | score_str_list = [line for line in result if 'Score [' in line] 31 | structurally_valid_str_list = [line for line in result if 'is_structurally_valid' in line] 32 | assert score_str_list, out 33 | score = float(score_str_list[0][-8:-1]) 34 | assert structurally_valid_str_list, out 35 | is_structurally_valid = True if 'True' in structurally_valid_str_list[0] else False 36 | assert score > 0.1 37 | return score, is_structurally_valid 38 | 39 | 40 | def sim_dir_parallel(args): 41 | cpu_count = multiprocessing.cpu_count() 42 | stability_json = {} 43 | count_stable = 0 44 | count_structurally_valid = 0 45 | count_good = 0 46 | dir_path = Path(args.dir_path).resolve() 47 | print(f"Calculating stability for dir [{dir_path}] with [{cpu_count}] processes") 48 | try: 49 | obj_files = sorted(dir_path.glob("*.obj")) 50 | print(len(obj_files)) 51 | if args.limit and args.limit < len(obj_files): 52 | obj_files = random.sample(obj_files, args.limit) 53 | blender_exe = Path(args.blender_exe).resolve() 54 | calculate_stability_proc_partial = partial(calculate_stability_proc, 55 | apply_normalization=args.apply_normalization, 56 | skip_components_check=args.skip_components_check, 57 | blender_exe=blender_exe) 58 | p = multiprocessing.Pool(cpu_count) 59 | stability_results = p.map(calculate_stability_proc_partial, obj_files) 60 | p.close() 61 | p.join() 62 | for obj_file_idx, obj_file in enumerate(obj_files): 63 | stability_json[str(obj_file)] = stability_results[obj_file_idx] 64 | score = stability_results[obj_file_idx][0] 65 | is_structurally_valid = stability_results[obj_file_idx][1] 66 | count_stable += 1 if score > 0.98 else 0 67 | count_structurally_valid += 1 if is_structurally_valid else 0 68 | count_good += 1 if (score > 0.98 and is_structurally_valid) else 0 69 | except Exception as e: 70 | print(traceback.format_exc()) 71 | print(repr(e)) 72 | sample_count = len(stability_json) 73 | print(f"# stable samples [{count_stable}] out of total [{sample_count}]") 74 | print(f"# structurally valid samples [{count_structurally_valid}] out of total [{sample_count}]") 75 | print(f"# good samples [{count_good}] out of total [{sample_count}] = [{(count_good/sample_count) * 100}%]") 76 | # save the detailed results to a json file 77 | stability_json['execution-details'] = {} 78 | stability_json['execution-details']['dir-path'] = str(dir_path) 79 | json_result_file_path = Path(__file__).parent.joinpath('stability_results.json').resolve() 80 | with open(json_result_file_path, 'w') as json_result_file: 81 | json.dump(stability_json, json_result_file) 82 | print(f"Results per .obj file were saved to [{json_result_file_path}]") 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser(prog='stability_parallel') 87 | parser.add_argument('--dir-path', type=str, required=True, help='Path to the dir to test with the \'stability metric\'') 88 | parser.add_argument('--blender-exe', type=str, required=True, help='Path to Blender executable') 89 | parser.add_argument('--skip-components-check', action='store_true', default=False, help='Skip checking if the shape is structurally valid') 90 | parser.add_argument('--apply-normalization', action='store_true', default=False, help='Apply normalization on the imported objects') 91 | parser.add_argument('--limit', type=int, default=None, help='Limit the number of shapes that will be evaluated, randomly selected shapes will be tested') 92 | 93 | try: 94 | args = parser.parse_args() 95 | sim_dir_parallel(args) 96 | except Exception as e: 97 | print(repr(e)) 98 | print(traceback.format_exc()) 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /stability_metric/stability_simulation.blend: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/threedle/GeoCode/8fc8e4d98a7375ab21c690224bce858330cd2c4f/stability_metric/stability_simulation.blend -------------------------------------------------------------------------------- /visualize_results/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | import bpy 5 | import math 6 | import random 7 | import argparse 8 | import traceback 9 | import mathutils 10 | from pathlib import Path 11 | from mathutils import Vector 12 | import importlib 13 | 14 | def import_parents(level=1): 15 | global __package__ 16 | file = Path(__file__).resolve() 17 | parent, top = file.parent, file.parents[level] 18 | sys.path.append(str(top)) 19 | try: 20 | sys.path.remove(str(parent)) 21 | except ValueError: 22 | pass 23 | 24 | __package__ = '.'.join(parent.parts[len(top.parts):]) 25 | importlib.import_module(__package__) 26 | 27 | if __name__ == '__main__' and __package__ is None: 28 | import_parents(level=1) 29 | 30 | from common.file_util import hash_file_name 31 | from common.bpy_util import clean_scene, setup_lights, select_objs, normalize_scale, del_obj, look_at 32 | 33 | 34 | def add_3d_text(obj_to_align_with, text): 35 | """ 36 | Adds 3D text object in front of the normalized object 37 | """ 38 | bpy.ops.object.select_all(action='DESELECT') 39 | font_curve = bpy.data.curves.new(type="FONT", name="Font Curve") 40 | font_curve.body = text 41 | font_obj = bpy.data.objects.new(name="Font Object", object_data=font_curve) 42 | bpy.context.scene.collection.objects.link(font_obj) 43 | font_obj.select_set(True) 44 | bpy.context.view_layer.objects.active = font_obj 45 | bpy.ops.object.transform_apply(location=False, rotation=False, scale=True) 46 | bpy.ops.object.origin_set(type='ORIGIN_GEOMETRY', center='MEDIAN') 47 | font_obj.location.x = obj_to_align_with.location.x 48 | font_obj.location.y = obj_to_align_with.location.y 49 | font_obj.location.z = 0 50 | font_obj.scale.x *= 0.2 51 | font_obj.scale.y *= 0.2 52 | font_obj.location.y -= 1 53 | return font_obj 54 | 55 | 56 | def visualize_results(args): 57 | """ 58 | Before using this method, a test dataset should be evaluated using the model 59 | """ 60 | test_ds_dir = Path(args.dataset_dir, args.phase).expanduser() 61 | if not test_ds_dir.is_dir(): 62 | raise Exception(f"Expected a \'{args.phase}\' dataset directory with 3D object to evaluate") 63 | 64 | results_dir = test_ds_dir.joinpath(f'results_{args.exp_name}') 65 | model_predictions_pc_dir = results_dir.joinpath('yml_predictions_pc') 66 | if not model_predictions_pc_dir.is_dir(): 67 | raise Exception(f"Expected a \'results_{args.exp_name}/yml_predictions_pc\' directory with predictions from point clouds") 68 | 69 | model_predictions_sketch_dir = results_dir.joinpath('yml_predictions_sketch') 70 | if not model_predictions_sketch_dir.is_dir(): 71 | raise Exception(f"Expected a \'results_{args.exp_name}/yml_predictions_sketch\' directory with predictions from skeches") 72 | 73 | obj_gt_dir = results_dir.joinpath('obj_gt') 74 | render_gt_dir = results_dir.joinpath('render_gt') 75 | obj_predictions_pc_dir = results_dir.joinpath('obj_predictions_pc') 76 | render_predictions_pc_dir = results_dir.joinpath('render_predictions_pc') 77 | obj_predictions_sketch_dir = results_dir.joinpath('obj_predictions_sketch') 78 | render_predictions_sketch_dir = results_dir.joinpath('render_predictions_sketch') 79 | 80 | work = [ 81 | (obj_gt_dir, render_gt_dir, "GT"), # render original 3D objs 82 | (obj_predictions_pc_dir, render_predictions_pc_dir, "PRED FROM PC"), # render predictions from point cloud input 83 | (obj_predictions_sketch_dir, render_predictions_sketch_dir, "PRED FROM SKETCH") # render predictions from sketch input 84 | ] 85 | 86 | try: 87 | clean_scene(start_with_strings=["Camera", "Light"]) 88 | setup_lights() 89 | # hide the main collections 90 | bpy.context.view_layer.layer_collection.children['Main'].hide_viewport = True 91 | bpy.context.view_layer.layer_collection.children['Main'].exclude = True 92 | for obj_dir, render_dir, title in work: 93 | file_names = sorted([f.stem for f in obj_dir.glob("*.obj")]) 94 | if args.parallel > 1: 95 | file_names = [file for file in file_names if hash_file_name(file) % args.parallel == args.mod] 96 | for file_name in file_names: 97 | original_obj_file_path = obj_dir.joinpath(f'{file_name}.obj') 98 | bpy.ops.wm.obj_import(filepath=str(original_obj_file_path)) 99 | imported_object = bpy.context.selected_objects[0] 100 | imported_object.hide_render = False 101 | imported_object.data.materials.clear() 102 | normalize_scale(imported_object) 103 | title_obj = add_3d_text(imported_object, title) 104 | render_images(render_dir, file_name) 105 | select_objs(title_obj, imported_object) 106 | bpy.ops.object.delete() 107 | except Exception as e: 108 | print(repr(e)) 109 | print(traceback.format_exc()) 110 | 111 | 112 | def render_images(target_dir: Path, file_name, suffix=None): 113 | # euler setting 114 | camera_angles = [ 115 | [-30.0, -35.0] 116 | ] 117 | radius = 2 118 | eulers = [mathutils.Euler((math.radians(camera_angle[0]), 0.0, math.radians(camera_angle[1])), 'XYZ') for 119 | camera_angle in camera_angles] 120 | 121 | for i, eul in enumerate(eulers): 122 | target_file_name = f"{file_name}{(f'_{suffix}' if suffix else '')}_at_{camera_angles[i][0]:.1f}_{camera_angles[i][1]:.1f}.png" 123 | target_file = target_dir.joinpath(target_file_name) 124 | 125 | # camera setting 126 | cam_pos = mathutils.Vector((0.0, -radius, 0.0)) 127 | cam_pos.rotate(eul) 128 | if i < 4: 129 | rand_x = random.uniform(-2.0, 2.0) 130 | rand_z = random.uniform(-5.0, 5.0) 131 | eul_perturb = mathutils.Euler((math.radians(rand_x), 0.0, math.radians(rand_z)), 'XYZ') 132 | cam_pos.rotate(eul_perturb) 133 | 134 | scene = bpy.context.scene 135 | bpy.ops.object.camera_add(enter_editmode=False, location=cam_pos) 136 | new_camera = bpy.context.active_object 137 | new_camera.name = "camera_tmp" 138 | new_camera.data.name = "camera_tmp" 139 | new_camera.data.lens_unit = 'FOV' 140 | new_camera.data.angle = math.radians(60) 141 | look_at(new_camera, Vector((0.0, 0.0, 0.0))) 142 | 143 | # render 144 | scene.camera = new_camera 145 | scene.render.filepath = str(target_file) 146 | scene.render.resolution_x = 224 147 | scene.render.resolution_y = 224 148 | bpy.context.scene.cycles.samples = 5 149 | # disable the sketch shader 150 | bpy.context.scene.render.use_freestyle = False 151 | bpy.ops.render.render(write_still=True) 152 | 153 | # prepare for the next camera 154 | del_obj(new_camera) 155 | 156 | 157 | def main(): 158 | if '--' in sys.argv: 159 | # refer to https://b3d.interplanety.org/en/how-to-pass-command-line-arguments-to-a-blender-python-script-or-add-on/ 160 | argv = sys.argv[sys.argv.index('--') + 1:] 161 | else: 162 | raise Exception("Expected \'--\' followed by arguments to the script") 163 | 164 | parser = argparse.ArgumentParser(prog='dataset_generator') 165 | parser.add_argument('--dataset-dir', type=str, required=True, help='Path to dataset directory') 166 | parser.add_argument('--phase', type=str, required=True, help='E.g. train, test or val') 167 | parser.add_argument('--exp-name', type=str, required=True) 168 | parser.add_argument('--parallel', type=int, default=1) 169 | parser.add_argument('--mod', type=int, default=None) 170 | 171 | try: 172 | args = parser.parse_known_args(argv)[0] 173 | visualize_results(args) 174 | except Exception as e: 175 | print(repr(e)) 176 | print(traceback.format_exc()) 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | --------------------------------------------------------------------------------