├── README.md ├── STEPP ├── DINO │ ├── __init__.py │ ├── backbone.py │ └── dino_feature_extract.py ├── SLIC │ └── slic_segmentation.py ├── __init__.py ├── model │ ├── mlp.py │ └── training.py └── utils │ ├── colorbar.py │ ├── data_loader.py │ ├── extract_future_poses.py │ ├── image_saver.py │ ├── make_dataset.py │ ├── make_unreal_data_pixel_file.py │ ├── misc.py │ ├── rename_files.py │ └── testing.py ├── STEPP_ros ├── CMakeLists.txt ├── config │ └── model_config.yaml ├── launch │ └── STEPP.launch ├── msg │ └── Float32Stamped.msg ├── package.xml ├── scripts │ └── inference_node.py └── src │ └── depth_projection_synchronized.cpp ├── assets ├── front_page.png ├── outdoor_all_2.png └── pre_train_pipeline.png ├── checkpoints ├── all_ViT_small_input_700_big_nn_checkpoint_20240827-1935.pth ├── richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth └── unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # Watch your STEPP: Semantic Traversability Estimation using Pose Projected Features # 2 | ![demo](assets/front_page.png) 3 | 4 | **Authors**: [Sebastian Aegidius*](https://rvl.cs.toronto.edu/), [Dennis Hadjivelichkov](https://dennisushi.github.io/), [Jianhao Jiao](https://gogojjh.github.io/), [Jonathan Embly-Riches](https://rpl-as-ucl.github.io/people/), [Dimitrios Kanoulas](https://dkanou.github.io/) 5 | 6 |
7 | 8 | [Project Page](https://rpl-cs-ucl.github.io/STEPP/)  [STEPP arXiv](https://arxiv.org/) 9 | 10 |
11 | 12 | 13 | ![demo](assets/outdoor_all_2.png) 14 | ![demo](assets/pre_train_pipeline.png) 15 | 16 | ## Installation ## 17 | ```bash 18 | conda create -n STEPP python=3.8 19 | conda activate STEPP 20 | cd 21 | # We use cuda 12.1 drivers/ 22 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 23 | ``` 24 | Place the repository in your catkin workspace of choice with your planner of choice implementation, we used the CMU Falco local planner from their [autonomous exploration development environment](https://www.cmu-exploration.com/). 25 | 26 | ```bash 27 | # Assuming an already setup and built ros workspace (workspace containing cmu-exploration, or any other navigation stack) 28 | cd your_navigation_ws/src 29 | git clone git@github.com:RPL-CS-UCL/STEPP-Code.git 30 | cd STEPP-code 31 | pip install -e . 32 | cd ../../.. 33 | catkin build STEPP_ros 34 | ``` 35 | 36 | 37 | For installation of Jetpack, Pytorch, and Torchvision on your Jetson Platform: [Link](https://pytorch.org/audio/stable/build.jetson.html) and [Link](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048) 38 | * Show jetpack version: ```apt-cache show nvidia-jetpack``` 39 | * [MUST] Create conda with python=3.8 and download wheel from this [link](https://nvidia.box.com/shared/static/i8pukc49h3lhak4kkn67tg9j4goqm0m7.whl) 40 | * And then ```pip install torch-2.0.0+nv23.05-cp38-cp38-linux_aarch64.whl``` 41 | * Install Torchvision (check the compatiable matrix with the corresponding pytorch). 42 | * Check this [link](https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048/1285?page=63) for this issue: ```ValueError: Unknown CUDA arch (8.7+PTX) or GPU not supported``` 43 | * Command: 44 | ``` 45 | pip install numpy && \ 46 | pip install torch-2.0.0+nv23.05-cp38-cp38-linux_aarch64.whl && \ 47 | cd torchvision/ && \ 48 | export BUILD_VERSION=0.15.1 && \ 49 | python setup.py install --user && \ 50 | python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); import torchvision" 51 | ``` 52 | ## Checkpoints ## 53 | 54 | The following trained checkpoints are included in the repo: 55 | 56 | | Modelname | Dataset| Image resolutions| DINOv2 size |MLP architecture| 57 | |-------------|--------|---------------------|-------------|---------| 58 | | [`richmond_forest.pth`](\\wsl.localhost\Ubuntu-20.04\home\sebastian\code\STEPP-Code\checkpoints\richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth) |Richmond Forest| 700x700 | dinov2_vits14 |bin_nn| 59 | | [`unreal_synthetic_data.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) |Unreal engine synthetic Data| 700x700 | dinov2_vits14 |big_nn| 60 | | [`all_data.pth`](\checkpoints\unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth)|Richmond Forest, Unreal synthetic Data | 700x700 | dinov2_vits14 |big_nn| 61 | 62 | ## Usage ## 63 | to launch the model, set all required paths correctly and build your workspace and run: 64 | ```bash 65 | roslaunch STEPP_ros STEPP.launch 66 | ``` 67 | 68 | ### STEPP.launch Arguments 69 | - "model_path": Path to your chosen checkpoint.pth file 70 | - 'visualize': decides if you want to output the overlayed traversability cost onto the image feed (slows inference time) 71 | - 'ump': option to use mixed precision for model inference. Makes inference time faster but requires retraining of model weights for best performance 72 | - 'cutoff': sets the value for the maximum normalized reconstruction error 73 | - "camera_type": [zed2, D455, cmu_sim] - sets the chosen depth projection camera intrinsics 74 | - "decayTime": (unfinished) how long do you want the depth pointcloud with cost to be remembered outside the decay zone and active camera view. 75 | 76 | ## Train Your Own STEPP inference model ## 77 | to train your own STEPP traversability estimation model all you need is a dataset consisting of an image folder and an odometry pose folder. Here each SE(3) odometry pose has to relate to the exact location and rotation of the correlating image. With these two you can run the `extract_future_poses.py` script and obtain a json file containing the pixels that represent the cameras future poses in the given image frame. 78 | 79 | With this json file and the associated images you can run the `make_dataset.py` file to obtain a `.npy` of the DINOv2 feature averaged vectors of each segment that the future poses in each image from your dataset belonges to. this can in turn be used to train the STEPP model on using `training.py` 80 | 81 | ### Acknowledgement 82 | https://github.com/leggedrobotics/wild_visual_navigation\ 83 | https://github.com/facebookresearch/dinov2\ 84 | https://github.com/HongbiaoZ/autonomous_exploration_development_environment 85 | 86 | ### Citation 87 | If you think any of our work was useful, please connsider citing it: 88 | 89 | ```bibtex 90 | Coming soon 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /STEPP/DINO/__init__.py: -------------------------------------------------------------------------------- 1 | from .dino_feature_extract import DinoInterface, run_dino_interfacer -------------------------------------------------------------------------------- /STEPP/DINO/backbone.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) Mark Hamilton. All rights reserved. 3 | # Copyright (c) 2022-2024, ETH Zurich, Piotr Libera, Jonas Frey, Matias Mattamala. 4 | # All rights reserved. Licensed under the MIT license. 5 | # See LICENSE file in the project root for details. 6 | # 7 | # 8 | import torch 9 | from torch import nn 10 | import numpy as np 11 | from abc import ABC, abstractmethod 12 | 13 | 14 | def get_backbone(cfg): 15 | """ 16 | Returns a selected DINOv2VIT backbone. 17 | After implementing the Backbone class for your backbone, add it to be returned from this function with a desired named. 18 | The backbone can then be used by specifying its name in the STEGO configuration file. 19 | """ 20 | if not hasattr(cfg, "backbone"): 21 | raise ValueError("Could not find 'backbone' option in the config file. Please check it") 22 | 23 | if cfg.backbone == "dinov2": 24 | return Dinov2ViT(cfg) 25 | else: 26 | raise ValueError("Backbone {} unavailable".format(cfg.backbone)) 27 | 28 | 29 | class Backbone(ABC, nn.Module): 30 | """ 31 | Base class to provide an interface for new STEGO backbones. 32 | 33 | To add a new backbone for use in STEGO, add a new implementation of this class. 34 | """ 35 | 36 | vit_name_long_to_short = { 37 | "vit_tiny": "T", 38 | "vit_small": "S", 39 | "vit_base": "B", 40 | "vit_large": "L", 41 | "vit_huge": "H", 42 | "vit_giant": "G", 43 | } 44 | 45 | # Initialize the backbone 46 | @abstractmethod 47 | def __init__(self, cfg): 48 | super().__init__() 49 | 50 | # Return the size of features generated by the backbone 51 | @abstractmethod 52 | def get_output_feat_dim(self) -> int: 53 | pass 54 | 55 | # Generate features for the given image 56 | @abstractmethod 57 | def forward(self, img): 58 | pass 59 | 60 | # Returh a name that identifies the type of the backbone 61 | @abstractmethod 62 | def get_backbone_name(self): 63 | pass 64 | 65 | 66 | class Dinov2ViT(Backbone): 67 | def __init__(self, cfg): 68 | super().__init__(cfg) 69 | self.cfg = cfg 70 | self.backbone_type = self.cfg.backbone_type 71 | self.patch_size = 14 72 | if self.backbone_type == "vit_small": 73 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14", verbose=False) 74 | elif self.backbone_type == "vit_base": 75 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14", verbose=False) 76 | elif self.backbone_type == "vit_large": 77 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14", verbose=False) 78 | elif self.backbone_type == "vit_giant": 79 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", verbose=False) 80 | elif self.backbone_type == "vit_small_reg": 81 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg", verbose=False) 82 | elif self.backbone_type == "vit_base_reg": 83 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg", verbose=False) 84 | elif self.backbone_type == "vit_large_reg": 85 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg", verbose=False) 86 | elif self.backbone_type == "vit_giant_reg": 87 | self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14_reg", verbose=False) 88 | else: 89 | raise ValueError("Model type {} unavailable".format(cfg.backbone_type)) 90 | 91 | for p in self.model.parameters(): 92 | p.requires_grad = False 93 | self.model.eval().cuda() 94 | self.dropout = torch.nn.Dropout2d(p=np.clip(self.cfg.dropout_p, 0.0, 1.0)) 95 | 96 | if self.backbone_type == "vit_small": 97 | self.n_feats = 384 98 | elif self.backbone_type == "vit_base": 99 | self.n_feats = 768 100 | elif self.backbone_type == "vit_large": 101 | self.n_feats = 1024 102 | elif self.backbone_type == "vit_giant": 103 | self.n_feats = 1536 104 | else: 105 | self.n_feats = 768 106 | 107 | def get_output_feat_dim(self): 108 | return self.n_feats 109 | 110 | def forward(self, img): 111 | self.model.eval() 112 | with torch.no_grad(): 113 | assert img.shape[2] % self.patch_size == 0 114 | assert img.shape[3] % self.patch_size == 0 115 | 116 | # get selected layer activations 117 | feat = self.model.get_intermediate_layers(img)[0] 118 | 119 | feat_h = img.shape[2] // self.patch_size 120 | feat_w = img.shape[3] // self.patch_size 121 | 122 | image_feat = feat[:, :, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) 123 | 124 | if self.cfg.dropout_p > 0: 125 | return self.dropout(image_feat) 126 | else: 127 | return image_feat 128 | 129 | def get_backbone_name(self): 130 | return "DINOv2-" + Backbone.vit_name_long_to_short[self.backbone_type] + "-" + str(self.patch_size) -------------------------------------------------------------------------------- /STEPP/DINO/dino_feature_extract.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2022-2024, ETH Zurich, Jonas Frey, Matias Mattamala. 3 | # All rights reserved. Licensed under the MIT license. 4 | # See LICENSE file in the project root for details. 5 | # 6 | from os.path import join 7 | import torch.nn.functional as F 8 | import torch 9 | import torch.quantization as quant 10 | from torchvision import transforms as T 11 | from omegaconf import OmegaConf 12 | import numpy as np 13 | from pytictac import Timer 14 | 15 | from STEPP.DINO.backbone import get_backbone 16 | 17 | 18 | class DinoInterface: 19 | def __init__( 20 | self, 21 | device: str, 22 | backbone: str = "dino", 23 | input_size: int = 448, 24 | backbone_type: str = "vit_small", 25 | patch_size: int = 8, 26 | projection_type: str = None, # nonlinear or None 27 | dropout_p: float = 0, # True or False 28 | pretrained_weights: str = None, 29 | interpolate: bool = True, 30 | use_mixed_precision: bool = False, 31 | cfg: OmegaConf = OmegaConf.create({}), 32 | ): 33 | # Load config 34 | if cfg.is_empty(): 35 | self._cfg = OmegaConf.create( 36 | { 37 | "backbone": backbone, 38 | "backbone_type": backbone_type, 39 | "input_size": input_size, 40 | "patch_size": patch_size, 41 | "projection_type": projection_type, 42 | "dropout_p": dropout_p, 43 | "pretrained_weights": pretrained_weights, 44 | "interpolate": interpolate, 45 | } 46 | ) 47 | else: 48 | self._cfg = cfg 49 | 50 | # Initialize DINO 51 | self._model = get_backbone(self._cfg) 52 | 53 | # Send to device 54 | self._model.to(device) 55 | self._device = device 56 | 57 | # self._model = quant.quantize_dynamic(self._model, dtype=torch.qint8, inplace=True) 58 | self.use_mixed_precision = use_mixed_precision 59 | if self.use_mixed_precision: 60 | self._model = self._model.to(torch.float16) 61 | 62 | 63 | # Other 64 | normalization = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 65 | self._transform = T.Compose( 66 | [ 67 | T.Resize(input_size, T.InterpolationMode.NEAREST), 68 | T.CenterCrop(input_size), 69 | # T.CenterCrop((input_size, 1582)), 70 | normalization, 71 | ] 72 | ) 73 | 74 | def change_device(self, device): 75 | """Changes the device of all the class members 76 | 77 | Args: 78 | device (str): new device 79 | """ 80 | self._model.to(device) 81 | self._device = device 82 | 83 | @torch.no_grad() 84 | def inference(self, img: torch.tensor): 85 | """Performance inference using DINO 86 | Args: 87 | img (torch.tensor, dtype=type.torch.float32, shape=(B,3,H.W)): Input image 88 | 89 | Returns: 90 | features (torch.tensor, dtype=torch.float32, shape=(B,D,H,W)): per-pixel D-dimensional features 91 | """ 92 | 93 | # Resize image and normalize 94 | resized_img = self._transform(img).to(self._device) 95 | if self.use_mixed_precision: 96 | resized_img=resized_img.half() 97 | 98 | # Extract features 99 | features = self._model.forward(resized_img) 100 | # print('features shape before interpolation', features.shape) 101 | 102 | if self._cfg.interpolate: 103 | # resize and interpolate features 104 | B, D, H, W = img.shape 105 | new_features_size = (H, W) 106 | # pad = int((W - H) / 2) 107 | features = F.interpolate(features, new_features_size, mode="bilinear", align_corners=True) 108 | print('features shape after interpolation', features.shape) 109 | # features = F.pad(features, pad=[pad, pad, 0, 0]) 110 | 111 | return features.to(torch.float32) 112 | 113 | @property 114 | def input_size(self): 115 | return self._cfg.input_size 116 | 117 | @property 118 | def backbone(self): 119 | return self._cfg.backbone 120 | 121 | @property 122 | def backbone_type(self): 123 | return self._cfg.backbone_type 124 | 125 | @property 126 | def vit_patch_size(self): 127 | return self._cfg.patch_size 128 | 129 | 130 | def get_dino_features(img, dino_size, interpolate): 131 | # Inference model 132 | device = "cuda" if torch.cuda.is_available() else "cpu" 133 | # #convert image to torch tensor 134 | # img = torch.from_numpy(img) 135 | img = img.to(device) 136 | # img = F.interpolate(img, scale_factor=0.25) 137 | 138 | # Settings 139 | size = 896 140 | model = dino_size 141 | patch = 14 142 | backbone = "dinov2" 143 | 144 | # Inference with DINO 145 | # Create DINO 146 | di = DinoInterface( 147 | device=device, 148 | backbone=backbone, 149 | input_size=size, 150 | backbone_type=model, 151 | patch_size=patch, 152 | interpolate=interpolate, 153 | ) 154 | 155 | # with Timer(f"DINO, input_size, {di.input_size}, model, {di.backbone_type}, patch_size, {di.vit_patch_size}"): 156 | feat_dino = di.inference(img) 157 | # print(f"Feature shape after interpolation: {feat_dino.shape}") 158 | 159 | return feat_dino 160 | 161 | def average_dino_feature_segment(features, segment_img, segments=None): 162 | #features is a torch tensor of shape [1, 384, 64, 64] 163 | 164 | averaged_features = [] 165 | 166 | if segments is None: 167 | segments = np.unique(segment_img) 168 | 169 | # Loop through each segment 170 | for segment_id in segments: 171 | segment_pixels = segment_img.astype(np.uint16) == segment_id 172 | selected_features = features[:, :, segment_pixels] 173 | vector = selected_features.mean(dim=-1) 174 | averaged_features.append(vector) 175 | 176 | # Stack all vectors vertically to form a m by n tensor 177 | averaged_features_tensor = torch.cat(averaged_features, dim=0) 178 | 179 | return averaged_features_tensor 180 | 181 | def average_dino_feature_segment_tensor(features, segment_img, segments=None): 182 | 183 | if segments is None: 184 | segments, segments_count = torch.unique(segment_img, return_counts=True) 185 | 186 | features_flattened = features.permute(0,2,3,1).flatten(0,-2) # (bhw x n_features) 187 | index = segment_img.flatten().unsqueeze(-1).repeat(1,features_flattened.shape[-1]).long() # (bhw x n_features) 188 | num_segments = torch.max(segment_img).int()+1 # adding +1 for the 0 ID. 189 | output = torch.zeros( (num_segments, features_flattened.shape[-1]), device="cuda", dtype=features.dtype) 190 | segment_means = output.scatter_reduce(0,index, features_flattened, reduce="sum") 191 | segment_means = segment_means[segment_means.sum(-1)!=0] / segments_count.unsqueeze(-1) 192 | # print("Difference between two methods",(segment_means-averaged_features_tensor).sum()) 193 | averaged_features_tensor = segment_means 194 | 195 | return averaged_features_tensor 196 | 197 | def run_dino_interfacer(): 198 | """Performance inference using DINOv2VIT and stores result as an image.""" 199 | 200 | from pytictac import Timer 201 | from STEPP.utils.misc import get_img_from_fig, load_test_image, make_results_folder, remove_axes 202 | import matplotlib.pyplot as plt 203 | 204 | #supress warnings 205 | import warnings 206 | warnings.filterwarnings("ignore") 207 | 208 | 209 | # Create test directory 210 | outpath = make_results_folder("test_dino_interfacer") 211 | 212 | # Inference model 213 | device = "cuda" if torch.cuda.is_available() else "cpu" 214 | img = load_test_image().to(device) 215 | # img = F.interpolate(img, scale_factor=0.25) 216 | 217 | print('image after interpolation before going to model', img.shape) 218 | 219 | plot = False 220 | save_features = True 221 | 222 | # Settings 223 | size = 896 224 | model = "vit_small" 225 | patch = 14 226 | backbone = "dinov2" 227 | 228 | # Inference with DINO 229 | # Create DINO 230 | di = DinoInterface( 231 | device=device, 232 | backbone=backbone, 233 | input_size=size, 234 | backbone_type=model, 235 | patch_size=patch, 236 | ) 237 | 238 | with Timer(f"DINO, input_size, {di.input_size}, model, {di.backbone_type}, patch_size, {di.vit_patch_size}"): 239 | feat_dino = di.inference(img) 240 | print(f"Feature shape after interpolation: {feat_dino.shape}") 241 | 242 | if save_features: 243 | for i in range(5): 244 | fig = plt.figure(frameon=False) 245 | fig.set_size_inches(2, 2) 246 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) 247 | ax.set_axis_off() 248 | fig.add_axes(ax) 249 | ax.imshow(feat_dino[0][i].cpu(), cmap=plt.colormaps.get("inferno")) 250 | 251 | # Store results to test directory 252 | out_img = get_img_from_fig(fig) 253 | out_img.save( 254 | join( 255 | outpath, 256 | f"forest_clean_dino_feat{i:02}_{di.input_size}_{di.backbone_type}_{di.vit_patch_size}.png", 257 | ) 258 | ) 259 | plt.close("all") 260 | 261 | if plot: 262 | # Plot result as in colab 263 | fig, ax = plt.subplots(10, 11, figsize=(1 * 11, 1 * 11)) 264 | 265 | for i in range(10): 266 | for j in range(11): 267 | if i == 0 and j == 0: 268 | continue 269 | 270 | elif (i == 0 and j != 0) or (i != 0 and j == 0): 271 | ax[i][j].imshow(img.permute(0, 2, 3, 1)[0].cpu()) 272 | ax[i][j].set_title("Image") 273 | else: 274 | n = (i - 1) * 10 + (j - 1) 275 | if n >= di.get_feature_dim(): 276 | break 277 | ax[i][j].imshow(feat_dino[0][n].cpu(), cmap=plt.colormaps.get("inferno")) 278 | ax[i][j].set_title("Features [0]") 279 | remove_axes(ax) 280 | plt.tight_layout() 281 | 282 | # Store results to test directory 283 | out_img = get_img_from_fig(fig) 284 | out_img.save( 285 | join( 286 | outpath, 287 | f"forest_clean_{di.backbone}_{di.input_size}_{di.backbone_type}_{di.vit_patch_size}.png", 288 | ) 289 | ) 290 | plt.close("all") 291 | 292 | 293 | if __name__ == "__main__": 294 | run_dino_interfacer() 295 | -------------------------------------------------------------------------------- /STEPP/SLIC/slic_segmentation.py: -------------------------------------------------------------------------------- 1 | #file to run SLIC segmentation on an image 2 | 3 | import cv2 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from fast_slic import Slic 7 | import time 8 | import json 9 | from collections import defaultdict 10 | from torchvision import transforms as T 11 | from PIL import Image 12 | import torch 13 | import torch.nn.functional as F 14 | from torchvision import transforms 15 | import os 16 | from pytictac import Timer 17 | 18 | class SLIC(): 19 | def __init__(self, crop_x=30, crop_y=20, num_superpixels=400, compactness=15): 20 | if crop_x == 0 and crop_y == 0: 21 | self.crop = False 22 | else: 23 | self.crop = True 24 | self.crop_x = crop_x 25 | self.crop_y = crop_y 26 | self.num_superpixels = num_superpixels 27 | self.compactness = compactness 28 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 29 | self.slic = Slic(num_components=self.num_superpixels, compactness=self.compactness) 30 | 31 | def Slic_segmentation_for_given_pixels(self, pixels, image): 32 | # Load your image 33 | if self.crop: 34 | only_img = image[self.crop_y:-self.crop_y, self.crop_x:-self.crop_x] 35 | else: 36 | only_img = image 37 | # Convert BGR image to RGB for matplotlib 38 | image_rgb = cv2.cvtColor(only_img, cv2.COLOR_BGR2RGB) 39 | 40 | # Create Slic object 41 | slic = Slic(num_components=self.num_superpixels, compactness=self.compactness) 42 | 43 | # Perform segmentation 44 | segmented_image = slic.iterate(image_rgb) 45 | 46 | # Assuming pixels is a list of (x, y) tuples or a 2D array where each row is an (x, y) pair 47 | pixels_array = np.array(pixels) 48 | 49 | # Extract the x and y coordinates 50 | y_coords = pixels_array[:, 0] 51 | x_coords = pixels_array[:, 1] 52 | 53 | # Use advanced indexing to get the segment values at the given (x, y) coordinates 54 | segment_values = segmented_image[x_coords, y_coords] 55 | 56 | # Create a dictionary to hold lists of pixel coordinates for each segment 57 | segment_dict = defaultdict(list) 58 | 59 | # Populate the dictionary with pixel coordinates grouped by their segment 60 | for i in range(len(segment_values)): 61 | segment = segment_values[i] 62 | pixel = (x_coords[i], y_coords[i]) 63 | segment_dict[segment].append(pixel) 64 | 65 | return segment_dict, segmented_image 66 | 67 | def Slic_segmentation_for_all_pixels(self, image): 68 | # Load your image 69 | if self.crop: 70 | only_img = image[self.crop_y:-self.crop_y, self.crop_x:-self.crop_x] 71 | else: 72 | only_img = image 73 | 74 | # Convert BGR image to RGB 75 | image_rgb = cv2.cvtColor(only_img, cv2.COLOR_BGR2RGB) 76 | 77 | # Create Slic object 78 | slic = Slic(num_components=self.num_superpixels, compactness=self.compactness) 79 | 80 | # Perform segmentation 81 | segmented_image = self.slic.iterate(image_rgb) 82 | 83 | # Get unique segment values 84 | unique_segments = np.unique(segmented_image) 85 | 86 | return unique_segments, segmented_image 87 | 88 | def Slic_segmentation_for_all_pixels_torch(self, image): 89 | # Load your image 90 | if self.crop: 91 | only_img = image[self.crop_y:-self.crop_y, self.crop_x:-self.crop_x] 92 | else: 93 | only_img = image 94 | 95 | # Convert BGR image to RGB 96 | image_rgb = cv2.cvtColor(only_img, cv2.COLOR_BGR2RGB) 97 | 98 | # Create Slic object 99 | slic = Slic(num_components=self.num_superpixels, compactness=self.compactness) 100 | 101 | # Perform segmentation 102 | segmented_image = self.slic.iterate(image_rgb) 103 | 104 | 105 | #put image onto the gpu 106 | segmented_image = torch.from_numpy(segmented_image).to(self.device) 107 | 108 | # Get unique segment values 109 | unique_segments = torch.unique(segmented_image) 110 | 111 | return unique_segments, segmented_image 112 | 113 | def make_masks_smaller_numpy(self, segment_values, segmented_image, wanted_size): 114 | # Convert NumPy array to PIL image 115 | segmented_image_pil = Image.fromarray(segmented_image.astype('uint16'), mode='I;16') 116 | 117 | 118 | # Resize the image while maintaining the pixel values 119 | resized_segmented_image_pil = segmented_image_pil.resize((wanted_size, wanted_size), Image.NEAREST) 120 | 121 | # Convert the resized PIL image back to a NumPy array 122 | resized_segmented_image = np.array(resized_segmented_image_pil).astype(np.uint16) 123 | 124 | new_segment_dict = defaultdict(list) 125 | 126 | # Iterate over each unique segment value 127 | for key in segment_values: 128 | # Find the coordinates where the pixel value equals the key 129 | coordinates = np.where(resized_segmented_image == key) 130 | 131 | # Zip the coordinates to get (row, column) pairs and store them in the dictionary 132 | new_segment_dict[key].extend(zip(coordinates[0], coordinates[1])) 133 | 134 | return resized_segmented_image, new_segment_dict 135 | 136 | def make_masks_smaller_torch(self, segment_values, segmented_image, wanted_size, return_dict=True): 137 | 138 | segmented_image = segmented_image.unsqueeze(0).unsqueeze(0).float() 139 | # Resize the image while maintaining the pixel values 140 | resized_segmented_image = F.interpolate( 141 | segmented_image, 142 | size=(wanted_size, wanted_size), 143 | mode='nearest') 144 | 145 | #get rid of the first and second dimension 146 | resized_segmented_image = resized_segmented_image.squeeze(0).squeeze(0) 147 | 148 | new_segment_dict = defaultdict(list) 149 | if return_dict: 150 | # Iterate over each unique segment value 151 | with Timer("loop"): 152 | for key in segment_values: 153 | # Find the coordinates where the pixel value equals the key 154 | coordinates = torch.where(resized_segmented_image == key) 155 | 156 | # Zip the coordinates to get (row, column) pairs and store them in the dictionary 157 | new_segment_dict[key].extend(zip(coordinates[0].tolist(), coordinates[1].tolist())) 158 | print (f"looped {len(segment_values)} times") 159 | 160 | return resized_segmented_image, new_segment_dict 161 | 162 | def get_difference_pixels(img1, img2): 163 | # Compute the absolute difference 164 | difference = cv2.absdiff(img1, img2) 165 | 166 | # Threshold the difference to find the significant changes 167 | _, thresholded_difference = cv2.threshold(difference, 25, 255, cv2.THRESH_BINARY) 168 | 169 | # Convert to grayscale 170 | gray_diff = cv2.cvtColor(thresholded_difference, cv2.COLOR_BGR2GRAY) 171 | 172 | # Find contours in the thresholded difference 173 | contours, _ = cv2.findContours(gray_diff, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 174 | 175 | # Find the largest contour by area 176 | largest_contour = max(contours, key=cv2.contourArea) 177 | 178 | flattened_list = [item[0].tolist() for item in largest_contour] 179 | 180 | return flattened_list 181 | 182 | def run_SLIC_segmentation(): 183 | """Run SLIC on an image and visualize the segmented image""" 184 | 185 | ############################################## 186 | # This should all be coming from a config file 187 | ############################################## 188 | img_width = 1408 189 | img_height = 1408 190 | x_boarder = 200 191 | y_boarder = 200 192 | number = 10 193 | # pixels = path[number] 194 | # img_path = images[number] 195 | img_path = 'path_to_test_image' 196 | print('img_path:', img_path) 197 | # ############################################## 198 | 199 | # #plot the image with the pixels 200 | img = cv2.imread(img_path) 201 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 202 | #crop image to remove the boarder 203 | img = img[y_boarder:-y_boarder, x_boarder:-x_boarder] 204 | plt.figure(figsize=(10, 10)) 205 | plt.imshow(img)#, cmap='inferno') 206 | plt.axis('off') 207 | # plt.show() 208 | 209 | # def overlay_images_1(n1_path, n2_path): 210 | # n1_image = cv2.imread(n1_path) 211 | # n2_image = cv2.imread(n2_path) 212 | # # n2_image[..., 3] = 1 213 | 214 | # mask = n2_image != 0 215 | 216 | # # Create an output image with all black pixels 217 | # output_image = np.zeros_like(n1_image) 218 | 219 | # # Apply the mask to n1_image and store the result in output_image 220 | # output_image[mask] = n1_image[mask] 221 | 222 | # output_image[0:520] = 0 223 | 224 | # #create a list of pixel coord pairs where the image is not black 225 | # pixels = [] 226 | # non_black_pixels = np.argwhere(np.any(output_image != 0, axis=-1)) 227 | # pixels = non_black_pixels[:, ::-1].tolist() 228 | 229 | # return output_image, pixels 230 | 231 | # def overlay_images_2(n1_path, n2_path): 232 | # n1_image = cv2.imread(n1_path) 233 | # n2_image = cv2.imread(n2_path) 234 | # # n2_image[..., 3] = 1 235 | 236 | # mask = n2_image != 0 237 | 238 | # # Create an output image with all black pixels 239 | # output_image = np.zeros_like(n1_image) 240 | 241 | # # Apply the mask to n1_image and store the result in output_image 242 | # output_image[mask] = n1_image[mask] 243 | 244 | # output_image[0:520] = 0 245 | 246 | # #create a list of pixel coord pairs where the image is not black 247 | # pixels = [] 248 | # for i in range(output_image.shape[0]): 249 | # for j in range(output_image.shape[1]): 250 | # if np.any(output_image[i, j] != 0): 251 | # pixels.append([j, i]) 252 | 253 | # return output_image, pixels 254 | 255 | 256 | # with Timer('overlay_images'): 257 | # output_img_1, pixels_1 = overlay_images_1(img_path, 'path_to_test_image') 258 | 259 | # with Timer('overlay_images'): 260 | # output_img_2, pixels_2 = overlay_images_2(img_path, 'path_to_test_image') 261 | 262 | # print(pixels_1 == pixels_2) 263 | # print(pixels_1[:10]) 264 | 265 | # exit() 266 | 267 | # output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) 268 | # plt.figure(figsize=(10, 10)) 269 | # plt.imshow(output_img)#, cmap='inferno') 270 | # plt.axis('off') 271 | 272 | #remove entries that contain values of larger than 720-20 and 1280-30 273 | # pixels = [pixel for pixel in pixels if pixel[0] < (img_width-y_boarder) and pixel[1] < (img_height-x_boarder)] 274 | # #also take off 20 from the x and 30 from the y 275 | # pixels = [(pixel[0] - x_boarder, pixel[1] - y_boarder) for pixel in pixels] 276 | 277 | slic = SLIC(crop_x=0, crop_y=0, num_superpixels=100, compactness=10) 278 | seg, seg_img = slic.Slic_segmentation_for_all_pixels(img) 279 | # segments, segmented_image = slic.Slic_segmentation_for_given_pixels(pixels, img) 280 | 281 | 282 | print('number of unique values in segmented image:', len(np.unique(seg_img))) 283 | print(seg) 284 | segmented_image_mask = seg_img 285 | 286 | #make values in each segment in seg_img random: 287 | # for value in seg: 288 | # random_value = np.random.randint(0, 255) # Generate a single random value for the current segment 289 | # seg_img = np.where(seg_img == value, random_value, seg_img) # seg_img = np.random.randint(0, len(np.unique(seg_img)), (seg_img.shape[0], seg_img.shape[1])) 290 | 291 | unique_values = set() # To keep track of the unique random values assigned 292 | for value in seg: 293 | random_value = np.random.randint(0, 255) 294 | 295 | # Ensure the random_value hasn't already been used 296 | while random_value in unique_values: 297 | random_value = np.random.randint(0, 255) # Generate a new random value if a collision occurs 298 | 299 | # Assign the unique random value and record it 300 | seg_img = np.where(seg_img == value, random_value, seg_img) 301 | unique_values.add(random_value) # Add to set of used values 302 | print(len(unique_values)) 303 | pixel_list = [[420, 973], 304 | [484, 833], 305 | [475, 745], 306 | [550, 778], 307 | [520, 717], 308 | [585, 678], 309 | [683, 632], 310 | [610, 610], 311 | [660, 668], 312 | [475,1000]] 313 | values = [] 314 | for pixels in pixel_list: 315 | # point = (pixel[1], pixel[0]) 316 | val = seg_img[(pixels[1], pixels[0])] 317 | print('val:', val) 318 | values.append(val) 319 | 320 | segmented_image_mask = np.where(np.isin(seg_img, values), seg_img, 0) 321 | segmented_image_mask_expanded = np.expand_dims(segmented_image_mask, axis=-1) # Adds a third dimension 322 | 323 | # Now segmented_image_mask_expanded will have shape (1008, 1008, 1) 324 | # Use np.where to compare and select values 325 | seg_img_path = np.where(segmented_image_mask_expanded != 0, img, 255) 326 | # for pixel in pixels: 327 | # # point = (pixel[1], pixel[0]) 328 | # val = segmented_image[(pixel[1], pixel[0])] 329 | # segmented_image_mask = np.where(segmented_image == val, 0, segmented_image_mask) 330 | 331 | # Optionally, visualize the segmented image 332 | plt.figure(figsize=(10, 10)) 333 | plt.imshow(seg_img)#, cmap='inferno') 334 | plt.axis('off') 335 | 336 | plt.figure(figsize=(10, 10)) 337 | plt.imshow(segmented_image_mask)#, cmap='inferno') 338 | plt.axis('off') 339 | 340 | plt.figure(figsize=(10, 10)) 341 | plt.imshow(seg_img_path)#, cmap='inferno') 342 | plt.axis('off') 343 | 344 | # resized_segmented_image, new_segment_dict = slic.make_masks_smaller(segments.keys(), segmented_image, 64) 345 | 346 | # print('new_segment_dict:', new_segment_dict) 347 | 348 | # print('number of unique values in resized segmented image:', len(np.unique(resized_segmented_image))) 349 | 350 | # resized_segmented_image_mask = resized_segmented_image 351 | 352 | # for key in new_segment_dict.keys(): 353 | # resized_segmented_image_mask = np.where(resized_segmented_image == float(key), 0, resized_segmented_image_mask) 354 | 355 | # plt.figure(figsize=(10, 10)) 356 | # plt.imshow(resized_segmented_image)#, cmap='inferno') 357 | # plt.axis('off') 358 | 359 | # # Optionally, visualize the resized segmented image 360 | # plt.figure(figsize=(10, 10)) 361 | # plt.imshow(resized_segmented_image_mask)#, cmap='inferno') 362 | # plt.axis('off') 363 | plt.show() 364 | 365 | 366 | if __name__ == "__main__": 367 | run_SLIC_segmentation() -------------------------------------------------------------------------------- /STEPP/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # Add the directory containing STEPP to the Python path 5 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 6 | sys.path.append(ROOT_DIR) 7 | """Absolute path to the STEPP repository.""" -------------------------------------------------------------------------------- /STEPP/model/mlp.py: -------------------------------------------------------------------------------- 1 | #Script to train a MLP network 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader, TensorDataset 9 | # from sklearn.model_selection import train_test_split 10 | # from sklearn.datasets import make_classification 11 | # from sklearn.preprocessing import StandardScaler 12 | # from sklearn.metrics import accuracy_score 13 | import matplotlib.pyplot as plt 14 | import os 15 | import json 16 | import argparse 17 | 18 | #MLP encoder decoder architecture 19 | class ReconstructMLP(nn.Module): 20 | def __init__(self, input_dim, hidden_dim): 21 | super(ReconstructMLP, self).__init__() 22 | output_dim = input_dim 23 | layers = [] 24 | for hd in hidden_dim[:]: 25 | layers.append(nn.Linear(input_dim, hd)) 26 | layers.append(nn.ReLU()) 27 | input_dim = hd 28 | layers.append(nn.Linear(input_dim, output_dim)) 29 | 30 | self.model = nn.Sequential(*layers) 31 | 32 | def forward(self, x): 33 | return self.model(x) 34 | 35 | # input_dim = 384 36 | # hidden layer dim = [input dim, 256, 64, 32, 16, 32, 64, 256, input_dim] 37 | 38 | 39 | #VAE encoder decoder architecture 40 | class ReconstructVAE (nn.Module): 41 | def __init__(self, input_dim, hidden_dim, latent_dim): 42 | super(ReconstructVAE, self).__init__() 43 | self.encoder = nn.Sequential( 44 | nn.Linear(input_dim, hidden_dim[0]), 45 | nn.ReLU(), 46 | nn.Linear(hidden_dim[0], hidden_dim[1]), 47 | nn.ReLU(), 48 | nn.Linear(hidden_dim[1], hidden_dim[2]), 49 | nn.ReLU(), 50 | nn.Linear(hidden_dim[2], hidden_dim[3]), 51 | nn.ReLU(), 52 | nn.Linear(hidden_dim[3], hidden_dim[4]), 53 | nn.ReLU(), 54 | nn.Linear(hidden_dim[4], hidden_dim[5]), 55 | nn.ReLU(), 56 | nn.Linear(hidden_dim[5], hidden_dim[6]), 57 | nn.ReLU(), 58 | nn.Linear(hidden_dim[6], latent_dim * 2) 59 | ) 60 | 61 | self.decoder = nn.Sequential( 62 | nn.Linear(latent_dim, hidden_dim[6]), 63 | nn.ReLU(), 64 | nn.Linear(hidden_dim[6], hidden_dim[5]), 65 | nn.ReLU(), 66 | nn.Linear(hidden_dim[5], hidden_dim[4]), 67 | nn.ReLU(), 68 | nn.Linear(hidden_dim[4], hidden_dim[3]), 69 | nn.ReLU(), 70 | nn.Linear(hidden_dim[3], hidden_dim[2]), 71 | nn.ReLU(), 72 | nn.Linear(hidden_dim[2], hidden_dim[1]), 73 | nn.ReLU(), 74 | nn.Linear(hidden_dim[1], hidden_dim[0]), 75 | nn.ReLU(), 76 | nn.Linear(hidden_dim[0], input_dim) 77 | ) 78 | 79 | def forward(self, x): 80 | mu, log_var = torch.chunk(self.encoder(x), 2, dim=1) 81 | z = self.reparameterize(mu, log_var) 82 | return self.decoder(z), mu, log_var 83 | 84 | def reparameterize(self, mu, log_var): 85 | std = torch.exp(0.5 * log_var) 86 | eps = torch.randn_like(std) 87 | return mu + eps * std -------------------------------------------------------------------------------- /STEPP/model/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from STEPP.model.mlp import ReconstructMLP 6 | import numpy as np 7 | from STEPP.utils.misc import make_results_folder 8 | from STEPP.utils.testing import test_feature_reconstructor_with_model 9 | import time 10 | import wandb 11 | import sys 12 | import os 13 | 14 | 15 | # Data loader 16 | class FeatureDataset: 17 | def __init__(self, feature_dir, stack=False, transform=None, target_transform=None, batch_size=None) -> None: 18 | self.feature_dir = feature_dir 19 | self.transform = transform 20 | self.batch_size = batch_size 21 | self.target_transform = target_transform 22 | 23 | if stack: 24 | #from the folder, load all numpy files and combine them into one big numpy array 25 | #loop through all files in the folder 26 | for root, dirs, files in os.walk(self.feature_dir): 27 | for file in files: 28 | if file.endswith('.npy'): 29 | #load the numpy file 30 | if not hasattr(self, 'avg_features'): 31 | self.avg_features = np.load(os.path.join(root, file)).astype(np.float32) 32 | else: 33 | self.avg_features = np.concatenate((self.avg_features, np.load(os.path.join(root, file)).astype(np.float32)), axis=0) 34 | print(self.avg_features.shape) 35 | self.avg_features = self.avg_features[~np.isnan(self.avg_features).any(axis=1)] 36 | else: 37 | self.avg_features = np.load(self.feature_dir).astype(np.float32) 38 | self.avg_features = self.avg_features[~np.isnan(self.avg_features).any(axis=1)] 39 | 40 | def __len__(self) -> int: 41 | return len(self.avg_features) 42 | 43 | def __getitem__(self, idx: int): 44 | if self.batch_size: 45 | feature = self.avg_features[idx:idx+self.batch_size] 46 | print(feature.shape) 47 | else: 48 | feature = self.avg_features[idx] 49 | if self.transform: 50 | feature = self.transform(feature) 51 | if self.target_transform: 52 | feature = self.target_transform(feature) 53 | return feature 54 | 55 | class EarlyStopping: 56 | def __init__(self, patience=20, verbose=False, delta=0): 57 | self.patience = patience 58 | self.verbose = verbose 59 | self.delta = delta 60 | self.counter = 0 61 | self.best_score = None 62 | self.early_stop = False 63 | self.val_loss_min = float('inf') 64 | self.training_start_time = time.strftime("%Y%m%d-%H%M") 65 | 66 | def __call__(self, val_loss, model): 67 | score = -val_loss 68 | 69 | if self.best_score is None: 70 | self.best_score = score 71 | self.save_checkpoint(val_loss, model) 72 | elif score < self.best_score + self.delta: 73 | self.counter += 1 74 | if self.verbose: 75 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 76 | if self.counter >= self.patience: 77 | self.early_stop = True 78 | else: 79 | self.best_score = score 80 | self.save_checkpoint(val_loss, model) 81 | self.counter = 0 82 | 83 | def save_checkpoint(self, val_loss, model): 84 | '''Saves model when validation loss decrease.''' 85 | results_folder = make_results_folder('trained_model') 86 | if self.verbose: 87 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 88 | torch.save(model.state_dict(), results_folder + f'/all_ViT_small_ump_input_700_small_nn_checkpoint_{self.training_start_time}.pth') 89 | self.val_loss_min = val_loss 90 | 91 | 92 | class TrainFeatureReconstructor(): 93 | 94 | def __init__(self, path, batch_size=32, epochs=1, learning_rate=1e-3): 95 | self.device = ( 96 | "cuda" 97 | if torch.cuda.is_available() 98 | else "cpu" 99 | ) 100 | print(f"Using {self.device} device") 101 | self.input_dim = 384 102 | # self.hidden_dim = [256, 128, 64, 32, 64, 128, 256] #big nn 103 | self.hidden_dim = [256, 64, 32, 16, 32, 64, 256] # small nn 104 | # self.hidden_dim = [1024, 512, 256, 64, 32, 16, 32, 64, 256, 512, 1024] #huge nn 105 | # self.hidden_dim = [256, 32] # wvn nn 106 | self.batch_size = batch_size 107 | self.epochs = epochs 108 | self.learning_rate = learning_rate 109 | self.data_path = path 110 | self.stack = True 111 | self.early_stopping = EarlyStopping(patience=10, verbose=True) 112 | 113 | # Training loop 114 | def train_loop(self, train_dataloader, loss_fn, optimizer): 115 | self.model.train() 116 | 117 | for epoch in range(self.epochs): 118 | running_loss = 0.0 119 | for data in train_dataloader: 120 | inputs = targets = data.to(self.device) 121 | 122 | # Zero the parameter gradients 123 | optimizer.zero_grad() 124 | 125 | # Forward pass 126 | outputs = self.model(inputs 127 | ) 128 | loss = loss_fn(outputs, targets) 129 | 130 | # Backward pass and optimize 131 | loss.backward() 132 | optimizer.step() 133 | optimizer.zero_grad() 134 | 135 | # Print statistics 136 | running_loss += loss.item() 137 | 138 | epoch_loss = running_loss / len(train_dataloader) 139 | print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {epoch_loss:.4f}") 140 | 141 | meta = {'epoch': epoch, 'loss': epoch_loss} 142 | if (epoch+1) % 10 == 0: 143 | test_dict = self.test_loop(self.model, loss_fn) 144 | meta.update(test_dict) 145 | self.early_stopping(test_dict["test_loss"], self.model) 146 | wandb.log(meta) 147 | if self.early_stopping.early_stop: 148 | print("Early stopping") 149 | exit() 150 | 151 | 152 | print('Finished Training') 153 | 154 | def test_loop(self, model, loss_fn): 155 | # Set the model to evaluation mode 156 | model.eval() 157 | dataloader = self.test_dataloader 158 | num_batches = len(dataloader) 159 | test_loss = 0 160 | 161 | # Ensure no gradients are computed during test mode 162 | with torch.no_grad(): 163 | for X in dataloader: 164 | X = X.to(self.device) 165 | # Forward pass: compute the model output 166 | recon_X = model(X) 167 | # Compute the loss 168 | test_loss += loss_fn(recon_X, X).item() 169 | 170 | # Compute the average loss over all batches 171 | test_loss /= num_batches 172 | print(f"Test Error: \n Avg MSE Loss: {test_loss:>8f} \n") 173 | 174 | # test on one validation image 175 | mode = 'segment_wise' 176 | test_image_path = 'path_to_test_image' 177 | figure = test_feature_reconstructor_with_model(mode,self.model, test_image_path) 178 | return dict(test_loss=test_loss, test_plot=figure) 179 | 180 | def data_split(self, dataset, train_split=0.8): 181 | train_size = int(train_split * len(dataset)) 182 | test_size = len(dataset) - train_size 183 | # training_data = dataset[:1000] 184 | # test_data = training_data 185 | training_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size]) 186 | 187 | train_dataloader = DataLoader(training_data, batch_size=self.batch_size, shuffle=True) 188 | test_dataloader = DataLoader(test_data, batch_size=self.batch_size, shuffle=True) 189 | 190 | return train_dataloader, test_dataloader 191 | 192 | def main(self): 193 | 194 | # Creating DataLoader 195 | dataset = FeatureDataset(self.data_path, self.stack) 196 | 197 | # Model instantiation 198 | self.model = ReconstructMLP(self.input_dim, self.hidden_dim).to(self.device) 199 | print(self.model) 200 | 201 | # Loss function and optimizer 202 | loss_fn = nn.MSELoss() 203 | optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) 204 | 205 | # Splitting the data 206 | train_dataloader, self.test_dataloader = self.data_split(dataset) 207 | 208 | # Training the model 209 | self.train_loop(train_dataloader, loss_fn, optimizer) 210 | 211 | # Testing the model 212 | self.test_loop(self.model, loss_fn) 213 | 214 | 215 | if __name__ == '__main__': 216 | wandb.init(project='STEPP') 217 | 218 | path_to_features = f'path_to_features' 219 | TrainFeatureReconstructor(path_to_features, epochs=1000000).main() -------------------------------------------------------------------------------- /STEPP/utils/colorbar.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | # import numpy as np 3 | # from matplotlib.colors import LinearSegmentedColormap, Normalize 4 | # from matplotlib.colorbar import ColorbarBase 5 | 6 | # # Your custom colormap stretching 7 | # s = 0.3 8 | # original_cmap = plt.cm.get_cmap("RdYlGn", 5000) 9 | # new_colors = np.vstack([ 10 | # original_cmap(np.linspace(0, s, 2500)), 11 | # original_cmap(np.linspace(1 - s, 1.0, 2500)) 12 | # ]) 13 | # new_cmap = LinearSegmentedColormap.from_list("stretched_RdYlBu", new_colors[::-1]) 14 | 15 | # fig, ax = plt.subplots(figsize=(2, 6)) # Size this appropriately to your needs 16 | 17 | # # Normalize the colormap 18 | # norm = Normalize(vmin=0, vmax=1) 19 | 20 | # # Create the colorbar 21 | # cbar = ColorbarBase(ax, cmap='hsv', norm=norm, orientation='vertical') 22 | # cbar.set_label('Predicted Traversability') # Label according to what the colors represent 23 | 24 | # plt.show() 25 | 26 | 27 | import matplotlib.pyplot as plt 28 | import numpy as np 29 | from matplotlib.colors import LinearSegmentedColormap, Normalize 30 | from matplotlib.colorbar import ColorbarBase 31 | 32 | # Create a segment of the 'hsv' colormap 33 | original_cmap = plt.cm.get_cmap('hsv') 34 | segment = np.linspace(0, 0.3, 256) # Adjust 256 for smoother or coarser color transitions 35 | colors = original_cmap(segment) 36 | 37 | # Create a new colormap from this segment 38 | new_cmap = LinearSegmentedColormap.from_list('red_to_green', colors) 39 | 40 | # Setup figure and axes for the color bar 41 | fig, ax = plt.subplots(figsize=(1, 10)) # Adjust figure size as needed 42 | 43 | # Normalize the colormap 44 | norm = Normalize(vmin=0, vmax=1) 45 | 46 | # Create the color bar using the new colormap 47 | cbar = ColorbarBase(ax, cmap=new_cmap, norm=norm, orientation='vertical') 48 | cbar.set_label('Value Range') 49 | 50 | plt.show() 51 | -------------------------------------------------------------------------------- /STEPP/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import Adam 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | 8 | class FeatureDataset: 9 | def __init__(self, feature_dir, transform=None, target_transform=None, batch_size=None) -> None: 10 | self.feature_dir = feature_dir 11 | self.transform = transform 12 | self.batch_size = batch_size 13 | self.target_transform = target_transform 14 | self.avg_features = np.load(self.feature_dir) 15 | 16 | def __len__(self) -> int: 17 | return len(self.avg_features) 18 | 19 | def __getitem__(self, idx: int): 20 | 21 | if self.batch_size: 22 | feature = self.avg_features[idx:idx+self.batch_size] 23 | print(feature.shape) 24 | else: 25 | feature = self.avg_features[idx] 26 | if self.transform: 27 | feature = self.transform(feature) 28 | if self.target_transform: 29 | feature = self.target_transform(feature) 30 | return feature -------------------------------------------------------------------------------- /STEPP/utils/extract_future_poses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import cv2 4 | import numpy as np 5 | from camera import Camera 6 | from scipy.spatial.transform import Rotation as R 7 | import rospy 8 | from nav_msgs.msg import Odometry as odom 9 | import os 10 | import matplotlib.pyplot as plt 11 | import json 12 | 13 | 14 | class CameraPinhole(Camera): 15 | def __init__(self, width, height, camera_name, distortion_model, K, D, Rect, P): 16 | super().__init__(width, height, camera_name, distortion_model, K, D, Rect, P) 17 | 18 | def undistort(self, image): 19 | undistorted_image = cv2.undistort(image, self.K, self.D) 20 | return undistorted_image 21 | 22 | def main(): 23 | """Main function to test the Camera class.""" 24 | # Create a pinhole camera model 25 | D = np.array([-0.28685832023620605, -2.0772109031677246, 0.0005875344504602253, -0.0005043392884545028, 1.5214914083480835, -0.39617425203323364, -1.8762085437774658, 1.4227665662765503]) 26 | K = np.array([607.9638061523438, 0.0, 638.83984375, 0.0, 607.9390869140625, 367.0916748046875, 0.0, 0.0, 1.0]).reshape(3, 3) 27 | Rect = np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).reshape(3, 3) 28 | P = np.array([607.9638061523438, 0.0, 638.83984375, 0.0, 0.0, 607.9390869140625, 367.0916748046875, 0.0, 0.0, 0.0, 1.0, 0.0]).reshape(3, 4) 29 | camera_pinhole = CameraPinhole(width=1280, height=720, camera_name='kinect_camera', 30 | distortion_model='rational_polynomial', 31 | K=K, D=D, Rect=Rect, P=P) 32 | 33 | # Initialize lists to store coordinates and orientations 34 | coordinates = [] 35 | orientations = [] 36 | directions = [] 37 | 38 | folder_path = 'path_to_image_folder' 39 | images = sorted([os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith((".png", ".jpg", ".jpeg"))]) 40 | img_file_names = [os.path.basename(img) for img in images] 41 | 42 | # Initialize the ROS node 43 | rospy.init_node('trajectory_publisher', anonymous=True) 44 | # Initialize the publisher 45 | pub = rospy.Publisher('/trajectory', odom, queue_size=10) 46 | pub2 = rospy.Publisher('/trajectory2', odom, queue_size=10) 47 | 48 | # Load the coordinates and orientations 49 | coordinates_path = 'path_to_txt_file_containing_odometry_data' 50 | 51 | T_odom_list = [] 52 | with open(coordinates_path, 'r') as file: 53 | for line in file: 54 | if line.startswith('#'): 55 | continue # Skip comment lines 56 | parts = line.split() 57 | if parts: 58 | coordinates.append(np.array([float(parts[1]), float(parts[2]), float(parts[3])])) 59 | # print(coordinates[-1]) 60 | orientations.append(np.array([float(parts[4]), float(parts[5]), float(parts[6]), float(parts[7])])) #qx, qy, qz, qw 61 | # print(orientations[-1]) 62 | 63 | T_odom = np.eye(4, 4) 64 | T_odom[:3, :3] = R.from_quat(orientations[-1]).as_matrix()[:3, :3] 65 | T_odom[:3, 3] = coordinates[-1] 66 | T_odom_list.append(T_odom) 67 | 68 | #difference between odometry frame and camera frame 69 | translation = [-0.739, -0.056, -0.205] #x, y, z 70 | path_translation = [0.0, 1.0, 0.0] #x, y, z 71 | rotation = [0.466, -0.469, -0.533, 0.528] #quaternion 72 | T_imu_camera = np.eye(4, 4) 73 | T_imu_camera[:3, :3] = R.from_quat(rotation).as_matrix()[:3, :3] 74 | T_imu_camera[:3, 3] = translation 75 | 76 | # rotation = [-0.469, -0.533, 0.528, 0.466] #quaternion 77 | 78 | for i in range(len(coordinates)): 79 | T_world_camera = np.linalg.inv(T_imu_camera) @ T_odom_list[i] @ T_imu_camera 80 | 81 | coordinates[i] = T_world_camera[:3, 3] 82 | orientations[i] = R.from_matrix(T_world_camera[:3, :3]).as_quat() 83 | 84 | #create a list of odometry messages from the coord and orientation lists 85 | for i in range(len(coordinates)): 86 | # Create a new odometry message 87 | odom_msg = odom() 88 | # Set the header 89 | odom_msg.header.stamp = rospy.Time.now() 90 | odom_msg.header.frame_id = "odom" 91 | 92 | # Set the position 93 | odom_msg.pose.pose.position.x = coordinates[i][0] 94 | odom_msg.pose.pose.position.y = coordinates[i][1] 95 | odom_msg.pose.pose.position.z = coordinates[i][2] 96 | # Set the orientation 97 | odom_msg.pose.pose.orientation.x = orientations[i][0] 98 | odom_msg.pose.pose.orientation.y = orientations[i][1] 99 | odom_msg.pose.pose.orientation.z = orientations[i][2] 100 | odom_msg.pose.pose.orientation.w = orientations[i][3] 101 | # Append the message to the list 102 | directions.append(odom_msg) 103 | 104 | # publish data to ros topic 105 | # for i in range(len(coordinates)): 106 | # # Publish the message 107 | # pub.publish(directions[i]) 108 | # # Sleep for 0.1 seconds 109 | # rospy.sleep(0.01) 110 | # print(f"Published message {i+1}/{len(coordinates)}", end='\r') 111 | 112 | # if i == point: 113 | # pub2.publish(directions[i]) 114 | 115 | def unit_vector(vector): 116 | magnitude = np.linalg.norm(vector) 117 | if magnitude == 0: 118 | return vector 119 | return vector / magnitude 120 | 121 | def trasnform_coord(quat, coord): 122 | R1 = R.from_quat(quat).as_matrix() 123 | #transpose R1 to get the inverse 124 | return R1.T @ coord 125 | 126 | def translate_to_frame(coords, point, quat): 127 | # for a given coordinate and orientation pair 128 | New_frame_coord = [] 129 | for i in range(1, len(coords)): 130 | c = trasnform_coord(quat, point - coords[i] - path_translation) 131 | New_frame_coord.append(c) 132 | # print('\n c:',c) 133 | return np.array(New_frame_coord) 134 | 135 | # #create cv2 window 136 | # cv2.namedWindow('image', cv2.WINDOW_NORMAL) 137 | # cv2.resizeWindow('image', 1280, 720) 138 | 139 | u_C2_past = np.zeros((2, 1)) 140 | save_flag = True 141 | img_points = [] 142 | 143 | future_steps = 50 144 | all_points = [] 145 | 146 | # exit() 147 | print('length of coordinates:', len(coordinates)) 148 | 149 | for i in range(1, len(coordinates)- future_steps): 150 | point = i 151 | points = translate_to_frame(coordinates[point:], coordinates[point], orientations[point]) 152 | p_C2 = points.T 153 | # Project a 3D point into the pixel plane 154 | #make the point number a 6 digit string 155 | img = cv2.imread(folder_path + img_file_names[point]) 156 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 157 | 158 | img_points = [] 159 | 160 | u_C2_past[0] = 1280/2 161 | u_C2_past[1] = 720 162 | 163 | for j in range(1, future_steps): 164 | p_C = p_C2[:, j] 165 | _, tmp_p = camera_pinhole.project(p_C) 166 | tmp_p = tmp_p.reshape(2, 1) 167 | u_C1 = tmp_p[:, 0] 168 | tmp_p, _ = cv2.projectPoints(p_C.reshape(1, 1, 3), np.zeros((3, 1)), np.zeros((3, 1)), K, D) 169 | u_C2 = tmp_p[0, 0, :2] 170 | if u_C2[0] < camera_pinhole.width and u_C2[0] > 30 and u_C2[1] < camera_pinhole.height-20 and u_C2[1] > 0: 171 | 172 | # set points to be drawn on the image 173 | cv2.circle(img, (int(u_C2[0]), int(u_C2[1])), 5, (0, 0, 255), -1) 174 | cv2.line(img, (int(u_C2_past[0]), int(u_C2_past[1])), (int(u_C2[0]), int(u_C2[1])), (255 - j*(255/future_steps), j*(255/future_steps), 0), 2) # green line 175 | 176 | #append img_points to img_points 177 | img_points.append([int(u_C2[0]),int(u_C2[1])]) 178 | 179 | u_C2_past = u_C2 180 | # print(img_points) 181 | #append img_points to all_points as another dimension 182 | 183 | all_points.append(img_points) 184 | 185 | # Display the image with the points drawn on it in the cv2 window 186 | cv2.imshow('image', img) 187 | cv2.waitKey(0) 188 | 189 | print(f"Point {point}/{len(coordinates)}", end='\r') 190 | 191 | #save all_points to a numpy file 192 | # print(all_points[-7:]) 193 | print(len(all_points)) 194 | 195 | #save all_pints as a json 196 | with open('OPS_grass_pixels.json', 'w') as f: 197 | json.dump(all_points, f) 198 | 199 | 200 | if __name__ == '__main__': 201 | main() -------------------------------------------------------------------------------- /STEPP/utils/image_saver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import rospy 4 | from sensor_msgs.msg import Image, CompressedImage 5 | from cv_bridge import CvBridge, CvBridgeError 6 | import cv2 7 | import os 8 | 9 | class ImageSaver: 10 | def __init__(self, image_topic, save_directory): 11 | # Initialize the ROS node 12 | rospy.init_node('image_saver', anonymous=True) 13 | print('Node initialized') 14 | 15 | # Create a CvBridge object 16 | self.bridge = CvBridge() 17 | 18 | # Subscribe to the image topic 19 | self.image_sub = rospy.Subscriber(image_topic, CompressedImage, self.image_callback) 20 | 21 | # Directory to save images 22 | self.save_directory = save_directory 23 | if not os.path.exists(self.save_directory): 24 | os.makedirs(self.save_directory) 25 | 26 | # Counter for naming images 27 | self.image_counter = 0 28 | 29 | def image_callback(self, msg): 30 | try: 31 | # Convert the ROS Image message to a format OpenCV can work with 32 | cv_image = self.bridge.compressed_imgmsg_to_cv2(msg, "bgr8") 33 | 34 | # Create a filename for each image 35 | filename = os.path.join(self.save_directory, "image_{:06d}.png".format(self.image_counter)) 36 | 37 | # Save the image to the specified directory 38 | cv2.imwrite(filename, cv_image) 39 | rospy.loginfo("Saved image: {}".format(filename)) 40 | 41 | # Increment the counter 42 | self.image_counter += 1 43 | 44 | except CvBridgeError as e: 45 | rospy.logerr("CvBridge Error: {}".format(e)) 46 | 47 | if __name__ == '__main__': 48 | try: 49 | # Parameters 50 | image_topic = "/rgb/image_rect_color/compressed" # Set the image topic 51 | save_directory = "path_to_save_folder" # Set the directory to save images 52 | 53 | # Create the ImageSaver object 54 | image_saver = ImageSaver(image_topic, save_directory) 55 | 56 | # Keep the node running 57 | rospy.spin() 58 | except rospy.ROSInterruptException: 59 | pass 60 | -------------------------------------------------------------------------------- /STEPP/utils/make_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import json 5 | import torch 6 | import torch.nn.functional as F 7 | import os 8 | from pytictac import Timer 9 | import warnings 10 | import argparse 11 | 12 | from STEPP import ROOT_DIR 13 | from STEPP.DINO import run_dino_interfacer 14 | from STEPP.DINO.dino_feature_extract import DinoInterface 15 | from STEPP.SLIC.slic_segmentation import SLIC 16 | from STEPP.utils import misc 17 | from STEPP.utils.misc import load_image 18 | from STEPP.DINO.dino_feature_extract import get_dino_features, average_dino_feature_segment 19 | 20 | 21 | class FeatureDataSet: 22 | def __init__(self, path_to_image_folder, path_to_pixels): 23 | self.img_width = 1408#1280 24 | self.img_height = 1408#720 25 | self.x_boarder = 0 #20 26 | self.y_boarder = 0 #30 27 | self.start_image_idx = 0#750 28 | self.interpolate = False 29 | self.dino_size = 'vit_small' 30 | self.use_mixed_precision = True 31 | 32 | if self.dino_size == 'vit_small': 33 | self.feature_dim = 384 34 | elif self.dino_size == 'vit_base': 35 | self.feature_dim = 768 36 | elif self.dino_size == 'vit_large': 37 | self.feature_dim = 1024 38 | elif self.dino_size == 'vit_giant': 39 | self.feature_dim = 1536 40 | 41 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 42 | 43 | # Settings 44 | self.size = 700 45 | self.dino_size = "vit_small" 46 | self.patch = 14 47 | self.backbone = "dinov2" 48 | 49 | # Inference with DINO 50 | # Create DINO 51 | self.di = DinoInterface( 52 | device=self.device, 53 | backbone=self.backbone, 54 | input_size=self.size, 55 | backbone_type=self.dino_size, 56 | patch_size=self.patch, 57 | interpolate=False, 58 | use_mixed_precision=self.use_mixed_precision, 59 | ) 60 | 61 | #points 62 | with open(path_to_pixels, 'r') as f: 63 | path_pixels = json.load(f) 64 | 65 | path_pixels_resized = [] 66 | #remove entries that contain values of larger than 720-20 and 1280-30 67 | for pixels in path_pixels: 68 | pixels = [pixel for pixel in pixels if pixel[0] < (self.img_width-self.y_boarder) and pixel[1] < (self.img_height-self.x_boarder)] 69 | #also take off 20 from the x and 30 from the y 70 | pixels = [(pixel[0] - self.x_boarder, pixel[1] - self.y_boarder) for pixel in pixels] 71 | 72 | path_pixels_resized.append(pixels) 73 | self.path_pixels_resized = path_pixels_resized 74 | 75 | print("loaded pixels") 76 | 77 | #images 78 | self.images = sorted([os.path.join(path_to_image_folder, img) for img in os.listdir(path_to_image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]) 79 | #what does this do? 80 | if len(self.images) > len(self.path_pixels_resized): 81 | self.images = self.images[:-(len(self.images) -len(self.path_pixels_resized))] 82 | 83 | print("loaded images") 84 | 85 | 86 | def main(feat): 87 | 88 | #supress warnings 89 | warnings.filterwarnings("ignore") 90 | slic = SLIC(crop_x=0, crop_y=0) 91 | average_features_segments = np.zeros((1, feat.feature_dim)) 92 | 93 | for i in range(len(feat.images)): 94 | if feat.path_pixels_resized[i] == []: 95 | continue 96 | img = cv2.imread(feat.images[i]) 97 | segments, segmented_image = slic.Slic_segmentation_for_given_pixels(feat.path_pixels_resized[i], img) 98 | resized_segmented_image, new_segment_dict = slic.make_masks_smaller_numpy(segments.keys(), segmented_image, int(feat.size/feat.patch)) 99 | 100 | tensor_img = load_image(feat.images[i]).to(feat.device) 101 | 102 | #get dino features 103 | features = feat.di.inference(tensor_img) 104 | 105 | #average dino features over segments 106 | average_features = average_dino_feature_segment(features, resized_segmented_image, new_segment_dict.keys()) 107 | #convert to numpy array 108 | average_features = average_features.cpu().detach().numpy() 109 | average_features_segments = np.concatenate((average_features_segments,average_features), axis=0) 110 | 111 | print('processed image:', i,'/', len(feat.images))#, end='\r') 112 | 113 | average_features_segments = average_features_segments[1:] 114 | 115 | print('\n') 116 | print('average_features_segments shape:', average_features_segments.shape) 117 | 118 | return average_features_segments 119 | 120 | 121 | if __name__ == '__main__': 122 | 123 | path_to_image_folder = 'path_to_image_folder' 124 | path_to_pixels = 'path_to_pixels.json' 125 | data_preprocessing = FeatureDataSet(path_to_image_folder, path_to_pixels) 126 | dataset = main(data_preprocessing) 127 | 128 | #save dataset 129 | dataset_path = 'path_to_save_dataset' 130 | np.save(dataset_path, dataset) -------------------------------------------------------------------------------- /STEPP/utils/make_unreal_data_pixel_file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import matplotlib.pyplot as plt 5 | import json 6 | from pytictac import Timer 7 | 8 | def overlay_images(n1_path, n2_path): 9 | n1_image = cv2.imread(n1_path) 10 | n2_image = cv2.imread(n2_path) 11 | # n2_image[..., 3] = 1 12 | 13 | mask = n2_image != 0 14 | 15 | # Create an output image with all black pixels 16 | output_image = np.zeros_like(n1_image) 17 | 18 | # Apply the mask to n1_image and store the result in output_image 19 | output_image[mask] = n1_image[mask] 20 | 21 | output_image[0:520] = 0 22 | 23 | #create a list of pixel coord pairs where the image is not black 24 | pixels = [] 25 | non_black_pixels = np.argwhere(np.any(output_image != 0, axis=-1)) 26 | pixels = non_black_pixels[:, ::-1].tolist() 27 | 28 | return output_image, pixels 29 | 30 | path_to_image_folder = 'path_to_image_folder' 31 | path_to_trajectory_folder = 'path_to_trajectory_folder' 32 | 33 | images = sorted([os.path.join(path_to_image_folder, img) for img in os.listdir(path_to_image_folder) if img.endswith((".png", ".jpg", ".jpeg"))]) 34 | trajectory_images = sorted([os.path.join(path_to_trajectory_folder, img) for img in os.listdir(path_to_trajectory_folder) if img.endswith((".png", ".jpg", ".jpeg"))]) 35 | 36 | 37 | 38 | all_pixels = [] 39 | for i in range(len(images)): 40 | output_img, pixels = overlay_images(images[i], trajectory_images[i]) 41 | 42 | all_pixels.append(pixels) 43 | 44 | print('processed image:', i,'/', len(images), end='\r') 45 | 46 | #save the pixels to json 47 | path = 'path_to_save_pixels.json' 48 | with open(path, 'w') as f: 49 | json.dump(all_pixels, f) 50 | 51 | print('Finished saving pixels to:\n', path) 52 | -------------------------------------------------------------------------------- /STEPP/utils/misc.py: -------------------------------------------------------------------------------- 1 | from matplotlib.backends.backend_agg import FigureCanvasAgg 2 | from PIL import Image 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import os 7 | import cv2 8 | from STEPP import ROOT_DIR 9 | 10 | def make_results_folder(name): 11 | path = os.path.join(ROOT_DIR, "results", name) 12 | os.makedirs(path, exist_ok=True) 13 | return path 14 | 15 | def get_img_from_fig(fig, dpi=180): 16 | """Returns an image as numpy array from figure 17 | 18 | Args: 19 | fig (matplotlib.figure.Figure): Input figure. 20 | dpi (int, optional): Resolution. Defaults to 180. 21 | 22 | Returns: 23 | buf (np.array, dtype=np.uint8 or PIL.Image.Image): Resulting image. 24 | """ 25 | fig.set_dpi(dpi) 26 | canvas = FigureCanvasAgg(fig) 27 | # Retrieve a view on the renderer buffer 28 | canvas.draw() 29 | buf = canvas.buffer_rgba() 30 | # convert to a NumPy array 31 | buf = np.asarray(buf) 32 | buf = Image.fromarray(buf) 33 | buf = buf.convert("RGB") 34 | return buf 35 | 36 | def load_test_image(): 37 | np_img = cv2.imread(os.path.join(ROOT_DIR, "path_to_test_image")) 38 | np_img = np_img[200:-200, 200:-200] 39 | img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)) 40 | img = img.permute(2, 0, 1) 41 | img = (img.type(torch.float32) / 255)[None] 42 | return img 43 | 44 | def load_image(path): 45 | np_img = cv2.imread(path) 46 | img = torch.from_numpy(cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)) 47 | img = img.permute(2, 0, 1) 48 | img = (img.type(torch.float32) / 255)[None] 49 | return img 50 | 51 | def _remove_axes(ax): 52 | ax.xaxis.set_major_formatter(plt.NullFormatter()) 53 | ax.yaxis.set_major_formatter(plt.NullFormatter()) 54 | ax.set_xticks([]) 55 | ax.set_yticks([]) 56 | 57 | def remove_axes(axes): 58 | if len(axes.shape) == 2: 59 | for ax1 in axes: 60 | for ax in ax1: 61 | _remove_axes(ax) 62 | else: 63 | for ax in axes: 64 | _remove_axes(ax) 65 | 66 | def save_dataset(dataset, path): 67 | #create a folder if it does not exist 68 | folder = os.path.dirname(path + '/' + 'dataset') 69 | os.makedirs(folder, exist_ok=True) 70 | np.save(folder, dataset) 71 | -------------------------------------------------------------------------------- /STEPP/utils/rename_files.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import cv2 3 | import os 4 | 5 | def rename_files_in_folder(folder_path): 6 | for i, filename in enumerate(os.listdir(folder_path)): 7 | os.rename(os.path.join(folder_path, filename), os.path.join(folder_path, f"{int(filename[:-4]):06d}.png")) 8 | 9 | if __name__ == '__main__': 10 | folder_path = 'path_to_folder' 11 | rename_files_in_folder(folder_path) -------------------------------------------------------------------------------- /STEPP/utils/testing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from STEPP.model.mlp import ReconstructMLP 3 | from STEPP.utils.misc import load_image 4 | from STEPP.SLIC.slic_segmentation import SLIC 5 | from STEPP.utils.make_dataset import FeatureDataSet 6 | from STEPP.DINO.dino_feature_extract import DinoInterface, get_dino_features, average_dino_feature_segment 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | from matplotlib import cm 10 | import numpy as np 11 | import torch.nn as nn 12 | from matplotlib.colors import LinearSegmentedColormap 13 | import time 14 | import torch.nn.functional as F 15 | import warnings 16 | from PIL import Image as PILImage 17 | import seaborn as sns 18 | from pytictac import Timer 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | def test_feature_reconstructor(mode, model_path, image_path, thresh): 24 | # mode = 1 for running segmentwise inference 25 | # mode = 2 for running whole image inference 26 | 27 | device = ( 28 | "cuda" 29 | if torch.cuda.is_available() 30 | else "cpu" 31 | ) 32 | 33 | # load the model 34 | model = ReconstructMLP(384,[256, 128, 64, 32, 64, 128, 256]) # [256, 32, 384]) # 35 | #load the model with the weights 36 | model.load_state_dict(torch.load(model_path)) 37 | 38 | model.to(device) 39 | return test_feature_reconstructor_with_model(mode, model, image_path, thresh) 40 | 41 | def test_feature_reconstructor_with_model(mode,model, image_path, thresh): 42 | start = time.time() 43 | 44 | alpha = 0.5 45 | 46 | #load an image 47 | img = cv2.imread(image_path) 48 | torch_img = load_image(image_path) 49 | H, W, D = img.shape 50 | H1 = 64 51 | new_features_size = (H, H) 52 | 53 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 54 | small_image = cv2.resize(img, (new_features_size)) 55 | # small_image = cv2.cvtColor(small_image, cv2.COLOR_BGR2RGB) 56 | 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | 59 | threshold = thresh#= 0.1 60 | 61 | # Settings 62 | size = 700 63 | dino_size = "vit_small" 64 | patch = 14 65 | backbone = "dinov2" 66 | 67 | # Inference with DINO 68 | # Create DINO 69 | di = DinoInterface( 70 | device=device, 71 | backbone=backbone, 72 | input_size=size, 73 | backbone_type=dino_size, 74 | patch_size=patch, 75 | interpolate=False, 76 | use_mixed_precision = False, 77 | ) 78 | 79 | torch_img = torch.from_numpy(small_image) 80 | torch_img = torch_img.permute(2, 0, 1) 81 | torch_img = (torch_img.type(torch.float32) / 255)[None].to(device) 82 | # torch_img.to(self.device) 83 | dino_size = 'vit_small' 84 | # features = get_dino_features(torch_img, dino_size, False) 85 | features = di.inference(torch_img) 86 | 87 | print('features shape',features.shape) 88 | 89 | if mode == 'segment_wise': 90 | #segment the whole image and get each pixel for each segment value 91 | slic = SLIC(crop_x=0, crop_y=0) 92 | segments, segmented_image = slic.Slic_segmentation_for_all_pixels(small_image) 93 | print('segmented image shape:', segmented_image.shape) 94 | resized_segmented_img, new_segment_dict = slic.make_masks_smaller_numpy(segments, segmented_image, 50) 95 | 96 | #average the features over the segments 97 | average_features = average_dino_feature_segment(features, resized_segmented_img) 98 | 99 | # Forward pass the entire batch 100 | reconstructed_features = model(average_features) 101 | 102 | # Calculate the losses for the entire batch 103 | loss_fn = nn.MSELoss(reduction='none') 104 | losses = loss_fn(average_features, reconstructed_features) 105 | losses = losses.mean(dim=1).cpu().detach().numpy() # Average the losses across the feature dimension 106 | 107 | #set the segment values of the segmented image to equal the loss in losses 108 | for key, loss in zip(new_segment_dict.keys(), losses): 109 | segmented_image = np.where(segmented_image == int(key), loss, segmented_image) 110 | 111 | segmented_image - np.where(segmented_image > 10, 10, segmented_image) 112 | 113 | # Normalize the segmented image values to the range [0, 0.15] 114 | segmented_image = (segmented_image - segmented_image.min()) / (segmented_image.max() - segmented_image.min()) * 0.45 115 | 116 | # Change all values above 1 to 1 117 | segmented_image = np.where(segmented_image > threshold, threshold, segmented_image) 118 | # segmented_image = np.where(segmented_image < self.threshold, 0.0, segmented_image) 119 | 120 | # # Calculate the extent to center the segmented image 121 | # original_height, original_width = small_image.shape[:2] 122 | # segmented_height, segmented_width = segmented_image.shape[:2] 123 | 124 | # # Crop the original image to the segmented image size 125 | # x_offset = (original_width - segmented_width) // 2 126 | # y_offset = (original_height - segmented_height) // 2 127 | # small_image = img[y_offset:y_offset + segmented_height, x_offset:x_offset + segmented_width] 128 | 129 | # Create the colormap 130 | s = 0.3 # If bigger, get more fine-grained green, if smaller get more fine-grained red 131 | cmap = cm.get_cmap("RdYlBu", 256) # or RdYlGn 132 | cmap = np.vstack([ 133 | cmap(np.linspace(0, s, 128)), 134 | cmap(np.linspace(1 - s, 1.0, 128)) 135 | ]) # Stretch the colormap 136 | cmap = (cmap[:, :3] * 255).astype(np.uint8) 137 | 138 | # Reverse the colormap if needed 139 | cmap = cmap[::-1] 140 | 141 | # Normalize the segmented image values to the range [0, 255] 142 | segmented_normalized = ((segmented_image - segmented_image.min()) / 143 | (segmented_image.max() - segmented_image.min()) * 255).astype(np.uint8) 144 | 145 | # Map the segmented image values to colors 146 | color_mapped_img = cmap[segmented_normalized] 147 | 148 | # Convert images to RGBA 149 | img_rgba = PILImage.fromarray(np.uint8(small_image)).convert("RGBA") 150 | seg_rgba = PILImage.fromarray(color_mapped_img).convert("RGBA") 151 | 152 | # Adjust the alpha channel to vary the transparency 153 | seg_rgba_np = np.array(seg_rgba) 154 | alpha_channel = seg_rgba_np[:, :, 3] # Extract alpha channel 155 | alpha_channel = (alpha_channel * 1.0).astype(np.uint8) # Adjust transparency (50% transparent) 156 | seg_rgba_np[:, :, 3] = alpha_channel # Update alpha channel 157 | seg_rgba = PILImage.fromarray(seg_rgba_np) 158 | 159 | # Alpha composite the images 160 | img_new = PILImage.alpha_composite(img_rgba, seg_rgba) 161 | img_rgb = img_new.convert("RGB") 162 | 163 | #resize the image to the original size 164 | img_rgb = img_rgb.resize((W,H)) 165 | 166 | # Overlay the segmented image on the original image 167 | fig = plt.figure(figsize=(10, 10)) 168 | plt.imshow(img_rgb) 169 | plt.title(mode + '_reconstruction_' + dino_size + '_threshold_' + str(threshold)) 170 | plt.axis('off') 171 | 172 | elif mode == 'pixel_wise': 173 | 174 | # torch shape is (1, 384, 64, 64) 175 | features = features.permute(2, 3, 1, 0) 176 | 177 | #change the shape to (4096, 384) 178 | features_tensor = features.reshape(50*50, 384) 179 | 180 | with Timer('Inference: '): 181 | # Forward pass the entire batch 182 | reconstructed_features = model(features_tensor) 183 | 184 | # Calculate the losses for the entire batch 185 | loss_fn = nn.MSELoss(reduction='none') 186 | losses = loss_fn(features_tensor, reconstructed_features) 187 | losses = losses.mean(dim=1).cpu().detach().numpy() # Average the losses across the feature dimension 188 | 189 | #reshape losses to be 64x64 190 | losses = losses.reshape(50, 50) 191 | 192 | #resize the cost map to the original image size 193 | cost_map = cv2.resize(losses, (H, H)) 194 | 195 | print('time to run inference:', time.time()-start) 196 | 197 | cost_map = np.where(cost_map > 10,10, cost_map) 198 | 199 | 200 | # Normalize the segmented image values to the range [0, 0.15] 201 | cost_map = (cost_map - cost_map.min()) / (cost_map.max() - cost_map.min()) * 0.45 202 | 203 | 204 | #change all values above 1 to 1 205 | # cost_map = np.where(cost_map < 3, 0, cost_map) 206 | cost_map = np.where(cost_map > threshold, threshold, cost_map) 207 | 208 | # Create the colormap 209 | s = 0.3 # If bigger, get more fine-grained green, if smaller get more fine-grained red 210 | cmap = cm.get_cmap("RdYlBu", 256) # or RdYlGn 211 | cmap = np.vstack([ 212 | cmap(np.linspace(0, s, 128)), 213 | cmap(np.linspace(1 - s, 1.0, 128)) 214 | ]) # Stretch the colormap 215 | cmap = (cmap[:, :3] * 255).astype(np.uint8) 216 | 217 | # Reverse the colormap if needed 218 | cmap = cmap[::-1] 219 | 220 | # Normalize the segmented image values to the range [0, 255] 221 | cost_map_normalized = ((cost_map - cost_map.min()) / 222 | (cost_map.max() - cost_map.min()) * 255).astype(np.uint8) 223 | 224 | # Map the segmented image values to colors 225 | color_mapped_img = cmap[cost_map_normalized] 226 | 227 | # Convert images to RGBA 228 | img_rgba = PILImage.fromarray(np.uint8(small_image)).convert("RGBA") 229 | seg_rgba = PILImage.fromarray(color_mapped_img).convert("RGBA") 230 | 231 | # Adjust the alpha channel to vary the transparency 232 | seg_rgba_np = np.array(seg_rgba) 233 | alpha_channel = seg_rgba_np[:, :, 3] # Extract alpha channel 234 | alpha_channel = (alpha_channel * 0.75).astype(np.uint8) # Adjust transparency (50% transparent) 235 | seg_rgba_np[:, :, 3] = alpha_channel # Update alpha channel 236 | seg_rgba = PILImage.fromarray(seg_rgba_np) 237 | 238 | # Alpha composite the images 239 | img_new = PILImage.alpha_composite(img_rgba, seg_rgba) 240 | img_rgb = img_new.convert("RGB") 241 | 242 | #resize the image to the original size 243 | img_rgb = img_rgb.resize((W,H)) 244 | 245 | # Overlay the segmented image on the original image 246 | fig = plt.figure(figsize=(10, 10)) 247 | plt.imshow(img_rgb) 248 | plt.title(mode + '_reconstruction_' + dino_size + '_threshold_' + str(threshold)) 249 | plt.axis('off') 250 | 251 | # plt.show() 252 | return fig 253 | 254 | if __name__ == '__main__': 255 | model_path = 'path_to_model.pth' 256 | image_path = 'path_to_test_image.png' 257 | threshold = 0.15 258 | test_feature_reconstructor('segment_wise',model_path, image_path, threshold) 259 | 260 | #save figure to test folder 261 | count = time.strftime("%Y%m%d-%H%M") 262 | plt.savefig('folder_to_save_figure'+ count +'.png') 263 | plt.show() -------------------------------------------------------------------------------- /STEPP_ros/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0.2) 2 | project(STEPP_ros) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | roscpp 6 | geometry_msgs 7 | octomap_ros 8 | octomap_msgs 9 | pcl_conversions 10 | pcl_ros 11 | rospy 12 | sensor_msgs 13 | std_msgs 14 | cv_bridge 15 | grid_map_ros 16 | image_transport 17 | message_generation 18 | ) 19 | 20 | add_message_files( 21 | FILES 22 | Float32Stamped.msg 23 | ) 24 | 25 | generate_messages( 26 | DEPENDENCIES 27 | std_msgs 28 | ) 29 | 30 | 31 | find_package(PCL REQUIRED) 32 | find_package(OpenMP REQUIRED) 33 | find_package(OpenCV REQUIRED) 34 | 35 | catkin_package( 36 | CATKIN_DEPENDS roscpp rospy sensor_msgs std_msgs nav_msgs cv_bridge message_runtime 37 | ) 38 | 39 | include_directories( 40 | ${catkin_INCLUDE_DIRS} 41 | ${PCL_INCLUDE_DIRS} 42 | ) 43 | 44 | link_directories(${PCL_LIBRARY_DIRS}) 45 | add_definitions(${PCL_DEFINITIONS}) 46 | 47 | # Add your C++ source files here 48 | add_executable(depth_projection_synchronized src/depth_projection_synchronized.cpp) # src/utils.cpp) 49 | target_link_libraries(depth_projection_synchronized ${catkin_LIBRARIES} ${PCL_LIBRARIES}) 50 | add_dependencies(depth_projection_synchronized ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) 51 | 52 | # Make the Python script executable 53 | catkin_install_python(PROGRAMS scripts/inference_node.py 54 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 55 | ) 56 | -------------------------------------------------------------------------------- /STEPP_ros/config/model_config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/STEPP_ros/config/model_config.yaml -------------------------------------------------------------------------------- /STEPP_ros/launch/STEPP.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /STEPP_ros/msg/Float32Stamped.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | std_msgs/Float32MultiArray data 3 | -------------------------------------------------------------------------------- /STEPP_ros/package.xml: -------------------------------------------------------------------------------- 1 | 2 | STEPP_ros 3 | 0.0.1 4 | Traversability estimation package using image features 5 | 6 | Sebastian Aegidius 7 | MIT 8 | 9 | catkin 10 | roscpp 11 | rospy 12 | sensor_msgs 13 | std_msgs 14 | nav_msgs 15 | cv_bridge 16 | message_generation 17 | 18 | message_runtime 19 | roscpp 20 | rospy 21 | sensor_msgs 22 | std_msgs 23 | nav_msgs 24 | cv_bridge 25 | torch 26 | 27 | 28 | -------------------------------------------------------------------------------- /STEPP_ros/scripts/inference_node.py: -------------------------------------------------------------------------------- 1 | #!/Rocket_ssd/miniconda3/envs/STEPP/bin/python3 2 | 3 | import rospy 4 | from sensor_msgs.msg import Image, CompressedImage 5 | from std_msgs.msg import Float32MultiArray, MultiArrayDimension 6 | import torch 7 | import cv2 8 | from cv_bridge import CvBridge 9 | import numpy as np 10 | import torch.nn as nn 11 | import time 12 | from PIL import Image as PILImage 13 | from torchvision import transforms 14 | # import seaborn as sns 15 | from matplotlib import cm 16 | import warnings 17 | from queue import Queue 18 | from threading import Thread, Lock 19 | 20 | from STEPP.DINO.backbone import get_backbone 21 | from STEPP.DINO.dino_feature_extract import DinoInterface 22 | from STEPP.DINO.dino_feature_extract import get_dino_features, average_dino_feature_segment, average_dino_feature_segment_tensor 23 | from STEPP.SLIC.slic_segmentation import SLIC 24 | from STEPP.model.mlp import ReconstructMLP 25 | from STEPP_ros.msg import Float32Stamped 26 | 27 | warnings.filterwarnings("ignore") 28 | CV_BRIDGE = CvBridge() 29 | TO_TENSOR = transforms.ToTensor() 30 | TO_PIL_IMAGE = transforms.ToPILImage() 31 | 32 | from pytictac import Timer 33 | 34 | class InferenceNode: 35 | def __init__(self): 36 | self.image_queue = Queue(maxsize=1) 37 | self.lock = Lock() 38 | 39 | self.processing = False 40 | self.image_sub = rospy.Subscriber('/camera/color/image_raw/compressed', CompressedImage, self.image_callback) 41 | self.inference_pub = rospy.Publisher('/inference/result', Float32MultiArray, queue_size=200) 42 | self.inference_stamped_pub = rospy.Publisher('/inference/results_stamped_post', Float32Stamped, queue_size=200) 43 | self.visu_traversability_pub = rospy.Publisher('/inference/visu_traversability_post', Image, queue_size=200) 44 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | 46 | # Threshold for traversability 47 | self.threshold = 0.2 48 | 49 | # Settings 50 | self.size = 700 51 | self.dino_size = "vit_small" 52 | self.patch = 14 53 | self.backbone = "dinov2" 54 | self.ump = rospy.get_param('~ump', True) 55 | self.cutoff = rospy.get_param('~cutoff', 1.2) 56 | print(self.cutoff) 57 | print(type(self.cutoff)) 58 | 59 | # Inference with DINO 60 | # Create DINO 61 | self.di = DinoInterface( 62 | device=self.device, 63 | backbone=self.backbone, 64 | input_size=self.size, 65 | backbone_type=self.dino_size, 66 | patch_size=self.patch, 67 | interpolate=False, 68 | use_mixed_precision = self.ump, 69 | ) 70 | 71 | self.slic = SLIC(crop_x=0, crop_y=0) 72 | 73 | # Load model architecture 74 | # self.model = ReconstructMLP(384, [256, 64, 32, 16, 32, 64, 256]) 75 | self.model = ReconstructMLP(384, [256, 128, 64, 32, 64, 128, 256]) 76 | 77 | # Load model weights 78 | state_dict = torch.load(rospy.get_param('~model_path')) 79 | self.model.load_state_dict(state_dict) 80 | 81 | # Move model to the device 82 | self.model.to(self.device) 83 | 84 | self.visualize = rospy.get_param('~visualize', False) 85 | 86 | self.thread = Thread(target=self.process_images) 87 | self.thread.start() 88 | 89 | print('Inference node initialized') 90 | 91 | def publish_matrix(self, matrix): 92 | msg = Float32MultiArray() 93 | msg.data = matrix.flatten().tolist() # Flatten the matrix and convert to list 94 | msg.layout.dim.append(MultiArrayDimension()) 95 | msg.layout.dim[0].label = "rows" 96 | msg.layout.dim[0].size = matrix.shape[0] 97 | msg.layout.dim[0].stride = matrix.shape[1] # stride is the number of columns 98 | msg.layout.dim.append(MultiArrayDimension()) 99 | msg.layout.dim[1].label = "columns" 100 | msg.layout.dim[1].size = matrix.shape[1] 101 | msg.layout.dim[1].stride = 1 # stride is 1 for columns 102 | self.inference_pub.publish(msg) 103 | 104 | def publish_array_stamped(self, matrix): 105 | msg = Float32Stamped() 106 | 107 | # Get the current time in nanoseconds 108 | msg.header.stamp = rospy.Time.now() 109 | 110 | msg.data = Float32MultiArray() 111 | msg.data.data = matrix.flatten().tolist() # Flatten the matrix and convert to list 112 | msg.data.layout.dim.append(MultiArrayDimension()) 113 | msg.data.layout.dim[0].label = "rows" 114 | msg.data.layout.dim[0].size = matrix.shape[0] 115 | msg.data.layout.dim[0].stride = matrix.shape[1] # stride is the number of columns 116 | msg.data.layout.dim.append(MultiArrayDimension()) 117 | msg.data.layout.dim[1].label = "columns" 118 | msg.data.layout.dim[1].size = matrix.shape[1] 119 | msg.data.layout.dim[1].stride = 1 # stride is 1 for columns 120 | self.inference_stamped_pub.publish(msg) 121 | 122 | def process_images(self): 123 | while not rospy.is_shutdown(): 124 | # with Timer("Full loop"): 125 | image_data = self.image_queue.get() 126 | if image_data is None: 127 | break 128 | 129 | with self.lock: 130 | if isinstance(image_data, CompressedImage): 131 | cv_image = CV_BRIDGE.compressed_imgmsg_to_cv2(image_data, desired_encoding="bgr8") 132 | else: 133 | cv_image = CV_BRIDGE.imgmsg_to_cv2(image_data, desired_encoding="bgr8") 134 | 135 | try: 136 | traversability_array, inference_img = self.inference_image(cv_image) 137 | except Exception as e: 138 | print(f'Error: {e}') 139 | self.processing = False 140 | continue 141 | 142 | # self.publish_matrix(traversability_array) 143 | self.publish_array_stamped(traversability_array) 144 | 145 | if self.visualize: 146 | self.visu_traversability_pub.publish(CV_BRIDGE.cv2_to_imgmsg(np.array(inference_img), "rgb8")) 147 | 148 | self.processing = False 149 | # print('-'*10) 150 | 151 | def inference_image(self, image): 152 | # Load an image 153 | org_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 154 | img = cv2.resize(org_img, (self.size, self.size)) 155 | H, W, D = org_img.shape 156 | 157 | # with Timer("DINO feature extraction"): 158 | # Get the dino features 159 | torch_img = torch.from_numpy(img) 160 | torch_img = torch_img.permute(2, 0, 1) 161 | torch_img = (torch_img.type(torch.float32) / 255)[None].to(self.device) 162 | # torch_img.to(self.device) 163 | dino_size = 'vit_small' 164 | # features = get_dino_features(torch_img, dino_size, False) 165 | features = self.di.inference(torch_img) 166 | 167 | # Segment the whole image and get each pixel for each segment value 168 | # with Timer("SLIC"): 169 | segments, segmented_image = self.slic.Slic_segmentation_for_all_pixels_torch(img) 170 | # with Timer("Make masks smaller"): 171 | resized_segmented_img, new_segment_dict = self.slic.make_masks_smaller_torch(segments, segmented_image, int(self.size/self.patch), return_dict=False) 172 | # Average the features over the segments 173 | # with Timer("Average dino feature"): 174 | average_features = average_dino_feature_segment_tensor(features, resized_segmented_img).to(self.device) 175 | # with Timer("Forward pass"): 176 | # Forward pass the entire batch 177 | reconstructed_features = self.model(average_features) 178 | 179 | # Calculate the losses for the entire batch 180 | # with Timer("Loss calculation"): 181 | loss_fn = nn.MSELoss(reduction='none') 182 | losses = loss_fn(average_features, reconstructed_features) 183 | losses = losses.mean(dim=1).cpu().detach().numpy() # Average the losses across the feature dimension 184 | 185 | # with Timer("Set segment values optimized"): 186 | segmented_image = segmented_image.cpu().detach().numpy() 187 | # Get the unique keys from the resized segmented image 188 | unique_keys = np.unique(resized_segmented_img.cpu().detach().numpy()).astype(int) 189 | # Create an array that maps the unique segment values to the corresponding losses 190 | max_segment_value = np.max(segmented_image) 191 | default_loss = 1.0 192 | mapping_array = np.full(max_segment_value + 1, default_loss) 193 | # Fill the mapping array with the corresponding losses 194 | mapping_array[unique_keys] = losses 195 | # Use the mapping array to replace values in segmented_image 196 | segmented_image = mapping_array[segmented_image] 197 | 198 | #cuttoff the values at 10 199 | segmented_image = np.where(segmented_image > 10, 10, segmented_image) 200 | 201 | # Normalize the segmented image values to the range [0, 0.15] 202 | segmented_image = ((segmented_image - segmented_image.min()) / (segmented_image.max() - segmented_image.min())) * self.cutoff 203 | 204 | # Change all values above 1 to 1 205 | segmented_image = np.where(segmented_image > self.threshold, self.threshold, segmented_image) 206 | # segmented_image = np.where(segmented_image < self.threshold, 0.0, segmented_image) 207 | 208 | if self.visualize: 209 | # with Timer("image processing"): 210 | # Create the colormap 211 | s = 0.3 # If bigger, get more fine-grained green, if smaller get more fine-grained red 212 | cmap = cm.get_cmap("RdYlBu", 256) # or RdYlGn 213 | cmap = np.vstack([ 214 | cmap(np.linspace(0, s, 128)), 215 | cmap(np.linspace(1 - s, 1.0, 128)) 216 | ]) # Stretch the colormap 217 | cmap = (cmap[:, :3] * 255).astype(np.uint8) 218 | 219 | # Reverse the colormap if needed 220 | cmap = cmap[::-1] 221 | 222 | # Normalize the segmented image values to the range [0, 255] 223 | segmented_normalized = ((segmented_image - segmented_image.min()) / 224 | (segmented_image.max() - segmented_image.min()) * 255).astype(np.uint8) 225 | 226 | # Map the segmented image values to colors 227 | color_mapped_img = cmap[segmented_normalized] 228 | 229 | # Convert images to RGBA 230 | img_rgba = PILImage.fromarray(np.uint8(img)).convert("RGBA") 231 | seg_rgba = PILImage.fromarray(color_mapped_img).convert("RGBA") 232 | 233 | # Adjust the alpha channel to vary the transparency 234 | seg_rgba_np = np.array(seg_rgba) 235 | alpha_channel = seg_rgba_np[:, :, 3] # Extract alpha channel 236 | alpha_channel = (alpha_channel * 0.5).astype(np.uint8) # Adjust transparency (50% transparent) 237 | seg_rgba_np[:, :, 3] = alpha_channel # Update alpha channel 238 | seg_rgba = PILImage.fromarray(seg_rgba_np) 239 | 240 | # Alpha composite the images 241 | img_new = PILImage.alpha_composite(img_rgba, seg_rgba) 242 | img_rgb = img_new.convert("RGB") 243 | 244 | #resize the image and the segmented image to the original size 245 | img_rgb = img_rgb.resize((W,H)) 246 | segmented_image = cv2.resize(segmented_image, (W,H)) 247 | 248 | return segmented_image, img_rgb 249 | else: 250 | segmented_image = cv2.resize(segmented_image, (W,H)) 251 | 252 | return segmented_image, None 253 | 254 | def image_callback(self, data): 255 | if not self.processing: 256 | with self.lock: 257 | if not self.image_queue.full(): 258 | self.image_queue.put(data) 259 | self.processing = True 260 | 261 | if __name__ == '__main__': 262 | print('Starting inference node') 263 | rospy.init_node('inference_node') 264 | node = InferenceNode() 265 | rospy.spin() 266 | -------------------------------------------------------------------------------- /STEPP_ros/src/depth_projection_synchronized.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | using namespace std; 24 | 25 | const double PI = 3.1415926; 26 | double depthCloudTime = 0.0; 27 | double systemInitTime = 0; 28 | bool systemInited = false; 29 | bool firstLoss = true; 30 | bool firstStampedLoss = true; 31 | bool newDepthCloud = false; 32 | float vehicleX = 0, vehicleY = 0, vehicleZ = 0; 33 | float vehicleRoll = 0, vehiclePitch = 0, vehicleYaw = 0; 34 | float sinVehicleRoll = 0, cosVehicleRoll = 0; 35 | float sinVehiclePitch = 0, cosVehiclePitch = 0; 36 | float sinVehicleYaw = 0, cosVehicleYaw = 0; 37 | float voxel_size_ = 0.1; 38 | double noDecayDis = 5.0; 39 | double minDis = 1.5; 40 | double clearingDis = 3.0; 41 | double vehicleHeight = 0.5; 42 | double decayTime = 8.0; 43 | double height = 720; 44 | double width = 1280; 45 | float fovy; 46 | float fovx; 47 | float azimuth_buff = 0.0; 48 | int rows = 1, cols = 1; 49 | int row_stride = 1, col_stride = 1; 50 | 51 | Eigen::Matrix4f cameraToMapTransform; 52 | ros::Publisher cloudPub; 53 | 54 | pcl::VoxelGrid downSizeFilter; 55 | 56 | struct CameraIntrinsics { 57 | double fx; 58 | double fy; 59 | double cx; 60 | double cy; 61 | }; 62 | 63 | CameraIntrinsics intrinsics; 64 | tf::Transform odomTransform; 65 | std_msgs::Float32MultiArray loss; 66 | STEPP_ros::Float32Stamped losStamped; 67 | 68 | pcl::PointCloud::Ptr 69 | cloud(new pcl::PointCloud); 70 | pcl::PointCloud::Ptr 71 | sparseCloud(new pcl::PointCloud); 72 | pcl::PointCloud::Ptr 73 | transformedCloud(new pcl::PointCloud); 74 | pcl::PointCloud::Ptr 75 | terrainCloud(new pcl::PointCloud); 76 | pcl::PointCloud::Ptr 77 | sparseTerrainCloud(new pcl::PointCloud); 78 | pcl::PointCloud::Ptr 79 | currentCloud(new pcl::PointCloud); 80 | pcl::PointCloud::Ptr 81 | pubCloud(new pcl::PointCloud); 82 | 83 | void setCameraIntrinsics(const std::string& cameraType) { 84 | ROS_INFO("Setting camera intrinsics for %s camera", cameraType.c_str()); 85 | if (cameraType == "D455") { 86 | intrinsics = {634.3491821289062, 632.8595581054688, 631.8179931640625, 375.0325622558594}; 87 | height = 720; 88 | width = 1280; 89 | } else if (cameraType == "zed2") { 90 | intrinsics = {534.3699951171875, 534.47998046875, 477.2049865722656, 262.4590148925781}; 91 | height = 540; 92 | width = 960-2*azimuth_buff; 93 | } else if (cameraType == "cmu_sim") { 94 | intrinsics = {205.46963709898583, 205.46963709898583, 320.5, 180.5}; 95 | height = 360; 96 | width = 640-2*azimuth_buff; 97 | } else { 98 | ROS_ERROR("Invalid camera type specified. Please choose from 'D455', 'zed2', or 'cmu_sim'."); 99 | ros::shutdown(); 100 | } 101 | 102 | fovy = 2 * atan(height / (2 * intrinsics.fy)); 103 | fovx = 2 * atan(width / (2 * intrinsics.fx)); 104 | } 105 | 106 | // Convert 2D pixel coordinates to 3D point 107 | pcl::PointXYZ convertTo3DPoint(int u, int v, float depth, const CameraIntrinsics& intrinsics) { 108 | pcl::PointXYZ point; 109 | point.z = depth; 110 | point.x = (u - intrinsics.cx) / intrinsics.fx * depth; 111 | point.y = (v - intrinsics.cy) / intrinsics.fy * depth; 112 | return point; 113 | } 114 | 115 | void callback(const sensor_msgs::Image::ConstPtr& depthMsg, 116 | const nav_msgs::Odometry::ConstPtr& odomMsg, 117 | const STEPP_ros::Float32StampedConstPtr& customMsg) { 118 | 119 | // if (loss.data.empty()) { // Check if the loss data is not initialized 120 | // ROS_WARN("Loss data not available yet."); 121 | // return; // Skip this callback cycle 122 | // } 123 | if (firstStampedLoss) { 124 | rows = customMsg->data.layout.dim[0].size; 125 | cols = customMsg->data.layout.dim[1].size; 126 | row_stride = customMsg->data.layout.dim[0].stride; 127 | col_stride = customMsg->data.layout.dim[1].stride; 128 | firstStampedLoss = false; 129 | } 130 | // losStamped = *customMsg; 131 | 132 | // Extract the position and orientation from the odometry message 133 | double roll, pitch, yaw; 134 | geometry_msgs::Point position = odomMsg->pose.pose.position; 135 | geometry_msgs::Quaternion orientation = odomMsg->pose.pose.orientation; 136 | tf::Matrix3x3(tf::Quaternion(orientation.x, orientation.y, orientation.z, orientation.w)) 137 | .getRPY(roll, pitch, yaw); 138 | 139 | vehicleX = odomMsg->pose.pose.position.x; 140 | vehicleY = odomMsg->pose.pose.position.y; 141 | vehicleZ = odomMsg->pose.pose.position.z; 142 | 143 | //temp [7.251, -10.919, -3.618] 144 | // vehicleX = vehicleX - 7.251; 145 | // vehicleY = vehicleY + 10.919; 146 | // vehicleZ = vehicleZ + 3.618; 147 | 148 | vehicleRoll = roll; 149 | vehiclePitch = pitch; 150 | vehicleYaw = yaw; 151 | 152 | sinVehicleRoll = sin(vehicleRoll); 153 | cosVehicleRoll = cos(vehicleRoll); 154 | sinVehiclePitch = sin(vehiclePitch); 155 | cosVehiclePitch = cos(vehiclePitch); 156 | sinVehicleYaw = sin(vehicleYaw); 157 | cosVehicleYaw = cos(vehicleYaw); 158 | 159 | // Convert the position and orientation into a transform 160 | tf::Transform transform; 161 | transform.setOrigin(tf::Vector3(position.x, position.y, position.z)); 162 | tf::Quaternion quat(orientation.x, orientation.y, orientation.z, orientation.w); 163 | transform.setRotation(quat); 164 | 165 | // Store the transformation to be used when processing the point cloud 166 | odomTransform = transform; 167 | 168 | // Extract the depth image from the depth message 169 | depthCloudTime = depthMsg->header.stamp.toSec(); 170 | 171 | if (!systemInited) { 172 | systemInitTime = depthCloudTime; 173 | systemInited = true; 174 | } 175 | 176 | cloud->clear(); 177 | cv_bridge::CvImageConstPtr cv_ptr; 178 | try { 179 | cv_ptr = cv_bridge::toCvShare(depthMsg, depthMsg->encoding); 180 | } catch (cv_bridge::Exception& e) { 181 | ROS_ERROR("cv_bridge exception: %s", e.what()); 182 | return; 183 | } 184 | 185 | if (depthMsg->encoding == sensor_msgs::image_encodings::TYPE_32FC1) { 186 | for (int v = 0; v < depthMsg->height; ++v) { 187 | for (int u = azimuth_buff; u < depthMsg->width-azimuth_buff; ++u) { 188 | float depth = cv_ptr->image.at(v, u); // Access the depth value as float (meters) 189 | if (depth > 0) { // Check for valid depth 190 | pcl::PointXYZ point = convertTo3DPoint(u, v, depth, intrinsics); 191 | pcl::PointXYZINormal iPoint; 192 | iPoint.x = point.x; 193 | iPoint.y = point.y; 194 | iPoint.z = point.z; 195 | iPoint.intensity = systemInitTime - depthCloudTime;; 196 | iPoint.curvature = customMsg->data.data[v * row_stride + u * col_stride]; 197 | cloud->points.push_back(iPoint); 198 | } 199 | } 200 | } 201 | } else if (depthMsg->encoding == sensor_msgs::image_encodings::TYPE_16UC1) { 202 | for (int v = 0; v < depthMsg->height; ++v) { 203 | for (int u = azimuth_buff; u < depthMsg->width-azimuth_buff; ++u) { 204 | uint16_t depth_mm = cv_ptr->image.at(v, u); // Access the depth value as uint16_t 205 | float depth = depth_mm * 0.001f; // Convert millimeters to meters 206 | if (depth != 0) { // Check for valid depth 207 | pcl::PointXYZ point = convertTo3DPoint(u, v, depth, intrinsics); 208 | pcl::PointXYZINormal iPoint; 209 | iPoint.x = point.x; 210 | iPoint.y = point.y; 211 | iPoint.z = point.z; 212 | iPoint.intensity = depthCloudTime - systemInitTime; 213 | iPoint.curvature = customMsg->data.data[v * row_stride + u * col_stride]; 214 | cloud->points.push_back(iPoint); 215 | } 216 | } 217 | } 218 | } else { 219 | ROS_ERROR("Unsupported depth encoding: %s", depthMsg->encoding.c_str()); 220 | return; 221 | } 222 | newDepthCloud = true; 223 | // ROS_INFO("Input cloud size %zu", cloud->points.size()); 224 | } 225 | 226 | int main(int argc, char** argv) { 227 | ros::init(argc, argv, "depth_projection"); 228 | ros::NodeHandle nh; 229 | 230 | std::string cameraType; 231 | nh.getParam("/depth_projection/camera_type", cameraType); 232 | nh.getParam("/depth_projection/decayTime", decayTime); 233 | setCameraIntrinsics(cameraType); 234 | 235 | // Set up subscribers using message_filters 236 | message_filters::Subscriber depthSub(nh, "/camera/aligned_depth_to_color/image_raw", 1); 237 | message_filters::Subscriber odomSub(nh, "/state_estimation", 1); 238 | message_filters::Subscriber customMsgSub(nh, "/inference/results_stamped_post", 1); 239 | 240 | // Create ApproximateTime policy 241 | typedef message_filters::sync_policies::ApproximateTime MySyncPolicy; 242 | message_filters::Synchronizer sync(MySyncPolicy(10), depthSub, odomSub, customMsgSub); 243 | sync.setInterMessageLowerBound(ros::Duration(1.5)); // Adjust time tolerance 244 | sync.registerCallback(boost::bind(&callback, _1, _2, _3)); 245 | 246 | // ros::Subscriber lossSub = nh.subscribe("/inference/results", 10, lossCallback); 247 | 248 | // cameraToMapTransform << 0.0, 0.0, 1.0, 0.0, // CMU_SIM transform 249 | // -1.0, 0.0, 0.0, 0.0, 250 | // 0.0,-1.0, 0.0, 0.0, 251 | // 0.0, 0.0, 0.0, 1.0; 252 | 253 | cameraToMapTransform << 0.01165962, -0.02415892, 0.99964014, 0.482, 254 | -0.99953617, 0.02784553, 0.01233136, 0.04, 255 | -0.02813342, -0.99932026, -0.02382304, 0.249, 256 | 0.0, 0.0, 0.0, 1.0; 257 | 258 | cloudPub = nh.advertise("/depth_projection", 10); 259 | 260 | downSizeFilter.setLeafSize(voxel_size_, voxel_size_, voxel_size_); 261 | 262 | //print out the camera intrinsics 263 | ROS_INFO("Camera intrinsics: fx = %f, fy = %f, cx = %f, cy = %f", intrinsics.fx, intrinsics.fy, intrinsics.cx, intrinsics.cy); 264 | 265 | ros::Rate rate(200); 266 | bool status = ros::ok(); 267 | while (status) { 268 | ros::spinOnce(); 269 | if (newDepthCloud) { 270 | newDepthCloud = false; 271 | 272 | //clear point clouds 273 | terrainCloud->clear(); 274 | transformedCloud->clear(); 275 | sparseCloud->clear(); 276 | sparseTerrainCloud->clear(); 277 | 278 | // Update terrain cloud as to get rid of old points outside decay distance 279 | int currentCloudSize = currentCloud->points.size(); 280 | for (int i = 0; i < currentCloudSize; i++) { 281 | pcl::PointXYZINormal point = currentCloud->points[i]; 282 | 283 | // Translate point to vehicle coordinate frame 284 | float translatedX = point.x - vehicleX; 285 | float translatedY = point.y - vehicleY; 286 | float translatedZ = point.z - vehicleZ; 287 | 288 | // Rotate point according to vehicle orientation 289 | float rotatedX = cosVehicleYaw * translatedX + sinVehicleYaw * translatedY; 290 | float rotatedY = -sinVehicleYaw * translatedX + cosVehicleYaw * translatedY; 291 | float rotatedZ = cosVehiclePitch * translatedZ - sinVehiclePitch * rotatedX; 292 | 293 | // Calculate planar distance in XY plane 294 | float dis = sqrt(rotatedX * rotatedX + rotatedY * rotatedY); 295 | 296 | // Calculate azimuth and elevation angles 297 | float angle1 = atan2(rotatedY, rotatedX); // Azimuth angle 298 | float angle2 = atan2(rotatedZ, dis); // Elevation angle 299 | 300 | // Check if the point is outside the decay time OR within no-decay distance 301 | // Also, check if the point is outside the FOV in both azimuth and elevation 302 | if ((depthCloudTime - systemInitTime + point.intensity < decayTime || dis < clearingDis) 303 | && point.z < vehicleHeight 304 | && (((fabs(angle1) > (fovx / 2) - 8*(PI/180) || fabs(angle2) > (fovy / 2))) || dis < minDis)) { // Use OR instead of AND 305 | terrainCloud->push_back(point); 306 | } 307 | // ROS_INFO("sysinit %f, depth %f, intensity %f, time diff %f",systemInitTime, depthCloudTime, point.intensity, depthCloudTime - systemInitTime - point.intensity); 308 | } 309 | 310 | //filter the terrain cloud 311 | downSizeFilter.setInputCloud(terrainCloud); 312 | downSizeFilter.filter(*sparseTerrainCloud); 313 | 314 | //filter input depth cloud 315 | // downSizeFilter.setInputCloud(cloud); 316 | // downSizeFilter.filter(*sparseCloud); 317 | 318 | // Transform the point cloud to the map frame 319 | pcl::transformPointCloud(*cloud, *transformedCloud, cameraToMapTransform); 320 | 321 | // ROS_INFO("transformedCloud size %zu", transformedCloud->points.size()); 322 | 323 | // Transform each point in the cloud to be in the odometry frame 324 | int transformedCloudSize = transformedCloud->points.size(); 325 | for (int i =0; i < transformedCloudSize; i++) { 326 | pcl::PointXYZINormal point = transformedCloud->points[i]; 327 | tf::Vector3 p(point.x, point.y, point.z); 328 | tf::Vector3 pTransformed = odomTransform * p; 329 | pcl::PointXYZINormal newPoint; 330 | newPoint.x = pTransformed.x(); 331 | newPoint.y = pTransformed.y(); 332 | newPoint.z = pTransformed.z(); 333 | newPoint.intensity = point.intensity; 334 | newPoint.curvature = point.curvature; 335 | float dis = sqrt((newPoint.x - vehicleX) * (newPoint.x - vehicleX) + (newPoint.y - vehicleY) * (newPoint.y - vehicleY)); 336 | if (newPoint.z < vehicleZ + vehicleHeight && dis > minDis && dis < noDecayDis) { 337 | sparseTerrainCloud->push_back(newPoint); 338 | } 339 | } 340 | 341 | currentCloud = pcl::PointCloud::Ptr(new pcl::PointCloud(*sparseTerrainCloud)); 342 | 343 | //loop through the terrain cloud 344 | pubCloud->clear(); 345 | int terrainCloudSize = sparseTerrainCloud->points.size(); 346 | for (int i = 0; i < terrainCloudSize; i++) { 347 | pcl::PointXYZINormal point = sparseTerrainCloud->points[i]; 348 | pcl::PointXYZI newPoint; 349 | newPoint.x = point.x; 350 | newPoint.y = point.y; 351 | newPoint.z = point.z; 352 | newPoint.intensity = point.curvature; 353 | // newPoint.intensity = point.z; 354 | pubCloud->push_back(newPoint); 355 | } 356 | 357 | // Publish the terrain cloud 358 | sensor_msgs::PointCloud2 terrainCloud2; 359 | pcl::toROSMsg(*pubCloud, terrainCloud2); 360 | terrainCloud2.header.frame_id = "odom"; 361 | terrainCloud2.header.stamp = ros::Time().fromSec(depthCloudTime); 362 | cloudPub.publish(terrainCloud2); 363 | } 364 | 365 | status = ros::ok(); 366 | rate.sleep(); 367 | } 368 | 369 | return 0; 370 | } -------------------------------------------------------------------------------- /assets/front_page.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/assets/front_page.png -------------------------------------------------------------------------------- /assets/outdoor_all_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/assets/outdoor_all_2.png -------------------------------------------------------------------------------- /assets/pre_train_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/assets/pre_train_pipeline.png -------------------------------------------------------------------------------- /checkpoints/all_ViT_small_input_700_big_nn_checkpoint_20240827-1935.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/checkpoints/all_ViT_small_input_700_big_nn_checkpoint_20240827-1935.pth -------------------------------------------------------------------------------- /checkpoints/richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/checkpoints/richmond_forest_full_ViT_small_big_nn_checkpoint_20240821-1825.pth -------------------------------------------------------------------------------- /checkpoints/unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RPL-CS-UCL/STEPP-Code/cf70e73c0e4f67eb54b5c0cbbea9ba23a656ba43/checkpoints/unreal_full_ViT_small_big_nn_checkpoint_20240819-2003.pth -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="STEPP", 5 | version="1.0.0", 6 | author="Sebastian Aegidius", 7 | author_email="your.email@example.com", 8 | description="Traversability estimation package using image features", 9 | # long_description=open("README.md").read(), 10 | # long_description_content_type="text/markdown", 11 | url="https://github.com/RPL-CS-UCL/STEPP-Code", 12 | packages=find_packages(), 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", 16 | "Operating System :: OS Independent", 17 | ], 18 | python_requires='>=3.8', 19 | install_requires=[ 20 | #generic 21 | "numpy", 22 | "tqdm", 23 | "kornia>=0.6.5", 24 | "pip", 25 | "torchvision", 26 | "torch>=1.21", 27 | "torchmetrics", 28 | "pytorch_lightning>=1.6.5", 29 | "pytest", 30 | "scipy", 31 | "scikit-image", 32 | "scikit-learn", 33 | "matplotlib", 34 | "seaborn", 35 | "pandas", 36 | "pytictac", 37 | "torch_geometric", 38 | "omegaconf", 39 | "optuna", 40 | "neptune", 41 | "fast-slic", 42 | "hydra-core", 43 | "prettytable", 44 | "termcolor", 45 | "pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git", 46 | "liegroups@git+https://github.com/mmattamala/liegroups", 47 | "wget", 48 | "rospkg", 49 | "wandb", 50 | "opencv-python", 51 | ], 52 | include_package_data=True, 53 | package_data={ 54 | }, 55 | ) --------------------------------------------------------------------------------