├── LICENCE ├── README.md ├── assets ├── general.mp4 └── teaser.png ├── data_util ├── README.md ├── blender_triangulate.py ├── clear_obj.sh ├── csv_to_ply.py ├── csv_to_ply.sh ├── generate_pc.cpp ├── generate_pc.sh ├── generate_urdf.py ├── partA.obj ├── partA_cp.obj ├── preprocess_obj_blender.py └── preprocess_obj_blender.sh ├── environment.yml └── src ├── .DS_Store ├── config ├── eval.yml ├── train_A.yml └── train_B.yml ├── script ├── .DS_Store ├── our_eval.py ├── our_train_A.py └── our_train_B.py └── shape_assembly ├── .DS_Store ├── __init__.py ├── __pycache__ ├── config.cpython-38.pyc └── utils.cpython-38.pyc ├── config.py ├── config_eval.py ├── datasets ├── .DS_Store └── dataloader │ ├── .DS_Store │ ├── __pycache__ │ ├── cls_dataloader.cpython-38.pyc │ ├── dataloader_A.cpython-38.pyc │ ├── dataloader_CR.cpython-38.pyc │ ├── dataloader_kit.cpython-38.pyc │ ├── dataloader_kit1.cpython-38.pyc │ ├── dataloader_kit_1.cpython-38.pyc │ ├── dataloader_kit_ori.cpython-38.pyc │ ├── dataloader_kit_single.cpython-38.pyc │ ├── dataloader_our.cpython-38.pyc │ ├── dataloader_our1.cpython-38.pyc │ ├── sub1.cpython-38.pyc │ └── sub_1_7.cpython-38.pyc │ ├── dataloader_A.py │ └── dataloader_B.py ├── models ├── decoder │ ├── .DS_Store │ ├── MLPDecoder.py │ └── __pycache__ │ │ └── MLPDecoder.cpython-38.pyc ├── encoder │ ├── .DS_Store │ ├── __pycache__ │ │ ├── dgcnn.cpython-38.pyc │ │ ├── pointnet.cpython-38.pyc │ │ ├── vn_dgcnn.cpython-38.pyc │ │ └── vn_layers.cpython-38.pyc │ ├── vn_dgcnn.py │ ├── vn_dgcnn_util.py │ └── vn_layers.py └── train │ ├── __pycache__ │ ├── Pose_Refinement_Module.cpython-38.pyc │ ├── network_vnn1.cpython-38.pyc │ ├── network_vnn2.cpython-38.pyc │ ├── network_vnn_A.cpython-38.pyc │ ├── network_vnn_A_indi.cpython-38.pyc │ ├── network_vnn_A_indi2.cpython-38.pyc │ ├── network_vnn_B.cpython-38.pyc │ ├── network_vnn_B2.cpython-38.pyc │ ├── regressor_CR.cpython-38.pyc │ └── transformer.cpython-38.pyc │ ├── network_vnn_A.py │ ├── network_vnn_A_indi.py │ ├── network_vnn_B.py │ ├── pose_estimator.py │ ├── regressor_CR.py │ └── transformer.py ├── utils.py └── version.py /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Yu Qi and Yuanchen Ju 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Two by Two✌️ : Learning Multi-Task Pairwise Objects Assembly for Generalizable Robot Manipulation](https://tea-lab.github.io/TwoByTwo/) 2 | 3 | Project Page 4 | | 5 | arXiv 6 | | 7 | Twitter 8 | | Dataset 9 | 10 | Yu Qi*, 11 | Yuanchen Ju*, 12 | Tianming Wei, 13 | Chi Chu, 14 | Lawson L.S. Wong, 15 | Huazhe Xu 16 | 17 | 18 | **CVPR, 2025** 19 | 20 | 21 |
22 | 2by2 23 |
24 | 25 | 26 | # 🛠️ Installation 27 | 28 | This project is tested on Ubuntu 22.04 with CUDA 11.8. 29 | 30 | - Install [Anaconda](https://www.anaconda.com/docs/getting-started/anaconda/install#macos-linux-installation:to-download-a-different-version) or [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install#linux-terminal-installer) 31 | - Clone the repository and create the environment. The environment should be installed correctly within minutes. 32 | 33 | ```python 34 | git clone git@github.com:TEA-Lab/TwoByTwo.git 35 | conda env create -f environment.yml 36 | conda activate twobytwo 37 | ``` 38 | 39 | - (Optional) If you would like to calculate Chamfer Distance, clone the [CUDA-accelerated Chamfer Distance library](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch/tree/master): 40 | 41 | ```python 42 | cd src/shape_assembly/models 43 | git clone https://github.com/ThibaultGROUEIX/ChamferDistancePytorch.git 44 | ``` 45 | 46 | ## 🧩 Dataset 47 | 48 | 2BY2 Dataset has been released. To obtain our dataset, please fill out this [form](https://docs.google.com/forms/d/e/1FAIpQLSfhPcAdky8ZojjPlSSHN4ubYqc7WHIwfiqFW2L5YpqAHbbVgg/viewform?usp=sharing). 49 | 50 | ## 🍰 Dataset Utility Support 51 | It is recommended to use our pre-generated point cloud. In the meantime, you can also generate your own point cloud, add your own data, or generate **URDF(Unified Robot Description Format)** file for robot simulation purpose, please see `data_util` folder for more detailed instructions. 52 | 53 | ## 🐰Training and Inference 54 | 55 | In `src/config` modify the path of `log_dir` `data root_dir`. We support Distributed Data Parallel Training. 56 | 57 | 58 | - Train Network B 59 | 60 | ```python 61 | cd src 62 | python script/our_train_B.py --cfg_file train_B.yml 63 | ``` 64 | 65 | - Train Network A 66 | ```python 67 | cd src 68 | python script/our_train_A.py --cfg_file train_A.yml 69 | ``` 70 | - Inference 71 | 72 | ```python 73 | cd src 74 | python script/our_eval.py 75 | ``` 76 | 77 | # 🎟️ Licence 78 | This repository is released under the MIT license. Refer to [LICENSE](LICENSE) for more information. 79 | 80 | # 🎨 Acknowledgement & Contact 81 | Our codebase is developed based on [SE3-part-assembly](https://crtie.github.io/SE-3-part-assembly/), and we express our gratitude to all the authors for their generously open-sourced code, as well as the open-source contributions of all baseline projects [Puzzlefusion++](https://puzzlefusion-plusplus.github.io.), [Jigsaw](https://jiaxin-lu.github.io/Jigsaw/), [Neural Shape Mating](https://neural-shape-mating.github.io/). for their valuable impact on the community. 82 | 83 | For inquiries about this project, please reach out to **Yu Qi: qi.yu2@northeastern.edu** and **Yuanchen Ju: juuycc0213@gmail.com**. You’re also welcome to open an issue or submit a pull request!😄 84 | 85 | 86 | # 🎸 BibTeX 87 | 88 | We would appreciate it if you find this work useful and consider citing it. 89 | ``` 90 | @article{qi2025two, 91 | title={Two by two: Learning multi-task pairwise objects assembly for generalizable robot manipulation}, 92 | author={Qi, Yu and Ju, Yuanchen and Wei, Tianming and Chu, Chi and Wong, Lawson LS and Xu, Huazhe}, 93 | journal={CVPR 2025}, 94 | year={2025} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /assets/general.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/assets/general.mp4 -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/assets/teaser.png -------------------------------------------------------------------------------- /data_util/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Utility Support 2 | 3 | ## Point Cloud Generation 4 | 5 | You can directly use our sampled point cloud in 2BY2 dataset, which contains 1024 points per object. But you can also generate your own point cloud. Same to [Breaking Bad](https://breaking-bad-dataset.github.io/), we use blue noise sampling to generate point cloud. 6 | 7 | ```shell 8 | sudo apt update 9 | 10 | sudo apt install build-essential cmake libglfw3-dev libgl1-mesa-dev libx11-dev libxi-dev libxrandr-dev libxinerama-dev libxcursor-dev libxxf86vm-dev libxext-dev libpthread-stubs0-dev libdl-dev 11 | 12 | g++ -I ./eigen-3.4.0 -I ./libigl/include -I ./glad/include -I /usr/include \ 13 | -L /usr/lib \ 14 | -o generate_pc \ 15 | generate_pc.cpp ./glad/src/glad.c \ 16 | -ldl -lGL -lglfw -pthread 17 | 18 | bash generate_pc.sh 19 | ``` 20 | 21 | After this step, you will generate `.csv` files containing point clouds. 22 | 23 | If you would like different number of points, change `L150` of `generate_pc.cpp` to your target number. 24 | 25 | ## Mesh Preprocessing 26 | 27 | We triangulate each mesh by blender and uniformly scale them before generate point cloud. This step takes `partA.obj` and `partB.obj` as input and outputs the uniform mesh `partA_new.obj` and `partB_new.obj` 28 | 29 | ```shell 30 | bash preprocess_obj_blender.sh 31 | ``` 32 | 33 | ## URDF(Unified Robot Description Format) Generation 34 | 35 | If you would like to use our mesh in simulator, you can run the following script. 36 | 37 | ```python 38 | python generate_urdf.py 39 | ``` 40 | 41 | When creating VHACD file for collision, you might need to modify the `simplify_quadric_decimation(0.8)` parameter for better collision effect. -------------------------------------------------------------------------------- /data_util/blender_triangulate.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import sys 3 | 4 | def triangulate_object(obj): 5 | bpy.ops.object.select_all(action='DESELECT') 6 | obj.select_set(True) 7 | # Select the object 8 | bpy.context.view_layer.objects.active = obj 9 | # Switch to object mode 10 | bpy.ops.object.mode_set(mode='OBJECT') 11 | # Switch to edit mode 12 | bpy.ops.object.mode_set(mode='EDIT') 13 | # Select all faces 14 | bpy.ops.mesh.select_all(action='SELECT') 15 | # Triangulate the mesh 16 | bpy.ops.mesh.quads_convert_to_tris() 17 | # Switch back to object mode 18 | bpy.ops.object.mode_set(mode='OBJECT') 19 | 20 | def main(): 21 | input_path = sys.argv[-2] 22 | output_path = sys.argv[-1] 23 | 24 | bpy.context.preferences.system.audio_device = 'Null' 25 | 26 | # Clear existing scene 27 | bpy.ops.wm.read_factory_settings(use_empty=True) 28 | 29 | # Import the OBJ file 30 | bpy.ops.import_scene.obj(filepath=input_path) 31 | 32 | # Triangulate all imported objects 33 | for obj in bpy.context.scene.objects: 34 | if obj.type == 'MESH': 35 | triangulate_object(obj) 36 | 37 | # Export the triangulated mesh to OBJ 38 | bpy.ops.export_scene.obj(filepath=output_path, use_selection=False) 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /data_util/clear_obj.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle") 4 | 5 | # delete partA_new.obj, partB_new.obj 6 | 7 | for folder in "${input_folders[@]}"; do 8 | for subfolder in $(find "$folder" -type d); do 9 | partA_obj="$subfolder/partA_new.obj" 10 | partB_obj="$subfolder/partB_new.obj" 11 | # partA_ply="$subfolder/partA-pc.ply" 12 | # partB_ply="$subfolder/partB-pc.ply" 13 | 14 | if [[ -f "$partA_obj" && -f "$partB_obj" ]]; then 15 | echo "$partA_obj and $partB_obj found, deleted" 16 | rm "$partA_obj" 17 | rm "$partB_obj" 18 | else 19 | echo "partA_new.obj and partB_new.obj is messing" 20 | fi 21 | done 22 | done -------------------------------------------------------------------------------- /data_util/csv_to_ply.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | 4 | # Function to read the CSV file and remove duplicates 5 | def read_csv_remove_duplicates(input_csv): 6 | df = pd.read_csv(input_csv, header=None, names=["x", "y", "z"]) 7 | # df = df.drop_duplicates() 8 | return df 9 | 10 | # Function to write the DataFrame to a PLY file 11 | def write_ply(df, output_ply): 12 | with open(output_ply, 'w+') as ply_file: 13 | ply_file.write('ply\n') 14 | ply_file.write('format ascii 1.0\n') 15 | ply_file.write(f'element vertex {len(df)}\n') 16 | ply_file.write('property float x\n') 17 | ply_file.write('property float y\n') 18 | ply_file.write('property float z\n') 19 | ply_file.write('property uchar red\n') 20 | ply_file.write('property uchar green\n') 21 | ply_file.write('property uchar blue\n') 22 | ply_file.write('end_header\n') 23 | for _, row in df.iterrows(): 24 | ply_file.write(f'{row["x"]} {row["y"]} {row["z"]} 255 0 0\n') 25 | 26 | # Main function to convert CSV to PLY 27 | def convert_csv_to_ply(input_csv, output_ply): 28 | df = read_csv_remove_duplicates(input_csv) 29 | write_ply(df, output_ply) 30 | 31 | # Example usage 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser(description="Convert CSV to PLY") 34 | parser.add_argument('--input_csv', default='points.csv', type=str) 35 | parser.add_argument('--output_ply', default='points.ply', type=str) 36 | args= parser.parse_args() 37 | convert_csv_to_ply(args.input_csv, args.output_ply) 38 | -------------------------------------------------------------------------------- /data_util/csv_to_ply.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle") 4 | 5 | for folder in "${input_folders[@]}"; do 6 | for subfolder in $(find "$folder" -type d); do 7 | partA_csv="$subfolder/partA-pc.csv" 8 | partB_csv="$subfolder/partB-pc.csv" 9 | partA_ply="$subfolder/partA-pc.ply" 10 | partB_ply="$subfolder/partB-pc.ply" 11 | 12 | if [[ -f "$partA_csv" && -f "$partB_csv" ]]; then 13 | echo "Processing $subfolder" 14 | python csv_to_ply.py --input_csv "$partA_csv" --output_ply "$partA_ply" 15 | python csv_to_ply.py --input_csv "$partB_csv" --output_ply "$partB_ply" 16 | else 17 | echo "Skipping $subfolder: partA.obj or partB.obj not found" 18 | fi 19 | done 20 | done -------------------------------------------------------------------------------- /data_util/generate_pc.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | const static Eigen::IOFormat CSVFormat(Eigen::StreamPrecision, Eigen::DontAlignCols, ",", "\n"); 18 | 19 | void sample_pc_with_blue_noise( 20 | const int & num_points, 21 | const Eigen::MatrixXd & mesh_v, 22 | const Eigen::MatrixXi & mesh_f, 23 | Eigen::MatrixXd & pc, 24 | Eigen::MatrixXd & normals) 25 | { 26 | if(mesh_v.size() == 0 || mesh_f.size() == 0) { 27 | std::cerr << "Error: Input mesh is empty.\n"; 28 | return; 29 | } 30 | 31 | Eigen::VectorXd A; 32 | igl::doublearea(mesh_v, mesh_f, A); 33 | double radius = sqrt(((A.sum()*0.5/(num_points*0.6162910373))/igl::PI)); 34 | std::cout << "Initial Blue noise radius: " << radius << "\n"; 35 | 36 | Eigen::MatrixXd B; 37 | Eigen::VectorXi I; 38 | int max_attempts = 1; 39 | for(int attempt = 0; attempt < max_attempts; ++attempt) 40 | { 41 | igl::blue_noise(mesh_v, mesh_f, radius, B, I, pc); 42 | std::cout<<"successfully generate a set of blue noise!"<= num_points * 0.9 && pc.rows() <= num_points * 1.1) 44 | { 45 | break; 46 | } 47 | if(pc.rows() == 0) { 48 | std::cerr << "Error: Blue noise sampling generated an empty point cloud. Attempt: " << attempt + 1 << "\n"; 49 | break; 50 | } 51 | radius *= sqrt(num_points * 1.0 / pc.rows()); 52 | //std::cout << "Adjusted Blue noise radius: " << radius << " (Attempt " << attempt + 1 << ")\n"; 53 | } 54 | 55 | if(pc.rows() == 0) { 56 | std::cerr << "Error: Blue noise sampling failed to generate points.\n"; 57 | return; 58 | } 59 | 60 | if (pc.rows() > num_points) 61 | { 62 | std::cout << "Trimming point cloud from " << pc.rows() << " to " << num_points << " points\n"; 63 | std::vector indices(pc.rows()); 64 | std::iota(indices.begin(), indices.end(), 0); 65 | std::shuffle(indices.begin(), indices.end(), std::default_random_engine{}); 66 | indices.resize(num_points); 67 | Eigen::MatrixXd trimmed_pc(num_points, pc.cols()); 68 | //Eigen::MatrixXd trimmed_normals(num_points, normals.cols()); 69 | std::cout << "pc.rows(): " << pc.rows() << ", pc.cols(): " << pc.cols() << std::endl; 70 | //std::cout << "normals.rows(): " << normals.rows() << ", normals.cols(): " << normals.cols() << std::endl; 71 | for (int i = 0; i < num_points; ++i) 72 | { 73 | //std::cout <<"for trimming" << " i is" << i << std::endl; 74 | trimmed_pc.row(i) = pc.row(indices[i]); 75 | //std::cout<<"1"< \n"; 127 | return 1; 128 | } 129 | 130 | std::string mesh_file_left = argv[1]; 131 | std::string mesh_file_right = argv[2]; 132 | std::string save_root_directory = argv[3]; 133 | Eigen::MatrixXd L_mesh_v, R_mesh_v; 134 | Eigen::MatrixXi L_mesh_f, R_mesh_f; 135 | 136 | if (!igl::read_triangle_mesh(mesh_file_left, L_mesh_v, L_mesh_f)) 137 | { 138 | std::cerr << "Failed to load mesh from " << mesh_file_left << std::endl; 139 | return 1; 140 | } 141 | 142 | if (!igl::read_triangle_mesh(mesh_file_right, R_mesh_v, R_mesh_f)) 143 | { 144 | std::cerr << "Failed to load mesh from " << mesh_file_right << std::endl; 145 | return 1; 146 | } 147 | 148 | std::cout << "Sampling point clouds...\n"; 149 | int num_points = 1024; 150 | Eigen::MatrixXd L_pc, R_pc; 151 | Eigen::MatrixXd L_normal, R_normal; 152 | sample_pc_with_blue_noise(num_points, L_mesh_v, L_mesh_f, L_pc, L_normal); 153 | sample_pc_with_blue_noise(num_points, R_mesh_v, R_mesh_f, R_pc, R_normal); 154 | 155 | std::cout << "The target root file is " << save_root_directory << std::endl; 156 | 157 | std::cout << "Saving point clouds...\n"; 158 | std::string L_pc_file = save_root_directory + '/' + "partA-pc.csv"; 159 | std::string R_pc_file = save_root_directory + '/' + "partB-pc.csv"; 160 | write_matrix_to_csv(L_pc_file, L_pc); 161 | write_matrix_to_csv(R_pc_file, R_pc); 162 | 163 | std::cout << "Saving normals...\n"; 164 | std::string L_normal_file = save_root_directory + '/' + "partA-normal.csv"; 165 | std::string R_normal_file = save_root_directory + '/' + "partB-normal.csv"; 166 | write_matrix_to_csv(L_normal_file, L_normal); 167 | write_matrix_to_csv(R_normal_file, R_normal); 168 | 169 | std::cout << "Saved point clouds and normals to " << save_root_directory << std::endl; 170 | 171 | return 0; 172 | } -------------------------------------------------------------------------------- /data_util/generate_pc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle") 4 | 5 | for folder in "${input_folders[@]}"; do 6 | for subfolder in $(find "$folder" -type d); do 7 | partA_file="$subfolder/partA_new.obj" 8 | partB_file="$subfolder/partB_new.obj" 9 | 10 | if [[ -f "$partA_file" && -f "$partB_file" ]]; then 11 | echo "Processing $subfolder" 12 | ./generate_pc "$partA_file" "$partB_file" "$subfolder" 13 | else 14 | echo "Skipping $subfolder: partA_new.obj or partB_new.obj not found" 15 | fi 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /data_util/generate_urdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import trimesh 3 | 4 | input_obj = "insert_flower/1/partA_new.obj" 5 | vhacd_obj = "insert_flower/1/vhacd_partA_new.obj" 6 | urdf_file = "insert_flower/1/partA.urdf" 7 | 8 | # Step 1: Run VHACD 9 | mesh = trimesh.load(input_obj) 10 | print(f"loaded mesh with {len(mesh.faces)} faces and {len(mesh.vertices)} vertices") 11 | 12 | # simple = mesh.simplify_quadric_decimation(500) 13 | 14 | simple = mesh.simplify_quadric_decimation(0.8) 15 | 16 | simple.export(vhacd_obj) 17 | 18 | print(f"VHACD complete. Output saved to: {vhacd_obj}") 19 | 20 | # Step 2: Write URDF 21 | print(f"Generating URDF: {urdf_file}") 22 | 23 | urdf_content = f""" 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | """ 43 | 44 | with open(urdf_file, "w") as f: 45 | f.write(urdf_content) 46 | 47 | print(f"URDF file written: {urdf_file}") -------------------------------------------------------------------------------- /data_util/preprocess_obj_blender.py: -------------------------------------------------------------------------------- 1 | # This script takes a raw .glb file, adds vertex colors or throw away if without color 2 | # It then translates by its center of mass and normalize the bounding box 3 | # Finally, it saves the mesh or point cloud to a specified folder 4 | import open3d as o3d 5 | import numpy as np 6 | import os 7 | import time 8 | import matplotlib.pyplot as plt 9 | import json 10 | import trimesh 11 | import sys 12 | # import bpy 13 | import subprocess 14 | from tqdm import tqdm 15 | 16 | global_scale = 3 17 | 18 | current_dir = os.path.dirname(os.path.abspath(__file__)) 19 | sys.path.append(current_dir) 20 | 21 | """ 22 | partA.obj -> partA_tri.obj (triangulate 23 | partA_tri.obj -> partA_new.obj (recenter and rescale) 24 | normal_vector.json 25 | """ 26 | 27 | # scale_factor = 2 28 | def file_exists(filepath): 29 | return os.path.exists(filepath) 30 | 31 | # activate when need to triangulate the meshes 32 | 33 | def triangulate_mesh_with_pyblender(instance_dir): 34 | blender_script_path = os.path.join(current_dir, 'blender_triangulate.py') 35 | subprocess.run(['blender', '--factory-startup', '--background', '--python', blender_script_path, '--', os.path.join(current_dir, instance_dir, 'partA.obj'), os.path.join(current_dir, instance_dir, 'partA_tri.obj')]) 36 | subprocess.run(['blender','--factory-startup', '--background', '--python', blender_script_path, '--', os.path.join(current_dir, instance_dir, 'partB.obj'), os.path.join(current_dir, instance_dir, 'partB_tri.obj')]) 37 | 38 | 39 | 40 | # use pyblender change the non-triangle mesh to triangle mesh 41 | def combine_meshes(instance_dir): 42 | triangulate_mesh_with_pyblender(instance_dir) 43 | 44 | # Load the two meshes 45 | mesh1_path = os.path.join(current_dir, instance_dir, "partA_tri.obj") 46 | mesh2_path = os.path.join(current_dir, instance_dir, "partB_tri.obj") 47 | combined_mesh_path = os.path.join(current_dir, instance_dir, "combined_mesh.obj") 48 | 49 | if not file_exists(mesh1_path) or not file_exists(mesh2_path): 50 | sys.exit(1) 51 | 52 | 53 | mesh1 = o3d.io.read_triangle_mesh(mesh1_path) 54 | # mesh1.triangulate() 55 | mesh2 = o3d.io.read_triangle_mesh(mesh2_path) 56 | # mesh2.triangulate() 57 | 58 | if not mesh1.has_vertices(): 59 | raise ValueError("mesh1 does not contain any vertices.", mesh1_path) 60 | if not mesh2.has_vertices(): 61 | raise ValueError("mesh2 does not contain any vertices.", mesh2_path) 62 | 63 | # Combine vertices and triangles 64 | combined_vertices = np.vstack((np.asarray(mesh1.vertices), np.asarray(mesh2.vertices))) 65 | combined_triangles = np.vstack((np.asarray(mesh1.triangles), np.asarray(mesh2.triangles) + len(mesh1.vertices))) 66 | 67 | # Create a new mesh with the combined vertices and triangles 68 | combined_mesh = o3d.geometry.TriangleMesh() 69 | combined_mesh.vertices = o3d.utility.Vector3dVector(combined_vertices) 70 | combined_mesh.triangles = o3d.utility.Vector3iVector(combined_triangles) 71 | 72 | # Optionally combine vertex colors if they exist 73 | if mesh1.has_vertex_colors() and mesh2.has_vertex_colors(): 74 | combined_colors = np.vstack((np.asarray(mesh1.vertex_colors), np.asarray(mesh2.vertex_colors))) 75 | combined_mesh.vertex_colors = o3d.utility.Vector3dVector(combined_colors) 76 | 77 | # print("Combined mesh saved to {}".format(combined_mesh_path)) 78 | return (mesh1, mesh2, combined_mesh) 79 | 80 | def calculate_center_and_scale_two_seperate_part(instance_dir): 81 | 82 | # calculate the center after combined two meshes and remove them to the center 83 | (mesh1, mesh2, combined_mesh) = combine_meshes(instance_dir) 84 | global_center = combined_mesh.get_center() 85 | 86 | mesh1.translate(-global_center) 87 | mesh2.translate(-global_center) 88 | combined_mesh.translate(-global_center) 89 | 90 | bounding_box = combined_mesh.get_axis_aligned_bounding_box() 91 | max_extent = np.max(bounding_box.get_extent()) 92 | max_extent_index = np.argmax(bounding_box.get_extent()) 93 | normal_vector = np.zeros(3) 94 | normal_vector[max_extent_index] = 1 95 | 96 | # Write stats to JSON file 97 | stats = { 98 | "Normal Vector":{ 99 | "x": normal_vector[0], 100 | "y": normal_vector[1], 101 | "z": normal_vector[2] 102 | } 103 | } 104 | 105 | with open(os.path.join(current_dir, instance_dir, "normal_vector.json"), "w+") as f: 106 | json.dump(stats, f, indent=4) 107 | 108 | 109 | scale_factor = global_scale / max_extent 110 | 111 | mesh1.scale(scale_factor, (0,0,0)) 112 | mesh2.scale(scale_factor, (0,0,0)) 113 | combined_mesh.scale(scale_factor, (0,0,0)) 114 | 115 | """ 116 | save the final files, after re-centered and re-scale, including writing normal_vector.json 117 | """ 118 | 119 | o3d.io.write_triangle_mesh(os.path.join(current_dir, instance_dir, "partA_new.obj"), mesh1) 120 | o3d.io.write_triangle_mesh(os.path.join(current_dir, instance_dir, "partB_new.obj"), mesh2) 121 | o3d.io.write_triangle_mesh(os.path.join(current_dir, instance_dir, "combined_mesh.obj"), combined_mesh) 122 | 123 | # Example usage 124 | # calculate_center_and_scale_two_seperate_part("original-2.obj", "original-3.obj", "combined_mesh.obj") 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | instance_dir = sys.argv[1] 130 | print("In the processing of {}".format(instance_dir)) 131 | calculate_center_and_scale_two_seperate_part(instance_dir) 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /data_util/preprocess_obj_blender.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle") 4 | 5 | for folder in "${input_folders[@]}"; do 6 | for subfolder in $(find "$folder" -type d); do 7 | partA_obj="$subfolder/partA.obj" 8 | partB_obj="$subfolder/partB.obj" 9 | # partA_ply="$subfolder/partA-pc.ply" 10 | # partB_ply="$subfolder/partB-pc.ply" 11 | 12 | if [[ -f "$partA_obj" && -f "$partB_obj" ]]; then 13 | echo "Processing $subfolder" 14 | python preprocess_obj_blender.py "$subfolder" 15 | # python csv_to_ply.py --input_csv "$partB_csv" --output_ply "$partB_ply" 16 | else 17 | echo "Skipping $subfolder: partA.obj or partB.obj not found" 18 | fi 19 | done 20 | done -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: twobytwo 2 | 3 | channels: 4 | 5 | - conda-forge 6 | 7 | - defaults 8 | 9 | dependencies: 10 | 11 | - _libgcc_mutex=0.1=conda_forge 12 | 13 | - _openmp_mutex=4.5=2_gnu 14 | 15 | - ca-certificates=2024.7.4=hbcca054_0 16 | 17 | - ld_impl_linux-64=2.40=hf3520f5_7 18 | 19 | - libffi=3.2.1=he1b5a44_1007 20 | 21 | - libgcc-ng=14.1.0=h77fa898_0 22 | 23 | - libgomp=14.1.0=h77fa898_0 24 | 25 | - libsqlite=3.46.0=hde9e2c9_0 26 | 27 | - libstdcxx-ng=14.1.0=hc0a3c3a_0 28 | 29 | - libzlib=1.2.13=h4ab18f5_6 30 | 31 | - ncurses=6.5=h59595ed_0 32 | 33 | - openssl=1.1.1w=hd590300_0 34 | 35 | - pip=24.0=pyhd8ed1ab_0 36 | 37 | - python=3.8.0=h357f687_5 38 | 39 | - readline=8.2=h8228510_1 40 | 41 | - setuptools=71.0.4=pyhd8ed1ab_0 42 | 43 | - sqlite=3.46.0=h6d4b2fc_0 44 | 45 | - tk=8.6.13=noxft_h4845f30_101 46 | 47 | - wheel=0.43.0=pyhd8ed1ab_1 48 | 49 | - xz=5.2.6=h166bdaf_0 50 | 51 | - zlib=1.2.13=h4ab18f5_6 52 | 53 | - pip: 54 | 55 | - addict==2.4.0 56 | 57 | - aiohttp==3.9.5 58 | 59 | - aiosignal==1.3.1 60 | 61 | - asttokens==2.4.1 62 | 63 | - async-timeout==4.0.3 64 | 65 | - attrs==23.2.0 66 | 67 | - backcall==0.2.0 68 | 69 | - blinker==1.8.2 70 | 71 | - certifi==2024.7.4 72 | 73 | - charset-normalizer==3.3.2 74 | 75 | - click==8.1.7 76 | 77 | - comm==0.2.2 78 | 79 | - configargparse==1.7 80 | 81 | - contourpy==1.1.1 82 | 83 | - cycler==0.12.1 84 | 85 | - dash==2.17.1 86 | 87 | - dash-core-components==2.0.0 88 | 89 | - dash-html-components==2.0.0 90 | 91 | - dash-table==5.0.0 92 | 93 | - decorator==5.1.1 94 | 95 | - docker-pycreds==0.4.0 96 | 97 | - einops==0.8.0 98 | 99 | - executing==2.0.1 100 | 101 | - fastjsonschema==2.20.0 102 | 103 | - filelock==3.15.4 104 | 105 | - flask==3.0.3 106 | 107 | - fonttools==4.53.1 108 | 109 | - freetype-py==2.4.0 110 | 111 | - frozenlist==1.4.1 112 | 113 | - fsspec==2024.6.1 114 | 115 | - fvcore==0.1.5.post20221221 116 | 117 | - gitdb==4.0.11 118 | 119 | - gitpython==3.1.43 120 | 121 | - h5py==3.11.0 122 | 123 | - idna==3.7 124 | 125 | - imageio==2.34.2 126 | 127 | - importlib-metadata==8.2.0 128 | 129 | - importlib-resources==6.4.0 130 | 131 | - iopath==0.1.10 132 | 133 | - ipdb==0.13.13 134 | 135 | - ipython==8.12.3 136 | 137 | - ipywidgets==8.1.3 138 | 139 | - itsdangerous==2.2.0 140 | 141 | - jedi==0.19.1 142 | 143 | - jinja2==3.1.4 144 | 145 | - joblib==1.4.2 146 | 147 | - jsonschema==4.23.0 148 | 149 | - jsonschema-specifications==2023.12.1 150 | 151 | - jupyter-core==5.7.2 152 | 153 | - jupyterlab-widgets==3.0.11 154 | 155 | - kiwisolver==1.4.5 156 | 157 | - lazy-loader==0.4 158 | 159 | - lightning-utilities==0.11.6 160 | 161 | - markupsafe==2.1.5 162 | 163 | - matplotlib==3.7.5 164 | 165 | - matplotlib-inline==0.1.7 166 | 167 | - mesh-to-sdf==0.0.15 168 | 169 | - mpmath==1.3.0 170 | 171 | - multidict==6.0.5 172 | 173 | - nbformat==5.10.4 174 | 175 | - nest-asyncio==1.6.0 176 | 177 | - networkx==3.1 178 | 179 | - ninja==1.11.1.1 180 | 181 | - numpy==1.24.4 182 | 183 | - nvidia-cublas-cu12==12.1.3.1 184 | 185 | - nvidia-cuda-cupti-cu12==12.1.105 186 | 187 | - nvidia-cuda-nvrtc-cu12==12.1.105 188 | 189 | - nvidia-cuda-runtime-cu12==12.1.105 190 | 191 | - nvidia-cudnn-cu12==9.1.0.70 192 | 193 | - nvidia-cufft-cu12==11.0.2.54 194 | 195 | - nvidia-curand-cu12==10.3.2.106 196 | 197 | - nvidia-cusolver-cu12==11.4.5.107 198 | 199 | - nvidia-cusparse-cu12==12.1.0.106 200 | 201 | - nvidia-nccl-cu12==2.20.5 202 | 203 | - nvidia-nvjitlink-cu12==12.5.82 204 | 205 | - nvidia-nvtx-cu12==12.1.105 206 | 207 | - open3d==0.18.0 208 | 209 | - packaging==24.1 210 | 211 | - pandas==2.0.3 212 | 213 | - parso==0.8.4 214 | 215 | - pexpect==4.9.0 216 | 217 | - pickleshare==0.7.5 218 | 219 | - pillow==10.4.0 220 | 221 | - pkgutil-resolve-name==1.3.10 222 | 223 | - platformdirs==4.2.2 224 | 225 | - plotly==5.23.0 226 | 227 | - portalocker==2.10.1 228 | 229 | - progressbar==2.5 230 | 231 | - prompt-toolkit==3.0.47 232 | 233 | - protobuf==5.27.2 234 | 235 | - psutil==6.0.0 236 | 237 | - ptyprocess==0.7.0 238 | 239 | - pure-eval==0.2.3 240 | 241 | - pyglet==2.0.16 242 | 243 | - pygments==2.18.0 244 | 245 | - pyopengl==3.1.0 246 | 247 | - pyparsing==3.1.2 248 | 249 | - pyquaternion==0.9.9 250 | 251 | - pyrender==0.1.45 252 | 253 | - python-dateutil==2.9.0.post0 254 | 255 | - pytorch-lightning==2.3.3 256 | 257 | - pytorch3d==0.3.0 258 | 259 | - pytz==2024.1 260 | 261 | - pywavelets==1.4.1 262 | 263 | - pyyaml==6.0.1 264 | 265 | - referencing==0.35.1 266 | 267 | - requests==2.32.3 268 | 269 | - retrying==1.3.4 270 | 271 | - rpds-py==0.19.1 272 | 273 | - scikit-image==0.21.0 274 | 275 | - scikit-learn==1.3.2 276 | 277 | - scipy==1.10.1 278 | 279 | - sentry-sdk==2.11.0 280 | 281 | - setproctitle==1.3.3 282 | 283 | - six==1.16.0 284 | 285 | - smmap==5.0.1 286 | 287 | - stack-data==0.6.3 288 | 289 | - sympy==1.13.1 290 | 291 | - tabulate==0.9.0 292 | 293 | - tenacity==8.5.0 294 | 295 | - termcolor==2.4.0 296 | 297 | - threadpoolctl==3.5.0 298 | 299 | - tifffile==2023.7.10 300 | 301 | - tomli==2.0.1 302 | 303 | - torch==2.4.0 304 | 305 | - torchmetrics==1.4.0.post0 306 | 307 | - torchvision==0.19.0 308 | 309 | - tqdm==4.66.4 310 | 311 | - traitlets==5.14.3 312 | 313 | - transforms3d==0.4.2 314 | 315 | - trimesh==4.4.3 316 | 317 | - triton==3.0.0 318 | 319 | - typing-extensions==4.12.2 320 | 321 | - tzdata==2024.1 322 | 323 | - urllib3==2.2.2 324 | 325 | - wandb==0.17.5 326 | 327 | - wcwidth==0.2.13 328 | 329 | - werkzeug==3.0.3 330 | 331 | - widgetsnbextension==4.0.11 332 | 333 | - yacs==0.1.8 334 | 335 | - yarl==1.9.4 336 | 337 | - zipp==3.19.2 338 | 339 | prefix: /localdata/miniconda3/envs/se3 340 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/.DS_Store -------------------------------------------------------------------------------- /src/config/eval.yml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: 'eval' 3 | checkpoint_dir: 'checkpoints' 4 | num_workers: 4 5 | batch_size: 4 6 | num_epochs: 1000 7 | gpus: [0] 8 | log_dir: '/home/two-by-two/logs/eval' # checkpoints will be read from this dir 9 | vis_dir: '/home/two-by-two/logs/eval' 10 | 11 | model: 12 | encoderA: 'vn_dgcnn' 13 | encoderB: 'vn_dgcnn' 14 | pose_predictor_quat: 'original' 15 | pose_predictor_rot: 'original' 16 | pose_predictor_trans: 'original' 17 | point_loss: 'True' 18 | recon_loss: 'False' 19 | corr_module: 'yes' 20 | pc_feat_dim: 512 21 | num_heads: 4 22 | num_blocks: 1 23 | 24 | cls_model: 25 | encoder: 'vn_dgcnn' 26 | point_loss: 'False' 27 | recon_loss: 'False' 28 | corr_module: 'No' 29 | pc_feat_dim: 512 30 | num_heads: 4 31 | num_blocks: 1 32 | 33 | # gpus: [0,1,2,3,4,5,6,7] 34 | 35 | # please modify the data path to the path of the generated point cloud 36 | data: 37 | root_dir: '/home/Shape_Data_pc' 38 | train_csv_file: '/home/Shape_Data_pc/stats/Train_All_List.txt' 39 | val_csv_file: '/home/Shape_Data_pc/stats/Test_All_List.txt' 40 | num_pc_points: 1024 41 | translation: '' 42 | 43 | optimizer: 44 | lr: 1e-4 45 | lr_decay: 0.0 46 | weight_decay: 1e-6 47 | lr_clip: 1e-5 -------------------------------------------------------------------------------- /src/config/train_A.yml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: 'network_A' 3 | checkpoint_dir: 'checkpoints' 4 | num_workers: 4 5 | batch_size: 4 6 | num_epochs: 1000 7 | gpus: [0] 8 | log_dir: '/home/two-by-two/logs/network_A' #modify to the correct log path 9 | vis_dir: '/home/two-by-two/logs/network_A' #modify to the correct log path 10 | 11 | model: 12 | encoderA: 'vn_dgcnn' 13 | encoderB: 'vn_dgcnn' 14 | pose_predictor_quat: 'original' 15 | pose_predictor_rot: 'original' 16 | pose_predictor_trans: 'original' 17 | point_loss: 'True' 18 | recon_loss: 'False' 19 | corr_module: 'yes' 20 | pc_feat_dim: 512 21 | num_heads: 4 22 | num_blocks: 1 23 | 24 | cls_model: 25 | encoder: 'vn_dgcnn' 26 | point_loss: 'False' 27 | recon_loss: 'False' 28 | corr_module: 'No' 29 | pc_feat_dim: 512 30 | num_heads: 4 31 | num_blocks: 1 32 | 33 | # gpus: [0,1,2,3,4,5,6,7] 34 | 35 | # please modify the data path to the path of the generated point cloud 36 | data: 37 | root_dir: '/home/Shape_Data_pc' 38 | train_csv_file: '/home/Shape_Data_pc/stats/Train_All_List.txt' 39 | val_csv_file: '/home/Shape_Data_pc/stats/Test_All_List.txt' 40 | num_pc_points: 1024 41 | translation: '' 42 | 43 | optimizer: 44 | lr: 1e-4 45 | lr_decay: 0.0 46 | weight_decay: 1e-6 47 | lr_clip: 1e-5 -------------------------------------------------------------------------------- /src/config/train_B.yml: -------------------------------------------------------------------------------- 1 | exp: 2 | name: 'network_B' 3 | checkpoint_dir: 'checkpoints' 4 | num_workers: 4 5 | batch_size: 4 6 | num_epochs: 1000 7 | gpus: [0] 8 | log_dir: '/home/two-by-two/logs/network_B' #modify to the correct log path 9 | vis_dir: '/home/two-by-two/logs/network_B' #modify to the correct log path 10 | 11 | model: 12 | encoderA: 'vn_dgcnn' 13 | encoderB: 'vn_dgcnn' 14 | pose_predictor_quat: 'original' 15 | pose_predictor_rot: 'original' 16 | pose_predictor_trans: 'original' 17 | point_loss: 'True' 18 | recon_loss: 'False' 19 | corr_module: 'yes' 20 | pc_feat_dim: 512 21 | num_heads: 4 22 | num_blocks: 1 23 | 24 | cls_model: 25 | encoder: 'vn_dgcnn' 26 | point_loss: 'False' 27 | recon_loss: 'False' 28 | corr_module: 'No' 29 | pc_feat_dim: 512 30 | num_heads: 4 31 | num_blocks: 1 32 | 33 | # gpus: [0,1,2,3,4,5,6,7] 34 | 35 | # please modify the data path to the path of the generated point cloud 36 | data: 37 | root_dir: '/home/Shape_Data_pc' 38 | train_csv_file: '/home/Shape_Data_pc/stats/Train_All_List.txt' 39 | val_csv_file: '/home/Shape_Data_pc/stats/Test_All_List.txt' 40 | num_pc_points: 1024 41 | translation: '' 42 | 43 | optimizer: 44 | lr: 1e-4 45 | lr_decay: 0.0 46 | weight_decay: 1e-6 47 | lr_clip: 1e-5 -------------------------------------------------------------------------------- /src/script/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/script/.DS_Store -------------------------------------------------------------------------------- /src/script/our_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import argparse 4 | import wandb 5 | 6 | wandb.require("core") 7 | 8 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6' 9 | 10 | # Suppress specific warnings 11 | warnings.filterwarnings('ignore', category=UserWarning, message='.*TORCH_CUDA_ARCH_LIST.*') 12 | 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.nn import DataParallel 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.distributed import DistributedSampler 19 | import torch 20 | 21 | import sys 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly')) 24 | from config import get_cfg_defaults 25 | from datasets.dataloader.dataloader_B import OurDataset 26 | from models.train.network_vnn_B import ShapeAssemblyNet_B_vnn 27 | from models.train.network_vnn_A import ShapeAssemblyNet_A_vnn 28 | import utils 29 | from tqdm import tqdm 30 | 31 | orange = '\033[38;5;214m' 32 | reset = '\033[0m' 33 | 34 | def setup(rank, world_size, cfg): 35 | os.environ['MASTER_ADDR'] = 'localhost' 36 | os.environ['MASTER_PORT'] = '52997' 37 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 38 | torch.cuda.set_device(rank) 39 | 40 | def cleanup(): 41 | dist.destroy_process_group() 42 | 43 | 44 | def train(rank, world_size, conf): 45 | # create the log_file 46 | os.makedirs(conf.exp.log_dir, exist_ok=True) 47 | 48 | # create the vis file 49 | os.makedirs(conf.exp.vis_dir, exist_ok=True) 50 | 51 | setup(rank, world_size, conf) 52 | 53 | if dist.get_rank() == 0: 54 | wandb.init(project='shape-matching', notes='weak baseline', config=conf) 55 | wandb.define_metric("test/epoch/*", step_metric="test/epoch/epoch") 56 | # wandb.define_metric("val/step/*", step_metric="val/step/step") 57 | wandb.define_metric("train/network_B/*", step_metric="train/network_B/step") 58 | wandb.define_metric("train/network_A/*", step_metric="train/network_A/step") 59 | 60 | data_features = ['src_pc', 'src_rot', 'src_trans', 'tgt_pc', 'tgt_rot', 'tgt_trans', 'partA_symmetry_type', 'partB_symmetry_type','predicted_partB_rotation', 'predicted_partB_position', 'predicted_partA_rotation', 'predicted_partA_position'] 61 | 62 | network_B = ShapeAssemblyNet_B_vnn(cfg=conf, data_features=data_features) 63 | network_B.cuda(rank) 64 | 65 | # Load pretrained weights for network_B 66 | network_B.load_state_dict(torch.load(os.path.join(conf.exp.log_dir, 'ckpts', 'network_B.pth'))) 67 | 68 | for param in network_B.parameters(): 69 | param.requires_grad = False 70 | 71 | 72 | network_A = ShapeAssemblyNet_A_vnn(cfg=conf, data_features=data_features) 73 | network_A.cuda(rank) 74 | 75 | # Load pretrained weights for network_A 76 | network_A.load_state_dict(torch.load(os.path.join(conf.exp.log_dir, 'ckpts', 'network_A.pth'))) 77 | network_A = DDP(network_A, device_ids=[rank]) 78 | 79 | for param in network_A.parameters(): 80 | param.requires_grad= False 81 | 82 | # Initialize train dataloader 83 | train_data = OurDataset( 84 | data_root_dir=conf.data.root_dir, 85 | data_csv_file=conf.data.train_csv_file, 86 | data_features=data_features, 87 | num_points=conf.data.num_pc_points 88 | ) 89 | train_data.load_data() 90 | 91 | print('Len of Train Data: ', len(train_data)) 92 | 93 | # Initialize val dataloader 94 | val_data = OurDataset( 95 | data_root_dir=conf.data.root_dir, 96 | data_csv_file=conf.data.val_csv_file, 97 | data_features=data_features, 98 | num_points=conf.data.num_pc_points 99 | ) 100 | val_data.load_data() 101 | 102 | print('Len of Val Data: ', len(val_data)) 103 | 104 | val_sampler = DistributedSampler(val_data, num_replicas=world_size, rank=rank) 105 | val_dataloader = DataLoader( 106 | dataset=val_data, 107 | batch_size=conf.exp.batch_size, 108 | num_workers=conf.exp.num_workers, 109 | pin_memory=True, 110 | # shuffle=True, 111 | shuffle=False, 112 | drop_last=False, 113 | sampler=val_sampler 114 | ) 115 | 116 | network_opt = torch.optim.Adam(list(network_A.parameters()), lr=conf.optimizer.lr, weight_decay=conf.optimizer.weight_decay) 117 | val_num_batch = len(val_dataloader) 118 | print(f"\033[33mNumber of batches in val_dataloader: {len(val_dataloader)}\033[0m") 119 | print(f"\033[33mNumber of samples in val_data: {len(val_data)}\033[0m") 120 | 121 | # === Evaluation === 122 | tot, tot_gd, tot_r, tot_t, tot_CD_A, tot_CD_B = 0, 0, 0, 0, 0, 0 123 | 124 | network_A.train() 125 | network_B.train() 126 | 127 | for val_batch_ind, val_batch in enumerate(val_dataloader): 128 | for key in val_batch.keys(): 129 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']: 130 | val_batch[key] = val_batch[key].to(rank) 131 | with torch.no_grad(): 132 | partB_eval_metric, partB_pos, partB_rot, *_ = network_B.forward_pass( 133 | batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind) 134 | val_batch['predicted_partB_position'] = partB_pos 135 | val_batch['predicted_partB_rotation'] = partB_rot 136 | 137 | partA_eval_metric, *_ = network_A.module.forward_pass( 138 | batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind) 139 | 140 | GD_B, R_B, T_B, _, _, CD_B = partB_eval_metric 141 | GD_A, R_A, T_A, _, _, CD_A = partA_eval_metric 142 | 143 | tot += 1 144 | tot_gd += (GD_A.mean() + GD_B.mean()) / 2 145 | tot_r += (R_A.mean() + R_B.mean()) / 2 146 | tot_t += (T_A.mean() + T_B.mean()) / 2 147 | tot_CD_A += CD_A.mean() 148 | tot_CD_B += CD_B.mean() 149 | 150 | if dist.get_rank() == 0: 151 | print(f"\033[1;32m[Eval Results]\033[0m") 152 | print(f"avg_gd: {tot_gd / tot}") 153 | print(f"avg_r: {tot_r / tot}") 154 | print(f"avg_t: {tot_t / tot}") 155 | print(f"avg_CD_A: {tot_CD_A / tot}") 156 | print(f"avg_CD_B: {tot_CD_B / tot}") 157 | 158 | wandb.log({ 159 | 'test/epoch/avg_gd': (tot_gd / tot).item(), 160 | 'test/epoch/avg_r': (tot_r / tot).item(), 161 | 'test/epoch/avg_t': (tot_t / tot).item(), 162 | 'test/epoch/avg_CD_A': (tot_CD_A / tot).item(), 163 | 'test/epoch/avg_CD_B': (tot_CD_B / tot).item(), 164 | 'test/epoch/epoch': 0 165 | }) 166 | 167 | dist.barrier() 168 | cleanup() 169 | 170 | 171 | def main(cfg): 172 | world_size = len(cfg.gpus) 173 | 174 | mp.spawn(train, args=(world_size, cfg), nprocs=world_size, join=True) 175 | 176 | if __name__ == '__main__': 177 | 178 | wandb.login() 179 | 180 | parser = argparse.ArgumentParser(description="Training script") 181 | parser.add_argument('--cfg_file', default='', type=str) 182 | parser.add_argument('--gpus', nargs='+', default=-1, type=int) 183 | 184 | args = parser.parse_args() 185 | args.cfg_file = os.path.join('./config', args.cfg_file) 186 | 187 | cfg = get_cfg_defaults() 188 | cfg.merge_from_file(args.cfg_file) 189 | 190 | cfg.gpus = cfg.exp.gpus 191 | 192 | cfg.freeze() 193 | print(cfg) 194 | 195 | main(cfg) -------------------------------------------------------------------------------- /src/script/our_train_A.py: -------------------------------------------------------------------------------- 1 | """ 2 | train network_A based on partB's ground truth 3 | """ 4 | 5 | 6 | import os 7 | import warnings 8 | import argparse 9 | import wandb 10 | 11 | wandb.require("core") 12 | 13 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6' 14 | 15 | # Suppress specific warnings 16 | warnings.filterwarnings('ignore', category=UserWarning, message='.*TORCH_CUDA_ARCH_LIST.*') 17 | 18 | import torch.distributed as dist 19 | import torch.multiprocessing as mp 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | from torch.nn import DataParallel 22 | from torch.utils.data import DataLoader 23 | from torch.utils.data.distributed import DistributedSampler 24 | import torch 25 | 26 | import sys 27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly')) 29 | from config import get_cfg_defaults 30 | from datasets.dataloader.dataloader_A import OurDataset 31 | from models.train.network_vnn_A_indi import ShapeAssemblyNet_A_vnn 32 | import utils 33 | from tqdm import tqdm 34 | 35 | orange = '\033[38;5;214m' 36 | reset = '\033[0m' 37 | 38 | def setup(rank, world_size, cfg): 39 | os.environ['MASTER_ADDR'] = 'localhost' 40 | os.environ['MASTER_PORT'] = '13013' 41 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 42 | torch.cuda.set_device(rank) 43 | 44 | def cleanup(): 45 | dist.destroy_process_group() 46 | 47 | def train(rank, world_size, conf): 48 | os.makedirs(conf.exp.log_dir, exist_ok=True) 49 | os.makedirs(conf.exp.vis_dir, exist_ok=True) 50 | setup(rank, world_size, conf) 51 | 52 | if dist.get_rank() == 0: 53 | wandb.init(project='shape-matching', notes='weak baseline', config=conf) 54 | wandb.define_metric("test/epoch/*", step_metric="test/epoch/epoch") 55 | wandb.define_metric("train/network_A/*", step_metric="train/network_A/step") 56 | 57 | data_features = ['src_pc', 'src_rot', 'src_trans', 'tgt_pc', 'tgt_rot', 'tgt_trans', 'partA_symmetry_type', 'partB_symmetry_type','predicted_partB_rotation', 'predicted_partB_position', 'predicted_partA_rotation', 'predicted_partA_position'] 58 | 59 | network_A = ShapeAssemblyNet_A_vnn(cfg=conf, data_features=data_features) 60 | network_A.cuda(rank) 61 | network_A = DDP(network_A, device_ids=[rank]) 62 | 63 | # Initialize train dataloader 64 | train_data = OurDataset( 65 | data_root_dir=conf.data.root_dir, 66 | data_csv_file=conf.data.train_csv_file, 67 | data_features=data_features, 68 | num_points=conf.data.num_pc_points 69 | ) 70 | train_data.load_data() 71 | 72 | print('Len of Train Data: ', len(train_data)) 73 | 74 | train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank) 75 | train_dataloader = DataLoader( 76 | dataset=train_data, 77 | batch_size=conf.exp.batch_size, 78 | num_workers=conf.exp.num_workers, 79 | pin_memory=True, 80 | shuffle=False, 81 | drop_last=False, 82 | sampler=train_sampler 83 | ) 84 | 85 | # Initialize val dataloader 86 | val_data = OurDataset( 87 | data_root_dir=conf.data.root_dir, 88 | data_csv_file=conf.data.val_csv_file, 89 | data_features=data_features, 90 | num_points=conf.data.num_pc_points 91 | ) 92 | val_data.load_data() 93 | 94 | # Output the distribution of the validation data 95 | print('Len of Val Data: ', len(val_data)) 96 | 97 | val_sampler = DistributedSampler(val_data, num_replicas=world_size, rank=rank) 98 | val_dataloader = DataLoader( 99 | dataset=val_data, 100 | batch_size=conf.exp.batch_size, 101 | num_workers=conf.exp.num_workers, 102 | pin_memory=True, 103 | # shuffle=True, 104 | shuffle=False, 105 | drop_last=False, 106 | sampler=val_sampler 107 | ) 108 | 109 | network_opt = torch.optim.Adam(list(network_A.parameters()), lr=conf.optimizer.lr, weight_decay=conf.optimizer.weight_decay) 110 | val_num_batch = len(val_dataloader) 111 | print(f"\033[33mNumber of batches in val_dataloader: {len(val_dataloader)}\033[0m") 112 | print(f"\033[33mNumber of samples in val_data: {len(val_data)}\033[0m") 113 | train_num_batch = len(train_dataloader) 114 | 115 | val_step = 0 116 | for epoch in tqdm(range(1, conf.exp.num_epochs + 1)): 117 | print("\033[1;32m", "Epoch {} In training".format(epoch), "\033[0m") 118 | train_dataloader.sampler.set_epoch(epoch) 119 | val_dataloader.sampler.set_epoch(epoch) 120 | 121 | train_batches = enumerate(train_dataloader, 0) 122 | val_batches = enumerate(val_dataloader, 0) 123 | val_fraction_done = 0.0 124 | val_batch_ind = -1 125 | 126 | # train for every batch 127 | for train_batch_ind, batch in train_batches: 128 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0: 129 | print("*" * 10) 130 | print(epoch, train_batch_ind) 131 | # print("*" * 10) 132 | train_fraction_done = (train_batch_ind + 1) / train_num_batch 133 | train_step = epoch * train_num_batch + train_batch_ind 134 | 135 | log_console = True 136 | 137 | # save checkpoint of network_A and network_B 138 | if epoch % 10 == 0 and train_batch_ind == 0 and dist.get_rank() == 0: 139 | with torch.no_grad(): 140 | 141 | os.makedirs(os.path.join(conf.exp.log_dir, 'ckpts'), exist_ok=True) 142 | torch.save(network_A.module.state_dict(), os.path.join(conf.exp.log_dir, 'ckpts', '%d-network_A.pth' % epoch), _use_new_zipfile_serialization=False) 143 | 144 | network_A.train() 145 | for key in batch.keys(): 146 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']: 147 | batch[key] = batch[key].cuda(non_blocking=True) 148 | 149 | partA_predictions = network_A.module.training_step(batch_data=batch, device=rank, batch_idx=train_batch_ind) 150 | total_loss_A = partA_predictions["total_loss"] 151 | rot_loss_A = partA_predictions["rot_loss"] 152 | trans_loss_A = partA_predictions["trans_loss"] 153 | 154 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0: 155 | print(total_loss_A.detach().cpu().numpy()) 156 | 157 | step = train_step 158 | if dist.get_rank() == 0: 159 | wandb.log({'train/network_A/total_loss': total_loss_A.item(), 'train/network_A/rot_loss': rot_loss_A.item(), 'train/network_A/trans_loss':trans_loss_A.item(), 'train/network_A/step':step}) 160 | 161 | 162 | total_loss = total_loss_A 163 | 164 | # optimize one step 165 | network_opt.zero_grad() 166 | total_loss.backward() 167 | network_opt.step() 168 | 169 | dist.barrier() 170 | 171 | """ 172 | In evaluation mode: 173 | """ 174 | 175 | if epoch % 1 == 0: 176 | tot = 0 177 | tot_gd_A = 0 178 | # tot_gd_B = 0 179 | tot_r_A = 0 180 | # tot_r_B = 0 181 | tot_t_A = 0 182 | # tot_t_B = 0 183 | tot_pa_A = 0 184 | # tot_pa_B = 0 185 | tot_pa = 0 186 | tot_t = 0 187 | tot_r = 0 188 | tot_gd = 0 189 | tot_pa_threshold = 0 190 | tot_CD_A = 0 191 | # tot_CD_B = 0 192 | tot_pa_threshold_A = 0 193 | # tot_pa_threshold_B = 0 194 | val_batches = enumerate(val_dataloader, 0) 195 | val_fraction_done = 0.0 196 | val_batch_ind = -1 197 | # device = torch.device('cuda:0') 198 | 199 | # train for every batch 200 | 201 | total_loss_epoch = 0 202 | rot_loss_epoch = 0 203 | trans_loss_epoch = 0 204 | 205 | for val_batch_ind, val_batch in val_batches: 206 | if val_batch_ind % 50 == 0: 207 | print("*" * 10) 208 | print(epoch, val_batch_ind) 209 | print("*" * 10) 210 | 211 | 212 | network_A.train() 213 | # network_B.train() 214 | 215 | for key in val_batch.keys(): 216 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']: 217 | val_batch[key] = val_batch[key].to(rank) 218 | with torch.no_grad(): 219 | 220 | partA_eval_metric, partA_total_loss, partA_point_loss, partA_rot_loss, partA_trans_loss, partA_recon_loss = network_A.module.forward_pass(batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind) 221 | GD_A, R_error_A, RMSE_T_A, PA_threshold_A, PA_A, CD_A = partA_eval_metric 222 | 223 | total_loss_epoch += partA_total_loss.item() 224 | rot_loss_epoch += partA_rot_loss.item() 225 | trans_loss_epoch += partA_trans_loss.item() 226 | 227 | val_step += 1 228 | 229 | tot_gd_A += GD_A.mean() 230 | # tot_gd_B += GD_B.mean() 231 | tot_r_A += R_error_A.mean() # the rotation loss for each batch 232 | # tot_r_B += R_error_B.mean() # the rotation loss for each batch 233 | tot_t_A += RMSE_T_A.mean() 234 | # tot_t_B += RMSE_T_B.mean() 235 | tot_pa_threshold_A += PA_threshold_A.mean() 236 | # tot_pa_threshold_B += PA_threshold_B.mean() 237 | tot_pa_A += PA_A.mean() 238 | # tot_pa_B += PA_B.mean() 239 | tot_CD_A += CD_A.mean() 240 | # tot_CD_B += CD_B.mean() 241 | 242 | tot_r = tot_r_A 243 | # tot_r = tot_r_A 244 | tot_t = tot_t_A 245 | # tot_t = tot_t_A 246 | tot_gd = tot_gd_A 247 | tot += 1 248 | 249 | if dist.get_rank() == 0: 250 | print("\033[1;32m", "Epoch {} In validation".format(epoch), "\033[0m") 251 | 252 | print("avg_gd: ", tot_gd / tot) 253 | print("avg_r: ", tot_r / tot) 254 | print("avg_t: ", tot_t / tot) 255 | 256 | print("avg_CD_1: ", tot_CD_A / tot) 257 | 258 | log_data = { 259 | 'test/epoch/avg_gd':(tot_gd/tot).item(), 260 | 'test/epoch/avg_r': (tot_r/tot).item(), 261 | 'test/epoch/avg_t': (tot_t/tot).item(), 262 | 'test/epoch/avg_CD_A': (tot_CD_A/tot).item(), 263 | 'test/epoch/total_loss': total_loss_epoch, 264 | 'test/epoch/rot_loss': rot_loss_epoch, 265 | 'test/epoch/trans_loss': trans_loss_epoch, 266 | 'test/epoch/epoch': epoch 267 | } 268 | wandb.log( 269 | log_data 270 | ) 271 | 272 | dist.barrier() 273 | cleanup() 274 | 275 | 276 | def main(cfg): 277 | world_size = len(cfg.gpus) 278 | 279 | mp.spawn(train, args=(world_size, cfg), nprocs=world_size, join=True) 280 | 281 | 282 | if __name__ == '__main__': 283 | 284 | wandb.login() 285 | 286 | parser = argparse.ArgumentParser(description="Training script") 287 | parser.add_argument('--cfg_file', default='', type=str) 288 | parser.add_argument('--gpus', nargs='+', default=-1, type=int) 289 | 290 | args = parser.parse_args() 291 | args.cfg_file = os.path.join('./config', args.cfg_file) 292 | 293 | cfg = get_cfg_defaults() 294 | cfg.merge_from_file(args.cfg_file) 295 | 296 | cfg.gpus = cfg.exp.gpus 297 | 298 | cfg.freeze() 299 | print(cfg) 300 | 301 | main(cfg) -------------------------------------------------------------------------------- /src/script/our_train_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import argparse 4 | import wandb 5 | 6 | wandb.require("core") 7 | 8 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6' 9 | 10 | # Suppress specific warnings 11 | warnings.filterwarnings('ignore', category=UserWarning, message='.*TORCH_CUDA_ARCH_LIST.*') 12 | 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.nn import DataParallel 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.distributed import DistributedSampler 19 | import torch 20 | 21 | import sys 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly')) 24 | from config import get_cfg_defaults 25 | from datasets.dataloader.dataloader_B import OurDataset 26 | from models.train.network_vnn_B import ShapeAssemblyNet_B_vnn 27 | import utils 28 | # from tensorboardX import SummaryWriter 29 | from tqdm import tqdm 30 | 31 | orange = '\033[38;5;214m' 32 | reset = '\033[0m' 33 | 34 | def setup(rank, world_size, cfg): 35 | os.environ['MASTER_ADDR'] = 'localhost' 36 | os.environ['MASTER_PORT'] = '12719' 37 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 38 | torch.cuda.set_device(rank) 39 | 40 | def cleanup(): 41 | dist.destroy_process_group() 42 | 43 | def train(rank, world_size, conf): 44 | os.makedirs(conf.exp.log_dir, exist_ok=True) 45 | os.makedirs(conf.exp.vis_dir, exist_ok=True) 46 | setup(rank, world_size, conf) 47 | 48 | if dist.get_rank() == 0: 49 | wandb.init(project='shape-matching', notes='weak baseline', config=conf) 50 | wandb.define_metric("test/epoch/*", step_metric="test/epoch/epoch") 51 | wandb.define_metric("train/network_B/*", step_metric="train/network_B/step") 52 | 53 | data_features = ['src_pc', 'src_rot', 'src_trans', 'tgt_pc', 'tgt_rot', 'tgt_trans', 'partA_symmetry_type', 'partB_symmetry_type','predicted_partB_rotation', 'predicted_partB_position', 'predicted_partA_rotation', 'predicted_partA_position'] 54 | 55 | network_B = ShapeAssemblyNet_B_vnn(cfg=conf, data_features=data_features) 56 | network_B.cuda(rank) 57 | network_B = DDP(network_B, device_ids=[rank]) 58 | 59 | # Initialize train dataloader 60 | train_data = OurDataset( 61 | data_root_dir=conf.data.root_dir, 62 | data_csv_file=conf.data.train_csv_file, 63 | data_features=data_features, 64 | num_points=conf.data.num_pc_points 65 | ) 66 | train_data.load_data() 67 | 68 | print('Len of Train Data: ', len(train_data)) 69 | 70 | train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank) 71 | train_dataloader = DataLoader( 72 | dataset=train_data, 73 | batch_size=conf.exp.batch_size, 74 | num_workers=conf.exp.num_workers, 75 | pin_memory=True, 76 | shuffle=False, 77 | drop_last=False, 78 | sampler=train_sampler 79 | ) 80 | 81 | # Initialize val dataloader 82 | val_data = OurDataset( 83 | data_root_dir=conf.data.root_dir, 84 | data_csv_file=conf.data.val_csv_file, 85 | data_features=data_features, 86 | num_points=conf.data.num_pc_points 87 | ) 88 | val_data.load_data() 89 | 90 | print('Len of Val Data: ', len(val_data)) 91 | 92 | val_sampler = DistributedSampler(val_data, num_replicas=world_size, rank=rank) 93 | val_dataloader = DataLoader( 94 | dataset=val_data, 95 | batch_size=conf.exp.batch_size, 96 | num_workers=conf.exp.num_workers, 97 | pin_memory=True, 98 | shuffle=False, 99 | drop_last=False, 100 | sampler=val_sampler 101 | ) 102 | 103 | network_opt = torch.optim.Adam(list(network_B.parameters()), lr=conf.optimizer.lr, weight_decay=conf.optimizer.weight_decay) 104 | val_num_batch = len(val_dataloader) 105 | print(f"\033[33mNumber of batches in val_dataloader: {len(val_dataloader)}\033[0m") 106 | print(f"\033[33mNumber of samples in val_data: {len(val_data)}\033[0m") 107 | train_num_batch = len(train_dataloader) 108 | 109 | val_step = 0 110 | for epoch in tqdm(range(1, conf.exp.num_epochs + 1)): 111 | print("\033[1;32m", "Epoch {} In training".format(epoch), "\033[0m") 112 | train_dataloader.sampler.set_epoch(epoch) 113 | val_dataloader.sampler.set_epoch(epoch) 114 | 115 | train_batches = enumerate(train_dataloader, 0) 116 | val_batches = enumerate(val_dataloader, 0) 117 | val_fraction_done = 0.0 118 | val_batch_ind = -1 119 | 120 | # train for every batch 121 | for train_batch_ind, batch in train_batches: 122 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0: 123 | print("*" * 10) 124 | print(epoch, train_batch_ind) 125 | # print("*" * 10) 126 | train_fraction_done = (train_batch_ind + 1) / train_num_batch 127 | train_step = epoch * train_num_batch + train_batch_ind 128 | 129 | log_console = True 130 | 131 | # save checkpoint of network_A and network_B 132 | if epoch % 10 == 0 and train_batch_ind == 0 and dist.get_rank() == 0: 133 | with torch.no_grad(): 134 | 135 | os.makedirs(os.path.join(conf.exp.log_dir, 'ckpts'), exist_ok=True) 136 | torch.save(network_B.module.state_dict(), os.path.join(conf.exp.log_dir, 'ckpts', '%d-network_B.pth' % epoch), _use_new_zipfile_serialization=False) 137 | 138 | network_B.train() 139 | for key in batch.keys(): 140 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']: 141 | batch[key] = batch[key].cuda(non_blocking=True) 142 | 143 | partB_predictions = network_B.module.training_step(batch_data=batch, device=rank, batch_idx=train_batch_ind) 144 | total_loss_B = partB_predictions["total_loss"] 145 | rot_loss_B = partB_predictions["rot_loss"] 146 | trans_loss_B = partB_predictions["trans_loss"] 147 | 148 | batch['predicted_partB_position'] = partB_predictions['predicted_partB_position'] 149 | batch['predicted_partB_rotation'] = partB_predictions['predicted_partB_rotation'] 150 | 151 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0: 152 | print(total_loss_B.detach().cpu().numpy()) 153 | 154 | step = train_step 155 | if dist.get_rank() == 0: 156 | wandb.log({'train/network_B/total_loss': total_loss_B.item(), 'train/network_B/rot_loss': rot_loss_B.item(), 'train/network_B/trans_loss': trans_loss_B.item(), 'train/network_B/step':step}) 157 | 158 | step = train_step 159 | 160 | total_loss = total_loss_B 161 | 162 | network_opt.zero_grad() 163 | total_loss.backward() 164 | network_opt.step() 165 | 166 | dist.barrier() 167 | 168 | if epoch % 1 == 0: 169 | tot = 0 170 | tot_gd_A = 0 171 | tot_gd_B = 0 172 | tot_r_A = 0 173 | tot_r_B = 0 174 | tot_t_A = 0 175 | tot_t_B = 0 176 | tot_pa_A = 0 177 | tot_pa_B = 0 178 | tot_pa = 0 179 | tot_t = 0 180 | tot_r = 0 181 | tot_gd = 0 182 | tot_pa_threshold = 0 183 | tot_CD_A = 0 184 | tot_CD_B = 0 185 | tot_pa_threshold_A = 0 186 | tot_pa_threshold_B = 0 187 | val_batches = enumerate(val_dataloader, 0) 188 | val_fraction_done = 0.0 189 | val_batch_ind = -1 190 | 191 | total_loss_epoch = 0 192 | rot_loss_epoch = 0 193 | trans_loss_epoch = 0 194 | 195 | for val_batch_ind, val_batch in val_batches: 196 | if val_batch_ind % 50 == 0: 197 | print("*" * 10) 198 | print(epoch, val_batch_ind) 199 | print("*" * 10) 200 | 201 | # network_A.train() 202 | network_B.train() 203 | 204 | for key in val_batch.keys(): 205 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']: 206 | val_batch[key] = val_batch[key].to(rank) 207 | with torch.no_grad(): 208 | 209 | partB_eval_metric, partB_position, partB_rotation, partB_total_loss, partB_point_loss, partB_rot_loss, partB_trans_loss, partB_recon_loss = network_B.module.forward_pass(batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind) 210 | GD_B, R_error_B, RMSE_T_B, PA_threshold_B, PA_B, CD_B = partB_eval_metric 211 | 212 | total_loss_epoch += partB_total_loss.item() 213 | rot_loss_epoch += partB_rot_loss.item() 214 | trans_loss_epoch += partB_trans_loss.item() 215 | 216 | val_batch['predicted_partB_position'] = partB_position 217 | val_batch['predicted_partB_rotation'] = partB_rotation 218 | 219 | val_step += 1 220 | 221 | # [todo] 222 | 223 | # tot_gd_A += GD_A.mean() 224 | tot_gd_B += GD_B.mean() 225 | # tot_r_A += R_error_A.mean() # the rotation loss for each batch 226 | tot_r_B += R_error_B.mean() # the rotation loss for each batch 227 | # tot_t_A += RMSE_T_A.mean() 228 | tot_t_B += RMSE_T_B.mean() 229 | # tot_pa_threshold_A += PA_threshold_A.mean() 230 | tot_pa_threshold_B += PA_threshold_B.mean() 231 | # tot_pa_A += PA_A.mean() 232 | tot_pa_B += PA_B.mean() 233 | # tot_CD_A += CD_A.mean() 234 | tot_CD_B += CD_B.mean() 235 | 236 | tot_r = tot_r_B 237 | tot_t = tot_t_B 238 | tot_gd = (tot_gd_A + tot_gd_B) / 2 239 | tot += 1 240 | 241 | if dist.get_rank() == 0: 242 | # print("logging wandb !!! in val !!!!! ") 243 | print("\033[1;32m", "Epoch {} In validation".format(epoch), "\033[0m") 244 | 245 | print("avg_gd: ", tot_gd / tot) 246 | print("avg_r: ", tot_r / tot) 247 | print("avg_t: ", tot_t / tot) 248 | print("avg_CD_1: ", tot_CD_A / tot) 249 | print("avg_CD_2: ", tot_CD_B / tot) 250 | 251 | log_data = { 252 | 'test/epoch/avg_gd':(tot_gd/tot).item(), 253 | 'test/epoch/avg_r': (tot_r/tot).item(), 254 | 'test/epoch/avg_t': (tot_t/tot).item(), 255 | 'test/epoch/avg_CD_B': (tot_CD_B/tot).item(), 256 | 'test/epoch/total_loss': total_loss_epoch, 257 | 'test/epoch/rot_loss': rot_loss_epoch, 258 | 'test/epoch/trans_loss': trans_loss_epoch, 259 | 'test/epoch/epoch': epoch 260 | } 261 | wandb.log( 262 | log_data 263 | ) 264 | 265 | # val_step += 1 266 | 267 | dist.barrier() 268 | cleanup() 269 | 270 | 271 | def main(cfg): 272 | world_size = len(cfg.gpus) 273 | mp.spawn(train, args=(world_size, cfg), nprocs=world_size, join=True) 274 | 275 | 276 | if __name__ == '__main__': 277 | 278 | wandb.login() 279 | 280 | parser = argparse.ArgumentParser(description="Training script") 281 | parser.add_argument('--cfg_file', default='', type=str) 282 | parser.add_argument('--gpus', nargs='+', default=-1, type=int) 283 | 284 | args = parser.parse_args() 285 | args.cfg_file = os.path.join('./config', args.cfg_file) 286 | 287 | cfg = get_cfg_defaults() 288 | cfg.merge_from_file(args.cfg_file) 289 | 290 | cfg.gpus = cfg.exp.gpus 291 | 292 | cfg.freeze() 293 | print(cfg) 294 | main(cfg) -------------------------------------------------------------------------------- /src/shape_assembly/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/.DS_Store -------------------------------------------------------------------------------- /src/shape_assembly/__init__.py: -------------------------------------------------------------------------------- 1 | from utils import * -------------------------------------------------------------------------------- /src/shape_assembly/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # Miscellaneous configs 4 | _C = CN() 5 | 6 | # Experiment related 7 | _C.exp = CN() 8 | _C.exp.name = '' 9 | _C.exp.checkpoint_dir = '' 10 | _C.exp.weight_file = '' 11 | _C.exp.gpus = [0] 12 | _C.exp.num_workers = 8 13 | _C.exp.batch_size = 1 14 | _C.exp.num_epochs = 1000 15 | _C.exp.log_dir = '' 16 | _C.exp.load_from = '' 17 | _C.exp.ckpt_name = '' 18 | _C.exp.vis_dir = '' 19 | 20 | # Model related 21 | _C.model = CN() 22 | _C.model.encoderA = '' 23 | _C.model.encoderB = '' 24 | _C.model.encoder = '' 25 | _C.model.encoder_geo = '' 26 | _C.model.pose_predictor_quat = '' 27 | _C.model.pose_predictor_rot = '' 28 | _C.model.pose_predictor_trans = '' 29 | _C.model.corr_module = '' 30 | _C.model.sdf_predictor= '' 31 | _C.model.aggregator = '' 32 | _C.model.pc_feat_dim = 512 33 | _C.model.transformer_feat_dim = 1024 34 | _C.model.num_heads = 4 35 | _C.model.num_blocks = 1 36 | _C.model.recon_loss = False 37 | _C.model.point_loss = False 38 | _C.model.corr_feat = 'False' 39 | 40 | # Classification Model related 41 | _C.cls_model = CN() 42 | _C.cls_model.encoder = '' 43 | _C.cls_model.point_loss = False 44 | _C.cls_model.recon_loss = False 45 | _C.cls_model.corr_module = 'No' 46 | _C.cls_model.pc_feat_dim = 512 47 | _C.cls_model.num_heads = 4 48 | _C.cls_model.num_blocks = 1 49 | _C.cls_model.PCA_obb = False 50 | 51 | # Data related 52 | _C.data = CN() 53 | _C.data.root_dir = '' 54 | _C.data.train_csv_file = '' 55 | _C.data.val_csv_file = '' 56 | _C.data.num_pc_points = 1024 57 | _C.data.translation = '' 58 | 59 | # Classification Task Data related 60 | _C.cls_data = CN() 61 | _C.cls_data.root_dir = '' 62 | _C.cls_data.train_csv_file = '' 63 | _C.cls_data.val_csv_file = '' 64 | _C.cls_data.num_pc_points = 1024 65 | _C.cls_data.translation = '' 66 | 67 | # Optimizer related 68 | _C.optimizer = CN() 69 | _C.optimizer.lr = 1e-3 70 | _C.optimizer.lr_decay = 0.7 71 | _C.optimizer.decay_step = 2e4 72 | _C.optimizer.weight_decay = 1e-6 73 | _C.optimizer.lr_clip = 1e-5 74 | 75 | def get_cfg_defaults(): 76 | return _C.clone() -------------------------------------------------------------------------------- /src/shape_assembly/config_eval.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # Miscellaneous configs 4 | _C = CN() 5 | 6 | # Experiment related 7 | _C.exp = CN() 8 | _C.exp.name = '' 9 | _C.exp.checkpoint_dir = '' 10 | _C.exp.weight_file = '' 11 | _C.exp.gpus = [0] 12 | _C.exp.num_workers = 8 13 | _C.exp.batch_size = 1 14 | _C.exp.num_epochs = 1000 15 | _C.exp.log_dir = '' 16 | _C.exp.eval_epoch = 100 17 | _C.exp.load_from = '' 18 | 19 | # Model related 20 | _C.model = CN() 21 | _C.model.encoder = '' 22 | _C.model.encoder_geo = '' 23 | _C.model.pose_predictor_quat = '' 24 | _C.model.pose_predictor_trans = '' 25 | _C.model.sdf_predictor= '' 26 | _C.model.aggregator = '' 27 | _C.model.pc_feat_dim = 512 28 | _C.model.transformer_feat_dim = 1024 29 | _C.model.num_heads = 4 30 | _C.model.num_blocks = 1 31 | 32 | # Data related 33 | _C.data = CN() 34 | _C.data.root_dir = '' 35 | _C.data.train_csv_file = '' 36 | _C.data.val_csv_file = '' 37 | _C.data.test_csv_file = '' 38 | _C.data.num_pc_points = 1024 39 | _C.data.draw_map = False 40 | _C.data.draw_gt_map = False 41 | 42 | def get_cfg_defaults(): 43 | return _C.clone() 44 | -------------------------------------------------------------------------------- /src/shape_assembly/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/.DS_Store -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/.DS_Store -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/cls_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/cls_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_A.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_A.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_CR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_CR.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit1.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_1.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_ori.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_ori.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_single.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_single.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our1.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/sub1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/sub1.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/__pycache__/sub_1_7.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/sub_1_7.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/dataloader_A.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import trimesh 9 | from PIL import Image 10 | import json 11 | from progressbar import ProgressBar 12 | import random 13 | import copy 14 | import time 15 | import ipdb 16 | import numpy as np 17 | import pandas as pd 18 | from scipy.spatial.transform import Rotation as R 19 | from pdb import set_trace 20 | 21 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly')) 24 | 25 | from pytorch3d.transforms import quaternion_to_matrix 26 | from mesh_to_sdf import sample_sdf_near_surface 27 | 28 | def load_data(file_dir, cat_shape_dict): 29 | with open(file_dir, 'r') as fin: 30 | for l in fin.readlines(): 31 | shape_id, cat = l.rstrip().split() 32 | cat_shape_dict[cat].append(shape_id) 33 | return cat_shape_dict 34 | 35 | class OurDataset(data.Dataset): 36 | 37 | def __init__(self, data_root_dir, data_csv_file, data_features=[], num_points = 1024 ,num_query_points = 1024 ,data_per_seg = 1): 38 | self.data_root_dir = data_root_dir 39 | self.data_csv_file = data_csv_file 40 | self.num_points = num_points 41 | self.num_query_points = num_query_points 42 | self.data_features = data_features 43 | self.data_per_seg = data_per_seg 44 | self.dataset = [] 45 | 46 | with open(self.data_csv_file, 'r') as fin: 47 | self.category_list = [line.strip() for line in fin.readlines()] 48 | 49 | 50 | def transform_pc_to_rot(self, pcs): 51 | # zero-centered 52 | pc_center = (pcs.max(axis=0, keepdims=True) + pcs.min(axis=0, keepdims=True)) / 2 53 | pc_center = pc_center[0] 54 | new_pcs = pcs - pc_center 55 | 56 | # (batch_size, 2, 3) 57 | def bgs(d6s): 58 | bsz = d6s.shape[0] 59 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1) 60 | a2 = d6s[:, :, 1] 61 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1) 62 | b3 = torch.cross(b1, b2, dim=1) 63 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1) 64 | 65 | # randomly sample two rotation matrices 66 | rotmat = bgs(torch.rand(1, 6).reshape(-1, 2, 3).permute(0, 2, 1)) 67 | new_pcs = (rotmat.reshape(3, 3) @ new_pcs.T).T 68 | 69 | gt_rot = rotmat[:, :, :2].permute(0, 2, 1).reshape(6).numpy() 70 | 71 | return new_pcs, pc_center, gt_rot 72 | 73 | 74 | def load_data(self): 75 | bar = ProgressBar() 76 | 77 | for category_i in bar(range(len(self.category_list))): 78 | category_id = self.category_list[category_i] 79 | instance_dir = os.path.join(self.data_root_dir, category_id) 80 | fileA = os.path.join(instance_dir, 'partA-pc.csv') 81 | fileB = os.path.join(instance_dir, 'partB-pc.csv') 82 | 83 | if not os.path.exists(fileA) or not os.path.exists(fileB): 84 | print("fileA is", fileA) 85 | print("fileB is", fileB) 86 | print("file not exists") 87 | continue 88 | 89 | dataframe_A = pd.read_csv(fileA, header=None) 90 | dataframe_B = pd.read_csv(fileB, header=None) 91 | gt_pcs_A = dataframe_A.to_numpy() 92 | gt_pcs_B = dataframe_B.to_numpy() 93 | 94 | for i in range(self.data_per_seg): 95 | gt_pcs_total = np.concatenate((gt_pcs_A, gt_pcs_B), axis=0) 96 | per = np.random.permutation(gt_pcs_total.shape[0]) 97 | gt_pcs_total = gt_pcs_total[per] 98 | 99 | self.dataset.append([fileA, 100 | fileB, 101 | ]) 102 | 103 | def __str__(self): 104 | return 105 | 106 | def __len__(self): 107 | return len(self.dataset) 108 | 109 | def __getitem__(self, index): 110 | 111 | flag = 0 112 | while flag == 0: 113 | point_fileA, point_fileB, = self.dataset[index] 114 | 115 | dataframe_A = pd.read_csv(point_fileA, header=None) 116 | dataframe_B = pd.read_csv(point_fileB, header=None) 117 | gt_pcs_A = dataframe_A.to_numpy() 118 | gt_pcs_B = dataframe_B.to_numpy() 119 | 120 | if gt_pcs_A[0][0] != 0 and gt_pcs_A[0][1] != 0: 121 | flag = 1 122 | else: 123 | index += 1 124 | 125 | if gt_pcs_A[0][0] == 0 and gt_pcs_A[0][1] == 0: 126 | raise ValueError('getitem Zero encountered!') 127 | 128 | new_pcs_A, trans_A, rot_A = self.transform_pc_to_rot(gt_pcs_A) 129 | 130 | new_pcs_B, trans_B, rot_B = self.transform_pc_to_rot(gt_pcs_B) 131 | 132 | partA_symmetry_type = np.array([0,0,0,0,1,0]) 133 | partB_symmetry_type = np.array([0,0,0,0,1,0]) 134 | 135 | data_feats = dict() 136 | 137 | for feat in self.data_features: 138 | if feat == 'src_pc': 139 | data_feats['src_pc'] = new_pcs_A.T.float() 140 | 141 | elif feat == 'tgt_pc': 142 | data_feats['tgt_pc'] = torch.tensor(gt_pcs_B).T.float() 143 | 144 | elif feat == 'src_rot': 145 | data_feats['src_rot'] = rot_A.astype(np.float32) 146 | 147 | elif feat == 'tgt_rot': 148 | data_feats['tgt_rot'] = rot_B.astype(np.float32) 149 | 150 | elif feat == 'src_trans': 151 | data_feats['src_trans'] = trans_A.reshape(1, 3).T.astype(np.float32) 152 | 153 | elif feat == 'tgt_trans': 154 | data_feats['tgt_trans'] = trans_B.reshape(1, 3).T.astype(np.float32) 155 | 156 | elif feat == 'partA_symmetry_type': 157 | data_feats['partA_symmetry_type'] = partA_symmetry_type 158 | 159 | elif feat == 'partB_symmetry_type': 160 | data_feats['partB_symmetry_type'] = partB_symmetry_type 161 | 162 | elif feat == 'predicted_partB_position': 163 | data_feats['predicted_partB_position'] = np.array([0, 0, 0, 0]).astype(np.float32) 164 | 165 | elif feat == 'predicted_partB_rotation': 166 | data_feats['predicted_partB_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32) 167 | 168 | elif feat == 'predicted_partA_position': 169 | data_feats['predicted_partA_position'] = np.array([0, 0, 0, 0]).astype(np.float32) 170 | 171 | elif feat == 'predicted_partA_rotation': 172 | data_feats['predicted_partA_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32) 173 | 174 | elif feat == 'partA_mesh': 175 | data_feats['partA_mesh'] = self.load_mesh(data_feats, point_fileA) 176 | 177 | elif feat == 'partB_mesh': 178 | data_feats['partB_mesh'] = self.load_mesh(data_feats, point_fileB) 179 | 180 | return data_feats -------------------------------------------------------------------------------- /src/shape_assembly/datasets/dataloader/dataloader_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import trimesh 9 | from PIL import Image 10 | import json 11 | from progressbar import ProgressBar 12 | # from pyquaternion import Quaternion 13 | import random 14 | import copy 15 | import time 16 | import ipdb 17 | import numpy as np 18 | import pandas as pd 19 | from scipy.spatial.transform import Rotation as R 20 | from pdb import set_trace 21 | 22 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 23 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly')) 25 | 26 | from pytorch3d.transforms import quaternion_to_matrix 27 | 28 | def load_data(file_dir, cat_shape_dict): 29 | with open(file_dir, 'r') as fin: 30 | for l in fin.readlines(): 31 | shape_id, cat = l.rstrip().split() 32 | cat_shape_dict[cat].append(shape_id) 33 | return cat_shape_dict 34 | 35 | class OurDataset(data.Dataset): 36 | 37 | def __init__(self, data_root_dir, data_csv_file, data_features=[], num_points = 1024 ,num_query_points = 1024 ,data_per_seg = 1): 38 | self.data_root_dir = data_root_dir 39 | self.data_csv_file = data_csv_file 40 | self.num_points = num_points 41 | self.num_query_points = num_query_points 42 | self.data_features = data_features 43 | self.data_per_seg = data_per_seg 44 | self.dataset = [] 45 | 46 | with open(self.data_csv_file, 'r') as fin: 47 | self.category_list = [line.strip() for line in fin.readlines()] 48 | 49 | 50 | def transform_pc_to_rot(self, pcs): 51 | # zero-centered 52 | pc_center = (pcs.max(axis=0, keepdims=True) + pcs.min(axis=0, keepdims=True)) / 2 53 | pc_center = pc_center[0] 54 | new_pcs = pcs - pc_center 55 | 56 | # (batch_size, 2, 3) 57 | def bgs(d6s): 58 | bsz = d6s.shape[0] 59 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1) 60 | a2 = d6s[:, :, 1] 61 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1) 62 | b3 = torch.cross(b1, b2, dim=1) 63 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1) 64 | 65 | # randomly sample two rotation matrices 66 | rotmat = bgs(torch.rand(1, 6).reshape(-1, 2, 3).permute(0, 2, 1)) 67 | new_pcs = (rotmat.reshape(3, 3) @ new_pcs.T).T 68 | 69 | gt_rot = rotmat[:, :, :2].permute(0, 2, 1).reshape(6).numpy() 70 | 71 | return new_pcs, pc_center, gt_rot 72 | 73 | 74 | def load_data(self): 75 | bar = ProgressBar() 76 | 77 | for category_i in bar(range(len(self.category_list))): 78 | category_id = self.category_list[category_i] 79 | instance_dir = os.path.join(self.data_root_dir, category_id) 80 | fileA = os.path.join(instance_dir, 'partA-pc.csv') 81 | fileB = os.path.join(instance_dir, 'partB-pc.csv') 82 | if not os.path.exists(fileA) or not os.path.exists(fileB): 83 | print("fileA is", fileA) 84 | print("fileB is", fileB) 85 | print("file not exists") 86 | # exit(0) 87 | continue 88 | 89 | dataframe_A = pd.read_csv(fileA, header=None) 90 | dataframe_B = pd.read_csv(fileB, header=None) 91 | gt_pcs_A = dataframe_A.to_numpy() 92 | gt_pcs_B = dataframe_B.to_numpy() 93 | 94 | for i in range(self.data_per_seg): 95 | gt_pcs_total = np.concatenate((gt_pcs_A, gt_pcs_B), axis=0) 96 | per = np.random.permutation(gt_pcs_total.shape[0]) 97 | gt_pcs_total = gt_pcs_total[per] 98 | 99 | self.dataset.append([fileA, 100 | fileB,]) 101 | 102 | def __str__(self): 103 | return 104 | 105 | def __len__(self): 106 | return len(self.dataset) 107 | 108 | def __getitem__(self, index): 109 | 110 | flag = 0 111 | while flag == 0: 112 | point_fileA, point_fileB = self.dataset[index] 113 | 114 | dataframe_A = pd.read_csv(point_fileA, header=None) 115 | dataframe_B = pd.read_csv(point_fileB, header=None) 116 | gt_pcs_A = dataframe_A.to_numpy() 117 | gt_pcs_B = dataframe_B.to_numpy() 118 | 119 | if gt_pcs_A[0][0] != 0 and gt_pcs_A[0][1] != 0: 120 | flag = 1 121 | else: 122 | index += 1 123 | 124 | if gt_pcs_A[0][0] == 0 and gt_pcs_A[0][1] == 0: 125 | raise ValueError('getitem Zero encountered!') 126 | 127 | try: 128 | new_pcs_A, trans_A, rot_A = self.transform_pc_to_rot(gt_pcs_A) 129 | except Exception as e: 130 | print("file A error!", point_fileA) 131 | raise 132 | 133 | try: 134 | new_pcs_B, trans_B, rot_B = self.transform_pc_to_rot(gt_pcs_B) 135 | except Exception as e: 136 | print("file B error!", point_fileB) 137 | 138 | partA_symmetry_type = np.array([0,0,0,0,1,0]) 139 | partB_symmetry_type = np.array([0,0,0,0,1,0]) 140 | 141 | data_feats = dict() 142 | 143 | for feat in self.data_features: 144 | if feat == 'src_pc': 145 | data_feats['src_pc'] = new_pcs_A.T.float() 146 | 147 | elif feat == 'tgt_pc': 148 | data_feats['tgt_pc'] = new_pcs_B.T.float() 149 | 150 | elif feat == 'src_rot': 151 | data_feats['src_rot'] = rot_A.astype(np.float32) 152 | 153 | elif feat == 'tgt_rot': 154 | data_feats['tgt_rot'] = rot_B.astype(np.float32) 155 | 156 | elif feat == 'src_trans': 157 | data_feats['src_trans'] = trans_A.reshape(1, 3).T.astype(np.float32) 158 | 159 | elif feat == 'tgt_trans': 160 | data_feats['tgt_trans'] = trans_B.reshape(1, 3).T.astype(np.float32) 161 | 162 | elif feat == 'partA_symmetry_type': 163 | data_feats['partA_symmetry_type'] = partA_symmetry_type 164 | 165 | elif feat == 'partB_symmetry_type': 166 | data_feats['partB_symmetry_type'] = partB_symmetry_type 167 | 168 | elif feat == 'predicted_partB_position': 169 | data_feats['predicted_partB_position'] = np.array([0, 0, 0, 0]).astype(np.float32) 170 | 171 | elif feat == 'predicted_partB_rotation': 172 | data_feats['predicted_partB_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32) 173 | 174 | elif feat == 'predicted_partA_position': 175 | data_feats['predicted_partA_position'] = np.array([0, 0, 0, 0]).astype(np.float32) 176 | 177 | elif feat == 'predicted_partA_rotation': 178 | data_feats['predicted_partA_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32) 179 | 180 | elif feat == 'partA_mesh': 181 | data_feats['partA_mesh'] = self.load_mesh(data_feats, point_fileA) 182 | 183 | elif feat == 'partB_mesh': 184 | data_feats['partB_mesh'] = self.load_mesh(data_feats, point_fileB) 185 | 186 | return data_feats -------------------------------------------------------------------------------- /src/shape_assembly/models/decoder/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/decoder/.DS_Store -------------------------------------------------------------------------------- /src/shape_assembly/models/decoder/MLPDecoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pdb import set_trace 5 | import os 6 | import sys 7 | import copy 8 | import math 9 | import numpy as np 10 | 11 | 12 | class MLPDecoder(nn.Module): 13 | def __init__(self, feat_dim, num_points): 14 | super().__init__() 15 | self.np = num_points 16 | self.fc_layers = nn.Sequential( 17 | nn.Linear(feat_dim * 2, num_points * 2), 18 | nn.BatchNorm1d(num_points * 2), 19 | nn.LeakyReLU(0.2), 20 | nn.Linear(num_points * 2, num_points * 3), 21 | ) 22 | 23 | def forward(self, x): 24 | # x.shape: (bs,1024) 25 | batch_size = x.shape[0] 26 | f = self.fc_layers(x) 27 | return f.reshape(batch_size, self.np, 3) -------------------------------------------------------------------------------- /src/shape_assembly/models/decoder/__pycache__/MLPDecoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/decoder/__pycache__/MLPDecoder.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/.DS_Store -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/__pycache__/dgcnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/dgcnn.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/__pycache__/pointnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/pointnet.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/__pycache__/vn_dgcnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/vn_dgcnn.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/__pycache__/vn_layers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/vn_layers.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/vn_dgcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pdb import set_trace 5 | import os 6 | import sys 7 | import copy 8 | import math 9 | import numpy as np 10 | 11 | from models.encoder.vn_layers import * 12 | 13 | def knn(x, k): 14 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 15 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 16 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 17 | 18 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 19 | return idx 20 | 21 | def get_graph_feature(x, k=20, idx=None, x_coord=None): 22 | batch_size = x.size(0) 23 | num_points = x.size(3) 24 | x = x.view(batch_size, -1, num_points) 25 | if idx is None: 26 | if x_coord is None: # dynamic knn graph 27 | idx = knn(x, k=k) 28 | else: # fixed knn graph with input point coordinates 29 | idx = knn(x_coord, k=k) 30 | device = torch.device('cuda') 31 | 32 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 33 | 34 | idx = idx + idx_base 35 | 36 | idx = idx.view(-1) 37 | 38 | _, num_dims, _ = x.size() 39 | num_dims = num_dims // 3 40 | 41 | x = x.transpose(2, 1).contiguous() 42 | feature = x.view(batch_size * num_points, -1)[idx, :] 43 | feature = feature.view(batch_size, num_points, k, num_dims, 3) 44 | x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1) 45 | 46 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous() 47 | 48 | return feature 49 | 50 | class VN_DGCNN(nn.Module): 51 | 52 | 53 | def __init__(self, feat_dim): 54 | super(VN_DGCNN, self).__init__() 55 | self.n_knn = 20 56 | # num_part = feat_dim 57 | 58 | pooling = 'mean' 59 | 60 | self.conv1 = VNLinearLeakyReLU(2, 64 // 3) 61 | self.conv2 = VNLinearLeakyReLU(64 // 3, 64 // 3) 62 | self.conv3 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3) 63 | self.conv4 = VNLinearLeakyReLU(64 // 3, 64 // 3) 64 | self.conv5 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3) 65 | 66 | if pooling == 'max': 67 | self.pool1 = VNMaxPool(64 // 3) 68 | self.pool2 = VNMaxPool(64 // 3) 69 | self.pool3 = VNMaxPool(64 // 3) 70 | self.pool4 = VNMaxPool(2 * feat_dim) 71 | elif pooling == 'mean': 72 | self.pool1 = mean_pool 73 | self.pool2 = mean_pool 74 | self.pool3 = mean_pool 75 | self.pool4 = mean_pool 76 | 77 | self.conv6 = VNLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4, share_nonlinearity=True) 78 | 79 | def forward(self, x): 80 | 81 | batch_size = x.size(0) 82 | num_points = x.size(2) 83 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16) 84 | 85 | x = x.unsqueeze(1) # (32, 1, 3, 1024) 86 | 87 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20) 88 | 89 | x = self.conv1(x) # (32, 21, 3, 1024, 20) 90 | x = self.conv2(x) # (32, 21, 3, 1024, 20) 91 | x1 = self.pool1(x) # (32, 21, 3, 1024) 92 | 93 | x = get_graph_feature(x1, k=self.n_knn) 94 | x = self.conv3(x) 95 | x = self.conv4(x) 96 | x2 = self.pool2(x) 97 | 98 | x = get_graph_feature(x2, k=self.n_knn) 99 | x = self.conv5(x) 100 | x3 = self.pool3(x) 101 | 102 | x123 = torch.cat((x1, x2, x3), dim=1) 103 | 104 | x = self.conv6(x123) 105 | x = self.pool4(x) # [batch, feature_dim, 3] 106 | return x 107 | 108 | class VN_DGCNN_corr(nn.Module): 109 | 110 | def __init__(self, feat_dim): 111 | super(VN_DGCNN_corr, self).__init__() 112 | self.n_knn = 20 113 | # num_part = feat_dim 114 | 115 | pooling = 'mean' 116 | 117 | self.conv1 = VNLinearLeakyReLU(2, 64 // 3) 118 | self.conv2 = VNLinearLeakyReLU(64 // 3, 64 // 3) 119 | self.conv3 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3) 120 | self.conv4 = VNLinearLeakyReLU(64 // 3, 64 // 3) 121 | self.conv5 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3) 122 | self.VnInv = VNStdFeature(2 * feat_dim, dim=3, normalize_frame=False) 123 | 124 | if pooling == 'max': 125 | self.pool1 = VNMaxPool(64 // 3) 126 | self.pool2 = VNMaxPool(64 // 3) 127 | self.pool3 = VNMaxPool(64 // 3) 128 | self.pool4 = VNMaxPool(2 * feat_dim) 129 | elif pooling == 'mean': 130 | self.pool1 = mean_pool 131 | self.pool2 = mean_pool 132 | self.pool3 = mean_pool 133 | self.pool4 = mean_pool 134 | 135 | self.conv6 = VNLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4, share_nonlinearity=True) 136 | self.linear0 = nn.Linear(3, 2 * feat_dim) 137 | 138 | def forward(self, x): 139 | 140 | batch_size = x.size(0) 141 | num_points = x.size(2) 142 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16) 143 | 144 | x = x.unsqueeze(1) # (32, 1, 3, 1024) 145 | 146 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20) 147 | 148 | x = self.conv1(x) # (32, 21, 3, 1024, 20) 149 | x = self.conv2(x) # (32, 21, 3, 1024, 20) 150 | x1 = self.pool1(x) # (32, 21, 3, 1024) 151 | 152 | x = get_graph_feature(x1, k=self.n_knn) 153 | x = self.conv3(x) 154 | x = self.conv4(x) 155 | x2 = self.pool2(x) 156 | 157 | x = get_graph_feature(x2, k=self.n_knn) 158 | x = self.conv5(x) 159 | x3 = self.pool3(x) 160 | 161 | x123 = torch.cat((x1, x2, x3), dim=1) 162 | 163 | x = self.conv6(x123) 164 | x_mean = x.mean(dim=-1, keepdim=True).expand(x.size()) 165 | x = torch.cat((x, x_mean), 1) 166 | x = self.pool4(x) # [batch, feature_dim, 3] 167 | x1, z0 = self.VnInv(x) 168 | x1 = self.linear0(x1) 169 | return x, x1 # [batch, 1024, 3], [batch, 1024, 1024] 170 | 171 | class VN_DGCNN_New(nn.Module): 172 | 173 | def __init__(self, feat_dim): 174 | super(VN_DGCNN_New, self).__init__() 175 | self.n_knn = 20 176 | num_part = feat_dim 177 | 178 | print("feat_dim is: ", feat_dim) 179 | 180 | pooling = 'mean' 181 | 182 | self.conv1 = VNLinearLeakyReLU(2, 64 // 3) 183 | self.conv2 = VNLinearLeakyReLU(64 // 3, 64 // 3) 184 | self.conv3 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3) 185 | self.conv4 = VNLinearLeakyReLU(64 // 3, 64 // 3) 186 | self.conv5 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3) 187 | self.VnInv = VNStdFeature(2 * feat_dim, dim=3, normalize_frame=False) 188 | 189 | if pooling == 'max': 190 | self.pool1 = VNMaxPool(64 // 3) 191 | self.pool2 = VNMaxPool(64 // 3) 192 | self.pool3 = VNMaxPool(64 // 3) 193 | self.pool4 = VNMaxPool(2 * feat_dim) 194 | elif pooling == 'mean': 195 | self.pool1 = mean_pool 196 | self.pool2 = mean_pool 197 | self.pool3 = mean_pool 198 | self.pool4 = mean_pool 199 | 200 | self.conv6 = VNLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4, share_nonlinearity=True) 201 | self.linear0 = nn.Linear(3, 2 * feat_dim) 202 | 203 | def forward(self, x): 204 | 205 | batch_size = x.size(0) 206 | num_points = x.size(2) 207 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16) 208 | 209 | x = x.unsqueeze(1) # (32, 1, 3, 1024) 210 | 211 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20) 212 | 213 | x = self.conv1(x) # (32, 21, 3, 1024, 20) 214 | x = self.conv2(x) # (32, 21, 3, 1024, 20) 215 | x1 = self.pool1(x) # (32, 21, 3, 1024) 216 | 217 | x = get_graph_feature(x1, k=self.n_knn) 218 | x = self.conv3(x) 219 | x = self.conv4(x) 220 | x2 = self.pool2(x) 221 | 222 | x = get_graph_feature(x2, k=self.n_knn) 223 | x = self.conv5(x) 224 | x3 = self.pool3(x) 225 | 226 | x123 = torch.cat((x1, x2, x3), dim=1) 227 | 228 | x = self.conv6(x123) 229 | x_mean = x.mean(dim=-1, keepdim=True).expand(x.size()) 230 | x = torch.cat((x, x_mean), 1) 231 | x = self.pool4(x) # [batch, feature_dim, 3] 232 | x1, z0 = self.VnInv(x) 233 | x1 = self.linear0(x1) 234 | return x, x1 # [batch, 1024, 3], [batch, 1024, 1024] 235 | 236 | class DGCNN_New(nn.Module): 237 | 238 | def __init__(self, feat_dim): 239 | super(DGCNN_New, self).__init__() 240 | self.n_knn = 20 241 | num_part = feat_dim 242 | 243 | print("feat_dim is: ", feat_dim) 244 | 245 | pooling = 'mean' 246 | 247 | self.conv1 = NonEquivariantLinearLeakyReLU(2, 64 // 3) 248 | self.conv2 = NonEquivariantLinearLeakyReLU(64 // 3, 64 // 3) 249 | self.conv3 = NonEquivariantLinearLeakyReLU(64 // 3 * 2, 64 // 3) 250 | self.conv4 = NonEquivariantLinearLeakyReLU(64 // 3, 64 // 3) 251 | self.conv5 = NonEquivariantLinearLeakyReLU(64 // 3 * 2, 64 // 3) 252 | self.VnInv = NonEquivariantStdFeature(2 * feat_dim, dim=3, normalize_frame=False) 253 | 254 | if pooling == 'max': 255 | self.pool1 = NonEquivariantMaxPool(64 // 3) 256 | self.pool2 = NonEquivariantMaxPool(64 // 3) 257 | self.pool3 = NonEquivariantMaxPool(64 // 3) 258 | self.pool4 = NonEquivariantMaxPool(2 * feat_dim) 259 | elif pooling == 'mean': 260 | self.pool1 = mean_pool 261 | self.pool2 = mean_pool 262 | self.pool3 = mean_pool 263 | self.pool4 = mean_pool 264 | 265 | self.conv6 = NonEquivariantLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4) 266 | self.linear0 = nn.Linear(3, 2 * feat_dim) 267 | 268 | def forward(self, x): 269 | 270 | # x: (batch_size, 3, num_points) 271 | # l: (batch_size, 1, 16) 272 | 273 | batch_size = x.size(0) 274 | num_points = x.size(2) 275 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16) 276 | print("!!! the shape of l is: ", l.shape) 277 | 278 | print("!!! the shape of x is: ", x.shape) 279 | x = x.unsqueeze(1) # (32, 1, 3, 1024) 280 | print("!!! the shape of x is: ", x.shape) 281 | 282 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20) 283 | 284 | print("!!! the shape of x is: ", x.shape) 285 | x = self.conv1(x) # (32, 21, 3, 1024, 20) 286 | print("!!! the shape of x is: ", x.shape) 287 | x = self.conv2(x) # (32, 21, 3, 1024, 20) 288 | x1 = self.pool1(x) # (32, 21, 3, 1024) 289 | 290 | x = get_graph_feature(x1, k=self.n_knn) 291 | x = self.conv3(x) 292 | x = self.conv4(x) 293 | x2 = self.pool2(x) 294 | 295 | x = get_graph_feature(x2, k=self.n_knn) 296 | x = self.conv5(x) 297 | x3 = self.pool3(x) 298 | 299 | x123 = torch.cat((x1, x2, x3), dim=1) 300 | 301 | x = self.conv6(x123) 302 | x_mean = x.mean(dim=-1, keepdim=True).expand(x.size()) 303 | x = torch.cat((x, x_mean), 1) 304 | x = self.pool4(x) # [batch, feature_dim, 3] 305 | x1, z0 = self.VnInv(x) 306 | x1 = self.linear0(x1) 307 | return x, x1 # [batch, 1024, 3], [batch, 1024, 1024] 308 | -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/vn_dgcnn_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def knn(x, k): 12 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 13 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 14 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 15 | 16 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 17 | return idx 18 | 19 | 20 | def get_graph_feature(x, k=20, idx=None, x_coord=None): 21 | batch_size = x.size(0) 22 | num_points = x.size(3) 23 | x = x.view(batch_size, -1, num_points) 24 | if idx is None: 25 | if x_coord is None: # dynamic knn graph 26 | idx = knn(x, k=k) 27 | else: # fixed knn graph with input point coordinates 28 | idx = knn(x_coord, k=k) 29 | device = torch.device('cuda') 30 | 31 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 32 | 33 | idx = idx + idx_base 34 | 35 | idx = idx.view(-1) 36 | 37 | _, num_dims, _ = x.size() 38 | num_dims = num_dims // 3 39 | 40 | x = x.transpose(2, 1).contiguous() 41 | feature = x.view(batch_size * num_points, -1)[idx, :] 42 | feature = feature.view(batch_size, num_points, k, num_dims, 3) 43 | x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1) 44 | 45 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous() 46 | 47 | return feature 48 | 49 | 50 | def get_graph_feature_cross(x, k=20, idx=None): 51 | batch_size = x.size(0) 52 | num_points = x.size(3) 53 | x = x.view(batch_size, -1, num_points) 54 | if idx is None: 55 | idx = knn(x, k=k) 56 | device = torch.device('cuda') 57 | 58 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 59 | 60 | idx = idx + idx_base 61 | 62 | idx = idx.view(-1) 63 | 64 | _, num_dims, _ = x.size() 65 | num_dims = num_dims // 3 66 | 67 | x = x.transpose(2, 1).contiguous() 68 | feature = x.view(batch_size * num_points, -1)[idx, :] 69 | feature = feature.view(batch_size, num_points, k, num_dims, 3) 70 | x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1) 71 | cross = torch.cross(feature, x, dim=-1) 72 | 73 | feature = torch.cat((feature - x, x, cross), dim=3).permute(0, 3, 4, 1, 2).contiguous() 74 | 75 | return feature -------------------------------------------------------------------------------- /src/shape_assembly/models/encoder/vn_layers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | EPS = 1e-6 11 | 12 | def conv1x1(in_channels, out_channels, dim): 13 | if dim == 3: 14 | return nn.Conv1d(in_channels, out_channels, 1, bias=False) 15 | elif dim == 4: 16 | return nn.Conv2d(in_channels, out_channels, 1, bias=False) 17 | elif dim == 5: 18 | return nn.Conv3d(in_channels, out_channels, 1, bias=False) 19 | else: 20 | raise NotImplementedError(f'{dim}D 1x1 Conv is not supported') 21 | 22 | 23 | class VNLinear(nn.Module): 24 | def __init__(self, in_channels, out_channels): 25 | super(VNLinear, self).__init__() 26 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) 27 | 28 | def forward(self, x): 29 | ''' 30 | x: point features of shape [B, N_feat, 3, N_samples, ...] 31 | ''' 32 | x_out = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) 33 | return x_out 34 | 35 | 36 | class VNLeakyReLU(nn.Module): 37 | def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2): 38 | super(VNLeakyReLU, self).__init__() 39 | if share_nonlinearity == True: 40 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False) 41 | else: 42 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) 43 | self.negative_slope = negative_slope 44 | 45 | def forward(self, x): 46 | ''' 47 | x: point features of shape [B, N_feat, 3, N_samples, ...] 48 | ''' 49 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) 50 | dotprod = (x * d).sum(2, keepdim=True) 51 | mask = (dotprod >= 0).float() 52 | d_norm_sq = (d * d).sum(2, keepdim=True) 53 | x_out = self.negative_slope * x + (1 - self.negative_slope) * ( 54 | mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d)) 55 | return x_out 56 | 57 | 58 | class VNNewLeakyReLU(nn.Module): 59 | def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2): 60 | super(VNNewLeakyReLU, self).__init__() 61 | if share_nonlinearity == True: 62 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False) 63 | else: 64 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) 65 | self.negative_slope = negative_slope 66 | 67 | def forward(self, x): 68 | ''' 69 | x: point features of shape [B, N_feat, 3, N_samples, ...] 70 | ''' 71 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) 72 | dotprod = (x * d) 73 | mask = (dotprod >= 0).float() 74 | d_norm_sq = (d * d) 75 | x_out = self.negative_slope * x + (1 - self.negative_slope) * ( 76 | mask * x + (1 - mask) * (x - (d / (d_norm_sq + EPS)) * d)) 77 | return x_out 78 | 79 | 80 | class VNLinearLeakyReLU(nn.Module): 81 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, negative_slope=0.2): 82 | super(VNLinearLeakyReLU, self).__init__() 83 | self.dim = dim 84 | self.negative_slope = negative_slope 85 | 86 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) 87 | self.batchnorm = VNBatchNorm(out_channels, dim=dim) 88 | 89 | if share_nonlinearity == True: 90 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False) 91 | else: 92 | self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False) 93 | 94 | def forward(self, x): 95 | ''' 96 | x: point features of shape [B, N_feat, 3, N_samples, ...] 97 | ''' 98 | # Linear 99 | p = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) 100 | # BatchNorm 101 | p = self.batchnorm(p) 102 | # LeakyReLU 103 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) 104 | dotprod = (p * d).sum(2, keepdims=True) 105 | mask = (dotprod >= 0).float() 106 | d_norm_sq = (d * d).sum(2, keepdims=True) 107 | x_out = self.negative_slope * p + (1 - self.negative_slope) * ( 108 | mask * p + (1 - mask) * (p - (dotprod / (d_norm_sq + EPS)) * d)) 109 | return x_out 110 | 111 | 112 | class VNLinearAndLeakyReLU(nn.Module): 113 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, use_batchnorm='norm', 114 | negative_slope=0.2): 115 | super(VNLinearLeakyReLU, self).__init__() 116 | self.dim = dim 117 | self.share_nonlinearity = share_nonlinearity 118 | self.use_batchnorm = use_batchnorm 119 | self.negative_slope = negative_slope 120 | 121 | self.linear = VNLinear(in_channels, out_channels) 122 | self.leaky_relu = VNLeakyReLU(out_channels, share_nonlinearity=share_nonlinearity, 123 | negative_slope=negative_slope) 124 | 125 | # BatchNorm 126 | self.use_batchnorm = use_batchnorm 127 | if use_batchnorm != 'none': 128 | self.batchnorm = VNBatchNorm(out_channels, dim=dim, mode=use_batchnorm) 129 | 130 | def forward(self, x): 131 | ''' 132 | x: point features of shape [B, N_feat, 3, N_samples, ...] 133 | ''' 134 | # Conv 135 | x = self.linear(x) 136 | # InstanceNorm 137 | if self.use_batchnorm != 'none': 138 | x = self.batchnorm(x) 139 | # LeakyReLU 140 | x_out = self.leaky_relu(x) 141 | return x_out 142 | 143 | 144 | class VNBatchNorm(nn.Module): 145 | def __init__(self, num_features, dim): 146 | super(VNBatchNorm, self).__init__() 147 | self.dim = dim 148 | if dim == 3 or dim == 4: 149 | self.bn = nn.BatchNorm1d(num_features) 150 | elif dim == 5: 151 | self.bn = nn.BatchNorm2d(num_features) 152 | 153 | def forward(self, x): 154 | ''' 155 | x: point features of shape [B, N_feat, 3, N_samples, ...] 156 | ''' 157 | # norm = torch.sqrt((x*x).sum(2)) 158 | norm = torch.norm(x, dim=2) + EPS 159 | norm_bn = self.bn(norm) 160 | norm = norm.unsqueeze(2) 161 | norm_bn = norm_bn.unsqueeze(2) 162 | x = x / norm * norm_bn 163 | 164 | return x 165 | 166 | 167 | class VNMaxPool(nn.Module): 168 | def __init__(self, in_channels): 169 | super(VNMaxPool, self).__init__() 170 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False) 171 | 172 | def forward(self, x): 173 | ''' 174 | x: point features of shape [B, N_feat, 3, N_samples, ...] 175 | ''' 176 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1) 177 | dotprod = (x * d).sum(2, keepdims=True) 178 | idx = dotprod.max(dim=-1, keepdim=False)[1] 179 | index_tuple = torch.meshgrid([torch.arange(j) for j in x.size()[:-1]]) + (idx,) 180 | x_max = x[index_tuple] 181 | return x_max 182 | 183 | 184 | def mean_pool(x, dim=-1, keepdim=False): 185 | return x.mean(dim=dim, keepdim=keepdim) 186 | 187 | 188 | class VNStdFeature(nn.Module): 189 | def __init__(self, in_channels, dim=4, normalize_frame=False, share_nonlinearity=False, negative_slope=0.2): 190 | super(VNStdFeature, self).__init__() 191 | self.dim = dim 192 | self.normalize_frame = normalize_frame 193 | 194 | self.vn1 = VNLinearLeakyReLU(in_channels, in_channels // 2, dim=dim, share_nonlinearity=share_nonlinearity, 195 | negative_slope=negative_slope) 196 | self.vn2 = VNLinearLeakyReLU(in_channels // 2, in_channels // 4, dim=dim, share_nonlinearity=share_nonlinearity, 197 | negative_slope=negative_slope) 198 | if normalize_frame: 199 | self.vn_lin = nn.Linear(in_channels // 4, 2, bias=False) 200 | else: 201 | self.vn_lin = nn.Linear(in_channels // 4, 3, bias=False) 202 | 203 | def forward(self, x): 204 | ''' 205 | x: point features of shape [B, N_feat, 3, N_samples, ...] 206 | ''' 207 | z0 = x 208 | z0 = self.vn1(z0) 209 | z0 = self.vn2(z0) 210 | z0 = self.vn_lin(z0.transpose(1, -1)).transpose(1, -1) 211 | 212 | if self.normalize_frame: 213 | # make z0 orthogonal. u2 = v2 - proj_u1(v2) 214 | v1 = z0[:, 0, :] 215 | # u1 = F.normalize(v1, dim=1) 216 | v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True)) 217 | u1 = v1 / (v1_norm + EPS) 218 | v2 = z0[:, 1, :] 219 | v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1 220 | # u2 = F.normalize(u2, dim=1) 221 | v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True)) 222 | u2 = v2 / (v2_norm + EPS) 223 | 224 | # compute the cross product of the two output vectors 225 | u3 = torch.cross(u1, u2) 226 | z0 = torch.stack([u1, u2, u3], dim=1).transpose(1, 2) 227 | else: 228 | z0 = z0.transpose(1, 2) 229 | 230 | if self.dim == 4: 231 | x_std = torch.einsum('bijm,bjkm->bikm', x, z0) 232 | elif self.dim == 3: 233 | x_std = torch.einsum('bij,bjk->bik', x, z0) 234 | elif self.dim == 5: 235 | x_std = torch.einsum('bijmn,bjkmn->bikmn', x, z0) 236 | 237 | return x_std, z0 238 | 239 | 240 | class VNInFeature(nn.Module): 241 | """VN-Invariant layer.""" 242 | 243 | def __init__( 244 | self, 245 | in_channels, 246 | dim=4, 247 | share_nonlinearity=False, 248 | negative_slope=0.2, 249 | use_rmat=False, 250 | ): 251 | super().__init__() 252 | 253 | self.dim = dim 254 | self.use_rmat = use_rmat 255 | self.vn1 = VNLinearBNLeakyReLU( 256 | in_channels, 257 | in_channels // 2, 258 | dim=dim, 259 | share_nonlinearity=share_nonlinearity, 260 | negative_slope=negative_slope, 261 | ) 262 | self.vn2 = VNLinearBNLeakyReLU( 263 | in_channels // 2, 264 | in_channels // 4, 265 | dim=dim, 266 | share_nonlinearity=share_nonlinearity, 267 | negative_slope=negative_slope, 268 | ) 269 | self.vn_lin = conv1x1( 270 | in_channels // 4, 2 if self.use_rmat else 3, dim=dim) 271 | 272 | def forward(self, x): 273 | """ 274 | Args: 275 | x: point features of shape [B, C, 3, N, ...] 276 | Returns: 277 | rotation invariant features of the same shape 278 | """ 279 | z = self.vn1(x) 280 | z = self.vn2(z) 281 | z = self.vn_lin(z) # [B, 3, 3, N] or [B, 2, 3, N] 282 | if self.use_rmat: 283 | z = z.flatten(1, 2).transpose(1, 2).contiguous() # [B, N, 6] 284 | z = rot6d_to_matrix(z) # [B, N, 3, 3] 285 | z = z.permute(0, 2, 3, 1) # [B, 3, 3, N] 286 | z = z.transpose(1, 2).contiguous() 287 | 288 | if self.dim == 4: 289 | x_in = torch.einsum('bijm,bjkm->bikm', x, z) 290 | elif self.dim == 3: 291 | x_in = torch.einsum('bij,bjk->bik', x, z) 292 | elif self.dim == 5: 293 | x_in = torch.einsum('bijmn,bjkmn->bikmn', x, z) 294 | else: 295 | raise NotImplementedError(f'dim={self.dim} is not supported') 296 | 297 | return x_in 298 | 299 | 300 | class VNLinearBNLeakyReLU(nn.Module): 301 | 302 | def __init__( 303 | self, 304 | in_channels, 305 | out_channels, 306 | dim=5, 307 | share_nonlinearity=False, 308 | negative_slope=0.2, 309 | ): 310 | super().__init__() 311 | 312 | self.linear = VNLinear(in_channels, out_channels) 313 | self.batchnorm = VNBatchNorm(out_channels, dim=dim) 314 | self.leaky_relu = VNLeakyReLU( 315 | out_channels, 316 | # dim=dim, 317 | share_nonlinearity=share_nonlinearity, 318 | negative_slope=negative_slope, 319 | ) 320 | 321 | def forward(self, x): 322 | # Linear 323 | p = self.linear(x) 324 | # BatchNorm 325 | p = self.batchnorm(p) 326 | # LeakyReLU 327 | p = self.leaky_relu(p) 328 | return p 329 | 330 | class NonEquivariantLinearLeakyReLU(nn.Module): 331 | def __init__(self, in_channels, out_channels, dim=2, use_batchnorm=True, negative_slope=0.2): 332 | super(NonEquivariantLinearLeakyReLU, self).__init__() 333 | self.dim = dim 334 | self.negative_slope = negative_slope 335 | self.use_batchnorm = use_batchnorm 336 | 337 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False) 338 | if use_batchnorm: 339 | self.batchnorm = nn.BatchNorm1d(out_channels) 340 | self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope) 341 | 342 | def forward(self, x): 343 | # Linear 344 | print("PRE x.shape is, ", x.shape) 345 | # [todo] 346 | x = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1) 347 | x = self.map_to_feat(x) 348 | 349 | # BatchNorm 350 | if self.use_batchnorm: 351 | x = x.transpose(1, -1).contiguous() 352 | x = self.batchnorm(x) 353 | x = x.transpose(1, -1).contiguous() 354 | 355 | # LeakyReLU 356 | x_out = self.leaky_relu(x) 357 | return x_out 358 | 359 | 360 | class NonEquivariantMaxPool(nn.Module): 361 | def __init__(self, dim=-1): 362 | super(NonEquivariantMaxPool, self).__init__() 363 | self.dim = dim 364 | 365 | def forward(self, x): 366 | ''' 367 | x: point features of shape [B, C, N] 368 | ''' 369 | return torch.max(x, dim=self.dim, keepdim=True)[0] 370 | 371 | 372 | class NonEquivariantStdFeature(nn.Module): 373 | def __init__(self, in_channels, dim=4, normalize_frame=False, negative_slope=0.2): 374 | super(NonEquivariantStdFeature, self).__init__() 375 | self.dim = dim 376 | self.normalize_frame = normalize_frame 377 | 378 | self.fc1 = NonEquivariantLinearLeakyReLU(in_channels, in_channels // 2, dim=dim, negative_slope=negative_slope) 379 | self.fc2 = NonEquivariantLinearLeakyReLU(in_channels // 2, in_channels // 4, dim=dim, negative_slope=negative_slope) 380 | 381 | def forward(self, x): 382 | ''' 383 | x: point features of shape [B, C_in, N] 384 | ''' 385 | z0 = x 386 | z0 = self.fc1(z0) 387 | z0 = self.fc2(z0) 388 | 389 | # No need for orthogonalization or frame normalization in non-equivariant version 390 | return z0 391 | 392 | 393 | 394 | class NonEquivariantLinearAndLeakyReLU(nn.Module): 395 | def __init__(self, in_channels, out_channels, dim=4, use_batchnorm=True, negative_slope=0.2): 396 | super(NonEquivariantLinearAndLeakyReLU, self).__init__() 397 | self.dim = dim 398 | self.use_batchnorm = use_batchnorm 399 | self.negative_slope = negative_slope 400 | 401 | self.linear = nn.Linear(in_channels, out_channels, bias=False) 402 | if use_batchnorm: 403 | self.batchnorm = nn.BatchNorm1d(out_channels) 404 | self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope) 405 | 406 | def forward(self, x): 407 | ''' 408 | x: point features of shape [B, C_in, N] 409 | ''' 410 | # Linear 411 | x = self.linear(x) 412 | # BatchNorm 413 | if self.use_batchnorm: 414 | x = self.batchnorm(x) 415 | # LeakyReLU 416 | x = self.leaky_relu(x) 417 | return x -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/Pose_Refinement_Module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/Pose_Refinement_Module.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn1.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn2.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn_A.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_A.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn_A_indi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_A_indi.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn_A_indi2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_A_indi2.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn_B.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_B.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/network_vnn_B2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_B2.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/regressor_CR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/regressor_CR.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /src/shape_assembly/models/train/network_vnn_A.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import pytorch3d 8 | from pytorch3d.transforms import quaternion_to_matrix 9 | from transforms3d.quaternions import quat2mat 10 | from scipy.spatial.transform import Rotation as R 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | import os 14 | import sys 15 | import copy 16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.join(BASE_DIR, '../../')) 18 | from models.encoder.pointnet import PointNet 19 | from models.encoder.dgcnn import DGCNN 20 | from models.decoder.MLPDecoder import MLPDecoder 21 | from models.encoder.vn_dgcnn import VN_DGCNN, VN_DGCNN_corr, VN_DGCNN_New 22 | from models.baseline.transformer import Transformer 23 | from models.baseline.regressor_CR import Regressor_CR, Regressor_6d, VN_Regressor_6d 24 | import utils 25 | from pdb import set_trace 26 | 27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | from ..ChamferDistancePytorch.chamfer3D import dist_chamfer_3D 29 | 30 | def bgs(d6s): 31 | bsz = d6s.shape[0] 32 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1) 33 | a2 = d6s[:, :, 1] 34 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1) 35 | b3 = torch.cross(b1, b2, dim=1) 36 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1) 37 | 38 | def debug_vis_input(batch_data, cfg, prd_data, iter_counts): 39 | for i in range(cfg.exp.batch_size): 40 | save_dir = cfg.exp.vis_dir 41 | vis_dir = os.path.join(save_dir, 'vis_A_input') 42 | if not os.path.exists((vis_dir)): 43 | os.mkdir(vis_dir) 44 | 45 | src_pc = batch_data['src_pc'][i] 46 | tgt_pc = batch_data['tgt_pc'][i] 47 | 48 | device = src_pc.device 49 | num_points = src_pc.shape[1] 50 | 51 | tgt_pc_trans = batch_data['predicted_partB_position'][i].unsqueeze(1) 52 | tgt_pc_rot = batch_data['predicted_partB_rotation'][i].unsqueeze(1) 53 | 54 | tgt_rot_mat = bgs(tgt_pc_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 55 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat) 56 | tgt_pc_trans = tgt_pc_trans.expand(-1, num_points) 57 | 58 | tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double()) 59 | tgt_pc = tgt_pc + tgt_pc_trans 60 | 61 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu()) 62 | color_mask_src = ['r'] * num_points 63 | color_mask_tgt = ['g'] * num_points 64 | total_color = color_mask_src + color_mask_tgt 65 | 66 | fig = plt.figure() 67 | ax = fig.add_subplot(projection='3d') 68 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1) 69 | ax.axis('scaled') 70 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 71 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 72 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 73 | 74 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_input'.format(iter_counts,i))) 75 | plt.close(fig) 76 | 77 | def debug_vis_output(batch_data, cfg, pred_data, iter_counts): 78 | for i in range(cfg.exp.batch_size): 79 | save_dir = cfg.exp.vis_dir 80 | vis_dir = os.path.join(save_dir, 'vis_A_output') 81 | if not os.path.exists((vis_dir)): 82 | os.mkdir(vis_dir) 83 | 84 | src_pc = batch_data['src_pc'][i] 85 | src_trans = pred_data['src_trans'][i].unsqueeze(1) 86 | src_rot = pred_data['src_rot'][i] 87 | 88 | device = src_pc.device 89 | num_points = src_pc.shape[1] 90 | 91 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 92 | src_rot_mat = torch.linalg.inv(src_rot_mat) 93 | 94 | src_trans = src_trans.expand(-1, num_points) 95 | 96 | src_pc = torch.matmul(src_rot_mat.double(), src_pc.double()) 97 | 98 | src_pc = src_pc + src_trans 99 | 100 | tgt_pc = batch_data['tgt_pc'][i] 101 | 102 | tgt_pc_trans = batch_data['predicted_partB_position'][i].unsqueeze(1) 103 | tgt_pc_rot = batch_data['predicted_partB_rotation'][i].unsqueeze(1) 104 | 105 | tgt_rot_mat = bgs(tgt_pc_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 106 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat) 107 | tgt_pc_trans = tgt_pc_trans.expand(-1, num_points) 108 | 109 | tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double()) 110 | tgt_pc = tgt_pc + tgt_pc_trans 111 | 112 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu()) 113 | color_mask_src = ['r'] * num_points 114 | color_mask_tgt = ['g'] * num_points 115 | total_color = color_mask_src + color_mask_tgt 116 | 117 | fig = plt.figure() 118 | ax = fig.add_subplot(projection='3d') 119 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1) 120 | ax.axis('scaled') 121 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 122 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 123 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 124 | 125 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_output'.format(iter_counts,i))) 126 | plt.close(fig) 127 | 128 | def debug_vis_gt(batch_data, cfg, pred_data, iter_counts): 129 | for i in range(cfg.exp.batch_size): 130 | save_dir = cfg.exp.vis_dir 131 | vis_dir = os.path.join(save_dir, 'vis_A_gt') 132 | if not os.path.exists((vis_dir)): 133 | os.mkdir(vis_dir) 134 | src_pc = batch_data['src_pc'][i] 135 | tgt_pc = batch_data['tgt_pc'][i] 136 | 137 | src_trans = batch_data['src_trans'][i] 138 | src_rot = batch_data['src_rot'][i] 139 | tgt_trans = batch_data['tgt_trans'][i] 140 | tgt_rot = batch_data['tgt_rot'][i] 141 | device = src_pc.device 142 | num_points = src_pc.shape[1] 143 | 144 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 145 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 146 | src_rot_mat = torch.linalg.inv(src_rot_mat) 147 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat) 148 | 149 | src_trans = src_trans.expand(-1, num_points) 150 | tgt_trans = tgt_trans.expand(-1, num_points) 151 | 152 | src_pc = torch.matmul(src_rot_mat, src_pc) 153 | tgt_pc = torch.matmul(tgt_rot_mat, tgt_pc) 154 | 155 | src_pc = src_pc + src_trans 156 | tgt_pc = tgt_pc + tgt_trans 157 | 158 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu()) 159 | color_mask_src = ['r'] * num_points 160 | color_mask_tgt = ['g'] * num_points 161 | total_color = color_mask_src + color_mask_tgt 162 | 163 | fig = plt.figure() 164 | ax = fig.add_subplot(projection='3d') 165 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1) 166 | ax.axis('scaled') 167 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 168 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 169 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 170 | 171 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_gt'.format(iter_counts,i))) 172 | plt.close(fig) 173 | 174 | class ShapeAssemblyNet_A_vnn(nn.Module): 175 | def __init__(self, cfg, data_features): 176 | super().__init__() 177 | self.cfg = cfg 178 | self.encoder = self.init_encoder() 179 | self.pose_predictor_rot = self.init_pose_predictor_rot() 180 | self.pose_predictor_trans = self.init_pose_predictor_trans() 181 | if self.cfg.model.recon_loss: 182 | self.decoder = self.init_decoder() 183 | self.data_features = data_features 184 | self.iter_counts = 0 185 | self.close_eps = 0.1 186 | self.L2 = nn.MSELoss() 187 | self.R = torch.tensor([[0.26726124, -0.57735027, 0.77151675], 188 | [0.53452248, -0.57735027, -0.6172134], 189 | [0.80178373, 0.57735027, 0.15430335]], dtype=torch.float64).unsqueeze(0) 190 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist() 191 | self.transformer = self.init_transformer() 192 | 193 | def init_transformer(self): 194 | transformer = Transformer(cfg = self.cfg) 195 | return transformer 196 | 197 | def init_encoder(self): 198 | if self.cfg.model.encoderA == 'dgcnn': 199 | encoder = DGCNN(feat_dim=self.cfg.model.pc_feat_dim) 200 | elif self.cfg.model.encoderA == 'vn_dgcnn': 201 | encoder = VN_DGCNN_New(feat_dim=self.cfg.model.pc_feat_dim) 202 | elif self.cfg.model.encoderA == 'pointnet': 203 | encoder = PointNet(feat_dim=self.cfg.model.pc_feat_dim) 204 | return encoder 205 | 206 | def init_pose_predictor_rot(self): 207 | if self.cfg.model.encoderA == 'vn_dgcnn': 208 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3 209 | if self.cfg.model.pose_predictor_rot == 'original': 210 | pose_predictor_rot = Regressor_CR(pc_feat_dim= pc_feat_dim, out_dim=6) 211 | elif self.cfg.model.pose_predictor_rot == 'vn': 212 | pose_predictor_rot = VN_equ_Regressor(pc_feat_dim= pc_feat_dim/3, out_dim=6) 213 | return pose_predictor_rot 214 | 215 | def init_pose_predictor_trans(self): 216 | if self.cfg.model.encoderA == 'vn_dgcnn': 217 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3 218 | if self.cfg.model.pose_predictor_trans == 'original': 219 | pose_predictor_trans = Regressor_CR(pc_feat_dim=pc_feat_dim, out_dim=3) 220 | elif self.cfg.model.pose_predictor_trans == 'vn': 221 | pose_predictor_trans = VN_inv_Regressor(pc_feat_dim=pc_feat_dim/3, out_dim=3) 222 | return pose_predictor_trans 223 | 224 | def init_decoder(self): 225 | pc_feat_dim = self.cfg.model.pc_feat_dim 226 | decoder = MLPDecoder(feat_dim=pc_feat_dim, num_points=self.cfg.data.num_pc_points) 227 | return decoder 228 | def configure_optimizers(self): 229 | optimizer = torch.optim.Adam( 230 | self.parameters(), 231 | lr=self.cfg.optimizer.lr, 232 | weight_decay=self.cfg.optimizer.weight_decay, 233 | ) 234 | return optimizer 235 | 236 | def check_equiv(self, x, R, xR, name): 237 | mean_diff = torch.mean(torch.abs(torch.matmul(x, R) - xR)) 238 | if mean_diff > self.close_eps: 239 | print(f'---[Equiv check]--- {name}: {mean_diff}') 240 | return 241 | 242 | def check_inv(self, x, R, xR, name): 243 | mean_diff = torch.mean(torch.abs(x - xR)) 244 | if mean_diff > self.close_eps: 245 | print(f'---[Equiv check]--- {name}: {mean_diff}') 246 | return 247 | 248 | def check_network_property(self, gt_data, pred_data): 249 | with torch.no_grad(): 250 | B, _, N = gt_data['src_pc'].shape 251 | R = self.R.float().repeat(B, 1, 1).to(gt_data['src_pc'].device) 252 | pcs_R = torch.matmul(gt_data['src_pc'].permute(0, 2, 1), R).permute(0, 2, 1) 253 | pred_data_R = self.forward(pcs_R, gt_data['tgt_pc']) 254 | 255 | equiv_feats = pred_data['Fa'] 256 | equiv_feats_R = pred_data_R['Fa'] 257 | self.check_equiv(equiv_feats, R, equiv_feats_R, 'equiv_feats') 258 | 259 | inv_feats = pred_data['Ga'] 260 | inv_feats_R = pred_data_R['Ga'] 261 | self.check_inv(inv_feats, R, inv_feats_R, 'inv_feats') 262 | 263 | if self.cfg.model.pose_predictor_rot == 'vn': 264 | rot = bgs(pred_data['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1) 265 | rot_R = bgs(pred_data_R['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1) 266 | self.check_equiv(rot, R, rot_R, 'rot') 267 | return 268 | 269 | def _recon_pts(self, Ga, Gb): 270 | global_inv_feat = torch.sum(torch.cat([Ga, Gb], dim=1), dim=1) 271 | recon_pts = self.decoder(global_inv_feat) 272 | return recon_pts 273 | 274 | def forward(self, src_pc, tgt_pc): 275 | batch_size = src_pc.shape[0] 276 | num_points = src_pc.shape[2] 277 | if self.cfg.model.encoderA == 'dgcnn': 278 | src_point_feat = self.encoder(src_pc) 279 | tgt_point_feat = self.encoder(tgt_pc) 280 | 281 | src_feat = torch.mean(src_point_feat, dim=2) 282 | tgt_feat = torch.mean(tgt_point_feat, dim=2) 283 | 284 | if self.cfg.model.encoderA == 'vn_dgcnn': 285 | Fa, Ga = self.encoder(src_pc) 286 | Fb, Gb = self.encoder(tgt_pc) 287 | 288 | src_feat_corr = Fa * Gb[:, :, :3] 289 | 290 | src_feat = Fa 291 | 292 | if self.cfg.model.pose_predictor_rot == 'original': 293 | src_rot = self.pose_predictor_rot(src_feat_corr.reshape(batch_size, -1)) 294 | else: 295 | src_rot = self.pose_predictor_rot(src_feat_corr) 296 | 297 | if self.cfg.model.pose_predictor_trans == 'original': 298 | src_trans = self.pose_predictor_trans(src_feat_corr.reshape(batch_size, -1)) 299 | else: 300 | src_trans = self.pose_predictor_trans(src_feat_corr) 301 | 302 | if self.cfg.model.recon_loss: 303 | recon_pts = self._recon_pts(Ga, Gb) 304 | pred_dict = { 305 | 'src_rot': src_rot, 306 | 'src_trans': src_trans, 307 | } 308 | if self.cfg.model.encoderA == 'vn_dgcnn': 309 | pred_dict['Fa'] = Fa 310 | pred_dict['Ga'] = Ga 311 | if self.cfg.model.recon_loss: 312 | pred_dict['recon_pts'] = recon_pts 313 | return pred_dict 314 | 315 | def compute_point_loss(self, batch_data, pred_data): 316 | src_pc = batch_data['src_pc'].float() 317 | 318 | src_rot_gt = self.recover_R_from_6d(batch_data['src_rot'].float()) 319 | src_trans_gt = batch_data['src_trans'].float() 320 | 321 | src_rot_pred = self.recover_R_from_6d(pred_data['src_rot'].float()) 322 | src_trans_pred = pred_data['src_trans'].float() 323 | 324 | src_trans_pred = src_trans_pred.unsqueeze(2) 325 | 326 | transformed_src_pc_pred = src_rot_pred @ src_pc + src_trans_pred 327 | with torch.no_grad(): 328 | transformed_src_pc_gt = src_rot_gt @ src_pc + src_trans_gt 329 | src_point_loss = torch.mean(torch.sum((transformed_src_pc_pred - transformed_src_pc_gt) ** 2, axis=1)) 330 | 331 | point_loss = src_point_loss 332 | return point_loss 333 | 334 | def compute_trans_loss(self, batch_data, pred_data): 335 | src_trans_gt = batch_data['src_trans'].float() 336 | 337 | src_trans_pred = pred_data['src_trans'] 338 | 339 | src_trans_pred = src_trans_pred.unsqueeze(dim=2) 340 | 341 | src_trans_loss = F.l1_loss(src_trans_pred, src_trans_gt) 342 | trans_loss = src_trans_loss 343 | return trans_loss 344 | 345 | def compute_rot_loss(self, batch_data, pred_data): 346 | src_R_6d = batch_data['src_rot'] 347 | 348 | src_R_6d_pred = pred_data['src_rot'] 349 | 350 | src_rot_loss = torch.mean(utils.get_6d_rot_loss(src_R_6d, src_R_6d_pred)) 351 | rot_loss = src_rot_loss 352 | return rot_loss 353 | 354 | def compute_rot_loss_symmetry(self, batch_data, pred_data, device): 355 | partA_symmetry_type = batch_data['partA_symmetry_type'] 356 | 357 | src_R_6d = batch_data['src_rot'] 358 | 359 | src_R_6d_pred = pred_data['src_rot'] 360 | src_rot_loss = torch.mean(utils.get_6d_rot_loss_symmetry(src_R_6d, src_R_6d_pred, partA_symmetry_type, device)) 361 | 362 | rot_loss = src_rot_loss 363 | return rot_loss 364 | 365 | def compute_recon_loss(self, batch_data, pred_data): 366 | recon_pts = pred_data['recon_pts'] 367 | 368 | src_pc = batch_data['src_pc'].float() 369 | tgt_pc = batch_data['tgt_pc'].float() 370 | src_quat_gt = batch_data['src_rot'].float() 371 | tgt_quat_gt = batch_data['tgt_rot'].float() 372 | 373 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) 374 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) 375 | 376 | src_trans_gt = batch_data['src_trans'].float() 377 | tgt_trans_gt = batch_data['tgt_trans'].float() 378 | with torch.no_grad(): 379 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt 380 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt 381 | gt_pts = torch.cat([transformed_src_pc_gt, transformed_tgt_pc_gt], dim=2).permute(0, 2, 1) 382 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist() 383 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts) 384 | recon_loss = torch.mean(dist1) + torch.mean(dist2) 385 | return recon_loss 386 | 387 | def recover_R_from_6d(self, R_6d): 388 | R = utils.bgs(R_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 389 | return R 390 | 391 | def quat_to_eular(self, quat): 392 | quat = np.array([quat[1], quat[2], quat[3], quat[0]]) 393 | 394 | r = R.from_quat(quat) 395 | euler0 = r.as_euler('xyz', degrees=True) 396 | 397 | return euler0 398 | 399 | def training_step(self, batch_data, device, batch_idx): 400 | self.iter_counts += 1 401 | total_loss, point_loss, rot_loss, trans_loss, recon_loss = self.forward_pass(batch_data, device, mode='train') 402 | return {"total_loss": total_loss, 403 | "point_loss": point_loss, 404 | "rot_loss": rot_loss, 405 | "trans_loss": trans_loss, 406 | "recon_loss": recon_loss, 407 | } 408 | 409 | def calculate_metrics(self, batch_data, pred_data, device, mode): 410 | GD = self.compute_rot_loss(batch_data, pred_data, device) 411 | 412 | rot_error = self.compute_rot_loss(batch_data, pred_data, device) 413 | 414 | src_pc = batch_data['src_pc'].float() 415 | src_quat_gt = batch_data['src_rot'].float() 416 | 417 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) 418 | 419 | src_trans_gt = batch_data['src_trans'].float() 420 | with torch.no_grad(): 421 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt 422 | gt_pts = transformed_src_pc_gt.permute(0, 2, 1) 423 | 424 | pred_R_src = self.recover_R_from_6d(pred_data['src_rot']) 425 | pred_t_src = pred_data['src_trans'].view(-1, 3, 1) 426 | 427 | gt_euler_src = pytorch3d.transforms.matrix_to_euler_angles(src_Rs, convention="XYZ") 428 | 429 | pred_euler_src = pytorch3d.transforms.matrix_to_euler_angles(pred_R_src, convention="XYZ") 430 | 431 | with torch.no_grad(): 432 | transformed_src_pc_pred = pred_R_src @ src_pc + pred_t_src 433 | 434 | recon_pts = transformed_src_pc_pred .permute(0, 2, 1) 435 | 436 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts) 437 | PA = torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1) 438 | 439 | thre = 0.0001 440 | acc = (PA < thre) 441 | PA_threshold = acc.sum(-1) / acc.shape[0] 442 | 443 | RMSE_T_1 = (pred_t_src - src_trans_gt).pow(2).mean(dim=-1) ** 0.5 444 | 445 | RMSE_T = RMSE_T_1 446 | 447 | dist_a1, dist_a2, idx_a1, idx_a2 = self.chamLoss(transformed_src_pc_gt.permute(0,2,1), transformed_src_pc_pred.permute(0,2,1)) 448 | 449 | CD_1 = torch.mean(dist_a1, dim=-1) + torch.mean(dist_a2, dim=-1) 450 | 451 | return GD, rot_error, RMSE_T, PA_threshold, PA, CD_1 452 | 453 | def forward_pass(self, batch_data, device, mode, vis_idx=-1): 454 | tgt_pc = batch_data['tgt_pc'].float() 455 | 456 | device = tgt_pc.device 457 | 458 | tgt_trans = batch_data['predicted_partB_position'].unsqueeze(-1).repeat(1,1,1024) 459 | tgt_rot = batch_data['predicted_partB_rotation'] 460 | 461 | device = tgt_pc.device 462 | num_points = tgt_pc.shape[1] 463 | batch_size = tgt_pc.shape[0] 464 | 465 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(batch_size, 3, 3) 466 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat) 467 | 468 | transformed_tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double()) 469 | transformed_tgt_pc = transformed_tgt_pc + tgt_trans 470 | transformed_tgt_pc = transformed_tgt_pc.float() 471 | 472 | pred_data = self.forward(batch_data['src_pc'].float(), transformed_tgt_pc) 473 | self.check_network_property(batch_data, pred_data) 474 | point_loss = 0.0 475 | 476 | rot_loss = self.compute_rot_loss(batch_data, pred_data, device) 477 | trans_loss = self.compute_trans_loss(batch_data, pred_data) 478 | if self.cfg.model.recon_loss: 479 | recon_loss = self.compute_recon_loss(batch_data, pred_data) 480 | else: 481 | recon_loss = 0.0 482 | 483 | if vis_idx > -1: 484 | debug_vis_input(batch_data, self.cfg, pred_data, vis_idx) 485 | debug_vis_output(batch_data, self.cfg, pred_data, vis_idx) 486 | debug_vis_gt(batch_data, self.cfg, pred_data, vis_idx) 487 | 488 | total_loss = point_loss + rot_loss + trans_loss + recon_loss 489 | 490 | if mode == 'val': 491 | return (self.calculate_metrics(batch_data, pred_data, device, mode), total_loss, point_loss,rot_loss,trans_loss,recon_loss) 492 | 493 | return total_loss, point_loss, rot_loss, trans_loss, recon_loss -------------------------------------------------------------------------------- /src/shape_assembly/models/train/network_vnn_A_indi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import pytorch3d 8 | from pytorch3d.transforms import quaternion_to_matrix 9 | from transforms3d.quaternions import quat2mat 10 | from scipy.spatial.transform import Rotation as R 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | import os 14 | import sys 15 | import copy 16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.join(BASE_DIR, '../../')) 18 | from models.encoder.pointnet import PointNet 19 | from models.encoder.dgcnn import DGCNN 20 | from models.decoder.MLPDecoder import MLPDecoder 21 | from models.encoder.vn_dgcnn import VN_DGCNN, VN_DGCNN_corr, VN_DGCNN_New 22 | from models.baseline.transformer import Transformer 23 | from models.baseline.regressor_CR import Regressor_CR, Regressor_6d, VN_Regressor_6d 24 | import utils 25 | from pdb import set_trace 26 | 27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | from ..ChamferDistancePytorch.chamfer3D import dist_chamfer_3D 29 | 30 | def bgs(d6s): 31 | bsz = d6s.shape[0] 32 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1) 33 | a2 = d6s[:, :, 1] 34 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1) 35 | b3 = torch.cross(b1, b2, dim=1) 36 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1) 37 | 38 | def debug_vis_input(batch_data, cfg, prd_data, iter_counts): 39 | for i in range(cfg.exp.batch_size): 40 | save_dir = cfg.exp.vis_dir 41 | vis_dir = os.path.join(save_dir, 'vis_A_input') 42 | if not os.path.exists((vis_dir)): 43 | os.mkdir(vis_dir) 44 | 45 | src_pc = batch_data['src_pc'][i] 46 | tgt_pc = batch_data['tgt_pc'][i] 47 | 48 | device = src_pc.device 49 | num_points = src_pc.shape[1] 50 | 51 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu()) 52 | color_mask_src = ['r'] * num_points 53 | color_mask_tgt = ['g'] * num_points 54 | total_color = color_mask_src + color_mask_tgt 55 | 56 | fig = plt.figure() 57 | ax = fig.add_subplot(projection='3d') 58 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1) 59 | ax.axis('scaled') 60 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 61 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 62 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 63 | 64 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_input'.format(iter_counts,i))) 65 | plt.close(fig) 66 | 67 | def debug_vis_output(batch_data, cfg, pred_data, iter_counts): 68 | for i in range(cfg.exp.batch_size): 69 | save_dir = cfg.exp.vis_dir 70 | vis_dir = os.path.join(save_dir, 'vis_A_output') 71 | if not os.path.exists((vis_dir)): 72 | os.mkdir(vis_dir) 73 | 74 | src_pc = batch_data['src_pc'][i] 75 | src_trans = pred_data['src_trans'][i].unsqueeze(1) 76 | src_rot = pred_data['src_rot'][i] 77 | 78 | device = src_pc.device 79 | num_points = src_pc.shape[1] 80 | 81 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 82 | src_rot_mat = torch.linalg.inv(src_rot_mat) 83 | 84 | src_trans = src_trans.expand(-1, num_points) 85 | src_pc = torch.matmul(src_rot_mat.double(), src_pc.double()) 86 | 87 | src_pc = src_pc + src_trans 88 | 89 | tgt_pc = batch_data['tgt_pc'][i] 90 | 91 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu()) 92 | color_mask_src = ['r'] * num_points 93 | color_mask_tgt = ['g'] * num_points 94 | total_color = color_mask_src + color_mask_tgt 95 | 96 | fig = plt.figure() 97 | ax = fig.add_subplot(projection='3d') 98 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1) 99 | ax.axis('scaled') 100 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 101 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 102 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 103 | 104 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_output'.format(iter_counts,i))) 105 | plt.close(fig) 106 | 107 | def debug_vis_gt(batch_data, cfg, pred_data, iter_counts): 108 | for i in range(cfg.exp.batch_size): 109 | save_dir = cfg.exp.vis_dir 110 | vis_dir = os.path.join(save_dir, 'vis_A_gt') 111 | if not os.path.exists((vis_dir)): 112 | os.mkdir(vis_dir) 113 | src_pc = batch_data['src_pc'][i] 114 | tgt_pc = batch_data['tgt_pc'][i] 115 | 116 | src_trans = batch_data['src_trans'][i] 117 | src_rot = batch_data['src_rot'][i] 118 | 119 | device = src_pc.device 120 | num_points = src_pc.shape[1] 121 | 122 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 123 | src_rot_mat = torch.linalg.inv(src_rot_mat) 124 | 125 | src_trans = src_trans.expand(-1, num_points) 126 | 127 | src_pc = torch.matmul(src_rot_mat, src_pc) 128 | 129 | src_pc = src_pc + src_trans 130 | 131 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu()) 132 | color_mask_src = ['r'] * num_points 133 | color_mask_tgt = ['g'] * num_points 134 | total_color = color_mask_src + color_mask_tgt 135 | 136 | fig = plt.figure() 137 | ax = fig.add_subplot(projection='3d') 138 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1) 139 | ax.axis('scaled') 140 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 141 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 142 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 143 | 144 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_gt'.format(iter_counts,i))) 145 | plt.close(fig) 146 | 147 | class ShapeAssemblyNet_A_vnn(nn.Module): 148 | def __init__(self, cfg, data_features): 149 | super().__init__() 150 | self.cfg = cfg 151 | self.encoder = self.init_encoder() 152 | self.pose_predictor_rot = self.init_pose_predictor_rot() 153 | self.pose_predictor_trans = self.init_pose_predictor_trans() 154 | if self.cfg.model.recon_loss: 155 | self.decoder = self.init_decoder() 156 | self.data_features = data_features 157 | self.iter_counts = 0 158 | self.close_eps = 0.1 159 | self.L2 = nn.MSELoss() 160 | self.R = torch.tensor([[0.26726124, -0.57735027, 0.77151675], 161 | [0.53452248, -0.57735027, -0.6172134], 162 | [0.80178373, 0.57735027, 0.15430335]], dtype=torch.float64).unsqueeze(0) 163 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist() 164 | self.transformer = self.init_transformer() 165 | 166 | def init_transformer(self): 167 | transformer = Transformer(cfg = self.cfg) 168 | return transformer 169 | 170 | def init_encoder(self): 171 | if self.cfg.model.encoderA == 'dgcnn': 172 | encoder = DGCNN(feat_dim=self.cfg.model.pc_feat_dim) 173 | elif self.cfg.model.encoderA == 'vn_dgcnn': 174 | encoder = VN_DGCNN_New(feat_dim=self.cfg.model.pc_feat_dim) 175 | elif self.cfg.model.encoderA == 'pointnet': 176 | encoder = PointNet(feat_dim=self.cfg.model.pc_feat_dim) 177 | return encoder 178 | 179 | def init_pose_predictor_rot(self): 180 | if self.cfg.model.encoderA == 'vn_dgcnn': 181 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3 182 | if self.cfg.model.pose_predictor_rot == 'original': 183 | pose_predictor_rot = Regressor_CR(pc_feat_dim= pc_feat_dim, out_dim=6) 184 | elif self.cfg.model.pose_predictor_rot == 'vn': 185 | pose_predictor_rot = VN_equ_Regressor(pc_feat_dim= pc_feat_dim/3, out_dim=6) 186 | return pose_predictor_rot 187 | 188 | def init_pose_predictor_trans(self): 189 | if self.cfg.model.encoderA == 'vn_dgcnn': 190 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3 191 | if self.cfg.model.pose_predictor_trans == 'original': 192 | pose_predictor_trans = Regressor_CR(pc_feat_dim=pc_feat_dim, out_dim=3) 193 | elif self.cfg.model.pose_predictor_trans == 'vn': 194 | pose_predictor_trans = VN_inv_Regressor(pc_feat_dim=pc_feat_dim/3, out_dim=3) 195 | return pose_predictor_trans 196 | 197 | def init_decoder(self): 198 | pc_feat_dim = self.cfg.model.pc_feat_dim 199 | decoder = MLPDecoder(feat_dim=pc_feat_dim, num_points=self.cfg.data.num_pc_points) 200 | return decoder 201 | def configure_optimizers(self): 202 | optimizer = torch.optim.Adam( 203 | self.parameters(), 204 | lr=self.cfg.optimizer.lr, 205 | weight_decay=self.cfg.optimizer.weight_decay, 206 | ) 207 | return optimizer 208 | 209 | def check_equiv(self, x, R, xR, name): 210 | mean_diff = torch.mean(torch.abs(torch.matmul(x, R) - xR)) 211 | if mean_diff > self.close_eps: 212 | print(f'---[Equiv check]--- {name}: {mean_diff}') 213 | return 214 | 215 | def check_inv(self, x, R, xR, name): 216 | mean_diff = torch.mean(torch.abs(x - xR)) 217 | if mean_diff > self.close_eps: 218 | print(f'---[Equiv check]--- {name}: {mean_diff}') 219 | return 220 | 221 | def check_network_property(self, gt_data, pred_data): 222 | with torch.no_grad(): 223 | B, _, N = gt_data['src_pc'].shape 224 | R = self.R.float().repeat(B, 1, 1).to(gt_data['src_pc'].device) 225 | pcs_R = torch.matmul(gt_data['src_pc'].permute(0, 2, 1), R).permute(0, 2, 1) 226 | pred_data_R = self.forward(pcs_R, gt_data['tgt_pc']) 227 | 228 | equiv_feats = pred_data['Fa'] 229 | equiv_feats_R = pred_data_R['Fa'] 230 | self.check_equiv(equiv_feats, R, equiv_feats_R, 'equiv_feats') 231 | 232 | inv_feats = pred_data['Ga'] 233 | inv_feats_R = pred_data_R['Ga'] 234 | self.check_inv(inv_feats, R, inv_feats_R, 'inv_feats') 235 | 236 | if self.cfg.model.pose_predictor_rot == 'vn': 237 | rot = bgs(pred_data['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1) 238 | rot_R = bgs(pred_data_R['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1) 239 | self.check_equiv(rot, R, rot_R, 'rot') 240 | return 241 | 242 | def _recon_pts(self, Ga, Gb): 243 | global_inv_feat = torch.sum(torch.cat([Ga, Gb], dim=1), dim=1) 244 | recon_pts = self.decoder(global_inv_feat) 245 | return recon_pts 246 | 247 | def forward(self, src_pc, tgt_pc): 248 | batch_size = src_pc.shape[0] 249 | num_points = src_pc.shape[2] 250 | if self.cfg.model.encoderA == 'dgcnn': 251 | src_point_feat = self.encoder(src_pc) 252 | tgt_point_feat = self.encoder(tgt_pc) 253 | 254 | src_feat = torch.mean(src_point_feat, dim=2) 255 | tgt_feat = torch.mean(tgt_point_feat, dim=2) 256 | 257 | if self.cfg.model.encoderA == 'vn_dgcnn': 258 | Fa, Ga = self.encoder(src_pc) 259 | Fb, Gb = self.encoder(tgt_pc) 260 | 261 | src_feat_corr = Fa * Gb[:, :, :3] 262 | 263 | src_feat = Fa 264 | 265 | if self.cfg.model.pose_predictor_rot == 'original': 266 | src_rot = self.pose_predictor_rot(src_feat_corr.reshape(batch_size, -1)) 267 | else: 268 | src_rot = self.pose_predictor_rot(src_feat_corr) 269 | 270 | if self.cfg.model.pose_predictor_trans == 'original': 271 | src_trans = self.pose_predictor_trans(src_feat_corr.reshape(batch_size, -1)) 272 | else: 273 | src_trans = self.pose_predictor_trans(src_feat_corr) 274 | 275 | if self.cfg.model.recon_loss: 276 | recon_pts = self._recon_pts(Ga, Gb) 277 | pred_dict = { 278 | 'src_rot': src_rot, 279 | 'src_trans': src_trans, 280 | } 281 | if self.cfg.model.encoderA == 'vn_dgcnn': 282 | pred_dict['Fa'] = Fa 283 | pred_dict['Ga'] = Ga 284 | if self.cfg.model.recon_loss: 285 | pred_dict['recon_pts'] = recon_pts 286 | return pred_dict 287 | 288 | def compute_point_loss(self, batch_data, pred_data): 289 | src_pc = batch_data['src_pc'].float() 290 | 291 | src_rot_gt = self.recover_R_from_6d(batch_data['src_rot'].float()) 292 | src_trans_gt = batch_data['src_trans'].float() 293 | 294 | src_rot_pred = self.recover_R_from_6d(pred_data['src_rot'].float()) 295 | src_trans_pred = pred_data['src_trans'].float() 296 | 297 | src_trans_pred = src_trans_pred.unsqueeze(2) 298 | 299 | transformed_src_pc_pred = src_rot_pred @ src_pc + src_trans_pred 300 | with torch.no_grad(): 301 | transformed_src_pc_gt = src_rot_gt @ src_pc + src_trans_gt 302 | 303 | src_point_loss = torch.mean(torch.sum((transformed_src_pc_pred - transformed_src_pc_gt) ** 2, axis=1)) 304 | 305 | point_loss = src_point_loss 306 | return point_loss 307 | 308 | def compute_trans_loss(self, batch_data, pred_data): 309 | src_trans_gt = batch_data['src_trans'].float() 310 | 311 | src_trans_pred = pred_data['src_trans'] 312 | 313 | src_trans_pred = src_trans_pred.unsqueeze(dim=2) 314 | 315 | src_trans_loss = F.l1_loss(src_trans_pred, src_trans_gt) 316 | trans_loss = src_trans_loss 317 | return trans_loss 318 | 319 | def compute_rot_loss(self, batch_data, pred_data): 320 | src_R_6d = batch_data['src_rot'] 321 | 322 | src_R_6d_pred = pred_data['src_rot'] 323 | 324 | src_rot_loss = torch.mean(utils.get_6d_rot_loss(src_R_6d, src_R_6d_pred)) 325 | rot_loss = src_rot_loss 326 | return rot_loss 327 | 328 | def compute_rot_loss_symmetry(self, batch_data, pred_data, device): 329 | partA_symmetry_type = batch_data['partA_symmetry_type'] 330 | 331 | src_R_6d = batch_data['src_rot'] 332 | 333 | src_R_6d_pred = pred_data['src_rot'] 334 | src_rot_loss = torch.mean(utils.get_6d_rot_loss_symmetry(src_R_6d, src_R_6d_pred, partA_symmetry_type, device)) 335 | 336 | rot_loss = src_rot_loss 337 | return rot_loss 338 | 339 | def compute_recon_loss(self, batch_data, pred_data): 340 | recon_pts = pred_data['recon_pts'] 341 | 342 | src_pc = batch_data['src_pc'].float() 343 | tgt_pc = batch_data['tgt_pc'].float() 344 | src_quat_gt = batch_data['src_rot'].float() 345 | tgt_quat_gt = batch_data['tgt_rot'].float() 346 | 347 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) 348 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) 349 | 350 | src_trans_gt = batch_data['src_trans'].float() 351 | tgt_trans_gt = batch_data['tgt_trans'].float() 352 | with torch.no_grad(): 353 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt 354 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt 355 | gt_pts = torch.cat([transformed_src_pc_gt, transformed_tgt_pc_gt], dim=2).permute(0, 2, 1) 356 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist() 357 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts) 358 | recon_loss = torch.mean(dist1) + torch.mean(dist2) 359 | return recon_loss 360 | 361 | def recover_R_from_6d(self, R_6d): 362 | R = utils.bgs(R_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 363 | return R 364 | 365 | def quat_to_eular(self, quat): 366 | quat = np.array([quat[1], quat[2], quat[3], quat[0]]) 367 | 368 | r = R.from_quat(quat) 369 | euler0 = r.as_euler('xyz', degrees=True) 370 | 371 | return euler0 372 | 373 | def training_step(self, batch_data, device, batch_idx): 374 | self.iter_counts += 1 375 | partA_position, partA_rotation, total_loss, point_loss, rot_loss, trans_loss, recon_loss = self.forward_pass(batch_data, device, mode='train') 376 | return {"total_loss": total_loss, 377 | "point_loss": point_loss, 378 | "rot_loss": rot_loss, 379 | "trans_loss": trans_loss, 380 | "recon_loss": recon_loss, 381 | "predicted_partA_position": partA_position, 382 | "predicted_partA_rotation": partA_rotation 383 | } 384 | 385 | def calculate_metrics(self, batch_data, pred_data, device, mode): 386 | GD = self.compute_rot_loss(batch_data, pred_data) 387 | 388 | rot_error = self.compute_rot_loss(batch_data, pred_data) 389 | 390 | src_pc = batch_data['src_pc'].float() 391 | src_quat_gt = batch_data['src_rot'].float() 392 | 393 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) 394 | 395 | src_trans_gt = batch_data['src_trans'].float() 396 | with torch.no_grad(): 397 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt 398 | gt_pts = transformed_src_pc_gt.permute(0, 2, 1) 399 | 400 | pred_R_src = self.recover_R_from_6d(pred_data['src_rot']) 401 | pred_t_src = pred_data['src_trans'].view(-1, 3, 1) 402 | 403 | gt_euler_src = pytorch3d.transforms.matrix_to_euler_angles(src_Rs, convention="XYZ") 404 | 405 | pred_euler_src = pytorch3d.transforms.matrix_to_euler_angles(pred_R_src, convention="XYZ") 406 | 407 | with torch.no_grad(): 408 | transformed_src_pc_pred = pred_R_src @ src_pc + pred_t_src 409 | 410 | recon_pts = transformed_src_pc_pred .permute(0, 2, 1) 411 | 412 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts) 413 | PA = torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1) 414 | 415 | thre = 0.0001 416 | acc = (PA < thre) 417 | PA_threshold = acc.sum(-1) / acc.shape[0] 418 | 419 | RMSE_T_1 = (pred_t_src - src_trans_gt).pow(2).mean(dim=-1) ** 0.5 420 | 421 | RMSE_T = RMSE_T_1 422 | 423 | dist_a1, dist_a2, idx_a1, idx_a2 = self.chamLoss(transformed_src_pc_gt.permute(0,2,1), transformed_src_pc_pred.permute(0,2,1)) 424 | 425 | CD_1 = torch.mean(dist_a1, dim=-1) + torch.mean(dist_a2, dim=-1) 426 | 427 | return GD, rot_error, RMSE_T, PA_threshold, PA, CD_1 428 | 429 | def forward_pass(self, batch_data, device, mode, vis_idx=-1): 430 | tgt_pc = batch_data['tgt_pc'].float() 431 | device = tgt_pc.device 432 | num_points = tgt_pc.shape[1] 433 | batch_size = tgt_pc.shape[0] 434 | 435 | pred_data = self.forward(batch_data['src_pc'].float(), batch_data['tgt_pc'].float()) 436 | self.check_network_property(batch_data, pred_data) 437 | point_loss = 0.0 438 | 439 | rot_loss = self.compute_rot_loss(batch_data, pred_data) 440 | trans_loss = self.compute_trans_loss(batch_data, pred_data) 441 | if self.cfg.model.recon_loss: 442 | recon_loss = self.compute_recon_loss(batch_data, pred_data) 443 | else: 444 | recon_loss = 0.0 445 | 446 | if vis_idx > -1: 447 | debug_vis_input(batch_data, self.cfg, pred_data, vis_idx) 448 | debug_vis_output(batch_data, self.cfg, pred_data, vis_idx) 449 | debug_vis_gt(batch_data, self.cfg, pred_data, vis_idx) 450 | 451 | total_loss = point_loss + rot_loss + trans_loss + recon_loss 452 | 453 | if mode == 'val': 454 | return (self.calculate_metrics(batch_data, pred_data, device, mode), total_loss, point_loss,rot_loss,trans_loss,recon_loss) 455 | 456 | return pred_data['src_trans'], pred_data['src_rot'], total_loss, point_loss, rot_loss, trans_loss, recon_loss -------------------------------------------------------------------------------- /src/shape_assembly/models/train/network_vnn_B.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | import pytorch3d 8 | from pytorch3d.transforms import quaternion_to_matrix 9 | from transforms3d.quaternions import quat2mat 10 | from scipy.spatial.transform import Rotation as R 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | import os 14 | import sys 15 | import copy 16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append(os.path.join(BASE_DIR, '../../')) 18 | from models.encoder.pointnet import PointNet 19 | from models.encoder.dgcnn import DGCNN 20 | from models.decoder.MLPDecoder import MLPDecoder 21 | from models.encoder.vn_dgcnn import VN_DGCNN, VN_DGCNN_corr, VN_DGCNN_New, DGCNN_New 22 | from models.baseline.regressor_CR import Regressor_CR, Regressor_6d, VN_Regressor_6d 23 | import utils 24 | from pdb import set_trace 25 | 26 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 27 | # print("BASE_DIR: ", BASE_DIR) 28 | from ..ChamferDistancePytorch.chamfer3D import dist_chamfer_3D 29 | # from chamfer_distance import ChamferDistance as dist_chamfer_3D 30 | 31 | def bgs(d6s): 32 | bsz = d6s.shape[0] 33 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1) 34 | a2 = d6s[:, :, 1] 35 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1) 36 | b3 = torch.cross(b1, b2, dim=1) 37 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1) 38 | 39 | def debug_vis_input(batch_data, cfg, prd_data, iter_counts): 40 | for i in range(cfg.exp.batch_size): 41 | # print("i is,", i) 42 | save_dir = cfg.exp.vis_dir 43 | vis_dir = os.path.join(save_dir, 'vis_B_input') 44 | if not os.path.exists((vis_dir)): 45 | os.mkdir(vis_dir) 46 | # src_pc = batch_data['src_pc'][i] # (3, 1024) 47 | tgt_pc = batch_data['tgt_pc'][i] # (3, 1024) 48 | 49 | device = tgt_pc.device 50 | num_points = tgt_pc.shape[1] 51 | 52 | total_pc = np.array(tgt_pc.detach().cpu()) 53 | color_mask_tgt = ['g'] * num_points 54 | 55 | fig = plt.figure() 56 | ax = fig.add_subplot(projection='3d') 57 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=color_mask_tgt, s=10, alpha=0.9) 58 | ax.axis('scaled') 59 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 60 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 61 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 62 | 63 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_input'.format(iter_counts,i))) 64 | plt.close(fig) 65 | 66 | def debug_vis_output(batch_data, cfg, pred_data, iter_counts): 67 | 68 | # for every data in batch_size 69 | for i in range(cfg.exp.batch_size): 70 | save_dir = cfg.exp.vis_dir 71 | vis_dir = os.path.join(save_dir, 'vis_B_output') 72 | if not os.path.exists((vis_dir)): 73 | os.mkdir(vis_dir) 74 | 75 | tgt_pc = batch_data['tgt_pc'][i] # (3, 1024) 76 | 77 | tgt_trans = pred_data['tgt_trans'][i].unsqueeze(1) # (3) 78 | tgt_rot = pred_data['tgt_rot'][i] # (6) 79 | 80 | device = tgt_pc.device 81 | num_points = tgt_pc.shape[1] 82 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 83 | 84 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat) 85 | 86 | tgt_trans = tgt_trans.expand(-1, num_points) # (3, 1024) 87 | 88 | tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double()) # (3, 1024) 89 | 90 | tgt_pc = tgt_pc + tgt_trans 91 | 92 | total_pc = np.array(tgt_pc.detach().cpu()) 93 | color_mask_tgt = ['g'] * num_points 94 | total_color = color_mask_tgt 95 | 96 | fig = plt.figure() 97 | ax = fig.add_subplot(projection='3d') 98 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=0.9) 99 | ax.axis('scaled') 100 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 101 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 102 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 103 | 104 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_output'.format(iter_counts,i))) 105 | plt.close(fig) 106 | 107 | 108 | def debug_vis_gt(batch_data, cfg, pred_data, iter_counts): 109 | 110 | for i in range(cfg.exp.batch_size): 111 | save_dir = cfg.exp.vis_dir 112 | vis_dir = os.path.join(save_dir, 'vis_B_gt') 113 | if not os.path.exists((vis_dir)): 114 | os.mkdir(vis_dir) 115 | tgt_pc = batch_data['tgt_pc'][i] # (3, 1024) 116 | tgt_trans = batch_data['tgt_trans'][i] # (3, 1) 117 | tgt_rot = batch_data['tgt_rot'][i] # (4) 118 | device = tgt_pc.device 119 | num_points = tgt_pc.shape[1] 120 | 121 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3) 122 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat) 123 | 124 | tgt_trans = tgt_trans.expand(-1, num_points) # (3, 1024) 125 | 126 | tgt_pc = torch.matmul(tgt_rot_mat, tgt_pc) # (3, 1024) 127 | 128 | tgt_pc = tgt_pc + tgt_trans 129 | 130 | total_pc = np.array(tgt_pc.detach().cpu()) 131 | color_mask_tgt = ['g'] * num_points 132 | total_color = color_mask_tgt 133 | 134 | fig = plt.figure() 135 | ax = fig.add_subplot(projection='3d') 136 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=0.9) 137 | ax.axis('scaled') 138 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'}) 139 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'}) 140 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'}) 141 | 142 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_gt'.format(iter_counts,i))) 143 | plt.close(fig) 144 | 145 | 146 | class ShapeAssemblyNet_B_vnn(nn.Module): 147 | 148 | def __init__(self, cfg, data_features): 149 | super().__init__() 150 | self.cfg = cfg 151 | self.encoder = self.init_encoder() 152 | 153 | self.encoder_dgcnn = DGCNN_New(feat_dim=cfg.model.pc_feat_dim) 154 | 155 | self.pose_predictor_rot = self.init_pose_predictor_rot() 156 | self.pose_predictor_trans = self.init_pose_predictor_trans() 157 | if self.cfg.model.recon_loss: 158 | self.decoder = self.init_decoder() 159 | self.data_features = data_features 160 | 161 | self.iter_counts = 0 162 | self.close_eps = 0.1 163 | self.L2 = nn.MSELoss() 164 | self.R = torch.tensor([[0.26726124, -0.57735027, 0.77151675], 165 | [0.53452248, -0.57735027, -0.6172134], 166 | [0.80178373, 0.57735027, 0.15430335]], dtype=torch.float64).unsqueeze(0) 167 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist() 168 | 169 | self.mlp_color = nn.Sequential( 170 | nn.Linear(512*2*3, 1024)) 171 | 172 | 173 | def init_encoder(self): 174 | if self.cfg.model.encoderB == 'dgcnn': 175 | encoder = DGCNN(feat_dim=self.cfg.model.pc_feat_dim) 176 | elif self.cfg.model.encoderB == 'vn_dgcnn': 177 | encoder = VN_DGCNN_New(feat_dim=self.cfg.model.pc_feat_dim) 178 | elif self.cfg.model.encoderB == 'pointnet': 179 | encoder = PointNet(feat_dim=self.cfg.model.pc_feat_dim) 180 | return encoder 181 | 182 | def init_pose_predictor_rot(self): 183 | if self.cfg.model.encoderB == 'vn_dgcnn': 184 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3 185 | if self.cfg.model.pose_predictor_rot == 'original': 186 | pose_predictor_rot = Regressor_CR(pc_feat_dim= pc_feat_dim, out_dim=6) 187 | elif self.cfg.model.pose_predictor_rot == 'vn': 188 | pose_predictor_rot = VN_equ_Regressor(pc_feat_dim= pc_feat_dim/3, out_dim=6) 189 | 190 | return pose_predictor_rot 191 | 192 | def init_pose_predictor_trans(self): 193 | if self.cfg.model.encoderB == 'vn_dgcnn': 194 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3 195 | if self.cfg.model.pose_predictor_trans == 'original': 196 | pose_predictor_trans = Regressor_CR(pc_feat_dim=pc_feat_dim, out_dim=3) 197 | elif self.cfg.model.pose_predictor_trans == 'vn': 198 | pose_predictor_trans = VN_inv_Regressor(pc_feat_dim=pc_feat_dim/3, out_dim=3) 199 | return pose_predictor_trans 200 | 201 | def init_decoder(self): 202 | pc_feat_dim = self.cfg.model.pc_feat_dim 203 | decoder = MLPDecoder(feat_dim=pc_feat_dim, num_points=self.cfg.data.num_pc_points) 204 | return decoder 205 | 206 | 207 | def configure_optimizers(self): 208 | optimizer = torch.optim.Adam( 209 | self.parameters(), 210 | lr=self.cfg.optimizer.lr, 211 | weight_decay=self.cfg.optimizer.weight_decay, 212 | ) 213 | return optimizer 214 | 215 | 216 | def check_equiv(self, x, R, xR, name): 217 | mean_diff = torch.mean(torch.abs(torch.matmul(x, R) - xR)) 218 | if mean_diff > self.close_eps: 219 | print(f'---[Equiv check]--- {name}: {mean_diff}') 220 | return 221 | 222 | def check_inv(self, x, R, xR, name): 223 | mean_diff = torch.mean(torch.abs(x - xR)) 224 | if mean_diff > self.close_eps: 225 | print(f'---[Equiv check]--- {name}: {mean_diff}') 226 | return 227 | 228 | def check_network_property(self, gt_data, pred_data): 229 | with torch.no_grad(): 230 | B, _, N = gt_data['src_pc'].shape 231 | R = self.R.float().repeat(B, 1, 1).to(gt_data['src_pc'].device) 232 | pcs_R = torch.matmul(gt_data['src_pc'].permute(0, 2, 1), R).permute(0, 2, 1) 233 | pred_data_R = self.forward(pcs_R, gt_data['tgt_pc']) 234 | 235 | equiv_feats = pred_data['Fa'] 236 | equiv_feats_R = pred_data_R['Fa'] 237 | self.check_equiv(equiv_feats, R, equiv_feats_R, 'equiv_feats') 238 | 239 | inv_feats = pred_data['Ga'] 240 | inv_feats_R = pred_data_R['Ga'] 241 | self.check_inv(inv_feats, R, inv_feats_R, 'inv_feats') 242 | 243 | if self.cfg.model.pose_predictor_rot == 'vn': 244 | rot = bgs(pred_data['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1) 245 | rot_R = bgs(pred_data_R['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1) 246 | self.check_equiv(rot, R, rot_R, 'rot') 247 | return 248 | 249 | def _recon_pts(self, Ga, Gb): 250 | global_inv_feat = torch.sum(torch.cat([Ga, Gb], dim=1), dim=1) 251 | recon_pts = self.decoder(global_inv_feat) 252 | return recon_pts 253 | 254 | def forward(self, tgt_pc): 255 | batch_size = tgt_pc.shape[0] 256 | num_points = tgt_pc.shape[2] 257 | if self.cfg.model.encoderB == 'dgcnn': 258 | tgt_point_feat = self.encoder(tgt_pc) # (batch_size, pc_feat_dim(512), num_point(1024)) 259 | 260 | tgt_feat = torch.mean(tgt_point_feat, dim=2) 261 | 262 | if self.cfg.model.encoderB == 'vn_dgcnn': 263 | Fb, Gb = self.encoder(tgt_pc) 264 | device = tgt_pc.device 265 | 266 | tgt_feat = Fb 267 | 268 | if self.cfg.model.pose_predictor_rot == 'original': 269 | tgt_rot = self.pose_predictor_rot(tgt_feat.reshape(batch_size, -1)) 270 | else: 271 | tgt_rot = self.pose_predictor_rot(tgt_feat) 272 | 273 | if self.cfg.model.pose_predictor_trans == 'original': 274 | tgt_trans = self.pose_predictor_trans(tgt_feat.reshape(batch_size, -1)) 275 | else: 276 | tgt_trans = self.pose_predictor_trans(tgt_feat) 277 | 278 | pred_dict = { 279 | 'tgt_rot': tgt_rot, 280 | 'tgt_trans': tgt_trans, 281 | } 282 | return pred_dict 283 | 284 | def compute_point_loss(self, batch_data, pred_data): 285 | tgt_pc = batch_data['tgt_pc'].float() # batch x 3 x 1024 286 | tgt_rot_gt = self.recover_R_from_6d(batch_data['tgt_rot'].float()) 287 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1 288 | tgt_rot_pred = self.recover_R_from_6d(pred_data['tgt_rot'].float()) 289 | tgt_trans_pred = pred_data['tgt_trans'].float() 290 | 291 | tgt_trans_pred = tgt_trans_pred.unsqueeze(2) 292 | 293 | # Target point loss 294 | transformed_tgt_pc_pred = tgt_rot_gt @ tgt_pc + tgt_trans_pred # batch x 3 x 1024 295 | with torch.no_grad(): 296 | transformed_tgt_pc_gt = tgt_rot_pred @ tgt_pc + tgt_trans_gt # batch x 3 x 1024 297 | tgt_point_loss = torch.mean(torch.sum((transformed_tgt_pc_pred - transformed_tgt_pc_gt) ** 2, axis=1)) 298 | 299 | # Point loss 300 | point_loss = tgt_point_loss 301 | return point_loss 302 | 303 | def compute_trans_loss(self, batch_data, pred_data): 304 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1ssssss 305 | tgt_trans_pred = pred_data['tgt_trans'] # batch x 3 x 1 306 | tgt_trans_pred = tgt_trans_pred.unsqueeze(dim=2) 307 | tgt_trans_loss = F.l1_loss(tgt_trans_pred, tgt_trans_gt) 308 | trans_loss = tgt_trans_loss 309 | return trans_loss 310 | 311 | def compute_rot_loss(self, batch_data, pred_data): 312 | tgt_R_6d = batch_data['tgt_rot'] 313 | tgt_R_6d_pred = pred_data['tgt_rot'] 314 | tgt_rot_loss = torch.mean(utils.get_6d_rot_loss(tgt_R_6d, tgt_R_6d_pred)) 315 | rot_loss = tgt_rot_loss 316 | return rot_loss 317 | 318 | def compute_rot_loss_symmetry(self, batch_data, pred_data, device): 319 | 320 | partB_symmetry_type = batch_data['partB_symmetry_type'] 321 | 322 | tgt_R_6d = batch_data['tgt_rot'] 323 | 324 | tgt_R_6d_pred = pred_data['tgt_rot'] 325 | tgt_rot_loss = torch.mean(utils.get_6d_rot_loss_symmetry(tgt_R_6d, tgt_R_6d_pred, partB_symmetry_type, device)) 326 | 327 | rot_loss = tgt_rot_loss 328 | return rot_loss 329 | 330 | def compute_recon_loss(self, batch_data, pred_data): 331 | 332 | recon_pts = pred_data['recon_pts'] # batch x 1024 x 3 333 | 334 | src_pc = batch_data['src_pc'].float() # batch x 3 x 1024 335 | tgt_pc = batch_data['tgt_pc'].float() # batch x 3 x 1024 336 | src_quat_gt = batch_data['src_rot'].float() 337 | tgt_quat_gt = batch_data['tgt_rot'].float() 338 | 339 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) # batch x 3 x 3 340 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) # batch x 3 x 3 341 | 342 | src_trans_gt = batch_data['src_trans'].float() # batch x 3 x 1 343 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1 344 | with torch.no_grad(): 345 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt # batch x 3 x 1024 346 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt # batch x 3 x 1024 347 | gt_pts = torch.cat([transformed_src_pc_gt, transformed_tgt_pc_gt], dim=2).permute(0, 2, 1) # batch x 2048 x 3 348 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist() 349 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts) 350 | recon_loss = torch.mean(dist1) + torch.mean(dist2) 351 | 352 | return recon_loss 353 | 354 | def recover_R_from_6d(self, R_6d): 355 | R = utils.bgs(R_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 356 | return R 357 | 358 | def quat_to_eular(self, quat): 359 | quat = np.array([quat[1], quat[2], quat[3], quat[0]]) 360 | 361 | r = R.from_quat(quat) 362 | euler0 = r.as_euler('xyz', degrees=True) 363 | 364 | return euler0 365 | 366 | def training_step(self, batch_data, device, batch_idx): 367 | self.iter_counts += 1 368 | partB_position, partB_rotation, total_loss, point_loss, rot_loss, trans_loss, recon_loss = self.forward_pass(batch_data, device, mode='train') 369 | return {"total_loss": total_loss, 370 | "point_loss": point_loss, 371 | "rot_loss": rot_loss, 372 | "trans_loss": trans_loss, 373 | "recon_loss": recon_loss, 374 | 'predicted_partB_position': partB_position, 375 | 'predicted_partB_rotation': partB_rotation 376 | } 377 | 378 | 379 | def calculate_metrics(self, batch_data, pred_data, device, mode): 380 | GD = self.compute_rot_loss(batch_data, pred_data) 381 | 382 | rot_error = self.compute_rot_loss(batch_data, pred_data) 383 | 384 | tgt_pc = batch_data['tgt_pc'].float() # batch x 3 x 1024 385 | tgt_quat_gt = batch_data['tgt_rot'].float() 386 | 387 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) # batch x 3 x 3 388 | 389 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1 390 | with torch.no_grad(): 391 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt # batch x 3 x 1024 392 | gt_pts = transformed_tgt_pc_gt.permute(0, 2, 1) # batch x 1024 x 3 393 | 394 | pred_R_tgt = self.recover_R_from_6d(pred_data['tgt_rot']) 395 | pred_t_tgt = pred_data['tgt_trans'].view(-1, 3, 1) 396 | 397 | gt_euler_tgt = pytorch3d.transforms.matrix_to_euler_angles(tgt_Rs, convention="XYZ") 398 | 399 | pred_euler_tgt = pytorch3d.transforms.matrix_to_euler_angles(pred_R_tgt, convention="XYZ") 400 | 401 | with torch.no_grad(): 402 | transformed_tgt_pc_pred = pred_R_tgt @ tgt_pc + pred_t_tgt # batch x 3 x 1024 403 | 404 | recon_pts = transformed_tgt_pc_pred.permute(0, 2, 1) # batch x 2048 x 3 405 | 406 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts) 407 | PA = torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1) 408 | 409 | thre = 0.0001 410 | acc = (PA < thre) 411 | PA_threshold = acc.sum(-1) / acc.shape[0] 412 | 413 | RMSE_T_2 = (pred_t_tgt - tgt_trans_gt).pow(2).mean(dim=-1) ** 0.5 414 | RMSE_T = RMSE_T_2 415 | 416 | dist_b1, dist_b2, idx_b1, idx_b2 = self.chamLoss(transformed_tgt_pc_gt.permute(0,2,1), transformed_tgt_pc_pred.permute(0,2,1)) 417 | CD_2 = torch.mean(dist_b1, dim=-1) + torch.mean(dist_b2, dim=-1) 418 | 419 | return GD, rot_error, RMSE_T, PA_threshold, PA, CD_2 420 | 421 | def forward_pass(self, batch_data, device, mode, vis_idx=-1): 422 | 423 | pred_data = self.forward(batch_data['tgt_pc'].float()) 424 | if self.cfg.model.point_loss: 425 | point_loss = self.compute_point_loss(batch_data, pred_data) 426 | else: 427 | point_loss = 0.0 428 | 429 | rot_loss = self.compute_rot_loss(batch_data, pred_data) 430 | trans_loss = self.compute_trans_loss(batch_data, pred_data) 431 | recon_loss = 0.0 432 | 433 | if vis_idx > -1: 434 | debug_vis_input(batch_data, self.cfg, pred_data, vis_idx) 435 | debug_vis_output(batch_data, self.cfg, pred_data, vis_idx) 436 | debug_vis_gt(batch_data, self.cfg, pred_data, vis_idx) 437 | 438 | total_loss = point_loss + rot_loss + trans_loss + recon_loss 439 | 440 | if mode == 'val': 441 | return (self.calculate_metrics(batch_data, pred_data, device, mode), pred_data['tgt_trans'], pred_data['tgt_rot'], total_loss, point_loss,rot_loss,trans_loss,recon_loss) 442 | 443 | # Total loss 444 | return pred_data['tgt_trans'], pred_data['tgt_rot'],total_loss, point_loss, rot_loss, trans_loss, recon_loss 445 | 446 | -------------------------------------------------------------------------------- /src/shape_assembly/models/train/pose_estimator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PointNet(nn.Module): 6 | def __init__(self, out_channels=(32, 64, 128), train_with_norm=True): 7 | super(PointNet, self).__init__() 8 | self.layers = nn.ModuleList() 9 | in_channels = 3 10 | for out_channel in out_channels: 11 | self.layers.append(nn.Conv1d(in_channels, out_channel, 1)) 12 | self.layers.append(nn.BatchNorm1d(out_channel) if train_with_norm else nn.Identity()) 13 | self.layers.append(nn.ReLU()) 14 | in_channels = out_channel 15 | self.global_pool = nn.AdaptiveMaxPool1d(1) 16 | 17 | def forward(self, x): 18 | for layer in self.layers: 19 | x = layer(x) 20 | x = self.global_pool(x) 21 | x = x.squeeze(-1) 22 | return x 23 | 24 | class PoseClassifier(nn.Module): 25 | def __init__(self, pointnet_out_dim=128, pose_dim=6, hidden_dims=(512, 256, 128)): 26 | super(PoseClassifier, self).__init__() 27 | self.pointnet = PointNet(out_channels=(32, 64, pointnet_out_dim)) 28 | input_dim = pointnet_out_dim + pose_dim 29 | layers = [] 30 | for hidden_dim in hidden_dims: 31 | layers.append(nn.Linear(input_dim, hidden_dim)) 32 | layers.append(nn.ReLU()) 33 | input_dim = hidden_dim 34 | layers.append(nn.Linear(input_dim, 1)) 35 | self.classifier = nn.Sequential(*layers) 36 | 37 | def forward(self, point_cloud, poses): 38 | # Point cloud feature extraction 39 | point_cloud_features = self.pointnet(point_cloud) # (batch_size, pointnet_out_dim) 40 | 41 | # Repeat point cloud features for each pose 42 | repeated_features = point_cloud_features.unsqueeze(1).repeat(1, poses.size(1), 1) # (batch_size, num_poses, pointnet_out_dim) 43 | 44 | # Concatenate pose features with point cloud features 45 | combined_features = torch.cat((repeated_features, poses), dim=-1) # (batch_size, num_poses, pointnet_out_dim + pose_dim) 46 | 47 | # Flatten the input for the classifier 48 | combined_features = combined_features.view(-1, combined_features.size(-1)) # (batch_size * num_poses, pointnet_out_dim + pose_dim) 49 | 50 | # Classification 51 | scores = self.classifier(combined_features) # (batch_size * num_poses, 1) 52 | scores = scores.view(-1, poses.size(1)) # (batch_size, num_poses) 53 | 54 | return scores 55 | 56 | 57 | batch_size = 8 58 | num_points = 1024 59 | num_poses = 1024 60 | pose_dim = 6 61 | 62 | point_cloud = torch.rand(batch_size, 3, num_points) 63 | poses = torch.rand(batch_size, num_poses, pose_dim) 64 | 65 | model = PoseClassifier(pointnet_out_dim=128, pose_dim=pose_dim) 66 | scores = model(point_cloud, poses) 67 | print(scores) 68 | print(scores.size()) 69 | -------------------------------------------------------------------------------- /src/shape_assembly/models/train/regressor_CR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.encoder.vn_layers import * 5 | from pdb import set_trace 6 | 7 | 8 | class Regressor_CR(nn.Module): 9 | def __init__(self, pc_feat_dim, out_dim): 10 | super().__init__() 11 | self.fc_layers = nn.Sequential( 12 | nn.Linear(pc_feat_dim, 256), 13 | nn.LeakyReLU(), 14 | nn.Linear(256, 128), 15 | nn.LeakyReLU() 16 | ) 17 | print("pc_feature.size is", pc_feat_dim) 18 | 19 | self.head = nn.Linear(128, out_dim) 20 | 21 | def forward(self, x): 22 | f = self.fc_layers(x) 23 | output = self.head(f) 24 | return output 25 | 26 | class Corr_Aggregator_CR(nn.Module): 27 | def __init__(self, pc_feat_dim, out_dim): 28 | super().__init__() 29 | self.fc_layers = nn.Sequential( 30 | nn.Linear(pc_feat_dim, out_dim), 31 | nn.LeakyReLU(), 32 | ) 33 | self.head = nn.Linear(out_dim, out_dim) 34 | def forward(self, x): 35 | f = self.fc_layers(x) 36 | output = self.head(f) 37 | return output 38 | 39 | class VN_Corr_Aggregator_CR(nn.Module): 40 | def __init__(self, pc_feat_dim, out_dim): 41 | super().__init__() 42 | self.fc_layers = nn.Sequential( 43 | nn.Linear(pc_feat_dim, out_dim), 44 | nn.LeakyReLU(), 45 | ) 46 | self.head = nn.Linear(out_dim, out_dim) 47 | def forward(self, x): 48 | f = self.fc_layers(x) 49 | output = self.head(f) 50 | return output 51 | 52 | class Regressor_6d(nn.Module): 53 | def __init__(self, pc_feat_dim): 54 | super().__init__() 55 | self.fc_layers = nn.Sequential( 56 | nn.Linear(2*pc_feat_dim, 256), 57 | nn.BatchNorm1d(256), 58 | nn.LeakyReLU(0.2), 59 | nn.Linear(256, 128), 60 | nn.BatchNorm1d(128), 61 | nn.LeakyReLU(0.2) 62 | ) 63 | 64 | # Rotation prediction head 65 | self.rot_head = nn.Linear(128, 6) 66 | 67 | # Translation prediction head 68 | self.trans_head = nn.Linear(128, 3) 69 | 70 | def forward(self, x): 71 | f = self.fc_layers(x) 72 | quat = self.rot_head(f) 73 | quat = quat / torch.norm(quat, p=2, dim=1, keepdim=True) 74 | trans = self.trans_head(f) 75 | trans = torch.unsqueeze(trans, dim=2) 76 | return quat, trans 77 | 78 | 79 | class VN_Regressor(nn.Module): 80 | def __init__(self, pc_feat_dim, out_dim): 81 | super().__init__() 82 | self.fc_layers = nn.Sequential( 83 | VNLinear(2*pc_feat_dim, 256), 84 | nn.BatchNorm1d(256), 85 | VNNewLeakyReLU(in_channels=256, negative_slope=0.2), 86 | VNLinear(256, 128), 87 | nn.BatchNorm1d(128), 88 | VNNewLeakyReLU(in_channels=128, negative_slope=0.2) 89 | ) 90 | 91 | # Rotation prediction head 92 | self.head = VNLinear(128, out_dim) 93 | 94 | def forward(self, x): 95 | f = self.fc_layers(x) 96 | output = self.head(f) 97 | return output 98 | 99 | class VN_Regressor_6d(nn.Module): 100 | def __init__(self, pc_feat_dim): 101 | super().__init__() 102 | self.fc_layers = nn.Sequential( 103 | VNLinear(1024, 256), 104 | # VNBatchNorm(256), 105 | nn.BatchNorm1d(256), 106 | VNLeakyReLU(in_channels=256, negative_slope=0.2), 107 | VNLinear(256, 128), 108 | # VNBatchNorm(128), 109 | nn.BatchNorm1d(128), 110 | VNLeakyReLU(in_channels=128, negative_slope=0.2) 111 | ) 112 | 113 | # Rotation prediction head 114 | self.rot_head = VNLinear(128, 2) 115 | 116 | # Translation prediction head 117 | self.trans_head = nn.Linear(128*3, 3) 118 | 119 | def forward(self, x): 120 | f = self.fc_layers(x) 121 | 122 | rot = self.rot_head(f) 123 | rot = rot / torch.norm(rot, p=2, dim=1, keepdim=True) 124 | trans = self.trans_head(f.reshape(-1, 128*3)) 125 | trans = torch.unsqueeze(trans, dim=2) 126 | return rot, trans -------------------------------------------------------------------------------- /src/shape_assembly/models/train/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import pytorch_lightning as pl 7 | 8 | def clones(module, N): 9 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 10 | 11 | def attention(query, key, value, mask=None): 12 | d_k = query.size(-1) 13 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / (d_k ** 0.5) 14 | if mask is not None: 15 | scores = scores.masked_fill(mask == 0, -1e9) 16 | p_attn = F.softmax(scores, dim=-1) 17 | return torch.matmul(p_attn, value), p_attn 18 | 19 | class EncoderDecoder(pl.LightningModule): 20 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): 21 | super().__init__() 22 | self.encoder = encoder 23 | self.decoder = decoder 24 | self.src_embed = src_embed 25 | self.tgt_embed = tgt_embed 26 | self.generator = generator 27 | 28 | def encode(self, src, src_mask): 29 | return self.encoder(self.src_embed(src), src_mask) 30 | 31 | def decode(self, memory, src_mask, tgt, tgt_mask): 32 | return self.generator(self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)) 33 | 34 | def forward(self, src, tgt, src_mask, tgt_mask): 35 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask) 36 | 37 | class Encoder(pl.LightningModule): 38 | def __init__(self, layer, N): 39 | super().__init__() 40 | self.layers = clones(layer, N) 41 | self.norm = LayerNorm(layer.size) 42 | 43 | def forward(self, x, mask): 44 | for layer in self.layers: 45 | x = layer(x, mask) 46 | return self.norm(x) 47 | 48 | class Decoder(pl.LightningModule): 49 | def __init__(self, layer, N): 50 | super().__init__() 51 | self.layers = clones(layer, N) 52 | self.norm = LayerNorm(layer.size) 53 | 54 | def forward(self, x, memory, src_mask, tgt_mask): 55 | for layer in self.layers: 56 | x = layer(x, memory, src_mask, tgt_mask) 57 | return self.norm(x) 58 | 59 | class LayerNorm(pl.LightningModule): 60 | def __init__(self, features, eps=1e-6): 61 | super().__init__() 62 | self.a_2 = nn.Parameter(torch.ones(features)) 63 | self.b_2 = nn.Parameter(torch.zeros(features)) 64 | self.eps = eps 65 | 66 | def forward(self, x): 67 | mean = x.mean(-1, keepdim=True) 68 | std = x.std(-1, keepdim=True) 69 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 70 | 71 | class SublayerConnection(pl.LightningModule): 72 | def __init__(self, size): 73 | super().__init__() 74 | self.norm = LayerNorm(size) 75 | 76 | def forward(self, x, sublayer): 77 | return x + sublayer(self.norm(x)) 78 | 79 | class EncoderLayer(pl.LightningModule): 80 | def __init__(self, size, self_attn, feed_forward): 81 | super(EncoderLayer, self).__init__() 82 | self.self_attn = self_attn 83 | self.feed_forward = feed_forward 84 | self.sublayer = clones(SublayerConnection(size), 2) 85 | self.size = size 86 | 87 | def forward(self, x, mask): 88 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 89 | return self.sublayer[1](x, self.feed_forward) 90 | 91 | class DecoderLayer(pl.LightningModule): 92 | 93 | def __init__(self, size, self_attn, src_attn, feed_forward): 94 | super().__init__() 95 | self.size = size 96 | self.self_attn = self_attn 97 | self.src_attn = src_attn 98 | self.feed_forward = feed_forward 99 | self.sublayer = clones(SublayerConnection(size), 3) 100 | 101 | def forward(self, x, memory, src_mask, tgt_mask): 102 | m = memory 103 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 104 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 105 | return self.sublayer[2](x, self.feed_forward) 106 | 107 | class MultiHeadedAttention(pl.LightningModule): 108 | 109 | def __init__(self, h, d_model): 110 | super().__init__() 111 | assert d_model % h == 0 112 | # We assume d_v always equals d_k 113 | self.d_k = d_model // h 114 | self.h = h 115 | self.linears = clones(nn.Linear(d_model, d_model), 4) 116 | 117 | def forward(self, query, key, value, mask=None): 118 | if mask is not None: 119 | mask = mask.unsqueeze(1) 120 | nbatches = query.size(0) 121 | 122 | # 1) Do all the linear projections in batch from d_model => h x d_k 123 | query, key, value = \ 124 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 125 | for l, x in zip(self.linears, (query, key, value))] 126 | 127 | # 2) Apply attention on all the projected vectors in batch. 128 | x, self.attn = attention(query, key, value, mask=mask) 129 | 130 | # 3) "Concat" using a view and apply a final linear. 131 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) 132 | 133 | return self.linears[-1](x) 134 | 135 | class PositionwiseFeedForward(pl.LightningModule): 136 | 137 | def __init__(self, d_model, d_ff): 138 | super().__init__() 139 | self.w_1 = nn.Linear(d_model, d_ff) 140 | self.norm = nn.Sequential() 141 | self.w_2 = nn.Linear(d_ff, d_model) 142 | 143 | def forward(self, x): 144 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous()) 145 | 146 | class Transformer(pl.LightningModule): 147 | 148 | def __init__(self, cfg): 149 | super().__init__() 150 | c = copy.deepcopy 151 | attn = MultiHeadedAttention( 152 | cfg.model.num_heads, 153 | cfg.model.pc_feat_dim 154 | ) 155 | 156 | ff = PositionwiseFeedForward( 157 | cfg.model.pc_feat_dim, 158 | cfg.model.transformer_feat_dim 159 | ) 160 | 161 | self.model = EncoderDecoder( 162 | Encoder(EncoderLayer(cfg.model.pc_feat_dim, c(attn), c(ff)), cfg.model.num_blocks), 163 | Decoder(DecoderLayer(cfg.model.pc_feat_dim, c(attn), c(attn), c(ff)), cfg.model.num_blocks), 164 | nn.Sequential(), 165 | nn.Sequential(), 166 | nn.Sequential() 167 | ) 168 | 169 | def forward(self, src, tgt): 170 | src = src.transpose(2, 1).contiguous() 171 | tgt = tgt.transpose(2, 1).contiguous() 172 | src_corr_feat = self.model(tgt, src, None, None).transpose(2, 1).contiguous() 173 | return src_corr_feat 174 | -------------------------------------------------------------------------------- /src/shape_assembly/utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def render_pts_label_png(fn, pc, color): 9 | # pc: (num_points, 3), color: (num_points,) 10 | new_color = [] 11 | for i in range(len(color)): 12 | if color[i] == 1: 13 | new_color.append('#ab4700') 14 | else: 15 | new_color.append('#00479e') 16 | fig = plt.figure(dpi=200) 17 | ax = fig.add_subplot(111, projection='3d') 18 | plt.title('point cloud') 19 | 20 | ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=new_color, marker='.', s=5, linewidth=0, alpha=1) 21 | ax.axis('scaled') 22 | ax.set_xlabel('X Label') 23 | ax.set_ylabel('Y Label') 24 | ax.set_zlabel('Z Label') 25 | 26 | plt.savefig(fn+'.png') 27 | 28 | def compute_distance_between_rotations(P,Q): 29 | #! input two rotation matrices, output the distance between them (unit is rad) 30 | #! P,Q are 3x3 numpy arrays 31 | P = np.asarray(P) 32 | Q = np.asarray(Q) 33 | R = np.matmul(P,Q.swapaxes(1,2)) 34 | theta = np.arccos(np.clip((np.trace(R,axis1 = 1,axis2 = 2) - 1)/2,-1,1)) 35 | return np.mean(theta) 36 | 37 | def bgs(d6s): 38 | bsz = d6s.shape[0] 39 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1) 40 | a2 = d6s[:, :, 1] 41 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1) 42 | b3 = torch.cross(b1, b2, dim=1) 43 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1) 44 | 45 | 46 | def bgdR(Rgts, Rps): 47 | Rds = torch.bmm(Rgts.permute(0, 2, 1), Rps) 48 | Rt = torch.sum(Rds[:, torch.eye(3).bool()], 1) #batch trace 49 | # necessary or it might lead to nans and the likes 50 | theta = torch.clamp(0.5 * (Rt - 1), -1 + 1e-6, 1 - 1e-6) 51 | # print("\033[33m the theta in bgdR is", theta,"\033[0m") 52 | return torch.acos(theta) 53 | 54 | def get_6d_rot_loss(gt_6d, pred_6d): 55 | #input pred_6d , gt_6d : batch * 2 * 3 56 | pred_Rs = bgs(pred_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 57 | gt_Rs = bgs(gt_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 58 | theta = bgdR(gt_Rs, pred_Rs) 59 | theta_degree = theta * 180 / np.pi 60 | # print("\033[33m the theta is", theta,"\033[0m") 61 | return theta_degree 62 | 63 | def get_6d_rot_loss_symmetry_new(batch_data, pred_data, device): 64 | 65 | batch_size = batch_data['src_rot'].shape[0] 66 | 67 | partA_symmetry_type = batch_data['partA_symmetry_type'] 68 | partB_symmetry_type = batch_data['partB_symmetry_type'] 69 | 70 | src_R_6d = batch_data['src_rot'] 71 | tgt_R_6d = batch_data['tgt_rot'] 72 | 73 | src_R_6d_pred = pred_data['src_rot'] 74 | tgt_R_6d_pred = pred_data['tgt_rot'] 75 | tgt_R_6d_init = batch_data['predicted_partB_rotation'] 76 | 77 | src_Rs = bgs(src_R_6d_pred.reshape(-1, 2, 3).permute(0, 2, 1)) 78 | gt_src_Rs = bgs(src_R_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 79 | 80 | tgt_Rs = bgs(tgt_R_6d_pred.reshape(-1, 2, 3).permute(0, 2, 1)) 81 | tgt_Rs_init = bgs(tgt_R_6d_init.reshape(-1, 2, 3).permute(0, 2, 1)) 82 | 83 | tgt_Rs_new = torch.matmul(tgt_Rs, tgt_Rs_init) 84 | gt_tgt_Rs = bgs(tgt_R_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 85 | 86 | 87 | R1 = src_Rs / torch.pow(torch.det(src_Rs), 1/3).view(-1,1,1) 88 | R2 = gt_src_Rs / torch.pow(torch.det(gt_src_Rs), 1/3).view(-1,1,1) 89 | 90 | R3 = tgt_Rs_new/ torch.pow(torch.det(tgt_Rs_init), 1/3).view(-1,1,1) 91 | R4 = gt_tgt_Rs / torch.pow(torch.det(tgt_Rs), 1/3).view(-1,1,1) 92 | 93 | # R5 = gt_tgt_Rs / torch.pow(torch.det(gt_tgt_Rs), 1/3).view(-1,1,1) 94 | 95 | # cos_theta = torch.zeros(batch_size) 96 | z = torch.tensor([0.0,0.0,1.0], device=device).unsqueeze(0).repeat(batch_size,1) 97 | cos_theta = torch.zeros(batch_size) 98 | for i in range(batch_size): 99 | # for every data in batch_size 100 | symmetry_i = partA_symmetry_type[i] 101 | 102 | # if the data is z-axis symmetric 103 | if symmetry_i[4].item() == 1: 104 | # cosidering symmetry when rotating around z-axis 105 | z1 = torch.matmul(R1[i], z[i]) 106 | z2 = torch.matmul(R2[i], z[i]) 107 | 108 | cos_theta[i] = torch.dot(z1,z2) / (torch.norm(z1) * torch.norm(z2)) 109 | 110 | else: 111 | R_A = torch.matmul(R1[i], R2[i].transpose(1,0)) 112 | cos_theta[i] = (torch.trace(R_A) - 1) /2 113 | 114 | theta_src = torch.acos(torch.clamp(cos_theta, -1.0+1e-6, 1.0-1e-6))*180 / np.pi 115 | 116 | cos_theta_B = torch.zeros(batch_size) 117 | for i in range(batch_size): 118 | # for every data in batch_size 119 | symmetry_i = partB_symmetry_type[i] 120 | 121 | # if the data is z-axis symmetric 122 | if symmetry_i[4].item() == 1: 123 | # cosidering symmetry when rotating around z-axis 124 | 125 | # z4 = torch.matmul(R4[i], z[i]) 126 | z3 = torch.matmul(R3[i], z[i]) 127 | z4 = torch.matmul(R4[i], z[i]) 128 | 129 | cos_theta_B[i] = torch.dot(z3,z4) / (torch.norm(z3) * torch.norm(z4)) 130 | 131 | else: 132 | R_B = torch.matmul(R3[i], R4[i].transpose(1,0)) 133 | cos_theta_B[i] = (torch.trace(R_B) - 1) /2 134 | 135 | theta_tgt = torch.acos(torch.clamp(cos_theta_B, -1.0+1e-6, 1.0-1e-6))*180 / np.pi 136 | 137 | src_rot_loss = torch.mean(theta_src) 138 | tgt_rot_loss = torch.mean(theta_tgt) 139 | 140 | # rot_loss = (src_rot_loss + tgt_rot_loss)/2.0 141 | 142 | return (src_rot_loss, tgt_rot_loss) 143 | 144 | def get_6d_rot_loss_symmetry(gt_6d, pred_6d, symmetry, device): 145 | batch_size = gt_6d.shape[0] 146 | 147 | pred_Rs = bgs(pred_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 148 | gt_Rs = bgs(gt_6d.reshape(-1, 2, 3).permute(0, 2, 1)) 149 | R1 = pred_Rs / torch.pow(torch.det(pred_Rs), 1/3).view(-1,1,1) 150 | R2 = gt_Rs / torch.pow(torch.det(gt_Rs), 1/3).view(-1,1,1) 151 | 152 | # cos_theta = torch.zeros(batch_size) 153 | z = torch.tensor([0.0,0.0,1.0], device=device).unsqueeze(0).repeat(batch_size,1) 154 | cos_theta = torch.zeros(batch_size) 155 | for i in range(batch_size): 156 | # for every data in batch_size 157 | symmetry_i = symmetry[i] 158 | 159 | # if the data is z-axis symmetric 160 | if symmetry_i[4].item() == 1: 161 | # cosidering symmetry when rotating around z-axis 162 | z1 = torch.matmul(R1[i], z[i]) 163 | z2 = torch.matmul(R2[i], z[i]) 164 | cos_theta[i] = torch.dot(z1,z2) / (torch.norm(z1) * torch.norm(z2)) 165 | else: 166 | R = torch.matmul(R1[i], R2[i].transpose(1,0)) 167 | cos_theta[i] = (torch.trace(R) - 1) /2 168 | 169 | theta = torch.acos(torch.clamp(cos_theta, -1.0+1e-6, 1.0-1e-6))*180 / np.pi 170 | return theta 171 | 172 | 173 | def printout(flog, strout): 174 | print(strout) 175 | if flog is not None: 176 | flog.write(strout + '\n') 177 | -------------------------------------------------------------------------------- /src/shape_assembly/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | --------------------------------------------------------------------------------