├── .gitignore ├── Inference ├── depth_pipeline.py └── depth_pipeline_half.py ├── LICENSE ├── README.md ├── datafiles ├── KITTI │ ├── kitti_example_train.txt │ ├── kitti_example_val.txt │ ├── kitti_raw_all.txt │ ├── kitti_raw_train.txt │ └── kitti_raw_val.txt ├── filenames │ ├── kitti_example_train.txt │ ├── kitti_example_val.txt │ ├── kitti_raw_all.txt │ ├── kitti_raw_train.txt │ └── kitti_raw_val.txt ├── middlebury │ ├── MiddleBury_mix.list │ ├── Middleburry.list │ ├── MiddleburryV3_val.list │ ├── middleburry2003_train.list │ ├── middleburry2005_train.list │ ├── middleburry2006_train.list │ ├── middleburry2014_train.list │ ├── middleburryV3_train.list │ ├── middleburry_2014_mix.list │ └── middleburry_submit.list └── sceneflow │ ├── Driving_train.list │ ├── FlyingThings3D_Test_With_Occ.list │ ├── SceneFlow_With_Occ.list │ ├── SceneFlow_With_Occ_mix.list │ ├── Things3D_train.list │ └── scneflow_mid_mix.list ├── dataloader ├── __pycache__ │ ├── file_io.cpython-39.pyc │ ├── sceneflow_loader.cpython-39.pyc │ ├── transforms.cpython-39.pyc │ └── utils.cpython-39.pyc ├── file_io.py ├── sceneflow_loader.py ├── transforms.py └── utils.py ├── playground └── check_depth_est.py ├── run └── run_inference.py ├── scripts ├── inference.sh └── train.sh ├── training ├── dataset_configuration.py └── depth2image_trainer.py └── utils ├── __pycache__ ├── colormap.cpython-39.pyc ├── common.cpython-39.pyc ├── de_normalized.cpython-39.pyc ├── depth_ensemble.cpython-39.pyc ├── image_util.cpython-39.pyc └── seed_all.cpython-39.pyc ├── batch_size.py ├── colormap.py ├── common.py ├── de_normalized.py ├── depth_ensemble.py ├── image_util.py └── seed_all.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.jpg 3 | *.pt 4 | *.png 5 | *.safetensors 6 | *.bin 7 | *.pyc 8 | outputs 9 | *.json 10 | *.bin 11 | *.pkl 12 | -------------------------------------------------------------------------------- /Inference/depth_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Dict, Union 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, TensorDataset 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | from PIL import Image 9 | from diffusers import ( 10 | DiffusionPipeline, 11 | DDIMScheduler, 12 | UNet2DConditionModel, 13 | AutoencoderKL, 14 | ) 15 | from diffusers.utils import BaseOutput 16 | from transformers import CLIPTextModel, CLIPTokenizer 17 | 18 | from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps 19 | from utils.colormap import kitti_colormap 20 | from utils.depth_ensemble import ensemble_depths 21 | 22 | 23 | 24 | class DepthPipelineOutput(BaseOutput): 25 | """ 26 | Output class for Marigold monocular depth prediction pipeline. 27 | 28 | Args: 29 | depth_np (`np.ndarray`): 30 | Predicted depth map, with depth values in the range of [0, 1]. 31 | depth_colored (`PIL.Image.Image`): 32 | Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. 33 | uncertainty (`None` or `np.ndarray`): 34 | Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. 35 | """ 36 | depth_np: np.ndarray 37 | depth_colored: Image.Image 38 | uncertainty: Union[None, np.ndarray] 39 | 40 | 41 | class DepthEstimationPipeline(DiffusionPipeline): 42 | # two hyper-parameters 43 | rgb_latent_scale_factor = 0.18215 44 | depth_latent_scale_factor = 0.18215 45 | 46 | def __init__(self, 47 | unet:UNet2DConditionModel, 48 | vae:AutoencoderKL, 49 | scheduler:DDIMScheduler, 50 | text_encoder:CLIPTextModel, 51 | tokenizer:CLIPTokenizer, 52 | ): 53 | super().__init__() 54 | 55 | self.register_modules( 56 | unet=unet, 57 | vae=vae, 58 | scheduler=scheduler, 59 | text_encoder=text_encoder, 60 | tokenizer=tokenizer, 61 | ) 62 | self.empty_text_embed = None 63 | 64 | 65 | 66 | 67 | @torch.no_grad() 68 | def __call__(self, 69 | input_image:Image, 70 | denosing_steps: int =10, 71 | ensemble_size: int =10, 72 | processing_res: int = 768, 73 | match_input_res:bool =True, 74 | batch_size:int =0, 75 | color_map: str="Spectral", 76 | show_progress_bar:bool = True, 77 | ensemble_kwargs: Dict = None, 78 | ) -> DepthPipelineOutput: 79 | 80 | # inherit from thea Diffusion Pipeline 81 | device = self.device 82 | input_size = input_image.size 83 | 84 | # adjust the input resolution. 85 | if not match_input_res: 86 | assert ( 87 | processing_res is not None 88 | )," Value Error: `resize_output_back` is only valid with " 89 | 90 | assert processing_res >=0 91 | assert denosing_steps >=1 92 | assert ensemble_size >=1 93 | 94 | # --------------- Image Processing ------------------------ 95 | # Resize image 96 | if processing_res >0: 97 | input_image = resize_max_res( 98 | input_image, max_edge_resolution=processing_res 99 | ) # resize image: for kitti is 231, 768 100 | 101 | 102 | # Convert the image to RGB, to 1. reomve the alpha channel. 103 | input_image = input_image.convert("RGB") 104 | image = np.array(input_image) 105 | 106 | 107 | # Normalize RGB Values. 108 | rgb = np.transpose(image,(2,0,1)) 109 | rgb_norm = rgb / 255.0 110 | rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) 111 | rgb_norm = rgb_norm.to(device) 112 | 113 | 114 | assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 115 | 116 | # ----------------- predicting depth ----------------- 117 | duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) 118 | single_rgb_dataset = TensorDataset(duplicated_rgb) 119 | 120 | # find the batch size 121 | if batch_size>0: 122 | _bs = batch_size 123 | else: 124 | _bs = 1 125 | 126 | single_rgb_loader = DataLoader(single_rgb_dataset,batch_size=_bs,shuffle=False) 127 | 128 | # predicted the depth 129 | depth_pred_ls = [] 130 | 131 | if show_progress_bar: 132 | iterable_bar = tqdm( 133 | single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False 134 | ) 135 | else: 136 | iterable_bar = single_rgb_loader 137 | 138 | for batch in iterable_bar: 139 | (batched_image,)= batch # here the image is still around 0-1 140 | depth_pred_raw = self.single_infer( 141 | input_rgb=batched_image, 142 | num_inference_steps=denosing_steps, 143 | show_pbar=show_progress_bar, 144 | ) 145 | depth_pred_ls.append(depth_pred_raw.detach().clone()) 146 | 147 | depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() #(10,224,768) 148 | torch.cuda.empty_cache() # clear vram cache for ensembling 149 | 150 | 151 | # ----------------- Test-time ensembling ----------------- 152 | if ensemble_size > 1: 153 | depth_pred, pred_uncert = ensemble_depths( 154 | depth_preds, **(ensemble_kwargs or {}) 155 | ) 156 | else: 157 | depth_pred = depth_preds 158 | pred_uncert = None 159 | 160 | # ----------------- Post processing ----------------- 161 | # Scale prediction to [0, 1] 162 | min_d = torch.min(depth_pred) 163 | max_d = torch.max(depth_pred) 164 | depth_pred = (depth_pred - min_d) / (max_d - min_d) 165 | 166 | # Convert to numpy 167 | depth_pred = depth_pred.cpu().numpy().astype(np.float32) 168 | 169 | # Resize back to original resolution 170 | if match_input_res: 171 | pred_img = Image.fromarray(depth_pred) 172 | pred_img = pred_img.resize(input_size) 173 | depth_pred = np.asarray(pred_img) 174 | 175 | # Clip output range: current size is the original size 176 | depth_pred = depth_pred.clip(0, 1) 177 | 178 | # colorization using the KITTI Color Plan. 179 | depth_pred_vis = depth_pred * 70 180 | disp_vis = 400/(depth_pred_vis+1e-3) 181 | disp_vis = disp_vis.clip(0,500) 182 | 183 | depth_color_pred = kitti_colormap(disp_vis) 184 | 185 | # Colorize 186 | depth_colored = colorize_depth_maps( 187 | depth_pred, 0, 1, cmap=color_map 188 | ).squeeze() # [3, H, W], value in (0, 1) 189 | depth_colored = (depth_colored * 255).astype(np.uint8) 190 | depth_colored_hwc = chw2hwc(depth_colored) 191 | depth_colored_img = Image.fromarray(depth_colored_hwc) 192 | 193 | 194 | return DepthPipelineOutput( 195 | depth_np = depth_pred, 196 | depth_colored = depth_colored_img, 197 | uncertainty=pred_uncert, 198 | ) 199 | 200 | 201 | def __encode_empty_text(self): 202 | """ 203 | Encode text embedding for empty prompt 204 | """ 205 | prompt = "" 206 | text_inputs = self.tokenizer( 207 | prompt, 208 | padding="do_not_pad", 209 | max_length=self.tokenizer.model_max_length, 210 | truncation=True, 211 | return_tensors="pt", 212 | ) 213 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2] 214 | # print(text_input_ids.shape) 215 | self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024] 216 | 217 | 218 | 219 | @torch.no_grad() 220 | def single_infer(self,input_rgb:torch.Tensor, 221 | num_inference_steps:int, 222 | show_pbar:bool,): 223 | 224 | 225 | device = input_rgb.device 226 | 227 | # Set timesteps: inherit from the diffuison pipeline 228 | self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10. 229 | timesteps = self.scheduler.timesteps # [T] 230 | 231 | # encode image 232 | rgb_latent = self.encode_RGB(input_rgb) # 1/8 Resolution with a channel nums of 4. 233 | 234 | 235 | # Initial depth map (Guassian noise) 236 | depth_latent = torch.randn( 237 | rgb_latent.shape, device=device, dtype=self.dtype 238 | ) # [B, 4, H/8, W/8] 239 | 240 | 241 | # Batched empty text embedding 242 | if self.empty_text_embed is None: 243 | self.__encode_empty_text() 244 | 245 | batch_empty_text_embed = self.empty_text_embed.repeat( 246 | (rgb_latent.shape[0], 1, 1) 247 | ) # [B, 2, 1024] 248 | 249 | # Denoising loop 250 | if show_pbar: 251 | iterable = tqdm( 252 | enumerate(timesteps), 253 | total=len(timesteps), 254 | leave=False, 255 | desc=" " * 4 + "Diffusion denoising", 256 | ) 257 | else: 258 | iterable = enumerate(timesteps) 259 | 260 | for i, t in iterable: 261 | unet_input = torch.cat( 262 | [rgb_latent, depth_latent], dim=1 263 | ) # this order is important: [1,8,H,W] 264 | 265 | # print(unet_input.shape) 266 | 267 | # predict the noise residual 268 | noise_pred = self.unet( 269 | unet_input, t, encoder_hidden_states=batch_empty_text_embed 270 | ).sample # [B, 4, h, w] 271 | 272 | # compute the previous noisy sample x_t -> x_t-1 273 | depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample 274 | 275 | torch.cuda.empty_cache() 276 | depth = self.decode_depth(depth_latent) 277 | # clip prediction 278 | depth = torch.clip(depth, -1.0, 1.0) 279 | # shift to [0, 1] 280 | depth = (depth + 1.0) / 2.0 281 | return depth 282 | 283 | 284 | def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor: 285 | """ 286 | Encode RGB image into latent. 287 | 288 | Args: 289 | rgb_in (`torch.Tensor`): 290 | Input RGB image to be encoded. 291 | 292 | Returns: 293 | `torch.Tensor`: Image latent. 294 | """ 295 | 296 | 297 | # encode 298 | h = self.vae.encoder(rgb_in) 299 | 300 | moments = self.vae.quant_conv(h) 301 | mean, logvar = torch.chunk(moments, 2, dim=1) 302 | # scale latent 303 | rgb_latent = mean * self.rgb_latent_scale_factor 304 | 305 | return rgb_latent 306 | 307 | def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: 308 | """ 309 | Decode depth latent into depth map. 310 | 311 | Args: 312 | depth_latent (`torch.Tensor`): 313 | Depth latent to be decoded. 314 | 315 | Returns: 316 | `torch.Tensor`: Decoded depth map. 317 | """ 318 | # scale latent 319 | depth_latent = depth_latent / self.depth_latent_scale_factor 320 | # decode 321 | z = self.vae.post_quant_conv(depth_latent) 322 | stacked = self.vae.decoder(z) 323 | # mean of output channels 324 | depth_mean = stacked.mean(dim=1, keepdim=True) 325 | return depth_mean 326 | 327 | 328 | -------------------------------------------------------------------------------- /Inference/depth_pipeline_half.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Dict, Union 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, TensorDataset 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | from PIL import Image 9 | from diffusers import ( 10 | DiffusionPipeline, 11 | DDIMScheduler, 12 | UNet2DConditionModel, 13 | AutoencoderKL, 14 | ) 15 | from diffusers.utils import BaseOutput 16 | from transformers import CLIPTextModel, CLIPTokenizer 17 | 18 | from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps 19 | from utils.colormap import kitti_colormap 20 | from utils.depth_ensemble import ensemble_depths 21 | 22 | 23 | 24 | class DepthPipelineOutput(BaseOutput): 25 | """ 26 | Output class for Marigold monocular depth prediction pipeline. 27 | 28 | Args: 29 | depth_np (`np.ndarray`): 30 | Predicted depth map, with depth values in the range of [0, 1]. 31 | depth_colored (`PIL.Image.Image`): 32 | Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. 33 | uncertainty (`None` or `np.ndarray`): 34 | Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. 35 | """ 36 | depth_np: np.ndarray 37 | depth_colored: Image.Image 38 | uncertainty: Union[None, np.ndarray] 39 | 40 | 41 | class DepthEstimationPipeline(DiffusionPipeline): 42 | # two hyper-parameters 43 | rgb_latent_scale_factor = 0.18215 44 | depth_latent_scale_factor = 0.18215 45 | 46 | def __init__(self, 47 | unet:UNet2DConditionModel, 48 | vae:AutoencoderKL, 49 | scheduler:DDIMScheduler, 50 | text_encoder:CLIPTextModel, 51 | tokenizer:CLIPTokenizer, 52 | ): 53 | super().__init__() 54 | 55 | self.register_modules( 56 | unet=unet, 57 | vae=vae, 58 | scheduler=scheduler, 59 | text_encoder=text_encoder, 60 | tokenizer=tokenizer, 61 | ) 62 | self.empty_text_embed = None 63 | 64 | # self.current_dtype = torch.float16 65 | 66 | 67 | 68 | 69 | @torch.no_grad() 70 | def __call__(self, 71 | input_image:Image, 72 | denosing_steps: int =10, 73 | ensemble_size: int =10, 74 | processing_res: int = 768, 75 | match_input_res:bool =True, 76 | batch_size:int =0, 77 | color_map: str="Spectral", 78 | show_progress_bar:bool = True, 79 | ensemble_kwargs: Dict = None, 80 | ) -> DepthPipelineOutput: 81 | 82 | # inherit from thea Diffusion Pipeline 83 | device = self.device 84 | input_size = input_image.size 85 | 86 | # adjust the input resolution. 87 | if not match_input_res: 88 | assert ( 89 | processing_res is not None 90 | )," Value Error: `resize_output_back` is only valid with " 91 | 92 | assert processing_res >=0 93 | assert denosing_steps >=1 94 | assert ensemble_size >=1 95 | 96 | # --------------- Image Processing ------------------------ 97 | # Resize image 98 | if processing_res >0: 99 | input_image = resize_max_res( 100 | input_image, max_edge_resolution=processing_res 101 | ) # resize image: for kitti is 231, 768 102 | 103 | 104 | # Convert the image to RGB, to 1. reomve the alpha channel. 105 | input_image = input_image.convert("RGB") 106 | image = np.array(input_image) 107 | 108 | 109 | # Normalize RGB Values. 110 | rgb = np.transpose(image,(2,0,1)) 111 | rgb_norm = rgb / 255.0 112 | rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) 113 | rgb_norm = rgb_norm.to(device) 114 | 115 | rgb_norm = rgb_norm.half() 116 | 117 | 118 | assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 119 | 120 | # ----------------- predicting depth ----------------- 121 | duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) 122 | single_rgb_dataset = TensorDataset(duplicated_rgb) 123 | 124 | # find the batch size 125 | if batch_size>0: 126 | _bs = batch_size 127 | else: 128 | _bs = 1 129 | 130 | single_rgb_loader = DataLoader(single_rgb_dataset,batch_size=_bs,shuffle=False) 131 | 132 | # predicted the depth 133 | depth_pred_ls = [] 134 | 135 | if show_progress_bar: 136 | iterable_bar = tqdm( 137 | single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False 138 | ) 139 | else: 140 | iterable_bar = single_rgb_loader 141 | 142 | for batch in iterable_bar: 143 | (batched_image,)= batch # here the image is still around 0-1 144 | depth_pred_raw = self.single_infer( 145 | input_rgb=batched_image, 146 | num_inference_steps=denosing_steps, 147 | show_pbar=show_progress_bar, 148 | ) 149 | depth_pred_ls.append(depth_pred_raw.detach().clone()) 150 | 151 | depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() #(10,224,768) 152 | torch.cuda.empty_cache() # clear vram cache for ensembling 153 | 154 | 155 | # ----------------- Test-time ensembling ----------------- 156 | if ensemble_size > 1: 157 | depth_pred, pred_uncert = ensemble_depths( 158 | depth_preds, **(ensemble_kwargs or {}) 159 | ) 160 | else: 161 | depth_pred = depth_preds 162 | pred_uncert = None 163 | 164 | # ----------------- Post processing ----------------- 165 | # Scale prediction to [0, 1] 166 | min_d = torch.min(depth_pred) 167 | max_d = torch.max(depth_pred) 168 | depth_pred = (depth_pred - min_d) / (max_d - min_d) 169 | 170 | # Convert to numpy 171 | depth_pred = depth_pred.cpu().numpy().astype(np.float32) 172 | 173 | # Resize back to original resolution 174 | if match_input_res: 175 | pred_img = Image.fromarray(depth_pred) 176 | pred_img = pred_img.resize(input_size) 177 | depth_pred = np.asarray(pred_img) 178 | 179 | # Clip output range: current size is the original size 180 | depth_pred = depth_pred.clip(0, 1) 181 | 182 | # colorization using the KITTI Color Plan. 183 | depth_pred_vis = depth_pred * 70 184 | disp_vis = 400/(depth_pred_vis+1e-3) 185 | disp_vis = disp_vis.clip(0,500) 186 | 187 | depth_color_pred = kitti_colormap(disp_vis) 188 | 189 | # Colorize 190 | depth_colored = colorize_depth_maps( 191 | depth_pred, 0, 1, cmap=color_map 192 | ).squeeze() # [3, H, W], value in (0, 1) 193 | depth_colored = (depth_colored * 255).astype(np.uint8) 194 | depth_colored_hwc = chw2hwc(depth_colored) 195 | depth_colored_img = Image.fromarray(depth_colored_hwc) 196 | 197 | 198 | return DepthPipelineOutput( 199 | depth_np = depth_pred, 200 | depth_colored = depth_colored_img, 201 | uncertainty=pred_uncert, 202 | ) 203 | 204 | 205 | def __encode_empty_text(self): 206 | """ 207 | Encode text embedding for empty prompt 208 | """ 209 | prompt = "" 210 | text_inputs = self.tokenizer( 211 | prompt, 212 | padding="do_not_pad", 213 | max_length=self.tokenizer.model_max_length, 214 | truncation=True, 215 | return_tensors="pt", 216 | ) 217 | text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) #[1,2] 218 | # print(text_input_ids.shape) 219 | self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) #[1,2,1024] 220 | self.empty_text_embed = self.empty_text_embed.half() 221 | 222 | 223 | @torch.no_grad() 224 | def single_infer(self,input_rgb:torch.Tensor, 225 | num_inference_steps:int, 226 | show_pbar:bool,): 227 | 228 | 229 | device = input_rgb.device 230 | 231 | # Set timesteps: inherit from the diffuison pipeline 232 | self.scheduler.set_timesteps(num_inference_steps, device=device) # here the numbers of the steps is only 10. 233 | timesteps = self.scheduler.timesteps # [T] 234 | 235 | # encode image 236 | rgb_latent = self.encode_RGB(input_rgb) # 1/8 Resolution with a channel nums of 4. 237 | 238 | 239 | # Initial depth map (Guassian noise) 240 | depth_latent = torch.randn( 241 | rgb_latent.shape, device=device, dtype=self.dtype 242 | ) # [B, 4, H/8, W/8] 243 | 244 | depth_latent = depth_latent.half() 245 | 246 | 247 | # Batched empty text embedding 248 | if self.empty_text_embed is None: 249 | self.__encode_empty_text() 250 | 251 | batch_empty_text_embed = self.empty_text_embed.repeat( 252 | (rgb_latent.shape[0], 1, 1) 253 | ) # [B, 2, 1024] 254 | 255 | # Denoising loop 256 | if show_pbar: 257 | iterable = tqdm( 258 | enumerate(timesteps), 259 | total=len(timesteps), 260 | leave=False, 261 | desc=" " * 4 + "Diffusion denoising", 262 | ) 263 | else: 264 | iterable = enumerate(timesteps) 265 | 266 | for i, t in iterable: 267 | unet_input = torch.cat( 268 | [rgb_latent, depth_latent], dim=1 269 | ) # this order is important: [1,8,H,W] 270 | 271 | # print(unet_input.shape) 272 | 273 | # predict the noise residual 274 | noise_pred = self.unet( 275 | unet_input, t, encoder_hidden_states=batch_empty_text_embed 276 | ).sample # [B, 4, h, w] 277 | 278 | # compute the previous noisy sample x_t -> x_t-1 279 | depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample 280 | 281 | torch.cuda.empty_cache() 282 | depth = self.decode_depth(depth_latent) 283 | # clip prediction 284 | depth = torch.clip(depth, -1.0, 1.0) 285 | # shift to [0, 1] 286 | depth = (depth + 1.0) / 2.0 287 | return depth 288 | 289 | 290 | def encode_RGB(self, rgb_in: torch.Tensor) -> torch.Tensor: 291 | """ 292 | Encode RGB image into latent. 293 | 294 | Args: 295 | rgb_in (`torch.Tensor`): 296 | Input RGB image to be encoded. 297 | 298 | Returns: 299 | `torch.Tensor`: Image latent. 300 | """ 301 | 302 | 303 | # encode 304 | h = self.vae.encoder(rgb_in) 305 | 306 | moments = self.vae.quant_conv(h) 307 | mean, logvar = torch.chunk(moments, 2, dim=1) 308 | # scale latent 309 | rgb_latent = mean * self.rgb_latent_scale_factor 310 | 311 | return rgb_latent 312 | 313 | def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: 314 | """ 315 | Decode depth latent into depth map. 316 | 317 | Args: 318 | depth_latent (`torch.Tensor`): 319 | Depth latent to be decoded. 320 | 321 | Returns: 322 | `torch.Tensor`: Decoded depth map. 323 | """ 324 | # scale latent 325 | depth_latent = depth_latent / self.depth_latent_scale_factor 326 | 327 | depth_latent = depth_latent.half() 328 | # decode 329 | z = self.vae.post_quant_conv(depth_latent) 330 | stacked = self.vae.decoder(z) 331 | # mean of output channels 332 | depth_mean = stacked.mean(dim=1, keepdim=True) 333 | return depth_mean 334 | 335 | 336 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Luke Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Accelerator-Simple-Template 2 | 3 | This an example of training the [Marigold Depth Estimation](https://huggingface.co/spaces/toshas/marigold) using accelerator using the sceneflow dataset. Since the original training code is not open source, only the inference pipeline is released, so the performance is not guaranteed. BTW, Any other dataset is fine, just change the dataloader. 4 | 5 | Reference Code: [Marigold-ETH](https://github.com/prs-eth/marigold) 6 | 7 | Reference Paper: [Marigold: Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://arxiv.org/abs/2312.02145) 8 | 9 | 10 | #### Run the Inference of Monodepth estimation: 11 | 12 | ``` 13 | cd scripts 14 | sh inference.sh 15 | ``` 16 | 17 | #### Run the Inference of Monodepth Training, Using SceneFlow as an example: 18 | ``` 19 | cd scripts 20 | sh train.sh 21 | ``` 22 | 23 | Note the training at least takes 21 VRAM even the batch size is set to 1. -------------------------------------------------------------------------------- /datafiles/KITTI/kitti_example_train.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000000.png 2 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000001.png 3 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000002.png 4 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000003.png 5 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000004.png 6 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000005.png 7 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000006.png 8 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000007.png 9 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000008.png 10 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000009.png 11 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000010.png 12 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000011.png 13 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000012.png 14 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000013.png 15 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000014.png 16 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000015.png 17 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000016.png 18 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000017.png 19 | -------------------------------------------------------------------------------- /datafiles/KITTI/kitti_example_val.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000007.png 2 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000008.png 3 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000015.png 4 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000016.png 5 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000017.png 6 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000018.png -------------------------------------------------------------------------------- /datafiles/filenames/kitti_example_train.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000000.png 2 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000001.png 3 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000002.png 4 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000003.png 5 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000004.png 6 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000005.png 7 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000006.png 8 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000007.png 9 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000008.png 10 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000009.png 11 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000010.png 12 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000011.png 13 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000012.png 14 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000013.png 15 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000014.png 16 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000015.png 17 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000016.png 18 | 2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000017.png 19 | -------------------------------------------------------------------------------- /datafiles/filenames/kitti_example_val.txt: -------------------------------------------------------------------------------- 1 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000007.png 2 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000008.png 3 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000015.png 4 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000016.png 5 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000017.png 6 | 2011_09_26/2011_09_26_drive_0002_sync/image_02/data/0000000018.png -------------------------------------------------------------------------------- /datafiles/middlebury/MiddleBury_mix.list: -------------------------------------------------------------------------------- 1 | MID_2005/trainingH/Reindeer/im0.png MID_2005/trainingH/Reindeer/im1.png MID_2005/trainingH/Reindeer/disp0.pfm MID_2005/trainingH/Reindeer/occ0.npy MID_2005/trainingH/Reindeer/im0.png MID_2005/trainingH/Reindeer/im1.png MID_2005/trainingH/Reindeer/disp0.pfm 2 | MID_2005/trainingH/Books/im0.png MID_2005/trainingH/Books/im1.png MID_2005/trainingH/Books/disp0.pfm MID_2005/trainingH/Books/occ0.npy MID_2005/trainingH/Books/im0.png MID_2005/trainingH/Books/im1.png MID_2005/trainingH/Books/disp0.pfm 3 | MID_2005/trainingH/Laundry/im0.png MID_2005/trainingH/Laundry/im1.png MID_2005/trainingH/Laundry/disp0.pfm MID_2005/trainingH/Laundry/occ0.npy MID_2005/trainingH/Laundry/im0.png MID_2005/trainingH/Laundry/im1.png MID_2005/trainingH/Laundry/disp0.pfm 4 | MID_2005/trainingH/Dolls/im0.png MID_2005/trainingH/Dolls/im1.png MID_2005/trainingH/Dolls/disp0.pfm MID_2005/trainingH/Dolls/occ0.npy MID_2005/trainingH/Dolls/im0.png MID_2005/trainingH/Dolls/im1.png MID_2005/trainingH/Dolls/disp0.pfm 5 | MID_2005/trainingH/Moebius/im0.png MID_2005/trainingH/Moebius/im1.png MID_2005/trainingH/Moebius/disp0.pfm MID_2005/trainingH/Moebius/occ0.npy MID_2005/trainingH/Moebius/im0.png MID_2005/trainingH/Moebius/im1.png MID_2005/trainingH/Moebius/disp0.pfm 6 | MID_2006/trainingH/Monopoly/im0.png MID_2006/trainingH/Monopoly/im1.png MID_2006/trainingH/Monopoly/disp0.pfm MID_2006/trainingH/Monopoly/occ0.npy MID_2006/trainingH/Monopoly/im0.png MID_2006/trainingH/Monopoly/im1.png MID_2006/trainingH/Monopoly/disp0.pfm 7 | MID_2006/trainingH/Cloth2/im0.png MID_2006/trainingH/Cloth2/im1.png MID_2006/trainingH/Cloth2/disp0.pfm MID_2006/trainingH/Cloth2/occ0.npy MID_2006/trainingH/Cloth2/im0.png MID_2006/trainingH/Cloth2/im1.png MID_2006/trainingH/Cloth2/disp0.pfm 8 | MID_2006/trainingH/Baby2/im0.png MID_2006/trainingH/Baby2/im1.png MID_2006/trainingH/Baby2/disp0.pfm MID_2006/trainingH/Baby2/occ0.npy MID_2006/trainingH/Baby2/im0.png MID_2006/trainingH/Baby2/im1.png MID_2006/trainingH/Baby2/disp0.pfm 9 | MID_2006/trainingH/Bowling2/im0.png MID_2006/trainingH/Bowling2/im1.png MID_2006/trainingH/Bowling2/disp0.pfm MID_2006/trainingH/Bowling2/occ0.npy MID_2006/trainingH/Bowling2/im0.png MID_2006/trainingH/Bowling2/im1.png MID_2006/trainingH/Bowling2/disp0.pfm 10 | MID_2006/trainingH/Baby3/im0.png MID_2006/trainingH/Baby3/im1.png MID_2006/trainingH/Baby3/disp0.pfm MID_2006/trainingH/Baby3/occ0.npy MID_2006/trainingH/Baby3/im0.png MID_2006/trainingH/Baby3/im1.png MID_2006/trainingH/Baby3/disp0.pfm 11 | MID_2006/trainingH/Flowerpots/im0.png MID_2006/trainingH/Flowerpots/im1.png MID_2006/trainingH/Flowerpots/disp0.pfm MID_2006/trainingH/Flowerpots/occ0.npy MID_2006/trainingH/Flowerpots/im0.png MID_2006/trainingH/Flowerpots/im1.png MID_2006/trainingH/Flowerpots/disp0.pfm 12 | MID_2006/trainingH/Lampshade2/im0.png MID_2006/trainingH/Lampshade2/im1.png MID_2006/trainingH/Lampshade2/disp0.pfm MID_2006/trainingH/Lampshade2/occ0.npy MID_2006/trainingH/Lampshade2/im0.png MID_2006/trainingH/Lampshade2/im1.png MID_2006/trainingH/Lampshade2/disp0.pfm 13 | MID_2006/trainingH/Bowling1/im0.png MID_2006/trainingH/Bowling1/im1.png MID_2006/trainingH/Bowling1/disp0.pfm MID_2006/trainingH/Bowling1/occ0.npy MID_2006/trainingH/Bowling1/im0.png MID_2006/trainingH/Bowling1/im1.png MID_2006/trainingH/Bowling1/disp0.pfm 14 | MID_2006/trainingH/Baby1/im0.png MID_2006/trainingH/Baby1/im1.png MID_2006/trainingH/Baby1/disp0.pfm MID_2006/trainingH/Baby1/occ0.npy MID_2006/trainingH/Baby1/im0.png MID_2006/trainingH/Baby1/im1.png MID_2006/trainingH/Baby1/disp0.pfm 15 | MID_2006/trainingH/Cloth1/im0.png MID_2006/trainingH/Cloth1/im1.png MID_2006/trainingH/Cloth1/disp0.pfm MID_2006/trainingH/Cloth1/occ0.npy MID_2006/trainingH/Cloth1/im0.png MID_2006/trainingH/Cloth1/im1.png MID_2006/trainingH/Cloth1/disp0.pfm 16 | MID_2006/trainingH/Rocks2/im0.png MID_2006/trainingH/Rocks2/im1.png MID_2006/trainingH/Rocks2/disp0.pfm MID_2006/trainingH/Rocks2/occ0.npy MID_2006/trainingH/Rocks2/im0.png MID_2006/trainingH/Rocks2/im1.png MID_2006/trainingH/Rocks2/disp0.pfm 17 | MID_2006/trainingH/Midd1/im0.png MID_2006/trainingH/Midd1/im1.png MID_2006/trainingH/Midd1/disp0.pfm MID_2006/trainingH/Midd1/occ0.npy MID_2006/trainingH/Midd1/im0.png MID_2006/trainingH/Midd1/im1.png MID_2006/trainingH/Midd1/disp0.pfm 18 | MID_2006/trainingH/Aloe/im0.png MID_2006/trainingH/Aloe/im1.png MID_2006/trainingH/Aloe/disp0.pfm MID_2006/trainingH/Aloe/occ0.npy MID_2006/trainingH/Aloe/im0.png MID_2006/trainingH/Aloe/im1.png MID_2006/trainingH/Aloe/disp0.pfm 19 | MID_2006/trainingH/Plastic/im0.png MID_2006/trainingH/Plastic/im1.png MID_2006/trainingH/Plastic/disp0.pfm MID_2006/trainingH/Plastic/occ0.npy MID_2006/trainingH/Plastic/im0.png MID_2006/trainingH/Plastic/im1.png MID_2006/trainingH/Plastic/disp0.pfm 20 | MID_2006/trainingH/Wood2/im0.png MID_2006/trainingH/Wood2/im1.png MID_2006/trainingH/Wood2/disp0.pfm MID_2006/trainingH/Wood2/occ0.npy MID_2006/trainingH/Wood2/im0.png MID_2006/trainingH/Wood2/im1.png MID_2006/trainingH/Wood2/disp0.pfm 21 | MID_2006/trainingH/Cloth4/im0.png MID_2006/trainingH/Cloth4/im1.png MID_2006/trainingH/Cloth4/disp0.pfm MID_2006/trainingH/Cloth4/occ0.npy MID_2006/trainingH/Cloth4/im0.png MID_2006/trainingH/Cloth4/im1.png MID_2006/trainingH/Cloth4/disp0.pfm 22 | MID_2006/trainingH/Rocks1/im0.png MID_2006/trainingH/Rocks1/im1.png MID_2006/trainingH/Rocks1/disp0.pfm MID_2006/trainingH/Rocks1/occ0.npy MID_2006/trainingH/Rocks1/im0.png MID_2006/trainingH/Rocks1/im1.png MID_2006/trainingH/Rocks1/disp0.pfm 23 | MID_2006/trainingH/Lampshade1/im0.png MID_2006/trainingH/Lampshade1/im1.png MID_2006/trainingH/Lampshade1/disp0.pfm MID_2006/trainingH/Lampshade1/occ0.npy MID_2006/trainingH/Lampshade1/im0.png MID_2006/trainingH/Lampshade1/im1.png MID_2006/trainingH/Lampshade1/disp0.pfm 24 | MID_2006/trainingH/Cloth3/im0.png MID_2006/trainingH/Cloth3/im1.png MID_2006/trainingH/Cloth3/disp0.pfm MID_2006/trainingH/Cloth3/occ0.npy MID_2006/trainingH/Cloth3/im0.png MID_2006/trainingH/Cloth3/im1.png MID_2006/trainingH/Cloth3/disp0.pfm 25 | MID_2006/trainingH/Midd2/im0.png MID_2006/trainingH/Midd2/im1.png MID_2006/trainingH/Midd2/disp0.pfm MID_2006/trainingH/Midd2/occ0.npy MID_2006/trainingH/Midd2/im0.png MID_2006/trainingH/Midd2/im1.png MID_2006/trainingH/Midd2/disp0.pfm 26 | MID_2006/trainingH/Wood1/im0.png MID_2006/trainingH/Wood1/im1.png MID_2006/trainingH/Wood1/disp0.pfm MID_2006/trainingH/Wood1/occ0.npy MID_2006/trainingH/Wood1/im0.png MID_2006/trainingH/Wood1/im1.png MID_2006/trainingH/Wood1/disp0.pfm 27 | MID_EVAL/trainingH/Pipes/im0.png MID_EVAL/trainingH/Pipes/im1.png MID_EVAL/trainingH/Pipes/disp0.pfm MID_EVAL/trainingH/Pipes/occ0.npy MID_EVAL/trainingH/Pipes/im0.png MID_EVAL/trainingH/Pipes/im1.png MID_EVAL/trainingH/Pipes/disp0.pfm 28 | MID_EVAL/trainingH/Teddy/im0.png MID_EVAL/trainingH/Teddy/im1.png MID_EVAL/trainingH/Teddy/disp0.pfm MID_EVAL/trainingH/Teddy/occ0.npy MID_EVAL/trainingH/Teddy/im0.png MID_EVAL/trainingH/Teddy/im1.png MID_EVAL/trainingH/Teddy/disp0.pfm 29 | MID_EVAL/trainingH/Jadeplant/im0.png MID_EVAL/trainingH/Jadeplant/im1.png MID_EVAL/trainingH/Jadeplant/disp0.pfm MID_EVAL/trainingH/Jadeplant/occ0.npy MID_EVAL/trainingH/Jadeplant/im0.png MID_EVAL/trainingH/Jadeplant/im1.png MID_EVAL/trainingH/Jadeplant/disp0.pfm 30 | MID_EVAL/trainingH/Adirondack/im0.png MID_EVAL/trainingH/Adirondack/im1.png MID_EVAL/trainingH/Adirondack/disp0.pfm MID_EVAL/trainingH/Adirondack/occ0.npy MID_EVAL/trainingH/Adirondack/im0.png MID_EVAL/trainingH/Adirondack/im1.png MID_EVAL/trainingH/Adirondack/disp0.pfm 31 | MID_EVAL/trainingH/MotorcycleE/im0.png MID_EVAL/trainingH/MotorcycleE/im1.png MID_EVAL/trainingH/MotorcycleE/disp0.pfm MID_EVAL/trainingH/MotorcycleE/occ0.npy MID_EVAL/trainingH/MotorcycleE/im0.png MID_EVAL/trainingH/MotorcycleE/im1.png MID_EVAL/trainingH/MotorcycleE/disp0.pfm 32 | MID_EVAL/trainingH/Piano/im0.png MID_EVAL/trainingH/Piano/im1.png MID_EVAL/trainingH/Piano/disp0.pfm MID_EVAL/trainingH/Piano/occ0.npy MID_EVAL/trainingH/Piano/im0.png MID_EVAL/trainingH/Piano/im1.png MID_EVAL/trainingH/Piano/disp0.pfm 33 | MID_EVAL/trainingH/PianoL/im0.png MID_EVAL/trainingH/PianoL/im1.png MID_EVAL/trainingH/PianoL/disp0.pfm MID_EVAL/trainingH/PianoL/occ0.npy MID_EVAL/trainingH/PianoL/im0.png MID_EVAL/trainingH/PianoL/im1.png MID_EVAL/trainingH/PianoL/disp0.pfm 34 | MID_EVAL/trainingH/Motorcycle/im0.png MID_EVAL/trainingH/Motorcycle/im1.png MID_EVAL/trainingH/Motorcycle/disp0.pfm MID_EVAL/trainingH/Motorcycle/occ0.npy MID_EVAL/trainingH/Motorcycle/im0.png MID_EVAL/trainingH/Motorcycle/im1.png MID_EVAL/trainingH/Motorcycle/disp0.pfm 35 | MID_EVAL/trainingH/PlaytableP/im0.png MID_EVAL/trainingH/PlaytableP/im1.png MID_EVAL/trainingH/PlaytableP/disp0.pfm MID_EVAL/trainingH/PlaytableP/occ0.npy MID_EVAL/trainingH/PlaytableP/im0.png MID_EVAL/trainingH/PlaytableP/im1.png MID_EVAL/trainingH/PlaytableP/disp0.pfm 36 | MID_EVAL/trainingH/ArtL/im0.png MID_EVAL/trainingH/ArtL/im1.png MID_EVAL/trainingH/ArtL/disp0.pfm MID_EVAL/trainingH/ArtL/occ0.npy MID_EVAL/trainingH/ArtL/im0.png MID_EVAL/trainingH/ArtL/im1.png MID_EVAL/trainingH/ArtL/disp0.pfm 37 | MID_EVAL/trainingH/Recycle/im0.png MID_EVAL/trainingH/Recycle/im1.png MID_EVAL/trainingH/Recycle/disp0.pfm MID_EVAL/trainingH/Recycle/occ0.npy MID_EVAL/trainingH/Recycle/im0.png MID_EVAL/trainingH/Recycle/im1.png MID_EVAL/trainingH/Recycle/disp0.pfm 38 | MID_EVAL/trainingH/Shelves/im0.png MID_EVAL/trainingH/Shelves/im1.png MID_EVAL/trainingH/Shelves/disp0.pfm MID_EVAL/trainingH/Shelves/occ0.npy MID_EVAL/trainingH/Shelves/im0.png MID_EVAL/trainingH/Shelves/im1.png MID_EVAL/trainingH/Shelves/disp0.pfm 39 | MID_EVAL/trainingH/Playroom/im0.png MID_EVAL/trainingH/Playroom/im1.png MID_EVAL/trainingH/Playroom/disp0.pfm MID_EVAL/trainingH/Playroom/occ0.npy MID_EVAL/trainingH/Playroom/im0.png MID_EVAL/trainingH/Playroom/im1.png MID_EVAL/trainingH/Playroom/disp0.pfm 40 | MID_EVAL/trainingH/Playtable/im0.png MID_EVAL/trainingH/Playtable/im1.png MID_EVAL/trainingH/Playtable/disp0.pfm MID_EVAL/trainingH/Playtable/occ0.npy MID_EVAL/trainingH/Playtable/im0.png MID_EVAL/trainingH/Playtable/im1.png MID_EVAL/trainingH/Playtable/disp0.pfm 41 | MID_EVAL/trainingH/Vintage/im0.png MID_EVAL/trainingH/Vintage/im1.png MID_EVAL/trainingH/Vintage/disp0.pfm MID_EVAL/trainingH/Vintage/occ0.npy MID_EVAL/trainingH/Vintage/im0.png MID_EVAL/trainingH/Vintage/im1.png MID_EVAL/trainingH/Vintage/disp0.pfm 42 | MID_2014/trainingH/Storage-perfect/im0.png MID_2014/trainingH/Storage-perfect/im1.png MID_2014/trainingH/Storage-perfect/disp0.pfm MID_2014/trainingH/Storage-perfect/occ0.npy MID_2014/trainingH/Storage-perfect/im0.png MID_2014/trainingH/Storage-perfect/im1.png MID_2014/trainingH/Storage-perfect/disp0.pfm 43 | MID_2014/trainingH/Bicycle1-perfect/im0.png MID_2014/trainingH/Bicycle1-perfect/im1.png MID_2014/trainingH/Bicycle1-perfect/disp0.pfm MID_2014/trainingH/Bicycle1-perfect/occ0.npy MID_2014/trainingH/Bicycle1-perfect/im0.png MID_2014/trainingH/Bicycle1-perfect/im1.png MID_2014/trainingH/Bicycle1-perfect/disp0.pfm 44 | MID_2014/trainingH/Flowers-perfect/im0.png MID_2014/trainingH/Flowers-perfect/im1.png MID_2014/trainingH/Flowers-perfect/disp0.pfm MID_2014/trainingH/Flowers-perfect/occ0.npy MID_2014/trainingH/Flowers-perfect/im0.png MID_2014/trainingH/Flowers-perfect/im1.png MID_2014/trainingH/Flowers-perfect/disp0.pfm 45 | MID_2014/trainingH/Shopvac-perfect/im0.png MID_2014/trainingH/Shopvac-perfect/im1.png MID_2014/trainingH/Shopvac-perfect/disp0.pfm MID_2014/trainingH/Shopvac-perfect/occ0.npy MID_2014/trainingH/Shopvac-perfect/im0.png MID_2014/trainingH/Shopvac-perfect/im1.png MID_2014/trainingH/Shopvac-perfect/disp0.pfm 46 | MID_2014/trainingH/Umbrella-perfect/im0.png MID_2014/trainingH/Umbrella-perfect/im1.png MID_2014/trainingH/Umbrella-perfect/disp0.pfm MID_2014/trainingH/Umbrella-perfect/occ0.npy MID_2014/trainingH/Umbrella-perfect/im0.png MID_2014/trainingH/Umbrella-perfect/im1.png MID_2014/trainingH/Umbrella-perfect/disp0.pfm 47 | MID_2014/trainingH/Sticks-perfect/im0.png MID_2014/trainingH/Sticks-perfect/im1.png MID_2014/trainingH/Sticks-perfect/disp0.pfm MID_2014/trainingH/Sticks-perfect/occ0.npy MID_2014/trainingH/Sticks-perfect/im0.png MID_2014/trainingH/Sticks-perfect/im1.png MID_2014/trainingH/Sticks-perfect/disp0.pfm 48 | MID_2014/trainingH/Couch-perfect/im0.png MID_2014/trainingH/Couch-perfect/im1.png MID_2014/trainingH/Couch-perfect/disp0.pfm MID_2014/trainingH/Couch-perfect/occ0.npy MID_2014/trainingH/Couch-perfect/im0.png MID_2014/trainingH/Couch-perfect/im1.png MID_2014/trainingH/Couch-perfect/disp0.pfm 49 | MID_2014/trainingH/Sword2-perfect/im0.png MID_2014/trainingH/Sword2-perfect/im1.png MID_2014/trainingH/Sword2-perfect/disp0.pfm MID_2014/trainingH/Sword2-perfect/occ0.npy MID_2014/trainingH/Sword2-perfect/im0.png MID_2014/trainingH/Sword2-perfect/im1.png MID_2014/trainingH/Sword2-perfect/disp0.pfm 50 | MID_2014/trainingH/Sword1-perfect/im0.png MID_2014/trainingH/Sword1-perfect/im1.png MID_2014/trainingH/Sword1-perfect/disp0.pfm MID_2014/trainingH/Sword1-perfect/occ0.npy MID_2014/trainingH/Sword1-perfect/im0.png MID_2014/trainingH/Sword1-perfect/im1.png MID_2014/trainingH/Sword1-perfect/disp0.pfm 51 | MID_2014/trainingH/Classroom1-perfect/im0.png MID_2014/trainingH/Classroom1-perfect/im1.png MID_2014/trainingH/Classroom1-perfect/disp0.pfm MID_2014/trainingH/Classroom1-perfect/occ0.npy MID_2014/trainingH/Classroom1-perfect/im0.png MID_2014/trainingH/Classroom1-perfect/im1.png MID_2014/trainingH/Classroom1-perfect/disp0.pfm 52 | MID_2014/trainingH/Mask-perfect/im0.png MID_2014/trainingH/Mask-perfect/im1.png MID_2014/trainingH/Mask-perfect/disp0.pfm MID_2014/trainingH/Mask-perfect/occ0.npy MID_2014/trainingH/Mask-perfect/im0.png MID_2014/trainingH/Mask-perfect/im1.png MID_2014/trainingH/Mask-perfect/disp0.pfm 53 | MID_2014/trainingH/Cable-perfect/im0.png MID_2014/trainingH/Cable-perfect/im1.png MID_2014/trainingH/Cable-perfect/disp0.pfm MID_2014/trainingH/Cable-perfect/occ0.npy MID_2014/trainingH/Cable-perfect/im0.png MID_2014/trainingH/Cable-perfect/im1.png MID_2014/trainingH/Cable-perfect/disp0.pfm 54 | MID_2014/trainingH/Backpack-perfect/im0.png MID_2014/trainingH/Backpack-perfect/im1.png MID_2014/trainingH/Backpack-perfect/disp0.pfm MID_2014/trainingH/Backpack-perfect/occ0.npy MID_2014/trainingH/Backpack-perfect/im0.png MID_2014/trainingH/Backpack-perfect/im1.png MID_2014/trainingH/Backpack-perfect/disp0.pfm 55 | -------------------------------------------------------------------------------- /datafiles/middlebury/Middleburry.list: -------------------------------------------------------------------------------- 1 | MiddEval3/trainingH/Adirondack/im0.png MiddEval3/trainingH/Adirondack/im1.png MiddEval3/disp/Adirondack/disp0GT.pfm 2 | MiddEval3/trainingH/ArtL/im0.png MiddEval3/trainingH/ArtL/im1.png MiddEval3/disp/ArtL/disp0GT.pfm 3 | MiddEval3/trainingH/Jadeplant/im0.png MiddEval3/trainingH/Jadeplant/im1.png MiddEval3/disp/Jadeplant/disp0GT.pfm 4 | MiddEval3/trainingH/Motorcycle/im0.png MiddEval3/trainingH/Motorcycle/im1.png MiddEval3/disp/Motorcycle/disp0GT.pfm 5 | MiddEval3/trainingH/MotorcycleE/im0.png MiddEval3/trainingH/MotorcycleE/im1.png MiddEval3/disp/MotorcycleE/disp0GT.pfm 6 | MiddEval3/trainingH/Piano/im0.png MiddEval3/trainingH/Piano/im1.png MiddEval3/disp/Piano/disp0GT.pfm 7 | MiddEval3/trainingH/PianoL/im0.png MiddEval3/trainingH/PianoL/im1.png MiddEval3/disp/PianoL/disp0GT.pfm 8 | MiddEval3/trainingH/Pipes/im0.png MiddEval3/trainingH/Pipes/im1.png MiddEval3/disp/Pipes/disp0GT.pfm 9 | MiddEval3/trainingH/Playroom/im0.png MiddEval3/trainingH/Playroom/im1.png MiddEval3/disp/Playroom/disp0GT.pfm 10 | MiddEval3/trainingH/Playtable/im0.png MiddEval3/trainingH/Playtable/im1.png MiddEval3/disp/Playtable/disp0GT.pfm 11 | MiddEval3/trainingH/PlaytableP/im0.png MiddEval3/trainingH/PlaytableP/im1.png MiddEval3/disp/PlaytableP/disp0GT.pfm 12 | MiddEval3/trainingH/Recycle/im0.png MiddEval3/trainingH/Recycle/im1.png MiddEval3/disp/Recycle/disp0GT.pfm 13 | MiddEval3/trainingH/Shelves/im0.png MiddEval3/trainingH/Shelves/im1.png MiddEval3/disp/Shelves/disp0GT.pfm 14 | MiddEval3/trainingH/Teddy/im0.png MiddEval3/trainingH/Teddy/im1.png MiddEval3/disp/Teddy/disp0GT.pfm 15 | MiddEval3/trainingH/Vintage/im0.png MiddEval3/trainingH/Vintage/im1.png MiddEval3/disp/Vintage/disp0GT.pfm 16 | MiddEval3/trainingH/Jadeplant/im0.png MiddEval3/trainingH/Jadeplant/im1.png MiddEval3/disp/Jadeplant/disp0GT.pfm 17 | MiddEval3/trainingH/Piano/im0.png MiddEval3/trainingH/Piano/im1.png MiddEval3/disp/Piano/disp0GT.pfm 18 | MiddEval3/trainingH/Shelves/im0.png MiddEval3/trainingH/Shelves/im1.png MiddEval3/disp/Shelves/disp0GT.pfm -------------------------------------------------------------------------------- /datafiles/middlebury/MiddleburryV3_val.list: -------------------------------------------------------------------------------- 1 | MiddEval3/trainingH/Jadeplant/im0.png MiddEval3/trainingH/Jadeplant/im1.png MiddEval3/disp/Jadeplant/disp0GT.pfm 2 | MiddEval3/trainingH/Piano/im0.png MiddEval3/trainingH/Piano/im1.png MiddEval3/disp/Piano/disp0GT.pfm 3 | MiddEval3/trainingH/Shelves/im0.png MiddEval3/trainingH/Shelves/im1.png MiddEval3/disp/Shelves/disp0GT.pfm 4 | -------------------------------------------------------------------------------- /datafiles/middlebury/middleburry2003_train.list: -------------------------------------------------------------------------------- 1 | MID_2003/trainingH/teddyH/im0.png MID_2003/trainingH/teddyH/im1.png MID_2003/trainingH/teddyH/disp0.pfm MID_2003/trainingH/teddyH/occ0.npy 2 | MID_2003/trainingH/conesH/im0.png MID_2003/trainingH/conesH/im1.png MID_2003/trainingH/conesH/disp0.pfm MID_2003/trainingH/conesH/occ0.npy -------------------------------------------------------------------------------- /datafiles/middlebury/middleburry2005_train.list: -------------------------------------------------------------------------------- 1 | MID_2005/trainingH/Reindeer/im0.png MID_2005/trainingH/Reindeer/im1.png MID_2005/trainingH/Reindeer/disp0.pfm MID_2005/trainingH/Reindeer/occ0.npy 2 | MID_2005/trainingH/Books/im0.png MID_2005/trainingH/Books/im1.png MID_2005/trainingH/Books/disp0.pfm MID_2005/trainingH/Books/occ0.npy 3 | MID_2005/trainingH/Laundry/im0.png MID_2005/trainingH/Laundry/im1.png MID_2005/trainingH/Laundry/disp0.pfm MID_2005/trainingH/Laundry/occ0.npy 4 | MID_2005/trainingH/Dolls/im0.png MID_2005/trainingH/Dolls/im1.png MID_2005/trainingH/Dolls/disp0.pfm MID_2005/trainingH/Dolls/occ0.npy 5 | MID_2005/trainingH/Art/im0.png MID_2005/trainingH/Art/im1.png MID_2005/trainingH/Art/disp0.pfm MID_2005/trainingH/Art/occ0.npy 6 | MID_2005/trainingH/Moebius/im0.png MID_2005/trainingH/Moebius/im1.png MID_2005/trainingH/Moebius/disp0.pfm MID_2005/trainingH/Moebius/occ0.npy -------------------------------------------------------------------------------- /datafiles/middlebury/middleburry2006_train.list: -------------------------------------------------------------------------------- 1 | MID_2006/trainingH/Monopoly/im0.png MID_2006/trainingH/Monopoly/im1.png MID_2006/trainingH/Monopoly/disp0.pfm MID_2006/trainingH/Monopoly/occ0.npy 2 | MID_2006/trainingH/Cloth2/im0.png MID_2006/trainingH/Cloth2/im1.png MID_2006/trainingH/Cloth2/disp0.pfm MID_2006/trainingH/Cloth2/occ0.npy 3 | MID_2006/trainingH/Baby2/im0.png MID_2006/trainingH/Baby2/im1.png MID_2006/trainingH/Baby2/disp0.pfm MID_2006/trainingH/Baby2/occ0.npy 4 | MID_2006/trainingH/Bowling2/im0.png MID_2006/trainingH/Bowling2/im1.png MID_2006/trainingH/Bowling2/disp0.pfm MID_2006/trainingH/Bowling2/occ0.npy 5 | MID_2006/trainingH/Baby3/im0.png MID_2006/trainingH/Baby3/im1.png MID_2006/trainingH/Baby3/disp0.pfm MID_2006/trainingH/Baby3/occ0.npy 6 | MID_2006/trainingH/Flowerpots/im0.png MID_2006/trainingH/Flowerpots/im1.png MID_2006/trainingH/Flowerpots/disp0.pfm MID_2006/trainingH/Flowerpots/occ0.npy 7 | MID_2006/trainingH/Lampshade2/im0.png MID_2006/trainingH/Lampshade2/im1.png MID_2006/trainingH/Lampshade2/disp0.pfm MID_2006/trainingH/Lampshade2/occ0.npy 8 | MID_2006/trainingH/Bowling1/im0.png MID_2006/trainingH/Bowling1/im1.png MID_2006/trainingH/Bowling1/disp0.pfm MID_2006/trainingH/Bowling1/occ0.npy 9 | MID_2006/trainingH/Baby1/im0.png MID_2006/trainingH/Baby1/im1.png MID_2006/trainingH/Baby1/disp0.pfm MID_2006/trainingH/Baby1/occ0.npy 10 | MID_2006/trainingH/Cloth1/im0.png MID_2006/trainingH/Cloth1/im1.png MID_2006/trainingH/Cloth1/disp0.pfm MID_2006/trainingH/Cloth1/occ0.npy 11 | MID_2006/trainingH/Rocks2/im0.png MID_2006/trainingH/Rocks2/im1.png MID_2006/trainingH/Rocks2/disp0.pfm MID_2006/trainingH/Rocks2/occ0.npy 12 | MID_2006/trainingH/Midd1/im0.png MID_2006/trainingH/Midd1/im1.png MID_2006/trainingH/Midd1/disp0.pfm MID_2006/trainingH/Midd1/occ0.npy 13 | MID_2006/trainingH/Aloe/im0.png MID_2006/trainingH/Aloe/im1.png MID_2006/trainingH/Aloe/disp0.pfm MID_2006/trainingH/Aloe/occ0.npy 14 | MID_2006/trainingH/Plastic/im0.png MID_2006/trainingH/Plastic/im1.png MID_2006/trainingH/Plastic/disp0.pfm MID_2006/trainingH/Plastic/occ0.npy 15 | MID_2006/trainingH/Wood2/im0.png MID_2006/trainingH/Wood2/im1.png MID_2006/trainingH/Wood2/disp0.pfm MID_2006/trainingH/Wood2/occ0.npy 16 | MID_2006/trainingH/Cloth4/im0.png MID_2006/trainingH/Cloth4/im1.png MID_2006/trainingH/Cloth4/disp0.pfm MID_2006/trainingH/Cloth4/occ0.npy 17 | MID_2006/trainingH/Rocks1/im0.png MID_2006/trainingH/Rocks1/im1.png MID_2006/trainingH/Rocks1/disp0.pfm MID_2006/trainingH/Rocks1/occ0.npy 18 | MID_2006/trainingH/Lampshade1/im0.png MID_2006/trainingH/Lampshade1/im1.png MID_2006/trainingH/Lampshade1/disp0.pfm MID_2006/trainingH/Lampshade1/occ0.npy 19 | MID_2006/trainingH/Cloth3/im0.png MID_2006/trainingH/Cloth3/im1.png MID_2006/trainingH/Cloth3/disp0.pfm MID_2006/trainingH/Cloth3/occ0.npy 20 | MID_2006/trainingH/Midd2/im0.png MID_2006/trainingH/Midd2/im1.png MID_2006/trainingH/Midd2/disp0.pfm MID_2006/trainingH/Midd2/occ0.npy 21 | MID_2006/trainingH/Wood1/im0.png MID_2006/trainingH/Wood1/im1.png MID_2006/trainingH/Wood1/disp0.pfm MID_2006/trainingH/Wood1/occ0.npy -------------------------------------------------------------------------------- /datafiles/middlebury/middleburry2014_train.list: -------------------------------------------------------------------------------- 1 | MID_2014/trainingH/Adirondack-perfect/im0.png MID_2014/trainingH/Adirondack-perfect/im1.png MID_2014/trainingH/Adirondack-perfect/disp0.pfm MID_2014/trainingH/Adirondack-perfect/occ0.npy 2 | MID_2014/trainingH/Piano-perfect/im0.png MID_2014/trainingH/Piano-perfect/im1.png MID_2014/trainingH/Piano-perfect/disp0.pfm MID_2014/trainingH/Piano-perfect/occ0.npy 3 | MID_2014/trainingH/Vintage-perfect/im0.png MID_2014/trainingH/Vintage-perfect/im1.png MID_2014/trainingH/Vintage-perfect/disp0.pfm MID_2014/trainingH/Vintage-perfect/occ0.npy 4 | MID_2014/trainingH/Storage-perfect/im0.png MID_2014/trainingH/Storage-perfect/im1.png MID_2014/trainingH/Storage-perfect/disp0.pfm MID_2014/trainingH/Storage-perfect/occ0.npy 5 | MID_2014/trainingH/Bicycle1-perfect/im0.png MID_2014/trainingH/Bicycle1-perfect/im1.png MID_2014/trainingH/Bicycle1-perfect/disp0.pfm MID_2014/trainingH/Bicycle1-perfect/occ0.npy 6 | MID_2014/trainingH/Flowers-perfect/im0.png MID_2014/trainingH/Flowers-perfect/im1.png MID_2014/trainingH/Flowers-perfect/disp0.pfm MID_2014/trainingH/Flowers-perfect/occ0.npy 7 | MID_2014/trainingH/Shopvac-perfect/im0.png MID_2014/trainingH/Shopvac-perfect/im1.png MID_2014/trainingH/Shopvac-perfect/disp0.pfm MID_2014/trainingH/Shopvac-perfect/occ0.npy 8 | MID_2014/trainingH/Playroom-perfect/im0.png MID_2014/trainingH/Playroom-perfect/im1.png MID_2014/trainingH/Playroom-perfect/disp0.pfm MID_2014/trainingH/Playroom-perfect/occ0.npy 9 | MID_2014/trainingH/Umbrella-perfect/im0.png MID_2014/trainingH/Umbrella-perfect/im1.png MID_2014/trainingH/Umbrella-perfect/disp0.pfm MID_2014/trainingH/Umbrella-perfect/occ0.npy 10 | MID_2014/trainingH/Sticks-perfect/im0.png MID_2014/trainingH/Sticks-perfect/im1.png MID_2014/trainingH/Sticks-perfect/disp0.pfm MID_2014/trainingH/Sticks-perfect/occ0.npy 11 | MID_2014/trainingH/Couch-perfect/im0.png MID_2014/trainingH/Couch-perfect/im1.png MID_2014/trainingH/Couch-perfect/disp0.pfm MID_2014/trainingH/Couch-perfect/occ0.npy 12 | MID_2014/trainingH/Recycle-perfect/im0.png MID_2014/trainingH/Recycle-perfect/im1.png MID_2014/trainingH/Recycle-perfect/disp0.pfm MID_2014/trainingH/Recycle-perfect/occ0.npy 13 | MID_2014/trainingH/Sword2-perfect/im0.png MID_2014/trainingH/Sword2-perfect/im1.png MID_2014/trainingH/Sword2-perfect/disp0.pfm MID_2014/trainingH/Sword2-perfect/occ0.npy 14 | MID_2014/trainingH/Motorcycle-perfect/im0.png MID_2014/trainingH/Motorcycle-perfect/im1.png MID_2014/trainingH/Motorcycle-perfect/disp0.pfm MID_2014/trainingH/Motorcycle-perfect/occ0.npy 15 | MID_2014/trainingH/Sword1-perfect/im0.png MID_2014/trainingH/Sword1-perfect/im1.png MID_2014/trainingH/Sword1-perfect/disp0.pfm MID_2014/trainingH/Sword1-perfect/occ0.npy 16 | MID_2014/trainingH/Classroom1-perfect/im0.png MID_2014/trainingH/Classroom1-perfect/im1.png MID_2014/trainingH/Classroom1-perfect/disp0.pfm MID_2014/trainingH/Classroom1-perfect/occ0.npy 17 | MID_2014/trainingH/Mask-perfect/im0.png MID_2014/trainingH/Mask-perfect/im1.png MID_2014/trainingH/Mask-perfect/disp0.pfm MID_2014/trainingH/Mask-perfect/occ0.npy 18 | MID_2014/trainingH/Shelves-perfect/im0.png MID_2014/trainingH/Shelves-perfect/im1.png MID_2014/trainingH/Shelves-perfect/disp0.pfm MID_2014/trainingH/Shelves-perfect/occ0.npy 19 | MID_2014/trainingH/Cable-perfect/im0.png MID_2014/trainingH/Cable-perfect/im1.png MID_2014/trainingH/Cable-perfect/disp0.pfm MID_2014/trainingH/Cable-perfect/occ0.npy 20 | MID_2014/trainingH/Pipes-perfect/im0.png MID_2014/trainingH/Pipes-perfect/im1.png MID_2014/trainingH/Pipes-perfect/disp0.pfm MID_2014/trainingH/Pipes-perfect/occ0.npy 21 | MID_2014/trainingH/Playtable-perfect/im0.png MID_2014/trainingH/Playtable-perfect/im1.png MID_2014/trainingH/Playtable-perfect/disp0.pfm MID_2014/trainingH/Playtable-perfect/occ0.npy 22 | MID_2014/trainingH/Backpack-perfect/im0.png MID_2014/trainingH/Backpack-perfect/im1.png MID_2014/trainingH/Backpack-perfect/disp0.pfm MID_2014/trainingH/Backpack-perfect/occ0.npy 23 | MID_2014/trainingH/Jadeplant-perfect/im0.png MID_2014/trainingH/Jadeplant-perfect/im1.png MID_2014/trainingH/Jadeplant-perfect/disp0.pfm MID_2014/trainingH/Jadeplant-perfect/occ0.npy -------------------------------------------------------------------------------- /datafiles/middlebury/middleburryV3_train.list: -------------------------------------------------------------------------------- 1 | MID_EVAL/trainingH/Pipes/im0.png MID_EVAL/trainingH/Pipes/im1.png MID_EVAL/trainingH/Pipes/disp0.pfm MID_EVAL/trainingH/Pipes/occ0.npy 2 | MID_EVAL/trainingH/Teddy/im0.png MID_EVAL/trainingH/Teddy/im1.png MID_EVAL/trainingH/Teddy/disp0.pfm MID_EVAL/trainingH/Teddy/occ0.npy 3 | MID_EVAL/trainingH/Jadeplant/im0.png MID_EVAL/trainingH/Jadeplant/im1.png MID_EVAL/trainingH/Jadeplant/disp0.pfm MID_EVAL/trainingH/Jadeplant/occ0.npy 4 | MID_EVAL/trainingH/Adirondack/im0.png MID_EVAL/trainingH/Adirondack/im1.png MID_EVAL/trainingH/Adirondack/disp0.pfm MID_EVAL/trainingH/Adirondack/occ0.npy 5 | MID_EVAL/trainingH/MotorcycleE/im0.png MID_EVAL/trainingH/MotorcycleE/im1.png MID_EVAL/trainingH/MotorcycleE/disp0.pfm MID_EVAL/trainingH/MotorcycleE/occ0.npy 6 | MID_EVAL/trainingH/Piano/im0.png MID_EVAL/trainingH/Piano/im1.png MID_EVAL/trainingH/Piano/disp0.pfm MID_EVAL/trainingH/Piano/occ0.npy 7 | MID_EVAL/trainingH/PianoL/im0.png MID_EVAL/trainingH/PianoL/im1.png MID_EVAL/trainingH/PianoL/disp0.pfm MID_EVAL/trainingH/PianoL/occ0.npy 8 | MID_EVAL/trainingH/Motorcycle/im0.png MID_EVAL/trainingH/Motorcycle/im1.png MID_EVAL/trainingH/Motorcycle/disp0.pfm MID_EVAL/trainingH/Motorcycle/occ0.npy 9 | MID_EVAL/trainingH/PlaytableP/im0.png MID_EVAL/trainingH/PlaytableP/im1.png MID_EVAL/trainingH/PlaytableP/disp0.pfm MID_EVAL/trainingH/PlaytableP/occ0.npy 10 | MID_EVAL/trainingH/ArtL/im0.png MID_EVAL/trainingH/ArtL/im1.png MID_EVAL/trainingH/ArtL/disp0.pfm MID_EVAL/trainingH/ArtL/occ0.npy 11 | MID_EVAL/trainingH/Recycle/im0.png MID_EVAL/trainingH/Recycle/im1.png MID_EVAL/trainingH/Recycle/disp0.pfm MID_EVAL/trainingH/Recycle/occ0.npy 12 | MID_EVAL/trainingH/Shelves/im0.png MID_EVAL/trainingH/Shelves/im1.png MID_EVAL/trainingH/Shelves/disp0.pfm MID_EVAL/trainingH/Shelves/occ0.npy 13 | MID_EVAL/trainingH/Playroom/im0.png MID_EVAL/trainingH/Playroom/im1.png MID_EVAL/trainingH/Playroom/disp0.pfm MID_EVAL/trainingH/Playroom/occ0.npy 14 | MID_EVAL/trainingH/Playtable/im0.png MID_EVAL/trainingH/Playtable/im1.png MID_EVAL/trainingH/Playtable/disp0.pfm MID_EVAL/trainingH/Playtable/occ0.npy 15 | MID_EVAL/trainingH/Vintage/im0.png MID_EVAL/trainingH/Vintage/im1.png MID_EVAL/trainingH/Vintage/disp0.pfm MID_EVAL/trainingH/Vintage/occ0.npy -------------------------------------------------------------------------------- /datafiles/middlebury/middleburry_2014_mix.list: -------------------------------------------------------------------------------- 1 | MID_EVAL/trainingH/Pipes/im0.png MID_EVAL/trainingH/Pipes/im1.png MID_EVAL/trainingH/Pipes/disp0.pfm MID_EVAL/trainingH/Pipes/occ0.npy MID_EVAL/trainingH/Pipes/im0.png MID_EVAL/trainingH/Pipes/im1.png MID_EVAL/trainingH/Pipes/disp0.pfm 2 | MID_EVAL/trainingH/Teddy/im0.png MID_EVAL/trainingH/Teddy/im1.png MID_EVAL/trainingH/Teddy/disp0.pfm MID_EVAL/trainingH/Teddy/occ0.npy MID_EVAL/trainingH/Teddy/im0.png MID_EVAL/trainingH/Teddy/im1.png MID_EVAL/trainingH/Teddy/disp0.pfm 3 | MID_EVAL/trainingH/Jadeplant/im0.png MID_EVAL/trainingH/Jadeplant/im1.png MID_EVAL/trainingH/Jadeplant/disp0.pfm MID_EVAL/trainingH/Jadeplant/occ0.npy MID_EVAL/trainingH/Jadeplant/im0.png MID_EVAL/trainingH/Jadeplant/im1.png MID_EVAL/trainingH/Jadeplant/disp0.pfm 4 | MID_EVAL/trainingH/Adirondack/im0.png MID_EVAL/trainingH/Adirondack/im1.png MID_EVAL/trainingH/Adirondack/disp0.pfm MID_EVAL/trainingH/Adirondack/occ0.npy MID_EVAL/trainingH/Adirondack/im0.png MID_EVAL/trainingH/Adirondack/im1.png MID_EVAL/trainingH/Adirondack/disp0.pfm 5 | MID_EVAL/trainingH/MotorcycleE/im0.png MID_EVAL/trainingH/MotorcycleE/im1.png MID_EVAL/trainingH/MotorcycleE/disp0.pfm MID_EVAL/trainingH/MotorcycleE/occ0.npy MID_EVAL/trainingH/MotorcycleE/im0.png MID_EVAL/trainingH/MotorcycleE/im1.png MID_EVAL/trainingH/MotorcycleE/disp0.pfm 6 | MID_EVAL/trainingH/Piano/im0.png MID_EVAL/trainingH/Piano/im1.png MID_EVAL/trainingH/Piano/disp0.pfm MID_EVAL/trainingH/Piano/occ0.npy MID_EVAL/trainingH/Piano/im0.png MID_EVAL/trainingH/Piano/im1.png MID_EVAL/trainingH/Piano/disp0.pfm 7 | MID_EVAL/trainingH/PianoL/im0.png MID_EVAL/trainingH/PianoL/im1.png MID_EVAL/trainingH/PianoL/disp0.pfm MID_EVAL/trainingH/PianoL/occ0.npy MID_EVAL/trainingH/PianoL/im0.png MID_EVAL/trainingH/PianoL/im1.png MID_EVAL/trainingH/PianoL/disp0.pfm 8 | MID_EVAL/trainingH/Motorcycle/im0.png MID_EVAL/trainingH/Motorcycle/im1.png MID_EVAL/trainingH/Motorcycle/disp0.pfm MID_EVAL/trainingH/Motorcycle/occ0.npy MID_EVAL/trainingH/Motorcycle/im0.png MID_EVAL/trainingH/Motorcycle/im1.png MID_EVAL/trainingH/Motorcycle/disp0.pfm 9 | MID_EVAL/trainingH/PlaytableP/im0.png MID_EVAL/trainingH/PlaytableP/im1.png MID_EVAL/trainingH/PlaytableP/disp0.pfm MID_EVAL/trainingH/PlaytableP/occ0.npy MID_EVAL/trainingH/PlaytableP/im0.png MID_EVAL/trainingH/PlaytableP/im1.png MID_EVAL/trainingH/PlaytableP/disp0.pfm 10 | MID_EVAL/trainingH/ArtL/im0.png MID_EVAL/trainingH/ArtL/im1.png MID_EVAL/trainingH/ArtL/disp0.pfm MID_EVAL/trainingH/ArtL/occ0.npy MID_EVAL/trainingH/ArtL/im0.png MID_EVAL/trainingH/ArtL/im1.png MID_EVAL/trainingH/ArtL/disp0.pfm 11 | MID_EVAL/trainingH/Recycle/im0.png MID_EVAL/trainingH/Recycle/im1.png MID_EVAL/trainingH/Recycle/disp0.pfm MID_EVAL/trainingH/Recycle/occ0.npy MID_EVAL/trainingH/Recycle/im0.png MID_EVAL/trainingH/Recycle/im1.png MID_EVAL/trainingH/Recycle/disp0.pfm 12 | MID_EVAL/trainingH/Shelves/im0.png MID_EVAL/trainingH/Shelves/im1.png MID_EVAL/trainingH/Shelves/disp0.pfm MID_EVAL/trainingH/Shelves/occ0.npy MID_EVAL/trainingH/Shelves/im0.png MID_EVAL/trainingH/Shelves/im1.png MID_EVAL/trainingH/Shelves/disp0.pfm 13 | MID_EVAL/trainingH/Playroom/im0.png MID_EVAL/trainingH/Playroom/im1.png MID_EVAL/trainingH/Playroom/disp0.pfm MID_EVAL/trainingH/Playroom/occ0.npy MID_EVAL/trainingH/Playroom/im0.png MID_EVAL/trainingH/Playroom/im1.png MID_EVAL/trainingH/Playroom/disp0.pfm 14 | MID_EVAL/trainingH/Playtable/im0.png MID_EVAL/trainingH/Playtable/im1.png MID_EVAL/trainingH/Playtable/disp0.pfm MID_EVAL/trainingH/Playtable/occ0.npy MID_EVAL/trainingH/Playtable/im0.png MID_EVAL/trainingH/Playtable/im1.png MID_EVAL/trainingH/Playtable/disp0.pfm 15 | MID_EVAL/trainingH/Vintage/im0.png MID_EVAL/trainingH/Vintage/im1.png MID_EVAL/trainingH/Vintage/disp0.pfm MID_EVAL/trainingH/Vintage/occ0.npy MID_EVAL/trainingH/Vintage/im0.png MID_EVAL/trainingH/Vintage/im1.png MID_EVAL/trainingH/Vintage/disp0.pfm 16 | MID_2014/trainingH/Storage-perfect/im0.png MID_2014/trainingH/Storage-perfect/im1.png MID_2014/trainingH/Storage-perfect/disp0.pfm MID_2014/trainingH/Storage-perfect/occ0.npy MID_2014/trainingH/Storage-perfect/im0.png MID_2014/trainingH/Storage-perfect/im1.png MID_2014/trainingH/Storage-perfect/disp0.pfm 17 | MID_2014/trainingH/Bicycle1-perfect/im0.png MID_2014/trainingH/Bicycle1-perfect/im1.png MID_2014/trainingH/Bicycle1-perfect/disp0.pfm MID_2014/trainingH/Bicycle1-perfect/occ0.npy MID_2014/trainingH/Bicycle1-perfect/im0.png MID_2014/trainingH/Bicycle1-perfect/im1.png MID_2014/trainingH/Bicycle1-perfect/disp0.pfm 18 | MID_2014/trainingH/Flowers-perfect/im0.png MID_2014/trainingH/Flowers-perfect/im1.png MID_2014/trainingH/Flowers-perfect/disp0.pfm MID_2014/trainingH/Flowers-perfect/occ0.npy MID_2014/trainingH/Flowers-perfect/im0.png MID_2014/trainingH/Flowers-perfect/im1.png MID_2014/trainingH/Flowers-perfect/disp0.pfm 19 | MID_2014/trainingH/Shopvac-perfect/im0.png MID_2014/trainingH/Shopvac-perfect/im1.png MID_2014/trainingH/Shopvac-perfect/disp0.pfm MID_2014/trainingH/Shopvac-perfect/occ0.npy MID_2014/trainingH/Shopvac-perfect/im0.png MID_2014/trainingH/Shopvac-perfect/im1.png MID_2014/trainingH/Shopvac-perfect/disp0.pfm 20 | MID_2014/trainingH/Umbrella-perfect/im0.png MID_2014/trainingH/Umbrella-perfect/im1.png MID_2014/trainingH/Umbrella-perfect/disp0.pfm MID_2014/trainingH/Umbrella-perfect/occ0.npy MID_2014/trainingH/Umbrella-perfect/im0.png MID_2014/trainingH/Umbrella-perfect/im1.png MID_2014/trainingH/Umbrella-perfect/disp0.pfm 21 | MID_2014/trainingH/Sticks-perfect/im0.png MID_2014/trainingH/Sticks-perfect/im1.png MID_2014/trainingH/Sticks-perfect/disp0.pfm MID_2014/trainingH/Sticks-perfect/occ0.npy MID_2014/trainingH/Sticks-perfect/im0.png MID_2014/trainingH/Sticks-perfect/im1.png MID_2014/trainingH/Sticks-perfect/disp0.pfm 22 | MID_2014/trainingH/Couch-perfect/im0.png MID_2014/trainingH/Couch-perfect/im1.png MID_2014/trainingH/Couch-perfect/disp0.pfm MID_2014/trainingH/Couch-perfect/occ0.npy MID_2014/trainingH/Couch-perfect/im0.png MID_2014/trainingH/Couch-perfect/im1.png MID_2014/trainingH/Couch-perfect/disp0.pfm 23 | MID_2014/trainingH/Sword2-perfect/im0.png MID_2014/trainingH/Sword2-perfect/im1.png MID_2014/trainingH/Sword2-perfect/disp0.pfm MID_2014/trainingH/Sword2-perfect/occ0.npy MID_2014/trainingH/Sword2-perfect/im0.png MID_2014/trainingH/Sword2-perfect/im1.png MID_2014/trainingH/Sword2-perfect/disp0.pfm 24 | MID_2014/trainingH/Sword1-perfect/im0.png MID_2014/trainingH/Sword1-perfect/im1.png MID_2014/trainingH/Sword1-perfect/disp0.pfm MID_2014/trainingH/Sword1-perfect/occ0.npy MID_2014/trainingH/Sword1-perfect/im0.png MID_2014/trainingH/Sword1-perfect/im1.png MID_2014/trainingH/Sword1-perfect/disp0.pfm 25 | MID_2014/trainingH/Classroom1-perfect/im0.png MID_2014/trainingH/Classroom1-perfect/im1.png MID_2014/trainingH/Classroom1-perfect/disp0.pfm MID_2014/trainingH/Classroom1-perfect/occ0.npy MID_2014/trainingH/Classroom1-perfect/im0.png MID_2014/trainingH/Classroom1-perfect/im1.png MID_2014/trainingH/Classroom1-perfect/disp0.pfm 26 | MID_2014/trainingH/Mask-perfect/im0.png MID_2014/trainingH/Mask-perfect/im1.png MID_2014/trainingH/Mask-perfect/disp0.pfm MID_2014/trainingH/Mask-perfect/occ0.npy MID_2014/trainingH/Mask-perfect/im0.png MID_2014/trainingH/Mask-perfect/im1.png MID_2014/trainingH/Mask-perfect/disp0.pfm 27 | MID_2014/trainingH/Cable-perfect/im0.png MID_2014/trainingH/Cable-perfect/im1.png MID_2014/trainingH/Cable-perfect/disp0.pfm MID_2014/trainingH/Cable-perfect/occ0.npy MID_2014/trainingH/Cable-perfect/im0.png MID_2014/trainingH/Cable-perfect/im1.png MID_2014/trainingH/Cable-perfect/disp0.pfm 28 | MID_2014/trainingH/Backpack-perfect/im0.png MID_2014/trainingH/Backpack-perfect/im1.png MID_2014/trainingH/Backpack-perfect/disp0.pfm MID_2014/trainingH/Backpack-perfect/occ0.npy MID_2014/trainingH/Backpack-perfect/im0.png MID_2014/trainingH/Backpack-perfect/im1.png MID_2014/trainingH/Backpack-perfect/disp0.pfm -------------------------------------------------------------------------------- /datafiles/middlebury/middleburry_submit.list: -------------------------------------------------------------------------------- 1 | MID_EVAL/trainingH/CrusadeP/im0.png MID_EVAL/trainingH/CrusadeP/im1.png 2 | MID_EVAL/trainingH/Staircase/im0.png MID_EVAL/trainingH/Staircase/im1.png 3 | MID_EVAL/trainingH/Computer/im0.png MID_EVAL/trainingH/Computer/im1.png 4 | MID_EVAL/trainingH/Djembe/im0.png MID_EVAL/trainingH/Djembe/im1.png 5 | MID_EVAL/trainingH/Plants_color/im0.png MID_EVAL/trainingH/Plants_color/im1.png 6 | MID_EVAL/trainingH/Bicycle2/im0.png MID_EVAL/trainingH/Bicycle2/im1.png 7 | MID_EVAL/trainingH/Classroom2/im0.png MID_EVAL/trainingH/Classroom2/im1.png 8 | MID_EVAL/trainingH/DjembeL/im0.png MID_EVAL/trainingH/DjembeL/im1.png 9 | MID_EVAL/trainingH/Australia/im0.png MID_EVAL/trainingH/Australia/im1.png 10 | MID_EVAL/trainingH/Plants/im0.png MID_EVAL/trainingH/Plants/im1.png 11 | MID_EVAL/trainingH/AustraliaP/im0.png MID_EVAL/trainingH/AustraliaP/im1.png 12 | MID_EVAL/trainingH/Crusade/im0.png MID_EVAL/trainingH/Crusade/im1.png 13 | MID_EVAL/trainingH/Hoops/im0.png MID_EVAL/trainingH/Hoops/im1.png 14 | MID_EVAL/trainingH/Classroom2E/im0.png MID_EVAL/trainingH/Classroom2E/im1.png 15 | MID_EVAL/trainingH/Newkuba/im0.png MID_EVAL/trainingH/Newkuba/im1.png 16 | MID_EVAL/trainingH/Livingroom/im0.png MID_EVAL/trainingH/Livingroom/im1.png -------------------------------------------------------------------------------- /dataloader/__pycache__/file_io.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/dataloader/__pycache__/file_io.cpython-39.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/sceneflow_loader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/dataloader/__pycache__/sceneflow_loader.cpython-39.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/dataloader/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/dataloader/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /dataloader/file_io.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import re 7 | from PIL import Image 8 | import sys 9 | import torch 10 | import cv2 11 | 12 | def depth2disp(left_depth_path): 13 | left_depth = np.array(Image.open(left_depth_path))/(10000) 14 | focal_length = 768.16058349609375 15 | baseline = 0.06 16 | left_disp = focal_length*baseline /left_depth 17 | return left_disp 18 | 19 | 20 | def depth2disp_cv(left_depth_path): 21 | left_depth = cv2.imread(left_depth_path, cv2.IMREAD_UNCHANGED)/ 10000 22 | 23 | 24 | focal_length = 768.16058349609375 25 | baseline = 0.06 26 | left_disp = focal_length*baseline /left_depth 27 | return left_disp 28 | 29 | 30 | def read_img(filename): 31 | # Convert to RGB for scene flow finalpass data 32 | img = np.array(Image.open(filename).convert('RGB')).astype(np.float32) 33 | return img 34 | 35 | 36 | def read_disp(filename, subset=False): 37 | # Scene Flow dataset 38 | if filename.endswith('pfm'): 39 | # For finalpass and cleanpass, gt disparity is positive, subset is negative 40 | disp = np.ascontiguousarray(_read_pfm(filename)[0]) 41 | if subset: 42 | disp = -disp 43 | # KITTI 44 | elif filename.endswith('png'): 45 | disp = _read_kitti_disp(filename) 46 | elif filename.endswith('npy'): 47 | disp = np.load(filename) 48 | else: 49 | raise Exception('Invalid disparity file format!') 50 | return disp # [H, W] 51 | 52 | 53 | def _read_pfm(file): 54 | file = open(file, 'rb') 55 | 56 | color = None 57 | width = None 58 | height = None 59 | scale = None 60 | endian = None 61 | 62 | header = file.readline().rstrip() 63 | if header.decode("ascii") == 'PF': 64 | color = True 65 | elif header.decode("ascii") == 'Pf': 66 | color = False 67 | else: 68 | raise Exception('Not a PFM file.') 69 | 70 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 71 | if dim_match: 72 | width, height = list(map(int, dim_match.groups())) 73 | else: 74 | raise Exception('Malformed PFM header.') 75 | 76 | scale = float(file.readline().decode("ascii").rstrip()) 77 | if scale < 0: # little-endian 78 | endian = '<' 79 | scale = -scale 80 | else: 81 | endian = '>' # big-endian 82 | 83 | data = np.fromfile(file, endian + 'f') 84 | shape = (height, width, 3) if color else (height, width) 85 | 86 | data = np.reshape(data, shape) 87 | data = np.flipud(data) 88 | return data, scale 89 | 90 | 91 | def write_pfm(file, image, scale=1): 92 | file = open(file, 'wb') 93 | 94 | color = None 95 | 96 | if image.dtype.name != 'float32': 97 | raise Exception('Image dtype must be float32.') 98 | 99 | image = np.flipud(image) 100 | 101 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 102 | color = True 103 | elif len(image.shape) == 2 or len( 104 | image.shape) == 3 and image.shape[2] == 1: # greyscale 105 | color = False 106 | else: 107 | raise Exception( 108 | 'Image must have H x W x 3, H x W x 1 or H x W dimensions.') 109 | 110 | file.write(b'PF\n' if color else b'Pf\n') 111 | file.write(b'%d %d\n' % (image.shape[1], image.shape[0])) 112 | 113 | endian = image.dtype.byteorder 114 | 115 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 116 | scale = -scale 117 | 118 | file.write(b'%f\n' % scale) 119 | 120 | image.tofile(file) 121 | 122 | 123 | def _read_kitti_disp(filename): 124 | depth = np.array(Image.open(filename)) 125 | depth = depth.astype(np.float32) / 256. 126 | return depth 127 | 128 | 129 | def read_occlusion_mid(filename): 130 | img = Image.open(filename) 131 | img_np = np.array(img) 132 | valid_mask_combine = (img_np<=128).astype(np.float) 133 | return valid_mask_combine -------------------------------------------------------------------------------- /dataloader/sceneflow_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from torch.utils.data import Dataset 6 | import os 7 | 8 | from dataloader.utils import read_text_lines 9 | from dataloader.file_io import read_disp,read_img 10 | # from utils.utils import read_text_lines 11 | # from utils.file_io import read_disp,read_img 12 | from skimage import io, transform 13 | import numpy as np 14 | 15 | class StereoDataset(Dataset): 16 | def __init__(self, data_dir, 17 | train_datalist, 18 | test_datalist, 19 | dataset_name='SceneFlow', 20 | mode='train', 21 | save_filename=False, 22 | load_pseudo_gt=False, 23 | transform=None): 24 | super(StereoDataset, self).__init__() 25 | 26 | self.data_dir = data_dir 27 | self.dataset_name = dataset_name 28 | self.mode = mode 29 | self.save_filename = save_filename 30 | self.transform = transform 31 | self.train_datalist = train_datalist 32 | self.test_datalist = test_datalist 33 | self.img_size=(540, 960) 34 | self.scale_size =(576,960) 35 | 36 | 37 | sceneflow_finalpass_dict = { 38 | 'train': self.train_datalist, 39 | 'val': self.test_datalist, 40 | 'test': self.test_datalist 41 | } 42 | 43 | kitti_2012_dict = { 44 | 'train': 'filenames/KITTI_2012_train.txt', 45 | 'train_all': 'filenames/KITTI_2012_train_all.txt', 46 | 'val': 'filenames/KITTI_2012_val.txt', 47 | 'test': 'filenames/KITTI_2012_test.txt' 48 | } 49 | 50 | kitti_2015_dict = { 51 | 'train': 'filenames/KITTI_2015_train.txt', 52 | 'train_all': 'filenames/KITTI_2015_train_all.txt', 53 | 'val': 'filenames/KITTI_2015_val.txt', 54 | 'test': 'filenames/KITTI_2015_test.txt' 55 | } 56 | 57 | kitti_mix_dict = { 58 | 'train': 'filenames/KITTI_mix.txt', 59 | 'test': 'filenames/KITTI_2015_test.txt' 60 | } 61 | 62 | dataset_name_dict = { 63 | 'SceneFlow': sceneflow_finalpass_dict, 64 | 'KITTI2012': kitti_2012_dict, 65 | 'KITTI2015': kitti_2015_dict, 66 | 'KITTI_mix': kitti_mix_dict, 67 | } 68 | 69 | assert dataset_name in dataset_name_dict.keys() 70 | self.dataset_name = dataset_name 71 | 72 | self.samples = [] 73 | 74 | data_filenames = dataset_name_dict[dataset_name][mode] 75 | 76 | lines = read_text_lines(data_filenames) 77 | 78 | for line in lines: 79 | splits = line.split() 80 | 81 | left_img, right_img = splits[:2] 82 | gt_disp = None if len(splits) == 2 else splits[2] 83 | 84 | sample = dict() 85 | 86 | if self.save_filename: 87 | sample['left_name'] = left_img.split('/', 1)[1] 88 | 89 | sample['left'] = os.path.join(data_dir, left_img) 90 | sample['right'] = os.path.join(data_dir, right_img) 91 | sample['disp'] = os.path.join(data_dir, gt_disp) if gt_disp is not None else None 92 | 93 | if load_pseudo_gt and sample['disp'] is not None: 94 | # KITTI 2015 95 | if 'disp_occ_0' in sample['disp']: 96 | sample['pseudo_disp'] = (sample['disp']).replace('disp_occ_0', 97 | 'disp_occ_0_pseudo_gt') 98 | # KITTI 2012 99 | elif 'disp_occ' in sample['disp']: 100 | sample['pseudo_disp'] = (sample['disp']).replace('disp_occ', 101 | 'disp_occ_pseudo_gt') 102 | else: 103 | raise NotImplementedError 104 | else: 105 | sample['pseudo_disp'] = None 106 | 107 | self.samples.append(sample) 108 | 109 | def __getitem__(self, index): 110 | sample = {} 111 | sample_path = self.samples[index] 112 | 113 | if self.save_filename: 114 | sample['left_name'] = sample_path['left_name'] 115 | 116 | sample['img_left'] = read_img(sample_path['left']) # [H, W, 3] 117 | sample['img_right'] = read_img(sample_path['right']) 118 | 119 | 120 | # GT disparity of subset if negative, finalpass and cleanpass is positive 121 | subset = True if 'subset' in self.dataset_name else False 122 | if sample_path['disp'] is not None: 123 | sample['gt_disp'] = read_disp(sample_path['disp'], subset=subset) # [H, W] 124 | 125 | if self.mode=='test' or self.mode=='val': 126 | # img_left = transform.resize(sample['img_left'], [576,960], preserve_range=True) 127 | # img_right = transform.resize(sample['img_right'], [576,960], preserve_range=True) 128 | img_left = sample['img_left'] 129 | img_right = sample['img_right'] 130 | 131 | img_left = img_left.astype(np.float32) 132 | img_right = img_right.astype(np.float32) 133 | 134 | sample['img_left'] = img_left 135 | sample['img_right'] = img_right 136 | 137 | if self.transform is not None: 138 | sample = self.transform(sample) 139 | 140 | return sample 141 | 142 | def __len__(self): 143 | return len(self.samples) 144 | 145 | def get_img_size(self): 146 | return self.img_size 147 | 148 | def get_scale_size(self): 149 | return self.scale_size -------------------------------------------------------------------------------- /dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from genericpath import samefile 3 | import torch 4 | import numpy as np 5 | from PIL import Image, ImageEnhance 6 | import torchvision.transforms.functional as F 7 | import random 8 | import cv2 9 | 10 | 11 | class Compose(object): 12 | def __init__(self, transforms): 13 | self.transforms = transforms 14 | 15 | def __call__(self, sample): 16 | for t in self.transforms: 17 | sample = t(sample) 18 | return sample 19 | 20 | 21 | 22 | class ToTensor(object): 23 | """Convert numpy array to torch tensor""" 24 | def __call__(self, sample): 25 | left = np.transpose(sample['img_left'], (2, 0, 1)) # [3, H, W] 26 | sample['img_left'] = torch.from_numpy(left) / 255. 27 | right = np.transpose(sample['img_right'], (2, 0, 1)) 28 | sample['img_right'] = torch.from_numpy(right) / 255. 29 | 30 | # disp = np.expand_dims(sample['disp'], axis=0) # [1, H, W] 31 | if 'gt_disp' in sample.keys(): 32 | 33 | disp = sample['gt_disp'] # [H, W] 34 | sample['gt_disp'] = torch.from_numpy(disp) 35 | 36 | return sample 37 | 38 | 39 | class Normalize(object): 40 | """Normalize image, with type tensor""" 41 | 42 | def __init__(self, mean, std): 43 | self.mean = mean 44 | self.std = std 45 | 46 | def __call__(self, sample): 47 | 48 | norm_keys = ['img_left', 'img_right'] 49 | for key in norm_keys: 50 | # Images have converted to tensor, with shape [C, H, W] 51 | for t, m, s in zip(sample[key], self.mean, self.std): 52 | t.sub_(m).div_(s) 53 | 54 | return sample 55 | 56 | 57 | class RandomCrop(object): 58 | def __init__(self, img_height, img_width, validate=False): 59 | self.img_height = img_height 60 | self.img_width = img_width 61 | self.validate = validate 62 | 63 | def __call__(self, sample): 64 | 65 | 66 | ori_height, ori_width = sample['img_left'].shape[:2] 67 | if self.img_height > ori_height or self.img_width > ori_width: 68 | top_pad = self.img_height - ori_height 69 | right_pad = self.img_width - ori_width 70 | 71 | assert top_pad >= 0 and right_pad >= 0 72 | 73 | sample['img_left'] = np.lib.pad(sample['img_left'], 74 | ((top_pad, 0), (0, right_pad), (0, 0)), 75 | mode='constant', 76 | constant_values=0) 77 | sample['img_right'] = np.lib.pad(sample['img_right'], 78 | ((top_pad, 0), (0, right_pad), (0, 0)), 79 | mode='constant', 80 | constant_values=0) 81 | if 'gt_disp' in sample.keys(): 82 | sample['gt_disp'] = np.lib.pad(sample['gt_disp'], 83 | ((top_pad, 0), (0, right_pad)), 84 | mode='constant', 85 | constant_values=0) 86 | 87 | else: 88 | assert self.img_height <= ori_height and self.img_width <= ori_width 89 | # Training: random crop 90 | if not self.validate: 91 | 92 | self.offset_x = np.random.randint(ori_width - self.img_width + 1) 93 | 94 | start_height = 0 95 | assert ori_height - start_height >= self.img_height 96 | 97 | self.offset_y = np.random.randint(start_height, ori_height - self.img_height + 1) 98 | 99 | # Validatoin, center crop 100 | else: 101 | self.offset_x = (ori_width - self.img_width) // 2 102 | self.offset_y = (ori_height - self.img_height) // 2 103 | 104 | sample['img_left'] = self.crop_img(sample['img_left']) 105 | sample['img_right'] = self.crop_img(sample['img_right']) 106 | if 'gt_disp' in sample.keys(): 107 | sample['gt_disp'] = self.crop_img(sample['gt_disp']) 108 | 109 | return sample 110 | 111 | def crop_img(self, img): 112 | return img[self.offset_y:self.offset_y + self.img_height, 113 | self.offset_x:self.offset_x + self.img_width] 114 | 115 | import matplotlib.pyplot as plt 116 | 117 | class RandomVerticalFlip(object): 118 | """Randomly vertically filps""" 119 | 120 | def __call__(self, sample): 121 | if np.random.random() < 0.09: 122 | sample['img_left'] = np.copy(np.flipud(sample['img_left'])) 123 | sample['img_right'] = np.copy(np.flipud(sample['img_right'])) 124 | 125 | sample['gt_disp'] = np.copy(np.flipud(sample['gt_disp'])) 126 | 127 | return sample 128 | 129 | 130 | class ToPILImage(object): 131 | 132 | def __call__(self, sample): 133 | sample['img_left'] = Image.fromarray(sample['img_left'].astype('uint8')) 134 | sample['img_right'] = Image.fromarray(sample['img_right'].astype('uint8')) 135 | 136 | return sample 137 | 138 | 139 | class ToNumpyArray(object): 140 | 141 | def __call__(self, sample): 142 | sample['img_left'] = np.array(sample['img_left']).astype(np.float32) 143 | sample['img_right'] = np.array(sample['img_right']).astype(np.float32) 144 | 145 | return sample 146 | 147 | 148 | # Random coloring 149 | class RandomContrast(object): 150 | """Random contrast""" 151 | 152 | def __call__(self, sample): 153 | if np.random.random() < 0.5: 154 | contrast_factor = np.random.uniform(0.8, 1.2) 155 | sample['img_left'] = F.adjust_contrast(sample['img_left'], contrast_factor) 156 | sample['img_right'] = F.adjust_contrast(sample['img_right'], contrast_factor) 157 | 158 | return sample 159 | 160 | 161 | class RandomGamma(object): 162 | 163 | def __call__(self, sample): 164 | if np.random.random() < 0.5: 165 | gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet 166 | 167 | sample['img_left'] = F.adjust_gamma(sample['img_left'], gamma) 168 | sample['img_right'] = F.adjust_gamma(sample['img_right'], gamma) 169 | 170 | return sample 171 | 172 | 173 | class RandomBrightness(object): 174 | 175 | def __call__(self, sample): 176 | if np.random.random() < 0.5: 177 | brightness = np.random.uniform(0.5, 2.0) 178 | 179 | sample['img_left'] = F.adjust_brightness(sample['img_left'], brightness) 180 | sample['img_right'] = F.adjust_brightness(sample['img_right'], brightness) 181 | 182 | return sample 183 | 184 | 185 | class RandomHue(object): 186 | 187 | def __call__(self, sample): 188 | if np.random.random() < 0.5: 189 | hue = np.random.uniform(-0.1, 0.1) 190 | 191 | sample['img_left'] = F.adjust_hue(sample['img_left'], hue) 192 | sample['img_right'] = F.adjust_hue(sample['img_right'], hue) 193 | 194 | return sample 195 | 196 | 197 | class RandomSaturation(object): 198 | 199 | def __call__(self, sample): 200 | if np.random.random() < 0.5: 201 | saturation = np.random.uniform(0.8, 1.2) 202 | sample['img_left'] = F.adjust_saturation(sample['img_left'], saturation) 203 | sample['img_right'] = F.adjust_saturation(sample['img_right'], saturation) 204 | 205 | return sample 206 | 207 | 208 | class RandomColor(object): 209 | 210 | def __call__(self, sample): 211 | transforms = [RandomContrast(), 212 | RandomGamma(), 213 | RandomBrightness(), 214 | RandomHue(), 215 | RandomSaturation()] 216 | 217 | sample = ToPILImage()(sample) 218 | 219 | if np.random.random() < 0.5: 220 | # A single transform 221 | t = random.choice(transforms) 222 | sample = t(sample) 223 | else: 224 | # Combination of transforms 225 | # Random order 226 | random.shuffle(transforms) 227 | for t in transforms: 228 | sample = t(sample) 229 | 230 | sample = ToNumpyArray()(sample) 231 | 232 | return sample -------------------------------------------------------------------------------- /dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import torch 5 | from glob import glob 6 | import logging 7 | 8 | import numpy as np 9 | 10 | 11 | def find_occ_mask(disp_left, disp_right): 12 | """ 13 | find occlusion map 14 | 1 indicates occlusion 15 | disp range [0,w] 16 | """ 17 | w = disp_left.shape[-1] 18 | 19 | # # left occlusion 20 | # find corresponding pixels in target image 21 | coord = np.linspace(0, w - 1, w)[None,] # 1xW 22 | right_shifted = coord - disp_left 23 | 24 | # 1. negative locations will be occlusion 25 | occ_mask_l = right_shifted <= 0 26 | 27 | # 2. wrong matches will be occlusion 28 | right_shifted[occ_mask_l] = 0 # set negative locations to 0 29 | right_shifted = right_shifted.astype(np.int) 30 | disp_right_selected = np.take_along_axis(disp_right, right_shifted, 31 | axis=1) # find tgt disparity at src-shifted locations 32 | wrong_matches = np.abs(disp_right_selected - disp_left) > 1 # theoretically, these two should match perfectly 33 | wrong_matches[disp_right_selected <= 0.0] = False 34 | wrong_matches[disp_left <= 0.0] = False 35 | 36 | # produce final occ 37 | wrong_matches[occ_mask_l] = True # apply case 1 occlusion to case 2 38 | occ_mask_l = wrong_matches 39 | 40 | # # right occlusion 41 | # find corresponding pixels in target image 42 | coord = np.linspace(0, w - 1, w)[None,] # 1xW 43 | left_shifted = coord + disp_right 44 | 45 | # 1. negative locations will be occlusion 46 | occ_mask_r = left_shifted >= w 47 | 48 | # 2. wrong matches will be occlusion 49 | left_shifted[occ_mask_r] = 0 # set negative locations to 0 50 | left_shifted = left_shifted.astype(np.int) 51 | disp_left_selected = np.take_along_axis(disp_left, left_shifted, 52 | axis=1) # find tgt disparity at src-shifted locations 53 | wrong_matches = np.abs(disp_left_selected - disp_right) > 1 # theoretically, these two should match perfectly 54 | wrong_matches[disp_left_selected <= 0.0] = False 55 | wrong_matches[disp_right <= 0.0] = False 56 | 57 | # produce final occ 58 | wrong_matches[occ_mask_r] = True # apply case 1 occlusion to case 2 59 | occ_mask_r = wrong_matches 60 | 61 | return occ_mask_l, occ_mask_r 62 | 63 | 64 | 65 | def read_text_lines(filepath): 66 | with open(filepath, 'r') as f: 67 | lines = f.readlines() 68 | lines = [l.rstrip() for l in lines] 69 | return lines 70 | 71 | 72 | def check_path(path): 73 | if not os.path.exists(path): 74 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing 75 | 76 | 77 | def save_command(save_path, filename='command_train.txt'): 78 | check_path(save_path) 79 | command = sys.argv 80 | save_file = os.path.join(save_path, filename) 81 | with open(save_file, 'w') as f: 82 | f.write(' '.join(command)) 83 | 84 | 85 | def save_args(args, filename='args.json'): 86 | args_dict = vars(args) 87 | check_path(args.checkpoint_dir) 88 | save_path = os.path.join(args.checkpoint_dir, filename) 89 | 90 | with open(save_path, 'w') as f: 91 | json.dump(args_dict, f, indent=4, sort_keys=False) 92 | 93 | 94 | def int_list(s): 95 | """Convert string to int list""" 96 | return [int(x) for x in s.split(',')] 97 | 98 | 99 | def save_checkpoint(save_path, optimizer, aanet, epoch, num_iter, 100 | epe, best_epe, best_epoch, filename=None, save_optimizer=True): 101 | # AANet 102 | aanet_state = { 103 | 'epoch': epoch, 104 | 'num_iter': num_iter, 105 | 'epe': epe, 106 | 'best_epe': best_epe, 107 | 'best_epoch': best_epoch, 108 | 'state_dict': aanet.state_dict() 109 | } 110 | aanet_filename = 'aanet_epoch_{:0>3d}.pth'.format(epoch) if filename is None else filename 111 | aanet_save_path = os.path.join(save_path, aanet_filename) 112 | torch.save(aanet_state, aanet_save_path) 113 | 114 | # Optimizer 115 | if save_optimizer: 116 | optimizer_state = { 117 | 'epoch': epoch, 118 | 'num_iter': num_iter, 119 | 'epe': epe, 120 | 'best_epe': best_epe, 121 | 'best_epoch': best_epoch, 122 | 'state_dict': optimizer.state_dict() 123 | } 124 | optimizer_name = aanet_filename.replace('aanet', 'optimizer') 125 | optimizer_save_path = os.path.join(save_path, optimizer_name) 126 | torch.save(optimizer_state, optimizer_save_path) 127 | 128 | 129 | def load_pretrained_net(net, pretrained_path, return_epoch_iter=False, resume=False, 130 | no_strict=False): 131 | if pretrained_path is not None: 132 | if torch.cuda.is_available(): 133 | state = torch.load(pretrained_path, map_location='cuda') 134 | else: 135 | state = torch.load(pretrained_path, map_location='cpu') 136 | 137 | from collections import OrderedDict 138 | new_state_dict = OrderedDict() 139 | 140 | weights = state['state_dict'] if 'state_dict' in state.keys() else state 141 | 142 | for k, v in weights.items(): 143 | name = k[7:] if 'module' in k and not resume else k 144 | new_state_dict[name] = v 145 | 146 | if no_strict: 147 | net.load_state_dict(new_state_dict, strict=False) # ignore intermediate output 148 | else: 149 | net.load_state_dict(new_state_dict) # optimizer has no argument `strict` 150 | 151 | if return_epoch_iter: 152 | epoch = state['epoch'] if 'epoch' in state.keys() else None 153 | num_iter = state['num_iter'] if 'num_iter' in state.keys() else None 154 | best_epe = state['best_epe'] if 'best_epe' in state.keys() else None 155 | best_epoch = state['best_epoch'] if 'best_epoch' in state.keys() else None 156 | return epoch, num_iter, best_epe, best_epoch 157 | 158 | 159 | def resume_latest_ckpt(checkpoint_dir, net, net_name): 160 | ckpts = sorted(glob(checkpoint_dir + '/' + net_name + '*.pth')) 161 | 162 | if len(ckpts) == 0: 163 | raise RuntimeError('=> No checkpoint found while resuming training') 164 | 165 | latest_ckpt = ckpts[-1] 166 | print('=> Resume latest %s checkpoint: %s' % (net_name, os.path.basename(latest_ckpt))) 167 | epoch, num_iter, best_epe, best_epoch = load_pretrained_net(net, latest_ckpt, True, True) 168 | 169 | return epoch, num_iter, best_epe, best_epoch 170 | 171 | 172 | def fix_net_parameters(net): 173 | for param in net.parameters(): 174 | param.requires_grad = False 175 | 176 | 177 | def count_parameters(model): 178 | num = sum(p.numel() for p in model.parameters() if p.requires_grad) 179 | return num 180 | 181 | 182 | def filter_specific_params(kv): 183 | specific_layer_name = ['offset_conv.weight', 'offset_conv.bias'] 184 | for name in specific_layer_name: 185 | if name in kv[0]: 186 | return True 187 | return False 188 | 189 | 190 | def filter_base_params(kv): 191 | specific_layer_name = ['offset_conv.weight', 'offset_conv.bias'] 192 | for name in specific_layer_name: 193 | if name in kv[0]: 194 | return False 195 | return True 196 | 197 | 198 | def get_logger(): 199 | logger_name = "main-logger" 200 | logger = logging.getLogger(logger_name) 201 | logger.setLevel(logging.INFO) 202 | handler = logging.StreamHandler() 203 | # fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 204 | fmt = "[%(asctime)s] %(message)s" 205 | handler.setFormatter(logging.Formatter(fmt)) 206 | logger.addHandler(handler) 207 | return logger -------------------------------------------------------------------------------- /playground/check_depth_est.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | import sys 8 | sys.path.append("..") 9 | 10 | from dataloader.sceneflow_loader import StereoDataset 11 | from torch.utils.data import DataLoader 12 | from dataloader import transforms 13 | import os 14 | 15 | from utils.common import logger 16 | from utils.image_util import resize_max_res,chw2hwc,colorize_depth_maps 17 | 18 | 19 | # IMAGENET NORMALIZATION 20 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 21 | IMAGENET_STD = [0.229, 0.224, 0.225] 22 | 23 | from diffusers import AutoencoderKL 24 | 25 | from utils.de_normalized import de_normalization 26 | 27 | 28 | # Get Dataset Here 29 | def prepare_dataset(data_name, 30 | datapath=None, 31 | trainlist=None, 32 | vallist=None, 33 | batch_size=1, 34 | test_batch=1, 35 | datathread=4, 36 | logger=None): 37 | 38 | # set the config parameters 39 | dataset_config_dict = dict() 40 | 41 | if data_name == 'sceneflow': 42 | train_transform_list = [ 43 | transforms.ToTensor(), 44 | # transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 45 | ] 46 | train_transform = transforms.Compose(train_transform_list) 47 | 48 | val_transform_list = [transforms.ToTensor(), 49 | # transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 50 | ] 51 | val_transform = transforms.Compose(val_transform_list) 52 | 53 | train_dataset = StereoDataset(data_dir=datapath,train_datalist=trainlist,test_datalist=vallist, 54 | dataset_name='SceneFlow',mode='train',transform=train_transform) 55 | test_dataset = StereoDataset(data_dir=datapath,train_datalist=trainlist,test_datalist=vallist, 56 | dataset_name='SceneFlow',mode='val',transform=val_transform) 57 | 58 | img_height, img_width = train_dataset.get_img_size() 59 | 60 | 61 | datathread=4 62 | if os.environ.get('datathread') is not None: 63 | datathread = int(os.environ.get('datathread')) 64 | 65 | if logger is not None: 66 | logger.info("Use %d processes to load data..." % datathread) 67 | 68 | train_loader = DataLoader(train_dataset, batch_size = batch_size, \ 69 | shuffle = True, num_workers = datathread, \ 70 | pin_memory = True) 71 | 72 | test_loader = DataLoader(test_dataset, batch_size = test_batch, \ 73 | shuffle = False, num_workers = datathread, \ 74 | pin_memory = True) 75 | 76 | num_batches_per_epoch = len(train_loader) 77 | 78 | 79 | dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch 80 | dataset_config_dict['img_size'] = (img_height,img_width) 81 | 82 | 83 | return (train_loader,test_loader),dataset_config_dict 84 | 85 | def Disparity_Normalization(disparity): 86 | min_value = torch.min(disparity) 87 | max_value = torch.max(disparity) 88 | normalized_disparity = ((disparity -min_value)/(max_value-min_value+1e-5) - 0.5) * 2 89 | return normalized_disparity 90 | 91 | def resize_max_res_tensor(input_tensor,is_disp=False,recom_resolution=768): 92 | assert input_tensor.shape[1]==3 93 | original_H, original_W = input_tensor.shape[2:] 94 | 95 | downscale_factor = min(recom_resolution/original_H, 96 | recom_resolution/original_W) 97 | 98 | resized_input_tensor = F.interpolate(input_tensor, 99 | scale_factor=downscale_factor,mode='bilinear', 100 | align_corners=False) 101 | 102 | if is_disp: 103 | return resized_input_tensor * downscale_factor 104 | else: 105 | return resized_input_tensor 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | if __name__=="__main__": 118 | 119 | 120 | datapath = "/data1/liu/" 121 | trainlist = "/home/zliu/ECCV2024/Accelerator-Simple-Template/datafiles/sceneflow/SceneFlow_With_Occ.list" 122 | vallist = "/home/zliu/ECCV2024/Accelerator-Simple-Template/datafiles/sceneflow/FlyingThings3D_Test_With_Occ.list" 123 | 124 | 125 | (train_loader,test_loader), dataset_config_dict = prepare_dataset(data_name='sceneflow', 126 | datapath=datapath,trainlist=trainlist, 127 | vallist=vallist,batch_size=1, 128 | test_batch=1,datathread=4,logger=logger) 129 | 130 | pretrained_model_name_path = "stabilityai/stable-diffusion-2" 131 | 132 | # define the vae 133 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_path, subfolder="vae") 134 | vae.requires_grad_(False) 135 | vae.cuda() 136 | print("Loaded the VAE pre-trained model successfully!") 137 | 138 | 139 | for idx, sample in enumerate(train_loader): 140 | left_img = sample['img_left'] 141 | right_img = sample['img_right'] 142 | left_disp_single = sample['gt_disp'] 143 | 144 | left_disp_single = left_disp_single.unsqueeze(0) 145 | left_disp = left_disp_single.repeat(1,3,1,1) 146 | 147 | resized_left_disp = resize_max_res_tensor(left_disp,is_disp=True) 148 | 149 | normaliazed_left_disp = Disparity_Normalization(resized_left_disp) 150 | 151 | normaliazed_left_disp = normaliazed_left_disp.cuda() 152 | 153 | with torch.no_grad(): 154 | latents = vae.encode(normaliazed_left_disp).latent_dist.sample() 155 | latents = latents * 0.18215 156 | 157 | 158 | # recovered image tensor back 159 | latents_recovered = 1 / 0.18215 * latents 160 | recovered_depth_normalized = vae.decode(latents_recovered).sample 161 | 162 | 163 | recovered_denoise = de_normalization(resized_left_disp.squeeze(0).permute(1,2,0).cpu().numpy(),recovered_depth_normalized.squeeze(0).permute(1,2,0).cpu().numpy()) 164 | 165 | print(np.mean(np.abs(recovered_denoise-resized_left_disp.squeeze(0).permute(1,2,0).cpu().numpy()))) 166 | 167 | 168 | # print((normaliazed_left_disp-recovered_depth_normalized).mean()) 169 | # print(normaliazed_left_disp.mean()) 170 | # print(recovered_depth_normalized.mean()) 171 | 172 | break 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /run/run_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from tqdm.auto import tqdm 9 | 10 | import sys 11 | sys.path.append("../") 12 | from Inference.depth_pipeline import DepthEstimationPipeline 13 | from utils.seed_all import seed_all 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | from diffusers import ( 18 | DiffusionPipeline, 19 | DDIMScheduler, 20 | UNet2DConditionModel, 21 | AutoencoderKL, 22 | ) 23 | from transformers import CLIPTextModel, CLIPTokenizer 24 | 25 | 26 | 27 | 28 | if __name__=="__main__": 29 | 30 | use_seperate = True 31 | stable_diffusion_repo_path = "stabilityai/stable-diffusion-2" 32 | 33 | logging.basicConfig(level=logging.INFO) 34 | 35 | '''Set the Args''' 36 | parser = argparse.ArgumentParser( 37 | description="Run MonoDepth Estimation using Stable Diffusion." 38 | ) 39 | parser.add_argument( 40 | "--pretrained_model_path", 41 | type=str, 42 | default='None', 43 | help="pretrained model path from hugging face or local dir", 44 | ) 45 | 46 | 47 | parser.add_argument( 48 | "--input_rgb_path", 49 | type=str, 50 | required=True, 51 | help="Path to the input image.", 52 | ) 53 | 54 | parser.add_argument( 55 | "--output_dir", type=str, required=True, help="Output directory." 56 | ) 57 | 58 | # inference setting 59 | parser.add_argument( 60 | "--denoise_steps", 61 | type=int, 62 | default=10, 63 | help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.", 64 | ) 65 | parser.add_argument( 66 | "--ensemble_size", 67 | type=int, 68 | default=10, 69 | help="Number of predictions to be ensembled, more inference gives better results but runs slower.", 70 | ) 71 | parser.add_argument( 72 | "--half_precision", 73 | action="store_true", 74 | help="Run with half-precision (16-bit float), might lead to suboptimal result.", 75 | ) 76 | 77 | # resolution setting 78 | parser.add_argument( 79 | "--processing_res", 80 | type=int, 81 | default=768, 82 | help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", 83 | ) 84 | parser.add_argument( 85 | "--output_processing_res", 86 | action="store_true", 87 | help="When input is resized, out put depth at resized operating resolution. Default: False.", 88 | ) 89 | 90 | # depth map colormap 91 | parser.add_argument( 92 | "--color_map", 93 | type=str, 94 | default="Spectral", 95 | help="Colormap used to render depth predictions.", 96 | ) 97 | # other settings 98 | parser.add_argument("--seed", type=int, default=None, help="Random seed.") 99 | parser.add_argument( 100 | "--batch_size", 101 | type=int, 102 | default=0, 103 | help="Inference batch size. Default: 0 (will be set automatically).", 104 | ) 105 | 106 | args = parser.parse_args() 107 | 108 | checkpoint_path = args.pretrained_model_path 109 | input_image_path = args.input_rgb_path 110 | output_dir = args.output_dir 111 | denoise_steps = args.denoise_steps 112 | ensemble_size = args.ensemble_size 113 | 114 | if ensemble_size>15: 115 | logging.warning("long ensemble steps, low speed..") 116 | 117 | half_precision = args.half_precision 118 | 119 | processing_res = args.processing_res 120 | match_input_res = not args.output_processing_res 121 | 122 | color_map = args.color_map 123 | seed = args.seed 124 | batch_size = args.batch_size 125 | 126 | if batch_size==0: 127 | batch_size = 1 # set default batchsize 128 | 129 | # -------------------- Preparation -------------------- 130 | # Random seed 131 | if seed is None: 132 | import time 133 | 134 | seed = int(time.time()) 135 | seed_all(seed) 136 | 137 | # Output directories 138 | output_dir_color = os.path.join(output_dir, "depth_colored") 139 | output_dir_npy = os.path.join(output_dir, "depth_npy") 140 | os.makedirs(output_dir, exist_ok=True) 141 | os.makedirs(output_dir_color, exist_ok=True) 142 | os.makedirs(output_dir_npy, exist_ok=True) 143 | logging.info(f"output dir = {output_dir}") 144 | 145 | # -------------------- Device -------------------- 146 | if torch.cuda.is_available(): 147 | device = torch.device("cuda") 148 | else: 149 | device = torch.device("cpu") 150 | logging.warning("CUDA is not available. Running on CPU will be slow.") 151 | logging.info(f"device = {device}") 152 | 153 | 154 | # -------------------Data---------------------------- 155 | logging.info("Inference Image Path from {}".format(input_image_path)) 156 | 157 | # -------------------- Model -------------------- 158 | if half_precision: 159 | dtype = torch.float16 160 | logging.info(f"Running with half precision ({dtype}).") 161 | else: 162 | dtype = torch.float32 163 | 164 | # declare a pipeline 165 | # unet = UNet2DConditionModel.from_pretrained(checkpoint_path,subfolder='unet') 166 | 167 | 168 | if not use_seperate: 169 | pipe = DepthEstimationPipeline.from_pretrained(checkpoint_path, torch_dtype=dtype) 170 | print("Using Completed") 171 | else: 172 | 173 | vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path,subfolder='vae') 174 | scheduler = DDIMScheduler.from_pretrained(stable_diffusion_repo_path,subfolder='scheduler') 175 | text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_repo_path,subfolder='text_encoder') 176 | tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_repo_path,subfolder='tokenizer') 177 | 178 | # https://huggingface.co/docs/diffusers/training/adapt_a_model 179 | unet = UNet2DConditionModel.from_pretrained(checkpoint_path,subfolder="unet", 180 | in_channels=8, sample_size=96, 181 | low_cpu_mem_usage=False, 182 | ignore_mismatched_sizes=True) 183 | 184 | pipe = DepthEstimationPipeline(unet=unet, 185 | vae=vae, 186 | scheduler=scheduler, 187 | text_encoder=text_encoder, 188 | tokenizer=tokenizer) 189 | print("Using Seperated Modules") 190 | 191 | logging.info("loading pipeline whole successfully.") 192 | 193 | try: 194 | 195 | pipe.enable_xformers_memory_efficient_attention() 196 | except: 197 | pass # run without xformers 198 | 199 | pipe = pipe.to(device) 200 | 201 | # -------------------- Inference and saving -------------------- 202 | with torch.no_grad(): 203 | os.makedirs(output_dir, exist_ok=True) 204 | # load the example image. 205 | input_image_pil = Image.open(input_image_path) 206 | # input_image_pil.save("input_image.png") 207 | # predict the depth here 208 | pipe_out = pipe(input_image_pil, 209 | denosing_steps=denoise_steps, 210 | ensemble_size= ensemble_size, 211 | processing_res = processing_res, 212 | match_input_res = match_input_res, 213 | batch_size = batch_size, 214 | color_map = color_map, 215 | show_progress_bar = True, 216 | ) 217 | 218 | depth_pred: np.ndarray = pipe_out.depth_np 219 | depth_colored: Image.Image = pipe_out.depth_colored 220 | # depth_colored: np.ndarray = pipe_out.depth_colored 221 | 222 | 223 | 224 | # savd as npy 225 | rgb_name_base = os.path.splitext(os.path.basename(input_image_path))[0] 226 | pred_name_base = rgb_name_base + "_pred" 227 | 228 | 229 | npy_save_path = os.path.join(output_dir_npy, f"{pred_name_base}.npy") 230 | if os.path.exists(npy_save_path): 231 | logging.warning(f"Existing file: '{npy_save_path}' will be overwritten") 232 | np.save(npy_save_path, depth_pred) 233 | 234 | # Colorize 235 | colored_save_path = os.path.join( 236 | output_dir_color, f"{pred_name_base}_colored.png" 237 | ) 238 | if os.path.exists(colored_save_path): 239 | logging.warning( 240 | f"Existing file: '{colored_save_path}' will be overwritten" 241 | ) 242 | depth_colored.save(colored_save_path) 243 | -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | inference_single_image(){ 2 | input_rgb_path="/home/zliu/ECCV2024/Accelerator-Simple-Template/data_sample/kitti3d_000025.png" 3 | output_dir="outputs" 4 | pretrained_model_path="Bingxin/Marigold" # your checkpoint here 5 | ensemble_size=10 6 | 7 | cd .. 8 | cd run 9 | 10 | CUDA_VISIBLE_DEVICES=0 python run_inference.py \ 11 | --input_rgb_path $input_rgb_path \ 12 | --output_dir $output_dir \ 13 | --pretrained_model_path $pretrained_model_path \ 14 | --ensemble_size $ensemble_size 15 | } 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | LAUNCH_TRAINING(){ 2 | 3 | # accelerate config default 4 | cd .. 5 | cd training 6 | pretrained_model_name_or_path='stabilityai/stable-diffusion-2' 7 | root_path='/data1/liu' 8 | dataset_name='sceneflow' 9 | trainlist='/home/zliu/Desktop/ECCV2024/code/Diffusion/sf_double_check/Accelerator-Simple-Template/datafiles/sceneflow/SceneFlow_With_Occ.list' 10 | vallist='/home/zliu/ECCV2024/Accelerator-Simple-Template/datafiles/sceneflow/FlyingThings3D_Test_With_Occ.list' 11 | output_dir='../outputs/sceneflow_fine_tune_hardest' 12 | train_batch_size=1 13 | num_train_epochs=10 14 | gradient_accumulation_steps=8 15 | learning_rate=1e-5 16 | lr_warmup_steps=0 17 | dataloader_num_workers=4 18 | tracker_project_name='sceneflow_pretrain_tracker_hardest' 19 | 20 | 21 | CUDA_VISIBLE_DEVICES=0,1 accelerate launch --mixed_precision="fp16" --multi_gpu depth2image_trainer.py \ 22 | --pretrained_model_name_or_path $pretrained_model_name_or_path \ 23 | --dataset_name $dataset_name --trainlist $trainlist \ 24 | --dataset_path $root_path --vallist $vallist \ 25 | --output_dir $output_dir \ 26 | --train_batch_size $train_batch_size \ 27 | --num_train_epochs $num_train_epochs \ 28 | --gradient_accumulation_steps $gradient_accumulation_steps\ 29 | --gradient_checkpointing \ 30 | --learning_rate $learning_rate \ 31 | --lr_warmup_steps $lr_warmup_steps \ 32 | --dataloader_num_workers $dataloader_num_workers \ 33 | --tracker_project_name $tracker_project_name \ 34 | --gradient_checkpointing \ 35 | --enable_xformers_memory_efficient_attention \ 36 | 37 | } 38 | 39 | 40 | 41 | LAUNCH_TRAINING 42 | -------------------------------------------------------------------------------- /training/dataset_configuration.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import sys 6 | sys.path.append("..") 7 | 8 | from dataloader.sceneflow_loader import StereoDataset 9 | from torch.utils.data import DataLoader 10 | from dataloader import transforms 11 | import os 12 | 13 | 14 | # Get Dataset Here 15 | def prepare_dataset(data_name, 16 | datapath=None, 17 | trainlist=None, 18 | vallist=None, 19 | batch_size=1, 20 | test_batch=1, 21 | datathread=4, 22 | logger=None): 23 | 24 | # set the config parameters 25 | dataset_config_dict = dict() 26 | 27 | if data_name == 'sceneflow': 28 | train_transform_list = [ 29 | transforms.ToTensor(), 30 | # transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 31 | ] 32 | train_transform = transforms.Compose(train_transform_list) 33 | 34 | val_transform_list = [transforms.ToTensor(), 35 | # transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 36 | ] 37 | val_transform = transforms.Compose(val_transform_list) 38 | 39 | train_dataset = StereoDataset(data_dir=datapath,train_datalist=trainlist,test_datalist=vallist, 40 | dataset_name='SceneFlow',mode='train',transform=train_transform) 41 | test_dataset = StereoDataset(data_dir=datapath,train_datalist=trainlist,test_datalist=vallist, 42 | dataset_name='SceneFlow',mode='val',transform=val_transform) 43 | 44 | img_height, img_width = train_dataset.get_img_size() 45 | 46 | 47 | datathread=4 48 | if os.environ.get('datathread') is not None: 49 | datathread = int(os.environ.get('datathread')) 50 | 51 | if logger is not None: 52 | logger.info("Use %d processes to load data..." % datathread) 53 | 54 | train_loader = DataLoader(train_dataset, batch_size = batch_size, \ 55 | shuffle = True, num_workers = datathread, \ 56 | pin_memory = True) 57 | 58 | test_loader = DataLoader(test_dataset, batch_size = test_batch, \ 59 | shuffle = False, num_workers = datathread, \ 60 | pin_memory = True) 61 | 62 | num_batches_per_epoch = len(train_loader) 63 | 64 | 65 | dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch 66 | dataset_config_dict['img_size'] = (img_height,img_width) 67 | 68 | 69 | return (train_loader,test_loader),dataset_config_dict 70 | 71 | def Disparity_Normalization(disparity): 72 | min_value = torch.min(disparity) 73 | max_value = torch.max(disparity) 74 | normalized_disparity = ((disparity -min_value)/(max_value-min_value+1e-5) - 0.5) * 2 75 | return normalized_disparity 76 | 77 | def resize_max_res_tensor(input_tensor,is_disp=False,recom_resolution=768): 78 | assert input_tensor.shape[1]==3 79 | original_H, original_W = input_tensor.shape[2:] 80 | 81 | downscale_factor = min(recom_resolution/original_H, 82 | recom_resolution/original_W) 83 | 84 | resized_input_tensor = F.interpolate(input_tensor, 85 | scale_factor=downscale_factor,mode='bilinear', 86 | align_corners=False) 87 | 88 | if is_disp: 89 | return resized_input_tensor * downscale_factor 90 | else: 91 | return resized_input_tensor -------------------------------------------------------------------------------- /training/depth2image_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import math 4 | import random 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.utils.checkpoint 10 | 11 | import os 12 | import logging 13 | import tqdm 14 | 15 | from accelerate import Accelerator 16 | import transformers 17 | import datasets 18 | import numpy as np 19 | from accelerate.logging import get_logger 20 | from accelerate.utils import set_seed 21 | from accelerate.state import AcceleratorState 22 | from accelerate.utils import ProjectConfiguration, set_seed 23 | import shutil 24 | 25 | 26 | import diffusers 27 | from diffusers import ( 28 | DiffusionPipeline, 29 | DDIMScheduler, 30 | UNet2DConditionModel, 31 | AutoencoderKL, 32 | ) 33 | 34 | from diffusers.optimization import get_scheduler 35 | from diffusers.training_utils import EMAModel, compute_snr 36 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid 37 | from diffusers.utils.import_utils import is_xformers_available 38 | from diffusers.utils.torch_utils import is_compiled_module 39 | 40 | 41 | from packaging import version 42 | from torchvision import transforms 43 | from tqdm.auto import tqdm 44 | from transformers import CLIPTextModel, CLIPTokenizer 45 | from transformers.utils import ContextManagers 46 | import accelerate 47 | 48 | import sys 49 | sys.path.append("..") 50 | from training.dataset_configuration import prepare_dataset,Disparity_Normalization,resize_max_res_tensor 51 | 52 | from Inference.depth_pipeline_half import DepthEstimationPipeline 53 | # from Inference.depth_pipeline import DepthEstimationPipeline 54 | from PIL import Image 55 | 56 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 57 | check_min_version("0.26.0.dev0") 58 | 59 | logger = get_logger(__name__, log_level="INFO") 60 | 61 | 62 | 63 | def log_validation(vae,text_encoder,tokenizer,unet,args,accelerator,weight_dtype,scheduler,epoch, 64 | input_image_path="/data1/liu/sceneflow/frames_cleanpass/flythings3d/TEST/A/0000/left/0006.png" 65 | ): 66 | 67 | denoise_steps = 10 68 | ensemble_size = 10 69 | processing_res = 768 70 | match_input_res = True 71 | batch_size = 1 72 | color_map="Spectral" 73 | 74 | 75 | logger.info("Running validation ... ") 76 | pipeline = DepthEstimationPipeline.from_pretrained(pretrained_model_name_or_path=args.pretrained_model_name_or_path, 77 | vae=accelerator.unwrap_model(vae), 78 | text_encoder=accelerator.unwrap_model(text_encoder), 79 | tokenizer=tokenizer, 80 | unet = accelerator.unwrap_model(unet), 81 | safety_checker=None, 82 | scheduler = accelerator.unwrap_model(scheduler), 83 | ) 84 | 85 | pipeline = pipeline.to(accelerator.device) 86 | try: 87 | pipeline.enable_xformers_memory_efficient_attention() 88 | except: 89 | pass 90 | 91 | # -------------------- Inference and saving -------------------- 92 | with torch.no_grad(): 93 | input_image_pil = Image.open(input_image_path) 94 | 95 | pipe_out = pipeline(input_image_pil, 96 | denosing_steps=denoise_steps, 97 | ensemble_size= ensemble_size, 98 | processing_res = processing_res, 99 | match_input_res = match_input_res, 100 | batch_size = batch_size, 101 | color_map = color_map, 102 | show_progress_bar = True, 103 | ) 104 | 105 | depth_pred: np.ndarray = pipe_out.depth_np 106 | depth_colored: Image.Image = pipe_out.depth_colored 107 | 108 | # savd as npy 109 | rgb_name_base = os.path.splitext(os.path.basename(input_image_path))[0] 110 | pred_name_base = rgb_name_base + "_pred" 111 | 112 | npy_save_path = os.path.join(args.output_dir, f"{pred_name_base}.npy") 113 | if os.path.exists(npy_save_path): 114 | logging.warning(f"Existing file: '{npy_save_path}' will be overwritten") 115 | np.save(npy_save_path, depth_pred) 116 | 117 | # Colorize 118 | colored_save_path = os.path.join( 119 | args.output_dir, f"{pred_name_base}_{epoch}_colored.png" 120 | ) 121 | if os.path.exists(colored_save_path): 122 | logging.warning( 123 | f"Existing file: '{colored_save_path}' will be overwritten" 124 | ) 125 | depth_colored.save(colored_save_path) 126 | 127 | 128 | del depth_colored 129 | del pipeline 130 | torch.cuda.empty_cache() 131 | 132 | 133 | 134 | 135 | 136 | def parse_args(): 137 | parser = argparse.ArgumentParser(description="Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation") 138 | 139 | parser.add_argument( 140 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." 141 | ) 142 | 143 | parser.add_argument( 144 | "--pretrained_model_name_or_path", 145 | type=str, 146 | default=None, 147 | required=True, 148 | help="Path to pretrained model or model identifier from huggingface.co/models.", 149 | ) 150 | 151 | parser.add_argument( 152 | "--dataset_name", 153 | type=str, 154 | default="sceneflow", 155 | required=True, 156 | help="Specify the dataset name used for training/validation.", 157 | ) 158 | parser.add_argument( 159 | "--dataset_path", 160 | type=str, 161 | default="/data1/liu", 162 | required=True, 163 | help="The Root Dataset Path.", 164 | ) 165 | parser.add_argument( 166 | "--trainlist", 167 | type=str, 168 | default="/home/zliu/ECCV2024/Accelerator-Simple-Template/datafiles/sceneflow/SceneFlow_With_Occ.list", 169 | required=True, 170 | help="train file listing the training files", 171 | ) 172 | parser.add_argument( 173 | "--vallist", 174 | type=str, 175 | default="/home/zliu/ECCV2024/Accelerator-Simple-Template/datafiles/sceneflow/FlyingThings3D_Test_With_Occ.list", 176 | required=True, 177 | help="validation file listing the validation files", 178 | ) 179 | 180 | parser.add_argument( 181 | "--max_train_samples", 182 | type=int, 183 | default=None, 184 | help=( 185 | "For debugging purposes or quicker training, truncate the number of training examples to this " 186 | "value if set." 187 | ), 188 | ) 189 | parser.add_argument( 190 | "--output_dir", 191 | type=str, 192 | default="saved_models", 193 | help="The output directory where the model predictions and checkpoints will be written.", 194 | ) 195 | parser.add_argument( 196 | "--cache_dir", 197 | type=str, 198 | default=None, 199 | help="The directory where the downloaded models and datasets will be stored.", 200 | ) 201 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 202 | 203 | parser.add_argument( 204 | "--recom_resolution", 205 | type=int, 206 | default=768, 207 | help=( 208 | "The resolution for resizeing the input images and the depth/disparity to make full use of the pre-trained model from \ 209 | from the stable diffusion vae, for common cases, do not change this parameter" 210 | ), 211 | ) 212 | #TODO : Data Augmentation 213 | parser.add_argument( 214 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." 215 | ) 216 | parser.add_argument("--num_train_epochs", type=int, default=70) 217 | 218 | parser.add_argument( 219 | "--max_train_steps", 220 | type=int, 221 | default=None, 222 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 223 | ) 224 | parser.add_argument( 225 | "--gradient_accumulation_steps", 226 | type=int, 227 | default=1, 228 | help="Number of updates steps to accumulate before performing a backward/update pass.", 229 | ) 230 | parser.add_argument( 231 | "--gradient_checkpointing", 232 | action="store_true", 233 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 234 | ) 235 | parser.add_argument( 236 | "--learning_rate", 237 | type=float, 238 | default=1e-4, 239 | help="Initial learning rate (after the potential warmup period) to use.", 240 | ) 241 | parser.add_argument( 242 | "--scale_lr", 243 | action="store_true", 244 | default=False, 245 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 246 | ) 247 | parser.add_argument( 248 | "--lr_scheduler", 249 | type=str, 250 | default="constant", 251 | help=( 252 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 253 | ' "constant", "constant_with_warmup"]' 254 | ), 255 | ) 256 | parser.add_argument( 257 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 258 | ) 259 | parser.add_argument( 260 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 261 | ) 262 | 263 | # using EMA for improving the generalization 264 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 265 | 266 | # dataloaderes 267 | parser.add_argument( 268 | "--dataloader_num_workers", 269 | type=int, 270 | default=0, 271 | help=( 272 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 273 | ), 274 | ) 275 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 276 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 277 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 278 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 279 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 280 | 281 | parser.add_argument( 282 | "--prediction_type", 283 | type=str, 284 | default=None, 285 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", 286 | ) 287 | 288 | parser.add_argument( 289 | "--logging_dir", 290 | type=str, 291 | default="logs", 292 | help=( 293 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 294 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 295 | ), 296 | ) 297 | parser.add_argument( 298 | "--mixed_precision", 299 | type=str, 300 | default=None, 301 | choices=["no", "fp16", "bf16"], 302 | help=( 303 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 304 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 305 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 306 | ), 307 | ) 308 | parser.add_argument( 309 | "--report_to", 310 | type=str, 311 | default="tensorboard", 312 | help=( 313 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 314 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 315 | ), 316 | ) 317 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 318 | 319 | # how many steps csave a checkpoints 320 | parser.add_argument( 321 | "--checkpointing_steps", 322 | type=int, 323 | default=10000, 324 | help=( 325 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 326 | " training using `--resume_from_checkpoint`." 327 | ), 328 | ) 329 | 330 | parser.add_argument( 331 | "--checkpoints_total_limit", 332 | type=int, 333 | default=None, 334 | help=("Max number of checkpoints to store."), 335 | ) 336 | 337 | parser.add_argument( 338 | "--resume_from_checkpoint", 339 | type=str, 340 | default=None, 341 | help=( 342 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 343 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 344 | ), 345 | ) 346 | # using xformers for efficient training 347 | parser.add_argument( 348 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 349 | ) 350 | 351 | # noise offset?::: #TODO HERE 352 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 353 | 354 | # validations every 5 Epochs 355 | parser.add_argument( 356 | "--validation_epochs", 357 | type=int, 358 | default=5, 359 | help="Run validation every X epochs.", 360 | ) 361 | 362 | parser.add_argument( 363 | "--tracker_project_name", 364 | type=str, 365 | default="text2image-fine-tune", 366 | help=( 367 | "The `project_name` argument passed to Accelerator.init_trackers for" 368 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 369 | ), 370 | ) 371 | 372 | # get the local rank 373 | args = parser.parse_args() 374 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 375 | 376 | 377 | if env_local_rank != -1 and env_local_rank != args.local_rank: 378 | args.local_rank = env_local_rank 379 | 380 | # Sanity checks 381 | if args.dataset_name is None and args.dataset_path is None: 382 | raise ValueError("Need either a dataset name or a DataPath.") 383 | 384 | return args 385 | 386 | 387 | def main(): 388 | 389 | ''' ------------------------Configs Preparation----------------------------''' 390 | # give the args parsers 391 | args = parse_args() 392 | # save the tensorboard log files 393 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 394 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 395 | 396 | # tell the gradient_accumulation_steps, mix precison, and tensorboard 397 | accelerator = Accelerator( 398 | gradient_accumulation_steps=args.gradient_accumulation_steps, 399 | mixed_precision=args.mixed_precision, 400 | log_with=args.report_to, 401 | project_config=accelerator_project_config, 402 | ) 403 | 404 | # Make one log on every process with the configuration for debugging. 405 | logging.basicConfig( 406 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 407 | datefmt="%m/%d/%Y %H:%M:%S", 408 | level=logging.INFO, 409 | ) 410 | logger.info(accelerator.state, main_process_only=True) # only the main process show the logs 411 | # set the warning levels 412 | if accelerator.is_local_main_process: 413 | datasets.utils.logging.set_verbosity_warning() 414 | transformers.utils.logging.set_verbosity_warning() 415 | diffusers.utils.logging.set_verbosity_info() 416 | else: 417 | datasets.utils.logging.set_verbosity_error() 418 | transformers.utils.logging.set_verbosity_error() 419 | diffusers.utils.logging.set_verbosity_error() 420 | 421 | # If passed along, set the training seed now. 422 | if args.seed is not None: 423 | set_seed(args.seed) 424 | 425 | # Doing I/O at the main proecss 426 | if accelerator.is_main_process: 427 | if args.output_dir is not None: 428 | os.makedirs(args.output_dir, exist_ok=True) 429 | 430 | ''' ------------------------Non-NN Modules Definition----------------------------''' 431 | noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path,subfolder='scheduler') 432 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path,subfolder='tokenizer') 433 | logger.info("loading the noise scheduler and the tokenizer from {}".format(args.pretrained_model_name_or_path),main_process_only=True) 434 | 435 | def deepspeed_zero_init_disabled_context_manager(): 436 | """ 437 | returns either a context list that includes one that will disable zero.Init or an empty context list 438 | """ 439 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None 440 | if deepspeed_plugin is None: 441 | return [] 442 | 443 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)] 444 | 445 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()): 446 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, 447 | subfolder='vae') 448 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, 449 | subfolder='text_encoder') 450 | 451 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,subfolder="unet", 452 | in_channels=8, sample_size=96, 453 | low_cpu_mem_usage=False, 454 | ignore_mismatched_sizes=True) 455 | 456 | # Freeze vae and text_encoder and set unet to trainable. 457 | vae.requires_grad_(False) 458 | text_encoder.requires_grad_(False) 459 | unet.train() # only make the unet-trainable 460 | 461 | # using EMA 462 | if args.use_ema: 463 | ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,subfolder="unet", 464 | in_channels=8, sample_size=96, 465 | low_cpu_mem_usage=False, 466 | ignore_mismatched_sizes=True) 467 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) 468 | 469 | 470 | # using xformers for efficient attentions. 471 | if args.enable_xformers_memory_efficient_attention: 472 | if is_xformers_available(): 473 | import xformers 474 | xformers_version = version.parse(xformers.__version__) 475 | if xformers_version == version.parse("0.0.16"): 476 | logger.warn( 477 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 478 | ) 479 | unet.enable_xformers_memory_efficient_attention() 480 | else: 481 | raise ValueError("xformers is not available. Make sure it is installed correctly") 482 | 483 | # `accelerate` 0.16.0 will have better support for customized saving 484 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 485 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 486 | def save_model_hook(models, weights, output_dir): 487 | if accelerator.is_main_process: 488 | if args.use_ema: 489 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 490 | for i, model in enumerate(models): 491 | model.save_pretrained(os.path.join(output_dir, "unet")) 492 | # make sure to pop weight so that corresponding model is not saved again 493 | weights.pop() 494 | 495 | def load_model_hook(models, input_dir): 496 | if args.use_ema: 497 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) 498 | ema_unet.load_state_dict(load_model.state_dict()) 499 | ema_unet.to(accelerator.device) 500 | del load_model 501 | 502 | for i in range(len(models)): 503 | # pop models so that they are not loaded again 504 | model = models.pop() 505 | # load diffusers style into model 506 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 507 | model.register_to_config(**load_model.config) 508 | model.load_state_dict(load_model.state_dict()) 509 | del load_model 510 | 511 | accelerator.register_save_state_pre_hook(save_model_hook) 512 | accelerator.register_load_state_pre_hook(load_model_hook) 513 | 514 | 515 | # using checkpint for saving the memories 516 | if args.gradient_checkpointing: 517 | unet.enable_gradient_checkpointing() 518 | 519 | # how many cards did we use: accelerator.num_processes 520 | if args.scale_lr: 521 | args.learning_rate = ( 522 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 523 | ) 524 | 525 | # Initialize the optimizer 526 | if args.use_8bit_adam: 527 | try: 528 | import bitsandbytes as bnb 529 | except ImportError: 530 | raise ImportError( 531 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 532 | ) 533 | 534 | optimizer_cls = bnb.optim.AdamW8bit 535 | else: 536 | optimizer_cls = torch.optim.AdamW 537 | 538 | # optimizer settings 539 | optimizer = optimizer_cls( 540 | unet.parameters(), 541 | lr=args.learning_rate, 542 | betas=(args.adam_beta1, args.adam_beta2), 543 | weight_decay=args.adam_weight_decay, 544 | eps=args.adam_epsilon, 545 | ) 546 | with accelerator.main_process_first(): 547 | (train_loader,test_loader), dataset_config_dict = prepare_dataset(data_name=args.dataset_name, 548 | datapath=args.dataset_path, 549 | trainlist=args.trainlist, 550 | vallist=args.vallist,batch_size=args.train_batch_size, 551 | test_batch=1, 552 | datathread=args.dataloader_num_workers, 553 | logger=logger) 554 | 555 | # because the optimizer not optimized every time, so we need to calculate how many steps it optimizes, 556 | # it is usually optimized by 557 | # Scheduler and math around the number of training steps. 558 | overrode_max_train_steps = False 559 | num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) 560 | if args.max_train_steps is None: 561 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 562 | overrode_max_train_steps = True 563 | 564 | lr_scheduler = get_scheduler( 565 | args.lr_scheduler, 566 | optimizer=optimizer, 567 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 568 | num_training_steps=args.max_train_steps * accelerator.num_processes, 569 | ) 570 | 571 | # Prepare everything with our `accelerator`. 572 | unet, optimizer, train_loader, test_loader,lr_scheduler = accelerator.prepare( 573 | unet, optimizer, train_loader, test_loader,lr_scheduler 574 | ) 575 | 576 | # scale factor. 577 | rgb_latent_scale_factor = 0.18215 578 | depth_latent_scale_factor = 0.18215 579 | 580 | 581 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 582 | # as these weights are only used for inference, keeping weights in full precision is not required. 583 | weight_dtype = torch.float32 584 | if accelerator.mixed_precision == "fp16": 585 | weight_dtype = torch.float16 586 | args.mixed_precision = accelerator.mixed_precision 587 | elif accelerator.mixed_precision == "bf16": 588 | weight_dtype = torch.bfloat16 589 | args.mixed_precision = accelerator.mixed_precision 590 | 591 | # Move text_encode and vae to gpu and cast to weight_dtype 592 | text_encoder.to(accelerator.device, dtype=weight_dtype) 593 | vae.to(accelerator.device, dtype=weight_dtype) 594 | 595 | 596 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 597 | num_update_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) 598 | if overrode_max_train_steps: 599 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 600 | # Afterwards we recalculate our number of training epochs 601 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 602 | 603 | 604 | # We need to initialize the trackers we use, and also store our configuration. 605 | # The trackers initializes automatically on the main process. 606 | if accelerator.is_main_process: 607 | tracker_config = dict(vars(args)) 608 | accelerator.init_trackers(args.tracker_project_name, tracker_config) 609 | 610 | 611 | # Here is the DDP training: actually is 4 612 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 613 | 614 | logger.info("***** Running training *****") 615 | logger.info(f" Num examples = {len(train_loader)}") 616 | logger.info(f" Num Epochs = {args.num_train_epochs}") 617 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 618 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 619 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 620 | logger.info(f" Total optimization steps = {args.max_train_steps}") 621 | global_step = 0 622 | first_epoch = 0 623 | 624 | # Potentially load in the weights and states from a previous save 625 | if args.resume_from_checkpoint: 626 | if args.resume_from_checkpoint != "latest": 627 | path = os.path.basename(args.resume_from_checkpoint) 628 | else: 629 | # Get the most recent checkpoint 630 | dirs = os.listdir(args.output_dir) 631 | dirs = [d for d in dirs if d.startswith("checkpoint")] 632 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 633 | path = dirs[-1] if len(dirs) > 0 else None 634 | 635 | if path is None: 636 | accelerator.print( 637 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 638 | ) 639 | args.resume_from_checkpoint = None 640 | initial_global_step = 0 641 | else: 642 | accelerator.print(f"Resuming from checkpoint {path}") 643 | accelerator.load_state(os.path.join(args.output_dir, path)) 644 | global_step = int(path.split("-")[1]) 645 | 646 | initial_global_step = global_step 647 | first_epoch = global_step // num_update_steps_per_epoch 648 | 649 | else: 650 | initial_global_step = 0 651 | 652 | progress_bar = tqdm( 653 | range(0, args.max_train_steps), 654 | initial=initial_global_step, 655 | desc="Steps", 656 | # Only show the progress bar once on each machine. 657 | disable=not accelerator.is_local_main_process, 658 | ) 659 | 660 | if accelerator.is_main_process: 661 | unet.eval() 662 | log_validation( 663 | vae=vae, 664 | text_encoder=text_encoder, 665 | tokenizer=tokenizer, 666 | unet=unet, 667 | args=args, 668 | accelerator=accelerator, 669 | weight_dtype=weight_dtype, 670 | scheduler=noise_scheduler, 671 | epoch=0, 672 | input_image_path="/data1/liu/sceneflow/frames_cleanpass/flythings3d/TEST/A/0000/left/0006.png") 673 | 674 | 675 | 676 | # using the epochs to training the model 677 | for epoch in range(first_epoch, args.num_train_epochs): 678 | unet.train() 679 | train_loss = 0.0 680 | for step, batch in enumerate(train_loader): 681 | with accelerator.accumulate(unet): 682 | # convert the images and the depths into lantent space. 683 | left_image_data = batch['img_left'] 684 | left_disparity = batch['gt_disp'] 685 | 686 | left_disp_single = left_disparity.unsqueeze(0) 687 | left_disparity_stacked = left_disp_single.repeat(1,3,1,1) 688 | left_image_data_resized = resize_max_res_tensor(left_image_data,is_disp=False) #range in (0-1) 689 | 690 | left_disparity_resized = resize_max_res_tensor(left_disparity_stacked,is_disp=True) # not range 691 | # depth normalization: [([1, 3, 432, 768])] 692 | left_disparity_resized_normalized = Disparity_Normalization(left_disparity_resized) 693 | 694 | # convert images and the disparity into latent space. 695 | 696 | # encode RGB to lantents 697 | h_rgb = vae.encoder(left_image_data_resized.to(weight_dtype)) 698 | moments_rgb = vae.quant_conv(h_rgb) 699 | mean_rgb, logvar_rgb = torch.chunk(moments_rgb, 2, dim=1) 700 | rgb_latents = mean_rgb *rgb_latent_scale_factor #torch.Size([1, 4, 54, 96]) 701 | 702 | # encode disparity to lantents 703 | h_disp = vae.encoder(left_disparity_resized_normalized.to(weight_dtype)) 704 | moments_disp = vae.quant_conv(h_disp) 705 | mean_disp, logvar_disp = torch.chunk(moments_disp, 2, dim=1) 706 | disp_latents = mean_disp * depth_latent_scale_factor 707 | 708 | # Sample noise that we'll add to the latents 709 | noise = torch.randn_like(disp_latents) # create noise 710 | # here is the setting batch size, in our settings, it can be 1.0 711 | bsz = disp_latents.shape[0] 712 | 713 | # in the Stable Diffusion, the iterations numbers is 1000 for adding the noise and denosing. 714 | # Sample a random timestep for each image 715 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=disp_latents.device) 716 | timesteps = timesteps.long() 717 | 718 | # add noise to the depth lantents 719 | noisy_disp_latents = noise_scheduler.add_noise(disp_latents, noise, timesteps) 720 | 721 | # Encode text embedding for empty prompt 722 | prompt = "" 723 | text_inputs =tokenizer( 724 | prompt, 725 | padding="do_not_pad", 726 | max_length=tokenizer.model_max_length, 727 | truncation=True, 728 | return_tensors="pt", 729 | ) 730 | text_input_ids = text_inputs.input_ids.to(text_encoder.device) #[1,2] 731 | # print(text_input_ids.shape) 732 | empty_text_embed = text_encoder(text_input_ids)[0].to(weight_dtype) 733 | 734 | 735 | # Get the target for loss depending on the prediction type 736 | if args.prediction_type is not None: 737 | # set prediction_type of scheduler if defined 738 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 739 | if noise_scheduler.config.prediction_type == "epsilon": 740 | target = noise 741 | elif noise_scheduler.config.prediction_type == "v_prediction": 742 | target = noise_scheduler.get_velocity(disp_latents, noise, timesteps) 743 | else: 744 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 745 | 746 | batch_empty_text_embed = empty_text_embed.repeat((noisy_disp_latents.shape[0], 1, 1)) # [B, 2, 1024] 747 | 748 | # predict the noise residual and compute the loss. 749 | unet_input = torch.cat([rgb_latents,noisy_disp_latents], dim=1) # this order is important: [1,8,H,W] 750 | 751 | # predict the noise residual 752 | noise_pred = unet(unet_input, 753 | timesteps, 754 | encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w] 755 | 756 | # loss functions 757 | loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") 758 | 759 | # Gather the losses across all processes for logging (if we use distributed training). 760 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 761 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 762 | 763 | 764 | # Backpropagate 765 | accelerator.backward(loss) 766 | if accelerator.sync_gradients: 767 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 768 | 769 | optimizer.step() 770 | lr_scheduler.step() 771 | optimizer.zero_grad() 772 | 773 | 774 | # currently the EMA is not used. 775 | if accelerator.sync_gradients: 776 | if args.use_ema: 777 | ema_unet.step(unet.parameters()) 778 | progress_bar.update(1) 779 | global_step += 1 780 | accelerator.log({"train_loss": train_loss}, step=global_step) 781 | train_loss = 0.0 782 | 783 | # saving the checkpoints 784 | if global_step % args.checkpointing_steps == 0: 785 | if accelerator.is_main_process: 786 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 787 | if args.checkpoints_total_limit is not None: 788 | checkpoints = os.listdir(args.output_dir) 789 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 790 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 791 | 792 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 793 | if len(checkpoints) >= args.checkpoints_total_limit: 794 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 795 | removing_checkpoints = checkpoints[0:num_to_remove] 796 | 797 | logger.info( 798 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 799 | ) 800 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 801 | 802 | for removing_checkpoint in removing_checkpoints: 803 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 804 | shutil.rmtree(removing_checkpoint) 805 | 806 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 807 | accelerator.save_state(save_path) 808 | logger.info(f"Saved state to {save_path}") 809 | 810 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 811 | progress_bar.set_postfix(**logs) 812 | 813 | # Stop training 814 | if global_step >= args.max_train_steps: 815 | break 816 | 817 | 818 | 819 | if accelerator.is_main_process: 820 | # validation each epoch by calculate the epe and the visualization depth 821 | if args.use_ema: 822 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 823 | ema_unet.store(unet.parameters()) 824 | ema_unet.copy_to(unet.parameters()) 825 | 826 | # validation inference here 827 | log_validation( 828 | vae=vae, 829 | text_encoder=text_encoder, 830 | tokenizer=tokenizer, 831 | unet=unet, 832 | args=args, 833 | accelerator=accelerator, 834 | weight_dtype=weight_dtype, 835 | scheduler=noise_scheduler, 836 | epoch=epoch, 837 | input_image_path="/data1/liu/sceneflow/frames_cleanpass/flythings3d/TEST/A/0000/left/0006.png" 838 | ) 839 | 840 | if args.use_ema: 841 | # Switch back to the original UNet parameters. 842 | ema_unet.restore(unet.parameters()) 843 | 844 | 845 | 846 | 847 | 848 | 849 | # Create the pipeline for training and savet 850 | accelerator.wait_for_everyone() 851 | accelerator.end_training() 852 | 853 | 854 | 855 | 856 | 857 | 858 | 859 | 860 | if __name__=="__main__": 861 | main() 862 | -------------------------------------------------------------------------------- /utils/__pycache__/colormap.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/utils/__pycache__/colormap.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/utils/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/de_normalized.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/utils/__pycache__/de_normalized.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/depth_ensemble.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/utils/__pycache__/depth_ensemble.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/utils/__pycache__/image_util.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/seed_all.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Magicboomliu/Accelerator-Simple-Template/92a82436df60ca4c9402e729e5175f60439c027e/utils/__pycache__/seed_all.cpython-39.pyc -------------------------------------------------------------------------------- /utils/batch_size.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import math 4 | 5 | 6 | # Search table for suggested max. inference batch size 7 | bs_search_table = [ 8 | # tested on A100-PCIE-80GB 9 | {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, 10 | {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, 11 | # tested on A100-PCIE-40GB 12 | {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, 13 | {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, 14 | {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, 15 | {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, 16 | # tested on RTX3090, RTX4090 17 | {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, 18 | {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, 19 | {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, 20 | {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, 21 | {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, 22 | {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, 23 | # tested on GTX1080Ti 24 | {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, 25 | {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, 26 | {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, 27 | {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, 28 | {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, 29 | ] 30 | 31 | 32 | def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: 33 | """ 34 | Automatically search for suitable operating batch size. 35 | 36 | Args: 37 | ensemble_size (`int`): 38 | Number of predictions to be ensembled. 39 | input_res (`int`): 40 | Operating resolution of the input image. 41 | 42 | Returns: 43 | `int`: Operating batch size. 44 | """ 45 | if not torch.cuda.is_available(): 46 | return 1 47 | 48 | total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 49 | filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] 50 | for settings in sorted( 51 | filtered_bs_search_table, 52 | key=lambda k: (k["res"], -k["total_vram"]), 53 | ): 54 | if input_res <= settings["res"] and total_vram >= settings["total_vram"]: 55 | bs = settings["bs"] 56 | if bs > ensemble_size: 57 | bs = ensemble_size 58 | elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: 59 | bs = math.ceil(ensemble_size / 2) 60 | return bs 61 | 62 | return 1 -------------------------------------------------------------------------------- /utils/colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def kitti_colormap(disparity, maxval=-1): 5 | """ 6 | A utility function to reproduce KITTI fake colormap 7 | Arguments: 8 | - disparity: numpy float32 array of dimension HxW 9 | - maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used) 10 | 11 | Returns a numpy uint8 array of shape HxWx3. 12 | """ 13 | if maxval < 0: 14 | maxval = np.max(disparity) 15 | 16 | colormap = np.asarray([[0,0,0,114],[0,0,1,185],[1,0,0,114],[1,0,1,174],[0,1,0,114],[0,1,1,185],[1,1,0,114],[1,1,1,0]]) 17 | weights = np.asarray([8.771929824561404,5.405405405405405,8.771929824561404,5.747126436781609,8.771929824561404,5.405405405405405,8.771929824561404,0]) 18 | cumsum = np.asarray([0,0.114,0.299,0.413,0.587,0.701,0.8859999999999999,0.9999999999999999]) 19 | 20 | colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3]) 21 | values = np.expand_dims(np.minimum(np.maximum(disparity/maxval, 0.), 1.), -1) 22 | bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum,axis=0),axis=0), disparity.shape[1], axis=1), disparity.shape[0], axis=0) 23 | diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins)) 24 | index = np.argmax(diffs, axis=-1)-1 25 | 26 | w = 1-(values[:,:,0]-cumsum[index])*np.asarray(weights)[index] 27 | 28 | 29 | colored_disp[:,:,2] = (w*colormap[index][:,:,0] + (1.-w)*colormap[index+1][:,:,0]) 30 | colored_disp[:,:,1] = (w*colormap[index][:,:,1] + (1.-w)*colormap[index+1][:,:,1]) 31 | colored_disp[:,:,0] = (w*colormap[index][:,:,2] + (1.-w)*colormap[index+1][:,:,2]) 32 | 33 | return (colored_disp*np.expand_dims((disparity>0),-1)*255).astype(np.uint8) 34 | 35 | def read_16bit_gt(path): 36 | """ 37 | A utility function to read KITTI 16bit gt 38 | Arguments: 39 | - path: filepath 40 | Returns a numpy float32 array of shape HxW. 41 | """ 42 | gt = cv2.imread(path,-1).astype(np.float32)/256. 43 | return gt -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import yaml 3 | import logging 4 | import os 5 | import numpy as np 6 | import sys 7 | 8 | def load_loss_scheme(loss_config): 9 | with open(loss_config, 'r') as f: 10 | loss_json = yaml.safe_load(f) 11 | return loss_json 12 | 13 | 14 | DEBUG =0 15 | logger = logging.getLogger() 16 | 17 | 18 | if DEBUG: 19 | #coloredlogs.install(level='DEBUG') 20 | logger.setLevel(logging.DEBUG) 21 | else: 22 | #coloredlogs.install(level='INFO') 23 | logger.setLevel(logging.INFO) 24 | 25 | 26 | strhdlr = logging.StreamHandler() 27 | logger.addHandler(strhdlr) 28 | formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s') 29 | strhdlr.setFormatter(formatter) 30 | 31 | 32 | 33 | def count_parameters(model): 34 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 35 | 36 | def check_path(path): 37 | if not os.path.exists(path): 38 | os.makedirs(path, exist_ok=True) 39 | 40 | 41 | -------------------------------------------------------------------------------- /utils/de_normalized.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import least_squares 3 | 4 | 5 | def optimization_function(params, stereo_depth, normalized_depth): 6 | s, t = params 7 | adjusted_depth = normalized_depth * s + t 8 | return np.sum((stereo_depth - adjusted_depth) ** 2) 9 | 10 | 11 | def de_normalization(stereo_depth_map, normalized_depth_map): 12 | # Optimize for s and t 13 | initial_guess = [1, 0] 14 | result = least_squares(optimization_function, initial_guess, args=(stereo_depth_map, normalized_depth_map)) 15 | s_optimized, t_optimized = result.x 16 | # Recover true depth map 17 | recovered = normalized_depth_map * s_optimized + t_optimized 18 | 19 | return recovered -------------------------------------------------------------------------------- /utils/depth_ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from scipy.optimize import minimize 5 | 6 | def inter_distances(tensors: torch.Tensor): 7 | """ 8 | To calculate the distance between each two depth maps. 9 | """ 10 | distances = [] 11 | for i, j in torch.combinations(torch.arange(tensors.shape[0])): 12 | arr1 = tensors[i : i + 1] 13 | arr2 = tensors[j : j + 1] 14 | distances.append(arr1 - arr2) 15 | dist = torch.concat(distances, dim=0) 16 | return dist 17 | 18 | 19 | def ensemble_depths(input_images:torch.Tensor, 20 | regularizer_strength: float =0.02, 21 | max_iter: int =2, 22 | tol:float =1e-3, 23 | reduction: str='median', 24 | max_res: int=None): 25 | """ 26 | To ensemble multiple affine-invariant depth images (up to scale and shift), 27 | by aligning estimating the scale and shift 28 | """ 29 | 30 | device = input_images.device 31 | dtype = input_images.dtype 32 | np_dtype = np.float32 33 | 34 | original_input = input_images.clone() 35 | n_img = input_images.shape[0] #10 36 | ori_shape = input_images.shape #10,224,768 37 | 38 | if max_res is not None: 39 | scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) 40 | if scale_factor < 1: 41 | downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") 42 | input_images = downscaler(torch.from_numpy(input_images)).numpy() 43 | 44 | 45 | # init guess 46 | _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the min value of each possible depth 47 | _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the max value of each possible depth 48 | s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) #(10,1,1) : re-scale'f scale 49 | t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) #(10,1,1) 50 | 51 | x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) #(20,) 52 | 53 | 54 | input_images = input_images.to(device) 55 | 56 | # objective function 57 | def closure(x): 58 | l = len(x) 59 | s = x[: int(l / 2)] 60 | t = x[int(l / 2) :] 61 | s = torch.from_numpy(s).to(dtype=dtype).to(device) 62 | t = torch.from_numpy(t).to(dtype=dtype).to(device) 63 | 64 | transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) 65 | dists = inter_distances(transformed_arrays) 66 | sqrt_dist = torch.sqrt(torch.mean(dists**2)) 67 | 68 | if "mean" == reduction: 69 | pred = torch.mean(transformed_arrays, dim=0) 70 | elif "median" == reduction: 71 | pred = torch.median(transformed_arrays, dim=0).values 72 | else: 73 | raise ValueError 74 | 75 | near_err = torch.sqrt((0 - torch.min(pred)) ** 2) 76 | far_err = torch.sqrt((1 - torch.max(pred)) ** 2) 77 | 78 | err = sqrt_dist + (near_err + far_err) * regularizer_strength 79 | err = err.detach().cpu().numpy().astype(np_dtype) 80 | return err 81 | 82 | res = minimize( 83 | closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} 84 | ) 85 | x = res.x 86 | l = len(x) 87 | s = x[: int(l / 2)] 88 | t = x[int(l / 2) :] 89 | 90 | # Prediction 91 | s = torch.from_numpy(s).to(dtype=dtype).to(device) 92 | t = torch.from_numpy(t).to(dtype=dtype).to(device) 93 | transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) #[10,H,W] 94 | 95 | 96 | if "mean" == reduction: 97 | aligned_images = torch.mean(transformed_arrays, dim=0) 98 | std = torch.std(transformed_arrays, dim=0) 99 | uncertainty = std 100 | 101 | elif "median" == reduction: 102 | aligned_images = torch.median(transformed_arrays, dim=0).values 103 | # MAD (median absolute deviation) as uncertainty indicator 104 | abs_dev = torch.abs(transformed_arrays - aligned_images) 105 | mad = torch.median(abs_dev, dim=0).values 106 | uncertainty = mad 107 | 108 | # Scale and shift to [0, 1] 109 | _min = torch.min(aligned_images) 110 | _max = torch.max(aligned_images) 111 | aligned_images = (aligned_images - _min) / (_max - _min) 112 | uncertainty /= _max - _min 113 | 114 | return aligned_images, uncertainty -------------------------------------------------------------------------------- /utils/image_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | 7 | 8 | 9 | def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: 10 | """ 11 | Resize image to limit maximum edge length while keeping aspect ratio. 12 | Args: 13 | img (`Image.Image`): 14 | Image to be resized. 15 | max_edge_resolution (`int`): 16 | Maximum edge length (pixel). 17 | Returns: 18 | `Image.Image`: Resized image. 19 | """ 20 | 21 | original_width, original_height = img.size 22 | 23 | downscale_factor = min( 24 | max_edge_resolution / original_width, max_edge_resolution / original_height 25 | ) 26 | 27 | new_width = int(original_width * downscale_factor) 28 | new_height = int(original_height * downscale_factor) 29 | 30 | resized_img = img.resize((new_width, new_height)) 31 | return resized_img 32 | 33 | 34 | def colorize_depth_maps( 35 | depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None 36 | ): 37 | """ 38 | Colorize depth maps. 39 | """ 40 | assert len(depth_map.shape) >= 2, "Invalid dimension" 41 | 42 | if isinstance(depth_map, torch.Tensor): 43 | depth = depth_map.detach().clone().squeeze().numpy() 44 | elif isinstance(depth_map, np.ndarray): 45 | depth = depth_map.copy().squeeze() 46 | # reshape to [ (B,) H, W ] 47 | if depth.ndim < 3: 48 | depth = depth[np.newaxis, :, :] 49 | 50 | # colorize 51 | cm = matplotlib.colormaps[cmap] 52 | depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) 53 | img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 54 | img_colored_np = np.rollaxis(img_colored_np, 3, 1) 55 | 56 | if valid_mask is not None: 57 | if isinstance(depth_map, torch.Tensor): 58 | valid_mask = valid_mask.detach().numpy() 59 | valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] 60 | if valid_mask.ndim < 3: 61 | valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] 62 | else: 63 | valid_mask = valid_mask[:, np.newaxis, :, :] 64 | valid_mask = np.repeat(valid_mask, 3, axis=1) 65 | img_colored_np[~valid_mask] = 0 66 | 67 | if isinstance(depth_map, torch.Tensor): 68 | img_colored = torch.from_numpy(img_colored_np).float() 69 | elif isinstance(depth_map, np.ndarray): 70 | img_colored = img_colored_np 71 | 72 | return img_colored 73 | 74 | 75 | def chw2hwc(chw): 76 | assert 3 == len(chw.shape) 77 | if isinstance(chw, torch.Tensor): 78 | hwc = torch.permute(chw, (1, 2, 0)) 79 | elif isinstance(chw, np.ndarray): 80 | hwc = np.moveaxis(chw, 0, -1) 81 | return hwc 82 | -------------------------------------------------------------------------------- /utils/seed_all.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # -------------------------------------------------------------------------- 15 | # If you find this code useful, we kindly ask you to cite our paper in your work. 16 | # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation 17 | # More information about the method can be found at https://marigoldmonodepth.github.io 18 | # -------------------------------------------------------------------------- 19 | 20 | 21 | import numpy as np 22 | import random 23 | import torch 24 | 25 | 26 | def seed_all(seed: int = 0): 27 | """ 28 | Set random seeds of all components. 29 | """ 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | --------------------------------------------------------------------------------