├── .gitignore ├── FOD ├── Custom_augmentation.py ├── FocusOnDepth.py ├── Fusion.py ├── Head.py ├── Loss.py ├── Predictor.py ├── Reassemble.py ├── Trainer.py ├── dataset.py └── utils.py ├── FocusOnDepth.pdf ├── LICENSE ├── README.md ├── config.json ├── images └── pull_figure.png ├── requirements.txt ├── run.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | DPT/weights/*.pt 3 | .ipynb_checkpoints/ 4 | *.pyc 5 | *.pyc 6 | wandb/* 7 | *.png 8 | *.jpg 9 | *.right 10 | DPT/util/__pycache__/*.pyc 11 | models/FocusOnDepth.p 12 | *.tar 13 | *.zip -------------------------------------------------------------------------------- /FOD/Custom_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class ToMask(object): 5 | """ 6 | Convert a 3 channel RGB image into a 1 channel segmentation mask 7 | """ 8 | def __init__(self, palette_dictionnary): 9 | self.nb_classes = len(palette_dictionnary) 10 | # sort the dictionary of the classes by the sum of rgb value -> to have always background = 0 11 | # self.converted_dictionnary = {i: v for i, (k, v) in enumerate(sorted(palette_dictionnary.items(), key=lambda item: sum(item[1])))} 12 | self.palette_dictionnary = palette_dictionnary 13 | 14 | def __call__(self, pil_image): 15 | # avoid taking the alpha channel 16 | image_array = np.array(pil_image)[:, :, :3] 17 | # get only one channel for the output 18 | output_array = np.zeros(image_array.shape, dtype="int")[:, :, 0] 19 | 20 | for label in self.palette_dictionnary.keys(): 21 | rgb_color = self.palette_dictionnary[label]['color'] 22 | mask = (image_array == rgb_color) 23 | output_array[mask[:, :, 0]] = int(label) 24 | 25 | output_array = torch.from_numpy(output_array).unsqueeze(0).long() 26 | return output_array 27 | -------------------------------------------------------------------------------- /FOD/FocusOnDepth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import timm 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | from FOD.Reassemble import Reassemble 9 | from FOD.Fusion import Fusion 10 | from FOD.Head import HeadDepth, HeadSeg 11 | 12 | torch.manual_seed(0) 13 | 14 | class FocusOnDepth(nn.Module): 15 | def __init__(self, 16 | image_size = (3, 384, 384), 17 | patch_size = 16, 18 | emb_dim = 1024, 19 | resample_dim = 256, 20 | read = 'projection', 21 | num_layers_encoder = 24, 22 | hooks = [5, 11, 17, 23], 23 | reassemble_s = [4, 8, 16, 32], 24 | transformer_dropout= 0, 25 | nclasses = 2, 26 | type = "full", 27 | model_timm = "vit_large_patch16_384"): 28 | """ 29 | Focus on Depth 30 | type : {"full", "depth", "segmentation"} 31 | image_size : (c, h, w) 32 | patch_size : *a square* 33 | emb_dim <=> D (in the paper) 34 | resample_dim <=> ^D (in the paper) 35 | read : {"ignore", "add", "projection"} 36 | """ 37 | super().__init__() 38 | 39 | #Splitting img into patches 40 | # channels, image_height, image_width = image_size 41 | # assert image_height % patch_size == 0 and image_width % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 42 | # num_patches = (image_height // patch_size) * (image_width // patch_size) 43 | # patch_dim = channels * patch_size * patch_size 44 | # self.to_patch_embedding = nn.Sequential( 45 | # Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), 46 | # nn.Linear(patch_dim, emb_dim), 47 | # ) 48 | # #Embedding 49 | # self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim)) 50 | # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim)) 51 | 52 | #Transformer 53 | # encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dropout=transformer_dropout, dim_feedforward=emb_dim*4) 54 | # self.transformer_encoders = nn.TransformerEncoder(encoder_layer, num_layers=num_layers_encoder) 55 | self.transformer_encoders = timm.create_model(model_timm, pretrained=True) 56 | self.type_ = type 57 | 58 | #Register hooks 59 | self.activation = {} 60 | self.hooks = hooks 61 | self._get_layers_from_hooks(self.hooks) 62 | 63 | #Reassembles Fusion 64 | self.reassembles = [] 65 | self.fusions = [] 66 | for s in reassemble_s: 67 | self.reassembles.append(Reassemble(image_size, read, patch_size, s, emb_dim, resample_dim)) 68 | self.fusions.append(Fusion(resample_dim)) 69 | self.reassembles = nn.ModuleList(self.reassembles) 70 | self.fusions = nn.ModuleList(self.fusions) 71 | 72 | #Head 73 | if type == "full": 74 | self.head_depth = HeadDepth(resample_dim) 75 | self.head_segmentation = HeadSeg(resample_dim, nclasses=nclasses) 76 | elif type == "depth": 77 | self.head_depth = HeadDepth(resample_dim) 78 | self.head_segmentation = None 79 | else: 80 | self.head_depth = None 81 | self.head_segmentation = HeadSeg(resample_dim, nclasses=nclasses) 82 | 83 | def forward(self, img): 84 | # x = self.to_patch_embedding(img) 85 | # b, n, _ = x.shape 86 | # cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 87 | # x = torch.cat((cls_tokens, x), dim=1) 88 | # x += self.pos_embedding[:, :(n + 1)] 89 | # t = self.transformer_encoders(x) 90 | 91 | t = self.transformer_encoders(img) 92 | previous_stage = None 93 | for i in np.arange(len(self.fusions)-1, -1, -1): 94 | hook_to_take = 't'+str(self.hooks[i]) 95 | activation_result = self.activation[hook_to_take] 96 | reassemble_result = self.reassembles[i](activation_result) 97 | fusion_result = self.fusions[i](reassemble_result, previous_stage) 98 | previous_stage = fusion_result 99 | out_depth = None 100 | out_segmentation = None 101 | if self.head_depth != None: 102 | out_depth = self.head_depth(previous_stage) 103 | if self.head_segmentation != None: 104 | out_segmentation = self.head_segmentation(previous_stage) 105 | return out_depth, out_segmentation 106 | 107 | def _get_layers_from_hooks(self, hooks): 108 | def get_activation(name): 109 | def hook(model, input, output): 110 | self.activation[name] = output 111 | return hook 112 | for h in hooks: 113 | #self.transformer_encoders.layers[h].register_forward_hook(get_activation('t'+str(h))) 114 | self.transformer_encoders.blocks[h].register_forward_hook(get_activation('t'+str(h))) 115 | -------------------------------------------------------------------------------- /FOD/Fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ResidualConvUnit(nn.Module): 6 | def __init__(self, features): 7 | super().__init__() 8 | 9 | self.conv1 = nn.Conv2d( 10 | features, features, kernel_size=3, stride=1, padding=1, bias=True) 11 | self.conv2 = nn.Conv2d( 12 | features, features, kernel_size=3, stride=1, padding=1, bias=True) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | def forward(self, x): 16 | """Forward pass. 17 | Args: 18 | x (tensor): input 19 | Returns: 20 | tensor: output 21 | """ 22 | out = self.relu(x) 23 | out = self.conv1(out) 24 | out = self.relu(out) 25 | out = self.conv2(out) 26 | return out + x 27 | 28 | class Fusion(nn.Module): 29 | def __init__(self, resample_dim): 30 | super(Fusion, self).__init__() 31 | self.res_conv1 = ResidualConvUnit(resample_dim) 32 | self.res_conv2 = ResidualConvUnit(resample_dim) 33 | #self.resample = nn.ConvTranspose2d(resample_dim, resample_dim, kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1) 34 | 35 | def forward(self, x, previous_stage=None): 36 | if previous_stage == None: 37 | previous_stage = torch.zeros_like(x) 38 | output_stage1 = self.res_conv1(x) 39 | output_stage1 += previous_stage 40 | output_stage2 = self.res_conv2(output_stage1) 41 | output_stage2 = nn.functional.interpolate(output_stage2, scale_factor=2, mode="bilinear", align_corners=True) 42 | return output_stage2 43 | -------------------------------------------------------------------------------- /FOD/Head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Interpolate(nn.Module): 6 | def __init__(self, scale_factor, mode, align_corners=False): 7 | super(Interpolate, self).__init__() 8 | self.interp = nn.functional.interpolate 9 | self.scale_factor = scale_factor 10 | self.mode = mode 11 | self.align_corners = align_corners 12 | 13 | def forward(self, x): 14 | x = self.interp( 15 | x, 16 | scale_factor=self.scale_factor, 17 | mode=self.mode, 18 | align_corners=self.align_corners) 19 | return x 20 | 21 | class HeadDepth(nn.Module): 22 | def __init__(self, features): 23 | super(HeadDepth, self).__init__() 24 | self.head = nn.Sequential( 25 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 26 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 27 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 28 | nn.ReLU(), 29 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 30 | # nn.ReLU() 31 | nn.Sigmoid() 32 | ) 33 | def forward(self, x): 34 | x = self.head(x) 35 | # x = (x - x.min())/(x.max()-x.min() + 1e-15) 36 | return x 37 | 38 | class HeadSeg(nn.Module): 39 | def __init__(self, features, nclasses=2): 40 | super(HeadSeg, self).__init__() 41 | self.head = nn.Sequential( 42 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 43 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 44 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 45 | nn.ReLU(), 46 | nn.Conv2d(32, nclasses, kernel_size=1, stride=1, padding=0) 47 | ) 48 | def forward(self, x): 49 | x = self.head(x) 50 | return x 51 | -------------------------------------------------------------------------------- /FOD/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def compute_scale_and_shift(prediction, target, mask): 5 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 6 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 7 | a_01 = torch.sum(mask * prediction, (1, 2)) 8 | a_11 = torch.sum(mask, (1, 2)) 9 | 10 | # right hand side: b = [b_0, b_1] 11 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 12 | b_1 = torch.sum(mask * target, (1, 2)) 13 | 14 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 15 | x_0 = torch.zeros_like(b_0) 16 | x_1 = torch.zeros_like(b_1) 17 | 18 | det = a_00 * a_11 - a_01 * a_01 19 | valid = det.nonzero() 20 | 21 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] 22 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] 23 | 24 | return x_0, x_1 25 | 26 | 27 | def reduction_batch_based(image_loss, M): 28 | # average of all valid pixels of the batch 29 | 30 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 31 | divisor = torch.sum(M) 32 | 33 | if divisor == 0: 34 | return 0 35 | else: 36 | return torch.sum(image_loss) / divisor 37 | 38 | 39 | def reduction_image_based(image_loss, M): 40 | # mean of average of valid pixels of an image 41 | 42 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 43 | valid = M.nonzero() 44 | 45 | image_loss[valid] = image_loss[valid] / M[valid] 46 | 47 | return torch.mean(image_loss) 48 | 49 | 50 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based): 51 | 52 | M = torch.sum(mask, (1, 2)) 53 | res = prediction - target 54 | image_loss = torch.sum(mask * res * res, (1, 2)) 55 | 56 | return reduction(image_loss, 2 * M) 57 | 58 | 59 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 60 | 61 | M = torch.sum(mask, (1, 2)) 62 | 63 | diff = prediction - target 64 | diff = torch.mul(mask, diff) 65 | 66 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 67 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 68 | grad_x = torch.mul(mask_x, grad_x) 69 | 70 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 71 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 72 | grad_y = torch.mul(mask_y, grad_y) 73 | 74 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 75 | 76 | return reduction(image_loss, M) 77 | 78 | 79 | class MSELoss(nn.Module): 80 | def __init__(self, reduction='batch-based'): 81 | super().__init__() 82 | 83 | if reduction == 'batch-based': 84 | self.__reduction = reduction_batch_based 85 | else: 86 | self.__reduction = reduction_image_based 87 | 88 | def forward(self, prediction, target, mask): 89 | return mse_loss(prediction, target, mask, reduction=self.__reduction) 90 | 91 | 92 | class GradientLoss(nn.Module): 93 | def __init__(self, scales=4, reduction='batch-based'): 94 | super().__init__() 95 | 96 | if reduction == 'batch-based': 97 | self.__reduction = reduction_batch_based 98 | else: 99 | self.__reduction = reduction_image_based 100 | 101 | self.__scales = scales 102 | 103 | def forward(self, prediction, target, mask): 104 | total = 0 105 | 106 | for scale in range(self.__scales): 107 | step = pow(2, scale) 108 | 109 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 110 | mask[:, ::step, ::step], reduction=self.__reduction) 111 | 112 | return total 113 | 114 | 115 | class ScaleAndShiftInvariantLoss(nn.Module): 116 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): 117 | super().__init__() 118 | 119 | self.__data_loss = MSELoss(reduction=reduction) 120 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) 121 | self.__alpha = alpha 122 | 123 | self.__prediction_ssi = None 124 | 125 | def forward(self, prediction, target): 126 | #preprocessing 127 | mask = target > 0 128 | 129 | #calcul 130 | scale, shift = compute_scale_and_shift(prediction, target, mask) 131 | # print(scale, shift) 132 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) 133 | 134 | total = self.__data_loss(self.__prediction_ssi, target, mask) 135 | if self.__alpha > 0: 136 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) 137 | 138 | return total 139 | 140 | def __get_prediction_ssi(self): 141 | return self.__prediction_ssi 142 | 143 | prediction_ssi = property(__get_prediction_ssi) 144 | -------------------------------------------------------------------------------- /FOD/Predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import cv2 6 | from torchvision import transforms 7 | from scipy.ndimage.filters import gaussian_filter 8 | 9 | from PIL import Image 10 | 11 | from FOD.FocusOnDepth import FocusOnDepth 12 | from FOD.utils import create_dir 13 | from FOD.dataset import show 14 | 15 | 16 | class Predictor(object): 17 | def __init__(self, config, input_images): 18 | self.input_images = input_images 19 | self.config = config 20 | self.type = self.config['General']['type'] 21 | 22 | self.device = torch.device(self.config['General']['device'] if torch.cuda.is_available() else "cpu") 23 | print("device: %s" % self.device) 24 | resize = config['Dataset']['transforms']['resize'] 25 | self.model = FocusOnDepth( 26 | image_size = (3,resize,resize), 27 | emb_dim = config['General']['emb_dim'], 28 | resample_dim= config['General']['resample_dim'], 29 | read = config['General']['read'], 30 | nclasses = len(config['Dataset']['classes']) + 1, 31 | hooks = config['General']['hooks'], 32 | model_timm = config['General']['model_timm'], 33 | type = self.type, 34 | patch_size = config['General']['patch_size'], 35 | ) 36 | path_model = os.path.join(config['General']['path_model'], 'FocusOnDepth_{}.p'.format(config['General']['model_timm'])) 37 | self.model.load_state_dict( 38 | torch.load(path_model, map_location=self.device)['model_state_dict'] 39 | ) 40 | self.model.eval() 41 | self.transform_image = transforms.Compose([ 42 | transforms.Resize((resize, resize)), 43 | transforms.ToTensor(), 44 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 45 | ]) 46 | self.output_dir = self.config['General']['path_predicted_images'] 47 | create_dir(self.output_dir) 48 | 49 | def run(self): 50 | with torch.no_grad(): 51 | for images in self.input_images: 52 | pil_im = Image.open(images) 53 | original_size = pil_im.size 54 | 55 | tensor_im = self.transform_image(pil_im).unsqueeze(0) 56 | output_depth, output_segmentation = self.model(tensor_im) 57 | output_depth = 1-output_depth 58 | 59 | output_segmentation = transforms.ToPILImage()(output_segmentation.squeeze(0).argmax(dim=0).float()).resize(original_size, resample=Image.NEAREST) 60 | output_depth = transforms.ToPILImage()(output_depth.squeeze(0).float()).resize(original_size, resample=Image.BICUBIC) 61 | 62 | path_dir_segmentation = os.path.join(self.output_dir, 'segmentations') 63 | path_dir_depths = os.path.join(self.output_dir, 'depths') 64 | create_dir(path_dir_segmentation) 65 | output_segmentation.save(os.path.join(path_dir_segmentation, os.path.basename(images))) 66 | 67 | path_dir_depths = os.path.join(self.output_dir, 'depths') 68 | create_dir(path_dir_depths) 69 | output_depth.save(os.path.join(path_dir_depths, os.path.basename(images))) 70 | 71 | ## TO DO: Apply AutoFocus 72 | 73 | # output_depth = np.array(output_depth) 74 | # output_segmentation = np.array(output_segmentation) 75 | 76 | # mask_person = (output_segmentation != 0) 77 | # depth_person = output_depth*mask_person 78 | # mean_depth_person = np.mean(depth_person[depth_person != 0]) 79 | # std_depth_person = np.std(depth_person[depth_person != 0]) 80 | 81 | # #print(mean_depth_person, std_depth_person) 82 | 83 | # mask_total = (depth_person >= mean_depth_person-2*std_depth_person) 84 | # mask_total = np.repeat(mask_total[:, :, np.newaxis], 3, axis=-1) 85 | # region_to_blur = np.ones(np_im.shape)*(1-mask_total) 86 | 87 | # #region_not_to_blur = np.zeros(np_im.shape) + np_im*(mask_total) 88 | # region_not_to_blur = np_im 89 | # blurred = cv2.blur(region_to_blur, (10, 10)) 90 | 91 | # #final_image = blurred + region_not_to_blur 92 | # final_image = cv2.addWeighted(region_not_to_blur.astype(np.uint8), 0.5, blurred.astype(np.uint8), 0.5, 0) 93 | # final_image = Image.fromarray((final_image).astype(np.uint8)) 94 | # final_image.save(os.path.join(self.output_dir, os.path.basename(images))) 95 | -------------------------------------------------------------------------------- /FOD/Reassemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | 7 | class Read_ignore(nn.Module): 8 | def __init__(self, start_index=1): 9 | super(Read_ignore, self).__init__() 10 | self.start_index = start_index 11 | 12 | def forward(self, x): 13 | return x[:, self.start_index:] 14 | 15 | 16 | class Read_add(nn.Module): 17 | def __init__(self, start_index=1): 18 | super(Read_add, self).__init__() 19 | self.start_index = start_index 20 | 21 | def forward(self, x): 22 | if self.start_index == 2: 23 | readout = (x[:, 0] + x[:, 1]) / 2 24 | else: 25 | readout = x[:, 0] 26 | return x[:, self.start_index :] + readout.unsqueeze(1) 27 | 28 | 29 | class Read_projection(nn.Module): 30 | def __init__(self, in_features, start_index=1): 31 | super(Read_projection, self).__init__() 32 | self.start_index = start_index 33 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 34 | 35 | def forward(self, x): 36 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 37 | features = torch.cat((x[:, self.start_index :], readout), -1) 38 | return self.project(features) 39 | 40 | class MyConvTranspose2d(nn.Module): 41 | def __init__(self, conv, output_size): 42 | super(MyConvTranspose2d, self).__init__() 43 | self.output_size = output_size 44 | self.conv = conv 45 | 46 | def forward(self, x): 47 | x = self.conv(x, output_size=self.output_size) 48 | return x 49 | 50 | class Resample(nn.Module): 51 | def __init__(self, p, s, h, emb_dim, resample_dim): 52 | super(Resample, self).__init__() 53 | assert (s in [4, 8, 16, 32]), "s must be in [0.5, 4, 8, 16, 32]" 54 | self.conv1 = nn.Conv2d(emb_dim, resample_dim, kernel_size=1, stride=1, padding=0) 55 | if s == 4: 56 | self.conv2 = nn.ConvTranspose2d(resample_dim, 57 | resample_dim, 58 | kernel_size=4, 59 | stride=4, 60 | padding=0, 61 | bias=True, 62 | dilation=1, 63 | groups=1) 64 | elif s == 8: 65 | self.conv2 = nn.ConvTranspose2d(resample_dim, 66 | resample_dim, 67 | kernel_size=2, 68 | stride=2, 69 | padding=0, 70 | bias=True, 71 | dilation=1, 72 | groups=1) 73 | elif s == 16: 74 | self.conv2 = nn.Identity() 75 | else: 76 | self.conv2 = nn.Conv2d(resample_dim, resample_dim, kernel_size=2,stride=2, padding=0, bias=True) 77 | 78 | def forward(self, x): 79 | x = self.conv1(x) 80 | x = self.conv2(x) 81 | return x 82 | 83 | class Reassemble(nn.Module): 84 | def __init__(self, image_size, read, p, s, emb_dim, resample_dim): 85 | """ 86 | p = patch size 87 | s = coefficient resample 88 | emb_dim <=> D (in the paper) 89 | resample_dim <=> ^D (in the paper) 90 | read : {"ignore", "add", "projection"} 91 | """ 92 | super(Reassemble, self).__init__() 93 | channels, image_height, image_width = image_size 94 | 95 | #Read 96 | self.read = Read_ignore() 97 | if read == 'add': 98 | self.read = Read_add() 99 | elif read == 'projection': 100 | self.read = Read_projection(emb_dim) 101 | 102 | #Concat after read 103 | self.concat = Rearrange('b (h w) c -> b c h w', 104 | c=emb_dim, 105 | h=(image_height // p), 106 | w=(image_width // p)) 107 | 108 | #Projection + Resample 109 | self.resample = Resample(p, s, image_height, emb_dim, resample_dim) 110 | 111 | def forward(self, x): 112 | x = self.read(x) 113 | x = self.concat(x) 114 | x = self.resample(x) 115 | return x 116 | -------------------------------------------------------------------------------- /FOD/Trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import wandb 6 | import cv2 7 | import torch.nn as nn 8 | 9 | from tqdm import tqdm 10 | from os import replace 11 | from numpy.core.numeric import Inf 12 | from FOD.utils import get_losses, get_optimizer, get_schedulers, create_dir 13 | from FOD.FocusOnDepth import FocusOnDepth 14 | 15 | class Trainer(object): 16 | def __init__(self, config): 17 | super().__init__() 18 | self.config = config 19 | self.type = self.config['General']['type'] 20 | 21 | self.device = torch.device(self.config['General']['device'] if torch.cuda.is_available() else "cpu") 22 | print("device: %s" % self.device) 23 | resize = config['Dataset']['transforms']['resize'] 24 | self.model = FocusOnDepth( 25 | image_size = (3,resize,resize), 26 | emb_dim = config['General']['emb_dim'], 27 | resample_dim= config['General']['resample_dim'], 28 | read = config['General']['read'], 29 | nclasses = len(config['Dataset']['classes']) + 1, 30 | hooks = config['General']['hooks'], 31 | model_timm = config['General']['model_timm'], 32 | type = self.type, 33 | patch_size = config['General']['patch_size'], 34 | ) 35 | 36 | self.model.to(self.device) 37 | # print(self.model) 38 | # exit(0) 39 | 40 | self.loss_depth, self.loss_segmentation = get_losses(config) 41 | self.optimizer_backbone, self.optimizer_scratch = get_optimizer(config, self.model) 42 | self.schedulers = get_schedulers([self.optimizer_backbone, self.optimizer_scratch]) 43 | 44 | def train(self, train_dataloader, val_dataloader): 45 | epochs = self.config['General']['epochs'] 46 | if self.config['wandb']['enable']: 47 | wandb.init(project="FocusOnDepth", entity=self.config['wandb']['username']) 48 | wandb.config = { 49 | "learning_rate_backbone": self.config['General']['lr_backbone'], 50 | "learning_rate_scratch": self.config['General']['lr_scratch'], 51 | "epochs": epochs, 52 | "batch_size": self.config['General']['batch_size'] 53 | } 54 | val_loss = Inf 55 | for epoch in range(epochs): # loop over the dataset multiple times 56 | print("Epoch ", epoch+1) 57 | running_loss = 0.0 58 | self.model.train() 59 | pbar = tqdm(train_dataloader) 60 | pbar.set_description("Training") 61 | for i, (X, Y_depths, Y_segmentations) in enumerate(pbar): 62 | # get the inputs; data is a list of [inputs, labels] 63 | X, Y_depths, Y_segmentations = X.to(self.device), Y_depths.to(self.device), Y_segmentations.to(self.device) 64 | # zero the parameter gradients 65 | self.optimizer_backbone.zero_grad() 66 | self.optimizer_scratch.zero_grad() 67 | # forward + backward + optimizer 68 | output_depths, output_segmentations = self.model(X) 69 | output_depths = output_depths.squeeze(1) if output_depths != None else None 70 | 71 | Y_depths = Y_depths.squeeze(1) #1xHxW -> HxW 72 | Y_segmentations = Y_segmentations.squeeze(1) #1xHxW -> HxW 73 | # get loss 74 | loss = self.loss_depth(output_depths, Y_depths) + self.loss_segmentation(output_segmentations, Y_segmentations) 75 | loss.backward() 76 | # step optimizer 77 | self.optimizer_scratch.step() 78 | self.optimizer_backbone.step() 79 | 80 | running_loss += loss.item() 81 | if np.isnan(running_loss): 82 | print('\n', 83 | X.min().item(), X.max().item(),'\n', 84 | Y_depths.min().item(), Y_depths.max().item(),'\n', 85 | output_depths.min().item(), output_depths.max().item(),'\n', 86 | loss.item(), 87 | ) 88 | exit(0) 89 | 90 | if self.config['wandb']['enable'] and ((i % 50 == 0 and i>0) or i==len(train_dataloader)-1): 91 | wandb.log({"loss": running_loss/(i+1)}) 92 | pbar.set_postfix({'training_loss': running_loss/(i+1)}) 93 | 94 | new_val_loss = self.run_eval(val_dataloader) 95 | 96 | if new_val_loss < val_loss: 97 | self.save_model() 98 | val_loss = new_val_loss 99 | 100 | self.schedulers[0].step(new_val_loss) 101 | self.schedulers[1].step(new_val_loss) 102 | 103 | print('Finished Training') 104 | 105 | def run_eval(self, val_dataloader): 106 | """ 107 | Evaluate the model on the validation set and visualize some results 108 | on wandb 109 | :- val_dataloader -: torch dataloader 110 | """ 111 | val_loss = 0. 112 | self.model.eval() 113 | X_1 = None 114 | Y_depths_1 = None 115 | Y_segmentations_1 = None 116 | output_depths_1 = None 117 | output_segmentations_1 = None 118 | with torch.no_grad(): 119 | pbar = tqdm(val_dataloader) 120 | pbar.set_description("Validation") 121 | for i, (X, Y_depths, Y_segmentations) in enumerate(pbar): 122 | X, Y_depths, Y_segmentations = X.to(self.device), Y_depths.to(self.device), Y_segmentations.to(self.device) 123 | output_depths, output_segmentations = self.model(X) 124 | output_depths = output_depths.squeeze(1) if output_depths != None else None 125 | Y_depths = Y_depths.squeeze(1) 126 | Y_segmentations = Y_segmentations.squeeze(1) 127 | if i==0: 128 | X_1 = X 129 | Y_depths_1 = Y_depths 130 | Y_segmentations_1 = Y_segmentations 131 | output_depths_1 = output_depths 132 | output_segmentations_1 = output_segmentations 133 | # get loss 134 | loss = self.loss_depth(output_depths, Y_depths) + self.loss_segmentation(output_segmentations, Y_segmentations) 135 | val_loss += loss.item() 136 | pbar.set_postfix({'validation_loss': val_loss/(i+1)}) 137 | if self.config['wandb']['enable']: 138 | wandb.log({"val_loss": val_loss/(i+1)}) 139 | self.img_logger(X_1, Y_depths_1, Y_segmentations_1, output_depths_1, output_segmentations_1) 140 | return val_loss/(i+1) 141 | 142 | def save_model(self): 143 | path_model = os.path.join(self.config['General']['path_model'], self.model.__class__.__name__) 144 | create_dir(path_model) 145 | torch.save({'model_state_dict': self.model.state_dict(), 146 | 'optimizer_backbone_state_dict': self.optimizer_backbone.state_dict(), 147 | 'optimizer_scratch_state_dict': self.optimizer_scratch.state_dict() 148 | }, path_model+'.p') 149 | print('Model saved at : {}'.format(path_model)) 150 | 151 | def img_logger(self, X, Y_depths, Y_segmentations, output_depths, output_segmentations): 152 | nb_to_show = self.config['wandb']['images_to_show'] if self.config['wandb']['images_to_show'] <= len(X) else len(X) 153 | tmp = X[:nb_to_show].detach().cpu().numpy() 154 | imgs = (tmp - tmp.min()) / (tmp.max() - tmp.min()) 155 | if output_depths != None: 156 | tmp = Y_depths[:nb_to_show].unsqueeze(1).detach().cpu().numpy() 157 | depth_truths = np.repeat(tmp, 3, axis=1) 158 | tmp = output_depths[:nb_to_show].unsqueeze(1).detach().cpu().numpy() 159 | tmp = np.repeat(tmp, 3, axis=1) 160 | #depth_preds = 1.0 - tmp 161 | depth_preds = tmp 162 | if output_segmentations != None: 163 | tmp = Y_segmentations[:nb_to_show].unsqueeze(1).detach().cpu().numpy() 164 | segmentation_truths = np.repeat(tmp, 3, axis=1).astype('float32') 165 | tmp = torch.argmax(output_segmentations[:nb_to_show], dim=1) 166 | tmp = tmp.unsqueeze(1).detach().cpu().numpy() 167 | tmp = np.repeat(tmp, 3, axis=1) 168 | segmentation_preds = tmp.astype('float32') 169 | # print("******************************************************") 170 | # print(imgs.shape, imgs.mean().item(), imgs.max().item(), imgs.min().item()) 171 | # if output_depths != None: 172 | # print(depth_truths.shape, depth_truths.mean().item(), depth_truths.max().item(), depth_truths.min().item()) 173 | # print(depth_preds.shape, depth_preds.mean().item(), depth_preds.max().item(), depth_preds.min().item()) 174 | # if output_segmentations != None: 175 | # print(segmentation_truths.shape, segmentation_truths.mean().item(), segmentation_truths.max().item(), segmentation_truths.min().item()) 176 | # print(segmentation_preds.shape, segmentation_preds.mean().item(), segmentation_preds.max().item(), segmentation_preds.min().item()) 177 | # print("******************************************************") 178 | imgs = imgs.transpose(0,2,3,1) 179 | if output_depths != None: 180 | depth_truths = depth_truths.transpose(0,2,3,1) 181 | depth_preds = depth_preds.transpose(0,2,3,1) 182 | if output_segmentations != None: 183 | segmentation_truths = segmentation_truths.transpose(0,2,3,1) 184 | segmentation_preds = segmentation_preds.transpose(0,2,3,1) 185 | output_dim = (int(self.config['wandb']['im_w']), int(self.config['wandb']['im_h'])) 186 | 187 | wandb.log({ 188 | "img": [wandb.Image(cv2.resize(im, output_dim), caption='img_{}'.format(i+1)) for i, im in enumerate(imgs)] 189 | }) 190 | if output_depths != None: 191 | wandb.log({ 192 | "depth_truths": [wandb.Image(cv2.resize(im, output_dim), caption='depth_truths_{}'.format(i+1)) for i, im in enumerate(depth_truths)], 193 | "depth_preds": [wandb.Image(cv2.resize(im, output_dim), caption='depth_preds_{}'.format(i+1)) for i, im in enumerate(depth_preds)] 194 | }) 195 | if output_segmentations != None: 196 | wandb.log({ 197 | "seg_truths": [wandb.Image(cv2.resize(im, output_dim), caption='seg_truths_{}'.format(i+1)) for i, im in enumerate(segmentation_truths)], 198 | "seg_preds": [wandb.Image(cv2.resize(im, output_dim), caption='seg_preds_{}'.format(i+1)) for i, im in enumerate(segmentation_preds)] 199 | }) 200 | -------------------------------------------------------------------------------- /FOD/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from glob import glob 4 | 5 | import torch 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from tqdm import tqdm 10 | from PIL import Image 11 | from torch.utils.data.dataloader import default_collate 12 | from torch.utils.data import Dataset, DataLoader 13 | from torchvision import transforms 14 | import torchvision.transforms.functional as TF 15 | 16 | from FOD.utils import get_total_paths, get_splitted_dataset, get_transforms 17 | 18 | def show(imgs): 19 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) 20 | for i, img in enumerate(imgs): 21 | img = transforms.ToPILImage()(img.to('cpu').float()) 22 | axs[0, i].imshow(np.asarray(img)) 23 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 24 | plt.show() 25 | 26 | class AutoFocusDataset(Dataset): 27 | """ 28 | Dataset class for the AutoFocus Task. Requires for each image, its depth ground-truth and 29 | segmentation mask 30 | Args: 31 | :- config -: json config file 32 | :- dataset_name -: str 33 | :- split -: split ['train', 'val', 'test'] 34 | """ 35 | def __init__(self, config, dataset_name, split=None): 36 | self.split = split 37 | self.config = config 38 | 39 | path_images = os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_images']) 40 | path_depths = os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_depths']) 41 | path_segmentations = os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_segmentations']) 42 | 43 | self.paths_images = get_total_paths(path_images, config['Dataset']['extensions']['ext_images']) 44 | self.paths_depths = get_total_paths(path_depths, config['Dataset']['extensions']['ext_depths']) 45 | self.paths_segmentations = get_total_paths(path_segmentations, config['Dataset']['extensions']['ext_segmentations']) 46 | 47 | assert (self.split in ['train', 'test', 'val']), "Invalid split!" 48 | assert (len(self.paths_images) == len(self.paths_depths)), "Different number of instances between the input and the depth maps" 49 | assert (len(self.paths_images) == len(self.paths_segmentations)), "Different number of instances between the input and the segmentation maps" 50 | assert (config['Dataset']['splits']['split_train']+config['Dataset']['splits']['split_test']+config['Dataset']['splits']['split_val'] == 1), "Invalid splits (sum must be equal to 1)" 51 | # check for segmentation 52 | 53 | # utility func for splitting 54 | self.paths_images, self.paths_depths, self.paths_segmentations = get_splitted_dataset(config, self.split, dataset_name, self.paths_images, self.paths_depths, self.paths_segmentations) 55 | 56 | # Get the transforms 57 | self.transform_image, self.transform_depth, self.transform_seg = get_transforms(config) 58 | 59 | # get p_flip from config 60 | self.p_flip = config['Dataset']['transforms']['p_flip'] if split=='train' else 0 61 | self.p_crop = config['Dataset']['transforms']['p_crop'] if split=='train' else 0 62 | self.p_rot = config['Dataset']['transforms']['p_rot'] if split=='train' else 0 63 | self.resize = config['Dataset']['transforms']['resize'] 64 | 65 | def __len__(self): 66 | """ 67 | Function to get the number of images using the given list of images 68 | """ 69 | return len(self.paths_images) 70 | 71 | def __getitem__(self, idx): 72 | """ 73 | Getter function in order to get the triplet of images / depth maps and segmentation masks 74 | """ 75 | if torch.is_tensor(idx): 76 | idx = idx.tolist() 77 | image = self.transform_image(Image.open(self.paths_images[idx])) 78 | depth = self.transform_depth(Image.open(self.paths_depths[idx])) 79 | segmentation = self.transform_seg(Image.open(self.paths_segmentations[idx])) 80 | imgorig = image.clone() 81 | 82 | if random.random() < self.p_flip: 83 | image = TF.hflip(image) 84 | depth = TF.hflip(depth) 85 | segmentation = TF.hflip(segmentation) 86 | 87 | if random.random() < self.p_crop: 88 | random_size = random.randint(256, self.resize-1) 89 | max_size = self.resize - random_size 90 | left = int(random.random()*max_size) 91 | top = int(random.random()*max_size) 92 | image = TF.crop(image, top, left, random_size, random_size) 93 | depth = TF.crop(depth, top, left, random_size, random_size) 94 | segmentation = TF.crop(segmentation, top, left, random_size, random_size) 95 | image = transforms.Resize((self.resize, self.resize))(image) 96 | depth = transforms.Resize((self.resize, self.resize))(depth) 97 | segmentation = transforms.Resize((self.resize, self.resize), interpolation=transforms.InterpolationMode.NEAREST)(segmentation) 98 | 99 | if random.random() < self.p_rot: 100 | #rotate 101 | random_angle = random.random()*20 - 10 #[-10 ; 10] 102 | mask = torch.ones((1,self.resize,self.resize)) #useful for the resize at the end 103 | mask = TF.rotate(mask, random_angle, interpolation=transforms.InterpolationMode.BILINEAR) 104 | image = TF.rotate(image, random_angle, interpolation=transforms.InterpolationMode.BILINEAR) 105 | depth = TF.rotate(depth, random_angle, interpolation=transforms.InterpolationMode.BILINEAR) 106 | segmentation = TF.rotate(segmentation, random_angle, interpolation=transforms.InterpolationMode.NEAREST) 107 | #crop to remove black borders due to the rotation 108 | left = torch.argmax(mask[:,0,:]).item() 109 | top = torch.argmax(mask[:,:,0]).item() 110 | coin = min(left,top) 111 | size = self.resize - 2*coin 112 | image = TF.crop(image, coin, coin, size, size) 113 | depth = TF.crop(depth, coin, coin, size, size) 114 | segmentation = TF.crop(segmentation, coin, coin, size, size) 115 | #Resize 116 | image = transforms.Resize((self.resize, self.resize))(image) 117 | depth = transforms.Resize((self.resize, self.resize))(depth) 118 | segmentation = transforms.Resize((self.resize, self.resize), interpolation=transforms.InterpolationMode.NEAREST)(segmentation) 119 | # show([imgorig, image, depth, segmentation]) 120 | # exit(0) 121 | return image, depth, segmentation 122 | -------------------------------------------------------------------------------- /FOD/utils.py: -------------------------------------------------------------------------------- 1 | import os, errno 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | 7 | from glob import glob 8 | from PIL import Image 9 | from torchvision import transforms, utils 10 | 11 | from FOD.Loss import ScaleAndShiftInvariantLoss 12 | from FOD.Custom_augmentation import ToMask 13 | 14 | def get_total_paths(path, ext): 15 | return glob(os.path.join(path, '*'+ext)) 16 | 17 | def get_splitted_dataset(config, split, dataset_name, path_images, path_depths, path_segmentation): 18 | list_files = [os.path.basename(im) for im in path_images] 19 | np.random.seed(config['General']['seed']) 20 | np.random.shuffle(list_files) 21 | if split == 'train': 22 | selected_files = list_files[:int(len(list_files)*config['Dataset']['splits']['split_train'])] 23 | elif split == 'val': 24 | selected_files = list_files[int(len(list_files)*config['Dataset']['splits']['split_train']):int(len(list_files)*config['Dataset']['splits']['split_train'])+int(len(list_files)*config['Dataset']['splits']['split_val'])] 25 | else: 26 | selected_files = list_files[int(len(list_files)*config['Dataset']['splits']['split_train'])+int(len(list_files)*config['Dataset']['splits']['split_val']):] 27 | 28 | path_images = [os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_images'], im[:-4]+config['Dataset']['extensions']['ext_images']) for im in selected_files] 29 | path_depths = [os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_depths'], im[:-4]+config['Dataset']['extensions']['ext_depths']) for im in selected_files] 30 | path_segmentation = [os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_segmentations'], im[:-4]+config['Dataset']['extensions']['ext_segmentations']) for im in selected_files] 31 | return path_images, path_depths, path_segmentation 32 | 33 | def get_transforms(config): 34 | im_size = config['Dataset']['transforms']['resize'] 35 | transform_image = transforms.Compose([ 36 | transforms.Resize((im_size, im_size)), 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 39 | ]) 40 | transform_depth = transforms.Compose([ 41 | transforms.Resize((im_size, im_size)), 42 | transforms.Grayscale(num_output_channels=1) , 43 | transforms.ToTensor() 44 | ]) 45 | transform_seg = transforms.Compose([ 46 | transforms.Resize((im_size, im_size), interpolation=transforms.InterpolationMode.NEAREST), 47 | ToMask(config['Dataset']['classes']), 48 | ]) 49 | return transform_image, transform_depth, transform_seg 50 | 51 | def get_losses(config): 52 | def NoneFunction(a, b): 53 | return 0 54 | loss_depth = NoneFunction 55 | loss_segmentation = NoneFunction 56 | type = config['General']['type'] 57 | if type == "full" or type=="depth": 58 | if config['General']['loss_depth'] == 'mse': 59 | loss_depth = nn.MSELoss() 60 | elif config['General']['loss_depth'] == 'ssi': 61 | loss_depth = ScaleAndShiftInvariantLoss() 62 | if type == "full" or type=="segmentation": 63 | if config['General']['loss_segmentation'] == 'ce': 64 | loss_segmentation = nn.CrossEntropyLoss() 65 | return loss_depth, loss_segmentation 66 | 67 | def create_dir(directory): 68 | try: 69 | os.makedirs(directory) 70 | except OSError as e: 71 | if e.errno != errno.EEXIST: 72 | raise 73 | 74 | # def get_optimizer(config, net): 75 | # if config['General']['optim'] == 'adam': 76 | # optimizer = optim.Adam(net.parameters(), lr=config['General']['lr']) 77 | # elif config['General']['optim'] == 'sgd': 78 | # optimizer = optim.SGD(net.parameters(), lr=config['General']['lr'], momentum=config['General']['momentum']) 79 | # return optimizer 80 | 81 | def get_optimizer(config, net): 82 | names = set([name.split('.')[0] for name, _ in net.named_modules()]) - set(['', 'transformer_encoders']) 83 | params_backbone = net.transformer_encoders.parameters() 84 | params_scratch = list() 85 | for name in names: 86 | params_scratch += list(eval("net."+name).parameters()) 87 | 88 | if config['General']['optim'] == 'adam': 89 | optimizer_backbone = optim.Adam(params_backbone, lr=config['General']['lr_backbone']) 90 | optimizer_scratch = optim.Adam(params_scratch, lr=config['General']['lr_scratch']) 91 | elif config['General']['optim'] == 'sgd': 92 | optimizer_backbone = optim.SGD(params_backbone, lr=config['General']['lr_backbone'], momentum=config['General']['momentum']) 93 | optimizer_scratch = optim.SGD(params_scratch, lr=config['General']['lr_scratch'], momentum=config['General']['momentum']) 94 | return optimizer_backbone, optimizer_scratch 95 | 96 | def get_schedulers(optimizers): 97 | return [ReduceLROnPlateau(optimizer) for optimizer in optimizers] 98 | -------------------------------------------------------------------------------- /FocusOnDepth.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antocad/FocusOnDepth/17feb70d927752965b981a98e8359d94227d561e/FocusOnDepth.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Antoine Cadiou (github.com/antocad) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Focus On Depth - A single DPT encoder for AutoFocus application and Dense Prediction Tasks 2 | 3 | ![pytorch](https://img.shields.io/badge/pytorch-v1.10-green.svg?style=plastic) 4 | ![wandb](https://img.shields.io/badge/wandb-v0.12.10-blue.svg?style=plastic) 5 | ![scipy](https://img.shields.io/badge/scipy-v1.7.3-orange.svg?style=plastic) 6 | 7 | 8 | 9 |

10 | 11 |

12 | 13 | 14 | 15 | ## Abstract 16 | 17 | 19 | > Depth estimation is a classic task in computer vision, which is of 20 | great significance for many applications such as augmented 21 | reality, target tracking and autonomous driving. We firstly 22 | summarize the deep learning models for monocular depth 23 | estimation. Secondly, we will implement a recent Vision 24 | Transformers based architecture for this task. We will seek 25 | to improve it by adding a segmentation head in order to 26 | perform multi-task learning using a customly built dataset. 27 | Thirdly, we will implement our model for in-the-wild images (i.e. without control on the environment, the distance 28 | and size of objects of interests, and their physical properties 29 | (rotation, dynamics, etc.)) for Auto-focus application on 30 | humans and will give qualitative comparison across other 31 | methods. 32 | 33 | ## :zap: New! Web demo 34 | 35 | You can check the webdemo hosted on Hugging Face and powered by Gradio, [here](https://huggingface.co/spaces/ybelkada/FocusOnDepth). 36 | 37 | ## :pushpin: Requirements 38 | 39 | Run: ``` pip install -r requirements.txt ``` 40 | 41 | ## :rocket: Running the model 42 | 43 | You can first download one of the models from the model zoo: 44 | 45 | ### :bank: Model zoo 46 | 47 | Get the links of the following models: 48 | 49 | + [```FocusOnDepth_vit_base_patch16_384.p```](https://drive.google.com/file/d/1Q7I777FW_dz5p5UlMsD6aktWQ1eyR1vN/view?usp=sharing) 50 | + Other models coming soon... 51 | 52 | And put the ```.p``` file into the directory ```models/```. After that, you need to update the ```config.json``` ([Tutorial here](https://github.com/antocad/FocusOnDepth/wiki/Config-Wiki)) according to the pre-trained model you have chosen to run the predictions (this means that if you load a depth-only model, then you have to set ```type``` to ```depth``` for example ...). 53 | 54 | ### :dart: Run a prediction 55 | 56 | Put your input images (that have to be ```.png``` or ```.jpg```) into the ```input/``` folder. Then, just run ```python run.py``` and you should get the depth maps as well as the segmentation masks in the ```output/``` folder. 57 | 58 | 59 | ## :hammer: Training 60 | 61 | ### :wrench: Build the dataset 62 | 63 | Our model is trained on a combination of 64 | + [inria movie 3d dataset](https://www.di.ens.fr/willow/research/stereoseg/) | [view on Kaggle](https://www.kaggle.com/antocad/inria-fod/) 65 | + [NYU2 Dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) | [view on Kaggle](https://www.kaggle.com/antocad/nyuv2-fod) 66 | + [PoseTrack](https://posetrack.net/) | [view on Kaggle](https://www.kaggle.com/antocad/posetrack-fod) 67 | 68 | ### :pencil: Configure ```config.json``` 69 | 70 | Please refer to our [config wiki](https://github.com/antocad/FocusOnDepth/wiki/Config-Wiki) to understand how to modify the config file to run a training. 71 | 72 | ### :nut_and_bolt: Run the training script 73 | After that, you can simply run the training script: ```python train.py``` 74 | 75 | 76 | ## :scroll: Citations 77 | 78 | Our work is based on the work from Ranflt et al. please do not forget to cite their work! :) 79 | You can also check our [report](https://github.com/antocad/FocusOnDepth/blob/master/FocusOnDepth.pdf) if you need more details. 80 | 81 | ``` 82 | @article{DPT, 83 | author = {Ren{\'{e}} Ranftl and 84 | Alexey Bochkovskiy and 85 | Vladlen Koltun}, 86 | title = {Vision Transformers for Dense Prediction}, 87 | journal = {CoRR}, 88 | volume = {abs/2103.13413}, 89 | year = {2021}, 90 | url = {https://arxiv.org/abs/2103.13413}, 91 | eprinttype = {arXiv}, 92 | eprint = {2103.13413}, 93 | timestamp = {Wed, 07 Apr 2021 15:31:46 +0200}, 94 | biburl = {https://dblp.org/rec/journals/corr/abs-2103-13413.bib}, 95 | bibsource = {dblp computer science bibliography, https://dblp.org} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "General":{ 3 | "device":"cuda", 4 | "type":"full", 5 | "model_timm":"vit_base_patch16_384", 6 | "emb_dim":768, 7 | "hooks":[2, 5, 8, 11], 8 | "read":"projection", 9 | "resample_dim":256, 10 | "optim":"adam", 11 | "lr_backbone":1e-5, 12 | "lr_scratch":3e-4, 13 | "loss_depth":"ssi", 14 | "loss_segmentation":"ce", 15 | "momentum":0.9, 16 | "epochs":20, 17 | "batch_size":1, 18 | "path_model":"models", 19 | "path_predicted_images":"output", 20 | "seed":0, 21 | "patch_size":16 22 | }, 23 | "Dataset":{ 24 | "paths":{ 25 | "path_dataset":"./datasets", 26 | "list_datasets":["inria", "nyuv2", "posetrack"], 27 | "path_images":"images", 28 | "path_segmentations":"segmentations", 29 | "path_depths":"depths" 30 | }, 31 | "extensions":{ 32 | "ext_images":".jpg", 33 | "ext_segmentations":".png", 34 | "ext_depths":".jpg" 35 | }, 36 | "splits":{ 37 | "split_train":0.6, 38 | "split_val":0.2, 39 | "split_test":0.2 40 | }, 41 | "transforms":{ 42 | "resize":384, 43 | "p_flip":0.5, 44 | "p_crop":0.3, 45 | "p_rot":0.2 46 | }, 47 | "classes":{ 48 | "1": { 49 | "name": "person", 50 | "color": [150,5,61] 51 | } 52 | } 53 | }, 54 | "wandb":{ 55 | "enable":false, 56 | "username":"younesbelkada", 57 | "images_to_show":3, 58 | "im_h":540, 59 | "im_w":980 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /images/pull_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antocad/FocusOnDepth/17feb70d927752965b981a98e8359d94227d561e/images/pull_figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.3.2 2 | matplotlib==3.1.2 3 | numpy==1.17.4 4 | opencv_python==4.4.0.46 5 | Pillow==9.0.1 6 | scipy==1.5.4 7 | timm==0.4.12 8 | torch==1.10.2 9 | torchvision==0.10.0 10 | tqdm==4.51.0 11 | wandb -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import json 2 | from glob import glob 3 | from FOD.Predictor import Predictor 4 | 5 | with open('config.json', 'r') as f: 6 | config = json.load(f) 7 | 8 | input_images = glob('input/*.jpg') + glob('input/*.png') 9 | predictor = Predictor(config, input_images) 10 | predictor.run() 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import torch 4 | 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import ConcatDataset 7 | 8 | from FOD.Trainer import Trainer 9 | from FOD.dataset import AutoFocusDataset 10 | 11 | with open('config.json', 'r') as f: 12 | config = json.load(f) 13 | np.random.seed(config['General']['seed']) 14 | 15 | list_data = config['Dataset']['paths']['list_datasets'] 16 | 17 | ## train set 18 | autofocus_datasets_train = [] 19 | for dataset_name in list_data: 20 | autofocus_datasets_train.append(AutoFocusDataset(config, dataset_name, 'train')) 21 | train_data = ConcatDataset(autofocus_datasets_train) 22 | train_dataloader = DataLoader(train_data, batch_size=config['General']['batch_size'], shuffle=True) 23 | 24 | ## validation set 25 | autofocus_datasets_val = [] 26 | for dataset_name in list_data: 27 | autofocus_datasets_val.append(AutoFocusDataset(config, dataset_name, 'val')) 28 | val_data = ConcatDataset(autofocus_datasets_val) 29 | val_dataloader = DataLoader(val_data, batch_size=config['General']['batch_size'], shuffle=True) 30 | 31 | trainer = Trainer(config) 32 | trainer.train(train_dataloader, val_dataloader) 33 | --------------------------------------------------------------------------------