├── train ├── vint_train │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── data_config.yaml │ │ ├── data_utils.py │ │ └── vint_dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── navibridge │ │ │ ├── __init__.py │ │ │ ├── ddbm │ │ │ │ ├── __init__.py │ │ │ │ ├── script_util.py │ │ │ │ ├── resample.py │ │ │ │ ├── nn.py │ │ │ │ └── karras_diffusion.py │ │ │ ├── self_attention.py │ │ │ ├── vae │ │ │ │ ├── vae.py │ │ │ │ └── conditional_mlp_1D_vae.py │ │ │ ├── navibridg_utils.py │ │ │ └── navibridge.py │ │ ├── base_model.py │ │ └── model_utils.py │ ├── training │ │ ├── __init__.py │ │ ├── logger.py │ │ └── train_eval_loop.py │ ├── process_data │ │ ├── __init__.py │ │ ├── process_bags_config.yaml │ │ └── process_data_utils.py │ └── visualizing │ │ ├── __init__.py │ │ ├── visualize_utils.py │ │ ├── distance_utils.py │ │ └── action_utils.py ├── setup.py ├── train_environment.yml ├── config │ ├── defaults.yaml │ ├── cvae.yaml │ └── navibridge.yaml ├── process_recon.py ├── data_split.py ├── process_bags.py └── train.py ├── deployment ├── config │ ├── params.yaml │ ├── robot.yaml │ └── models.yaml ├── deployment_environment.yaml └── src │ ├── navibridger_inference.py │ └── utils_inference.py ├── assets └── pipline.png ├── LICENSE ├── .gitignore └── README.md /train/vint_train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/process_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/visualizing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /deployment/config/params.yaml: -------------------------------------------------------------------------------- 1 | image_path: "../images/" -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/ddbm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/pipline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hren20/NaiviBridger/HEAD/assets/pipline.png -------------------------------------------------------------------------------- /train/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="vint_train", 5 | version="0.1.0", 6 | packages=find_packages(), 7 | ) 8 | -------------------------------------------------------------------------------- /deployment/deployment_environment.yaml: -------------------------------------------------------------------------------- 1 | name: navibridge 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.5 7 | - cudatoolkit=11. 8 | - torchvision 9 | - numpy 10 | - matplotlib 11 | - pyyaml 12 | - rospkg 13 | - pip: 14 | - torch 15 | - torchvision 16 | - efficientnet_pytorch 17 | - warmup_scheduler 18 | - diffusers==0.11.1 -------------------------------------------------------------------------------- /deployment/config/robot.yaml: -------------------------------------------------------------------------------- 1 | # linear and angular speed limits for the robot 2 | max_v: 0.3 #0.4 # m/s 3 | max_w: 0.4 #0.8 # rad/s 4 | # observation rate fo the robot 5 | frame_rate: 4 # Hz 6 | graph_rate: 0.3333 # Hz 7 | 8 | # topic names (modify for different robots/nodes) 9 | vel_teleop_topic: /cmd_vel_mux/input/teleop 10 | vel_navi_topic: /cmd_vel 11 | vel_recovery_topic: /cmd_vel_mux/input/recovery 12 | 13 | 14 | -------------------------------------------------------------------------------- /deployment/config/models.yaml: -------------------------------------------------------------------------------- 1 | navibridger_handcraft: 2 | config_path: "../model_weights/navibridger_handcraft.yaml" 3 | ckpt_path: "../model_weights/navibridger_handcraft.pth" 4 | 5 | navibridger_noise: 6 | config_path: "../model_weights/navibridger_noise.yaml" 7 | ckpt_path: "../model_weights/navibridger_noise.pth" 8 | 9 | navibridger_cvae: 10 | config_path: "../model_weights/navibridger_cvae.yaml" 11 | ckpt_path: "../model_weights/navibridger_cvae.pth" 12 | cvae: 13 | config_path: "../model_weights/cvae.yaml" 14 | ckpt_path: "../model_weights/cvae.pth" -------------------------------------------------------------------------------- /train/train_environment.yml: -------------------------------------------------------------------------------- 1 | name: navibridge_train 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python=3.8.5 7 | - cudatoolkit=10. 8 | - numpy 9 | - matplotlib 10 | - ipykernel 11 | - pip 12 | - pip: 13 | - torch 14 | - torchvision 15 | - tqdm==4.64.0 16 | - git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 17 | - opencv-python==4.6.0.66 18 | - h5py==3.6.0 19 | - wandb==0.12.18 20 | - --extra-index-url https://rospypi.github.io/simple/ 21 | - rosbag 22 | - roslz4 23 | - prettytable 24 | - efficientnet-pytorch 25 | - warmup-scheduler 26 | - diffusers==0.11.1 27 | - lmdb 28 | - vit-pytorch 29 | - positional-encodings 30 | - piq==0.8.0 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /train/vint_train/visualizing/visualize_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | VIZ_IMAGE_SIZE = (640, 480) 6 | RED = np.array([1, 0, 0]) 7 | GREEN = np.array([0, 1, 0]) 8 | BLUE = np.array([0, 0, 1]) 9 | CYAN = np.array([0, 1, 1]) 10 | YELLOW = np.array([1, 1, 0]) 11 | MAGENTA = np.array([1, 0, 1]) 12 | 13 | 14 | def numpy_to_img(arr: np.ndarray) -> Image: 15 | img = Image.fromarray(np.transpose(np.uint8(255 * arr), (1, 2, 0))) 16 | img = img.resize(VIZ_IMAGE_SIZE) 17 | return img 18 | 19 | 20 | def to_numpy(tensor: torch.Tensor) -> np.ndarray: 21 | return tensor.detach().cpu().numpy() 22 | 23 | 24 | def from_numpy(array: np.ndarray) -> torch.Tensor: 25 | return torch.from_numpy(array).float() 26 | -------------------------------------------------------------------------------- /train/vint_train/process_data/process_bags_config.yaml: -------------------------------------------------------------------------------- 1 | tartan_drive: 2 | odomtopics: "/odometry/filtered_odom" 3 | imtopics: "/multisense/left/image_rect_color" 4 | ang_offset: 1.5707963267948966 # pi/2 5 | img_process_func: "process_tartan_img" 6 | odom_process_func: "nav_to_xy_yaw" 7 | 8 | scand: 9 | odomtopics: ["/odom", "/jackal_velocity_controller/odom"] 10 | imtopics: ["/image_raw/compressed", "/camera/rgb/image_raw/compressed"] 11 | ang_offset: 0.0 12 | img_process_func: "process_scand_img" 13 | odom_process_func: "nav_to_xy_yaw" 14 | 15 | locobot: 16 | odomtopics: "/odom" 17 | imtopics: "/usb_cam/image_raw" 18 | ang_offset: 0.0 19 | img_process_func: "process_locobot_img" 20 | odom_process_func: "nav_to_xy_yaw" 21 | 22 | sacson: 23 | odomtopics: "/odometry" 24 | imtopics: "/fisheye_image/compressed" 25 | ang_offset: 0.0 26 | img_process_func: "process_sacson_img" 27 | odom_process_func: "nav_to_xy_yaw" 28 | 29 | # add your own datasets below: 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Hao Ren, Yiming Zeng, Zetong Bi, Zhaoliang Wan, 4 | Junlong Huang, Hui Cheng 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /train/config/defaults.yaml: -------------------------------------------------------------------------------- 1 | # defaults for training 2 | project_name: navibridge 3 | run_name: navibridge 4 | 5 | # training setup 6 | use_wandb: True # set to false if you don't want to log to wandb 7 | train: True 8 | batch_size: 400 9 | eval_batch_size: 400 10 | epochs: 30 11 | gpu_ids: [0] 12 | num_workers: 4 13 | lr: 5e-4 14 | optimizer: adam 15 | seed: 0 16 | clipping: False 17 | train_subset: 1. 18 | 19 | # model params 20 | model_type: navibridge 21 | obs_encoding_size: 1024 22 | goal_encoding_size: 1024 23 | 24 | # normalization for the action space 25 | normalize: True 26 | 27 | # context 28 | context_type: temporal 29 | context_size: 5 30 | 31 | # tradeoff between action and distance prediction loss 32 | alpha: 0.5 33 | 34 | # tradeoff between task loss and kld 35 | beta: 0.1 36 | 37 | obs_type: image 38 | goal_type: image 39 | scheduler: null 40 | 41 | # distance bounds for distance and action and distance predictions 42 | distance: 43 | min_dist_cat: 0 44 | max_dist_cat: 20 45 | action: 46 | min_dist_cat: 2 47 | max_dist_cat: 10 48 | close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint 49 | 50 | # action output params 51 | len_traj_pred: 5 52 | learn_angle: True 53 | 54 | # dataset specific parameters 55 | image_size: [85, 64] # width, height 56 | 57 | # logging stuff 58 | ## =0 turns off 59 | print_log_freq: 100 # in iterations 60 | image_log_freq: 1000 # in iterations 61 | num_images_log: 8 # number of images to log in a logging iteration 62 | pairwise_test_freq: 10 # in epochs 63 | eval_fraction: 0.25 # fraction of the dataset to use for evaluation 64 | wandb_log_freq: 10 # in iterations 65 | eval_freq: 1 # in epochs 66 | 67 | -------------------------------------------------------------------------------- /train/vint_train/models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import List, Dict, Optional, Tuple 5 | 6 | 7 | class BaseModel(nn.Module): 8 | def __init__( 9 | self, 10 | context_size: int = 5, 11 | len_traj_pred: Optional[int] = 5, 12 | learn_angle: Optional[bool] = True, 13 | ) -> None: 14 | """ 15 | Base Model main class 16 | Args: 17 | context_size (int): how many previous observations to used for context 18 | len_traj_pred (int): how many waypoints to predict in the future 19 | learn_angle (bool): whether to predict the yaw of the robot 20 | """ 21 | super(BaseModel, self).__init__() 22 | self.context_size = context_size 23 | self.learn_angle = learn_angle 24 | self.len_trajectory_pred = len_traj_pred 25 | if self.learn_angle: 26 | self.num_action_params = 4 # last two dims are the cos and sin of the angle 27 | else: 28 | self.num_action_params = 2 29 | 30 | def flatten(self, z: torch.Tensor) -> torch.Tensor: 31 | z = nn.functional.adaptive_avg_pool2d(z, (1, 1)) 32 | z = torch.flatten(z, 1) 33 | return z 34 | 35 | def forward( 36 | self, obs_img: torch.tensor, goal_img: torch.tensor 37 | ) -> Tuple[torch.Tensor, torch.Tensor]: 38 | """ 39 | Forward pass of the model 40 | Args: 41 | obs_img (torch.Tensor): batch of observations 42 | goal_img (torch.Tensor): batch of goals 43 | Returns: 44 | dist_pred (torch.Tensor): predicted distance to goal 45 | action_pred (torch.Tensor): predicted action 46 | """ 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /train/vint_train/training/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Logger: 5 | def __init__( 6 | self, 7 | name: str, 8 | dataset: str, 9 | window_size: int = 10, 10 | rounding: int = 4, 11 | ): 12 | """ 13 | Args: 14 | name (str): Name of the metric 15 | dataset (str): Name of the dataset 16 | window_size (int, optional): Size of the moving average window. Defaults to 10. 17 | rounding (int, optional): Number of decimals to round to. Defaults to 4. 18 | """ 19 | self.data = [] 20 | self.name = name 21 | self.dataset = dataset 22 | self.rounding = rounding 23 | self.window_size = window_size 24 | 25 | def display(self) -> str: 26 | latest = round(self.latest(), self.rounding) 27 | average = round(self.average(), self.rounding) 28 | moving_average = round(self.moving_average(), self.rounding) 29 | output = f"{self.full_name()}: {latest} ({self.window_size}pt moving_avg: {moving_average}) (avg: {average})" 30 | return output 31 | 32 | def log_data(self, data: float): 33 | if not np.isnan(data): 34 | self.data.append(data) 35 | 36 | def full_name(self) -> str: 37 | return f"{self.name} ({self.dataset})" 38 | 39 | def latest(self) -> float: 40 | if len(self.data) > 0: 41 | return self.data[-1] 42 | return np.nan 43 | 44 | def average(self) -> float: 45 | if len(self.data) > 0: 46 | return np.mean(self.data) 47 | return np.nan 48 | 49 | def moving_average(self) -> float: 50 | if len(self.data) > self.window_size: 51 | return np.mean(self.data[-self.window_size :]) 52 | return self.average() -------------------------------------------------------------------------------- /train/vint_train/data/data_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # global params for diffusion model 3 | # normalized min and max 4 | action_stats: 5 | min: [-2.5, -4] # [min_dx, min_dy] 6 | max: [5, 4] # [max_dx, max_dy] 7 | 8 | # data specific params 9 | recon: 10 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters) 11 | 12 | # OPTIONAL (FOR VISUALIZATION ONLY) 13 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html 14 | camera_height: 0.95 # meters 15 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera 16 | camera_matrix: 17 | fx: 272.547000 18 | fy: 266.358000 19 | cx: 320.000000 20 | cy: 220.000000 21 | dist_coeffs: 22 | k1: -0.038483 23 | k2: -0.010456 24 | p1: 0.003930 25 | p2: -0.001007 26 | k3: 0.0 27 | 28 | recon_test: 29 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters) 30 | 31 | # OPTIONAL (FOR VISUALIZATION ONLY) 32 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html 33 | camera_height: 0.95 # meters 34 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera 35 | camera_matrix: 36 | fx: 272.547000 37 | fy: 266.358000 38 | cx: 320.000000 39 | cy: 220.000000 40 | dist_coeffs: 41 | k1: -0.038483 42 | k2: -0.010456 43 | p1: 0.003930 44 | p2: -0.001007 45 | k3: 0.0 46 | 47 | scand: 48 | metric_waypoint_spacing: 0.38 49 | 50 | tartan_drive: 51 | metric_waypoint_spacing: 0.72 52 | 53 | go_stanford: 54 | metric_waypoint_spacing: 0.12 55 | 56 | # private datasets: 57 | cory_hall: 58 | metric_waypoint_spacing: 0.06 59 | 60 | seattle: 61 | metric_waypoint_spacing: 0.35 62 | 63 | racer: 64 | metric_waypoint_spacing: 0.38 65 | 66 | carla_intvns: 67 | metric_waypoint_spacing: 1.39 68 | 69 | carla_cil: 70 | metric_waypoint_spacing: 1.27 71 | 72 | carla_intvns: 73 | metric_waypoint_spacing: 1.39 74 | 75 | carla: 76 | metric_waypoint_spacing: 1.59 77 | image_path_func: get_image_path 78 | 79 | sacson: 80 | metric_waypoint_spacing: 0.255 81 | 82 | # add your own dataset params here: 83 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class PositionalEncoding(nn.Module): 7 | def __init__(self, d_model, max_seq_len=6): 8 | super().__init__() 9 | 10 | # Compute the positional encoding once 11 | pos_enc = torch.zeros(max_seq_len, d_model) 12 | pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) 13 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 14 | pos_enc[:, 0::2] = torch.sin(pos * div_term) 15 | pos_enc[:, 1::2] = torch.cos(pos * div_term) 16 | pos_enc = pos_enc.unsqueeze(0) 17 | 18 | # Register the positional encoding as a buffer to avoid it being 19 | # considered a parameter when saving the model 20 | self.register_buffer('pos_enc', pos_enc) 21 | 22 | def forward(self, x): 23 | # Add the positional encoding to the input 24 | x = x + self.pos_enc[:, :x.size(1), :] 25 | return x 26 | 27 | class MultiLayerDecoder(nn.Module): 28 | def __init__(self, embed_dim=512, seq_len=6, output_layers=[256, 128, 64], nhead=8, num_layers=8, ff_dim_factor=4): 29 | super(MultiLayerDecoder, self).__init__() 30 | self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len=seq_len) 31 | self.sa_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=ff_dim_factor*embed_dim, activation="gelu", batch_first=True, norm_first=True) 32 | self.sa_decoder = nn.TransformerEncoder(self.sa_layer, num_layers=num_layers) 33 | self.output_layers = nn.ModuleList([nn.Linear(seq_len*embed_dim, embed_dim)]) 34 | self.output_layers.append(nn.Linear(embed_dim, output_layers[0])) 35 | for i in range(len(output_layers)-1): 36 | self.output_layers.append(nn.Linear(output_layers[i], output_layers[i+1])) 37 | 38 | def forward(self, x): 39 | if self.positional_encoding: x = self.positional_encoding(x) 40 | x = self.sa_decoder(x) 41 | # currently, x is [batch_size, seq_len, embed_dim] 42 | x = x.reshape(x.shape[0], -1) 43 | for i in range(len(self.output_layers)): 44 | x = self.output_layers[i](x) 45 | x = F.relu(x) 46 | return x 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | train/logs/* 2 | train/datasets/* 3 | train/vint_train/data/data_splits/* 4 | train/wandb/* 5 | train/gnm_dataset/* 6 | train/models 7 | *_test.yaml 8 | 9 | *.png 10 | *.jpg 11 | *.pth 12 | *.mp4 13 | *.gif 14 | 15 | deployment/model_weights/* 16 | deployment/topomaps/* 17 | 18 | .vscode/* 19 | */.vscode/* 20 | 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | pip-wheel-metadata/ 45 | share/python-wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | MANIFEST 50 | 51 | # PyInstaller 52 | # Usually these files are written by a python script from a template 53 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 54 | *.manifest 55 | *.spec 56 | 57 | # Installer logs 58 | pip-log.txt 59 | pip-delete-this-directory.txt 60 | 61 | # Unit test / coverage reports 62 | htmlcov/ 63 | .tox/ 64 | .nox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | *.py,cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | -------------------------------------------------------------------------------- /train/process_recon.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import pickle 4 | from PIL import Image 5 | import io 6 | import argparse 7 | import tqdm 8 | 9 | 10 | def main(args: argparse.Namespace): 11 | recon_dir = os.path.join(args.input_dir, "recon_release") 12 | output_dir = args.output_dir 13 | 14 | # create output dir if it doesn't exist 15 | if not os.path.exists(output_dir): 16 | os.makedirs(output_dir) 17 | 18 | # get all the folders in the recon dataset 19 | filenames = os.listdir(recon_dir) 20 | if args.num_trajs >= 0: 21 | filenames = filenames[: args.num_trajs] 22 | 23 | # processing loop 24 | for filename in tqdm.tqdm(filenames, desc="Trajectories processed"): 25 | # extract the name without the extension 26 | traj_name = filename.split(".")[0] 27 | # load the hdf5 file 28 | try: 29 | h5_f = h5py.File(os.path.join(recon_dir, filename), "r") 30 | except OSError: 31 | print(f"Error loading {filename}. Skipping...") 32 | continue 33 | # extract the position and yaw data 34 | position_data = h5_f["jackal"]["position"][:, :2] 35 | yaw_data = h5_f["jackal"]["yaw"][()] 36 | # save the data to a dictionary 37 | traj_data = {"position": position_data, "yaw": yaw_data} 38 | traj_folder = os.path.join(output_dir, traj_name) 39 | os.makedirs(traj_folder, exist_ok=True) 40 | with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f: 41 | pickle.dump(traj_data, f) 42 | # make a folder for the file 43 | if not os.path.exists(traj_folder): 44 | os.makedirs(traj_folder) 45 | # save the image data to disk 46 | for i in range(h5_f["images"]["rgb_left"].shape[0]): 47 | img = Image.open(io.BytesIO(h5_f["images"]["rgb_left"][i])) 48 | img.save(os.path.join(traj_folder, f"{i}.jpg")) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | # get arguments for the recon input dir and the output dir 54 | parser.add_argument( 55 | "--input-dir", 56 | "-i", 57 | type=str, 58 | help="path of the recon_dataset", 59 | required=True, 60 | ) 61 | parser.add_argument( 62 | "--output-dir", 63 | "-o", 64 | default="datasets/recon/", 65 | type=str, 66 | help="path for processed recon dataset (default: datasets/recon/)", 67 | ) 68 | # number of trajs to process 69 | parser.add_argument( 70 | "--num-trajs", 71 | "-n", 72 | default=-1, 73 | type=int, 74 | help="number of trajectories to process (default: -1, all)", 75 | ) 76 | 77 | args = parser.parse_args() 78 | print("STARTING PROCESSING RECON DATASET") 79 | main(args) 80 | print("FINISHED PROCESSING RECON DATASET") 81 | -------------------------------------------------------------------------------- /train/data_split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import random 5 | 6 | 7 | def remove_files_in_dir(dir_path: str): 8 | for f in os.listdir(dir_path): 9 | file_path = os.path.join(dir_path, f) 10 | try: 11 | if os.path.isfile(file_path) or os.path.islink(file_path): 12 | os.unlink(file_path) 13 | elif os.path.isdir(file_path): 14 | shutil.rmtree(file_path) 15 | except Exception as e: 16 | print("Failed to delete %s. Reason: %s" % (file_path, e)) 17 | 18 | 19 | def main(args: argparse.Namespace): 20 | # Get the names of the folders in the data directory that contain the file 'traj_data.pkl' 21 | folder_names = [ 22 | f 23 | for f in os.listdir(args.data_dir) 24 | if os.path.isdir(os.path.join(args.data_dir, f)) 25 | and "traj_data.pkl" in os.listdir(os.path.join(args.data_dir, f)) 26 | ] 27 | 28 | # Randomly shuffle the names of the folders 29 | random.shuffle(folder_names) 30 | 31 | # Split the names of the folders into train and test sets 32 | split_index = int(args.split * len(folder_names)) 33 | train_folder_names = folder_names[:split_index] 34 | test_folder_names = folder_names[split_index:] 35 | 36 | # Create directories for the train and test sets 37 | train_dir = os.path.join(args.data_splits_dir, args.dataset_name, "train") 38 | test_dir = os.path.join(args.data_splits_dir, args.dataset_name, "test") 39 | for dir_path in [train_dir, test_dir]: 40 | if os.path.exists(dir_path): 41 | print(f"Clearing files from {dir_path} for new data split") 42 | remove_files_in_dir(dir_path) 43 | else: 44 | print(f"Creating {dir_path}") 45 | os.makedirs(dir_path) 46 | 47 | # Write the names of the train and test folders to files 48 | with open(os.path.join(train_dir, "traj_names.txt"), "w") as f: 49 | for folder_name in train_folder_names: 50 | f.write(folder_name + "\n") 51 | 52 | with open(os.path.join(test_dir, "traj_names.txt"), "w") as f: 53 | for folder_name in test_folder_names: 54 | f.write(folder_name + "\n") 55 | 56 | 57 | if __name__ == "__main__": 58 | # Set up the command line argument parser 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument( 62 | "--data-dir", "-i", help="Directory containing the data", required=True 63 | ) 64 | parser.add_argument( 65 | "--dataset-name", "-d", help="Name of the dataset", required=True 66 | ) 67 | parser.add_argument( 68 | "--split", "-s", type=float, default=0.8, help="Train/test split (default: 0.8)" 69 | ) 70 | parser.add_argument( 71 | "--data-splits-dir", "-o", default="vint_train/data/data_splits", help="Data splits directory" 72 | ) 73 | args = parser.parse_args() 74 | main(args) 75 | print("Done") 76 | -------------------------------------------------------------------------------- /train/config/cvae.yaml: -------------------------------------------------------------------------------- 1 | project_name: cvae 2 | run_name: cvae 3 | 4 | # training setup 5 | use_wandb: True # set to false if you don't want to log to wandb 6 | train: True 7 | batch_size: 256 8 | epochs: 30 9 | gpu_ids: [0] 10 | num_workers: 12 11 | lr: 1e-4 12 | optimizer: adamw 13 | clipping: False 14 | max_norm: 1. 15 | scheduler: "cosine" 16 | warmup: True 17 | warmup_epochs: 4 18 | cyclic_period: 10 19 | plateau_patience: 3 20 | plateau_factor: 0.5 21 | seed: 0 22 | save_freq: 1 23 | 24 | # model params 25 | model_type: cvae 26 | vision_encoder: navibridge_encoder 27 | encoding_size: 256 28 | obs_encoder: efficientnet-b0 29 | attn_unet: False 30 | cond_predict_scale: False 31 | mha_num_attention_heads: 4 32 | mha_num_attention_layers: 4 33 | mha_ff_dim_factor: 4 34 | down_dims: [64, 128, 256] 35 | 36 | # diffusion model params 37 | num_diffusion_iters: 10 38 | 39 | # mask 40 | goal_mask_prob: 0.5 41 | 42 | # normalization for the action space 43 | normalize: True 44 | 45 | # context 46 | context_type: temporal 47 | context_size: 3 # 5 48 | alpha: 1e-4 49 | 50 | # distance bounds for distance and action and distance predictions 51 | distance: 52 | min_dist_cat: 0 53 | max_dist_cat: 20 54 | action: 55 | min_dist_cat: 3 56 | max_dist_cat: 20 57 | 58 | # action output params 59 | len_traj_pred: 8 60 | action_dim: 2 61 | learn_angle: False 62 | 63 | # navibridge 64 | sampler_name: "uniform" 65 | pred_mode: "ve" 66 | weight_schedule: "karras" 67 | sigma_data: 0.5 68 | sigma_min: 0.002 69 | sigma_max: 80.0 70 | rho: 7.0 71 | beta_d: 2 72 | beta_min: 0.1 73 | cov_xy: 0. 74 | guidance: 1. 75 | # sample defaults 76 | clip_denoised: True 77 | sampler: "euler" 78 | churn_step_ratio: 0.33 79 | # prior settings 80 | prior_policy: "gaussian" # handcraft, gaussian, cvae 81 | class_num: 5 82 | 83 | angle_ranges: [[0, 67.5], 84 | [67.5, 112.5], 85 | [112.5, 180], 86 | [180, 270], 87 | [270, 360]] 88 | min_std_angle: 5.0 89 | max_std_angle: 20.0 90 | min_std_length: 1.0 91 | max_std_length: 5.0 92 | 93 | # cvae 94 | train_params: 95 | batch_size: 256 96 | num_itr: 3001 97 | lr: 0.5e-5 98 | lr_gamma: 0.99 99 | lr_step: 1000 100 | l2_norm: 0.0 101 | ema: 0.99 102 | 103 | 104 | diffuse_params: 105 | latent_dim: 64 106 | layer: 3 107 | net_type: vae_mlp 108 | ckpt_path: models/cvae.pth 109 | pretrain: False 110 | 111 | # dataset specific parameters 112 | image_size: [96, 96] # width, height 113 | datasets: 114 | recon: 115 | data_folder: ./datasets/recon 116 | train: ./datasets/data_splits/recon/train # path to train folder with traj_names.txt 117 | test: ./datasets/data_splits/recon/test # path to test folder with traj_names.txt 118 | end_slack: 3 # because many trajectories end in collisions 119 | goals_per_obs: 1 # how many goals are sampled per observation 120 | negative_mining: True # negative mining from the ViNG paper (Shah et al.) 121 | go_stanford: 122 | data_folder: ./datasets/go_stanford/ # datasets/stanford_go_new 123 | train: ./datasets/data_splits/go_stanford/train/ 124 | test: ./datasets/data_splits/go_stanford/test/ 125 | end_slack: 0 126 | goals_per_obs: 2 # increase dataset size 127 | negative_mining: True 128 | 129 | sacson: 130 | data_folder: ./datasets/sacson/ 131 | train: ./datasets/data_splits/sacson/train/ 132 | test: ./datasets/data_splits/sacson/test/ 133 | end_slack: 3 # because many trajectories end in collisions 134 | goals_per_obs: 1 135 | negative_mining: True 136 | 137 | scand: 138 | data_folder: ./datasets/scand/ 139 | train: ./datasets/data_splits/scand/train/ 140 | test: ./datasets/data_splits/scand/test/ 141 | end_slack: 0 142 | goals_per_obs: 1 143 | negative_mining: True 144 | 145 | # logging stuff 146 | ## =0 turns off 147 | print_log_freq: 500 # in iterations 148 | image_log_freq: 1000 #0 # in iterations 149 | num_images_log: 8 #0 150 | pairwise_test_freq: 0 # in epochs 151 | eval_fraction: 0.25 152 | wandb_log_freq: 10 # in iterations 153 | eval_freq: 1 # in epochs -------------------------------------------------------------------------------- /train/config/navibridge.yaml: -------------------------------------------------------------------------------- 1 | project_name: navibridge 2 | run_name: navibridge 3 | 4 | # training setup 5 | use_wandb: True # set to false if you don't want to log to wandb 6 | train: True 7 | batch_size: 256 8 | epochs: 30 9 | gpu_ids: [0] 10 | num_workers: 12 11 | lr: 1e-4 12 | optimizer: adamw 13 | clipping: False 14 | max_norm: 1. 15 | scheduler: "cosine" 16 | warmup: True 17 | warmup_epochs: 4 18 | cyclic_period: 10 19 | plateau_patience: 3 20 | plateau_factor: 0.5 21 | seed: 0 22 | save_freq: 1 23 | 24 | # model params 25 | model_type: navibridge 26 | vision_encoder: navibridge_encoder 27 | encoding_size: 256 28 | obs_encoder: efficientnet-b0 29 | attn_unet: False 30 | cond_predict_scale: False 31 | mha_num_attention_heads: 4 32 | mha_num_attention_layers: 4 33 | mha_ff_dim_factor: 4 34 | down_dims: [64, 128, 256] 35 | 36 | # diffusion model params 37 | num_diffusion_iters: 10 38 | 39 | # mask 40 | goal_mask_prob: 0.5 41 | 42 | # normalization for the action space 43 | normalize: True 44 | 45 | # context 46 | context_type: temporal 47 | context_size: 3 # 5 48 | alpha: 1e-4 49 | 50 | # distance bounds for distance and action and distance predictions 51 | distance: 52 | min_dist_cat: 0 53 | max_dist_cat: 20 54 | action: 55 | min_dist_cat: 3 56 | max_dist_cat: 20 57 | 58 | # action output params 59 | len_traj_pred: 8 60 | action_dim: 2 61 | learn_angle: False 62 | 63 | # navibridge 64 | sampler_name: "uniform" 65 | pred_mode: "ve" 66 | weight_schedule: "karras" 67 | sigma_data: 0.5 68 | sigma_min: 0.002 69 | sigma_max: 10.0 70 | rho: 7.0 71 | beta_d: 2 72 | beta_min: 0.1 73 | cov_xy: 0. 74 | guidance: 1. 75 | 76 | clip_denoised: True 77 | sampler: "euler" 78 | churn_step_ratio: 0.33 79 | # prior settings 80 | prior_policy: "gaussian" # handcraft, gaussian, cvae 81 | class_num: 5 82 | 83 | angle_ranges: [[0, 67.5], 84 | [67.5, 112.5], 85 | [112.5, 180], 86 | [180, 270], 87 | [270, 360]] 88 | min_std_angle: 5.0 89 | max_std_angle: 20.0 90 | min_std_length: 1.0 91 | max_std_length: 5.0 92 | 93 | # cvae 94 | train_params: 95 | batch_size: 256 96 | num_itr: 3001 97 | lr: 0.5e-5 98 | lr_gamma: 0.99 99 | lr_step: 1000 100 | l2_norm: 0.0 101 | ema: 0.99 102 | 103 | 104 | diffuse_params: 105 | latent_dim: 64 106 | layer: 3 107 | net_type: vae_mlp 108 | ckpt_path: models/cvae.pth 109 | pretrain: False 110 | 111 | # dataset specific parameters 112 | image_size: [96, 96] # width, height 113 | datasets: 114 | recon: 115 | data_folder: ./datasets/recon 116 | train: ./datasets/data_splits/recon/train # path to train folder with traj_names.txt 117 | test: ./datasets/data_splits/recon/test # path to test folder with traj_names.txt 118 | end_slack: 3 # because many trajectories end in collisions 119 | goals_per_obs: 1 # how many goals are sampled per observation 120 | negative_mining: True # negative mining from the ViNG paper (Shah et al.) 121 | 122 | go_stanford: 123 | data_folder: ./datasets/go_stanford/ # datasets/stanford_go_new 124 | train: ./datasets/data_splits/go_stanford/train/ 125 | test: ./datasets/data_splits/go_stanford/test/ 126 | end_slack: 0 127 | goals_per_obs: 2 # increase dataset size 128 | negative_mining: True 129 | 130 | sacson: 131 | data_folder: ./datasets/sacson/ 132 | train: ./datasets/data_splits/sacson/train/ 133 | test: ./datasets/data_splits/sacson/test/ 134 | end_slack: 3 # because many trajectories end in collisions 135 | goals_per_obs: 1 136 | negative_mining: True 137 | 138 | scand: 139 | data_folder: ./datasets/scand/ 140 | train: ./datasets/data_splits/scand/train/ 141 | test: ./datasets/data_splits/scand/test/ 142 | end_slack: 0 143 | goals_per_obs: 1 144 | negative_mining: True 145 | 146 | # logging stuff 147 | ## =0 turns off 148 | print_log_freq: 100 # in iterations 149 | image_log_freq: 1000 #0 # in iterations 150 | num_images_log: 8 #0 151 | pairwise_test_freq: 0 # in epochs 152 | eval_fraction: 0.25 153 | wandb_log_freq: 10 # in iterations 154 | eval_freq: 1 # in epochs -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/ddbm/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .karras_diffusion import KarrasDenoiser 4 | import numpy as np 5 | 6 | 7 | def get_workdir(exp): 8 | workdir = f'./workdir/{exp}' 9 | return workdir 10 | 11 | def cm_train_defaults(): 12 | return dict( 13 | teacher_model_path="", 14 | teacher_dropout=0.1, 15 | training_mode="consistency_distillation", 16 | target_ema_mode="fixed", 17 | scale_mode="fixed", 18 | total_training_steps=600000, 19 | start_ema=0.0, 20 | start_scales=40, 21 | end_scales=40, 22 | distill_steps_per_iter=50000, 23 | loss_norm="lpips", 24 | ) 25 | 26 | def sample_defaults(): 27 | return dict( 28 | generator="determ", 29 | clip_denoised=True, 30 | sampler="euler", 31 | s_churn=0.0, 32 | s_tmin=0.002, 33 | s_tmax=80, 34 | s_noise=1.0, 35 | steps=40, 36 | model_path="", 37 | seed=42, 38 | ts="", 39 | ) 40 | 41 | def create_ema_and_scales_fn( 42 | target_ema_mode, 43 | start_ema, 44 | scale_mode, 45 | start_scales, 46 | end_scales, 47 | total_steps, 48 | distill_steps_per_iter, 49 | ): 50 | def ema_and_scales_fn(step): 51 | if target_ema_mode == "fixed" and scale_mode == "fixed": 52 | target_ema = start_ema 53 | scales = start_scales 54 | elif target_ema_mode == "fixed" and scale_mode == "progressive": 55 | target_ema = start_ema 56 | scales = np.ceil( 57 | np.sqrt( 58 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) 59 | + start_scales**2 60 | ) 61 | - 1 62 | ).astype(np.int32) 63 | scales = np.maximum(scales, 1) 64 | scales = scales + 1 65 | 66 | elif target_ema_mode == "adaptive" and scale_mode == "progressive": 67 | scales = np.ceil( 68 | np.sqrt( 69 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) 70 | + start_scales**2 71 | ) 72 | - 1 73 | ).astype(np.int32) 74 | scales = np.maximum(scales, 1) 75 | c = -np.log(start_ema) * start_scales 76 | target_ema = np.exp(-c / scales) 77 | scales = scales + 1 78 | elif target_ema_mode == "fixed" and scale_mode == "progdist": 79 | distill_stage = step // distill_steps_per_iter 80 | scales = start_scales // (2**distill_stage) 81 | scales = np.maximum(scales, 2) 82 | 83 | sub_stage = np.maximum( 84 | step - distill_steps_per_iter * (np.log2(start_scales) - 1), 85 | 0, 86 | ) 87 | sub_stage = sub_stage // (distill_steps_per_iter * 2) 88 | sub_scales = 2 // (2**sub_stage) 89 | sub_scales = np.maximum(sub_scales, 1) 90 | 91 | scales = np.where(scales == 2, sub_scales, scales) 92 | 93 | target_ema = 1.0 94 | else: 95 | raise NotImplementedError 96 | 97 | return float(target_ema), int(scales) 98 | 99 | return ema_and_scales_fn 100 | 101 | 102 | def add_dict_to_argparser(parser, default_dict): 103 | for k, v in default_dict.items(): 104 | v_type = type(v) 105 | if v is None: 106 | v_type = str 107 | elif isinstance(v, bool): 108 | v_type = str2bool 109 | parser.add_argument(f"--{k}", default=v, type=v_type) 110 | 111 | 112 | def args_to_dict(args, keys): 113 | return {k: getattr(args, k) for k in keys} 114 | 115 | 116 | def str2bool(v): 117 | """ 118 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 119 | """ 120 | if isinstance(v, bool): 121 | return v 122 | if v.lower() in ("yes", "true", "t", "y", "1"): 123 | return True 124 | elif v.lower() in ("no", "false", "f", "n", "0"): 125 | return False 126 | else: 127 | raise argparse.ArgumentTypeError("boolean value expected") 128 | -------------------------------------------------------------------------------- /train/process_bags.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pickle 4 | from PIL import Image 5 | import io 6 | import argparse 7 | import tqdm 8 | import yaml 9 | import rosbag 10 | 11 | # utils 12 | from vint_train.process_data.process_data_utils import * 13 | 14 | 15 | def main(args: argparse.Namespace): 16 | 17 | # load the config file 18 | with open("vint_train/process_data/process_bags_config.yaml", "r") as f: 19 | config = yaml.load(f, Loader=yaml.FullLoader) 20 | 21 | # create output dir if it doesn't exist 22 | if not os.path.exists(args.output_dir): 23 | os.makedirs(args.output_dir) 24 | 25 | # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir 26 | bag_files = [] 27 | for root, dirs, files in os.walk(args.input_dir): 28 | for file in files: 29 | if file.endswith(".bag"): 30 | bag_files.append(os.path.join(root, file)) 31 | if args.num_trajs >= 0: 32 | bag_files = bag_files[: args.num_trajs] 33 | 34 | # processing loop 35 | for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"): 36 | try: 37 | b = rosbag.Bag(bag_path) 38 | except rosbag.ROSBagException as e: 39 | print(e) 40 | print(f"Error loading {bag_path}. Skipping...") 41 | continue 42 | 43 | # name is that folders separated by _ and then the last part of the path 44 | traj_name = "_".join(bag_path.split("/")[-2:])[:-4] 45 | 46 | # load the hdf5 file 47 | bag_img_data, bag_traj_data = get_images_and_odom( 48 | b, 49 | config[args.dataset_name]["imtopics"], 50 | config[args.dataset_name]["odomtopics"], 51 | eval(config[args.dataset_name]["img_process_func"]), 52 | eval(config[args.dataset_name]["odom_process_func"]), 53 | rate=args.sample_rate, 54 | ang_offset=config[args.dataset_name]["ang_offset"], 55 | ) 56 | 57 | 58 | if bag_img_data is None or bag_traj_data is None: 59 | print( 60 | f"{bag_path} did not have the topics we were looking for. Skipping..." 61 | ) 62 | continue 63 | # remove backwards movement 64 | cut_trajs = filter_backwards(bag_img_data, bag_traj_data) 65 | 66 | for i, (img_data_i, traj_data_i) in enumerate(cut_trajs): 67 | traj_name_i = traj_name + f"_{i}" 68 | traj_folder_i = os.path.join(args.output_dir, traj_name_i) 69 | # make a folder for the traj 70 | if not os.path.exists(traj_folder_i): 71 | os.makedirs(traj_folder_i) 72 | with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f: 73 | pickle.dump(traj_data_i, f) 74 | # save the image data to disk 75 | for i, img in enumerate(img_data_i): 76 | img.save(os.path.join(traj_folder_i, f"{i}.jpg")) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | # get arguments for the recon input dir and the output dir 82 | # add dataset name 83 | parser.add_argument( 84 | "--dataset-name", 85 | "-d", 86 | type=str, 87 | help="name of the dataset (must be in process_config.yaml)", 88 | default="tartan_drive", 89 | required=True, 90 | ) 91 | parser.add_argument( 92 | "--input-dir", 93 | "-i", 94 | type=str, 95 | help="path of the datasets with rosbags", 96 | required=True, 97 | ) 98 | parser.add_argument( 99 | "--output-dir", 100 | "-o", 101 | default="../datasets/tartan_drive/", 102 | type=str, 103 | help="path for processed dataset (default: ../datasets/tartan_drive/)", 104 | ) 105 | # number of trajs to process 106 | parser.add_argument( 107 | "--num-trajs", 108 | "-n", 109 | default=-1, 110 | type=int, 111 | help="number of bags to process (default: -1, all)", 112 | ) 113 | # sampling rate 114 | parser.add_argument( 115 | "--sample-rate", 116 | "-s", 117 | default=4.0, 118 | type=float, 119 | help="sampling rate (default: 4.0 hz)", 120 | ) 121 | 122 | args = parser.parse_args() 123 | # all caps for the dataset name 124 | print(f"STARTING PROCESSING {args.dataset_name.upper()} DATASET") 125 | main(args) 126 | print(f"FINISHED PROCESSING {args.dataset_name.upper()} DATASET") 127 | -------------------------------------------------------------------------------- /train/vint_train/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from typing import Any, Iterable, Tuple 5 | 6 | import torch 7 | from torchvision import transforms 8 | import torchvision.transforms.functional as TF 9 | import torch.nn.functional as F 10 | import io 11 | from typing import Union 12 | 13 | VISUALIZATION_IMAGE_SIZE = (160, 120) 14 | IMAGE_ASPECT_RATIO = ( 15 | 4 / 3 16 | ) # all images are centered cropped to a 4:3 aspect ratio in training 17 | 18 | 19 | 20 | def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"): 21 | data_ext = { 22 | "image": ".jpg", 23 | # add more data types here 24 | } 25 | return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}") 26 | 27 | 28 | def yaw_rotmat(yaw: float) -> np.ndarray: 29 | return np.array( 30 | [ 31 | [np.cos(yaw), -np.sin(yaw), 0.0], 32 | [np.sin(yaw), np.cos(yaw), 0.0], 33 | [0.0, 0.0, 1.0], 34 | ], 35 | ) 36 | 37 | 38 | def to_local_coords( 39 | positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float 40 | ) -> np.ndarray: 41 | """ 42 | Convert positions to local coordinates 43 | 44 | Args: 45 | positions (np.ndarray): positions to convert 46 | curr_pos (np.ndarray): current position 47 | curr_yaw (float): current yaw 48 | Returns: 49 | np.ndarray: positions in local coordinates 50 | """ 51 | rotmat = yaw_rotmat(curr_yaw) 52 | if positions.shape[-1] == 2: 53 | rotmat = rotmat[:2, :2] 54 | elif positions.shape[-1] == 3: 55 | pass 56 | else: 57 | raise ValueError 58 | 59 | return (positions - curr_pos).dot(rotmat) 60 | 61 | 62 | def calculate_deltas(waypoints: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Calculate deltas between waypoints 65 | 66 | Args: 67 | waypoints (torch.Tensor): waypoints 68 | Returns: 69 | torch.Tensor: deltas 70 | """ 71 | num_params = waypoints.shape[1] 72 | origin = torch.zeros(1, num_params) 73 | prev_waypoints = torch.concat((origin, waypoints[:-1]), axis=0) 74 | deltas = waypoints - prev_waypoints 75 | if num_params > 2: 76 | return calculate_sin_cos(deltas) 77 | return deltas 78 | 79 | 80 | def calculate_sin_cos(waypoints: torch.Tensor) -> torch.Tensor: 81 | """ 82 | Calculate sin and cos of the angle 83 | 84 | Args: 85 | waypoints (torch.Tensor): waypoints 86 | Returns: 87 | torch.Tensor: waypoints with sin and cos of the angle 88 | """ 89 | assert waypoints.shape[1] == 3 90 | angle_repr = torch.zeros_like(waypoints[:, :2]) 91 | angle_repr[:, 0] = torch.cos(waypoints[:, 2]) 92 | angle_repr[:, 1] = torch.sin(waypoints[:, 2]) 93 | return torch.concat((waypoints[:, :2], angle_repr), axis=1) 94 | 95 | 96 | def transform_images( 97 | img: Image.Image, transform: transforms, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO 98 | ): 99 | w, h = img.size 100 | if w > h: 101 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio 102 | else: 103 | img = TF.center_crop(img, (int(w / aspect_ratio), w)) 104 | viz_img = img.resize(VISUALIZATION_IMAGE_SIZE) 105 | viz_img = TF.to_tensor(viz_img) 106 | img = img.resize(image_resize_size) 107 | transf_img = transform(img) 108 | return viz_img, transf_img 109 | 110 | 111 | def resize_and_aspect_crop( 112 | img: Image.Image, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO 113 | ): 114 | w, h = img.size 115 | if w > h: 116 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio 117 | else: 118 | img = TF.center_crop(img, (int(w / aspect_ratio), w)) 119 | try: 120 | img = img.resize(image_resize_size) 121 | except: 122 | print("111111111111 ", image_resize_size) 123 | resize_img = TF.to_tensor(img) 124 | return resize_img 125 | 126 | 127 | def img_path_to_data(path: Union[str, io.BytesIO], image_resize_size: Tuple[int, int]) -> torch.Tensor: 128 | """ 129 | Load an image from a path and transform it 130 | Args: 131 | path (str): path to the image 132 | image_resize_size (Tuple[int, int]): size to resize the image to 133 | Returns: 134 | torch.Tensor: resized image as tensor 135 | """ 136 | # return transform_images(Image.open(path), transform, image_resize_size, aspect_ratio) 137 | return resize_and_aspect_crop(Image.open(path), image_resize_size) 138 | 139 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/vae/vae.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # This work is licensed under the NVIDIA Source Code License 5 | # for I2SB. To view a copy of this license, see the LICENSE file. 6 | # --------------------------------------------------------------- 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | from functools import partial 11 | import torch 12 | 13 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 14 | 15 | from torch_ema import ExponentialMovingAverage 16 | from vint_train.models.navibridge.vae.conditional_mlp_1D_vae import * 17 | import os 18 | 19 | def kl_divergence_normal(mean1, mean2, std1, std2): 20 | kl_loss = ((std2 + 1e-9).log() - (std1 + 1e-9).log() 21 | + (std1.pow(2) + (mean2 - mean1).pow(2)) 22 | / (2 * std2.pow(2) + 1e-9) - 0.5).sum(-1).mean() 23 | return kl_loss 24 | 25 | 26 | class VAEModel(): 27 | def __init__(self, model_args): 28 | self.net = None 29 | self.ema = None 30 | 31 | self.anneal_factor = 0.0 32 | self.prior_policy = 'gaussian' 33 | 34 | def sample_prior(self, cond, device, num_samples=None, x_prior=None, diffuse_step=None): 35 | num_samples = cond.shape[0] 36 | latent_sample = torch.randn((num_samples, self.net.latent_dim)).to(device) 37 | action_hat = self.net.decoder(torch.cat([cond.flatten(start_dim=1), latent_sample], dim=-1)) 38 | action_hat = action_hat.reshape(-1, self.net.len_traj_pred, self.net.action_dim) 39 | return action_hat 40 | 41 | def get_loss(self, obs_img, naction, device): 42 | nobs = obs_img.to(device).float().flatten(start_dim=1) 43 | naction = naction.to(device).float() 44 | 45 | latent_post_dist = self.net.encoder(torch.cat([nobs, naction.flatten(1)], dim=-1)) 46 | latent_post_rsample = latent_post_dist.rsample() 47 | latent_post_mean = latent_post_dist.mean 48 | latent_post_std = latent_post_dist.stddev 49 | 50 | latent_prior_mean = torch.zeros_like(latent_post_mean).float().to(device) 51 | latent_prior_std = torch.ones_like(latent_post_std).float().to(device) 52 | 53 | action_rec = self.net.decoder(torch.cat([nobs, latent_post_rsample], dim=-1)) 54 | rec_loss = torch.nn.functional.mse_loss(action_rec, naction.flatten(1)) * 10.0 55 | kl_loss = self.anneal_factor * kl_divergence_normal(latent_post_mean, latent_prior_mean, latent_post_std, latent_prior_std) 56 | 57 | self.anneal_factor += 0.0001 58 | self.anneal_factor = 0.1 if self.anneal_factor > 0.1 else self.anneal_factor 59 | 60 | loss = rec_loss + kl_loss 61 | loss_info = { 62 | 'loss': loss, 63 | 'rec_loss': rec_loss, 64 | 'kl_loss': kl_loss, 65 | } 66 | return loss, loss_info 67 | 68 | def log_info(self, writer, log, loss_info, optimizer, itr, num_itr): 69 | writer.add_scalar(itr, 'loss', loss_info['loss'].detach()) 70 | 71 | log.info("train_it {}/{} | lr:{} | loss:{}".format( 72 | 1 + itr, 73 | num_itr, 74 | "{:.2e}".format(optimizer.param_groups[0]['lr']), 75 | "{:+.2f}".format(loss_info['loss'].item()), 76 | )) 77 | 78 | def load_model(self, model_args, device): 79 | 80 | if model_args['net_type'] == 'vae_mlp': 81 | self.net = VAEConditionalMLP( 82 | action_dim=model_args['action_dim'], 83 | len_traj_pred=model_args['len_traj_pred'], 84 | global_cond_dim=3 * model_args['image_size'][0] * model_args['image_size'][1] * (model_args['context_size'] + 1), 85 | latent_dim=model_args['latent_dim'], 86 | layer=model_args['layer'] 87 | ) 88 | else: 89 | raise NotImplementedError 90 | 91 | self.ema = ExponentialMovingAverage(self.net.parameters(), decay=0.99) 92 | if model_args['pretrain']: 93 | checkpoint = torch.load(model_args['ckpt_path'], map_location="cpu") 94 | self.net.load_state_dict(checkpoint['net']) 95 | self.ema.load_state_dict(checkpoint["ema"]) 96 | 97 | self.net.to(device) 98 | self.ema.to(device) 99 | 100 | def save_model(self, ckpt_path, epoch): 101 | torch.save({ 102 | "net": self.net.state_dict(), 103 | "ema": self.ema.state_dict(), 104 | }, ckpt_path) 105 | print(f"Saved model to {ckpt_path}") 106 | 107 | 108 | def unsqueeze_xdim(z, xdim): 109 | bc_dim = (...,) + (None,) * len(xdim) 110 | return z[bc_dim] -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/vae/conditional_mlp_1D_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 5 | import numpy as np 6 | 7 | 8 | class SimpleMLP(nn.Module): 9 | def __init__(self, input_dim, hidden_dim, 10 | output_dim, layers: int, activation=nn.ELU): 11 | super().__init__() 12 | self._output_shape = (output_dim,) 13 | self._layers = layers 14 | self._hidden_size = hidden_dim 15 | self.activation = activation 16 | # For adjusting pytorch to tensorflow 17 | self._feature_size = input_dim 18 | # Defining the structure of the NN 19 | self.model = self.build_model() 20 | self.soft_plus = nn.Softplus() 21 | 22 | self._min = 0.01 23 | self._max = 10.0 24 | 25 | def build_model(self): 26 | model = [nn.Linear(self._feature_size, self._hidden_size)] 27 | model += [self.activation()] 28 | for i in range(self._layers - 1): 29 | model += [nn.Linear(self._hidden_size, self._hidden_size)] 30 | model += [self.activation()] 31 | model += [nn.Linear(self._hidden_size, int(np.prod(self._output_shape)))] 32 | return nn.Sequential(*model) 33 | 34 | def forward(self, features): 35 | shape_len = len(features.shape) 36 | if shape_len == 3: 37 | batch = features.shape[1] 38 | length = features.shape[0] 39 | features = features.reshape(-1, features.shape[-1]) 40 | outputs = self.model(features) 41 | 42 | if shape_len == 3: 43 | outputs = outputs.reshape(length, batch, -1) 44 | return outputs 45 | 46 | 47 | class NormalMLP(nn.Module): 48 | def __init__(self, input_dim, hidden_dim, 49 | output_dim, layers: int, activation=nn.ELU): 50 | super().__init__() 51 | self._output_shape = (output_dim,) 52 | self._layers = layers 53 | self._hidden_size = hidden_dim 54 | self.activation = activation 55 | # For adjusting pytorch to tensorflow 56 | self._feature_size = input_dim 57 | # Defining the structure of the NN 58 | self.model = self.build_model() 59 | self.soft_plus = nn.Softplus() 60 | 61 | self._min = 0.01 62 | self._max = 10.0 63 | 64 | def build_model(self): 65 | model = [nn.Linear(self._feature_size, self._hidden_size)] 66 | model += [self.activation()] 67 | for i in range(self._layers - 1): 68 | model += [nn.Linear(self._hidden_size, self._hidden_size)] 69 | model += [self.activation()] 70 | model += [nn.Linear(self._hidden_size, 2 * int(np.prod(self._output_shape)))] 71 | return nn.Sequential(*model) 72 | 73 | def forward(self, features): 74 | shape_len = len(features.shape) 75 | if shape_len == 3: 76 | batch = features.shape[1] 77 | length = features.shape[0] 78 | features = features.reshape(-1, features.shape[-1]) 79 | dist_inputs = self.model(features) 80 | reshaped_inputs_mean = dist_inputs[..., :np.prod(self._output_shape)] 81 | reshaped_inputs_std = dist_inputs[..., np.prod(self._output_shape):] 82 | 83 | reshaped_inputs_std = torch.clamp(self.soft_plus(reshaped_inputs_std), min=self._min, max=self._max) 84 | 85 | if shape_len == 3: 86 | reshaped_inputs_mean = reshaped_inputs_mean.reshape(length, batch, -1) 87 | reshaped_inputs_std = reshaped_inputs_std.reshape(length, batch, -1) 88 | return torch.distributions.independent.Independent( 89 | torch.distributions.Normal(reshaped_inputs_mean, reshaped_inputs_std), len(self._output_shape)) 90 | 91 | 92 | class VAEConditionalMLP(nn.Module): 93 | def __init__(self, 94 | action_dim, 95 | len_traj_pred, 96 | global_cond_dim, 97 | hidden_dim=512, 98 | latent_dim=64, 99 | layer=3 100 | ): 101 | """ 102 | input_dim: Dim of actions. 103 | global_cond_dim: Dim of global conditioning applied with FiLM 104 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 105 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 106 | down_dims: Channel size for each UNet level. 107 | The length of this array determines numebr of levels. 108 | kernel_size: Conv kernel size 109 | n_groups: Number of groups for GroupNorm 110 | """ 111 | 112 | super().__init__() 113 | input_dim = action_dim * len_traj_pred 114 | self.encoder = NormalMLP(input_dim=input_dim + global_cond_dim, hidden_dim=hidden_dim, 115 | output_dim=latent_dim, layers=layer) 116 | 117 | self.decoder = SimpleMLP(input_dim=latent_dim + global_cond_dim, hidden_dim=hidden_dim, 118 | output_dim=input_dim, layers=layer) 119 | 120 | self.latent_dim = latent_dim 121 | self.action_dim = action_dim 122 | self.len_traj_pred = len_traj_pred 123 | 124 | print("number of parameters: {:e}".format( 125 | sum(p.numel() for p in self.parameters())) 126 | ) 127 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/ddbm/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | from scipy.stats import norm 6 | import torch.distributed as dist 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | if name == "uniform": 10 | return UniformSampler(diffusion) 11 | elif name == "real-uniform": 12 | return RealUniformSampler(diffusion) 13 | elif name == "loss-second-moment": 14 | return LossSecondMomentResampler(diffusion) 15 | elif name == "lognormal": 16 | return LogNormalSampler(diffusion) 17 | else: 18 | raise NotImplementedError(f"unknown schedule sampler: {name}") 19 | 20 | class ScheduleSampler(ABC): 21 | @abstractmethod 22 | def weights(self): 23 | """ 24 | """ 25 | 26 | def sample(self, batch_size, device): 27 | w = self.weights() 28 | p = w / np.sum(w) 29 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 30 | indices = th.from_numpy(indices_np).long().to(device) 31 | weights_np = 1 / (len(p) * p[indices_np]) 32 | weights = th.from_numpy(weights_np).float().to(device) 33 | return indices, weights 34 | 35 | class UniformSampler(ScheduleSampler): 36 | def __init__(self, diffusion): 37 | self.diffusion = diffusion 38 | self._weights = np.ones([diffusion.num_timesteps]) 39 | 40 | def weights(self): 41 | return self._weights 42 | 43 | class RealUniformSampler: 44 | def __init__(self, diffusion): 45 | self.diffusion = diffusion 46 | self.sigma_max = diffusion.sigma_max 47 | self.sigma_min = diffusion.sigma_min 48 | 49 | def sample(self, batch_size, device): 50 | ts = th.rand(batch_size).to(device) * (self.sigma_max - self.sigma_min) + self.sigma_min 51 | return ts, th.ones_like(ts) 52 | 53 | class LossAwareSampler(ScheduleSampler): 54 | def update_with_local_losses(self, local_ts, local_losses): 55 | batch_sizes = [ 56 | th.tensor([0], dtype=th.int32, device=local_ts.device) 57 | for _ in range(dist.get_world_size()) 58 | ] 59 | dist.all_gather( 60 | batch_sizes, 61 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 62 | ) 63 | 64 | batch_sizes = [x.item() for x in batch_sizes] 65 | max_bs = max(batch_sizes) 66 | 67 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 68 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 69 | dist.all_gather(timestep_batches, local_ts) 70 | dist.all_gather(loss_batches, local_losses) 71 | timesteps = [ 72 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 73 | ] 74 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 75 | self.update_with_all_losses(timesteps, losses) 76 | 77 | @abstractmethod 78 | def update_with_all_losses(self, ts, losses): 79 | """ 80 | """ 81 | 82 | class LossSecondMomentResampler(LossAwareSampler): 83 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 84 | self.diffusion = diffusion 85 | self.history_per_term = history_per_term 86 | self.uniform_prob = uniform_prob 87 | self._loss_history = np.zeros( 88 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 89 | ) 90 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 91 | 92 | def weights(self): 93 | if not self._warmed_up(): 94 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 95 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) 96 | weights /= np.sum(weights) 97 | weights *= 1 - self.uniform_prob 98 | weights += self.uniform_prob / len(weights) 99 | return weights 100 | 101 | def update_with_all_losses(self, ts, losses): 102 | for t, loss in zip(ts, losses): 103 | if self._loss_counts[t] == self.history_per_term: 104 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 105 | self._loss_history[t, -1] = loss 106 | else: 107 | self._loss_history[t, self._loss_counts[t]] = loss 108 | self._loss_counts[t] += 1 109 | 110 | def _warmed_up(self): 111 | return (self._loss_counts == self.history_per_term).all() 112 | 113 | class LogNormalSampler: 114 | def __init__(self, diffusion, p_mean=-1.2, p_std=1.2, even=False): 115 | self.p_mean = p_mean 116 | self.p_std = p_std 117 | self.even = even 118 | if self.even: 119 | self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std) 120 | self.rank, self.size = dist.get_rank(), dist.get_world_size() 121 | 122 | def sample(self, bs, device): 123 | if self.even: 124 | start_i, end_i = self.rank * bs, (self.rank + 1) * bs 125 | global_batch_size = self.size * bs 126 | locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size 127 | log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device) 128 | else: 129 | log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device) 130 | 131 | sigmas = th.exp(log_sigmas) 132 | weights = th.ones_like(sigmas) 133 | return sigmas, weights 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prior Does Matter: Visual Navigation via Denoising Diffusion Bridge Models (NaviBridger) 2 | 3 | > 🏆 Accepted at **CVPR 2025** 4 | > 🔗 [Github](https://github.com/hren20/NaiviBridger) | [arXiv](https://arxiv.org/abs/2504.10041) 5 | 6 |

7 | Overview 8 |

9 | 10 | --- 11 | 12 | ## 📌 TLDR 13 | 14 | NaviBridger is a novel framework for visual navigation built upon **Denoising Diffusion Bridge Models (DDBMs)**. Unlike traditional diffusion policies that start from Gaussian noise, NaviBridger leverages **prior actions** (rule-based or learned) to guide the denoising process, accelerating convergence and improving trajectory accuracy. 15 | 16 | --- 17 | 18 | ## 🛠️ Key Features 19 | - 🔧 DDBM-based policy generation from arbitrary priors 20 | - 🔁 Unified framework supporting Gaussian, rule-based, and learning-based priors 21 | - 🏃‍♂️ Real-world deployment support on mobile robots (e.g., Diablo + Jetson Orin AGX) 22 | 23 | --- 24 | 25 | ## ✅ TODO List 26 | 27 | - \[x\] Model weights released 28 | - [ ] Deployment code updates 29 | - [ ] A refactored version of the code (in the coming weeks) 30 | 31 | --- 32 | 33 | ## 📁 Directory Overview 34 | 35 | ``` 36 | navibridge/ 37 | ├── train/ # Training code and dataset processing 38 | │ ├── vint_train/ # NaviBridger models, configs, and datasets 39 | │ ├── train.py # Training entry point 40 | │ ├── process_*.py # Data preprocessing scripts 41 | │ └── train_environment.yml # Conda setup for training 42 | ├── deployment/ # Inference and deployment 43 | │ ├── src/navibridger_inference.py 44 | │ ├── config/params.yaml # Inference config 45 | │ ├── deployment_environment.yaml 46 | │ └── model_weights/ # Place for .pth model weights and corresponding .yaml config 47 | └── README.md # This file 48 | ``` 49 | 50 | --- 51 | 52 | ## ⚙️ Setup 53 | 54 | ### 🧪 Environment (Training) 55 | 56 | ```bash 57 | conda env create -f train/train_environment.yml 58 | conda activate navibridge_train 59 | pip install -e train/ 60 | git clone git@github.com:real-stanford/diffusion_policy.git 61 | pip install -e diffusion_policy/ 62 | ``` 63 | 64 | ### 💻 Environment (Deployment) 65 | 66 | ```bash 67 | conda env create -f deployment/deployment_environment.yaml 68 | conda activate navibridge 69 | pip install -e train/ 70 | pip install -e diffusion_policy/ 71 | ``` 72 | 73 | --- 74 | 75 | ## 📦 Data Preparation 76 | 77 | 1. Download public datasets: 78 | - [RECON](https://sites.google.com/view/recon-robot/dataset) 79 | - [SCAND](https://www.cs.utexas.edu/~xiao/SCAND/) 80 | - [GoStanford2](https://cvgl.stanford.edu/gonet/dataset/) 81 | - [SACSoN](https://sites.google.com/view/sacson-review/huron-dataset) 82 | 83 | 2. Process datasets: 84 | ```bash 85 | python train/process_recon.py # or process_bags.py 86 | python train/data_split.py --dataset 87 | ``` 88 | 89 | 3. Expected format: 90 | ``` 91 | dataset_name/ 92 | ├── traj1/ 93 | │ ├── 0.jpg ... T_1.jpg 94 | │ └── traj_data.pkl 95 | └── ... 96 | ``` 97 | 98 | After `data_split.py`, you should have: 99 | ``` 100 | train/vint_train/data/data_splits/ 101 | └── / 102 | ├── train/traj_names.txt 103 | └── test/traj_names.txt 104 | ``` 105 | 106 | --- 107 | 108 | ## 🧠 Model Training 109 | 110 | ```bash 111 | cd train/ 112 | python train.py -c config/navibridge.yaml # Select the training type by changing prior_policy 113 | ``` 114 | 115 | --- 116 | 117 | For learning-based method, training CVAE first: 118 | ```bash 119 | python train.py -c config/cvae.yaml 120 | ``` 121 | 122 | You can download the *.pth file from [this link](https://drive.google.com/drive/folders/14YhZnqFH9M6Y2fJc6LJwnO1eBIEeDAMC?usp=sharing) to try out the inference 123 | 124 | --- 125 | 126 | ## 🚀 Inference Demo 127 | 128 | 1. Place your trained model and config in: 129 | 130 | ``` 131 | deployment/model_weights/*.pth 132 | deployment/model_weights/*.yaml 133 | ``` 134 | 2. Adjust model path `deplyment/config/models.yaml` 135 | 3. Prepare input images (minimum 4): `0.png`, `1.png`, etc. 136 | Adjust input directory path in `deployment/config/params.yaml`. 137 | 138 | 4. Run: 139 | 140 | ```bash 141 | python deployment/src/navibridger_inference.py --model navibridge_cvae # Model name corresponding to key value in deplyment/config/models.yaml 142 | ``` 143 | 144 | --- 145 | 146 | ## 🤖 Hardware Tested 147 | Here is our deployment platform information, you can replace it at will. 148 | 149 | - NVIDIA Jetson Orin AGX 150 | - Intel RealSense D435i 151 | - Diablo wheeled-legged robot 152 | 153 | > 📸 RGB-only input, no depth or LiDAR required. 154 | 155 | --- 156 | 157 | ## 🧪 Citation 158 | 159 | ```bibtex 160 | @inproceedings{ren2025prior, 161 | title={Prior does matter: Visual navigation via denoising diffusion bridge models}, 162 | author={Ren, Hao and Zeng, Yiming and Bi, Zetong and Wan, Zhaoliang and Huang, Junlong and Cheng, Hui}, 163 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 164 | pages={12100--12110}, 165 | year={2025} 166 | } 167 | ``` 168 | 169 | --- 170 | 171 | ## 📜 License 172 | 173 | This codebase is released under the [MIT License](LICENSE). 174 | 175 | ## Acknowledgment 176 | NaviBridger is inspired by the contributions of the following works to the open-source community: [DDBM](https://github.com/alexzhou907/DDBM), [NoMaD](https://github.com/robodhruv/visualnav-transformer), and [BRIDGER](https://github.com/clear-nus/bridger). We thank the authors for sharing their outstanding work. -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/ddbm/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 14 | class SiLU(nn.Module): 15 | def forward(self, x): 16 | return x * th.sigmoid(x) 17 | 18 | 19 | class GroupNorm32(nn.GroupNorm): 20 | def forward(self, x): 21 | return super().forward(x.float()).type(x.dtype) 22 | 23 | 24 | def conv_nd(dims, *args, **kwargs): 25 | """ 26 | Create a 1D, 2D, or 3D convolution module. 27 | """ 28 | if dims == 1: 29 | return nn.Conv1d(*args, **kwargs) 30 | elif dims == 2: 31 | return nn.Conv2d(*args, **kwargs) 32 | elif dims == 3: 33 | return nn.Conv3d(*args, **kwargs) 34 | raise ValueError(f"unsupported dimensions: {dims}") 35 | 36 | 37 | def linear(*args, **kwargs): 38 | """ 39 | Create a linear module. 40 | """ 41 | return nn.Linear(*args, **kwargs) 42 | 43 | 44 | def avg_pool_nd(dims, *args, **kwargs): 45 | """ 46 | Create a 1D, 2D, or 3D average pooling module. 47 | """ 48 | if dims == 1: 49 | return nn.AvgPool1d(*args, **kwargs) 50 | elif dims == 2: 51 | return nn.AvgPool2d(*args, **kwargs) 52 | elif dims == 3: 53 | return nn.AvgPool3d(*args, **kwargs) 54 | raise ValueError(f"unsupported dimensions: {dims}") 55 | 56 | 57 | def update_ema(target_params, source_params, rate=0.99): 58 | """ 59 | Update target parameters to be closer to those of source parameters using 60 | an exponential moving average. 61 | 62 | :param target_params: the target parameter sequence. 63 | :param source_params: the source parameter sequence. 64 | :param rate: the EMA rate (closer to 1 means slower). 65 | """ 66 | for targ, src in zip(target_params, source_params): 67 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 68 | 69 | 70 | def zero_module(module): 71 | """ 72 | Zero out the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().zero_() 76 | return module 77 | 78 | 79 | def scale_module(module, scale): 80 | """ 81 | Scale the parameters of a module and return it. 82 | """ 83 | for p in module.parameters(): 84 | p.detach().mul_(scale) 85 | return module 86 | 87 | 88 | def mean_flat(tensor): 89 | """ 90 | Take the mean over all non-batch dimensions. 91 | """ 92 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 93 | 94 | 95 | def append_dims(x, target_dims): 96 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 97 | dims_to_append = target_dims - x.ndim 98 | if dims_to_append < 0: 99 | raise ValueError( 100 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 101 | ) 102 | return x[(...,) + (None,) * dims_to_append] 103 | 104 | 105 | def append_zero(x): 106 | return th.cat([x, x.new_zeros([1])]) 107 | 108 | 109 | def normalization(channels): 110 | """ 111 | Make a standard normalization layer. 112 | 113 | :param channels: number of input channels. 114 | :return: an nn.Module for normalization. 115 | """ 116 | return GroupNorm32(32, channels) 117 | 118 | 119 | def timestep_embedding(timesteps, dim, max_period=10000): 120 | """ 121 | Create sinusoidal timestep embeddings. 122 | 123 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 124 | These may be fractional. 125 | :param dim: the dimension of the output. 126 | :param max_period: controls the minimum frequency of the embeddings. 127 | :return: an [N x dim] Tensor of positional embeddings. 128 | """ 129 | half = dim // 2 130 | freqs = th.exp( 131 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 132 | ).to(device=timesteps.device) 133 | args = timesteps[:, None].float() * freqs[None] 134 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 135 | if dim % 2: 136 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 137 | return embedding 138 | 139 | 140 | def checkpoint(func, inputs, params, flag): 141 | """ 142 | Evaluate a function without caching intermediate activations, allowing for 143 | reduced memory at the expense of extra compute in the backward pass. 144 | 145 | :param func: the function to evaluate. 146 | :param inputs: the argument sequence to pass to `func`. 147 | :param params: a sequence of parameters `func` depends on but does not 148 | explicitly take as arguments. 149 | :param flag: if False, disable gradient checkpointing. 150 | """ 151 | if flag: 152 | args = tuple(inputs) + tuple(params) 153 | return CheckpointFunction.apply(func, len(inputs), *args) 154 | else: 155 | return func(*inputs) 156 | 157 | 158 | class CheckpointFunction(th.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, run_function, length, *args): 161 | ctx.run_function = run_function 162 | ctx.input_tensors = list(args[:length]) 163 | ctx.input_params = list(args[length:]) 164 | with th.no_grad(): 165 | output_tensors = ctx.run_function(*ctx.input_tensors) 166 | return output_tensors 167 | 168 | @staticmethod 169 | def backward(ctx, *output_grads): 170 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 171 | with th.enable_grad(): 172 | # Fixes a bug where the first op in run_function modifies the 173 | # Tensor storage in place, which is not allowed for detach()'d 174 | # Tensors. 175 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 176 | output_tensors = ctx.run_function(*shallow_copies) 177 | input_grads = th.autograd.grad( 178 | output_tensors, 179 | ctx.input_tensors + ctx.input_params, 180 | output_grads, 181 | allow_unused=True, 182 | ) 183 | del ctx.input_tensors 184 | del ctx.input_params 185 | del output_tensors 186 | return (None, None) + input_grads 187 | -------------------------------------------------------------------------------- /train/vint_train/visualizing/distance_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import numpy as np 4 | from typing import List, Optional, Tuple 5 | from vint_train.visualizing.visualize_utils import numpy_to_img 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def visualize_dist_pred( 10 | batch_obs_images: np.ndarray, 11 | batch_goal_images: np.ndarray, 12 | batch_dist_preds: np.ndarray, 13 | batch_dist_labels: np.ndarray, 14 | eval_type: str, 15 | save_folder: str, 16 | epoch: int, 17 | num_images_preds: int = 8, 18 | use_wandb: bool = True, 19 | display: bool = False, 20 | rounding: int = 4, 21 | dist_error_threshold: float = 3.0, 22 | ): 23 | """ 24 | Visualize the distance classification predictions and labels for an observation-goal image pair. 25 | 26 | Args: 27 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] 28 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels] 29 | batch_dist_preds (np.ndarray): batch of distance predictions [batch_size] 30 | batch_dist_labels (np.ndarray): batch of distance labels [batch_size] 31 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.) 32 | epoch (int): current epoch number 33 | num_images_preds (int): number of images to visualize 34 | use_wandb (bool): whether to use wandb to log the images 35 | save_folder (str): folder to save the images. If None, will not save the images 36 | display (bool): whether to display the images 37 | rounding (int): number of decimal places to round the distance predictions and labels 38 | dist_error_threshold (float): distance error threshold for classifying the distance prediction as correct or incorrect (only used for visualization purposes) 39 | """ 40 | visualize_path = os.path.join( 41 | save_folder, 42 | "visualize", 43 | eval_type, 44 | f"epoch{epoch}", 45 | "dist_classification", 46 | ) 47 | if not os.path.isdir(visualize_path): 48 | os.makedirs(visualize_path) 49 | assert ( 50 | len(batch_obs_images) 51 | == len(batch_goal_images) 52 | == len(batch_dist_preds) 53 | == len(batch_dist_labels) 54 | ) 55 | batch_size = batch_obs_images.shape[0] 56 | wandb_list = [] 57 | for i in range(min(batch_size, num_images_preds)): 58 | dist_pred = np.round(batch_dist_preds[i], rounding) 59 | dist_label = np.round(batch_dist_labels[i], rounding) 60 | obs_image = numpy_to_img(batch_obs_images[i]) 61 | goal_image = numpy_to_img(batch_goal_images[i]) 62 | 63 | save_path = None 64 | if save_folder is not None: 65 | save_path = os.path.join(visualize_path, f"{i}.png") 66 | text_color = "black" 67 | if abs(dist_pred - dist_label) > dist_error_threshold: 68 | text_color = "red" 69 | 70 | display_distance_pred( 71 | [obs_image, goal_image], 72 | ["Observation", "Goal"], 73 | dist_pred, 74 | dist_label, 75 | text_color, 76 | save_path, 77 | display, 78 | ) 79 | if use_wandb: 80 | wandb_list.append(wandb.Image(save_path)) 81 | if use_wandb: 82 | wandb.log({f"{eval_type}_dist_prediction": wandb_list}, commit=False) 83 | 84 | 85 | def visualize_dist_pairwise_pred( 86 | batch_obs_images: np.ndarray, 87 | batch_close_images: np.ndarray, 88 | batch_far_images: np.ndarray, 89 | batch_close_preds: np.ndarray, 90 | batch_far_preds: np.ndarray, 91 | batch_close_labels: np.ndarray, 92 | batch_far_labels: np.ndarray, 93 | eval_type: str, 94 | save_folder: str, 95 | epoch: int, 96 | num_images_preds: int = 8, 97 | use_wandb: bool = True, 98 | display: bool = False, 99 | rounding: int = 4, 100 | ): 101 | """ 102 | Visualize the distance classification predictions and labels for an observation-goal image pair. 103 | 104 | Args: 105 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] 106 | batch_close_images (np.ndarray): batch of close goal images [batch_size, height, width, channels] 107 | batch_far_images (np.ndarray): batch of far goal images [batch_size, height, width, channels] 108 | batch_close_preds (np.ndarray): batch of close predictions [batch_size] 109 | batch_far_preds (np.ndarray): batch of far predictions [batch_size] 110 | batch_close_labels (np.ndarray): batch of close labels [batch_size] 111 | batch_far_labels (np.ndarray): batch of far labels [batch_size] 112 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.) 113 | save_folder (str): folder to save the images. If None, will not save the images 114 | epoch (int): current epoch number 115 | num_images_preds (int): number of images to visualize 116 | use_wandb (bool): whether to use wandb to log the images 117 | display (bool): whether to display the images 118 | rounding (int): number of decimal places to round the distance predictions and labels 119 | """ 120 | visualize_path = os.path.join( 121 | save_folder, 122 | "visualize", 123 | eval_type, 124 | f"epoch{epoch}", 125 | "pairwise_dist_classification", 126 | ) 127 | if not os.path.isdir(visualize_path): 128 | os.makedirs(visualize_path) 129 | assert ( 130 | len(batch_obs_images) 131 | == len(batch_close_images) 132 | == len(batch_far_images) 133 | == len(batch_close_preds) 134 | == len(batch_far_preds) 135 | == len(batch_close_labels) 136 | == len(batch_far_labels) 137 | ) 138 | batch_size = batch_obs_images.shape[0] 139 | wandb_list = [] 140 | for i in range(min(batch_size, num_images_preds)): 141 | close_dist_pred = np.round(batch_close_preds[i], rounding) 142 | far_dist_pred = np.round(batch_far_preds[i], rounding) 143 | close_dist_label = np.round(batch_close_labels[i], rounding) 144 | far_dist_label = np.round(batch_far_labels[i], rounding) 145 | obs_image = numpy_to_img(batch_obs_images[i]) 146 | close_image = numpy_to_img(batch_close_images[i]) 147 | far_image = numpy_to_img(batch_far_images[i]) 148 | 149 | save_path = None 150 | if save_folder is not None: 151 | save_path = os.path.join(visualize_path, f"{i}.png") 152 | 153 | if close_dist_pred < far_dist_pred: 154 | text_color = "black" 155 | else: 156 | text_color = "red" 157 | 158 | display_distance_pred( 159 | [obs_image, close_image, far_image], 160 | ["Observation", "Close Goal", "Far Goal"], 161 | f"close_pred = {close_dist_pred}, far_pred = {far_dist_pred}", 162 | f"close_label = {close_dist_label}, far_label = {far_dist_label}", 163 | text_color, 164 | save_path, 165 | display, 166 | ) 167 | if use_wandb: 168 | wandb_list.append(wandb.Image(save_path)) 169 | if use_wandb: 170 | wandb.log({f"{eval_type}_pairwise_classification": wandb_list}, commit=False) 171 | 172 | 173 | def display_distance_pred( 174 | imgs: list, 175 | titles: list, 176 | dist_pred: float, 177 | dist_label: float, 178 | text_color: str = "black", 179 | save_path: Optional[str] = None, 180 | display: bool = False, 181 | ): 182 | plt.figure() 183 | fig, ax = plt.subplots(1, len(imgs)) 184 | 185 | plt.suptitle(f"prediction: {dist_pred}\nlabel: {dist_label}", color=text_color) 186 | 187 | for axis, img, title in zip(ax, imgs, titles): 188 | axis.imshow(img) 189 | axis.set_title(title) 190 | axis.xaxis.set_visible(False) 191 | axis.yaxis.set_visible(False) 192 | 193 | # make the plot large 194 | fig.set_size_inches((18.5 / 3) * len(imgs), 10.5) 195 | 196 | if save_path is not None: 197 | fig.savefig( 198 | save_path, 199 | bbox_inches="tight", 200 | ) 201 | if not display: 202 | plt.close(fig) 203 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/navibridg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from typing import List, Dict, Optional, Tuple, Callable 6 | from efficientnet_pytorch import EfficientNet 7 | from vint_train.models.navibridge.self_attention import PositionalEncoding 8 | 9 | class NaviBridge_Encoder(nn.Module): 10 | def __init__( 11 | self, 12 | context_size: int = 5, 13 | obs_encoder: Optional[str] = "efficientnet-b0", 14 | obs_encoding_size: Optional[int] = 512, 15 | mha_num_attention_heads: Optional[int] = 2, 16 | mha_num_attention_layers: Optional[int] = 2, 17 | mha_ff_dim_factor: Optional[int] = 4, 18 | ) -> None: 19 | """ 20 | Encoder class 21 | """ 22 | super().__init__() 23 | self.obs_encoding_size = obs_encoding_size 24 | self.goal_encoding_size = obs_encoding_size 25 | self.context_size = context_size 26 | 27 | # Initialize the observation encoder 28 | if obs_encoder.split("-")[0] == "efficientnet": 29 | self.obs_encoder = EfficientNet.from_name(obs_encoder, in_channels=3) # context 30 | self.obs_encoder = replace_bn_with_gn(self.obs_encoder) 31 | self.num_obs_features = self.obs_encoder._fc.in_features 32 | self.obs_encoder_type = "efficientnet" 33 | else: 34 | raise NotImplementedError 35 | 36 | # Initialize the goal encoder 37 | self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=6) # obs+goal 38 | self.goal_encoder = replace_bn_with_gn(self.goal_encoder) 39 | self.num_goal_features = self.goal_encoder._fc.in_features 40 | 41 | # Initialize compression layers if necessary 42 | if self.num_obs_features != self.obs_encoding_size: 43 | self.compress_obs_enc = nn.Linear(self.num_obs_features, self.obs_encoding_size) 44 | else: 45 | self.compress_obs_enc = nn.Identity() 46 | 47 | if self.num_goal_features != self.goal_encoding_size: 48 | self.compress_goal_enc = nn.Linear(self.num_goal_features, self.goal_encoding_size) 49 | else: 50 | self.compress_goal_enc = nn.Identity() 51 | 52 | # Initialize positional encoding and self-attention layers 53 | self.positional_encoding = PositionalEncoding(self.obs_encoding_size, max_seq_len=self.context_size + 2) 54 | self.sa_layer = nn.TransformerEncoderLayer( 55 | d_model=self.obs_encoding_size, 56 | nhead=mha_num_attention_heads, 57 | dim_feedforward=mha_ff_dim_factor*self.obs_encoding_size, 58 | activation="gelu", 59 | batch_first=True, 60 | norm_first=True 61 | ) 62 | self.sa_encoder = nn.TransformerEncoder(self.sa_layer, num_layers=mha_num_attention_layers) 63 | 64 | # Definition of the goal mask (convention: 0 = no mask, 1 = mask) 65 | self.goal_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool) 66 | self.goal_mask[:, -1] = True # Mask out the goal 67 | self.no_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool) 68 | self.all_masks = torch.cat([self.no_mask, self.goal_mask], dim=0) 69 | self.avg_pool_mask = torch.cat([1 - self.no_mask.float(), (1 - self.goal_mask.float()) * ((self.context_size + 2)/(self.context_size + 1))], dim=0) 70 | 71 | 72 | def forward(self, obs_img: torch.tensor, goal_img: torch.tensor, input_goal_mask: torch.tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 73 | 74 | device = obs_img.device 75 | 76 | # Initialize the goal encoding 77 | goal_encoding = torch.zeros((obs_img.size()[0], 1, self.goal_encoding_size)).to(device) 78 | 79 | # Get the input goal mask 80 | if input_goal_mask is not None: 81 | goal_mask = input_goal_mask.to(device) 82 | 83 | # Get the goal encoding 84 | obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1) # concatenate the obs image/context and goal image --> non image goal? 85 | obsgoal_encoding = self.goal_encoder.extract_features(obsgoal_img) # get encoding of this img 86 | obsgoal_encoding = self.goal_encoder._avg_pooling(obsgoal_encoding) # avg pooling 87 | 88 | if self.goal_encoder._global_params.include_top: 89 | obsgoal_encoding = obsgoal_encoding.flatten(start_dim=1) 90 | obsgoal_encoding = self.goal_encoder._dropout(obsgoal_encoding) 91 | obsgoal_encoding = self.compress_goal_enc(obsgoal_encoding) 92 | 93 | if len(obsgoal_encoding.shape) == 2: 94 | obsgoal_encoding = obsgoal_encoding.unsqueeze(1) 95 | assert obsgoal_encoding.shape[2] == self.goal_encoding_size 96 | goal_encoding = obsgoal_encoding 97 | 98 | # Get the observation encoding 99 | obs_img = torch.split(obs_img, 3, dim=1) 100 | obs_img = torch.concat(obs_img, dim=0) 101 | 102 | obs_encoding = self.obs_encoder.extract_features(obs_img) 103 | obs_encoding = self.obs_encoder._avg_pooling(obs_encoding) 104 | if self.obs_encoder._global_params.include_top: 105 | obs_encoding = obs_encoding.flatten(start_dim=1) 106 | obs_encoding = self.obs_encoder._dropout(obs_encoding) 107 | obs_encoding = self.compress_obs_enc(obs_encoding) 108 | obs_encoding = obs_encoding.unsqueeze(1) 109 | obs_encoding = obs_encoding.reshape((self.context_size+1, -1, self.obs_encoding_size)) 110 | obs_encoding = torch.transpose(obs_encoding, 0, 1) 111 | obs_encoding = torch.cat((obs_encoding, goal_encoding), dim=1) 112 | 113 | # If a goal mask is provided, mask some of the goal tokens 114 | if goal_mask is not None: 115 | no_goal_mask = goal_mask.long() 116 | src_key_padding_mask = torch.index_select(self.all_masks.to(device), 0, no_goal_mask) 117 | else: 118 | src_key_padding_mask = None 119 | 120 | # Apply positional encoding 121 | if self.positional_encoding: 122 | obs_encoding = self.positional_encoding(obs_encoding) 123 | 124 | obs_encoding_tokens = self.sa_encoder(obs_encoding, src_key_padding_mask=src_key_padding_mask) 125 | if src_key_padding_mask is not None: 126 | avg_mask = torch.index_select(self.avg_pool_mask.to(device), 0, no_goal_mask).unsqueeze(-1) 127 | obs_encoding_tokens = obs_encoding_tokens * avg_mask 128 | obs_encoding_tokens = torch.mean(obs_encoding_tokens, dim=1) 129 | 130 | return obs_encoding_tokens 131 | 132 | 133 | # Utils for Group Norm 134 | def replace_bn_with_gn( 135 | root_module: nn.Module, 136 | features_per_group: int=16) -> nn.Module: 137 | """ 138 | Relace all BatchNorm layers with GroupNorm. 139 | """ 140 | replace_submodules( 141 | root_module=root_module, 142 | predicate=lambda x: isinstance(x, nn.BatchNorm2d), 143 | func=lambda x: nn.GroupNorm( 144 | num_groups=x.num_features//features_per_group, 145 | num_channels=x.num_features) 146 | ) 147 | return root_module 148 | 149 | 150 | def replace_submodules( 151 | root_module: nn.Module, 152 | predicate: Callable[[nn.Module], bool], 153 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 154 | """ 155 | Replace all submodules selected by the predicate with 156 | the output of func. 157 | 158 | predicate: Return true if the module is to be replaced. 159 | func: Return new module to use. 160 | """ 161 | if predicate(root_module): 162 | return func(root_module) 163 | 164 | bn_list = [k.split('.') for k, m 165 | in root_module.named_modules(remove_duplicate=True) 166 | if predicate(m)] 167 | for *parent, k in bn_list: 168 | parent_module = root_module 169 | if len(parent) > 0: 170 | parent_module = root_module.get_submodule('.'.join(parent)) 171 | if isinstance(parent_module, nn.Sequential): 172 | src_module = parent_module[int(k)] 173 | else: 174 | src_module = getattr(parent_module, k) 175 | tgt_module = func(src_module) 176 | if isinstance(parent_module, nn.Sequential): 177 | parent_module[int(k)] = tgt_module 178 | else: 179 | setattr(parent_module, k, tgt_module) 180 | # verify that all modules are replaced 181 | bn_list = [k.split('.') for k, m 182 | in root_module.named_modules(remove_duplicate=True) 183 | if predicate(m)] 184 | assert len(bn_list) == 0 185 | return root_module 186 | 187 | 188 | 189 | -------------------------------------------------------------------------------- /train/vint_train/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from vint_train.models.navibridge.navibridge import NaviBridge, DenseNetwork, StatesPredNet 2 | from vint_train.models.navibridge.navibridg_utils import NaviBridge_Encoder, replace_bn_with_gn 3 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D 4 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 5 | 6 | from vint_train.models.navibridge.ddbm.karras_diffusion import KarrasDenoiser 7 | from vint_train.models.navibridge.ddbm.resample import create_named_schedule_sampler 8 | from vint_train.models.navibridge.navibridge import PriorModel, Prior_HandCraft 9 | from vint_train.models.navibridge.vae.vae import VAEModel 10 | 11 | import os 12 | def create_model(config, device): 13 | """ 14 | Create a model based on the provided configuration. 15 | 16 | Args: 17 | config (dict): Configuration dictionary that includes model type and various parameters. 18 | 19 | Returns: 20 | model (object): Created model based on the specified configuration. 21 | 22 | Raises: 23 | ValueError: If the model type or vision encoder is not supported. 24 | """ 25 | # Create the model based on configuration 26 | if config["model_type"] == "navibridge": 27 | # Select and configure the vision encoder 28 | if config["vision_encoder"] == "navibridge_encoder": 29 | vision_encoder = NaviBridge_Encoder( 30 | obs_encoding_size=config["encoding_size"], 31 | context_size=config["context_size"], 32 | mha_num_attention_heads=config["mha_num_attention_heads"], 33 | mha_num_attention_layers=config["mha_num_attention_layers"], 34 | mha_ff_dim_factor=config["mha_ff_dim_factor"], 35 | ) 36 | vision_encoder = replace_bn_with_gn(vision_encoder) 37 | else: 38 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported") 39 | 40 | # Create the noise prediction network and distribution prediction network 41 | noise_pred_net = ConditionalUnet1D( 42 | input_dim=2, 43 | global_cond_dim=config["encoding_size"], 44 | down_dims=config["down_dims"], 45 | cond_predict_scale=config["cond_predict_scale"], 46 | ) 47 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"]) 48 | 49 | states_pred_net = StatesPredNet(embedding_dim=config["encoding_size"], 50 | class_num=config["class_num"], 51 | len_traj_pred=config["len_traj_pred"]) 52 | 53 | model = NaviBridge( 54 | vision_encoder=vision_encoder, 55 | noise_pred_net=noise_pred_net, 56 | dist_pred_net=dist_pred_network, 57 | states_pred_net=states_pred_net, 58 | ) 59 | set_model_prior(model, config, device) 60 | elif config["model_type"] == "cvae": 61 | model = VAEModel(config) 62 | else: 63 | raise ValueError(f"Model {config['model_type']} not supported") 64 | 65 | return model 66 | 67 | def create_noise_scheduler(config): 68 | if config["model_type"] == "navibridge": 69 | diffusion = create_diffusion(config) 70 | noise_scheduler = create_named_schedule_sampler(config["sampler_name"], diffusion) 71 | return noise_scheduler, diffusion 72 | 73 | def create_diffusion(config, 74 | ): 75 | # ddbm params 76 | sigma_data=config["sigma_data"] 77 | sigma_min=config["sigma_min"] 78 | sigma_max=config["sigma_max"] 79 | pred_mode=config["pred_mode"] 80 | weight_schedule=config["weight_schedule"] 81 | beta_d=config["beta_d"] 82 | beta_min=config["beta_min"] 83 | cov_xy=config["cov_xy"] 84 | diffusion = KarrasDenoiser( 85 | sigma_data=sigma_data, 86 | sigma_max=sigma_max, 87 | sigma_min=sigma_min, 88 | beta_d=beta_d, 89 | beta_min=beta_min, 90 | cov_xy=cov_xy, 91 | weight_schedule=weight_schedule, 92 | pred_mode=pred_mode 93 | ) 94 | return diffusion 95 | 96 | def set_model_prior(model, prior_args, device): 97 | prior = load_prior_model(prior_args, device) 98 | if prior_args['prior_policy'] == 'handcraft': 99 | prior_model = PriorModel(prior=prior, len_traj_pred=prior_args["len_traj_pred"], action_dim=prior_args["action_dim"]) 100 | elif prior_args['prior_policy'] == 'cluster': 101 | pass 102 | elif prior_args['prior_policy'] == 'gaussian': 103 | prior_model = PriorModel(prior=prior, len_traj_pred=prior_args["len_traj_pred"], action_dim=prior_args["action_dim"]) 104 | elif prior_args['prior_policy'] == 'cvae': 105 | prior_model = PriorModel(prior=prior, len_traj_pred=prior_args["len_traj_pred"], action_dim=prior_args["action_dim"]) 106 | model.prior_model = prior_model 107 | return model 108 | 109 | 110 | def load_prior_model(prior_args, device): 111 | if prior_args['prior_policy'] == 'handcraft': 112 | prior_model = Prior_HandCraft( 113 | class_num=prior_args["class_num"], 114 | angle_ranges=prior_args["angle_ranges"], 115 | min_std_angle=prior_args["min_std_angle"], 116 | max_std_angle=prior_args["max_std_angle"], 117 | min_std_length=prior_args["min_std_length"], 118 | max_std_length=prior_args["max_std_length"], 119 | len_traj_pred=prior_args["len_traj_pred"], 120 | ) 121 | elif prior_args['prior_policy'] == 'cluster': 122 | pass 123 | elif prior_args['prior_policy'] == 'gaussian': 124 | prior_model = None 125 | elif prior_args['prior_policy'] == 'cvae': 126 | from vint_train.models.navibridge.vae.vae import VAEModel 127 | model_spec_names = prior_args['net_type'] 128 | ckpt_path = prior_args["ckpt_path"] 129 | prior_args['pretrain'] = True 130 | prior_args['ckpt_path'] = ckpt_path 131 | 132 | prior_model = VAEModel(prior_args) 133 | prior_model.load_model(model_args=prior_args, device=device) 134 | 135 | else: 136 | raise NotImplementedError(f"Can not be found prior policy: {prior_args['prior_policy']}") 137 | 138 | return prior_model 139 | 140 | import torch 141 | from torch.optim import Adam, AdamW 142 | from torch.optim.lr_scheduler import ( 143 | CosineAnnealingLR, 144 | CyclicLR, 145 | ReduceLROnPlateau, 146 | StepLR 147 | ) 148 | from warmup_scheduler import GradualWarmupScheduler 149 | 150 | def get_optimizer_and_scheduler(config, model): 151 | lr = float(config["lr"]) 152 | config["optimizer"] = config["optimizer"].lower() 153 | 154 | if config["optimizer"] == "adam": 155 | optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.98)) 156 | elif config["optimizer"] == "adamw": 157 | optimizer = AdamW(model.parameters(), lr=lr) 158 | elif config["optimizer"] == "sgd": 159 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) 160 | else: 161 | raise ValueError(f"Optimizer {config['optimizer']} not supported") 162 | 163 | scheduler = None 164 | 165 | if config["model_type"] == "cvae": 166 | if config['lr_gamma'] < 1.0: 167 | scheduler = StepLR(optimizer, 168 | step_size=config["lr_step"], 169 | gamma=config["lr_gamma"]) 170 | else: 171 | scheduler = None 172 | 173 | elif config["scheduler"] is not None: 174 | config["scheduler"] = config["scheduler"].lower() 175 | 176 | if config["scheduler"] == "cosine": 177 | print("Using cosine annealing with T_max", config["epochs"]) 178 | scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"]) 179 | elif config["scheduler"] == "cyclic": 180 | print("Using cyclic LR with cycle", config["cyclic_period"]) 181 | scheduler = CyclicLR( 182 | optimizer, 183 | base_lr=lr / 10.0, 184 | max_lr=lr, 185 | step_size_up=config["cyclic_period"] // 2, 186 | cycle_momentum=False, 187 | ) 188 | elif config["scheduler"] == "plateau": 189 | print("Using ReduceLROnPlateau") 190 | scheduler = ReduceLROnPlateau( 191 | optimizer, 192 | factor=config["plateau_factor"], 193 | patience=config["plateau_patience"], 194 | verbose=True, 195 | ) 196 | else: 197 | raise ValueError(f"Scheduler {config['scheduler']} not supported") 198 | 199 | if config.get("warmup", False): 200 | print("Using warmup scheduler") 201 | scheduler = GradualWarmupScheduler( 202 | optimizer, 203 | multiplier=1, 204 | total_epoch=config["warmup_epochs"], 205 | after_scheduler=scheduler, 206 | ) 207 | 208 | return optimizer, scheduler 209 | -------------------------------------------------------------------------------- /deployment/src/navibridger_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import sys 5 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 6 | 7 | 8 | from vint_train.models.model_utils import create_noise_scheduler 9 | from vint_train.training.train_utils import get_action 10 | from vint_train.visualizing.action_utils import plot_trajs_and_points 11 | from vint_train.models.navibridge.ddbm.karras_diffusion import karras_sample 12 | 13 | import torch 14 | import numpy as np 15 | import yaml 16 | from PIL import Image as PILImage 17 | import matplotlib.pyplot as plt 18 | 19 | from utils_inference import to_numpy, transform_images, load_model, project_and_draw 20 | 21 | PARAMS_PATH = "../config/params.yaml" 22 | with open(PARAMS_PATH, "r") as f: 23 | params_config = yaml.safe_load(f) 24 | image_path = params_config["image_path"] 25 | 26 | # CONSTANTS 27 | MODEL_WEIGHTS_PATH = "../model_weights" 28 | ROBOT_CONFIG_PATH ="../config/robot.yaml" 29 | MODEL_CONFIG_PATH = "../config/models.yaml" 30 | with open(ROBOT_CONFIG_PATH, "r") as f: 31 | robot_config = yaml.safe_load(f) 32 | MAX_V = robot_config["max_v"] 33 | MAX_W = robot_config["max_w"] 34 | RATE = robot_config["frame_rate"] 35 | ACTION_STATS = {} 36 | ACTION_STATS['min'] = np.array([-2.5, -4]) 37 | ACTION_STATS['max'] = np.array([5, 4]) 38 | # GLOBALS 39 | context_queue = [] 40 | context_size = None 41 | 42 | # Load the model 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | print("Using device:", device) 45 | 46 | def get_bottom_folder_name(path): 47 | folder_path = os.path.dirname(path) 48 | bottom_folder_name = os.path.basename(folder_path) 49 | return bottom_folder_name 50 | 51 | def ensure_directory_exists(save_path): 52 | directory = os.path.dirname(save_path) 53 | if not os.path.exists(directory): 54 | os.makedirs(directory) 55 | 56 | def path_project_plot(img, path, args, camera_extrinsics, camera_intrinsics): 57 | for i, naction in enumerate(path): 58 | gc_actions = to_numpy(get_action(naction)) 59 | fig = project_and_draw(img, gc_actions, camera_extrinsics, camera_intrinsics) 60 | dir_basename = get_bottom_folder_name(image_path) 61 | save_path = os.path.join('../output', dir_basename, f'png_{args.model}_image_with_trajs_{i}.png') 62 | ensure_directory_exists(save_path) 63 | fig.savefig(save_path) 64 | save_path = os.path.join('../output', dir_basename, f'svg_{args.model}_image_with_trajs_{i}.svg') 65 | ensure_directory_exists(save_path) 66 | fig.savefig(save_path) 67 | print(f"output image saved as {save_path}") 68 | 69 | def main(args): 70 | camera_intrinsics = np.array([[470.7520828622471, 0, 16.00531005859375], 71 | [0, 470.7520828622471, 403.38909912109375], 72 | [0, 0, 1]]) 73 | camera_extrinsics = np.array([[0, 0, 1, -0.600], 74 | [-1, 0, 0, -0.000], 75 | [0, -1, 0, -0.042], 76 | [0, 0, 0, 1]]) 77 | global context_size, image_path 78 | # load model parameters 79 | with open(MODEL_CONFIG_PATH, "r") as f: 80 | model_paths = yaml.safe_load(f) 81 | 82 | model_config_path = model_paths[args.model]["config_path"] 83 | with open(model_config_path, "r") as f: 84 | model_params = yaml.safe_load(f) 85 | 86 | if model_params["model_type"] == "cvae": 87 | if "train_params" in model_params: 88 | model_params.update(model_params["train_params"]) 89 | if "diffuse_params" in model_params: 90 | model_params.update(model_params["diffuse_params"]) 91 | 92 | if model_params.get("prior_policy", None) == "cvae": 93 | if "diffuse_params" in model_params: 94 | model_params.update(model_params["diffuse_params"]) 95 | 96 | context_size = model_params["context_size"] 97 | # load model weights 98 | ckpth_path = model_paths[args.model]["ckpt_path"] 99 | if os.path.exists(ckpth_path): 100 | print(f"Loading model from {ckpth_path}") 101 | else: 102 | raise FileNotFoundError(f"Model weights not found at {ckpth_path}") 103 | model = load_model( 104 | ckpth_path, 105 | model_params, 106 | device, 107 | ) 108 | model.eval() 109 | 110 | if model_params["model_type"] == "navibridge": 111 | noise_scheduler, diffusion = create_noise_scheduler(model_params) 112 | 113 | for i in range(4): 114 | 115 | img_path = image_path + str(i + 18) + ".png" 116 | img = PILImage.open(img_path) 117 | context_queue.append(img) 118 | 119 | fake_goal = torch.randn((1, 3, *model_params["image_size"])).to(device) 120 | 121 | # infer action 122 | obs_images = transform_images(context_queue, model_params["image_size"], center_crop=False) 123 | obs_images = obs_images.to(device) 124 | fake_goal = torch.randn((1, 3, *model_params["image_size"])).to(device) 125 | mask = torch.ones(1).long().to(device) 126 | 127 | # You can change the fake_goal to a goal image, do the same operation as obs_images 128 | # goal_image = transform_images(goal_image, model_params["image_size"], center_crop=False) 129 | # goal_image = goal_image.to(device) 130 | # mask = torch.zeros(1).long().to(device) 131 | 132 | obs_cond_gc = model('vision_encoder', obs_img=obs_images, goal_img=fake_goal, input_goal_mask=mask) 133 | scale_factor=3 * MAX_V / RATE 134 | with torch.no_grad(): 135 | if len(obs_cond_gc.shape) == 2: 136 | obs_cond_gc = obs_cond_gc.repeat(args.num_samples, 1) 137 | else: 138 | obs_cond_gc = obs_cond_gc.repeat(args.num_samples, 1, 1) 139 | 140 | if model_params["model_type"] == "navibridge": 141 | if model_params["prior_policy"] == "handcraft": 142 | # Predict aciton states 143 | states_pred = model("states_pred_net", obsgoal_cond=obs_cond_gc) 144 | 145 | if model_params["prior_policy"] == "cvae": 146 | prior_cond = obs_images.repeat_interleave(args.num_samples, dim=0) 147 | elif model_params["prior_policy"] == "handcraft": 148 | prior_cond = states_pred 149 | else: 150 | prior_cond = None 151 | 152 | # initialize action from Gaussian noise 153 | if model.prior_model.prior is None: 154 | initial_samples = torch.randn((args.num_samples, model_params["len_traj_pred"], 2), device=device) 155 | else: 156 | with torch.no_grad(): 157 | initial_samples = model.prior_model.sample(cond=prior_cond, device=device) 158 | assert initial_samples.shape[-1] == 2, "action dim must be 2" 159 | naction, path, nfe = karras_sample( 160 | diffusion, 161 | model, 162 | initial_samples, 163 | None, 164 | steps=model_params["num_diffusion_iters"], 165 | model_kwargs=initial_samples, 166 | global_cond=obs_cond_gc, 167 | device=device, 168 | clip_denoised=model_params["clip_denoised"], 169 | sampler="heun", 170 | sigma_min=diffusion.sigma_min, 171 | sigma_max=diffusion.sigma_max, 172 | churn_step_ratio=model_params["churn_step_ratio"], 173 | rho=model_params["rho"], 174 | guidance=model_params["guidance"] 175 | ) 176 | 177 | gc_actions = to_numpy(get_action(naction)) 178 | 179 | gc_actions *= scale_factor 180 | if args.path_visual: 181 | path_project_plot(context_queue[-1], path, args, camera_extrinsics, camera_intrinsics) 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser( 186 | description="Code to run DIFFUSION NAVIGATION demo") 187 | parser.add_argument( 188 | "--model", 189 | "-m", 190 | default="navibridge_noise", 191 | type=str, 192 | help="model name (hint: check ../config/models.yaml)", 193 | ) 194 | parser.add_argument( 195 | "--waypoint", 196 | "-w", 197 | default=2, # close waypoints exihibit straight line motion (the middle waypoint is a good default) 198 | type=int, 199 | help=f"""index of the waypoint used for navigation (between 0 and 4 or 200 | how many waypoints your model predicts) (default: 2)""", 201 | ) 202 | parser.add_argument( 203 | "--num-samples", 204 | "-n", 205 | default=100, 206 | type=int, 207 | help=f"Number of actions sampled from the exploration model (default: 8)", 208 | ) 209 | parser.add_argument( 210 | "--path-visual", 211 | default=True, 212 | type=bool, 213 | help="visualization", 214 | ) 215 | parser.add_argument( 216 | "--device", 217 | "-d", 218 | default=0, 219 | type=int, 220 | help="device", 221 | ) 222 | args = parser.parse_args() 223 | print(f"Using {device}") 224 | main(args) 225 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/navibridge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import pdb 5 | import numpy as np 6 | import yaml 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from vint_train.training.train_utils import get_delta, normalize_data 12 | from vint_train.visualizing.visualize_utils import to_numpy, from_numpy 13 | 14 | # LOAD DATA CONFIG 15 | with open(os.path.join(os.path.dirname(__file__), "../../data/data_config.yaml"), "r") as f: 16 | data_config = yaml.safe_load(f) 17 | # POPULATE ACTION STATS 18 | ACTION_STATS = {} 19 | for key in data_config['action_stats']: 20 | ACTION_STATS[key] = np.array(data_config['action_stats'][key]) 21 | 22 | class NaviBridge(nn.Module): 23 | 24 | def __init__(self, vision_encoder, 25 | noise_pred_net, 26 | dist_pred_net, 27 | states_pred_net): 28 | super(NaviBridge, self).__init__() 29 | 30 | 31 | self.vision_encoder = vision_encoder 32 | self.noise_pred_net = noise_pred_net 33 | self.dist_pred_net = dist_pred_net 34 | self.states_pred_net = states_pred_net 35 | self.prior_model = None 36 | 37 | def forward(self, func_name, **kwargs): 38 | if func_name == "vision_encoder" : 39 | output = self.vision_encoder(kwargs["obs_img"], kwargs["goal_img"], input_goal_mask=kwargs["input_goal_mask"]) 40 | elif func_name == "noise_pred_net": 41 | output = self.noise_pred_net(sample=kwargs["sample"], timestep=kwargs["timestep"], global_cond=kwargs["global_cond"]) 42 | elif func_name == "dist_pred_net": 43 | output = self.dist_pred_net(kwargs["obsgoal_cond"]) 44 | elif func_name == "states_pred_net": 45 | output = self.states_pred_net(kwargs["obsgoal_cond"]) 46 | else: 47 | raise NotImplementedError 48 | return output 49 | 50 | 51 | class DenseNetwork(nn.Module): 52 | def __init__(self, embedding_dim): 53 | super(DenseNetwork, self).__init__() 54 | 55 | self.embedding_dim = embedding_dim 56 | self.network = nn.Sequential( 57 | nn.Linear(self.embedding_dim, self.embedding_dim//4), 58 | nn.ReLU(), 59 | nn.Linear(self.embedding_dim//4, self.embedding_dim//16), 60 | nn.ReLU(), 61 | nn.Linear(self.embedding_dim//16, 1) 62 | ) 63 | 64 | def forward(self, x): 65 | x = x.reshape((-1, self.embedding_dim)) 66 | output = self.network(x) 67 | return output 68 | 69 | 70 | class StatesPredNet(nn.Module): 71 | def __init__(self, embedding_dim, class_num, len_traj_pred): 72 | super(StatesPredNet, self).__init__() 73 | self.output_length = class_num + len_traj_pred * 2 74 | self.embedding_dim = embedding_dim 75 | self.class_num = class_num 76 | self.len_traj_pred = len_traj_pred 77 | 78 | self.network = nn.Sequential( 79 | nn.Linear(self.embedding_dim, self.embedding_dim // 4), 80 | nn.ReLU(), 81 | nn.Linear(self.embedding_dim // 4, self.embedding_dim // 16), 82 | nn.ReLU(), 83 | nn.Linear(self.embedding_dim // 16, self.output_length) 84 | ) 85 | 86 | def forward(self, x): 87 | x = x.reshape((-1, self.embedding_dim)) 88 | output = self.network(x) 89 | 90 | class_logits = output[:, :self.class_num] 91 | coords = output[:, self.class_num:] 92 | 93 | class_probs = nn.functional.softmax(class_logits, dim=-1) 94 | 95 | coords = coords.reshape((-1, self.len_traj_pred, 2)) 96 | 97 | return class_probs, coords 98 | 99 | 100 | class PriorModel: 101 | def __init__(self, prior=None, action_dim=2, len_traj_pred=8): 102 | self.prior = prior 103 | self.action_dim = action_dim 104 | self.len_traj_pred = len_traj_pred 105 | 106 | def sample(self, cond, device, num_samples=1): 107 | states = cond 108 | 109 | if self.prior is None: 110 | if type(states) is np.array: 111 | states_torch = torch.as_tensor(states) 112 | else: 113 | states_torch = states 114 | 115 | if isinstance(states_torch, tuple): 116 | prior_sample = torch.randn((states_torch[1].shape[0], self.len_traj_pred, self.action_dim), device=device) 117 | else: 118 | prior_sample = torch.randn((states_torch.shape[0], self.len_traj_pred, self.action_dim), device=device) 119 | else: 120 | if type(states) is np.array: 121 | states_torch = torch.as_tensor(states) 122 | else: 123 | states_torch = states 124 | 125 | prior_sample = self.prior.sample_prior(states_torch, num_samples=num_samples, device=device) 126 | 127 | prior_sample = prior_sample.to(device) 128 | 129 | return prior_sample 130 | 131 | class Prior_HandCraft: 132 | def __init__(self, 133 | class_num: int = 5, 134 | angle_ranges: list = None, # [(0, 67.5), (67.5, 112.5), (112.5, 180), (180, 270), (270, 360)] 135 | min_std_angle: float = 5.0, 136 | max_std_angle: float = 20.0, 137 | min_std_length: float = 1.0, 138 | max_std_length: float = 5.0, 139 | num_samples: int = 1, 140 | len_traj_pred=8, 141 | ): 142 | if angle_ranges is None: 143 | self._set_angle_ranges() 144 | else: 145 | self.angle_ranges = angle_ranges 146 | self.class_num = class_num 147 | self.min_std_angle = min_std_angle 148 | self.max_std_angle = max_std_angle 149 | self.min_std_length = min_std_length 150 | self.max_std_length = max_std_length 151 | self.num_samples = num_samples 152 | self.len_traj_pred = len_traj_pred 153 | 154 | def _set_angle_ranges(self): 155 | self.angle_ranges = [(0, 67.5), 156 | (67.5, 112.5), 157 | (112.5, 180), 158 | (180, 270), 159 | (270, 360)] 160 | 161 | def _extract_length(self, coords): 162 | x_end, y_end = coords[-1] 163 | length = np.sqrt(x_end**2 + y_end**2) 164 | 165 | return length 166 | 167 | def _exact_states(self, states_pred): 168 | class_probs, coords_pred = states_pred 169 | return class_probs.cpu(), coords_pred.cpu() 170 | 171 | def _convert_confidence_to_std(self, confidence, min_std=0.5, max_std=5.0): 172 | return min_std + (max_std - min_std) * (1 - confidence) 173 | 174 | def _preprocess(self, actions): 175 | actions = np.squeeze(actions) 176 | if actions.ndim == 2: 177 | actions = np.expand_dims(actions, axis=0) 178 | deltas = get_delta(actions) 179 | ndeltas = normalize_data(deltas, ACTION_STATS) 180 | naction = from_numpy(ndeltas) 181 | return naction 182 | 183 | def sample_prior(self, states, num_samples=None, device=None): 184 | if num_samples is not None: 185 | self.num_samples = num_samples 186 | 187 | class_probs, coords_pred = self._exact_states(states) 188 | batch_size = class_probs.shape[0] 189 | sampled_trajectories_batch = [] 190 | 191 | for batch_idx in range(batch_size): 192 | class_probs_single = class_probs[batch_idx] 193 | coords_pred_single = coords_pred[batch_idx] 194 | 195 | self.path_length_mean = self._extract_length(coords_pred_single) 196 | length_confidence = class_probs_single 197 | 198 | sampled_trajectories = [] 199 | 200 | for _ in range(self.num_samples): 201 | category = torch.multinomial(class_probs_single, num_samples=1).item() 202 | 203 | angle_range = self.angle_ranges[category] 204 | 205 | angle_std = self._convert_confidence_to_std( 206 | class_probs_single[category], 207 | min_std=self.min_std_angle, 208 | max_std=self.max_std_angle 209 | ) 210 | 211 | length_std = self._convert_confidence_to_std( 212 | length_confidence[category], 213 | min_std=self.min_std_length, 214 | max_std=self.max_std_length 215 | ) 216 | 217 | angle_mean = (angle_range[0] + angle_range[1]) / 2 218 | angle_sample = np.random.normal(loc=angle_mean, scale=angle_std) 219 | 220 | path_length_sample = np.random.normal(loc=self.path_length_mean, scale=length_std) 221 | 222 | x_end = path_length_sample * np.cos(np.radians(angle_sample)) 223 | y_end = path_length_sample * np.sin(np.radians(angle_sample)) 224 | 225 | h_mid = (0 + x_end) / 2 226 | h_range = (x_end - 0) / 2 227 | if y_end >= 0: 228 | h = np.random.uniform(x_end, x_end * 2) 229 | else: 230 | h = np.random.uniform(h_mid - h_range / 2, h_mid - h_range / 4) 231 | 232 | a = (y_end - 0) / ((x_end - 0) * (x_end + 0 - 2 * h)) 233 | k = -a * h**2 234 | 235 | x_vals = np.linspace(0, x_end, self.len_traj_pred) 236 | y_vals = a * (x_vals - h) ** 2 + k 237 | combined = np.stack((y_vals, x_vals), axis=1) 238 | sampled_trajectories.append(combined) 239 | 240 | sampled_trajectories = np.stack(sampled_trajectories, axis=0) 241 | 242 | sampled_trajectories_batch.append(sampled_trajectories) 243 | 244 | actions = np.stack(sampled_trajectories_batch, axis=0) 245 | naction = self._preprocess(actions) 246 | return naction 247 | -------------------------------------------------------------------------------- /deployment/src/utils_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import io 4 | import matplotlib.pyplot as plt 5 | 6 | # pytorch 7 | import torch 8 | import torch.nn as nn 9 | from torchvision import transforms 10 | import torchvision.transforms.functional as TF 11 | 12 | import numpy as np 13 | from PIL import Image as PILImage 14 | import cv2 15 | from typing import List, Tuple, Dict, Optional 16 | import importlib.resources as pkg_resources 17 | from tqdm import tqdm 18 | 19 | # models 20 | from vint_train.models.navibridge.navibridge import NaviBridge, DenseNetwork, StatesPredNet 21 | from vint_train.models.navibridge.navibridg_utils import NaviBridge_Encoder, replace_bn_with_gn 22 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D 23 | from vint_train.data.data_utils import IMAGE_ASPECT_RATIO 24 | 25 | from vint_train.models.navibridge.ddbm.karras_diffusion import KarrasDenoiser 26 | from vint_train.models.navibridge.ddbm.resample import create_named_schedule_sampler 27 | from vint_train.models.navibridge.navibridge import PriorModel, Prior_HandCraft 28 | from vint_train.models.navibridge.vae.vae import VAEModel 29 | from vint_train.models.model_utils import set_model_prior 30 | 31 | def load_model( 32 | model_path: str, 33 | config: dict, 34 | device: torch.device = torch.device("cpu"), 35 | ) -> nn.Module: 36 | """Load a model from a checkpoint file (works with models trained on multiple GPUs)""" 37 | model_type = config["model_type"] 38 | 39 | if config["model_type"] == "navibridge": 40 | # Select and configure the vision encoder 41 | if config["vision_encoder"] == "navibridge_encoder": 42 | vision_encoder = NaviBridge_Encoder( 43 | obs_encoding_size=config["encoding_size"], 44 | context_size=config["context_size"], 45 | mha_num_attention_heads=config["mha_num_attention_heads"], 46 | mha_num_attention_layers=config["mha_num_attention_layers"], 47 | mha_ff_dim_factor=config["mha_ff_dim_factor"], 48 | ) 49 | vision_encoder = replace_bn_with_gn(vision_encoder) 50 | else: 51 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported") 52 | 53 | # Create the noise prediction network and distribution prediction network 54 | noise_pred_net = ConditionalUnet1D( 55 | input_dim=2, 56 | global_cond_dim=config["encoding_size"], 57 | down_dims=config["down_dims"], 58 | cond_predict_scale=config["cond_predict_scale"], 59 | ) 60 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"]) 61 | 62 | states_pred_net = StatesPredNet(embedding_dim=config["encoding_size"], 63 | class_num=config["class_num"], 64 | len_traj_pred=config["len_traj_pred"]) 65 | 66 | model = NaviBridge( 67 | vision_encoder=vision_encoder, 68 | noise_pred_net=noise_pred_net, 69 | dist_pred_net=dist_pred_network, 70 | states_pred_net=states_pred_net, 71 | ) 72 | set_model_prior(model, config, device) 73 | elif config["model_type"] == "cvae": 74 | model = VAEModel(config) 75 | else: 76 | raise ValueError(f"Invalid model type: {model_type}") 77 | 78 | checkpoint = torch.load(model_path, map_location=device) 79 | 80 | if model_type == "navibridge": 81 | state_dict = checkpoint 82 | model.load_state_dict(state_dict, strict=False) 83 | else: 84 | loaded_model = checkpoint["model"] 85 | try: 86 | state_dict = loaded_model.module.state_dict() 87 | model.load_state_dict(state_dict, strict=False) 88 | except AttributeError as e: 89 | state_dict = loaded_model.state_dict() 90 | model.load_state_dict(state_dict, strict=False) 91 | model.to(device) 92 | return model 93 | 94 | def from_numpy(array: np.ndarray) -> torch.Tensor: 95 | return torch.from_numpy(array).float() 96 | 97 | def to_numpy(tensor): 98 | return tensor.cpu().detach().numpy() 99 | 100 | 101 | def transform_images(pil_imgs: List[PILImage.Image], image_size: List[int], center_crop: bool = False) -> torch.Tensor: 102 | """Transforms a list of PIL image to a torch tensor.""" 103 | transform_type = transforms.Compose( 104 | [ 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ 107 | 0.229, 0.224, 0.225]), 108 | ] 109 | ) 110 | if type(pil_imgs) != list: 111 | pil_imgs = [pil_imgs] 112 | transf_imgs = [] 113 | for pil_img in pil_imgs: 114 | w, h = pil_img.size 115 | if center_crop: 116 | if w > h: 117 | pil_img = TF.center_crop(pil_img, (h, int(h * IMAGE_ASPECT_RATIO))) # crop to the right ratio 118 | else: 119 | pil_img = TF.center_crop(pil_img, (int(w / IMAGE_ASPECT_RATIO), w)) 120 | pil_img = pil_img.resize(image_size) 121 | transf_img = transform_type(pil_img) 122 | transf_img = torch.unsqueeze(transf_img, 0) 123 | transf_imgs.append(transf_img) 124 | return torch.cat(transf_imgs, dim=1) 125 | 126 | def ensure_pil_image(item): 127 | if isinstance(item, PILImage.Image): 128 | return item 129 | elif isinstance(item, np.ndarray): 130 | return PILImage.fromarray(item) 131 | else: 132 | raise ValueError(f"Unsupported data type: {type(item)}") 133 | 134 | def alternate_merge(list1, list2): 135 | list1 = [ensure_pil_image(item) for item in list1] 136 | list2 = [ensure_pil_image(item) for item in list2] 137 | 138 | merged_list = [None] * (len(list1) + len(list2)) 139 | merged_list[::2] = list1 140 | merged_list[1::2] = list2 141 | 142 | return merged_list 143 | 144 | # clip angle between -pi and pi 145 | def clip_angle(angle): 146 | return np.mod(angle + np.pi, 2 * np.pi) - np.pi 147 | 148 | def find_images(directory): 149 | files = os.listdir(directory) 150 | image_files = [file for file in files if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))] 151 | return [os.path.join(directory, file) for file in image_files] 152 | 153 | def project_and_draw(image, trajectories, extrinsic_matrix, intrinsic_matrix): 154 | """ 155 | Project trajectories from world coordinates to pixel coordinates and draw them on the image using matplotlib. 156 | Note that the projection is not accurate, only for visualization. 157 | Args: 158 | image (np.array): The image on which to draw the trajectories. 159 | trajectories (list of list of tuples): Each trajectory is a list of (x, y, z) tuples. 160 | extrinsic_matrix (np.array): The 4x4 extrinsic matrix (world to camera). 161 | intrinsic_matrix (np.array): The 3x4 intrinsic matrix (camera to image). 162 | 163 | Returns: 164 | None: Displays the image with trajectories drawn using matplotlib. 165 | """ 166 | # Convert image to numpy array 167 | image_np = np.array(image) 168 | img_height, img_width = image_np.shape[:2] 169 | 170 | # Set up matplotlib figure and axis 171 | fig, ax = plt.subplots(figsize=(img_width / 100, img_height / 100), dpi=100) 172 | ax.imshow(image_np) # Show the image 173 | 174 | # Plot each trajectory 175 | for trajectory in trajectories: 176 | points = [] 177 | for (x, y) in trajectory: 178 | # Convert from world to camera coordinates 179 | camera_coords = world_to_camera(np.array([x, y, -1.5]), extrinsic_matrix) 180 | # Project to pixel coordinates 181 | pixel_coords = camera_to_pixel(camera_coords, intrinsic_matrix) 182 | u, v = int(pixel_coords[0]), int(pixel_coords[1]) 183 | u += (image.size[0] // 2) 184 | # v += (image.size[1] // 2) 185 | points.append((u, v)) 186 | 187 | # Separate u and v coordinates for plotting 188 | if points: 189 | u_coords, v_coords = zip(*points) # Unpack list of tuples into separate lists 190 | ax.plot(u_coords, v_coords, color='yellow', linewidth=8) # Plot trajectory 191 | 192 | # Limit the view to the image's dimensions only 193 | ax.set_xlim(0, img_width) 194 | ax.set_ylim(img_height, 0) # Reverse the y-axis to match image coordinates 195 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 196 | 197 | # Display the final image with trajectories 198 | plt.axis('off') # Turn off axes for clarity 199 | return fig 200 | 201 | def world_to_camera(world_coords, extrinsic_matrix): 202 | """ 203 | Convert world coordinates to camera coordinates using the extrinsic matrix. 204 | 205 | Args: 206 | world_coords (np.array): Coordinates in the world frame. 207 | extrinsic_matrix (np.array): The 4x4 extrinsic matrix (world to camera). 208 | 209 | Returns: 210 | np.array: Coordinates in the camera frame. 211 | """ 212 | # Convert to homogeneous coordinates 213 | world_coords_homogeneous = np.append(world_coords, 1) 214 | # Transform to camera coordinates 215 | camera_coords_homogeneous = np.linalg.inv(extrinsic_matrix) @ world_coords_homogeneous 216 | # Convert back to 3D and normalize 217 | return camera_coords_homogeneous[:3] / camera_coords_homogeneous[3] 218 | 219 | def camera_to_pixel(camera_coords, intrinsic_matrix): 220 | """ 221 | Project camera coordinates to pixel coordinates using the intrinsic matrix. 222 | 223 | Args: 224 | camera_coords (np.array): Coordinates in the camera frame. 225 | intrinsic_matrix (np.array): The 3x4 intrinsic matrix (camera to image). 226 | 227 | Returns: 228 | np.array: Coordinates in pixel space. 229 | """ 230 | # Focal lengths and principal point from intrinsic matrix 231 | fx, fy = intrinsic_matrix[0, 0], intrinsic_matrix[1, 1] 232 | cx, cy = intrinsic_matrix[0, 2], intrinsic_matrix[1, 2] 233 | X, Y, Z = camera_coords[0], camera_coords[1], camera_coords[2] 234 | 235 | # Project to pixel coordinates 236 | x = (fx * X / Z) + cx 237 | y = (fy * Y / Z) + cy 238 | 239 | return np.array([x, y]) # Convert to 2D pixel coordinates -------------------------------------------------------------------------------- /train/vint_train/training/train_eval_loop.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import os 3 | import numpy as np 4 | from typing import List, Optional, Dict 5 | from prettytable import PrettyTable 6 | 7 | from vint_train.training.train_utils import train_navibridge, evaluate_navibridge 8 | from vint_train.training.train_utils import train_cvae 9 | from vint_train.models.navibridge.ddbm.resample import * 10 | from vint_train.models.navibridge.ddbm.karras_diffusion import KarrasDenoiser 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | from torch.optim import Adam 17 | from torchvision import transforms 18 | 19 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 20 | from diffusers.training_utils import EMAModel 21 | 22 | def train_eval_loop_navibridge( 23 | train_model: bool, 24 | model: nn.Module, 25 | optimizer: Adam, 26 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 27 | noise_scheduler: UniformSampler, 28 | diffusuon: KarrasDenoiser, 29 | prior_policy: str, 30 | train_loader: DataLoader, 31 | test_dataloaders: Dict[str, DataLoader], 32 | transform: transforms, 33 | goal_mask_prob: float, 34 | epochs: int, 35 | device: torch.device, 36 | project_folder: str, 37 | print_log_freq: int = 100, 38 | wandb_log_freq: int = 10, 39 | image_log_freq: int = 1000, 40 | num_images_log: int = 8, 41 | current_epoch: int = 0, 42 | alpha: float = 1e-4, 43 | use_wandb: bool = True, 44 | eval_fraction: float = 0.25, 45 | eval_freq: int = 1, 46 | # ddbm params 47 | steps=10, 48 | clip_denoised: bool = True, 49 | sampler = "heun", 50 | sigma_min = 0.002, 51 | sigma_max = 80, 52 | churn_step_ratio = 0., 53 | rho = 7.0, 54 | guidance = 1, 55 | ): 56 | """ 57 | Train and evaluate the model for several epochs 58 | """ 59 | latest_path = os.path.join(project_folder, f"latest.pth") 60 | ema_model = EMAModel(model=model,power=0.75) 61 | 62 | for epoch in range(current_epoch, current_epoch + epochs): 63 | if train_model: 64 | print( 65 | f"Start ViNT DP Training Epoch {epoch}/{current_epoch + epochs - 1}" 66 | ) 67 | train_navibridge( 68 | model=model, 69 | ema_model=ema_model, 70 | optimizer=optimizer, 71 | dataloader=train_loader, 72 | transform=transform, 73 | device=device, 74 | diffusion=diffusuon, 75 | noise_scheduler=noise_scheduler, 76 | prior_policy=prior_policy, 77 | goal_mask_prob=goal_mask_prob, 78 | project_folder=project_folder, 79 | epoch=epoch, 80 | print_log_freq=print_log_freq, 81 | wandb_log_freq=wandb_log_freq, 82 | image_log_freq=image_log_freq, 83 | num_images_log=num_images_log, 84 | use_wandb=use_wandb, 85 | alpha=alpha, 86 | # ddbm params 87 | steps=steps, 88 | clip_denoised=clip_denoised, 89 | sampler=sampler, 90 | sigma_min=sigma_min, 91 | sigma_max=sigma_max, 92 | churn_step_ratio=churn_step_ratio, 93 | rho=rho, 94 | guidance=guidance, 95 | ) 96 | lr_scheduler.step() 97 | 98 | numbered_path = os.path.join(project_folder, f"ema_{epoch}.pth") 99 | torch.save(ema_model.averaged_model.state_dict(), numbered_path) 100 | numbered_path = os.path.join(project_folder, f"ema_latest.pth") 101 | print(f"Saved EMA model to {numbered_path}") 102 | 103 | numbered_path = os.path.join(project_folder, f"{epoch}.pth") 104 | torch.save(model.state_dict(), numbered_path) 105 | torch.save(model.state_dict(), latest_path) 106 | print(f"Saved model to {numbered_path}") 107 | 108 | # save optimizer 109 | numbered_path = os.path.join(project_folder, f"optimizer_{epoch}.pth") 110 | latest_optimizer_path = os.path.join(project_folder, f"optimizer_latest.pth") 111 | torch.save(optimizer.state_dict(), latest_optimizer_path) 112 | 113 | # save scheduler 114 | numbered_path = os.path.join(project_folder, f"scheduler_{epoch}.pth") 115 | latest_scheduler_path = os.path.join(project_folder, f"scheduler_latest.pth") 116 | torch.save(lr_scheduler.state_dict(), latest_scheduler_path) 117 | 118 | 119 | if (epoch + 1) % eval_freq == 0: 120 | for dataset_type in test_dataloaders: 121 | print( 122 | f"Start {dataset_type} ViNT DP Testing Epoch {epoch}/{current_epoch + epochs - 1}" 123 | ) 124 | loader = test_dataloaders[dataset_type] 125 | evaluate_navibridge( 126 | eval_type=dataset_type, 127 | ema_model=ema_model, 128 | dataloader=loader, 129 | transform=transform, 130 | device=device, 131 | prior_policy=prior_policy, 132 | diffusion=diffusuon, 133 | noise_scheduler=noise_scheduler, 134 | goal_mask_prob=goal_mask_prob, 135 | project_folder=project_folder, 136 | epoch=epoch, 137 | print_log_freq=print_log_freq, 138 | num_images_log=num_images_log, 139 | wandb_log_freq=wandb_log_freq, 140 | use_wandb=use_wandb, 141 | eval_fraction=eval_fraction, 142 | # ddbm params 143 | steps=steps, 144 | clip_denoised=clip_denoised, 145 | sampler=sampler, 146 | sigma_min=sigma_min, 147 | sigma_max=sigma_max, 148 | churn_step_ratio=churn_step_ratio, 149 | rho=rho, 150 | guidance=guidance, 151 | ) 152 | wandb.log({ 153 | "lr": optimizer.param_groups[0]["lr"], 154 | }, commit=False) 155 | 156 | if lr_scheduler is not None: 157 | lr_scheduler.step() 158 | 159 | # log average eval loss 160 | wandb.log({}, commit=False) 161 | 162 | wandb.log({ 163 | "lr": optimizer.param_groups[0]["lr"], 164 | }, commit=False) 165 | 166 | 167 | # Flush the last set of eval logs 168 | wandb.log({}) 169 | print() 170 | 171 | def train_eval_loop_cvae( 172 | train_model: bool, 173 | model: nn.Module, 174 | optimizer:Adam, 175 | lr_scheduler: torch.optim.lr_scheduler.StepLR, 176 | train_loader: DataLoader, 177 | transform: transforms, 178 | # num_itr: int, 179 | epochs: int, 180 | prior_policy: str, 181 | # model_args, 182 | device: torch.device, 183 | project_folder: str, 184 | print_log_freq: int = 100, 185 | wandb_log_freq: int = 10, 186 | current_epoch: int = 0, 187 | save_freq=1, 188 | use_wandb: bool = True, 189 | ): 190 | """ 191 | This function handles the training and evaluation loop for a CVAE model. 192 | 193 | Args: 194 | model (torch.nn.Module): The model to be trained. 195 | optimizer (torch.optim.Optimizer): The optimizer used to update model parameters. 196 | sched (torch.optim.lr_scheduler, optional): Learning rate scheduler, can be None. 197 | dataloader (torch.utils.data.DataLoader): Dataloader for training data. 198 | model_args (dict): Contains model-specific arguments, including prior_policy. 199 | opt (argparse.Namespace or similar): Contains other optimization settings like device and log_freq. 200 | ckpt_path (str): Path where model checkpoints and logs should be saved. 201 | log_freq (int, optional): Frequency of logging to wandb, defaults to 10 iterations. 202 | save_freq (int, optional): Frequency of saving the model, defaults to 100 iterations. 203 | """ 204 | 205 | for epoch in range(current_epoch, current_epoch + epochs): 206 | if train_model: 207 | print( 208 | f"Start ViNT DP Training Epoch {epoch}/{current_epoch + epochs - 1}" 209 | ) 210 | train_cvae( 211 | model=model, 212 | optimizer=optimizer, 213 | dataloader=train_loader, 214 | transform=transform, 215 | device=device, 216 | project_folder=project_folder, 217 | epoch=epoch, 218 | print_log_freq=print_log_freq, 219 | wandb_log_freq=wandb_log_freq, 220 | use_wandb=use_wandb, 221 | ) 222 | lr_scheduler.step() 223 | 224 | # Save model checkpoints and optimizer state 225 | if (epoch + 1) % save_freq == 0: 226 | numbered_path = os.path.join(project_folder, f"cvae_{epoch}.pth") 227 | model.save_model(ckpt_path=numbered_path, epoch=epoch) 228 | 229 | # save optimizer 230 | latest_optimizer_path = os.path.join(project_folder, f"optimizer_latest.pth") 231 | torch.save(optimizer.state_dict(), latest_optimizer_path) 232 | 233 | # save scheduler 234 | latest_scheduler_path = os.path.join(project_folder, f"scheduler_latest.pth") 235 | torch.save(lr_scheduler.state_dict(), latest_scheduler_path) 236 | 237 | if lr_scheduler is not None: 238 | lr_scheduler.step() 239 | 240 | def load_model(model, model_type, checkpoint: dict) -> None: 241 | """Load model from checkpoint.""" 242 | if model_type == "navibridge": 243 | state_dict = checkpoint 244 | model.load_state_dict(state_dict, strict=False) 245 | else: 246 | loaded_model = checkpoint["model"] 247 | try: 248 | state_dict = loaded_model.module.state_dict() 249 | model.load_state_dict(state_dict, strict=False) 250 | except AttributeError as e: 251 | state_dict = loaded_model.state_dict() 252 | model.load_state_dict(state_dict, strict=False) 253 | 254 | 255 | def load_ema_model(ema_model, state_dict: dict) -> None: 256 | """Load model from checkpoint.""" 257 | ema_model.load_state_dict(state_dict) 258 | 259 | 260 | def count_parameters(model): 261 | table = PrettyTable(["Modules", "Parameters"]) 262 | total_params = 0 263 | for name, parameter in model.named_parameters(): 264 | if not parameter.requires_grad: continue 265 | params = parameter.numel() 266 | table.add_row([name, params]) 267 | total_params+=params 268 | 269 | print(f"Total Trainable Params: {total_params/1e6:.2f}M") 270 | return total_params -------------------------------------------------------------------------------- /train/vint_train/process_data/process_data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import io 3 | import os 4 | import rosbag 5 | from PIL import Image 6 | import cv2 7 | from typing import Any, Tuple, List, Dict 8 | import torchvision.transforms.functional as TF 9 | 10 | IMAGE_SIZE = (160, 120) 11 | IMAGE_ASPECT_RATIO = 4 / 3 12 | 13 | 14 | def process_images(im_list: List, img_process_func) -> List: 15 | """ 16 | Process image data from a topic that publishes ros images into a list of PIL images 17 | """ 18 | images = [] 19 | for img_msg in im_list: 20 | img = img_process_func(img_msg) 21 | images.append(img) 22 | return images 23 | 24 | 25 | def process_tartan_img(msg) -> Image: 26 | """ 27 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the tartan_drive dataset 28 | """ 29 | img = ros_to_numpy(msg, output_resolution=IMAGE_SIZE) * 255 30 | img = img.astype(np.uint8) 31 | # reverse the axis order to get the image in the right orientation 32 | img = np.moveaxis(img, 0, -1) 33 | # convert rgb to bgr 34 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 35 | img = Image.fromarray(img) 36 | return img 37 | 38 | 39 | def process_locobot_img(msg) -> Image: 40 | """ 41 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the locobot dataset 42 | """ 43 | img = np.frombuffer(msg.data, dtype=np.uint8).reshape( 44 | msg.height, msg.width, -1) 45 | pil_image = Image.fromarray(img) 46 | return pil_image 47 | 48 | 49 | def process_scand_img(msg) -> Image: 50 | """ 51 | Process image data from a topic that publishes sensor_msgs/CompressedImage to a PIL image for the scand dataset 52 | """ 53 | # convert sensor_msgs/CompressedImage to PIL image 54 | img = Image.open(io.BytesIO(msg.data)) 55 | # center crop image to 4:3 aspect ratio 56 | w, h = img.size 57 | img = TF.center_crop( 58 | img, (h, int(h * IMAGE_ASPECT_RATIO)) 59 | ) # crop to the right ratio 60 | # resize image to IMAGE_SIZE 61 | img = img.resize(IMAGE_SIZE) 62 | return img 63 | 64 | 65 | ############## Add custom image processing functions here ############# 66 | 67 | def process_sacson_img(msg) -> Image: 68 | np_arr = np.fromstring(msg.data, np.uint8) 69 | image_np = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) 70 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 71 | pil_image = Image.fromarray(image_np) 72 | return pil_image 73 | 74 | 75 | ####################################################################### 76 | 77 | 78 | def process_odom( 79 | odom_list: List, 80 | odom_process_func: Any, 81 | ang_offset: float = 0.0, 82 | ) -> Dict[np.ndarray, np.ndarray]: 83 | """ 84 | Process odom data from a topic that publishes nav_msgs/Odometry into position and yaw 85 | """ 86 | xys = [] 87 | yaws = [] 88 | for odom_msg in odom_list: 89 | xy, yaw = odom_process_func(odom_msg, ang_offset) 90 | xys.append(xy) 91 | yaws.append(yaw) 92 | return {"position": np.array(xys), "yaw": np.array(yaws)} 93 | 94 | 95 | def nav_to_xy_yaw(odom_msg, ang_offset: float) -> Tuple[List[float], float]: 96 | """ 97 | Process odom data from a topic that publishes nav_msgs/Odometry into position 98 | """ 99 | 100 | position = odom_msg.pose.pose.position 101 | orientation = odom_msg.pose.pose.orientation 102 | yaw = ( 103 | quat_to_yaw(orientation.x, orientation.y, orientation.z, orientation.w) 104 | + ang_offset 105 | ) 106 | return [position.x, position.y], yaw 107 | 108 | 109 | ############ Add custom odometry processing functions here ############ 110 | 111 | 112 | ####################################################################### 113 | 114 | 115 | def get_images_and_odom( 116 | bag: rosbag.Bag, 117 | imtopics: List[str] or str, 118 | odomtopics: List[str] or str, 119 | img_process_func: Any, 120 | odom_process_func: Any, 121 | rate: float = 4.0, 122 | ang_offset: float = 0.0, 123 | ): 124 | """ 125 | Get image and odom data from a bag file 126 | 127 | Args: 128 | bag (rosbag.Bag): bag file 129 | imtopics (list[str] or str): topic name(s) for image data 130 | odomtopics (list[str] or str): topic name(s) for odom data 131 | img_process_func (Any): function to process image data 132 | odom_process_func (Any): function to process odom data 133 | rate (float, optional): rate to sample data. Defaults to 4.0. 134 | ang_offset (float, optional): angle offset to add to odom data. Defaults to 0.0. 135 | Returns: 136 | img_data (list): list of PIL images 137 | traj_data (list): list of odom data 138 | """ 139 | # check if bag has both topics 140 | odomtopic = None 141 | imtopic = None 142 | if type(imtopics) == str: 143 | imtopic = imtopics 144 | else: 145 | for imt in imtopics: 146 | if bag.get_message_count(imt) > 0: 147 | imtopic = imt 148 | break 149 | if type(odomtopics) == str: 150 | odomtopic = odomtopics 151 | else: 152 | for ot in odomtopics: 153 | if bag.get_message_count(ot) > 0: 154 | odomtopic = ot 155 | break 156 | if not (imtopic and odomtopic): 157 | # bag doesn't have both topics 158 | return None, None 159 | 160 | synced_imdata = [] 161 | synced_odomdata = [] 162 | # get start time of bag in seconds 163 | currtime = bag.get_start_time() 164 | 165 | curr_imdata = None 166 | curr_odomdata = None 167 | 168 | for topic, msg, t in bag.read_messages(topics=[imtopic, odomtopic]): 169 | if topic == imtopic: 170 | curr_imdata = msg 171 | elif topic == odomtopic: 172 | curr_odomdata = msg 173 | if (t.to_sec() - currtime) >= 1.0 / rate: 174 | if curr_imdata is not None and curr_odomdata is not None: 175 | synced_imdata.append(curr_imdata) 176 | synced_odomdata.append(curr_odomdata) 177 | currtime = t.to_sec() 178 | 179 | img_data = process_images(synced_imdata, img_process_func) 180 | traj_data = process_odom( 181 | synced_odomdata, 182 | odom_process_func, 183 | ang_offset=ang_offset, 184 | ) 185 | 186 | return img_data, traj_data 187 | 188 | 189 | def is_backwards( 190 | pos1: np.ndarray, yaw1: float, pos2: np.ndarray, eps: float = 1e-5 191 | ) -> bool: 192 | """ 193 | Check if the trajectory is going backwards given the position and yaw of two points 194 | Args: 195 | pos1: position of the first point 196 | 197 | """ 198 | dx, dy = pos2 - pos1 199 | return dx * np.cos(yaw1) + dy * np.sin(yaw1) < eps 200 | 201 | 202 | # cut out non-positive velocity segments of the trajectory 203 | def filter_backwards( 204 | img_list: List[Image.Image], 205 | traj_data: Dict[str, np.ndarray], 206 | start_slack: int = 0, 207 | end_slack: int = 0, 208 | ) -> Tuple[List[np.ndarray], List[int]]: 209 | """ 210 | Cut out non-positive velocity segments of the trajectory 211 | Args: 212 | traj_type: type of trajectory to cut 213 | img_list: list of images 214 | traj_data: dictionary of position and yaw data 215 | start_slack: number of points to ignore at the start of the trajectory 216 | end_slack: number of points to ignore at the end of the trajectory 217 | Returns: 218 | cut_trajs: list of cut trajectories 219 | start_times: list of start times of the cut trajectories 220 | """ 221 | traj_pos = traj_data["position"] 222 | traj_yaws = traj_data["yaw"] 223 | cut_trajs = [] 224 | start = True 225 | 226 | def process_pair(traj_pair: list) -> Tuple[List, Dict]: 227 | new_img_list, new_traj_data = zip(*traj_pair) 228 | new_traj_data = np.array(new_traj_data) 229 | new_traj_pos = new_traj_data[:, :2] 230 | new_traj_yaws = new_traj_data[:, 2] 231 | return (new_img_list, {"position": new_traj_pos, "yaw": new_traj_yaws}) 232 | 233 | for i in range(max(start_slack, 1), len(traj_pos) - end_slack): 234 | pos1 = traj_pos[i - 1] 235 | yaw1 = traj_yaws[i - 1] 236 | pos2 = traj_pos[i] 237 | if not is_backwards(pos1, yaw1, pos2): 238 | if start: 239 | new_traj_pairs = [ 240 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]]) 241 | ] 242 | start = False 243 | elif i == len(traj_pos) - end_slack - 1: 244 | cut_trajs.append(process_pair(new_traj_pairs)) 245 | else: 246 | new_traj_pairs.append( 247 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]]) 248 | ) 249 | elif not start: 250 | cut_trajs.append(process_pair(new_traj_pairs)) 251 | start = True 252 | return cut_trajs 253 | 254 | 255 | def quat_to_yaw( 256 | x: np.ndarray, 257 | y: np.ndarray, 258 | z: np.ndarray, 259 | w: np.ndarray, 260 | ) -> np.ndarray: 261 | """ 262 | Convert a batch quaternion into a yaw angle 263 | yaw is rotation around z in radians (counterclockwise) 264 | """ 265 | t3 = 2.0 * (w * z + x * y) 266 | t4 = 1.0 - 2.0 * (y * y + z * z) 267 | yaw = np.arctan2(t3, t4) 268 | return yaw 269 | 270 | 271 | def ros_to_numpy( 272 | msg, nchannels=3, empty_value=None, output_resolution=None, aggregate="none" 273 | ): 274 | """ 275 | Convert a ROS image message to a numpy array 276 | """ 277 | if output_resolution is None: 278 | output_resolution = (msg.width, msg.height) 279 | 280 | is_rgb = "8" in msg.encoding 281 | if is_rgb: 282 | data = np.frombuffer(msg.data, dtype=np.uint8).copy() 283 | else: 284 | data = np.frombuffer(msg.data, dtype=np.float32).copy() 285 | 286 | data = data.reshape(msg.height, msg.width, nchannels) 287 | 288 | if empty_value: 289 | mask = np.isclose(abs(data), empty_value) 290 | fill_value = np.percentile(data[~mask], 99) 291 | data[mask] = fill_value 292 | 293 | data = cv2.resize( 294 | data, 295 | dsize=(output_resolution[0], output_resolution[1]), 296 | interpolation=cv2.INTER_AREA, 297 | ) 298 | 299 | if aggregate == "littleendian": 300 | data = sum([data[:, :, i] * (256**i) for i in range(nchannels)]) 301 | elif aggregate == "bigendian": 302 | data = sum([data[:, :, -(i + 1)] * (256**i) for i in range(nchannels)]) 303 | 304 | if len(data.shape) == 2: 305 | data = np.expand_dims(data, axis=0) 306 | else: 307 | data = np.moveaxis(data, 2, 0) # Switch to channels-first 308 | 309 | if is_rgb: 310 | data = data.astype(np.float32) / ( 311 | 255.0 if aggregate == "none" else 255.0**nchannels 312 | ) 313 | 314 | return data 315 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import wandb 4 | import argparse 5 | import numpy as np 6 | import yaml 7 | import time 8 | import pdb 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader, ConcatDataset 13 | 14 | from torchvision import transforms 15 | import torch.backends.cudnn as cudnn 16 | 17 | from vint_train.models.model_utils import create_noise_scheduler, create_model, get_optimizer_and_scheduler 18 | 19 | 20 | from vint_train.data.vint_dataset import ViNT_Dataset 21 | from vint_train.training.train_eval_loop import ( 22 | load_model, 23 | train_eval_loop_navibridge, 24 | train_eval_loop_cvae, 25 | ) 26 | 27 | 28 | def main(config): 29 | assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"] 30 | assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"] 31 | 32 | if torch.cuda.is_available(): 33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 34 | if "gpu_ids" not in config: 35 | config["gpu_ids"] = [0] 36 | elif type(config["gpu_ids"]) == int: 37 | config["gpu_ids"] = [config["gpu_ids"]] 38 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 39 | [str(x) for x in config["gpu_ids"]] 40 | ) 41 | print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"]) 42 | else: 43 | print("Using cpu") 44 | 45 | first_gpu_id = config["gpu_ids"][0] 46 | device = torch.device( 47 | f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu" 48 | ) 49 | 50 | if "seed" in config: 51 | np.random.seed(config["seed"]) 52 | torch.manual_seed(config["seed"]) 53 | cudnn.deterministic = True 54 | 55 | cudnn.benchmark = True # good if input sizes don't vary 56 | transform = ([ 57 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 58 | ]) 59 | transform = transforms.Compose(transform) 60 | 61 | if config["model_type"] == "cvae": 62 | if "train_params" in config: 63 | config.update(config["train_params"]) 64 | if "diffuse_params" in config: 65 | config.update(config["diffuse_params"]) 66 | 67 | if config.get("prior_policy", None) == "cvae": 68 | if "diffuse_params" in config: 69 | config.update(config["diffuse_params"]) 70 | 71 | # Load the data 72 | train_dataset = [] 73 | test_dataloaders = {} 74 | 75 | if "context_type" not in config: 76 | config["context_type"] = "temporal" 77 | 78 | if "clip_goals" not in config: 79 | config["clip_goals"] = False 80 | 81 | for dataset_name in config["datasets"]: 82 | data_config = config["datasets"][dataset_name] 83 | if "negative_mining" not in data_config: 84 | data_config["negative_mining"] = True 85 | if "goals_per_obs" not in data_config: 86 | data_config["goals_per_obs"] = 1 87 | if "end_slack" not in data_config: 88 | data_config["end_slack"] = 0 89 | if "waypoint_spacing" not in data_config: 90 | data_config["waypoint_spacing"] = 1 91 | 92 | for data_split_type in ["train", "test"]: 93 | if data_split_type in data_config: 94 | dataset = ViNT_Dataset( 95 | data_folder=data_config["data_folder"], 96 | data_split_folder=data_config[data_split_type], 97 | dataset_name=dataset_name, 98 | image_size=config["image_size"], 99 | waypoint_spacing=data_config["waypoint_spacing"], 100 | min_dist_cat=config["distance"]["min_dist_cat"], 101 | max_dist_cat=config["distance"]["max_dist_cat"], 102 | min_action_distance=config["action"]["min_dist_cat"], 103 | max_action_distance=config["action"]["max_dist_cat"], 104 | negative_mining=data_config["negative_mining"], 105 | len_traj_pred=config["len_traj_pred"], 106 | learn_angle=config["learn_angle"], 107 | context_size=config["context_size"], 108 | context_type=config["context_type"], 109 | end_slack=data_config["end_slack"], 110 | goals_per_obs=data_config["goals_per_obs"], 111 | normalize=config["normalize"], 112 | goal_type=config["goal_type"], 113 | angle_ranges=config.get("angle_ranges", None), 114 | ) 115 | if data_split_type == "train": 116 | train_dataset.append(dataset) 117 | else: 118 | dataset_type = f"{dataset_name}_{data_split_type}" 119 | if dataset_type not in test_dataloaders: 120 | test_dataloaders[dataset_type] = {} 121 | test_dataloaders[dataset_type] = dataset 122 | 123 | # combine all the datasets from different robots 124 | train_dataset = ConcatDataset(train_dataset) 125 | 126 | train_loader = DataLoader( 127 | train_dataset, 128 | batch_size=config["batch_size"], 129 | shuffle=True, 130 | num_workers=config["num_workers"], 131 | drop_last=False, 132 | persistent_workers=True, 133 | ) 134 | 135 | if "eval_batch_size" not in config: 136 | config["eval_batch_size"] = config["batch_size"] 137 | 138 | for dataset_type, dataset in test_dataloaders.items(): 139 | test_dataloaders[dataset_type] = DataLoader( 140 | dataset, 141 | batch_size=config["eval_batch_size"], 142 | shuffle=True, 143 | num_workers=0, 144 | drop_last=False, 145 | ) 146 | 147 | model = create_model(config, device) 148 | if config["model_type"] == "navibridge": 149 | noise_scheduler, diffusion = create_noise_scheduler(config) 150 | elif config["model_type"] == "cvae": 151 | model.load_model(model_args=config, device=device) 152 | if config["clipping"]: 153 | print("Clipping gradients to", config["max_norm"]) 154 | for p in model.parameters(): 155 | if not p.requires_grad: 156 | continue 157 | p.register_hook( 158 | lambda grad: torch.clamp( 159 | grad, -1 * config["max_norm"], config["max_norm"] 160 | ) 161 | ) 162 | 163 | if config["model_type"] == "cvae": 164 | optimizer, scheduler = get_optimizer_and_scheduler(config, model.net) 165 | else: 166 | optimizer, scheduler = get_optimizer_and_scheduler(config, model) 167 | current_epoch = 0 168 | if "load_run" in config: 169 | load_project_folder = os.path.join("logs", config["load_run"]) 170 | print("Loading model from ", load_project_folder) 171 | latest_path = os.path.join(load_project_folder, "latest.pth") 172 | latest_checkpoint = torch.load(latest_path) #f"cuda:{}" if torch.cuda.is_available() else "cpu") 173 | load_model(model, config["model_type"], latest_checkpoint) 174 | if "epoch" in latest_checkpoint: 175 | current_epoch = latest_checkpoint["epoch"] + 1 176 | 177 | # Multi-GPU 178 | if len(config["gpu_ids"]) > 1: 179 | model = nn.DataParallel(model, device_ids=config["gpu_ids"]) 180 | if config["model_type"] == "cvae": 181 | model.net = model.net.to(device) 182 | else: 183 | model = model.to(device) 184 | 185 | if "load_run" in config: # load optimizer and scheduler after data parallel 186 | if "optimizer" in latest_checkpoint: 187 | optimizer.load_state_dict(latest_checkpoint["optimizer"].state_dict()) 188 | if scheduler is not None and "scheduler" in latest_checkpoint: 189 | scheduler.load_state_dict(latest_checkpoint["scheduler"].state_dict()) 190 | 191 | if config["model_type"] == "navibridge": 192 | train_eval_loop_navibridge( 193 | train_model=config["train"], 194 | model=model, 195 | diffusuon=diffusion, 196 | optimizer=optimizer, 197 | lr_scheduler=scheduler, 198 | noise_scheduler=noise_scheduler, 199 | prior_policy=config["prior_policy"], 200 | train_loader=train_loader, 201 | test_dataloaders=test_dataloaders, 202 | transform=transform, 203 | goal_mask_prob=config["goal_mask_prob"], 204 | epochs=config["epochs"], 205 | device=device, 206 | project_folder=config["project_folder"], 207 | print_log_freq=config["print_log_freq"], 208 | wandb_log_freq=config["wandb_log_freq"], 209 | image_log_freq=config["image_log_freq"], 210 | num_images_log=config["num_images_log"], 211 | current_epoch=current_epoch, 212 | alpha=float(config["alpha"]), 213 | use_wandb=config["use_wandb"], 214 | eval_fraction=config["eval_fraction"], 215 | eval_freq=config["eval_freq"], 216 | # ddbm params 217 | steps=config["num_diffusion_iters"], 218 | clip_denoised=config["clip_denoised"], 219 | sampler=config["sampler"], 220 | sigma_min=diffusion.sigma_min, 221 | sigma_max=diffusion.sigma_max, 222 | churn_step_ratio=config["churn_step_ratio"], 223 | rho=config["rho"], 224 | guidance=config["guidance"], 225 | ) 226 | elif config["model_type"] == "cvae": 227 | train_eval_loop_cvae( 228 | train_model=config["train"], 229 | model=model, 230 | optimizer=optimizer, 231 | lr_scheduler=scheduler, 232 | train_loader=train_loader, 233 | transform=transform, 234 | epochs=config["epochs"], 235 | prior_policy=config["prior_policy"], 236 | device=device, 237 | project_folder=config["project_folder"], 238 | print_log_freq=config["print_log_freq"], 239 | wandb_log_freq=config["wandb_log_freq"], 240 | current_epoch=current_epoch, 241 | save_freq=config["save_freq"], 242 | use_wandb=config["use_wandb"], 243 | ) 244 | print("FINISHED TRAINING") 245 | 246 | 247 | if __name__ == "__main__": 248 | torch.multiprocessing.set_start_method("spawn") 249 | 250 | parser = argparse.ArgumentParser(description="Visual Navigation Transformer") 251 | 252 | # project setup 253 | parser.add_argument( 254 | "--config", 255 | "-c", 256 | default="config/navibridge.yaml", 257 | type=str, 258 | help="Path to the config file in train_config folder", 259 | ) 260 | args = parser.parse_args() 261 | 262 | with open("config/defaults.yaml", "r") as f: 263 | default_config = yaml.safe_load(f) 264 | 265 | config = default_config 266 | 267 | with open(args.config, "r") as f: 268 | user_config = yaml.safe_load(f) 269 | 270 | config.update(user_config) 271 | 272 | config["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S") 273 | config["project_folder"] = os.path.join( 274 | "logs", config["project_name"], config["run_name"] 275 | ) 276 | os.makedirs( 277 | config[ 278 | "project_folder" 279 | ], # should error if dir already exists to avoid overwriting and old project 280 | ) 281 | 282 | if config["use_wandb"]: 283 | wandb.login() 284 | wandb.init( 285 | project=config["project_name"], 286 | settings=wandb.Settings(start_method="fork"), 287 | entity="offline", # TODO: change this to your wandb entity 288 | mode="offline", 289 | ) 290 | 291 | # Manually copy the config file to the wandb run directory 292 | config_filename = os.path.basename(args.config) 293 | dest_config_path = os.path.join(wandb.run.dir, config_filename) 294 | shutil.copyfile(args.config, dest_config_path) 295 | 296 | # Save the copied config file to wandb 297 | wandb.save(dest_config_path, policy="now") 298 | 299 | wandb.run.name = config["run_name"] 300 | # Update the wandb args with the training configurations 301 | if wandb.run: 302 | wandb.config.update(config) 303 | 304 | print(config) 305 | main(config) 306 | -------------------------------------------------------------------------------- /train/vint_train/models/navibridge/ddbm/karras_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on: https://github.com/crowsonkb/k-diffusion 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from piq import LPIPS 10 | 11 | from .nn import mean_flat, append_dims, append_zero 12 | 13 | from functools import partial 14 | 15 | 16 | def vp_logsnr(t, beta_d, beta_min): 17 | t = th.as_tensor(t) 18 | return - th.log((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1) 19 | 20 | def vp_logs(t, beta_d, beta_min): 21 | t = th.as_tensor(t) 22 | return -0.25 * t ** 2 * (beta_d) - 0.5 * t * beta_min 23 | 24 | class KarrasDenoiser: 25 | def __init__( 26 | self, 27 | sigma_data: float = 0.5, 28 | sigma_max=80.0, 29 | sigma_min=0.002, 30 | beta_d=2, 31 | beta_min=0.1, 32 | cov_xy=0., 33 | rho=7.0, 34 | image_size=64, 35 | num_timesteps=10, 36 | weight_schedule="karras", 37 | pred_mode='both', 38 | loss_norm="lpips", 39 | ): 40 | self.sigma_data = sigma_data 41 | self.sigma_max = sigma_max 42 | self.sigma_min = sigma_min 43 | 44 | self.beta_d = beta_d 45 | self.beta_min = beta_min 46 | 47 | self.sigma_data_end = self.sigma_data 48 | self.cov_xy = cov_xy 49 | 50 | self.c = 1 51 | 52 | self.weight_schedule = weight_schedule 53 | self.pred_mode = pred_mode 54 | self.loss_norm = loss_norm 55 | self.rho = rho 56 | self.num_timesteps = num_timesteps 57 | self.image_size = image_size 58 | 59 | def get_snr(self, sigmas): 60 | if self.pred_mode.startswith('vp'): 61 | return vp_logsnr(sigmas, self.beta_d, self.beta_min).exp() 62 | else: 63 | return sigmas**-2 64 | 65 | def get_sigmas(self, sigmas): 66 | return sigmas 67 | 68 | def get_weightings(self, sigma): 69 | snrs = self.get_snr(sigma) 70 | 71 | if self.weight_schedule == "snr": 72 | weightings = snrs 73 | elif self.weight_schedule == "snr+1": 74 | weightings = snrs + 1 75 | elif self.weight_schedule == "karras": 76 | weightings = snrs + 1.0 / self.sigma_data**2 77 | elif self.weight_schedule.startswith("bridge_karras"): 78 | if self.pred_mode == 've': 79 | A = sigma**4 / self.sigma_max**4 * self.sigma_data_end**2 + (1 - sigma**2 / self.sigma_max**2)**2 * self.sigma_data**2 + 2*sigma**2 / self.sigma_max**2 * (1 - sigma**2 / self.sigma_max**2) * self.cov_xy + self.c**2 * sigma**2 * (1 - sigma**2 / self.sigma_max**2) 80 | weightings = A / ((sigma/self.sigma_max)**4 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + self.sigma_data**2 * self.c**2 * sigma**2 * (1 - sigma**2/self.sigma_max**2) ) 81 | 82 | elif self.pred_mode == 'vp': 83 | logsnr_t = vp_logsnr(sigma, self.beta_d, self.beta_min) 84 | logsnr_T = vp_logsnr(1, self.beta_d, self.beta_min) 85 | logs_t = vp_logs(sigma, self.beta_d, self.beta_min) 86 | logs_T = vp_logs(1, self.beta_d, self.beta_min) 87 | 88 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp() 89 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp() 90 | c_t = -th.expm1(logsnr_T - logsnr_t) * (2*logs_t - logsnr_t).exp() 91 | 92 | A = a_t**2 * self.sigma_data_end**2 + b_t**2 * self.sigma_data**2 + 2*a_t * b_t * self.cov_xy + self.c**2 * c_t 93 | weightings = A / (a_t**2 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + self.sigma_data**2 * self.c**2 * c_t ) 94 | 95 | elif self.pred_mode == 'vp_simple' or self.pred_mode == 've_simple': 96 | weightings = th.ones_like(snrs) 97 | elif self.weight_schedule == "truncated-snr": 98 | weightings = th.clamp(snrs, min=1.0) 99 | elif self.weight_schedule == "uniform": 100 | weightings = th.ones_like(snrs) 101 | else: 102 | raise NotImplementedError() 103 | 104 | weightings = th.where(th.isfinite(weightings), weightings, th.zeros_like(weightings)) 105 | return weightings 106 | 107 | def get_bridge_scalings(self, sigma): 108 | if self.pred_mode == 've': 109 | A = sigma**4 / self.sigma_max**4 * self.sigma_data_end**2 + \ 110 | (1 - sigma**2 / self.sigma_max**2)**2 * self.sigma_data**2 + \ 111 | 2*sigma**2 / self.sigma_max**2 * (1 - sigma**2 / self.sigma_max**2) * self.cov_xy + \ 112 | self.c **2 * sigma**2 * (1 - sigma**2 / self.sigma_max**2) 113 | c_in = 1 / (A) ** 0.5 114 | c_skip = ((1 - sigma**2 / self.sigma_max**2) * self.sigma_data**2 + sigma**2 / self.sigma_max**2 * self.cov_xy) / A 115 | c_out = ((sigma/self.sigma_max)**4 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + \ 116 | self.sigma_data**2 * self.c **2 * sigma**2 * (1 - sigma**2/self.sigma_max**2) )**0.5 * c_in 117 | return c_skip, c_out, c_in 118 | 119 | elif self.pred_mode == 'vp': 120 | logsnr_t = vp_logsnr(sigma, self.beta_d, self.beta_min) 121 | logsnr_T = vp_logsnr(1, self.beta_d, self.beta_min) 122 | logs_t = vp_logs(sigma, self.beta_d, self.beta_min) 123 | logs_T = vp_logs(1, self.beta_d, self.beta_min) 124 | 125 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp() 126 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp() 127 | c_t = -th.expm1(logsnr_T - logsnr_t) * (2*logs_t - logsnr_t).exp() 128 | 129 | A = a_t**2 * self.sigma_data_end**2 + b_t**2 * self.sigma_data**2 + \ 130 | 2*a_t * b_t * self.cov_xy + self.c**2 * c_t 131 | 132 | c_in = 1 / (A) ** 0.5 133 | c_skip = (b_t * self.sigma_data**2 + a_t * self.cov_xy) / A 134 | c_out = (a_t**2 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + \ 135 | self.sigma_data**2 * self.c **2 * c_t )**0.5 * c_in 136 | return c_skip, c_out, c_in 137 | 138 | elif self.pred_mode == 've_simple' or self.pred_mode == 'vp_simple': 139 | c_in = th.ones_like(sigma) 140 | c_out = th.ones_like(sigma) 141 | c_skip = th.zeros_like(sigma) 142 | return c_skip, c_out, c_in 143 | 144 | def training_bridge_losses(self, model, x_start, sigmas, global_cond=None, model_kwargs=None, noise=None, vae=None, train=True): 145 | assert model_kwargs is not None 146 | xT = model_kwargs['xT'] 147 | if noise is None: 148 | noise = th.randn_like(x_start) 149 | sigmas = th.minimum(sigmas, th.ones_like(sigmas) * self.sigma_max) 150 | terms = {} 151 | 152 | dims = x_start.ndim 153 | 154 | def bridge_sample(x0, xT, t): 155 | t = append_dims(t, dims) 156 | if self.pred_mode.startswith('ve'): 157 | std_t = t * th.sqrt(1 - t**2 / self.sigma_max**2) 158 | mu_t = t**2 / self.sigma_max**2 * xT + (1 - t**2 / self.sigma_max**2) * x0 159 | samples = (mu_t + std_t * noise) 160 | elif self.pred_mode.startswith('vp'): 161 | logsnr_t = vp_logsnr(t, self.beta_d, self.beta_min) 162 | logsnr_T = vp_logsnr(self.sigma_max, self.beta_d, self.beta_min) 163 | logs_t = vp_logs(t, self.beta_d, self.beta_min) 164 | logs_T = vp_logs(self.sigma_max, self.beta_d, self.beta_min) 165 | 166 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp() 167 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp() 168 | std_t = (-th.expm1(logsnr_T - logsnr_t)).sqrt() * (logs_t - logsnr_t/2).exp() 169 | 170 | samples = a_t * xT + b_t * x0 + std_t * noise 171 | return samples 172 | 173 | x_t = bridge_sample(x_start, xT, sigmas) 174 | 175 | model_output, denoised = self.denoise(model, x_t, sigmas, global_cond, **model_kwargs) 176 | 177 | weights = self.get_weightings(sigmas) 178 | weights = append_dims((weights), dims) 179 | 180 | terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) 181 | terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2) 182 | if th.isnan(terms["mse"]).any(): 183 | import ipdb;ipdb.set_trace() 184 | 185 | if "vb" in terms: 186 | terms["loss"] = terms["mse"] + terms["vb"] 187 | else: 188 | terms["loss"] = terms["mse"] 189 | return terms, denoised 190 | 191 | def denoise(self, model, x_t, sigmas, global_cond=None, **model_kwargs): 192 | c_skip, c_out, c_in = [ 193 | append_dims(x, x_t.ndim) for x in self.get_bridge_scalings(sigmas) 194 | ] 195 | 196 | rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) 197 | 198 | model_output = model("noise_pred_net", 199 | sample=c_in * x_t, 200 | timestep=rescaled_t, 201 | global_cond=global_cond, 202 | **model_kwargs) 203 | denoised = c_out * model_output + c_skip * x_t 204 | return model_output, denoised 205 | 206 | def karras_sample( 207 | diffusion, 208 | model, 209 | x_T, 210 | x_0, 211 | steps, 212 | clip_denoised=True, 213 | progress=False, 214 | callback=None, 215 | model_kwargs=None, 216 | global_cond=None, 217 | device=None, 218 | sigma_min=0.002, 219 | sigma_max=80, 220 | rho=7.0, 221 | sampler="heun", 222 | churn_step_ratio=0., 223 | guidance=1, 224 | ): 225 | assert sampler in ["heun", ], 'only heun sampler is supported currently' 226 | 227 | sigmas = get_sigmas_karras(steps, sigma_min, sigma_max-1e-4, rho, device=device) 228 | 229 | sample_fn = { 230 | "heun": partial(sample_heun, beta_d=diffusion.beta_d, beta_min=diffusion.beta_min), 231 | }[sampler] 232 | 233 | sampler_args = dict( 234 | pred_mode=diffusion.pred_mode, churn_step_ratio=churn_step_ratio, sigma_max=sigma_max 235 | ) 236 | 237 | def denoiser(x_t, sigma, x_T=None, global_cond=None): 238 | _, denoised = diffusion.denoise(model, x_t, sigma, global_cond) 239 | if clip_denoised: 240 | denoised = denoised.clamp(-1, 1) 241 | return denoised 242 | 243 | x_0, path, nfe = sample_fn( 244 | denoiser, 245 | x_T, 246 | sigmas, 247 | global_cond=global_cond, 248 | progress=progress, 249 | callback=callback, 250 | guidance=guidance, 251 | **sampler_args, 252 | ) 253 | print('nfe:', nfe) 254 | 255 | return x_0.clamp(-1, 1), [x.clamp(-1, 1) for x in path], nfe 256 | 257 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): 258 | ramp = th.linspace(0, 1, n) 259 | min_inv_rho = sigma_min ** (1 / rho) 260 | max_inv_rho = sigma_max ** (1 / rho) 261 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 262 | return append_zero(sigmas).to(device) 263 | 264 | def get_bridge_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, eps=1e-4, device="cpu"): 265 | sigma_t_crit = sigma_max / np.sqrt(2) 266 | min_start_inv_rho = sigma_min ** (1 / rho) 267 | max_inv_rho = sigma_t_crit ** (1 / rho) 268 | 269 | sigmas_second_half = (max_inv_rho + th.linspace(0, 1, n//2 ) * (min_start_inv_rho - max_inv_rho)) ** rho 270 | 271 | sigmas_first_half = sigma_max - ((sigma_max - sigma_t_crit) ** (1 / rho) + th.linspace(0, 1, n - n//2 +1 ) * (eps ** (1 / rho) - (sigma_max - sigma_t_crit) ** (1 / rho))) ** rho 272 | sigmas = th.cat([sigmas_first_half.flip(0)[:-1], sigmas_second_half]) 273 | return append_zero(sigmas).to(device) 274 | 275 | def to_d(x, sigma, denoised, x_T, sigma_max, w=1, stochastic=False): 276 | grad_pxtlx0 = (denoised - x) / append_dims(sigma**2, x.ndim) 277 | grad_pxTlxt = (x_T - x) / (append_dims(th.ones_like(sigma)*sigma_max**2, x.ndim) - append_dims(sigma**2, x.ndim)) 278 | gt2 = 2 * sigma 279 | d = -(0.5 if not stochastic else 1) * gt2 * (grad_pxtlx0 - w * grad_pxTlxt * (0 if stochastic else 1)) 280 | if stochastic: 281 | return d, gt2 282 | else: 283 | return d 284 | 285 | def get_d_vp(x, denoised, x_T, std_t, logsnr_t, logsnr_T, logs_t, logs_T, s_t_deriv, sigma_t, sigma_t_deriv, w, stochastic=False): 286 | 287 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp() 288 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp() 289 | 290 | mu_t = a_t * x_T + b_t * denoised 291 | 292 | grad_logq = -(x - mu_t) / std_t**2 / (-th.expm1(logsnr_T - logsnr_t)) 293 | 294 | grad_logpxTlxt = -(x - th.exp(logs_t - logs_T) * x_T) / std_t**2 / th.expm1(logsnr_t - logsnr_T) 295 | 296 | f = s_t_deriv * (-logs_t).exp() * x 297 | gt2 = 2 * (logs_t).exp()**2 * sigma_t * sigma_t_deriv 298 | 299 | d = f - gt2 * ((0.5 if not stochastic else 1) * grad_logq - w * grad_logpxTlxt) 300 | if stochastic: 301 | return d, gt2 302 | else: 303 | return d 304 | 305 | @th.no_grad() 306 | def sample_heun( 307 | denoiser, 308 | x, 309 | sigmas, 310 | global_cond=None, 311 | pred_mode='both', 312 | progress=False, 313 | callback=None, 314 | sigma_max=80.0, 315 | beta_d=2, 316 | beta_min=0.1, 317 | churn_step_ratio=0., 318 | guidance=1, 319 | ): 320 | x_T = x 321 | path = [x] 322 | 323 | s_in = x.new_ones([x.shape[0]]) 324 | indices = range(len(sigmas) - 1) 325 | 326 | if progress: 327 | from tqdm.auto import tqdm 328 | indices = tqdm(indices) 329 | 330 | nfe = 0 331 | assert churn_step_ratio < 1 332 | 333 | if pred_mode.startswith('vp'): 334 | vp_snr_sqrt_reciprocal = lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 335 | vp_snr_sqrt_reciprocal_deriv = lambda t: 0.5 * (beta_min + beta_d * t) * (vp_snr_sqrt_reciprocal(t) + 1 / vp_snr_sqrt_reciprocal(t)) 336 | s = lambda t: (1 + vp_snr_sqrt_reciprocal(t) ** 2).rsqrt() 337 | s_deriv = lambda t: -vp_snr_sqrt_reciprocal(t) * vp_snr_sqrt_reciprocal_deriv(t) * (s(t) ** 3) 338 | logs = lambda t: -0.25 * t ** 2 * (beta_d) - 0.5 * t * beta_min 339 | std = lambda t: vp_snr_sqrt_reciprocal(t) * s(t) 340 | logsnr = lambda t : -2 * th.log(vp_snr_sqrt_reciprocal(t)) 341 | 342 | logsnr_T = logsnr(th.as_tensor(sigma_max)) 343 | logs_T = logs(th.as_tensor(sigma_max)) 344 | 345 | for j, i in enumerate(indices): 346 | if churn_step_ratio > 0: 347 | sigma_hat = (sigmas[i+1] - sigmas[i]) * churn_step_ratio + sigmas[i] 348 | 349 | denoised = denoiser(x, sigmas[i] * s_in, x_T, global_cond) 350 | if pred_mode == 've': 351 | d_1, gt2 = to_d(x, sigmas[i], denoised, x_T, sigma_max, w=guidance, stochastic=True) 352 | elif pred_mode.startswith('vp'): 353 | d_1, gt2 = get_d_vp(x, denoised, x_T, std(sigmas[i]), logsnr(sigmas[i]), logsnr_T, logs(sigmas[i]), logs_T, s_deriv(sigmas[i]), vp_snr_sqrt_reciprocal(sigmas[i]), vp_snr_sqrt_reciprocal_deriv(sigmas[i]), guidance, stochastic=True) 354 | 355 | dt = (sigma_hat - sigmas[i]) 356 | x = x + d_1 * dt + th.randn_like(x) * (dt.abs() ** 0.5) * gt2.sqrt() 357 | 358 | nfe += 1 359 | path.append(x.detach().cpu()) 360 | else: 361 | sigma_hat = sigmas[i] 362 | 363 | denoised = denoiser(x, sigma_hat * s_in, x_T, global_cond) 364 | if pred_mode == 've': 365 | d = to_d(x, sigma_hat, denoised, x_T, sigma_max, w=guidance) 366 | elif pred_mode.startswith('vp'): 367 | d = get_d_vp(x, denoised, x_T, std(sigma_hat), logsnr(sigma_hat), logsnr_T, logs(sigma_hat), logs_T, s_deriv(sigma_hat), vp_snr_sqrt_reciprocal(sigma_hat), vp_snr_sqrt_reciprocal_deriv(sigma_hat), guidance) 368 | 369 | nfe += 1 370 | if callback is not None: 371 | callback( 372 | { 373 | "x": x, 374 | "i": i, 375 | "sigma": sigmas[i], 376 | "sigma_hat": sigma_hat, 377 | "denoised": denoised, 378 | } 379 | ) 380 | 381 | dt = sigmas[i + 1] - sigma_hat 382 | 383 | if sigmas[i + 1] == 0: 384 | x = x + d * dt 385 | else: 386 | x_2 = x + d * dt 387 | denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, x_T, global_cond) 388 | if pred_mode == 've': 389 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2, x_T, sigma_max, w=guidance) 390 | elif pred_mode.startswith('vp'): 391 | d_2 = get_d_vp(x_2, denoised_2, x_T, std(sigmas[i + 1]), logsnr(sigmas[i + 1]), logsnr_T, logs(sigmas[i + 1]), logs_T, s_deriv(sigmas[i + 1]), vp_snr_sqrt_reciprocal(sigmas[i + 1]), vp_snr_sqrt_reciprocal_deriv(sigmas[i + 1]), guidance) 392 | 393 | d_prime = (d + d_2) / 2 394 | x = x + d_prime * dt 395 | nfe += 1 396 | 397 | path.append(x.detach().cpu()) 398 | 399 | return x, path, nfe 400 | 401 | @th.no_grad() 402 | def forward_sample( 403 | x0, 404 | y0, 405 | sigma_max, 406 | ): 407 | ts = th.linspace(0, sigma_max, 120) 408 | x = x0 409 | path = [x] 410 | 411 | for t in ts: 412 | std_t = th.sqrt(t) * th.sqrt(1 - t / sigma_max) 413 | mu_t = t / sigma_max * y0 + (1 - t / sigma_max) * x0 414 | xt = mu_t + std_t * th.randn_like(x0) 415 | path.append(xt) 416 | 417 | path.append(y0) 418 | 419 | return path 420 | -------------------------------------------------------------------------------- /train/vint_train/data/vint_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import yaml 5 | from typing import Any, Dict, List, Optional, Tuple 6 | import tqdm 7 | import io 8 | import lmdb 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | import torchvision.transforms.functional as TF 13 | 14 | from vint_train.data.data_utils import ( 15 | img_path_to_data, 16 | calculate_sin_cos, 17 | get_data_path, 18 | to_local_coords, 19 | ) 20 | 21 | class ViNT_Dataset(Dataset): 22 | def __init__( 23 | self, 24 | data_folder: str, 25 | data_split_folder: str, 26 | dataset_name: str, 27 | image_size: Tuple[int, int], 28 | waypoint_spacing: int, 29 | min_dist_cat: int, 30 | max_dist_cat: int, 31 | min_action_distance: int, 32 | max_action_distance: int, 33 | negative_mining: bool, 34 | len_traj_pred: int, 35 | learn_angle: bool, 36 | context_size: int, 37 | context_type: str = "temporal", 38 | end_slack: int = 0, 39 | goals_per_obs: int = 1, 40 | normalize: bool = True, 41 | obs_type: str = "image", 42 | goal_type: str = "image", 43 | angle_ranges: list = None, 44 | ): 45 | """ 46 | Main ViNT dataset class 47 | 48 | Args: 49 | data_folder (string): Directory with all the image data 50 | data_split_folder (string): Directory with filepaths.txt, a list of all trajectory names in the dataset split that are each seperated by a newline 51 | dataset_name (string): Name of the dataset [recon, go_stanford, scand, tartandrive, etc.] 52 | waypoint_spacing (int): Spacing between waypoints 53 | min_dist_cat (int): Minimum distance category to use 54 | max_dist_cat (int): Maximum distance category to use 55 | negative_mining (bool): Whether to use negative mining from the ViNG paper (Shah et al.) (https://arxiv.org/abs/2012.09812) 56 | len_traj_pred (int): Length of trajectory of waypoints to predict if this is an action dataset 57 | learn_angle (bool): Whether to learn the yaw of the robot at each predicted waypoint if this is an action dataset 58 | context_size (int): Number of previous observations to use as context 59 | context_type (str): Whether to use temporal, randomized, or randomized temporal context 60 | end_slack (int): Number of timesteps to ignore at the end of the trajectory 61 | goals_per_obs (int): Number of goals to sample per observation 62 | normalize (bool): Whether to normalize the distances or actions 63 | goal_type (str): What data type to use for the goal. The only one supported is "image" for now. 64 | """ 65 | self.data_folder = data_folder 66 | self.data_split_folder = data_split_folder 67 | self.dataset_name = dataset_name 68 | 69 | traj_names_file = os.path.join(data_split_folder, "traj_names.txt") 70 | with open(traj_names_file, "r") as f: 71 | file_lines = f.read() 72 | self.traj_names = file_lines.split("\n") 73 | if "" in self.traj_names: 74 | self.traj_names.remove("") 75 | 76 | self.image_size = image_size 77 | self.waypoint_spacing = waypoint_spacing 78 | self.distance_categories = list( 79 | range(min_dist_cat, max_dist_cat + 1, self.waypoint_spacing) 80 | ) 81 | self.min_dist_cat = self.distance_categories[0] 82 | self.max_dist_cat = self.distance_categories[-1] 83 | self.negative_mining = negative_mining 84 | if self.negative_mining: 85 | self.distance_categories.append(-1) 86 | self.len_traj_pred = len_traj_pred 87 | self.learn_angle = learn_angle 88 | 89 | self.min_action_distance = min_action_distance 90 | self.max_action_distance = max_action_distance 91 | 92 | self.context_size = context_size 93 | assert context_type in { 94 | "temporal", 95 | "randomized", 96 | "randomized_temporal", 97 | }, "context_type must be one of temporal, randomized, randomized_temporal" 98 | self.context_type = context_type 99 | self.end_slack = end_slack 100 | self.goals_per_obs = goals_per_obs 101 | self.normalize = normalize 102 | self.obs_type = obs_type 103 | self.goal_type = goal_type 104 | 105 | # load data/data_config.yaml 106 | with open( 107 | os.path.join(os.path.dirname(__file__), "data_config.yaml"), "r" 108 | ) as f: 109 | all_data_config = yaml.safe_load(f) 110 | assert ( 111 | self.dataset_name in all_data_config 112 | ), f"Dataset {self.dataset_name} not found in data_config.yaml" 113 | dataset_names = list(all_data_config.keys()) 114 | dataset_names.sort() 115 | # use this index to retrieve the dataset name from the data_config.yaml 116 | self.dataset_index = dataset_names.index(self.dataset_name) 117 | self.data_config = all_data_config[self.dataset_name] 118 | self.trajectory_cache = {} 119 | self._load_index() 120 | self._build_caches() 121 | 122 | if self.learn_angle: 123 | self.num_action_params = 3 124 | else: 125 | self.num_action_params = 2 126 | 127 | if angle_ranges is None: 128 | self._set_angle_ranges() 129 | else: 130 | self.angle_ranges = angle_ranges 131 | 132 | def _set_angle_ranges(self): 133 | self.angle_ranges = [(0, 67.5), 134 | (67.5, 112.5), 135 | (112.5, 180), 136 | (180, 270), 137 | (270, 360)] 138 | 139 | def __getstate__(self): 140 | state = self.__dict__.copy() 141 | state["_image_cache"] = None 142 | return state 143 | 144 | def __setstate__(self, state): 145 | self.__dict__ = state 146 | self._build_caches() 147 | 148 | def _build_caches(self, use_tqdm: bool = True): 149 | """ 150 | Build a cache of images for faster loading using LMDB 151 | """ 152 | cache_filename = os.path.join( 153 | self.data_split_folder, 154 | f"dataset_{self.dataset_name}.lmdb", 155 | ) 156 | 157 | # Load all the trajectories into memory. These should already be loaded, but just in case. 158 | for traj_name in self.traj_names: 159 | self._get_trajectory(traj_name) 160 | 161 | """ 162 | If the cache file doesn't exist, create it by iterating through the dataset and writing each image to the cache 163 | """ 164 | if not os.path.exists(cache_filename): 165 | tqdm_iterator = tqdm.tqdm( 166 | self.goals_index, 167 | disable=not use_tqdm, 168 | dynamic_ncols=True, 169 | desc=f"Building LMDB cache for {self.dataset_name}" 170 | ) 171 | with lmdb.open(cache_filename, map_size=2**40) as image_cache: 172 | with image_cache.begin(write=True) as txn: 173 | for traj_name, time in tqdm_iterator: 174 | image_path = get_data_path(self.data_folder, traj_name, time) 175 | with open(image_path, "rb") as f: 176 | txn.put(image_path.encode(), f.read()) 177 | 178 | # Reopen the cache file in read-only mode 179 | self._image_cache: lmdb.Environment = lmdb.open(cache_filename, readonly=True) 180 | 181 | def _build_index(self, use_tqdm: bool = False): 182 | """ 183 | Build an index consisting of tuples (trajectory name, time, max goal distance) 184 | """ 185 | samples_index = [] 186 | goals_index = [] 187 | 188 | for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True): 189 | traj_data = self._get_trajectory(traj_name) 190 | traj_len = len(traj_data["position"]) 191 | 192 | for goal_time in range(0, traj_len): 193 | goals_index.append((traj_name, goal_time)) 194 | 195 | begin_time = self.context_size * self.waypoint_spacing 196 | end_time = traj_len - self.end_slack - self.len_traj_pred * self.waypoint_spacing 197 | for curr_time in range(begin_time, end_time): 198 | max_goal_distance = min(self.max_dist_cat * self.waypoint_spacing, traj_len - curr_time - 1) 199 | samples_index.append((traj_name, curr_time, max_goal_distance)) 200 | 201 | return samples_index, goals_index 202 | 203 | def _sample_goal(self, trajectory_name, curr_time, max_goal_dist): 204 | """ 205 | Sample a goal from the future in the same trajectory. 206 | Returns: (trajectory_name, goal_time, goal_is_negative) 207 | """ 208 | goal_offset = np.random.randint(0, max_goal_dist + 1) 209 | if goal_offset == 0: 210 | trajectory_name, goal_time = self._sample_negative() 211 | return trajectory_name, goal_time, True 212 | else: 213 | goal_time = curr_time + int(goal_offset * self.waypoint_spacing) 214 | return trajectory_name, goal_time, False 215 | 216 | def _sample_negative(self): 217 | """ 218 | Sample a goal from a (likely) different trajectory. 219 | """ 220 | return self.goals_index[np.random.randint(0, len(self.goals_index))] 221 | 222 | def _load_index(self) -> None: 223 | """ 224 | Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset 225 | """ 226 | index_to_data_path = os.path.join( 227 | self.data_split_folder, 228 | f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_context_{self.context_type}_n{self.context_size}_slack_{self.end_slack}.pkl", 229 | ) 230 | try: 231 | # load the index_to_data if it already exists (to save time) 232 | with open(index_to_data_path, "rb") as f: 233 | self.index_to_data, self.goals_index = pickle.load(f) 234 | except: 235 | # if the index_to_data file doesn't exist, create it 236 | self.index_to_data, self.goals_index = self._build_index() 237 | with open(index_to_data_path, "wb") as f: 238 | pickle.dump((self.index_to_data, self.goals_index), f) 239 | 240 | def _load_image(self, trajectory_name, time): 241 | image_path = get_data_path(self.data_folder, trajectory_name, time) 242 | # return img_path_to_data(image_path, self.image_size) 243 | try: 244 | with self._image_cache.begin() as txn: 245 | image_buffer = txn.get(image_path.encode()) 246 | image_bytes = bytes(image_buffer) 247 | image_bytes = io.BytesIO(image_bytes) 248 | return img_path_to_data(image_bytes, self.image_size) 249 | except TypeError: 250 | print(f"Failed to load image {image_path}") 251 | 252 | def _theta2category(self, theta): 253 | for i, (min_angle, max_angle) in enumerate(self.angle_ranges): 254 | if min_angle <= theta < max_angle: 255 | return i 256 | return i 257 | 258 | def _calculate_angle(self, waypoints): 259 | x_end, y_end = waypoints[-1] 260 | 261 | angle_rad = np.arctan2(y_end, x_end) 262 | 263 | angle_deg = np.degrees(angle_rad) 264 | 265 | if angle_deg < 0: 266 | angle_deg += 360 267 | 268 | return angle_deg 269 | 270 | def _compute_actions(self, traj_data, curr_time, goal_time): 271 | start_index = curr_time 272 | end_index = curr_time + self.len_traj_pred * self.waypoint_spacing + 1 273 | 274 | yaw = traj_data["yaw"][start_index:end_index:self.waypoint_spacing] 275 | positions = traj_data["position"][start_index:end_index:self.waypoint_spacing] 276 | 277 | goal_pos = traj_data["position"][min(goal_time, len(traj_data["position"]) - 1)] 278 | 279 | if len(yaw.shape) == 2: 280 | yaw = yaw.squeeze(1) 281 | 282 | if yaw.shape != (self.len_traj_pred + 1,): 283 | const_len = self.len_traj_pred + 1 - yaw.shape[0] 284 | yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)]) 285 | positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0) 286 | 287 | assert yaw.shape == (self.len_traj_pred + 1,), f"{yaw.shape} and {(self.len_traj_pred + 1,)} should be equal" 288 | assert positions.shape == (self.len_traj_pred + 1, 2), f"{positions.shape} and {(self.len_traj_pred + 1, 2)} should be equal" 289 | 290 | waypoints = to_local_coords(positions, positions[0], yaw[0]) 291 | goal_pos = to_local_coords(goal_pos, positions[0], yaw[0]) 292 | 293 | assert waypoints.shape == (self.len_traj_pred + 1, 2), f"{waypoints.shape} and {(self.len_traj_pred + 1, 2)} should be equal" 294 | 295 | if self.learn_angle: 296 | yaw = yaw[1:] - yaw[0] 297 | actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1) 298 | else: 299 | actions = waypoints[1:] 300 | 301 | if self.normalize: 302 | actions[:, :2] /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing 303 | goal_pos /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing 304 | 305 | assert actions.shape == (self.len_traj_pred, self.num_action_params), f"{actions.shape} and {(self.len_traj_pred, self.num_action_params)} should be equal" 306 | 307 | theta = self._calculate_angle(waypoints) 308 | action_category = self._theta2category(theta) 309 | 310 | return actions, goal_pos, action_category 311 | 312 | def _get_trajectory(self, trajectory_name): 313 | if trajectory_name in self.trajectory_cache: 314 | return self.trajectory_cache[trajectory_name] 315 | else: 316 | with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f: 317 | traj_data = pickle.load(f) 318 | self.trajectory_cache[trajectory_name] = traj_data 319 | return traj_data 320 | 321 | def __len__(self) -> int: 322 | return len(self.index_to_data) 323 | 324 | def __getitem__(self, i: int) -> Tuple[torch.Tensor]: 325 | """ 326 | Args: 327 | i (int): index to ith datapoint 328 | Returns: 329 | Tuple of tensors containing the context, observation, goal, transformed context, transformed observation, transformed goal, distance label, and action label 330 | obs_image (torch.Tensor): tensor of shape [3, H, W] containing the image of the robot's observation 331 | goal_image (torch.Tensor): tensor of shape [3, H, W] containing the subgoal image 332 | dist_label (torch.Tensor): tensor of shape (1,) containing the distance labels from the observation to the goal 333 | action_label (torch.Tensor): tensor of shape (5, 2) or (5, 4) (if training with angle) containing the action labels from the observation to the goal 334 | which_dataset (torch.Tensor): index of the datapoint in the dataset [for identifying the dataset for visualization when using multiple datasets] 335 | """ 336 | f_curr, curr_time, max_goal_dist = self.index_to_data[i] 337 | f_goal, goal_time, goal_is_negative = self._sample_goal(f_curr, curr_time, max_goal_dist) 338 | 339 | # Load images 340 | context = [] 341 | if self.context_type == "temporal": 342 | # sample the last self.context_size times from interval [0, curr_time) 343 | context_times = list( 344 | range( 345 | curr_time + -self.context_size * self.waypoint_spacing, 346 | curr_time + 1, 347 | self.waypoint_spacing, 348 | ) 349 | ) 350 | context = [(f_curr, t) for t in context_times] 351 | else: 352 | raise ValueError(f"Invalid context type {self.context_type}") 353 | 354 | obs_image = torch.cat([ 355 | self._load_image(f, t) for f, t in context 356 | ]) 357 | 358 | # Load goal image 359 | goal_image = self._load_image(f_goal, goal_time) 360 | 361 | # Load other trajectory data 362 | curr_traj_data = self._get_trajectory(f_curr) 363 | curr_traj_len = len(curr_traj_data["position"]) 364 | assert curr_time < curr_traj_len, f"{curr_time} and {curr_traj_len}" 365 | 366 | goal_traj_data = self._get_trajectory(f_goal) 367 | goal_traj_len = len(goal_traj_data["position"]) 368 | assert goal_time < goal_traj_len, f"{goal_time} an {goal_traj_len}" 369 | 370 | # Compute actions 371 | actions, goal_pos, action_category = self._compute_actions(curr_traj_data, curr_time, goal_time) 372 | 373 | if actions.dtype is not np.float32: 374 | actions = np.array(actions, dtype=np.float32) 375 | if goal_pos.dtype is not np.float32: 376 | goal_pos = np.array(goal_pos, dtype=np.float32) 377 | 378 | # Compute distances 379 | if goal_is_negative: 380 | distance = self.max_dist_cat 381 | else: 382 | distance = (goal_time - curr_time) // self.waypoint_spacing 383 | assert (goal_time - curr_time) % self.waypoint_spacing == 0, f"{goal_time} and {curr_time} should be separated by an integer multiple of {self.waypoint_spacing}" 384 | 385 | actions_torch = torch.as_tensor(actions, dtype=torch.float32) 386 | if self.learn_angle: 387 | actions_torch = calculate_sin_cos(actions_torch) 388 | 389 | action_mask = ( 390 | (distance < self.max_action_distance) and 391 | (distance > self.min_action_distance) and 392 | (not goal_is_negative) 393 | ) 394 | 395 | return ( 396 | torch.as_tensor(obs_image, dtype=torch.float32), 397 | torch.as_tensor(goal_image, dtype=torch.float32), 398 | actions_torch, 399 | torch.as_tensor(distance, dtype=torch.int64), 400 | torch.as_tensor(goal_pos, dtype=torch.float32), 401 | torch.as_tensor(self.dataset_index, dtype=torch.int64), 402 | torch.as_tensor(action_mask, dtype=torch.float32), 403 | torch.as_tensor(action_category, dtype=torch.long) 404 | ) 405 | 406 | def sample_prior(): 407 | pass -------------------------------------------------------------------------------- /train/vint_train/visualizing/action_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | from typing import Optional, List 6 | import wandb 7 | import yaml 8 | import torch 9 | import torch.nn as nn 10 | from vint_train.visualizing.visualize_utils import ( 11 | to_numpy, 12 | numpy_to_img, 13 | VIZ_IMAGE_SIZE, 14 | RED, 15 | GREEN, 16 | BLUE, 17 | CYAN, 18 | YELLOW, 19 | MAGENTA, 20 | ) 21 | 22 | # load data_config.yaml 23 | with open(os.path.join(os.path.dirname(__file__), "../data/data_config.yaml"), "r") as f: 24 | data_config = yaml.safe_load(f) 25 | 26 | 27 | def visualize_traj_pred( 28 | batch_obs_images: np.ndarray, 29 | batch_goal_images: np.ndarray, 30 | dataset_indices: np.ndarray, 31 | batch_goals: np.ndarray, 32 | batch_pred_waypoints: np.ndarray, 33 | batch_label_waypoints: np.ndarray, 34 | eval_type: str, 35 | normalized: bool, 36 | save_folder: str, 37 | epoch: int, 38 | num_images_preds: int = 8, 39 | use_wandb: bool = True, 40 | display: bool = False, 41 | ): 42 | """ 43 | Compare predicted path with the gt path of waypoints using egocentric visualization. This visualization is for the last batch in the dataset. 44 | 45 | Args: 46 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels] 47 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels] 48 | dataset_names: indices corresponding to the dataset name 49 | batch_goals (np.ndarray): batch of goal positions [batch_size, 2] 50 | batch_pred_waypoints (np.ndarray): batch of predicted waypoints [batch_size, horizon, 4] or [batch_size, horizon, 2] or [batch_size, num_trajs_sampled horizon, {2 or 4}] 51 | batch_label_waypoints (np.ndarray): batch of label waypoints [batch_size, T, 4] or [batch_size, horizon, 2] 52 | eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.) 53 | normalized (bool): whether the waypoints are normalized 54 | save_folder (str): folder to save the images. If None, will not save the images 55 | epoch (int): current epoch number 56 | num_images_preds (int): number of images to visualize 57 | use_wandb (bool): whether to use wandb to log the images 58 | display (bool): whether to display the images 59 | """ 60 | visualize_path = None 61 | if save_folder is not None: 62 | visualize_path = os.path.join( 63 | save_folder, "visualize", eval_type, f"epoch{epoch}", "action_prediction" 64 | ) 65 | 66 | if not os.path.exists(visualize_path): 67 | os.makedirs(visualize_path) 68 | 69 | assert ( 70 | len(batch_obs_images) 71 | == len(batch_goal_images) 72 | == len(batch_goals) 73 | == len(batch_pred_waypoints) 74 | == len(batch_label_waypoints) 75 | ) 76 | 77 | dataset_names = list(data_config.keys()) 78 | dataset_names.sort() 79 | 80 | batch_size = batch_obs_images.shape[0] 81 | wandb_list = [] 82 | for i in range(min(batch_size, num_images_preds)): 83 | obs_img = numpy_to_img(batch_obs_images[i]) 84 | goal_img = numpy_to_img(batch_goal_images[i]) 85 | dataset_name = dataset_names[int(dataset_indices[i])] 86 | goal_pos = batch_goals[i] 87 | pred_waypoints = batch_pred_waypoints[i] 88 | label_waypoints = batch_label_waypoints[i] 89 | 90 | if normalized: 91 | pred_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"] 92 | label_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"] 93 | goal_pos *= data_config[dataset_name]["metric_waypoint_spacing"] 94 | 95 | save_path = None 96 | if visualize_path is not None: 97 | save_path = os.path.join(visualize_path, f"{str(i).zfill(4)}.png") 98 | 99 | compare_waypoints_pred_to_label( 100 | obs_img, 101 | goal_img, 102 | dataset_name, 103 | goal_pos, 104 | pred_waypoints, 105 | label_waypoints, 106 | save_path, 107 | display, 108 | ) 109 | if use_wandb: 110 | wandb_list.append(wandb.Image(save_path)) 111 | if use_wandb: 112 | wandb.log({f"{eval_type}_action_prediction": wandb_list}, commit=False) 113 | 114 | 115 | def compare_waypoints_pred_to_label( 116 | obs_img, 117 | goal_img, 118 | dataset_name: str, 119 | goal_pos: np.ndarray, 120 | pred_waypoints: np.ndarray, 121 | label_waypoints: np.ndarray, 122 | save_path: Optional[str] = None, 123 | display: Optional[bool] = False, 124 | ): 125 | """ 126 | Compare predicted path with the gt path of waypoints using egocentric visualization. 127 | 128 | Args: 129 | obs_img: image of the observation 130 | goal_img: image of the goal 131 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon") 132 | goal_pos: goal position in the image 133 | pred_waypoints: predicted waypoints in the image 134 | label_waypoints: label waypoints in the image 135 | save_path: path to save the figure 136 | display: whether to display the figure 137 | """ 138 | 139 | fig, ax = plt.subplots(1, 3) 140 | start_pos = np.array([0, 0]) 141 | if len(pred_waypoints.shape) > 2: 142 | trajs = [*pred_waypoints, label_waypoints] 143 | else: 144 | trajs = [pred_waypoints, label_waypoints] 145 | plot_trajs_and_points( 146 | ax[0], 147 | trajs, 148 | [start_pos, goal_pos], 149 | traj_colors=[CYAN, MAGENTA], 150 | point_colors=[GREEN, RED], 151 | ) 152 | plot_trajs_and_points_on_image( 153 | ax[1], 154 | obs_img, 155 | dataset_name, 156 | trajs, 157 | [start_pos, goal_pos], 158 | traj_colors=[CYAN, MAGENTA], 159 | point_colors=[GREEN, RED], 160 | ) 161 | ax[2].imshow(goal_img) 162 | 163 | fig.set_size_inches(18.5, 10.5) 164 | ax[0].set_title(f"Action Prediction") 165 | ax[1].set_title(f"Observation") 166 | ax[2].set_title(f"Goal") 167 | 168 | if save_path is not None: 169 | fig.savefig( 170 | save_path, 171 | bbox_inches="tight", 172 | ) 173 | 174 | if not display: 175 | plt.close(fig) 176 | 177 | 178 | def plot_trajs_and_points_on_image( 179 | ax: plt.Axes, 180 | img: np.ndarray, 181 | dataset_name: str, 182 | list_trajs: list, 183 | list_points: list, 184 | traj_colors: list = [CYAN, MAGENTA], 185 | point_colors: list = [RED, GREEN], 186 | ): 187 | """ 188 | Plot trajectories and points on an image. If there is no configuration for the camera interinstics of the dataset, the image will be plotted as is. 189 | Args: 190 | ax: matplotlib axis 191 | img: image to plot 192 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon") 193 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw) 194 | list_points: list of points, each point is a numpy array of shape (2,) 195 | traj_colors: list of colors for trajectories 196 | point_colors: list of colors for points 197 | """ 198 | assert len(list_trajs) <= len(traj_colors), "Not enough colors for trajectories" 199 | assert len(list_points) <= len(point_colors), "Not enough colors for points" 200 | assert ( 201 | dataset_name in data_config 202 | ), f"Dataset {dataset_name} not found in data/data_config.yaml" 203 | 204 | ax.imshow(img) 205 | if ( 206 | "camera_metrics" in data_config[dataset_name] 207 | and "camera_height" in data_config[dataset_name]["camera_metrics"] 208 | and "camera_matrix" in data_config[dataset_name]["camera_metrics"] 209 | and "dist_coeffs" in data_config[dataset_name]["camera_metrics"] 210 | ): 211 | camera_height = data_config[dataset_name]["camera_metrics"]["camera_height"] 212 | camera_x_offset = data_config[dataset_name]["camera_metrics"]["camera_x_offset"] 213 | 214 | fx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fx"] 215 | fy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fy"] 216 | cx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cx"] 217 | cy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cy"] 218 | camera_matrix = gen_camera_matrix(fx, fy, cx, cy) 219 | 220 | k1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k1"] 221 | k2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k2"] 222 | p1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p1"] 223 | p2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p2"] 224 | k3 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k3"] 225 | dist_coeffs = np.array([k1, k2, p1, p2, k3, 0.0, 0.0, 0.0]) 226 | 227 | for i, traj in enumerate(list_trajs): 228 | xy_coords = traj[:, :2] # (horizon, 2) 229 | traj_pixels = get_pos_pixels( 230 | xy_coords, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=False 231 | ) 232 | if len(traj_pixels.shape) == 2: 233 | ax.plot( 234 | traj_pixels[:250, 0], 235 | traj_pixels[:250, 1], 236 | color=traj_colors[i], 237 | lw=2.5, 238 | ) 239 | 240 | for i, point in enumerate(list_points): 241 | if len(point.shape) == 1: 242 | # add a dimension to the front of point 243 | point = point[None, :2] 244 | else: 245 | point = point[:, :2] 246 | pt_pixels = get_pos_pixels( 247 | point, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=True 248 | ) 249 | ax.plot( 250 | pt_pixels[:250, 0], 251 | pt_pixels[:250, 1], 252 | color=point_colors[i], 253 | marker="o", 254 | markersize=10.0, 255 | ) 256 | ax.xaxis.set_visible(False) 257 | ax.yaxis.set_visible(False) 258 | ax.set_xlim((0.5, VIZ_IMAGE_SIZE[0] - 0.5)) 259 | ax.set_ylim((VIZ_IMAGE_SIZE[1] - 0.5, 0.5)) 260 | 261 | 262 | def plot_trajs_and_points( 263 | ax: plt.Axes, 264 | list_trajs: list, 265 | list_points: list, 266 | traj_colors: list = [CYAN, MAGENTA], 267 | point_colors: list = [RED, GREEN], 268 | traj_labels: Optional[list] = ["prediction", "ground truth"], 269 | point_labels: Optional[list] = ["robot", "goal"], 270 | traj_alphas: Optional[list] = None, 271 | point_alphas: Optional[list] = None, 272 | quiver_freq: int = 1, 273 | default_coloring: bool = True, 274 | ): 275 | """ 276 | Plot trajectories and points that could potentially have a yaw. 277 | 278 | Args: 279 | ax: matplotlib axis 280 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw) 281 | list_points: list of points, each point is a numpy array of shape (2,) 282 | traj_colors: list of colors for trajectories 283 | point_colors: list of colors for points 284 | traj_labels: list of labels for trajectories 285 | point_labels: list of labels for points 286 | traj_alphas: list of alphas for trajectories 287 | point_alphas: list of alphas for points 288 | quiver_freq: frequency of quiver plot (if the trajectory data includes the yaw of the robot) 289 | """ 290 | assert ( 291 | len(list_trajs) <= len(traj_colors) or default_coloring 292 | ), "Not enough colors for trajectories" 293 | assert len(list_points) <= len(point_colors), "Not enough colors for points" 294 | assert ( 295 | traj_labels is None or len(list_trajs) == len(traj_labels) or default_coloring 296 | ), "Not enough labels for trajectories" 297 | assert point_labels is None or len(list_points) == len(point_labels), "Not enough labels for points" 298 | 299 | for i, traj in enumerate(list_trajs): 300 | if traj_labels is None: 301 | ax.plot( 302 | traj[:, 0], 303 | traj[:, 1], 304 | color=traj_colors[i], 305 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0, 306 | marker="o", 307 | ) 308 | else: 309 | ax.plot( 310 | traj[:, 0], 311 | traj[:, 1], 312 | color=traj_colors[i], 313 | label=traj_labels[i], 314 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0, 315 | marker="o", 316 | ) 317 | if traj.shape[1] > 2 and quiver_freq > 0: # traj data also includes yaw of the robot 318 | bearings = gen_bearings_from_waypoints(traj) 319 | ax.quiver( 320 | traj[::quiver_freq, 0], 321 | traj[::quiver_freq, 1], 322 | bearings[::quiver_freq, 0], 323 | bearings[::quiver_freq, 1], 324 | color=traj_colors[i] * 0.5, 325 | scale=1.0, 326 | ) 327 | for i, pt in enumerate(list_points): 328 | if point_labels is None: 329 | ax.plot( 330 | pt[0], 331 | pt[1], 332 | color=point_colors[i], 333 | alpha=point_alphas[i] if point_alphas is not None else 1.0, 334 | marker="o", 335 | markersize=7.0 336 | ) 337 | else: 338 | ax.plot( 339 | pt[0], 340 | pt[1], 341 | color=point_colors[i], 342 | alpha=point_alphas[i] if point_alphas is not None else 1.0, 343 | marker="o", 344 | markersize=7.0, 345 | label=point_labels[i], 346 | ) 347 | 348 | 349 | # put the legend below the plot 350 | if traj_labels is not None or point_labels is not None: 351 | ax.legend() 352 | ax.legend(bbox_to_anchor=(0.0, -0.5), loc="upper left", ncol=2) 353 | ax.set_aspect("equal", "box") 354 | 355 | 356 | def angle_to_unit_vector(theta): 357 | """Converts an angle to a unit vector.""" 358 | return np.array([np.cos(theta), np.sin(theta)]) 359 | 360 | 361 | def gen_bearings_from_waypoints( 362 | waypoints: np.ndarray, 363 | mag=0.2, 364 | ) -> np.ndarray: 365 | """Generate bearings from waypoints, (x, y, sin(theta), cos(theta)).""" 366 | bearing = [] 367 | for i in range(0, len(waypoints)): 368 | if waypoints.shape[1] > 3: # label is sin/cos repr 369 | v = waypoints[i, 2:] 370 | # normalize v 371 | v = v / np.linalg.norm(v) 372 | v = v * mag 373 | else: # label is radians repr 374 | v = mag * angle_to_unit_vector(waypoints[i, 2]) 375 | bearing.append(v) 376 | bearing = np.array(bearing) 377 | return bearing 378 | 379 | 380 | def project_points( 381 | xy: np.ndarray, 382 | camera_height: float, 383 | camera_x_offset: float, 384 | camera_matrix: np.ndarray, 385 | dist_coeffs: np.ndarray, 386 | ): 387 | """ 388 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters. 389 | 390 | Args: 391 | xy: array of shape (batch_size, horizon, 2) representing (x, y) coordinates 392 | camera_height: height of the camera above the ground (in meters) 393 | camera_x_offset: offset of the camera from the center of the car (in meters) 394 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters 395 | dist_coeffs: vector of distortion coefficients 396 | 397 | 398 | Returns: 399 | uv: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane 400 | """ 401 | batch_size, horizon, _ = xy.shape 402 | 403 | # create 3D coordinates with the camera positioned at the given height 404 | xyz = np.concatenate( 405 | [xy, -camera_height * np.ones(list(xy.shape[:-1]) + [1])], axis=-1 406 | ) 407 | 408 | # create dummy rotation and translation vectors 409 | rvec = tvec = (0, 0, 0) 410 | 411 | xyz[..., 0] += camera_x_offset 412 | xyz_cv = np.stack([xyz[..., 1], -xyz[..., 2], xyz[..., 0]], axis=-1) 413 | uv, _ = cv2.projectPoints( 414 | xyz_cv.reshape(batch_size * horizon, 3), rvec, tvec, camera_matrix, dist_coeffs 415 | ) 416 | uv = uv.reshape(batch_size, horizon, 2) 417 | 418 | return uv 419 | 420 | 421 | def get_pos_pixels( 422 | points: np.ndarray, 423 | camera_height: float, 424 | camera_x_offset: float, 425 | camera_matrix: np.ndarray, 426 | dist_coeffs: np.ndarray, 427 | clip: Optional[bool] = False, 428 | ): 429 | """ 430 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters. 431 | Args: 432 | points: array of shape (batch_size, horizon, 2) representing (x, y) coordinates 433 | camera_height: height of the camera above the ground (in meters) 434 | camera_x_offset: offset of the camera from the center of the car (in meters) 435 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters 436 | dist_coeffs: vector of distortion coefficients 437 | 438 | Returns: 439 | pixels: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane 440 | """ 441 | pixels = project_points( 442 | points[np.newaxis], camera_height, camera_x_offset, camera_matrix, dist_coeffs 443 | )[0] 444 | pixels[:, 0] = VIZ_IMAGE_SIZE[0] - pixels[:, 0] 445 | if clip: 446 | pixels = np.array( 447 | [ 448 | [ 449 | np.clip(p[0], 0, VIZ_IMAGE_SIZE[0]), 450 | np.clip(p[1], 0, VIZ_IMAGE_SIZE[1]), 451 | ] 452 | for p in pixels 453 | ] 454 | ) 455 | else: 456 | pixels = np.array( 457 | [ 458 | p 459 | for p in pixels 460 | if np.all(p > 0) and np.all(p < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]]) 461 | ] 462 | ) 463 | return pixels 464 | 465 | 466 | def gen_camera_matrix(fx: float, fy: float, cx: float, cy: float) -> np.ndarray: 467 | """ 468 | Args: 469 | fx: focal length in x direction 470 | fy: focal length in y direction 471 | cx: principal point x coordinate 472 | cy: principal point y coordinate 473 | Returns: 474 | camera matrix 475 | """ 476 | return np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]]) 477 | --------------------------------------------------------------------------------