├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── VCD ├── dataset.py ├── dataset_edge.py ├── generate_cached_initial_state.py ├── main.py ├── main_plan.py ├── main_train_edge.py ├── models.py ├── rs_planner.py ├── utils │ ├── __init__.py │ ├── camera_utils.py │ ├── data_utils.py │ ├── gemo_utils.py │ ├── plot_utils.py │ └── utils.py ├── vc_dynamics.py └── vc_edge.py ├── chester ├── README.md ├── add_variants.py ├── config.py ├── containers │ └── ubuntu-16.04-lts-rl.README ├── docs │ ├── Makefile │ ├── README.md │ ├── make.bat │ ├── readme.md │ └── source │ │ ├── _static │ │ └── basic_screenshot.png │ │ ├── conf.py │ │ ├── getting_started.rst │ │ ├── index.rst │ │ ├── launcher.rst │ │ ├── logger.rst │ │ └── visualization.rst ├── examples │ ├── cplot_example.py │ ├── pgm_plot.py │ ├── presets.py │ ├── presets2.py │ ├── presets3.py │ ├── presets_tiancheng.py │ ├── train.py │ └── train_launch.py ├── logger.py ├── plotting │ └── cplot.py ├── pull_result.py ├── pull_s3_result.py ├── rsync_exclude ├── rsync_include ├── run_exp.py ├── run_exp_worker.py ├── scripts │ ├── install_miniconda.sh │ └── install_mpi4py.sh ├── setup_ec2_for_chester.py ├── slurm.py ├── upload_result.py ├── utils_s3.py └── video_recorder.py ├── compile_1.0.sh ├── environment.yml ├── prepare_1.0.sh └── pretrained ├── README.md ├── vis_dynamics.gif └── vis_planning.gif /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.zip 3 | pretrained/dataset 4 | *.pkl 5 | softgym_rpad 6 | softgym_public 7 | wandb 8 | GNS/pcl_filter/build 9 | rlpyt_cloth/data 10 | datasets 11 | dpi_visualization 12 | .vscode 13 | chester/private 14 | wandb 15 | videos 16 | .idea 17 | data 18 | *.simg 19 | *.img 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | pip-wheel-metadata/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # celery beat schedule file 105 | celerybeat-schedule 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | 138 | # Results 139 | results/ 140 | 141 | # DPI data 142 | DPI-Net/dump_*/ 143 | DPI-Net/dump_*/ 144 | 145 | test_*/ 146 | imgs/ 147 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "softgym"] 2 | path = softgym 3 | url = https://github.com/Xingyu-Lin/softgym.git 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Carnegie Mellon University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Learning Visible Connectivity Dynamics for Cloth Smoothing

2 | 3 | Xingyu Lin*, Yufei Wang*, Zixuan Huang, David Held, CoRL 2021. `*` indicates equal contribution (order by dice rolling). 4 | 5 | [Website](https://sites.google.com/view/vcd-cloth/home) / [ArXiv](https://arxiv.org/pdf/2105.10389.pdf) 6 | 7 | # Table of Contents 8 | - 1 [Simulation](#simulation) 9 | - 1.1 [Setup](#setup) 10 | - 1.2 [Train VCD](#train-vcd) 11 | - 1.3 [Planning with VCD](#plan-vcd) 12 | - 1.4 [Graph Imitation Learning](#graph-imit) 13 | - 1.5 [Pretrained Models](#pretrained) 14 | - 1.6 [Demo](#Demo) 15 | ---- 16 | # Simulation 17 | 18 | ## Setup 19 | This repository is a subset of [SoftAgent](https://github.com/Xingyu-Lin/softagent) cleaned up for VCD. Environment setup for VCD is similar to that of softagent. 20 | 1. Install [SoftGym](https://github.com/Xingyu-Lin/softgym). Then, copy softgym as a submodule in this directory by running `cp -r [path to softgym] ./`. Use the updated softgym on the vcd branch by `cd softgym && git checkout vcd` 21 | 2. You should have a conda environment named `softgym`. Install additional packages required by VCD, by `conda env update --file environment.yml` 22 | 3. Generate initial environment configurations and cache them, by running `python VCD/generate_cached_initial_state.py`. 23 | 4. Run `./compile_1.0.sh && . ./prepare_1.0.sh` to compile PyFleX and prepare other paths. 24 | 25 | ## Train VCD 26 | * Generate the dataset for training by running 27 | ``` 28 | python VCD/main.py --gen_data=1 --dataf=./data/vcd 29 | ``` 30 | Please refer to `main.py` for argument options. 31 | 32 | * Train the dynamics model by running 33 | ``` 34 | python VCD/main.py --gen_data=0 --dataf=./data/vcd_dyn 35 | ``` 36 | * Train the EdgeGNN model by running 37 | ``` 38 | python VCD/main_train_edge.py --gen_data=0 --dataf=./data/vcd_edge 39 | ``` 40 | ## Planning with VCD 41 | ``` 42 | python VCD/main_plan.py --edge_model_path={path_to_trained_edge_model}\ 43 | --partial_dyn_path={path_to_trained_dynamics_model} 44 | ``` 45 | An example for loading the model trained for 120 epochs: 46 | ``` 47 | python VCD/main_plan.py --edge_model_path ./data/vcd_edge/vsbl_edge_120.pth\ 48 | --partial_dyn_path ./data/vcd_dyn/vsbl_dyn_120.pth 49 | ``` 50 | ## Graph Imitation Learning 51 | 1. Train dynamics using the full mesh 52 | ``` 53 | python VCD/main.py --gen_data=0 --dataf=./data/vcd --train_mode=full 54 | ``` 55 | 2. Train dynamics using partial point cloud and imitate the teacher model 56 | ``` 57 | python VCD/main.py --gen_data=0 --dataf=./data/vcd --train_mode=graph_imit --full_dyn_path={path_to_teacher_model} 58 | ``` 59 | 60 | ## Pretrained Model 61 | Please refer to [this page](pretrained/README.md) for downloading the pretrained models. 62 | 63 | ## Demo 64 | * Dynamics rollout 65 | ![](pretrained/vis_dynamics.gif) 66 | * Planning on square cloth 67 | ![](pretrained/vis_planning.gif) 68 | ## Cite 69 | If you find this codebase useful in your research, please consider citing: 70 | ``` 71 | @inproceedings{lin2022learning, 72 | title={Learning visible connectivity dynamics for cloth smoothing}, 73 | author={Lin, Xingyu and Wang, Yufei and Huang, Zixuan and Held, David}, 74 | booktitle={Conference on Robot Learning}, 75 | pages={256--266}, 76 | year={2022}, 77 | organization={PMLR} 78 | } 79 | 80 | @inproceedings{corl2020softgym, 81 | title={SoftGym: Benchmarking Deep Reinforcement Learning for Deformable Object Manipulation}, 82 | author={Lin, Xingyu and Wang, Yufei and Olkin, Jake and Held, David}, 83 | booktitle={Conference on Robot Learning}, 84 | year={2020} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /VCD/dataset_edge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import spatial 4 | from torch_geometric.data import Data 5 | 6 | from VCD.utils.camera_utils import get_observable_particle_index_3 7 | from VCD.dataset import ClothDataset 8 | from VCD.utils.utils import load_data, voxelize_pointcloud 9 | 10 | 11 | class ClothDatasetPointCloudEdge(ClothDataset): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | def __getitem__(self, idx): 16 | data = self._prepare_transition(idx) 17 | d = self.build_graph(data) 18 | return Data.from_dict(d) 19 | 20 | def _prepare_transition(self, idx, eval=False): 21 | pred_time_interval = self.args.pred_time_interval 22 | success = False 23 | next = 1 if not eval else self.args.time_step - self.args.n_his 24 | 25 | while not success: 26 | idx_rollout = (idx // (self.args.time_step - self.args.n_his)) % self.n_rollout 27 | idx_timestep = (self.args.n_his - pred_time_interval) + idx % (self.args.time_step - self.args.n_his) 28 | idx_timestep = max(idx_timestep, 0) 29 | 30 | data = load_data(self.data_dir, idx_rollout, idx_timestep, self.data_names) 31 | pointcloud = data['pointcloud'].astype(np.float32) 32 | if len(pointcloud.shape) != 2: 33 | print('dataset_edge.py, errorneous data. What is going on?') 34 | import pdb 35 | pdb.set_trace() 36 | 37 | idx += next 38 | continue 39 | 40 | if len(pointcloud) < 100: # TODO Filter these during dataset generation 41 | print('dataset_edge.py, fix this') 42 | import pdb 43 | pdb.set_trace() 44 | idx += next 45 | continue 46 | 47 | vox_pc = voxelize_pointcloud(pointcloud, self.args.voxel_size) 48 | 49 | partial_particle_pos = data['positions'][data['downsample_idx']][data['downsample_observable_idx']] 50 | if len(vox_pc) <= len(partial_particle_pos): 51 | success = True 52 | 53 | # NOTE: what is this for? 54 | if eval and not success: 55 | return None 56 | 57 | idx += next 58 | 59 | pointcloud, partial_pc_mapped_idx = get_observable_particle_index_3(vox_pc, partial_particle_pos, threshold=self.args.voxel_size) 60 | normalized_vox_pc = vox_pc - np.mean(vox_pc, axis=0) 61 | 62 | ret_data = { 63 | 'scene_params': data['scene_params'], 64 | 'downsample_observable_idx': data['downsample_observable_idx'], 65 | 'normalized_vox_pc': normalized_vox_pc, 66 | 'partial_pc_mapped_idx': partial_pc_mapped_idx, 67 | } 68 | if eval: 69 | ret_data['downsample_idx'] = data['downsample_idx'] 70 | ret_data['pointcloud'] = vox_pc 71 | 72 | return ret_data 73 | 74 | def _compute_edge_attr(self, vox_pc): 75 | point_tree = spatial.cKDTree(vox_pc) 76 | undirected_neighbors = np.array(list(point_tree.query_pairs(self.args.neighbor_radius, p=2))).T 77 | 78 | if len(undirected_neighbors) > 0: 79 | dist_vec = vox_pc[undirected_neighbors[0, :]] - vox_pc[undirected_neighbors[1, :]] 80 | dist = np.linalg.norm(dist_vec, axis=1, keepdims=True) 81 | edge_attr = np.concatenate([dist_vec, dist], axis=1) 82 | edge_attr_reverse = np.concatenate([-dist_vec, dist], axis=1) 83 | 84 | # Generate directed edge list and corresponding edge attributes 85 | edges = torch.from_numpy(np.concatenate([undirected_neighbors, undirected_neighbors[::-1]], axis=1)) 86 | edge_attr = torch.from_numpy(np.concatenate([edge_attr, edge_attr_reverse])) 87 | else: 88 | print("number of distance edges is 0! adding fake edges") 89 | edges = np.zeros((2, 2), dtype=np.uint8) 90 | edges[0][0] = 0 91 | edges[1][0] = 1 92 | edges[0][1] = 0 93 | edges[1][1] = 2 94 | edge_attr = np.zeros((2, self.args.relation_dim), dtype=np.float32) 95 | edges = torch.from_numpy(edges).bool() 96 | edge_attr = torch.from_numpy(edge_attr) 97 | print("shape of edges: ", edges.shape) 98 | print("shape of edge_attr: ", edge_attr.shape) 99 | 100 | return edges, edge_attr 101 | 102 | def build_graph(self, data, get_gt_edge_label=True): 103 | """ 104 | data: positions, picked_points, picked_point_positions, scene_params 105 | downsample: whether to downsample the graph 106 | test: if False, we are in the training mode, where we know exactly the picked point and its movement 107 | if True, we are in the test mode, we have to infer the picked point in the (downsampled graph) and compute 108 | its movement. 109 | 110 | return: 111 | node_attr: N x (vel_history x 3) 112 | edges: 2 x E, the edges 113 | edge_attr: E x edge_feature_dim 114 | gt_mesh_edge: 0/1 label for groundtruth mesh edge connection. 115 | """ 116 | node_attr = torch.from_numpy(data['normalized_vox_pc']) 117 | edges, edge_attr = self._compute_edge_attr(data['normalized_vox_pc']) 118 | 119 | if get_gt_edge_label: 120 | gt_mesh_edge = self._get_gt_mesh_edge(data, edges) 121 | gt_mesh_edge = torch.from_numpy(gt_mesh_edge) 122 | else: 123 | gt_mesh_edge = None 124 | 125 | return { 126 | 'x': node_attr, 127 | 'edge_index': edges, 128 | 'edge_attr': edge_attr, 129 | 'gt_mesh_edge': gt_mesh_edge 130 | } 131 | 132 | def _get_gt_mesh_edge(self, data, distance_edges): 133 | scene_params, observable_particle_idx, partial_pc_mapped_idx = data['scene_params'], data['downsample_observable_idx'], data['partial_pc_mapped_idx'] 134 | _, cloth_xdim, cloth_ydim, _ = scene_params 135 | cloth_xdim, cloth_ydim = int(cloth_xdim), int(cloth_ydim) 136 | 137 | observable_mask = np.zeros(cloth_xdim * cloth_ydim) 138 | observable_mask[observable_particle_idx] = 1 139 | 140 | num_edges = distance_edges.shape[1] 141 | gt_mesh_edge = np.zeros((num_edges, 1), dtype=np.float32) 142 | 143 | for edge_idx in range(num_edges): 144 | # the edge index is in the range [0, len(pointcloud) - 1] 145 | # needs to convert it back to the idx in the downsampled graph 146 | s = int(distance_edges[0][edge_idx].item()) 147 | r = int(distance_edges[1][edge_idx].item()) 148 | 149 | # map from pointcloud idx to observable particle index 150 | s = partial_pc_mapped_idx[s] 151 | r = partial_pc_mapped_idx[r] 152 | 153 | s = observable_particle_idx[s] 154 | r = observable_particle_idx[r] 155 | 156 | if (r == s + 1 or r == s - 1 or 157 | r == s + cloth_xdim or r == s - cloth_xdim or 158 | r == s + cloth_xdim + 1 or r == s + cloth_xdim - 1 or 159 | r == s - cloth_xdim + 1 or r == s - cloth_xdim - 1 160 | ): 161 | gt_mesh_edge[edge_idx] = 1 162 | 163 | return gt_mesh_edge 164 | -------------------------------------------------------------------------------- /VCD/generate_cached_initial_state.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from softgym.registered_env import SOFTGYM_ENVS, env_arg_dict 3 | from VCD.main import get_default_args 4 | 5 | 6 | def create_env(args): 7 | assert args.env_name == 'ClothFlatten' 8 | 9 | env_args = copy.deepcopy(env_arg_dict[args.env_name]) # Default args 10 | env_args['cached_states_path'] = args.cached_states_path 11 | env_args['num_variations'] = args.num_variations 12 | env_args['use_cached_states'] = False 13 | env_args['save_cached_states'] = True 14 | 15 | env_args['render'] = False 16 | env_args['headless'] = True 17 | env_args['render_mode'] = 'cloth' if args.gen_data else 'particle' 18 | env_args['camera_name'] = 'default_camera' 19 | env_args['camera_width'] = 360 20 | env_args['camera_height'] = 360 21 | 22 | env_args['num_picker'] = 2 # The extra picker is hidden and does not really matter 23 | env_args['picker_radius'] = 0.01 24 | env_args['picker_threshold'] = 0.00625 25 | env_args['action_repeat'] = 1 26 | 27 | if args.partial_observable and args.gen_data: 28 | env_args['observation_mode'] = 'cam_rgb' 29 | 30 | return SOFTGYM_ENVS[args.env_name](**env_args) 31 | 32 | 33 | if __name__ == '__main__': 34 | args = get_default_args() 35 | env = create_env(args) 36 | -------------------------------------------------------------------------------- /VCD/main.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import json 3 | import argparse 4 | 5 | from VCD.utils.utils import vv_to_args, set_resource 6 | import copy 7 | from VCD.vc_dynamics import VCDynamics 8 | from VCD.vc_edge import VCConnection 9 | from chester import logger 10 | from VCD.utils.utils import configure_logger, configure_seed 11 | 12 | 13 | def get_default_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--exp_name', type=str, default='test', help='Name of the experiment') 16 | parser.add_argument('--log_dir', type=str, default='data/dyn_debug/', help='Logging directory') 17 | parser.add_argument('--seed', type=int, default=100) 18 | 19 | # Env 20 | parser.add_argument('--env_name', type=str, default='ClothFlatten') 21 | parser.add_argument('--cached_states_path', type=str, default='1213_release_n1000.pkl') 22 | parser.add_argument('--num_variations', type=int, default=1000) 23 | parser.add_argument('--partial_observable', type=bool, default=True, help="Whether only the partial point cloud can be observed") 24 | parser.add_argument('--particle_radius', type=float, default=0.00625, help='Particle radius for the cloth') 25 | ## pyflex shape state 26 | parser.add_argument('--shape_state_dim', type=int, default=14, help="[xyz, xyz_last, quat(4), quat_last(4)]") 27 | 28 | # Dataset 29 | parser.add_argument('--n_rollout', type=int, default=2000, help='Number of training trajectories') 30 | parser.add_argument('--time_step', type=int, default=100, help='Time steps per trajectory') 31 | parser.add_argument('--dt', type=float, default=1. / 100.) 32 | parser.add_argument('--pred_time_interval', type=int, default=5, help='Interval of timesteps between each dynamics prediction (model dt)') 33 | parser.add_argument('--train_valid_ratio', type=float, default=0.9, help="Ratio between training and validation") 34 | parser.add_argument('--dataf', type=str, default='./data/release/', help='Path to dataset') 35 | parser.add_argument('--gen_data', type=int, default=0, help='Whether to generate dataset') 36 | parser.add_argument('--gen_gif', type=bool, default=0, help='Whether to also save gif of each trajectory (for debugging)') 37 | 38 | # Model 39 | parser.add_argument('--global_size', type=int, default=128, help="Number of hidden nodes for global in GNN") 40 | parser.add_argument('--n_his', type=int, default=5, help="Number of history step input to the dynamics") 41 | parser.add_argument('--down_sample_scale', type=int, default=3, help="Downsample the simulated cloth by a scale of 3 on each dimension") 42 | parser.add_argument('--voxel_size', type=float, default=0.0216) 43 | parser.add_argument('--neighbor_radius', type=float, default=0.045, help="Radius for connecting nearby edges") 44 | parser.add_argument('--use_rest_distance', type=bool, default=True, help="Subtract the rest distance for the edge attribute of mesh edges") 45 | parser.add_argument('--use_mesh_edge', type=bool, default=True) 46 | parser.add_argument('--collect_data_delta_move_min', type=float, default=0.15) 47 | parser.add_argument('--collect_data_delta_move_max', type=float, default=0.4) 48 | parser.add_argument('--proc_layer', type=int, default=10, help="Number of processor layers in GNN") 49 | parser.add_argument('--state_dim', type=int, default=18, 50 | help="Dim of node feature input. Computed based on n_his: 3 x 5 + 1 dist to ground + 2 one-hot encoding of picked particle") 51 | parser.add_argument('--relation_dim', type=int, default=7, help="""Dim of edge feature input: 52 | 3 for directional vector + 1 for directional vector magnitude + 2 for one-hot encoding of mesh or collision edge + 1 for rest distance 53 | """) 54 | 55 | # Resume training 56 | parser.add_argument('--edge_model_path', type=str, default=None, help='Path to a trained edgeGNN model') 57 | parser.add_argument('--full_dyn_path', type=str, default=None, help='Path to a dynamics model using full point cloud') 58 | parser.add_argument('--partial_dyn_path', type=str, default=None, help='Path to a dynamics model using partial point cloud') 59 | parser.add_argument('--load_optim', type=bool, default=False, help='Load optimizer when resume training') 60 | 61 | # Training 62 | parser.add_argument('--train_mode', type=str, default='vsbl', help='Should be in ["vsbl", "graph_imit", "full"]') 63 | parser.add_argument('--n_epoch', type=int, default=1000) 64 | parser.add_argument('--beta1', type=float, default=0.9) 65 | parser.add_argument('--lr', type=float, default=1e-4) 66 | parser.add_argument('--fixed_lr', type=bool, default=False, help='By default, decaying lr is used.') 67 | parser.add_argument('--batch_size', type=int, default=16) 68 | parser.add_argument('--cuda_idx', type=int, default=0) 69 | parser.add_argument('--num_workers', type=int, default=10, help='Number of workers for dataloader') 70 | parser.add_argument('--eval', type=int, default=0, help='Whether to just evaluating the model') 71 | parser.add_argument('--nstep_eval_rollout', type=int, default=20, help='Number of rollout trajectory for evaluation') 72 | parser.add_argument('--save_model_interval', type=int, default=5, help='Save the model every N epochs during training') 73 | parser.add_argument('--use_wandb', type=bool, default=False, help='Use weight and bias for logging') 74 | 75 | # For graph imitation 76 | parser.add_argument('--vsbl_lr', type=float, default=1e-4, help='Learning rate for visible(vsbl) point cloud dynamics') 77 | parser.add_argument('--full_lr', type=float, default=1e-4, help='Learning rate for full point cloud dynamics') 78 | parser.add_argument('--tune_teach', type=bool, default=False, help='Whether to allow teacher to adapt during graph imitation') 79 | parser.add_argument('--copy_teach', type=list, default=['encoder', 'decoder'], help="Which modules of the student are initialized from teacher") 80 | parser.add_argument('--imit_w_lat', type=float, default=1, help='Weight for imitating the global feature') 81 | parser.add_argument('--imit_w', type=float, default=5, help='Weight for imitation loss (vs student accel loss)') 82 | parser.add_argument('--reward_w', type=float, default=1e5, help='Weight for reward loss') 83 | 84 | # For ablation 85 | parser.add_argument('--fix_collision_edge', type=bool, default=False, help='Ablation that use fixed collision edges during rollout') 86 | parser.add_argument('--use_collision_as_mesh_edge', type=bool, default=False, 87 | help='If True, will not use gt mesh edges, but use all collision edges as mesh edges.') 88 | 89 | args = parser.parse_args() 90 | return args 91 | 92 | 93 | def create_env(args): 94 | from softgym.registered_env import env_arg_dict 95 | from softgym.registered_env import SOFTGYM_ENVS 96 | assert args.env_name == 'ClothFlatten' 97 | 98 | env_args = copy.deepcopy(env_arg_dict[args.env_name]) # Default args 99 | env_args['cached_states_path'] = args.cached_states_path 100 | env_args['num_variations'] = args.num_variations 101 | 102 | env_args['render'] = True 103 | env_args['headless'] = True 104 | env_args['render_mode'] = 'cloth' if args.gen_data else 'particle' 105 | env_args['camera_name'] = 'default_camera' 106 | env_args['camera_width'] = 360 107 | env_args['camera_height'] = 360 108 | 109 | env_args['num_picker'] = 2 # The extra picker is hidden and does not really matter 110 | env_args['picker_radius'] = 0.01 111 | env_args['picker_threshold'] = 0.00625 112 | env_args['action_repeat'] = 1 113 | 114 | if args.partial_observable and args.gen_data: 115 | env_args['observation_mode'] = 'cam_rgb' 116 | 117 | return SOFTGYM_ENVS[args.env_name](**env_args) 118 | 119 | 120 | def main(): 121 | set_resource() # To avoid pin_memory issue 122 | args = get_default_args() 123 | env = create_env(args) 124 | 125 | configure_logger(args.log_dir, args.exp_name) 126 | configure_seed(args.seed) 127 | 128 | with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f: 129 | json.dump(args.__dict__, f, indent=2, sort_keys=True) 130 | 131 | # load vcd_edge 132 | if args.edge_model_path is not None: 133 | edge_model_vv = json.load(open(osp.join(args.edge_model_path, 'variant.json'))) 134 | edge_model_args = vv_to_args(edge_model_vv) 135 | vcd_edge = VCConnection(edge_model_args, env=env) 136 | vcd_edge.load_model(args.edge_model_path) 137 | print('EdgeGNN successfully loaded from ', args.edge_model_path, flush=True) 138 | else: 139 | vcd_edge = None 140 | vcdynamics = VCDynamics(args, env, vcd_edge=vcd_edge) # Input the vcd_edge model in case we want to train with predicted edges 141 | 142 | if args.gen_data: 143 | vcdynamics.generate_dataset() 144 | else: 145 | vcdynamics.train() 146 | 147 | 148 | if __name__ == '__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /VCD/main_train_edge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from chester import logger 3 | import json 4 | import os.path as osp 5 | from VCD.vc_edge import VCConnection 6 | from VCD.main import create_env 7 | from VCD.utils.utils import configure_logger, configure_seed 8 | 9 | 10 | # TODO Merge arguments 11 | def get_default_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--exp_name', type=str, default='test', help='Name of the experiment') 14 | parser.add_argument('--log_dir', type=str, default='data/edge_debug/', help='Logging directory') 15 | parser.add_argument('--seed', type=int, default=100) 16 | 17 | # Env 18 | parser.add_argument('--env_name', type=str, default='ClothFlatten') 19 | parser.add_argument('--cached_states_path', type=str, default='1213_release_n1000.pkl') 20 | parser.add_argument('--num_variations', type=int, default=1000) 21 | parser.add_argument('--partial_observable', type=bool, default=True, help="Whether only the partial point cloud can be observed") 22 | parser.add_argument('--particle_radius', type=float, default=0.00625, help='Particle radius for the cloth') 23 | 24 | # Dataset 25 | parser.add_argument('--n_rollout', type=int, default=2000, help='Number of training trajectories') 26 | parser.add_argument('--time_step', type=int, default=100, help='Time steps per trajectory') 27 | parser.add_argument('--dt', type=float, default=1. / 100.) 28 | parser.add_argument('--pred_time_interval', type=int, default=5, help='Interval of timesteps between each dynamics prediction (model dt)') 29 | parser.add_argument('--train_valid_ratio', type=float, default=0.9, help="Ratio between training and validation") 30 | parser.add_argument('--dataf', type=str, default='softgym/softgym/cached_initial_states/', help='Path to dataset') 31 | parser.add_argument('--gen_data', type=int, default=0, help='Whether to generate dataset') 32 | parser.add_argument('--gen_gif', type=bool, default=0, help='Whether to also save gif of each trajectory (for debugging)') 33 | 34 | # Model 35 | parser.add_argument('--global_size', type=int, default=128, help="Number of hidden nodes for global in GNN") 36 | parser.add_argument('--n_his', type=int, default=5, help="Number of history step input to the dynamics") 37 | parser.add_argument('--down_sample_scale', type=int, default=3, help="Downsample the simulated cloth by a scale of 3 on each dimension") 38 | parser.add_argument('--voxel_size', type=float, default=0.0216) 39 | parser.add_argument('--neighbor_radius', type=float, default=0.045, help="Radius for connecting nearby edges") 40 | parser.add_argument('--collect_data_delta_move_min', type=float, default=0.15) 41 | parser.add_argument('--collect_data_delta_move_max', type=float, default=0.4) 42 | parser.add_argument('--proc_layer', type=int, default=10, help="Number of processor layers in GNN") 43 | parser.add_argument('--state_dim', type=int, default=3, 44 | help="Dim of node feature input. Computed based on n_his: 3 x 5 + 1 dist to ground + 2 one-hot encoding of picked particle") 45 | parser.add_argument('--relation_dim', type=int, default=4, help="Dim of edge feature input") 46 | 47 | # Resume training 48 | parser.add_argument('--edge_model_path', type=str, default=None, help='Path to a trained edgeGNN model') 49 | parser.add_argument('--load_optim', type=bool, default=False, help='Load optimizer when resume training') 50 | 51 | # Training 52 | parser.add_argument('--n_epoch', type=int, default=1000) 53 | parser.add_argument('--beta1', type=float, default=0.9) 54 | parser.add_argument('--lr', type=float, default=1e-4) 55 | parser.add_argument('--batch_size', type=int, default=16) 56 | parser.add_argument('--cuda_idx', type=int, default=0) 57 | parser.add_argument('--num_workers', type=int, default=15, help='Number of workers for dataloader') 58 | parser.add_argument('--use_wandb', type=bool, default=False, help='Use weight and bias for logging') 59 | parser.add_argument('--plot_num', type=int, default=8, help='Number of edge prediction visuals to dump per training epoch') 60 | parser.add_argument('--eval', type=int, default=0, help='Whether to just evaluating the model') 61 | 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def main(): 67 | args = get_default_args() 68 | configure_logger(args.log_dir, args.exp_name) 69 | configure_seed(args.seed) 70 | 71 | # Dump parameters 72 | with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f: 73 | json.dump(args.__dict__, f, indent=2, sort_keys=True) 74 | 75 | env = create_env(args) 76 | vcd_edge = VCConnection(args, env=env) 77 | if args.gen_data: 78 | vcd_edge.generate_dataset() 79 | else: 80 | vcd_edge.train() 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /VCD/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_scatter 3 | from itertools import chain 4 | from torch_geometric.nn import MetaLayer 5 | import os 6 | 7 | 8 | # ================== Encoder ================== # 9 | class NodeEncoder(torch.nn.Module): 10 | def __init__(self, input_size, hidden_size=128, output_size=128): 11 | super(NodeEncoder, self).__init__() 12 | self.input_size = input_size 13 | self.hidden_size = hidden_size 14 | self.output_size = output_size 15 | self.model = torch.nn.Sequential( 16 | torch.nn.Linear(self.input_size, self.hidden_size), 17 | torch.nn.ReLU(inplace=True), 18 | # torch.nn.LayerNorm(self.hidden_size), 19 | torch.nn.Linear(self.hidden_size, self.hidden_size), 20 | torch.nn.ReLU(inplace=True), 21 | # torch.nn.LayerNorm(self.hidden_size), 22 | torch.nn.Linear(self.hidden_size, self.output_size)) 23 | 24 | def forward(self, node_state): 25 | out = self.model(node_state) 26 | return out 27 | 28 | 29 | class EdgeEncoder(torch.nn.Module): 30 | def __init__(self, input_size, hidden_size=128, output_size=128): 31 | super(EdgeEncoder, self).__init__() 32 | self.input_size = input_size 33 | self.hidden_size = hidden_size 34 | self.output_size = output_size 35 | self.model = torch.nn.Sequential( 36 | torch.nn.Linear(self.input_size, self.hidden_size), 37 | torch.nn.ReLU(inplace=True), 38 | # torch.nn.LayerNorm(self.hidden_size), 39 | torch.nn.Linear(self.hidden_size, self.hidden_size), 40 | torch.nn.ReLU(inplace=True), 41 | # torch.nn.LayerNorm(self.hidden_size), 42 | torch.nn.Linear(self.hidden_size, self.output_size)) 43 | 44 | def forward(self, edge_properties): 45 | out = self.model(edge_properties) 46 | return out 47 | 48 | 49 | class Encoder(torch.nn.Module): 50 | def __init__(self, node_input_size, edge_input_size, hidden_size=128, output_size=128): 51 | super(Encoder, self).__init__() 52 | self.node_input_size = node_input_size 53 | self.edge_input_size = edge_input_size 54 | self.hidden_size = hidden_size 55 | self.output_size = output_size 56 | self.node_encoder = NodeEncoder(self.node_input_size, self.hidden_size, self.output_size) 57 | self.edge_encoder = EdgeEncoder(self.edge_input_size, self.hidden_size, self.output_size) 58 | 59 | def forward(self, node_states, edge_properties): 60 | node_embedding = self.node_encoder(node_states) 61 | edge_embedding = self.edge_encoder(edge_properties) 62 | return node_embedding, edge_embedding 63 | 64 | 65 | # ================== Processor ================== # 66 | class EdgeModel(torch.nn.Module): 67 | def __init__(self, input_size, hidden_size=128, output_size=128): 68 | super(EdgeModel, self).__init__() 69 | self.input_size = input_size 70 | self.hidden_size = hidden_size 71 | self.output_size = output_size 72 | self.model = torch.nn.Sequential( 73 | torch.nn.Linear(self.input_size, self.hidden_size), 74 | torch.nn.ReLU(inplace=True), 75 | # torch.nn.LayerNorm(self.hidden_size), 76 | torch.nn.Linear(self.hidden_size, self.hidden_size), 77 | torch.nn.ReLU(inplace=True), 78 | # torch.nn.LayerNorm(self.hidden_size), 79 | torch.nn.Linear(self.hidden_size, self.output_size)) 80 | 81 | def forward(self, src, dest, edge_attr, u, batch): 82 | # source, target: [E, F_x], where E is the number of edges. 83 | # edge_attr: [E, F_e] 84 | # u: [B, F_u], where B is the number of graphs. 85 | # batch: [E] with max entry B - 1. 86 | # u_expanded = u.expand([src.size()[0], -1]) 87 | # model_input = torch.cat([src, dest, edge_attr, u_expanded], 1) 88 | # out = self.model(model_input) 89 | model_input = torch.cat([src, dest, edge_attr, u[batch]], 1) 90 | out = self.model(model_input) 91 | return out 92 | 93 | 94 | class NodeModel(torch.nn.Module): 95 | def __init__(self, input_size, hidden_size=128, output_size=128): 96 | super(NodeModel, self).__init__() 97 | self.input_size = input_size 98 | self.hidden_size = hidden_size 99 | self.output_size = output_size 100 | self.model = torch.nn.Sequential( 101 | torch.nn.Linear(self.input_size, self.hidden_size), 102 | torch.nn.ReLU(inplace=True), 103 | # torch.nn.LayerNorm(self.hidden_size), 104 | torch.nn.Linear(self.hidden_size, self.hidden_size), 105 | torch.nn.ReLU(inplace=True), 106 | # torch.nn.LayerNorm(self.hidden_size), 107 | torch.nn.Linear(self.hidden_size, self.output_size)) 108 | 109 | def forward(self, x, edge_index, edge_attr, u, batch): 110 | # x: [N, F_x], where N is the number of nodes. 111 | # edge_index: [2, E] with max entry N - 1. 112 | # edge_attr: [E, F_e] 113 | # u: [B, F_u] 114 | # batch: [N] with max entry B - 1. 115 | _, edge_dst = edge_index 116 | edge_attr_aggregated = torch_scatter.scatter_add(edge_attr, edge_dst, dim=0, dim_size=x.size(0)) 117 | model_input = torch.cat([x, edge_attr_aggregated, u[batch]], dim=1) 118 | out = self.model(model_input) 119 | return out 120 | 121 | 122 | class GlobalModel(torch.nn.Module): 123 | def __init__(self, input_size, hidden_size=128, output_size=128): 124 | super(GlobalModel, self).__init__() 125 | self.input_size = input_size 126 | self.hidden_size = hidden_size 127 | self.output_size = output_size 128 | self.model = torch.nn.Sequential( 129 | torch.nn.Linear(self.input_size, self.hidden_size), 130 | torch.nn.ReLU(inplace=True), 131 | # torch.nn.LayerNorm(self.hidden_size), 132 | torch.nn.Linear(self.hidden_size, self.hidden_size), 133 | torch.nn.ReLU(inplace=True), 134 | # torch.nn.LayerNorm(self.hidden_size), 135 | torch.nn.Linear(self.hidden_size, self.output_size)) 136 | 137 | def forward(self, x, edge_index, edge_attr, u, batch): 138 | # x: [N, F_x], where N is the number of nodes. 139 | # edge_index: [2, E] with max entry N - 1. 140 | # edge_attr: [E, F_e] 141 | # u: [B, F_u] 142 | # batch: [N] with max entry B - 1. 143 | node_attr_mean = torch_scatter.scatter_mean(x, batch, dim=0) 144 | edge_attr_mean = torch_scatter.scatter_mean(edge_attr, batch[edge_index[0]], dim=0) 145 | model_input = torch.cat([u, node_attr_mean, edge_attr_mean], dim=1) 146 | out = self.model(model_input) 147 | assert out.shape == u.shape 148 | return out 149 | 150 | 151 | class RewardModel(torch.nn.Module): 152 | def __init__(self, node_size, global_size, hidden_size=128): 153 | super(RewardModel, self).__init__() 154 | self.node_size = node_size 155 | self.global_size = global_size 156 | self.hidden_size = hidden_size 157 | self.model = torch.nn.Sequential( 158 | torch.nn.Linear(self.global_size, self.hidden_size), 159 | torch.nn.ReLU(inplace=True), 160 | torch.nn.Linear(self.hidden_size, self.hidden_size), 161 | torch.nn.ReLU(inplace=True), 162 | torch.nn.Linear(self.hidden_size, 1)) 163 | 164 | def forward(self, node_feat, global_feat, batch): 165 | out = self.model(global_feat) 166 | return out 167 | 168 | 169 | class GNBlock(torch.nn.Module): 170 | def __init__(self, input_size, hidden_size=128, output_size=128, use_global=True, global_size=128): 171 | super(GNBlock, self).__init__() 172 | self.input_size = input_size 173 | self.hidden_size = hidden_size 174 | self.output_size = output_size 175 | if use_global: 176 | self.model = MetaLayer(EdgeModel(self.input_size[0], self.hidden_size, self.output_size), 177 | NodeModel(self.input_size[1], self.hidden_size, self.output_size), 178 | GlobalModel(self.input_size[2], self.hidden_size, global_size)) 179 | else: 180 | self.model = MetaLayer(EdgeModel(self.input_size[0], self.hidden_size, self.output_size), 181 | NodeModel(self.input_size[1], self.hidden_size, self.output_size), 182 | None) 183 | 184 | def forward(self, x, edge_index, edge_attr, u, batch): 185 | # x: [N, F_x], where N is the number of nodes. 186 | # edge_index: [2, E] with max entry N - 1. 187 | # edge_attr: [E, F_e] 188 | # u: [B, F_u] 189 | # batch: [N] with max entry B - 1. 190 | x, edge_attr, u = self.model(x, edge_index, edge_attr, u, batch) 191 | return x, edge_attr, u 192 | 193 | 194 | class Processor(torch.nn.Module): 195 | def __init__(self, input_size, hidden_size=128, output_size=128, use_global=True, global_size=128, layers=10): 196 | """ 197 | :param input_size: A list of size to edge model, node model and global model 198 | """ 199 | super(Processor, self).__init__() 200 | self.input_size = input_size 201 | self.hidden_size = hidden_size 202 | self.output_size = output_size 203 | self.use_global = use_global 204 | self.global_size = global_size 205 | self.gns = torch.nn.ModuleList([ 206 | GNBlock(self.input_size, self.hidden_size, self.output_size, self.use_global, global_size=global_size) 207 | for _ in range(layers)]) 208 | 209 | def forward(self, x, edge_index, edge_attr, u, batch): 210 | # def forward(self, data): 211 | # x, edge_index, edge_attr, u, batch = data.node_embedding, data.neighbors, data.edge_embedding, data.global_feat, data.batch 212 | # x: [N, F_x], where N is the number of nodes. 213 | # edge_index: [2, E] with max entry N - 1. 214 | # edge_attr: [E, F_e] 215 | # u: [B, F_u] 216 | # batch: [N] with max entry B - 1. 217 | if len(u.shape) == 1: 218 | u = u[None] 219 | if edge_index.shape[1] < 10: 220 | print("--------debug info---------") 221 | print("small number of edges") 222 | print("x.shape: ", x.shape) 223 | print("edge_index.shape: ", edge_index.shape) 224 | print("edge_attr.shape: ", edge_attr.shape, flush=True) 225 | print("--------------------------") 226 | 227 | x_new, edge_attr_new, u_new = x, edge_attr, u 228 | for gn in self.gns: 229 | x_res, edge_attr_res, u_res = gn(x_new, edge_index, edge_attr_new, u_new, batch) 230 | x_new = x_new + x_res 231 | edge_attr_new = edge_attr_new + edge_attr_res 232 | u_new = u_new + u_res 233 | return x_new, edge_attr_new, u_new 234 | 235 | 236 | # ================== Decoder ================== # 237 | class Decoder(torch.nn.Module): 238 | def __init__(self, input_size=128, hidden_size=128, output_size=3): 239 | super(Decoder, self).__init__() 240 | self.input_size = input_size 241 | self.hidden_size = hidden_size 242 | self.output_size = output_size 243 | self.model = torch.nn.Sequential( 244 | torch.nn.Linear(self.input_size, self.hidden_size), 245 | torch.nn.ReLU(inplace=True), 246 | # torch.nn.LayerNorm(self.hidden_size), 247 | torch.nn.Linear(self.hidden_size, self.hidden_size), 248 | torch.nn.ReLU(inplace=True), 249 | # torch.nn.LayerNorm(self.hidden_size), 250 | torch.nn.Linear(self.hidden_size, self.output_size)) 251 | 252 | def forward(self, node_feat, res=None): 253 | out = self.model(node_feat) 254 | if res is not None: 255 | out = out + res 256 | return out 257 | 258 | 259 | class GNN(torch.nn.Module): 260 | def __init__(self, args, decoder_output_dim, name, use_reward=False): 261 | super(GNN, self).__init__() 262 | self.name = name 263 | self.args = args 264 | self.use_global = True if self.args.global_size > 1 else False 265 | embed_dim = 128 266 | self.dyn_models = torch.nn.ModuleDict({'encoder': Encoder(args.state_dim, args.relation_dim, output_size=embed_dim), 267 | 'processor': Processor( 268 | [3 * embed_dim + args.global_size, 269 | 2 * embed_dim + args.global_size, 270 | 2 * embed_dim + args.global_size], 271 | use_global=self.use_global, layers=args.proc_layer, global_size=args.global_size), 272 | 'decoder': Decoder(output_size=decoder_output_dim)}) 273 | self.use_reward = use_reward 274 | print(use_reward) 275 | if use_reward: 276 | self.dyn_models['reward_model'] = RewardModel(128, 128, 128) 277 | 278 | def forward(self, data): 279 | """ data should be a dictionary containing the following dict 280 | edge_index: Edge index 2 x E 281 | x: Node feature 282 | edge_attr: Edge feature 283 | gt_accel: Acceleration label for each node 284 | x_batch: Batch index 285 | """ 286 | out = {} 287 | node_embedding, edge_embedding = self.dyn_models['encoder'](data['x'], data['edge_attr']) 288 | n_nxt, e_nxt, lat_nxt = self.dyn_models['processor'](node_embedding, 289 | data['edge_index'], 290 | edge_embedding, 291 | u=data['u'], 292 | batch=data['x_batch']) 293 | # Return acceleration for each node and the final global feature (for potential multi-step training) 294 | if self.name == 'EdgeGNN': 295 | out['mesh_edge'] = self.dyn_models['decoder'](e_nxt) 296 | else: 297 | out['accel'] = self.dyn_models['decoder'](n_nxt) 298 | if self.use_reward: 299 | out['reward_nxt'] = self.dyn_models['reward_model'](n_nxt, lat_nxt, batch=data['x_batch']) 300 | 301 | out['n_nxt'] = n_nxt[data['partial_pc_mapped_idx']] if 'partial_pc_mapped_idx' in data else n_nxt 302 | out['lat_nxt'] = lat_nxt 303 | return out 304 | 305 | def load_model(self, model_path, load_names='all', load_optim=False, optim=None): 306 | """ 307 | :param load_names: which part of ['encoder', 'processor', 'decoder'] to load 308 | :param load_optim: Whether to load optimizer states 309 | :return: 310 | """ 311 | ckpt = torch.load(model_path) 312 | optim_path = model_path.replace('dyn', 'optim') 313 | if load_names == 'all': 314 | for k, v in self.dyn_models.items(): 315 | self.dyn_models[k].load_state_dict(ckpt[k]) 316 | else: 317 | for model_name in load_names: 318 | self.dyn_models[model_name].load_state_dict(ckpt[model_name]) 319 | print('Loaded saved ckp from {} for {} models'.format(model_path, load_names)) 320 | 321 | if load_optim: 322 | assert os.path.exists(optim_path) 323 | optim.load_state_dict(torch.load(optim_path)) 324 | print('Load optimizer states from ', optim_path) 325 | 326 | def save_model(self, root_path, m_name, suffix, optim): 327 | """ 328 | Regular saving: {input_type}_dyn_{epoch}.pth 329 | Best model: {input_type}_dyn_best.pth 330 | Optim: {input_type}_optim_{epoch}.pth 331 | """ 332 | save_name = 'edge' if self.name == 'EdgeGNN' else 'dyn' 333 | model_path = os.path.join(root_path, '{}_{}_{}.pth'.format(m_name, save_name, suffix)) 334 | torch.save({k: v.state_dict() for k, v in self.dyn_models.items()}, model_path) 335 | optim_path = os.path.join(root_path, '{}_{}_{}.pth'.format(m_name, 'optim', suffix)) 336 | torch.save(optim.state_dict(), optim_path) 337 | 338 | def set_mode(self, mode='train'): 339 | assert mode in ['train', 'eval'] 340 | for model in self.dyn_models.values(): 341 | if mode == 'eval': 342 | model.eval() 343 | else: 344 | model.train() 345 | 346 | def param(self): 347 | model_parameters = list(chain(*[list(m.parameters()) for m in self.dyn_models.values()])) 348 | return model_parameters 349 | 350 | def to(self, device): 351 | for model in self.dyn_models.values(): 352 | model.to(device) 353 | 354 | def freeze(self, tgts=None): 355 | if tgts is None: 356 | for m in self.dyn_models.values(): 357 | for para in m.parameters(): 358 | para.requires_grad = False 359 | else: 360 | for tgt in tgts: 361 | m = self.dyn_models[tgt] 362 | for para in m.parameters(): 363 | para.requires_grad = False 364 | 365 | def unfreeze(self, tgts=None): 366 | if tgts is None: 367 | for m in self.dyn_models.values(): 368 | for para in m.parameters(): 369 | para.requires_grad = True 370 | else: 371 | for tgt in tgts: 372 | m = self.dyn_models[tgt] 373 | for para in m.parameters(): 374 | para.requires_grad = True 375 | -------------------------------------------------------------------------------- /VCD/rs_planner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiprocessing import pool 3 | import copy 4 | from VCD.utils.camera_utils import project_to_image, get_target_pos 5 | 6 | 7 | class RandomShootingUVPickandPlacePlanner(): 8 | 9 | def __init__(self, num_pick, delta_y, pull_step, wait_step, 10 | dynamics, reward_model, num_worker=10, 11 | move_distance_range=[0.05, 0.2], gpu_num=1, 12 | image_size=None, normalize_info=None, delta_y_range=None, 13 | matrix_world_to_camera=np.identity(4), task='flatten', 14 | use_pred_rwd=False): 15 | """ 16 | Random Shooting planner. 17 | """ 18 | 19 | self.normalize_info = normalize_info # Used for robot experiments to denormalize before action clipping 20 | self.num_pick = num_pick 21 | self.delta_y = delta_y # for real world experiment, delta_y (pick_height) is fixed 22 | self.delta_y_range = delta_y_range # for simulation, delta_y is randomlized 23 | self.move_distance_low, self.move_distance_high = move_distance_range[0], move_distance_range[1] 24 | self.reward_model = reward_model 25 | self.dynamics = dynamics 26 | self.pull_step, self.wait_step = pull_step, wait_step 27 | self.gpu_num = gpu_num 28 | self.use_pred_rwd = use_pred_rwd 29 | 30 | if num_worker > 0: 31 | self.pool = pool.Pool(processes=num_worker) 32 | self.num_worker = num_worker 33 | self.matrix_world_to_camera = matrix_world_to_camera 34 | self.image_size = image_size 35 | self.task = task 36 | 37 | def project_3d(self, pos): 38 | return project_to_image(self.matrix_world_to_camera, pos, self.image_size[0], self.image_size[1]) 39 | 40 | def get_action(self, init_data, robot_exp=False, cloth_mask=None, check_mask=None, m_name='vsbl'): 41 | """ 42 | check_mask: Used to filter out place points that are on the cloth. 43 | init_data should be a list that include: 44 | ['pointcloud', 'velocities', 'picker_position', 'action', 'picked_points', 'scene_params', 'observable_particle_indices] 45 | note: require position, velocity to be already downsampled 46 | 47 | """ 48 | args = self.dynamics.args 49 | data = init_data.copy() 50 | data['picked_points'] = [-1, -1] 51 | 52 | pull_step, wait_step = self.pull_step, self.wait_step 53 | 54 | # add a no-op action 55 | pick_try_num = self.num_pick + 1 if self.task == 'flatten' else self.num_pick 56 | actions = np.zeros((pick_try_num, pull_step + wait_step, 8)) 57 | pointcloud = copy.deepcopy(data['pointcloud']) 58 | 59 | picker_pos = data['picker_position'][0][:3] if data['picker_position'] is not None else None 60 | bb_margin = 30 61 | 62 | # paralleled version of generating action sequences 63 | if robot_exp: 64 | num_samples = 10 * self.num_pick # Reject 65 | 66 | def filter_out_of_bound(pos, x_low=-0.45, x_high=0.06, z_low=0.3, z_high=0.65): 67 | """Rreturn in bound idxes""" 68 | cond1 = pos[:, 0] >= x_low 69 | cond2 = pos[:, 0] <= x_high 70 | cond3 = pos[:, 2] >= z_low 71 | cond4 = pos[:, 2] <= z_high 72 | cond = cond1 * cond2 * cond3 * cond4 73 | return np.where(cond)[0] 74 | 75 | idxes = np.random.randint(0, len(pointcloud), num_samples) 76 | 77 | # In real world, instead of using uv, simply pick a random idx 78 | pickup_pos = pointcloud[idxes] 79 | 80 | # Remove out of bound pick places 81 | x_mean, z_mean = self.normalize_info['xz_mean'] # First denormalize 82 | pickup_pos[:, 0] += x_mean 83 | pickup_pos[:, 2] += z_mean 84 | in_bound_idx = filter_out_of_bound(pickup_pos) 85 | pickup_pos[:, 0] -= x_mean 86 | pickup_pos[:, 2] -= z_mean 87 | pickup_pos = pickup_pos[in_bound_idx] 88 | idxes = idxes[in_bound_idx] 89 | num_samples = len(pickup_pos) 90 | 91 | move_theta = np.random.rand(num_samples).reshape(num_samples, 1) * 2 * np.pi 92 | move_distance = np.random.uniform(self.move_distance_low, self.move_distance_high, num_samples) 93 | move_direction = np.hstack( 94 | [np.cos(move_theta), np.zeros_like(move_theta), np.sin(move_theta)]) * move_distance.reshape( 95 | num_samples, 1) 96 | 97 | place_pos = (pickup_pos + move_direction).copy() 98 | 99 | # Clip place_pos with a fixed bounding box to make sure the place point is within the camera 100 | x_mean, z_mean = self.normalize_info['xz_mean'] # First denormalize 101 | place_pos[:, 0] += x_mean 102 | place_pos[:, 2] += z_mean 103 | in_bound_idx = filter_out_of_bound(place_pos) 104 | place_pos[:, 0] -= x_mean 105 | place_pos[:, 2] -= z_mean 106 | 107 | if check_mask is not None: 108 | out_cloth_idx = np.where(check_mask(place_pos))[0] 109 | # print('in bound number:', in_bound_idx.shape) 110 | # print(in_bound_idx[:50], out_cloth_idx[:50]) 111 | in_bound_idx = np.intersect1d(in_bound_idx, out_cloth_idx) 112 | # print(in_bound_idx) 113 | # print('how many left:', in_bound_idx.shape) 114 | 115 | select_idx = np.random.choice(in_bound_idx, self.num_pick, replace=len(in_bound_idx) < self.num_pick) 116 | pickup_pos = pickup_pos[select_idx] 117 | place_pos = place_pos[select_idx] 118 | idxes = idxes[select_idx] 119 | waypoints = np.zeros([self.num_pick, 3, 3]) 120 | 121 | waypoints[:, 0, :] = pickup_pos 122 | waypoints[:, 1, :] = pickup_pos + np.array([0, self.delta_y, 0]).reshape(1, 3) 123 | waypoints[:, 2, :] = place_pos + np.array([0, self.delta_y, 0]).reshape(1, 3) 124 | 125 | # Update move_direction after clipping 126 | move_direction = waypoints[:, 2, :] - waypoints[:, 1, :] 127 | 128 | delta_moves = list(move_direction / self.pull_step) 129 | picked_particles = list(idxes) 130 | waypoints = list(waypoints) 131 | move_vec = list(move_direction) 132 | # TODO try using delta_y_raneg instead of fixed y 133 | delta_move = move_direction 134 | delta_move[:, 1] += self.delta_y 135 | num_step = self.pull_step 136 | actions[:-1, :num_step, :3] = delta_move[:, None, :] / num_step # Upward 137 | actions[:-1, :num_step, 3] = 1 138 | actions[:, :, 4:] = 0 # we essentially only plan over 1 picker action 139 | 140 | # Add no-op 141 | if not self.random: 142 | waypoints.append([np.nan, np.nan, np.nan]) 143 | delta_moves.append([0., 0., 0.]) 144 | picked_particles.append(-1) 145 | move_vec.append([0., 0., 0.]) 146 | else: # simulation planning 147 | us, vs = self.project_3d(pointcloud) 148 | params = [ 149 | (us, vs, self.image_size, pointcloud, pull_step, 150 | self.delta_y_range, bb_margin, self.matrix_world_to_camera, 151 | self.move_distance_low, self.move_distance_high, cloth_mask, self.task) 152 | for i in range(self.num_pick) 153 | ] 154 | results = self.pool.map(parallel_generate_actions, params) 155 | delta_moves, start_poses, after_poses = [x[0] for x in results], [x[1] for x in results], [x[2] for x in results] 156 | if self.task == 'flatten': # add a no-op action 157 | start_poses.append(data['picker_position'][0, :]) 158 | after_poses.append(data['picker_position'][0, :]) 159 | 160 | actions[:-1, :pull_step, :3] = np.vstack(delta_moves)[:, None, :] 161 | actions[:-1, :pull_step, 3] = 1 162 | actions[:, :, 4:] = 0 163 | move_vec = None 164 | 165 | # parallely rollout the dynamics model with the sampled action seqeunces 166 | data_cpy = copy.deepcopy(data) 167 | if self.num_worker > 0: 168 | job_each_gpu = pick_try_num // self.gpu_num 169 | params = [] 170 | for i in range(pick_try_num): 171 | if robot_exp: 172 | data_cpy['picked_points'] = [picked_particles[i], -1] 173 | else: 174 | data_cpy['picked_points'] = [-1, -1] 175 | data_cpy['picker_position'][0, :] = start_poses[i] 176 | 177 | gpu_id = i // job_each_gpu if i < self.gpu_num * job_each_gpu else i % self.gpu_num 178 | params.append( 179 | dict( 180 | model_input_data=copy.deepcopy(data_cpy), actions=actions[i], m_name=m_name, 181 | reward_model=self.reward_model, cuda_idx=gpu_id, robot_exp=robot_exp, 182 | ) 183 | ) 184 | results = self.pool.map(self.dynamics.rollout, params, chunksize=max(1, pick_try_num // self.num_worker)) 185 | returns = [x['final_ret'] for x in results] 186 | else: # sequentially rollout each sampled action trajectory 187 | returns, results = [], [] 188 | for i in range(pick_try_num): 189 | res = self.dynamics.rollout( 190 | dict( 191 | model_input_data=copy.deepcopy(data_cpy), actions=actions[i], m_name=m_name, 192 | reward_model=self.reward_model, cuda_idx=0, robot_exp=robot_exp, 193 | ) 194 | ) 195 | results.append(res), returns.append(res['final_ret']) 196 | 197 | ret_info = {} 198 | highest_return_idx = np.argmax(returns) 199 | 200 | ret_info['highest_return_idx'] = highest_return_idx 201 | action_seq = actions[highest_return_idx] 202 | if robot_exp: 203 | ret_info['waypoints'] = np.array(waypoints[highest_return_idx]).copy() 204 | ret_info['all_candidate'] = np.array(waypoints[:-1]) 205 | ret_info['all_candidate_rewards'] = np.array(returns[:-1]) 206 | else: 207 | ret_info['start_pos'] = start_poses[highest_return_idx] 208 | ret_info['after_pos'] = after_poses[highest_return_idx] 209 | 210 | model_predict_particle_positions = results[highest_return_idx]['model_positions'] 211 | model_predict_shape_positions = results[highest_return_idx]['shape_positions'] 212 | predicted_edges = results[highest_return_idx]['mesh_edges'] 213 | if move_vec is not None: 214 | ret_info['picked_pos'] = pointcloud[picked_particles[highest_return_idx]] 215 | ret_info['move_vec'] = move_vec[highest_return_idx] 216 | 217 | return action_seq, model_predict_particle_positions, model_predict_shape_positions, ret_info, predicted_edges 218 | 219 | 220 | def pos_in_image(after_pos, matrix_world_to_camera, image_size): 221 | euv = project_to_image(matrix_world_to_camera, after_pos.reshape((1, 3)), image_size[0], image_size[1]) 222 | u, v = euv[0][0], euv[1][0] 223 | if u >= 0 and u < image_size[1] and v >= 0 and v < image_size[0]: 224 | return True 225 | else: 226 | return False 227 | 228 | 229 | def parallel_generate_actions(args): 230 | us, vs, image_size, pointcloud, pull_step, delta_y_range, bb_margin, matrix_world_to_camera, move_distance_low, move_distance_high, cloth_mask, task = args 231 | 232 | # choosing a pick location 233 | lb_u, lb_v, ub_u, ub_v = int(np.min(us)), int(np.min(vs)), int(np.max(us)), int(np.max(vs)) 234 | u = np.random.randint(max(lb_u - bb_margin, 0), min(ub_u + bb_margin, image_size[1])) 235 | v = np.random.randint(max(lb_v - bb_margin, 0), min(ub_v + bb_margin, image_size[0])) 236 | target_pos = get_target_pos(pointcloud, u, v, image_size, matrix_world_to_camera, cloth_mask) 237 | 238 | # second stage: choose a random (x, y, z) direction, move towards that direction to determine the pick point 239 | while True: 240 | move_direction = np.random.rand(3) - 0.5 241 | if task == 'flatten': 242 | move_direction[1] = np.random.uniform(delta_y_range[0], delta_y_range[1]) 243 | else: # for fold, just generate horizontal move 244 | move_direction[1] = 0 245 | 246 | move_direction = move_direction / np.linalg.norm(move_direction) 247 | move_distance = np.random.uniform(move_distance_low, move_distance_high) 248 | delta_move = move_distance / pull_step * move_direction 249 | 250 | after_pos = target_pos + move_distance * move_direction 251 | if pos_in_image(after_pos, matrix_world_to_camera, image_size): 252 | break 253 | 254 | return delta_move, target_pos, after_pos 255 | -------------------------------------------------------------------------------- /VCD/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/VCD/utils/__init__.py -------------------------------------------------------------------------------- /VCD/utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.optimize as opt 3 | import scipy 4 | from scipy.spatial.distance import cdist 5 | 6 | 7 | def build_depth_from_pointcloud(pointcloud, matrix_world_to_camera, imsize): 8 | height, width = imsize 9 | pointcloud = np.concatenate([pointcloud, np.ones((len(pointcloud), 1))], axis=1) # n x 4 10 | camera_coordinate = matrix_world_to_camera @ pointcloud.T # 3 x n 11 | camera_coordinate = camera_coordinate.T # n x 3 12 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees 13 | 14 | u0 = K[0, 2] 15 | v0 = K[1, 2] 16 | fx = K[0, 0] 17 | fy = K[1, 1] 18 | 19 | x, y, depth = camera_coordinate[:, 0], camera_coordinate[:, 1], camera_coordinate[:, 2] 20 | u = np.rint((x * fx / depth + u0).astype("int")) 21 | v = np.rint((y * fy / depth + v0).astype("int")) 22 | 23 | us = u.flatten() 24 | vs = v.flatten() 25 | depth = depth.flatten() 26 | 27 | depth_map = dict() 28 | for u, v, d in zip(us, vs, depth): 29 | if depth_map.get((u, v)) is None: 30 | depth_map[(u, v)] = [] 31 | depth_map[(u, v)].append(d) 32 | else: 33 | depth_map[(u, v)].append(d) 34 | 35 | depth_2d = np.zeros((height, width)) 36 | for u in range(width): 37 | for v in range(height): 38 | if (u, v) in depth_map.keys(): 39 | depth_2d[v][u] = np.min(depth_map[(u, v)]) 40 | 41 | return depth_2d 42 | 43 | 44 | def pixel_coord_np(width, height): 45 | """ 46 | Pixel in homogenous coordinate 47 | Returns: 48 | Pixel coordinate: [3, width * height] 49 | """ 50 | x = np.linspace(0, width - 1, width).astype(np.int) 51 | y = np.linspace(0, height - 1, height).astype(np.int) 52 | [x, y] = np.meshgrid(x, y) 53 | return np.vstack((x.flatten(), y.flatten(), np.ones_like(x.flatten()))) 54 | 55 | 56 | def intrinsic_from_fov(height, width, fov=90): 57 | """ 58 | Basic Pinhole Camera Model 59 | intrinsic params from fov and sensor width and height in pixels 60 | Returns: 61 | K: [4, 4] 62 | """ 63 | px, py = (width / 2, height / 2) 64 | hfov = fov / 360. * 2. * np.pi 65 | fx = width / (2. * np.tan(hfov / 2.)) 66 | 67 | vfov = 2. * np.arctan(np.tan(hfov / 2) * height / width) 68 | fy = height / (2. * np.tan(vfov / 2.)) 69 | 70 | return np.array([[fx, 0, px, 0.], 71 | [0, fy, py, 0.], 72 | [0, 0, 1., 0.], 73 | [0., 0., 0., 1.]]) 74 | 75 | 76 | def get_rotation_matrix(angle, axis): 77 | axis = axis / np.linalg.norm(axis) 78 | s = np.sin(angle) 79 | c = np.cos(angle) 80 | 81 | m = np.zeros((4, 4)) 82 | 83 | m[0][0] = axis[0] * axis[0] + (1.0 - axis[0] * axis[0]) * c 84 | m[0][1] = axis[0] * axis[1] * (1.0 - c) - axis[2] * s 85 | m[0][2] = axis[0] * axis[2] * (1.0 - c) + axis[1] * s 86 | m[0][3] = 0.0 87 | 88 | m[1][0] = axis[0] * axis[1] * (1.0 - c) + axis[2] * s 89 | m[1][1] = axis[1] * axis[1] + (1.0 - axis[1] * axis[1]) * c 90 | m[1][2] = axis[1] * axis[2] * (1.0 - c) - axis[0] * s 91 | m[1][3] = 0.0 92 | 93 | m[2][0] = axis[0] * axis[2] * (1.0 - c) - axis[1] * s 94 | m[2][1] = axis[1] * axis[2] * (1.0 - c) + axis[0] * s 95 | m[2][2] = axis[2] * axis[2] + (1.0 - axis[2] * axis[2]) * c 96 | m[2][3] = 0.0 97 | 98 | m[3][0] = 0.0 99 | m[3][1] = 0.0 100 | m[3][2] = 0.0 101 | m[3][3] = 1.0 102 | 103 | return m 104 | 105 | 106 | def get_world_coords(rgb, depth, env, particle_pos=None): 107 | height, width, _ = rgb.shape 108 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees 109 | 110 | # Apply back-projection: K_inv @ pixels * depth 111 | u0 = K[0, 2] 112 | v0 = K[1, 2] 113 | fx = K[0, 0] 114 | fy = K[1, 1] 115 | 116 | x = np.linspace(0, width - 1, width).astype(np.float) 117 | y = np.linspace(0, height - 1, height).astype(np.float) 118 | u, v = np.meshgrid(x, y) 119 | one = np.ones((height, width, 1)) 120 | x = (u - u0) * depth / fx 121 | y = (v - v0) * depth / fy 122 | z = depth 123 | cam_coords = np.dstack([x, y, z, one]) 124 | 125 | matrix_world_to_camera = get_matrix_world_to_camera( 126 | env.camera_params[env.camera_name]['pos'], env.camera_params[env.camera_name]['angle']) 127 | 128 | # convert the camera coordinate back to the world coordinate using the rotation and translation matrix 129 | cam_coords = cam_coords.reshape((-1, 4)).transpose() # 4 x (height x width) 130 | world_coords = np.linalg.inv(matrix_world_to_camera) @ cam_coords # 4 x (height x width) 131 | world_coords = world_coords.transpose().reshape((height, width, 4)) 132 | 133 | return world_coords 134 | 135 | 136 | def get_observable_particle_index(world_coords, particle_pos, rgb, depth): 137 | height, width, _ = rgb.shape 138 | # perform the matching of pixel particle to real particle 139 | particle_pos = particle_pos[:, :3] 140 | 141 | estimated_world_coords = np.array(world_coords)[np.where(depth > 0)][:, :3] 142 | 143 | distance = scipy.spatial.distance.cdist(estimated_world_coords, particle_pos) 144 | # Each point in the point cloud will cover at most two particles. Particles not covered will be deemed occluded 145 | estimated_particle_idx = np.argpartition(distance, 2)[:, :2].flatten() 146 | estimated_particle_idx = np.unique(estimated_particle_idx) 147 | 148 | return np.array(estimated_particle_idx, dtype=np.int32) 149 | 150 | 151 | def get_observable_particle_index_old(world_coords, particle_pos, rgb, depth): 152 | height, width, _ = rgb.shape 153 | # perform the matching of pixel particle to real particle 154 | particle_pos = particle_pos[:, :3] 155 | 156 | estimated_world_coords = np.array(world_coords)[np.where(depth > 0)][:, :3] 157 | 158 | distance = scipy.spatial.distance.cdist(estimated_world_coords, particle_pos) 159 | estimated_particle_idx = np.argmin(distance, axis=1) 160 | estimated_particle_idx = np.unique(estimated_particle_idx) 161 | 162 | return np.array(estimated_particle_idx, dtype=np.int32) 163 | 164 | 165 | def get_observable_particle_index_3(pointcloud, mesh, threshold=0.0216): 166 | ### bi-partite graph matching 167 | distance = scipy.spatial.distance.cdist(pointcloud, mesh) 168 | distance[distance > threshold] = 1e10 169 | row_idx, column_idx = opt.linear_sum_assignment(distance) 170 | 171 | distance_mapped = distance[np.arange(len(pointcloud)), column_idx] 172 | bad_mapping = distance_mapped > threshold 173 | if np.sum(bad_mapping) > 0: 174 | column_idx[bad_mapping] = np.argmin(distance[bad_mapping], axis=1) 175 | 176 | return pointcloud, column_idx 177 | 178 | 179 | def get_mapping_from_pointcloud_to_partile_nearest_neighbor(pointcloud, particle): 180 | distance = scipy.spatial.distance.cdist(pointcloud, particle) 181 | nearest_idx = np.argmin(distance, axis=1) 182 | return nearest_idx 183 | 184 | 185 | def get_observable_particle_index_4(pointcloud, mesh, threshold=0.0216): 186 | # perform the matching of pixel particle to real particle 187 | estimated_world_coords = pointcloud 188 | 189 | distance = scipy.spatial.distance.cdist(estimated_world_coords, mesh) 190 | estimated_particle_idx = np.argmin(distance, axis=1) 191 | 192 | return pointcloud, np.array(estimated_particle_idx, dtype=np.int32) 193 | 194 | 195 | def get_observable_particle_pos(world_coords, particle_pos): 196 | # perform the matching of pixel particle to real particle 197 | particle_pos = particle_pos[:, :3] 198 | distance = scipy.spatial.distance.cdist(world_coords, particle_pos) 199 | estimated_particle_idx = np.argmin(distance, axis=1) 200 | observable_particle_pos = particle_pos[estimated_particle_idx] 201 | 202 | return observable_particle_pos 203 | 204 | 205 | def get_matrix_world_to_camera(cam_pos=[-0.0, 0.82, 0.82], cam_angle=[0, -45 / 180. * np.pi, 0.]): 206 | cam_x, cam_y, cam_z = cam_pos[0], cam_pos[1], \ 207 | cam_pos[2] 208 | cam_x_angle, cam_y_angle, cam_z_angle = cam_angle[0], cam_angle[1], \ 209 | cam_angle[2] 210 | 211 | # get rotation matrix: from world to camera 212 | matrix1 = get_rotation_matrix(- cam_x_angle, [0, 1, 0]) 213 | matrix2 = get_rotation_matrix(- cam_y_angle - np.pi, [1, 0, 0]) 214 | rotation_matrix = matrix2 @ matrix1 215 | 216 | # get translation matrix: from world to camera 217 | translation_matrix = np.zeros((4, 4)) 218 | translation_matrix[0][0] = 1 219 | translation_matrix[1][1] = 1 220 | translation_matrix[2][2] = 1 221 | translation_matrix[3][3] = 1 222 | translation_matrix[0][3] = - cam_x 223 | translation_matrix[1][3] = - cam_y 224 | translation_matrix[2][3] = - cam_z 225 | 226 | return rotation_matrix @ translation_matrix 227 | 228 | 229 | def project_to_image(matrix_world_to_camera, world_coordinate, height=360, width=360): 230 | world_coordinate = np.concatenate([world_coordinate, np.ones((len(world_coordinate), 1))], axis=1) # n x 4 231 | camera_coordinate = matrix_world_to_camera @ world_coordinate.T # 3 x n 232 | camera_coordinate = camera_coordinate.T # n x 3 233 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees 234 | 235 | u0 = K[0, 2] 236 | v0 = K[1, 2] 237 | fx = K[0, 0] 238 | fy = K[1, 1] 239 | 240 | x, y, depth = camera_coordinate[:, 0], camera_coordinate[:, 1], camera_coordinate[:, 2] 241 | u = (x * fx / depth + u0).astype("int") 242 | v = (y * fy / depth + v0).astype("int") 243 | 244 | return u, v 245 | 246 | 247 | def _get_depth(matrix, vec, height): 248 | """ Get the depth such that the back-projected point has a fixed height""" 249 | return (height - matrix[1, 3]) / (vec[0] * matrix[1, 0] + vec[1] * matrix[1, 1] + matrix[1, 2]) 250 | 251 | 252 | def get_world_coor_from_image(u, v, image_size, matrix_world_to_camera, all_depth): 253 | height, width = image_size 254 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees 255 | 256 | matrix = np.linalg.inv(matrix_world_to_camera) 257 | 258 | u0, v0, fx, fy = K[0, 2], K[1, 2], K[0, 0], K[1, 1] 259 | 260 | depth = all_depth[v][u] 261 | if depth == 0: 262 | vec = ((u - u0) / fx, (v - v0) / fy) 263 | depth = _get_depth(matrix, vec, 0.00625) # Height to be the particle radius 264 | 265 | x = (u - u0) * depth / fx 266 | y = (v - v0) * depth / fy 267 | z = depth 268 | cam_coords = np.array([x, y, z, 1]) 269 | cam_coords = cam_coords.reshape((-1, 4)).transpose() # 4 x (height x width) 270 | 271 | world_coord = matrix @ cam_coords # 4 x (height x width) 272 | world_coord = world_coord.reshape(4) 273 | return world_coord[:3] 274 | 275 | 276 | def get_target_pos(pos, u, v, image_size, matrix_world_to_camera, depth): 277 | coor = get_world_coor_from_image(u, v, image_size, matrix_world_to_camera, depth) 278 | dists = cdist(coor[None], pos)[0] 279 | idx = np.argmin(dists) 280 | return pos[idx] + np.array([0, 0.01, 0]) 281 | -------------------------------------------------------------------------------- /VCD/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch_geometric.data import Data 4 | import torch_geometric 5 | 6 | 7 | class PrivilData(Data): 8 | """ 9 | Encapsulation of multi-graphs for multi-step training 10 | ind: 0-(hor-1), type: vsbl or full 11 | Each graph contain: 12 | edge_index_{type}_{ind}, 13 | x_{type}_{ind}, 14 | edge_attr_{type}_{ind}, 15 | gt_rwd_{type}_{ind} 16 | gt_accel_{type}_{ind} 17 | mesh_mapping_{type}_{ind} 18 | """ 19 | 20 | def __init__(self, has_part=False, has_full=False, **kwargs): 21 | super(PrivilData, self).__init__(**kwargs) 22 | self.has_part = has_part 23 | self.has_full = has_full 24 | 25 | def __inc__(self, key, value, *args, **kwargs): 26 | if 'edge_index' in key: 27 | x = key.replace('edge_index', 'x') 28 | return self[x].size(0) 29 | elif 'mesh_mapping' in key: 30 | # add index of mesh matching by 31 | x = key.replace('partial_pc_mapped_idx', 'x') 32 | return self[x].size(0) 33 | else: 34 | return super().__inc__(key, value) 35 | 36 | 37 | class AggDict(dict): 38 | def __init__(self, is_detach=True): 39 | """ 40 | Aggregate numpy arrays or pytorch tensors 41 | :param is_detach: Whether to save numpy arrays in stead of torch tensors 42 | """ 43 | super(AggDict).__init__() 44 | self.is_detach = is_detach 45 | 46 | def __getitem__(self, item): 47 | return self.get(item, 0) 48 | 49 | def add_item(self, key, value): 50 | if self.is_detach and torch.is_tensor(value): 51 | value = value.detach().cpu().numpy() 52 | if not isinstance(value, torch.Tensor): 53 | if isinstance(value, np.ndarray) or isinstance(value, np.number): 54 | assert value.size == 1 55 | else: 56 | assert isinstance(value, int) or isinstance(value, float) 57 | if key not in self.keys(): 58 | self[key] = value 59 | else: 60 | self[key] += value 61 | 62 | def update_by_add(self, src_dict): 63 | for key, value in src_dict.items(): 64 | self.add_item(key, value) 65 | 66 | def get_mean(self, prefix, count=1): 67 | avg_dict = {} 68 | for k, v in self.items(): 69 | avg_dict[prefix + k] = v / count 70 | return avg_dict 71 | 72 | 73 | def updateDictByAdd(dict1, dict2): 74 | ''' 75 | update dict1 by dict2 76 | ''' 77 | for k1, v1 in dict2.items(): 78 | for k2, v2 in v1.items(): 79 | dict1[k1][k2] += v2.cpu().item() 80 | return dict1 81 | 82 | 83 | def get_index_before_padding(graph_sizes): 84 | ins_len = graph_sizes.max() 85 | pad_len = ins_len * graph_sizes.size(0) 86 | valid_len = graph_sizes.sum() 87 | accum = torch.zeros(1).cuda() 88 | out = [] 89 | for gs in graph_sizes: 90 | new_ind = torch.range(0, gs - 1).cuda() + accum 91 | out.append(new_ind) 92 | accum += ins_len 93 | final_ind = torch.cat(out, dim=0) 94 | return final_ind.long() 95 | 96 | 97 | class MyDataParallel(torch_geometric.nn.DataParallel): 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | 101 | def __getattr__(self, name): 102 | if name == 'module': 103 | return self._modules['module'] 104 | else: 105 | return getattr(self.module, name) 106 | 107 | 108 | def retrieve_data(data, key): 109 | """ 110 | vsbl: [vsbl], full: [full], dual :[vsbl, full] 111 | """ 112 | if isinstance(data, dict): 113 | identifier = '_{}'.format(key) 114 | out_data = {k.replace(identifier, ''): v for k, v in data.items() if identifier in k} 115 | return out_data 116 | -------------------------------------------------------------------------------- /VCD/utils/gemo_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyflex 3 | 4 | 5 | def pixel_coord_np(width, height): 6 | """ 7 | Pixel in homogenous coordinate 8 | Returns: 9 | Pixel coordinate: [3, width * height] 10 | """ 11 | x = np.linspace(0, width - 1, width).astype(np.int) 12 | y = np.linspace(0, height - 1, height).astype(np.int) 13 | [x, y] = np.meshgrid(x, y) 14 | return np.vstack((x.flatten(), y.flatten(), np.ones_like(x.flatten()))) 15 | 16 | 17 | def intrinsic_from_fov(height, width, fov=90): 18 | """ 19 | Basic Pinhole Camera Model 20 | intrinsic params from fov and sensor width and height in pixels 21 | Returns: 22 | K: [4, 4] 23 | """ 24 | px, py = (width / 2, height / 2) 25 | hfov = fov / 360. * 2. * np.pi 26 | fx = width / (2. * np.tan(hfov / 2.)) 27 | 28 | vfov = 2. * np.arctan(np.tan(hfov / 2) * height / width) 29 | fy = height / (2. * np.tan(vfov / 2.)) 30 | 31 | return np.array([[fx, 0, px, 0.], 32 | [0, fy, py, 0.], 33 | [0, 0, 1., 0.], 34 | [0., 0., 0., 1.]]) 35 | 36 | 37 | def get_rotation_matrix(angle, axis): 38 | axis = axis / np.linalg.norm(axis) 39 | s = np.sin(angle) 40 | c = np.cos(angle) 41 | 42 | m = np.zeros((4, 4)) 43 | 44 | m[0][0] = axis[0] * axis[0] + (1.0 - axis[0] * axis[0]) * c 45 | # m[0][1] = axis[0] * axis[1] * (1.0 - c) + axis[2] * s 46 | m[0][1] = axis[0] * axis[1] * (1.0 - c) - axis[2] * s 47 | # m[0][2] = axis[0] * axis[2] * (1.0 - c) - axis[1] * s 48 | m[0][2] = axis[0] * axis[2] * (1.0 - c) + axis[1] * s 49 | m[0][3] = 0.0 50 | 51 | # m[1][0] = axis[0] * axis[1] * (1.0 - c) - axis[2] * s 52 | m[1][0] = axis[0] * axis[1] * (1.0 - c) + axis[2] * s 53 | m[1][1] = axis[1] * axis[1] + (1.0 - axis[1] * axis[1]) * c 54 | # m[1][2] = axis[1] * axis[2] * (1.0 - c) + axis[0] * s 55 | m[1][2] = axis[1] * axis[2] * (1.0 - c) - axis[0] * s 56 | m[1][3] = 0.0 57 | 58 | # m[2][0] = axis[0] * axis[2] * (1.0 - c) + axis[1] * s 59 | m[2][0] = axis[0] * axis[2] * (1.0 - c) - axis[1] * s 60 | # m[2][1] = axis[1] * axis[2] * (1.0 - c) - axis[0] * s 61 | m[2][1] = axis[1] * axis[2] * (1.0 - c) + axis[0] * s 62 | m[2][2] = axis[2] * axis[2] + (1.0 - axis[2] * axis[2]) * c 63 | m[2][3] = 0.0 64 | 65 | m[3][0] = 0.0 66 | m[3][1] = 0.0 67 | m[3][2] = 0.0 68 | m[3][3] = 1.0 69 | 70 | return m 71 | 72 | 73 | def get_world_coords(rgb, depth, env): 74 | height, width, _ = rgb.shape 75 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees 76 | 77 | # Apply back-projection: K_inv @ pixels * depth 78 | cam_coords = np.ones((height, width, 4)) 79 | u0 = K[0, 2] 80 | v0 = K[1, 2] 81 | fx = K[0, 0] 82 | fy = K[1, 1] 83 | # Loop through each pixel in the image 84 | for v in range(height): 85 | for u in range(width): 86 | # Apply equation in fig 3 87 | x = (u - u0) * depth[v, u] / fx 88 | y = (v - v0) * depth[v, u] / fy 89 | z = depth[v, u] 90 | cam_coords[v][u][:3] = (x, y, z) 91 | 92 | particle_pos = pyflex.get_positions().reshape((-1, 4)) 93 | print('cloth pixels: ', np.count_nonzero(depth)) 94 | print("cloth particle num: ", pyflex.get_n_particles()) 95 | 96 | # debug: print camera coordinates 97 | # print(cam_coords.shape) 98 | # cnt = 0 99 | # for v in range(height): 100 | # for u in range(width): 101 | # if depth[v][u] > 0: 102 | # print("v: {} u: {} cnt: {} cam_coord: {} approximate particle pos: {}".format( 103 | # v, u, cnt, cam_coords[v][u], particle_pos[cnt])) 104 | # rgb = rgbd[:, :, :3].copy() 105 | # rgb[v][u][0] = 255 106 | # rgb[v][u][1] = 0 107 | # rgb[v][u][2] = 0 108 | # cv2.imshow('rgb', rgb[:, :, ::-1]) 109 | # cv2.waitKey() 110 | # cnt += 1 111 | 112 | # from cam coord to world coord 113 | cam_x, cam_y, cam_z = env.camera_params['default_camera']['pos'][0], env.camera_params['default_camera']['pos'][1], \ 114 | env.camera_params['default_camera']['pos'][2] 115 | cam_x_angle, cam_y_angle, cam_z_angle = env.camera_params['default_camera']['angle'][0], env.camera_params['default_camera']['angle'][1], \ 116 | env.camera_params['default_camera']['angle'][2] 117 | 118 | # get rotation matrix: from world to camera 119 | matrix1 = get_rotation_matrix(- cam_x_angle, [0, 1, 0]) 120 | # matrix2 = get_rotation_matrix(- cam_y_angle - np.pi, [np.cos(cam_x_angle), 0, np.sin(cam_x_angle)]) 121 | matrix2 = get_rotation_matrix(- cam_y_angle - np.pi, [1, 0, 0]) 122 | rotation_matrix = matrix2 @ matrix1 123 | 124 | # get translation matrix: from world to camera 125 | translation_matrix = np.zeros((4, 4)) 126 | translation_matrix[0][0] = 1 127 | translation_matrix[1][1] = 1 128 | translation_matrix[2][2] = 1 129 | translation_matrix[3][3] = 1 130 | translation_matrix[0][3] = - cam_x 131 | translation_matrix[1][3] = - cam_y 132 | translation_matrix[2][3] = - cam_z 133 | 134 | # debug: from world to camera 135 | cloth_x, cloth_y = env.current_config['ClothSize'][0], env.current_config['ClothSize'][1] 136 | # cnt = 0 137 | # for u in range(height): 138 | # for v in range(width): 139 | # if depth[u][v] > 0: 140 | # world_coord = np.ones(4) 141 | # world_coord[:3] = particle_pos[cnt][:3] 142 | # convert_cam_coord = rotation_matrix @ translation_matrix @ world_coord 143 | # # convert_cam_coord = translation_matrix @ matrix2 @ matrix1 @ world_coord 144 | # print("u {} v {} \n world coord {} \n convert camera coord {} \n real camera coord {}".format( 145 | # u, v, world_coord, convert_cam_coord, cam_coords[u][v] 146 | # )) 147 | # cnt += 1 148 | # input('wait...') 149 | 150 | # convert the camera coordinate back to the world coordinate using the rotation and translation matrix 151 | cam_coords = cam_coords.reshape((-1, 4)).transpose() # 4 x (height x width) 152 | world_coords = np.linalg.inv(rotation_matrix @ translation_matrix) @ cam_coords # 4 x (height x width) 153 | world_coords = world_coords.transpose().reshape((height, width, 4)) 154 | 155 | # roughly check the final world coordinate with the actual coordinate 156 | # firstu = 0 157 | # firstv = 0 158 | # for u in range(height): 159 | # for v in range(width): 160 | # if depth[u][v]: 161 | # if u > firstu: # move to a new line 162 | # firstu = u 163 | # firstv = v 164 | 165 | # cnt = (u - firstu) * cloth_x + (v - firstv) 166 | # print("u {} v {} cnt{}\nworld_coord\t{}\nparticle coord\t{}\nerror\t{}".format( 167 | # u, v, cnt, world_coords[u][v], particle_pos[cnt], np.linalg.norm( world_coords[u][v] - particle_pos[cnt]))) 168 | # rgb = rgbd[:, :, :3].copy() 169 | # rgb[u][v][0] = 255 170 | # rgb[u][v][1] = 0 171 | # rgb[u][v][2] = 0 172 | # cv2.imshow('rgb', rgb[:, :, ::-1]) 173 | # cv2.waitKey() 174 | # exit() 175 | return world_coords 176 | 177 | 178 | def get_observable_particle_index(world_coords, particle_pos, rgb, depth): 179 | height, width, _ = rgb.shape 180 | # perform the matching of pixel particle to real particle 181 | observable_particle_idxes = [] 182 | particle_pos = particle_pos[:, :3] 183 | for u in range(height): 184 | for v in range(width): 185 | if depth[u][v] > 0: 186 | estimated_world_coord = world_coords[u][v][:3] 187 | distance = np.linalg.norm(estimated_world_coord - particle_pos, axis=1) 188 | estimated_particle_idx = np.argmin(distance) 189 | # print("u {} v {} estimated particle idx {}".format(u, v, estimated_particle_idx)) 190 | observable_particle_idxes.append(estimated_particle_idx) 191 | # rgb = rgbd[:, :, :3].copy() 192 | # rgb[u][v][0] = 255 193 | # rgb[u][v][1] = 0 194 | # rgb[u][v][2] = 0 195 | # cv2.imshow('chosen_idx', rgb[:, :, ::-1]) 196 | # cv2.waitKey() 197 | # exit() 198 | return observable_particle_idxes 199 | -------------------------------------------------------------------------------- /VCD/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def seg_3d_figure(data: np.ndarray, labels: np.ndarray, labelmap=None, sizes=None, fig=None): 5 | import plotly.colors as pc 6 | import plotly.graph_objects as go 7 | from plotly.subplots import make_subplots 8 | import plotly.express as px 9 | import plotly.figure_factory as ff 10 | 11 | # Create a figure. 12 | if fig is None: 13 | fig = go.Figure() 14 | 15 | # Find the ranges for visualizing. 16 | mean = data.mean(axis=0) 17 | max_x = np.abs(data[:, 0] - mean[0]).max() 18 | max_y = np.abs(data[:, 1] - mean[1]).max() 19 | max_z = np.abs(data[:, 2] - mean[2]).max() 20 | all_max = max(max(max_x, max_y), max_z) 21 | 22 | # Colormap. 23 | cols = np.array(pc.qualitative.Alphabet) 24 | labels = labels.astype(int) 25 | for label in np.unique(labels): 26 | subset = data[np.where(labels == label)] 27 | subset = np.squeeze(subset) 28 | if sizes is None: 29 | subset_sizes = 1.5 30 | else: 31 | subset_sizes = sizes[np.where(labels == label)] 32 | color = cols[label % len(cols)] 33 | if labelmap is not None: 34 | legend = labelmap[label] 35 | else: 36 | legend = str(label) 37 | fig.add_trace( 38 | go.Scatter3d( 39 | mode="markers", 40 | marker={"size": subset_sizes, "color": color, "line": {"width": 0}}, 41 | x=subset[:, 0], 42 | y=subset[:, 1], 43 | z=subset[:, 2], 44 | name=legend, 45 | ) 46 | ) 47 | fig.update_layout(showlegend=True) 48 | 49 | # This sets the figure to be a cube centered at the center of the pointcloud, such that it fits 50 | # all the points. 51 | fig.update_layout( 52 | scene=dict( 53 | xaxis=dict(nticks=10, range=[mean[0] - all_max, mean[0] + all_max]), 54 | yaxis=dict(nticks=10, range=[mean[1] - all_max, mean[1] + all_max]), 55 | zaxis=dict(nticks=10, range=[mean[2] - all_max, mean[2] + all_max]), 56 | aspectratio=dict(x=1, y=1, z=1), 57 | ), 58 | margin=dict(l=0, r=0, b=0, t=40), 59 | legend=dict(x=1.0, y=0.75), 60 | ) 61 | return fig 62 | -------------------------------------------------------------------------------- /VCD/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from torchvision.utils import make_grid 6 | from VCD.utils.camera_utils import project_to_image 7 | import pyflex 8 | import re 9 | import h5py 10 | import os 11 | from softgym.utils.visualization import save_numpy_as_gif 12 | from chester import logger 13 | import random 14 | 15 | 16 | class VArgs(object): 17 | def __init__(self, vv): 18 | for key, val in vv.items(): 19 | setattr(self, key, val) 20 | 21 | 22 | def vv_to_args(vv): 23 | args = VArgs(vv) 24 | return args 25 | 26 | 27 | # Function to extract all the numbers from the given string 28 | def extract_numbers(str): 29 | array = re.findall(r'[0-9]+', str) 30 | if len(array) == 0: 31 | return [0] 32 | return array 33 | 34 | 35 | ################## Pointcloud Processing ################# 36 | import pcl 37 | 38 | 39 | # def get_partial_particle(full_particle, observable_idx): 40 | # return np.array(full_particle[observable_idx], dtype=np.float32) 41 | 42 | 43 | def voxelize_pointcloud(pointcloud, voxel_size): 44 | cloud = pcl.PointCloud(pointcloud) 45 | sor = cloud.make_voxel_grid_filter() 46 | sor.set_leaf_size(voxel_size, voxel_size, voxel_size) 47 | pointcloud = sor.filter() 48 | pointcloud = np.asarray(pointcloud).astype(np.float32) 49 | return pointcloud 50 | 51 | 52 | from softgym.utils.misc import vectorized_range, vectorized_meshgrid 53 | 54 | 55 | def pc_reward_model(pos, cloth_particle_radius=0.00625, downsample_scale=3): 56 | cloth_particle_radius *= downsample_scale 57 | pos = np.reshape(pos, [-1, 3]) 58 | min_x = np.min(pos[:, 0]) 59 | min_y = np.min(pos[:, 2]) 60 | max_x = np.max(pos[:, 0]) 61 | max_y = np.max(pos[:, 2]) 62 | init = np.array([min_x, min_y]) 63 | span = np.array([max_x - min_x, max_y - min_y]) / 100. 64 | pos2d = pos[:, [0, 2]] 65 | 66 | offset = pos2d - init 67 | slotted_x_low = np.maximum(np.round((offset[:, 0] - cloth_particle_radius) / span[0]).astype(int), 0) 68 | slotted_x_high = np.minimum(np.round((offset[:, 0] + cloth_particle_radius) / span[0]).astype(int), 100) 69 | slotted_y_low = np.maximum(np.round((offset[:, 1] - cloth_particle_radius) / span[1]).astype(int), 0) 70 | slotted_y_high = np.minimum(np.round((offset[:, 1] + cloth_particle_radius) / span[1]).astype(int), 100) 71 | 72 | grid = np.zeros(10000) # Discretization 73 | listx = vectorized_range(slotted_x_low, slotted_x_high) 74 | listy = vectorized_range(slotted_y_low, slotted_y_high) 75 | listxx, listyy = vectorized_meshgrid(listx, listy) 76 | idx = listxx * 100 + listyy 77 | idx = np.clip(idx.flatten(), 0, 9999) 78 | grid[idx] = 1 79 | 80 | res = np.sum(grid) * span[0] * span[1] 81 | return res 82 | 83 | 84 | ################## IO ################################# 85 | def downsample(cloth_xdim, cloth_ydim, scale): 86 | cloth_xdim, cloth_ydim = int(cloth_xdim), int(cloth_ydim) 87 | new_idx = np.arange(cloth_xdim * cloth_ydim).reshape((cloth_ydim, cloth_xdim)) 88 | new_idx = new_idx[::scale, ::scale] 89 | cloth_ydim, cloth_xdim = new_idx.shape 90 | new_idx = new_idx.flatten() 91 | 92 | return new_idx, cloth_xdim, cloth_ydim 93 | 94 | 95 | def load_h5_data(data_names, path): 96 | hf = h5py.File(path, 'r') 97 | data = {} 98 | for name in data_names: 99 | d = np.array(hf.get(name)) 100 | data[name] = d 101 | hf.close() 102 | return data 103 | 104 | 105 | def store_h5_data(data_names, data, path): 106 | hf = h5py.File(path, 'w') 107 | for name in data_names: 108 | hf.create_dataset(name, data=data[name]) 109 | hf.close() 110 | 111 | 112 | def load_data(data_dir, idx_rollout, idx_timestep, data_names): 113 | data_path = os.path.join(data_dir, str(idx_rollout), str(idx_timestep) + '.h5') 114 | return load_h5_data(data_names, data_path) 115 | 116 | 117 | def load_data_list(data_dir, idx_rollout, idx_timestep, data_names): 118 | data_path = os.path.join(data_dir, str(idx_rollout), str(idx_timestep) + '.h5') 119 | d = load_h5_data(data_names, data_path) 120 | return [d[name] for name in data_names] 121 | 122 | 123 | def store_data(): 124 | raise NotImplementedError 125 | 126 | 127 | def transform_info(all_infos): 128 | """ Input: All info is a nested list with the index of [episode][time]{info_key:info_value} 129 | Output: transformed_infos is a dictionary with the index of [info_key][episode][time] 130 | """ 131 | if len(all_infos) == 0: 132 | return [] 133 | transformed_info = {} 134 | num_episode = len(all_infos) 135 | T = len(all_infos[0]) 136 | 137 | for info_name in all_infos[0][0].keys(): 138 | infos = np.zeros([num_episode, T], dtype=np.float32) 139 | for i in range(num_episode): 140 | infos[i, :] = np.array([info[info_name] for info in all_infos[i]]) 141 | transformed_info[info_name] = infos 142 | return transformed_info 143 | 144 | 145 | def draw_grid(list_of_imgs, nrow, padding=10, pad_value=200): 146 | img_list = torch.from_numpy(np.array(list_of_imgs).transpose(0, 3, 1, 2)) 147 | img = make_grid(img_list, nrow=nrow, padding=padding, pad_value=pad_value) 148 | # print(img.shape) 149 | img = img.numpy().transpose(1, 2, 0) 150 | return img 151 | 152 | 153 | def inrange(x, low, high): 154 | if x >= low and x < high: 155 | return True 156 | else: 157 | return False 158 | 159 | 160 | ################## Visualization ###################### 161 | 162 | def draw_edge(frame, predicted_edges, matrix_world_to_camera, pointcloud, camera_height, camera_width): 163 | u, v = project_to_image(matrix_world_to_camera, pointcloud, camera_height, camera_width) 164 | for edge_idx in range(predicted_edges.shape[1]): 165 | s = predicted_edges[0][edge_idx] 166 | r = predicted_edges[1][edge_idx] 167 | start = (u[s], v[s]) 168 | end = (u[r], v[r]) 169 | color = (255, 0, 0) 170 | thickness = 1 171 | image = cv2.line(frame, start, end, color, thickness) 172 | 173 | return image 174 | 175 | 176 | def cem_make_gif(all_frames, save_dir, save_name): 177 | # Convert to T x index x C x H x W for pytorch 178 | all_frames = np.array(all_frames).transpose([1, 0, 4, 2, 3]) 179 | grid_imgs = [make_grid(torch.from_numpy(frame), nrow=5).permute(1, 2, 0).data.cpu().numpy() for frame in all_frames] 180 | save_numpy_as_gif(np.array(grid_imgs), osp.join(save_dir, save_name)) 181 | 182 | 183 | def draw_policy_action(obs_before, obs_after, start_loc_1, end_loc_1, matrix_world_to_camera, start_loc_2=None, end_loc_2=None): 184 | height, width, _ = obs_before.shape 185 | if start_loc_2 is not None: 186 | l = [(start_loc_1, end_loc_1), (start_loc_2, end_loc_2)] 187 | else: 188 | l = [(start_loc_1, end_loc_1)] 189 | for (start_loc, end_loc) in l: 190 | # print(start_loc, end_loc) 191 | suv = project_to_image(matrix_world_to_camera, start_loc.reshape((1, 3)), height, width) 192 | su, sv = suv[0][0], suv[1][0] 193 | euv = project_to_image(matrix_world_to_camera, end_loc.reshape((1, 3)), height, width) 194 | eu, ev = euv[0][0], euv[1][0] 195 | if inrange(su, 0, width) and inrange(sv, 0, height) and inrange(eu, 0, width) and inrange(ev, 0, height): 196 | cv2.arrowedLine(obs_before, (su, sv), (eu, ev), (255, 0, 0), 3) 197 | obs_before[sv - 5:sv + 5, su - 5:su + 5, :] = (0, 0, 0) 198 | 199 | res = np.concatenate((obs_before, obs_after), axis=1) 200 | return res 201 | 202 | 203 | def draw_planned_actions(save_idx, obses, start_poses, end_poses, matrix_world_to_camera, log_dir): 204 | height = width = obses[0].shape[0] 205 | 206 | start_uv = [] 207 | end_uv = [] 208 | for sp in start_poses: 209 | suv = project_to_image(matrix_world_to_camera, sp.reshape((1, 3)), height, width) 210 | start_uv.append((suv[0][0], suv[1][0])) 211 | for ep in end_poses: 212 | euv = project_to_image(matrix_world_to_camera, ep.reshape((1, 3)), height, width) 213 | end_uv.append((euv[0][0], euv[1][0])) 214 | 215 | res = [] 216 | for idx in range(len(obses) - 1): 217 | obs = obses[idx] 218 | su, sv = start_uv[idx] 219 | eu, ev = end_uv[idx] 220 | if inrange(su, 0, width) and inrange(sv, 0, height) and inrange(eu, 0, width) and inrange(ev, 0, height): 221 | cv2.arrowedLine(obs, (su, sv), (eu, ev), (255, 0, 0), 3) 222 | obs[sv - 5:sv + 5, su - 5:su + 5, :] = (0, 0, 0) 223 | res.append(obs) 224 | 225 | res.append(obses[-1]) 226 | res = np.concatenate(res, axis=1) 227 | cv2.imwrite(osp.join(log_dir, '{}_planned.png'.format(save_idx)), res[:, :, ::-1]) 228 | 229 | 230 | def draw_cem_elites(obs_, start_poses, end_poses, mean_start_pos, mean_end_pos, 231 | matrix_world_to_camera, log_dir, save_idx=None): 232 | obs = obs_.copy() 233 | start_uv = [] 234 | end_uv = [] 235 | height = width = obs.shape[0] 236 | for sp in start_poses: 237 | suv = project_to_image(matrix_world_to_camera, sp.reshape((1, 3)), height, width) 238 | start_uv.append((suv[0][0], suv[1][0])) 239 | for ep in end_poses: 240 | euv = project_to_image(matrix_world_to_camera, ep.reshape((1, 3)), height, width) 241 | end_uv.append((euv[0][0], euv[1][0])) 242 | 243 | for idx in range(len(start_poses)): 244 | su, sv = start_uv[idx] 245 | eu, ev = end_uv[idx] 246 | # poses at the front have higher reward 247 | if inrange(su, 0, 255) and inrange(sv, 0, 255) and inrange(eu, 0, 255) and inrange(ev, 0, 255): 248 | cv2.arrowedLine(obs, (su, sv), (eu, ev), (255 * (1 - idx / len(start_poses)), 0, 0), 2) 249 | obs[sv - 2:sv + 2, su - 2:su + 2, :] = (0, 0, 0) 250 | 251 | mean_s_uv = project_to_image(matrix_world_to_camera, mean_start_pos.reshape((1, 3)), height, width) 252 | mean_e_uv = project_to_image(matrix_world_to_camera, mean_end_pos.reshape((1, 3)), height, width) 253 | mean_su, mean_sv = mean_s_uv[0][0], mean_s_uv[1][0] 254 | mean_eu, mean_ev = mean_e_uv[0][0], mean_e_uv[1][0] 255 | 256 | if inrange(mean_su, 0, 255) and inrange(mean_sv, 0, 255) and \ 257 | inrange(mean_eu, 0, 255) and inrange(mean_ev, 0, 255): 258 | cv2.arrowedLine(obs, (mean_su, mean_sv), (mean_eu, mean_ev), (0, 0, 255), 3) 259 | obs[mean_su - 5:mean_sv + 5, mean_eu - 5:mean_ev + 5, :] = (0, 0, 0) 260 | if save_idx is not None: 261 | cv2.imwrite(osp.join(log_dir, '{}_elite.png'.format(save_idx)), obs) 262 | return obs 263 | 264 | 265 | def set_shape_pos(pos): 266 | shape_states = np.array(pyflex.get_shape_states()).reshape(-1, 14) 267 | shape_states[:, 3:6] = pos.reshape(-1, 3) 268 | shape_states[:, :3] = pos.reshape(-1, 3) 269 | pyflex.set_shape_states(shape_states) 270 | 271 | 272 | def visualize(env, particle_positions, shape_positions, config_id, sample_idx=None, picked_particles=None, show=False): 273 | """ Render point cloud trajectory without running the simulation dynamics""" 274 | env.reset(config_id=config_id) 275 | frames = [] 276 | for i in range(len(particle_positions)): 277 | particle_pos = particle_positions[i] 278 | shape_pos = shape_positions[i] 279 | p = pyflex.get_positions().reshape(-1, 4) 280 | p[:, :3] = [0., -0.1, 0.] # All particles moved underground 281 | if sample_idx is None: 282 | p[:len(particle_pos), :3] = particle_pos 283 | else: 284 | p[:, :3] = [0, -0.1, 0] 285 | p[sample_idx, :3] = particle_pos 286 | pyflex.set_positions(p) 287 | set_shape_pos(shape_pos) 288 | rgb = env.get_image(env.camera_width, env.camera_height) 289 | frames.append(rgb) 290 | if show: 291 | if i == 0: continue 292 | picked_point = picked_particles[i] 293 | phases = np.zeros(pyflex.get_n_particles()) 294 | for id in picked_point: 295 | if id != -1: 296 | phases[sample_idx[int(id)]] = 1 297 | pyflex.set_phases(phases) 298 | img = env.get_image() 299 | 300 | cv2.imshow('picked particle images', img[:, :, ::-1]) 301 | cv2.waitKey() 302 | 303 | return frames 304 | 305 | 306 | def add_occluded_particles(observable_positions, observable_vel_history, particle_radius=0.00625, neighbor_distance=0.0216): 307 | occluded_idx = np.where(observable_positions[:, 1] > neighbor_distance / 2 + particle_radius) 308 | occluded_positions = [] 309 | for o_idx in occluded_idx[0]: 310 | pos = observable_positions[o_idx] 311 | occlude_num = np.floor(pos[1] / neighbor_distance).astype('int') 312 | for i in range(occlude_num): 313 | occluded_positions.append([pos[0], particle_radius + i * neighbor_distance, pos[2]]) 314 | 315 | print("add occluded particles num: ", len(occluded_positions)) 316 | occluded_positions = np.asarray(occluded_positions, dtype=np.float32).reshape((-1, 3)) 317 | occluded_velocity_his = np.zeros((len(occluded_positions), observable_vel_history.shape[1]), dtype=np.float32) 318 | 319 | all_positions = np.concatenate([observable_positions, occluded_positions], axis=0) 320 | all_vel_his = np.concatenate([observable_vel_history, occluded_velocity_his], axis=0) 321 | return all_positions, all_vel_his 322 | 323 | 324 | def sort_pointcloud_for_fold(pointcloud, dim): 325 | pointcloud = list(pointcloud) 326 | sorted_pointcloud = sorted(pointcloud, key=lambda k: (k[0], k[2])) 327 | for idx in range(len(sorted_pointcloud) - 1): 328 | assert sorted_pointcloud[idx][0] < sorted_pointcloud[idx + 1][0] or ( 329 | sorted_pointcloud[idx][0] == sorted_pointcloud[idx + 1][0] and 330 | sorted_pointcloud[idx][2] < sorted_pointcloud[idx + 1][2] 331 | ) 332 | 333 | real_sorted = [] 334 | for i in range(dim): 335 | points_row = sorted_pointcloud[i * dim: (i + 1) * dim] 336 | points_row = sorted(points_row, key=lambda k: k[2]) 337 | real_sorted += points_row 338 | 339 | sorted_pointcloud = real_sorted 340 | 341 | return np.asarray(sorted_pointcloud) 342 | 343 | 344 | def get_fold_idx(dim=4): 345 | group_a = [] 346 | for i in range(dim - 1): 347 | for j in range(dim - i - 1): 348 | group_a.append(i * dim + j) 349 | 350 | group_b = [] 351 | for j in range(dim - 1, 0, -1): 352 | for i in range(dim - 1, dim - 1 - j, -1): 353 | group_b.append(i * dim + j) 354 | 355 | return group_a, group_b 356 | 357 | 358 | ############################ Other ######################## 359 | def updateDictByAdd(dict1, dict2): 360 | ''' 361 | update dict1 by dict2 362 | ''' 363 | for k1, v1 in dict2.items(): 364 | for k2, v2 in v1.items(): 365 | dict1[k1][k2] += v2.cpu().item() 366 | return dict1 367 | 368 | 369 | def configure_logger(log_dir, exp_name): 370 | # Configure logger 371 | logger.configure(dir=log_dir, exp_name=exp_name) 372 | logdir = logger.get_dir() 373 | assert logdir is not None 374 | os.makedirs(logdir, exist_ok=True) 375 | 376 | 377 | def configure_seed(seed): 378 | # Configure seed 379 | torch.manual_seed(seed) 380 | if torch.cuda.is_available(): 381 | torch.cuda.manual_seed_all(seed) 382 | 383 | np.random.seed(seed) 384 | random.seed(seed) 385 | 386 | 387 | ############### for planning ############################### 388 | def set_picker_pos(pos): 389 | shape_states = pyflex.get_shape_states().reshape((-1, 14)) 390 | shape_states[1, :3] = -1 391 | shape_states[1, 3:6] = -1 392 | 393 | shape_states[0, :3] = pos 394 | shape_states[0, 3:6] = pos 395 | pyflex.set_shape_states(shape_states) 396 | pyflex.step() 397 | 398 | 399 | def set_resource(): 400 | import resource 401 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 402 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 403 | -------------------------------------------------------------------------------- /VCD/vc_edge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from VCD.models import GNN 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | 10 | import os.path as osp 11 | 12 | from chester import logger 13 | from VCD.utils.camera_utils import get_matrix_world_to_camera, project_to_image 14 | import matplotlib.pyplot as plt 15 | import torch_geometric 16 | 17 | from VCD.dataset_edge import ClothDatasetPointCloudEdge 18 | from VCD.utils.utils import extract_numbers 19 | from VCD.utils.data_utils import AggDict 20 | import json 21 | from tqdm import tqdm 22 | 23 | 24 | class VCConnection(object): 25 | def __init__(self, args, env=None): 26 | self.args = args 27 | self.model = GNN(args, decoder_output_dim=1, name='EdgeGNN') # Predict 0/1 Label for mesh edge classification 28 | self.device = torch.device(self.args.cuda_idx) 29 | self.model.to(self.device) 30 | self.optim = torch.optim.Adam(self.model.param(), lr=args.lr, betas=(args.beta1, 0.999)) 31 | self.scheduler = ReduceLROnPlateau(self.optim, 'min', factor=0.8, patience=3, verbose=True) 32 | if self.args.edge_model_path is not None: 33 | self.load_model(self.args.load_optim) 34 | 35 | self.datasets = {phase: ClothDatasetPointCloudEdge(args, 'vsbl', phase, env) for phase in ['train', 'valid']} 36 | follow_batch = 'x_' 37 | self.dataloaders = { 38 | x: torch_geometric.data.DataLoader( 39 | self.datasets[x], batch_size=args.batch_size, follow_batch=follow_batch, 40 | shuffle=True if x == 'train' else False, drop_last=True, 41 | num_workers=args.num_workers, prefetch_factor=8) 42 | for x in ['train', 'valid'] 43 | } 44 | 45 | self.log_dir = logger.get_dir() 46 | self.bce_logit_loss = nn.BCEWithLogitsLoss() 47 | self.load_epoch = 0 48 | 49 | def generate_dataset(self): 50 | os.system('mkdir -p ' + self.args.dataf) 51 | for phase in ['train', 'valid']: 52 | self.datasets[phase].generate_dataset() 53 | print('Dataset generated in', self.args.dataf) 54 | 55 | def plot(self, phase, epoch, i): 56 | data_folder = osp.join(self.args.dataf, phase) 57 | traj_ids = np.random.randint(0, len(os.listdir(data_folder)), self.args.plot_num) 58 | step_ids = np.random.randint(self.args.n_his, self.args.time_step - self.args.n_his, self.args.plot_num) 59 | pred_accs, pred_mesh_edges, gt_mesh_edges, edges, positionss, rgbs = [], [], [], [], [], [] 60 | for idx, (traj_id, step_id) in enumerate(zip(traj_ids, step_ids)): 61 | pred_mesh_edge, gt_mesh_edge, edge, positions, rgb = self.load_data_and_predict(traj_id, step_id, self.datasets[phase]) 62 | pred_acc = np.mean(pred_mesh_edge == gt_mesh_edge) 63 | pred_accs.append(pred_acc) 64 | 65 | if idx < 3: # plot the first 4 edge predictions 66 | pred_mesh_edges.append(pred_mesh_edge) 67 | gt_mesh_edges.append(gt_mesh_edge) 68 | edges.append(edge) 69 | positionss.append(positions) 70 | rgbs.append(rgb) 71 | 72 | fig = plt.figure(figsize=(30, 30)) 73 | for idx in range(min(3, len(positionss))): 74 | pos, edge, pred_mesh_edge, gt_mesh_edge = positionss[idx], edges[idx], pred_mesh_edges[idx], gt_mesh_edges[idx] 75 | 76 | predict_ax = fig.add_subplot(3, 3, idx * 3 + 1, projection='3d') 77 | gt_ax = fig.add_subplot(3, 3, idx * 3 + 2, projection='3d') 78 | both_ax = fig.add_subplot(3, 3, idx * 3 + 3, projection='3d') 79 | 80 | for edge_idx in range(edge.shape[1]): 81 | s = int(edge[0][edge_idx]) 82 | r = int(edge[1][edge_idx]) 83 | if pred_mesh_edge[edge_idx]: 84 | predict_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='r') 85 | both_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='r') 86 | if gt_mesh_edge[edge_idx]: 87 | gt_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='g') 88 | both_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='g') 89 | 90 | gt_ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='g', s=20) 91 | predict_ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='r', s=20) 92 | both_ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='g', s=20) 93 | 94 | plt.savefig(osp.join(self.log_dir, 'edge-prediction-{}-{}-{}.png'.format(phase, epoch, i))) 95 | plt.close('all') 96 | 97 | if rgbs[0] is not None: 98 | fig, axes = plt.subplots(3, 2, figsize=(30, 20)) 99 | for idx in range(min(3, len(positionss))): 100 | rgb, gt_mesh_edge, pointcloud = rgbs[idx], gt_mesh_edges[idx], positionss[idx] 101 | pred_mesh_edge, edge = pred_mesh_edges[idx], edges[idx] 102 | 103 | height, width, _ = rgb.shape 104 | matrix_world_to_camera = get_matrix_world_to_camera( 105 | self.env.camera_params[self.env.camera_name]['pos'], self.env.camera_params[self.env.camera_name]['angle'] 106 | ) 107 | matrix_world_to_camera = matrix_world_to_camera[:3, :] # 3 x 4 108 | u, v = project_to_image(matrix_world_to_camera, pointcloud, height, width) 109 | 110 | predict_ax_2 = axes[idx][0] 111 | true_ax_2 = axes[idx][1] 112 | 113 | predict_ax_2.imshow(rgb) 114 | true_ax_2.imshow(rgb) 115 | 116 | for edge_idx in range(edge.shape[1]): 117 | s = int(edge[0][edge_idx]) 118 | r = int(edge[1][edge_idx]) 119 | if pred_mesh_edge[edge_idx]: 120 | predict_ax_2.plot([u[s], u[r]], [v[s], v[r]], c='r', linewidth=0.5) 121 | if gt_mesh_edge[edge_idx]: 122 | true_ax_2.plot([u[s], u[r]], [v[s], v[r]], c='r', linewidth=0.5) 123 | 124 | predict_ax_2.set_title("predicted edge on point cloud") 125 | true_ax_2.set_title("mesh edge on particles") 126 | predict_ax_2.scatter(u, v, c='r', s=2) 127 | true_ax_2.scatter(u, v, c='r', s=2) 128 | 129 | plt.savefig(osp.join(self.log_dir, 'edge-projected-{}-{}-{}.png'.format(phase, epoch, i))) 130 | plt.close('all') 131 | 132 | return pred_accs 133 | 134 | def train(self): 135 | 136 | # Training loop 137 | st_epoch = self.load_epoch 138 | best_valid_loss = np.inf 139 | for epoch in range(st_epoch, self.args.n_epoch): 140 | phases = ['train', 'valid'] if self.args.eval == 0 else ['valid'] 141 | for phase in phases: 142 | self.set_mode(phase) 143 | epoch_info = AggDict(is_detach=True) 144 | 145 | for i, data in tqdm(enumerate(self.dataloaders[phase]), desc=f'Epoch {epoch}, phase {phase}'): 146 | data = data.to(self.device).to_dict() 147 | iter_info = AggDict(is_detach=False) 148 | last_global = torch.zeros(self.args.batch_size, self.args.global_size, dtype=torch.float32, device=self.device) 149 | with torch.set_grad_enabled(phase == 'train'): 150 | data['u'] = last_global 151 | pred_mesh_edge = self.model(data) 152 | loss = self.bce_logit_loss(pred_mesh_edge['mesh_edge'], data['gt_mesh_edge']) # TODO change accel to eedge 153 | iter_info.add_item('loss', loss) 154 | 155 | if phase == 'train': 156 | self.optim.zero_grad() 157 | loss.backward() 158 | self.optim.step() 159 | 160 | epoch_info.update_by_add(iter_info) 161 | iter_info.clear() 162 | 163 | epoch_len = len(self.dataloaders[phase]) 164 | if i == len(self.dataloaders[phase]) - 1: 165 | avg_dict = epoch_info.get_mean('{}/'.format(phase), epoch_len) 166 | avg_dict['lr'] = self.optim.param_groups[0]['lr'] 167 | for k, v in avg_dict.items(): 168 | logger.record_tabular(k, v) 169 | 170 | pred_accs = self.plot(phase, epoch, i) 171 | 172 | logger.record_tabular(phase + '/epoch', epoch) 173 | logger.record_tabular(phase + '/pred_acc', np.mean(pred_accs)) 174 | logger.dump_tabular() 175 | 176 | if phase == 'train' and i == len(self.dataloaders[phase]) - 1: 177 | suffix = '{}'.format(epoch) 178 | self.model.save_model(self.log_dir, 'vsbl', suffix, self.optim) 179 | 180 | print('%s [%d/%d] Loss: %.4f, Best valid: %.4f' % 181 | (phase, epoch, self.args.n_epoch, avg_dict[f'{phase}/loss'], best_valid_loss)) 182 | 183 | if phase == 'valid': 184 | cur_loss = avg_dict[f'{phase}/loss'] 185 | self.scheduler.step(cur_loss) 186 | if (cur_loss < best_valid_loss): 187 | best_valid_loss = cur_loss 188 | state_dict = self.args.__dict__ 189 | state_dict['best_epoch'] = epoch 190 | state_dict['best_valid_loss'] = cur_loss 191 | with open(osp.join(self.log_dir, 'best_state.json'), 'w') as f: 192 | json.dump(state_dict, f, indent=2, sort_keys=True) 193 | self.model.save_model(self.log_dir, 'vsbl', 'best', self.optim) 194 | 195 | def load_data_and_predict(self, rollout_idx, timestep, dataset): 196 | args = self.args 197 | self.set_mode('eval') 198 | 199 | idx = rollout_idx * (self.args.time_step - self.args.n_his) + timestep 200 | data_ori = dataset._prepare_transition(idx) 201 | data = dataset.build_graph(data_ori) 202 | 203 | gt_mesh_edge = data['gt_mesh_edge'].detach().cpu().numpy() 204 | 205 | with torch.no_grad(): 206 | data['x_batch'] = torch.zeros(data['x'].size(0), dtype=torch.long, device=self.device) 207 | data['u'] = torch.zeros([1, self.args.global_size], device=self.device) 208 | for key in ['x', 'edge_index', 'edge_attr']: 209 | data[key] = data[key].to(self.device) 210 | pred_mesh_edge = self.model(data)['mesh_edge'] 211 | 212 | pred_mesh_edge_logits = pred_mesh_edge.cpu().numpy() 213 | pred_mesh_edge = pred_mesh_edge_logits > 0 214 | edges = data['edge_index'].detach().cpu().numpy() 215 | 216 | return pred_mesh_edge, gt_mesh_edge, edges, data_ori['normalized_vox_pc'], None 217 | 218 | def infer_mesh_edges(self, args): 219 | """ 220 | args: a dict 221 | scene_params 222 | pointcloud 223 | cuda_idx 224 | """ 225 | scene_params = args['scene_params'] 226 | point_cloud = args['pointcloud'] 227 | cuda_idx = args.get('cuda_idx', 0) 228 | 229 | self.set_mode('eval') 230 | if cuda_idx >= 0: 231 | self.to(cuda_idx) 232 | edge_dataset = self.datasets['train'] 233 | 234 | normalized_point_cloud = point_cloud - np.mean(point_cloud, axis=0) 235 | data_ori = { 236 | 'scene_params': scene_params, 237 | 'observable_idx': None, 238 | 'normalized_vox_pc': normalized_point_cloud, 239 | 'pc_to_mesh_mapping': None 240 | } 241 | data = edge_dataset.build_graph(data_ori, get_gt_edge_label=False) 242 | with torch.no_grad(): 243 | data['x_batch'] = torch.zeros(data['x'].size(0), dtype=torch.long, device=self.device) 244 | data['u'] = torch.zeros([1, self.args.global_size], device=self.device) 245 | for key in ['x', 'edge_index', 'edge_attr']: 246 | data[key] = data[key].to(self.device) 247 | pred_mesh_edge_logits = self.model(data)['mesh_edge'] 248 | 249 | pred_mesh_edge_logits = pred_mesh_edge_logits.cpu().numpy() 250 | pred_mesh_edge = pred_mesh_edge_logits > 0 251 | 252 | edges = data['edge_index'].detach().cpu().numpy() 253 | senders = [] 254 | receivers = [] 255 | num_edges = edges.shape[1] 256 | for e_idx in range(num_edges): 257 | if pred_mesh_edge[e_idx]: 258 | senders.append(int(edges[0][e_idx])) 259 | receivers.append(int(edges[1][e_idx])) 260 | 261 | mesh_edges = np.vstack([senders, receivers]) 262 | return mesh_edges 263 | 264 | def to(self, cuda_idx): 265 | self.model.to(torch.device("cuda:{}".format(cuda_idx))) 266 | 267 | def set_mode(self, mode='train'): 268 | self.model.set_mode('train' if mode == 'train' else 'eval') 269 | 270 | def load_model(self, load_optim=False): 271 | self.model.load_model(self.args.edge_model_path, load_optim=load_optim, optim=self.optim) 272 | self.load_epoch = extract_numbers(self.args.edge_model_path)[-1] 273 | -------------------------------------------------------------------------------- /chester/README.md: -------------------------------------------------------------------------------- 1 | # Chester 2 | 3 | Chester is a tool aiming at automatically launching experiments. This tool based on rllab(https://github.com/rll/rllab ), and further extended for launching and retrieving experiments in different remote machines, including: 4 | 1. Seuss 5 | 2. PSC (Pittsburgh Super Computing) 6 | 7 | ## Getting Started 8 | 9 | We've provided an example for launching experiments of openai/baseline's DDPG algorithm. 10 | 11 | Look into the /examples, you'll find 'train_luanch.py' and 'train.py'. 'train.py' is the parser where we copied a lot of codes in openai/baseline/ddpg/main.py and combined them as a function 'run_task'. 'run_task' receives the parameters and start running the DDPG algorithm with those given settings. 12 | 13 | The launcher 'train_launch.py' uses our chester and the 'run_task' function to launch a group of experiments locally. By running this launcher, the group experiments are started and the resutls are contained in one given folder. Those result files are able to be visulized with rllab's viskit. 14 | 15 | To support different options in visulization, chester provided self-written interface 'preset.py'. The author can write different custom splitters in this file and put it in the directory for experiments. The viskit tool can detect this preset file and apply different options. 16 | 17 | 18 | ### Prerequisites 19 | 20 | What things you need to install the software and how to install them 21 | 22 | ``` 23 | Give examples 24 | ``` 25 | 26 | ### Installation 27 | 28 | A step by step series of examples that tell you how to get a development env running 29 | 30 | Say what the step will be 31 | 32 | ``` 33 | Give the example 34 | ``` 35 | 36 | And repeat 37 | 38 | ``` 39 | until finished 40 | ``` 41 | 42 | End with an example of getting some data out of the system or using it for a little demo 43 | 44 | ## Running the tests 45 | 46 | Explain how to run the automated tests for this system 47 | 48 | ### Break down into end to end tests 49 | 50 | Explain what these tests test and why 51 | 52 | ``` 53 | Give an example 54 | ``` 55 | 56 | ### And coding style tests 57 | 58 | Explain what these tests test and why 59 | 60 | ``` 61 | Give an example 62 | ``` 63 | 64 | ## Deployment 65 | 66 | Add additional notes about how to deploy this on a live system 67 | 68 | ## Built With 69 | 70 | * [Dropwizard](http://www.dropwizard.io/1.0.2/docs/) - The web framework used 71 | * [Maven](https://maven.apache.org/) - Dependency Management 72 | * [ROME](https://rometools.github.io/rome/) - Used to generate RSS Feeds 73 | 74 | ## Contributing 75 | 76 | Please read [CONTRIBUTING.md](https://gist.github.com/PurpleBooth/b24679402957c63ec426) for details on our code of conduct, and the process for submitting pull requests to us. 77 | 78 | ## Versioning 79 | 80 | We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/your/project/tags). 81 | 82 | ## Authors 83 | 84 | * **Billie Thompson** - *Initial work* - [PurpleBooth](https://github.com/PurpleBooth) 85 | 86 | See also the list of [contributors](https://github.com/your/project/contributors) who participated in this project. 87 | 88 | ## License 89 | 90 | This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details 91 | 92 | ## Acknowledgments 93 | 94 | * Hat tip to anyone whose code was used 95 | * Inspiration 96 | * etc 97 | 98 | -------------------------------------------------------------------------------- /chester/add_variants.py: -------------------------------------------------------------------------------- 1 | # Add variants to finished experiments 2 | import argparse 3 | import os 4 | import json 5 | from pydoc import locate 6 | import config 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('exp_folder', type=str, help='Root of the experiment folder to walk through') 12 | parser.add_argument('key', type=str, help='Name of the additional key') 13 | parser.add_argument('value', help='Value of the additional key') 14 | parser.add_argument('value_type', default='str', type=str, help='Type of the additional key') 15 | parser.add_argument('remote', nargs='?', default=None, type=str, ) # Optional 16 | 17 | args = parser.parse_args() 18 | exp_paths = [x[0] for x in os.walk(args.exp_folder, followlinks=True)] 19 | 20 | value_type = locate(args.value_type) 21 | if value_type == bool: 22 | value = args.value in ['1', 'True', 'true'] 23 | else: 24 | value = value_type(args.value) 25 | 26 | for exp_path in exp_paths: 27 | try: 28 | variant_path = os.path.join(exp_path, "variant.json") 29 | # Modify locally 30 | with open(variant_path, 'r') as f: 31 | vv = json.load(f) 32 | if args.key in vv: 33 | print('Warning: key already in variants. {} = {}. Setting it to {}'.format(args.key, vv[args.key], value)) 34 | 35 | vv[args.key] = value 36 | with open(variant_path, 'w') as f: 37 | json.dump(vv, f, indent=2, sort_keys=True) 38 | print('{} modified'.format(variant_path)) 39 | 40 | # Upload it to remote 41 | if args.remote is not None: 42 | p = variant_path.rstrip('/').split('/') 43 | sub_exp_name, exp_name = p[-2], p[-3] 44 | 45 | remote_dir = os.path.join(config.REMOTE_DIR[args.remote], 'data', 'local', exp_name, sub_exp_name, 'variant.json') 46 | os.system('scp {} {}:{}'.format(variant_path, args.remote, remote_dir)) 47 | except IOError as e: 48 | print(e) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /chester/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | # TODO change this before make it into a pip package 5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 6 | 7 | LOG_DIR = os.path.join(PROJECT_PATH, "data") 8 | 9 | # Make sure to use absolute path 10 | REMOTE_DIR = { 11 | } 12 | 13 | REMOTE_MOUNT_OPTION = { 14 | } 15 | 16 | REMOTE_LOG_DIR = { 17 | } 18 | 19 | REMOTE_HEADER = dict() 20 | 21 | # location of the singularity file related to the project 22 | SIMG_DIR = { 23 | } 24 | CUDA_MODULE = { 25 | } 26 | MODULES = { 27 | } -------------------------------------------------------------------------------- /chester/containers/ubuntu-16.04-lts-rl.README: -------------------------------------------------------------------------------- 1 | Bootstrap: debootstrap 2 | OSVersion: xenial 3 | MirrorURL: http://us.archive.ubuntu.com/ubuntu/ 4 | 5 | %help 6 | This is a singularity container that runs Deep Reinforcement Learning algorithms on ubuntu 7 | Packages installed include: 8 | * cuda 9.0 and cuDNN 9 | Will run ~/.bashrc on start to make sure the PATH is the same. 10 | 11 | %runscript 12 | /usr/bin/nvidia-smi -L 13 | 14 | %environment 15 | LD_LIBRARY_PATH=/usr/local/cuda-9.0/cuda/lib64:/usr/local/cuda-9.0/lib64:/usr/lib/nvidia-384$LD_LIBRARY_PATH 16 | 17 | %setup 18 | echo "Let us have CUDA..." 19 | sh /home/xingyu/software/cuda/cuda_9.0.176_384.81_linux.run --silent --toolkit --toolkitpath=${SINGULARITY_ROOTFS}/usr/local/cuda-9.0 20 | ln -s ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0 ${SINGULARITY_ROOTFS}/usr/local/cuda 21 | echo "Let us also have cuDNN..." 22 | cp -prv /home/xingyu/software/cudnn/* ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0/ 23 | 24 | %labels 25 | AUTHOR xlin3@cs.cmu.edu 26 | VERSION v1.0 27 | 28 | %post 29 | echo "Hello from inside the container" 30 | sed -i 's/$/ universe/' /etc/apt/sources.list 31 | touch /usr/bin/nvidia-smi 32 | chmod +x /usr/bin/nvidia-smi 33 | 34 | apt-get -y update 35 | apt-get -y install software-properties-common vim make wget curl emacs ffmpeg git htop libffi-dev libglew-dev libgl1-mesa-glx libosmesa6 libosmesa6-dev libssl-dev mesa-utils module-init-tools openjdk-8-jdk python-dev python-numpy python-tk bzip2 36 | apt-get -y install build-essential 37 | apt-get -y install libgl1-mesa-dev libglfw3-dev 38 | apt-get -y install strace 39 | 40 | echo "Install openmpi 3.1.1" 41 | 42 | cd /tmp 43 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.1.tar.gz 44 | tar xf openmpi-3.1.1.tar.gz 45 | cd openmpi-3.1.1 46 | mkdir -p build 47 | cd build 48 | ../configure 49 | make -j 8 all 50 | make install 51 | apt-get -y install openmpi-bin 52 | rm -rf /tmp/openmpi* 53 | rm -rf /usr/bin/mpirun 54 | ln -s /usr/local/bin/mpirun /usr/bin/mpirun 55 | 56 | echo "Install mpi4py 3.0.0" 57 | cd /tmp 58 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz 59 | tar -zxf mpi4py-3.0.0.tar.gz 60 | cd mpi4py-3.0.0 61 | python setup.py build --mpicc= 62 | python setup.py install --user 63 | mkdir -p /usr/lib/nvidia-384 -------------------------------------------------------------------------------- /chester/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /chester/docs/README.md: -------------------------------------------------------------------------------- 1 | ## References in making this documentation 2 | * [Getting started on Sphinx](https://pythonhosted.org/an_example_pypi_project/sphinx.html) 3 | * [Read the Docs Sphinx Theme](https://github.com/rtfd/sphinx_rtd_theme) 4 | * [reStructuredText](http://docutils.sourceforge.net/docs/user/rst/quickref.html) 5 | -------------------------------------------------------------------------------- /chester/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /chester/docs/readme.md: -------------------------------------------------------------------------------- 1 | ## References in making this documentation 2 | * [Getting started on Sphinx](https://pythonhosted.org/an_example_pypi_project/sphinx.html) 3 | * [Read the Docs Sphinx Theme](https://github.com/rtfd/sphinx_rtd_theme) 4 | * [reStructuredText](http://docutils.sourceforge.net/docs/user/rst/quickref.html) 5 | -------------------------------------------------------------------------------- /chester/docs/source/_static/basic_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/chester/docs/source/_static/basic_screenshot.png -------------------------------------------------------------------------------- /chester/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'Chester' 21 | copyright = '2019, Xingyu Lin' 22 | author = 'Xingyu Lin' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = 'v0.1' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | ] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = [] 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | 47 | # The theme to use for HTML and HTML Help pages. See the documentation for 48 | # a list of builtin themes. 49 | # 50 | html_theme = 'sphinx_rtd_theme' 51 | 52 | # Add any paths that contain custom static files (such as style sheets) here, 53 | # relative to this directory. They are copied after the builtin static files, 54 | # so a file named "default.css" will overwrite the builtin "default.css". 55 | html_static_path = ['_static'] 56 | -------------------------------------------------------------------------------- /chester/docs/source/getting_started.rst: -------------------------------------------------------------------------------- 1 | .. _getting_started: 2 | 3 | 4 | *************** 5 | Getting started 6 | *************** 7 | 8 | .. _installing-docdir: 9 | 10 | Installing your doc directory 11 | ============================= 12 | 13 | You may already have sphinx `sphinx `_ 14 | installed -- you can check by doing:: 15 | 16 | python -c 'import sphinx' 17 | 18 | If that fails grab the latest version of and install it with:: 19 | 20 | > sudo easy_install -U Sphinx 21 | 22 | Now you are ready to build a template for your docs, using 23 | sphinx-quickstart:: 24 | 25 | > sphinx-quickstart 26 | 27 | accepting most of the defaults. I choose "sampledoc" as the name of my 28 | project. cd into your new directory and check the contents:: 29 | 30 | home:~/tmp/sampledoc> ls 31 | Makefile _static conf.py 32 | _build _templates index.rst 33 | 34 | The index.rst is the master ReST for your project, but before adding 35 | anything, let's see if we can build some html:: 36 | 37 | make html 38 | 39 | If you now point your browser to :file:`_build/html/index.html`, you 40 | should see a basic sphinx site. 41 | 42 | .. image:: _static/basic_screenshot.png 43 | 44 | .. _fetching-the-data: 45 | 46 | Fetching the data 47 | ----------------- 48 | 49 | Now we will start to customize out docs. Grab a couple of files from 50 | the `web site `_ 51 | or git. You will need :file:`getting_started.rst` and 52 | :file:`_static/basic_screenshot.png`. All of the files live in the 53 | "completed" version of this tutorial, but since this is a tutorial, 54 | we'll just grab them one at a time, so you can learn what needs to be 55 | changed where. Since we have more files to come, I'm going to grab 56 | the whole git directory and just copy the files I need over for now. 57 | First, I'll cd up back into the directory containing my project, check 58 | out the "finished" product from git, and then copy in just the files I 59 | need into my :file:`sampledoc` directory:: 60 | 61 | home:~/tmp/sampledoc> pwd 62 | /Users/jdhunter/tmp/sampledoc 63 | home:~/tmp/sampledoc> cd .. 64 | home:~/tmp> git clone https://github.com/matplotlib/sampledoc.git tutorial 65 | Cloning into 'tutorial'... 66 | remote: Counting objects: 87, done. 67 | remote: Compressing objects: 100% (43/43), done. 68 | remote: Total 87 (delta 45), reused 83 (delta 41) 69 | Unpacking objects: 100% (87/87), done. 70 | Checking connectivity... done 71 | home:~/tmp> cp tutorial/getting_started.rst sampledoc/ 72 | home:~/tmp> cp tutorial/_static/basic_screenshot.png sampledoc/_static/ 73 | 74 | The last step is to modify :file:`index.rst` to include the 75 | :file:`getting_started.rst` file (be careful with the indentation, the 76 | "g" in "getting_started" should line up with the ':' in ``:maxdepth``:: 77 | 78 | Contents: 79 | 80 | .. toctree:: 81 | :maxdepth: 2 82 | 83 | getting_started.rst 84 | 85 | and then rebuild the docs:: 86 | 87 | cd sampledoc 88 | make html 89 | 90 | 91 | When you reload the page by refreshing your browser pointing to 92 | :file:`_build/html/index.html`, you should see a link to the 93 | "Getting Started" docs, and in there this page with the screenshot. 94 | `Voila!` 95 | 96 | We can also use the image directive in :file:`index.rst` to include to the screenshot above 97 | with:: 98 | 99 | .. image:: 100 | _static/basic_screenshot.png 101 | 102 | 103 | Next we'll customize the look and feel of our site to give it a logo, 104 | some custom css, and update the navigation panels to look more like 105 | the `sphinx `_ site itself -- see 106 | :ref:`custom_look`. -------------------------------------------------------------------------------- /chester/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Chester documentation master file, created by 2 | sphinx-quickstart on Fri Apr 5 16:47:55 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Chester's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | getting_started.rst 14 | launcher.rst 15 | logger.rst 16 | visualization.rst 17 | 18 | 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | * :ref:`search` 26 | -------------------------------------------------------------------------------- /chester/docs/source/launcher.rst: -------------------------------------------------------------------------------- 1 | .. _launcher: 2 | 3 | 4 | *************** 5 | Launcher 6 | *************** 7 | -------------------------------------------------------------------------------- /chester/docs/source/logger.rst: -------------------------------------------------------------------------------- 1 | .. _logger: 2 | 3 | 4 | *************** 5 | Logger 6 | *************** 7 | s -------------------------------------------------------------------------------- /chester/docs/source/visualization.rst: -------------------------------------------------------------------------------- 1 | .. _visualization: 2 | 3 | 4 | *************** 5 | Visualization 6 | *************** 7 | 8 | .. _Batch-Plotting: 9 | 10 | Batch Plotting 11 | ============== 12 | 13 | If the data are logged with Chester, they can also be easily plotted in batch. 14 | After the data are logged, for each experiment, the hyper-parameters are stored in ``variants.json`` and different 15 | key values are stored in ``progress.csv``. ``chester/plotting/cplots.py`` offers the functions that can be used to 16 | organize different experiments based on their key values: 17 | 18 | - ``reload_data()``: Iterate through the data folder and organize each experiment into a list, with their progress data, hyper-parameters and also analyze all the curves and give the distinct hyper-parameters. 19 | - ``get_group_selectors()``: You should write a ``custom_series_splitter()``, which provides a legend for each experiment based on its hyper-parameters. This function will then group all the experiments by their legends. 20 | - ``get_shaded_curve()``: Create the needed y-values for plots with shades (representing variance or median) for a certain key value. 21 | 22 | A data structure from rllab visualization kit can be useful: ``Selector``. It can be constructed from the loaded 23 | experiments data structure:: 24 | 25 | from rllab.viskit import core 26 | exps_data, plottable_keys, distinct_params = reload_data(path_to_data_folder) 27 | selector = Selector(exps_data) 28 | 29 | After that, it can be used to extract progress infomation for a certain key value:: 30 | 31 | progresses = [exp.progress.get(key)) for exp in selector.extract()] 32 | 33 | or be filtered based on certain hyper-parameters:: 34 | 35 | selector = selector.where('env_name', env_name) 36 | 37 | Some examples can be found in both ``chester/cplots.py`` and ``chester/examples/cplot_example.py`` 38 | 39 | .. _Interactive-Frontend: 40 | 41 | Interactive Frontend 42 | ==================== 43 | 44 | Currently the interactive visualization feature is still coupled with the rllab. 45 | It can be accessed by doing:: 46 | 47 | python rllab/viskit/frontend.py 48 | 49 | .. _Preset: 50 | 51 | Preset 52 | ------ 53 | You may want to use a complex legend post-processor or splitter. 54 | The preset feature can be used to save such a setting. First write a ``presets.py``. Then, put it in the root of the 55 | data folder that you want to visualize. Now when you use the frontend visualization, there will be a preset button that 56 | you can choose. Some exmples of ``presets.py`` can be found at ``chester/examples`` 57 | -------------------------------------------------------------------------------- /chester/examples/cplot_example.py: -------------------------------------------------------------------------------- 1 | from chester.plotting.cplot import * 2 | import os.path as osp 3 | 4 | 5 | def custom_series_splitter(x): 6 | params = x['flat_params'] 7 | if 'use_ae_reward' in params and params['use_ae_reward']: 8 | return 'Auto Encoder' 9 | if params['her_replay_strategy'] == 'balance_filter': 10 | return 'Indicator+Balance+Filter' 11 | if params['env_kwargs.use_true_reward']: 12 | return 'Oracle' 13 | return 'Indicator' 14 | 15 | 16 | dict_leg2col = {"Oracle": 1, "Indicator": 0, 'Indicator+Balance+Filter': 2, "Auto Encoder": 3} 17 | save_path = './data/plots_chester' 18 | 19 | 20 | def plot_visual_learning(): 21 | data_path = './data/nsh/submit_rss/submit_rss/visual_learning' 22 | 23 | plot_keys = ['test/success_state', 'test/goal_dist_final_state'] 24 | plot_ylabels = ['Success', 'Final Distance to Goal'] 25 | plot_envs = ['FetchReach', 'Reacher', 'RopeFloat'] 26 | 27 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 28 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 29 | for (plot_key, plot_ylabel) in zip(plot_keys, plot_ylabels): 30 | for env_name in plot_envs: 31 | fig, ax = plt.subplots(figsize=(8, 5)) 32 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 33 | color = core.color_defaults[dict_leg2col[legend]] 34 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key, 35 | shade_type='median') 36 | 37 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"] 38 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode') 39 | x = [ele * env_horizon for ele in x] 40 | 41 | ax.plot(x, y, color=color, label=legend, linewidth=2.0) 42 | 43 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, 44 | alpha=0.2) 45 | 46 | def y_fmt(x, y): 47 | return str(int(np.round(x / 1000.0))) + 'K' 48 | 49 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 50 | ax.grid(True) 51 | ax.set_xlabel('Timesteps') 52 | ax.set_ylabel(plot_ylabel) 53 | axes = plt.gca() 54 | if 'Rope' in env_name: 55 | axes.set_xlim(left=20000) 56 | 57 | plt.title(env_name.replace('Float', 'Push')) 58 | loc = 'best' 59 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends) 60 | for legobj in leg.legendHandles: 61 | legobj.set_linewidth(3.0) 62 | 63 | save_name = filter_save_name('ind_visual_' + plot_key + '_' + env_name) 64 | 65 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight') 66 | 67 | 68 | def plot_state_learning(): 69 | data_path = './data/nsh/submit_rss/submit_rss/state_learning' 70 | 71 | plot_keys = ['test/success_state', 'test/goal_dist_final_state'] 72 | plot_envs = ['FetchReach', 'FetchPush', 'Reacher', 'RopeFloat'] 73 | 74 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 75 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 76 | for plot_key in plot_keys: 77 | for env_name in plot_envs: 78 | fig, ax = plt.subplots(figsize=(8, 5)) 79 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 80 | color = core.color_defaults[dict_leg2col[legend]] 81 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key, 82 | shade_type='median') 83 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"] 84 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode') 85 | x = [ele * env_horizon for ele in x] 86 | ax.plot(x, y, color=color, label=legend, linewidth=2.0) 87 | 88 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, 89 | alpha=0.2) 90 | 91 | def y_fmt(x, y): 92 | return str(int(np.round(x / 1000.0))) + 'K' 93 | 94 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 95 | ax.grid(True) 96 | ax.set_xlabel('Timesteps') 97 | ax.set_ylabel('Success') 98 | 99 | plt.title(env_name.replace('Float', 'Push')) 100 | loc = 'best' 101 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends) 102 | for legobj in leg.legendHandles: 103 | legobj.set_linewidth(3.0) 104 | 105 | save_name = filter_save_name('ind_state_' + plot_key + '_' + env_name) 106 | 107 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight') 108 | 109 | 110 | if __name__ == '__main__': 111 | plot_visual_learning() 112 | plot_state_learning() 113 | -------------------------------------------------------------------------------- /chester/examples/pgm_plot.py: -------------------------------------------------------------------------------- 1 | from chester.plotting.cplot import * 2 | import os.path as osp 3 | from random import shuffle 4 | 5 | save_path = '../sac/data/plots' 6 | dict_leg2col = {"LSP": 0, "Base": 1, "Behavior": 2} 7 | dict_xshift = {"LSP": 4000, "Base": 0, "Behavior": 6000} 8 | 9 | 10 | def custom_series_splitter(x): 11 | params = x['flat_params'] 12 | exp_name = params['exp_name'] 13 | dict_mapping = {'humanoid-resume-training-6000-00': 'Behavior', 14 | 'humanoid-resume-training-4000-00': 'LSP', 15 | 'humanoid-rllab/default-2019-04-14-07-04-08-421230-UTC-00': 'Base'} 16 | return dict_mapping[exp_name] 17 | 18 | 19 | def sliding_mean(data_array, window=5): 20 | data_array = np.array(data_array) 21 | new_list = [] 22 | for i in range(len(data_array)): 23 | indices = list(range(max(i - window + 1, 0), 24 | min(i + window + 1, len(data_array)))) 25 | avg = 0 26 | for j in indices: 27 | avg += data_array[j] 28 | avg /= float(len(indices)) 29 | new_list.append(avg) 30 | 31 | return np.array(new_list) 32 | 33 | 34 | def plot_main(): 35 | data_path = '../sac/data/mengxiong' 36 | plot_key = 'return-average' 37 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 38 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 39 | fig, ax = plt.subplots(figsize=(8, 5)) 40 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 41 | color = core.color_defaults[dict_leg2col[legend]] 42 | 43 | y, y_lower, y_upper = get_shaded_curve(selector, plot_key, shade_type='median') 44 | x = np.array(range(len(y))) 45 | x += dict_xshift[legend] 46 | y = sliding_mean(y, 5) 47 | ax.plot(x, y, color=color, label=legend, linewidth=2.0) 48 | 49 | # ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, 50 | # alpha=0.2) 51 | 52 | def y_fmt(x, y): 53 | return str(int(np.round(x))) + 'K' 54 | 55 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 56 | ax.grid(True) 57 | ax.set_xlabel('Timesteps') 58 | ax.set_ylabel('Average-return') 59 | 60 | # plt.title(env_name.replace('Float', 'Push')) 61 | loc = 'best' 62 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends) 63 | for legobj in leg.legendHandles: 64 | legobj.set_linewidth(3.0) 65 | 66 | save_name = filter_save_name('plots.png') 67 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight') 68 | 69 | 70 | if __name__ == '__main__': 71 | plot_main() 72 | -------------------------------------------------------------------------------- /chester/examples/presets.py: -------------------------------------------------------------------------------- 1 | preset_names = ['default'] 2 | 3 | def make_custom_seris_splitter(preset_names): 4 | legendNote = None 5 | if preset_names == 'default': 6 | def custom_series_splitter(x): 7 | params = x['flat_params'] 8 | # return params['her_replay_strategy'] 9 | if params['her_replay_strategy'] == 'future': 10 | ret = 'RG' 11 | elif params['her_replay_strategy'] == 'only_fake': 12 | if params['her_use_reward']: 13 | ret = 'FG+RR' 14 | else: 15 | ret = 'FG+FR' 16 | return ret + '+' + str(params['her_clip_len']) + '+' + str(params['her_reward_choices']) + '+' + str( 17 | params['her_failed_goal_ratio']) 18 | 19 | legendNote = "Fake Goal(FG)/Real Goal(RG) + Fake Reward(FR)/Real Goal(RG) + HER_clip_len + HER_reward_choices + HER_failed_goal_ratio" 20 | else: 21 | raise NotImplementedError 22 | return custom_series_splitter, legendNote 23 | -------------------------------------------------------------------------------- /chester/examples/presets2.py: -------------------------------------------------------------------------------- 1 | preset_names = ['default'] 2 | x_axis = 'Epoch' 3 | y_axis = 'Success' 4 | FILTERED = 'filtered' 5 | 6 | def make_custom_seris_splitter(preset_names): 7 | legendNote = None 8 | if preset_names == 'default': 9 | def custom_series_splitter(x): 10 | params = x['flat_params'] 11 | if params['her_failed_goal_option'] is None: 12 | ret = 'Distance Reward' 13 | elif params['her_failed_goal_option'] == 'dist_behaviour': 14 | ret = 'Exact Match' 15 | else: 16 | ret = FILTERED 17 | return ret 18 | 19 | legendNote = None 20 | else: 21 | raise NotImplementedError 22 | return custom_series_splitter, legendNote 23 | 24 | 25 | def make_custom_filter(preset_names): 26 | if preset_names == 'default': 27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names) 28 | def custom_filter(x): 29 | legend = custom_seris_splitter(x) 30 | if legend == FILTERED: 31 | return False 32 | else: 33 | return True 34 | # params = x['flat_params'] 35 | # if params['her_failed_goal_option'] != FILTERED: 36 | # return True 37 | # else: 38 | # return False 39 | return custom_filter 40 | 41 | -------------------------------------------------------------------------------- /chester/examples/presets3.py: -------------------------------------------------------------------------------- 1 | set1 = 'identity_ratio+her_clip(dist_behavior and HER)' 2 | preset_names = [set1] 3 | FILTERED = 'filtered' 4 | 5 | 6 | def make_custom_seris_splitter(preset_names): 7 | legendNote = None 8 | if preset_names == set1: 9 | def custom_series_splitter(x): 10 | params = x['flat_params'] 11 | if params['her_failed_goal_option'] in ['dist_G', 'dist_policy']: 12 | return FILTERED 13 | if params['her_identity_ratio'] is not None: 14 | return 'IR: ' + str(params['her_identity_ratio']) 15 | if params['her_clip_len'] is not None: 16 | return 'CL: ' + str(params['her_clip_len']) 17 | return 'HER' 18 | 19 | legendNote = 'IR: identity ratio; CL: clip length' 20 | else: 21 | raise NotImplementedError 22 | return custom_series_splitter, legendNote 23 | 24 | 25 | def make_custom_filter(preset_names): 26 | if preset_names == set1: 27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names) 28 | 29 | def custom_filter(x): 30 | legend = custom_seris_splitter(x) 31 | if legend == FILTERED: 32 | return False 33 | else: 34 | return True 35 | return custom_filter 36 | 37 | -------------------------------------------------------------------------------- /chester/examples/presets_tiancheng.py: -------------------------------------------------------------------------------- 1 | # Updated By Tiancheng Jin 08/28/2018 2 | 3 | # the preset file should be contained in the experiment folder ( which is assigned by exp_prefix ) 4 | # for example, this file should be put in /path to project/data/local/ 5 | 6 | # Here's an example for custom_series_splitter 7 | # suppose we want to split five experiments with random seeds from 0 to 4 into two strategies 8 | # * two groups for those with odd or plural random seeds: [0,2,4] and [1,3] 9 | # * two groups for those with smaller or larger random seeds: [0,1,2] and [3,4] 10 | 11 | preset_names = ['odd or plural','small or large'] 12 | 13 | 14 | def make_custom_seris_splitter(preset_name): 15 | legend_note = None 16 | custom_series_splitter = None 17 | 18 | if preset_name == 'odd or plural': 19 | # build a custom series splitter for odd or plural random seeds 20 | # where the input is the data for experiment ( contains both the results and the parameters ) 21 | def custom_series_splitter(x): 22 | # extract the parameters 23 | params = x['flat_params'] 24 | # make up the legend 25 | if params['seed'] % 2 == 0: 26 | legend = 'odd seeds' 27 | else: 28 | legend = 'plural seeds' 29 | return legend 30 | 31 | legend_note = "Odd or Plural" 32 | 33 | elif preset_name == 'small or large': 34 | def custom_series_splitter(x): 35 | params = x['flat_params'] 36 | if params['seed'] <= 2: 37 | legend = 'smaller seeds' 38 | else: 39 | legend = 'larger seeds' 40 | return legend 41 | 42 | legend_note = "Small or Large" 43 | else: 44 | assert NotImplementedError 45 | 46 | return custom_series_splitter, legend_note 47 | -------------------------------------------------------------------------------- /chester/examples/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from baselines import logger 4 | from baselines.common.misc_util import ( 5 | set_global_seeds, 6 | ) 7 | from baselines.ddpg.main import run 8 | from mpi4py import MPI 9 | 10 | 11 | DEFAULT_PARAMS = { 12 | # env 13 | 'env_id': 'HalfCheetah-v2', # max absolute value of actions on different coordinates 14 | 15 | # ddpg 16 | 'layer_norm': True, 17 | 'render': False, 18 | 'normalize_returns':False, 19 | 'normalize_observations':True, 20 | 'actor_lr': 0.0001, # critic learning rate 21 | 'critic_lr': 0.001, # actor learning rate 22 | 'critic_l2_reg': 1e-2, 23 | 'popart': False, 24 | 'gamma': 0.99, 25 | 26 | # training 27 | 'seed': 0, 28 | 'nb_epochs':500, # number of epochs 29 | 'nb_epoch_cycles': 20, # per epoch 30 | 'nb_rollout_steps': 100, # sampling batches per cycle 31 | 'nb_train_steps': 100, # training batches per cycle 32 | 'batch_size': 64, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length. 33 | 'reward_scale': 1.0, 34 | 'clip_norm': None, 35 | 36 | # exploration 37 | 'noise_type':'adaptive-param_0.2', 38 | 39 | # debugging, logging and visualization 40 | 'render_eval': False, 41 | 'nb_eval_steps':100, 42 | 'evaluation':False, 43 | } 44 | 45 | 46 | def run_task(vv, log_dir=None, exp_name=None, allow_extra_parameters=False): 47 | # Configure logging system 48 | if log_dir or logger.get_dir() is None: 49 | logger.configure(dir=log_dir) 50 | logdir = logger.get_dir() 51 | assert logdir is not None 52 | os.makedirs(logdir, exist_ok=True) 53 | 54 | # Seed for multi-CPU MPI implementation ( rank = 0 for single threaded implementation ) 55 | rank = MPI.COMM_WORLD.Get_rank() 56 | rank_seed = vv['seed'] + 1000000 * rank 57 | set_global_seeds(rank_seed) 58 | 59 | # load params from config 60 | params = DEFAULT_PARAMS 61 | 62 | # update all her parameters 63 | if not allow_extra_parameters: 64 | for k,v in vv.items(): 65 | if k not in DEFAULT_PARAMS: 66 | print("[ Warning ] Undefined Parameters %s with value %s"%(str(k),str(v))) 67 | params.update(**{k: v for (k, v) in vv.items() if k in DEFAULT_PARAMS}) 68 | else: 69 | params.update(**{k: v for (k, v) in vv.items()}) 70 | 71 | with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f: 72 | json.dump(params, f) 73 | 74 | run(**params) 75 | 76 | 77 | -------------------------------------------------------------------------------- /chester/examples/train_launch.py: -------------------------------------------------------------------------------- 1 | import time 2 | from chester.run_exp import run_experiment_lite, VariantGenerator 3 | 4 | if __name__ == '__main__': 5 | 6 | # Here's an example for doing grid search of openai's DDPG 7 | # on HalfCheetah 8 | 9 | # the experiment folder name 10 | # the directory is defined as /LOG_DIR/data/local/exp_prefix/, where LOG_DIR is defined in config.py 11 | exp_prefix = 'test-ddpg' 12 | vg = VariantGenerator() 13 | vg.add('env_id', ['HalfCheetah-v2', 'Hopper-v2', 'InvertedPendulum-v2']) 14 | 15 | # select random seeds from 0 to 4 16 | vg.add('seed', [0, 1, 2, 3, 4]) 17 | print('Number of configurations: ', len(vg.variants())) 18 | 19 | # set the maximum number for running experiments in parallel 20 | # this number depends on the number of processors in the runner 21 | maximum_launching_process = 5 22 | 23 | # launch experiments 24 | sub_process_popens = [] 25 | for vv in vg.variants(): 26 | while len(sub_process_popens) >= maximum_launching_process: 27 | sub_process_popens = [x for x in sub_process_popens if x.poll() is None] 28 | time.sleep(10) 29 | 30 | # import the launcher of experiments 31 | from chester.examples.train import run_task 32 | 33 | # use your written run_task function 34 | cur_popen = run_experiment_lite( 35 | stub_method_call=run_task, 36 | variant=vv, 37 | mode='local', 38 | exp_prefix=exp_prefix, 39 | wait_subprocess=False 40 | ) 41 | if cur_popen is not None: 42 | sub_process_popens.append(cur_popen) 43 | -------------------------------------------------------------------------------- /chester/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import os.path as osp 5 | import json 6 | import time 7 | import datetime 8 | import dateutil.tz 9 | import tempfile 10 | from collections import defaultdict 11 | 12 | # LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv', 'tensorboard'] 13 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv'] 14 | # Also valid: json, tensorboard 15 | 16 | DEBUG = 10 17 | INFO = 20 18 | WARN = 30 19 | ERROR = 40 20 | 21 | DISABLED = 50 22 | 23 | 24 | class KVWriter(object): 25 | def writekvs(self, kvs): 26 | raise NotImplementedError 27 | 28 | 29 | class SeqWriter(object): 30 | def writeseq(self, seq): 31 | raise NotImplementedError 32 | 33 | 34 | def put_in_middle(str1, str2): 35 | # Put str1 in str2 36 | n = len(str1) 37 | m = len(str2) 38 | if n <= m: 39 | return str2 40 | else: 41 | start = (n - m) // 2 42 | return str1[:start] + str2 + str1[start + m:] 43 | 44 | 45 | class HumanOutputFormat(KVWriter, SeqWriter): 46 | def __init__(self, filename_or_file): 47 | if isinstance(filename_or_file, str): 48 | self.file = open(filename_or_file, 'wt') 49 | self.own_file = True 50 | else: 51 | assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file 52 | self.file = filename_or_file 53 | self.own_file = False 54 | 55 | def writekvs(self, kvs): 56 | # Create strings for printing 57 | key2str = {} 58 | for (key, val) in sorted(kvs.items()): 59 | if isinstance(val, float): 60 | valstr = '%-8.3g' % (val,) 61 | else: 62 | valstr = str(val) 63 | key2str[self._truncate(key)] = self._truncate(valstr) 64 | 65 | # Find max widths 66 | if len(key2str) == 0: 67 | print('WARNING: tried to write empty key-value dict') 68 | return 69 | else: 70 | keywidth = max(map(len, key2str.keys())) 71 | valwidth = max(map(len, key2str.values())) 72 | 73 | # Write out the data 74 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 75 | timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z') 76 | 77 | dashes = '-' * (keywidth + valwidth + 7) 78 | dashes_time = put_in_middle(dashes, timestamp) 79 | lines = [dashes_time] 80 | for (key, val) in sorted(key2str.items()): 81 | lines.append('| %s%s | %s%s |' % ( 82 | key, 83 | ' ' * (keywidth - len(key)), 84 | val, 85 | ' ' * (valwidth - len(val)), 86 | )) 87 | lines.append(dashes) 88 | self.file.write('\n'.join(lines) + '\n') 89 | 90 | # Flush the output to the file 91 | self.file.flush() 92 | 93 | def _truncate(self, s): 94 | return s[:30] + '...' if len(s) > 33 else s 95 | 96 | def writeseq(self, seq): 97 | for arg in seq: 98 | self.file.write(arg) 99 | self.file.write('\n') 100 | self.file.flush() 101 | 102 | def close(self): 103 | if self.own_file: 104 | self.file.close() 105 | 106 | 107 | class JSONOutputFormat(KVWriter): 108 | def __init__(self, filename): 109 | self.file = open(filename, 'wt') 110 | 111 | def writekvs(self, kvs): 112 | for k, v in sorted(kvs.items()): 113 | if hasattr(v, 'dtype'): 114 | v = v.tolist() 115 | kvs[k] = float(v) 116 | self.file.write(json.dumps(kvs) + '\n') 117 | self.file.flush() 118 | 119 | def close(self): 120 | self.file.close() 121 | 122 | 123 | class CSVOutputFormat(KVWriter): 124 | def __init__(self, filename): 125 | self.file = open(filename, 'w+t') 126 | self.keys = [] 127 | self.sep = ',' 128 | 129 | def writekvs(self, kvs): 130 | # Add our current row to the history 131 | extra_keys = kvs.keys() - self.keys 132 | if extra_keys: 133 | self.keys.extend(extra_keys) 134 | self.file.seek(0) 135 | lines = self.file.readlines() 136 | self.file.seek(0) 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(',') 140 | self.file.write(k) 141 | self.file.write('\n') 142 | for line in lines[1:]: 143 | self.file.write(line[:-1]) 144 | self.file.write(self.sep * len(extra_keys)) 145 | self.file.write('\n') 146 | for (i, k) in enumerate(self.keys): 147 | if i > 0: 148 | self.file.write(',') 149 | v = kvs.get(k) 150 | if v is not None: 151 | self.file.write(str(v)) 152 | self.file.write('\n') 153 | self.file.flush() 154 | 155 | def close(self): 156 | self.file.close() 157 | 158 | 159 | class TensorBoardOutputFormat(KVWriter): 160 | """ 161 | Dumps key/value pairs into TensorBoard's numeric format. 162 | """ 163 | 164 | def __init__(self, dir): 165 | os.makedirs(dir, exist_ok=True) 166 | self.dir = dir 167 | self.step = 1 168 | prefix = 'events' 169 | path = osp.join(osp.abspath(dir), prefix) 170 | import tensorflow as tf 171 | from tensorflow.python import pywrap_tensorflow 172 | from tensorflow.core.util import event_pb2 173 | from tensorflow.python.util import compat 174 | self.tf = tf 175 | self.event_pb2 = event_pb2 176 | self.pywrap_tensorflow = pywrap_tensorflow 177 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 178 | 179 | def writekvs(self, kvs): 180 | def summary_val(k, v): 181 | kwargs = {'tag': k, 'simple_value': float(v)} 182 | return self.tf.Summary.Value(**kwargs) 183 | 184 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 185 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 186 | event.step = self.step # is there any reason why you'd want to specify the step? 187 | self.writer.WriteEvent(event) 188 | self.writer.Flush() 189 | self.step += 1 190 | 191 | def close(self): 192 | if self.writer: 193 | self.writer.Close() 194 | self.writer = None 195 | 196 | 197 | def make_output_format(format, ev_dir, log_suffix=''): 198 | os.makedirs(ev_dir, exist_ok=True) 199 | if format == 'stdout': 200 | return HumanOutputFormat(sys.stdout) 201 | elif format == 'log': 202 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) 203 | elif format == 'json': 204 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) 205 | elif format == 'csv': 206 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) 207 | elif format == 'tensorboard': 208 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) 209 | else: 210 | raise ValueError('Unknown format specified: %s' % (format,)) 211 | 212 | 213 | # ================================================================ 214 | # API 215 | # ================================================================ 216 | 217 | def logkv(key, val): 218 | """ 219 | Log a value of some diagnostic 220 | Call this once for each diagnostic quantity, each iteration 221 | If called many times, last value will be used. 222 | """ 223 | Logger.CURRENT.logkv(key, val) 224 | 225 | 226 | def logkv_mean(key, val): 227 | """ 228 | The same as logkv(), but if called many times, values averaged. 229 | """ 230 | Logger.CURRENT.logkv_mean(key, val) 231 | 232 | 233 | def logkvs(d): 234 | """ 235 | Log a dictionary of key-value pairs 236 | """ 237 | for (k, v) in d.items(): 238 | logkv(k, v) 239 | 240 | 241 | def dumpkvs(): 242 | """ 243 | Write all of the diagnostics from the current iteration 244 | 245 | level: int. (see logger.py docs) If the global logger level is higher than 246 | the level argument here, don't print to stdout. 247 | """ 248 | Logger.CURRENT.dumpkvs() 249 | 250 | 251 | def getkvs(): 252 | return Logger.CURRENT.name2val 253 | 254 | 255 | def log(*args, level=INFO): 256 | """ 257 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 258 | """ 259 | Logger.CURRENT.log(*args, level=level) 260 | 261 | 262 | def debug(*args): 263 | log(*args, level=DEBUG) 264 | 265 | 266 | def info(*args): 267 | log(*args, level=INFO) 268 | 269 | 270 | def warn(*args): 271 | log(*args, level=WARN) 272 | 273 | 274 | def error(*args): 275 | log(*args, level=ERROR) 276 | 277 | 278 | def set_level(level): 279 | """ 280 | Set logging threshold on current logger. 281 | """ 282 | Logger.CURRENT.set_level(level) 283 | 284 | 285 | def get_dir(): 286 | """ 287 | Get directory that log files are being written to. 288 | will be None if there is no output directory (i.e., if you didn't call start) 289 | """ 290 | return Logger.CURRENT.get_dir() 291 | 292 | 293 | record_tabular = logkv 294 | dump_tabular = dumpkvs 295 | 296 | 297 | class ProfileKV: 298 | """ 299 | Usage: 300 | with logger.ProfileKV("interesting_scope"): 301 | code 302 | """ 303 | 304 | def __init__(self, n): 305 | self.n = "wait_" + n 306 | 307 | def __enter__(self): 308 | self.t1 = time.time() 309 | 310 | def __exit__(self, type, value, traceback): 311 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1 312 | 313 | 314 | def profile(n): 315 | """ 316 | Usage: 317 | @profile("my_func") 318 | def my_func(): code 319 | """ 320 | 321 | def decorator_with_name(func): 322 | def func_wrapper(*args, **kwargs): 323 | with ProfileKV(n): 324 | return func(*args, **kwargs) 325 | 326 | return func_wrapper 327 | 328 | return decorator_with_name 329 | 330 | 331 | # ================================================================ 332 | # Backend 333 | # ================================================================ 334 | 335 | class Logger(object): 336 | DEFAULT = None # A logger with no output files. (See right below class definition) 337 | # So that you can still log to the terminal without setting up any output files 338 | CURRENT = None # Current logger being used by the free functions above 339 | 340 | def __init__(self, dir, output_formats): 341 | self.name2val = defaultdict(float) # values this iteration 342 | self.name2cnt = defaultdict(int) 343 | self.level = INFO 344 | self.dir = dir 345 | self.output_formats = output_formats 346 | 347 | # Logging API, forwarded 348 | # ---------------------------------------- 349 | def logkv(self, key, val): 350 | self.name2val[key] = val 351 | 352 | def logkv_mean(self, key, val): 353 | if val is None: 354 | self.name2val[key] = None 355 | return 356 | oldval, cnt = self.name2val[key], self.name2cnt[key] 357 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 358 | self.name2cnt[key] = cnt + 1 359 | 360 | def dumpkvs(self): 361 | if self.level == DISABLED: return 362 | for fmt in self.output_formats: 363 | if isinstance(fmt, KVWriter): 364 | fmt.writekvs(self.name2val) 365 | self.name2val.clear() 366 | self.name2cnt.clear() 367 | 368 | def log(self, *args, level=INFO): 369 | if self.level <= level: 370 | self._do_log(args) 371 | 372 | # Configuration 373 | # ---------------------------------------- 374 | def set_level(self, level): 375 | self.level = level 376 | 377 | def get_dir(self): 378 | return self.dir 379 | 380 | def close(self): 381 | for fmt in self.output_formats: 382 | fmt.close() 383 | 384 | # Misc 385 | # ---------------------------------------- 386 | def _do_log(self, args): 387 | for fmt in self.output_formats: 388 | if isinstance(fmt, SeqWriter): 389 | fmt.writeseq(map(str, args)) 390 | 391 | 392 | Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)]) 393 | 394 | 395 | def configure(dir=None, format_strs=None, exp_name=None): 396 | if dir is None: 397 | dir = os.getenv('OPENAI_LOGDIR') 398 | if dir is None: 399 | dir = osp.join(tempfile.gettempdir(), 400 | datetime.datetime.now().strftime("chester-%Y-%m-%d-%H-%M-%S")) 401 | 402 | assert isinstance(dir, str) 403 | os.makedirs(dir, exist_ok=True) 404 | 405 | if format_strs is None: 406 | strs = os.getenv('OPENAI_LOG_FORMAT') 407 | format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS 408 | output_formats = [make_output_format(f, dir) for f in format_strs] 409 | 410 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) 411 | log('Logging to %s' % dir) 412 | 413 | 414 | def reset(): 415 | if Logger.CURRENT is not Logger.DEFAULT: 416 | Logger.CURRENT.close() 417 | Logger.CURRENT = Logger.DEFAULT 418 | log('Reset logger') 419 | 420 | 421 | class scoped_configure(object): 422 | def __init__(self, dir=None, format_strs=None): 423 | self.dir = dir 424 | self.format_strs = format_strs 425 | self.prevlogger = None 426 | 427 | def __enter__(self): 428 | self.prevlogger = Logger.CURRENT 429 | configure(dir=self.dir, format_strs=self.format_strs) 430 | 431 | def __exit__(self, *args): 432 | Logger.CURRENT.close() 433 | Logger.CURRENT = self.prevlogger 434 | 435 | 436 | # ================================================================ 437 | 438 | def _demo(): 439 | info("hi") 440 | debug("shouldn't appear") 441 | set_level(DEBUG) 442 | debug("should appear") 443 | dir = "/tmp/testlogging" 444 | if os.path.exists(dir): 445 | shutil.rmtree(dir) 446 | configure(dir=dir) 447 | logkv("a", 3) 448 | logkv("b", 2.5) 449 | dumpkvs() 450 | logkv("b", -2.5) 451 | logkv("a", 5.5) 452 | dumpkvs() 453 | info("^^^ should see a = 5.5") 454 | logkv_mean("b", -22.5) 455 | logkv_mean("b", -44.4) 456 | logkv("a", 5.5) 457 | dumpkvs() 458 | info("^^^ should see b = 33.3") 459 | 460 | logkv("b", -2.5) 461 | dumpkvs() 462 | 463 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 464 | dumpkvs() 465 | 466 | 467 | # ================================================================ 468 | # Readers 469 | # ================================================================ 470 | 471 | def read_json(fname): 472 | import pandas 473 | ds = [] 474 | with open(fname, 'rt') as fh: 475 | for line in fh: 476 | ds.append(json.loads(line)) 477 | return pandas.DataFrame(ds) 478 | 479 | 480 | def read_csv(fname): 481 | import pandas 482 | return pandas.read_csv(fname, index_col=None, comment='#') 483 | 484 | 485 | def read_tb(path): 486 | """ 487 | path : a tensorboard file OR a directory, where we will find all TB files 488 | of the form events.* 489 | """ 490 | import pandas 491 | import numpy as np 492 | from glob import glob 493 | from collections import defaultdict 494 | import tensorflow as tf 495 | if osp.isdir(path): 496 | fnames = glob(osp.join(path, "events.*")) 497 | elif osp.basename(path).startswith("events."): 498 | fnames = [path] 499 | else: 500 | raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s" % path) 501 | tag2pairs = defaultdict(list) 502 | maxstep = 0 503 | for fname in fnames: 504 | for summary in tf.train.summary_iterator(fname): 505 | if summary.step > 0: 506 | for v in summary.summary.value: 507 | pair = (summary.step, v.simple_value) 508 | tag2pairs[v.tag].append(pair) 509 | maxstep = max(summary.step, maxstep) 510 | data = np.empty((maxstep, len(tag2pairs))) 511 | data[:] = np.nan 512 | tags = sorted(tag2pairs.keys()) 513 | for (colidx, tag) in enumerate(tags): 514 | pairs = tag2pairs[tag] 515 | for (step, value) in pairs: 516 | data[step - 1, colidx] = value 517 | return pandas.DataFrame(data, columns=tags) 518 | 519 | 520 | if __name__ == "__main__": 521 | _demo() 522 | -------------------------------------------------------------------------------- /chester/plotting/cplot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import json 5 | import argparse 6 | import itertools 7 | import numpy as np 8 | 9 | # Matplotlib 10 | import matplotlib 11 | 12 | matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | 15 | plt.rc('font', size=25) 16 | matplotlib.rcParams['pdf.fonttype'] = 42 # Default type3 cannot be rendered in some templates 17 | matplotlib.rcParams['ps.fonttype'] = 42 18 | matplotlib.rcParams['grid.alpha'] = 0.3 19 | matplotlib.rcParams['axes.titlesize'] = 25 20 | import matplotlib.ticker as tick 21 | 22 | # rllab 23 | sys.path.append('.') 24 | from rllab.misc.ext import flatten 25 | from rllab.viskit import core 26 | 27 | 28 | # from rllab.misc import ext 29 | 30 | # plotly 31 | # import plotly.offline as po 32 | # import plotly.graph_objs as go 33 | 34 | 35 | def reload_data(data_paths): 36 | """ 37 | Iterate through the data folder and organize each experiment into a list, with their progress data, hyper-parameters 38 | and also analyze all the curves and give the distinct hyper-parameters. 39 | :param data_path: Path of the folder storing all the data 40 | :return [exps_data, plottable_keys, distinct_params] 41 | exps_data: A list of the progress data for each curve. Each curve is an AttrDict with the key 42 | 'progress': A dictionary of plottable keys. The val of each key is an ndarray representing the 43 | values of the key during training, or one column in the progress.txt file. 44 | 'params'/'flat_params': A dictionary of all hyperparameters recorded in 'variants.json' file. 45 | plottable_keys: A list of strings representing all the keys that can be plotted. 46 | distinct_params: A list of hyper-parameters which have different values among all the curves. This can be used 47 | to split the graph into multiple figures. Each element is a tuple (param, list_of_values_to_take). 48 | """ 49 | 50 | exps_data = copy.copy(core.load_exps_data(data_paths, disable_variant=False, ignore_missing_keys=True)) 51 | plottable_keys = copy.copy(sorted(list(set(flatten(list(exp.progress.keys()) for exp in exps_data))))) 52 | distinct_params = copy.copy(sorted(core.extract_distinct_params(exps_data))) 53 | 54 | return exps_data, plottable_keys, distinct_params 55 | 56 | 57 | def get_shaded_curve(selector, key, shade_type='variance'): 58 | """ 59 | :param selector: Selector for a group of curves 60 | :param shade_type: Should be either 'variance' or 'median', indicating how the shades are calculated. 61 | :return: [y, y_lower, y_upper], representing the mean, upper and lower boundary of the shaded region 62 | """ 63 | 64 | # First, get the progresses 65 | progresses = [exp.progress.get(key, np.array([np.nan])) for exp in selector.extract()] 66 | max_size = max(len(x) for x in progresses) 67 | progresses = [np.concatenate([ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses] 68 | 69 | # Second, calculate the shaded area 70 | if shade_type == 'median': 71 | percentile25 = np.nanpercentile( 72 | progresses, q=25, axis=0) 73 | percentile50 = np.nanpercentile( 74 | progresses, q=50, axis=0) 75 | percentile75 = np.nanpercentile( 76 | progresses, q=75, axis=0) 77 | 78 | y = list(percentile50) 79 | y_upper = list(percentile75) 80 | y_lower = list(percentile25) 81 | elif shade_type == 'variance': 82 | means = np.nanmean(progresses, axis=0) 83 | stds = np.nanstd(progresses, axis=0) 84 | 85 | y = list(means) 86 | y_upper = list(means + stds) 87 | y_lower = list(means - stds) 88 | else: 89 | raise NotImplementedError 90 | 91 | return y, y_lower, y_upper 92 | 93 | 94 | def get_group_selectors(exps, custom_series_splitter): 95 | """ 96 | 97 | :param exps: 98 | :param custom_series_splitter: 99 | :return: A dictionary with (splitted_keys, group_selectors). Group selectors can be used to extract progresses. 100 | """ 101 | splitted_dict = dict() 102 | for exp in exps: 103 | # Group exps by their series_splitter key 104 | # splitted_dict: {key:[exp1, exp2, ...]} 105 | key = custom_series_splitter(exp) 106 | if key not in splitted_dict: 107 | splitted_dict[key] = list() 108 | splitted_dict[key].append(exp) 109 | 110 | splitted = list(splitted_dict.items()) 111 | # Group selectors: All the exps in one of the keys/legends 112 | # Group legends: All the different legends 113 | group_selectors = [core.Selector(list(x[1])) for x in splitted] 114 | group_legends = [x[0] for x in splitted] 115 | all_tuples = sorted(list(zip(group_selectors, group_legends)), key=lambda x: x[1], reverse=True) 116 | group_selectors = [x[0] for x in all_tuples] 117 | group_legends = [x[1] for x in all_tuples] 118 | return group_selectors, group_legends 119 | 120 | 121 | def filter_save_name(save_name): 122 | save_name = save_name.replace('/', '_') 123 | save_name = save_name.replace('[', '_') 124 | save_name = save_name.replace(']', '_') 125 | save_name = save_name.replace(',', '_') 126 | save_name = save_name.replace(' ', '_') 127 | save_name = save_name.replace('0.', '0_') 128 | 129 | return save_name 130 | 131 | 132 | def sliding_mean(data_array, window=5): 133 | data_array = np.array(data_array) 134 | new_list = [] 135 | for i in range(len(data_array)): 136 | indices = list(range(max(i - window + 1, 0), 137 | min(i + window + 1, len(data_array)))) 138 | avg = 0 139 | for j in indices: 140 | avg += data_array[j] 141 | avg /= float(len(indices)) 142 | new_list.append(avg) 143 | 144 | return np.array(new_list) 145 | 146 | 147 | if __name__ == '__main__': 148 | data_path = '/Users/Dora/Projects/baselines_hrl/data/seuss/visual_rss_RopeFloat_0407' 149 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 150 | 151 | # Example of extracting a single curve 152 | selector = core.Selector(exps_data) 153 | selector = selector.where('her_replay_strategy', 'balance_filter') 154 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state') 155 | _, ax = plt.subplots() 156 | 157 | color = core.color_defaults[0] 158 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2) 159 | ax.plot(range(len(y)), y, color=color, label=plt.legend, linewidth=2.0) 160 | 161 | 162 | # Example of extracting all the curves 163 | def custom_series_splitter(x): 164 | params = x['flat_params'] 165 | if 'use_ae_reward' in params and params['use_ae_reward']: 166 | return 'Auto Encoder' 167 | if params['her_replay_strategy'] == 'balance_filter': 168 | return 'Indicator+Balance+Filter' 169 | if params['env_kwargs.use_true_reward']: 170 | return 'Oracle' 171 | return 'Indicator' 172 | 173 | 174 | fig, ax = plt.subplots(figsize=(8, 5)) 175 | 176 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 177 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 178 | color = core.color_defaults[idx] 179 | 180 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state') 181 | 182 | ax.plot(range(len(y)), y, color=color, label=legend, linewidth=2.0) 183 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2) 184 | ax.grid(True) 185 | ax.set_xlabel('Timesteps') 186 | ax.set_ylabel('Success') 187 | loc = 'best' 188 | leg = ax.legend(loc=loc, prop={'size': 15}, ncol=1, labels=group_legends) 189 | for legobj in leg.legendHandles: 190 | legobj.set_linewidth(3.0) 191 | plt.savefig('test.png', bbox_inches='tight') 192 | -------------------------------------------------------------------------------- /chester/pull_result.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | sys.path.append('.') 6 | from chester import config 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('host', type=str) 11 | parser.add_argument('folder', type=str) 12 | parser.add_argument('--dry', action='store_true', default=False) 13 | parser.add_argument('--bare', action='store_true', default=False) 14 | parser.add_argument('--img', action='store_true', default=False) 15 | parser.add_argument('--pkl', action='store_true', default=False) 16 | parser.add_argument('--gif', action='store_true', default=False) 17 | parser.add_argument('--newdatadir', action='store_true', default=False) 18 | args = parser.parse_args() 19 | 20 | args.folder = args.folder.rstrip('/') 21 | if args.folder.rfind('/') !=-1: 22 | local_dir = os.path.join('./data', args.host, args.folder[:args.folder.rfind('/')]) 23 | else: 24 | local_dir = os.path.join('./data', args.host) 25 | # if args.newdatadir: 26 | dir_path = '/data/yufeiw2/softagent_prvil_merge/' 27 | # else: 28 | # dir_path = config.REMOTE_DIR[args.host] 29 | remote_data_dir = os.path.join(dir_path, 'data', 'local', args.folder) 30 | command = """rsync -avzh --delete --progress {host}:{remote_data_dir} {local_dir}""".format(host=args.host, 31 | remote_data_dir=remote_data_dir, 32 | local_dir=local_dir) 33 | if args.bare: 34 | command += """ --exclude '*checkpoin*' --exclude '*ckpt*' --exclude '*tfevents*' --exclude '*.pth' --exclude '*.pt' --include '*.csv' --include '*.json' --delete""" 35 | if not args.img: 36 | command += """ --exclude '*.png' """ 37 | if not args.gif: 38 | command += """ --exclude '*.gif' """ 39 | if not args.pkl: 40 | command += """ --exclude '*.pkl' """ 41 | if args.dry: 42 | print(command) 43 | else: 44 | os.system(command) 45 | -------------------------------------------------------------------------------- /chester/pull_s3_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | 5 | def aws_sync(bucket_name, s3_log_dir, target_dir, args): 6 | cmd = 'aws s3 cp --recursive s3://%s/%s %s' % (bucket_name, s3_log_dir, target_dir) 7 | exlus = ['"*.pkl"', '"*.gif"', '"*.png"', '"*.pth"'] 8 | inclus = [] 9 | if args.gif: 10 | exlus.remove('"*.gif"') 11 | if args.png: 12 | exlus.remove('"*.png"') 13 | if args.param: 14 | inclus.append('"params.pkl"') 15 | exlus.remove('"*.pkl"') 16 | 17 | if not args.include_all: 18 | for exc in exlus: 19 | cmd += ' --exclude ' + exc 20 | 21 | for inc in inclus: 22 | cmd += ' --include ' + inc 23 | 24 | print(cmd) 25 | # exit() 26 | subprocess.call(cmd, shell=True) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description='Process some integers.') 31 | parser.add_argument('log_dir', type=str, help='S3 Log dir') 32 | parser.add_argument('-b', '--bucket', type=str, default='chester-softgym', help='S3 Bucket') 33 | parser.add_argument('--param', type=int, default=0, help='Exclude') 34 | parser.add_argument('--gif', type=int, default=0, help='Exclude') 35 | parser.add_argument('--png', type=int, default=0, help='Exclude') 36 | parser.add_argument('--include_all', type=int, default=1, help='pull all data') 37 | 38 | args = parser.parse_args() 39 | s3_log_dir = "rllab/experiments/" + args.log_dir 40 | local_dir = os.path.join('./data', 'corl_s3_data', args.log_dir) 41 | if not os.path.exists(local_dir): 42 | os.makedirs(local_dir) 43 | aws_sync(args.bucket, s3_log_dir, local_dir, args) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /chester/rsync_exclude: -------------------------------------------------------------------------------- 1 | softgym_rpad 2 | *.img 3 | datasets 4 | data 5 | data/yufei_s3_data 6 | data/yufei_seuss_data 7 | data/local 8 | data/icml 9 | *__pycache__* 10 | build 11 | .idea 12 | .git 13 | DPI-Net 14 | videos 15 | imgs 16 | planet 17 | cem 18 | curl 19 | dreamer 20 | drq 21 | experiments 22 | PDDM 23 | pouring 24 | ResRL 25 | rlkit 26 | rlpyt_cloth 27 | tests 28 | tmp 29 | softgym/softgym/cached_initial_states/* 30 | wandb 31 | data2 32 | datasets2 -------------------------------------------------------------------------------- /chester/rsync_include: -------------------------------------------------------------------------------- 1 | GNS 2 | softgym/softgym/envs 3 | softgym/PyFlexRobotics/bindings/softgym_scenes 4 | -------------------------------------------------------------------------------- /chester/run_exp_worker.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os.path as osp 4 | import datetime 5 | import dateutil.tz 6 | import ast 7 | import uuid 8 | import pickle as pickle 9 | import base64 10 | import joblib 11 | 12 | from chester import config 13 | 14 | 15 | def run_experiment(argv): 16 | default_log_dir = config.LOG_DIR 17 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 18 | 19 | # avoid name clashes when running distributed jobs 20 | rand_id = str(uuid.uuid4())[:5] 21 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') 22 | 23 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--n_parallel', type=int, default=1, 26 | help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers') 27 | parser.add_argument( 28 | '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.') 29 | parser.add_argument('--log_dir', type=str, default=None, 30 | help='Path to save the log and iteration snapshot.') 31 | parser.add_argument('--snapshot_mode', type=str, default='all', 32 | help='Mode to save the snapshot. Can be either "all" ' 33 | '(all iterations will be saved), "last" (only ' 34 | 'the last iteration will be saved), "gap" (every' 35 | '`snapshot_gap` iterations are saved), or "none" ' 36 | '(do not save snapshots)') 37 | parser.add_argument('--snapshot_gap', type=int, default=1, 38 | help='Gap between snapshot iterations.') 39 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv', 40 | help='Name of the tabular log file (in csv).') 41 | parser.add_argument('--text_log_file', type=str, default='debug.log', 42 | help='Name of the text log file (in pure text).') 43 | parser.add_argument('--params_log_file', type=str, default='params.json', 44 | help='Name of the parameter log file (in json).') 45 | parser.add_argument('--variant_log_file', type=str, default='variant.json', 46 | help='Name of the variant log file (in json).') 47 | parser.add_argument('--resume_from', type=str, default=None, 48 | help='Name of the pickle file to resume experiment from.') 49 | parser.add_argument('--plot', type=ast.literal_eval, default=False, 50 | help='Whether to plot the iteration results') 51 | parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, 52 | help='Whether to only print the tabular log information (in a horizontal format)') 53 | parser.add_argument('--seed', type=int, 54 | help='Random seed for numpy') 55 | parser.add_argument('--args_data', type=str, 56 | help='Pickled data for stub objects') 57 | parser.add_argument('--variant_data', type=str, 58 | help='Pickled data for variant configuration') 59 | parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) 60 | 61 | args = parser.parse_args(argv[1:]) 62 | 63 | # if args.seed is not None: 64 | # set_seed(args.seed) 65 | # 66 | # if args.n_parallel > 0: 67 | # from rllab.sampler import parallel_sampler 68 | # parallel_sampler.initialize(n_parallel=args.n_parallel) 69 | # if args.seed is not None: 70 | # parallel_sampler.set_seed(args.seed) 71 | # 72 | # if args.plot: 73 | # from rllab.plotter import plotter 74 | # plotter.init_worker() 75 | 76 | if args.log_dir is None: 77 | log_dir = osp.join(default_log_dir, args.exp_name) 78 | else: 79 | log_dir = args.log_dir 80 | # tabular_log_file = osp.join(log_dir, args.tabular_log_file) 81 | # text_log_file = osp.join(log_dir, args.text_log_file) 82 | # params_log_file = osp.join(log_dir, args.params_log_file) 83 | 84 | if args.variant_data is not None: 85 | variant_data = pickle.loads(base64.b64decode(args.variant_data)) 86 | variant_log_file = osp.join(log_dir, args.variant_log_file) 87 | # logger.log_variant(variant_log_file, variant_data) 88 | else: 89 | variant_data = None 90 | 91 | # if not args.use_cloudpickle: 92 | # logger.log_parameters_lite(params_log_file, args) 93 | # 94 | # logger.add_text_output(text_log_file) 95 | # logger.add_tabular_output(tabular_log_file) 96 | # prev_snapshot_dir = logger.get_snapshot_dir() 97 | # prev_mode = logger.get_snapshot_mode() 98 | # logger.set_snapshot_dir(log_dir) 99 | # logger.set_snapshot_mode(args.snapshot_mode) 100 | # logger.set_snapshot_gap(args.snapshot_gap) 101 | # logger.set_log_tabular_only(args.log_tabular_only) 102 | # logger.push_prefix("[%s] " % args.exp_name) 103 | 104 | if args.resume_from is not None: 105 | data = joblib.load(args.resume_from) 106 | assert 'algo' in data 107 | algo = data['algo'] 108 | algo.train() 109 | else: 110 | # read from stdin 111 | if args.use_cloudpickle: 112 | import cloudpickle 113 | method_call = cloudpickle.loads(base64.b64decode(args.args_data)) 114 | method_call(variant_data, log_dir, args.exp_name) 115 | else: 116 | assert False 117 | # data = pickle.loads(base64.b64decode(args.args_data)) 118 | # maybe_iter = concretize(data) 119 | # if is_iterable(maybe_iter): 120 | # for _ in maybe_iter: 121 | # pass 122 | 123 | # logger.set_snapshot_mode(prev_mode) 124 | # logger.set_snapshot_dir(prev_snapshot_dir) 125 | # logger.remove_tabular_output(tabular_log_file) 126 | # logger.remove_text_output(text_log_file) 127 | # logger.pop_prefix() 128 | 129 | 130 | if __name__ == "__main__": 131 | run_experiment(sys.argv) 132 | -------------------------------------------------------------------------------- /chester/scripts/install_miniconda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Check this file before using 3 | CONDA_INSTALL_PATH="~/software/miniconda3" 4 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh 5 | chmod +x Miniconda3-latest-Linux-x86_64.sh 6 | ./Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_INSTALL_PATH 7 | if [ -d $CONDA_INSTALL_PATH/bin ]; then 8 | PATH=$PATH:$HOME/bin 9 | fi 10 | echo 'PATH='$CONDA_INSTALL_PATH'/bin:$PATH' >> ~/.bashrc 11 | rm ./Miniconda3-latest-Linux-x86_64.sh -------------------------------------------------------------------------------- /chester/scripts/install_mpi4py.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Install mpi4py 3.0.0" 3 | cd /tmp 4 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz 5 | tar -zxf mpi4py-3.0.0.tar.gz 6 | cd mpi4py-3.0.0 7 | python setup.py build --mpicc=/usr/local/bin/mpicc 8 | python setup.py install --user -------------------------------------------------------------------------------- /chester/setup_ec2_for_chester.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import re 3 | import sys 4 | import json 5 | import botocore 6 | import os 7 | from rllab.misc import console 8 | from string import Template 9 | import os.path as osp 10 | 11 | CHESTER_DIR = osp.dirname(__file__) 12 | ACCESS_KEY = os.environ["AWS_ACCESS_KEY"] 13 | ACCESS_SECRET = os.environ["AWS_ACCESS_SECRET"] 14 | S3_BUCKET_NAME = os.environ["RLLAB_S3_BUCKET"] 15 | 16 | ALL_REGION_AWS_SECURITY_GROUP_IDS = {} 17 | ALL_REGION_AWS_KEY_NAMES = {} 18 | 19 | CONFIG_TEMPLATE = Template(""" 20 | import os.path as osp 21 | import os 22 | 23 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 24 | 25 | AWS_NETWORK_INTERFACES = [] 26 | 27 | MUJOCO_KEY_PATH = osp.expanduser("~/.mujoco") 28 | 29 | USE_GPU = False 30 | 31 | USE_TF = True 32 | 33 | AWS_REGION_NAME = "us-east-2" 34 | 35 | if USE_GPU: 36 | DOCKER_IMAGE = "dementrock/rllab3-shared-gpu" 37 | else: 38 | DOCKER_IMAGE = "dementrock/rllab3-shared" 39 | 40 | DOCKER_LOG_DIR = "/tmp/expt" 41 | 42 | AWS_S3_PATH = "s3://$s3_bucket_name/rllab/experiments" 43 | 44 | AWS_CODE_SYNC_S3_PATH = "s3://$s3_bucket_name/rllab/code" 45 | 46 | ALL_REGION_AWS_IMAGE_IDS = { 47 | "ap-northeast-1": "ami-002f0167", 48 | "ap-northeast-2": "ami-590bd937", 49 | "ap-south-1": "ami-77314318", 50 | "ap-southeast-1": "ami-1610a975", 51 | "ap-southeast-2": "ami-9dd4ddfe", 52 | "eu-central-1": "ami-63af720c", 53 | "eu-west-1": "ami-41484f27", 54 | "sa-east-1": "ami-b7234edb", 55 | "us-east-1": "ami-83f26195", 56 | "us-east-2": "ami-66614603", 57 | "us-west-1": "ami-576f4b37", 58 | "us-west-2": "ami-b8b62bd8" 59 | } 60 | 61 | AWS_IMAGE_ID = ALL_REGION_AWS_IMAGE_IDS[AWS_REGION_NAME] 62 | 63 | if USE_GPU: 64 | AWS_INSTANCE_TYPE = "g2.2xlarge" 65 | else: 66 | AWS_INSTANCE_TYPE = "c4.4xlarge" 67 | 68 | ALL_REGION_AWS_KEY_NAMES = $all_region_aws_key_names 69 | 70 | AWS_KEY_NAME = ALL_REGION_AWS_KEY_NAMES[AWS_REGION_NAME] 71 | 72 | AWS_SPOT = True 73 | 74 | AWS_SPOT_PRICE = '0.5' 75 | 76 | AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None) 77 | 78 | AWS_ACCESS_SECRET = os.environ.get("AWS_ACCESS_SECRET", None) 79 | 80 | AWS_IAM_INSTANCE_PROFILE_NAME = "rllab" 81 | 82 | AWS_SECURITY_GROUPS = ["rllab-sg"] 83 | 84 | ALL_REGION_AWS_SECURITY_GROUP_IDS = $all_region_aws_security_group_ids 85 | 86 | AWS_SECURITY_GROUP_IDS = ALL_REGION_AWS_SECURITY_GROUP_IDS[AWS_REGION_NAME] 87 | 88 | FAST_CODE_SYNC_IGNORES = [ 89 | ".git", 90 | "data", 91 | "data/local", 92 | "data/archive", 93 | "data/debug", 94 | "data/s3", 95 | "data/video", 96 | "src", 97 | ".idea", 98 | ".pods", 99 | "tests", 100 | "examples", 101 | "docs", 102 | ".idea", 103 | ".DS_Store", 104 | ".ipynb_checkpoints", 105 | "blackbox", 106 | "blackbox.zip", 107 | "*.pyc", 108 | "*.ipynb", 109 | "scratch-notebooks", 110 | "conopt_root", 111 | "private/key_pairs", 112 | ] 113 | 114 | FAST_CODE_SYNC = True 115 | 116 | """) 117 | 118 | 119 | def setup_iam(): 120 | iam_client = boto3.client( 121 | "iam", 122 | aws_access_key_id=ACCESS_KEY, 123 | aws_secret_access_key=ACCESS_SECRET, 124 | ) 125 | iam = boto3.resource('iam', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=ACCESS_SECRET) 126 | 127 | # delete existing role if it exists 128 | try: 129 | existing_role = iam.Role('rllab') 130 | existing_role.load() 131 | # if role exists, delete and recreate 132 | if not query_yes_no( 133 | "There is an existing role named rllab. Proceed to delete everything rllab-related and recreate?", 134 | default="no"): 135 | sys.exit() 136 | print("Listing instance profiles...") 137 | inst_profiles = existing_role.instance_profiles.all() 138 | for prof in inst_profiles: 139 | for role in prof.roles: 140 | print("Removing role %s from instance profile %s" % (role.name, prof.name)) 141 | prof.remove_role(RoleName=role.name) 142 | print("Deleting instance profile %s" % prof.name) 143 | prof.delete() 144 | for policy in existing_role.policies.all(): 145 | print("Deleting inline policy %s" % policy.name) 146 | policy.delete() 147 | for policy in existing_role.attached_policies.all(): 148 | print("Detaching policy %s" % policy.arn) 149 | existing_role.detach_policy(PolicyArn=policy.arn) 150 | print("Deleting role") 151 | existing_role.delete() 152 | except botocore.exceptions.ClientError as e: 153 | if e.response['Error']['Code'] == 'NoSuchEntity': 154 | pass 155 | else: 156 | raise e 157 | 158 | print("Creating role rllab") 159 | iam_client.create_role( 160 | Path='/', 161 | RoleName='rllab', 162 | AssumeRolePolicyDocument=json.dumps({'Version': '2012-10-17', 'Statement': [ 163 | {'Action': 'sts:AssumeRole', 'Effect': 'Allow', 'Principal': {'Service': 'ec2.amazonaws.com'}}]}) 164 | ) 165 | 166 | role = iam.Role('rllab') 167 | print("Attaching policies") 168 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/AmazonS3FullAccess') 169 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/ResourceGroupsandTagEditorFullAccess') 170 | 171 | print("Creating inline policies") 172 | iam_client.put_role_policy( 173 | RoleName=role.name, 174 | PolicyName='CreateTags', 175 | PolicyDocument=json.dumps({ 176 | "Version": "2012-10-17", 177 | "Statement": [ 178 | { 179 | "Effect": "Allow", 180 | "Action": ["ec2:CreateTags"], 181 | "Resource": ["*"] 182 | } 183 | ] 184 | }) 185 | ) 186 | iam_client.put_role_policy( 187 | RoleName=role.name, 188 | PolicyName='TerminateInstances', 189 | PolicyDocument=json.dumps({ 190 | "Version": "2012-10-17", 191 | "Statement": [ 192 | { 193 | "Sid": "Stmt1458019101000", 194 | "Effect": "Allow", 195 | "Action": [ 196 | "ec2:TerminateInstances" 197 | ], 198 | "Resource": [ 199 | "*" 200 | ] 201 | } 202 | ] 203 | }) 204 | ) 205 | 206 | print("Creating instance profile rllab") 207 | iam_client.create_instance_profile( 208 | InstanceProfileName='rllab', 209 | Path='/' 210 | ) 211 | print("Adding role rllab to instance profile rllab") 212 | iam_client.add_role_to_instance_profile( 213 | InstanceProfileName='rllab', 214 | RoleName='rllab' 215 | ) 216 | 217 | 218 | def setup_s3(): 219 | print("Creating S3 bucket at s3://%s" % S3_BUCKET_NAME) 220 | s3_client = boto3.client( 221 | "s3", 222 | aws_access_key_id=ACCESS_KEY, 223 | aws_secret_access_key=ACCESS_SECRET, 224 | ) 225 | try: 226 | s3_client.create_bucket( 227 | ACL='private', 228 | Bucket=S3_BUCKET_NAME, 229 | CreateBucketConfiguration={ 230 | 'LocationConstraint': 'us-east-2' 231 | } 232 | ) 233 | except botocore.exceptions.ClientError as e: 234 | if e.response['Error']['Code'] == 'BucketAlreadyExists': 235 | raise ValueError("Bucket %s already exists. Please reconfigure S3_BUCKET_NAME" % S3_BUCKET_NAME) from e 236 | elif e.response['Error']['Code'] == 'BucketAlreadyOwnedByYou': 237 | print("Bucket already created by you") 238 | else: 239 | raise e 240 | print("S3 bucket created") 241 | 242 | 243 | def setup_ec2(): 244 | for region in ["us-east-1", "us-east-2", "us-west-1", "us-west-2"]: 245 | print("Setting up region %s" % region) 246 | 247 | ec2 = boto3.resource( 248 | "ec2", 249 | region_name=region, 250 | aws_access_key_id=ACCESS_KEY, 251 | aws_secret_access_key=ACCESS_SECRET, 252 | ) 253 | ec2_client = boto3.client( 254 | "ec2", 255 | region_name=region, 256 | aws_access_key_id=ACCESS_KEY, 257 | aws_secret_access_key=ACCESS_SECRET, 258 | ) 259 | existing_vpcs = list(ec2.vpcs.all()) 260 | assert len(existing_vpcs) >= 1 261 | vpc = existing_vpcs[0] 262 | print("Creating security group in VPC %s" % str(vpc.id)) 263 | try: 264 | security_group = vpc.create_security_group( 265 | GroupName='rllab-sg', Description='Security group for rllab' 266 | ) 267 | except botocore.exceptions.ClientError as e: 268 | if e.response['Error']['Code'] == 'InvalidGroup.Duplicate': 269 | sgs = list(vpc.security_groups.filter(GroupNames=['rllab-sg'])) 270 | security_group = sgs[0] 271 | else: 272 | raise e 273 | 274 | ALL_REGION_AWS_SECURITY_GROUP_IDS[region] = [security_group.id] 275 | 276 | ec2_client.create_tags(Resources=[security_group.id], Tags=[{'Key': 'Name', 'Value': 'rllab-sg'}]) 277 | try: 278 | security_group.authorize_ingress(FromPort=22, ToPort=22, IpProtocol='tcp', CidrIp='0.0.0.0/0') 279 | except botocore.exceptions.ClientError as e: 280 | if e.response['Error']['Code'] == 'InvalidPermission.Duplicate': 281 | pass 282 | else: 283 | raise e 284 | print("Security group created with id %s" % str(security_group.id)) 285 | 286 | key_name = 'rllab-%s' % region 287 | try: 288 | print("Trying to create key pair with name %s" % key_name) 289 | key_pair = ec2_client.create_key_pair(KeyName=key_name) 290 | except botocore.exceptions.ClientError as e: 291 | if e.response['Error']['Code'] == 'InvalidKeyPair.Duplicate': 292 | if not query_yes_no("Key pair with name %s exists. Proceed to delete and recreate?" % key_name, "no"): 293 | sys.exit() 294 | print("Deleting existing key pair with name %s" % key_name) 295 | ec2_client.delete_key_pair(KeyName=key_name) 296 | print("Recreating key pair with name %s" % key_name) 297 | key_pair = ec2_client.create_key_pair(KeyName=key_name) 298 | else: 299 | raise e 300 | 301 | key_pair_folder_path = os.path.join(CHESTER_DIR, "private", "key_pairs") 302 | file_name = os.path.join(key_pair_folder_path, "%s.pem" % key_name) 303 | 304 | print("Saving keypair file") 305 | console.mkdir_p(key_pair_folder_path) 306 | with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, 0o600), 'w') as handle: 307 | handle.write(key_pair['KeyMaterial'] + '\n') 308 | 309 | # adding pem file to ssh 310 | os.system("ssh-add %s" % file_name) 311 | 312 | ALL_REGION_AWS_KEY_NAMES[region] = key_name 313 | 314 | 315 | def write_config(): 316 | print("Writing config file...") 317 | content = CONFIG_TEMPLATE.substitute( 318 | all_region_aws_key_names=json.dumps(ALL_REGION_AWS_KEY_NAMES, indent=4), 319 | all_region_aws_security_group_ids=json.dumps(ALL_REGION_AWS_SECURITY_GROUP_IDS, indent=4), 320 | s3_bucket_name=S3_BUCKET_NAME, 321 | ) 322 | config_personal_file = os.path.join(CHESTER_DIR, "config_ec2.py") 323 | if os.path.exists(config_personal_file): 324 | if not query_yes_no("config_ec2.py exists. Override?", "no"): 325 | sys.exit() 326 | with open(config_personal_file, "wb") as f: 327 | f.write(content.encode("utf-8")) 328 | 329 | 330 | def setup(): 331 | setup_s3() 332 | setup_iam() 333 | setup_ec2() 334 | write_config() 335 | 336 | 337 | def query_yes_no(question, default="yes"): 338 | """Ask a yes/no question via raw_input() and return their answer. 339 | 340 | "question" is a string that is presented to the user. 341 | "default" is the presumed answer if the user just hits . 342 | It must be "yes" (the default), "no" or None (meaning 343 | an answer is required of the user). 344 | 345 | The "answer" return value is True for "yes" or False for "no". 346 | """ 347 | valid = {"yes": True, "y": True, "ye": True, 348 | "no": False, "n": False} 349 | if default is None: 350 | prompt = " [y/n] " 351 | elif default == "yes": 352 | prompt = " [Y/n] " 353 | elif default == "no": 354 | prompt = " [y/N] " 355 | else: 356 | raise ValueError("invalid default answer: '%s'" % default) 357 | 358 | while True: 359 | sys.stdout.write(question + prompt) 360 | choice = input().lower() 361 | if default is not None and choice == '': 362 | return valid[default] 363 | elif choice in valid: 364 | return valid[choice] 365 | else: 366 | sys.stdout.write("Please respond with 'yes' or 'no' " 367 | "(or 'y' or 'n').\n") 368 | 369 | 370 | if __name__ == "__main__": 371 | setup() -------------------------------------------------------------------------------- /chester/slurm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import re 4 | from subprocess import run 5 | from tempfile import NamedTemporaryFile 6 | from chester import config 7 | 8 | # TODO remove the singularity part 9 | 10 | slurm_dir = './' 11 | 12 | 13 | def slurm_run_scripts(scripts): 14 | """this is another function that those _sub files should call. this actually execute files""" 15 | # TODO support running multiple scripts 16 | 17 | assert isinstance(scripts, str) 18 | 19 | os.chdir(slurm_dir) 20 | 21 | # make sure it will run. 22 | assert scripts.startswith('#!/usr/bin/env bash\n') 23 | file_temp = NamedTemporaryFile(delete=False) 24 | file_temp.write(scripts.encode('utf-8')) 25 | file_temp.close() 26 | run(['sbatch', file_temp.name], check=True) 27 | os.remove(file_temp.name) 28 | 29 | 30 | _find_unsafe = re.compile(r'[a-zA-Z0-9_^@%+=:,./-]').search 31 | 32 | 33 | def _shellquote(s): 34 | """Return a shell-escaped version of the string *s*.""" 35 | if not s: 36 | return "''" 37 | 38 | if _find_unsafe(s) is None: 39 | return s 40 | 41 | # use single quotes, and put single quotes into double quotes 42 | # the string $'b is then quoted as '$'"'"'b' 43 | 44 | return "'" + s.replace("'", "'\"'\"'") + "'" 45 | 46 | 47 | def _to_param_val(v): 48 | if v is None: 49 | return "" 50 | elif isinstance(v, list): 51 | return " ".join(map(_shellquote, list(map(str, v)))) 52 | else: 53 | return _shellquote(str(v)) 54 | 55 | 56 | def to_slurm_command(params, header, python_command="python", remote_dir='~/', 57 | script=osp.join(config.PROJECT_PATH, 'scripts/run_experiment.py'), 58 | simg_dir=None, use_gpu=False, modules=None, cuda_module=None, use_singularity=True, 59 | mount_options=None, compile_script=None, wait_compile=None, set_egl_gpu=False): 60 | # TODO Add code for specifying the resource allocation 61 | # TODO Check if use_gpu can be applied 62 | """ 63 | Transfer the commands to the format that can be run by slurm. 64 | :param params: 65 | :param python_command: 66 | :param script: 67 | :param use_gpu: 68 | :return: 69 | """ 70 | assert simg_dir is not None 71 | command = python_command + " " + script 72 | 73 | pre_commands = params.pop("pre_commands", None) 74 | post_commands = params.pop("post_commands", None) 75 | 76 | command_list = list() 77 | command_list.append(header) 78 | 79 | # Log into singularity shell 80 | if use_singularity: 81 | command_list.append('set -x') # echo commands to stdout 82 | command_list.append('set -u') # throw an error if unset variable referenced 83 | command_list.append('set -e') # exit on errors 84 | command_list.append('srun hostname') 85 | 86 | for remote_module in modules: 87 | command_list.append('module load ' + remote_module) 88 | if use_gpu: 89 | assert cuda_module is not None 90 | command_list.append('module load ' + cuda_module) 91 | command_list.append('cd {}'.format(remote_dir)) 92 | # First execute a bash program inside the container and then run all the following commands 93 | 94 | if mount_options is not None: 95 | options = '-B ' + mount_options 96 | else: 97 | options = '' 98 | sing_prefix = 'singularity exec {} {} {} /bin/bash -c'.format(options, '--nv' if use_gpu else '', simg_dir) 99 | sing_commands = list() 100 | if compile_script is None or 'prepare' not in compile_script : 101 | # sing_commands.append('. ./prepare_1.0.sh') 102 | sing_commands.append('. ./prepare.sh') 103 | if set_egl_gpu: 104 | sing_commands.append('export EGL_GPU=$SLURM_JOB_GRES') 105 | sing_commands.append('echo $EGL_GPU') 106 | if compile_script is not None: 107 | sing_commands.append('./' + compile_script) 108 | if wait_compile is not None: 109 | sing_commands.append('sleep '+str(int(wait_compile))) 110 | 111 | if pre_commands is not None: 112 | command_list.extend(pre_commands) 113 | for k, v in params.items(): 114 | if isinstance(v, dict): 115 | for nk, nv in v.items(): 116 | if str(nk) == "_name": 117 | command += " --%s %s" % (k, _to_param_val(nv)) 118 | else: 119 | command += " --%s_%s %s" % (k, nk, _to_param_val(nv)) 120 | else: 121 | command += " --%s %s" % (k, _to_param_val(v)) 122 | sing_commands.append(command) 123 | all_sing_cmds = ' && '.join(sing_commands) 124 | command_list.append(sing_prefix + ' \'{}\''.format(all_sing_cmds)) 125 | if post_commands is not None: 126 | command_list.extend(post_commands) 127 | return command_list 128 | 129 | # if __name__ == '__main__': 130 | # slurm_run_scripts(header) 131 | -------------------------------------------------------------------------------- /chester/upload_result.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | from chester import config 5 | from chester.run_exp import rsync_code 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('mode', type=str) 10 | args = parser.parse_args() 11 | 12 | remote_dir = config.REMOTE_DIR[args.mode] 13 | rsync_code(args.mode, remote_dir) 14 | -------------------------------------------------------------------------------- /chester/video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glfw 3 | from multiprocessing import Process, Queue 4 | 5 | import cv2 as cv 6 | 7 | class VideoRecorder(object): 8 | ''' 9 | Used to record videos for mujoco_py environment 10 | ''' 11 | def __init__(self, env, saved_path='./data/videos/', saved_name='temp'): 12 | # Get rid of the gym wrappers 13 | if hasattr(env, 'env'): 14 | env = env.env 15 | self.viewer = env._get_viewer() 16 | self.saved_path = saved_path 17 | self.saved_name = saved_name 18 | # self._set_filepath('/tmp/temp%07d.mp4') 19 | saved_name += '.mp4' 20 | self._set_filepath(os.path.join(saved_path, saved_name)) 21 | 22 | def _set_filepath(self, video_name): 23 | self.viewer._video_path = video_name 24 | 25 | def start(self): 26 | self.viewer._record_video = True 27 | if self.viewer._record_video: 28 | fps = (1 / self.viewer._time_per_render) 29 | self.viewer._video_process = Process(target=save_video, 30 | args=(self.viewer._video_queue, 31 | self.viewer._video_path, fps)) 32 | self.viewer._video_process.start() 33 | 34 | def end(self): 35 | self.viewer.key_callback(None, glfw.KEY_V, None, glfw.RELEASE, None) 36 | 37 | # class VideoRecorderDM(object): 38 | # ''' 39 | # Used to record videos for dm_control based environments 40 | # ''' 41 | # def __init__(self, env, saved_path='./data/videos/', saved_name='temp'): 42 | # self.saved_path = saved_path 43 | # self.saved_name = saved_name 44 | # 45 | # def -------------------------------------------------------------------------------- /compile_1.0.sh: -------------------------------------------------------------------------------- 1 | cd softgym/PyFlex/bindings 2 | rm -rf build 3 | mkdir build 4 | cd build 5 | # Seuss 6 | if [[ $(hostname) = *"compute-0"* ]] || [[ $(hostname) = *"autobot-"* ]] || [[ $(hostname) = *"yertle"* ]]; then 7 | export CUDA_BIN_PATH=/usr/local/cuda-9.1 8 | fi 9 | cmake -DPYBIND11_PYTHON_VERSION=3.6 .. 10 | make -j 11 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: softgym 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6.9 6 | - numpy=1.16.4 7 | - gym=0.15.7 8 | - h5py 9 | - pyg=2.0.1 10 | - Pillow=6.1 11 | - pyquaternion=0.9.5 12 | - opencv-python=5.1.1 13 | - imageio=2.6.1 14 | - imageio-ffmpeg 15 | - glob2=0.7 16 | - cmake=3.14.0 17 | - pybind11=2.4.3 18 | - moviepy 19 | - click 20 | - matplotlib 21 | - joblib 22 | - plotly 23 | - gtimer 24 | - python-pcl 25 | - pip: 26 | - torch=1.9 27 | - torchvision 28 | - termcolor 29 | - scikit-image 30 | -------------------------------------------------------------------------------- /prepare_1.0.sh: -------------------------------------------------------------------------------- 1 | PATH=~/software/miniconda3/bin:~/anaconda3/bin:$PATH 2 | cd softgym 3 | . prepare_1.0.sh 4 | cd .. 5 | export PYFLEXROOT=${PWD}/softgym/PyFlex 6 | export PYTHONPATH=${PWD}:${PWD}/softgym:${PYFLEXROOT}/bindings/build:$PYTHONPATH 7 | export LD_LIBRARY_PATH=${PYFLEXROOT}/external/SDL2-2.0.4/lib/x64:$LD_LIBRARY_PATH 8 | export EGL_GPU=$CUDA_VISIBLE_DEVICES 9 | -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | # Notes on pre-trained model 2 | Our pre-trained model, pre-collected dataset and the cached initial states for planning can be accessed through this google drive link: [Google Drive Link](https://drive.google.com/drive/folders/1gS8ejcY1imKVT8TD8zmNC38gNicpkL6X?usp=sharing) 3 | 4 | The goolge drive folder includes: 5 | * `dataset.zip`: Pre-collected dataset for training. 6 | * dynamics_model: 7 | - `vsbl_dyn_140.pth`: Pre-trained dynamics GNN for partially observed point cloud. 8 | - `best_state.json`: Loading the pretraine edge and dynamics GNN will require a corresponding `best_state.json` that stores the model information. Just put these json files under the same directory as the pre-trained model. 9 | * edge_model: 10 | - `vsbl_edge_best.pth`: Pre-trained Edge GNN. 11 | - `best_state.json`: Loading the pretraine edge and dynamics GNN will require a corresponding `best_state.json` that stores the model information. Just put these json files under the same directory as the pre-trained model. 12 | * cached_states: 13 | - `1213_release_n1000.pkl`: The cached initial states for generating the training data. 14 | - `cloth_flatten_init_states_test_40.pkl` and `cloth_flatten_init_states_test_40_2.pkl`: in total 40 initial states for testing cloth smoothing on square clothes (each cached file has 20 states). 15 | - `cloth_flatten_test_retangular_1.pkl` and `cloth_flatten_test_retangular_1.pkl`: Initial states for testing on rectangular clothes (each cached file has 20 states). 16 | - `tshirt_flatten_init_states_small_2021_05_28_01_22.pkl` and `tshirt_flatten_init_states_small_2021_05_28_01_16.pkl`: Initial states for testing on t-shirt (each cahced file has 20 states). 17 | -------------------------------------------------------------------------------- /pretrained/vis_dynamics.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/pretrained/vis_dynamics.gif -------------------------------------------------------------------------------- /pretrained/vis_planning.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/pretrained/vis_planning.gif --------------------------------------------------------------------------------