├── .DS_Store ├── LICENSE ├── README.md ├── Render ├── .DS_Store ├── blender_utils.py ├── bpy_render.py ├── data_generation.py └── py_render.py ├── core ├── .DS_Store ├── __init__.py ├── dataset.py ├── loss.py ├── model.py ├── transform.py └── utils.py ├── objects.yaml ├── pretrained_models ├── checkpoint_1.pth └── checkpoint_2.pth ├── test.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sailor-z/Unseen_Object_Pose/42d10d4498660882874a36d4797397737857d0ac/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chen Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fusing Local Similarities for Retrieval-based 3D Orientation Estimation of Unseen Objects 2 | PyTorch implementation of Paper "Fusing Local Similarities for Retrieval-based 3D Orientation Estimation of Unseen Objects" (ECCV 2022) 3 | * [[project page](https://sailor-z.github.io/projects/Unseen_Object_Pose.html)] 4 | * [[paper](https://arxiv.org/abs/2203.08472)] 5 | 6 | # Citation 7 | ```bibtex 8 | If you find the code useful, please consider citing: 9 | @article{zhao2022fusing, 10 | title={Fusing Local Similarities for Retrieval-based 3D Orientation Estimation of Unseen Objects}, 11 | author={Zhao, Chen and Hu, Yinlin and Salzmann, Mathieu}, 12 | journal={arXiv preprint arXiv:2203.08472}, 13 | year={2022} 14 | } 15 | ``` 16 | # Setup 17 | Our code has been tested with the the following dependencies: Python 3.7.11, Pytorch 1.7.1, Python-Blender 2.8, Pytorch3d 0.6.0, Python-OpenCV 3.4.2.17, Imutils 0.5.4. Please start by installing all the dependencies: 18 | 19 | conda create -n UnseenObjectPose python=3.7.11 20 | source activate UnseenObjectPose 21 | conda install -c conda-forge imutils 22 | conda install -c pytorch pytorch=1.7.1 torchvision 23 | conda install pytorch3d -c pytorch3d 24 | conda install -c jewfrocuban python-blender 25 | pip install opencv-python 26 | pip install tqdm 27 | pip install imageio 28 | pip install pyyaml 29 | 30 | # Data Processing 31 | First please download the LineMOD dataset we used in our experiments from [LineMOD](https://u.pcloud.link/publink/show?code=XZrVD8VZCwypoMMPVA5QF0WeevE3SyyaeR07). The data should be organized as 32 | 33 | UnseenObjectPose 34 | |-- data 35 | |-- linemod_zhs 36 | |-- models 37 | |-- real 38 | 39 | ### Rendering 40 | We generate 10,000 reference images for each object by rendering. The number of reference images and data path can be changed by modifying ```cfg["RENDER"]["NUM"]``` and ```cfg["RENDER"]["OUTPUT_PATH"]```, respectively. Please run the following code for rendering: 41 | 42 | cd ./Render 43 | python data_generation.py 44 | 45 | # Training 46 | 47 | The model can be trained on LineMOD by running: 48 | 49 | python ./train.py 50 | 51 | The unseen objects are ```'APE', 'BENCHVISE', 'CAM', 'CAN'``` by default. 52 | 53 | The pretrained models are also available at ```./pretrained_models```. 54 | -------------------------------------------------------------------------------- /Render/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sailor-z/Unseen_Object_Pose/42d10d4498660882874a36d4797397737857d0ac/Render/.DS_Store -------------------------------------------------------------------------------- /Render/blender_utils.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | from mathutils import Matrix 3 | import numpy as np 4 | import math 5 | import os 6 | import cv2 7 | 8 | def matrixToNumpyArray(mat): 9 | new_mat = np.array([[mat[0][0],mat[0][1],mat[0][2],mat[0][3]], 10 | [mat[1][0],mat[1][1],mat[1][2],mat[1][3]], 11 | [mat[2][0],mat[2][1],mat[2][2],mat[2][3]], 12 | [mat[3][0],mat[3][1],mat[3][2],mat[3][3]]]) 13 | return new_mat 14 | 15 | def numpyArrayToMatrix(array): 16 | mat = Matrix(((array[0,0],array[0,1], array[0,2], array[0,3]), 17 | (array[1,0],array[1,1], array[1,2], array[1,3]), 18 | (array[2,0],array[2,1], array[2,2], array[2,3]) , 19 | (array[3,0],array[3,1], array[3,2], array[3,3]))) 20 | return mat 21 | 22 | def save_visual(rgb, mask, output_path, euler): 23 | path = os.path.join(output_path, "rgb", str(euler[0]) + "_" + str(euler[1]) + "_" + str(euler[2]) + ".png") 24 | cv2.imwrite(path, rgb) 25 | path = os.path.join(output_path, "mask", str(euler[0]) + "_" + str(euler[1]) + "_" + str(euler[2]) + ".png") 26 | cv2.imwrite(path, mask) 27 | 28 | def draw_bounding_box(cvImg, kpts_2d, color, thickness): 29 | x = np.int32(kpts_2d[0]) 30 | y = np.int32(kpts_2d[1]) 31 | bbox_lines = [0, 1, 0, 2, 0, 4, 5, 1, 5, 4, 6, 2, 6, 4, 3, 2, 3, 1, 7, 3, 7, 5, 7, 6] 32 | for i in range(12): 33 | id1 = bbox_lines[2*i] 34 | id2 = bbox_lines[2*i+1] 35 | cvImg = cv2.line(cvImg, (x[id1],y[id1]), (x[id2],y[id2]), color, thickness=thickness, lineType=cv2.LINE_AA) 36 | return cvImg 37 | 38 | # we could also define the camera matrix 39 | # https://blender.stackexchange.com/questions/38009/3x4-camera-matrix-from-blender-camera 40 | def get_calibration_matrix_K_from_blender(camera): 41 | f_in_mm = camera.lens 42 | scene = bpy.context.scene 43 | resolution_x_in_px = scene.render.resolution_x 44 | resolution_y_in_px = scene.render.resolution_y 45 | scale = scene.render.resolution_percentage / 100 46 | sensor_width_in_mm = camera.sensor_width 47 | sensor_height_in_mm = camera.sensor_height 48 | pixel_aspect_ratio = scene.render.pixel_aspect_x / scene.render.pixel_aspect_y 49 | 50 | if camera.sensor_fit == 'VERTICAL': 51 | # the sensor height is fixed (sensor fit is horizontal), 52 | # the sensor width is effectively changed with the pixel aspect ratio 53 | s_u = resolution_x_in_px * scale / sensor_width_in_mm / pixel_aspect_ratio 54 | s_v = resolution_y_in_px * scale / sensor_height_in_mm 55 | else: # 'HORIZONTAL' and 'AUTO' 56 | # the sensor width is fixed (sensor fit is horizontal), 57 | # the sensor height is effectively changed with the pixel aspect ratio 58 | s_u = resolution_x_in_px * scale / sensor_width_in_mm 59 | s_v = resolution_y_in_px * scale * pixel_aspect_ratio / sensor_height_in_mm 60 | 61 | # Parameters of intrinsic calibration matrix K 62 | alpha_u = f_in_mm * s_u 63 | alpha_v = f_in_mm * s_u 64 | u_0 = resolution_x_in_px * scale / 2 65 | v_0 = resolution_y_in_px * scale / 2 66 | skew = 0 # only use rectangular pixels 67 | 68 | K = Matrix(((alpha_u, skew, u_0), 69 | (0, alpha_v, v_0), 70 | (0, 0, 1))) 71 | 72 | return K 73 | 74 | 75 | # Returns camera rotation and translation matrices from Blender. 76 | # 77 | # There are 3 coordinate systems involved: 78 | # 1. The World coordinates: "world" 79 | # - right-handed 80 | # 2. The Blender camera coordinates: "bcam" 81 | # - x is horizontal 82 | # - y is up 83 | # - right-handed: negative z look-at direction 84 | # 3. The desired computer vision camera coordinates: "cv" 85 | # - x is horizontal 86 | # - y is down (to align to the actual pixel coordinates 87 | # used in digital images) 88 | # - right-handed: positive z look-at direction 89 | def get_3x4_RT_matrix_from_blender(camera): 90 | # bcam stands for blender camera 91 | R_bcam2cv = Matrix( 92 | ((1, 0, 0), 93 | (0, -1, 0), 94 | (0, 0, -1))) 95 | 96 | # Use matrix_world instead to account for all constraints 97 | location, rotation = camera.matrix_world.decompose()[0:2] 98 | R_world2bcam = rotation.to_matrix().transposed() 99 | 100 | # Convert camera location to translation vector used in coordinate changes 101 | # Use location from matrix_world to account for constraints: 102 | T_world2bcam = -1 * R_world2bcam @ location 103 | 104 | # Build the coordinate transform matrix from world to computer vision camera 105 | R_world2cv = R_bcam2cv @ R_world2bcam 106 | T_world2cv = R_bcam2cv @ T_world2bcam 107 | 108 | # put into 3x4 matrix 109 | RT = Matrix((R_world2cv[0][:] + (T_world2cv[0],), 110 | R_world2cv[1][:] + (T_world2cv[1],), 111 | R_world2cv[2][:] + (T_world2cv[2],))) 112 | return RT 113 | 114 | 115 | def get_3x4_P_matrix_from_blender(camera): 116 | K = get_calibration_matrix_K_from_blender(camera.data) 117 | RT = get_3x4_RT_matrix_from_blender(camera) 118 | return K*RT 119 | 120 | 121 | def get_K_P_from_blender(camera): 122 | K = get_calibration_matrix_K_from_blender(camera.data) 123 | RT = get_3x4_RT_matrix_from_blender(camera) 124 | return {"K": np.asarray(K, dtype=np.float32), "RT": np.asarray(RT, dtype=np.float32)} 125 | 126 | def quaternionFromYawPitchRoll(yaw, pitch, roll): 127 | c1 = math.cos(yaw / 2.0) 128 | c2 = math.cos(pitch / 2.0) 129 | c3 = math.cos(roll / 2.0) 130 | s1 = math.sin(yaw / 2.0) 131 | s2 = math.sin(pitch / 2.0) 132 | s3 = math.sin(roll / 2.0) 133 | q1 = c1 * c2 * c3 + s1 * s2 * s3 134 | q2 = c1 * c2 * s3 - s1 * s2 * c3 135 | q3 = c1 * s2 * c3 + s1 * c2 * s3 136 | q4 = s1 * c2 * c3 - c1 * s2 * s3 137 | return q1, q2, q3, q4 138 | 139 | 140 | def camPosToQuaternion(cx, cy, cz): 141 | q1a = 0 142 | q1b = 0 143 | q1c = math.sqrt(2) / 2 144 | q1d = math.sqrt(2) / 2 145 | camDist = math.sqrt(cx * cx + cy * cy + cz * cz) 146 | cx = cx / camDist 147 | cy = cy / camDist 148 | cz = cz / camDist 149 | t = math.sqrt(cx * cx + cy * cy) 150 | tx = cx / t 151 | ty = cy / t 152 | yaw = math.acos(ty) 153 | if tx > 0: 154 | yaw = 2 * math.pi - yaw 155 | pitch = 0 156 | tmp = min(max(tx * cx + ty * cy, -1), 1) 157 | # roll = math.acos(tx * cx + ty * cy) 158 | roll = math.acos(tmp) 159 | if cz < 0: 160 | roll = -roll 161 | print("%f %f %f" % (yaw, pitch, roll)) 162 | q2a, q2b, q2c, q2d = quaternionFromYawPitchRoll(yaw, pitch, roll) 163 | q1 = q1a * q2a - q1b * q2b - q1c * q2c - q1d * q2d 164 | q2 = q1b * q2a + q1a * q2b + q1d * q2c - q1c * q2d 165 | q3 = q1c * q2a - q1d * q2b + q1a * q2c + q1b * q2d 166 | q4 = q1d * q2a + q1c * q2b - q1b * q2c + q1a * q2d 167 | return q1, q2, q3, q4 168 | 169 | 170 | def camRotQuaternion(cx, cy, cz, theta): 171 | theta = theta / 180.0 * math.pi 172 | camDist = math.sqrt(cx * cx + cy * cy + cz * cz) 173 | cx = -cx / camDist 174 | cy = -cy / camDist 175 | cz = -cz / camDist 176 | q1 = math.cos(theta * 0.5) 177 | q2 = -cx * math.sin(theta * 0.5) 178 | q3 = -cy * math.sin(theta * 0.5) 179 | q4 = -cz * math.sin(theta * 0.5) 180 | return q1, q2, q3, q4 181 | 182 | 183 | def quaternionProduct(qx, qy): 184 | a = qx[0] 185 | b = qx[1] 186 | c = qx[2] 187 | d = qx[3] 188 | e = qy[0] 189 | f = qy[1] 190 | g = qy[2] 191 | h = qy[3] 192 | q1 = a * e - b * f - c * g - d * h 193 | q2 = a * f + b * e + c * h - d * g 194 | q3 = a * g - b * h + c * e + d * f 195 | q4 = a * h + b * g - c * f + d * e 196 | return q1, q2, q3, q4 197 | 198 | def obj_location(dist, azi, ele): 199 | ele = math.radians(ele) 200 | azi = math.radians(azi) 201 | x = dist * math.cos(azi) * math.cos(ele) 202 | y = dist * math.sin(azi) * math.cos(ele) 203 | z = dist * math.sin(ele) 204 | return x, y, z 205 | 206 | def obj_centened_camera_pos(dist, azimuth_deg, elevation_deg): 207 | phi = float(elevation_deg) / 180 * math.pi 208 | theta = float(azimuth_deg) / 180 * math.pi 209 | x = (dist * math.cos(theta) * math.cos(phi)) 210 | y = (dist * math.sin(theta) * math.cos(phi)) 211 | z = (dist * math.sin(phi)) 212 | return x, y, z 213 | -------------------------------------------------------------------------------- /Render/bpy_render.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple script that uses bpy to render views of a single object by 3 | move the camera around it. 4 | 5 | Original source: 6 | https://github.com/panmari/stanford-shapenet-renderer 7 | """ 8 | 9 | import os, sys 10 | import bpy 11 | import math 12 | from math import radians 13 | from tqdm import tqdm 14 | from PIL import Image 15 | import numpy as np 16 | import cv2 17 | import pickle 18 | from blender_utils import obj_location, save_visual 19 | import imutils 20 | pi = math.pi 21 | np.random.seed(0) 22 | 23 | def reset_blend(): 24 | bpy.ops.wm.read_factory_settings() 25 | 26 | for scene in bpy.data.scenes: 27 | for obj in scene.objects: 28 | scene.objects.unlink(obj) 29 | 30 | # only worry about data in the startup scene 31 | for bpy_data_iter in ( 32 | bpy.data.objects, 33 | bpy.data.meshes, 34 | bpy.data.lamps, 35 | bpy.data.cameras, 36 | ): 37 | for id_data in bpy_data_iter: 38 | bpy_data_iter.remove(id_data) 39 | 40 | def resize_padding(im, desired_size): 41 | # compute the new size 42 | old_size = im.size 43 | ratio = float(desired_size)/max(old_size) 44 | new_size = tuple([int(x*ratio) for x in old_size]) 45 | im = im.resize(new_size, Image.ANTIALIAS) 46 | 47 | # create a new image and paste the resized on it 48 | new_im = Image.new("RGBA", (desired_size, desired_size)) 49 | new_im.paste(im, ((desired_size - new_size[0]) // 2, (desired_size - new_size[1]) // 2)) 50 | 51 | return new_im 52 | 53 | 54 | def resize_padding_v2(im, desired_size_in, desired_size_out): 55 | # compute the new size 56 | old_size = im.size 57 | ratio = float(desired_size_in)/max(old_size) 58 | new_size = tuple([int(x*ratio) for x in old_size]) 59 | 60 | im = im.resize(new_size, Image.ANTIALIAS) 61 | 62 | # create a new image and paste the resized on it 63 | new_im = Image.new("RGBA", (desired_size_out, desired_size_out)) 64 | new_im.paste(im, ((desired_size_out - new_size[0]) // 2, (desired_size_out - new_size[1]) // 2)) 65 | return new_im 66 | 67 | 68 | # create a lamp with an appropriate energy 69 | def makeLamp(lamp_name, rad): 70 | # Create new lamp data block 71 | lamp_data = bpy.data.lights.new(name=lamp_name, type='POINT') 72 | lamp_data.energy = rad 73 | # modify the distance when the object is not normalized 74 | # lamp_data.distance = rad * 2.5 75 | 76 | # Create new object with our lamp data block 77 | lamp_object = bpy.data.objects.new(name=lamp_name, object_data=lamp_data) 78 | # Link lamp object to the scene so it'll appear in this scene 79 | scene = bpy.context.collection 80 | scene.objects.link(lamp_object) 81 | return lamp_object 82 | 83 | 84 | def parent_obj_to_camera(b_camera): 85 | # set the parenting to the origin 86 | origin = (0, 0, 0) 87 | b_empty = bpy.data.objects.new("Empty", None) 88 | b_empty.location = origin 89 | b_camera.parent = b_empty 90 | 91 | scn = bpy.context.collection 92 | scn.objects.link(b_empty) 93 | bpy.context.view_layer.objects.active = b_empty 94 | return b_empty 95 | 96 | 97 | def clean_obj_lamp_and_mesh(context): 98 | scene = context.collection 99 | objs = bpy.data.objects 100 | meshes = bpy.data.meshes 101 | for obj in objs: 102 | if obj.type == "MESH" or obj.type == 'LAMP': 103 | scene.objects.unlink(obj) 104 | objs.remove(obj) 105 | for mesh in meshes: 106 | meshes.remove(mesh) 107 | 108 | def add_shader_on_world(): 109 | bpy.data.worlds['World'].use_nodes = True 110 | env_node = bpy.data.worlds['World'].node_tree.nodes.new(type='ShaderNodeTexEnvironment') 111 | back_node = bpy.data.worlds['World'].node_tree.nodes['Background'] 112 | bpy.data.worlds['World'].node_tree.links.new(env_node.outputs['Color'], back_node.inputs['Color']) 113 | 114 | def set_material_node_parameters(material): 115 | nodes = material.node_tree.nodes 116 | nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.5#np.random.uniform(0.8, 1) 117 | 118 | 119 | def add_shader_on_ply_object(obj): 120 | material = bpy.data.materials.new("VertCol") 121 | 122 | material.use_nodes = True 123 | material.node_tree.links.clear() 124 | 125 | mat_out = material.node_tree.nodes['Material Output'] 126 | diffuse_node = material.node_tree.nodes['Principled BSDF'] 127 | attr_node = material.node_tree.nodes.new(type='ShaderNodeAttribute') 128 | 129 | attr_node.attribute_name = 'Col' 130 | 131 | material.node_tree.links.new(attr_node.outputs['Color'], diffuse_node.inputs['Base Color']) 132 | material.node_tree.links.new(diffuse_node.outputs['BSDF'], mat_out.inputs['Surface']) 133 | 134 | obj.data.materials.append(material) 135 | 136 | return material 137 | 138 | def setup(shape, light_main, light_add): 139 | clean_obj_lamp_and_mesh(bpy.context) 140 | # Set up rendering of depth map: 141 | bpy.context.scene.use_nodes = True 142 | tree = bpy.context.scene.node_tree 143 | links = tree.links 144 | 145 | # clear default nodes 146 | for n in tree.nodes: 147 | tree.nodes.remove(n) 148 | 149 | # Depth config 150 | rl = tree.nodes.new(type="CompositorNodeRLayers") 151 | depth_file_output = tree.nodes.new(type="CompositorNodeOutputFile") 152 | depth_file_output.base_path = '' 153 | depth_file_output.format.file_format = 'PNG' 154 | depth_file_output.format.color_depth = '16' 155 | 156 | map_node = tree.nodes.new(type="CompositorNodeMapRange") 157 | map_node.inputs[1].default_value = 0 158 | map_node.inputs[2].default_value = 255 159 | map_node.inputs[3].default_value = 0 160 | map_node.inputs[4].default_value = 1 161 | links.new(rl.outputs['Depth'], map_node.inputs[0]) 162 | links.new(map_node.outputs[0], depth_file_output.inputs[0]) 163 | 164 | # Setting up the environment 165 | scene = bpy.context.scene 166 | context = bpy.context 167 | 168 | scene.render.engine = "CYCLES" 169 | scene.cycles.sample_clamp_indirect = 1.0 170 | scene.cycles.blur_glossy = 3.0 171 | scene.cycles.samples = 100 172 | 173 | for mesh in bpy.data.meshes: 174 | mesh.use_auto_smooth = True 175 | 176 | scene.render.resolution_x = shape[0] 177 | scene.render.resolution_y = shape[1] 178 | scene.render.resolution_percentage = 100 179 | scene.render.film_transparent = True 180 | scene.render.image_settings.color_mode = 'RGBA' 181 | scene.render.image_settings.file_format = 'PNG' 182 | 183 | # Camera setting 184 | cam = scene.objects['Camera'] 185 | cam_constraint = scene.objects['Camera'].constraints.new(type='TRACK_TO') 186 | cam_constraint.track_axis = 'TRACK_NEGATIVE_Z' 187 | cam_constraint.up_axis = 'UP_Y' 188 | cam_empty = parent_obj_to_camera(scene.objects['Camera']) 189 | cam_constraint.target = cam_empty 190 | 191 | # Light setting 192 | lamp_object = makeLamp('Lamp1', light_main) 193 | lamp_add = makeLamp('Lamp2', light_add) 194 | 195 | return cam, depth_file_output, lamp_object, lamp_add 196 | 197 | def render(camera, lamp_object, lamp_add, depth_file_output, outfile, pose): 198 | bpy.context.scene.render.filepath = outfile 199 | depth_file_output.file_slots[0].path = bpy.context.scene.render.filepath + '_depth.png' 200 | 201 | theta, elevation, azimuth = pose[:3] 202 | 203 | azimuth = -(azimuth + 90) 204 | elevation = elevation - 90 205 | 206 | cam_dist = pose[-1] 207 | 208 | x, y, z = obj_location(cam_dist, azimuth, elevation) 209 | camera.location = (x, y, z) 210 | 211 | ## setup_light 212 | lamp_object.location = (0, 0, 4) 213 | lamp_add.location = (0, 0, -4) 214 | 215 | bpy.ops.render.render(write_still=True) 216 | 217 | im_path = outfile + '.png' 218 | im = Image.open(im_path).copy() 219 | im = np.array(im) 220 | 221 | im = imutils.rotate(im, angle=-theta) 222 | 223 | mask = (im[:, :, 3] > 0).astype(np.uint8) * 255 224 | im = cv2.cvtColor(im[:, :, :3], cv2.COLOR_RGB2BGR) 225 | return im, mask 226 | 227 | def render_ply(obj, output_dir, pose_list, shape=[256, 256], light_main=5, light_add=1, normalize=False, forward=None, up=None, texture=True): 228 | # setup 229 | cam, depth_file_output, lamp_object, lamp_add = setup(shape, light_main, light_add) 230 | 231 | # import object 232 | bpy.ops.import_mesh.ply(filepath=obj) 233 | object = bpy.data.objects[os.path.basename(obj).replace('.ply', '')] 234 | 235 | ## texture 236 | if texture is True: 237 | material = add_shader_on_ply_object(object) 238 | set_material_node_parameters(material) 239 | 240 | for object in bpy.context.scene.objects: 241 | if object.name in ['Lamp'] or object.type in ['EMPTY', 'LAMP']: 242 | continue 243 | bpy.context.view_layer.objects.active = object 244 | max_dim = max(object.dimensions) 245 | 246 | # normalize the object 247 | if normalize: 248 | object.dimensions = object.dimensions / max_dim if max_dim != 0 else object.dimensions 249 | 250 | cam_ob = bpy.context.scene.camera 251 | bpy.context.view_layer.objects.active = cam_ob 252 | 253 | bbx_list = [] 254 | for i, pose in enumerate(tqdm(pose_list)): 255 | # redirect output to log file 256 | logfile = 'render.log' 257 | open(logfile, 'a').close() 258 | old = os.dup(1) 259 | sys.stdout.flush() 260 | os.close(1) 261 | os.open(logfile, os.O_WRONLY) 262 | 263 | im, mask = render(cam_ob, lamp_object, lamp_add, depth_file_output, os.path.join(output_dir, 'rendered_img'), pose) 264 | 265 | bbx = np.where(mask > 0) 266 | bbx = np.asarray([bbx[1].min(), bbx[0].min(), bbx[1].max(), bbx[0].max()]) 267 | bbx_list.append(bbx) 268 | 269 | save_visual(im, mask, output_dir, pose[:3]) 270 | 271 | # disable output redirection 272 | os.close(1) 273 | os.dup(old) 274 | os.close(old) 275 | 276 | os.system("rm render.log") 277 | return bbx_list 278 | -------------------------------------------------------------------------------- /Render/data_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import torch 5 | import math 6 | import numpy as np 7 | import cv2 8 | from tqdm import tqdm 9 | import json 10 | import trimesh 11 | import argparse 12 | import glob 13 | import pickle 14 | from tqdm import trange 15 | from bpy_render import render_ply 16 | from pytorch3d.transforms import euler_angles_to_matrix, matrix_to_euler_angles, rotation_6d_to_matrix 17 | 18 | sys.path.append('../') 19 | from core.utils import ( 20 | load_bop_meshes, 21 | load_bbox_3d, 22 | get_single_bop_annotation, 23 | remap_pose, 24 | ) 25 | np.set_printoptions(threshold=np.inf) 26 | np.random.seed(0) 27 | 28 | def sample_6d(num): 29 | samples = [] 30 | for i in range(num): 31 | x = np.asarray([np.random.normal() for j in range(3)]).squeeze() 32 | y = np.asarray([np.random.normal() for j in range(3)]).squeeze() 33 | 34 | x = x / max(np.linalg.norm(x, ord=2), 1e-8) 35 | y = y / max(np.linalg.norm(y, ord=2), 1e-8) 36 | 37 | samples.append(np.concatenate([x, y], axis=-1)) 38 | 39 | return torch.from_numpy(np.asarray(samples)).float() 40 | 41 | def src_image_generate(cfg, mode): 42 | if mode == 'test': 43 | seq_paths = ['../data/linemod_zhs/000001_test.txt','../data/linemod_zhs/000002_test.txt',\ 44 | '../data/linemod_zhs/000004_test.txt','../data/linemod_zhs/000005_test.txt',\ 45 | '../data/linemod_zhs/000006_test.txt','../data/linemod_zhs/000008_test.txt',\ 46 | '../data/linemod_zhs/000009_test.txt','../data/linemod_zhs/000010_test.txt',\ 47 | '../data/linemod_zhs/000011_test.txt','../data/linemod_zhs/000012_test.txt',\ 48 | '../data/linemod_zhs/000013_test.txt','../data/linemod_zhs/000014_test.txt',\ 49 | '../data/linemod_zhs/000015_test.txt'\ 50 | ] 51 | elif mode == 'train': 52 | seq_paths = ['../data/linemod_zhs/000001_train.txt','../data/linemod_zhs/000002_train.txt',\ 53 | '../data/linemod_zhs/000004_train.txt','../data/linemod_zhs/000005_train.txt',\ 54 | '../data/linemod_zhs/000006_train.txt','../data/linemod_zhs/000008_train.txt',\ 55 | '../data/linemod_zhs/000009_train.txt','../data/linemod_zhs/000010_train.txt',\ 56 | '../data/linemod_zhs/000011_train.txt','../data/linemod_zhs/000012_train.txt',\ 57 | '../data/linemod_zhs/000013_train.txt','../data/linemod_zhs/000014_train.txt',\ 58 | '../data/linemod_zhs/000015_train.txt' 59 | ] 60 | else: 61 | raise RuntimeError("Unsupported mode") 62 | 63 | img_files = {} 64 | for seq_path in seq_paths: 65 | dataDir = os.path.split(seq_path)[0] 66 | scene = seq_path.split('/')[-1].split('.')[0] 67 | with open(seq_path, 'r') as f: 68 | img_file = f.readlines() 69 | img_file = [dataDir + '/' + x.strip() for x in img_file] 70 | img_files[scene] = img_file 71 | 72 | meshes, objID_2_clsID = load_bop_meshes(cfg['DATA']['MESH_DIR']) 73 | 74 | bbox = load_bbox_3d(cfg["DATA"]["BBOX_FILE"]) 75 | 76 | if not os.path.exists("../data/src_images_" + mode + "_pkl/"): 77 | os.makedirs("../data/src_images_" + mode + "_pkl/") 78 | 79 | for key in img_files.keys(): 80 | print("Generating src image for " + key.split('_')[0]) 81 | Ks, Rs, Ts, bbxs, imgs, masks, depths, ids, kpts = [], [], [], [], [], [], [], [], [] 82 | src_info = {} 83 | # Load image 84 | for idx in trange(len(img_files[key])): 85 | try: 86 | img = cv2.imread(img_files[key][idx], cv2.IMREAD_UNCHANGED) # BGR(A) 87 | if img is None: 88 | raise RuntimeError('load image error') 89 | # 90 | if img.dtype == np.uint16: 91 | img = cv2.convertScaleAbs(img, alpha=(255.0/65535.0)).astype(np.uint8) 92 | # 93 | if len(img.shape) == 2: 94 | # convert gray to 3 channels 95 | img = np.repeat(img.reshape(img.shape[0], img.shape[1], 1), 3, axis=2) 96 | 97 | elif img.shape[2] == 4: 98 | # having alpha 99 | tmpBack = (img[:,:,3] == 0) 100 | img[:,:,0:3][tmpBack] = 255 # white background 101 | except: 102 | print('image %s not found' % img_path) 103 | return None 104 | 105 | # Load labels (BOP format) 106 | height, width, _ = img.shape 107 | 108 | K, merged_mask, class_ids, rotations, translations = get_single_bop_annotation(img_files[key][idx], objID_2_clsID) 109 | 110 | for i in range(len(class_ids)): 111 | if (merged_mask==i+1).sum() == 0: 112 | continue 113 | bbx = np.where(merged_mask==i+1) 114 | x_min = int(np.min(bbx[1])) 115 | y_min = int(np.min(bbx[0])) 116 | x_max = int(np.max(bbx[1])) 117 | y_max = int(np.max(bbx[0])) 118 | 119 | h, w = y_max - y_min, x_max - x_min 120 | 121 | x_min = max(x_min - 0.0*w, 0) 122 | y_min = max(y_min - 0.0*h, 0) 123 | x_max = min(x_max + 0.0*w, width) 124 | y_max = min(y_max + 0.0*h, height) 125 | 126 | mask = (merged_mask==i+1).astype(np.uint8) 127 | 128 | K = np.asarray(K) 129 | R = rotations[i] 130 | T = translations[i] 131 | 132 | dst_K = np.asarray(cfg["DATA"]["INTERNAL_K"]).reshape(3, 3) 133 | points = np.asarray(bbox[class_ids[i]]) 134 | R, T = remap_pose(K, R, T, points, dst_K) 135 | 136 | pose = np.concatenate([R.reshape(3, 3), T.reshape(3, 1)], axis=-1) 137 | 138 | _, depth = render_objects([meshes[int(key.split('_')[0])-1]], [0], [pose], dst_K, \ 139 | cfg["DATA"]["INTERNAL_WIDTH"], cfg["DATA"]["INTERNAL_HEIGHT"]) 140 | 141 | Ks.append(dst_K.reshape(3, 3)) 142 | Rs.append(R.reshape(3, 3)) 143 | Ts.append(T.reshape(3, 1)) 144 | bbxs.append(np.asarray([x_min, y_min, x_max, y_max])) 145 | ids.append(key.split('_')[0]) 146 | imgs.append(img) 147 | masks.append(mask) 148 | depths.append(depth) 149 | 150 | src_info["Ks"] = Ks 151 | src_info["Rs"] = Rs 152 | src_info["Ts"] = Ts 153 | src_info["imgs"] = imgs 154 | src_info["masks"] = masks 155 | src_info["depths"] = depths 156 | src_info["bbxs"] = bbxs 157 | src_info["ids"] = ids 158 | 159 | with open("../data/src_images_" + mode + "_pkl/" + key.split('_')[0] + ".pkl", "wb") as f: 160 | pickle.dump(src_info, f) 161 | f.close() 162 | 163 | def reference_generation(cfg, mode, num): 164 | print("-----------Loading Renderer------------") 165 | src_info_paths = glob.glob(os.path.join("../data/src_images_test_pkl/", '*.pkl')) 166 | # Load the obj and ignore the textures and materials. 167 | 168 | if not os.path.exists(os.path.join(cfg["RENDER"]["OUTPUT_PATH"])): 169 | os.makedirs(os.path.join(cfg["RENDER"]["OUTPUT_PATH"])) 170 | 171 | ## continuous sample 172 | print("Number of rendered images:", num) 173 | samples = sample_6d(num) 174 | 175 | sample_R = rotation_6d_to_matrix(samples) 176 | sample_euler = matrix_to_euler_angles(sample_R, 'ZXZ') * 180. / np.pi 177 | sample_pose = [np.concatenate([sample_euler[i].numpy(), [0, 0, cfg["RENDER"]["CAM_DIST"]]], axis=-1) \ 178 | for i in range(sample_euler.shape[0])] 179 | 180 | for src_info_path in src_info_paths: 181 | idx = int(src_info_path.split('/')[-1].split('.')[0]) 182 | print("Now Processing obj: %06d" % (idx)) 183 | output_path = os.path.join(cfg["RENDER"]["OUTPUT_PATH"], "%06d" % (idx)) 184 | 185 | if not os.path.exists(output_path): 186 | os.makedirs(output_path) 187 | 188 | if not os.path.exists(os.path.join(output_path, "rgb")): 189 | os.makedirs(os.path.join(output_path, "rgb")) 190 | 191 | if not os.path.exists(os.path.join(output_path, "mask")): 192 | os.makedirs(os.path.join(output_path, "mask")) 193 | 194 | with open(src_info_path, "rb") as f: 195 | meta_info = pickle.load(f) 196 | f.close() 197 | 198 | obj = os.path.join(cfg["DATA"]["MESH_DIR"], "obj_%06d.ply" % (idx)) 199 | 200 | print("Perform image rendering") 201 | ref_bbxs = render_ply(obj, output_path, sample_pose, shape=[cfg["DATA"]["INTERNAL_WIDTH"], cfg["DATA"]["INTERNAL_HEIGHT"]], \ 202 | light_main=50, light_add=5, normalize=True, forward="X", up="Z", texture=True) 203 | 204 | ref_paths, ref_eulers, ref_Rs = [], [], [] 205 | for j in range(num): 206 | ref_paths.append(os.path.join(output_path[1:], "rgb", str(sample_pose[j][0]) + "_" + str(sample_pose[j][1]) + "_" + str(sample_pose[j][2]) + ".png")) 207 | ref_eulers.append(sample_pose[j][:3]) 208 | ref_Rs.append(sample_R[j].numpy()) 209 | 210 | ref_info = {} 211 | 212 | ref_info["paths"] = ref_paths 213 | ref_info["eulers"] = ref_eulers 214 | ref_info["Rs"] = ref_Rs 215 | ref_info["bbxs"] = ref_bbxs 216 | 217 | res_path = os.path.join(cfg["RENDER"]["OUTPUT_PATH"], src_info_path.split('/')[-1]) 218 | with open(res_path, "wb") as f: 219 | pickle.dump(ref_info, f) 220 | f.close() 221 | 222 | if __name__ == '__main__': 223 | with open("../objects.yaml", 'r') as load_f: 224 | cfg = yaml.load(load_f, Loader=yaml.FullLoader) 225 | print("---------- image generation------------") 226 | src_image_generate(cfg, 'train') 227 | src_image_generate(cfg, 'test') 228 | reference_generation(cfg, 'test', cfg["RENDER"]["NUM"]) 229 | -------------------------------------------------------------------------------- /Render/py_render.py: -------------------------------------------------------------------------------- 1 | import pyrender 2 | import numpy as np 3 | import cv2 4 | 5 | def render_objects(meshes, ids, poses, K, w, h): 6 | assert(K[0][1] == 0 and K[1][0] == 0 and K[2][0] ==0 and K[2][1] == 0 and K[2][2] == 1) 7 | fx = K[0][0] 8 | fy = K[1][1] 9 | cx = K[0][2] 10 | cy = K[1][2] 11 | objCnt = len(ids) 12 | assert(len(poses) == objCnt) 13 | 14 | # set background with 0 alpha, important for RGBA rendering 15 | scene = pyrender.Scene(bg_color=np.array([0.0, 0.0, 0.0, 1.0]), ambient_light=np.array([0.02, 0.02, 0.02, 1.0])) 16 | # pyrender.Viewer(scene, use_raymond_lighting=True) 17 | # camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0) 18 | camera = pyrender.IntrinsicsCamera(fx=fx,fy=fy,cx=cx,cy=cy,znear=0.05,zfar=100000) 19 | camera_pose = np.eye(4) 20 | # reverse the direction of Y and Z, check: https://pyrender.readthedocs.io/en/latest/examples/cameras.html 21 | camera_pose[1][1] = -1 22 | camera_pose[2][2] = -1 23 | scene.add(camera, pose=camera_pose) 24 | #light = pyrender.SpotLight(color=np.ones(3), intensity=4.0, innerConeAngle=np.pi/16.0, outerConeAngle=np.pi/6.0) 25 | light = pyrender.DirectionalLight(color=np.ones(3), intensity=4.0) 26 | #light = pyrender.PointLight(color=np.ones(3), intensity=4.0) 27 | scene.add(light, pose=camera_pose) 28 | for i in range(objCnt): 29 | clsId = int(ids[i]) 30 | mesh = pyrender.Mesh.from_trimesh(meshes[clsId]) 31 | H = np.zeros((4,4)) 32 | H[0:3] = poses[i][0:3] 33 | H[3][3] = 1.0 34 | scene.add(mesh, pose=H) 35 | # pyrender.Viewer(scene, use_raymond_lighting=True) 36 | 37 | r = pyrender.OffscreenRenderer(w, h) 38 | # flags = pyrender.RenderFlags.OFFSCREEN | pyrender.RenderFlags.DEPTH_ONLY 39 | #flags = pyrender.RenderFlags.OFFSCREEN 40 | #flags = pyrender.RenderFlags.OFFSCREEN | pyrender.RenderFlags.RGBA 41 | #color, depth = r.render(scene, flags=flags) 42 | color, depth = r.render(scene) 43 | 44 | color = cv2.cvtColor(color, cv2.COLOR_RGB2BGR) # RGB to BGR (for OpenCV) 45 | #color = cv2.cvtColor(color, cv2.COLOR_RGBA2BGRA) # RGBA to BGRA (for OpenCV) 46 | 47 | return color, depth 48 | -------------------------------------------------------------------------------- /core/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sailor-z/Unseen_Object_Pose/42d10d4498660882874a36d4797397737857d0ac/core/.DS_Store -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sailor-z/Unseen_Object_Pose/42d10d4498660882874a36d4797397737857d0ac/core/__init__.py -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torch.utils.data import DataLoader 3 | import os, sys 4 | import torch 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from PIL import Image 8 | import random 9 | import cv2 10 | import glob 11 | import pickle 12 | import imutils 13 | from tqdm import trange 14 | from pytorch3d.ops import sample_farthest_points 15 | import core.transform as transform 16 | from core.utils import pose_symmetry_handling, geodesic_distance 17 | 18 | np.set_printoptions(threshold=np.inf) 19 | 20 | def resize_pad(im, dim): 21 | _, h, w = im.shape 22 | im = transforms.functional.resize(im, int(dim * min(w, h) / max(w, h))) 23 | left = int(np.ceil((dim - im.shape[2]) / 2)) 24 | right = int(np.floor((dim - im.shape[2]) / 2)) 25 | top = int(np.ceil((dim - im.shape[1]) / 2)) 26 | bottom = int(np.floor((dim - im.shape[1]) / 2)) 27 | im = transforms.functional.pad(im, (left, top, right, bottom)) 28 | 29 | return im 30 | 31 | def bbx_resize(bbx, img_w, img_h): 32 | w, h = bbx[2] - bbx[0], bbx[3] - bbx[1] 33 | dim = max(w, h) 34 | 35 | left = int(np.ceil((dim - w) / 2)) 36 | right = int(np.floor((dim - w) / 2)) 37 | top = int(np.ceil((dim - h) / 2)) 38 | bottom = int(np.floor((dim - h) / 2)) 39 | 40 | bbx[0] = max(bbx[0] - left, 0) 41 | bbx[1] = max(bbx[1] - top, 0) 42 | bbx[2] = min(bbx[2] + right, img_w) 43 | bbx[3] = min(bbx[3] + bottom, img_h) 44 | 45 | return bbx 46 | 47 | def crop(img, bbx): 48 | if len(img.shape) < 4: 49 | crop_img = img[int(bbx[1]):int(bbx[3]), int(bbx[0]):int(bbx[2])] 50 | else: 51 | crop_img = [img[i, int(bbx[i, 1]):int(bbx[i, 3]), int(bbx[i, 0]):int(bbx[i, 2])] for i in range(img.shape[0])] 52 | return crop_img 53 | 54 | class ref_image_loader(data.Dataset): 55 | def __init__(self, cfg, ref_info): 56 | self.ref_info = ref_info 57 | self.cfg = cfg 58 | self.trans = transforms.Compose( 59 | [ 60 | transforms.ToTensor(), 61 | transforms.Normalize( 62 | self.cfg['DATA']['PIXEL_MEAN'], 63 | self.cfg['DATA']['PIXEL_STD']), 64 | ] 65 | ) 66 | self.mask_trans = transforms.ToTensor() 67 | 68 | def __len__(self): 69 | return len(self.ref_info["paths"]) 70 | 71 | def __getitem__(self, idx): 72 | ref_bbx = np.array(self.ref_info["bbxs"][idx]) 73 | ref_img = cv2.imread(self.ref_info["paths"][idx]) 74 | ref_mask = cv2.imread(self.ref_info["paths"][idx].split('rgb')[0] + 'mask' + \ 75 | self.ref_info["paths"][idx].split('rgb')[1], 0).astype(np.uint8) 76 | 77 | ref_img = crop(ref_img, ref_bbx) 78 | ref_mask = crop(ref_mask[..., None], ref_bbx) 79 | 80 | ref_img = self.trans(ref_img) * self.mask_trans(ref_mask) 81 | ref_img = resize_pad(ref_img, self.cfg["DATA"]["CROP_SIZE"]) 82 | 83 | return ref_img 84 | 85 | class ref_f_loader(data.Dataset): 86 | def __init__(self, ref_f): 87 | self.ref_f = ref_f 88 | 89 | def __len__(self): 90 | return len(self.ref_f) 91 | 92 | def __getitem__(self, idx): 93 | ref = self.ref_f[idx] 94 | return ref 95 | 96 | ## reference loader 97 | class ref_loader_so3(): 98 | def __init__(self, cfg): 99 | self.cfg = cfg 100 | self.ref_path = glob.glob(os.path.join(self.cfg['DATA']['META_DIR'], self.cfg['DATA']['RENDER_DIR'], '*.pkl')) 101 | 102 | self.ref_info = {} 103 | clsIDs = [] 104 | for ref_file in self.ref_path: 105 | clsID = ref_file.split('/')[-1].split('.')[0] 106 | with open(ref_file, 'rb') as f: 107 | self.ref_info[clsID] = pickle.load(f) 108 | f.close() 109 | clsIDs.append(clsID) 110 | 111 | def load(self, clsID): 112 | ref_info_clsID = self.ref_info[clsID] 113 | 114 | ref_paths = ref_info_clsID["paths"] 115 | ref_bbxs = np.array(ref_info_clsID["bbxs"]) 116 | ref_Rs = np.array(ref_info_clsID["Rs"]) 117 | 118 | if clsID in self.cfg["LINEMOD"]["SYMMETRIC_OBJ"].keys(): 119 | ref_Rs = [pose_symmetry_handling(ref_Rs[i], self.cfg["LINEMOD"]["SYMMETRIC_OBJ"][clsID]) for i in range(ref_Rs.shape[0])] 120 | ref_Rs = torch.from_numpy(np.asarray(ref_Rs)) 121 | else: 122 | ref_Rs = torch.from_numpy(ref_Rs) 123 | 124 | ref_database = ref_image_loader(self.cfg, ref_info_clsID) 125 | dataset_loader = DataLoader(ref_database, batch_size=128, shuffle=False, num_workers=self.cfg["TRAIN"]["WORKERS"], drop_last=False) 126 | 127 | ref_imgs = [] 128 | for i, ref_img in enumerate(dataset_loader): 129 | ref_imgs.append(ref_img) 130 | 131 | ref_imgs = torch.cat(ref_imgs, dim=0) 132 | ref_info_clsID = {} 133 | ref_info_clsID['imgs'] = ref_imgs 134 | ref_info_clsID['Rs'] = ref_Rs 135 | return ref_info_clsID 136 | 137 | 138 | class LINEMOD_SO3(data.Dataset): 139 | def __init__(self, cfg, mode, clsID): 140 | self.cfg = cfg 141 | self.mode = mode 142 | self.unseen_cat = [cfg["LINEMOD"][cat] for cat in cfg["TEST"]["UNSEEN"]] 143 | 144 | if self.mode == 'train': 145 | self.src_path = glob.glob(os.path.join(self.cfg['DATA']['META_DIR'], 'src_images_test_pkl', '*.pkl')) 146 | self.ref_path = glob.glob(os.path.join(self.cfg['DATA']['META_DIR'], self.cfg['DATA']['RENDER_DIR'], '*.pkl')) 147 | 148 | for cat in self.unseen_cat: 149 | self.src_path.remove(os.path.join(self.cfg['DATA']['META_DIR'], 'src_images_test_pkl', '%06d.pkl' % (cat))) 150 | self.ref_path.remove(os.path.join(self.cfg['DATA']['META_DIR'], self.cfg['DATA']['RENDER_DIR'], '%06d.pkl' % (cat))) 151 | 152 | self.trans = transforms.Compose( 153 | [ 154 | transform.RandomHSV(0.2, 0.5, 0.5), 155 | transform.RandomNoise(0.1), 156 | transform.RandomSmooth(0.5), 157 | transforms.ToTensor(), 158 | transforms.Normalize( 159 | self.cfg['DATA']['PIXEL_MEAN'], 160 | self.cfg['DATA']['PIXEL_STD']), 161 | ] 162 | ) 163 | self.mask_trans = transforms.ToTensor() 164 | 165 | elif self.mode == 'test': 166 | self.src_path = [os.path.join(self.cfg['DATA']['META_DIR'], 'src_images_test_pkl', '%06d.pkl' % (clsID))] 167 | 168 | self.trans = transforms.Compose( 169 | [ 170 | transforms.ToTensor(), 171 | transforms.Normalize( 172 | self.cfg['DATA']['PIXEL_MEAN'], 173 | self.cfg['DATA']['PIXEL_STD']), 174 | ] 175 | ) 176 | else: 177 | raise RuntimeError('Unsupported mode') 178 | 179 | print(">>>>>>> Loading source and reference data") 180 | self.src_info, self.ref_info = {}, {} 181 | src_imgs, src_masks, src_Ks, src_bbxs, src_ids = [], [], [], [], [] 182 | ref_info = {} 183 | 184 | for i in range(len(self.src_path)): 185 | with open(self.src_path[i], 'rb') as f: 186 | src_info_i = pickle.load(f) 187 | f.close() 188 | 189 | if src_info_i["ids"][0] in cfg["LINEMOD"]["SYMMETRIC_OBJ"].keys(): 190 | src_info_i["Rs"] = [pose_symmetry_handling(src_info_i["Rs"][j], self.cfg["LINEMOD"]["SYMMETRIC_OBJ"][src_info_i["ids"][0]]) for j in range(len(src_info_i["Rs"]))] 191 | src_info_i["Rs"] = np.asarray(src_info_i["Rs"]) 192 | else: 193 | src_info_i["Rs"] = np.asarray(src_info_i["Rs"]) 194 | 195 | src_Rs = src_info_i["Rs"] if i == 0 else np.concatenate([src_Rs, src_info_i["Rs"]], axis=0) 196 | 197 | src_imgs += src_info_i["imgs"] 198 | src_masks += src_info_i["masks"] 199 | src_Ks += src_info_i["Ks"] 200 | src_bbxs += src_info_i["bbxs"] 201 | src_ids += src_info_i["ids"] 202 | 203 | if self.mode == 'train': 204 | with open(self.ref_path[i], 'rb') as f: 205 | ref_info_i = pickle.load(f) 206 | f.close() 207 | 208 | if src_info_i["ids"][0] in cfg["LINEMOD"]["SYMMETRIC_OBJ"].keys(): 209 | ref_info_i["Rs"] = [pose_symmetry_handling(ref_info_i["Rs"][j], self.cfg["LINEMOD"]["SYMMETRIC_OBJ"][src_info_i["ids"][0]]) for j in range(len(ref_info_i["Rs"]))] 210 | ref_info_i["Rs"] = np.asarray(ref_info_i["Rs"]) 211 | else: 212 | ref_info_i["Rs"] = np.asarray(ref_info_i["Rs"]) 213 | 214 | self.ref_info[src_info_i["ids"][0]] = ref_info_i 215 | 216 | indices = np.random.permutation(len(src_imgs)) 217 | self.src_info["imgs"] = np.asarray(src_imgs)[indices] 218 | self.src_info["masks"] = np.asarray(src_masks)[indices] 219 | self.src_info["Ks"] = np.asarray(src_Ks)[indices] 220 | self.src_info["Rs"] = np.asarray(src_Rs)[indices] 221 | self.src_info["bbxs"] = np.asarray(src_bbxs)[indices] 222 | self.src_info["ids"] = np.asarray(src_ids)[indices] 223 | 224 | def load_sample(self, index, ref_paths, ref_bbxs): 225 | path = ref_paths[index].split('rgb') 226 | 227 | img = cv2.imread(ref_paths[index]) 228 | mask = cv2.imread(path[0] + 'mask' + path[1], 0) 229 | 230 | img = crop(img, ref_bbxs[index]) 231 | mask = crop(mask, ref_bbxs[index]) 232 | 233 | return img, mask 234 | 235 | def sampling(self, src_R, ref_Rs, ref_bbxs, ref_paths): 236 | ### anchor sample 237 | _, anchor_indices = sample_farthest_points(ref_Rs.view(1, -1, 9), K=self.cfg["DATA"]["ANCHOR_NUM"], random_start_point=True) 238 | anchor_indices = anchor_indices[0] 239 | 240 | gt_sim, _ = geodesic_distance(src_R, ref_Rs[anchor_indices]) 241 | anchor_index = anchor_indices[torch.argmax(gt_sim)] 242 | anchor_img, anchor_mask = self.load_sample(anchor_index, ref_paths, ref_bbxs) 243 | 244 | ### positive sample 245 | gt_sim, _ = geodesic_distance(src_R, ref_Rs) 246 | pos_index = torch.argmax(gt_sim) 247 | pos_img, pos_mask = self.load_sample(pos_index, ref_paths, ref_bbxs) 248 | 249 | ### negative sample 250 | random_index = torch.randperm(gt_sim.shape[0])[0] 251 | random_img, random_mask = self.load_sample(random_index, ref_paths, ref_bbxs) 252 | 253 | ref_R = torch.stack([ref_Rs[anchor_index], ref_Rs[pos_index], ref_Rs[random_index]]) 254 | 255 | return [anchor_img, pos_img, random_img], [anchor_mask, pos_mask, random_mask], ref_R 256 | 257 | def __len__(self): 258 | return self.src_info["imgs"].shape[0] 259 | 260 | def __getitem__(self, idx): 261 | src_img = self.src_info["imgs"][idx].astype(np.uint8) 262 | src_mask = self.src_info["masks"][idx].astype(np.uint8) 263 | bbx = self.src_info["bbxs"][idx] 264 | id = int(self.src_info["ids"][idx]) 265 | K = torch.from_numpy(self.src_info["Ks"][idx]).float() 266 | src_R = torch.from_numpy(self.src_info["Rs"][idx]).float() 267 | 268 | if self.mode == 'train': 269 | ref_info = self.ref_info[self.src_info["ids"][idx]] 270 | ref_paths = ref_info["paths"] 271 | ref_bbxs = np.array(ref_info["bbxs"]) 272 | ref_Rs = torch.from_numpy(ref_info["Rs"]) 273 | 274 | if random.random() > 0.5 and self.cfg['TRAIN']['ROTATION_AG'] is True: 275 | r = max(-60, min(60, torch.randn(1) * 30)) 276 | src_img = imutils.rotate_bound(src_img, angle=-r) 277 | src_mask = imutils.rotate_bound(src_mask, angle=-r) 278 | r = r * np.pi / 180. 279 | delta_R = torch.tensor([[np.cos(r), -np.sin(r), 0], [np.sin(r), np.cos(r), 0], [0, 0, 1]]).float() 280 | src_R = torch.matmul(torch.inverse(delta_R), src_R) 281 | 282 | bbx = np.where(src_mask>0) 283 | x_min = int(np.min(bbx[1])) 284 | y_min = int(np.min(bbx[0])) 285 | x_max = int(np.max(bbx[1])) 286 | y_max = int(np.max(bbx[0])) 287 | bbx = np.asarray([x_min, y_min, x_max, y_max]) 288 | 289 | ref_img, ref_mask, ref_R = self.sampling(src_R, ref_Rs, ref_bbxs, ref_paths) 290 | bbx = bbx_resize(bbx, src_mask.shape[1], src_mask.shape[0]) 291 | 292 | src_img = crop(src_img, bbx) 293 | src_mask = crop(src_mask, bbx) 294 | 295 | src_img = self.trans(src_img) 296 | 297 | src_img = resize_pad(src_img, self.cfg["DATA"]["CROP_SIZE"]) 298 | ref_img = [resize_pad(self.trans(ref_img[i]) * self.mask_trans(ref_mask[i][..., None]), self.cfg["DATA"]["CROP_SIZE"])\ 299 | for i in range(len(ref_img))] 300 | 301 | return src_img, ref_img, src_R, ref_R, id 302 | 303 | else: 304 | bbx = bbx_resize(bbx, src_mask.shape[1], src_mask.shape[0]) 305 | src_img = crop(src_img, bbx) 306 | src_img = self.trans(src_img) 307 | src_img = resize_pad(src_img, self.cfg["DATA"]["CROP_SIZE"]) 308 | 309 | return src_img, src_R, id 310 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pytorch3d.transforms import matrix_to_euler_angles 7 | from core.utils import geodesic_distance 8 | 9 | def weighted_infoNCE_loss_func(ref_sim_pos, ref_sim_rand, ref_R_positive, ref_R_rand, src_R, id, tau=0.1): 10 | with torch.no_grad(): 11 | same_cat = ((id[:, None] - id[None]) == 0).float() 12 | 13 | gt_sim_pos = (torch.sum(src_R.view(-1, 1, 9) * ref_R_positive.view(1, -1, 9), dim=-1).clamp(-1, 3) - 1) / 2 14 | gt_dis_pos = torch.arccos(gt_sim_pos) / np.pi 15 | gt_dis_pos = gt_dis_pos * same_cat + gt_dis_pos.new_ones(gt_dis_pos.shape) * (1 - same_cat) 16 | 17 | gt_sim_rand = (torch.sum(src_R.view(-1, 1, 9) * ref_R_rand.view(1, -1, 9), dim=-1).clamp(-1, 3) - 1) / 2 18 | gt_dis_rand = torch.arccos(gt_sim_rand) / np.pi 19 | gt_dis_rand = gt_dis_rand * same_cat + gt_dis_rand.new_ones(gt_dis_rand.shape) * (1 - same_cat) 20 | 21 | postive_term = (torch.diag(ref_sim_pos) / tau).exp() * torch.diag(gt_dis_pos) 22 | ref_sim = torch.cat([gt_dis_pos * (ref_sim_pos / tau).exp(), gt_dis_rand * (ref_sim_rand / tau).exp()], dim=-1) 23 | 24 | nce_loss = (-torch.log(postive_term / (torch.sum(ref_sim, dim=-1)))).mean() 25 | 26 | return nce_loss 27 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LocalNorm2d(nn.Module): 6 | ## borrowed from https://github.com/DagnyT/hardnet/tree/deab7e892468a07fb2cf77d41e38714fa96a6e99 7 | def __init__(self, kernel_size = 32): 8 | super(LocalNorm2d, self).__init__() 9 | self.ks = kernel_size 10 | self.pool = nn.AvgPool2d(kernel_size = self.ks, stride = 1, padding = 0) 11 | self.eps = 1e-10 12 | return 13 | def forward(self,x): 14 | pd = int(self.ks/2) 15 | mean = self.pool(F.pad(x, (pd,pd,pd,pd), 'reflect')) 16 | return torch.clamp((x - mean) / (torch.sqrt(torch.abs(self.pool(F.pad(x*x, (pd,pd,pd,pd), 'reflect')) - mean*mean)) + self.eps), min = -6.0, max = 6.0) 17 | 18 | class ResidualBlock(nn.Module): 19 | def __init__(self, in_planes, planes, norm_fn='group', kernel_size=3, stride=1, padding=1, bias=True): 20 | super(ResidualBlock, self).__init__() 21 | 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=kernel_size, padding=padding, stride=stride, bias=bias) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, stride=1, bias=bias) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | num_groups = planes // 8 27 | 28 | if norm_fn == 'group': 29 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 30 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 31 | if stride != 1 or in_planes != planes: 32 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 33 | 34 | elif norm_fn == 'batch': 35 | self.norm1 = nn.BatchNorm2d(planes) 36 | self.norm2 = nn.BatchNorm2d(planes) 37 | if stride != 1 or in_planes != planes: 38 | self.norm3 = nn.BatchNorm2d(planes) 39 | 40 | elif norm_fn == 'instance': 41 | self.norm1 = nn.InstanceNorm2d(planes) 42 | self.norm2 = nn.InstanceNorm2d(planes) 43 | if stride != 1 or in_planes != planes: 44 | self.norm3 = nn.InstanceNorm2d(planes) 45 | 46 | elif norm_fn == 'none': 47 | self.norm1 = nn.Sequential() 48 | self.norm2 = nn.Sequential() 49 | if stride != 1 or in_planes != planes: 50 | self.norm3 = nn.Sequential() 51 | 52 | if stride == 1 and in_planes == planes: 53 | self.downsample = None 54 | else: 55 | self.downsample = nn.Sequential( 56 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=bias), self.norm3) 57 | 58 | 59 | def forward(self, x): 60 | y = x 61 | y = self.relu(self.norm1(self.conv1(y))) 62 | y = self.relu(self.norm2(self.conv2(y))) 63 | 64 | if self.downsample is not None: 65 | x = self.downsample(x) 66 | return self.relu(x+y) 67 | 68 | 69 | class ResNet_encoder(nn.Module): 70 | def __init__(self, norm_fn='none', dropout=0.5): 71 | super().__init__() 72 | if norm_fn == 'group': 73 | self.norm = nn.GroupNorm(num_groups=8, num_channels=64) 74 | 75 | elif norm_fn == 'batch': 76 | self.norm = nn.BatchNorm2d(64) 77 | 78 | elif norm_fn == 'instance': 79 | self.norm = nn.InstanceNorm2d(64) 80 | 81 | elif norm_fn == 'none': 82 | self.norm = nn.Sequential() 83 | 84 | self.conv = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.relu = nn.ReLU(inplace=True) 86 | 87 | self.resblock_1 = nn.Sequential( 88 | ResidualBlock(32, 32, norm_fn=norm_fn, kernel_size=3, stride=1, padding=1, bias=False), 89 | ResidualBlock(32, 64, norm_fn=norm_fn, kernel_size=3, stride=2, padding=1, bias=False), 90 | ResidualBlock(64, 64, norm_fn=norm_fn, kernel_size=3, stride=1, padding=1, bias=False), 91 | ) 92 | 93 | self.resblock_2 = nn.Sequential( 94 | ResidualBlock(64, 128, norm_fn=norm_fn, kernel_size=3, stride=2, padding=1, bias=False), 95 | ResidualBlock(128, 128, norm_fn=norm_fn, kernel_size=3, stride=1, padding=1, bias=False), 96 | ) 97 | 98 | self.resblock_3 = nn.Sequential( 99 | ResidualBlock(128, 128, norm_fn=norm_fn, kernel_size=3, stride=2, padding=1, bias=False), 100 | ResidualBlock(128, 128, norm_fn=norm_fn, kernel_size=3, stride=1, padding=1, bias=False), 101 | ) 102 | 103 | self.conv_4 = nn.Sequential( 104 | nn.Conv2d(128, 128, kernel_size=4, bias=False), 105 | ) 106 | 107 | def forward(self, x): 108 | out = self.conv(x) 109 | out = self.norm(out) 110 | out = self.relu(out) 111 | out1 = self.resblock_1(out) # /2 112 | out2 = self.resblock_2(out1) # /4 113 | out3 = self.resblock_3(out2) # /8 114 | out4 = self.conv_4(out3) # /8 -3 115 | 116 | return [out1, out2, out3, out4] 117 | 118 | class ResNet_decoder(nn.Module): 119 | def __init__(self, dropout=0.5, feature_dim=16): 120 | super().__init__() 121 | self.feature_dim = feature_dim 122 | 123 | self.head_4 = nn.Sequential( 124 | nn.ReLU(), 125 | nn.Dropout(dropout), 126 | nn.Conv2d(128, self.feature_dim, kernel_size=1, bias=False) 127 | ) 128 | 129 | self.up_sample_3 = nn.Sequential( 130 | nn.ReLU(), 131 | nn.ConvTranspose2d(128, 128, kernel_size=4, bias=False) 132 | ) 133 | self.conv_3 = nn.Sequential( 134 | nn.ReLU(), 135 | nn.Conv2d(256, 128, kernel_size=1, bias=False), 136 | 137 | ) 138 | self.head_3 = nn.Sequential( 139 | nn.ReLU(), 140 | nn.Dropout(dropout), 141 | nn.Conv2d(128, self.feature_dim, kernel_size=1, bias=False) 142 | ) 143 | 144 | self.up_sample_2 = nn.UpsamplingBilinear2d(scale_factor=2) 145 | self.conv_2 = nn.Sequential( 146 | nn.ReLU(), 147 | nn.Conv2d(256, 128, kernel_size=1, bias=False), 148 | 149 | ) 150 | self.head_2 = nn.Sequential( 151 | nn.ReLU(), 152 | nn.Dropout(dropout), 153 | nn.Conv2d(128, self.feature_dim, kernel_size=1, bias=False) 154 | ) 155 | 156 | def forward(self, x1, x2, x3, x4): 157 | out4 = self.head_4(x4) 158 | 159 | out = self.up_sample_3(x4) 160 | out = torch.cat([out, x3], dim=1) 161 | out = self.conv_3(out) 162 | out3 = self.head_3(out) 163 | 164 | out = self.up_sample_2(out) 165 | out = torch.cat([out, x2], dim=1) 166 | out = self.conv_2(out) 167 | out2 = self.head_2(out) 168 | 169 | return [out2, out3, out4] 170 | 171 | class RetrievalNet(nn.Module): 172 | def __init__(self, cfg): 173 | super().__init__() 174 | self.cfg = cfg 175 | self.hidden_dim = cfg["MODEL"]["HIDDEN_DIM"]**2 176 | self.feature_dim = cfg["MODEL"]["FEATURE_DIM"] 177 | self.input_dim = cfg["DATA"]["CROP_SIZE"] 178 | 179 | self.input_norm = LocalNorm2d(17) 180 | 181 | self.encoder = ResNet_encoder(norm_fn='none', dropout=cfg["TRAIN"]["DROP"]) 182 | self.decoder = ResNet_decoder(dropout=self.cfg["TRAIN"]["DROP"], feature_dim=cfg["MODEL"]["FEATURE_DIM"]) 183 | 184 | def forward(self, img): 185 | B, _, H, W = img.shape 186 | 187 | if img.size(1) > 1: 188 | img = img.mean(dim=1, keepdim=True) 189 | 190 | if self.cfg["MODEL"]["LOCALNORM"] is True: 191 | img = self.input_norm(img) 192 | 193 | [out1, out2, out3, out4] = self.encoder(img) 194 | 195 | out = self.decoder(out1, out2, out3, out4) 196 | 197 | out = [out[i] / torch.norm(out[i], p=2, dim=1, keepdim=True).clamp(min=1e-8) for i in range(len(out))] 198 | return out 199 | 200 | class Sim_predictor(nn.Module): 201 | def __init__(self, cfg): 202 | super().__init__() 203 | self.cfg = cfg 204 | self.dropout = self.cfg["TRAIN"]["DROP"] 205 | self.feature_dim = cfg["MODEL"]["FEATURE_DIM"] 206 | 207 | self.scales = cfg["MODEL"]["SCALES"] 208 | 209 | self.fc_1 = nn.Sequential( 210 | nn.Conv1d(self.feature_dim, 2, 1), 211 | ) 212 | self.fc_2 = nn.Sequential( 213 | nn.Conv1d(self.feature_dim, 2, 1), 214 | ) 215 | self.fc_3 = nn.Sequential( 216 | nn.Conv1d(self.feature_dim, 2, 1), 217 | ) 218 | 219 | self.fcs = [self.fc_1, self.fc_2, self.fc_3] 220 | 221 | self.fc_finial = nn.Sequential( 222 | nn.Linear(self.feature_dim*len(self.scales), self.feature_dim), 223 | nn.LeakyReLU(), 224 | nn.Linear(self.feature_dim, 1) 225 | ) 226 | 227 | def fusion(self, src_f, ref_f): 228 | out = [] 229 | for i in range(len(self.scales)): 230 | fuse_f = (src_f[i][:, None]*ref_f[i][None]).view(-1, self.feature_dim, self.scales[i]**2) 231 | weights = self.fcs[i](fuse_f) 232 | 233 | local_mask = weights[:, 0, :] 234 | global_mask = weights[:, 1, :] 235 | 236 | local_mask = torch.sigmoid(local_mask) 237 | 238 | weights = torch.exp(global_mask) * local_mask 239 | weights = weights / torch.sum(weights, dim=-1, keepdim=True).clamp(min=1e-8) 240 | 241 | out.append((fuse_f * weights[:, None]).sum(dim=-1)) 242 | 243 | out = torch.cat(out, dim=-1) 244 | return out 245 | 246 | def forward(self, src_f, ref_f): 247 | B_src = src_f[0].shape[0] 248 | out = self.fusion(src_f, ref_f) 249 | sim = torch.tanh(self.fc_finial(out)).view(B_src, -1) 250 | return sim 251 | -------------------------------------------------------------------------------- /core/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import random 5 | import glob 6 | import os 7 | from core.utils import distort_hsv, distort_noise, distort_smooth 8 | 9 | class Compose: 10 | def __init__(self, transforms): 11 | self.transforms = transforms 12 | 13 | def __call__(self, img): 14 | for t in self.transforms: 15 | img = t(img) 16 | 17 | return img 18 | 19 | def __repr__(self): 20 | format_str = self.__class__.__name__ + '(' 21 | for t in self.transforms: 22 | format_str += '\n' 23 | format_str += f' {t}' 24 | format_str += '\n)' 25 | 26 | return format_str 27 | 28 | class RandomHSV: 29 | def __init__(self, h_ratio, s_ratio, v_ratio): 30 | self.h_ratio = h_ratio 31 | self.s_ratio = s_ratio 32 | self.v_ratio = v_ratio 33 | def __call__(self, img): 34 | img = distort_hsv(img, self.h_ratio, self.s_ratio, self.v_ratio) 35 | return img 36 | 37 | class RandomNoise: 38 | def __init__(self, noise_ratio): 39 | self.noise_ratio = noise_ratio 40 | def __call__(self, img): 41 | img = distort_noise(img, self.noise_ratio) 42 | return img 43 | 44 | class RandomSmooth: 45 | def __init__(self, smooth_ratio): 46 | self.smooth_ratio = smooth_ratio 47 | def __call__(self, img): 48 | img = distort_smooth(img, self.smooth_ratio) 49 | return img 50 | 51 | class ToTensor: 52 | def __call__(self, img): 53 | img = img.transpose(2, 0, 1) 54 | img = torch.from_numpy(img).float() 55 | return img 56 | 57 | class Normalize: 58 | def __init__(self, mean, std): 59 | self.mean = mean 60 | self.std = std 61 | 62 | def __call__(self, img): 63 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255 64 | img = img - np.array(self.mean).reshape(1,1,3) 65 | img = img / np.array(self.std).reshape(1,1,3) 66 | return img 67 | 68 | class RandomBackground: 69 | def __init__(self, background_dir): 70 | self.background_files = [] 71 | try: 72 | if os.path.exists(background_dir): 73 | png_files = glob.glob(os.path.join(background_dir, '*.png')) 74 | jpg_files = glob.glob(os.path.join(background_dir, '*.jpg')) 75 | self.background_files += png_files 76 | self.background_files += jpg_files 77 | except: 78 | print("can not read background directory, remains empty") 79 | pass 80 | print("Number of background images: %d" % len(self.background_files)) 81 | 82 | def __call__(self, img, mask): 83 | if len(self.background_files) > 0: 84 | if img.shape[2] == 4: 85 | img = self.merge_background_alpha(img, self.get_a_random_background()) 86 | else: 87 | img = self.merge_background_mask(img, self.get_a_random_background(), mask) 88 | else: 89 | img = img[:,:,0:3] 90 | return img 91 | 92 | def get_a_random_background(self): 93 | backImg = None 94 | while backImg is None: 95 | backIdx = random.randint(0, len(self.background_files) - 1) 96 | img_path = self.background_files[backIdx] 97 | try: 98 | backImg = cv2.imread(img_path) 99 | if backImg is None: 100 | raise RuntimeError('load image error') 101 | except: 102 | print('Error in loading background image: %s' % img_path) 103 | backImg = None 104 | return backImg 105 | 106 | def merge_background_alpha(self, foreImg, backImg): 107 | assert(foreImg.shape[2] == 4) 108 | forergb = foreImg[:, :, :3] 109 | alpha = foreImg[:, :, 3] / 255.0 110 | if forergb.shape != backImg.shape: 111 | backImg = cv2.resize(backImg, (foreImg.shape[1], foreImg.shape[0])) 112 | alpha = np.repeat(alpha, 3).reshape(foreImg.shape[0], foreImg.shape[1], 3) 113 | mergedImg = np.uint8(backImg * (1 - alpha) + forergb * alpha) 114 | # backImg[alpha > 128] = forergb[alpha > 128] 115 | return mergedImg 116 | 117 | def merge_background_mask(self, foreImg, backImg, maskImg): 118 | forergb = foreImg[:, :, :3] 119 | if forergb.shape != backImg.shape: 120 | backImg = cv2.resize(backImg, (foreImg.shape[1], foreImg.shape[0])) 121 | alpha = np.ones((foreImg.shape[0], foreImg.shape[1], 3), np.float32) 122 | alpha[maskImg == 0] = 0 123 | mergedImg = np.uint8(backImg * (1 - alpha) + forergb * alpha) 124 | # backImg[alpha > 128] = forergb[alpha > 128] 125 | return mergedImg 126 | 127 | class RandomOcclusion: 128 | """ 129 | randomly erasing holes 130 | ref: https://arxiv.org/abs/1708.04896 131 | """ 132 | def __init__(self, prob = 0): 133 | self.prob = prob 134 | 135 | def __call__(self, img): 136 | if self.prob > 0: 137 | height, width, _ = img.shape 138 | bw = width 139 | bh = height 140 | x1 = 0 141 | x2 = width 142 | y1 = 0 143 | y2 = height 144 | if random.uniform(0, 1) <= self.prob and bw > 2 and bh > 2: 145 | bb_size = bw*bh 146 | size = random.uniform(0.02, 0.7) * bb_size 147 | ratio = random.uniform(0.5, 2.0) 148 | ew = int(np.sqrt(size * ratio)) 149 | eh = int(np.sqrt(size / ratio)) 150 | ecx = random.uniform(x1, x2) 151 | ecy = random.uniform(y1, y2) 152 | esx = int(np.clip((ecx - ew/2 + 0.5), 0, width-1)) 153 | esy = int(np.clip((ecy - eh/2 + 0.5), 0, height-1)) 154 | eex = int(np.clip((ecx + ew/2 + 0.5), 0, width-1)) 155 | eey = int(np.clip((ecy + eh/2 + 0.5), 0, height-1)) 156 | targetshape = img[esy:eey, esx:eex, :].shape 157 | img[esy:eey, esx:eex, :] = np.random.randint(256, size=targetshape) 158 | return img, target 159 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import json 4 | import trimesh 5 | import random 6 | import numpy as np 7 | import cv2 8 | import torch 9 | from pytorch3d.ops import sample_farthest_points 10 | import transforms3d 11 | 12 | def save_json(path, meta): 13 | meta_dump = json.dumps(meta) 14 | f = open(path, 'w') 15 | f.write(meta_dump) 16 | f.close() 17 | 18 | def load_json(path): 19 | f = open(path, 'r') 20 | info = json.load(f) 21 | return info 22 | 23 | def load_checkpoint(model, optimizer, pth_file): 24 | """load state and network weights""" 25 | checkpoint = torch.load(pth_file, map_location=lambda storage, loc: storage.cuda()) 26 | pretrained_dict = checkpoint['state_dict'] 27 | model_dict = model.state_dict() 28 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 29 | model_dict.update(pretrained_dict) 30 | model.load_state_dict(model_dict) 31 | optimizer.load_state_dict(checkpoint['optimizer']) 32 | start_epoch = checkpoint['epoch'] + 1 33 | try: 34 | best_acc = checkpoint['best_acc'] 35 | except: 36 | best_acc = 0 37 | print("Best acc was not saved") 38 | print('Previous weight loaded') 39 | return model, optimizer, start_epoch, best_acc 40 | 41 | def load_bop_meshes(model_path): 42 | # load meshes 43 | meshFiles = [f for f in os.listdir(model_path) if f.endswith('.ply')] 44 | meshFiles.sort() 45 | meshes = [] 46 | objID_2_clsID = {} 47 | 48 | for i in range(len(meshFiles)): 49 | mFile = meshFiles[i] 50 | objId = int(os.path.splitext(mFile)[0][4:]) 51 | objID_2_clsID[str(objId)] = i 52 | meshes.append(trimesh.load(model_path + mFile)) 53 | 54 | return meshes, objID_2_clsID 55 | 56 | def load_bbox_3d(jsonFile): 57 | with open(jsonFile, 'r') as f: 58 | bbox_3d = json.load(f) 59 | return bbox_3d 60 | 61 | def get_single_bop_annotation(img_path, objID_2_clsID): 62 | # add attributes to function, for fast loading 63 | if not hasattr(get_single_bop_annotation, "dir_annots"): 64 | get_single_bop_annotation.dir_annots = {} 65 | # 66 | img_path = img_path.strip() 67 | cvImg = cv2.imread(img_path) 68 | height, width, _ = cvImg.shape 69 | # 70 | gt_dir, tmp, imgName = img_path.rsplit('/', 2) 71 | assert(tmp == 'rgb') 72 | imgBaseName, _ = os.path.splitext(imgName) 73 | im_id = int(imgBaseName) 74 | # 75 | camera_file = gt_dir + '/scene_camera.json' 76 | gt_file = gt_dir + "/scene_gt.json" 77 | # gt_info_file = gt_dir + "/scene_gt_info.json" 78 | gt_mask_visib = gt_dir + "/mask_visib/" 79 | 80 | if gt_dir in get_single_bop_annotation.dir_annots: 81 | gt_json, cam_json = get_single_bop_annotation.dir_annots[gt_dir] 82 | else: 83 | gt_json = json.load(open(gt_file)) 84 | # gt_info_json = json.load(open(gt_info_file)) 85 | cam_json = json.load(open(camera_file)) 86 | # 87 | get_single_bop_annotation.dir_annots[gt_dir] = [gt_json, cam_json] 88 | 89 | if str(im_id) in cam_json: 90 | annot_camera = cam_json[str(im_id)] 91 | else: 92 | annot_camera = cam_json[("%06d" % im_id)] 93 | if str(im_id) in gt_json: 94 | annot_poses = gt_json[str(im_id)] 95 | else: 96 | annot_poses = gt_json[("%06d" % im_id)] 97 | # annot_infos = gt_info_json[str(im_id)] 98 | 99 | objCnt = len(annot_poses) 100 | K = np.array(annot_camera['cam_K']).reshape(3,3) 101 | 102 | class_ids = [] 103 | # bbox_objs = [] 104 | rotations = [] 105 | translations = [] 106 | merged_mask = np.zeros((height, width), np.uint8) # segmenation masks 107 | for i in range(objCnt): 108 | mask_vis_file = gt_mask_visib + ("%06d_%06d.png" %(im_id, i)) 109 | mask_vis = cv2.imread(mask_vis_file, cv2.IMREAD_UNCHANGED) 110 | # 111 | # bbox = annot_infos[i]['bbox_visib'] 112 | # bbox = annot_infos[i]['bbox_obj'] 113 | # contourImg = cv2.rectangle(contourImg, (bbox[0], bbox[1]), (bbox[0]+bbox[2], bbox[1]+bbox[3]), (0,0,255)) 114 | # cv2.imshow(str(i), mask_vis) 115 | # 116 | R = np.array(annot_poses[i]['cam_R_m2c']).reshape(3,3) 117 | T = np.array(annot_poses[i]['cam_t_m2c']).reshape(3,1) 118 | obj_id = annot_poses[i]['obj_id'] 119 | cls_id = objID_2_clsID[str(obj_id)] 120 | # 121 | # bbox_objs.append(bbox) 122 | class_ids.append(cls_id) 123 | rotations.append(R) 124 | translations.append(T) 125 | # compose segmentation labels 126 | merged_mask[mask_vis==255] = (i+1) 127 | 128 | return K, merged_mask, class_ids, rotations, translations 129 | 130 | def remap_pose(srcK, srcR, srcT, pt3d, dstK): 131 | ptCnt = len(pt3d) 132 | pts = np.matmul(srcK, np.matmul(srcR, pt3d.transpose()) + srcT) 133 | xs = pts[0] / (pts[2] + 1e-12) 134 | ys = pts[1] / (pts[2] + 1e-12) 135 | xy2d = np.concatenate((xs.reshape(-1,1),ys.reshape(-1,1)), axis=1) 136 | 137 | #retval, rot, trans, inliers = cv2.solvePnPRansac(pt3d, xy2d, dstK, None, flags=cv2.SOLVEPNP_EPNP, reprojectionError=5.0) 138 | retval, rot, trans = cv2.solvePnP(pt3d.reshape(ptCnt,1,3), xy2d.reshape(ptCnt,1,2), dstK, None, flags=cv2.SOLVEPNP_EPNP) 139 | if retval: 140 | newR = cv2.Rodrigues(rot)[0] # convert to rotation matrix 141 | newT = trans.reshape(-1, 1) 142 | 143 | return newR, newT 144 | else: 145 | print('Error in pose remapping!') 146 | return srcR, srcT 147 | 148 | def pose_symmetry_handling(R, sym_types): 149 | if len(sym_types) == 0: 150 | return R, T 151 | 152 | assert(len(sym_types) % 2 == 0) 153 | itemCnt = int(len(sym_types) / 2) 154 | 155 | for i in range(itemCnt): 156 | axis = sym_types[2*i] 157 | mod = sym_types[2*i + 1] * np.pi / 180 158 | if axis == 'X': 159 | ai, aj, ak = transforms3d.euler.mat2euler(R, axes='sxyz') 160 | ai = 0 if mod == 0 else (ai % mod) 161 | R = transforms3d.euler.euler2mat(ai, aj, ak, axes='sxyz') 162 | elif axis == 'Y': 163 | ai, aj, ak = transforms3d.euler.mat2euler(R, axes='syzx') 164 | ai = 0 if mod == 0 else (ai % mod) 165 | R = transforms3d.euler.euler2mat(ai, aj, ak, axes='syzx') 166 | elif axis == 'Z': 167 | ai, aj, ak = transforms3d.euler.mat2euler(R, axes='szyx') 168 | ai = 0 if mod == 0 else (ai % mod) 169 | R = transforms3d.euler.euler2mat(ai, aj, ak, axes='szyx') 170 | else: 171 | print("symmetry axis should be 'X', 'Y' or 'Z'") 172 | assert(0) 173 | return R.astype(np.float32) 174 | 175 | def geodesic_distance(src_R, ref_R): 176 | sim = (torch.sum(src_R.view(-1, 9) * ref_R.view(-1, 9), dim=-1).clamp(-1, 3) - 1) / 2 177 | geo_dis = torch.arccos(sim) * 180. / np.pi 178 | return sim, geo_dis 179 | 180 | def farthest_point_sample_6d(Rs, K, random_start_point): 181 | x_col = Rs.view(-1, 3, 3)[:, :, 0] 182 | y_col = Rs.view(-1, 3, 3)[:, :, 1] 183 | 184 | x_col = x_col / torch.norm(x_col, p=2, dim=-1, keepdim=True).clamp(min=1e-8) 185 | y_col = y_col / torch.norm(y_col, p=2, dim=-1, keepdim=True).clamp(min=1e-8) 186 | 187 | Rs_6d = torch.cat([x_col, y_col], dim=-1) 188 | 189 | _, anchor_indices = sample_farthest_points(Rs_6d[None], K=K, random_start_point=random_start_point) 190 | anchor_indices = anchor_indices.squeeze(0) 191 | 192 | anchor = Rs[anchor_indices] 193 | 194 | return anchor, anchor_indices 195 | 196 | def distort_hsv(img, h_ratio, s_ratio, v_ratio): 197 | img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # hue, sat, val 198 | h = img_hsv[:, :, 0].astype(np.float32) # hue 199 | s = img_hsv[:, :, 1].astype(np.float32) # saturation 200 | v = img_hsv[:, :, 2].astype(np.float32) # value 201 | a = random.uniform(-1, 1) * h_ratio + 1 202 | b = random.uniform(-1, 1) * s_ratio + 1 203 | c = random.uniform(-1, 1) * v_ratio + 1 204 | h *= a 205 | s *= b 206 | v *= c 207 | img_hsv[:, :, 0] = h if a < 1 else h.clip(None, 179) 208 | img_hsv[:, :, 1] = s if b < 1 else s.clip(None, 255) 209 | img_hsv[:, :, 2] = v if c < 1 else v.clip(None, 255) 210 | return cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR) 211 | 212 | def distort_noise(img, noise_ratio=0): 213 | # add noise 214 | noisesigma = random.uniform(0, noise_ratio) 215 | gauss = np.random.normal(0, noisesigma, img.shape) * 255 216 | img = img + gauss 217 | 218 | img[img > 255] = 255 219 | img[img < 0] = 0 220 | 221 | return np.uint8(img) 222 | 223 | def distort_smooth(img, smooth_ratio=0): 224 | # add smooth 225 | smoothsigma = random.uniform(0, smooth_ratio) 226 | res = cv2.GaussianBlur(img, (7, 7), smoothsigma, cv2.BORDER_DEFAULT) 227 | return res 228 | -------------------------------------------------------------------------------- /objects.yaml: -------------------------------------------------------------------------------- 1 | RENDER: 2 | OUTPUT_PATH: "../data/reference_pool_6d_sample_10000" 3 | NUM: 10000 4 | CAM_DIST: 4 5 | MODEL: 6 | LOCALNORM: True 7 | HIDDEN_DIM: 7 8 | FEATURE_DIM: 128 9 | SCALES: [32, 16, 13] 10 | 11 | DATA: 12 | PIXEL_MEAN: [0.485, 0.456, 0.406] 13 | PIXEL_STD: [0.229, 0.224, 0.225] 14 | MESH_DIR: '../data/linemod_zhs/models/' #/cvlabdata2/home/yhu/data/linemod/models/ 15 | BBOX_FILE: '../data/linemod_zhs/linemod_bbox.json' 16 | META_DIR: '../data/' 17 | RENDER_DIR: 'reference_pool_6d_sample_10000' 18 | INTERNAL_WIDTH: 640 19 | INTERNAL_HEIGHT: 480 20 | CROP_SIZE: 128 21 | VIEW_NUM: 18 22 | ANCHOR_NUM: 1024 23 | DATASET: 'LINEMOD' 24 | INTERNAL_K: [572.4114, 0, 325.2611, 0, 573.57043, 242.04899, 0, 0, 1] 25 | ROTATION_NOISE: [0.5, 0.5, 0.5, 0.5] 26 | TRANSLATION_NOISE: [0, 0, 0] 27 | DIAMETER: [9.74298, 28.6908, 17.1185, 17.1593, 19.3416, 15.2633, 12.5961, 25.9425, 10.7131, 17.6364, 16.4857, 14.8204, 30.3153, 28.5155, 20.8394] 28 | 29 | TRAIN: 30 | VAL_STEP: 10 31 | DROP: 0.1 32 | BS: 16 33 | WORKERS: 8 34 | MAX_EPOCH: 200 35 | RATIO: 0.5 36 | LR: 0.0001 37 | GAMMA: 0.1 38 | STEP: [50, 150] 39 | WORKING_DIR: "./logs/linemod/" 40 | FROM_SCRATCH: True 41 | ROTATION_AG: True 42 | RANDOM_OCC: False 43 | NORM_FUNC: "instance" 44 | RUNNING_DEVICE: "cuda" 45 | 46 | TEST: 47 | VISUAL: True 48 | VISUAL_PATH: "./logs/linemod/visual/" 49 | THR_SO3: 30 50 | INIT_K: 4096 51 | FPS_K: 256 52 | UNSEEN: ['APE', 'BENCHVISE', 'CAM', 'CAN'] 53 | VIZ: True 54 | 55 | LINEMOD: 56 | APE: 1 57 | BENCHVISE: 2 58 | CAM: 4 59 | CAN: 5 60 | CAT: 6 61 | DRILLER: 8 62 | DUCK: 9 63 | EGGBOX: 10 64 | GLUE: 11 65 | HOLEPUNCHER: 12 66 | IRON: 13 67 | LAMP: 14 68 | PHONE: 15 69 | 70 | SYMMETRIC_OBJ: {"000010": ['Z', 180], "000011": ['Z', 180]} 71 | MOD: 180. 72 | 73 | YCBV: 74 | master_che_can: 1 75 | racker_box: 2 76 | sugar_box: 3 77 | tomato_soup_can: 4 78 | mustard_bottle: 5 79 | tuna_fish_can: 6 80 | pudding_box: 7 81 | gelatin_box: 8 82 | potted_meat_can: 9 83 | banana: 10 84 | pitcher_base: 11 85 | bleach_cleanser: 12 86 | bowl: 13 87 | mug: 14 88 | power_drill: 15 89 | wood_block: 16 90 | scissors: 17 91 | large_marker: 18 92 | large_clamp: 19 93 | extra_large_clamp: 20 94 | foam_brick: 21 95 | 96 | SYMMETRIC_OBJ: { 97 | "000013":['Z',0], 98 | "000016":['X',180,'Y',180,'Z',90], 99 | "000019":['Y',180], 100 | "000020":['X',180], 101 | "000021":['X',180,'Y',90,'Z',180] 102 | } 103 | -------------------------------------------------------------------------------- /pretrained_models/checkpoint_1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sailor-z/Unseen_Object_Pose/42d10d4498660882874a36d4797397737857d0ac/pretrained_models/checkpoint_1.pth -------------------------------------------------------------------------------- /pretrained_models/checkpoint_2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sailor-z/Unseen_Object_Pose/42d10d4498660882874a36d4797397737857d0ac/pretrained_models/checkpoint_2.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import yaml 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | import cv2 9 | import imageio 10 | from tqdm import tqdm, trange 11 | from core.model import RetrievalNet as Model 12 | from core.model import Sim_predictor as Predictor 13 | from core.dataset import LINEMOD_SO3 as LINEMOD 14 | from core.dataset import ref_loader_so3 as ref_loader 15 | from core.utils import geodesic_distance 16 | from pytorch3d.ops import sample_farthest_points 17 | 18 | np.set_printoptions(threshold=np.inf) 19 | np.random.seed(0) 20 | 21 | def visual(cfg, src_img, ref_img, gt_ref_img): 22 | src_img = src_img.permute(1, 2, 0).cpu().detach().numpy() 23 | ref_img = ref_img.permute(1, 2, 0).cpu().detach().numpy() 24 | gt_ref_img = gt_ref_img.permute(1, 2, 0).cpu().detach().numpy() 25 | 26 | ref_mask = np.absolute(ref_img).sum(axis=-1, keepdims=True) > 0 27 | gt_ref_mask = np.absolute(gt_ref_img).sum(axis=-1, keepdims=True) > 0 28 | 29 | src_img = src_img * np.array(cfg['DATA']['PIXEL_STD']).reshape(1, 1, 3) \ 30 | + np.array(cfg['DATA']['PIXEL_MEAN']).reshape(1, 1, 3) 31 | src_img = (255*src_img).astype(np.uint8) 32 | 33 | ref_img = ref_img * np.array(cfg['DATA']['PIXEL_STD']).reshape(1, 1, 3) \ 34 | + np.array(cfg['DATA']['PIXEL_MEAN']).reshape(1, 1, 3) 35 | ref_img = ref_img * ref_mask 36 | ref_img = (255*ref_img).astype(np.uint8) 37 | 38 | gt_ref_img = gt_ref_img * np.array(cfg['DATA']['PIXEL_STD']).reshape(1, 1, 3) \ 39 | + np.array(cfg['DATA']['PIXEL_MEAN']).reshape(1, 1, 3) 40 | gt_ref_img = gt_ref_img * gt_ref_mask 41 | gt_ref_img = (255*gt_ref_img).astype(np.uint8) 42 | 43 | h, w, _ = src_img.shape 44 | viz_img = np.zeros([h, 3*w, 3]).astype(np.uint8) 45 | 46 | viz_img[:, :w, :] = src_img 47 | viz_img[:, w:2*w, :] = gt_ref_img 48 | viz_img[:, 2*w:, :] = ref_img 49 | return viz_img 50 | 51 | def val(cfg, model, predictor, device): 52 | model.eval() 53 | predictor.eval() 54 | print(">>>>>>>>>>>>>> Loading reference database") 55 | ref_database = ref_loader(cfg) 56 | 57 | ref_info = {} 58 | with torch.no_grad(): 59 | for clsID in ref_database.ref_info.keys(): 60 | print(">>>>>>>>>>>>>> Estimating features for ref " + clsID) 61 | ref_info_clsID = ref_database.load(clsID) 62 | anchors, anchor_indices = sample_farthest_points(ref_info_clsID["Rs"].view(1, -1, 9), K=cfg["DATA"]["ANCHOR_NUM"], random_start_point=False) 63 | anchors, anchor_indices = anchors[0], anchor_indices[0] 64 | 65 | ref_info_clsID["anchors"] = anchors.to(device) 66 | ref_info_clsID["indices"] = anchor_indices.to(device) 67 | ref_info_clsID["Rs"] = ref_info_clsID["Rs"].to(device) 68 | ref_info_clsID["ref_f"] = [] 69 | 70 | ref_imgs = ref_info_clsID["imgs"] 71 | for i in trange(ref_imgs.shape[0]): 72 | ref_img = ref_imgs[i][None].to(device) 73 | ref_f = model(ref_img) 74 | ref_f = [ref_f[j].cpu().detach() for j in range(len(ref_f))] 75 | ref_info_clsID["ref_f"].append(ref_f) 76 | 77 | del ref_img, ref_imgs, ref_f 78 | torch.cuda.empty_cache() 79 | 80 | ref_info_clsID["anchor_f"] = [] 81 | for j in range(len(ref_info_clsID["ref_f"][0])): 82 | ref_info_clsID["anchor_f"].append(torch.cat([ref_info_clsID["ref_f"][idx][j].to(device) for idx in ref_info_clsID["indices"]], dim=0)) 83 | 84 | ref_info[clsID] = ref_info_clsID 85 | 86 | if cfg["TEST"]["VISUAL"] is True: 87 | if not os.path.exists(cfg["TEST"]["VISUAL_PATH"]): 88 | os.makedirs(cfg["TEST"]["VISUAL_PATH"]) 89 | 90 | print(">>>>>>>>>>>>>> START TESTING") 91 | test_cat_acc_all, test_R_acc_all = [], [] 92 | for cat in cfg["TEST"]["UNSEEN"]: 93 | objID = cfg[cfg["DATA"]["DATASET"]][cat] 94 | 95 | test_cat_acc, test_R_acc = test_category(cfg, model, predictor, ref_info, device, objID) 96 | 97 | test_cat_acc_all += [test_cat_acc] 98 | test_R_acc_all += [test_R_acc] 99 | 100 | test_cat_acc_all = np.asarray(test_cat_acc_all).mean() 101 | test_R_acc_all = np.asarray(test_R_acc_all).mean() 102 | 103 | print('All categories -- || Testing Cls Acc: %.2f || Testing R Acc: %.2f' % (test_cat_acc_all, test_R_acc_all)) 104 | 105 | return [test_R_acc_all, test_cat_acc_all] 106 | 107 | def iterative_retrieval(src_R, src_f, ref_info_clsID, predictor, device, max_iter=5, init_K=4096, fps_k=256, shrink_ratio=2): 108 | pred_sims = predictor(src_f, ref_info_clsID["anchor_f"]).squeeze(0) 109 | pred_index = torch.argmax(pred_sims) 110 | 111 | anchor = ref_info_clsID["anchors"][pred_index] 112 | 113 | for i in range(max_iter): 114 | neighbor_indices = torch.topk((anchor.view(1, 9) * ref_info_clsID["Rs"].view(-1, 9)).sum(dim=-1), k=init_K//(shrink_ratio**i), sorted=False)[1] 115 | 116 | if neighbor_indices.shape[0] > fps_k: 117 | _, fps_indices = sample_farthest_points(ref_info_clsID["Rs"][neighbor_indices].view(1, -1, 9), K=fps_k, random_start_point=False) 118 | neighbor_indices = neighbor_indices[fps_indices.squeeze(0)] 119 | 120 | ref_f = [] 121 | for j in range(len(ref_info_clsID["ref_f"][0])): 122 | ref_f.append(torch.cat([ref_info_clsID["ref_f"][neighbor_indices[idx]][j].to(device) \ 123 | for idx in range(neighbor_indices.shape[0])], dim=0)) 124 | 125 | pred_sims = predictor(src_f, ref_f).squeeze(0) 126 | 127 | pred_index = neighbor_indices[torch.argmax(pred_sims)] 128 | pred_sim = torch.max(pred_sims) 129 | 130 | anchor_new = ref_info_clsID["Rs"][pred_index] 131 | 132 | anchor = anchor_new 133 | 134 | return pred_index, pred_sim 135 | 136 | def test_category(cfg, model, predictor, ref_info, device, objID): 137 | dataset_test = LINEMOD(cfg, 'test', objID) 138 | dataset_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=cfg["TRAIN"]["WORKERS"], drop_last=False) 139 | 140 | print(">>>>>>>>>> TESTING DATA of %06d:" % (objID), len(dataset_loader)) 141 | 142 | if cfg["TEST"]["VISUAL"] is True: 143 | filename_output = cfg["TEST"]["VISUAL_PATH"] + '/%06d.gif' % (objID) 144 | writer = imageio.get_writer(filename_output, mode='I', duration=1.0) 145 | 146 | test_cat_acc, test_R_acc, test_errs, gt_errs = [], [], [], [] 147 | with torch.no_grad(): 148 | for i, data in enumerate(tqdm(dataset_loader)): 149 | # load data and label 150 | src_img, src_R, id = data 151 | src_img = src_img.to(device) 152 | src_R = src_R.to(device) 153 | 154 | src_f = model(src_img) 155 | 156 | '''Category prediction''' 157 | gt_clsID = "%06d" % (id) 158 | 159 | anchor_sims = [] 160 | for clsID in ref_info.keys(): 161 | pred_sims = predictor(src_f, ref_info[clsID]["anchor_f"]).squeeze(0) 162 | anchor_sims.append(torch.topk(pred_sims, k=3)[0].mean()) 163 | pred_cls_index = torch.argmax(torch.stack(anchor_sims)) 164 | pred_clsID = list(ref_info)[pred_cls_index] 165 | 166 | '''Fast Retrieval''' 167 | pred_index, pred_sim = iterative_retrieval(src_R, src_f, ref_info[pred_clsID], predictor, device, \ 168 | max_iter=4, init_K=cfg["TEST"]["INIT_K"], fps_k=cfg["TEST"]["FPS_K"], shrink_ratio=2) 169 | 170 | ref_sim, ref_err = geodesic_distance(src_R, ref_info[gt_clsID]["Rs"]) 171 | 172 | gt_index = torch.argmin(ref_err).item() 173 | gt_err = torch.min(ref_err).item() 174 | 175 | pred_err = ref_err[pred_index].item() 176 | 177 | cls_acc = int(pred_clsID == gt_clsID) 178 | 179 | test_cat_acc.append(cls_acc) 180 | test_errs.append(pred_err) 181 | test_R_acc.append(float(pred_err <= cfg["TEST"]["THR_SO3"]) * cls_acc) 182 | gt_errs.append(gt_err) 183 | if i % 20 == 0: 184 | if cfg["TEST"]["VISUAL"] is True: 185 | ref_img_pred = ref_info[pred_clsID]["imgs"][pred_index] 186 | ref_img_gt = ref_info[gt_clsID]["imgs"][gt_index] 187 | viz_img = visual(cfg, src_img.squeeze(0), ref_img_pred, ref_img_gt) 188 | writer.append_data(cv2.cvtColor(viz_img, cv2.COLOR_BGR2RGB)) 189 | 190 | if cfg["TEST"]["VISUAL"] is True: 191 | writer.close() 192 | 193 | test_cat_acc = 100 * np.asarray(test_cat_acc).mean() 194 | test_R_acc = 100 * np.asarray(test_R_acc).mean() 195 | test_errs = np.asarray(test_errs).mean() 196 | gt_errs = np.asarray(gt_errs).mean() 197 | print('Category: %02d -- || GT Err: %.2f || Test Err: %.2f || Testing Classifcation Acc: %.2f || Testing R1: %.2f' % \ 198 | (objID, gt_errs, test_errs, test_cat_acc, test_R_acc)) 199 | return test_cat_acc, test_R_acc 200 | 201 | def test(cfg, device): 202 | if not os.path.exists(cfg["TRAIN"]["WORKING_DIR"]): 203 | os.makedirs(cfg["TRAIN"]["WORKING_DIR"]) 204 | logname = os.path.join(cfg["TRAIN"]["WORKING_DIR"], 'testing_log.txt') 205 | 206 | if cfg["TEST"]["VISUAL"] is True: 207 | if not os.path.exists(cfg["TEST"]["VISUAL_PATH"]): 208 | os.makedirs(cfg["TEST"]["VISUAL_PATH"]) 209 | 210 | print(">>>>>>>>>>>>>> LOADING NETWORK") 211 | model = Model(cfg).to(device) 212 | checkpoint = torch.load(cfg["TRAIN"]["WORKING_DIR"] + "checkpoint_1.pth", map_location=lambda storage, loc: storage.cuda()) 213 | pretrained_dict = checkpoint['state_dict'] 214 | best_epoch = checkpoint["epoch"] 215 | model.load_state_dict(pretrained_dict) 216 | model.eval() 217 | 218 | predictor = Predictor(cfg).to(device) 219 | checkpoint = torch.load(cfg["TRAIN"]["WORKING_DIR"] + "checkpoint_2.pth", map_location=lambda storage, loc: storage.cuda()) 220 | pretrained_dict = checkpoint['state_dict'] 221 | predictor.load_state_dict(pretrained_dict) 222 | predictor.eval() 223 | 224 | print(">>>>>>>>>>>>>> Network trained on epoch %03d has been loaded" % (best_epoch)) 225 | 226 | print(">>>>>>>>>>>>>> Loading reference database") 227 | ref_database = ref_loader(cfg) 228 | ref_info = {} 229 | 230 | with torch.no_grad(): 231 | for clsID in ref_database.ref_info.keys(): 232 | print(">>>>>>>>>>>>>> Estimating features for ref " + clsID) 233 | ref_info_clsID = ref_database.load(clsID) 234 | anchors, anchor_indices = sample_farthest_points(ref_info_clsID["Rs"].reshape(1, -1, 9), K=cfg["DATA"]["ANCHOR_NUM"], random_start_point=False) 235 | anchors, anchor_indices = anchors[0], anchor_indices[0] 236 | 237 | ref_info_clsID["anchors"] = anchors.to(device) 238 | ref_info_clsID["indices"] = anchor_indices.to(device) 239 | ref_info_clsID["Rs"] = ref_info_clsID["Rs"].to(device) 240 | ref_info_clsID["ref_f"] = [] 241 | 242 | ref_imgs = ref_info_clsID["imgs"] 243 | for i in trange(ref_imgs.shape[0]): 244 | ref_img = ref_imgs[i][None].to(device) 245 | ref_f = model(ref_img) 246 | ref_f = [ref_f[j].cpu().detach() for j in range(len(ref_f))] 247 | ref_info_clsID["ref_f"].append(ref_f) 248 | 249 | del ref_img, ref_imgs, ref_f 250 | torch.cuda.empty_cache() 251 | 252 | ref_info_clsID["anchor_f"] = [] 253 | for j in range(len(ref_info_clsID["ref_f"][0])): 254 | ref_info_clsID["anchor_f"].append(torch.cat([ref_info_clsID["ref_f"][idx][j].to(device) for idx in ref_info_clsID["indices"]], dim=0)) 255 | 256 | ref_info[clsID] = ref_info_clsID 257 | 258 | 259 | test_cat_acc_all, test_R_acc_all = [], [] 260 | for cat in cfg["TEST"]["UNSEEN"]: 261 | objID = cfg[cfg["DATA"]["DATASET"]][cat] 262 | test_cat_acc, test_R_acc = test_category(cfg, model, predictor, ref_info, device, objID) 263 | 264 | test_cat_acc_all += [test_cat_acc] 265 | test_R_acc_all += [test_R_acc] 266 | 267 | with open(logname, 'a') as f: 268 | f.write('objID: %02d -- || Testing Cls: %.2f || Testing R Acc: %.2f\n' % (objID, test_cat_acc, test_R_acc)) 269 | f.close() 270 | 271 | test_cat_acc_all = np.asarray(test_cat_acc_all).mean() 272 | test_R_acc_all = np.asarray(test_R_acc_all).mean() 273 | 274 | with open(logname, 'a') as f: 275 | f.write('All categories -- || Testing Cls Acc: %.2f || Testing R Acc: %.2f\n' % (test_cat_acc_all, test_R_acc_all)) 276 | f.close() 277 | 278 | if __name__ == '__main__': 279 | 280 | with open("./objects.yaml", 'r') as load_f: 281 | cfg = yaml.load(load_f, Loader=yaml.FullLoader) 282 | cfg["TRAIN"]["BS"] = 1 283 | 284 | if torch.cuda.is_available(): 285 | device = torch.device("cuda:0") 286 | torch.cuda.set_device(device) 287 | else: 288 | device = torch.device("cpu") 289 | 290 | test(cfg, device) 291 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import yaml 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | import torch.optim as optim 8 | import numpy as np 9 | import cv2 10 | from tqdm import tqdm 11 | import argparse 12 | from tqdm import trange 13 | from test import val 14 | from core.dataset import LINEMOD_SO3 as LINEMOD 15 | from core.loss import weighted_infoNCE_loss_func 16 | from core.utils import load_checkpoint 17 | from core.model import RetrievalNet as Model 18 | from core.model import Sim_predictor as Predictor 19 | 20 | np.set_printoptions(threshold=np.inf) 21 | torch.backends.cudnn.deterministic=True 22 | torch.backends.cudnn.enabled=True 23 | torch.backends.cudnn.benchmark=True 24 | torch.autograd.set_detect_anomaly(True) 25 | 26 | torch.manual_seed(0) 27 | np.random.seed(0) 28 | 29 | def train_one_epoch(epoch, train_loader, model, predictor, optimizer): 30 | model.train() 31 | predictor.train() 32 | train_loss = [] 33 | train_errs = [] 34 | for i, data in enumerate(train_loader): 35 | ## load data and label 36 | src_img, ref_img, src_R, ref_R, id = data 37 | 38 | if torch.any(id == -1): 39 | print("Skip incorrect data") 40 | continue 41 | 42 | src_img = src_img.cuda() 43 | ref_img = torch.cat(ref_img, dim=0).cuda() 44 | src_R, ref_R = src_R.cuda(), ref_R.cuda() 45 | id = id.cuda() 46 | 47 | B, _, H, W = src_img.shape 48 | 49 | ## feature extraction 50 | src_f = model(src_img) 51 | ref_f = model(ref_img) 52 | 53 | ## similarity estimation 54 | ref_sim = predictor(src_f, ref_f) 55 | 56 | ## loss estimation 57 | loss = weighted_infoNCE_loss_func(ref_sim[:, B:2*B], ref_sim[:, 2*B:], ref_R[:, 1], ref_R[:, 2], src_R, id, tau=0.1) 58 | 59 | try: 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | except: 64 | print("Skip incorrect data") 65 | continue 66 | 67 | train_loss.append(loss.item()) 68 | 69 | if i % 20 == 0: 70 | print("\tEpoch %3d --- Iter [%d/%d] Train --- Loss: %.4f" % (epoch, i + 1, len(train_loader), loss.item())) 71 | 72 | train_loss = np.asarray(train_loss).mean() 73 | return train_loss 74 | 75 | def train(cfg, device): 76 | print(">>>>>>>>>>>>>> CREATE DATASET") 77 | dataset_train = LINEMOD(cfg, 'train', 0) 78 | train_loader = DataLoader(dataset_train, batch_size=cfg["TRAIN"]["BS"], shuffle=True, \ 79 | num_workers=cfg["TRAIN"]["WORKERS"], drop_last=True) 80 | print(">>>>>>>>>> TRAINING DATA:", len(train_loader)*cfg["TRAIN"]["BS"]) 81 | 82 | print(">>>>>>>>>>>>>> CREATE NETWORK") 83 | model = Model(cfg).to(device) 84 | predictor = Predictor(cfg).to(device) 85 | 86 | print(">>>>>>>>>>>>>> CREATE OPTIMIZER") 87 | optimizer = optim.Adam([{'params': model.parameters()}, {'params': predictor.parameters()}], lr=cfg["TRAIN"]["LR"]) 88 | lrScheduler = optim.lr_scheduler.MultiStepLR(optimizer, cfg["TRAIN"]["STEP"], gamma=cfg["TRAIN"]["GAMMA"]) 89 | 90 | if not os.path.exists(cfg["TRAIN"]["WORKING_DIR"]): 91 | os.makedirs(cfg["TRAIN"]["WORKING_DIR"]) 92 | 93 | logname = os.path.join(cfg["TRAIN"]["WORKING_DIR"], 'training_log.txt') 94 | with open(logname, 'a') as f: 95 | f.write('training set: ' + str(len(dataset_train)) + '\n') 96 | 97 | if cfg["TRAIN"]["FROM_SCRATCH"] is False: 98 | print(">>>>>>>>>>>>>> LOAD MODEL") 99 | model, optimizer, start_epoch, best_acc = load_checkpoint(model, optimizer, cfg["TRAIN"]["WORKING_DIR"] + "checkpoint_1.pth") 100 | predictor, _, _, _ = load_checkpoint(predictor, optimizer, cfg["TRAIN"]["WORKING_DIR"] + "checkpoint_2.pth") 101 | else: 102 | print(">>>>>>>>>>>>>> TRAINING FROM SCRATCH") 103 | best_acc = 0 104 | start_epoch = 0 105 | 106 | print(">>>>>>>>>>>>>> START TRAINING") 107 | for epoch in trange(start_epoch, cfg["TRAIN"]["MAX_EPOCH"]): 108 | loss = train_one_epoch(epoch, train_loader, model, predictor, optimizer) 109 | 110 | # update learning rate 111 | lrScheduler.step() 112 | 113 | if (epoch + 1) % cfg["TRAIN"]["VAL_STEP"] == 0: 114 | res = val(cfg, model, predictor, device) 115 | 116 | if res[0] > best_acc: 117 | best_acc = res[0] 118 | state_dict = {'epoch': epoch, 'state_dict': model.state_dict(),\ 119 | 'optimizer': optimizer.state_dict(), 120 | 'best_acc': res[0]} 121 | torch.save(state_dict, os.path.join(cfg["TRAIN"]["WORKING_DIR"], 'checkpoint_1.pth')) 122 | 123 | state_dict = {'epoch': epoch, 'state_dict': predictor.state_dict(),\ 124 | 'optimizer': optimizer.state_dict(), 125 | 'best_acc': res[0]} 126 | torch.save(state_dict, os.path.join(cfg["TRAIN"]["WORKING_DIR"], 'checkpoint_2.pth')) 127 | 128 | with open(logname, 'a') as f: 129 | text = str('Epoch: %03d || train_loss %.4f || test_cls_acc: %.2f || test_R_acc %.2f\n' % (epoch, loss, res[1], res[0])) 130 | f.write(text) 131 | else: 132 | with open(logname, 'a') as f: 133 | text = str('Epoch: %03d || train_loss %.4f \n' % (epoch, loss)) 134 | f.write(text) 135 | 136 | if __name__ == '__main__': 137 | print(">>>>>>>>>>>>> Loding configuration") 138 | with open("./objects.yaml", 'r') as load_f: 139 | cfg = yaml.load(load_f, Loader=yaml.FullLoader) 140 | 141 | if torch.cuda.is_available(): 142 | device = torch.device("cuda:0") 143 | torch.cuda.set_device(device) 144 | else: 145 | device = torch.device("cpu") 146 | 147 | train(cfg, device) 148 | --------------------------------------------------------------------------------