├── .gitignore ├── README.md ├── __init__.py ├── assets ├── 10_tasks.csv ├── 10_tasks.json ├── all_tasks.json └── taskvar_instructions.jsonl ├── config ├── default.py ├── plain_unet.yaml └── transformer_unet.yaml ├── core ├── actioner.py └── environments.py ├── dataloaders ├── __init__.py ├── keystep_dataset.py └── loader.py ├── eval_models.py ├── job_scripts ├── eval_tst_split.sh └── train_multitask_bc.sh ├── models ├── __init__.py ├── network_utils.py ├── plain_unet.py └── transformer_unet.py ├── optim ├── __init__.py ├── adamw.py ├── lookahead.py ├── misc.py ├── radam.py ├── ralamb.py ├── rangerlars.py └── sched.py ├── preprocess ├── evaluate_dataset_keysteps.py ├── generate_dataset_keysteps.py ├── generate_dataset_microsteps.py └── generate_instructions.py ├── requirements.txt ├── summarize_tst_results.py ├── summarize_val_results.py ├── train_models.py └── utils ├── __init__.py ├── coord_transforms.py ├── distributed.py ├── keystep_detection.py ├── logger.py ├── misc.py ├── ops.py ├── recorder.py ├── save.py ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | *.pyc 4 | __pycache__/ 5 | core-python-* 6 | bug-report-* 7 | 8 | data 9 | notebooks 10 | PyRep 11 | RLBench 12 | 13 | *.lock 14 | .~lock.* 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hiveformer: History-aware instruction-conditioned multi-view transformer for robotic manipulation 2 | 3 | This is a PyTorch re-implementation of the Hiveformer paper: 4 | > Instruction-driven history-aware policies for robotic manipulations 5 | > Pierre-Louis Guhur, Shizhe Chen, Ricardo Garcia, Makarand Tapaswi, Ivan Laptev, Cordelia Schmid 6 | > **CoRL 2022 (oral)** 7 | 8 | 9 | ## Prerequisite 10 | 11 | 1. Installation. 12 | 13 | Option 1: Use our pre-build singularity image. 14 | ``` 15 | singularity pull library://rjgpinel/rlbench/vlc_rlbench.sif 16 | ``` 17 | 18 | Option 2: Install everything from scratch. 19 | ```bash 20 | conda create --name hiveformer python=3.9 21 | conda activate hiveformer 22 | ``` 23 | 24 | See instructions in [PyRep](https://github.com/stepjam/PyRep) and [RLBench](https://github.com/stepjam/RLBench) to install RLBench simulator (with VirtualGL in headless machines). Use our modified version of [RLBench](https://github.com/rjgpinel/RLBench) to support additional tasks. 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | 29 | export PYTHONPATH=$PYTHONPATH:$(pwd) 30 | ``` 31 | 32 | 33 | 2. Dataset generation 34 | 35 | Option 1: Use our [generated datasets](https://drive.google.com/drive/folders/1BCvGrTK_cLkMuF9XR40Xx_QciXM3htwh?usp=drive_link) including the keystep trajectories and instruction embeddings. 36 | 37 | Option 2: generate the dataset on your own. 38 | ```bash 39 | seed=0 40 | task=put_knife_on_chopping_board 41 | variation=0 42 | variation_count=1 43 | 44 | # 1. generate microstep demonstrations 45 | python preprocess/generate_dataset_microsteps.py \ 46 | --save_path data/train_dataset/microsteps/seed{seed} \ 47 | --all_task_file assets/all_tasks.json \ 48 | --image_size 128,128 --renderer opengl \ 49 | --episodes_per_task 100 \ 50 | --tasks ${task} --variations ${variation_count} --offset ${variation} \ 51 | --processes 1 --seed ${seed} 52 | 53 | # 2. generate keystep demonstrations 54 | python preprocess/generate_dataset_keysteps.py \ 55 | --microstep_data_dir data/train_dataset/microsteps/seed${seed} \ 56 | --keystep_data_dir data/train_dataset/keysteps/seed${seed} \ 57 | --tasks ${task} 58 | 59 | # 3. (optional) check the correctness of generated keysteps 60 | python preprocess/evaluate_dataset_keysteps.py \ 61 | --microstep_data_dir data/train_dataset/microsteps/seed${seed} \ 62 | --keystep_data_dir data/train_dataset/keysteps/seed${seed} \ 63 | --tasks ${task} 64 | 65 | # 4. generate instructions embeddings for the tasks 66 | python preprocess/generate_instructions.py \ 67 | --encoder clip \ 68 | --output_file data/train_dataset/taskvar_instrs/clip 69 | ``` 70 | 71 | 72 | 73 | ## Train 74 | 75 | Our codes support distributed training with multiple GPUs in SLURM clusters. 76 | 77 | For slurm users, please use the following command to launch the training script. 78 | ```bash 79 | sbatch job_scripts/train_multitask_bc.sh 80 | ``` 81 | 82 | For non-slurm users, please manually set the environment variables as follows. 83 | 84 | ```bash 85 | export WORLD_SIZE=1 86 | export MASTER_ADDR='localhost' 87 | export MASTER_PORT=10000 88 | 89 | export LOCAL_RANK=0 90 | export RANK=0 91 | export CUDA_VISIBLE_DEVICES=0 92 | 93 | python train_models.py --exp-config config/transformer_unet.yaml 94 | ``` 95 | 96 | 97 | 98 | ## Evaluation 99 | 100 | For slurm users, please use the following command to launch the evaluation script. 101 | ```bash 102 | sbatch job_scripts/eval_tst_split.sh 103 | ``` 104 | 105 | For non-slurm users, run the following commands to evaluate the trained model. 106 | 107 | ```bash 108 | # set outdir to the directory of your trained model 109 | export DISPLAY=:0.0 # in headless machines 110 | 111 | # validation: select the best epoch 112 | for step in {5000..300000..5000} 113 | do 114 | python eval_models.py \ 115 | --exp_config ${outdir}/logs/training_config.yaml \ 116 | --seed 100 --num_demos 20 \ 117 | checkpoint ${outdir}/ckpts/model_step_${step}.pt 118 | done 119 | 120 | # run the script to summarize the validation results 121 | python summarize_val_results.py --result_file ${outdir}/preds/seed100/results.jsonl 122 | 123 | # test: use a different seed from validation 124 | step=300000 125 | python eval_models.py \ 126 | --exp_config ${outdir}/logs/training_config.yaml \ 127 | --seed 200 --num_demos 500 \ 128 | checkpoint ${outdir}/ckpts/model_step_${step}.pt 129 | 130 | # run the script to summarize the testing results 131 | python summarize_tst_results.py --result_file ${outdir}/preds/seed200/results.jsonl 132 | ``` 133 | 134 | We provided trained models in [Dropbox](https://www.dropbox.com/s/o4na7namn1ujhng/transformer_unet%2Bgripper_attn_multi32_300k.tar.gz?dl=0) for the multi-task setting (10 tasks). 135 | You could obtain results as follows which are similar to the results in the paper: 136 | 137 | | | pick_ and_lift | pick_up _cup | put_knife_on_ chopping_board | put_money _in_safe | push_ button | reach_ target | slide_block _to_target | stack _wine | take_money _out_safe | take_umbrella_out_ of_umbrella_stand | Avg. | 138 | |:------:|:--------------:|:------------:|:----------------------------:|:------------------:|:------------:|:-------------:|:----------------------:|:-----------:|:--------------------:|:------------------------------------:|:-----:| 139 | | seed=0 | 89.00 | 76.80 | 72.80 | 93.00 | 69.60 | 100.00 | 74.20 | 87.20 | 73.20 | 89.80 | 82.56 | 140 | | seed=2 | 91.40 | 75.80 | 76.20 | 81.60 | 86.60 | 100.00 | 85.00 | 89.00 | 72.80 | 79.60 | 83.80 | 141 | | seed=4 | 91.60 | 83.60 | 72.80 | 83.00 | 88.40 | 100.00 | 57.80 | 83.20 | 69.60 | 89.60 | 81.96 | 142 | | Avg. | 90.67 | 78.73 | 73.93 | 85.87 | 81.53 | 100.00 | 72.33 | 86.47 | 71.87 | 86.33 | 82.77 | 143 | 144 | 145 | We also trained the hiveformer model on 74 RLBench tasks. 146 | For the single-task setting, it achieves 66.09% success rate on average. 147 | For the multi-task setting, it achieves 49.22%. The multi-task policy is provided in [Dropbox](https://www.dropbox.com/sh/fwxtojgiusv8v82/AABQkDczpBOZYKp2tp1q1gMja?dl=0). 148 | 149 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/hiveformer/0ca80a156acb3985be236fd9ab50e56734f970d6/__init__.py -------------------------------------------------------------------------------- /assets/10_tasks.csv: -------------------------------------------------------------------------------- 1 | pick_and_lift 2 | pick_up_cup 3 | put_knife_on_chopping_board 4 | put_money_in_safe 5 | push_button 6 | reach_target 7 | slide_block_to_target 8 | stack_wine 9 | take_money_out_safe 10 | take_umbrella_out_of_umbrella_stand 11 | -------------------------------------------------------------------------------- /assets/10_tasks.json: -------------------------------------------------------------------------------- 1 | [ 2 | "pick_and_lift", 3 | "pick_up_cup", 4 | "put_knife_on_chopping_board", 5 | "put_money_in_safe", 6 | "push_button", 7 | "reach_target", 8 | "slide_block_to_target", 9 | "stack_wine", 10 | "take_money_out_safe", 11 | "take_umbrella_out_of_umbrella_stand" 12 | ] -------------------------------------------------------------------------------- /assets/all_tasks.json: -------------------------------------------------------------------------------- 1 | [ 2 | "basketball_in_hoop", 3 | "beat_the_buzz", 4 | "block_pyramid", 5 | "change_channel", 6 | "change_clock", 7 | "close_box", 8 | "close_door", 9 | "close_drawer", 10 | "close_fridge", 11 | "close_grill", 12 | "close_jar", 13 | "close_laptop_lid", 14 | "close_microwave", 15 | "empty_container", 16 | "empty_dishwasher", 17 | "get_ice_from_fridge", 18 | "hang_frame_on_hanger", 19 | "hit_ball_with_queue", 20 | "hockey", 21 | "insert_onto_square_peg", 22 | "insert_usb_in_computer", 23 | "lamp_off", 24 | "lamp_on", 25 | "lift_numbered_block", 26 | "light_bulb_in", 27 | "light_bulb_out", 28 | "meat_off_grill", 29 | "meat_on_grill", 30 | "move_hanger", 31 | "open_box", 32 | "open_door", 33 | "open_drawer", 34 | "open_fridge", 35 | "open_grill", 36 | "open_jar", 37 | "open_microwave", 38 | "open_oven", 39 | "open_washing_machine", 40 | "open_window", 41 | "open_wine_bottle", 42 | "phone_on_base", 43 | "pick_and_lift", 44 | "pick_and_lift_small", 45 | "pick_up_cup", 46 | "place_cups", 47 | "place_hanger_on_rack", 48 | "place_shape_in_shape_sorter", 49 | "play_jenga", 50 | "plug_charger_in_power_supply", 51 | "pour_from_cup_to_cup", 52 | "press_switch", 53 | "push_button", 54 | "push_buttons", 55 | "put_all_groceries_in_cupboard", 56 | "put_books_on_bookshelf", 57 | "put_bottle_in_fridge", 58 | "put_groceries_in_cupboard", 59 | "put_item_in_drawer", 60 | "put_knife_in_knife_block", 61 | "put_knife_on_chopping_board", 62 | "put_money_in_safe", 63 | "put_plate_in_colored_dish_rack", 64 | "put_rubbish_in_bin", 65 | "put_shoes_in_box", 66 | "put_toilet_roll_on_stand", 67 | "put_tray_in_oven", 68 | "put_umbrella_in_umbrella_stand", 69 | "reach_and_drag", 70 | "reach_target", 71 | "remove_cups", 72 | "scoop_with_spatula", 73 | "screw_nail", 74 | "set_the_table", 75 | "setup_checkers", 76 | "setup_chess", 77 | "slide_block_to_target", 78 | "slide_cabinet_open", 79 | "slide_cabinet_open_and_place_cups", 80 | "solve_puzzle", 81 | "stack_blocks", 82 | "stack_chairs", 83 | "stack_cups", 84 | "stack_wine", 85 | "straighten_rope", 86 | "sweep_to_dustpan", 87 | "take_cup_out_from_cabinet", 88 | "take_frame_off_hanger", 89 | "take_item_out_of_drawer", 90 | "take_lid_off_saucepan", 91 | "take_money_out_safe", 92 | "take_off_weighing_scales", 93 | "take_plate_off_colored_dish_rack", 94 | "take_shoes_out_of_box", 95 | "take_toilet_roll_off_stand", 96 | "take_tray_out_of_oven", 97 | "take_umbrella_out_of_umbrella_stand", 98 | "take_usb_out_of_computer", 99 | "toilet_seat_down", 100 | "toilet_seat_up", 101 | "turn_oven_on", 102 | "turn_tap", 103 | "tv_on", 104 | "unplug_charger", 105 | "water_plants", 106 | "weighing_scales", 107 | "wipe_desk" 108 | ] -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | 5 | import yacs.config 6 | 7 | # Default config node 8 | class Config(yacs.config.CfgNode): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs, new_allowed=True) 11 | 12 | CN = Config 13 | 14 | 15 | CONFIG_FILE_SEPARATOR = ',' 16 | 17 | # ----------------------------------------------------------------------------- 18 | # EXPERIMENT CONFIG 19 | # ----------------------------------------------------------------------------- 20 | _C = CN() 21 | _C.SEED = 2023 22 | _C.CMD_TRAILING_OPTS = [] # store command line options as list of strings 23 | 24 | # ----------------------------------------------------------------------------- 25 | # MODEL 26 | # ----------------------------------------------------------------------------- 27 | _C.MODEL = CN() 28 | 29 | # ----------------------------------------------------------------------------- 30 | # DATASET 31 | # ----------------------------------------------------------------------------- 32 | _C.DATASET = CN() 33 | 34 | 35 | def get_config( 36 | config_paths: Optional[Union[List[str], str]] = None, 37 | opts: Optional[list] = None, 38 | ) -> CN: 39 | r"""Create a unified config with default values overwritten by values from 40 | :ref:`config_paths` and overwritten by options from :ref:`opts`. 41 | 42 | Args: 43 | config_paths: List of config paths or string that contains comma 44 | separated list of config paths. 45 | opts: Config options (keys, values) in a list (e.g., passed from 46 | command line into the config. For example, ``opts = ['FOO.BAR', 47 | 0.5]``. Argument can be used for parameter sweeping or quick tests. 48 | """ 49 | config = _C.clone() 50 | if config_paths: 51 | if isinstance(config_paths, str): 52 | if CONFIG_FILE_SEPARATOR in config_paths: 53 | config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) 54 | else: 55 | config_paths = [config_paths] 56 | 57 | for config_path in config_paths: 58 | config.merge_from_file(config_path) 59 | 60 | if opts: 61 | config.CMD_TRAILING_OPTS = config.CMD_TRAILING_OPTS + opts 62 | # FIXME: remove later 63 | for i in range(len(config.CMD_TRAILING_OPTS)): 64 | if config.CMD_TRAILING_OPTS[i] == "DATASET.taskvars": 65 | if type(config.CMD_TRAILING_OPTS[i + 1]) is str: 66 | config.CMD_TRAILING_OPTS[i + 1] = [config.CMD_TRAILING_OPTS[i + 1]] 67 | 68 | config.merge_from_list(config.CMD_TRAILING_OPTS) 69 | 70 | config.freeze() 71 | return config 72 | -------------------------------------------------------------------------------- /config/plain_unet.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2023 2 | output_dir: 'data/exprs/plain_unet/pick_up_cup+0/seed0' 3 | checkpoint: null 4 | 5 | train_batch_size: 16 6 | gradient_accumulation_steps: 1 7 | num_epochs: null 8 | num_train_steps: 100000 9 | warmup_steps: 2000 10 | log_steps: 1000 11 | save_steps: 5000 12 | 13 | optim: 'adamw' 14 | learning_rate: 5e-4 15 | lr_sched: 'linear' # inverse_sqrt, linear 16 | betas: [0.9, 0.98] 17 | weight_decay: 0.001 18 | grad_norm: 5 19 | dropout: 0.1 20 | n_workers: 0 21 | pin_mem: True 22 | 23 | DATASET: 24 | dataset_class: 'keystep_stepwise' 25 | 26 | data_dir: 'data/train_dataset/keysteps/seed0' 27 | taskvars: ('pick_up_cup+0', ) 28 | # taskvars: ['pick_and_lift+0', 29 | # 'pick_up_cup+0', 30 | # 'put_knife_on_chopping_board+0', 31 | # 'put_money_in_safe+0', 32 | # 'push_button+0', 33 | # 'reach_target+0', 34 | # 'slide_block_to_target+0', 35 | # 'stack_wine+0', 36 | # 'take_money_out_safe+0', 37 | # 'take_umbrella_out_of_umbrella_stand+0'] 38 | instr_embed_file: null 39 | # instr_embed_file: 'data/train_dataset/taskvar_instrs/clip' 40 | use_instr_embed: 'none' # none, avg, last, all 41 | gripper_channel: False 42 | cameras: ('left_shoulder', 'right_shoulder', 'wrist') 43 | is_training: True 44 | in_memory: True 45 | num_workers: 0 46 | 47 | MODEL: 48 | model_class: 'PlainUNet' 49 | 50 | unet: True 51 | num_tasks: 1 52 | use_instr_embed: 'none' # none, avg, last, all 53 | instr_embed_size: 512 54 | max_steps: 20 55 | 56 | num_layers: 4 57 | hidden_size: 16 58 | gripper_channel: False 59 | 60 | -------------------------------------------------------------------------------- /config/transformer_unet.yaml: -------------------------------------------------------------------------------- 1 | SEED: 2023 2 | output_dir: 'data/exprs/transformer_unet/put_knife_on_chopping_board+0/seed0' 3 | checkpoint: null 4 | 5 | train_batch_size: 16 6 | gradient_accumulation_steps: 1 7 | num_epochs: null 8 | num_train_steps: 300000 9 | warmup_steps: 2000 10 | log_steps: 1000 11 | save_steps: 5000 12 | 13 | optim: 'adamw' 14 | learning_rate: 5e-4 15 | lr_sched: 'linear' # inverse_sqrt, linear 16 | betas: [0.9, 0.98] 17 | weight_decay: 0.001 18 | grad_norm: 5 19 | dropout: 0.1 20 | n_workers: 0 21 | pin_mem: True 22 | 23 | DATASET: 24 | dataset_class: 'keystep_episode' 25 | 26 | data_dir: 'data/train_dataset/keysteps/seed0' 27 | taskvars: ('put_knife_on_chopping_board+0', ) 28 | # taskvars: ['pick_and_lift+0', 29 | # 'pick_up_cup+0', 30 | # 'put_knife_on_chopping_board+0', 31 | # 'put_money_in_safe+0', 32 | # 'push_button+0', 33 | # 'reach_target+0', 34 | # 'slide_block_to_target+0', 35 | # 'stack_wine+0', 36 | # 'take_money_out_safe+0', 37 | # 'take_umbrella_out_of_umbrella_stand+0'] 38 | # instr_embed_file: null 39 | instr_embed_file: 'data/train_dataset/taskvar_instrs/clip' 40 | use_instr_embed: 'all' # none, avg, last, all 41 | gripper_channel: True 42 | cameras: ('left_shoulder', 'right_shoulder', 'wrist') 43 | is_training: True 44 | in_memory: True 45 | num_workers: 0 46 | 47 | MODEL: 48 | model_class: 'TransformerUNet' 49 | 50 | unet: True 51 | num_tasks: null 52 | use_instr_embed: 'all' # none, avg, last, all 53 | instr_embed_size: 512 54 | max_steps: 20 55 | 56 | num_layers: 4 57 | hidden_size: 32 58 | gripper_channel: 'attn' 59 | 60 | num_trans_layers: 1 61 | nhead: 8 62 | txt_attn_type: 'cross' # self, cross 63 | num_cams: 3 64 | latent_im_size: (8, 8) 65 | quat_input: 'concat' # add, concat 66 | 67 | -------------------------------------------------------------------------------- /core/actioner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Sequence, Tuple, TypedDict, Union, Any 2 | 3 | class BaseActioner: 4 | 5 | def reset(self, task_str, variation, instructions, demo_id): 6 | self.task_str = task_str 7 | self.variation = variation 8 | self.instructions = instructions 9 | self.demo_id = demo_id 10 | 11 | self.step_id = 0 12 | self.state_dict = {} 13 | self.history_obs = {} 14 | 15 | def predict(self, *args, **kwargs): 16 | raise NotImplementedError('implete predict function') 17 | -------------------------------------------------------------------------------- /core/environments.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Sequence, Tuple, TypedDict, Union, Any 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from rlbench.observation_config import ObservationConfig, CameraConfig 8 | from rlbench.environment import Environment 9 | from rlbench.task_environment import TaskEnvironment 10 | from rlbench.action_modes.action_mode import MoveArmThenGripper 11 | from rlbench.action_modes.gripper_action_modes import Discrete 12 | from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning 13 | from rlbench.backend.exceptions import InvalidActionError 14 | from rlbench.backend.observation import Observation 15 | from rlbench.demo import Demo 16 | from rlbench.backend.utils import task_file_to_task_class 17 | from pyrep.errors import IKError, ConfigurationPathError 18 | from pyrep.const import RenderMode 19 | from pyrep.objects.dummy import Dummy 20 | from pyrep.objects.vision_sensor import VisionSensor 21 | 22 | from core.actioner import BaseActioner 23 | from utils.coord_transforms import convert_gripper_pose_world_to_image 24 | from utils.visualize import plot_attention 25 | from utils.recorder import TaskRecorder, StaticCameraMotion, CircleCameraMotion, AttachedCameraMotion 26 | 27 | 28 | class Mover: 29 | def __init__(self, task: TaskEnvironment, disabled: bool = False, max_tries: int = 1): 30 | self._task = task 31 | self._last_action: Optional[np.ndarray] = None 32 | self._step_id = 0 33 | self._max_tries = max_tries 34 | self._disabled = disabled 35 | 36 | def __call__(self, action: np.ndarray): 37 | action = action.copy() 38 | 39 | if self._disabled: 40 | return self._task.step(action) 41 | 42 | target = action.copy() 43 | if self._last_action is not None: 44 | action[7] = self._last_action[7].copy() 45 | 46 | try_id = 0 47 | obs = None 48 | terminate = None 49 | reward = 0 50 | 51 | for try_id in range(self._max_tries): 52 | obs, reward, terminate = self._task.step(action) 53 | 54 | pos = obs.gripper_pose[:3] 55 | rot = obs.gripper_pose[3:7] 56 | dist_pos = np.sqrt(np.square(target[:3] - pos).sum()) # type: ignore 57 | dist_rot = np.sqrt(np.square(target[3:7] - rot).sum()) # type: ignore 58 | # criteria = (dist_pos < 5e-2, dist_rot < 1e-1, (gripper > 0.5) == (target_gripper > 0.5)) 59 | criteria = (dist_pos < 5e-2,) 60 | 61 | if all(criteria) or reward == 1: 62 | break 63 | 64 | print( 65 | f"Too far away (pos: {dist_pos:.3f}, rot: {dist_rot:.3f}, step: {self._step_id})... Retrying..." 66 | ) 67 | 68 | # we execute the gripper action after re-tries 69 | action = target 70 | if ( 71 | not reward 72 | and self._last_action is not None 73 | and action[7] != self._last_action[7] 74 | ): 75 | obs, reward, terminate = self._task.step(action) 76 | 77 | if try_id == self._max_tries: 78 | print(f"Failure after {self._max_tries} tries") 79 | 80 | self._step_id += 1 81 | self._last_action = action.copy() 82 | 83 | other_obs = [] 84 | 85 | return obs, reward, terminate, other_obs 86 | 87 | 88 | class RLBenchEnv(object): 89 | def __init__( 90 | self, 91 | data_path='', 92 | apply_rgb=False, 93 | apply_depth=False, 94 | apply_pc=False, 95 | headless=False, 96 | apply_cameras=("left_shoulder", "right_shoulder", "wrist", "front"), 97 | gripper_pose=None, 98 | ): 99 | 100 | # setup required inputs 101 | self.data_path = data_path 102 | self.apply_rgb = apply_rgb 103 | self.apply_depth = apply_depth 104 | self.apply_pc = apply_pc 105 | self.apply_cameras = apply_cameras 106 | self.gripper_pose = gripper_pose 107 | 108 | # setup RLBench environments 109 | self.obs_config = self.create_obs_config( 110 | apply_rgb, apply_depth, apply_pc, apply_cameras 111 | ) 112 | self.action_mode = MoveArmThenGripper( 113 | arm_action_mode=EndEffectorPoseViaPlanning(), 114 | gripper_action_mode=Discrete(), 115 | ) 116 | self.env = Environment( 117 | self.action_mode, str(data_path), self.obs_config, headless=headless 118 | ) 119 | 120 | def get_observation(self, obs: Observation): 121 | """Fetch the desired state based on the provided demo. 122 | :param obs: incoming obs 123 | :return: required observation (rgb, pc, gripper state) 124 | """ 125 | 126 | # fetch state: (#cameras, H, W, C) 127 | state_dict = {"rgb": [], "depth": [], "pc": []} 128 | for cam in self.apply_cameras: 129 | if self.apply_rgb: 130 | rgb = getattr(obs, "{}_rgb".format(cam)) 131 | state_dict["rgb"] += [rgb] 132 | 133 | if self.apply_depth: 134 | depth = getattr(obs, "{}_depth".format(cam)) 135 | state_dict["depth"] += [depth] 136 | 137 | if self.apply_pc: 138 | pc = getattr(obs, "{}_point_cloud".format(cam)) 139 | state_dict["pc"] += [pc] 140 | 141 | # fetch gripper state (3+4+1, ) 142 | gripper = np.concatenate([obs.gripper_pose, [obs.gripper_open]]).astype(np.float32) 143 | state_dict['gripper'] = gripper 144 | 145 | if self.gripper_pose: 146 | gripper_imgs = np.zeros( 147 | (len(self.apply_cameras), 1, 128, 128), dtype=np.float32 148 | ) 149 | for i, cam in enumerate(self.apply_cameras): 150 | u, v = convert_gripper_pose_world_to_image(obs, cam) 151 | if u > 0 and u < 128 and v > 0 and v < 128: 152 | gripper_imgs[i, 0, v, u] = 1 153 | state_dict["gripper_imgs"] = gripper_imgs 154 | 155 | return state_dict 156 | 157 | def get_demo(self, task_name, variation, episode_index): 158 | """ 159 | Fetch a demo from the saved environment. 160 | :param task_name: fetch task name 161 | :param variation: fetch variation id 162 | :param episode_index: fetch episode index: 0 ~ 99 163 | :return: desired demo 164 | """ 165 | demos = self.env.get_demos( 166 | task_name=task_name, 167 | variation_number=variation, 168 | amount=1, 169 | from_episode_number=episode_index, 170 | random_selection=False, 171 | ) 172 | return demos[0] 173 | 174 | def evaluate( 175 | self, taskvar_id, task_str, max_episodes, variation, num_demos, log_dir, 176 | actioner: BaseActioner, max_tries: int = 1, 177 | demos: Optional[List[Demo]] = None, demo_keys: List = None, 178 | save_attn: bool = False, save_image: bool = False, record_video: bool = False, 179 | ): 180 | """ 181 | Evaluate the policy network on the desired demo or test environments 182 | :param task_type: type of task to evaluate 183 | :param max_episodes: maximum episodes to finish a task 184 | :param num_demos: number of test demos for evaluation 185 | :param model: the policy network 186 | :param demos: whether to use the saved demos 187 | :return: success rate 188 | """ 189 | 190 | self.env.launch() 191 | task_type = task_file_to_task_class(task_str) 192 | task = self.env.get_task(task_type) 193 | task.set_variation(variation) # type: ignore 194 | 195 | if record_video: 196 | # Add a global camera to the scene 197 | cam_placeholder = Dummy('cam_cinematic_placeholder') 198 | cam_resolution = [480, 480] 199 | cam = VisionSensor.create(cam_resolution) 200 | cam.set_pose(cam_placeholder.get_pose()) 201 | cam.set_parent(cam_placeholder) 202 | 203 | # cam_motion = CircleCameraMotion(cam, Dummy('cam_cinematic_base'), 0.005) 204 | global_cam_motion = StaticCameraMotion(cam) 205 | 206 | include_robot_cameras = True 207 | 208 | cams_motion = {"global": global_cam_motion} 209 | 210 | if include_robot_cameras: 211 | # Env cameras 212 | cam_left = VisionSensor.create(cam_resolution) 213 | cam_right = VisionSensor.create(cam_resolution) 214 | cam_wrist = VisionSensor.create(cam_resolution) 215 | 216 | left_cam_motion = AttachedCameraMotion(cam_left, task._scene._cam_over_shoulder_left) 217 | right_cam_motion = AttachedCameraMotion(cam_right, task._scene._cam_over_shoulder_right) 218 | wrist_cam_motion = AttachedCameraMotion(cam_wrist, task._scene._cam_wrist) 219 | 220 | cams_motion["left"] = left_cam_motion 221 | cams_motion["right"] = right_cam_motion 222 | cams_motion["wrist"] = wrist_cam_motion 223 | tr = TaskRecorder(cams_motion, fps=30) 224 | task._scene.register_step_callback(tr.take_snap) 225 | 226 | success_rate = 0.0 227 | 228 | if demos is None: 229 | fetch_list = [i for i in range(num_demos)] 230 | else: 231 | fetch_list = demos 232 | 233 | if demo_keys is None: 234 | demo_keys = [f'episode{i}' for i in range(num_demos)] 235 | 236 | with torch.no_grad(): 237 | for demo_id, demo in tqdm(zip(demo_keys, fetch_list)): 238 | # reset a new demo or a defined demo in the demo list 239 | if isinstance(demo, int): 240 | instructions, obs = task.reset() 241 | else: 242 | print("Resetting to demo", demo_id) 243 | instructions, obs = task.reset_to_demo(demo) # type: ignore 244 | 245 | actioner.reset(task_str, variation, instructions, demo_id) 246 | 247 | move = Mover(task, max_tries=max_tries) 248 | reward = None 249 | 250 | if log_dir is not None and (save_attn or save_image): 251 | ep_dir = log_dir / task_str / demo_id 252 | ep_dir.mkdir(exist_ok=True, parents=True) 253 | 254 | for step_id in range(max_episodes): 255 | # fetch the current observation, and predict one action 256 | obs_state_dict = self.get_observation(obs) # type: ignore 257 | 258 | if log_dir is not None and save_image: 259 | for cam_id, img_by_cam in enumerate(obs_state_dict['rgb']): 260 | cam_dir = ep_dir / f'camera_{cam_id}' 261 | cam_dir.mkdir(exist_ok=True, parents=True) 262 | Image.fromarray(img_by_cam).save(cam_dir / f"{step_id}.png") 263 | 264 | output = actioner.predict(taskvar_id, step_id, obs_state_dict) 265 | action = output["action"] 266 | 267 | if action is None: 268 | break 269 | 270 | # TODO 271 | if log_dir is not None and save_attn and output["action"] is not None: 272 | ep_dir = log_dir / f"episode{demo_id}" 273 | fig = plot_attention( 274 | output["attention"], 275 | obs_state_dict['rgb'], 276 | obs_state_dict['pc'], 277 | ep_dir / f"attn_{step_id}.png", 278 | ) 279 | 280 | # update the observation based on the predicted action 281 | try: 282 | obs, reward, terminate, _ = move(action) 283 | 284 | if reward == 1: 285 | success_rate += 1 / num_demos 286 | break 287 | if terminate: 288 | print("The episode has terminated!") 289 | except (IKError, ConfigurationPathError, InvalidActionError) as e: 290 | print(task_str, demo_id, step_id, e) 291 | reward = 0 292 | break 293 | 294 | print( 295 | task_str, "Variation", variation, "Demo", demo_id, 296 | "Reward", reward, "SR: %.2f" % (success_rate * 100), 297 | ) 298 | 299 | if record_video: 300 | tr.save(str(log_dir/ f"{demo_id}")) 301 | 302 | self.env.shutdown() 303 | return success_rate 304 | 305 | def create_obs_config( 306 | self, apply_rgb, apply_depth, apply_pc, apply_cameras, **kwargs 307 | ): 308 | """ 309 | Set up observation config for RLBench environment. 310 | :param apply_rgb: Applying RGB as inputs. 311 | :param apply_depth: Applying Depth as inputs. 312 | :param apply_pc: Applying Point Cloud as inputs. 313 | :param apply_cameras: Desired cameras. 314 | :return: observation config 315 | """ 316 | unused_cams = CameraConfig() 317 | unused_cams.set_all(False) 318 | used_cams = CameraConfig( 319 | rgb=apply_rgb, 320 | point_cloud=apply_pc, 321 | depth=apply_depth, 322 | mask=False, 323 | render_mode=RenderMode.OPENGL, 324 | **kwargs, 325 | ) 326 | 327 | camera_names = apply_cameras 328 | kwargs = {} 329 | for n in camera_names: 330 | kwargs[n] = used_cams 331 | 332 | obs_config = ObservationConfig( 333 | front_camera=kwargs.get("front", unused_cams), 334 | left_shoulder_camera=kwargs.get("left_shoulder", unused_cams), 335 | right_shoulder_camera=kwargs.get("right_shoulder", unused_cams), 336 | wrist_camera=kwargs.get("wrist", unused_cams), 337 | overhead_camera=kwargs.get("overhead", unused_cams), 338 | joint_forces=False, 339 | joint_positions=False, 340 | joint_velocities=True, 341 | task_low_dim_state=False, 342 | gripper_touch_forces=False, 343 | gripper_pose=True, 344 | gripper_open=True, 345 | gripper_matrix=True, 346 | gripper_joint_positions=True, 347 | ) 348 | 349 | return obs_config 350 | 351 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/hiveformer/0ca80a156acb3985be236fd9ab50e56734f970d6/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/keystep_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | import os 4 | import numpy as np 5 | import einops 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | import torchvision.transforms.functional as transforms_f 11 | 12 | import lmdb 13 | import msgpack 14 | import msgpack_numpy 15 | msgpack_numpy.patch() 16 | 17 | from utils.ops import pad_tensors, gen_seq_masks 18 | 19 | 20 | class DataTransform(object): 21 | def __init__(self, scales): 22 | self.scales = scales 23 | 24 | def __call__(self, data) -> Dict[str, torch.Tensor]: 25 | """ 26 | Inputs: 27 | data: dict 28 | - rgb: (T, N, C, H, W), N: num of cameras 29 | - pc: (T, N, C, H, W) 30 | """ 31 | keys = list(data.keys()) 32 | 33 | # Continuous range of scales 34 | sc = np.random.uniform(*self.scales) 35 | 36 | t, n, c, raw_h, raw_w = data[keys[0]].shape 37 | data = {k: v.flatten(0, 1) for k, v in data.items()} # (t*n, h, w, c) 38 | resized_size = [int(raw_h * sc), int(raw_w * sc)] 39 | 40 | # Resize based on randomly sampled scale 41 | data = { 42 | k: transforms_f.resize( 43 | v, 44 | resized_size, 45 | transforms.InterpolationMode.BILINEAR 46 | ) 47 | for k, v in data.items() 48 | } 49 | 50 | # Adding padding if crop size is smaller than the resized size 51 | if raw_h > resized_size[0] or raw_w > resized_size[1]: 52 | right_pad = max(raw_w - resized_size[1], 0) 53 | bottom_pad = max(raw_h - resized_size[0], 0) 54 | data = { 55 | k: transforms_f.pad( 56 | v, 57 | padding=[0, 0, right_pad, bottom_pad], 58 | padding_mode="edge", 59 | ) 60 | for k, v in data.items() 61 | } 62 | 63 | # Random Cropping 64 | i, j, h, w = transforms.RandomCrop.get_params( 65 | data[keys[0]], output_size=(raw_h, raw_w) 66 | ) 67 | 68 | data = {k: transforms_f.crop(v, i, j, h, w) for k, v in data.items()} 69 | 70 | data = { 71 | k: einops.rearrange(v, "(t n) c h w -> t n c h w", t=t) 72 | for k, v in data.items() 73 | } 74 | 75 | return data 76 | 77 | 78 | class KeystepDataset(Dataset): 79 | def __init__( 80 | self, data_dir, taskvars, instr_embed_file=None, 81 | gripper_channel=False, 82 | cameras=('left_shoulder', 'right_shoulder', 'wrist'), 83 | use_instr_embed='none', is_training=False, 84 | in_memory=False, **kwargs 85 | ): 86 | ''' 87 | - use_instr_embed: 88 | 'none': use task_id; 89 | 'avg': use the average instruction embedding; 90 | 'last': use the last instruction embedding; 91 | 'all': use the embedding of all instruction tokens. 92 | ''' 93 | self.data_dir = data_dir 94 | self.taskvars = taskvars 95 | self.instr_embed_file = instr_embed_file 96 | self.taskvar_to_id = { 97 | x: i for i, x in enumerate(self.taskvars) 98 | } 99 | self.use_instr_embed = use_instr_embed 100 | self.gripper_channel = gripper_channel 101 | self.cameras = cameras 102 | self.in_memory = in_memory 103 | self.is_training = is_training 104 | 105 | if self.in_memory: 106 | self.memory = {} 107 | 108 | self._transform = DataTransform((0.75, 1.25)) 109 | 110 | self.lmdb_envs, self.lmdb_txns = [], [] 111 | self.episode_ids = [] 112 | for i, taskvar in enumerate(self.taskvars): 113 | lmdb_env = lmdb.open(os.path.join(data_dir, taskvar), readonly=True) 114 | self.lmdb_envs.append(lmdb_env) 115 | lmdb_txn = lmdb_env.begin() 116 | self.lmdb_txns.append(lmdb_txn) 117 | keys = list(lmdb_txn.cursor().iternext(values=False)) 118 | self.episode_ids.extend([(i, key) for key in keys]) 119 | if self.in_memory: 120 | self.memory[f'taskvar{i}'] = {} 121 | 122 | if self.use_instr_embed != 'none': 123 | assert self.instr_embed_file is not None 124 | self.lmdb_instr_env = lmdb.open(self.instr_embed_file, readonly=True) 125 | self.lmdb_instr_txn = self.lmdb_instr_env.begin() 126 | if self.in_memory: 127 | self.memory['instr_embeds'] = {} 128 | else: 129 | self.lmdb_instr_env = None 130 | 131 | def __exit__(self): 132 | for lmdb_env in self.lmdb_envs: 133 | lmdb_env.close() 134 | if self.lmdb_instr_env is not None: 135 | self.lmdb_instr_env.close() 136 | 137 | def __len__(self): 138 | return len(self.episode_ids) 139 | 140 | def get_taskvar_episode(self, taskvar_idx, episode_key): 141 | if self.in_memory: 142 | mem_key = f'taskvar{taskvar_idx}' 143 | if episode_key in self.memory[mem_key]: 144 | return self.memory[mem_key][episode_key] 145 | 146 | value = self.lmdb_txns[taskvar_idx].get(episode_key) 147 | value = msgpack.unpackb(value) 148 | if self.in_memory: 149 | self.memory[mem_key][episode_key] = value 150 | return value 151 | 152 | def get_taskvar_instr_embeds(self, taskvar): 153 | instr_embeds = None 154 | if self.in_memory: 155 | if taskvar in self.memory['instr_embeds']: 156 | instr_embeds = self.memory['instr_embeds'][taskvar] 157 | 158 | if instr_embeds is None: 159 | instr_embeds = self.lmdb_instr_txn.get(taskvar.encode('ascii')) 160 | instr_embeds = msgpack.unpackb(instr_embeds) 161 | instr_embeds = [torch.from_numpy(x).float() for x in instr_embeds] 162 | if self.in_memory: 163 | self.memory['instr_embeds'][taskvar] = instr_embeds 164 | 165 | # randomly select one instruction for the taskvar 166 | ridx = np.random.randint(len(instr_embeds)) 167 | instr_embeds = instr_embeds[ridx] 168 | 169 | if self.use_instr_embed == 'avg': 170 | instr_embeds = torch.mean(instr_embeds, 0, keepdim=True) 171 | elif self.use_instr_embed == 'last': 172 | instr_embeds = instr_embeds[-1:] 173 | 174 | return instr_embeds # (num_ttokens, dim) 175 | 176 | def __getitem__(self, idx): 177 | taskvar_idx, episode_key = self.episode_ids[idx] 178 | 179 | value = self.get_taskvar_episode(taskvar_idx, episode_key) 180 | 181 | # The last one is the stop observation 182 | rgbs = torch.from_numpy(value['rgb'][:-1]).float().permute(0, 1, 4, 2, 3) # (T, N, C, H, W) 183 | pcs = torch.from_numpy(value['pc'][:-1]).float().permute(0, 1, 4, 2, 3) 184 | # # normalise to [-1, 1] 185 | # rgbs = 2 * (rgbs / 255.0 - 0.5) 186 | rgbs = transforms_f.normalize( 187 | rgbs.float(), 188 | [0.485, 0.456, 0.406], 189 | [0.229, 0.224, 0.225] 190 | ) 191 | 192 | num_steps, num_cameras, _, im_height, im_width = rgbs.size() 193 | 194 | if self.gripper_channel: 195 | gripper_imgs = torch.zeros( 196 | num_steps, num_cameras, 1, im_height, im_width, 197 | dtype=torch.float32 198 | ) 199 | for t in range(num_steps): 200 | for c, cam in enumerate(self.cameras): 201 | u, v = value['gripper_pose'][t][cam] 202 | if u >= 0 and u < 128 and v >=0 and v < 128: 203 | gripper_imgs[t, c, 0, v, u] = 1 204 | rgbs = torch.cat([rgbs, gripper_imgs], dim=2) 205 | 206 | # rgb, pcd: (T, N, C, H, W) 207 | outs = {'rgbs': rgbs, 'pcds': pcs} 208 | if self.is_training: 209 | outs = self._transform(outs) 210 | 211 | outs['step_ids'] = torch.arange(0, num_steps).long() 212 | outs['actions'] = torch.from_numpy(value['action'][1:]) 213 | outs['episode_ids'] = episode_key.decode('ascii') 214 | outs['taskvars'] = self.taskvars[taskvar_idx] 215 | outs['taskvar_ids'] = taskvar_idx 216 | 217 | if self.use_instr_embed != 'none': 218 | outs['instr_embeds'] = self.get_taskvar_instr_embeds(outs['taskvars']) 219 | 220 | return outs 221 | 222 | 223 | def stepwise_collate_fn(data: List[Dict]): 224 | batch = {} 225 | 226 | for key in data[0].keys(): 227 | if key == 'taskvar_ids': 228 | batch[key] = [ 229 | torch.LongTensor([v['taskvar_ids']] * len(v['step_ids'])) for v in data 230 | ] 231 | elif key == 'instr_embeds': 232 | batch[key] = sum([ 233 | [v['instr_embeds']] * len(v['step_ids']) for v in data 234 | ], []) 235 | else: 236 | batch[key] = [v[key] for v in data] 237 | 238 | for key in ['rgbs', 'pcds', 'taskvar_ids', 'step_ids', 'actions']: 239 | # e.g. rgbs: (B*T, N, C, H, W) 240 | batch[key] = torch.cat(batch[key], dim=0) 241 | 242 | if 'instr_embeds' in batch: 243 | batch['instr_embeds'] = pad_tensors(batch['instr_embeds']) 244 | 245 | return batch 246 | 247 | 248 | def episode_collate_fn(data: List[Dict]): 249 | batch = {} 250 | 251 | for key in data[0].keys(): 252 | batch[key] = [v[key] for v in data] 253 | 254 | batch['taskvar_ids'] = torch.LongTensor(batch['taskvar_ids']) 255 | num_steps = [len(x['rgbs']) for x in data] 256 | if 'instr_embeds' in batch: 257 | num_ttokens = [len(x['instr_embeds']) for x in data] 258 | 259 | for key in ['rgbs', 'pcds', 'step_ids', 'actions']: 260 | # e.g. rgbs: (B, T, N, C, H, W) 261 | batch[key] = pad_tensors(batch[key], lens=num_steps) 262 | 263 | if 'instr_embeds' in batch: 264 | batch['instr_embeds'] = pad_tensors(batch['instr_embeds'], lens=num_ttokens) 265 | batch['txt_masks'] = torch.from_numpy(gen_seq_masks(num_ttokens)) 266 | else: 267 | batch['txt_masks'] = torch.ones(len(num_steps), 1).bool() 268 | 269 | batch['step_masks'] = torch.from_numpy(gen_seq_masks(num_steps)) 270 | 271 | return batch 272 | 273 | 274 | 275 | if __name__ == '__main__': 276 | import time 277 | from torch.utils.data import DataLoader 278 | 279 | data_dir = 'data/train_dataset/keysteps/seed0' 280 | taskvars = ['pick_up_cup+0'] 281 | cameras = ['left_shoulder', 'right_shoulder', 'wrist'] 282 | instr_embed_file = None 283 | instr_embed_file = 'data/train_dataset/taskvar_instrs/clip' 284 | 285 | dataset = KeystepDataset(data_dir, taskvars, 286 | instr_embed_file=instr_embed_file, 287 | use_instr_embed='all', 288 | gripper_channel='attn', cameras=cameras, 289 | is_training=True 290 | ) 291 | 292 | data_loader = DataLoader( 293 | dataset, batch_size=16, 294 | # collate_fn=stepwise_collate_fn 295 | collate_fn=episode_collate_fn, 296 | ) 297 | 298 | print(len(dataset), len(data_loader)) 299 | 300 | st = time.time() 301 | for batch in data_loader: 302 | for k, v in batch.items(): 303 | if isinstance(v, torch.Tensor): 304 | print(k, v.size()) 305 | break 306 | et = time.time() 307 | print('cost time: %.2fs' % (et - st)) -------------------------------------------------------------------------------- /dataloaders/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | A prefetch loader to speedup data loading 6 | Modified from Nvidia Deep Learning Examples 7 | (https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch). 8 | """ 9 | import random 10 | from typing import List, Dict, Tuple, Union, Iterator 11 | 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | from torch.utils.data.distributed import DistributedSampler 15 | import torch.distributed as dist 16 | 17 | 18 | class MetaLoader: 19 | """wraps multiple data loaders""" 20 | 21 | def __init__( 22 | self, loaders, accum_steps: int = 1, distributed: bool = False, device=None 23 | ): 24 | assert isinstance(loaders, dict) 25 | self.name2loader = {} 26 | self.name2iter = {} 27 | self.name2pre_epoch = {} 28 | self.names: List[str] = [] 29 | ratios: List[int] = [] 30 | for n, l in loaders.items(): 31 | if isinstance(l, tuple): 32 | l, r, p = l 33 | elif isinstance(l, DataLoader): 34 | r = 1 35 | p = lambda e: None 36 | else: 37 | raise ValueError() 38 | self.names.append(n) 39 | self.name2loader[n] = l 40 | self.name2iter[n] = iter(l) 41 | self.name2pre_epoch[n] = p 42 | ratios.append(r) 43 | 44 | self.accum_steps = accum_steps 45 | self.device = device 46 | self.sampling_ratios = torch.tensor(ratios).float().to(self.device) 47 | self.distributed = distributed 48 | self.step = 0 49 | 50 | def __iter__(self) -> Iterator[Tuple]: 51 | """this iterator will run indefinitely""" 52 | task_id = None 53 | epoch_id = 0 54 | while True: 55 | if self.step % self.accum_steps == 0: 56 | task_id = torch.multinomial(self.sampling_ratios, 1) 57 | if self.distributed: 58 | # make sure all process is training same task 59 | dist.broadcast(task_id, 0) 60 | self.step += 1 61 | task = self.names[task_id.cpu().item()] 62 | iter_ = self.name2iter[task] 63 | try: 64 | batch = next(iter_) 65 | except StopIteration: 66 | epoch_id += 1 67 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch 68 | # before creating the DataLoader iterator is necessary to make shuffling work properly 69 | # across multiple epochs. Otherwise, the same ordering will be always used. 70 | self.name2pre_epoch[task](epoch_id) 71 | iter_ = iter(self.name2loader[task]) 72 | batch = next(iter_) 73 | self.name2iter[task] = iter_ 74 | 75 | yield task, batch 76 | 77 | 78 | def move_to_cuda(batch: Union[List, Tuple, Dict, torch.Tensor], device: torch.device): 79 | if isinstance(batch, torch.Tensor): 80 | return batch.to(device, non_blocking=True) 81 | elif isinstance(batch, list): 82 | return [move_to_cuda(t, device) for t in batch] 83 | elif isinstance(batch, tuple): 84 | return tuple(move_to_cuda(t, device) for t in batch) 85 | elif isinstance(batch, dict): 86 | return {n: move_to_cuda(t, device) for n, t in batch.items()} 87 | return batch 88 | 89 | 90 | class PrefetchLoader(object): 91 | """ 92 | overlap compute and cuda data transfer 93 | """ 94 | def __init__(self, loader, device: torch.device): 95 | self.loader = loader 96 | self.device = device 97 | 98 | def __iter__(self): 99 | loader_it = iter(self.loader) 100 | self.preload(loader_it) 101 | batch = self.next(loader_it) 102 | while batch is not None: 103 | yield batch 104 | batch = self.next(loader_it) 105 | 106 | def __len__(self): 107 | return len(self.loader) 108 | 109 | def preload(self, it): 110 | try: 111 | self.batch = next(it) 112 | except StopIteration: 113 | self.batch = None 114 | return 115 | self.batch = move_to_cuda(self.batch, self.device) 116 | 117 | def next(self, it): 118 | batch = self.batch 119 | self.preload(it) 120 | return batch 121 | 122 | def __getattr__(self, name): 123 | method = self.loader.__getattribute__(name) 124 | return method 125 | 126 | 127 | def build_dataloader(dataset, collate_fn, is_train: bool, opts): 128 | 129 | batch_size = opts.train_batch_size if is_train else opts.val_batch_size 130 | 131 | if opts.local_rank == -1: 132 | if is_train: 133 | sampler: Union[ 134 | RandomSampler, SequentialSampler, DistributedSampler 135 | ] = RandomSampler(dataset) 136 | else: 137 | sampler = SequentialSampler(dataset) 138 | 139 | size = torch.cuda.device_count() if torch.cuda.is_available() else 1 140 | pre_epoch = lambda e: None 141 | 142 | # DataParallel: scale the batch size by the number of GPUs 143 | if size > 1: 144 | batch_size *= size 145 | 146 | else: 147 | size = dist.get_world_size() 148 | sampler = DistributedSampler( 149 | dataset, num_replicas=size, rank=dist.get_rank(), 150 | shuffle=is_train 151 | ) 152 | pre_epoch = sampler.set_epoch 153 | 154 | loader = DataLoader( 155 | dataset, 156 | sampler=sampler, 157 | batch_size=batch_size, 158 | num_workers=opts.n_workers, 159 | pin_memory=opts.pin_mem, 160 | collate_fn=collate_fn, 161 | drop_last=False, 162 | prefetch_factor=2, 163 | ) 164 | 165 | return loader, pre_epoch 166 | -------------------------------------------------------------------------------- /eval_models.py: -------------------------------------------------------------------------------- 1 | from train_models import model_factory 2 | from utils.misc import set_random_seed 3 | from config.default import get_config 4 | from core.actioner import BaseActioner 5 | from core.environments import RLBenchEnv 6 | from typing import Tuple, Dict, List 7 | 8 | import os 9 | import numpy as np 10 | import itertools 11 | from tqdm import tqdm 12 | import copy 13 | from pathlib import Path 14 | import jsonlines 15 | import tap 16 | 17 | import torch 18 | import lmdb 19 | import msgpack 20 | import msgpack_numpy 21 | msgpack_numpy.patch() 22 | 23 | import torchvision.transforms.functional as transforms_f 24 | 25 | 26 | 27 | class Arguments(tap.Tap): 28 | exp_config: str 29 | device: str = 'cuda' # cpu, cuda 30 | 31 | eval_train_split: bool = False 32 | 33 | seed: int = 100 # seed for RLBench 34 | num_demos: int = 500 35 | 36 | headless: bool = False 37 | max_tries: int = 10 38 | save_image: bool = False 39 | record_video: bool = False 40 | 41 | 42 | class Actioner(BaseActioner): 43 | def __init__(self, args) -> None: 44 | config = get_config(args.exp_config, args.extra_args) 45 | self.config = config 46 | 47 | self.device = torch.device(args.device) 48 | 49 | self.gripper_channel = self.config.MODEL.gripper_channel 50 | model_class = model_factory[config.MODEL.model_class] 51 | self.model = model_class(**config.MODEL) 52 | if config.checkpoint: 53 | checkpoint = torch.load( 54 | config.checkpoint, map_location=lambda storage, loc: storage) 55 | if 'state_dict' in checkpoint: 56 | checkpoint = checkpoint['state_dict'] 57 | self.model.load_state_dict(checkpoint, strict=True) 58 | 59 | self.model.to(self.device) 60 | self.model.eval() 61 | 62 | self.use_history = config.MODEL.model_class == 'TransformerUNet' 63 | self.use_instr_embed = config.MODEL.use_instr_embed 64 | if type(config.DATASET.taskvars) is str: 65 | config.DATASET.taskvars = [config.DATASET.taskvars] 66 | self.taskvars = config.DATASET.taskvars 67 | 68 | if self.use_instr_embed != 'none': 69 | assert config.DATASET.instr_embed_file is not None 70 | self.lmdb_instr_env = lmdb.open( 71 | config.DATASET.instr_embed_file, readonly=True) 72 | self.lmdb_instr_txn = self.lmdb_instr_env.begin() 73 | self.memory = {'instr_embeds': {}} 74 | else: 75 | self.lmdb_instr_env = None 76 | 77 | def __exit__(self): 78 | self.lmdb_instr_env.close() 79 | 80 | def get_taskvar_instr_embeds(self, taskvar): 81 | instr_embeds = None 82 | if taskvar in self.memory['instr_embeds']: 83 | instr_embeds = self.memory['instr_embeds'][taskvar] 84 | 85 | if instr_embeds is None: 86 | instr_embeds = self.lmdb_instr_txn.get(taskvar.encode('ascii')) 87 | instr_embeds = msgpack.unpackb(instr_embeds) 88 | instr_embeds = [torch.from_numpy(x).float() for x in instr_embeds] 89 | # ridx = np.random.randint(len(instr_embeds)) 90 | ridx = 0 91 | instr_embeds = instr_embeds[ridx] 92 | if self.use_instr_embed == 'avg': 93 | instr_embeds = torch.mean(instr_embeds, 0, keepdim=True) 94 | elif self.use_instr_embed == 'last': 95 | instr_embeds = instr_embeds[-1:] 96 | self.memory['instr_embeds'][taskvar] = instr_embeds 97 | return instr_embeds # (num_ttokens, dim) 98 | 99 | def preprocess_obs(self, taskvar_id, step_id, obs): 100 | rgb = np.stack(obs['rgb'], 0) # (N, H, W, C) 101 | rgb = torch.from_numpy(rgb).float().permute(0, 3, 1, 2) 102 | # # normalise to [-1, 1] 103 | # rgb = 2 * (rgb / 255.0 - 0.5) 104 | rgb = transforms_f.normalize( 105 | rgb.float(), 106 | [0.485, 0.456, 0.406], 107 | [0.229, 0.224, 0.225] 108 | ) 109 | 110 | if self.gripper_channel == "attn": 111 | gripper_imgs = torch.from_numpy( 112 | obs["gripper_imgs"]).float() # (N, 1, H, W) 113 | rgb = torch.cat([rgb, gripper_imgs], dim=1) 114 | 115 | pcd = np.stack(obs['pc'], 0) # (N, H, W, C) 116 | pcd = torch.from_numpy(pcd).float().permute(0, 3, 1, 2) 117 | 118 | batch = { 119 | 'rgbs': rgb.unsqueeze(0), 120 | 'pcds': pcd.unsqueeze(0), 121 | 'step_ids': torch.LongTensor([step_id]), 122 | 'taskvar_ids': torch.LongTensor([taskvar_id]), 123 | } 124 | if self.use_instr_embed != 'none': 125 | taskvar = self.taskvars[taskvar_id] 126 | batch['instr_embeds'] = self.get_taskvar_instr_embeds( 127 | taskvar).unsqueeze(0) 128 | batch['txt_masks'] = torch.ones( 129 | 1, batch['instr_embeds'].size(1)).long() 130 | 131 | if self.use_history: 132 | batch['rgbs'] = batch['rgbs'].unsqueeze(1) # (B, T, N, C, H, W) 133 | batch['pcds'] = batch['pcds'].unsqueeze(1) 134 | batch['step_ids'] = batch['step_ids'].unsqueeze(1) 135 | batch['step_masks'] = torch.ones(1, 1) 136 | if len(self.history_obs) == 0: 137 | self.history_obs = batch 138 | else: 139 | for key in ['rgbs', 'pcds', 'step_ids', 'step_masks']: 140 | self.history_obs[key] = torch.cat( 141 | [self.history_obs[key], batch[key]], dim=1 142 | ) 143 | batch = copy.deepcopy(self.history_obs) 144 | 145 | # for k, v in batch.items(): 146 | # print(k, v.size()) 147 | return batch 148 | 149 | def predict(self, taskvar_id, step_id, obs_state_dict): 150 | # print(obs_state_dict) 151 | batch = self.preprocess_obs(taskvar_id, step_id, obs_state_dict) 152 | with torch.no_grad(): 153 | action = self.model(batch)[0] 154 | if self.use_history: 155 | action = action[-1] 156 | 157 | action = action.data.cpu().numpy() 158 | out = { 159 | 'action': action 160 | } 161 | # print(self.demo_id, step_id) 162 | 163 | return out 164 | 165 | 166 | def evaluate_keysteps(args): 167 | set_random_seed(args.seed) 168 | 169 | actioner = Actioner(args) 170 | 171 | config = actioner.config 172 | 173 | if args.eval_train_split: 174 | microstep_data_dir = Path( 175 | config.DATASET.data_dir.replace('keysteps', 'microsteps')) 176 | pred_dir = os.path.join(config.output_dir, 'preds', 'train') 177 | else: 178 | microstep_data_dir = '' 179 | pred_dir = os.path.join(config.output_dir, 'preds', f'seed{args.seed}') 180 | os.makedirs(pred_dir, exist_ok=True) 181 | 182 | env = RLBenchEnv( 183 | data_path=microstep_data_dir, 184 | apply_rgb=True, 185 | apply_pc=True, 186 | apply_cameras=config.DATASET.cameras, 187 | headless=args.headless, 188 | gripper_pose=config.MODEL.gripper_channel, 189 | ) 190 | 191 | outfile = jsonlines.open( 192 | os.path.join(pred_dir, 'results.jsonl'), 'a', flush=True 193 | ) 194 | for taskvar_id, taskvar in enumerate(actioner.taskvars): 195 | task_str, variation = taskvar.split('+') 196 | variation = int(variation) 197 | 198 | if args.eval_train_split: 199 | episodes_dir = microstep_data_dir / task_str / \ 200 | f"variation{variation}" / "episodes" 201 | demo_keys, demos = [], [] 202 | for ep in tqdm(episodes_dir.glob('episode*')): 203 | episode_id = int(ep.stem[7:]) 204 | demo = env.get_demo(task_str, variation, episode_id) 205 | demo_keys.append(f'episode{episode_id}') 206 | demos.append(demo) 207 | # if len(demos) > 1: 208 | # break 209 | num_demos = len(demos) 210 | else: 211 | demo_keys = None 212 | demos = None 213 | num_demos = args.num_demos 214 | 215 | success_rate = env.evaluate( 216 | taskvar_id, 217 | task_str, 218 | actioner=actioner, 219 | max_episodes=config.MODEL.max_steps, 220 | variation=variation, 221 | num_demos=num_demos, 222 | demos=demos, 223 | demo_keys=demo_keys, 224 | log_dir=Path(pred_dir), 225 | max_tries=args.max_tries, 226 | save_image=args.save_image, 227 | record_video=args.record_video, 228 | ) 229 | 230 | print("Testing Success Rate {}: {:.04f}".format(task_str, success_rate)) 231 | outfile.write( 232 | { 233 | 'checkpoint': config.checkpoint, 234 | 'task': task_str, 'variation': variation, 235 | 'num_demos': num_demos, 'sr': success_rate 236 | } 237 | ) 238 | 239 | outfile.close() 240 | 241 | 242 | if __name__ == '__main__': 243 | args = Arguments().parse_args(known_only=True) 244 | evaluate_keysteps(args) 245 | -------------------------------------------------------------------------------- /job_scripts/eval_tst_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=eval_tst 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH --output=slurm_logs/%j.out 11 | #SBATCH --error=slurm_logs/%j.out 12 | 13 | set -x 14 | set -e 15 | 16 | module purge 17 | pwd; hostname; date 18 | 19 | cd $WORK/codes/hiveformer 20 | 21 | . $WORK/miniconda3/etc/profile.d/conda.sh 22 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 23 | 24 | conda activate hiveformer 25 | export PYTHONPATH=$PYTHONPATH:$(pwd) 26 | 27 | 28 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 29 | mkdir $XDG_RUNTIME_DIR 30 | chmod 700 $XDG_RUNTIME_DIR 31 | 32 | dirname=$1 33 | seed=$2 34 | checkpoint_step=$3 35 | 36 | outdir=data/exprs/transformer_unet/${dirname}/seed${seed} 37 | 38 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 39 | singularity exec \ 40 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 41 | $SINGULARITY_ALLOWED_DIR/vlc_rlbench.sif \ 42 | xvfb-run -a python eval_models.py \ 43 | --exp_config ${outdir}/logs/training_config.yaml \ 44 | --seed 200 --num_demos 500 \ 45 | --checkpoint ${outdir}/ckpts/model_step_${checkpoint_step}.pt -------------------------------------------------------------------------------- /job_scripts/train_multitask_bc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trainbc 3 | #SBATCH --nodes 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gres=gpu:1 6 | #SBATCH --cpus-per-task=10 7 | #SBATCH --qos=qos_gpu-t3 8 | #SBATCH --hint=nomultithread 9 | #SBATCH --time=20:00:00 10 | #SBATCH --output=slurm_logs/%j.out 11 | #SBATCH --error=slurm_logs/%j.out 12 | 13 | set -x 14 | set -e 15 | 16 | module purge 17 | pwd; hostname; date 18 | 19 | cd $WORK/codes/hiveformer 20 | 21 | . $WORK/miniconda3/etc/profile.d/conda.sh 22 | export LD_LIBRARY_PATH=$WORK/miniconda3/envs/bin/lib:$LD_LIBRARY_PATH 23 | 24 | conda activate hiveformer 25 | export PYTHONPATH=$PYTHONPATH:$(pwd) 26 | 27 | export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4)) 28 | export WORLD_SIZE=$(($SLURM_NNODES * $SLURM_NTASKS_PER_NODE)) 29 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 30 | export MASTER_ADDR=$master_addr 31 | 32 | 33 | export XDG_RUNTIME_DIR=$SCRATCH/tmp/runtime-$SLURM_JOBID 34 | mkdir $XDG_RUNTIME_DIR 35 | chmod 700 $XDG_RUNTIME_DIR 36 | 37 | taskvars="pick_and_lift+0,pick_up_cup+0,put_knife_on_chopping_board+0,put_money_in_safe+0,push_button+0,reach_target+0,slide_block_to_target+0,stack_wine+0,take_money_out_safe+0,take_umbrella_out_of_umbrella_stand+0" 38 | seed=$1 39 | 40 | configfile=config/transformer_unet.yaml 41 | 42 | srun --export=ALL,XDG_RUNTIME_DIR=$XDG_RUNTIME_DIR \ 43 | singularity exec \ 44 | --bind $WORK:$WORK,$SCRATCH:$SCRATCH,$STORE:$STORE --nv \ 45 | $SINGULARITY_ALLOWED_DIR/vlc_rlbench.sif \ 46 | xvfb-run -a python train_models.py \ 47 | --exp-config ${configfile} \ 48 | output_dir data/exprs/transformer_unet/10tasks \ 49 | DATASET.taskvars ${taskvars} \ 50 | DATASET.data_dir data/train_dataset/keysteps/seed${seed} 51 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/hiveformer/0ca80a156acb3985be236fd9ab50e56734f970d6/models/__init__.py -------------------------------------------------------------------------------- /models/network_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Literal, Union, List, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class ConvLayer(nn.Module): 9 | def __init__( 10 | self, 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride_size, 15 | apply_norm=True, 16 | apply_activation=True, 17 | ): 18 | super().__init__() 19 | 20 | padding_size = ( 21 | kernel_size // 2 22 | if isinstance(kernel_size, int) 23 | else (kernel_size[0] // 2, kernel_size[1] // 2) 24 | ) 25 | 26 | self.conv = nn.Conv2d( 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride_size, 31 | padding_size, 32 | padding_mode="replicate", 33 | ) 34 | 35 | if apply_norm: 36 | self.norm = nn.GroupNorm(1, out_channels, affine=True) 37 | 38 | if apply_activation: 39 | self.activation = nn.LeakyReLU(0.02) 40 | 41 | def forward( 42 | self, ft: torch.Tensor 43 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 44 | out = self.conv(ft) 45 | 46 | if hasattr(self, "norm"): 47 | out = self.norm(out) 48 | 49 | if hasattr(self, "activation"): 50 | out = self.activation(out) 51 | 52 | return out 53 | 54 | 55 | def dense_layer(in_channels, out_channels, apply_activation=True): 56 | layer: List[nn.Module] = [nn.Linear(in_channels, out_channels)] 57 | if apply_activation: 58 | layer += [nn.LeakyReLU(0.02)] 59 | return layer 60 | 61 | 62 | def normalise_quat(x): 63 | return x / x.square().sum(dim=-1).sqrt().unsqueeze(-1) 64 | 65 | 66 | class ActionLoss(object): 67 | def decompose_actions(self, actions): 68 | pos = actions[..., :3] 69 | rot = actions[..., 3:7] 70 | open = actions[..., 7] 71 | return pos, rot, open 72 | 73 | def compute_loss(self, preds, targets, masks=None) -> Dict[str, torch.Tensor]: 74 | pred_pos, pred_rot, pred_open = self.decompose_actions(preds) 75 | tgt_pos, tgt_rot, tgt_open = self.decompose_actions(targets) 76 | 77 | losses = {} 78 | 79 | # Automatically matching the closest quaternions (symmetrical solution). 80 | tgt_rot_ = -tgt_rot.clone() 81 | 82 | if masks is None: 83 | losses['pos'] = F.mse_loss(pred_pos, tgt_pos) 84 | 85 | rot_loss = F.mse_loss(pred_rot, tgt_rot) 86 | rot_loss_ = F.mse_loss(pred_rot, tgt_rot_) 87 | select_mask = (rot_loss < rot_loss_).float() 88 | losses['rot'] = (select_mask * rot_loss + (1 - select_mask) * rot_loss_) 89 | losses['rot'] = rot_loss 90 | 91 | losses['open'] = F.binary_cross_entropy_with_logits(pred_open, tgt_open) 92 | else: 93 | div_sum = torch.sum(masks) 94 | 95 | losses['pos'] = torch.sum(F.mse_loss( 96 | pred_pos, tgt_pos, reduction='none') * masks.unsqueeze(-1)) / div_sum / 3 97 | 98 | rot_loss = torch.sum(F.mse_loss( 99 | pred_rot, tgt_rot, reduction='none') * masks.unsqueeze(-1)) 100 | rot_loss_ = torch.sum(F.mse_loss( 101 | pred_rot, tgt_rot_, reduction='none') * masks.unsqueeze(-1)) 102 | select_mask = (rot_loss < rot_loss_).float() 103 | losses['rot'] = (select_mask * rot_loss + (1 - select_mask) * rot_loss_) / div_sum / 4 104 | 105 | losses['open'] = torch.sum(F.binary_cross_entropy_with_logits( 106 | pred_open, tgt_open, reduction='none') * masks) / div_sum 107 | 108 | losses['total'] = losses['pos'] + losses['rot'] + losses['open'] 109 | 110 | return losses 111 | -------------------------------------------------------------------------------- /models/plain_unet.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Literal, Union, List, Dict 2 | 3 | import numpy as np 4 | from einops.layers.torch import Rearrange 5 | import einops 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from models.network_utils import ( 12 | ConvLayer, dense_layer, normalise_quat, 13 | ActionLoss 14 | ) 15 | 16 | 17 | class PlainUNet(nn.Module): 18 | def __init__( 19 | self, hidden_size: int = 16, num_layers: int = 4, 20 | num_tasks: int = None, max_steps: int = 20, 21 | gripper_channel: bool = False, unet: bool = True, 22 | use_instr_embed: str = 'none', instr_embed_size: int = None, 23 | **kwargs 24 | ): 25 | super().__init__() 26 | 27 | self.hidden_size = hidden_size 28 | self.num_layers = num_layers 29 | self.num_tasks = num_tasks 30 | self.max_steps = max_steps 31 | self.gripper_channel = gripper_channel 32 | self.unet = unet 33 | self.use_instr_embed = use_instr_embed 34 | self.instr_embed_size = instr_embed_size 35 | 36 | if self.use_instr_embed == 'none': 37 | assert self.num_tasks is not None 38 | self.task_embedding = nn.Embedding( 39 | self.num_tasks, self.hidden_size) 40 | else: 41 | assert self.instr_embed_size is not None 42 | self.task_embedding = nn.Linear( 43 | self.instr_embed_size, self.hidden_size) 44 | 45 | self.step_embedding = nn.Embedding(self.max_steps, self.hidden_size) 46 | 47 | # in_channels: RGB or RGB + gripper pose heatmap image 48 | self.in_channels = 4 if self.gripper_channel == "attn" else 3 49 | 50 | # Input RGB Preprocess (SiameseNet for each camera) 51 | self.rgb_preprocess = ConvLayer( 52 | self.in_channels, self.hidden_size // 2, 53 | kernel_size=(3, 3), 54 | stride_size=(1, 1), 55 | apply_norm=False, 56 | ) 57 | self.to_feat = ConvLayer( 58 | self.hidden_size // 2, self.hidden_size, 59 | kernel_size=(1, 1), 60 | stride_size=(1, 1), 61 | apply_norm=False, 62 | ) 63 | 64 | # Encoder-Decoder Network, maps to pixel location with spatial argmax 65 | self.feature_encoder = nn.ModuleList() 66 | for i in range(self.num_layers): 67 | self.feature_encoder.append( 68 | ConvLayer( 69 | in_channels=self.hidden_size, 70 | out_channels=self.hidden_size, 71 | kernel_size=(3, 3), 72 | stride_size=(2, 2) 73 | ) 74 | ) 75 | 76 | self.enc_size = self.hidden_size 77 | 78 | if self.unet: 79 | self.trans_decoder = nn.ModuleList() 80 | for i in range(self.num_layers): 81 | self.trans_decoder.extend( 82 | [ 83 | nn.Sequential( 84 | ConvLayer( 85 | in_channels=self.hidden_size * 2, 86 | out_channels=self.hidden_size, 87 | kernel_size=(3, 3), 88 | stride_size=(1, 1), 89 | ), 90 | nn.Upsample( 91 | scale_factor=2, mode="bilinear", align_corners=True 92 | ), 93 | ) 94 | ] 95 | ) 96 | 97 | quat_hidden_size = self.quat_hidden_size 98 | self.quat_decoder = nn.Sequential( 99 | ConvLayer( 100 | in_channels=quat_hidden_size, out_channels=quat_hidden_size, 101 | kernel_size=(3, 3), stride_size=(2, 2), 102 | ), 103 | ConvLayer( 104 | in_channels=quat_hidden_size, out_channels=quat_hidden_size, 105 | kernel_size=(3, 3), stride_size=(2, 2) 106 | ), 107 | nn.AdaptiveAvgPool2d(1), 108 | Rearrange("b c h w -> b (c h w)"), 109 | *dense_layer(quat_hidden_size, quat_hidden_size), 110 | *dense_layer(quat_hidden_size, 3 + 4 + 1, apply_activation=False), 111 | ) 112 | 113 | self.maps_to_coord = ConvLayer( 114 | in_channels=self.hidden_size, 115 | out_channels=1, 116 | kernel_size=(1, 1), 117 | stride_size=(1, 1), 118 | apply_norm=False, 119 | apply_activation=False, 120 | ) 121 | 122 | self.loss_fn = ActionLoss() 123 | 124 | @property 125 | def quat_hidden_size(self): 126 | return self.hidden_size * 4 # 3 cameras + task/step embedding 127 | 128 | @property 129 | def num_parameters(self): 130 | nweights, nparams = 0, 0 131 | for k, v in self.named_parameters(): 132 | nweights += np.prod(v.size()) 133 | nparams += 1 134 | return nweights, nparams 135 | 136 | @property 137 | def num_trainable_parameters(self): 138 | nweights, nparams = 0, 0 139 | for k, v in self.named_parameters(): 140 | if v.requires_grad: 141 | nweights += np.prod(v.size()) 142 | nparams += 1 143 | return nweights, nparams 144 | 145 | def prepare_batch(self, batch): 146 | device = next(self.parameters()).device 147 | for k, v in batch.items(): 148 | if isinstance(v, torch.Tensor): 149 | batch[k] = v.to(device) 150 | return batch 151 | 152 | def forward(self, batch, compute_loss=False): 153 | '''Inputs: 154 | - rgb_obs, pc_obs: (B, N, C, H, W) B: batch_size, N: #cameras 155 | - task_ids, step_ids: (B, ) 156 | ''' 157 | batch = self.prepare_batch(batch) 158 | 159 | rgb_obs = batch['rgbs'] 160 | pc_obs = batch['pcds'] 161 | taskvar_ids = batch['taskvar_ids'] 162 | step_ids = batch['step_ids'] 163 | instr_embeds = batch.get('instr_embeds', None) 164 | 165 | batch_size, n_cameras, _, im_height, im_width = rgb_obs.size() 166 | 167 | rgb_fts = einops.rearrange(rgb_obs, "b n c h w -> (b n) c h w") 168 | rgb_fts = self.rgb_preprocess(rgb_fts) 169 | x = self.to_feat(rgb_fts) 170 | 171 | # encoding features 172 | enc_fts = [] 173 | for l in self.feature_encoder: 174 | x = l(x) 175 | enc_fts.append(x) 176 | 177 | # concat the rgb fts with task/step embeds 178 | if self.use_instr_embed == 'none': 179 | task_embeds = self.task_embedding(taskvar_ids) 180 | else: 181 | assert instr_embeds.size(1) == 1 182 | task_embeds = self.task_embedding(instr_embeds[:, 0]) 183 | step_embeds = self.step_embedding(step_ids) 184 | task_step_embeds = task_embeds + step_embeds 185 | 186 | # decoding features 187 | enc_fts.reverse() 188 | 189 | if self.unet: 190 | ext_task_step_embeds = einops.repeat( 191 | task_step_embeds, 'b c -> (b n) c h w', 192 | n=n_cameras, h=enc_fts[0].shape[-2], w=enc_fts[0].shape[-1] 193 | ) 194 | x = torch.cat([enc_fts[0], ext_task_step_embeds], dim=1) 195 | 196 | for i, l in enumerate(self.trans_decoder): 197 | if i == 0: 198 | xtr = l(x) 199 | else: 200 | xtr = l(torch.cat([xtr, enc_fts[i]], dim=1)) 201 | 202 | # predict the translation of the gripper 203 | xt_heatmap = self.maps_to_coord(xtr) 204 | xt_heatmap = einops.rearrange( 205 | xt_heatmap, '(b n) c h w -> b (n c h w)', n=n_cameras, c=1 206 | ) 207 | xt_heatmap = torch.softmax(xt_heatmap / 0.1, dim=1) 208 | xt_heatmap = einops.rearrange( 209 | xt_heatmap, 'b (n c h w) -> b n c h w', 210 | n=n_cameras, c=1, h=im_height, w=im_width 211 | ) 212 | xt = einops.reduce(pc_obs * xt_heatmap, 'b n c h w -> b c', 'sum') 213 | 214 | else: 215 | xt = 0 216 | 217 | # predict the (translation_offset, rotation and openness) of the gripper 218 | xg = einops.rearrange( 219 | enc_fts[0], '(b n) c h w -> b (n c) h w', n=n_cameras 220 | ) 221 | ext_task_step_embeds = einops.repeat( 222 | task_step_embeds, 'b c -> b c h w', 223 | h=xg.size(2), w=xg.size(3) 224 | ) 225 | xg = torch.cat([xg, ext_task_step_embeds], dim=1) 226 | xg = self.quat_decoder(xg) 227 | xt_offset = xg[..., :3] 228 | xr = normalise_quat(xg[..., 3:7]) 229 | xo = xg[..., 7].unsqueeze(-1) 230 | 231 | actions = torch.cat([xt + xt_offset, xr, xo], dim=-1) 232 | 233 | if compute_loss: 234 | losses = self.loss_fn.compute_loss(actions, batch['actions']) 235 | return losses, actions 236 | 237 | return actions 238 | 239 | 240 | if __name__ == '__main__': 241 | model = PlainUNet( 242 | hidden_size=16, num_layers=4, 243 | num_tasks=1, max_steps=20, 244 | gripper_channel=False, unet=True, 245 | use_instr_embed='avg', instr_embed_size=512, 246 | ) 247 | print(next(model.parameters()).device) 248 | 249 | b, n, h, w = 2, 3, 128, 128 250 | batch = { 251 | 'rgbs': torch.rand(b, n, 3, h, w), 252 | 'pcds': torch.rand(b, n, 3, h, w), 253 | 'taskvar_ids': torch.zeros(b).long(), 254 | 'step_ids': torch.randint(0, 10, (b, )).long(), 255 | 'instr_embeds': torch.rand(b, 1, 512), 256 | } 257 | 258 | actions = model(batch) 259 | print(actions.size()) 260 | print(actions) 261 | -------------------------------------------------------------------------------- /models/transformer_unet.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | 4 | import einops 5 | from einops.layers.torch import Rearrange 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from models.network_utils import normalise_quat 11 | from models.plain_unet import PlainUNet 12 | 13 | 14 | class TransformerUNet(PlainUNet): 15 | def __init__( 16 | self, num_trans_layers: int = 1, nhead: int = 8, 17 | txt_attn_type: str = 'self', num_cams: int = 3, 18 | latent_im_size: Tuple = (8, 8), 19 | quat_input: str = 'add', **kwargs 20 | ): 21 | self.num_trans_layers = num_trans_layers 22 | self.nhead = nhead 23 | self.txt_attn_type = txt_attn_type 24 | self.num_cams = num_cams 25 | self.latent_im_size = latent_im_size 26 | self.quat_input = quat_input 27 | 28 | super().__init__(**kwargs) 29 | 30 | if self.txt_attn_type == 'self': 31 | encoder_layer = nn.TransformerEncoderLayer( 32 | d_model=self.hidden_size, 33 | nhead=self.nhead, 34 | dim_feedforward=self.hidden_size*4, 35 | dropout=0.1, activation='gelu', 36 | layer_norm_eps=1e-12, norm_first=False, 37 | batch_first=True, 38 | ) 39 | self.self_attention = nn.TransformerEncoder( 40 | encoder_layer, num_layers=self.num_trans_layers 41 | ) 42 | 43 | elif self.txt_attn_type == 'cross': 44 | decoder_layer = nn.TransformerDecoderLayer( 45 | d_model=self.hidden_size, 46 | nhead=self.nhead, 47 | dim_feedforward=self.hidden_size*4, 48 | dropout=0.1, activation='gelu', 49 | layer_norm_eps=1e-12, norm_first=False, 50 | batch_first=True, 51 | ) 52 | self.cross_attention = nn.TransformerDecoder( 53 | decoder_layer, num_layers=self.num_trans_layers 54 | ) 55 | 56 | self.visual_embedding = nn.Linear(self.enc_size, self.hidden_size) 57 | self.cam_embedding = nn.Embedding(self.num_cams, self.hidden_size) 58 | self.pix_embedding = nn.Embedding( 59 | np.prod(self.latent_im_size), self.hidden_size 60 | ) 61 | 62 | @property 63 | def quat_hidden_size(self): 64 | if self.quat_input == 'add': 65 | return self.hidden_size * 3 # 3 cameras 66 | elif self.quat_input == 'concat': 67 | return self.hidden_size * 3 * 2 68 | else: 69 | raise NotImplementedError( 70 | 'unsupport quat_input %s' % self.quat_input) 71 | 72 | def forward(self, batch, compute_loss=False): 73 | '''Inputs: 74 | - rgb_obs, pc_obs: (B, T, N, C, H, W) B: batch_size, N: #cameras 75 | - task_ids: (B, ) 76 | - step_ids: (B, T) 77 | - step_masks: (B, T) 78 | - txt_masks: (B, L_txt) 79 | ''' 80 | batch = self.prepare_batch(batch) 81 | device = batch['rgbs'].device 82 | 83 | rgb_obs = batch['rgbs'] 84 | pc_obs = batch['pcds'] 85 | taskvar_ids = batch['taskvar_ids'] 86 | step_masks = batch['step_masks'] 87 | step_ids = batch['step_ids'] 88 | 89 | batch_size, nsteps, num_cams, _, im_height, im_width = rgb_obs.size() 90 | 91 | # (B, L, C) 92 | if 'instr_embeds' in batch: 93 | instr_embeds = self.task_embedding(batch['instr_embeds']) 94 | txt_masks = batch['txt_masks'] 95 | else: 96 | instr_embeds = self.task_embedding(taskvar_ids).unsqueeze(1) 97 | txt_masks = torch.ones(batch_size, 1).long().to(device) 98 | 99 | rgb_fts = einops.rearrange(rgb_obs, "b t n c h w -> (b t n) c h w") 100 | rgb_fts = self.rgb_preprocess(rgb_fts) 101 | x = self.to_feat(rgb_fts) 102 | 103 | # encoding features 104 | enc_fts = [] 105 | for l in self.feature_encoder: 106 | x = l(x) 107 | enc_fts.append(x) 108 | # x: (b t n) c h w 109 | x = einops.rearrange( 110 | x, '(b t n) c h w -> b t n (h w) c', 111 | b=batch_size, t=nsteps, n=num_cams 112 | ) 113 | x = self.visual_embedding(x) 114 | 115 | enc_fts[-1] = einops.rearrange(x, "b t n (h w) c -> (b t n) c h w", h=self.latent_im_size[0], w=self.latent_im_size[1]) 116 | 117 | step_embeds = self.step_embedding(batch['step_ids']) # (B, T, C) 118 | cam_embeds = self.cam_embedding( 119 | torch.arange(num_cams).long().to(device) 120 | ) # (N, C) 121 | pix_embeds = self.pix_embedding( 122 | torch.arange(np.prod(self.latent_im_size)).long().to(device) 123 | ) # (H * W, C) 124 | x = x + einops.rearrange(step_embeds, 'b t c -> b t 1 1 c') + \ 125 | einops.rearrange(cam_embeds, 'n c -> 1 1 n 1 c') + \ 126 | einops.rearrange(pix_embeds, 'l c -> 1 1 1 l c') 127 | x = einops.rearrange(x, 'b t n l c -> b (t n l) c') 128 | 129 | # transformer: text / history encoding 130 | num_vtokens_per_step = num_cams * np.prod(self.latent_im_size) 131 | num_ttokens = instr_embeds.size(1) 132 | num_vtokens = num_vtokens_per_step * nsteps 133 | num_tokens = num_ttokens + num_vtokens 134 | 135 | if self.txt_attn_type == 'self': 136 | inputs = torch.cat( 137 | [instr_embeds, x], dim=1 138 | ) 139 | 140 | ext_step_masks = einops.repeat( 141 | step_masks, 'b t -> b t k', k=num_vtokens_per_step 142 | ).flatten(start_dim=1) 143 | src_masks = torch.cat( 144 | [txt_masks, ext_step_masks], dim=1 145 | ) # (b, l+t*n*h*w) 146 | 147 | causal_masks = torch.zeros(num_tokens, num_tokens).bool() 148 | causal_masks[:num_ttokens, :num_ttokens] = True 149 | for t in range(nsteps): 150 | s = num_ttokens + num_vtokens_per_step * t 151 | e = s + num_vtokens_per_step 152 | causal_masks[s:e, :e] = True 153 | causal_masks = causal_masks.to(device) 154 | 155 | outputs = self.self_attention( 156 | inputs, 157 | mask=causal_masks.logical_not(), 158 | src_key_padding_mask=src_masks.logical_not() 159 | ) 160 | outputs = outputs[:, num_ttokens:] 161 | 162 | elif self.txt_attn_type == 'cross': 163 | src_masks = einops.repeat( 164 | step_masks, 'b t -> b t k', k=num_vtokens_per_step 165 | ).flatten(start_dim=1) 166 | 167 | causal_masks = torch.zeros(num_vtokens, num_vtokens).bool() 168 | for t in range(nsteps): 169 | s = num_vtokens_per_step * t 170 | e = s + num_vtokens_per_step 171 | causal_masks[s:e, :e] = True 172 | causal_masks = causal_masks.to(device) 173 | 174 | outputs = self.cross_attention( 175 | x, instr_embeds, 176 | tgt_mask=causal_masks.logical_not(), 177 | tgt_key_padding_mask=src_masks.logical_not(), 178 | memory_key_padding_mask=txt_masks.logical_not(), 179 | ) 180 | 181 | outputs = einops.rearrange( 182 | outputs, 'b (t n h w) c -> (b t n) c h w', 183 | t=nsteps, n=self.num_cams, 184 | h=self.latent_im_size[0], w=self.latent_im_size[1] 185 | ) 186 | 187 | # decoding features 188 | enc_fts.reverse() 189 | 190 | if self.unet: 191 | xtr = outputs 192 | for i, l in enumerate(self.trans_decoder): 193 | xtr = l(torch.cat([xtr, enc_fts[i]], dim=1)) 194 | 195 | # predict the translation of the gripper 196 | xt_heatmap = self.maps_to_coord(xtr) 197 | xt_heatmap = einops.rearrange( 198 | xt_heatmap, '(b t n) c h w -> b t (n c h w)', t=nsteps, n=num_cams, c=1 199 | ) 200 | xt_heatmap = torch.softmax(xt_heatmap / 0.1, dim=-1) 201 | xt_heatmap = einops.rearrange( 202 | xt_heatmap, 'b t (n c h w) -> b t n c h w', 203 | t=nsteps, n=num_cams, c=1, h=im_height, w=im_width 204 | ) 205 | xt = einops.reduce(pc_obs * xt_heatmap, 206 | 'b t n c h w -> b t c', 'sum') 207 | 208 | else: 209 | xt = 0 210 | 211 | # predict the (translation_offset, rotation and openness) of the gripper 212 | if self.quat_input == 'add': 213 | xg = outputs + enc_fts[0] 214 | elif self.quat_input == 'concat': 215 | xg = torch.cat([outputs, enc_fts[0]], dim=1) 216 | xg = einops.rearrange( 217 | xg, '(b t n) c h w -> (b t) (n c) h w', t=nsteps, n=num_cams 218 | ) 219 | xg = self.quat_decoder(xg) 220 | xg = einops.rearrange(xg, '(b t) c -> b t c', t=nsteps) 221 | xt_offset = xg[..., :3] 222 | xr = normalise_quat(xg[..., 3:7]) 223 | xo = xg[..., 7].unsqueeze(-1) 224 | 225 | actions = torch.cat([xt + xt_offset, xr, xo], dim=-1) 226 | 227 | if compute_loss: 228 | losses = self.loss_fn.compute_loss( 229 | actions, batch['actions'], masks=batch['step_masks'] 230 | ) 231 | return losses, actions 232 | 233 | return actions 234 | 235 | 236 | if __name__ == '__main__': 237 | model = TransformerUNet( 238 | hidden_size=16, num_layers=4, 239 | num_tasks=None, max_steps=20, 240 | gripper_channel=False, unet=False, 241 | use_instr_embed='all', instr_embed_size=512, 242 | num_trans_layers=1, nhead=8, 243 | txt_attn_type='self', num_cams=3, 244 | latent_im_size=(8, 8) 245 | ) 246 | 247 | b, t, n, h, w = 2, 4, 3, 128, 128 248 | batch = { 249 | 'rgbs': torch.rand(b, t, n, 3, h, w), 250 | 'pcds': torch.rand(b, t, n, 3, h, w), 251 | 'taskvar_ids': torch.zeros(b).long(), 252 | 'step_ids': einops.repeat(torch.arange(t).long(), 't -> b t', b=b), 253 | 'step_masks': torch.ones(b, t).bool(), 254 | 'instr_embeds': torch.rand(b, 5, 512), 255 | 'txt_masks': torch.ones(b, 5).bool(), 256 | } 257 | 258 | actions = model(batch) 259 | print(actions.size()) 260 | print(actions) 261 | -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | from .sched import noam_schedule, warmup_linear, get_lr_sched 7 | from .adamw import AdamW 8 | -------------------------------------------------------------------------------- /optim/adamw.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdamW optimizer (weight decay fix) 3 | copied from hugginface (https://github.com/huggingface/transformers). 4 | """ 5 | 6 | import math 7 | from typing import Callable, Iterable, Tuple 8 | 9 | import torch 10 | 11 | from torch.optim import Optimizer 12 | 13 | class AdamW(Optimizer): 14 | """ 15 | Implements Adam algorithm with weight decay fix as introduced in `Decoupled Weight Decay Regularization 16 | `__. 17 | 18 | Parameters: 19 | params (:obj:`Iterable[torch.nn.parameter.Parameter]`): 20 | Iterable of parameters to optimize or dictionaries defining parameter groups. 21 | lr (:obj:`float`, `optional`, defaults to 1e-3): 22 | The learning rate to use. 23 | betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)): 24 | Adam's betas parameters (b1, b2). 25 | eps (:obj:`float`, `optional`, defaults to 1e-6): 26 | Adam's epsilon for numerical stability. 27 | weight_decay (:obj:`float`, `optional`, defaults to 0): 28 | Decoupled weight decay to apply. 29 | correct_bias (:obj:`bool`, `optional`, defaults to `True`): 30 | Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`). 31 | """ 32 | 33 | def __init__( 34 | self, 35 | params: Iterable[torch.nn.parameter.Parameter], 36 | lr: float = 1e-3, 37 | betas: Tuple[float, float] = (0.9, 0.999), 38 | eps: float = 1e-6, 39 | weight_decay: float = 0.0, 40 | correct_bias: bool = True, 41 | ): 42 | if lr < 0.0: 43 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 44 | if not 0.0 <= betas[0] < 1.0: 45 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 46 | if not 0.0 <= betas[1] < 1.0: 47 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 48 | if not 0.0 <= eps: 49 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 50 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) 51 | super().__init__(params, defaults) 52 | 53 | def step(self, closure: Callable = None): 54 | """ 55 | Performs a single optimization step. 56 | 57 | Arguments: 58 | closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. 59 | """ 60 | loss = None 61 | if closure is not None: 62 | loss = closure() 63 | 64 | for group in self.param_groups: 65 | for p in group["params"]: 66 | if p.grad is None: 67 | continue 68 | grad = p.grad.data 69 | if grad.is_sparse: 70 | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") 71 | 72 | state = self.state[p] 73 | 74 | # State initialization 75 | if len(state) == 0: 76 | state["step"] = 0 77 | # Exponential moving average of gradient values 78 | state["exp_avg"] = torch.zeros_like(p.data) 79 | # Exponential moving average of squared gradient values 80 | state["exp_avg_sq"] = torch.zeros_like(p.data) 81 | 82 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 83 | beta1, beta2 = group["betas"] 84 | 85 | state["step"] += 1 86 | 87 | # Decay the first and second moment running average coefficient 88 | # In-place operations to update the averages at the same time 89 | exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) 90 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) 91 | denom = exp_avg_sq.sqrt().add_(group["eps"]) 92 | 93 | step_size = group["lr"] 94 | if group["correct_bias"]: # No bias correction for Bert 95 | bias_correction1 = 1.0 - beta1 ** state["step"] 96 | bias_correction2 = 1.0 - beta2 ** state["step"] 97 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 98 | 99 | p.data.addcdiv_(exp_avg, denom, value=-step_size) 100 | 101 | # Just adding the square of the weights to the loss function is *not* 102 | # the correct way of using L2 regularization/weight decay with Adam, 103 | # since that will interact with the m and v parameters in strange ways. 104 | # 105 | # Instead we want to decay the weights in a manner that doesn't interact 106 | # with the m/v parameters. This is equivalent to adding the square 107 | # of the weights to the loss with plain (non-momentum) SGD. 108 | # Add weight decay at the end (fixed version) 109 | if group["weight_decay"] > 0.0: 110 | p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /optim/lookahead.py: -------------------------------------------------------------------------------- 1 | # Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py 2 | 3 | """ Lookahead Optimizer Wrapper. 4 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 5 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 6 | """ 7 | import torch 8 | from torch.optim.optimizer import Optimizer 9 | from torch.optim import Adam 10 | from collections import defaultdict 11 | 12 | class Lookahead(Optimizer): 13 | def __init__(self, base_optimizer, alpha=0.5, k=6): 14 | if not 0.0 <= alpha <= 1.0: 15 | raise ValueError(f'Invalid slow update rate: {alpha}') 16 | if not 1 <= k: 17 | raise ValueError(f'Invalid lookahead steps: {k}') 18 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 19 | self.base_optimizer = base_optimizer 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults = base_optimizer.defaults 22 | self.defaults.update(defaults) 23 | self.state = defaultdict(dict) 24 | # manually add our defaults to the param groups 25 | for name, default in defaults.items(): 26 | for group in self.param_groups: 27 | group.setdefault(name, default) 28 | 29 | def update_slow(self, group): 30 | for fast_p in group["params"]: 31 | if fast_p.grad is None: 32 | continue 33 | param_state = self.state[fast_p] 34 | if 'slow_buffer' not in param_state: 35 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 36 | param_state['slow_buffer'].copy_(fast_p.data) 37 | slow = param_state['slow_buffer'] 38 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 39 | fast_p.data.copy_(slow) 40 | 41 | def sync_lookahead(self): 42 | for group in self.param_groups: 43 | self.update_slow(group) 44 | 45 | def step(self, closure=None): 46 | # print(self.k) 47 | #assert id(self.param_groups) == id(self.base_optimizer.param_groups) 48 | loss = self.base_optimizer.step(closure) 49 | for group in self.param_groups: 50 | group['lookahead_step'] += 1 51 | if group['lookahead_step'] % group['lookahead_k'] == 0: 52 | self.update_slow(group) 53 | return loss 54 | 55 | def state_dict(self): 56 | fast_state_dict = self.base_optimizer.state_dict() 57 | slow_state = { 58 | (id(k) if isinstance(k, torch.Tensor) else k): v 59 | for k, v in self.state.items() 60 | } 61 | fast_state = fast_state_dict['state'] 62 | param_groups = fast_state_dict['param_groups'] 63 | return { 64 | 'state': fast_state, 65 | 'slow_state': slow_state, 66 | 'param_groups': param_groups, 67 | } 68 | 69 | def load_state_dict(self, state_dict): 70 | fast_state_dict = { 71 | 'state': state_dict['state'], 72 | 'param_groups': state_dict['param_groups'], 73 | } 74 | self.base_optimizer.load_state_dict(fast_state_dict) 75 | 76 | # We want to restore the slow state, but share param_groups reference 77 | # with base_optimizer. This is a bit redundant but least code 78 | slow_state_new = False 79 | if 'slow_state' not in state_dict: 80 | print('Loading state_dict from optimizer without Lookahead applied.') 81 | state_dict['slow_state'] = defaultdict(dict) 82 | slow_state_new = True 83 | slow_state_dict = { 84 | 'state': state_dict['slow_state'], 85 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 86 | } 87 | super(Lookahead, self).load_state_dict(slow_state_dict) 88 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 89 | if slow_state_new: 90 | # reapply defaults to catch missing lookahead specific ones 91 | for name, default in self.defaults.items(): 92 | for group in self.param_groups: 93 | group.setdefault(name, default) 94 | 95 | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): 96 | adam = Adam(params, *args, **kwargs) 97 | return Lookahead(adam, alpha, k) 98 | -------------------------------------------------------------------------------- /optim/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Misc lr helper 6 | """ 7 | from torch.optim import Adam, Adamax 8 | 9 | from .adamw import AdamW 10 | from .rangerlars import RangerLars 11 | 12 | def build_optimizer(model, opts): 13 | param_optimizer = list(model.named_parameters()) 14 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 15 | optimizer_grouped_parameters = [ 16 | {'params': [p for n, p in param_optimizer 17 | if not any(nd in n for nd in no_decay)], 18 | 'weight_decay': opts.weight_decay}, 19 | {'params': [p for n, p in param_optimizer 20 | if any(nd in n for nd in no_decay)], 21 | 'weight_decay': 0.0} 22 | ] 23 | 24 | # currently Adam only 25 | if opts.optim == 'adam': 26 | OptimCls = Adam 27 | elif opts.optim == 'adamax': 28 | OptimCls = Adamax 29 | elif opts.optim == 'adamw': 30 | OptimCls = AdamW 31 | elif opts.optim == 'rangerlars': 32 | OptimCls = RangerLars 33 | else: 34 | raise ValueError('invalid optimizer') 35 | optimizer = OptimCls(optimizer_grouped_parameters, 36 | lr=opts.learning_rate, betas=opts.betas) 37 | return optimizer 38 | -------------------------------------------------------------------------------- /optim/radam.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam.py 2 | 3 | import math 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | class RAdam(Optimizer): 8 | 9 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 10 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 11 | self.buffer = [[None, None, None] for ind in range(10)] 12 | super(RAdam, self).__init__(params, defaults) 13 | 14 | def __setstate__(self, state): 15 | super(RAdam, self).__setstate__(state) 16 | 17 | def step(self, closure=None): 18 | 19 | loss = None 20 | if closure is not None: 21 | loss = closure() 22 | 23 | for group in self.param_groups: 24 | 25 | for p in group['params']: 26 | if p.grad is None: 27 | continue 28 | grad = p.grad.data.float() 29 | if grad.is_sparse: 30 | raise RuntimeError('RAdam does not support sparse gradients') 31 | 32 | p_data_fp32 = p.data.float() 33 | 34 | state = self.state[p] 35 | 36 | if len(state) == 0: 37 | state['step'] = 0 38 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 39 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 40 | else: 41 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 42 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 43 | 44 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 45 | beta1, beta2 = group['betas'] 46 | 47 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 48 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 49 | 50 | state['step'] += 1 51 | buffered = self.buffer[int(state['step'] % 10)] 52 | if state['step'] == buffered[0]: 53 | N_sma, step_size = buffered[1], buffered[2] 54 | else: 55 | buffered[0] = state['step'] 56 | beta2_t = beta2 ** state['step'] 57 | N_sma_max = 2 / (1 - beta2) - 1 58 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 59 | buffered[1] = N_sma 60 | 61 | # more conservative since it's an approximated value 62 | if N_sma >= 5: 63 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 64 | else: 65 | step_size = 1.0 / (1 - beta1 ** state['step']) 66 | buffered[2] = step_size 67 | 68 | if group['weight_decay'] != 0: 69 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 70 | 71 | # more conservative since it's an approximated value 72 | if N_sma >= 5: 73 | denom = exp_avg_sq.sqrt().add_(group['eps']) 74 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 75 | else: 76 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 77 | 78 | p.data.copy_(p_data_fp32) 79 | 80 | return loss 81 | 82 | class PlainRAdam(Optimizer): 83 | 84 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 85 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 86 | 87 | super(PlainRAdam, self).__init__(params, defaults) 88 | 89 | def __setstate__(self, state): 90 | super(PlainRAdam, self).__setstate__(state) 91 | 92 | def step(self, closure=None): 93 | 94 | loss = None 95 | if closure is not None: 96 | loss = closure() 97 | 98 | for group in self.param_groups: 99 | 100 | for p in group['params']: 101 | if p.grad is None: 102 | continue 103 | grad = p.grad.data.float() 104 | if grad.is_sparse: 105 | raise RuntimeError('RAdam does not support sparse gradients') 106 | 107 | p_data_fp32 = p.data.float() 108 | 109 | state = self.state[p] 110 | 111 | if len(state) == 0: 112 | state['step'] = 0 113 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 114 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 115 | else: 116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 117 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 118 | 119 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 120 | beta1, beta2 = group['betas'] 121 | 122 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 123 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 124 | 125 | state['step'] += 1 126 | beta2_t = beta2 ** state['step'] 127 | N_sma_max = 2 / (1 - beta2) - 1 128 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 129 | 130 | if group['weight_decay'] != 0: 131 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 132 | 133 | # more conservative since it's an approximated value 134 | if N_sma >= 5: 135 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 136 | denom = exp_avg_sq.sqrt().add_(group['eps']) 137 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 138 | else: 139 | step_size = group['lr'] / (1 - beta1 ** state['step']) 140 | p_data_fp32.add_(-step_size, exp_avg) 141 | 142 | p.data.copy_(p_data_fp32) 143 | 144 | return loss 145 | 146 | 147 | class AdamW(Optimizer): 148 | 149 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 150 | defaults = dict(lr=lr, betas=betas, eps=eps, 151 | weight_decay=weight_decay, warmup = warmup) 152 | super(AdamW, self).__init__(params, defaults) 153 | 154 | def __setstate__(self, state): 155 | super(AdamW, self).__setstate__(state) 156 | 157 | def step(self, closure=None): 158 | loss = None 159 | if closure is not None: 160 | loss = closure() 161 | 162 | for group in self.param_groups: 163 | 164 | for p in group['params']: 165 | if p.grad is None: 166 | continue 167 | grad = p.grad.data.float() 168 | if grad.is_sparse: 169 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 170 | 171 | p_data_fp32 = p.data.float() 172 | 173 | state = self.state[p] 174 | 175 | if len(state) == 0: 176 | state['step'] = 0 177 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 178 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 179 | else: 180 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 181 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 182 | 183 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 184 | beta1, beta2 = group['betas'] 185 | 186 | state['step'] += 1 187 | 188 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 189 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 190 | 191 | denom = exp_avg_sq.sqrt().add_(group['eps']) 192 | bias_correction1 = 1 - beta1 ** state['step'] 193 | bias_correction2 = 1 - beta2 ** state['step'] 194 | 195 | if group['warmup'] > state['step']: 196 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 197 | else: 198 | scheduled_lr = group['lr'] 199 | 200 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 201 | 202 | if group['weight_decay'] != 0: 203 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 204 | 205 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 206 | 207 | p.data.copy_(p_data_fp32) 208 | 209 | return loss 210 | -------------------------------------------------------------------------------- /optim/ralamb.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | from torch.optim.optimizer import Optimizer 3 | 4 | # RAdam + LARS 5 | class Ralamb(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(Ralamb, self).__init__(params, defaults) 11 | 12 | def __setstate__(self, state): 13 | super(Ralamb, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('Ralamb does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | # Decay the first and second moment running average coefficient 46 | # m_t 47 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 48 | # v_t 49 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 50 | 51 | state['step'] += 1 52 | buffered = self.buffer[int(state['step'] % 10)] 53 | 54 | if state['step'] == buffered[0]: 55 | N_sma, radam_step_size = buffered[1], buffered[2] 56 | else: 57 | buffered[0] = state['step'] 58 | beta2_t = beta2 ** state['step'] 59 | N_sma_max = 2 / (1 - beta2) - 1 60 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 61 | buffered[1] = N_sma 62 | 63 | # more conservative since it's an approximated value 64 | if N_sma >= 5: 65 | radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 66 | else: 67 | radam_step_size = 1.0 / (1 - beta1 ** state['step']) 68 | buffered[2] = radam_step_size 69 | 70 | if group['weight_decay'] != 0: 71 | p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) 72 | 73 | # more conservative since it's an approximated value 74 | radam_step = p_data_fp32.clone() 75 | if N_sma >= 5: 76 | denom = exp_avg_sq.sqrt().add_(group['eps']) 77 | radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom) 78 | else: 79 | radam_step.add_(exp_avg, alpha=-radam_step_size * group['lr']) 80 | 81 | radam_norm = radam_step.pow(2).sum().sqrt() 82 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 83 | if weight_norm == 0 or radam_norm == 0: 84 | trust_ratio = 1 85 | else: 86 | trust_ratio = weight_norm / radam_norm 87 | 88 | state['weight_norm'] = weight_norm 89 | state['adam_norm'] = radam_norm 90 | state['trust_ratio'] = trust_ratio 91 | 92 | if N_sma >= 5: 93 | p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom) 94 | else: 95 | p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg) 96 | 97 | p.data.copy_(p_data_fp32) 98 | 99 | return loss 100 | -------------------------------------------------------------------------------- /optim/rangerlars.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | from torch.optim.optimizer import Optimizer 3 | import itertools as it 4 | from .lookahead import * 5 | from .ralamb import * 6 | 7 | # RAdam + LARS + LookAHead 8 | 9 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 10 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 11 | 12 | def RangerLars(params, alpha=0.5, k=6, *args, **kwargs): 13 | ralamb = Ralamb(params, *args, **kwargs) 14 | return Lookahead(ralamb, alpha, k) 15 | -------------------------------------------------------------------------------- /optim/sched.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | optimizer learning rate scheduling helpers 6 | """ 7 | from math import ceil 8 | 9 | 10 | def noam_schedule(step, warmup_step=4000): 11 | """ original Transformer schedule""" 12 | if step <= warmup_step: 13 | return step / warmup_step 14 | return (warmup_step ** 0.5) * (step ** -0.5) 15 | 16 | 17 | def warmup_linear(step, warmup_step, tot_step): 18 | """ BERT schedule """ 19 | if step < warmup_step: 20 | return step / warmup_step 21 | return max(0, (tot_step-step)/(tot_step-warmup_step)) 22 | 23 | def warmup_inverse_sqrt(step, warmup_step, tot_step): 24 | """Decay the LR based on the inverse square root of the update number. 25 | We also support a warmup phase where we linearly increase the learning rate 26 | from some initial learning rate (``--warmup-init-lr``) until the configured 27 | learning rate (``--lr``). Thereafter we decay proportional to the number of 28 | updates, with a decay factor set to align with the configured learning rate. 29 | 30 | During warmup:: 31 | 32 | lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) 33 | lr = lrs[update_num] 34 | 35 | After warmup:: 36 | 37 | decay_factor = cfg.lr * sqrt(cfg.warmup_updates) 38 | lr = decay_factor / sqrt(update_num) 39 | """ 40 | if step < warmup_step: 41 | return step / warmup_step 42 | else: 43 | return warmup_step**0.5 * step**-0.5 44 | 45 | 46 | def get_lr_sched(global_step, opts): 47 | # learning rate scheduling 48 | if opts.lr_sched == 'linear': 49 | func = warmup_linear 50 | elif opts.lr_sched == 'inverse_sqrt': 51 | func = warmup_inverse_sqrt 52 | else: 53 | raise NotImplementedError(f'invalid lr scheduler {opts.lr_sched}') 54 | 55 | lr_this_step = opts.learning_rate * func( 56 | global_step, opts.warmup_steps, opts.num_train_steps 57 | ) 58 | if lr_this_step <= 0: 59 | lr_this_step = 1e-8 60 | return lr_this_step 61 | -------------------------------------------------------------------------------- /preprocess/evaluate_dataset_keysteps.py: -------------------------------------------------------------------------------- 1 | from core.actioner import BaseActioner 2 | from core.environments import RLBenchEnv 3 | from typing import Tuple, Dict, List 4 | 5 | import os 6 | import numpy as np 7 | import itertools 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | import collections 11 | import tap 12 | 13 | import lmdb 14 | import msgpack 15 | import msgpack_numpy 16 | msgpack_numpy.patch() 17 | 18 | 19 | class Arguments(tap.Tap): 20 | microstep_data_dir: Path = "data/train_dataset/microsteps/seed0" 21 | keystep_data_dir: Path = "data/train_dataset/keysteps/seed0" 22 | 23 | tasks: Tuple[str, ...] = ("pick_up_cup",) 24 | cameras: Tuple[str, ...] = ("left_shoulder", "right_shoulder", "wrist") 25 | 26 | max_variations: int = 1 27 | offset: int = 0 28 | 29 | headless: bool = False 30 | gripper_pose: str = None 31 | max_tries: int = 10 32 | 33 | log_dir: Path = None 34 | 35 | 36 | class KeystepActioner(BaseActioner): 37 | def __init__(self, keystep_data_dir) -> None: 38 | self.lmdb_env = lmdb.open(str(keystep_data_dir), readonly=True) 39 | self.lmdb_txn = self.lmdb_env.begin() 40 | 41 | def __exit__(self): 42 | self.lmdb_env.close() 43 | 44 | def reset(self, task_str, variation, instructions, demo_id): 45 | super().reset(task_str, variation, instructions, demo_id) 46 | 47 | value = self.lmdb_txn.get(demo_id.encode('ascii')) 48 | value = msgpack.unpackb(value) 49 | self.actions = value['action'][1:] 50 | 51 | def predict(self, taskvar_id, step_id, *args, **kwargs): 52 | out = {} 53 | if step_id < len(self.actions): 54 | out['action'] = self.actions[step_id] 55 | else: 56 | out['action'] = np.zeros((8, ), dtype=np.float32) 57 | print(self.demo_id, step_id, len(self.actions)) 58 | return out 59 | 60 | 61 | def evaluate_keysteps(args): 62 | env = RLBenchEnv( 63 | data_path=args.microstep_data_dir, 64 | apply_rgb=True, 65 | apply_pc=True, 66 | apply_cameras=args.cameras, 67 | headless=args.headless, 68 | gripper_pose=args.gripper_pose, 69 | ) 70 | 71 | variations = range(args.offset, args.max_variations) 72 | 73 | taskvar_id = 0 74 | for task_str in args.tasks: 75 | for variation in variations: 76 | actioner = KeystepActioner( 77 | args.keystep_data_dir / f"{task_str}+{variation}") 78 | episodes_dir = args.microstep_data_dir / \ 79 | task_str / f"variation{variation}" / "episodes" 80 | 81 | demo_keys, demos = [], [] 82 | for ep in tqdm(episodes_dir.glob('episode*')): 83 | episode_id = int(ep.stem[7:]) 84 | demo = env.get_demo(task_str, variation, episode_id) 85 | demo_keys.append(f'episode{episode_id}') 86 | demos.append(demo) 87 | 88 | success_rate = env.evaluate( 89 | taskvar_id, 90 | task_str, 91 | actioner=actioner, 92 | max_episodes=10, 93 | variation=variation, 94 | num_demos=len(demos), 95 | demos=demos, 96 | demo_keys=demo_keys, 97 | log_dir=args.log_dir, 98 | max_tries=args.max_tries, 99 | save_image=False, 100 | ) 101 | 102 | print("Testing Success Rate {}: {:.04f}".format( 103 | task_str, success_rate)) 104 | 105 | if __name__ == '__main__': 106 | args = Arguments().parse_args() 107 | evaluate_keysteps(args) 108 | -------------------------------------------------------------------------------- /preprocess/generate_dataset_keysteps.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, List 2 | 3 | import os 4 | import numpy as np 5 | import itertools 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import collections 9 | import tap 10 | 11 | import lmdb 12 | import msgpack 13 | import msgpack_numpy 14 | msgpack_numpy.patch() 15 | 16 | from utils.keystep_detection import keypoint_discovery 17 | from utils.coord_transforms import convert_gripper_pose_world_to_image 18 | 19 | from core.environments import RLBenchEnv 20 | 21 | 22 | class Arguments(tap.Tap): 23 | microstep_data_dir: Path = "data/train_dataset/microsteps/seed0" 24 | keystep_data_dir: Path = "data/train_dataset/keysteps/seed0" 25 | 26 | tasks: Tuple[str, ...] = ("pick_up_cup",) 27 | cameras: Tuple[str, ...] = ("left_shoulder", "right_shoulder", "wrist") 28 | 29 | max_variations: int = 1 30 | offset: int = 0 31 | 32 | 33 | 34 | def get_observation(task_str: str, variation: int, episode: int, env: RLBenchEnv): 35 | demo = env.get_demo(task_str, variation, episode) 36 | 37 | key_frames = keypoint_discovery(demo) 38 | key_frames.insert(0, 0) 39 | 40 | state_dict_ls = collections.defaultdict(list) 41 | for f in key_frames: 42 | state_dict = env.get_observation(demo._observations[f]) 43 | for k, v in state_dict.items(): 44 | if len(v) > 0: 45 | # rgb: (N: num_of_cameras, H, W, C); gripper: (7+1, ) 46 | state_dict_ls[k].append(v) 47 | 48 | for k, v in state_dict_ls.items(): 49 | state_dict_ls[k] = np.stack(v, 0) # (T, N, H, W, C) 50 | 51 | action_ls = state_dict_ls['gripper'] # (T, 7+1) 52 | del state_dict_ls['gripper'] 53 | 54 | return demo, key_frames, state_dict_ls, action_ls 55 | 56 | 57 | def generate_keystep_dataset(args: Arguments): 58 | # load RLBench environment 59 | rlbench_env = RLBenchEnv( 60 | data_path=args.microstep_data_dir, 61 | apply_rgb=True, 62 | apply_pc=True, 63 | apply_cameras=args.cameras, 64 | ) 65 | 66 | tasks = args.tasks 67 | variations = range(args.offset, args.max_variations) 68 | 69 | for task_str, variation in itertools.product(tasks, variations): 70 | episodes_dir = args.microstep_data_dir / task_str / f"variation{variation}" / "episodes" 71 | 72 | output_dir = args.keystep_data_dir / f"{task_str}+{variation}" 73 | output_dir.mkdir(parents=True, exist_ok=True) 74 | 75 | lmdb_env = lmdb.open(str(output_dir), map_size=int(1024**4)) 76 | 77 | for ep in tqdm(episodes_dir.glob('episode*')): 78 | episode = int(ep.stem[7:]) 79 | try: 80 | demo, key_frameids, state_dict_ls, action_ls = get_observation( 81 | task_str, variation, episode, rlbench_env 82 | ) 83 | except (FileNotFoundError, RuntimeError, IndexError) as e: 84 | print(e) 85 | return 86 | 87 | gripper_pose = [] 88 | for key_frameid in key_frameids: 89 | gripper_pose.append({ 90 | cam: convert_gripper_pose_world_to_image(demo[key_frameid], cam) for cam in args.cameras 91 | }) 92 | 93 | outs = { 94 | 'key_frameids': key_frameids, 95 | 'rgb': state_dict_ls['rgb'], # (T, N, H, W, 3) 96 | 'pc': state_dict_ls['pc'], # (T, N, H, W, 3) 97 | 'action': action_ls, # (T, A) 98 | 'gripper_pose': gripper_pose, # [T of dict] 99 | } 100 | 101 | txn = lmdb_env.begin(write=True) 102 | txn.put(f'episode{episode}'.encode('ascii'), msgpack.packb(outs)) 103 | txn.commit() 104 | 105 | lmdb_env.close() 106 | 107 | 108 | if __name__ == "__main__": 109 | args = Arguments().parse_args() 110 | generate_keystep_dataset(args) 111 | -------------------------------------------------------------------------------- /preprocess/generate_dataset_microsteps.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from RLBench/tools/dataset_generator.py 3 | ''' 4 | from multiprocessing import Process, Manager 5 | 6 | from pyrep.const import RenderMode 7 | 8 | from rlbench import ObservationConfig 9 | from rlbench.action_modes.action_mode import MoveArmThenGripper 10 | from rlbench.action_modes.arm_action_modes import JointVelocity 11 | from rlbench.action_modes.gripper_action_modes import Discrete 12 | from rlbench.backend.utils import task_file_to_task_class 13 | from rlbench.environment import Environment 14 | import rlbench.backend.task as task 15 | 16 | import os 17 | import pickle 18 | import json 19 | from PIL import Image 20 | from rlbench.backend import utils 21 | from rlbench.backend.const import * 22 | import numpy as np 23 | import random 24 | 25 | from absl import app 26 | from absl import flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string('save_path', 31 | 'data/train_dataset/microsteps/seed{seed}', 32 | 'Where to save the demos.') 33 | flags.DEFINE_string('all_task_file', 34 | 'assets/rlbench_all_tasks.json', 35 | 'A json file with list of all rlbench tasks.') 36 | flags.DEFINE_list('tasks', [], 37 | 'The tasks to collect. If empty, all tasks are collected.') 38 | flags.DEFINE_list('image_size', [128, 128], 39 | 'The size of the images tp save.') 40 | flags.DEFINE_enum('renderer', 'opengl3', ['opengl', 'opengl3'], 41 | 'The renderer to use. opengl does not include shadows, ' 42 | 'but is faster.') 43 | flags.DEFINE_integer('processes', 1, 44 | 'The number of parallel processes during collection.') 45 | flags.DEFINE_integer('episodes_per_task', 10, 46 | 'The number of episodes to collect per task.') 47 | flags.DEFINE_integer('variations', -1, 48 | 'Number of variations to collect per task. -1 for all.') 49 | flags.DEFINE_integer('offset', 0, 50 | 'First variation id.') 51 | flags.DEFINE_boolean('state', False, 52 | 'Record the state (not available for all tasks).') 53 | flags.DEFINE_integer('seed', 0, 54 | 'Seed of randomness') 55 | 56 | 57 | def check_and_make(dir): 58 | if not os.path.exists(dir): 59 | os.makedirs(dir) 60 | 61 | 62 | def save_demo(demo, example_path): 63 | 64 | # Save image data first, and then None the image data, and pickle 65 | left_shoulder_rgb_path = os.path.join( 66 | example_path, LEFT_SHOULDER_RGB_FOLDER) 67 | left_shoulder_depth_path = os.path.join( 68 | example_path, LEFT_SHOULDER_DEPTH_FOLDER) 69 | left_shoulder_mask_path = os.path.join( 70 | example_path, LEFT_SHOULDER_MASK_FOLDER) 71 | right_shoulder_rgb_path = os.path.join( 72 | example_path, RIGHT_SHOULDER_RGB_FOLDER) 73 | right_shoulder_depth_path = os.path.join( 74 | example_path, RIGHT_SHOULDER_DEPTH_FOLDER) 75 | right_shoulder_mask_path = os.path.join( 76 | example_path, RIGHT_SHOULDER_MASK_FOLDER) 77 | overhead_rgb_path = os.path.join( 78 | example_path, OVERHEAD_RGB_FOLDER) 79 | overhead_depth_path = os.path.join( 80 | example_path, OVERHEAD_DEPTH_FOLDER) 81 | overhead_mask_path = os.path.join( 82 | example_path, OVERHEAD_MASK_FOLDER) 83 | wrist_rgb_path = os.path.join(example_path, WRIST_RGB_FOLDER) 84 | wrist_depth_path = os.path.join(example_path, WRIST_DEPTH_FOLDER) 85 | wrist_mask_path = os.path.join(example_path, WRIST_MASK_FOLDER) 86 | front_rgb_path = os.path.join(example_path, FRONT_RGB_FOLDER) 87 | front_depth_path = os.path.join(example_path, FRONT_DEPTH_FOLDER) 88 | front_mask_path = os.path.join(example_path, FRONT_MASK_FOLDER) 89 | 90 | check_and_make(left_shoulder_rgb_path) 91 | check_and_make(left_shoulder_depth_path) 92 | check_and_make(left_shoulder_mask_path) 93 | check_and_make(right_shoulder_rgb_path) 94 | check_and_make(right_shoulder_depth_path) 95 | check_and_make(right_shoulder_mask_path) 96 | check_and_make(overhead_rgb_path) 97 | check_and_make(overhead_depth_path) 98 | check_and_make(overhead_mask_path) 99 | check_and_make(wrist_rgb_path) 100 | check_and_make(wrist_depth_path) 101 | check_and_make(wrist_mask_path) 102 | check_and_make(front_rgb_path) 103 | check_and_make(front_depth_path) 104 | check_and_make(front_mask_path) 105 | 106 | for i, obs in enumerate(demo): 107 | left_shoulder_rgb = Image.fromarray(obs.left_shoulder_rgb) 108 | left_shoulder_depth = utils.float_array_to_rgb_image( 109 | obs.left_shoulder_depth, scale_factor=DEPTH_SCALE) 110 | left_shoulder_mask = Image.fromarray( 111 | (obs.left_shoulder_mask * 255).astype(np.uint8)) 112 | right_shoulder_rgb = Image.fromarray(obs.right_shoulder_rgb) 113 | right_shoulder_depth = utils.float_array_to_rgb_image( 114 | obs.right_shoulder_depth, scale_factor=DEPTH_SCALE) 115 | right_shoulder_mask = Image.fromarray( 116 | (obs.right_shoulder_mask * 255).astype(np.uint8)) 117 | overhead_rgb = Image.fromarray(obs.overhead_rgb) 118 | overhead_depth = utils.float_array_to_rgb_image( 119 | obs.overhead_depth, scale_factor=DEPTH_SCALE) 120 | overhead_mask = Image.fromarray( 121 | (obs.overhead_mask * 255).astype(np.uint8)) 122 | wrist_rgb = Image.fromarray(obs.wrist_rgb) 123 | wrist_depth = utils.float_array_to_rgb_image( 124 | obs.wrist_depth, scale_factor=DEPTH_SCALE) 125 | wrist_mask = Image.fromarray((obs.wrist_mask * 255).astype(np.uint8)) 126 | front_rgb = Image.fromarray(obs.front_rgb) 127 | front_depth = utils.float_array_to_rgb_image( 128 | obs.front_depth, scale_factor=DEPTH_SCALE) 129 | front_mask = Image.fromarray((obs.front_mask * 255).astype(np.uint8)) 130 | 131 | left_shoulder_rgb.save( 132 | os.path.join(left_shoulder_rgb_path, IMAGE_FORMAT % i)) 133 | left_shoulder_depth.save( 134 | os.path.join(left_shoulder_depth_path, IMAGE_FORMAT % i)) 135 | left_shoulder_mask.save( 136 | os.path.join(left_shoulder_mask_path, IMAGE_FORMAT % i)) 137 | right_shoulder_rgb.save( 138 | os.path.join(right_shoulder_rgb_path, IMAGE_FORMAT % i)) 139 | right_shoulder_depth.save( 140 | os.path.join(right_shoulder_depth_path, IMAGE_FORMAT % i)) 141 | right_shoulder_mask.save( 142 | os.path.join(right_shoulder_mask_path, IMAGE_FORMAT % i)) 143 | overhead_rgb.save( 144 | os.path.join(overhead_rgb_path, IMAGE_FORMAT % i)) 145 | overhead_depth.save( 146 | os.path.join(overhead_depth_path, IMAGE_FORMAT % i)) 147 | overhead_mask.save( 148 | os.path.join(overhead_mask_path, IMAGE_FORMAT % i)) 149 | wrist_rgb.save(os.path.join(wrist_rgb_path, IMAGE_FORMAT % i)) 150 | wrist_depth.save(os.path.join(wrist_depth_path, IMAGE_FORMAT % i)) 151 | wrist_mask.save(os.path.join(wrist_mask_path, IMAGE_FORMAT % i)) 152 | front_rgb.save(os.path.join(front_rgb_path, IMAGE_FORMAT % i)) 153 | front_depth.save(os.path.join(front_depth_path, IMAGE_FORMAT % i)) 154 | front_mask.save(os.path.join(front_mask_path, IMAGE_FORMAT % i)) 155 | 156 | # We save the images separately, so set these to None for pickling. 157 | obs.left_shoulder_rgb = None 158 | obs.left_shoulder_depth = None 159 | obs.left_shoulder_point_cloud = None 160 | obs.left_shoulder_mask = None 161 | obs.right_shoulder_rgb = None 162 | obs.right_shoulder_depth = None 163 | obs.right_shoulder_point_cloud = None 164 | obs.right_shoulder_mask = None 165 | obs.overhead_rgb = None 166 | obs.overhead_depth = None 167 | obs.overhead_point_cloud = None 168 | obs.overhead_mask = None 169 | obs.wrist_rgb = None 170 | obs.wrist_depth = None 171 | obs.wrist_point_cloud = None 172 | obs.wrist_mask = None 173 | obs.front_rgb = None 174 | obs.front_depth = None 175 | obs.front_point_cloud = None 176 | obs.front_mask = None 177 | 178 | # Save the low-dimension data 179 | with open(os.path.join(example_path, LOW_DIM_PICKLE), 'wb') as f: 180 | pickle.dump(demo, f) 181 | 182 | 183 | def run(i, lock, task_index, variation_count, results, file_lock, tasks): 184 | """Each thread will choose one task and variation, and then gather 185 | all the episodes_per_task for that variation.""" 186 | 187 | # Initialise each thread with random seed 188 | # np.random.seed(None) 189 | np.random.seed(FLAGS.seed) 190 | random.seed(FLAGS.seed) 191 | num_tasks = len(tasks) 192 | 193 | img_size = list(map(int, FLAGS.image_size)) 194 | 195 | obs_config = ObservationConfig() 196 | obs_config.set_all(True) 197 | obs_config.right_shoulder_camera.image_size = img_size 198 | obs_config.left_shoulder_camera.image_size = img_size 199 | obs_config.overhead_camera.image_size = img_size 200 | obs_config.wrist_camera.image_size = img_size 201 | obs_config.front_camera.image_size = img_size 202 | 203 | # Store depth as 0 - 1 204 | obs_config.right_shoulder_camera.depth_in_meters = False 205 | obs_config.left_shoulder_camera.depth_in_meters = False 206 | obs_config.overhead_camera.depth_in_meters = False 207 | obs_config.wrist_camera.depth_in_meters = False 208 | obs_config.front_camera.depth_in_meters = False 209 | 210 | # We want to save the masks as rgb encodings. 211 | obs_config.left_shoulder_camera.masks_as_one_channel = False 212 | obs_config.right_shoulder_camera.masks_as_one_channel = False 213 | obs_config.overhead_camera.masks_as_one_channel = False 214 | obs_config.wrist_camera.masks_as_one_channel = False 215 | obs_config.front_camera.masks_as_one_channel = False 216 | 217 | if FLAGS.renderer == 'opengl': 218 | obs_config.right_shoulder_camera.render_mode = RenderMode.OPENGL 219 | obs_config.left_shoulder_camera.render_mode = RenderMode.OPENGL 220 | obs_config.overhead_camera.render_mode = RenderMode.OPENGL 221 | obs_config.wrist_camera.render_mode = RenderMode.OPENGL 222 | obs_config.front_camera.render_mode = RenderMode.OPENGL 223 | 224 | rlbench_env = Environment( 225 | action_mode=MoveArmThenGripper(JointVelocity(), Discrete()), 226 | obs_config=obs_config, 227 | headless=True) 228 | rlbench_env.launch() 229 | 230 | task_env = None 231 | 232 | tasks_with_problems = results[i] = '' 233 | 234 | while True: 235 | # Figure out what task/variation this thread is going to do 236 | with lock: 237 | 238 | if task_index.value >= num_tasks: 239 | print('Process', i, 'finished') 240 | break 241 | 242 | my_variation_count = variation_count.value 243 | t = tasks[task_index.value] 244 | task_env = rlbench_env.get_task(t) 245 | var_target = task_env.variation_count() 246 | if FLAGS.variations >= 0: 247 | var_target = np.minimum(FLAGS.variations+FLAGS.offset, var_target) 248 | if my_variation_count >= var_target: 249 | # If we have reached the required number of variations for this 250 | # task, then move on to the next task. 251 | variation_count.value = my_variation_count = FLAGS.offset 252 | task_index.value += 1 253 | 254 | variation_count.value += 1 255 | if task_index.value >= num_tasks: 256 | print('Process', i, 'finished') 257 | break 258 | t = tasks[task_index.value] 259 | 260 | task_env = rlbench_env.get_task(t) 261 | task_env.set_variation(my_variation_count) 262 | descriptions, obs = task_env.reset() 263 | 264 | variation_path = os.path.join( 265 | FLAGS.save_path, task_env.get_name(), 266 | VARIATIONS_FOLDER % my_variation_count 267 | ) 268 | print(variation_path) 269 | 270 | check_and_make(variation_path) 271 | 272 | with open(os.path.join( 273 | variation_path, VARIATION_DESCRIPTIONS), 'wb') as f: 274 | pickle.dump(descriptions, f) 275 | 276 | episodes_path = os.path.join(variation_path, EPISODES_FOLDER) 277 | check_and_make(episodes_path) 278 | 279 | abort_variation = False 280 | for ex_idx in range(FLAGS.episodes_per_task): 281 | print('Process', i, '// Task:', task_env.get_name(), 282 | '// Variation:', my_variation_count, '// Demo:', ex_idx) 283 | attempts = 10 284 | while attempts > 0: 285 | episode_path = os.path.join(episodes_path, EPISODE_FOLDER % ex_idx) 286 | if os.path.exists(episode_path): 287 | break 288 | try: 289 | # TODO: for now we do the explicit looping. 290 | demo, = task_env.get_demos( 291 | amount=1, 292 | live_demos=True) 293 | except Exception as e: 294 | attempts -= 1 295 | if attempts > 0: 296 | continue 297 | problem = ( 298 | 'Process %d failed collecting task %s (variation: %d, ' 299 | 'example: %d). Skipping this task/variation.\n%s\n' % ( 300 | i, task_env.get_name(), my_variation_count, ex_idx, 301 | str(e)) 302 | ) 303 | print(problem) 304 | tasks_with_problems += problem 305 | abort_variation = True 306 | break 307 | with file_lock: 308 | save_demo(demo, episode_path) 309 | break 310 | if abort_variation: 311 | break 312 | 313 | results[i] = tasks_with_problems 314 | rlbench_env.shutdown() 315 | 316 | 317 | def main(argv): 318 | 319 | FLAGS.save_path = FLAGS.save_path.format(seed=FLAGS.seed) 320 | 321 | with open(FLAGS.all_task_file, 'r') as f: 322 | task_files = json.load(f) 323 | 324 | if len(FLAGS.tasks) > 0: 325 | for t in FLAGS.tasks: 326 | if t not in task_files: 327 | raise ValueError('Task %s not recognised!.' % t) 328 | task_files = FLAGS.tasks 329 | 330 | tasks = [task_file_to_task_class(t) for t in task_files] 331 | 332 | manager = Manager() 333 | 334 | result_dict = manager.dict() 335 | file_lock = manager.Lock() 336 | 337 | task_index = manager.Value('i', 0) 338 | variation_count = manager.Value('i', FLAGS.offset) 339 | lock = manager.Lock() 340 | 341 | check_and_make(FLAGS.save_path) 342 | 343 | processes = [Process( 344 | target=run, args=( 345 | i, lock, task_index, variation_count, result_dict, file_lock, 346 | tasks)) 347 | for i in range(FLAGS.processes)] 348 | [t.start() for t in processes] 349 | [t.join() for t in processes] 350 | 351 | print('Data collection done!') 352 | for i in range(FLAGS.processes): 353 | print(result_dict[i]) 354 | 355 | 356 | if __name__ == '__main__': 357 | app.run(main) 358 | -------------------------------------------------------------------------------- /preprocess/generate_instructions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate instruction embeddings 3 | """ 4 | import os 5 | import json 6 | import jsonlines 7 | from tqdm import tqdm 8 | 9 | import lmdb 10 | import msgpack 11 | import msgpack_numpy 12 | msgpack_numpy.patch() 13 | 14 | import torch 15 | 16 | from rlbench.action_modes.action_mode import MoveArmThenGripper 17 | from rlbench.action_modes.arm_action_modes import JointVelocity 18 | from rlbench.action_modes.gripper_action_modes import Discrete 19 | from rlbench.environment import Environment 20 | from rlbench.utils import name_to_task_class 21 | 22 | import transformers 23 | 24 | INSTRUCTION_FILE = 'assets/taskvar_instructions.jsonl' 25 | 26 | BROKEN_TASKS = set([ 27 | "empty_container", 28 | "set_the_table", 29 | ]) 30 | 31 | 32 | def generate_all_instructions(): 33 | if os.path.exists(INSTRUCTION_FILE): 34 | exist_tasks = set() 35 | with jsonlines.open(INSTRUCTION_FILE) as f: 36 | for item in f: 37 | exist_tasks.add(item['task']) 38 | print('Exist task', len(exist_tasks)) 39 | 40 | all_tasks = json.load(open('assets/all_tasks.json')) 41 | 42 | action_mode = MoveArmThenGripper( 43 | arm_action_mode=JointVelocity(), 44 | gripper_action_mode=Discrete() 45 | ) 46 | env = Environment(action_mode) 47 | env.launch() 48 | 49 | outfile = jsonlines.open(INSTRUCTION_FILE, 'a', flush=True) 50 | 51 | for task in tqdm(all_tasks): 52 | if task in BROKEN_TASKS or task in exist_tasks: 53 | continue 54 | print(task) 55 | outs = {'task': task, 'variations': {}} 56 | task_env = env.get_task(name_to_task_class(task)) 57 | num_variations = task_env.variation_count() 58 | for v in tqdm(range(num_variations)): 59 | try: 60 | task_env.set_variation(v) 61 | descriptions, obs = task_env.reset() 62 | outs['variations'][v] = descriptions 63 | except Exception as e: 64 | print('Error', task, v, e) 65 | outfile.write(outs) 66 | 67 | env.shutdown() 68 | outfile.close() 69 | 70 | 71 | def load_all_instructions(): 72 | data = [] 73 | with jsonlines.open(INSTRUCTION_FILE, 'r') as f: 74 | for item in f: 75 | data.append(item) 76 | return data 77 | 78 | def load_text_encoder(encoder: str): 79 | if encoder == "bert": 80 | tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") 81 | model = transformers.BertModel.from_pretrained("bert-base-uncased") 82 | elif encoder == "clip": 83 | model_name = "openai/clip-vit-base-patch32" 84 | tokenizer = transformers.CLIPTokenizer.from_pretrained(model_name) 85 | model = transformers.CLIPTextModel.from_pretrained(model_name) 86 | else: 87 | raise ValueError(f"Unexpected encoder {encoder}") 88 | 89 | return tokenizer, model 90 | 91 | def main(args): 92 | taskvar_instrs = load_all_instructions() 93 | 94 | tokenizer, model = load_text_encoder(args.encoder) 95 | model = model.to(args.device) 96 | 97 | os.makedirs(args.output_file, exist_ok=True) 98 | lmdb_env = lmdb.open(args.output_file, map_size=int(1024**4)) 99 | 100 | for item in tqdm(taskvar_instrs): 101 | task = item['task'] 102 | for variation, instructions in item['variations'].items(): 103 | key = '%s+%s' % (task, variation) 104 | 105 | instr_embeds = [] 106 | for instr in instructions: 107 | tokens = tokenizer(instr, padding=False)["input_ids"] 108 | if len(tokens) > 77: 109 | print('Too long', task, variation, instr) 110 | 111 | tokens = torch.LongTensor(tokens).unsqueeze(0).to(args.device) 112 | with torch.no_grad(): 113 | embed = model(tokens).last_hidden_state.squeeze(0) 114 | instr_embeds.append(embed.data.cpu().numpy()) 115 | 116 | txn = lmdb_env.begin(write=True) 117 | txn.put(key.encode('ascii'), msgpack.packb(instr_embeds)) 118 | txn.commit() 119 | 120 | lmdb_env.close() 121 | 122 | 123 | if __name__ == "__main__": 124 | import argparse 125 | 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--encoder', choices=['bert', 'clip'], default='clip') 128 | parser.add_argument('--device', default='cuda') 129 | parser.add_argument('--output_file', required=True) 130 | parser.add_argument('--generate_all_instructions', action='store_true', default=False) 131 | args = parser.parse_args() 132 | 133 | if args.generate_all_instructions: 134 | generate_all_instructions() 135 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | filelock==3.12.2 3 | jsonlines==3.1.0 4 | lmdb==1.4.1 5 | matplotlib==3.5.1 6 | msgpack_numpy==0.4.8 7 | msgpack_python==0.5.6 8 | numpy==1.23.5 9 | opencv_python==4.7.0.72 10 | opencv_python_headless==4.7.0.72 11 | Pillow==9.4.0 12 | tensorboardX==2.6 13 | torch==1.13.0 14 | torchvision==0.14.0 15 | tqdm==4.62.3 16 | transformers==4.19.4 17 | typed_argument_parser==1.8.0 18 | yacs==0.1.8 19 | absl-py==1.4.0 20 | -------------------------------------------------------------------------------- /summarize_tst_results.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import jsonlines 5 | import collections 6 | import tap 7 | 8 | 9 | class Arguments(tap.Tap): 10 | result_file: str 11 | 12 | 13 | def main(args): 14 | results = collections.defaultdict(list) 15 | with jsonlines.open(args.result_file, 'r') as f: 16 | for item in f: 17 | results[item['checkpoint']].append((item['task'], item['sr'])) 18 | 19 | for ckpt, res in results.items(): 20 | print('\n', ckpt) 21 | print(','.join([x[0] for x in res])) 22 | 23 | print(','.join(['%.2f' % (x[1]*100) for x in res])) 24 | 25 | print(np.mean([x[1] for x in res]) * 100) 26 | 27 | 28 | if __name__ == '__main__': 29 | args = Arguments().parse_args() 30 | main(args) 31 | -------------------------------------------------------------------------------- /summarize_val_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import jsonlines 4 | import collections 5 | import tap 6 | 7 | 8 | class Arguments(tap.Tap): 9 | result_file: str 10 | 11 | 12 | def main(args): 13 | results = collections.defaultdict(list) 14 | with jsonlines.open(args.result_file, 'r') as f: 15 | for item in f: 16 | results[item['checkpoint']].append(item['sr']) 17 | 18 | avg_results = [] 19 | for k, v in results.items(): 20 | print(k, len(v), np.mean(v)) 21 | avg_results.append((k, np.mean(v))) 22 | 23 | print() 24 | print('Best checkpoint and SR') 25 | avg_results.sort(key=lambda x: -x[1]) 26 | print(avg_results[0]) 27 | 28 | 29 | if __name__ == '__main__': 30 | args = Arguments().parse_args() 31 | main(args) 32 | -------------------------------------------------------------------------------- /train_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import time 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.distributed as dist 13 | 14 | from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file 15 | from utils.save import ModelSaver, save_training_meta 16 | from utils.misc import NoOp, set_dropout, set_random_seed, set_cuda, wrap_model 17 | from utils.distributed import all_gather 18 | 19 | from optim import get_lr_sched 20 | from optim.misc import build_optimizer 21 | 22 | from config.default import get_config 23 | from dataloaders.loader import build_dataloader 24 | 25 | from dataloaders.keystep_dataset import ( 26 | KeystepDataset, stepwise_collate_fn, episode_collate_fn 27 | ) 28 | from models.plain_unet import PlainUNet 29 | from models.transformer_unet import TransformerUNet 30 | 31 | import warnings 32 | warnings.filterwarnings("ignore") 33 | 34 | 35 | dataset_factory = { 36 | 'keystep_stepwise': (KeystepDataset, stepwise_collate_fn), 37 | 'keystep_episode': (KeystepDataset, episode_collate_fn), 38 | } 39 | model_factory = { 40 | 'PlainUNet': PlainUNet, 41 | 'TransformerUNet': TransformerUNet, 42 | } 43 | 44 | 45 | def main(config): 46 | config.defrost() 47 | default_gpu, n_gpu, device = set_cuda(config) 48 | # config.freeze() 49 | 50 | if default_gpu: 51 | LOGGER.info( 52 | 'device: {} n_gpu: {}, distributed training: {}'.format( 53 | device, n_gpu, bool(config.local_rank != -1) 54 | ) 55 | ) 56 | 57 | seed = config.SEED 58 | if config.local_rank != -1: 59 | seed += config.rank 60 | set_random_seed(seed) 61 | 62 | if type(config.DATASET.taskvars) is str: 63 | config.DATASET.taskvars = [config.DATASET.taskvars] 64 | 65 | # load data training set 66 | dataset_class, dataset_collate_fn = dataset_factory[config.DATASET.dataset_class] 67 | 68 | dataset = dataset_class(**config.DATASET) 69 | data_loader, pre_epoch = build_dataloader( 70 | dataset, dataset_collate_fn, True, config 71 | ) 72 | LOGGER.info(f'#num_steps_per_epoch: {len(data_loader)}') 73 | if config.num_train_steps is None: 74 | config.num_train_steps = len(data_loader) * config.num_epochs 75 | else: 76 | assert config.num_epochs is None, 'cannot set num_train_steps and num_epochs at the same time.' 77 | config.num_epochs = int( 78 | np.ceil(config.num_train_steps / len(data_loader))) 79 | config.freeze() 80 | 81 | # setup loggers 82 | if default_gpu: 83 | save_training_meta(config) 84 | TB_LOGGER.create(os.path.join(config.output_dir, 'logs')) 85 | pbar = tqdm(total=config.num_train_steps) 86 | model_saver = ModelSaver(os.path.join(config.output_dir, 'ckpts')) 87 | add_log_to_file(os.path.join(config.output_dir, 'logs', 'log.txt')) 88 | else: 89 | LOGGER.disabled = True 90 | pbar = NoOp() 91 | model_saver = NoOp() 92 | 93 | # Prepare model 94 | model_class = model_factory[config.MODEL.model_class] 95 | model = model_class(**config.MODEL) 96 | 97 | LOGGER.info("Model: nweights %d nparams %d" % (model.num_parameters)) 98 | LOGGER.info("Model: trainable nweights %d nparams %d" % 99 | (model.num_trainable_parameters)) 100 | 101 | if config.checkpoint: 102 | checkpoint = torch.load( 103 | config.checkpoint, map_location=lambda storage, loc: storage) 104 | if 'state_dict' in checkpoint: 105 | checkpoint = checkpoint['state_dict'] 106 | model.load_state_dict(checkpoint, strict=True) 107 | 108 | model.train() 109 | set_dropout(model, config.dropout) 110 | model = wrap_model(model, device, config.local_rank) 111 | 112 | # Prepare optimizer 113 | optimizer = build_optimizer(model, config) 114 | 115 | LOGGER.info(f"***** Running training with {config.world_size} GPUs *****") 116 | LOGGER.info(" Batch size = %d", config.train_batch_size if config.local_rank == - 117 | 1 else config.train_batch_size * config.world_size) 118 | LOGGER.info(" Accumulate steps = %d", config.gradient_accumulation_steps) 119 | LOGGER.info(" Num steps = %d", config.num_train_steps) 120 | 121 | # to compute training statistics 122 | global_step = 0 123 | 124 | start_time = time.time() 125 | # quick hack for amp delay_unscale bug 126 | optimizer.zero_grad() 127 | optimizer.step() 128 | 129 | for epoch_id in range(config.num_epochs): 130 | # In distributed mode, calling the set_epoch() method at the beginning of each epoch 131 | pre_epoch(epoch_id) 132 | 133 | for step, batch in enumerate(data_loader): 134 | # forward pass 135 | losses, logits = model(batch, compute_loss=True) 136 | 137 | # backward pass 138 | if config.gradient_accumulation_steps > 1: # average loss 139 | losses['total'] = losses['total'] / \ 140 | config.gradient_accumulation_steps 141 | losses['total'].backward() 142 | 143 | acc = ((logits[..., -1].data.cpu() > 0.5) 144 | == batch['actions'][..., -1]).float() 145 | 146 | if 'step_masks' in batch: 147 | acc = torch.sum(acc * batch['step_masks']) / \ 148 | torch.sum(batch['step_masks']) 149 | else: 150 | acc = acc.mean() 151 | 152 | for key, value in losses.items(): 153 | TB_LOGGER.add_scalar( 154 | f'step/loss_{key}', value.item(), global_step) 155 | TB_LOGGER.add_scalar('step/acc_open', acc.item(), global_step) 156 | 157 | # optimizer update and logging 158 | if (step + 1) % config.gradient_accumulation_steps == 0: 159 | global_step += 1 160 | 161 | # learning rate scheduling 162 | lr_this_step = get_lr_sched(global_step, config) 163 | for param_group in optimizer.param_groups: 164 | param_group['lr'] = lr_this_step 165 | TB_LOGGER.add_scalar('lr', lr_this_step, global_step) 166 | 167 | # log loss 168 | # NOTE: not gathered across GPUs for efficiency 169 | TB_LOGGER.step() 170 | 171 | # update model params 172 | if config.grad_norm != -1: 173 | grad_norm = torch.nn.utils.clip_grad_norm_( 174 | model.parameters(), config.grad_norm 175 | ) 176 | # print(step, name, grad_norm) 177 | # for k, v in model.named_parameters(): 178 | # if v.grad is not None: 179 | # v = torch.norm(v).data.item() 180 | # print(k, v) 181 | TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) 182 | optimizer.step() 183 | optimizer.zero_grad() 184 | pbar.update(1) 185 | 186 | if global_step % config.log_steps == 0: 187 | # monitor training throughput 188 | LOGGER.info( 189 | f'==============Epoch {epoch_id} Step {global_step}===============') 190 | LOGGER.info(', '.join(['%s:%.4f' % ( 191 | lk, lv.item()) for lk, lv in losses.items()] + ['acc:%.2f' % (acc*100)])) 192 | LOGGER.info('===============================================') 193 | 194 | if global_step % config.save_steps == 0: 195 | model_saver.save(model, global_step) 196 | 197 | if global_step >= config.num_train_steps: 198 | break 199 | 200 | if global_step % config.save_steps != 0: 201 | LOGGER.info( 202 | f'==============Epoch {epoch_id} Step {global_step}===============') 203 | LOGGER.info(', '.join(['%s:%.4f' % (lk, lv.item()) 204 | for lk, lv in losses.items()] + ['acc:%.2f' % (acc*100)])) 205 | LOGGER.info('===============================================') 206 | model_saver.save(model, global_step) 207 | 208 | 209 | def build_args(): 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument( 212 | "--exp-config", 213 | type=str, 214 | required=True, 215 | help="path to config yaml containing info about experiment", 216 | ) 217 | parser.add_argument( 218 | "opts", 219 | default=None, 220 | nargs=argparse.REMAINDER, 221 | help="Modify config options from command line", 222 | ) 223 | args = parser.parse_args() 224 | 225 | config = get_config(args.exp_config, args.opts) 226 | 227 | for i in range(len(config.CMD_TRAILING_OPTS)): 228 | if config.CMD_TRAILING_OPTS[i] == "DATASET.taskvars": 229 | if type(config.CMD_TRAILING_OPTS[i + 1]) is str: 230 | config.CMD_TRAILING_OPTS[i + 231 | 1] = [config.CMD_TRAILING_OPTS[i + 1]] 232 | 233 | if os.path.exists(config.output_dir) and os.listdir(config.output_dir): 234 | LOGGER.warning( 235 | "Output directory ({}) already exists and is not empty.".format( 236 | config.output_dir 237 | ) 238 | ) 239 | 240 | return config 241 | 242 | 243 | if __name__ == '__main__': 244 | config = build_args() 245 | main(config) 246 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vlc-robot/hiveformer/0ca80a156acb3985be236fd9ab50e56734f970d6/utils/__init__.py -------------------------------------------------------------------------------- /utils/coord_transforms.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | 5 | def convert_gripper_pose_world_to_image(obs, camera: str) -> Tuple[int, int]: 6 | '''Convert the gripper pose from world coordinate system to image coordinate system. 7 | image[v, u] is the gripper location. 8 | ''' 9 | extrinsics_44 = obs.misc[f"{camera}_camera_extrinsics"].astype(np.float32) 10 | extrinsics_44 = np.linalg.inv(extrinsics_44) 11 | 12 | intrinsics_33 = obs.misc[f"{camera}_camera_intrinsics"].astype(np.float32) 13 | intrinsics_34 = np.concatenate([intrinsics_33, np.zeros((3, 1), dtype=np.float32)], 1) 14 | 15 | gripper_pos_31 = obs.gripper_pose[:3].astype(np.float32)[:, None] 16 | gripper_pos_41 = np.concatenate([gripper_pos_31, np.ones((1, 1), dtype=np.float32)], 0) 17 | 18 | points_cam_41 = extrinsics_44 @ gripper_pos_41 19 | 20 | proj_31 = intrinsics_34 @ points_cam_41 21 | proj_3 = proj_31[:, 0] 22 | 23 | u = int((proj_3[0] / proj_3[2]).round()) 24 | v = int((proj_3[1] / proj_3[2]).round()) 25 | 26 | return u, v -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Distributed tools 3 | """ 4 | import os 5 | from pathlib import Path 6 | from pprint import pformat 7 | import pickle 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def set_local_rank(opts) -> int: 14 | if os.environ.get("LOCAL_RANK", "") != "": 15 | opts.local_rank = int(os.environ["LOCAL_RANK"]) 16 | elif os.environ.get("SLURM_LOCALID", "") != "": 17 | opts.local_rank = int(os.environ["SLURM_LOCALID"]) 18 | else: 19 | opts.local_rank = -1 20 | return opts.local_rank 21 | 22 | 23 | def load_init_param(opts): 24 | """ 25 | Load parameters for the rendezvous distributed procedure 26 | """ 27 | # num of gpus per node 28 | # WARNING: this assumes that each node has the same number of GPUs 29 | if os.environ.get("SLURM_NTASKS_PER_NODE", "") != "": 30 | num_gpus = int(os.environ['SLURM_NTASKS_PER_NODE']) 31 | else: 32 | num_gpus = torch.cuda.device_count() 33 | 34 | # world size 35 | if os.environ.get("WORLD_SIZE", "") != "": 36 | world_size = int(os.environ["WORLD_SIZE"]) 37 | elif os.environ.get("SLURM_JOB_NUM_NODES", ""): 38 | num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"]) 39 | world_size = num_nodes * num_gpus 40 | else: 41 | raise RuntimeError("Can't find any world size") 42 | opts.world_size = world_size 43 | 44 | # rank 45 | if os.environ.get("RANK", "") != "": 46 | # pytorch.distributed.launch provide this variable no matter what 47 | opts.rank = int(os.environ["RANK"]) 48 | elif os.environ.get("SLURM_PROCID", "") != "": 49 | opts.rank = int(os.environ["SLURM_PROCID"]) 50 | else: 51 | if os.environ.get("NODE_RANK", "") != "": 52 | opts.node_rank = int(os.environ["NODE_RANK"]) 53 | elif os.environ.get("SLURM_NODEID", "") != "": 54 | opts.node_rank = int(os.environ["SLURM_NODEID"]) 55 | else: 56 | raise RuntimeError("Can't find any rank or node rank") 57 | 58 | opts.rank = opts.local_rank + node_rank * num_gpus 59 | 60 | init_method = "env://" # need to specify MASTER_ADDR and MASTER_PORT 61 | 62 | return { 63 | "backend": "nccl", 64 | "init_method": init_method, 65 | "rank": opts.rank, 66 | "world_size": world_size, 67 | } 68 | 69 | 70 | def init_distributed(opts): 71 | init_param = load_init_param(opts) 72 | rank = init_param["rank"] 73 | print(f"Init distributed {init_param['rank']} - {init_param['world_size']}") 74 | 75 | dist.init_process_group(**init_param) 76 | 77 | 78 | def is_default_gpu(opts) -> bool: 79 | return opts.local_rank == -1 or dist.get_rank() == 0 80 | 81 | 82 | def is_dist_avail_and_initialized(): 83 | if not dist.is_available(): 84 | return False 85 | if not dist.is_initialized(): 86 | return False 87 | return True 88 | 89 | def get_world_size(): 90 | if not is_dist_avail_and_initialized(): 91 | return 1 92 | return dist.get_world_size() 93 | 94 | def all_gather(data): 95 | """ 96 | Run all_gather on arbitrary picklable data (not necessarily tensors) 97 | Args: 98 | data: any picklable object 99 | Returns: 100 | list[data]: list of data gathered from each rank 101 | """ 102 | world_size = get_world_size() 103 | if world_size == 1: 104 | return [data] 105 | 106 | # serialized to a Tensor 107 | buffer = pickle.dumps(data) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to("cuda") 110 | 111 | # obtain Tensor size of each rank 112 | local_size = torch.tensor([tensor.numel()], device="cuda") 113 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 114 | dist.all_gather(size_list, local_size) 115 | size_list = [int(size.item()) for size in size_list] 116 | max_size = max(size_list) 117 | 118 | # receiving Tensor from all ranks 119 | # we pad the tensor because torch all_gather does not support 120 | # gathering tensors of different shapes 121 | tensor_list = [] 122 | for _ in size_list: 123 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 124 | if local_size != max_size: 125 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 126 | tensor = torch.cat((tensor, padding), dim=0) 127 | dist.all_gather(tensor_list, tensor) 128 | 129 | data_list = [] 130 | for size, tensor in zip(size_list, tensor_list): 131 | buffer = tensor.cpu().numpy().tobytes()[:size] 132 | data_list.append(pickle.loads(buffer)) 133 | 134 | return data_list 135 | 136 | 137 | def reduce_dict(input_dict, average=True): 138 | """ 139 | Args: 140 | input_dict (dict): all the values will be reduced 141 | average (bool): whether to do average or sum 142 | Reduce the values in the dictionary from all processes so that all processes 143 | have the averaged results. Returns a dict with the same fields as 144 | input_dict, after reduction. 145 | """ 146 | world_size = get_world_size() 147 | if world_size < 2: 148 | return input_dict 149 | with torch.no_grad(): 150 | names = [] 151 | values = [] 152 | # sort the keys so that they are consistent across processes 153 | for k in sorted(input_dict.keys()): 154 | names.append(k) 155 | values.append(input_dict[k]) 156 | values = torch.stack(values, dim=0) 157 | dist.all_reduce(values) 158 | if average: 159 | values /= world_size 160 | reduced_dict = {k: v for k, v in zip(names, values)} 161 | return reduced_dict 162 | 163 | 164 | -------------------------------------------------------------------------------- /utils/keystep_detection.py: -------------------------------------------------------------------------------- 1 | '''Identify way-point in each RLBench Demo 2 | ''' 3 | 4 | from typing import List, Dict, Optional, Sequence, Tuple, TypedDict, Union, Any 5 | 6 | import numpy as np 7 | 8 | from rlbench.demo import Demo 9 | 10 | 11 | def _is_stopped(demo, i, obs, stopped_buffer): 12 | next_is_not_final = (i < (len(demo) - 2)) 13 | gripper_state_no_change = i < (len(demo) - 2) and ( 14 | obs.gripper_open == demo[i + 1].gripper_open 15 | and obs.gripper_open == demo[max(0, i - 1)].gripper_open 16 | and demo[max(0, i - 2)].gripper_open == demo[max(0, i - 1)].gripper_open 17 | ) 18 | small_delta = np.allclose(obs.joint_velocities, 0, atol=0.1) 19 | stopped = ( 20 | stopped_buffer <= 0 21 | and small_delta 22 | and next_is_not_final 23 | and gripper_state_no_change 24 | ) 25 | return stopped 26 | 27 | 28 | def keypoint_discovery(demo: Demo) -> List[int]: 29 | episode_keypoints = [] 30 | prev_gripper_open = demo[0].gripper_open 31 | stopped_buffer = 0 32 | for i, obs in enumerate(demo): 33 | stopped = _is_stopped(demo, i, obs, stopped_buffer) 34 | stopped_buffer = 4 if stopped else stopped_buffer - 1 35 | # If change in gripper, or end of episode. 36 | last = i == (len(demo) - 1) 37 | if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped): 38 | episode_keypoints.append(i) 39 | prev_gripper_open = obs.gripper_open 40 | if ( 41 | len(episode_keypoints) > 1 42 | and (episode_keypoints[-1] - 1) == episode_keypoints[-2] 43 | ): 44 | episode_keypoints.pop(-2) 45 | 46 | return episode_keypoints 47 | 48 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | helper for logging 6 | NOTE: loggers are global objects use with caution 7 | """ 8 | import logging 9 | import math 10 | 11 | import tensorboardX 12 | 13 | 14 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 15 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 16 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 17 | LOGGER = logging.getLogger('__main__') # this is the global logger 18 | 19 | 20 | def add_log_to_file(log_path): 21 | fh = logging.FileHandler(log_path) 22 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 23 | fh.setFormatter(formatter) 24 | LOGGER.addHandler(fh) 25 | 26 | 27 | class TensorboardLogger(object): 28 | def __init__(self): 29 | self._logger = None 30 | self._global_step = 0 31 | 32 | def create(self, path): 33 | self._logger = tensorboardX.SummaryWriter(path) 34 | 35 | def noop(self, *args, **kwargs): 36 | return 37 | 38 | def step(self): 39 | self._global_step += 1 40 | 41 | @property 42 | def global_step(self): 43 | return self._global_step 44 | 45 | def log_scalar_dict(self, log_dict, prefix=''): 46 | """ log a dictionary of scalar values""" 47 | if self._logger is None: 48 | return 49 | if prefix: 50 | prefix = f'{prefix}_' 51 | for name, value in log_dict.items(): 52 | if isinstance(value, dict): 53 | self.log_scalar_dict(value, self._global_step, 54 | prefix=f'{prefix}{name}') 55 | else: 56 | self._logger.add_scalar(f'{prefix}{name}', value, 57 | self._global_step) 58 | 59 | def __getattr__(self, name): 60 | if self._logger is None: 61 | return self.noop 62 | return self._logger.__getattribute__(name) 63 | 64 | 65 | TB_LOGGER = TensorboardLogger() 66 | 67 | 68 | class RunningMeter(object): 69 | """ running meteor of a scalar value 70 | (useful for monitoring training loss) 71 | """ 72 | def __init__(self, name, val=None, smooth=0.99): 73 | self._name = name 74 | self._sm = smooth 75 | self._val = val 76 | 77 | def __call__(self, value): 78 | val = (value if self._val is None 79 | else value*(1-self._sm) + self._val*self._sm) 80 | if not math.isnan(val): 81 | self._val = val 82 | 83 | def __str__(self): 84 | return f'{self._name}: {self._val:.4f}' 85 | 86 | @property 87 | def val(self): 88 | if self._val is None: 89 | return 0 90 | return self._val 91 | 92 | @property 93 | def name(self): 94 | return self._name 95 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from typing import Tuple, Union, Dict, Any 5 | 6 | import torch 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | 10 | from .distributed import init_distributed, set_local_rank 11 | from .logger import LOGGER 12 | 13 | 14 | def set_random_seed(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | def set_dropout(model, drop_p): 21 | for name, module in model.named_modules(): 22 | # we might want to tune dropout for smaller dataset 23 | if name.startswith('map_encoder'): continue 24 | if isinstance(module, torch.nn.Dropout): 25 | if module.p != drop_p: 26 | module.p = drop_p 27 | LOGGER.info(f'{name} set to {drop_p}') 28 | 29 | 30 | def set_cuda(opts) -> Tuple[bool, int, torch.device]: 31 | """ 32 | Initialize CUDA for distributed computing 33 | """ 34 | set_local_rank(opts) 35 | 36 | if not torch.cuda.is_available(): 37 | assert opts.local_rank == -1, opts.local_rank 38 | return True, 0, torch.device("cpu") 39 | 40 | # get device settings 41 | if opts.local_rank != -1: 42 | init_distributed(opts) 43 | torch.cuda.set_device(opts.local_rank) 44 | device = torch.device("cuda", opts.local_rank) 45 | n_gpu = 1 46 | default_gpu = dist.get_rank() == 0 47 | if default_gpu: 48 | LOGGER.info(f"Found {dist.get_world_size()} GPUs") 49 | else: 50 | default_gpu = True 51 | device = torch.device("cuda") 52 | n_gpu = torch.cuda.device_count() 53 | 54 | return default_gpu, n_gpu, device 55 | 56 | 57 | def wrap_model( 58 | model: torch.nn.Module, device: torch.device, local_rank: int 59 | ) -> torch.nn.Module: 60 | model.to(device) 61 | 62 | if local_rank != -1: 63 | model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) 64 | # At the time of DDP wrapping, parameters and buffers (i.e., model.state_dict()) 65 | # on rank0 are broadcasted to all other ranks. 66 | elif torch.cuda.device_count() > 1: 67 | LOGGER.info("Using data parallel") 68 | model = torch.nn.DataParallel(model) 69 | 70 | return model 71 | 72 | 73 | class NoOp(object): 74 | """ useful for distributed training No-Ops """ 75 | def __getattr__(self, name): 76 | return self.noop 77 | 78 | def noop(self, *args, **kwargs): 79 | return 80 | 81 | -------------------------------------------------------------------------------- /utils/ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def pad_tensors(tensors, lens=None, pad=0): 5 | """B x [T, ...] torch tensors""" 6 | if lens is None: 7 | lens = [t.size(0) for t in tensors] 8 | max_len = max(lens) 9 | bs = len(tensors) 10 | hid = list(tensors[0].size()[1:]) 11 | size = [bs, max_len] + hid 12 | 13 | dtype = tensors[0].dtype 14 | output = torch.zeros(*size, dtype=dtype) 15 | if pad: 16 | output.data.fill_(pad) 17 | for i, (t, l) in enumerate(zip(tensors, lens)): 18 | output.data[i, :l, ...] = t.data 19 | return output 20 | 21 | def pad_tensors_wgrad(tensors, lens=None, value=0): 22 | """B x [T, ...] torch tensors""" 23 | if lens is None: 24 | lens = [t.size(0) for t in tensors] 25 | max_len = max(lens) 26 | batch_size = len(tensors) 27 | hid = list(tensors[0].size()[1:]) 28 | 29 | device = tensors[0].device 30 | dtype = tensors[0].dtype 31 | 32 | output = [] 33 | for i in range(batch_size): 34 | if lens[i] < max_len: 35 | tmp = torch.cat( 36 | [tensors[i], torch.zeros([max_len-lens[i]]+hid, dtype=dtype).to(device) + value], 37 | dim=0 38 | ) 39 | else: 40 | tmp = tensors[i] 41 | output.append(tmp) 42 | output = torch.stack(output, 0) 43 | return output 44 | 45 | 46 | def gen_seq_masks(seq_lens, max_len=None): 47 | """ 48 | Args: 49 | seq_lens: list or nparray int, shape=(N, ) 50 | Returns: 51 | masks: nparray, shape=(N, L), padded=0 52 | """ 53 | seq_lens = np.array(seq_lens) 54 | if max_len is None: 55 | max_len = max(seq_lens) 56 | if max_len == 0: 57 | return np.zeros((len(seq_lens), 0), dtype=np.bool) 58 | batch_size = len(seq_lens) 59 | masks = np.arange(max_len).reshape(-1, max_len).repeat(batch_size, 0) 60 | masks = masks < seq_lens.reshape(-1, 1) 61 | return masks 62 | 63 | 64 | def extend_neg_masks(masks, dtype=None): 65 | """ 66 | mask from (N, L) into (N, 1(H), 1(L), L) and make it negative 67 | """ 68 | if dtype is None: 69 | dtype = torch.float 70 | extended_masks = masks.unsqueeze(1).unsqueeze(2) 71 | extended_masks = extended_masks.to(dtype=dtype) 72 | extended_masks = (1.0 - extended_masks) * -10000.0 73 | return extended_masks -------------------------------------------------------------------------------- /utils/recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Type 3 | import numpy as np 4 | 5 | from pathlib import Path 6 | from typing import Tuple, Dict, List 7 | from pyrep.objects.dummy import Dummy 8 | from pyrep.objects.vision_sensor import VisionSensor 9 | 10 | 11 | class CameraMotion(object): 12 | def __init__(self, cam: VisionSensor): 13 | self.cam = cam 14 | 15 | def step(self): 16 | raise NotImplementedError() 17 | 18 | def save_pose(self): 19 | self._prev_pose = self.cam.get_pose() 20 | 21 | def restore_pose(self): 22 | self.cam.set_pose(self._prev_pose) 23 | 24 | 25 | class CircleCameraMotion(CameraMotion): 26 | 27 | def __init__(self, cam: VisionSensor, origin: Dummy, speed: float): 28 | super().__init__(cam) 29 | self.origin = origin 30 | self.speed = speed # in radians 31 | 32 | def step(self): 33 | self.origin.rotate([0, 0, self.speed]) 34 | 35 | 36 | class StaticCameraMotion(CameraMotion): 37 | 38 | def __init__(self, cam: VisionSensor): 39 | super().__init__(cam) 40 | 41 | def step(self): 42 | pass 43 | 44 | class AttachedCameraMotion(CameraMotion): 45 | 46 | def __init__(self, cam: VisionSensor, parent_cam: VisionSensor): 47 | super().__init__(cam) 48 | self.parent_cam = parent_cam 49 | 50 | def step(self): 51 | self.cam.set_pose(self.parent_cam.get_pose()) 52 | 53 | 54 | class TaskRecorder(object): 55 | 56 | def __init__(self, cams_motion: Dict[str, CameraMotion], fps=30): 57 | self._cams_motion = cams_motion 58 | self._fps = fps 59 | self._snaps = {cam_name: [] for cam_name in self._cams_motion.keys()} 60 | 61 | def take_snap(self): 62 | for cam_name, cam_motion in self._cams_motion.items(): 63 | cam_motion.step() 64 | self._snaps[cam_name].append( 65 | (cam_motion.cam.capture_rgb() * 255.).astype(np.uint8)) 66 | 67 | def save(self, path): 68 | print('Converting to video ...') 69 | path = Path(path) 70 | path.mkdir(exist_ok=True) 71 | # OpenCV QT version can conflict with PyRep, so import here 72 | import cv2 73 | for cam_name, cam_motion in self._cams_motion.items(): 74 | video = cv2.VideoWriter( 75 | str(path / f"{cam_name}.avi"), cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), self._fps, 76 | tuple(cam_motion.cam.get_resolution())) 77 | for image in self._snaps[cam_name]: 78 | video.write(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 79 | video.release() 80 | 81 | self._snaps = {cam_name: [] for cam_name in self._cams_motion.keys()} 82 | 83 | -------------------------------------------------------------------------------- /utils/save.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | saving utilities 6 | """ 7 | import json 8 | import os 9 | import torch 10 | 11 | 12 | def save_training_meta(args): 13 | os.makedirs(os.path.join(args.output_dir, 'logs'), exist_ok=True) 14 | os.makedirs(os.path.join(args.output_dir, 'ckpts'), exist_ok=True) 15 | 16 | with open(os.path.join(args.output_dir, 'logs', 'training_config.yaml'), 'w') as writer: 17 | args_str = args.dump() 18 | print(args_str, file=writer) 19 | 20 | class ModelSaver(object): 21 | def __init__(self, output_dir, prefix='model_step', suffix='pt'): 22 | self.output_dir = output_dir 23 | self.prefix = prefix 24 | self.suffix = suffix 25 | 26 | def save(self, model, step, optimizer=None): 27 | output_model_file = os.path.join(self.output_dir, 28 | f"{self.prefix}_{step}.{self.suffix}") 29 | state_dict = {} 30 | for k, v in model.state_dict().items(): 31 | if k.startswith('module.'): 32 | k = k[7:] 33 | if isinstance(v, torch.Tensor): 34 | state_dict[k] = v.cpu() 35 | else: 36 | state_dict[k] = v 37 | torch.save(state_dict, output_model_file) 38 | if optimizer is not None: 39 | dump = {'step': step, 'optimizer': optimizer.state_dict()} 40 | if hasattr(optimizer, '_amp_stash'): 41 | pass # TODO fp16 optimizer 42 | torch.save(dump, f'{self.output_dir}/train_state_{step}.pt') 43 | 44 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import random 6 | import torch 7 | 8 | def set_random_seed(seed): 9 | torch.manual_seed(seed) 10 | np.random.seed(seed) 11 | random.seed(seed) 12 | 13 | 14 | def get_expr_dirs(output_dir): 15 | log_dir = os.path.join(output_dir, 'logs') 16 | ckpt_dir = os.path.join(output_dir, 'ckpts') 17 | pred_dir = os.path.join(output_dir, 'preds') 18 | 19 | os.makedirs(log_dir, exist_ok=True) 20 | os.makedirs(ckpt_dir, exist_ok=True) 21 | os.makedirs(pred_dir, exist_ok=True) 22 | 23 | return log_dir, ckpt_dir, pred_dir 24 | 25 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | def plot_attention( 7 | attentions: torch.Tensor, rgbs: torch.Tensor, pcds: torch.Tensor, dest: Path 8 | ) -> plt.Figure: 9 | attentions = attentions.detach().cpu() 10 | rgbs = rgbs.detach().cpu() 11 | pcds = pcds.detach().cpu() 12 | 13 | ep_dir = dest.parent 14 | ep_dir.mkdir(exist_ok=True, parents=True) 15 | name = dest.stem 16 | ext = dest.suffix 17 | 18 | # plt.figure(figsize=(10, 8)) 19 | num_cameras = len(attentions) 20 | for i, (a, rgb, pcd) in enumerate(zip(attentions, rgbs, pcds)): 21 | # plt.subplot(num_cameras, 4, i * 4 + 1) 22 | plt.imshow(a.permute(1, 2, 0).log()) 23 | plt.axis("off") 24 | plt.colorbar() 25 | plt.savefig(ep_dir / f"{name}-{i}-attn{ext}", bbox_inches="tight") 26 | plt.tight_layout() 27 | plt.clf() 28 | 29 | # plt.subplot(num_cameras, 4, i * 4 + 2) 30 | # plt.imshow(a.permute(1, 2, 0)) 31 | # plt.axis('off') 32 | # plt.colorbar() 33 | # plt.tight_layout() 34 | # plt.clf() 35 | 36 | # plt.subplot(num_cameras, 4, i * 4 + 3) 37 | plt.imshow(((rgb + 1) / 2).permute(1, 2, 0)) 38 | plt.axis("off") 39 | plt.savefig(ep_dir / f"{name}-{i}-rgb{ext}", bbox_inches="tight") 40 | plt.tight_layout() 41 | plt.clf() 42 | 43 | pcd_norm = (pcd - pcd.min(0).values) / (pcd.max(0).values - pcd.min(0).values) 44 | # plt.subplot(num_cameras, 4, i * 4 + 4) 45 | plt.imshow(pcd_norm.permute(1, 2, 0)) 46 | plt.axis("off") 47 | plt.savefig(ep_dir / f"{name}-{i}-pcd{ext}", bbox_inches="tight") 48 | plt.tight_layout() 49 | plt.clf() 50 | 51 | return plt.gcf() --------------------------------------------------------------------------------