├── train
├── vint_train
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── data_config.yaml
│ │ ├── data_utils.py
│ │ └── vint_dataset.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── nomad
│ │ │ ├── __init__.py
│ │ │ ├── vib_placeholder.py
│ │ │ ├── README.md
│ │ │ ├── nomad.py
│ │ │ └── nomad_vint.py
│ │ └── base_model.py
│ ├── training
│ │ ├── __init__.py
│ │ ├── logger.py
│ │ └── train_eval_loop.py
│ ├── process_data
│ │ ├── __init__.py
│ │ ├── process_bags_config.yaml
│ │ └── process_data_utils.py
│ └── visualizing
│ │ ├── __init__.py
│ │ ├── visualize_utils.py
│ │ ├── distance_utils.py
│ │ └── action_utils.py
├── setup.py
├── train_environment.yml
├── config
│ ├── defaults.yaml
│ └── nomad.yaml
├── process_recon.py
├── data_split.py
├── process_bags.py
├── process_bag_diff.py
└── train.py
├── deployment
├── config
│ ├── params.yaml
│ ├── joystick.yaml
│ ├── camera_front.yaml
│ ├── camera_reverse.yaml
│ ├── models.yaml
│ ├── robot.yaml
│ └── cmd_vel_mux.yaml
├── src
│ ├── deployment_environment.yml
│ ├── joy_teleop.sh
│ ├── record_bag.sh
│ ├── create_topomap.sh
│ ├── vint_locobot.launch
│ ├── ros_data.py
│ ├── topic_names.py
│ ├── navigate.sh
│ ├── joy_teleop.py
│ ├── create_topomap.py
│ ├── pd_controller.py
│ ├── utils.py
│ ├── costmap_cfg.py
│ ├── cost_to_pcd.py
│ ├── tsdf_cost_map.py
│ ├── guide.py
│ └── navigate.py
└── deployment_environment.yaml
├── assets
└── pipeline.png
├── LICENSE
├── .gitignore
└── README.md
/train/vint_train/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/models/nomad/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/process_data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/visualizing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/models/nomad/vib_placeholder.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/deployment/config/params.yaml:
--------------------------------------------------------------------------------
1 | image_path: "../topomaps/images/"
2 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/NaviD/HEAD/assets/pipeline.png
--------------------------------------------------------------------------------
/train/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="vint_train",
5 | version="0.1.0",
6 | packages=find_packages(),
7 | )
8 |
--------------------------------------------------------------------------------
/train/vint_train/models/nomad/README.md:
--------------------------------------------------------------------------------
1 | ### NoMaD (Navigation with Goal Masked Diffusion)
2 |
3 | Files
4 | - `nomad.py` : Main model file
5 | - `nomad_vint.py` : Implementation of ViNT class for NoMaD
6 |
--------------------------------------------------------------------------------
/deployment/config/joystick.yaml:
--------------------------------------------------------------------------------
1 | # joystick parameters for src/gnm_locobot.launch
2 | dev: "/dev/input/js0" # change this to your joystick device path
3 |
4 | # joystick parameters for src/joy_teleop.py
5 | deadman_switch: 5 # button index
6 | lin_vel_button: 4
7 | ang_vel_button: 0
8 |
--------------------------------------------------------------------------------
/deployment/config/camera_front.yaml:
--------------------------------------------------------------------------------
1 | # camera parameters for src/gnm_locobot.launch
2 | video_device: "/dev/video0" # change this to your video device path
3 | image_width: 160
4 | image_height: 120
5 | pixel_format: yuyv
6 | camera_frame_id: "usb_cam"
7 | io_method: "mmap"
8 | framerate: 9
--------------------------------------------------------------------------------
/deployment/config/camera_reverse.yaml:
--------------------------------------------------------------------------------
1 | # camera parameters for src/gnm_locobot.launch
2 | video_device: "/dev/video2" # change this to your video device path
3 | image_width: 160
4 | image_height: 120
5 | pixel_format: yuyv
6 | camera_frame_id: "usb_cam"
7 | io_method: "mmap"
8 | framerate: 9
--------------------------------------------------------------------------------
/deployment/src/deployment_environment.yml:
--------------------------------------------------------------------------------
1 | name: nomad_train
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - python=3.8.5
7 | - cudatoolkit=11.
8 | - torchvision
9 | - numpy
10 | - matplotlib
11 | - pyyaml
12 | - rospkg
13 | - pip:
14 | - torch
15 | - torchvision
16 | - efficientnet_pytorch
17 | - warmup_scheduler
--------------------------------------------------------------------------------
/deployment/config/models.yaml:
--------------------------------------------------------------------------------
1 | nomad:
2 | config_path: "../../train/config/nomad.yaml"
3 | ckpt_path: "../checkpoints/nomad.pth"
4 |
5 | navidiffusor:
6 | config_path: "../../train/config/nomad.yaml"
7 | ckpt_path: "../checkpoints/navidiffusor.pth"
8 |
9 |
10 | # add your own model configs here after saving the *.pth file to ../model_weight
--------------------------------------------------------------------------------
/deployment/deployment_environment.yaml:
--------------------------------------------------------------------------------
1 | name: vint_deployment
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - python=3.8.5
7 | - cudatoolkit=11.
8 | - torchvision
9 | - numpy
10 | - matplotlib
11 | - pyyaml
12 | - rospkg
13 | - pip:
14 | - torch
15 | - torchvision
16 | - efficientnet_pytorch
17 | - warmup_scheduler
18 | - diffusers==0.11.1
--------------------------------------------------------------------------------
/deployment/config/robot.yaml:
--------------------------------------------------------------------------------
1 | # linear and angular speed limits for the robot
2 | max_v: 0.4 #0.4 # m/s
3 | max_w: 0.8 #0.8 # rad/s
4 | # observation rate fo the robot
5 | frame_rate: 4 # Hz
6 | graph_rate: 0.3333 # Hz
7 |
8 | # topic names (modify for different robots/nodes)
9 | vel_teleop_topic: /cmd_vel_mux/input/teleop
10 | vel_navi_topic: /cmd_vel
11 | vel_recovery_topic: /cmd_vel_mux/input/recovery
12 |
13 |
14 |
--------------------------------------------------------------------------------
/deployment/config/cmd_vel_mux.yaml:
--------------------------------------------------------------------------------
1 | subscribers:
2 | - name: "gnm vels"
3 | topic: "/cmd_vel_mux/input/navi"
4 | timeout: 0.1
5 | priority: 0
6 | short_desc: "The default cmd_vel, controllers unaware that we are multiplexing cmd_vel should come here"
7 | - name: "teleop"
8 | topic: "/cmd_vel_mux/input/teleop"
9 | timeout: 0.5
10 | priority: 2
11 | short_desc: "Navigation stack controller"
12 | - name: "gnm recovery"
13 | topic: "/cmd_vel_mux/input/recovery"
14 | timeout: 0.1
15 | priority: 1
16 | publisher: "/mobile_base/commands/velocity"
17 |
--------------------------------------------------------------------------------
/deployment/src/joy_teleop.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create a new tmux session
4 | session_name="teleop_locobot_$(date +%s)"
5 | tmux new-session -d -s $session_name
6 |
7 | # Split the window into two panes
8 | tmux selectp -t 0 # select the first (0) pane
9 | tmux splitw -v -p 50 # split it into two halves
10 |
11 | # Run the roslaunch command in the first pane
12 | tmux select-pane -t 0
13 | tmux send-keys "roslaunch gnm_locobot.launch" Enter
14 |
15 | # Run the teleop.py script in the second pane
16 | tmux select-pane -t 1
17 | tmux send-keys "conda activate gnm_deployment" Enter
18 | tmux send-keys "python joy_teleop.py" Enter
19 |
20 | # Attach to the tmux session
21 | tmux -2 attach-session -t $session_name
--------------------------------------------------------------------------------
/train/vint_train/visualizing/visualize_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 | VIZ_IMAGE_SIZE = (640, 480)
6 | RED = np.array([1, 0, 0])
7 | GREEN = np.array([0, 1, 0])
8 | BLUE = np.array([0, 0, 1])
9 | CYAN = np.array([0, 1, 1])
10 | YELLOW = np.array([1, 1, 0])
11 | MAGENTA = np.array([1, 0, 1])
12 |
13 |
14 | def numpy_to_img(arr: np.ndarray) -> Image:
15 | img = Image.fromarray(np.transpose(np.uint8(255 * arr), (1, 2, 0)))
16 | img = img.resize(VIZ_IMAGE_SIZE)
17 | return img
18 |
19 |
20 | def to_numpy(tensor: torch.Tensor) -> np.ndarray:
21 | return tensor.detach().cpu().numpy()
22 |
23 |
24 | def from_numpy(array: np.ndarray) -> torch.Tensor:
25 | return torch.from_numpy(array).float()
26 |
--------------------------------------------------------------------------------
/train/train_environment.yml:
--------------------------------------------------------------------------------
1 | name: navidiffusor
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - python=3.8.5
7 | # - cudatoolkit=10.
8 | - numpy
9 | - matplotlib
10 | - ipykernel
11 | - pip
12 | - pip:
13 | - torch
14 | - torchvision
15 | - tqdm==4.64.0
16 | - git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
17 | - opencv-python==4.6.0.66
18 | - h5py==3.6.0
19 | - wandb==0.12.18
20 | - --extra-index-url https://rospypi.github.io/simple/
21 | - rosbag
22 | - roslz4
23 | - prettytable
24 | - efficientnet-pytorch
25 | - warmup-scheduler
26 | - diffusers==0.11.1
27 | - lmdb
28 | - vit-pytorch
29 | - positional-encodings
30 | - scipy
31 | - open3d
32 | - rospkg
33 | - pypose
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/train/vint_train/process_data/process_bags_config.yaml:
--------------------------------------------------------------------------------
1 | tartan_drive:
2 | odomtopics: "/odometry/filtered_odom"
3 | imtopics: "/multisense/left/image_rect_color"
4 | ang_offset: 1.5707963267948966 # pi/2
5 | img_process_func: "process_tartan_img"
6 | odom_process_func: "nav_to_xy_yaw"
7 |
8 | scand:
9 | odomtopics: ["/odom", "/jackal_velocity_controller/odom"]
10 | imtopics: ["/image_raw/compressed", "/camera/rgb/image_raw/compressed"]
11 | ang_offset: 0.0
12 | img_process_func: "process_scand_img"
13 | odom_process_func: "nav_to_xy_yaw"
14 |
15 | locobot:
16 | odomtopics: "/odom"
17 | imtopics: "/usb_cam/image_raw"
18 | ang_offset: 0.0
19 | img_process_func: "process_locobot_img"
20 | odom_process_func: "nav_to_xy_yaw"
21 |
22 | sacson:
23 | odomtopics: "/odometry"
24 | imtopics: "/fisheye_image/compressed"
25 | ang_offset: 0.0
26 | img_process_func: "process_sacson_img"
27 | odom_process_func: "nav_to_xy_yaw"
28 |
29 | # add your own datasets below:
30 |
--------------------------------------------------------------------------------
/deployment/src/record_bag.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create a new tmux session
4 | session_name="record_bag_$(date +%s)"
5 | tmux new-session -d -s $session_name
6 |
7 | # Split the window into three panes
8 | tmux selectp -t 0 # select the first (0) pane
9 | tmux splitw -v -p 50 # split it into two halves
10 | tmux selectp -t 0 # go back to the first pane
11 | tmux splitw -h -p 50 # split it into two halves
12 |
13 | # Run the roslaunch command in the first pane
14 | tmux select-pane -t 0
15 | tmux send-keys "roslaunch vint_locobot.launch" Enter
16 |
17 | # Run the teleop.py script in the second pane
18 | tmux select-pane -t 1
19 | tmux send-keys "conda activate vint_deployment" Enter
20 | tmux send-keys "python joy_teleop.py" Enter
21 |
22 | # Change the directory to ../topomaps/bags and run the rosbag record command in the third pane
23 | tmux select-pane -t 2
24 | tmux send-keys "cd ../topomaps/bags" Enter
25 | tmux send-keys "rosbag record /usb_cam/image_raw -o $1" # change topic if necessary
26 |
27 | # Attach to the tmux session
28 | tmux -2 attach-session -t $session_name
--------------------------------------------------------------------------------
/deployment/src/create_topomap.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create a new tmux session
4 | session_name="gnm_locobot_$(date +%s)"
5 | tmux new-session -d -s $session_name
6 |
7 | # Split the window into three panes
8 | tmux selectp -t 0 # select the first (0) pane
9 | tmux splitw -v -p 50 # split it into two halves
10 | tmux selectp -t 0 # go back to the first pane
11 | tmux splitw -h -p 50 # split it into two halves
12 |
13 | # Run roscore in the first pane
14 | tmux select-pane -t 0
15 | tmux send-keys "roscore" Enter
16 |
17 | # Run the create_topoplan.py script with command line args in the second pane
18 | tmux select-pane -t 1
19 | tmux send-keys "conda activate navidiffusor" Enter
20 | tmux send-keys "python create_topomap.py --dt 1 --dir $1" Enter
21 |
22 | # Change the directory to ../topomaps/bags and run the rosbag play command in the third pane
23 | tmux select-pane -t 2
24 | tmux send-keys "mkdir -p ../topomaps/bags" Enter
25 | tmux send-keys "cd ../topomaps/bags" Enter
26 | tmux send-keys "rosbag play -r 0.7 $2" # feel free to change the playback rate to change the edge length in the graph
27 |
28 | # Attach to the tmux session
29 | tmux -2 attach-session -t $session_name
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Dhruv Shah, Ajay Sridhar, Nitish Dashora, Kyle Stachowicz, Kevin Black, Noriaki Hirose, Sergey Levine
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/deployment/src/vint_locobot.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/deployment/src/ros_data.py:
--------------------------------------------------------------------------------
1 | import rospy
2 |
3 | class ROSData:
4 | def __init__(self, timeout: int = 3, queue_size: int = 1, name: str = ""):
5 | self.timout = timeout
6 | self.last_time_received = float("-inf")
7 | self.queue_size = queue_size
8 | self.data = None
9 | self.name = name
10 | self.phantom = False
11 |
12 | def get(self):
13 | return self.data
14 |
15 | def set(self, data):
16 | time_waited = rospy.get_time() - self.last_time_received
17 | if self.queue_size == 1:
18 | self.data = data
19 | else:
20 | if self.data is None or time_waited > self.timout: # reset queue if timeout
21 | self.data = []
22 | if len(self.data) == self.queue_size:
23 | self.data.pop(0)
24 | self.data.append(data)
25 | self.last_time_received = rospy.get_time()
26 |
27 | def is_valid(self, verbose: bool = False):
28 | time_waited = rospy.get_time() - self.last_time_received
29 | valid = time_waited < self.timout
30 | if self.queue_size > 1:
31 | valid = valid and len(self.data) == self.queue_size
32 | if verbose and not valid:
33 | print(f"Not receiving {self.name} data for {time_waited} seconds (timeout: {self.timout} seconds)")
34 | return valid
--------------------------------------------------------------------------------
/deployment/src/topic_names.py:
--------------------------------------------------------------------------------
1 | # topic names for ROS communication
2 |
3 | # image obs topics
4 | FRONT_IMAGE_TOPIC = "/usb_cam_front/image_raw"
5 | REVERSE_IMAGE_TOPIC = "/usb_cam_reverse/image_raw"
6 | IMAGE_TOPIC = "/rgb/image_raw"
7 | POS_TOPIC ="/model_position"
8 |
9 |
10 | # exploration topics
11 | SUBGOALS_TOPIC = "/subgoals"
12 | GRAPH_NAME_TOPIC = "/graph_name"
13 | WAYPOINT_TOPIC = "/waypoint"
14 | REVERSE_MODE_TOPIC = "/reverse_mode"
15 | SAMPLED_OUTPUTS_TOPIC = "/sampled_outputs"
16 | REACHED_GOAL_TOPIC = "/topoplan/reached_goal"
17 | SAMPLED_WAYPOINTS_GRAPH_TOPIC = "/sampled_waypoints_graph"
18 | BACKTRACKING_IMAGE_TOPIC = "/backtracking_image"
19 | FRONTIER_IMAGE_TOPIC = "/frontier_image"
20 | SUBGOALS_SHAPE_TOPIC = "/subgoal_shape"
21 | SAMPLED_ACTIONS_TOPIC = "/sampled_actions"
22 | ANNOTATED_IMAGE_TOPIC = "/annotated_image"
23 | CURRENT_NODE_IMAGE_TOPIC = "/current_node_image"
24 | FLIP_DIRECTION_TOPIC = "/flip_direction"
25 | TURNING_TOPIC = "/turning"
26 | SUBGOAL_GEN_RATE_TOPIC = "/subgoal_gen_rate"
27 | MARKER_TOPIC = "/visualization_marker_array"
28 | VIZ_NAV_IMAGE_TOPIC = "/nav_image"
29 |
30 | # visualization topics
31 | CHOSEN_SUBGOAL_TOPIC = "/chosen_subgoal"
32 | VISUAL_MARKER_TOPIC = "/path"
33 | SUB_GOAL_TOPIC="/goal"
34 |
35 | # recorded ont the robot
36 | ODOM_TOPIC = "/odom"
37 | BUMPER_TOPIC = "/mobile_base/events/bumper"
38 | JOY_BUMPER_TOPIC = "/joy_bumper"
39 |
40 | # move the robot
--------------------------------------------------------------------------------
/deployment/src/navigate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Create a new tmux session
4 | session_name="navidiffusor_$(date +%s)"
5 | tmux new-session -d -s $session_name
6 |
7 | # Split the window into four panes
8 | tmux selectp -t 0 # select the first (0) pane
9 | tmux splitw -h -p 50 # split it into two halves
10 | tmux selectp -t 0 # select the first (0) pane
11 | tmux splitw -v -p 50 # split it into two halves
12 |
13 | tmux selectp -t 2 # select the new, second (2) pane
14 | tmux splitw -v -p 50 # split it into two halves
15 | tmux selectp -t 0 # go back to the first pane
16 |
17 | # Run the roslaunch command in the first pane
18 | tmux select-pane -t 0
19 | tmux send-keys "roslaunch vint_locobot.launch" Enter
20 |
21 | # Run the navigate.py script with command line args in the second pane
22 | tmux select-pane -t 1
23 | # tmux send-keys "conda activate vint_deployment" Enter
24 | tmux send-keys "conda activate navidiffusor" Enter
25 | tmux send-keys "python navigate.py $@" Enter
26 |
27 | # Run the teleop.py script in the third pane
28 | tmux select-pane -t 2
29 | # tmux send-keys "conda activate vint_deployment" Enter
30 | tmux send-keys "conda activate navidiffusor" Enter
31 | tmux send-keys "python joy_teleop.py" Enter
32 |
33 | # Run the pd_controller.py script in the fourth pane
34 | tmux select-pane -t 3
35 | tmux send-keys "conda activate navidiffusor" Enter
36 | tmux send-keys "python pd_controller.py" Enter
37 |
38 | # Attach to the tmux session
39 | tmux -2 attach-session -t $session_name
40 |
--------------------------------------------------------------------------------
/train/vint_train/data/data_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | # global params for diffusion model
3 | # normalized min and max
4 | action_stats:
5 | min: [-2.5, -4] # [min_dx, min_dy]
6 | max: [5, 4] # [max_dx, max_dy]
7 |
8 | # data specific params
9 | recon:
10 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters)
11 |
12 | # OPTIONAL (FOR VISUALIZATION ONLY)
13 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html
14 | camera_height: 0.95 # meters
15 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera
16 | camera_matrix:
17 | fx: 272.547000
18 | fy: 266.358000
19 | cx: 320.000000
20 | cy: 220.000000
21 | dist_coeffs:
22 | k1: -0.038483
23 | k2: -0.010456
24 | p1: 0.003930
25 | p2: -0.001007
26 | k3: 0.0
27 |
28 | scand:
29 | metric_waypoint_spacing: 0.38
30 |
31 | tartan_drive:
32 | metric_waypoint_spacing: 0.72
33 |
34 | go_stanford:
35 | metric_waypoint_spacing: 0.12
36 |
37 | # private datasets:
38 | cory_hall:
39 | metric_waypoint_spacing: 0.06
40 |
41 | seattle:
42 | metric_waypoint_spacing: 0.35
43 |
44 | racer:
45 | metric_waypoint_spacing: 0.38
46 |
47 | carla_intvns:
48 | metric_waypoint_spacing: 1.39
49 |
50 | carla_cil:
51 | metric_waypoint_spacing: 1.27
52 |
53 | carla_intvns:
54 | metric_waypoint_spacing: 1.39
55 |
56 | carla:
57 | metric_waypoint_spacing: 1.59
58 | image_path_func: get_image_path
59 |
60 | sacson:
61 | metric_waypoint_spacing: 0.255
62 |
63 | # add your own dataset params here:
64 |
--------------------------------------------------------------------------------
/train/config/defaults.yaml:
--------------------------------------------------------------------------------
1 | # defaults for training
2 | project_name: vint
3 | run_name: vint
4 |
5 | # training setup
6 | use_wandb: True # set to false if you don't want to log to wandb
7 | train: True
8 | batch_size: 400
9 | eval_batch_size: 400
10 | epochs: 30
11 | gpu_ids: [0]
12 | num_workers: 4
13 | lr: 5e-4
14 | optimizer: adam
15 | seed: 0
16 | clipping: False
17 | train_subset: 1.
18 |
19 | # model params
20 | model_type: gnm
21 | obs_encoding_size: 1024
22 | goal_encoding_size: 1024
23 |
24 | # normalization for the action space
25 | normalize: True
26 |
27 | # context
28 | context_type: temporal
29 | context_size: 5
30 |
31 | # tradeoff between action and distance prediction loss
32 | alpha: 0.5
33 |
34 | # tradeoff between task loss and kld
35 | beta: 0.1
36 |
37 | obs_type: image
38 | goal_type: image
39 | scheduler: null
40 |
41 | # distance bounds for distance and action and distance predictions
42 | distance:
43 | min_dist_cat: 0
44 | max_dist_cat: 20
45 | action:
46 | min_dist_cat: 2
47 | max_dist_cat: 10
48 | close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint
49 |
50 | # action output params
51 | len_traj_pred: 5
52 | learn_angle: True
53 |
54 | # dataset specific parameters
55 | image_size: [85, 64] # width, height
56 |
57 | # logging stuff
58 | ## =0 turns off
59 | print_log_freq: 100 # in iterations
60 | image_log_freq: 1000 # in iterations
61 | num_images_log: 8 # number of images to log in a logging iteration
62 | pairwise_test_freq: 10 # in epochs
63 | eval_fraction: 0.25 # fraction of the dataset to use for evaluation
64 | wandb_log_freq: 10 # in iterations
65 | eval_freq: 1 # in epochs
66 |
67 |
--------------------------------------------------------------------------------
/train/vint_train/models/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from typing import List, Dict, Optional, Tuple
5 |
6 |
7 | class BaseModel(nn.Module):
8 | def __init__(
9 | self,
10 | context_size: int = 5,
11 | len_traj_pred: Optional[int] = 5,
12 | learn_angle: Optional[bool] = True,
13 | ) -> None:
14 | """
15 | Base Model main class
16 | Args:
17 | context_size (int): how many previous observations to used for context
18 | len_traj_pred (int): how many waypoints to predict in the future
19 | learn_angle (bool): whether to predict the yaw of the robot
20 | """
21 | super(BaseModel, self).__init__()
22 | self.context_size = context_size
23 | self.learn_angle = learn_angle
24 | self.len_trajectory_pred = len_traj_pred
25 | if self.learn_angle:
26 | self.num_action_params = 4 # last two dims are the cos and sin of the angle
27 | else:
28 | self.num_action_params = 2
29 |
30 | def flatten(self, z: torch.Tensor) -> torch.Tensor:
31 | z = nn.functional.adaptive_avg_pool2d(z, (1, 1))
32 | z = torch.flatten(z, 1)
33 | return z
34 |
35 | def forward(
36 | self, obs_img: torch.tensor, goal_img: torch.tensor
37 | ) -> Tuple[torch.Tensor, torch.Tensor]:
38 | """
39 | Forward pass of the model
40 | Args:
41 | obs_img (torch.Tensor): batch of observations
42 | goal_img (torch.Tensor): batch of goals
43 | Returns:
44 | dist_pred (torch.Tensor): predicted distance to goal
45 | action_pred (torch.Tensor): predicted action
46 | """
47 | raise NotImplementedError
48 |
--------------------------------------------------------------------------------
/train/vint_train/models/nomad/nomad.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 | import pdb
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class NoMaD(nn.Module):
11 |
12 | def __init__(self, vision_encoder,
13 | noise_pred_net,
14 | dist_pred_net):
15 | super(NoMaD, self).__init__()
16 |
17 |
18 | self.vision_encoder = vision_encoder
19 | self.noise_pred_net = noise_pred_net
20 | self.dist_pred_net = dist_pred_net
21 |
22 | def forward(self, func_name, **kwargs):
23 | if func_name == "vision_encoder" :
24 | output = self.vision_encoder(kwargs["obs_img"], kwargs["goal_img"], input_goal_mask=kwargs["input_goal_mask"])
25 | elif func_name == "noise_pred_net":
26 | output = self.noise_pred_net(sample=kwargs["sample"], timestep=kwargs["timestep"], global_cond=kwargs["global_cond"])
27 | elif func_name == "dist_pred_net":
28 | output = self.dist_pred_net(kwargs["obsgoal_cond"])
29 | else:
30 | raise NotImplementedError
31 | return output
32 |
33 |
34 | class DenseNetwork(nn.Module):
35 | def __init__(self, embedding_dim):
36 | super(DenseNetwork, self).__init__()
37 |
38 | self.embedding_dim = embedding_dim
39 | self.network = nn.Sequential(
40 | nn.Linear(self.embedding_dim, self.embedding_dim//4),
41 | nn.ReLU(),
42 | nn.Linear(self.embedding_dim//4, self.embedding_dim//16),
43 | nn.ReLU(),
44 | nn.Linear(self.embedding_dim//16, 1)
45 | )
46 |
47 | def forward(self, x):
48 | x = x.reshape((-1, self.embedding_dim))
49 | output = self.network(x)
50 | return output
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/train/vint_train/training/logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Logger:
5 | def __init__(
6 | self,
7 | name: str,
8 | dataset: str,
9 | window_size: int = 10,
10 | rounding: int = 4,
11 | ):
12 | """
13 | Args:
14 | name (str): Name of the metric
15 | dataset (str): Name of the dataset
16 | window_size (int, optional): Size of the moving average window. Defaults to 10.
17 | rounding (int, optional): Number of decimals to round to. Defaults to 4.
18 | """
19 | self.data = []
20 | self.name = name
21 | self.dataset = dataset
22 | self.rounding = rounding
23 | self.window_size = window_size
24 |
25 | def display(self) -> str:
26 | latest = round(self.latest(), self.rounding)
27 | average = round(self.average(), self.rounding)
28 | moving_average = round(self.moving_average(), self.rounding)
29 | output = f"{self.full_name()}: {latest} ({self.window_size}pt moving_avg: {moving_average}) (avg: {average})"
30 | return output
31 |
32 | def log_data(self, data: float):
33 | if not np.isnan(data):
34 | self.data.append(data)
35 |
36 | def full_name(self) -> str:
37 | return f"{self.name} ({self.dataset})"
38 |
39 | def latest(self) -> float:
40 | if len(self.data) > 0:
41 | return self.data[-1]
42 | return np.nan
43 |
44 | def average(self) -> float:
45 | if len(self.data) > 0:
46 | return np.mean(self.data)
47 | return np.nan
48 |
49 | def moving_average(self) -> float:
50 | if len(self.data) > self.window_size:
51 | return np.mean(self.data[-self.window_size :])
52 | return self.average()
--------------------------------------------------------------------------------
/deployment/src/joy_teleop.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | # ROS
4 | import rospy
5 | from geometry_msgs.msg import Twist
6 | from sensor_msgs.msg import Joy
7 | from std_msgs.msg import Bool
8 |
9 | from topic_names import JOY_BUMPER_TOPIC
10 |
11 | vel_msg = Twist()
12 | CONFIG_PATH = "../config/robot.yaml"
13 | with open(CONFIG_PATH, "r") as f:
14 | robot_config = yaml.safe_load(f)
15 | MAX_V = 0.4
16 | MAX_W = 0.8
17 | VEL_TOPIC = robot_config["vel_teleop_topic"]
18 | JOY_CONFIG_PATH = "../config/joystick.yaml"
19 | with open(JOY_CONFIG_PATH, "r") as f:
20 | joy_config = yaml.safe_load(f)
21 | DEADMAN_SWITCH = joy_config["deadman_switch"] # button index
22 | LIN_VEL_BUTTON = joy_config["lin_vel_button"]
23 | ANG_VEL_BUTTON = joy_config["ang_vel_button"]
24 | RATE = 9
25 | vel_pub = rospy.Publisher(VEL_TOPIC, Twist, queue_size=1)
26 | bumper_pub = rospy.Publisher(JOY_BUMPER_TOPIC, Bool, queue_size=1)
27 | button = None
28 | bumper = False
29 |
30 |
31 | def callback_joy(data: Joy):
32 | """Callback function for the joystick subscriber"""
33 | global vel_msg, button, bumper
34 | button = data.buttons[DEADMAN_SWITCH]
35 | bumper_button = data.buttons[DEADMAN_SWITCH - 1]
36 | if button is not None: # hold down the dead-man switch to teleop the robot
37 | vel_msg.linear.x = MAX_V * data.axes[LIN_VEL_BUTTON]
38 | vel_msg.angular.z = MAX_W * data.axes[ANG_VEL_BUTTON]
39 | else:
40 | vel_msg = Twist()
41 | vel_pub.publish(vel_msg)
42 | if bumper_button is not None:
43 | bumper = bool(data.buttons[DEADMAN_SWITCH - 1])
44 | else:
45 | bumper = False
46 |
47 |
48 |
49 | def main():
50 | rospy.init_node("Joy2Locobot", anonymous=False)
51 | joy_sub = rospy.Subscriber("joy", Joy, callback_joy)
52 | rate = rospy.Rate(RATE)
53 | print("Registered with master node. Waiting for joystick input...")
54 | while not rospy.is_shutdown():
55 | if button:
56 | print(f"Teleoperating the robot:\n {vel_msg}")
57 | vel_pub.publish(vel_msg)
58 | rate.sleep()
59 | bumper_msg = Bool()
60 | bumper_msg.data = bumper
61 | bumper_pub.publish(bumper_msg)
62 | if bumper:
63 | print("Bumper pressed!")
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
69 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | train/logs/*
2 | train/datasets/*
3 | train/vint_train/data/data_splits/*
4 | train/wandb/*
5 | train/gnm_dataset/*
6 | *_test.yaml
7 |
8 |
9 | *.jpg
10 | *.pth
11 | *.mp4
12 | *.gif
13 |
14 | deployment/model_weights/*
15 | deployment/topomaps/*
16 |
17 | .vscode/*
18 | */.vscode/*
19 |
20 |
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | lib/
38 | lib64/
39 | parts/
40 | sdist/
41 | var/
42 | wheels/
43 | pip-wheel-metadata/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 | db.sqlite3-journal
83 |
84 | # Flask stuff:
85 | instance/
86 | .webassets-cache
87 |
88 | # Scrapy stuff:
89 | .scrapy
90 |
91 | # Sphinx documentation
92 | docs/_build/
93 |
94 | # PyBuilder
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # IPython
101 | profile_default/
102 | ipython_config.py
103 |
104 | # pyenv
105 | .python-version
106 |
107 | # pipenv
108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
111 | # install all needed dependencies.
112 | #Pipfile.lock
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
--------------------------------------------------------------------------------
/train/process_recon.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import os
3 | import pickle
4 | from PIL import Image
5 | import io
6 | import argparse
7 | import tqdm
8 |
9 |
10 | def main(args: argparse.Namespace):
11 | recon_dir = os.path.join(args.input_dir, "recon_release")
12 | output_dir = args.output_dir
13 |
14 | # create output dir if it doesn't exist
15 | if not os.path.exists(output_dir):
16 | os.makedirs(output_dir)
17 |
18 | # get all the folders in the recon dataset
19 | filenames = os.listdir(recon_dir)
20 | if args.num_trajs >= 0:
21 | filenames = filenames[: args.num_trajs]
22 |
23 | # processing loop
24 | for filename in tqdm.tqdm(filenames, desc="Trajectories processed"):
25 | # extract the name without the extension
26 | traj_name = filename.split(".")[0]
27 | # load the hdf5 file
28 | try:
29 | h5_f = h5py.File(os.path.join(recon_dir, filename), "r")
30 | except OSError:
31 | print(f"Error loading {filename}. Skipping...")
32 | continue
33 | # extract the position and yaw data
34 | position_data = h5_f["jackal"]["position"][:, :2]
35 | yaw_data = h5_f["jackal"]["yaw"][()]
36 | # save the data to a dictionary
37 | traj_data = {"position": position_data, "yaw": yaw_data}
38 | traj_folder = os.path.join(output_dir, traj_name)
39 | os.makedirs(traj_folder, exist_ok=True)
40 | with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f:
41 | pickle.dump(traj_data, f)
42 | # make a folder for the file
43 | if not os.path.exists(traj_folder):
44 | os.makedirs(traj_folder)
45 | # save the image data to disk
46 | for i in range(h5_f["images"]["rgb_left"].shape[0]):
47 | img = Image.open(io.BytesIO(h5_f["images"]["rgb_left"][i]))
48 | img.save(os.path.join(traj_folder, f"{i}.jpg"))
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | # get arguments for the recon input dir and the output dir
54 | parser.add_argument(
55 | "--input-dir",
56 | "-i",
57 | type=str,
58 | help="path of the recon_dataset",
59 | required=True,
60 | )
61 | parser.add_argument(
62 | "--output-dir",
63 | "-o",
64 | default="datasets/recon/",
65 | type=str,
66 | help="path for processed recon dataset (default: datasets/recon/)",
67 | )
68 | # number of trajs to process
69 | parser.add_argument(
70 | "--num-trajs",
71 | "-n",
72 | default=-1,
73 | type=int,
74 | help="number of trajectories to process (default: -1, all)",
75 | )
76 |
77 | args = parser.parse_args()
78 | print("STARTING PROCESSING RECON DATASET")
79 | main(args)
80 | print("FINISHED PROCESSING RECON DATASET")
81 |
--------------------------------------------------------------------------------
/train/data_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import random
5 |
6 |
7 | def remove_files_in_dir(dir_path: str):
8 | for f in os.listdir(dir_path):
9 | file_path = os.path.join(dir_path, f)
10 | try:
11 | if os.path.isfile(file_path) or os.path.islink(file_path):
12 | os.unlink(file_path)
13 | elif os.path.isdir(file_path):
14 | shutil.rmtree(file_path)
15 | except Exception as e:
16 | print("Failed to delete %s. Reason: %s" % (file_path, e))
17 |
18 |
19 | def main(args: argparse.Namespace):
20 | # Get the names of the folders in the data directory that contain the file 'traj_data.pkl'
21 | folder_names = [
22 | f
23 | for f in os.listdir(args.data_dir)
24 | if os.path.isdir(os.path.join(args.data_dir, f))
25 | and "traj_data.pkl" in os.listdir(os.path.join(args.data_dir, f))
26 | ]
27 |
28 | # Randomly shuffle the names of the folders
29 | random.shuffle(folder_names)
30 |
31 | # Split the names of the folders into train and test sets
32 | split_index = int(args.split * len(folder_names))
33 | train_folder_names = folder_names[:split_index]
34 | test_folder_names = folder_names[split_index:]
35 |
36 | # Create directories for the train and test sets
37 | train_dir = os.path.join(args.data_splits_dir, args.dataset_name, "train")
38 | test_dir = os.path.join(args.data_splits_dir, args.dataset_name, "test")
39 | for dir_path in [train_dir, test_dir]:
40 | if os.path.exists(dir_path):
41 | print(f"Clearing files from {dir_path} for new data split")
42 | remove_files_in_dir(dir_path)
43 | else:
44 | print(f"Creating {dir_path}")
45 | os.makedirs(dir_path)
46 |
47 | # Write the names of the train and test folders to files
48 | with open(os.path.join(train_dir, "traj_names.txt"), "w") as f:
49 | for folder_name in train_folder_names:
50 | f.write(folder_name + "\n")
51 |
52 | with open(os.path.join(test_dir, "traj_names.txt"), "w") as f:
53 | for folder_name in test_folder_names:
54 | f.write(folder_name + "\n")
55 |
56 |
57 | if __name__ == "__main__":
58 | # Set up the command line argument parser
59 | parser = argparse.ArgumentParser()
60 |
61 | parser.add_argument(
62 | "--data-dir", "-i", help="Directory containing the data", required=True
63 | )
64 | parser.add_argument(
65 | "--dataset-name", "-d", help="Name of the dataset", required=True
66 | )
67 | parser.add_argument(
68 | "--split", "-s", type=float, default=0.8, help="Train/test split (default: 0.8)"
69 | )
70 | parser.add_argument(
71 | "--data-splits-dir", "-o", default="vint_train/data/data_splits", help="Data splits directory"
72 | )
73 | args = parser.parse_args()
74 | main(args)
75 | print("Done")
76 |
--------------------------------------------------------------------------------
/deployment/src/create_topomap.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from utils import msg_to_pil
4 | import time
5 |
6 | # ROS
7 | import rospy
8 | from sensor_msgs.msg import Image
9 | from sensor_msgs.msg import Joy
10 |
11 | IMAGE_TOPIC = "/rgb/image_raw"
12 | TOPOMAP_IMAGES_DIR = "../topomaps/images"
13 | obs_img = None
14 |
15 |
16 | def remove_files_in_dir(dir_path: str):
17 | for f in os.listdir(dir_path):
18 | file_path = os.path.join(dir_path, f)
19 | try:
20 | if os.path.isfile(file_path) or os.path.islink(file_path):
21 | os.unlink(file_path)
22 | elif os.path.isdir(file_path):
23 | shutil.rmtree(file_path)
24 | except Exception as e:
25 | print("Failed to delete %s. Reason: %s" % (file_path, e))
26 |
27 |
28 | def callback_obs(msg: Image):
29 | global obs_img
30 | obs_img = msg_to_pil(msg)
31 |
32 |
33 | def callback_joy(msg: Joy):
34 | if msg.buttons[0]:
35 | rospy.signal_shutdown("shutdown")
36 |
37 |
38 | def main(args: argparse.Namespace):
39 | global obs_img
40 | rospy.init_node("CREATE_TOPOMAP", anonymous=False)
41 | image_curr_msg = rospy.Subscriber(
42 | IMAGE_TOPIC, Image, callback_obs, queue_size=1)
43 | subgoals_pub = rospy.Publisher(
44 | "/subgoals", Image, queue_size=1)
45 | joy_sub = rospy.Subscriber("joy", Joy, callback_joy)
46 |
47 | topomap_name_dir = os.path.join(TOPOMAP_IMAGES_DIR, args.dir)
48 | if not os.path.isdir(topomap_name_dir):
49 | os.makedirs(topomap_name_dir)
50 | else:
51 | print(f"{topomap_name_dir} already exists. Removing previous images...")
52 | remove_files_in_dir(topomap_name_dir)
53 |
54 |
55 | assert args.dt > 0, "dt must be positive"
56 | rate = rospy.Rate(1/args.dt)
57 | print("Registered with master node. Waiting for images...")
58 | i = 0
59 | start_time = float("inf")
60 | while not rospy.is_shutdown():
61 | if obs_img is not None:
62 | obs_img.save(os.path.join(topomap_name_dir, f"{i}.png"))
63 | print("published image", i)
64 | i += 1
65 | rate.sleep()
66 | start_time = time.time()
67 | obs_img = None
68 | if time.time() - start_time > 2 * args.dt:
69 | print(f"Topic {IMAGE_TOPIC} not publishing anymore. Shutting down...")
70 | rospy.signal_shutdown("shutdown")
71 |
72 |
73 | if __name__ == "__main__":
74 | parser = argparse.ArgumentParser(
75 | description=f"Code to generate topomaps from the {IMAGE_TOPIC} topic"
76 | )
77 | parser.add_argument(
78 | "--dir",
79 | "-d",
80 | default="topomap",
81 | type=str,
82 | help="path to topological map images in ../topomaps/images directory (default: topomap)",
83 | )
84 | parser.add_argument(
85 | "--dt",
86 | "-t",
87 | default=0.1,
88 | type=float,
89 | help=f"time between images sampled from the {IMAGE_TOPIC} topic (default: 3.0)",
90 | )
91 | args = parser.parse_args()
92 |
93 | main(args)
94 |
--------------------------------------------------------------------------------
/deployment/src/pd_controller.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import yaml
3 | from typing import Tuple
4 |
5 | # ROS
6 | import rospy
7 | from geometry_msgs.msg import Twist
8 | from std_msgs.msg import Float32MultiArray, Bool
9 |
10 | from topic_names import (WAYPOINT_TOPIC,
11 | REACHED_GOAL_TOPIC)
12 | from ros_data import ROSData
13 | from utils import clip_angle
14 |
15 | # CONSTS
16 | CONFIG_PATH = "../config/robot.yaml"
17 | with open(CONFIG_PATH, "r") as f:
18 | robot_config = yaml.safe_load(f)
19 | MAX_V = robot_config["max_v"]
20 | MAX_W = robot_config["max_w"]
21 | VEL_TOPIC = robot_config["vel_navi_topic"]
22 | DT = 1/robot_config["frame_rate"]
23 | RATE = 50
24 | EPS = 1e-8
25 | WAYPOINT_TIMEOUT = 4 # seconds # TODO: tune this
26 | FLIP_ANG_VEL = np.pi/4
27 |
28 | # GLOBALS
29 | vel_msg = Twist()
30 | waypoint = ROSData(WAYPOINT_TIMEOUT, name="waypoint")
31 | reached_goal = False
32 | reverse_mode = False
33 | current_yaw = None
34 |
35 | def clip_angle(theta) -> float:
36 | """Clip angle to [-pi, pi]"""
37 | theta %= 2 * np.pi
38 | if -np.pi < theta < np.pi:
39 | return theta
40 | return theta - 2 * np.pi
41 |
42 |
43 | def pd_controller(waypoint: np.ndarray) -> Tuple[float]:
44 | """PD controller for the robot"""
45 | assert len(waypoint) == 2 or len(waypoint) == 4, "waypoint must be a 2D or 4D vector"
46 | if len(waypoint) == 2:
47 | dx, dy = waypoint
48 | else:
49 | dx, dy, hx, hy = waypoint
50 | # this controller only uses the predicted heading if dx and dy near zero
51 | if len(waypoint) == 4 and np.abs(dx) < EPS and np.abs(dy) < EPS:
52 | v = 0
53 | w = clip_angle(np.arctan2(hy, hx))/DT
54 | elif np.abs(dx) < EPS:
55 | v = 0
56 | w = np.sign(dy) * np.pi/(2*DT)
57 | else:
58 | v = dx / DT
59 | w = np.arctan(dy/dx) / DT
60 | v = np.clip(v, 0, MAX_V)
61 | w = np.clip(w, -MAX_W, MAX_W)
62 | return v, w
63 |
64 |
65 | def callback_drive(waypoint_msg: Float32MultiArray):
66 | """Callback function for the waypoint subscriber"""
67 | global vel_msg
68 | print("seting waypoint")
69 | waypoint.set(waypoint_msg.data)
70 |
71 |
72 | def callback_reached_goal(reached_goal_msg: Bool):
73 | """Callback function for the reached goal subscriber"""
74 | global reached_goal
75 | reached_goal = reached_goal_msg.data
76 |
77 |
78 | def main():
79 | global vel_msg, reverse_mode
80 | rospy.init_node("PD_CONTROLLER", anonymous=False)
81 | waypoint_sub = rospy.Subscriber(WAYPOINT_TOPIC, Float32MultiArray, callback_drive, queue_size=1)
82 | reached_goal_sub = rospy.Subscriber(REACHED_GOAL_TOPIC, Bool, callback_reached_goal, queue_size=1)
83 | vel_out = rospy.Publisher(VEL_TOPIC, Twist, queue_size=1)
84 | rate = rospy.Rate(RATE)
85 | print("Registered with master node. Waiting for waypoints...")
86 | while not rospy.is_shutdown():
87 | vel_msg = Twist()
88 | if reached_goal:
89 | vel_out.publish(vel_msg)
90 | print("Reached goal! Stopping...")
91 | return
92 | elif waypoint.is_valid(verbose=True):
93 | v, w = pd_controller(waypoint.get())
94 | # if reverse_mode:
95 | # v *= -1
96 | vel_msg.linear.x = v
97 | vel_msg.angular.z = w
98 | print(f"publishing new vel: {v}, {w}")
99 | vel_out.publish(vel_msg)
100 | rate.sleep()
101 |
102 |
103 | if __name__ == '__main__':
104 | main()
105 |
--------------------------------------------------------------------------------
/train/config/nomad.yaml:
--------------------------------------------------------------------------------
1 | project_name: nomad
2 | run_name: nomad
3 |
4 | # training setup
5 | use_wandb: True # set to false if you don't want to log to wandb
6 | train: True
7 | batch_size: 256
8 | epochs: 100
9 | gpu_ids: [0]
10 | num_workers: 12
11 | lr: 1e-4
12 | optimizer: adamw
13 | clipping: False
14 | max_norm: 1.
15 | scheduler: "cosine"
16 | warmup: True
17 | warmup_epochs: 4
18 | cyclic_period: 10
19 | plateau_patience: 3
20 | plateau_factor: 0.5
21 | seed: 0
22 |
23 | # model params
24 | model_type: nomad
25 | vision_encoder: nomad_vint
26 | encoding_size: 256
27 | obs_encoder: efficientnet-b0
28 | attn_unet: False
29 | cond_predict_scale: False
30 | mha_num_attention_heads: 4
31 | mha_num_attention_layers: 4
32 | mha_ff_dim_factor: 4
33 | down_dims: [64, 128, 256]
34 |
35 | # diffusion model params
36 | num_diffusion_iters: 10
37 |
38 | # mask
39 | goal_mask_prob: 0.5
40 |
41 | # normalization for the action space
42 | normalize: True
43 |
44 | # context
45 | context_type: temporal
46 | context_size: 3 # 5
47 | alpha: 1e-4
48 |
49 | # distance bounds for distance and action and distance predictions
50 | distance:
51 | min_dist_cat: 0
52 | max_dist_cat: 20
53 | action:
54 | min_dist_cat: 3
55 | max_dist_cat: 20
56 |
57 | # action output params
58 | len_traj_pred: 8
59 | learn_angle: False
60 |
61 | # dataset specific parameters
62 | image_size: [96, 96] # width, height
63 | datasets:
64 | recon:
65 | data_folder: /home//nomad_dataset/recon
66 | train: /home//data_splits/recon/train/ # path to train folder with traj_names.txt
67 | test: /home//data_splits/recon/test/ # path to test folder with traj_names.txt
68 | end_slack: 3 # because many trajectories end in collisions
69 | goals_per_obs: 1 # how many goals are sampled per observation
70 | negative_mining: True # negative mining from the ViNG paper (Shah et al.)
71 | go_stanford:
72 | data_folder: /home//nomad_dataset/go_stanford_cropped # datasets/stanford_go_new
73 | train: /home//data_splits/go_stanford/train/
74 | test: /home//data_splits/go_stanford/test/
75 | end_slack: 0
76 | goals_per_obs: 2 # increase dataset size
77 | negative_mining: True
78 | cory_hall:
79 | data_folder: /home//nomad_dataset/cory_hall/
80 | train: /home//data_splits/cory_hall/train/
81 | test: /home//data_splits/cory_hall/test/
82 | end_slack: 3 # because many trajectories end in collisions
83 | goals_per_obs: 1
84 | negative_mining: True
85 | tartan_drive:
86 | data_folder: /home//nomad_dataset/tartan_drive/
87 | train: /home//data_splits/tartan_drive/train/
88 | test: /home//data_splits/tartan_drive/test/
89 | end_slack: 3 # because many trajectories end in collisions
90 | goals_per_obs: 1
91 | negative_mining: True
92 | sacson:
93 | data_folder: /home//nomad_dataset/sacson/
94 | train: /home//data_splits/sacson/train/
95 | test: /home//data_splits/sacson/test/
96 | end_slack: 3 # because many trajectories end in collisions
97 | goals_per_obs: 1
98 | negative_mining: True
99 | # private datasets (uncomment if you have access)
100 | # seattle:
101 | # data_folder: /home//nomad_dataset/seattle/
102 | # train: /home//data_splits/seattle/train/
103 | # test: /home//data_splits/seattle/test/
104 | # end_slack: 0
105 | # goals_per_obs: 1
106 | # negative_mining: True
107 | # scand:
108 | # data_folder: /home//nomad_dataset/scand/
109 | # train: /home//data_splits/scand/train/
110 | # test: /home//data_splits/scand/test/
111 | # end_slack: 0
112 | # goals_per_obs: 1
113 | # negative_mining: True
114 |
115 | # logging stuff
116 | ## =0 turns off
117 | print_log_freq: 100 # in iterations
118 | image_log_freq: 1000 #0 # in iterations
119 | num_images_log: 8 #0
120 | pairwise_test_freq: 0 # in epochs
121 | eval_fraction: 0.25
122 | wandb_log_freq: 10 # in iterations
123 | eval_freq: 1 # in epochs
--------------------------------------------------------------------------------
/train/vint_train/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from PIL import Image
4 | from typing import Any, Iterable, Tuple
5 |
6 | import torch
7 | from torchvision import transforms
8 | import torchvision.transforms.functional as TF
9 | import torch.nn.functional as F
10 | import io
11 | from typing import Union
12 |
13 | VISUALIZATION_IMAGE_SIZE = (160, 120)
14 | IMAGE_ASPECT_RATIO = (
15 | 4 / 3
16 | ) # all images are centered cropped to a 4:3 aspect ratio in training
17 |
18 |
19 |
20 | def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"):
21 | data_ext = {
22 | "image": ".jpg",
23 | # add more data types here
24 | }
25 | return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}")
26 |
27 |
28 | def yaw_rotmat(yaw: float) -> np.ndarray:
29 | return np.array(
30 | [
31 | [np.cos(yaw), -np.sin(yaw), 0.0],
32 | [np.sin(yaw), np.cos(yaw), 0.0],
33 | [0.0, 0.0, 1.0],
34 | ],
35 | )
36 |
37 |
38 | def to_local_coords(
39 | positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float
40 | ) -> np.ndarray:
41 | """
42 | Convert positions to local coordinates
43 |
44 | Args:
45 | positions (np.ndarray): positions to convert
46 | curr_pos (np.ndarray): current position
47 | curr_yaw (float): current yaw
48 | Returns:
49 | np.ndarray: positions in local coordinates
50 | """
51 | rotmat = yaw_rotmat(curr_yaw)
52 | if positions.shape[-1] == 2:
53 | rotmat = rotmat[:2, :2]
54 | elif positions.shape[-1] == 3:
55 | pass
56 | else:
57 | raise ValueError
58 |
59 | return (positions - curr_pos).dot(rotmat)
60 |
61 |
62 | def calculate_deltas(waypoints: torch.Tensor) -> torch.Tensor:
63 | """
64 | Calculate deltas between waypoints
65 |
66 | Args:
67 | waypoints (torch.Tensor): waypoints
68 | Returns:
69 | torch.Tensor: deltas
70 | """
71 | num_params = waypoints.shape[1]
72 | origin = torch.zeros(1, num_params)
73 | prev_waypoints = torch.concat((origin, waypoints[:-1]), axis=0)
74 | deltas = waypoints - prev_waypoints
75 | if num_params > 2:
76 | return calculate_sin_cos(deltas)
77 | return deltas
78 |
79 |
80 | def calculate_sin_cos(waypoints: torch.Tensor) -> torch.Tensor:
81 | """
82 | Calculate sin and cos of the angle
83 |
84 | Args:
85 | waypoints (torch.Tensor): waypoints
86 | Returns:
87 | torch.Tensor: waypoints with sin and cos of the angle
88 | """
89 | assert waypoints.shape[1] == 3
90 | angle_repr = torch.zeros_like(waypoints[:, :2])
91 | angle_repr[:, 0] = torch.cos(waypoints[:, 2])
92 | angle_repr[:, 1] = torch.sin(waypoints[:, 2])
93 | return torch.concat((waypoints[:, :2], angle_repr), axis=1)
94 |
95 |
96 | def transform_images(
97 | img: Image.Image, transform: transforms, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO
98 | ):
99 | w, h = img.size
100 | if w > h:
101 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio
102 | else:
103 | img = TF.center_crop(img, (int(w / aspect_ratio), w))
104 | viz_img = img.resize(VISUALIZATION_IMAGE_SIZE)
105 | viz_img = TF.to_tensor(viz_img)
106 | img = img.resize(image_resize_size)
107 | transf_img = transform(img)
108 | return viz_img, transf_img
109 |
110 |
111 | def resize_and_aspect_crop(
112 | img: Image.Image, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO
113 | ):
114 | w, h = img.size
115 | if w > h:
116 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio
117 | else:
118 | img = TF.center_crop(img, (int(w / aspect_ratio), w))
119 | img = img.resize(image_resize_size)
120 | resize_img = TF.to_tensor(img)
121 | return resize_img
122 |
123 |
124 | def img_path_to_data(path: Union[str, io.BytesIO], image_resize_size: Tuple[int, int]) -> torch.Tensor:
125 | """
126 | Load an image from a path and transform it
127 | Args:
128 | path (str): path to the image
129 | image_resize_size (Tuple[int, int]): size to resize the image to
130 | Returns:
131 | torch.Tensor: resized image as tensor
132 | """
133 | # return transform_images(Image.open(path), transform, image_resize_size, aspect_ratio)
134 | return resize_and_aspect_crop(Image.open(path), image_resize_size)
135 |
136 |
--------------------------------------------------------------------------------
/train/process_bags.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import pickle
4 | from PIL import Image
5 | import io
6 | import argparse
7 | import tqdm
8 | import yaml
9 | import rosbag
10 |
11 | # utils
12 | from vint_train.process_data.process_data_utils import *
13 |
14 |
15 | def main(args: argparse.Namespace):
16 |
17 | # load the config file
18 | with open("vint_train/process_data/process_bags_config.yaml", "r") as f:
19 | config = yaml.load(f, Loader=yaml.FullLoader)
20 |
21 | # create output dir if it doesn't exist
22 | if not os.path.exists(args.output_dir):
23 | os.makedirs(args.output_dir)
24 |
25 | # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir
26 | bag_files = []
27 | for root, dirs, files in os.walk(args.input_dir):
28 | for file in files:
29 | if file.endswith(".bag"):
30 | bag_files.append(os.path.join(root, file))
31 | if args.num_trajs >= 0:
32 | bag_files = bag_files[: args.num_trajs]
33 |
34 | # processing loop
35 | for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"):
36 | try:
37 | b = rosbag.Bag(bag_path)
38 | except rosbag.ROSBagException as e:
39 | print(e)
40 | print(f"Error loading {bag_path}. Skipping...")
41 | continue
42 |
43 | # name is that folders separated by _ and then the last part of the path
44 | traj_name = "_".join(bag_path.split("/")[-2:])[:-4]
45 |
46 | # load the hdf5 file
47 | bag_img_data, bag_traj_data = get_images_and_odom(
48 | b,
49 | config[args.dataset_name]["imtopics"],
50 | config[args.dataset_name]["odomtopics"],
51 | eval(config[args.dataset_name]["img_process_func"]),
52 | eval(config[args.dataset_name]["odom_process_func"]),
53 | rate=args.sample_rate,
54 | ang_offset=config[args.dataset_name]["ang_offset"],
55 | )
56 |
57 |
58 | if bag_img_data is None or bag_traj_data is None:
59 | print(
60 | f"{bag_path} did not have the topics we were looking for. Skipping..."
61 | )
62 | continue
63 | # remove backwards movement
64 | cut_trajs = filter_backwards(bag_img_data, bag_traj_data)
65 |
66 | for i, (img_data_i, traj_data_i) in enumerate(cut_trajs):
67 | traj_name_i = traj_name + f"_{i}"
68 | traj_folder_i = os.path.join(args.output_dir, traj_name_i)
69 | # make a folder for the traj
70 | if not os.path.exists(traj_folder_i):
71 | os.makedirs(traj_folder_i)
72 | with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f:
73 | pickle.dump(traj_data_i, f)
74 | # save the image data to disk
75 | for i, img in enumerate(img_data_i):
76 | img.save(os.path.join(traj_folder_i, f"{i}.jpg"))
77 |
78 |
79 | if __name__ == "__main__":
80 | parser = argparse.ArgumentParser()
81 | # get arguments for the recon input dir and the output dir
82 | # add dataset name
83 | parser.add_argument(
84 | "--dataset-name",
85 | "-d",
86 | type=str,
87 | help="name of the dataset (must be in process_config.yaml)",
88 | default="tartan_drive",
89 | required=True,
90 | )
91 | parser.add_argument(
92 | "--input-dir",
93 | "-i",
94 | type=str,
95 | help="path of the datasets with rosbags",
96 | required=True,
97 | )
98 | parser.add_argument(
99 | "--output-dir",
100 | "-o",
101 | default="../datasets/tartan_drive/",
102 | type=str,
103 | help="path for processed dataset (default: ../datasets/tartan_drive/)",
104 | )
105 | # number of trajs to process
106 | parser.add_argument(
107 | "--num-trajs",
108 | "-n",
109 | default=-1,
110 | type=int,
111 | help="number of bags to process (default: -1, all)",
112 | )
113 | # sampling rate
114 | parser.add_argument(
115 | "--sample-rate",
116 | "-s",
117 | default=4.0,
118 | type=float,
119 | help="sampling rate (default: 4.0 hz)",
120 | )
121 |
122 | args = parser.parse_args()
123 | # all caps for the dataset name
124 | print(f"STARTING PROCESSING {args.dataset_name.upper()} DATASET")
125 | main(args)
126 | print(f"FINISHED PROCESSING {args.dataset_name.upper()} DATASET")
127 |
--------------------------------------------------------------------------------
/train/process_bag_diff.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import pickle
4 | from PIL import Image
5 | import io
6 | import argparse
7 | import tqdm
8 | import yaml
9 | import rosbag
10 |
11 | # utils
12 | from vint_train.process_data.process_data_utils import *
13 |
14 |
15 | def main(args: argparse.Namespace):
16 |
17 | # load the config file
18 | with open("vint_train/process_data/process_bags_config.yaml", "r") as f:
19 | config = yaml.load(f, Loader=yaml.FullLoader)
20 |
21 | # create output dir if it doesn't exist
22 | if not os.path.exists(args.output_dir):
23 | os.makedirs(args.output_dir)
24 |
25 | # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir
26 | bag_files = []
27 | for root, dirs, files in os.walk(args.input_dir):
28 | for file in files:
29 | if file.endswith(".bag") and "diff" in file:
30 | bag_files.append(os.path.join(root, file))
31 | if args.num_trajs >= 0:
32 | bag_files = bag_files[: args.num_trajs]
33 |
34 | # processing loop
35 | for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"):
36 | try:
37 | b = rosbag.Bag(bag_path)
38 | except rosbag.ROSBagException as e:
39 | print(e)
40 | print(f"Error loading {bag_path}. Skipping...")
41 | continue
42 |
43 | # name is that folders separated by _ and then the last part of the path
44 | traj_name = "_".join(bag_path.split("/")[-2:])[:-4]
45 |
46 | # load the bag file
47 | bag_img_data, bag_traj_data = get_images_and_odom_2(
48 | b,
49 | ['/usb_cam_front/image_raw', '/chosen_subgoal'],
50 | ['/odom'],
51 | rate=args.sample_rate,
52 | )
53 |
54 | if bag_img_data is None:
55 | print(
56 | f"{bag_path} did not have the topics we were looking for. Skipping..."
57 | )
58 | continue
59 | # remove backwards movement
60 | # cut_trajs = filter_backwards(bag_img_data, bag_traj_data)
61 |
62 | # for i, (img_data_i, traj_data_i) in enumerate(cut_trajs):
63 | # traj_name_i = traj_name + f"_{i}"
64 | # traj_folder_i = os.path.join(args.output_dir, traj_name_i)
65 | # # make a folder for the traj
66 | # if not os.path.exists(traj_folder_i):
67 | # os.makedirs(traj_folder_i)
68 | # with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f:
69 | # pickle.dump(traj_data_i, f)
70 | # # save the image data to disk
71 | # for i, img in enumerate(img_data_i):
72 | # img.save(os.path.join(traj_folder_i, f"{i}.jpg"))
73 |
74 | traj_folder = os.path.join(args.output_dir, traj_name)
75 | if not os.path.exists(traj_folder):
76 | os.makedirs(traj_folder)
77 |
78 | obs_images = bag_img_data["/usb_cam_front/image_raw"]
79 | diff_images = bag_img_data["/chosen_subgoal"]
80 | for i, img_data in enumerate(zip(obs_images, diff_images)):
81 | obs_image, diff_image = img_data
82 | # save the image data to disk
83 | # save the image data to disk
84 | obs_image.save(os.path.join(traj_folder, f"{i}.jpg"))
85 | diff_image.save(os.path.join(traj_folder, f"diff_{i}.jpg"))
86 |
87 | with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f:
88 | pickle.dump(bag_traj_data['/odom'], f)
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser()
93 | # get arguments for the recon input dir and the output dir
94 | # add dataset name
95 | # parser.add_argument(
96 | # "--dataset-name",
97 | # "-d",
98 | # type=str,
99 | # help="name of the dataset (must be in process_config.yaml)",
100 | # default="tartan_drive",
101 | # required=True,
102 | # )
103 | parser.add_argument(
104 | "--input-dir",
105 | "-i",
106 | type=str,
107 | help="path of the datasets with rosbags",
108 | required=True,
109 | )
110 | parser.add_argument(
111 | "--output-dir",
112 | "-o",
113 | default="../datasets/tartan_drive/",
114 | type=str,
115 | help="path for processed dataset (default: ../datasets/tartan_drive/)",
116 | )
117 | # number of trajs to process
118 | parser.add_argument(
119 | "--num-trajs",
120 | "-n",
121 | default=-1,
122 | type=int,
123 | help="number of bags to process (default: -1, all)",
124 | )
125 | # sampling rate
126 | parser.add_argument(
127 | "--sample-rate",
128 | "-s",
129 | default=4.0,
130 | type=float,
131 | help="sampling rate (default: 4.0 hz)",
132 | )
133 |
134 | args = parser.parse_args()
135 | # all caps for the dataset name
136 | print(f"STARTING PROCESSING DIFF DATASET")
137 | main(args)
138 | print(f"FINISHED PROCESSING DIFF DATASET")
139 |
--------------------------------------------------------------------------------
/deployment/src/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import io
5 | import matplotlib.pyplot as plt
6 | from scipy.spatial.transform import Rotation as R
7 |
8 | # ROS
9 | from sensor_msgs.msg import Image
10 | import tf
11 |
12 | # pytorch
13 | import torch
14 | import torch.nn as nn
15 | from torchvision import transforms
16 | import torchvision.transforms.functional as TF
17 |
18 | import numpy as np
19 | from PIL import Image as PILImage
20 | from typing import List, Tuple, Dict, Optional
21 |
22 | # models
23 | from vint_train.models.gnm.gnm import GNM
24 | from vint_train.models.vint.vint import ViNT
25 |
26 | from vint_train.models.vint.vit import ViT
27 | from vint_train.models.nomad.nomad import NoMaD, DenseNetwork
28 | from vint_train.models.nomad.nomad_vint import NoMaD_ViNT, replace_bn_with_gn
29 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
30 | from vint_train.data.data_utils import IMAGE_ASPECT_RATIO
31 |
32 |
33 | def load_model(
34 | model_path: str,
35 | config: dict,
36 | device: torch.device = torch.device("cpu"),
37 | ) -> nn.Module:
38 | model_type = config["model_type"]
39 |
40 | if model_type == "gnm":
41 | model = GNM(
42 | config["context_size"],
43 | config["len_traj_pred"],
44 | config["learn_angle"],
45 | config["obs_encoding_size"],
46 | config["goal_encoding_size"],
47 | )
48 | elif model_type == "vint":
49 | model = ViNT(
50 | context_size=config["context_size"],
51 | len_traj_pred=config["len_traj_pred"],
52 | learn_angle=config["learn_angle"],
53 | obs_encoder=config["obs_encoder"],
54 | obs_encoding_size=config["obs_encoding_size"],
55 | late_fusion=config["late_fusion"],
56 | mha_num_attention_heads=config["mha_num_attention_heads"],
57 | mha_num_attention_layers=config["mha_num_attention_layers"],
58 | mha_ff_dim_factor=config["mha_ff_dim_factor"],
59 | )
60 | elif config["model_type"] == "nomad":
61 | if config["vision_encoder"] == "nomad_vint":
62 | vision_encoder = NoMaD_ViNT(
63 | obs_encoding_size=config["encoding_size"],
64 | context_size=config["context_size"],
65 | mha_num_attention_heads=config["mha_num_attention_heads"],
66 | mha_num_attention_layers=config["mha_num_attention_layers"],
67 | mha_ff_dim_factor=config["mha_ff_dim_factor"],
68 | )
69 | vision_encoder = replace_bn_with_gn(vision_encoder)
70 | elif config["vision_encoder"] == "vit":
71 | vision_encoder = ViT(
72 | obs_encoding_size=config["encoding_size"],
73 | context_size=config["context_size"],
74 | image_size=config["image_size"],
75 | patch_size=config["patch_size"],
76 | mha_num_attention_heads=config["mha_num_attention_heads"],
77 | mha_num_attention_layers=config["mha_num_attention_layers"],
78 | )
79 | vision_encoder = replace_bn_with_gn(vision_encoder)
80 | else:
81 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported")
82 |
83 | noise_pred_net = ConditionalUnet1D(
84 | input_dim=2,
85 | global_cond_dim=config["encoding_size"],
86 | down_dims=config["down_dims"],
87 | cond_predict_scale=config["cond_predict_scale"],
88 | )
89 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"])
90 |
91 | model = NoMaD(
92 | vision_encoder=vision_encoder,
93 | noise_pred_net=noise_pred_net,
94 | dist_pred_net=dist_pred_network,
95 | )
96 | else:
97 | raise ValueError(f"Invalid model type: {model_type}")
98 |
99 | checkpoint = torch.load(model_path, map_location=device)
100 | if model_type == "nomad":
101 | state_dict = checkpoint
102 | model.load_state_dict(state_dict, strict=False)
103 | else:
104 | loaded_model = checkpoint["model"]
105 | try:
106 | state_dict = loaded_model.module.state_dict()
107 | model.load_state_dict(state_dict, strict=False)
108 | except AttributeError as e:
109 | state_dict = loaded_model.state_dict()
110 | model.load_state_dict(state_dict, strict=False)
111 | model.to(device)
112 | return model
113 |
114 |
115 | def msg_to_pil(msg: Image) -> PILImage.Image:
116 | img = np.frombuffer(msg.data, dtype=np.uint8).reshape(
117 | msg.height, msg.width, -1)
118 | img = img[:, :, :3]
119 | img = img[:, :, ::-1]
120 | pil_image = PILImage.fromarray(img)
121 | return pil_image
122 |
123 |
124 | def pil_to_msg(pil_img: PILImage.Image, encoding="mono8") -> Image:
125 | img = np.asarray(pil_img)
126 | ros_image = Image(encoding=encoding)
127 | ros_image.height, ros_image.width, _ = img.shape
128 | ros_image.data = img.ravel().tobytes()
129 | ros_image.step = ros_image.width
130 | return ros_image
131 |
132 |
133 | def to_numpy(tensor):
134 | return tensor.cpu().detach().numpy()
135 |
136 |
137 | def transform_images(pil_imgs: List[PILImage.Image], image_size: List[int], center_crop: bool = False) -> torch.Tensor:
138 | transform_type = transforms.Compose(
139 | [
140 | transforms.ToTensor(),
141 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
142 | 0.229, 0.224, 0.225]),
143 | ]
144 | )
145 | if type(pil_imgs) != list:
146 | pil_imgs = [pil_imgs]
147 | transf_imgs = []
148 | for pil_img in pil_imgs:
149 | w, h = pil_img.size
150 | if center_crop:
151 | if w > h:
152 | pil_img = TF.center_crop(pil_img, (h, int(h * IMAGE_ASPECT_RATIO)))
153 | else:
154 | pil_img = TF.center_crop(pil_img, (int(w / IMAGE_ASPECT_RATIO), w))
155 | pil_img = pil_img.resize(image_size)
156 | transf_img = transform_type(pil_img)
157 | transf_img = torch.unsqueeze(transf_img, 0)
158 | transf_imgs.append(transf_img)
159 | return torch.cat(transf_imgs, dim=1)
160 |
161 |
162 | def clip_angle(angle):
163 | return np.mod(angle + np.pi, 2 * np.pi) - np.pi
164 |
165 | def rotate_point_by_quaternion(point, quaternion):
166 | r = R.from_quat(quaternion)
167 | rotation_matrix = r.as_matrix()
168 |
169 | point_h = np.array([point[0], point[1], point[2]])
170 |
171 | rotated_point = np.dot(np.linalg.inv(rotation_matrix), point_h)
172 |
173 | return rotated_point
174 |
--------------------------------------------------------------------------------
/deployment/src/costmap_cfg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, ETH Zurich (Robotics Systems Lab)
2 | # Author: Pascal Roth
3 | # All rights reserved.
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 |
7 | # python
8 | import os
9 | # from dataclasses import dataclass
10 | from typing import Optional
11 |
12 | import yaml
13 |
14 |
15 | class Loader(yaml.SafeLoader):
16 | pass
17 |
18 |
19 | def construct_GeneralCostMapConfig(loader, node):
20 | return GeneralCostMapConfig(**loader.construct_mapping(node))
21 |
22 |
23 | Loader.add_constructor(
24 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.GeneralCostMapConfig",
25 | construct_GeneralCostMapConfig,
26 | )
27 |
28 |
29 | def construct_ReconstructionCfg(loader, node):
30 | return ReconstructionCfg(**loader.construct_mapping(node))
31 |
32 |
33 | Loader.add_constructor(
34 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.ReconstructionCfg",
35 | construct_ReconstructionCfg,
36 | )
37 |
38 |
39 | def construct_SemCostMapConfig(loader, node):
40 | return SemCostMapConfig(**loader.construct_mapping(node))
41 |
42 |
43 | Loader.add_constructor(
44 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.SemCostMapConfig",
45 | construct_SemCostMapConfig,
46 | )
47 |
48 |
49 | def construct_TsdfCostMapConfig(loader, node):
50 | return TsdfCostMapConfig(**loader.construct_mapping(node))
51 |
52 |
53 | Loader.add_constructor(
54 | "tag:yaml.org,2002:python/object:viplanner.config.costmap_cfg.TsdfCostMapConfig",
55 | construct_TsdfCostMapConfig,
56 | )
57 |
58 |
59 | # @dataclass
60 | class ReconstructionCfg:
61 | """
62 | Arguments for 3D reconstruction using depth maps
63 | """
64 |
65 | # directory where the environment with the depth (and semantic) images is located
66 | data_dir: str = "${USER_PATH_TO_DATA}"
67 | # environment name
68 | env: str = "town01"
69 | # image suffix
70 | depth_suffix = "_cam0"
71 | sem_suffix = "_cam1"
72 | # higher resolution depth images available for reconstruction (meaning that the depth images are also taked by the semantic camera)
73 | high_res_depth: bool = False
74 |
75 | # reconstruction parameters
76 | voxel_size: float = 0.1 # [m] 0.05 for matterport 0.1 for carla
77 | start_idx: int = 0 # start index for reconstruction
78 | max_images: Optional[int] = 1000 # maximum number of images to reconstruct, if None, all images are used
79 | depth_scale: float = 1000.0 # depth scale factor
80 | # semantic reconstruction
81 | semantics: bool = False
82 |
83 | # speed vs. memory trade-off parameters
84 | point_cloud_batch_size: int = (
85 | 200 # 3d points of nbr images added to point cloud at once (higher values use more memory but faster)
86 | )
87 |
88 | """ Internal functions """
89 |
90 | def get_data_path(self) -> str:
91 | return os.path.join(self.data_dir, self.env)
92 |
93 | def get_out_path(self) -> str:
94 | return os.path.join(self.out_dir, self.env)
95 |
96 |
97 | # @dataclass
98 | class SemCostMapConfig:
99 | """Configuration for the semantic cost map"""
100 |
101 | # point-cloud filter parameters
102 | ground_height: Optional[float] = -0.5 # None for matterport -0.5 for carla -1.0 for nomoko
103 | robot_height: float = 0.70
104 | robot_height_factor: float = 3.0
105 | nb_neighbors: int = 100
106 | std_ratio: float = 2.0 # keep high, otherwise ground will be removed
107 | downsample: bool = False
108 | # smoothing
109 | nb_neigh: int = 15
110 | change_decimal: int = 3
111 | conv_crit: float = (
112 | 0.45 # ration of points that have to change by at least the #change_decimal decimal value to converge
113 | )
114 | nb_tasks: Optional[int] = 10 # number of tasks for parallel processing, if None, all available cores are used
115 | sigma_smooth: float = 2.5
116 | max_iterations: int = 1
117 | # obstacle threshold (multiplied with highest loss value defined for a semantic class)
118 | obstacle_threshold: float = 0.8 # 0.5/ 0.6 for matterport, 0.8 for carla
119 | # negative reward for space with smallest cost (introduces a gradient in area with smallest loss value, steering towards center)
120 | # NOTE: at the end cost map is elevated by that amount to ensure that the smallest cost is 0
121 | negative_reward: float = 0.5
122 | # loss values rounded up to decimal #round_decimal_traversable equal to 0.0 are selected and the traversable gradient is determined based on them
123 | round_decimal_traversable: int = 2
124 | # compute height map
125 | compute_height_map: bool = False # false for matterport, true for carla and nomoko
126 |
127 |
128 | # @dataclass
129 | class TsdfCostMapConfig:
130 | """Configuration for the tsdf cost map"""
131 |
132 | # offset of the point cloud
133 | offset_z: float = 0.3
134 | # filter parameters
135 | ground_height: float = 0.3
136 | robot_height: float = 0.7
137 | robot_height_factor: float = 1.2
138 | nb_neighbors: int = 50
139 | std_ratio: float = 0.2
140 | filter_outliers: bool = True
141 | # dilation parameters
142 | sigma_expand: float = 3.0
143 | obstacle_threshold: float = 0.01
144 | free_space_threshold: float = 0.8
145 |
146 |
147 | # @dataclass
148 | class GeneralCostMapConfig:
149 | """General Cost Map Configuration"""
150 |
151 | # path to point cloud
152 | root_path: str = "town01"
153 | ply_file: str = "cloud.ply"
154 | # resolution of the cost map
155 | resolution: float = 0.04 # [m] (0.04 for matterport, 0.1 for carla)
156 | # map parameters
157 | clear_dist: float = 1.0 # cost map expansion over the point cloud space (prevent paths to go out of the map)
158 | # smoothing parameters
159 | sigma_smooth: float = 3.0
160 | # cost map expansion
161 | x_min: Optional[float] = 0
162 | # [m] if None, the minimum of the point cloud is used None (carla town01: -8.05 matterport: None)
163 | y_min: Optional[float] = -3
164 | # [m] if None, the minimum of the point cloud is used None (carla town01: -8.05 matterport: None)
165 | x_max: Optional[float] = 8
166 | # [m] if None, the maximum of the point cloud is used None (carla town01: 346.22 matterport: None)
167 | y_max: Optional[float] = 3
168 | # [m] if None, the maximum of the point cloud is used None (carla town01: 336.65 matterport: None)
169 |
170 |
171 | # @dataclass
172 | class CostMapConfig:
173 | """General Cost Map Configuration"""
174 |
175 | # cost map domains
176 | semantics: bool = True
177 | geometry: bool = False
178 |
179 | # name
180 | map_name: str = "cost_map_sem"
181 |
182 | # general cost map configuration
183 | general: GeneralCostMapConfig = GeneralCostMapConfig()
184 |
185 | # individual cost map configurations
186 | sem_cost_map: SemCostMapConfig = SemCostMapConfig()
187 | tsdf_cost_map: TsdfCostMapConfig = TsdfCostMapConfig()
188 |
189 | # visualize cost map
190 | visualize: bool = True
191 |
192 | # FILLED BY CODE -> DO NOT CHANGE ###
193 | x_start: float = None
194 | y_start: float = None
195 |
196 |
197 | # EoF
198 |
--------------------------------------------------------------------------------
/train/vint_train/visualizing/distance_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import numpy as np
4 | from typing import List, Optional, Tuple
5 | from vint_train.visualizing.visualize_utils import numpy_to_img
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def visualize_dist_pred(
10 | batch_obs_images: np.ndarray,
11 | batch_goal_images: np.ndarray,
12 | batch_dist_preds: np.ndarray,
13 | batch_dist_labels: np.ndarray,
14 | eval_type: str,
15 | save_folder: str,
16 | epoch: int,
17 | num_images_preds: int = 8,
18 | use_wandb: bool = True,
19 | display: bool = False,
20 | rounding: int = 4,
21 | dist_error_threshold: float = 3.0,
22 | ):
23 | """
24 | Visualize the distance classification predictions and labels for an observation-goal image pair.
25 |
26 | Args:
27 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
28 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels]
29 | batch_dist_preds (np.ndarray): batch of distance predictions [batch_size]
30 | batch_dist_labels (np.ndarray): batch of distance labels [batch_size]
31 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.)
32 | epoch (int): current epoch number
33 | num_images_preds (int): number of images to visualize
34 | use_wandb (bool): whether to use wandb to log the images
35 | save_folder (str): folder to save the images. If None, will not save the images
36 | display (bool): whether to display the images
37 | rounding (int): number of decimal places to round the distance predictions and labels
38 | dist_error_threshold (float): distance error threshold for classifying the distance prediction as correct or incorrect (only used for visualization purposes)
39 | """
40 | visualize_path = os.path.join(
41 | save_folder,
42 | "visualize",
43 | eval_type,
44 | f"epoch{epoch}",
45 | "dist_classification",
46 | )
47 | if not os.path.isdir(visualize_path):
48 | os.makedirs(visualize_path)
49 | assert (
50 | len(batch_obs_images)
51 | == len(batch_goal_images)
52 | == len(batch_dist_preds)
53 | == len(batch_dist_labels)
54 | )
55 | batch_size = batch_obs_images.shape[0]
56 | wandb_list = []
57 | for i in range(min(batch_size, num_images_preds)):
58 | dist_pred = np.round(batch_dist_preds[i], rounding)
59 | dist_label = np.round(batch_dist_labels[i], rounding)
60 | obs_image = numpy_to_img(batch_obs_images[i])
61 | goal_image = numpy_to_img(batch_goal_images[i])
62 |
63 | save_path = None
64 | if save_folder is not None:
65 | save_path = os.path.join(visualize_path, f"{i}.png")
66 | text_color = "black"
67 | if abs(dist_pred - dist_label) > dist_error_threshold:
68 | text_color = "red"
69 |
70 | display_distance_pred(
71 | [obs_image, goal_image],
72 | ["Observation", "Goal"],
73 | dist_pred,
74 | dist_label,
75 | text_color,
76 | save_path,
77 | display,
78 | )
79 | if use_wandb:
80 | wandb_list.append(wandb.Image(save_path))
81 | if use_wandb:
82 | wandb.log({f"{eval_type}_dist_prediction": wandb_list}, commit=False)
83 |
84 |
85 | def visualize_dist_pairwise_pred(
86 | batch_obs_images: np.ndarray,
87 | batch_close_images: np.ndarray,
88 | batch_far_images: np.ndarray,
89 | batch_close_preds: np.ndarray,
90 | batch_far_preds: np.ndarray,
91 | batch_close_labels: np.ndarray,
92 | batch_far_labels: np.ndarray,
93 | eval_type: str,
94 | save_folder: str,
95 | epoch: int,
96 | num_images_preds: int = 8,
97 | use_wandb: bool = True,
98 | display: bool = False,
99 | rounding: int = 4,
100 | ):
101 | """
102 | Visualize the distance classification predictions and labels for an observation-goal image pair.
103 |
104 | Args:
105 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
106 | batch_close_images (np.ndarray): batch of close goal images [batch_size, height, width, channels]
107 | batch_far_images (np.ndarray): batch of far goal images [batch_size, height, width, channels]
108 | batch_close_preds (np.ndarray): batch of close predictions [batch_size]
109 | batch_far_preds (np.ndarray): batch of far predictions [batch_size]
110 | batch_close_labels (np.ndarray): batch of close labels [batch_size]
111 | batch_far_labels (np.ndarray): batch of far labels [batch_size]
112 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.)
113 | save_folder (str): folder to save the images. If None, will not save the images
114 | epoch (int): current epoch number
115 | num_images_preds (int): number of images to visualize
116 | use_wandb (bool): whether to use wandb to log the images
117 | display (bool): whether to display the images
118 | rounding (int): number of decimal places to round the distance predictions and labels
119 | """
120 | visualize_path = os.path.join(
121 | save_folder,
122 | "visualize",
123 | eval_type,
124 | f"epoch{epoch}",
125 | "pairwise_dist_classification",
126 | )
127 | if not os.path.isdir(visualize_path):
128 | os.makedirs(visualize_path)
129 | assert (
130 | len(batch_obs_images)
131 | == len(batch_close_images)
132 | == len(batch_far_images)
133 | == len(batch_close_preds)
134 | == len(batch_far_preds)
135 | == len(batch_close_labels)
136 | == len(batch_far_labels)
137 | )
138 | batch_size = batch_obs_images.shape[0]
139 | wandb_list = []
140 | for i in range(min(batch_size, num_images_preds)):
141 | close_dist_pred = np.round(batch_close_preds[i], rounding)
142 | far_dist_pred = np.round(batch_far_preds[i], rounding)
143 | close_dist_label = np.round(batch_close_labels[i], rounding)
144 | far_dist_label = np.round(batch_far_labels[i], rounding)
145 | obs_image = numpy_to_img(batch_obs_images[i])
146 | close_image = numpy_to_img(batch_close_images[i])
147 | far_image = numpy_to_img(batch_far_images[i])
148 |
149 | save_path = None
150 | if save_folder is not None:
151 | save_path = os.path.join(visualize_path, f"{i}.png")
152 |
153 | if close_dist_pred < far_dist_pred:
154 | text_color = "black"
155 | else:
156 | text_color = "red"
157 |
158 | display_distance_pred(
159 | [obs_image, close_image, far_image],
160 | ["Observation", "Close Goal", "Far Goal"],
161 | f"close_pred = {close_dist_pred}, far_pred = {far_dist_pred}",
162 | f"close_label = {close_dist_label}, far_label = {far_dist_label}",
163 | text_color,
164 | save_path,
165 | display,
166 | )
167 | if use_wandb:
168 | wandb_list.append(wandb.Image(save_path))
169 | if use_wandb:
170 | wandb.log({f"{eval_type}_pairwise_classification": wandb_list}, commit=False)
171 |
172 |
173 | def display_distance_pred(
174 | imgs: list,
175 | titles: list,
176 | dist_pred: float,
177 | dist_label: float,
178 | text_color: str = "black",
179 | save_path: Optional[str] = None,
180 | display: bool = False,
181 | ):
182 | plt.figure()
183 | fig, ax = plt.subplots(1, len(imgs))
184 |
185 | plt.suptitle(f"prediction: {dist_pred}\nlabel: {dist_label}", color=text_color)
186 |
187 | for axis, img, title in zip(ax, imgs, titles):
188 | axis.imshow(img)
189 | axis.set_title(title)
190 | axis.xaxis.set_visible(False)
191 | axis.yaxis.set_visible(False)
192 |
193 | # make the plot large
194 | fig.set_size_inches((18.5 / 3) * len(imgs), 10.5)
195 |
196 | if save_path is not None:
197 | fig.savefig(
198 | save_path,
199 | bbox_inches="tight",
200 | )
201 | if not display:
202 | plt.close(fig)
203 |
--------------------------------------------------------------------------------
/deployment/src/cost_to_pcd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, ETH Zurich (Robotics Systems Lab)
2 | # Author: Pascal Roth
3 | # All rights reserved.
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 |
7 | import argparse
8 | import os
9 | from typing import Optional, Union
10 |
11 | import numpy as np
12 | import open3d as o3d
13 | import pypose as pp
14 | import torch
15 | import yaml
16 |
17 | # viplanner
18 | from costmap_cfg import CostMapConfig, Loader
19 |
20 | torch.set_default_dtype(torch.float32)
21 |
22 |
23 | class CostMapPCD:
24 | def __init__(
25 | self,
26 | cfg: CostMapConfig,
27 | tsdf_array: np.ndarray,
28 | viz_points: np.ndarray,
29 | ground_array: np.ndarray,
30 | gpu_id: Optional[int] = 0,
31 | load_from_file: Optional[bool] = False,
32 | ):
33 | # determine device
34 | if torch.cuda.is_available() and gpu_id is not None:
35 | self.device = torch.device("cuda:" + str(gpu_id))
36 | else:
37 | self.device = torch.device("cpu")
38 |
39 | # args
40 | self.cfg: CostMapConfig = cfg
41 | self.load_from_file: bool = load_from_file
42 | self.tsdf_array: torch.Tensor = torch.tensor(tsdf_array, device=self.device)
43 | self.viz_points: np.ndarray = viz_points
44 | self.ground_array: torch.Tensor = torch.tensor(ground_array, device=self.device)
45 |
46 | # init flag
47 | self.map_init = False
48 |
49 | # init pointclouds
50 | self.pcd_tsdf = o3d.geometry.PointCloud()
51 | self.pcd_viz = o3d.geometry.PointCloud()
52 |
53 | # execute setup
54 | self.num_x: int = 0
55 | self.num_y: int = 0
56 | self.setup()
57 | return
58 |
59 | def setup(self):
60 | # expand of cost map
61 | self.num_x, self.num_y = self.tsdf_array.shape
62 | # visualization points
63 | self.pcd_viz.points = o3d.utility.Vector3dVector(self.viz_points)
64 | # set cost map
65 | self.SetUpCostArray()
66 | # update pcd instance
67 | xv, yv = np.meshgrid(
68 | np.linspace(0, self.num_x * self.cfg.general.resolution, self.num_x),
69 | np.linspace(0, self.num_y * self.cfg.general.resolution, self.num_y),
70 | indexing="ij",
71 | )
72 | T = np.concatenate((np.expand_dims(xv, axis=0), np.expand_dims(yv, axis=0)), axis=0)
73 | T = np.concatenate(
74 | (
75 | T,
76 | np.expand_dims(self.cost_array.cpu().detach().numpy(), axis=0),
77 | ),
78 | axis=0,
79 | )
80 | if self.load_from_file:
81 | wps = T.reshape(3, -1).T + np.array([self.cfg.x_start, self.cfg.y_start, 0.0])
82 | self.pcd_tsdf.points = o3d.utility.Vector3dVector(wps)
83 | else:
84 | self.pcd_tsdf.points = o3d.utility.Vector3dVector(T.reshape(3, -1).T)
85 |
86 | self.map_init = True
87 | return
88 |
89 | def ShowTSDFMap(self, cost_map=True): # not run with cuda
90 | if not self.map_init:
91 | print("Error: cannot show map, map has not been init yet!")
92 | return
93 | if cost_map:
94 | o3d.visualization.draw_geometries([self.pcd_tsdf])
95 | else:
96 | o3d.visualization.draw_geometries([self.pcd_viz])
97 | return
98 |
99 | def Pos2Ind(self, points: Union[torch.Tensor, pp.LieTensor]):
100 | # points [torch shapes [num_p, 3]]
101 | start_xy = torch.tensor(
102 | [self.cfg.x_start, self.cfg.y_start],
103 | dtype=torch.float64,
104 | device=points.device,
105 | ).expand(1, 1, -1)
106 | if isinstance(points, pp.LieTensor):
107 | H = (points.tensor()[:, :, 0:2] - start_xy) / self.cfg.general.resolution
108 | else:
109 | H = (points[:, :, 0:2] - start_xy) / self.cfg.general.resolution
110 | mask = torch.logical_and(
111 | (H > 0).all(axis=2),
112 | (H < torch.tensor([self.num_x, self.num_y], device=points.device)[None, None, :]).all(axis=2),
113 | )
114 | return self.NormInds(H), H[mask, :]
115 |
116 | def NormInds(self, H):
117 | norm_matrix = torch.tensor(
118 | [self.num_x / 2.0, self.num_y / 2.0],
119 | dtype=torch.float64,
120 | device=H.device,
121 | )
122 | H = (H - norm_matrix) / norm_matrix
123 | return H
124 |
125 | def DeNormInds(self, NH):
126 | norm_matrix = torch.tensor(
127 | [self.num_x / 2.0, self.num_y / 2.0],
128 | dtype=torch.float64,
129 | device=NH.device,
130 | )
131 | NH = NH * norm_matrix + norm_matrix
132 | return NH
133 |
134 | def SaveTSDFMap(self):
135 | if not self.map_init:
136 | print("Error: map has not been init yet!")
137 | return
138 |
139 | # make directories
140 | os.makedirs(
141 | os.path.join(self.cfg.general.root_path, "maps", "data"),
142 | exist_ok=True,
143 | )
144 | os.makedirs(
145 | os.path.join(self.cfg.general.root_path, "maps", "cloud"),
146 | exist_ok=True,
147 | )
148 | os.makedirs(
149 | os.path.join(self.cfg.general.root_path, "maps", "params"),
150 | exist_ok=True,
151 | )
152 |
153 | map_path = os.path.join(
154 | self.cfg.general.root_path,
155 | "maps",
156 | "data",
157 | self.cfg.map_name + "_map.txt",
158 | )
159 | ground_path = os.path.join(
160 | self.cfg.general.root_path,
161 | "maps",
162 | "data",
163 | self.cfg.map_name + "_ground.txt",
164 | )
165 | cloud_path = os.path.join(
166 | self.cfg.general.root_path,
167 | "maps",
168 | "cloud",
169 | self.cfg.map_name + "_cloud.txt",
170 | )
171 | # save data
172 | np.savetxt(map_path, self.tsdf_array.cpu())
173 | np.savetxt(ground_path, self.ground_array.cpu())
174 | np.savetxt(cloud_path, self.viz_points)
175 | # save config parameters
176 | yaml_path = os.path.join(
177 | self.cfg.general.root_path,
178 | "maps",
179 | "params",
180 | f"config_{self.cfg.map_name}.yaml",
181 | )
182 | with open(yaml_path, "w+") as file:
183 | yaml.dump(
184 | vars(self.cfg),
185 | file,
186 | allow_unicode=True,
187 | default_flow_style=False,
188 | )
189 |
190 | print("TSDF Map saved.")
191 | return
192 |
193 | def SetUpCostArray(self):
194 | self.cost_array = self.tsdf_array
195 | return
196 |
197 | @classmethod
198 | def ReadTSDFMap(cls, root_path: str, map_name: str, gpu_id: Optional[int] = None):
199 | # read config
200 | with open(os.path.join(root_path, "maps", "params", f"config_{map_name}.yaml")) as f:
201 | cfg: CostMapConfig = CostMapConfig(**yaml.load(f, Loader))
202 |
203 | # load data
204 | tsdf_array = np.loadtxt(os.path.join(root_path, "maps", "data", map_name + "_map.txt"))
205 | viz_points = np.loadtxt(os.path.join(root_path, "maps", "cloud", map_name + "_cloud.txt"))
206 | ground_array = np.loadtxt(os.path.join(root_path, "maps", "data", map_name + "_ground.txt"))
207 |
208 | return cls(
209 | cfg=cfg,
210 | tsdf_array=tsdf_array,
211 | viz_points=viz_points,
212 | ground_array=ground_array,
213 | gpu_id=gpu_id,
214 | load_from_file=True,
215 | )
216 |
217 |
218 | if __name__ == "__main__":
219 | # parse environment directory and cost_map name
220 | parser = argparse.ArgumentParser(prog="Show Costmap", description="Show Costmap")
221 | parser.add_argument(
222 | "-e",
223 | "--env",
224 | type=str,
225 | help="path to the environment directory",
226 | required=True,
227 | )
228 | parser.add_argument("-m", "--map", type=str, help="name of the cost_map", required=True)
229 | args = parser.parse_args()
230 |
231 | # show costmap
232 | map = CostMapPCD.ReadTSDFMap(args.env, args.map)
233 | map.ShowTSDFMap()
234 |
235 | # EoF
236 |
--------------------------------------------------------------------------------
/train/vint_train/models/nomad/nomad_vint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | from typing import List, Dict, Optional, Tuple, Callable
6 | from efficientnet_pytorch import EfficientNet
7 | from vint_train.models.vint.self_attention import PositionalEncoding
8 |
9 | class NoMaD_ViNT(nn.Module):
10 | def __init__(
11 | self,
12 | context_size: int = 5,
13 | obs_encoder: Optional[str] = "efficientnet-b0",
14 | obs_encoding_size: Optional[int] = 512,
15 | mha_num_attention_heads: Optional[int] = 2,
16 | mha_num_attention_layers: Optional[int] = 2,
17 | mha_ff_dim_factor: Optional[int] = 4,
18 | ) -> None:
19 | """
20 | NoMaD ViNT Encoder class
21 | """
22 | super().__init__()
23 | self.obs_encoding_size = obs_encoding_size
24 | self.goal_encoding_size = obs_encoding_size
25 | self.context_size = context_size
26 |
27 | # Initialize the observation encoder
28 | if obs_encoder.split("-")[0] == "efficientnet":
29 | self.obs_encoder = EfficientNet.from_name(obs_encoder, in_channels=3) # context
30 | self.obs_encoder = replace_bn_with_gn(self.obs_encoder)
31 | self.num_obs_features = self.obs_encoder._fc.in_features
32 | self.obs_encoder_type = "efficientnet"
33 | else:
34 | raise NotImplementedError
35 |
36 | # Initialize the goal encoder
37 | self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=6) # obs+goal
38 | self.goal_encoder = replace_bn_with_gn(self.goal_encoder)
39 | self.num_goal_features = self.goal_encoder._fc.in_features
40 |
41 | # Initialize compression layers if necessary
42 | if self.num_obs_features != self.obs_encoding_size:
43 | self.compress_obs_enc = nn.Linear(self.num_obs_features, self.obs_encoding_size)
44 | else:
45 | self.compress_obs_enc = nn.Identity()
46 |
47 | if self.num_goal_features != self.goal_encoding_size:
48 | self.compress_goal_enc = nn.Linear(self.num_goal_features, self.goal_encoding_size)
49 | else:
50 | self.compress_goal_enc = nn.Identity()
51 |
52 | # Initialize positional encoding and self-attention layers
53 | self.positional_encoding = PositionalEncoding(self.obs_encoding_size, max_seq_len=self.context_size + 2)
54 | self.sa_layer = nn.TransformerEncoderLayer(
55 | d_model=self.obs_encoding_size,
56 | nhead=mha_num_attention_heads,
57 | dim_feedforward=mha_ff_dim_factor*self.obs_encoding_size,
58 | activation="gelu",
59 | batch_first=True,
60 | norm_first=True
61 | )
62 | self.sa_encoder = nn.TransformerEncoder(self.sa_layer, num_layers=mha_num_attention_layers)
63 |
64 | # Definition of the goal mask (convention: 0 = no mask, 1 = mask)
65 | self.goal_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool)
66 | self.goal_mask[:, -1] = True # Mask out the goal
67 | self.no_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool)
68 | self.all_masks = torch.cat([self.no_mask, self.goal_mask], dim=0)
69 | self.avg_pool_mask = torch.cat([1 - self.no_mask.float(), (1 - self.goal_mask.float()) * ((self.context_size + 2)/(self.context_size + 1))], dim=0)
70 |
71 |
72 | def forward(self, obs_img: torch.tensor, goal_img: torch.tensor, input_goal_mask: torch.tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
73 |
74 | device = obs_img.device
75 |
76 | # Initialize the goal encoding
77 | goal_encoding = torch.zeros((obs_img.size()[0], 1, self.goal_encoding_size)).to(device)
78 |
79 | # Get the input goal mask
80 | if input_goal_mask is not None:
81 | goal_mask = input_goal_mask.to(device)
82 |
83 | # Get the goal encoding
84 | obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1) # concatenate the obs image/context and goal image --> non image goal?
85 | obsgoal_encoding = self.goal_encoder.extract_features(obsgoal_img) # get encoding of this img
86 | obsgoal_encoding = self.goal_encoder._avg_pooling(obsgoal_encoding) # avg pooling
87 |
88 | if self.goal_encoder._global_params.include_top:
89 | obsgoal_encoding = obsgoal_encoding.flatten(start_dim=1)
90 | obsgoal_encoding = self.goal_encoder._dropout(obsgoal_encoding)
91 | obsgoal_encoding = self.compress_goal_enc(obsgoal_encoding)
92 |
93 | if len(obsgoal_encoding.shape) == 2:
94 | obsgoal_encoding = obsgoal_encoding.unsqueeze(1)
95 | assert obsgoal_encoding.shape[2] == self.goal_encoding_size
96 | goal_encoding = obsgoal_encoding
97 |
98 | # Get the observation encoding
99 | obs_img = torch.split(obs_img, 3, dim=1)
100 | obs_img = torch.concat(obs_img, dim=0)
101 |
102 | obs_encoding = self.obs_encoder.extract_features(obs_img)
103 | obs_encoding = self.obs_encoder._avg_pooling(obs_encoding)
104 | if self.obs_encoder._global_params.include_top:
105 | obs_encoding = obs_encoding.flatten(start_dim=1)
106 | obs_encoding = self.obs_encoder._dropout(obs_encoding)
107 | obs_encoding = self.compress_obs_enc(obs_encoding)
108 | obs_encoding = obs_encoding.unsqueeze(1)
109 | obs_encoding = obs_encoding.reshape((self.context_size+1, -1, self.obs_encoding_size))
110 | obs_encoding = torch.transpose(obs_encoding, 0, 1)
111 | obs_encoding = torch.cat((obs_encoding, goal_encoding), dim=1)
112 |
113 | # If a goal mask is provided, mask some of the goal tokens
114 | if goal_mask is not None:
115 | no_goal_mask = goal_mask.long()
116 | src_key_padding_mask = torch.index_select(self.all_masks.to(device), 0, no_goal_mask)
117 | else:
118 | src_key_padding_mask = None
119 |
120 | # Apply positional encoding
121 | if self.positional_encoding:
122 | obs_encoding = self.positional_encoding(obs_encoding)
123 |
124 | obs_encoding_tokens = self.sa_encoder(obs_encoding, src_key_padding_mask=src_key_padding_mask)
125 | if src_key_padding_mask is not None:
126 | avg_mask = torch.index_select(self.avg_pool_mask.to(device), 0, no_goal_mask).unsqueeze(-1)
127 | obs_encoding_tokens = obs_encoding_tokens * avg_mask
128 | obs_encoding_tokens = torch.mean(obs_encoding_tokens, dim=1)
129 |
130 | return obs_encoding_tokens
131 |
132 |
133 |
134 | # Utils for Group Norm
135 | def replace_bn_with_gn(
136 | root_module: nn.Module,
137 | features_per_group: int=16) -> nn.Module:
138 | """
139 | Relace all BatchNorm layers with GroupNorm.
140 | """
141 | replace_submodules(
142 | root_module=root_module,
143 | predicate=lambda x: isinstance(x, nn.BatchNorm2d),
144 | func=lambda x: nn.GroupNorm(
145 | num_groups=x.num_features//features_per_group,
146 | num_channels=x.num_features)
147 | )
148 | return root_module
149 |
150 |
151 | def replace_submodules(
152 | root_module: nn.Module,
153 | predicate: Callable[[nn.Module], bool],
154 | func: Callable[[nn.Module], nn.Module]) -> nn.Module:
155 | """
156 | Replace all submodules selected by the predicate with
157 | the output of func.
158 |
159 | predicate: Return true if the module is to be replaced.
160 | func: Return new module to use.
161 | """
162 | if predicate(root_module):
163 | return func(root_module)
164 |
165 | bn_list = [k.split('.') for k, m
166 | in root_module.named_modules(remove_duplicate=True)
167 | if predicate(m)]
168 | for *parent, k in bn_list:
169 | parent_module = root_module
170 | if len(parent) > 0:
171 | parent_module = root_module.get_submodule('.'.join(parent))
172 | if isinstance(parent_module, nn.Sequential):
173 | src_module = parent_module[int(k)]
174 | else:
175 | src_module = getattr(parent_module, k)
176 | tgt_module = func(src_module)
177 | if isinstance(parent_module, nn.Sequential):
178 | parent_module[int(k)] = tgt_module
179 | else:
180 | setattr(parent_module, k, tgt_module)
181 | # verify that all modules are replaced
182 | bn_list = [k.split('.') for k, m
183 | in root_module.named_modules(remove_duplicate=True)
184 | if predicate(m)]
185 | assert len(bn_list) == 0
186 | return root_module
187 |
188 |
189 |
190 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NaviDiffusor: Cost-Guided Diffusion Model for Visual Navigation
2 |
3 | > 🏆 Accepted at **ICRA 2025**
4 | > 🔗 [arXiv](https://arxiv.org/abs/2504.10003) | [Bilibili](https://www.bilibili.com/video/BV1PaLizwEkW/) | [Youtube](https://www.youtube.com/watch?v=94ODPEqyP0s)
5 |
6 |
7 |
8 |
9 |
10 | ---
11 |
12 | ## ✅ TODO List
13 |
14 | - [ ] Training code updates
15 | - [ ] Simulation Envs
16 |
17 | ## ⚙️ Setup
18 | Run the commands below inside the project directory:
19 | 1. Set up the conda environment:
20 | ```bash
21 | conda env create -f train/train_environment.yml
22 | ```
23 | 2. Source the conda environment:
24 | ```
25 | conda activate navidiffusor
26 | ```
27 | 3. Install the vint_train packages:
28 | ```bash
29 | pip install -e train/
30 | ```
31 | 4. Install the `diffusion_policy` package from this [repo](https://github.com/real-stanford/diffusion_policy):
32 | ```bash
33 | git clone git@github.com:real-stanford/diffusion_policy.git
34 | pip install -e diffusion_policy/
35 | ```
36 | 5. Install the `depth_anything_v2` package from this [repo](https://github.com/DepthAnything/Depth-Anything-V2):
37 | ```bash
38 | git clone https://github.com/DepthAnything/Depth-Anything-V2.git
39 | pip install -r Depth-Anything-V2/requirements.txt
40 | ```
41 |
42 | ## Data
43 | - [RECON](https://sites.google.com/view/recon-robot/dataset)
44 | - [SCAND](https://www.cs.utexas.edu/~xiao/SCAND/SCAND.html#Links)
45 | - [GoStanford2 (Modified)](https://drive.google.com/drive/folders/1xrNvMl5q92oWed99noOt_UhqQnceJYV0?usp=share_link)
46 | - [SACSoN/HuRoN](https://sites.google.com/view/sacson-review/huron-dataset)
47 |
48 | We recommend you to download these (and any other datasets you may want to train on) and run the processing steps below.
49 |
50 | ### Data Processing
51 |
52 | We provide some sample scripts to process these datasets, either directly from a rosbag or from a custom format like HDF5s:
53 | 1. Run `process_bags.py` with the relevant args, or `process_recon.py` for processing RECON HDF5s. You can also manually add your own dataset by following our structure below.
54 | 2. Run `data_split.py` on your dataset folder with the relevant args.
55 | 3. Expected structure:
56 |
57 | ```
58 | ├──
59 | │ ├──
60 | │ │ ├── 0.jpg
61 | │ │ ├── 1.jpg
62 | │ │ ├── ...
63 | │ │ ├── T_1.jpg
64 | │ │ └── traj_data.pkl
65 | │ ├──
66 | │ │ ├── 0.jpg
67 | │ │ ├── 1.jpg
68 | │ │ ├── ...
69 | │ │ ├── T_2.jpg
70 | │ │ └── traj_data.pkl
71 | │ ...
72 | └── └──
73 | ├── 0.jpg
74 | ├── 1.jpg
75 | ├── ...
76 | ├── T_N.jpg
77 | └── traj_data.pkl
78 | ```
79 |
80 | Each `*.jpg` file contains an forward-facing RGB observation from the robot, and they are temporally labeled. The `traj_data.pkl` file is the odometry data for the trajectory. It’s a pickled dictionary with the keys:
81 | - `"position"`: An np.ndarray [T, 2] of the xy-coordinates of the robot at each image observation.
82 | - `"yaw"`: An np.ndarray [T,] of the yaws of the robot at each image observation.
83 |
84 |
85 | After step 2 of data processing, the processed data-split should the following structure inside `/train/vint_train/data/data_splits/`:
86 |
87 | ```
88 | ├──
89 | │ ├── train
90 | | | └── traj_names.txt
91 | └── └── test
92 | └── traj_names.txt
93 | ```
94 |
95 | ## Model Training
96 | ```bash
97 | cd /train
98 | python train.py -c
99 | ```
100 | The config yaml files are in the `train/config` directory.
101 |
102 | ## Deployment
103 |
104 |
133 |
134 |
135 | ### Inference with Guidance
136 | 🚀 **Our method is designed to provide guidance for any diffusion-based navigation model while inferece, improving path generation quality for both PointGoal and ImageGoal tasks. Here, we use [NoMaD](https://github.com/robodhruv/visualnav-transformer) as an example, an adaptable implementation in [guide.py](./deployment/src/guide.py) is provided for integrating with your own diffusion model.**
137 |
138 | ```bash
139 | cd deployment/src/
140 | sh ./navigate.sh --model --dir --point-goal False # set --point-goal=True for PointGoal navigation, False for ImageGoal
141 | ```
142 |
143 | The `` is the name of the model in the `/deployment/config/models.yaml` file. In this file, you specify these parameters of the model for each model (defaults used):
144 | - `config_path` (str): path of the *.yaml file in `/train/config/` used to train the model
145 | - `ckpt_path` (str): path of the *.pth file in `/deployment/model_weights/`
146 |
147 |
148 | Make sure these configurations match what you used to train the model. The configurations for the models we provided the weights for are provided in yaml file for your reference.
149 |
150 | The `` is the name of the directory in `/deployment/topomaps/images` that has the images corresponding to the nodes in the topological map. The images are ordered by name from 0 to N.
151 |
152 | This command opens up 4 windows:
153 |
154 | 1. `roslaunch vint_locobot.launch`: This launch file opens the usb_cam node for the camera, the joy node for the joystick, and several nodes for the robot’s mobile base.
155 | 2. `python navigate.py --model --dir `: This python script starts a node that reads in image observations from the `/usb_cam/image_raw` topic, inputs the observations and the map into the model, and publishes actions to the `/waypoint` topic.
156 | 3. `python joy_teleop.py`: This python script starts a node that reads inputs from the joy topic and outputs them on topics that teleoperate the robot’s base.
157 | 4. `python pd_controller.py`: This python script starts a node that reads messages from the `/waypoint` topic (waypoints from the model) and outputs velocities to navigate the robot’s base.
158 |
159 | When the robot is finishing navigating, kill the `pd_controller.py` script, and then kill the tmux session. If you want to take control of the robot while it is navigating, the `joy_teleop.py` script allows you to do so with the joystick.
160 |
161 | ## Citing
162 | ```
163 | @article{zeng2025navidiffusor,
164 | title={NaviDiffusor: Cost-Guided Diffusion Model for Visual Navigation},
165 | author={Zeng, Yiming and Ren, Hao and Wang, Shuhang and Huang, Junlong and Cheng, Hui},
166 | journal={arXiv preprint arXiv:2504.10003},
167 | year={2025}
168 | }
169 | ```
170 | ## Acknowlegdment
171 | NaviDiffusor is inspired by the contributions of the following works to the open-source community:[NoMaD](https://github.com/robodhruv/visualnav-transformer), [Depthanythingv2](https://github.com/DepthAnything/Depth-Anything-V2) and [ViPlanner](https://github.com/leggedrobotics/viplanner). We thank the authors for sharing their outstanding work.
172 |
--------------------------------------------------------------------------------
/deployment/src/tsdf_cost_map.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023-2024, ETH Zurich (Robotics Systems Lab)
2 | # Author: Pascal Roth
3 | # All rights reserved.
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 |
7 | import math
8 |
9 | # python
10 | import os
11 |
12 | import torch
13 | import numpy as np
14 | import open3d as o3d
15 | from scipy import ndimage
16 | from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion
17 |
18 | # imperative-cost-map
19 | from costmap_cfg import GeneralCostMapConfig, TsdfCostMapConfig
20 |
21 |
22 | class TsdfCostMap:
23 | """
24 | Cost Map based on geometric information
25 | """
26 |
27 | def __init__(self, cfg_general: GeneralCostMapConfig, cfg_tsdf: TsdfCostMapConfig):
28 | self._cfg_general = cfg_general
29 | self._cfg_tsdf = cfg_tsdf
30 | # set init flag
31 | self.is_map_ready = False
32 | # init point clouds
33 | self.obs_pcd = o3d.geometry.PointCloud()
34 | self.free_pcd = o3d.geometry.PointCloud()
35 | return
36 |
37 | def Pos2Ind(self, points):
38 | start_xy = torch.tensor([self.start_x, self.start_y], dtype=torch.float64, device=points.device).expand(1, 1, -1)
39 | H = (points - start_xy) / self._cfg_general.resolution
40 | mask = torch.logical_and((H > 0).all(axis=2), (H < torch.tensor([self.num_x, self.num_y], device=points.device)[None,None,:]).all(axis=2))
41 | return self.NormInds(H), H
42 |
43 | def NormInds(self, H):
44 | norm_matrix = torch.tensor([self.num_x/2.0, self.num_y/2.0], dtype=torch.float64, device=H.device)
45 | H = (H - norm_matrix) / norm_matrix
46 | return H
47 |
48 | def UpdatePCDwithPs(self, P_obs, P_free, is_downsample=False):
49 | self.obs_pcd.points = o3d.utility.Vector3dVector(P_obs)
50 | self.free_pcd.points = o3d.utility.Vector3dVector(P_free)
51 | if is_downsample:
52 | self.obs_pcd = self.obs_pcd.voxel_down_sample(self._cfg_general.resolution)
53 | self.free_pcd = self.free_pcd.voxel_down_sample(self._cfg_general.resolution * 0.85)
54 |
55 | self.obs_points = np.asarray(self.obs_pcd.points)
56 | self.free_points = np.asarray(self.free_pcd.points)
57 | print("number of obs points: %d, free points: %d" % (self.obs_points.shape[0], self.free_points.shape[0]))
58 |
59 | def ReadPointFromFile(self):
60 | pcd_load = o3d.io.read_point_cloud(os.path.join(self._cfg_general.root_path, self._cfg_general.ply_file))
61 | obs_p, free_p = self.TerrainAnalysis(np.asarray(pcd_load.points))
62 | self.UpdatePCDwithPs(obs_p, free_p, is_downsample=True)
63 | if self._cfg_tsdf.filter_outliers:
64 | obs_p = self.FilterCloud(self.obs_points)
65 | free_p = self.FilterCloud(self.free_points, outlier_filter=False)
66 | self.UpdatePCDwithPs(obs_p, free_p)
67 | self.UpdateMapParams()
68 | return
69 |
70 | def LoadPointCloud(self, pcd):
71 | obs_p, free_p = self.TerrainAnalysis(np.asarray(pcd.points))
72 | self.UpdatePCDwithPs(obs_p, free_p, is_downsample=True)
73 | if self._cfg_tsdf.filter_outliers:
74 | obs_p = self.FilterCloud(self.obs_points, outlier_filter=False)
75 | free_p = self.FilterCloud(self.free_points, outlier_filter=False)
76 | self.UpdatePCDwithPs(obs_p, free_p)
77 | self.UpdateMapParams()
78 | return
79 |
80 | def TerrainAnalysis(self, input_points):
81 | obs_points = np.zeros(input_points.shape)
82 | free_poins = np.zeros(input_points.shape)
83 | obs_idx = 0
84 | free_idx = 0
85 | # naive approach with z values
86 | for p in input_points:
87 | p_height = p[2] + self._cfg_tsdf.offset_z
88 | if (p_height > self._cfg_tsdf.ground_height * 1.0) and (
89 | p_height < self._cfg_tsdf.robot_height * self._cfg_tsdf.robot_height_factor
90 | ): # remove ground and ceiling
91 | obs_points[obs_idx, :] = p
92 | obs_idx = obs_idx + 1
93 | elif p_height < self._cfg_tsdf.ground_height:
94 | free_poins[free_idx, :] = p
95 | free_idx = free_idx + 1
96 | return obs_points[:obs_idx, :], free_poins[:free_idx, :]
97 |
98 | def UpdateMapParams(self):
99 | if self.obs_points.shape[0] == 0:
100 | print("No points received.")
101 | return
102 | max_x, max_y, _ = np.amax(self.obs_points, axis=0) + self._cfg_general.clear_dist
103 | min_x, min_y, _ = np.amin(self.obs_points, axis=0) - self._cfg_general.clear_dist
104 |
105 | self.num_x = np.ceil((max_x - min_x) / self._cfg_general.resolution / 10).astype(int) * 10
106 | self.num_y = np.ceil((max_y - min_y) / self._cfg_general.resolution / 10).astype(int) * 10
107 | self.start_x = (max_x + min_x) / 2.0 - self.num_x / 2.0 * self._cfg_general.resolution
108 | self.start_y = (max_y + min_y) / 2.0 - self.num_y / 2.0 * self._cfg_general.resolution
109 |
110 | print("tsdf map initialized, with size: %d, %d" % (self.num_x, self.num_y))
111 | self.is_map_ready = True
112 |
113 | def zero_single_isolated_ones(self, arr):
114 | neighbors = ndimage.convolve(arr, weights=np.ones((3, 3)), mode="constant", cval=0)
115 | isolated_ones = (arr == 1) & (neighbors == 1)
116 | arr[isolated_ones] = 0
117 | return arr
118 |
119 | def CreateTSDFMap(self):
120 | if not self.is_map_ready:
121 | raise ValueError("create tsdf map fails, no points received.")
122 | free_map = np.ones([self.num_x, self.num_y])
123 | obs_map = np.zeros([self.num_x, self.num_y])
124 | free_I = self.IndexArrayOfPs(self.free_points)
125 | obs_I = self.IndexArrayOfPs(self.obs_points)
126 | # create free place map
127 | for i in obs_I:
128 | obs_map[i[0], i[1]] = 1.0
129 | obs_map = self.zero_single_isolated_ones(obs_map)
130 | obs_map = gaussian_filter(obs_map, sigma=self._cfg_tsdf.sigma_expand)
131 | for i in free_I:
132 | if 0 < i[0] < self.num_x and 0 < i[1] < self.num_y:
133 | try:
134 | free_map[i[0], i[1]] = 0
135 | except:
136 | import ipdb;ipdb.set_trace()
137 | free_map = gaussian_filter(free_map, sigma=self._cfg_tsdf.sigma_expand)
138 |
139 | free_map[free_map < self._cfg_tsdf.free_space_threshold] = 0
140 | # assign obstacles
141 | free_map[obs_map > self._cfg_tsdf.obstacle_threshold] = 1.0
142 | print("occupancy map generation completed.")
143 | # Distance Transform
144 | tsdf_array = ndimage.distance_transform_edt(free_map)
145 |
146 | tsdf_array[tsdf_array > 0.0] = np.log(tsdf_array[tsdf_array > 0.0] + math.e)
147 | tsdf_array = gaussian_filter(tsdf_array, sigma=self._cfg_general.sigma_smooth)
148 |
149 | viz_points = np.concatenate((self.obs_points, self.free_points), axis=0)
150 |
151 | ground_array = np.ones([self.num_x, self.num_y]) * 0.0
152 |
153 | return [tsdf_array, viz_points, ground_array], [
154 | float(self.start_x),
155 | float(self.start_y),
156 | ]
157 |
158 | def IndexArrayOfPs(self, points):
159 | indexes = points[:, :2] - np.array([self.start_x, self.start_y])
160 | indexes = (np.round(indexes / self._cfg_general.resolution)).astype(int)
161 | return indexes
162 |
163 | def FilterCloud(self, points, outlier_filter=True):
164 | # crop points
165 | if any(
166 | [
167 | self._cfg_general.x_max,
168 | self._cfg_general.x_min,
169 | self._cfg_general.y_max,
170 | self._cfg_general.y_min,
171 | ]
172 | ):
173 | points_x_idx_upper = (
174 | (points[:, 0] < self._cfg_general.x_max)
175 | if self._cfg_general.x_max is not None
176 | else np.ones(points.shape[0], dtype=bool)
177 | )
178 | points_x_idx_lower = (
179 | (points[:, 0] > self._cfg_general.x_min)
180 | if self._cfg_general.x_min is not None
181 | else np.ones(points.shape[0], dtype=bool)
182 | )
183 | points_y_idx_upper = (
184 | (points[:, 1] < self._cfg_general.y_max)
185 | if self._cfg_general.y_max is not None
186 | else np.ones(points.shape[0], dtype=bool)
187 | )
188 | points_y_idx_lower = (
189 | (points[:, 1] > self._cfg_general.y_min)
190 | if self._cfg_general.y_min is not None
191 | else np.ones(points.shape[0], dtype=bool)
192 | )
193 | points = points[
194 | np.vstack(
195 | (
196 | points_x_idx_lower,
197 | points_x_idx_upper,
198 | points_y_idx_upper,
199 | points_y_idx_lower,
200 | )
201 | ).all(axis=0)
202 | ]
203 |
204 | if outlier_filter:
205 | # Filter outlier in points
206 | pcd = o3d.geometry.PointCloud()
207 | pcd.points = o3d.utility.Vector3dVector(points)
208 | cl, _ = pcd.remove_statistical_outlier(
209 | nb_neighbors=self._cfg_tsdf.nb_neighbors,
210 | std_ratio=self._cfg_tsdf.std_ratio,
211 | )
212 | points = np.asarray(cl.points)
213 |
214 | return points
215 |
216 | def VizCloud(self, pcd):
217 | o3d.visualization.draw_geometries([pcd]) # visualize point cloud
218 |
219 |
220 | # EoF
221 |
--------------------------------------------------------------------------------
/deployment/src/guide.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from scipy.ndimage import gaussian_filter, distance_transform_edt
6 |
7 | import cv2
8 | import matplotlib.pyplot as plt
9 | from depth_anything_v2.dpt import DepthAnythingV2
10 |
11 | import importlib.util
12 | import os
13 | import open3d as o3d
14 | from tsdf_cost_map import TsdfCostMap
15 | from costmap_cfg import CostMapConfig
16 |
17 |
18 | def from_numpy(array: np.ndarray) -> torch.Tensor:
19 | return torch.from_numpy(array).float()
20 |
21 | def check_tensor(tensor, name="tensor"):
22 | if tensor.grad is not None:
23 | print(f"{name} grad: {tensor.grad}")
24 | else:
25 | print(f"{name} grad is None")
26 |
27 | class PathGuide:
28 |
29 | def __init__(self, device, ACTION_STATS, guide_cfgs=None):
30 | """
31 | Parameters:
32 | """
33 | self.device = device
34 | self.guide_cfgs = guide_cfgs
35 | self.mse_loss = nn.MSELoss(reduction='mean')
36 | self.l1_loss = nn.L1Loss(reduction='mean')
37 | self.robot_width = 0.6
38 | self.spatial_resolution = 0.1
39 | self.max_distance = 10
40 | self.bev_dist = self.max_distance / self.spatial_resolution
41 | self.delta_min = from_numpy(ACTION_STATS['min']).to(self.device)
42 | self.delta_max = from_numpy(ACTION_STATS['max']).to(self.device)
43 |
44 | # TODO: Pass in parameters instead of constants
45 | self.camera_intrinsics = np.array([[607.99658203125, 0, 642.2532958984375],
46 | [0, 607.862060546875, 366.3480224609375],
47 | [0, 0, 1]])
48 | # robot to camera extrinsic
49 | self.camera_extrinsics = np.array([[0, 0, 1, -0.000],
50 | [-1, 0, 0, -0.000],
51 | [0, -1, 0, -0.042],
52 | [0, 0, 0, 1]])
53 |
54 | # depth anything v2 init
55 | model_configs = {
56 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
57 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
58 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
59 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
60 | }
61 | encoder = 'vits' # or 'vits', 'vitb', 'vitg'
62 | self.model = DepthAnythingV2(**model_configs[encoder])
63 | package_name = 'depth_anything_v2'
64 | package_spec = importlib.util.find_spec(package_name)
65 | if package_spec is None:
66 | raise ImportError(f"Package '{package_name}' not found")
67 | package_path = os.path.dirname(package_spec.origin)
68 | self.model.load_state_dict(torch.load(os.path.join(package_path, f'../checkpoints/depth_anything_v2_{encoder}.pth'), map_location='cpu'))
69 | self.model = self.model.to(self.device).eval()
70 |
71 | # TSDF init
72 | self.tsdf_cfg = CostMapConfig()
73 | self.tsdf_cost_map = TsdfCostMap(self.tsdf_cfg.general, self.tsdf_cfg.tsdf_cost_map)
74 |
75 | def _norm_delta_to_ori_trajs(self, trajs):
76 | delta_tmp = (trajs + 1) / 2
77 | delta_ori = delta_tmp * (self.delta_max - self.delta_min) + self.delta_min
78 | trajs_ori = delta_ori.cumsum(dim=1)
79 | return trajs_ori
80 |
81 | def goal_cost(self, trajs, goal, scale_factor=None):
82 | import time
83 | trajs_ori = self._norm_delta_to_ori_trajs(trajs)
84 | if scale_factor is not None:
85 | trajs_ori *= scale_factor
86 | trajs_end_positions = trajs_ori[:, -1, :]
87 |
88 | distances = torch.norm(goal - trajs_end_positions, dim=1)
89 |
90 | gloss = 0.05 * torch.sum(distances)
91 |
92 | if trajs.grad is not None:
93 | trajs.grad.zero_()
94 |
95 | gloss.backward()
96 | return trajs.grad
97 |
98 | def generate_scale(self, n):
99 | scale = torch.linspace(0, 1, steps=n)
100 |
101 | squared_scale = scale ** 1
102 |
103 | return squared_scale.to(self.device)
104 |
105 | def depth_to_pcd(self, depth_image, camera_intrinsics, camera_extrinsics, resize_factor=1.0, height_threshold=0.5, max_distance=10.0):
106 | height, width = depth_image.shape
107 | print("height: ", height, "width: ", width)
108 | fx, fy = camera_intrinsics[0, 0] * resize_factor, camera_intrinsics[1, 1] * resize_factor
109 | cx, cy = camera_intrinsics[0, 2] * resize_factor, camera_intrinsics[1, 2] * resize_factor
110 |
111 | x, y = np.meshgrid(np.arange(width), np.arange(height))
112 | z = depth_image.astype(np.float32)
113 | z_safe = np.where(z == 0, np.nan, z)
114 | z = 1 / z_safe
115 | x = (x - width / 2) * z / fx
116 | y = (y - height / 2) * z / fy
117 | non_ground_mask = (z > 0.5) & (z < max_distance)
118 | x_non_ground = x[non_ground_mask]
119 | y_non_ground = y[non_ground_mask]
120 | z_non_ground = z[non_ground_mask]
121 |
122 | points = np.stack((x_non_ground, y_non_ground, z_non_ground), axis=-1).reshape(-1, 3)
123 |
124 | extrinsics = camera_extrinsics
125 | homogeneous_points = np.hstack((points, np.ones((points.shape[0], 1))))
126 | transformed_points = (extrinsics @ homogeneous_points.T).T[:, :3]
127 |
128 | point_cloud = o3d.geometry.PointCloud()
129 | point_cloud.points = o3d.utility.Vector3dVector(transformed_points)
130 |
131 | return point_cloud
132 |
133 | def add_robot_dim(self, world_ps):
134 | tangent = world_ps[:, 1:, 0:2] - world_ps[:, :-1, 0:2]
135 | tangent = tangent / torch.norm(tangent, dim=2, keepdim=True)
136 | normals = tangent[:, :, [1, 0]] * torch.tensor(
137 | [-1, 1], dtype=torch.float32, device=world_ps.device
138 | )
139 | world_ps_inflated = torch.vstack([world_ps[:, :-1, :]] * 3)
140 | world_ps_inflated[:, :, 0:2] = torch.vstack(
141 | [
142 | world_ps[:, :-1, 0:2] + normals * self.robot_width / 2,
143 | world_ps[:, :-1, 0:2], # center
144 | world_ps[:, :-1, 0:2] - normals * self.robot_width / 2,
145 | ]
146 | )
147 | return world_ps_inflated
148 |
149 | def get_cost_map_via_tsdf(self, img):
150 | original_width, original_height = img.size
151 | resize_factor = 0.25
152 | new_size = (int(original_width * resize_factor), int(original_height * resize_factor))
153 | img = img.resize(new_size)
154 | depth_image = self.model.infer_image(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
155 | pseudo_pcd = self.depth_to_pcd(depth_image, self.camera_intrinsics, self.camera_extrinsics, resize_factor=resize_factor)
156 |
157 | self.tsdf_cost_map.LoadPointCloud(pseudo_pcd)
158 | data, coord = self.tsdf_cost_map.CreateTSDFMap()
159 | if data is None:
160 | self.cost_map = None
161 | else:
162 | self.cost_map = torch.tensor(data[0]).requires_grad_(False).to(self.device)
163 |
164 | def collision_cost(self, trajs, scale_factor=None):
165 | if self.cost_map is None:
166 | return torch.zeros(trajs.shape)
167 | batch_size, num_p, _ = trajs.shape
168 | trajs_ori = self._norm_delta_to_ori_trajs(trajs)
169 | trajs_ori = self.add_robot_dim(trajs_ori)
170 | if scale_factor is not None:
171 | trajs_ori *= scale_factor
172 | norm_inds, _ = self.tsdf_cost_map.Pos2Ind(trajs_ori)
173 | cost_grid = self.cost_map.T.expand(trajs_ori.shape[0], 1, -1, -1)
174 | oloss_M = F.grid_sample(cost_grid, norm_inds[:, None, :, :], mode='bicubic', padding_mode='border', align_corners=False).squeeze(1).squeeze(1)
175 | oloss_M = oloss_M.to(torch.float32)
176 |
177 | loss = 0.003 * torch.sum(oloss_M, axis=1)
178 | if trajs.grad is not None:
179 | trajs.grad.zero_()
180 | loss.backward(torch.ones_like(loss))
181 | cost_list = loss[1::3]
182 | generate_scale = self.generate_scale(trajs.shape[1])
183 | return generate_scale.unsqueeze(1).unsqueeze(0) * trajs.grad, cost_list
184 |
185 | def get_gradient(self, trajs, alpha=0.3, t=None, goal_pos=None, ACTION_STATS=None, scale_factor=None):
186 | trajs_in = trajs.detach().requires_grad_(True).to(self.device)
187 | if goal_pos is not None:
188 | goal_pos = torch.tensor(goal_pos).to(self.device)
189 | goal_cost = self.goal_cost(trajs_in, goal_pos, scale_factor=scale_factor)
190 | cost = goal_cost
191 | return cost, None
192 | else:
193 | collision_cost, cost_list = self.collision_cost(trajs_in, scale_factor=scale_factor)
194 | cost = collision_cost
195 | return cost, cost_list
196 |
197 |
198 |
199 | class PathOpt:
200 | def __init__(self):
201 | self.traj_cache = None
202 |
203 | def angle_between_vectors(self, vec1, vec2):
204 | dot_product = np.sum(vec1 * vec2, axis=1)
205 | norm_product = np.linalg.norm(vec1, axis=1) * np.linalg.norm(vec2, axis=1)
206 | angle = np.arccos(dot_product / norm_product)
207 | return np.degrees(angle)
208 |
209 | def select_trajectory(self, trajs, l=2, angle_threshold=45, collision_min_idx=None):
210 | if self.traj_cache is None or len(self.traj_cache) <= l:
211 | idx = collision_min_idx if collision_min_idx else 0
212 | self.traj_cache = trajs[idx]
213 | else:
214 | directions = trajs[:, l, :]
215 |
216 | historical_directions = self.traj_cache[l]
217 |
218 | historical_directions = np.broadcast_to(historical_directions, directions.shape)
219 |
220 | angle_diffs = self.angle_between_vectors(directions, historical_directions)
221 |
222 | sorted_indices = np.argsort(angle_diffs)
223 |
224 | if angle_diffs[sorted_indices[0]] > angle_threshold:
225 | idx = 0
226 | self.traj_cache = trajs[idx]
227 | else:
228 | idx =sorted_indices[0]
229 | self.traj_cache = trajs[idx]
230 |
231 | return trajs[idx], idx
--------------------------------------------------------------------------------
/train/vint_train/process_data/process_data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import os
4 | import rosbag
5 | from PIL import Image
6 | import cv2
7 | from typing import Any, Tuple, List, Dict
8 | import torchvision.transforms.functional as TF
9 |
10 | IMAGE_SIZE = (160, 120)
11 | IMAGE_ASPECT_RATIO = 4 / 3
12 |
13 |
14 | def process_images(im_list: List, img_process_func) -> List:
15 | """
16 | Process image data from a topic that publishes ros images into a list of PIL images
17 | """
18 | images = []
19 | for img_msg in im_list:
20 | img = img_process_func(img_msg)
21 | images.append(img)
22 | return images
23 |
24 |
25 | def process_tartan_img(msg) -> Image:
26 | """
27 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the tartan_drive dataset
28 | """
29 | img = ros_to_numpy(msg, output_resolution=IMAGE_SIZE) * 255
30 | img = img.astype(np.uint8)
31 | # reverse the axis order to get the image in the right orientation
32 | img = np.moveaxis(img, 0, -1)
33 | # convert rgb to bgr
34 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
35 | img = Image.fromarray(img)
36 | return img
37 |
38 |
39 | def process_locobot_img(msg) -> Image:
40 | """
41 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the locobot dataset
42 | """
43 | img = np.frombuffer(msg.data, dtype=np.uint8).reshape(
44 | msg.height, msg.width, -1)
45 | pil_image = Image.fromarray(img)
46 | return pil_image
47 |
48 |
49 | def process_scand_img(msg) -> Image:
50 | """
51 | Process image data from a topic that publishes sensor_msgs/CompressedImage to a PIL image for the scand dataset
52 | """
53 | # convert sensor_msgs/CompressedImage to PIL image
54 | img = Image.open(io.BytesIO(msg.data))
55 | # center crop image to 4:3 aspect ratio
56 | w, h = img.size
57 | img = TF.center_crop(
58 | img, (h, int(h * IMAGE_ASPECT_RATIO))
59 | ) # crop to the right ratio
60 | # resize image to IMAGE_SIZE
61 | img = img.resize(IMAGE_SIZE)
62 | return img
63 |
64 |
65 | ############## Add custom image processing functions here #############
66 |
67 | def process_sacson_img(msg) -> Image:
68 | np_arr = np.fromstring(msg.data, np.uint8)
69 | image_np = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
70 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
71 | pil_image = Image.fromarray(image_np)
72 | return pil_image
73 |
74 |
75 | #######################################################################
76 |
77 |
78 | def process_odom(
79 | odom_list: List,
80 | odom_process_func: Any,
81 | ang_offset: float = 0.0,
82 | ) -> Dict[np.ndarray, np.ndarray]:
83 | """
84 | Process odom data from a topic that publishes nav_msgs/Odometry into position and yaw
85 | """
86 | xys = []
87 | yaws = []
88 | for odom_msg in odom_list:
89 | xy, yaw = odom_process_func(odom_msg, ang_offset)
90 | xys.append(xy)
91 | yaws.append(yaw)
92 | return {"position": np.array(xys), "yaw": np.array(yaws)}
93 |
94 |
95 | def nav_to_xy_yaw(odom_msg, ang_offset: float) -> Tuple[List[float], float]:
96 | """
97 | Process odom data from a topic that publishes nav_msgs/Odometry into position
98 | """
99 |
100 | position = odom_msg.pose.pose.position
101 | orientation = odom_msg.pose.pose.orientation
102 | yaw = (
103 | quat_to_yaw(orientation.x, orientation.y, orientation.z, orientation.w)
104 | + ang_offset
105 | )
106 | return [position.x, position.y], yaw
107 |
108 |
109 | ############ Add custom odometry processing functions here ############
110 |
111 |
112 | #######################################################################
113 |
114 |
115 | def get_images_and_odom(
116 | bag: rosbag.Bag,
117 | imtopics: List[str] or str,
118 | odomtopics: List[str] or str,
119 | img_process_func: Any,
120 | odom_process_func: Any,
121 | rate: float = 4.0,
122 | ang_offset: float = 0.0,
123 | ):
124 | """
125 | Get image and odom data from a bag file
126 |
127 | Args:
128 | bag (rosbag.Bag): bag file
129 | imtopics (list[str] or str): topic name(s) for image data
130 | odomtopics (list[str] or str): topic name(s) for odom data
131 | img_process_func (Any): function to process image data
132 | odom_process_func (Any): function to process odom data
133 | rate (float, optional): rate to sample data. Defaults to 4.0.
134 | ang_offset (float, optional): angle offset to add to odom data. Defaults to 0.0.
135 | Returns:
136 | img_data (list): list of PIL images
137 | traj_data (list): list of odom data
138 | """
139 | # check if bag has both topics
140 | odomtopic = None
141 | imtopic = None
142 | if type(imtopics) == str:
143 | imtopic = imtopics
144 | else:
145 | for imt in imtopics:
146 | if bag.get_message_count(imt) > 0:
147 | imtopic = imt
148 | break
149 | if type(odomtopics) == str:
150 | odomtopic = odomtopics
151 | else:
152 | for ot in odomtopics:
153 | if bag.get_message_count(ot) > 0:
154 | odomtopic = ot
155 | break
156 | if not (imtopic and odomtopic):
157 | # bag doesn't have both topics
158 | return None, None
159 |
160 | synced_imdata = []
161 | synced_odomdata = []
162 | # get start time of bag in seconds
163 | currtime = bag.get_start_time()
164 |
165 | curr_imdata = None
166 | curr_odomdata = None
167 |
168 | for topic, msg, t in bag.read_messages(topics=[imtopic, odomtopic]):
169 | if topic == imtopic:
170 | curr_imdata = msg
171 | elif topic == odomtopic:
172 | curr_odomdata = msg
173 | if (t.to_sec() - currtime) >= 1.0 / rate:
174 | if curr_imdata is not None and curr_odomdata is not None:
175 | synced_imdata.append(curr_imdata)
176 | synced_odomdata.append(curr_odomdata)
177 | currtime = t.to_sec()
178 |
179 | img_data = process_images(synced_imdata, img_process_func)
180 | traj_data = process_odom(
181 | synced_odomdata,
182 | odom_process_func,
183 | ang_offset=ang_offset,
184 | )
185 |
186 | return img_data, traj_data
187 |
188 |
189 | def is_backwards(
190 | pos1: np.ndarray, yaw1: float, pos2: np.ndarray, eps: float = 1e-5
191 | ) -> bool:
192 | """
193 | Check if the trajectory is going backwards given the position and yaw of two points
194 | Args:
195 | pos1: position of the first point
196 |
197 | """
198 | dx, dy = pos2 - pos1
199 | return dx * np.cos(yaw1) + dy * np.sin(yaw1) < eps
200 |
201 |
202 | # cut out non-positive velocity segments of the trajectory
203 | def filter_backwards(
204 | img_list: List[Image.Image],
205 | traj_data: Dict[str, np.ndarray],
206 | start_slack: int = 0,
207 | end_slack: int = 0,
208 | ) -> Tuple[List[np.ndarray], List[int]]:
209 | """
210 | Cut out non-positive velocity segments of the trajectory
211 | Args:
212 | traj_type: type of trajectory to cut
213 | img_list: list of images
214 | traj_data: dictionary of position and yaw data
215 | start_slack: number of points to ignore at the start of the trajectory
216 | end_slack: number of points to ignore at the end of the trajectory
217 | Returns:
218 | cut_trajs: list of cut trajectories
219 | start_times: list of start times of the cut trajectories
220 | """
221 | traj_pos = traj_data["position"]
222 | traj_yaws = traj_data["yaw"]
223 | cut_trajs = []
224 | start = True
225 |
226 | def process_pair(traj_pair: list) -> Tuple[List, Dict]:
227 | new_img_list, new_traj_data = zip(*traj_pair)
228 | new_traj_data = np.array(new_traj_data)
229 | new_traj_pos = new_traj_data[:, :2]
230 | new_traj_yaws = new_traj_data[:, 2]
231 | return (new_img_list, {"position": new_traj_pos, "yaw": new_traj_yaws})
232 |
233 | for i in range(max(start_slack, 1), len(traj_pos) - end_slack):
234 | pos1 = traj_pos[i - 1]
235 | yaw1 = traj_yaws[i - 1]
236 | pos2 = traj_pos[i]
237 | if not is_backwards(pos1, yaw1, pos2):
238 | if start:
239 | new_traj_pairs = [
240 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]])
241 | ]
242 | start = False
243 | elif i == len(traj_pos) - end_slack - 1:
244 | cut_trajs.append(process_pair(new_traj_pairs))
245 | else:
246 | new_traj_pairs.append(
247 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]])
248 | )
249 | elif not start:
250 | cut_trajs.append(process_pair(new_traj_pairs))
251 | start = True
252 | return cut_trajs
253 |
254 |
255 | def quat_to_yaw(
256 | x: np.ndarray,
257 | y: np.ndarray,
258 | z: np.ndarray,
259 | w: np.ndarray,
260 | ) -> np.ndarray:
261 | """
262 | Convert a batch quaternion into a yaw angle
263 | yaw is rotation around z in radians (counterclockwise)
264 | """
265 | t3 = 2.0 * (w * z + x * y)
266 | t4 = 1.0 - 2.0 * (y * y + z * z)
267 | yaw = np.arctan2(t3, t4)
268 | return yaw
269 |
270 |
271 | def ros_to_numpy(
272 | msg, nchannels=3, empty_value=None, output_resolution=None, aggregate="none"
273 | ):
274 | """
275 | Convert a ROS image message to a numpy array
276 | """
277 | if output_resolution is None:
278 | output_resolution = (msg.width, msg.height)
279 |
280 | is_rgb = "8" in msg.encoding
281 | if is_rgb:
282 | data = np.frombuffer(msg.data, dtype=np.uint8).copy()
283 | else:
284 | data = np.frombuffer(msg.data, dtype=np.float32).copy()
285 |
286 | data = data.reshape(msg.height, msg.width, nchannels)
287 |
288 | if empty_value:
289 | mask = np.isclose(abs(data), empty_value)
290 | fill_value = np.percentile(data[~mask], 99)
291 | data[mask] = fill_value
292 |
293 | data = cv2.resize(
294 | data,
295 | dsize=(output_resolution[0], output_resolution[1]),
296 | interpolation=cv2.INTER_AREA,
297 | )
298 |
299 | if aggregate == "littleendian":
300 | data = sum([data[:, :, i] * (256**i) for i in range(nchannels)])
301 | elif aggregate == "bigendian":
302 | data = sum([data[:, :, -(i + 1)] * (256**i) for i in range(nchannels)])
303 |
304 | if len(data.shape) == 2:
305 | data = np.expand_dims(data, axis=0)
306 | else:
307 | data = np.moveaxis(data, 2, 0) # Switch to channels-first
308 |
309 | if is_rgb:
310 | data = data.astype(np.float32) / (
311 | 255.0 if aggregate == "none" else 255.0**nchannels
312 | )
313 |
314 | return data
315 |
--------------------------------------------------------------------------------
/train/vint_train/training/train_eval_loop.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import os
3 | import numpy as np
4 | from typing import List, Optional, Dict
5 | from prettytable import PrettyTable
6 |
7 | from vint_train.training.train_utils import train, evaluate
8 | from vint_train.training.train_utils import train_nomad, evaluate_nomad
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 | from torch.optim import Adam
15 | from torchvision import transforms
16 |
17 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
18 | from diffusers.training_utils import EMAModel
19 |
20 | def train_eval_loop(
21 | train_model: bool,
22 | model: nn.Module,
23 | optimizer: Adam,
24 | scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
25 | dataloader: DataLoader,
26 | test_dataloaders: Dict[str, DataLoader],
27 | transform: transforms,
28 | epochs: int,
29 | device: torch.device,
30 | project_folder: str,
31 | normalized: bool,
32 | wandb_log_freq: int = 10,
33 | print_log_freq: int = 100,
34 | image_log_freq: int = 1000,
35 | num_images_log: int = 8,
36 | current_epoch: int = 0,
37 | alpha: float = 0.5,
38 | learn_angle: bool = True,
39 | use_wandb: bool = True,
40 | eval_fraction: float = 0.25,
41 | ):
42 | """
43 | Train and evaluate the model for several epochs (vint or gnm models)
44 |
45 | Args:
46 | train_model: whether to train the model or not
47 | model: model to train
48 | optimizer: optimizer to use
49 | scheduler: learning rate scheduler to use
50 | dataloader: dataloader for train dataset
51 | test_dataloaders: dict of dataloaders for testing
52 | transform: transform to apply to images
53 | epochs: number of epochs to train
54 | device: device to train on
55 | project_folder: folder to save checkpoints and logs
56 | normalized: whether to normalize the action space or not
57 | wandb_log_freq: frequency of logging to wandb
58 | print_log_freq: frequency of printing to console
59 | image_log_freq: frequency of logging images to wandb
60 | num_images_log: number of images to log to wandb
61 | current_epoch: epoch to start training from
62 | alpha: tradeoff between distance and action loss
63 | learn_angle: whether to learn the angle or not
64 | use_wandb: whether to log to wandb or not
65 | eval_fraction: fraction of training data to use for evaluation
66 | """
67 | assert 0 <= alpha <= 1
68 | latest_path = os.path.join(project_folder, f"latest.pth")
69 |
70 | for epoch in range(current_epoch, current_epoch + epochs):
71 | if train_model:
72 | print(
73 | f"Start ViNT Training Epoch {epoch}/{current_epoch + epochs - 1}"
74 | )
75 | train(
76 | model=model,
77 | optimizer=optimizer,
78 | dataloader=dataloader,
79 | transform=transform,
80 | device=device,
81 | project_folder=project_folder,
82 | normalized=normalized,
83 | epoch=epoch,
84 | alpha=alpha,
85 | learn_angle=learn_angle,
86 | print_log_freq=print_log_freq,
87 | wandb_log_freq=wandb_log_freq,
88 | image_log_freq=image_log_freq,
89 | num_images_log=num_images_log,
90 | use_wandb=use_wandb,
91 | )
92 |
93 | avg_total_test_loss = []
94 | for dataset_type in test_dataloaders:
95 | print(
96 | f"Start {dataset_type} ViNT Testing Epoch {epoch}/{current_epoch + epochs - 1}"
97 | )
98 | loader = test_dataloaders[dataset_type]
99 |
100 | test_dist_loss, test_action_loss, total_eval_loss = evaluate(
101 | eval_type=dataset_type,
102 | model=model,
103 | dataloader=loader,
104 | transform=transform,
105 | device=device,
106 | project_folder=project_folder,
107 | normalized=normalized,
108 | epoch=epoch,
109 | alpha=alpha,
110 | learn_angle=learn_angle,
111 | num_images_log=num_images_log,
112 | use_wandb=use_wandb,
113 | eval_fraction=eval_fraction,
114 | )
115 |
116 | avg_total_test_loss.append(total_eval_loss)
117 |
118 | checkpoint = {
119 | "epoch": epoch,
120 | "model": model,
121 | "optimizer": optimizer,
122 | "avg_total_test_loss": np.mean(avg_total_test_loss),
123 | "scheduler": scheduler
124 | }
125 | # log average eval loss
126 | wandb.log({}, commit=False)
127 |
128 | if scheduler is not None:
129 | # scheduler calls based on the type of scheduler
130 | if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
131 | scheduler.step(np.mean(avg_total_test_loss))
132 | else:
133 | scheduler.step()
134 | wandb.log({
135 | "avg_total_test_loss": np.mean(avg_total_test_loss),
136 | "lr": optimizer.param_groups[0]["lr"],
137 | }, commit=False)
138 |
139 | numbered_path = os.path.join(project_folder, f"{epoch}.pth")
140 | torch.save(checkpoint, latest_path)
141 | torch.save(checkpoint, numbered_path) # keep track of model at every epoch
142 |
143 | # Flush the last set of eval logs
144 | wandb.log({})
145 | print()
146 |
147 | def train_eval_loop_nomad(
148 | train_model: bool,
149 | model: nn.Module,
150 | optimizer: Adam,
151 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
152 | noise_scheduler: DDPMScheduler,
153 | train_loader: DataLoader,
154 | test_dataloaders: Dict[str, DataLoader],
155 | transform: transforms,
156 | goal_mask_prob: float,
157 | epochs: int,
158 | device: torch.device,
159 | project_folder: str,
160 | print_log_freq: int = 100,
161 | wandb_log_freq: int = 10,
162 | image_log_freq: int = 1000,
163 | num_images_log: int = 8,
164 | current_epoch: int = 0,
165 | alpha: float = 1e-4,
166 | use_wandb: bool = True,
167 | eval_fraction: float = 0.25,
168 | eval_freq: int = 1,
169 | ):
170 | """
171 | Train and evaluate the model for several epochs (vint or gnm models)
172 |
173 | Args:
174 | model: model to train
175 | optimizer: optimizer to use
176 | lr_scheduler: learning rate scheduler to use
177 | noise_scheduler: noise scheduler to use
178 | dataloader: dataloader for train dataset
179 | test_dataloaders: dict of dataloaders for testing
180 | transform: transform to apply to images
181 | goal_mask_prob: probability of masking the goal token during training
182 | epochs: number of epochs to train
183 | device: device to train on
184 | project_folder: folder to save checkpoints and logs
185 | wandb_log_freq: frequency of logging to wandb
186 | print_log_freq: frequency of printing to console
187 | image_log_freq: frequency of logging images to wandb
188 | num_images_log: number of images to log to wandb
189 | current_epoch: epoch to start training from
190 | alpha: tradeoff between distance and action loss
191 | use_wandb: whether to log to wandb or not
192 | eval_fraction: fraction of training data to use for evaluation
193 | eval_freq: frequency of evaluation
194 | """
195 | latest_path = os.path.join(project_folder, f"latest.pth")
196 | ema_model = EMAModel(model=model,power=0.75)
197 |
198 | for epoch in range(current_epoch, current_epoch + epochs):
199 | if train_model:
200 | print(
201 | f"Start ViNT DP Training Epoch {epoch}/{current_epoch + epochs - 1}"
202 | )
203 | train_nomad(
204 | model=model,
205 | ema_model=ema_model,
206 | optimizer=optimizer,
207 | dataloader=train_loader,
208 | transform=transform,
209 | device=device,
210 | noise_scheduler=noise_scheduler,
211 | goal_mask_prob=goal_mask_prob,
212 | project_folder=project_folder,
213 | epoch=epoch,
214 | print_log_freq=print_log_freq,
215 | wandb_log_freq=wandb_log_freq,
216 | image_log_freq=image_log_freq,
217 | num_images_log=num_images_log,
218 | use_wandb=use_wandb,
219 | alpha=alpha,
220 | )
221 | lr_scheduler.step()
222 |
223 | numbered_path = os.path.join(project_folder, f"ema_{epoch}.pth")
224 | torch.save(ema_model.averaged_model.state_dict(), numbered_path)
225 | numbered_path = os.path.join(project_folder, f"ema_latest.pth")
226 | print(f"Saved EMA model to {numbered_path}")
227 |
228 | numbered_path = os.path.join(project_folder, f"{epoch}.pth")
229 | torch.save(model.state_dict(), numbered_path)
230 | torch.save(model.state_dict(), latest_path)
231 | print(f"Saved model to {numbered_path}")
232 |
233 | # save optimizer
234 | numbered_path = os.path.join(project_folder, f"optimizer_{epoch}.pth")
235 | latest_optimizer_path = os.path.join(project_folder, f"optimizer_latest.pth")
236 | torch.save(optimizer.state_dict(), latest_optimizer_path)
237 |
238 | # save scheduler
239 | numbered_path = os.path.join(project_folder, f"scheduler_{epoch}.pth")
240 | latest_scheduler_path = os.path.join(project_folder, f"scheduler_latest.pth")
241 | torch.save(lr_scheduler.state_dict(), latest_scheduler_path)
242 |
243 |
244 | if (epoch + 1) % eval_freq == 0:
245 | for dataset_type in test_dataloaders:
246 | print(
247 | f"Start {dataset_type} ViNT DP Testing Epoch {epoch}/{current_epoch + epochs - 1}"
248 | )
249 | loader = test_dataloaders[dataset_type]
250 | evaluate_nomad(
251 | eval_type=dataset_type,
252 | ema_model=ema_model,
253 | dataloader=loader,
254 | transform=transform,
255 | device=device,
256 | noise_scheduler=noise_scheduler,
257 | goal_mask_prob=goal_mask_prob,
258 | project_folder=project_folder,
259 | epoch=epoch,
260 | print_log_freq=print_log_freq,
261 | num_images_log=num_images_log,
262 | wandb_log_freq=wandb_log_freq,
263 | use_wandb=use_wandb,
264 | eval_fraction=eval_fraction,
265 | )
266 | wandb.log({
267 | "lr": optimizer.param_groups[0]["lr"],
268 | }, commit=False)
269 |
270 | if lr_scheduler is not None:
271 | lr_scheduler.step()
272 |
273 | # log average eval loss
274 | wandb.log({}, commit=False)
275 |
276 | wandb.log({
277 | "lr": optimizer.param_groups[0]["lr"],
278 | }, commit=False)
279 |
280 |
281 | # Flush the last set of eval logs
282 | wandb.log({})
283 | print()
284 |
285 | def load_model(model, model_type, checkpoint: dict) -> None:
286 | """Load model from checkpoint."""
287 | if model_type == "nomad":
288 | state_dict = checkpoint
289 | model.load_state_dict(state_dict, strict=False)
290 | else:
291 | loaded_model = checkpoint["model"]
292 | try:
293 | state_dict = loaded_model.module.state_dict()
294 | model.load_state_dict(state_dict, strict=False)
295 | except AttributeError as e:
296 | state_dict = loaded_model.state_dict()
297 | model.load_state_dict(state_dict, strict=False)
298 |
299 |
300 | def load_ema_model(ema_model, state_dict: dict) -> None:
301 | """Load model from checkpoint."""
302 | ema_model.load_state_dict(state_dict)
303 |
304 |
305 | def count_parameters(model):
306 | table = PrettyTable(["Modules", "Parameters"])
307 | total_params = 0
308 | for name, parameter in model.named_parameters():
309 | if not parameter.requires_grad: continue
310 | params = parameter.numel()
311 | table.add_row([name, params])
312 | total_params+=params
313 | # print(table)
314 | print(f"Total Trainable Params: {total_params/1e6:.2f}M")
315 | return total_params
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import argparse
4 | import numpy as np
5 | import yaml
6 | import time
7 | import pdb
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils.data import DataLoader, ConcatDataset
12 | from torch.optim import Adam, AdamW
13 | from torchvision import transforms
14 | import torch.backends.cudnn as cudnn
15 | from warmup_scheduler import GradualWarmupScheduler
16 |
17 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
18 | from diffusers.optimization import get_scheduler
19 |
20 | """
21 | IMPORT YOUR MODEL HERE
22 | """
23 | from vint_train.models.nomad.nomad import NoMaD, DenseNetwork
24 | from vint_train.models.nomad.nomad_vint import NoMaD_ViNT, replace_bn_with_gn
25 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
26 |
27 |
28 | from vint_train.data.vint_dataset import ViNT_Dataset
29 | from vint_train.training.train_eval_loop import (
30 | train_eval_loop,
31 | train_eval_loop_nomad,
32 | load_model,
33 | )
34 |
35 |
36 | def main(config):
37 | assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"]
38 | assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"]
39 |
40 | if torch.cuda.is_available():
41 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
42 | if "gpu_ids" not in config:
43 | config["gpu_ids"] = [0]
44 | elif type(config["gpu_ids"]) == int:
45 | config["gpu_ids"] = [config["gpu_ids"]]
46 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
47 | [str(x) for x in config["gpu_ids"]]
48 | )
49 | print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"])
50 | else:
51 | print("Using cpu")
52 |
53 | first_gpu_id = config["gpu_ids"][0]
54 | device = torch.device(
55 | f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu"
56 | )
57 |
58 | if "seed" in config:
59 | np.random.seed(config["seed"])
60 | torch.manual_seed(config["seed"])
61 | cudnn.deterministic = True
62 |
63 | cudnn.benchmark = True # good if input sizes don't vary
64 | transform = ([
65 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
66 | ])
67 | transform = transforms.Compose(transform)
68 |
69 | # Load the data
70 | train_dataset = []
71 | test_dataloaders = {}
72 |
73 | if "context_type" not in config:
74 | config["context_type"] = "temporal"
75 |
76 | if "clip_goals" not in config:
77 | config["clip_goals"] = False
78 |
79 | for dataset_name in config["datasets"]:
80 | data_config = config["datasets"][dataset_name]
81 | if "negative_mining" not in data_config:
82 | data_config["negative_mining"] = True
83 | if "goals_per_obs" not in data_config:
84 | data_config["goals_per_obs"] = 1
85 | if "end_slack" not in data_config:
86 | data_config["end_slack"] = 0
87 | if "waypoint_spacing" not in data_config:
88 | data_config["waypoint_spacing"] = 1
89 |
90 | for data_split_type in ["train", "test"]:
91 | if data_split_type in data_config:
92 | dataset = ViNT_Dataset(
93 | data_folder=data_config["data_folder"],
94 | data_split_folder=data_config[data_split_type],
95 | dataset_name=dataset_name,
96 | image_size=config["image_size"],
97 | waypoint_spacing=data_config["waypoint_spacing"],
98 | min_dist_cat=config["distance"]["min_dist_cat"],
99 | max_dist_cat=config["distance"]["max_dist_cat"],
100 | min_action_distance=config["action"]["min_dist_cat"],
101 | max_action_distance=config["action"]["max_dist_cat"],
102 | negative_mining=data_config["negative_mining"],
103 | len_traj_pred=config["len_traj_pred"],
104 | learn_angle=config["learn_angle"],
105 | context_size=config["context_size"],
106 | context_type=config["context_type"],
107 | end_slack=data_config["end_slack"],
108 | goals_per_obs=data_config["goals_per_obs"],
109 | normalize=config["normalize"],
110 | goal_type=config["goal_type"],
111 | )
112 | if data_split_type == "train":
113 | train_dataset.append(dataset)
114 | else:
115 | dataset_type = f"{dataset_name}_{data_split_type}"
116 | if dataset_type not in test_dataloaders:
117 | test_dataloaders[dataset_type] = {}
118 | test_dataloaders[dataset_type] = dataset
119 |
120 | # combine all the datasets from different robots
121 | train_dataset = ConcatDataset(train_dataset)
122 |
123 | train_loader = DataLoader(
124 | train_dataset,
125 | batch_size=config["batch_size"],
126 | shuffle=True,
127 | num_workers=config["num_workers"],
128 | drop_last=False,
129 | persistent_workers=True,
130 | )
131 |
132 | if "eval_batch_size" not in config:
133 | config["eval_batch_size"] = config["batch_size"]
134 |
135 | for dataset_type, dataset in test_dataloaders.items():
136 | test_dataloaders[dataset_type] = DataLoader(
137 | dataset,
138 | batch_size=config["eval_batch_size"],
139 | shuffle=True,
140 | num_workers=0,
141 | drop_last=False,
142 | )
143 |
144 | # Create the model
145 | if config["model_type"] == "nomad":
146 | if config["vision_encoder"] == "nomad_vint":
147 | vision_encoder = NoMaD_ViNT(
148 | obs_encoding_size=config["encoding_size"],
149 | context_size=config["context_size"],
150 | mha_num_attention_heads=config["mha_num_attention_heads"],
151 | mha_num_attention_layers=config["mha_num_attention_layers"],
152 | mha_ff_dim_factor=config["mha_ff_dim_factor"],
153 | )
154 | vision_encoder = replace_bn_with_gn(vision_encoder)
155 | elif config["vision_encoder"] == "vib":
156 | vision_encoder = ViB(
157 | obs_encoding_size=config["encoding_size"],
158 | context_size=config["context_size"],
159 | mha_num_attention_heads=config["mha_num_attention_heads"],
160 | mha_num_attention_layers=config["mha_num_attention_layers"],
161 | mha_ff_dim_factor=config["mha_ff_dim_factor"],
162 | )
163 | vision_encoder = replace_bn_with_gn(vision_encoder)
164 | elif config["vision_encoder"] == "vit":
165 | vision_encoder = ViT(
166 | obs_encoding_size=config["encoding_size"],
167 | context_size=config["context_size"],
168 | image_size=config["image_size"],
169 | patch_size=config["patch_size"],
170 | mha_num_attention_heads=config["mha_num_attention_heads"],
171 | mha_num_attention_layers=config["mha_num_attention_layers"],
172 | )
173 | vision_encoder = replace_bn_with_gn(vision_encoder)
174 | else:
175 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported")
176 |
177 | noise_pred_net = ConditionalUnet1D(
178 | input_dim=2,
179 | global_cond_dim=config["encoding_size"],
180 | down_dims=config["down_dims"],
181 | cond_predict_scale=config["cond_predict_scale"],
182 | )
183 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"])
184 |
185 | model = NoMaD(
186 | vision_encoder=vision_encoder,
187 | noise_pred_net=noise_pred_net,
188 | dist_pred_net=dist_pred_network,
189 | )
190 |
191 | noise_scheduler = DDPMScheduler(
192 | num_train_timesteps=config["num_diffusion_iters"],
193 | beta_schedule='squaredcos_cap_v2',
194 | clip_sample=True,
195 | prediction_type='epsilon'
196 | )
197 | else:
198 | raise ValueError(f"Model {config['model']} not supported")
199 |
200 | if config["clipping"]:
201 | print("Clipping gradients to", config["max_norm"])
202 | for p in model.parameters():
203 | if not p.requires_grad:
204 | continue
205 | p.register_hook(
206 | lambda grad: torch.clamp(
207 | grad, -1 * config["max_norm"], config["max_norm"]
208 | )
209 | )
210 |
211 | lr = float(config["lr"])
212 | config["optimizer"] = config["optimizer"].lower()
213 | if config["optimizer"] == "adam":
214 | optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.98))
215 | elif config["optimizer"] == "adamw":
216 | optimizer = AdamW(model.parameters(), lr=lr)
217 | elif config["optimizer"] == "sgd":
218 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
219 | else:
220 | raise ValueError(f"Optimizer {config['optimizer']} not supported")
221 |
222 | scheduler = None
223 | if config["scheduler"] is not None:
224 | config["scheduler"] = config["scheduler"].lower()
225 | if config["scheduler"] == "cosine":
226 | print("Using cosine annealing with T_max", config["epochs"])
227 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
228 | optimizer, T_max=config["epochs"]
229 | )
230 | elif config["scheduler"] == "cyclic":
231 | print("Using cyclic LR with cycle", config["cyclic_period"])
232 | scheduler = torch.optim.lr_scheduler.CyclicLR(
233 | optimizer,
234 | base_lr=lr / 10.,
235 | max_lr=lr,
236 | step_size_up=config["cyclic_period"] // 2,
237 | cycle_momentum=False,
238 | )
239 | elif config["scheduler"] == "plateau":
240 | print("Using ReduceLROnPlateau")
241 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
242 | optimizer,
243 | factor=config["plateau_factor"],
244 | patience=config["plateau_patience"],
245 | verbose=True,
246 | )
247 | else:
248 | raise ValueError(f"Scheduler {config['scheduler']} not supported")
249 |
250 | if config["warmup"]:
251 | print("Using warmup scheduler")
252 | scheduler = GradualWarmupScheduler(
253 | optimizer,
254 | multiplier=1,
255 | total_epoch=config["warmup_epochs"],
256 | after_scheduler=scheduler,
257 | )
258 |
259 | current_epoch = 0
260 | if "load_run" in config:
261 | load_project_folder = os.path.join("logs", config["load_run"])
262 | print("Loading model from ", load_project_folder)
263 | latest_path = os.path.join(load_project_folder, "latest.pth")
264 | latest_checkpoint = torch.load(latest_path) #f"cuda:{}" if torch.cuda.is_available() else "cpu")
265 | load_model(model, config["model_type"], latest_checkpoint)
266 | if "epoch" in latest_checkpoint:
267 | current_epoch = latest_checkpoint["epoch"] + 1
268 |
269 | # Multi-GPU
270 | if len(config["gpu_ids"]) > 1:
271 | model = nn.DataParallel(model, device_ids=config["gpu_ids"])
272 | model = model.to(device)
273 |
274 | if "load_run" in config: # load optimizer and scheduler after data parallel
275 | if "optimizer" in latest_checkpoint:
276 | optimizer.load_state_dict(latest_checkpoint["optimizer"].state_dict())
277 | if scheduler is not None and "scheduler" in latest_checkpoint:
278 | scheduler.load_state_dict(latest_checkpoint["scheduler"].state_dict())
279 |
280 | if config["model_type"] == "vint" or config["model_type"] == "gnm":
281 | train_eval_loop(
282 | train_model=config["train"],
283 | model=model,
284 | optimizer=optimizer,
285 | scheduler=scheduler,
286 | dataloader=train_loader,
287 | test_dataloaders=test_dataloaders,
288 | transform=transform,
289 | epochs=config["epochs"],
290 | device=device,
291 | project_folder=config["project_folder"],
292 | normalized=config["normalize"],
293 | print_log_freq=config["print_log_freq"],
294 | image_log_freq=config["image_log_freq"],
295 | num_images_log=config["num_images_log"],
296 | current_epoch=current_epoch,
297 | learn_angle=config["learn_angle"],
298 | alpha=config["alpha"],
299 | use_wandb=config["use_wandb"],
300 | eval_fraction=config["eval_fraction"],
301 | )
302 | else:
303 | train_eval_loop_nomad(
304 | train_model=config["train"],
305 | model=model,
306 | optimizer=optimizer,
307 | lr_scheduler=scheduler,
308 | noise_scheduler=noise_scheduler,
309 | train_loader=train_loader,
310 | test_dataloaders=test_dataloaders,
311 | transform=transform,
312 | goal_mask_prob=config["goal_mask_prob"],
313 | epochs=config["epochs"],
314 | device=device,
315 | project_folder=config["project_folder"],
316 | print_log_freq=config["print_log_freq"],
317 | wandb_log_freq=config["wandb_log_freq"],
318 | image_log_freq=config["image_log_freq"],
319 | num_images_log=config["num_images_log"],
320 | current_epoch=current_epoch,
321 | alpha=float(config["alpha"]),
322 | use_wandb=config["use_wandb"],
323 | eval_fraction=config["eval_fraction"],
324 | eval_freq=config["eval_freq"],
325 | )
326 |
327 | print("FINISHED TRAINING")
328 |
329 |
330 | if __name__ == "__main__":
331 | torch.multiprocessing.set_start_method("spawn")
332 |
333 | parser = argparse.ArgumentParser(description="Visual Navigation Transformer")
334 |
335 | # project setup
336 | parser.add_argument(
337 | "--config",
338 | "-c",
339 | default="config/vint.yaml",
340 | type=str,
341 | help="Path to the config file in train_config folder",
342 | )
343 | args = parser.parse_args()
344 |
345 | with open("config/defaults.yaml", "r") as f:
346 | default_config = yaml.safe_load(f)
347 |
348 | config = default_config
349 |
350 | with open(args.config, "r") as f:
351 | user_config = yaml.safe_load(f)
352 |
353 | config.update(user_config)
354 |
355 | config["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
356 | config["project_folder"] = os.path.join(
357 | "logs", config["project_name"], config["run_name"]
358 | )
359 | os.makedirs(
360 | config[
361 | "project_folder"
362 | ], # should error if dir already exists to avoid overwriting and old project
363 | )
364 |
365 | if config["use_wandb"]:
366 | wandb.login()
367 | wandb.init(
368 | project=config["project_name"],
369 | settings=wandb.Settings(start_method="fork"),
370 | entity="gnmv2", # TODO: change this to your wandb entity
371 | )
372 | wandb.save(args.config, policy="now") # save the config file
373 | wandb.run.name = config["run_name"]
374 | # update the wandb args with the training configurations
375 | if wandb.run:
376 | wandb.config.update(config)
377 |
378 | print(config)
379 | main(config)
380 |
--------------------------------------------------------------------------------
/train/vint_train/data/vint_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pickle
4 | import yaml
5 | from typing import Any, Dict, List, Optional, Tuple
6 | import tqdm
7 | import io
8 | import lmdb
9 |
10 | import torch
11 | from torch.utils.data import Dataset
12 | import torchvision.transforms.functional as TF
13 |
14 | from vint_train.data.data_utils import (
15 | img_path_to_data,
16 | calculate_sin_cos,
17 | get_data_path,
18 | to_local_coords,
19 | )
20 |
21 | class ViNT_Dataset(Dataset):
22 | def __init__(
23 | self,
24 | data_folder: str,
25 | data_split_folder: str,
26 | dataset_name: str,
27 | image_size: Tuple[int, int],
28 | waypoint_spacing: int,
29 | min_dist_cat: int,
30 | max_dist_cat: int,
31 | min_action_distance: int,
32 | max_action_distance: int,
33 | negative_mining: bool,
34 | len_traj_pred: int,
35 | learn_angle: bool,
36 | context_size: int,
37 | context_type: str = "temporal",
38 | end_slack: int = 0,
39 | goals_per_obs: int = 1,
40 | normalize: bool = True,
41 | obs_type: str = "image",
42 | goal_type: str = "image",
43 | ):
44 | """
45 | Main ViNT dataset class
46 |
47 | Args:
48 | data_folder (string): Directory with all the image data
49 | data_split_folder (string): Directory with filepaths.txt, a list of all trajectory names in the dataset split that are each seperated by a newline
50 | dataset_name (string): Name of the dataset [recon, go_stanford, scand, tartandrive, etc.]
51 | waypoint_spacing (int): Spacing between waypoints
52 | min_dist_cat (int): Minimum distance category to use
53 | max_dist_cat (int): Maximum distance category to use
54 | negative_mining (bool): Whether to use negative mining from the ViNG paper (Shah et al.) (https://arxiv.org/abs/2012.09812)
55 | len_traj_pred (int): Length of trajectory of waypoints to predict if this is an action dataset
56 | learn_angle (bool): Whether to learn the yaw of the robot at each predicted waypoint if this is an action dataset
57 | context_size (int): Number of previous observations to use as context
58 | context_type (str): Whether to use temporal, randomized, or randomized temporal context
59 | end_slack (int): Number of timesteps to ignore at the end of the trajectory
60 | goals_per_obs (int): Number of goals to sample per observation
61 | normalize (bool): Whether to normalize the distances or actions
62 | goal_type (str): What data type to use for the goal. The only one supported is "image" for now.
63 | """
64 | self.data_folder = data_folder
65 | self.data_split_folder = data_split_folder
66 | self.dataset_name = dataset_name
67 |
68 | traj_names_file = os.path.join(data_split_folder, "traj_names.txt")
69 | with open(traj_names_file, "r") as f:
70 | file_lines = f.read()
71 | self.traj_names = file_lines.split("\n")
72 | if "" in self.traj_names:
73 | self.traj_names.remove("")
74 |
75 | self.image_size = image_size
76 | self.waypoint_spacing = waypoint_spacing
77 | self.distance_categories = list(
78 | range(min_dist_cat, max_dist_cat + 1, self.waypoint_spacing)
79 | )
80 | self.min_dist_cat = self.distance_categories[0]
81 | self.max_dist_cat = self.distance_categories[-1]
82 | self.negative_mining = negative_mining
83 | if self.negative_mining:
84 | self.distance_categories.append(-1)
85 | self.len_traj_pred = len_traj_pred
86 | self.learn_angle = learn_angle
87 |
88 | self.min_action_distance = min_action_distance
89 | self.max_action_distance = max_action_distance
90 |
91 | self.context_size = context_size
92 | assert context_type in {
93 | "temporal",
94 | "randomized",
95 | "randomized_temporal",
96 | }, "context_type must be one of temporal, randomized, randomized_temporal"
97 | self.context_type = context_type
98 | self.end_slack = end_slack
99 | self.goals_per_obs = goals_per_obs
100 | self.normalize = normalize
101 | self.obs_type = obs_type
102 | self.goal_type = goal_type
103 |
104 | # load data/data_config.yaml
105 | with open(
106 | os.path.join(os.path.dirname(__file__), "data_config.yaml"), "r"
107 | ) as f:
108 | all_data_config = yaml.safe_load(f)
109 | assert (
110 | self.dataset_name in all_data_config
111 | ), f"Dataset {self.dataset_name} not found in data_config.yaml"
112 | dataset_names = list(all_data_config.keys())
113 | dataset_names.sort()
114 | # use this index to retrieve the dataset name from the data_config.yaml
115 | self.dataset_index = dataset_names.index(self.dataset_name)
116 | self.data_config = all_data_config[self.dataset_name]
117 | self.trajectory_cache = {}
118 | self._load_index()
119 | self._build_caches()
120 |
121 | if self.learn_angle:
122 | self.num_action_params = 3
123 | else:
124 | self.num_action_params = 2
125 |
126 | def __getstate__(self):
127 | state = self.__dict__.copy()
128 | state["_image_cache"] = None
129 | return state
130 |
131 | def __setstate__(self, state):
132 | self.__dict__ = state
133 | self._build_caches()
134 |
135 | def _build_caches(self, use_tqdm: bool = True):
136 | """
137 | Build a cache of images for faster loading using LMDB
138 | """
139 | cache_filename = os.path.join(
140 | self.data_split_folder,
141 | f"dataset_{self.dataset_name}.lmdb",
142 | )
143 |
144 | # Load all the trajectories into memory. These should already be loaded, but just in case.
145 | for traj_name in self.traj_names:
146 | self._get_trajectory(traj_name)
147 |
148 | """
149 | If the cache file doesn't exist, create it by iterating through the dataset and writing each image to the cache
150 | """
151 | if not os.path.exists(cache_filename):
152 | tqdm_iterator = tqdm.tqdm(
153 | self.goals_index,
154 | disable=not use_tqdm,
155 | dynamic_ncols=True,
156 | desc=f"Building LMDB cache for {self.dataset_name}"
157 | )
158 | with lmdb.open(cache_filename, map_size=2**40) as image_cache:
159 | with image_cache.begin(write=True) as txn:
160 | for traj_name, time in tqdm_iterator:
161 | image_path = get_data_path(self.data_folder, traj_name, time)
162 | with open(image_path, "rb") as f:
163 | txn.put(image_path.encode(), f.read())
164 |
165 | # Reopen the cache file in read-only mode
166 | self._image_cache: lmdb.Environment = lmdb.open(cache_filename, readonly=True)
167 |
168 | def _build_index(self, use_tqdm: bool = False):
169 | """
170 | Build an index consisting of tuples (trajectory name, time, max goal distance)
171 | """
172 | samples_index = []
173 | goals_index = []
174 |
175 | for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True):
176 | traj_data = self._get_trajectory(traj_name)
177 | traj_len = len(traj_data["position"])
178 |
179 | for goal_time in range(0, traj_len):
180 | goals_index.append((traj_name, goal_time))
181 |
182 | begin_time = self.context_size * self.waypoint_spacing
183 | end_time = traj_len - self.end_slack - self.len_traj_pred * self.waypoint_spacing
184 | for curr_time in range(begin_time, end_time):
185 | max_goal_distance = min(self.max_dist_cat * self.waypoint_spacing, traj_len - curr_time - 1)
186 | samples_index.append((traj_name, curr_time, max_goal_distance))
187 |
188 | return samples_index, goals_index
189 |
190 | def _sample_goal(self, trajectory_name, curr_time, max_goal_dist):
191 | """
192 | Sample a goal from the future in the same trajectory.
193 | Returns: (trajectory_name, goal_time, goal_is_negative)
194 | """
195 | goal_offset = np.random.randint(0, max_goal_dist + 1)
196 | if goal_offset == 0:
197 | trajectory_name, goal_time = self._sample_negative()
198 | return trajectory_name, goal_time, True
199 | else:
200 | goal_time = curr_time + int(goal_offset * self.waypoint_spacing)
201 | return trajectory_name, goal_time, False
202 |
203 | def _sample_negative(self):
204 | """
205 | Sample a goal from a (likely) different trajectory.
206 | """
207 | return self.goals_index[np.random.randint(0, len(self.goals_index))]
208 |
209 | def _load_index(self) -> None:
210 | """
211 | Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset
212 | """
213 | index_to_data_path = os.path.join(
214 | self.data_split_folder,
215 | f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_context_{self.context_type}_n{self.context_size}_slack_{self.end_slack}.pkl",
216 | )
217 | try:
218 | # load the index_to_data if it already exists (to save time)
219 | with open(index_to_data_path, "rb") as f:
220 | self.index_to_data, self.goals_index = pickle.load(f)
221 | except:
222 | # if the index_to_data file doesn't exist, create it
223 | self.index_to_data, self.goals_index = self._build_index()
224 | with open(index_to_data_path, "wb") as f:
225 | pickle.dump((self.index_to_data, self.goals_index), f)
226 |
227 | def _load_image(self, trajectory_name, time):
228 | image_path = get_data_path(self.data_folder, trajectory_name, time)
229 |
230 | try:
231 | with self._image_cache.begin() as txn:
232 | image_buffer = txn.get(image_path.encode())
233 | image_bytes = bytes(image_buffer)
234 | image_bytes = io.BytesIO(image_bytes)
235 | return img_path_to_data(image_bytes, self.image_size)
236 | except TypeError:
237 | print(f"Failed to load image {image_path}")
238 |
239 | def _compute_actions(self, traj_data, curr_time, goal_time):
240 | start_index = curr_time
241 | end_index = curr_time + self.len_traj_pred * self.waypoint_spacing + 1
242 | yaw = traj_data["yaw"][start_index:end_index:self.waypoint_spacing]
243 | positions = traj_data["position"][start_index:end_index:self.waypoint_spacing]
244 | goal_pos = traj_data["position"][min(goal_time, len(traj_data["position"]) - 1)]
245 |
246 | if len(yaw.shape) == 2:
247 | yaw = yaw.squeeze(1)
248 |
249 | if yaw.shape != (self.len_traj_pred + 1,):
250 | const_len = self.len_traj_pred + 1 - yaw.shape[0]
251 | yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)])
252 | positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0)
253 |
254 | assert yaw.shape == (self.len_traj_pred + 1,), f"{yaw.shape} and {(self.len_traj_pred + 1,)} should be equal"
255 | assert positions.shape == (self.len_traj_pred + 1, 2), f"{positions.shape} and {(self.len_traj_pred + 1, 2)} should be equal"
256 |
257 | waypoints = to_local_coords(positions, positions[0], yaw[0])
258 | goal_pos = to_local_coords(goal_pos, positions[0], yaw[0])
259 |
260 | assert waypoints.shape == (self.len_traj_pred + 1, 2), f"{waypoints.shape} and {(self.len_traj_pred + 1, 2)} should be equal"
261 |
262 | if self.learn_angle:
263 | yaw = yaw[1:] - yaw[0]
264 | actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1)
265 | else:
266 | actions = waypoints[1:]
267 |
268 | if self.normalize:
269 | actions[:, :2] /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing
270 | goal_pos /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing
271 |
272 | assert actions.shape == (self.len_traj_pred, self.num_action_params), f"{actions.shape} and {(self.len_traj_pred, self.num_action_params)} should be equal"
273 |
274 | return actions, goal_pos
275 |
276 | def _get_trajectory(self, trajectory_name):
277 | if trajectory_name in self.trajectory_cache:
278 | return self.trajectory_cache[trajectory_name]
279 | else:
280 | with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f:
281 | traj_data = pickle.load(f)
282 | self.trajectory_cache[trajectory_name] = traj_data
283 | return traj_data
284 |
285 | def __len__(self) -> int:
286 | return len(self.index_to_data)
287 |
288 | def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
289 | """
290 | Args:
291 | i (int): index to ith datapoint
292 | Returns:
293 | Tuple of tensors containing the context, observation, goal, transformed context, transformed observation, transformed goal, distance label, and action label
294 | obs_image (torch.Tensor): tensor of shape [3, H, W] containing the image of the robot's observation
295 | goal_image (torch.Tensor): tensor of shape [3, H, W] containing the subgoal image
296 | dist_label (torch.Tensor): tensor of shape (1,) containing the distance labels from the observation to the goal
297 | action_label (torch.Tensor): tensor of shape (5, 2) or (5, 4) (if training with angle) containing the action labels from the observation to the goal
298 | which_dataset (torch.Tensor): index of the datapoint in the dataset [for identifying the dataset for visualization when using multiple datasets]
299 | """
300 | f_curr, curr_time, max_goal_dist = self.index_to_data[i]
301 | f_goal, goal_time, goal_is_negative = self._sample_goal(f_curr, curr_time, max_goal_dist)
302 |
303 | # Load images
304 | context = []
305 | if self.context_type == "temporal":
306 | # sample the last self.context_size times from interval [0, curr_time)
307 | context_times = list(
308 | range(
309 | curr_time + -self.context_size * self.waypoint_spacing,
310 | curr_time + 1,
311 | self.waypoint_spacing,
312 | )
313 | )
314 | context = [(f_curr, t) for t in context_times]
315 | else:
316 | raise ValueError(f"Invalid context type {self.context_type}")
317 |
318 | obs_image = torch.cat([
319 | self._load_image(f, t) for f, t in context
320 | ])
321 |
322 | # Load goal image
323 | goal_image = self._load_image(f_goal, goal_time)
324 |
325 | # Load other trajectory data
326 | curr_traj_data = self._get_trajectory(f_curr)
327 | curr_traj_len = len(curr_traj_data["position"])
328 | assert curr_time < curr_traj_len, f"{curr_time} and {curr_traj_len}"
329 |
330 | goal_traj_data = self._get_trajectory(f_goal)
331 | goal_traj_len = len(goal_traj_data["position"])
332 | assert goal_time < goal_traj_len, f"{goal_time} an {goal_traj_len}"
333 |
334 | # Compute actions
335 | actions, goal_pos = self._compute_actions(curr_traj_data, curr_time, goal_time)
336 |
337 | # Compute distances
338 | if goal_is_negative:
339 | distance = self.max_dist_cat
340 | else:
341 | distance = (goal_time - curr_time) // self.waypoint_spacing
342 | assert (goal_time - curr_time) % self.waypoint_spacing == 0, f"{goal_time} and {curr_time} should be separated by an integer multiple of {self.waypoint_spacing}"
343 |
344 | actions_torch = torch.as_tensor(actions, dtype=torch.float32)
345 | if self.learn_angle:
346 | actions_torch = calculate_sin_cos(actions_torch)
347 |
348 | action_mask = (
349 | (distance < self.max_action_distance) and
350 | (distance > self.min_action_distance) and
351 | (not goal_is_negative)
352 | )
353 |
354 | return (
355 | torch.as_tensor(obs_image, dtype=torch.float32),
356 | torch.as_tensor(goal_image, dtype=torch.float32),
357 | actions_torch,
358 | torch.as_tensor(distance, dtype=torch.int64),
359 | torch.as_tensor(goal_pos, dtype=torch.float32),
360 | torch.as_tensor(self.dataset_index, dtype=torch.int64),
361 | torch.as_tensor(action_mask, dtype=torch.float32),
362 | )
363 |
--------------------------------------------------------------------------------
/train/vint_train/visualizing/action_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import cv2
5 | from typing import Optional, List
6 | import wandb
7 | import yaml
8 | import torch
9 | import torch.nn as nn
10 | from vint_train.visualizing.visualize_utils import (
11 | to_numpy,
12 | numpy_to_img,
13 | VIZ_IMAGE_SIZE,
14 | RED,
15 | GREEN,
16 | BLUE,
17 | CYAN,
18 | YELLOW,
19 | MAGENTA,
20 | )
21 |
22 | # load data_config.yaml
23 | with open(os.path.join(os.path.dirname(__file__), "../data/data_config.yaml"), "r") as f:
24 | data_config = yaml.safe_load(f)
25 |
26 |
27 | def visualize_traj_pred(
28 | batch_obs_images: np.ndarray,
29 | batch_goal_images: np.ndarray,
30 | dataset_indices: np.ndarray,
31 | batch_goals: np.ndarray,
32 | batch_pred_waypoints: np.ndarray,
33 | batch_label_waypoints: np.ndarray,
34 | eval_type: str,
35 | normalized: bool,
36 | save_folder: str,
37 | epoch: int,
38 | num_images_preds: int = 8,
39 | use_wandb: bool = True,
40 | display: bool = False,
41 | ):
42 | """
43 | Compare predicted path with the gt path of waypoints using egocentric visualization. This visualization is for the last batch in the dataset.
44 |
45 | Args:
46 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
47 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels]
48 | dataset_names: indices corresponding to the dataset name
49 | batch_goals (np.ndarray): batch of goal positions [batch_size, 2]
50 | batch_pred_waypoints (np.ndarray): batch of predicted waypoints [batch_size, horizon, 4] or [batch_size, horizon, 2] or [batch_size, num_trajs_sampled horizon, {2 or 4}]
51 | batch_label_waypoints (np.ndarray): batch of label waypoints [batch_size, T, 4] or [batch_size, horizon, 2]
52 | eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.)
53 | normalized (bool): whether the waypoints are normalized
54 | save_folder (str): folder to save the images. If None, will not save the images
55 | epoch (int): current epoch number
56 | num_images_preds (int): number of images to visualize
57 | use_wandb (bool): whether to use wandb to log the images
58 | display (bool): whether to display the images
59 | """
60 | visualize_path = None
61 | if save_folder is not None:
62 | visualize_path = os.path.join(
63 | save_folder, "visualize", eval_type, f"epoch{epoch}", "action_prediction"
64 | )
65 |
66 | if not os.path.exists(visualize_path):
67 | os.makedirs(visualize_path)
68 |
69 | assert (
70 | len(batch_obs_images)
71 | == len(batch_goal_images)
72 | == len(batch_goals)
73 | == len(batch_pred_waypoints)
74 | == len(batch_label_waypoints)
75 | )
76 |
77 | dataset_names = list(data_config.keys())
78 | dataset_names.sort()
79 |
80 | batch_size = batch_obs_images.shape[0]
81 | wandb_list = []
82 | for i in range(min(batch_size, num_images_preds)):
83 | obs_img = numpy_to_img(batch_obs_images[i])
84 | goal_img = numpy_to_img(batch_goal_images[i])
85 | dataset_name = dataset_names[int(dataset_indices[i])]
86 | goal_pos = batch_goals[i]
87 | pred_waypoints = batch_pred_waypoints[i]
88 | label_waypoints = batch_label_waypoints[i]
89 |
90 | if normalized:
91 | pred_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"]
92 | label_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"]
93 | goal_pos *= data_config[dataset_name]["metric_waypoint_spacing"]
94 |
95 | save_path = None
96 | if visualize_path is not None:
97 | save_path = os.path.join(visualize_path, f"{str(i).zfill(4)}.png")
98 |
99 | compare_waypoints_pred_to_label(
100 | obs_img,
101 | goal_img,
102 | dataset_name,
103 | goal_pos,
104 | pred_waypoints,
105 | label_waypoints,
106 | save_path,
107 | display,
108 | )
109 | if use_wandb:
110 | wandb_list.append(wandb.Image(save_path))
111 | if use_wandb:
112 | wandb.log({f"{eval_type}_action_prediction": wandb_list}, commit=False)
113 |
114 |
115 | def compare_waypoints_pred_to_label(
116 | obs_img,
117 | goal_img,
118 | dataset_name: str,
119 | goal_pos: np.ndarray,
120 | pred_waypoints: np.ndarray,
121 | label_waypoints: np.ndarray,
122 | save_path: Optional[str] = None,
123 | display: Optional[bool] = False,
124 | ):
125 | """
126 | Compare predicted path with the gt path of waypoints using egocentric visualization.
127 |
128 | Args:
129 | obs_img: image of the observation
130 | goal_img: image of the goal
131 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon")
132 | goal_pos: goal position in the image
133 | pred_waypoints: predicted waypoints in the image
134 | label_waypoints: label waypoints in the image
135 | save_path: path to save the figure
136 | display: whether to display the figure
137 | """
138 |
139 | fig, ax = plt.subplots(1, 3)
140 | start_pos = np.array([0, 0])
141 | if len(pred_waypoints.shape) > 2:
142 | trajs = [*pred_waypoints, label_waypoints]
143 | else:
144 | trajs = [pred_waypoints, label_waypoints]
145 | plot_trajs_and_points(
146 | ax[0],
147 | trajs,
148 | [start_pos, goal_pos],
149 | traj_colors=[CYAN, MAGENTA],
150 | point_colors=[GREEN, RED],
151 | )
152 | plot_trajs_and_points_on_image(
153 | ax[1],
154 | obs_img,
155 | dataset_name,
156 | trajs,
157 | [start_pos, goal_pos],
158 | traj_colors=[CYAN, MAGENTA],
159 | point_colors=[GREEN, RED],
160 | )
161 | ax[2].imshow(goal_img)
162 |
163 | fig.set_size_inches(18.5, 10.5)
164 | ax[0].set_title(f"Action Prediction")
165 | ax[1].set_title(f"Observation")
166 | ax[2].set_title(f"Goal")
167 |
168 | if save_path is not None:
169 | fig.savefig(
170 | save_path,
171 | bbox_inches="tight",
172 | )
173 |
174 | if not display:
175 | plt.close(fig)
176 |
177 |
178 | def plot_trajs_and_points_on_image(
179 | ax: plt.Axes,
180 | img: np.ndarray,
181 | dataset_name: str,
182 | list_trajs: list,
183 | list_points: list,
184 | traj_colors: list = [CYAN, MAGENTA],
185 | point_colors: list = [RED, GREEN],
186 | ):
187 | """
188 | Plot trajectories and points on an image. If there is no configuration for the camera interinstics of the dataset, the image will be plotted as is.
189 | Args:
190 | ax: matplotlib axis
191 | img: image to plot
192 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon")
193 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw)
194 | list_points: list of points, each point is a numpy array of shape (2,)
195 | traj_colors: list of colors for trajectories
196 | point_colors: list of colors for points
197 | """
198 | assert len(list_trajs) <= len(traj_colors), "Not enough colors for trajectories"
199 | assert len(list_points) <= len(point_colors), "Not enough colors for points"
200 | assert (
201 | dataset_name in data_config
202 | ), f"Dataset {dataset_name} not found in data/data_config.yaml"
203 |
204 | ax.imshow(img)
205 | if (
206 | "camera_metrics" in data_config[dataset_name]
207 | and "camera_height" in data_config[dataset_name]["camera_metrics"]
208 | and "camera_matrix" in data_config[dataset_name]["camera_metrics"]
209 | and "dist_coeffs" in data_config[dataset_name]["camera_metrics"]
210 | ):
211 | camera_height = data_config[dataset_name]["camera_metrics"]["camera_height"]
212 | camera_x_offset = data_config[dataset_name]["camera_metrics"]["camera_x_offset"]
213 |
214 | fx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fx"]
215 | fy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fy"]
216 | cx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cx"]
217 | cy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cy"]
218 | camera_matrix = gen_camera_matrix(fx, fy, cx, cy)
219 |
220 | k1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k1"]
221 | k2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k2"]
222 | p1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p1"]
223 | p2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p2"]
224 | k3 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k3"]
225 | dist_coeffs = np.array([k1, k2, p1, p2, k3, 0.0, 0.0, 0.0])
226 |
227 | for i, traj in enumerate(list_trajs):
228 | xy_coords = traj[:, :2] # (horizon, 2)
229 | traj_pixels = get_pos_pixels(
230 | xy_coords, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=False
231 | )
232 | if len(traj_pixels.shape) == 2:
233 | ax.plot(
234 | traj_pixels[:250, 0],
235 | traj_pixels[:250, 1],
236 | color=traj_colors[i],
237 | lw=2.5,
238 | )
239 |
240 | for i, point in enumerate(list_points):
241 | if len(point.shape) == 1:
242 | # add a dimension to the front of point
243 | point = point[None, :2]
244 | else:
245 | point = point[:, :2]
246 | pt_pixels = get_pos_pixels(
247 | point, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=True
248 | )
249 | ax.plot(
250 | pt_pixels[:250, 0],
251 | pt_pixels[:250, 1],
252 | color=point_colors[i],
253 | marker="o",
254 | markersize=10.0,
255 | )
256 | ax.xaxis.set_visible(False)
257 | ax.yaxis.set_visible(False)
258 | ax.set_xlim((0.5, VIZ_IMAGE_SIZE[0] - 0.5))
259 | ax.set_ylim((VIZ_IMAGE_SIZE[1] - 0.5, 0.5))
260 |
261 |
262 | def plot_trajs_and_points(
263 | ax: plt.Axes,
264 | list_trajs: list,
265 | list_points: list,
266 | traj_colors: list = [CYAN, MAGENTA],
267 | point_colors: list = [RED, GREEN],
268 | traj_labels: Optional[list] = ["prediction", "ground truth"],
269 | point_labels: Optional[list] = ["robot", "goal"],
270 | traj_alphas: Optional[list] = None,
271 | point_alphas: Optional[list] = None,
272 | quiver_freq: int = 1,
273 | default_coloring: bool = True,
274 | ):
275 | """
276 | Plot trajectories and points that could potentially have a yaw.
277 |
278 | Args:
279 | ax: matplotlib axis
280 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw)
281 | list_points: list of points, each point is a numpy array of shape (2,)
282 | traj_colors: list of colors for trajectories
283 | point_colors: list of colors for points
284 | traj_labels: list of labels for trajectories
285 | point_labels: list of labels for points
286 | traj_alphas: list of alphas for trajectories
287 | point_alphas: list of alphas for points
288 | quiver_freq: frequency of quiver plot (if the trajectory data includes the yaw of the robot)
289 | """
290 | assert (
291 | len(list_trajs) <= len(traj_colors) or default_coloring
292 | ), "Not enough colors for trajectories"
293 | assert len(list_points) <= len(point_colors), "Not enough colors for points"
294 | assert (
295 | traj_labels is None or len(list_trajs) == len(traj_labels) or default_coloring
296 | ), "Not enough labels for trajectories"
297 | assert point_labels is None or len(list_points) == len(point_labels), "Not enough labels for points"
298 |
299 | for i, traj in enumerate(list_trajs):
300 | if traj_labels is None:
301 | ax.plot(
302 | traj[:, 0],
303 | traj[:, 1],
304 | color=traj_colors[i],
305 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0,
306 | marker="o",
307 | )
308 | else:
309 | ax.plot(
310 | traj[:, 0],
311 | traj[:, 1],
312 | color=traj_colors[i],
313 | label=traj_labels[i],
314 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0,
315 | marker="o",
316 | )
317 | if traj.shape[1] > 2 and quiver_freq > 0: # traj data also includes yaw of the robot
318 | bearings = gen_bearings_from_waypoints(traj)
319 | ax.quiver(
320 | traj[::quiver_freq, 0],
321 | traj[::quiver_freq, 1],
322 | bearings[::quiver_freq, 0],
323 | bearings[::quiver_freq, 1],
324 | color=traj_colors[i] * 0.5,
325 | scale=1.0,
326 | )
327 | for i, pt in enumerate(list_points):
328 | if point_labels is None:
329 | ax.plot(
330 | pt[0],
331 | pt[1],
332 | color=point_colors[i],
333 | alpha=point_alphas[i] if point_alphas is not None else 1.0,
334 | marker="o",
335 | markersize=7.0
336 | )
337 | else:
338 | ax.plot(
339 | pt[0],
340 | pt[1],
341 | color=point_colors[i],
342 | alpha=point_alphas[i] if point_alphas is not None else 1.0,
343 | marker="o",
344 | markersize=7.0,
345 | label=point_labels[i],
346 | )
347 |
348 |
349 | # put the legend below the plot
350 | if traj_labels is not None or point_labels is not None:
351 | ax.legend()
352 | ax.legend(bbox_to_anchor=(0.0, -0.5), loc="upper left", ncol=2)
353 | ax.set_aspect("equal", "box")
354 |
355 |
356 | def angle_to_unit_vector(theta):
357 | """Converts an angle to a unit vector."""
358 | return np.array([np.cos(theta), np.sin(theta)])
359 |
360 |
361 | def gen_bearings_from_waypoints(
362 | waypoints: np.ndarray,
363 | mag=0.2,
364 | ) -> np.ndarray:
365 | """Generate bearings from waypoints, (x, y, sin(theta), cos(theta))."""
366 | bearing = []
367 | for i in range(0, len(waypoints)):
368 | if waypoints.shape[1] > 3: # label is sin/cos repr
369 | v = waypoints[i, 2:]
370 | # normalize v
371 | v = v / np.linalg.norm(v)
372 | v = v * mag
373 | else: # label is radians repr
374 | v = mag * angle_to_unit_vector(waypoints[i, 2])
375 | bearing.append(v)
376 | bearing = np.array(bearing)
377 | return bearing
378 |
379 |
380 | def project_points(
381 | xy: np.ndarray,
382 | camera_height: float,
383 | camera_x_offset: float,
384 | camera_matrix: np.ndarray,
385 | dist_coeffs: np.ndarray,
386 | ):
387 | """
388 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters.
389 |
390 | Args:
391 | xy: array of shape (batch_size, horizon, 2) representing (x, y) coordinates
392 | camera_height: height of the camera above the ground (in meters)
393 | camera_x_offset: offset of the camera from the center of the car (in meters)
394 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters
395 | dist_coeffs: vector of distortion coefficients
396 |
397 |
398 | Returns:
399 | uv: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane
400 | """
401 | batch_size, horizon, _ = xy.shape
402 |
403 | # create 3D coordinates with the camera positioned at the given height
404 | xyz = np.concatenate(
405 | [xy, -camera_height * np.ones(list(xy.shape[:-1]) + [1])], axis=-1
406 | )
407 |
408 | # create dummy rotation and translation vectors
409 | rvec = tvec = (0, 0, 0)
410 |
411 | xyz[..., 0] += camera_x_offset
412 | xyz_cv = np.stack([xyz[..., 1], -xyz[..., 2], xyz[..., 0]], axis=-1)
413 | uv, _ = cv2.projectPoints(
414 | xyz_cv.reshape(batch_size * horizon, 3), rvec, tvec, camera_matrix, dist_coeffs
415 | )
416 | uv = uv.reshape(batch_size, horizon, 2)
417 |
418 | return uv
419 |
420 |
421 | def get_pos_pixels(
422 | points: np.ndarray,
423 | camera_height: float,
424 | camera_x_offset: float,
425 | camera_matrix: np.ndarray,
426 | dist_coeffs: np.ndarray,
427 | clip: Optional[bool] = False,
428 | ):
429 | """
430 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters.
431 | Args:
432 | points: array of shape (batch_size, horizon, 2) representing (x, y) coordinates
433 | camera_height: height of the camera above the ground (in meters)
434 | camera_x_offset: offset of the camera from the center of the car (in meters)
435 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters
436 | dist_coeffs: vector of distortion coefficients
437 |
438 | Returns:
439 | pixels: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane
440 | """
441 | pixels = project_points(
442 | points[np.newaxis], camera_height, camera_x_offset, camera_matrix, dist_coeffs
443 | )[0]
444 | pixels[:, 0] = VIZ_IMAGE_SIZE[0] - pixels[:, 0]
445 | if clip:
446 | pixels = np.array(
447 | [
448 | [
449 | np.clip(p[0], 0, VIZ_IMAGE_SIZE[0]),
450 | np.clip(p[1], 0, VIZ_IMAGE_SIZE[1]),
451 | ]
452 | for p in pixels
453 | ]
454 | )
455 | else:
456 | pixels = np.array(
457 | [
458 | p
459 | for p in pixels
460 | if np.all(p > 0) and np.all(p < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]])
461 | ]
462 | )
463 | return pixels
464 |
465 |
466 | def gen_camera_matrix(fx: float, fy: float, cx: float, cy: float) -> np.ndarray:
467 | """
468 | Args:
469 | fx: focal length in x direction
470 | fy: focal length in y direction
471 | cx: principal point x coordinate
472 | cy: principal point y coordinate
473 | Returns:
474 | camera matrix
475 | """
476 | return np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])
477 |
--------------------------------------------------------------------------------
/deployment/src/navigate.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import os
3 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
8 |
9 | import matplotlib.pyplot as plt
10 | import yaml
11 | import tracemalloc
12 |
13 | # ROS
14 | import rospy
15 | from sensor_msgs.msg import Image
16 | from visualization_msgs.msg import Marker
17 | from geometry_msgs.msg import Point, PoseStamped
18 | from std_msgs.msg import Bool, Float32MultiArray
19 | from utils import msg_to_pil, to_numpy, transform_images, load_model, rotate_point_by_quaternion
20 |
21 | from vint_train.training.train_utils import get_action
22 | from vint_train.visualizing.action_utils import plot_trajs_and_points
23 | from guide import PathGuide, PathOpt
24 | import torch
25 | from PIL import Image as PILImage
26 | import numpy as np
27 | import argparse
28 | import yaml
29 | import time
30 |
31 | # UTILS
32 | from topic_names import (IMAGE_TOPIC,
33 | WAYPOINT_TOPIC,
34 | SUB_GOAL_TOPIC,
35 | POS_TOPIC,
36 | SAMPLED_ACTIONS_TOPIC,
37 | VISUAL_MARKER_TOPIC)
38 |
39 |
40 | # CONSTANTS
41 | TOPOMAP_IMAGES_DIR = "../topomaps/images"
42 | MODEL_WEIGHTS_PATH = "../model_weights"
43 | ROBOT_CONFIG_PATH ="../config/robot.yaml"
44 | MODEL_CONFIG_PATH = "../config/models.yaml"
45 | with open(ROBOT_CONFIG_PATH, "r") as f:
46 | robot_config = yaml.safe_load(f)
47 | MAX_V = robot_config["max_v"]
48 | MAX_W = robot_config["max_w"]
49 | RATE = robot_config["frame_rate"]
50 | ACTION_STATS = {}
51 | ACTION_STATS['min'] = np.array([-2.5, -4])
52 | ACTION_STATS['max'] = np.array([5, 4])
53 |
54 | # GLOBALS
55 | context_queue = []
56 | context_size = None
57 | subgoal = []
58 |
59 | robo_pos = None
60 | robo_orientation = None
61 | rela_pos = None
62 | # Load the model
63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64 | print("Using device:", device)
65 |
66 | def get_plt_param(uc_actions, gc_actions, goal_pos):
67 | traj_list = np.concatenate([
68 | uc_actions,
69 | gc_actions,
70 | ], axis=0)
71 | traj_colors = ["red"] * len(uc_actions) + ["green"] * len(gc_actions) + ["magenta"]
72 | traj_alphas = [0.1] * (len(uc_actions) + len(gc_actions)) + [1.0]
73 |
74 | point_list = [np.array([0, 0]), goal_pos]
75 | point_colors = ["green", "red"]
76 | point_alphas = [1.0, 1.0]
77 | return traj_list, traj_colors, traj_alphas, point_list, point_colors, point_alphas
78 |
79 | def action_plot(uc_actions, gc_actions, goal_pos):
80 | traj_list, traj_colors, traj_alphas, point_list, point_colors, point_alphas = get_plt_param(uc_actions, gc_actions, goal_pos)
81 | fig, ax = plt.subplots(1, 1)
82 | plot_trajs_and_points(
83 | ax,
84 | traj_list,
85 | point_list,
86 | traj_colors,
87 | point_colors,
88 | traj_labels=None,
89 | point_labels=None,
90 | quiver_freq=0,
91 | traj_alphas=traj_alphas,
92 | point_alphas=point_alphas,
93 | )
94 |
95 | save_path = os.path.join(f"output_goal_{rela_pos}.png")
96 | plt.savefig(save_path)
97 | plt.close(fig)
98 | print(f"output image saved as {save_path}")
99 |
100 | def Marker_process(points, id, selected_num, length=8):
101 | marker = Marker()
102 | marker.header.frame_id = "base_link"
103 | marker.header.stamp = rospy.Time.now()
104 | marker.ns= "points"
105 | marker.id = id
106 | marker.type = Marker.LINE_STRIP
107 | marker.action = Marker.ADD
108 | marker.scale.x = 0.01
109 | marker.scale.y = 0.01
110 | marker.scale.z = 0.01
111 | if selected_num == id:
112 | marker.color.a = 1.0
113 | marker.color.r = 1.0
114 | marker.color.g = 0.0
115 | marker.color.b = 0.0
116 | else:
117 | marker.color.a = 1.0
118 | marker.color.r = 0.0
119 | marker.color.g = 0.0
120 | marker.color.b = 1.0
121 | for i in range(length):
122 | p = Point()
123 | p.x = points[2 * i]
124 | p.y = points[2 * i + 1]
125 | p.z = 0
126 | marker.points.append(p)
127 | return marker
128 |
129 | def Marker_process_goal(points, marker, length=1):
130 | marker.header.frame_id = "base_link"
131 | marker.header.stamp = rospy.Time.now()
132 | marker.ns= "points"
133 | marker.id = 0
134 | marker.type = Marker.POINTS
135 | marker.action = Marker.ADD
136 | marker.scale.x = 0.1
137 | marker.scale.y = 0.1
138 | marker.color.a = 1.0
139 | marker.color.r = 1.0
140 | marker.color.g = 0.0
141 | marker.color.b = 0.0
142 |
143 | for i in range(length):
144 | p = Point()
145 | p.x = points[2 * i]
146 | p.y = points[2 * i + 1]
147 | p.z = 1
148 | marker.points.append(p)
149 | return marker
150 |
151 | def callback_obs(msg):
152 | obs_img = msg_to_pil(msg)
153 | if obs_img.mode == 'RGBA':
154 | obs_img = obs_img.convert('RGB')
155 | else:
156 | obs_img = obs_img
157 | if context_size is not None:
158 | if len(context_queue) < context_size + 1:
159 | context_queue.append(obs_img)
160 | else:
161 | context_queue.pop(0)
162 | context_queue.append(obs_img)
163 |
164 | def pos_callback(msg):
165 | global robo_pos, robo_orientation
166 | robo_pos = np.array([msg.pose.position.x, msg.pose.position.y, msg.pose.position.z])
167 | robo_orientation = np.array([msg.pose.orientation.x, msg.pose.orientation.y,
168 | msg.pose.orientation.z, msg.pose.orientation.w])
169 |
170 | def main(args: argparse.Namespace):
171 | global context_size, robo_pos, robo_orientation, rela_pos
172 |
173 | # load model parameters
174 | with open(MODEL_CONFIG_PATH, "r") as f:
175 | model_paths = yaml.safe_load(f)
176 |
177 | model_config_path = model_paths[args.model]["config_path"]
178 | with open(model_config_path, "r") as f:
179 | model_params = yaml.safe_load(f)
180 |
181 | if args.pos_goal:
182 | with open(os.path.join(TOPOMAP_IMAGES_DIR, args.dir, "position.txt"), 'r') as file:
183 | lines = file.readlines()
184 |
185 | context_size = model_params["context_size"]
186 |
187 | # load model weights
188 | ckpth_path = model_paths[args.model]["ckpt_path"]
189 | if os.path.exists(ckpth_path):
190 | print(f"Loading model from {ckpth_path}")
191 | else:
192 | raise FileNotFoundError(f"Model weights not found at {ckpth_path}")
193 | model = load_model(
194 | ckpth_path,
195 | model_params,
196 | device,
197 | )
198 | model = model.to(device)
199 | model.eval()
200 |
201 | pathguide = PathGuide(device, ACTION_STATS)
202 | pathopt = PathOpt()
203 | # load topomap
204 | topomap_filenames = sorted([filename for filename in os.listdir(os.path.join(
205 | TOPOMAP_IMAGES_DIR, args.dir)) if filename.endswith('.png')],
206 | key=lambda x: int(x.split(".")[0]))
207 | topomap_dir = f"{TOPOMAP_IMAGES_DIR}/{args.dir}"
208 | num_nodes = len(topomap_filenames)
209 | topomap = []
210 | for i in range(num_nodes):
211 | image_path = os.path.join(topomap_dir, topomap_filenames[i])
212 | topomap.append(PILImage.open(image_path))
213 |
214 | closest_node = args.init_node
215 | assert -1 <= args.goal_node < len(topomap), "Invalid goal index"
216 | if args.goal_node == -1:
217 | goal_node = len(topomap) - 1
218 | else:
219 | goal_node = args.goal_node
220 |
221 | # ROS
222 | rospy.init_node("EXPLORATION", anonymous=False)
223 | rate = rospy.Rate(RATE)
224 | image_curr_msg = rospy.Subscriber(
225 | IMAGE_TOPIC, Image, callback_obs, queue_size=1)
226 |
227 | if args.pos_goal:
228 | pos_curr_msg = rospy.Subscriber(
229 | POS_TOPIC, PoseStamped, pos_callback, queue_size=1)
230 | subgoal_pub = rospy.Publisher(
231 | SUB_GOAL_TOPIC, Marker, queue_size=1)
232 | robogoal_pub = rospy.Publisher(
233 | '/goal1', Marker, queue_size=1)
234 | waypoint_pub = rospy.Publisher(
235 | WAYPOINT_TOPIC, Float32MultiArray, queue_size=1)
236 | sampled_actions_pub = rospy.Publisher(SAMPLED_ACTIONS_TOPIC, Float32MultiArray, queue_size=1)
237 | goal_pub = rospy.Publisher("/topoplan/reached_goal", Bool, queue_size=1)
238 | marker_pub = rospy.Publisher(VISUAL_MARKER_TOPIC, Marker, queue_size=10)
239 |
240 | print("Registered with master node. Waiting for image observations...")
241 |
242 | if model_params["model_type"] == "nomad":
243 | num_diffusion_iters = model_params["num_diffusion_iters"]
244 | noise_scheduler = DDPMScheduler(
245 | num_train_timesteps=model_params["num_diffusion_iters"],
246 | beta_schedule='squaredcos_cap_v2',
247 | clip_sample=True,
248 | prediction_type='epsilon'
249 | )
250 |
251 | scale = 4.0
252 | scale_factor = scale * MAX_V / RATE
253 | # navigation loop
254 | while not rospy.is_shutdown():
255 | chosen_waypoint = np.zeros(4)
256 | if len(context_queue) > model_params["context_size"]:
257 | if model_params["model_type"] == "nomad":
258 | obs_images = transform_images(context_queue, model_params["image_size"], center_crop=False)
259 | if args.guide:
260 | pathguide.get_cost_map_via_tsdf(context_queue[-1])
261 | obs_images = torch.split(obs_images, 3, dim=1)
262 | obs_images = torch.cat(obs_images, dim=1)
263 | obs_images = obs_images.to(device)
264 | start = max(closest_node - args.radius, 0)
265 | end = min(closest_node + args.radius + 1, goal_node)
266 | if args.pos_goal:
267 | mask = torch.ones(1).long().to(device)
268 | goal_pos = np.array([float(lines[end].split()[0]), float(lines[end].split()[1]), float(lines[end].split()[2])])
269 | rela_pos = goal_pos - robo_pos
270 | rela_pos = rotate_point_by_quaternion(rela_pos, robo_orientation)[:2]
271 | print('rela_pos: ', rela_pos)
272 | marker_robogoal = Marker()
273 | Marker_process_goal(rela_pos[:2], marker_robogoal, 1)
274 | robogoal_pub.publish(marker_robogoal)
275 | else:
276 | mask = torch.zeros(1).long().to(device)
277 | goal_image = [transform_images(g_img, model_params["image_size"], center_crop=False).to(device) for g_img in topomap[start:end + 1]]
278 | goal_image = torch.concat(goal_image, dim=0)
279 | obsgoal_cond = model('vision_encoder', obs_img=obs_images.repeat(len(goal_image), 1, 1, 1), goal_img=goal_image, input_goal_mask=mask.repeat(len(goal_image)))
280 | if args.pos_goal:
281 | goal_poses = np.array([[float(lines[i].split()[0]), float(lines[i].split()[1]), float(lines[i].split()[2])] for i in range(start, end + 1)])
282 | min_idx = np.argmin(np.linalg.norm(goal_poses - robo_pos, axis=1))
283 | sg_idx = min_idx
284 | else:
285 | dists = model("dist_pred_net", obsgoal_cond=obsgoal_cond)
286 | dists = to_numpy(dists.flatten())
287 | min_idx = np.argmin(dists)
288 | sg_idx = min(min_idx + int(dists[min_idx] < args.close_threshold), len(obsgoal_cond) - 1)
289 | time4 = time.time()
290 | closest_node = min_idx + start
291 | print("closest node:", closest_node)
292 |
293 | obs_cond = obsgoal_cond[sg_idx].unsqueeze(0)
294 | # infer action
295 | with torch.no_grad():
296 | # encoder vision features
297 | if len(obs_cond.shape) == 2:
298 | obs_cond = obs_cond.repeat(args.num_samples, 1)
299 | else:
300 | obs_cond = obs_cond.repeat(args.num_samples, 1, 1)
301 |
302 | # initialize action from Gaussian noise
303 | noisy_action = torch.randn(
304 | (args.num_samples, model_params["len_traj_pred"], 2), device=device)
305 | naction = noisy_action
306 |
307 | # init scheduler
308 | noise_scheduler.set_timesteps(num_diffusion_iters)
309 |
310 | start_time = time.time()
311 | for k in noise_scheduler.timesteps[:]:
312 | with torch.no_grad():
313 | # predict noise
314 | noise_pred = model(
315 | 'noise_pred_net',
316 | sample=naction,
317 | timestep=k,
318 | global_cond=obs_cond
319 | )
320 | # inverse diffusion step (remove noise)
321 | naction = noise_scheduler.step(
322 | model_output=noise_pred,
323 | timestep=k,
324 | sample=naction
325 | ).prev_sample
326 | if args.guide:
327 | interval1 = 6
328 | period = 1
329 | if k <= interval1:
330 | if k % period == 0:
331 | if k > 2:
332 | grad, cost_list = pathguide.get_gradient(naction, goal_pos=rela_pos, scale_factor=scale_factor)
333 | grad_scale = 1.0
334 | naction -= grad_scale * grad
335 | else:
336 | if k>=0 and k <= 2:
337 | naction_tmp = naction.detach().clone()
338 | for i in range(1):
339 | grad, cost_list = pathguide.get_gradient(naction_tmp, goal_pos=rela_pos, scale_factor=scale_factor)
340 | naction_tmp -= grad
341 | naction = naction_tmp
342 |
343 | naction = to_numpy(get_action(naction))
344 | naction_selected, selected_num = pathopt.select_trajectory(naction, l=args.waypoint, angle_threshold=45)
345 | sampled_actions_msg = Float32MultiArray()
346 | sampled_actions_msg.data = np.concatenate((np.array([0]), naction.flatten()))
347 | for i in range(8):
348 | marker = Marker_process(sampled_actions_msg.data[i * 16 + 1 : (i + 1) * 16 + 1] * scale_factor, i, selected_num)
349 | marker_pub.publish(marker)
350 | print("published sampled actions")
351 | sampled_actions_pub.publish(sampled_actions_msg)
352 |
353 | chosen_waypoint = naction_selected[args.waypoint]
354 | elif (len(context_queue) > model_params["context_size"]):
355 | start = max(closest_node - args.radius, 0)
356 | end = min(closest_node + args.radius + 1, goal_node)
357 | distances = []
358 | waypoints = []
359 | batch_obs_imgs = []
360 | batch_goal_data = []
361 | for i, sg_img in enumerate(topomap[start: end + 1]):
362 | transf_obs_img = transform_images(context_queue, model_params["image_size"])
363 | goal_data = transform_images(sg_img, model_params["image_size"])
364 | batch_obs_imgs.append(transf_obs_img)
365 | batch_goal_data.append(goal_data)
366 |
367 | # predict distances and waypoints
368 | batch_obs_imgs = torch.cat(batch_obs_imgs, dim=0).to(device)
369 | batch_goal_data = torch.cat(batch_goal_data, dim=0).to(device)
370 |
371 | distances, waypoints = model(batch_obs_imgs, batch_goal_data)
372 | distances = to_numpy(distances)
373 | waypoints = to_numpy(waypoints)
374 | # look for closest node
375 | if args.pos_goal:
376 | goal_poses = np.array([[float(lines[i].split()[0]), float(lines[i].split()[1]), float(lines[i].split()[2])] for i in range(start, end + 1)])
377 | closest_node = np.argmin(np.linalg.norm(goal_poses - robo_pos, axis=1))
378 | else:
379 | closest_node = np.argmin(distances)
380 | # chose subgoal and output waypoints
381 | if distances[closest_node] > args.close_threshold:
382 | chosen_waypoint = waypoints[closest_node][args.waypoint]
383 | sg_img = topomap[start + closest_node]
384 | else:
385 | chosen_waypoint = waypoints[min(
386 | closest_node + 1, len(waypoints) - 1)][args.waypoint]
387 | sg_img = topomap[start + min(closest_node + 1, len(waypoints) - 1)]
388 |
389 | if model_params["normalize"]:
390 | chosen_waypoint[:2] *= (scale_factor / scale)
391 |
392 | waypoint_msg = Float32MultiArray()
393 | waypoint_msg.data = chosen_waypoint
394 | waypoint_pub.publish(waypoint_msg)
395 |
396 | torch.cuda.empty_cache()
397 |
398 | rate.sleep()
399 |
400 |
401 | if __name__ == "__main__":
402 | parser = argparse.ArgumentParser(
403 | description="Code to run GNM DIFFUSION EXPLORATION on the locobot")
404 | parser.add_argument(
405 | "--model",
406 | "-m",
407 | default="nomad",
408 | type=str,
409 | help="model name (only nomad is supported) (hint: check ../config/models.yaml) (default: nomad)",
410 | )
411 | parser.add_argument(
412 | "--waypoint",
413 | "-w",
414 | default=2, # close waypoints exihibit straight line motion (the middle waypoint is a good default)
415 | type=int,
416 | help=f"""index of the waypoint used for navigation (between 0 and 4 or
417 | how many waypoints your model predicts) (default: 2)""",
418 | )
419 | parser.add_argument(
420 | "--dir",
421 | "-d",
422 | default="topomap",
423 | type=str,
424 | help="path to topomap images",
425 | )
426 | parser.add_argument(
427 | "--init-node",
428 | "-i",
429 | default=0,
430 | type=int,
431 | help="""goal node index in the topomap (if -1, then the goal node is
432 | the last node in the topomap) (default: -1)""",
433 | )
434 | parser.add_argument(
435 | "--goal-node",
436 | "-g",
437 | default=-1,
438 | type=int,
439 | help="""goal node index in the topomap (if -1, then the goal node is
440 | the last node in the topomap) (default: -1)""",
441 | )
442 | parser.add_argument(
443 | "--close-threshold",
444 | "-t",
445 | default=3,
446 | type=int,
447 | help="""temporal distance within the next node in the topomap before
448 | localizing to it (default: 3)""",
449 | )
450 | parser.add_argument(
451 | "--radius",
452 | "-r",
453 | default=4,
454 | type=int,
455 | help="""temporal number of locobal nodes to look at in the topopmap for
456 | localization (default: 2)""",
457 | )
458 | parser.add_argument(
459 | "--num-samples",
460 | "-n",
461 | default=8,
462 | type=int,
463 | help=f"Number of actions sampled from the exploration model (default: 8)",
464 | )
465 | parser.add_argument(
466 | "--guide",
467 | default=True,
468 | type=bool,
469 | )
470 | parser.add_argument(
471 | "--point-goal",
472 | default=False,
473 | type=bool,
474 | )
475 | args = parser.parse_args()
476 | print(f"Using {device}")
477 | main(args)
478 |
479 |
480 |
--------------------------------------------------------------------------------