├── LICENCE
├── README.md
├── assets
├── general.mp4
└── teaser.png
├── data_util
├── README.md
├── blender_triangulate.py
├── clear_obj.sh
├── csv_to_ply.py
├── csv_to_ply.sh
├── generate_pc.cpp
├── generate_pc.sh
├── generate_urdf.py
├── partA.obj
├── partA_cp.obj
├── preprocess_obj_blender.py
└── preprocess_obj_blender.sh
├── environment.yml
└── src
├── .DS_Store
├── config
├── eval.yml
├── train_A.yml
└── train_B.yml
├── script
├── .DS_Store
├── our_eval.py
├── our_train_A.py
└── our_train_B.py
└── shape_assembly
├── .DS_Store
├── __init__.py
├── __pycache__
├── config.cpython-38.pyc
└── utils.cpython-38.pyc
├── config.py
├── config_eval.py
├── datasets
├── .DS_Store
└── dataloader
│ ├── .DS_Store
│ ├── __pycache__
│ ├── cls_dataloader.cpython-38.pyc
│ ├── dataloader_A.cpython-38.pyc
│ ├── dataloader_CR.cpython-38.pyc
│ ├── dataloader_kit.cpython-38.pyc
│ ├── dataloader_kit1.cpython-38.pyc
│ ├── dataloader_kit_1.cpython-38.pyc
│ ├── dataloader_kit_ori.cpython-38.pyc
│ ├── dataloader_kit_single.cpython-38.pyc
│ ├── dataloader_our.cpython-38.pyc
│ ├── dataloader_our1.cpython-38.pyc
│ ├── sub1.cpython-38.pyc
│ └── sub_1_7.cpython-38.pyc
│ ├── dataloader_A.py
│ └── dataloader_B.py
├── models
├── decoder
│ ├── .DS_Store
│ ├── MLPDecoder.py
│ └── __pycache__
│ │ └── MLPDecoder.cpython-38.pyc
├── encoder
│ ├── .DS_Store
│ ├── __pycache__
│ │ ├── dgcnn.cpython-38.pyc
│ │ ├── pointnet.cpython-38.pyc
│ │ ├── vn_dgcnn.cpython-38.pyc
│ │ └── vn_layers.cpython-38.pyc
│ ├── vn_dgcnn.py
│ ├── vn_dgcnn_util.py
│ └── vn_layers.py
└── train
│ ├── __pycache__
│ ├── Pose_Refinement_Module.cpython-38.pyc
│ ├── network_vnn1.cpython-38.pyc
│ ├── network_vnn2.cpython-38.pyc
│ ├── network_vnn_A.cpython-38.pyc
│ ├── network_vnn_A_indi.cpython-38.pyc
│ ├── network_vnn_A_indi2.cpython-38.pyc
│ ├── network_vnn_B.cpython-38.pyc
│ ├── network_vnn_B2.cpython-38.pyc
│ ├── regressor_CR.cpython-38.pyc
│ └── transformer.cpython-38.pyc
│ ├── network_vnn_A.py
│ ├── network_vnn_A_indi.py
│ ├── network_vnn_B.py
│ ├── pose_estimator.py
│ ├── regressor_CR.py
│ └── transformer.py
├── utils.py
└── version.py
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Yu Qi and Yuanchen Ju
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Two by Two✌️ : Learning Multi-Task Pairwise Objects Assembly for Generalizable Robot Manipulation](https://tea-lab.github.io/TwoByTwo/)
2 |
3 | Project Page
4 | |
5 | arXiv
6 | |
7 | Twitter
8 | | Dataset
9 |
10 | Yu Qi*,
11 | Yuanchen Ju*,
12 | Tianming Wei,
13 | Chi Chu,
14 | Lawson L.S. Wong,
15 | Huazhe Xu
16 |
17 |
18 | **CVPR, 2025**
19 |
20 |
21 |
22 |

23 |
24 |
25 |
26 | # 🛠️ Installation
27 |
28 | This project is tested on Ubuntu 22.04 with CUDA 11.8.
29 |
30 | - Install [Anaconda](https://www.anaconda.com/docs/getting-started/anaconda/install#macos-linux-installation:to-download-a-different-version) or [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install#linux-terminal-installer)
31 | - Clone the repository and create the environment. The environment should be installed correctly within minutes.
32 |
33 | ```python
34 | git clone git@github.com:TEA-Lab/TwoByTwo.git
35 | conda env create -f environment.yml
36 | conda activate twobytwo
37 | ```
38 |
39 | - (Optional) If you would like to calculate Chamfer Distance, clone the [CUDA-accelerated Chamfer Distance library](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch/tree/master):
40 |
41 | ```python
42 | cd src/shape_assembly/models
43 | git clone https://github.com/ThibaultGROUEIX/ChamferDistancePytorch.git
44 | ```
45 |
46 | ## 🧩 Dataset
47 |
48 | 2BY2 Dataset has been released. To obtain our dataset, please fill out this [form](https://docs.google.com/forms/d/e/1FAIpQLSfhPcAdky8ZojjPlSSHN4ubYqc7WHIwfiqFW2L5YpqAHbbVgg/viewform?usp=sharing).
49 |
50 | ## 🍰 Dataset Utility Support
51 | It is recommended to use our pre-generated point cloud. In the meantime, you can also generate your own point cloud, add your own data, or generate **URDF(Unified Robot Description Format)** file for robot simulation purpose, please see `data_util` folder for more detailed instructions.
52 |
53 | ## 🐰Training and Inference
54 |
55 | In `src/config` modify the path of `log_dir` `data root_dir`. We support Distributed Data Parallel Training.
56 |
57 |
58 | - Train Network B
59 |
60 | ```python
61 | cd src
62 | python script/our_train_B.py --cfg_file train_B.yml
63 | ```
64 |
65 | - Train Network A
66 | ```python
67 | cd src
68 | python script/our_train_A.py --cfg_file train_A.yml
69 | ```
70 | - Inference
71 |
72 | ```python
73 | cd src
74 | python script/our_eval.py
75 | ```
76 |
77 | # 🎟️ Licence
78 | This repository is released under the MIT license. Refer to [LICENSE](LICENSE) for more information.
79 |
80 | # 🎨 Acknowledgement & Contact
81 | Our codebase is developed based on [SE3-part-assembly](https://crtie.github.io/SE-3-part-assembly/), and we express our gratitude to all the authors for their generously open-sourced code, as well as the open-source contributions of all baseline projects [Puzzlefusion++](https://puzzlefusion-plusplus.github.io.), [Jigsaw](https://jiaxin-lu.github.io/Jigsaw/), [Neural Shape Mating](https://neural-shape-mating.github.io/). for their valuable impact on the community.
82 |
83 | For inquiries about this project, please reach out to **Yu Qi: qi.yu2@northeastern.edu** and **Yuanchen Ju: juuycc0213@gmail.com**. You’re also welcome to open an issue or submit a pull request!😄
84 |
85 |
86 | # 🎸 BibTeX
87 |
88 | We would appreciate it if you find this work useful and consider citing it.
89 | ```
90 | @article{qi2025two,
91 | title={Two by two: Learning multi-task pairwise objects assembly for generalizable robot manipulation},
92 | author={Qi, Yu and Ju, Yuanchen and Wei, Tianming and Chu, Chi and Wong, Lawson LS and Xu, Huazhe},
93 | journal={CVPR 2025},
94 | year={2025}
95 | }
96 | ```
97 |
--------------------------------------------------------------------------------
/assets/general.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/assets/general.mp4
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/assets/teaser.png
--------------------------------------------------------------------------------
/data_util/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Utility Support
2 |
3 | ## Point Cloud Generation
4 |
5 | You can directly use our sampled point cloud in 2BY2 dataset, which contains 1024 points per object. But you can also generate your own point cloud. Same to [Breaking Bad](https://breaking-bad-dataset.github.io/), we use blue noise sampling to generate point cloud.
6 |
7 | ```shell
8 | sudo apt update
9 |
10 | sudo apt install build-essential cmake libglfw3-dev libgl1-mesa-dev libx11-dev libxi-dev libxrandr-dev libxinerama-dev libxcursor-dev libxxf86vm-dev libxext-dev libpthread-stubs0-dev libdl-dev
11 |
12 | g++ -I ./eigen-3.4.0 -I ./libigl/include -I ./glad/include -I /usr/include \
13 | -L /usr/lib \
14 | -o generate_pc \
15 | generate_pc.cpp ./glad/src/glad.c \
16 | -ldl -lGL -lglfw -pthread
17 |
18 | bash generate_pc.sh
19 | ```
20 |
21 | After this step, you will generate `.csv` files containing point clouds.
22 |
23 | If you would like different number of points, change `L150` of `generate_pc.cpp` to your target number.
24 |
25 | ## Mesh Preprocessing
26 |
27 | We triangulate each mesh by blender and uniformly scale them before generate point cloud. This step takes `partA.obj` and `partB.obj` as input and outputs the uniform mesh `partA_new.obj` and `partB_new.obj`
28 |
29 | ```shell
30 | bash preprocess_obj_blender.sh
31 | ```
32 |
33 | ## URDF(Unified Robot Description Format) Generation
34 |
35 | If you would like to use our mesh in simulator, you can run the following script.
36 |
37 | ```python
38 | python generate_urdf.py
39 | ```
40 |
41 | When creating VHACD file for collision, you might need to modify the `simplify_quadric_decimation(0.8)` parameter for better collision effect.
--------------------------------------------------------------------------------
/data_util/blender_triangulate.py:
--------------------------------------------------------------------------------
1 | import bpy
2 | import sys
3 |
4 | def triangulate_object(obj):
5 | bpy.ops.object.select_all(action='DESELECT')
6 | obj.select_set(True)
7 | # Select the object
8 | bpy.context.view_layer.objects.active = obj
9 | # Switch to object mode
10 | bpy.ops.object.mode_set(mode='OBJECT')
11 | # Switch to edit mode
12 | bpy.ops.object.mode_set(mode='EDIT')
13 | # Select all faces
14 | bpy.ops.mesh.select_all(action='SELECT')
15 | # Triangulate the mesh
16 | bpy.ops.mesh.quads_convert_to_tris()
17 | # Switch back to object mode
18 | bpy.ops.object.mode_set(mode='OBJECT')
19 |
20 | def main():
21 | input_path = sys.argv[-2]
22 | output_path = sys.argv[-1]
23 |
24 | bpy.context.preferences.system.audio_device = 'Null'
25 |
26 | # Clear existing scene
27 | bpy.ops.wm.read_factory_settings(use_empty=True)
28 |
29 | # Import the OBJ file
30 | bpy.ops.import_scene.obj(filepath=input_path)
31 |
32 | # Triangulate all imported objects
33 | for obj in bpy.context.scene.objects:
34 | if obj.type == 'MESH':
35 | triangulate_object(obj)
36 |
37 | # Export the triangulated mesh to OBJ
38 | bpy.ops.export_scene.obj(filepath=output_path, use_selection=False)
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/data_util/clear_obj.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle")
4 |
5 | # delete partA_new.obj, partB_new.obj
6 |
7 | for folder in "${input_folders[@]}"; do
8 | for subfolder in $(find "$folder" -type d); do
9 | partA_obj="$subfolder/partA_new.obj"
10 | partB_obj="$subfolder/partB_new.obj"
11 | # partA_ply="$subfolder/partA-pc.ply"
12 | # partB_ply="$subfolder/partB-pc.ply"
13 |
14 | if [[ -f "$partA_obj" && -f "$partB_obj" ]]; then
15 | echo "$partA_obj and $partB_obj found, deleted"
16 | rm "$partA_obj"
17 | rm "$partB_obj"
18 | else
19 | echo "partA_new.obj and partB_new.obj is messing"
20 | fi
21 | done
22 | done
--------------------------------------------------------------------------------
/data_util/csv_to_ply.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import argparse
3 |
4 | # Function to read the CSV file and remove duplicates
5 | def read_csv_remove_duplicates(input_csv):
6 | df = pd.read_csv(input_csv, header=None, names=["x", "y", "z"])
7 | # df = df.drop_duplicates()
8 | return df
9 |
10 | # Function to write the DataFrame to a PLY file
11 | def write_ply(df, output_ply):
12 | with open(output_ply, 'w+') as ply_file:
13 | ply_file.write('ply\n')
14 | ply_file.write('format ascii 1.0\n')
15 | ply_file.write(f'element vertex {len(df)}\n')
16 | ply_file.write('property float x\n')
17 | ply_file.write('property float y\n')
18 | ply_file.write('property float z\n')
19 | ply_file.write('property uchar red\n')
20 | ply_file.write('property uchar green\n')
21 | ply_file.write('property uchar blue\n')
22 | ply_file.write('end_header\n')
23 | for _, row in df.iterrows():
24 | ply_file.write(f'{row["x"]} {row["y"]} {row["z"]} 255 0 0\n')
25 |
26 | # Main function to convert CSV to PLY
27 | def convert_csv_to_ply(input_csv, output_ply):
28 | df = read_csv_remove_duplicates(input_csv)
29 | write_ply(df, output_ply)
30 |
31 | # Example usage
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser(description="Convert CSV to PLY")
34 | parser.add_argument('--input_csv', default='points.csv', type=str)
35 | parser.add_argument('--output_ply', default='points.ply', type=str)
36 | args= parser.parse_args()
37 | convert_csv_to_ply(args.input_csv, args.output_ply)
38 |
--------------------------------------------------------------------------------
/data_util/csv_to_ply.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle")
4 |
5 | for folder in "${input_folders[@]}"; do
6 | for subfolder in $(find "$folder" -type d); do
7 | partA_csv="$subfolder/partA-pc.csv"
8 | partB_csv="$subfolder/partB-pc.csv"
9 | partA_ply="$subfolder/partA-pc.ply"
10 | partB_ply="$subfolder/partB-pc.ply"
11 |
12 | if [[ -f "$partA_csv" && -f "$partB_csv" ]]; then
13 | echo "Processing $subfolder"
14 | python csv_to_ply.py --input_csv "$partA_csv" --output_ply "$partA_ply"
15 | python csv_to_ply.py --input_csv "$partB_csv" --output_ply "$partB_ply"
16 | else
17 | echo "Skipping $subfolder: partA.obj or partB.obj not found"
18 | fi
19 | done
20 | done
--------------------------------------------------------------------------------
/data_util/generate_pc.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | const static Eigen::IOFormat CSVFormat(Eigen::StreamPrecision, Eigen::DontAlignCols, ",", "\n");
18 |
19 | void sample_pc_with_blue_noise(
20 | const int & num_points,
21 | const Eigen::MatrixXd & mesh_v,
22 | const Eigen::MatrixXi & mesh_f,
23 | Eigen::MatrixXd & pc,
24 | Eigen::MatrixXd & normals)
25 | {
26 | if(mesh_v.size() == 0 || mesh_f.size() == 0) {
27 | std::cerr << "Error: Input mesh is empty.\n";
28 | return;
29 | }
30 |
31 | Eigen::VectorXd A;
32 | igl::doublearea(mesh_v, mesh_f, A);
33 | double radius = sqrt(((A.sum()*0.5/(num_points*0.6162910373))/igl::PI));
34 | std::cout << "Initial Blue noise radius: " << radius << "\n";
35 |
36 | Eigen::MatrixXd B;
37 | Eigen::VectorXi I;
38 | int max_attempts = 1;
39 | for(int attempt = 0; attempt < max_attempts; ++attempt)
40 | {
41 | igl::blue_noise(mesh_v, mesh_f, radius, B, I, pc);
42 | std::cout<<"successfully generate a set of blue noise!"<= num_points * 0.9 && pc.rows() <= num_points * 1.1)
44 | {
45 | break;
46 | }
47 | if(pc.rows() == 0) {
48 | std::cerr << "Error: Blue noise sampling generated an empty point cloud. Attempt: " << attempt + 1 << "\n";
49 | break;
50 | }
51 | radius *= sqrt(num_points * 1.0 / pc.rows());
52 | //std::cout << "Adjusted Blue noise radius: " << radius << " (Attempt " << attempt + 1 << ")\n";
53 | }
54 |
55 | if(pc.rows() == 0) {
56 | std::cerr << "Error: Blue noise sampling failed to generate points.\n";
57 | return;
58 | }
59 |
60 | if (pc.rows() > num_points)
61 | {
62 | std::cout << "Trimming point cloud from " << pc.rows() << " to " << num_points << " points\n";
63 | std::vector indices(pc.rows());
64 | std::iota(indices.begin(), indices.end(), 0);
65 | std::shuffle(indices.begin(), indices.end(), std::default_random_engine{});
66 | indices.resize(num_points);
67 | Eigen::MatrixXd trimmed_pc(num_points, pc.cols());
68 | //Eigen::MatrixXd trimmed_normals(num_points, normals.cols());
69 | std::cout << "pc.rows(): " << pc.rows() << ", pc.cols(): " << pc.cols() << std::endl;
70 | //std::cout << "normals.rows(): " << normals.rows() << ", normals.cols(): " << normals.cols() << std::endl;
71 | for (int i = 0; i < num_points; ++i)
72 | {
73 | //std::cout <<"for trimming" << " i is" << i << std::endl;
74 | trimmed_pc.row(i) = pc.row(indices[i]);
75 | //std::cout<<"1"< \n";
127 | return 1;
128 | }
129 |
130 | std::string mesh_file_left = argv[1];
131 | std::string mesh_file_right = argv[2];
132 | std::string save_root_directory = argv[3];
133 | Eigen::MatrixXd L_mesh_v, R_mesh_v;
134 | Eigen::MatrixXi L_mesh_f, R_mesh_f;
135 |
136 | if (!igl::read_triangle_mesh(mesh_file_left, L_mesh_v, L_mesh_f))
137 | {
138 | std::cerr << "Failed to load mesh from " << mesh_file_left << std::endl;
139 | return 1;
140 | }
141 |
142 | if (!igl::read_triangle_mesh(mesh_file_right, R_mesh_v, R_mesh_f))
143 | {
144 | std::cerr << "Failed to load mesh from " << mesh_file_right << std::endl;
145 | return 1;
146 | }
147 |
148 | std::cout << "Sampling point clouds...\n";
149 | int num_points = 1024;
150 | Eigen::MatrixXd L_pc, R_pc;
151 | Eigen::MatrixXd L_normal, R_normal;
152 | sample_pc_with_blue_noise(num_points, L_mesh_v, L_mesh_f, L_pc, L_normal);
153 | sample_pc_with_blue_noise(num_points, R_mesh_v, R_mesh_f, R_pc, R_normal);
154 |
155 | std::cout << "The target root file is " << save_root_directory << std::endl;
156 |
157 | std::cout << "Saving point clouds...\n";
158 | std::string L_pc_file = save_root_directory + '/' + "partA-pc.csv";
159 | std::string R_pc_file = save_root_directory + '/' + "partB-pc.csv";
160 | write_matrix_to_csv(L_pc_file, L_pc);
161 | write_matrix_to_csv(R_pc_file, R_pc);
162 |
163 | std::cout << "Saving normals...\n";
164 | std::string L_normal_file = save_root_directory + '/' + "partA-normal.csv";
165 | std::string R_normal_file = save_root_directory + '/' + "partB-normal.csv";
166 | write_matrix_to_csv(L_normal_file, L_normal);
167 | write_matrix_to_csv(R_normal_file, R_normal);
168 |
169 | std::cout << "Saved point clouds and normals to " << save_root_directory << std::endl;
170 |
171 | return 0;
172 | }
--------------------------------------------------------------------------------
/data_util/generate_pc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle")
4 |
5 | for folder in "${input_folders[@]}"; do
6 | for subfolder in $(find "$folder" -type d); do
7 | partA_file="$subfolder/partA_new.obj"
8 | partB_file="$subfolder/partB_new.obj"
9 |
10 | if [[ -f "$partA_file" && -f "$partB_file" ]]; then
11 | echo "Processing $subfolder"
12 | ./generate_pc "$partA_file" "$partB_file" "$subfolder"
13 | else
14 | echo "Skipping $subfolder: partA_new.obj or partB_new.obj not found"
15 | fi
16 | done
17 | done
18 |
--------------------------------------------------------------------------------
/data_util/generate_urdf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import trimesh
3 |
4 | input_obj = "insert_flower/1/partA_new.obj"
5 | vhacd_obj = "insert_flower/1/vhacd_partA_new.obj"
6 | urdf_file = "insert_flower/1/partA.urdf"
7 |
8 | # Step 1: Run VHACD
9 | mesh = trimesh.load(input_obj)
10 | print(f"loaded mesh with {len(mesh.faces)} faces and {len(mesh.vertices)} vertices")
11 |
12 | # simple = mesh.simplify_quadric_decimation(500)
13 |
14 | simple = mesh.simplify_quadric_decimation(0.8)
15 |
16 | simple.export(vhacd_obj)
17 |
18 | print(f"VHACD complete. Output saved to: {vhacd_obj}")
19 |
20 | # Step 2: Write URDF
21 | print(f"Generating URDF: {urdf_file}")
22 |
23 | urdf_content = f"""
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | """
43 |
44 | with open(urdf_file, "w") as f:
45 | f.write(urdf_content)
46 |
47 | print(f"URDF file written: {urdf_file}")
--------------------------------------------------------------------------------
/data_util/preprocess_obj_blender.py:
--------------------------------------------------------------------------------
1 | # This script takes a raw .glb file, adds vertex colors or throw away if without color
2 | # It then translates by its center of mass and normalize the bounding box
3 | # Finally, it saves the mesh or point cloud to a specified folder
4 | import open3d as o3d
5 | import numpy as np
6 | import os
7 | import time
8 | import matplotlib.pyplot as plt
9 | import json
10 | import trimesh
11 | import sys
12 | # import bpy
13 | import subprocess
14 | from tqdm import tqdm
15 |
16 | global_scale = 3
17 |
18 | current_dir = os.path.dirname(os.path.abspath(__file__))
19 | sys.path.append(current_dir)
20 |
21 | """
22 | partA.obj -> partA_tri.obj (triangulate
23 | partA_tri.obj -> partA_new.obj (recenter and rescale)
24 | normal_vector.json
25 | """
26 |
27 | # scale_factor = 2
28 | def file_exists(filepath):
29 | return os.path.exists(filepath)
30 |
31 | # activate when need to triangulate the meshes
32 |
33 | def triangulate_mesh_with_pyblender(instance_dir):
34 | blender_script_path = os.path.join(current_dir, 'blender_triangulate.py')
35 | subprocess.run(['blender', '--factory-startup', '--background', '--python', blender_script_path, '--', os.path.join(current_dir, instance_dir, 'partA.obj'), os.path.join(current_dir, instance_dir, 'partA_tri.obj')])
36 | subprocess.run(['blender','--factory-startup', '--background', '--python', blender_script_path, '--', os.path.join(current_dir, instance_dir, 'partB.obj'), os.path.join(current_dir, instance_dir, 'partB_tri.obj')])
37 |
38 |
39 |
40 | # use pyblender change the non-triangle mesh to triangle mesh
41 | def combine_meshes(instance_dir):
42 | triangulate_mesh_with_pyblender(instance_dir)
43 |
44 | # Load the two meshes
45 | mesh1_path = os.path.join(current_dir, instance_dir, "partA_tri.obj")
46 | mesh2_path = os.path.join(current_dir, instance_dir, "partB_tri.obj")
47 | combined_mesh_path = os.path.join(current_dir, instance_dir, "combined_mesh.obj")
48 |
49 | if not file_exists(mesh1_path) or not file_exists(mesh2_path):
50 | sys.exit(1)
51 |
52 |
53 | mesh1 = o3d.io.read_triangle_mesh(mesh1_path)
54 | # mesh1.triangulate()
55 | mesh2 = o3d.io.read_triangle_mesh(mesh2_path)
56 | # mesh2.triangulate()
57 |
58 | if not mesh1.has_vertices():
59 | raise ValueError("mesh1 does not contain any vertices.", mesh1_path)
60 | if not mesh2.has_vertices():
61 | raise ValueError("mesh2 does not contain any vertices.", mesh2_path)
62 |
63 | # Combine vertices and triangles
64 | combined_vertices = np.vstack((np.asarray(mesh1.vertices), np.asarray(mesh2.vertices)))
65 | combined_triangles = np.vstack((np.asarray(mesh1.triangles), np.asarray(mesh2.triangles) + len(mesh1.vertices)))
66 |
67 | # Create a new mesh with the combined vertices and triangles
68 | combined_mesh = o3d.geometry.TriangleMesh()
69 | combined_mesh.vertices = o3d.utility.Vector3dVector(combined_vertices)
70 | combined_mesh.triangles = o3d.utility.Vector3iVector(combined_triangles)
71 |
72 | # Optionally combine vertex colors if they exist
73 | if mesh1.has_vertex_colors() and mesh2.has_vertex_colors():
74 | combined_colors = np.vstack((np.asarray(mesh1.vertex_colors), np.asarray(mesh2.vertex_colors)))
75 | combined_mesh.vertex_colors = o3d.utility.Vector3dVector(combined_colors)
76 |
77 | # print("Combined mesh saved to {}".format(combined_mesh_path))
78 | return (mesh1, mesh2, combined_mesh)
79 |
80 | def calculate_center_and_scale_two_seperate_part(instance_dir):
81 |
82 | # calculate the center after combined two meshes and remove them to the center
83 | (mesh1, mesh2, combined_mesh) = combine_meshes(instance_dir)
84 | global_center = combined_mesh.get_center()
85 |
86 | mesh1.translate(-global_center)
87 | mesh2.translate(-global_center)
88 | combined_mesh.translate(-global_center)
89 |
90 | bounding_box = combined_mesh.get_axis_aligned_bounding_box()
91 | max_extent = np.max(bounding_box.get_extent())
92 | max_extent_index = np.argmax(bounding_box.get_extent())
93 | normal_vector = np.zeros(3)
94 | normal_vector[max_extent_index] = 1
95 |
96 | # Write stats to JSON file
97 | stats = {
98 | "Normal Vector":{
99 | "x": normal_vector[0],
100 | "y": normal_vector[1],
101 | "z": normal_vector[2]
102 | }
103 | }
104 |
105 | with open(os.path.join(current_dir, instance_dir, "normal_vector.json"), "w+") as f:
106 | json.dump(stats, f, indent=4)
107 |
108 |
109 | scale_factor = global_scale / max_extent
110 |
111 | mesh1.scale(scale_factor, (0,0,0))
112 | mesh2.scale(scale_factor, (0,0,0))
113 | combined_mesh.scale(scale_factor, (0,0,0))
114 |
115 | """
116 | save the final files, after re-centered and re-scale, including writing normal_vector.json
117 | """
118 |
119 | o3d.io.write_triangle_mesh(os.path.join(current_dir, instance_dir, "partA_new.obj"), mesh1)
120 | o3d.io.write_triangle_mesh(os.path.join(current_dir, instance_dir, "partB_new.obj"), mesh2)
121 | o3d.io.write_triangle_mesh(os.path.join(current_dir, instance_dir, "combined_mesh.obj"), combined_mesh)
122 |
123 | # Example usage
124 | # calculate_center_and_scale_two_seperate_part("original-2.obj", "original-3.obj", "combined_mesh.obj")
125 |
126 |
127 | if __name__ == '__main__':
128 |
129 | instance_dir = sys.argv[1]
130 | print("In the processing of {}".format(instance_dir))
131 | calculate_center_and_scale_two_seperate_part(instance_dir)
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/data_util/preprocess_obj_blender.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | input_folders=("processed_cups" "water_bottle_processed" "beer_bottle")
4 |
5 | for folder in "${input_folders[@]}"; do
6 | for subfolder in $(find "$folder" -type d); do
7 | partA_obj="$subfolder/partA.obj"
8 | partB_obj="$subfolder/partB.obj"
9 | # partA_ply="$subfolder/partA-pc.ply"
10 | # partB_ply="$subfolder/partB-pc.ply"
11 |
12 | if [[ -f "$partA_obj" && -f "$partB_obj" ]]; then
13 | echo "Processing $subfolder"
14 | python preprocess_obj_blender.py "$subfolder"
15 | # python csv_to_ply.py --input_csv "$partB_csv" --output_ply "$partB_ply"
16 | else
17 | echo "Skipping $subfolder: partA.obj or partB.obj not found"
18 | fi
19 | done
20 | done
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: twobytwo
2 |
3 | channels:
4 |
5 | - conda-forge
6 |
7 | - defaults
8 |
9 | dependencies:
10 |
11 | - _libgcc_mutex=0.1=conda_forge
12 |
13 | - _openmp_mutex=4.5=2_gnu
14 |
15 | - ca-certificates=2024.7.4=hbcca054_0
16 |
17 | - ld_impl_linux-64=2.40=hf3520f5_7
18 |
19 | - libffi=3.2.1=he1b5a44_1007
20 |
21 | - libgcc-ng=14.1.0=h77fa898_0
22 |
23 | - libgomp=14.1.0=h77fa898_0
24 |
25 | - libsqlite=3.46.0=hde9e2c9_0
26 |
27 | - libstdcxx-ng=14.1.0=hc0a3c3a_0
28 |
29 | - libzlib=1.2.13=h4ab18f5_6
30 |
31 | - ncurses=6.5=h59595ed_0
32 |
33 | - openssl=1.1.1w=hd590300_0
34 |
35 | - pip=24.0=pyhd8ed1ab_0
36 |
37 | - python=3.8.0=h357f687_5
38 |
39 | - readline=8.2=h8228510_1
40 |
41 | - setuptools=71.0.4=pyhd8ed1ab_0
42 |
43 | - sqlite=3.46.0=h6d4b2fc_0
44 |
45 | - tk=8.6.13=noxft_h4845f30_101
46 |
47 | - wheel=0.43.0=pyhd8ed1ab_1
48 |
49 | - xz=5.2.6=h166bdaf_0
50 |
51 | - zlib=1.2.13=h4ab18f5_6
52 |
53 | - pip:
54 |
55 | - addict==2.4.0
56 |
57 | - aiohttp==3.9.5
58 |
59 | - aiosignal==1.3.1
60 |
61 | - asttokens==2.4.1
62 |
63 | - async-timeout==4.0.3
64 |
65 | - attrs==23.2.0
66 |
67 | - backcall==0.2.0
68 |
69 | - blinker==1.8.2
70 |
71 | - certifi==2024.7.4
72 |
73 | - charset-normalizer==3.3.2
74 |
75 | - click==8.1.7
76 |
77 | - comm==0.2.2
78 |
79 | - configargparse==1.7
80 |
81 | - contourpy==1.1.1
82 |
83 | - cycler==0.12.1
84 |
85 | - dash==2.17.1
86 |
87 | - dash-core-components==2.0.0
88 |
89 | - dash-html-components==2.0.0
90 |
91 | - dash-table==5.0.0
92 |
93 | - decorator==5.1.1
94 |
95 | - docker-pycreds==0.4.0
96 |
97 | - einops==0.8.0
98 |
99 | - executing==2.0.1
100 |
101 | - fastjsonschema==2.20.0
102 |
103 | - filelock==3.15.4
104 |
105 | - flask==3.0.3
106 |
107 | - fonttools==4.53.1
108 |
109 | - freetype-py==2.4.0
110 |
111 | - frozenlist==1.4.1
112 |
113 | - fsspec==2024.6.1
114 |
115 | - fvcore==0.1.5.post20221221
116 |
117 | - gitdb==4.0.11
118 |
119 | - gitpython==3.1.43
120 |
121 | - h5py==3.11.0
122 |
123 | - idna==3.7
124 |
125 | - imageio==2.34.2
126 |
127 | - importlib-metadata==8.2.0
128 |
129 | - importlib-resources==6.4.0
130 |
131 | - iopath==0.1.10
132 |
133 | - ipdb==0.13.13
134 |
135 | - ipython==8.12.3
136 |
137 | - ipywidgets==8.1.3
138 |
139 | - itsdangerous==2.2.0
140 |
141 | - jedi==0.19.1
142 |
143 | - jinja2==3.1.4
144 |
145 | - joblib==1.4.2
146 |
147 | - jsonschema==4.23.0
148 |
149 | - jsonschema-specifications==2023.12.1
150 |
151 | - jupyter-core==5.7.2
152 |
153 | - jupyterlab-widgets==3.0.11
154 |
155 | - kiwisolver==1.4.5
156 |
157 | - lazy-loader==0.4
158 |
159 | - lightning-utilities==0.11.6
160 |
161 | - markupsafe==2.1.5
162 |
163 | - matplotlib==3.7.5
164 |
165 | - matplotlib-inline==0.1.7
166 |
167 | - mesh-to-sdf==0.0.15
168 |
169 | - mpmath==1.3.0
170 |
171 | - multidict==6.0.5
172 |
173 | - nbformat==5.10.4
174 |
175 | - nest-asyncio==1.6.0
176 |
177 | - networkx==3.1
178 |
179 | - ninja==1.11.1.1
180 |
181 | - numpy==1.24.4
182 |
183 | - nvidia-cublas-cu12==12.1.3.1
184 |
185 | - nvidia-cuda-cupti-cu12==12.1.105
186 |
187 | - nvidia-cuda-nvrtc-cu12==12.1.105
188 |
189 | - nvidia-cuda-runtime-cu12==12.1.105
190 |
191 | - nvidia-cudnn-cu12==9.1.0.70
192 |
193 | - nvidia-cufft-cu12==11.0.2.54
194 |
195 | - nvidia-curand-cu12==10.3.2.106
196 |
197 | - nvidia-cusolver-cu12==11.4.5.107
198 |
199 | - nvidia-cusparse-cu12==12.1.0.106
200 |
201 | - nvidia-nccl-cu12==2.20.5
202 |
203 | - nvidia-nvjitlink-cu12==12.5.82
204 |
205 | - nvidia-nvtx-cu12==12.1.105
206 |
207 | - open3d==0.18.0
208 |
209 | - packaging==24.1
210 |
211 | - pandas==2.0.3
212 |
213 | - parso==0.8.4
214 |
215 | - pexpect==4.9.0
216 |
217 | - pickleshare==0.7.5
218 |
219 | - pillow==10.4.0
220 |
221 | - pkgutil-resolve-name==1.3.10
222 |
223 | - platformdirs==4.2.2
224 |
225 | - plotly==5.23.0
226 |
227 | - portalocker==2.10.1
228 |
229 | - progressbar==2.5
230 |
231 | - prompt-toolkit==3.0.47
232 |
233 | - protobuf==5.27.2
234 |
235 | - psutil==6.0.0
236 |
237 | - ptyprocess==0.7.0
238 |
239 | - pure-eval==0.2.3
240 |
241 | - pyglet==2.0.16
242 |
243 | - pygments==2.18.0
244 |
245 | - pyopengl==3.1.0
246 |
247 | - pyparsing==3.1.2
248 |
249 | - pyquaternion==0.9.9
250 |
251 | - pyrender==0.1.45
252 |
253 | - python-dateutil==2.9.0.post0
254 |
255 | - pytorch-lightning==2.3.3
256 |
257 | - pytorch3d==0.3.0
258 |
259 | - pytz==2024.1
260 |
261 | - pywavelets==1.4.1
262 |
263 | - pyyaml==6.0.1
264 |
265 | - referencing==0.35.1
266 |
267 | - requests==2.32.3
268 |
269 | - retrying==1.3.4
270 |
271 | - rpds-py==0.19.1
272 |
273 | - scikit-image==0.21.0
274 |
275 | - scikit-learn==1.3.2
276 |
277 | - scipy==1.10.1
278 |
279 | - sentry-sdk==2.11.0
280 |
281 | - setproctitle==1.3.3
282 |
283 | - six==1.16.0
284 |
285 | - smmap==5.0.1
286 |
287 | - stack-data==0.6.3
288 |
289 | - sympy==1.13.1
290 |
291 | - tabulate==0.9.0
292 |
293 | - tenacity==8.5.0
294 |
295 | - termcolor==2.4.0
296 |
297 | - threadpoolctl==3.5.0
298 |
299 | - tifffile==2023.7.10
300 |
301 | - tomli==2.0.1
302 |
303 | - torch==2.4.0
304 |
305 | - torchmetrics==1.4.0.post0
306 |
307 | - torchvision==0.19.0
308 |
309 | - tqdm==4.66.4
310 |
311 | - traitlets==5.14.3
312 |
313 | - transforms3d==0.4.2
314 |
315 | - trimesh==4.4.3
316 |
317 | - triton==3.0.0
318 |
319 | - typing-extensions==4.12.2
320 |
321 | - tzdata==2024.1
322 |
323 | - urllib3==2.2.2
324 |
325 | - wandb==0.17.5
326 |
327 | - wcwidth==0.2.13
328 |
329 | - werkzeug==3.0.3
330 |
331 | - widgetsnbextension==4.0.11
332 |
333 | - yacs==0.1.8
334 |
335 | - yarl==1.9.4
336 |
337 | - zipp==3.19.2
338 |
339 | prefix: /localdata/miniconda3/envs/se3
340 |
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/.DS_Store
--------------------------------------------------------------------------------
/src/config/eval.yml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: 'eval'
3 | checkpoint_dir: 'checkpoints'
4 | num_workers: 4
5 | batch_size: 4
6 | num_epochs: 1000
7 | gpus: [0]
8 | log_dir: '/home/two-by-two/logs/eval' # checkpoints will be read from this dir
9 | vis_dir: '/home/two-by-two/logs/eval'
10 |
11 | model:
12 | encoderA: 'vn_dgcnn'
13 | encoderB: 'vn_dgcnn'
14 | pose_predictor_quat: 'original'
15 | pose_predictor_rot: 'original'
16 | pose_predictor_trans: 'original'
17 | point_loss: 'True'
18 | recon_loss: 'False'
19 | corr_module: 'yes'
20 | pc_feat_dim: 512
21 | num_heads: 4
22 | num_blocks: 1
23 |
24 | cls_model:
25 | encoder: 'vn_dgcnn'
26 | point_loss: 'False'
27 | recon_loss: 'False'
28 | corr_module: 'No'
29 | pc_feat_dim: 512
30 | num_heads: 4
31 | num_blocks: 1
32 |
33 | # gpus: [0,1,2,3,4,5,6,7]
34 |
35 | # please modify the data path to the path of the generated point cloud
36 | data:
37 | root_dir: '/home/Shape_Data_pc'
38 | train_csv_file: '/home/Shape_Data_pc/stats/Train_All_List.txt'
39 | val_csv_file: '/home/Shape_Data_pc/stats/Test_All_List.txt'
40 | num_pc_points: 1024
41 | translation: ''
42 |
43 | optimizer:
44 | lr: 1e-4
45 | lr_decay: 0.0
46 | weight_decay: 1e-6
47 | lr_clip: 1e-5
--------------------------------------------------------------------------------
/src/config/train_A.yml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: 'network_A'
3 | checkpoint_dir: 'checkpoints'
4 | num_workers: 4
5 | batch_size: 4
6 | num_epochs: 1000
7 | gpus: [0]
8 | log_dir: '/home/two-by-two/logs/network_A' #modify to the correct log path
9 | vis_dir: '/home/two-by-two/logs/network_A' #modify to the correct log path
10 |
11 | model:
12 | encoderA: 'vn_dgcnn'
13 | encoderB: 'vn_dgcnn'
14 | pose_predictor_quat: 'original'
15 | pose_predictor_rot: 'original'
16 | pose_predictor_trans: 'original'
17 | point_loss: 'True'
18 | recon_loss: 'False'
19 | corr_module: 'yes'
20 | pc_feat_dim: 512
21 | num_heads: 4
22 | num_blocks: 1
23 |
24 | cls_model:
25 | encoder: 'vn_dgcnn'
26 | point_loss: 'False'
27 | recon_loss: 'False'
28 | corr_module: 'No'
29 | pc_feat_dim: 512
30 | num_heads: 4
31 | num_blocks: 1
32 |
33 | # gpus: [0,1,2,3,4,5,6,7]
34 |
35 | # please modify the data path to the path of the generated point cloud
36 | data:
37 | root_dir: '/home/Shape_Data_pc'
38 | train_csv_file: '/home/Shape_Data_pc/stats/Train_All_List.txt'
39 | val_csv_file: '/home/Shape_Data_pc/stats/Test_All_List.txt'
40 | num_pc_points: 1024
41 | translation: ''
42 |
43 | optimizer:
44 | lr: 1e-4
45 | lr_decay: 0.0
46 | weight_decay: 1e-6
47 | lr_clip: 1e-5
--------------------------------------------------------------------------------
/src/config/train_B.yml:
--------------------------------------------------------------------------------
1 | exp:
2 | name: 'network_B'
3 | checkpoint_dir: 'checkpoints'
4 | num_workers: 4
5 | batch_size: 4
6 | num_epochs: 1000
7 | gpus: [0]
8 | log_dir: '/home/two-by-two/logs/network_B' #modify to the correct log path
9 | vis_dir: '/home/two-by-two/logs/network_B' #modify to the correct log path
10 |
11 | model:
12 | encoderA: 'vn_dgcnn'
13 | encoderB: 'vn_dgcnn'
14 | pose_predictor_quat: 'original'
15 | pose_predictor_rot: 'original'
16 | pose_predictor_trans: 'original'
17 | point_loss: 'True'
18 | recon_loss: 'False'
19 | corr_module: 'yes'
20 | pc_feat_dim: 512
21 | num_heads: 4
22 | num_blocks: 1
23 |
24 | cls_model:
25 | encoder: 'vn_dgcnn'
26 | point_loss: 'False'
27 | recon_loss: 'False'
28 | corr_module: 'No'
29 | pc_feat_dim: 512
30 | num_heads: 4
31 | num_blocks: 1
32 |
33 | # gpus: [0,1,2,3,4,5,6,7]
34 |
35 | # please modify the data path to the path of the generated point cloud
36 | data:
37 | root_dir: '/home/Shape_Data_pc'
38 | train_csv_file: '/home/Shape_Data_pc/stats/Train_All_List.txt'
39 | val_csv_file: '/home/Shape_Data_pc/stats/Test_All_List.txt'
40 | num_pc_points: 1024
41 | translation: ''
42 |
43 | optimizer:
44 | lr: 1e-4
45 | lr_decay: 0.0
46 | weight_decay: 1e-6
47 | lr_clip: 1e-5
--------------------------------------------------------------------------------
/src/script/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/script/.DS_Store
--------------------------------------------------------------------------------
/src/script/our_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import argparse
4 | import wandb
5 |
6 | wandb.require("core")
7 |
8 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6'
9 |
10 | # Suppress specific warnings
11 | warnings.filterwarnings('ignore', category=UserWarning, message='.*TORCH_CUDA_ARCH_LIST.*')
12 |
13 | import torch.distributed as dist
14 | import torch.multiprocessing as mp
15 | from torch.nn.parallel import DistributedDataParallel as DDP
16 | from torch.nn import DataParallel
17 | from torch.utils.data import DataLoader
18 | from torch.utils.data.distributed import DistributedSampler
19 | import torch
20 |
21 | import sys
22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly'))
24 | from config import get_cfg_defaults
25 | from datasets.dataloader.dataloader_B import OurDataset
26 | from models.train.network_vnn_B import ShapeAssemblyNet_B_vnn
27 | from models.train.network_vnn_A import ShapeAssemblyNet_A_vnn
28 | import utils
29 | from tqdm import tqdm
30 |
31 | orange = '\033[38;5;214m'
32 | reset = '\033[0m'
33 |
34 | def setup(rank, world_size, cfg):
35 | os.environ['MASTER_ADDR'] = 'localhost'
36 | os.environ['MASTER_PORT'] = '52997'
37 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
38 | torch.cuda.set_device(rank)
39 |
40 | def cleanup():
41 | dist.destroy_process_group()
42 |
43 |
44 | def train(rank, world_size, conf):
45 | # create the log_file
46 | os.makedirs(conf.exp.log_dir, exist_ok=True)
47 |
48 | # create the vis file
49 | os.makedirs(conf.exp.vis_dir, exist_ok=True)
50 |
51 | setup(rank, world_size, conf)
52 |
53 | if dist.get_rank() == 0:
54 | wandb.init(project='shape-matching', notes='weak baseline', config=conf)
55 | wandb.define_metric("test/epoch/*", step_metric="test/epoch/epoch")
56 | # wandb.define_metric("val/step/*", step_metric="val/step/step")
57 | wandb.define_metric("train/network_B/*", step_metric="train/network_B/step")
58 | wandb.define_metric("train/network_A/*", step_metric="train/network_A/step")
59 |
60 | data_features = ['src_pc', 'src_rot', 'src_trans', 'tgt_pc', 'tgt_rot', 'tgt_trans', 'partA_symmetry_type', 'partB_symmetry_type','predicted_partB_rotation', 'predicted_partB_position', 'predicted_partA_rotation', 'predicted_partA_position']
61 |
62 | network_B = ShapeAssemblyNet_B_vnn(cfg=conf, data_features=data_features)
63 | network_B.cuda(rank)
64 |
65 | # Load pretrained weights for network_B
66 | network_B.load_state_dict(torch.load(os.path.join(conf.exp.log_dir, 'ckpts', 'network_B.pth')))
67 |
68 | for param in network_B.parameters():
69 | param.requires_grad = False
70 |
71 |
72 | network_A = ShapeAssemblyNet_A_vnn(cfg=conf, data_features=data_features)
73 | network_A.cuda(rank)
74 |
75 | # Load pretrained weights for network_A
76 | network_A.load_state_dict(torch.load(os.path.join(conf.exp.log_dir, 'ckpts', 'network_A.pth')))
77 | network_A = DDP(network_A, device_ids=[rank])
78 |
79 | for param in network_A.parameters():
80 | param.requires_grad= False
81 |
82 | # Initialize train dataloader
83 | train_data = OurDataset(
84 | data_root_dir=conf.data.root_dir,
85 | data_csv_file=conf.data.train_csv_file,
86 | data_features=data_features,
87 | num_points=conf.data.num_pc_points
88 | )
89 | train_data.load_data()
90 |
91 | print('Len of Train Data: ', len(train_data))
92 |
93 | # Initialize val dataloader
94 | val_data = OurDataset(
95 | data_root_dir=conf.data.root_dir,
96 | data_csv_file=conf.data.val_csv_file,
97 | data_features=data_features,
98 | num_points=conf.data.num_pc_points
99 | )
100 | val_data.load_data()
101 |
102 | print('Len of Val Data: ', len(val_data))
103 |
104 | val_sampler = DistributedSampler(val_data, num_replicas=world_size, rank=rank)
105 | val_dataloader = DataLoader(
106 | dataset=val_data,
107 | batch_size=conf.exp.batch_size,
108 | num_workers=conf.exp.num_workers,
109 | pin_memory=True,
110 | # shuffle=True,
111 | shuffle=False,
112 | drop_last=False,
113 | sampler=val_sampler
114 | )
115 |
116 | network_opt = torch.optim.Adam(list(network_A.parameters()), lr=conf.optimizer.lr, weight_decay=conf.optimizer.weight_decay)
117 | val_num_batch = len(val_dataloader)
118 | print(f"\033[33mNumber of batches in val_dataloader: {len(val_dataloader)}\033[0m")
119 | print(f"\033[33mNumber of samples in val_data: {len(val_data)}\033[0m")
120 |
121 | # === Evaluation ===
122 | tot, tot_gd, tot_r, tot_t, tot_CD_A, tot_CD_B = 0, 0, 0, 0, 0, 0
123 |
124 | network_A.train()
125 | network_B.train()
126 |
127 | for val_batch_ind, val_batch in enumerate(val_dataloader):
128 | for key in val_batch.keys():
129 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']:
130 | val_batch[key] = val_batch[key].to(rank)
131 | with torch.no_grad():
132 | partB_eval_metric, partB_pos, partB_rot, *_ = network_B.forward_pass(
133 | batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind)
134 | val_batch['predicted_partB_position'] = partB_pos
135 | val_batch['predicted_partB_rotation'] = partB_rot
136 |
137 | partA_eval_metric, *_ = network_A.module.forward_pass(
138 | batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind)
139 |
140 | GD_B, R_B, T_B, _, _, CD_B = partB_eval_metric
141 | GD_A, R_A, T_A, _, _, CD_A = partA_eval_metric
142 |
143 | tot += 1
144 | tot_gd += (GD_A.mean() + GD_B.mean()) / 2
145 | tot_r += (R_A.mean() + R_B.mean()) / 2
146 | tot_t += (T_A.mean() + T_B.mean()) / 2
147 | tot_CD_A += CD_A.mean()
148 | tot_CD_B += CD_B.mean()
149 |
150 | if dist.get_rank() == 0:
151 | print(f"\033[1;32m[Eval Results]\033[0m")
152 | print(f"avg_gd: {tot_gd / tot}")
153 | print(f"avg_r: {tot_r / tot}")
154 | print(f"avg_t: {tot_t / tot}")
155 | print(f"avg_CD_A: {tot_CD_A / tot}")
156 | print(f"avg_CD_B: {tot_CD_B / tot}")
157 |
158 | wandb.log({
159 | 'test/epoch/avg_gd': (tot_gd / tot).item(),
160 | 'test/epoch/avg_r': (tot_r / tot).item(),
161 | 'test/epoch/avg_t': (tot_t / tot).item(),
162 | 'test/epoch/avg_CD_A': (tot_CD_A / tot).item(),
163 | 'test/epoch/avg_CD_B': (tot_CD_B / tot).item(),
164 | 'test/epoch/epoch': 0
165 | })
166 |
167 | dist.barrier()
168 | cleanup()
169 |
170 |
171 | def main(cfg):
172 | world_size = len(cfg.gpus)
173 |
174 | mp.spawn(train, args=(world_size, cfg), nprocs=world_size, join=True)
175 |
176 | if __name__ == '__main__':
177 |
178 | wandb.login()
179 |
180 | parser = argparse.ArgumentParser(description="Training script")
181 | parser.add_argument('--cfg_file', default='', type=str)
182 | parser.add_argument('--gpus', nargs='+', default=-1, type=int)
183 |
184 | args = parser.parse_args()
185 | args.cfg_file = os.path.join('./config', args.cfg_file)
186 |
187 | cfg = get_cfg_defaults()
188 | cfg.merge_from_file(args.cfg_file)
189 |
190 | cfg.gpus = cfg.exp.gpus
191 |
192 | cfg.freeze()
193 | print(cfg)
194 |
195 | main(cfg)
--------------------------------------------------------------------------------
/src/script/our_train_A.py:
--------------------------------------------------------------------------------
1 | """
2 | train network_A based on partB's ground truth
3 | """
4 |
5 |
6 | import os
7 | import warnings
8 | import argparse
9 | import wandb
10 |
11 | wandb.require("core")
12 |
13 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6'
14 |
15 | # Suppress specific warnings
16 | warnings.filterwarnings('ignore', category=UserWarning, message='.*TORCH_CUDA_ARCH_LIST.*')
17 |
18 | import torch.distributed as dist
19 | import torch.multiprocessing as mp
20 | from torch.nn.parallel import DistributedDataParallel as DDP
21 | from torch.nn import DataParallel
22 | from torch.utils.data import DataLoader
23 | from torch.utils.data.distributed import DistributedSampler
24 | import torch
25 |
26 | import sys
27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
28 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly'))
29 | from config import get_cfg_defaults
30 | from datasets.dataloader.dataloader_A import OurDataset
31 | from models.train.network_vnn_A_indi import ShapeAssemblyNet_A_vnn
32 | import utils
33 | from tqdm import tqdm
34 |
35 | orange = '\033[38;5;214m'
36 | reset = '\033[0m'
37 |
38 | def setup(rank, world_size, cfg):
39 | os.environ['MASTER_ADDR'] = 'localhost'
40 | os.environ['MASTER_PORT'] = '13013'
41 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
42 | torch.cuda.set_device(rank)
43 |
44 | def cleanup():
45 | dist.destroy_process_group()
46 |
47 | def train(rank, world_size, conf):
48 | os.makedirs(conf.exp.log_dir, exist_ok=True)
49 | os.makedirs(conf.exp.vis_dir, exist_ok=True)
50 | setup(rank, world_size, conf)
51 |
52 | if dist.get_rank() == 0:
53 | wandb.init(project='shape-matching', notes='weak baseline', config=conf)
54 | wandb.define_metric("test/epoch/*", step_metric="test/epoch/epoch")
55 | wandb.define_metric("train/network_A/*", step_metric="train/network_A/step")
56 |
57 | data_features = ['src_pc', 'src_rot', 'src_trans', 'tgt_pc', 'tgt_rot', 'tgt_trans', 'partA_symmetry_type', 'partB_symmetry_type','predicted_partB_rotation', 'predicted_partB_position', 'predicted_partA_rotation', 'predicted_partA_position']
58 |
59 | network_A = ShapeAssemblyNet_A_vnn(cfg=conf, data_features=data_features)
60 | network_A.cuda(rank)
61 | network_A = DDP(network_A, device_ids=[rank])
62 |
63 | # Initialize train dataloader
64 | train_data = OurDataset(
65 | data_root_dir=conf.data.root_dir,
66 | data_csv_file=conf.data.train_csv_file,
67 | data_features=data_features,
68 | num_points=conf.data.num_pc_points
69 | )
70 | train_data.load_data()
71 |
72 | print('Len of Train Data: ', len(train_data))
73 |
74 | train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank)
75 | train_dataloader = DataLoader(
76 | dataset=train_data,
77 | batch_size=conf.exp.batch_size,
78 | num_workers=conf.exp.num_workers,
79 | pin_memory=True,
80 | shuffle=False,
81 | drop_last=False,
82 | sampler=train_sampler
83 | )
84 |
85 | # Initialize val dataloader
86 | val_data = OurDataset(
87 | data_root_dir=conf.data.root_dir,
88 | data_csv_file=conf.data.val_csv_file,
89 | data_features=data_features,
90 | num_points=conf.data.num_pc_points
91 | )
92 | val_data.load_data()
93 |
94 | # Output the distribution of the validation data
95 | print('Len of Val Data: ', len(val_data))
96 |
97 | val_sampler = DistributedSampler(val_data, num_replicas=world_size, rank=rank)
98 | val_dataloader = DataLoader(
99 | dataset=val_data,
100 | batch_size=conf.exp.batch_size,
101 | num_workers=conf.exp.num_workers,
102 | pin_memory=True,
103 | # shuffle=True,
104 | shuffle=False,
105 | drop_last=False,
106 | sampler=val_sampler
107 | )
108 |
109 | network_opt = torch.optim.Adam(list(network_A.parameters()), lr=conf.optimizer.lr, weight_decay=conf.optimizer.weight_decay)
110 | val_num_batch = len(val_dataloader)
111 | print(f"\033[33mNumber of batches in val_dataloader: {len(val_dataloader)}\033[0m")
112 | print(f"\033[33mNumber of samples in val_data: {len(val_data)}\033[0m")
113 | train_num_batch = len(train_dataloader)
114 |
115 | val_step = 0
116 | for epoch in tqdm(range(1, conf.exp.num_epochs + 1)):
117 | print("\033[1;32m", "Epoch {} In training".format(epoch), "\033[0m")
118 | train_dataloader.sampler.set_epoch(epoch)
119 | val_dataloader.sampler.set_epoch(epoch)
120 |
121 | train_batches = enumerate(train_dataloader, 0)
122 | val_batches = enumerate(val_dataloader, 0)
123 | val_fraction_done = 0.0
124 | val_batch_ind = -1
125 |
126 | # train for every batch
127 | for train_batch_ind, batch in train_batches:
128 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0:
129 | print("*" * 10)
130 | print(epoch, train_batch_ind)
131 | # print("*" * 10)
132 | train_fraction_done = (train_batch_ind + 1) / train_num_batch
133 | train_step = epoch * train_num_batch + train_batch_ind
134 |
135 | log_console = True
136 |
137 | # save checkpoint of network_A and network_B
138 | if epoch % 10 == 0 and train_batch_ind == 0 and dist.get_rank() == 0:
139 | with torch.no_grad():
140 |
141 | os.makedirs(os.path.join(conf.exp.log_dir, 'ckpts'), exist_ok=True)
142 | torch.save(network_A.module.state_dict(), os.path.join(conf.exp.log_dir, 'ckpts', '%d-network_A.pth' % epoch), _use_new_zipfile_serialization=False)
143 |
144 | network_A.train()
145 | for key in batch.keys():
146 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']:
147 | batch[key] = batch[key].cuda(non_blocking=True)
148 |
149 | partA_predictions = network_A.module.training_step(batch_data=batch, device=rank, batch_idx=train_batch_ind)
150 | total_loss_A = partA_predictions["total_loss"]
151 | rot_loss_A = partA_predictions["rot_loss"]
152 | trans_loss_A = partA_predictions["trans_loss"]
153 |
154 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0:
155 | print(total_loss_A.detach().cpu().numpy())
156 |
157 | step = train_step
158 | if dist.get_rank() == 0:
159 | wandb.log({'train/network_A/total_loss': total_loss_A.item(), 'train/network_A/rot_loss': rot_loss_A.item(), 'train/network_A/trans_loss':trans_loss_A.item(), 'train/network_A/step':step})
160 |
161 |
162 | total_loss = total_loss_A
163 |
164 | # optimize one step
165 | network_opt.zero_grad()
166 | total_loss.backward()
167 | network_opt.step()
168 |
169 | dist.barrier()
170 |
171 | """
172 | In evaluation mode:
173 | """
174 |
175 | if epoch % 1 == 0:
176 | tot = 0
177 | tot_gd_A = 0
178 | # tot_gd_B = 0
179 | tot_r_A = 0
180 | # tot_r_B = 0
181 | tot_t_A = 0
182 | # tot_t_B = 0
183 | tot_pa_A = 0
184 | # tot_pa_B = 0
185 | tot_pa = 0
186 | tot_t = 0
187 | tot_r = 0
188 | tot_gd = 0
189 | tot_pa_threshold = 0
190 | tot_CD_A = 0
191 | # tot_CD_B = 0
192 | tot_pa_threshold_A = 0
193 | # tot_pa_threshold_B = 0
194 | val_batches = enumerate(val_dataloader, 0)
195 | val_fraction_done = 0.0
196 | val_batch_ind = -1
197 | # device = torch.device('cuda:0')
198 |
199 | # train for every batch
200 |
201 | total_loss_epoch = 0
202 | rot_loss_epoch = 0
203 | trans_loss_epoch = 0
204 |
205 | for val_batch_ind, val_batch in val_batches:
206 | if val_batch_ind % 50 == 0:
207 | print("*" * 10)
208 | print(epoch, val_batch_ind)
209 | print("*" * 10)
210 |
211 |
212 | network_A.train()
213 | # network_B.train()
214 |
215 | for key in val_batch.keys():
216 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']:
217 | val_batch[key] = val_batch[key].to(rank)
218 | with torch.no_grad():
219 |
220 | partA_eval_metric, partA_total_loss, partA_point_loss, partA_rot_loss, partA_trans_loss, partA_recon_loss = network_A.module.forward_pass(batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind)
221 | GD_A, R_error_A, RMSE_T_A, PA_threshold_A, PA_A, CD_A = partA_eval_metric
222 |
223 | total_loss_epoch += partA_total_loss.item()
224 | rot_loss_epoch += partA_rot_loss.item()
225 | trans_loss_epoch += partA_trans_loss.item()
226 |
227 | val_step += 1
228 |
229 | tot_gd_A += GD_A.mean()
230 | # tot_gd_B += GD_B.mean()
231 | tot_r_A += R_error_A.mean() # the rotation loss for each batch
232 | # tot_r_B += R_error_B.mean() # the rotation loss for each batch
233 | tot_t_A += RMSE_T_A.mean()
234 | # tot_t_B += RMSE_T_B.mean()
235 | tot_pa_threshold_A += PA_threshold_A.mean()
236 | # tot_pa_threshold_B += PA_threshold_B.mean()
237 | tot_pa_A += PA_A.mean()
238 | # tot_pa_B += PA_B.mean()
239 | tot_CD_A += CD_A.mean()
240 | # tot_CD_B += CD_B.mean()
241 |
242 | tot_r = tot_r_A
243 | # tot_r = tot_r_A
244 | tot_t = tot_t_A
245 | # tot_t = tot_t_A
246 | tot_gd = tot_gd_A
247 | tot += 1
248 |
249 | if dist.get_rank() == 0:
250 | print("\033[1;32m", "Epoch {} In validation".format(epoch), "\033[0m")
251 |
252 | print("avg_gd: ", tot_gd / tot)
253 | print("avg_r: ", tot_r / tot)
254 | print("avg_t: ", tot_t / tot)
255 |
256 | print("avg_CD_1: ", tot_CD_A / tot)
257 |
258 | log_data = {
259 | 'test/epoch/avg_gd':(tot_gd/tot).item(),
260 | 'test/epoch/avg_r': (tot_r/tot).item(),
261 | 'test/epoch/avg_t': (tot_t/tot).item(),
262 | 'test/epoch/avg_CD_A': (tot_CD_A/tot).item(),
263 | 'test/epoch/total_loss': total_loss_epoch,
264 | 'test/epoch/rot_loss': rot_loss_epoch,
265 | 'test/epoch/trans_loss': trans_loss_epoch,
266 | 'test/epoch/epoch': epoch
267 | }
268 | wandb.log(
269 | log_data
270 | )
271 |
272 | dist.barrier()
273 | cleanup()
274 |
275 |
276 | def main(cfg):
277 | world_size = len(cfg.gpus)
278 |
279 | mp.spawn(train, args=(world_size, cfg), nprocs=world_size, join=True)
280 |
281 |
282 | if __name__ == '__main__':
283 |
284 | wandb.login()
285 |
286 | parser = argparse.ArgumentParser(description="Training script")
287 | parser.add_argument('--cfg_file', default='', type=str)
288 | parser.add_argument('--gpus', nargs='+', default=-1, type=int)
289 |
290 | args = parser.parse_args()
291 | args.cfg_file = os.path.join('./config', args.cfg_file)
292 |
293 | cfg = get_cfg_defaults()
294 | cfg.merge_from_file(args.cfg_file)
295 |
296 | cfg.gpus = cfg.exp.gpus
297 |
298 | cfg.freeze()
299 | print(cfg)
300 |
301 | main(cfg)
--------------------------------------------------------------------------------
/src/script/our_train_B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import argparse
4 | import wandb
5 |
6 | wandb.require("core")
7 |
8 | os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6'
9 |
10 | # Suppress specific warnings
11 | warnings.filterwarnings('ignore', category=UserWarning, message='.*TORCH_CUDA_ARCH_LIST.*')
12 |
13 | import torch.distributed as dist
14 | import torch.multiprocessing as mp
15 | from torch.nn.parallel import DistributedDataParallel as DDP
16 | from torch.nn import DataParallel
17 | from torch.utils.data import DataLoader
18 | from torch.utils.data.distributed import DistributedSampler
19 | import torch
20 |
21 | import sys
22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly'))
24 | from config import get_cfg_defaults
25 | from datasets.dataloader.dataloader_B import OurDataset
26 | from models.train.network_vnn_B import ShapeAssemblyNet_B_vnn
27 | import utils
28 | # from tensorboardX import SummaryWriter
29 | from tqdm import tqdm
30 |
31 | orange = '\033[38;5;214m'
32 | reset = '\033[0m'
33 |
34 | def setup(rank, world_size, cfg):
35 | os.environ['MASTER_ADDR'] = 'localhost'
36 | os.environ['MASTER_PORT'] = '12719'
37 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
38 | torch.cuda.set_device(rank)
39 |
40 | def cleanup():
41 | dist.destroy_process_group()
42 |
43 | def train(rank, world_size, conf):
44 | os.makedirs(conf.exp.log_dir, exist_ok=True)
45 | os.makedirs(conf.exp.vis_dir, exist_ok=True)
46 | setup(rank, world_size, conf)
47 |
48 | if dist.get_rank() == 0:
49 | wandb.init(project='shape-matching', notes='weak baseline', config=conf)
50 | wandb.define_metric("test/epoch/*", step_metric="test/epoch/epoch")
51 | wandb.define_metric("train/network_B/*", step_metric="train/network_B/step")
52 |
53 | data_features = ['src_pc', 'src_rot', 'src_trans', 'tgt_pc', 'tgt_rot', 'tgt_trans', 'partA_symmetry_type', 'partB_symmetry_type','predicted_partB_rotation', 'predicted_partB_position', 'predicted_partA_rotation', 'predicted_partA_position']
54 |
55 | network_B = ShapeAssemblyNet_B_vnn(cfg=conf, data_features=data_features)
56 | network_B.cuda(rank)
57 | network_B = DDP(network_B, device_ids=[rank])
58 |
59 | # Initialize train dataloader
60 | train_data = OurDataset(
61 | data_root_dir=conf.data.root_dir,
62 | data_csv_file=conf.data.train_csv_file,
63 | data_features=data_features,
64 | num_points=conf.data.num_pc_points
65 | )
66 | train_data.load_data()
67 |
68 | print('Len of Train Data: ', len(train_data))
69 |
70 | train_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank)
71 | train_dataloader = DataLoader(
72 | dataset=train_data,
73 | batch_size=conf.exp.batch_size,
74 | num_workers=conf.exp.num_workers,
75 | pin_memory=True,
76 | shuffle=False,
77 | drop_last=False,
78 | sampler=train_sampler
79 | )
80 |
81 | # Initialize val dataloader
82 | val_data = OurDataset(
83 | data_root_dir=conf.data.root_dir,
84 | data_csv_file=conf.data.val_csv_file,
85 | data_features=data_features,
86 | num_points=conf.data.num_pc_points
87 | )
88 | val_data.load_data()
89 |
90 | print('Len of Val Data: ', len(val_data))
91 |
92 | val_sampler = DistributedSampler(val_data, num_replicas=world_size, rank=rank)
93 | val_dataloader = DataLoader(
94 | dataset=val_data,
95 | batch_size=conf.exp.batch_size,
96 | num_workers=conf.exp.num_workers,
97 | pin_memory=True,
98 | shuffle=False,
99 | drop_last=False,
100 | sampler=val_sampler
101 | )
102 |
103 | network_opt = torch.optim.Adam(list(network_B.parameters()), lr=conf.optimizer.lr, weight_decay=conf.optimizer.weight_decay)
104 | val_num_batch = len(val_dataloader)
105 | print(f"\033[33mNumber of batches in val_dataloader: {len(val_dataloader)}\033[0m")
106 | print(f"\033[33mNumber of samples in val_data: {len(val_data)}\033[0m")
107 | train_num_batch = len(train_dataloader)
108 |
109 | val_step = 0
110 | for epoch in tqdm(range(1, conf.exp.num_epochs + 1)):
111 | print("\033[1;32m", "Epoch {} In training".format(epoch), "\033[0m")
112 | train_dataloader.sampler.set_epoch(epoch)
113 | val_dataloader.sampler.set_epoch(epoch)
114 |
115 | train_batches = enumerate(train_dataloader, 0)
116 | val_batches = enumerate(val_dataloader, 0)
117 | val_fraction_done = 0.0
118 | val_batch_ind = -1
119 |
120 | # train for every batch
121 | for train_batch_ind, batch in train_batches:
122 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0:
123 | print("*" * 10)
124 | print(epoch, train_batch_ind)
125 | # print("*" * 10)
126 | train_fraction_done = (train_batch_ind + 1) / train_num_batch
127 | train_step = epoch * train_num_batch + train_batch_ind
128 |
129 | log_console = True
130 |
131 | # save checkpoint of network_A and network_B
132 | if epoch % 10 == 0 and train_batch_ind == 0 and dist.get_rank() == 0:
133 | with torch.no_grad():
134 |
135 | os.makedirs(os.path.join(conf.exp.log_dir, 'ckpts'), exist_ok=True)
136 | torch.save(network_B.module.state_dict(), os.path.join(conf.exp.log_dir, 'ckpts', '%d-network_B.pth' % epoch), _use_new_zipfile_serialization=False)
137 |
138 | network_B.train()
139 | for key in batch.keys():
140 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']:
141 | batch[key] = batch[key].cuda(non_blocking=True)
142 |
143 | partB_predictions = network_B.module.training_step(batch_data=batch, device=rank, batch_idx=train_batch_ind)
144 | total_loss_B = partB_predictions["total_loss"]
145 | rot_loss_B = partB_predictions["rot_loss"]
146 | trans_loss_B = partB_predictions["trans_loss"]
147 |
148 | batch['predicted_partB_position'] = partB_predictions['predicted_partB_position']
149 | batch['predicted_partB_rotation'] = partB_predictions['predicted_partB_rotation']
150 |
151 | if train_batch_ind % 50 == 0 and dist.get_rank() == 0:
152 | print(total_loss_B.detach().cpu().numpy())
153 |
154 | step = train_step
155 | if dist.get_rank() == 0:
156 | wandb.log({'train/network_B/total_loss': total_loss_B.item(), 'train/network_B/rot_loss': rot_loss_B.item(), 'train/network_B/trans_loss': trans_loss_B.item(), 'train/network_B/step':step})
157 |
158 | step = train_step
159 |
160 | total_loss = total_loss_B
161 |
162 | network_opt.zero_grad()
163 | total_loss.backward()
164 | network_opt.step()
165 |
166 | dist.barrier()
167 |
168 | if epoch % 1 == 0:
169 | tot = 0
170 | tot_gd_A = 0
171 | tot_gd_B = 0
172 | tot_r_A = 0
173 | tot_r_B = 0
174 | tot_t_A = 0
175 | tot_t_B = 0
176 | tot_pa_A = 0
177 | tot_pa_B = 0
178 | tot_pa = 0
179 | tot_t = 0
180 | tot_r = 0
181 | tot_gd = 0
182 | tot_pa_threshold = 0
183 | tot_CD_A = 0
184 | tot_CD_B = 0
185 | tot_pa_threshold_A = 0
186 | tot_pa_threshold_B = 0
187 | val_batches = enumerate(val_dataloader, 0)
188 | val_fraction_done = 0.0
189 | val_batch_ind = -1
190 |
191 | total_loss_epoch = 0
192 | rot_loss_epoch = 0
193 | trans_loss_epoch = 0
194 |
195 | for val_batch_ind, val_batch in val_batches:
196 | if val_batch_ind % 50 == 0:
197 | print("*" * 10)
198 | print(epoch, val_batch_ind)
199 | print("*" * 10)
200 |
201 | # network_A.train()
202 | network_B.train()
203 |
204 | for key in val_batch.keys():
205 | if key not in ['category_name', 'cut_name', 'shape_id', 'result_id']:
206 | val_batch[key] = val_batch[key].to(rank)
207 | with torch.no_grad():
208 |
209 | partB_eval_metric, partB_position, partB_rotation, partB_total_loss, partB_point_loss, partB_rot_loss, partB_trans_loss, partB_recon_loss = network_B.module.forward_pass(batch_data=val_batch, device=rank, mode='val', vis_idx=val_batch_ind)
210 | GD_B, R_error_B, RMSE_T_B, PA_threshold_B, PA_B, CD_B = partB_eval_metric
211 |
212 | total_loss_epoch += partB_total_loss.item()
213 | rot_loss_epoch += partB_rot_loss.item()
214 | trans_loss_epoch += partB_trans_loss.item()
215 |
216 | val_batch['predicted_partB_position'] = partB_position
217 | val_batch['predicted_partB_rotation'] = partB_rotation
218 |
219 | val_step += 1
220 |
221 | # [todo]
222 |
223 | # tot_gd_A += GD_A.mean()
224 | tot_gd_B += GD_B.mean()
225 | # tot_r_A += R_error_A.mean() # the rotation loss for each batch
226 | tot_r_B += R_error_B.mean() # the rotation loss for each batch
227 | # tot_t_A += RMSE_T_A.mean()
228 | tot_t_B += RMSE_T_B.mean()
229 | # tot_pa_threshold_A += PA_threshold_A.mean()
230 | tot_pa_threshold_B += PA_threshold_B.mean()
231 | # tot_pa_A += PA_A.mean()
232 | tot_pa_B += PA_B.mean()
233 | # tot_CD_A += CD_A.mean()
234 | tot_CD_B += CD_B.mean()
235 |
236 | tot_r = tot_r_B
237 | tot_t = tot_t_B
238 | tot_gd = (tot_gd_A + tot_gd_B) / 2
239 | tot += 1
240 |
241 | if dist.get_rank() == 0:
242 | # print("logging wandb !!! in val !!!!! ")
243 | print("\033[1;32m", "Epoch {} In validation".format(epoch), "\033[0m")
244 |
245 | print("avg_gd: ", tot_gd / tot)
246 | print("avg_r: ", tot_r / tot)
247 | print("avg_t: ", tot_t / tot)
248 | print("avg_CD_1: ", tot_CD_A / tot)
249 | print("avg_CD_2: ", tot_CD_B / tot)
250 |
251 | log_data = {
252 | 'test/epoch/avg_gd':(tot_gd/tot).item(),
253 | 'test/epoch/avg_r': (tot_r/tot).item(),
254 | 'test/epoch/avg_t': (tot_t/tot).item(),
255 | 'test/epoch/avg_CD_B': (tot_CD_B/tot).item(),
256 | 'test/epoch/total_loss': total_loss_epoch,
257 | 'test/epoch/rot_loss': rot_loss_epoch,
258 | 'test/epoch/trans_loss': trans_loss_epoch,
259 | 'test/epoch/epoch': epoch
260 | }
261 | wandb.log(
262 | log_data
263 | )
264 |
265 | # val_step += 1
266 |
267 | dist.barrier()
268 | cleanup()
269 |
270 |
271 | def main(cfg):
272 | world_size = len(cfg.gpus)
273 | mp.spawn(train, args=(world_size, cfg), nprocs=world_size, join=True)
274 |
275 |
276 | if __name__ == '__main__':
277 |
278 | wandb.login()
279 |
280 | parser = argparse.ArgumentParser(description="Training script")
281 | parser.add_argument('--cfg_file', default='', type=str)
282 | parser.add_argument('--gpus', nargs='+', default=-1, type=int)
283 |
284 | args = parser.parse_args()
285 | args.cfg_file = os.path.join('./config', args.cfg_file)
286 |
287 | cfg = get_cfg_defaults()
288 | cfg.merge_from_file(args.cfg_file)
289 |
290 | cfg.gpus = cfg.exp.gpus
291 |
292 | cfg.freeze()
293 | print(cfg)
294 | main(cfg)
--------------------------------------------------------------------------------
/src/shape_assembly/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/.DS_Store
--------------------------------------------------------------------------------
/src/shape_assembly/__init__.py:
--------------------------------------------------------------------------------
1 | from utils import *
--------------------------------------------------------------------------------
/src/shape_assembly/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # Miscellaneous configs
4 | _C = CN()
5 |
6 | # Experiment related
7 | _C.exp = CN()
8 | _C.exp.name = ''
9 | _C.exp.checkpoint_dir = ''
10 | _C.exp.weight_file = ''
11 | _C.exp.gpus = [0]
12 | _C.exp.num_workers = 8
13 | _C.exp.batch_size = 1
14 | _C.exp.num_epochs = 1000
15 | _C.exp.log_dir = ''
16 | _C.exp.load_from = ''
17 | _C.exp.ckpt_name = ''
18 | _C.exp.vis_dir = ''
19 |
20 | # Model related
21 | _C.model = CN()
22 | _C.model.encoderA = ''
23 | _C.model.encoderB = ''
24 | _C.model.encoder = ''
25 | _C.model.encoder_geo = ''
26 | _C.model.pose_predictor_quat = ''
27 | _C.model.pose_predictor_rot = ''
28 | _C.model.pose_predictor_trans = ''
29 | _C.model.corr_module = ''
30 | _C.model.sdf_predictor= ''
31 | _C.model.aggregator = ''
32 | _C.model.pc_feat_dim = 512
33 | _C.model.transformer_feat_dim = 1024
34 | _C.model.num_heads = 4
35 | _C.model.num_blocks = 1
36 | _C.model.recon_loss = False
37 | _C.model.point_loss = False
38 | _C.model.corr_feat = 'False'
39 |
40 | # Classification Model related
41 | _C.cls_model = CN()
42 | _C.cls_model.encoder = ''
43 | _C.cls_model.point_loss = False
44 | _C.cls_model.recon_loss = False
45 | _C.cls_model.corr_module = 'No'
46 | _C.cls_model.pc_feat_dim = 512
47 | _C.cls_model.num_heads = 4
48 | _C.cls_model.num_blocks = 1
49 | _C.cls_model.PCA_obb = False
50 |
51 | # Data related
52 | _C.data = CN()
53 | _C.data.root_dir = ''
54 | _C.data.train_csv_file = ''
55 | _C.data.val_csv_file = ''
56 | _C.data.num_pc_points = 1024
57 | _C.data.translation = ''
58 |
59 | # Classification Task Data related
60 | _C.cls_data = CN()
61 | _C.cls_data.root_dir = ''
62 | _C.cls_data.train_csv_file = ''
63 | _C.cls_data.val_csv_file = ''
64 | _C.cls_data.num_pc_points = 1024
65 | _C.cls_data.translation = ''
66 |
67 | # Optimizer related
68 | _C.optimizer = CN()
69 | _C.optimizer.lr = 1e-3
70 | _C.optimizer.lr_decay = 0.7
71 | _C.optimizer.decay_step = 2e4
72 | _C.optimizer.weight_decay = 1e-6
73 | _C.optimizer.lr_clip = 1e-5
74 |
75 | def get_cfg_defaults():
76 | return _C.clone()
--------------------------------------------------------------------------------
/src/shape_assembly/config_eval.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # Miscellaneous configs
4 | _C = CN()
5 |
6 | # Experiment related
7 | _C.exp = CN()
8 | _C.exp.name = ''
9 | _C.exp.checkpoint_dir = ''
10 | _C.exp.weight_file = ''
11 | _C.exp.gpus = [0]
12 | _C.exp.num_workers = 8
13 | _C.exp.batch_size = 1
14 | _C.exp.num_epochs = 1000
15 | _C.exp.log_dir = ''
16 | _C.exp.eval_epoch = 100
17 | _C.exp.load_from = ''
18 |
19 | # Model related
20 | _C.model = CN()
21 | _C.model.encoder = ''
22 | _C.model.encoder_geo = ''
23 | _C.model.pose_predictor_quat = ''
24 | _C.model.pose_predictor_trans = ''
25 | _C.model.sdf_predictor= ''
26 | _C.model.aggregator = ''
27 | _C.model.pc_feat_dim = 512
28 | _C.model.transformer_feat_dim = 1024
29 | _C.model.num_heads = 4
30 | _C.model.num_blocks = 1
31 |
32 | # Data related
33 | _C.data = CN()
34 | _C.data.root_dir = ''
35 | _C.data.train_csv_file = ''
36 | _C.data.val_csv_file = ''
37 | _C.data.test_csv_file = ''
38 | _C.data.num_pc_points = 1024
39 | _C.data.draw_map = False
40 | _C.data.draw_gt_map = False
41 |
42 | def get_cfg_defaults():
43 | return _C.clone()
44 |
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/.DS_Store
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/.DS_Store
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/cls_dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/cls_dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_A.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_A.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_CR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_CR.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit1.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_1.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_ori.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_ori.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_single.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_kit_single.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/dataloader_our1.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/sub1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/sub1.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/__pycache__/sub_1_7.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/datasets/dataloader/__pycache__/sub_1_7.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/dataloader_A.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import h5py
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import trimesh
9 | from PIL import Image
10 | import json
11 | from progressbar import ProgressBar
12 | import random
13 | import copy
14 | import time
15 | import ipdb
16 | import numpy as np
17 | import pandas as pd
18 | from scipy.spatial.transform import Rotation as R
19 | from pdb import set_trace
20 |
21 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
23 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly'))
24 |
25 | from pytorch3d.transforms import quaternion_to_matrix
26 | from mesh_to_sdf import sample_sdf_near_surface
27 |
28 | def load_data(file_dir, cat_shape_dict):
29 | with open(file_dir, 'r') as fin:
30 | for l in fin.readlines():
31 | shape_id, cat = l.rstrip().split()
32 | cat_shape_dict[cat].append(shape_id)
33 | return cat_shape_dict
34 |
35 | class OurDataset(data.Dataset):
36 |
37 | def __init__(self, data_root_dir, data_csv_file, data_features=[], num_points = 1024 ,num_query_points = 1024 ,data_per_seg = 1):
38 | self.data_root_dir = data_root_dir
39 | self.data_csv_file = data_csv_file
40 | self.num_points = num_points
41 | self.num_query_points = num_query_points
42 | self.data_features = data_features
43 | self.data_per_seg = data_per_seg
44 | self.dataset = []
45 |
46 | with open(self.data_csv_file, 'r') as fin:
47 | self.category_list = [line.strip() for line in fin.readlines()]
48 |
49 |
50 | def transform_pc_to_rot(self, pcs):
51 | # zero-centered
52 | pc_center = (pcs.max(axis=0, keepdims=True) + pcs.min(axis=0, keepdims=True)) / 2
53 | pc_center = pc_center[0]
54 | new_pcs = pcs - pc_center
55 |
56 | # (batch_size, 2, 3)
57 | def bgs(d6s):
58 | bsz = d6s.shape[0]
59 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1)
60 | a2 = d6s[:, :, 1]
61 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1)
62 | b3 = torch.cross(b1, b2, dim=1)
63 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1)
64 |
65 | # randomly sample two rotation matrices
66 | rotmat = bgs(torch.rand(1, 6).reshape(-1, 2, 3).permute(0, 2, 1))
67 | new_pcs = (rotmat.reshape(3, 3) @ new_pcs.T).T
68 |
69 | gt_rot = rotmat[:, :, :2].permute(0, 2, 1).reshape(6).numpy()
70 |
71 | return new_pcs, pc_center, gt_rot
72 |
73 |
74 | def load_data(self):
75 | bar = ProgressBar()
76 |
77 | for category_i in bar(range(len(self.category_list))):
78 | category_id = self.category_list[category_i]
79 | instance_dir = os.path.join(self.data_root_dir, category_id)
80 | fileA = os.path.join(instance_dir, 'partA-pc.csv')
81 | fileB = os.path.join(instance_dir, 'partB-pc.csv')
82 |
83 | if not os.path.exists(fileA) or not os.path.exists(fileB):
84 | print("fileA is", fileA)
85 | print("fileB is", fileB)
86 | print("file not exists")
87 | continue
88 |
89 | dataframe_A = pd.read_csv(fileA, header=None)
90 | dataframe_B = pd.read_csv(fileB, header=None)
91 | gt_pcs_A = dataframe_A.to_numpy()
92 | gt_pcs_B = dataframe_B.to_numpy()
93 |
94 | for i in range(self.data_per_seg):
95 | gt_pcs_total = np.concatenate((gt_pcs_A, gt_pcs_B), axis=0)
96 | per = np.random.permutation(gt_pcs_total.shape[0])
97 | gt_pcs_total = gt_pcs_total[per]
98 |
99 | self.dataset.append([fileA,
100 | fileB,
101 | ])
102 |
103 | def __str__(self):
104 | return
105 |
106 | def __len__(self):
107 | return len(self.dataset)
108 |
109 | def __getitem__(self, index):
110 |
111 | flag = 0
112 | while flag == 0:
113 | point_fileA, point_fileB, = self.dataset[index]
114 |
115 | dataframe_A = pd.read_csv(point_fileA, header=None)
116 | dataframe_B = pd.read_csv(point_fileB, header=None)
117 | gt_pcs_A = dataframe_A.to_numpy()
118 | gt_pcs_B = dataframe_B.to_numpy()
119 |
120 | if gt_pcs_A[0][0] != 0 and gt_pcs_A[0][1] != 0:
121 | flag = 1
122 | else:
123 | index += 1
124 |
125 | if gt_pcs_A[0][0] == 0 and gt_pcs_A[0][1] == 0:
126 | raise ValueError('getitem Zero encountered!')
127 |
128 | new_pcs_A, trans_A, rot_A = self.transform_pc_to_rot(gt_pcs_A)
129 |
130 | new_pcs_B, trans_B, rot_B = self.transform_pc_to_rot(gt_pcs_B)
131 |
132 | partA_symmetry_type = np.array([0,0,0,0,1,0])
133 | partB_symmetry_type = np.array([0,0,0,0,1,0])
134 |
135 | data_feats = dict()
136 |
137 | for feat in self.data_features:
138 | if feat == 'src_pc':
139 | data_feats['src_pc'] = new_pcs_A.T.float()
140 |
141 | elif feat == 'tgt_pc':
142 | data_feats['tgt_pc'] = torch.tensor(gt_pcs_B).T.float()
143 |
144 | elif feat == 'src_rot':
145 | data_feats['src_rot'] = rot_A.astype(np.float32)
146 |
147 | elif feat == 'tgt_rot':
148 | data_feats['tgt_rot'] = rot_B.astype(np.float32)
149 |
150 | elif feat == 'src_trans':
151 | data_feats['src_trans'] = trans_A.reshape(1, 3).T.astype(np.float32)
152 |
153 | elif feat == 'tgt_trans':
154 | data_feats['tgt_trans'] = trans_B.reshape(1, 3).T.astype(np.float32)
155 |
156 | elif feat == 'partA_symmetry_type':
157 | data_feats['partA_symmetry_type'] = partA_symmetry_type
158 |
159 | elif feat == 'partB_symmetry_type':
160 | data_feats['partB_symmetry_type'] = partB_symmetry_type
161 |
162 | elif feat == 'predicted_partB_position':
163 | data_feats['predicted_partB_position'] = np.array([0, 0, 0, 0]).astype(np.float32)
164 |
165 | elif feat == 'predicted_partB_rotation':
166 | data_feats['predicted_partB_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32)
167 |
168 | elif feat == 'predicted_partA_position':
169 | data_feats['predicted_partA_position'] = np.array([0, 0, 0, 0]).astype(np.float32)
170 |
171 | elif feat == 'predicted_partA_rotation':
172 | data_feats['predicted_partA_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32)
173 |
174 | elif feat == 'partA_mesh':
175 | data_feats['partA_mesh'] = self.load_mesh(data_feats, point_fileA)
176 |
177 | elif feat == 'partB_mesh':
178 | data_feats['partB_mesh'] = self.load_mesh(data_feats, point_fileB)
179 |
180 | return data_feats
--------------------------------------------------------------------------------
/src/shape_assembly/datasets/dataloader/dataloader_B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import h5py
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import trimesh
9 | from PIL import Image
10 | import json
11 | from progressbar import ProgressBar
12 | # from pyquaternion import Quaternion
13 | import random
14 | import copy
15 | import time
16 | import ipdb
17 | import numpy as np
18 | import pandas as pd
19 | from scipy.spatial.transform import Rotation as R
20 | from pdb import set_trace
21 |
22 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
23 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
24 | sys.path.append(os.path.join(BASE_DIR, '../shape_assembly'))
25 |
26 | from pytorch3d.transforms import quaternion_to_matrix
27 |
28 | def load_data(file_dir, cat_shape_dict):
29 | with open(file_dir, 'r') as fin:
30 | for l in fin.readlines():
31 | shape_id, cat = l.rstrip().split()
32 | cat_shape_dict[cat].append(shape_id)
33 | return cat_shape_dict
34 |
35 | class OurDataset(data.Dataset):
36 |
37 | def __init__(self, data_root_dir, data_csv_file, data_features=[], num_points = 1024 ,num_query_points = 1024 ,data_per_seg = 1):
38 | self.data_root_dir = data_root_dir
39 | self.data_csv_file = data_csv_file
40 | self.num_points = num_points
41 | self.num_query_points = num_query_points
42 | self.data_features = data_features
43 | self.data_per_seg = data_per_seg
44 | self.dataset = []
45 |
46 | with open(self.data_csv_file, 'r') as fin:
47 | self.category_list = [line.strip() for line in fin.readlines()]
48 |
49 |
50 | def transform_pc_to_rot(self, pcs):
51 | # zero-centered
52 | pc_center = (pcs.max(axis=0, keepdims=True) + pcs.min(axis=0, keepdims=True)) / 2
53 | pc_center = pc_center[0]
54 | new_pcs = pcs - pc_center
55 |
56 | # (batch_size, 2, 3)
57 | def bgs(d6s):
58 | bsz = d6s.shape[0]
59 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1)
60 | a2 = d6s[:, :, 1]
61 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1)
62 | b3 = torch.cross(b1, b2, dim=1)
63 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1)
64 |
65 | # randomly sample two rotation matrices
66 | rotmat = bgs(torch.rand(1, 6).reshape(-1, 2, 3).permute(0, 2, 1))
67 | new_pcs = (rotmat.reshape(3, 3) @ new_pcs.T).T
68 |
69 | gt_rot = rotmat[:, :, :2].permute(0, 2, 1).reshape(6).numpy()
70 |
71 | return new_pcs, pc_center, gt_rot
72 |
73 |
74 | def load_data(self):
75 | bar = ProgressBar()
76 |
77 | for category_i in bar(range(len(self.category_list))):
78 | category_id = self.category_list[category_i]
79 | instance_dir = os.path.join(self.data_root_dir, category_id)
80 | fileA = os.path.join(instance_dir, 'partA-pc.csv')
81 | fileB = os.path.join(instance_dir, 'partB-pc.csv')
82 | if not os.path.exists(fileA) or not os.path.exists(fileB):
83 | print("fileA is", fileA)
84 | print("fileB is", fileB)
85 | print("file not exists")
86 | # exit(0)
87 | continue
88 |
89 | dataframe_A = pd.read_csv(fileA, header=None)
90 | dataframe_B = pd.read_csv(fileB, header=None)
91 | gt_pcs_A = dataframe_A.to_numpy()
92 | gt_pcs_B = dataframe_B.to_numpy()
93 |
94 | for i in range(self.data_per_seg):
95 | gt_pcs_total = np.concatenate((gt_pcs_A, gt_pcs_B), axis=0)
96 | per = np.random.permutation(gt_pcs_total.shape[0])
97 | gt_pcs_total = gt_pcs_total[per]
98 |
99 | self.dataset.append([fileA,
100 | fileB,])
101 |
102 | def __str__(self):
103 | return
104 |
105 | def __len__(self):
106 | return len(self.dataset)
107 |
108 | def __getitem__(self, index):
109 |
110 | flag = 0
111 | while flag == 0:
112 | point_fileA, point_fileB = self.dataset[index]
113 |
114 | dataframe_A = pd.read_csv(point_fileA, header=None)
115 | dataframe_B = pd.read_csv(point_fileB, header=None)
116 | gt_pcs_A = dataframe_A.to_numpy()
117 | gt_pcs_B = dataframe_B.to_numpy()
118 |
119 | if gt_pcs_A[0][0] != 0 and gt_pcs_A[0][1] != 0:
120 | flag = 1
121 | else:
122 | index += 1
123 |
124 | if gt_pcs_A[0][0] == 0 and gt_pcs_A[0][1] == 0:
125 | raise ValueError('getitem Zero encountered!')
126 |
127 | try:
128 | new_pcs_A, trans_A, rot_A = self.transform_pc_to_rot(gt_pcs_A)
129 | except Exception as e:
130 | print("file A error!", point_fileA)
131 | raise
132 |
133 | try:
134 | new_pcs_B, trans_B, rot_B = self.transform_pc_to_rot(gt_pcs_B)
135 | except Exception as e:
136 | print("file B error!", point_fileB)
137 |
138 | partA_symmetry_type = np.array([0,0,0,0,1,0])
139 | partB_symmetry_type = np.array([0,0,0,0,1,0])
140 |
141 | data_feats = dict()
142 |
143 | for feat in self.data_features:
144 | if feat == 'src_pc':
145 | data_feats['src_pc'] = new_pcs_A.T.float()
146 |
147 | elif feat == 'tgt_pc':
148 | data_feats['tgt_pc'] = new_pcs_B.T.float()
149 |
150 | elif feat == 'src_rot':
151 | data_feats['src_rot'] = rot_A.astype(np.float32)
152 |
153 | elif feat == 'tgt_rot':
154 | data_feats['tgt_rot'] = rot_B.astype(np.float32)
155 |
156 | elif feat == 'src_trans':
157 | data_feats['src_trans'] = trans_A.reshape(1, 3).T.astype(np.float32)
158 |
159 | elif feat == 'tgt_trans':
160 | data_feats['tgt_trans'] = trans_B.reshape(1, 3).T.astype(np.float32)
161 |
162 | elif feat == 'partA_symmetry_type':
163 | data_feats['partA_symmetry_type'] = partA_symmetry_type
164 |
165 | elif feat == 'partB_symmetry_type':
166 | data_feats['partB_symmetry_type'] = partB_symmetry_type
167 |
168 | elif feat == 'predicted_partB_position':
169 | data_feats['predicted_partB_position'] = np.array([0, 0, 0, 0]).astype(np.float32)
170 |
171 | elif feat == 'predicted_partB_rotation':
172 | data_feats['predicted_partB_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32)
173 |
174 | elif feat == 'predicted_partA_position':
175 | data_feats['predicted_partA_position'] = np.array([0, 0, 0, 0]).astype(np.float32)
176 |
177 | elif feat == 'predicted_partA_rotation':
178 | data_feats['predicted_partA_rotation'] = np.array([0, 0, 0, 0, 0, 0]).astype(np.float32)
179 |
180 | elif feat == 'partA_mesh':
181 | data_feats['partA_mesh'] = self.load_mesh(data_feats, point_fileA)
182 |
183 | elif feat == 'partB_mesh':
184 | data_feats['partB_mesh'] = self.load_mesh(data_feats, point_fileB)
185 |
186 | return data_feats
--------------------------------------------------------------------------------
/src/shape_assembly/models/decoder/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/decoder/.DS_Store
--------------------------------------------------------------------------------
/src/shape_assembly/models/decoder/MLPDecoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from pdb import set_trace
5 | import os
6 | import sys
7 | import copy
8 | import math
9 | import numpy as np
10 |
11 |
12 | class MLPDecoder(nn.Module):
13 | def __init__(self, feat_dim, num_points):
14 | super().__init__()
15 | self.np = num_points
16 | self.fc_layers = nn.Sequential(
17 | nn.Linear(feat_dim * 2, num_points * 2),
18 | nn.BatchNorm1d(num_points * 2),
19 | nn.LeakyReLU(0.2),
20 | nn.Linear(num_points * 2, num_points * 3),
21 | )
22 |
23 | def forward(self, x):
24 | # x.shape: (bs,1024)
25 | batch_size = x.shape[0]
26 | f = self.fc_layers(x)
27 | return f.reshape(batch_size, self.np, 3)
--------------------------------------------------------------------------------
/src/shape_assembly/models/decoder/__pycache__/MLPDecoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/decoder/__pycache__/MLPDecoder.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/.DS_Store
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/__pycache__/dgcnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/dgcnn.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/__pycache__/pointnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/pointnet.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/__pycache__/vn_dgcnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/vn_dgcnn.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/__pycache__/vn_layers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/encoder/__pycache__/vn_layers.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/vn_dgcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from pdb import set_trace
5 | import os
6 | import sys
7 | import copy
8 | import math
9 | import numpy as np
10 |
11 | from models.encoder.vn_layers import *
12 |
13 | def knn(x, k):
14 | inner = -2 * torch.matmul(x.transpose(2, 1), x)
15 | xx = torch.sum(x ** 2, dim=1, keepdim=True)
16 | pairwise_distance = -xx - inner - xx.transpose(2, 1)
17 |
18 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
19 | return idx
20 |
21 | def get_graph_feature(x, k=20, idx=None, x_coord=None):
22 | batch_size = x.size(0)
23 | num_points = x.size(3)
24 | x = x.view(batch_size, -1, num_points)
25 | if idx is None:
26 | if x_coord is None: # dynamic knn graph
27 | idx = knn(x, k=k)
28 | else: # fixed knn graph with input point coordinates
29 | idx = knn(x_coord, k=k)
30 | device = torch.device('cuda')
31 |
32 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
33 |
34 | idx = idx + idx_base
35 |
36 | idx = idx.view(-1)
37 |
38 | _, num_dims, _ = x.size()
39 | num_dims = num_dims // 3
40 |
41 | x = x.transpose(2, 1).contiguous()
42 | feature = x.view(batch_size * num_points, -1)[idx, :]
43 | feature = feature.view(batch_size, num_points, k, num_dims, 3)
44 | x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
45 |
46 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous()
47 |
48 | return feature
49 |
50 | class VN_DGCNN(nn.Module):
51 |
52 |
53 | def __init__(self, feat_dim):
54 | super(VN_DGCNN, self).__init__()
55 | self.n_knn = 20
56 | # num_part = feat_dim
57 |
58 | pooling = 'mean'
59 |
60 | self.conv1 = VNLinearLeakyReLU(2, 64 // 3)
61 | self.conv2 = VNLinearLeakyReLU(64 // 3, 64 // 3)
62 | self.conv3 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3)
63 | self.conv4 = VNLinearLeakyReLU(64 // 3, 64 // 3)
64 | self.conv5 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3)
65 |
66 | if pooling == 'max':
67 | self.pool1 = VNMaxPool(64 // 3)
68 | self.pool2 = VNMaxPool(64 // 3)
69 | self.pool3 = VNMaxPool(64 // 3)
70 | self.pool4 = VNMaxPool(2 * feat_dim)
71 | elif pooling == 'mean':
72 | self.pool1 = mean_pool
73 | self.pool2 = mean_pool
74 | self.pool3 = mean_pool
75 | self.pool4 = mean_pool
76 |
77 | self.conv6 = VNLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4, share_nonlinearity=True)
78 |
79 | def forward(self, x):
80 |
81 | batch_size = x.size(0)
82 | num_points = x.size(2)
83 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16)
84 |
85 | x = x.unsqueeze(1) # (32, 1, 3, 1024)
86 |
87 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20)
88 |
89 | x = self.conv1(x) # (32, 21, 3, 1024, 20)
90 | x = self.conv2(x) # (32, 21, 3, 1024, 20)
91 | x1 = self.pool1(x) # (32, 21, 3, 1024)
92 |
93 | x = get_graph_feature(x1, k=self.n_knn)
94 | x = self.conv3(x)
95 | x = self.conv4(x)
96 | x2 = self.pool2(x)
97 |
98 | x = get_graph_feature(x2, k=self.n_knn)
99 | x = self.conv5(x)
100 | x3 = self.pool3(x)
101 |
102 | x123 = torch.cat((x1, x2, x3), dim=1)
103 |
104 | x = self.conv6(x123)
105 | x = self.pool4(x) # [batch, feature_dim, 3]
106 | return x
107 |
108 | class VN_DGCNN_corr(nn.Module):
109 |
110 | def __init__(self, feat_dim):
111 | super(VN_DGCNN_corr, self).__init__()
112 | self.n_knn = 20
113 | # num_part = feat_dim
114 |
115 | pooling = 'mean'
116 |
117 | self.conv1 = VNLinearLeakyReLU(2, 64 // 3)
118 | self.conv2 = VNLinearLeakyReLU(64 // 3, 64 // 3)
119 | self.conv3 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3)
120 | self.conv4 = VNLinearLeakyReLU(64 // 3, 64 // 3)
121 | self.conv5 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3)
122 | self.VnInv = VNStdFeature(2 * feat_dim, dim=3, normalize_frame=False)
123 |
124 | if pooling == 'max':
125 | self.pool1 = VNMaxPool(64 // 3)
126 | self.pool2 = VNMaxPool(64 // 3)
127 | self.pool3 = VNMaxPool(64 // 3)
128 | self.pool4 = VNMaxPool(2 * feat_dim)
129 | elif pooling == 'mean':
130 | self.pool1 = mean_pool
131 | self.pool2 = mean_pool
132 | self.pool3 = mean_pool
133 | self.pool4 = mean_pool
134 |
135 | self.conv6 = VNLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4, share_nonlinearity=True)
136 | self.linear0 = nn.Linear(3, 2 * feat_dim)
137 |
138 | def forward(self, x):
139 |
140 | batch_size = x.size(0)
141 | num_points = x.size(2)
142 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16)
143 |
144 | x = x.unsqueeze(1) # (32, 1, 3, 1024)
145 |
146 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20)
147 |
148 | x = self.conv1(x) # (32, 21, 3, 1024, 20)
149 | x = self.conv2(x) # (32, 21, 3, 1024, 20)
150 | x1 = self.pool1(x) # (32, 21, 3, 1024)
151 |
152 | x = get_graph_feature(x1, k=self.n_knn)
153 | x = self.conv3(x)
154 | x = self.conv4(x)
155 | x2 = self.pool2(x)
156 |
157 | x = get_graph_feature(x2, k=self.n_knn)
158 | x = self.conv5(x)
159 | x3 = self.pool3(x)
160 |
161 | x123 = torch.cat((x1, x2, x3), dim=1)
162 |
163 | x = self.conv6(x123)
164 | x_mean = x.mean(dim=-1, keepdim=True).expand(x.size())
165 | x = torch.cat((x, x_mean), 1)
166 | x = self.pool4(x) # [batch, feature_dim, 3]
167 | x1, z0 = self.VnInv(x)
168 | x1 = self.linear0(x1)
169 | return x, x1 # [batch, 1024, 3], [batch, 1024, 1024]
170 |
171 | class VN_DGCNN_New(nn.Module):
172 |
173 | def __init__(self, feat_dim):
174 | super(VN_DGCNN_New, self).__init__()
175 | self.n_knn = 20
176 | num_part = feat_dim
177 |
178 | print("feat_dim is: ", feat_dim)
179 |
180 | pooling = 'mean'
181 |
182 | self.conv1 = VNLinearLeakyReLU(2, 64 // 3)
183 | self.conv2 = VNLinearLeakyReLU(64 // 3, 64 // 3)
184 | self.conv3 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3)
185 | self.conv4 = VNLinearLeakyReLU(64 // 3, 64 // 3)
186 | self.conv5 = VNLinearLeakyReLU(64 // 3 * 2, 64 // 3)
187 | self.VnInv = VNStdFeature(2 * feat_dim, dim=3, normalize_frame=False)
188 |
189 | if pooling == 'max':
190 | self.pool1 = VNMaxPool(64 // 3)
191 | self.pool2 = VNMaxPool(64 // 3)
192 | self.pool3 = VNMaxPool(64 // 3)
193 | self.pool4 = VNMaxPool(2 * feat_dim)
194 | elif pooling == 'mean':
195 | self.pool1 = mean_pool
196 | self.pool2 = mean_pool
197 | self.pool3 = mean_pool
198 | self.pool4 = mean_pool
199 |
200 | self.conv6 = VNLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4, share_nonlinearity=True)
201 | self.linear0 = nn.Linear(3, 2 * feat_dim)
202 |
203 | def forward(self, x):
204 |
205 | batch_size = x.size(0)
206 | num_points = x.size(2)
207 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16)
208 |
209 | x = x.unsqueeze(1) # (32, 1, 3, 1024)
210 |
211 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20)
212 |
213 | x = self.conv1(x) # (32, 21, 3, 1024, 20)
214 | x = self.conv2(x) # (32, 21, 3, 1024, 20)
215 | x1 = self.pool1(x) # (32, 21, 3, 1024)
216 |
217 | x = get_graph_feature(x1, k=self.n_knn)
218 | x = self.conv3(x)
219 | x = self.conv4(x)
220 | x2 = self.pool2(x)
221 |
222 | x = get_graph_feature(x2, k=self.n_knn)
223 | x = self.conv5(x)
224 | x3 = self.pool3(x)
225 |
226 | x123 = torch.cat((x1, x2, x3), dim=1)
227 |
228 | x = self.conv6(x123)
229 | x_mean = x.mean(dim=-1, keepdim=True).expand(x.size())
230 | x = torch.cat((x, x_mean), 1)
231 | x = self.pool4(x) # [batch, feature_dim, 3]
232 | x1, z0 = self.VnInv(x)
233 | x1 = self.linear0(x1)
234 | return x, x1 # [batch, 1024, 3], [batch, 1024, 1024]
235 |
236 | class DGCNN_New(nn.Module):
237 |
238 | def __init__(self, feat_dim):
239 | super(DGCNN_New, self).__init__()
240 | self.n_knn = 20
241 | num_part = feat_dim
242 |
243 | print("feat_dim is: ", feat_dim)
244 |
245 | pooling = 'mean'
246 |
247 | self.conv1 = NonEquivariantLinearLeakyReLU(2, 64 // 3)
248 | self.conv2 = NonEquivariantLinearLeakyReLU(64 // 3, 64 // 3)
249 | self.conv3 = NonEquivariantLinearLeakyReLU(64 // 3 * 2, 64 // 3)
250 | self.conv4 = NonEquivariantLinearLeakyReLU(64 // 3, 64 // 3)
251 | self.conv5 = NonEquivariantLinearLeakyReLU(64 // 3 * 2, 64 // 3)
252 | self.VnInv = NonEquivariantStdFeature(2 * feat_dim, dim=3, normalize_frame=False)
253 |
254 | if pooling == 'max':
255 | self.pool1 = NonEquivariantMaxPool(64 // 3)
256 | self.pool2 = NonEquivariantMaxPool(64 // 3)
257 | self.pool3 = NonEquivariantMaxPool(64 // 3)
258 | self.pool4 = NonEquivariantMaxPool(2 * feat_dim)
259 | elif pooling == 'mean':
260 | self.pool1 = mean_pool
261 | self.pool2 = mean_pool
262 | self.pool3 = mean_pool
263 | self.pool4 = mean_pool
264 |
265 | self.conv6 = NonEquivariantLinearLeakyReLU(64 // 3 * 3, feat_dim, dim=4)
266 | self.linear0 = nn.Linear(3, 2 * feat_dim)
267 |
268 | def forward(self, x):
269 |
270 | # x: (batch_size, 3, num_points)
271 | # l: (batch_size, 1, 16)
272 |
273 | batch_size = x.size(0)
274 | num_points = x.size(2)
275 | l = x[:, 0, 0:16].reshape(batch_size, 1, 16)
276 | print("!!! the shape of l is: ", l.shape)
277 |
278 | print("!!! the shape of x is: ", x.shape)
279 | x = x.unsqueeze(1) # (32, 1, 3, 1024)
280 | print("!!! the shape of x is: ", x.shape)
281 |
282 | x = get_graph_feature(x, k=self.n_knn) # (32, 2, 3, 1024, 20)
283 |
284 | print("!!! the shape of x is: ", x.shape)
285 | x = self.conv1(x) # (32, 21, 3, 1024, 20)
286 | print("!!! the shape of x is: ", x.shape)
287 | x = self.conv2(x) # (32, 21, 3, 1024, 20)
288 | x1 = self.pool1(x) # (32, 21, 3, 1024)
289 |
290 | x = get_graph_feature(x1, k=self.n_knn)
291 | x = self.conv3(x)
292 | x = self.conv4(x)
293 | x2 = self.pool2(x)
294 |
295 | x = get_graph_feature(x2, k=self.n_knn)
296 | x = self.conv5(x)
297 | x3 = self.pool3(x)
298 |
299 | x123 = torch.cat((x1, x2, x3), dim=1)
300 |
301 | x = self.conv6(x123)
302 | x_mean = x.mean(dim=-1, keepdim=True).expand(x.size())
303 | x = torch.cat((x, x_mean), 1)
304 | x = self.pool4(x) # [batch, feature_dim, 3]
305 | x1, z0 = self.VnInv(x)
306 | x1 = self.linear0(x1)
307 | return x, x1 # [batch, 1024, 3], [batch, 1024, 1024]
308 |
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/vn_dgcnn_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import math
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def knn(x, k):
12 | inner = -2 * torch.matmul(x.transpose(2, 1), x)
13 | xx = torch.sum(x ** 2, dim=1, keepdim=True)
14 | pairwise_distance = -xx - inner - xx.transpose(2, 1)
15 |
16 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
17 | return idx
18 |
19 |
20 | def get_graph_feature(x, k=20, idx=None, x_coord=None):
21 | batch_size = x.size(0)
22 | num_points = x.size(3)
23 | x = x.view(batch_size, -1, num_points)
24 | if idx is None:
25 | if x_coord is None: # dynamic knn graph
26 | idx = knn(x, k=k)
27 | else: # fixed knn graph with input point coordinates
28 | idx = knn(x_coord, k=k)
29 | device = torch.device('cuda')
30 |
31 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
32 |
33 | idx = idx + idx_base
34 |
35 | idx = idx.view(-1)
36 |
37 | _, num_dims, _ = x.size()
38 | num_dims = num_dims // 3
39 |
40 | x = x.transpose(2, 1).contiguous()
41 | feature = x.view(batch_size * num_points, -1)[idx, :]
42 | feature = feature.view(batch_size, num_points, k, num_dims, 3)
43 | x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
44 |
45 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous()
46 |
47 | return feature
48 |
49 |
50 | def get_graph_feature_cross(x, k=20, idx=None):
51 | batch_size = x.size(0)
52 | num_points = x.size(3)
53 | x = x.view(batch_size, -1, num_points)
54 | if idx is None:
55 | idx = knn(x, k=k)
56 | device = torch.device('cuda')
57 |
58 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
59 |
60 | idx = idx + idx_base
61 |
62 | idx = idx.view(-1)
63 |
64 | _, num_dims, _ = x.size()
65 | num_dims = num_dims // 3
66 |
67 | x = x.transpose(2, 1).contiguous()
68 | feature = x.view(batch_size * num_points, -1)[idx, :]
69 | feature = feature.view(batch_size, num_points, k, num_dims, 3)
70 | x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
71 | cross = torch.cross(feature, x, dim=-1)
72 |
73 | feature = torch.cat((feature - x, x, cross), dim=3).permute(0, 3, 4, 1, 2).contiguous()
74 |
75 | return feature
--------------------------------------------------------------------------------
/src/shape_assembly/models/encoder/vn_layers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import math
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | EPS = 1e-6
11 |
12 | def conv1x1(in_channels, out_channels, dim):
13 | if dim == 3:
14 | return nn.Conv1d(in_channels, out_channels, 1, bias=False)
15 | elif dim == 4:
16 | return nn.Conv2d(in_channels, out_channels, 1, bias=False)
17 | elif dim == 5:
18 | return nn.Conv3d(in_channels, out_channels, 1, bias=False)
19 | else:
20 | raise NotImplementedError(f'{dim}D 1x1 Conv is not supported')
21 |
22 |
23 | class VNLinear(nn.Module):
24 | def __init__(self, in_channels, out_channels):
25 | super(VNLinear, self).__init__()
26 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
27 |
28 | def forward(self, x):
29 | '''
30 | x: point features of shape [B, N_feat, 3, N_samples, ...]
31 | '''
32 | x_out = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1)
33 | return x_out
34 |
35 |
36 | class VNLeakyReLU(nn.Module):
37 | def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2):
38 | super(VNLeakyReLU, self).__init__()
39 | if share_nonlinearity == True:
40 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
41 | else:
42 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
43 | self.negative_slope = negative_slope
44 |
45 | def forward(self, x):
46 | '''
47 | x: point features of shape [B, N_feat, 3, N_samples, ...]
48 | '''
49 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1)
50 | dotprod = (x * d).sum(2, keepdim=True)
51 | mask = (dotprod >= 0).float()
52 | d_norm_sq = (d * d).sum(2, keepdim=True)
53 | x_out = self.negative_slope * x + (1 - self.negative_slope) * (
54 | mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d))
55 | return x_out
56 |
57 |
58 | class VNNewLeakyReLU(nn.Module):
59 | def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.2):
60 | super(VNNewLeakyReLU, self).__init__()
61 | if share_nonlinearity == True:
62 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
63 | else:
64 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
65 | self.negative_slope = negative_slope
66 |
67 | def forward(self, x):
68 | '''
69 | x: point features of shape [B, N_feat, 3, N_samples, ...]
70 | '''
71 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1)
72 | dotprod = (x * d)
73 | mask = (dotprod >= 0).float()
74 | d_norm_sq = (d * d)
75 | x_out = self.negative_slope * x + (1 - self.negative_slope) * (
76 | mask * x + (1 - mask) * (x - (d / (d_norm_sq + EPS)) * d))
77 | return x_out
78 |
79 |
80 | class VNLinearLeakyReLU(nn.Module):
81 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, negative_slope=0.2):
82 | super(VNLinearLeakyReLU, self).__init__()
83 | self.dim = dim
84 | self.negative_slope = negative_slope
85 |
86 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
87 | self.batchnorm = VNBatchNorm(out_channels, dim=dim)
88 |
89 | if share_nonlinearity == True:
90 | self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
91 | else:
92 | self.map_to_dir = nn.Linear(in_channels, out_channels, bias=False)
93 |
94 | def forward(self, x):
95 | '''
96 | x: point features of shape [B, N_feat, 3, N_samples, ...]
97 | '''
98 | # Linear
99 | p = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1)
100 | # BatchNorm
101 | p = self.batchnorm(p)
102 | # LeakyReLU
103 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1)
104 | dotprod = (p * d).sum(2, keepdims=True)
105 | mask = (dotprod >= 0).float()
106 | d_norm_sq = (d * d).sum(2, keepdims=True)
107 | x_out = self.negative_slope * p + (1 - self.negative_slope) * (
108 | mask * p + (1 - mask) * (p - (dotprod / (d_norm_sq + EPS)) * d))
109 | return x_out
110 |
111 |
112 | class VNLinearAndLeakyReLU(nn.Module):
113 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, use_batchnorm='norm',
114 | negative_slope=0.2):
115 | super(VNLinearLeakyReLU, self).__init__()
116 | self.dim = dim
117 | self.share_nonlinearity = share_nonlinearity
118 | self.use_batchnorm = use_batchnorm
119 | self.negative_slope = negative_slope
120 |
121 | self.linear = VNLinear(in_channels, out_channels)
122 | self.leaky_relu = VNLeakyReLU(out_channels, share_nonlinearity=share_nonlinearity,
123 | negative_slope=negative_slope)
124 |
125 | # BatchNorm
126 | self.use_batchnorm = use_batchnorm
127 | if use_batchnorm != 'none':
128 | self.batchnorm = VNBatchNorm(out_channels, dim=dim, mode=use_batchnorm)
129 |
130 | def forward(self, x):
131 | '''
132 | x: point features of shape [B, N_feat, 3, N_samples, ...]
133 | '''
134 | # Conv
135 | x = self.linear(x)
136 | # InstanceNorm
137 | if self.use_batchnorm != 'none':
138 | x = self.batchnorm(x)
139 | # LeakyReLU
140 | x_out = self.leaky_relu(x)
141 | return x_out
142 |
143 |
144 | class VNBatchNorm(nn.Module):
145 | def __init__(self, num_features, dim):
146 | super(VNBatchNorm, self).__init__()
147 | self.dim = dim
148 | if dim == 3 or dim == 4:
149 | self.bn = nn.BatchNorm1d(num_features)
150 | elif dim == 5:
151 | self.bn = nn.BatchNorm2d(num_features)
152 |
153 | def forward(self, x):
154 | '''
155 | x: point features of shape [B, N_feat, 3, N_samples, ...]
156 | '''
157 | # norm = torch.sqrt((x*x).sum(2))
158 | norm = torch.norm(x, dim=2) + EPS
159 | norm_bn = self.bn(norm)
160 | norm = norm.unsqueeze(2)
161 | norm_bn = norm_bn.unsqueeze(2)
162 | x = x / norm * norm_bn
163 |
164 | return x
165 |
166 |
167 | class VNMaxPool(nn.Module):
168 | def __init__(self, in_channels):
169 | super(VNMaxPool, self).__init__()
170 | self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
171 |
172 | def forward(self, x):
173 | '''
174 | x: point features of shape [B, N_feat, 3, N_samples, ...]
175 | '''
176 | d = self.map_to_dir(x.transpose(1, -1)).transpose(1, -1)
177 | dotprod = (x * d).sum(2, keepdims=True)
178 | idx = dotprod.max(dim=-1, keepdim=False)[1]
179 | index_tuple = torch.meshgrid([torch.arange(j) for j in x.size()[:-1]]) + (idx,)
180 | x_max = x[index_tuple]
181 | return x_max
182 |
183 |
184 | def mean_pool(x, dim=-1, keepdim=False):
185 | return x.mean(dim=dim, keepdim=keepdim)
186 |
187 |
188 | class VNStdFeature(nn.Module):
189 | def __init__(self, in_channels, dim=4, normalize_frame=False, share_nonlinearity=False, negative_slope=0.2):
190 | super(VNStdFeature, self).__init__()
191 | self.dim = dim
192 | self.normalize_frame = normalize_frame
193 |
194 | self.vn1 = VNLinearLeakyReLU(in_channels, in_channels // 2, dim=dim, share_nonlinearity=share_nonlinearity,
195 | negative_slope=negative_slope)
196 | self.vn2 = VNLinearLeakyReLU(in_channels // 2, in_channels // 4, dim=dim, share_nonlinearity=share_nonlinearity,
197 | negative_slope=negative_slope)
198 | if normalize_frame:
199 | self.vn_lin = nn.Linear(in_channels // 4, 2, bias=False)
200 | else:
201 | self.vn_lin = nn.Linear(in_channels // 4, 3, bias=False)
202 |
203 | def forward(self, x):
204 | '''
205 | x: point features of shape [B, N_feat, 3, N_samples, ...]
206 | '''
207 | z0 = x
208 | z0 = self.vn1(z0)
209 | z0 = self.vn2(z0)
210 | z0 = self.vn_lin(z0.transpose(1, -1)).transpose(1, -1)
211 |
212 | if self.normalize_frame:
213 | # make z0 orthogonal. u2 = v2 - proj_u1(v2)
214 | v1 = z0[:, 0, :]
215 | # u1 = F.normalize(v1, dim=1)
216 | v1_norm = torch.sqrt((v1 * v1).sum(1, keepdims=True))
217 | u1 = v1 / (v1_norm + EPS)
218 | v2 = z0[:, 1, :]
219 | v2 = v2 - (v2 * u1).sum(1, keepdims=True) * u1
220 | # u2 = F.normalize(u2, dim=1)
221 | v2_norm = torch.sqrt((v2 * v2).sum(1, keepdims=True))
222 | u2 = v2 / (v2_norm + EPS)
223 |
224 | # compute the cross product of the two output vectors
225 | u3 = torch.cross(u1, u2)
226 | z0 = torch.stack([u1, u2, u3], dim=1).transpose(1, 2)
227 | else:
228 | z0 = z0.transpose(1, 2)
229 |
230 | if self.dim == 4:
231 | x_std = torch.einsum('bijm,bjkm->bikm', x, z0)
232 | elif self.dim == 3:
233 | x_std = torch.einsum('bij,bjk->bik', x, z0)
234 | elif self.dim == 5:
235 | x_std = torch.einsum('bijmn,bjkmn->bikmn', x, z0)
236 |
237 | return x_std, z0
238 |
239 |
240 | class VNInFeature(nn.Module):
241 | """VN-Invariant layer."""
242 |
243 | def __init__(
244 | self,
245 | in_channels,
246 | dim=4,
247 | share_nonlinearity=False,
248 | negative_slope=0.2,
249 | use_rmat=False,
250 | ):
251 | super().__init__()
252 |
253 | self.dim = dim
254 | self.use_rmat = use_rmat
255 | self.vn1 = VNLinearBNLeakyReLU(
256 | in_channels,
257 | in_channels // 2,
258 | dim=dim,
259 | share_nonlinearity=share_nonlinearity,
260 | negative_slope=negative_slope,
261 | )
262 | self.vn2 = VNLinearBNLeakyReLU(
263 | in_channels // 2,
264 | in_channels // 4,
265 | dim=dim,
266 | share_nonlinearity=share_nonlinearity,
267 | negative_slope=negative_slope,
268 | )
269 | self.vn_lin = conv1x1(
270 | in_channels // 4, 2 if self.use_rmat else 3, dim=dim)
271 |
272 | def forward(self, x):
273 | """
274 | Args:
275 | x: point features of shape [B, C, 3, N, ...]
276 | Returns:
277 | rotation invariant features of the same shape
278 | """
279 | z = self.vn1(x)
280 | z = self.vn2(z)
281 | z = self.vn_lin(z) # [B, 3, 3, N] or [B, 2, 3, N]
282 | if self.use_rmat:
283 | z = z.flatten(1, 2).transpose(1, 2).contiguous() # [B, N, 6]
284 | z = rot6d_to_matrix(z) # [B, N, 3, 3]
285 | z = z.permute(0, 2, 3, 1) # [B, 3, 3, N]
286 | z = z.transpose(1, 2).contiguous()
287 |
288 | if self.dim == 4:
289 | x_in = torch.einsum('bijm,bjkm->bikm', x, z)
290 | elif self.dim == 3:
291 | x_in = torch.einsum('bij,bjk->bik', x, z)
292 | elif self.dim == 5:
293 | x_in = torch.einsum('bijmn,bjkmn->bikmn', x, z)
294 | else:
295 | raise NotImplementedError(f'dim={self.dim} is not supported')
296 |
297 | return x_in
298 |
299 |
300 | class VNLinearBNLeakyReLU(nn.Module):
301 |
302 | def __init__(
303 | self,
304 | in_channels,
305 | out_channels,
306 | dim=5,
307 | share_nonlinearity=False,
308 | negative_slope=0.2,
309 | ):
310 | super().__init__()
311 |
312 | self.linear = VNLinear(in_channels, out_channels)
313 | self.batchnorm = VNBatchNorm(out_channels, dim=dim)
314 | self.leaky_relu = VNLeakyReLU(
315 | out_channels,
316 | # dim=dim,
317 | share_nonlinearity=share_nonlinearity,
318 | negative_slope=negative_slope,
319 | )
320 |
321 | def forward(self, x):
322 | # Linear
323 | p = self.linear(x)
324 | # BatchNorm
325 | p = self.batchnorm(p)
326 | # LeakyReLU
327 | p = self.leaky_relu(p)
328 | return p
329 |
330 | class NonEquivariantLinearLeakyReLU(nn.Module):
331 | def __init__(self, in_channels, out_channels, dim=2, use_batchnorm=True, negative_slope=0.2):
332 | super(NonEquivariantLinearLeakyReLU, self).__init__()
333 | self.dim = dim
334 | self.negative_slope = negative_slope
335 | self.use_batchnorm = use_batchnorm
336 |
337 | self.map_to_feat = nn.Linear(in_channels, out_channels, bias=False)
338 | if use_batchnorm:
339 | self.batchnorm = nn.BatchNorm1d(out_channels)
340 | self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)
341 |
342 | def forward(self, x):
343 | # Linear
344 | print("PRE x.shape is, ", x.shape)
345 | # [todo]
346 | x = self.map_to_feat(x.transpose(1, -1)).transpose(1, -1)
347 | x = self.map_to_feat(x)
348 |
349 | # BatchNorm
350 | if self.use_batchnorm:
351 | x = x.transpose(1, -1).contiguous()
352 | x = self.batchnorm(x)
353 | x = x.transpose(1, -1).contiguous()
354 |
355 | # LeakyReLU
356 | x_out = self.leaky_relu(x)
357 | return x_out
358 |
359 |
360 | class NonEquivariantMaxPool(nn.Module):
361 | def __init__(self, dim=-1):
362 | super(NonEquivariantMaxPool, self).__init__()
363 | self.dim = dim
364 |
365 | def forward(self, x):
366 | '''
367 | x: point features of shape [B, C, N]
368 | '''
369 | return torch.max(x, dim=self.dim, keepdim=True)[0]
370 |
371 |
372 | class NonEquivariantStdFeature(nn.Module):
373 | def __init__(self, in_channels, dim=4, normalize_frame=False, negative_slope=0.2):
374 | super(NonEquivariantStdFeature, self).__init__()
375 | self.dim = dim
376 | self.normalize_frame = normalize_frame
377 |
378 | self.fc1 = NonEquivariantLinearLeakyReLU(in_channels, in_channels // 2, dim=dim, negative_slope=negative_slope)
379 | self.fc2 = NonEquivariantLinearLeakyReLU(in_channels // 2, in_channels // 4, dim=dim, negative_slope=negative_slope)
380 |
381 | def forward(self, x):
382 | '''
383 | x: point features of shape [B, C_in, N]
384 | '''
385 | z0 = x
386 | z0 = self.fc1(z0)
387 | z0 = self.fc2(z0)
388 |
389 | # No need for orthogonalization or frame normalization in non-equivariant version
390 | return z0
391 |
392 |
393 |
394 | class NonEquivariantLinearAndLeakyReLU(nn.Module):
395 | def __init__(self, in_channels, out_channels, dim=4, use_batchnorm=True, negative_slope=0.2):
396 | super(NonEquivariantLinearAndLeakyReLU, self).__init__()
397 | self.dim = dim
398 | self.use_batchnorm = use_batchnorm
399 | self.negative_slope = negative_slope
400 |
401 | self.linear = nn.Linear(in_channels, out_channels, bias=False)
402 | if use_batchnorm:
403 | self.batchnorm = nn.BatchNorm1d(out_channels)
404 | self.leaky_relu = nn.LeakyReLU(negative_slope=negative_slope)
405 |
406 | def forward(self, x):
407 | '''
408 | x: point features of shape [B, C_in, N]
409 | '''
410 | # Linear
411 | x = self.linear(x)
412 | # BatchNorm
413 | if self.use_batchnorm:
414 | x = self.batchnorm(x)
415 | # LeakyReLU
416 | x = self.leaky_relu(x)
417 | return x
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/Pose_Refinement_Module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/Pose_Refinement_Module.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn1.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn2.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn_A.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_A.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn_A_indi.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_A_indi.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn_A_indi2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_A_indi2.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn_B.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_B.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/network_vnn_B2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/network_vnn_B2.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/regressor_CR.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/regressor_CR.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/__pycache__/transformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TEA-Lab/TwoByTwo/b68f6594626a53e912be77e5a4d62e65af813141/src/shape_assembly/models/train/__pycache__/transformer.cpython-38.pyc
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/network_vnn_A.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 | import pytorch3d
8 | from pytorch3d.transforms import quaternion_to_matrix
9 | from transforms3d.quaternions import quat2mat
10 | from scipy.spatial.transform import Rotation as R
11 | from mpl_toolkits.mplot3d import Axes3D
12 |
13 | import os
14 | import sys
15 | import copy
16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17 | sys.path.append(os.path.join(BASE_DIR, '../../'))
18 | from models.encoder.pointnet import PointNet
19 | from models.encoder.dgcnn import DGCNN
20 | from models.decoder.MLPDecoder import MLPDecoder
21 | from models.encoder.vn_dgcnn import VN_DGCNN, VN_DGCNN_corr, VN_DGCNN_New
22 | from models.baseline.transformer import Transformer
23 | from models.baseline.regressor_CR import Regressor_CR, Regressor_6d, VN_Regressor_6d
24 | import utils
25 | from pdb import set_trace
26 |
27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
28 | from ..ChamferDistancePytorch.chamfer3D import dist_chamfer_3D
29 |
30 | def bgs(d6s):
31 | bsz = d6s.shape[0]
32 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1)
33 | a2 = d6s[:, :, 1]
34 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1)
35 | b3 = torch.cross(b1, b2, dim=1)
36 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1)
37 |
38 | def debug_vis_input(batch_data, cfg, prd_data, iter_counts):
39 | for i in range(cfg.exp.batch_size):
40 | save_dir = cfg.exp.vis_dir
41 | vis_dir = os.path.join(save_dir, 'vis_A_input')
42 | if not os.path.exists((vis_dir)):
43 | os.mkdir(vis_dir)
44 |
45 | src_pc = batch_data['src_pc'][i]
46 | tgt_pc = batch_data['tgt_pc'][i]
47 |
48 | device = src_pc.device
49 | num_points = src_pc.shape[1]
50 |
51 | tgt_pc_trans = batch_data['predicted_partB_position'][i].unsqueeze(1)
52 | tgt_pc_rot = batch_data['predicted_partB_rotation'][i].unsqueeze(1)
53 |
54 | tgt_rot_mat = bgs(tgt_pc_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
55 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat)
56 | tgt_pc_trans = tgt_pc_trans.expand(-1, num_points)
57 |
58 | tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double())
59 | tgt_pc = tgt_pc + tgt_pc_trans
60 |
61 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu())
62 | color_mask_src = ['r'] * num_points
63 | color_mask_tgt = ['g'] * num_points
64 | total_color = color_mask_src + color_mask_tgt
65 |
66 | fig = plt.figure()
67 | ax = fig.add_subplot(projection='3d')
68 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1)
69 | ax.axis('scaled')
70 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
71 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
72 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
73 |
74 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_input'.format(iter_counts,i)))
75 | plt.close(fig)
76 |
77 | def debug_vis_output(batch_data, cfg, pred_data, iter_counts):
78 | for i in range(cfg.exp.batch_size):
79 | save_dir = cfg.exp.vis_dir
80 | vis_dir = os.path.join(save_dir, 'vis_A_output')
81 | if not os.path.exists((vis_dir)):
82 | os.mkdir(vis_dir)
83 |
84 | src_pc = batch_data['src_pc'][i]
85 | src_trans = pred_data['src_trans'][i].unsqueeze(1)
86 | src_rot = pred_data['src_rot'][i]
87 |
88 | device = src_pc.device
89 | num_points = src_pc.shape[1]
90 |
91 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
92 | src_rot_mat = torch.linalg.inv(src_rot_mat)
93 |
94 | src_trans = src_trans.expand(-1, num_points)
95 |
96 | src_pc = torch.matmul(src_rot_mat.double(), src_pc.double())
97 |
98 | src_pc = src_pc + src_trans
99 |
100 | tgt_pc = batch_data['tgt_pc'][i]
101 |
102 | tgt_pc_trans = batch_data['predicted_partB_position'][i].unsqueeze(1)
103 | tgt_pc_rot = batch_data['predicted_partB_rotation'][i].unsqueeze(1)
104 |
105 | tgt_rot_mat = bgs(tgt_pc_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
106 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat)
107 | tgt_pc_trans = tgt_pc_trans.expand(-1, num_points)
108 |
109 | tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double())
110 | tgt_pc = tgt_pc + tgt_pc_trans
111 |
112 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu())
113 | color_mask_src = ['r'] * num_points
114 | color_mask_tgt = ['g'] * num_points
115 | total_color = color_mask_src + color_mask_tgt
116 |
117 | fig = plt.figure()
118 | ax = fig.add_subplot(projection='3d')
119 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1)
120 | ax.axis('scaled')
121 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
122 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
123 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
124 |
125 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_output'.format(iter_counts,i)))
126 | plt.close(fig)
127 |
128 | def debug_vis_gt(batch_data, cfg, pred_data, iter_counts):
129 | for i in range(cfg.exp.batch_size):
130 | save_dir = cfg.exp.vis_dir
131 | vis_dir = os.path.join(save_dir, 'vis_A_gt')
132 | if not os.path.exists((vis_dir)):
133 | os.mkdir(vis_dir)
134 | src_pc = batch_data['src_pc'][i]
135 | tgt_pc = batch_data['tgt_pc'][i]
136 |
137 | src_trans = batch_data['src_trans'][i]
138 | src_rot = batch_data['src_rot'][i]
139 | tgt_trans = batch_data['tgt_trans'][i]
140 | tgt_rot = batch_data['tgt_rot'][i]
141 | device = src_pc.device
142 | num_points = src_pc.shape[1]
143 |
144 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
145 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
146 | src_rot_mat = torch.linalg.inv(src_rot_mat)
147 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat)
148 |
149 | src_trans = src_trans.expand(-1, num_points)
150 | tgt_trans = tgt_trans.expand(-1, num_points)
151 |
152 | src_pc = torch.matmul(src_rot_mat, src_pc)
153 | tgt_pc = torch.matmul(tgt_rot_mat, tgt_pc)
154 |
155 | src_pc = src_pc + src_trans
156 | tgt_pc = tgt_pc + tgt_trans
157 |
158 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu())
159 | color_mask_src = ['r'] * num_points
160 | color_mask_tgt = ['g'] * num_points
161 | total_color = color_mask_src + color_mask_tgt
162 |
163 | fig = plt.figure()
164 | ax = fig.add_subplot(projection='3d')
165 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1)
166 | ax.axis('scaled')
167 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
168 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
169 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
170 |
171 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_gt'.format(iter_counts,i)))
172 | plt.close(fig)
173 |
174 | class ShapeAssemblyNet_A_vnn(nn.Module):
175 | def __init__(self, cfg, data_features):
176 | super().__init__()
177 | self.cfg = cfg
178 | self.encoder = self.init_encoder()
179 | self.pose_predictor_rot = self.init_pose_predictor_rot()
180 | self.pose_predictor_trans = self.init_pose_predictor_trans()
181 | if self.cfg.model.recon_loss:
182 | self.decoder = self.init_decoder()
183 | self.data_features = data_features
184 | self.iter_counts = 0
185 | self.close_eps = 0.1
186 | self.L2 = nn.MSELoss()
187 | self.R = torch.tensor([[0.26726124, -0.57735027, 0.77151675],
188 | [0.53452248, -0.57735027, -0.6172134],
189 | [0.80178373, 0.57735027, 0.15430335]], dtype=torch.float64).unsqueeze(0)
190 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist()
191 | self.transformer = self.init_transformer()
192 |
193 | def init_transformer(self):
194 | transformer = Transformer(cfg = self.cfg)
195 | return transformer
196 |
197 | def init_encoder(self):
198 | if self.cfg.model.encoderA == 'dgcnn':
199 | encoder = DGCNN(feat_dim=self.cfg.model.pc_feat_dim)
200 | elif self.cfg.model.encoderA == 'vn_dgcnn':
201 | encoder = VN_DGCNN_New(feat_dim=self.cfg.model.pc_feat_dim)
202 | elif self.cfg.model.encoderA == 'pointnet':
203 | encoder = PointNet(feat_dim=self.cfg.model.pc_feat_dim)
204 | return encoder
205 |
206 | def init_pose_predictor_rot(self):
207 | if self.cfg.model.encoderA == 'vn_dgcnn':
208 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3
209 | if self.cfg.model.pose_predictor_rot == 'original':
210 | pose_predictor_rot = Regressor_CR(pc_feat_dim= pc_feat_dim, out_dim=6)
211 | elif self.cfg.model.pose_predictor_rot == 'vn':
212 | pose_predictor_rot = VN_equ_Regressor(pc_feat_dim= pc_feat_dim/3, out_dim=6)
213 | return pose_predictor_rot
214 |
215 | def init_pose_predictor_trans(self):
216 | if self.cfg.model.encoderA == 'vn_dgcnn':
217 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3
218 | if self.cfg.model.pose_predictor_trans == 'original':
219 | pose_predictor_trans = Regressor_CR(pc_feat_dim=pc_feat_dim, out_dim=3)
220 | elif self.cfg.model.pose_predictor_trans == 'vn':
221 | pose_predictor_trans = VN_inv_Regressor(pc_feat_dim=pc_feat_dim/3, out_dim=3)
222 | return pose_predictor_trans
223 |
224 | def init_decoder(self):
225 | pc_feat_dim = self.cfg.model.pc_feat_dim
226 | decoder = MLPDecoder(feat_dim=pc_feat_dim, num_points=self.cfg.data.num_pc_points)
227 | return decoder
228 | def configure_optimizers(self):
229 | optimizer = torch.optim.Adam(
230 | self.parameters(),
231 | lr=self.cfg.optimizer.lr,
232 | weight_decay=self.cfg.optimizer.weight_decay,
233 | )
234 | return optimizer
235 |
236 | def check_equiv(self, x, R, xR, name):
237 | mean_diff = torch.mean(torch.abs(torch.matmul(x, R) - xR))
238 | if mean_diff > self.close_eps:
239 | print(f'---[Equiv check]--- {name}: {mean_diff}')
240 | return
241 |
242 | def check_inv(self, x, R, xR, name):
243 | mean_diff = torch.mean(torch.abs(x - xR))
244 | if mean_diff > self.close_eps:
245 | print(f'---[Equiv check]--- {name}: {mean_diff}')
246 | return
247 |
248 | def check_network_property(self, gt_data, pred_data):
249 | with torch.no_grad():
250 | B, _, N = gt_data['src_pc'].shape
251 | R = self.R.float().repeat(B, 1, 1).to(gt_data['src_pc'].device)
252 | pcs_R = torch.matmul(gt_data['src_pc'].permute(0, 2, 1), R).permute(0, 2, 1)
253 | pred_data_R = self.forward(pcs_R, gt_data['tgt_pc'])
254 |
255 | equiv_feats = pred_data['Fa']
256 | equiv_feats_R = pred_data_R['Fa']
257 | self.check_equiv(equiv_feats, R, equiv_feats_R, 'equiv_feats')
258 |
259 | inv_feats = pred_data['Ga']
260 | inv_feats_R = pred_data_R['Ga']
261 | self.check_inv(inv_feats, R, inv_feats_R, 'inv_feats')
262 |
263 | if self.cfg.model.pose_predictor_rot == 'vn':
264 | rot = bgs(pred_data['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1)
265 | rot_R = bgs(pred_data_R['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1)
266 | self.check_equiv(rot, R, rot_R, 'rot')
267 | return
268 |
269 | def _recon_pts(self, Ga, Gb):
270 | global_inv_feat = torch.sum(torch.cat([Ga, Gb], dim=1), dim=1)
271 | recon_pts = self.decoder(global_inv_feat)
272 | return recon_pts
273 |
274 | def forward(self, src_pc, tgt_pc):
275 | batch_size = src_pc.shape[0]
276 | num_points = src_pc.shape[2]
277 | if self.cfg.model.encoderA == 'dgcnn':
278 | src_point_feat = self.encoder(src_pc)
279 | tgt_point_feat = self.encoder(tgt_pc)
280 |
281 | src_feat = torch.mean(src_point_feat, dim=2)
282 | tgt_feat = torch.mean(tgt_point_feat, dim=2)
283 |
284 | if self.cfg.model.encoderA == 'vn_dgcnn':
285 | Fa, Ga = self.encoder(src_pc)
286 | Fb, Gb = self.encoder(tgt_pc)
287 |
288 | src_feat_corr = Fa * Gb[:, :, :3]
289 |
290 | src_feat = Fa
291 |
292 | if self.cfg.model.pose_predictor_rot == 'original':
293 | src_rot = self.pose_predictor_rot(src_feat_corr.reshape(batch_size, -1))
294 | else:
295 | src_rot = self.pose_predictor_rot(src_feat_corr)
296 |
297 | if self.cfg.model.pose_predictor_trans == 'original':
298 | src_trans = self.pose_predictor_trans(src_feat_corr.reshape(batch_size, -1))
299 | else:
300 | src_trans = self.pose_predictor_trans(src_feat_corr)
301 |
302 | if self.cfg.model.recon_loss:
303 | recon_pts = self._recon_pts(Ga, Gb)
304 | pred_dict = {
305 | 'src_rot': src_rot,
306 | 'src_trans': src_trans,
307 | }
308 | if self.cfg.model.encoderA == 'vn_dgcnn':
309 | pred_dict['Fa'] = Fa
310 | pred_dict['Ga'] = Ga
311 | if self.cfg.model.recon_loss:
312 | pred_dict['recon_pts'] = recon_pts
313 | return pred_dict
314 |
315 | def compute_point_loss(self, batch_data, pred_data):
316 | src_pc = batch_data['src_pc'].float()
317 |
318 | src_rot_gt = self.recover_R_from_6d(batch_data['src_rot'].float())
319 | src_trans_gt = batch_data['src_trans'].float()
320 |
321 | src_rot_pred = self.recover_R_from_6d(pred_data['src_rot'].float())
322 | src_trans_pred = pred_data['src_trans'].float()
323 |
324 | src_trans_pred = src_trans_pred.unsqueeze(2)
325 |
326 | transformed_src_pc_pred = src_rot_pred @ src_pc + src_trans_pred
327 | with torch.no_grad():
328 | transformed_src_pc_gt = src_rot_gt @ src_pc + src_trans_gt
329 | src_point_loss = torch.mean(torch.sum((transformed_src_pc_pred - transformed_src_pc_gt) ** 2, axis=1))
330 |
331 | point_loss = src_point_loss
332 | return point_loss
333 |
334 | def compute_trans_loss(self, batch_data, pred_data):
335 | src_trans_gt = batch_data['src_trans'].float()
336 |
337 | src_trans_pred = pred_data['src_trans']
338 |
339 | src_trans_pred = src_trans_pred.unsqueeze(dim=2)
340 |
341 | src_trans_loss = F.l1_loss(src_trans_pred, src_trans_gt)
342 | trans_loss = src_trans_loss
343 | return trans_loss
344 |
345 | def compute_rot_loss(self, batch_data, pred_data):
346 | src_R_6d = batch_data['src_rot']
347 |
348 | src_R_6d_pred = pred_data['src_rot']
349 |
350 | src_rot_loss = torch.mean(utils.get_6d_rot_loss(src_R_6d, src_R_6d_pred))
351 | rot_loss = src_rot_loss
352 | return rot_loss
353 |
354 | def compute_rot_loss_symmetry(self, batch_data, pred_data, device):
355 | partA_symmetry_type = batch_data['partA_symmetry_type']
356 |
357 | src_R_6d = batch_data['src_rot']
358 |
359 | src_R_6d_pred = pred_data['src_rot']
360 | src_rot_loss = torch.mean(utils.get_6d_rot_loss_symmetry(src_R_6d, src_R_6d_pred, partA_symmetry_type, device))
361 |
362 | rot_loss = src_rot_loss
363 | return rot_loss
364 |
365 | def compute_recon_loss(self, batch_data, pred_data):
366 | recon_pts = pred_data['recon_pts']
367 |
368 | src_pc = batch_data['src_pc'].float()
369 | tgt_pc = batch_data['tgt_pc'].float()
370 | src_quat_gt = batch_data['src_rot'].float()
371 | tgt_quat_gt = batch_data['tgt_rot'].float()
372 |
373 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1))
374 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1))
375 |
376 | src_trans_gt = batch_data['src_trans'].float()
377 | tgt_trans_gt = batch_data['tgt_trans'].float()
378 | with torch.no_grad():
379 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt
380 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt
381 | gt_pts = torch.cat([transformed_src_pc_gt, transformed_tgt_pc_gt], dim=2).permute(0, 2, 1)
382 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist()
383 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts)
384 | recon_loss = torch.mean(dist1) + torch.mean(dist2)
385 | return recon_loss
386 |
387 | def recover_R_from_6d(self, R_6d):
388 | R = utils.bgs(R_6d.reshape(-1, 2, 3).permute(0, 2, 1))
389 | return R
390 |
391 | def quat_to_eular(self, quat):
392 | quat = np.array([quat[1], quat[2], quat[3], quat[0]])
393 |
394 | r = R.from_quat(quat)
395 | euler0 = r.as_euler('xyz', degrees=True)
396 |
397 | return euler0
398 |
399 | def training_step(self, batch_data, device, batch_idx):
400 | self.iter_counts += 1
401 | total_loss, point_loss, rot_loss, trans_loss, recon_loss = self.forward_pass(batch_data, device, mode='train')
402 | return {"total_loss": total_loss,
403 | "point_loss": point_loss,
404 | "rot_loss": rot_loss,
405 | "trans_loss": trans_loss,
406 | "recon_loss": recon_loss,
407 | }
408 |
409 | def calculate_metrics(self, batch_data, pred_data, device, mode):
410 | GD = self.compute_rot_loss(batch_data, pred_data, device)
411 |
412 | rot_error = self.compute_rot_loss(batch_data, pred_data, device)
413 |
414 | src_pc = batch_data['src_pc'].float()
415 | src_quat_gt = batch_data['src_rot'].float()
416 |
417 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1))
418 |
419 | src_trans_gt = batch_data['src_trans'].float()
420 | with torch.no_grad():
421 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt
422 | gt_pts = transformed_src_pc_gt.permute(0, 2, 1)
423 |
424 | pred_R_src = self.recover_R_from_6d(pred_data['src_rot'])
425 | pred_t_src = pred_data['src_trans'].view(-1, 3, 1)
426 |
427 | gt_euler_src = pytorch3d.transforms.matrix_to_euler_angles(src_Rs, convention="XYZ")
428 |
429 | pred_euler_src = pytorch3d.transforms.matrix_to_euler_angles(pred_R_src, convention="XYZ")
430 |
431 | with torch.no_grad():
432 | transformed_src_pc_pred = pred_R_src @ src_pc + pred_t_src
433 |
434 | recon_pts = transformed_src_pc_pred .permute(0, 2, 1)
435 |
436 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts)
437 | PA = torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1)
438 |
439 | thre = 0.0001
440 | acc = (PA < thre)
441 | PA_threshold = acc.sum(-1) / acc.shape[0]
442 |
443 | RMSE_T_1 = (pred_t_src - src_trans_gt).pow(2).mean(dim=-1) ** 0.5
444 |
445 | RMSE_T = RMSE_T_1
446 |
447 | dist_a1, dist_a2, idx_a1, idx_a2 = self.chamLoss(transformed_src_pc_gt.permute(0,2,1), transformed_src_pc_pred.permute(0,2,1))
448 |
449 | CD_1 = torch.mean(dist_a1, dim=-1) + torch.mean(dist_a2, dim=-1)
450 |
451 | return GD, rot_error, RMSE_T, PA_threshold, PA, CD_1
452 |
453 | def forward_pass(self, batch_data, device, mode, vis_idx=-1):
454 | tgt_pc = batch_data['tgt_pc'].float()
455 |
456 | device = tgt_pc.device
457 |
458 | tgt_trans = batch_data['predicted_partB_position'].unsqueeze(-1).repeat(1,1,1024)
459 | tgt_rot = batch_data['predicted_partB_rotation']
460 |
461 | device = tgt_pc.device
462 | num_points = tgt_pc.shape[1]
463 | batch_size = tgt_pc.shape[0]
464 |
465 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(batch_size, 3, 3)
466 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat)
467 |
468 | transformed_tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double())
469 | transformed_tgt_pc = transformed_tgt_pc + tgt_trans
470 | transformed_tgt_pc = transformed_tgt_pc.float()
471 |
472 | pred_data = self.forward(batch_data['src_pc'].float(), transformed_tgt_pc)
473 | self.check_network_property(batch_data, pred_data)
474 | point_loss = 0.0
475 |
476 | rot_loss = self.compute_rot_loss(batch_data, pred_data, device)
477 | trans_loss = self.compute_trans_loss(batch_data, pred_data)
478 | if self.cfg.model.recon_loss:
479 | recon_loss = self.compute_recon_loss(batch_data, pred_data)
480 | else:
481 | recon_loss = 0.0
482 |
483 | if vis_idx > -1:
484 | debug_vis_input(batch_data, self.cfg, pred_data, vis_idx)
485 | debug_vis_output(batch_data, self.cfg, pred_data, vis_idx)
486 | debug_vis_gt(batch_data, self.cfg, pred_data, vis_idx)
487 |
488 | total_loss = point_loss + rot_loss + trans_loss + recon_loss
489 |
490 | if mode == 'val':
491 | return (self.calculate_metrics(batch_data, pred_data, device, mode), total_loss, point_loss,rot_loss,trans_loss,recon_loss)
492 |
493 | return total_loss, point_loss, rot_loss, trans_loss, recon_loss
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/network_vnn_A_indi.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 | import pytorch3d
8 | from pytorch3d.transforms import quaternion_to_matrix
9 | from transforms3d.quaternions import quat2mat
10 | from scipy.spatial.transform import Rotation as R
11 | from mpl_toolkits.mplot3d import Axes3D
12 |
13 | import os
14 | import sys
15 | import copy
16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17 | sys.path.append(os.path.join(BASE_DIR, '../../'))
18 | from models.encoder.pointnet import PointNet
19 | from models.encoder.dgcnn import DGCNN
20 | from models.decoder.MLPDecoder import MLPDecoder
21 | from models.encoder.vn_dgcnn import VN_DGCNN, VN_DGCNN_corr, VN_DGCNN_New
22 | from models.baseline.transformer import Transformer
23 | from models.baseline.regressor_CR import Regressor_CR, Regressor_6d, VN_Regressor_6d
24 | import utils
25 | from pdb import set_trace
26 |
27 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
28 | from ..ChamferDistancePytorch.chamfer3D import dist_chamfer_3D
29 |
30 | def bgs(d6s):
31 | bsz = d6s.shape[0]
32 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1)
33 | a2 = d6s[:, :, 1]
34 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1)
35 | b3 = torch.cross(b1, b2, dim=1)
36 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1)
37 |
38 | def debug_vis_input(batch_data, cfg, prd_data, iter_counts):
39 | for i in range(cfg.exp.batch_size):
40 | save_dir = cfg.exp.vis_dir
41 | vis_dir = os.path.join(save_dir, 'vis_A_input')
42 | if not os.path.exists((vis_dir)):
43 | os.mkdir(vis_dir)
44 |
45 | src_pc = batch_data['src_pc'][i]
46 | tgt_pc = batch_data['tgt_pc'][i]
47 |
48 | device = src_pc.device
49 | num_points = src_pc.shape[1]
50 |
51 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu())
52 | color_mask_src = ['r'] * num_points
53 | color_mask_tgt = ['g'] * num_points
54 | total_color = color_mask_src + color_mask_tgt
55 |
56 | fig = plt.figure()
57 | ax = fig.add_subplot(projection='3d')
58 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1)
59 | ax.axis('scaled')
60 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
61 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
62 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
63 |
64 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_input'.format(iter_counts,i)))
65 | plt.close(fig)
66 |
67 | def debug_vis_output(batch_data, cfg, pred_data, iter_counts):
68 | for i in range(cfg.exp.batch_size):
69 | save_dir = cfg.exp.vis_dir
70 | vis_dir = os.path.join(save_dir, 'vis_A_output')
71 | if not os.path.exists((vis_dir)):
72 | os.mkdir(vis_dir)
73 |
74 | src_pc = batch_data['src_pc'][i]
75 | src_trans = pred_data['src_trans'][i].unsqueeze(1)
76 | src_rot = pred_data['src_rot'][i]
77 |
78 | device = src_pc.device
79 | num_points = src_pc.shape[1]
80 |
81 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
82 | src_rot_mat = torch.linalg.inv(src_rot_mat)
83 |
84 | src_trans = src_trans.expand(-1, num_points)
85 | src_pc = torch.matmul(src_rot_mat.double(), src_pc.double())
86 |
87 | src_pc = src_pc + src_trans
88 |
89 | tgt_pc = batch_data['tgt_pc'][i]
90 |
91 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu())
92 | color_mask_src = ['r'] * num_points
93 | color_mask_tgt = ['g'] * num_points
94 | total_color = color_mask_src + color_mask_tgt
95 |
96 | fig = plt.figure()
97 | ax = fig.add_subplot(projection='3d')
98 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1)
99 | ax.axis('scaled')
100 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
101 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
102 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
103 |
104 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_output'.format(iter_counts,i)))
105 | plt.close(fig)
106 |
107 | def debug_vis_gt(batch_data, cfg, pred_data, iter_counts):
108 | for i in range(cfg.exp.batch_size):
109 | save_dir = cfg.exp.vis_dir
110 | vis_dir = os.path.join(save_dir, 'vis_A_gt')
111 | if not os.path.exists((vis_dir)):
112 | os.mkdir(vis_dir)
113 | src_pc = batch_data['src_pc'][i]
114 | tgt_pc = batch_data['tgt_pc'][i]
115 |
116 | src_trans = batch_data['src_trans'][i]
117 | src_rot = batch_data['src_rot'][i]
118 |
119 | device = src_pc.device
120 | num_points = src_pc.shape[1]
121 |
122 | src_rot_mat = bgs(src_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
123 | src_rot_mat = torch.linalg.inv(src_rot_mat)
124 |
125 | src_trans = src_trans.expand(-1, num_points)
126 |
127 | src_pc = torch.matmul(src_rot_mat, src_pc)
128 |
129 | src_pc = src_pc + src_trans
130 |
131 | total_pc = np.array(torch.cat([src_pc, tgt_pc], dim=1).detach().cpu())
132 | color_mask_src = ['r'] * num_points
133 | color_mask_tgt = ['g'] * num_points
134 | total_color = color_mask_src + color_mask_tgt
135 |
136 | fig = plt.figure()
137 | ax = fig.add_subplot(projection='3d')
138 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=1)
139 | ax.axis('scaled')
140 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
141 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
142 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
143 |
144 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_gt'.format(iter_counts,i)))
145 | plt.close(fig)
146 |
147 | class ShapeAssemblyNet_A_vnn(nn.Module):
148 | def __init__(self, cfg, data_features):
149 | super().__init__()
150 | self.cfg = cfg
151 | self.encoder = self.init_encoder()
152 | self.pose_predictor_rot = self.init_pose_predictor_rot()
153 | self.pose_predictor_trans = self.init_pose_predictor_trans()
154 | if self.cfg.model.recon_loss:
155 | self.decoder = self.init_decoder()
156 | self.data_features = data_features
157 | self.iter_counts = 0
158 | self.close_eps = 0.1
159 | self.L2 = nn.MSELoss()
160 | self.R = torch.tensor([[0.26726124, -0.57735027, 0.77151675],
161 | [0.53452248, -0.57735027, -0.6172134],
162 | [0.80178373, 0.57735027, 0.15430335]], dtype=torch.float64).unsqueeze(0)
163 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist()
164 | self.transformer = self.init_transformer()
165 |
166 | def init_transformer(self):
167 | transformer = Transformer(cfg = self.cfg)
168 | return transformer
169 |
170 | def init_encoder(self):
171 | if self.cfg.model.encoderA == 'dgcnn':
172 | encoder = DGCNN(feat_dim=self.cfg.model.pc_feat_dim)
173 | elif self.cfg.model.encoderA == 'vn_dgcnn':
174 | encoder = VN_DGCNN_New(feat_dim=self.cfg.model.pc_feat_dim)
175 | elif self.cfg.model.encoderA == 'pointnet':
176 | encoder = PointNet(feat_dim=self.cfg.model.pc_feat_dim)
177 | return encoder
178 |
179 | def init_pose_predictor_rot(self):
180 | if self.cfg.model.encoderA == 'vn_dgcnn':
181 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3
182 | if self.cfg.model.pose_predictor_rot == 'original':
183 | pose_predictor_rot = Regressor_CR(pc_feat_dim= pc_feat_dim, out_dim=6)
184 | elif self.cfg.model.pose_predictor_rot == 'vn':
185 | pose_predictor_rot = VN_equ_Regressor(pc_feat_dim= pc_feat_dim/3, out_dim=6)
186 | return pose_predictor_rot
187 |
188 | def init_pose_predictor_trans(self):
189 | if self.cfg.model.encoderA == 'vn_dgcnn':
190 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3
191 | if self.cfg.model.pose_predictor_trans == 'original':
192 | pose_predictor_trans = Regressor_CR(pc_feat_dim=pc_feat_dim, out_dim=3)
193 | elif self.cfg.model.pose_predictor_trans == 'vn':
194 | pose_predictor_trans = VN_inv_Regressor(pc_feat_dim=pc_feat_dim/3, out_dim=3)
195 | return pose_predictor_trans
196 |
197 | def init_decoder(self):
198 | pc_feat_dim = self.cfg.model.pc_feat_dim
199 | decoder = MLPDecoder(feat_dim=pc_feat_dim, num_points=self.cfg.data.num_pc_points)
200 | return decoder
201 | def configure_optimizers(self):
202 | optimizer = torch.optim.Adam(
203 | self.parameters(),
204 | lr=self.cfg.optimizer.lr,
205 | weight_decay=self.cfg.optimizer.weight_decay,
206 | )
207 | return optimizer
208 |
209 | def check_equiv(self, x, R, xR, name):
210 | mean_diff = torch.mean(torch.abs(torch.matmul(x, R) - xR))
211 | if mean_diff > self.close_eps:
212 | print(f'---[Equiv check]--- {name}: {mean_diff}')
213 | return
214 |
215 | def check_inv(self, x, R, xR, name):
216 | mean_diff = torch.mean(torch.abs(x - xR))
217 | if mean_diff > self.close_eps:
218 | print(f'---[Equiv check]--- {name}: {mean_diff}')
219 | return
220 |
221 | def check_network_property(self, gt_data, pred_data):
222 | with torch.no_grad():
223 | B, _, N = gt_data['src_pc'].shape
224 | R = self.R.float().repeat(B, 1, 1).to(gt_data['src_pc'].device)
225 | pcs_R = torch.matmul(gt_data['src_pc'].permute(0, 2, 1), R).permute(0, 2, 1)
226 | pred_data_R = self.forward(pcs_R, gt_data['tgt_pc'])
227 |
228 | equiv_feats = pred_data['Fa']
229 | equiv_feats_R = pred_data_R['Fa']
230 | self.check_equiv(equiv_feats, R, equiv_feats_R, 'equiv_feats')
231 |
232 | inv_feats = pred_data['Ga']
233 | inv_feats_R = pred_data_R['Ga']
234 | self.check_inv(inv_feats, R, inv_feats_R, 'inv_feats')
235 |
236 | if self.cfg.model.pose_predictor_rot == 'vn':
237 | rot = bgs(pred_data['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1)
238 | rot_R = bgs(pred_data_R['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1)
239 | self.check_equiv(rot, R, rot_R, 'rot')
240 | return
241 |
242 | def _recon_pts(self, Ga, Gb):
243 | global_inv_feat = torch.sum(torch.cat([Ga, Gb], dim=1), dim=1)
244 | recon_pts = self.decoder(global_inv_feat)
245 | return recon_pts
246 |
247 | def forward(self, src_pc, tgt_pc):
248 | batch_size = src_pc.shape[0]
249 | num_points = src_pc.shape[2]
250 | if self.cfg.model.encoderA == 'dgcnn':
251 | src_point_feat = self.encoder(src_pc)
252 | tgt_point_feat = self.encoder(tgt_pc)
253 |
254 | src_feat = torch.mean(src_point_feat, dim=2)
255 | tgt_feat = torch.mean(tgt_point_feat, dim=2)
256 |
257 | if self.cfg.model.encoderA == 'vn_dgcnn':
258 | Fa, Ga = self.encoder(src_pc)
259 | Fb, Gb = self.encoder(tgt_pc)
260 |
261 | src_feat_corr = Fa * Gb[:, :, :3]
262 |
263 | src_feat = Fa
264 |
265 | if self.cfg.model.pose_predictor_rot == 'original':
266 | src_rot = self.pose_predictor_rot(src_feat_corr.reshape(batch_size, -1))
267 | else:
268 | src_rot = self.pose_predictor_rot(src_feat_corr)
269 |
270 | if self.cfg.model.pose_predictor_trans == 'original':
271 | src_trans = self.pose_predictor_trans(src_feat_corr.reshape(batch_size, -1))
272 | else:
273 | src_trans = self.pose_predictor_trans(src_feat_corr)
274 |
275 | if self.cfg.model.recon_loss:
276 | recon_pts = self._recon_pts(Ga, Gb)
277 | pred_dict = {
278 | 'src_rot': src_rot,
279 | 'src_trans': src_trans,
280 | }
281 | if self.cfg.model.encoderA == 'vn_dgcnn':
282 | pred_dict['Fa'] = Fa
283 | pred_dict['Ga'] = Ga
284 | if self.cfg.model.recon_loss:
285 | pred_dict['recon_pts'] = recon_pts
286 | return pred_dict
287 |
288 | def compute_point_loss(self, batch_data, pred_data):
289 | src_pc = batch_data['src_pc'].float()
290 |
291 | src_rot_gt = self.recover_R_from_6d(batch_data['src_rot'].float())
292 | src_trans_gt = batch_data['src_trans'].float()
293 |
294 | src_rot_pred = self.recover_R_from_6d(pred_data['src_rot'].float())
295 | src_trans_pred = pred_data['src_trans'].float()
296 |
297 | src_trans_pred = src_trans_pred.unsqueeze(2)
298 |
299 | transformed_src_pc_pred = src_rot_pred @ src_pc + src_trans_pred
300 | with torch.no_grad():
301 | transformed_src_pc_gt = src_rot_gt @ src_pc + src_trans_gt
302 |
303 | src_point_loss = torch.mean(torch.sum((transformed_src_pc_pred - transformed_src_pc_gt) ** 2, axis=1))
304 |
305 | point_loss = src_point_loss
306 | return point_loss
307 |
308 | def compute_trans_loss(self, batch_data, pred_data):
309 | src_trans_gt = batch_data['src_trans'].float()
310 |
311 | src_trans_pred = pred_data['src_trans']
312 |
313 | src_trans_pred = src_trans_pred.unsqueeze(dim=2)
314 |
315 | src_trans_loss = F.l1_loss(src_trans_pred, src_trans_gt)
316 | trans_loss = src_trans_loss
317 | return trans_loss
318 |
319 | def compute_rot_loss(self, batch_data, pred_data):
320 | src_R_6d = batch_data['src_rot']
321 |
322 | src_R_6d_pred = pred_data['src_rot']
323 |
324 | src_rot_loss = torch.mean(utils.get_6d_rot_loss(src_R_6d, src_R_6d_pred))
325 | rot_loss = src_rot_loss
326 | return rot_loss
327 |
328 | def compute_rot_loss_symmetry(self, batch_data, pred_data, device):
329 | partA_symmetry_type = batch_data['partA_symmetry_type']
330 |
331 | src_R_6d = batch_data['src_rot']
332 |
333 | src_R_6d_pred = pred_data['src_rot']
334 | src_rot_loss = torch.mean(utils.get_6d_rot_loss_symmetry(src_R_6d, src_R_6d_pred, partA_symmetry_type, device))
335 |
336 | rot_loss = src_rot_loss
337 | return rot_loss
338 |
339 | def compute_recon_loss(self, batch_data, pred_data):
340 | recon_pts = pred_data['recon_pts']
341 |
342 | src_pc = batch_data['src_pc'].float()
343 | tgt_pc = batch_data['tgt_pc'].float()
344 | src_quat_gt = batch_data['src_rot'].float()
345 | tgt_quat_gt = batch_data['tgt_rot'].float()
346 |
347 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1))
348 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1))
349 |
350 | src_trans_gt = batch_data['src_trans'].float()
351 | tgt_trans_gt = batch_data['tgt_trans'].float()
352 | with torch.no_grad():
353 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt
354 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt
355 | gt_pts = torch.cat([transformed_src_pc_gt, transformed_tgt_pc_gt], dim=2).permute(0, 2, 1)
356 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist()
357 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts)
358 | recon_loss = torch.mean(dist1) + torch.mean(dist2)
359 | return recon_loss
360 |
361 | def recover_R_from_6d(self, R_6d):
362 | R = utils.bgs(R_6d.reshape(-1, 2, 3).permute(0, 2, 1))
363 | return R
364 |
365 | def quat_to_eular(self, quat):
366 | quat = np.array([quat[1], quat[2], quat[3], quat[0]])
367 |
368 | r = R.from_quat(quat)
369 | euler0 = r.as_euler('xyz', degrees=True)
370 |
371 | return euler0
372 |
373 | def training_step(self, batch_data, device, batch_idx):
374 | self.iter_counts += 1
375 | partA_position, partA_rotation, total_loss, point_loss, rot_loss, trans_loss, recon_loss = self.forward_pass(batch_data, device, mode='train')
376 | return {"total_loss": total_loss,
377 | "point_loss": point_loss,
378 | "rot_loss": rot_loss,
379 | "trans_loss": trans_loss,
380 | "recon_loss": recon_loss,
381 | "predicted_partA_position": partA_position,
382 | "predicted_partA_rotation": partA_rotation
383 | }
384 |
385 | def calculate_metrics(self, batch_data, pred_data, device, mode):
386 | GD = self.compute_rot_loss(batch_data, pred_data)
387 |
388 | rot_error = self.compute_rot_loss(batch_data, pred_data)
389 |
390 | src_pc = batch_data['src_pc'].float()
391 | src_quat_gt = batch_data['src_rot'].float()
392 |
393 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1))
394 |
395 | src_trans_gt = batch_data['src_trans'].float()
396 | with torch.no_grad():
397 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt
398 | gt_pts = transformed_src_pc_gt.permute(0, 2, 1)
399 |
400 | pred_R_src = self.recover_R_from_6d(pred_data['src_rot'])
401 | pred_t_src = pred_data['src_trans'].view(-1, 3, 1)
402 |
403 | gt_euler_src = pytorch3d.transforms.matrix_to_euler_angles(src_Rs, convention="XYZ")
404 |
405 | pred_euler_src = pytorch3d.transforms.matrix_to_euler_angles(pred_R_src, convention="XYZ")
406 |
407 | with torch.no_grad():
408 | transformed_src_pc_pred = pred_R_src @ src_pc + pred_t_src
409 |
410 | recon_pts = transformed_src_pc_pred .permute(0, 2, 1)
411 |
412 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts)
413 | PA = torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1)
414 |
415 | thre = 0.0001
416 | acc = (PA < thre)
417 | PA_threshold = acc.sum(-1) / acc.shape[0]
418 |
419 | RMSE_T_1 = (pred_t_src - src_trans_gt).pow(2).mean(dim=-1) ** 0.5
420 |
421 | RMSE_T = RMSE_T_1
422 |
423 | dist_a1, dist_a2, idx_a1, idx_a2 = self.chamLoss(transformed_src_pc_gt.permute(0,2,1), transformed_src_pc_pred.permute(0,2,1))
424 |
425 | CD_1 = torch.mean(dist_a1, dim=-1) + torch.mean(dist_a2, dim=-1)
426 |
427 | return GD, rot_error, RMSE_T, PA_threshold, PA, CD_1
428 |
429 | def forward_pass(self, batch_data, device, mode, vis_idx=-1):
430 | tgt_pc = batch_data['tgt_pc'].float()
431 | device = tgt_pc.device
432 | num_points = tgt_pc.shape[1]
433 | batch_size = tgt_pc.shape[0]
434 |
435 | pred_data = self.forward(batch_data['src_pc'].float(), batch_data['tgt_pc'].float())
436 | self.check_network_property(batch_data, pred_data)
437 | point_loss = 0.0
438 |
439 | rot_loss = self.compute_rot_loss(batch_data, pred_data)
440 | trans_loss = self.compute_trans_loss(batch_data, pred_data)
441 | if self.cfg.model.recon_loss:
442 | recon_loss = self.compute_recon_loss(batch_data, pred_data)
443 | else:
444 | recon_loss = 0.0
445 |
446 | if vis_idx > -1:
447 | debug_vis_input(batch_data, self.cfg, pred_data, vis_idx)
448 | debug_vis_output(batch_data, self.cfg, pred_data, vis_idx)
449 | debug_vis_gt(batch_data, self.cfg, pred_data, vis_idx)
450 |
451 | total_loss = point_loss + rot_loss + trans_loss + recon_loss
452 |
453 | if mode == 'val':
454 | return (self.calculate_metrics(batch_data, pred_data, device, mode), total_loss, point_loss,rot_loss,trans_loss,recon_loss)
455 |
456 | return pred_data['src_trans'], pred_data['src_rot'], total_loss, point_loss, rot_loss, trans_loss, recon_loss
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/network_vnn_B.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 |
7 | import pytorch3d
8 | from pytorch3d.transforms import quaternion_to_matrix
9 | from transforms3d.quaternions import quat2mat
10 | from scipy.spatial.transform import Rotation as R
11 | from mpl_toolkits.mplot3d import Axes3D
12 |
13 | import os
14 | import sys
15 | import copy
16 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
17 | sys.path.append(os.path.join(BASE_DIR, '../../'))
18 | from models.encoder.pointnet import PointNet
19 | from models.encoder.dgcnn import DGCNN
20 | from models.decoder.MLPDecoder import MLPDecoder
21 | from models.encoder.vn_dgcnn import VN_DGCNN, VN_DGCNN_corr, VN_DGCNN_New, DGCNN_New
22 | from models.baseline.regressor_CR import Regressor_CR, Regressor_6d, VN_Regressor_6d
23 | import utils
24 | from pdb import set_trace
25 |
26 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
27 | # print("BASE_DIR: ", BASE_DIR)
28 | from ..ChamferDistancePytorch.chamfer3D import dist_chamfer_3D
29 | # from chamfer_distance import ChamferDistance as dist_chamfer_3D
30 |
31 | def bgs(d6s):
32 | bsz = d6s.shape[0]
33 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1)
34 | a2 = d6s[:, :, 1]
35 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1)
36 | b3 = torch.cross(b1, b2, dim=1)
37 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1)
38 |
39 | def debug_vis_input(batch_data, cfg, prd_data, iter_counts):
40 | for i in range(cfg.exp.batch_size):
41 | # print("i is,", i)
42 | save_dir = cfg.exp.vis_dir
43 | vis_dir = os.path.join(save_dir, 'vis_B_input')
44 | if not os.path.exists((vis_dir)):
45 | os.mkdir(vis_dir)
46 | # src_pc = batch_data['src_pc'][i] # (3, 1024)
47 | tgt_pc = batch_data['tgt_pc'][i] # (3, 1024)
48 |
49 | device = tgt_pc.device
50 | num_points = tgt_pc.shape[1]
51 |
52 | total_pc = np.array(tgt_pc.detach().cpu())
53 | color_mask_tgt = ['g'] * num_points
54 |
55 | fig = plt.figure()
56 | ax = fig.add_subplot(projection='3d')
57 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=color_mask_tgt, s=10, alpha=0.9)
58 | ax.axis('scaled')
59 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
60 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
61 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
62 |
63 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_input'.format(iter_counts,i)))
64 | plt.close(fig)
65 |
66 | def debug_vis_output(batch_data, cfg, pred_data, iter_counts):
67 |
68 | # for every data in batch_size
69 | for i in range(cfg.exp.batch_size):
70 | save_dir = cfg.exp.vis_dir
71 | vis_dir = os.path.join(save_dir, 'vis_B_output')
72 | if not os.path.exists((vis_dir)):
73 | os.mkdir(vis_dir)
74 |
75 | tgt_pc = batch_data['tgt_pc'][i] # (3, 1024)
76 |
77 | tgt_trans = pred_data['tgt_trans'][i].unsqueeze(1) # (3)
78 | tgt_rot = pred_data['tgt_rot'][i] # (6)
79 |
80 | device = tgt_pc.device
81 | num_points = tgt_pc.shape[1]
82 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
83 |
84 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat)
85 |
86 | tgt_trans = tgt_trans.expand(-1, num_points) # (3, 1024)
87 |
88 | tgt_pc = torch.matmul(tgt_rot_mat.double(), tgt_pc.double()) # (3, 1024)
89 |
90 | tgt_pc = tgt_pc + tgt_trans
91 |
92 | total_pc = np.array(tgt_pc.detach().cpu())
93 | color_mask_tgt = ['g'] * num_points
94 | total_color = color_mask_tgt
95 |
96 | fig = plt.figure()
97 | ax = fig.add_subplot(projection='3d')
98 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=0.9)
99 | ax.axis('scaled')
100 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
101 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
102 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
103 |
104 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_output'.format(iter_counts,i)))
105 | plt.close(fig)
106 |
107 |
108 | def debug_vis_gt(batch_data, cfg, pred_data, iter_counts):
109 |
110 | for i in range(cfg.exp.batch_size):
111 | save_dir = cfg.exp.vis_dir
112 | vis_dir = os.path.join(save_dir, 'vis_B_gt')
113 | if not os.path.exists((vis_dir)):
114 | os.mkdir(vis_dir)
115 | tgt_pc = batch_data['tgt_pc'][i] # (3, 1024)
116 | tgt_trans = batch_data['tgt_trans'][i] # (3, 1)
117 | tgt_rot = batch_data['tgt_rot'][i] # (4)
118 | device = tgt_pc.device
119 | num_points = tgt_pc.shape[1]
120 |
121 | tgt_rot_mat = bgs(tgt_rot.reshape(-1, 2, 3).permute(0, 2, 1)).reshape(3, 3)
122 | tgt_rot_mat = torch.linalg.inv(tgt_rot_mat)
123 |
124 | tgt_trans = tgt_trans.expand(-1, num_points) # (3, 1024)
125 |
126 | tgt_pc = torch.matmul(tgt_rot_mat, tgt_pc) # (3, 1024)
127 |
128 | tgt_pc = tgt_pc + tgt_trans
129 |
130 | total_pc = np.array(tgt_pc.detach().cpu())
131 | color_mask_tgt = ['g'] * num_points
132 | total_color = color_mask_tgt
133 |
134 | fig = plt.figure()
135 | ax = fig.add_subplot(projection='3d')
136 | ax.scatter3D(list(total_pc[0]), list(total_pc[1]), list(total_pc[2]), c=total_color, s=10, alpha=0.9)
137 | ax.axis('scaled')
138 | ax.set_zlabel('Z', fontdict={'size': 20, 'color': 'red'})
139 | ax.set_ylabel('Y', fontdict={'size': 20, 'color': 'red'})
140 | ax.set_xlabel('X', fontdict={'size': 20, 'color': 'red'})
141 |
142 | fig.savefig(os.path.join(vis_dir, 'batch_{}_{}_gt'.format(iter_counts,i)))
143 | plt.close(fig)
144 |
145 |
146 | class ShapeAssemblyNet_B_vnn(nn.Module):
147 |
148 | def __init__(self, cfg, data_features):
149 | super().__init__()
150 | self.cfg = cfg
151 | self.encoder = self.init_encoder()
152 |
153 | self.encoder_dgcnn = DGCNN_New(feat_dim=cfg.model.pc_feat_dim)
154 |
155 | self.pose_predictor_rot = self.init_pose_predictor_rot()
156 | self.pose_predictor_trans = self.init_pose_predictor_trans()
157 | if self.cfg.model.recon_loss:
158 | self.decoder = self.init_decoder()
159 | self.data_features = data_features
160 |
161 | self.iter_counts = 0
162 | self.close_eps = 0.1
163 | self.L2 = nn.MSELoss()
164 | self.R = torch.tensor([[0.26726124, -0.57735027, 0.77151675],
165 | [0.53452248, -0.57735027, -0.6172134],
166 | [0.80178373, 0.57735027, 0.15430335]], dtype=torch.float64).unsqueeze(0)
167 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist()
168 |
169 | self.mlp_color = nn.Sequential(
170 | nn.Linear(512*2*3, 1024))
171 |
172 |
173 | def init_encoder(self):
174 | if self.cfg.model.encoderB == 'dgcnn':
175 | encoder = DGCNN(feat_dim=self.cfg.model.pc_feat_dim)
176 | elif self.cfg.model.encoderB == 'vn_dgcnn':
177 | encoder = VN_DGCNN_New(feat_dim=self.cfg.model.pc_feat_dim)
178 | elif self.cfg.model.encoderB == 'pointnet':
179 | encoder = PointNet(feat_dim=self.cfg.model.pc_feat_dim)
180 | return encoder
181 |
182 | def init_pose_predictor_rot(self):
183 | if self.cfg.model.encoderB == 'vn_dgcnn':
184 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3
185 | if self.cfg.model.pose_predictor_rot == 'original':
186 | pose_predictor_rot = Regressor_CR(pc_feat_dim= pc_feat_dim, out_dim=6)
187 | elif self.cfg.model.pose_predictor_rot == 'vn':
188 | pose_predictor_rot = VN_equ_Regressor(pc_feat_dim= pc_feat_dim/3, out_dim=6)
189 |
190 | return pose_predictor_rot
191 |
192 | def init_pose_predictor_trans(self):
193 | if self.cfg.model.encoderB == 'vn_dgcnn':
194 | pc_feat_dim = self.cfg.model.pc_feat_dim * 2 * 3
195 | if self.cfg.model.pose_predictor_trans == 'original':
196 | pose_predictor_trans = Regressor_CR(pc_feat_dim=pc_feat_dim, out_dim=3)
197 | elif self.cfg.model.pose_predictor_trans == 'vn':
198 | pose_predictor_trans = VN_inv_Regressor(pc_feat_dim=pc_feat_dim/3, out_dim=3)
199 | return pose_predictor_trans
200 |
201 | def init_decoder(self):
202 | pc_feat_dim = self.cfg.model.pc_feat_dim
203 | decoder = MLPDecoder(feat_dim=pc_feat_dim, num_points=self.cfg.data.num_pc_points)
204 | return decoder
205 |
206 |
207 | def configure_optimizers(self):
208 | optimizer = torch.optim.Adam(
209 | self.parameters(),
210 | lr=self.cfg.optimizer.lr,
211 | weight_decay=self.cfg.optimizer.weight_decay,
212 | )
213 | return optimizer
214 |
215 |
216 | def check_equiv(self, x, R, xR, name):
217 | mean_diff = torch.mean(torch.abs(torch.matmul(x, R) - xR))
218 | if mean_diff > self.close_eps:
219 | print(f'---[Equiv check]--- {name}: {mean_diff}')
220 | return
221 |
222 | def check_inv(self, x, R, xR, name):
223 | mean_diff = torch.mean(torch.abs(x - xR))
224 | if mean_diff > self.close_eps:
225 | print(f'---[Equiv check]--- {name}: {mean_diff}')
226 | return
227 |
228 | def check_network_property(self, gt_data, pred_data):
229 | with torch.no_grad():
230 | B, _, N = gt_data['src_pc'].shape
231 | R = self.R.float().repeat(B, 1, 1).to(gt_data['src_pc'].device)
232 | pcs_R = torch.matmul(gt_data['src_pc'].permute(0, 2, 1), R).permute(0, 2, 1)
233 | pred_data_R = self.forward(pcs_R, gt_data['tgt_pc'])
234 |
235 | equiv_feats = pred_data['Fa']
236 | equiv_feats_R = pred_data_R['Fa']
237 | self.check_equiv(equiv_feats, R, equiv_feats_R, 'equiv_feats')
238 |
239 | inv_feats = pred_data['Ga']
240 | inv_feats_R = pred_data_R['Ga']
241 | self.check_inv(inv_feats, R, inv_feats_R, 'inv_feats')
242 |
243 | if self.cfg.model.pose_predictor_rot == 'vn':
244 | rot = bgs(pred_data['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1)
245 | rot_R = bgs(pred_data_R['src_rot'].reshape(-1, 2, 3).permute(0, 2, 1)).permute(0, 2, 1)
246 | self.check_equiv(rot, R, rot_R, 'rot')
247 | return
248 |
249 | def _recon_pts(self, Ga, Gb):
250 | global_inv_feat = torch.sum(torch.cat([Ga, Gb], dim=1), dim=1)
251 | recon_pts = self.decoder(global_inv_feat)
252 | return recon_pts
253 |
254 | def forward(self, tgt_pc):
255 | batch_size = tgt_pc.shape[0]
256 | num_points = tgt_pc.shape[2]
257 | if self.cfg.model.encoderB == 'dgcnn':
258 | tgt_point_feat = self.encoder(tgt_pc) # (batch_size, pc_feat_dim(512), num_point(1024))
259 |
260 | tgt_feat = torch.mean(tgt_point_feat, dim=2)
261 |
262 | if self.cfg.model.encoderB == 'vn_dgcnn':
263 | Fb, Gb = self.encoder(tgt_pc)
264 | device = tgt_pc.device
265 |
266 | tgt_feat = Fb
267 |
268 | if self.cfg.model.pose_predictor_rot == 'original':
269 | tgt_rot = self.pose_predictor_rot(tgt_feat.reshape(batch_size, -1))
270 | else:
271 | tgt_rot = self.pose_predictor_rot(tgt_feat)
272 |
273 | if self.cfg.model.pose_predictor_trans == 'original':
274 | tgt_trans = self.pose_predictor_trans(tgt_feat.reshape(batch_size, -1))
275 | else:
276 | tgt_trans = self.pose_predictor_trans(tgt_feat)
277 |
278 | pred_dict = {
279 | 'tgt_rot': tgt_rot,
280 | 'tgt_trans': tgt_trans,
281 | }
282 | return pred_dict
283 |
284 | def compute_point_loss(self, batch_data, pred_data):
285 | tgt_pc = batch_data['tgt_pc'].float() # batch x 3 x 1024
286 | tgt_rot_gt = self.recover_R_from_6d(batch_data['tgt_rot'].float())
287 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1
288 | tgt_rot_pred = self.recover_R_from_6d(pred_data['tgt_rot'].float())
289 | tgt_trans_pred = pred_data['tgt_trans'].float()
290 |
291 | tgt_trans_pred = tgt_trans_pred.unsqueeze(2)
292 |
293 | # Target point loss
294 | transformed_tgt_pc_pred = tgt_rot_gt @ tgt_pc + tgt_trans_pred # batch x 3 x 1024
295 | with torch.no_grad():
296 | transformed_tgt_pc_gt = tgt_rot_pred @ tgt_pc + tgt_trans_gt # batch x 3 x 1024
297 | tgt_point_loss = torch.mean(torch.sum((transformed_tgt_pc_pred - transformed_tgt_pc_gt) ** 2, axis=1))
298 |
299 | # Point loss
300 | point_loss = tgt_point_loss
301 | return point_loss
302 |
303 | def compute_trans_loss(self, batch_data, pred_data):
304 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1ssssss
305 | tgt_trans_pred = pred_data['tgt_trans'] # batch x 3 x 1
306 | tgt_trans_pred = tgt_trans_pred.unsqueeze(dim=2)
307 | tgt_trans_loss = F.l1_loss(tgt_trans_pred, tgt_trans_gt)
308 | trans_loss = tgt_trans_loss
309 | return trans_loss
310 |
311 | def compute_rot_loss(self, batch_data, pred_data):
312 | tgt_R_6d = batch_data['tgt_rot']
313 | tgt_R_6d_pred = pred_data['tgt_rot']
314 | tgt_rot_loss = torch.mean(utils.get_6d_rot_loss(tgt_R_6d, tgt_R_6d_pred))
315 | rot_loss = tgt_rot_loss
316 | return rot_loss
317 |
318 | def compute_rot_loss_symmetry(self, batch_data, pred_data, device):
319 |
320 | partB_symmetry_type = batch_data['partB_symmetry_type']
321 |
322 | tgt_R_6d = batch_data['tgt_rot']
323 |
324 | tgt_R_6d_pred = pred_data['tgt_rot']
325 | tgt_rot_loss = torch.mean(utils.get_6d_rot_loss_symmetry(tgt_R_6d, tgt_R_6d_pred, partB_symmetry_type, device))
326 |
327 | rot_loss = tgt_rot_loss
328 | return rot_loss
329 |
330 | def compute_recon_loss(self, batch_data, pred_data):
331 |
332 | recon_pts = pred_data['recon_pts'] # batch x 1024 x 3
333 |
334 | src_pc = batch_data['src_pc'].float() # batch x 3 x 1024
335 | tgt_pc = batch_data['tgt_pc'].float() # batch x 3 x 1024
336 | src_quat_gt = batch_data['src_rot'].float()
337 | tgt_quat_gt = batch_data['tgt_rot'].float()
338 |
339 | src_Rs = utils.bgs(src_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) # batch x 3 x 3
340 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) # batch x 3 x 3
341 |
342 | src_trans_gt = batch_data['src_trans'].float() # batch x 3 x 1
343 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1
344 | with torch.no_grad():
345 | transformed_src_pc_gt = src_Rs @ src_pc + src_trans_gt # batch x 3 x 1024
346 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt # batch x 3 x 1024
347 | gt_pts = torch.cat([transformed_src_pc_gt, transformed_tgt_pc_gt], dim=2).permute(0, 2, 1) # batch x 2048 x 3
348 | self.chamLoss = dist_chamfer_3D.chamfer_3DDist()
349 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts)
350 | recon_loss = torch.mean(dist1) + torch.mean(dist2)
351 |
352 | return recon_loss
353 |
354 | def recover_R_from_6d(self, R_6d):
355 | R = utils.bgs(R_6d.reshape(-1, 2, 3).permute(0, 2, 1))
356 | return R
357 |
358 | def quat_to_eular(self, quat):
359 | quat = np.array([quat[1], quat[2], quat[3], quat[0]])
360 |
361 | r = R.from_quat(quat)
362 | euler0 = r.as_euler('xyz', degrees=True)
363 |
364 | return euler0
365 |
366 | def training_step(self, batch_data, device, batch_idx):
367 | self.iter_counts += 1
368 | partB_position, partB_rotation, total_loss, point_loss, rot_loss, trans_loss, recon_loss = self.forward_pass(batch_data, device, mode='train')
369 | return {"total_loss": total_loss,
370 | "point_loss": point_loss,
371 | "rot_loss": rot_loss,
372 | "trans_loss": trans_loss,
373 | "recon_loss": recon_loss,
374 | 'predicted_partB_position': partB_position,
375 | 'predicted_partB_rotation': partB_rotation
376 | }
377 |
378 |
379 | def calculate_metrics(self, batch_data, pred_data, device, mode):
380 | GD = self.compute_rot_loss(batch_data, pred_data)
381 |
382 | rot_error = self.compute_rot_loss(batch_data, pred_data)
383 |
384 | tgt_pc = batch_data['tgt_pc'].float() # batch x 3 x 1024
385 | tgt_quat_gt = batch_data['tgt_rot'].float()
386 |
387 | tgt_Rs = utils.bgs(tgt_quat_gt.reshape(-1, 2, 3).permute(0, 2, 1)) # batch x 3 x 3
388 |
389 | tgt_trans_gt = batch_data['tgt_trans'].float() # batch x 3 x 1
390 | with torch.no_grad():
391 | transformed_tgt_pc_gt = tgt_Rs @ tgt_pc + tgt_trans_gt # batch x 3 x 1024
392 | gt_pts = transformed_tgt_pc_gt.permute(0, 2, 1) # batch x 1024 x 3
393 |
394 | pred_R_tgt = self.recover_R_from_6d(pred_data['tgt_rot'])
395 | pred_t_tgt = pred_data['tgt_trans'].view(-1, 3, 1)
396 |
397 | gt_euler_tgt = pytorch3d.transforms.matrix_to_euler_angles(tgt_Rs, convention="XYZ")
398 |
399 | pred_euler_tgt = pytorch3d.transforms.matrix_to_euler_angles(pred_R_tgt, convention="XYZ")
400 |
401 | with torch.no_grad():
402 | transformed_tgt_pc_pred = pred_R_tgt @ tgt_pc + pred_t_tgt # batch x 3 x 1024
403 |
404 | recon_pts = transformed_tgt_pc_pred.permute(0, 2, 1) # batch x 2048 x 3
405 |
406 | dist1, dist2, idx1, idx2 = self.chamLoss(gt_pts, recon_pts)
407 | PA = torch.mean(dist1, dim=-1) + torch.mean(dist2, dim=-1)
408 |
409 | thre = 0.0001
410 | acc = (PA < thre)
411 | PA_threshold = acc.sum(-1) / acc.shape[0]
412 |
413 | RMSE_T_2 = (pred_t_tgt - tgt_trans_gt).pow(2).mean(dim=-1) ** 0.5
414 | RMSE_T = RMSE_T_2
415 |
416 | dist_b1, dist_b2, idx_b1, idx_b2 = self.chamLoss(transformed_tgt_pc_gt.permute(0,2,1), transformed_tgt_pc_pred.permute(0,2,1))
417 | CD_2 = torch.mean(dist_b1, dim=-1) + torch.mean(dist_b2, dim=-1)
418 |
419 | return GD, rot_error, RMSE_T, PA_threshold, PA, CD_2
420 |
421 | def forward_pass(self, batch_data, device, mode, vis_idx=-1):
422 |
423 | pred_data = self.forward(batch_data['tgt_pc'].float())
424 | if self.cfg.model.point_loss:
425 | point_loss = self.compute_point_loss(batch_data, pred_data)
426 | else:
427 | point_loss = 0.0
428 |
429 | rot_loss = self.compute_rot_loss(batch_data, pred_data)
430 | trans_loss = self.compute_trans_loss(batch_data, pred_data)
431 | recon_loss = 0.0
432 |
433 | if vis_idx > -1:
434 | debug_vis_input(batch_data, self.cfg, pred_data, vis_idx)
435 | debug_vis_output(batch_data, self.cfg, pred_data, vis_idx)
436 | debug_vis_gt(batch_data, self.cfg, pred_data, vis_idx)
437 |
438 | total_loss = point_loss + rot_loss + trans_loss + recon_loss
439 |
440 | if mode == 'val':
441 | return (self.calculate_metrics(batch_data, pred_data, device, mode), pred_data['tgt_trans'], pred_data['tgt_rot'], total_loss, point_loss,rot_loss,trans_loss,recon_loss)
442 |
443 | # Total loss
444 | return pred_data['tgt_trans'], pred_data['tgt_rot'],total_loss, point_loss, rot_loss, trans_loss, recon_loss
445 |
446 |
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/pose_estimator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class PointNet(nn.Module):
6 | def __init__(self, out_channels=(32, 64, 128), train_with_norm=True):
7 | super(PointNet, self).__init__()
8 | self.layers = nn.ModuleList()
9 | in_channels = 3
10 | for out_channel in out_channels:
11 | self.layers.append(nn.Conv1d(in_channels, out_channel, 1))
12 | self.layers.append(nn.BatchNorm1d(out_channel) if train_with_norm else nn.Identity())
13 | self.layers.append(nn.ReLU())
14 | in_channels = out_channel
15 | self.global_pool = nn.AdaptiveMaxPool1d(1)
16 |
17 | def forward(self, x):
18 | for layer in self.layers:
19 | x = layer(x)
20 | x = self.global_pool(x)
21 | x = x.squeeze(-1)
22 | return x
23 |
24 | class PoseClassifier(nn.Module):
25 | def __init__(self, pointnet_out_dim=128, pose_dim=6, hidden_dims=(512, 256, 128)):
26 | super(PoseClassifier, self).__init__()
27 | self.pointnet = PointNet(out_channels=(32, 64, pointnet_out_dim))
28 | input_dim = pointnet_out_dim + pose_dim
29 | layers = []
30 | for hidden_dim in hidden_dims:
31 | layers.append(nn.Linear(input_dim, hidden_dim))
32 | layers.append(nn.ReLU())
33 | input_dim = hidden_dim
34 | layers.append(nn.Linear(input_dim, 1))
35 | self.classifier = nn.Sequential(*layers)
36 |
37 | def forward(self, point_cloud, poses):
38 | # Point cloud feature extraction
39 | point_cloud_features = self.pointnet(point_cloud) # (batch_size, pointnet_out_dim)
40 |
41 | # Repeat point cloud features for each pose
42 | repeated_features = point_cloud_features.unsqueeze(1).repeat(1, poses.size(1), 1) # (batch_size, num_poses, pointnet_out_dim)
43 |
44 | # Concatenate pose features with point cloud features
45 | combined_features = torch.cat((repeated_features, poses), dim=-1) # (batch_size, num_poses, pointnet_out_dim + pose_dim)
46 |
47 | # Flatten the input for the classifier
48 | combined_features = combined_features.view(-1, combined_features.size(-1)) # (batch_size * num_poses, pointnet_out_dim + pose_dim)
49 |
50 | # Classification
51 | scores = self.classifier(combined_features) # (batch_size * num_poses, 1)
52 | scores = scores.view(-1, poses.size(1)) # (batch_size, num_poses)
53 |
54 | return scores
55 |
56 |
57 | batch_size = 8
58 | num_points = 1024
59 | num_poses = 1024
60 | pose_dim = 6
61 |
62 | point_cloud = torch.rand(batch_size, 3, num_points)
63 | poses = torch.rand(batch_size, num_poses, pose_dim)
64 |
65 | model = PoseClassifier(pointnet_out_dim=128, pose_dim=pose_dim)
66 | scores = model(point_cloud, poses)
67 | print(scores)
68 | print(scores.size())
69 |
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/regressor_CR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.encoder.vn_layers import *
5 | from pdb import set_trace
6 |
7 |
8 | class Regressor_CR(nn.Module):
9 | def __init__(self, pc_feat_dim, out_dim):
10 | super().__init__()
11 | self.fc_layers = nn.Sequential(
12 | nn.Linear(pc_feat_dim, 256),
13 | nn.LeakyReLU(),
14 | nn.Linear(256, 128),
15 | nn.LeakyReLU()
16 | )
17 | print("pc_feature.size is", pc_feat_dim)
18 |
19 | self.head = nn.Linear(128, out_dim)
20 |
21 | def forward(self, x):
22 | f = self.fc_layers(x)
23 | output = self.head(f)
24 | return output
25 |
26 | class Corr_Aggregator_CR(nn.Module):
27 | def __init__(self, pc_feat_dim, out_dim):
28 | super().__init__()
29 | self.fc_layers = nn.Sequential(
30 | nn.Linear(pc_feat_dim, out_dim),
31 | nn.LeakyReLU(),
32 | )
33 | self.head = nn.Linear(out_dim, out_dim)
34 | def forward(self, x):
35 | f = self.fc_layers(x)
36 | output = self.head(f)
37 | return output
38 |
39 | class VN_Corr_Aggregator_CR(nn.Module):
40 | def __init__(self, pc_feat_dim, out_dim):
41 | super().__init__()
42 | self.fc_layers = nn.Sequential(
43 | nn.Linear(pc_feat_dim, out_dim),
44 | nn.LeakyReLU(),
45 | )
46 | self.head = nn.Linear(out_dim, out_dim)
47 | def forward(self, x):
48 | f = self.fc_layers(x)
49 | output = self.head(f)
50 | return output
51 |
52 | class Regressor_6d(nn.Module):
53 | def __init__(self, pc_feat_dim):
54 | super().__init__()
55 | self.fc_layers = nn.Sequential(
56 | nn.Linear(2*pc_feat_dim, 256),
57 | nn.BatchNorm1d(256),
58 | nn.LeakyReLU(0.2),
59 | nn.Linear(256, 128),
60 | nn.BatchNorm1d(128),
61 | nn.LeakyReLU(0.2)
62 | )
63 |
64 | # Rotation prediction head
65 | self.rot_head = nn.Linear(128, 6)
66 |
67 | # Translation prediction head
68 | self.trans_head = nn.Linear(128, 3)
69 |
70 | def forward(self, x):
71 | f = self.fc_layers(x)
72 | quat = self.rot_head(f)
73 | quat = quat / torch.norm(quat, p=2, dim=1, keepdim=True)
74 | trans = self.trans_head(f)
75 | trans = torch.unsqueeze(trans, dim=2)
76 | return quat, trans
77 |
78 |
79 | class VN_Regressor(nn.Module):
80 | def __init__(self, pc_feat_dim, out_dim):
81 | super().__init__()
82 | self.fc_layers = nn.Sequential(
83 | VNLinear(2*pc_feat_dim, 256),
84 | nn.BatchNorm1d(256),
85 | VNNewLeakyReLU(in_channels=256, negative_slope=0.2),
86 | VNLinear(256, 128),
87 | nn.BatchNorm1d(128),
88 | VNNewLeakyReLU(in_channels=128, negative_slope=0.2)
89 | )
90 |
91 | # Rotation prediction head
92 | self.head = VNLinear(128, out_dim)
93 |
94 | def forward(self, x):
95 | f = self.fc_layers(x)
96 | output = self.head(f)
97 | return output
98 |
99 | class VN_Regressor_6d(nn.Module):
100 | def __init__(self, pc_feat_dim):
101 | super().__init__()
102 | self.fc_layers = nn.Sequential(
103 | VNLinear(1024, 256),
104 | # VNBatchNorm(256),
105 | nn.BatchNorm1d(256),
106 | VNLeakyReLU(in_channels=256, negative_slope=0.2),
107 | VNLinear(256, 128),
108 | # VNBatchNorm(128),
109 | nn.BatchNorm1d(128),
110 | VNLeakyReLU(in_channels=128, negative_slope=0.2)
111 | )
112 |
113 | # Rotation prediction head
114 | self.rot_head = VNLinear(128, 2)
115 |
116 | # Translation prediction head
117 | self.trans_head = nn.Linear(128*3, 3)
118 |
119 | def forward(self, x):
120 | f = self.fc_layers(x)
121 |
122 | rot = self.rot_head(f)
123 | rot = rot / torch.norm(rot, p=2, dim=1, keepdim=True)
124 | trans = self.trans_head(f.reshape(-1, 128*3))
125 | trans = torch.unsqueeze(trans, dim=2)
126 | return rot, trans
--------------------------------------------------------------------------------
/src/shape_assembly/models/train/transformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import pytorch_lightning as pl
7 |
8 | def clones(module, N):
9 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
10 |
11 | def attention(query, key, value, mask=None):
12 | d_k = query.size(-1)
13 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / (d_k ** 0.5)
14 | if mask is not None:
15 | scores = scores.masked_fill(mask == 0, -1e9)
16 | p_attn = F.softmax(scores, dim=-1)
17 | return torch.matmul(p_attn, value), p_attn
18 |
19 | class EncoderDecoder(pl.LightningModule):
20 | def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
21 | super().__init__()
22 | self.encoder = encoder
23 | self.decoder = decoder
24 | self.src_embed = src_embed
25 | self.tgt_embed = tgt_embed
26 | self.generator = generator
27 |
28 | def encode(self, src, src_mask):
29 | return self.encoder(self.src_embed(src), src_mask)
30 |
31 | def decode(self, memory, src_mask, tgt, tgt_mask):
32 | return self.generator(self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask))
33 |
34 | def forward(self, src, tgt, src_mask, tgt_mask):
35 | return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
36 |
37 | class Encoder(pl.LightningModule):
38 | def __init__(self, layer, N):
39 | super().__init__()
40 | self.layers = clones(layer, N)
41 | self.norm = LayerNorm(layer.size)
42 |
43 | def forward(self, x, mask):
44 | for layer in self.layers:
45 | x = layer(x, mask)
46 | return self.norm(x)
47 |
48 | class Decoder(pl.LightningModule):
49 | def __init__(self, layer, N):
50 | super().__init__()
51 | self.layers = clones(layer, N)
52 | self.norm = LayerNorm(layer.size)
53 |
54 | def forward(self, x, memory, src_mask, tgt_mask):
55 | for layer in self.layers:
56 | x = layer(x, memory, src_mask, tgt_mask)
57 | return self.norm(x)
58 |
59 | class LayerNorm(pl.LightningModule):
60 | def __init__(self, features, eps=1e-6):
61 | super().__init__()
62 | self.a_2 = nn.Parameter(torch.ones(features))
63 | self.b_2 = nn.Parameter(torch.zeros(features))
64 | self.eps = eps
65 |
66 | def forward(self, x):
67 | mean = x.mean(-1, keepdim=True)
68 | std = x.std(-1, keepdim=True)
69 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
70 |
71 | class SublayerConnection(pl.LightningModule):
72 | def __init__(self, size):
73 | super().__init__()
74 | self.norm = LayerNorm(size)
75 |
76 | def forward(self, x, sublayer):
77 | return x + sublayer(self.norm(x))
78 |
79 | class EncoderLayer(pl.LightningModule):
80 | def __init__(self, size, self_attn, feed_forward):
81 | super(EncoderLayer, self).__init__()
82 | self.self_attn = self_attn
83 | self.feed_forward = feed_forward
84 | self.sublayer = clones(SublayerConnection(size), 2)
85 | self.size = size
86 |
87 | def forward(self, x, mask):
88 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
89 | return self.sublayer[1](x, self.feed_forward)
90 |
91 | class DecoderLayer(pl.LightningModule):
92 |
93 | def __init__(self, size, self_attn, src_attn, feed_forward):
94 | super().__init__()
95 | self.size = size
96 | self.self_attn = self_attn
97 | self.src_attn = src_attn
98 | self.feed_forward = feed_forward
99 | self.sublayer = clones(SublayerConnection(size), 3)
100 |
101 | def forward(self, x, memory, src_mask, tgt_mask):
102 | m = memory
103 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
104 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
105 | return self.sublayer[2](x, self.feed_forward)
106 |
107 | class MultiHeadedAttention(pl.LightningModule):
108 |
109 | def __init__(self, h, d_model):
110 | super().__init__()
111 | assert d_model % h == 0
112 | # We assume d_v always equals d_k
113 | self.d_k = d_model // h
114 | self.h = h
115 | self.linears = clones(nn.Linear(d_model, d_model), 4)
116 |
117 | def forward(self, query, key, value, mask=None):
118 | if mask is not None:
119 | mask = mask.unsqueeze(1)
120 | nbatches = query.size(0)
121 |
122 | # 1) Do all the linear projections in batch from d_model => h x d_k
123 | query, key, value = \
124 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous()
125 | for l, x in zip(self.linears, (query, key, value))]
126 |
127 | # 2) Apply attention on all the projected vectors in batch.
128 | x, self.attn = attention(query, key, value, mask=mask)
129 |
130 | # 3) "Concat" using a view and apply a final linear.
131 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
132 |
133 | return self.linears[-1](x)
134 |
135 | class PositionwiseFeedForward(pl.LightningModule):
136 |
137 | def __init__(self, d_model, d_ff):
138 | super().__init__()
139 | self.w_1 = nn.Linear(d_model, d_ff)
140 | self.norm = nn.Sequential()
141 | self.w_2 = nn.Linear(d_ff, d_model)
142 |
143 | def forward(self, x):
144 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous())
145 |
146 | class Transformer(pl.LightningModule):
147 |
148 | def __init__(self, cfg):
149 | super().__init__()
150 | c = copy.deepcopy
151 | attn = MultiHeadedAttention(
152 | cfg.model.num_heads,
153 | cfg.model.pc_feat_dim
154 | )
155 |
156 | ff = PositionwiseFeedForward(
157 | cfg.model.pc_feat_dim,
158 | cfg.model.transformer_feat_dim
159 | )
160 |
161 | self.model = EncoderDecoder(
162 | Encoder(EncoderLayer(cfg.model.pc_feat_dim, c(attn), c(ff)), cfg.model.num_blocks),
163 | Decoder(DecoderLayer(cfg.model.pc_feat_dim, c(attn), c(attn), c(ff)), cfg.model.num_blocks),
164 | nn.Sequential(),
165 | nn.Sequential(),
166 | nn.Sequential()
167 | )
168 |
169 | def forward(self, src, tgt):
170 | src = src.transpose(2, 1).contiguous()
171 | tgt = tgt.transpose(2, 1).contiguous()
172 | src_corr_feat = self.model(tgt, src, None, None).transpose(2, 1).contiguous()
173 | return src_corr_feat
174 |
--------------------------------------------------------------------------------
/src/shape_assembly/utils.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.mplot3d import Axes3D
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | def render_pts_label_png(fn, pc, color):
9 | # pc: (num_points, 3), color: (num_points,)
10 | new_color = []
11 | for i in range(len(color)):
12 | if color[i] == 1:
13 | new_color.append('#ab4700')
14 | else:
15 | new_color.append('#00479e')
16 | fig = plt.figure(dpi=200)
17 | ax = fig.add_subplot(111, projection='3d')
18 | plt.title('point cloud')
19 |
20 | ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=new_color, marker='.', s=5, linewidth=0, alpha=1)
21 | ax.axis('scaled')
22 | ax.set_xlabel('X Label')
23 | ax.set_ylabel('Y Label')
24 | ax.set_zlabel('Z Label')
25 |
26 | plt.savefig(fn+'.png')
27 |
28 | def compute_distance_between_rotations(P,Q):
29 | #! input two rotation matrices, output the distance between them (unit is rad)
30 | #! P,Q are 3x3 numpy arrays
31 | P = np.asarray(P)
32 | Q = np.asarray(Q)
33 | R = np.matmul(P,Q.swapaxes(1,2))
34 | theta = np.arccos(np.clip((np.trace(R,axis1 = 1,axis2 = 2) - 1)/2,-1,1))
35 | return np.mean(theta)
36 |
37 | def bgs(d6s):
38 | bsz = d6s.shape[0]
39 | b1 = F.normalize(d6s[:, :, 0], p=2, dim=1)
40 | a2 = d6s[:, :, 1]
41 | b2 = F.normalize(a2 - torch.bmm(b1.view(bsz, 1, -1), a2.view(bsz, -1, 1)).view(bsz, 1) * b1, p=2, dim=1)
42 | b3 = torch.cross(b1, b2, dim=1)
43 | return torch.stack([b1, b2, b3], dim=1).permute(0, 2, 1)
44 |
45 |
46 | def bgdR(Rgts, Rps):
47 | Rds = torch.bmm(Rgts.permute(0, 2, 1), Rps)
48 | Rt = torch.sum(Rds[:, torch.eye(3).bool()], 1) #batch trace
49 | # necessary or it might lead to nans and the likes
50 | theta = torch.clamp(0.5 * (Rt - 1), -1 + 1e-6, 1 - 1e-6)
51 | # print("\033[33m the theta in bgdR is", theta,"\033[0m")
52 | return torch.acos(theta)
53 |
54 | def get_6d_rot_loss(gt_6d, pred_6d):
55 | #input pred_6d , gt_6d : batch * 2 * 3
56 | pred_Rs = bgs(pred_6d.reshape(-1, 2, 3).permute(0, 2, 1))
57 | gt_Rs = bgs(gt_6d.reshape(-1, 2, 3).permute(0, 2, 1))
58 | theta = bgdR(gt_Rs, pred_Rs)
59 | theta_degree = theta * 180 / np.pi
60 | # print("\033[33m the theta is", theta,"\033[0m")
61 | return theta_degree
62 |
63 | def get_6d_rot_loss_symmetry_new(batch_data, pred_data, device):
64 |
65 | batch_size = batch_data['src_rot'].shape[0]
66 |
67 | partA_symmetry_type = batch_data['partA_symmetry_type']
68 | partB_symmetry_type = batch_data['partB_symmetry_type']
69 |
70 | src_R_6d = batch_data['src_rot']
71 | tgt_R_6d = batch_data['tgt_rot']
72 |
73 | src_R_6d_pred = pred_data['src_rot']
74 | tgt_R_6d_pred = pred_data['tgt_rot']
75 | tgt_R_6d_init = batch_data['predicted_partB_rotation']
76 |
77 | src_Rs = bgs(src_R_6d_pred.reshape(-1, 2, 3).permute(0, 2, 1))
78 | gt_src_Rs = bgs(src_R_6d.reshape(-1, 2, 3).permute(0, 2, 1))
79 |
80 | tgt_Rs = bgs(tgt_R_6d_pred.reshape(-1, 2, 3).permute(0, 2, 1))
81 | tgt_Rs_init = bgs(tgt_R_6d_init.reshape(-1, 2, 3).permute(0, 2, 1))
82 |
83 | tgt_Rs_new = torch.matmul(tgt_Rs, tgt_Rs_init)
84 | gt_tgt_Rs = bgs(tgt_R_6d.reshape(-1, 2, 3).permute(0, 2, 1))
85 |
86 |
87 | R1 = src_Rs / torch.pow(torch.det(src_Rs), 1/3).view(-1,1,1)
88 | R2 = gt_src_Rs / torch.pow(torch.det(gt_src_Rs), 1/3).view(-1,1,1)
89 |
90 | R3 = tgt_Rs_new/ torch.pow(torch.det(tgt_Rs_init), 1/3).view(-1,1,1)
91 | R4 = gt_tgt_Rs / torch.pow(torch.det(tgt_Rs), 1/3).view(-1,1,1)
92 |
93 | # R5 = gt_tgt_Rs / torch.pow(torch.det(gt_tgt_Rs), 1/3).view(-1,1,1)
94 |
95 | # cos_theta = torch.zeros(batch_size)
96 | z = torch.tensor([0.0,0.0,1.0], device=device).unsqueeze(0).repeat(batch_size,1)
97 | cos_theta = torch.zeros(batch_size)
98 | for i in range(batch_size):
99 | # for every data in batch_size
100 | symmetry_i = partA_symmetry_type[i]
101 |
102 | # if the data is z-axis symmetric
103 | if symmetry_i[4].item() == 1:
104 | # cosidering symmetry when rotating around z-axis
105 | z1 = torch.matmul(R1[i], z[i])
106 | z2 = torch.matmul(R2[i], z[i])
107 |
108 | cos_theta[i] = torch.dot(z1,z2) / (torch.norm(z1) * torch.norm(z2))
109 |
110 | else:
111 | R_A = torch.matmul(R1[i], R2[i].transpose(1,0))
112 | cos_theta[i] = (torch.trace(R_A) - 1) /2
113 |
114 | theta_src = torch.acos(torch.clamp(cos_theta, -1.0+1e-6, 1.0-1e-6))*180 / np.pi
115 |
116 | cos_theta_B = torch.zeros(batch_size)
117 | for i in range(batch_size):
118 | # for every data in batch_size
119 | symmetry_i = partB_symmetry_type[i]
120 |
121 | # if the data is z-axis symmetric
122 | if symmetry_i[4].item() == 1:
123 | # cosidering symmetry when rotating around z-axis
124 |
125 | # z4 = torch.matmul(R4[i], z[i])
126 | z3 = torch.matmul(R3[i], z[i])
127 | z4 = torch.matmul(R4[i], z[i])
128 |
129 | cos_theta_B[i] = torch.dot(z3,z4) / (torch.norm(z3) * torch.norm(z4))
130 |
131 | else:
132 | R_B = torch.matmul(R3[i], R4[i].transpose(1,0))
133 | cos_theta_B[i] = (torch.trace(R_B) - 1) /2
134 |
135 | theta_tgt = torch.acos(torch.clamp(cos_theta_B, -1.0+1e-6, 1.0-1e-6))*180 / np.pi
136 |
137 | src_rot_loss = torch.mean(theta_src)
138 | tgt_rot_loss = torch.mean(theta_tgt)
139 |
140 | # rot_loss = (src_rot_loss + tgt_rot_loss)/2.0
141 |
142 | return (src_rot_loss, tgt_rot_loss)
143 |
144 | def get_6d_rot_loss_symmetry(gt_6d, pred_6d, symmetry, device):
145 | batch_size = gt_6d.shape[0]
146 |
147 | pred_Rs = bgs(pred_6d.reshape(-1, 2, 3).permute(0, 2, 1))
148 | gt_Rs = bgs(gt_6d.reshape(-1, 2, 3).permute(0, 2, 1))
149 | R1 = pred_Rs / torch.pow(torch.det(pred_Rs), 1/3).view(-1,1,1)
150 | R2 = gt_Rs / torch.pow(torch.det(gt_Rs), 1/3).view(-1,1,1)
151 |
152 | # cos_theta = torch.zeros(batch_size)
153 | z = torch.tensor([0.0,0.0,1.0], device=device).unsqueeze(0).repeat(batch_size,1)
154 | cos_theta = torch.zeros(batch_size)
155 | for i in range(batch_size):
156 | # for every data in batch_size
157 | symmetry_i = symmetry[i]
158 |
159 | # if the data is z-axis symmetric
160 | if symmetry_i[4].item() == 1:
161 | # cosidering symmetry when rotating around z-axis
162 | z1 = torch.matmul(R1[i], z[i])
163 | z2 = torch.matmul(R2[i], z[i])
164 | cos_theta[i] = torch.dot(z1,z2) / (torch.norm(z1) * torch.norm(z2))
165 | else:
166 | R = torch.matmul(R1[i], R2[i].transpose(1,0))
167 | cos_theta[i] = (torch.trace(R) - 1) /2
168 |
169 | theta = torch.acos(torch.clamp(cos_theta, -1.0+1e-6, 1.0-1e-6))*180 / np.pi
170 | return theta
171 |
172 |
173 | def printout(flog, strout):
174 | print(strout)
175 | if flog is not None:
176 | flog.write(strout + '\n')
177 |
--------------------------------------------------------------------------------
/src/shape_assembly/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.1'
2 |
--------------------------------------------------------------------------------