├── train
├── vint_train
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── data_config.yaml
│ │ ├── data_utils.py
│ │ └── vint_dataset.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── navibridge
│ │ │ ├── __init__.py
│ │ │ ├── ddbm
│ │ │ │ ├── __init__.py
│ │ │ │ ├── script_util.py
│ │ │ │ ├── resample.py
│ │ │ │ ├── nn.py
│ │ │ │ └── karras_diffusion.py
│ │ │ ├── self_attention.py
│ │ │ ├── vae
│ │ │ │ ├── vae.py
│ │ │ │ └── conditional_mlp_1D_vae.py
│ │ │ ├── navibridg_utils.py
│ │ │ └── navibridge.py
│ │ ├── base_model.py
│ │ └── model_utils.py
│ ├── training
│ │ ├── __init__.py
│ │ ├── logger.py
│ │ └── train_eval_loop.py
│ ├── process_data
│ │ ├── __init__.py
│ │ ├── process_bags_config.yaml
│ │ └── process_data_utils.py
│ └── visualizing
│ │ ├── __init__.py
│ │ ├── visualize_utils.py
│ │ ├── distance_utils.py
│ │ └── action_utils.py
├── setup.py
├── train_environment.yml
├── config
│ ├── defaults.yaml
│ ├── cvae.yaml
│ └── navibridge.yaml
├── process_recon.py
├── data_split.py
├── process_bags.py
└── train.py
├── deployment
├── config
│ ├── params.yaml
│ ├── robot.yaml
│ └── models.yaml
├── deployment_environment.yaml
└── src
│ ├── navibridger_inference.py
│ └── utils_inference.py
├── assets
└── pipline.png
├── LICENSE
├── .gitignore
└── README.md
/train/vint_train/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/process_data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/visualizing/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/deployment/config/params.yaml:
--------------------------------------------------------------------------------
1 | image_path: "../images/"
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/ddbm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/pipline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hren20/NaiviBridger/HEAD/assets/pipline.png
--------------------------------------------------------------------------------
/train/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="vint_train",
5 | version="0.1.0",
6 | packages=find_packages(),
7 | )
8 |
--------------------------------------------------------------------------------
/deployment/deployment_environment.yaml:
--------------------------------------------------------------------------------
1 | name: navibridge
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - python=3.8.5
7 | - cudatoolkit=11.
8 | - torchvision
9 | - numpy
10 | - matplotlib
11 | - pyyaml
12 | - rospkg
13 | - pip:
14 | - torch
15 | - torchvision
16 | - efficientnet_pytorch
17 | - warmup_scheduler
18 | - diffusers==0.11.1
--------------------------------------------------------------------------------
/deployment/config/robot.yaml:
--------------------------------------------------------------------------------
1 | # linear and angular speed limits for the robot
2 | max_v: 0.3 #0.4 # m/s
3 | max_w: 0.4 #0.8 # rad/s
4 | # observation rate fo the robot
5 | frame_rate: 4 # Hz
6 | graph_rate: 0.3333 # Hz
7 |
8 | # topic names (modify for different robots/nodes)
9 | vel_teleop_topic: /cmd_vel_mux/input/teleop
10 | vel_navi_topic: /cmd_vel
11 | vel_recovery_topic: /cmd_vel_mux/input/recovery
12 |
13 |
14 |
--------------------------------------------------------------------------------
/deployment/config/models.yaml:
--------------------------------------------------------------------------------
1 | navibridger_handcraft:
2 | config_path: "../model_weights/navibridger_handcraft.yaml"
3 | ckpt_path: "../model_weights/navibridger_handcraft.pth"
4 |
5 | navibridger_noise:
6 | config_path: "../model_weights/navibridger_noise.yaml"
7 | ckpt_path: "../model_weights/navibridger_noise.pth"
8 |
9 | navibridger_cvae:
10 | config_path: "../model_weights/navibridger_cvae.yaml"
11 | ckpt_path: "../model_weights/navibridger_cvae.pth"
12 | cvae:
13 | config_path: "../model_weights/cvae.yaml"
14 | ckpt_path: "../model_weights/cvae.pth"
--------------------------------------------------------------------------------
/train/train_environment.yml:
--------------------------------------------------------------------------------
1 | name: navibridge_train
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - python=3.8.5
7 | - cudatoolkit=10.
8 | - numpy
9 | - matplotlib
10 | - ipykernel
11 | - pip
12 | - pip:
13 | - torch
14 | - torchvision
15 | - tqdm==4.64.0
16 | - git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
17 | - opencv-python==4.6.0.66
18 | - h5py==3.6.0
19 | - wandb==0.12.18
20 | - --extra-index-url https://rospypi.github.io/simple/
21 | - rosbag
22 | - roslz4
23 | - prettytable
24 | - efficientnet-pytorch
25 | - warmup-scheduler
26 | - diffusers==0.11.1
27 | - lmdb
28 | - vit-pytorch
29 | - positional-encodings
30 | - piq==0.8.0
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/train/vint_train/visualizing/visualize_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch
4 |
5 | VIZ_IMAGE_SIZE = (640, 480)
6 | RED = np.array([1, 0, 0])
7 | GREEN = np.array([0, 1, 0])
8 | BLUE = np.array([0, 0, 1])
9 | CYAN = np.array([0, 1, 1])
10 | YELLOW = np.array([1, 1, 0])
11 | MAGENTA = np.array([1, 0, 1])
12 |
13 |
14 | def numpy_to_img(arr: np.ndarray) -> Image:
15 | img = Image.fromarray(np.transpose(np.uint8(255 * arr), (1, 2, 0)))
16 | img = img.resize(VIZ_IMAGE_SIZE)
17 | return img
18 |
19 |
20 | def to_numpy(tensor: torch.Tensor) -> np.ndarray:
21 | return tensor.detach().cpu().numpy()
22 |
23 |
24 | def from_numpy(array: np.ndarray) -> torch.Tensor:
25 | return torch.from_numpy(array).float()
26 |
--------------------------------------------------------------------------------
/train/vint_train/process_data/process_bags_config.yaml:
--------------------------------------------------------------------------------
1 | tartan_drive:
2 | odomtopics: "/odometry/filtered_odom"
3 | imtopics: "/multisense/left/image_rect_color"
4 | ang_offset: 1.5707963267948966 # pi/2
5 | img_process_func: "process_tartan_img"
6 | odom_process_func: "nav_to_xy_yaw"
7 |
8 | scand:
9 | odomtopics: ["/odom", "/jackal_velocity_controller/odom"]
10 | imtopics: ["/image_raw/compressed", "/camera/rgb/image_raw/compressed"]
11 | ang_offset: 0.0
12 | img_process_func: "process_scand_img"
13 | odom_process_func: "nav_to_xy_yaw"
14 |
15 | locobot:
16 | odomtopics: "/odom"
17 | imtopics: "/usb_cam/image_raw"
18 | ang_offset: 0.0
19 | img_process_func: "process_locobot_img"
20 | odom_process_func: "nav_to_xy_yaw"
21 |
22 | sacson:
23 | odomtopics: "/odometry"
24 | imtopics: "/fisheye_image/compressed"
25 | ang_offset: 0.0
26 | img_process_func: "process_sacson_img"
27 | odom_process_func: "nav_to_xy_yaw"
28 |
29 | # add your own datasets below:
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Hao Ren, Yiming Zeng, Zetong Bi, Zhaoliang Wan,
4 | Junlong Huang, Hui Cheng
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/train/config/defaults.yaml:
--------------------------------------------------------------------------------
1 | # defaults for training
2 | project_name: navibridge
3 | run_name: navibridge
4 |
5 | # training setup
6 | use_wandb: True # set to false if you don't want to log to wandb
7 | train: True
8 | batch_size: 400
9 | eval_batch_size: 400
10 | epochs: 30
11 | gpu_ids: [0]
12 | num_workers: 4
13 | lr: 5e-4
14 | optimizer: adam
15 | seed: 0
16 | clipping: False
17 | train_subset: 1.
18 |
19 | # model params
20 | model_type: navibridge
21 | obs_encoding_size: 1024
22 | goal_encoding_size: 1024
23 |
24 | # normalization for the action space
25 | normalize: True
26 |
27 | # context
28 | context_type: temporal
29 | context_size: 5
30 |
31 | # tradeoff between action and distance prediction loss
32 | alpha: 0.5
33 |
34 | # tradeoff between task loss and kld
35 | beta: 0.1
36 |
37 | obs_type: image
38 | goal_type: image
39 | scheduler: null
40 |
41 | # distance bounds for distance and action and distance predictions
42 | distance:
43 | min_dist_cat: 0
44 | max_dist_cat: 20
45 | action:
46 | min_dist_cat: 2
47 | max_dist_cat: 10
48 | close_far_threshold: 10 # distance threshold used to seperate the close and the far subgoals that are sampled per datapoint
49 |
50 | # action output params
51 | len_traj_pred: 5
52 | learn_angle: True
53 |
54 | # dataset specific parameters
55 | image_size: [85, 64] # width, height
56 |
57 | # logging stuff
58 | ## =0 turns off
59 | print_log_freq: 100 # in iterations
60 | image_log_freq: 1000 # in iterations
61 | num_images_log: 8 # number of images to log in a logging iteration
62 | pairwise_test_freq: 10 # in epochs
63 | eval_fraction: 0.25 # fraction of the dataset to use for evaluation
64 | wandb_log_freq: 10 # in iterations
65 | eval_freq: 1 # in epochs
66 |
67 |
--------------------------------------------------------------------------------
/train/vint_train/models/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from typing import List, Dict, Optional, Tuple
5 |
6 |
7 | class BaseModel(nn.Module):
8 | def __init__(
9 | self,
10 | context_size: int = 5,
11 | len_traj_pred: Optional[int] = 5,
12 | learn_angle: Optional[bool] = True,
13 | ) -> None:
14 | """
15 | Base Model main class
16 | Args:
17 | context_size (int): how many previous observations to used for context
18 | len_traj_pred (int): how many waypoints to predict in the future
19 | learn_angle (bool): whether to predict the yaw of the robot
20 | """
21 | super(BaseModel, self).__init__()
22 | self.context_size = context_size
23 | self.learn_angle = learn_angle
24 | self.len_trajectory_pred = len_traj_pred
25 | if self.learn_angle:
26 | self.num_action_params = 4 # last two dims are the cos and sin of the angle
27 | else:
28 | self.num_action_params = 2
29 |
30 | def flatten(self, z: torch.Tensor) -> torch.Tensor:
31 | z = nn.functional.adaptive_avg_pool2d(z, (1, 1))
32 | z = torch.flatten(z, 1)
33 | return z
34 |
35 | def forward(
36 | self, obs_img: torch.tensor, goal_img: torch.tensor
37 | ) -> Tuple[torch.Tensor, torch.Tensor]:
38 | """
39 | Forward pass of the model
40 | Args:
41 | obs_img (torch.Tensor): batch of observations
42 | goal_img (torch.Tensor): batch of goals
43 | Returns:
44 | dist_pred (torch.Tensor): predicted distance to goal
45 | action_pred (torch.Tensor): predicted action
46 | """
47 | raise NotImplementedError
48 |
--------------------------------------------------------------------------------
/train/vint_train/training/logger.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Logger:
5 | def __init__(
6 | self,
7 | name: str,
8 | dataset: str,
9 | window_size: int = 10,
10 | rounding: int = 4,
11 | ):
12 | """
13 | Args:
14 | name (str): Name of the metric
15 | dataset (str): Name of the dataset
16 | window_size (int, optional): Size of the moving average window. Defaults to 10.
17 | rounding (int, optional): Number of decimals to round to. Defaults to 4.
18 | """
19 | self.data = []
20 | self.name = name
21 | self.dataset = dataset
22 | self.rounding = rounding
23 | self.window_size = window_size
24 |
25 | def display(self) -> str:
26 | latest = round(self.latest(), self.rounding)
27 | average = round(self.average(), self.rounding)
28 | moving_average = round(self.moving_average(), self.rounding)
29 | output = f"{self.full_name()}: {latest} ({self.window_size}pt moving_avg: {moving_average}) (avg: {average})"
30 | return output
31 |
32 | def log_data(self, data: float):
33 | if not np.isnan(data):
34 | self.data.append(data)
35 |
36 | def full_name(self) -> str:
37 | return f"{self.name} ({self.dataset})"
38 |
39 | def latest(self) -> float:
40 | if len(self.data) > 0:
41 | return self.data[-1]
42 | return np.nan
43 |
44 | def average(self) -> float:
45 | if len(self.data) > 0:
46 | return np.mean(self.data)
47 | return np.nan
48 |
49 | def moving_average(self) -> float:
50 | if len(self.data) > self.window_size:
51 | return np.mean(self.data[-self.window_size :])
52 | return self.average()
--------------------------------------------------------------------------------
/train/vint_train/data/data_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | # global params for diffusion model
3 | # normalized min and max
4 | action_stats:
5 | min: [-2.5, -4] # [min_dx, min_dy]
6 | max: [5, 4] # [max_dx, max_dy]
7 |
8 | # data specific params
9 | recon:
10 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters)
11 |
12 | # OPTIONAL (FOR VISUALIZATION ONLY)
13 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html
14 | camera_height: 0.95 # meters
15 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera
16 | camera_matrix:
17 | fx: 272.547000
18 | fy: 266.358000
19 | cx: 320.000000
20 | cy: 220.000000
21 | dist_coeffs:
22 | k1: -0.038483
23 | k2: -0.010456
24 | p1: 0.003930
25 | p2: -0.001007
26 | k3: 0.0
27 |
28 | recon_test:
29 | metric_waypoint_spacing: 0.25 # average spacing between waypoints (meters)
30 |
31 | # OPTIONAL (FOR VISUALIZATION ONLY)
32 | camera_metrics: # https://docs.opencv.org/4.x/dc/dbb/tutorial_py_calibration.html
33 | camera_height: 0.95 # meters
34 | camera_x_offset: 0.45 # distance between the center of the robot and the forward facing camera
35 | camera_matrix:
36 | fx: 272.547000
37 | fy: 266.358000
38 | cx: 320.000000
39 | cy: 220.000000
40 | dist_coeffs:
41 | k1: -0.038483
42 | k2: -0.010456
43 | p1: 0.003930
44 | p2: -0.001007
45 | k3: 0.0
46 |
47 | scand:
48 | metric_waypoint_spacing: 0.38
49 |
50 | tartan_drive:
51 | metric_waypoint_spacing: 0.72
52 |
53 | go_stanford:
54 | metric_waypoint_spacing: 0.12
55 |
56 | # private datasets:
57 | cory_hall:
58 | metric_waypoint_spacing: 0.06
59 |
60 | seattle:
61 | metric_waypoint_spacing: 0.35
62 |
63 | racer:
64 | metric_waypoint_spacing: 0.38
65 |
66 | carla_intvns:
67 | metric_waypoint_spacing: 1.39
68 |
69 | carla_cil:
70 | metric_waypoint_spacing: 1.27
71 |
72 | carla_intvns:
73 | metric_waypoint_spacing: 1.39
74 |
75 | carla:
76 | metric_waypoint_spacing: 1.59
77 | image_path_func: get_image_path
78 |
79 | sacson:
80 | metric_waypoint_spacing: 0.255
81 |
82 | # add your own dataset params here:
83 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/self_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | class PositionalEncoding(nn.Module):
7 | def __init__(self, d_model, max_seq_len=6):
8 | super().__init__()
9 |
10 | # Compute the positional encoding once
11 | pos_enc = torch.zeros(max_seq_len, d_model)
12 | pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
13 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
14 | pos_enc[:, 0::2] = torch.sin(pos * div_term)
15 | pos_enc[:, 1::2] = torch.cos(pos * div_term)
16 | pos_enc = pos_enc.unsqueeze(0)
17 |
18 | # Register the positional encoding as a buffer to avoid it being
19 | # considered a parameter when saving the model
20 | self.register_buffer('pos_enc', pos_enc)
21 |
22 | def forward(self, x):
23 | # Add the positional encoding to the input
24 | x = x + self.pos_enc[:, :x.size(1), :]
25 | return x
26 |
27 | class MultiLayerDecoder(nn.Module):
28 | def __init__(self, embed_dim=512, seq_len=6, output_layers=[256, 128, 64], nhead=8, num_layers=8, ff_dim_factor=4):
29 | super(MultiLayerDecoder, self).__init__()
30 | self.positional_encoding = PositionalEncoding(embed_dim, max_seq_len=seq_len)
31 | self.sa_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=nhead, dim_feedforward=ff_dim_factor*embed_dim, activation="gelu", batch_first=True, norm_first=True)
32 | self.sa_decoder = nn.TransformerEncoder(self.sa_layer, num_layers=num_layers)
33 | self.output_layers = nn.ModuleList([nn.Linear(seq_len*embed_dim, embed_dim)])
34 | self.output_layers.append(nn.Linear(embed_dim, output_layers[0]))
35 | for i in range(len(output_layers)-1):
36 | self.output_layers.append(nn.Linear(output_layers[i], output_layers[i+1]))
37 |
38 | def forward(self, x):
39 | if self.positional_encoding: x = self.positional_encoding(x)
40 | x = self.sa_decoder(x)
41 | # currently, x is [batch_size, seq_len, embed_dim]
42 | x = x.reshape(x.shape[0], -1)
43 | for i in range(len(self.output_layers)):
44 | x = self.output_layers[i](x)
45 | x = F.relu(x)
46 | return x
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | train/logs/*
2 | train/datasets/*
3 | train/vint_train/data/data_splits/*
4 | train/wandb/*
5 | train/gnm_dataset/*
6 | train/models
7 | *_test.yaml
8 |
9 | *.png
10 | *.jpg
11 | *.pth
12 | *.mp4
13 | *.gif
14 |
15 | deployment/model_weights/*
16 | deployment/topomaps/*
17 |
18 | .vscode/*
19 | */.vscode/*
20 |
21 |
22 | # Byte-compiled / optimized / DLL files
23 | __pycache__/
24 | *.py[cod]
25 | *$py.class
26 |
27 | # C extensions
28 | *.so
29 |
30 | # Distribution / packaging
31 | .Python
32 | build/
33 | develop-eggs/
34 | dist/
35 | downloads/
36 | eggs/
37 | .eggs/
38 | lib/
39 | lib64/
40 | parts/
41 | sdist/
42 | var/
43 | wheels/
44 | pip-wheel-metadata/
45 | share/python-wheels/
46 | *.egg-info/
47 | .installed.cfg
48 | *.egg
49 | MANIFEST
50 |
51 | # PyInstaller
52 | # Usually these files are written by a python script from a template
53 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
54 | *.manifest
55 | *.spec
56 |
57 | # Installer logs
58 | pip-log.txt
59 | pip-delete-this-directory.txt
60 |
61 | # Unit test / coverage reports
62 | htmlcov/
63 | .tox/
64 | .nox/
65 | .coverage
66 | .coverage.*
67 | .cache
68 | nosetests.xml
69 | coverage.xml
70 | *.cover
71 | *.py,cover
72 | .hypothesis/
73 | .pytest_cache/
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 | local_settings.py
82 | db.sqlite3
83 | db.sqlite3-journal
84 |
85 | # Flask stuff:
86 | instance/
87 | .webassets-cache
88 |
89 | # Scrapy stuff:
90 | .scrapy
91 |
92 | # Sphinx documentation
93 | docs/_build/
94 |
95 | # PyBuilder
96 | target/
97 |
98 | # Jupyter Notebook
99 | .ipynb_checkpoints
100 |
101 | # IPython
102 | profile_default/
103 | ipython_config.py
104 |
105 | # pyenv
106 | .python-version
107 |
108 | # pipenv
109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
112 | # install all needed dependencies.
113 | #Pipfile.lock
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
--------------------------------------------------------------------------------
/train/process_recon.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import os
3 | import pickle
4 | from PIL import Image
5 | import io
6 | import argparse
7 | import tqdm
8 |
9 |
10 | def main(args: argparse.Namespace):
11 | recon_dir = os.path.join(args.input_dir, "recon_release")
12 | output_dir = args.output_dir
13 |
14 | # create output dir if it doesn't exist
15 | if not os.path.exists(output_dir):
16 | os.makedirs(output_dir)
17 |
18 | # get all the folders in the recon dataset
19 | filenames = os.listdir(recon_dir)
20 | if args.num_trajs >= 0:
21 | filenames = filenames[: args.num_trajs]
22 |
23 | # processing loop
24 | for filename in tqdm.tqdm(filenames, desc="Trajectories processed"):
25 | # extract the name without the extension
26 | traj_name = filename.split(".")[0]
27 | # load the hdf5 file
28 | try:
29 | h5_f = h5py.File(os.path.join(recon_dir, filename), "r")
30 | except OSError:
31 | print(f"Error loading {filename}. Skipping...")
32 | continue
33 | # extract the position and yaw data
34 | position_data = h5_f["jackal"]["position"][:, :2]
35 | yaw_data = h5_f["jackal"]["yaw"][()]
36 | # save the data to a dictionary
37 | traj_data = {"position": position_data, "yaw": yaw_data}
38 | traj_folder = os.path.join(output_dir, traj_name)
39 | os.makedirs(traj_folder, exist_ok=True)
40 | with open(os.path.join(traj_folder, "traj_data.pkl"), "wb") as f:
41 | pickle.dump(traj_data, f)
42 | # make a folder for the file
43 | if not os.path.exists(traj_folder):
44 | os.makedirs(traj_folder)
45 | # save the image data to disk
46 | for i in range(h5_f["images"]["rgb_left"].shape[0]):
47 | img = Image.open(io.BytesIO(h5_f["images"]["rgb_left"][i]))
48 | img.save(os.path.join(traj_folder, f"{i}.jpg"))
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | # get arguments for the recon input dir and the output dir
54 | parser.add_argument(
55 | "--input-dir",
56 | "-i",
57 | type=str,
58 | help="path of the recon_dataset",
59 | required=True,
60 | )
61 | parser.add_argument(
62 | "--output-dir",
63 | "-o",
64 | default="datasets/recon/",
65 | type=str,
66 | help="path for processed recon dataset (default: datasets/recon/)",
67 | )
68 | # number of trajs to process
69 | parser.add_argument(
70 | "--num-trajs",
71 | "-n",
72 | default=-1,
73 | type=int,
74 | help="number of trajectories to process (default: -1, all)",
75 | )
76 |
77 | args = parser.parse_args()
78 | print("STARTING PROCESSING RECON DATASET")
79 | main(args)
80 | print("FINISHED PROCESSING RECON DATASET")
81 |
--------------------------------------------------------------------------------
/train/data_split.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import random
5 |
6 |
7 | def remove_files_in_dir(dir_path: str):
8 | for f in os.listdir(dir_path):
9 | file_path = os.path.join(dir_path, f)
10 | try:
11 | if os.path.isfile(file_path) or os.path.islink(file_path):
12 | os.unlink(file_path)
13 | elif os.path.isdir(file_path):
14 | shutil.rmtree(file_path)
15 | except Exception as e:
16 | print("Failed to delete %s. Reason: %s" % (file_path, e))
17 |
18 |
19 | def main(args: argparse.Namespace):
20 | # Get the names of the folders in the data directory that contain the file 'traj_data.pkl'
21 | folder_names = [
22 | f
23 | for f in os.listdir(args.data_dir)
24 | if os.path.isdir(os.path.join(args.data_dir, f))
25 | and "traj_data.pkl" in os.listdir(os.path.join(args.data_dir, f))
26 | ]
27 |
28 | # Randomly shuffle the names of the folders
29 | random.shuffle(folder_names)
30 |
31 | # Split the names of the folders into train and test sets
32 | split_index = int(args.split * len(folder_names))
33 | train_folder_names = folder_names[:split_index]
34 | test_folder_names = folder_names[split_index:]
35 |
36 | # Create directories for the train and test sets
37 | train_dir = os.path.join(args.data_splits_dir, args.dataset_name, "train")
38 | test_dir = os.path.join(args.data_splits_dir, args.dataset_name, "test")
39 | for dir_path in [train_dir, test_dir]:
40 | if os.path.exists(dir_path):
41 | print(f"Clearing files from {dir_path} for new data split")
42 | remove_files_in_dir(dir_path)
43 | else:
44 | print(f"Creating {dir_path}")
45 | os.makedirs(dir_path)
46 |
47 | # Write the names of the train and test folders to files
48 | with open(os.path.join(train_dir, "traj_names.txt"), "w") as f:
49 | for folder_name in train_folder_names:
50 | f.write(folder_name + "\n")
51 |
52 | with open(os.path.join(test_dir, "traj_names.txt"), "w") as f:
53 | for folder_name in test_folder_names:
54 | f.write(folder_name + "\n")
55 |
56 |
57 | if __name__ == "__main__":
58 | # Set up the command line argument parser
59 | parser = argparse.ArgumentParser()
60 |
61 | parser.add_argument(
62 | "--data-dir", "-i", help="Directory containing the data", required=True
63 | )
64 | parser.add_argument(
65 | "--dataset-name", "-d", help="Name of the dataset", required=True
66 | )
67 | parser.add_argument(
68 | "--split", "-s", type=float, default=0.8, help="Train/test split (default: 0.8)"
69 | )
70 | parser.add_argument(
71 | "--data-splits-dir", "-o", default="vint_train/data/data_splits", help="Data splits directory"
72 | )
73 | args = parser.parse_args()
74 | main(args)
75 | print("Done")
76 |
--------------------------------------------------------------------------------
/train/config/cvae.yaml:
--------------------------------------------------------------------------------
1 | project_name: cvae
2 | run_name: cvae
3 |
4 | # training setup
5 | use_wandb: True # set to false if you don't want to log to wandb
6 | train: True
7 | batch_size: 256
8 | epochs: 30
9 | gpu_ids: [0]
10 | num_workers: 12
11 | lr: 1e-4
12 | optimizer: adamw
13 | clipping: False
14 | max_norm: 1.
15 | scheduler: "cosine"
16 | warmup: True
17 | warmup_epochs: 4
18 | cyclic_period: 10
19 | plateau_patience: 3
20 | plateau_factor: 0.5
21 | seed: 0
22 | save_freq: 1
23 |
24 | # model params
25 | model_type: cvae
26 | vision_encoder: navibridge_encoder
27 | encoding_size: 256
28 | obs_encoder: efficientnet-b0
29 | attn_unet: False
30 | cond_predict_scale: False
31 | mha_num_attention_heads: 4
32 | mha_num_attention_layers: 4
33 | mha_ff_dim_factor: 4
34 | down_dims: [64, 128, 256]
35 |
36 | # diffusion model params
37 | num_diffusion_iters: 10
38 |
39 | # mask
40 | goal_mask_prob: 0.5
41 |
42 | # normalization for the action space
43 | normalize: True
44 |
45 | # context
46 | context_type: temporal
47 | context_size: 3 # 5
48 | alpha: 1e-4
49 |
50 | # distance bounds for distance and action and distance predictions
51 | distance:
52 | min_dist_cat: 0
53 | max_dist_cat: 20
54 | action:
55 | min_dist_cat: 3
56 | max_dist_cat: 20
57 |
58 | # action output params
59 | len_traj_pred: 8
60 | action_dim: 2
61 | learn_angle: False
62 |
63 | # navibridge
64 | sampler_name: "uniform"
65 | pred_mode: "ve"
66 | weight_schedule: "karras"
67 | sigma_data: 0.5
68 | sigma_min: 0.002
69 | sigma_max: 80.0
70 | rho: 7.0
71 | beta_d: 2
72 | beta_min: 0.1
73 | cov_xy: 0.
74 | guidance: 1.
75 | # sample defaults
76 | clip_denoised: True
77 | sampler: "euler"
78 | churn_step_ratio: 0.33
79 | # prior settings
80 | prior_policy: "gaussian" # handcraft, gaussian, cvae
81 | class_num: 5
82 |
83 | angle_ranges: [[0, 67.5],
84 | [67.5, 112.5],
85 | [112.5, 180],
86 | [180, 270],
87 | [270, 360]]
88 | min_std_angle: 5.0
89 | max_std_angle: 20.0
90 | min_std_length: 1.0
91 | max_std_length: 5.0
92 |
93 | # cvae
94 | train_params:
95 | batch_size: 256
96 | num_itr: 3001
97 | lr: 0.5e-5
98 | lr_gamma: 0.99
99 | lr_step: 1000
100 | l2_norm: 0.0
101 | ema: 0.99
102 |
103 |
104 | diffuse_params:
105 | latent_dim: 64
106 | layer: 3
107 | net_type: vae_mlp
108 | ckpt_path: models/cvae.pth
109 | pretrain: False
110 |
111 | # dataset specific parameters
112 | image_size: [96, 96] # width, height
113 | datasets:
114 | recon:
115 | data_folder: ./datasets/recon
116 | train: ./datasets/data_splits/recon/train # path to train folder with traj_names.txt
117 | test: ./datasets/data_splits/recon/test # path to test folder with traj_names.txt
118 | end_slack: 3 # because many trajectories end in collisions
119 | goals_per_obs: 1 # how many goals are sampled per observation
120 | negative_mining: True # negative mining from the ViNG paper (Shah et al.)
121 | go_stanford:
122 | data_folder: ./datasets/go_stanford/ # datasets/stanford_go_new
123 | train: ./datasets/data_splits/go_stanford/train/
124 | test: ./datasets/data_splits/go_stanford/test/
125 | end_slack: 0
126 | goals_per_obs: 2 # increase dataset size
127 | negative_mining: True
128 |
129 | sacson:
130 | data_folder: ./datasets/sacson/
131 | train: ./datasets/data_splits/sacson/train/
132 | test: ./datasets/data_splits/sacson/test/
133 | end_slack: 3 # because many trajectories end in collisions
134 | goals_per_obs: 1
135 | negative_mining: True
136 |
137 | scand:
138 | data_folder: ./datasets/scand/
139 | train: ./datasets/data_splits/scand/train/
140 | test: ./datasets/data_splits/scand/test/
141 | end_slack: 0
142 | goals_per_obs: 1
143 | negative_mining: True
144 |
145 | # logging stuff
146 | ## =0 turns off
147 | print_log_freq: 500 # in iterations
148 | image_log_freq: 1000 #0 # in iterations
149 | num_images_log: 8 #0
150 | pairwise_test_freq: 0 # in epochs
151 | eval_fraction: 0.25
152 | wandb_log_freq: 10 # in iterations
153 | eval_freq: 1 # in epochs
--------------------------------------------------------------------------------
/train/config/navibridge.yaml:
--------------------------------------------------------------------------------
1 | project_name: navibridge
2 | run_name: navibridge
3 |
4 | # training setup
5 | use_wandb: True # set to false if you don't want to log to wandb
6 | train: True
7 | batch_size: 256
8 | epochs: 30
9 | gpu_ids: [0]
10 | num_workers: 12
11 | lr: 1e-4
12 | optimizer: adamw
13 | clipping: False
14 | max_norm: 1.
15 | scheduler: "cosine"
16 | warmup: True
17 | warmup_epochs: 4
18 | cyclic_period: 10
19 | plateau_patience: 3
20 | plateau_factor: 0.5
21 | seed: 0
22 | save_freq: 1
23 |
24 | # model params
25 | model_type: navibridge
26 | vision_encoder: navibridge_encoder
27 | encoding_size: 256
28 | obs_encoder: efficientnet-b0
29 | attn_unet: False
30 | cond_predict_scale: False
31 | mha_num_attention_heads: 4
32 | mha_num_attention_layers: 4
33 | mha_ff_dim_factor: 4
34 | down_dims: [64, 128, 256]
35 |
36 | # diffusion model params
37 | num_diffusion_iters: 10
38 |
39 | # mask
40 | goal_mask_prob: 0.5
41 |
42 | # normalization for the action space
43 | normalize: True
44 |
45 | # context
46 | context_type: temporal
47 | context_size: 3 # 5
48 | alpha: 1e-4
49 |
50 | # distance bounds for distance and action and distance predictions
51 | distance:
52 | min_dist_cat: 0
53 | max_dist_cat: 20
54 | action:
55 | min_dist_cat: 3
56 | max_dist_cat: 20
57 |
58 | # action output params
59 | len_traj_pred: 8
60 | action_dim: 2
61 | learn_angle: False
62 |
63 | # navibridge
64 | sampler_name: "uniform"
65 | pred_mode: "ve"
66 | weight_schedule: "karras"
67 | sigma_data: 0.5
68 | sigma_min: 0.002
69 | sigma_max: 10.0
70 | rho: 7.0
71 | beta_d: 2
72 | beta_min: 0.1
73 | cov_xy: 0.
74 | guidance: 1.
75 |
76 | clip_denoised: True
77 | sampler: "euler"
78 | churn_step_ratio: 0.33
79 | # prior settings
80 | prior_policy: "gaussian" # handcraft, gaussian, cvae
81 | class_num: 5
82 |
83 | angle_ranges: [[0, 67.5],
84 | [67.5, 112.5],
85 | [112.5, 180],
86 | [180, 270],
87 | [270, 360]]
88 | min_std_angle: 5.0
89 | max_std_angle: 20.0
90 | min_std_length: 1.0
91 | max_std_length: 5.0
92 |
93 | # cvae
94 | train_params:
95 | batch_size: 256
96 | num_itr: 3001
97 | lr: 0.5e-5
98 | lr_gamma: 0.99
99 | lr_step: 1000
100 | l2_norm: 0.0
101 | ema: 0.99
102 |
103 |
104 | diffuse_params:
105 | latent_dim: 64
106 | layer: 3
107 | net_type: vae_mlp
108 | ckpt_path: models/cvae.pth
109 | pretrain: False
110 |
111 | # dataset specific parameters
112 | image_size: [96, 96] # width, height
113 | datasets:
114 | recon:
115 | data_folder: ./datasets/recon
116 | train: ./datasets/data_splits/recon/train # path to train folder with traj_names.txt
117 | test: ./datasets/data_splits/recon/test # path to test folder with traj_names.txt
118 | end_slack: 3 # because many trajectories end in collisions
119 | goals_per_obs: 1 # how many goals are sampled per observation
120 | negative_mining: True # negative mining from the ViNG paper (Shah et al.)
121 |
122 | go_stanford:
123 | data_folder: ./datasets/go_stanford/ # datasets/stanford_go_new
124 | train: ./datasets/data_splits/go_stanford/train/
125 | test: ./datasets/data_splits/go_stanford/test/
126 | end_slack: 0
127 | goals_per_obs: 2 # increase dataset size
128 | negative_mining: True
129 |
130 | sacson:
131 | data_folder: ./datasets/sacson/
132 | train: ./datasets/data_splits/sacson/train/
133 | test: ./datasets/data_splits/sacson/test/
134 | end_slack: 3 # because many trajectories end in collisions
135 | goals_per_obs: 1
136 | negative_mining: True
137 |
138 | scand:
139 | data_folder: ./datasets/scand/
140 | train: ./datasets/data_splits/scand/train/
141 | test: ./datasets/data_splits/scand/test/
142 | end_slack: 0
143 | goals_per_obs: 1
144 | negative_mining: True
145 |
146 | # logging stuff
147 | ## =0 turns off
148 | print_log_freq: 100 # in iterations
149 | image_log_freq: 1000 #0 # in iterations
150 | num_images_log: 8 #0
151 | pairwise_test_freq: 0 # in epochs
152 | eval_fraction: 0.25
153 | wandb_log_freq: 10 # in iterations
154 | eval_freq: 1 # in epochs
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/ddbm/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from .karras_diffusion import KarrasDenoiser
4 | import numpy as np
5 |
6 |
7 | def get_workdir(exp):
8 | workdir = f'./workdir/{exp}'
9 | return workdir
10 |
11 | def cm_train_defaults():
12 | return dict(
13 | teacher_model_path="",
14 | teacher_dropout=0.1,
15 | training_mode="consistency_distillation",
16 | target_ema_mode="fixed",
17 | scale_mode="fixed",
18 | total_training_steps=600000,
19 | start_ema=0.0,
20 | start_scales=40,
21 | end_scales=40,
22 | distill_steps_per_iter=50000,
23 | loss_norm="lpips",
24 | )
25 |
26 | def sample_defaults():
27 | return dict(
28 | generator="determ",
29 | clip_denoised=True,
30 | sampler="euler",
31 | s_churn=0.0,
32 | s_tmin=0.002,
33 | s_tmax=80,
34 | s_noise=1.0,
35 | steps=40,
36 | model_path="",
37 | seed=42,
38 | ts="",
39 | )
40 |
41 | def create_ema_and_scales_fn(
42 | target_ema_mode,
43 | start_ema,
44 | scale_mode,
45 | start_scales,
46 | end_scales,
47 | total_steps,
48 | distill_steps_per_iter,
49 | ):
50 | def ema_and_scales_fn(step):
51 | if target_ema_mode == "fixed" and scale_mode == "fixed":
52 | target_ema = start_ema
53 | scales = start_scales
54 | elif target_ema_mode == "fixed" and scale_mode == "progressive":
55 | target_ema = start_ema
56 | scales = np.ceil(
57 | np.sqrt(
58 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2)
59 | + start_scales**2
60 | )
61 | - 1
62 | ).astype(np.int32)
63 | scales = np.maximum(scales, 1)
64 | scales = scales + 1
65 |
66 | elif target_ema_mode == "adaptive" and scale_mode == "progressive":
67 | scales = np.ceil(
68 | np.sqrt(
69 | (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2)
70 | + start_scales**2
71 | )
72 | - 1
73 | ).astype(np.int32)
74 | scales = np.maximum(scales, 1)
75 | c = -np.log(start_ema) * start_scales
76 | target_ema = np.exp(-c / scales)
77 | scales = scales + 1
78 | elif target_ema_mode == "fixed" and scale_mode == "progdist":
79 | distill_stage = step // distill_steps_per_iter
80 | scales = start_scales // (2**distill_stage)
81 | scales = np.maximum(scales, 2)
82 |
83 | sub_stage = np.maximum(
84 | step - distill_steps_per_iter * (np.log2(start_scales) - 1),
85 | 0,
86 | )
87 | sub_stage = sub_stage // (distill_steps_per_iter * 2)
88 | sub_scales = 2 // (2**sub_stage)
89 | sub_scales = np.maximum(sub_scales, 1)
90 |
91 | scales = np.where(scales == 2, sub_scales, scales)
92 |
93 | target_ema = 1.0
94 | else:
95 | raise NotImplementedError
96 |
97 | return float(target_ema), int(scales)
98 |
99 | return ema_and_scales_fn
100 |
101 |
102 | def add_dict_to_argparser(parser, default_dict):
103 | for k, v in default_dict.items():
104 | v_type = type(v)
105 | if v is None:
106 | v_type = str
107 | elif isinstance(v, bool):
108 | v_type = str2bool
109 | parser.add_argument(f"--{k}", default=v, type=v_type)
110 |
111 |
112 | def args_to_dict(args, keys):
113 | return {k: getattr(args, k) for k in keys}
114 |
115 |
116 | def str2bool(v):
117 | """
118 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
119 | """
120 | if isinstance(v, bool):
121 | return v
122 | if v.lower() in ("yes", "true", "t", "y", "1"):
123 | return True
124 | elif v.lower() in ("no", "false", "f", "n", "0"):
125 | return False
126 | else:
127 | raise argparse.ArgumentTypeError("boolean value expected")
128 |
--------------------------------------------------------------------------------
/train/process_bags.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import pickle
4 | from PIL import Image
5 | import io
6 | import argparse
7 | import tqdm
8 | import yaml
9 | import rosbag
10 |
11 | # utils
12 | from vint_train.process_data.process_data_utils import *
13 |
14 |
15 | def main(args: argparse.Namespace):
16 |
17 | # load the config file
18 | with open("vint_train/process_data/process_bags_config.yaml", "r") as f:
19 | config = yaml.load(f, Loader=yaml.FullLoader)
20 |
21 | # create output dir if it doesn't exist
22 | if not os.path.exists(args.output_dir):
23 | os.makedirs(args.output_dir)
24 |
25 | # iterate recurisively through all the folders and get the path of files with .bag extension in the args.input_dir
26 | bag_files = []
27 | for root, dirs, files in os.walk(args.input_dir):
28 | for file in files:
29 | if file.endswith(".bag"):
30 | bag_files.append(os.path.join(root, file))
31 | if args.num_trajs >= 0:
32 | bag_files = bag_files[: args.num_trajs]
33 |
34 | # processing loop
35 | for bag_path in tqdm.tqdm(bag_files, desc="Bags processed"):
36 | try:
37 | b = rosbag.Bag(bag_path)
38 | except rosbag.ROSBagException as e:
39 | print(e)
40 | print(f"Error loading {bag_path}. Skipping...")
41 | continue
42 |
43 | # name is that folders separated by _ and then the last part of the path
44 | traj_name = "_".join(bag_path.split("/")[-2:])[:-4]
45 |
46 | # load the hdf5 file
47 | bag_img_data, bag_traj_data = get_images_and_odom(
48 | b,
49 | config[args.dataset_name]["imtopics"],
50 | config[args.dataset_name]["odomtopics"],
51 | eval(config[args.dataset_name]["img_process_func"]),
52 | eval(config[args.dataset_name]["odom_process_func"]),
53 | rate=args.sample_rate,
54 | ang_offset=config[args.dataset_name]["ang_offset"],
55 | )
56 |
57 |
58 | if bag_img_data is None or bag_traj_data is None:
59 | print(
60 | f"{bag_path} did not have the topics we were looking for. Skipping..."
61 | )
62 | continue
63 | # remove backwards movement
64 | cut_trajs = filter_backwards(bag_img_data, bag_traj_data)
65 |
66 | for i, (img_data_i, traj_data_i) in enumerate(cut_trajs):
67 | traj_name_i = traj_name + f"_{i}"
68 | traj_folder_i = os.path.join(args.output_dir, traj_name_i)
69 | # make a folder for the traj
70 | if not os.path.exists(traj_folder_i):
71 | os.makedirs(traj_folder_i)
72 | with open(os.path.join(traj_folder_i, "traj_data.pkl"), "wb") as f:
73 | pickle.dump(traj_data_i, f)
74 | # save the image data to disk
75 | for i, img in enumerate(img_data_i):
76 | img.save(os.path.join(traj_folder_i, f"{i}.jpg"))
77 |
78 |
79 | if __name__ == "__main__":
80 | parser = argparse.ArgumentParser()
81 | # get arguments for the recon input dir and the output dir
82 | # add dataset name
83 | parser.add_argument(
84 | "--dataset-name",
85 | "-d",
86 | type=str,
87 | help="name of the dataset (must be in process_config.yaml)",
88 | default="tartan_drive",
89 | required=True,
90 | )
91 | parser.add_argument(
92 | "--input-dir",
93 | "-i",
94 | type=str,
95 | help="path of the datasets with rosbags",
96 | required=True,
97 | )
98 | parser.add_argument(
99 | "--output-dir",
100 | "-o",
101 | default="../datasets/tartan_drive/",
102 | type=str,
103 | help="path for processed dataset (default: ../datasets/tartan_drive/)",
104 | )
105 | # number of trajs to process
106 | parser.add_argument(
107 | "--num-trajs",
108 | "-n",
109 | default=-1,
110 | type=int,
111 | help="number of bags to process (default: -1, all)",
112 | )
113 | # sampling rate
114 | parser.add_argument(
115 | "--sample-rate",
116 | "-s",
117 | default=4.0,
118 | type=float,
119 | help="sampling rate (default: 4.0 hz)",
120 | )
121 |
122 | args = parser.parse_args()
123 | # all caps for the dataset name
124 | print(f"STARTING PROCESSING {args.dataset_name.upper()} DATASET")
125 | main(args)
126 | print(f"FINISHED PROCESSING {args.dataset_name.upper()} DATASET")
127 |
--------------------------------------------------------------------------------
/train/vint_train/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from PIL import Image
4 | from typing import Any, Iterable, Tuple
5 |
6 | import torch
7 | from torchvision import transforms
8 | import torchvision.transforms.functional as TF
9 | import torch.nn.functional as F
10 | import io
11 | from typing import Union
12 |
13 | VISUALIZATION_IMAGE_SIZE = (160, 120)
14 | IMAGE_ASPECT_RATIO = (
15 | 4 / 3
16 | ) # all images are centered cropped to a 4:3 aspect ratio in training
17 |
18 |
19 |
20 | def get_data_path(data_folder: str, f: str, time: int, data_type: str = "image"):
21 | data_ext = {
22 | "image": ".jpg",
23 | # add more data types here
24 | }
25 | return os.path.join(data_folder, f, f"{str(time)}{data_ext[data_type]}")
26 |
27 |
28 | def yaw_rotmat(yaw: float) -> np.ndarray:
29 | return np.array(
30 | [
31 | [np.cos(yaw), -np.sin(yaw), 0.0],
32 | [np.sin(yaw), np.cos(yaw), 0.0],
33 | [0.0, 0.0, 1.0],
34 | ],
35 | )
36 |
37 |
38 | def to_local_coords(
39 | positions: np.ndarray, curr_pos: np.ndarray, curr_yaw: float
40 | ) -> np.ndarray:
41 | """
42 | Convert positions to local coordinates
43 |
44 | Args:
45 | positions (np.ndarray): positions to convert
46 | curr_pos (np.ndarray): current position
47 | curr_yaw (float): current yaw
48 | Returns:
49 | np.ndarray: positions in local coordinates
50 | """
51 | rotmat = yaw_rotmat(curr_yaw)
52 | if positions.shape[-1] == 2:
53 | rotmat = rotmat[:2, :2]
54 | elif positions.shape[-1] == 3:
55 | pass
56 | else:
57 | raise ValueError
58 |
59 | return (positions - curr_pos).dot(rotmat)
60 |
61 |
62 | def calculate_deltas(waypoints: torch.Tensor) -> torch.Tensor:
63 | """
64 | Calculate deltas between waypoints
65 |
66 | Args:
67 | waypoints (torch.Tensor): waypoints
68 | Returns:
69 | torch.Tensor: deltas
70 | """
71 | num_params = waypoints.shape[1]
72 | origin = torch.zeros(1, num_params)
73 | prev_waypoints = torch.concat((origin, waypoints[:-1]), axis=0)
74 | deltas = waypoints - prev_waypoints
75 | if num_params > 2:
76 | return calculate_sin_cos(deltas)
77 | return deltas
78 |
79 |
80 | def calculate_sin_cos(waypoints: torch.Tensor) -> torch.Tensor:
81 | """
82 | Calculate sin and cos of the angle
83 |
84 | Args:
85 | waypoints (torch.Tensor): waypoints
86 | Returns:
87 | torch.Tensor: waypoints with sin and cos of the angle
88 | """
89 | assert waypoints.shape[1] == 3
90 | angle_repr = torch.zeros_like(waypoints[:, :2])
91 | angle_repr[:, 0] = torch.cos(waypoints[:, 2])
92 | angle_repr[:, 1] = torch.sin(waypoints[:, 2])
93 | return torch.concat((waypoints[:, :2], angle_repr), axis=1)
94 |
95 |
96 | def transform_images(
97 | img: Image.Image, transform: transforms, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO
98 | ):
99 | w, h = img.size
100 | if w > h:
101 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio
102 | else:
103 | img = TF.center_crop(img, (int(w / aspect_ratio), w))
104 | viz_img = img.resize(VISUALIZATION_IMAGE_SIZE)
105 | viz_img = TF.to_tensor(viz_img)
106 | img = img.resize(image_resize_size)
107 | transf_img = transform(img)
108 | return viz_img, transf_img
109 |
110 |
111 | def resize_and_aspect_crop(
112 | img: Image.Image, image_resize_size: Tuple[int, int], aspect_ratio: float = IMAGE_ASPECT_RATIO
113 | ):
114 | w, h = img.size
115 | if w > h:
116 | img = TF.center_crop(img, (h, int(h * aspect_ratio))) # crop to the right ratio
117 | else:
118 | img = TF.center_crop(img, (int(w / aspect_ratio), w))
119 | try:
120 | img = img.resize(image_resize_size)
121 | except:
122 | print("111111111111 ", image_resize_size)
123 | resize_img = TF.to_tensor(img)
124 | return resize_img
125 |
126 |
127 | def img_path_to_data(path: Union[str, io.BytesIO], image_resize_size: Tuple[int, int]) -> torch.Tensor:
128 | """
129 | Load an image from a path and transform it
130 | Args:
131 | path (str): path to the image
132 | image_resize_size (Tuple[int, int]): size to resize the image to
133 | Returns:
134 | torch.Tensor: resized image as tensor
135 | """
136 | # return transform_images(Image.open(path), transform, image_resize_size, aspect_ratio)
137 | return resize_and_aspect_crop(Image.open(path), image_resize_size)
138 |
139 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/vae/vae.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------
2 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # This work is licensed under the NVIDIA Source Code License
5 | # for I2SB. To view a copy of this license, see the LICENSE file.
6 | # ---------------------------------------------------------------
7 |
8 | import numpy as np
9 | from tqdm import tqdm
10 | from functools import partial
11 | import torch
12 |
13 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
14 |
15 | from torch_ema import ExponentialMovingAverage
16 | from vint_train.models.navibridge.vae.conditional_mlp_1D_vae import *
17 | import os
18 |
19 | def kl_divergence_normal(mean1, mean2, std1, std2):
20 | kl_loss = ((std2 + 1e-9).log() - (std1 + 1e-9).log()
21 | + (std1.pow(2) + (mean2 - mean1).pow(2))
22 | / (2 * std2.pow(2) + 1e-9) - 0.5).sum(-1).mean()
23 | return kl_loss
24 |
25 |
26 | class VAEModel():
27 | def __init__(self, model_args):
28 | self.net = None
29 | self.ema = None
30 |
31 | self.anneal_factor = 0.0
32 | self.prior_policy = 'gaussian'
33 |
34 | def sample_prior(self, cond, device, num_samples=None, x_prior=None, diffuse_step=None):
35 | num_samples = cond.shape[0]
36 | latent_sample = torch.randn((num_samples, self.net.latent_dim)).to(device)
37 | action_hat = self.net.decoder(torch.cat([cond.flatten(start_dim=1), latent_sample], dim=-1))
38 | action_hat = action_hat.reshape(-1, self.net.len_traj_pred, self.net.action_dim)
39 | return action_hat
40 |
41 | def get_loss(self, obs_img, naction, device):
42 | nobs = obs_img.to(device).float().flatten(start_dim=1)
43 | naction = naction.to(device).float()
44 |
45 | latent_post_dist = self.net.encoder(torch.cat([nobs, naction.flatten(1)], dim=-1))
46 | latent_post_rsample = latent_post_dist.rsample()
47 | latent_post_mean = latent_post_dist.mean
48 | latent_post_std = latent_post_dist.stddev
49 |
50 | latent_prior_mean = torch.zeros_like(latent_post_mean).float().to(device)
51 | latent_prior_std = torch.ones_like(latent_post_std).float().to(device)
52 |
53 | action_rec = self.net.decoder(torch.cat([nobs, latent_post_rsample], dim=-1))
54 | rec_loss = torch.nn.functional.mse_loss(action_rec, naction.flatten(1)) * 10.0
55 | kl_loss = self.anneal_factor * kl_divergence_normal(latent_post_mean, latent_prior_mean, latent_post_std, latent_prior_std)
56 |
57 | self.anneal_factor += 0.0001
58 | self.anneal_factor = 0.1 if self.anneal_factor > 0.1 else self.anneal_factor
59 |
60 | loss = rec_loss + kl_loss
61 | loss_info = {
62 | 'loss': loss,
63 | 'rec_loss': rec_loss,
64 | 'kl_loss': kl_loss,
65 | }
66 | return loss, loss_info
67 |
68 | def log_info(self, writer, log, loss_info, optimizer, itr, num_itr):
69 | writer.add_scalar(itr, 'loss', loss_info['loss'].detach())
70 |
71 | log.info("train_it {}/{} | lr:{} | loss:{}".format(
72 | 1 + itr,
73 | num_itr,
74 | "{:.2e}".format(optimizer.param_groups[0]['lr']),
75 | "{:+.2f}".format(loss_info['loss'].item()),
76 | ))
77 |
78 | def load_model(self, model_args, device):
79 |
80 | if model_args['net_type'] == 'vae_mlp':
81 | self.net = VAEConditionalMLP(
82 | action_dim=model_args['action_dim'],
83 | len_traj_pred=model_args['len_traj_pred'],
84 | global_cond_dim=3 * model_args['image_size'][0] * model_args['image_size'][1] * (model_args['context_size'] + 1),
85 | latent_dim=model_args['latent_dim'],
86 | layer=model_args['layer']
87 | )
88 | else:
89 | raise NotImplementedError
90 |
91 | self.ema = ExponentialMovingAverage(self.net.parameters(), decay=0.99)
92 | if model_args['pretrain']:
93 | checkpoint = torch.load(model_args['ckpt_path'], map_location="cpu")
94 | self.net.load_state_dict(checkpoint['net'])
95 | self.ema.load_state_dict(checkpoint["ema"])
96 |
97 | self.net.to(device)
98 | self.ema.to(device)
99 |
100 | def save_model(self, ckpt_path, epoch):
101 | torch.save({
102 | "net": self.net.state_dict(),
103 | "ema": self.ema.state_dict(),
104 | }, ckpt_path)
105 | print(f"Saved model to {ckpt_path}")
106 |
107 |
108 | def unsqueeze_xdim(z, xdim):
109 | bc_dim = (...,) + (None,) * len(xdim)
110 | return z[bc_dim]
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/vae/conditional_mlp_1D_vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
5 | import numpy as np
6 |
7 |
8 | class SimpleMLP(nn.Module):
9 | def __init__(self, input_dim, hidden_dim,
10 | output_dim, layers: int, activation=nn.ELU):
11 | super().__init__()
12 | self._output_shape = (output_dim,)
13 | self._layers = layers
14 | self._hidden_size = hidden_dim
15 | self.activation = activation
16 | # For adjusting pytorch to tensorflow
17 | self._feature_size = input_dim
18 | # Defining the structure of the NN
19 | self.model = self.build_model()
20 | self.soft_plus = nn.Softplus()
21 |
22 | self._min = 0.01
23 | self._max = 10.0
24 |
25 | def build_model(self):
26 | model = [nn.Linear(self._feature_size, self._hidden_size)]
27 | model += [self.activation()]
28 | for i in range(self._layers - 1):
29 | model += [nn.Linear(self._hidden_size, self._hidden_size)]
30 | model += [self.activation()]
31 | model += [nn.Linear(self._hidden_size, int(np.prod(self._output_shape)))]
32 | return nn.Sequential(*model)
33 |
34 | def forward(self, features):
35 | shape_len = len(features.shape)
36 | if shape_len == 3:
37 | batch = features.shape[1]
38 | length = features.shape[0]
39 | features = features.reshape(-1, features.shape[-1])
40 | outputs = self.model(features)
41 |
42 | if shape_len == 3:
43 | outputs = outputs.reshape(length, batch, -1)
44 | return outputs
45 |
46 |
47 | class NormalMLP(nn.Module):
48 | def __init__(self, input_dim, hidden_dim,
49 | output_dim, layers: int, activation=nn.ELU):
50 | super().__init__()
51 | self._output_shape = (output_dim,)
52 | self._layers = layers
53 | self._hidden_size = hidden_dim
54 | self.activation = activation
55 | # For adjusting pytorch to tensorflow
56 | self._feature_size = input_dim
57 | # Defining the structure of the NN
58 | self.model = self.build_model()
59 | self.soft_plus = nn.Softplus()
60 |
61 | self._min = 0.01
62 | self._max = 10.0
63 |
64 | def build_model(self):
65 | model = [nn.Linear(self._feature_size, self._hidden_size)]
66 | model += [self.activation()]
67 | for i in range(self._layers - 1):
68 | model += [nn.Linear(self._hidden_size, self._hidden_size)]
69 | model += [self.activation()]
70 | model += [nn.Linear(self._hidden_size, 2 * int(np.prod(self._output_shape)))]
71 | return nn.Sequential(*model)
72 |
73 | def forward(self, features):
74 | shape_len = len(features.shape)
75 | if shape_len == 3:
76 | batch = features.shape[1]
77 | length = features.shape[0]
78 | features = features.reshape(-1, features.shape[-1])
79 | dist_inputs = self.model(features)
80 | reshaped_inputs_mean = dist_inputs[..., :np.prod(self._output_shape)]
81 | reshaped_inputs_std = dist_inputs[..., np.prod(self._output_shape):]
82 |
83 | reshaped_inputs_std = torch.clamp(self.soft_plus(reshaped_inputs_std), min=self._min, max=self._max)
84 |
85 | if shape_len == 3:
86 | reshaped_inputs_mean = reshaped_inputs_mean.reshape(length, batch, -1)
87 | reshaped_inputs_std = reshaped_inputs_std.reshape(length, batch, -1)
88 | return torch.distributions.independent.Independent(
89 | torch.distributions.Normal(reshaped_inputs_mean, reshaped_inputs_std), len(self._output_shape))
90 |
91 |
92 | class VAEConditionalMLP(nn.Module):
93 | def __init__(self,
94 | action_dim,
95 | len_traj_pred,
96 | global_cond_dim,
97 | hidden_dim=512,
98 | latent_dim=64,
99 | layer=3
100 | ):
101 | """
102 | input_dim: Dim of actions.
103 | global_cond_dim: Dim of global conditioning applied with FiLM
104 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
105 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
106 | down_dims: Channel size for each UNet level.
107 | The length of this array determines numebr of levels.
108 | kernel_size: Conv kernel size
109 | n_groups: Number of groups for GroupNorm
110 | """
111 |
112 | super().__init__()
113 | input_dim = action_dim * len_traj_pred
114 | self.encoder = NormalMLP(input_dim=input_dim + global_cond_dim, hidden_dim=hidden_dim,
115 | output_dim=latent_dim, layers=layer)
116 |
117 | self.decoder = SimpleMLP(input_dim=latent_dim + global_cond_dim, hidden_dim=hidden_dim,
118 | output_dim=input_dim, layers=layer)
119 |
120 | self.latent_dim = latent_dim
121 | self.action_dim = action_dim
122 | self.len_traj_pred = len_traj_pred
123 |
124 | print("number of parameters: {:e}".format(
125 | sum(p.numel() for p in self.parameters()))
126 | )
127 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/ddbm/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | from scipy.stats import norm
6 | import torch.distributed as dist
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | if name == "uniform":
10 | return UniformSampler(diffusion)
11 | elif name == "real-uniform":
12 | return RealUniformSampler(diffusion)
13 | elif name == "loss-second-moment":
14 | return LossSecondMomentResampler(diffusion)
15 | elif name == "lognormal":
16 | return LogNormalSampler(diffusion)
17 | else:
18 | raise NotImplementedError(f"unknown schedule sampler: {name}")
19 |
20 | class ScheduleSampler(ABC):
21 | @abstractmethod
22 | def weights(self):
23 | """
24 | """
25 |
26 | def sample(self, batch_size, device):
27 | w = self.weights()
28 | p = w / np.sum(w)
29 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
30 | indices = th.from_numpy(indices_np).long().to(device)
31 | weights_np = 1 / (len(p) * p[indices_np])
32 | weights = th.from_numpy(weights_np).float().to(device)
33 | return indices, weights
34 |
35 | class UniformSampler(ScheduleSampler):
36 | def __init__(self, diffusion):
37 | self.diffusion = diffusion
38 | self._weights = np.ones([diffusion.num_timesteps])
39 |
40 | def weights(self):
41 | return self._weights
42 |
43 | class RealUniformSampler:
44 | def __init__(self, diffusion):
45 | self.diffusion = diffusion
46 | self.sigma_max = diffusion.sigma_max
47 | self.sigma_min = diffusion.sigma_min
48 |
49 | def sample(self, batch_size, device):
50 | ts = th.rand(batch_size).to(device) * (self.sigma_max - self.sigma_min) + self.sigma_min
51 | return ts, th.ones_like(ts)
52 |
53 | class LossAwareSampler(ScheduleSampler):
54 | def update_with_local_losses(self, local_ts, local_losses):
55 | batch_sizes = [
56 | th.tensor([0], dtype=th.int32, device=local_ts.device)
57 | for _ in range(dist.get_world_size())
58 | ]
59 | dist.all_gather(
60 | batch_sizes,
61 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
62 | )
63 |
64 | batch_sizes = [x.item() for x in batch_sizes]
65 | max_bs = max(batch_sizes)
66 |
67 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
68 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
69 | dist.all_gather(timestep_batches, local_ts)
70 | dist.all_gather(loss_batches, local_losses)
71 | timesteps = [
72 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
73 | ]
74 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
75 | self.update_with_all_losses(timesteps, losses)
76 |
77 | @abstractmethod
78 | def update_with_all_losses(self, ts, losses):
79 | """
80 | """
81 |
82 | class LossSecondMomentResampler(LossAwareSampler):
83 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
84 | self.diffusion = diffusion
85 | self.history_per_term = history_per_term
86 | self.uniform_prob = uniform_prob
87 | self._loss_history = np.zeros(
88 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
89 | )
90 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
91 |
92 | def weights(self):
93 | if not self._warmed_up():
94 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
95 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
96 | weights /= np.sum(weights)
97 | weights *= 1 - self.uniform_prob
98 | weights += self.uniform_prob / len(weights)
99 | return weights
100 |
101 | def update_with_all_losses(self, ts, losses):
102 | for t, loss in zip(ts, losses):
103 | if self._loss_counts[t] == self.history_per_term:
104 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
105 | self._loss_history[t, -1] = loss
106 | else:
107 | self._loss_history[t, self._loss_counts[t]] = loss
108 | self._loss_counts[t] += 1
109 |
110 | def _warmed_up(self):
111 | return (self._loss_counts == self.history_per_term).all()
112 |
113 | class LogNormalSampler:
114 | def __init__(self, diffusion, p_mean=-1.2, p_std=1.2, even=False):
115 | self.p_mean = p_mean
116 | self.p_std = p_std
117 | self.even = even
118 | if self.even:
119 | self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std)
120 | self.rank, self.size = dist.get_rank(), dist.get_world_size()
121 |
122 | def sample(self, bs, device):
123 | if self.even:
124 | start_i, end_i = self.rank * bs, (self.rank + 1) * bs
125 | global_batch_size = self.size * bs
126 | locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size
127 | log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device)
128 | else:
129 | log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device)
130 |
131 | sigmas = th.exp(log_sigmas)
132 | weights = th.ones_like(sigmas)
133 | return sigmas, weights
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prior Does Matter: Visual Navigation via Denoising Diffusion Bridge Models (NaviBridger)
2 |
3 | > 🏆 Accepted at **CVPR 2025**
4 | > 🔗 [Github](https://github.com/hren20/NaiviBridger) | [arXiv](https://arxiv.org/abs/2504.10041)
5 |
6 |
7 |
8 |
9 |
10 | ---
11 |
12 | ## 📌 TLDR
13 |
14 | NaviBridger is a novel framework for visual navigation built upon **Denoising Diffusion Bridge Models (DDBMs)**. Unlike traditional diffusion policies that start from Gaussian noise, NaviBridger leverages **prior actions** (rule-based or learned) to guide the denoising process, accelerating convergence and improving trajectory accuracy.
15 |
16 | ---
17 |
18 | ## 🛠️ Key Features
19 | - 🔧 DDBM-based policy generation from arbitrary priors
20 | - 🔁 Unified framework supporting Gaussian, rule-based, and learning-based priors
21 | - 🏃♂️ Real-world deployment support on mobile robots (e.g., Diablo + Jetson Orin AGX)
22 |
23 | ---
24 |
25 | ## ✅ TODO List
26 |
27 | - \[x\] Model weights released
28 | - [ ] Deployment code updates
29 | - [ ] A refactored version of the code (in the coming weeks)
30 |
31 | ---
32 |
33 | ## 📁 Directory Overview
34 |
35 | ```
36 | navibridge/
37 | ├── train/ # Training code and dataset processing
38 | │ ├── vint_train/ # NaviBridger models, configs, and datasets
39 | │ ├── train.py # Training entry point
40 | │ ├── process_*.py # Data preprocessing scripts
41 | │ └── train_environment.yml # Conda setup for training
42 | ├── deployment/ # Inference and deployment
43 | │ ├── src/navibridger_inference.py
44 | │ ├── config/params.yaml # Inference config
45 | │ ├── deployment_environment.yaml
46 | │ └── model_weights/ # Place for .pth model weights and corresponding .yaml config
47 | └── README.md # This file
48 | ```
49 |
50 | ---
51 |
52 | ## ⚙️ Setup
53 |
54 | ### 🧪 Environment (Training)
55 |
56 | ```bash
57 | conda env create -f train/train_environment.yml
58 | conda activate navibridge_train
59 | pip install -e train/
60 | git clone git@github.com:real-stanford/diffusion_policy.git
61 | pip install -e diffusion_policy/
62 | ```
63 |
64 | ### 💻 Environment (Deployment)
65 |
66 | ```bash
67 | conda env create -f deployment/deployment_environment.yaml
68 | conda activate navibridge
69 | pip install -e train/
70 | pip install -e diffusion_policy/
71 | ```
72 |
73 | ---
74 |
75 | ## 📦 Data Preparation
76 |
77 | 1. Download public datasets:
78 | - [RECON](https://sites.google.com/view/recon-robot/dataset)
79 | - [SCAND](https://www.cs.utexas.edu/~xiao/SCAND/)
80 | - [GoStanford2](https://cvgl.stanford.edu/gonet/dataset/)
81 | - [SACSoN](https://sites.google.com/view/sacson-review/huron-dataset)
82 |
83 | 2. Process datasets:
84 | ```bash
85 | python train/process_recon.py # or process_bags.py
86 | python train/data_split.py --dataset
87 | ```
88 |
89 | 3. Expected format:
90 | ```
91 | dataset_name/
92 | ├── traj1/
93 | │ ├── 0.jpg ... T_1.jpg
94 | │ └── traj_data.pkl
95 | └── ...
96 | ```
97 |
98 | After `data_split.py`, you should have:
99 | ```
100 | train/vint_train/data/data_splits/
101 | └── /
102 | ├── train/traj_names.txt
103 | └── test/traj_names.txt
104 | ```
105 |
106 | ---
107 |
108 | ## 🧠 Model Training
109 |
110 | ```bash
111 | cd train/
112 | python train.py -c config/navibridge.yaml # Select the training type by changing prior_policy
113 | ```
114 |
115 | ---
116 |
117 | For learning-based method, training CVAE first:
118 | ```bash
119 | python train.py -c config/cvae.yaml
120 | ```
121 |
122 | You can download the *.pth file from [this link](https://drive.google.com/drive/folders/14YhZnqFH9M6Y2fJc6LJwnO1eBIEeDAMC?usp=sharing) to try out the inference
123 |
124 | ---
125 |
126 | ## 🚀 Inference Demo
127 |
128 | 1. Place your trained model and config in:
129 |
130 | ```
131 | deployment/model_weights/*.pth
132 | deployment/model_weights/*.yaml
133 | ```
134 | 2. Adjust model path `deplyment/config/models.yaml`
135 | 3. Prepare input images (minimum 4): `0.png`, `1.png`, etc.
136 | Adjust input directory path in `deployment/config/params.yaml`.
137 |
138 | 4. Run:
139 |
140 | ```bash
141 | python deployment/src/navibridger_inference.py --model navibridge_cvae # Model name corresponding to key value in deplyment/config/models.yaml
142 | ```
143 |
144 | ---
145 |
146 | ## 🤖 Hardware Tested
147 | Here is our deployment platform information, you can replace it at will.
148 |
149 | - NVIDIA Jetson Orin AGX
150 | - Intel RealSense D435i
151 | - Diablo wheeled-legged robot
152 |
153 | > 📸 RGB-only input, no depth or LiDAR required.
154 |
155 | ---
156 |
157 | ## 🧪 Citation
158 |
159 | ```bibtex
160 | @inproceedings{ren2025prior,
161 | title={Prior does matter: Visual navigation via denoising diffusion bridge models},
162 | author={Ren, Hao and Zeng, Yiming and Bi, Zetong and Wan, Zhaoliang and Huang, Junlong and Cheng, Hui},
163 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
164 | pages={12100--12110},
165 | year={2025}
166 | }
167 | ```
168 |
169 | ---
170 |
171 | ## 📜 License
172 |
173 | This codebase is released under the [MIT License](LICENSE).
174 |
175 | ## Acknowledgment
176 | NaviBridger is inspired by the contributions of the following works to the open-source community: [DDBM](https://github.com/alexzhou907/DDBM), [NoMaD](https://github.com/robodhruv/visualnav-transformer), and [BRIDGER](https://github.com/clear-nus/bridger). We thank the authors for sharing their outstanding work.
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/ddbm/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 | import numpy as np
10 | import torch.nn.functional as F
11 |
12 |
13 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
14 | class SiLU(nn.Module):
15 | def forward(self, x):
16 | return x * th.sigmoid(x)
17 |
18 |
19 | class GroupNorm32(nn.GroupNorm):
20 | def forward(self, x):
21 | return super().forward(x.float()).type(x.dtype)
22 |
23 |
24 | def conv_nd(dims, *args, **kwargs):
25 | """
26 | Create a 1D, 2D, or 3D convolution module.
27 | """
28 | if dims == 1:
29 | return nn.Conv1d(*args, **kwargs)
30 | elif dims == 2:
31 | return nn.Conv2d(*args, **kwargs)
32 | elif dims == 3:
33 | return nn.Conv3d(*args, **kwargs)
34 | raise ValueError(f"unsupported dimensions: {dims}")
35 |
36 |
37 | def linear(*args, **kwargs):
38 | """
39 | Create a linear module.
40 | """
41 | return nn.Linear(*args, **kwargs)
42 |
43 |
44 | def avg_pool_nd(dims, *args, **kwargs):
45 | """
46 | Create a 1D, 2D, or 3D average pooling module.
47 | """
48 | if dims == 1:
49 | return nn.AvgPool1d(*args, **kwargs)
50 | elif dims == 2:
51 | return nn.AvgPool2d(*args, **kwargs)
52 | elif dims == 3:
53 | return nn.AvgPool3d(*args, **kwargs)
54 | raise ValueError(f"unsupported dimensions: {dims}")
55 |
56 |
57 | def update_ema(target_params, source_params, rate=0.99):
58 | """
59 | Update target parameters to be closer to those of source parameters using
60 | an exponential moving average.
61 |
62 | :param target_params: the target parameter sequence.
63 | :param source_params: the source parameter sequence.
64 | :param rate: the EMA rate (closer to 1 means slower).
65 | """
66 | for targ, src in zip(target_params, source_params):
67 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
68 |
69 |
70 | def zero_module(module):
71 | """
72 | Zero out the parameters of a module and return it.
73 | """
74 | for p in module.parameters():
75 | p.detach().zero_()
76 | return module
77 |
78 |
79 | def scale_module(module, scale):
80 | """
81 | Scale the parameters of a module and return it.
82 | """
83 | for p in module.parameters():
84 | p.detach().mul_(scale)
85 | return module
86 |
87 |
88 | def mean_flat(tensor):
89 | """
90 | Take the mean over all non-batch dimensions.
91 | """
92 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
93 |
94 |
95 | def append_dims(x, target_dims):
96 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
97 | dims_to_append = target_dims - x.ndim
98 | if dims_to_append < 0:
99 | raise ValueError(
100 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
101 | )
102 | return x[(...,) + (None,) * dims_to_append]
103 |
104 |
105 | def append_zero(x):
106 | return th.cat([x, x.new_zeros([1])])
107 |
108 |
109 | def normalization(channels):
110 | """
111 | Make a standard normalization layer.
112 |
113 | :param channels: number of input channels.
114 | :return: an nn.Module for normalization.
115 | """
116 | return GroupNorm32(32, channels)
117 |
118 |
119 | def timestep_embedding(timesteps, dim, max_period=10000):
120 | """
121 | Create sinusoidal timestep embeddings.
122 |
123 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
124 | These may be fractional.
125 | :param dim: the dimension of the output.
126 | :param max_period: controls the minimum frequency of the embeddings.
127 | :return: an [N x dim] Tensor of positional embeddings.
128 | """
129 | half = dim // 2
130 | freqs = th.exp(
131 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
132 | ).to(device=timesteps.device)
133 | args = timesteps[:, None].float() * freqs[None]
134 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
135 | if dim % 2:
136 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
137 | return embedding
138 |
139 |
140 | def checkpoint(func, inputs, params, flag):
141 | """
142 | Evaluate a function without caching intermediate activations, allowing for
143 | reduced memory at the expense of extra compute in the backward pass.
144 |
145 | :param func: the function to evaluate.
146 | :param inputs: the argument sequence to pass to `func`.
147 | :param params: a sequence of parameters `func` depends on but does not
148 | explicitly take as arguments.
149 | :param flag: if False, disable gradient checkpointing.
150 | """
151 | if flag:
152 | args = tuple(inputs) + tuple(params)
153 | return CheckpointFunction.apply(func, len(inputs), *args)
154 | else:
155 | return func(*inputs)
156 |
157 |
158 | class CheckpointFunction(th.autograd.Function):
159 | @staticmethod
160 | def forward(ctx, run_function, length, *args):
161 | ctx.run_function = run_function
162 | ctx.input_tensors = list(args[:length])
163 | ctx.input_params = list(args[length:])
164 | with th.no_grad():
165 | output_tensors = ctx.run_function(*ctx.input_tensors)
166 | return output_tensors
167 |
168 | @staticmethod
169 | def backward(ctx, *output_grads):
170 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
171 | with th.enable_grad():
172 | # Fixes a bug where the first op in run_function modifies the
173 | # Tensor storage in place, which is not allowed for detach()'d
174 | # Tensors.
175 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
176 | output_tensors = ctx.run_function(*shallow_copies)
177 | input_grads = th.autograd.grad(
178 | output_tensors,
179 | ctx.input_tensors + ctx.input_params,
180 | output_grads,
181 | allow_unused=True,
182 | )
183 | del ctx.input_tensors
184 | del ctx.input_params
185 | del output_tensors
186 | return (None, None) + input_grads
187 |
--------------------------------------------------------------------------------
/train/vint_train/visualizing/distance_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import numpy as np
4 | from typing import List, Optional, Tuple
5 | from vint_train.visualizing.visualize_utils import numpy_to_img
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | def visualize_dist_pred(
10 | batch_obs_images: np.ndarray,
11 | batch_goal_images: np.ndarray,
12 | batch_dist_preds: np.ndarray,
13 | batch_dist_labels: np.ndarray,
14 | eval_type: str,
15 | save_folder: str,
16 | epoch: int,
17 | num_images_preds: int = 8,
18 | use_wandb: bool = True,
19 | display: bool = False,
20 | rounding: int = 4,
21 | dist_error_threshold: float = 3.0,
22 | ):
23 | """
24 | Visualize the distance classification predictions and labels for an observation-goal image pair.
25 |
26 | Args:
27 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
28 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels]
29 | batch_dist_preds (np.ndarray): batch of distance predictions [batch_size]
30 | batch_dist_labels (np.ndarray): batch of distance labels [batch_size]
31 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.)
32 | epoch (int): current epoch number
33 | num_images_preds (int): number of images to visualize
34 | use_wandb (bool): whether to use wandb to log the images
35 | save_folder (str): folder to save the images. If None, will not save the images
36 | display (bool): whether to display the images
37 | rounding (int): number of decimal places to round the distance predictions and labels
38 | dist_error_threshold (float): distance error threshold for classifying the distance prediction as correct or incorrect (only used for visualization purposes)
39 | """
40 | visualize_path = os.path.join(
41 | save_folder,
42 | "visualize",
43 | eval_type,
44 | f"epoch{epoch}",
45 | "dist_classification",
46 | )
47 | if not os.path.isdir(visualize_path):
48 | os.makedirs(visualize_path)
49 | assert (
50 | len(batch_obs_images)
51 | == len(batch_goal_images)
52 | == len(batch_dist_preds)
53 | == len(batch_dist_labels)
54 | )
55 | batch_size = batch_obs_images.shape[0]
56 | wandb_list = []
57 | for i in range(min(batch_size, num_images_preds)):
58 | dist_pred = np.round(batch_dist_preds[i], rounding)
59 | dist_label = np.round(batch_dist_labels[i], rounding)
60 | obs_image = numpy_to_img(batch_obs_images[i])
61 | goal_image = numpy_to_img(batch_goal_images[i])
62 |
63 | save_path = None
64 | if save_folder is not None:
65 | save_path = os.path.join(visualize_path, f"{i}.png")
66 | text_color = "black"
67 | if abs(dist_pred - dist_label) > dist_error_threshold:
68 | text_color = "red"
69 |
70 | display_distance_pred(
71 | [obs_image, goal_image],
72 | ["Observation", "Goal"],
73 | dist_pred,
74 | dist_label,
75 | text_color,
76 | save_path,
77 | display,
78 | )
79 | if use_wandb:
80 | wandb_list.append(wandb.Image(save_path))
81 | if use_wandb:
82 | wandb.log({f"{eval_type}_dist_prediction": wandb_list}, commit=False)
83 |
84 |
85 | def visualize_dist_pairwise_pred(
86 | batch_obs_images: np.ndarray,
87 | batch_close_images: np.ndarray,
88 | batch_far_images: np.ndarray,
89 | batch_close_preds: np.ndarray,
90 | batch_far_preds: np.ndarray,
91 | batch_close_labels: np.ndarray,
92 | batch_far_labels: np.ndarray,
93 | eval_type: str,
94 | save_folder: str,
95 | epoch: int,
96 | num_images_preds: int = 8,
97 | use_wandb: bool = True,
98 | display: bool = False,
99 | rounding: int = 4,
100 | ):
101 | """
102 | Visualize the distance classification predictions and labels for an observation-goal image pair.
103 |
104 | Args:
105 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
106 | batch_close_images (np.ndarray): batch of close goal images [batch_size, height, width, channels]
107 | batch_far_images (np.ndarray): batch of far goal images [batch_size, height, width, channels]
108 | batch_close_preds (np.ndarray): batch of close predictions [batch_size]
109 | batch_far_preds (np.ndarray): batch of far predictions [batch_size]
110 | batch_close_labels (np.ndarray): batch of close labels [batch_size]
111 | batch_far_labels (np.ndarray): batch of far labels [batch_size]
112 | eval_type (string): {data_type}_{eval_type} (e.g. recon_train, gs_test, etc.)
113 | save_folder (str): folder to save the images. If None, will not save the images
114 | epoch (int): current epoch number
115 | num_images_preds (int): number of images to visualize
116 | use_wandb (bool): whether to use wandb to log the images
117 | display (bool): whether to display the images
118 | rounding (int): number of decimal places to round the distance predictions and labels
119 | """
120 | visualize_path = os.path.join(
121 | save_folder,
122 | "visualize",
123 | eval_type,
124 | f"epoch{epoch}",
125 | "pairwise_dist_classification",
126 | )
127 | if not os.path.isdir(visualize_path):
128 | os.makedirs(visualize_path)
129 | assert (
130 | len(batch_obs_images)
131 | == len(batch_close_images)
132 | == len(batch_far_images)
133 | == len(batch_close_preds)
134 | == len(batch_far_preds)
135 | == len(batch_close_labels)
136 | == len(batch_far_labels)
137 | )
138 | batch_size = batch_obs_images.shape[0]
139 | wandb_list = []
140 | for i in range(min(batch_size, num_images_preds)):
141 | close_dist_pred = np.round(batch_close_preds[i], rounding)
142 | far_dist_pred = np.round(batch_far_preds[i], rounding)
143 | close_dist_label = np.round(batch_close_labels[i], rounding)
144 | far_dist_label = np.round(batch_far_labels[i], rounding)
145 | obs_image = numpy_to_img(batch_obs_images[i])
146 | close_image = numpy_to_img(batch_close_images[i])
147 | far_image = numpy_to_img(batch_far_images[i])
148 |
149 | save_path = None
150 | if save_folder is not None:
151 | save_path = os.path.join(visualize_path, f"{i}.png")
152 |
153 | if close_dist_pred < far_dist_pred:
154 | text_color = "black"
155 | else:
156 | text_color = "red"
157 |
158 | display_distance_pred(
159 | [obs_image, close_image, far_image],
160 | ["Observation", "Close Goal", "Far Goal"],
161 | f"close_pred = {close_dist_pred}, far_pred = {far_dist_pred}",
162 | f"close_label = {close_dist_label}, far_label = {far_dist_label}",
163 | text_color,
164 | save_path,
165 | display,
166 | )
167 | if use_wandb:
168 | wandb_list.append(wandb.Image(save_path))
169 | if use_wandb:
170 | wandb.log({f"{eval_type}_pairwise_classification": wandb_list}, commit=False)
171 |
172 |
173 | def display_distance_pred(
174 | imgs: list,
175 | titles: list,
176 | dist_pred: float,
177 | dist_label: float,
178 | text_color: str = "black",
179 | save_path: Optional[str] = None,
180 | display: bool = False,
181 | ):
182 | plt.figure()
183 | fig, ax = plt.subplots(1, len(imgs))
184 |
185 | plt.suptitle(f"prediction: {dist_pred}\nlabel: {dist_label}", color=text_color)
186 |
187 | for axis, img, title in zip(ax, imgs, titles):
188 | axis.imshow(img)
189 | axis.set_title(title)
190 | axis.xaxis.set_visible(False)
191 | axis.yaxis.set_visible(False)
192 |
193 | # make the plot large
194 | fig.set_size_inches((18.5 / 3) * len(imgs), 10.5)
195 |
196 | if save_path is not None:
197 | fig.savefig(
198 | save_path,
199 | bbox_inches="tight",
200 | )
201 | if not display:
202 | plt.close(fig)
203 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/navibridg_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | from typing import List, Dict, Optional, Tuple, Callable
6 | from efficientnet_pytorch import EfficientNet
7 | from vint_train.models.navibridge.self_attention import PositionalEncoding
8 |
9 | class NaviBridge_Encoder(nn.Module):
10 | def __init__(
11 | self,
12 | context_size: int = 5,
13 | obs_encoder: Optional[str] = "efficientnet-b0",
14 | obs_encoding_size: Optional[int] = 512,
15 | mha_num_attention_heads: Optional[int] = 2,
16 | mha_num_attention_layers: Optional[int] = 2,
17 | mha_ff_dim_factor: Optional[int] = 4,
18 | ) -> None:
19 | """
20 | Encoder class
21 | """
22 | super().__init__()
23 | self.obs_encoding_size = obs_encoding_size
24 | self.goal_encoding_size = obs_encoding_size
25 | self.context_size = context_size
26 |
27 | # Initialize the observation encoder
28 | if obs_encoder.split("-")[0] == "efficientnet":
29 | self.obs_encoder = EfficientNet.from_name(obs_encoder, in_channels=3) # context
30 | self.obs_encoder = replace_bn_with_gn(self.obs_encoder)
31 | self.num_obs_features = self.obs_encoder._fc.in_features
32 | self.obs_encoder_type = "efficientnet"
33 | else:
34 | raise NotImplementedError
35 |
36 | # Initialize the goal encoder
37 | self.goal_encoder = EfficientNet.from_name("efficientnet-b0", in_channels=6) # obs+goal
38 | self.goal_encoder = replace_bn_with_gn(self.goal_encoder)
39 | self.num_goal_features = self.goal_encoder._fc.in_features
40 |
41 | # Initialize compression layers if necessary
42 | if self.num_obs_features != self.obs_encoding_size:
43 | self.compress_obs_enc = nn.Linear(self.num_obs_features, self.obs_encoding_size)
44 | else:
45 | self.compress_obs_enc = nn.Identity()
46 |
47 | if self.num_goal_features != self.goal_encoding_size:
48 | self.compress_goal_enc = nn.Linear(self.num_goal_features, self.goal_encoding_size)
49 | else:
50 | self.compress_goal_enc = nn.Identity()
51 |
52 | # Initialize positional encoding and self-attention layers
53 | self.positional_encoding = PositionalEncoding(self.obs_encoding_size, max_seq_len=self.context_size + 2)
54 | self.sa_layer = nn.TransformerEncoderLayer(
55 | d_model=self.obs_encoding_size,
56 | nhead=mha_num_attention_heads,
57 | dim_feedforward=mha_ff_dim_factor*self.obs_encoding_size,
58 | activation="gelu",
59 | batch_first=True,
60 | norm_first=True
61 | )
62 | self.sa_encoder = nn.TransformerEncoder(self.sa_layer, num_layers=mha_num_attention_layers)
63 |
64 | # Definition of the goal mask (convention: 0 = no mask, 1 = mask)
65 | self.goal_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool)
66 | self.goal_mask[:, -1] = True # Mask out the goal
67 | self.no_mask = torch.zeros((1, self.context_size + 2), dtype=torch.bool)
68 | self.all_masks = torch.cat([self.no_mask, self.goal_mask], dim=0)
69 | self.avg_pool_mask = torch.cat([1 - self.no_mask.float(), (1 - self.goal_mask.float()) * ((self.context_size + 2)/(self.context_size + 1))], dim=0)
70 |
71 |
72 | def forward(self, obs_img: torch.tensor, goal_img: torch.tensor, input_goal_mask: torch.tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
73 |
74 | device = obs_img.device
75 |
76 | # Initialize the goal encoding
77 | goal_encoding = torch.zeros((obs_img.size()[0], 1, self.goal_encoding_size)).to(device)
78 |
79 | # Get the input goal mask
80 | if input_goal_mask is not None:
81 | goal_mask = input_goal_mask.to(device)
82 |
83 | # Get the goal encoding
84 | obsgoal_img = torch.cat([obs_img[:, 3*self.context_size:, :, :], goal_img], dim=1) # concatenate the obs image/context and goal image --> non image goal?
85 | obsgoal_encoding = self.goal_encoder.extract_features(obsgoal_img) # get encoding of this img
86 | obsgoal_encoding = self.goal_encoder._avg_pooling(obsgoal_encoding) # avg pooling
87 |
88 | if self.goal_encoder._global_params.include_top:
89 | obsgoal_encoding = obsgoal_encoding.flatten(start_dim=1)
90 | obsgoal_encoding = self.goal_encoder._dropout(obsgoal_encoding)
91 | obsgoal_encoding = self.compress_goal_enc(obsgoal_encoding)
92 |
93 | if len(obsgoal_encoding.shape) == 2:
94 | obsgoal_encoding = obsgoal_encoding.unsqueeze(1)
95 | assert obsgoal_encoding.shape[2] == self.goal_encoding_size
96 | goal_encoding = obsgoal_encoding
97 |
98 | # Get the observation encoding
99 | obs_img = torch.split(obs_img, 3, dim=1)
100 | obs_img = torch.concat(obs_img, dim=0)
101 |
102 | obs_encoding = self.obs_encoder.extract_features(obs_img)
103 | obs_encoding = self.obs_encoder._avg_pooling(obs_encoding)
104 | if self.obs_encoder._global_params.include_top:
105 | obs_encoding = obs_encoding.flatten(start_dim=1)
106 | obs_encoding = self.obs_encoder._dropout(obs_encoding)
107 | obs_encoding = self.compress_obs_enc(obs_encoding)
108 | obs_encoding = obs_encoding.unsqueeze(1)
109 | obs_encoding = obs_encoding.reshape((self.context_size+1, -1, self.obs_encoding_size))
110 | obs_encoding = torch.transpose(obs_encoding, 0, 1)
111 | obs_encoding = torch.cat((obs_encoding, goal_encoding), dim=1)
112 |
113 | # If a goal mask is provided, mask some of the goal tokens
114 | if goal_mask is not None:
115 | no_goal_mask = goal_mask.long()
116 | src_key_padding_mask = torch.index_select(self.all_masks.to(device), 0, no_goal_mask)
117 | else:
118 | src_key_padding_mask = None
119 |
120 | # Apply positional encoding
121 | if self.positional_encoding:
122 | obs_encoding = self.positional_encoding(obs_encoding)
123 |
124 | obs_encoding_tokens = self.sa_encoder(obs_encoding, src_key_padding_mask=src_key_padding_mask)
125 | if src_key_padding_mask is not None:
126 | avg_mask = torch.index_select(self.avg_pool_mask.to(device), 0, no_goal_mask).unsqueeze(-1)
127 | obs_encoding_tokens = obs_encoding_tokens * avg_mask
128 | obs_encoding_tokens = torch.mean(obs_encoding_tokens, dim=1)
129 |
130 | return obs_encoding_tokens
131 |
132 |
133 | # Utils for Group Norm
134 | def replace_bn_with_gn(
135 | root_module: nn.Module,
136 | features_per_group: int=16) -> nn.Module:
137 | """
138 | Relace all BatchNorm layers with GroupNorm.
139 | """
140 | replace_submodules(
141 | root_module=root_module,
142 | predicate=lambda x: isinstance(x, nn.BatchNorm2d),
143 | func=lambda x: nn.GroupNorm(
144 | num_groups=x.num_features//features_per_group,
145 | num_channels=x.num_features)
146 | )
147 | return root_module
148 |
149 |
150 | def replace_submodules(
151 | root_module: nn.Module,
152 | predicate: Callable[[nn.Module], bool],
153 | func: Callable[[nn.Module], nn.Module]) -> nn.Module:
154 | """
155 | Replace all submodules selected by the predicate with
156 | the output of func.
157 |
158 | predicate: Return true if the module is to be replaced.
159 | func: Return new module to use.
160 | """
161 | if predicate(root_module):
162 | return func(root_module)
163 |
164 | bn_list = [k.split('.') for k, m
165 | in root_module.named_modules(remove_duplicate=True)
166 | if predicate(m)]
167 | for *parent, k in bn_list:
168 | parent_module = root_module
169 | if len(parent) > 0:
170 | parent_module = root_module.get_submodule('.'.join(parent))
171 | if isinstance(parent_module, nn.Sequential):
172 | src_module = parent_module[int(k)]
173 | else:
174 | src_module = getattr(parent_module, k)
175 | tgt_module = func(src_module)
176 | if isinstance(parent_module, nn.Sequential):
177 | parent_module[int(k)] = tgt_module
178 | else:
179 | setattr(parent_module, k, tgt_module)
180 | # verify that all modules are replaced
181 | bn_list = [k.split('.') for k, m
182 | in root_module.named_modules(remove_duplicate=True)
183 | if predicate(m)]
184 | assert len(bn_list) == 0
185 | return root_module
186 |
187 |
188 |
189 |
--------------------------------------------------------------------------------
/train/vint_train/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from vint_train.models.navibridge.navibridge import NaviBridge, DenseNetwork, StatesPredNet
2 | from vint_train.models.navibridge.navibridg_utils import NaviBridge_Encoder, replace_bn_with_gn
3 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
4 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
5 |
6 | from vint_train.models.navibridge.ddbm.karras_diffusion import KarrasDenoiser
7 | from vint_train.models.navibridge.ddbm.resample import create_named_schedule_sampler
8 | from vint_train.models.navibridge.navibridge import PriorModel, Prior_HandCraft
9 | from vint_train.models.navibridge.vae.vae import VAEModel
10 |
11 | import os
12 | def create_model(config, device):
13 | """
14 | Create a model based on the provided configuration.
15 |
16 | Args:
17 | config (dict): Configuration dictionary that includes model type and various parameters.
18 |
19 | Returns:
20 | model (object): Created model based on the specified configuration.
21 |
22 | Raises:
23 | ValueError: If the model type or vision encoder is not supported.
24 | """
25 | # Create the model based on configuration
26 | if config["model_type"] == "navibridge":
27 | # Select and configure the vision encoder
28 | if config["vision_encoder"] == "navibridge_encoder":
29 | vision_encoder = NaviBridge_Encoder(
30 | obs_encoding_size=config["encoding_size"],
31 | context_size=config["context_size"],
32 | mha_num_attention_heads=config["mha_num_attention_heads"],
33 | mha_num_attention_layers=config["mha_num_attention_layers"],
34 | mha_ff_dim_factor=config["mha_ff_dim_factor"],
35 | )
36 | vision_encoder = replace_bn_with_gn(vision_encoder)
37 | else:
38 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported")
39 |
40 | # Create the noise prediction network and distribution prediction network
41 | noise_pred_net = ConditionalUnet1D(
42 | input_dim=2,
43 | global_cond_dim=config["encoding_size"],
44 | down_dims=config["down_dims"],
45 | cond_predict_scale=config["cond_predict_scale"],
46 | )
47 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"])
48 |
49 | states_pred_net = StatesPredNet(embedding_dim=config["encoding_size"],
50 | class_num=config["class_num"],
51 | len_traj_pred=config["len_traj_pred"])
52 |
53 | model = NaviBridge(
54 | vision_encoder=vision_encoder,
55 | noise_pred_net=noise_pred_net,
56 | dist_pred_net=dist_pred_network,
57 | states_pred_net=states_pred_net,
58 | )
59 | set_model_prior(model, config, device)
60 | elif config["model_type"] == "cvae":
61 | model = VAEModel(config)
62 | else:
63 | raise ValueError(f"Model {config['model_type']} not supported")
64 |
65 | return model
66 |
67 | def create_noise_scheduler(config):
68 | if config["model_type"] == "navibridge":
69 | diffusion = create_diffusion(config)
70 | noise_scheduler = create_named_schedule_sampler(config["sampler_name"], diffusion)
71 | return noise_scheduler, diffusion
72 |
73 | def create_diffusion(config,
74 | ):
75 | # ddbm params
76 | sigma_data=config["sigma_data"]
77 | sigma_min=config["sigma_min"]
78 | sigma_max=config["sigma_max"]
79 | pred_mode=config["pred_mode"]
80 | weight_schedule=config["weight_schedule"]
81 | beta_d=config["beta_d"]
82 | beta_min=config["beta_min"]
83 | cov_xy=config["cov_xy"]
84 | diffusion = KarrasDenoiser(
85 | sigma_data=sigma_data,
86 | sigma_max=sigma_max,
87 | sigma_min=sigma_min,
88 | beta_d=beta_d,
89 | beta_min=beta_min,
90 | cov_xy=cov_xy,
91 | weight_schedule=weight_schedule,
92 | pred_mode=pred_mode
93 | )
94 | return diffusion
95 |
96 | def set_model_prior(model, prior_args, device):
97 | prior = load_prior_model(prior_args, device)
98 | if prior_args['prior_policy'] == 'handcraft':
99 | prior_model = PriorModel(prior=prior, len_traj_pred=prior_args["len_traj_pred"], action_dim=prior_args["action_dim"])
100 | elif prior_args['prior_policy'] == 'cluster':
101 | pass
102 | elif prior_args['prior_policy'] == 'gaussian':
103 | prior_model = PriorModel(prior=prior, len_traj_pred=prior_args["len_traj_pred"], action_dim=prior_args["action_dim"])
104 | elif prior_args['prior_policy'] == 'cvae':
105 | prior_model = PriorModel(prior=prior, len_traj_pred=prior_args["len_traj_pred"], action_dim=prior_args["action_dim"])
106 | model.prior_model = prior_model
107 | return model
108 |
109 |
110 | def load_prior_model(prior_args, device):
111 | if prior_args['prior_policy'] == 'handcraft':
112 | prior_model = Prior_HandCraft(
113 | class_num=prior_args["class_num"],
114 | angle_ranges=prior_args["angle_ranges"],
115 | min_std_angle=prior_args["min_std_angle"],
116 | max_std_angle=prior_args["max_std_angle"],
117 | min_std_length=prior_args["min_std_length"],
118 | max_std_length=prior_args["max_std_length"],
119 | len_traj_pred=prior_args["len_traj_pred"],
120 | )
121 | elif prior_args['prior_policy'] == 'cluster':
122 | pass
123 | elif prior_args['prior_policy'] == 'gaussian':
124 | prior_model = None
125 | elif prior_args['prior_policy'] == 'cvae':
126 | from vint_train.models.navibridge.vae.vae import VAEModel
127 | model_spec_names = prior_args['net_type']
128 | ckpt_path = prior_args["ckpt_path"]
129 | prior_args['pretrain'] = True
130 | prior_args['ckpt_path'] = ckpt_path
131 |
132 | prior_model = VAEModel(prior_args)
133 | prior_model.load_model(model_args=prior_args, device=device)
134 |
135 | else:
136 | raise NotImplementedError(f"Can not be found prior policy: {prior_args['prior_policy']}")
137 |
138 | return prior_model
139 |
140 | import torch
141 | from torch.optim import Adam, AdamW
142 | from torch.optim.lr_scheduler import (
143 | CosineAnnealingLR,
144 | CyclicLR,
145 | ReduceLROnPlateau,
146 | StepLR
147 | )
148 | from warmup_scheduler import GradualWarmupScheduler
149 |
150 | def get_optimizer_and_scheduler(config, model):
151 | lr = float(config["lr"])
152 | config["optimizer"] = config["optimizer"].lower()
153 |
154 | if config["optimizer"] == "adam":
155 | optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.98))
156 | elif config["optimizer"] == "adamw":
157 | optimizer = AdamW(model.parameters(), lr=lr)
158 | elif config["optimizer"] == "sgd":
159 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
160 | else:
161 | raise ValueError(f"Optimizer {config['optimizer']} not supported")
162 |
163 | scheduler = None
164 |
165 | if config["model_type"] == "cvae":
166 | if config['lr_gamma'] < 1.0:
167 | scheduler = StepLR(optimizer,
168 | step_size=config["lr_step"],
169 | gamma=config["lr_gamma"])
170 | else:
171 | scheduler = None
172 |
173 | elif config["scheduler"] is not None:
174 | config["scheduler"] = config["scheduler"].lower()
175 |
176 | if config["scheduler"] == "cosine":
177 | print("Using cosine annealing with T_max", config["epochs"])
178 | scheduler = CosineAnnealingLR(optimizer, T_max=config["epochs"])
179 | elif config["scheduler"] == "cyclic":
180 | print("Using cyclic LR with cycle", config["cyclic_period"])
181 | scheduler = CyclicLR(
182 | optimizer,
183 | base_lr=lr / 10.0,
184 | max_lr=lr,
185 | step_size_up=config["cyclic_period"] // 2,
186 | cycle_momentum=False,
187 | )
188 | elif config["scheduler"] == "plateau":
189 | print("Using ReduceLROnPlateau")
190 | scheduler = ReduceLROnPlateau(
191 | optimizer,
192 | factor=config["plateau_factor"],
193 | patience=config["plateau_patience"],
194 | verbose=True,
195 | )
196 | else:
197 | raise ValueError(f"Scheduler {config['scheduler']} not supported")
198 |
199 | if config.get("warmup", False):
200 | print("Using warmup scheduler")
201 | scheduler = GradualWarmupScheduler(
202 | optimizer,
203 | multiplier=1,
204 | total_epoch=config["warmup_epochs"],
205 | after_scheduler=scheduler,
206 | )
207 |
208 | return optimizer, scheduler
209 |
--------------------------------------------------------------------------------
/deployment/src/navibridger_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import sys
5 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
6 |
7 |
8 | from vint_train.models.model_utils import create_noise_scheduler
9 | from vint_train.training.train_utils import get_action
10 | from vint_train.visualizing.action_utils import plot_trajs_and_points
11 | from vint_train.models.navibridge.ddbm.karras_diffusion import karras_sample
12 |
13 | import torch
14 | import numpy as np
15 | import yaml
16 | from PIL import Image as PILImage
17 | import matplotlib.pyplot as plt
18 |
19 | from utils_inference import to_numpy, transform_images, load_model, project_and_draw
20 |
21 | PARAMS_PATH = "../config/params.yaml"
22 | with open(PARAMS_PATH, "r") as f:
23 | params_config = yaml.safe_load(f)
24 | image_path = params_config["image_path"]
25 |
26 | # CONSTANTS
27 | MODEL_WEIGHTS_PATH = "../model_weights"
28 | ROBOT_CONFIG_PATH ="../config/robot.yaml"
29 | MODEL_CONFIG_PATH = "../config/models.yaml"
30 | with open(ROBOT_CONFIG_PATH, "r") as f:
31 | robot_config = yaml.safe_load(f)
32 | MAX_V = robot_config["max_v"]
33 | MAX_W = robot_config["max_w"]
34 | RATE = robot_config["frame_rate"]
35 | ACTION_STATS = {}
36 | ACTION_STATS['min'] = np.array([-2.5, -4])
37 | ACTION_STATS['max'] = np.array([5, 4])
38 | # GLOBALS
39 | context_queue = []
40 | context_size = None
41 |
42 | # Load the model
43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44 | print("Using device:", device)
45 |
46 | def get_bottom_folder_name(path):
47 | folder_path = os.path.dirname(path)
48 | bottom_folder_name = os.path.basename(folder_path)
49 | return bottom_folder_name
50 |
51 | def ensure_directory_exists(save_path):
52 | directory = os.path.dirname(save_path)
53 | if not os.path.exists(directory):
54 | os.makedirs(directory)
55 |
56 | def path_project_plot(img, path, args, camera_extrinsics, camera_intrinsics):
57 | for i, naction in enumerate(path):
58 | gc_actions = to_numpy(get_action(naction))
59 | fig = project_and_draw(img, gc_actions, camera_extrinsics, camera_intrinsics)
60 | dir_basename = get_bottom_folder_name(image_path)
61 | save_path = os.path.join('../output', dir_basename, f'png_{args.model}_image_with_trajs_{i}.png')
62 | ensure_directory_exists(save_path)
63 | fig.savefig(save_path)
64 | save_path = os.path.join('../output', dir_basename, f'svg_{args.model}_image_with_trajs_{i}.svg')
65 | ensure_directory_exists(save_path)
66 | fig.savefig(save_path)
67 | print(f"output image saved as {save_path}")
68 |
69 | def main(args):
70 | camera_intrinsics = np.array([[470.7520828622471, 0, 16.00531005859375],
71 | [0, 470.7520828622471, 403.38909912109375],
72 | [0, 0, 1]])
73 | camera_extrinsics = np.array([[0, 0, 1, -0.600],
74 | [-1, 0, 0, -0.000],
75 | [0, -1, 0, -0.042],
76 | [0, 0, 0, 1]])
77 | global context_size, image_path
78 | # load model parameters
79 | with open(MODEL_CONFIG_PATH, "r") as f:
80 | model_paths = yaml.safe_load(f)
81 |
82 | model_config_path = model_paths[args.model]["config_path"]
83 | with open(model_config_path, "r") as f:
84 | model_params = yaml.safe_load(f)
85 |
86 | if model_params["model_type"] == "cvae":
87 | if "train_params" in model_params:
88 | model_params.update(model_params["train_params"])
89 | if "diffuse_params" in model_params:
90 | model_params.update(model_params["diffuse_params"])
91 |
92 | if model_params.get("prior_policy", None) == "cvae":
93 | if "diffuse_params" in model_params:
94 | model_params.update(model_params["diffuse_params"])
95 |
96 | context_size = model_params["context_size"]
97 | # load model weights
98 | ckpth_path = model_paths[args.model]["ckpt_path"]
99 | if os.path.exists(ckpth_path):
100 | print(f"Loading model from {ckpth_path}")
101 | else:
102 | raise FileNotFoundError(f"Model weights not found at {ckpth_path}")
103 | model = load_model(
104 | ckpth_path,
105 | model_params,
106 | device,
107 | )
108 | model.eval()
109 |
110 | if model_params["model_type"] == "navibridge":
111 | noise_scheduler, diffusion = create_noise_scheduler(model_params)
112 |
113 | for i in range(4):
114 |
115 | img_path = image_path + str(i + 18) + ".png"
116 | img = PILImage.open(img_path)
117 | context_queue.append(img)
118 |
119 | fake_goal = torch.randn((1, 3, *model_params["image_size"])).to(device)
120 |
121 | # infer action
122 | obs_images = transform_images(context_queue, model_params["image_size"], center_crop=False)
123 | obs_images = obs_images.to(device)
124 | fake_goal = torch.randn((1, 3, *model_params["image_size"])).to(device)
125 | mask = torch.ones(1).long().to(device)
126 |
127 | # You can change the fake_goal to a goal image, do the same operation as obs_images
128 | # goal_image = transform_images(goal_image, model_params["image_size"], center_crop=False)
129 | # goal_image = goal_image.to(device)
130 | # mask = torch.zeros(1).long().to(device)
131 |
132 | obs_cond_gc = model('vision_encoder', obs_img=obs_images, goal_img=fake_goal, input_goal_mask=mask)
133 | scale_factor=3 * MAX_V / RATE
134 | with torch.no_grad():
135 | if len(obs_cond_gc.shape) == 2:
136 | obs_cond_gc = obs_cond_gc.repeat(args.num_samples, 1)
137 | else:
138 | obs_cond_gc = obs_cond_gc.repeat(args.num_samples, 1, 1)
139 |
140 | if model_params["model_type"] == "navibridge":
141 | if model_params["prior_policy"] == "handcraft":
142 | # Predict aciton states
143 | states_pred = model("states_pred_net", obsgoal_cond=obs_cond_gc)
144 |
145 | if model_params["prior_policy"] == "cvae":
146 | prior_cond = obs_images.repeat_interleave(args.num_samples, dim=0)
147 | elif model_params["prior_policy"] == "handcraft":
148 | prior_cond = states_pred
149 | else:
150 | prior_cond = None
151 |
152 | # initialize action from Gaussian noise
153 | if model.prior_model.prior is None:
154 | initial_samples = torch.randn((args.num_samples, model_params["len_traj_pred"], 2), device=device)
155 | else:
156 | with torch.no_grad():
157 | initial_samples = model.prior_model.sample(cond=prior_cond, device=device)
158 | assert initial_samples.shape[-1] == 2, "action dim must be 2"
159 | naction, path, nfe = karras_sample(
160 | diffusion,
161 | model,
162 | initial_samples,
163 | None,
164 | steps=model_params["num_diffusion_iters"],
165 | model_kwargs=initial_samples,
166 | global_cond=obs_cond_gc,
167 | device=device,
168 | clip_denoised=model_params["clip_denoised"],
169 | sampler="heun",
170 | sigma_min=diffusion.sigma_min,
171 | sigma_max=diffusion.sigma_max,
172 | churn_step_ratio=model_params["churn_step_ratio"],
173 | rho=model_params["rho"],
174 | guidance=model_params["guidance"]
175 | )
176 |
177 | gc_actions = to_numpy(get_action(naction))
178 |
179 | gc_actions *= scale_factor
180 | if args.path_visual:
181 | path_project_plot(context_queue[-1], path, args, camera_extrinsics, camera_intrinsics)
182 |
183 |
184 | if __name__ == "__main__":
185 | parser = argparse.ArgumentParser(
186 | description="Code to run DIFFUSION NAVIGATION demo")
187 | parser.add_argument(
188 | "--model",
189 | "-m",
190 | default="navibridge_noise",
191 | type=str,
192 | help="model name (hint: check ../config/models.yaml)",
193 | )
194 | parser.add_argument(
195 | "--waypoint",
196 | "-w",
197 | default=2, # close waypoints exihibit straight line motion (the middle waypoint is a good default)
198 | type=int,
199 | help=f"""index of the waypoint used for navigation (between 0 and 4 or
200 | how many waypoints your model predicts) (default: 2)""",
201 | )
202 | parser.add_argument(
203 | "--num-samples",
204 | "-n",
205 | default=100,
206 | type=int,
207 | help=f"Number of actions sampled from the exploration model (default: 8)",
208 | )
209 | parser.add_argument(
210 | "--path-visual",
211 | default=True,
212 | type=bool,
213 | help="visualization",
214 | )
215 | parser.add_argument(
216 | "--device",
217 | "-d",
218 | default=0,
219 | type=int,
220 | help="device",
221 | )
222 | args = parser.parse_args()
223 | print(f"Using {device}")
224 | main(args)
225 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/navibridge.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 | import pdb
5 | import numpy as np
6 | import yaml
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from vint_train.training.train_utils import get_delta, normalize_data
12 | from vint_train.visualizing.visualize_utils import to_numpy, from_numpy
13 |
14 | # LOAD DATA CONFIG
15 | with open(os.path.join(os.path.dirname(__file__), "../../data/data_config.yaml"), "r") as f:
16 | data_config = yaml.safe_load(f)
17 | # POPULATE ACTION STATS
18 | ACTION_STATS = {}
19 | for key in data_config['action_stats']:
20 | ACTION_STATS[key] = np.array(data_config['action_stats'][key])
21 |
22 | class NaviBridge(nn.Module):
23 |
24 | def __init__(self, vision_encoder,
25 | noise_pred_net,
26 | dist_pred_net,
27 | states_pred_net):
28 | super(NaviBridge, self).__init__()
29 |
30 |
31 | self.vision_encoder = vision_encoder
32 | self.noise_pred_net = noise_pred_net
33 | self.dist_pred_net = dist_pred_net
34 | self.states_pred_net = states_pred_net
35 | self.prior_model = None
36 |
37 | def forward(self, func_name, **kwargs):
38 | if func_name == "vision_encoder" :
39 | output = self.vision_encoder(kwargs["obs_img"], kwargs["goal_img"], input_goal_mask=kwargs["input_goal_mask"])
40 | elif func_name == "noise_pred_net":
41 | output = self.noise_pred_net(sample=kwargs["sample"], timestep=kwargs["timestep"], global_cond=kwargs["global_cond"])
42 | elif func_name == "dist_pred_net":
43 | output = self.dist_pred_net(kwargs["obsgoal_cond"])
44 | elif func_name == "states_pred_net":
45 | output = self.states_pred_net(kwargs["obsgoal_cond"])
46 | else:
47 | raise NotImplementedError
48 | return output
49 |
50 |
51 | class DenseNetwork(nn.Module):
52 | def __init__(self, embedding_dim):
53 | super(DenseNetwork, self).__init__()
54 |
55 | self.embedding_dim = embedding_dim
56 | self.network = nn.Sequential(
57 | nn.Linear(self.embedding_dim, self.embedding_dim//4),
58 | nn.ReLU(),
59 | nn.Linear(self.embedding_dim//4, self.embedding_dim//16),
60 | nn.ReLU(),
61 | nn.Linear(self.embedding_dim//16, 1)
62 | )
63 |
64 | def forward(self, x):
65 | x = x.reshape((-1, self.embedding_dim))
66 | output = self.network(x)
67 | return output
68 |
69 |
70 | class StatesPredNet(nn.Module):
71 | def __init__(self, embedding_dim, class_num, len_traj_pred):
72 | super(StatesPredNet, self).__init__()
73 | self.output_length = class_num + len_traj_pred * 2
74 | self.embedding_dim = embedding_dim
75 | self.class_num = class_num
76 | self.len_traj_pred = len_traj_pred
77 |
78 | self.network = nn.Sequential(
79 | nn.Linear(self.embedding_dim, self.embedding_dim // 4),
80 | nn.ReLU(),
81 | nn.Linear(self.embedding_dim // 4, self.embedding_dim // 16),
82 | nn.ReLU(),
83 | nn.Linear(self.embedding_dim // 16, self.output_length)
84 | )
85 |
86 | def forward(self, x):
87 | x = x.reshape((-1, self.embedding_dim))
88 | output = self.network(x)
89 |
90 | class_logits = output[:, :self.class_num]
91 | coords = output[:, self.class_num:]
92 |
93 | class_probs = nn.functional.softmax(class_logits, dim=-1)
94 |
95 | coords = coords.reshape((-1, self.len_traj_pred, 2))
96 |
97 | return class_probs, coords
98 |
99 |
100 | class PriorModel:
101 | def __init__(self, prior=None, action_dim=2, len_traj_pred=8):
102 | self.prior = prior
103 | self.action_dim = action_dim
104 | self.len_traj_pred = len_traj_pred
105 |
106 | def sample(self, cond, device, num_samples=1):
107 | states = cond
108 |
109 | if self.prior is None:
110 | if type(states) is np.array:
111 | states_torch = torch.as_tensor(states)
112 | else:
113 | states_torch = states
114 |
115 | if isinstance(states_torch, tuple):
116 | prior_sample = torch.randn((states_torch[1].shape[0], self.len_traj_pred, self.action_dim), device=device)
117 | else:
118 | prior_sample = torch.randn((states_torch.shape[0], self.len_traj_pred, self.action_dim), device=device)
119 | else:
120 | if type(states) is np.array:
121 | states_torch = torch.as_tensor(states)
122 | else:
123 | states_torch = states
124 |
125 | prior_sample = self.prior.sample_prior(states_torch, num_samples=num_samples, device=device)
126 |
127 | prior_sample = prior_sample.to(device)
128 |
129 | return prior_sample
130 |
131 | class Prior_HandCraft:
132 | def __init__(self,
133 | class_num: int = 5,
134 | angle_ranges: list = None, # [(0, 67.5), (67.5, 112.5), (112.5, 180), (180, 270), (270, 360)]
135 | min_std_angle: float = 5.0,
136 | max_std_angle: float = 20.0,
137 | min_std_length: float = 1.0,
138 | max_std_length: float = 5.0,
139 | num_samples: int = 1,
140 | len_traj_pred=8,
141 | ):
142 | if angle_ranges is None:
143 | self._set_angle_ranges()
144 | else:
145 | self.angle_ranges = angle_ranges
146 | self.class_num = class_num
147 | self.min_std_angle = min_std_angle
148 | self.max_std_angle = max_std_angle
149 | self.min_std_length = min_std_length
150 | self.max_std_length = max_std_length
151 | self.num_samples = num_samples
152 | self.len_traj_pred = len_traj_pred
153 |
154 | def _set_angle_ranges(self):
155 | self.angle_ranges = [(0, 67.5),
156 | (67.5, 112.5),
157 | (112.5, 180),
158 | (180, 270),
159 | (270, 360)]
160 |
161 | def _extract_length(self, coords):
162 | x_end, y_end = coords[-1]
163 | length = np.sqrt(x_end**2 + y_end**2)
164 |
165 | return length
166 |
167 | def _exact_states(self, states_pred):
168 | class_probs, coords_pred = states_pred
169 | return class_probs.cpu(), coords_pred.cpu()
170 |
171 | def _convert_confidence_to_std(self, confidence, min_std=0.5, max_std=5.0):
172 | return min_std + (max_std - min_std) * (1 - confidence)
173 |
174 | def _preprocess(self, actions):
175 | actions = np.squeeze(actions)
176 | if actions.ndim == 2:
177 | actions = np.expand_dims(actions, axis=0)
178 | deltas = get_delta(actions)
179 | ndeltas = normalize_data(deltas, ACTION_STATS)
180 | naction = from_numpy(ndeltas)
181 | return naction
182 |
183 | def sample_prior(self, states, num_samples=None, device=None):
184 | if num_samples is not None:
185 | self.num_samples = num_samples
186 |
187 | class_probs, coords_pred = self._exact_states(states)
188 | batch_size = class_probs.shape[0]
189 | sampled_trajectories_batch = []
190 |
191 | for batch_idx in range(batch_size):
192 | class_probs_single = class_probs[batch_idx]
193 | coords_pred_single = coords_pred[batch_idx]
194 |
195 | self.path_length_mean = self._extract_length(coords_pred_single)
196 | length_confidence = class_probs_single
197 |
198 | sampled_trajectories = []
199 |
200 | for _ in range(self.num_samples):
201 | category = torch.multinomial(class_probs_single, num_samples=1).item()
202 |
203 | angle_range = self.angle_ranges[category]
204 |
205 | angle_std = self._convert_confidence_to_std(
206 | class_probs_single[category],
207 | min_std=self.min_std_angle,
208 | max_std=self.max_std_angle
209 | )
210 |
211 | length_std = self._convert_confidence_to_std(
212 | length_confidence[category],
213 | min_std=self.min_std_length,
214 | max_std=self.max_std_length
215 | )
216 |
217 | angle_mean = (angle_range[0] + angle_range[1]) / 2
218 | angle_sample = np.random.normal(loc=angle_mean, scale=angle_std)
219 |
220 | path_length_sample = np.random.normal(loc=self.path_length_mean, scale=length_std)
221 |
222 | x_end = path_length_sample * np.cos(np.radians(angle_sample))
223 | y_end = path_length_sample * np.sin(np.radians(angle_sample))
224 |
225 | h_mid = (0 + x_end) / 2
226 | h_range = (x_end - 0) / 2
227 | if y_end >= 0:
228 | h = np.random.uniform(x_end, x_end * 2)
229 | else:
230 | h = np.random.uniform(h_mid - h_range / 2, h_mid - h_range / 4)
231 |
232 | a = (y_end - 0) / ((x_end - 0) * (x_end + 0 - 2 * h))
233 | k = -a * h**2
234 |
235 | x_vals = np.linspace(0, x_end, self.len_traj_pred)
236 | y_vals = a * (x_vals - h) ** 2 + k
237 | combined = np.stack((y_vals, x_vals), axis=1)
238 | sampled_trajectories.append(combined)
239 |
240 | sampled_trajectories = np.stack(sampled_trajectories, axis=0)
241 |
242 | sampled_trajectories_batch.append(sampled_trajectories)
243 |
244 | actions = np.stack(sampled_trajectories_batch, axis=0)
245 | naction = self._preprocess(actions)
246 | return naction
247 |
--------------------------------------------------------------------------------
/deployment/src/utils_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import io
4 | import matplotlib.pyplot as plt
5 |
6 | # pytorch
7 | import torch
8 | import torch.nn as nn
9 | from torchvision import transforms
10 | import torchvision.transforms.functional as TF
11 |
12 | import numpy as np
13 | from PIL import Image as PILImage
14 | import cv2
15 | from typing import List, Tuple, Dict, Optional
16 | import importlib.resources as pkg_resources
17 | from tqdm import tqdm
18 |
19 | # models
20 | from vint_train.models.navibridge.navibridge import NaviBridge, DenseNetwork, StatesPredNet
21 | from vint_train.models.navibridge.navibridg_utils import NaviBridge_Encoder, replace_bn_with_gn
22 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D
23 | from vint_train.data.data_utils import IMAGE_ASPECT_RATIO
24 |
25 | from vint_train.models.navibridge.ddbm.karras_diffusion import KarrasDenoiser
26 | from vint_train.models.navibridge.ddbm.resample import create_named_schedule_sampler
27 | from vint_train.models.navibridge.navibridge import PriorModel, Prior_HandCraft
28 | from vint_train.models.navibridge.vae.vae import VAEModel
29 | from vint_train.models.model_utils import set_model_prior
30 |
31 | def load_model(
32 | model_path: str,
33 | config: dict,
34 | device: torch.device = torch.device("cpu"),
35 | ) -> nn.Module:
36 | """Load a model from a checkpoint file (works with models trained on multiple GPUs)"""
37 | model_type = config["model_type"]
38 |
39 | if config["model_type"] == "navibridge":
40 | # Select and configure the vision encoder
41 | if config["vision_encoder"] == "navibridge_encoder":
42 | vision_encoder = NaviBridge_Encoder(
43 | obs_encoding_size=config["encoding_size"],
44 | context_size=config["context_size"],
45 | mha_num_attention_heads=config["mha_num_attention_heads"],
46 | mha_num_attention_layers=config["mha_num_attention_layers"],
47 | mha_ff_dim_factor=config["mha_ff_dim_factor"],
48 | )
49 | vision_encoder = replace_bn_with_gn(vision_encoder)
50 | else:
51 | raise ValueError(f"Vision encoder {config['vision_encoder']} not supported")
52 |
53 | # Create the noise prediction network and distribution prediction network
54 | noise_pred_net = ConditionalUnet1D(
55 | input_dim=2,
56 | global_cond_dim=config["encoding_size"],
57 | down_dims=config["down_dims"],
58 | cond_predict_scale=config["cond_predict_scale"],
59 | )
60 | dist_pred_network = DenseNetwork(embedding_dim=config["encoding_size"])
61 |
62 | states_pred_net = StatesPredNet(embedding_dim=config["encoding_size"],
63 | class_num=config["class_num"],
64 | len_traj_pred=config["len_traj_pred"])
65 |
66 | model = NaviBridge(
67 | vision_encoder=vision_encoder,
68 | noise_pred_net=noise_pred_net,
69 | dist_pred_net=dist_pred_network,
70 | states_pred_net=states_pred_net,
71 | )
72 | set_model_prior(model, config, device)
73 | elif config["model_type"] == "cvae":
74 | model = VAEModel(config)
75 | else:
76 | raise ValueError(f"Invalid model type: {model_type}")
77 |
78 | checkpoint = torch.load(model_path, map_location=device)
79 |
80 | if model_type == "navibridge":
81 | state_dict = checkpoint
82 | model.load_state_dict(state_dict, strict=False)
83 | else:
84 | loaded_model = checkpoint["model"]
85 | try:
86 | state_dict = loaded_model.module.state_dict()
87 | model.load_state_dict(state_dict, strict=False)
88 | except AttributeError as e:
89 | state_dict = loaded_model.state_dict()
90 | model.load_state_dict(state_dict, strict=False)
91 | model.to(device)
92 | return model
93 |
94 | def from_numpy(array: np.ndarray) -> torch.Tensor:
95 | return torch.from_numpy(array).float()
96 |
97 | def to_numpy(tensor):
98 | return tensor.cpu().detach().numpy()
99 |
100 |
101 | def transform_images(pil_imgs: List[PILImage.Image], image_size: List[int], center_crop: bool = False) -> torch.Tensor:
102 | """Transforms a list of PIL image to a torch tensor."""
103 | transform_type = transforms.Compose(
104 | [
105 | transforms.ToTensor(),
106 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
107 | 0.229, 0.224, 0.225]),
108 | ]
109 | )
110 | if type(pil_imgs) != list:
111 | pil_imgs = [pil_imgs]
112 | transf_imgs = []
113 | for pil_img in pil_imgs:
114 | w, h = pil_img.size
115 | if center_crop:
116 | if w > h:
117 | pil_img = TF.center_crop(pil_img, (h, int(h * IMAGE_ASPECT_RATIO))) # crop to the right ratio
118 | else:
119 | pil_img = TF.center_crop(pil_img, (int(w / IMAGE_ASPECT_RATIO), w))
120 | pil_img = pil_img.resize(image_size)
121 | transf_img = transform_type(pil_img)
122 | transf_img = torch.unsqueeze(transf_img, 0)
123 | transf_imgs.append(transf_img)
124 | return torch.cat(transf_imgs, dim=1)
125 |
126 | def ensure_pil_image(item):
127 | if isinstance(item, PILImage.Image):
128 | return item
129 | elif isinstance(item, np.ndarray):
130 | return PILImage.fromarray(item)
131 | else:
132 | raise ValueError(f"Unsupported data type: {type(item)}")
133 |
134 | def alternate_merge(list1, list2):
135 | list1 = [ensure_pil_image(item) for item in list1]
136 | list2 = [ensure_pil_image(item) for item in list2]
137 |
138 | merged_list = [None] * (len(list1) + len(list2))
139 | merged_list[::2] = list1
140 | merged_list[1::2] = list2
141 |
142 | return merged_list
143 |
144 | # clip angle between -pi and pi
145 | def clip_angle(angle):
146 | return np.mod(angle + np.pi, 2 * np.pi) - np.pi
147 |
148 | def find_images(directory):
149 | files = os.listdir(directory)
150 | image_files = [file for file in files if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
151 | return [os.path.join(directory, file) for file in image_files]
152 |
153 | def project_and_draw(image, trajectories, extrinsic_matrix, intrinsic_matrix):
154 | """
155 | Project trajectories from world coordinates to pixel coordinates and draw them on the image using matplotlib.
156 | Note that the projection is not accurate, only for visualization.
157 | Args:
158 | image (np.array): The image on which to draw the trajectories.
159 | trajectories (list of list of tuples): Each trajectory is a list of (x, y, z) tuples.
160 | extrinsic_matrix (np.array): The 4x4 extrinsic matrix (world to camera).
161 | intrinsic_matrix (np.array): The 3x4 intrinsic matrix (camera to image).
162 |
163 | Returns:
164 | None: Displays the image with trajectories drawn using matplotlib.
165 | """
166 | # Convert image to numpy array
167 | image_np = np.array(image)
168 | img_height, img_width = image_np.shape[:2]
169 |
170 | # Set up matplotlib figure and axis
171 | fig, ax = plt.subplots(figsize=(img_width / 100, img_height / 100), dpi=100)
172 | ax.imshow(image_np) # Show the image
173 |
174 | # Plot each trajectory
175 | for trajectory in trajectories:
176 | points = []
177 | for (x, y) in trajectory:
178 | # Convert from world to camera coordinates
179 | camera_coords = world_to_camera(np.array([x, y, -1.5]), extrinsic_matrix)
180 | # Project to pixel coordinates
181 | pixel_coords = camera_to_pixel(camera_coords, intrinsic_matrix)
182 | u, v = int(pixel_coords[0]), int(pixel_coords[1])
183 | u += (image.size[0] // 2)
184 | # v += (image.size[1] // 2)
185 | points.append((u, v))
186 |
187 | # Separate u and v coordinates for plotting
188 | if points:
189 | u_coords, v_coords = zip(*points) # Unpack list of tuples into separate lists
190 | ax.plot(u_coords, v_coords, color='yellow', linewidth=8) # Plot trajectory
191 |
192 | # Limit the view to the image's dimensions only
193 | ax.set_xlim(0, img_width)
194 | ax.set_ylim(img_height, 0) # Reverse the y-axis to match image coordinates
195 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
196 |
197 | # Display the final image with trajectories
198 | plt.axis('off') # Turn off axes for clarity
199 | return fig
200 |
201 | def world_to_camera(world_coords, extrinsic_matrix):
202 | """
203 | Convert world coordinates to camera coordinates using the extrinsic matrix.
204 |
205 | Args:
206 | world_coords (np.array): Coordinates in the world frame.
207 | extrinsic_matrix (np.array): The 4x4 extrinsic matrix (world to camera).
208 |
209 | Returns:
210 | np.array: Coordinates in the camera frame.
211 | """
212 | # Convert to homogeneous coordinates
213 | world_coords_homogeneous = np.append(world_coords, 1)
214 | # Transform to camera coordinates
215 | camera_coords_homogeneous = np.linalg.inv(extrinsic_matrix) @ world_coords_homogeneous
216 | # Convert back to 3D and normalize
217 | return camera_coords_homogeneous[:3] / camera_coords_homogeneous[3]
218 |
219 | def camera_to_pixel(camera_coords, intrinsic_matrix):
220 | """
221 | Project camera coordinates to pixel coordinates using the intrinsic matrix.
222 |
223 | Args:
224 | camera_coords (np.array): Coordinates in the camera frame.
225 | intrinsic_matrix (np.array): The 3x4 intrinsic matrix (camera to image).
226 |
227 | Returns:
228 | np.array: Coordinates in pixel space.
229 | """
230 | # Focal lengths and principal point from intrinsic matrix
231 | fx, fy = intrinsic_matrix[0, 0], intrinsic_matrix[1, 1]
232 | cx, cy = intrinsic_matrix[0, 2], intrinsic_matrix[1, 2]
233 | X, Y, Z = camera_coords[0], camera_coords[1], camera_coords[2]
234 |
235 | # Project to pixel coordinates
236 | x = (fx * X / Z) + cx
237 | y = (fy * Y / Z) + cy
238 |
239 | return np.array([x, y]) # Convert to 2D pixel coordinates
--------------------------------------------------------------------------------
/train/vint_train/training/train_eval_loop.py:
--------------------------------------------------------------------------------
1 | import wandb
2 | import os
3 | import numpy as np
4 | from typing import List, Optional, Dict
5 | from prettytable import PrettyTable
6 |
7 | from vint_train.training.train_utils import train_navibridge, evaluate_navibridge
8 | from vint_train.training.train_utils import train_cvae
9 | from vint_train.models.navibridge.ddbm.resample import *
10 | from vint_train.models.navibridge.ddbm.karras_diffusion import KarrasDenoiser
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from torch.utils.data import DataLoader
16 | from torch.optim import Adam
17 | from torchvision import transforms
18 |
19 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
20 | from diffusers.training_utils import EMAModel
21 |
22 | def train_eval_loop_navibridge(
23 | train_model: bool,
24 | model: nn.Module,
25 | optimizer: Adam,
26 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
27 | noise_scheduler: UniformSampler,
28 | diffusuon: KarrasDenoiser,
29 | prior_policy: str,
30 | train_loader: DataLoader,
31 | test_dataloaders: Dict[str, DataLoader],
32 | transform: transforms,
33 | goal_mask_prob: float,
34 | epochs: int,
35 | device: torch.device,
36 | project_folder: str,
37 | print_log_freq: int = 100,
38 | wandb_log_freq: int = 10,
39 | image_log_freq: int = 1000,
40 | num_images_log: int = 8,
41 | current_epoch: int = 0,
42 | alpha: float = 1e-4,
43 | use_wandb: bool = True,
44 | eval_fraction: float = 0.25,
45 | eval_freq: int = 1,
46 | # ddbm params
47 | steps=10,
48 | clip_denoised: bool = True,
49 | sampler = "heun",
50 | sigma_min = 0.002,
51 | sigma_max = 80,
52 | churn_step_ratio = 0.,
53 | rho = 7.0,
54 | guidance = 1,
55 | ):
56 | """
57 | Train and evaluate the model for several epochs
58 | """
59 | latest_path = os.path.join(project_folder, f"latest.pth")
60 | ema_model = EMAModel(model=model,power=0.75)
61 |
62 | for epoch in range(current_epoch, current_epoch + epochs):
63 | if train_model:
64 | print(
65 | f"Start ViNT DP Training Epoch {epoch}/{current_epoch + epochs - 1}"
66 | )
67 | train_navibridge(
68 | model=model,
69 | ema_model=ema_model,
70 | optimizer=optimizer,
71 | dataloader=train_loader,
72 | transform=transform,
73 | device=device,
74 | diffusion=diffusuon,
75 | noise_scheduler=noise_scheduler,
76 | prior_policy=prior_policy,
77 | goal_mask_prob=goal_mask_prob,
78 | project_folder=project_folder,
79 | epoch=epoch,
80 | print_log_freq=print_log_freq,
81 | wandb_log_freq=wandb_log_freq,
82 | image_log_freq=image_log_freq,
83 | num_images_log=num_images_log,
84 | use_wandb=use_wandb,
85 | alpha=alpha,
86 | # ddbm params
87 | steps=steps,
88 | clip_denoised=clip_denoised,
89 | sampler=sampler,
90 | sigma_min=sigma_min,
91 | sigma_max=sigma_max,
92 | churn_step_ratio=churn_step_ratio,
93 | rho=rho,
94 | guidance=guidance,
95 | )
96 | lr_scheduler.step()
97 |
98 | numbered_path = os.path.join(project_folder, f"ema_{epoch}.pth")
99 | torch.save(ema_model.averaged_model.state_dict(), numbered_path)
100 | numbered_path = os.path.join(project_folder, f"ema_latest.pth")
101 | print(f"Saved EMA model to {numbered_path}")
102 |
103 | numbered_path = os.path.join(project_folder, f"{epoch}.pth")
104 | torch.save(model.state_dict(), numbered_path)
105 | torch.save(model.state_dict(), latest_path)
106 | print(f"Saved model to {numbered_path}")
107 |
108 | # save optimizer
109 | numbered_path = os.path.join(project_folder, f"optimizer_{epoch}.pth")
110 | latest_optimizer_path = os.path.join(project_folder, f"optimizer_latest.pth")
111 | torch.save(optimizer.state_dict(), latest_optimizer_path)
112 |
113 | # save scheduler
114 | numbered_path = os.path.join(project_folder, f"scheduler_{epoch}.pth")
115 | latest_scheduler_path = os.path.join(project_folder, f"scheduler_latest.pth")
116 | torch.save(lr_scheduler.state_dict(), latest_scheduler_path)
117 |
118 |
119 | if (epoch + 1) % eval_freq == 0:
120 | for dataset_type in test_dataloaders:
121 | print(
122 | f"Start {dataset_type} ViNT DP Testing Epoch {epoch}/{current_epoch + epochs - 1}"
123 | )
124 | loader = test_dataloaders[dataset_type]
125 | evaluate_navibridge(
126 | eval_type=dataset_type,
127 | ema_model=ema_model,
128 | dataloader=loader,
129 | transform=transform,
130 | device=device,
131 | prior_policy=prior_policy,
132 | diffusion=diffusuon,
133 | noise_scheduler=noise_scheduler,
134 | goal_mask_prob=goal_mask_prob,
135 | project_folder=project_folder,
136 | epoch=epoch,
137 | print_log_freq=print_log_freq,
138 | num_images_log=num_images_log,
139 | wandb_log_freq=wandb_log_freq,
140 | use_wandb=use_wandb,
141 | eval_fraction=eval_fraction,
142 | # ddbm params
143 | steps=steps,
144 | clip_denoised=clip_denoised,
145 | sampler=sampler,
146 | sigma_min=sigma_min,
147 | sigma_max=sigma_max,
148 | churn_step_ratio=churn_step_ratio,
149 | rho=rho,
150 | guidance=guidance,
151 | )
152 | wandb.log({
153 | "lr": optimizer.param_groups[0]["lr"],
154 | }, commit=False)
155 |
156 | if lr_scheduler is not None:
157 | lr_scheduler.step()
158 |
159 | # log average eval loss
160 | wandb.log({}, commit=False)
161 |
162 | wandb.log({
163 | "lr": optimizer.param_groups[0]["lr"],
164 | }, commit=False)
165 |
166 |
167 | # Flush the last set of eval logs
168 | wandb.log({})
169 | print()
170 |
171 | def train_eval_loop_cvae(
172 | train_model: bool,
173 | model: nn.Module,
174 | optimizer:Adam,
175 | lr_scheduler: torch.optim.lr_scheduler.StepLR,
176 | train_loader: DataLoader,
177 | transform: transforms,
178 | # num_itr: int,
179 | epochs: int,
180 | prior_policy: str,
181 | # model_args,
182 | device: torch.device,
183 | project_folder: str,
184 | print_log_freq: int = 100,
185 | wandb_log_freq: int = 10,
186 | current_epoch: int = 0,
187 | save_freq=1,
188 | use_wandb: bool = True,
189 | ):
190 | """
191 | This function handles the training and evaluation loop for a CVAE model.
192 |
193 | Args:
194 | model (torch.nn.Module): The model to be trained.
195 | optimizer (torch.optim.Optimizer): The optimizer used to update model parameters.
196 | sched (torch.optim.lr_scheduler, optional): Learning rate scheduler, can be None.
197 | dataloader (torch.utils.data.DataLoader): Dataloader for training data.
198 | model_args (dict): Contains model-specific arguments, including prior_policy.
199 | opt (argparse.Namespace or similar): Contains other optimization settings like device and log_freq.
200 | ckpt_path (str): Path where model checkpoints and logs should be saved.
201 | log_freq (int, optional): Frequency of logging to wandb, defaults to 10 iterations.
202 | save_freq (int, optional): Frequency of saving the model, defaults to 100 iterations.
203 | """
204 |
205 | for epoch in range(current_epoch, current_epoch + epochs):
206 | if train_model:
207 | print(
208 | f"Start ViNT DP Training Epoch {epoch}/{current_epoch + epochs - 1}"
209 | )
210 | train_cvae(
211 | model=model,
212 | optimizer=optimizer,
213 | dataloader=train_loader,
214 | transform=transform,
215 | device=device,
216 | project_folder=project_folder,
217 | epoch=epoch,
218 | print_log_freq=print_log_freq,
219 | wandb_log_freq=wandb_log_freq,
220 | use_wandb=use_wandb,
221 | )
222 | lr_scheduler.step()
223 |
224 | # Save model checkpoints and optimizer state
225 | if (epoch + 1) % save_freq == 0:
226 | numbered_path = os.path.join(project_folder, f"cvae_{epoch}.pth")
227 | model.save_model(ckpt_path=numbered_path, epoch=epoch)
228 |
229 | # save optimizer
230 | latest_optimizer_path = os.path.join(project_folder, f"optimizer_latest.pth")
231 | torch.save(optimizer.state_dict(), latest_optimizer_path)
232 |
233 | # save scheduler
234 | latest_scheduler_path = os.path.join(project_folder, f"scheduler_latest.pth")
235 | torch.save(lr_scheduler.state_dict(), latest_scheduler_path)
236 |
237 | if lr_scheduler is not None:
238 | lr_scheduler.step()
239 |
240 | def load_model(model, model_type, checkpoint: dict) -> None:
241 | """Load model from checkpoint."""
242 | if model_type == "navibridge":
243 | state_dict = checkpoint
244 | model.load_state_dict(state_dict, strict=False)
245 | else:
246 | loaded_model = checkpoint["model"]
247 | try:
248 | state_dict = loaded_model.module.state_dict()
249 | model.load_state_dict(state_dict, strict=False)
250 | except AttributeError as e:
251 | state_dict = loaded_model.state_dict()
252 | model.load_state_dict(state_dict, strict=False)
253 |
254 |
255 | def load_ema_model(ema_model, state_dict: dict) -> None:
256 | """Load model from checkpoint."""
257 | ema_model.load_state_dict(state_dict)
258 |
259 |
260 | def count_parameters(model):
261 | table = PrettyTable(["Modules", "Parameters"])
262 | total_params = 0
263 | for name, parameter in model.named_parameters():
264 | if not parameter.requires_grad: continue
265 | params = parameter.numel()
266 | table.add_row([name, params])
267 | total_params+=params
268 |
269 | print(f"Total Trainable Params: {total_params/1e6:.2f}M")
270 | return total_params
--------------------------------------------------------------------------------
/train/vint_train/process_data/process_data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import io
3 | import os
4 | import rosbag
5 | from PIL import Image
6 | import cv2
7 | from typing import Any, Tuple, List, Dict
8 | import torchvision.transforms.functional as TF
9 |
10 | IMAGE_SIZE = (160, 120)
11 | IMAGE_ASPECT_RATIO = 4 / 3
12 |
13 |
14 | def process_images(im_list: List, img_process_func) -> List:
15 | """
16 | Process image data from a topic that publishes ros images into a list of PIL images
17 | """
18 | images = []
19 | for img_msg in im_list:
20 | img = img_process_func(img_msg)
21 | images.append(img)
22 | return images
23 |
24 |
25 | def process_tartan_img(msg) -> Image:
26 | """
27 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the tartan_drive dataset
28 | """
29 | img = ros_to_numpy(msg, output_resolution=IMAGE_SIZE) * 255
30 | img = img.astype(np.uint8)
31 | # reverse the axis order to get the image in the right orientation
32 | img = np.moveaxis(img, 0, -1)
33 | # convert rgb to bgr
34 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
35 | img = Image.fromarray(img)
36 | return img
37 |
38 |
39 | def process_locobot_img(msg) -> Image:
40 | """
41 | Process image data from a topic that publishes sensor_msgs/Image to a PIL image for the locobot dataset
42 | """
43 | img = np.frombuffer(msg.data, dtype=np.uint8).reshape(
44 | msg.height, msg.width, -1)
45 | pil_image = Image.fromarray(img)
46 | return pil_image
47 |
48 |
49 | def process_scand_img(msg) -> Image:
50 | """
51 | Process image data from a topic that publishes sensor_msgs/CompressedImage to a PIL image for the scand dataset
52 | """
53 | # convert sensor_msgs/CompressedImage to PIL image
54 | img = Image.open(io.BytesIO(msg.data))
55 | # center crop image to 4:3 aspect ratio
56 | w, h = img.size
57 | img = TF.center_crop(
58 | img, (h, int(h * IMAGE_ASPECT_RATIO))
59 | ) # crop to the right ratio
60 | # resize image to IMAGE_SIZE
61 | img = img.resize(IMAGE_SIZE)
62 | return img
63 |
64 |
65 | ############## Add custom image processing functions here #############
66 |
67 | def process_sacson_img(msg) -> Image:
68 | np_arr = np.fromstring(msg.data, np.uint8)
69 | image_np = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
70 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
71 | pil_image = Image.fromarray(image_np)
72 | return pil_image
73 |
74 |
75 | #######################################################################
76 |
77 |
78 | def process_odom(
79 | odom_list: List,
80 | odom_process_func: Any,
81 | ang_offset: float = 0.0,
82 | ) -> Dict[np.ndarray, np.ndarray]:
83 | """
84 | Process odom data from a topic that publishes nav_msgs/Odometry into position and yaw
85 | """
86 | xys = []
87 | yaws = []
88 | for odom_msg in odom_list:
89 | xy, yaw = odom_process_func(odom_msg, ang_offset)
90 | xys.append(xy)
91 | yaws.append(yaw)
92 | return {"position": np.array(xys), "yaw": np.array(yaws)}
93 |
94 |
95 | def nav_to_xy_yaw(odom_msg, ang_offset: float) -> Tuple[List[float], float]:
96 | """
97 | Process odom data from a topic that publishes nav_msgs/Odometry into position
98 | """
99 |
100 | position = odom_msg.pose.pose.position
101 | orientation = odom_msg.pose.pose.orientation
102 | yaw = (
103 | quat_to_yaw(orientation.x, orientation.y, orientation.z, orientation.w)
104 | + ang_offset
105 | )
106 | return [position.x, position.y], yaw
107 |
108 |
109 | ############ Add custom odometry processing functions here ############
110 |
111 |
112 | #######################################################################
113 |
114 |
115 | def get_images_and_odom(
116 | bag: rosbag.Bag,
117 | imtopics: List[str] or str,
118 | odomtopics: List[str] or str,
119 | img_process_func: Any,
120 | odom_process_func: Any,
121 | rate: float = 4.0,
122 | ang_offset: float = 0.0,
123 | ):
124 | """
125 | Get image and odom data from a bag file
126 |
127 | Args:
128 | bag (rosbag.Bag): bag file
129 | imtopics (list[str] or str): topic name(s) for image data
130 | odomtopics (list[str] or str): topic name(s) for odom data
131 | img_process_func (Any): function to process image data
132 | odom_process_func (Any): function to process odom data
133 | rate (float, optional): rate to sample data. Defaults to 4.0.
134 | ang_offset (float, optional): angle offset to add to odom data. Defaults to 0.0.
135 | Returns:
136 | img_data (list): list of PIL images
137 | traj_data (list): list of odom data
138 | """
139 | # check if bag has both topics
140 | odomtopic = None
141 | imtopic = None
142 | if type(imtopics) == str:
143 | imtopic = imtopics
144 | else:
145 | for imt in imtopics:
146 | if bag.get_message_count(imt) > 0:
147 | imtopic = imt
148 | break
149 | if type(odomtopics) == str:
150 | odomtopic = odomtopics
151 | else:
152 | for ot in odomtopics:
153 | if bag.get_message_count(ot) > 0:
154 | odomtopic = ot
155 | break
156 | if not (imtopic and odomtopic):
157 | # bag doesn't have both topics
158 | return None, None
159 |
160 | synced_imdata = []
161 | synced_odomdata = []
162 | # get start time of bag in seconds
163 | currtime = bag.get_start_time()
164 |
165 | curr_imdata = None
166 | curr_odomdata = None
167 |
168 | for topic, msg, t in bag.read_messages(topics=[imtopic, odomtopic]):
169 | if topic == imtopic:
170 | curr_imdata = msg
171 | elif topic == odomtopic:
172 | curr_odomdata = msg
173 | if (t.to_sec() - currtime) >= 1.0 / rate:
174 | if curr_imdata is not None and curr_odomdata is not None:
175 | synced_imdata.append(curr_imdata)
176 | synced_odomdata.append(curr_odomdata)
177 | currtime = t.to_sec()
178 |
179 | img_data = process_images(synced_imdata, img_process_func)
180 | traj_data = process_odom(
181 | synced_odomdata,
182 | odom_process_func,
183 | ang_offset=ang_offset,
184 | )
185 |
186 | return img_data, traj_data
187 |
188 |
189 | def is_backwards(
190 | pos1: np.ndarray, yaw1: float, pos2: np.ndarray, eps: float = 1e-5
191 | ) -> bool:
192 | """
193 | Check if the trajectory is going backwards given the position and yaw of two points
194 | Args:
195 | pos1: position of the first point
196 |
197 | """
198 | dx, dy = pos2 - pos1
199 | return dx * np.cos(yaw1) + dy * np.sin(yaw1) < eps
200 |
201 |
202 | # cut out non-positive velocity segments of the trajectory
203 | def filter_backwards(
204 | img_list: List[Image.Image],
205 | traj_data: Dict[str, np.ndarray],
206 | start_slack: int = 0,
207 | end_slack: int = 0,
208 | ) -> Tuple[List[np.ndarray], List[int]]:
209 | """
210 | Cut out non-positive velocity segments of the trajectory
211 | Args:
212 | traj_type: type of trajectory to cut
213 | img_list: list of images
214 | traj_data: dictionary of position and yaw data
215 | start_slack: number of points to ignore at the start of the trajectory
216 | end_slack: number of points to ignore at the end of the trajectory
217 | Returns:
218 | cut_trajs: list of cut trajectories
219 | start_times: list of start times of the cut trajectories
220 | """
221 | traj_pos = traj_data["position"]
222 | traj_yaws = traj_data["yaw"]
223 | cut_trajs = []
224 | start = True
225 |
226 | def process_pair(traj_pair: list) -> Tuple[List, Dict]:
227 | new_img_list, new_traj_data = zip(*traj_pair)
228 | new_traj_data = np.array(new_traj_data)
229 | new_traj_pos = new_traj_data[:, :2]
230 | new_traj_yaws = new_traj_data[:, 2]
231 | return (new_img_list, {"position": new_traj_pos, "yaw": new_traj_yaws})
232 |
233 | for i in range(max(start_slack, 1), len(traj_pos) - end_slack):
234 | pos1 = traj_pos[i - 1]
235 | yaw1 = traj_yaws[i - 1]
236 | pos2 = traj_pos[i]
237 | if not is_backwards(pos1, yaw1, pos2):
238 | if start:
239 | new_traj_pairs = [
240 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]])
241 | ]
242 | start = False
243 | elif i == len(traj_pos) - end_slack - 1:
244 | cut_trajs.append(process_pair(new_traj_pairs))
245 | else:
246 | new_traj_pairs.append(
247 | (img_list[i - 1], [*traj_pos[i - 1], traj_yaws[i - 1]])
248 | )
249 | elif not start:
250 | cut_trajs.append(process_pair(new_traj_pairs))
251 | start = True
252 | return cut_trajs
253 |
254 |
255 | def quat_to_yaw(
256 | x: np.ndarray,
257 | y: np.ndarray,
258 | z: np.ndarray,
259 | w: np.ndarray,
260 | ) -> np.ndarray:
261 | """
262 | Convert a batch quaternion into a yaw angle
263 | yaw is rotation around z in radians (counterclockwise)
264 | """
265 | t3 = 2.0 * (w * z + x * y)
266 | t4 = 1.0 - 2.0 * (y * y + z * z)
267 | yaw = np.arctan2(t3, t4)
268 | return yaw
269 |
270 |
271 | def ros_to_numpy(
272 | msg, nchannels=3, empty_value=None, output_resolution=None, aggregate="none"
273 | ):
274 | """
275 | Convert a ROS image message to a numpy array
276 | """
277 | if output_resolution is None:
278 | output_resolution = (msg.width, msg.height)
279 |
280 | is_rgb = "8" in msg.encoding
281 | if is_rgb:
282 | data = np.frombuffer(msg.data, dtype=np.uint8).copy()
283 | else:
284 | data = np.frombuffer(msg.data, dtype=np.float32).copy()
285 |
286 | data = data.reshape(msg.height, msg.width, nchannels)
287 |
288 | if empty_value:
289 | mask = np.isclose(abs(data), empty_value)
290 | fill_value = np.percentile(data[~mask], 99)
291 | data[mask] = fill_value
292 |
293 | data = cv2.resize(
294 | data,
295 | dsize=(output_resolution[0], output_resolution[1]),
296 | interpolation=cv2.INTER_AREA,
297 | )
298 |
299 | if aggregate == "littleendian":
300 | data = sum([data[:, :, i] * (256**i) for i in range(nchannels)])
301 | elif aggregate == "bigendian":
302 | data = sum([data[:, :, -(i + 1)] * (256**i) for i in range(nchannels)])
303 |
304 | if len(data.shape) == 2:
305 | data = np.expand_dims(data, axis=0)
306 | else:
307 | data = np.moveaxis(data, 2, 0) # Switch to channels-first
308 |
309 | if is_rgb:
310 | data = data.astype(np.float32) / (
311 | 255.0 if aggregate == "none" else 255.0**nchannels
312 | )
313 |
314 | return data
315 |
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import wandb
4 | import argparse
5 | import numpy as np
6 | import yaml
7 | import time
8 | import pdb
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import DataLoader, ConcatDataset
13 |
14 | from torchvision import transforms
15 | import torch.backends.cudnn as cudnn
16 |
17 | from vint_train.models.model_utils import create_noise_scheduler, create_model, get_optimizer_and_scheduler
18 |
19 |
20 | from vint_train.data.vint_dataset import ViNT_Dataset
21 | from vint_train.training.train_eval_loop import (
22 | load_model,
23 | train_eval_loop_navibridge,
24 | train_eval_loop_cvae,
25 | )
26 |
27 |
28 | def main(config):
29 | assert config["distance"]["min_dist_cat"] < config["distance"]["max_dist_cat"]
30 | assert config["action"]["min_dist_cat"] < config["action"]["max_dist_cat"]
31 |
32 | if torch.cuda.is_available():
33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
34 | if "gpu_ids" not in config:
35 | config["gpu_ids"] = [0]
36 | elif type(config["gpu_ids"]) == int:
37 | config["gpu_ids"] = [config["gpu_ids"]]
38 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
39 | [str(x) for x in config["gpu_ids"]]
40 | )
41 | print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"])
42 | else:
43 | print("Using cpu")
44 |
45 | first_gpu_id = config["gpu_ids"][0]
46 | device = torch.device(
47 | f"cuda:{first_gpu_id}" if torch.cuda.is_available() else "cpu"
48 | )
49 |
50 | if "seed" in config:
51 | np.random.seed(config["seed"])
52 | torch.manual_seed(config["seed"])
53 | cudnn.deterministic = True
54 |
55 | cudnn.benchmark = True # good if input sizes don't vary
56 | transform = ([
57 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
58 | ])
59 | transform = transforms.Compose(transform)
60 |
61 | if config["model_type"] == "cvae":
62 | if "train_params" in config:
63 | config.update(config["train_params"])
64 | if "diffuse_params" in config:
65 | config.update(config["diffuse_params"])
66 |
67 | if config.get("prior_policy", None) == "cvae":
68 | if "diffuse_params" in config:
69 | config.update(config["diffuse_params"])
70 |
71 | # Load the data
72 | train_dataset = []
73 | test_dataloaders = {}
74 |
75 | if "context_type" not in config:
76 | config["context_type"] = "temporal"
77 |
78 | if "clip_goals" not in config:
79 | config["clip_goals"] = False
80 |
81 | for dataset_name in config["datasets"]:
82 | data_config = config["datasets"][dataset_name]
83 | if "negative_mining" not in data_config:
84 | data_config["negative_mining"] = True
85 | if "goals_per_obs" not in data_config:
86 | data_config["goals_per_obs"] = 1
87 | if "end_slack" not in data_config:
88 | data_config["end_slack"] = 0
89 | if "waypoint_spacing" not in data_config:
90 | data_config["waypoint_spacing"] = 1
91 |
92 | for data_split_type in ["train", "test"]:
93 | if data_split_type in data_config:
94 | dataset = ViNT_Dataset(
95 | data_folder=data_config["data_folder"],
96 | data_split_folder=data_config[data_split_type],
97 | dataset_name=dataset_name,
98 | image_size=config["image_size"],
99 | waypoint_spacing=data_config["waypoint_spacing"],
100 | min_dist_cat=config["distance"]["min_dist_cat"],
101 | max_dist_cat=config["distance"]["max_dist_cat"],
102 | min_action_distance=config["action"]["min_dist_cat"],
103 | max_action_distance=config["action"]["max_dist_cat"],
104 | negative_mining=data_config["negative_mining"],
105 | len_traj_pred=config["len_traj_pred"],
106 | learn_angle=config["learn_angle"],
107 | context_size=config["context_size"],
108 | context_type=config["context_type"],
109 | end_slack=data_config["end_slack"],
110 | goals_per_obs=data_config["goals_per_obs"],
111 | normalize=config["normalize"],
112 | goal_type=config["goal_type"],
113 | angle_ranges=config.get("angle_ranges", None),
114 | )
115 | if data_split_type == "train":
116 | train_dataset.append(dataset)
117 | else:
118 | dataset_type = f"{dataset_name}_{data_split_type}"
119 | if dataset_type not in test_dataloaders:
120 | test_dataloaders[dataset_type] = {}
121 | test_dataloaders[dataset_type] = dataset
122 |
123 | # combine all the datasets from different robots
124 | train_dataset = ConcatDataset(train_dataset)
125 |
126 | train_loader = DataLoader(
127 | train_dataset,
128 | batch_size=config["batch_size"],
129 | shuffle=True,
130 | num_workers=config["num_workers"],
131 | drop_last=False,
132 | persistent_workers=True,
133 | )
134 |
135 | if "eval_batch_size" not in config:
136 | config["eval_batch_size"] = config["batch_size"]
137 |
138 | for dataset_type, dataset in test_dataloaders.items():
139 | test_dataloaders[dataset_type] = DataLoader(
140 | dataset,
141 | batch_size=config["eval_batch_size"],
142 | shuffle=True,
143 | num_workers=0,
144 | drop_last=False,
145 | )
146 |
147 | model = create_model(config, device)
148 | if config["model_type"] == "navibridge":
149 | noise_scheduler, diffusion = create_noise_scheduler(config)
150 | elif config["model_type"] == "cvae":
151 | model.load_model(model_args=config, device=device)
152 | if config["clipping"]:
153 | print("Clipping gradients to", config["max_norm"])
154 | for p in model.parameters():
155 | if not p.requires_grad:
156 | continue
157 | p.register_hook(
158 | lambda grad: torch.clamp(
159 | grad, -1 * config["max_norm"], config["max_norm"]
160 | )
161 | )
162 |
163 | if config["model_type"] == "cvae":
164 | optimizer, scheduler = get_optimizer_and_scheduler(config, model.net)
165 | else:
166 | optimizer, scheduler = get_optimizer_and_scheduler(config, model)
167 | current_epoch = 0
168 | if "load_run" in config:
169 | load_project_folder = os.path.join("logs", config["load_run"])
170 | print("Loading model from ", load_project_folder)
171 | latest_path = os.path.join(load_project_folder, "latest.pth")
172 | latest_checkpoint = torch.load(latest_path) #f"cuda:{}" if torch.cuda.is_available() else "cpu")
173 | load_model(model, config["model_type"], latest_checkpoint)
174 | if "epoch" in latest_checkpoint:
175 | current_epoch = latest_checkpoint["epoch"] + 1
176 |
177 | # Multi-GPU
178 | if len(config["gpu_ids"]) > 1:
179 | model = nn.DataParallel(model, device_ids=config["gpu_ids"])
180 | if config["model_type"] == "cvae":
181 | model.net = model.net.to(device)
182 | else:
183 | model = model.to(device)
184 |
185 | if "load_run" in config: # load optimizer and scheduler after data parallel
186 | if "optimizer" in latest_checkpoint:
187 | optimizer.load_state_dict(latest_checkpoint["optimizer"].state_dict())
188 | if scheduler is not None and "scheduler" in latest_checkpoint:
189 | scheduler.load_state_dict(latest_checkpoint["scheduler"].state_dict())
190 |
191 | if config["model_type"] == "navibridge":
192 | train_eval_loop_navibridge(
193 | train_model=config["train"],
194 | model=model,
195 | diffusuon=diffusion,
196 | optimizer=optimizer,
197 | lr_scheduler=scheduler,
198 | noise_scheduler=noise_scheduler,
199 | prior_policy=config["prior_policy"],
200 | train_loader=train_loader,
201 | test_dataloaders=test_dataloaders,
202 | transform=transform,
203 | goal_mask_prob=config["goal_mask_prob"],
204 | epochs=config["epochs"],
205 | device=device,
206 | project_folder=config["project_folder"],
207 | print_log_freq=config["print_log_freq"],
208 | wandb_log_freq=config["wandb_log_freq"],
209 | image_log_freq=config["image_log_freq"],
210 | num_images_log=config["num_images_log"],
211 | current_epoch=current_epoch,
212 | alpha=float(config["alpha"]),
213 | use_wandb=config["use_wandb"],
214 | eval_fraction=config["eval_fraction"],
215 | eval_freq=config["eval_freq"],
216 | # ddbm params
217 | steps=config["num_diffusion_iters"],
218 | clip_denoised=config["clip_denoised"],
219 | sampler=config["sampler"],
220 | sigma_min=diffusion.sigma_min,
221 | sigma_max=diffusion.sigma_max,
222 | churn_step_ratio=config["churn_step_ratio"],
223 | rho=config["rho"],
224 | guidance=config["guidance"],
225 | )
226 | elif config["model_type"] == "cvae":
227 | train_eval_loop_cvae(
228 | train_model=config["train"],
229 | model=model,
230 | optimizer=optimizer,
231 | lr_scheduler=scheduler,
232 | train_loader=train_loader,
233 | transform=transform,
234 | epochs=config["epochs"],
235 | prior_policy=config["prior_policy"],
236 | device=device,
237 | project_folder=config["project_folder"],
238 | print_log_freq=config["print_log_freq"],
239 | wandb_log_freq=config["wandb_log_freq"],
240 | current_epoch=current_epoch,
241 | save_freq=config["save_freq"],
242 | use_wandb=config["use_wandb"],
243 | )
244 | print("FINISHED TRAINING")
245 |
246 |
247 | if __name__ == "__main__":
248 | torch.multiprocessing.set_start_method("spawn")
249 |
250 | parser = argparse.ArgumentParser(description="Visual Navigation Transformer")
251 |
252 | # project setup
253 | parser.add_argument(
254 | "--config",
255 | "-c",
256 | default="config/navibridge.yaml",
257 | type=str,
258 | help="Path to the config file in train_config folder",
259 | )
260 | args = parser.parse_args()
261 |
262 | with open("config/defaults.yaml", "r") as f:
263 | default_config = yaml.safe_load(f)
264 |
265 | config = default_config
266 |
267 | with open(args.config, "r") as f:
268 | user_config = yaml.safe_load(f)
269 |
270 | config.update(user_config)
271 |
272 | config["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
273 | config["project_folder"] = os.path.join(
274 | "logs", config["project_name"], config["run_name"]
275 | )
276 | os.makedirs(
277 | config[
278 | "project_folder"
279 | ], # should error if dir already exists to avoid overwriting and old project
280 | )
281 |
282 | if config["use_wandb"]:
283 | wandb.login()
284 | wandb.init(
285 | project=config["project_name"],
286 | settings=wandb.Settings(start_method="fork"),
287 | entity="offline", # TODO: change this to your wandb entity
288 | mode="offline",
289 | )
290 |
291 | # Manually copy the config file to the wandb run directory
292 | config_filename = os.path.basename(args.config)
293 | dest_config_path = os.path.join(wandb.run.dir, config_filename)
294 | shutil.copyfile(args.config, dest_config_path)
295 |
296 | # Save the copied config file to wandb
297 | wandb.save(dest_config_path, policy="now")
298 |
299 | wandb.run.name = config["run_name"]
300 | # Update the wandb args with the training configurations
301 | if wandb.run:
302 | wandb.config.update(config)
303 |
304 | print(config)
305 | main(config)
306 |
--------------------------------------------------------------------------------
/train/vint_train/models/navibridge/ddbm/karras_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | based on: https://github.com/crowsonkb/k-diffusion
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from piq import LPIPS
10 |
11 | from .nn import mean_flat, append_dims, append_zero
12 |
13 | from functools import partial
14 |
15 |
16 | def vp_logsnr(t, beta_d, beta_min):
17 | t = th.as_tensor(t)
18 | return - th.log((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1)
19 |
20 | def vp_logs(t, beta_d, beta_min):
21 | t = th.as_tensor(t)
22 | return -0.25 * t ** 2 * (beta_d) - 0.5 * t * beta_min
23 |
24 | class KarrasDenoiser:
25 | def __init__(
26 | self,
27 | sigma_data: float = 0.5,
28 | sigma_max=80.0,
29 | sigma_min=0.002,
30 | beta_d=2,
31 | beta_min=0.1,
32 | cov_xy=0.,
33 | rho=7.0,
34 | image_size=64,
35 | num_timesteps=10,
36 | weight_schedule="karras",
37 | pred_mode='both',
38 | loss_norm="lpips",
39 | ):
40 | self.sigma_data = sigma_data
41 | self.sigma_max = sigma_max
42 | self.sigma_min = sigma_min
43 |
44 | self.beta_d = beta_d
45 | self.beta_min = beta_min
46 |
47 | self.sigma_data_end = self.sigma_data
48 | self.cov_xy = cov_xy
49 |
50 | self.c = 1
51 |
52 | self.weight_schedule = weight_schedule
53 | self.pred_mode = pred_mode
54 | self.loss_norm = loss_norm
55 | self.rho = rho
56 | self.num_timesteps = num_timesteps
57 | self.image_size = image_size
58 |
59 | def get_snr(self, sigmas):
60 | if self.pred_mode.startswith('vp'):
61 | return vp_logsnr(sigmas, self.beta_d, self.beta_min).exp()
62 | else:
63 | return sigmas**-2
64 |
65 | def get_sigmas(self, sigmas):
66 | return sigmas
67 |
68 | def get_weightings(self, sigma):
69 | snrs = self.get_snr(sigma)
70 |
71 | if self.weight_schedule == "snr":
72 | weightings = snrs
73 | elif self.weight_schedule == "snr+1":
74 | weightings = snrs + 1
75 | elif self.weight_schedule == "karras":
76 | weightings = snrs + 1.0 / self.sigma_data**2
77 | elif self.weight_schedule.startswith("bridge_karras"):
78 | if self.pred_mode == 've':
79 | A = sigma**4 / self.sigma_max**4 * self.sigma_data_end**2 + (1 - sigma**2 / self.sigma_max**2)**2 * self.sigma_data**2 + 2*sigma**2 / self.sigma_max**2 * (1 - sigma**2 / self.sigma_max**2) * self.cov_xy + self.c**2 * sigma**2 * (1 - sigma**2 / self.sigma_max**2)
80 | weightings = A / ((sigma/self.sigma_max)**4 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + self.sigma_data**2 * self.c**2 * sigma**2 * (1 - sigma**2/self.sigma_max**2) )
81 |
82 | elif self.pred_mode == 'vp':
83 | logsnr_t = vp_logsnr(sigma, self.beta_d, self.beta_min)
84 | logsnr_T = vp_logsnr(1, self.beta_d, self.beta_min)
85 | logs_t = vp_logs(sigma, self.beta_d, self.beta_min)
86 | logs_T = vp_logs(1, self.beta_d, self.beta_min)
87 |
88 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp()
89 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp()
90 | c_t = -th.expm1(logsnr_T - logsnr_t) * (2*logs_t - logsnr_t).exp()
91 |
92 | A = a_t**2 * self.sigma_data_end**2 + b_t**2 * self.sigma_data**2 + 2*a_t * b_t * self.cov_xy + self.c**2 * c_t
93 | weightings = A / (a_t**2 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + self.sigma_data**2 * self.c**2 * c_t )
94 |
95 | elif self.pred_mode == 'vp_simple' or self.pred_mode == 've_simple':
96 | weightings = th.ones_like(snrs)
97 | elif self.weight_schedule == "truncated-snr":
98 | weightings = th.clamp(snrs, min=1.0)
99 | elif self.weight_schedule == "uniform":
100 | weightings = th.ones_like(snrs)
101 | else:
102 | raise NotImplementedError()
103 |
104 | weightings = th.where(th.isfinite(weightings), weightings, th.zeros_like(weightings))
105 | return weightings
106 |
107 | def get_bridge_scalings(self, sigma):
108 | if self.pred_mode == 've':
109 | A = sigma**4 / self.sigma_max**4 * self.sigma_data_end**2 + \
110 | (1 - sigma**2 / self.sigma_max**2)**2 * self.sigma_data**2 + \
111 | 2*sigma**2 / self.sigma_max**2 * (1 - sigma**2 / self.sigma_max**2) * self.cov_xy + \
112 | self.c **2 * sigma**2 * (1 - sigma**2 / self.sigma_max**2)
113 | c_in = 1 / (A) ** 0.5
114 | c_skip = ((1 - sigma**2 / self.sigma_max**2) * self.sigma_data**2 + sigma**2 / self.sigma_max**2 * self.cov_xy) / A
115 | c_out = ((sigma/self.sigma_max)**4 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + \
116 | self.sigma_data**2 * self.c **2 * sigma**2 * (1 - sigma**2/self.sigma_max**2) )**0.5 * c_in
117 | return c_skip, c_out, c_in
118 |
119 | elif self.pred_mode == 'vp':
120 | logsnr_t = vp_logsnr(sigma, self.beta_d, self.beta_min)
121 | logsnr_T = vp_logsnr(1, self.beta_d, self.beta_min)
122 | logs_t = vp_logs(sigma, self.beta_d, self.beta_min)
123 | logs_T = vp_logs(1, self.beta_d, self.beta_min)
124 |
125 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp()
126 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp()
127 | c_t = -th.expm1(logsnr_T - logsnr_t) * (2*logs_t - logsnr_t).exp()
128 |
129 | A = a_t**2 * self.sigma_data_end**2 + b_t**2 * self.sigma_data**2 + \
130 | 2*a_t * b_t * self.cov_xy + self.c**2 * c_t
131 |
132 | c_in = 1 / (A) ** 0.5
133 | c_skip = (b_t * self.sigma_data**2 + a_t * self.cov_xy) / A
134 | c_out = (a_t**2 * (self.sigma_data_end**2 * self.sigma_data**2 - self.cov_xy**2) + \
135 | self.sigma_data**2 * self.c **2 * c_t )**0.5 * c_in
136 | return c_skip, c_out, c_in
137 |
138 | elif self.pred_mode == 've_simple' or self.pred_mode == 'vp_simple':
139 | c_in = th.ones_like(sigma)
140 | c_out = th.ones_like(sigma)
141 | c_skip = th.zeros_like(sigma)
142 | return c_skip, c_out, c_in
143 |
144 | def training_bridge_losses(self, model, x_start, sigmas, global_cond=None, model_kwargs=None, noise=None, vae=None, train=True):
145 | assert model_kwargs is not None
146 | xT = model_kwargs['xT']
147 | if noise is None:
148 | noise = th.randn_like(x_start)
149 | sigmas = th.minimum(sigmas, th.ones_like(sigmas) * self.sigma_max)
150 | terms = {}
151 |
152 | dims = x_start.ndim
153 |
154 | def bridge_sample(x0, xT, t):
155 | t = append_dims(t, dims)
156 | if self.pred_mode.startswith('ve'):
157 | std_t = t * th.sqrt(1 - t**2 / self.sigma_max**2)
158 | mu_t = t**2 / self.sigma_max**2 * xT + (1 - t**2 / self.sigma_max**2) * x0
159 | samples = (mu_t + std_t * noise)
160 | elif self.pred_mode.startswith('vp'):
161 | logsnr_t = vp_logsnr(t, self.beta_d, self.beta_min)
162 | logsnr_T = vp_logsnr(self.sigma_max, self.beta_d, self.beta_min)
163 | logs_t = vp_logs(t, self.beta_d, self.beta_min)
164 | logs_T = vp_logs(self.sigma_max, self.beta_d, self.beta_min)
165 |
166 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp()
167 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp()
168 | std_t = (-th.expm1(logsnr_T - logsnr_t)).sqrt() * (logs_t - logsnr_t/2).exp()
169 |
170 | samples = a_t * xT + b_t * x0 + std_t * noise
171 | return samples
172 |
173 | x_t = bridge_sample(x_start, xT, sigmas)
174 |
175 | model_output, denoised = self.denoise(model, x_t, sigmas, global_cond, **model_kwargs)
176 |
177 | weights = self.get_weightings(sigmas)
178 | weights = append_dims((weights), dims)
179 |
180 | terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
181 | terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2)
182 | if th.isnan(terms["mse"]).any():
183 | import ipdb;ipdb.set_trace()
184 |
185 | if "vb" in terms:
186 | terms["loss"] = terms["mse"] + terms["vb"]
187 | else:
188 | terms["loss"] = terms["mse"]
189 | return terms, denoised
190 |
191 | def denoise(self, model, x_t, sigmas, global_cond=None, **model_kwargs):
192 | c_skip, c_out, c_in = [
193 | append_dims(x, x_t.ndim) for x in self.get_bridge_scalings(sigmas)
194 | ]
195 |
196 | rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
197 |
198 | model_output = model("noise_pred_net",
199 | sample=c_in * x_t,
200 | timestep=rescaled_t,
201 | global_cond=global_cond,
202 | **model_kwargs)
203 | denoised = c_out * model_output + c_skip * x_t
204 | return model_output, denoised
205 |
206 | def karras_sample(
207 | diffusion,
208 | model,
209 | x_T,
210 | x_0,
211 | steps,
212 | clip_denoised=True,
213 | progress=False,
214 | callback=None,
215 | model_kwargs=None,
216 | global_cond=None,
217 | device=None,
218 | sigma_min=0.002,
219 | sigma_max=80,
220 | rho=7.0,
221 | sampler="heun",
222 | churn_step_ratio=0.,
223 | guidance=1,
224 | ):
225 | assert sampler in ["heun", ], 'only heun sampler is supported currently'
226 |
227 | sigmas = get_sigmas_karras(steps, sigma_min, sigma_max-1e-4, rho, device=device)
228 |
229 | sample_fn = {
230 | "heun": partial(sample_heun, beta_d=diffusion.beta_d, beta_min=diffusion.beta_min),
231 | }[sampler]
232 |
233 | sampler_args = dict(
234 | pred_mode=diffusion.pred_mode, churn_step_ratio=churn_step_ratio, sigma_max=sigma_max
235 | )
236 |
237 | def denoiser(x_t, sigma, x_T=None, global_cond=None):
238 | _, denoised = diffusion.denoise(model, x_t, sigma, global_cond)
239 | if clip_denoised:
240 | denoised = denoised.clamp(-1, 1)
241 | return denoised
242 |
243 | x_0, path, nfe = sample_fn(
244 | denoiser,
245 | x_T,
246 | sigmas,
247 | global_cond=global_cond,
248 | progress=progress,
249 | callback=callback,
250 | guidance=guidance,
251 | **sampler_args,
252 | )
253 | print('nfe:', nfe)
254 |
255 | return x_0.clamp(-1, 1), [x.clamp(-1, 1) for x in path], nfe
256 |
257 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
258 | ramp = th.linspace(0, 1, n)
259 | min_inv_rho = sigma_min ** (1 / rho)
260 | max_inv_rho = sigma_max ** (1 / rho)
261 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
262 | return append_zero(sigmas).to(device)
263 |
264 | def get_bridge_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, eps=1e-4, device="cpu"):
265 | sigma_t_crit = sigma_max / np.sqrt(2)
266 | min_start_inv_rho = sigma_min ** (1 / rho)
267 | max_inv_rho = sigma_t_crit ** (1 / rho)
268 |
269 | sigmas_second_half = (max_inv_rho + th.linspace(0, 1, n//2 ) * (min_start_inv_rho - max_inv_rho)) ** rho
270 |
271 | sigmas_first_half = sigma_max - ((sigma_max - sigma_t_crit) ** (1 / rho) + th.linspace(0, 1, n - n//2 +1 ) * (eps ** (1 / rho) - (sigma_max - sigma_t_crit) ** (1 / rho))) ** rho
272 | sigmas = th.cat([sigmas_first_half.flip(0)[:-1], sigmas_second_half])
273 | return append_zero(sigmas).to(device)
274 |
275 | def to_d(x, sigma, denoised, x_T, sigma_max, w=1, stochastic=False):
276 | grad_pxtlx0 = (denoised - x) / append_dims(sigma**2, x.ndim)
277 | grad_pxTlxt = (x_T - x) / (append_dims(th.ones_like(sigma)*sigma_max**2, x.ndim) - append_dims(sigma**2, x.ndim))
278 | gt2 = 2 * sigma
279 | d = -(0.5 if not stochastic else 1) * gt2 * (grad_pxtlx0 - w * grad_pxTlxt * (0 if stochastic else 1))
280 | if stochastic:
281 | return d, gt2
282 | else:
283 | return d
284 |
285 | def get_d_vp(x, denoised, x_T, std_t, logsnr_t, logsnr_T, logs_t, logs_T, s_t_deriv, sigma_t, sigma_t_deriv, w, stochastic=False):
286 |
287 | a_t = (logsnr_T - logsnr_t + logs_t - logs_T).exp()
288 | b_t = -th.expm1(logsnr_T - logsnr_t) * logs_t.exp()
289 |
290 | mu_t = a_t * x_T + b_t * denoised
291 |
292 | grad_logq = -(x - mu_t) / std_t**2 / (-th.expm1(logsnr_T - logsnr_t))
293 |
294 | grad_logpxTlxt = -(x - th.exp(logs_t - logs_T) * x_T) / std_t**2 / th.expm1(logsnr_t - logsnr_T)
295 |
296 | f = s_t_deriv * (-logs_t).exp() * x
297 | gt2 = 2 * (logs_t).exp()**2 * sigma_t * sigma_t_deriv
298 |
299 | d = f - gt2 * ((0.5 if not stochastic else 1) * grad_logq - w * grad_logpxTlxt)
300 | if stochastic:
301 | return d, gt2
302 | else:
303 | return d
304 |
305 | @th.no_grad()
306 | def sample_heun(
307 | denoiser,
308 | x,
309 | sigmas,
310 | global_cond=None,
311 | pred_mode='both',
312 | progress=False,
313 | callback=None,
314 | sigma_max=80.0,
315 | beta_d=2,
316 | beta_min=0.1,
317 | churn_step_ratio=0.,
318 | guidance=1,
319 | ):
320 | x_T = x
321 | path = [x]
322 |
323 | s_in = x.new_ones([x.shape[0]])
324 | indices = range(len(sigmas) - 1)
325 |
326 | if progress:
327 | from tqdm.auto import tqdm
328 | indices = tqdm(indices)
329 |
330 | nfe = 0
331 | assert churn_step_ratio < 1
332 |
333 | if pred_mode.startswith('vp'):
334 | vp_snr_sqrt_reciprocal = lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
335 | vp_snr_sqrt_reciprocal_deriv = lambda t: 0.5 * (beta_min + beta_d * t) * (vp_snr_sqrt_reciprocal(t) + 1 / vp_snr_sqrt_reciprocal(t))
336 | s = lambda t: (1 + vp_snr_sqrt_reciprocal(t) ** 2).rsqrt()
337 | s_deriv = lambda t: -vp_snr_sqrt_reciprocal(t) * vp_snr_sqrt_reciprocal_deriv(t) * (s(t) ** 3)
338 | logs = lambda t: -0.25 * t ** 2 * (beta_d) - 0.5 * t * beta_min
339 | std = lambda t: vp_snr_sqrt_reciprocal(t) * s(t)
340 | logsnr = lambda t : -2 * th.log(vp_snr_sqrt_reciprocal(t))
341 |
342 | logsnr_T = logsnr(th.as_tensor(sigma_max))
343 | logs_T = logs(th.as_tensor(sigma_max))
344 |
345 | for j, i in enumerate(indices):
346 | if churn_step_ratio > 0:
347 | sigma_hat = (sigmas[i+1] - sigmas[i]) * churn_step_ratio + sigmas[i]
348 |
349 | denoised = denoiser(x, sigmas[i] * s_in, x_T, global_cond)
350 | if pred_mode == 've':
351 | d_1, gt2 = to_d(x, sigmas[i], denoised, x_T, sigma_max, w=guidance, stochastic=True)
352 | elif pred_mode.startswith('vp'):
353 | d_1, gt2 = get_d_vp(x, denoised, x_T, std(sigmas[i]), logsnr(sigmas[i]), logsnr_T, logs(sigmas[i]), logs_T, s_deriv(sigmas[i]), vp_snr_sqrt_reciprocal(sigmas[i]), vp_snr_sqrt_reciprocal_deriv(sigmas[i]), guidance, stochastic=True)
354 |
355 | dt = (sigma_hat - sigmas[i])
356 | x = x + d_1 * dt + th.randn_like(x) * (dt.abs() ** 0.5) * gt2.sqrt()
357 |
358 | nfe += 1
359 | path.append(x.detach().cpu())
360 | else:
361 | sigma_hat = sigmas[i]
362 |
363 | denoised = denoiser(x, sigma_hat * s_in, x_T, global_cond)
364 | if pred_mode == 've':
365 | d = to_d(x, sigma_hat, denoised, x_T, sigma_max, w=guidance)
366 | elif pred_mode.startswith('vp'):
367 | d = get_d_vp(x, denoised, x_T, std(sigma_hat), logsnr(sigma_hat), logsnr_T, logs(sigma_hat), logs_T, s_deriv(sigma_hat), vp_snr_sqrt_reciprocal(sigma_hat), vp_snr_sqrt_reciprocal_deriv(sigma_hat), guidance)
368 |
369 | nfe += 1
370 | if callback is not None:
371 | callback(
372 | {
373 | "x": x,
374 | "i": i,
375 | "sigma": sigmas[i],
376 | "sigma_hat": sigma_hat,
377 | "denoised": denoised,
378 | }
379 | )
380 |
381 | dt = sigmas[i + 1] - sigma_hat
382 |
383 | if sigmas[i + 1] == 0:
384 | x = x + d * dt
385 | else:
386 | x_2 = x + d * dt
387 | denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in, x_T, global_cond)
388 | if pred_mode == 've':
389 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2, x_T, sigma_max, w=guidance)
390 | elif pred_mode.startswith('vp'):
391 | d_2 = get_d_vp(x_2, denoised_2, x_T, std(sigmas[i + 1]), logsnr(sigmas[i + 1]), logsnr_T, logs(sigmas[i + 1]), logs_T, s_deriv(sigmas[i + 1]), vp_snr_sqrt_reciprocal(sigmas[i + 1]), vp_snr_sqrt_reciprocal_deriv(sigmas[i + 1]), guidance)
392 |
393 | d_prime = (d + d_2) / 2
394 | x = x + d_prime * dt
395 | nfe += 1
396 |
397 | path.append(x.detach().cpu())
398 |
399 | return x, path, nfe
400 |
401 | @th.no_grad()
402 | def forward_sample(
403 | x0,
404 | y0,
405 | sigma_max,
406 | ):
407 | ts = th.linspace(0, sigma_max, 120)
408 | x = x0
409 | path = [x]
410 |
411 | for t in ts:
412 | std_t = th.sqrt(t) * th.sqrt(1 - t / sigma_max)
413 | mu_t = t / sigma_max * y0 + (1 - t / sigma_max) * x0
414 | xt = mu_t + std_t * th.randn_like(x0)
415 | path.append(xt)
416 |
417 | path.append(y0)
418 |
419 | return path
420 |
--------------------------------------------------------------------------------
/train/vint_train/data/vint_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import pickle
4 | import yaml
5 | from typing import Any, Dict, List, Optional, Tuple
6 | import tqdm
7 | import io
8 | import lmdb
9 |
10 | import torch
11 | from torch.utils.data import Dataset
12 | import torchvision.transforms.functional as TF
13 |
14 | from vint_train.data.data_utils import (
15 | img_path_to_data,
16 | calculate_sin_cos,
17 | get_data_path,
18 | to_local_coords,
19 | )
20 |
21 | class ViNT_Dataset(Dataset):
22 | def __init__(
23 | self,
24 | data_folder: str,
25 | data_split_folder: str,
26 | dataset_name: str,
27 | image_size: Tuple[int, int],
28 | waypoint_spacing: int,
29 | min_dist_cat: int,
30 | max_dist_cat: int,
31 | min_action_distance: int,
32 | max_action_distance: int,
33 | negative_mining: bool,
34 | len_traj_pred: int,
35 | learn_angle: bool,
36 | context_size: int,
37 | context_type: str = "temporal",
38 | end_slack: int = 0,
39 | goals_per_obs: int = 1,
40 | normalize: bool = True,
41 | obs_type: str = "image",
42 | goal_type: str = "image",
43 | angle_ranges: list = None,
44 | ):
45 | """
46 | Main ViNT dataset class
47 |
48 | Args:
49 | data_folder (string): Directory with all the image data
50 | data_split_folder (string): Directory with filepaths.txt, a list of all trajectory names in the dataset split that are each seperated by a newline
51 | dataset_name (string): Name of the dataset [recon, go_stanford, scand, tartandrive, etc.]
52 | waypoint_spacing (int): Spacing between waypoints
53 | min_dist_cat (int): Minimum distance category to use
54 | max_dist_cat (int): Maximum distance category to use
55 | negative_mining (bool): Whether to use negative mining from the ViNG paper (Shah et al.) (https://arxiv.org/abs/2012.09812)
56 | len_traj_pred (int): Length of trajectory of waypoints to predict if this is an action dataset
57 | learn_angle (bool): Whether to learn the yaw of the robot at each predicted waypoint if this is an action dataset
58 | context_size (int): Number of previous observations to use as context
59 | context_type (str): Whether to use temporal, randomized, or randomized temporal context
60 | end_slack (int): Number of timesteps to ignore at the end of the trajectory
61 | goals_per_obs (int): Number of goals to sample per observation
62 | normalize (bool): Whether to normalize the distances or actions
63 | goal_type (str): What data type to use for the goal. The only one supported is "image" for now.
64 | """
65 | self.data_folder = data_folder
66 | self.data_split_folder = data_split_folder
67 | self.dataset_name = dataset_name
68 |
69 | traj_names_file = os.path.join(data_split_folder, "traj_names.txt")
70 | with open(traj_names_file, "r") as f:
71 | file_lines = f.read()
72 | self.traj_names = file_lines.split("\n")
73 | if "" in self.traj_names:
74 | self.traj_names.remove("")
75 |
76 | self.image_size = image_size
77 | self.waypoint_spacing = waypoint_spacing
78 | self.distance_categories = list(
79 | range(min_dist_cat, max_dist_cat + 1, self.waypoint_spacing)
80 | )
81 | self.min_dist_cat = self.distance_categories[0]
82 | self.max_dist_cat = self.distance_categories[-1]
83 | self.negative_mining = negative_mining
84 | if self.negative_mining:
85 | self.distance_categories.append(-1)
86 | self.len_traj_pred = len_traj_pred
87 | self.learn_angle = learn_angle
88 |
89 | self.min_action_distance = min_action_distance
90 | self.max_action_distance = max_action_distance
91 |
92 | self.context_size = context_size
93 | assert context_type in {
94 | "temporal",
95 | "randomized",
96 | "randomized_temporal",
97 | }, "context_type must be one of temporal, randomized, randomized_temporal"
98 | self.context_type = context_type
99 | self.end_slack = end_slack
100 | self.goals_per_obs = goals_per_obs
101 | self.normalize = normalize
102 | self.obs_type = obs_type
103 | self.goal_type = goal_type
104 |
105 | # load data/data_config.yaml
106 | with open(
107 | os.path.join(os.path.dirname(__file__), "data_config.yaml"), "r"
108 | ) as f:
109 | all_data_config = yaml.safe_load(f)
110 | assert (
111 | self.dataset_name in all_data_config
112 | ), f"Dataset {self.dataset_name} not found in data_config.yaml"
113 | dataset_names = list(all_data_config.keys())
114 | dataset_names.sort()
115 | # use this index to retrieve the dataset name from the data_config.yaml
116 | self.dataset_index = dataset_names.index(self.dataset_name)
117 | self.data_config = all_data_config[self.dataset_name]
118 | self.trajectory_cache = {}
119 | self._load_index()
120 | self._build_caches()
121 |
122 | if self.learn_angle:
123 | self.num_action_params = 3
124 | else:
125 | self.num_action_params = 2
126 |
127 | if angle_ranges is None:
128 | self._set_angle_ranges()
129 | else:
130 | self.angle_ranges = angle_ranges
131 |
132 | def _set_angle_ranges(self):
133 | self.angle_ranges = [(0, 67.5),
134 | (67.5, 112.5),
135 | (112.5, 180),
136 | (180, 270),
137 | (270, 360)]
138 |
139 | def __getstate__(self):
140 | state = self.__dict__.copy()
141 | state["_image_cache"] = None
142 | return state
143 |
144 | def __setstate__(self, state):
145 | self.__dict__ = state
146 | self._build_caches()
147 |
148 | def _build_caches(self, use_tqdm: bool = True):
149 | """
150 | Build a cache of images for faster loading using LMDB
151 | """
152 | cache_filename = os.path.join(
153 | self.data_split_folder,
154 | f"dataset_{self.dataset_name}.lmdb",
155 | )
156 |
157 | # Load all the trajectories into memory. These should already be loaded, but just in case.
158 | for traj_name in self.traj_names:
159 | self._get_trajectory(traj_name)
160 |
161 | """
162 | If the cache file doesn't exist, create it by iterating through the dataset and writing each image to the cache
163 | """
164 | if not os.path.exists(cache_filename):
165 | tqdm_iterator = tqdm.tqdm(
166 | self.goals_index,
167 | disable=not use_tqdm,
168 | dynamic_ncols=True,
169 | desc=f"Building LMDB cache for {self.dataset_name}"
170 | )
171 | with lmdb.open(cache_filename, map_size=2**40) as image_cache:
172 | with image_cache.begin(write=True) as txn:
173 | for traj_name, time in tqdm_iterator:
174 | image_path = get_data_path(self.data_folder, traj_name, time)
175 | with open(image_path, "rb") as f:
176 | txn.put(image_path.encode(), f.read())
177 |
178 | # Reopen the cache file in read-only mode
179 | self._image_cache: lmdb.Environment = lmdb.open(cache_filename, readonly=True)
180 |
181 | def _build_index(self, use_tqdm: bool = False):
182 | """
183 | Build an index consisting of tuples (trajectory name, time, max goal distance)
184 | """
185 | samples_index = []
186 | goals_index = []
187 |
188 | for traj_name in tqdm.tqdm(self.traj_names, disable=not use_tqdm, dynamic_ncols=True):
189 | traj_data = self._get_trajectory(traj_name)
190 | traj_len = len(traj_data["position"])
191 |
192 | for goal_time in range(0, traj_len):
193 | goals_index.append((traj_name, goal_time))
194 |
195 | begin_time = self.context_size * self.waypoint_spacing
196 | end_time = traj_len - self.end_slack - self.len_traj_pred * self.waypoint_spacing
197 | for curr_time in range(begin_time, end_time):
198 | max_goal_distance = min(self.max_dist_cat * self.waypoint_spacing, traj_len - curr_time - 1)
199 | samples_index.append((traj_name, curr_time, max_goal_distance))
200 |
201 | return samples_index, goals_index
202 |
203 | def _sample_goal(self, trajectory_name, curr_time, max_goal_dist):
204 | """
205 | Sample a goal from the future in the same trajectory.
206 | Returns: (trajectory_name, goal_time, goal_is_negative)
207 | """
208 | goal_offset = np.random.randint(0, max_goal_dist + 1)
209 | if goal_offset == 0:
210 | trajectory_name, goal_time = self._sample_negative()
211 | return trajectory_name, goal_time, True
212 | else:
213 | goal_time = curr_time + int(goal_offset * self.waypoint_spacing)
214 | return trajectory_name, goal_time, False
215 |
216 | def _sample_negative(self):
217 | """
218 | Sample a goal from a (likely) different trajectory.
219 | """
220 | return self.goals_index[np.random.randint(0, len(self.goals_index))]
221 |
222 | def _load_index(self) -> None:
223 | """
224 | Generates a list of tuples of (obs_traj_name, goal_traj_name, obs_time, goal_time) for each observation in the dataset
225 | """
226 | index_to_data_path = os.path.join(
227 | self.data_split_folder,
228 | f"dataset_dist_{self.min_dist_cat}_to_{self.max_dist_cat}_context_{self.context_type}_n{self.context_size}_slack_{self.end_slack}.pkl",
229 | )
230 | try:
231 | # load the index_to_data if it already exists (to save time)
232 | with open(index_to_data_path, "rb") as f:
233 | self.index_to_data, self.goals_index = pickle.load(f)
234 | except:
235 | # if the index_to_data file doesn't exist, create it
236 | self.index_to_data, self.goals_index = self._build_index()
237 | with open(index_to_data_path, "wb") as f:
238 | pickle.dump((self.index_to_data, self.goals_index), f)
239 |
240 | def _load_image(self, trajectory_name, time):
241 | image_path = get_data_path(self.data_folder, trajectory_name, time)
242 | # return img_path_to_data(image_path, self.image_size)
243 | try:
244 | with self._image_cache.begin() as txn:
245 | image_buffer = txn.get(image_path.encode())
246 | image_bytes = bytes(image_buffer)
247 | image_bytes = io.BytesIO(image_bytes)
248 | return img_path_to_data(image_bytes, self.image_size)
249 | except TypeError:
250 | print(f"Failed to load image {image_path}")
251 |
252 | def _theta2category(self, theta):
253 | for i, (min_angle, max_angle) in enumerate(self.angle_ranges):
254 | if min_angle <= theta < max_angle:
255 | return i
256 | return i
257 |
258 | def _calculate_angle(self, waypoints):
259 | x_end, y_end = waypoints[-1]
260 |
261 | angle_rad = np.arctan2(y_end, x_end)
262 |
263 | angle_deg = np.degrees(angle_rad)
264 |
265 | if angle_deg < 0:
266 | angle_deg += 360
267 |
268 | return angle_deg
269 |
270 | def _compute_actions(self, traj_data, curr_time, goal_time):
271 | start_index = curr_time
272 | end_index = curr_time + self.len_traj_pred * self.waypoint_spacing + 1
273 |
274 | yaw = traj_data["yaw"][start_index:end_index:self.waypoint_spacing]
275 | positions = traj_data["position"][start_index:end_index:self.waypoint_spacing]
276 |
277 | goal_pos = traj_data["position"][min(goal_time, len(traj_data["position"]) - 1)]
278 |
279 | if len(yaw.shape) == 2:
280 | yaw = yaw.squeeze(1)
281 |
282 | if yaw.shape != (self.len_traj_pred + 1,):
283 | const_len = self.len_traj_pred + 1 - yaw.shape[0]
284 | yaw = np.concatenate([yaw, np.repeat(yaw[-1], const_len)])
285 | positions = np.concatenate([positions, np.repeat(positions[-1][None], const_len, axis=0)], axis=0)
286 |
287 | assert yaw.shape == (self.len_traj_pred + 1,), f"{yaw.shape} and {(self.len_traj_pred + 1,)} should be equal"
288 | assert positions.shape == (self.len_traj_pred + 1, 2), f"{positions.shape} and {(self.len_traj_pred + 1, 2)} should be equal"
289 |
290 | waypoints = to_local_coords(positions, positions[0], yaw[0])
291 | goal_pos = to_local_coords(goal_pos, positions[0], yaw[0])
292 |
293 | assert waypoints.shape == (self.len_traj_pred + 1, 2), f"{waypoints.shape} and {(self.len_traj_pred + 1, 2)} should be equal"
294 |
295 | if self.learn_angle:
296 | yaw = yaw[1:] - yaw[0]
297 | actions = np.concatenate([waypoints[1:], yaw[:, None]], axis=-1)
298 | else:
299 | actions = waypoints[1:]
300 |
301 | if self.normalize:
302 | actions[:, :2] /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing
303 | goal_pos /= self.data_config["metric_waypoint_spacing"] * self.waypoint_spacing
304 |
305 | assert actions.shape == (self.len_traj_pred, self.num_action_params), f"{actions.shape} and {(self.len_traj_pred, self.num_action_params)} should be equal"
306 |
307 | theta = self._calculate_angle(waypoints)
308 | action_category = self._theta2category(theta)
309 |
310 | return actions, goal_pos, action_category
311 |
312 | def _get_trajectory(self, trajectory_name):
313 | if trajectory_name in self.trajectory_cache:
314 | return self.trajectory_cache[trajectory_name]
315 | else:
316 | with open(os.path.join(self.data_folder, trajectory_name, "traj_data.pkl"), "rb") as f:
317 | traj_data = pickle.load(f)
318 | self.trajectory_cache[trajectory_name] = traj_data
319 | return traj_data
320 |
321 | def __len__(self) -> int:
322 | return len(self.index_to_data)
323 |
324 | def __getitem__(self, i: int) -> Tuple[torch.Tensor]:
325 | """
326 | Args:
327 | i (int): index to ith datapoint
328 | Returns:
329 | Tuple of tensors containing the context, observation, goal, transformed context, transformed observation, transformed goal, distance label, and action label
330 | obs_image (torch.Tensor): tensor of shape [3, H, W] containing the image of the robot's observation
331 | goal_image (torch.Tensor): tensor of shape [3, H, W] containing the subgoal image
332 | dist_label (torch.Tensor): tensor of shape (1,) containing the distance labels from the observation to the goal
333 | action_label (torch.Tensor): tensor of shape (5, 2) or (5, 4) (if training with angle) containing the action labels from the observation to the goal
334 | which_dataset (torch.Tensor): index of the datapoint in the dataset [for identifying the dataset for visualization when using multiple datasets]
335 | """
336 | f_curr, curr_time, max_goal_dist = self.index_to_data[i]
337 | f_goal, goal_time, goal_is_negative = self._sample_goal(f_curr, curr_time, max_goal_dist)
338 |
339 | # Load images
340 | context = []
341 | if self.context_type == "temporal":
342 | # sample the last self.context_size times from interval [0, curr_time)
343 | context_times = list(
344 | range(
345 | curr_time + -self.context_size * self.waypoint_spacing,
346 | curr_time + 1,
347 | self.waypoint_spacing,
348 | )
349 | )
350 | context = [(f_curr, t) for t in context_times]
351 | else:
352 | raise ValueError(f"Invalid context type {self.context_type}")
353 |
354 | obs_image = torch.cat([
355 | self._load_image(f, t) for f, t in context
356 | ])
357 |
358 | # Load goal image
359 | goal_image = self._load_image(f_goal, goal_time)
360 |
361 | # Load other trajectory data
362 | curr_traj_data = self._get_trajectory(f_curr)
363 | curr_traj_len = len(curr_traj_data["position"])
364 | assert curr_time < curr_traj_len, f"{curr_time} and {curr_traj_len}"
365 |
366 | goal_traj_data = self._get_trajectory(f_goal)
367 | goal_traj_len = len(goal_traj_data["position"])
368 | assert goal_time < goal_traj_len, f"{goal_time} an {goal_traj_len}"
369 |
370 | # Compute actions
371 | actions, goal_pos, action_category = self._compute_actions(curr_traj_data, curr_time, goal_time)
372 |
373 | if actions.dtype is not np.float32:
374 | actions = np.array(actions, dtype=np.float32)
375 | if goal_pos.dtype is not np.float32:
376 | goal_pos = np.array(goal_pos, dtype=np.float32)
377 |
378 | # Compute distances
379 | if goal_is_negative:
380 | distance = self.max_dist_cat
381 | else:
382 | distance = (goal_time - curr_time) // self.waypoint_spacing
383 | assert (goal_time - curr_time) % self.waypoint_spacing == 0, f"{goal_time} and {curr_time} should be separated by an integer multiple of {self.waypoint_spacing}"
384 |
385 | actions_torch = torch.as_tensor(actions, dtype=torch.float32)
386 | if self.learn_angle:
387 | actions_torch = calculate_sin_cos(actions_torch)
388 |
389 | action_mask = (
390 | (distance < self.max_action_distance) and
391 | (distance > self.min_action_distance) and
392 | (not goal_is_negative)
393 | )
394 |
395 | return (
396 | torch.as_tensor(obs_image, dtype=torch.float32),
397 | torch.as_tensor(goal_image, dtype=torch.float32),
398 | actions_torch,
399 | torch.as_tensor(distance, dtype=torch.int64),
400 | torch.as_tensor(goal_pos, dtype=torch.float32),
401 | torch.as_tensor(self.dataset_index, dtype=torch.int64),
402 | torch.as_tensor(action_mask, dtype=torch.float32),
403 | torch.as_tensor(action_category, dtype=torch.long)
404 | )
405 |
406 | def sample_prior():
407 | pass
--------------------------------------------------------------------------------
/train/vint_train/visualizing/action_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import cv2
5 | from typing import Optional, List
6 | import wandb
7 | import yaml
8 | import torch
9 | import torch.nn as nn
10 | from vint_train.visualizing.visualize_utils import (
11 | to_numpy,
12 | numpy_to_img,
13 | VIZ_IMAGE_SIZE,
14 | RED,
15 | GREEN,
16 | BLUE,
17 | CYAN,
18 | YELLOW,
19 | MAGENTA,
20 | )
21 |
22 | # load data_config.yaml
23 | with open(os.path.join(os.path.dirname(__file__), "../data/data_config.yaml"), "r") as f:
24 | data_config = yaml.safe_load(f)
25 |
26 |
27 | def visualize_traj_pred(
28 | batch_obs_images: np.ndarray,
29 | batch_goal_images: np.ndarray,
30 | dataset_indices: np.ndarray,
31 | batch_goals: np.ndarray,
32 | batch_pred_waypoints: np.ndarray,
33 | batch_label_waypoints: np.ndarray,
34 | eval_type: str,
35 | normalized: bool,
36 | save_folder: str,
37 | epoch: int,
38 | num_images_preds: int = 8,
39 | use_wandb: bool = True,
40 | display: bool = False,
41 | ):
42 | """
43 | Compare predicted path with the gt path of waypoints using egocentric visualization. This visualization is for the last batch in the dataset.
44 |
45 | Args:
46 | batch_obs_images (np.ndarray): batch of observation images [batch_size, height, width, channels]
47 | batch_goal_images (np.ndarray): batch of goal images [batch_size, height, width, channels]
48 | dataset_names: indices corresponding to the dataset name
49 | batch_goals (np.ndarray): batch of goal positions [batch_size, 2]
50 | batch_pred_waypoints (np.ndarray): batch of predicted waypoints [batch_size, horizon, 4] or [batch_size, horizon, 2] or [batch_size, num_trajs_sampled horizon, {2 or 4}]
51 | batch_label_waypoints (np.ndarray): batch of label waypoints [batch_size, T, 4] or [batch_size, horizon, 2]
52 | eval_type (string): f"{data_type}_{eval_type}" (e.g. "recon_train", "gs_test", etc.)
53 | normalized (bool): whether the waypoints are normalized
54 | save_folder (str): folder to save the images. If None, will not save the images
55 | epoch (int): current epoch number
56 | num_images_preds (int): number of images to visualize
57 | use_wandb (bool): whether to use wandb to log the images
58 | display (bool): whether to display the images
59 | """
60 | visualize_path = None
61 | if save_folder is not None:
62 | visualize_path = os.path.join(
63 | save_folder, "visualize", eval_type, f"epoch{epoch}", "action_prediction"
64 | )
65 |
66 | if not os.path.exists(visualize_path):
67 | os.makedirs(visualize_path)
68 |
69 | assert (
70 | len(batch_obs_images)
71 | == len(batch_goal_images)
72 | == len(batch_goals)
73 | == len(batch_pred_waypoints)
74 | == len(batch_label_waypoints)
75 | )
76 |
77 | dataset_names = list(data_config.keys())
78 | dataset_names.sort()
79 |
80 | batch_size = batch_obs_images.shape[0]
81 | wandb_list = []
82 | for i in range(min(batch_size, num_images_preds)):
83 | obs_img = numpy_to_img(batch_obs_images[i])
84 | goal_img = numpy_to_img(batch_goal_images[i])
85 | dataset_name = dataset_names[int(dataset_indices[i])]
86 | goal_pos = batch_goals[i]
87 | pred_waypoints = batch_pred_waypoints[i]
88 | label_waypoints = batch_label_waypoints[i]
89 |
90 | if normalized:
91 | pred_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"]
92 | label_waypoints *= data_config[dataset_name]["metric_waypoint_spacing"]
93 | goal_pos *= data_config[dataset_name]["metric_waypoint_spacing"]
94 |
95 | save_path = None
96 | if visualize_path is not None:
97 | save_path = os.path.join(visualize_path, f"{str(i).zfill(4)}.png")
98 |
99 | compare_waypoints_pred_to_label(
100 | obs_img,
101 | goal_img,
102 | dataset_name,
103 | goal_pos,
104 | pred_waypoints,
105 | label_waypoints,
106 | save_path,
107 | display,
108 | )
109 | if use_wandb:
110 | wandb_list.append(wandb.Image(save_path))
111 | if use_wandb:
112 | wandb.log({f"{eval_type}_action_prediction": wandb_list}, commit=False)
113 |
114 |
115 | def compare_waypoints_pred_to_label(
116 | obs_img,
117 | goal_img,
118 | dataset_name: str,
119 | goal_pos: np.ndarray,
120 | pred_waypoints: np.ndarray,
121 | label_waypoints: np.ndarray,
122 | save_path: Optional[str] = None,
123 | display: Optional[bool] = False,
124 | ):
125 | """
126 | Compare predicted path with the gt path of waypoints using egocentric visualization.
127 |
128 | Args:
129 | obs_img: image of the observation
130 | goal_img: image of the goal
131 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon")
132 | goal_pos: goal position in the image
133 | pred_waypoints: predicted waypoints in the image
134 | label_waypoints: label waypoints in the image
135 | save_path: path to save the figure
136 | display: whether to display the figure
137 | """
138 |
139 | fig, ax = plt.subplots(1, 3)
140 | start_pos = np.array([0, 0])
141 | if len(pred_waypoints.shape) > 2:
142 | trajs = [*pred_waypoints, label_waypoints]
143 | else:
144 | trajs = [pred_waypoints, label_waypoints]
145 | plot_trajs_and_points(
146 | ax[0],
147 | trajs,
148 | [start_pos, goal_pos],
149 | traj_colors=[CYAN, MAGENTA],
150 | point_colors=[GREEN, RED],
151 | )
152 | plot_trajs_and_points_on_image(
153 | ax[1],
154 | obs_img,
155 | dataset_name,
156 | trajs,
157 | [start_pos, goal_pos],
158 | traj_colors=[CYAN, MAGENTA],
159 | point_colors=[GREEN, RED],
160 | )
161 | ax[2].imshow(goal_img)
162 |
163 | fig.set_size_inches(18.5, 10.5)
164 | ax[0].set_title(f"Action Prediction")
165 | ax[1].set_title(f"Observation")
166 | ax[2].set_title(f"Goal")
167 |
168 | if save_path is not None:
169 | fig.savefig(
170 | save_path,
171 | bbox_inches="tight",
172 | )
173 |
174 | if not display:
175 | plt.close(fig)
176 |
177 |
178 | def plot_trajs_and_points_on_image(
179 | ax: plt.Axes,
180 | img: np.ndarray,
181 | dataset_name: str,
182 | list_trajs: list,
183 | list_points: list,
184 | traj_colors: list = [CYAN, MAGENTA],
185 | point_colors: list = [RED, GREEN],
186 | ):
187 | """
188 | Plot trajectories and points on an image. If there is no configuration for the camera interinstics of the dataset, the image will be plotted as is.
189 | Args:
190 | ax: matplotlib axis
191 | img: image to plot
192 | dataset_name: name of the dataset found in data_config.yaml (e.g. "recon")
193 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw)
194 | list_points: list of points, each point is a numpy array of shape (2,)
195 | traj_colors: list of colors for trajectories
196 | point_colors: list of colors for points
197 | """
198 | assert len(list_trajs) <= len(traj_colors), "Not enough colors for trajectories"
199 | assert len(list_points) <= len(point_colors), "Not enough colors for points"
200 | assert (
201 | dataset_name in data_config
202 | ), f"Dataset {dataset_name} not found in data/data_config.yaml"
203 |
204 | ax.imshow(img)
205 | if (
206 | "camera_metrics" in data_config[dataset_name]
207 | and "camera_height" in data_config[dataset_name]["camera_metrics"]
208 | and "camera_matrix" in data_config[dataset_name]["camera_metrics"]
209 | and "dist_coeffs" in data_config[dataset_name]["camera_metrics"]
210 | ):
211 | camera_height = data_config[dataset_name]["camera_metrics"]["camera_height"]
212 | camera_x_offset = data_config[dataset_name]["camera_metrics"]["camera_x_offset"]
213 |
214 | fx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fx"]
215 | fy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["fy"]
216 | cx = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cx"]
217 | cy = data_config[dataset_name]["camera_metrics"]["camera_matrix"]["cy"]
218 | camera_matrix = gen_camera_matrix(fx, fy, cx, cy)
219 |
220 | k1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k1"]
221 | k2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k2"]
222 | p1 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p1"]
223 | p2 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["p2"]
224 | k3 = data_config[dataset_name]["camera_metrics"]["dist_coeffs"]["k3"]
225 | dist_coeffs = np.array([k1, k2, p1, p2, k3, 0.0, 0.0, 0.0])
226 |
227 | for i, traj in enumerate(list_trajs):
228 | xy_coords = traj[:, :2] # (horizon, 2)
229 | traj_pixels = get_pos_pixels(
230 | xy_coords, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=False
231 | )
232 | if len(traj_pixels.shape) == 2:
233 | ax.plot(
234 | traj_pixels[:250, 0],
235 | traj_pixels[:250, 1],
236 | color=traj_colors[i],
237 | lw=2.5,
238 | )
239 |
240 | for i, point in enumerate(list_points):
241 | if len(point.shape) == 1:
242 | # add a dimension to the front of point
243 | point = point[None, :2]
244 | else:
245 | point = point[:, :2]
246 | pt_pixels = get_pos_pixels(
247 | point, camera_height, camera_x_offset, camera_matrix, dist_coeffs, clip=True
248 | )
249 | ax.plot(
250 | pt_pixels[:250, 0],
251 | pt_pixels[:250, 1],
252 | color=point_colors[i],
253 | marker="o",
254 | markersize=10.0,
255 | )
256 | ax.xaxis.set_visible(False)
257 | ax.yaxis.set_visible(False)
258 | ax.set_xlim((0.5, VIZ_IMAGE_SIZE[0] - 0.5))
259 | ax.set_ylim((VIZ_IMAGE_SIZE[1] - 0.5, 0.5))
260 |
261 |
262 | def plot_trajs_and_points(
263 | ax: plt.Axes,
264 | list_trajs: list,
265 | list_points: list,
266 | traj_colors: list = [CYAN, MAGENTA],
267 | point_colors: list = [RED, GREEN],
268 | traj_labels: Optional[list] = ["prediction", "ground truth"],
269 | point_labels: Optional[list] = ["robot", "goal"],
270 | traj_alphas: Optional[list] = None,
271 | point_alphas: Optional[list] = None,
272 | quiver_freq: int = 1,
273 | default_coloring: bool = True,
274 | ):
275 | """
276 | Plot trajectories and points that could potentially have a yaw.
277 |
278 | Args:
279 | ax: matplotlib axis
280 | list_trajs: list of trajectories, each trajectory is a numpy array of shape (horizon, 2) (if there is no yaw) or (horizon, 4) (if there is yaw)
281 | list_points: list of points, each point is a numpy array of shape (2,)
282 | traj_colors: list of colors for trajectories
283 | point_colors: list of colors for points
284 | traj_labels: list of labels for trajectories
285 | point_labels: list of labels for points
286 | traj_alphas: list of alphas for trajectories
287 | point_alphas: list of alphas for points
288 | quiver_freq: frequency of quiver plot (if the trajectory data includes the yaw of the robot)
289 | """
290 | assert (
291 | len(list_trajs) <= len(traj_colors) or default_coloring
292 | ), "Not enough colors for trajectories"
293 | assert len(list_points) <= len(point_colors), "Not enough colors for points"
294 | assert (
295 | traj_labels is None or len(list_trajs) == len(traj_labels) or default_coloring
296 | ), "Not enough labels for trajectories"
297 | assert point_labels is None or len(list_points) == len(point_labels), "Not enough labels for points"
298 |
299 | for i, traj in enumerate(list_trajs):
300 | if traj_labels is None:
301 | ax.plot(
302 | traj[:, 0],
303 | traj[:, 1],
304 | color=traj_colors[i],
305 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0,
306 | marker="o",
307 | )
308 | else:
309 | ax.plot(
310 | traj[:, 0],
311 | traj[:, 1],
312 | color=traj_colors[i],
313 | label=traj_labels[i],
314 | alpha=traj_alphas[i] if traj_alphas is not None else 1.0,
315 | marker="o",
316 | )
317 | if traj.shape[1] > 2 and quiver_freq > 0: # traj data also includes yaw of the robot
318 | bearings = gen_bearings_from_waypoints(traj)
319 | ax.quiver(
320 | traj[::quiver_freq, 0],
321 | traj[::quiver_freq, 1],
322 | bearings[::quiver_freq, 0],
323 | bearings[::quiver_freq, 1],
324 | color=traj_colors[i] * 0.5,
325 | scale=1.0,
326 | )
327 | for i, pt in enumerate(list_points):
328 | if point_labels is None:
329 | ax.plot(
330 | pt[0],
331 | pt[1],
332 | color=point_colors[i],
333 | alpha=point_alphas[i] if point_alphas is not None else 1.0,
334 | marker="o",
335 | markersize=7.0
336 | )
337 | else:
338 | ax.plot(
339 | pt[0],
340 | pt[1],
341 | color=point_colors[i],
342 | alpha=point_alphas[i] if point_alphas is not None else 1.0,
343 | marker="o",
344 | markersize=7.0,
345 | label=point_labels[i],
346 | )
347 |
348 |
349 | # put the legend below the plot
350 | if traj_labels is not None or point_labels is not None:
351 | ax.legend()
352 | ax.legend(bbox_to_anchor=(0.0, -0.5), loc="upper left", ncol=2)
353 | ax.set_aspect("equal", "box")
354 |
355 |
356 | def angle_to_unit_vector(theta):
357 | """Converts an angle to a unit vector."""
358 | return np.array([np.cos(theta), np.sin(theta)])
359 |
360 |
361 | def gen_bearings_from_waypoints(
362 | waypoints: np.ndarray,
363 | mag=0.2,
364 | ) -> np.ndarray:
365 | """Generate bearings from waypoints, (x, y, sin(theta), cos(theta))."""
366 | bearing = []
367 | for i in range(0, len(waypoints)):
368 | if waypoints.shape[1] > 3: # label is sin/cos repr
369 | v = waypoints[i, 2:]
370 | # normalize v
371 | v = v / np.linalg.norm(v)
372 | v = v * mag
373 | else: # label is radians repr
374 | v = mag * angle_to_unit_vector(waypoints[i, 2])
375 | bearing.append(v)
376 | bearing = np.array(bearing)
377 | return bearing
378 |
379 |
380 | def project_points(
381 | xy: np.ndarray,
382 | camera_height: float,
383 | camera_x_offset: float,
384 | camera_matrix: np.ndarray,
385 | dist_coeffs: np.ndarray,
386 | ):
387 | """
388 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters.
389 |
390 | Args:
391 | xy: array of shape (batch_size, horizon, 2) representing (x, y) coordinates
392 | camera_height: height of the camera above the ground (in meters)
393 | camera_x_offset: offset of the camera from the center of the car (in meters)
394 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters
395 | dist_coeffs: vector of distortion coefficients
396 |
397 |
398 | Returns:
399 | uv: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane
400 | """
401 | batch_size, horizon, _ = xy.shape
402 |
403 | # create 3D coordinates with the camera positioned at the given height
404 | xyz = np.concatenate(
405 | [xy, -camera_height * np.ones(list(xy.shape[:-1]) + [1])], axis=-1
406 | )
407 |
408 | # create dummy rotation and translation vectors
409 | rvec = tvec = (0, 0, 0)
410 |
411 | xyz[..., 0] += camera_x_offset
412 | xyz_cv = np.stack([xyz[..., 1], -xyz[..., 2], xyz[..., 0]], axis=-1)
413 | uv, _ = cv2.projectPoints(
414 | xyz_cv.reshape(batch_size * horizon, 3), rvec, tvec, camera_matrix, dist_coeffs
415 | )
416 | uv = uv.reshape(batch_size, horizon, 2)
417 |
418 | return uv
419 |
420 |
421 | def get_pos_pixels(
422 | points: np.ndarray,
423 | camera_height: float,
424 | camera_x_offset: float,
425 | camera_matrix: np.ndarray,
426 | dist_coeffs: np.ndarray,
427 | clip: Optional[bool] = False,
428 | ):
429 | """
430 | Projects 3D coordinates onto a 2D image plane using the provided camera parameters.
431 | Args:
432 | points: array of shape (batch_size, horizon, 2) representing (x, y) coordinates
433 | camera_height: height of the camera above the ground (in meters)
434 | camera_x_offset: offset of the camera from the center of the car (in meters)
435 | camera_matrix: 3x3 matrix representing the camera's intrinsic parameters
436 | dist_coeffs: vector of distortion coefficients
437 |
438 | Returns:
439 | pixels: array of shape (batch_size, horizon, 2) representing (u, v) coordinates on the 2D image plane
440 | """
441 | pixels = project_points(
442 | points[np.newaxis], camera_height, camera_x_offset, camera_matrix, dist_coeffs
443 | )[0]
444 | pixels[:, 0] = VIZ_IMAGE_SIZE[0] - pixels[:, 0]
445 | if clip:
446 | pixels = np.array(
447 | [
448 | [
449 | np.clip(p[0], 0, VIZ_IMAGE_SIZE[0]),
450 | np.clip(p[1], 0, VIZ_IMAGE_SIZE[1]),
451 | ]
452 | for p in pixels
453 | ]
454 | )
455 | else:
456 | pixels = np.array(
457 | [
458 | p
459 | for p in pixels
460 | if np.all(p > 0) and np.all(p < [VIZ_IMAGE_SIZE[0], VIZ_IMAGE_SIZE[1]])
461 | ]
462 | )
463 | return pixels
464 |
465 |
466 | def gen_camera_matrix(fx: float, fy: float, cx: float, cy: float) -> np.ndarray:
467 | """
468 | Args:
469 | fx: focal length in x direction
470 | fy: focal length in y direction
471 | cx: principal point x coordinate
472 | cy: principal point y coordinate
473 | Returns:
474 | camera matrix
475 | """
476 | return np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])
477 |
--------------------------------------------------------------------------------