├── README.md
├── STEPP
├── DINO
│ ├── __init__.py
│ ├── backbone.py
│ └── dino_feature_extract.py
├── SLIC
│ └── slic_segmentation.py
├── __init__.py
├── model
│ ├── mlp.py
│ └── training.py
└── utils
│ ├── colorbar.py
│ ├── data_loader.py
│ ├── extract_future_poses.py
│ ├── image_saver.py
│ ├── make_dataset.py
│ ├── make_unreal_data_pixel_file.py
│ ├── misc.py
│ ├── rename_files.py
│ └── testing.py
├── STEPP_ros
├── CMakeLists.txt
├── config
│ └── model_config.yaml
├── launch
│ └── STEPP.launch
├── msg
│ └── Float32Stamped.msg
├── package.xml
├── scripts
│ └── inference_node.py
└── src
│ └── depth_projection_synchronized.cpp
├── assets
├── front_page.png
├── outdoor_all_2.png
└── pre_train_pipeline.png
├── checkpoints
├── all_ViT_small_input_700_big_nn_checkpoint_20240827-1935.pth
├── richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth
└── unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth
└── setup.py
/README.md:
--------------------------------------------------------------------------------
1 | # Watch your STEPP: Semantic Traversability Estimation using Pose Projected Features #
2 | 
3 |
4 | **Authors**: [Sebastian Aegidius*](https://rvl.cs.toronto.edu/), [Dennis Hadjivelichkov](https://dennisushi.github.io/), [Jianhao Jiao](https://gogojjh.github.io/), [Jonathan Embly-Riches](https://rpl-as-ucl.github.io/people/), [Dimitrios Kanoulas](https://dkanou.github.io/)
5 |
6 |
7 |
8 | [Project Page](https://rpl-cs-ucl.github.io/STEPP/) [STEPP arXiv](https://arxiv.org/)
9 |
10 |
11 |
12 |
13 | 
14 | 
15 |
16 | ## Installation ##
17 | ```bash
18 | conda create -n STEPP python=3.8
19 | conda activate STEPP
20 | cd
21 | # We use cuda 12.1 drivers/
22 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
23 | ```
24 | Place the repository in your catkin workspace of choice with your planner of choice implementation, we used the CMU Falco local planner from their [autonomous exploration development environment](https://www.cmu-exploration.com/).
25 |
26 | ```bash
27 | # Assuming an already setup and built ros workspace (workspace containing cmu-exploration, or any other navigation stack)
28 | cd your_navigation_ws/src
29 | git clone git@github.com:RPL-CS-UCL/STEPP-Code.git
30 | cd STEPP-code
31 | pip install -e .
32 | cd ../../..
33 | catkin build STEPP_ros
34 | ```
35 |
36 |
37 | For installation of Jetpack, Pytorch, and Torchvision on your Jetson Platform: [Link](https://pytorch.org/audio/stable/build.jetson.html) and [Link](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048)
38 | * Show jetpack version: ```apt-cache show nvidia-jetpack```
39 | * [MUST] Create conda with python=3.8 and download wheel from this [link](https://nvidia.box.com/shared/static/i8pukc49h3lhak4kkn67tg9j4goqm0m7.whl)
40 | * And then ```pip install torch-2.0.0+nv23.05-cp38-cp38-linux_aarch64.whl```
41 | * Install Torchvision (check the compatiable matrix with the corresponding pytorch).
42 | * Check this [link](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048/1285?page=63) for this issue: ```ValueError: Unknown CUDA arch (8.7+PTX) or GPU not supported```
43 | * Command:
44 | ```
45 | pip install numpy && \
46 | pip install torch-2.0.0+nv23.05-cp38-cp38-linux_aarch64.whl && \
47 | cd torchvision/ && \
48 | export BUILD_VERSION=0.15.1 && \
49 | python setup.py install --user && \
50 | python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); import torchvision"
51 | ```
52 | ## Checkpoints ##
53 |
54 | The following trained checkpoints are included in the repo:
55 |
56 | | Modelname | Dataset| Image resolutions| DINOv2 size |MLP architecture|
57 | |-------------|--------|---------------------|-------------|---------|
58 | | [`richmond_forest.pth`](\\wsl.localhost\Ubuntu-20.04\home\sebastian\code\STEPP-Code\checkpoints\richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth) |Richmond Forest| 700x700 | dinov2_vits14 |bin_nn|
59 | | [`unreal_synthetic_data.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) |Unreal engine synthetic Data| 700x700 | dinov2_vits14 |big_nn|
60 | | [`all_data.pth`](\checkpoints\unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth)|Richmond Forest, Unreal synthetic Data | 700x700 | dinov2_vits14 |big_nn|
61 |
62 | ## Usage ##
63 | to launch the model, set all required paths correctly and build your workspace and run:
64 | ```bash
65 | roslaunch STEPP_ros STEPP.launch
66 | ```
67 |
68 | ### STEPP.launch Arguments
69 | - "model_path": Path to your chosen checkpoint.pth file
70 | - 'visualize': decides if you want to output the overlayed traversability cost onto the image feed (slows inference time)
71 | - 'ump': option to use mixed precision for model inference. Makes inference time faster but requires retraining of model weights for best performance
72 | - 'cutoff': sets the value for the maximum normalized reconstruction error
73 | - "camera_type": [zed2, D455, cmu_sim] - sets the chosen depth projection camera intrinsics
74 | - "decayTime": (unfinished) how long do you want the depth pointcloud with cost to be remembered outside the decay zone and active camera view.
75 |
76 | ## Train Your Own STEPP inference model ##
77 | to train your own STEPP traversability estimation model all you need is a dataset consisting of an image folder and an odometry pose folder. Here each SE(3) odometry pose has to relate to the exact location and rotation of the correlating image. With these two you can run the `extract_future_poses.py` script and obtain a json file containing the pixels that represent the cameras future poses in the given image frame.
78 |
79 | With this json file and the associated images you can run the `make_dataset.py` file to obtain a `.npy` of the DINOv2 feature averaged vectors of each segment that the future poses in each image from your dataset belonges to. this can in turn be used to train the STEPP model on using `training.py`
80 |
81 | ### Acknowledgement
82 | https://github.com/leggedrobotics/wild_visual_navigation\
83 | https://github.com/facebookresearch/dinov2\
84 | https://github.com/HongbiaoZ/autonomous_exploration_development_environment
85 |
86 | ### Citation
87 | If you think any of our work was useful, please connsider citing it:
88 |
89 | ```bibtex
90 | Coming soon
91 | ```
92 |
93 |
--------------------------------------------------------------------------------
/STEPP/DINO/__init__.py:
--------------------------------------------------------------------------------
1 | from .dino_feature_extract import DinoInterface, run_dino_interfacer
--------------------------------------------------------------------------------
/STEPP/DINO/backbone.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) Mark Hamilton. All rights reserved.
3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala.
4 | # All rights reserved. Licensed under the MIT license.
5 | # See LICENSE file in the project root for details.
6 | #
7 | #
8 | import torch
9 | from torch import nn
10 | import numpy as np
11 | from abc import ABC, abstractmethod
12 |
13 |
14 | def get_backbone(cfg):
15 | """
16 | Returns a selected DINOv2VIT backbone.
17 | After implementing the Backbone class for your backbone, add it to be returned from this function with a desired named.
18 | The backbone can then be used by specifying its name in the STEGO configuration file.
19 | """
20 | if not hasattr(cfg, "backbone"):
21 | raise ValueError("Could not find 'backbone' option in the config file. Please check it")
22 |
23 | if cfg.backbone == "dinov2":
24 | return Dinov2ViT(cfg)
25 | else:
26 | raise ValueError("Backbone {} unavailable".format(cfg.backbone))
27 |
28 |
29 | class Backbone(ABC, nn.Module):
30 | """
31 | Base class to provide an interface for new STEGO backbones.
32 |
33 | To add a new backbone for use in STEGO, add a new implementation of this class.
34 | """
35 |
36 | vit_name_long_to_short = {
37 | "vit_tiny": "T",
38 | "vit_small": "S",
39 | "vit_base": "B",
40 | "vit_large": "L",
41 | "vit_huge": "H",
42 | "vit_giant": "G",
43 | }
44 |
45 | # Initialize the backbone
46 | @abstractmethod
47 | def __init__(self, cfg):
48 | super().__init__()
49 |
50 | # Return the size of features generated by the backbone
51 | @abstractmethod
52 | def get_output_feat_dim(self) -> int:
53 | pass
54 |
55 | # Generate features for the given image
56 | @abstractmethod
57 | def forward(self, img):
58 | pass
59 |
60 | # Returh a name that identifies the type of the backbone
61 | @abstractmethod
62 | def get_backbone_name(self):
63 | pass
64 |
65 |
66 | class Dinov2ViT(Backbone):
67 | def __init__(self, cfg):
68 | super().__init__(cfg)
69 | self.cfg = cfg
70 | self.backbone_type = self.cfg.backbone_type
71 | self.patch_size = 14
72 | if self.backbone_type == "vit_small":
73 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14", verbose=False)
74 | elif self.backbone_type == "vit_base":
75 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14", verbose=False)
76 | elif self.backbone_type == "vit_large":
77 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14", verbose=False)
78 | elif self.backbone_type == "vit_giant":
79 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", verbose=False)
80 | elif self.backbone_type == "vit_small_reg":
81 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg", verbose=False)
82 | elif self.backbone_type == "vit_base_reg":
83 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg", verbose=False)
84 | elif self.backbone_type == "vit_large_reg":
85 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg", verbose=False)
86 | elif self.backbone_type == "vit_giant_reg":
87 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14_reg", verbose=False)
88 | else:
89 | raise ValueError("Model type {} unavailable".format(cfg.backbone_type))
90 |
91 | for p in self.model.parameters():
92 | p.requires_grad = False
93 | self.model.eval().cuda()
94 | self.dropout = torch.nn.Dropout2d(p=np.clip(self.cfg.dropout_p, 0.0, 1.0))
95 |
96 | if self.backbone_type == "vit_small":
97 | self.n_feats = 384
98 | elif self.backbone_type == "vit_base":
99 | self.n_feats = 768
100 | elif self.backbone_type == "vit_large":
101 | self.n_feats = 1024
102 | elif self.backbone_type == "vit_giant":
103 | self.n_feats = 1536
104 | else:
105 | self.n_feats = 768
106 |
107 | def get_output_feat_dim(self):
108 | return self.n_feats
109 |
110 | def forward(self, img):
111 | self.model.eval()
112 | with torch.no_grad():
113 | assert img.shape[2] % self.patch_size == 0
114 | assert img.shape[3] % self.patch_size == 0
115 |
116 | # get selected layer activations
117 | feat = self.model.get_intermediate_layers(img)[0]
118 |
119 | feat_h = img.shape[2] // self.patch_size
120 | feat_w = img.shape[3] // self.patch_size
121 |
122 | image_feat = feat[:, :, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
123 |
124 | if self.cfg.dropout_p > 0:
125 | return self.dropout(image_feat)
126 | else:
127 | return image_feat
128 |
129 | def get_backbone_name(self):
130 | return "DINOv2-" + Backbone.vit_name_long_to_short[self.backbone_type] + "-" + str(self.patch_size)
--------------------------------------------------------------------------------
/STEPP/DINO/dino_feature_extract.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (c) 2022-2024, ETH Zurich, Jonas Frey, Matias Mattamala.
3 | # All rights reserved. Licensed under the MIT license.
4 | # See LICENSE file in the project root for details.
5 | #
6 | from os.path import join
7 | import torch.nn.functional as F
8 | import torch
9 | import torch.quantization as quant
10 | from torchvision import transforms as T
11 | from omegaconf import OmegaConf
12 | import numpy as np
13 | from pytictac import Timer
14 |
15 | from STEPP.DINO.backbone import get_backbone
16 |
17 |
18 | class DinoInterface:
19 | def __init__(
20 | self,
21 | device: str,
22 | backbone: str = "dino",
23 | input_size: int = 448,
24 | backbone_type: str = "vit_small",
25 | patch_size: int = 8,
26 | projection_type: str = None, # nonlinear or None
27 | dropout_p: float = 0, # True or False
28 | pretrained_weights: str = None,
29 | interpolate: bool = True,
30 | use_mixed_precision: bool = False,
31 | cfg: OmegaConf = OmegaConf.create({}),
32 | ):
33 | # Load config
34 | if cfg.is_empty():
35 | self._cfg = OmegaConf.create(
36 | {
37 | "backbone": backbone,
38 | "backbone_type": backbone_type,
39 | "input_size": input_size,
40 | "patch_size": patch_size,
41 | "projection_type": projection_type,
42 | "dropout_p": dropout_p,
43 | "pretrained_weights": pretrained_weights,
44 | "interpolate": interpolate,
45 | }
46 | )
47 | else:
48 | self._cfg = cfg
49 |
50 | # Initialize DINO
51 | self._model = get_backbone(self._cfg)
52 |
53 | # Send to device
54 | self._model.to(device)
55 | self._device = device
56 |
57 | # self._model = quant.quantize_dynamic(self._model, dtype=torch.qint8, inplace=True)
58 | self.use_mixed_precision = use_mixed_precision
59 | if self.use_mixed_precision:
60 | self._model = self._model.to(torch.float16)
61 |
62 |
63 | # Other
64 | normalization = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
65 | self._transform = T.Compose(
66 | [
67 | T.Resize(input_size, T.InterpolationMode.NEAREST),
68 | T.CenterCrop(input_size),
69 | # T.CenterCrop((input_size, 1582)),
70 | normalization,
71 | ]
72 | )
73 |
74 | def change_device(self, device):
75 | """Changes the device of all the class members
76 |
77 | Args:
78 | device (str): new device
79 | """
80 | self._model.to(device)
81 | self._device = device
82 |
83 | @torch.no_grad()
84 | def inference(self, img: torch.tensor):
85 | """Performance inference using DINO
86 | Args:
87 | img (torch.tensor, dtype=type.torch.float32, shape=(B,3,H.W)): Input image
88 |
89 | Returns:
90 | features (torch.tensor, dtype=torch.float32, shape=(B,D,H,W)): per-pixel D-dimensional features
91 | """
92 |
93 | # Resize image and normalize
94 | resized_img = self._transform(img).to(self._device)
95 | if self.use_mixed_precision:
96 | resized_img=resized_img.half()
97 |
98 | # Extract features
99 | features = self._model.forward(resized_img)
100 | # print('features shape before interpolation', features.shape)
101 |
102 | if self._cfg.interpolate:
103 | # resize and interpolate features
104 | B, D, H, W = img.shape
105 | new_features_size = (H, W)
106 | # pad = int((W - H) / 2)
107 | features = F.interpolate(features, new_features_size, mode="bilinear", align_corners=True)
108 | print('features shape after interpolation', features.shape)
109 | # features = F.pad(features, pad=[pad, pad, 0, 0])
110 |
111 | return features.to(torch.float32)
112 |
113 | @property
114 | def input_size(self):
115 | return self._cfg.input_size
116 |
117 | @property
118 | def backbone(self):
119 | return self._cfg.backbone
120 |
121 | @property
122 | def backbone_type(self):
123 | return self._cfg.backbone_type
124 |
125 | @property
126 | def vit_patch_size(self):
127 | return self._cfg.patch_size
128 |
129 |
130 | def get_dino_features(img, dino_size, interpolate):
131 | # Inference model
132 | device = "cuda" if torch.cuda.is_available() else "cpu"
133 | # #convert image to torch tensor
134 | # img = torch.from_numpy(img)
135 | img = img.to(device)
136 | # img = F.interpolate(img, scale_factor=0.25)
137 |
138 | # Settings
139 | size = 896
140 | model = dino_size
141 | patch = 14
142 | backbone = "dinov2"
143 |
144 | # Inference with DINO
145 | # Create DINO
146 | di = DinoInterface(
147 | device=device,
148 | backbone=backbone,
149 | input_size=size,
150 | backbone_type=model,
151 | patch_size=patch,
152 | interpolate=interpolate,
153 | )
154 |
155 | # with Timer(f"DINO, input_size, {di.input_size}, model, {di.backbone_type}, patch_size, {di.vit_patch_size}"):
156 | feat_dino = di.inference(img)
157 | # print(f"Feature shape after interpolation: {feat_dino.shape}")
158 |
159 | return feat_dino
160 |
161 | def average_dino_feature_segment(features, segment_img, segments=None):
162 | #features is a torch tensor of shape [1, 384, 64, 64]
163 |
164 | averaged_features = []
165 |
166 | if segments is None:
167 | segments = np.unique(segment_img)
168 |
169 | # Loop through each segment
170 | for segment_id in segments:
171 | segment_pixels = segment_img.astype(np.uint16) == segment_id
172 | selected_features = features[:, :, segment_pixels]
173 | vector = selected_features.mean(dim=-1)
174 | averaged_features.append(vector)
175 |
176 | # Stack all vectors vertically to form a m by n tensor
177 | averaged_features_tensor = torch.cat(averaged_features, dim=0)
178 |
179 | return averaged_features_tensor
180 |
181 | def average_dino_feature_segment_tensor(features, segment_img, segments=None):
182 |
183 | if segments is None:
184 | segments, segments_count = torch.unique(segment_img, return_counts=True)
185 |
186 | features_flattened = features.permute(0,2,3,1).flatten(0,-2) # (bhw x n_features)
187 | index = segment_img.flatten().unsqueeze(-1).repeat(1,features_flattened.shape[-1]).long() # (bhw x n_features)
188 | num_segments = torch.max(segment_img).int()+1 # adding +1 for the 0 ID.
189 | output = torch.zeros( (num_segments, features_flattened.shape[-1]), device="cuda", dtype=features.dtype)
190 | segment_means = output.scatter_reduce(0,index, features_flattened, reduce="sum")
191 | segment_means = segment_means[segment_means.sum(-1)!=0] / segments_count.unsqueeze(-1)
192 | # print("Difference between two methods",(segment_means-averaged_features_tensor).sum())
193 | averaged_features_tensor = segment_means
194 |
195 | return averaged_features_tensor
196 |
197 | def run_dino_interfacer():
198 | """Performance inference using DINOv2VIT and stores result as an image."""
199 |
200 | from pytictac import Timer
201 | from STEPP.utils.misc import get_img_from_fig, load_test_image, make_results_folder, remove_axes
202 | import matplotlib.pyplot as plt
203 |
204 | #supress warnings
205 | import warnings
206 | warnings.filterwarnings("ignore")
207 |
208 |
209 | # Create test directory
210 | outpath = make_results_folder("test_dino_interfacer")
211 |
212 | # Inference model
213 | device = "cuda" if torch.cuda.is_available() else "cpu"
214 | img = load_test_image().to(device)
215 | # img = F.interpolate(img, scale_factor=0.25)
216 |
217 | print('image after interpolation before going to model', img.shape)
218 |
219 | plot = False
220 | save_features = True
221 |
222 | # Settings
223 | size = 896
224 | model = "vit_small"
225 | patch = 14
226 | backbone = "dinov2"
227 |
228 | # Inference with DINO
229 | # Create DINO
230 | di = DinoInterface(
231 | device=device,
232 | backbone=backbone,
233 | input_size=size,
234 | backbone_type=model,
235 | patch_size=patch,
236 | )
237 |
238 | with Timer(f"DINO, input_size, {di.input_size}, model, {di.backbone_type}, patch_size, {di.vit_patch_size}"):
239 | feat_dino = di.inference(img)
240 | print(f"Feature shape after interpolation: {feat_dino.shape}")
241 |
242 | if save_features:
243 | for i in range(5):
244 | fig = plt.figure(frameon=False)
245 | fig.set_size_inches(2, 2)
246 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
247 | ax.set_axis_off()
248 | fig.add_axes(ax)
249 | ax.imshow(feat_dino[0][i].cpu(), cmap=plt.colormaps.get("inferno"))
250 |
251 | # Store results to test directory
252 | out_img = get_img_from_fig(fig)
253 | out_img.save(
254 | join(
255 | outpath,
256 | f"forest_clean_dino_feat{i:02}_{di.input_size}_{di.backbone_type}_{di.vit_patch_size}.png",
257 | )
258 | )
259 | plt.close("all")
260 |
261 | if plot:
262 | # Plot result as in colab
263 | fig, ax = plt.subplots(10, 11, figsize=(1 * 11, 1 * 11))
264 |
265 | for i in range(10):
266 | for j in range(11):
267 | if i == 0 and j == 0:
268 | continue
269 |
270 | elif (i == 0 and j != 0) or (i != 0 and j == 0):
271 | ax[i][j].imshow(img.permute(0, 2, 3, 1)[0].cpu())
272 | ax[i][j].set_title("Image")
273 | else:
274 | n = (i - 1) * 10 + (j - 1)
275 | if n >= di.get_feature_dim():
276 | break
277 | ax[i][j].imshow(feat_dino[0][n].cpu(), cmap=plt.colormaps.get("inferno"))
278 | ax[i][j].set_title("Features [0]")
279 | remove_axes(ax)
280 | plt.tight_layout()
281 |
282 | # Store results to test directory
283 | out_img = get_img_from_fig(fig)
284 | out_img.save(
285 | join(
286 | outpath,
287 | f"forest_clean_{di.backbone}_{di.input_size}_{di.backbone_type}_{di.vit_patch_size}.png",
288 | )
289 | )
290 | plt.close("all")
291 |
292 |
293 | if __name__ == "__main__":
294 | run_dino_interfacer()
295 |
--------------------------------------------------------------------------------
/STEPP/SLIC/slic_segmentation.py:
--------------------------------------------------------------------------------
1 | #file to run SLIC segmentation on an image
2 |
3 | import cv2
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from fast_slic import Slic
7 | import time
8 | import json
9 | from collections import defaultdict
10 | from torchvision import transforms as T
11 | from PIL import Image
12 | import torch
13 | import torch.nn.functional as F
14 | from torchvision import transforms
15 | import os
16 | from pytictac import Timer
17 |
18 | class SLIC():
19 | def __init__(self, crop_x=30, crop_y=20, num_superpixels=400, compactness=15):
20 | if crop_x == 0 and crop_y == 0:
21 | self.crop = False
22 | else:
23 | self.crop = True
24 | self.crop_x = crop_x
25 | self.crop_y = crop_y
26 | self.num_superpixels = num_superpixels
27 | self.compactness = compactness
28 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
29 | self.slic = Slic(num_components=self.num_superpixels, compactness=self.compactness)
30 |
31 | def Slic_segmentation_for_given_pixels(self, pixels, image):
32 | # Load your image
33 | if self.crop:
34 | only_img = image[self.crop_y:-self.crop_y, self.crop_x:-self.crop_x]
35 | else:
36 | only_img = image
37 | # Convert BGR image to RGB for matplotlib
38 | image_rgb = cv2.cvtColor(only_img, cv2.COLOR_BGR2RGB)
39 |
40 | # Create Slic object
41 | slic = Slic(num_components=self.num_superpixels, compactness=self.compactness)
42 |
43 | # Perform segmentation
44 | segmented_image = slic.iterate(image_rgb)
45 |
46 | # Assuming pixels is a list of (x, y) tuples or a 2D array where each row is an (x, y) pair
47 | pixels_array = np.array(pixels)
48 |
49 | # Extract the x and y coordinates
50 | y_coords = pixels_array[:, 0]
51 | x_coords = pixels_array[:, 1]
52 |
53 | # Use advanced indexing to get the segment values at the given (x, y) coordinates
54 | segment_values = segmented_image[x_coords, y_coords]
55 |
56 | # Create a dictionary to hold lists of pixel coordinates for each segment
57 | segment_dict = defaultdict(list)
58 |
59 | # Populate the dictionary with pixel coordinates grouped by their segment
60 | for i in range(len(segment_values)):
61 | segment = segment_values[i]
62 | pixel = (x_coords[i], y_coords[i])
63 | segment_dict[segment].append(pixel)
64 |
65 | return segment_dict, segmented_image
66 |
67 | def Slic_segmentation_for_all_pixels(self, image):
68 | # Load your image
69 | if self.crop:
70 | only_img = image[self.crop_y:-self.crop_y, self.crop_x:-self.crop_x]
71 | else:
72 | only_img = image
73 |
74 | # Convert BGR image to RGB
75 | image_rgb = cv2.cvtColor(only_img, cv2.COLOR_BGR2RGB)
76 |
77 | # Create Slic object
78 | slic = Slic(num_components=self.num_superpixels, compactness=self.compactness)
79 |
80 | # Perform segmentation
81 | segmented_image = self.slic.iterate(image_rgb)
82 |
83 | # Get unique segment values
84 | unique_segments = np.unique(segmented_image)
85 |
86 | return unique_segments, segmented_image
87 |
88 | def Slic_segmentation_for_all_pixels_torch(self, image):
89 | # Load your image
90 | if self.crop:
91 | only_img = image[self.crop_y:-self.crop_y, self.crop_x:-self.crop_x]
92 | else:
93 | only_img = image
94 |
95 | # Convert BGR image to RGB
96 | image_rgb = cv2.cvtColor(only_img, cv2.COLOR_BGR2RGB)
97 |
98 | # Create Slic object
99 | slic = Slic(num_components=self.num_superpixels, compactness=self.compactness)
100 |
101 | # Perform segmentation
102 | segmented_image = self.slic.iterate(image_rgb)
103 |
104 |
105 | #put image onto the gpu
106 | segmented_image = torch.from_numpy(segmented_image).to(self.device)
107 |
108 | # Get unique segment values
109 | unique_segments = torch.unique(segmented_image)
110 |
111 | return unique_segments, segmented_image
112 |
113 | def make_masks_smaller_numpy(self, segment_values, segmented_image, wanted_size):
114 | # Convert NumPy array to PIL image
115 | segmented_image_pil = Image.fromarray(segmented_image.astype('uint16'), mode='I;16')
116 |
117 |
118 | # Resize the image while maintaining the pixel values
119 | resized_segmented_image_pil = segmented_image_pil.resize((wanted_size, wanted_size), Image.NEAREST)
120 |
121 | # Convert the resized PIL image back to a NumPy array
122 | resized_segmented_image = np.array(resized_segmented_image_pil).astype(np.uint16)
123 |
124 | new_segment_dict = defaultdict(list)
125 |
126 | # Iterate over each unique segment value
127 | for key in segment_values:
128 | # Find the coordinates where the pixel value equals the key
129 | coordinates = np.where(resized_segmented_image == key)
130 |
131 | # Zip the coordinates to get (row, column) pairs and store them in the dictionary
132 | new_segment_dict[key].extend(zip(coordinates[0], coordinates[1]))
133 |
134 | return resized_segmented_image, new_segment_dict
135 |
136 | def make_masks_smaller_torch(self, segment_values, segmented_image, wanted_size, return_dict=True):
137 |
138 | segmented_image = segmented_image.unsqueeze(0).unsqueeze(0).float()
139 | # Resize the image while maintaining the pixel values
140 | resized_segmented_image = F.interpolate(
141 | segmented_image,
142 | size=(wanted_size, wanted_size),
143 | mode='nearest')
144 |
145 | #get rid of the first and second dimension
146 | resized_segmented_image = resized_segmented_image.squeeze(0).squeeze(0)
147 |
148 | new_segment_dict = defaultdict(list)
149 | if return_dict:
150 | # Iterate over each unique segment value
151 | with Timer("loop"):
152 | for key in segment_values:
153 | # Find the coordinates where the pixel value equals the key
154 | coordinates = torch.where(resized_segmented_image == key)
155 |
156 | # Zip the coordinates to get (row, column) pairs and store them in the dictionary
157 | new_segment_dict[key].extend(zip(coordinates[0].tolist(), coordinates[1].tolist()))
158 | print (f"looped {len(segment_values)} times")
159 |
160 | return resized_segmented_image, new_segment_dict
161 |
162 | def get_difference_pixels(img1, img2):
163 | # Compute the absolute difference
164 | difference = cv2.absdiff(img1, img2)
165 |
166 | # Threshold the difference to find the significant changes
167 | _, thresholded_difference = cv2.threshold(difference, 25, 255, cv2.THRESH_BINARY)
168 |
169 | # Convert to grayscale
170 | gray_diff = cv2.cvtColor(thresholded_difference, cv2.COLOR_BGR2GRAY)
171 |
172 | # Find contours in the thresholded difference
173 | contours, _ = cv2.findContours(gray_diff, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
174 |
175 | # Find the largest contour by area
176 | largest_contour = max(contours, key=cv2.contourArea)
177 |
178 | flattened_list = [item[0].tolist() for item in largest_contour]
179 |
180 | return flattened_list
181 |
182 | def run_SLIC_segmentation():
183 | """Run SLIC on an image and visualize the segmented image"""
184 |
185 | ##############################################
186 | # This should all be coming from a config file
187 | ##############################################
188 | img_width = 1408
189 | img_height = 1408
190 | x_boarder = 200
191 | y_boarder = 200
192 | number = 10
193 | # pixels = path[number]
194 | # img_path = images[number]
195 | img_path = 'path_to_test_image'
196 | print('img_path:', img_path)
197 | # ##############################################
198 |
199 | # #plot the image with the pixels
200 | img = cv2.imread(img_path)
201 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
202 | #crop image to remove the boarder
203 | img = img[y_boarder:-y_boarder, x_boarder:-x_boarder]
204 | plt.figure(figsize=(10, 10))
205 | plt.imshow(img)#, cmap='inferno')
206 | plt.axis('off')
207 | # plt.show()
208 |
209 | # def overlay_images_1(n1_path, n2_path):
210 | # n1_image = cv2.imread(n1_path)
211 | # n2_image = cv2.imread(n2_path)
212 | # # n2_image[..., 3] = 1
213 |
214 | # mask = n2_image != 0
215 |
216 | # # Create an output image with all black pixels
217 | # output_image = np.zeros_like(n1_image)
218 |
219 | # # Apply the mask to n1_image and store the result in output_image
220 | # output_image[mask] = n1_image[mask]
221 |
222 | # output_image[0:520] = 0
223 |
224 | # #create a list of pixel coord pairs where the image is not black
225 | # pixels = []
226 | # non_black_pixels = np.argwhere(np.any(output_image != 0, axis=-1))
227 | # pixels = non_black_pixels[:, ::-1].tolist()
228 |
229 | # return output_image, pixels
230 |
231 | # def overlay_images_2(n1_path, n2_path):
232 | # n1_image = cv2.imread(n1_path)
233 | # n2_image = cv2.imread(n2_path)
234 | # # n2_image[..., 3] = 1
235 |
236 | # mask = n2_image != 0
237 |
238 | # # Create an output image with all black pixels
239 | # output_image = np.zeros_like(n1_image)
240 |
241 | # # Apply the mask to n1_image and store the result in output_image
242 | # output_image[mask] = n1_image[mask]
243 |
244 | # output_image[0:520] = 0
245 |
246 | # #create a list of pixel coord pairs where the image is not black
247 | # pixels = []
248 | # for i in range(output_image.shape[0]):
249 | # for j in range(output_image.shape[1]):
250 | # if np.any(output_image[i, j] != 0):
251 | # pixels.append([j, i])
252 |
253 | # return output_image, pixels
254 |
255 |
256 | # with Timer('overlay_images'):
257 | # output_img_1, pixels_1 = overlay_images_1(img_path, 'path_to_test_image')
258 |
259 | # with Timer('overlay_images'):
260 | # output_img_2, pixels_2 = overlay_images_2(img_path, 'path_to_test_image')
261 |
262 | # print(pixels_1 == pixels_2)
263 | # print(pixels_1[:10])
264 |
265 | # exit()
266 |
267 | # output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
268 | # plt.figure(figsize=(10, 10))
269 | # plt.imshow(output_img)#, cmap='inferno')
270 | # plt.axis('off')
271 |
272 | #remove entries that contain values of larger than 720-20 and 1280-30
273 | # pixels = [pixel for pixel in pixels if pixel[0] < (img_width-y_boarder) and pixel[1] < (img_height-x_boarder)]
274 | # #also take off 20 from the x and 30 from the y
275 | # pixels = [(pixel[0] - x_boarder, pixel[1] - y_boarder) for pixel in pixels]
276 |
277 | slic = SLIC(crop_x=0, crop_y=0, num_superpixels=100, compactness=10)
278 | seg, seg_img = slic.Slic_segmentation_for_all_pixels(img)
279 | # segments, segmented_image = slic.Slic_segmentation_for_given_pixels(pixels, img)
280 |
281 |
282 | print('number of unique values in segmented image:', len(np.unique(seg_img)))
283 | print(seg)
284 | segmented_image_mask = seg_img
285 |
286 | #make values in each segment in seg_img random:
287 | # for value in seg:
288 | # random_value = np.random.randint(0, 255) # Generate a single random value for the current segment
289 | # seg_img = np.where(seg_img == value, random_value, seg_img) # seg_img = np.random.randint(0, len(np.unique(seg_img)), (seg_img.shape[0], seg_img.shape[1]))
290 |
291 | unique_values = set() # To keep track of the unique random values assigned
292 | for value in seg:
293 | random_value = np.random.randint(0, 255)
294 |
295 | # Ensure the random_value hasn't already been used
296 | while random_value in unique_values:
297 | random_value = np.random.randint(0, 255) # Generate a new random value if a collision occurs
298 |
299 | # Assign the unique random value and record it
300 | seg_img = np.where(seg_img == value, random_value, seg_img)
301 | unique_values.add(random_value) # Add to set of used values
302 | print(len(unique_values))
303 | pixel_list = [[420, 973],
304 | [484, 833],
305 | [475, 745],
306 | [550, 778],
307 | [520, 717],
308 | [585, 678],
309 | [683, 632],
310 | [610, 610],
311 | [660, 668],
312 | [475,1000]]
313 | values = []
314 | for pixels in pixel_list:
315 | # point = (pixel[1], pixel[0])
316 | val = seg_img[(pixels[1], pixels[0])]
317 | print('val:', val)
318 | values.append(val)
319 |
320 | segmented_image_mask = np.where(np.isin(seg_img, values), seg_img, 0)
321 | segmented_image_mask_expanded = np.expand_dims(segmented_image_mask, axis=-1) # Adds a third dimension
322 |
323 | # Now segmented_image_mask_expanded will have shape (1008, 1008, 1)
324 | # Use np.where to compare and select values
325 | seg_img_path = np.where(segmented_image_mask_expanded != 0, img, 255)
326 | # for pixel in pixels:
327 | # # point = (pixel[1], pixel[0])
328 | # val = segmented_image[(pixel[1], pixel[0])]
329 | # segmented_image_mask = np.where(segmented_image == val, 0, segmented_image_mask)
330 |
331 | # Optionally, visualize the segmented image
332 | plt.figure(figsize=(10, 10))
333 | plt.imshow(seg_img)#, cmap='inferno')
334 | plt.axis('off')
335 |
336 | plt.figure(figsize=(10, 10))
337 | plt.imshow(segmented_image_mask)#, cmap='inferno')
338 | plt.axis('off')
339 |
340 | plt.figure(figsize=(10, 10))
341 | plt.imshow(seg_img_path)#, cmap='inferno')
342 | plt.axis('off')
343 |
344 | # resized_segmented_image, new_segment_dict = slic.make_masks_smaller(segments.keys(), segmented_image, 64)
345 |
346 | # print('new_segment_dict:', new_segment_dict)
347 |
348 | # print('number of unique values in resized segmented image:', len(np.unique(resized_segmented_image)))
349 |
350 | # resized_segmented_image_mask = resized_segmented_image
351 |
352 | # for key in new_segment_dict.keys():
353 | # resized_segmented_image_mask = np.where(resized_segmented_image == float(key), 0, resized_segmented_image_mask)
354 |
355 | # plt.figure(figsize=(10, 10))
356 | # plt.imshow(resized_segmented_image)#, cmap='inferno')
357 | # plt.axis('off')
358 |
359 | # # Optionally, visualize the resized segmented image
360 | # plt.figure(figsize=(10, 10))
361 | # plt.imshow(resized_segmented_image_mask)#, cmap='inferno')
362 | # plt.axis('off')
363 | plt.show()
364 |
365 |
366 | if __name__ == "__main__":
367 | run_SLIC_segmentation()
--------------------------------------------------------------------------------
/STEPP/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # Add the directory containing STEPP to the Python path
5 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6 | sys.path.append(ROOT_DIR)
7 | """Absolute path to the STEPP repository."""
--------------------------------------------------------------------------------
/STEPP/model/mlp.py:
--------------------------------------------------------------------------------
1 | #Script to train a MLP network
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | from torch.utils.data import DataLoader, TensorDataset
9 | # from sklearn.model_selection import train_test_split
10 | # from sklearn.datasets import make_classification
11 | # from sklearn.preprocessing import StandardScaler
12 | # from sklearn.metrics import accuracy_score
13 | import matplotlib.pyplot as plt
14 | import os
15 | import json
16 | import argparse
17 |
18 | #MLP encoder decoder architecture
19 | class ReconstructMLP(nn.Module):
20 | def __init__(self, input_dim, hidden_dim):
21 | super(ReconstructMLP, self).__init__()
22 | output_dim = input_dim
23 | layers = []
24 | for hd in hidden_dim[:]:
25 | layers.append(nn.Linear(input_dim, hd))
26 | layers.append(nn.ReLU())
27 | input_dim = hd
28 | layers.append(nn.Linear(input_dim, output_dim))
29 |
30 | self.model = nn.Sequential(*layers)
31 |
32 | def forward(self, x):
33 | return self.model(x)
34 |
35 | # input_dim = 384
36 | # hidden layer dim = [input dim, 256, 64, 32, 16, 32, 64, 256, input_dim]
37 |
38 |
39 | #VAE encoder decoder architecture
40 | class ReconstructVAE (nn.Module):
41 | def __init__(self, input_dim, hidden_dim, latent_dim):
42 | super(ReconstructVAE, self).__init__()
43 | self.encoder = nn.Sequential(
44 | nn.Linear(input_dim, hidden_dim[0]),
45 | nn.ReLU(),
46 | nn.Linear(hidden_dim[0], hidden_dim[1]),
47 | nn.ReLU(),
48 | nn.Linear(hidden_dim[1], hidden_dim[2]),
49 | nn.ReLU(),
50 | nn.Linear(hidden_dim[2], hidden_dim[3]),
51 | nn.ReLU(),
52 | nn.Linear(hidden_dim[3], hidden_dim[4]),
53 | nn.ReLU(),
54 | nn.Linear(hidden_dim[4], hidden_dim[5]),
55 | nn.ReLU(),
56 | nn.Linear(hidden_dim[5], hidden_dim[6]),
57 | nn.ReLU(),
58 | nn.Linear(hidden_dim[6], latent_dim * 2)
59 | )
60 |
61 | self.decoder = nn.Sequential(
62 | nn.Linear(latent_dim, hidden_dim[6]),
63 | nn.ReLU(),
64 | nn.Linear(hidden_dim[6], hidden_dim[5]),
65 | nn.ReLU(),
66 | nn.Linear(hidden_dim[5], hidden_dim[4]),
67 | nn.ReLU(),
68 | nn.Linear(hidden_dim[4], hidden_dim[3]),
69 | nn.ReLU(),
70 | nn.Linear(hidden_dim[3], hidden_dim[2]),
71 | nn.ReLU(),
72 | nn.Linear(hidden_dim[2], hidden_dim[1]),
73 | nn.ReLU(),
74 | nn.Linear(hidden_dim[1], hidden_dim[0]),
75 | nn.ReLU(),
76 | nn.Linear(hidden_dim[0], input_dim)
77 | )
78 |
79 | def forward(self, x):
80 | mu, log_var = torch.chunk(self.encoder(x), 2, dim=1)
81 | z = self.reparameterize(mu, log_var)
82 | return self.decoder(z), mu, log_var
83 |
84 | def reparameterize(self, mu, log_var):
85 | std = torch.exp(0.5 * log_var)
86 | eps = torch.randn_like(std)
87 | return mu + eps * std
--------------------------------------------------------------------------------
/STEPP/model/training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from torch.utils.data import DataLoader
5 | from STEPP.model.mlp import ReconstructMLP
6 | import numpy as np
7 | from STEPP.utils.misc import make_results_folder
8 | from STEPP.utils.testing import test_feature_reconstructor_with_model
9 | import time
10 | import wandb
11 | import sys
12 | import os
13 |
14 |
15 | # Data loader
16 | class FeatureDataset:
17 | def __init__(self, feature_dir, stack=False, transform=None, target_transform=None, batch_size=None) -> None:
18 | self.feature_dir = feature_dir
19 | self.transform = transform
20 | self.batch_size = batch_size
21 | self.target_transform = target_transform
22 |
23 | if stack:
24 | #from the folder, load all numpy files and combine them into one big numpy array
25 | #loop through all files in the folder
26 | for root, dirs, files in os.walk(self.feature_dir):
27 | for file in files:
28 | if file.endswith('.npy'):
29 | #load the numpy file
30 | if not hasattr(self, 'avg_features'):
31 | self.avg_features = np.load(os.path.join(root, file)).astype(np.float32)
32 | else:
33 | self.avg_features = np.concatenate((self.avg_features, np.load(os.path.join(root, file)).astype(np.float32)), axis=0)
34 | print(self.avg_features.shape)
35 | self.avg_features = self.avg_features[~np.isnan(self.avg_features).any(axis=1)]
36 | else:
37 | self.avg_features = np.load(self.feature_dir).astype(np.float32)
38 | self.avg_features = self.avg_features[~np.isnan(self.avg_features).any(axis=1)]
39 |
40 | def __len__(self) -> int:
41 | return len(self.avg_features)
42 |
43 | def __getitem__(self, idx: int):
44 | if self.batch_size:
45 | feature = self.avg_features[idx:idx+self.batch_size]
46 | print(feature.shape)
47 | else:
48 | feature = self.avg_features[idx]
49 | if self.transform:
50 | feature = self.transform(feature)
51 | if self.target_transform:
52 | feature = self.target_transform(feature)
53 | return feature
54 |
55 | class EarlyStopping:
56 | def __init__(self, patience=20, verbose=False, delta=0):
57 | self.patience = patience
58 | self.verbose = verbose
59 | self.delta = delta
60 | self.counter = 0
61 | self.best_score = None
62 | self.early_stop = False
63 | self.val_loss_min = float('inf')
64 | self.training_start_time = time.strftime("%Y%m%d-%H%M")
65 |
66 | def __call__(self, val_loss, model):
67 | score = -val_loss
68 |
69 | if self.best_score is None:
70 | self.best_score = score
71 | self.save_checkpoint(val_loss, model)
72 | elif score < self.best_score + self.delta:
73 | self.counter += 1
74 | if self.verbose:
75 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
76 | if self.counter >= self.patience:
77 | self.early_stop = True
78 | else:
79 | self.best_score = score
80 | self.save_checkpoint(val_loss, model)
81 | self.counter = 0
82 |
83 | def save_checkpoint(self, val_loss, model):
84 | '''Saves model when validation loss decrease.'''
85 | results_folder = make_results_folder('trained_model')
86 | if self.verbose:
87 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
88 | torch.save(model.state_dict(), results_folder + f'/all_ViT_small_ump_input_700_small_nn_checkpoint_{self.training_start_time}.pth')
89 | self.val_loss_min = val_loss
90 |
91 |
92 | class TrainFeatureReconstructor():
93 |
94 | def __init__(self, path, batch_size=32, epochs=1, learning_rate=1e-3):
95 | self.device = (
96 | "cuda"
97 | if torch.cuda.is_available()
98 | else "cpu"
99 | )
100 | print(f"Using {self.device} device")
101 | self.input_dim = 384
102 | # self.hidden_dim = [256, 128, 64, 32, 64, 128, 256] #big nn
103 | self.hidden_dim = [256, 64, 32, 16, 32, 64, 256] # small nn
104 | # self.hidden_dim = [1024, 512, 256, 64, 32, 16, 32, 64, 256, 512, 1024] #huge nn
105 | # self.hidden_dim = [256, 32] # wvn nn
106 | self.batch_size = batch_size
107 | self.epochs = epochs
108 | self.learning_rate = learning_rate
109 | self.data_path = path
110 | self.stack = True
111 | self.early_stopping = EarlyStopping(patience=10, verbose=True)
112 |
113 | # Training loop
114 | def train_loop(self, train_dataloader, loss_fn, optimizer):
115 | self.model.train()
116 |
117 | for epoch in range(self.epochs):
118 | running_loss = 0.0
119 | for data in train_dataloader:
120 | inputs = targets = data.to(self.device)
121 |
122 | # Zero the parameter gradients
123 | optimizer.zero_grad()
124 |
125 | # Forward pass
126 | outputs = self.model(inputs
127 | )
128 | loss = loss_fn(outputs, targets)
129 |
130 | # Backward pass and optimize
131 | loss.backward()
132 | optimizer.step()
133 | optimizer.zero_grad()
134 |
135 | # Print statistics
136 | running_loss += loss.item()
137 |
138 | epoch_loss = running_loss / len(train_dataloader)
139 | print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {epoch_loss:.4f}")
140 |
141 | meta = {'epoch': epoch, 'loss': epoch_loss}
142 | if (epoch+1) % 10 == 0:
143 | test_dict = self.test_loop(self.model, loss_fn)
144 | meta.update(test_dict)
145 | self.early_stopping(test_dict["test_loss"], self.model)
146 | wandb.log(meta)
147 | if self.early_stopping.early_stop:
148 | print("Early stopping")
149 | exit()
150 |
151 |
152 | print('Finished Training')
153 |
154 | def test_loop(self, model, loss_fn):
155 | # Set the model to evaluation mode
156 | model.eval()
157 | dataloader = self.test_dataloader
158 | num_batches = len(dataloader)
159 | test_loss = 0
160 |
161 | # Ensure no gradients are computed during test mode
162 | with torch.no_grad():
163 | for X in dataloader:
164 | X = X.to(self.device)
165 | # Forward pass: compute the model output
166 | recon_X = model(X)
167 | # Compute the loss
168 | test_loss += loss_fn(recon_X, X).item()
169 |
170 | # Compute the average loss over all batches
171 | test_loss /= num_batches
172 | print(f"Test Error: \n Avg MSE Loss: {test_loss:>8f} \n")
173 |
174 | # test on one validation image
175 | mode = 'segment_wise'
176 | test_image_path = 'path_to_test_image'
177 | figure = test_feature_reconstructor_with_model(mode,self.model, test_image_path)
178 | return dict(test_loss=test_loss, test_plot=figure)
179 |
180 | def data_split(self, dataset, train_split=0.8):
181 | train_size = int(train_split * len(dataset))
182 | test_size = len(dataset) - train_size
183 | # training_data = dataset[:1000]
184 | # test_data = training_data
185 | training_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])
186 |
187 | train_dataloader = DataLoader(training_data, batch_size=self.batch_size, shuffle=True)
188 | test_dataloader = DataLoader(test_data, batch_size=self.batch_size, shuffle=True)
189 |
190 | return train_dataloader, test_dataloader
191 |
192 | def main(self):
193 |
194 | # Creating DataLoader
195 | dataset = FeatureDataset(self.data_path, self.stack)
196 |
197 | # Model instantiation
198 | self.model = ReconstructMLP(self.input_dim, self.hidden_dim).to(self.device)
199 | print(self.model)
200 |
201 | # Loss function and optimizer
202 | loss_fn = nn.MSELoss()
203 | optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
204 |
205 | # Splitting the data
206 | train_dataloader, self.test_dataloader = self.data_split(dataset)
207 |
208 | # Training the model
209 | self.train_loop(train_dataloader, loss_fn, optimizer)
210 |
211 | # Testing the model
212 | self.test_loop(self.model, loss_fn)
213 |
214 |
215 | if __name__ == '__main__':
216 | wandb.init(project='STEPP')
217 |
218 | path_to_features = f'path_to_features'
219 | TrainFeatureReconstructor(path_to_features, epochs=1000000).main()
--------------------------------------------------------------------------------
/STEPP/utils/colorbar.py:
--------------------------------------------------------------------------------
1 | # import matplotlib.pyplot as plt
2 | # import numpy as np
3 | # from matplotlib.colors import LinearSegmentedColormap, Normalize
4 | # from matplotlib.colorbar import ColorbarBase
5 |
6 | # # Your custom colormap stretching
7 | # s = 0.3
8 | # original_cmap = plt.cm.get_cmap("RdYlGn", 5000)
9 | # new_colors = np.vstack([
10 | # original_cmap(np.linspace(0, s, 2500)),
11 | # original_cmap(np.linspace(1 - s, 1.0, 2500))
12 | # ])
13 | # new_cmap = LinearSegmentedColormap.from_list("stretched_RdYlBu", new_colors[::-1])
14 |
15 | # fig, ax = plt.subplots(figsize=(2, 6)) # Size this appropriately to your needs
16 |
17 | # # Normalize the colormap
18 | # norm = Normalize(vmin=0, vmax=1)
19 |
20 | # # Create the colorbar
21 | # cbar = ColorbarBase(ax, cmap='hsv', norm=norm, orientation='vertical')
22 | # cbar.set_label('Predicted Traversability') # Label according to what the colors represent
23 |
24 | # plt.show()
25 |
26 |
27 | import matplotlib.pyplot as plt
28 | import numpy as np
29 | from matplotlib.colors import LinearSegmentedColormap, Normalize
30 | from matplotlib.colorbar import ColorbarBase
31 |
32 | # Create a segment of the 'hsv' colormap
33 | original_cmap = plt.cm.get_cmap('hsv')
34 | segment = np.linspace(0, 0.3, 256) # Adjust 256 for smoother or coarser color transitions
35 | colors = original_cmap(segment)
36 |
37 | # Create a new colormap from this segment
38 | new_cmap = LinearSegmentedColormap.from_list('red_to_green', colors)
39 |
40 | # Setup figure and axes for the color bar
41 | fig, ax = plt.subplots(figsize=(1, 10)) # Adjust figure size as needed
42 |
43 | # Normalize the colormap
44 | norm = Normalize(vmin=0, vmax=1)
45 |
46 | # Create the color bar using the new colormap
47 | cbar = ColorbarBase(ax, cmap=new_cmap, norm=norm, orientation='vertical')
48 | cbar.set_label('Value Range')
49 |
50 | plt.show()
51 |
--------------------------------------------------------------------------------
/STEPP/utils/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.optim import Adam
5 | from torch.utils.data import Dataset
6 | import numpy as np
7 |
8 | class FeatureDataset:
9 | def __init__(self, feature_dir, transform=None, target_transform=None, batch_size=None) -> None:
10 | self.feature_dir = feature_dir
11 | self.transform = transform
12 | self.batch_size = batch_size
13 | self.target_transform = target_transform
14 | self.avg_features = np.load(self.feature_dir)
15 |
16 | def __len__(self) -> int:
17 | return len(self.avg_features)
18 |
19 | def __getitem__(self, idx: int):
20 |
21 | if self.batch_size:
22 | feature = self.avg_features[idx:idx+self.batch_size]
23 | print(feature.shape)
24 | else:
25 | feature = self.avg_features[idx]
26 | if self.transform:
27 | feature = self.transform(feature)
28 | if self.target_transform:
29 | feature = self.target_transform(feature)
30 | return feature
--------------------------------------------------------------------------------
/STEPP/utils/extract_future_poses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | import cv2
4 | import numpy as np
5 | from camera import Camera
6 | from scipy.spatial.transform import Rotation as R
7 | import rospy
8 | from nav_msgs.msg import Odometry as odom
9 | import os
10 | import matplotlib.pyplot as plt
11 | import json
12 |
13 |
14 | class CameraPinhole(Camera):
15 | def __init__(self, width, height, camera_name, distortion_model, K, D, Rect, P):
16 | super().__init__(width, height, camera_name, distortion_model, K, D, Rect, P)
17 |
18 | def undistort(self, image):
19 | undistorted_image = cv2.undistort(image, self.K, self.D)
20 | return undistorted_image
21 |
22 | def main():
23 | """Main function to test the Camera class."""
24 | # Create a pinhole camera model
25 | D = np.array([-0.28685832023620605, -2.0772109031677246, 0.0005875344504602253, -0.0005043392884545028, 1.5214914083480835, -0.39617425203323364, -1.8762085437774658, 1.4227665662765503])
26 | K = np.array([607.9638061523438, 0.0, 638.83984375, 0.0, 607.9390869140625, 367.0916748046875, 0.0, 0.0, 1.0]).reshape(3, 3)
27 | Rect = np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).reshape(3, 3)
28 | P = np.array([607.9638061523438, 0.0, 638.83984375, 0.0, 0.0, 607.9390869140625, 367.0916748046875, 0.0, 0.0, 0.0, 1.0, 0.0]).reshape(3, 4)
29 | camera_pinhole = CameraPinhole(width=1280, height=720, camera_name='kinect_camera',
30 | distortion_model='rational_polynomial',
31 | K=K, D=D, Rect=Rect, P=P)
32 |
33 | # Initialize lists to store coordinates and orientations
34 | coordinates = []
35 | orientations = []
36 | directions = []
37 |
38 | folder_path = 'path_to_image_folder'
39 | images = sorted([os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith((".png", ".jpg", ".jpeg"))])
40 | img_file_names = [os.path.basename(img) for img in images]
41 |
42 | # Initialize the ROS node
43 | rospy.init_node('trajectory_publisher', anonymous=True)
44 | # Initialize the publisher
45 | pub = rospy.Publisher('/trajectory', odom, queue_size=10)
46 | pub2 = rospy.Publisher('/trajectory2', odom, queue_size=10)
47 |
48 | # Load the coordinates and orientations
49 | coordinates_path = 'path_to_txt_file_containing_odometry_data'
50 |
51 | T_odom_list = []
52 | with open(coordinates_path, 'r') as file:
53 | for line in file:
54 | if line.startswith('#'):
55 | continue # Skip comment lines
56 | parts = line.split()
57 | if parts:
58 | coordinates.append(np.array([float(parts[1]), float(parts[2]), float(parts[3])]))
59 | # print(coordinates[-1])
60 | orientations.append(np.array([float(parts[4]), float(parts[5]), float(parts[6]), float(parts[7])])) #qx, qy, qz, qw
61 | # print(orientations[-1])
62 |
63 | T_odom = np.eye(4, 4)
64 | T_odom[:3, :3] = R.from_quat(orientations[-1]).as_matrix()[:3, :3]
65 | T_odom[:3, 3] = coordinates[-1]
66 | T_odom_list.append(T_odom)
67 |
68 | #difference between odometry frame and camera frame
69 | translation = [-0.739, -0.056, -0.205] #x, y, z
70 | path_translation = [0.0, 1.0, 0.0] #x, y, z
71 | rotation = [0.466, -0.469, -0.533, 0.528] #quaternion
72 | T_imu_camera = np.eye(4, 4)
73 | T_imu_camera[:3, :3] = R.from_quat(rotation).as_matrix()[:3, :3]
74 | T_imu_camera[:3, 3] = translation
75 |
76 | # rotation = [-0.469, -0.533, 0.528, 0.466] #quaternion
77 |
78 | for i in range(len(coordinates)):
79 | T_world_camera = np.linalg.inv(T_imu_camera) @ T_odom_list[i] @ T_imu_camera
80 |
81 | coordinates[i] = T_world_camera[:3, 3]
82 | orientations[i] = R.from_matrix(T_world_camera[:3, :3]).as_quat()
83 |
84 | #create a list of odometry messages from the coord and orientation lists
85 | for i in range(len(coordinates)):
86 | # Create a new odometry message
87 | odom_msg = odom()
88 | # Set the header
89 | odom_msg.header.stamp = rospy.Time.now()
90 | odom_msg.header.frame_id = "odom"
91 |
92 | # Set the position
93 | odom_msg.pose.pose.position.x = coordinates[i][0]
94 | odom_msg.pose.pose.position.y = coordinates[i][1]
95 | odom_msg.pose.pose.position.z = coordinates[i][2]
96 | # Set the orientation
97 | odom_msg.pose.pose.orientation.x = orientations[i][0]
98 | odom_msg.pose.pose.orientation.y = orientations[i][1]
99 | odom_msg.pose.pose.orientation.z = orientations[i][2]
100 | odom_msg.pose.pose.orientation.w = orientations[i][3]
101 | # Append the message to the list
102 | directions.append(odom_msg)
103 |
104 | # publish data to ros topic
105 | # for i in range(len(coordinates)):
106 | # # Publish the message
107 | # pub.publish(directions[i])
108 | # # Sleep for 0.1 seconds
109 | # rospy.sleep(0.01)
110 | # print(f"Published message {i+1}/{len(coordinates)}", end='\r')
111 |
112 | # if i == point:
113 | # pub2.publish(directions[i])
114 |
115 | def unit_vector(vector):
116 | magnitude = np.linalg.norm(vector)
117 | if magnitude == 0:
118 | return vector
119 | return vector / magnitude
120 |
121 | def trasnform_coord(quat, coord):
122 | R1 = R.from_quat(quat).as_matrix()
123 | #transpose R1 to get the inverse
124 | return R1.T @ coord
125 |
126 | def translate_to_frame(coords, point, quat):
127 | # for a given coordinate and orientation pair
128 | New_frame_coord = []
129 | for i in range(1, len(coords)):
130 | c = trasnform_coord(quat, point - coords[i] - path_translation)
131 | New_frame_coord.append(c)
132 | # print('\n c:',c)
133 | return np.array(New_frame_coord)
134 |
135 | # #create cv2 window
136 | # cv2.namedWindow('image', cv2.WINDOW_NORMAL)
137 | # cv2.resizeWindow('image', 1280, 720)
138 |
139 | u_C2_past = np.zeros((2, 1))
140 | save_flag = True
141 | img_points = []
142 |
143 | future_steps = 50
144 | all_points = []
145 |
146 | # exit()
147 | print('length of coordinates:', len(coordinates))
148 |
149 | for i in range(1, len(coordinates)- future_steps):
150 | point = i
151 | points = translate_to_frame(coordinates[point:], coordinates[point], orientations[point])
152 | p_C2 = points.T
153 | # Project a 3D point into the pixel plane
154 | #make the point number a 6 digit string
155 | img = cv2.imread(folder_path + img_file_names[point])
156 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
157 |
158 | img_points = []
159 |
160 | u_C2_past[0] = 1280/2
161 | u_C2_past[1] = 720
162 |
163 | for j in range(1, future_steps):
164 | p_C = p_C2[:, j]
165 | _, tmp_p = camera_pinhole.project(p_C)
166 | tmp_p = tmp_p.reshape(2, 1)
167 | u_C1 = tmp_p[:, 0]
168 | tmp_p, _ = cv2.projectPoints(p_C.reshape(1, 1, 3), np.zeros((3, 1)), np.zeros((3, 1)), K, D)
169 | u_C2 = tmp_p[0, 0, :2]
170 | if u_C2[0] < camera_pinhole.width and u_C2[0] > 30 and u_C2[1] < camera_pinhole.height-20 and u_C2[1] > 0:
171 |
172 | # set points to be drawn on the image
173 | cv2.circle(img, (int(u_C2[0]), int(u_C2[1])), 5, (0, 0, 255), -1)
174 | cv2.line(img, (int(u_C2_past[0]), int(u_C2_past[1])), (int(u_C2[0]), int(u_C2[1])), (255 - j*(255/future_steps), j*(255/future_steps), 0), 2) # green line
175 |
176 | #append img_points to img_points
177 | img_points.append([int(u_C2[0]),int(u_C2[1])])
178 |
179 | u_C2_past = u_C2
180 | # print(img_points)
181 | #append img_points to all_points as another dimension
182 |
183 | all_points.append(img_points)
184 |
185 | # Display the image with the points drawn on it in the cv2 window
186 | cv2.imshow('image', img)
187 | cv2.waitKey(0)
188 |
189 | print(f"Point {point}/{len(coordinates)}", end='\r')
190 |
191 | #save all_points to a numpy file
192 | # print(all_points[-7:])
193 | print(len(all_points))
194 |
195 | #save all_pints as a json
196 | with open('OPS_grass_pixels.json', 'w') as f:
197 | json.dump(all_points, f)
198 |
199 |
200 | if __name__ == '__main__':
201 | main()
--------------------------------------------------------------------------------
/STEPP/utils/image_saver.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import rospy
4 | from sensor_msgs.msg import Image, CompressedImage
5 | from cv_bridge import CvBridge, CvBridgeError
6 | import cv2
7 | import os
8 |
9 | class ImageSaver:
10 | def __init__(self, image_topic, save_directory):
11 | # Initialize the ROS node
12 | rospy.init_node('image_saver', anonymous=True)
13 | print('Node initialized')
14 |
15 | # Create a CvBridge object
16 | self.bridge = CvBridge()
17 |
18 | # Subscribe to the image topic
19 | self.image_sub = rospy.Subscriber(image_topic, CompressedImage, self.image_callback)
20 |
21 | # Directory to save images
22 | self.save_directory = save_directory
23 | if not os.path.exists(self.save_directory):
24 | os.makedirs(self.save_directory)
25 |
26 | # Counter for naming images
27 | self.image_counter = 0
28 |
29 | def image_callback(self, msg):
30 | try:
31 | # Convert the ROS Image message to a format OpenCV can work with
32 | cv_image = self.bridge.compressed_imgmsg_to_cv2(msg, "bgr8")
33 |
34 | # Create a filename for each image
35 | filename = os.path.join(self.save_directory, "image_{:06d}.png".format(self.image_counter))
36 |
37 | # Save the image to the specified directory
38 | cv2.imwrite(filename, cv_image)
39 | rospy.loginfo("Saved image: {}".format(filename))
40 |
41 | # Increment the counter
42 | self.image_counter += 1
43 |
44 | except CvBridgeError as e:
45 | rospy.logerr("CvBridge Error: {}".format(e))
46 |
47 | if __name__ == '__main__':
48 | try:
49 | # Parameters
50 | image_topic = "/rgb/image_rect_color/compressed" # Set the image topic
51 | save_directory = "path_to_save_folder" # Set the directory to save images
52 |
53 | # Create the ImageSaver object
54 | image_saver = ImageSaver(image_topic, save_directory)
55 |
56 | # Keep the node running
57 | rospy.spin()
58 | except rospy.ROSInterruptException:
59 | pass
60 |
--------------------------------------------------------------------------------
/STEPP/utils/make_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import json
5 | import torch
6 | import torch.nn.functional as F
7 | import os
8 | from pytictac import Timer
9 | import warnings
10 | import argparse
11 |
12 | from STEPP import ROOT_DIR
13 | from STEPP.DINO import run_dino_interfacer
14 | from STEPP.DINO.dino_feature_extract import DinoInterface
15 | from STEPP.SLIC.slic_segmentation import SLIC
16 | from STEPP.utils import misc
17 | from STEPP.utils.misc import load_image
18 | from STEPP.DINO.dino_feature_extract import get_dino_features, average_dino_feature_segment
19 |
20 |
21 | class FeatureDataSet:
22 | def __init__(self, path_to_image_folder, path_to_pixels):
23 | self.img_width = 1408#1280
24 | self.img_height = 1408#720
25 | self.x_boarder = 0 #20
26 | self.y_boarder = 0 #30
27 | self.start_image_idx = 0#750
28 | self.interpolate = False
29 | self.dino_size = 'vit_small'
30 | self.use_mixed_precision = True
31 |
32 | if self.dino_size == 'vit_small':
33 | self.feature_dim = 384
34 | elif self.dino_size == 'vit_base':
35 | self.feature_dim = 768
36 | elif self.dino_size == 'vit_large':
37 | self.feature_dim = 1024
38 | elif self.dino_size == 'vit_giant':
39 | self.feature_dim = 1536
40 |
41 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42 |
43 | # Settings
44 | self.size = 700
45 | self.dino_size = "vit_small"
46 | self.patch = 14
47 | self.backbone = "dinov2"
48 |
49 | # Inference with DINO
50 | # Create DINO
51 | self.di = DinoInterface(
52 | device=self.device,
53 | backbone=self.backbone,
54 | input_size=self.size,
55 | backbone_type=self.dino_size,
56 | patch_size=self.patch,
57 | interpolate=False,
58 | use_mixed_precision=self.use_mixed_precision,
59 | )
60 |
61 | #points
62 | with open(path_to_pixels, 'r') as f:
63 | path_pixels = json.load(f)
64 |
65 | path_pixels_resized = []
66 | #remove entries that contain values of larger than 720-20 and 1280-30
67 | for pixels in path_pixels:
68 | pixels = [pixel for pixel in pixels if pixel[0] < (self.img_width-self.y_boarder) and pixel[1] < (self.img_height-self.x_boarder)]
69 | #also take off 20 from the x and 30 from the y
70 | pixels = [(pixel[0] - self.x_boarder, pixel[1] - self.y_boarder) for pixel in pixels]
71 |
72 | path_pixels_resized.append(pixels)
73 | self.path_pixels_resized = path_pixels_resized
74 |
75 | print("loaded pixels")
76 |
77 | #images
78 | self.images = sorted([os.path.join(path_to_image_folder, img) for img in os.listdir(path_to_image_folder) if img.endswith((".png", ".jpg", ".jpeg"))])
79 | #what does this do?
80 | if len(self.images) > len(self.path_pixels_resized):
81 | self.images = self.images[:-(len(self.images) -len(self.path_pixels_resized))]
82 |
83 | print("loaded images")
84 |
85 |
86 | def main(feat):
87 |
88 | #supress warnings
89 | warnings.filterwarnings("ignore")
90 | slic = SLIC(crop_x=0, crop_y=0)
91 | average_features_segments = np.zeros((1, feat.feature_dim))
92 |
93 | for i in range(len(feat.images)):
94 | if feat.path_pixels_resized[i] == []:
95 | continue
96 | img = cv2.imread(feat.images[i])
97 | segments, segmented_image = slic.Slic_segmentation_for_given_pixels(feat.path_pixels_resized[i], img)
98 | resized_segmented_image, new_segment_dict = slic.make_masks_smaller_numpy(segments.keys(), segmented_image, int(feat.size/feat.patch))
99 |
100 | tensor_img = load_image(feat.images[i]).to(feat.device)
101 |
102 | #get dino features
103 | features = feat.di.inference(tensor_img)
104 |
105 | #average dino features over segments
106 | average_features = average_dino_feature_segment(features, resized_segmented_image, new_segment_dict.keys())
107 | #convert to numpy array
108 | average_features = average_features.cpu().detach().numpy()
109 | average_features_segments = np.concatenate((average_features_segments,average_features), axis=0)
110 |
111 | print('processed image:', i,'/', len(feat.images))#, end='\r')
112 |
113 | average_features_segments = average_features_segments[1:]
114 |
115 | print('\n')
116 | print('average_features_segments shape:', average_features_segments.shape)
117 |
118 | return average_features_segments
119 |
120 |
121 | if __name__ == '__main__':
122 |
123 | path_to_image_folder = 'path_to_image_folder'
124 | path_to_pixels = 'path_to_pixels.json'
125 | data_preprocessing = FeatureDataSet(path_to_image_folder, path_to_pixels)
126 | dataset = main(data_preprocessing)
127 |
128 | #save dataset
129 | dataset_path = 'path_to_save_dataset'
130 | np.save(dataset_path, dataset)
--------------------------------------------------------------------------------
/STEPP/utils/make_unreal_data_pixel_file.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import matplotlib.pyplot as plt
5 | import json
6 | from pytictac import Timer
7 |
8 | def overlay_images(n1_path, n2_path):
9 | n1_image = cv2.imread(n1_path)
10 | n2_image = cv2.imread(n2_path)
11 | # n2_image[..., 3] = 1
12 |
13 | mask = n2_image != 0
14 |
15 | # Create an output image with all black pixels
16 | output_image = np.zeros_like(n1_image)
17 |
18 | # Apply the mask to n1_image and store the result in output_image
19 | output_image[mask] = n1_image[mask]
20 |
21 | output_image[0:520] = 0
22 |
23 | #create a list of pixel coord pairs where the image is not black
24 | pixels = []
25 | non_black_pixels = np.argwhere(np.any(output_image != 0, axis=-1))
26 | pixels = non_black_pixels[:, ::-1].tolist()
27 |
28 | return output_image, pixels
29 |
30 | path_to_image_folder = 'path_to_image_folder'
31 | path_to_trajectory_folder = 'path_to_trajectory_folder'
32 |
33 | images = sorted([os.path.join(path_to_image_folder, img) for img in os.listdir(path_to_image_folder) if img.endswith((".png", ".jpg", ".jpeg"))])
34 | trajectory_images = sorted([os.path.join(path_to_trajectory_folder, img) for img in os.listdir(path_to_trajectory_folder) if img.endswith((".png", ".jpg", ".jpeg"))])
35 |
36 |
37 |
38 | all_pixels = []
39 | for i in range(len(images)):
40 | output_img, pixels = overlay_images(images[i], trajectory_images[i])
41 |
42 | all_pixels.append(pixels)
43 |
44 | print('processed image:', i,'/', len(images), end='\r')
45 |
46 | #save the pixels to json
47 | path = 'path_to_save_pixels.json'
48 | with open(path, 'w') as f:
49 | json.dump(all_pixels, f)
50 |
51 | print('Finished saving pixels to:\n', path)
52 |
--------------------------------------------------------------------------------
/STEPP/utils/misc.py:
--------------------------------------------------------------------------------
1 | from matplotlib.backends.backend_agg import FigureCanvasAgg
2 | from PIL import Image
3 | import numpy as np
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import os
7 | import cv2
8 | from STEPP import ROOT_DIR
9 |
10 | def make_results_folder(name):
11 | path = os.path.join(ROOT_DIR, "results", name)
12 | os.makedirs(path, exist_ok=True)
13 | return path
14 |
15 | def get_img_from_fig(fig, dpi=180):
16 | """Returns an image as numpy array from figure
17 |
18 | Args:
19 | fig (matplotlib.figure.Figure): Input figure.
20 | dpi (int, optional): Resolution. Defaults to 180.
21 |
22 | Returns:
23 | buf (np.array, dtype=np.uint8 or PIL.Image.Image): Resulting image.
24 | """
25 | fig.set_dpi(dpi)
26 | canvas = FigureCanvasAgg(fig)
27 | # Retrieve a view on the renderer buffer
28 | canvas.draw()
29 | buf = canvas.buffer_rgba()
30 | # convert to a NumPy array
31 | buf = np.asarray(buf)
32 | buf = Image.fromarray(buf)
33 | buf = buf.convert("RGB")
34 | return buf
35 |
36 | def load_test_image():
37 | np_img = cv2.imread(os.path.join(ROOT_DIR, "path_to_test_image"))
38 | np_img = np_img[200:-200, 200:-200]
39 | img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB))
40 | img = img.permute(2, 0, 1)
41 | img = (img.type(torch.float32) / 255)[None]
42 | return img
43 |
44 | def load_image(path):
45 | np_img = cv2.imread(path)
46 | img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB))
47 | img = img.permute(2, 0, 1)
48 | img = (img.type(torch.float32) / 255)[None]
49 | return img
50 |
51 | def _remove_axes(ax):
52 | ax.xaxis.set_major_formatter(plt.NullFormatter())
53 | ax.yaxis.set_major_formatter(plt.NullFormatter())
54 | ax.set_xticks([])
55 | ax.set_yticks([])
56 |
57 | def remove_axes(axes):
58 | if len(axes.shape) == 2:
59 | for ax1 in axes:
60 | for ax in ax1:
61 | _remove_axes(ax)
62 | else:
63 | for ax in axes:
64 | _remove_axes(ax)
65 |
66 | def save_dataset(dataset, path):
67 | #create a folder if it does not exist
68 | folder = os.path.dirname(path + '/' + 'dataset')
69 | os.makedirs(folder, exist_ok=True)
70 | np.save(folder, dataset)
71 |
--------------------------------------------------------------------------------
/STEPP/utils/rename_files.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import cv2
3 | import os
4 |
5 | def rename_files_in_folder(folder_path):
6 | for i, filename in enumerate(os.listdir(folder_path)):
7 | os.rename(os.path.join(folder_path, filename), os.path.join(folder_path, f"{int(filename[:-4]):06d}.png"))
8 |
9 | if __name__ == '__main__':
10 | folder_path = 'path_to_folder'
11 | rename_files_in_folder(folder_path)
--------------------------------------------------------------------------------
/STEPP/utils/testing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from STEPP.model.mlp import ReconstructMLP
3 | from STEPP.utils.misc import load_image
4 | from STEPP.SLIC.slic_segmentation import SLIC
5 | from STEPP.utils.make_dataset import FeatureDataSet
6 | from STEPP.DINO.dino_feature_extract import DinoInterface, get_dino_features, average_dino_feature_segment
7 | import cv2
8 | import matplotlib.pyplot as plt
9 | from matplotlib import cm
10 | import numpy as np
11 | import torch.nn as nn
12 | from matplotlib.colors import LinearSegmentedColormap
13 | import time
14 | import torch.nn.functional as F
15 | import warnings
16 | from PIL import Image as PILImage
17 | import seaborn as sns
18 | from pytictac import Timer
19 |
20 | warnings.filterwarnings("ignore")
21 |
22 |
23 | def test_feature_reconstructor(mode, model_path, image_path, thresh):
24 | # mode = 1 for running segmentwise inference
25 | # mode = 2 for running whole image inference
26 |
27 | device = (
28 | "cuda"
29 | if torch.cuda.is_available()
30 | else "cpu"
31 | )
32 |
33 | # load the model
34 | model = ReconstructMLP(384,[256, 128, 64, 32, 64, 128, 256]) # [256, 32, 384]) #
35 | #load the model with the weights
36 | model.load_state_dict(torch.load(model_path))
37 |
38 | model.to(device)
39 | return test_feature_reconstructor_with_model(mode, model, image_path, thresh)
40 |
41 | def test_feature_reconstructor_with_model(mode,model, image_path, thresh):
42 | start = time.time()
43 |
44 | alpha = 0.5
45 |
46 | #load an image
47 | img = cv2.imread(image_path)
48 | torch_img = load_image(image_path)
49 | H, W, D = img.shape
50 | H1 = 64
51 | new_features_size = (H, H)
52 |
53 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
54 | small_image = cv2.resize(img, (new_features_size))
55 | # small_image = cv2.cvtColor(small_image, cv2.COLOR_BGR2RGB)
56 |
57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58 |
59 | threshold = thresh#= 0.1
60 |
61 | # Settings
62 | size = 700
63 | dino_size = "vit_small"
64 | patch = 14
65 | backbone = "dinov2"
66 |
67 | # Inference with DINO
68 | # Create DINO
69 | di = DinoInterface(
70 | device=device,
71 | backbone=backbone,
72 | input_size=size,
73 | backbone_type=dino_size,
74 | patch_size=patch,
75 | interpolate=False,
76 | use_mixed_precision = False,
77 | )
78 |
79 | torch_img = torch.from_numpy(small_image)
80 | torch_img = torch_img.permute(2, 0, 1)
81 | torch_img = (torch_img.type(torch.float32) / 255)[None].to(device)
82 | # torch_img.to(self.device)
83 | dino_size = 'vit_small'
84 | # features = get_dino_features(torch_img, dino_size, False)
85 | features = di.inference(torch_img)
86 |
87 | print('features shape',features.shape)
88 |
89 | if mode == 'segment_wise':
90 | #segment the whole image and get each pixel for each segment value
91 | slic = SLIC(crop_x=0, crop_y=0)
92 | segments, segmented_image = slic.Slic_segmentation_for_all_pixels(small_image)
93 | print('segmented image shape:', segmented_image.shape)
94 | resized_segmented_img, new_segment_dict = slic.make_masks_smaller_numpy(segments, segmented_image, 50)
95 |
96 | #average the features over the segments
97 | average_features = average_dino_feature_segment(features, resized_segmented_img)
98 |
99 | # Forward pass the entire batch
100 | reconstructed_features = model(average_features)
101 |
102 | # Calculate the losses for the entire batch
103 | loss_fn = nn.MSELoss(reduction='none')
104 | losses = loss_fn(average_features, reconstructed_features)
105 | losses = losses.mean(dim=1).cpu().detach().numpy() # Average the losses across the feature dimension
106 |
107 | #set the segment values of the segmented image to equal the loss in losses
108 | for key, loss in zip(new_segment_dict.keys(), losses):
109 | segmented_image = np.where(segmented_image == int(key), loss, segmented_image)
110 |
111 | segmented_image - np.where(segmented_image > 10, 10, segmented_image)
112 |
113 | # Normalize the segmented image values to the range [0, 0.15]
114 | segmented_image = (segmented_image - segmented_image.min()) / (segmented_image.max() - segmented_image.min()) * 0.45
115 |
116 | # Change all values above 1 to 1
117 | segmented_image = np.where(segmented_image > threshold, threshold, segmented_image)
118 | # segmented_image = np.where(segmented_image < self.threshold, 0.0, segmented_image)
119 |
120 | # # Calculate the extent to center the segmented image
121 | # original_height, original_width = small_image.shape[:2]
122 | # segmented_height, segmented_width = segmented_image.shape[:2]
123 |
124 | # # Crop the original image to the segmented image size
125 | # x_offset = (original_width - segmented_width) // 2
126 | # y_offset = (original_height - segmented_height) // 2
127 | # small_image = img[y_offset:y_offset + segmented_height, x_offset:x_offset + segmented_width]
128 |
129 | # Create the colormap
130 | s = 0.3 # If bigger, get more fine-grained green, if smaller get more fine-grained red
131 | cmap = cm.get_cmap("RdYlBu", 256) # or RdYlGn
132 | cmap = np.vstack([
133 | cmap(np.linspace(0, s, 128)),
134 | cmap(np.linspace(1 - s, 1.0, 128))
135 | ]) # Stretch the colormap
136 | cmap = (cmap[:, :3] * 255).astype(np.uint8)
137 |
138 | # Reverse the colormap if needed
139 | cmap = cmap[::-1]
140 |
141 | # Normalize the segmented image values to the range [0, 255]
142 | segmented_normalized = ((segmented_image - segmented_image.min()) /
143 | (segmented_image.max() - segmented_image.min()) * 255).astype(np.uint8)
144 |
145 | # Map the segmented image values to colors
146 | color_mapped_img = cmap[segmented_normalized]
147 |
148 | # Convert images to RGBA
149 | img_rgba = PILImage.fromarray(np.uint8(small_image)).convert("RGBA")
150 | seg_rgba = PILImage.fromarray(color_mapped_img).convert("RGBA")
151 |
152 | # Adjust the alpha channel to vary the transparency
153 | seg_rgba_np = np.array(seg_rgba)
154 | alpha_channel = seg_rgba_np[:, :, 3] # Extract alpha channel
155 | alpha_channel = (alpha_channel * 1.0).astype(np.uint8) # Adjust transparency (50% transparent)
156 | seg_rgba_np[:, :, 3] = alpha_channel # Update alpha channel
157 | seg_rgba = PILImage.fromarray(seg_rgba_np)
158 |
159 | # Alpha composite the images
160 | img_new = PILImage.alpha_composite(img_rgba, seg_rgba)
161 | img_rgb = img_new.convert("RGB")
162 |
163 | #resize the image to the original size
164 | img_rgb = img_rgb.resize((W,H))
165 |
166 | # Overlay the segmented image on the original image
167 | fig = plt.figure(figsize=(10, 10))
168 | plt.imshow(img_rgb)
169 | plt.title(mode + '_reconstruction_' + dino_size + '_threshold_' + str(threshold))
170 | plt.axis('off')
171 |
172 | elif mode == 'pixel_wise':
173 |
174 | # torch shape is (1, 384, 64, 64)
175 | features = features.permute(2, 3, 1, 0)
176 |
177 | #change the shape to (4096, 384)
178 | features_tensor = features.reshape(50*50, 384)
179 |
180 | with Timer('Inference: '):
181 | # Forward pass the entire batch
182 | reconstructed_features = model(features_tensor)
183 |
184 | # Calculate the losses for the entire batch
185 | loss_fn = nn.MSELoss(reduction='none')
186 | losses = loss_fn(features_tensor, reconstructed_features)
187 | losses = losses.mean(dim=1).cpu().detach().numpy() # Average the losses across the feature dimension
188 |
189 | #reshape losses to be 64x64
190 | losses = losses.reshape(50, 50)
191 |
192 | #resize the cost map to the original image size
193 | cost_map = cv2.resize(losses, (H, H))
194 |
195 | print('time to run inference:', time.time()-start)
196 |
197 | cost_map = np.where(cost_map > 10,10, cost_map)
198 |
199 |
200 | # Normalize the segmented image values to the range [0, 0.15]
201 | cost_map = (cost_map - cost_map.min()) / (cost_map.max() - cost_map.min()) * 0.45
202 |
203 |
204 | #change all values above 1 to 1
205 | # cost_map = np.where(cost_map < 3, 0, cost_map)
206 | cost_map = np.where(cost_map > threshold, threshold, cost_map)
207 |
208 | # Create the colormap
209 | s = 0.3 # If bigger, get more fine-grained green, if smaller get more fine-grained red
210 | cmap = cm.get_cmap("RdYlBu", 256) # or RdYlGn
211 | cmap = np.vstack([
212 | cmap(np.linspace(0, s, 128)),
213 | cmap(np.linspace(1 - s, 1.0, 128))
214 | ]) # Stretch the colormap
215 | cmap = (cmap[:, :3] * 255).astype(np.uint8)
216 |
217 | # Reverse the colormap if needed
218 | cmap = cmap[::-1]
219 |
220 | # Normalize the segmented image values to the range [0, 255]
221 | cost_map_normalized = ((cost_map - cost_map.min()) /
222 | (cost_map.max() - cost_map.min()) * 255).astype(np.uint8)
223 |
224 | # Map the segmented image values to colors
225 | color_mapped_img = cmap[cost_map_normalized]
226 |
227 | # Convert images to RGBA
228 | img_rgba = PILImage.fromarray(np.uint8(small_image)).convert("RGBA")
229 | seg_rgba = PILImage.fromarray(color_mapped_img).convert("RGBA")
230 |
231 | # Adjust the alpha channel to vary the transparency
232 | seg_rgba_np = np.array(seg_rgba)
233 | alpha_channel = seg_rgba_np[:, :, 3] # Extract alpha channel
234 | alpha_channel = (alpha_channel * 0.75).astype(np.uint8) # Adjust transparency (50% transparent)
235 | seg_rgba_np[:, :, 3] = alpha_channel # Update alpha channel
236 | seg_rgba = PILImage.fromarray(seg_rgba_np)
237 |
238 | # Alpha composite the images
239 | img_new = PILImage.alpha_composite(img_rgba, seg_rgba)
240 | img_rgb = img_new.convert("RGB")
241 |
242 | #resize the image to the original size
243 | img_rgb = img_rgb.resize((W,H))
244 |
245 | # Overlay the segmented image on the original image
246 | fig = plt.figure(figsize=(10, 10))
247 | plt.imshow(img_rgb)
248 | plt.title(mode + '_reconstruction_' + dino_size + '_threshold_' + str(threshold))
249 | plt.axis('off')
250 |
251 | # plt.show()
252 | return fig
253 |
254 | if __name__ == '__main__':
255 | model_path = 'path_to_model.pth'
256 | image_path = 'path_to_test_image.png'
257 | threshold = 0.15
258 | test_feature_reconstructor('segment_wise',model_path, image_path, threshold)
259 |
260 | #save figure to test folder
261 | count = time.strftime("%Y%m%d-%H%M")
262 | plt.savefig('folder_to_save_figure'+ count +'.png')
263 | plt.show()
--------------------------------------------------------------------------------
/STEPP_ros/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.0.2)
2 | project(STEPP_ros)
3 |
4 | find_package(catkin REQUIRED COMPONENTS
5 | roscpp
6 | geometry_msgs
7 | octomap_ros
8 | octomap_msgs
9 | pcl_conversions
10 | pcl_ros
11 | rospy
12 | sensor_msgs
13 | std_msgs
14 | cv_bridge
15 | grid_map_ros
16 | image_transport
17 | message_generation
18 | )
19 |
20 | add_message_files(
21 | FILES
22 | Float32Stamped.msg
23 | )
24 |
25 | generate_messages(
26 | DEPENDENCIES
27 | std_msgs
28 | )
29 |
30 |
31 | find_package(PCL REQUIRED)
32 | find_package(OpenMP REQUIRED)
33 | find_package(OpenCV REQUIRED)
34 |
35 | catkin_package(
36 | CATKIN_DEPENDS roscpp rospy sensor_msgs std_msgs nav_msgs cv_bridge message_runtime
37 | )
38 |
39 | include_directories(
40 | ${catkin_INCLUDE_DIRS}
41 | ${PCL_INCLUDE_DIRS}
42 | )
43 |
44 | link_directories(${PCL_LIBRARY_DIRS})
45 | add_definitions(${PCL_DEFINITIONS})
46 |
47 | # Add your C++ source files here
48 | add_executable(depth_projection_synchronized src/depth_projection_synchronized.cpp) # src/utils.cpp)
49 | target_link_libraries(depth_projection_synchronized ${catkin_LIBRARIES} ${PCL_LIBRARIES})
50 | add_dependencies(depth_projection_synchronized ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS})
51 |
52 | # Make the Python script executable
53 | catkin_install_python(PROGRAMS scripts/inference_node.py
54 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
55 | )
56 |
--------------------------------------------------------------------------------
/STEPP_ros/config/model_config.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/STEPP_ros/config/model_config.yaml
--------------------------------------------------------------------------------
/STEPP_ros/launch/STEPP.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/STEPP_ros/msg/Float32Stamped.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 | std_msgs/Float32MultiArray data
3 |
--------------------------------------------------------------------------------
/STEPP_ros/package.xml:
--------------------------------------------------------------------------------
1 |
2 | STEPP_ros
3 | 0.0.1
4 | Traversability estimation package using image features
5 |
6 | Sebastian Aegidius
7 | MIT
8 |
9 | catkin
10 | roscpp
11 | rospy
12 | sensor_msgs
13 | std_msgs
14 | nav_msgs
15 | cv_bridge
16 | message_generation
17 |
18 | message_runtime
19 | roscpp
20 | rospy
21 | sensor_msgs
22 | std_msgs
23 | nav_msgs
24 | cv_bridge
25 | torch
26 |
27 |
28 |
--------------------------------------------------------------------------------
/STEPP_ros/scripts/inference_node.py:
--------------------------------------------------------------------------------
1 | #!/Rocket_ssd/miniconda3/envs/STEPP/bin/python3
2 |
3 | import rospy
4 | from sensor_msgs.msg import Image, CompressedImage
5 | from std_msgs.msg import Float32MultiArray, MultiArrayDimension
6 | import torch
7 | import cv2
8 | from cv_bridge import CvBridge
9 | import numpy as np
10 | import torch.nn as nn
11 | import time
12 | from PIL import Image as PILImage
13 | from torchvision import transforms
14 | # import seaborn as sns
15 | from matplotlib import cm
16 | import warnings
17 | from queue import Queue
18 | from threading import Thread, Lock
19 |
20 | from STEPP.DINO.backbone import get_backbone
21 | from STEPP.DINO.dino_feature_extract import DinoInterface
22 | from STEPP.DINO.dino_feature_extract import get_dino_features, average_dino_feature_segment, average_dino_feature_segment_tensor
23 | from STEPP.SLIC.slic_segmentation import SLIC
24 | from STEPP.model.mlp import ReconstructMLP
25 | from STEPP_ros.msg import Float32Stamped
26 |
27 | warnings.filterwarnings("ignore")
28 | CV_BRIDGE = CvBridge()
29 | TO_TENSOR = transforms.ToTensor()
30 | TO_PIL_IMAGE = transforms.ToPILImage()
31 |
32 | from pytictac import Timer
33 |
34 | class InferenceNode:
35 | def __init__(self):
36 | self.image_queue = Queue(maxsize=1)
37 | self.lock = Lock()
38 |
39 | self.processing = False
40 | self.image_sub = rospy.Subscriber('/camera/color/image_raw/compressed', CompressedImage, self.image_callback)
41 | self.inference_pub = rospy.Publisher('/inference/result', Float32MultiArray, queue_size=200)
42 | self.inference_stamped_pub = rospy.Publisher('/inference/results_stamped_post', Float32Stamped, queue_size=200)
43 | self.visu_traversability_pub = rospy.Publisher('/inference/visu_traversability_post', Image, queue_size=200)
44 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 |
46 | # Threshold for traversability
47 | self.threshold = 0.2
48 |
49 | # Settings
50 | self.size = 700
51 | self.dino_size = "vit_small"
52 | self.patch = 14
53 | self.backbone = "dinov2"
54 | self.ump = rospy.get_param('~ump', True)
55 | self.cutoff = rospy.get_param('~cutoff', 1.2)
56 | print(self.cutoff)
57 | print(type(self.cutoff))
58 |
59 | # Inference with DINO
60 | # Create DINO
61 | self.di = DinoInterface(
62 | device=self.device,
63 | backbone=self.backbone,
64 | input_size=self.size,
65 | backbone_type=self.dino_size,
66 | patch_size=self.patch,
67 | interpolate=False,
68 | use_mixed_precision = self.ump,
69 | )
70 |
71 | self.slic = SLIC(crop_x=0, crop_y=0)
72 |
73 | # Load model architecture
74 | # self.model = ReconstructMLP(384, [256, 64, 32, 16, 32, 64, 256])
75 | self.model = ReconstructMLP(384, [256, 128, 64, 32, 64, 128, 256])
76 |
77 | # Load model weights
78 | state_dict = torch.load(rospy.get_param('~model_path'))
79 | self.model.load_state_dict(state_dict)
80 |
81 | # Move model to the device
82 | self.model.to(self.device)
83 |
84 | self.visualize = rospy.get_param('~visualize', False)
85 |
86 | self.thread = Thread(target=self.process_images)
87 | self.thread.start()
88 |
89 | print('Inference node initialized')
90 |
91 | def publish_matrix(self, matrix):
92 | msg = Float32MultiArray()
93 | msg.data = matrix.flatten().tolist() # Flatten the matrix and convert to list
94 | msg.layout.dim.append(MultiArrayDimension())
95 | msg.layout.dim[0].label = "rows"
96 | msg.layout.dim[0].size = matrix.shape[0]
97 | msg.layout.dim[0].stride = matrix.shape[1] # stride is the number of columns
98 | msg.layout.dim.append(MultiArrayDimension())
99 | msg.layout.dim[1].label = "columns"
100 | msg.layout.dim[1].size = matrix.shape[1]
101 | msg.layout.dim[1].stride = 1 # stride is 1 for columns
102 | self.inference_pub.publish(msg)
103 |
104 | def publish_array_stamped(self, matrix):
105 | msg = Float32Stamped()
106 |
107 | # Get the current time in nanoseconds
108 | msg.header.stamp = rospy.Time.now()
109 |
110 | msg.data = Float32MultiArray()
111 | msg.data.data = matrix.flatten().tolist() # Flatten the matrix and convert to list
112 | msg.data.layout.dim.append(MultiArrayDimension())
113 | msg.data.layout.dim[0].label = "rows"
114 | msg.data.layout.dim[0].size = matrix.shape[0]
115 | msg.data.layout.dim[0].stride = matrix.shape[1] # stride is the number of columns
116 | msg.data.layout.dim.append(MultiArrayDimension())
117 | msg.data.layout.dim[1].label = "columns"
118 | msg.data.layout.dim[1].size = matrix.shape[1]
119 | msg.data.layout.dim[1].stride = 1 # stride is 1 for columns
120 | self.inference_stamped_pub.publish(msg)
121 |
122 | def process_images(self):
123 | while not rospy.is_shutdown():
124 | # with Timer("Full loop"):
125 | image_data = self.image_queue.get()
126 | if image_data is None:
127 | break
128 |
129 | with self.lock:
130 | if isinstance(image_data, CompressedImage):
131 | cv_image = CV_BRIDGE.compressed_imgmsg_to_cv2(image_data, desired_encoding="bgr8")
132 | else:
133 | cv_image = CV_BRIDGE.imgmsg_to_cv2(image_data, desired_encoding="bgr8")
134 |
135 | try:
136 | traversability_array, inference_img = self.inference_image(cv_image)
137 | except Exception as e:
138 | print(f'Error: {e}')
139 | self.processing = False
140 | continue
141 |
142 | # self.publish_matrix(traversability_array)
143 | self.publish_array_stamped(traversability_array)
144 |
145 | if self.visualize:
146 | self.visu_traversability_pub.publish(CV_BRIDGE.cv2_to_imgmsg(np.array(inference_img), "rgb8"))
147 |
148 | self.processing = False
149 | # print('-'*10)
150 |
151 | def inference_image(self, image):
152 | # Load an image
153 | org_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
154 | img = cv2.resize(org_img, (self.size, self.size))
155 | H, W, D = org_img.shape
156 |
157 | # with Timer("DINO feature extraction"):
158 | # Get the dino features
159 | torch_img = torch.from_numpy(img)
160 | torch_img = torch_img.permute(2, 0, 1)
161 | torch_img = (torch_img.type(torch.float32) / 255)[None].to(self.device)
162 | # torch_img.to(self.device)
163 | dino_size = 'vit_small'
164 | # features = get_dino_features(torch_img, dino_size, False)
165 | features = self.di.inference(torch_img)
166 |
167 | # Segment the whole image and get each pixel for each segment value
168 | # with Timer("SLIC"):
169 | segments, segmented_image = self.slic.Slic_segmentation_for_all_pixels_torch(img)
170 | # with Timer("Make masks smaller"):
171 | resized_segmented_img, new_segment_dict = self.slic.make_masks_smaller_torch(segments, segmented_image, int(self.size/self.patch), return_dict=False)
172 | # Average the features over the segments
173 | # with Timer("Average dino feature"):
174 | average_features = average_dino_feature_segment_tensor(features, resized_segmented_img).to(self.device)
175 | # with Timer("Forward pass"):
176 | # Forward pass the entire batch
177 | reconstructed_features = self.model(average_features)
178 |
179 | # Calculate the losses for the entire batch
180 | # with Timer("Loss calculation"):
181 | loss_fn = nn.MSELoss(reduction='none')
182 | losses = loss_fn(average_features, reconstructed_features)
183 | losses = losses.mean(dim=1).cpu().detach().numpy() # Average the losses across the feature dimension
184 |
185 | # with Timer("Set segment values optimized"):
186 | segmented_image = segmented_image.cpu().detach().numpy()
187 | # Get the unique keys from the resized segmented image
188 | unique_keys = np.unique(resized_segmented_img.cpu().detach().numpy()).astype(int)
189 | # Create an array that maps the unique segment values to the corresponding losses
190 | max_segment_value = np.max(segmented_image)
191 | default_loss = 1.0
192 | mapping_array = np.full(max_segment_value + 1, default_loss)
193 | # Fill the mapping array with the corresponding losses
194 | mapping_array[unique_keys] = losses
195 | # Use the mapping array to replace values in segmented_image
196 | segmented_image = mapping_array[segmented_image]
197 |
198 | #cuttoff the values at 10
199 | segmented_image = np.where(segmented_image > 10, 10, segmented_image)
200 |
201 | # Normalize the segmented image values to the range [0, 0.15]
202 | segmented_image = ((segmented_image - segmented_image.min()) / (segmented_image.max() - segmented_image.min())) * self.cutoff
203 |
204 | # Change all values above 1 to 1
205 | segmented_image = np.where(segmented_image > self.threshold, self.threshold, segmented_image)
206 | # segmented_image = np.where(segmented_image < self.threshold, 0.0, segmented_image)
207 |
208 | if self.visualize:
209 | # with Timer("image processing"):
210 | # Create the colormap
211 | s = 0.3 # If bigger, get more fine-grained green, if smaller get more fine-grained red
212 | cmap = cm.get_cmap("RdYlBu", 256) # or RdYlGn
213 | cmap = np.vstack([
214 | cmap(np.linspace(0, s, 128)),
215 | cmap(np.linspace(1 - s, 1.0, 128))
216 | ]) # Stretch the colormap
217 | cmap = (cmap[:, :3] * 255).astype(np.uint8)
218 |
219 | # Reverse the colormap if needed
220 | cmap = cmap[::-1]
221 |
222 | # Normalize the segmented image values to the range [0, 255]
223 | segmented_normalized = ((segmented_image - segmented_image.min()) /
224 | (segmented_image.max() - segmented_image.min()) * 255).astype(np.uint8)
225 |
226 | # Map the segmented image values to colors
227 | color_mapped_img = cmap[segmented_normalized]
228 |
229 | # Convert images to RGBA
230 | img_rgba = PILImage.fromarray(np.uint8(img)).convert("RGBA")
231 | seg_rgba = PILImage.fromarray(color_mapped_img).convert("RGBA")
232 |
233 | # Adjust the alpha channel to vary the transparency
234 | seg_rgba_np = np.array(seg_rgba)
235 | alpha_channel = seg_rgba_np[:, :, 3] # Extract alpha channel
236 | alpha_channel = (alpha_channel * 0.5).astype(np.uint8) # Adjust transparency (50% transparent)
237 | seg_rgba_np[:, :, 3] = alpha_channel # Update alpha channel
238 | seg_rgba = PILImage.fromarray(seg_rgba_np)
239 |
240 | # Alpha composite the images
241 | img_new = PILImage.alpha_composite(img_rgba, seg_rgba)
242 | img_rgb = img_new.convert("RGB")
243 |
244 | #resize the image and the segmented image to the original size
245 | img_rgb = img_rgb.resize((W,H))
246 | segmented_image = cv2.resize(segmented_image, (W,H))
247 |
248 | return segmented_image, img_rgb
249 | else:
250 | segmented_image = cv2.resize(segmented_image, (W,H))
251 |
252 | return segmented_image, None
253 |
254 | def image_callback(self, data):
255 | if not self.processing:
256 | with self.lock:
257 | if not self.image_queue.full():
258 | self.image_queue.put(data)
259 | self.processing = True
260 |
261 | if __name__ == '__main__':
262 | print('Starting inference node')
263 | rospy.init_node('inference_node')
264 | node = InferenceNode()
265 | rospy.spin()
266 |
--------------------------------------------------------------------------------
/STEPP_ros/src/depth_projection_synchronized.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | using namespace std;
24 |
25 | const double PI = 3.1415926;
26 | double depthCloudTime = 0.0;
27 | double systemInitTime = 0;
28 | bool systemInited = false;
29 | bool firstLoss = true;
30 | bool firstStampedLoss = true;
31 | bool newDepthCloud = false;
32 | float vehicleX = 0, vehicleY = 0, vehicleZ = 0;
33 | float vehicleRoll = 0, vehiclePitch = 0, vehicleYaw = 0;
34 | float sinVehicleRoll = 0, cosVehicleRoll = 0;
35 | float sinVehiclePitch = 0, cosVehiclePitch = 0;
36 | float sinVehicleYaw = 0, cosVehicleYaw = 0;
37 | float voxel_size_ = 0.1;
38 | double noDecayDis = 5.0;
39 | double minDis = 1.5;
40 | double clearingDis = 3.0;
41 | double vehicleHeight = 0.5;
42 | double decayTime = 8.0;
43 | double height = 720;
44 | double width = 1280;
45 | float fovy;
46 | float fovx;
47 | float azimuth_buff = 0.0;
48 | int rows = 1, cols = 1;
49 | int row_stride = 1, col_stride = 1;
50 |
51 | Eigen::Matrix4f cameraToMapTransform;
52 | ros::Publisher cloudPub;
53 |
54 | pcl::VoxelGrid downSizeFilter;
55 |
56 | struct CameraIntrinsics {
57 | double fx;
58 | double fy;
59 | double cx;
60 | double cy;
61 | };
62 |
63 | CameraIntrinsics intrinsics;
64 | tf::Transform odomTransform;
65 | std_msgs::Float32MultiArray loss;
66 | STEPP_ros::Float32Stamped losStamped;
67 |
68 | pcl::PointCloud::Ptr
69 | cloud(new pcl::PointCloud);
70 | pcl::PointCloud::Ptr
71 | sparseCloud(new pcl::PointCloud);
72 | pcl::PointCloud::Ptr
73 | transformedCloud(new pcl::PointCloud);
74 | pcl::PointCloud::Ptr
75 | terrainCloud(new pcl::PointCloud);
76 | pcl::PointCloud::Ptr
77 | sparseTerrainCloud(new pcl::PointCloud);
78 | pcl::PointCloud::Ptr
79 | currentCloud(new pcl::PointCloud);
80 | pcl::PointCloud::Ptr
81 | pubCloud(new pcl::PointCloud);
82 |
83 | void setCameraIntrinsics(const std::string& cameraType) {
84 | ROS_INFO("Setting camera intrinsics for %s camera", cameraType.c_str());
85 | if (cameraType == "D455") {
86 | intrinsics = {634.3491821289062, 632.8595581054688, 631.8179931640625, 375.0325622558594};
87 | height = 720;
88 | width = 1280;
89 | } else if (cameraType == "zed2") {
90 | intrinsics = {534.3699951171875, 534.47998046875, 477.2049865722656, 262.4590148925781};
91 | height = 540;
92 | width = 960-2*azimuth_buff;
93 | } else if (cameraType == "cmu_sim") {
94 | intrinsics = {205.46963709898583, 205.46963709898583, 320.5, 180.5};
95 | height = 360;
96 | width = 640-2*azimuth_buff;
97 | } else {
98 | ROS_ERROR("Invalid camera type specified. Please choose from 'D455', 'zed2', or 'cmu_sim'.");
99 | ros::shutdown();
100 | }
101 |
102 | fovy = 2 * atan(height / (2 * intrinsics.fy));
103 | fovx = 2 * atan(width / (2 * intrinsics.fx));
104 | }
105 |
106 | // Convert 2D pixel coordinates to 3D point
107 | pcl::PointXYZ convertTo3DPoint(int u, int v, float depth, const CameraIntrinsics& intrinsics) {
108 | pcl::PointXYZ point;
109 | point.z = depth;
110 | point.x = (u - intrinsics.cx) / intrinsics.fx * depth;
111 | point.y = (v - intrinsics.cy) / intrinsics.fy * depth;
112 | return point;
113 | }
114 |
115 | void callback(const sensor_msgs::Image::ConstPtr& depthMsg,
116 | const nav_msgs::Odometry::ConstPtr& odomMsg,
117 | const STEPP_ros::Float32StampedConstPtr& customMsg) {
118 |
119 | // if (loss.data.empty()) { // Check if the loss data is not initialized
120 | // ROS_WARN("Loss data not available yet.");
121 | // return; // Skip this callback cycle
122 | // }
123 | if (firstStampedLoss) {
124 | rows = customMsg->data.layout.dim[0].size;
125 | cols = customMsg->data.layout.dim[1].size;
126 | row_stride = customMsg->data.layout.dim[0].stride;
127 | col_stride = customMsg->data.layout.dim[1].stride;
128 | firstStampedLoss = false;
129 | }
130 | // losStamped = *customMsg;
131 |
132 | // Extract the position and orientation from the odometry message
133 | double roll, pitch, yaw;
134 | geometry_msgs::Point position = odomMsg->pose.pose.position;
135 | geometry_msgs::Quaternion orientation = odomMsg->pose.pose.orientation;
136 | tf::Matrix3x3(tf::Quaternion(orientation.x, orientation.y, orientation.z, orientation.w))
137 | .getRPY(roll, pitch, yaw);
138 |
139 | vehicleX = odomMsg->pose.pose.position.x;
140 | vehicleY = odomMsg->pose.pose.position.y;
141 | vehicleZ = odomMsg->pose.pose.position.z;
142 |
143 | //temp [7.251, -10.919, -3.618]
144 | // vehicleX = vehicleX - 7.251;
145 | // vehicleY = vehicleY + 10.919;
146 | // vehicleZ = vehicleZ + 3.618;
147 |
148 | vehicleRoll = roll;
149 | vehiclePitch = pitch;
150 | vehicleYaw = yaw;
151 |
152 | sinVehicleRoll = sin(vehicleRoll);
153 | cosVehicleRoll = cos(vehicleRoll);
154 | sinVehiclePitch = sin(vehiclePitch);
155 | cosVehiclePitch = cos(vehiclePitch);
156 | sinVehicleYaw = sin(vehicleYaw);
157 | cosVehicleYaw = cos(vehicleYaw);
158 |
159 | // Convert the position and orientation into a transform
160 | tf::Transform transform;
161 | transform.setOrigin(tf::Vector3(position.x, position.y, position.z));
162 | tf::Quaternion quat(orientation.x, orientation.y, orientation.z, orientation.w);
163 | transform.setRotation(quat);
164 |
165 | // Store the transformation to be used when processing the point cloud
166 | odomTransform = transform;
167 |
168 | // Extract the depth image from the depth message
169 | depthCloudTime = depthMsg->header.stamp.toSec();
170 |
171 | if (!systemInited) {
172 | systemInitTime = depthCloudTime;
173 | systemInited = true;
174 | }
175 |
176 | cloud->clear();
177 | cv_bridge::CvImageConstPtr cv_ptr;
178 | try {
179 | cv_ptr = cv_bridge::toCvShare(depthMsg, depthMsg->encoding);
180 | } catch (cv_bridge::Exception& e) {
181 | ROS_ERROR("cv_bridge exception: %s", e.what());
182 | return;
183 | }
184 |
185 | if (depthMsg->encoding == sensor_msgs::image_encodings::TYPE_32FC1) {
186 | for (int v = 0; v < depthMsg->height; ++v) {
187 | for (int u = azimuth_buff; u < depthMsg->width-azimuth_buff; ++u) {
188 | float depth = cv_ptr->image.at(v, u); // Access the depth value as float (meters)
189 | if (depth > 0) { // Check for valid depth
190 | pcl::PointXYZ point = convertTo3DPoint(u, v, depth, intrinsics);
191 | pcl::PointXYZINormal iPoint;
192 | iPoint.x = point.x;
193 | iPoint.y = point.y;
194 | iPoint.z = point.z;
195 | iPoint.intensity = systemInitTime - depthCloudTime;;
196 | iPoint.curvature = customMsg->data.data[v * row_stride + u * col_stride];
197 | cloud->points.push_back(iPoint);
198 | }
199 | }
200 | }
201 | } else if (depthMsg->encoding == sensor_msgs::image_encodings::TYPE_16UC1) {
202 | for (int v = 0; v < depthMsg->height; ++v) {
203 | for (int u = azimuth_buff; u < depthMsg->width-azimuth_buff; ++u) {
204 | uint16_t depth_mm = cv_ptr->image.at(v, u); // Access the depth value as uint16_t
205 | float depth = depth_mm * 0.001f; // Convert millimeters to meters
206 | if (depth != 0) { // Check for valid depth
207 | pcl::PointXYZ point = convertTo3DPoint(u, v, depth, intrinsics);
208 | pcl::PointXYZINormal iPoint;
209 | iPoint.x = point.x;
210 | iPoint.y = point.y;
211 | iPoint.z = point.z;
212 | iPoint.intensity = depthCloudTime - systemInitTime;
213 | iPoint.curvature = customMsg->data.data[v * row_stride + u * col_stride];
214 | cloud->points.push_back(iPoint);
215 | }
216 | }
217 | }
218 | } else {
219 | ROS_ERROR("Unsupported depth encoding: %s", depthMsg->encoding.c_str());
220 | return;
221 | }
222 | newDepthCloud = true;
223 | // ROS_INFO("Input cloud size %zu", cloud->points.size());
224 | }
225 |
226 | int main(int argc, char** argv) {
227 | ros::init(argc, argv, "depth_projection");
228 | ros::NodeHandle nh;
229 |
230 | std::string cameraType;
231 | nh.getParam("/depth_projection/camera_type", cameraType);
232 | nh.getParam("/depth_projection/decayTime", decayTime);
233 | setCameraIntrinsics(cameraType);
234 |
235 | // Set up subscribers using message_filters
236 | message_filters::Subscriber depthSub(nh, "/camera/aligned_depth_to_color/image_raw", 1);
237 | message_filters::Subscriber odomSub(nh, "/state_estimation", 1);
238 | message_filters::Subscriber customMsgSub(nh, "/inference/results_stamped_post", 1);
239 |
240 | // Create ApproximateTime policy
241 | typedef message_filters::sync_policies::ApproximateTime MySyncPolicy;
242 | message_filters::Synchronizer sync(MySyncPolicy(10), depthSub, odomSub, customMsgSub);
243 | sync.setInterMessageLowerBound(ros::Duration(1.5)); // Adjust time tolerance
244 | sync.registerCallback(boost::bind(&callback, _1, _2, _3));
245 |
246 | // ros::Subscriber lossSub = nh.subscribe("/inference/results", 10, lossCallback);
247 |
248 | // cameraToMapTransform << 0.0, 0.0, 1.0, 0.0, // CMU_SIM transform
249 | // -1.0, 0.0, 0.0, 0.0,
250 | // 0.0,-1.0, 0.0, 0.0,
251 | // 0.0, 0.0, 0.0, 1.0;
252 |
253 | cameraToMapTransform << 0.01165962, -0.02415892, 0.99964014, 0.482,
254 | -0.99953617, 0.02784553, 0.01233136, 0.04,
255 | -0.02813342, -0.99932026, -0.02382304, 0.249,
256 | 0.0, 0.0, 0.0, 1.0;
257 |
258 | cloudPub = nh.advertise("/depth_projection", 10);
259 |
260 | downSizeFilter.setLeafSize(voxel_size_, voxel_size_, voxel_size_);
261 |
262 | //print out the camera intrinsics
263 | ROS_INFO("Camera intrinsics: fx = %f, fy = %f, cx = %f, cy = %f", intrinsics.fx, intrinsics.fy, intrinsics.cx, intrinsics.cy);
264 |
265 | ros::Rate rate(200);
266 | bool status = ros::ok();
267 | while (status) {
268 | ros::spinOnce();
269 | if (newDepthCloud) {
270 | newDepthCloud = false;
271 |
272 | //clear point clouds
273 | terrainCloud->clear();
274 | transformedCloud->clear();
275 | sparseCloud->clear();
276 | sparseTerrainCloud->clear();
277 |
278 | // Update terrain cloud as to get rid of old points outside decay distance
279 | int currentCloudSize = currentCloud->points.size();
280 | for (int i = 0; i < currentCloudSize; i++) {
281 | pcl::PointXYZINormal point = currentCloud->points[i];
282 |
283 | // Translate point to vehicle coordinate frame
284 | float translatedX = point.x - vehicleX;
285 | float translatedY = point.y - vehicleY;
286 | float translatedZ = point.z - vehicleZ;
287 |
288 | // Rotate point according to vehicle orientation
289 | float rotatedX = cosVehicleYaw * translatedX + sinVehicleYaw * translatedY;
290 | float rotatedY = -sinVehicleYaw * translatedX + cosVehicleYaw * translatedY;
291 | float rotatedZ = cosVehiclePitch * translatedZ - sinVehiclePitch * rotatedX;
292 |
293 | // Calculate planar distance in XY plane
294 | float dis = sqrt(rotatedX * rotatedX + rotatedY * rotatedY);
295 |
296 | // Calculate azimuth and elevation angles
297 | float angle1 = atan2(rotatedY, rotatedX); // Azimuth angle
298 | float angle2 = atan2(rotatedZ, dis); // Elevation angle
299 |
300 | // Check if the point is outside the decay time OR within no-decay distance
301 | // Also, check if the point is outside the FOV in both azimuth and elevation
302 | if ((depthCloudTime - systemInitTime + point.intensity < decayTime || dis < clearingDis)
303 | && point.z < vehicleHeight
304 | && (((fabs(angle1) > (fovx / 2) - 8*(PI/180) || fabs(angle2) > (fovy / 2))) || dis < minDis)) { // Use OR instead of AND
305 | terrainCloud->push_back(point);
306 | }
307 | // ROS_INFO("sysinit %f, depth %f, intensity %f, time diff %f",systemInitTime, depthCloudTime, point.intensity, depthCloudTime - systemInitTime - point.intensity);
308 | }
309 |
310 | //filter the terrain cloud
311 | downSizeFilter.setInputCloud(terrainCloud);
312 | downSizeFilter.filter(*sparseTerrainCloud);
313 |
314 | //filter input depth cloud
315 | // downSizeFilter.setInputCloud(cloud);
316 | // downSizeFilter.filter(*sparseCloud);
317 |
318 | // Transform the point cloud to the map frame
319 | pcl::transformPointCloud(*cloud, *transformedCloud, cameraToMapTransform);
320 |
321 | // ROS_INFO("transformedCloud size %zu", transformedCloud->points.size());
322 |
323 | // Transform each point in the cloud to be in the odometry frame
324 | int transformedCloudSize = transformedCloud->points.size();
325 | for (int i =0; i < transformedCloudSize; i++) {
326 | pcl::PointXYZINormal point = transformedCloud->points[i];
327 | tf::Vector3 p(point.x, point.y, point.z);
328 | tf::Vector3 pTransformed = odomTransform * p;
329 | pcl::PointXYZINormal newPoint;
330 | newPoint.x = pTransformed.x();
331 | newPoint.y = pTransformed.y();
332 | newPoint.z = pTransformed.z();
333 | newPoint.intensity = point.intensity;
334 | newPoint.curvature = point.curvature;
335 | float dis = sqrt((newPoint.x - vehicleX) * (newPoint.x - vehicleX) + (newPoint.y - vehicleY) * (newPoint.y - vehicleY));
336 | if (newPoint.z < vehicleZ + vehicleHeight && dis > minDis && dis < noDecayDis) {
337 | sparseTerrainCloud->push_back(newPoint);
338 | }
339 | }
340 |
341 | currentCloud = pcl::PointCloud::Ptr(new pcl::PointCloud(*sparseTerrainCloud));
342 |
343 | //loop through the terrain cloud
344 | pubCloud->clear();
345 | int terrainCloudSize = sparseTerrainCloud->points.size();
346 | for (int i = 0; i < terrainCloudSize; i++) {
347 | pcl::PointXYZINormal point = sparseTerrainCloud->points[i];
348 | pcl::PointXYZI newPoint;
349 | newPoint.x = point.x;
350 | newPoint.y = point.y;
351 | newPoint.z = point.z;
352 | newPoint.intensity = point.curvature;
353 | // newPoint.intensity = point.z;
354 | pubCloud->push_back(newPoint);
355 | }
356 |
357 | // Publish the terrain cloud
358 | sensor_msgs::PointCloud2 terrainCloud2;
359 | pcl::toROSMsg(*pubCloud, terrainCloud2);
360 | terrainCloud2.header.frame_id = "odom";
361 | terrainCloud2.header.stamp = ros::Time().fromSec(depthCloudTime);
362 | cloudPub.publish(terrainCloud2);
363 | }
364 |
365 | status = ros::ok();
366 | rate.sleep();
367 | }
368 |
369 | return 0;
370 | }
--------------------------------------------------------------------------------
/assets/front_page.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/assets/front_page.png
--------------------------------------------------------------------------------
/assets/outdoor_all_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/assets/outdoor_all_2.png
--------------------------------------------------------------------------------
/assets/pre_train_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/assets/pre_train_pipeline.png
--------------------------------------------------------------------------------
/checkpoints/all_ViT_small_input_700_big_nn_checkpoint_20240827-1935.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/checkpoints/all_ViT_small_input_700_big_nn_checkpoint_20240827-1935.pth
--------------------------------------------------------------------------------
/checkpoints/richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/checkpoints/richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth
--------------------------------------------------------------------------------
/checkpoints/unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/checkpoints/unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="STEPP",
5 | version="1.0.0",
6 | author="Sebastian Aegidius",
7 | author_email="your.email@example.com",
8 | description="Traversability estimation package using image features",
9 | # long_description=open("README.md").read(),
10 | # long_description_content_type="text/markdown",
11 | url="https://github.com/RPL-CS-UCL/STEPP-Code",
12 | packages=find_packages(),
13 | classifiers=[
14 | "Programming Language :: Python :: 3",
15 | "License :: OSI Approved :: MIT License",
16 | "Operating System :: OS Independent",
17 | ],
18 | python_requires='>=3.8',
19 | install_requires=[
20 | #generic
21 | "numpy",
22 | "tqdm",
23 | "kornia>=0.6.5",
24 | "pip",
25 | "torchvision",
26 | "torch>=1.21",
27 | "torchmetrics",
28 | "pytorch_lightning>=1.6.5",
29 | "pytest",
30 | "scipy",
31 | "scikit-image",
32 | "scikit-learn",
33 | "matplotlib",
34 | "seaborn",
35 | "pandas",
36 | "pytictac",
37 | "torch_geometric",
38 | "omegaconf",
39 | "optuna",
40 | "neptune",
41 | "fast-slic",
42 | "hydra-core",
43 | "prettytable",
44 | "termcolor",
45 | "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git",
46 | "liegroups@git+https://github.com/mmattamala/liegroups",
47 | "wget",
48 | "rospkg",
49 | "wandb",
50 | "opencv-python",
51 | ],
52 | include_package_data=True,
53 | package_data={
54 | },
55 | )
--------------------------------------------------------------------------------