├── chester ├── rsync_include ├── scripts │ ├── install_mpi4py.sh │ └── install_miniconda.sh ├── config_empty.py ├── rsync_exclude ├── examples │ ├── presets.py │ ├── presets3.py │ ├── presets2.py │ ├── train_launch.py │ ├── presets_tiancheng.py │ ├── train.py │ ├── pgm_plot.py │ └── cplot_example.py ├── pull_result.py ├── video_recorder.py ├── pull_s3_result.py ├── add_variants.py ├── config_private.py ├── containers │ └── ubuntu-16.04-lts-rl.README ├── config.py ├── config_ec2.py ├── slurm.py ├── run_exp_worker.py ├── plotting │ └── cplot.py ├── setup_ec2_for_chester.py ├── logger.py └── utils_s3.py ├── bc ├── 100_tool_pts.pkl ├── logger.py ├── encoder.py ├── pointnet2_classification.py ├── rotations.py └── se3.py ├── initializers ├── __init__.py ├── random.py ├── utils.py ├── base.py └── mm.py ├── rpmg ├── README.md ├── rpmg_losses.py └── rpmg.py ├── compile_1.0.sh ├── prepare_1.0.sh ├── compile.sh ├── LICENSE ├── .gitignore ├── README.md └── tests ├── test_pt3d_pyquaternion.ipynb └── tool_reduce_test.ipynb /chester/rsync_include: -------------------------------------------------------------------------------- 1 | softgym/PyFlexRobotics/data -------------------------------------------------------------------------------- /bc/100_tool_pts.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielTakeshi/softagent_tfn/HEAD/bc/100_tool_pts.pkl -------------------------------------------------------------------------------- /initializers/__init__.py: -------------------------------------------------------------------------------- 1 | from initializers.base import InitializerBase 2 | from initializers.utils import load_initializer, initialize_env -------------------------------------------------------------------------------- /rpmg/README.md: -------------------------------------------------------------------------------- 1 | # RPMG (Rotation Projective Manifold Gradient) 2 | 3 | See https://github.com/JYChen18/RPMG which gives us: 4 | 5 | - `rpmg.py` (this has the layer) 6 | - `tools.py` 7 | -------------------------------------------------------------------------------- /chester/scripts/install_mpi4py.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Install mpi4py 3.0.0" 3 | cd /tmp 4 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz 5 | tar -zxf mpi4py-3.0.0.tar.gz 6 | cd mpi4py-3.0.0 7 | python setup.py build --mpicc=/usr/local/bin/mpicc 8 | python setup.py install --user -------------------------------------------------------------------------------- /compile_1.0.sh: -------------------------------------------------------------------------------- 1 | cd ${PYFLEXROOT}/bindings 2 | rm -rf build 3 | mkdir build 4 | cd build 5 | # Seuss 6 | if [[ $(hostname) = *"compute-0"* ]] || [[ $(hostname) = *"autobot-"* ]] || [[ $(hostname) = *"yertle"* ]]; then 7 | export CUDA_BIN_PATH=/usr/local/cuda-9.1 8 | fi 9 | cmake -DPYBIND11_PYTHON_VERSION=3.6 .. 10 | make -j 11 | -------------------------------------------------------------------------------- /prepare_1.0.sh: -------------------------------------------------------------------------------- 1 | PATH=~/miniconda3/bin:$PATH 2 | cd softgym 3 | . prepare_1.0.sh 4 | cd .. 5 | export PYTORCH_JIT=0 6 | export PYFLEXROOT=${PWD}/softgym/PyFlex 7 | export PYTHONPATH=${PWD}:${PWD}/softgym:${PYFLEXROOT}/bindings/build:$PYTHONPATH 8 | export LD_LIBRARY_PATH=${PYFLEXROOT}/external/SDL2-2.0.4/lib/x64:$LD_LIBRARY_PATH 9 | export EGL_GPU=$CUDA_VISIBLE_DEVICES 10 | -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | cd ${PYFLEXROOT}/bindings 2 | rm -rf build 3 | mkdir build 4 | cd build 5 | # Seuss 6 | if [[ $(hostname) = *"compute-0"* ]] || [[ $(hostname) = *"autobot-"* ]] || [[ $(hostname) = *"yertle"* ]]; then 7 | export CUDA_BIN_PATH=/usr/local/cuda-9.1 8 | fi 9 | 10 | if [[ $(hostname) = *"Xingyu-"* ]]; then 11 | export CUDA_BIN_PATH=/usr/local/cuda-9.2 12 | fi 13 | cmake -DPYBIND11_PYTHON_VERSION=3.6 .. 14 | make -j 15 | -------------------------------------------------------------------------------- /initializers/random.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from initializers import InitializerBase 4 | 5 | class RandomInitializer(InitializerBase): 6 | def __init__(self, env, args): 7 | self.env = env 8 | 9 | def get_action(self, obs, info=None): 10 | action = self.env.action_space.sample() 11 | 12 | # done with probability 0.5 13 | done = random.random() < 0.5 14 | 15 | return action, done 16 | -------------------------------------------------------------------------------- /chester/scripts/install_miniconda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Check this file before using 3 | CONDA_INSTALL_PATH="~/software/miniconda3" 4 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh 5 | chmod +x Miniconda3-latest-Linux-x86_64.sh 6 | ./Miniconda3-latest-Linux-x86_64.sh -b -p $CONDA_INSTALL_PATH 7 | if [ -d $CONDA_INSTALL_PATH/bin ]; then 8 | PATH=$PATH:$HOME/bin 9 | fi 10 | echo 'PATH='$CONDA_INSTALL_PATH'/bin:$PATH' >> ~/.bashrc 11 | rm ./Miniconda3-latest-Linux-x86_64.sh -------------------------------------------------------------------------------- /chester/config_empty.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | # TODO change this before make it into a pip package 5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 6 | 7 | LOG_DIR = os.path.join(PROJECT_PATH, "data") 8 | 9 | # Make sure to use absolute path 10 | REMOTE_DIR = { 11 | } 12 | 13 | REMOTE_MOUNT_OPTION = { 14 | } 15 | 16 | REMOTE_LOG_DIR = { 17 | } 18 | 19 | REMOTE_HEADER = dict() 20 | 21 | # location of the singularity file related to the project 22 | SIMG_DIR = { 23 | } 24 | CUDA_MODULE = { 25 | } 26 | MODULES = { 27 | } -------------------------------------------------------------------------------- /chester/rsync_exclude: -------------------------------------------------------------------------------- 1 | *.img 2 | datasets 3 | data 4 | data/yufei_s3_data 5 | data/yufei_seuss_data 6 | data/local 7 | data/icml 8 | *__pycache__* 9 | build 10 | wandb 11 | .idea 12 | .git 13 | DPI-Net 14 | videos 15 | imgs 16 | data_demo/ 17 | *.swp 18 | *.swo 19 | *.png 20 | cloth_manipulation 21 | GNS 22 | PDDM 23 | PDDM 24 | pouring 25 | rlkit 26 | rllab 27 | rlpyt_cloth 28 | VCD 29 | softgym/dummy_data 30 | softgym/PyFlexRobotics/data/PR2/ 31 | softgym/PyFlexRobotics/data/atlas_description/ 32 | softgym/PyFlexRobotics/data/baxter_common-master/ 33 | softgym/PyFlexRobotics/data/dex-net/ 34 | softgym/PyFlexRobotics/data/yumi_description/ 35 | softgym/PyFlexRobotics/data/franka_description/ 36 | softgym/PyFlexRobotics/data/sawyer/ 37 | softgym/PyFlexRobotics/data/jaco/meshes/ 38 | softgym/PyFlexRobotics/data/kuka_iiwa/ 39 | softgym/PyFlexRobotics/data/sektion_cabinet_model/ -------------------------------------------------------------------------------- /chester/examples/presets.py: -------------------------------------------------------------------------------- 1 | preset_names = ['default'] 2 | 3 | def make_custom_seris_splitter(preset_names): 4 | legendNote = None 5 | if preset_names == 'default': 6 | def custom_series_splitter(x): 7 | params = x['flat_params'] 8 | # return params['her_replay_strategy'] 9 | if params['her_replay_strategy'] == 'future': 10 | ret = 'RG' 11 | elif params['her_replay_strategy'] == 'only_fake': 12 | if params['her_use_reward']: 13 | ret = 'FG+RR' 14 | else: 15 | ret = 'FG+FR' 16 | return ret + '+' + str(params['her_clip_len']) + '+' + str(params['her_reward_choices']) + '+' + str( 17 | params['her_failed_goal_ratio']) 18 | 19 | legendNote = "Fake Goal(FG)/Real Goal(RG) + Fake Reward(FR)/Real Goal(RG) + HER_clip_len + HER_reward_choices + HER_failed_goal_ratio" 20 | else: 21 | raise NotImplementedError 22 | return custom_series_splitter, legendNote 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Carnegie Mellon University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /initializers/utils.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | def load_initializer(initializer_cls_string): 4 | """ 5 | Loads an initializer class from a string. 6 | """ 7 | initializer_module_name, initializer_class_name = initializer_cls_string.rsplit('.', 1) 8 | initializer_module = import_module(initializer_module_name) 9 | initializer_cls = getattr(initializer_module, initializer_class_name) 10 | return initializer_cls 11 | 12 | def initialize_env(env, initializer, initial_obs): 13 | """ 14 | Fully initializes an environment using an initializer 15 | 16 | Args: 17 | env: the environment to initialize 18 | initializer: the initializer to use 19 | initial_obs: the initial observation 20 | """ 21 | initializer.reset() 22 | action, done = initializer.get_action(initial_obs) 23 | obs, _, env_done, info = env.step(action) 24 | 25 | while not done: 26 | if env_done: 27 | raise RuntimeError("Environment is done before initializer is done!") 28 | 29 | action, done = initializer.get_action(obs, info) 30 | obs, _, env_done, info = env.step(action) 31 | 32 | return obs, env_done, info 33 | -------------------------------------------------------------------------------- /chester/examples/presets3.py: -------------------------------------------------------------------------------- 1 | set1 = 'identity_ratio+her_clip(dist_behavior and HER)' 2 | preset_names = [set1] 3 | FILTERED = 'filtered' 4 | 5 | 6 | def make_custom_seris_splitter(preset_names): 7 | legendNote = None 8 | if preset_names == set1: 9 | def custom_series_splitter(x): 10 | params = x['flat_params'] 11 | if params['her_failed_goal_option'] in ['dist_G', 'dist_policy']: 12 | return FILTERED 13 | if params['her_identity_ratio'] is not None: 14 | return 'IR: ' + str(params['her_identity_ratio']) 15 | if params['her_clip_len'] is not None: 16 | return 'CL: ' + str(params['her_clip_len']) 17 | return 'HER' 18 | 19 | legendNote = 'IR: identity ratio; CL: clip length' 20 | else: 21 | raise NotImplementedError 22 | return custom_series_splitter, legendNote 23 | 24 | 25 | def make_custom_filter(preset_names): 26 | if preset_names == set1: 27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names) 28 | 29 | def custom_filter(x): 30 | legend = custom_seris_splitter(x) 31 | if legend == FILTERED: 32 | return False 33 | else: 34 | return True 35 | return custom_filter 36 | 37 | -------------------------------------------------------------------------------- /chester/examples/presets2.py: -------------------------------------------------------------------------------- 1 | preset_names = ['default'] 2 | x_axis = 'Epoch' 3 | y_axis = 'Success' 4 | FILTERED = 'filtered' 5 | 6 | def make_custom_seris_splitter(preset_names): 7 | legendNote = None 8 | if preset_names == 'default': 9 | def custom_series_splitter(x): 10 | params = x['flat_params'] 11 | if params['her_failed_goal_option'] is None: 12 | ret = 'Distance Reward' 13 | elif params['her_failed_goal_option'] == 'dist_behaviour': 14 | ret = 'Exact Match' 15 | else: 16 | ret = FILTERED 17 | return ret 18 | 19 | legendNote = None 20 | else: 21 | raise NotImplementedError 22 | return custom_series_splitter, legendNote 23 | 24 | 25 | def make_custom_filter(preset_names): 26 | if preset_names == 'default': 27 | custom_seris_splitter, _ = make_custom_seris_splitter(preset_names) 28 | def custom_filter(x): 29 | legend = custom_seris_splitter(x) 30 | if legend == FILTERED: 31 | return False 32 | else: 33 | return True 34 | # params = x['flat_params'] 35 | # if params['her_failed_goal_option'] != FILTERED: 36 | # return True 37 | # else: 38 | # return False 39 | return custom_filter 40 | 41 | -------------------------------------------------------------------------------- /chester/pull_result.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | sys.path.append('.') 6 | from chester import config 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('host', type=str) 11 | parser.add_argument('folder', type=str) 12 | parser.add_argument('--dry', action='store_true', default=False) 13 | parser.add_argument('--bare', action='store_true', default=False) 14 | args = parser.parse_args() 15 | 16 | args.folder = args.folder.rstrip('/') 17 | if args.folder.rfind('/') !=-1: 18 | local_dir = os.path.join('./data', args.host, args.folder[:args.folder.rfind('/')]) 19 | else: 20 | local_dir = os.path.join('./data', args.host, args.folder) 21 | remote_data_dir = os.path.join(config.REMOTE_DIR[args.host], 'data', 'local', args.folder) 22 | command = """rsync -avzh --delete --progress {host}:{remote_data_dir} {local_dir}""".format(host=args.host, 23 | remote_data_dir=remote_data_dir, 24 | local_dir=local_dir) 25 | if args.bare: 26 | command += """ --exclude '*.pkl' --exclude '*.png' --exclude '*.gif' --exclude '*.pth' --exclude '*.pt' --include '*.csv' --include '*.json' --delete""" 27 | if args.dry: 28 | print(command) 29 | else: 30 | os.system(command) 31 | -------------------------------------------------------------------------------- /initializers/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class InitializerBase(ABC): 4 | """ 5 | Base class for initializers. 6 | """ 7 | @abstractmethod 8 | def __init__(self, env, args): 9 | """ 10 | Initialize the initializer. Variant `args` are passed in for convenience. 11 | 12 | Args: 13 | env: Environment object 14 | args: Argument class 15 | """ 16 | pass 17 | 18 | def reset(self): 19 | """ 20 | (Optional) Called when environment is reset. 21 | Useful for stateful initializers that need to re-init 22 | themselves when a new environment is created 23 | """ 24 | pass 25 | 26 | @abstractmethod 27 | def get_action(self, obs, info=None): 28 | """ 29 | Get action from the initializer. 30 | 31 | NOTE: `info` is NOT AVAILABLE on the first environment step! 32 | If `info` is always needed, just sample something arbitrary 33 | on the first step. Can also use `env` information saved from 34 | __init__. 35 | 36 | However, good practice is to only use information from `obs`, as 37 | this is what the initializer will have access to in real. 38 | 39 | Args: 40 | obs: observation from the environment 41 | info: info from the environment 42 | 43 | Returns: 44 | action: action to be taken 45 | done: True if initializer is complete 46 | """ 47 | pass 48 | -------------------------------------------------------------------------------- /chester/examples/train_launch.py: -------------------------------------------------------------------------------- 1 | import time 2 | from chester.run_exp import run_experiment_lite, VariantGenerator 3 | 4 | if __name__ == '__main__': 5 | 6 | # Here's an example for doing grid search of openai's DDPG 7 | # on HalfCheetah 8 | 9 | # the experiment folder name 10 | # the directory is defined as /LOG_DIR/data/local/exp_prefix/, where LOG_DIR is defined in config.py 11 | exp_prefix = 'test-ddpg' 12 | vg = VariantGenerator() 13 | vg.add('env_id', ['HalfCheetah-v2', 'Hopper-v2', 'InvertedPendulum-v2']) 14 | 15 | # select random seeds from 0 to 4 16 | vg.add('seed', [0, 1, 2, 3, 4]) 17 | print('Number of configurations: ', len(vg.variants())) 18 | 19 | # set the maximum number for running experiments in parallel 20 | # this number depends on the number of processors in the runner 21 | maximum_launching_process = 5 22 | 23 | # launch experiments 24 | sub_process_popens = [] 25 | for vv in vg.variants(): 26 | while len(sub_process_popens) >= maximum_launching_process: 27 | sub_process_popens = [x for x in sub_process_popens if x.poll() is None] 28 | time.sleep(10) 29 | 30 | # import the launcher of experiments 31 | from chester.examples.train import run_task 32 | 33 | # use your written run_task function 34 | cur_popen = run_experiment_lite( 35 | stub_method_call=run_task, 36 | variant=vv, 37 | mode='local', 38 | exp_prefix=exp_prefix, 39 | wait_subprocess=False 40 | ) 41 | if cur_popen is not None: 42 | sub_process_popens.append(cur_popen) 43 | -------------------------------------------------------------------------------- /chester/video_recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glfw 3 | from multiprocessing import Process, Queue 4 | 5 | import cv2 as cv 6 | 7 | class VideoRecorder(object): 8 | ''' 9 | Used to record videos for mujoco_py environment 10 | ''' 11 | def __init__(self, env, saved_path='./data/videos/', saved_name='temp'): 12 | # Get rid of the gym wrappers 13 | if hasattr(env, 'env'): 14 | env = env.env 15 | self.viewer = env._get_viewer() 16 | self.saved_path = saved_path 17 | self.saved_name = saved_name 18 | # self._set_filepath('/tmp/temp%07d.mp4') 19 | saved_name += '.mp4' 20 | self._set_filepath(os.path.join(saved_path, saved_name)) 21 | 22 | def _set_filepath(self, video_name): 23 | self.viewer._video_path = video_name 24 | 25 | def start(self): 26 | self.viewer._record_video = True 27 | if self.viewer._record_video: 28 | fps = (1 / self.viewer._time_per_render) 29 | self.viewer._video_process = Process(target=save_video, 30 | args=(self.viewer._video_queue, 31 | self.viewer._video_path, fps)) 32 | self.viewer._video_process.start() 33 | 34 | def end(self): 35 | self.viewer.key_callback(None, glfw.KEY_V, None, glfw.RELEASE, None) 36 | 37 | # class VideoRecorderDM(object): 38 | # ''' 39 | # Used to record videos for dm_control based environments 40 | # ''' 41 | # def __init__(self, env, saved_path='./data/videos/', saved_name='temp'): 42 | # self.saved_path = saved_path 43 | # self.saved_name = saved_name 44 | # 45 | # def -------------------------------------------------------------------------------- /chester/pull_s3_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | 5 | def aws_sync(bucket_name, s3_log_dir, target_dir, args): 6 | cmd = 'aws s3 cp --recursive s3://%s/%s %s' % (bucket_name, s3_log_dir, target_dir) 7 | exlus = ['"*.pkl"', '"*.gif"', '"*.png"', '"*.pth"'] 8 | inclus = [] 9 | if args.gif: 10 | exlus.remove('"*.gif"') 11 | if args.png: 12 | exlus.remove('"*.png"') 13 | if args.param: 14 | inclus.append('"params.pkl"') 15 | exlus.remove('"*.pkl"') 16 | 17 | if not args.include_all: 18 | for exc in exlus: 19 | cmd += ' --exclude ' + exc 20 | 21 | for inc in inclus: 22 | cmd += ' --include ' + inc 23 | 24 | print(cmd) 25 | # exit() 26 | subprocess.call(cmd, shell=True) 27 | 28 | 29 | def main(): 30 | parser = argparse.ArgumentParser(description='Process some integers.') 31 | parser.add_argument('log_dir', type=str, help='S3 Log dir') 32 | parser.add_argument('-b', '--bucket', type=str, default='chester-softgym', help='S3 Bucket') 33 | parser.add_argument('--param', type=int, default=0, help='Exclude') 34 | parser.add_argument('--gif', type=int, default=0, help='Exclude') 35 | parser.add_argument('--png', type=int, default=0, help='Exclude') 36 | parser.add_argument('--include_all', type=int, default=1, help='pull all data') 37 | 38 | args = parser.parse_args() 39 | s3_log_dir = "rllab/experiments/" + args.log_dir 40 | local_dir = os.path.join('./data', 'corl_s3_data', args.log_dir) 41 | if not os.path.exists(local_dir): 42 | os.makedirs(local_dir) 43 | aws_sync(args.bucket, s3_log_dir, local_dir, args) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /chester/examples/presets_tiancheng.py: -------------------------------------------------------------------------------- 1 | # Updated By Tiancheng Jin 08/28/2018 2 | 3 | # the preset file should be contained in the experiment folder ( which is assigned by exp_prefix ) 4 | # for example, this file should be put in /path to project/data/local/ 5 | 6 | # Here's an example for custom_series_splitter 7 | # suppose we want to split five experiments with random seeds from 0 to 4 into two strategies 8 | # * two groups for those with odd or plural random seeds: [0,2,4] and [1,3] 9 | # * two groups for those with smaller or larger random seeds: [0,1,2] and [3,4] 10 | 11 | preset_names = ['odd or plural','small or large'] 12 | 13 | 14 | def make_custom_seris_splitter(preset_name): 15 | legend_note = None 16 | custom_series_splitter = None 17 | 18 | if preset_name == 'odd or plural': 19 | # build a custom series splitter for odd or plural random seeds 20 | # where the input is the data for experiment ( contains both the results and the parameters ) 21 | def custom_series_splitter(x): 22 | # extract the parameters 23 | params = x['flat_params'] 24 | # make up the legend 25 | if params['seed'] % 2 == 0: 26 | legend = 'odd seeds' 27 | else: 28 | legend = 'plural seeds' 29 | return legend 30 | 31 | legend_note = "Odd or Plural" 32 | 33 | elif preset_name == 'small or large': 34 | def custom_series_splitter(x): 35 | params = x['flat_params'] 36 | if params['seed'] <= 2: 37 | legend = 'smaller seeds' 38 | else: 39 | legend = 'larger seeds' 40 | return legend 41 | 42 | legend_note = "Small or Large" 43 | else: 44 | assert NotImplementedError 45 | 46 | return custom_series_splitter, legend_note 47 | -------------------------------------------------------------------------------- /chester/add_variants.py: -------------------------------------------------------------------------------- 1 | # Add variants to finished experiments 2 | import argparse 3 | import os 4 | import json 5 | from pydoc import locate 6 | import config 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('exp_folder', type=str, help='Root of the experiment folder to walk through') 12 | parser.add_argument('key', type=str, help='Name of the additional key') 13 | parser.add_argument('value', help='Value of the additional key') 14 | parser.add_argument('value_type', default='str', type=str, help='Type of the additional key') 15 | parser.add_argument('remote', nargs='?', default=None, type=str, ) # Optional 16 | 17 | args = parser.parse_args() 18 | exp_paths = [x[0] for x in os.walk(args.exp_folder, followlinks=True)] 19 | 20 | value_type = locate(args.value_type) 21 | if value_type == bool: 22 | value = args.value in ['1', 'True', 'true'] 23 | else: 24 | value = value_type(args.value) 25 | 26 | for exp_path in exp_paths: 27 | try: 28 | variant_path = os.path.join(exp_path, "variant.json") 29 | # Modify locally 30 | with open(variant_path, 'r') as f: 31 | vv = json.load(f) 32 | if args.key in vv: 33 | print('Warning: key already in variants. {} = {}. Setting it to {}'.format(args.key, vv[args.key], value)) 34 | 35 | vv[args.key] = value 36 | with open(variant_path, 'w') as f: 37 | json.dump(vv, f, indent=2, sort_keys=True) 38 | print('{} modified'.format(variant_path)) 39 | 40 | # Upload it to remote 41 | if args.remote is not None: 42 | p = variant_path.rstrip('/').split('/') 43 | sub_exp_name, exp_name = p[-2], p[-3] 44 | 45 | remote_dir = os.path.join(config.REMOTE_DIR[args.remote], 'data', 'local', exp_name, sub_exp_name, 'variant.json') 46 | os.system('scp {} {}:{}'.format(variant_path, args.remote, remote_dir)) 47 | except IOError as e: 48 | print(e) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /chester/config_private.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | # TODO change this before make it into a pip package 5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 6 | 7 | LOG_DIR = os.path.join(PROJECT_PATH, "data") 8 | 9 | # Make sure to use absolute path 10 | REMOTE_DIR = { 11 | 'seuss': '/home/yufeiw2/Projects/softagent', 12 | 'psc': '/home/yufeiw2/Projects/softagent', 13 | 'nsh': '/home/yufeiw2/Projects/softagent', 14 | 'yertle':'/home/yufeiw2/Projects/softagent' 15 | } 16 | 17 | REMOTE_MOUNT_OPTION = { 18 | 'seuss': '/usr/share/glvnd', 19 | # 'psc': '/pylon5/ir5fpfp/xlin3/Projects/baselines_hrl/:/mnt', 20 | } 21 | 22 | REMOTE_LOG_DIR = { 23 | 'seuss': os.path.join(REMOTE_DIR['seuss'], "data"), 24 | 25 | # 'psc': os.path.join(REMOTE_DIR['psc'], "data") 26 | 'psc': os.path.join('/mnt', "data"), 27 | } 28 | # PSC: https://www.psc.edu/bridges/user-guide/running-jobs 29 | # partition include [RM, RM-shared, LM, GPU] 30 | # TODO change cpu-per-task based on the actual cpus needed (on psc) 31 | # #SBATCH --exclude=compute-0-[7,11] 32 | # Adding this will make the job to grab the whole gpu. #SBATCH --gres=gpu:1 33 | REMOTE_HEADER = dict(seuss=""" 34 | #!/usr/bin/env bash 35 | #SBATCH --nodes=1 36 | #SBATCH --partition=GPU 37 | #SBATCH --exclude=compute-0-[5,7,13,27] 38 | #SBATCH --ntasks-per-node=8 39 | #SBATCH --time=480:00:00 40 | #SBATCH --gres=gpu:1 41 | #SBATCH --mem=90G 42 | """.strip(), psc=""" 43 | #!/usr/bin/env bash 44 | #SBATCH --nodes=1 45 | #SBATCH --partition=RM 46 | #SBATCH --ntasks-per-node=18 47 | #SBATCH --time=48:00:00 48 | #SBATCH --mem=64G 49 | """.strip(), psc_gpu=""" 50 | #!/usr/bin/env bash 51 | #SBATCH --nodes=1 52 | #SBATCH --partition=GPU-shared 53 | #SBATCH --gres=gpu:p100:1 54 | #SBATCH --ntasks-per-node=4 55 | #SBATCH --time=48:00:00 56 | """.strip()) 57 | 58 | # location of the singularity file related to the project 59 | SIMG_DIR = { 60 | 'seuss': '/home/xlin3/softgym_containers/softgymcontainer_v3.simg', 61 | # 'psc': '$SCRATCH/containers/ubuntu-16.04-lts-rl.img', 62 | 'psc': '/pylon5/ir5fpfp/xlin3/containers/ubuntu-16.04-lts-rl.img', 63 | 64 | } 65 | CUDA_MODULE = { 66 | 'seuss': 'cuda-91', 67 | 'psc': 'cuda/9.0', 68 | } 69 | MODULES = { 70 | 'seuss': ['singularity'], 71 | 'psc': ['singularity'], 72 | } 73 | -------------------------------------------------------------------------------- /chester/containers/ubuntu-16.04-lts-rl.README: -------------------------------------------------------------------------------- 1 | Bootstrap: debootstrap 2 | OSVersion: xenial 3 | MirrorURL: http://us.archive.ubuntu.com/ubuntu/ 4 | 5 | %help 6 | This is a singularity container that runs Deep Reinforcement Learning algorithms on ubuntu 7 | Packages installed include: 8 | * cuda 9.0 and cuDNN 9 | Will run ~/.bashrc on start to make sure the PATH is the same. 10 | 11 | %runscript 12 | /usr/bin/nvidia-smi -L 13 | 14 | %environment 15 | LD_LIBRARY_PATH=/usr/local/cuda-9.0/cuda/lib64:/usr/local/cuda-9.0/lib64:/usr/lib/nvidia-384$LD_LIBRARY_PATH 16 | 17 | %setup 18 | echo "Let us have CUDA..." 19 | sh /home/xingyu/software/cuda/cuda_9.0.176_384.81_linux.run --silent --toolkit --toolkitpath=${SINGULARITY_ROOTFS}/usr/local/cuda-9.0 20 | ln -s ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0 ${SINGULARITY_ROOTFS}/usr/local/cuda 21 | echo "Let us also have cuDNN..." 22 | cp -prv /home/xingyu/software/cudnn/* ${SINGULARITY_ROOTFS}/usr/local/cuda-9.0/ 23 | 24 | %labels 25 | AUTHOR xlin3@cs.cmu.edu 26 | VERSION v1.0 27 | 28 | %post 29 | echo "Hello from inside the container" 30 | sed -i 's/$/ universe/' /etc/apt/sources.list 31 | touch /usr/bin/nvidia-smi 32 | chmod +x /usr/bin/nvidia-smi 33 | 34 | apt-get -y update 35 | apt-get -y install software-properties-common vim make wget curl emacs ffmpeg git htop libffi-dev libglew-dev libgl1-mesa-glx libosmesa6 libosmesa6-dev libssl-dev mesa-utils module-init-tools openjdk-8-jdk python-dev python-numpy python-tk bzip2 36 | apt-get -y install build-essential 37 | apt-get -y install libgl1-mesa-dev libglfw3-dev 38 | apt-get -y install strace 39 | 40 | echo "Install openmpi 3.1.1" 41 | 42 | cd /tmp 43 | wget https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.1.tar.gz 44 | tar xf openmpi-3.1.1.tar.gz 45 | cd openmpi-3.1.1 46 | mkdir -p build 47 | cd build 48 | ../configure 49 | make -j 8 all 50 | make install 51 | apt-get -y install openmpi-bin 52 | rm -rf /tmp/openmpi* 53 | rm -rf /usr/bin/mpirun 54 | ln -s /usr/local/bin/mpirun /usr/bin/mpirun 55 | 56 | echo "Install mpi4py 3.0.0" 57 | cd /tmp 58 | wget https://bitbucket.org/mpi4py/mpi4py/downloads/mpi4py-3.0.0.tar.gz 59 | tar -zxf mpi4py-3.0.0.tar.gz 60 | cd mpi4py-3.0.0 61 | python setup.py build --mpicc= 62 | python setup.py install --user 63 | mkdir -p /usr/lib/nvidia-384 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Daniel Seita 2 | data/plots/ 3 | *.png 4 | *.swp 5 | *.swo 6 | test.sh 7 | tmp/ 8 | 9 | 10 | GNS/pcl_filter/build 11 | rlpyt_cloth/data 12 | datasets 13 | dpi_visualization 14 | *.gif 15 | .vscode 16 | chester/private 17 | videos 18 | .idea 19 | data 20 | *.simg 21 | *.img 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | pip-wheel-metadata/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | 140 | # Results 141 | results/ 142 | 143 | # DPI data 144 | DPI-Net/dump_*/ 145 | DPI-Net/dump_*/ 146 | 147 | test_*/ 148 | imgs/ 149 | softgym 150 | wandb 151 | -------------------------------------------------------------------------------- /chester/examples/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from baselines import logger 4 | from baselines.common.misc_util import ( 5 | set_global_seeds, 6 | ) 7 | from baselines.ddpg.main import run 8 | from mpi4py import MPI 9 | 10 | 11 | DEFAULT_PARAMS = { 12 | # env 13 | 'env_id': 'HalfCheetah-v2', # max absolute value of actions on different coordinates 14 | 15 | # ddpg 16 | 'layer_norm': True, 17 | 'render': False, 18 | 'normalize_returns':False, 19 | 'normalize_observations':True, 20 | 'actor_lr': 0.0001, # critic learning rate 21 | 'critic_lr': 0.001, # actor learning rate 22 | 'critic_l2_reg': 1e-2, 23 | 'popart': False, 24 | 'gamma': 0.99, 25 | 26 | # training 27 | 'seed': 0, 28 | 'nb_epochs':500, # number of epochs 29 | 'nb_epoch_cycles': 20, # per epoch 30 | 'nb_rollout_steps': 100, # sampling batches per cycle 31 | 'nb_train_steps': 100, # training batches per cycle 32 | 'batch_size': 64, # per mpi thread, measured in transitions and reduced to even multiple of chunk_length. 33 | 'reward_scale': 1.0, 34 | 'clip_norm': None, 35 | 36 | # exploration 37 | 'noise_type':'adaptive-param_0.2', 38 | 39 | # debugging, logging and visualization 40 | 'render_eval': False, 41 | 'nb_eval_steps':100, 42 | 'evaluation':False, 43 | } 44 | 45 | 46 | def run_task(vv, log_dir=None, exp_name=None, allow_extra_parameters=False): 47 | # Configure logging system 48 | if log_dir or logger.get_dir() is None: 49 | logger.configure(dir=log_dir) 50 | logdir = logger.get_dir() 51 | assert logdir is not None 52 | os.makedirs(logdir, exist_ok=True) 53 | 54 | # Seed for multi-CPU MPI implementation ( rank = 0 for single threaded implementation ) 55 | rank = MPI.COMM_WORLD.Get_rank() 56 | rank_seed = vv['seed'] + 1000000 * rank 57 | set_global_seeds(rank_seed) 58 | 59 | # load params from config 60 | params = DEFAULT_PARAMS 61 | 62 | # update all her parameters 63 | if not allow_extra_parameters: 64 | for k,v in vv.items(): 65 | if k not in DEFAULT_PARAMS: 66 | print("[ Warning ] Undefined Parameters %s with value %s"%(str(k),str(v))) 67 | params.update(**{k: v for (k, v) in vv.items() if k in DEFAULT_PARAMS}) 68 | else: 69 | params.update(**{k: v for (k, v) in vv.items()}) 70 | 71 | with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f: 72 | json.dump(params, f) 73 | 74 | run(**params) 75 | 76 | 77 | -------------------------------------------------------------------------------- /chester/examples/pgm_plot.py: -------------------------------------------------------------------------------- 1 | from chester.plotting.cplot import * 2 | import os.path as osp 3 | from random import shuffle 4 | 5 | save_path = '../sac/data/plots' 6 | dict_leg2col = {"LSP": 0, "Base": 1, "Behavior": 2} 7 | dict_xshift = {"LSP": 4000, "Base": 0, "Behavior": 6000} 8 | 9 | 10 | def custom_series_splitter(x): 11 | params = x['flat_params'] 12 | exp_name = params['exp_name'] 13 | dict_mapping = {'humanoid-resume-training-6000-00': 'Behavior', 14 | 'humanoid-resume-training-4000-00': 'LSP', 15 | 'humanoid-rllab/default-2019-04-14-07-04-08-421230-UTC-00': 'Base'} 16 | return dict_mapping[exp_name] 17 | 18 | 19 | def sliding_mean(data_array, window=5): 20 | data_array = np.array(data_array) 21 | new_list = [] 22 | for i in range(len(data_array)): 23 | indices = list(range(max(i - window + 1, 0), 24 | min(i + window + 1, len(data_array)))) 25 | avg = 0 26 | for j in indices: 27 | avg += data_array[j] 28 | avg /= float(len(indices)) 29 | new_list.append(avg) 30 | 31 | return np.array(new_list) 32 | 33 | 34 | def plot_main(): 35 | data_path = '../sac/data/mengxiong' 36 | plot_key = 'return-average' 37 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 38 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 39 | fig, ax = plt.subplots(figsize=(8, 5)) 40 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 41 | color = core.color_defaults[dict_leg2col[legend]] 42 | 43 | y, y_lower, y_upper = get_shaded_curve(selector, plot_key, shade_type='median') 44 | x = np.array(range(len(y))) 45 | x += dict_xshift[legend] 46 | y = sliding_mean(y, 5) 47 | ax.plot(x, y, color=color, label=legend, linewidth=2.0) 48 | 49 | # ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, 50 | # alpha=0.2) 51 | 52 | def y_fmt(x, y): 53 | return str(int(np.round(x))) + 'K' 54 | 55 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 56 | ax.grid(True) 57 | ax.set_xlabel('Timesteps') 58 | ax.set_ylabel('Average-return') 59 | 60 | # plt.title(env_name.replace('Float', 'Push')) 61 | loc = 'best' 62 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends) 63 | for legobj in leg.legendHandles: 64 | legobj.set_linewidth(3.0) 65 | 66 | save_name = filter_save_name('plots.png') 67 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight') 68 | 69 | 70 | if __name__ == '__main__': 71 | plot_main() 72 | -------------------------------------------------------------------------------- /chester/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | # TODO change this before make it into a pip package 5 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 6 | 7 | LOG_DIR = os.path.join(PROJECT_PATH, "data") 8 | 9 | # Make sure to use absolute path 10 | REMOTE_DIR = { 11 | 'seuss': '/home/dseita/softagent_rpad_MM', 12 | 'autobot': '/home/xlin3/Projects/softagent', 13 | 'psc': '/home/xlin3/Projects/softagent', 14 | 'nsh': '/home/xingyu/Projects/softagent', 15 | 'yertle': '/home/xingyu/Projects/softagent' 16 | } 17 | 18 | REMOTE_MOUNT_OPTION = { 19 | 'seuss': '/usr/share/glvnd', 20 | 'autobot': '/usr/share/glvnd', 21 | # 'psc': '/pylon5/ir5fpfp/xlin3/Projects/baselines_hrl/:/mnt', 22 | } 23 | 24 | REMOTE_LOG_DIR = { 25 | 'seuss': os.path.join(REMOTE_DIR['seuss'], "data"), 26 | 'autobot': os.path.join(REMOTE_DIR['autobot'], "data"), 27 | # 'psc': os.path.join(REMOTE_DIR['psc'], "data") 28 | 'psc': os.path.join('/mnt', "data"), 29 | } 30 | 31 | # PSC: https://www.psc.edu/bridges/user-guide/running-jobs 32 | # partition include [RM, RM-shared, LM, GPU] 33 | # TODO change cpu-per-task based on the actual cpus needed (on psc) 34 | # #SBATCH --exclude=compute-0-[7,11] 35 | # Adding this will make the job to grab the whole gpu. #SBATCH --gres=gpu:1 36 | REMOTE_HEADER = dict(seuss=""" 37 | #!/usr/bin/env bash 38 | #SBATCH --nodes=1 39 | #SBATCH --partition=GPU 40 | #SBATCH --exclude=compute-0-[7,9,27] 41 | #SBATCH --cpus-per-task=8 42 | #SBATCH --time=480:00:00 43 | #SBATCH --gres=gpu:1 44 | #SBATCH --mem=110G 45 | """.strip(), psc=""" 46 | #!/usr/bin/env bash 47 | #SBATCH --nodes=1 48 | #SBATCH --partition=RM 49 | #SBATCH --ntasks-per-node=18 50 | #SBATCH --time=48:00:00 51 | #SBATCH --mem=64G 52 | """.strip(), psc_gpu=""" 53 | #!/usr/bin/env bash 54 | #SBATCH --nodes=1 55 | #SBATCH --partition=GPU-shared 56 | #SBATCH --gres=gpu:p100:1 57 | #SBATCH --ntasks-per-node=4 58 | #SBATCH --time=48:00:00 59 | """.strip(), autobot=""" 60 | #!/usr/bin/env bash 61 | #SBATCH --nodes=1 62 | #SBATCH --partition=long 63 | #SBATCH --cpus-per-task=32 64 | #SBATCH --time=3-12:00:00 65 | #SBATCH --gres=gpu:1 66 | #SBATCH --mem=40G 67 | """.strip()) 68 | 69 | # location of the singularity file related to the project 70 | SIMG_DIR = { 71 | 'seuss': '/home/xlin3/softgym_containers/softgymcontainer_v4.simg', 72 | 'autobot': '/home/xlin3/softgym_containers/softgymcontainer_v3.simg', 73 | # 'psc': '$SCRATCH/containers/ubuntu-16.04-lts-rl.img', 74 | 'psc': '/pylon5/ir5fpfp/xlin3/containers/ubuntu-16.04-lts-rl.img', 75 | 76 | } 77 | CUDA_MODULE = { 78 | 'seuss': 'cuda-91', 79 | 'autobot': 'cuda-10.2', 80 | 'psc': 'cuda/9.0', 81 | } 82 | MODULES = { 83 | 'seuss': ['singularity'], 84 | 'autobot': ['singularity'], 85 | 'psc': ['singularity'], 86 | } 87 | -------------------------------------------------------------------------------- /chester/config_ec2.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | 4 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 5 | 6 | AWS_NETWORK_INTERFACES = [] 7 | 8 | AWS_BUCKET_REGION_NAME = 'us-east-2' 9 | 10 | MUJOCO_KEY_PATH = osp.expanduser("~/.mujoco") 11 | 12 | USE_GPU = True 13 | 14 | USE_TF = True 15 | 16 | AWS_REGION_NAME = "us-east-2" 17 | 18 | if USE_GPU: 19 | DOCKER_IMAGE = "dementrock/rllab3-shared-gpu" 20 | else: 21 | DOCKER_IMAGE = "dementrock/rllab3-shared" 22 | 23 | DOCKER_LOG_DIR = "/tmp/expt" 24 | 25 | CODE_DIR = "/root/code/" 26 | 27 | AWS_S3_PATH = "s3://chester-softgym/rllab/experiments" 28 | 29 | EBS_OPTIMIZED = True 30 | 31 | AWS_EXTRA_CONFIGS = dict() 32 | 33 | AWS_CODE_SYNC_S3_PATH = "s3://chester-softgym/rllab/code" 34 | 35 | ALL_REGION_AWS_IMAGE_IDS = { 36 | # "ap-northeast-1": "ami-002f0167", 37 | # "ap-northeast-2": "ami-590bd937", 38 | # "ap-south-1": "ami-77314318", 39 | # "ap-southeast-1": "ami-1610a975", 40 | # "ap-southeast-2": "ami-9dd4ddfe", 41 | # "eu-central-1": "ami-63af720c", 42 | # "eu-west-1": "ami-41484f27", 43 | # "sa-east-1": "ami-b7234edb", 44 | "us-east-1": "ami-83f26195", 45 | "us-east-2": "ami-0ec385d5f98faacc3", #"ami-0e63a1a8842443350", 46 | "us-west-1": "ami-576f4b37", 47 | "us-west-2": "ami-b8b62bd8" 48 | } 49 | 50 | AWS_IMAGE_ID = ALL_REGION_AWS_IMAGE_IDS[AWS_REGION_NAME] 51 | 52 | if USE_GPU: 53 | AWS_INSTANCE_TYPE = "p2.xlarge" 54 | else: 55 | AWS_INSTANCE_TYPE = "c4.4xlarge" 56 | 57 | ALL_REGION_AWS_KEY_NAMES = { 58 | "us-east-1": "rllab-us-east-1", 59 | "us-east-2": "rllab-us-east-2", 60 | "us-west-1": "rllab-us-west-1", 61 | "us-west-2": "rllab-us-west-2" 62 | } 63 | 64 | AWS_KEY_NAME = ALL_REGION_AWS_KEY_NAMES[AWS_REGION_NAME] 65 | 66 | AWS_SPOT = True 67 | 68 | AWS_SPOT_PRICE = '2.0' 69 | 70 | AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None) 71 | 72 | AWS_ACCESS_SECRET = os.environ.get("AWS_ACCESS_SECRET", None) 73 | 74 | AWS_IAM_INSTANCE_PROFILE_NAME = "rllab" 75 | 76 | AWS_SECURITY_GROUPS = ["rllab-sg"] 77 | 78 | ALL_REGION_AWS_SECURITY_GROUP_IDS = { 79 | "us-east-1": [ 80 | "sg-e17bfc95" 81 | ], 82 | "us-east-2": [ 83 | "sg-1ddb3876" 84 | ], 85 | "us-west-1": [ 86 | "sg-cd5f9db4" 87 | ], 88 | "us-west-2": [ 89 | "sg-b585a8c9" 90 | ] 91 | } 92 | 93 | AWS_SECURITY_GROUP_IDS = ALL_REGION_AWS_SECURITY_GROUP_IDS[AWS_REGION_NAME] 94 | 95 | FAST_CODE_SYNC_IGNORES = [ 96 | ".git", 97 | "data/autobot", 98 | "data/corl_s3_data", 99 | "data/videos", 100 | "data/open_loop_videos", 101 | "data/icml", 102 | "data/local", 103 | "data/seuss", 104 | "data/yufei_s3_data", 105 | "data/icml" 106 | "data/local", 107 | "data/archive", 108 | "data/debug", 109 | "data/s3", 110 | "data/video", 111 | ".idea", 112 | "tests", 113 | "examples", 114 | "docs", 115 | ".idea", 116 | ".DS_Store", 117 | ".ipynb_checkpoints", 118 | "blackbox", 119 | "blackbox.zip", 120 | "*.pyc", 121 | "*.ipynb", 122 | "scratch-notebooks", 123 | "conopt_root", 124 | "private/key_pairs", 125 | "DPI-Net", 126 | "imgs/", 127 | "imgs", 128 | "videos" 129 | ] 130 | 131 | FAST_CODE_SYNC = True 132 | 133 | LABEL = "" 134 | -------------------------------------------------------------------------------- /chester/slurm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import re 4 | from subprocess import run 5 | from tempfile import NamedTemporaryFile 6 | from chester import config 7 | 8 | # TODO remove the singularity part 9 | 10 | slurm_dir = './' 11 | 12 | 13 | def slurm_run_scripts(scripts): 14 | """this is another function that those _sub files should call. this actually execute files""" 15 | # TODO support running multiple scripts 16 | 17 | assert isinstance(scripts, str) 18 | 19 | os.chdir(slurm_dir) 20 | 21 | # make sure it will run. 22 | assert scripts.startswith('#!/usr/bin/env bash\n') 23 | file_temp = NamedTemporaryFile(delete=False) 24 | file_temp.write(scripts.encode('utf-8')) 25 | file_temp.close() 26 | run(['sbatch', file_temp.name], check=True) 27 | os.remove(file_temp.name) 28 | 29 | 30 | _find_unsafe = re.compile(r'[a-zA-Z0-9_^@%+=:,./-]').search 31 | 32 | 33 | def _shellquote(s): 34 | """Return a shell-escaped version of the string *s*.""" 35 | if not s: 36 | return "''" 37 | 38 | if _find_unsafe(s) is None: 39 | return s 40 | 41 | # use single quotes, and put single quotes into double quotes 42 | # the string $'b is then quoted as '$'"'"'b' 43 | 44 | return "'" + s.replace("'", "'\"'\"'") + "'" 45 | 46 | 47 | def _to_param_val(v): 48 | if v is None: 49 | return "" 50 | elif isinstance(v, list): 51 | return " ".join(map(_shellquote, list(map(str, v)))) 52 | else: 53 | return _shellquote(str(v)) 54 | 55 | 56 | def to_slurm_command(params, header, python_command="python", remote_dir='~/', 57 | script=osp.join(config.PROJECT_PATH, 'scripts/run_experiment.py'), 58 | simg_dir=None, use_gpu=False, modules=None, cuda_module=None, use_singularity=True, 59 | mount_options=None, compile_script=None, wait_compile=None, set_egl_gpu=False): 60 | # TODO Add code for specifying the resource allocation 61 | # TODO Check if use_gpu can be applied 62 | """ 63 | Transfer the commands to the format that can be run by slurm. 64 | :param params: 65 | :param python_command: 66 | :param script: 67 | :param use_gpu: 68 | :return: 69 | """ 70 | assert simg_dir is not None 71 | command = python_command + " " + script 72 | 73 | pre_commands = params.pop("pre_commands", None) 74 | post_commands = params.pop("post_commands", None) 75 | 76 | command_list = list() 77 | command_list.append(header) 78 | 79 | # Log into singularity shell 80 | if use_singularity: 81 | command_list.append('set -x') # echo commands to stdout 82 | command_list.append('set -u') # throw an error if unset variable referenced 83 | command_list.append('set -e') # exit on errors 84 | command_list.append('srun hostname') 85 | 86 | for remote_module in modules: 87 | command_list.append('module load ' + remote_module) 88 | if use_gpu: 89 | assert cuda_module is not None 90 | command_list.append('module load ' + cuda_module) 91 | command_list.append('cd {}'.format(remote_dir)) 92 | # First execute a bash program inside the container and then run all the following commands 93 | 94 | if mount_options is not None: 95 | options = '-B ' + mount_options 96 | else: 97 | options = '' 98 | options += " -B /opt" 99 | sing_prefix = 'singularity exec {} {} {} /bin/bash -c'.format(options, '--nv' if use_gpu else '', simg_dir) 100 | sing_commands = list() 101 | if compile_script is None or 'prepare' not in compile_script : 102 | sing_commands.append('. ./prepare_1.0.sh') 103 | if set_egl_gpu: 104 | sing_commands.append('export EGL_GPU=$SLURM_JOB_GRES') 105 | sing_commands.append('echo $EGL_GPU') 106 | if compile_script is not None: 107 | sing_commands.append(compile_script) 108 | if wait_compile is not None: 109 | sing_commands.append('sleep '+str(int(wait_compile))) 110 | 111 | if pre_commands is not None: 112 | command_list.extend(pre_commands) 113 | for k, v in params.items(): 114 | if isinstance(v, dict): 115 | for nk, nv in v.items(): 116 | if str(nk) == "_name": 117 | command += " --%s %s" % (k, _to_param_val(nv)) 118 | else: 119 | command += " --%s_%s %s" % (k, nk, _to_param_val(nv)) 120 | else: 121 | command += " --%s %s" % (k, _to_param_val(v)) 122 | sing_commands.append(command) 123 | all_sing_cmds = ' && '.join(sing_commands) 124 | command_list.append(sing_prefix + ' \'{}\''.format(all_sing_cmds)) 125 | if post_commands is not None: 126 | command_list.extend(post_commands) 127 | return command_list 128 | 129 | # if __name__ == '__main__': 130 | # slurm_run_scripts(header) 131 | -------------------------------------------------------------------------------- /chester/examples/cplot_example.py: -------------------------------------------------------------------------------- 1 | from chester.plotting.cplot import * 2 | import os.path as osp 3 | 4 | 5 | def custom_series_splitter(x): 6 | params = x['flat_params'] 7 | if 'use_ae_reward' in params and params['use_ae_reward']: 8 | return 'Auto Encoder' 9 | if params['her_replay_strategy'] == 'balance_filter': 10 | return 'Indicator+Balance+Filter' 11 | if params['env_kwargs.use_true_reward']: 12 | return 'Oracle' 13 | return 'Indicator' 14 | 15 | 16 | dict_leg2col = {"Oracle": 1, "Indicator": 0, 'Indicator+Balance+Filter': 2, "Auto Encoder": 3} 17 | save_path = './data/plots_chester' 18 | 19 | 20 | def plot_visual_learning(): 21 | data_path = './data/nsh/submit_rss/submit_rss/visual_learning' 22 | 23 | plot_keys = ['test/success_state', 'test/goal_dist_final_state'] 24 | plot_ylabels = ['Success', 'Final Distance to Goal'] 25 | plot_envs = ['FetchReach', 'Reacher', 'RopeFloat'] 26 | 27 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 28 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 29 | for (plot_key, plot_ylabel) in zip(plot_keys, plot_ylabels): 30 | for env_name in plot_envs: 31 | fig, ax = plt.subplots(figsize=(8, 5)) 32 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 33 | color = core.color_defaults[dict_leg2col[legend]] 34 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key, 35 | shade_type='median') 36 | 37 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"] 38 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode') 39 | x = [ele * env_horizon for ele in x] 40 | 41 | ax.plot(x, y, color=color, label=legend, linewidth=2.0) 42 | 43 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, 44 | alpha=0.2) 45 | 46 | def y_fmt(x, y): 47 | return str(int(np.round(x / 1000.0))) + 'K' 48 | 49 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 50 | ax.grid(True) 51 | ax.set_xlabel('Timesteps') 52 | ax.set_ylabel(plot_ylabel) 53 | axes = plt.gca() 54 | if 'Rope' in env_name: 55 | axes.set_xlim(left=20000) 56 | 57 | plt.title(env_name.replace('Float', 'Push')) 58 | loc = 'best' 59 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends) 60 | for legobj in leg.legendHandles: 61 | legobj.set_linewidth(3.0) 62 | 63 | save_name = filter_save_name('ind_visual_' + plot_key + '_' + env_name) 64 | 65 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight') 66 | 67 | 68 | def plot_state_learning(): 69 | data_path = './data/nsh/submit_rss/submit_rss/state_learning' 70 | 71 | plot_keys = ['test/success_state', 'test/goal_dist_final_state'] 72 | plot_envs = ['FetchReach', 'FetchPush', 'Reacher', 'RopeFloat'] 73 | 74 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 75 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 76 | for plot_key in plot_keys: 77 | for env_name in plot_envs: 78 | fig, ax = plt.subplots(figsize=(8, 5)) 79 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 80 | color = core.color_defaults[dict_leg2col[legend]] 81 | y, y_lower, y_upper = get_shaded_curve(selector.where('env_name', env_name), plot_key, 82 | shade_type='median') 83 | env_horizon = selector.where('env_name', env_name).extract()[0].params["env_kwargs"]["horizon"] 84 | x, _, _ = get_shaded_curve(selector.where('env_name', env_name), 'train/episode') 85 | x = [ele * env_horizon for ele in x] 86 | ax.plot(x, y, color=color, label=legend, linewidth=2.0) 87 | 88 | ax.fill_between(x, y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, 89 | alpha=0.2) 90 | 91 | def y_fmt(x, y): 92 | return str(int(np.round(x / 1000.0))) + 'K' 93 | 94 | ax.xaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 95 | ax.grid(True) 96 | ax.set_xlabel('Timesteps') 97 | ax.set_ylabel('Success') 98 | 99 | plt.title(env_name.replace('Float', 'Push')) 100 | loc = 'best' 101 | leg = ax.legend(loc=loc, prop={'size': 20}, ncol=1, labels=group_legends) 102 | for legobj in leg.legendHandles: 103 | legobj.set_linewidth(3.0) 104 | 105 | save_name = filter_save_name('ind_state_' + plot_key + '_' + env_name) 106 | 107 | plt.savefig(osp.join(save_path, save_name), bbox_inches='tight') 108 | 109 | 110 | if __name__ == '__main__': 111 | plot_visual_learning() 112 | plot_state_learning() 113 | -------------------------------------------------------------------------------- /rpmg/rpmg_losses.py: -------------------------------------------------------------------------------- 1 | 2 | import tools 3 | import torch 4 | from pytorch3d.transforms import so3_log_map, so3_exponential_map 5 | 6 | mse_loss = torch.nn.MSELoss() 7 | 8 | 9 | def rpmg_forward(in_nd): 10 | proj_kind = in_nd.shape[1] 11 | if proj_kind == 6: 12 | r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd) 13 | elif proj_kind == 9: 14 | r0 = tools.symmetric_orthogonalization(in_nd) 15 | elif proj_kind == 4: 16 | r0 = tools.compute_rotation_matrix_from_quaternion(in_nd) 17 | elif proj_kind == 10: 18 | r0 = tools.compute_rotation_matrix_from_10d(in_nd) 19 | else: 20 | raise NotImplementedError 21 | # return r0 22 | return r0.transpose(-1, -2) 23 | 24 | 25 | def rpmg_inverse(R, proj_kind): 26 | R = R.transpose(-1, -2) 27 | 28 | if proj_kind == 6: 29 | x = torch.cat([R[:, :, 0], R[:, :, 1]], dim=1) 30 | 31 | elif proj_kind == 9: 32 | x = R.reshape(-1, 9) 33 | 34 | elif proj_kind == 4: 35 | x = tools.compute_quaternions_from_rotation_matrices(R) 36 | 37 | elif proj_kind == 10: 38 | q = tools.compute_quaternions_from_rotation_matrices(R) 39 | reg_A = torch.eye(4, device=q.device)[None].repeat(q.shape[0], 1, 1) \ 40 | - torch.bmm(q.unsqueeze(-1), q.unsqueeze(-2)) 41 | x = tools.convert_A_to_Avec(reg_A) 42 | 43 | else: 44 | raise NotImplementedError 45 | 46 | return x 47 | 48 | 49 | def rpmg_goal_and_nearest(x, R_goal): 50 | R_goal = R_goal.transpose(-1, -2) 51 | proj_kind = x.shape[1] 52 | 53 | if proj_kind == 6: 54 | x_proj_1 = (R_goal[:, :, 0] * x[:, :3]).sum(dim=1, 55 | keepdim=True) * R_goal[:, :, 0] 56 | x_proj_2 = (R_goal[:, :, 0] * x[:, 3:]).sum(dim=1, keepdim=True) * R_goal[:, :, 0] \ 57 | + (R_goal[:, :, 1] * x[:, 3:]).sum(dim=1, 58 | keepdim=True) * R_goal[:, :, 1] 59 | x_goal = torch.cat([R_goal[:, :, 0], R_goal[:, :, 1]], dim=1) 60 | x_nearest = torch.cat([x_proj_1, x_proj_2], dim=1) 61 | 62 | elif proj_kind == 9: 63 | x_goal = R_goal.reshape(-1, 9) 64 | x_nearest = tools.compute_SVD_nearest_Mnlsew( 65 | x.reshape(-1, 3, 3), R_goal) 66 | 67 | elif proj_kind == 4: 68 | q_1 = tools.compute_quaternions_from_rotation_matrices(R_goal) 69 | q_2 = -q_1 70 | x_proj = tools.normalize_vector(x) 71 | x_goal = torch.where( 72 | (q_1 - x_proj).norm(dim=1, keepdim=True) < (q_2 - 73 | x_proj).norm(dim=1, keepdim=True), 74 | q_1, q_2) 75 | x_nearest = (x * x_goal).sum(dim=1, keepdim=True) * x_goal 76 | 77 | elif proj_kind == 10: 78 | q_goal = tools.compute_quaternions_from_rotation_matrices(R_goal) 79 | x_nearest = tools.compute_nearest_10d(x, q_goal) 80 | reg_A = torch.eye(4, device=q_goal.device)[None].repeat(q_goal.shape[0], 1, 1) \ 81 | - torch.bmm(q_goal.unsqueeze(-1), q_goal.unsqueeze(-2)) 82 | x_goal = tools.convert_A_to_Avec(reg_A) 83 | 84 | return x_nearest, x_goal 85 | 86 | 87 | def projective_manifold_gradient_loss( 88 | x0_pred, x1_pred, delta_R10_gt, 89 | step_size=0.05, lambda_reg=0.01, 90 | transpose=False): 91 | if transpose: 92 | R0_pred = rpmg_forward(x0_pred.detach()).transpose(-1, -2) 93 | R1_pred = rpmg_forward(x1_pred.detach()).transpose(-1, -2) 94 | else: 95 | R0_pred = rpmg_forward(x0_pred.detach()) 96 | R1_pred = rpmg_forward(x1_pred.detach()) 97 | 98 | delta_r = so3_log_map(torch.bmm(R0_pred.transpose(-1, -2), 99 | torch.bmm(delta_R10_gt, R1_pred)), eps=0.001) 100 | R0_goal = torch.bmm(R0_pred, so3_exponential_map(step_size*delta_r)) 101 | 102 | x0_nearest, x0_goal = rpmg_goal_and_nearest(x0_pred, R0_goal) 103 | if(False): 104 | loss = mse_loss(x0_pred, x0_nearest - 105 | lambda_reg*(x0_nearest - x0_goal)) 106 | else: 107 | loss = mse_loss(x0_pred, x0_nearest) \ 108 | + lambda_reg*mse_loss(x0_pred, x0_goal) 109 | return loss, delta_r.norm(dim=-1) 110 | 111 | 112 | def projective_manifold_gradient_loss_absolute( 113 | x0_pred, x1_pred, R0_gt, R1_gt, 114 | step_size=0.05, lambda_reg=0.01, 115 | transpose=False): 116 | if transpose: 117 | R0_pred = rpmg_forward(x0_pred.detach()).transpose(-1, -2) 118 | R1_pred = rpmg_forward(x1_pred.detach()).transpose(-1, -2) 119 | else: 120 | R0_pred = rpmg_forward(x0_pred.detach()) 121 | R1_pred = rpmg_forward(x1_pred.detach()) 122 | 123 | delta_R10_gt = torch.bmm(R0_gt, R1_gt.transpose(-1, -2)) 124 | delta_r = so3_log_map(torch.bmm(R0_pred.transpose(-1, -2), 125 | torch.bmm(delta_R10_gt, R1_pred))) 126 | R0_goal = torch.bmm(R0_pred, so3_exponential_map(step_size*delta_r)) 127 | 128 | x0_nearest, x0_goal = rpmg_goal_and_nearest(x0_pred, R0_goal) 129 | if(False): 130 | loss = mse_loss(x0_pred, x0_nearest - 131 | lambda_reg*(x0_nearest - x0_goal)) 132 | else: 133 | loss = mse_loss(x0_pred, x0_nearest) \ 134 | + lambda_reg*mse_loss(x0_pred, x0_goal) 135 | return loss, delta_r.norm(dim=-1) 136 | -------------------------------------------------------------------------------- /chester/run_exp_worker.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os.path as osp 4 | import datetime 5 | import dateutil.tz 6 | import ast 7 | import uuid 8 | import pickle as pickle 9 | import base64 10 | import joblib 11 | 12 | from chester import config 13 | 14 | 15 | def run_experiment(argv): 16 | default_log_dir = config.LOG_DIR 17 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 18 | 19 | # avoid name clashes when running distributed jobs 20 | rand_id = str(uuid.uuid4())[:5] 21 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S_%f_%Z') 22 | 23 | default_exp_name = 'experiment_%s_%s' % (timestamp, rand_id) 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--exp_name', type=str, default=default_exp_name, 26 | help='Name of the experiment.') 27 | parser.add_argument('--log_dir', type=str, default=None, 28 | help='Path to save the log and iteration snapshot.') 29 | parser.add_argument('--snapshot_mode', type=str, default='all', 30 | help='Mode to save the snapshot. Can be either "all" ' 31 | '(all iterations will be saved), "last" (only ' 32 | 'the last iteration will be saved), "gap" (every' 33 | '`snapshot_gap` iterations are saved), or "none" ' 34 | '(do not save snapshots)') 35 | parser.add_argument('--snapshot_gap', type=int, default=1, 36 | help='Gap between snapshot iterations.') 37 | parser.add_argument('--tabular_log_file', type=str, default='progress.csv', 38 | help='Name of the tabular log file (in csv).') 39 | parser.add_argument('--text_log_file', type=str, default='debug.log', 40 | help='Name of the text log file (in pure text).') 41 | parser.add_argument('--params_log_file', type=str, default='params.json', 42 | help='Name of the parameter log file (in json).') 43 | parser.add_argument('--variant_log_file', type=str, default='variant.json', 44 | help='Name of the variant log file (in json).') 45 | parser.add_argument('--resume_from', type=str, default=None, 46 | help='Name of the pickle file to resume experiment from.') 47 | parser.add_argument('--plot', type=ast.literal_eval, default=False, 48 | help='Whether to plot the iteration results') 49 | parser.add_argument('--log_tabular_only', type=ast.literal_eval, default=False, 50 | help='Whether to only print the tabular log information (in a horizontal format)') 51 | parser.add_argument('--seed', type=int, 52 | help='Random seed for numpy') 53 | parser.add_argument('--args_data', type=str, 54 | help='Pickled data for stub objects') 55 | parser.add_argument('--variant_data', type=str, 56 | help='Pickled data for variant configuration') 57 | parser.add_argument('--use_cloudpickle', type=ast.literal_eval, default=False) 58 | 59 | args = parser.parse_args(argv[1:]) 60 | 61 | # if args.seed is not None: 62 | # set_seed(args.seed) 63 | # 64 | # if args.plot: 65 | # from rllab.plotter import plotter 66 | # plotter.init_worker() 67 | 68 | if args.log_dir is None: 69 | log_dir = osp.join(default_log_dir, args.exp_name) 70 | else: 71 | log_dir = args.log_dir 72 | # tabular_log_file = osp.join(log_dir, args.tabular_log_file) 73 | # text_log_file = osp.join(log_dir, args.text_log_file) 74 | # params_log_file = osp.join(log_dir, args.params_log_file) 75 | 76 | if args.variant_data is not None: 77 | variant_data = pickle.loads(base64.b64decode(args.variant_data)) 78 | variant_log_file = osp.join(log_dir, args.variant_log_file) 79 | # logger.log_variant(variant_log_file, variant_data) 80 | else: 81 | variant_data = None 82 | 83 | # if not args.use_cloudpickle: 84 | # logger.log_parameters_lite(params_log_file, args) 85 | # 86 | # logger.add_text_output(text_log_file) 87 | # logger.add_tabular_output(tabular_log_file) 88 | # prev_snapshot_dir = logger.get_snapshot_dir() 89 | # prev_mode = logger.get_snapshot_mode() 90 | # logger.set_snapshot_dir(log_dir) 91 | # logger.set_snapshot_mode(args.snapshot_mode) 92 | # logger.set_snapshot_gap(args.snapshot_gap) 93 | # logger.set_log_tabular_only(args.log_tabular_only) 94 | # logger.push_prefix("[%s] " % args.exp_name) 95 | 96 | if args.resume_from is not None: 97 | data = joblib.load(args.resume_from) 98 | assert 'algo' in data 99 | algo = data['algo'] 100 | algo.train() 101 | else: 102 | # read from stdin 103 | if args.use_cloudpickle: 104 | import cloudpickle 105 | method_call = cloudpickle.loads(base64.b64decode(args.args_data)) 106 | method_call(variant_data, log_dir, args.exp_name) 107 | else: 108 | assert False 109 | # data = pickle.loads(base64.b64decode(args.args_data)) 110 | # maybe_iter = concretize(data) 111 | # if is_iterable(maybe_iter): 112 | # for _ in maybe_iter: 113 | # pass 114 | 115 | # logger.set_snapshot_mode(prev_mode) 116 | # logger.set_snapshot_dir(prev_snapshot_dir) 117 | # logger.remove_tabular_output(tabular_log_file) 118 | # logger.remove_text_output(text_log_file) 119 | # logger.pop_prefix() 120 | 121 | 122 | if __name__ == "__main__": 123 | run_experiment(sys.argv) 124 | -------------------------------------------------------------------------------- /initializers/mm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from initializers import InitializerBase 4 | 5 | class MoveCloserInitializer(InitializerBase): 6 | THRESH2 = 75 + 125 7 | 8 | def __init__(self, env, args): 9 | self.env = env 10 | 11 | def get_action(self, obs, info=None): 12 | step = self.env.inner_step 13 | T = self.env.tool_idx 14 | dx, dy, dz = 0., 0., 0. 15 | 16 | # Position of the tool. Unfortunately adding to tx and tz is a real hack. 17 | tx = self.env.tool_state[T,0] 18 | ty = self.env.tool_state[T,1] 19 | tz = self.env.tool_state[T,2] 20 | tx += 0.090 21 | tz += 0.090 22 | 23 | # Position of the item. 24 | rigid_avg_pos = self.env._get_rigid_pos() 25 | ix = rigid_avg_pos[0] 26 | iy = rigid_avg_pos[1] 27 | iz = rigid_avg_pos[2] 28 | 29 | # Unfortunately distance thresholds have to be tuned carefully. 30 | dist_xz = np.sqrt((tx-ix)**2 + (tz-iz)**2) 31 | 32 | thresh1 = 75 33 | thresh2 = thresh1 + 125 34 | if 0 <= step < 15: 35 | # Only move if the item is below the sphere. 36 | if dist_xz < 0.050: 37 | # Go the _negative_ direction 38 | dx = -(ix - tx) 39 | dz = -(iz - tz) 40 | elif 15 <= step < 66: 41 | # Actually it seems like faster strangely means less water movement. 42 | dy = -0.0040 # stronger -dy won't have effect due to action bounds 43 | elif thresh1 <= step < thresh2: 44 | # Try to correct for the discrepancy 45 | if dist_xz > 0.004: 46 | dx = ix - tx 47 | dz = iz - tz 48 | elif thresh2 <= step < 600: 49 | # Ah, it would actually be hard for us to do another xz correction here, since 50 | # if it causes collision, we'd just stop the motion. :( 51 | dy = 0.0040 52 | else: 53 | pass 54 | 55 | # Try to normalize (to magnitude 1) then downscale by a tuned amount. 56 | # Unfortunately these numbers just come from tuning / visualizing. 57 | action = np.array([dx, dy, dz, 0.]) 58 | if (5 <= step < 15) or (thresh1 <= step < thresh2): 59 | if np.linalg.norm(action) > 0: 60 | action = action / np.linalg.norm(action) * 0.0020 61 | 62 | if self.env.action_mode == 'translation': 63 | action = action[:3] 64 | else: 65 | raise NotImplementedError() 66 | 67 | # 'Un-scale' to anticipate effect of future scaling in `NormalizedEnv` 68 | lb = self.env._wrapped_env.action_space.low 69 | ub = self.env._wrapped_env.action_space.high 70 | action = (action - lb) / ((ub - lb) * 0.5) - 1.0 71 | 72 | # Check if we're done 73 | done = step >= (self.THRESH2 - 1) 74 | 75 | return action, done 76 | 77 | class SmartMoveCloserInitializer(InitializerBase): 78 | THRESH2 = 75 + 125 79 | 80 | def __init__(self, env, args): 81 | self.env = env 82 | self.state = 0 83 | 84 | def reset(self): 85 | self.state = 0 86 | 87 | def get_action(self, obs, info=None): 88 | step = self.env.inner_step 89 | T = self.env.tool_idx 90 | dx, dy, dz = 0., 0., 0. 91 | 92 | # Position of the tool. Unfortunately adding to tx and tz is a real hack. 93 | tx = self.env.tool_state[T,0] 94 | ty = self.env.tool_state[T,1] 95 | tz = self.env.tool_state[T,2] 96 | tx += 0.090 97 | tz += 0.090 98 | 99 | # Position of the item. 100 | rigid_avg_pos = self.env._get_rigid_pos() 101 | ix = rigid_avg_pos[0] 102 | iy = rigid_avg_pos[1] 103 | iz = rigid_avg_pos[2] 104 | 105 | # Unfortunately distance thresholds have to be tuned carefully. 106 | dist_xz = np.sqrt((tx-ix)**2 + (tz-iz)**2) 107 | 108 | thresh1 = 75 109 | thresh2 = thresh1 + 125 110 | # if 0 <= step < 15: 111 | if self.state == 0: 112 | # Don't collide! 113 | if -0.2 >= tx or -0.2 >= tz or 0.2 <= tx or 0.2 <= tz: 114 | dx = -ix 115 | dz = -iz 116 | # Only move if the item is below the sphere. 117 | elif dist_xz < 0.050: 118 | # Go the _negative_ direction 119 | dx = -(ix - tx) 120 | dz = -(iz - tz) 121 | else: 122 | self.state += 1 123 | # elif 15 <= step < 66: 124 | elif self.state == 1: 125 | # Actually it seems like faster strangely means less water movement. 126 | if ty > 0.08: 127 | dy = -0.0040 # stronger -dy won't have effect due to action bounds 128 | else: 129 | self.state += 1 130 | self.finish_step = step 131 | # elif thresh1 <= step < thresh2: 132 | elif self.state == 2: 133 | # Try to correct for the discrepancy 134 | if dist_xz > 0.004 and step < self.finish_step + 125: 135 | dx = ix - tx 136 | dz = iz - tz 137 | else: 138 | self.state += 1 139 | # elif thresh2 <= step < 600: 140 | elif self.state == 3: 141 | # Ah, it would actually be hard for us to do another xz correction here, since 142 | # if it causes collision, we'd just stop the motion. :( 143 | dy = 0.0040 144 | if iy < ty: 145 | self.state = 0 146 | else: 147 | pass 148 | 149 | # Try to normalize (to magnitude 1) then downscale by a tuned amount. 150 | # Unfortunately these numbers just come from tuning / visualizing. 151 | action = np.array([dx, dy, dz, 0.]) 152 | # if (5 <= step < 15) or (thresh1 <= step < thresh2): 153 | if self.state == 0 or self.state == 2: 154 | if np.linalg.norm(action) > 0: 155 | action = action / np.linalg.norm(action) * 0.0020 156 | 157 | if self.env.action_mode == 'translation': 158 | action = action[:3] 159 | else: 160 | raise NotImplementedError() 161 | 162 | # 'Un-scale' to anticipate effect of future scaling in `NormalizedEnv` 163 | lb = self.env._wrapped_env.action_space.low 164 | ub = self.env._wrapped_env.action_space.high 165 | action = (action - lb) / ((ub - lb) * 0.5) - 1.0 166 | 167 | # Check if we're done 168 | done = step >= (self.THRESH2 - 1) 169 | 170 | return action, done 171 | -------------------------------------------------------------------------------- /bc/logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from collections import defaultdict 3 | import json 4 | import os 5 | import shutil 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | from termcolor import colored 10 | 11 | FORMAT_CONFIG = { 12 | 'rl': { 13 | 'train': [ 14 | ('episode', 'E', 'int'), ('step', 'S', 'int'), 15 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 16 | ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'), 17 | ('critic_loss', 'CR_LOSS', 'float'), ('curl_loss', 'CU_LOSS', 'float'), 18 | # ('q1', 'Q1', 'float'), ('comp_Q', 'comp_Q', 'float'), ('comp_E', 'comp_E', 'float') 19 | ], 20 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')] 21 | } 22 | } 23 | 24 | 25 | class AverageMeter(object): 26 | def __init__(self): 27 | self._sum = 0 28 | self._count = 0 29 | 30 | def update(self, value, n=1): 31 | self._sum += value 32 | self._count += n 33 | 34 | def value(self): 35 | return self._sum / max(1, self._count) 36 | 37 | 38 | class MetersGroup(object): 39 | def __init__(self, file_name, formating): 40 | self._file_name = file_name 41 | if os.path.exists(file_name): 42 | os.remove(file_name) 43 | self._formating = formating 44 | self._meters = defaultdict(AverageMeter) 45 | 46 | def log(self, key, value, n=1): 47 | self._meters[key].update(value, n) 48 | 49 | def _prime_meters(self): 50 | data = dict() 51 | for key, meter in self._meters.items(): 52 | if key.startswith('train'): 53 | key = key[len('train') + 1:] 54 | else: 55 | key = key[len('eval') + 1:] 56 | key = key.replace('/', '_') 57 | data[key] = meter.value() 58 | return data 59 | 60 | def _dump_to_file(self, data): 61 | with open(self._file_name, 'a') as f: 62 | f.write(json.dumps(data) + '\n') 63 | 64 | def _format(self, key, value, ty): 65 | template = '%s: ' 66 | if ty == 'int': 67 | template += '%d' 68 | elif ty == 'float': 69 | template += '%.04f' 70 | elif ty == 'time': 71 | template += '%.01f s' 72 | else: 73 | raise 'invalid format type: %s' % ty 74 | return template % (key, value) 75 | 76 | def _dump_to_console(self, data, prefix): 77 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 78 | pieces = ['{:5}'.format(prefix)] 79 | for key, disp_key, ty in self._formating: 80 | value = data.get(key, 0) 81 | pieces.append(self._format(disp_key, value, ty)) 82 | print('| %s' % (' | '.join(pieces))) 83 | 84 | def dump(self, step, prefix): 85 | data = self._prime_meters() 86 | data['step'] = step 87 | self._dump_to_file(data) 88 | self._dump_to_console(data, prefix) 89 | self._meters.clear() 90 | 91 | 92 | class Logger(object): 93 | def __init__(self, log_dir, use_tb=True, config='rl', chester_logger=None): 94 | self._log_dir = log_dir 95 | if use_tb: 96 | tb_dir = os.path.join(log_dir, 'tb') 97 | if os.path.exists(tb_dir): 98 | shutil.rmtree(tb_dir) 99 | self._sw = SummaryWriter(tb_dir) 100 | else: 101 | self._sw = None 102 | self._train_mg = MetersGroup( 103 | os.path.join(log_dir, 'train.log'), 104 | formating=FORMAT_CONFIG[config]['train'] 105 | ) 106 | self._eval_mg = MetersGroup( 107 | os.path.join(log_dir, 'eval.log'), 108 | formating=FORMAT_CONFIG[config]['eval'] 109 | ) 110 | self.chester_logger = chester_logger 111 | 112 | def _try_sw_log(self, key, value, step): 113 | if self._sw is not None: 114 | self._sw.add_scalar(key, value, step) 115 | 116 | def _try_sw_log_image(self, key, image, step): 117 | if self._sw is not None: 118 | assert image.dim() == 3 119 | grid = torchvision.utils.make_grid(image.unsqueeze(1)) 120 | self._sw.add_image(key, grid, step) 121 | 122 | def _try_sw_log_video(self, key, frames, step): 123 | if self._sw is not None: 124 | frames = torch.from_numpy(np.array(frames)) 125 | frames = frames.unsqueeze(0) 126 | self._sw.add_video(key, frames, step, fps=30) 127 | 128 | def _try_sw_log_histogram(self, key, histogram, step): 129 | if self._sw is not None: 130 | self._sw.add_histogram(key, histogram, step) 131 | 132 | def log(self, key, value, step, n=1): 133 | assert key.startswith('train') or key.startswith('eval') 134 | if type(value) == torch.Tensor: 135 | value = value.item() 136 | self._try_sw_log(key, value / n, step) 137 | mg = self._train_mg if key.startswith('train') else self._eval_mg 138 | mg.log(key, value, n) 139 | if self.chester_logger is not None: 140 | self.chester_logger.record_tabular(key, value) 141 | 142 | def log_param(self, key, param, step): 143 | self.log_histogram(key + '_w', param.weight.data, step) 144 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 145 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 146 | if hasattr(param, 'bias'): 147 | self.log_histogram(key + '_b', param.bias.data, step) 148 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 149 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 150 | 151 | def log_image(self, key, image, step): 152 | assert key.startswith('train') or key.startswith('eval') 153 | self._try_sw_log_image(key, image, step) 154 | 155 | def log_video(self, key, frames, step): 156 | assert key.startswith('train') or key.startswith('eval') 157 | self._try_sw_log_video(key, frames, step) 158 | 159 | def log_histogram(self, key, histogram, step): 160 | assert key.startswith('train') or key.startswith('eval') 161 | self._try_sw_log_histogram(key, histogram, step) 162 | 163 | def dump(self, step): 164 | # NOTE(daniel): chester logger only dumps if we've done eval episodes 165 | #if len(self._eval_mg._prime_meters()) > 0 and self.chester_logger is not None: 166 | if self.chester_logger is not None: 167 | self.chester_logger.dump_tabular() 168 | self._train_mg.dump(step, 'train') 169 | self._eval_mg.dump(step, 'eval') 170 | -------------------------------------------------------------------------------- /bc/encoder.py: -------------------------------------------------------------------------------- 1 | from xml.dom.minidom import Identified 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def tie_weights(src, trg): 7 | assert type(src) == type(trg) 8 | trg.weight = src.weight 9 | trg.bias = src.bias 10 | 11 | 12 | # NOTE(daniel): get these by checking shape after going through conv layers 13 | # for 128 x 128 inputs 14 | OUT_DIM_128 = {4: 57} 15 | # for 100 x 100 inputs 16 | OUT_DIM_100 = {4: 43} 17 | # for 84 x 84 inputs 18 | OUT_DIM = {2: 39, 4: 35, 6: 31} 19 | # for 64 x 64 inputs 20 | OUT_DIM_64 = {2: 29, 4: 25, 6: 21} 21 | 22 | 23 | class PixelEncoder(nn.Module): 24 | """Convolutional encoder of pixels observations. 25 | 26 | Update 08/24/2022: support depth_segm as observation type. 27 | """ 28 | 29 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32, 30 | output_logits=False, depth_segm=False): 31 | super().__init__() 32 | 33 | assert len(obs_shape) == 3 34 | self.obs_shape = obs_shape 35 | self.feature_dim = feature_dim 36 | self.num_layers = num_layers 37 | self.depth_segm = depth_segm 38 | print(f'Making Image CNN. Using a segm? {self.depth_segm}, obs: {self.obs_shape}') 39 | 40 | self.convs = nn.ModuleList( 41 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] 42 | ) 43 | for i in range(num_layers - 1): 44 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) 45 | 46 | # Handle compatibility with output of convs and the FC layers portion. 47 | if obs_shape[-1] == 64: 48 | out_dim = OUT_DIM_64[num_layers] 49 | elif obs_shape[-1] == 84: 50 | out_dim = OUT_DIM[num_layers] 51 | elif obs_shape[-1] == 100: 52 | out_dim = OUT_DIM_100[num_layers] 53 | elif obs_shape[-1] == 128: 54 | out_dim = OUT_DIM_128[num_layers] 55 | else: 56 | raise NotImplementedError 57 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) 58 | self.ln = nn.LayerNorm(self.feature_dim) 59 | 60 | self.outputs = dict() 61 | self.output_logits = output_logits 62 | 63 | def reparameterize(self, mu, logstd): 64 | std = torch.exp(logstd) 65 | eps = torch.randn_like(std) 66 | return mu + eps * std 67 | 68 | def forward_conv(self, obs): 69 | """Forward pass through CNN. 70 | 71 | NOTE(daniel): since we support RGBD input now, only divide by 255 for the 72 | image based components. With just RGBD we can do: obs = obs / 255. 73 | """ 74 | if self.depth_segm: 75 | # Special case, we actually don't do anything as the image should already 76 | # be binary 0-1 for the masks, and between [0,x] for depth, where x ~ 1. 77 | # E.g., for MMOneSphere, image channels are: depth, tool mask, item mask. 78 | # Update: also doing this for rgb_segm_masks, rgbd_segm_masks, as we assume 79 | # the input will be in the correct range. 80 | pass 81 | else: 82 | obs[:,:3,:,:] /= 255. # (batch, channels, H, W) 83 | self.outputs['obs'] = obs 84 | 85 | conv = torch.relu(self.convs[0](obs)) 86 | self.outputs['conv1'] = conv 87 | 88 | for i in range(1, self.num_layers): 89 | conv = torch.relu(self.convs[i](conv)) 90 | self.outputs['conv%s' % (i + 1)] = conv 91 | 92 | h = conv.view(conv.size(0), -1) 93 | return h 94 | 95 | def forward(self, obs, detach=False): 96 | h = self.forward_conv(obs) 97 | 98 | if detach: 99 | h = h.detach() 100 | h_fc = self.fc(h) 101 | self.outputs['fc'] = h_fc 102 | 103 | h_norm = self.ln(h_fc) 104 | self.outputs['ln'] = h_norm 105 | 106 | if self.output_logits: 107 | out = h_norm 108 | else: 109 | out = torch.tanh(h_norm) 110 | self.outputs['tanh'] = out 111 | 112 | return out 113 | 114 | def copy_conv_weights_from(self, source): 115 | """Tie convolutional layers""" 116 | # only tie conv layers 117 | for i in range(self.num_layers): 118 | tie_weights(src=source.convs[i], trg=self.convs[i]) 119 | 120 | def log(self, L, step, log_freq): 121 | if step % log_freq != 0: 122 | return 123 | 124 | for k, v in self.outputs.items(): 125 | L.log_histogram('train_encoder/%s_hist' % k, v, step) 126 | if len(v.shape) > 2: 127 | L.log_image('train_encoder/%s_img' % k, v[0], step) 128 | 129 | for i in range(self.num_layers): 130 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) 131 | L.log_param('train_encoder/fc', self.fc, step) 132 | L.log_param('train_encoder/ln', self.ln, step) 133 | 134 | 135 | class IdentityEncoder(nn.Module): 136 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, output_logits, 137 | depth_segm=False): 138 | super().__init__() 139 | assert len(obs_shape) == 1 140 | self.feature_dim = obs_shape[0] 141 | 142 | def forward(self, obs, detach=False): 143 | return obs 144 | 145 | def copy_conv_weights_from(self, source): 146 | pass 147 | 148 | def log(self, L, step, log_freq): 149 | pass 150 | 151 | 152 | class IdentityPNEncoder(nn.Module): 153 | """NOTE(daniel) for PointNet++, to allow for len(obs_shape) = 2. 154 | But this will still be an identity function. Just to make code consistent. 155 | """ 156 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, output_logits, 157 | depth_segm=False): 158 | super().__init__() 159 | assert len(obs_shape) == 2 160 | self.feature_dim = obs_shape 161 | 162 | def forward(self, obs, detach=False): 163 | return obs 164 | 165 | def copy_conv_weights_from(self, source): 166 | pass 167 | 168 | def log(self, L, step, log_freq): 169 | pass 170 | 171 | 172 | # NOTE(daniel): this shoudl be fixed at some point, after deadline. :) 173 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 174 | 'segm': PixelEncoder, 175 | 'mlp': IdentityEncoder, 176 | 'state_predictor_then_mlp': IdentityEncoder, 177 | 'identity': IdentityEncoder, 178 | 'pointnet': IdentityPNEncoder, 179 | 'pointnet_rpmg': IdentityPNEncoder, 180 | 'pointnet_avg': IdentityPNEncoder, 181 | 'pointnet_svd': IdentityPNEncoder, 182 | 'pointnet_svd_centered': IdentityPNEncoder, 183 | 'pointnet_svd_pointwise': IdentityPNEncoder, 184 | 'pointnet_svd_pointwise_6d_flow': IdentityPNEncoder, 185 | 'pointnet_dense_tf_3D_MSE': IdentityPNEncoder, 186 | 'pointnet_dense_tf_6D_MSE': IdentityPNEncoder, 187 | 'pointnet_dense_tf_6D_pointwise': IdentityPNEncoder, 188 | 'pointnet_classif_6D_pointwise': IdentityPNEncoder, 189 | 'pointnet_rpmg_pointwise': IdentityPNEncoder, 190 | 'pointnet_rpmg_taugt': IdentityPNEncoder, 191 | 'pointnet_svd_6d_flow_mse_loss': IdentityPNEncoder, 192 | 'pointnet_svd_pointwise_PW_bef_SVD': IdentityPNEncoder, 193 | } 194 | 195 | 196 | def make_encoder( 197 | encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False, 198 | depth_segm=False 199 | ): 200 | assert encoder_type in _AVAILABLE_ENCODERS 201 | return _AVAILABLE_ENCODERS[encoder_type]( 202 | obs_shape, feature_dim, num_layers, num_filters, output_logits, depth_segm=depth_segm 203 | ) 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | ToolFlowNet: Robotic Manipulation with Tools via Predicting Tool Flow from Point Clouds

3 | 4 |
5 | Daniel Seita  •  6 | Yufei Wang  •  7 | Sarthak J. Shetty  •  8 | Edward Y. Li  •  9 | Zackory Erickson  •  10 | David Held 11 |
12 | 13 |

14 | Website  •  15 | Paper 16 |

17 | 18 |
19 | 20 | This has the **ToolFlowNet** code for 21 | our CoRL 2022 paper. 22 | See our other repository 23 | for the SoftGym environment code. Note that you should install that code *first*, 24 | and then this ToolFlowNet code *second*. 25 | 26 | The code in this repository has a bunch of SoftGym-specific components. We are 27 | working on a separate, simplified version of ToolFlowNet which is more agnostic 28 | to the choice of environment or task. 29 | 30 | **Note:** This branch contains instructions to run the ToolFlowNet simulation code. For the physical experiments code, we maintain a seperate [**`physical`**](https://github.com/DanielTakeshi/softagent_tfn/tree/physical) branch and a separate repository [**`tfn-robot`**](https://github.com/SarthakJShetty/tfn-robot) for the real-world data collection and robot control code. 31 | 32 |
33 | 34 | Contents: 35 | 36 | - [Installation](#installation) 37 | - [How to Use](#how-to-use) 38 | - [CoRL 2022 Experiments](#corl-2022-experiments) 39 | - [Inspect Results](#inspect-results) 40 | - [Citation](#citation) 41 | 42 |
43 | 44 | ## Installation 45 | 46 | Assuming you have already installed the other repository, then for this one, we 47 | first need to make a symlink to `softgym_tfn`. Make sure that `softgym_tfn/` 48 | does not exist within this folder. Then run this: 49 | 50 | ``` 51 | ln -s ../softgym_tfn/ softgym 52 | ``` 53 | 54 | This creates a symlink so the `softgym` subdirectory points to our `softgym_tfn` 55 | repository. By typing in `ls -lh` you should see: `softgym -> ../softgym_tfn/`. 56 | 57 | Then run this in any new Terminal window/tab that you want to run code in: 58 | 59 | ``` 60 | . ./prepare_1.0.sh 61 | ``` 62 | 63 | This will go into `softgym` and set up the conda environment. It should also set 64 | `PYFLEXROOT` which points to `PyFlex` in the SoftAgent repository. 65 | 66 | ## How to Use 67 | 68 | The main way that we launch experiments is with: 69 | 70 | ``` 71 | python launch_exp.py 72 | ``` 73 | 74 | or 75 | 76 | ``` 77 | python launch_exp.py --debug 78 | ``` 79 | 80 | The first case will run multiple combinations of variants in parallel. Thus, be 81 | careful about launching a lot of variants, since the combination can overwhelm 82 | one machine. Adding the `--debug` flag means the code runs just one of the 83 | variants. *We recommend using `--debug` to start*. In addition, when running 84 | multiple variants, we recommend only adjusting the random seed. We can do this 85 | by (for example) setting `vg.add('seed', [100,101])` and making all other 86 | `vg.add(...)` calls use just one-length lists. This will run 2 runs in parallel, 87 | each with the same settings, except with different random seeds. For the paper, 88 | we launched these scripts while using 5 random seeds with 89 | `vg.add('seed', [100,101,102,103,104])`. 90 | 91 | See `launch_exp.py` for details on what to modify. The three main areas to 92 | adjust for the purpose of learning from demonstrations are: 93 | 94 | - Adjusting the behavioral cloning data directory. 95 | - Selecting the environment to use, `PourWater` or `PourWater6D`. In this code, 96 | `PourWater` refers to the task version with 3DOF actions. 97 | - Selecting the method to use by setting `this_cfg` appropriately. See the code 98 | comments and `bc/exp_configs.py` for more about what the different 99 | configurations mean. 100 | 101 | You can find these areas by searching in `launch_exp.py` for this pattern: 102 | 103 | ``` 104 | # ----------------------------- ADJUST -------------------------------- # 105 | 106 | # --------------------------------------------------------------------- # 107 | ``` 108 | 109 | Check the content in between the above two lines. 110 | 111 | See the next section for how we set these for the CoRL 2022 submission. 112 | 113 | **Important note**: we highly recommend using `wandb` to track experiments. 114 | Please adjust these two lines in `launch_exp.py` appropriately: 115 | 116 | ``` 117 | vg.add('wandb_project', ['']) # Fill this in! 118 | vg.add('wandb_entity', ['']) # Fill this in! 119 | ``` 120 | 121 | If you need a refresher on `wandb`, refer to [the official documentation][2]. If 122 | you leave these blank, the script might not run successfully. 123 | 124 | ## CoRL 2022 Experiments 125 | 126 | Before this, make sure you have downloaded demonstration data [following our 127 | other repository's instructions][1]. This includes both the cache and the 128 | demonstrations themselves. While the default instructions put the data in 129 | `~/softgym_tfn/data_demo`, you may put the data in a different location if 130 | desired. 131 | 132 | **First**: with the demonstration data, adjust the `DATA_HEAD` variable. For 133 | example, setting this: 134 | 135 | ``` 136 | DATA_HEAD = '/home/seita/softgym_tfn/data_demo/' 137 | ``` 138 | 139 | near the top of `launch_exp.py` means that, for a run with PourWater (3D), I 140 | should expect to see the demonstrations located at: 141 | 142 | ``` 143 | /home/seita/softgym_tfn/data_demo/PourWater_v01_BClone_filtered_wDepth_pw_algo_v02_nVars_1500_obs_combo_act_translation_axis_angle_withWaterFrac 144 | ``` 145 | 146 | **Second**: select the task you want, either `PourWater` (the 3DOF action space 147 | version) or `PourWater6D` (with 6DOF actions). This means selecting *one* of the 148 | following: 149 | 150 | ``` 151 | env, env_version, alg_policy = 'PourWater', 'v01', 'pw_algo_v02' 152 | env, env_version, alg_policy = 'PourWater6D', 'v01', 'pw_algo_v02' 153 | ``` 154 | 155 | Be sure to comment out whatever option you are not using. 156 | 157 | **Third**: pick the method. For example, select ToolFlowNet with: 158 | 159 | ``` 160 | this_cfg = exp_configs.SVD_POINTWISE_EE2FLOW 161 | ``` 162 | 163 | or the PCL Direct Vector MSE baseline with: 164 | 165 | ``` 166 | this_cfg = exp_configs.DIRECT_VECTOR_INTRINSIC_AXIS_ANGLE 167 | ``` 168 | 169 | There are many experiment options. See the comments and `bc/exp_configs.py` for 170 | more details. 171 | 172 | Finally, double check all the settings in the variant generator (`vg`). For the 173 | paper we typically ran by setting: 174 | 175 | ``` 176 | vg.add('seed', [100,101,102,103,104]) 177 | ``` 178 | 179 | as the only variant, in the sense that this is the only `vg` option with more 180 | than one list item. This means (as stated in our paper) we ran 5 random seeds 181 | for each experiment setting. 182 | 183 | Once you are confident the settings are correct, run the script! (Did you 184 | remember to set up `wandb`?) 185 | 186 | 187 | ## Inspect Results 188 | 189 | For accumulating and computing results for the CoRL 2022 paper, we used one of 190 | the following four commands: 191 | 192 | ``` 193 | python results_table.py 194 | python results_table.py --show_avg_epoch 195 | python results_table.py --show_raw_perf 196 | python results_table.py --show_raw_perf --show_avg_epoch 197 | ``` 198 | 199 | These four commands, respectively, produce a table of statistics which correspond to results in 200 | **Table 1**, **Table S5**, **Table S6**, and **Table S7** in the paper. 201 | Some relevant keys in the code are: 202 | 203 | - `FLOW3D_SVD_PW_CONSIST_0_1` corresponds to ToolFlowNet with `lambda` (consistency weight) value of 0.1. 204 | - `PCL_DIRECT_VECTOR_MSE_INTRIN_AXANG` corresponds to "PCL Direct Vector MSE." 205 | - `PCL_DENSE_TRANSF_MSE_INTRIN_AXANG` corresponds to "PCL Dense Transformation MSE." 206 | 207 | Successfully running this command requires having all relevant experiment 208 | results. We provide this script to explain how we produced the results. 209 | 210 | 211 | ## Citation 212 | 213 | If you find this repository useful, please cite our paper: 214 | 215 | ``` 216 | @inproceedings{Seita2022toolflownet, 217 | title={{ToolFlowNet: Robotic Manipulation with Tools via Predicting Tool Flow from Point Clouds}}, 218 | author={Seita, Daniel and Wang, Yufei and Shetty, Sarthak, and Li, Edward and Erickson, Zackory and Held, David}, 219 | booktitle={Conference on Robot Learning (CoRL)}, 220 | year={2022} 221 | } 222 | ``` 223 | 224 | [1]:https://github.com/DanielTakeshi/softgym_tfn 225 | [2]:https://docs.wandb.ai/ 226 | -------------------------------------------------------------------------------- /chester/plotting/cplot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import json 5 | import argparse 6 | import itertools 7 | import numpy as np 8 | 9 | # Matplotlib 10 | import matplotlib 11 | 12 | matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | 15 | plt.rc('font', size=25) 16 | matplotlib.rcParams['pdf.fonttype'] = 42 # Default type3 cannot be rendered in some templates 17 | matplotlib.rcParams['ps.fonttype'] = 42 18 | matplotlib.rcParams['grid.alpha'] = 0.3 19 | matplotlib.rcParams['axes.titlesize'] = 25 20 | import matplotlib.ticker as tick 21 | 22 | # rllab 23 | sys.path.append('.') 24 | from rllab.misc.ext import flatten 25 | from rllab.viskit import core 26 | 27 | 28 | # from rllab.misc import ext 29 | 30 | # plotly 31 | # import plotly.offline as po 32 | # import plotly.graph_objs as go 33 | 34 | 35 | def smooth_data(data, smooth): 36 | """NOTE(daniel) smoothing with window average, from SpinningUp. 37 | https://github.com/openai/spinningup/blob/master/spinup/utils/plot.py#L15 38 | 39 | smooth data with moving window average. 40 | that is, 41 | smoothed_y[t] = average(y[t-k], y[t-k+1], ..., y[t+k-1], y[t+k]) 42 | where the "smooth" param is width of that window (2k+1) 43 | """ 44 | if smooth <= 1: 45 | return data 46 | x = np.asarray(data) 47 | y = np.ones(smooth) 48 | z = np.ones(len(x)) 49 | smoothed_x = np.convolve(x,y,'same') / np.convolve(z,y,'same') 50 | return smoothed_x 51 | 52 | 53 | def reload_data(data_paths): 54 | """ 55 | Iterate through the data folder and organize each experiment into a list, with their progress data, hyper-parameters 56 | and also analyze all the curves and give the distinct hyper-parameters. 57 | :param data_path: Path of the folder storing all the data 58 | :return [exps_data, plottable_keys, distinct_params] 59 | exps_data: A list of the progress data for each curve. Each curve is an AttrDict with the key 60 | 'progress': A dictionary of plottable keys. The val of each key is an ndarray representing the 61 | values of the key during training, or one column in the progress.txt file. 62 | 'params'/'flat_params': A dictionary of all hyperparameters recorded in 'variants.json' file. 63 | plottable_keys: A list of strings representing all the keys that can be plotted. 64 | distinct_params: A list of hyper-parameters which have different values among all the curves. This can be used 65 | to split the graph into multiple figures. Each element is a tuple (param, list_of_values_to_take). 66 | """ 67 | 68 | exps_data = copy.copy(core.load_exps_data(data_paths, disable_variant=False, ignore_missing_keys=True)) 69 | plottable_keys = copy.copy(sorted(list(set(flatten(list(exp.progress.keys()) for exp in exps_data))))) 70 | distinct_params = copy.copy(sorted(core.extract_distinct_params(exps_data))) 71 | 72 | return exps_data, plottable_keys, distinct_params 73 | 74 | 75 | def get_shaded_curve(selector, key, shade_type='variance'): 76 | """ 77 | :param selector: Selector for a group of curves 78 | :param shade_type: Should be either 'variance' or 'median', indicating how the shades are calculated. 79 | :return: [y, y_lower, y_upper], representing the mean, upper and lower boundary of the shaded region 80 | """ 81 | 82 | # First, get the progresses 83 | progresses = [exp.progress.get(key, np.array([np.nan])) for exp in selector.extract()] 84 | max_size = max(len(x) for x in progresses) 85 | progresses = [np.concatenate([ps, np.ones(max_size - len(ps)) * np.nan]) for ps in progresses] 86 | 87 | # Second, calculate the shaded area 88 | if shade_type == 'median': 89 | percentile25 = np.nanpercentile( 90 | progresses, q=25, axis=0) 91 | percentile50 = np.nanpercentile( 92 | progresses, q=50, axis=0) 93 | percentile75 = np.nanpercentile( 94 | progresses, q=75, axis=0) 95 | 96 | y = list(percentile50) 97 | y_upper = list(percentile75) 98 | y_lower = list(percentile25) 99 | elif shade_type == 'variance': 100 | means = np.nanmean(progresses, axis=0) 101 | stds = np.nanstd(progresses, axis=0) 102 | 103 | y = list(means) 104 | y_upper = list(means + stds) 105 | y_lower = list(means - stds) 106 | else: 107 | raise NotImplementedError 108 | 109 | return y, y_lower, y_upper 110 | 111 | 112 | def get_group_selectors(exps, custom_series_splitter): 113 | """Get selectors, a custom rllab class. 114 | 115 | Example: 116 | splitted_dict['Reduced State Oracle (SAC)'] = [{dict1}, {dict2},...] 117 | Each `dictk` has data from one `progress.csv`, created from one RL run. 118 | 119 | :param exps: list of experiments, each is of `rllab.misc.ext.AttrDict` type. 120 | IDK why they need that. The keys are 'progress' (which is loaded from the 121 | csv), 'params' and 'flat_params'. The 'params' and 'flat_params' seem to 122 | only differ based on the latter having our new `env_kwargs_{...}` stuff. 123 | :param custom_series_splitter: custom function defined to extract the algorithm 124 | and other info, and produce a label for the legend. 125 | :return: A tuple of (list,list) type, containing the selectors and legends. 126 | """ 127 | splitted_dict = dict() 128 | for exp in exps: 129 | # Group exps by their series_splitter key 130 | # splitted_dict: {key:[exp1, exp2, ...]} 131 | key = custom_series_splitter(exp) 132 | if key not in splitted_dict: 133 | splitted_dict[key] = list() 134 | splitted_dict[key].append(exp) 135 | 136 | splitted = list(splitted_dict.items()) # list of tuples, each tuple is (key,val) 137 | # Group selectors: All the exps in one of the keys/legends 138 | # Group legends: All the different legends 139 | group_selectors = [core.Selector(list(x[1])) for x in splitted] # x[1]: list of progress.csv dicts 140 | group_legends = [x[0] for x in splitted] # x[0] is key, e.g., 'Reduced State Oracle (SAC)' 141 | all_tuples = sorted(list(zip(group_selectors, group_legends)), key=lambda x: x[1], reverse=True) 142 | group_selectors = [x[0] for x in all_tuples] 143 | group_legends = [x[1] for x in all_tuples] 144 | return group_selectors, group_legends 145 | 146 | 147 | def filter_save_name(save_name): 148 | save_name = save_name.replace('/', '_') 149 | save_name = save_name.replace('[', '_') 150 | save_name = save_name.replace(']', '_') 151 | save_name = save_name.replace('(', '_') 152 | save_name = save_name.replace(')', '_') 153 | save_name = save_name.replace(',', '_') 154 | save_name = save_name.replace(' ', '_') 155 | save_name = save_name.replace('0.', '0_') 156 | return save_name 157 | 158 | 159 | def sliding_mean(data_array, window=5): 160 | data_array = np.array(data_array) 161 | new_list = [] 162 | for i in range(len(data_array)): 163 | indices = list(range(max(i - window + 1, 0), 164 | min(i + window + 1, len(data_array)))) 165 | avg = 0 166 | for j in indices: 167 | avg += data_array[j] 168 | avg /= float(len(indices)) 169 | new_list.append(avg) 170 | 171 | return np.array(new_list) 172 | 173 | 174 | if __name__ == '__main__': 175 | data_path = '/Users/Dora/Projects/baselines_hrl/data/seuss/visual_rss_RopeFloat_0407' 176 | exps_data, plottable_keys, distinct_params = reload_data(data_path) 177 | 178 | # Example of extracting a single curve 179 | selector = core.Selector(exps_data) 180 | selector = selector.where('her_replay_strategy', 'balance_filter') 181 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state') 182 | _, ax = plt.subplots() 183 | 184 | color = core.color_defaults[0] 185 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2) 186 | ax.plot(range(len(y)), y, color=color, label=plt.legend, linewidth=2.0) 187 | 188 | 189 | # Example of extracting all the curves 190 | def custom_series_splitter(x): 191 | params = x['flat_params'] 192 | if 'use_ae_reward' in params and params['use_ae_reward']: 193 | return 'Auto Encoder' 194 | if params['her_replay_strategy'] == 'balance_filter': 195 | return 'Indicator+Balance+Filter' 196 | if params['env_kwargs.use_true_reward']: 197 | return 'Oracle' 198 | return 'Indicator' 199 | 200 | 201 | fig, ax = plt.subplots(figsize=(8, 5)) 202 | 203 | group_selectors, group_legends = get_group_selectors(exps_data, custom_series_splitter) 204 | for idx, (selector, legend) in enumerate(zip(group_selectors, group_legends)): 205 | color = core.color_defaults[idx] 206 | 207 | y, y_lower, y_upper = get_shaded_curve(selector, 'test/success_state') 208 | 209 | ax.plot(range(len(y)), y, color=color, label=legend, linewidth=2.0) 210 | ax.fill_between(range(len(y)), y_lower, y_upper, interpolate=True, facecolor=color, linewidth=0.0, alpha=0.2) 211 | ax.grid(True) 212 | ax.set_xlabel('Timesteps') 213 | ax.set_ylabel('Success') 214 | loc = 'best' 215 | leg = ax.legend(loc=loc, prop={'size': 15}, ncol=1, labels=group_legends) 216 | for legobj in leg.legendHandles: 217 | legobj.set_linewidth(3.0) 218 | plt.savefig('test.png', bbox_inches='tight') 219 | -------------------------------------------------------------------------------- /tests/test_pt3d_pyquaternion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 53, 6 | "id": "6b71c9f3-5edc-458d-9957-e74c7e0fd786", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import math\n", 11 | "import numpy as np\n", 12 | "import torch\n", 13 | "from pyquaternion import Quaternion\n", 14 | "from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_axis_angle, axis_angle_to_matrix, Rotate\n", 15 | "\n", 16 | "from bc.se3 import flow2pose" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 25, 22 | "id": "39b8afcb-f566-48e4-bf53-5dfe69706ac3", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "axis = [0, math.sqrt(2) / 2, math.sqrt(2) / 2]\n", 27 | "angle = math.pi / 2" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 26, 33 | "id": "047facaf-c0e6-4e36-9150-fcdb27cf862a", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "quat = Quaternion(axis=axis, angle=angle)\n", 38 | "points = np.vstack(np.meshgrid(np.linspace(-1, 1, 10), np.linspace(-1, 1, 10), np.linspace(-1, 1, 10))).reshape(3, -1).T" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 27, 44 | "id": "fc97695c-75d9-42fb-94a1-7ba7ec108712", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "array([[ 1.11022302e-16, -7.07106781e-01, 7.07106781e-01],\n", 51 | " [ 7.07106781e-01, 5.00000000e-01, 5.00000000e-01],\n", 52 | " [-7.07106781e-01, 5.00000000e-01, 5.00000000e-01]])" 53 | ] 54 | }, 55 | "execution_count": 27, 56 | "metadata": {}, 57 | "output_type": "execute_result" 58 | } 59 | ], 60 | "source": [ 61 | "quat.rotation_matrix" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 28, 67 | "id": "e17fedfa-08a3-44e2-b374-7760749d15dc", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "flow = np.zeros_like(points)\n", 72 | "for i in range(points.shape[0]):\n", 73 | " pt = points[i]\n", 74 | " rot = quat.rotate(pt)\n", 75 | " flow[i] += rot - pt\n", 76 | "\n", 77 | "points = torch.from_numpy(points.astype(np.float32))\n", 78 | "flow = torch.from_numpy(flow.astype(np.float32))" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 29, 84 | "id": "be73a4a7-51ab-41e2-ab79-10341a332637", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "trfm = flow2pose(\n", 89 | " xyz=points[None, :],\n", 90 | " flow=flow[None, :],\n", 91 | " weights=None,\n", 92 | " return_transform3d=True,\n", 93 | " return_quaternions=False,\n", 94 | ")" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 30, 100 | "id": "f43a6f21-b47a-471d-9736-b8d7d477b56a", 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "True" 107 | ] 108 | }, 109 | "execution_count": 30, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "pred_flow = trfm.transform_points(points).squeeze(0) - points\n", 116 | "torch.allclose(flow, pred_flow, atol=1e-4)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 31, 122 | "id": "90a38c7b-70c2-4b68-941b-8d9a09ceee16", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "rot_matrices, trans = flow2pose(\n", 127 | " xyz=points[None, :],\n", 128 | " flow=flow[None, :],\n", 129 | " weights=None,\n", 130 | " return_transform3d=False,\n", 131 | " return_quaternions=False,\n", 132 | " world_frameify=False,\n", 133 | ")" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 32, 139 | "id": "41559b3c-c7f6-4823-9ca4-557c4fe92d5d", 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "text/plain": [ 145 | "tensor([[[ 0.0000, 0.7071, -0.7071],\n", 146 | " [-0.7071, 0.5000, 0.5000],\n", 147 | " [ 0.7071, 0.5000, 0.5000]]])" 148 | ] 149 | }, 150 | "execution_count": 32, 151 | "metadata": {}, 152 | "output_type": "execute_result" 153 | } 154 | ], 155 | "source": [ 156 | "rot_matrices" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 33, 162 | "id": "afffab03-032b-4625-9687-f8838c804a78", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "quats = matrix_to_quaternion(matrix=rot_matrices.transpose(1, 2))\n", 167 | "axis_ang = quaternion_to_axis_angle(quaternions=quats)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 34, 173 | "id": "6ab53033-a75a-4d6b-96b4-33801c3bb8a5", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "pred_axis = axis_ang / torch.linalg.norm(axis_ang)\n", 178 | "pred_angle = torch.linalg.norm(axis_ang)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 35, 184 | "id": "a2d19df3-9599-40f9-82ef-9725d2cbcad4", 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "tensor([[5.9605e-08, 7.0711e-01, 7.0711e-01]])\n", 192 | "tensor(1.5708)\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "print(pred_axis)\n", 198 | "print(pred_angle)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 36, 204 | "id": "c0663f71-3027-4a00-b205-69a1db9d5dbe", 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "data": { 209 | "text/plain": [ 210 | "True" 211 | ] 212 | }, 213 | "execution_count": 36, 214 | "metadata": {}, 215 | "output_type": "execute_result" 216 | } 217 | ], 218 | "source": [ 219 | "pred_flow = torch.bmm(points.unsqueeze(0), rot_matrices).squeeze(0) - points\n", 220 | "torch.allclose(flow, pred_flow, atol=1e-4)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 37, 226 | "id": "130c57bd-20e5-4676-84dc-5f1c34e85972", 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "torch.Size([1, 3])" 233 | ] 234 | }, 235 | "execution_count": 37, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "trans.shape" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 38, 247 | "id": "80180df0-a846-405b-9220-b9970295fa3e", 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "data": { 252 | "text/plain": [ 253 | "torch.Size([1, 3])" 254 | ] 255 | }, 256 | "execution_count": 38, 257 | "metadata": {}, 258 | "output_type": "execute_result" 259 | } 260 | ], 261 | "source": [ 262 | "torch.mean(flow, axis=0, keepdims=True).shape" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "id": "44c2d7d4-a27a-47d4-b5c2-42fb77450855", 268 | "metadata": {}, 269 | "source": [ 270 | "Dense Trfm Testing\n", 271 | "==================" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 61, 277 | "id": "8a1ac814-eae4-4705-8a57-9596466b4e1c", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "axis_ang = torch.tensor([axis]) * angle\n", 282 | "rot_matrix = axis_angle_to_matrix(axis_ang).transpose(1, 2)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 62, 288 | "id": "3d854473-c848-4b67-b553-3b48f624b53a", 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "True" 295 | ] 296 | }, 297 | "execution_count": 62, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "pred_flow = torch.bmm(points.unsqueeze(0), rot_matrix).squeeze(0) - points\n", 304 | "torch.allclose(flow, pred_flow, atol=1e-4)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 63, 310 | "id": "da168b0b-5e14-4794-be24-9b32ea0ed47e", 311 | "metadata": {}, 312 | "outputs": [ 313 | { 314 | "data": { 315 | "text/plain": [ 316 | "True" 317 | ] 318 | }, 319 | "execution_count": 63, 320 | "metadata": {}, 321 | "output_type": "execute_result" 322 | } 323 | ], 324 | "source": [ 325 | "trfm = Rotate(rot_matrix)\n", 326 | "pred_flow = trfm.transform_points(points).squeeze(0) - points\n", 327 | "torch.allclose(flow, pred_flow, atol=1e-4)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": null, 333 | "id": "6c0bc3e3-e16c-4522-8521-5918f457c0ab", 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "softgym", 342 | "language": "python", 343 | "name": "softgym" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.6.13" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 5 360 | } 361 | -------------------------------------------------------------------------------- /rpmg/rpmg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | BASEPATH = os.path.dirname(__file__) 5 | sys.path.append(BASEPATH) 6 | import tools 7 | 8 | def Rodrigues(w): 9 | ''' 10 | axis angle -> rotation 11 | :param w: [b,3] 12 | :return: R: [b,3,3] 13 | ''' 14 | w = w.unsqueeze(2).unsqueeze(3).repeat(1, 1, 3, 3) 15 | b = w.shape[0] 16 | theta = w.norm(dim=1) 17 | #print(theta[0]) 18 | #theta = torch.where(t>math.pi/16, torch.Tensor([math.pi/16]).cuda(), t) 19 | wnorm = w / (w.norm(dim=1,keepdim=True)+0.001) 20 | #wnorm = torch.nn.functional.normalize(w,dim=1) 21 | I = torch.eye(3, device=w.get_device()).repeat(b, 1, 1) 22 | help1 = torch.zeros((b,1,3, 3), device=w.get_device()) 23 | help2 = torch.zeros((b,1,3, 3), device=w.get_device()) 24 | help3 = torch.zeros((b,1,3, 3), device=w.get_device()) 25 | help1[:,:,1, 2] = -1 26 | help1[:,:,2, 1] = 1 27 | help2[:,:,0, 2] = 1 28 | help2[:,:,2, 0] = -1 29 | help3[:,:,0, 1] = -1 30 | help3[:,:,1, 0] = 1 31 | Jwnorm = (torch.cat([help1,help2,help3],1)*wnorm).sum(dim=1) 32 | 33 | return I + torch.sin(theta) * Jwnorm + (1 - torch.cos(theta)) * torch.bmm(Jwnorm, Jwnorm) 34 | 35 | logger = 0 36 | def logger_init(ll): 37 | global logger 38 | logger = ll 39 | print('logger init') 40 | 41 | class RPMG(torch.autograd.Function): 42 | ''' 43 | full version. See "simple_RPMG()" for a simplified version. 44 | Tips: 45 | 1. Use "logger_init()" to initialize the logger, if you want to record some intermidiate variables by tensorboard. 46 | 2. Use sum of L2/geodesic loss instead of mean, since our tau_converge is derivated without considering the scalar introduced by mean loss. 47 | See for an example. 48 | 3. Pass "weight=$YOUR_WEIGHT" instead of directly multiple the weight on rotation loss, if you want to reweight R loss and other losses. 49 | See for an example. 50 | ''' 51 | @staticmethod 52 | def forward(ctx, in_nd, tau, lam, rgt, iter, weight=1): 53 | proj_kind = in_nd.shape[1] 54 | if proj_kind == 6: 55 | r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd) 56 | elif proj_kind == 9: 57 | r0 = tools.symmetric_orthogonalization(in_nd) 58 | elif proj_kind == 4: 59 | r0 = tools.compute_rotation_matrix_from_quaternion(in_nd) 60 | elif proj_kind == 10: 61 | r0 = tools.compute_rotation_matrix_from_10d(in_nd) 62 | else: 63 | raise NotImplementedError 64 | ctx.save_for_backward(in_nd, r0, torch.Tensor([tau,lam, iter, weight]), rgt) 65 | return r0 66 | 67 | @staticmethod 68 | def backward(ctx, grad_in): 69 | in_nd, r0, config,rgt, = ctx.saved_tensors 70 | tau = config[0] 71 | lam = config[1] 72 | b = r0.shape[0] 73 | iter = config[2] 74 | weight = config[3] 75 | proj_kind = in_nd.shape[1] 76 | 77 | # use Riemannian optimization to get the next goal R 78 | if tau == -1: 79 | r_new = rgt 80 | else: 81 | # Eucliean gradient -> Riemannian gradient 82 | Jx = torch.zeros((b, 3, 3)).cuda() 83 | Jx[:, 2, 1] = 1 84 | Jx[:, 1, 2] = -1 85 | Jy = torch.zeros((b, 3, 3)).cuda() 86 | Jy[:, 0, 2] = 1 87 | Jy[:, 2, 0] = -1 88 | Jz = torch.zeros((b, 3, 3)).cuda() 89 | Jz[:, 0, 1] = -1 90 | Jz[:, 1, 0] = 1 91 | gx = (grad_in*torch.bmm(r0, Jx)).reshape(-1,9).sum(dim=1,keepdim=True) 92 | gy = (grad_in * torch.bmm(r0, Jy)).reshape(-1, 9).sum(dim=1,keepdim=True) 93 | gz = (grad_in * torch.bmm(r0, Jz)).reshape(-1, 9).sum(dim=1,keepdim=True) 94 | g = torch.cat([gx,gy,gz],1) 95 | 96 | # take one step 97 | delta_w = -tau * g 98 | 99 | # update R 100 | r_new = torch.bmm(r0, Rodrigues(delta_w)) 101 | 102 | #this can help you to tune the tau if you don't use L2/geodesic loss. 103 | if iter % 100 == 0: 104 | logger.add_scalar('next_goal_angle_mean', delta_w.norm(dim=1).mean(), iter) 105 | logger.add_scalar('next_goal_angle_max', delta_w.norm(dim=1).max(), iter) 106 | R0_Rgt = tools.compute_geodesic_distance_from_two_matrices(r0, rgt) 107 | logger.add_scalar('r0_rgt_angle', R0_Rgt.mean(), iter) 108 | 109 | # inverse & project 110 | if proj_kind == 6: 111 | r_proj_1 = (r_new[:, :, 0] * in_nd[:, :3]).sum(dim=1, keepdim=True) * r_new[:, :, 0] 112 | r_proj_2 = (r_new[:, :, 0] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 0] \ 113 | + (r_new[:, :, 1] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 1] 114 | r_reg_1 = lam * (r_proj_1 - r_new[:, :, 0]) 115 | r_reg_2 = lam * (r_proj_2 - r_new[:, :, 1]) 116 | gradient_nd = torch.cat([in_nd[:, :3] - r_proj_1 + r_reg_1, in_nd[:, 3:] - r_proj_2 + r_reg_2], 1) 117 | elif proj_kind == 9: 118 | SVD_proj = tools.compute_SVD_nearest_Mnlsew(in_nd.reshape(-1,3,3), r_new) 119 | gradient_nd = in_nd - SVD_proj + lam * (SVD_proj - r_new.reshape(-1,9)) 120 | R_proj_g = tools.symmetric_orthogonalization(SVD_proj) 121 | if iter % 100 == 0: 122 | logger.add_scalar('9d_reflection', (((R_proj_g-r_new).reshape(-1,9).abs().sum(dim=1))>5e-1).sum(), iter) 123 | logger.add_scalar('reg', (SVD_proj - r_new.reshape(-1, 9)).norm(dim=1).mean(), iter) 124 | logger.add_scalar('main', (in_nd - SVD_proj).norm(dim=1).mean(), iter) 125 | elif proj_kind == 4: 126 | q_1 = tools.compute_quaternions_from_rotation_matrices(r_new) 127 | q_2 = -q_1 128 | normalized_nd = tools.normalize_vector(in_nd) 129 | q_new = torch.where( 130 | (q_1 - normalized_nd).norm(dim=1, keepdim=True) < (q_2 - normalized_nd).norm(dim=1, keepdim=True), 131 | q_1, q_2) 132 | q_proj = (in_nd * q_new).sum(dim=1, keepdim=True) * q_new 133 | gradient_nd = in_nd - q_proj + lam * (q_proj - q_new) 134 | elif proj_kind == 10: 135 | qg = tools.compute_quaternions_from_rotation_matrices(r_new) 136 | new_x = tools.compute_nearest_10d(in_nd, qg) 137 | reg_A = torch.eye(4, device=qg.device)[None].repeat(qg.shape[0],1,1) - torch.bmm(qg.unsqueeze(-1), qg.unsqueeze(-2)) 138 | reg_x = tools.convert_A_to_Avec(reg_A) 139 | gradient_nd = in_nd - new_x + lam * (new_x - reg_x) 140 | if iter % 100 == 0: 141 | logger.add_scalar('reg', (new_x - reg_x).norm(dim=1).mean(), iter) 142 | logger.add_scalar('main', (in_nd - new_x).norm(dim=1).mean(), iter) 143 | 144 | return gradient_nd * weight, None, None,None,None,None 145 | 146 | 147 | 148 | class simple_RPMG(torch.autograd.Function): 149 | ''' 150 | simplified version without tensorboard and r_gt. 151 | ''' 152 | @staticmethod 153 | def forward(ctx, in_nd, tau, lam, weight=1, rgt=None): 154 | proj_kind = in_nd.shape[1] 155 | if proj_kind == 6: 156 | r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd) 157 | elif proj_kind == 9: 158 | r0 = tools.symmetric_orthogonalization(in_nd) 159 | elif proj_kind == 4: 160 | r0 = tools.compute_rotation_matrix_from_quaternion(in_nd) 161 | elif proj_kind == 10: 162 | r0 = tools.compute_rotation_matrix_from_10d(in_nd) 163 | else: 164 | raise NotImplementedError 165 | 166 | if rgt is None: 167 | rgt = torch.zeros_like(r0) 168 | 169 | ctx.save_for_backward(in_nd, r0, torch.Tensor([tau,lam, weight]), rgt) 170 | # return r0.transpose(-1, -2) # TODO(daniel): getting rid of this? 171 | return r0 172 | 173 | @staticmethod 174 | def backward(ctx, grad_in): 175 | # grad_in = grad_in.transpose(-1, -2) # TODO(daniel): getting rid of this? 176 | in_nd, r0, config, rgt = ctx.saved_tensors 177 | tau = config[0] 178 | lam = config[1] 179 | weight = config[2] 180 | b = r0.shape[0] 181 | proj_kind = in_nd.shape[1] 182 | 183 | # use Riemannian optimization to get the next goal R 184 | if tau == -1: 185 | # use tau_gt if tau = -1 186 | # requires rgt to be passed in initially 187 | r_new = rgt 188 | else: 189 | # Eucliean gradient -> Riemannian gradient 190 | Jx = torch.zeros((b, 3, 3)).cuda() 191 | Jx[:, 2, 1] = 1 192 | Jx[:, 1, 2] = -1 193 | Jy = torch.zeros((b, 3, 3)).cuda() 194 | Jy[:, 0, 2] = 1 195 | Jy[:, 2, 0] = -1 196 | Jz = torch.zeros((b, 3, 3)).cuda() 197 | Jz[:, 0, 1] = -1 198 | Jz[:, 1, 0] = 1 199 | gx = (grad_in*torch.bmm(r0, Jx)).reshape(-1,9).sum(dim=1,keepdim=True) 200 | gy = (grad_in * torch.bmm(r0, Jy)).reshape(-1, 9).sum(dim=1,keepdim=True) 201 | gz = (grad_in * torch.bmm(r0, Jz)).reshape(-1, 9).sum(dim=1,keepdim=True) 202 | g = torch.cat([gx,gy,gz],1) 203 | 204 | # take one step 205 | delta_w = -tau * g 206 | 207 | # update R 208 | r_new = torch.bmm(r0, Rodrigues(delta_w)) 209 | 210 | # inverse & project 211 | if proj_kind == 6: 212 | r_proj_1 = (r_new[:, :, 0] * in_nd[:, :3]).sum(dim=1, keepdim=True) * r_new[:, :, 0] 213 | r_proj_2 = (r_new[:, :, 0] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 0] \ 214 | + (r_new[:, :, 1] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 1] 215 | r_reg_1 = lam * (r_proj_1 - r_new[:, :, 0]) 216 | r_reg_2 = lam * (r_proj_2 - r_new[:, :, 1]) 217 | gradient_nd = torch.cat([in_nd[:, :3] - r_proj_1 + r_reg_1, in_nd[:, 3:] - r_proj_2 + r_reg_2], 1) 218 | elif proj_kind == 9: 219 | SVD_proj = tools.compute_SVD_nearest_Mnlsew(in_nd.reshape(-1,3,3), r_new) 220 | gradient_nd = in_nd - SVD_proj + lam * (SVD_proj - r_new.reshape(-1,9)) 221 | elif proj_kind == 4: 222 | q_1 = tools.compute_quaternions_from_rotation_matrices(r_new) 223 | q_2 = -q_1 224 | normalized_nd = tools.normalize_vector(in_nd) 225 | q_new = torch.where( 226 | (q_1 - normalized_nd).norm(dim=1, keepdim=True) < (q_2 - normalized_nd).norm(dim=1, keepdim=True), 227 | q_1, q_2) 228 | q_proj = (in_nd * q_new).sum(dim=1, keepdim=True) * q_new 229 | gradient_nd = in_nd - q_proj + lam * (q_proj - q_new) 230 | elif proj_kind == 10: 231 | qg = tools.compute_quaternions_from_rotation_matrices(r_new) 232 | new_x = tools.compute_nearest_10d(in_nd, qg) 233 | reg_A = torch.eye(4, device=qg.device)[None].repeat(qg.shape[0],1,1) - torch.bmm(qg.unsqueeze(-1), qg.unsqueeze(-2)) 234 | reg_x = tools.convert_A_to_Avec(reg_A) 235 | gradient_nd = in_nd - new_x + lam * (new_x - reg_x) 236 | 237 | return gradient_nd * weight, None, None,None,None,None 238 | -------------------------------------------------------------------------------- /bc/pointnet2_classification.py: -------------------------------------------------------------------------------- 1 | # https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py 2 | import torch 3 | from torch_geometric.nn import MLP, PointConv, fps, global_max_pool, radius 4 | torch.set_printoptions(linewidth=180, precision=5) 5 | from pytorch3d.transforms import (Rotate, axis_angle_to_matrix) 6 | from rpmg.rpmg import (RPMG, simple_RPMG) 7 | from rpmg.rpmg_losses import rpmg_forward 8 | 9 | 10 | class SAModule(torch.nn.Module): 11 | def __init__(self, ratio, r, nn): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.r = r 15 | self.conv = PointConv(nn, add_self_loops=False) 16 | 17 | def forward(self, x, pos, batch): 18 | idx = fps(pos, batch, ratio=self.ratio) 19 | row, col = radius(pos, pos[idx], self.r, batch, batch[idx], 20 | max_num_neighbors=64) 21 | edge_index = torch.stack([col, row], dim=0) 22 | x_dst = None if x is None else x[idx] 23 | x = self.conv((x, x_dst), (pos, pos[idx]), edge_index) 24 | pos, batch = pos[idx], batch[idx] 25 | return x, pos, batch 26 | 27 | 28 | class GlobalSAModule(torch.nn.Module): 29 | def __init__(self, nn): 30 | super().__init__() 31 | self.nn = nn 32 | 33 | def forward(self, x, pos, batch): 34 | x = self.nn(torch.cat([x, pos], dim=1)) 35 | x = global_max_pool(x, batch) 36 | pos = pos.new_zeros((x.size(0), 3)) 37 | batch = torch.arange(x.size(0), device=batch.device) 38 | return x, pos, batch 39 | 40 | 41 | class PointNet2_Class(torch.nn.Module): 42 | """NOTE(daniel): PointNet++ from the NeurIPS 2017 paper. 43 | 44 | We need to adjust the input dimension to take into account our `data.x` 45 | which has the features, whereas `data.pos` is known to be 3D. Other than 46 | those changes, I haven't made changes compared to standard classification 47 | PointNet++. 48 | 49 | We need to construct `data.x, data.pos, data.batch` for the forward pass. 50 | For now I do this by constructing `Data` tuples when we have minibatches. 51 | PointNet++ may require smaller minibatch sizes compared to CNNs. On PyG, 52 | each PC for the ModelNet10 data is of dim (1024,3), and with a batch size 53 | of 32 it was taking ~5G of GPU RAM. 54 | 55 | 05/19/2022: if we assume geodesic distance, we must normalize output. 56 | 05/27/2022: adding special case of encoder type with classification but 57 | if we use pointwise loss. We use the per-point flow instead, which 58 | has the same minimum as the per-point MSE on the future tool. 59 | 05/31/2022: ah, fixed bug: I was not re-scaling the PCL values back to 60 | the original value, as I was in segm PN++, we should do that for a 61 | fair comparison with: {class PN++ and pointwise loss}. 62 | 06/02/2022: actually for pointwise baseline we really should remove any 63 | non-used rotations before we convert to axis-angle. That should help 64 | with training. 65 | """ 66 | 67 | def __init__(self, in_dim, out_dim, encoder_type, scale_pcl_val=None, 68 | normalize_quat=False, n_epochs=None, rpmg_lambda=None, lambda_rot=None): 69 | super().__init__() 70 | self.in_dim = in_dim # 3 for pos, then rest for segmentation 71 | self.out_dim = out_dim # the action dimension 72 | self.encoder_type = encoder_type 73 | self.scale_pcl_val = scale_pcl_val 74 | self.normalize_quat = normalize_quat 75 | self.n_epochs = n_epochs 76 | self.rpmg_lambda = rpmg_lambda 77 | self.lambda_rot = lambda_rot 78 | self.raw_out = None 79 | self._mask = None 80 | 81 | # Input channels account for both `pos` and node features. 82 | self.sa1_module = SAModule(0.5, 0.2, MLP([self.in_dim, 64, 64, 128])) 83 | self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256])) 84 | self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024])) 85 | 86 | self.mlp = MLP([1024, 512, 256, out_dim], dropout=0.5, batch_norm=False) 87 | 88 | def assign_clear_mask(self, mask): 89 | """ 90 | For pouring or scooping, in case we have a 6D vector and need to extract 91 | a transformation from it here (and want to clear out unused components). 92 | """ 93 | self._mask = mask 94 | 95 | def forward(self, data, info=None, epoch=None, rgt=None): 96 | """Standard forward pass, except we might need a custom `data`. 97 | 98 | Just return MLP output, no log softmax's as there's no classification, 99 | we are either doing regression or segmentation. This is for regression. 100 | If we are using geodesic distance, we assume last 4 parts form a quat, 101 | so we normalize the output. 102 | """ 103 | 104 | # Special case if we scaled everything (e.g., meter -> millimeter), need 105 | # to downscale the observation because otherwise PN++ produces NaNs. :( 106 | if self.scale_pcl_val is not None: 107 | data.pos /= self.scale_pcl_val 108 | 109 | sa0_out = (data.x, data.pos, data.batch) 110 | sa1_out = self.sa1_module(*sa0_out) 111 | sa2_out = self.sa2_module(*sa1_out) 112 | sa3_out = self.sa3_module(*sa2_out) 113 | x, _, _ = sa3_out 114 | out = self.mlp(x) 115 | 116 | # Should revert `data.pos` back to the original scale. 117 | if self.scale_pcl_val is not None: 118 | data.pos *= self.scale_pcl_val 119 | 120 | if self.normalize_quat: 121 | # shapes in operation: (batch_size, 4) / (batch_size, 1) 122 | assert self.out_dim == 7, self.out_dim 123 | out_q = out[:,3:] / torch.norm(out[:,3:], dim=1, keepdim=True) 124 | out = torch.cat([out[:,:3], out_q], dim=1) 125 | 126 | if self.encoder_type == 'pointnet_classif_6D_pointwise': 127 | # In this special case, we actually want to return flow, which 128 | # we can compare with ground truth to determine pointwise loss. 129 | # During inference time we need a special case to just take `out` 130 | # computed earlier. Similar to 'pointnet_dense_tf_6D_pointwise' 131 | # but dense_transf is directly the _only_ output of the PN++. 132 | self.raw_out = out 133 | 134 | # Get rotation center (scooping or pouring) for translation correction. 135 | if len(info.shape) == 1: 136 | info = info.unsqueeze(0) 137 | tool_tip_pos = info[:,:3] 138 | 139 | # The network doesn't predict flow, but we use pointwise matching loss. 140 | # Given (o(t), a(t)), we apply transform to get a predicted a(t). 141 | flows = [] 142 | 143 | for i in range(len(data.ptr)-1): 144 | tip_pos_one = tool_tip_pos[i][None, None, :] # (1,1,3) 145 | idx1 = data.ptr[i].detach().cpu().numpy().item() 146 | idx2 = data.ptr[i+1].detach().cpu().numpy().item() 147 | posi_one = data.pos[idx1:idx2] # just this PCL's xyz (tool+items). 148 | 149 | # Compute a transformation from just `dense_transf`, no SVD. 150 | dense_transf = self.raw_out[i] * self._mask # note the mask! 151 | mean_one = dense_transf[:3] 152 | rot_matrices = axis_angle_to_matrix(dense_transf[3:]).transpose(0, 1).unsqueeze(0) 153 | 154 | # Correct rotation frame. 155 | t_correction = (tip_pos_one - torch.bmm(tip_pos_one, rot_matrices)).squeeze(1) # (1,3) 156 | 157 | # Compute a transformation. 158 | trfm = Rotate(rot_matrices).translate(mean_one[None,:] + t_correction) 159 | 160 | # Applies on ALL points, but later in loss fxn, we filter by tool. 161 | trfm_xyz = trfm.transform_points(posi_one).squeeze(0) 162 | batch_flow = trfm_xyz - posi_one 163 | flows.append(batch_flow) 164 | 165 | return torch.cat(flows) 166 | elif self.encoder_type == 'pointnet_rpmg_pointwise': 167 | # WARNING: DO NOT USE WITH INTRINSIC ROTATIONS! 168 | out_rot = simple_RPMG.apply(out[:, 3:], 50., self.rpmg_lambda, self.lambda_rot) 169 | out_rot_raw = out_rot.reshape(-1, 9) 170 | self.raw_out = torch.cat((out[:, :3], out_rot_raw), dim=1) 171 | 172 | # Get rotation center (scooping or pouring) for translation correction. 173 | if len(info.shape) == 1: 174 | info = info.unsqueeze(0) 175 | tool_tip_pos = info[:,:3] 176 | 177 | # The network doesn't predict flow, but we use pointwise matching loss. 178 | # Given (o(t), a(t)), we apply transform to get a predicted a(t). 179 | flows = [] 180 | 181 | for i in range(len(data.ptr)-1): 182 | tip_pos_one = tool_tip_pos[i][None, None, :] # (1,1,3) 183 | idx1 = data.ptr[i].detach().cpu().numpy().item() 184 | idx2 = data.ptr[i+1].detach().cpu().numpy().item() 185 | posi_one = data.pos[idx1:idx2] # just this PCL's xyz (tool+items). 186 | 187 | # Compute a transformation from just `dense_transf`, no SVD. 188 | mean_one = out[i, :3] 189 | rot_matrices = out_rot[i].unsqueeze(0) 190 | 191 | # Correct rotation frame. 192 | t_correction = (tip_pos_one - torch.bmm(tip_pos_one, rot_matrices)).squeeze(1) # (1,3) 193 | 194 | # Compute a transformation. 195 | trfm = Rotate(rot_matrices).translate(mean_one[None,:] + t_correction) 196 | 197 | # Applies on ALL points, but later in loss fxn, we filter by tool. 198 | trfm_xyz = trfm.transform_points(posi_one).squeeze(0) 199 | batch_flow = trfm_xyz - posi_one 200 | flows.append(batch_flow) 201 | 202 | self.flow_per_pt = torch.cat(flows) 203 | self.flow_per_pt_r = None 204 | 205 | return torch.cat(flows) 206 | elif self.encoder_type == 'pointnet_rpmg': 207 | # We do the default thing for tau as in the RPMG paper 208 | # More concretely, we linearly scale tau from tau_init=1/20 to tau_converge=1/4 209 | # as training progresses. See section 4.3 (in RPMG) and B.1 for a proof of 210 | # tau_converge=1/4 211 | if epoch is None: 212 | tau = 1 / 4 213 | else: 214 | tau = 1 / 20 + (1 / 4 - 1 / 20) / 9 * min(epoch // (self.n_epochs // 10), 9) 215 | out_rot = simple_RPMG.apply(out[:, 3:], tau, self.rpmg_lambda, self.lambda_rot) 216 | out_rot = out_rot.reshape(-1, 9) 217 | 218 | out = torch.cat((out[:, :3], out_rot), dim=1) 219 | return out 220 | elif self.encoder_type == 'pointnet_rpmg_forward': 221 | # Brian/Chuer stand-alone method method returns (batch,3,3). This is just the 222 | # forward pass and doesn't produce the proper gradient, use as a baseline only. 223 | out_rot = rpmg_forward(out[:, 3:]) 224 | out_rot = out_rot.reshape(-1, 9) 225 | 226 | # Concatenate --> (batch, 3+9) 227 | out = torch.cat((out[:, :3], out_rot), dim=1) 228 | return out 229 | elif self.encoder_type == 'pointnet_rpmg_taugt': 230 | if rgt is not None: 231 | rgt = rgt.reshape(-1, 12)[:, 3:] 232 | rgt = rgt.reshape(-1, 3, 3) 233 | 234 | out_rot = simple_RPMG.apply(out[:, 3:], -1, self.rpmg_lambda, self.lambda_rot, rgt) 235 | out_rot = out_rot.reshape(-1, 9) 236 | 237 | out = torch.cat((out[:, :3], out_rot), dim=1) 238 | return out 239 | else: 240 | return out 241 | -------------------------------------------------------------------------------- /chester/setup_ec2_for_chester.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import re 3 | import sys 4 | import json 5 | import botocore 6 | import os 7 | from rllab.misc import console 8 | from string import Template 9 | import os.path as osp 10 | 11 | CHESTER_DIR = osp.dirname(__file__) 12 | ACCESS_KEY = os.environ["AWS_ACCESS_KEY"] 13 | ACCESS_SECRET = os.environ["AWS_ACCESS_SECRET"] 14 | S3_BUCKET_NAME = os.environ["RLLAB_S3_BUCKET"] 15 | 16 | ALL_REGION_AWS_SECURITY_GROUP_IDS = {} 17 | ALL_REGION_AWS_KEY_NAMES = {} 18 | 19 | CONFIG_TEMPLATE = Template(""" 20 | import os.path as osp 21 | import os 22 | 23 | PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..')) 24 | 25 | AWS_NETWORK_INTERFACES = [] 26 | 27 | MUJOCO_KEY_PATH = osp.expanduser("~/.mujoco") 28 | 29 | USE_GPU = False 30 | 31 | USE_TF = True 32 | 33 | AWS_REGION_NAME = "us-east-2" 34 | 35 | if USE_GPU: 36 | DOCKER_IMAGE = "dementrock/rllab3-shared-gpu" 37 | else: 38 | DOCKER_IMAGE = "dementrock/rllab3-shared" 39 | 40 | DOCKER_LOG_DIR = "/tmp/expt" 41 | 42 | AWS_S3_PATH = "s3://$s3_bucket_name/rllab/experiments" 43 | 44 | AWS_CODE_SYNC_S3_PATH = "s3://$s3_bucket_name/rllab/code" 45 | 46 | ALL_REGION_AWS_IMAGE_IDS = { 47 | "ap-northeast-1": "ami-002f0167", 48 | "ap-northeast-2": "ami-590bd937", 49 | "ap-south-1": "ami-77314318", 50 | "ap-southeast-1": "ami-1610a975", 51 | "ap-southeast-2": "ami-9dd4ddfe", 52 | "eu-central-1": "ami-63af720c", 53 | "eu-west-1": "ami-41484f27", 54 | "sa-east-1": "ami-b7234edb", 55 | "us-east-1": "ami-83f26195", 56 | "us-east-2": "ami-66614603", 57 | "us-west-1": "ami-576f4b37", 58 | "us-west-2": "ami-b8b62bd8" 59 | } 60 | 61 | AWS_IMAGE_ID = ALL_REGION_AWS_IMAGE_IDS[AWS_REGION_NAME] 62 | 63 | if USE_GPU: 64 | AWS_INSTANCE_TYPE = "g2.2xlarge" 65 | else: 66 | AWS_INSTANCE_TYPE = "c4.4xlarge" 67 | 68 | ALL_REGION_AWS_KEY_NAMES = $all_region_aws_key_names 69 | 70 | AWS_KEY_NAME = ALL_REGION_AWS_KEY_NAMES[AWS_REGION_NAME] 71 | 72 | AWS_SPOT = True 73 | 74 | AWS_SPOT_PRICE = '0.5' 75 | 76 | AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None) 77 | 78 | AWS_ACCESS_SECRET = os.environ.get("AWS_ACCESS_SECRET", None) 79 | 80 | AWS_IAM_INSTANCE_PROFILE_NAME = "rllab" 81 | 82 | AWS_SECURITY_GROUPS = ["rllab-sg"] 83 | 84 | ALL_REGION_AWS_SECURITY_GROUP_IDS = $all_region_aws_security_group_ids 85 | 86 | AWS_SECURITY_GROUP_IDS = ALL_REGION_AWS_SECURITY_GROUP_IDS[AWS_REGION_NAME] 87 | 88 | FAST_CODE_SYNC_IGNORES = [ 89 | ".git", 90 | "data", 91 | "data/local", 92 | "data/archive", 93 | "data/debug", 94 | "data/s3", 95 | "data/video", 96 | "src", 97 | ".idea", 98 | ".pods", 99 | "tests", 100 | "examples", 101 | "docs", 102 | ".idea", 103 | ".DS_Store", 104 | ".ipynb_checkpoints", 105 | "blackbox", 106 | "blackbox.zip", 107 | "*.pyc", 108 | "*.ipynb", 109 | "scratch-notebooks", 110 | "conopt_root", 111 | "private/key_pairs", 112 | ] 113 | 114 | FAST_CODE_SYNC = True 115 | 116 | """) 117 | 118 | 119 | def setup_iam(): 120 | iam_client = boto3.client( 121 | "iam", 122 | aws_access_key_id=ACCESS_KEY, 123 | aws_secret_access_key=ACCESS_SECRET, 124 | ) 125 | iam = boto3.resource('iam', aws_access_key_id=ACCESS_KEY, aws_secret_access_key=ACCESS_SECRET) 126 | 127 | # delete existing role if it exists 128 | try: 129 | existing_role = iam.Role('rllab') 130 | existing_role.load() 131 | # if role exists, delete and recreate 132 | if not query_yes_no( 133 | "There is an existing role named rllab. Proceed to delete everything rllab-related and recreate?", 134 | default="no"): 135 | sys.exit() 136 | print("Listing instance profiles...") 137 | inst_profiles = existing_role.instance_profiles.all() 138 | for prof in inst_profiles: 139 | for role in prof.roles: 140 | print("Removing role %s from instance profile %s" % (role.name, prof.name)) 141 | prof.remove_role(RoleName=role.name) 142 | print("Deleting instance profile %s" % prof.name) 143 | prof.delete() 144 | for policy in existing_role.policies.all(): 145 | print("Deleting inline policy %s" % policy.name) 146 | policy.delete() 147 | for policy in existing_role.attached_policies.all(): 148 | print("Detaching policy %s" % policy.arn) 149 | existing_role.detach_policy(PolicyArn=policy.arn) 150 | print("Deleting role") 151 | existing_role.delete() 152 | except botocore.exceptions.ClientError as e: 153 | if e.response['Error']['Code'] == 'NoSuchEntity': 154 | pass 155 | else: 156 | raise e 157 | 158 | print("Creating role rllab") 159 | iam_client.create_role( 160 | Path='/', 161 | RoleName='rllab', 162 | AssumeRolePolicyDocument=json.dumps({'Version': '2012-10-17', 'Statement': [ 163 | {'Action': 'sts:AssumeRole', 'Effect': 'Allow', 'Principal': {'Service': 'ec2.amazonaws.com'}}]}) 164 | ) 165 | 166 | role = iam.Role('rllab') 167 | print("Attaching policies") 168 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/AmazonS3FullAccess') 169 | role.attach_policy(PolicyArn='arn:aws:iam::aws:policy/ResourceGroupsandTagEditorFullAccess') 170 | 171 | print("Creating inline policies") 172 | iam_client.put_role_policy( 173 | RoleName=role.name, 174 | PolicyName='CreateTags', 175 | PolicyDocument=json.dumps({ 176 | "Version": "2012-10-17", 177 | "Statement": [ 178 | { 179 | "Effect": "Allow", 180 | "Action": ["ec2:CreateTags"], 181 | "Resource": ["*"] 182 | } 183 | ] 184 | }) 185 | ) 186 | iam_client.put_role_policy( 187 | RoleName=role.name, 188 | PolicyName='TerminateInstances', 189 | PolicyDocument=json.dumps({ 190 | "Version": "2012-10-17", 191 | "Statement": [ 192 | { 193 | "Sid": "Stmt1458019101000", 194 | "Effect": "Allow", 195 | "Action": [ 196 | "ec2:TerminateInstances" 197 | ], 198 | "Resource": [ 199 | "*" 200 | ] 201 | } 202 | ] 203 | }) 204 | ) 205 | 206 | print("Creating instance profile rllab") 207 | iam_client.create_instance_profile( 208 | InstanceProfileName='rllab', 209 | Path='/' 210 | ) 211 | print("Adding role rllab to instance profile rllab") 212 | iam_client.add_role_to_instance_profile( 213 | InstanceProfileName='rllab', 214 | RoleName='rllab' 215 | ) 216 | 217 | 218 | def setup_s3(): 219 | print("Creating S3 bucket at s3://%s" % S3_BUCKET_NAME) 220 | s3_client = boto3.client( 221 | "s3", 222 | aws_access_key_id=ACCESS_KEY, 223 | aws_secret_access_key=ACCESS_SECRET, 224 | ) 225 | try: 226 | s3_client.create_bucket( 227 | ACL='private', 228 | Bucket=S3_BUCKET_NAME, 229 | CreateBucketConfiguration={ 230 | 'LocationConstraint': 'us-east-2' 231 | } 232 | ) 233 | except botocore.exceptions.ClientError as e: 234 | if e.response['Error']['Code'] == 'BucketAlreadyExists': 235 | raise ValueError("Bucket %s already exists. Please reconfigure S3_BUCKET_NAME" % S3_BUCKET_NAME) from e 236 | elif e.response['Error']['Code'] == 'BucketAlreadyOwnedByYou': 237 | print("Bucket already created by you") 238 | else: 239 | raise e 240 | print("S3 bucket created") 241 | 242 | 243 | def setup_ec2(): 244 | for region in ["us-east-1", "us-east-2", "us-west-1", "us-west-2"]: 245 | print("Setting up region %s" % region) 246 | 247 | ec2 = boto3.resource( 248 | "ec2", 249 | region_name=region, 250 | aws_access_key_id=ACCESS_KEY, 251 | aws_secret_access_key=ACCESS_SECRET, 252 | ) 253 | ec2_client = boto3.client( 254 | "ec2", 255 | region_name=region, 256 | aws_access_key_id=ACCESS_KEY, 257 | aws_secret_access_key=ACCESS_SECRET, 258 | ) 259 | existing_vpcs = list(ec2.vpcs.all()) 260 | assert len(existing_vpcs) >= 1 261 | vpc = existing_vpcs[0] 262 | print("Creating security group in VPC %s" % str(vpc.id)) 263 | try: 264 | security_group = vpc.create_security_group( 265 | GroupName='rllab-sg', Description='Security group for rllab' 266 | ) 267 | except botocore.exceptions.ClientError as e: 268 | if e.response['Error']['Code'] == 'InvalidGroup.Duplicate': 269 | sgs = list(vpc.security_groups.filter(GroupNames=['rllab-sg'])) 270 | security_group = sgs[0] 271 | else: 272 | raise e 273 | 274 | ALL_REGION_AWS_SECURITY_GROUP_IDS[region] = [security_group.id] 275 | 276 | ec2_client.create_tags(Resources=[security_group.id], Tags=[{'Key': 'Name', 'Value': 'rllab-sg'}]) 277 | try: 278 | security_group.authorize_ingress(FromPort=22, ToPort=22, IpProtocol='tcp', CidrIp='0.0.0.0/0') 279 | except botocore.exceptions.ClientError as e: 280 | if e.response['Error']['Code'] == 'InvalidPermission.Duplicate': 281 | pass 282 | else: 283 | raise e 284 | print("Security group created with id %s" % str(security_group.id)) 285 | 286 | key_name = 'rllab-%s' % region 287 | try: 288 | print("Trying to create key pair with name %s" % key_name) 289 | key_pair = ec2_client.create_key_pair(KeyName=key_name) 290 | except botocore.exceptions.ClientError as e: 291 | if e.response['Error']['Code'] == 'InvalidKeyPair.Duplicate': 292 | if not query_yes_no("Key pair with name %s exists. Proceed to delete and recreate?" % key_name, "no"): 293 | sys.exit() 294 | print("Deleting existing key pair with name %s" % key_name) 295 | ec2_client.delete_key_pair(KeyName=key_name) 296 | print("Recreating key pair with name %s" % key_name) 297 | key_pair = ec2_client.create_key_pair(KeyName=key_name) 298 | else: 299 | raise e 300 | 301 | key_pair_folder_path = os.path.join(CHESTER_DIR, "private", "key_pairs") 302 | file_name = os.path.join(key_pair_folder_path, "%s.pem" % key_name) 303 | 304 | print("Saving keypair file") 305 | console.mkdir_p(key_pair_folder_path) 306 | with os.fdopen(os.open(file_name, os.O_WRONLY | os.O_CREAT, 0o600), 'w') as handle: 307 | handle.write(key_pair['KeyMaterial'] + '\n') 308 | 309 | # adding pem file to ssh 310 | os.system("ssh-add %s" % file_name) 311 | 312 | ALL_REGION_AWS_KEY_NAMES[region] = key_name 313 | 314 | 315 | def write_config(): 316 | print("Writing config file...") 317 | content = CONFIG_TEMPLATE.substitute( 318 | all_region_aws_key_names=json.dumps(ALL_REGION_AWS_KEY_NAMES, indent=4), 319 | all_region_aws_security_group_ids=json.dumps(ALL_REGION_AWS_SECURITY_GROUP_IDS, indent=4), 320 | s3_bucket_name=S3_BUCKET_NAME, 321 | ) 322 | config_personal_file = os.path.join(CHESTER_DIR, "config_ec2.py") 323 | if os.path.exists(config_personal_file): 324 | if not query_yes_no("config_ec2.py exists. Override?", "no"): 325 | sys.exit() 326 | with open(config_personal_file, "wb") as f: 327 | f.write(content.encode("utf-8")) 328 | 329 | 330 | def setup(): 331 | setup_s3() 332 | setup_iam() 333 | setup_ec2() 334 | write_config() 335 | 336 | 337 | def query_yes_no(question, default="yes"): 338 | """Ask a yes/no question via raw_input() and return their answer. 339 | 340 | "question" is a string that is presented to the user. 341 | "default" is the presumed answer if the user just hits . 342 | It must be "yes" (the default), "no" or None (meaning 343 | an answer is required of the user). 344 | 345 | The "answer" return value is True for "yes" or False for "no". 346 | """ 347 | valid = {"yes": True, "y": True, "ye": True, 348 | "no": False, "n": False} 349 | if default is None: 350 | prompt = " [y/n] " 351 | elif default == "yes": 352 | prompt = " [Y/n] " 353 | elif default == "no": 354 | prompt = " [y/N] " 355 | else: 356 | raise ValueError("invalid default answer: '%s'" % default) 357 | 358 | while True: 359 | sys.stdout.write(question + prompt) 360 | choice = input().lower() 361 | if default is not None and choice == '': 362 | return valid[default] 363 | elif choice in valid: 364 | return valid[choice] 365 | else: 366 | sys.stdout.write("Please respond with 'yes' or 'no' " 367 | "(or 'y' or 'n').\n") 368 | 369 | 370 | if __name__ == "__main__": 371 | setup() -------------------------------------------------------------------------------- /bc/rotations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from rpmg import tools 4 | from pyquaternion import Quaternion 5 | from scipy.spatial.transform import Rotation as Rot 6 | np.set_printoptions(precision=8) 7 | #from pytorch3d.transforms import quaternion_to_matrix 8 | 9 | # === ROTATION REPRESENTATION CONVERTERS === 10 | 11 | # class Example: 12 | # def convert_to(self, quaternion, tool_rotation): 13 | # """ 14 | # Parameters: quaternion - a PyQuaternion representing EXTRINSIC rotation 15 | # tool_rotation - a PyQuaternion representing the current tool rotation 16 | 17 | # Returns: rotation in a new format, in shape (K,), where K is the number 18 | # of dimensions of the rotation 19 | # """ 20 | # pass 21 | 22 | # def convert_from(self, example, tool_rotation): 23 | # """ 24 | # Parameters: rotation in the format, in shape (K,) 25 | # tool_rotation - a PyQuaternion representing the current tool rotation 26 | 27 | # Returns: a PyQuaternion representing that rotation 28 | # """ 29 | # pass 30 | 31 | class ExtrinsicToIntrinsic: 32 | def convert_to(self, quaternion, tool_rotation): 33 | return tool_rotation.inverse * quaternion * tool_rotation 34 | 35 | def convert_from(self, rotation, tool_rotation): 36 | return tool_rotation * rotation * tool_rotation.inverse 37 | 38 | class Compose: 39 | def __init__(self, *args): 40 | self.ops = args 41 | 42 | def convert_to(self, quaternion, tool_rotation): 43 | for op in self.ops: 44 | quaternion = op.convert_to(quaternion, tool_rotation) 45 | 46 | return quaternion 47 | 48 | def convert_from(self, rotation, tool_rotation): 49 | for op in self.ops[::-1]: 50 | rotation = op.convert_from(rotation, tool_rotation) 51 | 52 | return rotation 53 | 54 | class RotationMatrix: 55 | # NOTE(daniel): this does not depend on the rotation_dim, except for how 56 | # I need to make rtol and atol different for 9D rotations (why? not sure). 57 | 58 | def __init__(self, rot_dim): 59 | self.rot_dim = rot_dim 60 | if self.rot_dim == 9: 61 | self.rtol = 1e-2 62 | self.atol = 1e-2 63 | else: 64 | self.rtol = 1e-5 65 | self.atol = 1e-5 66 | 67 | def convert_to(self, quaternion, tool_rotation): 68 | matrix_3x3 = quaternion.rotation_matrix 69 | return matrix_3x3.flatten() 70 | 71 | def convert_from(self, rotation, tool_rotation): 72 | # Use pyquaternion 0.9.9 to get rtol and atol args working. 73 | matrix_3x3 = rotation.reshape((3,3)) 74 | return Quaternion(matrix=matrix_3x3, rtol=self.rtol, atol=self.atol) 75 | 76 | class AxisAngle: 77 | def convert_to(self, quaternion, tool_rotation): 78 | axis = quaternion.get_axis(undefined=np.array([0., 0., 1.])) 79 | dtheta = quaternion.radians 80 | 81 | return axis * dtheta 82 | 83 | def convert_from(self, axis_angle, tool_rotation): 84 | dtheta = np.linalg.norm(axis_angle) 85 | if dtheta != 0: 86 | axis = axis_angle / dtheta 87 | return Quaternion(axis=axis, angle=dtheta) 88 | else: 89 | # identity rotation 90 | return Quaternion() 91 | 92 | class FlowConverter(AxisAngle): 93 | # Flow needs to process quaternion further in the replay buffer 94 | def convert_to(self, quaternion, tool_rotation): 95 | return quaternion 96 | 97 | class RPMGFlowConverter(RotationMatrix): 98 | # RPMG outputs a rotation matrix, not axis angle 99 | def convert_to(self, quaternion, tool_rotation): 100 | return quaternion 101 | 102 | class NoRPMG: 103 | def __init__(self, rot_dim): 104 | self.rot_dim = rot_dim 105 | 106 | def convert_to(self, rot_matrix, tool_rotation): 107 | rot_matrix = rot_matrix.reshape((3, 3)) 108 | R = torch.from_numpy(rot_matrix).unsqueeze(0) 109 | 110 | if self.rot_dim == 6: 111 | x = torch.cat([R[:, :, 0], R[:, :, 1]], dim=1) 112 | elif self.rot_dim == 9: 113 | x = R.reshape(-1, 9) 114 | elif self.rot_dim == 4: 115 | x = tools.compute_quaternions_from_rotation_matrices(R) 116 | elif self.rot_dim == 10: 117 | q = tools.compute_quaternions_from_rotation_matrices(R) 118 | reg_A = torch.eye(4, device=q.device)[None].repeat(q.shape[0], 1, 1) \ 119 | - torch.bmm(q.unsqueeze(-1), q.unsqueeze(-2)) 120 | x = tools.convert_A_to_Avec(reg_A) 121 | else: 122 | raise NotImplementedError 123 | 124 | return x.flatten().numpy() 125 | 126 | def convert_from(self, rotation, tool_rotation): 127 | rotation = torch.from_numpy(rotation).unsqueeze(0) 128 | if self.rot_dim == 6: 129 | r0 = tools.compute_rotation_matrix_from_ortho6d(rotation) 130 | elif self.rot_dim == 9: 131 | r0 = tools.symmetric_orthogonalization(rotation) 132 | elif self.rot_dim == 4: 133 | r0 = tools.compute_rotation_matrix_from_quaternion(rotation) 134 | elif self.rot_dim == 10: 135 | r0 = tools.compute_rotation_matrix_from_10d(rotation) 136 | else: 137 | raise NotImplementedError 138 | return r0.flatten().numpy() 139 | 140 | # === CONVERTER REGISTRATION === 141 | 142 | CONVERTERS = { 143 | "flow": FlowConverter(), 144 | "axis_angle": AxisAngle(), 145 | "intrinsic_axis_angle": Compose( 146 | ExtrinsicToIntrinsic(), 147 | AxisAngle(), 148 | ), 149 | "rotation_4D": RotationMatrix(rot_dim=4), 150 | "rotation_6D": RotationMatrix(rot_dim=6), 151 | "rotation_9D": RotationMatrix(rot_dim=9), 152 | "rotation_10D": RotationMatrix(rot_dim=10), 153 | "intrinsic_rotation_6D": Compose( 154 | ExtrinsicToIntrinsic(), 155 | RotationMatrix(rot_dim=6), 156 | ), 157 | "rpmg_flow_6D": RPMGFlowConverter(rot_dim=6), 158 | "no_rpmg_6D": Compose( 159 | RotationMatrix(rot_dim=6), 160 | NoRPMG(rot_dim=6), 161 | ), 162 | "no_rpmg_9D": Compose( 163 | RotationMatrix(rot_dim=9), 164 | NoRPMG(rot_dim=9), 165 | ), 166 | "no_rpmg_10D": Compose( 167 | RotationMatrix(rot_dim=10), 168 | NoRPMG(rot_dim=10), 169 | ), 170 | } 171 | 172 | # === ENV SPECIFIC CANONICALIZATION === 173 | 174 | def _canonicalize_action(env_name, obs_tuple, act_raw, qt_current): 175 | """ 176 | Gets global delta action (in PyQuaternion form) from raw action and 177 | observation, depending on environment 178 | This is especially important as PourWater6D has act_raw as delta euler 179 | angle, while MixedMedia uses *local* axis-angle. Both need to be 180 | converted to global axis-angle. 181 | Note [qt_current] is only used for MixedMedia 182 | 183 | Returns [tool_rotation] the extrinsic rotation of the tool currently 184 | """ 185 | act_tran = act_raw[:3] 186 | axis = act_raw[3:] 187 | dtheta = np.linalg.norm(axis) 188 | 189 | # Unfortunately there are some 'hacks' here. If PourWater, we negate 190 | # the 3rd value because the axis in SoftGym is actually (0,0,-1) when 191 | # handling rotations. A positive value means dropping water, and that 192 | # is clockwise w.r.t. +z, but if we keep it here that would negate it 193 | # for PyQuaternion. OK to do here as the purpose is to get accurate 194 | # flow estimates, but check if we can simplify the env as well. 195 | if env_name == 'PourWater': 196 | tool_origin = obs_tuple[0][0,:3] # shape (10,14) 197 | if dtheta == 0: 198 | axis = np.array([0., 0., -1.]) 199 | else: 200 | # FYI, I tried visualizing with and without this, and we do 201 | # need to negate to see proper flow vectors. 202 | axis[2] = -axis[2] 203 | 204 | axis = axis / np.linalg.norm(axis) 205 | delta_quat = Quaternion(axis=axis, angle=dtheta) 206 | tool_raw = obs_tuple[0][0, 6:10] 207 | tool_rotation = Quaternion(w=tool_raw[3], x=tool_raw[0], y=tool_raw[1], z=tool_raw[2]) 208 | elif env_name == 'PourWater6D': 209 | tool_origin = obs_tuple[0][0, :3] # shape (10, 14) 210 | 211 | # Get current global rotation in intrinsic ZYX euler angles due to the way 212 | # we handle rotation 213 | curr_rot = Rot.from_quat(obs_tuple[0][0, 6:10]) 214 | curr_euler = curr_rot.as_euler('zyx') 215 | curr_z, curr_y, curr_x = -curr_euler[0], curr_euler[1], curr_euler[2] 216 | 217 | # Compute new tool orientation 218 | theta_x = curr_x + act_raw[3] 219 | theta_y = curr_y + act_raw[4] 220 | theta_z = curr_z + act_raw[5] 221 | axis_ang_z = np.array([0., 0., -1.]) 222 | axis_ang_y = np.array([0., 1., 0.]) 223 | axis_ang_x = np.array([1, 0., 0.]) 224 | axis_angle_z = axis_ang_z * theta_z 225 | axis_angle_y = axis_ang_y * theta_y 226 | axis_angle_x = axis_ang_x * theta_x 227 | final_rot = Rot.from_rotvec(axis_angle_x) * Rot.from_rotvec(axis_angle_y) * Rot.from_rotvec(axis_angle_z) 228 | 229 | # Compute rotation difference 230 | delta_rot = final_rot * curr_rot.inv() 231 | delta_raw = delta_rot.as_quat() 232 | delta_quat = Quaternion(w=delta_raw[3], x=delta_raw[0], y=delta_raw[1], z=delta_raw[2]) 233 | 234 | # Compute current tool rotation for later if we want to convert to intrinsic rotations 235 | tool_raw = obs_tuple[0][0, 6:10] 236 | tool_rotation = Quaternion(w=tool_raw[3], x=tool_raw[0], y=tool_raw[1], z=tool_raw[2]) 237 | elif env_name in ['MMOneSphere', 'MMMultiSphere']: 238 | tool_origin = obs_tuple[0][:3] # tool tip position 239 | if dtheta == 0: 240 | axis = np.array([0., -1., 0.]) 241 | 242 | axis = axis / np.linalg.norm(axis) 243 | 244 | # Change axis from local frame to world to work with 6DoF 245 | axis_world = qt_current.rotate(axis) 246 | delta_quat = Quaternion(axis=axis_world, angle=dtheta) 247 | 248 | tool_rotation = qt_current 249 | 250 | return act_tran, tool_origin, delta_quat, tool_rotation 251 | 252 | def _decanonicalize_action(env_name, delta_quat, env): 253 | dtheta = delta_quat.radians 254 | 255 | if (env_name in ['MMOneSphere', 'MMMultiSphere']): 256 | axis = delta_quat.get_axis(undefined=np.array([0., 0., 1.])) 257 | if dtheta != 0: 258 | # Get global rotation axis from env 259 | curr_q = env.tool_state[0, 6:10] 260 | qt_current = Quaternion(w=curr_q[3], x=curr_q[0], y=curr_q[1], z=curr_q[2]) 261 | inv_qt_current = qt_current.inverse 262 | local_axis = inv_qt_current.rotate(axis) * dtheta 263 | return local_axis 264 | else: 265 | return np.zeros(3) 266 | elif (env_name == 'PourWater'): 267 | # Now positive rotation is w.r.t., the negative z axis, not the 268 | # positive z axis. :( Also only do this if we used pytorch geometric. 269 | # I think this condition should be sufficient but DOUBLE CHECK. 270 | # UPDATE(06/02/2022), ah also needs to be done with pointwise losses! 271 | # UPDATE(06/03/2022), if SVD but with MSE after it, don't do this. 272 | axis = delta_quat.get_axis(undefined=np.array([0., 0., 1.])) 273 | action = axis * dtheta 274 | action[2] = -action[2] 275 | return action 276 | elif (env_name in ['PourWater6D']): 277 | # Sadly, we need to convert from delta axis angle to delta euler angle. oof 278 | if dtheta != 0: 279 | # Convert from delta axis angle -> delta scipy rotation 280 | delta_quat._normalise() # knock on wood 281 | delta_items = delta_quat.elements 282 | delta_raw = np.array([delta_items[1], delta_items[2], delta_items[3], delta_items[0]]) 283 | delta_rot = Rot.from_quat(delta_raw) 284 | 285 | # Get current env scipy rotation 286 | curr_rot = Rot.from_quat(env.glass_states[0, 6:10]) 287 | 288 | # Perform rotation 289 | final_rot = delta_rot * curr_rot 290 | 291 | # Convert from final scipy rotation -> final euler angles 292 | final_euler = final_rot.as_euler('zyx') 293 | # This flipping of [-final_euler[0]] is where we handle the 294 | # flipped z-axis in PourWater 295 | final_z, final_y, final_x = -final_euler[0], final_euler[1], final_euler[2] 296 | 297 | # Convert from final -> delta euler angles 298 | dtheta_x = final_x - env.glass_rotation_x 299 | dtheta_y = final_y - env.glass_rotation_y 300 | dtheta_z = final_z - env.glass_rotation_z 301 | 302 | return np.array([dtheta_x, dtheta_y, dtheta_z]) 303 | else: 304 | return np.zeros(3) 305 | 306 | def _env_to_tool_rotation(env_name, env): 307 | if (env_name in ['MMOneSphere', 'MMMultiSphere']): 308 | curr_q = env.tool_state[0, 6:10] 309 | return Quaternion(w=curr_q[3], x=curr_q[0], y=curr_q[1], z=curr_q[2]) 310 | elif env_name in ['PourWater', 'PourWater6D']: 311 | curr_q = env.glass_states[0, 6:10] 312 | return Quaternion(w=curr_q[3], x=curr_q[0], y=curr_q[1], z=curr_q[2]) 313 | 314 | # === DRIVER CODE === 315 | 316 | def convert_action(rotation_representation, env_name, obs_tuple, act_raw, qt_current): 317 | assert rotation_representation in CONVERTERS, f"Invalid rotation representation {rotation_representation}" 318 | act_tran, tool_origin, delta_quat, tool_rotation = _canonicalize_action(env_name, obs_tuple, act_raw, qt_current) 319 | converted_rotation = CONVERTERS[rotation_representation].convert_to(delta_quat, tool_rotation) 320 | return act_tran, tool_origin, converted_rotation 321 | 322 | def deconvert_action(rotation_representation, env_name, action, env): 323 | assert rotation_representation in CONVERTERS, f"Invalid rotation representation {rotation_representation}" 324 | tool_rotation = _env_to_tool_rotation(env_name, env) 325 | delta_quat = CONVERTERS[rotation_representation].convert_from(action[3:], tool_rotation) 326 | env_specific_rotation = _decanonicalize_action(env_name, delta_quat, env) 327 | return np.concatenate((action[:3], env_specific_rotation)) 328 | -------------------------------------------------------------------------------- /bc/se3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pickle 4 | from pytorch3d.transforms import ( 5 | Transform3d, Rotate, rotation_6d_to_matrix, axis_angle_to_matrix, 6 | matrix_to_quaternion, matrix_to_euler_angles, quaternion_to_axis_angle 7 | ) 8 | import torch 9 | import numpy as np 10 | DEG_TO_RAD = np.pi / 180. 11 | RAD_TO_DEG = 180 / np.pi 12 | 13 | 14 | def to_transform3d(x, rot_function = rotation_6d_to_matrix): 15 | trans = x[:,:3] 16 | rot = x[:,3:] 17 | return Transform3d(device=x.device).compose(Rotate(rot_function(rot), 18 | device=rot.device)).translate(trans) 19 | 20 | 21 | def transform_points(points, x, rot_function = rotation_6d_to_matrix): 22 | t = x[:,:3] 23 | R = rot_function(x[:,3:]) 24 | return (torch.bmm(R, points.transpose(-2,-1)) + t.unsqueeze(-1)).transpose(-2,-1) 25 | 26 | 27 | def transform3d_to(T, device): 28 | T = T.to(device) 29 | T = T.to(device) 30 | T._transforms = [ 31 | t.to(device) for t in T._transforms 32 | ] 33 | return T 34 | 35 | 36 | def random_se3(N, rot_var = np.pi/180 * 5, trans_var = 0.1, device = None): 37 | R = axis_angle_to_matrix(torch.randn(N,3, device=device)*rot_var) 38 | t = torch.randn(N,3, device=device)*trans_var 39 | return Rotate(R, device=device).translate(t) 40 | 41 | 42 | def symmetric_orthogonalization(M): 43 | """Maps arbitrary input matrices onto SO(3) via symmetric orthogonalization. 44 | (modified from https://github.com/amakadia/svd_for_pose) 45 | 46 | M: should have size [batch_size, 3, 3] 47 | 48 | Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3). 49 | """ 50 | U, _, Vh = torch.linalg.svd(M) 51 | det = torch.det(torch.bmm(U, Vh)).view(-1, 1, 1) 52 | Vh = torch.cat((Vh[:, :2, :], Vh[:, -1:, :] * det), 1) 53 | R = U @ Vh 54 | return R 55 | 56 | 57 | def flow2pose(xyz, flow, weights=None, return_transform3d=False, 58 | return_quaternions=False, world_frameify=True): 59 | """Flow2Pose via SVD. 60 | 61 | Operates on minibatches of `B` point clouds, each with `N` points. Assumes 62 | all point clouds have `N` points, but in practice we only call this with 63 | minibatch size 1 and we get rid of non-tool points before calling this. 64 | 65 | Outputs a rotation with the origin at the tool centroid. This is not the 66 | origin of the frame where actions are expressed, which is at the tool tip. 67 | 68 | Parameters 69 | ---------- 70 | xyz: point clouds of shape (B,N,3). This gets zero-centered so it's OK if it 71 | is not already centered. 72 | flow: corresponding flow of shape (B,N,3). As with xyz, it gets zero-centered. 73 | weights: weights for the N points, set to None for uniform weighting, for now 74 | I don't think we want to weigh certain points more than others, and it 75 | could be tricky when points can technically be any order in a PCL. 76 | return_transform3d: Used if we want to return a transform, for which we apply 77 | on a set of point clouds. This is what Brian/Chuer use to compute losses, 78 | by applying this on original points and comparing point-wise MSEs. 79 | return_quaternions: Use if we want to convert rotation matrices to quaternions. 80 | Uses format of (wxyz) format, so the identity quanternion is (1,0,0,0). 81 | world_frameify: Use if we want to correct the translation vector so that the 82 | transformation is expressed w.r.t. the world frame. 83 | """ 84 | if weights is None: 85 | weights = torch.ones(xyz.shape[:-1], device=xyz.device) 86 | ww = (weights / weights.sum(dim=-1, keepdims=True)).unsqueeze(-1) 87 | 88 | # xyz_mean shape: ((B,N,1), (B,N,3)) mult -> (B,N,3) -> sum -> (B,1,3) 89 | xyz_mean = (ww * xyz).sum(dim=1, keepdims=True) 90 | xyz_demean = xyz - xyz_mean # broadcast `xyz_mean`, still shape (B,N,3) 91 | 92 | # As with xyz positions, find (weighted) mean of flow, shape (B,1,3). 93 | flow_mean = (ww * flow).sum(dim=1, keepdims=True) 94 | 95 | # Zero-mean positions plus zero-mean flow to find new points. 96 | xyz_trans = xyz_demean + flow - flow_mean # (B,N,3) 97 | 98 | # Batch matrix-multiply, get X: (B,3,3), each (3x3) matrix is in SO(3). 99 | X = torch.bmm(xyz_demean.transpose(-2,-1), # (B,3,N) 100 | ww * xyz_trans) # (B,N,3) 101 | 102 | # Rotation matrix in SO(3) for each mb item, (B,3,3). 103 | R = symmetric_orthogonalization(X) 104 | 105 | # 3D translation vector for eacb mb item, (B,3) due to squeezing. 106 | if world_frameify: 107 | t = (flow_mean + xyz_mean - torch.bmm(xyz_mean, R)).squeeze(1) 108 | else: 109 | t = flow_mean.squeeze(1) 110 | 111 | if return_transform3d: 112 | return Rotate(R).translate(t) 113 | if return_quaternions: 114 | quats = matrix_to_quaternion(matrix=R) 115 | return quats, t 116 | return R, t 117 | 118 | eps = 1e-9 119 | 120 | 121 | def dualflow2pose(xyz, flow, polarity, weights = None, return_transform3d = False): 122 | if(weights is None): 123 | weights = torch.ones(xyz.shape[:-1], device=xyz.device) 124 | w = (weights / weights.sum(dim=-1, keepdims=True)).unsqueeze(-1) 125 | w_p = (polarity * weights).unsqueeze(-1) 126 | w_p_sum = w_p.sum(dim=1, keepdims=True) 127 | w_p = w_p / w_p_sum.clamp(min=eps) 128 | w_n = ((1-polarity) * weights).unsqueeze(-1) 129 | w_n_sum = w_n.sum(dim=1, keepdims=True) 130 | w_n = w_n / w_n_sum.clamp(min=eps) 131 | 132 | 133 | xyz_mean_p = (w_p * xyz).sum(dim=1, keepdims=True) 134 | xyz_demean_p = xyz - xyz_mean_p 135 | xyz_mean_n = (w_n * xyz).sum(dim=1, keepdims=True) 136 | xyz_demean_n = xyz - xyz_mean_n 137 | 138 | flow_mean_p = (w_p * flow).sum(dim=1, keepdims=True) 139 | flow_demean_p = flow - flow_mean_p 140 | flow_mean_n = (w_n * flow).sum(dim=1, keepdims=True) 141 | flow_demean_n = flow - flow_mean_n 142 | 143 | mask = (polarity.unsqueeze(-1).expand(-1,-1,3)==1) 144 | xyz_1 = torch.where(mask, 145 | xyz_demean_p, xyz_demean_n + flow_demean_n) 146 | xyz_2 = torch.where(mask, 147 | xyz_demean_p + flow_demean_p, xyz_demean_n) 148 | 149 | X = torch.bmm(xyz_1.transpose(-2,-1), w*xyz_2) 150 | 151 | R = symmetric_orthogonalization(X) 152 | t_p = (flow_mean_p + xyz_mean_p - torch.bmm(xyz_mean_p, R)) 153 | t_n = (xyz_mean_n - torch.bmm(flow_mean_n + xyz_mean_n, R)) 154 | 155 | t = ((w_p_sum * t_p + w_n_sum * t_n)/(w_p_sum + w_n_sum)).squeeze(1) 156 | 157 | if(return_transform3d): 158 | return Rotate(R).translate(t) 159 | return R, t 160 | 161 | 162 | def points2pose(xyz1, xyz2, weights = None, return_transform3d = False): 163 | if(weights is None): 164 | weights = torch.ones(xyz1.shape[:-1], device=xyz1.device) 165 | w = (weights / weights.sum(dim=-1, keepdims=True)).unsqueeze(-1) 166 | 167 | xyz1_mean = (w * xyz1).sum(dim=1, keepdims=True) 168 | xyz1_demean = xyz1 - xyz1_mean 169 | 170 | xyz2_mean = (w * xyz2).sum(dim=1, keepdims=True) 171 | xyz2_demean = xyz2 - xyz2_mean 172 | 173 | X = torch.bmm(xyz1_demean.transpose(-2,-1), 174 | w*xyz2_demean) 175 | 176 | R = symmetric_orthogonalization(X) 177 | t = (xyz2_mean - torch.bmm(xyz1_mean, R)).squeeze(1) 178 | 179 | if(return_transform3d): 180 | return Rotate(R).translate(t) 181 | return R, t 182 | 183 | 184 | def _debug_bc_data(device): 185 | """Debug our Behavioral Cloning data. 186 | 187 | Particularly, the data with the extremely simple 1DoF rotation about the 188 | y axis. Given the ground-truth flow, we should get the desired rotation. 189 | If this doesn't work, there's a problem with SVD. If it does work, then 190 | we at least know this step is OK? 191 | 192 | This could also be useful for debugging the translation corrections we 193 | use, since the center of rotation is not the centroid of the tool, but 194 | at the tip of the tool's stick. 195 | """ 196 | def get_obs_tool_flow(pcl, tool_flow): 197 | # Copied from bc utils 198 | pcl_tool = pcl[:,3] == 1 199 | tf_pts = tool_flow['points'] 200 | tf_flow = tool_flow['flow'] 201 | n_tool_pts_obs = np.sum(pcl_tool) 202 | n_tool_pts_flow = tf_pts.shape[0] 203 | # First shapes only equal if: (a) fewer than max pts or (b) no item/distr. 204 | assert tf_pts.shape[0] <= pcl.shape[0], f'{tf_pts.shape}, {pcl.shape}' 205 | assert tf_pts.shape == tf_flow.shape, f'{tf_pts.shape}, {tf_flow.shape}' 206 | assert n_tool_pts_obs == n_tool_pts_flow, f'{n_tool_pts_obs}, {n_tool_pts_flow}' 207 | assert np.array_equal(pcl[:n_tool_pts_obs,:3], tf_pts) # yay :) 208 | a = np.zeros((pcl.shape[0], 3)) # all non-tool point rows get 0s 209 | a[:n_tool_pts_obs] = tf_flow # actually encode flow for BC purposes 210 | return (pcl, a) 211 | 212 | # This is the v06 with the 1DoF rotation about y axis at the start. 213 | PATH = os.path.join( 214 | '/data/seita/softgym_mm/data_demo/', 215 | 'MMOneSphere_v01_BClone_filtered_ladle_algorithmic_v06_nVars_2000_obs_combo_act_translation_axis_angle', 216 | 'BC_0000_600.pkl' 217 | ) 218 | with open(PATH, 'rb') as fh: 219 | data = pickle.load(fh) 220 | 221 | # obses: list of tuples (different obs types), acts: list of axis-angles. 222 | obses = data['obs'] 223 | acts = data['act_raw'] 224 | assert len(obses) == len(acts) 225 | 226 | # Remember that to get flow we need the _next_ observation. 227 | t = 0 228 | obs_t = obses[t][3] # PCL at idx=3 229 | act_t = acts[t] # the axis-angle formulation 230 | info_tp1 = obses[t+1][4] # Flow at idx=4 for _next_ time step. 231 | obs_t, act_flow_t = get_obs_tool_flow(obs_t, info_tp1) 232 | print(f'Testing PCLs at t={t} shaped {obs_t.shape} with action: {act_t}') 233 | print(f' act_flow_t: {act_flow_t.shape}') 234 | 235 | # Let's visualize this, but make flow longer for clarity. EDIT: for some 236 | # reason I can't create it here, so let's put it in my MWE. 237 | 238 | # Get the pose from the flow. 239 | xyz = torch.as_tensor(np.array([obs_t[:,:3]])).float() 240 | flow = torch.as_tensor(np.array([act_flow_t])).float() 241 | print(f'Calling flow2pose, xyz: {xyz.shape}, flow: {flow.shape}') 242 | pose = flow2pose( 243 | xyz=xyz.to(device), 244 | flow=flow.to(device), 245 | weights=None, 246 | return_transform3d=False, 247 | return_quaternions=False, 248 | ) 249 | 250 | # Debugging, could be useful to show a visualization? 251 | R, t = pose 252 | print(f'\nFinished SVD! Shape of R, t: {R.shape}, {t.shape}\n') 253 | R = R[0] 254 | t = t[0] 255 | print('\nThe rotation and translation:') 256 | print(R) 257 | print(t) 258 | print('\nIs this a rotation matrix? This should be the identity.') 259 | RT_times_R = torch.matmul(torch.transpose(R, 0, 1), R) 260 | print(RT_times_R) 261 | print('\nEuler angles:') 262 | print(matrix_to_euler_angles(R, convention='XYZ')) 263 | print('\nQuaternion:') 264 | quat = matrix_to_quaternion(R) 265 | print(quat) 266 | print('\nAxis-angle:') 267 | # WAIT! This is producing the OPPOSITE rotation! 268 | print(quaternion_to_axis_angle(quat)) 269 | sys.exit() 270 | 271 | 272 | if __name__ == "__main__": 273 | # Try debugging / testing the flow / SVD methods. 274 | np.set_printoptions(suppress=True, precision=6, linewidth=150, edgeitems=20) 275 | torch.set_printoptions(sci_mode=False, precision=6, linewidth=150, edgeitems=20) 276 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 277 | #_debug_bc_data(device) 278 | 279 | # Minibatch of B point clouds, 10 points in it, with 3D position or flow. 280 | B = 4 281 | xyz = torch.rand(B,10000,3) 282 | flow = torch.rand(B,10000,3) * 0.01 283 | ret_trans = False 284 | ret_quat = False 285 | world_frameify = True 286 | 287 | ## For debugging, let's actually do translation only for one item. 288 | ## If we do this, we get a tensor of (1,0,0,0) as output quaternion (good). 289 | #flow[0] = torch.ones((1,3)) * 0.5 290 | #flow[0,0,:] = torch.ones((1,3)) 291 | 292 | # More debugging: what if we just scale the first item (pos and flow)? 293 | # If world_frameify=False, translations are the same subject to positioning 294 | # of decimal points (good). 295 | # Rotations are _almost_ the same but up to 3 decimal points: 296 | # axis-angle: tensor([-0.091616, -0.033090, 0.173777]), magn: 0.1502 297 | # axis-angle: tensor([-0.091667, -0.033330, 0.174000]), magn: 0.1503 298 | # axis-angle: tensor([-0.091644, -0.033659, 0.173811]), magn: 0.1506 299 | # axis-angle: tensor([-0.091447, -0.033111, 0.173603]), magn: 0.1504 300 | # It does seem to be close enough and might not matter too much in practice. 301 | # Magntiudes seem to be close enough that the policy will do similar things. 302 | # 303 | # What if world_frameify=True? That would only affect the quality of the 304 | # resulting returned translation. It still scales well though not as perfectly 305 | # as just averaging the flow vectors (makes sense). 306 | xyz[1] = xyz[0] * 10. 307 | flow[1] = flow[0] * 10. 308 | xyz[2] = xyz[0] * 100. 309 | flow[2] = flow[0] * 100. 310 | xyz[3] = xyz[0] * 1000. 311 | flow[3] = flow[0] * 1000. 312 | 313 | # Get the pose from the flow. 314 | pose = flow2pose( 315 | xyz=xyz.to(device), 316 | flow=flow.to(device), 317 | weights=None, 318 | return_transform3d=ret_trans, # Brian/Chuer use True 319 | return_quaternions=ret_quat, # something new 320 | world_frameify=world_frameify, # something new 321 | ) 322 | if ret_trans: 323 | sys.exit() 324 | 325 | # Debugging, could be useful to show a visualization? 326 | R, t = pose 327 | print(f'Shape of R, t: {R.shape}, {t.shape}\n') 328 | for b in range(B): 329 | print(f'\n---------- On minibatch item {b} ----------') 330 | print(R[b]) 331 | quat_b = matrix_to_quaternion(R[b]) 332 | aang_b = quaternion_to_axis_angle(quat_b) 333 | magn_rad = torch.norm(aang_b) 334 | magn_deg = magn_rad * RAD_TO_DEG 335 | print(f'quaternion: {quat_b}') 336 | print(f'axis-angle: {aang_b}, magnitude: {magn_rad:0.3f} rad, {magn_deg:0.3f} deg') 337 | print(f'translation: {t[b]}') 338 | if not ret_quat: 339 | print(f'To confirm rotation matrix (also check it is on cuda):') 340 | RTR = torch.matmul(torch.transpose(R[b], 0, 1), R[b]) 341 | print(RTR) 342 | -------------------------------------------------------------------------------- /tests/tool_reduce_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 87, 6 | "id": "45533134-06e9-4f26-8b42-e7c2b7868227", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pickle\n", 11 | "from softgym.softgym.utils import visualization\n", 12 | "from torch_geometric.nn import fps\n", 13 | "import os\n", 14 | "import torch\n", 15 | "import numpy as np\n", 16 | "from pyquaternion import Quaternion" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 4, 22 | "id": "be2bd76c-4089-4fc0-86dc-a1729f99ba28", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "DATA_DIR = \"/data/seita/softgym_mm/data_demo/MMOneSphere_v01_BClone_filtered_ladle_algorithmic_v04_nVars_2000_obs_combo_act_translation_axis_angle\"" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 8, 32 | "id": "9a3db0d9-424d-44e0-ad01-e58a82d6f5bc", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "with open(os.path.join(DATA_DIR, \"BC_0000_100.pkl\"), 'rb') as f:\n", 37 | " data = pickle.load(f)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 21, 43 | "id": "f171c4f8-889f-4f0e-8530-ac6fac787466", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "obs_p = [data['obs'][0][3]]" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 22, 53 | "id": "e71f063b-f126-496c-8f63-ddbac7a9fd33", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "MoviePy - Building file ./point_cloud_segm.gif with imageio.\n" 61 | ] 62 | }, 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | " \r" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "visualization.save_pointclouds(obs_p, savedir='.')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 45, 78 | "id": "68b29fa7-c735-4846-819a-80817509c4b4", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "pcl = obs_p[0]\n", 83 | "\n", 84 | "i_tool = np.where(pcl[:,3] > 0)[0]\n", 85 | "tool_pts = pcl[i_tool]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 46, 91 | "id": "48151410-912d-4104-9565-10235b8c5f89", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "tool_pts = torch.from_numpy(tool_pts)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 59, 101 | "id": "446e7d3b-1d4d-416b-bcd9-6a4d854d7ff4", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "sampled_tool_idxs = fps(tool_pts[:, :3], ratio=0.05, random_start=False)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 60, 111 | "id": "64f411eb-2fdf-44b0-b45a-4fbbf467996b", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "sampled_tool_pts = tool_pts[sampled_tool_idxs]" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 61, 121 | "id": "fb50a4b7-ff9d-4c7e-aed4-c8219b71fe8e", 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "MoviePy - Building file ./point_cloud_segm.gif with imageio.\n" 129 | ] 130 | }, 131 | { 132 | "name": "stderr", 133 | "output_type": "stream", 134 | "text": [ 135 | " \r" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "visualization.save_pointclouds([sampled_tool_pts], savedir='.')" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 66, 146 | "id": "f88ba691-8b9b-4cee-9474-ffac6b4f376c", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "tool_tip_pos = data['obs'][0][0][:3]" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 73, 156 | "id": "cd293673-4a7c-44b5-9a6b-0c4f7b4ba7d9", 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "data": { 161 | "text/plain": [ 162 | "array([-0.00986772, 0.54506862, -0.07499563])" 163 | ] 164 | }, 165 | "execution_count": 73, 166 | "metadata": {}, 167 | "output_type": "execute_result" 168 | } 169 | ], 170 | "source": [ 171 | "tool_tip_pos" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": 70, 177 | "id": "1e0d9a18-8411-4358-abe8-44e7b0b36666", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "gt_tool_pts = sampled_tool_pts[:, :3] - tool_tip_pos" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 71, 187 | "id": "48270c87-7b5d-4f0b-bf75-76f661fdae45", 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "data": { 192 | "text/plain": [ 193 | "tensor([-0.0045, -0.0117, -0.0070], dtype=torch.float64)" 194 | ] 195 | }, 196 | "execution_count": 71, 197 | "metadata": {}, 198 | "output_type": "execute_result" 199 | } 200 | ], 201 | "source": [ 202 | "gt_tool_pts[0]" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 74, 208 | "id": "b4f93391-559d-441e-8c8b-cf263871f903", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "with open('100_tool_pts.pkl', 'wb') as f:\n", 213 | " pickle.dump(gt_tool_pts, f)" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 75, 219 | "id": "deb4f8ad-b89a-4d80-8583-37ca6b83da5a", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "all_gt_tool_pts = sampled_tool_pts\n", 224 | "all_gt_tool_pts[:, :3] -= tool_tip_pos" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 76, 230 | "id": "795096bb-95b4-4afc-b609-255ba841ae7e", 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "name": "stdout", 235 | "output_type": "stream", 236 | "text": [ 237 | "MoviePy - Building file ./point_cloud_segm.gif with imageio.\n" 238 | ] 239 | }, 240 | { 241 | "name": "stderr", 242 | "output_type": "stream", 243 | "text": [ 244 | " \r" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "visualization.save_pointclouds([all_gt_tool_pts], savedir='.')" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 82, 255 | "id": "deb1bc32-fa0a-4e8c-a9b5-e06fea6c086b", 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "tensor(0.1027, dtype=torch.float64)" 262 | ] 263 | }, 264 | "execution_count": 82, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "max(gt_tool_pts[1])" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 92, 276 | "id": "6a3997c5-f96f-42d2-b650-bff12117f8fb", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "class MixedMediaToolReducer:\n", 281 | " TOOL_DATA_PATH = \"bc/100_tool_pts.pkl\"\n", 282 | " ACTION_LOW = np.array([ 0, 0, 0, -1, -1, -1])\n", 283 | " ACTION_HIGH = np.array([ 0, 0, 0, 1, 1, 1])\n", 284 | " DEG_TO_RAD = np.pi / 180.\n", 285 | " MAX_ROT_AXIS_ANG = (10. * DEG_TO_RAD)\n", 286 | "\n", 287 | " def __init__(self, args, action_repeat):\n", 288 | " assert args.reduce_tool_points\n", 289 | " self.tool_point_num = args.tool_point_num\n", 290 | " self.action_repeat = action_repeat\n", 291 | "\n", 292 | " self.MAX_ROT_AXIS_ANG /= action_repeat\n", 293 | "\n", 294 | " with open(self.TOOL_DATA_PATH, 'rb') as f:\n", 295 | " self.all_tool_points = pickle.load(f)\n", 296 | "\n", 297 | " # Sample tool points\n", 298 | " ratio = self.tool_point_num / 100\n", 299 | " sampled_idxs = fps(self.all_tool_points, ratio=ratio, random_start=False)\n", 300 | " self.tool_points = self.all_tool_points[sampled_idxs].detach().numpy()\n", 301 | "\n", 302 | " self.rotation = Quaternion()\n", 303 | "\n", 304 | " # Prep tool points for rotation\n", 305 | " self.vec_mat = np.zeros((self.tool_point_num, 4, 4), dtype=self.tool_points.dtype)\n", 306 | " self.vec_mat[:, 0, 1] = -self.tool_points[:, 0]\n", 307 | " self.vec_mat[:, 0, 2] = -self.tool_points[:, 1]\n", 308 | " self.vec_mat[:, 0, 3] = -self.tool_points[:, 2]\n", 309 | "\n", 310 | " self.vec_mat[:, 1, 0] = self.tool_points[:, 0]\n", 311 | " self.vec_mat[:, 1, 2] = -self.tool_points[:, 2]\n", 312 | " self.vec_mat[:, 1, 3] = self.tool_points[:, 1]\n", 313 | "\n", 314 | " self.vec_mat[:, 2, 0] = self.tool_points[:, 1]\n", 315 | " self.vec_mat[:, 2, 1] = self.tool_points[:, 2]\n", 316 | " self.vec_mat[:, 2, 3] = -self.tool_points[:, 0]\n", 317 | "\n", 318 | " self.vec_mat[:, 3, 0] = self.tool_points[:, 2]\n", 319 | " self.vec_mat[:, 3, 1] = -self.tool_points[:, 1]\n", 320 | " self.vec_mat[:, 3, 2] = self.tool_points[:, 0]\n", 321 | "\n", 322 | " def reset(self):\n", 323 | " self.rotation = Quaternion()\n", 324 | "\n", 325 | " def set_axis(self, axis):\n", 326 | " self.rotation = Quaternion(w=axis[3], x=axis[0], y=axis[1], z=axis[2])\n", 327 | "\n", 328 | " def step(self, act_raw):\n", 329 | " # act_raw: [x, y, z, rx, ry, rz]\n", 330 | " act_clip = np.clip(act_raw, a_min=self.ACTION_LOW, a_max=self.ACTION_HIGH)\n", 331 | " axis = act_clip[3:]\n", 332 | "\n", 333 | " dtheta = np.linalg.norm(act_clip[3:])\n", 334 | " if dtheta > self.MAX_ROT_AXIS_ANG:\n", 335 | " dtheta = dtheta * self.MAX_ROT_AXIS_ANG / np.sqrt(3)\n", 336 | " \n", 337 | " if dtheta == 0:\n", 338 | " axis = np.array([0., -1., 0.])\n", 339 | "\n", 340 | " axis = axis / np.linalg.norm(axis)\n", 341 | "\n", 342 | " for i in range(self.action_repeat):\n", 343 | " axis_world = self.rotation.rotate(axis)\n", 344 | " qt_rotate = Quaternion(axis=axis_world, angle=dtheta)\n", 345 | " self.rotation = qt_rotate * self.rotation\n", 346 | "\n", 347 | " def reduce_tool(self, obs, info):\n", 348 | " tool_idxs = np.where(obs[:, 3] == 1)[0]\n", 349 | " obs_notool = obs[len(tool_idxs):]\n", 350 | "\n", 351 | " tool_tip_pos = info[:3]\n", 352 | "\n", 353 | " # Rotate tool points\n", 354 | " global_rotation = self.rotation\n", 355 | " global_rotation._normalise()\n", 356 | " dqp = global_rotation.conjugate.q\n", 357 | "\n", 358 | " mid = np.matmul(self.vec_mat, dqp)\n", 359 | " mid = np.expand_dims(mid, axis=-1)\n", 360 | "\n", 361 | " rotated_tool_pts = global_rotation._q_matrix() @ mid\n", 362 | " rotated_tool_pts = rotated_tool_pts[:, 1:, 0]\n", 363 | "\n", 364 | " rotated_tool_pts += tool_tip_pos\n", 365 | "\n", 366 | " num_classes = obs.shape[1] - 3\n", 367 | " tool_onehot = np.zeros((self.tool_point_num, num_classes), dtype=obs.dtype)\n", 368 | " tool_onehot[:, 0] = 1\n", 369 | "\n", 370 | " tool_reduced = np.concatenate((rotated_tool_pts, tool_onehot), axis=1)\n", 371 | " return np.concatenate((tool_reduced, obs_notool), axis=0)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 98, 377 | "id": "2f4390a7-b05b-49ee-95b5-b49534a12dda", 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "class Args:\n", 382 | " reduce_tool_points = True\n", 383 | " tool_point_num = 20\n", 384 | " \n", 385 | "args = Args()" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 99, 391 | "id": "23529944-66e7-4c03-aa0d-ff42fb27898b", 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "num_obs = len(data['obs']) - 1\n", 396 | "\n", 397 | "raw_obs_p = []\n", 398 | "reduced_obs_p = []\n", 399 | "\n", 400 | "tool_reducer = MixedMediaToolReducer(args=args, action_repeat=8)\n", 401 | "tool_reducer.reset()\n", 402 | "\n", 403 | "for t in range(num_obs):\n", 404 | " obs = data['obs'][t]\n", 405 | " raw_obs_p.append(obs[3])\n", 406 | " reduced_obs = tool_reducer.reduce_tool(obs[3], info=obs[0])\n", 407 | " reduced_obs_p.append(reduced_obs)\n", 408 | " tool_reducer.step(data['act_raw'][t])" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 100, 414 | "id": "7278da5e-e342-4e4c-9b33-03f472216f2e", 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "name": "stdout", 419 | "output_type": "stream", 420 | "text": [ 421 | "MoviePy - Building file ./raw.gif with imageio.\n" 422 | ] 423 | }, 424 | { 425 | "name": "stderr", 426 | "output_type": "stream", 427 | "text": [ 428 | " \r" 429 | ] 430 | }, 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "MoviePy - Building file ./reduced.gif with imageio.\n" 436 | ] 437 | }, 438 | { 439 | "name": "stderr", 440 | "output_type": "stream", 441 | "text": [ 442 | " \r" 443 | ] 444 | } 445 | ], 446 | "source": [ 447 | "visualization.save_pointclouds(raw_obs_p, savedir='.', suffix=\"raw.gif\")\n", 448 | "visualization.save_pointclouds(reduced_obs_p, savedir='.', suffix=\"reduced.gif\")" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "id": "5a1af452-1853-4d4b-bf1f-7c8d7d06a42f", 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [] 458 | } 459 | ], 460 | "metadata": { 461 | "kernelspec": { 462 | "display_name": "softgym", 463 | "language": "python", 464 | "name": "softgym" 465 | }, 466 | "language_info": { 467 | "codemirror_mode": { 468 | "name": "ipython", 469 | "version": 3 470 | }, 471 | "file_extension": ".py", 472 | "mimetype": "text/x-python", 473 | "name": "python", 474 | "nbconvert_exporter": "python", 475 | "pygments_lexer": "ipython3", 476 | "version": "3.6.13" 477 | } 478 | }, 479 | "nbformat": 4, 480 | "nbformat_minor": 5 481 | } 482 | -------------------------------------------------------------------------------- /chester/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import os.path as osp 5 | import json 6 | import time 7 | import datetime 8 | import dateutil.tz 9 | import tempfile 10 | import wandb 11 | from collections import defaultdict 12 | 13 | # LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv', 'tensorboard'] 14 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv'] 15 | # Also valid: json, tensorboard, wandb 16 | 17 | DEBUG = 10 18 | INFO = 20 19 | WARN = 30 20 | ERROR = 40 21 | 22 | DISABLED = 50 23 | 24 | 25 | class KVWriter(object): 26 | def writekvs(self, kvs): 27 | raise NotImplementedError 28 | 29 | 30 | class SeqWriter(object): 31 | def writeseq(self, seq): 32 | raise NotImplementedError 33 | 34 | 35 | def put_in_middle(str1, str2): 36 | # Put str1 in str2 37 | n = len(str1) 38 | m = len(str2) 39 | if n <= m: 40 | return str2 41 | else: 42 | start = (n - m) // 2 43 | return str1[:start] + str2 + str1[start + m:] 44 | 45 | 46 | class HumanOutputFormat(KVWriter, SeqWriter): 47 | def __init__(self, filename_or_file): 48 | if isinstance(filename_or_file, str): 49 | self.file = open(filename_or_file, 'wt') 50 | self.own_file = True 51 | else: 52 | assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file 53 | self.file = filename_or_file 54 | self.own_file = False 55 | 56 | def writekvs(self, kvs): 57 | # Create strings for printing 58 | key2str = {} 59 | for (key, val) in sorted(kvs.items()): 60 | if isinstance(val, float): 61 | valstr = '%-8.3g' % (val,) 62 | else: 63 | valstr = str(val) 64 | key2str[self._truncate(key)] = self._truncate(valstr) 65 | 66 | # Find max widths 67 | if len(key2str) == 0: 68 | print('WARNING: tried to write empty key-value dict') 69 | return 70 | else: 71 | keywidth = max(map(len, key2str.keys())) 72 | valwidth = max(map(len, key2str.values())) 73 | 74 | # Write out the data 75 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 76 | timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z') 77 | 78 | dashes = '-' * (keywidth + valwidth + 7) 79 | dashes_time = put_in_middle(dashes, timestamp) 80 | lines = [dashes_time] 81 | for (key, val) in sorted(key2str.items()): 82 | lines.append('| %s%s | %s%s |' % ( 83 | key, 84 | ' ' * (keywidth - len(key)), 85 | val, 86 | ' ' * (valwidth - len(val)), 87 | )) 88 | lines.append(dashes) 89 | self.file.write('\n'.join(lines) + '\n') 90 | 91 | # Flush the output to the file 92 | self.file.flush() 93 | 94 | def _truncate(self, s): 95 | return s[:30] + '...' if len(s) > 33 else s 96 | 97 | def writeseq(self, seq): 98 | for arg in seq: 99 | self.file.write(arg) 100 | self.file.write('\n') 101 | self.file.flush() 102 | 103 | def close(self): 104 | if self.own_file: 105 | self.file.close() 106 | 107 | 108 | class JSONOutputFormat(KVWriter): 109 | def __init__(self, filename): 110 | self.file = open(filename, 'wt') 111 | 112 | def writekvs(self, kvs): 113 | for k, v in sorted(kvs.items()): 114 | if hasattr(v, 'dtype'): 115 | v = v.tolist() 116 | kvs[k] = float(v) 117 | self.file.write(json.dumps(kvs) + '\n') 118 | self.file.flush() 119 | 120 | def close(self): 121 | self.file.close() 122 | 123 | 124 | class CSVOutputFormat(KVWriter): 125 | def __init__(self, filename): 126 | self.file = open(filename, 'w+t') 127 | self.keys = [] 128 | self.sep = ',' 129 | 130 | def writekvs(self, kvs): 131 | # Add our current row to the history 132 | extra_keys = kvs.keys() - self.keys 133 | if extra_keys: 134 | self.keys.extend(extra_keys) 135 | self.file.seek(0) 136 | lines = self.file.readlines() 137 | self.file.seek(0) 138 | for (i, k) in enumerate(self.keys): 139 | if i > 0: 140 | self.file.write(',') 141 | self.file.write(k) 142 | self.file.write('\n') 143 | for line in lines[1:]: 144 | self.file.write(line[:-1]) 145 | self.file.write(self.sep * len(extra_keys)) 146 | self.file.write('\n') 147 | for (i, k) in enumerate(self.keys): 148 | if i > 0: 149 | self.file.write(',') 150 | v = kvs.get(k) 151 | if v is not None: 152 | self.file.write(str(v)) 153 | self.file.write('\n') 154 | self.file.flush() 155 | 156 | def close(self): 157 | self.file.close() 158 | 159 | 160 | class TensorBoardOutputFormat(KVWriter): 161 | """ 162 | Dumps key/value pairs into TensorBoard's numeric format. 163 | """ 164 | 165 | def __init__(self, dir): 166 | os.makedirs(dir, exist_ok=True) 167 | self.dir = dir 168 | self.step = 1 169 | prefix = 'events' 170 | path = osp.join(osp.abspath(dir), prefix) 171 | import tensorflow as tf 172 | from tensorflow.python import pywrap_tensorflow 173 | from tensorflow.core.util import event_pb2 174 | from tensorflow.python.util import compat 175 | self.tf = tf 176 | self.event_pb2 = event_pb2 177 | self.pywrap_tensorflow = pywrap_tensorflow 178 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 179 | 180 | def writekvs(self, kvs): 181 | def summary_val(k, v): 182 | kwargs = {'tag': k, 'simple_value': float(v)} 183 | return self.tf.Summary.Value(**kwargs) 184 | 185 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 186 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 187 | event.step = self.step # is there any reason why you'd want to specify the step? 188 | self.writer.WriteEvent(event) 189 | self.writer.Flush() 190 | self.step += 1 191 | 192 | def close(self): 193 | if self.writer: 194 | self.writer.Close() 195 | self.writer = None 196 | 197 | 198 | class WandbOutputFormat(KVWriter): 199 | def __init__(self, exp_name, variant): 200 | if variant is None: 201 | variant = {} 202 | 203 | project = variant.get("wandb_project") 204 | entity = variant.get("wandb_entity") 205 | wandb.init(project=project, entity=entity, name=exp_name) 206 | 207 | def writekvs(self, kvs): 208 | wandb.log(kvs) 209 | 210 | def close(self): 211 | wandb.finish() 212 | 213 | 214 | def make_output_format(format, ev_dir, exp_name, variant=None, log_suffix=''): 215 | os.makedirs(ev_dir, exist_ok=True) 216 | if format == 'stdout': 217 | return HumanOutputFormat(sys.stdout) 218 | elif format == 'log': 219 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) 220 | elif format == 'json': 221 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) 222 | elif format == 'csv': 223 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) 224 | elif format == 'tensorboard': 225 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) 226 | elif format == 'wandb': 227 | return WandbOutputFormat(exp_name, variant) 228 | else: 229 | raise ValueError('Unknown format specified: %s' % (format,)) 230 | 231 | 232 | # ================================================================ 233 | # API 234 | # ================================================================ 235 | 236 | def logkv(key, val): 237 | """ 238 | Log a value of some diagnostic 239 | Call this once for each diagnostic quantity, each iteration 240 | If called many times, last value will be used. 241 | """ 242 | Logger.CURRENT.logkv(key, val) 243 | 244 | 245 | def logkv_mean(key, val): 246 | """ 247 | The same as logkv(), but if called many times, values averaged. 248 | """ 249 | Logger.CURRENT.logkv_mean(key, val) 250 | 251 | 252 | def logkvs(d): 253 | """ 254 | Log a dictionary of key-value pairs 255 | """ 256 | for (k, v) in d.items(): 257 | logkv(k, v) 258 | 259 | 260 | def dumpkvs(): 261 | """ 262 | Write all of the diagnostics from the current iteration 263 | 264 | level: int. (see logger.py docs) If the global logger level is higher than 265 | the level argument here, don't print to stdout. 266 | """ 267 | Logger.CURRENT.dumpkvs() 268 | 269 | 270 | def getkvs(): 271 | return Logger.CURRENT.name2val 272 | 273 | 274 | def log(*args, level=INFO): 275 | """ 276 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 277 | """ 278 | Logger.CURRENT.log(*args, level=level) 279 | 280 | 281 | def debug(*args): 282 | log(*args, level=DEBUG) 283 | 284 | 285 | def info(*args): 286 | log(*args, level=INFO) 287 | 288 | 289 | def warn(*args): 290 | log(*args, level=WARN) 291 | 292 | 293 | def error(*args): 294 | log(*args, level=ERROR) 295 | 296 | 297 | def set_level(level): 298 | """ 299 | Set logging threshold on current logger. 300 | """ 301 | Logger.CURRENT.set_level(level) 302 | 303 | 304 | def get_dir(): 305 | """ 306 | Get directory that log files are being written to. 307 | will be None if there is no output directory (i.e., if you didn't call start) 308 | """ 309 | return Logger.CURRENT.get_dir() 310 | 311 | 312 | record_tabular = logkv 313 | dump_tabular = dumpkvs 314 | 315 | 316 | class ProfileKV: 317 | """ 318 | Usage: 319 | with logger.ProfileKV("interesting_scope"): 320 | code 321 | """ 322 | 323 | def __init__(self, n): 324 | self.n = "wait_" + n 325 | 326 | def __enter__(self): 327 | self.t1 = time.time() 328 | 329 | def __exit__(self, type, value, traceback): 330 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1 331 | 332 | 333 | def profile(n): 334 | """ 335 | Usage: 336 | @profile("my_func") 337 | def my_func(): code 338 | """ 339 | 340 | def decorator_with_name(func): 341 | def func_wrapper(*args, **kwargs): 342 | with ProfileKV(n): 343 | return func(*args, **kwargs) 344 | 345 | return func_wrapper 346 | 347 | return decorator_with_name 348 | 349 | 350 | # ================================================================ 351 | # Backend 352 | # ================================================================ 353 | 354 | class Logger(object): 355 | DEFAULT = None # A logger with no output files. (See right below class definition) 356 | # So that you can still log to the terminal without setting up any output files 357 | CURRENT = None # Current logger being used by the free functions above 358 | 359 | def __init__(self, dir, output_formats): 360 | self.name2val = defaultdict(float) # values this iteration 361 | self.name2cnt = defaultdict(int) 362 | self.level = INFO 363 | self.dir = dir 364 | self.output_formats = output_formats 365 | 366 | # Logging API, forwarded 367 | # ---------------------------------------- 368 | def logkv(self, key, val): 369 | self.name2val[key] = val 370 | 371 | def logkv_mean(self, key, val): 372 | if val is None: 373 | self.name2val[key] = None 374 | return 375 | oldval, cnt = self.name2val[key], self.name2cnt[key] 376 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 377 | self.name2cnt[key] = cnt + 1 378 | 379 | def dumpkvs(self): 380 | if self.level == DISABLED: return 381 | for fmt in self.output_formats: 382 | if isinstance(fmt, KVWriter): 383 | fmt.writekvs(self.name2val) 384 | self.name2val.clear() 385 | self.name2cnt.clear() 386 | 387 | def log(self, *args, level=INFO): 388 | if self.level <= level: 389 | self._do_log(args) 390 | 391 | # Configuration 392 | # ---------------------------------------- 393 | def set_level(self, level): 394 | self.level = level 395 | 396 | def get_dir(self): 397 | return self.dir 398 | 399 | def close(self): 400 | for fmt in self.output_formats: 401 | fmt.close() 402 | 403 | # Misc 404 | # ---------------------------------------- 405 | def _do_log(self, args): 406 | for fmt in self.output_formats: 407 | if isinstance(fmt, SeqWriter): 408 | fmt.writeseq(map(str, args)) 409 | 410 | 411 | Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)]) 412 | 413 | 414 | def configure(dir=None, format_strs=None, exp_name=None, variant=None): 415 | if dir is None: 416 | dir = os.getenv('OPENAI_LOGDIR') 417 | if dir is None: 418 | dir = osp.join(tempfile.gettempdir(), 419 | datetime.datetime.now().strftime("chester-%Y-%m-%d-%H-%M-%S")) 420 | 421 | assert isinstance(dir, str) 422 | os.makedirs(dir, exist_ok=True) 423 | 424 | if format_strs is None: 425 | strs = os.getenv('OPENAI_LOG_FORMAT') 426 | format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS 427 | output_formats = [make_output_format(f, dir, exp_name, variant) for f in format_strs] 428 | 429 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) 430 | log('Logging to %s' % dir) 431 | 432 | 433 | def reset(): 434 | if Logger.CURRENT is not Logger.DEFAULT: 435 | Logger.CURRENT.close() 436 | Logger.CURRENT = Logger.DEFAULT 437 | log('Reset logger') 438 | 439 | 440 | class scoped_configure(object): 441 | def __init__(self, dir=None, format_strs=None): 442 | self.dir = dir 443 | self.format_strs = format_strs 444 | self.prevlogger = None 445 | 446 | def __enter__(self): 447 | self.prevlogger = Logger.CURRENT 448 | configure(dir=self.dir, format_strs=self.format_strs) 449 | 450 | def __exit__(self, *args): 451 | Logger.CURRENT.close() 452 | Logger.CURRENT = self.prevlogger 453 | 454 | 455 | # ================================================================ 456 | 457 | def _demo(): 458 | info("hi") 459 | debug("shouldn't appear") 460 | set_level(DEBUG) 461 | debug("should appear") 462 | dir = "/tmp/testlogging" 463 | if os.path.exists(dir): 464 | shutil.rmtree(dir) 465 | configure(dir=dir) 466 | logkv("a", 3) 467 | logkv("b", 2.5) 468 | dumpkvs() 469 | logkv("b", -2.5) 470 | logkv("a", 5.5) 471 | dumpkvs() 472 | info("^^^ should see a = 5.5") 473 | logkv_mean("b", -22.5) 474 | logkv_mean("b", -44.4) 475 | logkv("a", 5.5) 476 | dumpkvs() 477 | info("^^^ should see b = 33.3") 478 | 479 | logkv("b", -2.5) 480 | dumpkvs() 481 | 482 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 483 | dumpkvs() 484 | 485 | 486 | # ================================================================ 487 | # Readers 488 | # ================================================================ 489 | 490 | def read_json(fname): 491 | import pandas 492 | ds = [] 493 | with open(fname, 'rt') as fh: 494 | for line in fh: 495 | ds.append(json.loads(line)) 496 | return pandas.DataFrame(ds) 497 | 498 | 499 | def read_csv(fname): 500 | import pandas 501 | return pandas.read_csv(fname, index_col=None, comment='#') 502 | 503 | 504 | def read_tb(path): 505 | """ 506 | path : a tensorboard file OR a directory, where we will find all TB files 507 | of the form events.* 508 | """ 509 | import pandas 510 | import numpy as np 511 | from glob import glob 512 | from collections import defaultdict 513 | import tensorflow as tf 514 | if osp.isdir(path): 515 | fnames = glob(osp.join(path, "events.*")) 516 | elif osp.basename(path).startswith("events."): 517 | fnames = [path] 518 | else: 519 | raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s" % path) 520 | tag2pairs = defaultdict(list) 521 | maxstep = 0 522 | for fname in fnames: 523 | for summary in tf.train.summary_iterator(fname): 524 | if summary.step > 0: 525 | for v in summary.summary.value: 526 | pair = (summary.step, v.simple_value) 527 | tag2pairs[v.tag].append(pair) 528 | maxstep = max(summary.step, maxstep) 529 | data = np.empty((maxstep, len(tag2pairs))) 530 | data[:] = np.nan 531 | tags = sorted(tag2pairs.keys()) 532 | for (colidx, tag) in enumerate(tags): 533 | pairs = tag2pairs[tag] 534 | for (step, value) in pairs: 535 | data[step - 1, colidx] = value 536 | return pandas.DataFrame(data, columns=tags) 537 | 538 | 539 | if __name__ == "__main__": 540 | _demo() 541 | -------------------------------------------------------------------------------- /chester/utils_s3.py: -------------------------------------------------------------------------------- 1 | from chester import config_ec2 as config 2 | from io import StringIO 3 | import base64 4 | import os 5 | import os.path as osp 6 | import subprocess 7 | import hashlib 8 | import datetime 9 | import dateutil 10 | import re 11 | 12 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 13 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 14 | 15 | 16 | def dedent(s): 17 | lines = [l.strip() for l in s.split('\n')] 18 | return '\n'.join(lines) 19 | 20 | 21 | def upload_file_to_s3(script_content): 22 | import tempfile 23 | import uuid 24 | f = tempfile.NamedTemporaryFile(delete=False) 25 | f.write(script_content.encode()) 26 | f.close() 27 | remote_path = os.path.join( 28 | config.AWS_CODE_SYNC_S3_PATH, "oversize_bash_scripts", str(uuid.uuid4())) 29 | subprocess.check_call(["aws", "s3", "cp", f.name, remote_path]) 30 | os.unlink(f.name) 31 | return remote_path 32 | 33 | 34 | S3_CODE_PATH = None 35 | 36 | 37 | def s3_sync_code(config, dry=False): 38 | global S3_CODE_PATH 39 | if S3_CODE_PATH is not None: 40 | return S3_CODE_PATH 41 | base = config.AWS_CODE_SYNC_S3_PATH 42 | has_git = True 43 | 44 | if config.FAST_CODE_SYNC: 45 | try: 46 | current_commit = subprocess.check_output( 47 | ["git", "rev-parse", "HEAD"]).strip().decode("utf-8") 48 | except subprocess.CalledProcessError as _: 49 | print("Warning: failed to execute git commands") 50 | current_commit = None 51 | file_name = str(timestamp) + "_" + hashlib.sha224( 52 | subprocess.check_output(["pwd"]) + str(current_commit).encode() + str(timestamp).encode() 53 | ).hexdigest() + ".tar.gz" 54 | 55 | file_path = "/tmp/" + file_name 56 | 57 | tar_cmd = ["tar", "-zcvf", file_path, "-C", config.PROJECT_PATH] 58 | for pattern in config.FAST_CODE_SYNC_IGNORES: 59 | tar_cmd += ["--exclude", pattern] 60 | tar_cmd += ["-h", "."] 61 | 62 | remote_path = "%s/%s" % (base, file_name) 63 | 64 | upload_cmd = ["aws", "s3", "cp", file_path, remote_path] 65 | 66 | mujoco_key_cmd = [ 67 | "aws", "s3", "sync", config.MUJOCO_KEY_PATH, "{}/.mujoco/".format(base)] 68 | 69 | print(" ".join(tar_cmd)) 70 | print(" ".join(upload_cmd)) 71 | print(" ".join(mujoco_key_cmd)) 72 | 73 | if not dry: 74 | subprocess.check_call(tar_cmd) 75 | subprocess.check_call(upload_cmd) 76 | try: 77 | subprocess.check_call(mujoco_key_cmd) 78 | except Exception as e: 79 | print(e) 80 | 81 | S3_CODE_PATH = remote_path 82 | return remote_path 83 | else: 84 | try: 85 | current_commit = subprocess.check_output( 86 | ["git", "rev-parse", "HEAD"]).strip().decode("utf-8") 87 | clean_state = len( 88 | subprocess.check_output(["git", "status", "--porcelain"])) == 0 89 | except subprocess.CalledProcessError as _: 90 | print("Warning: failed to execute git commands") 91 | has_git = False 92 | dir_hash = base64.b64encode(subprocess.check_output(["pwd"])).decode("utf-8") 93 | code_path = "%s_%s" % ( 94 | dir_hash, 95 | (current_commit if clean_state else "%s_dirty_%s" % (current_commit, timestamp)) if 96 | has_git else timestamp 97 | ) 98 | full_path = "%s/%s" % (base, code_path) 99 | cache_path = "%s/%s" % (base, dir_hash) 100 | cache_cmds = ["aws", "s3", "cp", "--recursive"] + \ 101 | flatten(["--exclude", "%s" % pattern] for pattern in config.CODE_SYNC_IGNORES) + \ 102 | [cache_path, full_path] 103 | cmds = ["aws", "s3", "cp", "--recursive"] + \ 104 | flatten(["--exclude", "%s" % pattern] for pattern in config.CODE_SYNC_IGNORES) + \ 105 | [".", full_path] 106 | caching_cmds = ["aws", "s3", "cp", "--recursive"] + \ 107 | flatten(["--exclude", "%s" % pattern] for pattern in config.CODE_SYNC_IGNORES) + \ 108 | [full_path, cache_path] 109 | mujoco_key_cmd = [ 110 | "aws", "s3", "sync", config.MUJOCO_KEY_PATH, "{}/.mujoco/".format(base)] 111 | print(cache_cmds, cmds, caching_cmds, mujoco_key_cmd) 112 | if not dry: 113 | subprocess.check_call(cache_cmds) 114 | subprocess.check_call(cmds) 115 | subprocess.check_call(caching_cmds) 116 | try: 117 | subprocess.check_call(mujoco_key_cmd) 118 | except Exception: 119 | print('Unable to sync mujoco keys!') 120 | S3_CODE_PATH = full_path 121 | return full_path 122 | 123 | 124 | _find_unsafe = re.compile(r'[a-zA-Z0-9_^@%+=:,./-]').search 125 | 126 | 127 | def _shellquote(s): 128 | """Return a shell-escaped version of the string *s*.""" 129 | if not s: 130 | return "''" 131 | 132 | if _find_unsafe(s) is None: 133 | return s 134 | 135 | # use single quotes, and put single quotes into double quotes 136 | # the string $'b is then quoted as '$'"'"'b' 137 | 138 | return "'" + s.replace("'", "'\"'\"'") + "'" 139 | 140 | 141 | def _to_param_val(v): 142 | if v is None: 143 | return "" 144 | elif isinstance(v, list): 145 | return " ".join(map(_shellquote, list(map(str, v)))) 146 | else: 147 | return _shellquote(str(v)) 148 | 149 | 150 | def to_local_command(params, python_command="python", script=osp.join(config.PROJECT_PATH, 'scripts/run_experiment.py'), use_gpu=False): 151 | command = python_command + " " + script 152 | pre_commands = params.pop("pre_commands", None) 153 | post_commands = params.pop("post_commands", None) 154 | if post_commands is not None: 155 | print("Not executing the post_commands: ", post_commands) 156 | 157 | for k, v in params.items(): 158 | if isinstance(v, dict): 159 | for nk, nv in v.items(): 160 | if str(nk) == "_name": 161 | command += " --%s %s" % (k, _to_param_val(nv)) 162 | else: 163 | command += \ 164 | " --%s_%s %s" % (k, nk, _to_param_val(nv)) 165 | else: 166 | command += " --%s %s" % (k, _to_param_val(v)) 167 | for pre_command in reversed(pre_commands): 168 | command = pre_command + " && " + command 169 | return command 170 | 171 | 172 | def launch_ec2(params_list, exp_prefix, docker_image, code_full_path, 173 | python_command="python", 174 | script='scripts/run_experiment.py', 175 | aws_config=None, dry=False, terminate_machine=True, use_gpu=False, sync_s3_pkl=False, 176 | sync_s3_png=False, 177 | sync_s3_log=False, 178 | sync_s3_html=False, 179 | sync_s3_mp4=False, 180 | sync_s3_gif=False, 181 | sync_s3_pth=False, 182 | sync_s3_txt=False, 183 | sync_log_on_termination=True, 184 | periodic_sync=True, periodic_sync_interval=15): 185 | if len(params_list) == 0: 186 | return 187 | 188 | default_config = dict( 189 | image_id=config.AWS_IMAGE_ID, 190 | instance_type=config.AWS_INSTANCE_TYPE, 191 | key_name=config.AWS_KEY_NAME, 192 | spot=config.AWS_SPOT, 193 | spot_price=config.AWS_SPOT_PRICE, 194 | iam_instance_profile_name=config.AWS_IAM_INSTANCE_PROFILE_NAME, 195 | security_groups=config.AWS_SECURITY_GROUPS, 196 | security_group_ids=config.AWS_SECURITY_GROUP_IDS, 197 | network_interfaces=config.AWS_NETWORK_INTERFACES, 198 | instance_interruption_behavior='terminate', # TODO 199 | ) 200 | 201 | if aws_config is None: 202 | aws_config = dict() 203 | aws_config = dict(default_config, **aws_config) 204 | 205 | sio = StringIO() 206 | sio.write("#!/bin/bash\n") 207 | sio.write("{\n") 208 | sio.write(""" 209 | die() { status=$1; shift; echo "FATAL: $*"; exit $status; } 210 | """) 211 | sio.write(""" 212 | EC2_INSTANCE_ID="`wget -q -O - http://169.254.169.254/latest/meta-data/instance-id`" 213 | """) 214 | # sio.write("""service docker start""") 215 | # sio.write("""docker --config /home/ubuntu/.docker pull {docker_image}""".format(docker_image=docker_image)) 216 | sio.write(""" 217 | export PATH=/home/ubuntu/bin:/home/ubuntu/.local/bin:$PATH 218 | """) 219 | sio.write(""" 220 | export PATH=/home/ubuntu/miniconda3/bin:/usr/local/cuda/bin:$PATH 221 | """) 222 | sio.write(""" 223 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco200/bin 224 | """) 225 | sio.write(""" 226 | echo $PATH 227 | """) 228 | sio.write(""" 229 | export AWS_DEFAULT_REGION={aws_region} 230 | """.format(aws_region=config.AWS_BUCKET_REGION_NAME)) # add AWS_BUCKET_REGION_NAME=us-east-1 in your config.py 231 | sio.write(""" 232 | pip install --upgrade --user awscli 233 | """) 234 | 235 | sio.write(""" 236 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=Name,Value={exp_name} --region {aws_region} 237 | """.format(exp_name=params_list[0].get("exp_name"), aws_region=config.AWS_REGION_NAME)) 238 | if config.LABEL: 239 | sio.write(""" 240 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=owner,Value={label} --region {aws_region} 241 | """.format(label=config.LABEL, aws_region=config.AWS_REGION_NAME)) 242 | sio.write(""" 243 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=exp_prefix,Value={exp_prefix} --region {aws_region} 244 | """.format(exp_prefix=exp_prefix, aws_region=config.AWS_REGION_NAME)) 245 | 246 | if config.FAST_CODE_SYNC: 247 | sio.write(""" 248 | aws s3 cp {code_full_path} /tmp/chester_code.tar.gz 249 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR)) 250 | sio.write(""" 251 | mkdir -p {local_code_path} 252 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR, 253 | aws_region=config.AWS_REGION_NAME)) 254 | sio.write(""" 255 | tar -zxvf /tmp/chester_code.tar.gz -C {local_code_path} 256 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR, 257 | aws_region=config.AWS_REGION_NAME)) 258 | else: 259 | sio.write(""" 260 | aws s3 cp --recursive {code_full_path} {local_code_path} 261 | """.format(code_full_path=code_full_path, local_code_path=config.CODE_DIR)) 262 | s3_mujoco_key_path = config.AWS_CODE_SYNC_S3_PATH + '/.mujoco/' 263 | sio.write(""" 264 | aws s3 cp --recursive {} {} 265 | """.format(s3_mujoco_key_path, config.MUJOCO_KEY_PATH)) 266 | sio.write(""" 267 | cd {local_code_path} 268 | """.format(local_code_path=config.CODE_DIR)) 269 | 270 | for params in params_list: 271 | log_dir = params.get("log_dir") 272 | remote_log_dir = params.pop("remote_log_dir") 273 | env = params.pop("env", None) 274 | 275 | sio.write(""" 276 | aws ec2 create-tags --resources $EC2_INSTANCE_ID --tags Key=Name,Value={exp_name} --region {aws_region} 277 | """.format(exp_name=params.get("exp_name"), aws_region=config.AWS_REGION_NAME)) 278 | sio.write(""" 279 | mkdir -p {log_dir} 280 | """.format(log_dir=log_dir)) 281 | if periodic_sync: 282 | include_png = " --include '*.png' " if sync_s3_png else " " 283 | include_pkl = " --include '*.pkl' " if sync_s3_pkl else " " 284 | include_log = " --include '*.log' " if sync_s3_log else " " 285 | include_html = " --include '*.html' " if sync_s3_html else " " 286 | include_mp4 = " --include '*.mp4' " if sync_s3_mp4 else " " 287 | include_gif = " --include '*.gif' " if sync_s3_gif else " " 288 | include_pth = " --include '*.pth' " if sync_s3_pth else " " 289 | include_txt = " --include '*.txt' " if sync_s3_txt else " " 290 | sio.write(""" 291 | while /bin/true; do 292 | aws s3 sync --exclude '*' {include_png} {include_pkl} {include_log} {include_html} {include_mp4} {include_gif} {include_pth} {include_txt} --include '*.csv' --include '*.json' {log_dir} {remote_log_dir} 293 | sleep {periodic_sync_interval} 294 | done & echo sync initiated""".format(include_png=include_png, include_pkl=include_pkl, 295 | include_log=include_log, include_html=include_html, 296 | include_mp4=include_mp4, include_gif=include_gif, 297 | include_pth=include_pth, include_txt=include_txt, 298 | log_dir=log_dir, remote_log_dir=remote_log_dir, 299 | periodic_sync_interval=periodic_sync_interval)) 300 | if sync_log_on_termination: 301 | sio.write(""" 302 | while /bin/true; do 303 | if [ -z $(curl -Is http://169.254.169.254/latest/meta-data/spot/termination-time | head -1 | grep 404 | cut -d \ -f 2) ] 304 | then 305 | logger "Running shutdown hook." 306 | aws s3 cp /home/ubuntu/user_data.log {remote_log_dir}/stdout.log 307 | aws s3 cp --recursive {log_dir} {remote_log_dir} 308 | break 309 | else 310 | # Spot instance not yet marked for termination. 311 | sleep 5 312 | fi 313 | done & echo log sync initiated 314 | """.format(log_dir=log_dir, remote_log_dir=remote_log_dir)) 315 | sio.write("""{command}""".format(command=to_local_command(params, python_command=python_command, script=script, use_gpu=use_gpu))) 316 | sio.write(""" 317 | aws s3 cp --recursive {log_dir} {remote_log_dir} 318 | """.format(log_dir=log_dir, remote_log_dir=remote_log_dir)) 319 | sio.write(""" 320 | aws s3 cp /home/ubuntu/user_data.log {remote_log_dir}/stdout.log 321 | """.format(remote_log_dir=remote_log_dir)) 322 | 323 | if terminate_machine: 324 | sio.write(""" 325 | EC2_INSTANCE_ID="`wget -q -O - http://169.254.169.254/latest/meta-data/instance-id || die \"wget instance-id has failed: $?\"`" 326 | aws ec2 terminate-instances --instance-ids $EC2_INSTANCE_ID --region {aws_region} 327 | """.format(aws_region=config.AWS_REGION_NAME)) 328 | sio.write("} >> /home/ubuntu/user_data.log 2>&1\n") 329 | 330 | full_script = dedent(sio.getvalue()) 331 | 332 | import boto3 333 | import botocore 334 | if aws_config["spot"]: 335 | ec2 = boto3.client( 336 | "ec2", 337 | region_name=config.AWS_REGION_NAME, 338 | aws_access_key_id=config.AWS_ACCESS_KEY, 339 | aws_secret_access_key=config.AWS_ACCESS_SECRET, 340 | ) 341 | else: 342 | ec2 = boto3.resource( 343 | "ec2", 344 | region_name=config.AWS_REGION_NAME, 345 | aws_access_key_id=config.AWS_ACCESS_KEY, 346 | aws_secret_access_key=config.AWS_ACCESS_SECRET, 347 | ) 348 | 349 | print("len_full_script", len(full_script)) 350 | if len(full_script) > 16384 or len(base64.b64encode(full_script.encode()).decode("utf-8")) > 16384: 351 | # Script too long; need to upload script to s3 first. 352 | # We're being conservative here since the actual limit is 16384 bytes 353 | s3_path = upload_file_to_s3(full_script) 354 | sio = StringIO() 355 | sio.write("#!/bin/bash\n") 356 | sio.write(""" 357 | aws s3 cp {s3_path} /home/ubuntu/remote_script.sh --region {aws_region} && \\ 358 | chmod +x /home/ubuntu/remote_script.sh && \\ 359 | bash /home/ubuntu/remote_script.sh 360 | """.format(s3_path=s3_path, aws_region=config.AWS_REGION_NAME)) 361 | user_data = dedent(sio.getvalue()) 362 | else: 363 | user_data = full_script 364 | print(full_script) 365 | with open("/tmp/full_script", "w") as f: 366 | f.write(full_script) 367 | 368 | instance_args = dict( 369 | ImageId=aws_config["image_id"], 370 | KeyName=aws_config["key_name"], 371 | UserData=user_data, 372 | InstanceType=aws_config["instance_type"], 373 | EbsOptimized=config.EBS_OPTIMIZED, 374 | SecurityGroups=aws_config["security_groups"], 375 | SecurityGroupIds=aws_config["security_group_ids"], 376 | NetworkInterfaces=aws_config["network_interfaces"], 377 | IamInstanceProfile=dict( 378 | Name=aws_config["iam_instance_profile_name"], 379 | ), 380 | **config.AWS_EXTRA_CONFIGS, 381 | ) 382 | 383 | if len(instance_args["NetworkInterfaces"]) > 0: 384 | # disable_security_group = query_yes_no( 385 | # "Cannot provide both network interfaces and security groups info. Do you want to disable security group settings?", 386 | # default="yes", 387 | # ) 388 | disable_security_group = True 389 | if disable_security_group: 390 | instance_args.pop("SecurityGroups") 391 | instance_args.pop("SecurityGroupIds") 392 | 393 | if aws_config.get("placement", None) is not None: 394 | instance_args["Placement"] = aws_config["placement"] 395 | if not aws_config["spot"]: 396 | instance_args["MinCount"] = 1 397 | instance_args["MaxCount"] = 1 398 | print("************************************************************") 399 | print(instance_args["UserData"]) 400 | print("************************************************************") 401 | if aws_config["spot"]: 402 | instance_args["UserData"] = base64.b64encode(instance_args["UserData"].encode()).decode("utf-8") 403 | spot_args = dict( 404 | DryRun=dry, 405 | InstanceCount=1, 406 | LaunchSpecification=instance_args, 407 | SpotPrice=aws_config["spot_price"], 408 | # ClientToken=params_list[0]["exp_name"], 409 | ) 410 | import pprint 411 | pprint.pprint(spot_args) 412 | if not dry: 413 | response = ec2.request_spot_instances(**spot_args) 414 | print(response) 415 | spot_request_id = response['SpotInstanceRequests'][ 416 | 0]['SpotInstanceRequestId'] 417 | for _ in range(10): 418 | try: 419 | ec2.create_tags( 420 | Resources=[spot_request_id], 421 | Tags=[ 422 | {'Key': 'Name', 'Value': params_list[0]["exp_name"]} 423 | ], 424 | ) 425 | break 426 | except botocore.exceptions.ClientError: 427 | continue 428 | else: 429 | import pprint 430 | pprint.pprint(instance_args) 431 | ec2.q( 432 | DryRun=dry, 433 | **instance_args 434 | ) 435 | --------------------------------------------------------------------------------