├── .gitattributes ├── .gitignore ├── README.md ├── experiments ├── experiment_accuracy │ ├── analyze.sh │ ├── experiment_accuracy.py │ ├── run.sh │ └── table.py ├── experiment_infeasible_poses │ ├── analyze.sh │ ├── convert_pose_dataset.py │ ├── experiment_infeasible_poses.py │ ├── generate_infeasible_poses.py │ ├── run.sh │ └── table.py ├── experiment_joint_limits │ ├── analyze.sh │ ├── experiment_joint_limits.py │ ├── joint_visualize.py │ ├── results │ │ └── paper_models │ │ │ └── panda_512k_experiment │ │ │ ├── plots │ │ │ └── joint_limits.pdf │ │ │ └── results.pkl │ └── run.sh ├── experiment_timing │ ├── analyze.sh │ ├── experiment_timing.py │ ├── run.sh │ └── timing_curve.py └── experiment_tracIK │ ├── analyze.sh │ ├── dist_images.py │ ├── experiment_2.py │ ├── run.sh │ ├── run_tracik.sh │ ├── table.py │ └── tracik_comparison.py ├── generative_graphik ├── args │ ├── parser.py │ └── utils.py ├── model.py ├── networks │ ├── egnn.py │ ├── eqgraph.py │ ├── gatgraph.py │ ├── gcngraph.py │ ├── linearvae.py │ ├── mpnngraph.py │ └── sagegraph.py ├── train.py ├── train.sh └── utils │ ├── __init__.py │ ├── dataset_generation.py │ └── torch_utils.py ├── paper_models.zip └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | paper_models.zip filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | saved_models/ 2 | datasets/ 3 | **/results/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # PyCharm 136 | .idea/ 137 | 138 | # Mac 139 | .DS_STORE 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # generative-graphIK 2 | Code for paper on [Generative Graphical Inverse Kinematics](https://arxiv.org/abs/2209.08812). 3 | 4 | ## Installation 5 | Install matching versions of [PyTorch](https://pytorch.org/get-started/previous-versions/#v1101:~:text=org/whl/cpu-,v1.10.1,-Conda) and [PyTorch-Geometric](https://pytorch-geometric.readthedocs.io/en/2.0.3/notes/installation.html). We used `torch-1.10.1` and `torch-geometric-2.0.4` but other versions should work as well. 6 | 7 | After installing the above: 8 | ``` 9 | pip install -e . 10 | ``` 11 | 12 | ## Generate a dataset and train 13 | ``` 14 | ./train.sh 15 | ``` 16 | 17 | See `./generative-graphik/generative_graphik/args/parser.py` for more details on data generation and model parameters. 18 | 19 | ## Modifying data generation 20 | To modify the training data, modify lines 29-35 of `train.sh`. 21 | 22 | To train on specific robots: 23 | ``` 24 | python -u ${SRC_PATH}/generative_graphik/utils/dataset_generation.py \ 25 | --id "${DATASET_NAME}" \ 26 | --robots ur10 kuka panda lwa4d lwa4p \ 27 | --num_examples 512000 \ 28 | --max_examples_per_file 512000 \ 29 | --goal_type pose \ 30 | --randomize False 31 | ``` 32 | 33 | To train on random robots of DOFs 6 and 7: 34 | ``` 35 | python -u ${SRC_PATH}/generative_graphik/utils/dataset_generation.py \ 36 | --id "${DATASET_NAME}" \ 37 | --robots revolute_chain \ 38 | --dof 7 6 \ 39 | --num_examples 512000 \ 40 | --max_examples_per_file 512000 \ 41 | --goal_type pose \ 42 | --randomize True \ 43 | --randomize_percentage 0.4 44 | ``` 45 | 46 | ## Modifying hyperparameters 47 | To modify the model parameters, modify lines 43-77 of `train.sh`. 48 | -------------------------------------------------------------------------------- /experiments/experiment_accuracy/analyze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 8 | 9 | python table.py \ 10 | --id "${NAME}_experiment" \ 11 | --save_latex True \ 12 | --latex_path "/home/olimoyo/generative-graphik/experiments/experiment_accuracy/results/" 13 | -------------------------------------------------------------------------------- /experiments/experiment_accuracy/experiment_accuracy.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import sys 5 | import argparse 6 | from graphik.graphs import ProblemGraphRevolute 7 | from graphik.graphs.graph_revolute import list_to_variable_dict 8 | from graphik.robots import RobotRevolute 9 | 10 | from torch_geometric.data import DataLoader 11 | from generative_graphik.utils.dataset_generation import generate_data_point 12 | 13 | from generative_graphik.utils.torch_utils import batchFKmultiDOF, batchIKmultiDOF, node_attributes 14 | 15 | os.environ["PYOPENGL_PLATFORM"] = "egl" 16 | import random 17 | from argparse import Namespace 18 | import copy 19 | import pandas as pd 20 | import time 21 | 22 | import graphik 23 | import matplotlib.pyplot as plt 24 | import numpy as np 25 | import torch 26 | from generative_graphik.args.parser import parse_analysis_args 27 | from graphik.utils.roboturdf import ( 28 | RobotURDF, 29 | load_ur10, 30 | load_kuka, 31 | load_schunk_lwa4d, 32 | load_schunk_lwa4p, 33 | load_panda, 34 | ) 35 | # import pyrender 36 | 37 | from graphik.utils.dgp import graph_from_pos 38 | from liegroups.numpy import SE3, SO3 39 | 40 | def model_arg_loader(path): 41 | """Load hyperparameters from trained model.""" 42 | if os.path.isdir(path): 43 | with open(os.path.join(path, "hyperparameters.txt"), "r") as fp: 44 | return Namespace(**json.load(fp)) 45 | 46 | # NOTE generates all the initializations and stores them to a pickle file 47 | def main(args): 48 | device = args.device 49 | num_evals = args.n_evals # number of evaluations 50 | robot_types = args.robots 51 | 52 | evals_per_robot = num_evals // len(robot_types) 53 | for model_path in args.model_path:# number of evaluations per robots 54 | spec = importlib.util.spec_from_file_location("model", model_path + "model.py") 55 | model = importlib.util.module_from_spec(spec) 56 | spec.loader.exec_module(model) 57 | 58 | # load models 59 | model_args = model_arg_loader(model_path) 60 | model = model.Model(model_args).to(device) 61 | name = model_args.id.replace("model", "results") 62 | c = np.pi / 180 63 | 64 | if model_path is not None: 65 | try: 66 | state_dict = torch.load(model_path + f"checkpoints/checkpoint.pth", map_location=device) 67 | model.load_state_dict(state_dict["net"]) 68 | model.eval() 69 | except Exception as e: 70 | print(e) 71 | 72 | all_sol_data = [] 73 | fig_handle, ax_handle = plt.subplots(nrows=2, ncols=2, figsize=(8, 8)) 74 | for robot_type in robot_types: 75 | if robot_type == "ur10": 76 | # robot, graph = load_ur10(limits=None) 77 | # fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 78 | # urdf_robot = RobotURDF(fname) 79 | 80 | # UR10 coordinates for testing 81 | modified_dh = False 82 | a = [0, -0.612, 0.5723, 0, 0, 0] 83 | d = [0.1273, 0, 0, 0.1639, 0.1157, 0.0922] 84 | al = [np.pi / 2, 0, 0, np.pi / 2, -np.pi / 2, 0] 85 | # th = [0, np.pi, 0, 0, 0, 0] 86 | th = [0, 0, 0, 0, 0, 0] 87 | 88 | params = { 89 | "a": a, 90 | "alpha": al, 91 | "d": d, 92 | "theta": th, 93 | "modified_dh": modified_dh, 94 | "num_joints": 6, 95 | } 96 | robot = RobotRevolute(params) 97 | graph = ProblemGraphRevolute(robot) 98 | elif robot_type == "kuka": 99 | # limits_l = -np.array([170, 120, 170, 120, 170, 120, 170]) * c 100 | # limits_u = np.array([170, 120, 170, 120, 170, 120, 170]) * c 101 | # limits = [limits_l, limits_u] 102 | # robot, graph = load_kuka(limits=None) 103 | # fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 104 | # urdf_robot = RobotURDF(fname) 105 | 106 | # UR10 coordinates for testing 107 | modified_dh = False 108 | a = [0, 0, 0, 0, 0, 0, 0] 109 | d = [0.34, 0, 0.40, 0, 0.40, 0, 0.126] 110 | al = [-np.pi / 2, np.pi / 2, np.pi / 2, -np.pi / 2, -np.pi / 2, np.pi / 2, 0] 111 | th = [0, 0, 0, 0, 0, 0, 0] 112 | 113 | params = { 114 | "a": a, 115 | "alpha": al, 116 | "d": d, 117 | "theta": th, 118 | "modified_dh": modified_dh, 119 | "num_joints": 7, 120 | } 121 | robot = RobotRevolute(params) 122 | graph = ProblemGraphRevolute(robot) 123 | elif robot_type == "lwa4d": 124 | # limits_l = -np.array([180, 123, 180, 125, 180, 170, 170]) * c 125 | # limits_u = np.array([180, 123, 180, 125, 180, 170, 170]) * c 126 | # limits = [limits_l, limits_u] 127 | # robot, graph = load_schunk_lwa4d(limits=None) 128 | # fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 129 | # urdf_robot = RobotURDF(fname) 130 | 131 | modified_dh = False 132 | a = [0, 0, 0, 0, 0, 0, 0] 133 | d = [0.3, 0, 0.328, 0, 0.323, 0, 0.0824] 134 | al = [-np.pi / 2, np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, np.pi / 2, 0] 135 | th = [0, 0, 0, 0, 0, 0, 0] 136 | 137 | params = { 138 | "a": a, 139 | "alpha": al, 140 | "d": d, 141 | "theta": th, 142 | "modified_dh": modified_dh, 143 | "num_joints": 7, 144 | } 145 | robot = RobotRevolute(params) 146 | graph = ProblemGraphRevolute(robot) 147 | elif robot_type == "panda": 148 | limits_l = -np.array( 149 | [2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 3.7525, 2.8973] 150 | ) 151 | limits_u = np.array( 152 | [2.8973, 1.7628, 2.8973, 3.0718, 2.8973, 3.7525, 2.8973] 153 | ) 154 | limits = [limits_l, limits_u] 155 | robot, graph = load_panda(limits=limits) 156 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 157 | urdf_robot = RobotURDF(fname) 158 | 159 | # modified_dh = False 160 | # a = [0, 0, 0, 0.0825, -0.0825, 0, 0.088] 161 | # d = [0.333, 0, 0.316, 0, 0.384, 0, 0] 162 | # al = [0, -np.pi/2, np.pi / 2, np.pi / 2, -np.pi / 2, np.pi / 2, np.pi / 2] 163 | # th = [0, 0, 0, 0, 0, 0, 0] 164 | 165 | # params = { 166 | # "a": a, 167 | # "alpha": al, 168 | # "d": d, 169 | # "theta": th, 170 | # "modified_dh": modified_dh, 171 | # "num_joints": 7, 172 | # } 173 | # robot = RobotRevolute(params) 174 | # graph = ProblemGraphRevolute(robot) 175 | elif robot_type == "lwa4p": 176 | # limits_l = -np.array([170, 170, 155.3, 170, 170, 170]) * c 177 | # limits_u = np.array([170, 170, 155.3, 170, 170, 170]) * c 178 | # limits = [limits_l, limits_u] 179 | # robot, graph = load_schunk_lwa4p(limits=None) 180 | # fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 181 | # urdf_robot = RobotURDF(fname) 182 | 183 | modified_dh = False 184 | a = [0, 0.350, 0, 0, 0, 0] 185 | d = [0.205, 0, 0, 0.305, 0, 0.075] 186 | al = [-np.pi / 2, np.pi, -np.pi / 2, np.pi / 2, -np.pi / 2, 0] 187 | th = [0, 0, 0, 0, 0, 0] 188 | 189 | params = { 190 | "a": a, 191 | "alpha": al, 192 | "d": d, 193 | "theta": th, 194 | "modified_dh": modified_dh, 195 | "num_joints": 6, 196 | } 197 | robot = RobotRevolute(params) 198 | graph = ProblemGraphRevolute(robot) 199 | else: 200 | raise NotImplementedError 201 | 202 | for kdx in range(evals_per_robot): 203 | sol_data = [] 204 | 205 | # Generate random problem 206 | prob_data = generate_data_point(graph).to(device) 207 | prob_data.num_graphs = 1 208 | # T_goal = prob_data.T_ee.cpu().numpy() 209 | data = model.preprocess(prob_data) 210 | P_goal = data.pos.cpu().numpy() 211 | T_goal = SE3.exp(data.T_ee.cpu().numpy()) 212 | # T_goal = SE3.exp(data.T_ee[0].cpu().numpy()) 213 | # q_goal = graph.joint_variables(graph_from_pos(P_goal, graph.node_ids)) 214 | # q_goal_np = np.fromiter( 215 | # (q_goal[f"p{jj}"] for jj in range(1, graph.robot.n + 1)), dtype=float 216 | # ) 217 | 218 | # Compute solutions 219 | t0 = time.time() 220 | P_all = model.forward_eval(data, num_samples=args.num_samples).cpu().detach().numpy() 221 | # P_all = ( 222 | # model.forward_eval( 223 | # x=data.pos, 224 | # h=torch.cat((data.type, data.goal_data_repeated_per_node), dim=-1), 225 | # edge_attr=data.edge_attr, 226 | # edge_attr_partial=data.edge_attr_partial, 227 | # edge_index=data.edge_index_full, 228 | # partial_goal_mask=data.partial_goal_mask, 229 | # nodes_per_single_graph= int(data.num_nodes / 1), 230 | # batch_size=1, 231 | # num_samples=args.num_samples 232 | # ) 233 | # ) 234 | # torch.cuda.synchronize() 235 | t_sol = time.time() - t0 236 | 237 | # Analyze solutions 238 | e_pose = np.empty([P_all.shape[0]]) 239 | e_pos = np.empty([P_all.shape[0]]) 240 | e_rot = np.empty([P_all.shape[0]]) 241 | q_sols_np = np.empty([P_all.shape[0], robot.n]) 242 | q_sols = [] 243 | for idx in range(P_all.shape[0]): 244 | P = P_all[idx, :] 245 | # q_sol = batchIKmultiDOF(P, prob_data.T0, prob_data.num_joints, 246 | # T_final = torch.tensor(T_goal.as_matrix(), dtype=P.dtype).unsqueeze(0).to(device)) 247 | 248 | q_sol = graph.joint_variables( 249 | graph_from_pos(P, graph.node_ids), {robot.end_effectors[0]: T_goal} 250 | ) # get joint angles 251 | 252 | q_sols_np[idx] = np.fromiter( 253 | (q_sol[f"p{jj}"] for jj in range(1, graph.robot.n + 1)), dtype=float 254 | ) 255 | 256 | T_ee = graph.robot.pose(q_sol, robot.end_effectors[-1]) 257 | e_pose[idx] = np.linalg.norm(T_ee.inv().dot(T_goal).log()) 258 | e_pos[idx] = np.linalg.norm(T_ee.trans - T_goal.trans) 259 | e_rot[idx] = np.linalg.norm(T_ee.rot.inv().dot(T_goal.rot).log()) 260 | 261 | entry = { 262 | "Id": kdx, 263 | "Robot": robot_type, 264 | "Goal Pose": T_goal.as_matrix(), 265 | "Sol. Config": q_sols_np[idx], 266 | # "Sol. Points": P_all[idx,:], 267 | "Err. Pose": e_pose[idx], 268 | "Err. Position": e_pos[idx], 269 | "Err. Rotation": e_rot[idx], 270 | # "Goal Config": q_goal_np, 271 | # "Goal Points": P_goal, 272 | "Sol. Time": t_sol, 273 | } 274 | sol_data.append(entry) 275 | all_sol_data.append(pd.DataFrame(sol_data)) 276 | 277 | pd_data = pd.concat(all_sol_data) 278 | 279 | exp_dir = f"{sys.path[0]}/results/"+ f"{args.id}/" 280 | os.makedirs(exp_dir, exist_ok=True) 281 | pd_data.to_pickle(os.path.join(exp_dir, "results.pkl")) 282 | 283 | 284 | if __name__ == "__main__": 285 | random.seed(17) 286 | parser = argparse.ArgumentParser() 287 | 288 | # General settings 289 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 290 | parser.add_argument("--model_path", nargs="*", type=str, required=True, help="Path to folder with model data") 291 | parser.add_argument('--device', type=str, default='cuda:1', help='Device to use for PyTorch') 292 | parser.add_argument("--robots", nargs="*", type=str, default=["planar_chain"], help="Type of robot used") 293 | parser.add_argument("--n_evals", type=int, default=100, help="Number of evaluations") 294 | parser.add_argument("--num_samples", type=int, default=100, help="Total number of samples per problem") 295 | 296 | args = parser.parse_args() 297 | main(args) -------------------------------------------------------------------------------- /experiments/experiment_accuracy/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 8 | 9 | # copy the model and training code if new 10 | if [ -d "${MODEL_PATH}" ] 11 | then 12 | echo "Directory already exists, using existing model." 13 | else 14 | echo "Model not found!" 15 | fi 16 | 17 | python experiment_accuracy.py \ 18 | --id "${NAME}_experiment" \ 19 | --robots panda \ 20 | --n_evals 500 \ 21 | --model_path "${MODEL_PATH}/" \ 22 | --device cuda:1 \ 23 | --num_samples 32 -------------------------------------------------------------------------------- /experiments/experiment_accuracy/table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import importlib.util 3 | import json 4 | # import tikzplotlib 5 | import os 6 | import sys 7 | 8 | from seaborn.utils import despine 9 | 10 | os.environ["PYOPENGL_PLATFORM"] = "egl" 11 | import random 12 | import copy 13 | import pandas as pd 14 | import time 15 | 16 | # import graphik 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | from generative_graphik.args.utils import str2bool 20 | import argparse 21 | import seaborn as sns 22 | sns.set_theme(style="darkgrid") 23 | 24 | def main(args): 25 | data = pd.read_pickle(f"{sys.path[0]}/results/{args.id}/results.pkl") 26 | stats = data.reset_index() 27 | stats["Err. Position"] = stats["Err. Position"]*1000 28 | stats["Err. Rotation"] = stats["Err. Rotation"]*(180/np.pi) 29 | q_pos = stats["Err. Position"].quantile(0.99) 30 | q_rot = stats["Err. Rotation"].quantile(0.99) 31 | stats = stats.drop(stats[stats["Err. Position"] > q_pos].index) 32 | stats = stats.drop(stats[stats["Err. Rotation"] > q_rot].index) 33 | 34 | stats = stats.groupby(["Robot", "Id"])[["Err. Position", "Err. Rotation"]].describe().groupby("Robot").mean() 35 | 36 | stats = stats.drop(["count", "std", "50%"], axis=1, level=1) 37 | 38 | perc_data = data.set_index(["Robot", "Id"]) 39 | # perc_data["Success"] = ( 40 | # (perc_data["Err. Position"] < 0.01) & (perc_data["Err. Rotation"] < (180/np.pi)) 41 | # ) 42 | # suc_pos_perc = ( 43 | # perc_data["Success"] 44 | # .eq(True) 45 | # .groupby(level=[0, 1]) 46 | # .value_counts(True) 47 | # .unstack(fill_value=0) 48 | # ) 49 | # stats["Success [\%]"] = suc_pos_perc.groupby(level=0).apply(lambda c: (c>0).sum()/len(c))[True]*100 50 | 51 | stats.rename(columns = {'75%': 'Q$_{3}$', '25%': 'Q$_{1}$','Err. Position':'Err. Pos. [mm]', 'Err. Rotation':'Err. Rot. [deg]'}, inplace = True) 52 | 53 | # Swap to follow paper order 54 | cols = stats.columns.tolist() 55 | ins = cols.pop(4) 56 | cols.insert(2, ins) 57 | ins = cols.pop(9) 58 | cols.insert(7, ins) 59 | stats = stats[cols] 60 | 61 | if args.save_latex: 62 | s = stats.style 63 | s.format(precision=1) 64 | s.format_index(axis=1,level=[0,1]) 65 | latex = s.to_latex(hrules=True, multicol_align="c") 66 | print(latex) 67 | 68 | # open text file 69 | # text_file = open(args.latex_path + "tables/experiment_2.tex", "w") 70 | # write string to file 71 | # text_file.write(latex) 72 | 73 | if __name__ == "__main__": 74 | random.seed(17) 75 | parser = argparse.ArgumentParser() 76 | 77 | # General settings 78 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 79 | parser.add_argument("--save_latex", type=str2bool, default=True, help="Save latex table.") 80 | parser.add_argument("--latex_path", type=str, default="/home/filipmrc/Documents/Latex/2022-limoyo-maric-generative-corl/tables/experiment_2.tex", help="Base path for folder with experiment data") 81 | 82 | args = parser.parse_args() 83 | main(args) -------------------------------------------------------------------------------- /experiments/experiment_infeasible_poses/analyze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python table.py \ 3 | --save_latex True \ 4 | --latex_path "/home/olimoyo/generative-graphik/experiments/experiment_infeasible_poses/results/" 5 | -------------------------------------------------------------------------------- /experiments/experiment_infeasible_poses/convert_pose_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import pickle 6 | 7 | import torch 8 | from torch_geometric.data import InMemoryDataset 9 | from torch_geometric.loader import DataLoader 10 | 11 | def load_datasets(path: str, device, val_pcnt=0): 12 | with open(path, "rb") as f: 13 | try: 14 | # data = pickle.load(f) 15 | data = torch.load(f) 16 | data._data = data._data.to(device) 17 | val_size = int((val_pcnt/100)*len(data)) 18 | train_size = len(data) - val_size 19 | val_dataset, train_dataset = torch.utils.data.random_split(data, [val_size, train_size]) 20 | except (OSError, IOError) as e: 21 | val_dataset = None 22 | train_dataset = None 23 | return train_dataset, val_dataset 24 | 25 | class CachedDataset(InMemoryDataset): 26 | def __init__(self, data, slices): 27 | super(CachedDataset, self).__init__(None) 28 | self.data, self.slices = data, slices 29 | 30 | def convert_poses(args): 31 | dataset_path = os.path.join(args.dataset_path, "data_0.p") 32 | # with open(os.path.join(args.dataset_path, "data_0.p"), 'rb') as f: 33 | # reference_dataset = torch.load(f) 34 | 35 | # Load training dataset from training path 36 | all_data, _ = load_datasets( 37 | dataset_path, 38 | "cpu", 39 | val_pcnt=0 40 | ) 41 | 42 | loader = DataLoader( 43 | all_data, 44 | batch_size=256, 45 | num_workers=16, 46 | shuffle=False 47 | ) 48 | pose_list = [] 49 | for idx, data in enumerate(loader): 50 | pose_list.append(data.T_ee.cpu().detach().numpy()) 51 | pose_list = np.concatenate(pose_list) 52 | 53 | with open(os.path.join(args.dataset_path, f"np_poses.pkl"), 'wb') as f: 54 | # Dump the list of NumPy arrays into the file 55 | pickle.dump(pose_list, f) 56 | 57 | def parse_convert_poses_args(): 58 | parser = argparse.ArgumentParser() 59 | 60 | # General settings 61 | parser.add_argument("--dataset_path", type=str, default="/media/stonehenge/users/oliver-limoyo/2.56m-lwa4p", help="Path to folder with infeasible poses to test with.") 62 | 63 | args = parser.parse_args() 64 | return args 65 | 66 | if __name__ == "__main__": 67 | args = parse_convert_poses_args() 68 | infeasible_poses = convert_poses(args) 69 | -------------------------------------------------------------------------------- /experiments/experiment_infeasible_poses/experiment_infeasible_poses.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import argparse 3 | import random 4 | import os 5 | from argparse import Namespace 6 | import json 7 | import pickle 8 | import numpy as np 9 | import sys 10 | 11 | import torch 12 | from torch_geometric.data import InMemoryDataset 13 | from liegroups.numpy import SE3 14 | import pandas as pd 15 | 16 | import graphik 17 | from graphik.utils.dgp import graph_from_pos 18 | from graphik.utils.roboturdf import ( 19 | RobotURDF, 20 | load_ur10, 21 | load_kuka, 22 | load_schunk_lwa4d, 23 | load_schunk_lwa4p, 24 | load_panda, 25 | ) 26 | from generative_graphik.utils.dataset_generation import generate_data_point_from_pose 27 | 28 | def model_arg_loader(path): 29 | """Load hyperparameters from trained model.""" 30 | if os.path.isdir(path): 31 | with open(os.path.join(path, "hyperparameters.txt"), "r") as fp: 32 | return Namespace(**json.load(fp)) 33 | 34 | class CachedDataset(InMemoryDataset): 35 | def __init__(self, data, slices): 36 | super(CachedDataset, self).__init__(None) 37 | self.data, self.slices = data, slices 38 | 39 | def run_experiment_infeasible_poses(args): 40 | robot_types = args.robots 41 | model_paths = args.model_paths 42 | infeasible_pose_paths = args.infeasible_pose_paths 43 | dataset_paths = args.dataset_paths 44 | 45 | all_sol_data = [] 46 | # Initialize tracIK 47 | for robot_type, model_path, infeasible_pose_path, dataset_path in zip(robot_types, model_paths, infeasible_pose_paths, dataset_paths): 48 | 49 | # Load model 50 | spec = importlib.util.spec_from_file_location("model", os.path.join(model_path, "model.py")) 51 | model = importlib.util.module_from_spec(spec) 52 | spec.loader.exec_module(model) 53 | model_args = model_arg_loader(model_path) 54 | model = model.Model(model_args).to(args.device) 55 | 56 | # Load infeasible poses 57 | with open(infeasible_pose_path, 'rb') as f: 58 | infeasible_poses_list = pickle.load(f) 59 | 60 | infeasible_poses = [] 61 | for infeasible_pose in infeasible_poses_list: 62 | infeasible_pose = infeasible_pose[None, ...] 63 | infeasible_poses.append(infeasible_pose) 64 | infeasible_poses = np.concatenate(infeasible_poses) 65 | 66 | # Load dataset poses to compare to 67 | with open(os.path.join(dataset_path, "np_poses.pkl"), 'rb') as f: 68 | dataset_poses = pickle.load(f) 69 | 70 | # Load problem 71 | if robot_type == "ur10": 72 | robot, graph = load_ur10(limits=None) 73 | fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 74 | link_base, link_ee = 'base_link', 'ee_link' 75 | elif robot_type == "kuka": 76 | robot, graph = load_kuka(limits=None) 77 | fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 78 | link_base, link_ee = 'lbr_iiwa_link_0', 'ee_link' 79 | elif robot_type == "lwa4d": 80 | robot, graph = load_schunk_lwa4d(limits=None) 81 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 82 | link_base, link_ee = 'lwa4d_base_link', 'lwa4d_ee_link' 83 | elif robot_type == "panda": 84 | robot, graph = load_panda(limits=None) 85 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 86 | link_base, link_ee = 'panda_link0', 'panda_link7' 87 | elif robot_type == "lwa4p": 88 | robot, graph = load_schunk_lwa4p(limits=None) 89 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 90 | link_base, link_ee = 'lwa4p_base_link', 'lwa4p_ee_link' 91 | else: 92 | raise NotImplementedError 93 | 94 | sol_data = [] 95 | for kdx, infeasible_pose in enumerate(infeasible_poses): 96 | print(robot_type, f"{kdx + 1} / {len(infeasible_poses)}") 97 | prob_data = generate_data_point_from_pose(graph, infeasible_pose).to(args.device) 98 | prob_data.num_graphs = 1 99 | data = model.preprocess(prob_data) 100 | num_samples_pre = args.num_samples * 4 101 | T_goal = SE3.from_matrix(infeasible_pose) 102 | 103 | # Compute solutions 104 | P_all = ( 105 | model.forward_eval( 106 | x_partial=data.pos_partial, 107 | h=torch.cat((data.type, data.goal_data_repeated_per_node), dim=-1), 108 | edge_attr_partial=data.edge_attr_partial, 109 | edge_index=data.edge_index_full, 110 | nodes_per_single_graph= int(data.num_nodes / 1), 111 | batch_size=1, 112 | num_samples=num_samples_pre 113 | ) 114 | ).cpu().detach().numpy() 115 | 116 | # Analyze solutions 117 | e_pose = np.empty([P_all.shape[0]]) 118 | e_pos = np.empty([P_all.shape[0]]) 119 | e_rot = np.empty([P_all.shape[0]]) 120 | q_sols_np = np.empty([P_all.shape[0], robot.n]) 121 | for idx in range(P_all.shape[0]): 122 | P = P_all[idx, :] 123 | 124 | q_sol = graph.joint_variables( 125 | graph_from_pos(P, graph.node_ids), {robot.end_effectors[0]: T_goal} 126 | ) # get joint angles 127 | 128 | q_sols_np[idx] = np.fromiter( 129 | (q_sol[f"p{jj}"] for jj in range(1, graph.robot.n + 1)), dtype=float 130 | ) 131 | 132 | T_ee = graph.robot.pose(q_sol, robot.end_effectors[-1]) 133 | e_pose[idx] = np.linalg.norm(T_ee.inv().dot(T_goal).log()) 134 | e_pos[idx] = np.linalg.norm(T_ee.trans - T_goal.trans) 135 | e_rot[idx] = np.linalg.norm(T_ee.rot.inv().dot(T_goal.rot).log()) 136 | idx_sorted = np.argsort(e_pose) 137 | 138 | for ii in idx_sorted[:args.num_samples]: 139 | entry = { 140 | "Id": kdx, 141 | "Robot": robot_type, 142 | "Goal Pose": T_goal.as_matrix(), 143 | "Sol. Config": q_sols_np[ii], 144 | "Err. Pose": e_pose[ii], 145 | "Err. Position": e_pos[ii], 146 | "Err. Rotation": e_rot[ii], 147 | } 148 | sol_data.append(entry) 149 | all_sol_data.append(pd.DataFrame(sol_data)) 150 | 151 | pd_data = pd.concat(all_sol_data) 152 | exp_dir = f"{sys.path[0]}/results/" 153 | os.makedirs(exp_dir, exist_ok=True) 154 | pd_data.to_pickle(os.path.join(exp_dir, "results.pkl")) 155 | 156 | def parse_experiment_infeasible_poses_args(): 157 | parser = argparse.ArgumentParser() 158 | 159 | # General settings 160 | parser.add_argument("--model_paths", nargs="*", type=str, default=["/home/olimoyo/generative-graphik/saved_models/paper_models/kuka_512k_model"], help="Path to folder with model") 161 | parser.add_argument("--infeasible_pose_paths", nargs="*", type=str, default=["/home/olimoyo/generative-graphik/datasets/infeasible_poses/infeasible_poses_kuka.pkl"], help="Path to folder with infeasible poses to test with.") 162 | parser.add_argument("--robots", nargs="*", type=str, default=["kuka"], help="Robots to test on") 163 | parser.add_argument("--dataset_paths", nargs="*", type=str, default=["/media/stonehenge/users/oliver-limoyo/2.56m-kuka"], help="Path to folder with infeasible poses to test with.") 164 | parser.add_argument('--device', type=str, default='cpu', help='Device to use for PyTorch') 165 | parser.add_argument("--num_samples", type=int, default=32, help="Total number of samples per problem") 166 | args = parser.parse_args() 167 | return args 168 | 169 | if __name__ == "__main__": 170 | random.seed(3) 171 | args = parse_experiment_infeasible_poses_args() 172 | infeasible_poses = run_experiment_infeasible_poses(args) 173 | 174 | -------------------------------------------------------------------------------- /experiments/experiment_infeasible_poses/generate_infeasible_poses.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import os 4 | 5 | import torch 6 | import numpy as np 7 | import pickle 8 | 9 | from generative_graphik.utils.dataset_generation import CachedDataset 10 | from generative_graphik.utils.dataset_generation import generate_data_point 11 | from tracikpy import TracIKSolver 12 | import graphik 13 | from graphik.utils.roboturdf import ( 14 | RobotURDF, 15 | load_ur10, 16 | load_kuka, 17 | load_schunk_lwa4d, 18 | load_schunk_lwa4p, 19 | load_panda, 20 | ) 21 | 22 | def ee_error(ee1, ee2): 23 | ee_diff = np.linalg.inv(ee1) @ ee2 24 | trans_err = np.linalg.norm(ee_diff[:3, 3], ord=1) 25 | angle_err = np.arccos(np.trace(ee_diff[:3, :3] - 1) / 2) 26 | return trans_err, angle_err 27 | 28 | def find_infeasible_poses(args): 29 | robot_types = args.robots 30 | 31 | # Initialize tracIK 32 | for robot_type in robot_types: 33 | if robot_type == "ur10": 34 | robot, graph = load_ur10(limits=None) 35 | fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 36 | link_base, link_ee = 'base_link', 'ee_link' 37 | elif robot_type == "kuka": 38 | robot, graph = load_kuka(limits=None) 39 | fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 40 | link_base, link_ee = 'lbr_iiwa_link_0', 'ee_link' 41 | elif robot_type == "lwa4d": 42 | robot, graph = load_schunk_lwa4d(limits=None) 43 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 44 | link_base, link_ee = 'lwa4d_base_link', 'lwa4d_ee_link' 45 | elif robot_type == "panda": 46 | robot, graph = load_panda(limits=None) 47 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 48 | link_base, link_ee = 'panda_link0', 'panda_link7' 49 | elif robot_type == "lwa4p": 50 | robot, graph = load_schunk_lwa4p(limits=None) 51 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 52 | link_base, link_ee = 'lwa4p_base_link', 'lwa4p_ee_link' 53 | else: 54 | raise NotImplementedError 55 | 56 | ik_solver = TracIKSolver(fname, link_base, link_ee, timeout=0.5) 57 | # Set joint limits in [-pi, pi] 58 | ub = np.ones(ik_solver.number_of_joints) * np.pi 59 | lb = -ub 60 | ik_solver.joint_limits = lb, ub 61 | 62 | # Pick the top N as seed poses and add random noise to find infeasible poses 63 | infeasible_poses = [] 64 | for _ in range(args.n_poses): 65 | 66 | # Generate random problem 67 | prob_data = generate_data_point(graph) 68 | q_goal = prob_data.q_goal.numpy() 69 | q_goal[-1] += 0.0001 70 | T_goal = ik_solver.fk(q_goal) 71 | 72 | # Add some noise 73 | translation_noise = np.random.uniform( 74 | -args.translational_noise, 75 | args.translational_noise, 76 | size=(3,) 77 | ) 78 | T_goal[:3, 3] += translation_noise 79 | 80 | # Test if pose is feasible 81 | for _ in range(32): 82 | # Initialize randomly 83 | q_init = np.array(list(robot.random_configuration().values())) 84 | 85 | # Solve using tracIK and find poses that fail 86 | qout = ik_solver.ik( 87 | T_goal, 88 | qinit=q_init, 89 | bx = 1e-2, 90 | by = 1e-2, 91 | bz = 1e-2 92 | ) 93 | if qout is not None: 94 | break 95 | 96 | if qout is None: 97 | print("Failure detected") 98 | infeasible_poses.append(T_goal) 99 | 100 | print(f"Total failures for {robot_type}: {len(infeasible_poses)} / {args.n_poses}") 101 | os.makedirs(args.save_path, exist_ok=True) 102 | with open(os.path.join(args.save_path, f"infeasible_poses_{robot_type}.pkl"), 'wb') as f: 103 | # Dump the list of NumPy arrays into the file 104 | pickle.dump(infeasible_poses, f) 105 | 106 | def parse_generate_infeasible_poses_args(): 107 | parser = argparse.ArgumentParser() 108 | 109 | # General settings 110 | parser.add_argument("--save_path", type=str, default="/home/olimoyo/generative-graphik/datasets/infeasible_poses", help="Path to folder to save poses") 111 | parser.add_argument("--robots", nargs="*", type=str, default=["kuka", "panda", "ur10", "lwa4p", "lwa4d"], help="Robots to test on") 112 | parser.add_argument("--n_poses", type=int, default=2400, help="Number of poses to search per robot") 113 | parser.add_argument("--translational_noise", type=float, default=0.05, help="Noise in metres to add.") 114 | args = parser.parse_args() 115 | return args 116 | 117 | if __name__ == "__main__": 118 | random.seed(3) 119 | args = parse_generate_infeasible_poses_args() 120 | infeasible_poses = find_infeasible_poses(args) 121 | -------------------------------------------------------------------------------- /experiments/experiment_infeasible_poses/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python experiment_infeasible_poses.py \ 3 | --robots kuka panda ur10 lwa4d lwa4p \ 4 | --model_paths /home/olimoyo/generative-graphik/saved_models/paper_models/kuka_512k_model /home/olimoyo/generative-graphik/saved_models/paper_models/panda_512k_model /home/olimoyo/generative-graphik/saved_models/paper_models/ur10_512k_model /home/olimoyo/generative-graphik/saved_models/paper_models/lwa4d_512k_model /home/olimoyo/generative-graphik/saved_models/paper_models/lwa4p_512k_model \ 5 | --infeasible_pose_paths /home/olimoyo/generative-graphik/datasets/infeasible_poses/infeasible_poses_kuka.pkl /home/olimoyo/generative-graphik/datasets/infeasible_poses/infeasible_poses_panda.pkl /home/olimoyo/generative-graphik/datasets/infeasible_poses/infeasible_poses_ur10.pkl /home/olimoyo/generative-graphik/datasets/infeasible_poses/infeasible_poses_lwa4d.pkl /home/olimoyo/generative-graphik/datasets/infeasible_poses/infeasible_poses_lwa4p.pkl\ 6 | --dataset_paths /media/stonehenge/users/oliver-limoyo/2.56m-kuka /media/stonehenge/users/oliver-limoyo/2.56m-panda /media/stonehenge/users/oliver-limoyo/2.56m-ur10 /media/stonehenge/users/oliver-limoyo/2.56m-lwa4d /media/stonehenge/users/oliver-limoyo/2.56m-lwa4p \ 7 | --device cuda:1 \ 8 | -------------------------------------------------------------------------------- /experiments/experiment_infeasible_poses/table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import importlib.util 3 | import json 4 | # import tikzplotlib 5 | import os 6 | import sys 7 | 8 | from seaborn.utils import despine 9 | 10 | os.environ["PYOPENGL_PLATFORM"] = "egl" 11 | import random 12 | import copy 13 | import pandas as pd 14 | import time 15 | 16 | # import graphik 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | from generative_graphik.args.utils import str2bool 20 | import argparse 21 | import seaborn as sns 22 | sns.set_theme(style="darkgrid") 23 | 24 | def main(args): 25 | data = pd.read_pickle(f"{sys.path[0]}/results/results.pkl") 26 | stats = data.reset_index() 27 | stats["Err. Position"] = stats["Err. Position"]*1000 28 | stats["Err. Rotation"] = stats["Err. Rotation"]*(180/np.pi) 29 | q_pos = stats["Err. Position"].quantile(0.99) 30 | q_rot = stats["Err. Rotation"].quantile(0.99) 31 | stats = stats.drop(stats[stats["Err. Position"] > q_pos].index) 32 | stats = stats.drop(stats[stats["Err. Rotation"] > q_rot].index) 33 | 34 | stats = stats.groupby(["Robot", "Id"])[["Err. Position", "Err. Rotation"]].describe().groupby("Robot").mean() 35 | 36 | stats = stats.drop(["count", "std", "50%"], axis=1, level=1) 37 | 38 | perc_data = data.set_index(["Robot", "Id"]) 39 | # perc_data["Success"] = ( 40 | # (perc_data["Err. Position"] < 0.01) & (perc_data["Err. Rotation"] < (180/np.pi)) 41 | # ) 42 | # suc_pos_perc = ( 43 | # perc_data["Success"] 44 | # .eq(True) 45 | # .groupby(level=[0, 1]) 46 | # .value_counts(True) 47 | # .unstack(fill_value=0) 48 | # ) 49 | # stats["Success [\%]"] = suc_pos_perc.groupby(level=0).apply(lambda c: (c>0).sum()/len(c))[True]*100 50 | 51 | stats.rename(columns = {'75%': 'Q$_{3}$', '25%': 'Q$_{1}$','Err. Position':'Err. Pos. [mm]', 'Err. Rotation':'Err. Rot. [deg]'}, inplace = True) 52 | 53 | # Swap to follow paper order 54 | cols = stats.columns.tolist() 55 | ins = cols.pop(4) 56 | cols.insert(2, ins) 57 | ins = cols.pop(9) 58 | cols.insert(7, ins) 59 | stats = stats[cols] 60 | 61 | if args.save_latex: 62 | s = stats.style 63 | s.format(precision=1) 64 | s.format_index(axis=1,level=[0,1]) 65 | latex = s.to_latex(hrules=True, multicol_align="c") 66 | print(latex) 67 | 68 | # open text file 69 | # text_file = open(args.latex_path + "tables/experiment_2.tex", "w") 70 | # write string to file 71 | # text_file.write(latex) 72 | 73 | if __name__ == "__main__": 74 | random.seed(17) 75 | parser = argparse.ArgumentParser() 76 | 77 | # General settings 78 | parser.add_argument("--save_latex", type=str2bool, default=True, help="Save latex table.") 79 | parser.add_argument("--latex_path", type=str, default="/home/filipmrc/Documents/Latex/2022-limoyo-maric-generative-corl/tables/experiment_2.tex", help="Base path for folder with experiment data") 80 | 81 | args = parser.parse_args() 82 | main(args) -------------------------------------------------------------------------------- /experiments/experiment_joint_limits/analyze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 8 | 9 | python3 joint_visualize.py \ 10 | --id "${NAME}_experiment" \ 11 | --save_path "/home/olimoyo/generative-graphik/experiments/experiment_joint_limits/results/" \ -------------------------------------------------------------------------------- /experiments/experiment_joint_limits/experiment_joint_limits.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import sys 5 | import argparse 6 | import math 7 | 8 | from graphik.graphs import ProblemGraphRevolute 9 | from graphik.graphs.graph_revolute import list_to_variable_dict 10 | from graphik.robots import RobotRevolute 11 | 12 | import torch_geometric 13 | from torch_geometric.data import DataLoader 14 | from generative_graphik.utils.dataset_generation import generate_data_point 15 | 16 | from generative_graphik.utils.torch_utils import batchFKmultiDOF, batchIKmultiDOF, node_attributes 17 | 18 | os.environ["PYOPENGL_PLATFORM"] = "egl" 19 | import random 20 | from argparse import Namespace 21 | import copy 22 | import pandas as pd 23 | import time 24 | 25 | import graphik 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | import torch 29 | from generative_graphik.args.parser import parse_analysis_args 30 | from graphik.utils.roboturdf import ( 31 | RobotURDF, 32 | load_ur10, 33 | load_kuka, 34 | load_schunk_lwa4d, 35 | load_schunk_lwa4p, 36 | load_panda, 37 | ) 38 | # import pyrender 39 | 40 | from graphik.utils.dgp import graph_from_pos 41 | from liegroups.numpy import SE3, SO3 42 | 43 | def model_arg_loader(path): 44 | """Load hyperparameters from trained model.""" 45 | if os.path.isdir(path): 46 | with open(os.path.join(path, "hyperparameters.txt"), "r") as fp: 47 | return Namespace(**json.load(fp)) 48 | 49 | # NOTE generates all the initializations and stores them to a pickle file 50 | def main(args): 51 | device = args.device 52 | num_evals = args.n_evals # number of evaluations 53 | robot_types = args.robots 54 | 55 | evals_per_robot = num_evals // len(robot_types) 56 | for model_path in args.model_path:# number of evaluations per robots 57 | spec = importlib.util.spec_from_file_location("model", model_path + "model.py") 58 | model = importlib.util.module_from_spec(spec) 59 | spec.loader.exec_module(model) 60 | 61 | # load models 62 | model_args = model_arg_loader(model_path) 63 | model = model.Model(model_args).to(device) 64 | # model = torch_geometric.compile(model, fullgraph=True) 65 | name = model_args.id.replace("model", "results") 66 | c = np.pi / 180 67 | 68 | if model_path is not None: 69 | try: 70 | state_dict = torch.load(model_path + f"checkpoints/checkpoint.pth", map_location=device) 71 | model.load_state_dict(state_dict["net"]) 72 | model.eval() 73 | except Exception as e: 74 | print(e) 75 | 76 | all_sol_data = [] 77 | fig_handle, ax_handle = plt.subplots(nrows=2, ncols=2, figsize=(8, 8)) 78 | for robot_type in robot_types: 79 | if robot_type == "panda": 80 | limits_l = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973]) 81 | limits_u = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973]) 82 | limits = [limits_l, limits_u] 83 | robot, graph = load_panda(limits=limits) 84 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 85 | urdf_robot = RobotURDF(fname) 86 | else: 87 | raise NotImplementedError 88 | 89 | joint_limits = np.stack([limits_l, limits_u]) 90 | for kdx in range(evals_per_robot): 91 | sol_data = [] 92 | 93 | # Generate random problem 94 | prob_data = generate_data_point(graph, joint_limits=joint_limits).to(device) 95 | prob_data.num_graphs = 1 96 | data = model.preprocess(prob_data) 97 | 98 | # T_goal = SE3.exp(data.T_ee.cpu().numpy()) 99 | T_goal = SE3.exp(data.T_ee[0].cpu().numpy()) 100 | 101 | # P_all = model.forward_eval(data, num_samples=args.num_samples).cpu().detach().numpy() 102 | # Compute solutions 103 | P_all = ( 104 | model.forward_eval( 105 | x_partial=data.pos_partial, 106 | h=torch.cat((data.type, data.goal_data_repeated_per_node), dim=-1), 107 | edge_attr_partial=data.edge_attr_partial, 108 | edge_index=data.edge_index_full, 109 | nodes_per_single_graph= int(data.num_nodes / 1), 110 | batch_size=1, 111 | num_samples=args.num_samples 112 | ) 113 | ).cpu().detach().numpy() 114 | 115 | # Analyze solutions 116 | e_pose = np.empty([P_all.shape[0]]) 117 | e_pos = np.empty([P_all.shape[0]]) 118 | e_rot = np.empty([P_all.shape[0]]) 119 | q_sols_np = np.empty([P_all.shape[0], robot.n]) 120 | for idx in range(P_all.shape[0]): 121 | P = P_all[idx, :] 122 | 123 | q_sol = graph.joint_variables( 124 | graph_from_pos(P, graph.node_ids), {robot.end_effectors[0]: T_goal} 125 | ) # get joint angles 126 | q_sols_np[idx] = np.fromiter( 127 | (q_sol[f"p{jj}"] for jj in range(1, graph.robot.n + 1)), dtype=float 128 | ) 129 | # normalize to match panda joints and joint limits 130 | q_sols_np[idx][5] = q_sols_np[idx][5] % (2 * np.pi) 131 | 132 | T_ee = graph.robot.pose(q_sol, robot.end_effectors[-1]) 133 | e_pose[idx] = np.linalg.norm(T_ee.inv().dot(T_goal).log()) 134 | e_pos[idx] = np.linalg.norm(T_ee.trans - T_goal.trans) 135 | e_rot[idx] = np.linalg.norm(T_ee.rot.inv().dot(T_goal.rot).log()) 136 | 137 | entry = { 138 | "Id": kdx, 139 | "Robot": robot_type, 140 | "Goal Pose": T_goal.as_matrix(), 141 | "Sol. Config": q_sols_np[idx], 142 | "Err. Pose": e_pose[idx], 143 | "Err. Position": e_pos[idx], 144 | "Err. Rotation": e_rot[idx], 145 | } 146 | sol_data.append(entry) 147 | all_sol_data.append(pd.DataFrame(sol_data)) 148 | 149 | pd_data = pd.concat(all_sol_data) 150 | 151 | exp_dir = f"{sys.path[0]}/results/"+ f"{args.id}/" 152 | os.makedirs(exp_dir, exist_ok=True) 153 | pd_data.to_pickle(os.path.join(exp_dir, "results.pkl")) 154 | 155 | 156 | def parse_experiment_joint_limits_args(): 157 | parser = argparse.ArgumentParser() 158 | 159 | # General settings 160 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 161 | parser.add_argument("--model_path", nargs="*", type=str, required=True, help="Path to folder with model data") 162 | parser.add_argument('--device', type=str, default='cpu', help='Device to use for PyTorch') 163 | parser.add_argument("--robots", nargs="*", type=str, default=["planar_chain"], help="Type of robot used") 164 | parser.add_argument("--n_evals", type=int, default=1000, help="Number of evaluations") 165 | parser.add_argument("--num_samples", type=int, default=1, help="Total number of samples per problem") 166 | 167 | args = parser.parse_args() 168 | return args 169 | 170 | if __name__ == "__main__": 171 | random.seed(17) 172 | args = parse_experiment_joint_limits_args() 173 | main(args) -------------------------------------------------------------------------------- /experiments/experiment_joint_limits/joint_visualize.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | import numpy as np 4 | import os 5 | import sys 6 | import math 7 | import pickle as pkl 8 | 9 | import pandas as pd 10 | import seaborn as sns 11 | sns.set_theme(style="darkgrid") 12 | sns.set_style({"xtick.direction": "in","ytick.direction": "in"}) 13 | sns.set(font_scale=2.750, rc={'text.usetex' : True, "font.family": "Computer Modern"}) 14 | params = ["xmajorticks=true", "ymajorticks=true", "xtick pos=left", "ytick pos=left"] 15 | import matplotlib.pyplot as plt 16 | 17 | def main(args): 18 | path = f"{sys.path[0]}/results/{args.id}/results.pkl" 19 | with open(path, 'rb') as f: 20 | data = pkl.load(f) 21 | 22 | dataset = { 23 | 'A': [], 24 | 'B': [], 25 | 'C': [], 26 | 'D': [], 27 | 'E': [], 28 | 'F': [], 29 | 'G': [] 30 | } 31 | 32 | for q in data.get('Sol. Config'): 33 | for i, key in enumerate(dataset.keys()): 34 | dataset[key].append(q[i]) 35 | 36 | save_path = os.path.join(args.save_path, f"{args.id}/plots/") 37 | os.makedirs(save_path, exist_ok=True) 38 | 39 | limits_l = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973]) 40 | limits_u = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973]) 41 | 42 | plt.figure(figsize=(10, 6)) # Adjust the width (10) as per your requirement 43 | df = pd.DataFrame(dataset) 44 | 45 | graph = sns.stripplot(x="variable", y="value", data=df.melt(), jitter=0.225, orient="v", s =3) 46 | # Get unique categories on x-axis and their positions 47 | categories = df.columns 48 | category_positions = np.arange(len(categories)) 49 | category_width = 0.6 # Adjust this width to control the width of the shaded rectangle 50 | 51 | # Draw horizontal lines at y-values specified by limits_l and limits_u for each category 52 | for i, (limit_l, limit_u) in enumerate(zip(limits_l, limits_u)): 53 | tick_center = category_positions[i] # Center of the current tick 54 | # Shading the area between the dashed lines 55 | plt.fill_between( 56 | [tick_center - category_width/2, tick_center + category_width/2], 57 | limit_l, 58 | limit_u, 59 | color='green', 60 | alpha=0.25 61 | ) 62 | 63 | plt.xticks(ticks=category_positions, labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$", r"$\theta_4$", r"$\theta_5$", r"$\theta_6$", r"$\theta_7$"]) 64 | plt.xlabel("") # Remove x-axis label 65 | plt.ylabel("Joint Angle [rads]") # Rename y-axis label 66 | plt.tight_layout() 67 | plt.savefig( 68 | os.path.join(save_path, "joint_limits.pdf") 69 | ) 70 | 71 | if __name__ == "__main__": 72 | random.seed(17) 73 | parser = argparse.ArgumentParser() 74 | 75 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 76 | parser.add_argument("--save_path", type=str, required=True, help="Path to folder with model data") 77 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for PyTorch') 78 | 79 | args = parser.parse_args() 80 | main(args) 81 | -------------------------------------------------------------------------------- /experiments/experiment_joint_limits/results/paper_models/panda_512k_experiment/plots/joint_limits.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utiasSTARS/generative-graphik/9263453cc7e1df506e57370d6558eca5aeb4b998/experiments/experiment_joint_limits/results/paper_models/panda_512k_experiment/plots/joint_limits.pdf -------------------------------------------------------------------------------- /experiments/experiment_joint_limits/results/paper_models/panda_512k_experiment/results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utiasSTARS/generative-graphik/9263453cc7e1df506e57370d6558eca5aeb4b998/experiments/experiment_joint_limits/results/paper_models/panda_512k_experiment/results.pkl -------------------------------------------------------------------------------- /experiments/experiment_joint_limits/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 8 | 9 | # copy the model and training code if new 10 | if [ -d "${MODEL_PATH}" ] 11 | then 12 | echo "Directory already exists, using existing model." 13 | else 14 | echo "Model not found!" 15 | fi 16 | 17 | python3 experiment_joint_limits.py \ 18 | --id "${NAME}_experiment" \ 19 | --model_path "${MODEL_PATH}/" \ 20 | --device cuda:1 \ 21 | --robots panda -------------------------------------------------------------------------------- /experiments/experiment_timing/analyze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 8 | 9 | python3 joint_visualize.py \ 10 | --id "${NAME}_experiment" \ 11 | --save_path "/home/olimoyo/generative-graphik/experiments/experiment_joint_limits/results/" \ -------------------------------------------------------------------------------- /experiments/experiment_timing/experiment_timing.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import sys 5 | os.environ["PYOPENGL_PLATFORM"] = "egl" 6 | import random 7 | from argparse import Namespace 8 | import time 9 | import pickle as pkl 10 | import numpy as np 11 | import argparse 12 | 13 | import torch 14 | import torch_geometric 15 | from generative_graphik.utils.dataset_generation import ( 16 | generate_data_point, 17 | random_revolute_robot_graph, 18 | ) 19 | 20 | 21 | def model_arg_loader(path): 22 | """Load hyperparameters from trained model.""" 23 | if os.path.isdir(path): 24 | with open(os.path.join(path, "hyperparameters.txt"), "r") as fp: 25 | return Namespace(**json.load(fp)) 26 | 27 | # NOTE generates all the initializations and stores them to a pickle file 28 | def main(args): 29 | device = args.device 30 | 31 | for model_path in args.model_path:# number of evaluations per robots 32 | spec = importlib.util.spec_from_file_location("model", model_path + "model.py") 33 | model = importlib.util.module_from_spec(spec) 34 | spec.loader.exec_module(model) 35 | 36 | # load models 37 | model_args = model_arg_loader(model_path) 38 | model = model.Model(model_args).to(device) 39 | if model_path is not None: 40 | try: 41 | model.load_state_dict( 42 | torch.load(model_path + f"/net.pth", map_location=device) 43 | ) 44 | model.eval() 45 | except Exception as e: 46 | print(e) 47 | model = torch_geometric.compile( 48 | model, 49 | mode="max-autotune", 50 | fullgraph=True 51 | ) 52 | 53 | sample_amounts = [1, 16, 64, 128, 256, 512, 1024] 54 | joint_amounts = ["6", "12", "18", "24", "30", "36"] 55 | results = [] 56 | 57 | for sample_amount in sample_amounts: 58 | for joint_amount in joint_amounts: 59 | starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) 60 | 61 | #XXX: Warm up forward pass 62 | for _ in range(2): 63 | num_joints = int(joint_amount) 64 | graph = random_revolute_robot_graph(num_joints) 65 | 66 | # Generate random problem 67 | prob_data = generate_data_point(graph).to(device) 68 | prob_data.num_graphs = 1 69 | prob_data.T_ee = prob_data.T_ee.unsqueeze(0) 70 | data = model.preprocess(prob_data) 71 | 72 | # Compute solutions 73 | _ = model.forward_eval( 74 | x=data.pos, 75 | h=torch.cat((data.type, data.goal_data_repeated_per_node), dim=-1), 76 | edge_attr=data.edge_attr, 77 | edge_attr_partial=data.edge_attr_partial, 78 | edge_index=data.edge_index_full, 79 | partial_goal_mask=data.partial_goal_mask, 80 | nodes_per_single_graph = int(data.num_nodes / 1), 81 | num_samples=sample_amount, 82 | batch_size=1 83 | ) 84 | # torch.cuda.synchronize() # wait for warm-up to finish 85 | 86 | for _ in range(16): 87 | num_joints = int(joint_amount) 88 | graph = random_revolute_robot_graph(num_joints) 89 | 90 | # Generate random problem 91 | prob_data = generate_data_point(graph).to(device) 92 | prob_data.num_graphs = 1 93 | prob_data.T_ee = prob_data.T_ee.unsqueeze(0) 94 | data = model.preprocess(prob_data) 95 | 96 | # Compute solutions 97 | # torch.cuda.synchronize() 98 | # t0 = time.time() 99 | starter.record() 100 | _ = model.forward_eval( 101 | x=data.pos, 102 | h=torch.cat((data.type, data.goal_data_repeated_per_node), dim=-1), 103 | edge_attr=data.edge_attr, 104 | edge_attr_partial=data.edge_attr_partial, 105 | edge_index=data.edge_index_full, 106 | partial_goal_mask=data.partial_goal_mask, 107 | nodes_per_single_graph = int(data.num_nodes / 1), 108 | num_samples=sample_amount, 109 | batch_size=1 110 | ) 111 | ender.record() 112 | torch.cuda.synchronize() 113 | # t_sol = time.time() - t0 114 | t_sol = starter.elapsed_time(ender) 115 | results.append((sample_amount, t_sol, joint_amount)) 116 | 117 | exp_dir = f"{sys.path[0]}/results/"+ f"{args.id}/" 118 | os.makedirs(exp_dir, exist_ok=True) 119 | with open(os.path.join(exp_dir, 'results.pkl'), 'wb') as handle: 120 | pkl.dump(results, handle, protocol=pkl.HIGHEST_PROTOCOL) 121 | 122 | if __name__ == "__main__": 123 | random.seed(15) 124 | parser = argparse.ArgumentParser() 125 | 126 | # General settings 127 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 128 | parser.add_argument("--model_path", nargs="*", type=str, required=True, help="Path to folder with model data") 129 | parser.add_argument('--device', type=str, default='cuda:1', help='Device to use for PyTorch') 130 | 131 | args = parser.parse_args() 132 | main(args) 133 | -------------------------------------------------------------------------------- /experiments/experiment_timing/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 8 | 9 | # copy the model and training code if new 10 | if [ -d "${MODEL_PATH}" ] 11 | then 12 | echo "Directory already exists, using existing model." 13 | else 14 | echo "Model not found!" 15 | fi 16 | 17 | python3 experiment_timing.py \ 18 | --id "${NAME}_experiment" \ 19 | --model_path "${MODEL_PATH}/" \ 20 | --device cuda:1 \ 21 | -------------------------------------------------------------------------------- /experiments/experiment_timing/timing_curve.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | import matplotlib.pyplot as plt 5 | import tikzplotlib 6 | import random 7 | import argparse 8 | import pickle as pkl 9 | import numpy as np 10 | import seaborn as sns 11 | sns.set_theme(style="darkgrid") 12 | sns.set_style({"xtick.direction": "in","ytick.direction": "in"}) 13 | sns.set(font_scale=1.65, rc={'text.usetex' : True, "font.family": "Computer Modern"}) 14 | params = ["xmajorticks=true", "ymajorticks=true", "xtick pos=left", "ytick pos=left"] 15 | import pandas as pd 16 | 17 | def tikzplotlib_fix_ncols(obj): 18 | """ 19 | workaround for matplotlib 3.6 renamed legend's _ncol to _ncols, which breaks tikzplotlib 20 | """ 21 | if hasattr(obj, "_ncols"): 22 | obj._ncol = obj._ncols 23 | for child in obj.get_children(): 24 | tikzplotlib_fix_ncols(child) 25 | 26 | def main(args): 27 | path = f"{sys.path[0]}/results/{args.id}/results.pkl" 28 | 29 | with open(path, 'rb') as f: 30 | data = pkl.load(f) 31 | data = np.array(data, dtype=object) 32 | data[:, 1] = data[:, 1] * 1000 33 | df = pd.DataFrame(data, columns=["Number of Sampled Configurations", "Time [ms]", "DOF"]) 34 | out = sns.lineplot(data=df, x="Number of Sampled Configurations", y="Time [ms]", hue="DOF", errorbar="sd") 35 | # plt.yscale('log') 36 | fig = out.get_figure() 37 | 38 | plt.tight_layout() 39 | tikzplotlib_fix_ncols(fig) 40 | 41 | save_path = os.path.join(args.save_path, f"{args.id}/plots/") 42 | os.makedirs(save_path, exist_ok=True) 43 | tikzplotlib.save( 44 | os.path.join(save_path, "timing_plot.tex"), 45 | figure="gcf", 46 | textsize=12.0, 47 | extra_axis_parameters=params, 48 | wrap=False, 49 | ) 50 | 51 | fig.savefig( 52 | os.path.join(save_path, "timing_plot.pdf") 53 | ) 54 | 55 | if __name__ == "__main__": 56 | random.seed(17) 57 | parser = argparse.ArgumentParser() 58 | 59 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 60 | parser.add_argument("--save_path", type=str, required=True, help="Path to folder with model data") 61 | parser.add_argument('--device', type=str, default='cuda', help='Device to use for PyTorch') 62 | 63 | args = parser.parse_args() 64 | main(args) 65 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/analyze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/paper_models/${NAME}_model" 8 | 9 | python table.py \ 10 | --id "${NAME}_experiment" \ 11 | --save_latex True \ 12 | --latex_path "/home/filipmrc/Documents/Latex/2022-limoyo-maric-generative-corl/" \ 13 | 14 | python dist_images.py \ 15 | --id "${NAME}_experiment" \ 16 | --robots ur10 kuka panda \ 17 | --n_evals 9 \ 18 | --model_path "${MODEL_PATH}/" \ 19 | --device cpu \ 20 | --randomize False \ 21 | --num_samples 32 22 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/dist_images.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import sys 5 | 6 | os.environ["PYOPENGL_PLATFORM"] = "egl" 7 | import random 8 | from argparse import Namespace 9 | import copy 10 | import pandas as pd 11 | import time 12 | 13 | import graphik 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | import torch 17 | from generative_graphik.args.parser import parse_analysis_args 18 | from generative_graphik.utils.dataset_generation import ( 19 | generate_data_point, 20 | random_revolute_robot_graph, 21 | ) 22 | from graphik.utils.roboturdf import ( 23 | RobotURDF, 24 | load_ur10, 25 | load_kuka, 26 | load_schunk_lwa4d, 27 | load_schunk_lwa4p, 28 | load_panda, 29 | ) 30 | import pyrender 31 | 32 | from graphik.utils.dgp import graph_from_pos 33 | from liegroups.numpy import SE3, SO3 34 | from graphik.utils.urdf_visualization import make_scene 35 | 36 | 37 | def model_arg_loader(path): 38 | """Load hyperparameters from trained model.""" 39 | if os.path.isdir(path): 40 | with open(os.path.join(path, "hyperparameters.txt"), "r") as fp: 41 | return Namespace(**json.load(fp)) 42 | 43 | 44 | def plot_revolute_manipulator_robot(robot, configs, transparency=0.7): 45 | 46 | # scene = make_scene(robot, with_balls=False, with_edges=False) 47 | scene = None 48 | for idx in range(len(configs)): 49 | if idx==0: 50 | trn=1 51 | else: 52 | trn=transparency 53 | scene = make_scene( 54 | robot, 55 | scene = scene, 56 | q=configs[idx], 57 | with_frames=False, 58 | with_balls=False, 59 | with_edges=False, 60 | with_robot=True, 61 | transparency=trn, 62 | ) 63 | 64 | camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0) 65 | s = np.sqrt(2) / 2 66 | 67 | # panda 68 | # camera_pose = SE3(SO3.identity(), np.array([0,-1.45,0.25])).dot(SE3(SO3.rotx(np.pi/2), np.array([0,0,0]))) 69 | # ur10 70 | # camera_pose = SE3(SO3.identity(), np.array([0,-1.75,0.35])).dot(SE3(SO3.rotx(np.pi/2), np.array([0,0,0]))) 71 | # kuka 72 | # camera_pose = SE3(SO3.identity(), np.array([0,-1.75,0.35])).dot(SE3(SO3.rotx(np.pi/2), np.array([0,0,0]))) 73 | # lwa4d 74 | camera_pose = SE3(SO3.identity(), np.array([0,-1.75,0.35])).dot(SE3(SO3.rotx(np.pi/2), np.array([0,0,0]))) 75 | camera_pose = camera_pose.as_matrix() 76 | # camera_pose = np.array( 77 | # [ 78 | # [0.0, -s, s, 1.25], 79 | # [1.0, 0.0, 0.0, 0.0], 80 | # [0.0, s, s, 1.5], 81 | # [0.0, 0.0, 0.0, 1.0], 82 | # ] 83 | # ) 84 | 85 | scene.add(camera, pose=camera_pose) 86 | light = pyrender.SpotLight( 87 | color=np.ones(3), 88 | intensity=3.0, 89 | innerConeAngle=np.pi / 16.0, 90 | outerConeAngle=np.pi / 6.0, 91 | ) 92 | scene.add(light, pose=camera_pose) 93 | r = pyrender.OffscreenRenderer(2048, 2048) 94 | color, depth = r.render(scene) 95 | return color 96 | 97 | 98 | # NOTE generates all the initializations and stores them to a pickle file 99 | def main(args): 100 | device = args.device 101 | num_evals = args.n_evals # number of evaluations 102 | robot_types = args.robots 103 | dofs = args.dofs # number of dof we test on 104 | 105 | evals_per_robot = num_evals // len(robot_types) # number of evaluations per robots 106 | spec = importlib.util.spec_from_file_location("model", args.model_path[0] + "model.py") 107 | model = importlib.util.module_from_spec(spec) 108 | spec.loader.exec_module(model) 109 | 110 | # load models 111 | model_args = model_arg_loader(args.model_path[0]) 112 | model = model.Model(model_args).to(device) 113 | name = model_args.id.replace("model", "results") 114 | exp_dir = f"{sys.path[0]}/results/TRO/"+ f"{args.id}/images/" 115 | os.makedirs(exp_dir, exist_ok=True) 116 | c = np.pi / 180 117 | 118 | if args.model_path is not None: 119 | try: 120 | model.load_state_dict( 121 | torch.load(args.model_path[0] + f"/net.pth", map_location=device) 122 | ) 123 | model.eval() 124 | except Exception as e: 125 | print(e) 126 | 127 | all_sol_data = [] 128 | # fig_handle, ax_handle = plt.subplots(nrows=2, ncols=2, figsize=(8, 8)) 129 | for robot_type in robot_types: 130 | if robot_type == "ur10": 131 | limits_l = -np.array([180, 180, 180, 180, 180, 180]) * c 132 | limits_u = np.array([180, 180, 180, 180, 180, 180]) * c 133 | limits = [limits_l, limits_u] 134 | robot, graph = load_ur10(limits=limits) 135 | fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 136 | urdf_robot = RobotURDF(fname) 137 | elif robot_type == "kuka": 138 | limits_l = -np.array([170, 120, 170, 120, 170, 120, 170]) * c 139 | limits_u = np.array([170, 120, 170, 120, 170, 120, 170]) * c 140 | limits = [limits_l, limits_u] 141 | robot, graph = load_kuka(limits=limits) 142 | fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 143 | urdf_robot = RobotURDF(fname) 144 | elif robot_type == "lwa4d": 145 | limits_l = -np.array([180, 123, 180, 125, 180, 170, 170]) * c 146 | limits_u = np.array([180, 123, 180, 125, 180, 170, 170]) * c 147 | limits = [limits_l, limits_u] 148 | robot, graph = load_schunk_lwa4d(limits=None) 149 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 150 | urdf_robot = RobotURDF(fname) 151 | elif robot_type == "panda": 152 | limits_l = -np.array( 153 | [2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 3.7525, 2.8973] 154 | ) 155 | limits_u = np.array( 156 | [2.8973, 1.7628, 2.8973, 3.0718, 2.8973, 3.7525, 2.8973] 157 | ) 158 | limits = [limits_l, limits_u] 159 | robot, graph = load_panda(limits=None) 160 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 161 | urdf_robot = RobotURDF(fname) 162 | elif robot_type == "lwa4p": 163 | limits_l = -np.array([170, 170, 155.3, 170, 170, 170]) * c 164 | limits_u = np.array([170, 170, 155.3, 170, 170, 170]) * c 165 | limits = [limits_l, limits_u] 166 | robot, graph = load_schunk_lwa4p(limits=limits) 167 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 168 | urdf_robot = RobotURDF(fname) 169 | else: 170 | raise NotImplementedError 171 | fig= plt.figure() 172 | for kdx in range(evals_per_robot): 173 | sol_data = [] 174 | 175 | # Generate random problem 176 | prob_data = generate_data_point(graph).to(device) 177 | prob_data.num_graphs = 1 178 | data = model.preprocess(prob_data) 179 | P_goal = data.pos.cpu().numpy() 180 | T_goal = SE3.exp(data.T_ee.cpu().numpy()) 181 | q_goal = graph.joint_variables(graph_from_pos(P_goal, graph.node_ids)) 182 | 183 | # Compute solutions 184 | P_all = ( 185 | model.forward_eval(data, num_samples=args.num_samples[0]*4) 186 | .cpu() 187 | .detach() 188 | # .numpy() 189 | ) 190 | src, dst = data["edge_index_full"] 191 | dist_samples = ((P_all[:,src] - P_all[:,dst])**2).sum(dim=-1).sqrt() 192 | dist_diff = dist_samples[:, data.partial_mask] - data["edge_attr_partial"].t() 193 | dist_diff_norm, _ = torch.max(torch.abs(dist_diff), dim=-1) 194 | ind = torch.argsort(dist_diff_norm) 195 | P_all = P_all[ind[:args.num_samples[0]]] 196 | 197 | # torch.cuda.synchronize() 198 | 199 | # Analyze solutions 200 | q_sols = [q_goal] 201 | for idx in range(P_all.shape[0]): 202 | P = P_all[idx, :] 203 | 204 | q_sol = graph.joint_variables( 205 | graph_from_pos(P, graph.node_ids), {robot.end_effectors[0]: T_goal} 206 | ) # get joint angles 207 | q_sols.append(q_sol) 208 | 209 | config_img = plot_revolute_manipulator_robot( 210 | urdf_robot, 211 | q_sols, 212 | transparency=0.15 213 | ) 214 | fig = plt.imshow(config_img) 215 | fig.axes.get_xaxis().set_visible(False) 216 | fig.axes.get_yaxis().set_visible(False) 217 | # fig.axes.set_title("Distribution") 218 | # plt.pause(0.1) 219 | os.makedirs(exp_dir, exist_ok=True) 220 | plt.savefig(exp_dir + str(robot_type) + '_' + str(kdx) + '.png', dpi=512) 221 | 222 | 223 | if __name__ == "__main__": 224 | random.seed(17) 225 | args = parse_analysis_args() 226 | main(args) 227 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/experiment_2.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import sys 5 | from graphik.graphs import ProblemGraphRevolute 6 | from graphik.graphs.graph_revolute import list_to_variable_dict 7 | from graphik.robots import RobotRevolute 8 | 9 | from torch_geometric.data import DataLoader 10 | from generative_graphik.utils.dataset_generation import generate_data_point 11 | 12 | from generative_graphik.utils.torch_utils import SE3_inv, batchFKmultiDOF, batchIKmultiDOF, node_attributes, torch_log_from_T 13 | 14 | os.environ["PYOPENGL_PLATFORM"] = "egl" 15 | import random 16 | from argparse import Namespace 17 | import copy 18 | import pandas as pd 19 | import time 20 | 21 | import graphik 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | import torch 25 | from generative_graphik.args.parser import parse_analysis_args 26 | from graphik.utils.roboturdf import ( 27 | RobotURDF, 28 | load_ur10, 29 | load_kuka, 30 | load_schunk_lwa4d, 31 | load_schunk_lwa4p, 32 | load_panda, 33 | ) 34 | # import pyrender 35 | 36 | from graphik.utils.dgp import graph_from_pos 37 | from liegroups.numpy import SE3, SO3 38 | 39 | def model_arg_loader(path): 40 | """Load hyperparameters from trained model.""" 41 | if os.path.isdir(path): 42 | with open(os.path.join(path, "hyperparameters.txt"), "r") as fp: 43 | return Namespace(**json.load(fp)) 44 | 45 | def filter_by_distance(P_all, data, norm = 'inf'): 46 | src, dst = data["edge_index_full"] 47 | dist_samples = ((P_all[:,src] - P_all[:,dst])**2).sum(dim=-1).sqrt() 48 | dist_diff = dist_samples[:, data.partial_mask] - data["edge_attr_partial"].t() 49 | 50 | if norm == 'inf': 51 | dist_diff_norm, _ = torch.max(torch.abs(dist_diff), dim=-1) 52 | else: 53 | dist_diff_norm = (dist_diff**2).sum(dim=-1).sqrt() 54 | 55 | ind = torch.argsort(dist_diff_norm) 56 | return ind 57 | 58 | def filter_by_error(T_ee, T_final_inv): 59 | e_pose = torch_log_from_T(torch.bmm(T_final_inv.expand(T_ee.shape[0],-1,-1), T_ee)) 60 | e_pose_norm = torch.norm(e_pose, dim=-1) 61 | ind = torch.argsort(e_pose_norm) 62 | return ind 63 | 64 | 65 | # NOTE generates all the initializations and stores them to a pickle file 66 | def main(args): 67 | device = args.device 68 | num_evals = args.n_evals # number of evaluations 69 | robot_types = args.robots 70 | dofs = args.dofs # number of dof we test on 71 | evals_per_robot = num_evals // len(robot_types) 72 | for model_path in args.model_path:# number of evaluations per robots 73 | spec = importlib.util.spec_from_file_location("model", model_path + "model.py") 74 | model = importlib.util.module_from_spec(spec) 75 | spec.loader.exec_module(model) 76 | 77 | # load models 78 | model_args = model_arg_loader(model_path) 79 | model = model.Model(model_args).to(device) 80 | name = model_args.id.replace("model", "results") 81 | c = np.pi / 180 82 | 83 | if model_path is not None: 84 | try: 85 | state_dict = torch.load(model_path + f"checkpoints/checkpoint.pth", map_location=device) 86 | model.load_state_dict(state_dict["net"]) 87 | # model.load_state_dict( 88 | # torch.load(model_path + f"/net.pth", map_location=device) 89 | # ) 90 | model.eval() 91 | except Exception as e: 92 | print(e) 93 | 94 | all_sol_data = [] 95 | fig_handle, ax_handle = plt.subplots(nrows=2, ncols=2, figsize=(8, 8)) 96 | for robot_type in robot_types: 97 | if robot_type == "ur10": 98 | robot, graph = load_ur10(limits=None) 99 | fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 100 | urdf_robot = RobotURDF(fname) 101 | 102 | # # UR10 coordinates for testing 103 | # modified_dh = False 104 | # a = [0, -0.612, 0.5723, 0, 0, 0] 105 | # d = [0.1273, 0, 0, 0.1639, 0.1157, 0.0922] 106 | # al = [np.pi / 2, 0, 0, np.pi / 2, -np.pi / 2, 0] 107 | # # th = [0, np.pi, 0, 0, 0, 0] 108 | # th = [0, 0, 0, 0, 0, 0] 109 | 110 | # params = { 111 | # "a": a, 112 | # "alpha": al, 113 | # "d": d, 114 | # "theta": th, 115 | # "modified_dh": modified_dh, 116 | # "num_joints": 6, 117 | # } 118 | # robot = RobotRevolute(params) 119 | # graph = ProblemGraphRevolute(robot) 120 | elif robot_type == "kuka": 121 | limits_l = -np.array([170, 120, 170, 120, 170, 120, 170]) * c 122 | limits_u = np.array([170, 120, 170, 120, 170, 120, 170]) * c 123 | limits = [limits_l, limits_u] 124 | robot, graph = load_kuka(limits=None) 125 | fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 126 | urdf_robot = RobotURDF(fname) 127 | 128 | # modified_dh = False 129 | # a = [0, 0, 0, 0, 0, 0, 0] 130 | # d = [0.34, 0, 0.40, 0, 0.40, 0, 0.126] 131 | # al = [-np.pi / 2, np.pi / 2, np.pi / 2, -np.pi / 2, -np.pi / 2, np.pi / 2, 0] 132 | # th = [0, 0, 0, 0, 0, 0, 0] 133 | 134 | # params = { 135 | # "a": a, 136 | # "alpha": al, 137 | # "d": d, 138 | # "theta": th, 139 | # "modified_dh": modified_dh, 140 | # "num_joints": 7, 141 | # } 142 | # robot = RobotRevolute(params) 143 | # graph = ProblemGraphRevolute(robot) 144 | elif robot_type == "lwa4d": 145 | limits_l = -np.array([180, 123, 180, 125, 180, 170, 170]) * c 146 | limits_u = np.array([180, 123, 180, 125, 180, 170, 170]) * c 147 | limits = [limits_l, limits_u] 148 | robot, graph = load_schunk_lwa4d(limits=None) 149 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 150 | urdf_robot = RobotURDF(fname) 151 | 152 | # modified_dh = False 153 | # a = [0, 0, 0, 0, 0, 0, 0] 154 | # d = [0.3, 0, 0.328, 0, 0.323, 0, 0.0824] 155 | # al = [-np.pi / 2, np.pi / 2, -np.pi / 2, np.pi / 2, -np.pi / 2, np.pi / 2, 0] 156 | # th = [0, 0, 0, 0, 0, 0, 0] 157 | 158 | # params = { 159 | # "a": a, 160 | # "alpha": al, 161 | # "d": d, 162 | # "theta": th, 163 | # "modified_dh": modified_dh, 164 | # "num_joints": 7, 165 | # } 166 | # robot = RobotRevolute(params) 167 | # graph = ProblemGraphRevolute(robot) 168 | elif robot_type == "panda": 169 | limits_l = -np.array( 170 | [2.8973, 1.7628, 2.8973, 0.0698, 2.8973, 3.7525, 2.8973] 171 | ) 172 | limits_u = np.array( 173 | [2.8973, 1.7628, 2.8973, 3.0718, 2.8973, 3.7525, 2.8973] 174 | ) 175 | limits = [limits_l, limits_u] 176 | robot, graph = load_panda(limits=None) 177 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 178 | urdf_robot = RobotURDF(fname) 179 | 180 | # modified_dh = False 181 | # a = [0, 0, 0, 0.0825, -0.0825, 0, 0.088] 182 | # d = [0.333, 0, 0.316, 0, 0.384, 0, 0] 183 | # al = [0, -np.pi/2, np.pi / 2, np.pi / 2, -np.pi / 2, np.pi / 2, np.pi / 2] 184 | # th = [0, 0, 0, 0, 0, 0, 0] 185 | 186 | # params = { 187 | # "a": a, 188 | # "alpha": al, 189 | # "d": d, 190 | # "theta": th, 191 | # "modified_dh": modified_dh, 192 | # "num_joints": 7, 193 | # } 194 | # robot = RobotRevolute(params) 195 | # graph = ProblemGraphRevolute(robot) 196 | elif robot_type == "lwa4p": 197 | limits_l = -np.array([170, 170, 155.3, 170, 170, 170]) * c 198 | limits_u = np.array([170, 170, 155.3, 170, 170, 170]) * c 199 | limits = [limits_l, limits_u] 200 | robot, graph = load_schunk_lwa4p(limits=None) 201 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 202 | urdf_robot = RobotURDF(fname) 203 | 204 | # modified_dh = False 205 | # a = [0, 0.350, 0, 0, 0, 0] 206 | # d = [0.205, 0, 0, 0.305, 0, 0.075] 207 | # al = [-np.pi / 2, np.pi, -np.pi / 2, np.pi / 2, -np.pi / 2, 0] 208 | # th = [0, 0, 0, 0, 0, 0] 209 | 210 | # params = { 211 | # "a": a, 212 | # "alpha": al, 213 | # "d": d, 214 | # "theta": th, 215 | # "modified_dh": modified_dh, 216 | # "num_joints": 6, 217 | # } 218 | # robot = RobotRevolute(params) 219 | # graph = ProblemGraphRevolute(robot) 220 | else: 221 | raise NotImplementedError 222 | 223 | for kdx in range(evals_per_robot): 224 | sol_data = [] 225 | 226 | # Generate random problem 227 | prob_data = generate_data_point(graph).to(device) 228 | prob_data.num_graphs = 1 229 | data = model.preprocess(prob_data) 230 | num_samples_pre = args.num_samples[0]*4 231 | 232 | T_goal = SE3.exp(data.T_ee.cpu().numpy()) 233 | T_final = torch.tensor(T_goal.as_matrix(), dtype=torch.float32).unsqueeze(0).to(device) 234 | T_final_inv = torch.tensor(T_goal.inv().as_matrix(), dtype=torch.float32).unsqueeze(0).to(device) 235 | T_goal = T_goal.as_matrix() 236 | ee_ind = torch.cumsum(prob_data.num_joints.expand(num_samples_pre) + 1, dim=0) - 1 # end indices of joints 237 | 238 | data.goal_data_repeated_per_node = torch.repeat_interleave(1, 2*data.num_joints + model.num_anchor_nodes, dim=0) 239 | t0 = time.time() 240 | # Compute solutions 241 | # P_all = ( 242 | # model.forward_eval(data, num_samples=num_samples_pre).to(device).detach() 243 | # ) 244 | P_all = model.forward_eval( 245 | x=data.pos, 246 | h=torch.cat((data.type, data.goal_data_repeated_per_node), dim=-1), 247 | edge_attr=data.edge_attr, 248 | edge_attr_partial=data.edge_attr_partial, 249 | edge_index=data.edge_index_full, 250 | partial_goal_mask=data.partial_goal_mask, 251 | nodes_per_single_graph = int(data.num_nodes / 1), 252 | num_samples=num_samples_pre, 253 | batch_size=data.num_graphs 254 | ) 255 | 256 | # P_all = filter_by_distance(P_all, data, args.num_samples[0]) 257 | # torch.cuda.synchronize() 258 | 259 | # q_sols_ = batchIKmultiDOF( 260 | # P_all.reshape(-1,3), 261 | # prob_data.T0.repeat(args.num_samples[0],1,1), 262 | # prob_data.num_joints.expand(args.num_samples[0]), 263 | # T_final = T_final.expand(args.num_samples[0],-1,-1) 264 | # ) 265 | 266 | # T = batchFKmultiDOF( 267 | # prob_data.T0.repeat(args.num_samples[0],1,1), 268 | # q_sols_.reshape(-1,1), 269 | # prob_data.num_joints.expand(args.num_samples[0]) 270 | # ) 271 | # ee_ind = torch.cumsum(prob_data.num_joints.expand(args.num_samples[0]) + 1, dim=0) - 1 # end indices of joints 272 | 273 | q_sols_ = batchIKmultiDOF( 274 | P_all.reshape(-1,3), 275 | prob_data.T0.repeat(num_samples_pre,1,1), 276 | prob_data.num_joints.expand(num_samples_pre), 277 | T_final = T_final.expand(num_samples_pre,-1,-1) 278 | ).reshape(-1,prob_data.num_joints) 279 | 280 | T = batchFKmultiDOF( 281 | prob_data.T0.repeat(num_samples_pre,1,1), 282 | q_sols_.reshape(-1,1), 283 | prob_data.num_joints.expand(num_samples_pre) 284 | ) 285 | T_ee = T[ee_ind] 286 | ind = filter_by_error(T_ee, T_final_inv.expand(num_samples_pre,-1,-1)) 287 | t_sol = time.time() - t0 288 | 289 | q_sols_ = q_sols_[ind[:args.num_samples[0]]] 290 | e_pose = torch_log_from_T(torch.bmm(T_final_inv.expand(args.num_samples[0],-1,-1), T_ee[ind[:args.num_samples[0]]])) 291 | e_pose_norm = torch.norm(e_pose, dim=-1).cpu().numpy() 292 | e_pos_norm = torch.norm(e_pose[:,:3], dim=-1).cpu().numpy() 293 | e_rot_norm = torch.norm(e_pose[:,3:], dim=-1).cpu().numpy() 294 | for idx in range(args.num_samples[0]): 295 | entry = { 296 | "Id": kdx, 297 | "Robot": robot_type, 298 | "Goal Pose": T_goal, 299 | "Sol. Config": q_sols_[idx], 300 | # "Sol. Points": P_all[idx,:], 301 | "Err. Pose": e_pose_norm[idx], 302 | "Err. Position": e_pos_norm[idx], 303 | "Err. Rotation": e_rot_norm[idx], 304 | # "Goal Config": q_goal_np, 305 | # "Goal Points": P_goal, 306 | "Sol. Time": t_sol, 307 | } 308 | sol_data.append(entry) 309 | all_sol_data.append(pd.DataFrame(sol_data)) 310 | print(kdx) 311 | 312 | pd_data = pd.concat(all_sol_data) 313 | 314 | # exp_dir = f"{sys.path[0]}/results/"+ f"{args.id}/" 315 | exp_dir = f"{sys.path[0]}/results/TRO/"+ f"{args.id}/" 316 | os.makedirs(exp_dir, exist_ok=True) 317 | pd_data.to_pickle(os.path.join(exp_dir, "results.pkl")) 318 | pd_data.to_csv(os.path.join(exp_dir, "results.csv")) 319 | 320 | 321 | if __name__ == "__main__": 322 | random.seed(17) 323 | args = parse_analysis_args() 324 | main(args) 325 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/paper_models/${NAME}_model" 8 | 9 | export PYTORCH_ENABLE_MPS_FALLBACK=1 10 | # copy the model and training code if new 11 | if [ -d "${MODEL_PATH}" ] 12 | then 13 | echo "Directory already exists, using existing model." 14 | else 15 | echo "Model not found!" 16 | fi 17 | 18 | # python experiment_2.py \ 19 | # --id "${NAME}_experiment" \ 20 | # --robots panda kuka lwa4d lwa4p ur10 \ 21 | # --n_evals 100 \ 22 | # --model_path "${MODEL_PATH}/" \ 23 | # --device cuda \ 24 | # --num_samples 32 25 | 26 | python experiment_2_tracik.py \ 27 | --id "${NAME}_tracik_experiment" \ 28 | --robots panda kuka lwa4d lwa4p ur10 \ 29 | --n_evals 10000 \ 30 | --model_path "${MODEL_PATH}/" \ 31 | --device cuda \ 32 | --num_samples 32 33 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/run_tracik.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | NAME=$1 4 | 5 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 6 | SRC_PATH="${SCRIPT_DIR}/../.." 7 | MODEL_PATH="${SRC_PATH}/saved_models/paper_models/${NAME}_model" 8 | 9 | export PYTORCH_ENABLE_MPS_FALLBACK=1 10 | # copy the model and training code if new 11 | if [ -d "${MODEL_PATH}" ] 12 | then 13 | echo "Directory already exists, using existing model." 14 | else 15 | echo "Model not found!" 16 | fi 17 | 18 | python tracik_comparison.py \ 19 | --id "${NAME}_experiment" \ 20 | --robots ur10 kuka lwa4p lwa4d panda \ 21 | --n_evals 100 \ 22 | --model_path "${MODEL_PATH}/" \ 23 | --device cuda \ 24 | --num_samples 32 25 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import json 3 | import os 4 | import sys 5 | import importlib.util 6 | 7 | 8 | os.environ["PYOPENGL_PLATFORM"] = "egl" 9 | import random 10 | import copy 11 | import pandas as pd 12 | import time 13 | 14 | # import graphik 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | from generative_graphik.args.parser import parse_analysis_args 18 | from generative_graphik.args.utils import str2bool 19 | import argparse 20 | # import tikzplotlib 21 | # import seaborn as sns 22 | # from seaborn.utils import despine 23 | # sns.set_theme(style="darkgrid") 24 | 25 | 26 | def parse_analysis_args(): 27 | parser = argparse.ArgumentParser() 28 | 29 | # General settings 30 | parser.add_argument( 31 | "--id", 32 | type=str, 33 | default="test_experiment", 34 | help="Name of the folder with experiment data", 35 | ) 36 | parser.add_argument( 37 | "--save_latex", 38 | type=str2bool, 39 | default=True, 40 | help="Save latex table.", 41 | ) 42 | parser.add_argument( 43 | "--latex_path", 44 | type=str, 45 | default="/home/filipmrc/Documents/Latex/2022-limoyo-maric-generative-corl/tables/experiment_2.tex", 46 | help="Base path for folder with experiment data", 47 | ) 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def main(args): 54 | # data = pd.read_pickle(f"{sys.path[0]}/results/{args.id}/results.pkl") 55 | data = pd.read_pickle(f"{sys.path[0]}/results/TRO/{args.id}/results.pkl") 56 | 57 | # # Smallest error samples per problem 58 | # min_pos_error_norm = data.groupby(["Robot", "Id"])["Err. Position"].min() 59 | # min_rot_error_norm = data.groupby(["Robot", "Id"])["Err. Rotation"].min() 60 | # min_pose_error_norm = data.groupby(["Robot", "Id"])["Err. Pose"].min() 61 | 62 | # # Percentage of samples per robot that have error lower than some criteria 63 | # suc_pos_perc_all = [] 64 | # suc_rot_perc_all = [] 65 | # perc_data = data.set_index(["Robot", "Id"]) 66 | # pos_increment = 0.005 67 | # rot_increment = np.pi / 180 68 | # resolution = 1 69 | # for idx in [1, 5, 10]: 70 | # perc_data["Suc. Pos"] = ( 71 | # perc_data["Err. Position"] < (idx / resolution) * pos_increment 72 | # ) 73 | # suc_pos_perc = ( 74 | # perc_data["Suc. Pos"] 75 | # .eq(True) 76 | # .groupby(level=[0, 1]) 77 | # .value_counts(True) 78 | # .unstack(fill_value=0) 79 | # ) * 100 80 | # suc_pos_perc["Error"] = (idx / resolution) * pos_increment * 100 # centimeters 81 | # suc_pos_perc["Err. Type"] = "Pos" 82 | # suc_pos_perc_all.append(suc_pos_perc) 83 | 84 | # perc_data["Suc. Rot"] = ( 85 | # perc_data["Err. Rotation"] < (idx / resolution) * rot_increment 86 | # ) 87 | # suc_rot_perc = ( 88 | # perc_data["Suc. Rot"] 89 | # .eq(True) 90 | # .groupby(level=[0, 1]) 91 | # .value_counts(True) 92 | # .unstack(fill_value=0) 93 | # ) * 100 94 | # suc_rot_perc["Error"] = idx / resolution # degrees 95 | # suc_rot_perc["Err. Type"] = "Rot" 96 | # suc_rot_perc_all.append(suc_rot_perc) 97 | 98 | 99 | # # Average lowest errors per robot 100 | # avg_min_pos_err_norm = min_pos_error_norm.groupby("Robot").mean()*100 101 | # avg_min_pose_err_norm = min_pose_error_norm.groupby("Robot").mean() 102 | # avg_min_rot_err_norm = min_rot_error_norm.groupby("Robot").mean()*100 103 | 104 | # # Average time it took to generate the samples 105 | # avg_sol_t = data.groupby("Robot")["Sol. Time"].mean() 106 | 107 | # # Table data 108 | # data_dict = { 109 | # "Avg. Pos. Err.": avg_min_pos_err_norm, 110 | # # "$\bar{e_{pos}}$": avg_min_pos_err_norm, 111 | # # "Avg. Pose Err.": avg_min_pose_err_norm, 112 | # "Avg. Rot. Err.": avg_min_rot_err_norm, 113 | # "Avg. Sol. Time": avg_sol_t, 114 | # } 115 | 116 | # suc_perc_concat = pd.concat(suc_pos_perc_all + suc_rot_perc_all).reset_index() 117 | # for err_cut in sorted(suc_perc_concat["Error"].unique()): 118 | # if err_cut in suc_perc_concat[suc_perc_concat["Err. Type"] == "Pos"]["Error"].unique(): 119 | # data_dict[f"Err. Pos. $<$ {err_cut}"] = suc_perc_concat[ 120 | # (suc_perc_concat["Err. Type"] == "Pos") & (suc_perc_concat["Error"] == err_cut) 121 | # ].groupby("Robot")[True].mean() 122 | 123 | # if err_cut in suc_perc_concat[suc_perc_concat["Err. Type"] == "Rot"]["Error"].unique(): 124 | # data_dict[f"Err. Rot. $<$ {err_cut}"] = suc_perc_concat[ 125 | # (suc_perc_concat["Err. Type"] == "Rot") & (suc_perc_concat["Error"] == err_cut) 126 | # ].groupby("Robot")[True].mean() 127 | 128 | # # Assemble into new DataFrame 129 | # pd_all = pd.DataFrame(data_dict) 130 | # print(pd_all) 131 | # pd_all.index.name = None 132 | 133 | # # Format data 134 | # format_dict = { 135 | # "Avg. Pos. Err.": "{:.2f}", 136 | # "Avg. Pose Err.": "{:.2f}", 137 | # "Avg. Rot. Err.": "{:.2f}", 138 | # "Avg. Sol. Time": "{:.2f}", 139 | # } 140 | # for err_cut in suc_perc_concat["Error"].unique(): 141 | # format_dict[f"Err. Pos. $<$ {err_cut}"] = "{:.1f}" 142 | # format_dict[f"Err. Rot. $<$ {err_cut}"] = "{:.1f}" 143 | 144 | # # Print to latex table 145 | # if args.save_latex: 146 | # s = pd_all.style 147 | # s.format(format_dict) 148 | # latex = s.to_latex(hrules=True) 149 | # print(latex) 150 | 151 | # # open text file 152 | # text_file = open(args.latex_path + "tables/experiment_2.tex", "w") 153 | 154 | # # write string to file 155 | # text_file.write(latex) 156 | 157 | stats = data.reset_index() 158 | stats["Err. Position"] = stats["Err. Position"]*1000 159 | stats["Err. Rotation"] = stats["Err. Rotation"]*(180/np.pi) 160 | q_pos = stats["Err. Position"].quantile(0.99) 161 | q_rot = stats["Err. Rotation"].quantile(0.99) 162 | stats = stats.drop(stats[stats["Err. Position"] > q_pos].index) 163 | stats = stats.drop(stats[stats["Err. Rotation"] > q_rot].index) 164 | 165 | stats = stats.groupby(["Robot", "Id"])[["Err. Position", "Err. Rotation"]].describe().groupby("Robot").mean() 166 | stats = stats.drop(["count", "std", "50%"], axis=1, level=1) 167 | 168 | perc_data = data.set_index(["Robot", "Id"]) 169 | perc_data["Success"] = ( 170 | (perc_data["Err. Position"] < 0.01) & (perc_data["Err. Rotation"] < (180/np.pi)) 171 | ) 172 | suc_pos_perc = ( 173 | perc_data["Success"] 174 | .eq(True) 175 | .groupby(level=[0, 1]) 176 | .value_counts(True) 177 | .unstack(fill_value=0) 178 | ) 179 | stats["Success [\%]"] = suc_pos_perc.groupby(level=0).apply(lambda c: (c>0).sum()/len(c))[True]*100 180 | 181 | stats.rename(columns = {'75%': 'Q$_{3}$', '25%': 'Q$_{1}$','Err. Position':'Err. Pos. [mm]', 'Err. Rotation':'Err. Rot. [deg]'}, inplace = True) 182 | # Swap to follow paper order 183 | cols = stats.columns.tolist() 184 | ins = cols.pop(4) 185 | cols.insert(2, ins) 186 | ins = cols.pop(9) 187 | cols.insert(7, ins) 188 | stats = stats[cols] 189 | 190 | if args.save_latex: 191 | s = stats.style 192 | s.format(precision=1) 193 | s.format_index(axis=1,level=[0,1]) 194 | latex = s.to_latex(hrules=True, multicol_align="c") 195 | print(latex) 196 | 197 | # open text file 198 | # text_file = open(args.latex_path + "tables/experiment_2.tex", "w") 199 | # write string to file 200 | # text_file.write(latex) 201 | if __name__ == "__main__": 202 | random.seed(17) 203 | args = parse_analysis_args() 204 | main(args) 205 | -------------------------------------------------------------------------------- /experiments/experiment_tracIK/tracik_comparison.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import sys 5 | import tracikpy 6 | import random 7 | import copy 8 | import pandas as pd 9 | import time 10 | import graphik 11 | import numpy as np 12 | import torch 13 | 14 | from argparse import Namespace 15 | from tracikpy import TracIKSolver 16 | from graphik.graphs import ProblemGraphRevolute 17 | from graphik.graphs.graph_revolute import list_to_variable_dict 18 | from graphik.robots import RobotRevolute 19 | from torch_geometric.data import DataLoader 20 | from generative_graphik.utils.dataset_generation import generate_data_point 21 | # from generative_graphik.utils.torch_utils import SE3_inv, batchFKmultiDOF, batchIKmultiDOF, node_attributes, torch_log_from_T 22 | 23 | os.environ["PYOPENGL_PLATFORM"] = "egl" 24 | 25 | from generative_graphik.args.parser import parse_analysis_args 26 | from graphik.utils.roboturdf import ( 27 | RobotURDF, 28 | load_ur10, 29 | load_kuka, 30 | load_schunk_lwa4d, 31 | load_schunk_lwa4p, 32 | load_panda, 33 | ) 34 | 35 | from liegroups.numpy import SE3, SO3 36 | 37 | # NOTE generates all the initializations and stores them to a pickle file 38 | def main(args): 39 | np.random.seed(0) 40 | device = args.device 41 | num_evals = args.n_evals # number of evaluations 42 | robot_types = args.robots 43 | dofs = args.dofs # number of dof we test on 44 | evals_per_robot = num_evals // len(robot_types) 45 | all_sol_data = [] 46 | for robot_type in robot_types: 47 | if robot_type == "ur10": 48 | robot, graph = load_ur10(limits=None) 49 | fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 50 | link_base, link_ee = 'base_link', 'ee_link' 51 | elif robot_type == "kuka": 52 | # limits_l = -np.array([170, 120, 170, 120, 170, 120, 170]) * c 53 | # limits_u = np.array([170, 120, 170, 120, 170, 120, 170]) * c 54 | # limits = [limits_l, limits_u] 55 | robot, graph = load_kuka(limits=None) 56 | fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 57 | link_base, link_ee = 'lbr_iiwa_link_0', 'ee_link' 58 | elif robot_type == "lwa4d": 59 | # limits_l = -np.array([180, 123, 180, 125, 180, 170, 170]) * c 60 | # limits_u = np.array([180, 123, 180, 125, 180, 170, 170]) * c 61 | # limits = [limits_l, limits_u] 62 | robot, graph = load_schunk_lwa4d(limits=None) 63 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 64 | link_base, link_ee = 'lwa4d_base_link', 'lwa4d_ee_link' 65 | elif robot_type == "panda": 66 | # limits_l = -np.array( 67 | # [2.8973, 1.7628, 2.8973, 3.0718, 2.8973, 0.0175, 2.8973] 68 | # ) 69 | # limits_u = np.array( 70 | # [2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973] 71 | # ) 72 | # limits = [limits_l, limits_u] 73 | robot, graph = load_panda(limits=None) 74 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 75 | link_base, link_ee = 'panda_link0', 'panda_link7' 76 | elif robot_type == "lwa4p": 77 | # limits_l = -np.array([170, 170, 155.3, 170, 170, 170]) * c 78 | # limits_u = np.array([170, 170, 155.3, 170, 170, 170]) * c 79 | # limits = [limits_l, limits_u] 80 | robot, graph = load_schunk_lwa4p(limits=None) 81 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 82 | link_base, link_ee = 'lwa4p_base_link', 'lwa4p_ee_link' 83 | else: 84 | raise NotImplementedError 85 | 86 | t_sol_max = 0.1 87 | ik_solver = TracIKSolver(fname, link_base, link_ee, timeout=t_sol_max) 88 | ub = np.ones(ik_solver.number_of_joints) * np.pi 89 | lb = -ub 90 | ik_solver.joint_limits = lb, ub 91 | # q_init = np.zeros(ik_solver.number_of_joints) 92 | for kdx in range(evals_per_robot): 93 | sol_data = [] 94 | q_init = np.array(list(robot.random_configuration().values())) 95 | 96 | # Generate random problem 97 | prob_data = generate_data_point(graph).to(device) 98 | q_goal = prob_data.q_goal.numpy() 99 | q_goal[-1] += 0.0001 100 | T_goal = ik_solver.fk(q_goal) 101 | 102 | t0 = time.time() 103 | q_sol = ik_solver.ik(T_goal, qinit=q_init, bx = 1e-3, by = 1e-3, bz = 1e-3) 104 | t_sol = time.time() - t0 105 | if q_sol is None: 106 | q_sol = q_init 107 | t_sol = t_sol_max 108 | # import ipdb; ipdb.set_trace() 109 | 110 | T_ee = ik_solver.fk(q_sol) 111 | 112 | e_pose = (SE3.from_matrix(T_goal, normalize=True).inv().dot(SE3.from_matrix(T_ee, normalize=True))).log() 113 | e_pose_norm = np.linalg.norm(e_pose) 114 | e_pos_norm = np.linalg.norm(e_pose[:3]) 115 | e_rot_norm = np.linalg.norm(e_pose[3:]) 116 | 117 | entry = { 118 | "Id": kdx, 119 | "Robot": robot_type, 120 | "Goal Pose": SE3.from_matrix(T_goal), 121 | "Sol. Config": [q_sol], 122 | # "Sol. Points": P_all[idx,:], 123 | "Err. Pose": e_pose_norm, 124 | "Err. Position": e_pos_norm, 125 | "Err. Rotation": e_rot_norm, 126 | # "Goal Config": q_goal_np, 127 | # "Goal Points": P_goal, 128 | "Sol. Time": t_sol, 129 | } 130 | all_sol_data.append(pd.DataFrame(entry)) 131 | # print(kdx) 132 | 133 | pd_data = pd.concat(all_sol_data) 134 | 135 | stats = pd_data.reset_index() 136 | stats["Err. Position"] = stats['Err. Position']*1000 137 | stats["Err. Rotation"] = stats['Err. Rotation']*(180/np.pi) 138 | q_pos = stats['Err. Position'].quantile(0.99) 139 | q_rot = stats['Err. Rotation'].quantile(0.99) 140 | stats = stats.drop(stats[stats['Err. Position'] > q_pos].index) 141 | stats = stats.drop(stats[stats['Err. Rotation'] > q_rot].index) 142 | 143 | stats = stats.groupby(['Robot'])[['Err. Position', 'Err. Rotation', 'Sol. Time']].describe() 144 | # stats = stats.drop(['count', 'std', '50%'], axis=1, level=1) 145 | stats = stats.drop(['count', '50%'], axis=1, level=1) 146 | perc_data = pd_data.set_index(['Robot']) 147 | perc_data['Success'] = ( 148 | (perc_data['Err. Position'] < 0.01) & (perc_data['Err. Rotation'] < (180/np.pi)) 149 | ) 150 | import ipdb; ipdb.set_trace() 151 | stats['Success [\%]'] = perc_data['Success'].eq(True).groupby(level=0).value_counts(True).unstack(fill_value=0)[True]*100 152 | # suc_pos_perc = ( 153 | # perc_data["Success"] 154 | # .eq(True) 155 | # .groupby(level=[0, 1]) 156 | # .value_counts(True) 157 | # .unstack(fill_value=0) 158 | # ) 159 | # stats["Success [\%]"] = suc_pos_perc.groupby(level=0).apply(lambda c: (c>0).sum()/len(c))[True]*100 160 | 161 | stats.rename(columns = {'75%': 'Q$_{3}$', '25%': 'Q$_{1}$','Err. Position':'Err. Pos. [mm]', 'Err. Rotation':'Err. Rot. [deg]'}, inplace = True) 162 | # Swap to follow paper order 163 | cols = stats.columns.tolist() 164 | ins = cols.pop(4) 165 | cols.insert(2, ins) 166 | ins = cols.pop(9) 167 | cols.insert(7, ins) 168 | stats = stats[cols] 169 | print(stats) 170 | print(stats.mean()) 171 | # import ipdb; ipdb.set_trace() 172 | # # print(pd_data.groupby('Robot').mean()) 173 | 174 | # for robot_type in robot_types: 175 | # print(robot_type) 176 | # success_data = pd_data[(pd_data['Err. Position'] < 1e-2) & (pd_data['Err. Rotation'] < (180/np.pi)) & (pd_data['Robot']==robot_type)] 177 | # print(success_data.describe()) 178 | # print(len(success_data)/evals_per_robot) 179 | # print('------------------------') 180 | 181 | # exp_dir = f"{sys.path[0]}/results/TRO/"+ f"{args.id}/" 182 | # os.makedirs(exp_dir, exist_ok=True) 183 | # pd_data.to_pickle(os.path.join(exp_dir, "results.pkl")) 184 | # pd_data.to_csv(os.path.join(exp_dir, "results.csv")) 185 | 186 | 187 | if __name__ == "__main__": 188 | random.seed(17) 189 | args = parse_analysis_args() 190 | main(args) 191 | -------------------------------------------------------------------------------- /generative_graphik/args/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from generative_graphik.args.utils import str2bool, str2inttuple, str2tuple, str2floattuple 4 | 5 | def parse_training_args(): 6 | parser = argparse.ArgumentParser() 7 | 8 | # Experiment Settings 9 | parser.add_argument('--device', type=str, default='cpu', help='Device to use for PyTorch') 10 | parser.add_argument('--cudnn_deterministic', type=str2bool, default=True, help='Use cudnn deterministic') 11 | parser.add_argument('--cudnn_benchmark', type=str2bool, default=False, help='Use cudnn benchmark') 12 | parser.add_argument('--id', type=str, default="None", help='Name of folder used to store model') 13 | parser.add_argument('--random_seed', type=int, default=333, help='Random seed') 14 | parser.add_argument('--use_validation', type=str2bool, default=True, help='Run validation') 15 | parser.add_argument('--debug', type=str2bool, default=True, help='Debug and do not save models or log anything') 16 | parser.add_argument('--storage_base_path', type=str, required=True, help='Base path to store all training data') 17 | parser.add_argument('--training_data_path', type=str, default="planar_chain_size_20000_dof_[5, 7, 9, 11]_randomize_True_partial_True_approx_edges_False_data.p", help='Path to training data') 18 | parser.add_argument('--validation_data_path', type=str, default="planar_chain_size_20000_dof_[5, 7, 9, 11]_randomize_True_partial_True_approx_edges_False_data.p", help='Path to training data') 19 | parser.add_argument('--pretrained_weights_path', type=str, default=None, help='Path to pretrained weights') 20 | parser.add_argument('--module_path', type=str, default="none", help='Path to network module.') 21 | 22 | # Training Settings 23 | parser.add_argument('--n_epoch', type=int, default=4096, help='Number of epochs') 24 | parser.add_argument('--n_scheduler_epoch', type=int, default=25, help='Number of epochs before fixed scheduler steps.') 25 | parser.add_argument('--n_checkpoint_epoch', type=int, default=32, help='Number of epochs for checkpointing') 26 | parser.add_argument('--n_beta_scaling_epoch', type=int, default=1, help='Warm start KL divergence for this amount of epochs.') 27 | parser.add_argument('--n_joint_scaling_epoch', type=int, default=1, help='Warm start joint loss for this amount of epochs.') 28 | parser.add_argument('--n_batch', type=int, default=32, help='Batch size') 29 | parser.add_argument('--n_worker', type=int, default=16, help='Amount of workers for dataloading.') 30 | parser.add_argument('--lr', type=float, default= 3e-4, help='Learning rate') 31 | 32 | # Network parameters 33 | parser.add_argument('--num_anchor_nodes', type=int, default=3, help='Number of anchor nodes') 34 | parser.add_argument('--num_node_features_out', type=int, default=3, help='Size of node features out') 35 | parser.add_argument('--num_coordinates_in', type=int, default=3, help='Size of node coordinates in') 36 | parser.add_argument('--num_features_in', type=int, default=3, help='Size of node features in') 37 | parser.add_argument('--num_edge_features_in', type=int, default=1, help='Size of edge features in') 38 | parser.add_argument('--gnn_type', type=str, default="egnn", help='GNN type used.') 39 | parser.add_argument('--num_gnn_layers', type=int, default=3, help='Number of GNN layers') 40 | parser.add_argument('--num_graph_mlp_layers', type=int, default=0, help='Number of layers for the MLPs used in the graph') 41 | parser.add_argument('--num_egnn_mlp_layers', type=int, default=2, help='Number of layers for the MLPs used in the EGNN layer itself') 42 | parser.add_argument('--num_iterations', type=int, default=1, help='Number of iterations to networks go through') 43 | parser.add_argument('--dim_latent', type=int, default=8, help='Size of latent node features in to encoder') 44 | parser.add_argument('--dim_goal', type=int, default=3, help='Size of goal representation (SE3-->6, SE2-->3)') 45 | parser.add_argument('--num_prior_mixture_components', type=int, default=1, help='Number of mixture components for prior network') 46 | parser.add_argument('--num_likelihood_mixture_components', type=int, default=1, help='Number of mixture components for likelihood network') 47 | parser.add_argument('--train_prior', type=str2bool, default=True, help='Learn prior parameters conditionned on variables.') 48 | parser.add_argument('--rec_gain', type=int, default=80, help='Gain on non-anchor node reconstruction') 49 | parser.add_argument('--non_linearity', type=str, default="silu", help='Non-linearity used.') 50 | parser.add_argument('--dim_latent_node_out', type=int, default=3, help='Size of node feature dim in enc/dec') 51 | parser.add_argument('--graph_mlp_hidden_size', type=int, default=4, help='Size of hiddden layers of MLP used in GNN') 52 | parser.add_argument('--mlp_hidden_size', type=int, default=4, help='Size of all other MLP hiddden layers') 53 | parser.add_argument('--norm_layer', choices=['None', 'BatchNorm', 'LayerNorm', 'GroupNorm', 'InstanceNorm', 'GraphNorm'], default='None', help='Layer normalization method.') 54 | 55 | args = parser.parse_args() 56 | return args 57 | 58 | def parse_data_generation_args(): 59 | parser = argparse.ArgumentParser() 60 | 61 | # General settings 62 | parser.add_argument("--id", type=str, default=None, help="Name of the dataset") 63 | parser.add_argument('--storage_base_path', type=str, default=None, help='Base path for storing dataset') 64 | 65 | # Robot settings 66 | parser.add_argument("--robots", nargs="*", type=str, default=["planar_chain"], help="Type of robot used") 67 | parser.add_argument("--dofs", nargs="*", type=int, default=[5,7,9,11], help="Numbers of DoF that occur in the dataset") 68 | parser.add_argument("--randomize", type=str2bool, default=True, help="Randomize kinematic parameters for every instance") 69 | parser.add_argument("--randomize_percentage", type=float, default=0.2, help="Percentage variation of link lengths.") 70 | parser.add_argument("--goal_type", type=str, default="pose", help="Randomize kinematic parameters for every instance") 71 | parser.add_argument("--obstacles", type=str2bool, default=False, help="Use obstacles") 72 | parser.add_argument("--semantic", type=str2bool, default=False, help="Use semantic tags for nodes.") 73 | parser.add_argument("--load_from_disk", type=str2bool, default=False, help="Save data as separate files.") 74 | 75 | # Problem settings 76 | parser.add_argument("--num_examples", type=int, default=10, help="Total number of problems in the dataset") 77 | parser.add_argument("--validation_percentage", type=int, default=10, help="Percentage of validation data") 78 | parser.add_argument("--num_samples", type=int, default=100, help="Total number of samples per problem") 79 | parser.add_argument("--max_examples_per_file", type=int, default=512, help="Max number of problems per file in the dataset") 80 | 81 | args = parser.parse_args() 82 | return args 83 | 84 | def parse_analysis_args(): 85 | parser = argparse.ArgumentParser() 86 | 87 | # General settings 88 | parser.add_argument("--id", type=str, default="test_experiment", help="Name of the folder with experiment data") 89 | parser.add_argument('--storage_base_path', type=str, default=None, help='Base path for folder with experiment data') 90 | parser.add_argument("--model_path", nargs="*", type=str, required=True, help="Path to folder with model data") 91 | parser.add_argument('--device', type=str, default='cpu', help='Device to use for PyTorch') 92 | parser.add_argument("--num_samples", nargs="*", type=int, default=1, help="Number of samples to generate") 93 | 94 | # Robot settings 95 | parser.add_argument("--robots", nargs="*", type=str, default=["planar_chain"], help="Numbers of DoF that are validated") 96 | parser.add_argument("--dofs", nargs="*", type=int, default=[6,8,10,12], help="Numbers of DoF that are validated") 97 | parser.add_argument("--n_evals", type=int, default=100, help="Number of evaluations") 98 | parser.add_argument("--randomize", type=str2bool, default=True, help="Randomize link lengths during test time.") 99 | 100 | args = parser.parse_args() 101 | return args 102 | -------------------------------------------------------------------------------- /generative_graphik/args/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2inttuple(v): 4 | return tuple([int(item) for item in v.split(',')] if v else []) 5 | 6 | def str2floattuple(v): 7 | return tuple([float(item) for item in v.split(',')] if v else []) 8 | 9 | def str2tuple(v): 10 | return tuple([item for item in v.split(',')] if v else []) 11 | 12 | def str2bool(v): 13 | if isinstance(v, bool): 14 | return v 15 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 16 | return True 17 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 18 | return False 19 | else: 20 | raise argparse.ArgumentTypeError('Boolean value expected.') 21 | -------------------------------------------------------------------------------- /generative_graphik/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from generative_graphik.networks.eqgraph import EqGraph 8 | from generative_graphik.networks.gatgraph import GATGraph 9 | from generative_graphik.networks.gcngraph import GCNGraph 10 | from generative_graphik.networks.sagegraph import SAGEGraph 11 | from generative_graphik.networks.mpnngraph import MPNNGraph 12 | from generative_graphik.networks.linearvae import LinearVAE 13 | from generative_graphik.utils.torch_utils import (MixtureGaussianDiag, 14 | MultivariateNormalDiag, 15 | kl_divergence, 16 | torch_log_from_T, 17 | repeat_offset_index) 18 | 19 | class Model(nn.Module): 20 | def __init__(self, args): 21 | super(Model, self).__init__() 22 | 23 | self.n_beta_scaling_epoch = args.n_beta_scaling_epoch 24 | self.rec_gain = args.rec_gain 25 | self.num_anchor_nodes = args.num_anchor_nodes 26 | self.train_prior = args.train_prior 27 | self.max_num_iterations = args.num_iterations 28 | 29 | if args.non_linearity == "relu": 30 | non_linearity = nn.ReLU() 31 | elif args.non_linearity == "silu": 32 | non_linearity = nn.SiLU() 33 | elif args.non_linearity == "elu": 34 | non_linearity = nn.ELU() 35 | else: 36 | raise NotImplementedError 37 | 38 | if args.gnn_type == "egnn": 39 | gnn = EqGraph 40 | elif args.gnn_type == "gat": 41 | gnn = GATGraph 42 | elif args.gnn_type == "gcn": 43 | gnn = GCNGraph 44 | elif args.gnn_type == "mpnn": 45 | gnn = MPNNGraph 46 | elif args.gnn_type =="graphsage": 47 | gnn = SAGEGraph 48 | else: 49 | raise NotImplementedError 50 | 51 | self.goal_config_encoder = gnn( 52 | latent_dim=args.dim_latent, 53 | out_channels_node=args.dim_latent_node_out, 54 | coordinates_dim=args.num_coordinates_in, 55 | node_features_dim=args.num_features_in + args.dim_goal, 56 | edge_features_dim=args.num_edge_features_in, 57 | mlp_hidden_size=args.graph_mlp_hidden_size, 58 | num_graph_mlp_layers=args.num_graph_mlp_layers, 59 | num_egnn_mlp_layers=args.num_egnn_mlp_layers, 60 | num_gnn_layers=args.num_gnn_layers, 61 | norm_layer=args.norm_layer, 62 | stochastic=False, 63 | num_mixture_components=1, 64 | non_linearity=non_linearity 65 | ) 66 | 67 | self.inference_encoder = LinearVAE( 68 | dim_in=2 * args.dim_latent_node_out, 69 | dim_out=args.dim_latent_node_out, 70 | norm_layer=args.norm_layer, 71 | hidden_size=args.mlp_hidden_size, 72 | stochastic=True, 73 | non_linearity=non_linearity 74 | ) 75 | self.qz_x_dist = MultivariateNormalDiag 76 | 77 | self.goal_partial_config_encoder = gnn( 78 | latent_dim=args.dim_latent, 79 | out_channels_node=args.dim_latent_node_out, 80 | coordinates_dim=args.num_coordinates_in, 81 | node_features_dim=args.num_features_in + args.dim_goal, 82 | edge_features_dim=args.num_edge_features_in, 83 | mlp_hidden_size=args.graph_mlp_hidden_size, 84 | num_graph_mlp_layers=args.num_graph_mlp_layers, 85 | num_egnn_mlp_layers=args.num_egnn_mlp_layers, 86 | num_gnn_layers=args.num_gnn_layers, 87 | norm_layer=args.norm_layer, 88 | stochastic=False, 89 | num_mixture_components=1, 90 | non_linearity=non_linearity 91 | ) 92 | if self.train_prior: 93 | self.prior_encoder = LinearVAE( 94 | dim_in=args.dim_latent_node_out, 95 | dim_out=args.dim_latent_node_out, 96 | norm_layer=args.norm_layer, 97 | hidden_size=args.mlp_hidden_size, 98 | stochastic=True, 99 | num_mixture_components=args.num_prior_mixture_components, 100 | non_linearity=non_linearity 101 | ) 102 | 103 | if args.num_prior_mixture_components == 1: 104 | self.pz_c_dist = MultivariateNormalDiag 105 | else: 106 | self.pz_c_dist = MixtureGaussianDiag 107 | else: 108 | self.pz_c_dist = MultivariateNormalDiag 109 | 110 | self.decoder = gnn( 111 | latent_dim=args.dim_latent, 112 | out_channels_node=args.num_node_features_out, 113 | coordinates_dim=2 * args.dim_latent_node_out, 114 | node_features_dim=args.num_features_in + args.dim_goal, 115 | edge_features_dim=args.num_edge_features_in, 116 | mlp_hidden_size=args.graph_mlp_hidden_size, 117 | num_graph_mlp_layers=args.num_graph_mlp_layers, 118 | num_egnn_mlp_layers=args.num_egnn_mlp_layers, 119 | num_gnn_layers=args.num_gnn_layers, 120 | norm_layer=args.norm_layer, 121 | stochastic=False, 122 | num_mixture_components=args.num_likelihood_mixture_components, 123 | non_linearity=non_linearity, 124 | ) 125 | 126 | if args.num_likelihood_mixture_components == 1: 127 | self.px_z_dist = MultivariateNormalDiag 128 | else: 129 | self.px_z_dist = MixtureGaussianDiag 130 | 131 | def preprocess(self, data): 132 | data["edge_index_full"] = data.edge_index_full.type(torch.long) 133 | data["edge_index_partial"] = data.edge_index_full[:, data.partial_mask].type(torch.long) 134 | # data["edge_attr_partial"] = data.edge_attr[data.partial_mask] 135 | data["edge_attr_partial"] = data.partial_mask.unsqueeze(-1) * data.edge_attr 136 | data["T_ee"] = torch_log_from_T(data.T_ee) 137 | #XXX: Hacky workaround since LieGroups doesn't handle batch_size=1 138 | if data["T_ee"].dim() == 1: 139 | data["T_ee"] = data["T_ee"].unsqueeze(0) 140 | dim = data.T_ee.shape[-1]//3 + 1 141 | data["goal_data_repeated_per_node"] = torch.repeat_interleave(data.T_ee, (dim-1)*data.num_joints + self.num_anchor_nodes, dim=0) 142 | return data 143 | 144 | def loss(self, res, epoch, batch_size, goal_pos, partial_goal_mask): 145 | mu_x_sample = res["mu_x_sample"] 146 | partial_non_goal_mask = torch.ones_like(partial_goal_mask) - partial_goal_mask 147 | beta_kl = min(((epoch + 1) / self.n_beta_scaling_epoch), 1.0) 148 | stats = {} 149 | 150 | # Point loss 151 | loss_anchor = torch.sum((partial_goal_mask[:,None]*(mu_x_sample - goal_pos))**2) / (batch_size) 152 | loss_non_anchor = torch.sum(partial_non_goal_mask[:,None]*((mu_x_sample - goal_pos)**2)) / (batch_size) 153 | loss_rec_pos_opt = loss_anchor + self.rec_gain * loss_non_anchor 154 | loss_rec_pos = torch.sum((mu_x_sample - goal_pos)**2) / (batch_size) 155 | stats["rec_pos_l"] = loss_rec_pos.item() 156 | 157 | # Distance loss 158 | # src, dst = res["edge_index_partial"] 159 | # # src, dst = res["edge_index_full"] 160 | # dist_samples = ((mu_x_sample[src] - mu_x_sample[dst])**2).sum(dim=-1).sqrt() 161 | # dist = ((goal_pos[src] - goal_pos[dst])**2).sum(dim=-1).sqrt() 162 | # loss_rec_dist = torch.sum((dist_samples - dist)**2) / (batch_size) 163 | # stats["rec_dist_l"] = loss_rec_dist.item() 164 | 165 | qz_xc = res["qz_xc"] 166 | pz_c = res["pz_c"] 167 | loss_kl = torch.sum(kl_divergence(qz_xc, pz_c)) / (batch_size) 168 | stats["kl_l"] = loss_kl.item() 169 | 170 | # Point loss 171 | loss = loss_rec_pos + loss_kl 172 | loss_opt = loss_rec_pos_opt + beta_kl * loss_kl# + loss_sc 173 | 174 | # Distance loss 175 | # loss = loss_rec_dist + loss_kl 176 | # loss_opt = self.rec_gain * loss_rec_dist + beta_kl * loss_kl 177 | 178 | stats["total_l"] = loss.item() 179 | return loss_opt, stats 180 | 181 | def forward(self, x, h, edge_attr, edge_attr_partial, edge_index, partial_goal_mask): 182 | # Goal T_g encoder, all edge attributes (distances) 183 | z_goal = self.goal_config_encoder( 184 | x=x, 185 | h=h, 186 | edge_attr=edge_attr, 187 | edge_index=edge_index, 188 | ) 189 | 190 | # AlphaFold style sampling of iterations to encourage fast convergence 191 | num_iterations = random.randint(1, self.max_num_iterations) 192 | 193 | for _ in range(num_iterations): 194 | 195 | # unknown distances and positions transformed to 0 196 | z_goal_partial = self.goal_partial_config_encoder( 197 | x=partial_goal_mask[:, None] * x, 198 | h=h, 199 | edge_attr=edge_attr_partial, 200 | edge_index=edge_index, 201 | ) 202 | 203 | # Encode conditional prior p(z | c) 204 | if self.train_prior: 205 | params = self.prior_encoder(z_goal_partial) 206 | pz_c = self.pz_c_dist(*params) 207 | else: 208 | pz_c = self.pz_c_dist( 209 | loc=torch.zeros_like(params[0]), 210 | scale=torch.ones_like(params[1]) 211 | ) 212 | 213 | # Encode inference distribution q(z | x, c) 214 | inp = torch.cat(( 215 | z_goal, 216 | z_goal_partial, 217 | ), dim=-1) 218 | params = self.inference_encoder(inp) 219 | qz_xc = self.qz_x_dist(*params) 220 | z = qz_xc.rsample() 221 | 222 | # Decode distribution p(x | z, c) 223 | inp_decoder = torch.cat(( 224 | z, 225 | z_goal_partial, 226 | ), dim=-1) 227 | mu_x_sample = self.decoder( 228 | x=inp_decoder, 229 | h=h, 230 | edge_attr=0.0 * edge_attr, 231 | edge_index=edge_index, 232 | ) 233 | 234 | # Decode distribution p(x | z, c) if we're going to iterate 235 | if self.max_num_iterations > 1 and self.train_prior: 236 | z_prior = pz_c.sample() 237 | inp_decoder_prior = torch.cat(( 238 | z_prior, 239 | z_goal_partial 240 | ), dim=-1) 241 | mu_x_sample_prior = self.decoder( 242 | x=inp_decoder_prior, 243 | h=h, 244 | edge_attr=0.0 * edge_attr, 245 | edge_index=edge_index, 246 | ) 247 | nodes = mu_x_sample_prior 248 | src, dst = edge_index 249 | edges = ((nodes[src] - nodes[dst])**2).sum(dim=-1).sqrt() 250 | edges = edges.unsqueeze(-1) 251 | 252 | return { 253 | "mu_x_sample": mu_x_sample, 254 | "qz_xc": qz_xc, 255 | "pz_c": pz_c 256 | } 257 | 258 | def forward_eval(self, x, h, edge_attr, edge_attr_partial, edge_index, partial_goal_mask, nodes_per_single_graph, num_samples, batch_size): 259 | for ii in range(self.max_num_iterations): 260 | with torch.no_grad(): 261 | # unknown distances and positions transformed to 0 262 | # tic = time.time() 263 | # torch.cuda.synchronize() 264 | z_goal_partial = self.goal_partial_config_encoder( 265 | x=partial_goal_mask[:, None] * x, 266 | h=h, 267 | edge_attr=edge_attr_partial, 268 | edge_index=edge_index, 269 | ) 270 | # torch.cuda.synchronize() 271 | # print(f"Goal encoder time: {time.time() - tic}") 272 | 273 | # Encode conditional prior p(z | c) 274 | # tic = time.time() 275 | # torch.cuda.synchronize() 276 | if self.train_prior: 277 | params = self.prior_encoder(z_goal_partial) 278 | pz_c = self.pz_c_dist(*params) 279 | else: 280 | pz_c = self.pz_c_dist( 281 | loc=torch.zeros((batch_size * nodes_per_single_graph, z_goal_partial.shape[-1])).to(device=z_goal_partial.device), 282 | scale=torch.ones((batch_size * nodes_per_single_graph, z_goal_partial.shape[-1])).to(device=z_goal_partial.device), 283 | ) 284 | # torch.cuda.synchronize() 285 | # print(f"Prior encoder time: {time.time() - tic}") 286 | 287 | # tic = time.time() 288 | # torch.cuda.synchronize() 289 | # Repeat data num_samples times 290 | if ii == 0: 291 | z_prior = pz_c.sample([num_samples]) 292 | z_prior = z_prior.reshape(-1, z_prior.shape[-1]) 293 | z_goal_partial = z_goal_partial.unsqueeze(0).expand(num_samples, -1, -1) 294 | z_goal_partial = z_goal_partial.reshape(-1, z_goal_partial.shape[-1]) 295 | h = h.unsqueeze(0).expand(num_samples,-1,-1) 296 | h = h.reshape(-1, h.shape[-1]) 297 | data_index = edge_index 298 | data_index = repeat_offset_index(data_index, num_samples, nodes_per_single_graph) 299 | data_index = data_index.reshape(data_index.shape[0], -1) 300 | data_edge_attr = edge_attr.unsqueeze(0).expand(num_samples, -1, -1) 301 | data_edge_attr = data_edge_attr.reshape(-1, data_edge_attr.shape[-1]) 302 | else: 303 | z_prior = pz_c.sample() 304 | # torch.cuda.synchronize() 305 | # print(f"Sampling time: {time.time() - tic}") 306 | 307 | # tic = time.time() 308 | # torch.cuda.synchronize() 309 | # Decode distribution p(x | z, c) 310 | inp_decoder_prior = torch.cat(( 311 | z_prior, 312 | z_goal_partial 313 | ), dim=-1) 314 | mu_x_sample = self.decoder( 315 | x=inp_decoder_prior, 316 | h=h, 317 | edge_attr=0.0 * data_edge_attr, 318 | edge_index=data_index, 319 | ) 320 | # torch.cuda.synchronize() 321 | # print(f"Decoder time: {time.time() - tic}") 322 | 323 | if self.max_num_iterations > 1 and self.train_prior: 324 | nodes = mu_x_sample 325 | src, dst = data_index 326 | edges = ((nodes[src] - nodes[dst])**2).sum(dim=-1).sqrt() 327 | edges = edges.unsqueeze(-1) 328 | 329 | mu_x_sample = mu_x_sample.reshape(num_samples, batch_size * nodes_per_single_graph, -1) 330 | return mu_x_sample 331 | -------------------------------------------------------------------------------- /generative_graphik/networks/egnn.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch_geometric.utils import degree 7 | from torch_geometric.nn import MessagePassing 8 | from torch_geometric.typing import OptTensor 9 | 10 | from generative_graphik.utils.torch_utils import get_norm_layer 11 | 12 | class ResWrapper(nn.Module): 13 | def __init__(self, module, dim_res=2): 14 | super(ResWrapper, self).__init__() 15 | self.module = module 16 | self.dim_res = dim_res 17 | 18 | def forward(self, x): 19 | res = x[:, :self.dim_res] 20 | out = self.module(x) 21 | return out + res 22 | 23 | class EGNNLayer(MessagePassing): 24 | def __init__( 25 | self, 26 | non_linearity, 27 | channels_h, 28 | channels_m, 29 | channels_a, 30 | aggr: str = 'add', 31 | norm_layer: str = 'None', 32 | hidden_channels: int = 64, 33 | mlp_layers=2, 34 | **kwargs 35 | ): 36 | super(EGNNLayer, self).__init__(aggr=aggr, **kwargs) 37 | 38 | self.m_len = channels_m 39 | 40 | phi_e_layers = [] 41 | phi_e_layers.extend([ 42 | nn.Linear(2 * channels_h + 1 + channels_a, hidden_channels), 43 | get_norm_layer(hidden_channels, layer_type=norm_layer, layer_dim="1d"), 44 | non_linearity 45 | ]) 46 | for _ in range(mlp_layers-2): 47 | phi_e_layers.extend([ 48 | nn.Linear(hidden_channels, hidden_channels), 49 | get_norm_layer(hidden_channels, layer_type=norm_layer, layer_dim="1d"), 50 | non_linearity 51 | ]) 52 | phi_e_layers.extend([ 53 | nn.Linear(hidden_channels, channels_m), 54 | get_norm_layer(channels_m, layer_type=norm_layer, layer_dim="1d"), 55 | non_linearity 56 | ]) 57 | self.phi_e = nn.Sequential(*phi_e_layers) 58 | 59 | phi_x_layers = [] 60 | phi_x_layers.extend([ 61 | nn.Linear(channels_m, hidden_channels), 62 | get_norm_layer(hidden_channels, layer_type=norm_layer, layer_dim="1d"), 63 | non_linearity 64 | ]) 65 | for _ in range(mlp_layers-2): 66 | phi_x_layers.extend([ 67 | nn.Linear(hidden_channels, hidden_channels), 68 | get_norm_layer(hidden_channels, layer_type=norm_layer, layer_dim="1d"), 69 | non_linearity 70 | ]) 71 | phi_x_layers.append(nn.Linear(hidden_channels, 1)) 72 | self.phi_x = nn.Sequential(*phi_x_layers) 73 | 74 | phi_h_layers = [] 75 | phi_h_layers.extend([ 76 | nn.Linear(channels_h + channels_m, hidden_channels), 77 | get_norm_layer(hidden_channels, layer_type=norm_layer, layer_dim="1d"), 78 | non_linearity 79 | ]) 80 | for _ in range(mlp_layers-2): 81 | phi_h_layers.extend([ 82 | nn.Linear(hidden_channels, hidden_channels), 83 | get_norm_layer(hidden_channels, layer_type=norm_layer, layer_dim="1d"), 84 | non_linearity 85 | ]) 86 | phi_h_layers.append(nn.Linear(hidden_channels, channels_h)) 87 | self.phi_h = nn.Sequential(*phi_h_layers) 88 | self.phi_h = ResWrapper(self.phi_h, dim_res=channels_h) 89 | 90 | def forward(self, x: Tensor, h: Tensor, edge_attr: Tensor, edge_index: Tensor, c: OptTensor=None) -> Tensor: 91 | if c is None: 92 | c = degree(edge_index[0], x.shape[0]).unsqueeze(-1) 93 | # propagate_type: (x: Tensor, h: Tensor, edge_attr: Tensor, c: OptTensor) 94 | return self.propagate(edge_index=edge_index, x=x, h=h, edge_attr=edge_attr, c=c) 95 | 96 | def message(self, x_i: Tensor, x_j: Tensor, h_i: Tensor, h_j: Tensor, edge_attr: Tensor) -> Tensor: 97 | mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2, edge_attr], dim=-1)) 98 | mx_ij = (x_i - x_j) * self.phi_x(mh_ij) 99 | return torch.cat((mx_ij, mh_ij), dim=-1) 100 | 101 | def update(self, aggr_out: Tensor, x: Tensor, h: Tensor, edge_attr: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]: 102 | m_x, m_h = aggr_out[:, :self.m_len], aggr_out[:, self.m_len:] 103 | h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) 104 | x_l1 = x + (m_x / c) 105 | return x_l1, h_l1 106 | -------------------------------------------------------------------------------- /generative_graphik/networks/eqgraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from generative_graphik.utils.torch_utils import get_norm_layer 5 | from generative_graphik.networks.egnn import EGNNLayer 6 | 7 | class EqGraph(nn.Module): 8 | def __init__( 9 | self, 10 | latent_dim=32, 11 | out_channels_node=3, 12 | coordinates_dim=3, 13 | node_features_dim=5, 14 | edge_features_dim=1, 15 | mlp_hidden_size=64, 16 | num_graph_mlp_layers=0, 17 | num_egnn_mlp_layers=2, 18 | num_gnn_layers=3, 19 | stochastic=False, 20 | num_mixture_components=8, 21 | norm_layer='None', 22 | non_linearity=nn.SiLU(), 23 | ) -> None: 24 | super(EqGraph, self).__init__() 25 | self.stochastic = stochastic 26 | self.num_mixture_components = num_mixture_components 27 | self.num_gnn_layers = num_gnn_layers 28 | self.dim_out = out_channels_node 29 | 30 | mlp_x_list = [] 31 | mlp_x_list.extend([ 32 | nn.Linear(coordinates_dim, mlp_hidden_size), 33 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 34 | non_linearity 35 | ]) 36 | for _ in range(num_graph_mlp_layers): 37 | mlp_x_list.extend([ 38 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 39 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 40 | non_linearity 41 | ]) 42 | mlp_x_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 43 | self.mlp_x = nn.Sequential(*mlp_x_list) 44 | 45 | mlp_h_list = [] 46 | mlp_h_list.extend([ 47 | nn.Linear(node_features_dim, mlp_hidden_size), 48 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 49 | non_linearity 50 | ]) 51 | for _ in range(num_graph_mlp_layers): 52 | mlp_h_list.extend([ 53 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 54 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 55 | non_linearity 56 | ]) 57 | mlp_h_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 58 | self.mlp_h = nn.Sequential(*mlp_h_list) 59 | 60 | self.gnn = nn.ModuleList() 61 | for _ in range(num_gnn_layers): 62 | self.gnn.append( 63 | EGNNLayer( 64 | channels_h=latent_dim, 65 | channels_m=latent_dim, 66 | channels_a=edge_features_dim, 67 | hidden_channels=mlp_hidden_size, 68 | norm_layer=norm_layer, 69 | non_linearity=non_linearity, 70 | mlp_layers=num_egnn_mlp_layers 71 | ) 72 | ) 73 | 74 | # Final encoder layer that outputs mean and stddev 75 | self.fc_mu = nn.Sequential( 76 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 77 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 78 | non_linearity, 79 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 80 | ) 81 | 82 | if self.stochastic: 83 | self.fc_logvar = nn.Sequential( 84 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 85 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 86 | non_linearity, 87 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 88 | ) 89 | 90 | if self.num_mixture_components > 1: 91 | self.fc_mixture = nn.Sequential( 92 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 93 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 94 | non_linearity, 95 | nn.Linear(mlp_hidden_size, self.num_mixture_components), 96 | nn.Softmax(dim=1) 97 | ) 98 | 99 | def forward(self, x, h, edge_attr, edge_index, c=None): 100 | n = x.shape[0] 101 | x_l = self.mlp_x(x) 102 | h_l = self.mlp_h(h) 103 | 104 | for ii in range(self.num_gnn_layers): 105 | x_l1, h_l1 = self.gnn[ii]( 106 | x=x_l, 107 | h=h_l, 108 | edge_attr=edge_attr, 109 | edge_index=edge_index, 110 | c=c 111 | ) 112 | x_l = x_l1 113 | h_l = h_l1 114 | v = torch.cat([x_l, h_l], dim=-1) 115 | 116 | mu = self.fc_mu(v) 117 | mu = mu.reshape(n, self.num_mixture_components, self.dim_out) 118 | if self.stochastic: 119 | logvar = self.fc_logvar(v) 120 | std = torch.exp(logvar / 2.0) 121 | std = std.reshape(n, self.num_mixture_components, self.dim_out) 122 | if self.num_mixture_components > 1: 123 | k = self.fc_mixture(v) 124 | return (k, mu, std) 125 | else: 126 | return (mu.squeeze(1), std.squeeze(1)) 127 | else: 128 | return mu.squeeze(1) -------------------------------------------------------------------------------- /generative_graphik/networks/gatgraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from generative_graphik.utils.torch_utils import get_norm_layer 5 | from torch_geometric.nn import GATConv 6 | 7 | class GATGraph(nn.Module): 8 | def __init__( 9 | self, 10 | latent_dim=32, 11 | out_channels_node=3, 12 | coordinates_dim=3, 13 | node_features_dim=5, 14 | edge_features_dim=1, 15 | mlp_hidden_size=64, 16 | num_graph_mlp_layers=0, 17 | num_gnn_layers=3, 18 | stochastic=False, 19 | num_mixture_components=8, 20 | norm_layer='None', 21 | non_linearity=nn.SiLU(), 22 | **kwargs 23 | ) -> None: 24 | super(GATGraph, self).__init__() 25 | self.stochastic = stochastic 26 | self.num_mixture_components = num_mixture_components 27 | self.num_gnn_layers = num_gnn_layers 28 | self.dim_out = out_channels_node 29 | 30 | mlp_x_list = [] 31 | mlp_x_list.extend([ 32 | nn.Linear(coordinates_dim, mlp_hidden_size), 33 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 34 | non_linearity 35 | ]) 36 | for _ in range(num_graph_mlp_layers): 37 | mlp_x_list.extend([ 38 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 39 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 40 | non_linearity 41 | ]) 42 | mlp_x_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 43 | self.mlp_x = nn.Sequential(*mlp_x_list) 44 | 45 | mlp_h_list = [] 46 | mlp_h_list.extend([ 47 | nn.Linear(node_features_dim, mlp_hidden_size), 48 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 49 | non_linearity 50 | ]) 51 | for _ in range(num_graph_mlp_layers): 52 | mlp_h_list.extend([ 53 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 54 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 55 | non_linearity 56 | ]) 57 | mlp_h_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 58 | self.mlp_h = nn.Sequential(*mlp_h_list) 59 | 60 | self.gnn = nn.ModuleList() 61 | for _ in range(num_gnn_layers): 62 | self.gnn.append( 63 | GATConv( 64 | latent_dim + latent_dim, 65 | latent_dim + latent_dim, 66 | edge_dim=edge_features_dim 67 | ) 68 | ) 69 | 70 | # Final encoder layer that outputs mean and stddev 71 | self.fc_mu = nn.Sequential( 72 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 73 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 74 | non_linearity, 75 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 76 | ) 77 | 78 | if self.stochastic: 79 | self.fc_logvar = nn.Sequential( 80 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 81 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 82 | non_linearity, 83 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 84 | ) 85 | 86 | if self.num_mixture_components > 1: 87 | self.fc_mixture = nn.Sequential( 88 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 89 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 90 | non_linearity, 91 | nn.Linear(mlp_hidden_size, self.num_mixture_components), 92 | nn.Softmax(dim=1) 93 | ) 94 | 95 | def forward(self, x, h, edge_attr, edge_index): 96 | n = x.shape[0] 97 | x_l = self.mlp_x(x) 98 | h_l = self.mlp_h(h) 99 | 100 | inp = torch.cat((x_l, h_l), dim=-1) 101 | for ii in range(self.num_gnn_layers): 102 | x_l1 = self.gnn[ii]( 103 | x=inp, 104 | edge_index=edge_index, 105 | edge_attr=edge_attr 106 | ) 107 | inp = x_l1 108 | 109 | mu = self.fc_mu(inp) 110 | mu = mu.reshape(n, self.num_mixture_components, self.dim_out) 111 | if self.stochastic: 112 | logvar = self.fc_logvar(inp) 113 | std = torch.exp(logvar / 2.0) 114 | std = std.reshape(n, self.num_mixture_components, self.dim_out) 115 | if self.num_mixture_components > 1: 116 | k = self.fc_mixture(inp) 117 | return (k, mu, std) 118 | else: 119 | return (mu.squeeze(1), std.squeeze(1)) 120 | else: 121 | return mu.squeeze(1) -------------------------------------------------------------------------------- /generative_graphik/networks/gcngraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from generative_graphik.utils.torch_utils import get_norm_layer 5 | from torch_geometric.nn import GCNConv 6 | 7 | class GCNGraph(nn.Module): 8 | def __init__( 9 | self, 10 | latent_dim=32, 11 | out_channels_node=3, 12 | coordinates_dim=3, 13 | node_features_dim=5, 14 | edge_features_dim=1, 15 | mlp_hidden_size=64, 16 | num_graph_mlp_layers=0, 17 | num_gnn_layers=3, 18 | stochastic=False, 19 | num_mixture_components=8, 20 | norm_layer='None', 21 | non_linearity=nn.SiLU(), 22 | **kwargs 23 | ) -> None: 24 | super(GCNGraph, self).__init__() 25 | self.stochastic = stochastic 26 | self.num_mixture_components = num_mixture_components 27 | self.num_gnn_layers = num_gnn_layers 28 | self.dim_out = out_channels_node 29 | 30 | mlp_x_list = [] 31 | mlp_x_list.extend([ 32 | nn.Linear(coordinates_dim, mlp_hidden_size), 33 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 34 | non_linearity 35 | ]) 36 | for _ in range(num_graph_mlp_layers): 37 | mlp_x_list.extend([ 38 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 39 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 40 | non_linearity 41 | ]) 42 | mlp_x_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 43 | self.mlp_x = nn.Sequential(*mlp_x_list) 44 | 45 | mlp_h_list = [] 46 | mlp_h_list.extend([ 47 | nn.Linear(node_features_dim, mlp_hidden_size), 48 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 49 | non_linearity 50 | ]) 51 | for _ in range(num_graph_mlp_layers): 52 | mlp_h_list.extend([ 53 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 54 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 55 | non_linearity 56 | ]) 57 | mlp_h_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 58 | self.mlp_h = nn.Sequential(*mlp_h_list) 59 | 60 | self.gnn = nn.ModuleList() 61 | for _ in range(num_gnn_layers): 62 | self.gnn.append( 63 | GCNConv(latent_dim + latent_dim, latent_dim + latent_dim, improved=True) 64 | ) 65 | 66 | # Final encoder layer that outputs mean and stddev 67 | self.fc_mu = nn.Sequential( 68 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 69 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 70 | non_linearity, 71 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 72 | ) 73 | 74 | if self.stochastic: 75 | self.fc_logvar = nn.Sequential( 76 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 77 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 78 | non_linearity, 79 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 80 | ) 81 | 82 | if self.num_mixture_components > 1: 83 | self.fc_mixture = nn.Sequential( 84 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 85 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 86 | non_linearity, 87 | nn.Linear(mlp_hidden_size, self.num_mixture_components), 88 | nn.Softmax(dim=1) 89 | ) 90 | 91 | def forward(self, x, h, edge_attr, edge_index): 92 | n = x.shape[0] 93 | x_l = self.mlp_x(x) 94 | h_l = self.mlp_h(h) 95 | 96 | inp = torch.cat((x_l, h_l), dim=-1) 97 | for ii in range(self.num_gnn_layers): 98 | x_l1 = self.gnn[ii]( 99 | x=inp, 100 | edge_index=edge_index, 101 | ) 102 | inp = x_l1 103 | 104 | mu = self.fc_mu(inp) 105 | mu = mu.reshape(n, self.num_mixture_components, self.dim_out) 106 | if self.stochastic: 107 | logvar = self.fc_logvar(inp) 108 | std = torch.exp(logvar / 2.0) 109 | std = std.reshape(n, self.num_mixture_components, self.dim_out) 110 | if self.num_mixture_components > 1: 111 | k = self.fc_mixture(inp) 112 | return (k, mu, std) 113 | else: 114 | return (mu.squeeze(1), std.squeeze(1)) 115 | else: 116 | return mu.squeeze(1) -------------------------------------------------------------------------------- /generative_graphik/networks/linearvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from generative_graphik.utils.torch_utils import Flatten, get_norm_layer 5 | 6 | class LinearVAE(nn.Module): 7 | def __init__( 8 | self, 9 | dim_in, 10 | dim_out, 11 | norm_layer='None', 12 | dropout=False, 13 | non_linearity=nn.SiLU(), 14 | hidden_size=256, 15 | stochastic=False, 16 | num_mixture_components=1 17 | ): 18 | super(LinearVAE, self).__init__() 19 | self.flatten = Flatten() 20 | self.stochastic = stochastic 21 | self.layers = nn.ModuleList() 22 | self.num_mixture_components = num_mixture_components 23 | self.dim_out = dim_out 24 | 25 | self.layers.append(nn.Linear(dim_in, hidden_size)) 26 | self.layers.append(get_norm_layer(hidden_size, layer_type=norm_layer, layer_dim="1d")) 27 | if dropout: self.layers.append(nn.Dropout(p=0.5)) 28 | self.layers.append(non_linearity) 29 | 30 | self.fc_mu = nn.Sequential( 31 | nn.Linear(hidden_size, hidden_size), 32 | get_norm_layer(hidden_size, layer_type=norm_layer, layer_dim="1d"), 33 | non_linearity, 34 | nn.Linear(hidden_size, num_mixture_components * dim_out), 35 | ) 36 | if self.stochastic: 37 | self.fc_logvar = nn.Sequential( 38 | nn.Linear(hidden_size, hidden_size), 39 | get_norm_layer(hidden_size, layer_type=norm_layer, layer_dim="1d"), 40 | non_linearity, 41 | nn.Linear(hidden_size, num_mixture_components * dim_out) 42 | ) 43 | 44 | if self.num_mixture_components > 1: 45 | self.fc_mixture = nn.Sequential( 46 | nn.Linear(hidden_size, hidden_size), 47 | get_norm_layer(hidden_size, layer_type=norm_layer, layer_dim="1d"), 48 | non_linearity, 49 | nn.Linear(hidden_size, num_mixture_components), 50 | nn.Softmax(dim=1) 51 | ) 52 | 53 | def forward(self, x): 54 | n = x.shape[0] 55 | for l in self.layers: 56 | x = l(x) 57 | 58 | mu = self.fc_mu(x) 59 | mu = mu.reshape(n, self.num_mixture_components, self.dim_out) 60 | if self.stochastic: 61 | logvar = self.fc_logvar(x) 62 | std = torch.exp(logvar / 2.0) 63 | std = std.reshape(n, self.num_mixture_components, self.dim_out) 64 | if self.num_mixture_components > 1: 65 | k = self.fc_mixture(x) 66 | return (k, mu, std) 67 | else: 68 | return (mu.squeeze(1), std.squeeze(1)) 69 | else: 70 | return mu.squeeze(1) 71 | -------------------------------------------------------------------------------- /generative_graphik/networks/mpnngraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from generative_graphik.utils.torch_utils import get_norm_layer 5 | from torch_geometric.nn import NNConv 6 | 7 | class MPNNGraph(nn.Module): 8 | def __init__( 9 | self, 10 | latent_dim=32, 11 | out_channels_node=3, 12 | coordinates_dim=3, 13 | node_features_dim=5, 14 | edge_features_dim=1, 15 | mlp_hidden_size=64, 16 | num_graph_mlp_layers=0, 17 | num_gnn_layers=3, 18 | stochastic=False, 19 | num_mixture_components=8, 20 | norm_layer='None', 21 | non_linearity=nn.SiLU(), 22 | **kwargs 23 | ) -> None: 24 | super(MPNNGraph, self).__init__() 25 | self.stochastic = stochastic 26 | self.num_mixture_components = num_mixture_components 27 | self.num_gnn_layers = num_gnn_layers 28 | self.dim_out = out_channels_node 29 | 30 | mlp_x_list = [] 31 | mlp_x_list.extend([ 32 | nn.Linear(coordinates_dim, mlp_hidden_size), 33 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 34 | non_linearity 35 | ]) 36 | for _ in range(num_graph_mlp_layers): 37 | mlp_x_list.extend([ 38 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 39 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 40 | non_linearity 41 | ]) 42 | mlp_x_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 43 | self.mlp_x = nn.Sequential(*mlp_x_list) 44 | 45 | mlp_h_list = [] 46 | mlp_h_list.extend([ 47 | nn.Linear(node_features_dim, mlp_hidden_size), 48 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 49 | non_linearity 50 | ]) 51 | for _ in range(num_graph_mlp_layers): 52 | mlp_h_list.extend([ 53 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 54 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 55 | non_linearity 56 | ]) 57 | mlp_h_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 58 | self.mlp_h = nn.Sequential(*mlp_h_list) 59 | 60 | mlp = nn.Sequential( 61 | nn.Linear(edge_features_dim, edge_features_dim), 62 | non_linearity, 63 | nn.Linear(edge_features_dim, (latent_dim + latent_dim) * (latent_dim + latent_dim)) 64 | ) 65 | self.gnn = nn.ModuleList() 66 | for _ in range(num_gnn_layers): 67 | self.gnn.append( 68 | NNConv(latent_dim + latent_dim, latent_dim + latent_dim, mlp) 69 | ) 70 | 71 | # Final encoder layer that outputs mean and stddev 72 | self.fc_mu = nn.Sequential( 73 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 74 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 75 | non_linearity, 76 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 77 | ) 78 | 79 | if self.stochastic: 80 | self.fc_logvar = nn.Sequential( 81 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 82 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 83 | non_linearity, 84 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 85 | ) 86 | 87 | if self.num_mixture_components > 1: 88 | self.fc_mixture = nn.Sequential( 89 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 90 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 91 | non_linearity, 92 | nn.Linear(mlp_hidden_size, self.num_mixture_components), 93 | nn.Softmax(dim=1) 94 | ) 95 | 96 | def forward(self, x, h, edge_attr, edge_index): 97 | n = x.shape[0] 98 | x_l = self.mlp_x(x) 99 | h_l = self.mlp_h(h) 100 | 101 | inp = torch.cat((x_l, h_l), dim=-1) 102 | for ii in range(self.num_gnn_layers): 103 | x_l1 = self.gnn[ii]( 104 | x=inp, 105 | edge_index=edge_index, 106 | edge_attr=edge_attr 107 | ) 108 | inp = x_l1 109 | 110 | mu = self.fc_mu(inp) 111 | mu = mu.reshape(n, self.num_mixture_components, self.dim_out) 112 | if self.stochastic: 113 | logvar = self.fc_logvar(inp) 114 | std = torch.exp(logvar / 2.0) 115 | std = std.reshape(n, self.num_mixture_components, self.dim_out) 116 | if self.num_mixture_components > 1: 117 | k = self.fc_mixture(inp) 118 | return (k, mu, std) 119 | else: 120 | return (mu.squeeze(1), std.squeeze(1)) 121 | else: 122 | return mu.squeeze(1) -------------------------------------------------------------------------------- /generative_graphik/networks/sagegraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from generative_graphik.utils.torch_utils import get_norm_layer 5 | from torch_geometric.nn import SAGEConv 6 | 7 | class SAGEGraph(nn.Module): 8 | def __init__( 9 | self, 10 | latent_dim=32, 11 | out_channels_node=3, 12 | coordinates_dim=3, 13 | node_features_dim=5, 14 | edge_features_dim=1, 15 | mlp_hidden_size=64, 16 | num_graph_mlp_layers=0, 17 | num_gnn_layers=3, 18 | stochastic=False, 19 | num_mixture_components=8, 20 | norm_layer='None', 21 | non_linearity=nn.SiLU(), 22 | **kwargs 23 | ) -> None: 24 | super(SAGEGraph, self).__init__() 25 | self.stochastic = stochastic 26 | self.num_mixture_components = num_mixture_components 27 | self.num_gnn_layers = num_gnn_layers 28 | self.dim_out = out_channels_node 29 | 30 | mlp_x_list = [] 31 | mlp_x_list.extend([ 32 | nn.Linear(coordinates_dim, mlp_hidden_size), 33 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 34 | non_linearity 35 | ]) 36 | for _ in range(num_graph_mlp_layers): 37 | mlp_x_list.extend([ 38 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 39 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 40 | non_linearity 41 | ]) 42 | mlp_x_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 43 | self.mlp_x = nn.Sequential(*mlp_x_list) 44 | 45 | mlp_h_list = [] 46 | mlp_h_list.extend([ 47 | nn.Linear(node_features_dim, mlp_hidden_size), 48 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 49 | non_linearity 50 | ]) 51 | for _ in range(num_graph_mlp_layers): 52 | mlp_h_list.extend([ 53 | nn.Linear(mlp_hidden_size, mlp_hidden_size), 54 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 55 | non_linearity 56 | ]) 57 | mlp_h_list.append(nn.Linear(mlp_hidden_size, latent_dim)) 58 | self.mlp_h = nn.Sequential(*mlp_h_list) 59 | 60 | self.gnn = nn.ModuleList() 61 | for _ in range(num_gnn_layers): 62 | self.gnn.append( 63 | SAGEConv(latent_dim + latent_dim, latent_dim + latent_dim) 64 | ) 65 | 66 | # Final encoder layer that outputs mean and stddev 67 | self.fc_mu = nn.Sequential( 68 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 69 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 70 | non_linearity, 71 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 72 | ) 73 | 74 | if self.stochastic: 75 | self.fc_logvar = nn.Sequential( 76 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 77 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 78 | non_linearity, 79 | nn.Linear(mlp_hidden_size, self.num_mixture_components * out_channels_node), 80 | ) 81 | 82 | if self.num_mixture_components > 1: 83 | self.fc_mixture = nn.Sequential( 84 | nn.Linear(latent_dim + latent_dim, mlp_hidden_size), 85 | get_norm_layer(mlp_hidden_size, layer_type=norm_layer, layer_dim="1d"), 86 | non_linearity, 87 | nn.Linear(mlp_hidden_size, self.num_mixture_components), 88 | nn.Softmax(dim=1) 89 | ) 90 | 91 | def forward(self, x, h, edge_attr, edge_index): 92 | n = x.shape[0] 93 | x_l = self.mlp_x(x) 94 | h_l = self.mlp_h(h) 95 | 96 | inp = torch.cat((x_l, h_l), dim=-1) 97 | for ii in range(self.num_gnn_layers): 98 | x_l1 = self.gnn[ii]( 99 | x=inp, 100 | edge_index=edge_index, 101 | ) 102 | inp = x_l1 103 | 104 | mu = self.fc_mu(inp) 105 | mu = mu.reshape(n, self.num_mixture_components, self.dim_out) 106 | if self.stochastic: 107 | logvar = self.fc_logvar(inp) 108 | std = torch.exp(logvar / 2.0) 109 | std = std.reshape(n, self.num_mixture_components, self.dim_out) 110 | if self.num_mixture_components > 1: 111 | k = self.fc_mixture(inp) 112 | return (k, mu, std) 113 | else: 114 | return (mu.squeeze(1), std.squeeze(1)) 115 | else: 116 | return mu.squeeze(1) -------------------------------------------------------------------------------- /generative_graphik/train.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import json 3 | import os 4 | import time 5 | from collections import OrderedDict 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch_geometric.loader import DataLoader 12 | import torch_geometric 13 | 14 | from generative_graphik.args.parser import parse_training_args 15 | from generative_graphik.utils.torch_utils import set_seed_torch 16 | from generative_graphik.utils.dataset_generation import CachedDataset 17 | 18 | def _init_fn(worker_id): 19 | np.random.seed(int(args.random_seed)) 20 | 21 | def load_datasets(path: str, device, val_pcnt=0): 22 | with open(path, "rb") as f: 23 | try: 24 | # data = pickle.load(f) 25 | data = torch.load(f) 26 | data._data = data._data.to(device) 27 | val_size = int((val_pcnt/100)*len(data)) 28 | train_size = len(data) - val_size 29 | val_dataset, train_dataset = torch.utils.data.random_split(data, [val_size, train_size]) 30 | except (OSError, IOError) as e: 31 | val_dataset = None 32 | train_dataset = None 33 | return train_dataset, val_dataset 34 | 35 | def opt_epoch(paths, model, epoch, device, opt=None, total_batches=100): 36 | """Single training epoch.""" 37 | if opt: 38 | model.train() 39 | else: 40 | model.eval() 41 | 42 | # Keep track of losses 43 | running_stats = {} 44 | 45 | num_batches = 0 46 | with tqdm(total=total_batches) as pbar: 47 | for path in paths: 48 | 49 | # Load training dataset from training path 50 | all_data, _ = load_datasets( 51 | path, 52 | device, 53 | val_pcnt=0 54 | ) 55 | 56 | loader = DataLoader( 57 | all_data, 58 | batch_size=args.n_batch, 59 | num_workers=args.n_worker, 60 | shuffle=True, 61 | worker_init_fn=_init_fn 62 | ) 63 | for idx, data in enumerate(loader): 64 | 65 | # Pick 1 random config from the samples as the goal and the rest as the random configs 66 | with torch.no_grad(): 67 | data_ = model.preprocess(data) 68 | 69 | # Forward call 70 | res = model.forward( 71 | x=data_.pos, 72 | h=torch.cat((data_.type, data_.goal_data_repeated_per_node), dim=-1), 73 | edge_attr=data_.edge_attr, 74 | edge_attr_partial=data_.edge_attr_partial, 75 | edge_index=data_.edge_index_full, 76 | partial_goal_mask=data_.partial_goal_mask 77 | ) 78 | 79 | # Get loss and stats 80 | loss, stats = model.loss( 81 | res=res, 82 | epoch=epoch, 83 | batch_size=args.n_batch, 84 | goal_pos=data_.pos, 85 | partial_goal_mask=data_.partial_goal_mask 86 | ) 87 | 88 | if opt: 89 | opt.zero_grad() 90 | loss.backward() 91 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 100) 92 | opt.step() 93 | 94 | num_batches += 1 95 | pbar.update(1) 96 | # add stats to running stats 97 | for key, val in stats.items(): 98 | if key not in running_stats: 99 | running_stats[key] = [val] 100 | else: 101 | running_stats[key] += [val] 102 | summary_stats = {f"avg_{k}": sum(v) / len(v) for k, v in running_stats.items()} 103 | 104 | return summary_stats 105 | 106 | 107 | def train(args): 108 | # Make directory to save hyperparameters, models, etc 109 | if not args.debug: 110 | # Directories based on if using SLURM 111 | slurm_id = os.environ.get("SLURM_JOB_ID") 112 | save_dir = os.path.join(args.storage_base_path, args.id) 113 | if slurm_id is not None: 114 | user = os.environ.get("USER") 115 | checkpoint_dir = f"/checkpoint/{user}/{slurm_id}" 116 | else: 117 | checkpoint_dir = os.path.join(save_dir, "checkpoints/") 118 | 119 | if not os.path.exists(save_dir + "/hyperparameters.txt"): # if model not yet generated 120 | print("Saving hyperparameters ...") 121 | print(args) 122 | os.makedirs(save_dir, exist_ok=True) 123 | os.makedirs(checkpoint_dir, exist_ok=True) 124 | args.__dict__ = OrderedDict(sorted(args.__dict__.items(), key=lambda t: t[0])) 125 | with open(save_dir + "/hyperparameters.txt", "w") as f: 126 | json.dump(args.__dict__, f, indent=2) 127 | else: # if model exists load previous args 128 | with open(save_dir + "/hyperparameters.txt", "r") as f: 129 | print("Loading parameters from hyperparameters.txt") 130 | new_amount_epochs = args.n_epoch 131 | args.__dict__.update(json.load(f)) 132 | args.n_epoch = new_amount_epochs 133 | 134 | writer = SummaryWriter(log_dir=save_dir) 135 | tb_data = [] 136 | 137 | # Fix random seed 138 | torch.backends.cudnn.deterministic = args.cudnn_deterministic 139 | torch.backends.cudnn.benchmark = args.cudnn_benchmark 140 | set_seed_torch(args.random_seed) 141 | device = torch.device(args.device) 142 | print("PyTorch setting set") 143 | 144 | # Dynamically load the networks module specific to the model 145 | if args.module_path == "none": 146 | spec = importlib.util.spec_from_file_location("networks", save_dir + "/model.py") 147 | else: 148 | spec = importlib.util.spec_from_file_location("networks", args.module_path) 149 | 150 | network = importlib.util.module_from_spec(spec) 151 | spec.loader.exec_module(network) 152 | print("Network module loaded") 153 | 154 | # Model 155 | model = network.Model(args).to(device) 156 | # model = torch_geometric.compile(model, fullgraph=True) 157 | print("Model loaded") 158 | 159 | # Optimizer 160 | params = list(model.parameters()) 161 | opt = torch.optim.AdamW(params, lr=args.lr) 162 | sched = torch.optim.lr_scheduler.StepLR(opt, args.n_scheduler_epoch, gamma=0.5) 163 | print("Optimizer loaded") 164 | 165 | # XXX: If a checkpoint exists, assume preempted and resume training 166 | initial_epoch = 0 167 | if not args.debug: 168 | if os.path.exists(os.path.join(checkpoint_dir, "checkpoint.pth")): 169 | checkpoint = torch.load(os.path.join(checkpoint_dir, "checkpoint.pth")) 170 | model.load_state_dict(checkpoint["net"]) 171 | sched.load_state_dict(checkpoint["sched"]) 172 | opt.load_state_dict(checkpoint["opt"]) 173 | initial_epoch = checkpoint["epoch"] 174 | print(f"Resuming training from checkpoint at epoch {initial_epoch}") 175 | elif args.pretrained_weights_path: 176 | # Optionally, if using pre-trained weights 177 | state_dict = torch.load(args.pretrained_weights_path, map_location=device) 178 | print("State dict loaded") 179 | model.load_state_dict(state_dict) 180 | print(f"Using pretrained weights from {args.pretrained_weights_path}") 181 | 182 | root = args.training_data_path 183 | root_val = args.validation_data_path 184 | paths = [root + "/" + f for f in os.listdir(root) if os.path.isfile(os.path.join(root, f))] 185 | paths_val = [root_val + "/" + f for f in os.listdir(root_val) if os.path.isfile(os.path.join(root_val, f))] 186 | 187 | # Load one file from the training set to establish total size 188 | loader = DataLoader(load_datasets(paths[0],device)[0],batch_size=args.n_batch) 189 | total_batches = len(loader)*len(paths) 190 | val_loader = DataLoader(load_datasets(paths_val[0],device)[0],batch_size=args.n_batch) 191 | val_batches = len(val_loader)*len(paths_val) 192 | del loader, val_loader 193 | print("Data loaders loaded") 194 | 195 | # Training loop 196 | try: 197 | for epoch in range(initial_epoch + 1, args.n_epoch + 1): 198 | tic = time.time() 199 | 200 | # Train for one epoch 201 | summary_train = opt_epoch( 202 | paths, 203 | model=model, 204 | opt=opt, 205 | epoch=epoch, 206 | device=device, 207 | total_batches=total_batches 208 | ) 209 | 210 | if args.use_validation: 211 | with torch.no_grad(): 212 | summary_val = opt_epoch( 213 | paths_val, 214 | model=model, 215 | epoch=epoch, 216 | device=device, 217 | total_batches=val_batches 218 | ) 219 | if sched: 220 | sched.step() 221 | 222 | epoch_time = time.time() - tic 223 | 224 | print(f"Epoch {epoch}/{args.n_epoch}, Time per epoch: {epoch_time}") 225 | train_str = "[Train]" 226 | for key, val in summary_train.items(): 227 | train_str += " " + key + f": {val}," 228 | print(train_str) 229 | 230 | if args.use_validation: 231 | val_str = "[Val]" 232 | for key, val in summary_val.items(): 233 | val_str += " " + key + f": {val}," 234 | 235 | print(val_str) 236 | print("----------------------------------------") 237 | 238 | # Store tensorboard data 239 | if not args.debug: 240 | for k, v in summary_train.items(): 241 | tb_data.append((f"train/{k}", v, epoch)) 242 | 243 | if args.use_validation: 244 | for k, v in summary_val.items(): 245 | tb_data.append((f"val/{k}", v, epoch)) 246 | 247 | if epoch % args.n_checkpoint_epoch == 0: 248 | # Write tensorboard data 249 | for data in tb_data: 250 | writer.add_scalar(data[0], data[1], data[2]) 251 | tb_data = [] 252 | 253 | # Save model at intermittent checkpoints 254 | torch.save( 255 | { 256 | "net": model.state_dict(), 257 | "opt": opt.state_dict(), 258 | "sched": sched.state_dict(), 259 | "epoch": epoch, 260 | }, 261 | os.path.join(checkpoint_dir, "checkpoint.pth"), 262 | ) 263 | finally: 264 | # Save models 265 | if not args.debug: 266 | # Save models 267 | torch.save(model.state_dict(), save_dir + f"/net.pth") 268 | writer.close() 269 | 270 | if __name__ == "__main__": 271 | args = parse_training_args() 272 | train(args) 273 | -------------------------------------------------------------------------------- /generative_graphik/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 4 | SRC_PATH="${SCRIPT_DIR}/.." 5 | 6 | NAME=$1 7 | DATASET_NAME=$2 8 | VALIDATION_DATASET_NAME=${3:-"${DATASET_NAME}_validation"} 9 | TRAIN=${4:-true} 10 | 11 | MODEL_PATH="${SRC_PATH}/saved_models/${NAME}_model" 12 | 13 | # export PYTORCH_ENABLE_MPS_FALLBACK=1 14 | 15 | # copy the model and training code if new 16 | if [ -d "${MODEL_PATH}" ] 17 | then 18 | echo "Directory already exists, using existing model." 19 | else 20 | echo "Creating new model directory." 21 | 22 | mkdir ${MODEL_PATH} 23 | cp ${SRC_PATH}/generative_graphik/model.py ${MODEL_PATH}/model.py 24 | fi 25 | 26 | if [ ! -d "${SRC_PATH}/datasets/${DATASET_NAME}" ] 27 | then 28 | echo "Dataset ${DATASET_NAME} not found, creating it." 29 | python -u ${SRC_PATH}/generative_graphik/utils/dataset_generation.py \ 30 | --id "${DATASET_NAME}" \ 31 | --robots kuka \ 32 | --num_examples 512000 \ 33 | --max_examples_per_file 512000 \ 34 | --goal_type pose \ 35 | --goal_type pose \ 36 | --randomize False 37 | else 38 | echo "Dataset ${DATASET_NAME} found!" 39 | fi 40 | 41 | if [ "${TRAIN}" = true ] 42 | then 43 | python -u ${SRC_PATH}/generative_graphik/train.py \ 44 | --id "${NAME}_model" \ 45 | --norm_layer LayerNorm \ 46 | --debug False \ 47 | --device cuda:1 \ 48 | --n_worker 0 \ 49 | --n_beta_scaling_epoch 1 \ 50 | --lr 3e-4 \ 51 | --n_batch 128 \ 52 | --num_graph_mlp_layers 2 \ 53 | --num_egnn_mlp_layers 2 \ 54 | --graph_mlp_hidden_size 128 \ 55 | --mlp_hidden_size 128 \ 56 | --dim_latent_node_out 16 \ 57 | --dim_latent 64 \ 58 | --gnn_type "egnn" \ 59 | --num_gnn_layers 5 \ 60 | --num_node_features_out 3 \ 61 | --num_coordinates_in 3 \ 62 | --num_features_in 3 \ 63 | --num_edge_features_in 1 \ 64 | --num_prior_mixture_components 16 \ 65 | --num_likelihood_mixture_components 1\ 66 | --num_anchor_nodes 4 \ 67 | --train_prior True \ 68 | --n_epoch 1 \ 69 | --n_scheduler_epoch 60\ 70 | --dim_goal 6 \ 71 | --storage_base_path "${SRC_PATH}/saved_models" \ 72 | --training_data_path "${SRC_PATH}/datasets/${DATASET_NAME}" \ 73 | --validation_data_path "${SRC_PATH}/datasets/${VALIDATION_DATASET_NAME}" \ 74 | --module_path "${SRC_PATH}/generative_graphik/model.py" \ 75 | --use_validation True \ 76 | --n_checkpoint_epoch 16 \ 77 | --non_linearity silu \ 78 | --rec_gain 10 79 | fi -------------------------------------------------------------------------------- /generative_graphik/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/utiasSTARS/generative-graphik/9263453cc7e1df506e57370d6558eca5aeb4b998/generative_graphik/utils/__init__.py -------------------------------------------------------------------------------- /generative_graphik/utils/dataset_generation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | from dataclasses import dataclass 6 | 7 | import torch 8 | from torch_geometric.data import InMemoryDataset, Data 9 | import torch.multiprocessing as mp 10 | 11 | import graphik 12 | from graphik.robots import RobotRevolute 13 | from graphik.graphs import ProblemGraphRevolute 14 | from graphik.graphs.graph_revolute import random_revolute_robot_graph 15 | import generative_graphik 16 | from generative_graphik.args.parser import parse_data_generation_args 17 | from generative_graphik.utils.torch_utils import ( 18 | batchFKmultiDOF, 19 | batchPmultiDOF, 20 | edge_indices_attributes, 21 | node_attributes, 22 | ) 23 | from graphik.utils import ( 24 | BASE, 25 | DIST, 26 | ROBOT, 27 | OBSTACLE, 28 | POS, 29 | TYPE, 30 | ) 31 | from graphik.utils.roboturdf import RobotURDF 32 | 33 | TYPE_ENUM = { 34 | BASE: np.asarray([1, 0, 0]), 35 | ROBOT: np.asarray([0, 1, 0]), 36 | OBSTACLE: np.asarray([0, 0, 1]), 37 | } 38 | ANCHOR_ENUM = {"anchor": np.asarray([1, 0]), "not_anchor": np.asarray([0, 1])} 39 | 40 | 41 | class CachedDataset(InMemoryDataset): 42 | def __init__(self, data, slices): 43 | super(CachedDataset, self).__init__(None) 44 | self.data, self.slices = data, slices 45 | 46 | 47 | @dataclass 48 | class StructData: 49 | type: Union[List[torch.Tensor], torch.Tensor] 50 | num_joints: Union[List[int], int] 51 | num_nodes: Union[List[int], int] 52 | num_edges: Union[List[int], int] 53 | partial_mask: Union[List[torch.Tensor], torch.Tensor] 54 | partial_goal_mask: Union[List[torch.Tensor], torch.Tensor] 55 | edge_index_full: Union[List[torch.Tensor], torch.Tensor] 56 | T0: Union[List[torch.Tensor], torch.Tensor] 57 | 58 | 59 | def generate_data_point(graph): 60 | struct_data = generate_struct_data(graph) 61 | 62 | num_joints = torch.tensor([struct_data.num_joints]) 63 | edge_index_full = struct_data.edge_index_full 64 | T0 = struct_data.T0 65 | 66 | q = torch.rand(num_joints[0], dtype=T0.dtype) * 2 * torch.pi - torch.pi 67 | q[num_joints[0] - 1] = 0 68 | T = batchFKmultiDOF(T0, q, num_joints) 69 | P = batchPmultiDOF(T, num_joints) 70 | T_ee = T[num_joints[0]] 71 | distances = torch.linalg.norm( 72 | P[edge_index_full[0], :] - P[edge_index_full[1], :], dim=-1 73 | ) 74 | 75 | return Data( 76 | type=struct_data.type, 77 | pos=P, 78 | edge_attr=distances.unsqueeze(1), 79 | T_ee=T_ee, 80 | num_joints=num_joints.type(torch.int32), 81 | partial_mask=struct_data.partial_mask, 82 | partial_goal_mask=struct_data.partial_goal_mask, 83 | edge_index_full=edge_index_full.type(torch.int32), 84 | T0=struct_data.T0, 85 | q_goal=q, 86 | ) 87 | 88 | 89 | def generate_struct_data(graph): 90 | 91 | robot = graph.robot 92 | dof = robot.n 93 | num_joints = dof 94 | num_nodes = 2 * (dof + 1) + 2 # number of nodes for point graphs 95 | 96 | type = node_attributes(graph, attrs=[TYPE])[0] 97 | T0 = node_attributes(graph.robot, attrs=["T0"])[0] 98 | 99 | G_partial = graph.from_pose(robot.pose(robot.random_configuration(), f"p{dof}")) 100 | edge_index_partial, _ = edge_indices_attributes(G_partial) 101 | # D = nx.to_scipy_sparse_array(G_partial.to_undirected(), weight=DIST, format="coo") 102 | # ind0, ind1 = D.row, D.col 103 | ind0 = edge_index_partial[0] 104 | ind1 = edge_index_partial[1] 105 | 106 | edge_index_full = ( 107 | (torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes)) 108 | .nonzero() 109 | .transpose(0, 1) 110 | ) 111 | num_edges = edge_index_full[-1].shape[-1] 112 | 113 | partial_goal_mask = torch.zeros(num_nodes) 114 | partial_goal_mask[: graph.dim + 1] = 1 115 | partial_goal_mask[-2:] = 1 116 | 117 | # _______extracting partial indices from vectorized full indices via mask 118 | mask_gen = torch.zeros(num_nodes, num_nodes) # square matrix of zeroes 119 | mask_gen[ind0, ind1] = 1 # set partial elements to 1 120 | mask = ( 121 | mask_gen[edge_index_full[0], edge_index_full[1]] > 0 122 | ) # get full elements from matrix (same order as generated) 123 | 124 | return StructData( 125 | type=type, 126 | num_joints=num_joints, 127 | num_edges=num_edges, 128 | num_nodes=num_nodes, 129 | partial_mask=mask, 130 | partial_goal_mask=partial_goal_mask, 131 | edge_index_full=edge_index_full, 132 | T0=T0, 133 | ) 134 | 135 | 136 | def generate_specific_robot_data(robots, num_examples, params): 137 | 138 | examples_per_robot = num_examples // len(robots) 139 | 140 | all_struct_data = StructData( 141 | type=[], 142 | num_joints=[], 143 | num_nodes=[], 144 | num_edges=[], 145 | partial_mask=[], 146 | partial_goal_mask=[], 147 | edge_index_full=[], 148 | T0=[], 149 | ) 150 | 151 | q_lim_l_all = [] 152 | q_lim_u_all = [] 153 | 154 | for robot_name in robots: 155 | # generate data for robot like ur10, kuka etc. 156 | if robot_name == "ur10": 157 | # randomize won't work on ur10 158 | # robot, graph = load_ur10(limits=None, randomized_links = False) 159 | fname = graphik.__path__[0] + "/robots/urdfs/ur10_mod.urdf" 160 | q_lim_l = -np.pi * np.ones(6) 161 | q_lim_u = np.pi * np.ones(6) 162 | elif robot_name == "kuka": 163 | # robot, graph = load_kuka(limits=None, randomized_links = params["randomize"], randomize_percentage=0.2) 164 | fname = graphik.__path__[0] + "/robots/urdfs/kuka_iiwr.urdf" 165 | q_lim_l = -np.pi * np.ones(7) 166 | q_lim_u = np.pi * np.ones(7) 167 | elif robot_name == "lwa4d": 168 | # robot, graph = load_schunk_lwa4d(limits=None, randomized_links = params["randomize"], randomize_percentage=0.2) 169 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4d.urdf" 170 | q_lim_l = -np.pi * np.ones(7) 171 | q_lim_u = np.pi * np.ones(7) 172 | elif robot_name == "panda": 173 | # robot, graph = load_schunk_lwa4d(limits=None, randomized_links = params["randomize"], randomize_percentage=0.2) 174 | fname = graphik.__path__[0] + "/robots/urdfs/panda_arm.urdf" 175 | # q_lim_l = -np.pi * np.ones(7) 176 | # q_lim_u = np.pi * np.ones(7) 177 | q_lim_l = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973]) 178 | q_lim_u = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973]) 179 | elif robot_name == "lwa4p": 180 | # robot, graph = load_schunk_lwa4p(limits=None, randomized_links = params["randomize"], randomize_percentage=0.2) 181 | fname = graphik.__path__[0] + "/robots/urdfs/lwa4p.urdf" 182 | q_lim_l = -np.pi * np.ones(6) 183 | q_lim_u = np.pi * np.ones(6) 184 | else: 185 | raise NotImplementedError 186 | 187 | urdf_robot = RobotURDF(fname) 188 | robot = urdf_robot.make_Revolute3d( 189 | q_lim_l, 190 | q_lim_u, 191 | randomized_links=params["randomize"], 192 | randomize_percentage=params["randomize_percentage"], 193 | ) # make the Revolute class from a URDF 194 | graph = ProblemGraphRevolute(robot) 195 | struct_data = generate_struct_data(graph) 196 | 197 | for _ in tqdm(range(examples_per_robot), leave=False): 198 | # q_lim_l_all.append(q_lim_l) 199 | # q_lim_u_all.append(q_lim_u) 200 | for field in struct_data.__dataclass_fields__: 201 | all_struct_data.__dict__[field].append(getattr(struct_data, field)) 202 | 203 | types = torch.cat(all_struct_data.type, dim=0) 204 | T0 = torch.cat(all_struct_data.T0, dim=0).reshape(-1, 4, 4) 205 | # q_lim_l_all = torch.from_numpy(np.concatenate(q_lim_l_all)).type(T0.dtype) 206 | # q_lim_u_all = torch.from_numpy(np.concatenate(q_lim_u_all)).type(T0.dtype) 207 | num_joints = torch.tensor(all_struct_data.num_joints) 208 | num_nodes = torch.tensor(all_struct_data.num_nodes) 209 | num_edges = torch.tensor(all_struct_data.num_edges) 210 | 211 | # problem is that edge_index_full doesn't contain self-loops 212 | masks = torch.cat(all_struct_data.partial_mask, dim=-1) 213 | edge_index_full = torch.cat(all_struct_data.edge_index_full, dim=-1) 214 | partial_goal_mask = torch.cat(all_struct_data.partial_goal_mask, dim=-1) 215 | 216 | # delete struct data 217 | all_struct_data = None 218 | q = torch.rand(num_joints.sum(), dtype=T0.dtype) * 2 * torch.pi - torch.pi 219 | # q = torch.rand(num_joints.sum(), dtype=T0.dtype) * (q_lim_u_all - q_lim_l_all) + q_lim_l_all 220 | 221 | q[(num_joints).cumsum(dim=-1) - 1] = 0 222 | T = batchFKmultiDOF(T0, q, num_joints) 223 | P = batchPmultiDOF(T, num_joints) 224 | # T_ee = T[num_joints.cumsum(dim=-1)] 225 | T_ee = T[torch.cumsum(num_joints + 1, dim=0) - 1] 226 | offset_full = ( 227 | torch.cat([torch.tensor([0]), num_nodes[:-1].cumsum(dim=-1)]) 228 | .repeat_interleave(num_edges, dim=-1) 229 | .unsqueeze(0) 230 | .expand(2, -1) 231 | ) 232 | edge_index_full_offset = edge_index_full + offset_full 233 | distances = torch.linalg.norm( 234 | P[edge_index_full_offset[0], :] - P[edge_index_full_offset[1], :], dim=-1 235 | ) 236 | 237 | node_slice = torch.cat([torch.tensor([0]), (num_nodes).cumsum(dim=-1)]) 238 | joint_slice = torch.cat([torch.tensor([0]), (num_joints).cumsum(dim=-1)]) 239 | frame_slice = torch.cat([torch.tensor([0]), (num_joints + 1).cumsum(dim=-1)]) 240 | robot_slice = torch.arange(num_joints.size(0) + 1) 241 | edge_full_slice = torch.cat([torch.tensor([0]), (num_edges).cumsum(dim=-1)]) 242 | 243 | slices = { 244 | "edge_attr": edge_full_slice, 245 | "pos": node_slice, 246 | "type": node_slice, 247 | "T_ee": robot_slice, 248 | "num_joints": robot_slice, 249 | "partial_mask": edge_full_slice, 250 | "partial_goal_mask": node_slice, 251 | "edge_index_full": edge_full_slice, 252 | "M": frame_slice, 253 | "q_goal": joint_slice, 254 | } 255 | 256 | data = Data( 257 | type=types, 258 | pos=P, 259 | edge_attr=distances.unsqueeze(1), 260 | T_ee=T_ee, 261 | num_joints=num_joints.type(torch.int32), 262 | partial_mask=masks, 263 | partial_goal_mask=partial_goal_mask, 264 | edge_index_full=edge_index_full.type(torch.int32), 265 | M=T0, 266 | q_goal=q, 267 | ) 268 | return data, slices 269 | 270 | 271 | def generate_random_struct_data(dof): 272 | return generate_struct_data(random_revolute_robot_graph(dof)) 273 | 274 | 275 | def generate_randomized_robot_data(robot_type, dofs, num_examples, params): 276 | # generate data for randomized robots 277 | 278 | examples_per_dof = num_examples // len(dofs) 279 | print("Generating " + robot_type + " data!") 280 | 281 | all_struct_data = StructData( 282 | type=[], 283 | num_joints=[], 284 | num_nodes=[], 285 | num_edges=[], 286 | partial_mask=[], 287 | partial_goal_mask=[], 288 | edge_index_full=[], 289 | T0=[], 290 | ) 291 | 292 | for dof in dofs: 293 | with mp.Pool() as p: 294 | graphs = p.map(random_revolute_robot_graph, [dof] * examples_per_dof) 295 | for idx in tqdm(range(examples_per_dof), leave=False): 296 | struct_data = generate_struct_data(graphs[idx]) 297 | for field in struct_data.__dataclass_fields__: 298 | all_struct_data.__dict__[field].append(getattr(struct_data, field)) 299 | 300 | types = torch.cat(all_struct_data.type, dim=0) 301 | T0 = torch.cat(all_struct_data.T0, dim=0).reshape(-1, 4, 4) 302 | num_joints = torch.tensor(all_struct_data.num_joints) 303 | num_nodes = torch.tensor(all_struct_data.num_nodes) 304 | num_edges = torch.tensor(all_struct_data.num_edges) 305 | 306 | # problem is that edge_index_full doesn't contain self-loops 307 | masks = torch.cat(all_struct_data.partial_mask, dim=-1) 308 | edge_index_full = torch.cat(all_struct_data.edge_index_full, dim=-1) 309 | partial_goal_mask = torch.cat(all_struct_data.partial_goal_mask, dim=-1) 310 | 311 | # delete struct data 312 | all_struct_data = None 313 | 314 | q = torch.rand(num_joints.sum(), dtype=T0.dtype) * 2 * torch.pi - torch.pi 315 | q[(num_joints).cumsum(dim=-1) - 1] = 0 316 | T = batchFKmultiDOF(T0, q, num_joints) 317 | P = batchPmultiDOF(T, num_joints) 318 | T_ee = T[num_joints.cumsum(dim=-1)] 319 | offset_full = ( 320 | torch.cat([torch.tensor([0]), num_nodes[:-1].cumsum(dim=-1)]) 321 | .repeat_interleave(num_edges, dim=-1) 322 | .unsqueeze(0) 323 | .expand(2, -1) 324 | ) 325 | edge_index_full_offset = edge_index_full + offset_full 326 | distances = torch.linalg.norm( 327 | P[edge_index_full_offset[0], :] - P[edge_index_full_offset[1], :], dim=-1 328 | ) 329 | 330 | node_slice = torch.cat([torch.tensor([0]), (num_nodes).cumsum(dim=-1)]) 331 | joint_slice = torch.cat([torch.tensor([0]), (num_joints).cumsum(dim=-1)]) 332 | frame_slice = torch.cat([torch.tensor([0]), (num_joints + 1).cumsum(dim=-1)]) 333 | robot_slice = torch.arange(num_joints.size(0) + 1) 334 | edge_full_slice = torch.cat([torch.tensor([0]), (num_edges).cumsum(dim=-1)]) 335 | 336 | slices = { 337 | "edge_attr": edge_full_slice, 338 | "pos": node_slice, 339 | "type": node_slice, 340 | "T_ee": robot_slice, 341 | "num_joints": robot_slice, 342 | "partial_mask": edge_full_slice, 343 | "partial_goal_mask": node_slice, 344 | "edge_index_full": edge_full_slice, 345 | "M": frame_slice, 346 | "q_goal": joint_slice, 347 | } 348 | 349 | data = Data( 350 | type=types, 351 | pos=P, 352 | edge_attr=distances.unsqueeze(1), 353 | T_ee=T_ee, 354 | num_joints=num_joints.type(torch.int32), 355 | partial_mask=masks, 356 | partial_goal_mask=partial_goal_mask, 357 | edge_index_full=edge_index_full.type(torch.int32), 358 | M=T0, 359 | q_goal=q, 360 | ) 361 | 362 | return data, slices 363 | 364 | 365 | def generate_dataset(params, robots): 366 | dof = params.get("dof", [3]) # if no dofs are defined, default to 3 367 | num_examples = params.get("size", 1000) 368 | 369 | if robots[0] == "revolute_chain": 370 | data, slices = generate_randomized_robot_data( 371 | robots[0], dof, num_examples, params 372 | ) 373 | else: 374 | data, slices = generate_specific_robot_data(robots, num_examples, params) 375 | 376 | return data, slices 377 | 378 | 379 | def main(args): 380 | # torch.multiprocessing.set_sharing_strategy('file_system') 381 | if args.num_examples > args.max_examples_per_file: 382 | num_files = int(args.num_examples / args.max_examples_per_file) 383 | else: 384 | num_files = 1 385 | 386 | if args.storage_base_path is None: 387 | storage_path = generative_graphik.__path__[0] + "/../datasets/" + args.id + "/" 388 | val_path = ( 389 | generative_graphik.__path__[0] + "/../datasets/" + args.id + "_validation/" 390 | ) 391 | else: 392 | storage_path = args.storage_base_path 393 | val_path = args.storage_base_path + "_validation/" 394 | 395 | if not os.path.exists(storage_path): 396 | print(f"Path {storage_path} not found. Creating directory.") 397 | os.makedirs(storage_path) 398 | 399 | if not os.path.exists(val_path): 400 | print(f"Path {val_path} not found. Creating directory.") 401 | os.makedirs(val_path) 402 | 403 | print(f"Saving dataset to {storage_path} as {num_files} separate files.") 404 | for idx in range(num_files): 405 | 406 | dataset_params = { 407 | "size": args.num_examples // num_files, 408 | "samples": args.num_samples, 409 | "dof": args.dofs, 410 | "goal_type": args.goal_type, 411 | "randomize": args.randomize, 412 | "randomize_percentage": args.randomize_percentage, 413 | } 414 | 415 | data, slices = generate_dataset( 416 | dataset_params, 417 | args.robots, 418 | ) 419 | 420 | dataset = CachedDataset(data, slices) 421 | 422 | with open(os.path.join(storage_path, "data_" + f"{idx}" + ".p"), "wb") as f: 423 | torch.save(dataset, f) 424 | 425 | num_val_examples = int( 426 | (args.num_examples / num_files) / (100 / args.validation_percentage) 427 | ) 428 | print( 429 | f"Generating validation set with {num_val_examples} problems (10% of single file)." 430 | ) 431 | dataset_params = { 432 | "size": num_val_examples, 433 | "samples": args.num_samples, 434 | "dof": args.dofs, 435 | "goal_type": args.goal_type, 436 | "randomize": args.randomize, 437 | "randomize_percentage": args.randomize_percentage, 438 | } 439 | data, slices = generate_dataset( 440 | dataset_params, 441 | args.robots, 442 | ) 443 | dataset = CachedDataset(data, slices) 444 | with open(val_path + "data_0" + ".p", "wb") as f: 445 | torch.save(dataset, f) 446 | 447 | 448 | if __name__ == "__main__": 449 | args = parse_data_generation_args() 450 | main(args) 451 | -------------------------------------------------------------------------------- /generative_graphik/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch_geometric.utils import remove_self_loops 8 | from torch.distributions.independent import Independent 9 | from torch.distributions.mixture_same_family import MixtureSameFamily 10 | from torch.distributions import Normal, Categorical 11 | from scipy.sparse import find 12 | from liegroups.torch import SE3, SE2, SO3, SO2 13 | 14 | from graphik.utils import ( 15 | adjacency_matrix_from_graph, 16 | distance_matrix_from_graph, 17 | TYPE, 18 | POS 19 | ) 20 | import time 21 | 22 | TYPE_ENUM = { 23 | "base": np.asarray([1, 0, 0]), 24 | "robot": np.asarray([0, 1, 0]), 25 | "obstacle": np.asarray([0, 0, 1]), 26 | } 27 | 28 | def repeat_offset_index(index, repeat, offset): 29 | """Repeat and offset indices""" 30 | cumsum = 0 31 | new_edge_index = [] 32 | for _ in range(repeat): 33 | new_edge_index.append((index + cumsum).unsqueeze(1)) 34 | cumsum += offset 35 | new_edge_index = torch.cat(new_edge_index, dim=1) 36 | return new_edge_index 37 | 38 | def get_norm_layer(out_channels, num_groups=32, layer_type='None', layer_dim='1d'): 39 | if layer_type == 'BatchNorm': 40 | if layer_dim == '2d': 41 | return nn.BatchNorm2d(out_channels) 42 | elif layer_dim == '1d': 43 | return nn.BatchNorm1d(out_channels) 44 | elif layer_type == 'GroupNorm': 45 | return nn.GroupNorm(num_groups, out_channels) 46 | elif layer_type == 'LayerNorm': 47 | if layer_dim == '2d': 48 | return nn.GroupNorm(1, out_channels) 49 | elif layer_dim == '1d': 50 | return nn.LayerNorm(out_channels) 51 | elif layer_type == 'None': 52 | return nn.Identity() 53 | else: 54 | raise NotImplementedError(f"{layer_type} is not a valid normalization layer.") 55 | 56 | 57 | def kl_divergence(d1, d2, K=128): 58 | """Computes closed-form KL if available, else computes a MC estimate.""" 59 | if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY: 60 | return torch.distributions.kl_divergence(d1, d2) 61 | else: 62 | samples = d1.rsample(torch.Size([K])) 63 | return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0) 64 | 65 | 66 | def MixtureGaussianDiag(categorical_prob, loc, scale, batch_dim=1): 67 | return MixtureSameFamily( 68 | Categorical(categorical_prob), MultivariateNormalDiag(loc,scale), batch_dim 69 | ) 70 | 71 | 72 | def MultivariateNormalDiag(loc, scale, batch_dim=1): 73 | """Returns a diagonal multivariate normal Torch distribution.""" 74 | return Independent(Normal(loc, scale), batch_dim) 75 | 76 | 77 | def set_seed_torch(seed): 78 | """Set the same random seed for all sources of rng?""" 79 | np.random.seed(seed) 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed(seed) 82 | torch.cuda.manual_seed_all(seed) 83 | random.seed(seed) 84 | 85 | def torch_log_from_T(T): 86 | if T.dim() < 3: 87 | T = T.unsqueeze(0) 88 | 89 | if T.shape[-1] == 4: 90 | return SE3(SO3(T[:,:-1,:-1]), T[:,:-1,-1]).log() 91 | elif T.shape[-1] == 3: 92 | return SE2(SO2(T[:,:-1,:-1]), T[:,:-1,-1]).log() 93 | 94 | def torch_T_from_log(S): 95 | if S.dim() < 2: 96 | S = S.unsqueeze(0) 97 | 98 | if S.shape[-1] == 6: 99 | return SE3.exp(S).as_matrix().reshape(-1, 4, 4) 100 | elif S.shape[-1] == 3: 101 | return SE2.exp(S).as_matrix().reshape(-1, 3, 3) 102 | 103 | def torch_T_from_log_angle_axis(S,q): 104 | if S.dim() < 2: 105 | S = S.unsqueeze(0) 106 | q = q.unsqueeze(0) 107 | 108 | dim_tangent = S.shape[-1] 109 | dim = torch.div(dim_tangent, 3, rounding_mode="trunc") + 1 110 | if dim == 2: 111 | Sq = torch.mul(q.view(-1, 1), S.view(-1, dim_tangent)) 112 | T = torch_T_from_log(Sq).reshape(-1, dim+1, dim+1) 113 | else: 114 | Phi = torch.zeros(S.shape[0], dim, dim, device=S.device) 115 | Phi[:, 0, 1] = -S[:, 5] 116 | Phi[:, 1, 0] = S[:, 5] 117 | Phi[:, 0, 2] = S[:, 4] 118 | Phi[:, 2, 0] = -S[:, 4] 119 | Phi[:, 1, 2] = -S[:, 3] 120 | Phi[:, 2, 1] = S[:, 3] 121 | 122 | Phi2 = torch.bmm(Phi,Phi) 123 | 124 | I = torch.eye(3, device=S.device).unsqueeze(0).expand(S.shape[0],-1,-1) 125 | T = torch.eye(4, device=S.device).unsqueeze(0).repeat(S.shape[0],1,1) 126 | 127 | sin_q = torch.sin(q) 128 | cos_q = torch.cos(q) 129 | 130 | A = sin_q * Phi.view(S.shape[0],-1) + (1-cos_q) * Phi2.view(S.shape[0],-1) 131 | T[:,:dim,:dim] = I + A.view(S.shape[0],3,3) 132 | 133 | B = q * I.view(S.shape[0],-1) + (1 - cos_q) * Phi.view(S.shape[0],-1) + (q - sin_q) * Phi2.view(S.shape[0],-1) 134 | T[:,:dim,dim] = torch.matmul(B.view(S.shape[0],3,3), S[:,:dim].view(S.shape[0],3,1)).view(S.shape[0],3) 135 | 136 | return T 137 | 138 | def edge_indices_attributes(G): 139 | A = adjacency_matrix_from_graph(G) # adjacency matrix 140 | idx, jdx, _ = find(A) # indices of non-zero elements 141 | edges = torch.tensor([list(idx), list(jdx)], dtype=torch.long) 142 | 143 | D = np.sqrt(distance_matrix_from_graph(G)) # distance matrix 144 | d = D[A > 0][np.newaxis,:] 145 | distances = torch.tensor(d, dtype=torch.float).t() 146 | 147 | return edges, distances 148 | 149 | def node_attributes(G, attrs=["pos"]): 150 | 151 | out = [] 152 | for attr in attrs: 153 | if attr in [POS, "S"]: 154 | out += [ 155 | torch.tensor( 156 | [list(data) for node, data in G.nodes(data=attr)], dtype=torch.float 157 | ) 158 | ] 159 | if attr in ["T0"]: 160 | Ts = np.stack([data.as_matrix() for node, data in G.nodes(data=attr)]) 161 | out += [torch.tensor(Ts, dtype=torch.float)] 162 | if attr in [TYPE]: 163 | out += [ 164 | torch.tensor( 165 | [list(TYPE_ENUM[data[0]]) for node, data in G.nodes(data=attr)], 166 | dtype=torch.float, 167 | ) 168 | ] 169 | return tuple(out) 170 | 171 | class Flatten(nn.Module): 172 | def forward(self, x): 173 | return x.view(x.size(0), -1) 174 | 175 | @torch.compile 176 | def SE3_from(rot: Optional[torch.Tensor] = None, trans: Optional[torch.Tensor] = None) -> torch.Tensor: 177 | # Composes SE3 matrix from rotation and translation component 178 | # ---------------------------------------------------------------------- 179 | # T = torch.eye(4, device=rot.device).repeat(rot.shape[0],1,1) 180 | if trans is not None and rot is not None: 181 | T = torch.eye(4, device=rot.device, dtype=rot.dtype).repeat(rot.shape[0],1,1) 182 | T[:,:-1,:-1] = rot 183 | T[:,:-1,-1] = trans 184 | elif rot is not None: 185 | T = torch.eye(4, device=rot.device, dtype=rot.dtype).repeat(rot.shape[0],1,1) 186 | T[:,:-1,:-1] = rot 187 | elif trans is not None: 188 | T = torch.eye(4, device=trans.device, dtype=trans.dtype).repeat(trans.shape[0],1,1) 189 | T[:,:-1,-1] = trans 190 | else: 191 | raise Exception("Neither rotation nor translation specified") 192 | return T 193 | 194 | @torch.compile 195 | def SE3_inv_from(rot: torch.Tensor, trans: torch.Tensor) -> torch.Tensor: 196 | # Composes inverse SE3 matrix from rotation and translation component 197 | # ---------------------------------------------------------------------- 198 | inv_rot = rot.transpose(2,1) 199 | inv_trans = -torch.bmm(inv_rot, trans[:,:,None])[:,:,0] 200 | inv_T = torch.eye(4, device=rot.device, dtype=rot.dtype).repeat(rot.shape[0],1,1) 201 | inv_T[:,:-1,:-1] = inv_rot 202 | inv_T[:,:-1,-1] = inv_trans 203 | return inv_T 204 | 205 | def SE3_inv(T: torch.Tensor): 206 | # Computes SE3 matrix inverse 207 | # ---------------------------------------------------------------------- 208 | return SE3_inv_from(T[:,:-1,:-1], T[:,:-1,-1]) 209 | 210 | @torch.compile 211 | def rotz(angle_in_radians, dim=3): 212 | """Form a rotation matrix given an angle in rad about the z-axis.""" 213 | s = angle_in_radians.sin() 214 | c = angle_in_radians.cos() 215 | 216 | mat = angle_in_radians.new_empty( 217 | angle_in_radians.shape[0], dim, dim, dtype=angle_in_radians.dtype).zero_() 218 | mat[:, 2, 2] = 1. 219 | mat[:, 0, 0] = c 220 | mat[:, 0, 1] = -s 221 | mat[:, 1, 0] = s 222 | mat[:, 1, 1] = c 223 | return mat 224 | 225 | @torch.compile 226 | def batchJointScrews(T0: torch.Tensor): 227 | omega = T0[:,:-1,2] # z axis 228 | q = T0[:,:-1,-1] 229 | v = torch.cross(-omega, q) 230 | return torch.cat([v, omega], dim=-1) 231 | 232 | 233 | # @torch.compile 234 | def batchIKmultiDOF( 235 | P: torch.Tensor, 236 | T0: torch.Tensor, 237 | num_joints: torch.Tensor, 238 | T_final: Optional[torch.Tensor] = None, 239 | T_rel: Optional[torch.Tensor] = None, 240 | qs_0: Optional[torch.Tensor] = None, 241 | ) -> torch.Tensor: 242 | # Computes IK for multiple robots with varying DOF from points. 243 | # ---------------------------------------------------------------------- 244 | # P [sum(num_nodes) x 3] - points corresponding to the distance-geometric 245 | # robot model described in (Maric et al.,2021) 246 | # T0 [num_nodes x 16] - 4x4 matrices of robot frames in home poses 247 | # num_joints [num_edges x 1] - number of joints (DOF) for each robot in batch 248 | # ---------------------------------------------------------------------- 249 | # 250 | device = T0.device # GPU or CPU 251 | dtype = T0.dtype # data type for non-integer 252 | dim = P.shape[1] # dimension (2 or 3) 253 | 254 | # constant matrices (robot structure, known a priori) 255 | if T_rel is None or qs_0 is None: 256 | T0_inv = SE3_inv(T0) # inverses of T0 (zero config xf) 257 | T_rel = torch.bmm(T0_inv, T0.roll(-1,0)) # relative xf between T0 258 | T0_q = SE3_from(T0[:,:-1,:-1], T0[:,:-1,-1] + T0[:,:-1,-2]) # frames at points q 259 | qs_0 = torch.bmm(T0_inv, T0_q.roll(-1,0))[:,:-1,-1] # points q at home config 260 | 261 | omega_z = torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 0]], dtype=dtype, device=device) 262 | omega_z_sq = torch.mm(omega_z, omega_z.transpose(1,0)) 263 | 264 | # number of robots, joints, unique ids 265 | num_robots = num_joints.shape[0] # total number of robots 266 | num_nodes = 2*(num_joints+1) + (dim-1) # number of nodes for point graphs 267 | node_start_ind = torch.cumsum(num_nodes,dim=0) - num_nodes # start indices for nodes 268 | node_start_ind = node_start_ind.to(device) 269 | robot_ids = torch.arange(num_robots, device=device) 270 | 271 | # normalizes the node positions to the canonical coordinate system 272 | x_hat = (P[node_start_ind + 1] - P[node_start_ind]) 273 | y_hat = -(P[node_start_ind + 2] - P[node_start_ind]) 274 | z_hat = (P[node_start_ind + 3] - P[node_start_ind]) 275 | 276 | # get modified base frames 277 | R = torch.cat([x_hat.unsqueeze(1), y_hat.unsqueeze(1), z_hat.unsqueeze(1)], dim = 1).transpose(2,1) 278 | B_inv = SE3_inv_from(R, P[node_start_ind]) 279 | hl = robot_ids.repeat_interleave(num_joints + 1, dim=0) 280 | # hl = torch.arange(num_robots, device=device).repeat_interleave(num_joints + 1, dim=0) 281 | # hl = torch.repeat_interleave(torch.arange(num_robots, device=device), num_joints + 1, dim=0) 282 | 283 | # mask = torch.tensor(True, device=device).repeat(num_nodes.sum()) 284 | mask = torch.ones(num_nodes.sum(), device=device, dtype=torch.bool) 285 | mask[node_start_ind+1] = False 286 | mask[node_start_ind+2] = False 287 | P = P.masked_select(mask[:,None].expand(-1,3)).reshape(-1,3) 288 | 289 | # Initialize frame and joint angle tensors 290 | num_frames_total = T0.shape[0] # total number of frames on robots 291 | T = torch.eye(dim+1, device=device, dtype=dtype)[None,:,:].repeat(num_frames_total,1,1) 292 | theta = torch.zeros(num_frames_total, dtype=dtype, device=device) 293 | 294 | 295 | # indices of relevant p and q nodes 296 | ind = torch.arange(num_frames_total, device=device) 297 | idx_p = 2*(ind - 1) + 2 298 | idx_q = 2*(ind - 1) + 3 299 | 300 | # compute normalized q (i.e., distance fixed to 1) and transform to base frame 301 | q = torch.baddbmm(B_inv[hl[ind],:-1,-1][:,:,None], B_inv[hl[ind],:-1,:-1], P[idx_q][:,:,None])[:,:,0] 302 | 303 | A = torch.matmul(qs_0, omega_z)[:,None,:] 304 | B = torch.matmul(qs_0, omega_z_sq)[:,None,:] 305 | 306 | # generate virtual edge indices used to multiply pairs of joints 307 | joint_end_ind = torch.cumsum(num_joints + 1, dim=0) - 1 # end indices of joints 308 | joint_start_ind = joint_end_ind - num_joints # start indices of joints 309 | 310 | ei = torch.zeros(2, joint_start_ind.size(-1), dtype=torch.int64, device=device) 311 | ei[0] = joint_start_ind + 1 # starting jnt id repeat for all con 312 | ei[1] = joint_end_ind 313 | inc = torch.tensor([[1],[0]], dtype=torch.int64, device=device) # tensor for inc 314 | for _ in range(1, num_joints.max()): 315 | # q point expressed in previous frame 316 | qs = torch.bmm(T[ei[0]-1,:-1,:-1].transpose(2,1), (q[ei[0]] - T[ei[0]-1,:-1,-1])[:,:,None])[:,:,0] 317 | 318 | # compute angle approximation 319 | theta[ei[0]-1] = torch.atan2( 320 | -torch.bmm(A[ei[0]-1], qs[:,None,:].transpose(2,1))[:,0] + 1e-7, 321 | torch.bmm(B[ei[0]-1], qs[:,None,:].transpose(2,1))[:,0] + 1e-7 322 | ).reshape(-1) 323 | 324 | rotmat = SE3_from(rotz(theta[ei[0]-1])) 325 | T[ei[0]] = torch.bmm(torch.bmm(T[ei[0]-1], rotmat), T_rel[ei[0]-1].expand(ei[0].size(-1),-1,-1)) 326 | 327 | ei, _ = remove_self_loops(ei + inc) # removes equivalent elements 328 | 329 | ind = torch.arange(num_joints.sum(), device=device) + robot_ids.repeat_interleave(num_joints) 330 | # ind = torch.arange(num_joints.sum(), device=device) + torch.arange(num_robots, device=device).repeat_interleave(num_joints) 331 | # ind = torch.arange(num_joints.sum(), device=device) + torch.repeat_interleave(torch.arange(num_robots, device=device), num_joints, dim=0) 332 | 333 | if T_final is not None: 334 | # parallel = torch.linalg.cross(T_rel[joint_end_ind-1,:-1,-1], torch.tensor([[0,0,1]], device=device, dtype=dtype)).norm(dim=-1) < 1e-6 335 | # parallel_idx = parallel.nonzero() 336 | T_th = torch.bmm(SE3_inv(T[joint_end_ind-1]), T_final) 337 | theta[joint_end_ind-1] = theta[joint_end_ind-1] + torch.atan2(T_th[:, 1, 0], T_th[:, 0, 0]) 338 | 339 | return theta[ind] 340 | 341 | def batchPmultiDOF(T: torch.Tensor, num_joints: torch.Tensor) -> torch.Tensor: 342 | # from frame transforms to batched points 343 | # the key problem is to sort everything properly 344 | 345 | device = T.device # GPU or CPU 346 | dtype = T.dtype # data type for non-integer 347 | dim = 3 # problem dimension 348 | 349 | num_frames_total = T.shape[0] # total number of frames on robots 350 | num_nodes = 2*(num_joints+1) + (dim-1) # number of nodes for point graphs 351 | node_start_ind = torch.cumsum(num_nodes,dim=0) - num_nodes # start indices for nodes 352 | 353 | # indices of relevant p and q n 354 | ind = torch.arange(num_frames_total, device=device) 355 | idx_p = 2*(ind - 1) + 2 356 | idx_q = 2*(ind - 1) + 3 357 | 358 | pos_p = T[:,:-1,-1] 359 | pos_q = T[:,:-1,-1] + T[:,:-1,-2] 360 | pos_pq = torch.zeros(2*num_frames_total, 3, dtype=dtype, device=device) 361 | pos_pq[torch.cat([idx_p, idx_q], dim=-1), :] = torch.cat([pos_p, pos_q], dim=0) 362 | 363 | 364 | P = torch.zeros(num_nodes.sum(),3) 365 | 366 | mask = torch.tensor(True, device=device).repeat(num_nodes.sum()) 367 | mask[node_start_ind+1] = False 368 | mask[node_start_ind+2] = False 369 | P = torch.masked_scatter(P, mask[:,None].expand(-1,3), pos_pq) 370 | P[node_start_ind + 1] = torch.tensor([1,0,0], dtype=dtype, device=device) 371 | P[node_start_ind + 2] = torch.tensor([0,-1,0], dtype=dtype, device=device) 372 | 373 | return P 374 | 375 | @torch.compile 376 | def batchFKmultiDOF(T0: torch.Tensor, q: torch.Tensor, num_joints: torch.Tensor) -> torch.Tensor: 377 | # Computes FK for multiple robots with varying DOF using graphs. 378 | # Uses the lie group FK formula: 379 | # T_ee = exp(S0*q0)*exp(S1*q1)*...*exp(Sn*qn)*M 380 | # ---------------------------------------------------------------------- 381 | # T0 [num_nodes x 16] - 4x4 matrices of robot frames in home poses 382 | # q [num_edges x 1] - joint angles 383 | # num_joints [num_robots] - number of joints per robot. 384 | # This is the the joint-based model description so it's num_joints+1 385 | # edge_index - all edges in batch 386 | # ---------------------------------------------------------------------- 387 | 388 | device = T0.device 389 | num_frames = num_joints + 1 390 | num_robots = num_frames.shape[0] 391 | total_num_frames = T0.shape[0] 392 | max_num_frames = num_frames.max() 393 | frame_start_ind = num_frames.cumsum(dim=0) - num_frames 394 | 395 | # get edge index 396 | ind_i = torch.arange(num_joints.sum(), device=device) + torch.arange(num_robots, device=device).repeat_interleave(num_joints) 397 | ind_j = ind_i + 1 398 | edge_index = torch.concat([ind_i, ind_j], dim=0).reshape(2,-1) 399 | total_num_edges = edge_index.shape[-1] 400 | 401 | # compute joint screws 402 | S = batchJointScrews(T0[edge_index[0]]) 403 | dim_tangent = S.shape[-1] 404 | dim = (dim_tangent // 3) + 1 405 | 406 | # get Tq = exp(S0*q) for every joint for every robot, shape [total_num_nodes x 4 x 4] 407 | Tq = torch.empty(total_num_frames, dim+1, dim+1, device=device) 408 | Tq[edge_index[0]] = torch_T_from_log_angle_axis(S, q.view(total_num_edges, 1)) 409 | 410 | T0 = T0.reshape(total_num_frames, dim+1, dim+1) 411 | T = torch.eye(dim+1, device=device).unsqueeze(0).repeat(total_num_frames,1,1) 412 | 413 | ei = torch.zeros(edge_index.shape[0], total_num_frames, dtype=torch.int64, device=device) 414 | # ei[0] = frame_start_ind.repeat_interleave(num_frames) # starting jnt id repeat for all con 415 | ei[0] = torch.repeat_interleave(frame_start_ind, num_frames, 0) 416 | ei[1] = torch.arange(0, total_num_frames) # all others 417 | inc = torch.tensor([[1],[0]], dtype=torch.int64, device=device) # tensor for inc 418 | for _ in range(max_num_frames-1): 419 | ei, _ = remove_self_loops(ei) # remove self-loops for starting 420 | T[ei[1]] = torch.bmm(T[ei[1]], Tq[ei[0]]) # apply action from prev to cur 421 | ei = ei + inc # next 422 | T[edge_index[1]] = torch.bmm(T[edge_index[1]], T0[edge_index[1]]) 423 | 424 | return T 425 | 426 | def batchPostProc( 427 | P: torch.Tensor, 428 | T0: torch.Tensor, 429 | num_joints: torch.Tensor, 430 | T_final: torch.Tensor, 431 | T_rel: torch.Tensor, 432 | qs_0: torch.Tensor, 433 | A: torch.Tensor, 434 | B: torch.Tensor, 435 | ) -> Tuple[torch.Tensor, torch.Tensor]: 436 | # Computes IK for multiple robots with varying DOF from points. 437 | # ---------------------------------------------------------------------- 438 | # P [sum(num_nodes) x 3] - points corresponding to the distance-geometric 439 | # robot model described in (Maric et al.,2021) 440 | # T0 [num_nodes x 16] - 4x4 matrices of robot frames in home poses 441 | # num_joints [num_edges x 1] - number of joints (DOF) for each robot in batch 442 | # ---------------------------------------------------------------------- 443 | device = T0.device # GPU or CPU 444 | dtype = T0.dtype # data type for non-integer 445 | dim = P.shape[1] # dimension (2 or 3) 446 | 447 | # number of robots, joints, unique ids 448 | num_robots = num_joints.shape[0] # total number of robots 449 | num_nodes = 2*(num_joints+1) + (dim-1) # number of nodes for point graphs 450 | node_start_ind = torch.cumsum(num_nodes,dim=0) - num_nodes # start indices for nodes 451 | robot_ids = torch.arange(num_robots, device=device) 452 | 453 | # normalizes the node positions to the canonical coordinate system 454 | x_hat = (P[node_start_ind + 1] - P[node_start_ind]) 455 | y_hat = -(P[node_start_ind + 2] - P[node_start_ind]) 456 | z_hat = (P[node_start_ind + 3] - P[node_start_ind]) 457 | 458 | # get modified base frames 459 | R = torch.cat([x_hat.unsqueeze(1), y_hat.unsqueeze(1), z_hat.unsqueeze(1)], dim = 1).transpose(2,1) 460 | B_inv = SE3_inv_from(R, P[node_start_ind]) 461 | hl = robot_ids.repeat_interleave(num_joints + 1, dim=0) 462 | 463 | mask = torch.ones(num_nodes.sum(), device=device, dtype=torch.bool) 464 | mask[node_start_ind+1] = False 465 | mask[node_start_ind+2] = False 466 | P = P.masked_select(mask[:,None].expand(-1,3)).reshape(-1,3) 467 | 468 | # Initialize frame and joint angle tensors 469 | num_frames_total = T0.shape[0] # total number of frames on robots 470 | T = torch.eye(dim+1, device=device, dtype=dtype)[None,:,:].repeat(num_frames_total,1,1) 471 | theta = torch.zeros(num_frames_total, dtype=dtype, device=device) 472 | 473 | # indices of relevant p and q nodes 474 | ind = torch.arange(num_frames_total, device=device) 475 | idx_p = 2*(ind - 1) + 2 476 | idx_q = 2*(ind - 1) + 3 477 | 478 | # compute normalized q (i.e., distance fixed to 1) and transform to base frame 479 | q = torch.baddbmm(B_inv[hl[ind],:-1,-1][:,:,None], B_inv[hl[ind],:-1,:-1], P[idx_q][:,:,None])[:,:,0] 480 | 481 | # generate virtual edge indices used to multiply pairs of joints 482 | joint_end_ind = torch.cumsum(num_joints + 1, dim=0) - 1 # end indices of joints 483 | joint_start_ind = joint_end_ind - num_joints # start indices of joints 484 | 485 | ei = torch.zeros(2, joint_start_ind.size(-1), dtype=torch.int64, device=device) 486 | ei[0] = joint_start_ind + 1 # starting jnt id repeat for all con 487 | ei[1] = joint_end_ind 488 | inc = torch.tensor([[1],[0]], dtype=torch.int64, device=device) # tensor for inc 489 | for _ in range(1, num_joints.max()): 490 | # q point expressed in previous frame 491 | qs = torch.bmm(T[ei[0]-1,:-1,:-1].transpose(2,1), (q[ei[0]] - T[ei[0]-1,:-1,-1])[:,:,None])[:,:,0] 492 | 493 | # compute angle approximation 494 | theta[ei[0]-1] = torch.atan2( 495 | -torch.bmm(A[ei[0]-1], qs[:,:,None])[:,0] + 1e-7, 496 | torch.bmm(B[ei[0]-1], qs[:,:,None])[:,0] + 1e-7 497 | ).reshape(-1) 498 | 499 | rotmat = SE3_from(rotz(theta[ei[0]-1])) 500 | T[ei[0]] = torch.bmm(torch.bmm(T[ei[0]-1], rotmat), T_rel[ei[0]-1]) 501 | 502 | ei, _ = remove_self_loops(ei + inc) # removes equivalent elements 503 | 504 | ind = torch.arange(num_joints.sum(), device=device) + robot_ids.repeat_interleave(num_joints) 505 | 506 | if T_final is not None: 507 | T_th = torch.bmm(SE3_inv(T[joint_end_ind-1]), T_final) 508 | theta[joint_end_ind-1] = theta[joint_end_ind-1] + torch.atan2(T_th[:, 1, 0], T_th[:, 0, 0]) 509 | rotmat = SE3_from(rotz(theta[joint_end_ind-1])) 510 | T[joint_end_ind] = torch.bmm(torch.bmm(T[joint_end_ind-1], rotmat), T_rel[joint_end_ind-1].expand(joint_end_ind.size(-1),-1,-1)) 511 | 512 | return theta[ind], T 513 | -------------------------------------------------------------------------------- /paper_models.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e664a96e7a64e01cbd004206942806eed00522ea15841587b0bccfe465711558 3 | size 259136245 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # TODO: see https://github.com/pymanopt/pymanopt/blob/master/setup.py for mmore later 4 | setup( 5 | name="generative-graphik", 6 | version="0.1", 7 | description="Generative inverse kinematics", 8 | author="Filip Maric, Oliver Limoyo", 9 | author_email="filip.maric@robotics.utias.utoronto.ca, oliver.limoyo@robotics.utias.utoronto.ca", 10 | license="MIT", 11 | url="https://github.com/utiasSTARS/generative-graphik", 12 | packages=find_packages(), 13 | install_requires=[ 14 | "pandas", 15 | "urdfpy", 16 | "numpy <= 1.23.5", 17 | "liegroups @ git+ssh://git@github.com/utiasSTARS/liegroups@generative_ik#egg=liegroups", 18 | "graphIK @ git+ssh://git@github.com/utiasSTARS/graphIK@generative_ik#egg=graphIK", 19 | "networkx >= 2.8.7", 20 | "tensorboard" 21 | ], 22 | python_requires=">=3.8", 23 | ) 24 | --------------------------------------------------------------------------------