├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── VCD
├── dataset.py
├── dataset_edge.py
├── generate_cached_initial_state.py
├── main.py
├── main_plan.py
├── main_train_edge.py
├── models.py
├── rs_planner.py
├── utils
│ ├── __init__.py
│ ├── camera_utils.py
│ ├── data_utils.py
│ ├── gemo_utils.py
│ ├── plot_utils.py
│ └── utils.py
├── vc_dynamics.py
└── vc_edge.py
├── chester
├── README.md
├── add_variants.py
├── config.py
├── containers
│ └── ubuntu-16.04-lts-rl.README
├── docs
│ ├── Makefile
│ ├── README.md
│ ├── make.bat
│ ├── readme.md
│ └── source
│ │ ├── _static
│ │ └── basic_screenshot.png
│ │ ├── conf.py
│ │ ├── getting_started.rst
│ │ ├── index.rst
│ │ ├── launcher.rst
│ │ ├── logger.rst
│ │ └── visualization.rst
├── examples
│ ├── cplot_example.py
│ ├── pgm_plot.py
│ ├── presets.py
│ ├── presets2.py
│ ├── presets3.py
│ ├── presets_tiancheng.py
│ ├── train.py
│ └── train_launch.py
├── logger.py
├── plotting
│ └── cplot.py
├── pull_result.py
├── pull_s3_result.py
├── rsync_exclude
├── rsync_include
├── run_exp.py
├── run_exp_worker.py
├── scripts
│ ├── install_miniconda.sh
│ └── install_mpi4py.sh
├── setup_ec2_for_chester.py
├── slurm.py
├── upload_result.py
├── utils_s3.py
└── video_recorder.py
├── compile_1.0.sh
├── environment.yml
├── prepare_1.0.sh
└── pretrained
├── README.md
├── vis_dynamics.gif
└── vis_planning.gif
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pth
2 | *.zip
3 | pretrained/dataset
4 | *.pkl
5 | softgym_rpad
6 | softgym_public
7 | wandb
8 | GNS/pcl_filter/build
9 | rlpyt_cloth/data
10 | datasets
11 | dpi_visualization
12 | .vscode
13 | chester/private
14 | wandb
15 | videos
16 | .idea
17 | data
18 | *.simg
19 | *.img
20 | # Byte-compiled / optimized / DLL files
21 | __pycache__/
22 | *.py[cod]
23 | *$py.class
24 |
25 | # C extensions
26 | *.so
27 |
28 | # Distribution / packaging
29 | .Python
30 | build/
31 | develop-eggs/
32 | dist/
33 | downloads/
34 | eggs/
35 | .eggs/
36 | lib/
37 | lib64/
38 | parts/
39 | sdist/
40 | var/
41 | wheels/
42 | pip-wheel-metadata/
43 | share/python-wheels/
44 | *.egg-info/
45 | .installed.cfg
46 | *.egg
47 | MANIFEST
48 |
49 | # PyInstaller
50 | # Usually these files are written by a python script from a template
51 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
52 | *.manifest
53 | *.spec
54 |
55 | # Installer logs
56 | pip-log.txt
57 | pip-delete-this-directory.txt
58 |
59 | # Unit test / coverage reports
60 | htmlcov/
61 | .tox/
62 | .nox/
63 | .coverage
64 | .coverage.*
65 | .cache
66 | nosetests.xml
67 | coverage.xml
68 | *.cover
69 | .hypothesis/
70 | .pytest_cache/
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 |
81 | # Flask stuff:
82 | instance/
83 | .webassets-cache
84 |
85 | # Scrapy stuff:
86 | .scrapy
87 |
88 | # Sphinx documentation
89 | docs/_build/
90 |
91 | # PyBuilder
92 | target/
93 |
94 | # Jupyter Notebook
95 | .ipynb_checkpoints
96 |
97 | # IPython
98 | profile_default/
99 | ipython_config.py
100 |
101 | # pyenv
102 | .python-version
103 |
104 | # celery beat schedule file
105 | celerybeat-schedule
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 |
138 | # Results
139 | results/
140 |
141 | # DPI data
142 | DPI-Net/dump_*/
143 | DPI-Net/dump_*/
144 |
145 | test_*/
146 | imgs/
147 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "softgym"]
2 | path = softgym
3 | url = https://github.com/Xingyu-Lin/softgym.git
4 |
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Carnegie Mellon University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Learning Visible Connectivity Dynamics for Cloth Smoothing
2 |
3 | Xingyu Lin*, Yufei Wang*, Zixuan Huang, David Held, CoRL 2021. `*` indicates equal contribution (order by dice rolling).
4 |
5 | [Website](https://sites.google.com/view/vcd-cloth/home) / [ArXiv](https://arxiv.org/pdf/2105.10389.pdf)
6 |
7 | # Table of Contents
8 | - 1 [Simulation](#simulation)
9 | - 1.1 [Setup](#setup)
10 | - 1.2 [Train VCD](#train-vcd)
11 | - 1.3 [Planning with VCD](#plan-vcd)
12 | - 1.4 [Graph Imitation Learning](#graph-imit)
13 | - 1.5 [Pretrained Models](#pretrained)
14 | - 1.6 [Demo](#Demo)
15 | ----
16 | # Simulation
17 |
18 | ## Setup
19 | This repository is a subset of [SoftAgent](https://github.com/Xingyu-Lin/softagent) cleaned up for VCD. Environment setup for VCD is similar to that of softagent.
20 | 1. Install [SoftGym](https://github.com/Xingyu-Lin/softgym). Then, copy softgym as a submodule in this directory by running `cp -r [path to softgym] ./`. Use the updated softgym on the vcd branch by `cd softgym && git checkout vcd`
21 | 2. You should have a conda environment named `softgym`. Install additional packages required by VCD, by `conda env update --file environment.yml`
22 | 3. Generate initial environment configurations and cache them, by running `python VCD/generate_cached_initial_state.py`.
23 | 4. Run `./compile_1.0.sh && . ./prepare_1.0.sh` to compile PyFleX and prepare other paths.
24 |
25 | ## Train VCD
26 | * Generate the dataset for training by running
27 | ```
28 | python VCD/main.py --gen_data=1 --dataf=./data/vcd
29 | ```
30 | Please refer to `main.py` for argument options.
31 |
32 | * Train the dynamics model by running
33 | ```
34 | python VCD/main.py --gen_data=0 --dataf=./data/vcd_dyn
35 | ```
36 | * Train the EdgeGNN model by running
37 | ```
38 | python VCD/main_train_edge.py --gen_data=0 --dataf=./data/vcd_edge
39 | ```
40 | ## Planning with VCD
41 | ```
42 | python VCD/main_plan.py --edge_model_path={path_to_trained_edge_model}\
43 | --partial_dyn_path={path_to_trained_dynamics_model}
44 | ```
45 | An example for loading the model trained for 120 epochs:
46 | ```
47 | python VCD/main_plan.py --edge_model_path ./data/vcd_edge/vsbl_edge_120.pth\
48 | --partial_dyn_path ./data/vcd_dyn/vsbl_dyn_120.pth
49 | ```
50 | ## Graph Imitation Learning
51 | 1. Train dynamics using the full mesh
52 | ```
53 | python VCD/main.py --gen_data=0 --dataf=./data/vcd --train_mode=full
54 | ```
55 | 2. Train dynamics using partial point cloud and imitate the teacher model
56 | ```
57 | python VCD/main.py --gen_data=0 --dataf=./data/vcd --train_mode=graph_imit --full_dyn_path={path_to_teacher_model}
58 | ```
59 |
60 | ## Pretrained Model
61 | Please refer to [this page](pretrained/README.md) for downloading the pretrained models.
62 |
63 | ## Demo
64 | * Dynamics rollout
65 | 
66 | * Planning on square cloth
67 | 
68 | ## Cite
69 | If you find this codebase useful in your research, please consider citing:
70 | ```
71 | @inproceedings{lin2022learning,
72 | title={Learning visible connectivity dynamics for cloth smoothing},
73 | author={Lin, Xingyu and Wang, Yufei and Huang, Zixuan and Held, David},
74 | booktitle={Conference on Robot Learning},
75 | pages={256--266},
76 | year={2022},
77 | organization={PMLR}
78 | }
79 |
80 | @inproceedings{corl2020softgym,
81 | title={SoftGym: Benchmarking Deep Reinforcement Learning for Deformable Object Manipulation},
82 | author={Lin, Xingyu and Wang, Yufei and Olkin, Jake and Held, David},
83 | booktitle={Conference on Robot Learning},
84 | year={2020}
85 | }
86 | ```
87 |
--------------------------------------------------------------------------------
/VCD/dataset_edge.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy import spatial
4 | from torch_geometric.data import Data
5 |
6 | from VCD.utils.camera_utils import get_observable_particle_index_3
7 | from VCD.dataset import ClothDataset
8 | from VCD.utils.utils import load_data, voxelize_pointcloud
9 |
10 |
11 | class ClothDatasetPointCloudEdge(ClothDataset):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 |
15 | def __getitem__(self, idx):
16 | data = self._prepare_transition(idx)
17 | d = self.build_graph(data)
18 | return Data.from_dict(d)
19 |
20 | def _prepare_transition(self, idx, eval=False):
21 | pred_time_interval = self.args.pred_time_interval
22 | success = False
23 | next = 1 if not eval else self.args.time_step - self.args.n_his
24 |
25 | while not success:
26 | idx_rollout = (idx // (self.args.time_step - self.args.n_his)) % self.n_rollout
27 | idx_timestep = (self.args.n_his - pred_time_interval) + idx % (self.args.time_step - self.args.n_his)
28 | idx_timestep = max(idx_timestep, 0)
29 |
30 | data = load_data(self.data_dir, idx_rollout, idx_timestep, self.data_names)
31 | pointcloud = data['pointcloud'].astype(np.float32)
32 | if len(pointcloud.shape) != 2:
33 | print('dataset_edge.py, errorneous data. What is going on?')
34 | import pdb
35 | pdb.set_trace()
36 |
37 | idx += next
38 | continue
39 |
40 | if len(pointcloud) < 100: # TODO Filter these during dataset generation
41 | print('dataset_edge.py, fix this')
42 | import pdb
43 | pdb.set_trace()
44 | idx += next
45 | continue
46 |
47 | vox_pc = voxelize_pointcloud(pointcloud, self.args.voxel_size)
48 |
49 | partial_particle_pos = data['positions'][data['downsample_idx']][data['downsample_observable_idx']]
50 | if len(vox_pc) <= len(partial_particle_pos):
51 | success = True
52 |
53 | # NOTE: what is this for?
54 | if eval and not success:
55 | return None
56 |
57 | idx += next
58 |
59 | pointcloud, partial_pc_mapped_idx = get_observable_particle_index_3(vox_pc, partial_particle_pos, threshold=self.args.voxel_size)
60 | normalized_vox_pc = vox_pc - np.mean(vox_pc, axis=0)
61 |
62 | ret_data = {
63 | 'scene_params': data['scene_params'],
64 | 'downsample_observable_idx': data['downsample_observable_idx'],
65 | 'normalized_vox_pc': normalized_vox_pc,
66 | 'partial_pc_mapped_idx': partial_pc_mapped_idx,
67 | }
68 | if eval:
69 | ret_data['downsample_idx'] = data['downsample_idx']
70 | ret_data['pointcloud'] = vox_pc
71 |
72 | return ret_data
73 |
74 | def _compute_edge_attr(self, vox_pc):
75 | point_tree = spatial.cKDTree(vox_pc)
76 | undirected_neighbors = np.array(list(point_tree.query_pairs(self.args.neighbor_radius, p=2))).T
77 |
78 | if len(undirected_neighbors) > 0:
79 | dist_vec = vox_pc[undirected_neighbors[0, :]] - vox_pc[undirected_neighbors[1, :]]
80 | dist = np.linalg.norm(dist_vec, axis=1, keepdims=True)
81 | edge_attr = np.concatenate([dist_vec, dist], axis=1)
82 | edge_attr_reverse = np.concatenate([-dist_vec, dist], axis=1)
83 |
84 | # Generate directed edge list and corresponding edge attributes
85 | edges = torch.from_numpy(np.concatenate([undirected_neighbors, undirected_neighbors[::-1]], axis=1))
86 | edge_attr = torch.from_numpy(np.concatenate([edge_attr, edge_attr_reverse]))
87 | else:
88 | print("number of distance edges is 0! adding fake edges")
89 | edges = np.zeros((2, 2), dtype=np.uint8)
90 | edges[0][0] = 0
91 | edges[1][0] = 1
92 | edges[0][1] = 0
93 | edges[1][1] = 2
94 | edge_attr = np.zeros((2, self.args.relation_dim), dtype=np.float32)
95 | edges = torch.from_numpy(edges).bool()
96 | edge_attr = torch.from_numpy(edge_attr)
97 | print("shape of edges: ", edges.shape)
98 | print("shape of edge_attr: ", edge_attr.shape)
99 |
100 | return edges, edge_attr
101 |
102 | def build_graph(self, data, get_gt_edge_label=True):
103 | """
104 | data: positions, picked_points, picked_point_positions, scene_params
105 | downsample: whether to downsample the graph
106 | test: if False, we are in the training mode, where we know exactly the picked point and its movement
107 | if True, we are in the test mode, we have to infer the picked point in the (downsampled graph) and compute
108 | its movement.
109 |
110 | return:
111 | node_attr: N x (vel_history x 3)
112 | edges: 2 x E, the edges
113 | edge_attr: E x edge_feature_dim
114 | gt_mesh_edge: 0/1 label for groundtruth mesh edge connection.
115 | """
116 | node_attr = torch.from_numpy(data['normalized_vox_pc'])
117 | edges, edge_attr = self._compute_edge_attr(data['normalized_vox_pc'])
118 |
119 | if get_gt_edge_label:
120 | gt_mesh_edge = self._get_gt_mesh_edge(data, edges)
121 | gt_mesh_edge = torch.from_numpy(gt_mesh_edge)
122 | else:
123 | gt_mesh_edge = None
124 |
125 | return {
126 | 'x': node_attr,
127 | 'edge_index': edges,
128 | 'edge_attr': edge_attr,
129 | 'gt_mesh_edge': gt_mesh_edge
130 | }
131 |
132 | def _get_gt_mesh_edge(self, data, distance_edges):
133 | scene_params, observable_particle_idx, partial_pc_mapped_idx = data['scene_params'], data['downsample_observable_idx'], data['partial_pc_mapped_idx']
134 | _, cloth_xdim, cloth_ydim, _ = scene_params
135 | cloth_xdim, cloth_ydim = int(cloth_xdim), int(cloth_ydim)
136 |
137 | observable_mask = np.zeros(cloth_xdim * cloth_ydim)
138 | observable_mask[observable_particle_idx] = 1
139 |
140 | num_edges = distance_edges.shape[1]
141 | gt_mesh_edge = np.zeros((num_edges, 1), dtype=np.float32)
142 |
143 | for edge_idx in range(num_edges):
144 | # the edge index is in the range [0, len(pointcloud) - 1]
145 | # needs to convert it back to the idx in the downsampled graph
146 | s = int(distance_edges[0][edge_idx].item())
147 | r = int(distance_edges[1][edge_idx].item())
148 |
149 | # map from pointcloud idx to observable particle index
150 | s = partial_pc_mapped_idx[s]
151 | r = partial_pc_mapped_idx[r]
152 |
153 | s = observable_particle_idx[s]
154 | r = observable_particle_idx[r]
155 |
156 | if (r == s + 1 or r == s - 1 or
157 | r == s + cloth_xdim or r == s - cloth_xdim or
158 | r == s + cloth_xdim + 1 or r == s + cloth_xdim - 1 or
159 | r == s - cloth_xdim + 1 or r == s - cloth_xdim - 1
160 | ):
161 | gt_mesh_edge[edge_idx] = 1
162 |
163 | return gt_mesh_edge
164 |
--------------------------------------------------------------------------------
/VCD/generate_cached_initial_state.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from softgym.registered_env import SOFTGYM_ENVS, env_arg_dict
3 | from VCD.main import get_default_args
4 |
5 |
6 | def create_env(args):
7 | assert args.env_name == 'ClothFlatten'
8 |
9 | env_args = copy.deepcopy(env_arg_dict[args.env_name]) # Default args
10 | env_args['cached_states_path'] = args.cached_states_path
11 | env_args['num_variations'] = args.num_variations
12 | env_args['use_cached_states'] = False
13 | env_args['save_cached_states'] = True
14 |
15 | env_args['render'] = False
16 | env_args['headless'] = True
17 | env_args['render_mode'] = 'cloth' if args.gen_data else 'particle'
18 | env_args['camera_name'] = 'default_camera'
19 | env_args['camera_width'] = 360
20 | env_args['camera_height'] = 360
21 |
22 | env_args['num_picker'] = 2 # The extra picker is hidden and does not really matter
23 | env_args['picker_radius'] = 0.01
24 | env_args['picker_threshold'] = 0.00625
25 | env_args['action_repeat'] = 1
26 |
27 | if args.partial_observable and args.gen_data:
28 | env_args['observation_mode'] = 'cam_rgb'
29 |
30 | return SOFTGYM_ENVS[args.env_name](**env_args)
31 |
32 |
33 | if __name__ == '__main__':
34 | args = get_default_args()
35 | env = create_env(args)
36 |
--------------------------------------------------------------------------------
/VCD/main.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import json
3 | import argparse
4 |
5 | from VCD.utils.utils import vv_to_args, set_resource
6 | import copy
7 | from VCD.vc_dynamics import VCDynamics
8 | from VCD.vc_edge import VCConnection
9 | from chester import logger
10 | from VCD.utils.utils import configure_logger, configure_seed
11 |
12 |
13 | def get_default_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--exp_name', type=str, default='test', help='Name of the experiment')
16 | parser.add_argument('--log_dir', type=str, default='data/dyn_debug/', help='Logging directory')
17 | parser.add_argument('--seed', type=int, default=100)
18 |
19 | # Env
20 | parser.add_argument('--env_name', type=str, default='ClothFlatten')
21 | parser.add_argument('--cached_states_path', type=str, default='1213_release_n1000.pkl')
22 | parser.add_argument('--num_variations', type=int, default=1000)
23 | parser.add_argument('--partial_observable', type=bool, default=True, help="Whether only the partial point cloud can be observed")
24 | parser.add_argument('--particle_radius', type=float, default=0.00625, help='Particle radius for the cloth')
25 | ## pyflex shape state
26 | parser.add_argument('--shape_state_dim', type=int, default=14, help="[xyz, xyz_last, quat(4), quat_last(4)]")
27 |
28 | # Dataset
29 | parser.add_argument('--n_rollout', type=int, default=2000, help='Number of training trajectories')
30 | parser.add_argument('--time_step', type=int, default=100, help='Time steps per trajectory')
31 | parser.add_argument('--dt', type=float, default=1. / 100.)
32 | parser.add_argument('--pred_time_interval', type=int, default=5, help='Interval of timesteps between each dynamics prediction (model dt)')
33 | parser.add_argument('--train_valid_ratio', type=float, default=0.9, help="Ratio between training and validation")
34 | parser.add_argument('--dataf', type=str, default='./data/release/', help='Path to dataset')
35 | parser.add_argument('--gen_data', type=int, default=0, help='Whether to generate dataset')
36 | parser.add_argument('--gen_gif', type=bool, default=0, help='Whether to also save gif of each trajectory (for debugging)')
37 |
38 | # Model
39 | parser.add_argument('--global_size', type=int, default=128, help="Number of hidden nodes for global in GNN")
40 | parser.add_argument('--n_his', type=int, default=5, help="Number of history step input to the dynamics")
41 | parser.add_argument('--down_sample_scale', type=int, default=3, help="Downsample the simulated cloth by a scale of 3 on each dimension")
42 | parser.add_argument('--voxel_size', type=float, default=0.0216)
43 | parser.add_argument('--neighbor_radius', type=float, default=0.045, help="Radius for connecting nearby edges")
44 | parser.add_argument('--use_rest_distance', type=bool, default=True, help="Subtract the rest distance for the edge attribute of mesh edges")
45 | parser.add_argument('--use_mesh_edge', type=bool, default=True)
46 | parser.add_argument('--collect_data_delta_move_min', type=float, default=0.15)
47 | parser.add_argument('--collect_data_delta_move_max', type=float, default=0.4)
48 | parser.add_argument('--proc_layer', type=int, default=10, help="Number of processor layers in GNN")
49 | parser.add_argument('--state_dim', type=int, default=18,
50 | help="Dim of node feature input. Computed based on n_his: 3 x 5 + 1 dist to ground + 2 one-hot encoding of picked particle")
51 | parser.add_argument('--relation_dim', type=int, default=7, help="""Dim of edge feature input:
52 | 3 for directional vector + 1 for directional vector magnitude + 2 for one-hot encoding of mesh or collision edge + 1 for rest distance
53 | """)
54 |
55 | # Resume training
56 | parser.add_argument('--edge_model_path', type=str, default=None, help='Path to a trained edgeGNN model')
57 | parser.add_argument('--full_dyn_path', type=str, default=None, help='Path to a dynamics model using full point cloud')
58 | parser.add_argument('--partial_dyn_path', type=str, default=None, help='Path to a dynamics model using partial point cloud')
59 | parser.add_argument('--load_optim', type=bool, default=False, help='Load optimizer when resume training')
60 |
61 | # Training
62 | parser.add_argument('--train_mode', type=str, default='vsbl', help='Should be in ["vsbl", "graph_imit", "full"]')
63 | parser.add_argument('--n_epoch', type=int, default=1000)
64 | parser.add_argument('--beta1', type=float, default=0.9)
65 | parser.add_argument('--lr', type=float, default=1e-4)
66 | parser.add_argument('--fixed_lr', type=bool, default=False, help='By default, decaying lr is used.')
67 | parser.add_argument('--batch_size', type=int, default=16)
68 | parser.add_argument('--cuda_idx', type=int, default=0)
69 | parser.add_argument('--num_workers', type=int, default=10, help='Number of workers for dataloader')
70 | parser.add_argument('--eval', type=int, default=0, help='Whether to just evaluating the model')
71 | parser.add_argument('--nstep_eval_rollout', type=int, default=20, help='Number of rollout trajectory for evaluation')
72 | parser.add_argument('--save_model_interval', type=int, default=5, help='Save the model every N epochs during training')
73 | parser.add_argument('--use_wandb', type=bool, default=False, help='Use weight and bias for logging')
74 |
75 | # For graph imitation
76 | parser.add_argument('--vsbl_lr', type=float, default=1e-4, help='Learning rate for visible(vsbl) point cloud dynamics')
77 | parser.add_argument('--full_lr', type=float, default=1e-4, help='Learning rate for full point cloud dynamics')
78 | parser.add_argument('--tune_teach', type=bool, default=False, help='Whether to allow teacher to adapt during graph imitation')
79 | parser.add_argument('--copy_teach', type=list, default=['encoder', 'decoder'], help="Which modules of the student are initialized from teacher")
80 | parser.add_argument('--imit_w_lat', type=float, default=1, help='Weight for imitating the global feature')
81 | parser.add_argument('--imit_w', type=float, default=5, help='Weight for imitation loss (vs student accel loss)')
82 | parser.add_argument('--reward_w', type=float, default=1e5, help='Weight for reward loss')
83 |
84 | # For ablation
85 | parser.add_argument('--fix_collision_edge', type=bool, default=False, help='Ablation that use fixed collision edges during rollout')
86 | parser.add_argument('--use_collision_as_mesh_edge', type=bool, default=False,
87 | help='If True, will not use gt mesh edges, but use all collision edges as mesh edges.')
88 |
89 | args = parser.parse_args()
90 | return args
91 |
92 |
93 | def create_env(args):
94 | from softgym.registered_env import env_arg_dict
95 | from softgym.registered_env import SOFTGYM_ENVS
96 | assert args.env_name == 'ClothFlatten'
97 |
98 | env_args = copy.deepcopy(env_arg_dict[args.env_name]) # Default args
99 | env_args['cached_states_path'] = args.cached_states_path
100 | env_args['num_variations'] = args.num_variations
101 |
102 | env_args['render'] = True
103 | env_args['headless'] = True
104 | env_args['render_mode'] = 'cloth' if args.gen_data else 'particle'
105 | env_args['camera_name'] = 'default_camera'
106 | env_args['camera_width'] = 360
107 | env_args['camera_height'] = 360
108 |
109 | env_args['num_picker'] = 2 # The extra picker is hidden and does not really matter
110 | env_args['picker_radius'] = 0.01
111 | env_args['picker_threshold'] = 0.00625
112 | env_args['action_repeat'] = 1
113 |
114 | if args.partial_observable and args.gen_data:
115 | env_args['observation_mode'] = 'cam_rgb'
116 |
117 | return SOFTGYM_ENVS[args.env_name](**env_args)
118 |
119 |
120 | def main():
121 | set_resource() # To avoid pin_memory issue
122 | args = get_default_args()
123 | env = create_env(args)
124 |
125 | configure_logger(args.log_dir, args.exp_name)
126 | configure_seed(args.seed)
127 |
128 | with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f:
129 | json.dump(args.__dict__, f, indent=2, sort_keys=True)
130 |
131 | # load vcd_edge
132 | if args.edge_model_path is not None:
133 | edge_model_vv = json.load(open(osp.join(args.edge_model_path, 'variant.json')))
134 | edge_model_args = vv_to_args(edge_model_vv)
135 | vcd_edge = VCConnection(edge_model_args, env=env)
136 | vcd_edge.load_model(args.edge_model_path)
137 | print('EdgeGNN successfully loaded from ', args.edge_model_path, flush=True)
138 | else:
139 | vcd_edge = None
140 | vcdynamics = VCDynamics(args, env, vcd_edge=vcd_edge) # Input the vcd_edge model in case we want to train with predicted edges
141 |
142 | if args.gen_data:
143 | vcdynamics.generate_dataset()
144 | else:
145 | vcdynamics.train()
146 |
147 |
148 | if __name__ == '__main__':
149 | main()
150 |
--------------------------------------------------------------------------------
/VCD/main_train_edge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from chester import logger
3 | import json
4 | import os.path as osp
5 | from VCD.vc_edge import VCConnection
6 | from VCD.main import create_env
7 | from VCD.utils.utils import configure_logger, configure_seed
8 |
9 |
10 | # TODO Merge arguments
11 | def get_default_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--exp_name', type=str, default='test', help='Name of the experiment')
14 | parser.add_argument('--log_dir', type=str, default='data/edge_debug/', help='Logging directory')
15 | parser.add_argument('--seed', type=int, default=100)
16 |
17 | # Env
18 | parser.add_argument('--env_name', type=str, default='ClothFlatten')
19 | parser.add_argument('--cached_states_path', type=str, default='1213_release_n1000.pkl')
20 | parser.add_argument('--num_variations', type=int, default=1000)
21 | parser.add_argument('--partial_observable', type=bool, default=True, help="Whether only the partial point cloud can be observed")
22 | parser.add_argument('--particle_radius', type=float, default=0.00625, help='Particle radius for the cloth')
23 |
24 | # Dataset
25 | parser.add_argument('--n_rollout', type=int, default=2000, help='Number of training trajectories')
26 | parser.add_argument('--time_step', type=int, default=100, help='Time steps per trajectory')
27 | parser.add_argument('--dt', type=float, default=1. / 100.)
28 | parser.add_argument('--pred_time_interval', type=int, default=5, help='Interval of timesteps between each dynamics prediction (model dt)')
29 | parser.add_argument('--train_valid_ratio', type=float, default=0.9, help="Ratio between training and validation")
30 | parser.add_argument('--dataf', type=str, default='softgym/softgym/cached_initial_states/', help='Path to dataset')
31 | parser.add_argument('--gen_data', type=int, default=0, help='Whether to generate dataset')
32 | parser.add_argument('--gen_gif', type=bool, default=0, help='Whether to also save gif of each trajectory (for debugging)')
33 |
34 | # Model
35 | parser.add_argument('--global_size', type=int, default=128, help="Number of hidden nodes for global in GNN")
36 | parser.add_argument('--n_his', type=int, default=5, help="Number of history step input to the dynamics")
37 | parser.add_argument('--down_sample_scale', type=int, default=3, help="Downsample the simulated cloth by a scale of 3 on each dimension")
38 | parser.add_argument('--voxel_size', type=float, default=0.0216)
39 | parser.add_argument('--neighbor_radius', type=float, default=0.045, help="Radius for connecting nearby edges")
40 | parser.add_argument('--collect_data_delta_move_min', type=float, default=0.15)
41 | parser.add_argument('--collect_data_delta_move_max', type=float, default=0.4)
42 | parser.add_argument('--proc_layer', type=int, default=10, help="Number of processor layers in GNN")
43 | parser.add_argument('--state_dim', type=int, default=3,
44 | help="Dim of node feature input. Computed based on n_his: 3 x 5 + 1 dist to ground + 2 one-hot encoding of picked particle")
45 | parser.add_argument('--relation_dim', type=int, default=4, help="Dim of edge feature input")
46 |
47 | # Resume training
48 | parser.add_argument('--edge_model_path', type=str, default=None, help='Path to a trained edgeGNN model')
49 | parser.add_argument('--load_optim', type=bool, default=False, help='Load optimizer when resume training')
50 |
51 | # Training
52 | parser.add_argument('--n_epoch', type=int, default=1000)
53 | parser.add_argument('--beta1', type=float, default=0.9)
54 | parser.add_argument('--lr', type=float, default=1e-4)
55 | parser.add_argument('--batch_size', type=int, default=16)
56 | parser.add_argument('--cuda_idx', type=int, default=0)
57 | parser.add_argument('--num_workers', type=int, default=15, help='Number of workers for dataloader')
58 | parser.add_argument('--use_wandb', type=bool, default=False, help='Use weight and bias for logging')
59 | parser.add_argument('--plot_num', type=int, default=8, help='Number of edge prediction visuals to dump per training epoch')
60 | parser.add_argument('--eval', type=int, default=0, help='Whether to just evaluating the model')
61 |
62 | args = parser.parse_args()
63 | return args
64 |
65 |
66 | def main():
67 | args = get_default_args()
68 | configure_logger(args.log_dir, args.exp_name)
69 | configure_seed(args.seed)
70 |
71 | # Dump parameters
72 | with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f:
73 | json.dump(args.__dict__, f, indent=2, sort_keys=True)
74 |
75 | env = create_env(args)
76 | vcd_edge = VCConnection(args, env=env)
77 | if args.gen_data:
78 | vcd_edge.generate_dataset()
79 | else:
80 | vcd_edge.train()
81 |
82 |
83 | if __name__ == '__main__':
84 | main()
85 |
--------------------------------------------------------------------------------
/VCD/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_scatter
3 | from itertools import chain
4 | from torch_geometric.nn import MetaLayer
5 | import os
6 |
7 |
8 | # ================== Encoder ================== #
9 | class NodeEncoder(torch.nn.Module):
10 | def __init__(self, input_size, hidden_size=128, output_size=128):
11 | super(NodeEncoder, self).__init__()
12 | self.input_size = input_size
13 | self.hidden_size = hidden_size
14 | self.output_size = output_size
15 | self.model = torch.nn.Sequential(
16 | torch.nn.Linear(self.input_size, self.hidden_size),
17 | torch.nn.ReLU(inplace=True),
18 | # torch.nn.LayerNorm(self.hidden_size),
19 | torch.nn.Linear(self.hidden_size, self.hidden_size),
20 | torch.nn.ReLU(inplace=True),
21 | # torch.nn.LayerNorm(self.hidden_size),
22 | torch.nn.Linear(self.hidden_size, self.output_size))
23 |
24 | def forward(self, node_state):
25 | out = self.model(node_state)
26 | return out
27 |
28 |
29 | class EdgeEncoder(torch.nn.Module):
30 | def __init__(self, input_size, hidden_size=128, output_size=128):
31 | super(EdgeEncoder, self).__init__()
32 | self.input_size = input_size
33 | self.hidden_size = hidden_size
34 | self.output_size = output_size
35 | self.model = torch.nn.Sequential(
36 | torch.nn.Linear(self.input_size, self.hidden_size),
37 | torch.nn.ReLU(inplace=True),
38 | # torch.nn.LayerNorm(self.hidden_size),
39 | torch.nn.Linear(self.hidden_size, self.hidden_size),
40 | torch.nn.ReLU(inplace=True),
41 | # torch.nn.LayerNorm(self.hidden_size),
42 | torch.nn.Linear(self.hidden_size, self.output_size))
43 |
44 | def forward(self, edge_properties):
45 | out = self.model(edge_properties)
46 | return out
47 |
48 |
49 | class Encoder(torch.nn.Module):
50 | def __init__(self, node_input_size, edge_input_size, hidden_size=128, output_size=128):
51 | super(Encoder, self).__init__()
52 | self.node_input_size = node_input_size
53 | self.edge_input_size = edge_input_size
54 | self.hidden_size = hidden_size
55 | self.output_size = output_size
56 | self.node_encoder = NodeEncoder(self.node_input_size, self.hidden_size, self.output_size)
57 | self.edge_encoder = EdgeEncoder(self.edge_input_size, self.hidden_size, self.output_size)
58 |
59 | def forward(self, node_states, edge_properties):
60 | node_embedding = self.node_encoder(node_states)
61 | edge_embedding = self.edge_encoder(edge_properties)
62 | return node_embedding, edge_embedding
63 |
64 |
65 | # ================== Processor ================== #
66 | class EdgeModel(torch.nn.Module):
67 | def __init__(self, input_size, hidden_size=128, output_size=128):
68 | super(EdgeModel, self).__init__()
69 | self.input_size = input_size
70 | self.hidden_size = hidden_size
71 | self.output_size = output_size
72 | self.model = torch.nn.Sequential(
73 | torch.nn.Linear(self.input_size, self.hidden_size),
74 | torch.nn.ReLU(inplace=True),
75 | # torch.nn.LayerNorm(self.hidden_size),
76 | torch.nn.Linear(self.hidden_size, self.hidden_size),
77 | torch.nn.ReLU(inplace=True),
78 | # torch.nn.LayerNorm(self.hidden_size),
79 | torch.nn.Linear(self.hidden_size, self.output_size))
80 |
81 | def forward(self, src, dest, edge_attr, u, batch):
82 | # source, target: [E, F_x], where E is the number of edges.
83 | # edge_attr: [E, F_e]
84 | # u: [B, F_u], where B is the number of graphs.
85 | # batch: [E] with max entry B - 1.
86 | # u_expanded = u.expand([src.size()[0], -1])
87 | # model_input = torch.cat([src, dest, edge_attr, u_expanded], 1)
88 | # out = self.model(model_input)
89 | model_input = torch.cat([src, dest, edge_attr, u[batch]], 1)
90 | out = self.model(model_input)
91 | return out
92 |
93 |
94 | class NodeModel(torch.nn.Module):
95 | def __init__(self, input_size, hidden_size=128, output_size=128):
96 | super(NodeModel, self).__init__()
97 | self.input_size = input_size
98 | self.hidden_size = hidden_size
99 | self.output_size = output_size
100 | self.model = torch.nn.Sequential(
101 | torch.nn.Linear(self.input_size, self.hidden_size),
102 | torch.nn.ReLU(inplace=True),
103 | # torch.nn.LayerNorm(self.hidden_size),
104 | torch.nn.Linear(self.hidden_size, self.hidden_size),
105 | torch.nn.ReLU(inplace=True),
106 | # torch.nn.LayerNorm(self.hidden_size),
107 | torch.nn.Linear(self.hidden_size, self.output_size))
108 |
109 | def forward(self, x, edge_index, edge_attr, u, batch):
110 | # x: [N, F_x], where N is the number of nodes.
111 | # edge_index: [2, E] with max entry N - 1.
112 | # edge_attr: [E, F_e]
113 | # u: [B, F_u]
114 | # batch: [N] with max entry B - 1.
115 | _, edge_dst = edge_index
116 | edge_attr_aggregated = torch_scatter.scatter_add(edge_attr, edge_dst, dim=0, dim_size=x.size(0))
117 | model_input = torch.cat([x, edge_attr_aggregated, u[batch]], dim=1)
118 | out = self.model(model_input)
119 | return out
120 |
121 |
122 | class GlobalModel(torch.nn.Module):
123 | def __init__(self, input_size, hidden_size=128, output_size=128):
124 | super(GlobalModel, self).__init__()
125 | self.input_size = input_size
126 | self.hidden_size = hidden_size
127 | self.output_size = output_size
128 | self.model = torch.nn.Sequential(
129 | torch.nn.Linear(self.input_size, self.hidden_size),
130 | torch.nn.ReLU(inplace=True),
131 | # torch.nn.LayerNorm(self.hidden_size),
132 | torch.nn.Linear(self.hidden_size, self.hidden_size),
133 | torch.nn.ReLU(inplace=True),
134 | # torch.nn.LayerNorm(self.hidden_size),
135 | torch.nn.Linear(self.hidden_size, self.output_size))
136 |
137 | def forward(self, x, edge_index, edge_attr, u, batch):
138 | # x: [N, F_x], where N is the number of nodes.
139 | # edge_index: [2, E] with max entry N - 1.
140 | # edge_attr: [E, F_e]
141 | # u: [B, F_u]
142 | # batch: [N] with max entry B - 1.
143 | node_attr_mean = torch_scatter.scatter_mean(x, batch, dim=0)
144 | edge_attr_mean = torch_scatter.scatter_mean(edge_attr, batch[edge_index[0]], dim=0)
145 | model_input = torch.cat([u, node_attr_mean, edge_attr_mean], dim=1)
146 | out = self.model(model_input)
147 | assert out.shape == u.shape
148 | return out
149 |
150 |
151 | class RewardModel(torch.nn.Module):
152 | def __init__(self, node_size, global_size, hidden_size=128):
153 | super(RewardModel, self).__init__()
154 | self.node_size = node_size
155 | self.global_size = global_size
156 | self.hidden_size = hidden_size
157 | self.model = torch.nn.Sequential(
158 | torch.nn.Linear(self.global_size, self.hidden_size),
159 | torch.nn.ReLU(inplace=True),
160 | torch.nn.Linear(self.hidden_size, self.hidden_size),
161 | torch.nn.ReLU(inplace=True),
162 | torch.nn.Linear(self.hidden_size, 1))
163 |
164 | def forward(self, node_feat, global_feat, batch):
165 | out = self.model(global_feat)
166 | return out
167 |
168 |
169 | class GNBlock(torch.nn.Module):
170 | def __init__(self, input_size, hidden_size=128, output_size=128, use_global=True, global_size=128):
171 | super(GNBlock, self).__init__()
172 | self.input_size = input_size
173 | self.hidden_size = hidden_size
174 | self.output_size = output_size
175 | if use_global:
176 | self.model = MetaLayer(EdgeModel(self.input_size[0], self.hidden_size, self.output_size),
177 | NodeModel(self.input_size[1], self.hidden_size, self.output_size),
178 | GlobalModel(self.input_size[2], self.hidden_size, global_size))
179 | else:
180 | self.model = MetaLayer(EdgeModel(self.input_size[0], self.hidden_size, self.output_size),
181 | NodeModel(self.input_size[1], self.hidden_size, self.output_size),
182 | None)
183 |
184 | def forward(self, x, edge_index, edge_attr, u, batch):
185 | # x: [N, F_x], where N is the number of nodes.
186 | # edge_index: [2, E] with max entry N - 1.
187 | # edge_attr: [E, F_e]
188 | # u: [B, F_u]
189 | # batch: [N] with max entry B - 1.
190 | x, edge_attr, u = self.model(x, edge_index, edge_attr, u, batch)
191 | return x, edge_attr, u
192 |
193 |
194 | class Processor(torch.nn.Module):
195 | def __init__(self, input_size, hidden_size=128, output_size=128, use_global=True, global_size=128, layers=10):
196 | """
197 | :param input_size: A list of size to edge model, node model and global model
198 | """
199 | super(Processor, self).__init__()
200 | self.input_size = input_size
201 | self.hidden_size = hidden_size
202 | self.output_size = output_size
203 | self.use_global = use_global
204 | self.global_size = global_size
205 | self.gns = torch.nn.ModuleList([
206 | GNBlock(self.input_size, self.hidden_size, self.output_size, self.use_global, global_size=global_size)
207 | for _ in range(layers)])
208 |
209 | def forward(self, x, edge_index, edge_attr, u, batch):
210 | # def forward(self, data):
211 | # x, edge_index, edge_attr, u, batch = data.node_embedding, data.neighbors, data.edge_embedding, data.global_feat, data.batch
212 | # x: [N, F_x], where N is the number of nodes.
213 | # edge_index: [2, E] with max entry N - 1.
214 | # edge_attr: [E, F_e]
215 | # u: [B, F_u]
216 | # batch: [N] with max entry B - 1.
217 | if len(u.shape) == 1:
218 | u = u[None]
219 | if edge_index.shape[1] < 10:
220 | print("--------debug info---------")
221 | print("small number of edges")
222 | print("x.shape: ", x.shape)
223 | print("edge_index.shape: ", edge_index.shape)
224 | print("edge_attr.shape: ", edge_attr.shape, flush=True)
225 | print("--------------------------")
226 |
227 | x_new, edge_attr_new, u_new = x, edge_attr, u
228 | for gn in self.gns:
229 | x_res, edge_attr_res, u_res = gn(x_new, edge_index, edge_attr_new, u_new, batch)
230 | x_new = x_new + x_res
231 | edge_attr_new = edge_attr_new + edge_attr_res
232 | u_new = u_new + u_res
233 | return x_new, edge_attr_new, u_new
234 |
235 |
236 | # ================== Decoder ================== #
237 | class Decoder(torch.nn.Module):
238 | def __init__(self, input_size=128, hidden_size=128, output_size=3):
239 | super(Decoder, self).__init__()
240 | self.input_size = input_size
241 | self.hidden_size = hidden_size
242 | self.output_size = output_size
243 | self.model = torch.nn.Sequential(
244 | torch.nn.Linear(self.input_size, self.hidden_size),
245 | torch.nn.ReLU(inplace=True),
246 | # torch.nn.LayerNorm(self.hidden_size),
247 | torch.nn.Linear(self.hidden_size, self.hidden_size),
248 | torch.nn.ReLU(inplace=True),
249 | # torch.nn.LayerNorm(self.hidden_size),
250 | torch.nn.Linear(self.hidden_size, self.output_size))
251 |
252 | def forward(self, node_feat, res=None):
253 | out = self.model(node_feat)
254 | if res is not None:
255 | out = out + res
256 | return out
257 |
258 |
259 | class GNN(torch.nn.Module):
260 | def __init__(self, args, decoder_output_dim, name, use_reward=False):
261 | super(GNN, self).__init__()
262 | self.name = name
263 | self.args = args
264 | self.use_global = True if self.args.global_size > 1 else False
265 | embed_dim = 128
266 | self.dyn_models = torch.nn.ModuleDict({'encoder': Encoder(args.state_dim, args.relation_dim, output_size=embed_dim),
267 | 'processor': Processor(
268 | [3 * embed_dim + args.global_size,
269 | 2 * embed_dim + args.global_size,
270 | 2 * embed_dim + args.global_size],
271 | use_global=self.use_global, layers=args.proc_layer, global_size=args.global_size),
272 | 'decoder': Decoder(output_size=decoder_output_dim)})
273 | self.use_reward = use_reward
274 | print(use_reward)
275 | if use_reward:
276 | self.dyn_models['reward_model'] = RewardModel(128, 128, 128)
277 |
278 | def forward(self, data):
279 | """ data should be a dictionary containing the following dict
280 | edge_index: Edge index 2 x E
281 | x: Node feature
282 | edge_attr: Edge feature
283 | gt_accel: Acceleration label for each node
284 | x_batch: Batch index
285 | """
286 | out = {}
287 | node_embedding, edge_embedding = self.dyn_models['encoder'](data['x'], data['edge_attr'])
288 | n_nxt, e_nxt, lat_nxt = self.dyn_models['processor'](node_embedding,
289 | data['edge_index'],
290 | edge_embedding,
291 | u=data['u'],
292 | batch=data['x_batch'])
293 | # Return acceleration for each node and the final global feature (for potential multi-step training)
294 | if self.name == 'EdgeGNN':
295 | out['mesh_edge'] = self.dyn_models['decoder'](e_nxt)
296 | else:
297 | out['accel'] = self.dyn_models['decoder'](n_nxt)
298 | if self.use_reward:
299 | out['reward_nxt'] = self.dyn_models['reward_model'](n_nxt, lat_nxt, batch=data['x_batch'])
300 |
301 | out['n_nxt'] = n_nxt[data['partial_pc_mapped_idx']] if 'partial_pc_mapped_idx' in data else n_nxt
302 | out['lat_nxt'] = lat_nxt
303 | return out
304 |
305 | def load_model(self, model_path, load_names='all', load_optim=False, optim=None):
306 | """
307 | :param load_names: which part of ['encoder', 'processor', 'decoder'] to load
308 | :param load_optim: Whether to load optimizer states
309 | :return:
310 | """
311 | ckpt = torch.load(model_path)
312 | optim_path = model_path.replace('dyn', 'optim')
313 | if load_names == 'all':
314 | for k, v in self.dyn_models.items():
315 | self.dyn_models[k].load_state_dict(ckpt[k])
316 | else:
317 | for model_name in load_names:
318 | self.dyn_models[model_name].load_state_dict(ckpt[model_name])
319 | print('Loaded saved ckp from {} for {} models'.format(model_path, load_names))
320 |
321 | if load_optim:
322 | assert os.path.exists(optim_path)
323 | optim.load_state_dict(torch.load(optim_path))
324 | print('Load optimizer states from ', optim_path)
325 |
326 | def save_model(self, root_path, m_name, suffix, optim):
327 | """
328 | Regular saving: {input_type}_dyn_{epoch}.pth
329 | Best model: {input_type}_dyn_best.pth
330 | Optim: {input_type}_optim_{epoch}.pth
331 | """
332 | save_name = 'edge' if self.name == 'EdgeGNN' else 'dyn'
333 | model_path = os.path.join(root_path, '{}_{}_{}.pth'.format(m_name, save_name, suffix))
334 | torch.save({k: v.state_dict() for k, v in self.dyn_models.items()}, model_path)
335 | optim_path = os.path.join(root_path, '{}_{}_{}.pth'.format(m_name, 'optim', suffix))
336 | torch.save(optim.state_dict(), optim_path)
337 |
338 | def set_mode(self, mode='train'):
339 | assert mode in ['train', 'eval']
340 | for model in self.dyn_models.values():
341 | if mode == 'eval':
342 | model.eval()
343 | else:
344 | model.train()
345 |
346 | def param(self):
347 | model_parameters = list(chain(*[list(m.parameters()) for m in self.dyn_models.values()]))
348 | return model_parameters
349 |
350 | def to(self, device):
351 | for model in self.dyn_models.values():
352 | model.to(device)
353 |
354 | def freeze(self, tgts=None):
355 | if tgts is None:
356 | for m in self.dyn_models.values():
357 | for para in m.parameters():
358 | para.requires_grad = False
359 | else:
360 | for tgt in tgts:
361 | m = self.dyn_models[tgt]
362 | for para in m.parameters():
363 | para.requires_grad = False
364 |
365 | def unfreeze(self, tgts=None):
366 | if tgts is None:
367 | for m in self.dyn_models.values():
368 | for para in m.parameters():
369 | para.requires_grad = True
370 | else:
371 | for tgt in tgts:
372 | m = self.dyn_models[tgt]
373 | for para in m.parameters():
374 | para.requires_grad = True
375 |
--------------------------------------------------------------------------------
/VCD/rs_planner.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from multiprocessing import pool
3 | import copy
4 | from VCD.utils.camera_utils import project_to_image, get_target_pos
5 |
6 |
7 | class RandomShootingUVPickandPlacePlanner():
8 |
9 | def __init__(self, num_pick, delta_y, pull_step, wait_step,
10 | dynamics, reward_model, num_worker=10,
11 | move_distance_range=[0.05, 0.2], gpu_num=1,
12 | image_size=None, normalize_info=None, delta_y_range=None,
13 | matrix_world_to_camera=np.identity(4), task='flatten',
14 | use_pred_rwd=False):
15 | """
16 | Random Shooting planner.
17 | """
18 |
19 | self.normalize_info = normalize_info # Used for robot experiments to denormalize before action clipping
20 | self.num_pick = num_pick
21 | self.delta_y = delta_y # for real world experiment, delta_y (pick_height) is fixed
22 | self.delta_y_range = delta_y_range # for simulation, delta_y is randomlized
23 | self.move_distance_low, self.move_distance_high = move_distance_range[0], move_distance_range[1]
24 | self.reward_model = reward_model
25 | self.dynamics = dynamics
26 | self.pull_step, self.wait_step = pull_step, wait_step
27 | self.gpu_num = gpu_num
28 | self.use_pred_rwd = use_pred_rwd
29 |
30 | if num_worker > 0:
31 | self.pool = pool.Pool(processes=num_worker)
32 | self.num_worker = num_worker
33 | self.matrix_world_to_camera = matrix_world_to_camera
34 | self.image_size = image_size
35 | self.task = task
36 |
37 | def project_3d(self, pos):
38 | return project_to_image(self.matrix_world_to_camera, pos, self.image_size[0], self.image_size[1])
39 |
40 | def get_action(self, init_data, robot_exp=False, cloth_mask=None, check_mask=None, m_name='vsbl'):
41 | """
42 | check_mask: Used to filter out place points that are on the cloth.
43 | init_data should be a list that include:
44 | ['pointcloud', 'velocities', 'picker_position', 'action', 'picked_points', 'scene_params', 'observable_particle_indices]
45 | note: require position, velocity to be already downsampled
46 |
47 | """
48 | args = self.dynamics.args
49 | data = init_data.copy()
50 | data['picked_points'] = [-1, -1]
51 |
52 | pull_step, wait_step = self.pull_step, self.wait_step
53 |
54 | # add a no-op action
55 | pick_try_num = self.num_pick + 1 if self.task == 'flatten' else self.num_pick
56 | actions = np.zeros((pick_try_num, pull_step + wait_step, 8))
57 | pointcloud = copy.deepcopy(data['pointcloud'])
58 |
59 | picker_pos = data['picker_position'][0][:3] if data['picker_position'] is not None else None
60 | bb_margin = 30
61 |
62 | # paralleled version of generating action sequences
63 | if robot_exp:
64 | num_samples = 10 * self.num_pick # Reject
65 |
66 | def filter_out_of_bound(pos, x_low=-0.45, x_high=0.06, z_low=0.3, z_high=0.65):
67 | """Rreturn in bound idxes"""
68 | cond1 = pos[:, 0] >= x_low
69 | cond2 = pos[:, 0] <= x_high
70 | cond3 = pos[:, 2] >= z_low
71 | cond4 = pos[:, 2] <= z_high
72 | cond = cond1 * cond2 * cond3 * cond4
73 | return np.where(cond)[0]
74 |
75 | idxes = np.random.randint(0, len(pointcloud), num_samples)
76 |
77 | # In real world, instead of using uv, simply pick a random idx
78 | pickup_pos = pointcloud[idxes]
79 |
80 | # Remove out of bound pick places
81 | x_mean, z_mean = self.normalize_info['xz_mean'] # First denormalize
82 | pickup_pos[:, 0] += x_mean
83 | pickup_pos[:, 2] += z_mean
84 | in_bound_idx = filter_out_of_bound(pickup_pos)
85 | pickup_pos[:, 0] -= x_mean
86 | pickup_pos[:, 2] -= z_mean
87 | pickup_pos = pickup_pos[in_bound_idx]
88 | idxes = idxes[in_bound_idx]
89 | num_samples = len(pickup_pos)
90 |
91 | move_theta = np.random.rand(num_samples).reshape(num_samples, 1) * 2 * np.pi
92 | move_distance = np.random.uniform(self.move_distance_low, self.move_distance_high, num_samples)
93 | move_direction = np.hstack(
94 | [np.cos(move_theta), np.zeros_like(move_theta), np.sin(move_theta)]) * move_distance.reshape(
95 | num_samples, 1)
96 |
97 | place_pos = (pickup_pos + move_direction).copy()
98 |
99 | # Clip place_pos with a fixed bounding box to make sure the place point is within the camera
100 | x_mean, z_mean = self.normalize_info['xz_mean'] # First denormalize
101 | place_pos[:, 0] += x_mean
102 | place_pos[:, 2] += z_mean
103 | in_bound_idx = filter_out_of_bound(place_pos)
104 | place_pos[:, 0] -= x_mean
105 | place_pos[:, 2] -= z_mean
106 |
107 | if check_mask is not None:
108 | out_cloth_idx = np.where(check_mask(place_pos))[0]
109 | # print('in bound number:', in_bound_idx.shape)
110 | # print(in_bound_idx[:50], out_cloth_idx[:50])
111 | in_bound_idx = np.intersect1d(in_bound_idx, out_cloth_idx)
112 | # print(in_bound_idx)
113 | # print('how many left:', in_bound_idx.shape)
114 |
115 | select_idx = np.random.choice(in_bound_idx, self.num_pick, replace=len(in_bound_idx) < self.num_pick)
116 | pickup_pos = pickup_pos[select_idx]
117 | place_pos = place_pos[select_idx]
118 | idxes = idxes[select_idx]
119 | waypoints = np.zeros([self.num_pick, 3, 3])
120 |
121 | waypoints[:, 0, :] = pickup_pos
122 | waypoints[:, 1, :] = pickup_pos + np.array([0, self.delta_y, 0]).reshape(1, 3)
123 | waypoints[:, 2, :] = place_pos + np.array([0, self.delta_y, 0]).reshape(1, 3)
124 |
125 | # Update move_direction after clipping
126 | move_direction = waypoints[:, 2, :] - waypoints[:, 1, :]
127 |
128 | delta_moves = list(move_direction / self.pull_step)
129 | picked_particles = list(idxes)
130 | waypoints = list(waypoints)
131 | move_vec = list(move_direction)
132 | # TODO try using delta_y_raneg instead of fixed y
133 | delta_move = move_direction
134 | delta_move[:, 1] += self.delta_y
135 | num_step = self.pull_step
136 | actions[:-1, :num_step, :3] = delta_move[:, None, :] / num_step # Upward
137 | actions[:-1, :num_step, 3] = 1
138 | actions[:, :, 4:] = 0 # we essentially only plan over 1 picker action
139 |
140 | # Add no-op
141 | if not self.random:
142 | waypoints.append([np.nan, np.nan, np.nan])
143 | delta_moves.append([0., 0., 0.])
144 | picked_particles.append(-1)
145 | move_vec.append([0., 0., 0.])
146 | else: # simulation planning
147 | us, vs = self.project_3d(pointcloud)
148 | params = [
149 | (us, vs, self.image_size, pointcloud, pull_step,
150 | self.delta_y_range, bb_margin, self.matrix_world_to_camera,
151 | self.move_distance_low, self.move_distance_high, cloth_mask, self.task)
152 | for i in range(self.num_pick)
153 | ]
154 | results = self.pool.map(parallel_generate_actions, params)
155 | delta_moves, start_poses, after_poses = [x[0] for x in results], [x[1] for x in results], [x[2] for x in results]
156 | if self.task == 'flatten': # add a no-op action
157 | start_poses.append(data['picker_position'][0, :])
158 | after_poses.append(data['picker_position'][0, :])
159 |
160 | actions[:-1, :pull_step, :3] = np.vstack(delta_moves)[:, None, :]
161 | actions[:-1, :pull_step, 3] = 1
162 | actions[:, :, 4:] = 0
163 | move_vec = None
164 |
165 | # parallely rollout the dynamics model with the sampled action seqeunces
166 | data_cpy = copy.deepcopy(data)
167 | if self.num_worker > 0:
168 | job_each_gpu = pick_try_num // self.gpu_num
169 | params = []
170 | for i in range(pick_try_num):
171 | if robot_exp:
172 | data_cpy['picked_points'] = [picked_particles[i], -1]
173 | else:
174 | data_cpy['picked_points'] = [-1, -1]
175 | data_cpy['picker_position'][0, :] = start_poses[i]
176 |
177 | gpu_id = i // job_each_gpu if i < self.gpu_num * job_each_gpu else i % self.gpu_num
178 | params.append(
179 | dict(
180 | model_input_data=copy.deepcopy(data_cpy), actions=actions[i], m_name=m_name,
181 | reward_model=self.reward_model, cuda_idx=gpu_id, robot_exp=robot_exp,
182 | )
183 | )
184 | results = self.pool.map(self.dynamics.rollout, params, chunksize=max(1, pick_try_num // self.num_worker))
185 | returns = [x['final_ret'] for x in results]
186 | else: # sequentially rollout each sampled action trajectory
187 | returns, results = [], []
188 | for i in range(pick_try_num):
189 | res = self.dynamics.rollout(
190 | dict(
191 | model_input_data=copy.deepcopy(data_cpy), actions=actions[i], m_name=m_name,
192 | reward_model=self.reward_model, cuda_idx=0, robot_exp=robot_exp,
193 | )
194 | )
195 | results.append(res), returns.append(res['final_ret'])
196 |
197 | ret_info = {}
198 | highest_return_idx = np.argmax(returns)
199 |
200 | ret_info['highest_return_idx'] = highest_return_idx
201 | action_seq = actions[highest_return_idx]
202 | if robot_exp:
203 | ret_info['waypoints'] = np.array(waypoints[highest_return_idx]).copy()
204 | ret_info['all_candidate'] = np.array(waypoints[:-1])
205 | ret_info['all_candidate_rewards'] = np.array(returns[:-1])
206 | else:
207 | ret_info['start_pos'] = start_poses[highest_return_idx]
208 | ret_info['after_pos'] = after_poses[highest_return_idx]
209 |
210 | model_predict_particle_positions = results[highest_return_idx]['model_positions']
211 | model_predict_shape_positions = results[highest_return_idx]['shape_positions']
212 | predicted_edges = results[highest_return_idx]['mesh_edges']
213 | if move_vec is not None:
214 | ret_info['picked_pos'] = pointcloud[picked_particles[highest_return_idx]]
215 | ret_info['move_vec'] = move_vec[highest_return_idx]
216 |
217 | return action_seq, model_predict_particle_positions, model_predict_shape_positions, ret_info, predicted_edges
218 |
219 |
220 | def pos_in_image(after_pos, matrix_world_to_camera, image_size):
221 | euv = project_to_image(matrix_world_to_camera, after_pos.reshape((1, 3)), image_size[0], image_size[1])
222 | u, v = euv[0][0], euv[1][0]
223 | if u >= 0 and u < image_size[1] and v >= 0 and v < image_size[0]:
224 | return True
225 | else:
226 | return False
227 |
228 |
229 | def parallel_generate_actions(args):
230 | us, vs, image_size, pointcloud, pull_step, delta_y_range, bb_margin, matrix_world_to_camera, move_distance_low, move_distance_high, cloth_mask, task = args
231 |
232 | # choosing a pick location
233 | lb_u, lb_v, ub_u, ub_v = int(np.min(us)), int(np.min(vs)), int(np.max(us)), int(np.max(vs))
234 | u = np.random.randint(max(lb_u - bb_margin, 0), min(ub_u + bb_margin, image_size[1]))
235 | v = np.random.randint(max(lb_v - bb_margin, 0), min(ub_v + bb_margin, image_size[0]))
236 | target_pos = get_target_pos(pointcloud, u, v, image_size, matrix_world_to_camera, cloth_mask)
237 |
238 | # second stage: choose a random (x, y, z) direction, move towards that direction to determine the pick point
239 | while True:
240 | move_direction = np.random.rand(3) - 0.5
241 | if task == 'flatten':
242 | move_direction[1] = np.random.uniform(delta_y_range[0], delta_y_range[1])
243 | else: # for fold, just generate horizontal move
244 | move_direction[1] = 0
245 |
246 | move_direction = move_direction / np.linalg.norm(move_direction)
247 | move_distance = np.random.uniform(move_distance_low, move_distance_high)
248 | delta_move = move_distance / pull_step * move_direction
249 |
250 | after_pos = target_pos + move_distance * move_direction
251 | if pos_in_image(after_pos, matrix_world_to_camera, image_size):
252 | break
253 |
254 | return delta_move, target_pos, after_pos
255 |
--------------------------------------------------------------------------------
/VCD/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/VCD/utils/__init__.py
--------------------------------------------------------------------------------
/VCD/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.optimize as opt
3 | import scipy
4 | from scipy.spatial.distance import cdist
5 |
6 |
7 | def build_depth_from_pointcloud(pointcloud, matrix_world_to_camera, imsize):
8 | height, width = imsize
9 | pointcloud = np.concatenate([pointcloud, np.ones((len(pointcloud), 1))], axis=1) # n x 4
10 | camera_coordinate = matrix_world_to_camera @ pointcloud.T # 3 x n
11 | camera_coordinate = camera_coordinate.T # n x 3
12 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees
13 |
14 | u0 = K[0, 2]
15 | v0 = K[1, 2]
16 | fx = K[0, 0]
17 | fy = K[1, 1]
18 |
19 | x, y, depth = camera_coordinate[:, 0], camera_coordinate[:, 1], camera_coordinate[:, 2]
20 | u = np.rint((x * fx / depth + u0).astype("int"))
21 | v = np.rint((y * fy / depth + v0).astype("int"))
22 |
23 | us = u.flatten()
24 | vs = v.flatten()
25 | depth = depth.flatten()
26 |
27 | depth_map = dict()
28 | for u, v, d in zip(us, vs, depth):
29 | if depth_map.get((u, v)) is None:
30 | depth_map[(u, v)] = []
31 | depth_map[(u, v)].append(d)
32 | else:
33 | depth_map[(u, v)].append(d)
34 |
35 | depth_2d = np.zeros((height, width))
36 | for u in range(width):
37 | for v in range(height):
38 | if (u, v) in depth_map.keys():
39 | depth_2d[v][u] = np.min(depth_map[(u, v)])
40 |
41 | return depth_2d
42 |
43 |
44 | def pixel_coord_np(width, height):
45 | """
46 | Pixel in homogenous coordinate
47 | Returns:
48 | Pixel coordinate: [3, width * height]
49 | """
50 | x = np.linspace(0, width - 1, width).astype(np.int)
51 | y = np.linspace(0, height - 1, height).astype(np.int)
52 | [x, y] = np.meshgrid(x, y)
53 | return np.vstack((x.flatten(), y.flatten(), np.ones_like(x.flatten())))
54 |
55 |
56 | def intrinsic_from_fov(height, width, fov=90):
57 | """
58 | Basic Pinhole Camera Model
59 | intrinsic params from fov and sensor width and height in pixels
60 | Returns:
61 | K: [4, 4]
62 | """
63 | px, py = (width / 2, height / 2)
64 | hfov = fov / 360. * 2. * np.pi
65 | fx = width / (2. * np.tan(hfov / 2.))
66 |
67 | vfov = 2. * np.arctan(np.tan(hfov / 2) * height / width)
68 | fy = height / (2. * np.tan(vfov / 2.))
69 |
70 | return np.array([[fx, 0, px, 0.],
71 | [0, fy, py, 0.],
72 | [0, 0, 1., 0.],
73 | [0., 0., 0., 1.]])
74 |
75 |
76 | def get_rotation_matrix(angle, axis):
77 | axis = axis / np.linalg.norm(axis)
78 | s = np.sin(angle)
79 | c = np.cos(angle)
80 |
81 | m = np.zeros((4, 4))
82 |
83 | m[0][0] = axis[0] * axis[0] + (1.0 - axis[0] * axis[0]) * c
84 | m[0][1] = axis[0] * axis[1] * (1.0 - c) - axis[2] * s
85 | m[0][2] = axis[0] * axis[2] * (1.0 - c) + axis[1] * s
86 | m[0][3] = 0.0
87 |
88 | m[1][0] = axis[0] * axis[1] * (1.0 - c) + axis[2] * s
89 | m[1][1] = axis[1] * axis[1] + (1.0 - axis[1] * axis[1]) * c
90 | m[1][2] = axis[1] * axis[2] * (1.0 - c) - axis[0] * s
91 | m[1][3] = 0.0
92 |
93 | m[2][0] = axis[0] * axis[2] * (1.0 - c) - axis[1] * s
94 | m[2][1] = axis[1] * axis[2] * (1.0 - c) + axis[0] * s
95 | m[2][2] = axis[2] * axis[2] + (1.0 - axis[2] * axis[2]) * c
96 | m[2][3] = 0.0
97 |
98 | m[3][0] = 0.0
99 | m[3][1] = 0.0
100 | m[3][2] = 0.0
101 | m[3][3] = 1.0
102 |
103 | return m
104 |
105 |
106 | def get_world_coords(rgb, depth, env, particle_pos=None):
107 | height, width, _ = rgb.shape
108 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees
109 |
110 | # Apply back-projection: K_inv @ pixels * depth
111 | u0 = K[0, 2]
112 | v0 = K[1, 2]
113 | fx = K[0, 0]
114 | fy = K[1, 1]
115 |
116 | x = np.linspace(0, width - 1, width).astype(np.float)
117 | y = np.linspace(0, height - 1, height).astype(np.float)
118 | u, v = np.meshgrid(x, y)
119 | one = np.ones((height, width, 1))
120 | x = (u - u0) * depth / fx
121 | y = (v - v0) * depth / fy
122 | z = depth
123 | cam_coords = np.dstack([x, y, z, one])
124 |
125 | matrix_world_to_camera = get_matrix_world_to_camera(
126 | env.camera_params[env.camera_name]['pos'], env.camera_params[env.camera_name]['angle'])
127 |
128 | # convert the camera coordinate back to the world coordinate using the rotation and translation matrix
129 | cam_coords = cam_coords.reshape((-1, 4)).transpose() # 4 x (height x width)
130 | world_coords = np.linalg.inv(matrix_world_to_camera) @ cam_coords # 4 x (height x width)
131 | world_coords = world_coords.transpose().reshape((height, width, 4))
132 |
133 | return world_coords
134 |
135 |
136 | def get_observable_particle_index(world_coords, particle_pos, rgb, depth):
137 | height, width, _ = rgb.shape
138 | # perform the matching of pixel particle to real particle
139 | particle_pos = particle_pos[:, :3]
140 |
141 | estimated_world_coords = np.array(world_coords)[np.where(depth > 0)][:, :3]
142 |
143 | distance = scipy.spatial.distance.cdist(estimated_world_coords, particle_pos)
144 | # Each point in the point cloud will cover at most two particles. Particles not covered will be deemed occluded
145 | estimated_particle_idx = np.argpartition(distance, 2)[:, :2].flatten()
146 | estimated_particle_idx = np.unique(estimated_particle_idx)
147 |
148 | return np.array(estimated_particle_idx, dtype=np.int32)
149 |
150 |
151 | def get_observable_particle_index_old(world_coords, particle_pos, rgb, depth):
152 | height, width, _ = rgb.shape
153 | # perform the matching of pixel particle to real particle
154 | particle_pos = particle_pos[:, :3]
155 |
156 | estimated_world_coords = np.array(world_coords)[np.where(depth > 0)][:, :3]
157 |
158 | distance = scipy.spatial.distance.cdist(estimated_world_coords, particle_pos)
159 | estimated_particle_idx = np.argmin(distance, axis=1)
160 | estimated_particle_idx = np.unique(estimated_particle_idx)
161 |
162 | return np.array(estimated_particle_idx, dtype=np.int32)
163 |
164 |
165 | def get_observable_particle_index_3(pointcloud, mesh, threshold=0.0216):
166 | ### bi-partite graph matching
167 | distance = scipy.spatial.distance.cdist(pointcloud, mesh)
168 | distance[distance > threshold] = 1e10
169 | row_idx, column_idx = opt.linear_sum_assignment(distance)
170 |
171 | distance_mapped = distance[np.arange(len(pointcloud)), column_idx]
172 | bad_mapping = distance_mapped > threshold
173 | if np.sum(bad_mapping) > 0:
174 | column_idx[bad_mapping] = np.argmin(distance[bad_mapping], axis=1)
175 |
176 | return pointcloud, column_idx
177 |
178 |
179 | def get_mapping_from_pointcloud_to_partile_nearest_neighbor(pointcloud, particle):
180 | distance = scipy.spatial.distance.cdist(pointcloud, particle)
181 | nearest_idx = np.argmin(distance, axis=1)
182 | return nearest_idx
183 |
184 |
185 | def get_observable_particle_index_4(pointcloud, mesh, threshold=0.0216):
186 | # perform the matching of pixel particle to real particle
187 | estimated_world_coords = pointcloud
188 |
189 | distance = scipy.spatial.distance.cdist(estimated_world_coords, mesh)
190 | estimated_particle_idx = np.argmin(distance, axis=1)
191 |
192 | return pointcloud, np.array(estimated_particle_idx, dtype=np.int32)
193 |
194 |
195 | def get_observable_particle_pos(world_coords, particle_pos):
196 | # perform the matching of pixel particle to real particle
197 | particle_pos = particle_pos[:, :3]
198 | distance = scipy.spatial.distance.cdist(world_coords, particle_pos)
199 | estimated_particle_idx = np.argmin(distance, axis=1)
200 | observable_particle_pos = particle_pos[estimated_particle_idx]
201 |
202 | return observable_particle_pos
203 |
204 |
205 | def get_matrix_world_to_camera(cam_pos=[-0.0, 0.82, 0.82], cam_angle=[0, -45 / 180. * np.pi, 0.]):
206 | cam_x, cam_y, cam_z = cam_pos[0], cam_pos[1], \
207 | cam_pos[2]
208 | cam_x_angle, cam_y_angle, cam_z_angle = cam_angle[0], cam_angle[1], \
209 | cam_angle[2]
210 |
211 | # get rotation matrix: from world to camera
212 | matrix1 = get_rotation_matrix(- cam_x_angle, [0, 1, 0])
213 | matrix2 = get_rotation_matrix(- cam_y_angle - np.pi, [1, 0, 0])
214 | rotation_matrix = matrix2 @ matrix1
215 |
216 | # get translation matrix: from world to camera
217 | translation_matrix = np.zeros((4, 4))
218 | translation_matrix[0][0] = 1
219 | translation_matrix[1][1] = 1
220 | translation_matrix[2][2] = 1
221 | translation_matrix[3][3] = 1
222 | translation_matrix[0][3] = - cam_x
223 | translation_matrix[1][3] = - cam_y
224 | translation_matrix[2][3] = - cam_z
225 |
226 | return rotation_matrix @ translation_matrix
227 |
228 |
229 | def project_to_image(matrix_world_to_camera, world_coordinate, height=360, width=360):
230 | world_coordinate = np.concatenate([world_coordinate, np.ones((len(world_coordinate), 1))], axis=1) # n x 4
231 | camera_coordinate = matrix_world_to_camera @ world_coordinate.T # 3 x n
232 | camera_coordinate = camera_coordinate.T # n x 3
233 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees
234 |
235 | u0 = K[0, 2]
236 | v0 = K[1, 2]
237 | fx = K[0, 0]
238 | fy = K[1, 1]
239 |
240 | x, y, depth = camera_coordinate[:, 0], camera_coordinate[:, 1], camera_coordinate[:, 2]
241 | u = (x * fx / depth + u0).astype("int")
242 | v = (y * fy / depth + v0).astype("int")
243 |
244 | return u, v
245 |
246 |
247 | def _get_depth(matrix, vec, height):
248 | """ Get the depth such that the back-projected point has a fixed height"""
249 | return (height - matrix[1, 3]) / (vec[0] * matrix[1, 0] + vec[1] * matrix[1, 1] + matrix[1, 2])
250 |
251 |
252 | def get_world_coor_from_image(u, v, image_size, matrix_world_to_camera, all_depth):
253 | height, width = image_size
254 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees
255 |
256 | matrix = np.linalg.inv(matrix_world_to_camera)
257 |
258 | u0, v0, fx, fy = K[0, 2], K[1, 2], K[0, 0], K[1, 1]
259 |
260 | depth = all_depth[v][u]
261 | if depth == 0:
262 | vec = ((u - u0) / fx, (v - v0) / fy)
263 | depth = _get_depth(matrix, vec, 0.00625) # Height to be the particle radius
264 |
265 | x = (u - u0) * depth / fx
266 | y = (v - v0) * depth / fy
267 | z = depth
268 | cam_coords = np.array([x, y, z, 1])
269 | cam_coords = cam_coords.reshape((-1, 4)).transpose() # 4 x (height x width)
270 |
271 | world_coord = matrix @ cam_coords # 4 x (height x width)
272 | world_coord = world_coord.reshape(4)
273 | return world_coord[:3]
274 |
275 |
276 | def get_target_pos(pos, u, v, image_size, matrix_world_to_camera, depth):
277 | coor = get_world_coor_from_image(u, v, image_size, matrix_world_to_camera, depth)
278 | dists = cdist(coor[None], pos)[0]
279 | idx = np.argmin(dists)
280 | return pos[idx] + np.array([0, 0.01, 0])
281 |
--------------------------------------------------------------------------------
/VCD/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch_geometric.data import Data
4 | import torch_geometric
5 |
6 |
7 | class PrivilData(Data):
8 | """
9 | Encapsulation of multi-graphs for multi-step training
10 | ind: 0-(hor-1), type: vsbl or full
11 | Each graph contain:
12 | edge_index_{type}_{ind},
13 | x_{type}_{ind},
14 | edge_attr_{type}_{ind},
15 | gt_rwd_{type}_{ind}
16 | gt_accel_{type}_{ind}
17 | mesh_mapping_{type}_{ind}
18 | """
19 |
20 | def __init__(self, has_part=False, has_full=False, **kwargs):
21 | super(PrivilData, self).__init__(**kwargs)
22 | self.has_part = has_part
23 | self.has_full = has_full
24 |
25 | def __inc__(self, key, value, *args, **kwargs):
26 | if 'edge_index' in key:
27 | x = key.replace('edge_index', 'x')
28 | return self[x].size(0)
29 | elif 'mesh_mapping' in key:
30 | # add index of mesh matching by
31 | x = key.replace('partial_pc_mapped_idx', 'x')
32 | return self[x].size(0)
33 | else:
34 | return super().__inc__(key, value)
35 |
36 |
37 | class AggDict(dict):
38 | def __init__(self, is_detach=True):
39 | """
40 | Aggregate numpy arrays or pytorch tensors
41 | :param is_detach: Whether to save numpy arrays in stead of torch tensors
42 | """
43 | super(AggDict).__init__()
44 | self.is_detach = is_detach
45 |
46 | def __getitem__(self, item):
47 | return self.get(item, 0)
48 |
49 | def add_item(self, key, value):
50 | if self.is_detach and torch.is_tensor(value):
51 | value = value.detach().cpu().numpy()
52 | if not isinstance(value, torch.Tensor):
53 | if isinstance(value, np.ndarray) or isinstance(value, np.number):
54 | assert value.size == 1
55 | else:
56 | assert isinstance(value, int) or isinstance(value, float)
57 | if key not in self.keys():
58 | self[key] = value
59 | else:
60 | self[key] += value
61 |
62 | def update_by_add(self, src_dict):
63 | for key, value in src_dict.items():
64 | self.add_item(key, value)
65 |
66 | def get_mean(self, prefix, count=1):
67 | avg_dict = {}
68 | for k, v in self.items():
69 | avg_dict[prefix + k] = v / count
70 | return avg_dict
71 |
72 |
73 | def updateDictByAdd(dict1, dict2):
74 | '''
75 | update dict1 by dict2
76 | '''
77 | for k1, v1 in dict2.items():
78 | for k2, v2 in v1.items():
79 | dict1[k1][k2] += v2.cpu().item()
80 | return dict1
81 |
82 |
83 | def get_index_before_padding(graph_sizes):
84 | ins_len = graph_sizes.max()
85 | pad_len = ins_len * graph_sizes.size(0)
86 | valid_len = graph_sizes.sum()
87 | accum = torch.zeros(1).cuda()
88 | out = []
89 | for gs in graph_sizes:
90 | new_ind = torch.range(0, gs - 1).cuda() + accum
91 | out.append(new_ind)
92 | accum += ins_len
93 | final_ind = torch.cat(out, dim=0)
94 | return final_ind.long()
95 |
96 |
97 | class MyDataParallel(torch_geometric.nn.DataParallel):
98 | def __init__(self, *args, **kwargs):
99 | super().__init__(*args, **kwargs)
100 |
101 | def __getattr__(self, name):
102 | if name == 'module':
103 | return self._modules['module']
104 | else:
105 | return getattr(self.module, name)
106 |
107 |
108 | def retrieve_data(data, key):
109 | """
110 | vsbl: [vsbl], full: [full], dual :[vsbl, full]
111 | """
112 | if isinstance(data, dict):
113 | identifier = '_{}'.format(key)
114 | out_data = {k.replace(identifier, ''): v for k, v in data.items() if identifier in k}
115 | return out_data
116 |
--------------------------------------------------------------------------------
/VCD/utils/gemo_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyflex
3 |
4 |
5 | def pixel_coord_np(width, height):
6 | """
7 | Pixel in homogenous coordinate
8 | Returns:
9 | Pixel coordinate: [3, width * height]
10 | """
11 | x = np.linspace(0, width - 1, width).astype(np.int)
12 | y = np.linspace(0, height - 1, height).astype(np.int)
13 | [x, y] = np.meshgrid(x, y)
14 | return np.vstack((x.flatten(), y.flatten(), np.ones_like(x.flatten())))
15 |
16 |
17 | def intrinsic_from_fov(height, width, fov=90):
18 | """
19 | Basic Pinhole Camera Model
20 | intrinsic params from fov and sensor width and height in pixels
21 | Returns:
22 | K: [4, 4]
23 | """
24 | px, py = (width / 2, height / 2)
25 | hfov = fov / 360. * 2. * np.pi
26 | fx = width / (2. * np.tan(hfov / 2.))
27 |
28 | vfov = 2. * np.arctan(np.tan(hfov / 2) * height / width)
29 | fy = height / (2. * np.tan(vfov / 2.))
30 |
31 | return np.array([[fx, 0, px, 0.],
32 | [0, fy, py, 0.],
33 | [0, 0, 1., 0.],
34 | [0., 0., 0., 1.]])
35 |
36 |
37 | def get_rotation_matrix(angle, axis):
38 | axis = axis / np.linalg.norm(axis)
39 | s = np.sin(angle)
40 | c = np.cos(angle)
41 |
42 | m = np.zeros((4, 4))
43 |
44 | m[0][0] = axis[0] * axis[0] + (1.0 - axis[0] * axis[0]) * c
45 | # m[0][1] = axis[0] * axis[1] * (1.0 - c) + axis[2] * s
46 | m[0][1] = axis[0] * axis[1] * (1.0 - c) - axis[2] * s
47 | # m[0][2] = axis[0] * axis[2] * (1.0 - c) - axis[1] * s
48 | m[0][2] = axis[0] * axis[2] * (1.0 - c) + axis[1] * s
49 | m[0][3] = 0.0
50 |
51 | # m[1][0] = axis[0] * axis[1] * (1.0 - c) - axis[2] * s
52 | m[1][0] = axis[0] * axis[1] * (1.0 - c) + axis[2] * s
53 | m[1][1] = axis[1] * axis[1] + (1.0 - axis[1] * axis[1]) * c
54 | # m[1][2] = axis[1] * axis[2] * (1.0 - c) + axis[0] * s
55 | m[1][2] = axis[1] * axis[2] * (1.0 - c) - axis[0] * s
56 | m[1][3] = 0.0
57 |
58 | # m[2][0] = axis[0] * axis[2] * (1.0 - c) + axis[1] * s
59 | m[2][0] = axis[0] * axis[2] * (1.0 - c) - axis[1] * s
60 | # m[2][1] = axis[1] * axis[2] * (1.0 - c) - axis[0] * s
61 | m[2][1] = axis[1] * axis[2] * (1.0 - c) + axis[0] * s
62 | m[2][2] = axis[2] * axis[2] + (1.0 - axis[2] * axis[2]) * c
63 | m[2][3] = 0.0
64 |
65 | m[3][0] = 0.0
66 | m[3][1] = 0.0
67 | m[3][2] = 0.0
68 | m[3][3] = 1.0
69 |
70 | return m
71 |
72 |
73 | def get_world_coords(rgb, depth, env):
74 | height, width, _ = rgb.shape
75 | K = intrinsic_from_fov(height, width, 45) # the fov is 90 degrees
76 |
77 | # Apply back-projection: K_inv @ pixels * depth
78 | cam_coords = np.ones((height, width, 4))
79 | u0 = K[0, 2]
80 | v0 = K[1, 2]
81 | fx = K[0, 0]
82 | fy = K[1, 1]
83 | # Loop through each pixel in the image
84 | for v in range(height):
85 | for u in range(width):
86 | # Apply equation in fig 3
87 | x = (u - u0) * depth[v, u] / fx
88 | y = (v - v0) * depth[v, u] / fy
89 | z = depth[v, u]
90 | cam_coords[v][u][:3] = (x, y, z)
91 |
92 | particle_pos = pyflex.get_positions().reshape((-1, 4))
93 | print('cloth pixels: ', np.count_nonzero(depth))
94 | print("cloth particle num: ", pyflex.get_n_particles())
95 |
96 | # debug: print camera coordinates
97 | # print(cam_coords.shape)
98 | # cnt = 0
99 | # for v in range(height):
100 | # for u in range(width):
101 | # if depth[v][u] > 0:
102 | # print("v: {} u: {} cnt: {} cam_coord: {} approximate particle pos: {}".format(
103 | # v, u, cnt, cam_coords[v][u], particle_pos[cnt]))
104 | # rgb = rgbd[:, :, :3].copy()
105 | # rgb[v][u][0] = 255
106 | # rgb[v][u][1] = 0
107 | # rgb[v][u][2] = 0
108 | # cv2.imshow('rgb', rgb[:, :, ::-1])
109 | # cv2.waitKey()
110 | # cnt += 1
111 |
112 | # from cam coord to world coord
113 | cam_x, cam_y, cam_z = env.camera_params['default_camera']['pos'][0], env.camera_params['default_camera']['pos'][1], \
114 | env.camera_params['default_camera']['pos'][2]
115 | cam_x_angle, cam_y_angle, cam_z_angle = env.camera_params['default_camera']['angle'][0], env.camera_params['default_camera']['angle'][1], \
116 | env.camera_params['default_camera']['angle'][2]
117 |
118 | # get rotation matrix: from world to camera
119 | matrix1 = get_rotation_matrix(- cam_x_angle, [0, 1, 0])
120 | # matrix2 = get_rotation_matrix(- cam_y_angle - np.pi, [np.cos(cam_x_angle), 0, np.sin(cam_x_angle)])
121 | matrix2 = get_rotation_matrix(- cam_y_angle - np.pi, [1, 0, 0])
122 | rotation_matrix = matrix2 @ matrix1
123 |
124 | # get translation matrix: from world to camera
125 | translation_matrix = np.zeros((4, 4))
126 | translation_matrix[0][0] = 1
127 | translation_matrix[1][1] = 1
128 | translation_matrix[2][2] = 1
129 | translation_matrix[3][3] = 1
130 | translation_matrix[0][3] = - cam_x
131 | translation_matrix[1][3] = - cam_y
132 | translation_matrix[2][3] = - cam_z
133 |
134 | # debug: from world to camera
135 | cloth_x, cloth_y = env.current_config['ClothSize'][0], env.current_config['ClothSize'][1]
136 | # cnt = 0
137 | # for u in range(height):
138 | # for v in range(width):
139 | # if depth[u][v] > 0:
140 | # world_coord = np.ones(4)
141 | # world_coord[:3] = particle_pos[cnt][:3]
142 | # convert_cam_coord = rotation_matrix @ translation_matrix @ world_coord
143 | # # convert_cam_coord = translation_matrix @ matrix2 @ matrix1 @ world_coord
144 | # print("u {} v {} \n world coord {} \n convert camera coord {} \n real camera coord {}".format(
145 | # u, v, world_coord, convert_cam_coord, cam_coords[u][v]
146 | # ))
147 | # cnt += 1
148 | # input('wait...')
149 |
150 | # convert the camera coordinate back to the world coordinate using the rotation and translation matrix
151 | cam_coords = cam_coords.reshape((-1, 4)).transpose() # 4 x (height x width)
152 | world_coords = np.linalg.inv(rotation_matrix @ translation_matrix) @ cam_coords # 4 x (height x width)
153 | world_coords = world_coords.transpose().reshape((height, width, 4))
154 |
155 | # roughly check the final world coordinate with the actual coordinate
156 | # firstu = 0
157 | # firstv = 0
158 | # for u in range(height):
159 | # for v in range(width):
160 | # if depth[u][v]:
161 | # if u > firstu: # move to a new line
162 | # firstu = u
163 | # firstv = v
164 |
165 | # cnt = (u - firstu) * cloth_x + (v - firstv)
166 | # print("u {} v {} cnt{}\nworld_coord\t{}\nparticle coord\t{}\nerror\t{}".format(
167 | # u, v, cnt, world_coords[u][v], particle_pos[cnt], np.linalg.norm( world_coords[u][v] - particle_pos[cnt])))
168 | # rgb = rgbd[:, :, :3].copy()
169 | # rgb[u][v][0] = 255
170 | # rgb[u][v][1] = 0
171 | # rgb[u][v][2] = 0
172 | # cv2.imshow('rgb', rgb[:, :, ::-1])
173 | # cv2.waitKey()
174 | # exit()
175 | return world_coords
176 |
177 |
178 | def get_observable_particle_index(world_coords, particle_pos, rgb, depth):
179 | height, width, _ = rgb.shape
180 | # perform the matching of pixel particle to real particle
181 | observable_particle_idxes = []
182 | particle_pos = particle_pos[:, :3]
183 | for u in range(height):
184 | for v in range(width):
185 | if depth[u][v] > 0:
186 | estimated_world_coord = world_coords[u][v][:3]
187 | distance = np.linalg.norm(estimated_world_coord - particle_pos, axis=1)
188 | estimated_particle_idx = np.argmin(distance)
189 | # print("u {} v {} estimated particle idx {}".format(u, v, estimated_particle_idx))
190 | observable_particle_idxes.append(estimated_particle_idx)
191 | # rgb = rgbd[:, :, :3].copy()
192 | # rgb[u][v][0] = 255
193 | # rgb[u][v][1] = 0
194 | # rgb[u][v][2] = 0
195 | # cv2.imshow('chosen_idx', rgb[:, :, ::-1])
196 | # cv2.waitKey()
197 | # exit()
198 | return observable_particle_idxes
199 |
--------------------------------------------------------------------------------
/VCD/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def seg_3d_figure(data: np.ndarray, labels: np.ndarray, labelmap=None, sizes=None, fig=None):
5 | import plotly.colors as pc
6 | import plotly.graph_objects as go
7 | from plotly.subplots import make_subplots
8 | import plotly.express as px
9 | import plotly.figure_factory as ff
10 |
11 | # Create a figure.
12 | if fig is None:
13 | fig = go.Figure()
14 |
15 | # Find the ranges for visualizing.
16 | mean = data.mean(axis=0)
17 | max_x = np.abs(data[:, 0] - mean[0]).max()
18 | max_y = np.abs(data[:, 1] - mean[1]).max()
19 | max_z = np.abs(data[:, 2] - mean[2]).max()
20 | all_max = max(max(max_x, max_y), max_z)
21 |
22 | # Colormap.
23 | cols = np.array(pc.qualitative.Alphabet)
24 | labels = labels.astype(int)
25 | for label in np.unique(labels):
26 | subset = data[np.where(labels == label)]
27 | subset = np.squeeze(subset)
28 | if sizes is None:
29 | subset_sizes = 1.5
30 | else:
31 | subset_sizes = sizes[np.where(labels == label)]
32 | color = cols[label % len(cols)]
33 | if labelmap is not None:
34 | legend = labelmap[label]
35 | else:
36 | legend = str(label)
37 | fig.add_trace(
38 | go.Scatter3d(
39 | mode="markers",
40 | marker={"size": subset_sizes, "color": color, "line": {"width": 0}},
41 | x=subset[:, 0],
42 | y=subset[:, 1],
43 | z=subset[:, 2],
44 | name=legend,
45 | )
46 | )
47 | fig.update_layout(showlegend=True)
48 |
49 | # This sets the figure to be a cube centered at the center of the pointcloud, such that it fits
50 | # all the points.
51 | fig.update_layout(
52 | scene=dict(
53 | xaxis=dict(nticks=10, range=[mean[0] - all_max, mean[0] + all_max]),
54 | yaxis=dict(nticks=10, range=[mean[1] - all_max, mean[1] + all_max]),
55 | zaxis=dict(nticks=10, range=[mean[2] - all_max, mean[2] + all_max]),
56 | aspectratio=dict(x=1, y=1, z=1),
57 | ),
58 | margin=dict(l=0, r=0, b=0, t=40),
59 | legend=dict(x=1.0, y=0.75),
60 | )
61 | return fig
62 |
--------------------------------------------------------------------------------
/VCD/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | import cv2
4 | import torch
5 | from torchvision.utils import make_grid
6 | from VCD.utils.camera_utils import project_to_image
7 | import pyflex
8 | import re
9 | import h5py
10 | import os
11 | from softgym.utils.visualization import save_numpy_as_gif
12 | from chester import logger
13 | import random
14 |
15 |
16 | class VArgs(object):
17 | def __init__(self, vv):
18 | for key, val in vv.items():
19 | setattr(self, key, val)
20 |
21 |
22 | def vv_to_args(vv):
23 | args = VArgs(vv)
24 | return args
25 |
26 |
27 | # Function to extract all the numbers from the given string
28 | def extract_numbers(str):
29 | array = re.findall(r'[0-9]+', str)
30 | if len(array) == 0:
31 | return [0]
32 | return array
33 |
34 |
35 | ################## Pointcloud Processing #################
36 | import pcl
37 |
38 |
39 | # def get_partial_particle(full_particle, observable_idx):
40 | # return np.array(full_particle[observable_idx], dtype=np.float32)
41 |
42 |
43 | def voxelize_pointcloud(pointcloud, voxel_size):
44 | cloud = pcl.PointCloud(pointcloud)
45 | sor = cloud.make_voxel_grid_filter()
46 | sor.set_leaf_size(voxel_size, voxel_size, voxel_size)
47 | pointcloud = sor.filter()
48 | pointcloud = np.asarray(pointcloud).astype(np.float32)
49 | return pointcloud
50 |
51 |
52 | from softgym.utils.misc import vectorized_range, vectorized_meshgrid
53 |
54 |
55 | def pc_reward_model(pos, cloth_particle_radius=0.00625, downsample_scale=3):
56 | cloth_particle_radius *= downsample_scale
57 | pos = np.reshape(pos, [-1, 3])
58 | min_x = np.min(pos[:, 0])
59 | min_y = np.min(pos[:, 2])
60 | max_x = np.max(pos[:, 0])
61 | max_y = np.max(pos[:, 2])
62 | init = np.array([min_x, min_y])
63 | span = np.array([max_x - min_x, max_y - min_y]) / 100.
64 | pos2d = pos[:, [0, 2]]
65 |
66 | offset = pos2d - init
67 | slotted_x_low = np.maximum(np.round((offset[:, 0] - cloth_particle_radius) / span[0]).astype(int), 0)
68 | slotted_x_high = np.minimum(np.round((offset[:, 0] + cloth_particle_radius) / span[0]).astype(int), 100)
69 | slotted_y_low = np.maximum(np.round((offset[:, 1] - cloth_particle_radius) / span[1]).astype(int), 0)
70 | slotted_y_high = np.minimum(np.round((offset[:, 1] + cloth_particle_radius) / span[1]).astype(int), 100)
71 |
72 | grid = np.zeros(10000) # Discretization
73 | listx = vectorized_range(slotted_x_low, slotted_x_high)
74 | listy = vectorized_range(slotted_y_low, slotted_y_high)
75 | listxx, listyy = vectorized_meshgrid(listx, listy)
76 | idx = listxx * 100 + listyy
77 | idx = np.clip(idx.flatten(), 0, 9999)
78 | grid[idx] = 1
79 |
80 | res = np.sum(grid) * span[0] * span[1]
81 | return res
82 |
83 |
84 | ################## IO #################################
85 | def downsample(cloth_xdim, cloth_ydim, scale):
86 | cloth_xdim, cloth_ydim = int(cloth_xdim), int(cloth_ydim)
87 | new_idx = np.arange(cloth_xdim * cloth_ydim).reshape((cloth_ydim, cloth_xdim))
88 | new_idx = new_idx[::scale, ::scale]
89 | cloth_ydim, cloth_xdim = new_idx.shape
90 | new_idx = new_idx.flatten()
91 |
92 | return new_idx, cloth_xdim, cloth_ydim
93 |
94 |
95 | def load_h5_data(data_names, path):
96 | hf = h5py.File(path, 'r')
97 | data = {}
98 | for name in data_names:
99 | d = np.array(hf.get(name))
100 | data[name] = d
101 | hf.close()
102 | return data
103 |
104 |
105 | def store_h5_data(data_names, data, path):
106 | hf = h5py.File(path, 'w')
107 | for name in data_names:
108 | hf.create_dataset(name, data=data[name])
109 | hf.close()
110 |
111 |
112 | def load_data(data_dir, idx_rollout, idx_timestep, data_names):
113 | data_path = os.path.join(data_dir, str(idx_rollout), str(idx_timestep) + '.h5')
114 | return load_h5_data(data_names, data_path)
115 |
116 |
117 | def load_data_list(data_dir, idx_rollout, idx_timestep, data_names):
118 | data_path = os.path.join(data_dir, str(idx_rollout), str(idx_timestep) + '.h5')
119 | d = load_h5_data(data_names, data_path)
120 | return [d[name] for name in data_names]
121 |
122 |
123 | def store_data():
124 | raise NotImplementedError
125 |
126 |
127 | def transform_info(all_infos):
128 | """ Input: All info is a nested list with the index of [episode][time]{info_key:info_value}
129 | Output: transformed_infos is a dictionary with the index of [info_key][episode][time]
130 | """
131 | if len(all_infos) == 0:
132 | return []
133 | transformed_info = {}
134 | num_episode = len(all_infos)
135 | T = len(all_infos[0])
136 |
137 | for info_name in all_infos[0][0].keys():
138 | infos = np.zeros([num_episode, T], dtype=np.float32)
139 | for i in range(num_episode):
140 | infos[i, :] = np.array([info[info_name] for info in all_infos[i]])
141 | transformed_info[info_name] = infos
142 | return transformed_info
143 |
144 |
145 | def draw_grid(list_of_imgs, nrow, padding=10, pad_value=200):
146 | img_list = torch.from_numpy(np.array(list_of_imgs).transpose(0, 3, 1, 2))
147 | img = make_grid(img_list, nrow=nrow, padding=padding, pad_value=pad_value)
148 | # print(img.shape)
149 | img = img.numpy().transpose(1, 2, 0)
150 | return img
151 |
152 |
153 | def inrange(x, low, high):
154 | if x >= low and x < high:
155 | return True
156 | else:
157 | return False
158 |
159 |
160 | ################## Visualization ######################
161 |
162 | def draw_edge(frame, predicted_edges, matrix_world_to_camera, pointcloud, camera_height, camera_width):
163 | u, v = project_to_image(matrix_world_to_camera, pointcloud, camera_height, camera_width)
164 | for edge_idx in range(predicted_edges.shape[1]):
165 | s = predicted_edges[0][edge_idx]
166 | r = predicted_edges[1][edge_idx]
167 | start = (u[s], v[s])
168 | end = (u[r], v[r])
169 | color = (255, 0, 0)
170 | thickness = 1
171 | image = cv2.line(frame, start, end, color, thickness)
172 |
173 | return image
174 |
175 |
176 | def cem_make_gif(all_frames, save_dir, save_name):
177 | # Convert to T x index x C x H x W for pytorch
178 | all_frames = np.array(all_frames).transpose([1, 0, 4, 2, 3])
179 | grid_imgs = [make_grid(torch.from_numpy(frame), nrow=5).permute(1, 2, 0).data.cpu().numpy() for frame in all_frames]
180 | save_numpy_as_gif(np.array(grid_imgs), osp.join(save_dir, save_name))
181 |
182 |
183 | def draw_policy_action(obs_before, obs_after, start_loc_1, end_loc_1, matrix_world_to_camera, start_loc_2=None, end_loc_2=None):
184 | height, width, _ = obs_before.shape
185 | if start_loc_2 is not None:
186 | l = [(start_loc_1, end_loc_1), (start_loc_2, end_loc_2)]
187 | else:
188 | l = [(start_loc_1, end_loc_1)]
189 | for (start_loc, end_loc) in l:
190 | # print(start_loc, end_loc)
191 | suv = project_to_image(matrix_world_to_camera, start_loc.reshape((1, 3)), height, width)
192 | su, sv = suv[0][0], suv[1][0]
193 | euv = project_to_image(matrix_world_to_camera, end_loc.reshape((1, 3)), height, width)
194 | eu, ev = euv[0][0], euv[1][0]
195 | if inrange(su, 0, width) and inrange(sv, 0, height) and inrange(eu, 0, width) and inrange(ev, 0, height):
196 | cv2.arrowedLine(obs_before, (su, sv), (eu, ev), (255, 0, 0), 3)
197 | obs_before[sv - 5:sv + 5, su - 5:su + 5, :] = (0, 0, 0)
198 |
199 | res = np.concatenate((obs_before, obs_after), axis=1)
200 | return res
201 |
202 |
203 | def draw_planned_actions(save_idx, obses, start_poses, end_poses, matrix_world_to_camera, log_dir):
204 | height = width = obses[0].shape[0]
205 |
206 | start_uv = []
207 | end_uv = []
208 | for sp in start_poses:
209 | suv = project_to_image(matrix_world_to_camera, sp.reshape((1, 3)), height, width)
210 | start_uv.append((suv[0][0], suv[1][0]))
211 | for ep in end_poses:
212 | euv = project_to_image(matrix_world_to_camera, ep.reshape((1, 3)), height, width)
213 | end_uv.append((euv[0][0], euv[1][0]))
214 |
215 | res = []
216 | for idx in range(len(obses) - 1):
217 | obs = obses[idx]
218 | su, sv = start_uv[idx]
219 | eu, ev = end_uv[idx]
220 | if inrange(su, 0, width) and inrange(sv, 0, height) and inrange(eu, 0, width) and inrange(ev, 0, height):
221 | cv2.arrowedLine(obs, (su, sv), (eu, ev), (255, 0, 0), 3)
222 | obs[sv - 5:sv + 5, su - 5:su + 5, :] = (0, 0, 0)
223 | res.append(obs)
224 |
225 | res.append(obses[-1])
226 | res = np.concatenate(res, axis=1)
227 | cv2.imwrite(osp.join(log_dir, '{}_planned.png'.format(save_idx)), res[:, :, ::-1])
228 |
229 |
230 | def draw_cem_elites(obs_, start_poses, end_poses, mean_start_pos, mean_end_pos,
231 | matrix_world_to_camera, log_dir, save_idx=None):
232 | obs = obs_.copy()
233 | start_uv = []
234 | end_uv = []
235 | height = width = obs.shape[0]
236 | for sp in start_poses:
237 | suv = project_to_image(matrix_world_to_camera, sp.reshape((1, 3)), height, width)
238 | start_uv.append((suv[0][0], suv[1][0]))
239 | for ep in end_poses:
240 | euv = project_to_image(matrix_world_to_camera, ep.reshape((1, 3)), height, width)
241 | end_uv.append((euv[0][0], euv[1][0]))
242 |
243 | for idx in range(len(start_poses)):
244 | su, sv = start_uv[idx]
245 | eu, ev = end_uv[idx]
246 | # poses at the front have higher reward
247 | if inrange(su, 0, 255) and inrange(sv, 0, 255) and inrange(eu, 0, 255) and inrange(ev, 0, 255):
248 | cv2.arrowedLine(obs, (su, sv), (eu, ev), (255 * (1 - idx / len(start_poses)), 0, 0), 2)
249 | obs[sv - 2:sv + 2, su - 2:su + 2, :] = (0, 0, 0)
250 |
251 | mean_s_uv = project_to_image(matrix_world_to_camera, mean_start_pos.reshape((1, 3)), height, width)
252 | mean_e_uv = project_to_image(matrix_world_to_camera, mean_end_pos.reshape((1, 3)), height, width)
253 | mean_su, mean_sv = mean_s_uv[0][0], mean_s_uv[1][0]
254 | mean_eu, mean_ev = mean_e_uv[0][0], mean_e_uv[1][0]
255 |
256 | if inrange(mean_su, 0, 255) and inrange(mean_sv, 0, 255) and \
257 | inrange(mean_eu, 0, 255) and inrange(mean_ev, 0, 255):
258 | cv2.arrowedLine(obs, (mean_su, mean_sv), (mean_eu, mean_ev), (0, 0, 255), 3)
259 | obs[mean_su - 5:mean_sv + 5, mean_eu - 5:mean_ev + 5, :] = (0, 0, 0)
260 | if save_idx is not None:
261 | cv2.imwrite(osp.join(log_dir, '{}_elite.png'.format(save_idx)), obs)
262 | return obs
263 |
264 |
265 | def set_shape_pos(pos):
266 | shape_states = np.array(pyflex.get_shape_states()).reshape(-1, 14)
267 | shape_states[:, 3:6] = pos.reshape(-1, 3)
268 | shape_states[:, :3] = pos.reshape(-1, 3)
269 | pyflex.set_shape_states(shape_states)
270 |
271 |
272 | def visualize(env, particle_positions, shape_positions, config_id, sample_idx=None, picked_particles=None, show=False):
273 | """ Render point cloud trajectory without running the simulation dynamics"""
274 | env.reset(config_id=config_id)
275 | frames = []
276 | for i in range(len(particle_positions)):
277 | particle_pos = particle_positions[i]
278 | shape_pos = shape_positions[i]
279 | p = pyflex.get_positions().reshape(-1, 4)
280 | p[:, :3] = [0., -0.1, 0.] # All particles moved underground
281 | if sample_idx is None:
282 | p[:len(particle_pos), :3] = particle_pos
283 | else:
284 | p[:, :3] = [0, -0.1, 0]
285 | p[sample_idx, :3] = particle_pos
286 | pyflex.set_positions(p)
287 | set_shape_pos(shape_pos)
288 | rgb = env.get_image(env.camera_width, env.camera_height)
289 | frames.append(rgb)
290 | if show:
291 | if i == 0: continue
292 | picked_point = picked_particles[i]
293 | phases = np.zeros(pyflex.get_n_particles())
294 | for id in picked_point:
295 | if id != -1:
296 | phases[sample_idx[int(id)]] = 1
297 | pyflex.set_phases(phases)
298 | img = env.get_image()
299 |
300 | cv2.imshow('picked particle images', img[:, :, ::-1])
301 | cv2.waitKey()
302 |
303 | return frames
304 |
305 |
306 | def add_occluded_particles(observable_positions, observable_vel_history, particle_radius=0.00625, neighbor_distance=0.0216):
307 | occluded_idx = np.where(observable_positions[:, 1] > neighbor_distance / 2 + particle_radius)
308 | occluded_positions = []
309 | for o_idx in occluded_idx[0]:
310 | pos = observable_positions[o_idx]
311 | occlude_num = np.floor(pos[1] / neighbor_distance).astype('int')
312 | for i in range(occlude_num):
313 | occluded_positions.append([pos[0], particle_radius + i * neighbor_distance, pos[2]])
314 |
315 | print("add occluded particles num: ", len(occluded_positions))
316 | occluded_positions = np.asarray(occluded_positions, dtype=np.float32).reshape((-1, 3))
317 | occluded_velocity_his = np.zeros((len(occluded_positions), observable_vel_history.shape[1]), dtype=np.float32)
318 |
319 | all_positions = np.concatenate([observable_positions, occluded_positions], axis=0)
320 | all_vel_his = np.concatenate([observable_vel_history, occluded_velocity_his], axis=0)
321 | return all_positions, all_vel_his
322 |
323 |
324 | def sort_pointcloud_for_fold(pointcloud, dim):
325 | pointcloud = list(pointcloud)
326 | sorted_pointcloud = sorted(pointcloud, key=lambda k: (k[0], k[2]))
327 | for idx in range(len(sorted_pointcloud) - 1):
328 | assert sorted_pointcloud[idx][0] < sorted_pointcloud[idx + 1][0] or (
329 | sorted_pointcloud[idx][0] == sorted_pointcloud[idx + 1][0] and
330 | sorted_pointcloud[idx][2] < sorted_pointcloud[idx + 1][2]
331 | )
332 |
333 | real_sorted = []
334 | for i in range(dim):
335 | points_row = sorted_pointcloud[i * dim: (i + 1) * dim]
336 | points_row = sorted(points_row, key=lambda k: k[2])
337 | real_sorted += points_row
338 |
339 | sorted_pointcloud = real_sorted
340 |
341 | return np.asarray(sorted_pointcloud)
342 |
343 |
344 | def get_fold_idx(dim=4):
345 | group_a = []
346 | for i in range(dim - 1):
347 | for j in range(dim - i - 1):
348 | group_a.append(i * dim + j)
349 |
350 | group_b = []
351 | for j in range(dim - 1, 0, -1):
352 | for i in range(dim - 1, dim - 1 - j, -1):
353 | group_b.append(i * dim + j)
354 |
355 | return group_a, group_b
356 |
357 |
358 | ############################ Other ########################
359 | def updateDictByAdd(dict1, dict2):
360 | '''
361 | update dict1 by dict2
362 | '''
363 | for k1, v1 in dict2.items():
364 | for k2, v2 in v1.items():
365 | dict1[k1][k2] += v2.cpu().item()
366 | return dict1
367 |
368 |
369 | def configure_logger(log_dir, exp_name):
370 | # Configure logger
371 | logger.configure(dir=log_dir, exp_name=exp_name)
372 | logdir = logger.get_dir()
373 | assert logdir is not None
374 | os.makedirs(logdir, exist_ok=True)
375 |
376 |
377 | def configure_seed(seed):
378 | # Configure seed
379 | torch.manual_seed(seed)
380 | if torch.cuda.is_available():
381 | torch.cuda.manual_seed_all(seed)
382 |
383 | np.random.seed(seed)
384 | random.seed(seed)
385 |
386 |
387 | ############### for planning ###############################
388 | def set_picker_pos(pos):
389 | shape_states = pyflex.get_shape_states().reshape((-1, 14))
390 | shape_states[1, :3] = -1
391 | shape_states[1, 3:6] = -1
392 |
393 | shape_states[0, :3] = pos
394 | shape_states[0, 3:6] = pos
395 | pyflex.set_shape_states(shape_states)
396 | pyflex.step()
397 |
398 |
399 | def set_resource():
400 | import resource
401 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
402 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
403 |
--------------------------------------------------------------------------------
/VCD/vc_edge.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | from VCD.models import GNN
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.optim.lr_scheduler import ReduceLROnPlateau
9 |
10 | import os.path as osp
11 |
12 | from chester import logger
13 | from VCD.utils.camera_utils import get_matrix_world_to_camera, project_to_image
14 | import matplotlib.pyplot as plt
15 | import torch_geometric
16 |
17 | from VCD.dataset_edge import ClothDatasetPointCloudEdge
18 | from VCD.utils.utils import extract_numbers
19 | from VCD.utils.data_utils import AggDict
20 | import json
21 | from tqdm import tqdm
22 |
23 |
24 | class VCConnection(object):
25 | def __init__(self, args, env=None):
26 | self.args = args
27 | self.model = GNN(args, decoder_output_dim=1, name='EdgeGNN') # Predict 0/1 Label for mesh edge classification
28 | self.device = torch.device(self.args.cuda_idx)
29 | self.model.to(self.device)
30 | self.optim = torch.optim.Adam(self.model.param(), lr=args.lr, betas=(args.beta1, 0.999))
31 | self.scheduler = ReduceLROnPlateau(self.optim, 'min', factor=0.8, patience=3, verbose=True)
32 | if self.args.edge_model_path is not None:
33 | self.load_model(self.args.load_optim)
34 |
35 | self.datasets = {phase: ClothDatasetPointCloudEdge(args, 'vsbl', phase, env) for phase in ['train', 'valid']}
36 | follow_batch = 'x_'
37 | self.dataloaders = {
38 | x: torch_geometric.data.DataLoader(
39 | self.datasets[x], batch_size=args.batch_size, follow_batch=follow_batch,
40 | shuffle=True if x == 'train' else False, drop_last=True,
41 | num_workers=args.num_workers, prefetch_factor=8)
42 | for x in ['train', 'valid']
43 | }
44 |
45 | self.log_dir = logger.get_dir()
46 | self.bce_logit_loss = nn.BCEWithLogitsLoss()
47 | self.load_epoch = 0
48 |
49 | def generate_dataset(self):
50 | os.system('mkdir -p ' + self.args.dataf)
51 | for phase in ['train', 'valid']:
52 | self.datasets[phase].generate_dataset()
53 | print('Dataset generated in', self.args.dataf)
54 |
55 | def plot(self, phase, epoch, i):
56 | data_folder = osp.join(self.args.dataf, phase)
57 | traj_ids = np.random.randint(0, len(os.listdir(data_folder)), self.args.plot_num)
58 | step_ids = np.random.randint(self.args.n_his, self.args.time_step - self.args.n_his, self.args.plot_num)
59 | pred_accs, pred_mesh_edges, gt_mesh_edges, edges, positionss, rgbs = [], [], [], [], [], []
60 | for idx, (traj_id, step_id) in enumerate(zip(traj_ids, step_ids)):
61 | pred_mesh_edge, gt_mesh_edge, edge, positions, rgb = self.load_data_and_predict(traj_id, step_id, self.datasets[phase])
62 | pred_acc = np.mean(pred_mesh_edge == gt_mesh_edge)
63 | pred_accs.append(pred_acc)
64 |
65 | if idx < 3: # plot the first 4 edge predictions
66 | pred_mesh_edges.append(pred_mesh_edge)
67 | gt_mesh_edges.append(gt_mesh_edge)
68 | edges.append(edge)
69 | positionss.append(positions)
70 | rgbs.append(rgb)
71 |
72 | fig = plt.figure(figsize=(30, 30))
73 | for idx in range(min(3, len(positionss))):
74 | pos, edge, pred_mesh_edge, gt_mesh_edge = positionss[idx], edges[idx], pred_mesh_edges[idx], gt_mesh_edges[idx]
75 |
76 | predict_ax = fig.add_subplot(3, 3, idx * 3 + 1, projection='3d')
77 | gt_ax = fig.add_subplot(3, 3, idx * 3 + 2, projection='3d')
78 | both_ax = fig.add_subplot(3, 3, idx * 3 + 3, projection='3d')
79 |
80 | for edge_idx in range(edge.shape[1]):
81 | s = int(edge[0][edge_idx])
82 | r = int(edge[1][edge_idx])
83 | if pred_mesh_edge[edge_idx]:
84 | predict_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='r')
85 | both_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='r')
86 | if gt_mesh_edge[edge_idx]:
87 | gt_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='g')
88 | both_ax.plot([pos[s, 0], pos[r, 0]], [pos[s, 1], pos[r, 1]], [pos[s, 2], pos[r, 2]], c='g')
89 |
90 | gt_ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='g', s=20)
91 | predict_ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='r', s=20)
92 | both_ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='g', s=20)
93 |
94 | plt.savefig(osp.join(self.log_dir, 'edge-prediction-{}-{}-{}.png'.format(phase, epoch, i)))
95 | plt.close('all')
96 |
97 | if rgbs[0] is not None:
98 | fig, axes = plt.subplots(3, 2, figsize=(30, 20))
99 | for idx in range(min(3, len(positionss))):
100 | rgb, gt_mesh_edge, pointcloud = rgbs[idx], gt_mesh_edges[idx], positionss[idx]
101 | pred_mesh_edge, edge = pred_mesh_edges[idx], edges[idx]
102 |
103 | height, width, _ = rgb.shape
104 | matrix_world_to_camera = get_matrix_world_to_camera(
105 | self.env.camera_params[self.env.camera_name]['pos'], self.env.camera_params[self.env.camera_name]['angle']
106 | )
107 | matrix_world_to_camera = matrix_world_to_camera[:3, :] # 3 x 4
108 | u, v = project_to_image(matrix_world_to_camera, pointcloud, height, width)
109 |
110 | predict_ax_2 = axes[idx][0]
111 | true_ax_2 = axes[idx][1]
112 |
113 | predict_ax_2.imshow(rgb)
114 | true_ax_2.imshow(rgb)
115 |
116 | for edge_idx in range(edge.shape[1]):
117 | s = int(edge[0][edge_idx])
118 | r = int(edge[1][edge_idx])
119 | if pred_mesh_edge[edge_idx]:
120 | predict_ax_2.plot([u[s], u[r]], [v[s], v[r]], c='r', linewidth=0.5)
121 | if gt_mesh_edge[edge_idx]:
122 | true_ax_2.plot([u[s], u[r]], [v[s], v[r]], c='r', linewidth=0.5)
123 |
124 | predict_ax_2.set_title("predicted edge on point cloud")
125 | true_ax_2.set_title("mesh edge on particles")
126 | predict_ax_2.scatter(u, v, c='r', s=2)
127 | true_ax_2.scatter(u, v, c='r', s=2)
128 |
129 | plt.savefig(osp.join(self.log_dir, 'edge-projected-{}-{}-{}.png'.format(phase, epoch, i)))
130 | plt.close('all')
131 |
132 | return pred_accs
133 |
134 | def train(self):
135 |
136 | # Training loop
137 | st_epoch = self.load_epoch
138 | best_valid_loss = np.inf
139 | for epoch in range(st_epoch, self.args.n_epoch):
140 | phases = ['train', 'valid'] if self.args.eval == 0 else ['valid']
141 | for phase in phases:
142 | self.set_mode(phase)
143 | epoch_info = AggDict(is_detach=True)
144 |
145 | for i, data in tqdm(enumerate(self.dataloaders[phase]), desc=f'Epoch {epoch}, phase {phase}'):
146 | data = data.to(self.device).to_dict()
147 | iter_info = AggDict(is_detach=False)
148 | last_global = torch.zeros(self.args.batch_size, self.args.global_size, dtype=torch.float32, device=self.device)
149 | with torch.set_grad_enabled(phase == 'train'):
150 | data['u'] = last_global
151 | pred_mesh_edge = self.model(data)
152 | loss = self.bce_logit_loss(pred_mesh_edge['mesh_edge'], data['gt_mesh_edge']) # TODO change accel to eedge
153 | iter_info.add_item('loss', loss)
154 |
155 | if phase == 'train':
156 | self.optim.zero_grad()
157 | loss.backward()
158 | self.optim.step()
159 |
160 | epoch_info.update_by_add(iter_info)
161 | iter_info.clear()
162 |
163 | epoch_len = len(self.dataloaders[phase])
164 | if i == len(self.dataloaders[phase]) - 1:
165 | avg_dict = epoch_info.get_mean('{}/'.format(phase), epoch_len)
166 | avg_dict['lr'] = self.optim.param_groups[0]['lr']
167 | for k, v in avg_dict.items():
168 | logger.record_tabular(k, v)
169 |
170 | pred_accs = self.plot(phase, epoch, i)
171 |
172 | logger.record_tabular(phase + '/epoch', epoch)
173 | logger.record_tabular(phase + '/pred_acc', np.mean(pred_accs))
174 | logger.dump_tabular()
175 |
176 | if phase == 'train' and i == len(self.dataloaders[phase]) - 1:
177 | suffix = '{}'.format(epoch)
178 | self.model.save_model(self.log_dir, 'vsbl', suffix, self.optim)
179 |
180 | print('%s [%d/%d] Loss: %.4f, Best valid: %.4f' %
181 | (phase, epoch, self.args.n_epoch, avg_dict[f'{phase}/loss'], best_valid_loss))
182 |
183 | if phase == 'valid':
184 | cur_loss = avg_dict[f'{phase}/loss']
185 | self.scheduler.step(cur_loss)
186 | if (cur_loss < best_valid_loss):
187 | best_valid_loss = cur_loss
188 | state_dict = self.args.__dict__
189 | state_dict['best_epoch'] = epoch
190 | state_dict['best_valid_loss'] = cur_loss
191 | with open(osp.join(self.log_dir, 'best_state.json'), 'w') as f:
192 | json.dump(state_dict, f, indent=2, sort_keys=True)
193 | self.model.save_model(self.log_dir, 'vsbl', 'best', self.optim)
194 |
195 | def load_data_and_predict(self, rollout_idx, timestep, dataset):
196 | args = self.args
197 | self.set_mode('eval')
198 |
199 | idx = rollout_idx * (self.args.time_step - self.args.n_his) + timestep
200 | data_ori = dataset._prepare_transition(idx)
201 | data = dataset.build_graph(data_ori)
202 |
203 | gt_mesh_edge = data['gt_mesh_edge'].detach().cpu().numpy()
204 |
205 | with torch.no_grad():
206 | data['x_batch'] = torch.zeros(data['x'].size(0), dtype=torch.long, device=self.device)
207 | data['u'] = torch.zeros([1, self.args.global_size], device=self.device)
208 | for key in ['x', 'edge_index', 'edge_attr']:
209 | data[key] = data[key].to(self.device)
210 | pred_mesh_edge = self.model(data)['mesh_edge']
211 |
212 | pred_mesh_edge_logits = pred_mesh_edge.cpu().numpy()
213 | pred_mesh_edge = pred_mesh_edge_logits > 0
214 | edges = data['edge_index'].detach().cpu().numpy()
215 |
216 | return pred_mesh_edge, gt_mesh_edge, edges, data_ori['normalized_vox_pc'], None
217 |
218 | def infer_mesh_edges(self, args):
219 | """
220 | args: a dict
221 | scene_params
222 | pointcloud
223 | cuda_idx
224 | """
225 | scene_params = args['scene_params']
226 | point_cloud = args['pointcloud']
227 | cuda_idx = args.get('cuda_idx', 0)
228 |
229 | self.set_mode('eval')
230 | if cuda_idx >= 0:
231 | self.to(cuda_idx)
232 | edge_dataset = self.datasets['train']
233 |
234 | normalized_point_cloud = point_cloud - np.mean(point_cloud, axis=0)
235 | data_ori = {
236 | 'scene_params': scene_params,
237 | 'observable_idx': None,
238 | 'normalized_vox_pc': normalized_point_cloud,
239 | 'pc_to_mesh_mapping': None
240 | }
241 | data = edge_dataset.build_graph(data_ori, get_gt_edge_label=False)
242 | with torch.no_grad():
243 | data['x_batch'] = torch.zeros(data['x'].size(0), dtype=torch.long, device=self.device)
244 | data['u'] = torch.zeros([1, self.args.global_size], device=self.device)
245 | for key in ['x', 'edge_index', 'edge_attr']:
246 | data[key] = data[key].to(self.device)
247 | pred_mesh_edge_logits = self.model(data)['mesh_edge']
248 |
249 | pred_mesh_edge_logits = pred_mesh_edge_logits.cpu().numpy()
250 | pred_mesh_edge = pred_mesh_edge_logits > 0
251 |
252 | edges = data['edge_index'].detach().cpu().numpy()
253 | senders = []
254 | receivers = []
255 | num_edges = edges.shape[1]
256 | for e_idx in range(num_edges):
257 | if pred_mesh_edge[e_idx]:
258 | senders.append(int(edges[0][e_idx]))
259 | receivers.append(int(edges[1][e_idx]))
260 |
261 | mesh_edges = np.vstack([senders, receivers])
262 | return mesh_edges
263 |
264 | def to(self, cuda_idx):
265 | self.model.to(torch.device("cuda:{}".format(cuda_idx)))
266 |
267 | def set_mode(self, mode='train'):
268 | self.model.set_mode('train' if mode == 'train' else 'eval')
269 |
270 | def load_model(self, load_optim=False):
271 | self.model.load_model(self.args.edge_model_path, load_optim=load_optim, optim=self.optim)
272 | self.load_epoch = extract_numbers(self.args.edge_model_path)[-1]
273 |
--------------------------------------------------------------------------------
/chester/README.md:
--------------------------------------------------------------------------------
1 | # Chester
2 |
3 | Chester is a tool aiming at automatically launching experiments. This tool based on rllab(https://github.com/rll/rllab ), and further extended for launching and retrieving experiments in different remote machines, including:
4 | 1. Seuss
5 | 2. PSC (Pittsburgh Super Computing)
6 |
7 | ## Getting Started
8 |
9 | We've provided an example for launching experiments of openai/baseline's DDPG algorithm.
10 |
11 | Look into the /examples, you'll find 'train_luanch.py' and 'train.py'. 'train.py' is the parser where we copied a lot of codes in openai/baseline/ddpg/main.py and combined them as a function 'run_task'. 'run_task' receives the parameters and start running the DDPG algorithm with those given settings.
12 |
13 | The launcher 'train_launch.py' uses our chester and the 'run_task' function to launch a group of experiments locally. By running this launcher, the group experiments are started and the resutls are contained in one given folder. Those result files are able to be visulized with rllab's viskit.
14 |
15 | To support different options in visulization, chester provided self-written interface 'preset.py'. The author can write different custom splitters in this file and put it in the directory for experiments. The viskit tool can detect this preset file and apply different options.
16 |
17 |
18 | ### Prerequisites
19 |
20 | What things you need to install the software and how to install them
21 |
22 | ```
23 | Give examples
24 | ```
25 |
26 | ### Installation
27 |
28 | A step by step series of examples that tell you how to get a development env running
29 |
30 | Say what the step will be
31 |
32 | ```
33 | Give the example
34 | ```
35 |
36 | And repeat
37 |
38 | ```
39 | until finished
40 | ```
41 |
42 | End with an example of getting some data out of the system or using it for a little demo
43 |
44 | ## Running the tests
45 |
46 | Explain how to run the automated tests for this system
47 |
48 | ### Break down into end to end tests
49 |
50 | Explain what these tests test and why
51 |
52 | ```
53 | Give an example
54 | ```
55 |
56 | ### And coding style tests
57 |
58 | Explain what these tests test and why
59 |
60 | ```
61 | Give an example
62 | ```
63 |
64 | ## Deployment
65 |
66 | Add additional notes about how to deploy this on a live system
67 |
68 | ## Built With
69 |
70 | * [Dropwizard](http://www.dropwizard.io/1.0.2/docs/) - The web framework used
71 | * [Maven](https://maven.apache.org/) - Dependency Management
72 | * [ROME](https://rometools.github.io/rome/) - Used to generate RSS Feeds
73 |
74 | ## Contributing
75 |
76 | Please read [CONTRIBUTING.md](https://gist.github.com/PurpleBooth/b24679402957c63ec426) for details on our code of conduct, and the process for submitting pull requests to us.
77 |
78 | ## Versioning
79 |
80 | We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/your/project/tags).
81 |
82 | ## Authors
83 |
84 | * **Billie Thompson** - *Initial work* - [PurpleBooth](https://github.com/PurpleBooth)
85 |
86 | See also the list of [contributors](https://github.com/your/project/contributors) who participated in this project.
87 |
88 | ## License
89 |
90 | This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details
91 |
92 | ## Acknowledgments
93 |
94 | * Hat tip to anyone whose code was used
95 | * Inspiration
96 | * etc
97 |
98 |
--------------------------------------------------------------------------------
/chester/add_variants.py:
--------------------------------------------------------------------------------
1 | # Add variants to finished experiments
2 | import argparse
3 | import os
4 | import json
5 | from pydoc import locate
6 | import config
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('exp_folder', type=str, help='Root of the experiment folder to walk through')
12 | parser.add_argument('key', type=str, help='Name of the additional key')
13 | parser.add_argument('value', help='Value of the additional key')
14 | parser.add_argument('value_type', default='str', type=str, help='Type of the additional key')
15 | parser.add_argument('remote', nargs='?', default=None, type=str, ) # Optional
16 |
17 | args = parser.parse_args()
18 | exp_paths = [x[0] for x in os.walk(args.exp_folder, followlinks=True)]
19 |
20 | value_type = locate(args.value_type)
21 | if value_type == bool:
22 | value = args.value in ['1', 'True', 'true']
23 | else:
24 | value = value_type(args.value)
25 |
26 | for exp_path in exp_paths:
27 | try:
28 | variant_path = os.path.join(exp_path, "variant.json")
29 | # Modify locally
30 | with open(variant_path, 'r') as f:
31 | vv = json.load(f)
32 | if args.key in vv:
33 | print('Warning: key already in variants. {} = {}. Setting it to {}'.format(args.key, vv[args.key], value))
34 |
35 | vv[args.key] = value
36 | with open(variant_path, 'w') as f:
37 | json.dump(vv, f, indent=2, sort_keys=True)
38 | print('{} modified'.format(variant_path))
39 |
40 | # Upload it to remote
41 | if args.remote is not None:
42 | p = variant_path.rstrip('/').split('/')
43 | sub_exp_name, exp_name = p[-2], p[-3]
44 |
45 | remote_dir = os.path.join(config.REMOTE_DIR[args.remote], 'data', 'local', exp_name, sub_exp_name, 'variant.json')
46 | os.system('scp {} {}:{}'.format(variant_path, args.remote, remote_dir))
47 | except IOError as e:
48 | print(e)
49 |
50 |
51 | if __name__ == '__main__':
52 | main()
53 |
--------------------------------------------------------------------------------
/chester/config.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 |
4 | # TODO change this before make it into a pip package
5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
6 |
7 | LOG_DIR = os.path.join(PROJECT_PATH, "data")
8 |
9 | # Make sure to use absolute path
10 | REMOTE_DIR = {
11 | }
12 |
13 | REMOTE_MOUNT_OPTION = {
14 | }
15 |
16 | REMOTE_LOG_DIR = {
17 | }
18 |
19 | REMOTE_HEADER = dict()
20 |
21 | # location of the singularity file related to the project
22 | SIMG_DIR = {
23 | }
24 | CUDA_MODULE = {
25 | }
26 | MODULES = {
27 | }
--------------------------------------------------------------------------------
/chester/containers/ubuntu-16.04-lts-rl.README:
--------------------------------------------------------------------------------
1 | Bootstrap: debootstrap
2 | OSVersion: xenial
3 | MirrorURL: http://us.archive.ubuntu.com/ubuntu/
4 |
5 | %help
6 | This is a singularity container that runs Deep Reinforcement Learning algorithms on ubuntu
7 | Packages installed include:
8 | * cuda 9.0 and cuDNN
9 | Will run ~/.bashrc on start to make sure the PATH is the same.
10 |
11 | %runscript
12 | /usr/bin/nvidia-smi -L
13 |
14 | %environment
15 | LD_LIBRARY_PATH=/usr/local/cuda-9.0/cuda/lib64:/usr/local/cuda-9.0/lib64:/usr/lib/nvidia-384$LD_LIBRARY_PATH
16 |
17 | %setup
18 | echo "Let us have CUDA..."
19 | sh /home/xingyu/software/cuda/cuda_9.0.176_384.81_linux.run --silent --toolkit --toolkitpath=${SINGULARITY_ROOTFS}/usr/local/cuda-9.0
20 | ln -s ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0 ${SINGULARITY_ROOTFS}/usr/local/cuda
21 | echo "Let us also have cuDNN..."
22 | cp -prv /home/xingyu/software/cudnn/* ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0/
23 |
24 | %labels
25 | AUTHOR xlin3@cs.cmu.edu
26 | VERSION v1.0
27 |
28 | %post
29 | echo "Hello from inside the container"
30 | sed -i 's/$/ universe/' /etc/apt/sources.list
31 | touch /usr/bin/nvidia-smi
32 | chmod +x /usr/bin/nvidia-smi
33 |
34 | apt-get -y update
35 | apt-get -y install software-properties-common vim make wget curl emacs ffmpeg git htop libffi-dev libglew-dev libgl1-mesa-glx libosmesa6 libosmesa6-dev libssl-dev mesa-utils module-init-tools openjdk-8-jdk python-dev python-numpy python-tk bzip2
36 | apt-get -y install build-essential
37 | apt-get -y install libgl1-mesa-dev libglfw3-dev
38 | apt-get -y install strace
39 |
40 | echo "Install openmpi 3.1.1"
41 |
42 | cd /tmp
43 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.1.tar.gz
44 | tar xf openmpi-3.1.1.tar.gz
45 | cd openmpi-3.1.1
46 | mkdir -p build
47 | cd build
48 | ../configure
49 | make -j 8 all
50 | make install
51 | apt-get -y install openmpi-bin
52 | rm -rf /tmp/openmpi*
53 | rm -rf /usr/bin/mpirun
54 | ln -s /usr/local/bin/mpirun /usr/bin/mpirun
55 |
56 | echo "Install mpi4py 3.0.0"
57 | cd /tmp
58 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz
59 | tar -zxf mpi4py-3.0.0.tar.gz
60 | cd mpi4py-3.0.0
61 | python setup.py build --mpicc=
62 | python setup.py install --user
63 | mkdir -p /usr/lib/nvidia-384
--------------------------------------------------------------------------------
/chester/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = source
8 | BUILDDIR = build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/chester/docs/README.md:
--------------------------------------------------------------------------------
1 | ## References in making this documentation
2 | * [Getting started on Sphinx](https://pythonhosted.org/an_example_pypi_project/sphinx.html)
3 | * [Read the Docs Sphinx Theme](https://github.com/rtfd/sphinx_rtd_theme)
4 | * [reStructuredText](http://docutils.sourceforge.net/docs/user/rst/quickref.html)
5 |
--------------------------------------------------------------------------------
/chester/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/chester/docs/readme.md:
--------------------------------------------------------------------------------
1 | ## References in making this documentation
2 | * [Getting started on Sphinx](https://pythonhosted.org/an_example_pypi_project/sphinx.html)
3 | * [Read the Docs Sphinx Theme](https://github.com/rtfd/sphinx_rtd_theme)
4 | * [reStructuredText](http://docutils.sourceforge.net/docs/user/rst/quickref.html)
5 |
--------------------------------------------------------------------------------
/chester/docs/source/_static/basic_screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/chester/docs/source/_static/basic_screenshot.png
--------------------------------------------------------------------------------
/chester/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # http://www.sphinx-doc.org/en/master/config
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'Chester'
21 | copyright = '2019, Xingyu Lin'
22 | author = 'Xingyu Lin'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = 'v0.1'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | ]
35 |
36 | # Add any paths that contain templates here, relative to this directory.
37 | templates_path = ['_templates']
38 |
39 | # List of patterns, relative to source directory, that match files and
40 | # directories to ignore when looking for source files.
41 | # This pattern also affects html_static_path and html_extra_path.
42 | exclude_patterns = []
43 |
44 |
45 | # -- Options for HTML output -------------------------------------------------
46 |
47 | # The theme to use for HTML and HTML Help pages. See the documentation for
48 | # a list of builtin themes.
49 | #
50 | html_theme = 'sphinx_rtd_theme'
51 |
52 | # Add any paths that contain custom static files (such as style sheets) here,
53 | # relative to this directory. They are copied after the builtin static files,
54 | # so a file named "default.css" will overwrite the builtin "default.css".
55 | html_static_path = ['_static']
56 |
--------------------------------------------------------------------------------
/chester/docs/source/getting_started.rst:
--------------------------------------------------------------------------------
1 | .. _getting_started:
2 |
3 |
4 | ***************
5 | Getting started
6 | ***************
7 |
8 | .. _installing-docdir:
9 |
10 | Installing your doc directory
11 | =============================
12 |
13 | You may already have sphinx `sphinx `_
14 | installed -- you can check by doing::
15 |
16 | python -c 'import sphinx'
17 |
18 | If that fails grab the latest version of and install it with::
19 |
20 | > sudo easy_install -U Sphinx
21 |
22 | Now you are ready to build a template for your docs, using
23 | sphinx-quickstart::
24 |
25 | > sphinx-quickstart
26 |
27 | accepting most of the defaults. I choose "sampledoc" as the name of my
28 | project. cd into your new directory and check the contents::
29 |
30 | home:~/tmp/sampledoc> ls
31 | Makefile _static conf.py
32 | _build _templates index.rst
33 |
34 | The index.rst is the master ReST for your project, but before adding
35 | anything, let's see if we can build some html::
36 |
37 | make html
38 |
39 | If you now point your browser to :file:`_build/html/index.html`, you
40 | should see a basic sphinx site.
41 |
42 | .. image:: _static/basic_screenshot.png
43 |
44 | .. _fetching-the-data:
45 |
46 | Fetching the data
47 | -----------------
48 |
49 | Now we will start to customize out docs. Grab a couple of files from
50 | the `web site `_
51 | or git. You will need :file:`getting_started.rst` and
52 | :file:`_static/basic_screenshot.png`. All of the files live in the
53 | "completed" version of this tutorial, but since this is a tutorial,
54 | we'll just grab them one at a time, so you can learn what needs to be
55 | changed where. Since we have more files to come, I'm going to grab
56 | the whole git directory and just copy the files I need over for now.
57 | First, I'll cd up back into the directory containing my project, check
58 | out the "finished" product from git, and then copy in just the files I
59 | need into my :file:`sampledoc` directory::
60 |
61 | home:~/tmp/sampledoc> pwd
62 | /Users/jdhunter/tmp/sampledoc
63 | home:~/tmp/sampledoc> cd ..
64 | home:~/tmp> git clone https://github.com/matplotlib/sampledoc.git tutorial
65 | Cloning into 'tutorial'...
66 | remote: Counting objects: 87, done.
67 | remote: Compressing objects: 100% (43/43), done.
68 | remote: Total 87 (delta 45), reused 83 (delta 41)
69 | Unpacking objects: 100% (87/87), done.
70 | Checking connectivity... done
71 | home:~/tmp> cp tutorial/getting_started.rst sampledoc/
72 | home:~/tmp> cp tutorial/_static/basic_screenshot.png sampledoc/_static/
73 |
74 | The last step is to modify :file:`index.rst` to include the
75 | :file:`getting_started.rst` file (be careful with the indentation, the
76 | "g" in "getting_started" should line up with the ':' in ``:maxdepth``::
77 |
78 | Contents:
79 |
80 | .. toctree::
81 | :maxdepth: 2
82 |
83 | getting_started.rst
84 |
85 | and then rebuild the docs::
86 |
87 | cd sampledoc
88 | make html
89 |
90 |
91 | When you reload the page by refreshing your browser pointing to
92 | :file:`_build/html/index.html`, you should see a link to the
93 | "Getting Started" docs, and in there this page with the screenshot.
94 | `Voila!`
95 |
96 | We can also use the image directive in :file:`index.rst` to include to the screenshot above
97 | with::
98 |
99 | .. image::
100 | _static/basic_screenshot.png
101 |
102 |
103 | Next we'll customize the look and feel of our site to give it a logo,
104 | some custom css, and update the navigation panels to look more like
105 | the `sphinx `_ site itself -- see
106 | :ref:`custom_look`.
--------------------------------------------------------------------------------
/chester/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Chester documentation master file, created by
2 | sphinx-quickstart on Fri Apr 5 16:47:55 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Chester's documentation!
7 | ===================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | getting_started.rst
14 | launcher.rst
15 | logger.rst
16 | visualization.rst
17 |
18 |
19 |
20 | Indices and tables
21 | ==================
22 |
23 | * :ref:`genindex`
24 | * :ref:`modindex`
25 | * :ref:`search`
26 |
--------------------------------------------------------------------------------
/chester/docs/source/launcher.rst:
--------------------------------------------------------------------------------
1 | .. _launcher:
2 |
3 |
4 | ***************
5 | Launcher
6 | ***************
7 |
--------------------------------------------------------------------------------
/chester/docs/source/logger.rst:
--------------------------------------------------------------------------------
1 | .. _logger:
2 |
3 |
4 | ***************
5 | Logger
6 | ***************
7 | s
--------------------------------------------------------------------------------
/chester/docs/source/visualization.rst:
--------------------------------------------------------------------------------
1 | .. _visualization:
2 |
3 |
4 | ***************
5 | Visualization
6 | ***************
7 |
8 | .. _Batch-Plotting:
9 |
10 | Batch Plotting
11 | ==============
12 |
13 | If the data are logged with Chester, they can also be easily plotted in batch.
14 | After the data are logged, for each experiment, the hyper-parameters are stored in ``variants.json`` and different
15 | key values are stored in ``progress.csv``. ``chester/plotting/cplots.py`` offers the functions that can be used to
16 | organize different experiments based on their key values:
17 |
18 | - ``reload_data()``: Iterate through the data folder and organize each experiment into a list, with their progress data, hyper-parameters and also analyze all the curves and give the distinct hyper-parameters.
19 | - ``get_group_selectors()``: You should write a ``custom_series_splitter()``, which provides a legend for each experiment based on its hyper-parameters. This function will then group all the experiments by their legends.
20 | - ``get_shaded_curve()``: Create the needed y-values for plots with shades (representing variance or median) for a certain key value.
21 |
22 | A data structure from rllab visualization kit can be useful: ``Selector``. It can be constructed from the loaded
23 | experiments data structure::
24 |
25 | from rllab.viskit import core
26 | exps_data, plottable_keys, distinct_params = reload_data(path_to_data_folder)
27 | selector = Selector(exps_data)
28 |
29 | After that, it can be used to extract progress infomation for a certain key value::
30 |
31 | progresses = [exp.progress.get(key)) for exp in selector.extract()]
32 |
33 | or be filtered based on certain hyper-parameters::
34 |
35 | selector = selector.where('env_name', env_name)
36 |
37 | Some examples can be found in both ``chester/cplots.py`` and ``chester/examples/cplot_example.py``
38 |
39 | .. _Interactive-Frontend:
40 |
41 | Interactive Frontend
42 | ====================
43 |
44 | Currently the interactive visualization feature is still coupled with the rllab.
45 | It can be accessed by doing::
46 |
47 | python rllab/viskit/frontend.py
48 |
49 | .. _Preset:
50 |
51 | Preset
52 | ------
53 | You may want to use a complex legend post-processor or splitter.
54 | The preset feature can be used to save such a setting. First write a ``presets.py``. Then, put it in the root of the
55 | data folder that you want to visualize. Now when you use the frontend visualization, there will be a preset button that
56 | you can choose. Some exmples of ``presets.py`` can be found at ``chester/examples``
57 |
--------------------------------------------------------------------------------
/chester/examples/cplot_example.py:
--------------------------------------------------------------------------------
1 | from chester.plotting.cplot import *
2 | import os.path as osp
3 |
4 |
5 | def custom_series_splitter(x):
6 | params = x['flat_params']
7 | if 'use_ae_reward' in params and params['use_ae_reward']:
8 | return 'Auto Encoder'
9 | if params['her_replay_strategy'] == 'balance_filter':
10 | return 'Indicator+Balance+Filter'
11 | if params['env_kwargs.use_true_reward']:
12 | return 'Oracle'
13 | return 'Indicator'
14 |
15 |
16 | dict_leg2col = {"Oracle": 1, "Indicator": 0, 'Indicator+Balance+Filter': 2, "Auto Encoder": 3}
17 | save_path = './data/plots_chester'
18 |
19 |
20 | def plot_visual_learning():
21 | data_path = './data/nsh/submit_rss/submit_rss/visual_learning'
22 |
23 | plot_keys = ['test/success_state', 'test/goal_dist_final_state']
24 | plot_ylabels = ['Success', 'Final Distance to Goal']
25 | plot_envs = ['FetchReach', 'Reacher', 'RopeFloat']
26 |
27 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
28 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
29 | for (plot_key, plot_ylabel) in zip(plot_keys, plot_ylabels):
30 | for env_name in plot_envs:
31 | fig, ax = plt.subplots(figsize=(8, 5))
32 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
33 | color = core.color_defaults[dict_leg2col[legend]]
34 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key,
35 | shade_type='median')
36 |
37 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"]
38 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode')
39 | x = [ele * env_horizon for ele in x]
40 |
41 | ax.plot(x, y, color=color, label=legend, linewidth=2.0)
42 |
43 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0,
44 | alpha=0.2)
45 |
46 | def y_fmt(x, y):
47 | return str(int(np.round(x / 1000.0))) + 'K'
48 |
49 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
50 | ax.grid(True)
51 | ax.set_xlabel('Timesteps')
52 | ax.set_ylabel(plot_ylabel)
53 | axes = plt.gca()
54 | if 'Rope' in env_name:
55 | axes.set_xlim(left=20000)
56 |
57 | plt.title(env_name.replace('Float', 'Push'))
58 | loc = 'best'
59 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends)
60 | for legobj in leg.legendHandles:
61 | legobj.set_linewidth(3.0)
62 |
63 | save_name = filter_save_name('ind_visual_' + plot_key + '_' + env_name)
64 |
65 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight')
66 |
67 |
68 | def plot_state_learning():
69 | data_path = './data/nsh/submit_rss/submit_rss/state_learning'
70 |
71 | plot_keys = ['test/success_state', 'test/goal_dist_final_state']
72 | plot_envs = ['FetchReach', 'FetchPush', 'Reacher', 'RopeFloat']
73 |
74 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
75 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
76 | for plot_key in plot_keys:
77 | for env_name in plot_envs:
78 | fig, ax = plt.subplots(figsize=(8, 5))
79 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
80 | color = core.color_defaults[dict_leg2col[legend]]
81 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key,
82 | shade_type='median')
83 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"]
84 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode')
85 | x = [ele * env_horizon for ele in x]
86 | ax.plot(x, y, color=color, label=legend, linewidth=2.0)
87 |
88 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0,
89 | alpha=0.2)
90 |
91 | def y_fmt(x, y):
92 | return str(int(np.round(x / 1000.0))) + 'K'
93 |
94 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
95 | ax.grid(True)
96 | ax.set_xlabel('Timesteps')
97 | ax.set_ylabel('Success')
98 |
99 | plt.title(env_name.replace('Float', 'Push'))
100 | loc = 'best'
101 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends)
102 | for legobj in leg.legendHandles:
103 | legobj.set_linewidth(3.0)
104 |
105 | save_name = filter_save_name('ind_state_' + plot_key + '_' + env_name)
106 |
107 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight')
108 |
109 |
110 | if __name__ == '__main__':
111 | plot_visual_learning()
112 | plot_state_learning()
113 |
--------------------------------------------------------------------------------
/chester/examples/pgm_plot.py:
--------------------------------------------------------------------------------
1 | from chester.plotting.cplot import *
2 | import os.path as osp
3 | from random import shuffle
4 |
5 | save_path = '../sac/data/plots'
6 | dict_leg2col = {"LSP": 0, "Base": 1, "Behavior": 2}
7 | dict_xshift = {"LSP": 4000, "Base": 0, "Behavior": 6000}
8 |
9 |
10 | def custom_series_splitter(x):
11 | params = x['flat_params']
12 | exp_name = params['exp_name']
13 | dict_mapping = {'humanoid-resume-training-6000-00': 'Behavior',
14 | 'humanoid-resume-training-4000-00': 'LSP',
15 | 'humanoid-rllab/default-2019-04-14-07-04-08-421230-UTC-00': 'Base'}
16 | return dict_mapping[exp_name]
17 |
18 |
19 | def sliding_mean(data_array, window=5):
20 | data_array = np.array(data_array)
21 | new_list = []
22 | for i in range(len(data_array)):
23 | indices = list(range(max(i - window + 1, 0),
24 | min(i + window + 1, len(data_array))))
25 | avg = 0
26 | for j in indices:
27 | avg += data_array[j]
28 | avg /= float(len(indices))
29 | new_list.append(avg)
30 |
31 | return np.array(new_list)
32 |
33 |
34 | def plot_main():
35 | data_path = '../sac/data/mengxiong'
36 | plot_key = 'return-average'
37 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
38 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
39 | fig, ax = plt.subplots(figsize=(8, 5))
40 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
41 | color = core.color_defaults[dict_leg2col[legend]]
42 |
43 | y, y_lower, y_upper = get_shaded_curve(selector, plot_key, shade_type='median')
44 | x = np.array(range(len(y)))
45 | x += dict_xshift[legend]
46 | y = sliding_mean(y, 5)
47 | ax.plot(x, y, color=color, label=legend, linewidth=2.0)
48 |
49 | # ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0,
50 | # alpha=0.2)
51 |
52 | def y_fmt(x, y):
53 | return str(int(np.round(x))) + 'K'
54 |
55 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
56 | ax.grid(True)
57 | ax.set_xlabel('Timesteps')
58 | ax.set_ylabel('Average-return')
59 |
60 | # plt.title(env_name.replace('Float', 'Push'))
61 | loc = 'best'
62 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends)
63 | for legobj in leg.legendHandles:
64 | legobj.set_linewidth(3.0)
65 |
66 | save_name = filter_save_name('plots.png')
67 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight')
68 |
69 |
70 | if __name__ == '__main__':
71 | plot_main()
72 |
--------------------------------------------------------------------------------
/chester/examples/presets.py:
--------------------------------------------------------------------------------
1 | preset_names = ['default']
2 |
3 | def make_custom_seris_splitter(preset_names):
4 | legendNote = None
5 | if preset_names == 'default':
6 | def custom_series_splitter(x):
7 | params = x['flat_params']
8 | # return params['her_replay_strategy']
9 | if params['her_replay_strategy'] == 'future':
10 | ret = 'RG'
11 | elif params['her_replay_strategy'] == 'only_fake':
12 | if params['her_use_reward']:
13 | ret = 'FG+RR'
14 | else:
15 | ret = 'FG+FR'
16 | return ret + '+' + str(params['her_clip_len']) + '+' + str(params['her_reward_choices']) + '+' + str(
17 | params['her_failed_goal_ratio'])
18 |
19 | legendNote = "Fake Goal(FG)/Real Goal(RG) + Fake Reward(FR)/Real Goal(RG) + HER_clip_len + HER_reward_choices + HER_failed_goal_ratio"
20 | else:
21 | raise NotImplementedError
22 | return custom_series_splitter, legendNote
23 |
--------------------------------------------------------------------------------
/chester/examples/presets2.py:
--------------------------------------------------------------------------------
1 | preset_names = ['default']
2 | x_axis = 'Epoch'
3 | y_axis = 'Success'
4 | FILTERED = 'filtered'
5 |
6 | def make_custom_seris_splitter(preset_names):
7 | legendNote = None
8 | if preset_names == 'default':
9 | def custom_series_splitter(x):
10 | params = x['flat_params']
11 | if params['her_failed_goal_option'] is None:
12 | ret = 'Distance Reward'
13 | elif params['her_failed_goal_option'] == 'dist_behaviour':
14 | ret = 'Exact Match'
15 | else:
16 | ret = FILTERED
17 | return ret
18 |
19 | legendNote = None
20 | else:
21 | raise NotImplementedError
22 | return custom_series_splitter, legendNote
23 |
24 |
25 | def make_custom_filter(preset_names):
26 | if preset_names == 'default':
27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names)
28 | def custom_filter(x):
29 | legend = custom_seris_splitter(x)
30 | if legend == FILTERED:
31 | return False
32 | else:
33 | return True
34 | # params = x['flat_params']
35 | # if params['her_failed_goal_option'] != FILTERED:
36 | # return True
37 | # else:
38 | # return False
39 | return custom_filter
40 |
41 |
--------------------------------------------------------------------------------
/chester/examples/presets3.py:
--------------------------------------------------------------------------------
1 | set1 = 'identity_ratio+her_clip(dist_behavior and HER)'
2 | preset_names = [set1]
3 | FILTERED = 'filtered'
4 |
5 |
6 | def make_custom_seris_splitter(preset_names):
7 | legendNote = None
8 | if preset_names == set1:
9 | def custom_series_splitter(x):
10 | params = x['flat_params']
11 | if params['her_failed_goal_option'] in ['dist_G', 'dist_policy']:
12 | return FILTERED
13 | if params['her_identity_ratio'] is not None:
14 | return 'IR: ' + str(params['her_identity_ratio'])
15 | if params['her_clip_len'] is not None:
16 | return 'CL: ' + str(params['her_clip_len'])
17 | return 'HER'
18 |
19 | legendNote = 'IR: identity ratio; CL: clip length'
20 | else:
21 | raise NotImplementedError
22 | return custom_series_splitter, legendNote
23 |
24 |
25 | def make_custom_filter(preset_names):
26 | if preset_names == set1:
27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names)
28 |
29 | def custom_filter(x):
30 | legend = custom_seris_splitter(x)
31 | if legend == FILTERED:
32 | return False
33 | else:
34 | return True
35 | return custom_filter
36 |
37 |
--------------------------------------------------------------------------------
/chester/examples/presets_tiancheng.py:
--------------------------------------------------------------------------------
1 | # Updated By Tiancheng Jin 08/28/2018
2 |
3 | # the preset file should be contained in the experiment folder ( which is assigned by exp_prefix )
4 | # for example, this file should be put in /path to project/data/local/
5 |
6 | # Here's an example for custom_series_splitter
7 | # suppose we want to split five experiments with random seeds from 0 to 4 into two strategies
8 | # * two groups for those with odd or plural random seeds: [0,2,4] and [1,3]
9 | # * two groups for those with smaller or larger random seeds: [0,1,2] and [3,4]
10 |
11 | preset_names = ['odd or plural','small or large']
12 |
13 |
14 | def make_custom_seris_splitter(preset_name):
15 | legend_note = None
16 | custom_series_splitter = None
17 |
18 | if preset_name == 'odd or plural':
19 | # build a custom series splitter for odd or plural random seeds
20 | # where the input is the data for experiment ( contains both the results and the parameters )
21 | def custom_series_splitter(x):
22 | # extract the parameters
23 | params = x['flat_params']
24 | # make up the legend
25 | if params['seed'] % 2 == 0:
26 | legend = 'odd seeds'
27 | else:
28 | legend = 'plural seeds'
29 | return legend
30 |
31 | legend_note = "Odd or Plural"
32 |
33 | elif preset_name == 'small or large':
34 | def custom_series_splitter(x):
35 | params = x['flat_params']
36 | if params['seed'] <= 2:
37 | legend = 'smaller seeds'
38 | else:
39 | legend = 'larger seeds'
40 | return legend
41 |
42 | legend_note = "Small or Large"
43 | else:
44 | assert NotImplementedError
45 |
46 | return custom_series_splitter, legend_note
47 |
--------------------------------------------------------------------------------
/chester/examples/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from baselines import logger
4 | from baselines.common.misc_util import (
5 | set_global_seeds,
6 | )
7 | from baselines.ddpg.main import run
8 | from mpi4py import MPI
9 |
10 |
11 | DEFAULT_PARAMS = {
12 | # env
13 | 'env_id': 'HalfCheetah-v2', # max absolute value of actions on different coordinates
14 |
15 | # ddpg
16 | 'layer_norm': True,
17 | 'render': False,
18 | 'normalize_returns':False,
19 | 'normalize_observations':True,
20 | 'actor_lr': 0.0001, # critic learning rate
21 | 'critic_lr': 0.001, # actor learning rate
22 | 'critic_l2_reg': 1e-2,
23 | 'popart': False,
24 | 'gamma': 0.99,
25 |
26 | # training
27 | 'seed': 0,
28 | 'nb_epochs':500, # number of epochs
29 | 'nb_epoch_cycles': 20, # per epoch
30 | 'nb_rollout_steps': 100, # sampling batches per cycle
31 | 'nb_train_steps': 100, # training batches per cycle
32 | 'batch_size': 64, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
33 | 'reward_scale': 1.0,
34 | 'clip_norm': None,
35 |
36 | # exploration
37 | 'noise_type':'adaptive-param_0.2',
38 |
39 | # debugging, logging and visualization
40 | 'render_eval': False,
41 | 'nb_eval_steps':100,
42 | 'evaluation':False,
43 | }
44 |
45 |
46 | def run_task(vv, log_dir=None, exp_name=None, allow_extra_parameters=False):
47 | # Configure logging system
48 | if log_dir or logger.get_dir() is None:
49 | logger.configure(dir=log_dir)
50 | logdir = logger.get_dir()
51 | assert logdir is not None
52 | os.makedirs(logdir, exist_ok=True)
53 |
54 | # Seed for multi-CPU MPI implementation ( rank = 0 for single threaded implementation )
55 | rank = MPI.COMM_WORLD.Get_rank()
56 | rank_seed = vv['seed'] + 1000000 * rank
57 | set_global_seeds(rank_seed)
58 |
59 | # load params from config
60 | params = DEFAULT_PARAMS
61 |
62 | # update all her parameters
63 | if not allow_extra_parameters:
64 | for k,v in vv.items():
65 | if k not in DEFAULT_PARAMS:
66 | print("[ Warning ] Undefined Parameters %s with value %s"%(str(k),str(v)))
67 | params.update(**{k: v for (k, v) in vv.items() if k in DEFAULT_PARAMS})
68 | else:
69 | params.update(**{k: v for (k, v) in vv.items()})
70 |
71 | with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
72 | json.dump(params, f)
73 |
74 | run(**params)
75 |
76 |
77 |
--------------------------------------------------------------------------------
/chester/examples/train_launch.py:
--------------------------------------------------------------------------------
1 | import time
2 | from chester.run_exp import run_experiment_lite, VariantGenerator
3 |
4 | if __name__ == '__main__':
5 |
6 | # Here's an example for doing grid search of openai's DDPG
7 | # on HalfCheetah
8 |
9 | # the experiment folder name
10 | # the directory is defined as /LOG_DIR/data/local/exp_prefix/, where LOG_DIR is defined in config.py
11 | exp_prefix = 'test-ddpg'
12 | vg = VariantGenerator()
13 | vg.add('env_id', ['HalfCheetah-v2', 'Hopper-v2', 'InvertedPendulum-v2'])
14 |
15 | # select random seeds from 0 to 4
16 | vg.add('seed', [0, 1, 2, 3, 4])
17 | print('Number of configurations: ', len(vg.variants()))
18 |
19 | # set the maximum number for running experiments in parallel
20 | # this number depends on the number of processors in the runner
21 | maximum_launching_process = 5
22 |
23 | # launch experiments
24 | sub_process_popens = []
25 | for vv in vg.variants():
26 | while len(sub_process_popens) >= maximum_launching_process:
27 | sub_process_popens = [x for x in sub_process_popens if x.poll() is None]
28 | time.sleep(10)
29 |
30 | # import the launcher of experiments
31 | from chester.examples.train import run_task
32 |
33 | # use your written run_task function
34 | cur_popen = run_experiment_lite(
35 | stub_method_call=run_task,
36 | variant=vv,
37 | mode='local',
38 | exp_prefix=exp_prefix,
39 | wait_subprocess=False
40 | )
41 | if cur_popen is not None:
42 | sub_process_popens.append(cur_popen)
43 |
--------------------------------------------------------------------------------
/chester/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import os.path as osp
5 | import json
6 | import time
7 | import datetime
8 | import dateutil.tz
9 | import tempfile
10 | from collections import defaultdict
11 |
12 | # LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv', 'tensorboard']
13 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv']
14 | # Also valid: json, tensorboard
15 |
16 | DEBUG = 10
17 | INFO = 20
18 | WARN = 30
19 | ERROR = 40
20 |
21 | DISABLED = 50
22 |
23 |
24 | class KVWriter(object):
25 | def writekvs(self, kvs):
26 | raise NotImplementedError
27 |
28 |
29 | class SeqWriter(object):
30 | def writeseq(self, seq):
31 | raise NotImplementedError
32 |
33 |
34 | def put_in_middle(str1, str2):
35 | # Put str1 in str2
36 | n = len(str1)
37 | m = len(str2)
38 | if n <= m:
39 | return str2
40 | else:
41 | start = (n - m) // 2
42 | return str1[:start] + str2 + str1[start + m:]
43 |
44 |
45 | class HumanOutputFormat(KVWriter, SeqWriter):
46 | def __init__(self, filename_or_file):
47 | if isinstance(filename_or_file, str):
48 | self.file = open(filename_or_file, 'wt')
49 | self.own_file = True
50 | else:
51 | assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file
52 | self.file = filename_or_file
53 | self.own_file = False
54 |
55 | def writekvs(self, kvs):
56 | # Create strings for printing
57 | key2str = {}
58 | for (key, val) in sorted(kvs.items()):
59 | if isinstance(val, float):
60 | valstr = '%-8.3g' % (val,)
61 | else:
62 | valstr = str(val)
63 | key2str[self._truncate(key)] = self._truncate(valstr)
64 |
65 | # Find max widths
66 | if len(key2str) == 0:
67 | print('WARNING: tried to write empty key-value dict')
68 | return
69 | else:
70 | keywidth = max(map(len, key2str.keys()))
71 | valwidth = max(map(len, key2str.values()))
72 |
73 | # Write out the data
74 | now = datetime.datetime.now(dateutil.tz.tzlocal())
75 | timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z')
76 |
77 | dashes = '-' * (keywidth + valwidth + 7)
78 | dashes_time = put_in_middle(dashes, timestamp)
79 | lines = [dashes_time]
80 | for (key, val) in sorted(key2str.items()):
81 | lines.append('| %s%s | %s%s |' % (
82 | key,
83 | ' ' * (keywidth - len(key)),
84 | val,
85 | ' ' * (valwidth - len(val)),
86 | ))
87 | lines.append(dashes)
88 | self.file.write('\n'.join(lines) + '\n')
89 |
90 | # Flush the output to the file
91 | self.file.flush()
92 |
93 | def _truncate(self, s):
94 | return s[:30] + '...' if len(s) > 33 else s
95 |
96 | def writeseq(self, seq):
97 | for arg in seq:
98 | self.file.write(arg)
99 | self.file.write('\n')
100 | self.file.flush()
101 |
102 | def close(self):
103 | if self.own_file:
104 | self.file.close()
105 |
106 |
107 | class JSONOutputFormat(KVWriter):
108 | def __init__(self, filename):
109 | self.file = open(filename, 'wt')
110 |
111 | def writekvs(self, kvs):
112 | for k, v in sorted(kvs.items()):
113 | if hasattr(v, 'dtype'):
114 | v = v.tolist()
115 | kvs[k] = float(v)
116 | self.file.write(json.dumps(kvs) + '\n')
117 | self.file.flush()
118 |
119 | def close(self):
120 | self.file.close()
121 |
122 |
123 | class CSVOutputFormat(KVWriter):
124 | def __init__(self, filename):
125 | self.file = open(filename, 'w+t')
126 | self.keys = []
127 | self.sep = ','
128 |
129 | def writekvs(self, kvs):
130 | # Add our current row to the history
131 | extra_keys = kvs.keys() - self.keys
132 | if extra_keys:
133 | self.keys.extend(extra_keys)
134 | self.file.seek(0)
135 | lines = self.file.readlines()
136 | self.file.seek(0)
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(',')
140 | self.file.write(k)
141 | self.file.write('\n')
142 | for line in lines[1:]:
143 | self.file.write(line[:-1])
144 | self.file.write(self.sep * len(extra_keys))
145 | self.file.write('\n')
146 | for (i, k) in enumerate(self.keys):
147 | if i > 0:
148 | self.file.write(',')
149 | v = kvs.get(k)
150 | if v is not None:
151 | self.file.write(str(v))
152 | self.file.write('\n')
153 | self.file.flush()
154 |
155 | def close(self):
156 | self.file.close()
157 |
158 |
159 | class TensorBoardOutputFormat(KVWriter):
160 | """
161 | Dumps key/value pairs into TensorBoard's numeric format.
162 | """
163 |
164 | def __init__(self, dir):
165 | os.makedirs(dir, exist_ok=True)
166 | self.dir = dir
167 | self.step = 1
168 | prefix = 'events'
169 | path = osp.join(osp.abspath(dir), prefix)
170 | import tensorflow as tf
171 | from tensorflow.python import pywrap_tensorflow
172 | from tensorflow.core.util import event_pb2
173 | from tensorflow.python.util import compat
174 | self.tf = tf
175 | self.event_pb2 = event_pb2
176 | self.pywrap_tensorflow = pywrap_tensorflow
177 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
178 |
179 | def writekvs(self, kvs):
180 | def summary_val(k, v):
181 | kwargs = {'tag': k, 'simple_value': float(v)}
182 | return self.tf.Summary.Value(**kwargs)
183 |
184 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
185 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
186 | event.step = self.step # is there any reason why you'd want to specify the step?
187 | self.writer.WriteEvent(event)
188 | self.writer.Flush()
189 | self.step += 1
190 |
191 | def close(self):
192 | if self.writer:
193 | self.writer.Close()
194 | self.writer = None
195 |
196 |
197 | def make_output_format(format, ev_dir, log_suffix=''):
198 | os.makedirs(ev_dir, exist_ok=True)
199 | if format == 'stdout':
200 | return HumanOutputFormat(sys.stdout)
201 | elif format == 'log':
202 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
203 | elif format == 'json':
204 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
205 | elif format == 'csv':
206 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
207 | elif format == 'tensorboard':
208 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
209 | else:
210 | raise ValueError('Unknown format specified: %s' % (format,))
211 |
212 |
213 | # ================================================================
214 | # API
215 | # ================================================================
216 |
217 | def logkv(key, val):
218 | """
219 | Log a value of some diagnostic
220 | Call this once for each diagnostic quantity, each iteration
221 | If called many times, last value will be used.
222 | """
223 | Logger.CURRENT.logkv(key, val)
224 |
225 |
226 | def logkv_mean(key, val):
227 | """
228 | The same as logkv(), but if called many times, values averaged.
229 | """
230 | Logger.CURRENT.logkv_mean(key, val)
231 |
232 |
233 | def logkvs(d):
234 | """
235 | Log a dictionary of key-value pairs
236 | """
237 | for (k, v) in d.items():
238 | logkv(k, v)
239 |
240 |
241 | def dumpkvs():
242 | """
243 | Write all of the diagnostics from the current iteration
244 |
245 | level: int. (see logger.py docs) If the global logger level is higher than
246 | the level argument here, don't print to stdout.
247 | """
248 | Logger.CURRENT.dumpkvs()
249 |
250 |
251 | def getkvs():
252 | return Logger.CURRENT.name2val
253 |
254 |
255 | def log(*args, level=INFO):
256 | """
257 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
258 | """
259 | Logger.CURRENT.log(*args, level=level)
260 |
261 |
262 | def debug(*args):
263 | log(*args, level=DEBUG)
264 |
265 |
266 | def info(*args):
267 | log(*args, level=INFO)
268 |
269 |
270 | def warn(*args):
271 | log(*args, level=WARN)
272 |
273 |
274 | def error(*args):
275 | log(*args, level=ERROR)
276 |
277 |
278 | def set_level(level):
279 | """
280 | Set logging threshold on current logger.
281 | """
282 | Logger.CURRENT.set_level(level)
283 |
284 |
285 | def get_dir():
286 | """
287 | Get directory that log files are being written to.
288 | will be None if there is no output directory (i.e., if you didn't call start)
289 | """
290 | return Logger.CURRENT.get_dir()
291 |
292 |
293 | record_tabular = logkv
294 | dump_tabular = dumpkvs
295 |
296 |
297 | class ProfileKV:
298 | """
299 | Usage:
300 | with logger.ProfileKV("interesting_scope"):
301 | code
302 | """
303 |
304 | def __init__(self, n):
305 | self.n = "wait_" + n
306 |
307 | def __enter__(self):
308 | self.t1 = time.time()
309 |
310 | def __exit__(self, type, value, traceback):
311 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1
312 |
313 |
314 | def profile(n):
315 | """
316 | Usage:
317 | @profile("my_func")
318 | def my_func(): code
319 | """
320 |
321 | def decorator_with_name(func):
322 | def func_wrapper(*args, **kwargs):
323 | with ProfileKV(n):
324 | return func(*args, **kwargs)
325 |
326 | return func_wrapper
327 |
328 | return decorator_with_name
329 |
330 |
331 | # ================================================================
332 | # Backend
333 | # ================================================================
334 |
335 | class Logger(object):
336 | DEFAULT = None # A logger with no output files. (See right below class definition)
337 | # So that you can still log to the terminal without setting up any output files
338 | CURRENT = None # Current logger being used by the free functions above
339 |
340 | def __init__(self, dir, output_formats):
341 | self.name2val = defaultdict(float) # values this iteration
342 | self.name2cnt = defaultdict(int)
343 | self.level = INFO
344 | self.dir = dir
345 | self.output_formats = output_formats
346 |
347 | # Logging API, forwarded
348 | # ----------------------------------------
349 | def logkv(self, key, val):
350 | self.name2val[key] = val
351 |
352 | def logkv_mean(self, key, val):
353 | if val is None:
354 | self.name2val[key] = None
355 | return
356 | oldval, cnt = self.name2val[key], self.name2cnt[key]
357 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
358 | self.name2cnt[key] = cnt + 1
359 |
360 | def dumpkvs(self):
361 | if self.level == DISABLED: return
362 | for fmt in self.output_formats:
363 | if isinstance(fmt, KVWriter):
364 | fmt.writekvs(self.name2val)
365 | self.name2val.clear()
366 | self.name2cnt.clear()
367 |
368 | def log(self, *args, level=INFO):
369 | if self.level <= level:
370 | self._do_log(args)
371 |
372 | # Configuration
373 | # ----------------------------------------
374 | def set_level(self, level):
375 | self.level = level
376 |
377 | def get_dir(self):
378 | return self.dir
379 |
380 | def close(self):
381 | for fmt in self.output_formats:
382 | fmt.close()
383 |
384 | # Misc
385 | # ----------------------------------------
386 | def _do_log(self, args):
387 | for fmt in self.output_formats:
388 | if isinstance(fmt, SeqWriter):
389 | fmt.writeseq(map(str, args))
390 |
391 |
392 | Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)])
393 |
394 |
395 | def configure(dir=None, format_strs=None, exp_name=None):
396 | if dir is None:
397 | dir = os.getenv('OPENAI_LOGDIR')
398 | if dir is None:
399 | dir = osp.join(tempfile.gettempdir(),
400 | datetime.datetime.now().strftime("chester-%Y-%m-%d-%H-%M-%S"))
401 |
402 | assert isinstance(dir, str)
403 | os.makedirs(dir, exist_ok=True)
404 |
405 | if format_strs is None:
406 | strs = os.getenv('OPENAI_LOG_FORMAT')
407 | format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS
408 | output_formats = [make_output_format(f, dir) for f in format_strs]
409 |
410 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
411 | log('Logging to %s' % dir)
412 |
413 |
414 | def reset():
415 | if Logger.CURRENT is not Logger.DEFAULT:
416 | Logger.CURRENT.close()
417 | Logger.CURRENT = Logger.DEFAULT
418 | log('Reset logger')
419 |
420 |
421 | class scoped_configure(object):
422 | def __init__(self, dir=None, format_strs=None):
423 | self.dir = dir
424 | self.format_strs = format_strs
425 | self.prevlogger = None
426 |
427 | def __enter__(self):
428 | self.prevlogger = Logger.CURRENT
429 | configure(dir=self.dir, format_strs=self.format_strs)
430 |
431 | def __exit__(self, *args):
432 | Logger.CURRENT.close()
433 | Logger.CURRENT = self.prevlogger
434 |
435 |
436 | # ================================================================
437 |
438 | def _demo():
439 | info("hi")
440 | debug("shouldn't appear")
441 | set_level(DEBUG)
442 | debug("should appear")
443 | dir = "/tmp/testlogging"
444 | if os.path.exists(dir):
445 | shutil.rmtree(dir)
446 | configure(dir=dir)
447 | logkv("a", 3)
448 | logkv("b", 2.5)
449 | dumpkvs()
450 | logkv("b", -2.5)
451 | logkv("a", 5.5)
452 | dumpkvs()
453 | info("^^^ should see a = 5.5")
454 | logkv_mean("b", -22.5)
455 | logkv_mean("b", -44.4)
456 | logkv("a", 5.5)
457 | dumpkvs()
458 | info("^^^ should see b = 33.3")
459 |
460 | logkv("b", -2.5)
461 | dumpkvs()
462 |
463 | logkv("a", "longasslongasslongasslongasslongasslongassvalue")
464 | dumpkvs()
465 |
466 |
467 | # ================================================================
468 | # Readers
469 | # ================================================================
470 |
471 | def read_json(fname):
472 | import pandas
473 | ds = []
474 | with open(fname, 'rt') as fh:
475 | for line in fh:
476 | ds.append(json.loads(line))
477 | return pandas.DataFrame(ds)
478 |
479 |
480 | def read_csv(fname):
481 | import pandas
482 | return pandas.read_csv(fname, index_col=None, comment='#')
483 |
484 |
485 | def read_tb(path):
486 | """
487 | path : a tensorboard file OR a directory, where we will find all TB files
488 | of the form events.*
489 | """
490 | import pandas
491 | import numpy as np
492 | from glob import glob
493 | from collections import defaultdict
494 | import tensorflow as tf
495 | if osp.isdir(path):
496 | fnames = glob(osp.join(path, "events.*"))
497 | elif osp.basename(path).startswith("events."):
498 | fnames = [path]
499 | else:
500 | raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s" % path)
501 | tag2pairs = defaultdict(list)
502 | maxstep = 0
503 | for fname in fnames:
504 | for summary in tf.train.summary_iterator(fname):
505 | if summary.step > 0:
506 | for v in summary.summary.value:
507 | pair = (summary.step, v.simple_value)
508 | tag2pairs[v.tag].append(pair)
509 | maxstep = max(summary.step, maxstep)
510 | data = np.empty((maxstep, len(tag2pairs)))
511 | data[:] = np.nan
512 | tags = sorted(tag2pairs.keys())
513 | for (colidx, tag) in enumerate(tags):
514 | pairs = tag2pairs[tag]
515 | for (step, value) in pairs:
516 | data[step - 1, colidx] = value
517 | return pandas.DataFrame(data, columns=tags)
518 |
519 |
520 | if __name__ == "__main__":
521 | _demo()
522 |
--------------------------------------------------------------------------------
/chester/plotting/cplot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import argparse
6 | import itertools
7 | import numpy as np
8 |
9 | # Matplotlib
10 | import matplotlib
11 |
12 | matplotlib.use('Agg')
13 | import matplotlib.pyplot as plt
14 |
15 | plt.rc('font', size=25)
16 | matplotlib.rcParams['pdf.fonttype'] = 42 # Default type3 cannot be rendered in some templates
17 | matplotlib.rcParams['ps.fonttype'] = 42
18 | matplotlib.rcParams['grid.alpha'] = 0.3
19 | matplotlib.rcParams['axes.titlesize'] = 25
20 | import matplotlib.ticker as tick
21 |
22 | # rllab
23 | sys.path.append('.')
24 | from rllab.misc.ext import flatten
25 | from rllab.viskit import core
26 |
27 |
28 | # from rllab.misc import ext
29 |
30 | # plotly
31 | # import plotly.offline as po
32 | # import plotly.graph_objs as go
33 |
34 |
35 | def reload_data(data_paths):
36 | """
37 | Iterate through the data folder and organize each experiment into a list, with their progress data, hyper-parameters
38 | and also analyze all the curves and give the distinct hyper-parameters.
39 | :param data_path: Path of the folder storing all the data
40 | :return [exps_data, plottable_keys, distinct_params]
41 | exps_data: A list of the progress data for each curve. Each curve is an AttrDict with the key
42 | 'progress': A dictionary of plottable keys. The val of each key is an ndarray representing the
43 | values of the key during training, or one column in the progress.txt file.
44 | 'params'/'flat_params': A dictionary of all hyperparameters recorded in 'variants.json' file.
45 | plottable_keys: A list of strings representing all the keys that can be plotted.
46 | distinct_params: A list of hyper-parameters which have different values among all the curves. This can be used
47 | to split the graph into multiple figures. Each element is a tuple (param, list_of_values_to_take).
48 | """
49 |
50 | exps_data = copy.copy(core.load_exps_data(data_paths, disable_variant=False, ignore_missing_keys=True))
51 | plottable_keys = copy.copy(sorted(list(set(flatten(list(exp.progress.keys()) for exp in exps_data)))))
52 | distinct_params = copy.copy(sorted(core.extract_distinct_params(exps_data)))
53 |
54 | return exps_data, plottable_keys, distinct_params
55 |
56 |
57 | def get_shaded_curve(selector, key, shade_type='variance'):
58 | """
59 | :param selector: Selector for a group of curves
60 | :param shade_type: Should be either 'variance' or 'median', indicating how the shades are calculated.
61 | :return: [y, y_lower, y_upper], representing the mean, upper and lower boundary of the shaded region
62 | """
63 |
64 | # First, get the progresses
65 | progresses = [exp.progress.get(key, np.array([np.nan])) for exp in selector.extract()]
66 | max_size = max(len(x) for x in progresses)
67 | progresses = [np.concatenate([ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses]
68 |
69 | # Second, calculate the shaded area
70 | if shade_type == 'median':
71 | percentile25 = np.nanpercentile(
72 | progresses, q=25, axis=0)
73 | percentile50 = np.nanpercentile(
74 | progresses, q=50, axis=0)
75 | percentile75 = np.nanpercentile(
76 | progresses, q=75, axis=0)
77 |
78 | y = list(percentile50)
79 | y_upper = list(percentile75)
80 | y_lower = list(percentile25)
81 | elif shade_type == 'variance':
82 | means = np.nanmean(progresses, axis=0)
83 | stds = np.nanstd(progresses, axis=0)
84 |
85 | y = list(means)
86 | y_upper = list(means + stds)
87 | y_lower = list(means - stds)
88 | else:
89 | raise NotImplementedError
90 |
91 | return y, y_lower, y_upper
92 |
93 |
94 | def get_group_selectors(exps, custom_series_splitter):
95 | """
96 |
97 | :param exps:
98 | :param custom_series_splitter:
99 | :return: A dictionary with (splitted_keys, group_selectors). Group selectors can be used to extract progresses.
100 | """
101 | splitted_dict = dict()
102 | for exp in exps:
103 | # Group exps by their series_splitter key
104 | # splitted_dict: {key:[exp1, exp2, ...]}
105 | key = custom_series_splitter(exp)
106 | if key not in splitted_dict:
107 | splitted_dict[key] = list()
108 | splitted_dict[key].append(exp)
109 |
110 | splitted = list(splitted_dict.items())
111 | # Group selectors: All the exps in one of the keys/legends
112 | # Group legends: All the different legends
113 | group_selectors = [core.Selector(list(x[1])) for x in splitted]
114 | group_legends = [x[0] for x in splitted]
115 | all_tuples = sorted(list(zip(group_selectors, group_legends)), key=lambda x: x[1], reverse=True)
116 | group_selectors = [x[0] for x in all_tuples]
117 | group_legends = [x[1] for x in all_tuples]
118 | return group_selectors, group_legends
119 |
120 |
121 | def filter_save_name(save_name):
122 | save_name = save_name.replace('/', '_')
123 | save_name = save_name.replace('[', '_')
124 | save_name = save_name.replace(']', '_')
125 | save_name = save_name.replace(',', '_')
126 | save_name = save_name.replace(' ', '_')
127 | save_name = save_name.replace('0.', '0_')
128 |
129 | return save_name
130 |
131 |
132 | def sliding_mean(data_array, window=5):
133 | data_array = np.array(data_array)
134 | new_list = []
135 | for i in range(len(data_array)):
136 | indices = list(range(max(i - window + 1, 0),
137 | min(i + window + 1, len(data_array))))
138 | avg = 0
139 | for j in indices:
140 | avg += data_array[j]
141 | avg /= float(len(indices))
142 | new_list.append(avg)
143 |
144 | return np.array(new_list)
145 |
146 |
147 | if __name__ == '__main__':
148 | data_path = '/Users/Dora/Projects/baselines_hrl/data/seuss/visual_rss_RopeFloat_0407'
149 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
150 |
151 | # Example of extracting a single curve
152 | selector = core.Selector(exps_data)
153 | selector = selector.where('her_replay_strategy', 'balance_filter')
154 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state')
155 | _, ax = plt.subplots()
156 |
157 | color = core.color_defaults[0]
158 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2)
159 | ax.plot(range(len(y)), y, color=color, label=plt.legend, linewidth=2.0)
160 |
161 |
162 | # Example of extracting all the curves
163 | def custom_series_splitter(x):
164 | params = x['flat_params']
165 | if 'use_ae_reward' in params and params['use_ae_reward']:
166 | return 'Auto Encoder'
167 | if params['her_replay_strategy'] == 'balance_filter':
168 | return 'Indicator+Balance+Filter'
169 | if params['env_kwargs.use_true_reward']:
170 | return 'Oracle'
171 | return 'Indicator'
172 |
173 |
174 | fig, ax = plt.subplots(figsize=(8, 5))
175 |
176 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
177 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
178 | color = core.color_defaults[idx]
179 |
180 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state')
181 |
182 | ax.plot(range(len(y)), y, color=color, label=legend, linewidth=2.0)
183 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2)
184 | ax.grid(True)
185 | ax.set_xlabel('Timesteps')
186 | ax.set_ylabel('Success')
187 | loc = 'best'
188 | leg = ax.legend(loc=loc, prop={'size': 15}, ncol=1, labels=group_legends)
189 | for legobj in leg.legendHandles:
190 | legobj.set_linewidth(3.0)
191 | plt.savefig('test.png', bbox_inches='tight')
192 |
--------------------------------------------------------------------------------
/chester/pull_result.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 |
5 | sys.path.append('.')
6 | from chester import config
7 |
8 | if __name__ == "__main__":
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('host', type=str)
11 | parser.add_argument('folder', type=str)
12 | parser.add_argument('--dry', action='store_true', default=False)
13 | parser.add_argument('--bare', action='store_true', default=False)
14 | parser.add_argument('--img', action='store_true', default=False)
15 | parser.add_argument('--pkl', action='store_true', default=False)
16 | parser.add_argument('--gif', action='store_true', default=False)
17 | parser.add_argument('--newdatadir', action='store_true', default=False)
18 | args = parser.parse_args()
19 |
20 | args.folder = args.folder.rstrip('/')
21 | if args.folder.rfind('/') !=-1:
22 | local_dir = os.path.join('./data', args.host, args.folder[:args.folder.rfind('/')])
23 | else:
24 | local_dir = os.path.join('./data', args.host)
25 | # if args.newdatadir:
26 | dir_path = '/data/yufeiw2/softagent_prvil_merge/'
27 | # else:
28 | # dir_path = config.REMOTE_DIR[args.host]
29 | remote_data_dir = os.path.join(dir_path, 'data', 'local', args.folder)
30 | command = """rsync -avzh --delete --progress {host}:{remote_data_dir} {local_dir}""".format(host=args.host,
31 | remote_data_dir=remote_data_dir,
32 | local_dir=local_dir)
33 | if args.bare:
34 | command += """ --exclude '*checkpoin*' --exclude '*ckpt*' --exclude '*tfevents*' --exclude '*.pth' --exclude '*.pt' --include '*.csv' --include '*.json' --delete"""
35 | if not args.img:
36 | command += """ --exclude '*.png' """
37 | if not args.gif:
38 | command += """ --exclude '*.gif' """
39 | if not args.pkl:
40 | command += """ --exclude '*.pkl' """
41 | if args.dry:
42 | print(command)
43 | else:
44 | os.system(command)
45 |
--------------------------------------------------------------------------------
/chester/pull_s3_result.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import argparse
4 |
5 | def aws_sync(bucket_name, s3_log_dir, target_dir, args):
6 | cmd = 'aws s3 cp --recursive s3://%s/%s %s' % (bucket_name, s3_log_dir, target_dir)
7 | exlus = ['"*.pkl"', '"*.gif"', '"*.png"', '"*.pth"']
8 | inclus = []
9 | if args.gif:
10 | exlus.remove('"*.gif"')
11 | if args.png:
12 | exlus.remove('"*.png"')
13 | if args.param:
14 | inclus.append('"params.pkl"')
15 | exlus.remove('"*.pkl"')
16 |
17 | if not args.include_all:
18 | for exc in exlus:
19 | cmd += ' --exclude ' + exc
20 |
21 | for inc in inclus:
22 | cmd += ' --include ' + inc
23 |
24 | print(cmd)
25 | # exit()
26 | subprocess.call(cmd, shell=True)
27 |
28 |
29 | def main():
30 | parser = argparse.ArgumentParser(description='Process some integers.')
31 | parser.add_argument('log_dir', type=str, help='S3 Log dir')
32 | parser.add_argument('-b', '--bucket', type=str, default='chester-softgym', help='S3 Bucket')
33 | parser.add_argument('--param', type=int, default=0, help='Exclude')
34 | parser.add_argument('--gif', type=int, default=0, help='Exclude')
35 | parser.add_argument('--png', type=int, default=0, help='Exclude')
36 | parser.add_argument('--include_all', type=int, default=1, help='pull all data')
37 |
38 | args = parser.parse_args()
39 | s3_log_dir = "rllab/experiments/" + args.log_dir
40 | local_dir = os.path.join('./data', 'corl_s3_data', args.log_dir)
41 | if not os.path.exists(local_dir):
42 | os.makedirs(local_dir)
43 | aws_sync(args.bucket, s3_log_dir, local_dir, args)
44 |
45 |
46 | if __name__ == "__main__":
47 | main()
48 |
--------------------------------------------------------------------------------
/chester/rsync_exclude:
--------------------------------------------------------------------------------
1 | softgym_rpad
2 | *.img
3 | datasets
4 | data
5 | data/yufei_s3_data
6 | data/yufei_seuss_data
7 | data/local
8 | data/icml
9 | *__pycache__*
10 | build
11 | .idea
12 | .git
13 | DPI-Net
14 | videos
15 | imgs
16 | planet
17 | cem
18 | curl
19 | dreamer
20 | drq
21 | experiments
22 | PDDM
23 | pouring
24 | ResRL
25 | rlkit
26 | rlpyt_cloth
27 | tests
28 | tmp
29 | softgym/softgym/cached_initial_states/*
30 | wandb
31 | data2
32 | datasets2
--------------------------------------------------------------------------------
/chester/rsync_include:
--------------------------------------------------------------------------------
1 | GNS
2 | softgym/softgym/envs
3 | softgym/PyFlexRobotics/bindings/softgym_scenes
4 |
--------------------------------------------------------------------------------
/chester/run_exp_worker.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os.path as osp
4 | import datetime
5 | import dateutil.tz
6 | import ast
7 | import uuid
8 | import pickle as pickle
9 | import base64
10 | import joblib
11 |
12 | from chester import config
13 |
14 |
15 | def run_experiment(argv):
16 | default_log_dir = config.LOG_DIR
17 | now = datetime.datetime.now(dateutil.tz.tzlocal())
18 |
19 | # avoid name clashes when running distributed jobs
20 | rand_id = str(uuid.uuid4())[:5]
21 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
22 |
23 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--n_parallel', type=int, default=1,
26 | help='Number of parallel workers to perform rollouts. 0 => don\'t start any workers')
27 | parser.add_argument(
28 | '--exp_name', type=str, default=default_exp_name, help='Name of the experiment.')
29 | parser.add_argument('--log_dir', type=str, default=None,
30 | help='Path to save the log and iteration snapshot.')
31 | parser.add_argument('--snapshot_mode', type=str, default='all',
32 | help='Mode to save the snapshot. Can be either "all" '
33 | '(all iterations will be saved), "last" (only '
34 | 'the last iteration will be saved), "gap" (every'
35 | '`snapshot_gap` iterations are saved), or "none" '
36 | '(do not save snapshots)')
37 | parser.add_argument('--snapshot_gap', type=int, default=1,
38 | help='Gap between snapshot iterations.')
39 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
40 | help='Name of the tabular log file (in csv).')
41 | parser.add_argument('--text_log_file', type=str, default='debug.log',
42 | help='Name of the text log file (in pure text).')
43 | parser.add_argument('--params_log_file', type=str, default='params.json',
44 | help='Name of the parameter log file (in json).')
45 | parser.add_argument('--variant_log_file', type=str, default='variant.json',
46 | help='Name of the variant log file (in json).')
47 | parser.add_argument('--resume_from', type=str, default=None,
48 | help='Name of the pickle file to resume experiment from.')
49 | parser.add_argument('--plot', type=ast.literal_eval, default=False,
50 | help='Whether to plot the iteration results')
51 | parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
52 | help='Whether to only print the tabular log information (in a horizontal format)')
53 | parser.add_argument('--seed', type=int,
54 | help='Random seed for numpy')
55 | parser.add_argument('--args_data', type=str,
56 | help='Pickled data for stub objects')
57 | parser.add_argument('--variant_data', type=str,
58 | help='Pickled data for variant configuration')
59 | parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)
60 |
61 | args = parser.parse_args(argv[1:])
62 |
63 | # if args.seed is not None:
64 | # set_seed(args.seed)
65 | #
66 | # if args.n_parallel > 0:
67 | # from rllab.sampler import parallel_sampler
68 | # parallel_sampler.initialize(n_parallel=args.n_parallel)
69 | # if args.seed is not None:
70 | # parallel_sampler.set_seed(args.seed)
71 | #
72 | # if args.plot:
73 | # from rllab.plotter import plotter
74 | # plotter.init_worker()
75 |
76 | if args.log_dir is None:
77 | log_dir = osp.join(default_log_dir, args.exp_name)
78 | else:
79 | log_dir = args.log_dir
80 | # tabular_log_file = osp.join(log_dir, args.tabular_log_file)
81 | # text_log_file = osp.join(log_dir, args.text_log_file)
82 | # params_log_file = osp.join(log_dir, args.params_log_file)
83 |
84 | if args.variant_data is not None:
85 | variant_data = pickle.loads(base64.b64decode(args.variant_data))
86 | variant_log_file = osp.join(log_dir, args.variant_log_file)
87 | # logger.log_variant(variant_log_file, variant_data)
88 | else:
89 | variant_data = None
90 |
91 | # if not args.use_cloudpickle:
92 | # logger.log_parameters_lite(params_log_file, args)
93 | #
94 | # logger.add_text_output(text_log_file)
95 | # logger.add_tabular_output(tabular_log_file)
96 | # prev_snapshot_dir = logger.get_snapshot_dir()
97 | # prev_mode = logger.get_snapshot_mode()
98 | # logger.set_snapshot_dir(log_dir)
99 | # logger.set_snapshot_mode(args.snapshot_mode)
100 | # logger.set_snapshot_gap(args.snapshot_gap)
101 | # logger.set_log_tabular_only(args.log_tabular_only)
102 | # logger.push_prefix("[%s] " % args.exp_name)
103 |
104 | if args.resume_from is not None:
105 | data = joblib.load(args.resume_from)
106 | assert 'algo' in data
107 | algo = data['algo']
108 | algo.train()
109 | else:
110 | # read from stdin
111 | if args.use_cloudpickle:
112 | import cloudpickle
113 | method_call = cloudpickle.loads(base64.b64decode(args.args_data))
114 | method_call(variant_data, log_dir, args.exp_name)
115 | else:
116 | assert False
117 | # data = pickle.loads(base64.b64decode(args.args_data))
118 | # maybe_iter = concretize(data)
119 | # if is_iterable(maybe_iter):
120 | # for _ in maybe_iter:
121 | # pass
122 |
123 | # logger.set_snapshot_mode(prev_mode)
124 | # logger.set_snapshot_dir(prev_snapshot_dir)
125 | # logger.remove_tabular_output(tabular_log_file)
126 | # logger.remove_text_output(text_log_file)
127 | # logger.pop_prefix()
128 |
129 |
130 | if __name__ == "__main__":
131 | run_experiment(sys.argv)
132 |
--------------------------------------------------------------------------------
/chester/scripts/install_miniconda.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Check this file before using
3 | CONDA_INSTALL_PATH="~/software/miniconda3"
4 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
5 | chmod +x Miniconda3-latest-Linux-x86_64.sh
6 | ./Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_INSTALL_PATH
7 | if [ -d $CONDA_INSTALL_PATH/bin ]; then
8 | PATH=$PATH:$HOME/bin
9 | fi
10 | echo 'PATH='$CONDA_INSTALL_PATH'/bin:$PATH' >> ~/.bashrc
11 | rm ./Miniconda3-latest-Linux-x86_64.sh
--------------------------------------------------------------------------------
/chester/scripts/install_mpi4py.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Install mpi4py 3.0.0"
3 | cd /tmp
4 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz
5 | tar -zxf mpi4py-3.0.0.tar.gz
6 | cd mpi4py-3.0.0
7 | python setup.py build --mpicc=/usr/local/bin/mpicc
8 | python setup.py install --user
--------------------------------------------------------------------------------
/chester/setup_ec2_for_chester.py:
--------------------------------------------------------------------------------
1 | import boto3
2 | import re
3 | import sys
4 | import json
5 | import botocore
6 | import os
7 | from rllab.misc import console
8 | from string import Template
9 | import os.path as osp
10 |
11 | CHESTER_DIR = osp.dirname(__file__)
12 | ACCESS_KEY = os.environ["AWS_ACCESS_KEY"]
13 | ACCESS_SECRET = os.environ["AWS_ACCESS_SECRET"]
14 | S3_BUCKET_NAME = os.environ["RLLAB_S3_BUCKET"]
15 |
16 | ALL_REGION_AWS_SECURITY_GROUP_IDS = {}
17 | ALL_REGION_AWS_KEY_NAMES = {}
18 |
19 | CONFIG_TEMPLATE = Template("""
20 | import os.path as osp
21 | import os
22 |
23 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
24 |
25 | AWS_NETWORK_INTERFACES = []
26 |
27 | MUJOCO_KEY_PATH = osp.expanduser("~/.mujoco")
28 |
29 | USE_GPU = False
30 |
31 | USE_TF = True
32 |
33 | AWS_REGION_NAME = "us-east-2"
34 |
35 | if USE_GPU:
36 | DOCKER_IMAGE = "dementrock/rllab3-shared-gpu"
37 | else:
38 | DOCKER_IMAGE = "dementrock/rllab3-shared"
39 |
40 | DOCKER_LOG_DIR = "/tmp/expt"
41 |
42 | AWS_S3_PATH = "s3://$s3_bucket_name/rllab/experiments"
43 |
44 | AWS_CODE_SYNC_S3_PATH = "s3://$s3_bucket_name/rllab/code"
45 |
46 | ALL_REGION_AWS_IMAGE_IDS = {
47 | "ap-northeast-1": "ami-002f0167",
48 | "ap-northeast-2": "ami-590bd937",
49 | "ap-south-1": "ami-77314318",
50 | "ap-southeast-1": "ami-1610a975",
51 | "ap-southeast-2": "ami-9dd4ddfe",
52 | "eu-central-1": "ami-63af720c",
53 | "eu-west-1": "ami-41484f27",
54 | "sa-east-1": "ami-b7234edb",
55 | "us-east-1": "ami-83f26195",
56 | "us-east-2": "ami-66614603",
57 | "us-west-1": "ami-576f4b37",
58 | "us-west-2": "ami-b8b62bd8"
59 | }
60 |
61 | AWS_IMAGE_ID = ALL_REGION_AWS_IMAGE_IDS[AWS_REGION_NAME]
62 |
63 | if USE_GPU:
64 | AWS_INSTANCE_TYPE = "g2.2xlarge"
65 | else:
66 | AWS_INSTANCE_TYPE = "c4.4xlarge"
67 |
68 | ALL_REGION_AWS_KEY_NAMES = $all_region_aws_key_names
69 |
70 | AWS_KEY_NAME = ALL_REGION_AWS_KEY_NAMES[AWS_REGION_NAME]
71 |
72 | AWS_SPOT = True
73 |
74 | AWS_SPOT_PRICE = '0.5'
75 |
76 | AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
77 |
78 | AWS_ACCESS_SECRET = os.environ.get("AWS_ACCESS_SECRET", None)
79 |
80 | AWS_IAM_INSTANCE_PROFILE_NAME = "rllab"
81 |
82 | AWS_SECURITY_GROUPS = ["rllab-sg"]
83 |
84 | ALL_REGION_AWS_SECURITY_GROUP_IDS = $all_region_aws_security_group_ids
85 |
86 | AWS_SECURITY_GROUP_IDS = ALL_REGION_AWS_SECURITY_GROUP_IDS[AWS_REGION_NAME]
87 |
88 | FAST_CODE_SYNC_IGNORES = [
89 | ".git",
90 | "data",
91 | "data/local",
92 | "data/archive",
93 | "data/debug",
94 | "data/s3",
95 | "data/video",
96 | "src",
97 | ".idea",
98 | ".pods",
99 | "tests",
100 | "examples",
101 | "docs",
102 | ".idea",
103 | ".DS_Store",
104 | ".ipynb_checkpoints",
105 | "blackbox",
106 | "blackbox.zip",
107 | "*.pyc",
108 | "*.ipynb",
109 | "scratch-notebooks",
110 | "conopt_root",
111 | "private/key_pairs",
112 | ]
113 |
114 | FAST_CODE_SYNC = True
115 |
116 | """)
117 |
118 |
119 | def setup_iam():
120 | iam_client = boto3.client(
121 | "iam",
122 | aws_access_key_id=ACCESS_KEY,
123 | aws_secret_access_key=ACCESS_SECRET,
124 | )
125 | iam = boto3.resource('iam', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=ACCESS_SECRET)
126 |
127 | # delete existing role if it exists
128 | try:
129 | existing_role = iam.Role('rllab')
130 | existing_role.load()
131 | # if role exists, delete and recreate
132 | if not query_yes_no(
133 | "There is an existing role named rllab. Proceed to delete everything rllab-related and recreate?",
134 | default="no"):
135 | sys.exit()
136 | print("Listing instance profiles...")
137 | inst_profiles = existing_role.instance_profiles.all()
138 | for prof in inst_profiles:
139 | for role in prof.roles:
140 | print("Removing role %s from instance profile %s" % (role.name, prof.name))
141 | prof.remove_role(RoleName=role.name)
142 | print("Deleting instance profile %s" % prof.name)
143 | prof.delete()
144 | for policy in existing_role.policies.all():
145 | print("Deleting inline policy %s" % policy.name)
146 | policy.delete()
147 | for policy in existing_role.attached_policies.all():
148 | print("Detaching policy %s" % policy.arn)
149 | existing_role.detach_policy(PolicyArn=policy.arn)
150 | print("Deleting role")
151 | existing_role.delete()
152 | except botocore.exceptions.ClientError as e:
153 | if e.response['Error']['Code'] == 'NoSuchEntity':
154 | pass
155 | else:
156 | raise e
157 |
158 | print("Creating role rllab")
159 | iam_client.create_role(
160 | Path='/',
161 | RoleName='rllab',
162 | AssumeRolePolicyDocument=json.dumps({'Version': '2012-10-17', 'Statement': [
163 | {'Action': 'sts:AssumeRole', 'Effect': 'Allow', 'Principal': {'Service': 'ec2.amazonaws.com'}}]})
164 | )
165 |
166 | role = iam.Role('rllab')
167 | print("Attaching policies")
168 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/AmazonS3FullAccess')
169 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/ResourceGroupsandTagEditorFullAccess')
170 |
171 | print("Creating inline policies")
172 | iam_client.put_role_policy(
173 | RoleName=role.name,
174 | PolicyName='CreateTags',
175 | PolicyDocument=json.dumps({
176 | "Version": "2012-10-17",
177 | "Statement": [
178 | {
179 | "Effect": "Allow",
180 | "Action": ["ec2:CreateTags"],
181 | "Resource": ["*"]
182 | }
183 | ]
184 | })
185 | )
186 | iam_client.put_role_policy(
187 | RoleName=role.name,
188 | PolicyName='TerminateInstances',
189 | PolicyDocument=json.dumps({
190 | "Version": "2012-10-17",
191 | "Statement": [
192 | {
193 | "Sid": "Stmt1458019101000",
194 | "Effect": "Allow",
195 | "Action": [
196 | "ec2:TerminateInstances"
197 | ],
198 | "Resource": [
199 | "*"
200 | ]
201 | }
202 | ]
203 | })
204 | )
205 |
206 | print("Creating instance profile rllab")
207 | iam_client.create_instance_profile(
208 | InstanceProfileName='rllab',
209 | Path='/'
210 | )
211 | print("Adding role rllab to instance profile rllab")
212 | iam_client.add_role_to_instance_profile(
213 | InstanceProfileName='rllab',
214 | RoleName='rllab'
215 | )
216 |
217 |
218 | def setup_s3():
219 | print("Creating S3 bucket at s3://%s" % S3_BUCKET_NAME)
220 | s3_client = boto3.client(
221 | "s3",
222 | aws_access_key_id=ACCESS_KEY,
223 | aws_secret_access_key=ACCESS_SECRET,
224 | )
225 | try:
226 | s3_client.create_bucket(
227 | ACL='private',
228 | Bucket=S3_BUCKET_NAME,
229 | CreateBucketConfiguration={
230 | 'LocationConstraint': 'us-east-2'
231 | }
232 | )
233 | except botocore.exceptions.ClientError as e:
234 | if e.response['Error']['Code'] == 'BucketAlreadyExists':
235 | raise ValueError("Bucket %s already exists. Please reconfigure S3_BUCKET_NAME" % S3_BUCKET_NAME) from e
236 | elif e.response['Error']['Code'] == 'BucketAlreadyOwnedByYou':
237 | print("Bucket already created by you")
238 | else:
239 | raise e
240 | print("S3 bucket created")
241 |
242 |
243 | def setup_ec2():
244 | for region in ["us-east-1", "us-east-2", "us-west-1", "us-west-2"]:
245 | print("Setting up region %s" % region)
246 |
247 | ec2 = boto3.resource(
248 | "ec2",
249 | region_name=region,
250 | aws_access_key_id=ACCESS_KEY,
251 | aws_secret_access_key=ACCESS_SECRET,
252 | )
253 | ec2_client = boto3.client(
254 | "ec2",
255 | region_name=region,
256 | aws_access_key_id=ACCESS_KEY,
257 | aws_secret_access_key=ACCESS_SECRET,
258 | )
259 | existing_vpcs = list(ec2.vpcs.all())
260 | assert len(existing_vpcs) >= 1
261 | vpc = existing_vpcs[0]
262 | print("Creating security group in VPC %s" % str(vpc.id))
263 | try:
264 | security_group = vpc.create_security_group(
265 | GroupName='rllab-sg', Description='Security group for rllab'
266 | )
267 | except botocore.exceptions.ClientError as e:
268 | if e.response['Error']['Code'] == 'InvalidGroup.Duplicate':
269 | sgs = list(vpc.security_groups.filter(GroupNames=['rllab-sg']))
270 | security_group = sgs[0]
271 | else:
272 | raise e
273 |
274 | ALL_REGION_AWS_SECURITY_GROUP_IDS[region] = [security_group.id]
275 |
276 | ec2_client.create_tags(Resources=[security_group.id], Tags=[{'Key': 'Name', 'Value': 'rllab-sg'}])
277 | try:
278 | security_group.authorize_ingress(FromPort=22, ToPort=22, IpProtocol='tcp', CidrIp='0.0.0.0/0')
279 | except botocore.exceptions.ClientError as e:
280 | if e.response['Error']['Code'] == 'InvalidPermission.Duplicate':
281 | pass
282 | else:
283 | raise e
284 | print("Security group created with id %s" % str(security_group.id))
285 |
286 | key_name = 'rllab-%s' % region
287 | try:
288 | print("Trying to create key pair with name %s" % key_name)
289 | key_pair = ec2_client.create_key_pair(KeyName=key_name)
290 | except botocore.exceptions.ClientError as e:
291 | if e.response['Error']['Code'] == 'InvalidKeyPair.Duplicate':
292 | if not query_yes_no("Key pair with name %s exists. Proceed to delete and recreate?" % key_name, "no"):
293 | sys.exit()
294 | print("Deleting existing key pair with name %s" % key_name)
295 | ec2_client.delete_key_pair(KeyName=key_name)
296 | print("Recreating key pair with name %s" % key_name)
297 | key_pair = ec2_client.create_key_pair(KeyName=key_name)
298 | else:
299 | raise e
300 |
301 | key_pair_folder_path = os.path.join(CHESTER_DIR, "private", "key_pairs")
302 | file_name = os.path.join(key_pair_folder_path, "%s.pem" % key_name)
303 |
304 | print("Saving keypair file")
305 | console.mkdir_p(key_pair_folder_path)
306 | with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, 0o600), 'w') as handle:
307 | handle.write(key_pair['KeyMaterial'] + '\n')
308 |
309 | # adding pem file to ssh
310 | os.system("ssh-add %s" % file_name)
311 |
312 | ALL_REGION_AWS_KEY_NAMES[region] = key_name
313 |
314 |
315 | def write_config():
316 | print("Writing config file...")
317 | content = CONFIG_TEMPLATE.substitute(
318 | all_region_aws_key_names=json.dumps(ALL_REGION_AWS_KEY_NAMES, indent=4),
319 | all_region_aws_security_group_ids=json.dumps(ALL_REGION_AWS_SECURITY_GROUP_IDS, indent=4),
320 | s3_bucket_name=S3_BUCKET_NAME,
321 | )
322 | config_personal_file = os.path.join(CHESTER_DIR, "config_ec2.py")
323 | if os.path.exists(config_personal_file):
324 | if not query_yes_no("config_ec2.py exists. Override?", "no"):
325 | sys.exit()
326 | with open(config_personal_file, "wb") as f:
327 | f.write(content.encode("utf-8"))
328 |
329 |
330 | def setup():
331 | setup_s3()
332 | setup_iam()
333 | setup_ec2()
334 | write_config()
335 |
336 |
337 | def query_yes_no(question, default="yes"):
338 | """Ask a yes/no question via raw_input() and return their answer.
339 |
340 | "question" is a string that is presented to the user.
341 | "default" is the presumed answer if the user just hits .
342 | It must be "yes" (the default), "no" or None (meaning
343 | an answer is required of the user).
344 |
345 | The "answer" return value is True for "yes" or False for "no".
346 | """
347 | valid = {"yes": True, "y": True, "ye": True,
348 | "no": False, "n": False}
349 | if default is None:
350 | prompt = " [y/n] "
351 | elif default == "yes":
352 | prompt = " [Y/n] "
353 | elif default == "no":
354 | prompt = " [y/N] "
355 | else:
356 | raise ValueError("invalid default answer: '%s'" % default)
357 |
358 | while True:
359 | sys.stdout.write(question + prompt)
360 | choice = input().lower()
361 | if default is not None and choice == '':
362 | return valid[default]
363 | elif choice in valid:
364 | return valid[choice]
365 | else:
366 | sys.stdout.write("Please respond with 'yes' or 'no' "
367 | "(or 'y' or 'n').\n")
368 |
369 |
370 | if __name__ == "__main__":
371 | setup()
--------------------------------------------------------------------------------
/chester/slurm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import re
4 | from subprocess import run
5 | from tempfile import NamedTemporaryFile
6 | from chester import config
7 |
8 | # TODO remove the singularity part
9 |
10 | slurm_dir = './'
11 |
12 |
13 | def slurm_run_scripts(scripts):
14 | """this is another function that those _sub files should call. this actually execute files"""
15 | # TODO support running multiple scripts
16 |
17 | assert isinstance(scripts, str)
18 |
19 | os.chdir(slurm_dir)
20 |
21 | # make sure it will run.
22 | assert scripts.startswith('#!/usr/bin/env bash\n')
23 | file_temp = NamedTemporaryFile(delete=False)
24 | file_temp.write(scripts.encode('utf-8'))
25 | file_temp.close()
26 | run(['sbatch', file_temp.name], check=True)
27 | os.remove(file_temp.name)
28 |
29 |
30 | _find_unsafe = re.compile(r'[a-zA-Z0-9_^@%+=:,./-]').search
31 |
32 |
33 | def _shellquote(s):
34 | """Return a shell-escaped version of the string *s*."""
35 | if not s:
36 | return "''"
37 |
38 | if _find_unsafe(s) is None:
39 | return s
40 |
41 | # use single quotes, and put single quotes into double quotes
42 | # the string $'b is then quoted as '$'"'"'b'
43 |
44 | return "'" + s.replace("'", "'\"'\"'") + "'"
45 |
46 |
47 | def _to_param_val(v):
48 | if v is None:
49 | return ""
50 | elif isinstance(v, list):
51 | return " ".join(map(_shellquote, list(map(str, v))))
52 | else:
53 | return _shellquote(str(v))
54 |
55 |
56 | def to_slurm_command(params, header, python_command="python", remote_dir='~/',
57 | script=osp.join(config.PROJECT_PATH, 'scripts/run_experiment.py'),
58 | simg_dir=None, use_gpu=False, modules=None, cuda_module=None, use_singularity=True,
59 | mount_options=None, compile_script=None, wait_compile=None, set_egl_gpu=False):
60 | # TODO Add code for specifying the resource allocation
61 | # TODO Check if use_gpu can be applied
62 | """
63 | Transfer the commands to the format that can be run by slurm.
64 | :param params:
65 | :param python_command:
66 | :param script:
67 | :param use_gpu:
68 | :return:
69 | """
70 | assert simg_dir is not None
71 | command = python_command + " " + script
72 |
73 | pre_commands = params.pop("pre_commands", None)
74 | post_commands = params.pop("post_commands", None)
75 |
76 | command_list = list()
77 | command_list.append(header)
78 |
79 | # Log into singularity shell
80 | if use_singularity:
81 | command_list.append('set -x') # echo commands to stdout
82 | command_list.append('set -u') # throw an error if unset variable referenced
83 | command_list.append('set -e') # exit on errors
84 | command_list.append('srun hostname')
85 |
86 | for remote_module in modules:
87 | command_list.append('module load ' + remote_module)
88 | if use_gpu:
89 | assert cuda_module is not None
90 | command_list.append('module load ' + cuda_module)
91 | command_list.append('cd {}'.format(remote_dir))
92 | # First execute a bash program inside the container and then run all the following commands
93 |
94 | if mount_options is not None:
95 | options = '-B ' + mount_options
96 | else:
97 | options = ''
98 | sing_prefix = 'singularity exec {} {} {} /bin/bash -c'.format(options, '--nv' if use_gpu else '', simg_dir)
99 | sing_commands = list()
100 | if compile_script is None or 'prepare' not in compile_script :
101 | # sing_commands.append('. ./prepare_1.0.sh')
102 | sing_commands.append('. ./prepare.sh')
103 | if set_egl_gpu:
104 | sing_commands.append('export EGL_GPU=$SLURM_JOB_GRES')
105 | sing_commands.append('echo $EGL_GPU')
106 | if compile_script is not None:
107 | sing_commands.append('./' + compile_script)
108 | if wait_compile is not None:
109 | sing_commands.append('sleep '+str(int(wait_compile)))
110 |
111 | if pre_commands is not None:
112 | command_list.extend(pre_commands)
113 | for k, v in params.items():
114 | if isinstance(v, dict):
115 | for nk, nv in v.items():
116 | if str(nk) == "_name":
117 | command += " --%s %s" % (k, _to_param_val(nv))
118 | else:
119 | command += " --%s_%s %s" % (k, nk, _to_param_val(nv))
120 | else:
121 | command += " --%s %s" % (k, _to_param_val(v))
122 | sing_commands.append(command)
123 | all_sing_cmds = ' && '.join(sing_commands)
124 | command_list.append(sing_prefix + ' \'{}\''.format(all_sing_cmds))
125 | if post_commands is not None:
126 | command_list.extend(post_commands)
127 | return command_list
128 |
129 | # if __name__ == '__main__':
130 | # slurm_run_scripts(header)
131 |
--------------------------------------------------------------------------------
/chester/upload_result.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 | from chester import config
5 | from chester.run_exp import rsync_code
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('mode', type=str)
10 | args = parser.parse_args()
11 |
12 | remote_dir = config.REMOTE_DIR[args.mode]
13 | rsync_code(args.mode, remote_dir)
14 |
--------------------------------------------------------------------------------
/chester/video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glfw
3 | from multiprocessing import Process, Queue
4 |
5 | import cv2 as cv
6 |
7 | class VideoRecorder(object):
8 | '''
9 | Used to record videos for mujoco_py environment
10 | '''
11 | def __init__(self, env, saved_path='./data/videos/', saved_name='temp'):
12 | # Get rid of the gym wrappers
13 | if hasattr(env, 'env'):
14 | env = env.env
15 | self.viewer = env._get_viewer()
16 | self.saved_path = saved_path
17 | self.saved_name = saved_name
18 | # self._set_filepath('/tmp/temp%07d.mp4')
19 | saved_name += '.mp4'
20 | self._set_filepath(os.path.join(saved_path, saved_name))
21 |
22 | def _set_filepath(self, video_name):
23 | self.viewer._video_path = video_name
24 |
25 | def start(self):
26 | self.viewer._record_video = True
27 | if self.viewer._record_video:
28 | fps = (1 / self.viewer._time_per_render)
29 | self.viewer._video_process = Process(target=save_video,
30 | args=(self.viewer._video_queue,
31 | self.viewer._video_path, fps))
32 | self.viewer._video_process.start()
33 |
34 | def end(self):
35 | self.viewer.key_callback(None, glfw.KEY_V, None, glfw.RELEASE, None)
36 |
37 | # class VideoRecorderDM(object):
38 | # '''
39 | # Used to record videos for dm_control based environments
40 | # '''
41 | # def __init__(self, env, saved_path='./data/videos/', saved_name='temp'):
42 | # self.saved_path = saved_path
43 | # self.saved_name = saved_name
44 | #
45 | # def
--------------------------------------------------------------------------------
/compile_1.0.sh:
--------------------------------------------------------------------------------
1 | cd softgym/PyFlex/bindings
2 | rm -rf build
3 | mkdir build
4 | cd build
5 | # Seuss
6 | if [[ $(hostname) = *"compute-0"* ]] || [[ $(hostname) = *"autobot-"* ]] || [[ $(hostname) = *"yertle"* ]]; then
7 | export CUDA_BIN_PATH=/usr/local/cuda-9.1
8 | fi
9 | cmake -DPYBIND11_PYTHON_VERSION=3.6 ..
10 | make -j
11 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: softgym
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.6.9
6 | - numpy=1.16.4
7 | - gym=0.15.7
8 | - h5py
9 | - pyg=2.0.1
10 | - Pillow=6.1
11 | - pyquaternion=0.9.5
12 | - opencv-python=5.1.1
13 | - imageio=2.6.1
14 | - imageio-ffmpeg
15 | - glob2=0.7
16 | - cmake=3.14.0
17 | - pybind11=2.4.3
18 | - moviepy
19 | - click
20 | - matplotlib
21 | - joblib
22 | - plotly
23 | - gtimer
24 | - python-pcl
25 | - pip:
26 | - torch=1.9
27 | - torchvision
28 | - termcolor
29 | - scikit-image
30 |
--------------------------------------------------------------------------------
/prepare_1.0.sh:
--------------------------------------------------------------------------------
1 | PATH=~/software/miniconda3/bin:~/anaconda3/bin:$PATH
2 | cd softgym
3 | . prepare_1.0.sh
4 | cd ..
5 | export PYFLEXROOT=${PWD}/softgym/PyFlex
6 | export PYTHONPATH=${PWD}:${PWD}/softgym:${PYFLEXROOT}/bindings/build:$PYTHONPATH
7 | export LD_LIBRARY_PATH=${PYFLEXROOT}/external/SDL2-2.0.4/lib/x64:$LD_LIBRARY_PATH
8 | export EGL_GPU=$CUDA_VISIBLE_DEVICES
9 |
--------------------------------------------------------------------------------
/pretrained/README.md:
--------------------------------------------------------------------------------
1 | # Notes on pre-trained model
2 | Our pre-trained model, pre-collected dataset and the cached initial states for planning can be accessed through this google drive link: [Google Drive Link](https://drive.google.com/drive/folders/1gS8ejcY1imKVT8TD8zmNC38gNicpkL6X?usp=sharing)
3 |
4 | The goolge drive folder includes:
5 | * `dataset.zip`: Pre-collected dataset for training.
6 | * dynamics_model:
7 | - `vsbl_dyn_140.pth`: Pre-trained dynamics GNN for partially observed point cloud.
8 | - `best_state.json`: Loading the pretraine edge and dynamics GNN will require a corresponding `best_state.json` that stores the model information. Just put these json files under the same directory as the pre-trained model.
9 | * edge_model:
10 | - `vsbl_edge_best.pth`: Pre-trained Edge GNN.
11 | - `best_state.json`: Loading the pretraine edge and dynamics GNN will require a corresponding `best_state.json` that stores the model information. Just put these json files under the same directory as the pre-trained model.
12 | * cached_states:
13 | - `1213_release_n1000.pkl`: The cached initial states for generating the training data.
14 | - `cloth_flatten_init_states_test_40.pkl` and `cloth_flatten_init_states_test_40_2.pkl`: in total 40 initial states for testing cloth smoothing on square clothes (each cached file has 20 states).
15 | - `cloth_flatten_test_retangular_1.pkl` and `cloth_flatten_test_retangular_1.pkl`: Initial states for testing on rectangular clothes (each cached file has 20 states).
16 | - `tshirt_flatten_init_states_small_2021_05_28_01_22.pkl` and `tshirt_flatten_init_states_small_2021_05_28_01_16.pkl`: Initial states for testing on t-shirt (each cahced file has 20 states).
17 |
--------------------------------------------------------------------------------
/pretrained/vis_dynamics.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/pretrained/vis_dynamics.gif
--------------------------------------------------------------------------------
/pretrained/vis_planning.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Xingyu-Lin/VCD/328f3e0ada7c3d39b974e875c5847badfca5037d/pretrained/vis_planning.gif
--------------------------------------------------------------------------------