├── chester
├── rsync_include
├── scripts
│ ├── install_mpi4py.sh
│ └── install_miniconda.sh
├── config_empty.py
├── rsync_exclude
├── examples
│ ├── presets.py
│ ├── presets3.py
│ ├── presets2.py
│ ├── train_launch.py
│ ├── presets_tiancheng.py
│ ├── train.py
│ ├── pgm_plot.py
│ └── cplot_example.py
├── pull_result.py
├── video_recorder.py
├── pull_s3_result.py
├── add_variants.py
├── config_private.py
├── containers
│ └── ubuntu-16.04-lts-rl.README
├── config.py
├── config_ec2.py
├── slurm.py
├── run_exp_worker.py
├── plotting
│ └── cplot.py
├── setup_ec2_for_chester.py
├── logger.py
└── utils_s3.py
├── bc
├── 100_tool_pts.pkl
├── logger.py
├── encoder.py
├── pointnet2_classification.py
├── rotations.py
└── se3.py
├── initializers
├── __init__.py
├── random.py
├── utils.py
├── base.py
└── mm.py
├── rpmg
├── README.md
├── rpmg_losses.py
└── rpmg.py
├── compile_1.0.sh
├── prepare_1.0.sh
├── compile.sh
├── LICENSE
├── .gitignore
├── README.md
└── tests
├── test_pt3d_pyquaternion.ipynb
└── tool_reduce_test.ipynb
/chester/rsync_include:
--------------------------------------------------------------------------------
1 | softgym/PyFlexRobotics/data
--------------------------------------------------------------------------------
/bc/100_tool_pts.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DanielTakeshi/softagent_tfn/HEAD/bc/100_tool_pts.pkl
--------------------------------------------------------------------------------
/initializers/__init__.py:
--------------------------------------------------------------------------------
1 | from initializers.base import InitializerBase
2 | from initializers.utils import load_initializer, initialize_env
--------------------------------------------------------------------------------
/rpmg/README.md:
--------------------------------------------------------------------------------
1 | # RPMG (Rotation Projective Manifold Gradient)
2 |
3 | See https://github.com/JYChen18/RPMG which gives us:
4 |
5 | - `rpmg.py` (this has the layer)
6 | - `tools.py`
7 |
--------------------------------------------------------------------------------
/chester/scripts/install_mpi4py.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Install mpi4py 3.0.0"
3 | cd /tmp
4 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz
5 | tar -zxf mpi4py-3.0.0.tar.gz
6 | cd mpi4py-3.0.0
7 | python setup.py build --mpicc=/usr/local/bin/mpicc
8 | python setup.py install --user
--------------------------------------------------------------------------------
/compile_1.0.sh:
--------------------------------------------------------------------------------
1 | cd ${PYFLEXROOT}/bindings
2 | rm -rf build
3 | mkdir build
4 | cd build
5 | # Seuss
6 | if [[ $(hostname) = *"compute-0"* ]] || [[ $(hostname) = *"autobot-"* ]] || [[ $(hostname) = *"yertle"* ]]; then
7 | export CUDA_BIN_PATH=/usr/local/cuda-9.1
8 | fi
9 | cmake -DPYBIND11_PYTHON_VERSION=3.6 ..
10 | make -j
11 |
--------------------------------------------------------------------------------
/prepare_1.0.sh:
--------------------------------------------------------------------------------
1 | PATH=~/miniconda3/bin:$PATH
2 | cd softgym
3 | . prepare_1.0.sh
4 | cd ..
5 | export PYTORCH_JIT=0
6 | export PYFLEXROOT=${PWD}/softgym/PyFlex
7 | export PYTHONPATH=${PWD}:${PWD}/softgym:${PYFLEXROOT}/bindings/build:$PYTHONPATH
8 | export LD_LIBRARY_PATH=${PYFLEXROOT}/external/SDL2-2.0.4/lib/x64:$LD_LIBRARY_PATH
9 | export EGL_GPU=$CUDA_VISIBLE_DEVICES
10 |
--------------------------------------------------------------------------------
/compile.sh:
--------------------------------------------------------------------------------
1 | cd ${PYFLEXROOT}/bindings
2 | rm -rf build
3 | mkdir build
4 | cd build
5 | # Seuss
6 | if [[ $(hostname) = *"compute-0"* ]] || [[ $(hostname) = *"autobot-"* ]] || [[ $(hostname) = *"yertle"* ]]; then
7 | export CUDA_BIN_PATH=/usr/local/cuda-9.1
8 | fi
9 |
10 | if [[ $(hostname) = *"Xingyu-"* ]]; then
11 | export CUDA_BIN_PATH=/usr/local/cuda-9.2
12 | fi
13 | cmake -DPYBIND11_PYTHON_VERSION=3.6 ..
14 | make -j
15 |
--------------------------------------------------------------------------------
/initializers/random.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from initializers import InitializerBase
4 |
5 | class RandomInitializer(InitializerBase):
6 | def __init__(self, env, args):
7 | self.env = env
8 |
9 | def get_action(self, obs, info=None):
10 | action = self.env.action_space.sample()
11 |
12 | # done with probability 0.5
13 | done = random.random() < 0.5
14 |
15 | return action, done
16 |
--------------------------------------------------------------------------------
/chester/scripts/install_miniconda.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Check this file before using
3 | CONDA_INSTALL_PATH="~/software/miniconda3"
4 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
5 | chmod +x Miniconda3-latest-Linux-x86_64.sh
6 | ./Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_INSTALL_PATH
7 | if [ -d $CONDA_INSTALL_PATH/bin ]; then
8 | PATH=$PATH:$HOME/bin
9 | fi
10 | echo 'PATH='$CONDA_INSTALL_PATH'/bin:$PATH' >> ~/.bashrc
11 | rm ./Miniconda3-latest-Linux-x86_64.sh
--------------------------------------------------------------------------------
/chester/config_empty.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 |
4 | # TODO change this before make it into a pip package
5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
6 |
7 | LOG_DIR = os.path.join(PROJECT_PATH, "data")
8 |
9 | # Make sure to use absolute path
10 | REMOTE_DIR = {
11 | }
12 |
13 | REMOTE_MOUNT_OPTION = {
14 | }
15 |
16 | REMOTE_LOG_DIR = {
17 | }
18 |
19 | REMOTE_HEADER = dict()
20 |
21 | # location of the singularity file related to the project
22 | SIMG_DIR = {
23 | }
24 | CUDA_MODULE = {
25 | }
26 | MODULES = {
27 | }
--------------------------------------------------------------------------------
/chester/rsync_exclude:
--------------------------------------------------------------------------------
1 | *.img
2 | datasets
3 | data
4 | data/yufei_s3_data
5 | data/yufei_seuss_data
6 | data/local
7 | data/icml
8 | *__pycache__*
9 | build
10 | wandb
11 | .idea
12 | .git
13 | DPI-Net
14 | videos
15 | imgs
16 | data_demo/
17 | *.swp
18 | *.swo
19 | *.png
20 | cloth_manipulation
21 | GNS
22 | PDDM
23 | PDDM
24 | pouring
25 | rlkit
26 | rllab
27 | rlpyt_cloth
28 | VCD
29 | softgym/dummy_data
30 | softgym/PyFlexRobotics/data/PR2/
31 | softgym/PyFlexRobotics/data/atlas_description/
32 | softgym/PyFlexRobotics/data/baxter_common-master/
33 | softgym/PyFlexRobotics/data/dex-net/
34 | softgym/PyFlexRobotics/data/yumi_description/
35 | softgym/PyFlexRobotics/data/franka_description/
36 | softgym/PyFlexRobotics/data/sawyer/
37 | softgym/PyFlexRobotics/data/jaco/meshes/
38 | softgym/PyFlexRobotics/data/kuka_iiwa/
39 | softgym/PyFlexRobotics/data/sektion_cabinet_model/
--------------------------------------------------------------------------------
/chester/examples/presets.py:
--------------------------------------------------------------------------------
1 | preset_names = ['default']
2 |
3 | def make_custom_seris_splitter(preset_names):
4 | legendNote = None
5 | if preset_names == 'default':
6 | def custom_series_splitter(x):
7 | params = x['flat_params']
8 | # return params['her_replay_strategy']
9 | if params['her_replay_strategy'] == 'future':
10 | ret = 'RG'
11 | elif params['her_replay_strategy'] == 'only_fake':
12 | if params['her_use_reward']:
13 | ret = 'FG+RR'
14 | else:
15 | ret = 'FG+FR'
16 | return ret + '+' + str(params['her_clip_len']) + '+' + str(params['her_reward_choices']) + '+' + str(
17 | params['her_failed_goal_ratio'])
18 |
19 | legendNote = "Fake Goal(FG)/Real Goal(RG) + Fake Reward(FR)/Real Goal(RG) + HER_clip_len + HER_reward_choices + HER_failed_goal_ratio"
20 | else:
21 | raise NotImplementedError
22 | return custom_series_splitter, legendNote
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Carnegie Mellon University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/initializers/utils.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 |
3 | def load_initializer(initializer_cls_string):
4 | """
5 | Loads an initializer class from a string.
6 | """
7 | initializer_module_name, initializer_class_name = initializer_cls_string.rsplit('.', 1)
8 | initializer_module = import_module(initializer_module_name)
9 | initializer_cls = getattr(initializer_module, initializer_class_name)
10 | return initializer_cls
11 |
12 | def initialize_env(env, initializer, initial_obs):
13 | """
14 | Fully initializes an environment using an initializer
15 |
16 | Args:
17 | env: the environment to initialize
18 | initializer: the initializer to use
19 | initial_obs: the initial observation
20 | """
21 | initializer.reset()
22 | action, done = initializer.get_action(initial_obs)
23 | obs, _, env_done, info = env.step(action)
24 |
25 | while not done:
26 | if env_done:
27 | raise RuntimeError("Environment is done before initializer is done!")
28 |
29 | action, done = initializer.get_action(obs, info)
30 | obs, _, env_done, info = env.step(action)
31 |
32 | return obs, env_done, info
33 |
--------------------------------------------------------------------------------
/chester/examples/presets3.py:
--------------------------------------------------------------------------------
1 | set1 = 'identity_ratio+her_clip(dist_behavior and HER)'
2 | preset_names = [set1]
3 | FILTERED = 'filtered'
4 |
5 |
6 | def make_custom_seris_splitter(preset_names):
7 | legendNote = None
8 | if preset_names == set1:
9 | def custom_series_splitter(x):
10 | params = x['flat_params']
11 | if params['her_failed_goal_option'] in ['dist_G', 'dist_policy']:
12 | return FILTERED
13 | if params['her_identity_ratio'] is not None:
14 | return 'IR: ' + str(params['her_identity_ratio'])
15 | if params['her_clip_len'] is not None:
16 | return 'CL: ' + str(params['her_clip_len'])
17 | return 'HER'
18 |
19 | legendNote = 'IR: identity ratio; CL: clip length'
20 | else:
21 | raise NotImplementedError
22 | return custom_series_splitter, legendNote
23 |
24 |
25 | def make_custom_filter(preset_names):
26 | if preset_names == set1:
27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names)
28 |
29 | def custom_filter(x):
30 | legend = custom_seris_splitter(x)
31 | if legend == FILTERED:
32 | return False
33 | else:
34 | return True
35 | return custom_filter
36 |
37 |
--------------------------------------------------------------------------------
/chester/examples/presets2.py:
--------------------------------------------------------------------------------
1 | preset_names = ['default']
2 | x_axis = 'Epoch'
3 | y_axis = 'Success'
4 | FILTERED = 'filtered'
5 |
6 | def make_custom_seris_splitter(preset_names):
7 | legendNote = None
8 | if preset_names == 'default':
9 | def custom_series_splitter(x):
10 | params = x['flat_params']
11 | if params['her_failed_goal_option'] is None:
12 | ret = 'Distance Reward'
13 | elif params['her_failed_goal_option'] == 'dist_behaviour':
14 | ret = 'Exact Match'
15 | else:
16 | ret = FILTERED
17 | return ret
18 |
19 | legendNote = None
20 | else:
21 | raise NotImplementedError
22 | return custom_series_splitter, legendNote
23 |
24 |
25 | def make_custom_filter(preset_names):
26 | if preset_names == 'default':
27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names)
28 | def custom_filter(x):
29 | legend = custom_seris_splitter(x)
30 | if legend == FILTERED:
31 | return False
32 | else:
33 | return True
34 | # params = x['flat_params']
35 | # if params['her_failed_goal_option'] != FILTERED:
36 | # return True
37 | # else:
38 | # return False
39 | return custom_filter
40 |
41 |
--------------------------------------------------------------------------------
/chester/pull_result.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import argparse
4 |
5 | sys.path.append('.')
6 | from chester import config
7 |
8 | if __name__ == "__main__":
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('host', type=str)
11 | parser.add_argument('folder', type=str)
12 | parser.add_argument('--dry', action='store_true', default=False)
13 | parser.add_argument('--bare', action='store_true', default=False)
14 | args = parser.parse_args()
15 |
16 | args.folder = args.folder.rstrip('/')
17 | if args.folder.rfind('/') !=-1:
18 | local_dir = os.path.join('./data', args.host, args.folder[:args.folder.rfind('/')])
19 | else:
20 | local_dir = os.path.join('./data', args.host, args.folder)
21 | remote_data_dir = os.path.join(config.REMOTE_DIR[args.host], 'data', 'local', args.folder)
22 | command = """rsync -avzh --delete --progress {host}:{remote_data_dir} {local_dir}""".format(host=args.host,
23 | remote_data_dir=remote_data_dir,
24 | local_dir=local_dir)
25 | if args.bare:
26 | command += """ --exclude '*.pkl' --exclude '*.png' --exclude '*.gif' --exclude '*.pth' --exclude '*.pt' --include '*.csv' --include '*.json' --delete"""
27 | if args.dry:
28 | print(command)
29 | else:
30 | os.system(command)
31 |
--------------------------------------------------------------------------------
/initializers/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | class InitializerBase(ABC):
4 | """
5 | Base class for initializers.
6 | """
7 | @abstractmethod
8 | def __init__(self, env, args):
9 | """
10 | Initialize the initializer. Variant `args` are passed in for convenience.
11 |
12 | Args:
13 | env: Environment object
14 | args: Argument class
15 | """
16 | pass
17 |
18 | def reset(self):
19 | """
20 | (Optional) Called when environment is reset.
21 | Useful for stateful initializers that need to re-init
22 | themselves when a new environment is created
23 | """
24 | pass
25 |
26 | @abstractmethod
27 | def get_action(self, obs, info=None):
28 | """
29 | Get action from the initializer.
30 |
31 | NOTE: `info` is NOT AVAILABLE on the first environment step!
32 | If `info` is always needed, just sample something arbitrary
33 | on the first step. Can also use `env` information saved from
34 | __init__.
35 |
36 | However, good practice is to only use information from `obs`, as
37 | this is what the initializer will have access to in real.
38 |
39 | Args:
40 | obs: observation from the environment
41 | info: info from the environment
42 |
43 | Returns:
44 | action: action to be taken
45 | done: True if initializer is complete
46 | """
47 | pass
48 |
--------------------------------------------------------------------------------
/chester/examples/train_launch.py:
--------------------------------------------------------------------------------
1 | import time
2 | from chester.run_exp import run_experiment_lite, VariantGenerator
3 |
4 | if __name__ == '__main__':
5 |
6 | # Here's an example for doing grid search of openai's DDPG
7 | # on HalfCheetah
8 |
9 | # the experiment folder name
10 | # the directory is defined as /LOG_DIR/data/local/exp_prefix/, where LOG_DIR is defined in config.py
11 | exp_prefix = 'test-ddpg'
12 | vg = VariantGenerator()
13 | vg.add('env_id', ['HalfCheetah-v2', 'Hopper-v2', 'InvertedPendulum-v2'])
14 |
15 | # select random seeds from 0 to 4
16 | vg.add('seed', [0, 1, 2, 3, 4])
17 | print('Number of configurations: ', len(vg.variants()))
18 |
19 | # set the maximum number for running experiments in parallel
20 | # this number depends on the number of processors in the runner
21 | maximum_launching_process = 5
22 |
23 | # launch experiments
24 | sub_process_popens = []
25 | for vv in vg.variants():
26 | while len(sub_process_popens) >= maximum_launching_process:
27 | sub_process_popens = [x for x in sub_process_popens if x.poll() is None]
28 | time.sleep(10)
29 |
30 | # import the launcher of experiments
31 | from chester.examples.train import run_task
32 |
33 | # use your written run_task function
34 | cur_popen = run_experiment_lite(
35 | stub_method_call=run_task,
36 | variant=vv,
37 | mode='local',
38 | exp_prefix=exp_prefix,
39 | wait_subprocess=False
40 | )
41 | if cur_popen is not None:
42 | sub_process_popens.append(cur_popen)
43 |
--------------------------------------------------------------------------------
/chester/video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glfw
3 | from multiprocessing import Process, Queue
4 |
5 | import cv2 as cv
6 |
7 | class VideoRecorder(object):
8 | '''
9 | Used to record videos for mujoco_py environment
10 | '''
11 | def __init__(self, env, saved_path='./data/videos/', saved_name='temp'):
12 | # Get rid of the gym wrappers
13 | if hasattr(env, 'env'):
14 | env = env.env
15 | self.viewer = env._get_viewer()
16 | self.saved_path = saved_path
17 | self.saved_name = saved_name
18 | # self._set_filepath('/tmp/temp%07d.mp4')
19 | saved_name += '.mp4'
20 | self._set_filepath(os.path.join(saved_path, saved_name))
21 |
22 | def _set_filepath(self, video_name):
23 | self.viewer._video_path = video_name
24 |
25 | def start(self):
26 | self.viewer._record_video = True
27 | if self.viewer._record_video:
28 | fps = (1 / self.viewer._time_per_render)
29 | self.viewer._video_process = Process(target=save_video,
30 | args=(self.viewer._video_queue,
31 | self.viewer._video_path, fps))
32 | self.viewer._video_process.start()
33 |
34 | def end(self):
35 | self.viewer.key_callback(None, glfw.KEY_V, None, glfw.RELEASE, None)
36 |
37 | # class VideoRecorderDM(object):
38 | # '''
39 | # Used to record videos for dm_control based environments
40 | # '''
41 | # def __init__(self, env, saved_path='./data/videos/', saved_name='temp'):
42 | # self.saved_path = saved_path
43 | # self.saved_name = saved_name
44 | #
45 | # def
--------------------------------------------------------------------------------
/chester/pull_s3_result.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import argparse
4 |
5 | def aws_sync(bucket_name, s3_log_dir, target_dir, args):
6 | cmd = 'aws s3 cp --recursive s3://%s/%s %s' % (bucket_name, s3_log_dir, target_dir)
7 | exlus = ['"*.pkl"', '"*.gif"', '"*.png"', '"*.pth"']
8 | inclus = []
9 | if args.gif:
10 | exlus.remove('"*.gif"')
11 | if args.png:
12 | exlus.remove('"*.png"')
13 | if args.param:
14 | inclus.append('"params.pkl"')
15 | exlus.remove('"*.pkl"')
16 |
17 | if not args.include_all:
18 | for exc in exlus:
19 | cmd += ' --exclude ' + exc
20 |
21 | for inc in inclus:
22 | cmd += ' --include ' + inc
23 |
24 | print(cmd)
25 | # exit()
26 | subprocess.call(cmd, shell=True)
27 |
28 |
29 | def main():
30 | parser = argparse.ArgumentParser(description='Process some integers.')
31 | parser.add_argument('log_dir', type=str, help='S3 Log dir')
32 | parser.add_argument('-b', '--bucket', type=str, default='chester-softgym', help='S3 Bucket')
33 | parser.add_argument('--param', type=int, default=0, help='Exclude')
34 | parser.add_argument('--gif', type=int, default=0, help='Exclude')
35 | parser.add_argument('--png', type=int, default=0, help='Exclude')
36 | parser.add_argument('--include_all', type=int, default=1, help='pull all data')
37 |
38 | args = parser.parse_args()
39 | s3_log_dir = "rllab/experiments/" + args.log_dir
40 | local_dir = os.path.join('./data', 'corl_s3_data', args.log_dir)
41 | if not os.path.exists(local_dir):
42 | os.makedirs(local_dir)
43 | aws_sync(args.bucket, s3_log_dir, local_dir, args)
44 |
45 |
46 | if __name__ == "__main__":
47 | main()
48 |
--------------------------------------------------------------------------------
/chester/examples/presets_tiancheng.py:
--------------------------------------------------------------------------------
1 | # Updated By Tiancheng Jin 08/28/2018
2 |
3 | # the preset file should be contained in the experiment folder ( which is assigned by exp_prefix )
4 | # for example, this file should be put in /path to project/data/local/
5 |
6 | # Here's an example for custom_series_splitter
7 | # suppose we want to split five experiments with random seeds from 0 to 4 into two strategies
8 | # * two groups for those with odd or plural random seeds: [0,2,4] and [1,3]
9 | # * two groups for those with smaller or larger random seeds: [0,1,2] and [3,4]
10 |
11 | preset_names = ['odd or plural','small or large']
12 |
13 |
14 | def make_custom_seris_splitter(preset_name):
15 | legend_note = None
16 | custom_series_splitter = None
17 |
18 | if preset_name == 'odd or plural':
19 | # build a custom series splitter for odd or plural random seeds
20 | # where the input is the data for experiment ( contains both the results and the parameters )
21 | def custom_series_splitter(x):
22 | # extract the parameters
23 | params = x['flat_params']
24 | # make up the legend
25 | if params['seed'] % 2 == 0:
26 | legend = 'odd seeds'
27 | else:
28 | legend = 'plural seeds'
29 | return legend
30 |
31 | legend_note = "Odd or Plural"
32 |
33 | elif preset_name == 'small or large':
34 | def custom_series_splitter(x):
35 | params = x['flat_params']
36 | if params['seed'] <= 2:
37 | legend = 'smaller seeds'
38 | else:
39 | legend = 'larger seeds'
40 | return legend
41 |
42 | legend_note = "Small or Large"
43 | else:
44 | assert NotImplementedError
45 |
46 | return custom_series_splitter, legend_note
47 |
--------------------------------------------------------------------------------
/chester/add_variants.py:
--------------------------------------------------------------------------------
1 | # Add variants to finished experiments
2 | import argparse
3 | import os
4 | import json
5 | from pydoc import locate
6 | import config
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('exp_folder', type=str, help='Root of the experiment folder to walk through')
12 | parser.add_argument('key', type=str, help='Name of the additional key')
13 | parser.add_argument('value', help='Value of the additional key')
14 | parser.add_argument('value_type', default='str', type=str, help='Type of the additional key')
15 | parser.add_argument('remote', nargs='?', default=None, type=str, ) # Optional
16 |
17 | args = parser.parse_args()
18 | exp_paths = [x[0] for x in os.walk(args.exp_folder, followlinks=True)]
19 |
20 | value_type = locate(args.value_type)
21 | if value_type == bool:
22 | value = args.value in ['1', 'True', 'true']
23 | else:
24 | value = value_type(args.value)
25 |
26 | for exp_path in exp_paths:
27 | try:
28 | variant_path = os.path.join(exp_path, "variant.json")
29 | # Modify locally
30 | with open(variant_path, 'r') as f:
31 | vv = json.load(f)
32 | if args.key in vv:
33 | print('Warning: key already in variants. {} = {}. Setting it to {}'.format(args.key, vv[args.key], value))
34 |
35 | vv[args.key] = value
36 | with open(variant_path, 'w') as f:
37 | json.dump(vv, f, indent=2, sort_keys=True)
38 | print('{} modified'.format(variant_path))
39 |
40 | # Upload it to remote
41 | if args.remote is not None:
42 | p = variant_path.rstrip('/').split('/')
43 | sub_exp_name, exp_name = p[-2], p[-3]
44 |
45 | remote_dir = os.path.join(config.REMOTE_DIR[args.remote], 'data', 'local', exp_name, sub_exp_name, 'variant.json')
46 | os.system('scp {} {}:{}'.format(variant_path, args.remote, remote_dir))
47 | except IOError as e:
48 | print(e)
49 |
50 |
51 | if __name__ == '__main__':
52 | main()
53 |
--------------------------------------------------------------------------------
/chester/config_private.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 |
4 | # TODO change this before make it into a pip package
5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
6 |
7 | LOG_DIR = os.path.join(PROJECT_PATH, "data")
8 |
9 | # Make sure to use absolute path
10 | REMOTE_DIR = {
11 | 'seuss': '/home/yufeiw2/Projects/softagent',
12 | 'psc': '/home/yufeiw2/Projects/softagent',
13 | 'nsh': '/home/yufeiw2/Projects/softagent',
14 | 'yertle':'/home/yufeiw2/Projects/softagent'
15 | }
16 |
17 | REMOTE_MOUNT_OPTION = {
18 | 'seuss': '/usr/share/glvnd',
19 | # 'psc': '/pylon5/ir5fpfp/xlin3/Projects/baselines_hrl/:/mnt',
20 | }
21 |
22 | REMOTE_LOG_DIR = {
23 | 'seuss': os.path.join(REMOTE_DIR['seuss'], "data"),
24 |
25 | # 'psc': os.path.join(REMOTE_DIR['psc'], "data")
26 | 'psc': os.path.join('/mnt', "data"),
27 | }
28 | # PSC: https://www.psc.edu/bridges/user-guide/running-jobs
29 | # partition include [RM, RM-shared, LM, GPU]
30 | # TODO change cpu-per-task based on the actual cpus needed (on psc)
31 | # #SBATCH --exclude=compute-0-[7,11]
32 | # Adding this will make the job to grab the whole gpu. #SBATCH --gres=gpu:1
33 | REMOTE_HEADER = dict(seuss="""
34 | #!/usr/bin/env bash
35 | #SBATCH --nodes=1
36 | #SBATCH --partition=GPU
37 | #SBATCH --exclude=compute-0-[5,7,13,27]
38 | #SBATCH --ntasks-per-node=8
39 | #SBATCH --time=480:00:00
40 | #SBATCH --gres=gpu:1
41 | #SBATCH --mem=90G
42 | """.strip(), psc="""
43 | #!/usr/bin/env bash
44 | #SBATCH --nodes=1
45 | #SBATCH --partition=RM
46 | #SBATCH --ntasks-per-node=18
47 | #SBATCH --time=48:00:00
48 | #SBATCH --mem=64G
49 | """.strip(), psc_gpu="""
50 | #!/usr/bin/env bash
51 | #SBATCH --nodes=1
52 | #SBATCH --partition=GPU-shared
53 | #SBATCH --gres=gpu:p100:1
54 | #SBATCH --ntasks-per-node=4
55 | #SBATCH --time=48:00:00
56 | """.strip())
57 |
58 | # location of the singularity file related to the project
59 | SIMG_DIR = {
60 | 'seuss': '/home/xlin3/softgym_containers/softgymcontainer_v3.simg',
61 | # 'psc': '$SCRATCH/containers/ubuntu-16.04-lts-rl.img',
62 | 'psc': '/pylon5/ir5fpfp/xlin3/containers/ubuntu-16.04-lts-rl.img',
63 |
64 | }
65 | CUDA_MODULE = {
66 | 'seuss': 'cuda-91',
67 | 'psc': 'cuda/9.0',
68 | }
69 | MODULES = {
70 | 'seuss': ['singularity'],
71 | 'psc': ['singularity'],
72 | }
73 |
--------------------------------------------------------------------------------
/chester/containers/ubuntu-16.04-lts-rl.README:
--------------------------------------------------------------------------------
1 | Bootstrap: debootstrap
2 | OSVersion: xenial
3 | MirrorURL: http://us.archive.ubuntu.com/ubuntu/
4 |
5 | %help
6 | This is a singularity container that runs Deep Reinforcement Learning algorithms on ubuntu
7 | Packages installed include:
8 | * cuda 9.0 and cuDNN
9 | Will run ~/.bashrc on start to make sure the PATH is the same.
10 |
11 | %runscript
12 | /usr/bin/nvidia-smi -L
13 |
14 | %environment
15 | LD_LIBRARY_PATH=/usr/local/cuda-9.0/cuda/lib64:/usr/local/cuda-9.0/lib64:/usr/lib/nvidia-384$LD_LIBRARY_PATH
16 |
17 | %setup
18 | echo "Let us have CUDA..."
19 | sh /home/xingyu/software/cuda/cuda_9.0.176_384.81_linux.run --silent --toolkit --toolkitpath=${SINGULARITY_ROOTFS}/usr/local/cuda-9.0
20 | ln -s ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0 ${SINGULARITY_ROOTFS}/usr/local/cuda
21 | echo "Let us also have cuDNN..."
22 | cp -prv /home/xingyu/software/cudnn/* ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0/
23 |
24 | %labels
25 | AUTHOR xlin3@cs.cmu.edu
26 | VERSION v1.0
27 |
28 | %post
29 | echo "Hello from inside the container"
30 | sed -i 's/$/ universe/' /etc/apt/sources.list
31 | touch /usr/bin/nvidia-smi
32 | chmod +x /usr/bin/nvidia-smi
33 |
34 | apt-get -y update
35 | apt-get -y install software-properties-common vim make wget curl emacs ffmpeg git htop libffi-dev libglew-dev libgl1-mesa-glx libosmesa6 libosmesa6-dev libssl-dev mesa-utils module-init-tools openjdk-8-jdk python-dev python-numpy python-tk bzip2
36 | apt-get -y install build-essential
37 | apt-get -y install libgl1-mesa-dev libglfw3-dev
38 | apt-get -y install strace
39 |
40 | echo "Install openmpi 3.1.1"
41 |
42 | cd /tmp
43 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.1.tar.gz
44 | tar xf openmpi-3.1.1.tar.gz
45 | cd openmpi-3.1.1
46 | mkdir -p build
47 | cd build
48 | ../configure
49 | make -j 8 all
50 | make install
51 | apt-get -y install openmpi-bin
52 | rm -rf /tmp/openmpi*
53 | rm -rf /usr/bin/mpirun
54 | ln -s /usr/local/bin/mpirun /usr/bin/mpirun
55 |
56 | echo "Install mpi4py 3.0.0"
57 | cd /tmp
58 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz
59 | tar -zxf mpi4py-3.0.0.tar.gz
60 | cd mpi4py-3.0.0
61 | python setup.py build --mpicc=
62 | python setup.py install --user
63 | mkdir -p /usr/lib/nvidia-384
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Daniel Seita
2 | data/plots/
3 | *.png
4 | *.swp
5 | *.swo
6 | test.sh
7 | tmp/
8 |
9 |
10 | GNS/pcl_filter/build
11 | rlpyt_cloth/data
12 | datasets
13 | dpi_visualization
14 | *.gif
15 | .vscode
16 | chester/private
17 | videos
18 | .idea
19 | data
20 | *.simg
21 | *.img
22 | # Byte-compiled / optimized / DLL files
23 | __pycache__/
24 | *.py[cod]
25 | *$py.class
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | wheels/
44 | pip-wheel-metadata/
45 | share/python-wheels/
46 | *.egg-info/
47 | .installed.cfg
48 | *.egg
49 | MANIFEST
50 |
51 | # PyInstaller
52 | # Usually these files are written by a python script from a template
53 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
54 | *.manifest
55 | *.spec
56 |
57 | # Installer logs
58 | pip-log.txt
59 | pip-delete-this-directory.txt
60 |
61 | # Unit test / coverage reports
62 | htmlcov/
63 | .tox/
64 | .nox/
65 | .coverage
66 | .coverage.*
67 | .cache
68 | nosetests.xml
69 | coverage.xml
70 | *.cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 |
83 | # Flask stuff:
84 | instance/
85 | .webassets-cache
86 |
87 | # Scrapy stuff:
88 | .scrapy
89 |
90 | # Sphinx documentation
91 | docs/_build/
92 |
93 | # PyBuilder
94 | target/
95 |
96 | # Jupyter Notebook
97 | .ipynb_checkpoints
98 |
99 | # IPython
100 | profile_default/
101 | ipython_config.py
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # celery beat schedule file
107 | celerybeat-schedule
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 | .dmypy.json
134 | dmypy.json
135 |
136 | # Pyre type checker
137 | .pyre/
138 |
139 |
140 | # Results
141 | results/
142 |
143 | # DPI data
144 | DPI-Net/dump_*/
145 | DPI-Net/dump_*/
146 |
147 | test_*/
148 | imgs/
149 | softgym
150 | wandb
151 |
--------------------------------------------------------------------------------
/chester/examples/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from baselines import logger
4 | from baselines.common.misc_util import (
5 | set_global_seeds,
6 | )
7 | from baselines.ddpg.main import run
8 | from mpi4py import MPI
9 |
10 |
11 | DEFAULT_PARAMS = {
12 | # env
13 | 'env_id': 'HalfCheetah-v2', # max absolute value of actions on different coordinates
14 |
15 | # ddpg
16 | 'layer_norm': True,
17 | 'render': False,
18 | 'normalize_returns':False,
19 | 'normalize_observations':True,
20 | 'actor_lr': 0.0001, # critic learning rate
21 | 'critic_lr': 0.001, # actor learning rate
22 | 'critic_l2_reg': 1e-2,
23 | 'popart': False,
24 | 'gamma': 0.99,
25 |
26 | # training
27 | 'seed': 0,
28 | 'nb_epochs':500, # number of epochs
29 | 'nb_epoch_cycles': 20, # per epoch
30 | 'nb_rollout_steps': 100, # sampling batches per cycle
31 | 'nb_train_steps': 100, # training batches per cycle
32 | 'batch_size': 64, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
33 | 'reward_scale': 1.0,
34 | 'clip_norm': None,
35 |
36 | # exploration
37 | 'noise_type':'adaptive-param_0.2',
38 |
39 | # debugging, logging and visualization
40 | 'render_eval': False,
41 | 'nb_eval_steps':100,
42 | 'evaluation':False,
43 | }
44 |
45 |
46 | def run_task(vv, log_dir=None, exp_name=None, allow_extra_parameters=False):
47 | # Configure logging system
48 | if log_dir or logger.get_dir() is None:
49 | logger.configure(dir=log_dir)
50 | logdir = logger.get_dir()
51 | assert logdir is not None
52 | os.makedirs(logdir, exist_ok=True)
53 |
54 | # Seed for multi-CPU MPI implementation ( rank = 0 for single threaded implementation )
55 | rank = MPI.COMM_WORLD.Get_rank()
56 | rank_seed = vv['seed'] + 1000000 * rank
57 | set_global_seeds(rank_seed)
58 |
59 | # load params from config
60 | params = DEFAULT_PARAMS
61 |
62 | # update all her parameters
63 | if not allow_extra_parameters:
64 | for k,v in vv.items():
65 | if k not in DEFAULT_PARAMS:
66 | print("[ Warning ] Undefined Parameters %s with value %s"%(str(k),str(v)))
67 | params.update(**{k: v for (k, v) in vv.items() if k in DEFAULT_PARAMS})
68 | else:
69 | params.update(**{k: v for (k, v) in vv.items()})
70 |
71 | with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
72 | json.dump(params, f)
73 |
74 | run(**params)
75 |
76 |
77 |
--------------------------------------------------------------------------------
/chester/examples/pgm_plot.py:
--------------------------------------------------------------------------------
1 | from chester.plotting.cplot import *
2 | import os.path as osp
3 | from random import shuffle
4 |
5 | save_path = '../sac/data/plots'
6 | dict_leg2col = {"LSP": 0, "Base": 1, "Behavior": 2}
7 | dict_xshift = {"LSP": 4000, "Base": 0, "Behavior": 6000}
8 |
9 |
10 | def custom_series_splitter(x):
11 | params = x['flat_params']
12 | exp_name = params['exp_name']
13 | dict_mapping = {'humanoid-resume-training-6000-00': 'Behavior',
14 | 'humanoid-resume-training-4000-00': 'LSP',
15 | 'humanoid-rllab/default-2019-04-14-07-04-08-421230-UTC-00': 'Base'}
16 | return dict_mapping[exp_name]
17 |
18 |
19 | def sliding_mean(data_array, window=5):
20 | data_array = np.array(data_array)
21 | new_list = []
22 | for i in range(len(data_array)):
23 | indices = list(range(max(i - window + 1, 0),
24 | min(i + window + 1, len(data_array))))
25 | avg = 0
26 | for j in indices:
27 | avg += data_array[j]
28 | avg /= float(len(indices))
29 | new_list.append(avg)
30 |
31 | return np.array(new_list)
32 |
33 |
34 | def plot_main():
35 | data_path = '../sac/data/mengxiong'
36 | plot_key = 'return-average'
37 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
38 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
39 | fig, ax = plt.subplots(figsize=(8, 5))
40 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
41 | color = core.color_defaults[dict_leg2col[legend]]
42 |
43 | y, y_lower, y_upper = get_shaded_curve(selector, plot_key, shade_type='median')
44 | x = np.array(range(len(y)))
45 | x += dict_xshift[legend]
46 | y = sliding_mean(y, 5)
47 | ax.plot(x, y, color=color, label=legend, linewidth=2.0)
48 |
49 | # ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0,
50 | # alpha=0.2)
51 |
52 | def y_fmt(x, y):
53 | return str(int(np.round(x))) + 'K'
54 |
55 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
56 | ax.grid(True)
57 | ax.set_xlabel('Timesteps')
58 | ax.set_ylabel('Average-return')
59 |
60 | # plt.title(env_name.replace('Float', 'Push'))
61 | loc = 'best'
62 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends)
63 | for legobj in leg.legendHandles:
64 | legobj.set_linewidth(3.0)
65 |
66 | save_name = filter_save_name('plots.png')
67 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight')
68 |
69 |
70 | if __name__ == '__main__':
71 | plot_main()
72 |
--------------------------------------------------------------------------------
/chester/config.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 |
4 | # TODO change this before make it into a pip package
5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
6 |
7 | LOG_DIR = os.path.join(PROJECT_PATH, "data")
8 |
9 | # Make sure to use absolute path
10 | REMOTE_DIR = {
11 | 'seuss': '/home/dseita/softagent_rpad_MM',
12 | 'autobot': '/home/xlin3/Projects/softagent',
13 | 'psc': '/home/xlin3/Projects/softagent',
14 | 'nsh': '/home/xingyu/Projects/softagent',
15 | 'yertle': '/home/xingyu/Projects/softagent'
16 | }
17 |
18 | REMOTE_MOUNT_OPTION = {
19 | 'seuss': '/usr/share/glvnd',
20 | 'autobot': '/usr/share/glvnd',
21 | # 'psc': '/pylon5/ir5fpfp/xlin3/Projects/baselines_hrl/:/mnt',
22 | }
23 |
24 | REMOTE_LOG_DIR = {
25 | 'seuss': os.path.join(REMOTE_DIR['seuss'], "data"),
26 | 'autobot': os.path.join(REMOTE_DIR['autobot'], "data"),
27 | # 'psc': os.path.join(REMOTE_DIR['psc'], "data")
28 | 'psc': os.path.join('/mnt', "data"),
29 | }
30 |
31 | # PSC: https://www.psc.edu/bridges/user-guide/running-jobs
32 | # partition include [RM, RM-shared, LM, GPU]
33 | # TODO change cpu-per-task based on the actual cpus needed (on psc)
34 | # #SBATCH --exclude=compute-0-[7,11]
35 | # Adding this will make the job to grab the whole gpu. #SBATCH --gres=gpu:1
36 | REMOTE_HEADER = dict(seuss="""
37 | #!/usr/bin/env bash
38 | #SBATCH --nodes=1
39 | #SBATCH --partition=GPU
40 | #SBATCH --exclude=compute-0-[7,9,27]
41 | #SBATCH --cpus-per-task=8
42 | #SBATCH --time=480:00:00
43 | #SBATCH --gres=gpu:1
44 | #SBATCH --mem=110G
45 | """.strip(), psc="""
46 | #!/usr/bin/env bash
47 | #SBATCH --nodes=1
48 | #SBATCH --partition=RM
49 | #SBATCH --ntasks-per-node=18
50 | #SBATCH --time=48:00:00
51 | #SBATCH --mem=64G
52 | """.strip(), psc_gpu="""
53 | #!/usr/bin/env bash
54 | #SBATCH --nodes=1
55 | #SBATCH --partition=GPU-shared
56 | #SBATCH --gres=gpu:p100:1
57 | #SBATCH --ntasks-per-node=4
58 | #SBATCH --time=48:00:00
59 | """.strip(), autobot="""
60 | #!/usr/bin/env bash
61 | #SBATCH --nodes=1
62 | #SBATCH --partition=long
63 | #SBATCH --cpus-per-task=32
64 | #SBATCH --time=3-12:00:00
65 | #SBATCH --gres=gpu:1
66 | #SBATCH --mem=40G
67 | """.strip())
68 |
69 | # location of the singularity file related to the project
70 | SIMG_DIR = {
71 | 'seuss': '/home/xlin3/softgym_containers/softgymcontainer_v4.simg',
72 | 'autobot': '/home/xlin3/softgym_containers/softgymcontainer_v3.simg',
73 | # 'psc': '$SCRATCH/containers/ubuntu-16.04-lts-rl.img',
74 | 'psc': '/pylon5/ir5fpfp/xlin3/containers/ubuntu-16.04-lts-rl.img',
75 |
76 | }
77 | CUDA_MODULE = {
78 | 'seuss': 'cuda-91',
79 | 'autobot': 'cuda-10.2',
80 | 'psc': 'cuda/9.0',
81 | }
82 | MODULES = {
83 | 'seuss': ['singularity'],
84 | 'autobot': ['singularity'],
85 | 'psc': ['singularity'],
86 | }
87 |
--------------------------------------------------------------------------------
/chester/config_ec2.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 |
4 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
5 |
6 | AWS_NETWORK_INTERFACES = []
7 |
8 | AWS_BUCKET_REGION_NAME = 'us-east-2'
9 |
10 | MUJOCO_KEY_PATH = osp.expanduser("~/.mujoco")
11 |
12 | USE_GPU = True
13 |
14 | USE_TF = True
15 |
16 | AWS_REGION_NAME = "us-east-2"
17 |
18 | if USE_GPU:
19 | DOCKER_IMAGE = "dementrock/rllab3-shared-gpu"
20 | else:
21 | DOCKER_IMAGE = "dementrock/rllab3-shared"
22 |
23 | DOCKER_LOG_DIR = "/tmp/expt"
24 |
25 | CODE_DIR = "/root/code/"
26 |
27 | AWS_S3_PATH = "s3://chester-softgym/rllab/experiments"
28 |
29 | EBS_OPTIMIZED = True
30 |
31 | AWS_EXTRA_CONFIGS = dict()
32 |
33 | AWS_CODE_SYNC_S3_PATH = "s3://chester-softgym/rllab/code"
34 |
35 | ALL_REGION_AWS_IMAGE_IDS = {
36 | # "ap-northeast-1": "ami-002f0167",
37 | # "ap-northeast-2": "ami-590bd937",
38 | # "ap-south-1": "ami-77314318",
39 | # "ap-southeast-1": "ami-1610a975",
40 | # "ap-southeast-2": "ami-9dd4ddfe",
41 | # "eu-central-1": "ami-63af720c",
42 | # "eu-west-1": "ami-41484f27",
43 | # "sa-east-1": "ami-b7234edb",
44 | "us-east-1": "ami-83f26195",
45 | "us-east-2": "ami-0ec385d5f98faacc3", #"ami-0e63a1a8842443350",
46 | "us-west-1": "ami-576f4b37",
47 | "us-west-2": "ami-b8b62bd8"
48 | }
49 |
50 | AWS_IMAGE_ID = ALL_REGION_AWS_IMAGE_IDS[AWS_REGION_NAME]
51 |
52 | if USE_GPU:
53 | AWS_INSTANCE_TYPE = "p2.xlarge"
54 | else:
55 | AWS_INSTANCE_TYPE = "c4.4xlarge"
56 |
57 | ALL_REGION_AWS_KEY_NAMES = {
58 | "us-east-1": "rllab-us-east-1",
59 | "us-east-2": "rllab-us-east-2",
60 | "us-west-1": "rllab-us-west-1",
61 | "us-west-2": "rllab-us-west-2"
62 | }
63 |
64 | AWS_KEY_NAME = ALL_REGION_AWS_KEY_NAMES[AWS_REGION_NAME]
65 |
66 | AWS_SPOT = True
67 |
68 | AWS_SPOT_PRICE = '2.0'
69 |
70 | AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
71 |
72 | AWS_ACCESS_SECRET = os.environ.get("AWS_ACCESS_SECRET", None)
73 |
74 | AWS_IAM_INSTANCE_PROFILE_NAME = "rllab"
75 |
76 | AWS_SECURITY_GROUPS = ["rllab-sg"]
77 |
78 | ALL_REGION_AWS_SECURITY_GROUP_IDS = {
79 | "us-east-1": [
80 | "sg-e17bfc95"
81 | ],
82 | "us-east-2": [
83 | "sg-1ddb3876"
84 | ],
85 | "us-west-1": [
86 | "sg-cd5f9db4"
87 | ],
88 | "us-west-2": [
89 | "sg-b585a8c9"
90 | ]
91 | }
92 |
93 | AWS_SECURITY_GROUP_IDS = ALL_REGION_AWS_SECURITY_GROUP_IDS[AWS_REGION_NAME]
94 |
95 | FAST_CODE_SYNC_IGNORES = [
96 | ".git",
97 | "data/autobot",
98 | "data/corl_s3_data",
99 | "data/videos",
100 | "data/open_loop_videos",
101 | "data/icml",
102 | "data/local",
103 | "data/seuss",
104 | "data/yufei_s3_data",
105 | "data/icml"
106 | "data/local",
107 | "data/archive",
108 | "data/debug",
109 | "data/s3",
110 | "data/video",
111 | ".idea",
112 | "tests",
113 | "examples",
114 | "docs",
115 | ".idea",
116 | ".DS_Store",
117 | ".ipynb_checkpoints",
118 | "blackbox",
119 | "blackbox.zip",
120 | "*.pyc",
121 | "*.ipynb",
122 | "scratch-notebooks",
123 | "conopt_root",
124 | "private/key_pairs",
125 | "DPI-Net",
126 | "imgs/",
127 | "imgs",
128 | "videos"
129 | ]
130 |
131 | FAST_CODE_SYNC = True
132 |
133 | LABEL = ""
134 |
--------------------------------------------------------------------------------
/chester/slurm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import re
4 | from subprocess import run
5 | from tempfile import NamedTemporaryFile
6 | from chester import config
7 |
8 | # TODO remove the singularity part
9 |
10 | slurm_dir = './'
11 |
12 |
13 | def slurm_run_scripts(scripts):
14 | """this is another function that those _sub files should call. this actually execute files"""
15 | # TODO support running multiple scripts
16 |
17 | assert isinstance(scripts, str)
18 |
19 | os.chdir(slurm_dir)
20 |
21 | # make sure it will run.
22 | assert scripts.startswith('#!/usr/bin/env bash\n')
23 | file_temp = NamedTemporaryFile(delete=False)
24 | file_temp.write(scripts.encode('utf-8'))
25 | file_temp.close()
26 | run(['sbatch', file_temp.name], check=True)
27 | os.remove(file_temp.name)
28 |
29 |
30 | _find_unsafe = re.compile(r'[a-zA-Z0-9_^@%+=:,./-]').search
31 |
32 |
33 | def _shellquote(s):
34 | """Return a shell-escaped version of the string *s*."""
35 | if not s:
36 | return "''"
37 |
38 | if _find_unsafe(s) is None:
39 | return s
40 |
41 | # use single quotes, and put single quotes into double quotes
42 | # the string $'b is then quoted as '$'"'"'b'
43 |
44 | return "'" + s.replace("'", "'\"'\"'") + "'"
45 |
46 |
47 | def _to_param_val(v):
48 | if v is None:
49 | return ""
50 | elif isinstance(v, list):
51 | return " ".join(map(_shellquote, list(map(str, v))))
52 | else:
53 | return _shellquote(str(v))
54 |
55 |
56 | def to_slurm_command(params, header, python_command="python", remote_dir='~/',
57 | script=osp.join(config.PROJECT_PATH, 'scripts/run_experiment.py'),
58 | simg_dir=None, use_gpu=False, modules=None, cuda_module=None, use_singularity=True,
59 | mount_options=None, compile_script=None, wait_compile=None, set_egl_gpu=False):
60 | # TODO Add code for specifying the resource allocation
61 | # TODO Check if use_gpu can be applied
62 | """
63 | Transfer the commands to the format that can be run by slurm.
64 | :param params:
65 | :param python_command:
66 | :param script:
67 | :param use_gpu:
68 | :return:
69 | """
70 | assert simg_dir is not None
71 | command = python_command + " " + script
72 |
73 | pre_commands = params.pop("pre_commands", None)
74 | post_commands = params.pop("post_commands", None)
75 |
76 | command_list = list()
77 | command_list.append(header)
78 |
79 | # Log into singularity shell
80 | if use_singularity:
81 | command_list.append('set -x') # echo commands to stdout
82 | command_list.append('set -u') # throw an error if unset variable referenced
83 | command_list.append('set -e') # exit on errors
84 | command_list.append('srun hostname')
85 |
86 | for remote_module in modules:
87 | command_list.append('module load ' + remote_module)
88 | if use_gpu:
89 | assert cuda_module is not None
90 | command_list.append('module load ' + cuda_module)
91 | command_list.append('cd {}'.format(remote_dir))
92 | # First execute a bash program inside the container and then run all the following commands
93 |
94 | if mount_options is not None:
95 | options = '-B ' + mount_options
96 | else:
97 | options = ''
98 | options += " -B /opt"
99 | sing_prefix = 'singularity exec {} {} {} /bin/bash -c'.format(options, '--nv' if use_gpu else '', simg_dir)
100 | sing_commands = list()
101 | if compile_script is None or 'prepare' not in compile_script :
102 | sing_commands.append('. ./prepare_1.0.sh')
103 | if set_egl_gpu:
104 | sing_commands.append('export EGL_GPU=$SLURM_JOB_GRES')
105 | sing_commands.append('echo $EGL_GPU')
106 | if compile_script is not None:
107 | sing_commands.append(compile_script)
108 | if wait_compile is not None:
109 | sing_commands.append('sleep '+str(int(wait_compile)))
110 |
111 | if pre_commands is not None:
112 | command_list.extend(pre_commands)
113 | for k, v in params.items():
114 | if isinstance(v, dict):
115 | for nk, nv in v.items():
116 | if str(nk) == "_name":
117 | command += " --%s %s" % (k, _to_param_val(nv))
118 | else:
119 | command += " --%s_%s %s" % (k, nk, _to_param_val(nv))
120 | else:
121 | command += " --%s %s" % (k, _to_param_val(v))
122 | sing_commands.append(command)
123 | all_sing_cmds = ' && '.join(sing_commands)
124 | command_list.append(sing_prefix + ' \'{}\''.format(all_sing_cmds))
125 | if post_commands is not None:
126 | command_list.extend(post_commands)
127 | return command_list
128 |
129 | # if __name__ == '__main__':
130 | # slurm_run_scripts(header)
131 |
--------------------------------------------------------------------------------
/chester/examples/cplot_example.py:
--------------------------------------------------------------------------------
1 | from chester.plotting.cplot import *
2 | import os.path as osp
3 |
4 |
5 | def custom_series_splitter(x):
6 | params = x['flat_params']
7 | if 'use_ae_reward' in params and params['use_ae_reward']:
8 | return 'Auto Encoder'
9 | if params['her_replay_strategy'] == 'balance_filter':
10 | return 'Indicator+Balance+Filter'
11 | if params['env_kwargs.use_true_reward']:
12 | return 'Oracle'
13 | return 'Indicator'
14 |
15 |
16 | dict_leg2col = {"Oracle": 1, "Indicator": 0, 'Indicator+Balance+Filter': 2, "Auto Encoder": 3}
17 | save_path = './data/plots_chester'
18 |
19 |
20 | def plot_visual_learning():
21 | data_path = './data/nsh/submit_rss/submit_rss/visual_learning'
22 |
23 | plot_keys = ['test/success_state', 'test/goal_dist_final_state']
24 | plot_ylabels = ['Success', 'Final Distance to Goal']
25 | plot_envs = ['FetchReach', 'Reacher', 'RopeFloat']
26 |
27 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
28 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
29 | for (plot_key, plot_ylabel) in zip(plot_keys, plot_ylabels):
30 | for env_name in plot_envs:
31 | fig, ax = plt.subplots(figsize=(8, 5))
32 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
33 | color = core.color_defaults[dict_leg2col[legend]]
34 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key,
35 | shade_type='median')
36 |
37 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"]
38 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode')
39 | x = [ele * env_horizon for ele in x]
40 |
41 | ax.plot(x, y, color=color, label=legend, linewidth=2.0)
42 |
43 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0,
44 | alpha=0.2)
45 |
46 | def y_fmt(x, y):
47 | return str(int(np.round(x / 1000.0))) + 'K'
48 |
49 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
50 | ax.grid(True)
51 | ax.set_xlabel('Timesteps')
52 | ax.set_ylabel(plot_ylabel)
53 | axes = plt.gca()
54 | if 'Rope' in env_name:
55 | axes.set_xlim(left=20000)
56 |
57 | plt.title(env_name.replace('Float', 'Push'))
58 | loc = 'best'
59 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends)
60 | for legobj in leg.legendHandles:
61 | legobj.set_linewidth(3.0)
62 |
63 | save_name = filter_save_name('ind_visual_' + plot_key + '_' + env_name)
64 |
65 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight')
66 |
67 |
68 | def plot_state_learning():
69 | data_path = './data/nsh/submit_rss/submit_rss/state_learning'
70 |
71 | plot_keys = ['test/success_state', 'test/goal_dist_final_state']
72 | plot_envs = ['FetchReach', 'FetchPush', 'Reacher', 'RopeFloat']
73 |
74 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
75 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
76 | for plot_key in plot_keys:
77 | for env_name in plot_envs:
78 | fig, ax = plt.subplots(figsize=(8, 5))
79 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
80 | color = core.color_defaults[dict_leg2col[legend]]
81 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key,
82 | shade_type='median')
83 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"]
84 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode')
85 | x = [ele * env_horizon for ele in x]
86 | ax.plot(x, y, color=color, label=legend, linewidth=2.0)
87 |
88 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0,
89 | alpha=0.2)
90 |
91 | def y_fmt(x, y):
92 | return str(int(np.round(x / 1000.0))) + 'K'
93 |
94 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
95 | ax.grid(True)
96 | ax.set_xlabel('Timesteps')
97 | ax.set_ylabel('Success')
98 |
99 | plt.title(env_name.replace('Float', 'Push'))
100 | loc = 'best'
101 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends)
102 | for legobj in leg.legendHandles:
103 | legobj.set_linewidth(3.0)
104 |
105 | save_name = filter_save_name('ind_state_' + plot_key + '_' + env_name)
106 |
107 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight')
108 |
109 |
110 | if __name__ == '__main__':
111 | plot_visual_learning()
112 | plot_state_learning()
113 |
--------------------------------------------------------------------------------
/rpmg/rpmg_losses.py:
--------------------------------------------------------------------------------
1 |
2 | import tools
3 | import torch
4 | from pytorch3d.transforms import so3_log_map, so3_exponential_map
5 |
6 | mse_loss = torch.nn.MSELoss()
7 |
8 |
9 | def rpmg_forward(in_nd):
10 | proj_kind = in_nd.shape[1]
11 | if proj_kind == 6:
12 | r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd)
13 | elif proj_kind == 9:
14 | r0 = tools.symmetric_orthogonalization(in_nd)
15 | elif proj_kind == 4:
16 | r0 = tools.compute_rotation_matrix_from_quaternion(in_nd)
17 | elif proj_kind == 10:
18 | r0 = tools.compute_rotation_matrix_from_10d(in_nd)
19 | else:
20 | raise NotImplementedError
21 | # return r0
22 | return r0.transpose(-1, -2)
23 |
24 |
25 | def rpmg_inverse(R, proj_kind):
26 | R = R.transpose(-1, -2)
27 |
28 | if proj_kind == 6:
29 | x = torch.cat([R[:, :, 0], R[:, :, 1]], dim=1)
30 |
31 | elif proj_kind == 9:
32 | x = R.reshape(-1, 9)
33 |
34 | elif proj_kind == 4:
35 | x = tools.compute_quaternions_from_rotation_matrices(R)
36 |
37 | elif proj_kind == 10:
38 | q = tools.compute_quaternions_from_rotation_matrices(R)
39 | reg_A = torch.eye(4, device=q.device)[None].repeat(q.shape[0], 1, 1) \
40 | - torch.bmm(q.unsqueeze(-1), q.unsqueeze(-2))
41 | x = tools.convert_A_to_Avec(reg_A)
42 |
43 | else:
44 | raise NotImplementedError
45 |
46 | return x
47 |
48 |
49 | def rpmg_goal_and_nearest(x, R_goal):
50 | R_goal = R_goal.transpose(-1, -2)
51 | proj_kind = x.shape[1]
52 |
53 | if proj_kind == 6:
54 | x_proj_1 = (R_goal[:, :, 0] * x[:, :3]).sum(dim=1,
55 | keepdim=True) * R_goal[:, :, 0]
56 | x_proj_2 = (R_goal[:, :, 0] * x[:, 3:]).sum(dim=1, keepdim=True) * R_goal[:, :, 0] \
57 | + (R_goal[:, :, 1] * x[:, 3:]).sum(dim=1,
58 | keepdim=True) * R_goal[:, :, 1]
59 | x_goal = torch.cat([R_goal[:, :, 0], R_goal[:, :, 1]], dim=1)
60 | x_nearest = torch.cat([x_proj_1, x_proj_2], dim=1)
61 |
62 | elif proj_kind == 9:
63 | x_goal = R_goal.reshape(-1, 9)
64 | x_nearest = tools.compute_SVD_nearest_Mnlsew(
65 | x.reshape(-1, 3, 3), R_goal)
66 |
67 | elif proj_kind == 4:
68 | q_1 = tools.compute_quaternions_from_rotation_matrices(R_goal)
69 | q_2 = -q_1
70 | x_proj = tools.normalize_vector(x)
71 | x_goal = torch.where(
72 | (q_1 - x_proj).norm(dim=1, keepdim=True) < (q_2 -
73 | x_proj).norm(dim=1, keepdim=True),
74 | q_1, q_2)
75 | x_nearest = (x * x_goal).sum(dim=1, keepdim=True) * x_goal
76 |
77 | elif proj_kind == 10:
78 | q_goal = tools.compute_quaternions_from_rotation_matrices(R_goal)
79 | x_nearest = tools.compute_nearest_10d(x, q_goal)
80 | reg_A = torch.eye(4, device=q_goal.device)[None].repeat(q_goal.shape[0], 1, 1) \
81 | - torch.bmm(q_goal.unsqueeze(-1), q_goal.unsqueeze(-2))
82 | x_goal = tools.convert_A_to_Avec(reg_A)
83 |
84 | return x_nearest, x_goal
85 |
86 |
87 | def projective_manifold_gradient_loss(
88 | x0_pred, x1_pred, delta_R10_gt,
89 | step_size=0.05, lambda_reg=0.01,
90 | transpose=False):
91 | if transpose:
92 | R0_pred = rpmg_forward(x0_pred.detach()).transpose(-1, -2)
93 | R1_pred = rpmg_forward(x1_pred.detach()).transpose(-1, -2)
94 | else:
95 | R0_pred = rpmg_forward(x0_pred.detach())
96 | R1_pred = rpmg_forward(x1_pred.detach())
97 |
98 | delta_r = so3_log_map(torch.bmm(R0_pred.transpose(-1, -2),
99 | torch.bmm(delta_R10_gt, R1_pred)), eps=0.001)
100 | R0_goal = torch.bmm(R0_pred, so3_exponential_map(step_size*delta_r))
101 |
102 | x0_nearest, x0_goal = rpmg_goal_and_nearest(x0_pred, R0_goal)
103 | if(False):
104 | loss = mse_loss(x0_pred, x0_nearest -
105 | lambda_reg*(x0_nearest - x0_goal))
106 | else:
107 | loss = mse_loss(x0_pred, x0_nearest) \
108 | + lambda_reg*mse_loss(x0_pred, x0_goal)
109 | return loss, delta_r.norm(dim=-1)
110 |
111 |
112 | def projective_manifold_gradient_loss_absolute(
113 | x0_pred, x1_pred, R0_gt, R1_gt,
114 | step_size=0.05, lambda_reg=0.01,
115 | transpose=False):
116 | if transpose:
117 | R0_pred = rpmg_forward(x0_pred.detach()).transpose(-1, -2)
118 | R1_pred = rpmg_forward(x1_pred.detach()).transpose(-1, -2)
119 | else:
120 | R0_pred = rpmg_forward(x0_pred.detach())
121 | R1_pred = rpmg_forward(x1_pred.detach())
122 |
123 | delta_R10_gt = torch.bmm(R0_gt, R1_gt.transpose(-1, -2))
124 | delta_r = so3_log_map(torch.bmm(R0_pred.transpose(-1, -2),
125 | torch.bmm(delta_R10_gt, R1_pred)))
126 | R0_goal = torch.bmm(R0_pred, so3_exponential_map(step_size*delta_r))
127 |
128 | x0_nearest, x0_goal = rpmg_goal_and_nearest(x0_pred, R0_goal)
129 | if(False):
130 | loss = mse_loss(x0_pred, x0_nearest -
131 | lambda_reg*(x0_nearest - x0_goal))
132 | else:
133 | loss = mse_loss(x0_pred, x0_nearest) \
134 | + lambda_reg*mse_loss(x0_pred, x0_goal)
135 | return loss, delta_r.norm(dim=-1)
136 |
--------------------------------------------------------------------------------
/chester/run_exp_worker.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os.path as osp
4 | import datetime
5 | import dateutil.tz
6 | import ast
7 | import uuid
8 | import pickle as pickle
9 | import base64
10 | import joblib
11 |
12 | from chester import config
13 |
14 |
15 | def run_experiment(argv):
16 | default_log_dir = config.LOG_DIR
17 | now = datetime.datetime.now(dateutil.tz.tzlocal())
18 |
19 | # avoid name clashes when running distributed jobs
20 | rand_id = str(uuid.uuid4())[:5]
21 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z')
22 |
23 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id)
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--exp_name', type=str, default=default_exp_name,
26 | help='Name of the experiment.')
27 | parser.add_argument('--log_dir', type=str, default=None,
28 | help='Path to save the log and iteration snapshot.')
29 | parser.add_argument('--snapshot_mode', type=str, default='all',
30 | help='Mode to save the snapshot. Can be either "all" '
31 | '(all iterations will be saved), "last" (only '
32 | 'the last iteration will be saved), "gap" (every'
33 | '`snapshot_gap` iterations are saved), or "none" '
34 | '(do not save snapshots)')
35 | parser.add_argument('--snapshot_gap', type=int, default=1,
36 | help='Gap between snapshot iterations.')
37 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv',
38 | help='Name of the tabular log file (in csv).')
39 | parser.add_argument('--text_log_file', type=str, default='debug.log',
40 | help='Name of the text log file (in pure text).')
41 | parser.add_argument('--params_log_file', type=str, default='params.json',
42 | help='Name of the parameter log file (in json).')
43 | parser.add_argument('--variant_log_file', type=str, default='variant.json',
44 | help='Name of the variant log file (in json).')
45 | parser.add_argument('--resume_from', type=str, default=None,
46 | help='Name of the pickle file to resume experiment from.')
47 | parser.add_argument('--plot', type=ast.literal_eval, default=False,
48 | help='Whether to plot the iteration results')
49 | parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False,
50 | help='Whether to only print the tabular log information (in a horizontal format)')
51 | parser.add_argument('--seed', type=int,
52 | help='Random seed for numpy')
53 | parser.add_argument('--args_data', type=str,
54 | help='Pickled data for stub objects')
55 | parser.add_argument('--variant_data', type=str,
56 | help='Pickled data for variant configuration')
57 | parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False)
58 |
59 | args = parser.parse_args(argv[1:])
60 |
61 | # if args.seed is not None:
62 | # set_seed(args.seed)
63 | #
64 | # if args.plot:
65 | # from rllab.plotter import plotter
66 | # plotter.init_worker()
67 |
68 | if args.log_dir is None:
69 | log_dir = osp.join(default_log_dir, args.exp_name)
70 | else:
71 | log_dir = args.log_dir
72 | # tabular_log_file = osp.join(log_dir, args.tabular_log_file)
73 | # text_log_file = osp.join(log_dir, args.text_log_file)
74 | # params_log_file = osp.join(log_dir, args.params_log_file)
75 |
76 | if args.variant_data is not None:
77 | variant_data = pickle.loads(base64.b64decode(args.variant_data))
78 | variant_log_file = osp.join(log_dir, args.variant_log_file)
79 | # logger.log_variant(variant_log_file, variant_data)
80 | else:
81 | variant_data = None
82 |
83 | # if not args.use_cloudpickle:
84 | # logger.log_parameters_lite(params_log_file, args)
85 | #
86 | # logger.add_text_output(text_log_file)
87 | # logger.add_tabular_output(tabular_log_file)
88 | # prev_snapshot_dir = logger.get_snapshot_dir()
89 | # prev_mode = logger.get_snapshot_mode()
90 | # logger.set_snapshot_dir(log_dir)
91 | # logger.set_snapshot_mode(args.snapshot_mode)
92 | # logger.set_snapshot_gap(args.snapshot_gap)
93 | # logger.set_log_tabular_only(args.log_tabular_only)
94 | # logger.push_prefix("[%s] " % args.exp_name)
95 |
96 | if args.resume_from is not None:
97 | data = joblib.load(args.resume_from)
98 | assert 'algo' in data
99 | algo = data['algo']
100 | algo.train()
101 | else:
102 | # read from stdin
103 | if args.use_cloudpickle:
104 | import cloudpickle
105 | method_call = cloudpickle.loads(base64.b64decode(args.args_data))
106 | method_call(variant_data, log_dir, args.exp_name)
107 | else:
108 | assert False
109 | # data = pickle.loads(base64.b64decode(args.args_data))
110 | # maybe_iter = concretize(data)
111 | # if is_iterable(maybe_iter):
112 | # for _ in maybe_iter:
113 | # pass
114 |
115 | # logger.set_snapshot_mode(prev_mode)
116 | # logger.set_snapshot_dir(prev_snapshot_dir)
117 | # logger.remove_tabular_output(tabular_log_file)
118 | # logger.remove_text_output(text_log_file)
119 | # logger.pop_prefix()
120 |
121 |
122 | if __name__ == "__main__":
123 | run_experiment(sys.argv)
124 |
--------------------------------------------------------------------------------
/initializers/mm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from initializers import InitializerBase
4 |
5 | class MoveCloserInitializer(InitializerBase):
6 | THRESH2 = 75 + 125
7 |
8 | def __init__(self, env, args):
9 | self.env = env
10 |
11 | def get_action(self, obs, info=None):
12 | step = self.env.inner_step
13 | T = self.env.tool_idx
14 | dx, dy, dz = 0., 0., 0.
15 |
16 | # Position of the tool. Unfortunately adding to tx and tz is a real hack.
17 | tx = self.env.tool_state[T,0]
18 | ty = self.env.tool_state[T,1]
19 | tz = self.env.tool_state[T,2]
20 | tx += 0.090
21 | tz += 0.090
22 |
23 | # Position of the item.
24 | rigid_avg_pos = self.env._get_rigid_pos()
25 | ix = rigid_avg_pos[0]
26 | iy = rigid_avg_pos[1]
27 | iz = rigid_avg_pos[2]
28 |
29 | # Unfortunately distance thresholds have to be tuned carefully.
30 | dist_xz = np.sqrt((tx-ix)**2 + (tz-iz)**2)
31 |
32 | thresh1 = 75
33 | thresh2 = thresh1 + 125
34 | if 0 <= step < 15:
35 | # Only move if the item is below the sphere.
36 | if dist_xz < 0.050:
37 | # Go the _negative_ direction
38 | dx = -(ix - tx)
39 | dz = -(iz - tz)
40 | elif 15 <= step < 66:
41 | # Actually it seems like faster strangely means less water movement.
42 | dy = -0.0040 # stronger -dy won't have effect due to action bounds
43 | elif thresh1 <= step < thresh2:
44 | # Try to correct for the discrepancy
45 | if dist_xz > 0.004:
46 | dx = ix - tx
47 | dz = iz - tz
48 | elif thresh2 <= step < 600:
49 | # Ah, it would actually be hard for us to do another xz correction here, since
50 | # if it causes collision, we'd just stop the motion. :(
51 | dy = 0.0040
52 | else:
53 | pass
54 |
55 | # Try to normalize (to magnitude 1) then downscale by a tuned amount.
56 | # Unfortunately these numbers just come from tuning / visualizing.
57 | action = np.array([dx, dy, dz, 0.])
58 | if (5 <= step < 15) or (thresh1 <= step < thresh2):
59 | if np.linalg.norm(action) > 0:
60 | action = action / np.linalg.norm(action) * 0.0020
61 |
62 | if self.env.action_mode == 'translation':
63 | action = action[:3]
64 | else:
65 | raise NotImplementedError()
66 |
67 | # 'Un-scale' to anticipate effect of future scaling in `NormalizedEnv`
68 | lb = self.env._wrapped_env.action_space.low
69 | ub = self.env._wrapped_env.action_space.high
70 | action = (action - lb) / ((ub - lb) * 0.5) - 1.0
71 |
72 | # Check if we're done
73 | done = step >= (self.THRESH2 - 1)
74 |
75 | return action, done
76 |
77 | class SmartMoveCloserInitializer(InitializerBase):
78 | THRESH2 = 75 + 125
79 |
80 | def __init__(self, env, args):
81 | self.env = env
82 | self.state = 0
83 |
84 | def reset(self):
85 | self.state = 0
86 |
87 | def get_action(self, obs, info=None):
88 | step = self.env.inner_step
89 | T = self.env.tool_idx
90 | dx, dy, dz = 0., 0., 0.
91 |
92 | # Position of the tool. Unfortunately adding to tx and tz is a real hack.
93 | tx = self.env.tool_state[T,0]
94 | ty = self.env.tool_state[T,1]
95 | tz = self.env.tool_state[T,2]
96 | tx += 0.090
97 | tz += 0.090
98 |
99 | # Position of the item.
100 | rigid_avg_pos = self.env._get_rigid_pos()
101 | ix = rigid_avg_pos[0]
102 | iy = rigid_avg_pos[1]
103 | iz = rigid_avg_pos[2]
104 |
105 | # Unfortunately distance thresholds have to be tuned carefully.
106 | dist_xz = np.sqrt((tx-ix)**2 + (tz-iz)**2)
107 |
108 | thresh1 = 75
109 | thresh2 = thresh1 + 125
110 | # if 0 <= step < 15:
111 | if self.state == 0:
112 | # Don't collide!
113 | if -0.2 >= tx or -0.2 >= tz or 0.2 <= tx or 0.2 <= tz:
114 | dx = -ix
115 | dz = -iz
116 | # Only move if the item is below the sphere.
117 | elif dist_xz < 0.050:
118 | # Go the _negative_ direction
119 | dx = -(ix - tx)
120 | dz = -(iz - tz)
121 | else:
122 | self.state += 1
123 | # elif 15 <= step < 66:
124 | elif self.state == 1:
125 | # Actually it seems like faster strangely means less water movement.
126 | if ty > 0.08:
127 | dy = -0.0040 # stronger -dy won't have effect due to action bounds
128 | else:
129 | self.state += 1
130 | self.finish_step = step
131 | # elif thresh1 <= step < thresh2:
132 | elif self.state == 2:
133 | # Try to correct for the discrepancy
134 | if dist_xz > 0.004 and step < self.finish_step + 125:
135 | dx = ix - tx
136 | dz = iz - tz
137 | else:
138 | self.state += 1
139 | # elif thresh2 <= step < 600:
140 | elif self.state == 3:
141 | # Ah, it would actually be hard for us to do another xz correction here, since
142 | # if it causes collision, we'd just stop the motion. :(
143 | dy = 0.0040
144 | if iy < ty:
145 | self.state = 0
146 | else:
147 | pass
148 |
149 | # Try to normalize (to magnitude 1) then downscale by a tuned amount.
150 | # Unfortunately these numbers just come from tuning / visualizing.
151 | action = np.array([dx, dy, dz, 0.])
152 | # if (5 <= step < 15) or (thresh1 <= step < thresh2):
153 | if self.state == 0 or self.state == 2:
154 | if np.linalg.norm(action) > 0:
155 | action = action / np.linalg.norm(action) * 0.0020
156 |
157 | if self.env.action_mode == 'translation':
158 | action = action[:3]
159 | else:
160 | raise NotImplementedError()
161 |
162 | # 'Un-scale' to anticipate effect of future scaling in `NormalizedEnv`
163 | lb = self.env._wrapped_env.action_space.low
164 | ub = self.env._wrapped_env.action_space.high
165 | action = (action - lb) / ((ub - lb) * 0.5) - 1.0
166 |
167 | # Check if we're done
168 | done = step >= (self.THRESH2 - 1)
169 |
170 | return action, done
171 |
--------------------------------------------------------------------------------
/bc/logger.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 | from collections import defaultdict
3 | import json
4 | import os
5 | import shutil
6 | import torch
7 | import torchvision
8 | import numpy as np
9 | from termcolor import colored
10 |
11 | FORMAT_CONFIG = {
12 | 'rl': {
13 | 'train': [
14 | ('episode', 'E', 'int'), ('step', 'S', 'int'),
15 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'),
16 | ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'),
17 | ('critic_loss', 'CR_LOSS', 'float'), ('curl_loss', 'CU_LOSS', 'float'),
18 | # ('q1', 'Q1', 'float'), ('comp_Q', 'comp_Q', 'float'), ('comp_E', 'comp_E', 'float')
19 | ],
20 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')]
21 | }
22 | }
23 |
24 |
25 | class AverageMeter(object):
26 | def __init__(self):
27 | self._sum = 0
28 | self._count = 0
29 |
30 | def update(self, value, n=1):
31 | self._sum += value
32 | self._count += n
33 |
34 | def value(self):
35 | return self._sum / max(1, self._count)
36 |
37 |
38 | class MetersGroup(object):
39 | def __init__(self, file_name, formating):
40 | self._file_name = file_name
41 | if os.path.exists(file_name):
42 | os.remove(file_name)
43 | self._formating = formating
44 | self._meters = defaultdict(AverageMeter)
45 |
46 | def log(self, key, value, n=1):
47 | self._meters[key].update(value, n)
48 |
49 | def _prime_meters(self):
50 | data = dict()
51 | for key, meter in self._meters.items():
52 | if key.startswith('train'):
53 | key = key[len('train') + 1:]
54 | else:
55 | key = key[len('eval') + 1:]
56 | key = key.replace('/', '_')
57 | data[key] = meter.value()
58 | return data
59 |
60 | def _dump_to_file(self, data):
61 | with open(self._file_name, 'a') as f:
62 | f.write(json.dumps(data) + '\n')
63 |
64 | def _format(self, key, value, ty):
65 | template = '%s: '
66 | if ty == 'int':
67 | template += '%d'
68 | elif ty == 'float':
69 | template += '%.04f'
70 | elif ty == 'time':
71 | template += '%.01f s'
72 | else:
73 | raise 'invalid format type: %s' % ty
74 | return template % (key, value)
75 |
76 | def _dump_to_console(self, data, prefix):
77 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
78 | pieces = ['{:5}'.format(prefix)]
79 | for key, disp_key, ty in self._formating:
80 | value = data.get(key, 0)
81 | pieces.append(self._format(disp_key, value, ty))
82 | print('| %s' % (' | '.join(pieces)))
83 |
84 | def dump(self, step, prefix):
85 | data = self._prime_meters()
86 | data['step'] = step
87 | self._dump_to_file(data)
88 | self._dump_to_console(data, prefix)
89 | self._meters.clear()
90 |
91 |
92 | class Logger(object):
93 | def __init__(self, log_dir, use_tb=True, config='rl', chester_logger=None):
94 | self._log_dir = log_dir
95 | if use_tb:
96 | tb_dir = os.path.join(log_dir, 'tb')
97 | if os.path.exists(tb_dir):
98 | shutil.rmtree(tb_dir)
99 | self._sw = SummaryWriter(tb_dir)
100 | else:
101 | self._sw = None
102 | self._train_mg = MetersGroup(
103 | os.path.join(log_dir, 'train.log'),
104 | formating=FORMAT_CONFIG[config]['train']
105 | )
106 | self._eval_mg = MetersGroup(
107 | os.path.join(log_dir, 'eval.log'),
108 | formating=FORMAT_CONFIG[config]['eval']
109 | )
110 | self.chester_logger = chester_logger
111 |
112 | def _try_sw_log(self, key, value, step):
113 | if self._sw is not None:
114 | self._sw.add_scalar(key, value, step)
115 |
116 | def _try_sw_log_image(self, key, image, step):
117 | if self._sw is not None:
118 | assert image.dim() == 3
119 | grid = torchvision.utils.make_grid(image.unsqueeze(1))
120 | self._sw.add_image(key, grid, step)
121 |
122 | def _try_sw_log_video(self, key, frames, step):
123 | if self._sw is not None:
124 | frames = torch.from_numpy(np.array(frames))
125 | frames = frames.unsqueeze(0)
126 | self._sw.add_video(key, frames, step, fps=30)
127 |
128 | def _try_sw_log_histogram(self, key, histogram, step):
129 | if self._sw is not None:
130 | self._sw.add_histogram(key, histogram, step)
131 |
132 | def log(self, key, value, step, n=1):
133 | assert key.startswith('train') or key.startswith('eval')
134 | if type(value) == torch.Tensor:
135 | value = value.item()
136 | self._try_sw_log(key, value / n, step)
137 | mg = self._train_mg if key.startswith('train') else self._eval_mg
138 | mg.log(key, value, n)
139 | if self.chester_logger is not None:
140 | self.chester_logger.record_tabular(key, value)
141 |
142 | def log_param(self, key, param, step):
143 | self.log_histogram(key + '_w', param.weight.data, step)
144 | if hasattr(param.weight, 'grad') and param.weight.grad is not None:
145 | self.log_histogram(key + '_w_g', param.weight.grad.data, step)
146 | if hasattr(param, 'bias'):
147 | self.log_histogram(key + '_b', param.bias.data, step)
148 | if hasattr(param.bias, 'grad') and param.bias.grad is not None:
149 | self.log_histogram(key + '_b_g', param.bias.grad.data, step)
150 |
151 | def log_image(self, key, image, step):
152 | assert key.startswith('train') or key.startswith('eval')
153 | self._try_sw_log_image(key, image, step)
154 |
155 | def log_video(self, key, frames, step):
156 | assert key.startswith('train') or key.startswith('eval')
157 | self._try_sw_log_video(key, frames, step)
158 |
159 | def log_histogram(self, key, histogram, step):
160 | assert key.startswith('train') or key.startswith('eval')
161 | self._try_sw_log_histogram(key, histogram, step)
162 |
163 | def dump(self, step):
164 | # NOTE(daniel): chester logger only dumps if we've done eval episodes
165 | #if len(self._eval_mg._prime_meters()) > 0 and self.chester_logger is not None:
166 | if self.chester_logger is not None:
167 | self.chester_logger.dump_tabular()
168 | self._train_mg.dump(step, 'train')
169 | self._eval_mg.dump(step, 'eval')
170 |
--------------------------------------------------------------------------------
/bc/encoder.py:
--------------------------------------------------------------------------------
1 | from xml.dom.minidom import Identified
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def tie_weights(src, trg):
7 | assert type(src) == type(trg)
8 | trg.weight = src.weight
9 | trg.bias = src.bias
10 |
11 |
12 | # NOTE(daniel): get these by checking shape after going through conv layers
13 | # for 128 x 128 inputs
14 | OUT_DIM_128 = {4: 57}
15 | # for 100 x 100 inputs
16 | OUT_DIM_100 = {4: 43}
17 | # for 84 x 84 inputs
18 | OUT_DIM = {2: 39, 4: 35, 6: 31}
19 | # for 64 x 64 inputs
20 | OUT_DIM_64 = {2: 29, 4: 25, 6: 21}
21 |
22 |
23 | class PixelEncoder(nn.Module):
24 | """Convolutional encoder of pixels observations.
25 |
26 | Update 08/24/2022: support depth_segm as observation type.
27 | """
28 |
29 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,
30 | output_logits=False, depth_segm=False):
31 | super().__init__()
32 |
33 | assert len(obs_shape) == 3
34 | self.obs_shape = obs_shape
35 | self.feature_dim = feature_dim
36 | self.num_layers = num_layers
37 | self.depth_segm = depth_segm
38 | print(f'Making Image CNN. Using a segm? {self.depth_segm}, obs: {self.obs_shape}')
39 |
40 | self.convs = nn.ModuleList(
41 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
42 | )
43 | for i in range(num_layers - 1):
44 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
45 |
46 | # Handle compatibility with output of convs and the FC layers portion.
47 | if obs_shape[-1] == 64:
48 | out_dim = OUT_DIM_64[num_layers]
49 | elif obs_shape[-1] == 84:
50 | out_dim = OUT_DIM[num_layers]
51 | elif obs_shape[-1] == 100:
52 | out_dim = OUT_DIM_100[num_layers]
53 | elif obs_shape[-1] == 128:
54 | out_dim = OUT_DIM_128[num_layers]
55 | else:
56 | raise NotImplementedError
57 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
58 | self.ln = nn.LayerNorm(self.feature_dim)
59 |
60 | self.outputs = dict()
61 | self.output_logits = output_logits
62 |
63 | def reparameterize(self, mu, logstd):
64 | std = torch.exp(logstd)
65 | eps = torch.randn_like(std)
66 | return mu + eps * std
67 |
68 | def forward_conv(self, obs):
69 | """Forward pass through CNN.
70 |
71 | NOTE(daniel): since we support RGBD input now, only divide by 255 for the
72 | image based components. With just RGBD we can do: obs = obs / 255.
73 | """
74 | if self.depth_segm:
75 | # Special case, we actually don't do anything as the image should already
76 | # be binary 0-1 for the masks, and between [0,x] for depth, where x ~ 1.
77 | # E.g., for MMOneSphere, image channels are: depth, tool mask, item mask.
78 | # Update: also doing this for rgb_segm_masks, rgbd_segm_masks, as we assume
79 | # the input will be in the correct range.
80 | pass
81 | else:
82 | obs[:,:3,:,:] /= 255. # (batch, channels, H, W)
83 | self.outputs['obs'] = obs
84 |
85 | conv = torch.relu(self.convs[0](obs))
86 | self.outputs['conv1'] = conv
87 |
88 | for i in range(1, self.num_layers):
89 | conv = torch.relu(self.convs[i](conv))
90 | self.outputs['conv%s' % (i + 1)] = conv
91 |
92 | h = conv.view(conv.size(0), -1)
93 | return h
94 |
95 | def forward(self, obs, detach=False):
96 | h = self.forward_conv(obs)
97 |
98 | if detach:
99 | h = h.detach()
100 | h_fc = self.fc(h)
101 | self.outputs['fc'] = h_fc
102 |
103 | h_norm = self.ln(h_fc)
104 | self.outputs['ln'] = h_norm
105 |
106 | if self.output_logits:
107 | out = h_norm
108 | else:
109 | out = torch.tanh(h_norm)
110 | self.outputs['tanh'] = out
111 |
112 | return out
113 |
114 | def copy_conv_weights_from(self, source):
115 | """Tie convolutional layers"""
116 | # only tie conv layers
117 | for i in range(self.num_layers):
118 | tie_weights(src=source.convs[i], trg=self.convs[i])
119 |
120 | def log(self, L, step, log_freq):
121 | if step % log_freq != 0:
122 | return
123 |
124 | for k, v in self.outputs.items():
125 | L.log_histogram('train_encoder/%s_hist' % k, v, step)
126 | if len(v.shape) > 2:
127 | L.log_image('train_encoder/%s_img' % k, v[0], step)
128 |
129 | for i in range(self.num_layers):
130 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
131 | L.log_param('train_encoder/fc', self.fc, step)
132 | L.log_param('train_encoder/ln', self.ln, step)
133 |
134 |
135 | class IdentityEncoder(nn.Module):
136 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, output_logits,
137 | depth_segm=False):
138 | super().__init__()
139 | assert len(obs_shape) == 1
140 | self.feature_dim = obs_shape[0]
141 |
142 | def forward(self, obs, detach=False):
143 | return obs
144 |
145 | def copy_conv_weights_from(self, source):
146 | pass
147 |
148 | def log(self, L, step, log_freq):
149 | pass
150 |
151 |
152 | class IdentityPNEncoder(nn.Module):
153 | """NOTE(daniel) for PointNet++, to allow for len(obs_shape) = 2.
154 | But this will still be an identity function. Just to make code consistent.
155 | """
156 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, output_logits,
157 | depth_segm=False):
158 | super().__init__()
159 | assert len(obs_shape) == 2
160 | self.feature_dim = obs_shape
161 |
162 | def forward(self, obs, detach=False):
163 | return obs
164 |
165 | def copy_conv_weights_from(self, source):
166 | pass
167 |
168 | def log(self, L, step, log_freq):
169 | pass
170 |
171 |
172 | # NOTE(daniel): this shoudl be fixed at some point, after deadline. :)
173 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder,
174 | 'segm': PixelEncoder,
175 | 'mlp': IdentityEncoder,
176 | 'state_predictor_then_mlp': IdentityEncoder,
177 | 'identity': IdentityEncoder,
178 | 'pointnet': IdentityPNEncoder,
179 | 'pointnet_rpmg': IdentityPNEncoder,
180 | 'pointnet_avg': IdentityPNEncoder,
181 | 'pointnet_svd': IdentityPNEncoder,
182 | 'pointnet_svd_centered': IdentityPNEncoder,
183 | 'pointnet_svd_pointwise': IdentityPNEncoder,
184 | 'pointnet_svd_pointwise_6d_flow': IdentityPNEncoder,
185 | 'pointnet_dense_tf_3D_MSE': IdentityPNEncoder,
186 | 'pointnet_dense_tf_6D_MSE': IdentityPNEncoder,
187 | 'pointnet_dense_tf_6D_pointwise': IdentityPNEncoder,
188 | 'pointnet_classif_6D_pointwise': IdentityPNEncoder,
189 | 'pointnet_rpmg_pointwise': IdentityPNEncoder,
190 | 'pointnet_rpmg_taugt': IdentityPNEncoder,
191 | 'pointnet_svd_6d_flow_mse_loss': IdentityPNEncoder,
192 | 'pointnet_svd_pointwise_PW_bef_SVD': IdentityPNEncoder,
193 | }
194 |
195 |
196 | def make_encoder(
197 | encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False,
198 | depth_segm=False
199 | ):
200 | assert encoder_type in _AVAILABLE_ENCODERS
201 | return _AVAILABLE_ENCODERS[encoder_type](
202 | obs_shape, feature_dim, num_layers, num_filters, output_logits, depth_segm=depth_segm
203 | )
204 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ToolFlowNet: Robotic Manipulation with Tools via Predicting Tool Flow from Point Clouds
3 |
4 |
12 |
13 |
17 |
18 |
19 |
20 | This has the **ToolFlowNet** code for
21 | our CoRL 2022 paper.
22 | See our other repository
23 | for the SoftGym environment code. Note that you should install that code *first*,
24 | and then this ToolFlowNet code *second*.
25 |
26 | The code in this repository has a bunch of SoftGym-specific components. We are
27 | working on a separate, simplified version of ToolFlowNet which is more agnostic
28 | to the choice of environment or task.
29 |
30 | **Note:** This branch contains instructions to run the ToolFlowNet simulation code. For the physical experiments code, we maintain a seperate [**`physical`**](https://github.com/DanielTakeshi/softagent_tfn/tree/physical) branch and a separate repository [**`tfn-robot`**](https://github.com/SarthakJShetty/tfn-robot) for the real-world data collection and robot control code.
31 |
32 |
33 |
34 | Contents:
35 |
36 | - [Installation](#installation)
37 | - [How to Use](#how-to-use)
38 | - [CoRL 2022 Experiments](#corl-2022-experiments)
39 | - [Inspect Results](#inspect-results)
40 | - [Citation](#citation)
41 |
42 |
43 |
44 | ## Installation
45 |
46 | Assuming you have already installed the other repository, then for this one, we
47 | first need to make a symlink to `softgym_tfn`. Make sure that `softgym_tfn/`
48 | does not exist within this folder. Then run this:
49 |
50 | ```
51 | ln -s ../softgym_tfn/ softgym
52 | ```
53 |
54 | This creates a symlink so the `softgym` subdirectory points to our `softgym_tfn`
55 | repository. By typing in `ls -lh` you should see: `softgym -> ../softgym_tfn/`.
56 |
57 | Then run this in any new Terminal window/tab that you want to run code in:
58 |
59 | ```
60 | . ./prepare_1.0.sh
61 | ```
62 |
63 | This will go into `softgym` and set up the conda environment. It should also set
64 | `PYFLEXROOT` which points to `PyFlex` in the SoftAgent repository.
65 |
66 | ## How to Use
67 |
68 | The main way that we launch experiments is with:
69 |
70 | ```
71 | python launch_exp.py
72 | ```
73 |
74 | or
75 |
76 | ```
77 | python launch_exp.py --debug
78 | ```
79 |
80 | The first case will run multiple combinations of variants in parallel. Thus, be
81 | careful about launching a lot of variants, since the combination can overwhelm
82 | one machine. Adding the `--debug` flag means the code runs just one of the
83 | variants. *We recommend using `--debug` to start*. In addition, when running
84 | multiple variants, we recommend only adjusting the random seed. We can do this
85 | by (for example) setting `vg.add('seed', [100,101])` and making all other
86 | `vg.add(...)` calls use just one-length lists. This will run 2 runs in parallel,
87 | each with the same settings, except with different random seeds. For the paper,
88 | we launched these scripts while using 5 random seeds with
89 | `vg.add('seed', [100,101,102,103,104])`.
90 |
91 | See `launch_exp.py` for details on what to modify. The three main areas to
92 | adjust for the purpose of learning from demonstrations are:
93 |
94 | - Adjusting the behavioral cloning data directory.
95 | - Selecting the environment to use, `PourWater` or `PourWater6D`. In this code,
96 | `PourWater` refers to the task version with 3DOF actions.
97 | - Selecting the method to use by setting `this_cfg` appropriately. See the code
98 | comments and `bc/exp_configs.py` for more about what the different
99 | configurations mean.
100 |
101 | You can find these areas by searching in `launch_exp.py` for this pattern:
102 |
103 | ```
104 | # ----------------------------- ADJUST -------------------------------- #
105 |
106 | # --------------------------------------------------------------------- #
107 | ```
108 |
109 | Check the content in between the above two lines.
110 |
111 | See the next section for how we set these for the CoRL 2022 submission.
112 |
113 | **Important note**: we highly recommend using `wandb` to track experiments.
114 | Please adjust these two lines in `launch_exp.py` appropriately:
115 |
116 | ```
117 | vg.add('wandb_project', ['']) # Fill this in!
118 | vg.add('wandb_entity', ['']) # Fill this in!
119 | ```
120 |
121 | If you need a refresher on `wandb`, refer to [the official documentation][2]. If
122 | you leave these blank, the script might not run successfully.
123 |
124 | ## CoRL 2022 Experiments
125 |
126 | Before this, make sure you have downloaded demonstration data [following our
127 | other repository's instructions][1]. This includes both the cache and the
128 | demonstrations themselves. While the default instructions put the data in
129 | `~/softgym_tfn/data_demo`, you may put the data in a different location if
130 | desired.
131 |
132 | **First**: with the demonstration data, adjust the `DATA_HEAD` variable. For
133 | example, setting this:
134 |
135 | ```
136 | DATA_HEAD = '/home/seita/softgym_tfn/data_demo/'
137 | ```
138 |
139 | near the top of `launch_exp.py` means that, for a run with PourWater (3D), I
140 | should expect to see the demonstrations located at:
141 |
142 | ```
143 | /home/seita/softgym_tfn/data_demo/PourWater_v01_BClone_filtered_wDepth_pw_algo_v02_nVars_1500_obs_combo_act_translation_axis_angle_withWaterFrac
144 | ```
145 |
146 | **Second**: select the task you want, either `PourWater` (the 3DOF action space
147 | version) or `PourWater6D` (with 6DOF actions). This means selecting *one* of the
148 | following:
149 |
150 | ```
151 | env, env_version, alg_policy = 'PourWater', 'v01', 'pw_algo_v02'
152 | env, env_version, alg_policy = 'PourWater6D', 'v01', 'pw_algo_v02'
153 | ```
154 |
155 | Be sure to comment out whatever option you are not using.
156 |
157 | **Third**: pick the method. For example, select ToolFlowNet with:
158 |
159 | ```
160 | this_cfg = exp_configs.SVD_POINTWISE_EE2FLOW
161 | ```
162 |
163 | or the PCL Direct Vector MSE baseline with:
164 |
165 | ```
166 | this_cfg = exp_configs.DIRECT_VECTOR_INTRINSIC_AXIS_ANGLE
167 | ```
168 |
169 | There are many experiment options. See the comments and `bc/exp_configs.py` for
170 | more details.
171 |
172 | Finally, double check all the settings in the variant generator (`vg`). For the
173 | paper we typically ran by setting:
174 |
175 | ```
176 | vg.add('seed', [100,101,102,103,104])
177 | ```
178 |
179 | as the only variant, in the sense that this is the only `vg` option with more
180 | than one list item. This means (as stated in our paper) we ran 5 random seeds
181 | for each experiment setting.
182 |
183 | Once you are confident the settings are correct, run the script! (Did you
184 | remember to set up `wandb`?)
185 |
186 |
187 | ## Inspect Results
188 |
189 | For accumulating and computing results for the CoRL 2022 paper, we used one of
190 | the following four commands:
191 |
192 | ```
193 | python results_table.py
194 | python results_table.py --show_avg_epoch
195 | python results_table.py --show_raw_perf
196 | python results_table.py --show_raw_perf --show_avg_epoch
197 | ```
198 |
199 | These four commands, respectively, produce a table of statistics which correspond to results in
200 | **Table 1**, **Table S5**, **Table S6**, and **Table S7** in the paper.
201 | Some relevant keys in the code are:
202 |
203 | - `FLOW3D_SVD_PW_CONSIST_0_1` corresponds to ToolFlowNet with `lambda` (consistency weight) value of 0.1.
204 | - `PCL_DIRECT_VECTOR_MSE_INTRIN_AXANG` corresponds to "PCL Direct Vector MSE."
205 | - `PCL_DENSE_TRANSF_MSE_INTRIN_AXANG` corresponds to "PCL Dense Transformation MSE."
206 |
207 | Successfully running this command requires having all relevant experiment
208 | results. We provide this script to explain how we produced the results.
209 |
210 |
211 | ## Citation
212 |
213 | If you find this repository useful, please cite our paper:
214 |
215 | ```
216 | @inproceedings{Seita2022toolflownet,
217 | title={{ToolFlowNet: Robotic Manipulation with Tools via Predicting Tool Flow from Point Clouds}},
218 | author={Seita, Daniel and Wang, Yufei and Shetty, Sarthak, and Li, Edward and Erickson, Zackory and Held, David},
219 | booktitle={Conference on Robot Learning (CoRL)},
220 | year={2022}
221 | }
222 | ```
223 |
224 | [1]:https://github.com/DanielTakeshi/softgym_tfn
225 | [2]:https://docs.wandb.ai/
226 |
--------------------------------------------------------------------------------
/chester/plotting/cplot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import copy
4 | import json
5 | import argparse
6 | import itertools
7 | import numpy as np
8 |
9 | # Matplotlib
10 | import matplotlib
11 |
12 | matplotlib.use('Agg')
13 | import matplotlib.pyplot as plt
14 |
15 | plt.rc('font', size=25)
16 | matplotlib.rcParams['pdf.fonttype'] = 42 # Default type3 cannot be rendered in some templates
17 | matplotlib.rcParams['ps.fonttype'] = 42
18 | matplotlib.rcParams['grid.alpha'] = 0.3
19 | matplotlib.rcParams['axes.titlesize'] = 25
20 | import matplotlib.ticker as tick
21 |
22 | # rllab
23 | sys.path.append('.')
24 | from rllab.misc.ext import flatten
25 | from rllab.viskit import core
26 |
27 |
28 | # from rllab.misc import ext
29 |
30 | # plotly
31 | # import plotly.offline as po
32 | # import plotly.graph_objs as go
33 |
34 |
35 | def smooth_data(data, smooth):
36 | """NOTE(daniel) smoothing with window average, from SpinningUp.
37 | https://github.com/openai/spinningup/blob/master/spinup/utils/plot.py#L15
38 |
39 | smooth data with moving window average.
40 | that is,
41 | smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k])
42 | where the "smooth" param is width of that window (2k+1)
43 | """
44 | if smooth <= 1:
45 | return data
46 | x = np.asarray(data)
47 | y = np.ones(smooth)
48 | z = np.ones(len(x))
49 | smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same')
50 | return smoothed_x
51 |
52 |
53 | def reload_data(data_paths):
54 | """
55 | Iterate through the data folder and organize each experiment into a list, with their progress data, hyper-parameters
56 | and also analyze all the curves and give the distinct hyper-parameters.
57 | :param data_path: Path of the folder storing all the data
58 | :return [exps_data, plottable_keys, distinct_params]
59 | exps_data: A list of the progress data for each curve. Each curve is an AttrDict with the key
60 | 'progress': A dictionary of plottable keys. The val of each key is an ndarray representing the
61 | values of the key during training, or one column in the progress.txt file.
62 | 'params'/'flat_params': A dictionary of all hyperparameters recorded in 'variants.json' file.
63 | plottable_keys: A list of strings representing all the keys that can be plotted.
64 | distinct_params: A list of hyper-parameters which have different values among all the curves. This can be used
65 | to split the graph into multiple figures. Each element is a tuple (param, list_of_values_to_take).
66 | """
67 |
68 | exps_data = copy.copy(core.load_exps_data(data_paths, disable_variant=False, ignore_missing_keys=True))
69 | plottable_keys = copy.copy(sorted(list(set(flatten(list(exp.progress.keys()) for exp in exps_data)))))
70 | distinct_params = copy.copy(sorted(core.extract_distinct_params(exps_data)))
71 |
72 | return exps_data, plottable_keys, distinct_params
73 |
74 |
75 | def get_shaded_curve(selector, key, shade_type='variance'):
76 | """
77 | :param selector: Selector for a group of curves
78 | :param shade_type: Should be either 'variance' or 'median', indicating how the shades are calculated.
79 | :return: [y, y_lower, y_upper], representing the mean, upper and lower boundary of the shaded region
80 | """
81 |
82 | # First, get the progresses
83 | progresses = [exp.progress.get(key, np.array([np.nan])) for exp in selector.extract()]
84 | max_size = max(len(x) for x in progresses)
85 | progresses = [np.concatenate([ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses]
86 |
87 | # Second, calculate the shaded area
88 | if shade_type == 'median':
89 | percentile25 = np.nanpercentile(
90 | progresses, q=25, axis=0)
91 | percentile50 = np.nanpercentile(
92 | progresses, q=50, axis=0)
93 | percentile75 = np.nanpercentile(
94 | progresses, q=75, axis=0)
95 |
96 | y = list(percentile50)
97 | y_upper = list(percentile75)
98 | y_lower = list(percentile25)
99 | elif shade_type == 'variance':
100 | means = np.nanmean(progresses, axis=0)
101 | stds = np.nanstd(progresses, axis=0)
102 |
103 | y = list(means)
104 | y_upper = list(means + stds)
105 | y_lower = list(means - stds)
106 | else:
107 | raise NotImplementedError
108 |
109 | return y, y_lower, y_upper
110 |
111 |
112 | def get_group_selectors(exps, custom_series_splitter):
113 | """Get selectors, a custom rllab class.
114 |
115 | Example:
116 | splitted_dict['Reduced State Oracle (SAC)'] = [{dict1}, {dict2},...]
117 | Each `dictk` has data from one `progress.csv`, created from one RL run.
118 |
119 | :param exps: list of experiments, each is of `rllab.misc.ext.AttrDict` type.
120 | IDK why they need that. The keys are 'progress' (which is loaded from the
121 | csv), 'params' and 'flat_params'. The 'params' and 'flat_params' seem to
122 | only differ based on the latter having our new `env_kwargs_{...}` stuff.
123 | :param custom_series_splitter: custom function defined to extract the algorithm
124 | and other info, and produce a label for the legend.
125 | :return: A tuple of (list,list) type, containing the selectors and legends.
126 | """
127 | splitted_dict = dict()
128 | for exp in exps:
129 | # Group exps by their series_splitter key
130 | # splitted_dict: {key:[exp1, exp2, ...]}
131 | key = custom_series_splitter(exp)
132 | if key not in splitted_dict:
133 | splitted_dict[key] = list()
134 | splitted_dict[key].append(exp)
135 |
136 | splitted = list(splitted_dict.items()) # list of tuples, each tuple is (key,val)
137 | # Group selectors: All the exps in one of the keys/legends
138 | # Group legends: All the different legends
139 | group_selectors = [core.Selector(list(x[1])) for x in splitted] # x[1]: list of progress.csv dicts
140 | group_legends = [x[0] for x in splitted] # x[0] is key, e.g., 'Reduced State Oracle (SAC)'
141 | all_tuples = sorted(list(zip(group_selectors, group_legends)), key=lambda x: x[1], reverse=True)
142 | group_selectors = [x[0] for x in all_tuples]
143 | group_legends = [x[1] for x in all_tuples]
144 | return group_selectors, group_legends
145 |
146 |
147 | def filter_save_name(save_name):
148 | save_name = save_name.replace('/', '_')
149 | save_name = save_name.replace('[', '_')
150 | save_name = save_name.replace(']', '_')
151 | save_name = save_name.replace('(', '_')
152 | save_name = save_name.replace(')', '_')
153 | save_name = save_name.replace(',', '_')
154 | save_name = save_name.replace(' ', '_')
155 | save_name = save_name.replace('0.', '0_')
156 | return save_name
157 |
158 |
159 | def sliding_mean(data_array, window=5):
160 | data_array = np.array(data_array)
161 | new_list = []
162 | for i in range(len(data_array)):
163 | indices = list(range(max(i - window + 1, 0),
164 | min(i + window + 1, len(data_array))))
165 | avg = 0
166 | for j in indices:
167 | avg += data_array[j]
168 | avg /= float(len(indices))
169 | new_list.append(avg)
170 |
171 | return np.array(new_list)
172 |
173 |
174 | if __name__ == '__main__':
175 | data_path = '/Users/Dora/Projects/baselines_hrl/data/seuss/visual_rss_RopeFloat_0407'
176 | exps_data, plottable_keys, distinct_params = reload_data(data_path)
177 |
178 | # Example of extracting a single curve
179 | selector = core.Selector(exps_data)
180 | selector = selector.where('her_replay_strategy', 'balance_filter')
181 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state')
182 | _, ax = plt.subplots()
183 |
184 | color = core.color_defaults[0]
185 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2)
186 | ax.plot(range(len(y)), y, color=color, label=plt.legend, linewidth=2.0)
187 |
188 |
189 | # Example of extracting all the curves
190 | def custom_series_splitter(x):
191 | params = x['flat_params']
192 | if 'use_ae_reward' in params and params['use_ae_reward']:
193 | return 'Auto Encoder'
194 | if params['her_replay_strategy'] == 'balance_filter':
195 | return 'Indicator+Balance+Filter'
196 | if params['env_kwargs.use_true_reward']:
197 | return 'Oracle'
198 | return 'Indicator'
199 |
200 |
201 | fig, ax = plt.subplots(figsize=(8, 5))
202 |
203 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter)
204 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)):
205 | color = core.color_defaults[idx]
206 |
207 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state')
208 |
209 | ax.plot(range(len(y)), y, color=color, label=legend, linewidth=2.0)
210 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2)
211 | ax.grid(True)
212 | ax.set_xlabel('Timesteps')
213 | ax.set_ylabel('Success')
214 | loc = 'best'
215 | leg = ax.legend(loc=loc, prop={'size': 15}, ncol=1, labels=group_legends)
216 | for legobj in leg.legendHandles:
217 | legobj.set_linewidth(3.0)
218 | plt.savefig('test.png', bbox_inches='tight')
219 |
--------------------------------------------------------------------------------
/tests/test_pt3d_pyquaternion.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 53,
6 | "id": "6b71c9f3-5edc-458d-9957-e74c7e0fd786",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import math\n",
11 | "import numpy as np\n",
12 | "import torch\n",
13 | "from pyquaternion import Quaternion\n",
14 | "from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_axis_angle, axis_angle_to_matrix, Rotate\n",
15 | "\n",
16 | "from bc.se3 import flow2pose"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 25,
22 | "id": "39b8afcb-f566-48e4-bf53-5dfe69706ac3",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "axis = [0, math.sqrt(2) / 2, math.sqrt(2) / 2]\n",
27 | "angle = math.pi / 2"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 26,
33 | "id": "047facaf-c0e6-4e36-9150-fcdb27cf862a",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "quat = Quaternion(axis=axis, angle=angle)\n",
38 | "points = np.vstack(np.meshgrid(np.linspace(-1, 1, 10), np.linspace(-1, 1, 10), np.linspace(-1, 1, 10))).reshape(3, -1).T"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 27,
44 | "id": "fc97695c-75d9-42fb-94a1-7ba7ec108712",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "text/plain": [
50 | "array([[ 1.11022302e-16, -7.07106781e-01, 7.07106781e-01],\n",
51 | " [ 7.07106781e-01, 5.00000000e-01, 5.00000000e-01],\n",
52 | " [-7.07106781e-01, 5.00000000e-01, 5.00000000e-01]])"
53 | ]
54 | },
55 | "execution_count": 27,
56 | "metadata": {},
57 | "output_type": "execute_result"
58 | }
59 | ],
60 | "source": [
61 | "quat.rotation_matrix"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 28,
67 | "id": "e17fedfa-08a3-44e2-b374-7760749d15dc",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "flow = np.zeros_like(points)\n",
72 | "for i in range(points.shape[0]):\n",
73 | " pt = points[i]\n",
74 | " rot = quat.rotate(pt)\n",
75 | " flow[i] += rot - pt\n",
76 | "\n",
77 | "points = torch.from_numpy(points.astype(np.float32))\n",
78 | "flow = torch.from_numpy(flow.astype(np.float32))"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 29,
84 | "id": "be73a4a7-51ab-41e2-ab79-10341a332637",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "trfm = flow2pose(\n",
89 | " xyz=points[None, :],\n",
90 | " flow=flow[None, :],\n",
91 | " weights=None,\n",
92 | " return_transform3d=True,\n",
93 | " return_quaternions=False,\n",
94 | ")"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 30,
100 | "id": "f43a6f21-b47a-471d-9736-b8d7d477b56a",
101 | "metadata": {},
102 | "outputs": [
103 | {
104 | "data": {
105 | "text/plain": [
106 | "True"
107 | ]
108 | },
109 | "execution_count": 30,
110 | "metadata": {},
111 | "output_type": "execute_result"
112 | }
113 | ],
114 | "source": [
115 | "pred_flow = trfm.transform_points(points).squeeze(0) - points\n",
116 | "torch.allclose(flow, pred_flow, atol=1e-4)"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 31,
122 | "id": "90a38c7b-70c2-4b68-941b-8d9a09ceee16",
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "rot_matrices, trans = flow2pose(\n",
127 | " xyz=points[None, :],\n",
128 | " flow=flow[None, :],\n",
129 | " weights=None,\n",
130 | " return_transform3d=False,\n",
131 | " return_quaternions=False,\n",
132 | " world_frameify=False,\n",
133 | ")"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 32,
139 | "id": "41559b3c-c7f6-4823-9ca4-557c4fe92d5d",
140 | "metadata": {},
141 | "outputs": [
142 | {
143 | "data": {
144 | "text/plain": [
145 | "tensor([[[ 0.0000, 0.7071, -0.7071],\n",
146 | " [-0.7071, 0.5000, 0.5000],\n",
147 | " [ 0.7071, 0.5000, 0.5000]]])"
148 | ]
149 | },
150 | "execution_count": 32,
151 | "metadata": {},
152 | "output_type": "execute_result"
153 | }
154 | ],
155 | "source": [
156 | "rot_matrices"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 33,
162 | "id": "afffab03-032b-4625-9687-f8838c804a78",
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "quats = matrix_to_quaternion(matrix=rot_matrices.transpose(1, 2))\n",
167 | "axis_ang = quaternion_to_axis_angle(quaternions=quats)"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 34,
173 | "id": "6ab53033-a75a-4d6b-96b4-33801c3bb8a5",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "pred_axis = axis_ang / torch.linalg.norm(axis_ang)\n",
178 | "pred_angle = torch.linalg.norm(axis_ang)"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 35,
184 | "id": "a2d19df3-9599-40f9-82ef-9725d2cbcad4",
185 | "metadata": {},
186 | "outputs": [
187 | {
188 | "name": "stdout",
189 | "output_type": "stream",
190 | "text": [
191 | "tensor([[5.9605e-08, 7.0711e-01, 7.0711e-01]])\n",
192 | "tensor(1.5708)\n"
193 | ]
194 | }
195 | ],
196 | "source": [
197 | "print(pred_axis)\n",
198 | "print(pred_angle)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 36,
204 | "id": "c0663f71-3027-4a00-b205-69a1db9d5dbe",
205 | "metadata": {},
206 | "outputs": [
207 | {
208 | "data": {
209 | "text/plain": [
210 | "True"
211 | ]
212 | },
213 | "execution_count": 36,
214 | "metadata": {},
215 | "output_type": "execute_result"
216 | }
217 | ],
218 | "source": [
219 | "pred_flow = torch.bmm(points.unsqueeze(0), rot_matrices).squeeze(0) - points\n",
220 | "torch.allclose(flow, pred_flow, atol=1e-4)"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 37,
226 | "id": "130c57bd-20e5-4676-84dc-5f1c34e85972",
227 | "metadata": {},
228 | "outputs": [
229 | {
230 | "data": {
231 | "text/plain": [
232 | "torch.Size([1, 3])"
233 | ]
234 | },
235 | "execution_count": 37,
236 | "metadata": {},
237 | "output_type": "execute_result"
238 | }
239 | ],
240 | "source": [
241 | "trans.shape"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 38,
247 | "id": "80180df0-a846-405b-9220-b9970295fa3e",
248 | "metadata": {},
249 | "outputs": [
250 | {
251 | "data": {
252 | "text/plain": [
253 | "torch.Size([1, 3])"
254 | ]
255 | },
256 | "execution_count": 38,
257 | "metadata": {},
258 | "output_type": "execute_result"
259 | }
260 | ],
261 | "source": [
262 | "torch.mean(flow, axis=0, keepdims=True).shape"
263 | ]
264 | },
265 | {
266 | "cell_type": "markdown",
267 | "id": "44c2d7d4-a27a-47d4-b5c2-42fb77450855",
268 | "metadata": {},
269 | "source": [
270 | "Dense Trfm Testing\n",
271 | "=================="
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 61,
277 | "id": "8a1ac814-eae4-4705-8a57-9596466b4e1c",
278 | "metadata": {},
279 | "outputs": [],
280 | "source": [
281 | "axis_ang = torch.tensor([axis]) * angle\n",
282 | "rot_matrix = axis_angle_to_matrix(axis_ang).transpose(1, 2)"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 62,
288 | "id": "3d854473-c848-4b67-b553-3b48f624b53a",
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "data": {
293 | "text/plain": [
294 | "True"
295 | ]
296 | },
297 | "execution_count": 62,
298 | "metadata": {},
299 | "output_type": "execute_result"
300 | }
301 | ],
302 | "source": [
303 | "pred_flow = torch.bmm(points.unsqueeze(0), rot_matrix).squeeze(0) - points\n",
304 | "torch.allclose(flow, pred_flow, atol=1e-4)"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 63,
310 | "id": "da168b0b-5e14-4794-be24-9b32ea0ed47e",
311 | "metadata": {},
312 | "outputs": [
313 | {
314 | "data": {
315 | "text/plain": [
316 | "True"
317 | ]
318 | },
319 | "execution_count": 63,
320 | "metadata": {},
321 | "output_type": "execute_result"
322 | }
323 | ],
324 | "source": [
325 | "trfm = Rotate(rot_matrix)\n",
326 | "pred_flow = trfm.transform_points(points).squeeze(0) - points\n",
327 | "torch.allclose(flow, pred_flow, atol=1e-4)"
328 | ]
329 | },
330 | {
331 | "cell_type": "code",
332 | "execution_count": null,
333 | "id": "6c0bc3e3-e16c-4522-8521-5918f457c0ab",
334 | "metadata": {},
335 | "outputs": [],
336 | "source": []
337 | }
338 | ],
339 | "metadata": {
340 | "kernelspec": {
341 | "display_name": "softgym",
342 | "language": "python",
343 | "name": "softgym"
344 | },
345 | "language_info": {
346 | "codemirror_mode": {
347 | "name": "ipython",
348 | "version": 3
349 | },
350 | "file_extension": ".py",
351 | "mimetype": "text/x-python",
352 | "name": "python",
353 | "nbconvert_exporter": "python",
354 | "pygments_lexer": "ipython3",
355 | "version": "3.6.13"
356 | }
357 | },
358 | "nbformat": 4,
359 | "nbformat_minor": 5
360 | }
361 |
--------------------------------------------------------------------------------
/rpmg/rpmg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | import os
4 | BASEPATH = os.path.dirname(__file__)
5 | sys.path.append(BASEPATH)
6 | import tools
7 |
8 | def Rodrigues(w):
9 | '''
10 | axis angle -> rotation
11 | :param w: [b,3]
12 | :return: R: [b,3,3]
13 | '''
14 | w = w.unsqueeze(2).unsqueeze(3).repeat(1, 1, 3, 3)
15 | b = w.shape[0]
16 | theta = w.norm(dim=1)
17 | #print(theta[0])
18 | #theta = torch.where(t>math.pi/16, torch.Tensor([math.pi/16]).cuda(), t)
19 | wnorm = w / (w.norm(dim=1,keepdim=True)+0.001)
20 | #wnorm = torch.nn.functional.normalize(w,dim=1)
21 | I = torch.eye(3, device=w.get_device()).repeat(b, 1, 1)
22 | help1 = torch.zeros((b,1,3, 3), device=w.get_device())
23 | help2 = torch.zeros((b,1,3, 3), device=w.get_device())
24 | help3 = torch.zeros((b,1,3, 3), device=w.get_device())
25 | help1[:,:,1, 2] = -1
26 | help1[:,:,2, 1] = 1
27 | help2[:,:,0, 2] = 1
28 | help2[:,:,2, 0] = -1
29 | help3[:,:,0, 1] = -1
30 | help3[:,:,1, 0] = 1
31 | Jwnorm = (torch.cat([help1,help2,help3],1)*wnorm).sum(dim=1)
32 |
33 | return I + torch.sin(theta) * Jwnorm + (1 - torch.cos(theta)) * torch.bmm(Jwnorm, Jwnorm)
34 |
35 | logger = 0
36 | def logger_init(ll):
37 | global logger
38 | logger = ll
39 | print('logger init')
40 |
41 | class RPMG(torch.autograd.Function):
42 | '''
43 | full version. See "simple_RPMG()" for a simplified version.
44 | Tips:
45 | 1. Use "logger_init()" to initialize the logger, if you want to record some intermidiate variables by tensorboard.
46 | 2. Use sum of L2/geodesic loss instead of mean, since our tau_converge is derivated without considering the scalar introduced by mean loss.
47 | See for an example.
48 | 3. Pass "weight=$YOUR_WEIGHT" instead of directly multiple the weight on rotation loss, if you want to reweight R loss and other losses.
49 | See for an example.
50 | '''
51 | @staticmethod
52 | def forward(ctx, in_nd, tau, lam, rgt, iter, weight=1):
53 | proj_kind = in_nd.shape[1]
54 | if proj_kind == 6:
55 | r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd)
56 | elif proj_kind == 9:
57 | r0 = tools.symmetric_orthogonalization(in_nd)
58 | elif proj_kind == 4:
59 | r0 = tools.compute_rotation_matrix_from_quaternion(in_nd)
60 | elif proj_kind == 10:
61 | r0 = tools.compute_rotation_matrix_from_10d(in_nd)
62 | else:
63 | raise NotImplementedError
64 | ctx.save_for_backward(in_nd, r0, torch.Tensor([tau,lam, iter, weight]), rgt)
65 | return r0
66 |
67 | @staticmethod
68 | def backward(ctx, grad_in):
69 | in_nd, r0, config,rgt, = ctx.saved_tensors
70 | tau = config[0]
71 | lam = config[1]
72 | b = r0.shape[0]
73 | iter = config[2]
74 | weight = config[3]
75 | proj_kind = in_nd.shape[1]
76 |
77 | # use Riemannian optimization to get the next goal R
78 | if tau == -1:
79 | r_new = rgt
80 | else:
81 | # Eucliean gradient -> Riemannian gradient
82 | Jx = torch.zeros((b, 3, 3)).cuda()
83 | Jx[:, 2, 1] = 1
84 | Jx[:, 1, 2] = -1
85 | Jy = torch.zeros((b, 3, 3)).cuda()
86 | Jy[:, 0, 2] = 1
87 | Jy[:, 2, 0] = -1
88 | Jz = torch.zeros((b, 3, 3)).cuda()
89 | Jz[:, 0, 1] = -1
90 | Jz[:, 1, 0] = 1
91 | gx = (grad_in*torch.bmm(r0, Jx)).reshape(-1,9).sum(dim=1,keepdim=True)
92 | gy = (grad_in * torch.bmm(r0, Jy)).reshape(-1, 9).sum(dim=1,keepdim=True)
93 | gz = (grad_in * torch.bmm(r0, Jz)).reshape(-1, 9).sum(dim=1,keepdim=True)
94 | g = torch.cat([gx,gy,gz],1)
95 |
96 | # take one step
97 | delta_w = -tau * g
98 |
99 | # update R
100 | r_new = torch.bmm(r0, Rodrigues(delta_w))
101 |
102 | #this can help you to tune the tau if you don't use L2/geodesic loss.
103 | if iter % 100 == 0:
104 | logger.add_scalar('next_goal_angle_mean', delta_w.norm(dim=1).mean(), iter)
105 | logger.add_scalar('next_goal_angle_max', delta_w.norm(dim=1).max(), iter)
106 | R0_Rgt = tools.compute_geodesic_distance_from_two_matrices(r0, rgt)
107 | logger.add_scalar('r0_rgt_angle', R0_Rgt.mean(), iter)
108 |
109 | # inverse & project
110 | if proj_kind == 6:
111 | r_proj_1 = (r_new[:, :, 0] * in_nd[:, :3]).sum(dim=1, keepdim=True) * r_new[:, :, 0]
112 | r_proj_2 = (r_new[:, :, 0] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 0] \
113 | + (r_new[:, :, 1] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 1]
114 | r_reg_1 = lam * (r_proj_1 - r_new[:, :, 0])
115 | r_reg_2 = lam * (r_proj_2 - r_new[:, :, 1])
116 | gradient_nd = torch.cat([in_nd[:, :3] - r_proj_1 + r_reg_1, in_nd[:, 3:] - r_proj_2 + r_reg_2], 1)
117 | elif proj_kind == 9:
118 | SVD_proj = tools.compute_SVD_nearest_Mnlsew(in_nd.reshape(-1,3,3), r_new)
119 | gradient_nd = in_nd - SVD_proj + lam * (SVD_proj - r_new.reshape(-1,9))
120 | R_proj_g = tools.symmetric_orthogonalization(SVD_proj)
121 | if iter % 100 == 0:
122 | logger.add_scalar('9d_reflection', (((R_proj_g-r_new).reshape(-1,9).abs().sum(dim=1))>5e-1).sum(), iter)
123 | logger.add_scalar('reg', (SVD_proj - r_new.reshape(-1, 9)).norm(dim=1).mean(), iter)
124 | logger.add_scalar('main', (in_nd - SVD_proj).norm(dim=1).mean(), iter)
125 | elif proj_kind == 4:
126 | q_1 = tools.compute_quaternions_from_rotation_matrices(r_new)
127 | q_2 = -q_1
128 | normalized_nd = tools.normalize_vector(in_nd)
129 | q_new = torch.where(
130 | (q_1 - normalized_nd).norm(dim=1, keepdim=True) < (q_2 - normalized_nd).norm(dim=1, keepdim=True),
131 | q_1, q_2)
132 | q_proj = (in_nd * q_new).sum(dim=1, keepdim=True) * q_new
133 | gradient_nd = in_nd - q_proj + lam * (q_proj - q_new)
134 | elif proj_kind == 10:
135 | qg = tools.compute_quaternions_from_rotation_matrices(r_new)
136 | new_x = tools.compute_nearest_10d(in_nd, qg)
137 | reg_A = torch.eye(4, device=qg.device)[None].repeat(qg.shape[0],1,1) - torch.bmm(qg.unsqueeze(-1), qg.unsqueeze(-2))
138 | reg_x = tools.convert_A_to_Avec(reg_A)
139 | gradient_nd = in_nd - new_x + lam * (new_x - reg_x)
140 | if iter % 100 == 0:
141 | logger.add_scalar('reg', (new_x - reg_x).norm(dim=1).mean(), iter)
142 | logger.add_scalar('main', (in_nd - new_x).norm(dim=1).mean(), iter)
143 |
144 | return gradient_nd * weight, None, None,None,None,None
145 |
146 |
147 |
148 | class simple_RPMG(torch.autograd.Function):
149 | '''
150 | simplified version without tensorboard and r_gt.
151 | '''
152 | @staticmethod
153 | def forward(ctx, in_nd, tau, lam, weight=1, rgt=None):
154 | proj_kind = in_nd.shape[1]
155 | if proj_kind == 6:
156 | r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd)
157 | elif proj_kind == 9:
158 | r0 = tools.symmetric_orthogonalization(in_nd)
159 | elif proj_kind == 4:
160 | r0 = tools.compute_rotation_matrix_from_quaternion(in_nd)
161 | elif proj_kind == 10:
162 | r0 = tools.compute_rotation_matrix_from_10d(in_nd)
163 | else:
164 | raise NotImplementedError
165 |
166 | if rgt is None:
167 | rgt = torch.zeros_like(r0)
168 |
169 | ctx.save_for_backward(in_nd, r0, torch.Tensor([tau,lam, weight]), rgt)
170 | # return r0.transpose(-1, -2) # TODO(daniel): getting rid of this?
171 | return r0
172 |
173 | @staticmethod
174 | def backward(ctx, grad_in):
175 | # grad_in = grad_in.transpose(-1, -2) # TODO(daniel): getting rid of this?
176 | in_nd, r0, config, rgt = ctx.saved_tensors
177 | tau = config[0]
178 | lam = config[1]
179 | weight = config[2]
180 | b = r0.shape[0]
181 | proj_kind = in_nd.shape[1]
182 |
183 | # use Riemannian optimization to get the next goal R
184 | if tau == -1:
185 | # use tau_gt if tau = -1
186 | # requires rgt to be passed in initially
187 | r_new = rgt
188 | else:
189 | # Eucliean gradient -> Riemannian gradient
190 | Jx = torch.zeros((b, 3, 3)).cuda()
191 | Jx[:, 2, 1] = 1
192 | Jx[:, 1, 2] = -1
193 | Jy = torch.zeros((b, 3, 3)).cuda()
194 | Jy[:, 0, 2] = 1
195 | Jy[:, 2, 0] = -1
196 | Jz = torch.zeros((b, 3, 3)).cuda()
197 | Jz[:, 0, 1] = -1
198 | Jz[:, 1, 0] = 1
199 | gx = (grad_in*torch.bmm(r0, Jx)).reshape(-1,9).sum(dim=1,keepdim=True)
200 | gy = (grad_in * torch.bmm(r0, Jy)).reshape(-1, 9).sum(dim=1,keepdim=True)
201 | gz = (grad_in * torch.bmm(r0, Jz)).reshape(-1, 9).sum(dim=1,keepdim=True)
202 | g = torch.cat([gx,gy,gz],1)
203 |
204 | # take one step
205 | delta_w = -tau * g
206 |
207 | # update R
208 | r_new = torch.bmm(r0, Rodrigues(delta_w))
209 |
210 | # inverse & project
211 | if proj_kind == 6:
212 | r_proj_1 = (r_new[:, :, 0] * in_nd[:, :3]).sum(dim=1, keepdim=True) * r_new[:, :, 0]
213 | r_proj_2 = (r_new[:, :, 0] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 0] \
214 | + (r_new[:, :, 1] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 1]
215 | r_reg_1 = lam * (r_proj_1 - r_new[:, :, 0])
216 | r_reg_2 = lam * (r_proj_2 - r_new[:, :, 1])
217 | gradient_nd = torch.cat([in_nd[:, :3] - r_proj_1 + r_reg_1, in_nd[:, 3:] - r_proj_2 + r_reg_2], 1)
218 | elif proj_kind == 9:
219 | SVD_proj = tools.compute_SVD_nearest_Mnlsew(in_nd.reshape(-1,3,3), r_new)
220 | gradient_nd = in_nd - SVD_proj + lam * (SVD_proj - r_new.reshape(-1,9))
221 | elif proj_kind == 4:
222 | q_1 = tools.compute_quaternions_from_rotation_matrices(r_new)
223 | q_2 = -q_1
224 | normalized_nd = tools.normalize_vector(in_nd)
225 | q_new = torch.where(
226 | (q_1 - normalized_nd).norm(dim=1, keepdim=True) < (q_2 - normalized_nd).norm(dim=1, keepdim=True),
227 | q_1, q_2)
228 | q_proj = (in_nd * q_new).sum(dim=1, keepdim=True) * q_new
229 | gradient_nd = in_nd - q_proj + lam * (q_proj - q_new)
230 | elif proj_kind == 10:
231 | qg = tools.compute_quaternions_from_rotation_matrices(r_new)
232 | new_x = tools.compute_nearest_10d(in_nd, qg)
233 | reg_A = torch.eye(4, device=qg.device)[None].repeat(qg.shape[0],1,1) - torch.bmm(qg.unsqueeze(-1), qg.unsqueeze(-2))
234 | reg_x = tools.convert_A_to_Avec(reg_A)
235 | gradient_nd = in_nd - new_x + lam * (new_x - reg_x)
236 |
237 | return gradient_nd * weight, None, None,None,None,None
238 |
--------------------------------------------------------------------------------
/bc/pointnet2_classification.py:
--------------------------------------------------------------------------------
1 | # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py
2 | import torch
3 | from torch_geometric.nn import MLP, PointConv, fps, global_max_pool, radius
4 | torch.set_printoptions(linewidth=180, precision=5)
5 | from pytorch3d.transforms import (Rotate, axis_angle_to_matrix)
6 | from rpmg.rpmg import (RPMG, simple_RPMG)
7 | from rpmg.rpmg_losses import rpmg_forward
8 |
9 |
10 | class SAModule(torch.nn.Module):
11 | def __init__(self, ratio, r, nn):
12 | super().__init__()
13 | self.ratio = ratio
14 | self.r = r
15 | self.conv = PointConv(nn, add_self_loops=False)
16 |
17 | def forward(self, x, pos, batch):
18 | idx = fps(pos, batch, ratio=self.ratio)
19 | row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
20 | max_num_neighbors=64)
21 | edge_index = torch.stack([col, row], dim=0)
22 | x_dst = None if x is None else x[idx]
23 | x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)
24 | pos, batch = pos[idx], batch[idx]
25 | return x, pos, batch
26 |
27 |
28 | class GlobalSAModule(torch.nn.Module):
29 | def __init__(self, nn):
30 | super().__init__()
31 | self.nn = nn
32 |
33 | def forward(self, x, pos, batch):
34 | x = self.nn(torch.cat([x, pos], dim=1))
35 | x = global_max_pool(x, batch)
36 | pos = pos.new_zeros((x.size(0), 3))
37 | batch = torch.arange(x.size(0), device=batch.device)
38 | return x, pos, batch
39 |
40 |
41 | class PointNet2_Class(torch.nn.Module):
42 | """NOTE(daniel): PointNet++ from the NeurIPS 2017 paper.
43 |
44 | We need to adjust the input dimension to take into account our `data.x`
45 | which has the features, whereas `data.pos` is known to be 3D. Other than
46 | those changes, I haven't made changes compared to standard classification
47 | PointNet++.
48 |
49 | We need to construct `data.x, data.pos, data.batch` for the forward pass.
50 | For now I do this by constructing `Data` tuples when we have minibatches.
51 | PointNet++ may require smaller minibatch sizes compared to CNNs. On PyG,
52 | each PC for the ModelNet10 data is of dim (1024,3), and with a batch size
53 | of 32 it was taking ~5G of GPU RAM.
54 |
55 | 05/19/2022: if we assume geodesic distance, we must normalize output.
56 | 05/27/2022: adding special case of encoder type with classification but
57 | if we use pointwise loss. We use the per-point flow instead, which
58 | has the same minimum as the per-point MSE on the future tool.
59 | 05/31/2022: ah, fixed bug: I was not re-scaling the PCL values back to
60 | the original value, as I was in segm PN++, we should do that for a
61 | fair comparison with: {class PN++ and pointwise loss}.
62 | 06/02/2022: actually for pointwise baseline we really should remove any
63 | non-used rotations before we convert to axis-angle. That should help
64 | with training.
65 | """
66 |
67 | def __init__(self, in_dim, out_dim, encoder_type, scale_pcl_val=None,
68 | normalize_quat=False, n_epochs=None, rpmg_lambda=None, lambda_rot=None):
69 | super().__init__()
70 | self.in_dim = in_dim # 3 for pos, then rest for segmentation
71 | self.out_dim = out_dim # the action dimension
72 | self.encoder_type = encoder_type
73 | self.scale_pcl_val = scale_pcl_val
74 | self.normalize_quat = normalize_quat
75 | self.n_epochs = n_epochs
76 | self.rpmg_lambda = rpmg_lambda
77 | self.lambda_rot = lambda_rot
78 | self.raw_out = None
79 | self._mask = None
80 |
81 | # Input channels account for both `pos` and node features.
82 | self.sa1_module = SAModule(0.5, 0.2, MLP([self.in_dim, 64, 64, 128]))
83 | self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))
84 | self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))
85 |
86 | self.mlp = MLP([1024, 512, 256, out_dim], dropout=0.5, batch_norm=False)
87 |
88 | def assign_clear_mask(self, mask):
89 | """
90 | For pouring or scooping, in case we have a 6D vector and need to extract
91 | a transformation from it here (and want to clear out unused components).
92 | """
93 | self._mask = mask
94 |
95 | def forward(self, data, info=None, epoch=None, rgt=None):
96 | """Standard forward pass, except we might need a custom `data`.
97 |
98 | Just return MLP output, no log softmax's as there's no classification,
99 | we are either doing regression or segmentation. This is for regression.
100 | If we are using geodesic distance, we assume last 4 parts form a quat,
101 | so we normalize the output.
102 | """
103 |
104 | # Special case if we scaled everything (e.g., meter -> millimeter), need
105 | # to downscale the observation because otherwise PN++ produces NaNs. :(
106 | if self.scale_pcl_val is not None:
107 | data.pos /= self.scale_pcl_val
108 |
109 | sa0_out = (data.x, data.pos, data.batch)
110 | sa1_out = self.sa1_module(*sa0_out)
111 | sa2_out = self.sa2_module(*sa1_out)
112 | sa3_out = self.sa3_module(*sa2_out)
113 | x, _, _ = sa3_out
114 | out = self.mlp(x)
115 |
116 | # Should revert `data.pos` back to the original scale.
117 | if self.scale_pcl_val is not None:
118 | data.pos *= self.scale_pcl_val
119 |
120 | if self.normalize_quat:
121 | # shapes in operation: (batch_size, 4) / (batch_size, 1)
122 | assert self.out_dim == 7, self.out_dim
123 | out_q = out[:,3:] / torch.norm(out[:,3:], dim=1, keepdim=True)
124 | out = torch.cat([out[:,:3], out_q], dim=1)
125 |
126 | if self.encoder_type == 'pointnet_classif_6D_pointwise':
127 | # In this special case, we actually want to return flow, which
128 | # we can compare with ground truth to determine pointwise loss.
129 | # During inference time we need a special case to just take `out`
130 | # computed earlier. Similar to 'pointnet_dense_tf_6D_pointwise'
131 | # but dense_transf is directly the _only_ output of the PN++.
132 | self.raw_out = out
133 |
134 | # Get rotation center (scooping or pouring) for translation correction.
135 | if len(info.shape) == 1:
136 | info = info.unsqueeze(0)
137 | tool_tip_pos = info[:,:3]
138 |
139 | # The network doesn't predict flow, but we use pointwise matching loss.
140 | # Given (o(t), a(t)), we apply transform to get a predicted a(t).
141 | flows = []
142 |
143 | for i in range(len(data.ptr)-1):
144 | tip_pos_one = tool_tip_pos[i][None, None, :] # (1,1,3)
145 | idx1 = data.ptr[i].detach().cpu().numpy().item()
146 | idx2 = data.ptr[i+1].detach().cpu().numpy().item()
147 | posi_one = data.pos[idx1:idx2] # just this PCL's xyz (tool+items).
148 |
149 | # Compute a transformation from just `dense_transf`, no SVD.
150 | dense_transf = self.raw_out[i] * self._mask # note the mask!
151 | mean_one = dense_transf[:3]
152 | rot_matrices = axis_angle_to_matrix(dense_transf[3:]).transpose(0, 1).unsqueeze(0)
153 |
154 | # Correct rotation frame.
155 | t_correction = (tip_pos_one - torch.bmm(tip_pos_one, rot_matrices)).squeeze(1) # (1,3)
156 |
157 | # Compute a transformation.
158 | trfm = Rotate(rot_matrices).translate(mean_one[None,:] + t_correction)
159 |
160 | # Applies on ALL points, but later in loss fxn, we filter by tool.
161 | trfm_xyz = trfm.transform_points(posi_one).squeeze(0)
162 | batch_flow = trfm_xyz - posi_one
163 | flows.append(batch_flow)
164 |
165 | return torch.cat(flows)
166 | elif self.encoder_type == 'pointnet_rpmg_pointwise':
167 | # WARNING: DO NOT USE WITH INTRINSIC ROTATIONS!
168 | out_rot = simple_RPMG.apply(out[:, 3:], 50., self.rpmg_lambda, self.lambda_rot)
169 | out_rot_raw = out_rot.reshape(-1, 9)
170 | self.raw_out = torch.cat((out[:, :3], out_rot_raw), dim=1)
171 |
172 | # Get rotation center (scooping or pouring) for translation correction.
173 | if len(info.shape) == 1:
174 | info = info.unsqueeze(0)
175 | tool_tip_pos = info[:,:3]
176 |
177 | # The network doesn't predict flow, but we use pointwise matching loss.
178 | # Given (o(t), a(t)), we apply transform to get a predicted a(t).
179 | flows = []
180 |
181 | for i in range(len(data.ptr)-1):
182 | tip_pos_one = tool_tip_pos[i][None, None, :] # (1,1,3)
183 | idx1 = data.ptr[i].detach().cpu().numpy().item()
184 | idx2 = data.ptr[i+1].detach().cpu().numpy().item()
185 | posi_one = data.pos[idx1:idx2] # just this PCL's xyz (tool+items).
186 |
187 | # Compute a transformation from just `dense_transf`, no SVD.
188 | mean_one = out[i, :3]
189 | rot_matrices = out_rot[i].unsqueeze(0)
190 |
191 | # Correct rotation frame.
192 | t_correction = (tip_pos_one - torch.bmm(tip_pos_one, rot_matrices)).squeeze(1) # (1,3)
193 |
194 | # Compute a transformation.
195 | trfm = Rotate(rot_matrices).translate(mean_one[None,:] + t_correction)
196 |
197 | # Applies on ALL points, but later in loss fxn, we filter by tool.
198 | trfm_xyz = trfm.transform_points(posi_one).squeeze(0)
199 | batch_flow = trfm_xyz - posi_one
200 | flows.append(batch_flow)
201 |
202 | self.flow_per_pt = torch.cat(flows)
203 | self.flow_per_pt_r = None
204 |
205 | return torch.cat(flows)
206 | elif self.encoder_type == 'pointnet_rpmg':
207 | # We do the default thing for tau as in the RPMG paper
208 | # More concretely, we linearly scale tau from tau_init=1/20 to tau_converge=1/4
209 | # as training progresses. See section 4.3 (in RPMG) and B.1 for a proof of
210 | # tau_converge=1/4
211 | if epoch is None:
212 | tau = 1 / 4
213 | else:
214 | tau = 1 / 20 + (1 / 4 - 1 / 20) / 9 * min(epoch // (self.n_epochs // 10), 9)
215 | out_rot = simple_RPMG.apply(out[:, 3:], tau, self.rpmg_lambda, self.lambda_rot)
216 | out_rot = out_rot.reshape(-1, 9)
217 |
218 | out = torch.cat((out[:, :3], out_rot), dim=1)
219 | return out
220 | elif self.encoder_type == 'pointnet_rpmg_forward':
221 | # Brian/Chuer stand-alone method method returns (batch,3,3). This is just the
222 | # forward pass and doesn't produce the proper gradient, use as a baseline only.
223 | out_rot = rpmg_forward(out[:, 3:])
224 | out_rot = out_rot.reshape(-1, 9)
225 |
226 | # Concatenate --> (batch, 3+9)
227 | out = torch.cat((out[:, :3], out_rot), dim=1)
228 | return out
229 | elif self.encoder_type == 'pointnet_rpmg_taugt':
230 | if rgt is not None:
231 | rgt = rgt.reshape(-1, 12)[:, 3:]
232 | rgt = rgt.reshape(-1, 3, 3)
233 |
234 | out_rot = simple_RPMG.apply(out[:, 3:], -1, self.rpmg_lambda, self.lambda_rot, rgt)
235 | out_rot = out_rot.reshape(-1, 9)
236 |
237 | out = torch.cat((out[:, :3], out_rot), dim=1)
238 | return out
239 | else:
240 | return out
241 |
--------------------------------------------------------------------------------
/chester/setup_ec2_for_chester.py:
--------------------------------------------------------------------------------
1 | import boto3
2 | import re
3 | import sys
4 | import json
5 | import botocore
6 | import os
7 | from rllab.misc import console
8 | from string import Template
9 | import os.path as osp
10 |
11 | CHESTER_DIR = osp.dirname(__file__)
12 | ACCESS_KEY = os.environ["AWS_ACCESS_KEY"]
13 | ACCESS_SECRET = os.environ["AWS_ACCESS_SECRET"]
14 | S3_BUCKET_NAME = os.environ["RLLAB_S3_BUCKET"]
15 |
16 | ALL_REGION_AWS_SECURITY_GROUP_IDS = {}
17 | ALL_REGION_AWS_KEY_NAMES = {}
18 |
19 | CONFIG_TEMPLATE = Template("""
20 | import os.path as osp
21 | import os
22 |
23 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
24 |
25 | AWS_NETWORK_INTERFACES = []
26 |
27 | MUJOCO_KEY_PATH = osp.expanduser("~/.mujoco")
28 |
29 | USE_GPU = False
30 |
31 | USE_TF = True
32 |
33 | AWS_REGION_NAME = "us-east-2"
34 |
35 | if USE_GPU:
36 | DOCKER_IMAGE = "dementrock/rllab3-shared-gpu"
37 | else:
38 | DOCKER_IMAGE = "dementrock/rllab3-shared"
39 |
40 | DOCKER_LOG_DIR = "/tmp/expt"
41 |
42 | AWS_S3_PATH = "s3://$s3_bucket_name/rllab/experiments"
43 |
44 | AWS_CODE_SYNC_S3_PATH = "s3://$s3_bucket_name/rllab/code"
45 |
46 | ALL_REGION_AWS_IMAGE_IDS = {
47 | "ap-northeast-1": "ami-002f0167",
48 | "ap-northeast-2": "ami-590bd937",
49 | "ap-south-1": "ami-77314318",
50 | "ap-southeast-1": "ami-1610a975",
51 | "ap-southeast-2": "ami-9dd4ddfe",
52 | "eu-central-1": "ami-63af720c",
53 | "eu-west-1": "ami-41484f27",
54 | "sa-east-1": "ami-b7234edb",
55 | "us-east-1": "ami-83f26195",
56 | "us-east-2": "ami-66614603",
57 | "us-west-1": "ami-576f4b37",
58 | "us-west-2": "ami-b8b62bd8"
59 | }
60 |
61 | AWS_IMAGE_ID = ALL_REGION_AWS_IMAGE_IDS[AWS_REGION_NAME]
62 |
63 | if USE_GPU:
64 | AWS_INSTANCE_TYPE = "g2.2xlarge"
65 | else:
66 | AWS_INSTANCE_TYPE = "c4.4xlarge"
67 |
68 | ALL_REGION_AWS_KEY_NAMES = $all_region_aws_key_names
69 |
70 | AWS_KEY_NAME = ALL_REGION_AWS_KEY_NAMES[AWS_REGION_NAME]
71 |
72 | AWS_SPOT = True
73 |
74 | AWS_SPOT_PRICE = '0.5'
75 |
76 | AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
77 |
78 | AWS_ACCESS_SECRET = os.environ.get("AWS_ACCESS_SECRET", None)
79 |
80 | AWS_IAM_INSTANCE_PROFILE_NAME = "rllab"
81 |
82 | AWS_SECURITY_GROUPS = ["rllab-sg"]
83 |
84 | ALL_REGION_AWS_SECURITY_GROUP_IDS = $all_region_aws_security_group_ids
85 |
86 | AWS_SECURITY_GROUP_IDS = ALL_REGION_AWS_SECURITY_GROUP_IDS[AWS_REGION_NAME]
87 |
88 | FAST_CODE_SYNC_IGNORES = [
89 | ".git",
90 | "data",
91 | "data/local",
92 | "data/archive",
93 | "data/debug",
94 | "data/s3",
95 | "data/video",
96 | "src",
97 | ".idea",
98 | ".pods",
99 | "tests",
100 | "examples",
101 | "docs",
102 | ".idea",
103 | ".DS_Store",
104 | ".ipynb_checkpoints",
105 | "blackbox",
106 | "blackbox.zip",
107 | "*.pyc",
108 | "*.ipynb",
109 | "scratch-notebooks",
110 | "conopt_root",
111 | "private/key_pairs",
112 | ]
113 |
114 | FAST_CODE_SYNC = True
115 |
116 | """)
117 |
118 |
119 | def setup_iam():
120 | iam_client = boto3.client(
121 | "iam",
122 | aws_access_key_id=ACCESS_KEY,
123 | aws_secret_access_key=ACCESS_SECRET,
124 | )
125 | iam = boto3.resource('iam', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=ACCESS_SECRET)
126 |
127 | # delete existing role if it exists
128 | try:
129 | existing_role = iam.Role('rllab')
130 | existing_role.load()
131 | # if role exists, delete and recreate
132 | if not query_yes_no(
133 | "There is an existing role named rllab. Proceed to delete everything rllab-related and recreate?",
134 | default="no"):
135 | sys.exit()
136 | print("Listing instance profiles...")
137 | inst_profiles = existing_role.instance_profiles.all()
138 | for prof in inst_profiles:
139 | for role in prof.roles:
140 | print("Removing role %s from instance profile %s" % (role.name, prof.name))
141 | prof.remove_role(RoleName=role.name)
142 | print("Deleting instance profile %s" % prof.name)
143 | prof.delete()
144 | for policy in existing_role.policies.all():
145 | print("Deleting inline policy %s" % policy.name)
146 | policy.delete()
147 | for policy in existing_role.attached_policies.all():
148 | print("Detaching policy %s" % policy.arn)
149 | existing_role.detach_policy(PolicyArn=policy.arn)
150 | print("Deleting role")
151 | existing_role.delete()
152 | except botocore.exceptions.ClientError as e:
153 | if e.response['Error']['Code'] == 'NoSuchEntity':
154 | pass
155 | else:
156 | raise e
157 |
158 | print("Creating role rllab")
159 | iam_client.create_role(
160 | Path='/',
161 | RoleName='rllab',
162 | AssumeRolePolicyDocument=json.dumps({'Version': '2012-10-17', 'Statement': [
163 | {'Action': 'sts:AssumeRole', 'Effect': 'Allow', 'Principal': {'Service': 'ec2.amazonaws.com'}}]})
164 | )
165 |
166 | role = iam.Role('rllab')
167 | print("Attaching policies")
168 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/AmazonS3FullAccess')
169 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/ResourceGroupsandTagEditorFullAccess')
170 |
171 | print("Creating inline policies")
172 | iam_client.put_role_policy(
173 | RoleName=role.name,
174 | PolicyName='CreateTags',
175 | PolicyDocument=json.dumps({
176 | "Version": "2012-10-17",
177 | "Statement": [
178 | {
179 | "Effect": "Allow",
180 | "Action": ["ec2:CreateTags"],
181 | "Resource": ["*"]
182 | }
183 | ]
184 | })
185 | )
186 | iam_client.put_role_policy(
187 | RoleName=role.name,
188 | PolicyName='TerminateInstances',
189 | PolicyDocument=json.dumps({
190 | "Version": "2012-10-17",
191 | "Statement": [
192 | {
193 | "Sid": "Stmt1458019101000",
194 | "Effect": "Allow",
195 | "Action": [
196 | "ec2:TerminateInstances"
197 | ],
198 | "Resource": [
199 | "*"
200 | ]
201 | }
202 | ]
203 | })
204 | )
205 |
206 | print("Creating instance profile rllab")
207 | iam_client.create_instance_profile(
208 | InstanceProfileName='rllab',
209 | Path='/'
210 | )
211 | print("Adding role rllab to instance profile rllab")
212 | iam_client.add_role_to_instance_profile(
213 | InstanceProfileName='rllab',
214 | RoleName='rllab'
215 | )
216 |
217 |
218 | def setup_s3():
219 | print("Creating S3 bucket at s3://%s" % S3_BUCKET_NAME)
220 | s3_client = boto3.client(
221 | "s3",
222 | aws_access_key_id=ACCESS_KEY,
223 | aws_secret_access_key=ACCESS_SECRET,
224 | )
225 | try:
226 | s3_client.create_bucket(
227 | ACL='private',
228 | Bucket=S3_BUCKET_NAME,
229 | CreateBucketConfiguration={
230 | 'LocationConstraint': 'us-east-2'
231 | }
232 | )
233 | except botocore.exceptions.ClientError as e:
234 | if e.response['Error']['Code'] == 'BucketAlreadyExists':
235 | raise ValueError("Bucket %s already exists. Please reconfigure S3_BUCKET_NAME" % S3_BUCKET_NAME) from e
236 | elif e.response['Error']['Code'] == 'BucketAlreadyOwnedByYou':
237 | print("Bucket already created by you")
238 | else:
239 | raise e
240 | print("S3 bucket created")
241 |
242 |
243 | def setup_ec2():
244 | for region in ["us-east-1", "us-east-2", "us-west-1", "us-west-2"]:
245 | print("Setting up region %s" % region)
246 |
247 | ec2 = boto3.resource(
248 | "ec2",
249 | region_name=region,
250 | aws_access_key_id=ACCESS_KEY,
251 | aws_secret_access_key=ACCESS_SECRET,
252 | )
253 | ec2_client = boto3.client(
254 | "ec2",
255 | region_name=region,
256 | aws_access_key_id=ACCESS_KEY,
257 | aws_secret_access_key=ACCESS_SECRET,
258 | )
259 | existing_vpcs = list(ec2.vpcs.all())
260 | assert len(existing_vpcs) >= 1
261 | vpc = existing_vpcs[0]
262 | print("Creating security group in VPC %s" % str(vpc.id))
263 | try:
264 | security_group = vpc.create_security_group(
265 | GroupName='rllab-sg', Description='Security group for rllab'
266 | )
267 | except botocore.exceptions.ClientError as e:
268 | if e.response['Error']['Code'] == 'InvalidGroup.Duplicate':
269 | sgs = list(vpc.security_groups.filter(GroupNames=['rllab-sg']))
270 | security_group = sgs[0]
271 | else:
272 | raise e
273 |
274 | ALL_REGION_AWS_SECURITY_GROUP_IDS[region] = [security_group.id]
275 |
276 | ec2_client.create_tags(Resources=[security_group.id], Tags=[{'Key': 'Name', 'Value': 'rllab-sg'}])
277 | try:
278 | security_group.authorize_ingress(FromPort=22, ToPort=22, IpProtocol='tcp', CidrIp='0.0.0.0/0')
279 | except botocore.exceptions.ClientError as e:
280 | if e.response['Error']['Code'] == 'InvalidPermission.Duplicate':
281 | pass
282 | else:
283 | raise e
284 | print("Security group created with id %s" % str(security_group.id))
285 |
286 | key_name = 'rllab-%s' % region
287 | try:
288 | print("Trying to create key pair with name %s" % key_name)
289 | key_pair = ec2_client.create_key_pair(KeyName=key_name)
290 | except botocore.exceptions.ClientError as e:
291 | if e.response['Error']['Code'] == 'InvalidKeyPair.Duplicate':
292 | if not query_yes_no("Key pair with name %s exists. Proceed to delete and recreate?" % key_name, "no"):
293 | sys.exit()
294 | print("Deleting existing key pair with name %s" % key_name)
295 | ec2_client.delete_key_pair(KeyName=key_name)
296 | print("Recreating key pair with name %s" % key_name)
297 | key_pair = ec2_client.create_key_pair(KeyName=key_name)
298 | else:
299 | raise e
300 |
301 | key_pair_folder_path = os.path.join(CHESTER_DIR, "private", "key_pairs")
302 | file_name = os.path.join(key_pair_folder_path, "%s.pem" % key_name)
303 |
304 | print("Saving keypair file")
305 | console.mkdir_p(key_pair_folder_path)
306 | with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, 0o600), 'w') as handle:
307 | handle.write(key_pair['KeyMaterial'] + '\n')
308 |
309 | # adding pem file to ssh
310 | os.system("ssh-add %s" % file_name)
311 |
312 | ALL_REGION_AWS_KEY_NAMES[region] = key_name
313 |
314 |
315 | def write_config():
316 | print("Writing config file...")
317 | content = CONFIG_TEMPLATE.substitute(
318 | all_region_aws_key_names=json.dumps(ALL_REGION_AWS_KEY_NAMES, indent=4),
319 | all_region_aws_security_group_ids=json.dumps(ALL_REGION_AWS_SECURITY_GROUP_IDS, indent=4),
320 | s3_bucket_name=S3_BUCKET_NAME,
321 | )
322 | config_personal_file = os.path.join(CHESTER_DIR, "config_ec2.py")
323 | if os.path.exists(config_personal_file):
324 | if not query_yes_no("config_ec2.py exists. Override?", "no"):
325 | sys.exit()
326 | with open(config_personal_file, "wb") as f:
327 | f.write(content.encode("utf-8"))
328 |
329 |
330 | def setup():
331 | setup_s3()
332 | setup_iam()
333 | setup_ec2()
334 | write_config()
335 |
336 |
337 | def query_yes_no(question, default="yes"):
338 | """Ask a yes/no question via raw_input() and return their answer.
339 |
340 | "question" is a string that is presented to the user.
341 | "default" is the presumed answer if the user just hits .
342 | It must be "yes" (the default), "no" or None (meaning
343 | an answer is required of the user).
344 |
345 | The "answer" return value is True for "yes" or False for "no".
346 | """
347 | valid = {"yes": True, "y": True, "ye": True,
348 | "no": False, "n": False}
349 | if default is None:
350 | prompt = " [y/n] "
351 | elif default == "yes":
352 | prompt = " [Y/n] "
353 | elif default == "no":
354 | prompt = " [y/N] "
355 | else:
356 | raise ValueError("invalid default answer: '%s'" % default)
357 |
358 | while True:
359 | sys.stdout.write(question + prompt)
360 | choice = input().lower()
361 | if default is not None and choice == '':
362 | return valid[default]
363 | elif choice in valid:
364 | return valid[choice]
365 | else:
366 | sys.stdout.write("Please respond with 'yes' or 'no' "
367 | "(or 'y' or 'n').\n")
368 |
369 |
370 | if __name__ == "__main__":
371 | setup()
--------------------------------------------------------------------------------
/bc/rotations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from rpmg import tools
4 | from pyquaternion import Quaternion
5 | from scipy.spatial.transform import Rotation as Rot
6 | np.set_printoptions(precision=8)
7 | #from pytorch3d.transforms import quaternion_to_matrix
8 |
9 | # === ROTATION REPRESENTATION CONVERTERS ===
10 |
11 | # class Example:
12 | # def convert_to(self, quaternion, tool_rotation):
13 | # """
14 | # Parameters: quaternion - a PyQuaternion representing EXTRINSIC rotation
15 | # tool_rotation - a PyQuaternion representing the current tool rotation
16 |
17 | # Returns: rotation in a new format, in shape (K,), where K is the number
18 | # of dimensions of the rotation
19 | # """
20 | # pass
21 |
22 | # def convert_from(self, example, tool_rotation):
23 | # """
24 | # Parameters: rotation in the format, in shape (K,)
25 | # tool_rotation - a PyQuaternion representing the current tool rotation
26 |
27 | # Returns: a PyQuaternion representing that rotation
28 | # """
29 | # pass
30 |
31 | class ExtrinsicToIntrinsic:
32 | def convert_to(self, quaternion, tool_rotation):
33 | return tool_rotation.inverse * quaternion * tool_rotation
34 |
35 | def convert_from(self, rotation, tool_rotation):
36 | return tool_rotation * rotation * tool_rotation.inverse
37 |
38 | class Compose:
39 | def __init__(self, *args):
40 | self.ops = args
41 |
42 | def convert_to(self, quaternion, tool_rotation):
43 | for op in self.ops:
44 | quaternion = op.convert_to(quaternion, tool_rotation)
45 |
46 | return quaternion
47 |
48 | def convert_from(self, rotation, tool_rotation):
49 | for op in self.ops[::-1]:
50 | rotation = op.convert_from(rotation, tool_rotation)
51 |
52 | return rotation
53 |
54 | class RotationMatrix:
55 | # NOTE(daniel): this does not depend on the rotation_dim, except for how
56 | # I need to make rtol and atol different for 9D rotations (why? not sure).
57 |
58 | def __init__(self, rot_dim):
59 | self.rot_dim = rot_dim
60 | if self.rot_dim == 9:
61 | self.rtol = 1e-2
62 | self.atol = 1e-2
63 | else:
64 | self.rtol = 1e-5
65 | self.atol = 1e-5
66 |
67 | def convert_to(self, quaternion, tool_rotation):
68 | matrix_3x3 = quaternion.rotation_matrix
69 | return matrix_3x3.flatten()
70 |
71 | def convert_from(self, rotation, tool_rotation):
72 | # Use pyquaternion 0.9.9 to get rtol and atol args working.
73 | matrix_3x3 = rotation.reshape((3,3))
74 | return Quaternion(matrix=matrix_3x3, rtol=self.rtol, atol=self.atol)
75 |
76 | class AxisAngle:
77 | def convert_to(self, quaternion, tool_rotation):
78 | axis = quaternion.get_axis(undefined=np.array([0., 0., 1.]))
79 | dtheta = quaternion.radians
80 |
81 | return axis * dtheta
82 |
83 | def convert_from(self, axis_angle, tool_rotation):
84 | dtheta = np.linalg.norm(axis_angle)
85 | if dtheta != 0:
86 | axis = axis_angle / dtheta
87 | return Quaternion(axis=axis, angle=dtheta)
88 | else:
89 | # identity rotation
90 | return Quaternion()
91 |
92 | class FlowConverter(AxisAngle):
93 | # Flow needs to process quaternion further in the replay buffer
94 | def convert_to(self, quaternion, tool_rotation):
95 | return quaternion
96 |
97 | class RPMGFlowConverter(RotationMatrix):
98 | # RPMG outputs a rotation matrix, not axis angle
99 | def convert_to(self, quaternion, tool_rotation):
100 | return quaternion
101 |
102 | class NoRPMG:
103 | def __init__(self, rot_dim):
104 | self.rot_dim = rot_dim
105 |
106 | def convert_to(self, rot_matrix, tool_rotation):
107 | rot_matrix = rot_matrix.reshape((3, 3))
108 | R = torch.from_numpy(rot_matrix).unsqueeze(0)
109 |
110 | if self.rot_dim == 6:
111 | x = torch.cat([R[:, :, 0], R[:, :, 1]], dim=1)
112 | elif self.rot_dim == 9:
113 | x = R.reshape(-1, 9)
114 | elif self.rot_dim == 4:
115 | x = tools.compute_quaternions_from_rotation_matrices(R)
116 | elif self.rot_dim == 10:
117 | q = tools.compute_quaternions_from_rotation_matrices(R)
118 | reg_A = torch.eye(4, device=q.device)[None].repeat(q.shape[0], 1, 1) \
119 | - torch.bmm(q.unsqueeze(-1), q.unsqueeze(-2))
120 | x = tools.convert_A_to_Avec(reg_A)
121 | else:
122 | raise NotImplementedError
123 |
124 | return x.flatten().numpy()
125 |
126 | def convert_from(self, rotation, tool_rotation):
127 | rotation = torch.from_numpy(rotation).unsqueeze(0)
128 | if self.rot_dim == 6:
129 | r0 = tools.compute_rotation_matrix_from_ortho6d(rotation)
130 | elif self.rot_dim == 9:
131 | r0 = tools.symmetric_orthogonalization(rotation)
132 | elif self.rot_dim == 4:
133 | r0 = tools.compute_rotation_matrix_from_quaternion(rotation)
134 | elif self.rot_dim == 10:
135 | r0 = tools.compute_rotation_matrix_from_10d(rotation)
136 | else:
137 | raise NotImplementedError
138 | return r0.flatten().numpy()
139 |
140 | # === CONVERTER REGISTRATION ===
141 |
142 | CONVERTERS = {
143 | "flow": FlowConverter(),
144 | "axis_angle": AxisAngle(),
145 | "intrinsic_axis_angle": Compose(
146 | ExtrinsicToIntrinsic(),
147 | AxisAngle(),
148 | ),
149 | "rotation_4D": RotationMatrix(rot_dim=4),
150 | "rotation_6D": RotationMatrix(rot_dim=6),
151 | "rotation_9D": RotationMatrix(rot_dim=9),
152 | "rotation_10D": RotationMatrix(rot_dim=10),
153 | "intrinsic_rotation_6D": Compose(
154 | ExtrinsicToIntrinsic(),
155 | RotationMatrix(rot_dim=6),
156 | ),
157 | "rpmg_flow_6D": RPMGFlowConverter(rot_dim=6),
158 | "no_rpmg_6D": Compose(
159 | RotationMatrix(rot_dim=6),
160 | NoRPMG(rot_dim=6),
161 | ),
162 | "no_rpmg_9D": Compose(
163 | RotationMatrix(rot_dim=9),
164 | NoRPMG(rot_dim=9),
165 | ),
166 | "no_rpmg_10D": Compose(
167 | RotationMatrix(rot_dim=10),
168 | NoRPMG(rot_dim=10),
169 | ),
170 | }
171 |
172 | # === ENV SPECIFIC CANONICALIZATION ===
173 |
174 | def _canonicalize_action(env_name, obs_tuple, act_raw, qt_current):
175 | """
176 | Gets global delta action (in PyQuaternion form) from raw action and
177 | observation, depending on environment
178 | This is especially important as PourWater6D has act_raw as delta euler
179 | angle, while MixedMedia uses *local* axis-angle. Both need to be
180 | converted to global axis-angle.
181 | Note [qt_current] is only used for MixedMedia
182 |
183 | Returns [tool_rotation] the extrinsic rotation of the tool currently
184 | """
185 | act_tran = act_raw[:3]
186 | axis = act_raw[3:]
187 | dtheta = np.linalg.norm(axis)
188 |
189 | # Unfortunately there are some 'hacks' here. If PourWater, we negate
190 | # the 3rd value because the axis in SoftGym is actually (0,0,-1) when
191 | # handling rotations. A positive value means dropping water, and that
192 | # is clockwise w.r.t. +z, but if we keep it here that would negate it
193 | # for PyQuaternion. OK to do here as the purpose is to get accurate
194 | # flow estimates, but check if we can simplify the env as well.
195 | if env_name == 'PourWater':
196 | tool_origin = obs_tuple[0][0,:3] # shape (10,14)
197 | if dtheta == 0:
198 | axis = np.array([0., 0., -1.])
199 | else:
200 | # FYI, I tried visualizing with and without this, and we do
201 | # need to negate to see proper flow vectors.
202 | axis[2] = -axis[2]
203 |
204 | axis = axis / np.linalg.norm(axis)
205 | delta_quat = Quaternion(axis=axis, angle=dtheta)
206 | tool_raw = obs_tuple[0][0, 6:10]
207 | tool_rotation = Quaternion(w=tool_raw[3], x=tool_raw[0], y=tool_raw[1], z=tool_raw[2])
208 | elif env_name == 'PourWater6D':
209 | tool_origin = obs_tuple[0][0, :3] # shape (10, 14)
210 |
211 | # Get current global rotation in intrinsic ZYX euler angles due to the way
212 | # we handle rotation
213 | curr_rot = Rot.from_quat(obs_tuple[0][0, 6:10])
214 | curr_euler = curr_rot.as_euler('zyx')
215 | curr_z, curr_y, curr_x = -curr_euler[0], curr_euler[1], curr_euler[2]
216 |
217 | # Compute new tool orientation
218 | theta_x = curr_x + act_raw[3]
219 | theta_y = curr_y + act_raw[4]
220 | theta_z = curr_z + act_raw[5]
221 | axis_ang_z = np.array([0., 0., -1.])
222 | axis_ang_y = np.array([0., 1., 0.])
223 | axis_ang_x = np.array([1, 0., 0.])
224 | axis_angle_z = axis_ang_z * theta_z
225 | axis_angle_y = axis_ang_y * theta_y
226 | axis_angle_x = axis_ang_x * theta_x
227 | final_rot = Rot.from_rotvec(axis_angle_x) * Rot.from_rotvec(axis_angle_y) * Rot.from_rotvec(axis_angle_z)
228 |
229 | # Compute rotation difference
230 | delta_rot = final_rot * curr_rot.inv()
231 | delta_raw = delta_rot.as_quat()
232 | delta_quat = Quaternion(w=delta_raw[3], x=delta_raw[0], y=delta_raw[1], z=delta_raw[2])
233 |
234 | # Compute current tool rotation for later if we want to convert to intrinsic rotations
235 | tool_raw = obs_tuple[0][0, 6:10]
236 | tool_rotation = Quaternion(w=tool_raw[3], x=tool_raw[0], y=tool_raw[1], z=tool_raw[2])
237 | elif env_name in ['MMOneSphere', 'MMMultiSphere']:
238 | tool_origin = obs_tuple[0][:3] # tool tip position
239 | if dtheta == 0:
240 | axis = np.array([0., -1., 0.])
241 |
242 | axis = axis / np.linalg.norm(axis)
243 |
244 | # Change axis from local frame to world to work with 6DoF
245 | axis_world = qt_current.rotate(axis)
246 | delta_quat = Quaternion(axis=axis_world, angle=dtheta)
247 |
248 | tool_rotation = qt_current
249 |
250 | return act_tran, tool_origin, delta_quat, tool_rotation
251 |
252 | def _decanonicalize_action(env_name, delta_quat, env):
253 | dtheta = delta_quat.radians
254 |
255 | if (env_name in ['MMOneSphere', 'MMMultiSphere']):
256 | axis = delta_quat.get_axis(undefined=np.array([0., 0., 1.]))
257 | if dtheta != 0:
258 | # Get global rotation axis from env
259 | curr_q = env.tool_state[0, 6:10]
260 | qt_current = Quaternion(w=curr_q[3], x=curr_q[0], y=curr_q[1], z=curr_q[2])
261 | inv_qt_current = qt_current.inverse
262 | local_axis = inv_qt_current.rotate(axis) * dtheta
263 | return local_axis
264 | else:
265 | return np.zeros(3)
266 | elif (env_name == 'PourWater'):
267 | # Now positive rotation is w.r.t., the negative z axis, not the
268 | # positive z axis. :( Also only do this if we used pytorch geometric.
269 | # I think this condition should be sufficient but DOUBLE CHECK.
270 | # UPDATE(06/02/2022), ah also needs to be done with pointwise losses!
271 | # UPDATE(06/03/2022), if SVD but with MSE after it, don't do this.
272 | axis = delta_quat.get_axis(undefined=np.array([0., 0., 1.]))
273 | action = axis * dtheta
274 | action[2] = -action[2]
275 | return action
276 | elif (env_name in ['PourWater6D']):
277 | # Sadly, we need to convert from delta axis angle to delta euler angle. oof
278 | if dtheta != 0:
279 | # Convert from delta axis angle -> delta scipy rotation
280 | delta_quat._normalise() # knock on wood
281 | delta_items = delta_quat.elements
282 | delta_raw = np.array([delta_items[1], delta_items[2], delta_items[3], delta_items[0]])
283 | delta_rot = Rot.from_quat(delta_raw)
284 |
285 | # Get current env scipy rotation
286 | curr_rot = Rot.from_quat(env.glass_states[0, 6:10])
287 |
288 | # Perform rotation
289 | final_rot = delta_rot * curr_rot
290 |
291 | # Convert from final scipy rotation -> final euler angles
292 | final_euler = final_rot.as_euler('zyx')
293 | # This flipping of [-final_euler[0]] is where we handle the
294 | # flipped z-axis in PourWater
295 | final_z, final_y, final_x = -final_euler[0], final_euler[1], final_euler[2]
296 |
297 | # Convert from final -> delta euler angles
298 | dtheta_x = final_x - env.glass_rotation_x
299 | dtheta_y = final_y - env.glass_rotation_y
300 | dtheta_z = final_z - env.glass_rotation_z
301 |
302 | return np.array([dtheta_x, dtheta_y, dtheta_z])
303 | else:
304 | return np.zeros(3)
305 |
306 | def _env_to_tool_rotation(env_name, env):
307 | if (env_name in ['MMOneSphere', 'MMMultiSphere']):
308 | curr_q = env.tool_state[0, 6:10]
309 | return Quaternion(w=curr_q[3], x=curr_q[0], y=curr_q[1], z=curr_q[2])
310 | elif env_name in ['PourWater', 'PourWater6D']:
311 | curr_q = env.glass_states[0, 6:10]
312 | return Quaternion(w=curr_q[3], x=curr_q[0], y=curr_q[1], z=curr_q[2])
313 |
314 | # === DRIVER CODE ===
315 |
316 | def convert_action(rotation_representation, env_name, obs_tuple, act_raw, qt_current):
317 | assert rotation_representation in CONVERTERS, f"Invalid rotation representation {rotation_representation}"
318 | act_tran, tool_origin, delta_quat, tool_rotation = _canonicalize_action(env_name, obs_tuple, act_raw, qt_current)
319 | converted_rotation = CONVERTERS[rotation_representation].convert_to(delta_quat, tool_rotation)
320 | return act_tran, tool_origin, converted_rotation
321 |
322 | def deconvert_action(rotation_representation, env_name, action, env):
323 | assert rotation_representation in CONVERTERS, f"Invalid rotation representation {rotation_representation}"
324 | tool_rotation = _env_to_tool_rotation(env_name, env)
325 | delta_quat = CONVERTERS[rotation_representation].convert_from(action[3:], tool_rotation)
326 | env_specific_rotation = _decanonicalize_action(env_name, delta_quat, env)
327 | return np.concatenate((action[:3], env_specific_rotation))
328 |
--------------------------------------------------------------------------------
/bc/se3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pickle
4 | from pytorch3d.transforms import (
5 | Transform3d, Rotate, rotation_6d_to_matrix, axis_angle_to_matrix,
6 | matrix_to_quaternion, matrix_to_euler_angles, quaternion_to_axis_angle
7 | )
8 | import torch
9 | import numpy as np
10 | DEG_TO_RAD = np.pi / 180.
11 | RAD_TO_DEG = 180 / np.pi
12 |
13 |
14 | def to_transform3d(x, rot_function = rotation_6d_to_matrix):
15 | trans = x[:,:3]
16 | rot = x[:,3:]
17 | return Transform3d(device=x.device).compose(Rotate(rot_function(rot),
18 | device=rot.device)).translate(trans)
19 |
20 |
21 | def transform_points(points, x, rot_function = rotation_6d_to_matrix):
22 | t = x[:,:3]
23 | R = rot_function(x[:,3:])
24 | return (torch.bmm(R, points.transpose(-2,-1)) + t.unsqueeze(-1)).transpose(-2,-1)
25 |
26 |
27 | def transform3d_to(T, device):
28 | T = T.to(device)
29 | T = T.to(device)
30 | T._transforms = [
31 | t.to(device) for t in T._transforms
32 | ]
33 | return T
34 |
35 |
36 | def random_se3(N, rot_var = np.pi/180 * 5, trans_var = 0.1, device = None):
37 | R = axis_angle_to_matrix(torch.randn(N,3, device=device)*rot_var)
38 | t = torch.randn(N,3, device=device)*trans_var
39 | return Rotate(R, device=device).translate(t)
40 |
41 |
42 | def symmetric_orthogonalization(M):
43 | """Maps arbitrary input matrices onto SO(3) via symmetric orthogonalization.
44 | (modified from https://github.com/amakadia/svd_for_pose)
45 |
46 | M: should have size [batch_size, 3, 3]
47 |
48 | Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3).
49 | """
50 | U, _, Vh = torch.linalg.svd(M)
51 | det = torch.det(torch.bmm(U, Vh)).view(-1, 1, 1)
52 | Vh = torch.cat((Vh[:, :2, :], Vh[:, -1:, :] * det), 1)
53 | R = U @ Vh
54 | return R
55 |
56 |
57 | def flow2pose(xyz, flow, weights=None, return_transform3d=False,
58 | return_quaternions=False, world_frameify=True):
59 | """Flow2Pose via SVD.
60 |
61 | Operates on minibatches of `B` point clouds, each with `N` points. Assumes
62 | all point clouds have `N` points, but in practice we only call this with
63 | minibatch size 1 and we get rid of non-tool points before calling this.
64 |
65 | Outputs a rotation with the origin at the tool centroid. This is not the
66 | origin of the frame where actions are expressed, which is at the tool tip.
67 |
68 | Parameters
69 | ----------
70 | xyz: point clouds of shape (B,N,3). This gets zero-centered so it's OK if it
71 | is not already centered.
72 | flow: corresponding flow of shape (B,N,3). As with xyz, it gets zero-centered.
73 | weights: weights for the N points, set to None for uniform weighting, for now
74 | I don't think we want to weigh certain points more than others, and it
75 | could be tricky when points can technically be any order in a PCL.
76 | return_transform3d: Used if we want to return a transform, for which we apply
77 | on a set of point clouds. This is what Brian/Chuer use to compute losses,
78 | by applying this on original points and comparing point-wise MSEs.
79 | return_quaternions: Use if we want to convert rotation matrices to quaternions.
80 | Uses format of (wxyz) format, so the identity quanternion is (1,0,0,0).
81 | world_frameify: Use if we want to correct the translation vector so that the
82 | transformation is expressed w.r.t. the world frame.
83 | """
84 | if weights is None:
85 | weights = torch.ones(xyz.shape[:-1], device=xyz.device)
86 | ww = (weights / weights.sum(dim=-1, keepdims=True)).unsqueeze(-1)
87 |
88 | # xyz_mean shape: ((B,N,1), (B,N,3)) mult -> (B,N,3) -> sum -> (B,1,3)
89 | xyz_mean = (ww * xyz).sum(dim=1, keepdims=True)
90 | xyz_demean = xyz - xyz_mean # broadcast `xyz_mean`, still shape (B,N,3)
91 |
92 | # As with xyz positions, find (weighted) mean of flow, shape (B,1,3).
93 | flow_mean = (ww * flow).sum(dim=1, keepdims=True)
94 |
95 | # Zero-mean positions plus zero-mean flow to find new points.
96 | xyz_trans = xyz_demean + flow - flow_mean # (B,N,3)
97 |
98 | # Batch matrix-multiply, get X: (B,3,3), each (3x3) matrix is in SO(3).
99 | X = torch.bmm(xyz_demean.transpose(-2,-1), # (B,3,N)
100 | ww * xyz_trans) # (B,N,3)
101 |
102 | # Rotation matrix in SO(3) for each mb item, (B,3,3).
103 | R = symmetric_orthogonalization(X)
104 |
105 | # 3D translation vector for eacb mb item, (B,3) due to squeezing.
106 | if world_frameify:
107 | t = (flow_mean + xyz_mean - torch.bmm(xyz_mean, R)).squeeze(1)
108 | else:
109 | t = flow_mean.squeeze(1)
110 |
111 | if return_transform3d:
112 | return Rotate(R).translate(t)
113 | if return_quaternions:
114 | quats = matrix_to_quaternion(matrix=R)
115 | return quats, t
116 | return R, t
117 |
118 | eps = 1e-9
119 |
120 |
121 | def dualflow2pose(xyz, flow, polarity, weights = None, return_transform3d = False):
122 | if(weights is None):
123 | weights = torch.ones(xyz.shape[:-1], device=xyz.device)
124 | w = (weights / weights.sum(dim=-1, keepdims=True)).unsqueeze(-1)
125 | w_p = (polarity * weights).unsqueeze(-1)
126 | w_p_sum = w_p.sum(dim=1, keepdims=True)
127 | w_p = w_p / w_p_sum.clamp(min=eps)
128 | w_n = ((1-polarity) * weights).unsqueeze(-1)
129 | w_n_sum = w_n.sum(dim=1, keepdims=True)
130 | w_n = w_n / w_n_sum.clamp(min=eps)
131 |
132 |
133 | xyz_mean_p = (w_p * xyz).sum(dim=1, keepdims=True)
134 | xyz_demean_p = xyz - xyz_mean_p
135 | xyz_mean_n = (w_n * xyz).sum(dim=1, keepdims=True)
136 | xyz_demean_n = xyz - xyz_mean_n
137 |
138 | flow_mean_p = (w_p * flow).sum(dim=1, keepdims=True)
139 | flow_demean_p = flow - flow_mean_p
140 | flow_mean_n = (w_n * flow).sum(dim=1, keepdims=True)
141 | flow_demean_n = flow - flow_mean_n
142 |
143 | mask = (polarity.unsqueeze(-1).expand(-1,-1,3)==1)
144 | xyz_1 = torch.where(mask,
145 | xyz_demean_p, xyz_demean_n + flow_demean_n)
146 | xyz_2 = torch.where(mask,
147 | xyz_demean_p + flow_demean_p, xyz_demean_n)
148 |
149 | X = torch.bmm(xyz_1.transpose(-2,-1), w*xyz_2)
150 |
151 | R = symmetric_orthogonalization(X)
152 | t_p = (flow_mean_p + xyz_mean_p - torch.bmm(xyz_mean_p, R))
153 | t_n = (xyz_mean_n - torch.bmm(flow_mean_n + xyz_mean_n, R))
154 |
155 | t = ((w_p_sum * t_p + w_n_sum * t_n)/(w_p_sum + w_n_sum)).squeeze(1)
156 |
157 | if(return_transform3d):
158 | return Rotate(R).translate(t)
159 | return R, t
160 |
161 |
162 | def points2pose(xyz1, xyz2, weights = None, return_transform3d = False):
163 | if(weights is None):
164 | weights = torch.ones(xyz1.shape[:-1], device=xyz1.device)
165 | w = (weights / weights.sum(dim=-1, keepdims=True)).unsqueeze(-1)
166 |
167 | xyz1_mean = (w * xyz1).sum(dim=1, keepdims=True)
168 | xyz1_demean = xyz1 - xyz1_mean
169 |
170 | xyz2_mean = (w * xyz2).sum(dim=1, keepdims=True)
171 | xyz2_demean = xyz2 - xyz2_mean
172 |
173 | X = torch.bmm(xyz1_demean.transpose(-2,-1),
174 | w*xyz2_demean)
175 |
176 | R = symmetric_orthogonalization(X)
177 | t = (xyz2_mean - torch.bmm(xyz1_mean, R)).squeeze(1)
178 |
179 | if(return_transform3d):
180 | return Rotate(R).translate(t)
181 | return R, t
182 |
183 |
184 | def _debug_bc_data(device):
185 | """Debug our Behavioral Cloning data.
186 |
187 | Particularly, the data with the extremely simple 1DoF rotation about the
188 | y axis. Given the ground-truth flow, we should get the desired rotation.
189 | If this doesn't work, there's a problem with SVD. If it does work, then
190 | we at least know this step is OK?
191 |
192 | This could also be useful for debugging the translation corrections we
193 | use, since the center of rotation is not the centroid of the tool, but
194 | at the tip of the tool's stick.
195 | """
196 | def get_obs_tool_flow(pcl, tool_flow):
197 | # Copied from bc utils
198 | pcl_tool = pcl[:,3] == 1
199 | tf_pts = tool_flow['points']
200 | tf_flow = tool_flow['flow']
201 | n_tool_pts_obs = np.sum(pcl_tool)
202 | n_tool_pts_flow = tf_pts.shape[0]
203 | # First shapes only equal if: (a) fewer than max pts or (b) no item/distr.
204 | assert tf_pts.shape[0] <= pcl.shape[0], f'{tf_pts.shape}, {pcl.shape}'
205 | assert tf_pts.shape == tf_flow.shape, f'{tf_pts.shape}, {tf_flow.shape}'
206 | assert n_tool_pts_obs == n_tool_pts_flow, f'{n_tool_pts_obs}, {n_tool_pts_flow}'
207 | assert np.array_equal(pcl[:n_tool_pts_obs,:3], tf_pts) # yay :)
208 | a = np.zeros((pcl.shape[0], 3)) # all non-tool point rows get 0s
209 | a[:n_tool_pts_obs] = tf_flow # actually encode flow for BC purposes
210 | return (pcl, a)
211 |
212 | # This is the v06 with the 1DoF rotation about y axis at the start.
213 | PATH = os.path.join(
214 | '/data/seita/softgym_mm/data_demo/',
215 | 'MMOneSphere_v01_BClone_filtered_ladle_algorithmic_v06_nVars_2000_obs_combo_act_translation_axis_angle',
216 | 'BC_0000_600.pkl'
217 | )
218 | with open(PATH, 'rb') as fh:
219 | data = pickle.load(fh)
220 |
221 | # obses: list of tuples (different obs types), acts: list of axis-angles.
222 | obses = data['obs']
223 | acts = data['act_raw']
224 | assert len(obses) == len(acts)
225 |
226 | # Remember that to get flow we need the _next_ observation.
227 | t = 0
228 | obs_t = obses[t][3] # PCL at idx=3
229 | act_t = acts[t] # the axis-angle formulation
230 | info_tp1 = obses[t+1][4] # Flow at idx=4 for _next_ time step.
231 | obs_t, act_flow_t = get_obs_tool_flow(obs_t, info_tp1)
232 | print(f'Testing PCLs at t={t} shaped {obs_t.shape} with action: {act_t}')
233 | print(f' act_flow_t: {act_flow_t.shape}')
234 |
235 | # Let's visualize this, but make flow longer for clarity. EDIT: for some
236 | # reason I can't create it here, so let's put it in my MWE.
237 |
238 | # Get the pose from the flow.
239 | xyz = torch.as_tensor(np.array([obs_t[:,:3]])).float()
240 | flow = torch.as_tensor(np.array([act_flow_t])).float()
241 | print(f'Calling flow2pose, xyz: {xyz.shape}, flow: {flow.shape}')
242 | pose = flow2pose(
243 | xyz=xyz.to(device),
244 | flow=flow.to(device),
245 | weights=None,
246 | return_transform3d=False,
247 | return_quaternions=False,
248 | )
249 |
250 | # Debugging, could be useful to show a visualization?
251 | R, t = pose
252 | print(f'\nFinished SVD! Shape of R, t: {R.shape}, {t.shape}\n')
253 | R = R[0]
254 | t = t[0]
255 | print('\nThe rotation and translation:')
256 | print(R)
257 | print(t)
258 | print('\nIs this a rotation matrix? This should be the identity.')
259 | RT_times_R = torch.matmul(torch.transpose(R, 0, 1), R)
260 | print(RT_times_R)
261 | print('\nEuler angles:')
262 | print(matrix_to_euler_angles(R, convention='XYZ'))
263 | print('\nQuaternion:')
264 | quat = matrix_to_quaternion(R)
265 | print(quat)
266 | print('\nAxis-angle:')
267 | # WAIT! This is producing the OPPOSITE rotation!
268 | print(quaternion_to_axis_angle(quat))
269 | sys.exit()
270 |
271 |
272 | if __name__ == "__main__":
273 | # Try debugging / testing the flow / SVD methods.
274 | np.set_printoptions(suppress=True, precision=6, linewidth=150, edgeitems=20)
275 | torch.set_printoptions(sci_mode=False, precision=6, linewidth=150, edgeitems=20)
276 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277 | #_debug_bc_data(device)
278 |
279 | # Minibatch of B point clouds, 10 points in it, with 3D position or flow.
280 | B = 4
281 | xyz = torch.rand(B,10000,3)
282 | flow = torch.rand(B,10000,3) * 0.01
283 | ret_trans = False
284 | ret_quat = False
285 | world_frameify = True
286 |
287 | ## For debugging, let's actually do translation only for one item.
288 | ## If we do this, we get a tensor of (1,0,0,0) as output quaternion (good).
289 | #flow[0] = torch.ones((1,3)) * 0.5
290 | #flow[0,0,:] = torch.ones((1,3))
291 |
292 | # More debugging: what if we just scale the first item (pos and flow)?
293 | # If world_frameify=False, translations are the same subject to positioning
294 | # of decimal points (good).
295 | # Rotations are _almost_ the same but up to 3 decimal points:
296 | # axis-angle: tensor([-0.091616, -0.033090, 0.173777]), magn: 0.1502
297 | # axis-angle: tensor([-0.091667, -0.033330, 0.174000]), magn: 0.1503
298 | # axis-angle: tensor([-0.091644, -0.033659, 0.173811]), magn: 0.1506
299 | # axis-angle: tensor([-0.091447, -0.033111, 0.173603]), magn: 0.1504
300 | # It does seem to be close enough and might not matter too much in practice.
301 | # Magntiudes seem to be close enough that the policy will do similar things.
302 | #
303 | # What if world_frameify=True? That would only affect the quality of the
304 | # resulting returned translation. It still scales well though not as perfectly
305 | # as just averaging the flow vectors (makes sense).
306 | xyz[1] = xyz[0] * 10.
307 | flow[1] = flow[0] * 10.
308 | xyz[2] = xyz[0] * 100.
309 | flow[2] = flow[0] * 100.
310 | xyz[3] = xyz[0] * 1000.
311 | flow[3] = flow[0] * 1000.
312 |
313 | # Get the pose from the flow.
314 | pose = flow2pose(
315 | xyz=xyz.to(device),
316 | flow=flow.to(device),
317 | weights=None,
318 | return_transform3d=ret_trans, # Brian/Chuer use True
319 | return_quaternions=ret_quat, # something new
320 | world_frameify=world_frameify, # something new
321 | )
322 | if ret_trans:
323 | sys.exit()
324 |
325 | # Debugging, could be useful to show a visualization?
326 | R, t = pose
327 | print(f'Shape of R, t: {R.shape}, {t.shape}\n')
328 | for b in range(B):
329 | print(f'\n---------- On minibatch item {b} ----------')
330 | print(R[b])
331 | quat_b = matrix_to_quaternion(R[b])
332 | aang_b = quaternion_to_axis_angle(quat_b)
333 | magn_rad = torch.norm(aang_b)
334 | magn_deg = magn_rad * RAD_TO_DEG
335 | print(f'quaternion: {quat_b}')
336 | print(f'axis-angle: {aang_b}, magnitude: {magn_rad:0.3f} rad, {magn_deg:0.3f} deg')
337 | print(f'translation: {t[b]}')
338 | if not ret_quat:
339 | print(f'To confirm rotation matrix (also check it is on cuda):')
340 | RTR = torch.matmul(torch.transpose(R[b], 0, 1), R[b])
341 | print(RTR)
342 |
--------------------------------------------------------------------------------
/tests/tool_reduce_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 87,
6 | "id": "45533134-06e9-4f26-8b42-e7c2b7868227",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import pickle\n",
11 | "from softgym.softgym.utils import visualization\n",
12 | "from torch_geometric.nn import fps\n",
13 | "import os\n",
14 | "import torch\n",
15 | "import numpy as np\n",
16 | "from pyquaternion import Quaternion"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 4,
22 | "id": "be2bd76c-4089-4fc0-86dc-a1729f99ba28",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "DATA_DIR = \"/data/seita/softgym_mm/data_demo/MMOneSphere_v01_BClone_filtered_ladle_algorithmic_v04_nVars_2000_obs_combo_act_translation_axis_angle\""
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 8,
32 | "id": "9a3db0d9-424d-44e0-ad01-e58a82d6f5bc",
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "with open(os.path.join(DATA_DIR, \"BC_0000_100.pkl\"), 'rb') as f:\n",
37 | " data = pickle.load(f)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 21,
43 | "id": "f171c4f8-889f-4f0e-8530-ac6fac787466",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "obs_p = [data['obs'][0][3]]"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 22,
53 | "id": "e71f063b-f126-496c-8f63-ddbac7a9fd33",
54 | "metadata": {},
55 | "outputs": [
56 | {
57 | "name": "stdout",
58 | "output_type": "stream",
59 | "text": [
60 | "MoviePy - Building file ./point_cloud_segm.gif with imageio.\n"
61 | ]
62 | },
63 | {
64 | "name": "stderr",
65 | "output_type": "stream",
66 | "text": [
67 | " \r"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "visualization.save_pointclouds(obs_p, savedir='.')"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 45,
78 | "id": "68b29fa7-c735-4846-819a-80817509c4b4",
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "pcl = obs_p[0]\n",
83 | "\n",
84 | "i_tool = np.where(pcl[:,3] > 0)[0]\n",
85 | "tool_pts = pcl[i_tool]"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": 46,
91 | "id": "48151410-912d-4104-9565-10235b8c5f89",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "tool_pts = torch.from_numpy(tool_pts)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 59,
101 | "id": "446e7d3b-1d4d-416b-bcd9-6a4d854d7ff4",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "sampled_tool_idxs = fps(tool_pts[:, :3], ratio=0.05, random_start=False)"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": 60,
111 | "id": "64f411eb-2fdf-44b0-b45a-4fbbf467996b",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "sampled_tool_pts = tool_pts[sampled_tool_idxs]"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 61,
121 | "id": "fb50a4b7-ff9d-4c7e-aed4-c8219b71fe8e",
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "name": "stdout",
126 | "output_type": "stream",
127 | "text": [
128 | "MoviePy - Building file ./point_cloud_segm.gif with imageio.\n"
129 | ]
130 | },
131 | {
132 | "name": "stderr",
133 | "output_type": "stream",
134 | "text": [
135 | " \r"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "visualization.save_pointclouds([sampled_tool_pts], savedir='.')"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 66,
146 | "id": "f88ba691-8b9b-4cee-9474-ffac6b4f376c",
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "tool_tip_pos = data['obs'][0][0][:3]"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 73,
156 | "id": "cd293673-4a7c-44b5-9a6b-0c4f7b4ba7d9",
157 | "metadata": {},
158 | "outputs": [
159 | {
160 | "data": {
161 | "text/plain": [
162 | "array([-0.00986772, 0.54506862, -0.07499563])"
163 | ]
164 | },
165 | "execution_count": 73,
166 | "metadata": {},
167 | "output_type": "execute_result"
168 | }
169 | ],
170 | "source": [
171 | "tool_tip_pos"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 70,
177 | "id": "1e0d9a18-8411-4358-abe8-44e7b0b36666",
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "gt_tool_pts = sampled_tool_pts[:, :3] - tool_tip_pos"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 71,
187 | "id": "48270c87-7b5d-4f0b-bf75-76f661fdae45",
188 | "metadata": {},
189 | "outputs": [
190 | {
191 | "data": {
192 | "text/plain": [
193 | "tensor([-0.0045, -0.0117, -0.0070], dtype=torch.float64)"
194 | ]
195 | },
196 | "execution_count": 71,
197 | "metadata": {},
198 | "output_type": "execute_result"
199 | }
200 | ],
201 | "source": [
202 | "gt_tool_pts[0]"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": 74,
208 | "id": "b4f93391-559d-441e-8c8b-cf263871f903",
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "with open('100_tool_pts.pkl', 'wb') as f:\n",
213 | " pickle.dump(gt_tool_pts, f)"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 75,
219 | "id": "deb4f8ad-b89a-4d80-8583-37ca6b83da5a",
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "all_gt_tool_pts = sampled_tool_pts\n",
224 | "all_gt_tool_pts[:, :3] -= tool_tip_pos"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": 76,
230 | "id": "795096bb-95b4-4afc-b609-255ba841ae7e",
231 | "metadata": {},
232 | "outputs": [
233 | {
234 | "name": "stdout",
235 | "output_type": "stream",
236 | "text": [
237 | "MoviePy - Building file ./point_cloud_segm.gif with imageio.\n"
238 | ]
239 | },
240 | {
241 | "name": "stderr",
242 | "output_type": "stream",
243 | "text": [
244 | " \r"
245 | ]
246 | }
247 | ],
248 | "source": [
249 | "visualization.save_pointclouds([all_gt_tool_pts], savedir='.')"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 82,
255 | "id": "deb1bc32-fa0a-4e8c-a9b5-e06fea6c086b",
256 | "metadata": {},
257 | "outputs": [
258 | {
259 | "data": {
260 | "text/plain": [
261 | "tensor(0.1027, dtype=torch.float64)"
262 | ]
263 | },
264 | "execution_count": 82,
265 | "metadata": {},
266 | "output_type": "execute_result"
267 | }
268 | ],
269 | "source": [
270 | "max(gt_tool_pts[1])"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 92,
276 | "id": "6a3997c5-f96f-42d2-b650-bff12117f8fb",
277 | "metadata": {},
278 | "outputs": [],
279 | "source": [
280 | "class MixedMediaToolReducer:\n",
281 | " TOOL_DATA_PATH = \"bc/100_tool_pts.pkl\"\n",
282 | " ACTION_LOW = np.array([ 0, 0, 0, -1, -1, -1])\n",
283 | " ACTION_HIGH = np.array([ 0, 0, 0, 1, 1, 1])\n",
284 | " DEG_TO_RAD = np.pi / 180.\n",
285 | " MAX_ROT_AXIS_ANG = (10. * DEG_TO_RAD)\n",
286 | "\n",
287 | " def __init__(self, args, action_repeat):\n",
288 | " assert args.reduce_tool_points\n",
289 | " self.tool_point_num = args.tool_point_num\n",
290 | " self.action_repeat = action_repeat\n",
291 | "\n",
292 | " self.MAX_ROT_AXIS_ANG /= action_repeat\n",
293 | "\n",
294 | " with open(self.TOOL_DATA_PATH, 'rb') as f:\n",
295 | " self.all_tool_points = pickle.load(f)\n",
296 | "\n",
297 | " # Sample tool points\n",
298 | " ratio = self.tool_point_num / 100\n",
299 | " sampled_idxs = fps(self.all_tool_points, ratio=ratio, random_start=False)\n",
300 | " self.tool_points = self.all_tool_points[sampled_idxs].detach().numpy()\n",
301 | "\n",
302 | " self.rotation = Quaternion()\n",
303 | "\n",
304 | " # Prep tool points for rotation\n",
305 | " self.vec_mat = np.zeros((self.tool_point_num, 4, 4), dtype=self.tool_points.dtype)\n",
306 | " self.vec_mat[:, 0, 1] = -self.tool_points[:, 0]\n",
307 | " self.vec_mat[:, 0, 2] = -self.tool_points[:, 1]\n",
308 | " self.vec_mat[:, 0, 3] = -self.tool_points[:, 2]\n",
309 | "\n",
310 | " self.vec_mat[:, 1, 0] = self.tool_points[:, 0]\n",
311 | " self.vec_mat[:, 1, 2] = -self.tool_points[:, 2]\n",
312 | " self.vec_mat[:, 1, 3] = self.tool_points[:, 1]\n",
313 | "\n",
314 | " self.vec_mat[:, 2, 0] = self.tool_points[:, 1]\n",
315 | " self.vec_mat[:, 2, 1] = self.tool_points[:, 2]\n",
316 | " self.vec_mat[:, 2, 3] = -self.tool_points[:, 0]\n",
317 | "\n",
318 | " self.vec_mat[:, 3, 0] = self.tool_points[:, 2]\n",
319 | " self.vec_mat[:, 3, 1] = -self.tool_points[:, 1]\n",
320 | " self.vec_mat[:, 3, 2] = self.tool_points[:, 0]\n",
321 | "\n",
322 | " def reset(self):\n",
323 | " self.rotation = Quaternion()\n",
324 | "\n",
325 | " def set_axis(self, axis):\n",
326 | " self.rotation = Quaternion(w=axis[3], x=axis[0], y=axis[1], z=axis[2])\n",
327 | "\n",
328 | " def step(self, act_raw):\n",
329 | " # act_raw: [x, y, z, rx, ry, rz]\n",
330 | " act_clip = np.clip(act_raw, a_min=self.ACTION_LOW, a_max=self.ACTION_HIGH)\n",
331 | " axis = act_clip[3:]\n",
332 | "\n",
333 | " dtheta = np.linalg.norm(act_clip[3:])\n",
334 | " if dtheta > self.MAX_ROT_AXIS_ANG:\n",
335 | " dtheta = dtheta * self.MAX_ROT_AXIS_ANG / np.sqrt(3)\n",
336 | " \n",
337 | " if dtheta == 0:\n",
338 | " axis = np.array([0., -1., 0.])\n",
339 | "\n",
340 | " axis = axis / np.linalg.norm(axis)\n",
341 | "\n",
342 | " for i in range(self.action_repeat):\n",
343 | " axis_world = self.rotation.rotate(axis)\n",
344 | " qt_rotate = Quaternion(axis=axis_world, angle=dtheta)\n",
345 | " self.rotation = qt_rotate * self.rotation\n",
346 | "\n",
347 | " def reduce_tool(self, obs, info):\n",
348 | " tool_idxs = np.where(obs[:, 3] == 1)[0]\n",
349 | " obs_notool = obs[len(tool_idxs):]\n",
350 | "\n",
351 | " tool_tip_pos = info[:3]\n",
352 | "\n",
353 | " # Rotate tool points\n",
354 | " global_rotation = self.rotation\n",
355 | " global_rotation._normalise()\n",
356 | " dqp = global_rotation.conjugate.q\n",
357 | "\n",
358 | " mid = np.matmul(self.vec_mat, dqp)\n",
359 | " mid = np.expand_dims(mid, axis=-1)\n",
360 | "\n",
361 | " rotated_tool_pts = global_rotation._q_matrix() @ mid\n",
362 | " rotated_tool_pts = rotated_tool_pts[:, 1:, 0]\n",
363 | "\n",
364 | " rotated_tool_pts += tool_tip_pos\n",
365 | "\n",
366 | " num_classes = obs.shape[1] - 3\n",
367 | " tool_onehot = np.zeros((self.tool_point_num, num_classes), dtype=obs.dtype)\n",
368 | " tool_onehot[:, 0] = 1\n",
369 | "\n",
370 | " tool_reduced = np.concatenate((rotated_tool_pts, tool_onehot), axis=1)\n",
371 | " return np.concatenate((tool_reduced, obs_notool), axis=0)"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 98,
377 | "id": "2f4390a7-b05b-49ee-95b5-b49534a12dda",
378 | "metadata": {},
379 | "outputs": [],
380 | "source": [
381 | "class Args:\n",
382 | " reduce_tool_points = True\n",
383 | " tool_point_num = 20\n",
384 | " \n",
385 | "args = Args()"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": 99,
391 | "id": "23529944-66e7-4c03-aa0d-ff42fb27898b",
392 | "metadata": {},
393 | "outputs": [],
394 | "source": [
395 | "num_obs = len(data['obs']) - 1\n",
396 | "\n",
397 | "raw_obs_p = []\n",
398 | "reduced_obs_p = []\n",
399 | "\n",
400 | "tool_reducer = MixedMediaToolReducer(args=args, action_repeat=8)\n",
401 | "tool_reducer.reset()\n",
402 | "\n",
403 | "for t in range(num_obs):\n",
404 | " obs = data['obs'][t]\n",
405 | " raw_obs_p.append(obs[3])\n",
406 | " reduced_obs = tool_reducer.reduce_tool(obs[3], info=obs[0])\n",
407 | " reduced_obs_p.append(reduced_obs)\n",
408 | " tool_reducer.step(data['act_raw'][t])"
409 | ]
410 | },
411 | {
412 | "cell_type": "code",
413 | "execution_count": 100,
414 | "id": "7278da5e-e342-4e4c-9b33-03f472216f2e",
415 | "metadata": {},
416 | "outputs": [
417 | {
418 | "name": "stdout",
419 | "output_type": "stream",
420 | "text": [
421 | "MoviePy - Building file ./raw.gif with imageio.\n"
422 | ]
423 | },
424 | {
425 | "name": "stderr",
426 | "output_type": "stream",
427 | "text": [
428 | " \r"
429 | ]
430 | },
431 | {
432 | "name": "stdout",
433 | "output_type": "stream",
434 | "text": [
435 | "MoviePy - Building file ./reduced.gif with imageio.\n"
436 | ]
437 | },
438 | {
439 | "name": "stderr",
440 | "output_type": "stream",
441 | "text": [
442 | " \r"
443 | ]
444 | }
445 | ],
446 | "source": [
447 | "visualization.save_pointclouds(raw_obs_p, savedir='.', suffix=\"raw.gif\")\n",
448 | "visualization.save_pointclouds(reduced_obs_p, savedir='.', suffix=\"reduced.gif\")"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": null,
454 | "id": "5a1af452-1853-4d4b-bf1f-7c8d7d06a42f",
455 | "metadata": {},
456 | "outputs": [],
457 | "source": []
458 | }
459 | ],
460 | "metadata": {
461 | "kernelspec": {
462 | "display_name": "softgym",
463 | "language": "python",
464 | "name": "softgym"
465 | },
466 | "language_info": {
467 | "codemirror_mode": {
468 | "name": "ipython",
469 | "version": 3
470 | },
471 | "file_extension": ".py",
472 | "mimetype": "text/x-python",
473 | "name": "python",
474 | "nbconvert_exporter": "python",
475 | "pygments_lexer": "ipython3",
476 | "version": "3.6.13"
477 | }
478 | },
479 | "nbformat": 4,
480 | "nbformat_minor": 5
481 | }
482 |
--------------------------------------------------------------------------------
/chester/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import os.path as osp
5 | import json
6 | import time
7 | import datetime
8 | import dateutil.tz
9 | import tempfile
10 | import wandb
11 | from collections import defaultdict
12 |
13 | # LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv', 'tensorboard']
14 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv']
15 | # Also valid: json, tensorboard, wandb
16 |
17 | DEBUG = 10
18 | INFO = 20
19 | WARN = 30
20 | ERROR = 40
21 |
22 | DISABLED = 50
23 |
24 |
25 | class KVWriter(object):
26 | def writekvs(self, kvs):
27 | raise NotImplementedError
28 |
29 |
30 | class SeqWriter(object):
31 | def writeseq(self, seq):
32 | raise NotImplementedError
33 |
34 |
35 | def put_in_middle(str1, str2):
36 | # Put str1 in str2
37 | n = len(str1)
38 | m = len(str2)
39 | if n <= m:
40 | return str2
41 | else:
42 | start = (n - m) // 2
43 | return str1[:start] + str2 + str1[start + m:]
44 |
45 |
46 | class HumanOutputFormat(KVWriter, SeqWriter):
47 | def __init__(self, filename_or_file):
48 | if isinstance(filename_or_file, str):
49 | self.file = open(filename_or_file, 'wt')
50 | self.own_file = True
51 | else:
52 | assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file
53 | self.file = filename_or_file
54 | self.own_file = False
55 |
56 | def writekvs(self, kvs):
57 | # Create strings for printing
58 | key2str = {}
59 | for (key, val) in sorted(kvs.items()):
60 | if isinstance(val, float):
61 | valstr = '%-8.3g' % (val,)
62 | else:
63 | valstr = str(val)
64 | key2str[self._truncate(key)] = self._truncate(valstr)
65 |
66 | # Find max widths
67 | if len(key2str) == 0:
68 | print('WARNING: tried to write empty key-value dict')
69 | return
70 | else:
71 | keywidth = max(map(len, key2str.keys()))
72 | valwidth = max(map(len, key2str.values()))
73 |
74 | # Write out the data
75 | now = datetime.datetime.now(dateutil.tz.tzlocal())
76 | timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z')
77 |
78 | dashes = '-' * (keywidth + valwidth + 7)
79 | dashes_time = put_in_middle(dashes, timestamp)
80 | lines = [dashes_time]
81 | for (key, val) in sorted(key2str.items()):
82 | lines.append('| %s%s | %s%s |' % (
83 | key,
84 | ' ' * (keywidth - len(key)),
85 | val,
86 | ' ' * (valwidth - len(val)),
87 | ))
88 | lines.append(dashes)
89 | self.file.write('\n'.join(lines) + '\n')
90 |
91 | # Flush the output to the file
92 | self.file.flush()
93 |
94 | def _truncate(self, s):
95 | return s[:30] + '...' if len(s) > 33 else s
96 |
97 | def writeseq(self, seq):
98 | for arg in seq:
99 | self.file.write(arg)
100 | self.file.write('\n')
101 | self.file.flush()
102 |
103 | def close(self):
104 | if self.own_file:
105 | self.file.close()
106 |
107 |
108 | class JSONOutputFormat(KVWriter):
109 | def __init__(self, filename):
110 | self.file = open(filename, 'wt')
111 |
112 | def writekvs(self, kvs):
113 | for k, v in sorted(kvs.items()):
114 | if hasattr(v, 'dtype'):
115 | v = v.tolist()
116 | kvs[k] = float(v)
117 | self.file.write(json.dumps(kvs) + '\n')
118 | self.file.flush()
119 |
120 | def close(self):
121 | self.file.close()
122 |
123 |
124 | class CSVOutputFormat(KVWriter):
125 | def __init__(self, filename):
126 | self.file = open(filename, 'w+t')
127 | self.keys = []
128 | self.sep = ','
129 |
130 | def writekvs(self, kvs):
131 | # Add our current row to the history
132 | extra_keys = kvs.keys() - self.keys
133 | if extra_keys:
134 | self.keys.extend(extra_keys)
135 | self.file.seek(0)
136 | lines = self.file.readlines()
137 | self.file.seek(0)
138 | for (i, k) in enumerate(self.keys):
139 | if i > 0:
140 | self.file.write(',')
141 | self.file.write(k)
142 | self.file.write('\n')
143 | for line in lines[1:]:
144 | self.file.write(line[:-1])
145 | self.file.write(self.sep * len(extra_keys))
146 | self.file.write('\n')
147 | for (i, k) in enumerate(self.keys):
148 | if i > 0:
149 | self.file.write(',')
150 | v = kvs.get(k)
151 | if v is not None:
152 | self.file.write(str(v))
153 | self.file.write('\n')
154 | self.file.flush()
155 |
156 | def close(self):
157 | self.file.close()
158 |
159 |
160 | class TensorBoardOutputFormat(KVWriter):
161 | """
162 | Dumps key/value pairs into TensorBoard's numeric format.
163 | """
164 |
165 | def __init__(self, dir):
166 | os.makedirs(dir, exist_ok=True)
167 | self.dir = dir
168 | self.step = 1
169 | prefix = 'events'
170 | path = osp.join(osp.abspath(dir), prefix)
171 | import tensorflow as tf
172 | from tensorflow.python import pywrap_tensorflow
173 | from tensorflow.core.util import event_pb2
174 | from tensorflow.python.util import compat
175 | self.tf = tf
176 | self.event_pb2 = event_pb2
177 | self.pywrap_tensorflow = pywrap_tensorflow
178 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
179 |
180 | def writekvs(self, kvs):
181 | def summary_val(k, v):
182 | kwargs = {'tag': k, 'simple_value': float(v)}
183 | return self.tf.Summary.Value(**kwargs)
184 |
185 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
186 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
187 | event.step = self.step # is there any reason why you'd want to specify the step?
188 | self.writer.WriteEvent(event)
189 | self.writer.Flush()
190 | self.step += 1
191 |
192 | def close(self):
193 | if self.writer:
194 | self.writer.Close()
195 | self.writer = None
196 |
197 |
198 | class WandbOutputFormat(KVWriter):
199 | def __init__(self, exp_name, variant):
200 | if variant is None:
201 | variant = {}
202 |
203 | project = variant.get("wandb_project")
204 | entity = variant.get("wandb_entity")
205 | wandb.init(project=project, entity=entity, name=exp_name)
206 |
207 | def writekvs(self, kvs):
208 | wandb.log(kvs)
209 |
210 | def close(self):
211 | wandb.finish()
212 |
213 |
214 | def make_output_format(format, ev_dir, exp_name, variant=None, log_suffix=''):
215 | os.makedirs(ev_dir, exist_ok=True)
216 | if format == 'stdout':
217 | return HumanOutputFormat(sys.stdout)
218 | elif format == 'log':
219 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix))
220 | elif format == 'json':
221 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix))
222 | elif format == 'csv':
223 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix))
224 | elif format == 'tensorboard':
225 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix))
226 | elif format == 'wandb':
227 | return WandbOutputFormat(exp_name, variant)
228 | else:
229 | raise ValueError('Unknown format specified: %s' % (format,))
230 |
231 |
232 | # ================================================================
233 | # API
234 | # ================================================================
235 |
236 | def logkv(key, val):
237 | """
238 | Log a value of some diagnostic
239 | Call this once for each diagnostic quantity, each iteration
240 | If called many times, last value will be used.
241 | """
242 | Logger.CURRENT.logkv(key, val)
243 |
244 |
245 | def logkv_mean(key, val):
246 | """
247 | The same as logkv(), but if called many times, values averaged.
248 | """
249 | Logger.CURRENT.logkv_mean(key, val)
250 |
251 |
252 | def logkvs(d):
253 | """
254 | Log a dictionary of key-value pairs
255 | """
256 | for (k, v) in d.items():
257 | logkv(k, v)
258 |
259 |
260 | def dumpkvs():
261 | """
262 | Write all of the diagnostics from the current iteration
263 |
264 | level: int. (see logger.py docs) If the global logger level is higher than
265 | the level argument here, don't print to stdout.
266 | """
267 | Logger.CURRENT.dumpkvs()
268 |
269 |
270 | def getkvs():
271 | return Logger.CURRENT.name2val
272 |
273 |
274 | def log(*args, level=INFO):
275 | """
276 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
277 | """
278 | Logger.CURRENT.log(*args, level=level)
279 |
280 |
281 | def debug(*args):
282 | log(*args, level=DEBUG)
283 |
284 |
285 | def info(*args):
286 | log(*args, level=INFO)
287 |
288 |
289 | def warn(*args):
290 | log(*args, level=WARN)
291 |
292 |
293 | def error(*args):
294 | log(*args, level=ERROR)
295 |
296 |
297 | def set_level(level):
298 | """
299 | Set logging threshold on current logger.
300 | """
301 | Logger.CURRENT.set_level(level)
302 |
303 |
304 | def get_dir():
305 | """
306 | Get directory that log files are being written to.
307 | will be None if there is no output directory (i.e., if you didn't call start)
308 | """
309 | return Logger.CURRENT.get_dir()
310 |
311 |
312 | record_tabular = logkv
313 | dump_tabular = dumpkvs
314 |
315 |
316 | class ProfileKV:
317 | """
318 | Usage:
319 | with logger.ProfileKV("interesting_scope"):
320 | code
321 | """
322 |
323 | def __init__(self, n):
324 | self.n = "wait_" + n
325 |
326 | def __enter__(self):
327 | self.t1 = time.time()
328 |
329 | def __exit__(self, type, value, traceback):
330 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1
331 |
332 |
333 | def profile(n):
334 | """
335 | Usage:
336 | @profile("my_func")
337 | def my_func(): code
338 | """
339 |
340 | def decorator_with_name(func):
341 | def func_wrapper(*args, **kwargs):
342 | with ProfileKV(n):
343 | return func(*args, **kwargs)
344 |
345 | return func_wrapper
346 |
347 | return decorator_with_name
348 |
349 |
350 | # ================================================================
351 | # Backend
352 | # ================================================================
353 |
354 | class Logger(object):
355 | DEFAULT = None # A logger with no output files. (See right below class definition)
356 | # So that you can still log to the terminal without setting up any output files
357 | CURRENT = None # Current logger being used by the free functions above
358 |
359 | def __init__(self, dir, output_formats):
360 | self.name2val = defaultdict(float) # values this iteration
361 | self.name2cnt = defaultdict(int)
362 | self.level = INFO
363 | self.dir = dir
364 | self.output_formats = output_formats
365 |
366 | # Logging API, forwarded
367 | # ----------------------------------------
368 | def logkv(self, key, val):
369 | self.name2val[key] = val
370 |
371 | def logkv_mean(self, key, val):
372 | if val is None:
373 | self.name2val[key] = None
374 | return
375 | oldval, cnt = self.name2val[key], self.name2cnt[key]
376 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
377 | self.name2cnt[key] = cnt + 1
378 |
379 | def dumpkvs(self):
380 | if self.level == DISABLED: return
381 | for fmt in self.output_formats:
382 | if isinstance(fmt, KVWriter):
383 | fmt.writekvs(self.name2val)
384 | self.name2val.clear()
385 | self.name2cnt.clear()
386 |
387 | def log(self, *args, level=INFO):
388 | if self.level <= level:
389 | self._do_log(args)
390 |
391 | # Configuration
392 | # ----------------------------------------
393 | def set_level(self, level):
394 | self.level = level
395 |
396 | def get_dir(self):
397 | return self.dir
398 |
399 | def close(self):
400 | for fmt in self.output_formats:
401 | fmt.close()
402 |
403 | # Misc
404 | # ----------------------------------------
405 | def _do_log(self, args):
406 | for fmt in self.output_formats:
407 | if isinstance(fmt, SeqWriter):
408 | fmt.writeseq(map(str, args))
409 |
410 |
411 | Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)])
412 |
413 |
414 | def configure(dir=None, format_strs=None, exp_name=None, variant=None):
415 | if dir is None:
416 | dir = os.getenv('OPENAI_LOGDIR')
417 | if dir is None:
418 | dir = osp.join(tempfile.gettempdir(),
419 | datetime.datetime.now().strftime("chester-%Y-%m-%d-%H-%M-%S"))
420 |
421 | assert isinstance(dir, str)
422 | os.makedirs(dir, exist_ok=True)
423 |
424 | if format_strs is None:
425 | strs = os.getenv('OPENAI_LOG_FORMAT')
426 | format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS
427 | output_formats = [make_output_format(f, dir, exp_name, variant) for f in format_strs]
428 |
429 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
430 | log('Logging to %s' % dir)
431 |
432 |
433 | def reset():
434 | if Logger.CURRENT is not Logger.DEFAULT:
435 | Logger.CURRENT.close()
436 | Logger.CURRENT = Logger.DEFAULT
437 | log('Reset logger')
438 |
439 |
440 | class scoped_configure(object):
441 | def __init__(self, dir=None, format_strs=None):
442 | self.dir = dir
443 | self.format_strs = format_strs
444 | self.prevlogger = None
445 |
446 | def __enter__(self):
447 | self.prevlogger = Logger.CURRENT
448 | configure(dir=self.dir, format_strs=self.format_strs)
449 |
450 | def __exit__(self, *args):
451 | Logger.CURRENT.close()
452 | Logger.CURRENT = self.prevlogger
453 |
454 |
455 | # ================================================================
456 |
457 | def _demo():
458 | info("hi")
459 | debug("shouldn't appear")
460 | set_level(DEBUG)
461 | debug("should appear")
462 | dir = "/tmp/testlogging"
463 | if os.path.exists(dir):
464 | shutil.rmtree(dir)
465 | configure(dir=dir)
466 | logkv("a", 3)
467 | logkv("b", 2.5)
468 | dumpkvs()
469 | logkv("b", -2.5)
470 | logkv("a", 5.5)
471 | dumpkvs()
472 | info("^^^ should see a = 5.5")
473 | logkv_mean("b", -22.5)
474 | logkv_mean("b", -44.4)
475 | logkv("a", 5.5)
476 | dumpkvs()
477 | info("^^^ should see b = 33.3")
478 |
479 | logkv("b", -2.5)
480 | dumpkvs()
481 |
482 | logkv("a", "longasslongasslongasslongasslongasslongassvalue")
483 | dumpkvs()
484 |
485 |
486 | # ================================================================
487 | # Readers
488 | # ================================================================
489 |
490 | def read_json(fname):
491 | import pandas
492 | ds = []
493 | with open(fname, 'rt') as fh:
494 | for line in fh:
495 | ds.append(json.loads(line))
496 | return pandas.DataFrame(ds)
497 |
498 |
499 | def read_csv(fname):
500 | import pandas
501 | return pandas.read_csv(fname, index_col=None, comment='#')
502 |
503 |
504 | def read_tb(path):
505 | """
506 | path : a tensorboard file OR a directory, where we will find all TB files
507 | of the form events.*
508 | """
509 | import pandas
510 | import numpy as np
511 | from glob import glob
512 | from collections import defaultdict
513 | import tensorflow as tf
514 | if osp.isdir(path):
515 | fnames = glob(osp.join(path, "events.*"))
516 | elif osp.basename(path).startswith("events."):
517 | fnames = [path]
518 | else:
519 | raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s" % path)
520 | tag2pairs = defaultdict(list)
521 | maxstep = 0
522 | for fname in fnames:
523 | for summary in tf.train.summary_iterator(fname):
524 | if summary.step > 0:
525 | for v in summary.summary.value:
526 | pair = (summary.step, v.simple_value)
527 | tag2pairs[v.tag].append(pair)
528 | maxstep = max(summary.step, maxstep)
529 | data = np.empty((maxstep, len(tag2pairs)))
530 | data[:] = np.nan
531 | tags = sorted(tag2pairs.keys())
532 | for (colidx, tag) in enumerate(tags):
533 | pairs = tag2pairs[tag]
534 | for (step, value) in pairs:
535 | data[step - 1, colidx] = value
536 | return pandas.DataFrame(data, columns=tags)
537 |
538 |
539 | if __name__ == "__main__":
540 | _demo()
541 |
--------------------------------------------------------------------------------
/chester/utils_s3.py:
--------------------------------------------------------------------------------
1 | from chester import config_ec2 as config
2 | from io import StringIO
3 | import base64
4 | import os
5 | import os.path as osp
6 | import subprocess
7 | import hashlib
8 | import datetime
9 | import dateutil
10 | import re
11 |
12 | now = datetime.datetime.now(dateutil.tz.tzlocal())
13 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
14 |
15 |
16 | def dedent(s):
17 | lines = [l.strip() for l in s.split('\n')]
18 | return '\n'.join(lines)
19 |
20 |
21 | def upload_file_to_s3(script_content):
22 | import tempfile
23 | import uuid
24 | f = tempfile.NamedTemporaryFile(delete=False)
25 | f.write(script_content.encode())
26 | f.close()
27 | remote_path = os.path.join(
28 | config.AWS_CODE_SYNC_S3_PATH, "oversize_bash_scripts", str(uuid.uuid4()))
29 | subprocess.check_call(["aws", "s3", "cp", f.name, remote_path])
30 | os.unlink(f.name)
31 | return remote_path
32 |
33 |
34 | S3_CODE_PATH = None
35 |
36 |
37 | def s3_sync_code(config, dry=False):
38 | global S3_CODE_PATH
39 | if S3_CODE_PATH is not None:
40 | return S3_CODE_PATH
41 | base = config.AWS_CODE_SYNC_S3_PATH
42 | has_git = True
43 |
44 | if config.FAST_CODE_SYNC:
45 | try:
46 | current_commit = subprocess.check_output(
47 | ["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
48 | except subprocess.CalledProcessError as _:
49 | print("Warning: failed to execute git commands")
50 | current_commit = None
51 | file_name = str(timestamp) + "_" + hashlib.sha224(
52 | subprocess.check_output(["pwd"]) + str(current_commit).encode() + str(timestamp).encode()
53 | ).hexdigest() + ".tar.gz"
54 |
55 | file_path = "/tmp/" + file_name
56 |
57 | tar_cmd = ["tar", "-zcvf", file_path, "-C", config.PROJECT_PATH]
58 | for pattern in config.FAST_CODE_SYNC_IGNORES:
59 | tar_cmd += ["--exclude", pattern]
60 | tar_cmd += ["-h", "."]
61 |
62 | remote_path = "%s/%s" % (base, file_name)
63 |
64 | upload_cmd = ["aws", "s3", "cp", file_path, remote_path]
65 |
66 | mujoco_key_cmd = [
67 | "aws", "s3", "sync", config.MUJOCO_KEY_PATH, "{}/.mujoco/".format(base)]
68 |
69 | print(" ".join(tar_cmd))
70 | print(" ".join(upload_cmd))
71 | print(" ".join(mujoco_key_cmd))
72 |
73 | if not dry:
74 | subprocess.check_call(tar_cmd)
75 | subprocess.check_call(upload_cmd)
76 | try:
77 | subprocess.check_call(mujoco_key_cmd)
78 | except Exception as e:
79 | print(e)
80 |
81 | S3_CODE_PATH = remote_path
82 | return remote_path
83 | else:
84 | try:
85 | current_commit = subprocess.check_output(
86 | ["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
87 | clean_state = len(
88 | subprocess.check_output(["git", "status", "--porcelain"])) == 0
89 | except subprocess.CalledProcessError as _:
90 | print("Warning: failed to execute git commands")
91 | has_git = False
92 | dir_hash = base64.b64encode(subprocess.check_output(["pwd"])).decode("utf-8")
93 | code_path = "%s_%s" % (
94 | dir_hash,
95 | (current_commit if clean_state else "%s_dirty_%s" % (current_commit, timestamp)) if
96 | has_git else timestamp
97 | )
98 | full_path = "%s/%s" % (base, code_path)
99 | cache_path = "%s/%s" % (base, dir_hash)
100 | cache_cmds = ["aws", "s3", "cp", "--recursive"] + \
101 | flatten(["--exclude", "%s" % pattern] for pattern in config.CODE_SYNC_IGNORES) + \
102 | [cache_path, full_path]
103 | cmds = ["aws", "s3", "cp", "--recursive"] + \
104 | flatten(["--exclude", "%s" % pattern] for pattern in config.CODE_SYNC_IGNORES) + \
105 | [".", full_path]
106 | caching_cmds = ["aws", "s3", "cp", "--recursive"] + \
107 | flatten(["--exclude", "%s" % pattern] for pattern in config.CODE_SYNC_IGNORES) + \
108 | [full_path, cache_path]
109 | mujoco_key_cmd = [
110 | "aws", "s3", "sync", config.MUJOCO_KEY_PATH, "{}/.mujoco/".format(base)]
111 | print(cache_cmds, cmds, caching_cmds, mujoco_key_cmd)
112 | if not dry:
113 | subprocess.check_call(cache_cmds)
114 | subprocess.check_call(cmds)
115 | subprocess.check_call(caching_cmds)
116 | try:
117 | subprocess.check_call(mujoco_key_cmd)
118 | except Exception:
119 | print('Unable to sync mujoco keys!')
120 | S3_CODE_PATH = full_path
121 | return full_path
122 |
123 |
124 | _find_unsafe = re.compile(r'[a-zA-Z0-9_^@%+=:,./-]').search
125 |
126 |
127 | def _shellquote(s):
128 | """Return a shell-escaped version of the string *s*."""
129 | if not s:
130 | return "''"
131 |
132 | if _find_unsafe(s) is None:
133 | return s
134 |
135 | # use single quotes, and put single quotes into double quotes
136 | # the string $'b is then quoted as '$'"'"'b'
137 |
138 | return "'" + s.replace("'", "'\"'\"'") + "'"
139 |
140 |
141 | def _to_param_val(v):
142 | if v is None:
143 | return ""
144 | elif isinstance(v, list):
145 | return " ".join(map(_shellquote, list(map(str, v))))
146 | else:
147 | return _shellquote(str(v))
148 |
149 |
150 | def to_local_command(params, python_command="python", script=osp.join(config.PROJECT_PATH, 'scripts/run_experiment.py'), use_gpu=False):
151 | command = python_command + " " + script
152 | pre_commands = params.pop("pre_commands", None)
153 | post_commands = params.pop("post_commands", None)
154 | if post_commands is not None:
155 | print("Not executing the post_commands: ", post_commands)
156 |
157 | for k, v in params.items():
158 | if isinstance(v, dict):
159 | for nk, nv in v.items():
160 | if str(nk) == "_name":
161 | command += " --%s %s" % (k, _to_param_val(nv))
162 | else:
163 | command += \
164 | " --%s_%s %s" % (k, nk, _to_param_val(nv))
165 | else:
166 | command += " --%s %s" % (k, _to_param_val(v))
167 | for pre_command in reversed(pre_commands):
168 | command = pre_command + " && " + command
169 | return command
170 |
171 |
172 | def launch_ec2(params_list, exp_prefix, docker_image, code_full_path,
173 | python_command="python",
174 | script='scripts/run_experiment.py',
175 | aws_config=None, dry=False, terminate_machine=True, use_gpu=False, sync_s3_pkl=False,
176 | sync_s3_png=False,
177 | sync_s3_log=False,
178 | sync_s3_html=False,
179 | sync_s3_mp4=False,
180 | sync_s3_gif=False,
181 | sync_s3_pth=False,
182 | sync_s3_txt=False,
183 | sync_log_on_termination=True,
184 | periodic_sync=True, periodic_sync_interval=15):
185 | if len(params_list) == 0:
186 | return
187 |
188 | default_config = dict(
189 | image_id=config.AWS_IMAGE_ID,
190 | instance_type=config.AWS_INSTANCE_TYPE,
191 | key_name=config.AWS_KEY_NAME,
192 | spot=config.AWS_SPOT,
193 | spot_price=config.AWS_SPOT_PRICE,
194 | iam_instance_profile_name=config.AWS_IAM_INSTANCE_PROFILE_NAME,
195 | security_groups=config.AWS_SECURITY_GROUPS,
196 | security_group_ids=config.AWS_SECURITY_GROUP_IDS,
197 | network_interfaces=config.AWS_NETWORK_INTERFACES,
198 | instance_interruption_behavior='terminate', # TODO
199 | )
200 |
201 | if aws_config is None:
202 | aws_config = dict()
203 | aws_config = dict(default_config, **aws_config)
204 |
205 | sio = StringIO()
206 | sio.write("#!/bin/bash\n")
207 | sio.write("{\n")
208 | sio.write("""
209 | die() { status=$1; shift; echo "FATAL: $*"; exit $status; }
210 | """)
211 | sio.write("""
212 | EC2_INSTANCE_ID="`wget -q -O - http://169.254.169.254/latest/meta-data/instance-id`"
213 | """)
214 | # sio.write("""service docker start""")
215 | # sio.write("""docker --config /home/ubuntu/.docker pull {docker_image}""".format(docker_image=docker_image))
216 | sio.write("""
217 | export PATH=/home/ubuntu/bin:/home/ubuntu/.local/bin:$PATH
218 | """)
219 | sio.write("""
220 | export PATH=/home/ubuntu/miniconda3/bin:/usr/local/cuda/bin:$PATH
221 | """)
222 | sio.write("""
223 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco200/bin
224 | """)
225 | sio.write("""
226 | echo $PATH
227 | """)
228 | sio.write("""
229 | export AWS_DEFAULT_REGION={aws_region}
230 | """.format(aws_region=config.AWS_BUCKET_REGION_NAME)) # add AWS_BUCKET_REGION_NAME=us-east-1 in your config.py
231 | sio.write("""
232 | pip install --upgrade --user awscli
233 | """)
234 |
235 | sio.write("""
236 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=Name,Value={exp_name} --region {aws_region}
237 | """.format(exp_name=params_list[0].get("exp_name"), aws_region=config.AWS_REGION_NAME))
238 | if config.LABEL:
239 | sio.write("""
240 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=owner,Value={label} --region {aws_region}
241 | """.format(label=config.LABEL, aws_region=config.AWS_REGION_NAME))
242 | sio.write("""
243 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=exp_prefix,Value={exp_prefix} --region {aws_region}
244 | """.format(exp_prefix=exp_prefix, aws_region=config.AWS_REGION_NAME))
245 |
246 | if config.FAST_CODE_SYNC:
247 | sio.write("""
248 | aws s3 cp {code_full_path} /tmp/chester_code.tar.gz
249 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR))
250 | sio.write("""
251 | mkdir -p {local_code_path}
252 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR,
253 | aws_region=config.AWS_REGION_NAME))
254 | sio.write("""
255 | tar -zxvf /tmp/chester_code.tar.gz -C {local_code_path}
256 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR,
257 | aws_region=config.AWS_REGION_NAME))
258 | else:
259 | sio.write("""
260 | aws s3 cp --recursive {code_full_path} {local_code_path}
261 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR))
262 | s3_mujoco_key_path = config.AWS_CODE_SYNC_S3_PATH + '/.mujoco/'
263 | sio.write("""
264 | aws s3 cp --recursive {} {}
265 | """.format(s3_mujoco_key_path, config.MUJOCO_KEY_PATH))
266 | sio.write("""
267 | cd {local_code_path}
268 | """.format(local_code_path=config.CODE_DIR))
269 |
270 | for params in params_list:
271 | log_dir = params.get("log_dir")
272 | remote_log_dir = params.pop("remote_log_dir")
273 | env = params.pop("env", None)
274 |
275 | sio.write("""
276 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=Name,Value={exp_name} --region {aws_region}
277 | """.format(exp_name=params.get("exp_name"), aws_region=config.AWS_REGION_NAME))
278 | sio.write("""
279 | mkdir -p {log_dir}
280 | """.format(log_dir=log_dir))
281 | if periodic_sync:
282 | include_png = " --include '*.png' " if sync_s3_png else " "
283 | include_pkl = " --include '*.pkl' " if sync_s3_pkl else " "
284 | include_log = " --include '*.log' " if sync_s3_log else " "
285 | include_html = " --include '*.html' " if sync_s3_html else " "
286 | include_mp4 = " --include '*.mp4' " if sync_s3_mp4 else " "
287 | include_gif = " --include '*.gif' " if sync_s3_gif else " "
288 | include_pth = " --include '*.pth' " if sync_s3_pth else " "
289 | include_txt = " --include '*.txt' " if sync_s3_txt else " "
290 | sio.write("""
291 | while /bin/true; do
292 | aws s3 sync --exclude '*' {include_png} {include_pkl} {include_log} {include_html} {include_mp4} {include_gif} {include_pth} {include_txt} --include '*.csv' --include '*.json' {log_dir} {remote_log_dir}
293 | sleep {periodic_sync_interval}
294 | done & echo sync initiated""".format(include_png=include_png, include_pkl=include_pkl,
295 | include_log=include_log, include_html=include_html,
296 | include_mp4=include_mp4, include_gif=include_gif,
297 | include_pth=include_pth, include_txt=include_txt,
298 | log_dir=log_dir, remote_log_dir=remote_log_dir,
299 | periodic_sync_interval=periodic_sync_interval))
300 | if sync_log_on_termination:
301 | sio.write("""
302 | while /bin/true; do
303 | if [ -z $(curl -Is http://169.254.169.254/latest/meta-data/spot/termination-time | head -1 | grep 404 | cut -d \ -f 2) ]
304 | then
305 | logger "Running shutdown hook."
306 | aws s3 cp /home/ubuntu/user_data.log {remote_log_dir}/stdout.log
307 | aws s3 cp --recursive {log_dir} {remote_log_dir}
308 | break
309 | else
310 | # Spot instance not yet marked for termination.
311 | sleep 5
312 | fi
313 | done & echo log sync initiated
314 | """.format(log_dir=log_dir, remote_log_dir=remote_log_dir))
315 | sio.write("""{command}""".format(command=to_local_command(params, python_command=python_command, script=script, use_gpu=use_gpu)))
316 | sio.write("""
317 | aws s3 cp --recursive {log_dir} {remote_log_dir}
318 | """.format(log_dir=log_dir, remote_log_dir=remote_log_dir))
319 | sio.write("""
320 | aws s3 cp /home/ubuntu/user_data.log {remote_log_dir}/stdout.log
321 | """.format(remote_log_dir=remote_log_dir))
322 |
323 | if terminate_machine:
324 | sio.write("""
325 | EC2_INSTANCE_ID="`wget -q -O - http://169.254.169.254/latest/meta-data/instance-id || die \"wget instance-id has failed: $?\"`"
326 | aws ec2 terminate-instances --instance-ids $EC2_INSTANCE_ID --region {aws_region}
327 | """.format(aws_region=config.AWS_REGION_NAME))
328 | sio.write("} >> /home/ubuntu/user_data.log 2>&1\n")
329 |
330 | full_script = dedent(sio.getvalue())
331 |
332 | import boto3
333 | import botocore
334 | if aws_config["spot"]:
335 | ec2 = boto3.client(
336 | "ec2",
337 | region_name=config.AWS_REGION_NAME,
338 | aws_access_key_id=config.AWS_ACCESS_KEY,
339 | aws_secret_access_key=config.AWS_ACCESS_SECRET,
340 | )
341 | else:
342 | ec2 = boto3.resource(
343 | "ec2",
344 | region_name=config.AWS_REGION_NAME,
345 | aws_access_key_id=config.AWS_ACCESS_KEY,
346 | aws_secret_access_key=config.AWS_ACCESS_SECRET,
347 | )
348 |
349 | print("len_full_script", len(full_script))
350 | if len(full_script) > 16384 or len(base64.b64encode(full_script.encode()).decode("utf-8")) > 16384:
351 | # Script too long; need to upload script to s3 first.
352 | # We're being conservative here since the actual limit is 16384 bytes
353 | s3_path = upload_file_to_s3(full_script)
354 | sio = StringIO()
355 | sio.write("#!/bin/bash\n")
356 | sio.write("""
357 | aws s3 cp {s3_path} /home/ubuntu/remote_script.sh --region {aws_region} && \\
358 | chmod +x /home/ubuntu/remote_script.sh && \\
359 | bash /home/ubuntu/remote_script.sh
360 | """.format(s3_path=s3_path, aws_region=config.AWS_REGION_NAME))
361 | user_data = dedent(sio.getvalue())
362 | else:
363 | user_data = full_script
364 | print(full_script)
365 | with open("/tmp/full_script", "w") as f:
366 | f.write(full_script)
367 |
368 | instance_args = dict(
369 | ImageId=aws_config["image_id"],
370 | KeyName=aws_config["key_name"],
371 | UserData=user_data,
372 | InstanceType=aws_config["instance_type"],
373 | EbsOptimized=config.EBS_OPTIMIZED,
374 | SecurityGroups=aws_config["security_groups"],
375 | SecurityGroupIds=aws_config["security_group_ids"],
376 | NetworkInterfaces=aws_config["network_interfaces"],
377 | IamInstanceProfile=dict(
378 | Name=aws_config["iam_instance_profile_name"],
379 | ),
380 | **config.AWS_EXTRA_CONFIGS,
381 | )
382 |
383 | if len(instance_args["NetworkInterfaces"]) > 0:
384 | # disable_security_group = query_yes_no(
385 | # "Cannot provide both network interfaces and security groups info. Do you want to disable security group settings?",
386 | # default="yes",
387 | # )
388 | disable_security_group = True
389 | if disable_security_group:
390 | instance_args.pop("SecurityGroups")
391 | instance_args.pop("SecurityGroupIds")
392 |
393 | if aws_config.get("placement", None) is not None:
394 | instance_args["Placement"] = aws_config["placement"]
395 | if not aws_config["spot"]:
396 | instance_args["MinCount"] = 1
397 | instance_args["MaxCount"] = 1
398 | print("************************************************************")
399 | print(instance_args["UserData"])
400 | print("************************************************************")
401 | if aws_config["spot"]:
402 | instance_args["UserData"] = base64.b64encode(instance_args["UserData"].encode()).decode("utf-8")
403 | spot_args = dict(
404 | DryRun=dry,
405 | InstanceCount=1,
406 | LaunchSpecification=instance_args,
407 | SpotPrice=aws_config["spot_price"],
408 | # ClientToken=params_list[0]["exp_name"],
409 | )
410 | import pprint
411 | pprint.pprint(spot_args)
412 | if not dry:
413 | response = ec2.request_spot_instances(**spot_args)
414 | print(response)
415 | spot_request_id = response['SpotInstanceRequests'][
416 | 0]['SpotInstanceRequestId']
417 | for _ in range(10):
418 | try:
419 | ec2.create_tags(
420 | Resources=[spot_request_id],
421 | Tags=[
422 | {'Key': 'Name', 'Value': params_list[0]["exp_name"]}
423 | ],
424 | )
425 | break
426 | except botocore.exceptions.ClientError:
427 | continue
428 | else:
429 | import pprint
430 | pprint.pprint(instance_args)
431 | ec2.q(
432 | DryRun=dry,
433 | **instance_args
434 | )
435 |
--------------------------------------------------------------------------------