├── train ├── vint_train │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── data_config.yaml │ │ ├── data_utils.py │ │ └── vint_dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── nomad │ │ │ ├── __init__.py │ │ │ ├── vib_placeholder.py │ │ │ ├── README.md │ │ │ ├── nomad.py │ │ │ └── nomad_vint.py │ │ └── base_model.py │ ├── training │ │ ├── __init__.py │ │ ├── logger.py │ │ └── train_eval_loop.py │ ├── process_data │ │ ├── __init__.py │ │ ├── process_bags_config.yaml │ │ └── process_data_utils.py │ └── visualizing │ │ ├── __init__.py │ │ ├── visualize_utils.py │ │ ├── distance_utils.py │ │ └── action_utils.py ├── setup.py ├── train_environment.yml ├── config │ ├── defaults.yaml │ └── nomad.yaml ├── process_recon.py ├── data_split.py ├── process_bags.py ├── process_bag_diff.py └── train.py ├── deployment ├── config │ ├── params.yaml │ ├── joystick.yaml │ ├── camera_front.yaml │ ├── camera_reverse.yaml │ ├── models.yaml │ ├── robot.yaml │ └── cmd_vel_mux.yaml ├── src │ ├── deployment_environment.yml │ ├── joy_teleop.sh │ ├── record_bag.sh │ ├── create_topomap.sh │ ├── vint_locobot.launch │ ├── ros_data.py │ ├── topic_names.py │ ├── navigate.sh │ ├── joy_teleop.py │ ├── create_topomap.py │ ├── pd_controller.py │ ├── utils.py │ ├── costmap_cfg.py │ ├── cost_to_pcd.py │ ├── tsdf_cost_map.py │ ├── guide.py │ └── navigate.py └── deployment_environment.yaml ├── assets └── pipeline.png ├── LICENSE ├── .gitignore └── README.md /train/vint_train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/models/nomad/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/process_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/visualizing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/models/nomad/vib_placeholder.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deployment/config/params.yaml: -------------------------------------------------------------------------------- 1 | image_path: "../topomaps/images/" 2 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/NaviD/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /train/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="vint_train", 5 | version="0.1.0", 6 | packages=find_packages(), 7 | ) 8 | -------------------------------------------------------------------------------- /train/vint_train/models/nomad/README.md: -------------------------------------------------------------------------------- 1 | ### NoMaD (Navigation with Goal Masked Diffusion) 2 | 3 | Files 4 | - `nomad.py` : Main model file 5 | - `nomad_vint.py` : Implementation of ViNT class for NoMaD 6 | -------------------------------------------------------------------------------- /deployment/config/joystick.yaml: -------------------------------------------------------------------------------- 1 | # joystick parameters for src/gnm_locobot.launch 2 | dev: "/dev/input/js0" # change this to your joystick device path 3 | 4 | # joystick parameters for src/joy_teleop.py 5 | deadman_switch: 5 # button index 6 | lin_vel_button: 4 7 | ang_vel_button: 0 8 | -------------------------------------------------------------------------------- /deployment/config/camera_front.yaml: -------------------------------------------------------------------------------- 1 | # camera parameters for src/gnm_locobot.launch 2 | video_device: "/dev/video0" # change this to your video device path 3 | image_width: 160 4 | image_height: 120 5 | pixel_format: yuyv 6 | camera_frame_id: "usb_cam" 7 | io_method: "mmap" 8 | framerate: 9 -------------------------------------------------------------------------------- /deployment/config/camera_reverse.yaml: -------------------------------------------------------------------------------- 1 | # camera parameters for src/gnm_locobot.launch 2 | video_device: "/dev/video2" # change this to your video device path 3 | image_width: 160 4 | image_height: 120 5 | pixel_format: yuyv 6 | camera_frame_id: "usb_cam" 7 | io_method: "mmap" 8 | framerate: 9 -------------------------------------------------------------------------------- /deployment/src/deployment_environment.yml: -------------------------------------------------------------------------------- 1 | name: nomad_train 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.5 7 | - cudatoolkit=11. 8 | - torchvision 9 | - numpy 10 | - matplotlib 11 | - pyyaml 12 | - rospkg 13 | - pip: 14 | - torch 15 | - torchvision 16 | - efficientnet_pytorch 17 | - warmup_scheduler -------------------------------------------------------------------------------- /deployment/config/models.yaml: -------------------------------------------------------------------------------- 1 | nomad: 2 | config_path: "../../train/config/nomad.yaml" 3 | ckpt_path: "../checkpoints/nomad.pth" 4 | 5 | navidiffusor: 6 | config_path: "../../train/config/nomad.yaml" 7 | ckpt_path: "../checkpoints/navidiffusor.pth" 8 | 9 | 10 | # add your own model configs here after saving the *.pth file to ../model_weight -------------------------------------------------------------------------------- /deployment/deployment_environment.yaml: -------------------------------------------------------------------------------- 1 | name: vint_deployment 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.5 7 | - cudatoolkit=11. 8 | - torchvision 9 | - numpy 10 | - matplotlib 11 | - pyyaml 12 | - rospkg 13 | - pip: 14 | - torch 15 | - torchvision 16 | - efficientnet_pytorch 17 | - warmup_scheduler 18 | - diffusers==0.11.1 -------------------------------------------------------------------------------- /deployment/config/robot.yaml: -------------------------------------------------------------------------------- 1 | # linear and angular speed limits for the robot 2 | max_v: 0.4 #0.4 # m/s 3 | max_w: 0.8 #0.8 # rad/s 4 | # observation rate fo the robot 5 | frame_rate: 4 # Hz 6 | graph_rate: 0.3333 # Hz 7 | 8 | # topic names (modify for different robots/nodes) 9 | vel_teleop_topic: /cmd_vel_mux/input/teleop 10 | vel_navi_topic: /cmd_vel 11 | vel_recovery_topic: /cmd_vel_mux/input/recovery 12 | 13 | 14 | -------------------------------------------------------------------------------- /deployment/config/cmd_vel_mux.yaml: -------------------------------------------------------------------------------- 1 | subscribers: 2 | - name: "gnm vels" 3 | topic: "/cmd_vel_mux/input/navi" 4 | timeout: 0.1 5 | priority: 0 6 | short_desc: "The default cmd_vel, controllers unaware that we are multiplexing cmd_vel should come here" 7 | - name: "teleop" 8 | topic: "/cmd_vel_mux/input/teleop" 9 | timeout: 0.5 10 | priority: 2 11 | short_desc: "Navigation stack controller" 12 | - name: "gnm recovery" 13 | topic: "/cmd_vel_mux/input/recovery" 14 | timeout: 0.1 15 | priority: 1 16 | publisher: "/mobile_base/commands/velocity" 17 | -------------------------------------------------------------------------------- /deployment/src/joy_teleop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a new tmux session 4 | session_name="teleop_locobot_$(date +%s)" 5 | tmux new-session -d -s $session_name 6 | 7 | # Split the window into two panes 8 | tmux selectp -t 0 # select the first (0) pane 9 | tmux splitw -v -p 50 # split it into two halves 10 | 11 | # Run the roslaunch command in the first pane 12 | tmux select-pane -t 0 13 | tmux send-keys "roslaunch gnm_locobot.launch" Enter 14 | 15 | # Run the teleop.py script in the second pane 16 | tmux select-pane -t 1 17 | tmux send-keys "conda activate gnm_deployment" Enter 18 | tmux send-keys "python joy_teleop.py" Enter 19 | 20 | # Attach to the tmux session 21 | tmux -2 attach-session -t $session_name -------------------------------------------------------------------------------- /train/vint_train/visualizing/visualize_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | VIZ_IMAGE_SIZE = (640, 480) 6 | RED = np.array([1, 0, 0]) 7 | GREEN = np.array([0, 1, 0]) 8 | BLUE = np.array([0, 0, 1]) 9 | CYAN = np.array([0, 1, 1]) 10 | YELLOW = np.array([1, 1, 0]) 11 | MAGENTA = np.array([1, 0, 1]) 12 | 13 | 14 | def numpy_to_img(arr: np.ndarray) -> Image: 15 | img = Image.fromarray(np.transpose(np.uint8(255 * arr), (1, 2, 0))) 16 | img = img.resize(VIZ_IMAGE_SIZE) 17 | return img 18 | 19 | 20 | def to_numpy(tensor: torch.Tensor) -> np.ndarray: 21 | return tensor.detach().cpu().numpy() 22 | 23 | 24 | def from_numpy(array: np.ndarray) -> torch.Tensor: 25 | return torch.from_numpy(array).float() 26 | -------------------------------------------------------------------------------- /train/train_environment.yml: -------------------------------------------------------------------------------- 1 | name: navidiffusor 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python=3.8.5 7 | # - cudatoolkit=10. 8 | - numpy 9 | - matplotlib 10 | - ipykernel 11 | - pip 12 | - pip: 13 | - torch 14 | - torchvision 15 | - tqdm==4.64.0 16 | - git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 17 | - opencv-python==4.6.0.66 18 | - h5py==3.6.0 19 | - wandb==0.12.18 20 | - --extra-index-url https://rospypi.github.io/simple/ 21 | - rosbag 22 | - roslz4 23 | - prettytable 24 | - efficientnet-pytorch 25 | - warmup-scheduler 26 | - diffusers==0.11.1 27 | - lmdb 28 | - vit-pytorch 29 | - positional-encodings 30 | - scipy 31 | - open3d 32 | - rospkg 33 | - pypose 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /train/vint_train/process_data/process_bags_config.yaml: -------------------------------------------------------------------------------- 1 | tartan_drive: 2 | odomtopics: "/odometry/filtered_odom" 3 | imtopics: "/multisense/left/image_rect_color" 4 | ang_offset: 1.5707963267948966 # pi/2 5 | img_process_func: "process_tartan_img" 6 | odom_process_func: "nav_to_xy_yaw" 7 | 8 | scand: 9 | odomtopics: ["/odom", "/jackal_velocity_controller/odom"] 10 | imtopics: ["/image_raw/compressed", "/camera/rgb/image_raw/compressed"] 11 | ang_offset: 0.0 12 | img_process_func: "process_scand_img" 13 | odom_process_func: "nav_to_xy_yaw" 14 | 15 | locobot: 16 | odomtopics: "/odom" 17 | imtopics: "/usb_cam/image_raw" 18 | ang_offset: 0.0 19 | img_process_func: "process_locobot_img" 20 | odom_process_func: "nav_to_xy_yaw" 21 | 22 | sacson: 23 | odomtopics: "/odometry" 24 | imtopics: "/fisheye_image/compressed" 25 | ang_offset: 0.0 26 | img_process_func: "process_sacson_img" 27 | odom_process_func: "nav_to_xy_yaw" 28 | 29 | # add your own datasets below: 30 | -------------------------------------------------------------------------------- /deployment/src/record_bag.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a new tmux session 4 | session_name="record_bag_$(date +%s)" 5 | tmux new-session -d -s $session_name 6 | 7 | # Split the window into three panes 8 | tmux selectp -t 0 # select the first (0) pane 9 | tmux splitw -v -p 50 # split it into two halves 10 | tmux selectp -t 0 # go back to the first pane 11 | tmux splitw -h -p 50 # split it into two halves 12 | 13 | # Run the roslaunch command in the first pane 14 | tmux select-pane -t 0 15 | tmux send-keys "roslaunch vint_locobot.launch" Enter 16 | 17 | # Run the teleop.py script in the second pane 18 | tmux select-pane -t 1 19 | tmux send-keys "conda activate vint_deployment" Enter 20 | tmux send-keys "python joy_teleop.py" Enter 21 | 22 | # Change the directory to ../topomaps/bags and run the rosbag record command in the third pane 23 | tmux select-pane -t 2 24 | tmux send-keys "cd ../topomaps/bags" Enter 25 | tmux send-keys "rosbag record /usb_cam/image_raw -o $1" # change topic if necessary 26 | 27 | # Attach to the tmux session 28 | tmux -2 attach-session -t $session_name -------------------------------------------------------------------------------- /deployment/src/create_topomap.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a new tmux session 4 | session_name="gnm_locobot_$(date +%s)" 5 | tmux new-session -d -s $session_name 6 | 7 | # Split the window into three panes 8 | tmux selectp -t 0 # select the first (0) pane 9 | tmux splitw -v -p 50 # split it into two halves 10 | tmux selectp -t 0 # go back to the first pane 11 | tmux splitw -h -p 50 # split it into two halves 12 | 13 | # Run roscore in the first pane 14 | tmux select-pane -t 0 15 | tmux send-keys "roscore" Enter 16 | 17 | # Run the create_topoplan.py script with command line args in the second pane 18 | tmux select-pane -t 1 19 | tmux send-keys "conda activate navidiffusor" Enter 20 | tmux send-keys "python create_topomap.py --dt 1 --dir $1" Enter 21 | 22 | # Change the directory to ../topomaps/bags and run the rosbag play command in the third pane 23 | tmux select-pane -t 2 24 | tmux send-keys "mkdir -p ../topomaps/bags" Enter 25 | tmux send-keys "cd ../topomaps/bags" Enter 26 | tmux send-keys "rosbag play -r 0.7 $2" # feel free to change the playback rate to change the edge length in the graph 27 | 28 | # Attach to the tmux session 29 | tmux -2 attach-session -t $session_name 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Dhruv Shah, Ajay Sridhar, Nitish Dashora, Kyle Stachowicz, Kevin Black, Noriaki Hirose, Sergey Levine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /deployment/src/vint_locobot.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /deployment/src/ros_data.py: -------------------------------------------------------------------------------- 1 | import rospy 2 | 3 | class ROSData: 4 | def __init__(self, timeout: int = 3, queue_size: int = 1, name: str = ""): 5 | self.timout = timeout 6 | self.last_time_received = float("-inf") 7 | self.queue_size = queue_size 8 | self.data = None 9 | self.name = name 10 | self.phantom = False 11 | 12 | def get(self): 13 | return self.data 14 | 15 | def set(self, data): 16 | time_waited = rospy.get_time() - self.last_time_received 17 | if self.queue_size == 1: 18 | self.data = data 19 | else: 20 | if self.data is None or time_waited > self.timout: # reset queue if timeout 21 | self.data = [] 22 | if len(self.data) == self.queue_size: 23 | self.data.pop(0) 24 | self.data.append(data) 25 | self.last_time_received = rospy.get_time() 26 | 27 | def is_valid(self, verbose: bool = False): 28 | time_waited = rospy.get_time() - self.last_time_received 29 | valid = time_waited < self.timout 30 | if self.queue_size > 1: 31 | valid = valid and len(self.data) == self.queue_size 32 | if verbose and not valid: 33 | print(f"Not receiving {self.name} data for {time_waited} seconds (timeout: {self.timout} seconds)") 34 | return valid -------------------------------------------------------------------------------- /deployment/src/topic_names.py: -------------------------------------------------------------------------------- 1 | # topic names for ROS communication 2 | 3 | # image obs topics 4 | FRONT_IMAGE_TOPIC = "/usb_cam_front/image_raw" 5 | REVERSE_IMAGE_TOPIC = "/usb_cam_reverse/image_raw" 6 | IMAGE_TOPIC = "/rgb/image_raw" 7 | POS_TOPIC ="/model_position" 8 | 9 | 10 | # exploration topics 11 | SUBGOALS_TOPIC = "/subgoals" 12 | GRAPH_NAME_TOPIC = "/graph_name" 13 | WAYPOINT_TOPIC = "/waypoint" 14 | REVERSE_MODE_TOPIC = "/reverse_mode" 15 | SAMPLED_OUTPUTS_TOPIC = "/sampled_outputs" 16 | REACHED_GOAL_TOPIC = "/topoplan/reached_goal" 17 | SAMPLED_WAYPOINTS_GRAPH_TOPIC = "/sampled_waypoints_graph" 18 | BACKTRACKING_IMAGE_TOPIC = "/backtracking_image" 19 | FRONTIER_IMAGE_TOPIC = "/frontier_image" 20 | SUBGOALS_SHAPE_TOPIC = "/subgoal_shape" 21 | SAMPLED_ACTIONS_TOPIC = "/sampled_actions" 22 | ANNOTATED_IMAGE_TOPIC = "/annotated_image" 23 | CURRENT_NODE_IMAGE_TOPIC = "/current_node_image" 24 | FLIP_DIRECTION_TOPIC = "/flip_direction" 25 | TURNING_TOPIC = "/turning" 26 | SUBGOAL_GEN_RATE_TOPIC = "/subgoal_gen_rate" 27 | MARKER_TOPIC = "/visualization_marker_array" 28 | VIZ_NAV_IMAGE_TOPIC = "/nav_image" 29 | 30 | # visualization topics 31 | CHOSEN_SUBGOAL_TOPIC = "/chosen_subgoal" 32 | VISUAL_MARKER_TOPIC = "/path" 33 | SUB_GOAL_TOPIC="/goal" 34 | 35 | # recorded ont the robot 36 | ODOM_TOPIC = "/odom" 37 | BUMPER_TOPIC = "/mobile_base/events/bumper" 38 | JOY_BUMPER_TOPIC = "/joy_bumper" 39 | 40 | # move the robot -------------------------------------------------------------------------------- /deployment/src/navigate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a new tmux session 4 | session_name="navidiffusor_$(date +%s)" 5 | tmux new-session -d -s $session_name 6 | 7 | # Split the window into four panes 8 | tmux selectp -t 0 # select the first (0) pane 9 | tmux splitw -h -p 50 # split it into two halves 10 | tmux selectp -t 0 # select the first (0) pane 11 | tmux splitw -v -p 50 # split it into two halves 12 | 13 | tmux selectp -t 2 # select the new, second (2) pane 14 | tmux splitw -v -p 50 # split it into two halves 15 | tmux selectp -t 0 # go back to the first pane 16 | 17 | # Run the roslaunch command in the first pane 18 | tmux select-pane -t 0 19 | tmux send-keys "roslaunch vint_locobot.launch" Enter 20 | 21 | # Run the navigate.py script with command line args in the second pane 22 | tmux select-pane -t 1 23 | # tmux send-keys "conda activate vint_deployment" Enter 24 | tmux send-keys "conda activate navidiffusor" Enter 25 | tmux send-keys "python navigate.py $@" Enter 26 | 27 | # Run the teleop.py script in the third pane 28 | tmux select-pane -t 2 29 | # tmux send-keys "conda activate vint_deployment" Enter 30 | tmux send-keys "conda activate navidiffusor" Enter 31 | tmux send-keys "python joy_teleop.py" Enter 32 | 33 | # Run the pd_controller.py script in the fourth pane 34 | tmux select-pane -t 3 35 | tmux send-keys "conda activate navidiffusor" Enter 36 | tmux send-keys "python pd_controller.py" Enter 37 | 38 | # Attach to the tmux session 39 | tmux -2 attach-session -t $session_name 40 | -------------------------------------------------------------------------------- /train/vint_train/data/data_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # global params for diffusion model 3 | # normalized min and max 4 | action_stats: 5 | min: [-2.5, -4] # [min_dx, min_dy] 6 | max: [5, 4] # [max_dx, max_dy] 7 | 8 | # data specific params 9 | recon: 10 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters) 11 | 12 | # OPTIONAL (FOR VISUALIZATION ONLY) 13 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html 14 | camera_height: 0.95 # meters 15 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera 16 | camera_matrix: 17 | fx: 272.547000 18 | fy: 266.358000 19 | cx: 320.000000 20 | cy: 220.000000 21 | dist_coeffs: 22 | k1: -0.038483 23 | k2: -0.010456 24 | p1: 0.003930 25 | p2: -0.001007 26 | k3: 0.0 27 | 28 | scand: 29 | metric_waypoint_spacing: 0.38 30 | 31 | tartan_drive: 32 | metric_waypoint_spacing: 0.72 33 | 34 | go_stanford: 35 | metric_waypoint_spacing: 0.12 36 | 37 | # private datasets: 38 | cory_hall: 39 | metric_waypoint_spacing: 0.06 40 | 41 | seattle: 42 | metric_waypoint_spacing: 0.35 43 | 44 | racer: 45 | metric_waypoint_spacing: 0.38 46 | 47 | carla_intvns: 48 | metric_waypoint_spacing: 1.39 49 | 50 | carla_cil: 51 | metric_waypoint_spacing: 1.27 52 | 53 | carla_intvns: 54 | metric_waypoint_spacing: 1.39 55 | 56 | carla: 57 | metric_waypoint_spacing: 1.59 58 | image_path_func: get_image_path 59 | 60 | sacson: 61 | metric_waypoint_spacing: 0.255 62 | 63 | # add your own dataset params here: 64 | -------------------------------------------------------------------------------- /train/config/defaults.yaml: -------------------------------------------------------------------------------- 1 | # defaults for training 2 | project_name: vint 3 | run_name: vint 4 | 5 | # training setup 6 | use_wandb: True # set to false if you don't want to log to wandb 7 | train: True 8 | batch_size: 400 9 | eval_batch_size: 400 10 | epochs: 30 11 | gpu_ids: [0] 12 | num_workers: 4 13 | lr: 5e-4 14 | optimizer: adam 15 | seed: 0 16 | clipping: False 17 | train_subset: 1. 18 | 19 | # model params 20 | model_type: gnm 21 | obs_encoding_size: 1024 22 | goal_encoding_size: 1024 23 | 24 | # normalization for the action space 25 | normalize: True 26 | 27 | # context 28 | context_type: temporal 29 | context_size: 5 30 | 31 | # tradeoff between action and distance prediction loss 32 | alpha: 0.5 33 | 34 | # tradeoff between task loss and kld 35 | beta: 0.1 36 | 37 | obs_type: image 38 | goal_type: image 39 | scheduler: null 40 | 41 | # distance bounds for distance and action and distance predictions 42 | distance: 43 | min_dist_cat: 0 44 | max_dist_cat: 20 45 | action: 46 | min_dist_cat: 2 47 | max_dist_cat: 10 48 | close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint 49 | 50 | # action output params 51 | len_traj_pred: 5 52 | learn_angle: True 53 | 54 | # dataset specific parameters 55 | image_size: [85, 64] # width, height 56 | 57 | # logging stuff 58 | ## =0 turns off 59 | print_log_freq: 100 # in iterations 60 | image_log_freq: 1000 # in iterations 61 | num_images_log: 8 # number of images to log in a logging iteration 62 | pairwise_test_freq: 10 # in epochs 63 | eval_fraction: 0.25 # fraction of the dataset to use for evaluation 64 | wandb_log_freq: 10 # in iterations 65 | eval_freq: 1 # in epochs 66 | 67 | -------------------------------------------------------------------------------- /train/vint_train/models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import List, Dict, Optional, Tuple 5 | 6 | 7 | class BaseModel(nn.Module): 8 | def __init__( 9 | self, 10 | context_size: int = 5, 11 | len_traj_pred: Optional[int] = 5, 12 | learn_angle: Optional[bool] = True, 13 | ) -> None: 14 | """ 15 | Base Model main class 16 | Args: 17 | context_size (int): how many previous observations to used for context 18 | len_traj_pred (int): how many waypoints to predict in the future 19 | learn_angle (bool): whether to predict the yaw of the robot 20 | """ 21 | super(BaseModel, self).__init__() 22 | self.context_size = context_size 23 | self.learn_angle = learn_angle 24 | self.len_trajectory_pred = len_traj_pred 25 | if self.learn_angle: 26 | self.num_action_params = 4 # last two dims are the cos and sin of the angle 27 | else: 28 | self.num_action_params = 2 29 | 30 | def flatten(self, z: torch.Tensor) -> torch.Tensor: 31 | z = nn.functional.adaptive_avg_pool2d(z, (1, 1)) 32 | z = torch.flatten(z, 1) 33 | return z 34 | 35 | def forward( 36 | self, obs_img: torch.tensor, goal_img: torch.tensor 37 | ) -> Tuple[torch.Tensor, torch.Tensor]: 38 | """ 39 | Forward pass of the model 40 | Args: 41 | obs_img (torch.Tensor): batch of observations 42 | goal_img (torch.Tensor): batch of goals 43 | Returns: 44 | dist_pred (torch.Tensor): predicted distance to goal 45 | action_pred (torch.Tensor): predicted action 46 | """ 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /train/vint_train/models/nomad/nomad.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import pdb 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class NoMaD(nn.Module): 11 | 12 | def __init__(self, vision_encoder, 13 | noise_pred_net, 14 | dist_pred_net): 15 | super(NoMaD, self).__init__() 16 | 17 | 18 | self.vision_encoder = vision_encoder 19 | self.noise_pred_net = noise_pred_net 20 | self.dist_pred_net = dist_pred_net 21 | 22 | def forward(self, func_name, **kwargs): 23 | if func_name == "vision_encoder" : 24 | output = self.vision_encoder(kwargs["obs_img"], kwargs["goal_img"], input_goal_mask=kwargs["input_goal_mask"]) 25 | elif func_name == "noise_pred_net": 26 | output = self.noise_pred_net(sample=kwargs["sample"], timestep=kwargs["timestep"], global_cond=kwargs["global_cond"]) 27 | elif func_name == "dist_pred_net": 28 | output = self.dist_pred_net(kwargs["obsgoal_cond"]) 29 | else: 30 | raise NotImplementedError 31 | return output 32 | 33 | 34 | class DenseNetwork(nn.Module): 35 | def __init__(self, embedding_dim): 36 | super(DenseNetwork, self).__init__() 37 | 38 | self.embedding_dim = embedding_dim 39 | self.network = nn.Sequential( 40 | nn.Linear(self.embedding_dim, self.embedding_dim//4), 41 | nn.ReLU(), 42 | nn.Linear(self.embedding_dim//4, self.embedding_dim//16), 43 | nn.ReLU(), 44 | nn.Linear(self.embedding_dim//16, 1) 45 | ) 46 | 47 | def forward(self, x): 48 | x = x.reshape((-1, self.embedding_dim)) 49 | output = self.network(x) 50 | return output 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /train/vint_train/training/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Logger: 5 | def __init__( 6 | self, 7 | name: str, 8 | dataset: str, 9 | window_size: int = 10, 10 | rounding: int = 4, 11 | ): 12 | """ 13 | Args: 14 | name (str): Name of the metric 15 | dataset (str): Name of the dataset 16 | window_size (int, optional): Size of the moving average window. Defaults to 10. 17 | rounding (int, optional): Number of decimals to round to. Defaults to 4. 18 | """ 19 | self.data = [] 20 | self.name = name 21 | self.dataset = dataset 22 | self.rounding = rounding 23 | self.window_size = window_size 24 | 25 | def display(self) -> str: 26 | latest = round(self.latest(), self.rounding) 27 | average = round(self.average(), self.rounding) 28 | moving_average = round(self.moving_average(), self.rounding) 29 | output = f"{self.full_name()}: {latest} ({self.window_size}pt moving_avg: {moving_average}) (avg: {average})" 30 | return output 31 | 32 | def log_data(self, data: float): 33 | if not np.isnan(data): 34 | self.data.append(data) 35 | 36 | def full_name(self) -> str: 37 | return f"{self.name} ({self.dataset})" 38 | 39 | def latest(self) -> float: 40 | if len(self.data) > 0: 41 | return self.data[-1] 42 | return np.nan 43 | 44 | def average(self) -> float: 45 | if len(self.data) > 0: 46 | return np.mean(self.data) 47 | return np.nan 48 | 49 | def moving_average(self) -> float: 50 | if len(self.data) > self.window_size: 51 | return np.mean(self.data[-self.window_size :]) 52 | return self.average() -------------------------------------------------------------------------------- /deployment/src/joy_teleop.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | # ROS 4 | import rospy 5 | from geometry_msgs.msg import Twist 6 | from sensor_msgs.msg import Joy 7 | from std_msgs.msg import Bool 8 | 9 | from topic_names import JOY_BUMPER_TOPIC 10 | 11 | vel_msg = Twist() 12 | CONFIG_PATH = "../config/robot.yaml" 13 | with open(CONFIG_PATH, "r") as f: 14 | robot_config = yaml.safe_load(f) 15 | MAX_V = 0.4 16 | MAX_W = 0.8 17 | VEL_TOPIC = robot_config["vel_teleop_topic"] 18 | JOY_CONFIG_PATH = "../config/joystick.yaml" 19 | with open(JOY_CONFIG_PATH, "r") as f: 20 | joy_config = yaml.safe_load(f) 21 | DEADMAN_SWITCH = joy_config["deadman_switch"] # button index 22 | LIN_VEL_BUTTON = joy_config["lin_vel_button"] 23 | ANG_VEL_BUTTON = joy_config["ang_vel_button"] 24 | RATE = 9 25 | vel_pub = rospy.Publisher(VEL_TOPIC, Twist, queue_size=1) 26 | bumper_pub = rospy.Publisher(JOY_BUMPER_TOPIC, Bool, queue_size=1) 27 | button = None 28 | bumper = False 29 | 30 | 31 | def callback_joy(data: Joy): 32 | """Callback function for the joystick subscriber""" 33 | global vel_msg, button, bumper 34 | button = data.buttons[DEADMAN_SWITCH] 35 | bumper_button = data.buttons[DEADMAN_SWITCH - 1] 36 | if button is not None: # hold down the dead-man switch to teleop the robot 37 | vel_msg.linear.x = MAX_V * data.axes[LIN_VEL_BUTTON] 38 | vel_msg.angular.z = MAX_W * data.axes[ANG_VEL_BUTTON] 39 | else: 40 | vel_msg = Twist() 41 | vel_pub.publish(vel_msg) 42 | if bumper_button is not None: 43 | bumper = bool(data.buttons[DEADMAN_SWITCH - 1]) 44 | else: 45 | bumper = False 46 | 47 | 48 | 49 | def main(): 50 | rospy.init_node("Joy2Locobot", anonymous=False) 51 | joy_sub = rospy.Subscriber("joy", Joy, callback_joy) 52 | rate = rospy.Rate(RATE) 53 | print("Registered with master node. Waiting for joystick input...") 54 | while not rospy.is_shutdown(): 55 | if button: 56 | print(f"Teleoperating the robot:\n {vel_msg}") 57 | vel_pub.publish(vel_msg) 58 | rate.sleep() 59 | bumper_msg = Bool() 60 | bumper_msg.data = bumper 61 | bumper_pub.publish(bumper_msg) 62 | if bumper: 63 | print("Bumper pressed!") 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | train/logs/* 2 | train/datasets/* 3 | train/vint_train/data/data_splits/* 4 | train/wandb/* 5 | train/gnm_dataset/* 6 | *_test.yaml 7 | 8 | 9 | *.jpg 10 | *.pth 11 | *.mp4 12 | *.gif 13 | 14 | deployment/model_weights/* 15 | deployment/topomaps/* 16 | 17 | .vscode/* 18 | */.vscode/* 19 | 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | -------------------------------------------------------------------------------- /train/process_recon.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import pickle 4 | from PIL import Image 5 | import io 6 | import argparse 7 | import tqdm 8 | 9 | 10 | def main(args: argparse.Namespace): 11 | recon_dir = os.path.join(args.input_dir, "recon_release") 12 | output_dir = args.output_dir 13 | 14 | # create output dir if it doesn't exist 15 | if not os.path.exists(output_dir): 16 | os.makedirs(output_dir) 17 | 18 | # get all the folders in the recon dataset 19 | filenames = os.listdir(recon_dir) 20 | if args.num_trajs >= 0: 21 | filenames = filenames[: args.num_trajs] 22 | 23 | # processing loop 24 | for filename in tqdm.tqdm(filenames, desc="Trajectories processed"): 25 | # extract the name without the extension 26 | traj_name = filename.split(".")[0] 27 | # load the hdf5 file 28 | try: 29 | h5_f = h5py.File(os.path.join(recon_dir, filename), "r") 30 | except OSError: 31 | print(f"Error loading {filename}. Skipping...") 32 | continue 33 | # extract the position and yaw data 34 | position_data = h5_f["jackal"]["position"][:, :2] 35 | yaw_data = h5_f["jackal"]["yaw"][()] 36 | # save the data to a dictionary 37 | traj_data = {"position": position_data, "yaw": yaw_data} 38 | traj_folder = os.path.join(output_dir, traj_name) 39 | os.makedirs(traj_folder, exist_ok=True) 40 | with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f: 41 | pickle.dump(traj_data, f) 42 | # make a folder for the file 43 | if not os.path.exists(traj_folder): 44 | os.makedirs(traj_folder) 45 | # save the image data to disk 46 | for i in range(h5_f["images"]["rgb_left"].shape[0]): 47 | img = Image.open(io.BytesIO(h5_f["images"]["rgb_left"][i])) 48 | img.save(os.path.join(traj_folder, f"{i}.jpg")) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # get arguments for the recon input dir and the output dir 54 | parser.add_argument( 55 | "--input-dir", 56 | "-i", 57 | type=str, 58 | help="path of the recon_dataset", 59 | required=True, 60 | ) 61 | parser.add_argument( 62 | "--output-dir", 63 | "-o", 64 | default="datasets/recon/", 65 | type=str, 66 | help="path for processed recon dataset (default: datasets/recon/)", 67 | ) 68 | # number of trajs to process 69 | parser.add_argument( 70 | "--num-trajs", 71 | "-n", 72 | default=-1, 73 | type=int, 74 | help="number of trajectories to process (default: -1, all)", 75 | ) 76 | 77 | args = parser.parse_args() 78 | print("STARTING PROCESSING RECON DATASET") 79 | main(args) 80 | print("FINISHED PROCESSING RECON DATASET") 81 | -------------------------------------------------------------------------------- /train/data_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import random 5 | 6 | 7 | def remove_files_in_dir(dir_path: str): 8 | for f in os.listdir(dir_path): 9 | file_path = os.path.join(dir_path, f) 10 | try: 11 | if os.path.isfile(file_path) or os.path.islink(file_path): 12 | os.unlink(file_path) 13 | elif os.path.isdir(file_path): 14 | shutil.rmtree(file_path) 15 | except Exception as e: 16 | print("Failed to delete %s. Reason: %s" % (file_path, e)) 17 | 18 | 19 | def main(args: argparse.Namespace): 20 | # Get the names of the folders in the data directory that contain the file 'traj_data.pkl' 21 | folder_names = [ 22 | f 23 | for f in os.listdir(args.data_dir) 24 | if os.path.isdir(os.path.join(args.data_dir, f)) 25 | and "traj_data.pkl" in os.listdir(os.path.join(args.data_dir, f)) 26 | ] 27 | 28 | # Randomly shuffle the names of the folders 29 | random.shuffle(folder_names) 30 | 31 | # Split the names of the folders into train and test sets 32 | split_index = int(args.split * len(folder_names)) 33 | train_folder_names = folder_names[:split_index] 34 | test_folder_names = folder_names[split_index:] 35 | 36 | # Create directories for the train and test sets 37 | train_dir = os.path.join(args.data_splits_dir, args.dataset_name, "train") 38 | test_dir = os.path.join(args.data_splits_dir, args.dataset_name, "test") 39 | for dir_path in [train_dir, test_dir]: 40 | if os.path.exists(dir_path): 41 | print(f"Clearing files from {dir_path} for new data split") 42 | remove_files_in_dir(dir_path) 43 | else: 44 | print(f"Creating {dir_path}") 45 | os.makedirs(dir_path) 46 | 47 | # Write the names of the train and test folders to files 48 | with open(os.path.join(train_dir, "traj_names.txt"), "w") as f: 49 | for folder_name in train_folder_names: 50 | f.write(folder_name + "\n") 51 | 52 | with open(os.path.join(test_dir, "traj_names.txt"), "w") as f: 53 | for folder_name in test_folder_names: 54 | f.write(folder_name + "\n") 55 | 56 | 57 | if __name__ == "__main__": 58 | # Set up the command line argument parser 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument( 62 | "--data-dir", "-i", help="Directory containing the data", required=True 63 | ) 64 | parser.add_argument( 65 | "--dataset-name", "-d", help="Name of the dataset", required=True 66 | ) 67 | parser.add_argument( 68 | "--split", "-s", type=float, default=0.8, help="Train/test split (default: 0.8)" 69 | ) 70 | parser.add_argument( 71 | "--data-splits-dir", "-o", default="vint_train/data/data_splits", help="Data splits directory" 72 | ) 73 | args = parser.parse_args() 74 | main(args) 75 | print("Done") 76 | -------------------------------------------------------------------------------- /deployment/src/create_topomap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from utils import msg_to_pil 4 | import time 5 | 6 | # ROS 7 | import rospy 8 | from sensor_msgs.msg import Image 9 | from sensor_msgs.msg import Joy 10 | 11 | IMAGE_TOPIC = "/rgb/image_raw" 12 | TOPOMAP_IMAGES_DIR = "../topomaps/images" 13 | obs_img = None 14 | 15 | 16 | def remove_files_in_dir(dir_path: str): 17 | for f in os.listdir(dir_path): 18 | file_path = os.path.join(dir_path, f) 19 | try: 20 | if os.path.isfile(file_path) or os.path.islink(file_path): 21 | os.unlink(file_path) 22 | elif os.path.isdir(file_path): 23 | shutil.rmtree(file_path) 24 | except Exception as e: 25 | print("Failed to delete %s. Reason: %s" % (file_path, e)) 26 | 27 | 28 | def callback_obs(msg: Image): 29 | global obs_img 30 | obs_img = msg_to_pil(msg) 31 | 32 | 33 | def callback_joy(msg: Joy): 34 | if msg.buttons[0]: 35 | rospy.signal_shutdown("shutdown") 36 | 37 | 38 | def main(args: argparse.Namespace): 39 | global obs_img 40 | rospy.init_node("CREATE_TOPOMAP", anonymous=False) 41 | image_curr_msg = rospy.Subscriber( 42 | IMAGE_TOPIC, Image, callback_obs, queue_size=1) 43 | subgoals_pub = rospy.Publisher( 44 | "/subgoals", Image, queue_size=1) 45 | joy_sub = rospy.Subscriber("joy", Joy, callback_joy) 46 | 47 | topomap_name_dir = os.path.join(TOPOMAP_IMAGES_DIR, args.dir) 48 | if not os.path.isdir(topomap_name_dir): 49 | os.makedirs(topomap_name_dir) 50 | else: 51 | print(f"{topomap_name_dir} already exists. Removing previous images...") 52 | remove_files_in_dir(topomap_name_dir) 53 | 54 | 55 | assert args.dt > 0, "dt must be positive" 56 | rate = rospy.Rate(1/args.dt) 57 | print("Registered with master node. Waiting for images...") 58 | i = 0 59 | start_time = float("inf") 60 | while not rospy.is_shutdown(): 61 | if obs_img is not None: 62 | obs_img.save(os.path.join(topomap_name_dir, f"{i}.png")) 63 | print("published image", i) 64 | i += 1 65 | rate.sleep() 66 | start_time = time.time() 67 | obs_img = None 68 | if time.time() - start_time > 2 * args.dt: 69 | print(f"Topic {IMAGE_TOPIC} not publishing anymore. Shutting down...") 70 | rospy.signal_shutdown("shutdown") 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser( 75 | description=f"Code to generate topomaps from the {IMAGE_TOPIC} topic" 76 | ) 77 | parser.add_argument( 78 | "--dir", 79 | "-d", 80 | default="topomap", 81 | type=str, 82 | help="path to topological map images in ../topomaps/images directory (default: topomap)", 83 | ) 84 | parser.add_argument( 85 | "--dt", 86 | "-t", 87 | default=0.1, 88 | type=float, 89 | help=f"time between images sampled from the {IMAGE_TOPIC} topic (default: 3.0)", 90 | ) 91 | args = parser.parse_args() 92 | 93 | main(args) 94 | -------------------------------------------------------------------------------- /deployment/src/pd_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import yaml 3 | from typing import Tuple 4 | 5 | # ROS 6 | import rospy 7 | from geometry_msgs.msg import Twist 8 | from std_msgs.msg import Float32MultiArray, Bool 9 | 10 | from topic_names import (WAYPOINT_TOPIC, 11 | REACHED_GOAL_TOPIC) 12 | from ros_data import ROSData 13 | from utils import clip_angle 14 | 15 | # CONSTS 16 | CONFIG_PATH = "../config/robot.yaml" 17 | with open(CONFIG_PATH, "r") as f: 18 | robot_config = yaml.safe_load(f) 19 | MAX_V = robot_config["max_v"] 20 | MAX_W = robot_config["max_w"] 21 | VEL_TOPIC = robot_config["vel_navi_topic"] 22 | DT = 1/robot_config["frame_rate"] 23 | RATE = 50 24 | EPS = 1e-8 25 | WAYPOINT_TIMEOUT = 4 # seconds # TODO: tune this 26 | FLIP_ANG_VEL = np.pi/4 27 | 28 | # GLOBALS 29 | vel_msg = Twist() 30 | waypoint = ROSData(WAYPOINT_TIMEOUT, name="waypoint") 31 | reached_goal = False 32 | reverse_mode = False 33 | current_yaw = None 34 | 35 | def clip_angle(theta) -> float: 36 | """Clip angle to [-pi, pi]""" 37 | theta %= 2 * np.pi 38 | if -np.pi < theta < np.pi: 39 | return theta 40 | return theta - 2 * np.pi 41 | 42 | 43 | def pd_controller(waypoint: np.ndarray) -> Tuple[float]: 44 | """PD controller for the robot""" 45 | assert len(waypoint) == 2 or len(waypoint) == 4, "waypoint must be a 2D or 4D vector" 46 | if len(waypoint) == 2: 47 | dx, dy = waypoint 48 | else: 49 | dx, dy, hx, hy = waypoint 50 | # this controller only uses the predicted heading if dx and dy near zero 51 | if len(waypoint) == 4 and np.abs(dx) < EPS and np.abs(dy) < EPS: 52 | v = 0 53 | w = clip_angle(np.arctan2(hy, hx))/DT 54 | elif np.abs(dx) < EPS: 55 | v = 0 56 | w = np.sign(dy) * np.pi/(2*DT) 57 | else: 58 | v = dx / DT 59 | w = np.arctan(dy/dx) / DT 60 | v = np.clip(v, 0, MAX_V) 61 | w = np.clip(w, -MAX_W, MAX_W) 62 | return v, w 63 | 64 | 65 | def callback_drive(waypoint_msg: Float32MultiArray): 66 | """Callback function for the waypoint subscriber""" 67 | global vel_msg 68 | print("seting waypoint") 69 | waypoint.set(waypoint_msg.data) 70 | 71 | 72 | def callback_reached_goal(reached_goal_msg: Bool): 73 | """Callback function for the reached goal subscriber""" 74 | global reached_goal 75 | reached_goal = reached_goal_msg.data 76 | 77 | 78 | def main(): 79 | global vel_msg, reverse_mode 80 | rospy.init_node("PD_CONTROLLER", anonymous=False) 81 | waypoint_sub = rospy.Subscriber(WAYPOINT_TOPIC, Float32MultiArray, callback_drive, queue_size=1) 82 | reached_goal_sub = rospy.Subscriber(REACHED_GOAL_TOPIC, Bool, callback_reached_goal, queue_size=1) 83 | vel_out = rospy.Publisher(VEL_TOPIC, Twist, queue_size=1) 84 | rate = rospy.Rate(RATE) 85 | print("Registered with master node. Waiting for waypoints...") 86 | while not rospy.is_shutdown(): 87 | vel_msg = Twist() 88 | if reached_goal: 89 | vel_out.publish(vel_msg) 90 | print("Reached goal! Stopping...") 91 | return 92 | elif waypoint.is_valid(verbose=True): 93 | v, w = pd_controller(waypoint.get()) 94 | # if reverse_mode: 95 | # v *= -1 96 | vel_msg.linear.x = v 97 | vel_msg.angular.z = w 98 | print(f"publishing new vel: {v}, {w}") 99 | vel_out.publish(vel_msg) 100 | rate.sleep() 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /train/config/nomad.yaml: -------------------------------------------------------------------------------- 1 | project_name: nomad 2 | run_name: nomad 3 | 4 | # training setup 5 | use_wandb: True # set to false if you don't want to log to wandb 6 | train: True 7 | batch_size: 256 8 | epochs: 100 9 | gpu_ids: [0] 10 | num_workers: 12 11 | lr: 1e-4 12 | optimizer: adamw 13 | clipping: False 14 | max_norm: 1. 15 | scheduler: "cosine" 16 | warmup: True 17 | warmup_epochs: 4 18 | cyclic_period: 10 19 | plateau_patience: 3 20 | plateau_factor: 0.5 21 | seed: 0 22 | 23 | # model params 24 | model_type: nomad 25 | vision_encoder: nomad_vint 26 | encoding_size: 256 27 | obs_encoder: efficientnet-b0 28 | attn_unet: False 29 | cond_predict_scale: False 30 | mha_num_attention_heads: 4 31 | mha_num_attention_layers: 4 32 | mha_ff_dim_factor: 4 33 | down_dims: [64, 128, 256] 34 | 35 | # diffusion model params 36 | num_diffusion_iters: 10 37 | 38 | # mask 39 | goal_mask_prob: 0.5 40 | 41 | # normalization for the action space 42 | normalize: True 43 | 44 | # context 45 | context_type: temporal 46 | context_size: 3 # 5 47 | alpha: 1e-4 48 | 49 | # distance bounds for distance and action and distance predictions 50 | distance: 51 | min_dist_cat: 0 52 | max_dist_cat: 20 53 | action: 54 | min_dist_cat: 3 55 | max_dist_cat: 20 56 | 57 | # action output params 58 | len_traj_pred: 8 59 | learn_angle: False 60 | 61 | # dataset specific parameters 62 | image_size: [96, 96] # width, height 63 | datasets: 64 | recon: 65 | data_folder: /home//nomad_dataset/recon 66 | train: /home//data_splits/recon/train/ # path to train folder with traj_names.txt 67 | test: /home//data_splits/recon/test/ # path to test folder with traj_names.txt 68 | end_slack: 3 # because many trajectories end in collisions 69 | goals_per_obs: 1 # how many goals are sampled per observation 70 | negative_mining: True # negative mining from the ViNG paper (Shah et al.) 71 | go_stanford: 72 | data_folder: /home//nomad_dataset/go_stanford_cropped # datasets/stanford_go_new 73 | train: /home//data_splits/go_stanford/train/ 74 | test: /home//data_splits/go_stanford/test/ 75 | end_slack: 0 76 | goals_per_obs: 2 # increase dataset size 77 | negative_mining: True 78 | cory_hall: 79 | data_folder: /home//nomad_dataset/cory_hall/ 80 | train: /home//data_splits/cory_hall/train/ 81 | test: /home//data_splits/cory_hall/test/ 82 | end_slack: 3 # because many trajectories end in collisions 83 | goals_per_obs: 1 84 | negative_mining: True 85 | tartan_drive: 86 | data_folder: /home//nomad_dataset/tartan_drive/ 87 | train: /home//data_splits/tartan_drive/train/ 88 | test: /home//data_splits/tartan_drive/test/ 89 | end_slack: 3 # because many trajectories end in collisions 90 | goals_per_obs: 1 91 | negative_mining: True 92 | sacson: 93 | data_folder: /home//nomad_dataset/sacson/ 94 | train: /home//data_splits/sacson/train/ 95 | test: /home//data_splits/sacson/test/ 96 | end_slack: 3 # because many trajectories end in collisions 97 | goals_per_obs: 1 98 | negative_mining: True 99 | # private datasets (uncomment if you have access) 100 | # seattle: 101 | # data_folder: /home//nomad_dataset/seattle/ 102 | # train: /home//data_splits/seattle/train/ 103 | # test: /home//data_splits/seattle/test/ 104 | # end_slack: 0 105 | # goals_per_obs: 1 106 | # negative_mining: True 107 | # scand: 108 | # data_folder: /home//nomad_dataset/scand/ 109 | # train: /home//data_splits/scand/train/ 110 | # test: /home//data_splits/scand/test/ 111 | # end_slack: 0 112 | # goals_per_obs: 1 113 | # negative_mining: True 114 | 115 | # logging stuff 116 | ## =0 turns off 117 | print_log_freq: 100 # in iterations 118 | image_log_freq: 1000 #0 # in iterations 119 | num_images_log: 8 #0 120 | pairwise_test_freq: 0 # in epochs 121 | eval_fraction: 0.25 122 | wandb_log_freq: 10 # in iterations 123 | eval_freq: 1 # in epochs -------------------------------------------------------------------------------- /train/vint_train/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from typing import Any, Iterable, Tuple 5 | 6 | import torch 7 | from torchvision import transforms 8 | import torchvision.transforms.functional as TF 9 | import torch.nn.functional as F 10 | import io 11 | from typing import Union 12 | 13 | VISUALIZATION_IMAGE_SIZE = (160, 120) 14 | IMAGE_ASPECT_RATIO = ( 15 | 4 / 3 16 | ) # all images are centered cropped to a 4:3 aspect ratio in training 17 | 18 | 19 | 20 | def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"): 21 | data_ext = { 22 | "image": ".jpg", 23 | # add more data types here 24 | } 25 | return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}") 26 | 27 | 28 | def yaw_rotmat(yaw: float) -> np.ndarray: 29 | return np.array( 30 | [ 31 | [np.cos(yaw), -np.sin(yaw), 0.0], 32 | [np.sin(yaw), np.cos(yaw), 0.0], 33 | [0.0, 0.0, 1.0], 34 | ], 35 | ) 36 | 37 | 38 | def to_local_coords( 39 | positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float 40 | ) -> np.ndarray: 41 | """ 42 | Convert positions to local coordinates 43 | 44 | Args: 45 | positions (np.ndarray): positions to convert 46 | curr_pos (np.ndarray): current position 47 | curr_yaw (float): current yaw 48 | Returns: 49 | np.ndarray: positions in local coordinates 50 | """ 51 | rotmat = yaw_rotmat(curr_yaw) 52 | if positions.shape[-1] == 2: 53 | rotmat = rotmat[:2, :2] 54 | elif positions.shape[-1] == 3: 55 | pass 56 | else: 57 | raise ValueError 58 | 59 | return (positions - curr_pos).dot(rotmat) 60 | 61 | 62 | def calculate_deltas(waypoints: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Calculate deltas between waypoints 65 | 66 | Args: 67 | waypoints (torch.Tensor): waypoints 68 | Returns: 69 | torch.Tensor: deltas 70 | """ 71 | num_params = waypoints.shape[1] 72 | origin = torch.zeros(1, num_params) 73 | prev_waypoints = torch.concat((origin, waypoints[:-1]), axis=0) 74 | deltas = waypoints - prev_waypoints 75 | if num_params > 2: 76 | return calculate_sin_cos(deltas) 77 | return deltas 78 | 79 | 80 | def calculate_sin_cos(waypoints: torch.Tensor) -> torch.Tensor: 81 | """ 82 | Calculate sin and cos of the angle 83 | 84 | Args: 85 | waypoints (torch.Tensor): waypoints 86 | Returns: 87 | torch.Tensor: waypoints with sin and cos of the angle 88 | """ 89 | assert waypoints.shape[1] == 3 90 | angle_repr = torch.zeros_like(waypoints[:, :2]) 91 | angle_repr[:, 0] = torch.cos(waypoints[:, 2]) 92 | angle_repr[:, 1] = torch.sin(waypoints[:, 2]) 93 | return torch.concat((waypoints[:, :2], angle_repr), axis=1) 94 | 95 | 96 | def transform_images( 97 | img: Image.Image, transform: transforms, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO 98 | ): 99 | w, h = img.size 100 | if w > h: 101 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio 102 | else: 103 | img = TF.center_crop(img, (int(w / aspect_ratio), w)) 104 | viz_img = img.resize(VISUALIZATION_IMAGE_SIZE) 105 | viz_img = TF.to_tensor(viz_img) 106 | img = img.resize(image_resize_size) 107 | transf_img = transform(img) 108 | return viz_img, transf_img 109 | 110 | 111 | def resize_and_aspect_crop( 112 | img: Image.Image, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO 113 | ): 114 | w, h = img.size 115 | if w > h: 116 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio 117 | else: 118 | img = TF.center_crop(img, (int(w / aspect_ratio), w)) 119 | img = img.resize(image_resize_size) 120 | resize_img = TF.to_tensor(img) 121 | return resize_img 122 | 123 | 124 | def img_path_to_data(path: Union[str, io.BytesIO], image_resize_size: Tuple[int, int]) -> torch.Tensor: 125 | """ 126 | Load an image from a path and transform it 127 | Args: 128 | path (str): path to the image 129 | image_resize_size (Tuple[int, int]): size to resize the image to 130 | Returns: 131 | torch.Tensor: resized image as tensor 132 | """ 133 | # return transform_images(Image.open(path), transform, image_resize_size, aspect_ratio) 134 | return resize_and_aspect_crop(Image.open(path), image_resize_size) 135 | 136 | -------------------------------------------------------------------------------- /train/process_bags.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pickle 4 | from PIL import Image 5 | import io 6 | import argparse 7 | import tqdm 8 | import yaml 9 | import rosbag 10 | 11 | # utils 12 | from vint_train.process_data.process_data_utils import * 13 | 14 | 15 | def main(args: argparse.Namespace): 16 | 17 | # load the config file 18 | with open("vint_train/process_data/process_bags_config.yaml", "r") as f: 19 | config = yaml.load(f, Loader=yaml.FullLoader) 20 | 21 | # create output dir if it doesn't exist 22 | if not os.path.exists(args.output_dir): 23 | os.makedirs(args.output_dir) 24 | 25 | # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir 26 | bag_files = [] 27 | for root, dirs, files in os.walk(args.input_dir): 28 | for file in files: 29 | if file.endswith(".bag"): 30 | bag_files.append(os.path.join(root, file)) 31 | if args.num_trajs >= 0: 32 | bag_files = bag_files[: args.num_trajs] 33 | 34 | # processing loop 35 | for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"): 36 | try: 37 | b = rosbag.Bag(bag_path) 38 | except rosbag.ROSBagException as e: 39 | print(e) 40 | print(f"Error loading {bag_path}. Skipping...") 41 | continue 42 | 43 | # name is that folders separated by _ and then the last part of the path 44 | traj_name = "_".join(bag_path.split("/")[-2:])[:-4] 45 | 46 | # load the hdf5 file 47 | bag_img_data, bag_traj_data = get_images_and_odom( 48 | b, 49 | config[args.dataset_name]["imtopics"], 50 | config[args.dataset_name]["odomtopics"], 51 | eval(config[args.dataset_name]["img_process_func"]), 52 | eval(config[args.dataset_name]["odom_process_func"]), 53 | rate=args.sample_rate, 54 | ang_offset=config[args.dataset_name]["ang_offset"], 55 | ) 56 | 57 | 58 | if bag_img_data is None or bag_traj_data is None: 59 | print( 60 | f"{bag_path} did not have the topics we were looking for. Skipping..." 61 | ) 62 | continue 63 | # remove backwards movement 64 | cut_trajs = filter_backwards(bag_img_data, bag_traj_data) 65 | 66 | for i, (img_data_i, traj_data_i) in enumerate(cut_trajs): 67 | traj_name_i = traj_name + f"_{i}" 68 | traj_folder_i = os.path.join(args.output_dir, traj_name_i) 69 | # make a folder for the traj 70 | if not os.path.exists(traj_folder_i): 71 | os.makedirs(traj_folder_i) 72 | with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f: 73 | pickle.dump(traj_data_i, f) 74 | # save the image data to disk 75 | for i, img in enumerate(img_data_i): 76 | img.save(os.path.join(traj_folder_i, f"{i}.jpg")) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | # get arguments for the recon input dir and the output dir 82 | # add dataset name 83 | parser.add_argument( 84 | "--dataset-name", 85 | "-d", 86 | type=str, 87 | help="name of the dataset (must be in process_config.yaml)", 88 | default="tartan_drive", 89 | required=True, 90 | ) 91 | parser.add_argument( 92 | "--input-dir", 93 | "-i", 94 | type=str, 95 | help="path of the datasets with rosbags", 96 | required=True, 97 | ) 98 | parser.add_argument( 99 | "--output-dir", 100 | "-o", 101 | default="../datasets/tartan_drive/", 102 | type=str, 103 | help="path for processed dataset (default: ../datasets/tartan_drive/)", 104 | ) 105 | # number of trajs to process 106 | parser.add_argument( 107 | "--num-trajs", 108 | "-n", 109 | default=-1, 110 | type=int, 111 | help="number of bags to process (default: -1, all)", 112 | ) 113 | # sampling rate 114 | parser.add_argument( 115 | "--sample-rate", 116 | "-s", 117 | default=4.0, 118 | type=float, 119 | help="sampling rate (default: 4.0 hz)", 120 | ) 121 | 122 | args = parser.parse_args() 123 | # all caps for the dataset name 124 | print(f"STARTING PROCESSING {args.dataset_name.upper()} DATASET") 125 | main(args) 126 | print(f"FINISHED PROCESSING {args.dataset_name.upper()} DATASET") 127 | -------------------------------------------------------------------------------- /train/process_bag_diff.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pickle 4 | from PIL import Image 5 | import io 6 | import argparse 7 | import tqdm 8 | import yaml 9 | import rosbag 10 | 11 | # utils 12 | from vint_train.process_data.process_data_utils import * 13 | 14 | 15 | def main(args: argparse.Namespace): 16 | 17 | # load the config file 18 | with open("vint_train/process_data/process_bags_config.yaml", "r") as f: 19 | config = yaml.load(f, Loader=yaml.FullLoader) 20 | 21 | # create output dir if it doesn't exist 22 | if not os.path.exists(args.output_dir): 23 | os.makedirs(args.output_dir) 24 | 25 | # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir 26 | bag_files = [] 27 | for root, dirs, files in os.walk(args.input_dir): 28 | for file in files: 29 | if file.endswith(".bag") and "diff" in file: 30 | bag_files.append(os.path.join(root, file)) 31 | if args.num_trajs >= 0: 32 | bag_files = bag_files[: args.num_trajs] 33 | 34 | # processing loop 35 | for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"): 36 | try: 37 | b = rosbag.Bag(bag_path) 38 | except rosbag.ROSBagException as e: 39 | print(e) 40 | print(f"Error loading {bag_path}. Skipping...") 41 | continue 42 | 43 | # name is that folders separated by _ and then the last part of the path 44 | traj_name = "_".join(bag_path.split("/")[-2:])[:-4] 45 | 46 | # load the bag file 47 | bag_img_data, bag_traj_data = get_images_and_odom_2( 48 | b, 49 | ['/usb_cam_front/image_raw', '/chosen_subgoal'], 50 | ['/odom'], 51 | rate=args.sample_rate, 52 | ) 53 | 54 | if bag_img_data is None: 55 | print( 56 | f"{bag_path} did not have the topics we were looking for. Skipping..." 57 | ) 58 | continue 59 | # remove backwards movement 60 | # cut_trajs = filter_backwards(bag_img_data, bag_traj_data) 61 | 62 | # for i, (img_data_i, traj_data_i) in enumerate(cut_trajs): 63 | # traj_name_i = traj_name + f"_{i}" 64 | # traj_folder_i = os.path.join(args.output_dir, traj_name_i) 65 | # # make a folder for the traj 66 | # if not os.path.exists(traj_folder_i): 67 | # os.makedirs(traj_folder_i) 68 | # with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f: 69 | # pickle.dump(traj_data_i, f) 70 | # # save the image data to disk 71 | # for i, img in enumerate(img_data_i): 72 | # img.save(os.path.join(traj_folder_i, f"{i}.jpg")) 73 | 74 | traj_folder = os.path.join(args.output_dir, traj_name) 75 | if not os.path.exists(traj_folder): 76 | os.makedirs(traj_folder) 77 | 78 | obs_images = bag_img_data["/usb_cam_front/image_raw"] 79 | diff_images = bag_img_data["/chosen_subgoal"] 80 | for i, img_data in enumerate(zip(obs_images, diff_images)): 81 | obs_image, diff_image = img_data 82 | # save the image data to disk 83 | # save the image data to disk 84 | obs_image.save(os.path.join(traj_folder, f"{i}.jpg")) 85 | diff_image.save(os.path.join(traj_folder, f"diff_{i}.jpg")) 86 | 87 | with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f: 88 | pickle.dump(bag_traj_data['/odom'], f) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | # get arguments for the recon input dir and the output dir 94 | # add dataset name 95 | # parser.add_argument( 96 | # "--dataset-name", 97 | # "-d", 98 | # type=str, 99 | # help="name of the dataset (must be in process_config.yaml)", 100 | # default="tartan_drive", 101 | # required=True, 102 | # ) 103 | parser.add_argument( 104 | "--input-dir", 105 | "-i", 106 | type=str, 107 | help="path of the datasets with rosbags", 108 | required=True, 109 | ) 110 | parser.add_argument( 111 | "--output-dir", 112 | "-o", 113 | default="../datasets/tartan_drive/", 114 | type=str, 115 | help="path for processed dataset (default: ../datasets/tartan_drive/)", 116 | ) 117 | # number of trajs to process 118 | parser.add_argument( 119 | "--num-trajs", 120 | "-n", 121 | default=-1, 122 | type=int, 123 | help="number of bags to process (default: -1, all)", 124 | ) 125 | # sampling rate 126 | parser.add_argument( 127 | "--sample-rate", 128 | "-s", 129 | default=4.0, 130 | type=float, 131 | help="sampling rate (default: 4.0 hz)", 132 | ) 133 | 134 | args = parser.parse_args() 135 | # all caps for the dataset name 136 | print(f"STARTING PROCESSING DIFF DATASET") 137 | main(args) 138 | print(f"FINISHED PROCESSING DIFF DATASET") 139 | -------------------------------------------------------------------------------- /deployment/src/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import io 5 | import matplotlib.pyplot as plt 6 | from scipy.spatial.transform import Rotation as R 7 | 8 | # ROS 9 | from sensor_msgs.msg import Image 10 | import tf 11 | 12 | # pytorch 13 | import torch 14 | import torch.nn as nn 15 | from torchvision import transforms 16 | import torchvision.transforms.functional as TF 17 | 18 | import numpy as np 19 | from PIL import Image as PILImage 20 | from typing import List, Tuple, Dict, Optional 21 | 22 | # models 23 | from vint_train.models.gnm.gnm import GNM 24 | from vint_train.models.vint.vint import ViNT 25 | 26 | from vint_train.models.vint.vit import ViT 27 | from vint_train.models.nomad.nomad import NoMaD, DenseNetwork 28 | from vint_train.models.nomad.nomad_vint import NoMaD_ViNT, replace_bn_with_gn 29 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D 30 | from vint_train.data.data_utils import IMAGE_ASPECT_RATIO 31 | 32 | 33 | def load_model( 34 | model_path: str, 35 | config: dict, 36 | device: torch.device = torch.device("cpu"), 37 | ) -> nn.Module: 38 | model_type = config["model_type"] 39 | 40 | if model_type == "gnm": 41 | model = GNM( 42 | config["context_size"], 43 | config["len_traj_pred"], 44 | config["learn_angle"], 45 | config["obs_encoding_size"], 46 | config["goal_encoding_size"], 47 | ) 48 | elif model_type == "vint": 49 | model = ViNT( 50 | context_size=config["context_size"], 51 | len_traj_pred=config["len_traj_pred"], 52 | learn_angle=config["learn_angle"], 53 | obs_encoder=config["obs_encoder"], 54 | obs_encoding_size=config["obs_encoding_size"], 55 | late_fusion=config["late_fusion"], 56 | mha_num_attention_heads=config["mha_num_attention_heads"], 57 | mha_num_attention_layers=config["mha_num_attention_layers"], 58 | mha_ff_dim_factor=config["mha_ff_dim_factor"], 59 | ) 60 | elif config["model_type"] == "nomad": 61 | if config["vision_encoder"] == "nomad_vint": 62 | vision_encoder = NoMaD_ViNT( 63 | obs_encoding_size=config["encoding_size"], 64 | context_size=config["context_size"], 65 | mha_num_attention_heads=config["mha_num_attention_heads"], 66 | mha_num_attention_layers=config["mha_num_attention_layers"], 67 | mha_ff_dim_factor=config["mha_ff_dim_factor"], 68 | ) 69 | vision_encoder = replace_bn_with_gn(vision_encoder) 70 | elif config["vision_encoder"] == "vit": 71 | vision_encoder = ViT( 72 | obs_encoding_size=config["encoding_size"], 73 | context_size=config["context_size"], 74 | image_size=config["image_size"], 75 | patch_size=config["patch_size"], 76 | mha_num_attention_heads=config["mha_num_attention_heads"], 77 | mha_num_attention_layers=config["mha_num_attention_layers"], 78 | ) 79 | vision_encoder = replace_bn_with_gn(vision_encoder) 80 | else: 81 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported") 82 | 83 | noise_pred_net = ConditionalUnet1D( 84 | input_dim=2, 85 | global_cond_dim=config["encoding_size"], 86 | down_dims=config["down_dims"], 87 | cond_predict_scale=config["cond_predict_scale"], 88 | ) 89 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"]) 90 | 91 | model = NoMaD( 92 | vision_encoder=vision_encoder, 93 | noise_pred_net=noise_pred_net, 94 | dist_pred_net=dist_pred_network, 95 | ) 96 | else: 97 | raise ValueError(f"Invalid model type: {model_type}") 98 | 99 | checkpoint = torch.load(model_path, map_location=device) 100 | if model_type == "nomad": 101 | state_dict = checkpoint 102 | model.load_state_dict(state_dict, strict=False) 103 | else: 104 | loaded_model = checkpoint["model"] 105 | try: 106 | state_dict = loaded_model.module.state_dict() 107 | model.load_state_dict(state_dict, strict=False) 108 | except AttributeError as e: 109 | state_dict = loaded_model.state_dict() 110 | model.load_state_dict(state_dict, strict=False) 111 | model.to(device) 112 | return model 113 | 114 | 115 | def msg_to_pil(msg: Image) -> PILImage.Image: 116 | img = np.frombuffer(msg.data, dtype=np.uint8).reshape( 117 | msg.height, msg.width, -1) 118 | img = img[:, :, :3] 119 | img = img[:, :, ::-1] 120 | pil_image = PILImage.fromarray(img) 121 | return pil_image 122 | 123 | 124 | def pil_to_msg(pil_img: PILImage.Image, encoding="mono8") -> Image: 125 | img = np.asarray(pil_img) 126 | ros_image = Image(encoding=encoding) 127 | ros_image.height, ros_image.width, _ = img.shape 128 | ros_image.data = img.ravel().tobytes() 129 | ros_image.step = ros_image.width 130 | return ros_image 131 | 132 | 133 | def to_numpy(tensor): 134 | return tensor.cpu().detach().numpy() 135 | 136 | 137 | def transform_images(pil_imgs: List[PILImage.Image], image_size: List[int], center_crop: bool = False) -> torch.Tensor: 138 | transform_type = transforms.Compose( 139 | [ 140 | transforms.ToTensor(), 141 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ 142 | 0.229, 0.224, 0.225]), 143 | ] 144 | ) 145 | if type(pil_imgs) != list: 146 | pil_imgs = [pil_imgs] 147 | transf_imgs = [] 148 | for pil_img in pil_imgs: 149 | w, h = pil_img.size 150 | if center_crop: 151 | if w > h: 152 | pil_img = TF.center_crop(pil_img, (h, int(h * IMAGE_ASPECT_RATIO))) 153 | else: 154 | pil_img = TF.center_crop(pil_img, (int(w / IMAGE_ASPECT_RATIO), w)) 155 | pil_img = pil_img.resize(image_size) 156 | transf_img = transform_type(pil_img) 157 | transf_img = torch.unsqueeze(transf_img, 0) 158 | transf_imgs.append(transf_img) 159 | return torch.cat(transf_imgs, dim=1) 160 | 161 | 162 | def clip_angle(angle): 163 | return np.mod(angle + np.pi, 2 * np.pi) - np.pi 164 | 165 | def rotate_point_by_quaternion(point, quaternion): 166 | r = R.from_quat(quaternion) 167 | rotation_matrix = r.as_matrix() 168 | 169 | point_h = np.array([point[0], point[1], point[2]]) 170 | 171 | rotated_point = np.dot(np.linalg.inv(rotation_matrix), point_h) 172 | 173 | return rotated_point 174 | -------------------------------------------------------------------------------- /deployment/src/costmap_cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, ETH Zurich (Robotics Systems Lab) 2 | # Author: Pascal Roth 3 | # All rights reserved. 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | 7 | # python 8 | import os 9 | # from dataclasses import dataclass 10 | from typing import Optional 11 | 12 | import yaml 13 | 14 | 15 | class Loader(yaml.SafeLoader): 16 | pass 17 | 18 | 19 | def construct_GeneralCostMapConfig(loader, node): 20 | return GeneralCostMapConfig(**loader.construct_mapping(node)) 21 | 22 | 23 | Loader.add_constructor( 24 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.GeneralCostMapConfig", 25 | construct_GeneralCostMapConfig, 26 | ) 27 | 28 | 29 | def construct_ReconstructionCfg(loader, node): 30 | return ReconstructionCfg(**loader.construct_mapping(node)) 31 | 32 | 33 | Loader.add_constructor( 34 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.ReconstructionCfg", 35 | construct_ReconstructionCfg, 36 | ) 37 | 38 | 39 | def construct_SemCostMapConfig(loader, node): 40 | return SemCostMapConfig(**loader.construct_mapping(node)) 41 | 42 | 43 | Loader.add_constructor( 44 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.SemCostMapConfig", 45 | construct_SemCostMapConfig, 46 | ) 47 | 48 | 49 | def construct_TsdfCostMapConfig(loader, node): 50 | return TsdfCostMapConfig(**loader.construct_mapping(node)) 51 | 52 | 53 | Loader.add_constructor( 54 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.TsdfCostMapConfig", 55 | construct_TsdfCostMapConfig, 56 | ) 57 | 58 | 59 | # @dataclass 60 | class ReconstructionCfg: 61 | """ 62 | Arguments for 3D reconstruction using depth maps 63 | """ 64 | 65 | # directory where the environment with the depth (and semantic) images is located 66 | data_dir: str = "${USER_PATH_TO_DATA}" 67 | # environment name 68 | env: str = "town01" 69 | # image suffix 70 | depth_suffix = "_cam0" 71 | sem_suffix = "_cam1" 72 | # higher resolution depth images available for reconstruction (meaning that the depth images are also taked by the semantic camera) 73 | high_res_depth: bool = False 74 | 75 | # reconstruction parameters 76 | voxel_size: float = 0.1 # [m] 0.05 for matterport 0.1 for carla 77 | start_idx: int = 0 # start index for reconstruction 78 | max_images: Optional[int] = 1000 # maximum number of images to reconstruct, if None, all images are used 79 | depth_scale: float = 1000.0 # depth scale factor 80 | # semantic reconstruction 81 | semantics: bool = False 82 | 83 | # speed vs. memory trade-off parameters 84 | point_cloud_batch_size: int = ( 85 | 200 # 3d points of nbr images added to point cloud at once (higher values use more memory but faster) 86 | ) 87 | 88 | """ Internal functions """ 89 | 90 | def get_data_path(self) -> str: 91 | return os.path.join(self.data_dir, self.env) 92 | 93 | def get_out_path(self) -> str: 94 | return os.path.join(self.out_dir, self.env) 95 | 96 | 97 | # @dataclass 98 | class SemCostMapConfig: 99 | """Configuration for the semantic cost map""" 100 | 101 | # point-cloud filter parameters 102 | ground_height: Optional[float] = -0.5 # None for matterport -0.5 for carla -1.0 for nomoko 103 | robot_height: float = 0.70 104 | robot_height_factor: float = 3.0 105 | nb_neighbors: int = 100 106 | std_ratio: float = 2.0 # keep high, otherwise ground will be removed 107 | downsample: bool = False 108 | # smoothing 109 | nb_neigh: int = 15 110 | change_decimal: int = 3 111 | conv_crit: float = ( 112 | 0.45 # ration of points that have to change by at least the #change_decimal decimal value to converge 113 | ) 114 | nb_tasks: Optional[int] = 10 # number of tasks for parallel processing, if None, all available cores are used 115 | sigma_smooth: float = 2.5 116 | max_iterations: int = 1 117 | # obstacle threshold (multiplied with highest loss value defined for a semantic class) 118 | obstacle_threshold: float = 0.8 # 0.5/ 0.6 for matterport, 0.8 for carla 119 | # negative reward for space with smallest cost (introduces a gradient in area with smallest loss value, steering towards center) 120 | # NOTE: at the end cost map is elevated by that amount to ensure that the smallest cost is 0 121 | negative_reward: float = 0.5 122 | # loss values rounded up to decimal #round_decimal_traversable equal to 0.0 are selected and the traversable gradient is determined based on them 123 | round_decimal_traversable: int = 2 124 | # compute height map 125 | compute_height_map: bool = False # false for matterport, true for carla and nomoko 126 | 127 | 128 | # @dataclass 129 | class TsdfCostMapConfig: 130 | """Configuration for the tsdf cost map""" 131 | 132 | # offset of the point cloud 133 | offset_z: float = 0.3 134 | # filter parameters 135 | ground_height: float = 0.3 136 | robot_height: float = 0.7 137 | robot_height_factor: float = 1.2 138 | nb_neighbors: int = 50 139 | std_ratio: float = 0.2 140 | filter_outliers: bool = True 141 | # dilation parameters 142 | sigma_expand: float = 3.0 143 | obstacle_threshold: float = 0.01 144 | free_space_threshold: float = 0.8 145 | 146 | 147 | # @dataclass 148 | class GeneralCostMapConfig: 149 | """General Cost Map Configuration""" 150 | 151 | # path to point cloud 152 | root_path: str = "town01" 153 | ply_file: str = "cloud.ply" 154 | # resolution of the cost map 155 | resolution: float = 0.04 # [m] (0.04 for matterport, 0.1 for carla) 156 | # map parameters 157 | clear_dist: float = 1.0 # cost map expansion over the point cloud space (prevent paths to go out of the map) 158 | # smoothing parameters 159 | sigma_smooth: float = 3.0 160 | # cost map expansion 161 | x_min: Optional[float] = 0 162 | # [m] if None, the minimum of the point cloud is used None (carla town01: -8.05 matterport: None) 163 | y_min: Optional[float] = -3 164 | # [m] if None, the minimum of the point cloud is used None (carla town01: -8.05 matterport: None) 165 | x_max: Optional[float] = 8 166 | # [m] if None, the maximum of the point cloud is used None (carla town01: 346.22 matterport: None) 167 | y_max: Optional[float] = 3 168 | # [m] if None, the maximum of the point cloud is used None (carla town01: 336.65 matterport: None) 169 | 170 | 171 | # @dataclass 172 | class CostMapConfig: 173 | """General Cost Map Configuration""" 174 | 175 | # cost map domains 176 | semantics: bool = True 177 | geometry: bool = False 178 | 179 | # name 180 | map_name: str = "cost_map_sem" 181 | 182 | # general cost map configuration 183 | general: GeneralCostMapConfig = GeneralCostMapConfig() 184 | 185 | # individual cost map configurations 186 | sem_cost_map: SemCostMapConfig = SemCostMapConfig() 187 | tsdf_cost_map: TsdfCostMapConfig = TsdfCostMapConfig() 188 | 189 | # visualize cost map 190 | visualize: bool = True 191 | 192 | # FILLED BY CODE -> DO NOT CHANGE ### 193 | x_start: float = None 194 | y_start: float = None 195 | 196 | 197 | # EoF 198 | -------------------------------------------------------------------------------- /train/vint_train/visualizing/distance_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import numpy as np 4 | from typing import List, Optional, Tuple 5 | from vint_train.visualizing.visualize_utils import numpy_to_img 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def visualize_dist_pred( 10 | batch_obs_images: np.ndarray, 11 | batch_goal_images: np.ndarray, 12 | batch_dist_preds: np.ndarray, 13 | batch_dist_labels: np.ndarray, 14 | eval_type: str, 15 | save_folder: str, 16 | epoch: int, 17 | num_images_preds: int = 8, 18 | use_wandb: bool = True, 19 | display: bool = False, 20 | rounding: int = 4, 21 | dist_error_threshold: float = 3.0, 22 | ): 23 | """ 24 | Visualize the distance classification predictions and labels for an observation-goal image pair. 25 | 26 | Args: 27 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] 28 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels] 29 | batch_dist_preds (np.ndarray): batch of distance predictions [batch_size] 30 | batch_dist_labels (np.ndarray): batch of distance labels [batch_size] 31 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.) 32 | epoch (int): current epoch number 33 | num_images_preds (int): number of images to visualize 34 | use_wandb (bool): whether to use wandb to log the images 35 | save_folder (str): folder to save the images. If None, will not save the images 36 | display (bool): whether to display the images 37 | rounding (int): number of decimal places to round the distance predictions and labels 38 | dist_error_threshold (float): distance error threshold for classifying the distance prediction as correct or incorrect (only used for visualization purposes) 39 | """ 40 | visualize_path = os.path.join( 41 | save_folder, 42 | "visualize", 43 | eval_type, 44 | f"epoch{epoch}", 45 | "dist_classification", 46 | ) 47 | if not os.path.isdir(visualize_path): 48 | os.makedirs(visualize_path) 49 | assert ( 50 | len(batch_obs_images) 51 | == len(batch_goal_images) 52 | == len(batch_dist_preds) 53 | == len(batch_dist_labels) 54 | ) 55 | batch_size = batch_obs_images.shape[0] 56 | wandb_list = [] 57 | for i in range(min(batch_size, num_images_preds)): 58 | dist_pred = np.round(batch_dist_preds[i], rounding) 59 | dist_label = np.round(batch_dist_labels[i], rounding) 60 | obs_image = numpy_to_img(batch_obs_images[i]) 61 | goal_image = numpy_to_img(batch_goal_images[i]) 62 | 63 | save_path = None 64 | if save_folder is not None: 65 | save_path = os.path.join(visualize_path, f"{i}.png") 66 | text_color = "black" 67 | if abs(dist_pred - dist_label) > dist_error_threshold: 68 | text_color = "red" 69 | 70 | display_distance_pred( 71 | [obs_image, goal_image], 72 | ["Observation", "Goal"], 73 | dist_pred, 74 | dist_label, 75 | text_color, 76 | save_path, 77 | display, 78 | ) 79 | if use_wandb: 80 | wandb_list.append(wandb.Image(save_path)) 81 | if use_wandb: 82 | wandb.log({f"{eval_type}_dist_prediction": wandb_list}, commit=False) 83 | 84 | 85 | def visualize_dist_pairwise_pred( 86 | batch_obs_images: np.ndarray, 87 | batch_close_images: np.ndarray, 88 | batch_far_images: np.ndarray, 89 | batch_close_preds: np.ndarray, 90 | batch_far_preds: np.ndarray, 91 | batch_close_labels: np.ndarray, 92 | batch_far_labels: np.ndarray, 93 | eval_type: str, 94 | save_folder: str, 95 | epoch: int, 96 | num_images_preds: int = 8, 97 | use_wandb: bool = True, 98 | display: bool = False, 99 | rounding: int = 4, 100 | ): 101 | """ 102 | Visualize the distance classification predictions and labels for an observation-goal image pair. 103 | 104 | Args: 105 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] 106 | batch_close_images (np.ndarray): batch of close goal images [batch_size, height, width, channels] 107 | batch_far_images (np.ndarray): batch of far goal images [batch_size, height, width, channels] 108 | batch_close_preds (np.ndarray): batch of close predictions [batch_size] 109 | batch_far_preds (np.ndarray): batch of far predictions [batch_size] 110 | batch_close_labels (np.ndarray): batch of close labels [batch_size] 111 | batch_far_labels (np.ndarray): batch of far labels [batch_size] 112 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.) 113 | save_folder (str): folder to save the images. If None, will not save the images 114 | epoch (int): current epoch number 115 | num_images_preds (int): number of images to visualize 116 | use_wandb (bool): whether to use wandb to log the images 117 | display (bool): whether to display the images 118 | rounding (int): number of decimal places to round the distance predictions and labels 119 | """ 120 | visualize_path = os.path.join( 121 | save_folder, 122 | "visualize", 123 | eval_type, 124 | f"epoch{epoch}", 125 | "pairwise_dist_classification", 126 | ) 127 | if not os.path.isdir(visualize_path): 128 | os.makedirs(visualize_path) 129 | assert ( 130 | len(batch_obs_images) 131 | == len(batch_close_images) 132 | == len(batch_far_images) 133 | == len(batch_close_preds) 134 | == len(batch_far_preds) 135 | == len(batch_close_labels) 136 | == len(batch_far_labels) 137 | ) 138 | batch_size = batch_obs_images.shape[0] 139 | wandb_list = [] 140 | for i in range(min(batch_size, num_images_preds)): 141 | close_dist_pred = np.round(batch_close_preds[i], rounding) 142 | far_dist_pred = np.round(batch_far_preds[i], rounding) 143 | close_dist_label = np.round(batch_close_labels[i], rounding) 144 | far_dist_label = np.round(batch_far_labels[i], rounding) 145 | obs_image = numpy_to_img(batch_obs_images[i]) 146 | close_image = numpy_to_img(batch_close_images[i]) 147 | far_image = numpy_to_img(batch_far_images[i]) 148 | 149 | save_path = None 150 | if save_folder is not None: 151 | save_path = os.path.join(visualize_path, f"{i}.png") 152 | 153 | if close_dist_pred < far_dist_pred: 154 | text_color = "black" 155 | else: 156 | text_color = "red" 157 | 158 | display_distance_pred( 159 | [obs_image, close_image, far_image], 160 | ["Observation", "Close Goal", "Far Goal"], 161 | f"close_pred = {close_dist_pred}, far_pred = {far_dist_pred}", 162 | f"close_label = {close_dist_label}, far_label = {far_dist_label}", 163 | text_color, 164 | save_path, 165 | display, 166 | ) 167 | if use_wandb: 168 | wandb_list.append(wandb.Image(save_path)) 169 | if use_wandb: 170 | wandb.log({f"{eval_type}_pairwise_classification": wandb_list}, commit=False) 171 | 172 | 173 | def display_distance_pred( 174 | imgs: list, 175 | titles: list, 176 | dist_pred: float, 177 | dist_label: float, 178 | text_color: str = "black", 179 | save_path: Optional[str] = None, 180 | display: bool = False, 181 | ): 182 | plt.figure() 183 | fig, ax = plt.subplots(1, len(imgs)) 184 | 185 | plt.suptitle(f"prediction: {dist_pred}\nlabel: {dist_label}", color=text_color) 186 | 187 | for axis, img, title in zip(ax, imgs, titles): 188 | axis.imshow(img) 189 | axis.set_title(title) 190 | axis.xaxis.set_visible(False) 191 | axis.yaxis.set_visible(False) 192 | 193 | # make the plot large 194 | fig.set_size_inches((18.5 / 3) * len(imgs), 10.5) 195 | 196 | if save_path is not None: 197 | fig.savefig( 198 | save_path, 199 | bbox_inches="tight", 200 | ) 201 | if not display: 202 | plt.close(fig) 203 | -------------------------------------------------------------------------------- /deployment/src/cost_to_pcd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, ETH Zurich (Robotics Systems Lab) 2 | # Author: Pascal Roth 3 | # All rights reserved. 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | 7 | import argparse 8 | import os 9 | from typing import Optional, Union 10 | 11 | import numpy as np 12 | import open3d as o3d 13 | import pypose as pp 14 | import torch 15 | import yaml 16 | 17 | # viplanner 18 | from costmap_cfg import CostMapConfig, Loader 19 | 20 | torch.set_default_dtype(torch.float32) 21 | 22 | 23 | class CostMapPCD: 24 | def __init__( 25 | self, 26 | cfg: CostMapConfig, 27 | tsdf_array: np.ndarray, 28 | viz_points: np.ndarray, 29 | ground_array: np.ndarray, 30 | gpu_id: Optional[int] = 0, 31 | load_from_file: Optional[bool] = False, 32 | ): 33 | # determine device 34 | if torch.cuda.is_available() and gpu_id is not None: 35 | self.device = torch.device("cuda:" + str(gpu_id)) 36 | else: 37 | self.device = torch.device("cpu") 38 | 39 | # args 40 | self.cfg: CostMapConfig = cfg 41 | self.load_from_file: bool = load_from_file 42 | self.tsdf_array: torch.Tensor = torch.tensor(tsdf_array, device=self.device) 43 | self.viz_points: np.ndarray = viz_points 44 | self.ground_array: torch.Tensor = torch.tensor(ground_array, device=self.device) 45 | 46 | # init flag 47 | self.map_init = False 48 | 49 | # init pointclouds 50 | self.pcd_tsdf = o3d.geometry.PointCloud() 51 | self.pcd_viz = o3d.geometry.PointCloud() 52 | 53 | # execute setup 54 | self.num_x: int = 0 55 | self.num_y: int = 0 56 | self.setup() 57 | return 58 | 59 | def setup(self): 60 | # expand of cost map 61 | self.num_x, self.num_y = self.tsdf_array.shape 62 | # visualization points 63 | self.pcd_viz.points = o3d.utility.Vector3dVector(self.viz_points) 64 | # set cost map 65 | self.SetUpCostArray() 66 | # update pcd instance 67 | xv, yv = np.meshgrid( 68 | np.linspace(0, self.num_x * self.cfg.general.resolution, self.num_x), 69 | np.linspace(0, self.num_y * self.cfg.general.resolution, self.num_y), 70 | indexing="ij", 71 | ) 72 | T = np.concatenate((np.expand_dims(xv, axis=0), np.expand_dims(yv, axis=0)), axis=0) 73 | T = np.concatenate( 74 | ( 75 | T, 76 | np.expand_dims(self.cost_array.cpu().detach().numpy(), axis=0), 77 | ), 78 | axis=0, 79 | ) 80 | if self.load_from_file: 81 | wps = T.reshape(3, -1).T + np.array([self.cfg.x_start, self.cfg.y_start, 0.0]) 82 | self.pcd_tsdf.points = o3d.utility.Vector3dVector(wps) 83 | else: 84 | self.pcd_tsdf.points = o3d.utility.Vector3dVector(T.reshape(3, -1).T) 85 | 86 | self.map_init = True 87 | return 88 | 89 | def ShowTSDFMap(self, cost_map=True): # not run with cuda 90 | if not self.map_init: 91 | print("Error: cannot show map, map has not been init yet!") 92 | return 93 | if cost_map: 94 | o3d.visualization.draw_geometries([self.pcd_tsdf]) 95 | else: 96 | o3d.visualization.draw_geometries([self.pcd_viz]) 97 | return 98 | 99 | def Pos2Ind(self, points: Union[torch.Tensor, pp.LieTensor]): 100 | # points [torch shapes [num_p, 3]] 101 | start_xy = torch.tensor( 102 | [self.cfg.x_start, self.cfg.y_start], 103 | dtype=torch.float64, 104 | device=points.device, 105 | ).expand(1, 1, -1) 106 | if isinstance(points, pp.LieTensor): 107 | H = (points.tensor()[:, :, 0:2] - start_xy) / self.cfg.general.resolution 108 | else: 109 | H = (points[:, :, 0:2] - start_xy) / self.cfg.general.resolution 110 | mask = torch.logical_and( 111 | (H > 0).all(axis=2), 112 | (H < torch.tensor([self.num_x, self.num_y], device=points.device)[None, None, :]).all(axis=2), 113 | ) 114 | return self.NormInds(H), H[mask, :] 115 | 116 | def NormInds(self, H): 117 | norm_matrix = torch.tensor( 118 | [self.num_x / 2.0, self.num_y / 2.0], 119 | dtype=torch.float64, 120 | device=H.device, 121 | ) 122 | H = (H - norm_matrix) / norm_matrix 123 | return H 124 | 125 | def DeNormInds(self, NH): 126 | norm_matrix = torch.tensor( 127 | [self.num_x / 2.0, self.num_y / 2.0], 128 | dtype=torch.float64, 129 | device=NH.device, 130 | ) 131 | NH = NH * norm_matrix + norm_matrix 132 | return NH 133 | 134 | def SaveTSDFMap(self): 135 | if not self.map_init: 136 | print("Error: map has not been init yet!") 137 | return 138 | 139 | # make directories 140 | os.makedirs( 141 | os.path.join(self.cfg.general.root_path, "maps", "data"), 142 | exist_ok=True, 143 | ) 144 | os.makedirs( 145 | os.path.join(self.cfg.general.root_path, "maps", "cloud"), 146 | exist_ok=True, 147 | ) 148 | os.makedirs( 149 | os.path.join(self.cfg.general.root_path, "maps", "params"), 150 | exist_ok=True, 151 | ) 152 | 153 | map_path = os.path.join( 154 | self.cfg.general.root_path, 155 | "maps", 156 | "data", 157 | self.cfg.map_name + "_map.txt", 158 | ) 159 | ground_path = os.path.join( 160 | self.cfg.general.root_path, 161 | "maps", 162 | "data", 163 | self.cfg.map_name + "_ground.txt", 164 | ) 165 | cloud_path = os.path.join( 166 | self.cfg.general.root_path, 167 | "maps", 168 | "cloud", 169 | self.cfg.map_name + "_cloud.txt", 170 | ) 171 | # save data 172 | np.savetxt(map_path, self.tsdf_array.cpu()) 173 | np.savetxt(ground_path, self.ground_array.cpu()) 174 | np.savetxt(cloud_path, self.viz_points) 175 | # save config parameters 176 | yaml_path = os.path.join( 177 | self.cfg.general.root_path, 178 | "maps", 179 | "params", 180 | f"config_{self.cfg.map_name}.yaml", 181 | ) 182 | with open(yaml_path, "w+") as file: 183 | yaml.dump( 184 | vars(self.cfg), 185 | file, 186 | allow_unicode=True, 187 | default_flow_style=False, 188 | ) 189 | 190 | print("TSDF Map saved.") 191 | return 192 | 193 | def SetUpCostArray(self): 194 | self.cost_array = self.tsdf_array 195 | return 196 | 197 | @classmethod 198 | def ReadTSDFMap(cls, root_path: str, map_name: str, gpu_id: Optional[int] = None): 199 | # read config 200 | with open(os.path.join(root_path, "maps", "params", f"config_{map_name}.yaml")) as f: 201 | cfg: CostMapConfig = CostMapConfig(**yaml.load(f, Loader)) 202 | 203 | # load data 204 | tsdf_array = np.loadtxt(os.path.join(root_path, "maps", "data", map_name + "_map.txt")) 205 | viz_points = np.loadtxt(os.path.join(root_path, "maps", "cloud", map_name + "_cloud.txt")) 206 | ground_array = np.loadtxt(os.path.join(root_path, "maps", "data", map_name + "_ground.txt")) 207 | 208 | return cls( 209 | cfg=cfg, 210 | tsdf_array=tsdf_array, 211 | viz_points=viz_points, 212 | ground_array=ground_array, 213 | gpu_id=gpu_id, 214 | load_from_file=True, 215 | ) 216 | 217 | 218 | if __name__ == "__main__": 219 | # parse environment directory and cost_map name 220 | parser = argparse.ArgumentParser(prog="Show Costmap", description="Show Costmap") 221 | parser.add_argument( 222 | "-e", 223 | "--env", 224 | type=str, 225 | help="path to the environment directory", 226 | required=True, 227 | ) 228 | parser.add_argument("-m", "--map", type=str, help="name of the cost_map", required=True) 229 | args = parser.parse_args() 230 | 231 | # show costmap 232 | map = CostMapPCD.ReadTSDFMap(args.env, args.map) 233 | map.ShowTSDFMap() 234 | 235 | # EoF 236 | -------------------------------------------------------------------------------- /train/vint_train/models/nomad/nomad_vint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from typing import List, Dict, Optional, Tuple, Callable 6 | from efficientnet_pytorch import EfficientNet 7 | from vint_train.models.vint.self_attention import PositionalEncoding 8 | 9 | class NoMaD_ViNT(nn.Module): 10 | def __init__( 11 | self, 12 | context_size: int = 5, 13 | obs_encoder: Optional[str] = "efficientnet-b0", 14 | obs_encoding_size: Optional[int] = 512, 15 | mha_num_attention_heads: Optional[int] = 2, 16 | mha_num_attention_layers: Optional[int] = 2, 17 | mha_ff_dim_factor: Optional[int] = 4, 18 | ) -> None: 19 | """ 20 | NoMaD ViNT Encoder class 21 | """ 22 | super().__init__() 23 | self.obs_encoding_size = obs_encoding_size 24 | self.goal_encoding_size = obs_encoding_size 25 | self.context_size = context_size 26 | 27 | # Initialize the observation encoder 28 | if obs_encoder.split("-")[0] == "efficientnet": 29 | self.obs_encoder = EfficientNet.from_name(obs_encoder, in_channels=3) # context 30 | self.obs_encoder = replace_bn_with_gn(self.obs_encoder) 31 | self.num_obs_features = self.obs_encoder._fc.in_features 32 | self.obs_encoder_type = "efficientnet" 33 | else: 34 | raise NotImplementedError 35 | 36 | # Initialize the goal encoder 37 | self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=6) # obs+goal 38 | self.goal_encoder = replace_bn_with_gn(self.goal_encoder) 39 | self.num_goal_features = self.goal_encoder._fc.in_features 40 | 41 | # Initialize compression layers if necessary 42 | if self.num_obs_features != self.obs_encoding_size: 43 | self.compress_obs_enc = nn.Linear(self.num_obs_features, self.obs_encoding_size) 44 | else: 45 | self.compress_obs_enc = nn.Identity() 46 | 47 | if self.num_goal_features != self.goal_encoding_size: 48 | self.compress_goal_enc = nn.Linear(self.num_goal_features, self.goal_encoding_size) 49 | else: 50 | self.compress_goal_enc = nn.Identity() 51 | 52 | # Initialize positional encoding and self-attention layers 53 | self.positional_encoding = PositionalEncoding(self.obs_encoding_size, max_seq_len=self.context_size + 2) 54 | self.sa_layer = nn.TransformerEncoderLayer( 55 | d_model=self.obs_encoding_size, 56 | nhead=mha_num_attention_heads, 57 | dim_feedforward=mha_ff_dim_factor*self.obs_encoding_size, 58 | activation="gelu", 59 | batch_first=True, 60 | norm_first=True 61 | ) 62 | self.sa_encoder = nn.TransformerEncoder(self.sa_layer, num_layers=mha_num_attention_layers) 63 | 64 | # Definition of the goal mask (convention: 0 = no mask, 1 = mask) 65 | self.goal_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool) 66 | self.goal_mask[:, -1] = True # Mask out the goal 67 | self.no_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool) 68 | self.all_masks = torch.cat([self.no_mask, self.goal_mask], dim=0) 69 | self.avg_pool_mask = torch.cat([1 - self.no_mask.float(), (1 - self.goal_mask.float()) * ((self.context_size + 2)/(self.context_size + 1))], dim=0) 70 | 71 | 72 | def forward(self, obs_img: torch.tensor, goal_img: torch.tensor, input_goal_mask: torch.tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 73 | 74 | device = obs_img.device 75 | 76 | # Initialize the goal encoding 77 | goal_encoding = torch.zeros((obs_img.size()[0], 1, self.goal_encoding_size)).to(device) 78 | 79 | # Get the input goal mask 80 | if input_goal_mask is not None: 81 | goal_mask = input_goal_mask.to(device) 82 | 83 | # Get the goal encoding 84 | obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1) # concatenate the obs image/context and goal image --> non image goal? 85 | obsgoal_encoding = self.goal_encoder.extract_features(obsgoal_img) # get encoding of this img 86 | obsgoal_encoding = self.goal_encoder._avg_pooling(obsgoal_encoding) # avg pooling 87 | 88 | if self.goal_encoder._global_params.include_top: 89 | obsgoal_encoding = obsgoal_encoding.flatten(start_dim=1) 90 | obsgoal_encoding = self.goal_encoder._dropout(obsgoal_encoding) 91 | obsgoal_encoding = self.compress_goal_enc(obsgoal_encoding) 92 | 93 | if len(obsgoal_encoding.shape) == 2: 94 | obsgoal_encoding = obsgoal_encoding.unsqueeze(1) 95 | assert obsgoal_encoding.shape[2] == self.goal_encoding_size 96 | goal_encoding = obsgoal_encoding 97 | 98 | # Get the observation encoding 99 | obs_img = torch.split(obs_img, 3, dim=1) 100 | obs_img = torch.concat(obs_img, dim=0) 101 | 102 | obs_encoding = self.obs_encoder.extract_features(obs_img) 103 | obs_encoding = self.obs_encoder._avg_pooling(obs_encoding) 104 | if self.obs_encoder._global_params.include_top: 105 | obs_encoding = obs_encoding.flatten(start_dim=1) 106 | obs_encoding = self.obs_encoder._dropout(obs_encoding) 107 | obs_encoding = self.compress_obs_enc(obs_encoding) 108 | obs_encoding = obs_encoding.unsqueeze(1) 109 | obs_encoding = obs_encoding.reshape((self.context_size+1, -1, self.obs_encoding_size)) 110 | obs_encoding = torch.transpose(obs_encoding, 0, 1) 111 | obs_encoding = torch.cat((obs_encoding, goal_encoding), dim=1) 112 | 113 | # If a goal mask is provided, mask some of the goal tokens 114 | if goal_mask is not None: 115 | no_goal_mask = goal_mask.long() 116 | src_key_padding_mask = torch.index_select(self.all_masks.to(device), 0, no_goal_mask) 117 | else: 118 | src_key_padding_mask = None 119 | 120 | # Apply positional encoding 121 | if self.positional_encoding: 122 | obs_encoding = self.positional_encoding(obs_encoding) 123 | 124 | obs_encoding_tokens = self.sa_encoder(obs_encoding, src_key_padding_mask=src_key_padding_mask) 125 | if src_key_padding_mask is not None: 126 | avg_mask = torch.index_select(self.avg_pool_mask.to(device), 0, no_goal_mask).unsqueeze(-1) 127 | obs_encoding_tokens = obs_encoding_tokens * avg_mask 128 | obs_encoding_tokens = torch.mean(obs_encoding_tokens, dim=1) 129 | 130 | return obs_encoding_tokens 131 | 132 | 133 | 134 | # Utils for Group Norm 135 | def replace_bn_with_gn( 136 | root_module: nn.Module, 137 | features_per_group: int=16) -> nn.Module: 138 | """ 139 | Relace all BatchNorm layers with GroupNorm. 140 | """ 141 | replace_submodules( 142 | root_module=root_module, 143 | predicate=lambda x: isinstance(x, nn.BatchNorm2d), 144 | func=lambda x: nn.GroupNorm( 145 | num_groups=x.num_features//features_per_group, 146 | num_channels=x.num_features) 147 | ) 148 | return root_module 149 | 150 | 151 | def replace_submodules( 152 | root_module: nn.Module, 153 | predicate: Callable[[nn.Module], bool], 154 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 155 | """ 156 | Replace all submodules selected by the predicate with 157 | the output of func. 158 | 159 | predicate: Return true if the module is to be replaced. 160 | func: Return new module to use. 161 | """ 162 | if predicate(root_module): 163 | return func(root_module) 164 | 165 | bn_list = [k.split('.') for k, m 166 | in root_module.named_modules(remove_duplicate=True) 167 | if predicate(m)] 168 | for *parent, k in bn_list: 169 | parent_module = root_module 170 | if len(parent) > 0: 171 | parent_module = root_module.get_submodule('.'.join(parent)) 172 | if isinstance(parent_module, nn.Sequential): 173 | src_module = parent_module[int(k)] 174 | else: 175 | src_module = getattr(parent_module, k) 176 | tgt_module = func(src_module) 177 | if isinstance(parent_module, nn.Sequential): 178 | parent_module[int(k)] = tgt_module 179 | else: 180 | setattr(parent_module, k, tgt_module) 181 | # verify that all modules are replaced 182 | bn_list = [k.split('.') for k, m 183 | in root_module.named_modules(remove_duplicate=True) 184 | if predicate(m)] 185 | assert len(bn_list) == 0 186 | return root_module 187 | 188 | 189 | 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NaviDiffusor: Cost-Guided Diffusion Model for Visual Navigation 2 | 3 | > 🏆 Accepted at **ICRA 2025** 4 | > 🔗 [arXiv](https://arxiv.org/abs/2504.10003) | [Bilibili](https://www.bilibili.com/video/BV1PaLizwEkW/) | [Youtube](https://www.youtube.com/watch?v=94ODPEqyP0s) 5 | 6 |

7 | Overview 8 |

9 | 10 | --- 11 | 12 | ## ✅ TODO List 13 | 14 | - [ ] Training code updates 15 | - [ ] Simulation Envs 16 | 17 | ## ⚙️ Setup 18 | Run the commands below inside the project directory: 19 | 1. Set up the conda environment: 20 | ```bash 21 | conda env create -f train/train_environment.yml 22 | ``` 23 | 2. Source the conda environment: 24 | ``` 25 | conda activate navidiffusor 26 | ``` 27 | 3. Install the vint_train packages: 28 | ```bash 29 | pip install -e train/ 30 | ``` 31 | 4. Install the `diffusion_policy` package from this [repo](https://github.com/real-stanford/diffusion_policy): 32 | ```bash 33 | git clone git@github.com:real-stanford/diffusion_policy.git 34 | pip install -e diffusion_policy/ 35 | ``` 36 | 5. Install the `depth_anything_v2` package from this [repo](https://github.com/DepthAnything/Depth-Anything-V2): 37 | ```bash 38 | git clone https://github.com/DepthAnything/Depth-Anything-V2.git 39 | pip install -r Depth-Anything-V2/requirements.txt 40 | ``` 41 | 42 | ## Data 43 | - [RECON](https://sites.google.com/view/recon-robot/dataset) 44 | - [SCAND](https://www.cs.utexas.edu/~xiao/SCAND/SCAND.html#Links) 45 | - [GoStanford2 (Modified)](https://drive.google.com/drive/folders/1xrNvMl5q92oWed99noOt_UhqQnceJYV0?usp=share_link) 46 | - [SACSoN/HuRoN](https://sites.google.com/view/sacson-review/huron-dataset) 47 | 48 | We recommend you to download these (and any other datasets you may want to train on) and run the processing steps below. 49 | 50 | ### Data Processing 51 | 52 | We provide some sample scripts to process these datasets, either directly from a rosbag or from a custom format like HDF5s: 53 | 1. Run `process_bags.py` with the relevant args, or `process_recon.py` for processing RECON HDF5s. You can also manually add your own dataset by following our structure below. 54 | 2. Run `data_split.py` on your dataset folder with the relevant args. 55 | 3. Expected structure: 56 | 57 | ``` 58 | ├── 59 | │ ├── 60 | │ │ ├── 0.jpg 61 | │ │ ├── 1.jpg 62 | │ │ ├── ... 63 | │ │ ├── T_1.jpg 64 | │ │ └── traj_data.pkl 65 | │ ├── 66 | │ │ ├── 0.jpg 67 | │ │ ├── 1.jpg 68 | │ │ ├── ... 69 | │ │ ├── T_2.jpg 70 | │ │ └── traj_data.pkl 71 | │ ... 72 | └── └── 73 | ├── 0.jpg 74 | ├── 1.jpg 75 | ├── ... 76 | ├── T_N.jpg 77 | └── traj_data.pkl 78 | ``` 79 | 80 | Each `*.jpg` file contains an forward-facing RGB observation from the robot, and they are temporally labeled. The `traj_data.pkl` file is the odometry data for the trajectory. It’s a pickled dictionary with the keys: 81 | - `"position"`: An np.ndarray [T, 2] of the xy-coordinates of the robot at each image observation. 82 | - `"yaw"`: An np.ndarray [T,] of the yaws of the robot at each image observation. 83 | 84 | 85 | After step 2 of data processing, the processed data-split should the following structure inside `/train/vint_train/data/data_splits/`: 86 | 87 | ``` 88 | ├── 89 | │ ├── train 90 | | | └── traj_names.txt 91 | └── └── test 92 | └── traj_names.txt 93 | ``` 94 | 95 | ## Model Training 96 | ```bash 97 | cd /train 98 | python train.py -c 99 | ``` 100 | The config yaml files are in the `train/config` directory. 101 | 102 | ## Deployment 103 | 104 | 133 | 134 | 135 | ### Inference with Guidance 136 | 🚀 **Our method is designed to provide guidance for any diffusion-based navigation model while inferece, improving path generation quality for both PointGoal and ImageGoal tasks. Here, we use [NoMaD](https://github.com/robodhruv/visualnav-transformer) as an example, an adaptable implementation in [guide.py](./deployment/src/guide.py) is provided for integrating with your own diffusion model.** 137 | 138 | ```bash 139 | cd deployment/src/ 140 | sh ./navigate.sh --model --dir --point-goal False # set --point-goal=True for PointGoal navigation, False for ImageGoal 141 | ``` 142 | 143 | The `` is the name of the model in the `/deployment/config/models.yaml` file. In this file, you specify these parameters of the model for each model (defaults used): 144 | - `config_path` (str): path of the *.yaml file in `/train/config/` used to train the model 145 | - `ckpt_path` (str): path of the *.pth file in `/deployment/model_weights/` 146 | 147 | 148 | Make sure these configurations match what you used to train the model. The configurations for the models we provided the weights for are provided in yaml file for your reference. 149 | 150 | The `` is the name of the directory in `/deployment/topomaps/images` that has the images corresponding to the nodes in the topological map. The images are ordered by name from 0 to N. 151 | 152 | This command opens up 4 windows: 153 | 154 | 1. `roslaunch vint_locobot.launch`: This launch file opens the usb_cam node for the camera, the joy node for the joystick, and several nodes for the robot’s mobile base. 155 | 2. `python navigate.py --model --dir `: This python script starts a node that reads in image observations from the `/usb_cam/image_raw` topic, inputs the observations and the map into the model, and publishes actions to the `/waypoint` topic. 156 | 3. `python joy_teleop.py`: This python script starts a node that reads inputs from the joy topic and outputs them on topics that teleoperate the robot’s base. 157 | 4. `python pd_controller.py`: This python script starts a node that reads messages from the `/waypoint` topic (waypoints from the model) and outputs velocities to navigate the robot’s base. 158 | 159 | When the robot is finishing navigating, kill the `pd_controller.py` script, and then kill the tmux session. If you want to take control of the robot while it is navigating, the `joy_teleop.py` script allows you to do so with the joystick. 160 | 161 | ## Citing 162 | ``` 163 | @article{zeng2025navidiffusor, 164 | title={NaviDiffusor: Cost-Guided Diffusion Model for Visual Navigation}, 165 | author={Zeng, Yiming and Ren, Hao and Wang, Shuhang and Huang, Junlong and Cheng, Hui}, 166 | journal={arXiv preprint arXiv:2504.10003}, 167 | year={2025} 168 | } 169 | ``` 170 | ## Acknowlegdment 171 | NaviDiffusor is inspired by the contributions of the following works to the open-source community:[NoMaD](https://github.com/robodhruv/visualnav-transformer), [Depthanythingv2](https://github.com/DepthAnything/Depth-Anything-V2) and [ViPlanner](https://github.com/leggedrobotics/viplanner). We thank the authors for sharing their outstanding work. 172 | -------------------------------------------------------------------------------- /deployment/src/tsdf_cost_map.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024, ETH Zurich (Robotics Systems Lab) 2 | # Author: Pascal Roth 3 | # All rights reserved. 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | 7 | import math 8 | 9 | # python 10 | import os 11 | 12 | import torch 13 | import numpy as np 14 | import open3d as o3d 15 | from scipy import ndimage 16 | from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion 17 | 18 | # imperative-cost-map 19 | from costmap_cfg import GeneralCostMapConfig, TsdfCostMapConfig 20 | 21 | 22 | class TsdfCostMap: 23 | """ 24 | Cost Map based on geometric information 25 | """ 26 | 27 | def __init__(self, cfg_general: GeneralCostMapConfig, cfg_tsdf: TsdfCostMapConfig): 28 | self._cfg_general = cfg_general 29 | self._cfg_tsdf = cfg_tsdf 30 | # set init flag 31 | self.is_map_ready = False 32 | # init point clouds 33 | self.obs_pcd = o3d.geometry.PointCloud() 34 | self.free_pcd = o3d.geometry.PointCloud() 35 | return 36 | 37 | def Pos2Ind(self, points): 38 | start_xy = torch.tensor([self.start_x, self.start_y], dtype=torch.float64, device=points.device).expand(1, 1, -1) 39 | H = (points - start_xy) / self._cfg_general.resolution 40 | mask = torch.logical_and((H > 0).all(axis=2), (H < torch.tensor([self.num_x, self.num_y], device=points.device)[None,None,:]).all(axis=2)) 41 | return self.NormInds(H), H 42 | 43 | def NormInds(self, H): 44 | norm_matrix = torch.tensor([self.num_x/2.0, self.num_y/2.0], dtype=torch.float64, device=H.device) 45 | H = (H - norm_matrix) / norm_matrix 46 | return H 47 | 48 | def UpdatePCDwithPs(self, P_obs, P_free, is_downsample=False): 49 | self.obs_pcd.points = o3d.utility.Vector3dVector(P_obs) 50 | self.free_pcd.points = o3d.utility.Vector3dVector(P_free) 51 | if is_downsample: 52 | self.obs_pcd = self.obs_pcd.voxel_down_sample(self._cfg_general.resolution) 53 | self.free_pcd = self.free_pcd.voxel_down_sample(self._cfg_general.resolution * 0.85) 54 | 55 | self.obs_points = np.asarray(self.obs_pcd.points) 56 | self.free_points = np.asarray(self.free_pcd.points) 57 | print("number of obs points: %d, free points: %d" % (self.obs_points.shape[0], self.free_points.shape[0])) 58 | 59 | def ReadPointFromFile(self): 60 | pcd_load = o3d.io.read_point_cloud(os.path.join(self._cfg_general.root_path, self._cfg_general.ply_file)) 61 | obs_p, free_p = self.TerrainAnalysis(np.asarray(pcd_load.points)) 62 | self.UpdatePCDwithPs(obs_p, free_p, is_downsample=True) 63 | if self._cfg_tsdf.filter_outliers: 64 | obs_p = self.FilterCloud(self.obs_points) 65 | free_p = self.FilterCloud(self.free_points, outlier_filter=False) 66 | self.UpdatePCDwithPs(obs_p, free_p) 67 | self.UpdateMapParams() 68 | return 69 | 70 | def LoadPointCloud(self, pcd): 71 | obs_p, free_p = self.TerrainAnalysis(np.asarray(pcd.points)) 72 | self.UpdatePCDwithPs(obs_p, free_p, is_downsample=True) 73 | if self._cfg_tsdf.filter_outliers: 74 | obs_p = self.FilterCloud(self.obs_points, outlier_filter=False) 75 | free_p = self.FilterCloud(self.free_points, outlier_filter=False) 76 | self.UpdatePCDwithPs(obs_p, free_p) 77 | self.UpdateMapParams() 78 | return 79 | 80 | def TerrainAnalysis(self, input_points): 81 | obs_points = np.zeros(input_points.shape) 82 | free_poins = np.zeros(input_points.shape) 83 | obs_idx = 0 84 | free_idx = 0 85 | # naive approach with z values 86 | for p in input_points: 87 | p_height = p[2] + self._cfg_tsdf.offset_z 88 | if (p_height > self._cfg_tsdf.ground_height * 1.0) and ( 89 | p_height < self._cfg_tsdf.robot_height * self._cfg_tsdf.robot_height_factor 90 | ): # remove ground and ceiling 91 | obs_points[obs_idx, :] = p 92 | obs_idx = obs_idx + 1 93 | elif p_height < self._cfg_tsdf.ground_height: 94 | free_poins[free_idx, :] = p 95 | free_idx = free_idx + 1 96 | return obs_points[:obs_idx, :], free_poins[:free_idx, :] 97 | 98 | def UpdateMapParams(self): 99 | if self.obs_points.shape[0] == 0: 100 | print("No points received.") 101 | return 102 | max_x, max_y, _ = np.amax(self.obs_points, axis=0) + self._cfg_general.clear_dist 103 | min_x, min_y, _ = np.amin(self.obs_points, axis=0) - self._cfg_general.clear_dist 104 | 105 | self.num_x = np.ceil((max_x - min_x) / self._cfg_general.resolution / 10).astype(int) * 10 106 | self.num_y = np.ceil((max_y - min_y) / self._cfg_general.resolution / 10).astype(int) * 10 107 | self.start_x = (max_x + min_x) / 2.0 - self.num_x / 2.0 * self._cfg_general.resolution 108 | self.start_y = (max_y + min_y) / 2.0 - self.num_y / 2.0 * self._cfg_general.resolution 109 | 110 | print("tsdf map initialized, with size: %d, %d" % (self.num_x, self.num_y)) 111 | self.is_map_ready = True 112 | 113 | def zero_single_isolated_ones(self, arr): 114 | neighbors = ndimage.convolve(arr, weights=np.ones((3, 3)), mode="constant", cval=0) 115 | isolated_ones = (arr == 1) & (neighbors == 1) 116 | arr[isolated_ones] = 0 117 | return arr 118 | 119 | def CreateTSDFMap(self): 120 | if not self.is_map_ready: 121 | raise ValueError("create tsdf map fails, no points received.") 122 | free_map = np.ones([self.num_x, self.num_y]) 123 | obs_map = np.zeros([self.num_x, self.num_y]) 124 | free_I = self.IndexArrayOfPs(self.free_points) 125 | obs_I = self.IndexArrayOfPs(self.obs_points) 126 | # create free place map 127 | for i in obs_I: 128 | obs_map[i[0], i[1]] = 1.0 129 | obs_map = self.zero_single_isolated_ones(obs_map) 130 | obs_map = gaussian_filter(obs_map, sigma=self._cfg_tsdf.sigma_expand) 131 | for i in free_I: 132 | if 0 < i[0] < self.num_x and 0 < i[1] < self.num_y: 133 | try: 134 | free_map[i[0], i[1]] = 0 135 | except: 136 | import ipdb;ipdb.set_trace() 137 | free_map = gaussian_filter(free_map, sigma=self._cfg_tsdf.sigma_expand) 138 | 139 | free_map[free_map < self._cfg_tsdf.free_space_threshold] = 0 140 | # assign obstacles 141 | free_map[obs_map > self._cfg_tsdf.obstacle_threshold] = 1.0 142 | print("occupancy map generation completed.") 143 | # Distance Transform 144 | tsdf_array = ndimage.distance_transform_edt(free_map) 145 | 146 | tsdf_array[tsdf_array > 0.0] = np.log(tsdf_array[tsdf_array > 0.0] + math.e) 147 | tsdf_array = gaussian_filter(tsdf_array, sigma=self._cfg_general.sigma_smooth) 148 | 149 | viz_points = np.concatenate((self.obs_points, self.free_points), axis=0) 150 | 151 | ground_array = np.ones([self.num_x, self.num_y]) * 0.0 152 | 153 | return [tsdf_array, viz_points, ground_array], [ 154 | float(self.start_x), 155 | float(self.start_y), 156 | ] 157 | 158 | def IndexArrayOfPs(self, points): 159 | indexes = points[:, :2] - np.array([self.start_x, self.start_y]) 160 | indexes = (np.round(indexes / self._cfg_general.resolution)).astype(int) 161 | return indexes 162 | 163 | def FilterCloud(self, points, outlier_filter=True): 164 | # crop points 165 | if any( 166 | [ 167 | self._cfg_general.x_max, 168 | self._cfg_general.x_min, 169 | self._cfg_general.y_max, 170 | self._cfg_general.y_min, 171 | ] 172 | ): 173 | points_x_idx_upper = ( 174 | (points[:, 0] < self._cfg_general.x_max) 175 | if self._cfg_general.x_max is not None 176 | else np.ones(points.shape[0], dtype=bool) 177 | ) 178 | points_x_idx_lower = ( 179 | (points[:, 0] > self._cfg_general.x_min) 180 | if self._cfg_general.x_min is not None 181 | else np.ones(points.shape[0], dtype=bool) 182 | ) 183 | points_y_idx_upper = ( 184 | (points[:, 1] < self._cfg_general.y_max) 185 | if self._cfg_general.y_max is not None 186 | else np.ones(points.shape[0], dtype=bool) 187 | ) 188 | points_y_idx_lower = ( 189 | (points[:, 1] > self._cfg_general.y_min) 190 | if self._cfg_general.y_min is not None 191 | else np.ones(points.shape[0], dtype=bool) 192 | ) 193 | points = points[ 194 | np.vstack( 195 | ( 196 | points_x_idx_lower, 197 | points_x_idx_upper, 198 | points_y_idx_upper, 199 | points_y_idx_lower, 200 | ) 201 | ).all(axis=0) 202 | ] 203 | 204 | if outlier_filter: 205 | # Filter outlier in points 206 | pcd = o3d.geometry.PointCloud() 207 | pcd.points = o3d.utility.Vector3dVector(points) 208 | cl, _ = pcd.remove_statistical_outlier( 209 | nb_neighbors=self._cfg_tsdf.nb_neighbors, 210 | std_ratio=self._cfg_tsdf.std_ratio, 211 | ) 212 | points = np.asarray(cl.points) 213 | 214 | return points 215 | 216 | def VizCloud(self, pcd): 217 | o3d.visualization.draw_geometries([pcd]) # visualize point cloud 218 | 219 | 220 | # EoF 221 | -------------------------------------------------------------------------------- /deployment/src/guide.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from scipy.ndimage import gaussian_filter, distance_transform_edt 6 | 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | from depth_anything_v2.dpt import DepthAnythingV2 10 | 11 | import importlib.util 12 | import os 13 | import open3d as o3d 14 | from tsdf_cost_map import TsdfCostMap 15 | from costmap_cfg import CostMapConfig 16 | 17 | 18 | def from_numpy(array: np.ndarray) -> torch.Tensor: 19 | return torch.from_numpy(array).float() 20 | 21 | def check_tensor(tensor, name="tensor"): 22 | if tensor.grad is not None: 23 | print(f"{name} grad: {tensor.grad}") 24 | else: 25 | print(f"{name} grad is None") 26 | 27 | class PathGuide: 28 | 29 | def __init__(self, device, ACTION_STATS, guide_cfgs=None): 30 | """ 31 | Parameters: 32 | """ 33 | self.device = device 34 | self.guide_cfgs = guide_cfgs 35 | self.mse_loss = nn.MSELoss(reduction='mean') 36 | self.l1_loss = nn.L1Loss(reduction='mean') 37 | self.robot_width = 0.6 38 | self.spatial_resolution = 0.1 39 | self.max_distance = 10 40 | self.bev_dist = self.max_distance / self.spatial_resolution 41 | self.delta_min = from_numpy(ACTION_STATS['min']).to(self.device) 42 | self.delta_max = from_numpy(ACTION_STATS['max']).to(self.device) 43 | 44 | # TODO: Pass in parameters instead of constants 45 | self.camera_intrinsics = np.array([[607.99658203125, 0, 642.2532958984375], 46 | [0, 607.862060546875, 366.3480224609375], 47 | [0, 0, 1]]) 48 | # robot to camera extrinsic 49 | self.camera_extrinsics = np.array([[0, 0, 1, -0.000], 50 | [-1, 0, 0, -0.000], 51 | [0, -1, 0, -0.042], 52 | [0, 0, 0, 1]]) 53 | 54 | # depth anything v2 init 55 | model_configs = { 56 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 57 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 58 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 59 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} 60 | } 61 | encoder = 'vits' # or 'vits', 'vitb', 'vitg' 62 | self.model = DepthAnythingV2(**model_configs[encoder]) 63 | package_name = 'depth_anything_v2' 64 | package_spec = importlib.util.find_spec(package_name) 65 | if package_spec is None: 66 | raise ImportError(f"Package '{package_name}' not found") 67 | package_path = os.path.dirname(package_spec.origin) 68 | self.model.load_state_dict(torch.load(os.path.join(package_path, f'../checkpoints/depth_anything_v2_{encoder}.pth'), map_location='cpu')) 69 | self.model = self.model.to(self.device).eval() 70 | 71 | # TSDF init 72 | self.tsdf_cfg = CostMapConfig() 73 | self.tsdf_cost_map = TsdfCostMap(self.tsdf_cfg.general, self.tsdf_cfg.tsdf_cost_map) 74 | 75 | def _norm_delta_to_ori_trajs(self, trajs): 76 | delta_tmp = (trajs + 1) / 2 77 | delta_ori = delta_tmp * (self.delta_max - self.delta_min) + self.delta_min 78 | trajs_ori = delta_ori.cumsum(dim=1) 79 | return trajs_ori 80 | 81 | def goal_cost(self, trajs, goal, scale_factor=None): 82 | import time 83 | trajs_ori = self._norm_delta_to_ori_trajs(trajs) 84 | if scale_factor is not None: 85 | trajs_ori *= scale_factor 86 | trajs_end_positions = trajs_ori[:, -1, :] 87 | 88 | distances = torch.norm(goal - trajs_end_positions, dim=1) 89 | 90 | gloss = 0.05 * torch.sum(distances) 91 | 92 | if trajs.grad is not None: 93 | trajs.grad.zero_() 94 | 95 | gloss.backward() 96 | return trajs.grad 97 | 98 | def generate_scale(self, n): 99 | scale = torch.linspace(0, 1, steps=n) 100 | 101 | squared_scale = scale ** 1 102 | 103 | return squared_scale.to(self.device) 104 | 105 | def depth_to_pcd(self, depth_image, camera_intrinsics, camera_extrinsics, resize_factor=1.0, height_threshold=0.5, max_distance=10.0): 106 | height, width = depth_image.shape 107 | print("height: ", height, "width: ", width) 108 | fx, fy = camera_intrinsics[0, 0] * resize_factor, camera_intrinsics[1, 1] * resize_factor 109 | cx, cy = camera_intrinsics[0, 2] * resize_factor, camera_intrinsics[1, 2] * resize_factor 110 | 111 | x, y = np.meshgrid(np.arange(width), np.arange(height)) 112 | z = depth_image.astype(np.float32) 113 | z_safe = np.where(z == 0, np.nan, z) 114 | z = 1 / z_safe 115 | x = (x - width / 2) * z / fx 116 | y = (y - height / 2) * z / fy 117 | non_ground_mask = (z > 0.5) & (z < max_distance) 118 | x_non_ground = x[non_ground_mask] 119 | y_non_ground = y[non_ground_mask] 120 | z_non_ground = z[non_ground_mask] 121 | 122 | points = np.stack((x_non_ground, y_non_ground, z_non_ground), axis=-1).reshape(-1, 3) 123 | 124 | extrinsics = camera_extrinsics 125 | homogeneous_points = np.hstack((points, np.ones((points.shape[0], 1)))) 126 | transformed_points = (extrinsics @ homogeneous_points.T).T[:, :3] 127 | 128 | point_cloud = o3d.geometry.PointCloud() 129 | point_cloud.points = o3d.utility.Vector3dVector(transformed_points) 130 | 131 | return point_cloud 132 | 133 | def add_robot_dim(self, world_ps): 134 | tangent = world_ps[:, 1:, 0:2] - world_ps[:, :-1, 0:2] 135 | tangent = tangent / torch.norm(tangent, dim=2, keepdim=True) 136 | normals = tangent[:, :, [1, 0]] * torch.tensor( 137 | [-1, 1], dtype=torch.float32, device=world_ps.device 138 | ) 139 | world_ps_inflated = torch.vstack([world_ps[:, :-1, :]] * 3) 140 | world_ps_inflated[:, :, 0:2] = torch.vstack( 141 | [ 142 | world_ps[:, :-1, 0:2] + normals * self.robot_width / 2, 143 | world_ps[:, :-1, 0:2], # center 144 | world_ps[:, :-1, 0:2] - normals * self.robot_width / 2, 145 | ] 146 | ) 147 | return world_ps_inflated 148 | 149 | def get_cost_map_via_tsdf(self, img): 150 | original_width, original_height = img.size 151 | resize_factor = 0.25 152 | new_size = (int(original_width * resize_factor), int(original_height * resize_factor)) 153 | img = img.resize(new_size) 154 | depth_image = self.model.infer_image(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) 155 | pseudo_pcd = self.depth_to_pcd(depth_image, self.camera_intrinsics, self.camera_extrinsics, resize_factor=resize_factor) 156 | 157 | self.tsdf_cost_map.LoadPointCloud(pseudo_pcd) 158 | data, coord = self.tsdf_cost_map.CreateTSDFMap() 159 | if data is None: 160 | self.cost_map = None 161 | else: 162 | self.cost_map = torch.tensor(data[0]).requires_grad_(False).to(self.device) 163 | 164 | def collision_cost(self, trajs, scale_factor=None): 165 | if self.cost_map is None: 166 | return torch.zeros(trajs.shape) 167 | batch_size, num_p, _ = trajs.shape 168 | trajs_ori = self._norm_delta_to_ori_trajs(trajs) 169 | trajs_ori = self.add_robot_dim(trajs_ori) 170 | if scale_factor is not None: 171 | trajs_ori *= scale_factor 172 | norm_inds, _ = self.tsdf_cost_map.Pos2Ind(trajs_ori) 173 | cost_grid = self.cost_map.T.expand(trajs_ori.shape[0], 1, -1, -1) 174 | oloss_M = F.grid_sample(cost_grid, norm_inds[:, None, :, :], mode='bicubic', padding_mode='border', align_corners=False).squeeze(1).squeeze(1) 175 | oloss_M = oloss_M.to(torch.float32) 176 | 177 | loss = 0.003 * torch.sum(oloss_M, axis=1) 178 | if trajs.grad is not None: 179 | trajs.grad.zero_() 180 | loss.backward(torch.ones_like(loss)) 181 | cost_list = loss[1::3] 182 | generate_scale = self.generate_scale(trajs.shape[1]) 183 | return generate_scale.unsqueeze(1).unsqueeze(0) * trajs.grad, cost_list 184 | 185 | def get_gradient(self, trajs, alpha=0.3, t=None, goal_pos=None, ACTION_STATS=None, scale_factor=None): 186 | trajs_in = trajs.detach().requires_grad_(True).to(self.device) 187 | if goal_pos is not None: 188 | goal_pos = torch.tensor(goal_pos).to(self.device) 189 | goal_cost = self.goal_cost(trajs_in, goal_pos, scale_factor=scale_factor) 190 | cost = goal_cost 191 | return cost, None 192 | else: 193 | collision_cost, cost_list = self.collision_cost(trajs_in, scale_factor=scale_factor) 194 | cost = collision_cost 195 | return cost, cost_list 196 | 197 | 198 | 199 | class PathOpt: 200 | def __init__(self): 201 | self.traj_cache = None 202 | 203 | def angle_between_vectors(self, vec1, vec2): 204 | dot_product = np.sum(vec1 * vec2, axis=1) 205 | norm_product = np.linalg.norm(vec1, axis=1) * np.linalg.norm(vec2, axis=1) 206 | angle = np.arccos(dot_product / norm_product) 207 | return np.degrees(angle) 208 | 209 | def select_trajectory(self, trajs, l=2, angle_threshold=45, collision_min_idx=None): 210 | if self.traj_cache is None or len(self.traj_cache) <= l: 211 | idx = collision_min_idx if collision_min_idx else 0 212 | self.traj_cache = trajs[idx] 213 | else: 214 | directions = trajs[:, l, :] 215 | 216 | historical_directions = self.traj_cache[l] 217 | 218 | historical_directions = np.broadcast_to(historical_directions, directions.shape) 219 | 220 | angle_diffs = self.angle_between_vectors(directions, historical_directions) 221 | 222 | sorted_indices = np.argsort(angle_diffs) 223 | 224 | if angle_diffs[sorted_indices[0]] > angle_threshold: 225 | idx = 0 226 | self.traj_cache = trajs[idx] 227 | else: 228 | idx =sorted_indices[0] 229 | self.traj_cache = trajs[idx] 230 | 231 | return trajs[idx], idx -------------------------------------------------------------------------------- /train/vint_train/process_data/process_data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import rosbag 5 | from PIL import Image 6 | import cv2 7 | from typing import Any, Tuple, List, Dict 8 | import torchvision.transforms.functional as TF 9 | 10 | IMAGE_SIZE = (160, 120) 11 | IMAGE_ASPECT_RATIO = 4 / 3 12 | 13 | 14 | def process_images(im_list: List, img_process_func) -> List: 15 | """ 16 | Process image data from a topic that publishes ros images into a list of PIL images 17 | """ 18 | images = [] 19 | for img_msg in im_list: 20 | img = img_process_func(img_msg) 21 | images.append(img) 22 | return images 23 | 24 | 25 | def process_tartan_img(msg) -> Image: 26 | """ 27 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the tartan_drive dataset 28 | """ 29 | img = ros_to_numpy(msg, output_resolution=IMAGE_SIZE) * 255 30 | img = img.astype(np.uint8) 31 | # reverse the axis order to get the image in the right orientation 32 | img = np.moveaxis(img, 0, -1) 33 | # convert rgb to bgr 34 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 35 | img = Image.fromarray(img) 36 | return img 37 | 38 | 39 | def process_locobot_img(msg) -> Image: 40 | """ 41 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the locobot dataset 42 | """ 43 | img = np.frombuffer(msg.data, dtype=np.uint8).reshape( 44 | msg.height, msg.width, -1) 45 | pil_image = Image.fromarray(img) 46 | return pil_image 47 | 48 | 49 | def process_scand_img(msg) -> Image: 50 | """ 51 | Process image data from a topic that publishes sensor_msgs/CompressedImage to a PIL image for the scand dataset 52 | """ 53 | # convert sensor_msgs/CompressedImage to PIL image 54 | img = Image.open(io.BytesIO(msg.data)) 55 | # center crop image to 4:3 aspect ratio 56 | w, h = img.size 57 | img = TF.center_crop( 58 | img, (h, int(h * IMAGE_ASPECT_RATIO)) 59 | ) # crop to the right ratio 60 | # resize image to IMAGE_SIZE 61 | img = img.resize(IMAGE_SIZE) 62 | return img 63 | 64 | 65 | ############## Add custom image processing functions here ############# 66 | 67 | def process_sacson_img(msg) -> Image: 68 | np_arr = np.fromstring(msg.data, np.uint8) 69 | image_np = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) 70 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 71 | pil_image = Image.fromarray(image_np) 72 | return pil_image 73 | 74 | 75 | ####################################################################### 76 | 77 | 78 | def process_odom( 79 | odom_list: List, 80 | odom_process_func: Any, 81 | ang_offset: float = 0.0, 82 | ) -> Dict[np.ndarray, np.ndarray]: 83 | """ 84 | Process odom data from a topic that publishes nav_msgs/Odometry into position and yaw 85 | """ 86 | xys = [] 87 | yaws = [] 88 | for odom_msg in odom_list: 89 | xy, yaw = odom_process_func(odom_msg, ang_offset) 90 | xys.append(xy) 91 | yaws.append(yaw) 92 | return {"position": np.array(xys), "yaw": np.array(yaws)} 93 | 94 | 95 | def nav_to_xy_yaw(odom_msg, ang_offset: float) -> Tuple[List[float], float]: 96 | """ 97 | Process odom data from a topic that publishes nav_msgs/Odometry into position 98 | """ 99 | 100 | position = odom_msg.pose.pose.position 101 | orientation = odom_msg.pose.pose.orientation 102 | yaw = ( 103 | quat_to_yaw(orientation.x, orientation.y, orientation.z, orientation.w) 104 | + ang_offset 105 | ) 106 | return [position.x, position.y], yaw 107 | 108 | 109 | ############ Add custom odometry processing functions here ############ 110 | 111 | 112 | ####################################################################### 113 | 114 | 115 | def get_images_and_odom( 116 | bag: rosbag.Bag, 117 | imtopics: List[str] or str, 118 | odomtopics: List[str] or str, 119 | img_process_func: Any, 120 | odom_process_func: Any, 121 | rate: float = 4.0, 122 | ang_offset: float = 0.0, 123 | ): 124 | """ 125 | Get image and odom data from a bag file 126 | 127 | Args: 128 | bag (rosbag.Bag): bag file 129 | imtopics (list[str] or str): topic name(s) for image data 130 | odomtopics (list[str] or str): topic name(s) for odom data 131 | img_process_func (Any): function to process image data 132 | odom_process_func (Any): function to process odom data 133 | rate (float, optional): rate to sample data. Defaults to 4.0. 134 | ang_offset (float, optional): angle offset to add to odom data. Defaults to 0.0. 135 | Returns: 136 | img_data (list): list of PIL images 137 | traj_data (list): list of odom data 138 | """ 139 | # check if bag has both topics 140 | odomtopic = None 141 | imtopic = None 142 | if type(imtopics) == str: 143 | imtopic = imtopics 144 | else: 145 | for imt in imtopics: 146 | if bag.get_message_count(imt) > 0: 147 | imtopic = imt 148 | break 149 | if type(odomtopics) == str: 150 | odomtopic = odomtopics 151 | else: 152 | for ot in odomtopics: 153 | if bag.get_message_count(ot) > 0: 154 | odomtopic = ot 155 | break 156 | if not (imtopic and odomtopic): 157 | # bag doesn't have both topics 158 | return None, None 159 | 160 | synced_imdata = [] 161 | synced_odomdata = [] 162 | # get start time of bag in seconds 163 | currtime = bag.get_start_time() 164 | 165 | curr_imdata = None 166 | curr_odomdata = None 167 | 168 | for topic, msg, t in bag.read_messages(topics=[imtopic, odomtopic]): 169 | if topic == imtopic: 170 | curr_imdata = msg 171 | elif topic == odomtopic: 172 | curr_odomdata = msg 173 | if (t.to_sec() - currtime) >= 1.0 / rate: 174 | if curr_imdata is not None and curr_odomdata is not None: 175 | synced_imdata.append(curr_imdata) 176 | synced_odomdata.append(curr_odomdata) 177 | currtime = t.to_sec() 178 | 179 | img_data = process_images(synced_imdata, img_process_func) 180 | traj_data = process_odom( 181 | synced_odomdata, 182 | odom_process_func, 183 | ang_offset=ang_offset, 184 | ) 185 | 186 | return img_data, traj_data 187 | 188 | 189 | def is_backwards( 190 | pos1: np.ndarray, yaw1: float, pos2: np.ndarray, eps: float = 1e-5 191 | ) -> bool: 192 | """ 193 | Check if the trajectory is going backwards given the position and yaw of two points 194 | Args: 195 | pos1: position of the first point 196 | 197 | """ 198 | dx, dy = pos2 - pos1 199 | return dx * np.cos(yaw1) + dy * np.sin(yaw1) < eps 200 | 201 | 202 | # cut out non-positive velocity segments of the trajectory 203 | def filter_backwards( 204 | img_list: List[Image.Image], 205 | traj_data: Dict[str, np.ndarray], 206 | start_slack: int = 0, 207 | end_slack: int = 0, 208 | ) -> Tuple[List[np.ndarray], List[int]]: 209 | """ 210 | Cut out non-positive velocity segments of the trajectory 211 | Args: 212 | traj_type: type of trajectory to cut 213 | img_list: list of images 214 | traj_data: dictionary of position and yaw data 215 | start_slack: number of points to ignore at the start of the trajectory 216 | end_slack: number of points to ignore at the end of the trajectory 217 | Returns: 218 | cut_trajs: list of cut trajectories 219 | start_times: list of start times of the cut trajectories 220 | """ 221 | traj_pos = traj_data["position"] 222 | traj_yaws = traj_data["yaw"] 223 | cut_trajs = [] 224 | start = True 225 | 226 | def process_pair(traj_pair: list) -> Tuple[List, Dict]: 227 | new_img_list, new_traj_data = zip(*traj_pair) 228 | new_traj_data = np.array(new_traj_data) 229 | new_traj_pos = new_traj_data[:, :2] 230 | new_traj_yaws = new_traj_data[:, 2] 231 | return (new_img_list, {"position": new_traj_pos, "yaw": new_traj_yaws}) 232 | 233 | for i in range(max(start_slack, 1), len(traj_pos) - end_slack): 234 | pos1 = traj_pos[i - 1] 235 | yaw1 = traj_yaws[i - 1] 236 | pos2 = traj_pos[i] 237 | if not is_backwards(pos1, yaw1, pos2): 238 | if start: 239 | new_traj_pairs = [ 240 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]]) 241 | ] 242 | start = False 243 | elif i == len(traj_pos) - end_slack - 1: 244 | cut_trajs.append(process_pair(new_traj_pairs)) 245 | else: 246 | new_traj_pairs.append( 247 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]]) 248 | ) 249 | elif not start: 250 | cut_trajs.append(process_pair(new_traj_pairs)) 251 | start = True 252 | return cut_trajs 253 | 254 | 255 | def quat_to_yaw( 256 | x: np.ndarray, 257 | y: np.ndarray, 258 | z: np.ndarray, 259 | w: np.ndarray, 260 | ) -> np.ndarray: 261 | """ 262 | Convert a batch quaternion into a yaw angle 263 | yaw is rotation around z in radians (counterclockwise) 264 | """ 265 | t3 = 2.0 * (w * z + x * y) 266 | t4 = 1.0 - 2.0 * (y * y + z * z) 267 | yaw = np.arctan2(t3, t4) 268 | return yaw 269 | 270 | 271 | def ros_to_numpy( 272 | msg, nchannels=3, empty_value=None, output_resolution=None, aggregate="none" 273 | ): 274 | """ 275 | Convert a ROS image message to a numpy array 276 | """ 277 | if output_resolution is None: 278 | output_resolution = (msg.width, msg.height) 279 | 280 | is_rgb = "8" in msg.encoding 281 | if is_rgb: 282 | data = np.frombuffer(msg.data, dtype=np.uint8).copy() 283 | else: 284 | data = np.frombuffer(msg.data, dtype=np.float32).copy() 285 | 286 | data = data.reshape(msg.height, msg.width, nchannels) 287 | 288 | if empty_value: 289 | mask = np.isclose(abs(data), empty_value) 290 | fill_value = np.percentile(data[~mask], 99) 291 | data[mask] = fill_value 292 | 293 | data = cv2.resize( 294 | data, 295 | dsize=(output_resolution[0], output_resolution[1]), 296 | interpolation=cv2.INTER_AREA, 297 | ) 298 | 299 | if aggregate == "littleendian": 300 | data = sum([data[:, :, i] * (256**i) for i in range(nchannels)]) 301 | elif aggregate == "bigendian": 302 | data = sum([data[:, :, -(i + 1)] * (256**i) for i in range(nchannels)]) 303 | 304 | if len(data.shape) == 2: 305 | data = np.expand_dims(data, axis=0) 306 | else: 307 | data = np.moveaxis(data, 2, 0) # Switch to channels-first 308 | 309 | if is_rgb: 310 | data = data.astype(np.float32) / ( 311 | 255.0 if aggregate == "none" else 255.0**nchannels 312 | ) 313 | 314 | return data 315 | -------------------------------------------------------------------------------- /train/vint_train/training/train_eval_loop.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import os 3 | import numpy as np 4 | from typing import List, Optional, Dict 5 | from prettytable import PrettyTable 6 | 7 | from vint_train.training.train_utils import train, evaluate 8 | from vint_train.training.train_utils import train_nomad, evaluate_nomad 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | from torch.optim import Adam 15 | from torchvision import transforms 16 | 17 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 18 | from diffusers.training_utils import EMAModel 19 | 20 | def train_eval_loop( 21 | train_model: bool, 22 | model: nn.Module, 23 | optimizer: Adam, 24 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], 25 | dataloader: DataLoader, 26 | test_dataloaders: Dict[str, DataLoader], 27 | transform: transforms, 28 | epochs: int, 29 | device: torch.device, 30 | project_folder: str, 31 | normalized: bool, 32 | wandb_log_freq: int = 10, 33 | print_log_freq: int = 100, 34 | image_log_freq: int = 1000, 35 | num_images_log: int = 8, 36 | current_epoch: int = 0, 37 | alpha: float = 0.5, 38 | learn_angle: bool = True, 39 | use_wandb: bool = True, 40 | eval_fraction: float = 0.25, 41 | ): 42 | """ 43 | Train and evaluate the model for several epochs (vint or gnm models) 44 | 45 | Args: 46 | train_model: whether to train the model or not 47 | model: model to train 48 | optimizer: optimizer to use 49 | scheduler: learning rate scheduler to use 50 | dataloader: dataloader for train dataset 51 | test_dataloaders: dict of dataloaders for testing 52 | transform: transform to apply to images 53 | epochs: number of epochs to train 54 | device: device to train on 55 | project_folder: folder to save checkpoints and logs 56 | normalized: whether to normalize the action space or not 57 | wandb_log_freq: frequency of logging to wandb 58 | print_log_freq: frequency of printing to console 59 | image_log_freq: frequency of logging images to wandb 60 | num_images_log: number of images to log to wandb 61 | current_epoch: epoch to start training from 62 | alpha: tradeoff between distance and action loss 63 | learn_angle: whether to learn the angle or not 64 | use_wandb: whether to log to wandb or not 65 | eval_fraction: fraction of training data to use for evaluation 66 | """ 67 | assert 0 <= alpha <= 1 68 | latest_path = os.path.join(project_folder, f"latest.pth") 69 | 70 | for epoch in range(current_epoch, current_epoch + epochs): 71 | if train_model: 72 | print( 73 | f"Start ViNT Training Epoch {epoch}/{current_epoch + epochs - 1}" 74 | ) 75 | train( 76 | model=model, 77 | optimizer=optimizer, 78 | dataloader=dataloader, 79 | transform=transform, 80 | device=device, 81 | project_folder=project_folder, 82 | normalized=normalized, 83 | epoch=epoch, 84 | alpha=alpha, 85 | learn_angle=learn_angle, 86 | print_log_freq=print_log_freq, 87 | wandb_log_freq=wandb_log_freq, 88 | image_log_freq=image_log_freq, 89 | num_images_log=num_images_log, 90 | use_wandb=use_wandb, 91 | ) 92 | 93 | avg_total_test_loss = [] 94 | for dataset_type in test_dataloaders: 95 | print( 96 | f"Start {dataset_type} ViNT Testing Epoch {epoch}/{current_epoch + epochs - 1}" 97 | ) 98 | loader = test_dataloaders[dataset_type] 99 | 100 | test_dist_loss, test_action_loss, total_eval_loss = evaluate( 101 | eval_type=dataset_type, 102 | model=model, 103 | dataloader=loader, 104 | transform=transform, 105 | device=device, 106 | project_folder=project_folder, 107 | normalized=normalized, 108 | epoch=epoch, 109 | alpha=alpha, 110 | learn_angle=learn_angle, 111 | num_images_log=num_images_log, 112 | use_wandb=use_wandb, 113 | eval_fraction=eval_fraction, 114 | ) 115 | 116 | avg_total_test_loss.append(total_eval_loss) 117 | 118 | checkpoint = { 119 | "epoch": epoch, 120 | "model": model, 121 | "optimizer": optimizer, 122 | "avg_total_test_loss": np.mean(avg_total_test_loss), 123 | "scheduler": scheduler 124 | } 125 | # log average eval loss 126 | wandb.log({}, commit=False) 127 | 128 | if scheduler is not None: 129 | # scheduler calls based on the type of scheduler 130 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 131 | scheduler.step(np.mean(avg_total_test_loss)) 132 | else: 133 | scheduler.step() 134 | wandb.log({ 135 | "avg_total_test_loss": np.mean(avg_total_test_loss), 136 | "lr": optimizer.param_groups[0]["lr"], 137 | }, commit=False) 138 | 139 | numbered_path = os.path.join(project_folder, f"{epoch}.pth") 140 | torch.save(checkpoint, latest_path) 141 | torch.save(checkpoint, numbered_path) # keep track of model at every epoch 142 | 143 | # Flush the last set of eval logs 144 | wandb.log({}) 145 | print() 146 | 147 | def train_eval_loop_nomad( 148 | train_model: bool, 149 | model: nn.Module, 150 | optimizer: Adam, 151 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 152 | noise_scheduler: DDPMScheduler, 153 | train_loader: DataLoader, 154 | test_dataloaders: Dict[str, DataLoader], 155 | transform: transforms, 156 | goal_mask_prob: float, 157 | epochs: int, 158 | device: torch.device, 159 | project_folder: str, 160 | print_log_freq: int = 100, 161 | wandb_log_freq: int = 10, 162 | image_log_freq: int = 1000, 163 | num_images_log: int = 8, 164 | current_epoch: int = 0, 165 | alpha: float = 1e-4, 166 | use_wandb: bool = True, 167 | eval_fraction: float = 0.25, 168 | eval_freq: int = 1, 169 | ): 170 | """ 171 | Train and evaluate the model for several epochs (vint or gnm models) 172 | 173 | Args: 174 | model: model to train 175 | optimizer: optimizer to use 176 | lr_scheduler: learning rate scheduler to use 177 | noise_scheduler: noise scheduler to use 178 | dataloader: dataloader for train dataset 179 | test_dataloaders: dict of dataloaders for testing 180 | transform: transform to apply to images 181 | goal_mask_prob: probability of masking the goal token during training 182 | epochs: number of epochs to train 183 | device: device to train on 184 | project_folder: folder to save checkpoints and logs 185 | wandb_log_freq: frequency of logging to wandb 186 | print_log_freq: frequency of printing to console 187 | image_log_freq: frequency of logging images to wandb 188 | num_images_log: number of images to log to wandb 189 | current_epoch: epoch to start training from 190 | alpha: tradeoff between distance and action loss 191 | use_wandb: whether to log to wandb or not 192 | eval_fraction: fraction of training data to use for evaluation 193 | eval_freq: frequency of evaluation 194 | """ 195 | latest_path = os.path.join(project_folder, f"latest.pth") 196 | ema_model = EMAModel(model=model,power=0.75) 197 | 198 | for epoch in range(current_epoch, current_epoch + epochs): 199 | if train_model: 200 | print( 201 | f"Start ViNT DP Training Epoch {epoch}/{current_epoch + epochs - 1}" 202 | ) 203 | train_nomad( 204 | model=model, 205 | ema_model=ema_model, 206 | optimizer=optimizer, 207 | dataloader=train_loader, 208 | transform=transform, 209 | device=device, 210 | noise_scheduler=noise_scheduler, 211 | goal_mask_prob=goal_mask_prob, 212 | project_folder=project_folder, 213 | epoch=epoch, 214 | print_log_freq=print_log_freq, 215 | wandb_log_freq=wandb_log_freq, 216 | image_log_freq=image_log_freq, 217 | num_images_log=num_images_log, 218 | use_wandb=use_wandb, 219 | alpha=alpha, 220 | ) 221 | lr_scheduler.step() 222 | 223 | numbered_path = os.path.join(project_folder, f"ema_{epoch}.pth") 224 | torch.save(ema_model.averaged_model.state_dict(), numbered_path) 225 | numbered_path = os.path.join(project_folder, f"ema_latest.pth") 226 | print(f"Saved EMA model to {numbered_path}") 227 | 228 | numbered_path = os.path.join(project_folder, f"{epoch}.pth") 229 | torch.save(model.state_dict(), numbered_path) 230 | torch.save(model.state_dict(), latest_path) 231 | print(f"Saved model to {numbered_path}") 232 | 233 | # save optimizer 234 | numbered_path = os.path.join(project_folder, f"optimizer_{epoch}.pth") 235 | latest_optimizer_path = os.path.join(project_folder, f"optimizer_latest.pth") 236 | torch.save(optimizer.state_dict(), latest_optimizer_path) 237 | 238 | # save scheduler 239 | numbered_path = os.path.join(project_folder, f"scheduler_{epoch}.pth") 240 | latest_scheduler_path = os.path.join(project_folder, f"scheduler_latest.pth") 241 | torch.save(lr_scheduler.state_dict(), latest_scheduler_path) 242 | 243 | 244 | if (epoch + 1) % eval_freq == 0: 245 | for dataset_type in test_dataloaders: 246 | print( 247 | f"Start {dataset_type} ViNT DP Testing Epoch {epoch}/{current_epoch + epochs - 1}" 248 | ) 249 | loader = test_dataloaders[dataset_type] 250 | evaluate_nomad( 251 | eval_type=dataset_type, 252 | ema_model=ema_model, 253 | dataloader=loader, 254 | transform=transform, 255 | device=device, 256 | noise_scheduler=noise_scheduler, 257 | goal_mask_prob=goal_mask_prob, 258 | project_folder=project_folder, 259 | epoch=epoch, 260 | print_log_freq=print_log_freq, 261 | num_images_log=num_images_log, 262 | wandb_log_freq=wandb_log_freq, 263 | use_wandb=use_wandb, 264 | eval_fraction=eval_fraction, 265 | ) 266 | wandb.log({ 267 | "lr": optimizer.param_groups[0]["lr"], 268 | }, commit=False) 269 | 270 | if lr_scheduler is not None: 271 | lr_scheduler.step() 272 | 273 | # log average eval loss 274 | wandb.log({}, commit=False) 275 | 276 | wandb.log({ 277 | "lr": optimizer.param_groups[0]["lr"], 278 | }, commit=False) 279 | 280 | 281 | # Flush the last set of eval logs 282 | wandb.log({}) 283 | print() 284 | 285 | def load_model(model, model_type, checkpoint: dict) -> None: 286 | """Load model from checkpoint.""" 287 | if model_type == "nomad": 288 | state_dict = checkpoint 289 | model.load_state_dict(state_dict, strict=False) 290 | else: 291 | loaded_model = checkpoint["model"] 292 | try: 293 | state_dict = loaded_model.module.state_dict() 294 | model.load_state_dict(state_dict, strict=False) 295 | except AttributeError as e: 296 | state_dict = loaded_model.state_dict() 297 | model.load_state_dict(state_dict, strict=False) 298 | 299 | 300 | def load_ema_model(ema_model, state_dict: dict) -> None: 301 | """Load model from checkpoint.""" 302 | ema_model.load_state_dict(state_dict) 303 | 304 | 305 | def count_parameters(model): 306 | table = PrettyTable(["Modules", "Parameters"]) 307 | total_params = 0 308 | for name, parameter in model.named_parameters(): 309 | if not parameter.requires_grad: continue 310 | params = parameter.numel() 311 | table.add_row([name, params]) 312 | total_params+=params 313 | # print(table) 314 | print(f"Total Trainable Params: {total_params/1e6:.2f}M") 315 | return total_params -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import argparse 4 | import numpy as np 5 | import yaml 6 | import time 7 | import pdb 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader, ConcatDataset 12 | from torch.optim import Adam, AdamW 13 | from torchvision import transforms 14 | import torch.backends.cudnn as cudnn 15 | from warmup_scheduler import GradualWarmupScheduler 16 | 17 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 18 | from diffusers.optimization import get_scheduler 19 | 20 | """ 21 | IMPORT YOUR MODEL HERE 22 | """ 23 | from vint_train.models.nomad.nomad import NoMaD, DenseNetwork 24 | from vint_train.models.nomad.nomad_vint import NoMaD_ViNT, replace_bn_with_gn 25 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D 26 | 27 | 28 | from vint_train.data.vint_dataset import ViNT_Dataset 29 | from vint_train.training.train_eval_loop import ( 30 | train_eval_loop, 31 | train_eval_loop_nomad, 32 | load_model, 33 | ) 34 | 35 | 36 | def main(config): 37 | assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"] 38 | assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"] 39 | 40 | if torch.cuda.is_available(): 41 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 42 | if "gpu_ids" not in config: 43 | config["gpu_ids"] = [0] 44 | elif type(config["gpu_ids"]) == int: 45 | config["gpu_ids"] = [config["gpu_ids"]] 46 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 47 | [str(x) for x in config["gpu_ids"]] 48 | ) 49 | print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"]) 50 | else: 51 | print("Using cpu") 52 | 53 | first_gpu_id = config["gpu_ids"][0] 54 | device = torch.device( 55 | f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu" 56 | ) 57 | 58 | if "seed" in config: 59 | np.random.seed(config["seed"]) 60 | torch.manual_seed(config["seed"]) 61 | cudnn.deterministic = True 62 | 63 | cudnn.benchmark = True # good if input sizes don't vary 64 | transform = ([ 65 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 66 | ]) 67 | transform = transforms.Compose(transform) 68 | 69 | # Load the data 70 | train_dataset = [] 71 | test_dataloaders = {} 72 | 73 | if "context_type" not in config: 74 | config["context_type"] = "temporal" 75 | 76 | if "clip_goals" not in config: 77 | config["clip_goals"] = False 78 | 79 | for dataset_name in config["datasets"]: 80 | data_config = config["datasets"][dataset_name] 81 | if "negative_mining" not in data_config: 82 | data_config["negative_mining"] = True 83 | if "goals_per_obs" not in data_config: 84 | data_config["goals_per_obs"] = 1 85 | if "end_slack" not in data_config: 86 | data_config["end_slack"] = 0 87 | if "waypoint_spacing" not in data_config: 88 | data_config["waypoint_spacing"] = 1 89 | 90 | for data_split_type in ["train", "test"]: 91 | if data_split_type in data_config: 92 | dataset = ViNT_Dataset( 93 | data_folder=data_config["data_folder"], 94 | data_split_folder=data_config[data_split_type], 95 | dataset_name=dataset_name, 96 | image_size=config["image_size"], 97 | waypoint_spacing=data_config["waypoint_spacing"], 98 | min_dist_cat=config["distance"]["min_dist_cat"], 99 | max_dist_cat=config["distance"]["max_dist_cat"], 100 | min_action_distance=config["action"]["min_dist_cat"], 101 | max_action_distance=config["action"]["max_dist_cat"], 102 | negative_mining=data_config["negative_mining"], 103 | len_traj_pred=config["len_traj_pred"], 104 | learn_angle=config["learn_angle"], 105 | context_size=config["context_size"], 106 | context_type=config["context_type"], 107 | end_slack=data_config["end_slack"], 108 | goals_per_obs=data_config["goals_per_obs"], 109 | normalize=config["normalize"], 110 | goal_type=config["goal_type"], 111 | ) 112 | if data_split_type == "train": 113 | train_dataset.append(dataset) 114 | else: 115 | dataset_type = f"{dataset_name}_{data_split_type}" 116 | if dataset_type not in test_dataloaders: 117 | test_dataloaders[dataset_type] = {} 118 | test_dataloaders[dataset_type] = dataset 119 | 120 | # combine all the datasets from different robots 121 | train_dataset = ConcatDataset(train_dataset) 122 | 123 | train_loader = DataLoader( 124 | train_dataset, 125 | batch_size=config["batch_size"], 126 | shuffle=True, 127 | num_workers=config["num_workers"], 128 | drop_last=False, 129 | persistent_workers=True, 130 | ) 131 | 132 | if "eval_batch_size" not in config: 133 | config["eval_batch_size"] = config["batch_size"] 134 | 135 | for dataset_type, dataset in test_dataloaders.items(): 136 | test_dataloaders[dataset_type] = DataLoader( 137 | dataset, 138 | batch_size=config["eval_batch_size"], 139 | shuffle=True, 140 | num_workers=0, 141 | drop_last=False, 142 | ) 143 | 144 | # Create the model 145 | if config["model_type"] == "nomad": 146 | if config["vision_encoder"] == "nomad_vint": 147 | vision_encoder = NoMaD_ViNT( 148 | obs_encoding_size=config["encoding_size"], 149 | context_size=config["context_size"], 150 | mha_num_attention_heads=config["mha_num_attention_heads"], 151 | mha_num_attention_layers=config["mha_num_attention_layers"], 152 | mha_ff_dim_factor=config["mha_ff_dim_factor"], 153 | ) 154 | vision_encoder = replace_bn_with_gn(vision_encoder) 155 | elif config["vision_encoder"] == "vib": 156 | vision_encoder = ViB( 157 | obs_encoding_size=config["encoding_size"], 158 | context_size=config["context_size"], 159 | mha_num_attention_heads=config["mha_num_attention_heads"], 160 | mha_num_attention_layers=config["mha_num_attention_layers"], 161 | mha_ff_dim_factor=config["mha_ff_dim_factor"], 162 | ) 163 | vision_encoder = replace_bn_with_gn(vision_encoder) 164 | elif config["vision_encoder"] == "vit": 165 | vision_encoder = ViT( 166 | obs_encoding_size=config["encoding_size"], 167 | context_size=config["context_size"], 168 | image_size=config["image_size"], 169 | patch_size=config["patch_size"], 170 | mha_num_attention_heads=config["mha_num_attention_heads"], 171 | mha_num_attention_layers=config["mha_num_attention_layers"], 172 | ) 173 | vision_encoder = replace_bn_with_gn(vision_encoder) 174 | else: 175 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported") 176 | 177 | noise_pred_net = ConditionalUnet1D( 178 | input_dim=2, 179 | global_cond_dim=config["encoding_size"], 180 | down_dims=config["down_dims"], 181 | cond_predict_scale=config["cond_predict_scale"], 182 | ) 183 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"]) 184 | 185 | model = NoMaD( 186 | vision_encoder=vision_encoder, 187 | noise_pred_net=noise_pred_net, 188 | dist_pred_net=dist_pred_network, 189 | ) 190 | 191 | noise_scheduler = DDPMScheduler( 192 | num_train_timesteps=config["num_diffusion_iters"], 193 | beta_schedule='squaredcos_cap_v2', 194 | clip_sample=True, 195 | prediction_type='epsilon' 196 | ) 197 | else: 198 | raise ValueError(f"Model {config['model']} not supported") 199 | 200 | if config["clipping"]: 201 | print("Clipping gradients to", config["max_norm"]) 202 | for p in model.parameters(): 203 | if not p.requires_grad: 204 | continue 205 | p.register_hook( 206 | lambda grad: torch.clamp( 207 | grad, -1 * config["max_norm"], config["max_norm"] 208 | ) 209 | ) 210 | 211 | lr = float(config["lr"]) 212 | config["optimizer"] = config["optimizer"].lower() 213 | if config["optimizer"] == "adam": 214 | optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.98)) 215 | elif config["optimizer"] == "adamw": 216 | optimizer = AdamW(model.parameters(), lr=lr) 217 | elif config["optimizer"] == "sgd": 218 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) 219 | else: 220 | raise ValueError(f"Optimizer {config['optimizer']} not supported") 221 | 222 | scheduler = None 223 | if config["scheduler"] is not None: 224 | config["scheduler"] = config["scheduler"].lower() 225 | if config["scheduler"] == "cosine": 226 | print("Using cosine annealing with T_max", config["epochs"]) 227 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 228 | optimizer, T_max=config["epochs"] 229 | ) 230 | elif config["scheduler"] == "cyclic": 231 | print("Using cyclic LR with cycle", config["cyclic_period"]) 232 | scheduler = torch.optim.lr_scheduler.CyclicLR( 233 | optimizer, 234 | base_lr=lr / 10., 235 | max_lr=lr, 236 | step_size_up=config["cyclic_period"] // 2, 237 | cycle_momentum=False, 238 | ) 239 | elif config["scheduler"] == "plateau": 240 | print("Using ReduceLROnPlateau") 241 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 242 | optimizer, 243 | factor=config["plateau_factor"], 244 | patience=config["plateau_patience"], 245 | verbose=True, 246 | ) 247 | else: 248 | raise ValueError(f"Scheduler {config['scheduler']} not supported") 249 | 250 | if config["warmup"]: 251 | print("Using warmup scheduler") 252 | scheduler = GradualWarmupScheduler( 253 | optimizer, 254 | multiplier=1, 255 | total_epoch=config["warmup_epochs"], 256 | after_scheduler=scheduler, 257 | ) 258 | 259 | current_epoch = 0 260 | if "load_run" in config: 261 | load_project_folder = os.path.join("logs", config["load_run"]) 262 | print("Loading model from ", load_project_folder) 263 | latest_path = os.path.join(load_project_folder, "latest.pth") 264 | latest_checkpoint = torch.load(latest_path) #f"cuda:{}" if torch.cuda.is_available() else "cpu") 265 | load_model(model, config["model_type"], latest_checkpoint) 266 | if "epoch" in latest_checkpoint: 267 | current_epoch = latest_checkpoint["epoch"] + 1 268 | 269 | # Multi-GPU 270 | if len(config["gpu_ids"]) > 1: 271 | model = nn.DataParallel(model, device_ids=config["gpu_ids"]) 272 | model = model.to(device) 273 | 274 | if "load_run" in config: # load optimizer and scheduler after data parallel 275 | if "optimizer" in latest_checkpoint: 276 | optimizer.load_state_dict(latest_checkpoint["optimizer"].state_dict()) 277 | if scheduler is not None and "scheduler" in latest_checkpoint: 278 | scheduler.load_state_dict(latest_checkpoint["scheduler"].state_dict()) 279 | 280 | if config["model_type"] == "vint" or config["model_type"] == "gnm": 281 | train_eval_loop( 282 | train_model=config["train"], 283 | model=model, 284 | optimizer=optimizer, 285 | scheduler=scheduler, 286 | dataloader=train_loader, 287 | test_dataloaders=test_dataloaders, 288 | transform=transform, 289 | epochs=config["epochs"], 290 | device=device, 291 | project_folder=config["project_folder"], 292 | normalized=config["normalize"], 293 | print_log_freq=config["print_log_freq"], 294 | image_log_freq=config["image_log_freq"], 295 | num_images_log=config["num_images_log"], 296 | current_epoch=current_epoch, 297 | learn_angle=config["learn_angle"], 298 | alpha=config["alpha"], 299 | use_wandb=config["use_wandb"], 300 | eval_fraction=config["eval_fraction"], 301 | ) 302 | else: 303 | train_eval_loop_nomad( 304 | train_model=config["train"], 305 | model=model, 306 | optimizer=optimizer, 307 | lr_scheduler=scheduler, 308 | noise_scheduler=noise_scheduler, 309 | train_loader=train_loader, 310 | test_dataloaders=test_dataloaders, 311 | transform=transform, 312 | goal_mask_prob=config["goal_mask_prob"], 313 | epochs=config["epochs"], 314 | device=device, 315 | project_folder=config["project_folder"], 316 | print_log_freq=config["print_log_freq"], 317 | wandb_log_freq=config["wandb_log_freq"], 318 | image_log_freq=config["image_log_freq"], 319 | num_images_log=config["num_images_log"], 320 | current_epoch=current_epoch, 321 | alpha=float(config["alpha"]), 322 | use_wandb=config["use_wandb"], 323 | eval_fraction=config["eval_fraction"], 324 | eval_freq=config["eval_freq"], 325 | ) 326 | 327 | print("FINISHED TRAINING") 328 | 329 | 330 | if __name__ == "__main__": 331 | torch.multiprocessing.set_start_method("spawn") 332 | 333 | parser = argparse.ArgumentParser(description="Visual Navigation Transformer") 334 | 335 | # project setup 336 | parser.add_argument( 337 | "--config", 338 | "-c", 339 | default="config/vint.yaml", 340 | type=str, 341 | help="Path to the config file in train_config folder", 342 | ) 343 | args = parser.parse_args() 344 | 345 | with open("config/defaults.yaml", "r") as f: 346 | default_config = yaml.safe_load(f) 347 | 348 | config = default_config 349 | 350 | with open(args.config, "r") as f: 351 | user_config = yaml.safe_load(f) 352 | 353 | config.update(user_config) 354 | 355 | config["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S") 356 | config["project_folder"] = os.path.join( 357 | "logs", config["project_name"], config["run_name"] 358 | ) 359 | os.makedirs( 360 | config[ 361 | "project_folder" 362 | ], # should error if dir already exists to avoid overwriting and old project 363 | ) 364 | 365 | if config["use_wandb"]: 366 | wandb.login() 367 | wandb.init( 368 | project=config["project_name"], 369 | settings=wandb.Settings(start_method="fork"), 370 | entity="gnmv2", # TODO: change this to your wandb entity 371 | ) 372 | wandb.save(args.config, policy="now") # save the config file 373 | wandb.run.name = config["run_name"] 374 | # update the wandb args with the training configurations 375 | if wandb.run: 376 | wandb.config.update(config) 377 | 378 | print(config) 379 | main(config) 380 | -------------------------------------------------------------------------------- /train/vint_train/data/vint_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import yaml 5 | from typing import Any, Dict, List, Optional, Tuple 6 | import tqdm 7 | import io 8 | import lmdb 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | import torchvision.transforms.functional as TF 13 | 14 | from vint_train.data.data_utils import ( 15 | img_path_to_data, 16 | calculate_sin_cos, 17 | get_data_path, 18 | to_local_coords, 19 | ) 20 | 21 | class ViNT_Dataset(Dataset): 22 | def __init__( 23 | self, 24 | data_folder: str, 25 | data_split_folder: str, 26 | dataset_name: str, 27 | image_size: Tuple[int, int], 28 | waypoint_spacing: int, 29 | min_dist_cat: int, 30 | max_dist_cat: int, 31 | min_action_distance: int, 32 | max_action_distance: int, 33 | negative_mining: bool, 34 | len_traj_pred: int, 35 | learn_angle: bool, 36 | context_size: int, 37 | context_type: str = "temporal", 38 | end_slack: int = 0, 39 | goals_per_obs: int = 1, 40 | normalize: bool = True, 41 | obs_type: str = "image", 42 | goal_type: str = "image", 43 | ): 44 | """ 45 | Main ViNT dataset class 46 | 47 | Args: 48 | data_folder (string): Directory with all the image data 49 | data_split_folder (string): Directory with filepaths.txt, a list of all trajectory names in the dataset split that are each seperated by a newline 50 | dataset_name (string): Name of the dataset [recon, go_stanford, scand, tartandrive, etc.] 51 | waypoint_spacing (int): Spacing between waypoints 52 | min_dist_cat (int): Minimum distance category to use 53 | max_dist_cat (int): Maximum distance category to use 54 | negative_mining (bool): Whether to use negative mining from the ViNG paper (Shah et al.) (https://arxiv.org/abs/2012.09812) 55 | len_traj_pred (int): Length of trajectory of waypoints to predict if this is an action dataset 56 | learn_angle (bool): Whether to learn the yaw of the robot at each predicted waypoint if this is an action dataset 57 | context_size (int): Number of previous observations to use as context 58 | context_type (str): Whether to use temporal, randomized, or randomized temporal context 59 | end_slack (int): Number of timesteps to ignore at the end of the trajectory 60 | goals_per_obs (int): Number of goals to sample per observation 61 | normalize (bool): Whether to normalize the distances or actions 62 | goal_type (str): What data type to use for the goal. The only one supported is "image" for now. 63 | """ 64 | self.data_folder = data_folder 65 | self.data_split_folder = data_split_folder 66 | self.dataset_name = dataset_name 67 | 68 | traj_names_file = os.path.join(data_split_folder, "traj_names.txt") 69 | with open(traj_names_file, "r") as f: 70 | file_lines = f.read() 71 | self.traj_names = file_lines.split("\n") 72 | if "" in self.traj_names: 73 | self.traj_names.remove("") 74 | 75 | self.image_size = image_size 76 | self.waypoint_spacing = waypoint_spacing 77 | self.distance_categories = list( 78 | range(min_dist_cat, max_dist_cat + 1, self.waypoint_spacing) 79 | ) 80 | self.min_dist_cat = self.distance_categories[0] 81 | self.max_dist_cat = self.distance_categories[-1] 82 | self.negative_mining = negative_mining 83 | if self.negative_mining: 84 | self.distance_categories.append(-1) 85 | self.len_traj_pred = len_traj_pred 86 | self.learn_angle = learn_angle 87 | 88 | self.min_action_distance = min_action_distance 89 | self.max_action_distance = max_action_distance 90 | 91 | self.context_size = context_size 92 | assert context_type in { 93 | "temporal", 94 | "randomized", 95 | "randomized_temporal", 96 | }, "context_type must be one of temporal, randomized, randomized_temporal" 97 | self.context_type = context_type 98 | self.end_slack = end_slack 99 | self.goals_per_obs = goals_per_obs 100 | self.normalize = normalize 101 | self.obs_type = obs_type 102 | self.goal_type = goal_type 103 | 104 | # load data/data_config.yaml 105 | with open( 106 | os.path.join(os.path.dirname(__file__), "data_config.yaml"), "r" 107 | ) as f: 108 | all_data_config = yaml.safe_load(f) 109 | assert ( 110 | self.dataset_name in all_data_config 111 | ), f"Dataset {self.dataset_name} not found in data_config.yaml" 112 | dataset_names = list(all_data_config.keys()) 113 | dataset_names.sort() 114 | # use this index to retrieve the dataset name from the data_config.yaml 115 | self.dataset_index = dataset_names.index(self.dataset_name) 116 | self.data_config = all_data_config[self.dataset_name] 117 | self.trajectory_cache = {} 118 | self._load_index() 119 | self._build_caches() 120 | 121 | if self.learn_angle: 122 | self.num_action_params = 3 123 | else: 124 | self.num_action_params = 2 125 | 126 | def __getstate__(self): 127 | state = self.__dict__.copy() 128 | state["_image_cache"] = None 129 | return state 130 | 131 | def __setstate__(self, state): 132 | self.__dict__ = state 133 | self._build_caches() 134 | 135 | def _build_caches(self, use_tqdm: bool = True): 136 | """ 137 | Build a cache of images for faster loading using LMDB 138 | """ 139 | cache_filename = os.path.join( 140 | self.data_split_folder, 141 | f"dataset_{self.dataset_name}.lmdb", 142 | ) 143 | 144 | # Load all the trajectories into memory. These should already be loaded, but just in case. 145 | for traj_name in self.traj_names: 146 | self._get_trajectory(traj_name) 147 | 148 | """ 149 | If the cache file doesn't exist, create it by iterating through the dataset and writing each image to the cache 150 | """ 151 | if not os.path.exists(cache_filename): 152 | tqdm_iterator = tqdm.tqdm( 153 | self.goals_index, 154 | disable=not use_tqdm, 155 | dynamic_ncols=True, 156 | desc=f"Building LMDB cache for {self.dataset_name}" 157 | ) 158 | with lmdb.open(cache_filename, map_size=2**40) as image_cache: 159 | with image_cache.begin(write=True) as txn: 160 | for traj_name, time in tqdm_iterator: 161 | image_path = get_data_path(self.data_folder, traj_name, time) 162 | with open(image_path, "rb") as f: 163 | txn.put(image_path.encode(), f.read()) 164 | 165 | # Reopen the cache file in read-only mode 166 | self._image_cache: lmdb.Environment = lmdb.open(cache_filename, readonly=True) 167 | 168 | def _build_index(self, use_tqdm: bool = False): 169 | """ 170 | Build an index consisting of tuples (trajectory name, time, max goal distance) 171 | """ 172 | samples_index = [] 173 | goals_index = [] 174 | 175 | for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True): 176 | traj_data = self._get_trajectory(traj_name) 177 | traj_len = len(traj_data["position"]) 178 | 179 | for goal_time in range(0, traj_len): 180 | goals_index.append((traj_name, goal_time)) 181 | 182 | begin_time = self.context_size * self.waypoint_spacing 183 | end_time = traj_len - self.end_slack - self.len_traj_pred * self.waypoint_spacing 184 | for curr_time in range(begin_time, end_time): 185 | max_goal_distance = min(self.max_dist_cat * self.waypoint_spacing, traj_len - curr_time - 1) 186 | samples_index.append((traj_name, curr_time, max_goal_distance)) 187 | 188 | return samples_index, goals_index 189 | 190 | def _sample_goal(self, trajectory_name, curr_time, max_goal_dist): 191 | """ 192 | Sample a goal from the future in the same trajectory. 193 | Returns: (trajectory_name, goal_time, goal_is_negative) 194 | """ 195 | goal_offset = np.random.randint(0, max_goal_dist + 1) 196 | if goal_offset == 0: 197 | trajectory_name, goal_time = self._sample_negative() 198 | return trajectory_name, goal_time, True 199 | else: 200 | goal_time = curr_time + int(goal_offset * self.waypoint_spacing) 201 | return trajectory_name, goal_time, False 202 | 203 | def _sample_negative(self): 204 | """ 205 | Sample a goal from a (likely) different trajectory. 206 | """ 207 | return self.goals_index[np.random.randint(0, len(self.goals_index))] 208 | 209 | def _load_index(self) -> None: 210 | """ 211 | Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset 212 | """ 213 | index_to_data_path = os.path.join( 214 | self.data_split_folder, 215 | f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_context_{self.context_type}_n{self.context_size}_slack_{self.end_slack}.pkl", 216 | ) 217 | try: 218 | # load the index_to_data if it already exists (to save time) 219 | with open(index_to_data_path, "rb") as f: 220 | self.index_to_data, self.goals_index = pickle.load(f) 221 | except: 222 | # if the index_to_data file doesn't exist, create it 223 | self.index_to_data, self.goals_index = self._build_index() 224 | with open(index_to_data_path, "wb") as f: 225 | pickle.dump((self.index_to_data, self.goals_index), f) 226 | 227 | def _load_image(self, trajectory_name, time): 228 | image_path = get_data_path(self.data_folder, trajectory_name, time) 229 | 230 | try: 231 | with self._image_cache.begin() as txn: 232 | image_buffer = txn.get(image_path.encode()) 233 | image_bytes = bytes(image_buffer) 234 | image_bytes = io.BytesIO(image_bytes) 235 | return img_path_to_data(image_bytes, self.image_size) 236 | except TypeError: 237 | print(f"Failed to load image {image_path}") 238 | 239 | def _compute_actions(self, traj_data, curr_time, goal_time): 240 | start_index = curr_time 241 | end_index = curr_time + self.len_traj_pred * self.waypoint_spacing + 1 242 | yaw = traj_data["yaw"][start_index:end_index:self.waypoint_spacing] 243 | positions = traj_data["position"][start_index:end_index:self.waypoint_spacing] 244 | goal_pos = traj_data["position"][min(goal_time, len(traj_data["position"]) - 1)] 245 | 246 | if len(yaw.shape) == 2: 247 | yaw = yaw.squeeze(1) 248 | 249 | if yaw.shape != (self.len_traj_pred + 1,): 250 | const_len = self.len_traj_pred + 1 - yaw.shape[0] 251 | yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)]) 252 | positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0) 253 | 254 | assert yaw.shape == (self.len_traj_pred + 1,), f"{yaw.shape} and {(self.len_traj_pred + 1,)} should be equal" 255 | assert positions.shape == (self.len_traj_pred + 1, 2), f"{positions.shape} and {(self.len_traj_pred + 1, 2)} should be equal" 256 | 257 | waypoints = to_local_coords(positions, positions[0], yaw[0]) 258 | goal_pos = to_local_coords(goal_pos, positions[0], yaw[0]) 259 | 260 | assert waypoints.shape == (self.len_traj_pred + 1, 2), f"{waypoints.shape} and {(self.len_traj_pred + 1, 2)} should be equal" 261 | 262 | if self.learn_angle: 263 | yaw = yaw[1:] - yaw[0] 264 | actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1) 265 | else: 266 | actions = waypoints[1:] 267 | 268 | if self.normalize: 269 | actions[:, :2] /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing 270 | goal_pos /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing 271 | 272 | assert actions.shape == (self.len_traj_pred, self.num_action_params), f"{actions.shape} and {(self.len_traj_pred, self.num_action_params)} should be equal" 273 | 274 | return actions, goal_pos 275 | 276 | def _get_trajectory(self, trajectory_name): 277 | if trajectory_name in self.trajectory_cache: 278 | return self.trajectory_cache[trajectory_name] 279 | else: 280 | with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f: 281 | traj_data = pickle.load(f) 282 | self.trajectory_cache[trajectory_name] = traj_data 283 | return traj_data 284 | 285 | def __len__(self) -> int: 286 | return len(self.index_to_data) 287 | 288 | def __getitem__(self, i: int) -> Tuple[torch.Tensor]: 289 | """ 290 | Args: 291 | i (int): index to ith datapoint 292 | Returns: 293 | Tuple of tensors containing the context, observation, goal, transformed context, transformed observation, transformed goal, distance label, and action label 294 | obs_image (torch.Tensor): tensor of shape [3, H, W] containing the image of the robot's observation 295 | goal_image (torch.Tensor): tensor of shape [3, H, W] containing the subgoal image 296 | dist_label (torch.Tensor): tensor of shape (1,) containing the distance labels from the observation to the goal 297 | action_label (torch.Tensor): tensor of shape (5, 2) or (5, 4) (if training with angle) containing the action labels from the observation to the goal 298 | which_dataset (torch.Tensor): index of the datapoint in the dataset [for identifying the dataset for visualization when using multiple datasets] 299 | """ 300 | f_curr, curr_time, max_goal_dist = self.index_to_data[i] 301 | f_goal, goal_time, goal_is_negative = self._sample_goal(f_curr, curr_time, max_goal_dist) 302 | 303 | # Load images 304 | context = [] 305 | if self.context_type == "temporal": 306 | # sample the last self.context_size times from interval [0, curr_time) 307 | context_times = list( 308 | range( 309 | curr_time + -self.context_size * self.waypoint_spacing, 310 | curr_time + 1, 311 | self.waypoint_spacing, 312 | ) 313 | ) 314 | context = [(f_curr, t) for t in context_times] 315 | else: 316 | raise ValueError(f"Invalid context type {self.context_type}") 317 | 318 | obs_image = torch.cat([ 319 | self._load_image(f, t) for f, t in context 320 | ]) 321 | 322 | # Load goal image 323 | goal_image = self._load_image(f_goal, goal_time) 324 | 325 | # Load other trajectory data 326 | curr_traj_data = self._get_trajectory(f_curr) 327 | curr_traj_len = len(curr_traj_data["position"]) 328 | assert curr_time < curr_traj_len, f"{curr_time} and {curr_traj_len}" 329 | 330 | goal_traj_data = self._get_trajectory(f_goal) 331 | goal_traj_len = len(goal_traj_data["position"]) 332 | assert goal_time < goal_traj_len, f"{goal_time} an {goal_traj_len}" 333 | 334 | # Compute actions 335 | actions, goal_pos = self._compute_actions(curr_traj_data, curr_time, goal_time) 336 | 337 | # Compute distances 338 | if goal_is_negative: 339 | distance = self.max_dist_cat 340 | else: 341 | distance = (goal_time - curr_time) // self.waypoint_spacing 342 | assert (goal_time - curr_time) % self.waypoint_spacing == 0, f"{goal_time} and {curr_time} should be separated by an integer multiple of {self.waypoint_spacing}" 343 | 344 | actions_torch = torch.as_tensor(actions, dtype=torch.float32) 345 | if self.learn_angle: 346 | actions_torch = calculate_sin_cos(actions_torch) 347 | 348 | action_mask = ( 349 | (distance < self.max_action_distance) and 350 | (distance > self.min_action_distance) and 351 | (not goal_is_negative) 352 | ) 353 | 354 | return ( 355 | torch.as_tensor(obs_image, dtype=torch.float32), 356 | torch.as_tensor(goal_image, dtype=torch.float32), 357 | actions_torch, 358 | torch.as_tensor(distance, dtype=torch.int64), 359 | torch.as_tensor(goal_pos, dtype=torch.float32), 360 | torch.as_tensor(self.dataset_index, dtype=torch.int64), 361 | torch.as_tensor(action_mask, dtype=torch.float32), 362 | ) 363 | -------------------------------------------------------------------------------- /train/vint_train/visualizing/action_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | from typing import Optional, List 6 | import wandb 7 | import yaml 8 | import torch 9 | import torch.nn as nn 10 | from vint_train.visualizing.visualize_utils import ( 11 | to_numpy, 12 | numpy_to_img, 13 | VIZ_IMAGE_SIZE, 14 | RED, 15 | GREEN, 16 | BLUE, 17 | CYAN, 18 | YELLOW, 19 | MAGENTA, 20 | ) 21 | 22 | # load data_config.yaml 23 | with open(os.path.join(os.path.dirname(__file__), "../data/data_config.yaml"), "r") as f: 24 | data_config = yaml.safe_load(f) 25 | 26 | 27 | def visualize_traj_pred( 28 | batch_obs_images: np.ndarray, 29 | batch_goal_images: np.ndarray, 30 | dataset_indices: np.ndarray, 31 | batch_goals: np.ndarray, 32 | batch_pred_waypoints: np.ndarray, 33 | batch_label_waypoints: np.ndarray, 34 | eval_type: str, 35 | normalized: bool, 36 | save_folder: str, 37 | epoch: int, 38 | num_images_preds: int = 8, 39 | use_wandb: bool = True, 40 | display: bool = False, 41 | ): 42 | """ 43 | Compare predicted path with the gt path of waypoints using egocentric visualization. This visualization is for the last batch in the dataset. 44 | 45 | Args: 46 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] 47 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels] 48 | dataset_names: indices corresponding to the dataset name 49 | batch_goals (np.ndarray): batch of goal positions [batch_size, 2] 50 | batch_pred_waypoints (np.ndarray): batch of predicted waypoints [batch_size, horizon, 4] or [batch_size, horizon, 2] or [batch_size, num_trajs_sampled horizon, {2 or 4}] 51 | batch_label_waypoints (np.ndarray): batch of label waypoints [batch_size, T, 4] or [batch_size, horizon, 2] 52 | eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.) 53 | normalized (bool): whether the waypoints are normalized 54 | save_folder (str): folder to save the images. If None, will not save the images 55 | epoch (int): current epoch number 56 | num_images_preds (int): number of images to visualize 57 | use_wandb (bool): whether to use wandb to log the images 58 | display (bool): whether to display the images 59 | """ 60 | visualize_path = None 61 | if save_folder is not None: 62 | visualize_path = os.path.join( 63 | save_folder, "visualize", eval_type, f"epoch{epoch}", "action_prediction" 64 | ) 65 | 66 | if not os.path.exists(visualize_path): 67 | os.makedirs(visualize_path) 68 | 69 | assert ( 70 | len(batch_obs_images) 71 | == len(batch_goal_images) 72 | == len(batch_goals) 73 | == len(batch_pred_waypoints) 74 | == len(batch_label_waypoints) 75 | ) 76 | 77 | dataset_names = list(data_config.keys()) 78 | dataset_names.sort() 79 | 80 | batch_size = batch_obs_images.shape[0] 81 | wandb_list = [] 82 | for i in range(min(batch_size, num_images_preds)): 83 | obs_img = numpy_to_img(batch_obs_images[i]) 84 | goal_img = numpy_to_img(batch_goal_images[i]) 85 | dataset_name = dataset_names[int(dataset_indices[i])] 86 | goal_pos = batch_goals[i] 87 | pred_waypoints = batch_pred_waypoints[i] 88 | label_waypoints = batch_label_waypoints[i] 89 | 90 | if normalized: 91 | pred_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"] 92 | label_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"] 93 | goal_pos *= data_config[dataset_name]["metric_waypoint_spacing"] 94 | 95 | save_path = None 96 | if visualize_path is not None: 97 | save_path = os.path.join(visualize_path, f"{str(i).zfill(4)}.png") 98 | 99 | compare_waypoints_pred_to_label( 100 | obs_img, 101 | goal_img, 102 | dataset_name, 103 | goal_pos, 104 | pred_waypoints, 105 | label_waypoints, 106 | save_path, 107 | display, 108 | ) 109 | if use_wandb: 110 | wandb_list.append(wandb.Image(save_path)) 111 | if use_wandb: 112 | wandb.log({f"{eval_type}_action_prediction": wandb_list}, commit=False) 113 | 114 | 115 | def compare_waypoints_pred_to_label( 116 | obs_img, 117 | goal_img, 118 | dataset_name: str, 119 | goal_pos: np.ndarray, 120 | pred_waypoints: np.ndarray, 121 | label_waypoints: np.ndarray, 122 | save_path: Optional[str] = None, 123 | display: Optional[bool] = False, 124 | ): 125 | """ 126 | Compare predicted path with the gt path of waypoints using egocentric visualization. 127 | 128 | Args: 129 | obs_img: image of the observation 130 | goal_img: image of the goal 131 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon") 132 | goal_pos: goal position in the image 133 | pred_waypoints: predicted waypoints in the image 134 | label_waypoints: label waypoints in the image 135 | save_path: path to save the figure 136 | display: whether to display the figure 137 | """ 138 | 139 | fig, ax = plt.subplots(1, 3) 140 | start_pos = np.array([0, 0]) 141 | if len(pred_waypoints.shape) > 2: 142 | trajs = [*pred_waypoints, label_waypoints] 143 | else: 144 | trajs = [pred_waypoints, label_waypoints] 145 | plot_trajs_and_points( 146 | ax[0], 147 | trajs, 148 | [start_pos, goal_pos], 149 | traj_colors=[CYAN, MAGENTA], 150 | point_colors=[GREEN, RED], 151 | ) 152 | plot_trajs_and_points_on_image( 153 | ax[1], 154 | obs_img, 155 | dataset_name, 156 | trajs, 157 | [start_pos, goal_pos], 158 | traj_colors=[CYAN, MAGENTA], 159 | point_colors=[GREEN, RED], 160 | ) 161 | ax[2].imshow(goal_img) 162 | 163 | fig.set_size_inches(18.5, 10.5) 164 | ax[0].set_title(f"Action Prediction") 165 | ax[1].set_title(f"Observation") 166 | ax[2].set_title(f"Goal") 167 | 168 | if save_path is not None: 169 | fig.savefig( 170 | save_path, 171 | bbox_inches="tight", 172 | ) 173 | 174 | if not display: 175 | plt.close(fig) 176 | 177 | 178 | def plot_trajs_and_points_on_image( 179 | ax: plt.Axes, 180 | img: np.ndarray, 181 | dataset_name: str, 182 | list_trajs: list, 183 | list_points: list, 184 | traj_colors: list = [CYAN, MAGENTA], 185 | point_colors: list = [RED, GREEN], 186 | ): 187 | """ 188 | Plot trajectories and points on an image. If there is no configuration for the camera interinstics of the dataset, the image will be plotted as is. 189 | Args: 190 | ax: matplotlib axis 191 | img: image to plot 192 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon") 193 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw) 194 | list_points: list of points, each point is a numpy array of shape (2,) 195 | traj_colors: list of colors for trajectories 196 | point_colors: list of colors for points 197 | """ 198 | assert len(list_trajs) <= len(traj_colors), "Not enough colors for trajectories" 199 | assert len(list_points) <= len(point_colors), "Not enough colors for points" 200 | assert ( 201 | dataset_name in data_config 202 | ), f"Dataset {dataset_name} not found in data/data_config.yaml" 203 | 204 | ax.imshow(img) 205 | if ( 206 | "camera_metrics" in data_config[dataset_name] 207 | and "camera_height" in data_config[dataset_name]["camera_metrics"] 208 | and "camera_matrix" in data_config[dataset_name]["camera_metrics"] 209 | and "dist_coeffs" in data_config[dataset_name]["camera_metrics"] 210 | ): 211 | camera_height = data_config[dataset_name]["camera_metrics"]["camera_height"] 212 | camera_x_offset = data_config[dataset_name]["camera_metrics"]["camera_x_offset"] 213 | 214 | fx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fx"] 215 | fy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fy"] 216 | cx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cx"] 217 | cy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cy"] 218 | camera_matrix = gen_camera_matrix(fx, fy, cx, cy) 219 | 220 | k1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k1"] 221 | k2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k2"] 222 | p1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p1"] 223 | p2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p2"] 224 | k3 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k3"] 225 | dist_coeffs = np.array([k1, k2, p1, p2, k3, 0.0, 0.0, 0.0]) 226 | 227 | for i, traj in enumerate(list_trajs): 228 | xy_coords = traj[:, :2] # (horizon, 2) 229 | traj_pixels = get_pos_pixels( 230 | xy_coords, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=False 231 | ) 232 | if len(traj_pixels.shape) == 2: 233 | ax.plot( 234 | traj_pixels[:250, 0], 235 | traj_pixels[:250, 1], 236 | color=traj_colors[i], 237 | lw=2.5, 238 | ) 239 | 240 | for i, point in enumerate(list_points): 241 | if len(point.shape) == 1: 242 | # add a dimension to the front of point 243 | point = point[None, :2] 244 | else: 245 | point = point[:, :2] 246 | pt_pixels = get_pos_pixels( 247 | point, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=True 248 | ) 249 | ax.plot( 250 | pt_pixels[:250, 0], 251 | pt_pixels[:250, 1], 252 | color=point_colors[i], 253 | marker="o", 254 | markersize=10.0, 255 | ) 256 | ax.xaxis.set_visible(False) 257 | ax.yaxis.set_visible(False) 258 | ax.set_xlim((0.5, VIZ_IMAGE_SIZE[0] - 0.5)) 259 | ax.set_ylim((VIZ_IMAGE_SIZE[1] - 0.5, 0.5)) 260 | 261 | 262 | def plot_trajs_and_points( 263 | ax: plt.Axes, 264 | list_trajs: list, 265 | list_points: list, 266 | traj_colors: list = [CYAN, MAGENTA], 267 | point_colors: list = [RED, GREEN], 268 | traj_labels: Optional[list] = ["prediction", "ground truth"], 269 | point_labels: Optional[list] = ["robot", "goal"], 270 | traj_alphas: Optional[list] = None, 271 | point_alphas: Optional[list] = None, 272 | quiver_freq: int = 1, 273 | default_coloring: bool = True, 274 | ): 275 | """ 276 | Plot trajectories and points that could potentially have a yaw. 277 | 278 | Args: 279 | ax: matplotlib axis 280 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw) 281 | list_points: list of points, each point is a numpy array of shape (2,) 282 | traj_colors: list of colors for trajectories 283 | point_colors: list of colors for points 284 | traj_labels: list of labels for trajectories 285 | point_labels: list of labels for points 286 | traj_alphas: list of alphas for trajectories 287 | point_alphas: list of alphas for points 288 | quiver_freq: frequency of quiver plot (if the trajectory data includes the yaw of the robot) 289 | """ 290 | assert ( 291 | len(list_trajs) <= len(traj_colors) or default_coloring 292 | ), "Not enough colors for trajectories" 293 | assert len(list_points) <= len(point_colors), "Not enough colors for points" 294 | assert ( 295 | traj_labels is None or len(list_trajs) == len(traj_labels) or default_coloring 296 | ), "Not enough labels for trajectories" 297 | assert point_labels is None or len(list_points) == len(point_labels), "Not enough labels for points" 298 | 299 | for i, traj in enumerate(list_trajs): 300 | if traj_labels is None: 301 | ax.plot( 302 | traj[:, 0], 303 | traj[:, 1], 304 | color=traj_colors[i], 305 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0, 306 | marker="o", 307 | ) 308 | else: 309 | ax.plot( 310 | traj[:, 0], 311 | traj[:, 1], 312 | color=traj_colors[i], 313 | label=traj_labels[i], 314 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0, 315 | marker="o", 316 | ) 317 | if traj.shape[1] > 2 and quiver_freq > 0: # traj data also includes yaw of the robot 318 | bearings = gen_bearings_from_waypoints(traj) 319 | ax.quiver( 320 | traj[::quiver_freq, 0], 321 | traj[::quiver_freq, 1], 322 | bearings[::quiver_freq, 0], 323 | bearings[::quiver_freq, 1], 324 | color=traj_colors[i] * 0.5, 325 | scale=1.0, 326 | ) 327 | for i, pt in enumerate(list_points): 328 | if point_labels is None: 329 | ax.plot( 330 | pt[0], 331 | pt[1], 332 | color=point_colors[i], 333 | alpha=point_alphas[i] if point_alphas is not None else 1.0, 334 | marker="o", 335 | markersize=7.0 336 | ) 337 | else: 338 | ax.plot( 339 | pt[0], 340 | pt[1], 341 | color=point_colors[i], 342 | alpha=point_alphas[i] if point_alphas is not None else 1.0, 343 | marker="o", 344 | markersize=7.0, 345 | label=point_labels[i], 346 | ) 347 | 348 | 349 | # put the legend below the plot 350 | if traj_labels is not None or point_labels is not None: 351 | ax.legend() 352 | ax.legend(bbox_to_anchor=(0.0, -0.5), loc="upper left", ncol=2) 353 | ax.set_aspect("equal", "box") 354 | 355 | 356 | def angle_to_unit_vector(theta): 357 | """Converts an angle to a unit vector.""" 358 | return np.array([np.cos(theta), np.sin(theta)]) 359 | 360 | 361 | def gen_bearings_from_waypoints( 362 | waypoints: np.ndarray, 363 | mag=0.2, 364 | ) -> np.ndarray: 365 | """Generate bearings from waypoints, (x, y, sin(theta), cos(theta)).""" 366 | bearing = [] 367 | for i in range(0, len(waypoints)): 368 | if waypoints.shape[1] > 3: # label is sin/cos repr 369 | v = waypoints[i, 2:] 370 | # normalize v 371 | v = v / np.linalg.norm(v) 372 | v = v * mag 373 | else: # label is radians repr 374 | v = mag * angle_to_unit_vector(waypoints[i, 2]) 375 | bearing.append(v) 376 | bearing = np.array(bearing) 377 | return bearing 378 | 379 | 380 | def project_points( 381 | xy: np.ndarray, 382 | camera_height: float, 383 | camera_x_offset: float, 384 | camera_matrix: np.ndarray, 385 | dist_coeffs: np.ndarray, 386 | ): 387 | """ 388 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters. 389 | 390 | Args: 391 | xy: array of shape (batch_size, horizon, 2) representing (x, y) coordinates 392 | camera_height: height of the camera above the ground (in meters) 393 | camera_x_offset: offset of the camera from the center of the car (in meters) 394 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters 395 | dist_coeffs: vector of distortion coefficients 396 | 397 | 398 | Returns: 399 | uv: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane 400 | """ 401 | batch_size, horizon, _ = xy.shape 402 | 403 | # create 3D coordinates with the camera positioned at the given height 404 | xyz = np.concatenate( 405 | [xy, -camera_height * np.ones(list(xy.shape[:-1]) + [1])], axis=-1 406 | ) 407 | 408 | # create dummy rotation and translation vectors 409 | rvec = tvec = (0, 0, 0) 410 | 411 | xyz[..., 0] += camera_x_offset 412 | xyz_cv = np.stack([xyz[..., 1], -xyz[..., 2], xyz[..., 0]], axis=-1) 413 | uv, _ = cv2.projectPoints( 414 | xyz_cv.reshape(batch_size * horizon, 3), rvec, tvec, camera_matrix, dist_coeffs 415 | ) 416 | uv = uv.reshape(batch_size, horizon, 2) 417 | 418 | return uv 419 | 420 | 421 | def get_pos_pixels( 422 | points: np.ndarray, 423 | camera_height: float, 424 | camera_x_offset: float, 425 | camera_matrix: np.ndarray, 426 | dist_coeffs: np.ndarray, 427 | clip: Optional[bool] = False, 428 | ): 429 | """ 430 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters. 431 | Args: 432 | points: array of shape (batch_size, horizon, 2) representing (x, y) coordinates 433 | camera_height: height of the camera above the ground (in meters) 434 | camera_x_offset: offset of the camera from the center of the car (in meters) 435 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters 436 | dist_coeffs: vector of distortion coefficients 437 | 438 | Returns: 439 | pixels: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane 440 | """ 441 | pixels = project_points( 442 | points[np.newaxis], camera_height, camera_x_offset, camera_matrix, dist_coeffs 443 | )[0] 444 | pixels[:, 0] = VIZ_IMAGE_SIZE[0] - pixels[:, 0] 445 | if clip: 446 | pixels = np.array( 447 | [ 448 | [ 449 | np.clip(p[0], 0, VIZ_IMAGE_SIZE[0]), 450 | np.clip(p[1], 0, VIZ_IMAGE_SIZE[1]), 451 | ] 452 | for p in pixels 453 | ] 454 | ) 455 | else: 456 | pixels = np.array( 457 | [ 458 | p 459 | for p in pixels 460 | if np.all(p > 0) and np.all(p < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]]) 461 | ] 462 | ) 463 | return pixels 464 | 465 | 466 | def gen_camera_matrix(fx: float, fy: float, cx: float, cy: float) -> np.ndarray: 467 | """ 468 | Args: 469 | fx: focal length in x direction 470 | fy: focal length in y direction 471 | cx: principal point x coordinate 472 | cy: principal point y coordinate 473 | Returns: 474 | camera matrix 475 | """ 476 | return np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]) 477 | -------------------------------------------------------------------------------- /deployment/src/navigate.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 8 | 9 | import matplotlib.pyplot as plt 10 | import yaml 11 | import tracemalloc 12 | 13 | # ROS 14 | import rospy 15 | from sensor_msgs.msg import Image 16 | from visualization_msgs.msg import Marker 17 | from geometry_msgs.msg import Point, PoseStamped 18 | from std_msgs.msg import Bool, Float32MultiArray 19 | from utils import msg_to_pil, to_numpy, transform_images, load_model, rotate_point_by_quaternion 20 | 21 | from vint_train.training.train_utils import get_action 22 | from vint_train.visualizing.action_utils import plot_trajs_and_points 23 | from guide import PathGuide, PathOpt 24 | import torch 25 | from PIL import Image as PILImage 26 | import numpy as np 27 | import argparse 28 | import yaml 29 | import time 30 | 31 | # UTILS 32 | from topic_names import (IMAGE_TOPIC, 33 | WAYPOINT_TOPIC, 34 | SUB_GOAL_TOPIC, 35 | POS_TOPIC, 36 | SAMPLED_ACTIONS_TOPIC, 37 | VISUAL_MARKER_TOPIC) 38 | 39 | 40 | # CONSTANTS 41 | TOPOMAP_IMAGES_DIR = "../topomaps/images" 42 | MODEL_WEIGHTS_PATH = "../model_weights" 43 | ROBOT_CONFIG_PATH ="../config/robot.yaml" 44 | MODEL_CONFIG_PATH = "../config/models.yaml" 45 | with open(ROBOT_CONFIG_PATH, "r") as f: 46 | robot_config = yaml.safe_load(f) 47 | MAX_V = robot_config["max_v"] 48 | MAX_W = robot_config["max_w"] 49 | RATE = robot_config["frame_rate"] 50 | ACTION_STATS = {} 51 | ACTION_STATS['min'] = np.array([-2.5, -4]) 52 | ACTION_STATS['max'] = np.array([5, 4]) 53 | 54 | # GLOBALS 55 | context_queue = [] 56 | context_size = None 57 | subgoal = [] 58 | 59 | robo_pos = None 60 | robo_orientation = None 61 | rela_pos = None 62 | # Load the model 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | print("Using device:", device) 65 | 66 | def get_plt_param(uc_actions, gc_actions, goal_pos): 67 | traj_list = np.concatenate([ 68 | uc_actions, 69 | gc_actions, 70 | ], axis=0) 71 | traj_colors = ["red"] * len(uc_actions) + ["green"] * len(gc_actions) + ["magenta"] 72 | traj_alphas = [0.1] * (len(uc_actions) + len(gc_actions)) + [1.0] 73 | 74 | point_list = [np.array([0, 0]), goal_pos] 75 | point_colors = ["green", "red"] 76 | point_alphas = [1.0, 1.0] 77 | return traj_list, traj_colors, traj_alphas, point_list, point_colors, point_alphas 78 | 79 | def action_plot(uc_actions, gc_actions, goal_pos): 80 | traj_list, traj_colors, traj_alphas, point_list, point_colors, point_alphas = get_plt_param(uc_actions, gc_actions, goal_pos) 81 | fig, ax = plt.subplots(1, 1) 82 | plot_trajs_and_points( 83 | ax, 84 | traj_list, 85 | point_list, 86 | traj_colors, 87 | point_colors, 88 | traj_labels=None, 89 | point_labels=None, 90 | quiver_freq=0, 91 | traj_alphas=traj_alphas, 92 | point_alphas=point_alphas, 93 | ) 94 | 95 | save_path = os.path.join(f"output_goal_{rela_pos}.png") 96 | plt.savefig(save_path) 97 | plt.close(fig) 98 | print(f"output image saved as {save_path}") 99 | 100 | def Marker_process(points, id, selected_num, length=8): 101 | marker = Marker() 102 | marker.header.frame_id = "base_link" 103 | marker.header.stamp = rospy.Time.now() 104 | marker.ns= "points" 105 | marker.id = id 106 | marker.type = Marker.LINE_STRIP 107 | marker.action = Marker.ADD 108 | marker.scale.x = 0.01 109 | marker.scale.y = 0.01 110 | marker.scale.z = 0.01 111 | if selected_num == id: 112 | marker.color.a = 1.0 113 | marker.color.r = 1.0 114 | marker.color.g = 0.0 115 | marker.color.b = 0.0 116 | else: 117 | marker.color.a = 1.0 118 | marker.color.r = 0.0 119 | marker.color.g = 0.0 120 | marker.color.b = 1.0 121 | for i in range(length): 122 | p = Point() 123 | p.x = points[2 * i] 124 | p.y = points[2 * i + 1] 125 | p.z = 0 126 | marker.points.append(p) 127 | return marker 128 | 129 | def Marker_process_goal(points, marker, length=1): 130 | marker.header.frame_id = "base_link" 131 | marker.header.stamp = rospy.Time.now() 132 | marker.ns= "points" 133 | marker.id = 0 134 | marker.type = Marker.POINTS 135 | marker.action = Marker.ADD 136 | marker.scale.x = 0.1 137 | marker.scale.y = 0.1 138 | marker.color.a = 1.0 139 | marker.color.r = 1.0 140 | marker.color.g = 0.0 141 | marker.color.b = 0.0 142 | 143 | for i in range(length): 144 | p = Point() 145 | p.x = points[2 * i] 146 | p.y = points[2 * i + 1] 147 | p.z = 1 148 | marker.points.append(p) 149 | return marker 150 | 151 | def callback_obs(msg): 152 | obs_img = msg_to_pil(msg) 153 | if obs_img.mode == 'RGBA': 154 | obs_img = obs_img.convert('RGB') 155 | else: 156 | obs_img = obs_img 157 | if context_size is not None: 158 | if len(context_queue) < context_size + 1: 159 | context_queue.append(obs_img) 160 | else: 161 | context_queue.pop(0) 162 | context_queue.append(obs_img) 163 | 164 | def pos_callback(msg): 165 | global robo_pos, robo_orientation 166 | robo_pos = np.array([msg.pose.position.x, msg.pose.position.y, msg.pose.position.z]) 167 | robo_orientation = np.array([msg.pose.orientation.x, msg.pose.orientation.y, 168 | msg.pose.orientation.z, msg.pose.orientation.w]) 169 | 170 | def main(args: argparse.Namespace): 171 | global context_size, robo_pos, robo_orientation, rela_pos 172 | 173 | # load model parameters 174 | with open(MODEL_CONFIG_PATH, "r") as f: 175 | model_paths = yaml.safe_load(f) 176 | 177 | model_config_path = model_paths[args.model]["config_path"] 178 | with open(model_config_path, "r") as f: 179 | model_params = yaml.safe_load(f) 180 | 181 | if args.pos_goal: 182 | with open(os.path.join(TOPOMAP_IMAGES_DIR, args.dir, "position.txt"), 'r') as file: 183 | lines = file.readlines() 184 | 185 | context_size = model_params["context_size"] 186 | 187 | # load model weights 188 | ckpth_path = model_paths[args.model]["ckpt_path"] 189 | if os.path.exists(ckpth_path): 190 | print(f"Loading model from {ckpth_path}") 191 | else: 192 | raise FileNotFoundError(f"Model weights not found at {ckpth_path}") 193 | model = load_model( 194 | ckpth_path, 195 | model_params, 196 | device, 197 | ) 198 | model = model.to(device) 199 | model.eval() 200 | 201 | pathguide = PathGuide(device, ACTION_STATS) 202 | pathopt = PathOpt() 203 | # load topomap 204 | topomap_filenames = sorted([filename for filename in os.listdir(os.path.join( 205 | TOPOMAP_IMAGES_DIR, args.dir)) if filename.endswith('.png')], 206 | key=lambda x: int(x.split(".")[0])) 207 | topomap_dir = f"{TOPOMAP_IMAGES_DIR}/{args.dir}" 208 | num_nodes = len(topomap_filenames) 209 | topomap = [] 210 | for i in range(num_nodes): 211 | image_path = os.path.join(topomap_dir, topomap_filenames[i]) 212 | topomap.append(PILImage.open(image_path)) 213 | 214 | closest_node = args.init_node 215 | assert -1 <= args.goal_node < len(topomap), "Invalid goal index" 216 | if args.goal_node == -1: 217 | goal_node = len(topomap) - 1 218 | else: 219 | goal_node = args.goal_node 220 | 221 | # ROS 222 | rospy.init_node("EXPLORATION", anonymous=False) 223 | rate = rospy.Rate(RATE) 224 | image_curr_msg = rospy.Subscriber( 225 | IMAGE_TOPIC, Image, callback_obs, queue_size=1) 226 | 227 | if args.pos_goal: 228 | pos_curr_msg = rospy.Subscriber( 229 | POS_TOPIC, PoseStamped, pos_callback, queue_size=1) 230 | subgoal_pub = rospy.Publisher( 231 | SUB_GOAL_TOPIC, Marker, queue_size=1) 232 | robogoal_pub = rospy.Publisher( 233 | '/goal1', Marker, queue_size=1) 234 | waypoint_pub = rospy.Publisher( 235 | WAYPOINT_TOPIC, Float32MultiArray, queue_size=1) 236 | sampled_actions_pub = rospy.Publisher(SAMPLED_ACTIONS_TOPIC, Float32MultiArray, queue_size=1) 237 | goal_pub = rospy.Publisher("/topoplan/reached_goal", Bool, queue_size=1) 238 | marker_pub = rospy.Publisher(VISUAL_MARKER_TOPIC, Marker, queue_size=10) 239 | 240 | print("Registered with master node. Waiting for image observations...") 241 | 242 | if model_params["model_type"] == "nomad": 243 | num_diffusion_iters = model_params["num_diffusion_iters"] 244 | noise_scheduler = DDPMScheduler( 245 | num_train_timesteps=model_params["num_diffusion_iters"], 246 | beta_schedule='squaredcos_cap_v2', 247 | clip_sample=True, 248 | prediction_type='epsilon' 249 | ) 250 | 251 | scale = 4.0 252 | scale_factor = scale * MAX_V / RATE 253 | # navigation loop 254 | while not rospy.is_shutdown(): 255 | chosen_waypoint = np.zeros(4) 256 | if len(context_queue) > model_params["context_size"]: 257 | if model_params["model_type"] == "nomad": 258 | obs_images = transform_images(context_queue, model_params["image_size"], center_crop=False) 259 | if args.guide: 260 | pathguide.get_cost_map_via_tsdf(context_queue[-1]) 261 | obs_images = torch.split(obs_images, 3, dim=1) 262 | obs_images = torch.cat(obs_images, dim=1) 263 | obs_images = obs_images.to(device) 264 | start = max(closest_node - args.radius, 0) 265 | end = min(closest_node + args.radius + 1, goal_node) 266 | if args.pos_goal: 267 | mask = torch.ones(1).long().to(device) 268 | goal_pos = np.array([float(lines[end].split()[0]), float(lines[end].split()[1]), float(lines[end].split()[2])]) 269 | rela_pos = goal_pos - robo_pos 270 | rela_pos = rotate_point_by_quaternion(rela_pos, robo_orientation)[:2] 271 | print('rela_pos: ', rela_pos) 272 | marker_robogoal = Marker() 273 | Marker_process_goal(rela_pos[:2], marker_robogoal, 1) 274 | robogoal_pub.publish(marker_robogoal) 275 | else: 276 | mask = torch.zeros(1).long().to(device) 277 | goal_image = [transform_images(g_img, model_params["image_size"], center_crop=False).to(device) for g_img in topomap[start:end + 1]] 278 | goal_image = torch.concat(goal_image, dim=0) 279 | obsgoal_cond = model('vision_encoder', obs_img=obs_images.repeat(len(goal_image), 1, 1, 1), goal_img=goal_image, input_goal_mask=mask.repeat(len(goal_image))) 280 | if args.pos_goal: 281 | goal_poses = np.array([[float(lines[i].split()[0]), float(lines[i].split()[1]), float(lines[i].split()[2])] for i in range(start, end + 1)]) 282 | min_idx = np.argmin(np.linalg.norm(goal_poses - robo_pos, axis=1)) 283 | sg_idx = min_idx 284 | else: 285 | dists = model("dist_pred_net", obsgoal_cond=obsgoal_cond) 286 | dists = to_numpy(dists.flatten()) 287 | min_idx = np.argmin(dists) 288 | sg_idx = min(min_idx + int(dists[min_idx] < args.close_threshold), len(obsgoal_cond) - 1) 289 | time4 = time.time() 290 | closest_node = min_idx + start 291 | print("closest node:", closest_node) 292 | 293 | obs_cond = obsgoal_cond[sg_idx].unsqueeze(0) 294 | # infer action 295 | with torch.no_grad(): 296 | # encoder vision features 297 | if len(obs_cond.shape) == 2: 298 | obs_cond = obs_cond.repeat(args.num_samples, 1) 299 | else: 300 | obs_cond = obs_cond.repeat(args.num_samples, 1, 1) 301 | 302 | # initialize action from Gaussian noise 303 | noisy_action = torch.randn( 304 | (args.num_samples, model_params["len_traj_pred"], 2), device=device) 305 | naction = noisy_action 306 | 307 | # init scheduler 308 | noise_scheduler.set_timesteps(num_diffusion_iters) 309 | 310 | start_time = time.time() 311 | for k in noise_scheduler.timesteps[:]: 312 | with torch.no_grad(): 313 | # predict noise 314 | noise_pred = model( 315 | 'noise_pred_net', 316 | sample=naction, 317 | timestep=k, 318 | global_cond=obs_cond 319 | ) 320 | # inverse diffusion step (remove noise) 321 | naction = noise_scheduler.step( 322 | model_output=noise_pred, 323 | timestep=k, 324 | sample=naction 325 | ).prev_sample 326 | if args.guide: 327 | interval1 = 6 328 | period = 1 329 | if k <= interval1: 330 | if k % period == 0: 331 | if k > 2: 332 | grad, cost_list = pathguide.get_gradient(naction, goal_pos=rela_pos, scale_factor=scale_factor) 333 | grad_scale = 1.0 334 | naction -= grad_scale * grad 335 | else: 336 | if k>=0 and k <= 2: 337 | naction_tmp = naction.detach().clone() 338 | for i in range(1): 339 | grad, cost_list = pathguide.get_gradient(naction_tmp, goal_pos=rela_pos, scale_factor=scale_factor) 340 | naction_tmp -= grad 341 | naction = naction_tmp 342 | 343 | naction = to_numpy(get_action(naction)) 344 | naction_selected, selected_num = pathopt.select_trajectory(naction, l=args.waypoint, angle_threshold=45) 345 | sampled_actions_msg = Float32MultiArray() 346 | sampled_actions_msg.data = np.concatenate((np.array([0]), naction.flatten())) 347 | for i in range(8): 348 | marker = Marker_process(sampled_actions_msg.data[i * 16 + 1 : (i + 1) * 16 + 1] * scale_factor, i, selected_num) 349 | marker_pub.publish(marker) 350 | print("published sampled actions") 351 | sampled_actions_pub.publish(sampled_actions_msg) 352 | 353 | chosen_waypoint = naction_selected[args.waypoint] 354 | elif (len(context_queue) > model_params["context_size"]): 355 | start = max(closest_node - args.radius, 0) 356 | end = min(closest_node + args.radius + 1, goal_node) 357 | distances = [] 358 | waypoints = [] 359 | batch_obs_imgs = [] 360 | batch_goal_data = [] 361 | for i, sg_img in enumerate(topomap[start: end + 1]): 362 | transf_obs_img = transform_images(context_queue, model_params["image_size"]) 363 | goal_data = transform_images(sg_img, model_params["image_size"]) 364 | batch_obs_imgs.append(transf_obs_img) 365 | batch_goal_data.append(goal_data) 366 | 367 | # predict distances and waypoints 368 | batch_obs_imgs = torch.cat(batch_obs_imgs, dim=0).to(device) 369 | batch_goal_data = torch.cat(batch_goal_data, dim=0).to(device) 370 | 371 | distances, waypoints = model(batch_obs_imgs, batch_goal_data) 372 | distances = to_numpy(distances) 373 | waypoints = to_numpy(waypoints) 374 | # look for closest node 375 | if args.pos_goal: 376 | goal_poses = np.array([[float(lines[i].split()[0]), float(lines[i].split()[1]), float(lines[i].split()[2])] for i in range(start, end + 1)]) 377 | closest_node = np.argmin(np.linalg.norm(goal_poses - robo_pos, axis=1)) 378 | else: 379 | closest_node = np.argmin(distances) 380 | # chose subgoal and output waypoints 381 | if distances[closest_node] > args.close_threshold: 382 | chosen_waypoint = waypoints[closest_node][args.waypoint] 383 | sg_img = topomap[start + closest_node] 384 | else: 385 | chosen_waypoint = waypoints[min( 386 | closest_node + 1, len(waypoints) - 1)][args.waypoint] 387 | sg_img = topomap[start + min(closest_node + 1, len(waypoints) - 1)] 388 | 389 | if model_params["normalize"]: 390 | chosen_waypoint[:2] *= (scale_factor / scale) 391 | 392 | waypoint_msg = Float32MultiArray() 393 | waypoint_msg.data = chosen_waypoint 394 | waypoint_pub.publish(waypoint_msg) 395 | 396 | torch.cuda.empty_cache() 397 | 398 | rate.sleep() 399 | 400 | 401 | if __name__ == "__main__": 402 | parser = argparse.ArgumentParser( 403 | description="Code to run GNM DIFFUSION EXPLORATION on the locobot") 404 | parser.add_argument( 405 | "--model", 406 | "-m", 407 | default="nomad", 408 | type=str, 409 | help="model name (only nomad is supported) (hint: check ../config/models.yaml) (default: nomad)", 410 | ) 411 | parser.add_argument( 412 | "--waypoint", 413 | "-w", 414 | default=2, # close waypoints exihibit straight line motion (the middle waypoint is a good default) 415 | type=int, 416 | help=f"""index of the waypoint used for navigation (between 0 and 4 or 417 | how many waypoints your model predicts) (default: 2)""", 418 | ) 419 | parser.add_argument( 420 | "--dir", 421 | "-d", 422 | default="topomap", 423 | type=str, 424 | help="path to topomap images", 425 | ) 426 | parser.add_argument( 427 | "--init-node", 428 | "-i", 429 | default=0, 430 | type=int, 431 | help="""goal node index in the topomap (if -1, then the goal node is 432 | the last node in the topomap) (default: -1)""", 433 | ) 434 | parser.add_argument( 435 | "--goal-node", 436 | "-g", 437 | default=-1, 438 | type=int, 439 | help="""goal node index in the topomap (if -1, then the goal node is 440 | the last node in the topomap) (default: -1)""", 441 | ) 442 | parser.add_argument( 443 | "--close-threshold", 444 | "-t", 445 | default=3, 446 | type=int, 447 | help="""temporal distance within the next node in the topomap before 448 | localizing to it (default: 3)""", 449 | ) 450 | parser.add_argument( 451 | "--radius", 452 | "-r", 453 | default=4, 454 | type=int, 455 | help="""temporal number of locobal nodes to look at in the topopmap for 456 | localization (default: 2)""", 457 | ) 458 | parser.add_argument( 459 | "--num-samples", 460 | "-n", 461 | default=8, 462 | type=int, 463 | help=f"Number of actions sampled from the exploration model (default: 8)", 464 | ) 465 | parser.add_argument( 466 | "--guide", 467 | default=True, 468 | type=bool, 469 | ) 470 | parser.add_argument( 471 | "--point-goal", 472 | default=False, 473 | type=bool, 474 | ) 475 | args = parser.parse_args() 476 | print(f"Using {device}") 477 | main(args) 478 | 479 | 480 | --------------------------------------------------------------------------------