├── agents ├── __init__.py ├── baselines │ ├── __init__.py │ ├── bc_lang │ │ ├── __init__.py │ │ └── bc_lang_agent.py │ └── vit_bc_lang │ │ ├── __init__.py │ │ └── vit_bc_lang_agent.py ├── act_bc_lang │ ├── detr │ │ ├── __init__.py │ │ ├── util │ │ │ └── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── position_encoding.py │ │ │ └── backbone.py │ │ └── build.py │ ├── __init__.py │ └── act_policy.py ├── arm │ ├── __init__.py │ └── qattention_agent.py ├── rvt │ ├── __init__.py │ └── launch_utils.py ├── peract_bc │ ├── __init__.py │ ├── launch_utils.py │ └── qattention_stack_agent.py ├── bimanual_peract │ ├── __init__.py │ ├── launch_utils.py │ └── qattention_stack_agent.py ├── c2farm_lingunet_bc │ ├── __init__.py │ ├── qattention_stack_agent.py │ └── networks.py └── agent_factory.py ├── voxel ├── __init__.py └── voxel_grid.py ├── helpers ├── __init__.py ├── clip │ ├── __init__.py │ └── core │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── attention_image_goal.py │ │ ├── attention.py │ │ ├── unet.py │ │ ├── transport.py │ │ ├── simple_tokenizer.py │ │ ├── transport_image_goal.py │ │ └── resnet.py ├── optim │ ├── __init__.py │ └── lamb.py ├── preprocess_agent.py ├── demo_loading_utils.py └── observation_utils.py ├── conf ├── method │ ├── BC_LANG.yaml │ ├── VIT_BC_LANG.yaml │ ├── ARM.yaml │ ├── C2FARM_LINGUNET_BC.yaml │ ├── ACT_BC_LANG.yaml │ ├── RVT.yaml │ ├── PERACT_BC.yaml │ └── BIMANUAL_PERACT.yaml ├── hydra │ └── job_logging │ │ └── custom.yaml ├── eval.yaml └── config.yaml ├── scripts ├── install_conda.sh └── install_dependencies.sh ├── peract_config.py ├── pyproject.toml ├── INSTALLATION.md ├── Dockerfile ├── model-card.md ├── .gitignore ├── train.py ├── run_seed_fn.py ├── eval.py └── ARM_LICENSE /agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /voxel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /agents/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/clip/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/optim/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/clip/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /agents/act_bc_lang/detr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /agents/arm/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.arm.launch_utils 2 | -------------------------------------------------------------------------------- /agents/rvt/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.rvt.launch_utils 2 | -------------------------------------------------------------------------------- /agents/act_bc_lang/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.act_bc_lang.launch_utils 2 | -------------------------------------------------------------------------------- /agents/peract_bc/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.peract_bc.launch_utils 2 | -------------------------------------------------------------------------------- /agents/bimanual_peract/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.bimanual_peract.launch_utils 2 | -------------------------------------------------------------------------------- /agents/baselines/bc_lang/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.baselines.bc_lang.launch_utils 2 | -------------------------------------------------------------------------------- /agents/c2farm_lingunet_bc/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.c2farm_lingunet_bc.launch_utils 2 | -------------------------------------------------------------------------------- /agents/baselines/vit_bc_lang/__init__.py: -------------------------------------------------------------------------------- 1 | import agents.baselines.vit_bc_lang.launch_utils 2 | -------------------------------------------------------------------------------- /agents/act_bc_lang/detr/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /helpers/clip/core/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markusgrotz/peract_bimanual/HEAD/helpers/clip/core/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /conf/method/BC_LANG.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'BC_LANG' 4 | activation: lrelu 5 | lr: 0.0005 6 | weight_decay: 0.000001 7 | grad_clip: 0.1 8 | demo_augmentation: True 9 | demo_augmentation_every_n: 10 10 | -------------------------------------------------------------------------------- /conf/method/VIT_BC_LANG.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'VIT_BC_LANG' 4 | activation: lrelu 5 | lr: 0.0005 6 | weight_decay: 0.000001 7 | grad_clip: 0.1 8 | demo_augmentation: True 9 | demo_augmentation_every_n: 10 10 | -------------------------------------------------------------------------------- /conf/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: '[%(levelname)s] - %(message)s' 5 | handlers: 6 | rich_console: 7 | class: rich.logging.RichHandler 8 | root: 9 | handlers: [rich_console] 10 | 11 | 12 | disable_existing_loggers: false 13 | -------------------------------------------------------------------------------- /agents/act_bc_lang/detr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .detr_vae import build as build_vae 3 | from .detr_vae import build_cnnmlp as build_cnnmlp 4 | 5 | def build_ACT_model(args): 6 | return build_vae(args) 7 | 8 | def build_CNNMLP_model(args): 9 | return build_cnnmlp(args) -------------------------------------------------------------------------------- /scripts/install_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -exu 2 | 3 | # install conda 4 | 5 | sudo apt install curl 6 | 7 | 8 | TEMP_DIR=$(mktemp --tmpdir -d miniconda_XXXXXXXXXX) 9 | cd $TEMP_DIR 10 | 11 | curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 12 | chmod +x Miniconda3-latest-Linux-x86_64.sh 13 | ./Miniconda3-latest-Linux-x86_64.sh 14 | 15 | 16 | SHELL_NAME=`basename $SHELL` 17 | eval "$($HOME/miniconda3/bin/conda shell.${SHELL_NAME} hook)" 18 | 19 | conda init ${SHELL_NAME} 20 | conda install mamba -c conda-forge 21 | conda config --set auto_activate_base false 22 | 23 | -------------------------------------------------------------------------------- /conf/method/ARM.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'ARM' 4 | activation: lrelu 5 | q_conf: True 6 | alpha: 0.05 7 | alpha_lr: 0.0001 8 | alpha_auto_tune: False 9 | next_best_pose_critic_lr: 0.0025 10 | next_best_pose_actor_lr: 0.001 11 | next_best_pose_critic_weight_decay: 0.00001 12 | next_best_pose_actor_weight_decay: 0.00001 13 | crop_shape: [16, 16] 14 | next_best_pose_tau: 0.005 15 | next_best_pose_critic_grad_clip: 5 16 | next_best_pose_actor_grad_clip: 5 17 | qattention_grad_clip: 5 18 | qattention_tau: 0.005 19 | qattention_lr: 0.0005 20 | qattention_weight_decay: 0.00001 21 | qattention_lambda_qreg: 0.0000001 22 | 23 | demo_augmentation: True 24 | demo_augmentation_every_n: 10 25 | -------------------------------------------------------------------------------- /conf/method/C2FARM_LINGUNET_BC.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'C2FARM_LINGUNET_BC' 4 | 5 | # Voxelization 6 | image_crop_size: 64 7 | bounds_offset: [0.15] 8 | voxel_sizes: [32, 32] 9 | include_prev_layer: False 10 | 11 | # Training 12 | lr: 0.0005 13 | lr_scheduler: False 14 | num_warmup_steps: 10000 15 | 16 | lambda_weight_l2: 0.000001 17 | trans_loss_weight: 1.0 18 | rot_loss_weight: 1.0 19 | grip_loss_weight: 1.0 20 | collision_loss_weight: 1.0 21 | rotation_resolution: 5 22 | 23 | # Network 24 | activation: lrelu 25 | norm: None 26 | 27 | # Augmentation 28 | crop_augmentation: True 29 | transform_augmentation: 30 | apply_se3: True 31 | aug_xyz: [0.125, 0.125, 0.125] 32 | aug_rpy: [0.0, 0.0, 45.0] 33 | aug_rot_resolution: ${method.rotation_resolution} 34 | 35 | demo_augmentation: True 36 | demo_augmentation_every_n: 10 37 | exploration_strategy: gaussian 38 | 39 | # Ablations 40 | keypoint_method: 'heuristic' -------------------------------------------------------------------------------- /peract_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | System configuration for peract 3 | """ 4 | import os 5 | import logging 6 | 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def config_logging(logging_level=logging.INFO, reset=False): 11 | if reset: 12 | root = logging.getLogger() 13 | list(map(root.removeHandler, root.handlers)) 14 | list(map(root.removeFilter, root.filters)) 15 | 16 | from rich.logging import RichHandler 17 | 18 | logging.basicConfig(level=logging_level, handlers=[RichHandler()]) 19 | 20 | 21 | def on_init(): 22 | config_logging(logging.INFO) 23 | 24 | logging.debug("Configuring environment.") 25 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 26 | mp.set_start_method("spawn", force=True) 27 | mp.set_sharing_strategy("file_system") 28 | 29 | 30 | def on_config(cfg): 31 | os.environ["MASTER_ADDR"] = str(cfg.ddp.master_addr) 32 | os.environ["MASTER_PORT"] = str(cfg.ddp.master_port) 33 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "peract_bimanual" 3 | version = "0.0.1" 4 | description = "A perceiver actor framework for bimanual manipulation tasks" 5 | authors = [ "Markus Grotz ", 6 | "Mohit Shridhar "] 7 | packages = [{include = "agents"}, {include = "helpers"}, {include = "voxel"}] 8 | 9 | 10 | readme = "README.md" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "Framework :: Robot Framework " 14 | ] 15 | 16 | [tool.poetry.dependencies] 17 | python = ">=3.8,<4.0" 18 | einops = "0.3.2" 19 | ftfy = "^6.1.1" 20 | hydra-core = ">=1.0.5" 21 | matplotlib = "^3.7.1" 22 | pandas = "1.4.1" 23 | regex = "^2023.6.3" 24 | tensorboard = "^2.13.0" 25 | perceiver-pytorch = "^0.8.7" 26 | transformers = "^4.21" 27 | 28 | 29 | 30 | [tool.poetry.extras] 31 | docs = ["sphinx"] 32 | 33 | [build-system] 34 | requires = ["setuptools", "wheel", "poetry-core>=1.0.0"] 35 | build-backend = "poetry.core.masonry.api" 36 | -------------------------------------------------------------------------------- /conf/method/ACT_BC_LANG.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'ACT_BC_LANG' 4 | 5 | # Agent 6 | robot_name: 'bimanual' 7 | agent_type: 'bimanual' 8 | 9 | 10 | train_demo_path: "/home/markus/rlbench_data_v2_128/train/" 11 | 12 | activation: lrelu 13 | lr: 1e-4 14 | weight_decay: 0.000001 15 | grad_clip: 0.1 16 | demo_augmentation: True 17 | demo_augmentation_every_n: 10 18 | 19 | prev_action_horizon: 1 20 | next_action_horizon: 10 21 | 22 | # hyperparameters 23 | lr_backbone: 1e-5 24 | backbone: resnet18 25 | dilation: False 26 | position_embedding: sine 27 | kl_weight: 100 28 | chunk_size: ${method.next_action_horizon} 29 | 30 | # transformer 31 | input_dim: 16 # 7 revolute joints + 1 gripper joints 32 | enc_layers: 4 33 | dec_layers: 7 34 | dim_feedforward: 3200 35 | hidden_dim: 512 36 | dropout: 0.1 37 | nheads: 8 38 | num_queries: ${method.next_action_horizon} 39 | pre_norm: False 40 | 41 | # unused 42 | masks: False 43 | 44 | # legacy 45 | camera_names: ${rlbench.cameras} 46 | 47 | # ..todo:: also set the following 48 | 49 | +rlbench.episode_length: 400 50 | +rlbench.arm_action_mode: JointPosition 51 | +rlbench.action_mode: JointPositionActionMode 52 | -------------------------------------------------------------------------------- /conf/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - method: PERACT_BC 3 | 4 | 5 | rlbench: 6 | task_name: "multi" 7 | tasks: [open_drawer,slide_block_to_color_target] 8 | demo_path: /my/demo/path 9 | episode_length: 25 10 | cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"] 11 | camera_resolution: [128, 128] 12 | scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] 13 | include_lang_goal_in_obs: True 14 | time_in_state: True 15 | headless: True 16 | gripper_mode: 'Discrete' 17 | arm_action_mode: 'EndEffectorPoseViaPlanning' 18 | action_mode: 'MoveArmThenGripper' 19 | 20 | framework: 21 | tensorboard_logging: True 22 | csv_logging: True 23 | gpu: 0 24 | logdir: '/tmp/arm_test/' 25 | start_seed: 0 26 | record_every_n: 5 27 | 28 | eval_envs: 1 29 | eval_from_eps_number: 0 30 | eval_episodes: 5 31 | eval_type: 'last' # or 'best', 'missing', or 'last' 32 | eval_save_metrics: True 33 | 34 | cinematic_recorder: 35 | enabled: False 36 | camera_resolution: [1280, 720] 37 | fps: 30 38 | rotate_speed: 0.005 39 | save_path: '/tmp/videos/' 40 | -------------------------------------------------------------------------------- /agents/act_bc_lang/detr/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from .models import build_ACT_model, build_CNNMLP_model 8 | 9 | 10 | 11 | def build_ACT_model_and_optimizer(args): 12 | model = build_ACT_model(args) 13 | 14 | param_dicts = [ 15 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 16 | { 17 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 18 | "lr": args.lr_backbone, 19 | }, 20 | ] 21 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 22 | weight_decay=args.weight_decay) 23 | 24 | return model, optimizer 25 | 26 | 27 | def build_CNNMLP_model_and_optimizer(args): 28 | model = build_CNNMLP_model(args) 29 | 30 | param_dicts = [ 31 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 32 | { 33 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 34 | "lr": args.lr_backbone, 35 | }, 36 | ] 37 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 38 | weight_decay=args.weight_decay) 39 | 40 | return model, optimizer 41 | 42 | -------------------------------------------------------------------------------- /conf/method/RVT.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'RVT' 4 | 5 | # Agent 6 | agent_type: 'leader_follower' 7 | robot_name: 'bimanual' 8 | 9 | # Voxelization 10 | image_crop_size: 64 11 | bounds_offset: [0.15] 12 | voxel_sizes: [100] 13 | include_prev_layer: False 14 | 15 | low_dim_size: 4 16 | 17 | # Perceiver 18 | num_latents: 2048 19 | latent_dim: 512 20 | transformer_depth: 6 21 | transformer_iterations: 1 22 | cross_heads: 1 23 | cross_dim_head: 64 24 | latent_heads: 8 25 | latent_dim_head: 64 26 | pos_encoding_with_lang: True 27 | conv_downsample: True 28 | lang_fusion_type: 'seq' # or 'concat' 29 | voxel_patch_size: 5 30 | voxel_patch_stride: 5 31 | final_dim: 64 32 | 33 | # Training 34 | input_dropout: 0.1 35 | attn_dropout: 0.1 36 | decoder_dropout: 0.0 37 | 38 | lr: 0.0005 39 | lr_scheduler: False 40 | num_warmup_steps: 3000 41 | optimizer: 'lamb' # or 'adam' 42 | 43 | lambda_weight_l2: 0.000001 44 | trans_loss_weight: 1.0 45 | rot_loss_weight: 1.0 46 | grip_loss_weight: 1.0 47 | collision_loss_weight: 1.0 48 | rotation_resolution: 5 49 | 50 | # Network 51 | activation: lrelu 52 | norm: None 53 | 54 | # Augmentation 55 | crop_augmentation: True 56 | transform_augmentation: 57 | apply_se3: True 58 | aug_xyz: [0.125, 0.125, 0.125] 59 | aug_rpy: [0.0, 0.0, 45.0] 60 | aug_rot_resolution: ${method.rotation_resolution} 61 | 62 | demo_augmentation: True 63 | demo_augmentation_every_n: 10 64 | 65 | # Ablations 66 | no_skip_connection: False 67 | no_perceiver: False 68 | no_language: False 69 | keypoint_method: 'heuristic' 70 | -------------------------------------------------------------------------------- /conf/method/PERACT_BC.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'PERACT_BC' 4 | 5 | # Agent 6 | agent_type: 'leader_follower' 7 | robot_name: 'bimanual' 8 | 9 | # Voxelization 10 | image_crop_size: 64 11 | bounds_offset: [0.15] 12 | voxel_sizes: [100] 13 | include_prev_layer: False 14 | 15 | # Perceiver 16 | num_latents: 2048 17 | latent_dim: 512 18 | transformer_depth: 6 19 | transformer_iterations: 1 20 | cross_heads: 1 21 | cross_dim_head: 64 22 | latent_heads: 8 23 | latent_dim_head: 64 24 | pos_encoding_with_lang: True 25 | conv_downsample: True 26 | lang_fusion_type: 'seq' # or 'concat' 27 | voxel_patch_size: 5 28 | voxel_patch_stride: 5 29 | final_dim: 64 30 | low_dim_size: 4 31 | 32 | # Training 33 | input_dropout: 0.1 34 | attn_dropout: 0.1 35 | decoder_dropout: 0.0 36 | 37 | lr: 0.0005 38 | lr_scheduler: False 39 | num_warmup_steps: 3000 40 | optimizer: 'lamb' # or 'adam' 41 | 42 | lambda_weight_l2: 0.000001 43 | trans_loss_weight: 1.0 44 | rot_loss_weight: 1.0 45 | grip_loss_weight: 1.0 46 | collision_loss_weight: 1.0 47 | rotation_resolution: 5 48 | 49 | # Network 50 | activation: lrelu 51 | norm: None 52 | 53 | # Augmentation 54 | crop_augmentation: True 55 | transform_augmentation: 56 | apply_se3: True 57 | aug_xyz: [0.125, 0.125, 0.125] 58 | aug_rpy: [0.0, 0.0, 45.0] 59 | aug_rot_resolution: ${method.rotation_resolution} 60 | 61 | demo_augmentation: True 62 | demo_augmentation_every_n: 10 63 | 64 | # Ablations 65 | no_skip_connection: False 66 | no_perceiver: False 67 | no_language: False 68 | keypoint_method: 'heuristic' 69 | -------------------------------------------------------------------------------- /conf/method/BIMANUAL_PERACT.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | name: 'BIMANUAL_PERACT' 4 | 5 | # Agent 6 | robot_name: 'bimanual' 7 | agent_type: 'bimanual' 8 | 9 | 10 | # Voxelization 11 | image_crop_size: 64 12 | bounds_offset: [0.15] 13 | voxel_sizes: [100] 14 | include_prev_layer: False 15 | 16 | # Perceiver 17 | num_latents: 2048 18 | latent_dim: 512 19 | transformer_depth: 6 20 | transformer_iterations: 1 21 | cross_heads: 1 22 | cross_dim_head: 64 23 | latent_heads: 8 24 | latent_dim_head: 64 25 | pos_encoding_with_lang: True 26 | conv_downsample: True 27 | lang_fusion_type: 'seq' # or 'concat' 28 | voxel_patch_size: 5 29 | voxel_patch_stride: 5 30 | final_dim: 64 31 | low_dim_size: 8 32 | 33 | 34 | # Training 35 | input_dropout: 0.1 36 | attn_dropout: 0.1 37 | decoder_dropout: 0.0 38 | 39 | lr: 0.0005 40 | lr_scheduler: False 41 | num_warmup_steps: 3000 42 | optimizer: 'lamb' # or 'adam' 43 | 44 | lambda_weight_l2: 0.000001 45 | trans_loss_weight: 1.0 46 | rot_loss_weight: 1.0 47 | grip_loss_weight: 1.0 48 | collision_loss_weight: 1.0 49 | rotation_resolution: 5 50 | 51 | # Network 52 | activation: lrelu 53 | norm: None 54 | 55 | # Augmentation 56 | crop_augmentation: True 57 | transform_augmentation: 58 | apply_se3: True 59 | aug_xyz: [0.125, 0.125, 0.125] 60 | aug_rpy: [0.0, 0.0, 45.0] 61 | aug_rot_resolution: ${method.rotation_resolution} 62 | 63 | demo_augmentation: True 64 | demo_augmentation_every_n: 10 65 | 66 | # Ablations 67 | no_skip_connection: False 68 | no_perceiver: False 69 | no_language: False 70 | keypoint_method: 'heuristic' 71 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | ddp: 2 | master_addr: "localhost" 3 | master_port: "0" 4 | num_devices: 1 5 | 6 | rlbench: 7 | task_name: "multi" 8 | tasks: [open_drawer,slide_block_to_color_target] 9 | demos: 100 10 | demo_path: /my/demo/path 11 | episode_length: 25 12 | cameras: ["over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"] 13 | camera_resolution: [128, 128] 14 | scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] 15 | include_lang_goal_in_obs: True 16 | 17 | replay: 18 | batch_size: 8 19 | timesteps: 1 20 | prioritisation: False 21 | task_uniform: True # uniform sampling of tasks for multi-task buffers 22 | use_disk: True 23 | path: '/tmp/arm/replay' # only used when use_disk is True. 24 | max_parallel_processes: 32 25 | 26 | framework: 27 | log_freq: 100 28 | save_freq: 100 29 | train_envs: 1 30 | replay_ratio: ${replay.batch_size} 31 | transitions_before_train: 200 32 | tensorboard_logging: True 33 | csv_logging: True 34 | training_iterations: 40000 35 | gpu: 0 36 | env_gpu: 0 37 | logdir: '/tmp/arm_test/' 38 | logging_level: 20 # https://docs.python.org/3/library/logging.html#levels 39 | seeds: 1 40 | start_seed: 0 41 | load_existing_weights: True 42 | num_weights_to_keep: 60 # older checkpoints will be deleted chronologically 43 | num_workers: 0 44 | record_every_n: 5 45 | checkpoint_name_prefix: "checkpoint" 46 | 47 | defaults: 48 | - method: PERACT_BC 49 | 50 | hydra: 51 | run: 52 | dir: ${framework.logdir}/${rlbench.task_name}/${method.name} 53 | -------------------------------------------------------------------------------- /helpers/clip/core/attention_image_goal.py: -------------------------------------------------------------------------------- 1 | """Attention module.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | from cliport.models.core.attention import Attention 9 | 10 | 11 | class AttentionImageGoal(Attention): 12 | """Attention (a.k.a Pick) with image-goals module.""" 13 | 14 | def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device): 15 | super().__init__(stream_fcn, in_shape, n_rotations, preprocess, cfg, device) 16 | 17 | def forward(self, inp_img, goal_img, softmax=True): 18 | """Forward pass.""" 19 | # Input image. 20 | in_data = np.pad(inp_img, self.padding, mode="constant") 21 | in_shape = (1,) + in_data.shape 22 | in_data = in_data.reshape(in_shape) 23 | in_tens = torch.from_numpy(in_data).to(dtype=torch.float, device=self.device) 24 | 25 | goal_tensor = np.pad(goal_img, self.padding, mode="constant") 26 | goal_shape = (1,) + goal_tensor.shape 27 | goal_tensor = goal_tensor.reshape(goal_shape) 28 | goal_tensor = torch.from_numpy(goal_tensor.copy()).to( 29 | dtype=torch.float, device=self.device 30 | ) 31 | in_tens = in_tens * goal_tensor 32 | 33 | # Rotation pivot. 34 | pv = np.array(in_data.shape[1:3]) // 2 35 | 36 | # Rotate input. 37 | in_tens = in_tens.permute(0, 3, 1, 2) 38 | in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1) 39 | in_tens = self.rotator(in_tens, pivot=pv) 40 | 41 | # Forward pass. 42 | logits = [] 43 | for x in in_tens: 44 | logits.append(self.attend(x)) 45 | logits = torch.cat(logits, dim=0) 46 | 47 | # Rotate back output. 48 | logits = self.rotator(logits, reverse=True, pivot=pv) 49 | logits = torch.cat(logits, dim=0) 50 | c0 = self.padding[:2, 0] 51 | c1 = c0 + inp_img.shape[:2] 52 | logits = logits[:, :, c0[0] : c1[0], c0[1] : c1[1]] 53 | 54 | logits = logits.permute(1, 2, 3, 0) # D H W C 55 | output = logits.reshape(1, np.prod(logits.shape)) 56 | if softmax: 57 | output = F.softmax(output, dim=-1) 58 | output = output.reshape(logits.shape[1:]) 59 | return output 60 | -------------------------------------------------------------------------------- /INSTALLATION.md: -------------------------------------------------------------------------------- 1 | # INSTALLATION 2 | 3 | To install the dependencies execute the `scripts/install_dependencies.sh` 4 | 5 | ```bash 6 | scripts/install_conda.sh # Skip this step if you already have conda installed. 7 | scripts/install_dependencies.sh 8 | ``` 9 | 10 | Please see the [README](README.md) for a quick start instruction. 11 | 12 | 13 | Alternatively, you can follow the detailed instructions to setup the software from scratch 14 | 15 | #### 2. PyRep and Coppelia Simulator 16 | 17 | Follow instructions from my [PyRep fork](https://github.com/markusgrotz/PyRep); reproduced here for convenience: 18 | 19 | PyRep requires version **4.1** of CoppeliaSim. Download: 20 | - [Ubuntu 20.04](https://www.coppeliarobotics.com/files/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz) 21 | 22 | Once you have downloaded CoppeliaSim, you can pull PyRep from git: 23 | 24 | ```bash 25 | cd 26 | git clone https://github.com/markusgrotz/PyRep.git 27 | cd PyRep 28 | ``` 29 | 30 | Add the following to your *~/.bashrc* file: (__NOTE__: the 'EDIT ME' in the first line) 31 | 32 | ```bash 33 | export COPPELIASIM_ROOT=/PATH/TO/COPPELIASIM/INSTALL/DIR 34 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT 35 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT 36 | ``` 37 | 38 | Remember to source your bashrc (`source ~/.bashrc`) or 39 | zshrc (`source ~/.zshrc`) after this. 40 | 41 | **Warning**: CoppeliaSim might cause conflicts with ROS workspaces. 42 | 43 | Finally install the python library: 44 | 45 | ```bash 46 | pip install -e . 47 | ``` 48 | 49 | You should be good to go! 50 | You could try running one of the examples in the *examples/* folder. 51 | 52 | #### 3. RLBench 53 | 54 | PerAct uses my [RLBench fork](https://github.com/markusgrotz/RLBench/tree/peract). 55 | 56 | ```bash 57 | cd 58 | git clone https://github.com/markusgrotz/RLBench.git 59 | 60 | cd RLBench 61 | pip install -e . 62 | ``` 63 | 64 | For [running in headless mode](https://github.com/MohitShridhar/RLBench/tree/peract#running-headless), tasks setups, and other issues, please refer to the [official repo](https://github.com/stepjam/RLBench). 65 | 66 | #### 4. YARR 67 | 68 | PerAct uses my [YARR fork](https://github.com/markusgrotz/YARR/). 69 | 70 | ```bash 71 | cd 72 | git clone https://github.com/markusgrotz/YARR.git 73 | 74 | cd YARR 75 | pip install -e . 76 | ``` 77 | 78 | 79 | 80 | #### RVT baseline 81 | 82 | pip install git+https://github.com/NVlabs/RVT.git 83 | pip install -e . 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /scripts/install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # edit this line if you want to install the dependencies to another directory 5 | 6 | WORKSPACE_DIR=${HOME}/code 7 | ENVIRONMENT_NAME=rlbench 8 | 9 | basedir=$(dirname $0) 10 | basedir=$(readlink -f $basedir) 11 | 12 | 13 | if ! [ -x "$(command -v curl)" ]; then 14 | echo "Unable to find curl. installing." 15 | sudo apt install curl 16 | fi 17 | 18 | if ! [ -x "$(command -v git)" ]; then 19 | echo "Unable to find git. installing." 20 | sudo apt install git 21 | fi 22 | 23 | if ! [ -x "$(command -v conda)" ]; then 24 | echo "Unable to find conda" 25 | exit 1 26 | fi 27 | 28 | conda create -n ${ENVIRONMENT_NAME} python=3.8 29 | mamba install -n ${ENVIRONMENT_NAME} pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia 30 | 31 | 32 | export COPPELIASIM_ROOT=${WORKSPACE_DIR}/coppelia_sim 33 | mkdir -p $COPPELIASIM_ROOT 34 | 35 | TEMP_DIR=$(mktemp --tmpdir -d coppelia_XXXXXXXXXX) 36 | cd $TEMP_DIR 37 | 38 | curl -L -O https://www.coppeliarobotics.com/files/V4_1_0/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz 39 | tar -xvf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz -C $COPPELIASIM_ROOT --strip-components 1 40 | rm -rf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz 41 | 42 | CONDA_PREFIX=$(conda info --envs | grep -e "^${ENVIRONMENT_NAME}\ " | awk '{print $2}') 43 | mkdir -p ${CONDA_PREFIX}/etc/conda/activate.d/ 44 | cat > ${CONDA_PREFIX}/etc/conda/activate.d/coppelia_sim.sh <> ${HOME}/.ssh/known_hosts 39 | 40 | RUN curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 41 | RUN bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda 42 | RUN export PATH=/opt/conda/bin:${PATH} 43 | 44 | # Install code and dependencies 45 | 46 | WORKDIR ${HOME}/code 47 | 48 | RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda init bash 49 | RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && conda install mamba -c conda-forge 50 | #RUN conda config --set auto_activate_base false 51 | 52 | 53 | RUN git clone https://github.com/markusgrotz/peract_bimanual.git ${HOME}/code/peract_bimanual 54 | 55 | 56 | RUN eval "$(/opt/conda/bin/conda shell.bash hook)" && ${HOME}/code/peract_bimanual/scripts/install_dependencies.sh 57 | 58 | 59 | # Activate the environment by default 60 | RUN echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 61 | echo "conda activate rlbench" >> ~/.bashrc 62 | 63 | 64 | WORKDIR /root/code/peract_bimanual 65 | 66 | # Default command 67 | CMD ["/bin/bash"] 68 | 69 | -------------------------------------------------------------------------------- /helpers/clip/core/attention.py: -------------------------------------------------------------------------------- 1 | """Attention module.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import cliport.models as models 9 | from cliport.utils import utils 10 | 11 | 12 | class Attention(nn.Module): 13 | """Attention (a.k.a Pick) module.""" 14 | 15 | def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device): 16 | super().__init__() 17 | self.stream_fcn = stream_fcn 18 | self.n_rotations = n_rotations 19 | self.preprocess = preprocess 20 | self.cfg = cfg 21 | self.device = device 22 | self.batchnorm = self.cfg["train"]["batchnorm"] 23 | 24 | self.padding = np.zeros((3, 2), dtype=int) 25 | max_dim = np.max(in_shape[:2]) 26 | pad = (max_dim - np.array(in_shape[:2])) / 2 27 | self.padding[:2] = pad.reshape(2, 1) 28 | 29 | in_shape = np.array(in_shape) 30 | in_shape += np.sum(self.padding, axis=1) 31 | in_shape = tuple(in_shape) 32 | self.in_shape = in_shape 33 | 34 | self.rotator = utils.ImageRotator(self.n_rotations) 35 | 36 | self._build_nets() 37 | 38 | def _build_nets(self): 39 | stream_one_fcn, _ = self.stream_fcn 40 | self.attn_stream = models.names[stream_one_fcn]( 41 | self.in_shape, 1, self.cfg, self.device 42 | ) 43 | print(f"Attn FCN: {stream_one_fcn}") 44 | 45 | def attend(self, x): 46 | return self.attn_stream(x) 47 | 48 | def forward(self, inp_img, softmax=True): 49 | """Forward pass.""" 50 | in_data = np.pad(inp_img, self.padding, mode="constant") 51 | in_shape = (1,) + in_data.shape 52 | in_data = in_data.reshape(in_shape) 53 | in_tens = torch.from_numpy(in_data).to( 54 | dtype=torch.float, device=self.device 55 | ) # [B W H 6] 56 | 57 | # Rotation pivot. 58 | pv = np.array(in_data.shape[1:3]) // 2 59 | 60 | # Rotate input. 61 | in_tens = in_tens.permute(0, 3, 1, 2) # [B 6 W H] 62 | in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1) 63 | in_tens = self.rotator(in_tens, pivot=pv) 64 | 65 | # Forward pass. 66 | logits = [] 67 | for x in in_tens: 68 | lgts = self.attend(x) 69 | logits.append(lgts) 70 | logits = torch.cat(logits, dim=0) 71 | 72 | # Rotate back output. 73 | logits = self.rotator(logits, reverse=True, pivot=pv) 74 | logits = torch.cat(logits, dim=0) 75 | c0 = self.padding[:2, 0] 76 | c1 = c0 + inp_img.shape[:2] 77 | logits = logits[:, :, c0[0] : c1[0], c0[1] : c1[1]] 78 | 79 | logits = logits.permute(1, 2, 3, 0) # [B W H 1] 80 | output = logits.reshape(1, np.prod(logits.shape)) 81 | if softmax: 82 | output = F.softmax(output, dim=-1) 83 | output = output.reshape(logits.shape[1:]) 84 | return output 85 | -------------------------------------------------------------------------------- /helpers/clip/core/unet.py: -------------------------------------------------------------------------------- 1 | # Credit: https://github.com/milesial/Pytorch-UNet/ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | # nn.BatchNorm2d(mid_channels), # (Mohit): argh... forgot to remove this batchnorm 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | # nn.BatchNorm2d(out_channels), # (Mohit): argh... forgot to remove this batchnorm 21 | nn.ReLU(inplace=True), 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.maxpool_conv(x) 39 | 40 | 41 | class Up(nn.Module): 42 | """Upscaling then double conv""" 43 | 44 | def __init__(self, in_channels, out_channels, bilinear=True): 45 | super().__init__() 46 | 47 | # if bilinear, use the normal convolutions to reduce the number of channels 48 | if bilinear: 49 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 50 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 51 | else: 52 | self.up = nn.ConvTranspose2d( 53 | in_channels, in_channels // 2, kernel_size=2, stride=2 54 | ) 55 | self.conv = DoubleConv(in_channels, out_channels) 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) 64 | # if you have padding issues, see 65 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 66 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 67 | x = torch.cat([x2, x1], dim=1) 68 | return self.conv(x) 69 | 70 | 71 | class OutConv(nn.Module): 72 | def __init__(self, in_channels, out_channels): 73 | super(OutConv, self).__init__() 74 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 75 | 76 | def forward(self, x): 77 | return self.conv(x) 78 | -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: Perceiver-Actor 2 | 3 | Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf) we provide additional information on PerAct. 4 | 5 | ## Model Details 6 | 7 | 8 | ### Overview 9 | - Developed by Shridhar et al. at University of Washington and NVIDIA. PerAct is an end-to-end behavior cloning agent that learns to perform a wide variety of language-conditioned manipulation tasks. PerAct uses a Transformer that exploits the 3D structure of _voxel patches_ to learn policies with just a few demonstrations per task. 10 | - Architecture: Transformer trained from scratch with end-to-end supervised learning. 11 | - Trained for 6-DoF manipulation tasks with objects that appear in tabletop scenes. 12 | 13 | ### Model Date 14 | 15 | Nov 2022 16 | 17 | ### Documents 18 | 19 | - [PerAct Paper](https://peract.github.io/paper/peract_corl2022.pdf) 20 | - [PerceiverIO Paper](https://arxiv.org/abs/2107.14795) 21 | - [C2FARM Paper](https://arxiv.org/abs/2106.12534) 22 | 23 | 24 | ## Model Use 25 | 26 | - **Primary intended use case**: PerAct is intended for robotic manipulation research. We hope the benchmark and pre-trained models will enable researchers to study the capabilities of Transformers for end-to-end 6-DoF Manipulation. Specifically, we hope the setup serves a reproducible framework for evaluating robustness and scaling capabilities of manipulation agents. 27 | - **Primary intended users**: Robotics researchers. 28 | - **Out-of-scope use cases**: Deployed use cases in real-world autonomous systems without human supervision during test-time is currently out-of-scope. Use cases that involve manipulating novel objects and observations with people, are not recommended for safety-critical systems. The agent is also intended to be trained and evaluated with English language instructions. 29 | 30 | ## Data 31 | 32 | - Pre-training Data for CLIP's language encoder: See [OpenAI's Model Card](https://github.com/openai/CLIP/blob/main/model-card.md#data) for full details. **Note:** We do not use CLIP's vision encoders for any agents in the repo. 33 | - Manipulation Data for PerAct: The agent was trained with expert demonstrations. In simulation, we use oracle agents and in real-world we use human demonstrations. Since the agent is used in few-shot settings with very limited data, the agent might exploit intended and un-intented biases in the training demonstrations. Currently, these biases are limited to just objects that appear on tabletops. 34 | 35 | 36 | ## Limitations 37 | 38 | - Depends on a sampling-based motion planner. 39 | - Hard to extend to dexterous and continuous manipulation tasks. 40 | - Lacks memory to solve tasks with ordering and history-based sequencing. 41 | - Exploits biases in training demonstrations. 42 | - Needs good hand-eye calibration. 43 | - Doesn't generalize to novel objects. 44 | - Struggles with grounding complex spatial relationships. 45 | - Does not predict task completion. 46 | 47 | See Appendix L in the [paper](https://peract.github.io/paper/peract_corl2022.pdf) for an extended discussion. -------------------------------------------------------------------------------- /agents/act_bc_lang/detr/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from agents.act_bc_lang.detr.util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor): 29 | x = tensor 30 | # mask = tensor_list.mask 31 | # assert mask is not None 32 | # not_mask = ~mask 33 | 34 | not_mask = torch.ones_like(x[0, [0]]) 35 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 36 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 37 | if self.normalize: 38 | eps = 1e-6 39 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 40 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 41 | 42 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 43 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 44 | 45 | pos_x = x_embed[:, :, :, None] / dim_t 46 | pos_y = y_embed[:, :, :, None] / dim_t 47 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 48 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 49 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 50 | return pos 51 | 52 | 53 | class PositionEmbeddingLearned(nn.Module): 54 | """ 55 | Absolute pos embedding, learned. 56 | """ 57 | def __init__(self, num_pos_feats=256): 58 | super().__init__() 59 | self.row_embed = nn.Embedding(50, num_pos_feats) 60 | self.col_embed = nn.Embedding(50, num_pos_feats) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | nn.init.uniform_(self.row_embed.weight) 65 | nn.init.uniform_(self.col_embed.weight) 66 | 67 | def forward(self, tensor_list: NestedTensor): 68 | x = tensor_list.tensors 69 | h, w = x.shape[-2:] 70 | i = torch.arange(w, device=x.device) 71 | j = torch.arange(h, device=x.device) 72 | x_emb = self.col_embed(i) 73 | y_emb = self.row_embed(j) 74 | pos = torch.cat([ 75 | x_emb.unsqueeze(0).repeat(h, 1, 1), 76 | y_emb.unsqueeze(1).repeat(1, w, 1), 77 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 78 | return pos 79 | 80 | 81 | def build_position_encoding(args): 82 | N_steps = args.hidden_dim // 2 83 | if args.position_embedding in ('v2', 'sine'): 84 | # TODO find a better way of exposing other arguments 85 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 86 | elif args.position_embedding in ('v3', 'learned'): 87 | position_embedding = PositionEmbeddingLearned(N_steps) 88 | else: 89 | raise ValueError(f"not supported {args.position_embedding}") 90 | 91 | return position_embedding 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /agents/agent_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from omegaconf import DictConfig 5 | 6 | 7 | from yarr.agents.agent import BimanualAgent 8 | from yarr.agents.agent import LeaderFollowerAgent 9 | from yarr.agents.agent import Agent 10 | 11 | 12 | supported_agents = { 13 | "leader_follower": ("PERACT_BC", "RVT"), 14 | "independent": ("PERACT_BC", "RVT"), 15 | "bimanual": ("BIMANUAL_PERACT", "ACT_BC_LANG"), 16 | "unimanual": (), 17 | } 18 | 19 | 20 | def create_agent(cfg: DictConfig) -> Agent: 21 | method_name = cfg.method.name 22 | agent_type = cfg.method.agent_type 23 | 24 | logging.info("Using method %s with type %s", method_name, agent_type) 25 | 26 | assert method_name in supported_agents[agent_type] 27 | 28 | agent_fn = agent_fn_by_name(method_name) 29 | 30 | if agent_type == "leader_follower": 31 | checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix 32 | cfg.method.robot_name = "right" 33 | cfg.framework.checkpoint_name_prefix = ( 34 | f"{checkpoint_name_prefix}_{method_name.lower()}_leader" 35 | ) 36 | leader_agent = agent_fn(cfg) 37 | 38 | cfg.method.robot_name = "left" 39 | cfg.framework.checkpoint_name_prefix = ( 40 | f"{checkpoint_name_prefix}_{method_name.lower()}_follower" 41 | ) 42 | cfg.method.low_dim_size = ( 43 | cfg.method.low_dim_size + 8 44 | ) # also add the action size 45 | follower_agent = agent_fn(cfg) 46 | 47 | cfg.method.robot_name = "bimanual" 48 | 49 | return LeaderFollowerAgent(leader_agent, follower_agent) 50 | 51 | elif agent_type == "independent": 52 | checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix 53 | cfg.method.robot_name = "right" 54 | cfg.framework.checkpoint_name_prefix = ( 55 | f"{checkpoint_name_prefix}_{method_name.lower()}_right" 56 | ) 57 | right_agent = agent_fn(cfg) 58 | 59 | cfg.method.robot_name = "left" 60 | cfg.framework.checkpoint_name_prefix = ( 61 | f"{checkpoint_name_prefix}_{method_name.lower()}_left" 62 | ) 63 | left_agent = agent_fn(cfg) 64 | 65 | cfg.method.robot_name = "bimanual" 66 | 67 | return BimanualAgent(right_agent, left_agent) 68 | elif agent_type == "bimanual" or agent_type == "unimanual": 69 | return agent_fn(cfg) 70 | else: 71 | raise Exception("invalid agent type") 72 | 73 | 74 | def agent_fn_by_name(method_name: str) -> Agent: 75 | if method_name == "ARM": 76 | from agents import arm 77 | 78 | raise NotImplementedError("ARM not yet supported for eval.py") 79 | elif method_name == "BC_LANG": 80 | from agents.baselines import bc_lang 81 | 82 | return bc_lang.launch_utils.create_agent 83 | elif method_name == "VIT_BC_LANG": 84 | from agents.baselines import vit_bc_lang 85 | 86 | return vit_bc_lang.launch_utils.create_agent 87 | elif method_name == "C2FARM_LINGUNET_BC": 88 | from agents import c2farm_lingunet_bc 89 | 90 | return c2farm_lingunet_bc.launch_utils.create_agent 91 | elif method_name.startswith("PERACT_BC"): 92 | from agents import peract_bc 93 | 94 | return peract_bc.launch_utils.create_agent 95 | elif method_name.startswith("BIMANUAL_PERACT"): 96 | from agents import bimanual_peract 97 | 98 | return bimanual_peract.launch_utils.create_agent 99 | elif method_name.startswith("RVT"): 100 | from agents import rvt 101 | 102 | return rvt.launch_utils.create_agent 103 | elif method_name.startswith("ACT_BC_LANG"): 104 | from agents import act_bc_lang 105 | 106 | return act_bc_lang.launch_utils.create_agent 107 | elif method_name == "PERACT_RL": 108 | raise NotImplementedError("PERACT_RL not yet supported for eval.py") 109 | 110 | else: 111 | raise ValueError("Method %s does not exists." % method_name) 112 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import logging 3 | import os 4 | import sys 5 | from datetime import datetime 6 | 7 | import peract_config 8 | 9 | import hydra 10 | from omegaconf import DictConfig, OmegaConf, ListConfig 11 | 12 | import run_seed_fn 13 | from helpers.observation_utils import create_obs_config 14 | 15 | import torch.multiprocessing as mp 16 | 17 | 18 | @hydra.main(config_name="config", config_path="conf") 19 | def main(cfg: DictConfig) -> None: 20 | cfg_yaml = OmegaConf.to_yaml(cfg) 21 | logging.info("\n" + cfg_yaml) 22 | 23 | peract_config.on_config(cfg) 24 | 25 | cfg.rlbench.cameras = ( 26 | cfg.rlbench.cameras 27 | if isinstance(cfg.rlbench.cameras, ListConfig) 28 | else [cfg.rlbench.cameras] 29 | ) 30 | 31 | # sanity check if rgb is not used as camera name 32 | for camera_name in cfg.rlbench.cameras: 33 | assert "rgb" not in camera_name 34 | 35 | obs_config = create_obs_config( 36 | cfg.rlbench.cameras, cfg.rlbench.camera_resolution, cfg.method.name 37 | ) 38 | 39 | cwd = os.getcwd() 40 | logging.info("CWD:" + os.getcwd()) 41 | 42 | if cfg.framework.start_seed >= 0: 43 | # seed specified 44 | start_seed = cfg.framework.start_seed 45 | elif ( 46 | cfg.framework.start_seed == -1 47 | and len(list(filter(lambda x: "seed" in x, os.listdir(cwd)))) > 0 48 | ): 49 | # unspecified seed; use largest existing seed plus one 50 | largest_seed = max( 51 | [ 52 | int(n.replace("seed", "")) 53 | for n in list(filter(lambda x: "seed" in x, os.listdir(cwd))) 54 | ] 55 | ) 56 | start_seed = largest_seed + 1 57 | else: 58 | # start with seed 0 59 | start_seed = 0 60 | 61 | seed_folder = os.path.join(os.getcwd(), "seed%d" % start_seed) 62 | os.makedirs(seed_folder, exist_ok=True) 63 | 64 | start_time = datetime.now() 65 | with open(os.path.join(seed_folder, "config.yaml"), "w") as f: 66 | f.write(cfg_yaml) 67 | 68 | # check if previous checkpoints already exceed the number of desired training iterations 69 | # if so, exit the script 70 | latest_weight = 0 71 | weights_folder = os.path.join(seed_folder, "weights") 72 | if os.path.isdir(weights_folder) and len(os.listdir(weights_folder)) > 0: 73 | weights = os.listdir(weights_folder) 74 | latest_weight = sorted(map(int, weights))[-1] 75 | if latest_weight >= cfg.framework.training_iterations: 76 | logging.info( 77 | "Agent was already trained for %d iterations. Exiting." % latest_weight 78 | ) 79 | sys.exit(0) 80 | 81 | with open(os.path.join(seed_folder, "training.log"), "a") as f: 82 | f.write( 83 | f"# Starting training from weights: {latest_weight} to {cfg.framework.training_iterations}" 84 | ) 85 | f.write(f"# Training started on: {start_time.isoformat()}") 86 | f.write(os.linesep) 87 | 88 | # run train jobs with multiple seeds (sequentially) 89 | for seed in range(start_seed, start_seed + cfg.framework.seeds): 90 | logging.info("Starting seed %d." % seed) 91 | 92 | world_size = cfg.ddp.num_devices 93 | mp.spawn( 94 | run_seed_fn.run_seed, 95 | args=( 96 | cfg, 97 | obs_config, 98 | seed, 99 | world_size, 100 | ), 101 | nprocs=world_size, 102 | join=True, 103 | ) 104 | 105 | end_time = datetime.now() 106 | duration = end_time - start_time 107 | with open(os.path.join(seed_folder, "training.log"), "a") as f: 108 | f.write(f"# Training finished on: {end_time.isoformat()}") 109 | f.write(f"# Took {duration.total_seconds()}") 110 | f.write(os.linesep) 111 | f.write(os.linesep) 112 | 113 | 114 | if __name__ == "__main__": 115 | peract_config.on_init() 116 | main() 117 | -------------------------------------------------------------------------------- /helpers/clip/core/transport.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cliport.models as models 3 | from cliport.utils import utils 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Transport(nn.Module): 11 | def __init__( 12 | self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device 13 | ): 14 | """Transport (a.k.a Place) module.""" 15 | super().__init__() 16 | 17 | self.iters = 0 18 | self.stream_fcn = stream_fcn 19 | self.n_rotations = n_rotations 20 | self.crop_size = crop_size # crop size must be N*16 (e.g. 96) 21 | self.preprocess = preprocess 22 | self.cfg = cfg 23 | self.device = device 24 | self.batchnorm = self.cfg["train"]["batchnorm"] 25 | 26 | self.pad_size = int(self.crop_size / 2) 27 | self.padding = np.zeros((3, 2), dtype=int) 28 | self.padding[:2, :] = self.pad_size 29 | 30 | in_shape = np.array(in_shape) 31 | in_shape = tuple(in_shape) 32 | self.in_shape = in_shape 33 | 34 | # Crop before network (default from Transporters CoRL 2020). 35 | self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2]) 36 | 37 | if not hasattr(self, "output_dim"): 38 | self.output_dim = 3 39 | if not hasattr(self, "kernel_dim"): 40 | self.kernel_dim = 3 41 | 42 | self.rotator = utils.ImageRotator(self.n_rotations) 43 | 44 | self._build_nets() 45 | 46 | def _build_nets(self): 47 | stream_one_fcn, _ = self.stream_fcn 48 | model = models.names[stream_one_fcn] 49 | self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) 50 | self.query_resnet = model( 51 | self.kernel_shape, self.kernel_dim, self.cfg, self.device 52 | ) 53 | print(f"Transport FCN: {stream_one_fcn}") 54 | 55 | def correlate(self, in0, in1, softmax): 56 | """Correlate two input tensors.""" 57 | output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size)) 58 | output = F.interpolate( 59 | output, size=(in0.shape[-2], in0.shape[-1]), mode="bilinear" 60 | ) 61 | output = output[ 62 | :, :, self.pad_size : -self.pad_size, self.pad_size : -self.pad_size 63 | ] 64 | if softmax: 65 | output_shape = output.shape 66 | output = output.reshape((1, np.prod(output.shape))) 67 | output = F.softmax(output, dim=-1) 68 | output = output.reshape(output_shape[1:]) 69 | return output 70 | 71 | def transport(self, in_tensor, crop): 72 | logits = self.key_resnet(in_tensor) 73 | kernel = self.query_resnet(crop) 74 | return logits, kernel 75 | 76 | def forward(self, inp_img, p, softmax=True): 77 | """Forward pass.""" 78 | img_unprocessed = np.pad(inp_img, self.padding, mode="constant") 79 | input_data = img_unprocessed 80 | in_shape = (1,) + input_data.shape 81 | input_data = input_data.reshape(in_shape) # [B W H D] 82 | in_tensor = torch.from_numpy(input_data).to( 83 | dtype=torch.float, device=self.device 84 | ) 85 | 86 | # Rotation pivot. 87 | pv = np.array([p[0], p[1]]) + self.pad_size 88 | 89 | # Crop before network (default from Transporters CoRL 2020). 90 | hcrop = self.pad_size 91 | in_tensor = in_tensor.permute(0, 3, 1, 2) # [B D W H] 92 | 93 | crop = in_tensor.repeat(self.n_rotations, 1, 1, 1) 94 | crop = self.rotator(crop, pivot=pv) 95 | crop = torch.cat(crop, dim=0) 96 | crop = crop[:, :, pv[0] - hcrop : pv[0] + hcrop, pv[1] - hcrop : pv[1] + hcrop] 97 | 98 | logits, kernel = self.transport(in_tensor, crop) 99 | 100 | # TODO(Mohit): Crop after network. Broken for now. 101 | # in_tensor = in_tensor.permute(0, 3, 1, 2) 102 | # logits, crop = self.transport(in_tensor) 103 | # crop = crop.repeat(self.n_rotations, 1, 1, 1) 104 | # crop = self.rotator(crop, pivot=pv) 105 | # crop = torch.cat(crop, dim=0) 106 | 107 | # kernel = crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 108 | # kernel = crop[:, :, p[0]:(p[0] + self.crop_size), p[1]:(p[1] + self.crop_size)] 109 | 110 | return self.correlate(logits, kernel, softmax) 111 | -------------------------------------------------------------------------------- /agents/bimanual_peract/launch_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from ARM 2 | # Source: https://github.com/stepjam/ARM 3 | # License: https://github.com/stepjam/ARM/LICENSE 4 | 5 | 6 | from helpers.preprocess_agent import PreprocessAgent 7 | 8 | from agents.bimanual_peract.perceiver_lang_io import PerceiverVoxelLangEncoder 9 | from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent 10 | from agents.bimanual_peract.qattention_stack_agent import QAttentionStackAgent 11 | 12 | from omegaconf import DictConfig 13 | 14 | 15 | def create_agent(cfg: DictConfig): 16 | depth_0bounds = cfg.rlbench.scene_bounds 17 | cam_resolution = cfg.rlbench.camera_resolution 18 | 19 | num_rotation_classes = int(360.0 // cfg.method.rotation_resolution) 20 | qattention_agents = [] 21 | for depth, vox_size in enumerate(cfg.method.voxel_sizes): 22 | last = depth == len(cfg.method.voxel_sizes) - 1 23 | perceiver_encoder = PerceiverVoxelLangEncoder( 24 | depth=cfg.method.transformer_depth, 25 | iterations=cfg.method.transformer_iterations, 26 | voxel_size=vox_size, 27 | initial_dim=3 + 3 + 1 + 3, 28 | low_dim_size=cfg.method.low_dim_size, 29 | layer=depth, 30 | num_rotation_classes=num_rotation_classes if last else 0, 31 | num_grip_classes=2 if last else 0, 32 | num_collision_classes=2 if last else 0, 33 | input_axis=3, 34 | num_latents=cfg.method.num_latents, 35 | latent_dim=cfg.method.latent_dim, 36 | cross_heads=cfg.method.cross_heads, 37 | latent_heads=cfg.method.latent_heads, 38 | cross_dim_head=cfg.method.cross_dim_head, 39 | latent_dim_head=cfg.method.latent_dim_head, 40 | weight_tie_layers=False, 41 | activation=cfg.method.activation, 42 | pos_encoding_with_lang=cfg.method.pos_encoding_with_lang, 43 | input_dropout=cfg.method.input_dropout, 44 | attn_dropout=cfg.method.attn_dropout, 45 | decoder_dropout=cfg.method.decoder_dropout, 46 | lang_fusion_type=cfg.method.lang_fusion_type, 47 | voxel_patch_size=cfg.method.voxel_patch_size, 48 | voxel_patch_stride=cfg.method.voxel_patch_stride, 49 | no_skip_connection=cfg.method.no_skip_connection, 50 | no_perceiver=cfg.method.no_perceiver, 51 | no_language=cfg.method.no_language, 52 | final_dim=cfg.method.final_dim, 53 | ) 54 | 55 | qattention_agent = QAttentionPerActBCAgent( 56 | layer=depth, 57 | coordinate_bounds=depth_0bounds, 58 | perceiver_encoder=perceiver_encoder, 59 | camera_names=cfg.rlbench.cameras, 60 | voxel_size=vox_size, 61 | bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None, 62 | image_crop_size=cfg.method.image_crop_size, 63 | lr=cfg.method.lr, 64 | training_iterations=cfg.framework.training_iterations, 65 | lr_scheduler=cfg.method.lr_scheduler, 66 | num_warmup_steps=cfg.method.num_warmup_steps, 67 | trans_loss_weight=cfg.method.trans_loss_weight, 68 | rot_loss_weight=cfg.method.rot_loss_weight, 69 | grip_loss_weight=cfg.method.grip_loss_weight, 70 | collision_loss_weight=cfg.method.collision_loss_weight, 71 | include_low_dim_state=True, 72 | image_resolution=cam_resolution, 73 | batch_size=cfg.replay.batch_size, 74 | voxel_feature_size=3, 75 | lambda_weight_l2=cfg.method.lambda_weight_l2, 76 | num_rotation_classes=num_rotation_classes, 77 | rotation_resolution=cfg.method.rotation_resolution, 78 | transform_augmentation=cfg.method.transform_augmentation.apply_se3, 79 | transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz, 80 | transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy, 81 | transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution, 82 | optimizer_type=cfg.method.optimizer, 83 | num_devices=cfg.ddp.num_devices, 84 | ) 85 | qattention_agents.append(qattention_agent) 86 | 87 | rotation_agent = QAttentionStackAgent( 88 | qattention_agents=qattention_agents, 89 | rotation_resolution=cfg.method.rotation_resolution, 90 | camera_names=cfg.rlbench.cameras, 91 | ) 92 | preprocess_agent = PreprocessAgent(pose_agent=rotation_agent) 93 | return preprocess_agent 94 | -------------------------------------------------------------------------------- /agents/peract_bc/launch_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from ARM 2 | # Source: https://github.com/stepjam/ARM 3 | # License: https://github.com/stepjam/ARM/LICENSE 4 | 5 | 6 | from helpers.preprocess_agent import PreprocessAgent 7 | from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder 8 | from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent 9 | from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent 10 | 11 | from omegaconf import DictConfig 12 | 13 | 14 | def create_agent(cfg: DictConfig): 15 | LATENT_SIZE = 64 16 | depth_0bounds = cfg.rlbench.scene_bounds 17 | cam_resolution = cfg.rlbench.camera_resolution 18 | 19 | num_rotation_classes = int(360.0 // cfg.method.rotation_resolution) 20 | qattention_agents = [] 21 | for depth, vox_size in enumerate(cfg.method.voxel_sizes): 22 | last = depth == len(cfg.method.voxel_sizes) - 1 23 | perceiver_encoder = PerceiverVoxelLangEncoder( 24 | depth=cfg.method.transformer_depth, 25 | iterations=cfg.method.transformer_iterations, 26 | voxel_size=vox_size, 27 | initial_dim=3 + 3 + 1 + 3, 28 | low_dim_size=cfg.method.low_dim_size, 29 | layer=depth, 30 | num_rotation_classes=num_rotation_classes if last else 0, 31 | num_grip_classes=2 if last else 0, 32 | num_collision_classes=2 if last else 0, 33 | input_axis=3, 34 | num_latents=cfg.method.num_latents, 35 | latent_dim=cfg.method.latent_dim, 36 | cross_heads=cfg.method.cross_heads, 37 | latent_heads=cfg.method.latent_heads, 38 | cross_dim_head=cfg.method.cross_dim_head, 39 | latent_dim_head=cfg.method.latent_dim_head, 40 | weight_tie_layers=False, 41 | activation=cfg.method.activation, 42 | pos_encoding_with_lang=cfg.method.pos_encoding_with_lang, 43 | input_dropout=cfg.method.input_dropout, 44 | attn_dropout=cfg.method.attn_dropout, 45 | decoder_dropout=cfg.method.decoder_dropout, 46 | lang_fusion_type=cfg.method.lang_fusion_type, 47 | voxel_patch_size=cfg.method.voxel_patch_size, 48 | voxel_patch_stride=cfg.method.voxel_patch_stride, 49 | no_skip_connection=cfg.method.no_skip_connection, 50 | no_perceiver=cfg.method.no_perceiver, 51 | no_language=cfg.method.no_language, 52 | final_dim=cfg.method.final_dim, 53 | ) 54 | 55 | qattention_agent = QAttentionPerActBCAgent( 56 | layer=depth, 57 | coordinate_bounds=depth_0bounds, 58 | perceiver_encoder=perceiver_encoder, 59 | camera_names=cfg.rlbench.cameras, 60 | voxel_size=vox_size, 61 | bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None, 62 | image_crop_size=cfg.method.image_crop_size, 63 | lr=cfg.method.lr, 64 | training_iterations=cfg.framework.training_iterations, 65 | lr_scheduler=cfg.method.lr_scheduler, 66 | num_warmup_steps=cfg.method.num_warmup_steps, 67 | trans_loss_weight=cfg.method.trans_loss_weight, 68 | rot_loss_weight=cfg.method.rot_loss_weight, 69 | grip_loss_weight=cfg.method.grip_loss_weight, 70 | collision_loss_weight=cfg.method.collision_loss_weight, 71 | include_low_dim_state=True, 72 | image_resolution=cam_resolution, 73 | batch_size=cfg.replay.batch_size, 74 | voxel_feature_size=3, 75 | lambda_weight_l2=cfg.method.lambda_weight_l2, 76 | num_rotation_classes=num_rotation_classes, 77 | rotation_resolution=cfg.method.rotation_resolution, 78 | transform_augmentation=cfg.method.transform_augmentation.apply_se3, 79 | transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz, 80 | transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy, 81 | transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution, 82 | optimizer_type=cfg.method.optimizer, 83 | num_devices=cfg.ddp.num_devices, 84 | checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix, 85 | ) 86 | qattention_agents.append(qattention_agent) 87 | 88 | rotation_agent = QAttentionStackAgent( 89 | qattention_agents=qattention_agents, 90 | rotation_resolution=cfg.method.rotation_resolution, 91 | camera_names=cfg.rlbench.cameras, 92 | ) 93 | preprocess_agent = PreprocessAgent(pose_agent=rotation_agent) 94 | return preprocess_agent 95 | -------------------------------------------------------------------------------- /agents/act_bc_lang/detr/models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from agents.act_bc_lang.detr.util.misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | 23 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 24 | without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] 25 | produce nans. 26 | """ 27 | 28 | def __init__(self, n): 29 | super(FrozenBatchNorm2d, self).__init__() 30 | self.register_buffer("weight", torch.ones(n)) 31 | self.register_buffer("bias", torch.zeros(n)) 32 | self.register_buffer("running_mean", torch.zeros(n)) 33 | self.register_buffer("running_var", torch.ones(n)) 34 | 35 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 36 | missing_keys, unexpected_keys, error_msgs): 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | # move reshapes to the beginning 47 | # to make it fuser-friendly 48 | w = self.weight.reshape(1, -1, 1, 1) 49 | b = self.bias.reshape(1, -1, 1, 1) 50 | rv = self.running_var.reshape(1, -1, 1, 1) 51 | rm = self.running_mean.reshape(1, -1, 1, 1) 52 | eps = 1e-5 53 | scale = w * (rv + eps).rsqrt() 54 | bias = b - rm * scale 55 | return x * scale + bias 56 | 57 | 58 | class BackboneBase(nn.Module): 59 | 60 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 61 | super().__init__() 62 | # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? 63 | # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 64 | # parameter.requires_grad_(False) 65 | if return_interm_layers: 66 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 67 | else: 68 | return_layers = {'layer4': "0"} 69 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 70 | self.num_channels = num_channels 71 | 72 | def forward(self, tensor): 73 | xs = self.body(tensor) 74 | return xs 75 | # out: Dict[str, NestedTensor] = {} 76 | # for name, x in xs.items(): 77 | # m = tensor_list.mask 78 | # assert m is not None 79 | # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 80 | # out[name] = NestedTensor(x, mask) 81 | # return out 82 | 83 | 84 | class Backbone(BackboneBase): 85 | """ResNet backbone with frozen BatchNorm.""" 86 | def __init__(self, name: str, 87 | train_backbone: bool, 88 | return_interm_layers: bool, 89 | dilation: bool): 90 | backbone = getattr(torchvision.models, name)( 91 | replace_stride_with_dilation=[False, False, dilation], 92 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? 93 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 94 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 95 | 96 | 97 | class Joiner(nn.Sequential): 98 | def __init__(self, backbone, position_embedding): 99 | super().__init__(backbone, position_embedding) 100 | 101 | def forward(self, tensor_list: NestedTensor): 102 | xs = self[0](tensor_list) 103 | out: List[NestedTensor] = [] 104 | pos = [] 105 | for name, x in xs.items(): 106 | out.append(x) 107 | # position encoding 108 | pos.append(self[1](x).to(x.dtype)) 109 | 110 | return out, pos 111 | 112 | 113 | def build_backbone(args): 114 | position_embedding = build_position_encoding(args) 115 | train_backbone = args.lr_backbone > 0 116 | return_interm_layers = args.masks 117 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 118 | model = Joiner(backbone, position_embedding) 119 | model.num_channels = backbone.num_channels 120 | return model 121 | -------------------------------------------------------------------------------- /agents/act_bc_lang/act_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import torchvision.transforms as transforms 5 | 6 | from agents.act_bc_lang.detr.build import ( 7 | build_ACT_model_and_optimizer, 8 | build_CNNMLP_model_and_optimizer, 9 | ) 10 | 11 | 12 | class ACTPolicy(nn.Module): 13 | def __init__(self, args): 14 | super().__init__() 15 | model, optimizer = build_ACT_model_and_optimizer(args) 16 | self.model = model # CVAE decoder 17 | self.optimizer = optimizer 18 | self.kl_weight = args.kl_weight 19 | print(f"KL Weight {self.kl_weight}") 20 | 21 | def forward(self, qpos, image, actions=None, is_pad=None): 22 | env_state = None 23 | 24 | if actions is not None: # training time 25 | actions = actions[:, : self.model.num_queries] 26 | is_pad = is_pad[:, : self.model.num_queries] 27 | 28 | a_hat, is_pad_hat, (mu, logvar) = self.model( 29 | qpos, image, env_state, actions, is_pad 30 | ) 31 | total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) 32 | loss_dict = dict() 33 | 34 | right_actions_joints, right_a_hat_joints = ( 35 | actions[:, :, 0:8], 36 | a_hat[:, :, 0:8], 37 | ) 38 | right_actions_gripper, right_a_hat_gripper = ( 39 | actions[:, :, 7], 40 | a_hat[:, :, 7], 41 | ) 42 | left_actions_joints, left_a_hat_joints = ( 43 | actions[:, :, 8:16], 44 | a_hat[:, :, 8:16], 45 | ) 46 | left_actions_gripper, left_a_hat_gripper = ( 47 | actions[:, :, 15], 48 | a_hat[:, :, 15], 49 | ) 50 | 51 | # use L1 loss for joints 52 | right_l1_loss = F.l1_loss( 53 | right_a_hat_joints, right_actions_joints, reduction="none" 54 | ) 55 | right_l1 = (right_l1_loss * ~is_pad.unsqueeze(-1)).mean() 56 | 57 | left_l1_loss = F.l1_loss( 58 | left_a_hat_joints, left_actions_joints, reduction="none" 59 | ) 60 | left_l1 = (left_l1_loss * ~is_pad.unsqueeze(-1)).mean() 61 | 62 | l1 = right_l1 + left_l1 63 | 64 | right_gripper_l1_loss = F.l1_loss( 65 | right_a_hat_gripper, right_actions_gripper, reduction="none" 66 | ) 67 | right_gripper_l1_loss = (right_gripper_l1_loss * ~is_pad).mean() 68 | 69 | left_gripper_l1_loss = F.l1_loss( 70 | left_a_hat_gripper, left_actions_gripper, reduction="none" 71 | ) 72 | left_gripper_l1_loss = (left_gripper_l1_loss * ~is_pad).mean() 73 | 74 | gripper_l1 = right_gripper_l1_loss + left_gripper_l1_loss 75 | loss_dict["right_l1"] = right_l1 76 | loss_dict["left_l1"] = left_l1 77 | 78 | loss_dict["l1"] = l1 79 | loss_dict["gripper_l1"] = gripper_l1 80 | 81 | loss_dict["kl"] = total_kld[0] 82 | loss_dict["total_losses"] = ( 83 | loss_dict["l1"] + loss_dict["kl"] * self.kl_weight 84 | ) 85 | return loss_dict 86 | else: # inference time 87 | a_hat, _, (_, _) = self.model( 88 | qpos, image, env_state 89 | ) # no action, sample from prior 90 | return a_hat 91 | 92 | def configure_optimizers(self): 93 | return self.optimizer 94 | 95 | 96 | class CNNMLPPolicy(nn.Module): 97 | def __init__(self, args): 98 | super().__init__() 99 | model, optimizer = build_CNNMLP_model_and_optimizer(args) 100 | self.model = model # decoder 101 | self.optimizer = optimizer 102 | 103 | def forward(self, qpos, image, actions=None, is_pad=None): 104 | env_state = None # TODO 105 | 106 | if actions is not None: # training time 107 | actions = actions[:, 0] 108 | a_hat = self.model(qpos, image, env_state, actions) 109 | mse = F.mse_loss(actions, a_hat) 110 | loss_dict = dict() 111 | loss_dict["mse"] = mse 112 | loss_dict["loss"] = loss_dict["mse"] 113 | return loss_dict 114 | else: # inference time 115 | a_hat = self.model(qpos, image, env_state) # no action, sample from prior 116 | return a_hat 117 | 118 | def configure_optimizers(self): 119 | return self.optimizer 120 | 121 | 122 | def kl_divergence(mu, logvar): 123 | batch_size = mu.size(0) 124 | assert batch_size != 0 125 | if mu.data.ndimension() == 4: 126 | mu = mu.view(mu.size(0), mu.size(1)) 127 | if logvar.data.ndimension() == 4: 128 | logvar = logvar.view(logvar.size(0), logvar.size(1)) 129 | 130 | klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) 131 | total_kld = klds.sum(1).mean(0, True) 132 | dimension_wise_kld = klds.mean(0) 133 | mean_kld = klds.mean(1).mean(0, True) 134 | 135 | return total_kld, dimension_wise_kld, mean_kld 136 | -------------------------------------------------------------------------------- /agents/peract_bc/qattention_stack_agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from yarr.agents.agent import Agent, ActResult, Summary 5 | 6 | import numpy as np 7 | 8 | from helpers import utils 9 | from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent 10 | 11 | NAME = "QAttentionStackAgent" 12 | 13 | 14 | class QAttentionStackAgent(Agent): 15 | def __init__( 16 | self, 17 | qattention_agents: List[QAttentionPerActBCAgent], 18 | rotation_resolution: float, 19 | camera_names: List[str], 20 | rotation_prediction_depth: int = 0, 21 | ): 22 | super(QAttentionStackAgent, self).__init__() 23 | self._qattention_agents = qattention_agents 24 | self._rotation_resolution = rotation_resolution 25 | self._camera_names = camera_names 26 | self._rotation_prediction_depth = rotation_prediction_depth 27 | 28 | def build(self, training: bool, device=None) -> None: 29 | self._device = device 30 | if self._device is None: 31 | self._device = torch.device("cpu") 32 | for qa in self._qattention_agents: 33 | qa.build(training, device) 34 | 35 | def update(self, step: int, replay_sample: dict) -> dict: 36 | priorities = 0 37 | total_losses = 0.0 38 | for qa in self._qattention_agents: 39 | update_dict = qa.update(step, replay_sample) 40 | replay_sample.update(update_dict) 41 | total_losses += update_dict["total_loss"] 42 | return { 43 | "total_losses": total_losses, 44 | } 45 | 46 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 47 | observation_elements = {} 48 | translation_results, rot_grip_results, ignore_collisions_results = [], [], [] 49 | infos = {} 50 | for depth, qagent in enumerate(self._qattention_agents): 51 | act_results = qagent.act(step, observation, deterministic) 52 | attention_coordinate = ( 53 | act_results.observation_elements["attention_coordinate"].cpu().numpy() 54 | ) 55 | observation_elements[ 56 | "attention_coordinate_layer_%d" % depth 57 | ] = attention_coordinate[0] 58 | 59 | translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action 60 | translation_results.append(translation_idxs) 61 | if rot_grip_idxs is not None: 62 | rot_grip_results.append(rot_grip_idxs) 63 | if ignore_collisions_idxs is not None: 64 | ignore_collisions_results.append(ignore_collisions_idxs) 65 | 66 | observation["attention_coordinate"] = act_results.observation_elements[ 67 | "attention_coordinate" 68 | ] 69 | observation["prev_layer_voxel_grid"] = act_results.observation_elements[ 70 | "prev_layer_voxel_grid" 71 | ] 72 | observation["prev_layer_bounds"] = act_results.observation_elements[ 73 | "prev_layer_bounds" 74 | ] 75 | 76 | for n in self._camera_names: 77 | px, py = utils.point_to_pixel_index( 78 | attention_coordinate[0], 79 | observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(), 80 | observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(), 81 | ) 82 | pc_t = torch.tensor( 83 | [[[py, px]]], dtype=torch.float32, device=self._device 84 | ) 85 | observation["%s_pixel_coord" % n] = pc_t 86 | observation_elements["%s_pixel_coord" % n] = [py, px] 87 | 88 | infos.update(act_results.info) 89 | 90 | rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy() 91 | ignore_collisions = float( 92 | torch.cat(ignore_collisions_results, 1)[0].cpu().numpy() 93 | ) 94 | observation_elements["trans_action_indicies"] = ( 95 | torch.cat(translation_results, 1)[0].cpu().numpy() 96 | ) 97 | observation_elements["rot_grip_action_indicies"] = rgai 98 | continuous_action = np.concatenate( 99 | [ 100 | act_results.observation_elements["attention_coordinate"] 101 | .cpu() 102 | .numpy()[0], 103 | utils.discrete_euler_to_quaternion( 104 | rgai[-4:-1], self._rotation_resolution 105 | ), 106 | rgai[-1:], 107 | [ignore_collisions], 108 | ] 109 | ) 110 | return ActResult( 111 | continuous_action, observation_elements=observation_elements, info=infos 112 | ) 113 | 114 | def update_summaries(self) -> List[Summary]: 115 | summaries = [] 116 | for qa in self._qattention_agents: 117 | summaries.extend(qa.update_summaries()) 118 | return summaries 119 | 120 | def act_summaries(self) -> List[Summary]: 121 | s = [] 122 | for qa in self._qattention_agents: 123 | s.extend(qa.act_summaries()) 124 | return s 125 | 126 | def load_weights(self, savedir: str): 127 | for qa in self._qattention_agents: 128 | qa.load_weights(savedir) 129 | 130 | def save_weights(self, savedir: str): 131 | for qa in self._qattention_agents: 132 | qa.save_weights(savedir) 133 | -------------------------------------------------------------------------------- /helpers/clip/core/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join( 13 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 14 | ) 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = ( 29 | list(range(ord("!"), ord("~") + 1)) 30 | + list(range(ord("¡"), ord("¬") + 1)) 31 | + list(range(ord("®"), ord("ÿ") + 1)) 32 | ) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8 + n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r"\s+", " ", text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe()): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 73 | merges = merges[1 : 49152 - 256 - 2 + 1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v + "" for v in vocab] 77 | for merge in merges: 78 | vocab.append("".join(merge)) 79 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 80 | self.encoder = dict(zip(vocab, range(len(vocab)))) 81 | self.decoder = {v: k for k, v in self.encoder.items()} 82 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 83 | self.cache = { 84 | "<|startoftext|>": "<|startoftext|>", 85 | "<|endoftext|>": "<|endoftext|>", 86 | } 87 | self.pat = re.compile( 88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 89 | re.IGNORECASE, 90 | ) 91 | 92 | def bpe(self, token): 93 | if token in self.cache: 94 | return self.cache[token] 95 | word = tuple(token[:-1]) + (token[-1] + "",) 96 | pairs = get_pairs(word) 97 | 98 | if not pairs: 99 | return token + "" 100 | 101 | while True: 102 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 103 | if bigram not in self.bpe_ranks: 104 | break 105 | first, second = bigram 106 | new_word = [] 107 | i = 0 108 | while i < len(word): 109 | try: 110 | j = word.index(first, i) 111 | new_word.extend(word[i:j]) 112 | i = j 113 | except: 114 | new_word.extend(word[i:]) 115 | break 116 | 117 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 118 | new_word.append(first + second) 119 | i += 2 120 | else: 121 | new_word.append(word[i]) 122 | i += 1 123 | new_word = tuple(new_word) 124 | word = new_word 125 | if len(word) == 1: 126 | break 127 | else: 128 | pairs = get_pairs(word) 129 | word = " ".join(word) 130 | self.cache[token] = word 131 | return word 132 | 133 | def encode(self, text): 134 | bpe_tokens = [] 135 | text = whitespace_clean(basic_clean(text)).lower() 136 | for token in re.findall(self.pat, text): 137 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 138 | bpe_tokens.extend( 139 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 140 | ) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = "".join([self.decoder[token] for token in tokens]) 145 | text = ( 146 | bytearray([self.byte_decoder[c] for c in text]) 147 | .decode("utf-8", errors="replace") 148 | .replace("", " ") 149 | ) 150 | return text 151 | -------------------------------------------------------------------------------- /agents/c2farm_lingunet_bc/qattention_stack_agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from yarr.agents.agent import Agent, ActResult, Summary 5 | 6 | import numpy as np 7 | 8 | from helpers import utils 9 | from agents.c2farm_lingunet_bc.qattention_lingunet_bc_agent import ( 10 | QAttentionLingUNetBCAgent, 11 | ) 12 | 13 | from scipy.spatial.transform import Rotation 14 | 15 | NAME = "QAttentionStackAgent" 16 | 17 | 18 | class QAttentionStackAgent(Agent): 19 | def __init__( 20 | self, 21 | qattention_agents: List[QAttentionLingUNetBCAgent], 22 | rotation_resolution: float, 23 | camera_names: List[str], 24 | rotation_prediction_depth: int = 0, 25 | ): 26 | super(QAttentionStackAgent, self).__init__() 27 | self._qattention_agents = qattention_agents 28 | self._rotation_resolution = rotation_resolution 29 | self._camera_names = camera_names 30 | self._rotation_prediction_depth = rotation_prediction_depth 31 | 32 | def build(self, training: bool, device=None) -> None: 33 | self._device = device 34 | if self._device is None: 35 | self._device = torch.device("cpu") 36 | for qa in self._qattention_agents: 37 | qa.build(training, device) 38 | 39 | def update(self, step: int, replay_sample: dict) -> dict: 40 | priorities = 0 41 | total_losses = 0.0 42 | for qa in self._qattention_agents: 43 | update_dict = qa.update(step, replay_sample) 44 | replay_sample.update(update_dict) 45 | total_losses += update_dict["total_loss"] 46 | return { 47 | "total_losses": total_losses, 48 | } 49 | 50 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 51 | observation_elements = {} 52 | translation_results, rot_grip_results, ignore_collisions_results = [], [], [] 53 | infos = {} 54 | for depth, qagent in enumerate(self._qattention_agents): 55 | act_results = qagent.act(step, observation, deterministic) 56 | attention_coordinate = ( 57 | act_results.observation_elements["attention_coordinate"].cpu().numpy() 58 | ) 59 | observation_elements[ 60 | "attention_coordinate_layer_%d" % depth 61 | ] = attention_coordinate[0] 62 | 63 | translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action 64 | translation_results.append(translation_idxs) 65 | if rot_grip_idxs is not None: 66 | rot_grip_results.append(rot_grip_idxs) 67 | if ignore_collisions_idxs is not None: 68 | ignore_collisions_results.append(ignore_collisions_idxs) 69 | 70 | observation["attention_coordinate"] = act_results.observation_elements[ 71 | "attention_coordinate" 72 | ] 73 | observation["prev_layer_voxel_grid"] = act_results.observation_elements[ 74 | "prev_layer_voxel_grid" 75 | ] 76 | observation["prev_layer_bounds"] = act_results.observation_elements[ 77 | "prev_layer_bounds" 78 | ] 79 | 80 | for n in self._camera_names: 81 | px, py = utils.point_to_pixel_index( 82 | attention_coordinate[0], 83 | observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(), 84 | observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(), 85 | ) 86 | pc_t = torch.tensor( 87 | [[[py, px]]], dtype=torch.float32, device=self._device 88 | ) 89 | observation["%s_pixel_coord" % n] = pc_t 90 | observation_elements["%s_pixel_coord" % n] = [py, px] 91 | 92 | infos.update(act_results.info) 93 | 94 | rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy() 95 | ignore_collisions = float( 96 | torch.cat(ignore_collisions_results, 1)[0].cpu().numpy() 97 | ) 98 | observation_elements["trans_action_indicies"] = ( 99 | torch.cat(translation_results, 1)[0].cpu().numpy() 100 | ) 101 | observation_elements["rot_grip_action_indicies"] = rgai 102 | continuous_action = np.concatenate( 103 | [ 104 | act_results.observation_elements["attention_coordinate"] 105 | .cpu() 106 | .numpy()[0], 107 | utils.discrete_euler_to_quaternion( 108 | rgai[-4:-1], self._rotation_resolution 109 | ), 110 | rgai[-1:], 111 | [ignore_collisions], 112 | ] 113 | ) 114 | return ActResult( 115 | continuous_action, observation_elements=observation_elements, info=infos 116 | ) 117 | 118 | def update_summaries(self) -> List[Summary]: 119 | summaries = [] 120 | for qa in self._qattention_agents: 121 | summaries.extend(qa.update_summaries()) 122 | return summaries 123 | 124 | def act_summaries(self) -> List[Summary]: 125 | s = [] 126 | for qa in self._qattention_agents: 127 | s.extend(qa.act_summaries()) 128 | return s 129 | 130 | def load_weights(self, savedir: str): 131 | for qa in self._qattention_agents: 132 | qa.load_weights(savedir) 133 | 134 | def save_weights(self, savedir: str): 135 | for qa in self._qattention_agents: 136 | qa.save_weights(savedir) 137 | -------------------------------------------------------------------------------- /agents/rvt/launch_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import torch 4 | import numpy as np 5 | 6 | from omegaconf import DictConfig 7 | 8 | from yarr.agents.agent import Agent 9 | from yarr.agents.agent import ActResult 10 | from yarr.agents.agent import Summary 11 | from yarr.agents.agent import ScalarSummary 12 | 13 | 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | 16 | from helpers.preprocess_agent import PreprocessAgent 17 | 18 | 19 | from rvt.mvt.mvt import MVT 20 | from rvt.models import rvt_agent 21 | from rvt.utils.peract_utils import ( 22 | CAMERAS, 23 | SCENE_BOUNDS, 24 | IMAGE_SIZE, 25 | DATA_FOLDER, 26 | ) 27 | 28 | 29 | import rvt.config as exp_cfg_mod 30 | import rvt.models.rvt_agent as rvt_agent 31 | import rvt.mvt.config as mvt_cfg_mod 32 | 33 | 34 | def create_agent(cfg: DictConfig): 35 | exp_cfg = exp_cfg_mod.get_cfg_defaults() 36 | exp_cfg.bs = cfg.replay.batch_size 37 | exp_cfg.tasks = ",".join(cfg.rlbench.tasks) 38 | 39 | exp_cfg.freeze() 40 | 41 | mvt_cfg = mvt_cfg_mod.get_cfg_defaults() 42 | mvt_cfg.proprio_dim = cfg.method.low_dim_size 43 | mvt_cfg.freeze() 44 | 45 | agent = RVTAgentWrapper( 46 | cfg.framework.checkpoint_name_prefix, cfg.rlbench, mvt_cfg, exp_cfg 47 | ) 48 | 49 | preprocess_agent = PreprocessAgent(pose_agent=agent) 50 | return preprocess_agent 51 | 52 | 53 | class RVTAgentWrapper(Agent): 54 | def __init__(self, checkpoint_name_prefix, rlbench_cfg, mvt_cfg, exp_cfg): 55 | self._checkpoint_filename = f"{checkpoint_name_prefix}.pt" 56 | self.rvt_agent = None 57 | self.rlbench_cfg = rlbench_cfg 58 | self.mvt_cfg = mvt_cfg 59 | self.exp_cfg = exp_cfg 60 | self._summaries = {} 61 | 62 | def build(self, training: bool, device=None) -> None: 63 | import torch 64 | 65 | torch.cuda.set_device(device) 66 | torch.cuda.empty_cache() 67 | 68 | if isinstance(device, int): 69 | device = f"cuda:{device}" 70 | 71 | rvt = MVT( 72 | renderer_device=device, 73 | **self.mvt_cfg, 74 | ) 75 | rvt = rvt.to(device) 76 | 77 | if training: 78 | rvt = DDP(rvt, device_ids=[device]) 79 | 80 | self.rvt_agent = rvt_agent.RVTAgent( 81 | network=rvt, 82 | # image_resolution=self.rlbench_cfg.camera_resolution, 83 | add_lang=self.mvt_cfg.add_lang, 84 | scene_bounds=self.rlbench_cfg.scene_bounds, 85 | cameras=self.rlbench_cfg.cameras, 86 | log_dir="/tmp/eval_run", 87 | **self.exp_cfg.peract, 88 | **self.exp_cfg.rvt, 89 | ) 90 | 91 | self.rvt_agent.build(training, device) 92 | 93 | def update(self, step: int, replay_sample: dict) -> dict: 94 | for k, v in replay_sample.items(): 95 | replay_sample[k] = v.unsqueeze(1) 96 | # RVT is based on the PerAct's Colab version. 97 | replay_sample["lang_goal_embs"] = replay_sample["lang_token_embs"] 98 | replay_sample["tasks"] = self.exp_cfg.tasks.split(",") 99 | 100 | update_dict = self.rvt_agent.update(step, replay_sample) 101 | 102 | for key, val in self.rvt_agent.loss_log.items(): 103 | self._summaries[key] = np.mean(np.array(val)) 104 | 105 | return { 106 | "total_losses": update_dict["total_loss"], 107 | } 108 | 109 | return result 110 | 111 | def act(self, step: int, observation: dict, deterministic: bool) -> ActResult: 112 | return self.rvt_agent.act(step, observation, deterministic) 113 | 114 | def reset(self) -> None: 115 | self.rvt_agent.reset() 116 | 117 | def update_summaries(self) -> List[Summary]: 118 | summaries = [] 119 | for k, v in self._summaries.items(): 120 | summaries.append(ScalarSummary(f"RVT/{k}", v)) 121 | return summaries 122 | 123 | def act_summaries(self) -> List[Summary]: 124 | return [] 125 | 126 | def load_weights(self, savedir: str) -> None: 127 | """ 128 | copied from RVT 129 | """ 130 | device = torch.device("cuda:0") 131 | weight_file = os.path.join(savedir, self._checkpoint_filename) 132 | state_dict = torch.load(weight_file, map_location=device) 133 | 134 | model = self.rvt_agent._network 135 | optimizer = self.rvt_agent._optimizer 136 | lr_sched = self.rvt_agent._lr_sched 137 | 138 | if isinstance(model, DDP): 139 | model = model.module 140 | 141 | model.load_state_dict(state_dict["model_state"]) 142 | optimizer.load_state_dict(state_dict["optimizer_state"]) 143 | lr_sched.load_state_dict(state_dict["lr_sched_state"]) 144 | 145 | return self.rvt_agent.load_clip() 146 | 147 | def save_weights(self, savedir: str) -> None: 148 | os.makedirs(savedir, exist_ok=True) 149 | 150 | weight_file = os.path.join(savedir, self._checkpoint_filename) 151 | 152 | model = self.rvt_agent._network 153 | optimizer = self.rvt_agent._optimizer 154 | lr_sched = self.rvt_agent._lr_sched 155 | 156 | if isinstance(model, DDP): 157 | model = model.module 158 | 159 | model_state = model.state_dict() 160 | 161 | torch.save( 162 | { 163 | "model_state": model_state, 164 | "optimizer_state": optimizer.state_dict(), 165 | "lr_sched_state": lr_sched.state_dict(), 166 | }, 167 | weight_file, 168 | ) 169 | -------------------------------------------------------------------------------- /helpers/preprocess_agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | from yarr.agents.agent import ( 7 | Agent, 8 | Summary, 9 | ActResult, 10 | ScalarSummary, 11 | HistogramSummary, 12 | ImageSummary, 13 | ) 14 | 15 | 16 | class PreprocessAgent(Agent): 17 | def __init__( 18 | self, pose_agent: Agent, norm_rgb: bool = True, norm_type: str = "zero_mean" 19 | ): 20 | self._pose_agent = pose_agent 21 | self._norm_rgb = norm_rgb 22 | self._norm_type = norm_type 23 | 24 | def build(self, training: bool, device: torch.device = None): 25 | self._pose_agent.build(training, device) 26 | 27 | def _norm_rgb_(self, x): 28 | if self._norm_type == "zero_mean": 29 | return (x.float() / 255.0) * 2.0 - 1.0 30 | elif self._norm_type == "imagenet": 31 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 32 | # std=[0.229, 0.224, 0.225]) 33 | # return normalize(x) 34 | return x.float() / 255.0 35 | else: 36 | raise NotImplementedError 37 | 38 | def update(self, step: int, replay_sample: dict) -> dict: 39 | # Samples are (B, N, ...) where N is number of buffers/tasks. This is a single task setup, so 0 index. 40 | replay_sample = { 41 | k: v[:, 0] if len(v.shape) > 2 and v.shape[1] == 1 else v 42 | for k, v in replay_sample.items() 43 | } 44 | for k, v in replay_sample.items(): 45 | if self._norm_rgb and "rgb" in k: 46 | replay_sample[k] = self._norm_rgb_(v) 47 | else: 48 | replay_sample[k] = v.float() 49 | self._replay_sample = replay_sample 50 | return self._pose_agent.update(step, replay_sample) 51 | 52 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 53 | # observation = {k: torch.tensor(v) for k, v in observation.items()} 54 | for k, v in observation.items(): 55 | if self._norm_rgb and "rgb" in k: 56 | observation[k] = self._norm_rgb_(v) 57 | else: 58 | observation[k] = v.float() 59 | act_res = self._pose_agent.act(step, observation, deterministic) 60 | act_res.replay_elements.update({"demo": False}) 61 | return act_res 62 | 63 | def update_summaries(self) -> List[Summary]: 64 | prefix = "inputs" 65 | demo_f = self._replay_sample["demo"].float() 66 | demo_proportion = demo_f.mean() 67 | tile = lambda x: torch.squeeze(torch.cat(x.split(1, dim=1), dim=-1), dim=1) 68 | sums = [ 69 | ScalarSummary("%s/demo_proportion" % prefix, demo_proportion), 70 | ScalarSummary( 71 | "%s/timeouts" % prefix, self._replay_sample["timeout"].float().mean() 72 | ), 73 | ] 74 | 75 | for robot_prefix in ["", "right_", "left_"]: 76 | if not f"{robot_prefix}low_dim_state" in self._replay_sample.keys(): 77 | continue 78 | 79 | sums.extend( 80 | [ 81 | HistogramSummary( 82 | f"{prefix}/{robot_prefix}low_dim_state", 83 | self._replay_sample[f"{robot_prefix}low_dim_state"], 84 | ), 85 | HistogramSummary( 86 | f"{prefix}/{robot_prefix}low_dim_state_tp1", 87 | self._replay_sample[f"{robot_prefix}low_dim_state_tp1"], 88 | ), 89 | ScalarSummary( 90 | f"{prefix}/{robot_prefix}low_dim_state_mean", 91 | self._replay_sample[f"{robot_prefix}low_dim_state"].mean(), 92 | ), 93 | ScalarSummary( 94 | f"{prefix}/{robot_prefix}low_dim_state_min", 95 | self._replay_sample[f"{robot_prefix}low_dim_state"].min(), 96 | ), 97 | ScalarSummary( 98 | f"{prefix}/{robot_prefix}low_dim_state_max", 99 | self._replay_sample[f"{robot_prefix}low_dim_state"].max(), 100 | ), 101 | ] 102 | ) 103 | 104 | for k, v in self._replay_sample.items(): 105 | if "rgb" in k or "point_cloud" in k: 106 | if "rgb" in k: 107 | # Convert back to 0 - 1 108 | v = (v + 1.0) / 2.0 109 | sums.append( 110 | ImageSummary( 111 | "%s/%s" % (prefix, k), tile(v) if len(v.shape) > 4 else v 112 | ) 113 | ) 114 | 115 | if "sampling_probabilities" in self._replay_sample: 116 | sums.extend( 117 | [ 118 | HistogramSummary( 119 | "replay/priority", self._replay_sample["sampling_probabilities"] 120 | ), 121 | ] 122 | ) 123 | sums.extend(self._pose_agent.update_summaries()) 124 | return sums 125 | 126 | def act_summaries(self) -> List[Summary]: 127 | return self._pose_agent.act_summaries() 128 | 129 | def load_weights(self, savedir: str): 130 | self._pose_agent.load_weights(savedir) 131 | 132 | def save_weights(self, savedir: str): 133 | self._pose_agent.save_weights(savedir) 134 | 135 | def reset(self) -> None: 136 | self._pose_agent.reset() 137 | -------------------------------------------------------------------------------- /helpers/demo_loading_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import numpy as np 5 | from rlbench.demo import Demo 6 | import omegaconf 7 | 8 | 9 | def _is_stopped(demo, i, obs, delta=0.1): 10 | next_is_not_final = i == (len(demo) - 2) 11 | gripper_state_no_change = i < (len(demo) - 2) and ( 12 | obs.gripper_open == demo[i + 1].gripper_open 13 | and obs.gripper_open == demo[i - 1].gripper_open 14 | and demo[i - 2].gripper_open == demo[i - 1].gripper_open 15 | ) 16 | small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) 17 | return small_delta and (not next_is_not_final) and gripper_state_no_change 18 | 19 | 20 | def _is_stopped_right(demo, i, obs, delta=0.1): 21 | next_is_not_final = i == (len(demo) - 2) 22 | gripper_state_no_change = i < (len(demo) - 2) and ( 23 | obs.gripper_open == demo[i + 1].right.gripper_open 24 | and obs.gripper_open == demo[i - 1].right.gripper_open 25 | and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open 26 | ) 27 | small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) 28 | return small_delta and (not next_is_not_final) and gripper_state_no_change 29 | 30 | 31 | def _is_stopped_left(demo, i, obs, delta=0.1): 32 | next_is_not_final = i == (len(demo) - 2) 33 | gripper_state_no_change = i < (len(demo) - 2) and ( 34 | obs.gripper_open == demo[i + 1].left.gripper_open 35 | and obs.gripper_open == demo[i - 1].left.gripper_open 36 | and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open 37 | ) 38 | small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) 39 | return small_delta and (not next_is_not_final) and gripper_state_no_change 40 | 41 | 42 | def _keypoint_discovery_bimanual(demo: Demo, stopping_delta=0.1) -> List[int]: 43 | episode_keypoints = [] 44 | right_prev_gripper_open = demo[0].right.gripper_open 45 | left_prev_gripper_open = demo[0].left.gripper_open 46 | stopped_buffer = 0 47 | for i, obs in enumerate(demo._observations): 48 | right_stopped = _is_stopped_right(demo, i, obs.right, stopping_delta) 49 | left_stopped = _is_stopped_left(demo, i, obs.left, stopping_delta) 50 | stopped = (stopped_buffer <= 0) and right_stopped and left_stopped 51 | stopped_buffer = 4 if stopped else stopped_buffer - 1 52 | # if change in gripper, or end of episode. 53 | last = i == (len(demo) - 1) 54 | right_state_changed = obs.right.gripper_open != right_prev_gripper_open 55 | left_state_changed = obs.left.gripper_open != left_prev_gripper_open 56 | state_changed = right_state_changed or left_state_changed 57 | if i != 0 and (state_changed or last or stopped): 58 | episode_keypoints.append(i) 59 | 60 | right_prev_gripper_open = obs.right.gripper_open 61 | left_prev_gripper_open = obs.left.gripper_open 62 | if ( 63 | len(episode_keypoints) > 1 64 | and (episode_keypoints[-1] - 1) == episode_keypoints[-2] 65 | ): 66 | episode_keypoints.pop(-2) 67 | print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) 68 | return episode_keypoints 69 | 70 | 71 | def _keypoint_discovery_unimanual(demo: Demo, stopping_delta=0.1) -> List[int]: 72 | episode_keypoints = [] 73 | prev_gripper_open = demo[0].gripper_open 74 | stopped_buffer = 0 75 | for i, obs in enumerate(demo): 76 | stopped = _is_stopped(demo, i, obs, stopping_delta) 77 | stopped = (stopped_buffer <= 0) and stopped 78 | stopped_buffer = 4 if stopped else stopped_buffer - 1 79 | # if change in gripper, or end of episode. 80 | last = i == (len(demo) - 1) 81 | if i != 0 and (obs.gripper_open != prev_gripper_open or last or stopped): 82 | episode_keypoints.append(i) 83 | prev_gripper_open = obs.gripper_open 84 | if ( 85 | len(episode_keypoints) > 1 86 | and (episode_keypoints[-1] - 1) == episode_keypoints[-2] 87 | ): 88 | episode_keypoints.pop(-2) 89 | print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) 90 | return episode_keypoints 91 | 92 | 93 | def _keypoint_discovery_heuristic(demo: Demo, stopping_delta=0.1) -> List[int]: 94 | if demo[0].is_bimanual: 95 | return _keypoint_discovery_bimanual(demo, stopping_delta) 96 | else: 97 | return _keypoint_discovery_unimanual(demo, stopping_delta) 98 | 99 | 100 | def keypoint_discovery(demo: Demo, stopping_delta=0.1, method="heuristic") -> List[int]: 101 | episode_keypoints = [] 102 | if method == "heuristic": 103 | return _keypoint_discovery_heuristic(demo, stopping_delta) 104 | 105 | elif method == "random": 106 | # Randomly select keypoints. 107 | episode_keypoints = np.random.choice(range(len(demo)), size=20, replace=False) 108 | episode_keypoints.sort() 109 | return episode_keypoints 110 | 111 | elif method == "fixed_interval": 112 | # Fixed interval. 113 | episode_keypoints = [] 114 | segment_length = len(demo) // 20 115 | for i in range(0, len(demo), segment_length): 116 | episode_keypoints.append(i) 117 | return episode_keypoints 118 | elif isinstance(method, omegaconf.listconfig.ListConfig): 119 | return list(method) 120 | else: 121 | raise NotImplementedError 122 | 123 | 124 | # find minimum difference between any two elements in list 125 | def find_minimum_difference(lst): 126 | minimum = lst[-1] 127 | for i in range(1, len(lst)): 128 | if lst[i] - lst[i - 1] < minimum: 129 | minimum = lst[i] - lst[i - 1] 130 | return minimum 131 | -------------------------------------------------------------------------------- /helpers/optim/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | # LAMB optimizer used as is. 4 | # Source: https://github.com/cybertronai/pytorch-lamb 5 | # License: https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE 6 | 7 | import collections 8 | import math 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | # def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 15 | # """Log a histogram of trust ratio scalars in across layers.""" 16 | # results = collections.defaultdict(list) 17 | # for group in optimizer.param_groups: 18 | # for p in group['params']: 19 | # state = optimizer.state[p] 20 | # for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 21 | # if i in state: 22 | # results[i].append(state[i]) 23 | # 24 | # for k, v in results.items(): 25 | # event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 26 | 27 | 28 | class Lamb(Optimizer): 29 | r"""Implements Lamb algorithm. 30 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 31 | Arguments: 32 | params (iterable): iterable of parameters to optimize or dicts defining 33 | parameter groups 34 | lr (float, optional): learning rate (default: 1e-3) 35 | betas (Tuple[float, float], optional): coefficients used for computing 36 | running averages of gradient and its square (default: (0.9, 0.999)) 37 | eps (float, optional): term added to the denominator to improve 38 | numerical stability (default: 1e-8) 39 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 40 | adam (bool, optional): always use trust ratio = 1, which turns this into 41 | Adam. Useful for comparison purposes. 42 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 43 | https://arxiv.org/abs/1904.00962 44 | """ 45 | 46 | def __init__( 47 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False 48 | ): 49 | if not 0.0 <= lr: 50 | raise ValueError("Invalid learning rate: {}".format(lr)) 51 | if not 0.0 <= eps: 52 | raise ValueError("Invalid epsilon value: {}".format(eps)) 53 | if not 0.0 <= betas[0] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 55 | if not 0.0 <= betas[1] < 1.0: 56 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 57 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 58 | self.adam = adam 59 | super(Lamb, self).__init__(params, defaults) 60 | 61 | def step(self, closure=None): 62 | """Performs a single optimization step. 63 | Arguments: 64 | closure (callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | loss = None 68 | if closure is not None: 69 | loss = closure() 70 | 71 | for group in self.param_groups: 72 | for p in group["params"]: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError( 78 | "Lamb does not support sparse gradients, consider SparseAdam instad." 79 | ) 80 | 81 | state = self.state[p] 82 | 83 | # State initialization 84 | if len(state) == 0: 85 | state["step"] = 0 86 | # Exponential moving average of gradient values 87 | state["exp_avg"] = torch.zeros_like(p.data) 88 | # Exponential moving average of squared gradient values 89 | state["exp_avg_sq"] = torch.zeros_like(p.data) 90 | 91 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 92 | beta1, beta2 = group["betas"] 93 | 94 | state["step"] += 1 95 | 96 | # Decay the first and second moment running average coefficient 97 | # m_t 98 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 99 | # v_t 100 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 101 | 102 | # Paper v3 does not use debiasing. 103 | # bias_correction1 = 1 - beta1 ** state['step'] 104 | # bias_correction2 = 1 - beta2 ** state['step'] 105 | # Apply bias to lr to avoid broadcast. 106 | step_size = group[ 107 | "lr" 108 | ] # * math.sqrt(bias_correction2) / bias_correction1 109 | 110 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 111 | 112 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) 113 | if group["weight_decay"] != 0: 114 | adam_step.add_(p.data, alpha=group["weight_decay"]) 115 | 116 | adam_norm = adam_step.pow(2).sum().sqrt() 117 | if weight_norm == 0 or adam_norm == 0: 118 | trust_ratio = 1 119 | else: 120 | trust_ratio = weight_norm / adam_norm 121 | state["weight_norm"] = weight_norm 122 | state["adam_norm"] = adam_norm 123 | state["trust_ratio"] = trust_ratio 124 | if self.adam: 125 | trust_ratio = 1 126 | 127 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 128 | 129 | return loss 130 | -------------------------------------------------------------------------------- /agents/baselines/bc_lang/bc_lang_agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary 10 | 11 | from helpers import utils 12 | from helpers.utils import stack_on_channel 13 | 14 | from helpers.clip.core.clip import build_model, load_clip 15 | 16 | NAME = "BCLangAgent" 17 | REPLAY_ALPHA = 0.7 18 | REPLAY_BETA = 1.0 19 | 20 | 21 | class Actor(nn.Module): 22 | def __init__(self, actor_network: nn.Module): 23 | super(Actor, self).__init__() 24 | self._actor_network = copy.deepcopy(actor_network) 25 | self._actor_network.build() 26 | 27 | def forward(self, observations, robot_state, lang_goal_emb): 28 | mu = self._actor_network(observations, robot_state, lang_goal_emb) 29 | return mu 30 | 31 | 32 | class BCLangAgent(Agent): 33 | def __init__( 34 | self, 35 | actor_network: nn.Module, 36 | camera_name: str, 37 | lr: float = 0.01, 38 | weight_decay: float = 1e-5, 39 | grad_clip: float = 20.0, 40 | ): 41 | self._camera_name = camera_name 42 | self._actor_network = actor_network 43 | self._lr = lr 44 | self._weight_decay = weight_decay 45 | self._grad_clip = grad_clip 46 | 47 | def build(self, training: bool, device: torch.device = None): 48 | if device is None: 49 | device = torch.device("cpu") 50 | self._actor = Actor(self._actor_network).to(device).train(training) 51 | if training: 52 | self._actor_optimizer = torch.optim.Adam( 53 | self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay 54 | ) 55 | logging.info( 56 | "# Actor Params: %d" 57 | % sum(p.numel() for p in self._actor.parameters() if p.requires_grad) 58 | ) 59 | else: 60 | for p in self._actor.parameters(): 61 | p.requires_grad = False 62 | 63 | model, _ = load_clip("RN50", jit=False) 64 | self._clip_rn50 = build_model(model.state_dict()) 65 | self._clip_rn50 = self._clip_rn50.float().to(device) 66 | self._clip_rn50.eval() 67 | del model 68 | 69 | self._device = device 70 | 71 | def _grad_step(self, loss, opt, model_params=None, clip=None): 72 | opt.zero_grad() 73 | loss.backward() 74 | if clip is not None and model_params is not None: 75 | nn.utils.clip_grad_value_(model_params, clip) 76 | opt.step() 77 | 78 | def update(self, step: int, replay_sample: dict) -> dict: 79 | lang_goal_emb = replay_sample["lang_goal_emb"] 80 | robot_state = replay_sample["low_dim_state"] 81 | observations = [ 82 | replay_sample["%s_rgb" % self._camera_name], 83 | replay_sample["%s_point_cloud" % self._camera_name], 84 | ] 85 | mu = self._actor(observations, robot_state, lang_goal_emb) 86 | loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) 87 | delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1) 88 | loss = (delta * loss_weights).mean() 89 | self._grad_step( 90 | loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip 91 | ) 92 | self._summaries = { 93 | "pi/loss": loss, 94 | "pi/mu": mu.mean(), 95 | } 96 | return {"total_losses": loss} 97 | 98 | def _normalize_quat(self, x): 99 | return x / x.square().sum(dim=1).sqrt().unsqueeze(-1) 100 | 101 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 102 | lang_goal_tokens = observation.get("lang_goal_tokens", None).long() 103 | 104 | with torch.no_grad(): 105 | lang_goal_tokens = lang_goal_tokens.to(device=self._device) 106 | lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings( 107 | lang_goal_tokens[0] 108 | ) 109 | lang_goal_emb = lang_goal_emb.to(device=self._device) 110 | 111 | observations = [ 112 | observation["%s_rgb" % self._camera_name][0].to(self._device), 113 | observation["%s_point_cloud" % self._camera_name][0].to(self._device), 114 | ] 115 | robot_state = observation["low_dim_state"][0].to(self._device) 116 | 117 | mu = self._actor(observations, robot_state, lang_goal_emb) 118 | mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1) 119 | ignore_collisions = torch.Tensor([1.0]).to(mu.device) 120 | mu0 = torch.cat([mu[0], ignore_collisions]) 121 | return ActResult(mu0.detach().cpu()) 122 | 123 | def update_summaries(self) -> List[Summary]: 124 | summaries = [] 125 | for n, v in self._summaries.items(): 126 | summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) 127 | 128 | for tag, param in self._actor.named_parameters(): 129 | summaries.append( 130 | HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) 131 | ) 132 | summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) 133 | 134 | return summaries 135 | 136 | def act_summaries(self) -> List[Summary]: 137 | return [] 138 | 139 | def load_weights(self, savedir: str): 140 | self._actor.load_state_dict( 141 | torch.load( 142 | os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu") 143 | ) 144 | ) 145 | print("Loaded weights from %s" % savedir) 146 | 147 | def save_weights(self, savedir: str): 148 | torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt")) 149 | -------------------------------------------------------------------------------- /agents/baselines/vit_bc_lang/vit_bc_lang_agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from yarr.agents.agent import Agent, Summary, ActResult, ScalarSummary, HistogramSummary 10 | 11 | from helpers import utils 12 | from helpers.utils import stack_on_channel 13 | 14 | from helpers.clip.core.clip import build_model, load_clip 15 | 16 | NAME = "ViTBCLangAgent" 17 | REPLAY_ALPHA = 0.7 18 | REPLAY_BETA = 1.0 19 | 20 | 21 | class Actor(nn.Module): 22 | def __init__(self, actor_network: nn.Module): 23 | super(Actor, self).__init__() 24 | self._actor_network = copy.deepcopy(actor_network) 25 | self._actor_network.build() 26 | 27 | def forward(self, observations, robot_state, lang_goal_emb): 28 | mu = self._actor_network(observations, robot_state, lang_goal_emb) 29 | return mu 30 | 31 | 32 | class ViTBCLangAgent(Agent): 33 | def __init__( 34 | self, 35 | actor_network: nn.Module, 36 | camera_name: str, 37 | lr: float = 0.01, 38 | weight_decay: float = 1e-5, 39 | grad_clip: float = 20.0, 40 | ): 41 | self._camera_name = camera_name 42 | self._actor_network = actor_network 43 | self._lr = lr 44 | self._weight_decay = weight_decay 45 | self._grad_clip = grad_clip 46 | 47 | def build(self, training: bool, device: torch.device = None): 48 | if device is None: 49 | device = torch.device("cpu") 50 | self._actor = Actor(self._actor_network).to(device).train(training) 51 | if training: 52 | self._actor_optimizer = torch.optim.Adam( 53 | self._actor.parameters(), lr=self._lr, weight_decay=self._weight_decay 54 | ) 55 | logging.info( 56 | "# Actor Params: %d" 57 | % sum(p.numel() for p in self._actor.parameters() if p.requires_grad) 58 | ) 59 | else: 60 | for p in self._actor.parameters(): 61 | p.requires_grad = False 62 | 63 | model, _ = load_clip("RN50", jit=False) 64 | self._clip_rn50 = build_model(model.state_dict()) 65 | self._clip_rn50 = self._clip_rn50.float().to(device) 66 | self._clip_rn50.eval() 67 | del model 68 | 69 | self._device = device 70 | 71 | def _grad_step(self, loss, opt, model_params=None, clip=None): 72 | opt.zero_grad() 73 | loss.backward() 74 | if clip is not None and model_params is not None: 75 | nn.utils.clip_grad_value_(model_params, clip) 76 | opt.step() 77 | 78 | def update(self, step: int, replay_sample: dict) -> dict: 79 | lang_goal_emb = replay_sample["lang_goal_emb"] 80 | robot_state = replay_sample["low_dim_state"] 81 | observations = [ 82 | replay_sample["%s_rgb" % self._camera_name], 83 | replay_sample["%s_point_cloud" % self._camera_name], 84 | ] 85 | mu = self._actor(observations, robot_state, lang_goal_emb) 86 | loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) 87 | delta = F.mse_loss(mu, replay_sample["action"], reduction="none").mean(1) 88 | loss = (delta * loss_weights).mean() 89 | self._grad_step( 90 | loss, self._actor_optimizer, self._actor.parameters(), self._grad_clip 91 | ) 92 | self._summaries = { 93 | "pi/loss": loss, 94 | "pi/mu": mu.mean(), 95 | } 96 | return {"total_losses": loss} 97 | 98 | def _normalize_quat(self, x): 99 | return x / x.square().sum(dim=1).sqrt().unsqueeze(-1) 100 | 101 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 102 | lang_goal_tokens = observation.get("lang_goal_tokens", None).long() 103 | 104 | with torch.no_grad(): 105 | lang_goal_tokens = lang_goal_tokens.to(device=self._device) 106 | lang_goal_emb, _ = self._clip_rn50.encode_text_with_embeddings( 107 | lang_goal_tokens[0] 108 | ) 109 | lang_goal_emb = lang_goal_emb.to(device=self._device) 110 | 111 | observations = [ 112 | observation["%s_rgb" % self._camera_name][0].to(self._device), 113 | observation["%s_point_cloud" % self._camera_name][0].to(self._device), 114 | ] 115 | robot_state = observation["low_dim_state"][0].to(self._device) 116 | 117 | mu = self._actor(observations, robot_state, lang_goal_emb) 118 | mu = torch.cat([mu[:, :3], self._normalize_quat(mu[:, 3:7]), mu[:, 7:]], dim=-1) 119 | ignore_collisions = torch.Tensor([1.0]).to(mu.device) 120 | mu0 = torch.cat([mu[0], ignore_collisions]) 121 | return ActResult(mu0.detach().cpu()) 122 | 123 | def update_summaries(self) -> List[Summary]: 124 | summaries = [] 125 | for n, v in self._summaries.items(): 126 | summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) 127 | 128 | for tag, param in self._actor.named_parameters(): 129 | summaries.append( 130 | HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) 131 | ) 132 | summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) 133 | 134 | return summaries 135 | 136 | def act_summaries(self) -> List[Summary]: 137 | return [] 138 | 139 | def load_weights(self, savedir: str): 140 | self._actor.load_state_dict( 141 | torch.load( 142 | os.path.join(savedir, "bc_actor.pt"), map_location=torch.device("cpu") 143 | ) 144 | ) 145 | print("Loaded weights from %s" % savedir) 146 | 147 | def save_weights(self, savedir: str): 148 | torch.save(self._actor.state_dict(), os.path.join(savedir, "bc_actor.pt")) 149 | -------------------------------------------------------------------------------- /helpers/clip/core/transport_image_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cliport.models as models 3 | from cliport.utils import utils 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class TransportImageGoal(nn.Module): 11 | """Transport module.""" 12 | 13 | def __init__( 14 | self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device 15 | ): 16 | """Transport module for placing. 17 | Args: 18 | in_shape: shape of input image. 19 | n_rotations: number of rotations of convolving kernel. 20 | crop_size: crop size around pick argmax used as convolving kernel. 21 | preprocess: function to preprocess input images. 22 | """ 23 | super().__init__() 24 | 25 | self.iters = 0 26 | self.stream_fcn = stream_fcn 27 | self.n_rotations = n_rotations 28 | self.crop_size = crop_size # crop size must be N*16 (e.g. 96) 29 | self.preprocess = preprocess 30 | self.cfg = cfg 31 | self.device = device 32 | self.batchnorm = self.cfg["train"]["batchnorm"] 33 | 34 | self.pad_size = int(self.crop_size / 2) 35 | self.padding = np.zeros((3, 2), dtype=int) 36 | self.padding[:2, :] = self.pad_size 37 | 38 | in_shape = np.array(in_shape) 39 | in_shape = tuple(in_shape) 40 | self.in_shape = in_shape 41 | 42 | # Crop before network (default for Transporters CoRL 2020). 43 | self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2]) 44 | 45 | if not hasattr(self, "output_dim"): 46 | self.output_dim = 3 47 | if not hasattr(self, "kernel_dim"): 48 | self.kernel_dim = 3 49 | 50 | self.rotator = utils.ImageRotator(self.n_rotations) 51 | 52 | self._build_nets() 53 | 54 | def _build_nets(self): 55 | stream_one_fcn, _ = self.stream_fcn 56 | model = models.names[stream_one_fcn] 57 | self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) 58 | self.query_resnet = model(self.in_shape, self.kernel_dim, self.cfg, self.device) 59 | self.goal_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) 60 | print(f"Transport FCN: {stream_one_fcn}") 61 | 62 | def correlate(self, in0, in1, softmax): 63 | """Correlate two input tensors.""" 64 | output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size)) 65 | output = F.interpolate( 66 | output, size=(in0.shape[-2], in0.shape[-1]), mode="bilinear" 67 | ) 68 | output = output[ 69 | :, :, self.pad_size : -self.pad_size, self.pad_size : -self.pad_size 70 | ] 71 | if softmax: 72 | output_shape = output.shape 73 | output = output.reshape((1, np.prod(output.shape))) 74 | output = F.softmax(output, dim=-1) 75 | output = output.reshape(output_shape[1:]) 76 | return output 77 | 78 | def forward(self, inp_img, goal_img, p, softmax=True): 79 | """Forward pass.""" 80 | 81 | # Input image. 82 | img_unprocessed = np.pad(inp_img, self.padding, mode="constant") 83 | input_data = img_unprocessed 84 | in_shape = (1,) + input_data.shape 85 | input_data = input_data.reshape(in_shape) 86 | in_tensor = torch.from_numpy(input_data.copy()).to( 87 | dtype=torch.float, device=self.device 88 | ) 89 | in_tensor = in_tensor.permute(0, 3, 1, 2) 90 | 91 | # Goal image. 92 | goal_tensor = np.pad(goal_img, self.padding, mode="constant") 93 | goal_shape = (1,) + goal_tensor.shape 94 | goal_tensor = goal_tensor.reshape(goal_shape) 95 | goal_tensor = torch.from_numpy(goal_tensor.copy()).to( 96 | dtype=torch.float, device=self.device 97 | ) 98 | goal_tensor = goal_tensor.permute(0, 3, 1, 2) 99 | 100 | # Rotation pivot. 101 | pv = np.array([p[0], p[1]]) + self.pad_size 102 | hcrop = self.pad_size 103 | 104 | # Cropped input features. 105 | in_crop = in_tensor.repeat(self.n_rotations, 1, 1, 1) 106 | in_crop = self.rotator(in_crop, pivot=pv) 107 | in_crop = torch.cat(in_crop, dim=0) 108 | in_crop = in_crop[ 109 | :, :, pv[0] - hcrop : pv[0] + hcrop, pv[1] - hcrop : pv[1] + hcrop 110 | ] 111 | 112 | # Cropped goal features. 113 | goal_crop = goal_tensor.repeat(self.n_rotations, 1, 1, 1) 114 | goal_crop = self.rotator(goal_crop, pivot=pv) 115 | goal_crop = torch.cat(goal_crop, dim=0) 116 | goal_crop = goal_crop[ 117 | :, :, pv[0] - hcrop : pv[0] + hcrop, pv[1] - hcrop : pv[1] + hcrop 118 | ] 119 | 120 | in_logits = self.key_resnet(in_tensor) 121 | goal_logits = self.goal_resnet(goal_tensor) 122 | kernel_crop = self.query_resnet(in_crop) 123 | goal_crop = self.goal_resnet(goal_crop) 124 | 125 | # Fuse Goal and Transport features 126 | goal_x_in_logits = ( 127 | goal_logits + in_logits 128 | ) # Mohit: why doesn't multiply work? :( 129 | goal_x_kernel = goal_crop + kernel_crop 130 | 131 | # TODO(Mohit): Crop after network. Broken for now 132 | # in_logits = self.key_resnet(in_tensor) 133 | # kernel_nocrop_logits = self.query_resnet(in_tensor) 134 | # goal_logits = self.goal_resnet(goal_tensor) 135 | 136 | # goal_x_in_logits = in_logits 137 | # goal_x_kernel_logits = goal_logits * kernel_nocrop_logits 138 | 139 | # goal_crop = goal_x_kernel_logits.repeat(self.n_rotations, 1, 1, 1) 140 | # goal_crop = self.rotator(goal_crop, pivot=pv) 141 | # goal_crop = torch.cat(goal_crop, dim=0) 142 | # goal_crop = goal_crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 143 | 144 | return self.correlate(goal_x_in_logits, goal_x_kernel, softmax) 145 | -------------------------------------------------------------------------------- /helpers/clip/core/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class IdentityBlock(nn.Module): 7 | def __init__( 8 | self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True 9 | ): 10 | super(IdentityBlock, self).__init__() 11 | self.final_relu = final_relu 12 | self.batchnorm = batchnorm 13 | 14 | filters1, filters2, filters3 = filters 15 | self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() 17 | self.conv2 = nn.Conv2d( 18 | filters1, 19 | filters2, 20 | kernel_size=kernel_size, 21 | dilation=1, 22 | stride=stride, 23 | padding=1, 24 | bias=False, 25 | ) 26 | self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() 27 | self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = F.relu(self.bn2(self.conv2(out))) 33 | out = self.bn3(self.conv3(out)) 34 | out += x 35 | if self.final_relu: 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ConvBlock(nn.Module): 41 | def __init__( 42 | self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True 43 | ): 44 | super(ConvBlock, self).__init__() 45 | self.final_relu = final_relu 46 | self.batchnorm = batchnorm 47 | 48 | filters1, filters2, filters3 = filters 49 | self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() 51 | self.conv2 = nn.Conv2d( 52 | filters1, 53 | filters2, 54 | kernel_size=kernel_size, 55 | dilation=1, 56 | stride=stride, 57 | padding=1, 58 | bias=False, 59 | ) 60 | self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() 61 | self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() 63 | 64 | self.shortcut = nn.Sequential( 65 | nn.Conv2d(in_planes, filters3, kernel_size=1, stride=stride, bias=False), 66 | nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity(), 67 | ) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = F.relu(self.bn2(self.conv2(out))) 72 | out = self.bn3(self.conv3(out)) 73 | out += self.shortcut(x) 74 | if self.final_relu: 75 | out = F.relu(out) 76 | return out 77 | 78 | 79 | class ResNet43_8s(nn.Module): 80 | def __init__(self, input_shape, output_dim, cfg, device, preprocess): 81 | super(ResNet43_8s, self).__init__() 82 | self.input_shape = input_shape 83 | self.input_dim = input_shape[-1] 84 | self.output_dim = output_dim 85 | self.cfg = cfg 86 | self.device = device 87 | self.batchnorm = self.cfg["train"]["batchnorm"] 88 | self.preprocess = preprocess 89 | 90 | self.layers = self._make_layers() 91 | 92 | def _make_layers(self): 93 | layers = nn.Sequential( 94 | # conv1 95 | nn.Conv2d(self.input_dim, 64, stride=1, kernel_size=3, padding=1), 96 | nn.BatchNorm2d(64) if self.batchnorm else nn.Identity(), 97 | nn.ReLU(True), 98 | # fcn 99 | ConvBlock( 100 | 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm 101 | ), 102 | IdentityBlock( 103 | 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm 104 | ), 105 | ConvBlock( 106 | 64, [128, 128, 128], kernel_size=3, stride=2, batchnorm=self.batchnorm 107 | ), 108 | IdentityBlock( 109 | 128, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm 110 | ), 111 | ConvBlock( 112 | 128, [256, 256, 256], kernel_size=3, stride=2, batchnorm=self.batchnorm 113 | ), 114 | IdentityBlock( 115 | 256, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm 116 | ), 117 | ConvBlock( 118 | 256, [512, 512, 512], kernel_size=3, stride=2, batchnorm=self.batchnorm 119 | ), 120 | IdentityBlock( 121 | 512, [512, 512, 512], kernel_size=3, stride=1, batchnorm=self.batchnorm 122 | ), 123 | # head 124 | ConvBlock( 125 | 512, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm 126 | ), 127 | IdentityBlock( 128 | 256, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm 129 | ), 130 | nn.UpsamplingBilinear2d(scale_factor=2), 131 | ConvBlock( 132 | 256, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm 133 | ), 134 | IdentityBlock( 135 | 128, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm 136 | ), 137 | nn.UpsamplingBilinear2d(scale_factor=2), 138 | ConvBlock( 139 | 128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm 140 | ), 141 | IdentityBlock( 142 | 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm 143 | ), 144 | nn.UpsamplingBilinear2d(scale_factor=2), 145 | # conv2 146 | ConvBlock( 147 | 64, 148 | [16, 16, self.output_dim], 149 | kernel_size=3, 150 | stride=1, 151 | final_relu=False, 152 | batchnorm=self.batchnorm, 153 | ), 154 | IdentityBlock( 155 | self.output_dim, 156 | [16, 16, self.output_dim], 157 | kernel_size=3, 158 | stride=1, 159 | final_relu=False, 160 | batchnorm=self.batchnorm, 161 | ), 162 | ) 163 | return layers 164 | 165 | def forward(self, x): 166 | x = self.preprocess(x, dist="transporter") 167 | 168 | out = self.layers(x) 169 | return out 170 | -------------------------------------------------------------------------------- /run_seed_fn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import gc 4 | from typing import List 5 | 6 | import hydra 7 | import numpy as np 8 | import torch 9 | from omegaconf import DictConfig 10 | 11 | from rlbench import CameraConfig, ObservationConfig 12 | from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer 13 | from yarr.runners.offline_train_runner import OfflineTrainRunner 14 | from yarr.utils.stat_accumulator import SimpleAccumulator 15 | 16 | from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv 17 | import torch.distributed as dist 18 | 19 | from agents import agent_factory 20 | from agents import replay_utils 21 | 22 | import peract_config 23 | from functools import partial 24 | 25 | 26 | def run_seed( 27 | rank, 28 | cfg: DictConfig, 29 | obs_config: ObservationConfig, 30 | seed, 31 | world_size, 32 | ) -> None: 33 | peract_config.config_logging() 34 | 35 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 36 | 37 | tasks = cfg.rlbench.tasks 38 | cams = cfg.rlbench.cameras 39 | 40 | task_folder = "multi" if len(tasks) > 1 else tasks[0] 41 | replay_path = os.path.join( 42 | cfg.replay.path, task_folder, cfg.method.name, "seed%d" % seed 43 | ) 44 | 45 | agent = agent_factory.create_agent(cfg) 46 | 47 | if not agent: 48 | print("Unable to create agent") 49 | return 50 | 51 | if cfg.method.name == "ARM": 52 | raise NotImplementedError("ARM is not supported yet") 53 | elif cfg.method.name == "BC_LANG": 54 | from agents.baselines import bc_lang 55 | 56 | assert cfg.ddp.num_devices == 1, "BC_LANG only supports single GPU training" 57 | replay_buffer = bc_lang.launch_utils.create_replay( 58 | cfg.replay.batch_size, 59 | cfg.replay.timesteps, 60 | cfg.replay.prioritisation, 61 | cfg.replay.task_uniform, 62 | replay_path if cfg.replay.use_disk else None, 63 | cams, 64 | cfg.rlbench.camera_resolution, 65 | ) 66 | 67 | bc_lang.launch_utils.fill_multi_task_replay( 68 | cfg, 69 | obs_config, 70 | rank, 71 | replay_buffer, 72 | tasks, 73 | cfg.rlbench.demos, 74 | cfg.method.demo_augmentation, 75 | cfg.method.demo_augmentation_every_n, 76 | cams, 77 | ) 78 | 79 | elif cfg.method.name == "VIT_BC_LANG": 80 | from agents.baselines import vit_bc_lang 81 | 82 | assert cfg.ddp.num_devices == 1, "VIT_BC_LANG only supports single GPU training" 83 | replay_buffer = vit_bc_lang.launch_utils.create_replay( 84 | cfg.replay.batch_size, 85 | cfg.replay.timesteps, 86 | cfg.replay.prioritisation, 87 | cfg.replay.task_uniform, 88 | replay_path if cfg.replay.use_disk else None, 89 | cams, 90 | cfg.rlbench.camera_resolution, 91 | ) 92 | 93 | vit_bc_lang.launch_utils.fill_multi_task_replay( 94 | cfg, 95 | obs_config, 96 | rank, 97 | replay_buffer, 98 | tasks, 99 | cfg.rlbench.demos, 100 | cfg.method.demo_augmentation, 101 | cfg.method.demo_augmentation_every_n, 102 | cams, 103 | ) 104 | 105 | elif cfg.method.name.startswith("ACT_BC_LANG"): 106 | from agents import act_bc_lang 107 | 108 | assert cfg.ddp.num_devices == 1, "ACT_BC_LANG only supports single GPU training" 109 | replay_buffer = act_bc_lang.launch_utils.create_replay( 110 | cfg.replay.batch_size, 111 | cfg.replay.timesteps, 112 | cfg.replay.prioritisation, 113 | cfg.replay.task_uniform, 114 | replay_path if cfg.replay.use_disk else None, 115 | cams, 116 | cfg.rlbench.camera_resolution, 117 | replay_size=3e5, 118 | prev_action_horizon=cfg.method.prev_action_horizon, 119 | next_action_horizon=cfg.method.next_action_horizon, 120 | ) 121 | 122 | act_bc_lang.launch_utils.fill_multi_task_replay( 123 | cfg, 124 | obs_config, 125 | rank, 126 | replay_buffer, 127 | tasks, 128 | cfg.rlbench.demos, 129 | cfg.method.demo_augmentation, 130 | cfg.method.demo_augmentation_every_n, 131 | cams, 132 | ) 133 | 134 | elif cfg.method.name == "C2FARM_LINGUNET_BC": 135 | from agents import c2farm_lingunet_bc 136 | 137 | replay_buffer = c2farm_lingunet_bc.launch_utils.create_replay( 138 | cfg.replay.batch_size, 139 | cfg.replay.timesteps, 140 | cfg.replay.prioritisation, 141 | cfg.replay.task_uniform, 142 | replay_path if cfg.replay.use_disk else None, 143 | cams, 144 | cfg.method.voxel_sizes, 145 | cfg.rlbench.camera_resolution, 146 | ) 147 | 148 | c2farm_lingunet_bc.launch_utils.fill_multi_task_replay( 149 | cfg, 150 | obs_config, 151 | rank, 152 | replay_buffer, 153 | tasks, 154 | cfg.rlbench.demos, 155 | cfg.method.demo_augmentation, 156 | cfg.method.demo_augmentation_every_n, 157 | cams, 158 | cfg.rlbench.scene_bounds, 159 | cfg.method.voxel_sizes, 160 | cfg.method.bounds_offset, 161 | cfg.method.rotation_resolution, 162 | cfg.method.crop_augmentation, 163 | keypoint_method=cfg.method.keypoint_method, 164 | ) 165 | 166 | elif ( 167 | cfg.method.name.startswith("BIMANUAL_PERACT") 168 | or cfg.method.name.startswith("RVT") 169 | or cfg.method.name.startswith("PERACT_BC") 170 | ): 171 | replay_buffer = replay_utils.create_replay(cfg, replay_path) 172 | 173 | replay_utils.fill_multi_task_replay(cfg, obs_config, rank, replay_buffer, tasks) 174 | 175 | elif cfg.method.name == "PERACT_RL": 176 | raise NotImplementedError("PERACT_RL is not supported yet") 177 | 178 | else: 179 | raise ValueError("Method %s does not exists." % cfg.method.name) 180 | 181 | wrapped_replay = PyTorchReplayBuffer( 182 | replay_buffer, num_workers=cfg.framework.num_workers 183 | ) 184 | stat_accum = SimpleAccumulator(eval_video_fps=30) 185 | 186 | cwd = os.getcwd() 187 | weightsdir = os.path.join(cwd, "seed%d" % seed, "weights") 188 | logdir = os.path.join(cwd, "seed%d" % seed) 189 | 190 | train_runner = OfflineTrainRunner( 191 | agent=agent, 192 | wrapped_replay_buffer=wrapped_replay, 193 | train_device=rank, 194 | stat_accumulator=stat_accum, 195 | iterations=cfg.framework.training_iterations, 196 | logdir=logdir, 197 | logging_level=cfg.framework.logging_level, 198 | log_freq=cfg.framework.log_freq, 199 | weightsdir=weightsdir, 200 | num_weights_to_keep=cfg.framework.num_weights_to_keep, 201 | save_freq=cfg.framework.save_freq, 202 | tensorboard_logging=cfg.framework.tensorboard_logging, 203 | csv_logging=cfg.framework.csv_logging, 204 | load_existing_weights=cfg.framework.load_existing_weights, 205 | rank=rank, 206 | world_size=world_size, 207 | ) 208 | 209 | train_runner._on_thread_start = partial( 210 | peract_config.config_logging, cfg.framework.logging_level 211 | ) 212 | 213 | train_runner.start() 214 | 215 | del train_runner 216 | del agent 217 | gc.collect() 218 | torch.cuda.empty_cache() 219 | -------------------------------------------------------------------------------- /agents/bimanual_peract/qattention_stack_agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from yarr.agents.agent import Agent, ActResult, Summary 5 | 6 | import numpy as np 7 | 8 | from helpers import utils 9 | from agents.bimanual_peract.qattention_peract_bc_agent import QAttentionPerActBCAgent 10 | 11 | NAME = "QAttentionStackAgent" 12 | 13 | 14 | class QAttentionStackAgent(Agent): 15 | def __init__( 16 | self, 17 | qattention_agents: List[QAttentionPerActBCAgent], 18 | rotation_resolution: float, 19 | camera_names: List[str], 20 | rotation_prediction_depth: int = 0, 21 | ): 22 | super(QAttentionStackAgent, self).__init__() 23 | self._qattention_agents = qattention_agents 24 | self._rotation_resolution = rotation_resolution 25 | self._camera_names = camera_names 26 | self._rotation_prediction_depth = rotation_prediction_depth 27 | 28 | def build(self, training: bool, device=None) -> None: 29 | self._device = device 30 | if self._device is None: 31 | self._device = torch.device("cpu") 32 | for qa in self._qattention_agents: 33 | qa.build(training, device) 34 | 35 | def update(self, step: int, replay_sample: dict) -> dict: 36 | priorities = 0 37 | total_losses = 0.0 38 | for qa in self._qattention_agents: 39 | update_dict = qa.update(step, replay_sample) 40 | replay_sample.update(update_dict) 41 | total_losses += update_dict["total_loss"] 42 | return { 43 | "total_losses": total_losses, 44 | } 45 | 46 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 47 | observation_elements = {} 48 | ( 49 | right_translation_results, 50 | right_rot_grip_results, 51 | right_ignore_collisions_results, 52 | ) = ([], [], []) 53 | ( 54 | left_translation_results, 55 | left_rot_grip_results, 56 | left_ignore_collisions_results, 57 | ) = ([], [], []) 58 | 59 | infos = {} 60 | for depth, qagent in enumerate(self._qattention_agents): 61 | act_results = qagent.act(step, observation, deterministic) 62 | right_attention_coordinate = ( 63 | act_results.observation_elements["right_attention_coordinate"] 64 | .cpu() 65 | .numpy() 66 | ) 67 | left_attention_coordinate = ( 68 | act_results.observation_elements["left_attention_coordinate"] 69 | .cpu() 70 | .numpy() 71 | ) 72 | observation_elements[ 73 | "right_attention_coordinate_layer_%d" % depth 74 | ] = right_attention_coordinate[0] 75 | observation_elements[ 76 | "left_attention_coordinate_layer_%d" % depth 77 | ] = left_attention_coordinate[0] 78 | 79 | ( 80 | right_translation_idxs, 81 | right_rot_grip_idxs, 82 | right_ignore_collisions_idxs, 83 | left_translation_idxs, 84 | left_rot_grip_idxs, 85 | left_ignore_collisions_idxs, 86 | ) = act_results.action 87 | 88 | right_translation_results.append(right_translation_idxs) 89 | if right_rot_grip_idxs is not None: 90 | right_rot_grip_results.append(right_rot_grip_idxs) 91 | if right_ignore_collisions_idxs is not None: 92 | right_ignore_collisions_results.append(right_ignore_collisions_idxs) 93 | 94 | left_translation_results.append(left_translation_idxs) 95 | if left_rot_grip_idxs is not None: 96 | left_rot_grip_results.append(left_rot_grip_idxs) 97 | if left_ignore_collisions_idxs is not None: 98 | left_ignore_collisions_results.append(left_ignore_collisions_idxs) 99 | 100 | observation[ 101 | "right_attention_coordinate" 102 | ] = act_results.observation_elements["right_attention_coordinate"] 103 | observation["left_attention_coordinate"] = act_results.observation_elements[ 104 | "left_attention_coordinate" 105 | ] 106 | 107 | observation["prev_layer_voxel_grid"] = act_results.observation_elements[ 108 | "prev_layer_voxel_grid" 109 | ] 110 | observation["prev_layer_bounds"] = act_results.observation_elements[ 111 | "prev_layer_bounds" 112 | ] 113 | 114 | for n in self._camera_names: 115 | extrinsics = observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy() 116 | intrinsics = observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy() 117 | px, py = utils.point_to_pixel_index( 118 | right_attention_coordinate[0], extrinsics, intrinsics 119 | ) 120 | pc_t = torch.tensor( 121 | [[[py, px]]], dtype=torch.float32, device=self._device 122 | ) 123 | observation[f"right_{n}_pixel_coord"] = pc_t 124 | observation_elements[f"right_{n}_pixel_coord"] = [py, px] 125 | 126 | px, py = utils.point_to_pixel_index( 127 | left_attention_coordinate[0], extrinsics, intrinsics 128 | ) 129 | pc_t = torch.tensor( 130 | [[[py, px]]], dtype=torch.float32, device=self._device 131 | ) 132 | observation[f"left_{n}_pixel_coord"] = pc_t 133 | observation_elements[f"left_{n}_pixel_coord"] = [py, px] 134 | infos.update(act_results.info) 135 | 136 | right_rgai = torch.cat(right_rot_grip_results, 1)[0].cpu().numpy() 137 | # ..todo:: utils.correct_rotation_instability does nothing so we can ignore it 138 | # right_rgai = utils.correct_rotation_instability(right_rgai, self._rotation_resolution) 139 | right_ignore_collisions = ( 140 | torch.cat(right_ignore_collisions_results, 1)[0].cpu().numpy() 141 | ) 142 | right_trans_action_indicies = ( 143 | torch.cat(right_translation_results, 1)[0].cpu().numpy() 144 | ) 145 | 146 | observation_elements[ 147 | "right_trans_action_indicies" 148 | ] = right_trans_action_indicies[:3] 149 | observation_elements["right_rot_grip_action_indicies"] = right_rgai[:4] 150 | 151 | left_rgai = torch.cat(left_rot_grip_results, 1)[0].cpu().numpy() 152 | left_ignore_collisions = ( 153 | torch.cat(left_ignore_collisions_results, 1)[0].cpu().numpy() 154 | ) 155 | left_trans_action_indicies = ( 156 | torch.cat(left_translation_results, 1)[0].cpu().numpy() 157 | ) 158 | 159 | observation_elements["left_trans_action_indicies"] = left_trans_action_indicies[ 160 | 3: 161 | ] 162 | observation_elements["left_rot_grip_action_indicies"] = left_rgai[4:] 163 | 164 | continuous_action = np.concatenate( 165 | [ 166 | right_attention_coordinate[0], 167 | utils.discrete_euler_to_quaternion( 168 | right_rgai[-4:-1], self._rotation_resolution 169 | ), 170 | right_rgai[-1:], 171 | right_ignore_collisions, 172 | left_attention_coordinate[0], 173 | utils.discrete_euler_to_quaternion( 174 | left_rgai[-4:-1], self._rotation_resolution 175 | ), 176 | left_rgai[-1:], 177 | left_ignore_collisions, 178 | ] 179 | ) 180 | return ActResult( 181 | continuous_action, observation_elements=observation_elements, info=infos 182 | ) 183 | 184 | def update_summaries(self) -> List[Summary]: 185 | summaries = [] 186 | for qa in self._qattention_agents: 187 | summaries.extend(qa.update_summaries()) 188 | return summaries 189 | 190 | def act_summaries(self) -> List[Summary]: 191 | s = [] 192 | for qa in self._qattention_agents: 193 | s.extend(qa.act_summaries()) 194 | return s 195 | 196 | def load_weights(self, savedir: str): 197 | for qa in self._qattention_agents: 198 | qa.load_weights(savedir) 199 | 200 | def save_weights(self, savedir: str): 201 | for qa in self._qattention_agents: 202 | qa.save_weights(savedir) 203 | -------------------------------------------------------------------------------- /helpers/observation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rlbench.backend.observation import Observation 3 | 4 | from rlbench.backend.observation import BimanualObservation 5 | from rlbench import CameraConfig, ObservationConfig 6 | from pyrep.const import RenderMode 7 | from typing import List 8 | 9 | REMOVE_KEYS = [ 10 | "joint_velocities", 11 | "joint_positions", 12 | "joint_forces", 13 | "gripper_open", 14 | "gripper_pose", 15 | "gripper_joint_positions", 16 | "gripper_touch_forces", 17 | "task_low_dim_state", 18 | "misc", 19 | ] 20 | 21 | 22 | def extract_obs( 23 | obs: Observation, 24 | cameras, 25 | t: int = 0, 26 | prev_action=None, 27 | channels_last: bool = False, 28 | episode_length: int = 10, 29 | robot_name: str = "", 30 | ): 31 | if obs.is_bimanual: 32 | return extract_obs_bimanual( 33 | obs, cameras, t, prev_action, channels_last, episode_length, robot_name 34 | ) 35 | else: 36 | return extract_obs_unimanual( 37 | obs, cameras, t, prev_action, channels_last, episode_length 38 | ) 39 | 40 | 41 | def extract_obs_unimanual( 42 | obs: Observation, 43 | cameras, 44 | t: int = 0, 45 | prev_action=None, 46 | channels_last: bool = False, 47 | episode_length: int = 10, 48 | ): 49 | obs.joint_velocities = None 50 | grip_mat = obs.gripper_matrix 51 | grip_pose = obs.gripper_pose 52 | joint_pos = obs.joint_positions 53 | obs.gripper_pose = None 54 | obs.gripper_matrix = None 55 | obs.joint_positions = None 56 | if obs.gripper_joint_positions is not None: 57 | obs.gripper_joint_positions = np.clip(obs.gripper_joint_positions, 0.0, 0.04) 58 | 59 | obs_dict = vars(obs) 60 | obs_dict = {k: v for k, v in obs_dict.items() if v is not None} 61 | robot_state = obs.get_low_dim_data() 62 | # remove low-level proprioception variables that are not needed 63 | obs_dict = {k: v for k, v in obs_dict.items() if k not in REMOVE_KEYS} 64 | 65 | if not channels_last: 66 | # swap channels from last dim to 1st dim 67 | obs_dict = { 68 | k: np.transpose(v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0) 69 | for k, v in obs.perception_data.items() 70 | if type(v) == np.ndarray or type(v) == list 71 | } 72 | else: 73 | # add extra dim to depth data 74 | obs_dict = { 75 | k: v if v.ndim == 3 else np.expand_dims(v, -1) 76 | for k, v in obs.perception_data.items() 77 | } 78 | obs_dict["low_dim_state"] = np.array(robot_state, dtype=np.float32) 79 | 80 | # binary variable indicating if collisions are allowed or not while planning paths to reach poses 81 | obs_dict["ignore_collisions"] = np.array([obs.ignore_collisions], dtype=np.float32) 82 | for k, v in [(k, v) for k, v in obs_dict.items() if "point_cloud" in k]: 83 | obs_dict[k] = v.astype(np.float32) 84 | 85 | for camera_name in cameras: 86 | obs_dict["%s_camera_extrinsics" % camera_name] = obs.misc[ 87 | "%s_camera_extrinsics" % camera_name 88 | ] 89 | obs_dict["%s_camera_intrinsics" % camera_name] = obs.misc[ 90 | "%s_camera_intrinsics" % camera_name 91 | ] 92 | 93 | # add timestep to low_dim_state 94 | time = (1.0 - (t / float(episode_length - 1))) * 2.0 - 1.0 95 | obs_dict["low_dim_state"] = np.concatenate( 96 | [obs_dict["low_dim_state"], [time]] 97 | ).astype(np.float32) 98 | 99 | obs.gripper_matrix = grip_mat 100 | obs.joint_positions = joint_pos 101 | obs.gripper_pose = grip_pose 102 | 103 | return obs_dict 104 | 105 | 106 | def extract_obs_bimanual( 107 | obs: Observation, 108 | cameras, 109 | t: int = 0, 110 | prev_action=None, 111 | channels_last: bool = False, 112 | episode_length: int = 10, 113 | robot_name: str = "", 114 | ): 115 | obs.right.joint_velocities = None 116 | right_grip_mat = obs.right.gripper_matrix 117 | right_grip_pose = obs.right.gripper_pose 118 | right_joint_pos = obs.right.joint_positions 119 | obs.right.gripper_pose = None 120 | obs.right.gripper_matrix = None 121 | obs.right.joint_positions = None 122 | 123 | obs.left.joint_velocities = None 124 | left_grip_mat = obs.left.gripper_matrix 125 | left_grip_pose = obs.left.gripper_pose 126 | left_joint_pos = obs.left.joint_positions 127 | obs.left.gripper_pose = None 128 | obs.left.gripper_matrix = None 129 | obs.left.joint_positions = None 130 | 131 | if obs.right.gripper_joint_positions is not None: 132 | obs.right.gripper_joint_positions = np.clip( 133 | obs.right.gripper_joint_positions, 0.0, 0.04 134 | ) 135 | obs.left.gripper_joint_positions = np.clip( 136 | obs.left.gripper_joint_positions, 0.0, 0.04 137 | ) 138 | 139 | # fixme:: 140 | obs_dict = vars(obs) 141 | obs_dict = {k: v for k, v in obs_dict.items() if v is not None} 142 | 143 | right_robot_state = obs.get_low_dim_data(obs.right) 144 | left_robot_state = obs.get_low_dim_data(obs.left) 145 | 146 | # remove low-level proprioception variables that are not needed 147 | obs_dict = {k: v for k, v in obs_dict.items() if k not in REMOVE_KEYS} 148 | 149 | if not channels_last: 150 | # swap channels from last dim to 1st dim 151 | obs_dict = { 152 | k: np.transpose(v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0) 153 | for k, v in obs.perception_data.items() 154 | if type(v) == np.ndarray or type(v) == list 155 | } 156 | else: 157 | # add extra dim to depth data 158 | obs_dict = { 159 | k: v if v.ndim == 3 else np.expand_dims(v, -1) 160 | for k, v in obs.perception_data.items() 161 | } 162 | 163 | if robot_name == "right": 164 | obs_dict["low_dim_state"] = right_robot_state.astype(np.float32) 165 | # binary variable indicating if collisions are allowed or not while planning paths to reach poses 166 | obs_dict["ignore_collisions"] = np.array( 167 | [obs.right.ignore_collisions], dtype=np.float32 168 | ) 169 | elif robot_name == "left": 170 | obs_dict["low_dim_state"] = left_robot_state.astype(np.float32) 171 | obs_dict["ignore_collisions"] = np.array( 172 | [obs.left.ignore_collisions], dtype=np.float32 173 | ) 174 | elif robot_name == "bimanual": 175 | obs_dict["right_low_dim_state"] = right_robot_state.astype(np.float32) 176 | obs_dict["left_low_dim_state"] = left_robot_state.astype(np.float32) 177 | obs_dict["right_ignore_collisions"] = np.array( 178 | [obs.right.ignore_collisions], dtype=np.float32 179 | ) 180 | obs_dict["left_ignore_collisions"] = np.array( 181 | [obs.left.ignore_collisions], dtype=np.float32 182 | ) 183 | 184 | for k, v in [(k, v) for k, v in obs_dict.items() if "point_cloud" in k]: 185 | # ..TODO:: switch to np.float16 186 | obs_dict[k] = v.astype(np.float32) 187 | 188 | for camera_name in cameras: 189 | obs_dict["%s_camera_extrinsics" % camera_name] = obs.misc[ 190 | "%s_camera_extrinsics" % camera_name 191 | ] 192 | obs_dict["%s_camera_intrinsics" % camera_name] = obs.misc[ 193 | "%s_camera_intrinsics" % camera_name 194 | ] 195 | 196 | # add timestep to low_dim_state 197 | time = (1.0 - (t / float(episode_length - 1))) * 2.0 - 1.0 198 | 199 | if "low_dim_state" in obs_dict: 200 | obs_dict["low_dim_state"] = np.concatenate( 201 | [obs_dict["low_dim_state"], [time]] 202 | ).astype(np.float32) 203 | else: 204 | obs_dict["right_low_dim_state"] = np.concatenate( 205 | [obs_dict["right_low_dim_state"], [time]] 206 | ).astype(np.float32) 207 | obs_dict["left_low_dim_state"] = np.concatenate( 208 | [obs_dict["left_low_dim_state"], [time]] 209 | ).astype(np.float32) 210 | 211 | obs.right.gripper_matrix = right_grip_mat 212 | obs.right.joint_positions = right_joint_pos 213 | obs.right.gripper_pose = right_grip_pose 214 | obs.left.gripper_matrix = left_grip_mat 215 | obs.left.joint_positions = left_joint_pos 216 | obs.left.gripper_pose = left_grip_pose 217 | 218 | return obs_dict 219 | 220 | 221 | def create_obs_config( 222 | camera_names: List[str], 223 | camera_resolution: List[int], 224 | method_name: str, 225 | robot_name: str = "bimanual", 226 | ): 227 | unused_cams = CameraConfig() 228 | unused_cams.set_all(False) 229 | used_cams = CameraConfig( 230 | rgb=True, 231 | point_cloud=True, 232 | mask=False, 233 | depth=False, 234 | image_size=camera_resolution, 235 | render_mode=RenderMode.OPENGL, 236 | ) 237 | 238 | camera_configs = {camera_name: used_cams for camera_name in camera_names} 239 | 240 | # Some of these obs are only used for keypoint detection. 241 | obs_config = ObservationConfig( 242 | camera_configs=camera_configs, 243 | joint_forces=False, 244 | joint_positions=True, 245 | joint_velocities=True, 246 | task_low_dim_state=False, 247 | gripper_touch_forces=False, 248 | gripper_pose=True, 249 | gripper_open=True, 250 | gripper_matrix=True, 251 | gripper_joint_positions=True, 252 | robot_name=robot_name, 253 | ) 254 | return obs_config 255 | -------------------------------------------------------------------------------- /voxel/voxel_grid.py: -------------------------------------------------------------------------------- 1 | # Voxelizer modified from ARM for DDP training 2 | # Source: https://github.com/stepjam/ARM 3 | # License: https://github.com/stepjam/ARM/LICENSE 4 | 5 | from functools import reduce 6 | from operator import mul 7 | 8 | import torch 9 | from torch import nn 10 | 11 | MIN_DENOMINATOR = 1e-12 12 | INCLUDE_PER_VOXEL_COORD = False 13 | 14 | 15 | class VoxelGrid(nn.Module): 16 | def __init__( 17 | self, 18 | coord_bounds, 19 | voxel_size: int, 20 | device, 21 | batch_size, 22 | feature_size, # e.g. rgb or image features 23 | max_num_coords: int, 24 | ): 25 | super(VoxelGrid, self).__init__() 26 | self._device = device 27 | self._voxel_size = voxel_size 28 | self._voxel_shape = [voxel_size] * 3 29 | self._voxel_d = float(self._voxel_shape[-1]) 30 | self._voxel_feature_size = 4 + feature_size 31 | self._voxel_shape_spec = ( 32 | torch.tensor( 33 | self._voxel_shape, 34 | ).unsqueeze(0) 35 | + 2 36 | ) # +2 because we crop the edges. 37 | self._coord_bounds = torch.tensor( 38 | coord_bounds, 39 | dtype=torch.float, 40 | ).unsqueeze(0) 41 | max_dims = self._voxel_shape_spec[0] 42 | self._total_dims_list = torch.cat( 43 | [ 44 | torch.tensor( 45 | [batch_size], 46 | ), 47 | max_dims, 48 | torch.tensor( 49 | [4 + feature_size], 50 | ), 51 | ], 52 | -1, 53 | ).tolist() 54 | 55 | self.register_buffer( 56 | "_ones_max_coords", torch.ones((batch_size, max_num_coords, 1)) 57 | ) 58 | self._num_coords = max_num_coords 59 | 60 | shape = self._total_dims_list 61 | result_dim_sizes = torch.tensor( 62 | [reduce(mul, shape[i + 1 :], 1) for i in range(len(shape) - 1)] + [1], 63 | ) 64 | self.register_buffer("_result_dim_sizes", result_dim_sizes) 65 | flat_result_size = reduce(mul, shape, 1) 66 | 67 | self._initial_val = torch.tensor(0, dtype=torch.float) 68 | flat_output = ( 69 | torch.ones(flat_result_size, dtype=torch.float) * self._initial_val 70 | ) 71 | self.register_buffer("_flat_output", flat_output) 72 | 73 | self.register_buffer("_arange_to_max_coords", torch.arange(4 + feature_size)) 74 | self._flat_zeros = torch.zeros(flat_result_size, dtype=torch.float) 75 | 76 | self._const_1 = torch.tensor( 77 | 1.0, 78 | ) 79 | self._batch_size = batch_size 80 | 81 | # Coordinate Bounds: 82 | bb_mins = self._coord_bounds[..., 0:3] 83 | self.register_buffer("_bb_mins", bb_mins) 84 | bb_maxs = self._coord_bounds[..., 3:6] 85 | bb_ranges = bb_maxs - bb_mins 86 | # get voxel dimensions. 'DIMS' mode 87 | self._dims = dims = self._voxel_shape_spec.int() 88 | dims_orig = self._voxel_shape_spec.int() - 2 89 | self.register_buffer("_dims_orig", dims_orig) 90 | 91 | # self._dims_m_one = (dims - 1).int() 92 | dims_m_one = (dims - 1).int() 93 | self.register_buffer("_dims_m_one", dims_m_one) 94 | 95 | # BS x 1 x 3 96 | res = bb_ranges / (dims_orig.float() + MIN_DENOMINATOR) 97 | self._res_minis_2 = bb_ranges / (dims.float() - 2 + MIN_DENOMINATOR) 98 | self.register_buffer("_res", res) 99 | 100 | voxel_indicy_denmominator = res + MIN_DENOMINATOR 101 | self.register_buffer("_voxel_indicy_denmominator", voxel_indicy_denmominator) 102 | 103 | self.register_buffer("_dims_m_one_zeros", torch.zeros_like(dims_m_one)) 104 | 105 | batch_indices = torch.arange(self._batch_size, dtype=torch.int).view( 106 | self._batch_size, 1, 1 107 | ) 108 | self.register_buffer( 109 | "_tiled_batch_indices", batch_indices.repeat([1, self._num_coords, 1]) 110 | ) 111 | 112 | w = self._voxel_shape[0] + 2 113 | arange = torch.arange( 114 | 0, 115 | w, 116 | dtype=torch.float, 117 | ) 118 | index_grid = ( 119 | torch.cat( 120 | [ 121 | arange.view(w, 1, 1, 1).repeat([1, w, w, 1]), 122 | arange.view(1, w, 1, 1).repeat([w, 1, w, 1]), 123 | arange.view(1, 1, w, 1).repeat([w, w, 1, 1]), 124 | ], 125 | dim=-1, 126 | ) 127 | .unsqueeze(0) 128 | .repeat([self._batch_size, 1, 1, 1, 1]) 129 | ) 130 | self.register_buffer("_index_grid", index_grid) 131 | 132 | def _broadcast(self, src: torch.Tensor, other: torch.Tensor, dim: int): 133 | if dim < 0: 134 | dim = other.dim() + dim 135 | if src.dim() == 1: 136 | for _ in range(0, dim): 137 | src = src.unsqueeze(0) 138 | for _ in range(src.dim(), other.dim()): 139 | src = src.unsqueeze(-1) 140 | src = src.expand_as(other) 141 | return src 142 | 143 | def _scatter_mean( 144 | self, src: torch.Tensor, index: torch.Tensor, out: torch.Tensor, dim: int = -1 145 | ): 146 | out = out.scatter_add_(dim, index, src) 147 | 148 | index_dim = dim 149 | if index_dim < 0: 150 | index_dim = index_dim + src.dim() 151 | if index.dim() <= index_dim: 152 | index_dim = index.dim() - 1 153 | 154 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 155 | out_count = torch.zeros(out.size(), dtype=out.dtype, device=out.device) 156 | out_count = out_count.scatter_add_(index_dim, index, ones) 157 | out_count.clamp_(1) 158 | count = self._broadcast(out_count, out, dim) 159 | if torch.is_floating_point(out): 160 | out.true_divide_(count) 161 | else: 162 | out.floor_divide_(count) 163 | return out 164 | 165 | def _scatter_nd(self, indices, updates): 166 | indices_shape = indices.shape 167 | num_index_dims = indices_shape[-1] 168 | flat_updates = updates.view((-1,)) 169 | indices_scales = self._result_dim_sizes[0:num_index_dims].view( 170 | [1] * (len(indices_shape) - 1) + [num_index_dims] 171 | ) 172 | indices_for_flat_tiled = ( 173 | ((indices * indices_scales).sum(dim=-1, keepdims=True)) 174 | .view(-1, 1) 175 | .repeat(*[1, self._voxel_feature_size]) 176 | ) 177 | 178 | implicit_indices = ( 179 | self._arange_to_max_coords[: self._voxel_feature_size] 180 | .unsqueeze(0) 181 | .repeat(*[indices_for_flat_tiled.shape[0], 1]) 182 | ) 183 | indices_for_flat = indices_for_flat_tiled + implicit_indices 184 | flat_indices_for_flat = indices_for_flat.view((-1,)).long() 185 | 186 | flat_scatter = self._scatter_mean( 187 | flat_updates, flat_indices_for_flat, out=torch.zeros_like(self._flat_output) 188 | ) 189 | return flat_scatter.view(self._total_dims_list) 190 | 191 | def coords_to_bounding_voxel_grid( 192 | self, coords, coord_features=None, coord_bounds=None 193 | ): 194 | voxel_indicy_denmominator = self._voxel_indicy_denmominator 195 | res, bb_mins = self._res, self._bb_mins 196 | if coord_bounds is not None: 197 | bb_mins = coord_bounds[..., 0:3] 198 | bb_maxs = coord_bounds[..., 3:6] 199 | bb_ranges = bb_maxs - bb_mins 200 | res = bb_ranges / (self._dims_orig.float() + MIN_DENOMINATOR) 201 | voxel_indicy_denmominator = res + MIN_DENOMINATOR 202 | 203 | bb_mins_shifted = bb_mins - res # shift back by one 204 | floor = torch.floor( 205 | (coords - bb_mins_shifted.unsqueeze(1)) 206 | / voxel_indicy_denmominator.unsqueeze(1) 207 | ).int() 208 | voxel_indices = torch.min(floor, self._dims_m_one) 209 | voxel_indices = torch.max(voxel_indices, self._dims_m_one_zeros) 210 | 211 | # BS x NC x 3 212 | voxel_values = coords 213 | if coord_features is not None: 214 | voxel_values = torch.cat([voxel_values, coord_features], -1) 215 | 216 | _, num_coords, _ = voxel_indices.shape 217 | # BS x N x (num_batch_dims + 2) 218 | all_indices = torch.cat( 219 | [self._tiled_batch_indices[:, :num_coords], voxel_indices], -1 220 | ) 221 | 222 | # BS x N x 4 223 | voxel_values_pruned_flat = torch.cat( 224 | [voxel_values, self._ones_max_coords[:, :num_coords]], -1 225 | ) 226 | 227 | # BS x x_max x y_max x z_max x 4 228 | scattered = self._scatter_nd( 229 | all_indices.view([-1, 1 + 3]), 230 | voxel_values_pruned_flat.view(-1, self._voxel_feature_size), 231 | ) 232 | 233 | vox = scattered[:, 1:-1, 1:-1, 1:-1] 234 | if INCLUDE_PER_VOXEL_COORD: 235 | res_expanded = res.unsqueeze(1).unsqueeze(1).unsqueeze(1) 236 | res_centre = (res_expanded * self._index_grid) + res_expanded / 2.0 237 | coord_positions = ( 238 | res_centre + bb_mins_shifted.unsqueeze(1).unsqueeze(1).unsqueeze(1) 239 | )[:, 1:-1, 1:-1, 1:-1] 240 | vox = torch.cat([vox[..., :-1], coord_positions, vox[..., -1:]], -1) 241 | 242 | occupied = (vox[..., -1:] > 0).float() 243 | vox = torch.cat([vox[..., :-1], occupied], -1) 244 | 245 | return torch.cat( 246 | [ 247 | vox[..., :-1], 248 | self._index_grid[:, :-2, :-2, :-2] / self._voxel_d, 249 | vox[..., -1:], 250 | ], 251 | -1, 252 | ) 253 | -------------------------------------------------------------------------------- /agents/arm/qattention_agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | from typing import List 5 | 6 | import PIL 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision import transforms 11 | 12 | from yarr.agents.agent import ( 13 | Agent, 14 | ActResult, 15 | ScalarSummary, 16 | HistogramSummary, 17 | ImageSummary, 18 | Summary, 19 | ) 20 | 21 | from helpers import utils 22 | from helpers.utils import stack_on_channel 23 | 24 | NAME = "QAttentionAgent" 25 | REPLAY_BETA = 1.0 26 | 27 | 28 | class QFunction(nn.Module): 29 | def __init__(self, unet: nn.Module): 30 | super(QFunction, self).__init__() 31 | self._qnet = copy.deepcopy(unet) 32 | self._qnet2 = copy.deepcopy(unet) 33 | self._qnet.build() 34 | self._qnet2.build() 35 | 36 | def _argmax_2d(self, tensor): 37 | t_shape = tensor.shape 38 | m = tensor.view(t_shape[0], -1).argmax(1).view(-1, 1) 39 | indices = torch.cat((m // t_shape[-1], m % t_shape[-1]), dim=1) 40 | return indices 41 | 42 | def forward(self, x, robot_state): 43 | q = self._qnet(x, robot_state)[:, 0] 44 | q2 = self._qnet2(x, robot_state)[:, 0] 45 | coords = self._argmax_2d(torch.min(q, q2)) 46 | return q, q2, coords 47 | 48 | 49 | class QAttentionAgent(Agent): 50 | def __init__( 51 | self, 52 | pixel_unet: nn.Module, 53 | camera_name: str, 54 | tau: float = 0.005, 55 | gamma: float = 0.99, 56 | nstep: int = 1, 57 | lr: float = 0.0001, 58 | weight_decay: float = 1e-5, 59 | lambda_qreg: float = 1e-6, 60 | grad_clip: float = 20.0, 61 | include_low_dim_state: bool = False, 62 | ): 63 | self._pixel_unet = pixel_unet 64 | self._camera_name = camera_name 65 | self._tau = tau 66 | self._gamma = gamma 67 | self._nstep = nstep 68 | self._lr = lr 69 | self._weight_decay = weight_decay 70 | self._lambda_qreg = lambda_qreg 71 | self._grad_clip = grad_clip 72 | self._include_low_dim_state = include_low_dim_state 73 | 74 | def build(self, training: bool, device: torch.device = None): 75 | if device is None: 76 | device = torch.device("cpu") 77 | self._q = QFunction(self._pixel_unet).to(device).train(training) 78 | self._q_target = None 79 | if training: 80 | self._q_target = QFunction(self._pixel_unet).to(device).train(False) 81 | for p in self._q_target.parameters(): 82 | p.requires_grad = False 83 | utils.soft_updates(self._q, self._q_target, 1.0) 84 | self._optimizer = torch.optim.Adam( 85 | self._q.parameters(), lr=self._lr, weight_decay=self._weight_decay 86 | ) 87 | logging.info( 88 | "# Q-attention Params: %d" 89 | % sum(p.numel() for p in self._q.parameters() if p.requires_grad) 90 | ) 91 | else: 92 | for p in self._q.parameters(): 93 | p.requires_grad = False 94 | self._device = device 95 | 96 | def _get_q_from_pixel_coord(self, q, coord): 97 | b, h, w = q.shape 98 | flat_indicies = (coord[:, 0] * w + coord[:, 1])[:, None].long() 99 | return q.view(b, h * w).gather(1, flat_indicies) 100 | 101 | def _preprocess_inputs(self, replay_sample): 102 | observations = [ 103 | stack_on_channel(replay_sample["%s_rgb" % self._camera_name]), 104 | stack_on_channel(replay_sample["%s_point_cloud" % self._camera_name]), 105 | ] 106 | tp1_observations = [ 107 | stack_on_channel(replay_sample["%s_rgb_tp1" % self._camera_name]), 108 | stack_on_channel(replay_sample["%s_point_cloud_tp1" % self._camera_name]), 109 | ] 110 | return observations, tp1_observations 111 | 112 | def update(self, step: int, replay_sample: dict) -> dict: 113 | pixel_action = replay_sample["%s_pixel_coord" % self._camera_name][:, -1].int() 114 | reward = replay_sample["reward"] 115 | reward = torch.where(reward > 0, reward, torch.zeros_like(reward)) 116 | 117 | robot_state = robot_state_tp1 = None 118 | if self._include_low_dim_state: 119 | robot_state = stack_on_channel(replay_sample["low_dim_state"]) 120 | robot_state_tp1 = stack_on_channel(replay_sample["low_dim_state_tp1"]) 121 | 122 | # Don't want timeouts to be classed as terminals 123 | terminal = replay_sample["terminal"].float() - replay_sample["timeout"].float() 124 | 125 | obs, obs_tp1 = self._preprocess_inputs(replay_sample) 126 | q, q2, coords = self._q(obs, robot_state) 127 | 128 | with torch.no_grad(): 129 | # (B, h, w) 130 | _, _, coords_tp1 = self._q(obs_tp1, robot_state_tp1) 131 | q_tp1_targ, q2_tp1_targ, _ = self._q_target(obs_tp1, robot_state_tp1) 132 | q_tp1_targ = torch.min(q_tp1_targ, q2_tp1_targ) 133 | q_tp1_targ = self._get_q_from_pixel_coord(q_tp1_targ, coords_tp1) 134 | target = ( 135 | reward.unsqueeze(1) 136 | + (self._gamma**self._nstep) 137 | * (1 - terminal.unsqueeze(1)) 138 | * q_tp1_targ 139 | ) 140 | target = torch.clamp(target, 0.0, 100.0) 141 | 142 | q_pred = self._get_q_from_pixel_coord(q, pixel_action) 143 | delta = F.smooth_l1_loss(q_pred, target, reduction="none").mean(1) 144 | 145 | delta += F.smooth_l1_loss( 146 | self._get_q_from_pixel_coord(q2, pixel_action), target, reduction="none" 147 | ).mean(1) 148 | q_reg = ( 149 | (0.5 * torch.sum(q**2)) + (0.5 * torch.sum(q2**2)) 150 | ) * self._lambda_qreg 151 | 152 | loss_weights = utils.loss_weights(replay_sample, REPLAY_BETA) 153 | total_loss = ((delta) * loss_weights).mean() + q_reg 154 | new_priority = ((delta) + 1e-10).sqrt() 155 | new_priority /= new_priority.max() 156 | 157 | self._summaries = { 158 | "losses/bellman": delta.mean(), 159 | "losses/qreg": q_reg.mean(), 160 | "q/mean": q.mean(), 161 | "q/action_q": q_pred.mean(), 162 | } 163 | self._qvalues = q[:1] 164 | self._rgb_observation = replay_sample["front_rgb"][0, -1] 165 | self._optimizer.zero_grad() 166 | total_loss.backward() 167 | if self._grad_clip is not None: 168 | nn.utils.clip_grad_value_(self._q.parameters(), self._grad_clip) 169 | self._optimizer.step() 170 | utils.soft_updates(self._q, self._q_target, self._tau) 171 | 172 | return { 173 | "priority": new_priority, 174 | } 175 | 176 | def act(self, step: int, observation: dict, deterministic=False) -> ActResult: 177 | with torch.no_grad(): 178 | observations = [ 179 | stack_on_channel(observation["%s_rgb" % self._camera_name]), 180 | stack_on_channel(observation["%s_point_cloud" % self._camera_name]), 181 | ] 182 | robot_state = None 183 | if self._include_low_dim_state: 184 | robot_state = stack_on_channel(observation["low_dim_state"]) 185 | # Coords are stored as (y, x) 186 | q, q2, coords = self._q(observations, robot_state) 187 | self._act_qvalues = torch.min(q, q2)[:1] 188 | self._rgb_observation = observation["front_rgb"][0, -1] 189 | return ActResult( 190 | coords[0], 191 | observation_elements={ 192 | "%s_pixel_coord" % self._camera_name: coords[0], 193 | }, 194 | info={"q_values": self._act_qvalues}, 195 | ) 196 | 197 | @staticmethod 198 | def generate_heatmap(q_values, rgb_obs): 199 | norm_q = torch.clamp(q_values / 100.0, 0, 1) 200 | heatmap = torch.cat( 201 | [norm_q, torch.zeros_like(norm_q), torch.zeros_like(norm_q)] 202 | ) 203 | img = transforms.functional.to_pil_image(rgb_obs) 204 | h_img = transforms.functional.to_pil_image(heatmap).convert("RGB") 205 | ret = PIL.Image.blend(img, h_img, 0.75) 206 | return transforms.ToTensor()(ret).unsqueeze_(0) 207 | 208 | def update_summaries(self) -> List[Summary]: 209 | summaries = [ 210 | ImageSummary( 211 | "%s/Q" % NAME, 212 | QAttentionAgent.generate_heatmap( 213 | self._qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu() 214 | ), 215 | ) 216 | ] 217 | for n, v in self._summaries.items(): 218 | summaries.append(ScalarSummary("%s/%s" % (NAME, n), v)) 219 | 220 | for tag, param in self._q.named_parameters(): 221 | assert not torch.isnan(param.grad.abs() <= 1.0).all() 222 | summaries.append( 223 | HistogramSummary("%s/gradient/%s" % (NAME, tag), param.grad) 224 | ) 225 | summaries.append(HistogramSummary("%s/weight/%s" % (NAME, tag), param.data)) 226 | return summaries 227 | 228 | def act_summaries(self) -> List[Summary]: 229 | return [ 230 | ImageSummary( 231 | "%s/Q_act" % NAME, 232 | QAttentionAgent.generate_heatmap( 233 | self._act_qvalues.cpu(), ((self._rgb_observation + 1) / 2.0).cpu() 234 | ), 235 | ) 236 | ] 237 | 238 | def load_weights(self, savedir: str): 239 | self._q.load_state_dict( 240 | torch.load( 241 | os.path.join(savedir, "pixel_agent_q.pt"), 242 | map_location=torch.device("cpu"), 243 | ) 244 | ) 245 | 246 | def save_weights(self, savedir: str): 247 | torch.save(self._q.state_dict(), os.path.join(savedir, "pixel_agent_q.pt")) 248 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import os 4 | import sys 5 | 6 | import peract_config 7 | 8 | import hydra 9 | import numpy as np 10 | import torch 11 | import pandas as pd 12 | from omegaconf import DictConfig, OmegaConf, ListConfig 13 | from rlbench.action_modes.action_mode import BimanualMoveArmThenGripper 14 | from rlbench.action_modes.action_mode import BimanualJointPositionActionMode 15 | from rlbench.action_modes.arm_action_modes import BimanualEndEffectorPoseViaPlanning 16 | from rlbench.action_modes.arm_action_modes import BimanualJointPosition, JointPosition 17 | from rlbench.action_modes.gripper_action_modes import BimanualDiscrete 18 | from rlbench.action_modes.action_mode import MoveArmThenGripper 19 | from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning 20 | from rlbench.action_modes.gripper_action_modes import Discrete 21 | 22 | from rlbench.backend import task as rlbench_task 23 | from rlbench.backend.utils import task_file_to_task_class 24 | from yarr.runners.independent_env_runner import IndependentEnvRunner 25 | from yarr.utils.stat_accumulator import SimpleAccumulator 26 | 27 | from helpers import utils 28 | from helpers import observation_utils 29 | 30 | from yarr.utils.rollout_generator import RolloutGenerator 31 | import torch.multiprocessing as mp 32 | 33 | from agents import agent_factory 34 | 35 | 36 | def eval_seed( 37 | train_cfg, eval_cfg, logdir, env_device, multi_task, seed, env_config 38 | ) -> None: 39 | tasks = eval_cfg.rlbench.tasks 40 | rg = RolloutGenerator() 41 | 42 | train_cfg.method.robot_name = eval_cfg.method.robot_name 43 | 44 | agent = agent_factory.create_agent(train_cfg) 45 | stat_accum = SimpleAccumulator(eval_video_fps=30) 46 | 47 | cwd = os.getcwd() 48 | weightsdir = os.path.join(logdir, "weights") 49 | 50 | env_runner = IndependentEnvRunner( 51 | train_env=None, 52 | agent=agent, 53 | train_replay_buffer=None, 54 | num_train_envs=0, 55 | num_eval_envs=eval_cfg.framework.eval_envs, 56 | rollout_episodes=99999, 57 | eval_episodes=eval_cfg.framework.eval_episodes, 58 | training_iterations=train_cfg.framework.training_iterations, 59 | eval_from_eps_number=eval_cfg.framework.eval_from_eps_number, 60 | episode_length=eval_cfg.rlbench.episode_length, 61 | stat_accumulator=stat_accum, 62 | weightsdir=weightsdir, 63 | logdir=logdir, 64 | env_device=env_device, 65 | rollout_generator=rg, 66 | num_eval_runs=len(tasks), 67 | multi_task=multi_task, 68 | ) 69 | 70 | env_runner._on_thread_start = peract_config.config_logging 71 | 72 | manager = mp.Manager() 73 | save_load_lock = manager.Lock() 74 | writer_lock = manager.Lock() 75 | 76 | # evaluate all checkpoints (0, 1000, ...) which don't have results, i.e. validation phase 77 | if eval_cfg.framework.eval_type == "missing": 78 | weight_folders = os.listdir(weightsdir) 79 | weight_folders = sorted(map(int, weight_folders)) 80 | 81 | env_data_csv_file = os.path.join(logdir, "eval_data.csv") 82 | if os.path.exists(env_data_csv_file): 83 | env_dict = pd.read_csv(env_data_csv_file).to_dict() 84 | evaluated_weights = sorted(map(int, list(env_dict["step"].values()))) 85 | weight_folders = [w for w in weight_folders if w not in evaluated_weights] 86 | 87 | print("Missing weights: ", weight_folders) 88 | 89 | # pick the best checkpoint from validation and evaluate, i.e. test phase 90 | elif eval_cfg.framework.eval_type == "best": 91 | env_data_csv_file = os.path.join(logdir, "eval_data.csv") 92 | if os.path.exists(env_data_csv_file): 93 | env_dict = pd.read_csv(env_data_csv_file).to_dict() 94 | existing_weights = list( 95 | map(int, sorted(os.listdir(os.path.join(logdir, "weights")))) 96 | ) 97 | task_weights = {} 98 | for task in tasks: 99 | weights = list(env_dict["step"].values()) 100 | 101 | if len(tasks) > 1: 102 | task_score = list(env_dict["eval_envs/return/%s" % task].values()) 103 | else: 104 | task_score = list(env_dict["eval_envs/return"].values()) 105 | 106 | avail_weights, avail_task_scores = [], [] 107 | for step_idx, step in enumerate(weights): 108 | if step in existing_weights: 109 | avail_weights.append(step) 110 | avail_task_scores.append(task_score[step_idx]) 111 | 112 | assert len(avail_weights) == len(avail_task_scores) 113 | best_weight = avail_weights[ 114 | np.argwhere(avail_task_scores == np.amax(avail_task_scores)) 115 | .flatten() 116 | .tolist()[-1] 117 | ] 118 | task_weights[task] = best_weight 119 | 120 | weight_folders = [task_weights] 121 | print("Best weights:", weight_folders) 122 | else: 123 | raise Exception("No existing eval_data.csv file found in %s" % logdir) 124 | 125 | # evaluate only the last checkpoint 126 | elif eval_cfg.framework.eval_type == "last": 127 | weight_folders = os.listdir(weightsdir) 128 | weight_folders = sorted(map(int, weight_folders)) 129 | weight_folders = [weight_folders[-1]] 130 | print("Last weight:", weight_folders) 131 | 132 | elif eval_cfg.framework.eval_type == "all": 133 | weight_folders = os.listdir(weightsdir) 134 | weight_folders = sorted(map(int, weight_folders)) 135 | 136 | # evaluate a specific checkpoint 137 | elif type(eval_cfg.framework.eval_type) == int: 138 | weight_folders = [int(eval_cfg.framework.eval_type)] 139 | print("Weight:", weight_folders) 140 | 141 | else: 142 | raise Exception("Unknown eval type") 143 | 144 | if len(weight_folders) == 0: 145 | logging.info( 146 | "No weights to evaluate. Results are already available in eval_data.csv" 147 | ) 148 | sys.exit(0) 149 | 150 | # evaluate several checkpoints in parallel 151 | # NOTE: in multi-task settings, each task is evaluated serially, which makes everything slow! 152 | split_n = utils.split_list(weight_folders, eval_cfg.framework.eval_envs) 153 | for split in split_n: 154 | processes = [] 155 | for e_idx, weight in enumerate(split): 156 | p = mp.Process( 157 | target=env_runner.start, 158 | args=( 159 | weight, 160 | save_load_lock, 161 | writer_lock, 162 | env_config, 163 | e_idx % torch.cuda.device_count(), 164 | eval_cfg.framework.eval_save_metrics, 165 | eval_cfg.cinematic_recorder, 166 | ), 167 | ) 168 | p.start() 169 | processes.append(p) 170 | for p in processes: 171 | p.join() 172 | 173 | del env_runner 174 | del agent 175 | gc.collect() 176 | torch.cuda.empty_cache() 177 | 178 | 179 | @hydra.main(config_name="eval", config_path="conf") 180 | def main(eval_cfg: DictConfig) -> None: 181 | logging.info("\n" + OmegaConf.to_yaml(eval_cfg)) 182 | 183 | start_seed = eval_cfg.framework.start_seed 184 | logdir = os.path.join( 185 | eval_cfg.framework.logdir, 186 | eval_cfg.rlbench.task_name, 187 | eval_cfg.method.name, 188 | "seed%d" % start_seed, 189 | ) 190 | 191 | train_config_path = os.path.join(logdir, "config.yaml") 192 | 193 | if os.path.exists(train_config_path): 194 | with open(train_config_path, "r") as f: 195 | train_cfg = OmegaConf.load(f) 196 | else: 197 | raise Exception(f"Missing seed{start_seed}/config.yaml. Logdir is {logdir}") 198 | 199 | # sanity checks 200 | assert train_cfg.method.name == eval_cfg.method.name 201 | assert train_cfg.method.agent_type == eval_cfg.method.agent_type 202 | for task in eval_cfg.rlbench.tasks: 203 | assert task in train_cfg.rlbench.tasks 204 | 205 | env_device = utils.get_device(eval_cfg.framework.gpu) 206 | logging.info("Using env device %s." % str(env_device)) 207 | 208 | gripper_mode = eval(eval_cfg.rlbench.gripper_mode)() 209 | arm_action_mode = eval(eval_cfg.rlbench.arm_action_mode)() 210 | action_mode = eval(eval_cfg.rlbench.action_mode)(arm_action_mode, gripper_mode) 211 | 212 | is_bimanual = eval_cfg.method.robot_name == "bimanual" 213 | 214 | if is_bimanual: 215 | # TODO: automate instantiation with eval 216 | task_path = rlbench_task.BIMANUAL_TASKS_PATH 217 | else: 218 | task_path = rlbench_task.TASKS_PATH 219 | 220 | task_files = [ 221 | t.replace(".py", "") 222 | for t in os.listdir(task_path) 223 | if t != "__init__.py" and t.endswith(".py") 224 | ] 225 | eval_cfg.rlbench.cameras = ( 226 | eval_cfg.rlbench.cameras 227 | if isinstance(eval_cfg.rlbench.cameras, ListConfig) 228 | else [eval_cfg.rlbench.cameras] 229 | ) 230 | obs_config = observation_utils.create_obs_config( 231 | eval_cfg.rlbench.cameras, 232 | eval_cfg.rlbench.camera_resolution, 233 | eval_cfg.method.name, 234 | eval_cfg.method.robot_name, 235 | ) 236 | 237 | if eval_cfg.cinematic_recorder.enabled: 238 | obs_config.record_gripper_closing = True 239 | 240 | multi_task = len(eval_cfg.rlbench.tasks) > 1 241 | 242 | tasks = eval_cfg.rlbench.tasks 243 | task_classes = [] 244 | for task in tasks: 245 | if task not in task_files: 246 | raise ValueError("Task %s not recognised!." % task) 247 | task_classes.append(task_file_to_task_class(task, is_bimanual)) 248 | 249 | # single-task or multi-task 250 | if multi_task: 251 | env_config = ( 252 | task_classes, 253 | obs_config, 254 | action_mode, 255 | eval_cfg.rlbench.demo_path, 256 | eval_cfg.rlbench.episode_length, 257 | eval_cfg.rlbench.headless, 258 | eval_cfg.framework.eval_episodes, 259 | train_cfg.rlbench.include_lang_goal_in_obs, 260 | eval_cfg.rlbench.time_in_state, 261 | eval_cfg.framework.record_every_n, 262 | ) 263 | else: 264 | env_config = ( 265 | task_classes[0], 266 | obs_config, 267 | action_mode, 268 | eval_cfg.rlbench.demo_path, 269 | eval_cfg.rlbench.episode_length, 270 | eval_cfg.rlbench.headless, 271 | train_cfg.rlbench.include_lang_goal_in_obs, 272 | eval_cfg.rlbench.time_in_state, 273 | eval_cfg.framework.record_every_n, 274 | ) 275 | 276 | logging.info("Evaluating seed %d." % start_seed) 277 | eval_seed( 278 | train_cfg, 279 | eval_cfg, 280 | logdir, 281 | env_device, 282 | multi_task, 283 | start_seed, 284 | env_config, 285 | ) 286 | 287 | 288 | if __name__ == "__main__": 289 | peract_config.on_init() 290 | main() 291 | -------------------------------------------------------------------------------- /agents/c2farm_lingunet_bc/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from helpers.network_utils import ( 5 | Conv3DInceptionBlock, 6 | DenseBlock, 7 | SpatialSoftmax3D, 8 | Conv3DInceptionBlockUpsampleBlock, 9 | Conv3DBlock, 10 | ) 11 | 12 | 13 | class QattentionLingU3DNet(nn.Module): 14 | def __init__( 15 | self, 16 | in_channels: int, 17 | out_channels: int, 18 | out_dense: int, 19 | voxel_size: int, 20 | low_dim_size: int, 21 | kernels: int, 22 | norm: str = None, 23 | activation: str = "relu", 24 | dense_feats: int = 32, 25 | include_prev_layer=False, 26 | depth=0, 27 | lingunet_dropout=0.0, 28 | ): 29 | super(QattentionLingU3DNet, self).__init__() 30 | self._in_channels = in_channels 31 | self._out_channels = out_channels 32 | self._norm = norm 33 | self._activation = activation 34 | self._kernels = kernels 35 | self._low_dim_size = low_dim_size 36 | self._build_calls = 0 37 | self._voxel_size = voxel_size 38 | self._dense_feats = dense_feats 39 | self._out_dense = out_dense 40 | self._include_prev_layer = include_prev_layer 41 | self._depth = depth 42 | 43 | self._lingunet_dropout = lingunet_dropout 44 | self._clip_lang_feat_dim = 1024 45 | 46 | if self._voxel_size < 16: 47 | raise Exception( 48 | "Voxel size for C2FARM_LINGUNET_BC should be at least 16 or higher" 49 | ) 50 | 51 | def build(self): 52 | use_residual = False 53 | self._build_calls += 1 54 | if self._build_calls != 1: 55 | raise RuntimeError("Build needs to be called once.") 56 | 57 | spatial_size = self._voxel_size 58 | self._input_preprocess = Conv3DInceptionBlock( 59 | self._in_channels, 60 | self._kernels, 61 | norm=self._norm, 62 | activation=self._activation, 63 | ) 64 | 65 | d0_ins = self._input_preprocess.out_channels 66 | if self._include_prev_layer: 67 | PREV_VOXEL_CHANNELS = 0 68 | d0_ins += self._input_preprocess.out_channels * self._depth 69 | 70 | if self._low_dim_size > 0: 71 | self._proprio_preprocess = DenseBlock( 72 | self._low_dim_size, self._kernels, None, self._activation 73 | ) 74 | d0_ins += self._kernels 75 | 76 | self._down0 = Conv3DInceptionBlock( 77 | d0_ins, 78 | self._kernels, 79 | norm=self._norm, 80 | activation=self._activation, 81 | residual=use_residual, 82 | ) 83 | self._ss0 = SpatialSoftmax3D( 84 | spatial_size, spatial_size, spatial_size, self._down0.out_channels 85 | ) 86 | spatial_size //= 2 87 | self._down1 = Conv3DInceptionBlock( 88 | self._down0.out_channels, 89 | self._kernels * 2, 90 | norm=self._norm, 91 | activation=self._activation, 92 | residual=use_residual, 93 | ) 94 | self._ss1 = SpatialSoftmax3D( 95 | spatial_size, spatial_size, spatial_size, self._down1.out_channels 96 | ) 97 | spatial_size //= 2 98 | 99 | flat_size = self._down0.out_channels * 4 + self._down1.out_channels * 4 100 | 101 | k1 = self._down1.out_channels 102 | if self._voxel_size > 8: 103 | k1 += self._kernels 104 | self._down2 = Conv3DInceptionBlock( 105 | self._down1.out_channels, 106 | self._kernels * 4, 107 | norm=self._norm, 108 | activation=self._activation, 109 | residual=use_residual, 110 | ) 111 | self._lang_proj2 = DenseBlock( 112 | self._clip_lang_feat_dim, self._down2.out_channels, None, None 113 | ) 114 | self._dropout2 = nn.Dropout(self._lingunet_dropout) 115 | flat_size += self._down2.out_channels * 4 116 | self._ss2 = SpatialSoftmax3D( 117 | spatial_size, spatial_size, spatial_size, self._down2.out_channels 118 | ) 119 | spatial_size //= 2 120 | k2 = self._down2.out_channels 121 | if self._voxel_size > 16: 122 | k2 *= 2 123 | self._down3 = Conv3DInceptionBlock( 124 | self._down2.out_channels, 125 | self._kernels, 126 | norm=self._norm, 127 | activation=self._activation, 128 | residual=use_residual, 129 | ) 130 | self._lang_proj3 = DenseBlock( 131 | self._clip_lang_feat_dim, self._down3.out_channels, None, None 132 | ) 133 | self._dropout3 = nn.Dropout(self._lingunet_dropout) 134 | flat_size += self._down3.out_channels * 4 135 | self._ss3 = SpatialSoftmax3D( 136 | spatial_size, spatial_size, spatial_size, self._down3.out_channels 137 | ) 138 | self._up3 = Conv3DInceptionBlockUpsampleBlock( 139 | self._kernels, 140 | self._kernels * 4, 141 | 2, 142 | norm=self._norm, 143 | activation=self._activation, 144 | residual=use_residual, 145 | ) 146 | self._up2 = Conv3DInceptionBlockUpsampleBlock( 147 | k2, 148 | self._kernels, 149 | 2, 150 | norm=self._norm, 151 | activation=self._activation, 152 | residual=use_residual, 153 | ) 154 | 155 | self._up1 = Conv3DInceptionBlockUpsampleBlock( 156 | k1, 157 | self._kernels, 158 | 2, 159 | norm=self._norm, 160 | activation=self._activation, 161 | residual=use_residual, 162 | ) 163 | 164 | self._global_maxp = nn.AdaptiveMaxPool3d(1) 165 | self._local_maxp = nn.MaxPool3d(3, 2, padding=1) 166 | self._final = Conv3DBlock( 167 | self._kernels * 2, 168 | self._kernels, 169 | kernel_sizes=3, 170 | strides=1, 171 | norm=self._norm, 172 | activation=self._activation, 173 | ) 174 | self._final2 = Conv3DBlock( 175 | self._kernels, 176 | self._out_channels, 177 | kernel_sizes=3, 178 | strides=1, 179 | norm=None, 180 | activation=None, 181 | ) 182 | 183 | self._ss_final = SpatialSoftmax3D( 184 | self._voxel_size, self._voxel_size, self._voxel_size, self._kernels 185 | ) 186 | flat_size += self._kernels * 4 187 | 188 | if self._out_dense > 0: 189 | self._dense0 = DenseBlock( 190 | flat_size, self._dense_feats, None, self._activation 191 | ) 192 | self._dense1 = DenseBlock( 193 | self._dense_feats, self._dense_feats, None, self._activation 194 | ) 195 | self._dense2 = DenseBlock(self._dense_feats, self._out_dense, None, None) 196 | 197 | def _proj_feature(self, x, spatial_size, proj_fn): 198 | x = proj_fn(x) 199 | x = x.unsqueeze(2).unsqueeze(3).unsqueeze(4) 200 | x = x.repeat(1, 1, spatial_size, spatial_size, spatial_size) 201 | return x 202 | 203 | def forward( 204 | self, 205 | ins, 206 | proprio, 207 | lang_goal_embs, 208 | lang_token_embs, 209 | bounds, 210 | prev_bounds, 211 | prev_layer_voxel_grid, 212 | ): 213 | b, _, d, h, w = ins.shape 214 | x = self._input_preprocess(ins) 215 | 216 | if self._include_prev_layer: 217 | for voxel_grid in prev_layer_voxel_grid: 218 | y = self._input_preprocess(voxel_grid) 219 | x = torch.cat([x, y], dim=1) 220 | 221 | if self._low_dim_size > 0: 222 | p = self._proprio_preprocess(proprio) 223 | p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w) 224 | x = torch.cat([x, p], dim=1) 225 | 226 | l_feat = lang_goal_embs 227 | l_feat = l_feat.to(dtype=x.dtype) 228 | 229 | d0 = self._down0(x) 230 | # l0 = self._proj_feature(l_feat, d0.shape[-1], self._lang_proj0) 231 | # d0 = self._dropout0(d0 * l0) 232 | ss0 = self._ss0(d0) 233 | maxp0 = self._global_maxp(d0).view(b, -1) 234 | 235 | d1 = u = self._down1(self._local_maxp(d0)) 236 | # l1 = self._proj_feature(l_feat, d1.shape[-1], self._lang_proj1) 237 | # d1 = self._dropout1(d1 * l1) 238 | ss1 = self._ss1(d1) 239 | maxp1 = self._global_maxp(d1).view(b, -1) 240 | 241 | feats = [ss0, maxp0, ss1, maxp1] 242 | 243 | if self._voxel_size > 8: 244 | d2 = u = self._down2(self._local_maxp(d1)) 245 | l2 = self._proj_feature(l_feat, d2.shape[-1], self._lang_proj2) 246 | d2 = self._dropout2(d2 * l2) 247 | feats.extend([self._ss2(d2), self._global_maxp(d2).view(b, -1)]) 248 | if self._voxel_size > 16: 249 | d3 = self._down3(self._local_maxp(d2)) 250 | l3 = self._proj_feature(l_feat, d3.shape[-1], self._lang_proj3) 251 | d3 = self._dropout3(d3 * l3) 252 | feats.extend([self._ss3(d3), self._global_maxp(d3).view(b, -1)]) 253 | u3 = self._up3(d3) 254 | u = torch.cat([d2, u3], dim=1) 255 | u2 = self._up2(u) 256 | u = torch.cat([d1, u2], dim=1) 257 | 258 | u1 = self._up1(u) 259 | f1 = self._final(torch.cat([d0, u1], dim=1)) 260 | trans = self._final2(f1) 261 | 262 | feats.extend([self._ss_final(f1), self._global_maxp(f1).view(b, -1)]) 263 | 264 | self.latent_dict = { 265 | "d0": d0.mean(-1).mean(-1).mean(-1), 266 | "d1": d1.mean(-1).mean(-1).mean(-1), 267 | "u1": u1.mean(-1).mean(-1).mean(-1), 268 | "trans_out": trans, 269 | } 270 | 271 | rot_and_grip_out, collision_out = None, None 272 | if self._out_dense > 0: 273 | dense0 = self._dense0(torch.cat(feats, 1)) 274 | dense1 = self._dense1(dense0) 275 | rot_and_grip_collision_out = self._dense2(dense1) 276 | rot_and_grip_out = rot_and_grip_collision_out[:, :-2] 277 | collision_out = rot_and_grip_collision_out[:, -2:] 278 | self.latent_dict.update( 279 | { 280 | "dense0": dense0, 281 | "dense1": dense1, 282 | "dense2": rot_and_grip_collision_out, 283 | } 284 | ) 285 | 286 | if self._voxel_size > 8: 287 | self.latent_dict.update( 288 | { 289 | "d2": d2.mean(-1).mean(-1).mean(-1), 290 | "u2": u2.mean(-1).mean(-1).mean(-1), 291 | } 292 | ) 293 | if self._voxel_size > 16: 294 | self.latent_dict.update( 295 | { 296 | "d3": d3.mean(-1).mean(-1).mean(-1), 297 | "u3": u3.mean(-1).mean(-1).mean(-1), 298 | } 299 | ) 300 | 301 | return trans, rot_and_grip_out, collision_out 302 | -------------------------------------------------------------------------------- /ARM_LICENSE: -------------------------------------------------------------------------------- 1 | Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation 2 | 3 | LICENCE AGREEMENT 4 | 5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”)) 6 | ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE 7 | CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE 8 | FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE 9 | DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD 10 | THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT. 11 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 12 | 13 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty 14 | free, non-transferable, non-sub- licensable licence (the “Licence”) to use the Q-attention 15 | source code, including any modification, part or derivative (the “Software”). 16 | Ownership and Licence. Your rights to use and download the Software onto your computer, 17 | and all other copies that You are authorised to make, are specified in this Agreement. 18 | However, we (or our licensors) retain all rights, including but not limited to all copyright and 19 | other intellectual property rights anywhere in the world, in the Software not expressly 20 | granted to You in this Agreement. 21 | 22 | 2. Permitted use of the Licence: 23 | 24 | (a) You may download and install the Software onto one computer or server for use in 25 | accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is 26 | not accessible by other users unless they have themselves accepted the terms of this licence 27 | agreement. 28 | 29 | (b) You may use the Software solely for non-commercial, internal or academic research 30 | purposes and only in accordance with the terms of this Agreement. You may not use the 31 | Software for commercial purposes, including but not limited to (1) integration of all or part of 32 | the source code or the Software into a product for sale or licence by or on behalf of You to 33 | third parties or (2) use of the Software or any derivative of it for research to develop software 34 | products for sale or licence to a third party or (3) use of the Software or any derivative of it 35 | for research to develop non-software products for sale or licence to a third party, or (4) use of 36 | the Software to provide any service to an external organisation for which payment is 37 | received. 38 | 39 | Should You wish to use the Software for commercial purposes, You shall 40 | email researchcontracts.engineering@imperial.ac.uk . 41 | 42 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, provided 43 | that each copy is kept in your possession and provided You reproduce our copyright notice 44 | (set out in Schedule 1) on each copy. 45 | 46 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may 47 | not transmit, transfer or sub-license this licence to use the Software or any of your rights or 48 | obligations under this Agreement to another party. 49 | 50 | (e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit 51 | any third party to access, modify or otherwise use the Software nor shall You access modify 52 | or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for 53 | mutiple users or a site licence for the Software please contact us 54 | at researchcontracts.engineering@imperial.ac.uk . 55 | 56 | (f) Publications and presentations. You may make public, results or data obtained from, 57 | dependent on or arising from research carried out using the Software, provided that any such 58 | presentation or publication identifies the Software as the source of the results or the data, 59 | including the Copyright Notice given in each element of the Software, and stating that the 60 | Software has been made available for use by You under licence from Imperial College London 61 | and You provide a copy of any such publication to Imperial College London. 62 | 63 | 3. Prohibited Uses. You may not, without written permission from us 64 | at researchcontracts.engineering@imperial.ac.uk : 65 | 66 | (a) Use, copy, modify, merge, or transfer copies of the Software or any documentation 67 | provided by us which relates to the Software except as provided in this Agreement; 68 | 69 | (b) Use any back-up or archival copies of the Software (or allow anyone else to use such 70 | copies) for any purpose other than to replace the original copy in the event it is destroyed or 71 | becomes defective; or 72 | 73 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the 74 | Software for any reason. 75 | 76 | 4. Warranty Disclaimer 77 | 78 | (a) Disclaimer. The Software has been developed for research purposes only. You 79 | acknowledge that we are providing the Software to You under this licence agreement free of 80 | charge and on condition that the disclaimer set out below shall apply. We do not represent or 81 | warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the 82 | suitability of the Software for any particular use or for use under any specific conditions; and 83 | (iii) whether use of the Software will infringe third-party rights. 84 | You acknowledge that You have reviewed and evaluated the Software to determine that it 85 | meets your needs and that You assume all responsibility and liability for determining the 86 | suitability of the Software as fit for your particular purposes and requirements. Subject to 87 | Clause 4(b), we exclude and expressly disclaim all express and implied representations, 88 | warranties, conditions and terms not stated herein (including the implied conditions or 89 | warranties of satisfactory quality, merchantable quality, merchantability and fitness for 90 | purpose). 91 | 92 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose 93 | obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or 94 | otherwise do not allow the exclusion of implied warranties, conditions or terms, in which 95 | case the above warranty disclaimer and exclusion will only apply to You to the extent 96 | permitted in the relevant jurisdiction and does not in any event exclude any implied 97 | warranties, conditions or terms which may not under applicable law be excluded. 98 | 99 | (c) Imperial College London disclaims all responsibility for the use which is made of the 100 | Software and any liability for the outcomes arising from using the Software. 101 | 102 | 5. Limitation of Liability 103 | 104 | (a) You acknowledge that we are providing the Software to You under this licence agreement 105 | free of charge and on condition that the limitation of liability set out below shall apply. 106 | Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort, 107 | negligence or otherwise, in respect of the Software and/or any related documentation 108 | provided to You by us including, but not limited to, liability for loss or corruption of data, 109 | loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect 110 | loss or damage of any kind arising out of or in connection with this licence agreement, 111 | however caused. This exclusion shall apply even if we have been advised of the possibility of 112 | such loss or damage. 113 | 114 | (b) You agree to indemnify Imperial College London and hold it harmless from and against 115 | any and all claims, damages and liabilities asserted by third parties (including claims for 116 | negligence) which arise directly or indirectly from the use of the Software or any derivative 117 | of it or the sale of any products based on the Software. You undertake to make no liability 118 | claim against any employee, student, agent or appointee of Imperial College London, in 119 | connection with this Licence or the Software. 120 | 121 | (c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory 122 | liability. 123 | 124 | (d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part, 125 | and, to that extent, they may not apply to you. Nothing in this licence agreement will affect 126 | your statutory rights or other relevant statutory provisions which cannot be excluded, 127 | restricted or modified, and its terms and conditions must be read and construed subject to any 128 | such statutory rights and/or provisions. 129 | 130 | 6. Confidentiality. You agree not to disclose any confidential information provided to You by 131 | us pursuant to this Agreement to any third party without our prior written consent. The 132 | obligations in this Clause 6 shall survive the termination of this Agreement for any reason. 133 | 134 | 7. Termination. 135 | 136 | (a) We may terminate this licence agreement and your right to use the Software at any time 137 | with immediate effect upon written notice to You. 138 | 139 | (b) This licence agreement and your right to use the Software automatically terminate if You: 140 | (i) fail to comply with any provisions of this Agreement; or 141 | (ii) destroy the copies of the Software in your possession, or voluntarily return the Software 142 | to us. 143 | 144 | (c) Upon termination You will destroy all copies of the Software. 145 | 146 | (d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years 147 | after first use of the Software under this licence agreement. 148 | 149 | 8. Miscellaneous Provisions. 150 | 151 | (a) This Agreement will be governed by and construed in accordance with the substantive 152 | laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes 153 | which may arise between us. 154 | 155 | (b) This is the entire agreement between us relating to the Software, and supersedes any prior 156 | purchase order, communications, advertising or representations concerning the Software. 157 | 158 | (c) No change or modification of this Agreement will be valid unless it is in writing, and is 159 | signed by us. 160 | 161 | (d) The unenforceability or invalidity of any part of this Agreement will not affect the 162 | enforceability or validity of the remaining parts. 163 | 164 | BSD Elements of the Software 165 | 166 | For BSD elements of the Software, the following terms shall apply: 167 | 168 | Copyright as indicated in the header of the individual element of the Software. 169 | 170 | All rights reserved. 171 | 172 | Redistribution and use in source and binary forms, with or without modification, are 173 | permitted provided that the following conditions are met: 174 | 175 | 1. Redistributions of source code must retain the above copyright notice, this list of 176 | conditions and the following disclaimer. 177 | 178 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of 179 | conditions and the following disclaimer in the documentation and/or other materials 180 | provided with the distribution. 181 | 182 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to 183 | endorse or promote products derived from this software without specific prior written 184 | permission. 185 | 186 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 187 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 188 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 189 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 190 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 191 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 192 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 193 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 194 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 195 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 196 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 197 | --------------------------------------------------------------------------------