├── README.md ├── dataset └── nyuv2_splits.mat ├── download.sh ├── figure ├── displacement_field.png └── toy.png ├── lib ├── backbone │ └── resnet.py ├── datareader │ └── img_utils.py ├── datasets │ └── nyu.py ├── engine │ ├── engine.py │ ├── logger.py │ ├── lr_policy.py │ └── version.py ├── layers │ └── basic_module.py ├── misc │ └── utils.py └── utils │ ├── init_func.py │ └── pyt_utils.py └── model └── nyu ├── df_nyu_depth_only ├── config.py ├── dataloader.py ├── df.py ├── train.py └── unet.py ├── df_nyu_rgb_guidance ├── config.py ├── dataloader.py ├── df.py ├── train.py └── unet.py └── df_nyu_rgb_guidance_pos_encoding_attention_loss ├── config.py ├── dataloader.py ├── df.py ├── train.py └── unet.py /README.md: -------------------------------------------------------------------------------- 1 | # Displacement_Field 2 | Official implementation of paper **Predicting Sharp and Accurate Occlusion Boundaries in Monocular Depth Estimation Using Displacement Fields**(CVPR 2020) [paper link](https://arxiv.org/abs/2002.12730) 3 | 4 | NYUv2OC++ dataset(only for test use) [download link](https://drive.google.com/file/d/1Fk8uuH3oJJhyCN-4ffD3mdtCq2l4geJc/view) 5 | 6 | ## Visualization 7 | ### 1D example 8 | ![1D](./figure/toy.png) 9 | ------ 10 | ### 2D example on blurry depth image(prediction of depth estimator) 11 | ![2D](./figure/displacement_field.png) 12 | ------ 13 | ## Requirements: 14 | - PyTorch >= 0.4 15 | - OpenCV 16 | - CUDA >= 8.0(Only tested with CUDA >= 8.0) 17 | - Easydict 18 | 19 | ## Data Preparation 20 | ```bash 21 | sh download.sh 22 | ``` 23 | 24 | ## Training 25 | ```bash 26 | #Use depth only as input 27 | cd model/nyu/df_nyu_depth_only 28 | python train.py -d 0 29 | 30 | #Use RGB image as guidance 31 | cd model/nyu/df_nyu_rgb_guidance 32 | python train.py -d 0 33 | ``` 34 | ## Citation 35 | ```bash 36 | @InProceedings{Ramamonjisoa_2020_CVPR, 37 | author = {Ramamonjisoa, Michael and Du, Yuming and Lepetit, Vincent}, 38 | title = {Predicting Sharp and Accurate Occlusion Boundaries in Monocular Depth Estimation Using Displacement Fields}, 39 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 40 | month = {June}, 41 | year = {2020} 42 | } 43 | ``` 44 | 45 | ## Miscellaneous 46 | The model can be trained with only synthetic data([Scenenet](https://robotvault.bitbucket.io/scenenet-rgbd.html) for example), and generalize naturally on real data. 47 | 48 | ## Acknowledgement 49 | The code is based on [TorchSeg](https://github.com/ycszen/TorchSeg) 50 | 51 | The NYUv2-OC++ is annotated manually by 4 PhD students major in computer vision. Special thanks to [Yang Xiao](https://youngxiao13.github.io/) and [Xuchong Qiu](https://imagine-lab.enpc.fr/staff-members/xuchong-qiu/) for their help in annotating the NYUv2-OC++ dataset. 52 | -------------------------------------------------------------------------------- /dataset/nyuv2_splits.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dulucas/Displacement_Field/75c08f36774ed0a9a327fffd7e9da901cb0823c3/dataset/nyuv2_splits.mat -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | wget http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat -P dataset/ 2 | -------------------------------------------------------------------------------- /figure/displacement_field.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dulucas/Displacement_Field/75c08f36774ed0a9a327fffd7e9da901cb0823c3/figure/displacement_field.png -------------------------------------------------------------------------------- /figure/toy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dulucas/Displacement_Field/75c08f36774ed0a9a327fffd7e9da901cb0823c3/figure/toy.png -------------------------------------------------------------------------------- /lib/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # Taken from pytorch official repository: 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from torch import load as th_load 7 | 8 | import os 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'Bottleneck'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | # expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, expansion=4, dilation=(1, 1)): 71 | super(Bottleneck, self).__init__() 72 | self.expansion = expansion 73 | self.conv1 = conv1x1(inplanes, planes) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | # self.conv2 = conv3x3(planes, planes, stride) # changed from original Bottleneck in ResNet without dilation 76 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 77 | padding=dilation[1], bias=False, 78 | dilation=dilation[1]) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | self.conv3 = conv1x1(planes, planes * self.expansion) 81 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | 89 | out = self.conv1(x) 90 | out = self.bn1(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv2(out) 94 | out = self.bn2(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv3(out) 98 | out = self.bn3(out) 99 | 100 | if self.downsample is not None: 101 | residual = self.downsample(x) 102 | 103 | out += residual 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | 111 | def __init__(self, block, layers, num_classes=1000): 112 | super(ResNet, self).__init__() 113 | self.inplanes = 64 114 | self.conv1_ = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, #3 (RGB) * 7x7 * 64 115 | bias=False) 116 | self.bn1 = nn.BatchNorm2d(64) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | self.layer1 = self._make_layer(block, 64, layers[0]) 120 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 121 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 122 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 123 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 124 | self.fc = nn.Linear(512 * block.expansion, num_classes) 125 | 126 | # initialize weights 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 130 | elif isinstance(m, nn.BatchNorm2d): 131 | nn.init.constant_(m.weight, 1) 132 | nn.init.constant_(m.bias, 0) 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | # for all resblocks self.layerI(first inplanes != planes * expansion then stride=2 => depth 512->256) 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | conv1x1(self.inplanes, planes * block.expansion, stride), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = list([]) 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for _ in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | output = [] 153 | x = self.conv1_(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | output.append(x) 160 | x = self.layer2(x) 161 | output.append(x) 162 | x = self.layer3(x) 163 | output.append(x) 164 | x = self.layer4(x) 165 | output.append(x) 166 | 167 | return output 168 | 169 | 170 | def resnet18(pretrained=False, **kwargs): 171 | """Constructs a ResNet-18 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 178 | return model 179 | 180 | 181 | def resnet34(pretrained=False, **kwargs): 182 | """Constructs a ResNet-34 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 189 | return model 190 | 191 | 192 | def resnet50(pretrained=False, **kwargs): 193 | """Constructs a ResNet-50 model. 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 198 | model_path = os.path.join('models', model_urls['resnet50'].split('/')[-1]) 199 | if pretrained: 200 | if os.path.exists(model_path): 201 | model.load_state_dict(th_load(model_path)) 202 | else: 203 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 204 | return model 205 | 206 | 207 | def resnet101(pretrained=False, **kwargs): 208 | """Constructs a ResNet-101 model. 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 213 | if pretrained: 214 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 215 | return model 216 | 217 | 218 | def resnet152(pretrained=False, **kwargs): 219 | """Constructs a ResNet-152 model. 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | """ 223 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 224 | if pretrained: 225 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 226 | return model 227 | -------------------------------------------------------------------------------- /lib/datareader/img_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import numbers 4 | import random 5 | import collections 6 | 7 | def rgb2gray(img): 8 | R = np.array(img[:, :, 0]) 9 | G = np.array(img[:, :, 1]) 10 | B = np.array(img[:, :, 2]) 11 | 12 | R = (R *.299) 13 | G = (G *.587) 14 | B = (B *.114) 15 | 16 | Avg = (R+G+B) 17 | 18 | return Avg 19 | 20 | def get_2dshape(shape, *, zero=True): 21 | if not isinstance(shape, collections.Iterable): 22 | shape = int(shape) 23 | shape = (shape, shape) 24 | else: 25 | h, w = map(int, shape) 26 | shape = (h, w) 27 | if zero: 28 | minv = 0 29 | else: 30 | minv = 1 31 | 32 | assert min(shape) >= minv, 'invalid shape: {}'.format(shape) 33 | return shape 34 | 35 | 36 | def random_crop_pad_to_shape(img, crop_pos, crop_size, pad_label_value): 37 | h, w = img.shape[:2] 38 | start_crop_h, start_crop_w = crop_pos 39 | assert ((start_crop_h < h) and (start_crop_h >= 0)) 40 | assert ((start_crop_w < w) and (start_crop_w >= 0)) 41 | 42 | crop_size = get_2dshape(crop_size) 43 | crop_h, crop_w = crop_size 44 | 45 | img_crop = img[start_crop_h:start_crop_h + crop_h, 46 | start_crop_w:start_crop_w + crop_w, ...] 47 | 48 | img_, margin = pad_image_to_shape(img_crop, crop_size, cv2.BORDER_CONSTANT, 49 | pad_label_value) 50 | 51 | return img_, margin 52 | 53 | 54 | def generate_random_crop_pos(ori_size, crop_size): 55 | ori_size = get_2dshape(ori_size) 56 | h, w = ori_size 57 | 58 | crop_size = get_2dshape(crop_size) 59 | crop_h, crop_w = crop_size 60 | 61 | pos_h, pos_w = 0, 0 62 | 63 | if h > crop_h: 64 | pos_h = random.randint(0, h - crop_h + 1) 65 | 66 | if w > crop_w: 67 | pos_w = random.randint(0, w - crop_w + 1) 68 | 69 | return pos_h, pos_w 70 | 71 | 72 | def pad_image_to_shape(img, shape, border_mode, value): 73 | margin = np.zeros(4, np.uint32) 74 | shape = get_2dshape(shape) 75 | pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 76 | pad_width = shape[1] - img.shape[1] if shape[1] - img.shape[1] > 0 else 0 77 | 78 | margin[0] = pad_height // 2 79 | margin[1] = pad_height // 2 + pad_height % 2 80 | margin[2] = pad_width // 2 81 | margin[3] = pad_width // 2 + pad_width % 2 82 | 83 | img = cv2.copyMakeBorder(img, margin[0], margin[1], margin[2], margin[3], 84 | border_mode, value=value) 85 | 86 | return img, margin 87 | 88 | 89 | def pad_image_size_to_multiples_of(img, multiple, pad_value): 90 | h, w = img.shape[:2] 91 | d = multiple 92 | 93 | def canonicalize(s): 94 | v = s // d 95 | return (v + (v * d != s)) * d 96 | 97 | th, tw = map(canonicalize, (h, w)) 98 | 99 | return pad_image_to_shape(img, (th, tw), cv2.BORDER_CONSTANT, pad_value) 100 | 101 | 102 | def resize_ensure_shortest_edge(img, edge_length, 103 | interpolation_mode=cv2.INTER_LINEAR): 104 | assert isinstance(edge_length, int) and edge_length > 0, edge_length 105 | h, w = img.shape[:2] 106 | if h < w: 107 | ratio = float(edge_length) / h 108 | th, tw = edge_length, max(1, int(ratio * w)) 109 | else: 110 | ratio = float(edge_length) / w 111 | th, tw = max(1, int(ratio * h)), edge_length 112 | img = cv2.resize(img, (tw, th), interpolation_mode) 113 | 114 | return img 115 | 116 | 117 | def random_scale(img, gt, scales): 118 | scale = random.choice(scales) 119 | if scale > 1: 120 | interpolation_ = cv2.INTER_CUBIC 121 | else: 122 | interpolation_ = cv2.INTER_LINEAR 123 | sh = int(img.shape[0] * scale) 124 | sw = int(img.shape[1] * scale) 125 | img = cv2.resize(img, (sw, sh), interpolation=interpolation_) 126 | gt = cv2.resize(gt, (sw, sh), interpolation=interpolation_) 127 | gt /= scale 128 | 129 | return img, gt, scale 130 | 131 | 132 | def random_scale_with_length(img, gt, length): 133 | size = random.choice(length) 134 | sh = size 135 | sw = size 136 | img = cv2.resize(img, (sw, sh), interpolation=cv2.INTER_LINEAR) 137 | gt = cv2.resize(gt, (sw, sh), interpolation=cv2.INTER_LINEAR) 138 | 139 | return img, gt, size 140 | 141 | 142 | def random_mirror(img, gt): 143 | if random.random() >= 0.5: 144 | img = cv2.flip(img, 1) 145 | gt = cv2.flip(gt, 1) 146 | return img, gt 147 | 148 | 149 | def random_rotation(img, gt): 150 | angle = random.random() * 20 - 10 151 | h, w = img.shape[:2] 152 | rotation_matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 153 | img = cv2.warpAffine(img, rotation_matrix, (w, h), flags=cv2.INTER_LINEAR) 154 | gt = cv2.warpAffine(gt, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST) 155 | 156 | return img, gt 157 | 158 | 159 | def random_gaussian_blur(img): 160 | gauss_size = random.choice([1, 3, 5, 7]) 161 | if gauss_size > 1: 162 | # do the gaussian blur 163 | img = cv2.GaussianBlur(img, (gauss_size, gauss_size), 0) 164 | 165 | return img 166 | 167 | 168 | def center_crop(img, shape): 169 | h, w = shape[0], shape[1] 170 | y = (img.shape[0] - h) // 2 171 | x = (img.shape[1] - w) // 2 172 | return img[y:y + h, x:x + w] 173 | 174 | 175 | def random_crop(img, gt, size): 176 | if isinstance(size, numbers.Number): 177 | size = (int(size), int(size)) 178 | else: 179 | size = size 180 | 181 | h, w = img.shape[:2] 182 | crop_h, crop_w = size[0], size[1] 183 | 184 | if h > crop_h: 185 | x = random.randint(0, h - crop_h + 1) 186 | img = img[x:x + crop_h, :, :] 187 | gt = gt[x:x + crop_h, :] 188 | 189 | if w > crop_w: 190 | x = random.randint(0, w - crop_w + 1) 191 | img = img[:, x:x + crop_w, :] 192 | gt = gt[:, x:x + crop_w] 193 | 194 | return img, gt 195 | 196 | 197 | def normalize(img, mean, std): 198 | # pytorch pretrained model need the input range: 0-1 199 | img = img.astype(np.float32) / 255.0 200 | img = img - mean 201 | img = img / std 202 | 203 | return img 204 | 205 | def normalize_depth(depth): 206 | if depth.max() == 0: 207 | return depth 208 | m = depth[depth>0].min() 209 | M = depth[depth>0].max() 210 | depth[depth>0] = (depth[depth>0] - m) / (M - m) 211 | 212 | return depth 213 | 214 | def random_uniform_gaussian_blur(depth, kernel_range, max_kernel): 215 | gauss_size = int(random.uniform(kernel_range[0], kernel_range[1]) * max_kernel) 216 | if gauss_size % 2 == 0: 217 | gauss_size += 1 218 | if gauss_size > 1: 219 | # do the gaussian blur 220 | depth = cv2.GaussianBlur(depth, (gauss_size, gauss_size), 0) 221 | 222 | return depth 223 | 224 | def updown_sampling(depth, scale, interpolation): 225 | if interpolation == 'LINEAR': 226 | interpolation = cv2.INTER_LINEAR 227 | elif interpolation == 'CUBIC': 228 | interpolation = cv2.INTER_CUBIC 229 | elif interpolation == 'NEAREST': 230 | interpolation = cv2.INTER_NEAREST 231 | h, w = depth.shape[:2] 232 | depth = cv2.resize(depth, (w // scale, h // scale), interpolation=interpolation) 233 | depth = cv2.reszie(depth, (w, h), interpolation=interpolation) 234 | 235 | return depth 236 | 237 | def generate_mask_by_shifting(depth, scale=1, kernel=10, step_size=1, delta=5): 238 | if scale > 1: 239 | depth = depth[::scale, ::scale] 240 | affinities = np.zeros(depth.shape) 241 | 242 | depth_pad = np.pad(depth, ((kernel//2, kernel//2), (kernel//2, kernel//2)), 'edge') 243 | rows, cols = depth.shape[0], depth.shape[1] 244 | for i in range(-(kernel//2), kernel//2 + 1, step_size): 245 | for j in range(-(kernel//2), kernel//2 + 1, step_size): 246 | if i == 0 and j == 0: 247 | continue 248 | affinities[max(i, 0):min(rows, rows+i), max(j, 0):min(cols, cols+j)] = \ 249 | np.abs(depth_pad[max(-i, 0):min(rows, rows-i), max(-j, 0):min(cols, cols-j)] - depth[max(i, 0):min(rows, rows+i), max(j, 0):min(cols, cols+j)]) 250 | 251 | return 1-np.exp(-affinities**2 * delta) 252 | 253 | -------------------------------------------------------------------------------- /lib/datasets/nyu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2017/12/16 下午8:41 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : BaseDataset.py 7 | 8 | import os 9 | import time 10 | import cv2 11 | import torch 12 | import numpy as np 13 | 14 | import torch.utils.data as data 15 | 16 | 17 | class NYUDataset(data.Dataset): 18 | def __init__(self, setting, split_name, preprocess=None, 19 | file_length=None): 20 | super(NYUDataset, self).__init__() 21 | self._split_name = split_name 22 | self._imgs_source, self._labels_source = self._load_data_file(setting['data_source']) 23 | self._train_test_splits = self._load_split_file(setting['train_test_splits']) 24 | self._file_names = self._get_file_names(split_name, self._train_test_splits) 25 | self._file_length = file_length 26 | self.preprocess = preprocess 27 | 28 | def __len__(self): 29 | if self._file_length is not None: 30 | return self._file_length 31 | return len(self._file_names) 32 | 33 | def __getitem__(self, index): 34 | if self._file_length is not None: 35 | index_ = self._construct_new_file_names(self._file_length)[index] 36 | else: 37 | index_ = self._file_names[index] 38 | 39 | img, gt = self._fetch_data(index_) 40 | #img = img[:, :, ::-1] 41 | if self.preprocess is not None: 42 | img, ori, gt, mask, extra_dict = self.preprocess(img, gt) 43 | 44 | if self._split_name is 'train': 45 | img = torch.from_numpy(np.ascontiguousarray(img)).float() 46 | gt = torch.from_numpy(np.ascontiguousarray(gt)).float() 47 | ori = torch.from_numpy(np.ascontiguousarray(ori)).float() 48 | mask = torch.from_numpy(np.ascontiguousarray(mask)).float() 49 | 50 | if self.preprocess is not None and extra_dict is not None: 51 | for k, v in extra_dict.items(): 52 | extra_dict[k] = torch.from_numpy(np.ascontiguousarray(v)) 53 | if 'label' in k: 54 | extra_dict[k] = extra_dict[k].float() 55 | if 'img' in k: 56 | extra_dict[k] = extra_dict[k].float() 57 | 58 | output_dict = dict(guidance=img, data=ori, label=gt, mask=mask, 59 | n=len(self._file_names)) 60 | if self.preprocess is not None and extra_dict is not None: 61 | output_dict.update(**extra_dict) 62 | 63 | return output_dict 64 | 65 | def _fetch_data(self, index): 66 | img = self._imgs_source[index] 67 | gt = self._labels_source[index] 68 | 69 | return img, gt 70 | 71 | def _load_split_file(self, split_file_path): 72 | from scipy.io import loadmat 73 | split_file = loadmat(split_file_path) 74 | split_file['train'] = [i[0] - 1 for i in split_file['trainNdxs']] 75 | split_file['test'] = [i[0] - 1 for i in split_file['testNdxs']] 76 | return split_file 77 | 78 | def _load_data_file(self, data_source_path): 79 | import h5py 80 | data_file = h5py.File(data_source_path, 'r') 81 | depths = np.array(data_file['depths'], dtype=np.float64).transpose(0, 2, 1) 82 | images = np.array(data_file['images'], dtype=np.float64).transpose(0, 3, 2, 1) 83 | return images, depths 84 | 85 | 86 | def _get_file_names(self, split_name, split_file): 87 | assert split_name in ['train', 'test'] 88 | file_names = split_file[split_name] 89 | 90 | return file_names 91 | 92 | def _construct_new_file_names(self, length): 93 | assert isinstance(length, int) 94 | files_len = len(self._file_names) 95 | new_file_names = self._file_names * (length // files_len) 96 | 97 | rand_indices = torch.randperm(files_len).tolist() 98 | new_indices = rand_indices[:length % files_len] 99 | 100 | new_file_names += [self._file_names[i] for i in new_indices] 101 | 102 | return new_file_names 103 | 104 | def get_length(self): 105 | return self.__len__() 106 | 107 | 108 | -------------------------------------------------------------------------------- /lib/engine/engine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/8/2 下午3:23 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : engine.py 7 | import os 8 | import os.path as osp 9 | import time 10 | import argparse 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | from .logger import get_logger 16 | from .version import __version__ 17 | from utils.pyt_utils import load_model, parse_devices, extant_file, link_file, \ 18 | ensure_dir 19 | 20 | logger = get_logger() 21 | 22 | 23 | class State(object): 24 | def __init__(self): 25 | self.epoch = 0 26 | self.iteration = 0 27 | self.dataloader = None 28 | self.model = None 29 | self.optimizer = None 30 | 31 | def register(self, **kwargs): 32 | for k, v in kwargs.items(): 33 | assert k in ['epoch', 'iteration', 'dataloader', 'model', 34 | 'optimizer'] 35 | setattr(self, k, v) 36 | 37 | 38 | class Engine(object): 39 | def __init__(self, custom_parser=None): 40 | self.version = __version__ 41 | logger.info( 42 | "PyTorch Version {}, Furnace Version {}".format(torch.__version__, 43 | self.version)) 44 | self.state = State() 45 | self.devices = None 46 | self.distributed = False 47 | 48 | if custom_parser is None: 49 | self.parser = argparse.ArgumentParser() 50 | else: 51 | assert isinstance(custom_parser, argparse.ArgumentParser) 52 | self.parser = custom_parser 53 | 54 | self.inject_default_parser() 55 | self.args = self.parser.parse_args() 56 | 57 | self.continue_state_object = self.args.continue_fpath 58 | 59 | if 'WORLD_SIZE' in os.environ: 60 | self.distributed = int(os.environ['WORLD_SIZE']) > 1 61 | 62 | self.devices = parse_devices(self.args.devices) 63 | 64 | def inject_default_parser(self): 65 | p = self.parser 66 | p.add_argument('-d', '--devices', default='', 67 | help='set data parallel training') 68 | p.add_argument('-c', '--continue', type=extant_file, 69 | metavar="FILE", 70 | dest="continue_fpath", 71 | help='continue from one certain checkpoint') 72 | p.add_argument('--local_rank', default=0, type=int, 73 | help='process rank on node') 74 | 75 | def register_state(self, **kwargs): 76 | self.state.register(**kwargs) 77 | 78 | def update_iteration(self, epoch, iteration): 79 | self.state.epoch = epoch 80 | self.state.iteration = iteration 81 | 82 | def save_checkpoint(self, path): 83 | logger.info("Saving checkpoint to file {}".format(path)) 84 | t_start = time.time() 85 | 86 | state_dict = {} 87 | 88 | from collections import OrderedDict 89 | new_state_dict = OrderedDict() 90 | for k, v in self.state.model.state_dict().items(): 91 | key = k 92 | if k.split('.')[0] == 'module': 93 | key = k[7:] 94 | new_state_dict[key] = v 95 | state_dict['model'] = new_state_dict 96 | state_dict['optimizer'] = self.state.optimizer.state_dict() 97 | state_dict['epoch'] = self.state.epoch 98 | state_dict['iteration'] = self.state.iteration 99 | 100 | t_iobegin = time.time() 101 | torch.save(state_dict, path) 102 | del state_dict 103 | del new_state_dict 104 | t_end = time.time() 105 | logger.info( 106 | "Save checkpoint to file {}, " 107 | "Time usage:\n\tprepare snapshot: {}, IO: {}".format( 108 | path, t_iobegin - t_start, t_end - t_iobegin)) 109 | 110 | def save_and_link_checkpoint(self, snapshot_dir, log_dir, log_dir_link): 111 | ensure_dir(snapshot_dir) 112 | if not osp.exists(log_dir_link): 113 | link_file(log_dir, log_dir_link) 114 | current_epoch_checkpoint = osp.join(snapshot_dir, 'epoch-{}.pth'.format( 115 | self.state.epoch)) 116 | self.save_checkpoint(current_epoch_checkpoint) 117 | last_epoch_checkpoint = osp.join(snapshot_dir, 118 | 'epoch-last.pth') 119 | link_file(current_epoch_checkpoint, last_epoch_checkpoint) 120 | 121 | def restore_checkpoint(self): 122 | t_start = time.time() 123 | #if self.distributed: 124 | # tmp = torch.load(self.continue_state_object, 125 | # map_location=lambda storage, loc: storage.cuda( 126 | # self.local_rank)) 127 | #else: 128 | tmp = torch.load(self.continue_state_object) 129 | t_ioend = time.time() 130 | 131 | self.state.model = load_model(self.state.model, tmp['model'], 132 | True) 133 | self.state.optimizer.load_state_dict(tmp['optimizer']) 134 | self.state.epoch = tmp['epoch'] + 1 135 | self.state.iteration = tmp['iteration'] 136 | del tmp 137 | t_end = time.time() 138 | logger.info( 139 | "Load checkpoint from file {}, " 140 | "Time usage:\n\tIO: {}, restore snapshot: {}".format( 141 | self.continue_state_object, t_ioend - t_start, t_end - t_ioend)) 142 | 143 | def __enter__(self): 144 | return self 145 | 146 | def __exit__(self, type, value, tb): 147 | torch.cuda.empty_cache() 148 | if type is not None: 149 | logger.warning( 150 | "A exception occurred during Engine initialization, " 151 | "give up running process") 152 | return False 153 | -------------------------------------------------------------------------------- /lib/engine/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/8/2 上午11:48 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : logger.py 7 | import os 8 | import sys 9 | import logging 10 | 11 | from utils.pyt_utils import ensure_dir 12 | # from utils.pyt_utils import ensure_dir 13 | 14 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 15 | _default_level = logging.getLevelName(_default_level_name.upper()) 16 | 17 | 18 | class LogFormatter(logging.Formatter): 19 | log_fout = None 20 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 21 | date = '%(asctime)s ' 22 | msg = '%(message)s' 23 | 24 | def format(self, record): 25 | if record.levelno == logging.DEBUG: 26 | mcl, mtxt = self._color_dbg, 'DBG' 27 | elif record.levelno == logging.WARNING: 28 | mcl, mtxt = self._color_warn, 'WRN' 29 | elif record.levelno == logging.ERROR: 30 | mcl, mtxt = self._color_err, 'ERR' 31 | else: 32 | mcl, mtxt = self._color_normal, '' 33 | 34 | if mtxt: 35 | mtxt += ' ' 36 | 37 | if self.log_fout: 38 | self.__set_fmt(self.date_full + mtxt + self.msg) 39 | formatted = super(LogFormatter, self).format(record) 40 | # self.log_fout.write(formatted) 41 | # self.log_fout.write('\n') 42 | # self.log_fout.flush() 43 | return formatted 44 | 45 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 46 | formatted = super(LogFormatter, self).format(record) 47 | 48 | return formatted 49 | 50 | if sys.version_info.major < 3: 51 | def __set_fmt(self, fmt): 52 | self._fmt = fmt 53 | else: 54 | def __set_fmt(self, fmt): 55 | self._style._fmt = fmt 56 | 57 | @staticmethod 58 | def _color_dbg(msg): 59 | return '\x1b[36m{}\x1b[0m'.format(msg) 60 | 61 | @staticmethod 62 | def _color_warn(msg): 63 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 64 | 65 | @staticmethod 66 | def _color_err(msg): 67 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 68 | 69 | @staticmethod 70 | def _color_omitted(msg): 71 | return '\x1b[35m{}\x1b[0m'.format(msg) 72 | 73 | @staticmethod 74 | def _color_normal(msg): 75 | return msg 76 | 77 | @staticmethod 78 | def _color_date(msg): 79 | return '\x1b[32m{}\x1b[0m'.format(msg) 80 | 81 | 82 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 83 | logger = logging.getLogger() 84 | logger.setLevel(_default_level) 85 | del logger.handlers[:] 86 | 87 | if log_dir and log_file: 88 | ensure_dir(log_dir) 89 | LogFormatter.log_fout = True 90 | file_handler = logging.FileHandler(log_file, mode='a') 91 | file_handler.setLevel(logging.INFO) 92 | file_handler.setFormatter(formatter) 93 | logger.addHandler(file_handler) 94 | 95 | stream_handler = logging.StreamHandler() 96 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 97 | stream_handler.setLevel(0) 98 | logger.addHandler(stream_handler) 99 | return logger 100 | -------------------------------------------------------------------------------- /lib/engine/lr_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/8/1 上午1:50 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : lr_policy.py.py 7 | 8 | from abc import ABCMeta, abstractmethod 9 | 10 | 11 | class BaseLR(): 12 | __metaclass__ = ABCMeta 13 | 14 | @abstractmethod 15 | def get_lr(self, cur_iter): pass 16 | 17 | 18 | class PolyLR(BaseLR): 19 | def __init__(self, start_lr, lr_power, total_iters): 20 | self.start_lr = start_lr 21 | self.lr_power = lr_power 22 | self.total_iters = total_iters + 0.0 23 | 24 | def get_lr(self, cur_iter): 25 | return self.start_lr * ( 26 | (1 - float(cur_iter) / self.total_iters) ** self.lr_power) 27 | 28 | 29 | class MultiStageLR(BaseLR): 30 | def __init__(self, lr_stages): 31 | assert type(lr_stages) in [list, tuple] and len(lr_stages[0]) == 2, \ 32 | 'lr_stages must be list or tuple, with [iters, lr] format' 33 | self._lr_stagess = lr_stages 34 | 35 | def get_lr(self, epoch): 36 | for it_lr in self._lr_stagess: 37 | if epoch < it_lr[0]: 38 | return it_lr[1] 39 | 40 | 41 | class LinearIncreaseLR(BaseLR): 42 | def __init__(self, start_lr, end_lr, warm_iters): 43 | self._start_lr = start_lr 44 | self._end_lr = end_lr 45 | self._warm_iters = warm_iters 46 | self._delta_lr = (end_lr - start_lr) / warm_iters 47 | 48 | def get_lr(self, cur_epoch): 49 | return self._start_lr + cur_epoch * self._delta_lr 50 | 51 | -------------------------------------------------------------------------------- /lib/engine/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/8/3 下午2:59 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : version.py 7 | 8 | __version__ = '0.1.1' 9 | -------------------------------------------------------------------------------- /lib/layers/basic_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ConvBnRelu(nn.Module): 6 | def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1, 7 | groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5, 8 | has_relu=True, inplace=True, has_bias=False): 9 | super(ConvBnRelu, self).__init__() 10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize, 11 | stride=stride, padding=pad, 12 | dilation=dilation, groups=groups, bias=has_bias) 13 | self.has_bn = has_bn 14 | if self.has_bn: 15 | self.bn = norm_layer(out_planes, eps=bn_eps) 16 | self.has_relu = has_relu 17 | if self.has_relu: 18 | self.relu = nn.ReLU(inplace=inplace) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | if self.has_bn: 23 | x = self.bn(x) 24 | if self.has_relu: 25 | x = self.relu(x) 26 | 27 | return x 28 | 29 | class ConvBnLeakyRelu(nn.Module): 30 | def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1, 31 | groups=1, has_bn=True, norm_layer=nn.BatchNorm2d, bn_eps=1e-5, 32 | leaky_alpha=0.3, has_leaky_relu=True, inplace=True, has_bias=False): 33 | super(ConvBnLeakyRelu, self).__init__() 34 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize, 35 | stride=stride, padding=pad, 36 | dilation=dilation, groups=groups, bias=has_bias) 37 | self.has_bn = has_bn 38 | if self.has_bn: 39 | self.bn = norm_layer(out_planes, eps=bn_eps) 40 | self.has_leakyrelu = has_leaky_relu 41 | if self.has_leakyrelu: 42 | self.relu = nn.LeakyReLU(negative_slope=leaky_alpha, inplace=inplace) 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | if self.has_bn: 47 | x = self.bn(x) 48 | if self.has_leakyrelu: 49 | x = self.relu(x) 50 | 51 | return x 52 | 53 | class SeparableConvBnLeakyRelu(nn.Module): 54 | def __init__(self, in_channels, out_channels, 55 | kernel_size=1, stride=1, padding=0, dilation=1,has_bn=True,inplace=True, 56 | leaky_alpha=0.3, has_leaky_relu=True, norm_layer=nn.BatchNorm2d, has_bias=True): 57 | super(SeparableConvBnLeakyRelu, self).__init__() 58 | 59 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, 60 | padding, dilation, groups=in_channels, 61 | bias=has_bias) 62 | self.bn = norm_layer(in_channels) 63 | self.point_wise_cbr = ConvBnLeakyRelu(in_channels, out_channels, 1, 1, 0, 64 | has_bn=has_bn, norm_layer=norm_layer,\ 65 | leaky_alpha=leaky_alpha, 66 | has_leaky_relu=has_leaky_relu, inplace=inplace,\ 67 | has_bias=has_bias) 68 | 69 | def forward(self, x): 70 | x = self.conv1(x) 71 | x = self.bn(x) 72 | x = self.point_wise_cbr(x) 73 | return x 74 | 75 | 76 | class SeparableConvBnRelu(nn.Module): 77 | def __init__(self, in_channels, out_channels, 78 | kernel_size=1, stride=1, padding=0, dilation=1,has_bn=True,inplace=True, 79 | has_relu=True, norm_layer=nn.BatchNorm2d, has_bias=True): 80 | super(SeparableConvBnRelu, self).__init__() 81 | 82 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, 83 | padding, dilation, groups=in_channels, 84 | bias=has_bias) 85 | self.bn = norm_layer(in_channels) 86 | self.point_wise_cbr = ConvBnRelu(in_channels, out_channels, 1, 1, 0, 87 | has_bn=has_bn, norm_layer=norm_layer,\ 88 | has_relu=has_relu, inplace=inplace,\ 89 | has_bias=has_bias) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.bn(x) 94 | x = self.point_wise_cbr(x) 95 | return x 96 | 97 | class SELayer(nn.Module): 98 | def __init__(self, in_planes, out_planes, reduction=16): 99 | super(SELayer, self).__init__() 100 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 101 | self.fc = nn.Sequential( 102 | nn.Linear(in_planes, out_planes // reduction), 103 | nn.ReLU(inplace=True), 104 | nn.Linear(out_planes // reduction, out_planes), 105 | nn.Sigmoid() 106 | ) 107 | self.out_planes = out_planes 108 | 109 | def forward(self, x): 110 | b, c, _, _ = x.size() 111 | y = self.avg_pool(x).view(b, c) 112 | y = self.fc(y).view(b, self.out_planes, 1, 1) 113 | return y 114 | 115 | # For DFN 116 | class ChannelAttention(nn.Module): 117 | def __init__(self, in_planes, out_planes, reduction): 118 | super(ChannelAttention, self).__init__() 119 | self.channel_attention = SELayer(in_planes, out_planes, reduction) 120 | 121 | def forward(self, x1, x2): 122 | fm = torch.cat([x1, x2], 1) 123 | channel_attetion = self.channel_attention(fm) 124 | fm = x1 * channel_attetion + x2 125 | 126 | return fm 127 | 128 | class BNRefine(nn.Module): 129 | def __init__(self, in_planes, out_planes, ksize, has_bias=False, 130 | has_relu=False, norm_layer=nn.BatchNorm2d, bn_eps=1e-5): 131 | super(BNRefine, self).__init__() 132 | self.conv_bn_relu = ConvBnRelu(in_planes, out_planes, ksize, 1, 133 | ksize // 2, has_bias=has_bias, 134 | norm_layer=norm_layer, bn_eps=bn_eps) 135 | self.conv_refine = nn.Conv2d(out_planes, out_planes, kernel_size=ksize, 136 | stride=1, padding=ksize // 2, dilation=1, 137 | bias=has_bias) 138 | self.has_relu = has_relu 139 | if self.has_relu: 140 | self.relu = nn.ReLU(inplace=False) 141 | 142 | def forward(self, x): 143 | t = self.conv_bn_relu(x) 144 | t = self.conv_refine(t) 145 | if self.has_relu: 146 | return self.relu(t + x) 147 | return t + x 148 | 149 | class RefineResidual(nn.Module): 150 | def __init__(self, in_planes, out_planes, relu_layer, ksize=3, has_bias=False, 151 | has_relu=False, norm_layer=nn.BatchNorm2d, bn_eps=1e-5, leaky_alpha=0.3, inplace=True): 152 | super(RefineResidual, self).__init__() 153 | self.conv_1x1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, 154 | stride=1, padding=0, dilation=1, 155 | bias=has_bias) 156 | self.cbr = ConvBnRelu(out_planes, out_planes, ksize, 1, 157 | ksize // 2, has_bias=has_bias, 158 | norm_layer=norm_layer, bn_eps=bn_eps) 159 | self.conv_refine = nn.Conv2d(out_planes, out_planes, kernel_size=ksize, 160 | stride=1, padding=ksize // 2, dilation=1, 161 | bias=has_bias) 162 | self.has_relu = has_relu 163 | if self.has_relu: 164 | if relu_layer == 'ReLU': 165 | self.relu = nn.ReLU(inplace=inplace) 166 | elif relu_layer == 'LeakyReLU': 167 | self.relu = nn.LeakyReLU(negative_slope=leaky_alpha, inplace=inplace) 168 | 169 | def forward(self, x): 170 | x = self.conv_1x1(x) 171 | t = self.cbr(x) 172 | t = self.conv_refine(t) 173 | if self.has_relu: 174 | return self.relu(t + x) 175 | return t + x 176 | 177 | class SeparableRefineResidual(nn.Module): 178 | def __init__(self, in_planes, out_planes, relu_layer, ksize=3, has_bias=False, 179 | has_relu=False, norm_layer=nn.BatchNorm2d, bn_eps=1e-5, leaky_alpha=0.3, inplace=True): 180 | super(SeparableRefineResidual, self).__init__() 181 | self.conv_1x1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, 182 | stride=1, padding=0, dilation=1, 183 | bias=has_bias) 184 | self.cbr = SeparableConvBnRelu(out_planes, out_planes, ksize, 1, 185 | ksize // 2, has_bias=has_bias, 186 | norm_layer=norm_layer) 187 | self.conv_refine = nn.Conv2d(out_planes, out_planes, kernel_size=ksize, 188 | stride=1, padding=ksize // 2, dilation=1, 189 | bias=has_bias) 190 | self.has_relu = has_relu 191 | if self.has_relu: 192 | if relu_layer == 'ReLU': 193 | self.relu = nn.ReLU(inplace=inplace) 194 | elif relu_layer == 'LeakyReLU': 195 | self.relu = nn.LeakyReLU(negative_slope=leaky_alpha, inplace=inplace) 196 | 197 | def forward(self, x): 198 | x = self.conv_1x1(x) 199 | t = self.cbr(x) 200 | t = self.conv_refine(t) 201 | if self.has_relu: 202 | return self.relu(t + x) 203 | return t + x 204 | 205 | 206 | # For BiSeNet 207 | class AttentionRefinement(nn.Module): 208 | def __init__(self, in_planes, out_planes, 209 | norm_layer=nn.BatchNorm2d): 210 | super(AttentionRefinement, self).__init__() 211 | self.conv_3x3 = ConvBnRelu(in_planes, out_planes, 3, 1, 1, 212 | has_bn=True, norm_layer=norm_layer, 213 | has_relu=True, has_bias=False) 214 | self.channel_attention = nn.Sequential( 215 | nn.AdaptiveAvgPool2d(1), 216 | ConvBnRelu(out_planes, out_planes, 1, 1, 0, 217 | has_bn=True, norm_layer=norm_layer, 218 | has_relu=False, has_bias=False), 219 | nn.Sigmoid() 220 | ) 221 | 222 | def forward(self, x): 223 | fm = self.conv_3x3(x) 224 | fm_se = self.channel_attention(fm) 225 | fm = fm * fm_se 226 | 227 | return fm 228 | 229 | class GlobalAvgPool2d(nn.Module): 230 | def __init__(self): 231 | """Global average pooling over the input's spatial dimensions""" 232 | super(GlobalAvgPool2d, self).__init__() 233 | 234 | def forward(self, inputs): 235 | in_size = inputs.size() 236 | inputs = inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 237 | inputs = inputs.view(in_size[0], in_size[1], 1, 1) 238 | 239 | return inputs 240 | 241 | class FeatureFusion(nn.Module): 242 | def __init__(self, in_planes, out_planes, 243 | reduction=1, norm_layer=nn.BatchNorm2d): 244 | super(FeatureFusion, self).__init__() 245 | self.conv_1x1 = ConvBnRelu(in_planes, out_planes, 1, 1, 0, 246 | has_bn=True, norm_layer=norm_layer, 247 | has_relu=True, has_bias=False) 248 | self.channel_attention = nn.Sequential( 249 | nn.AdaptiveAvgPool2d(1), 250 | ConvBnRelu(out_planes, out_planes // reduction, 1, 1, 0, 251 | has_bn=False, norm_layer=norm_layer, 252 | has_relu=True, has_bias=False), 253 | ConvBnRelu(out_planes // reduction, out_planes, 1, 1, 0, 254 | has_bn=False, norm_layer=norm_layer, 255 | has_relu=False, has_bias=False), 256 | nn.Sigmoid() 257 | ) 258 | 259 | def forward(self, x1, x2): 260 | fm = torch.cat([x1, x2], dim=1) 261 | fm = self.conv_1x1(fm) 262 | fm_se = self.channel_attention(fm) 263 | output = fm + fm * fm_se 264 | return output 265 | -------------------------------------------------------------------------------- /lib/misc/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def get_params(model): 5 | for m in model.modules(): 6 | if isinstance(m, nn.Conv2d): 7 | for p in m.parameters(): 8 | if p.requires_grad: 9 | yield p 10 | elif isinstance(m, nn.BatchNorm2d): 11 | for p in m.parameters(): 12 | if p.requires_grad: 13 | yield p 14 | -------------------------------------------------------------------------------- /lib/utils/init_func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/9/28 下午12:13 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : init_func.py.py 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, business_layer=False, 12 | **kwargs): 13 | for name, m in feature.named_modules(): 14 | if isinstance(m, (nn.Conv2d, nn.Conv3d)): 15 | conv_init(m.weight, **kwargs) 16 | if business_layer: 17 | nn.init.constant_(m.bias, 1) 18 | else: 19 | nn.init.constant_(m.bias, 0) 20 | elif isinstance(m, norm_layer): 21 | if m.weight is not None: 22 | m.eps = bn_eps 23 | m.momentum = bn_momentum 24 | nn.init.constant_(m.weight, 1) 25 | if m.bias is not None: 26 | nn.init.constant_(m.bias, 0) 27 | 28 | 29 | def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, 30 | **kwargs): 31 | if isinstance(module_list, list): 32 | for feature in module_list: 33 | __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, 34 | **kwargs) 35 | else: 36 | __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, 37 | **kwargs) 38 | 39 | 40 | def group_weight(weight_group, module, norm_layer, lr): 41 | group_decay = [] 42 | group_no_decay = [] 43 | for m in module.modules(): 44 | if isinstance(m, nn.Linear): 45 | #group_decay.append(m.weight) 46 | #group_decay.append(m.bias) 47 | for p in m.parameters(): 48 | yield p 49 | elif isinstance(m, nn.Conv2d): 50 | #group_decay.append(m.weight) 51 | #group_decay.append(m.bias) 52 | for p in m.parameters(): 53 | yield p 54 | elif isinstance(m, norm_layer) or isinstance(m, nn.GroupNorm): 55 | #if m.weight is not None: 56 | # group_decay.append(m.weight) 57 | #if m.bias is not None: 58 | # group_decay.append(m.bias) 59 | for p in m.parameters(): 60 | yield p 61 | else: 62 | #if 'weight' in dir(m) and m.weight is not None: 63 | # group_decay.append(m.weight) 64 | #if 'bias' in dir(m) and m.bias is not None: 65 | # group_decay.append(m.bias) 66 | for p in m.parameters(): 67 | yield p 68 | 69 | #assert len(list(module.parameters())) == len(group_decay) + len( 70 | # group_no_decay) 71 | #weight_group.append(dict(params=group_decay, lr=lr)) 72 | #weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) 73 | #return weight_group 74 | -------------------------------------------------------------------------------- /lib/utils/pyt_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import os 3 | import sys 4 | import time 5 | import argparse 6 | import logging 7 | from collections import OrderedDict, defaultdict 8 | 9 | import torch 10 | import torch.utils.model_zoo as model_zoo 11 | 12 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 13 | _default_level = logging.getLevelName(_default_level_name.upper()) 14 | 15 | 16 | class LogFormatter(logging.Formatter): 17 | log_fout = None 18 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 19 | date = '%(asctime)s ' 20 | msg = '%(message)s' 21 | 22 | def format(self, record): 23 | if record.levelno == logging.DEBUG: 24 | mcl, mtxt = self._color_dbg, 'DBG' 25 | elif record.levelno == logging.WARNING: 26 | mcl, mtxt = self._color_warn, 'WRN' 27 | elif record.levelno == logging.ERROR: 28 | mcl, mtxt = self._color_err, 'ERR' 29 | else: 30 | mcl, mtxt = self._color_normal, '' 31 | 32 | if mtxt: 33 | mtxt += ' ' 34 | 35 | if self.log_fout: 36 | self.__set_fmt(self.date_full + mtxt + self.msg) 37 | formatted = super(LogFormatter, self).format(record) 38 | # self.log_fout.write(formatted) 39 | # self.log_fout.write('\n') 40 | # self.log_fout.flush() 41 | return formatted 42 | 43 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 44 | formatted = super(LogFormatter, self).format(record) 45 | 46 | return formatted 47 | 48 | if sys.version_info.major < 3: 49 | def __set_fmt(self, fmt): 50 | self._fmt = fmt 51 | else: 52 | def __set_fmt(self, fmt): 53 | self._style._fmt = fmt 54 | 55 | @staticmethod 56 | def _color_dbg(msg): 57 | return '\x1b[36m{}\x1b[0m'.format(msg) 58 | 59 | @staticmethod 60 | def _color_warn(msg): 61 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 62 | 63 | @staticmethod 64 | def _color_err(msg): 65 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 66 | 67 | @staticmethod 68 | def _color_omitted(msg): 69 | return '\x1b[35m{}\x1b[0m'.format(msg) 70 | 71 | @staticmethod 72 | def _color_normal(msg): 73 | return msg 74 | 75 | @staticmethod 76 | def _color_date(msg): 77 | return '\x1b[32m{}\x1b[0m'.format(msg) 78 | 79 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 80 | logger = logging.getLogger() 81 | logger.setLevel(_default_level) 82 | del logger.handlers[:] 83 | 84 | if log_dir and log_file: 85 | pyt_utils.ensure_dir(log_dir) 86 | LogFormatter.log_fout = True 87 | file_handler = logging.FileHandler(log_file, mode='a') 88 | file_handler.setLevel(logging.INFO) 89 | file_handler.setFormatter(formatter) 90 | logger.addHandler(file_handler) 91 | 92 | stream_handler = logging.StreamHandler() 93 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 94 | stream_handler.setLevel(0) 95 | logger.addHandler(stream_handler) 96 | return logger 97 | 98 | logger = get_logger() 99 | 100 | model_urls = { 101 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 102 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 103 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 104 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 105 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 106 | } 107 | 108 | 109 | def load_model(model, model_file, is_restore=False): 110 | t_start = time.time() 111 | if isinstance(model_file, str): 112 | state_dict = torch.load(model_file) 113 | if 'model' in state_dict.keys(): 114 | state_dict = state_dict['model'] 115 | else: 116 | state_dict = model_file 117 | t_ioend = time.time() 118 | 119 | if is_restore: 120 | new_state_dict = OrderedDict() 121 | for k, v in state_dict.items(): 122 | name = 'module.' + k 123 | new_state_dict[name] = v 124 | state_dict = new_state_dict 125 | 126 | model.load_state_dict(state_dict, strict=False) 127 | ckpt_keys = set(state_dict.keys()) 128 | own_keys = set(model.state_dict().keys()) 129 | missing_keys = own_keys - ckpt_keys 130 | unexpected_keys = ckpt_keys - own_keys 131 | 132 | if len(missing_keys) > 0: 133 | logger.warning('Missing key(s) in state_dict: {}'.format( 134 | ', '.join('{}'.format(k) for k in missing_keys))) 135 | 136 | if len(unexpected_keys) > 0: 137 | logger.warning('Unexpected key(s) in state_dict: {}'.format( 138 | ', '.join('{}'.format(k) for k in unexpected_keys))) 139 | 140 | del state_dict 141 | t_end = time.time() 142 | logger.info( 143 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 144 | t_ioend - t_start, t_end - t_ioend)) 145 | 146 | return model 147 | 148 | 149 | def parse_devices(input_devices): 150 | if input_devices.endswith('*'): 151 | devices = list(range(torch.cuda.device_count())) 152 | return devices 153 | 154 | devices = [] 155 | for d in input_devices.split(','): 156 | if '-' in d: 157 | start_device, end_device = d.split('-')[0], d.split('-')[1] 158 | assert start_device != '' 159 | assert end_device != '' 160 | start_device, end_device = int(start_device), int(end_device) 161 | assert start_device < end_device 162 | assert end_device < torch.cuda.device_count() 163 | for sd in range(start_device, end_device + 1): 164 | devices.append(sd) 165 | else: 166 | device = int(d) 167 | assert device < torch.cuda.device_count() 168 | devices.append(device) 169 | 170 | logger.info('using devices {}'.format( 171 | ', '.join([str(d) for d in devices]))) 172 | 173 | return devices 174 | 175 | 176 | def extant_file(x): 177 | """ 178 | 'Type' for argparse - checks that file exists but does not open. 179 | """ 180 | if not os.path.exists(x): 181 | # Argparse uses the ArgumentTypeError to give a rejection message like: 182 | # error: argument input: x does not exist 183 | raise argparse.ArgumentTypeError("{0} does not exist".format(x)) 184 | return x 185 | 186 | 187 | def link_file(src, target): 188 | if os.path.isdir(target) or os.path.isfile(target): 189 | os.remove(target) 190 | os.system('ln -s {} {}'.format(src, target)) 191 | 192 | 193 | def ensure_dir(path): 194 | if not os.path.isdir(path): 195 | os.makedirs(path) 196 | 197 | 198 | def _dbg_interactive(var, value): 199 | from IPython import embed 200 | embed() 201 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_depth_only/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import time 4 | import numpy as np 5 | from easydict import EasyDict as edict 6 | import argparse 7 | 8 | C = edict() 9 | config = C 10 | cfg = C 11 | 12 | C.seed = 304 13 | 14 | """please config ROOT_dir and user when u first using""" 15 | C.abs_dir = osp.realpath(".") 16 | C.this_dir = C.abs_dir.split(osp.sep)[-1] 17 | C.label_dir = C.abs_dir.split(osp.sep)[-2] 18 | C.root_dir = C.abs_dir[:C.abs_dir.index('model')] 19 | C.log_dir = osp.abspath(osp.join(C.root_dir, 'log', C.label_dir, C.this_dir)) 20 | C.log_dir_link = osp.join(C.abs_dir, 'log') 21 | C.snapshot_dir = osp.abspath(osp.join(C.log_dir, "snapshot")) 22 | 23 | exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) 24 | C.log_file = C.log_dir + '/log_' + exp_time + '.log' 25 | C.link_log_file = C.log_file + '/log_last.log' 26 | C.val_log_file = C.log_dir + '/val_' + exp_time + '.log' 27 | C.link_val_log_file = C.log_dir + '/val_last.log' 28 | 29 | """Data Dir and Weight Dir""" 30 | C.data_source = '/home/duy/phd/Displacement_Field/dataset/nyu_depth_v2_labeled.mat' 31 | C.train_test_splits = '/home/duy/phd/Displacement_Field/dataset/nyuv2_splits.mat' 32 | C.is_test = False 33 | 34 | """Path Config""" 35 | 36 | def add_path(path): 37 | if path not in sys.path: 38 | sys.path.insert(0, path) 39 | 40 | 41 | add_path(osp.join(C.root_dir, 'lib')) 42 | 43 | """Image Config""" 44 | C.image_mean = np.array([0.485, 0.456, 0.406]) # 0.485, 0.456, 0.406 45 | C.image_std = np.array([0.229, 0.224, 0.225]) 46 | C.use_gauss_blur = True 47 | C.gaussian_kernel_range = [.3, 1] 48 | C.max_kernel = 51 49 | C.downsampling_scale = 8 50 | C.interpolation = 'LINEAR' 51 | C.use_updown_sampling = False 52 | C.target_size = 320 53 | C.image_height = 320 54 | C.image_width = 320 55 | C.num_train_imgs = 795 56 | C.num_eval_imgs = 654 57 | 58 | """ Settings for network, this would be different for each kind of model""" 59 | C.fix_bias = False 60 | C.fix_bn = False 61 | C.bn_eps = 1e-5 62 | C.bn_momentum = 0.1 63 | C.loss_weight = None 64 | C.pretrained_model = None 65 | 66 | """Train Config""" 67 | C.lr = 1e-3 68 | C.lr_power = 0.9 69 | C.momentum = 0.9 70 | C.weight_decay = 1e-6 71 | C.batch_size = 1 72 | C.nepochs = 20 73 | C.niters_per_epoch = 795 74 | C.num_workers = 4 75 | C.train_scale_array = [1., 1.5, 2, 2.5, 4] 76 | C.business_lr_ratio = 1.0 77 | C.aux_loss_ratio = 1 78 | 79 | """Display Config""" 80 | C.snapshot_iter = 5 81 | C.record_info_iter = 20 82 | C.display_iter = 50 83 | 84 | def open_tensorboard(): 85 | pass 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument( 90 | '-tb', '--tensorboard', default=False, action='store_true') 91 | args = parser.parse_args() 92 | 93 | if args.tensorboard: 94 | open_tensorboard() 95 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_depth_only/dataloader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils import data 6 | 7 | from config import config 8 | from datareader.img_utils import random_scale, random_mirror, normalize, \ 9 | generate_random_crop_pos, random_crop_pad_to_shape, \ 10 | random_uniform_gaussian_blur, normalize_depth, rgb2gray 11 | 12 | class TrainPre(object): 13 | def __init__(self, img_mean, img_std, target_size, use_gauss_blur=True): 14 | self.img_mean = img_mean 15 | self.img_std = img_std 16 | self.target_size = target_size 17 | self.use_gauss_blur = use_gauss_blur 18 | 19 | def __call__(self, img, gt): 20 | img, gt = random_mirror(img, gt) 21 | if config.train_scale_array is not None: 22 | img, gt, scale = random_scale(img, gt, config.train_scale_array) 23 | 24 | #img = normalize(img, self.img_mean, self.img_std) 25 | img = rgb2gray(img) 26 | img = img / 255. 27 | 28 | crop_size = (config.image_height, config.image_width) 29 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size) 30 | 31 | p_img, _ = random_crop_pad_to_shape(img, crop_pos, crop_size, 0) 32 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 0) 33 | 34 | p_mask = np.zeros(p_gt.shape) 35 | p_mask[p_gt > 0] = 1 36 | p_depth = p_gt.copy() 37 | if self.use_gauss_blur: 38 | p_depth = random_uniform_gaussian_blur(p_depth, config.gaussian_kernel_range, config.max_kernel) 39 | p_gt = normalize_depth(p_gt) 40 | p_depth = normalize_depth(p_depth) 41 | 42 | p_img = np.expand_dims(p_img, axis=0) 43 | p_depth = np.expand_dims(p_depth, axis=0) 44 | extra_dict = None 45 | 46 | return p_img, p_depth, p_gt, p_mask, extra_dict 47 | 48 | 49 | def get_train_loader(engine, dataset): 50 | data_setting = {'data_source': config.data_source, 51 | 'train_test_splits': config.train_test_splits} 52 | train_preprocess = TrainPre(config.image_mean, config.image_std, 53 | config.target_size, config.use_gauss_blur 54 | ) 55 | 56 | train_dataset = dataset(data_setting, "train", train_preprocess, 57 | config.niters_per_epoch * config.batch_size) 58 | 59 | train_sampler = None 60 | is_shuffle = True 61 | batch_size = config.batch_size 62 | 63 | train_loader = data.DataLoader(train_dataset, 64 | batch_size=batch_size, 65 | num_workers=config.num_workers, 66 | shuffle=is_shuffle) 67 | 68 | return train_loader, train_sampler 69 | 70 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_depth_only/df.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from unet import Unet 5 | from config import config 6 | 7 | class Displacement_Field(nn.Module): 8 | def __init__(self): 9 | super(Displacement_Field, self).__init__() 10 | self.displacement_net = Unet(n_channels=1, rgb_channels=1, n_classes=2) 11 | self.theta = torch.Tensor([[1, 0, 0], 12 | [0, 1, 0]]) 13 | self.theta = self.theta.view(-1, 2, 3) 14 | 15 | def forward(self, x): 16 | max_disp = .9 17 | displacement_map = self.displacement_net(x) 18 | output = [] 19 | 20 | displacement_map = displacement_map / 320 21 | displacement_map = displacement_map.clamp(min=-max_disp, max=max_disp) 22 | 23 | theta = self.theta.repeat(x.size()[0], 1, 1) 24 | grid = F.affine_grid(theta, x.size()).cuda() 25 | grid = (grid + displacement_map.transpose(1,2).transpose(2,3)).clamp(min=-1, max=1) 26 | x = F.grid_sample(x, grid) 27 | return x 28 | 29 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_depth_only/train.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import argparse 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.loss import MSELoss 10 | import torch.nn.functional as F 11 | 12 | from config import config 13 | from dataloader import get_train_loader 14 | from df import Displacement_Field 15 | from datasets.nyu import NYUDataset 16 | 17 | from utils.init_func import init_weight 18 | from misc.utils import get_params 19 | from engine.lr_policy import PolyLR 20 | from engine.engine import Engine 21 | 22 | class Mseloss(MSELoss): 23 | def __init__(self): 24 | super(Mseloss, self).__init__() 25 | 26 | def forward(self, input, target, mask=None): 27 | if mask is not None: 28 | input = input.squeeze(1) 29 | input = torch.mul(input, mask) 30 | target = torch.mul(target, mask) 31 | loss = F.mse_loss(input, target, reduction=self.reduction) 32 | 33 | return loss 34 | 35 | parser = argparse.ArgumentParser() 36 | 37 | with Engine(custom_parser=parser) as engine: 38 | args = parser.parse_args() 39 | 40 | seed = config.seed 41 | torch.manual_seed(seed) 42 | 43 | train_loader, train_sampler = get_train_loader(engine, NYUDataset) 44 | 45 | criterion = Mseloss() 46 | BatchNorm2d = nn.BatchNorm2d 47 | 48 | model = Displacement_Field() 49 | init_weight(model.displacement_net, nn.init.xavier_normal_, 50 | BatchNorm2d, config.bn_eps, config.bn_momentum) 51 | base_lr = config.lr 52 | 53 | total_iteration = config.nepochs * config.niters_per_epoch 54 | lr_policy = PolyLR(base_lr, config.lr_power, total_iteration) 55 | 56 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 57 | model.to(device) 58 | 59 | if engine.continue_state_object: 60 | engine.restore_checkpoint() 61 | model.zero_grad() 62 | model.train() 63 | 64 | optimizer = torch.optim.Adam(params=get_params(model), 65 | lr=base_lr) 66 | engine.register_state(dataloader=train_loader, model=model, 67 | optimizer=optimizer) 68 | 69 | for epoch in range(engine.state.epoch, config.nepochs): 70 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 71 | pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, 72 | bar_format=bar_format) 73 | dataloader = iter(train_loader) 74 | for idx in pbar: 75 | optimizer.zero_grad() 76 | engine.update_iteration(epoch, idx) 77 | minibatch = dataloader.next() 78 | 79 | deps = minibatch['data'] 80 | gts = minibatch['label'] 81 | masks = minibatch['mask'] 82 | 83 | deps = deps.cuda() 84 | deps = torch.autograd.Variable(deps) 85 | gts = gts.cuda() 86 | gts = torch.autograd.Variable(gts) 87 | masks = masks.cuda() 88 | masks = torch.autograd.Variable(masks) 89 | 90 | pred = model(deps) 91 | loss = criterion(pred, gts, masks) 92 | current_idx = epoch * config.niters_per_epoch + idx 93 | lr = lr_policy.get_lr(current_idx) 94 | 95 | optimizer.param_groups[0]['lr'] = lr 96 | loss.backward() 97 | optimizer.step() 98 | print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \ 99 | + ' Iter{}/{}:'.format(idx + 1, config.niters_per_epoch) \ 100 | + ' lr=%.2e' % lr \ 101 | + ' loss=%.6f' % float(loss) 102 | 103 | pbar.set_description(print_str, refresh=False) 104 | 105 | if (epoch == (config.nepochs - 1)) or (epoch % config.snapshot_iter == 0): 106 | engine.save_and_link_checkpoint(config.snapshot_dir, 107 | config.log_dir, 108 | config.log_dir_link) 109 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_depth_only/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from config import config 5 | from layers.basic_module import ConvBnRelu, ConvBnLeakyRelu, SeparableConvBnLeakyRelu, SeparableConvBnRelu, \ 6 | SELayer, ChannelAttention, BNRefine, RefineResidual, AttentionRefinement, GlobalAvgPool2d, \ 7 | FeatureFusion 8 | 9 | ########################################################### 10 | 11 | class Unet(nn.Module): 12 | def __init__(self, n_channels, rgb_channels, n_classes): 13 | super(Unet, self).__init__() 14 | 15 | self.down_scale = nn.MaxPool2d(2) 16 | self.up_scale = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 17 | 18 | self.depth_down_layer0 = ConvBnLeakyRelu(n_channels, 8, 3, 1, 1, 1, 1,\ 19 | has_bn=True, leaky_alpha=0.3, \ 20 | has_leaky_relu=True, inplace=True, has_bias=True) 21 | self.depth_down_layer1 = ConvBnLeakyRelu(8, 16, 3, 1, 1, 1, 1,\ 22 | has_bn=True, leaky_alpha=0.3, \ 23 | has_leaky_relu=True, inplace=True, has_bias=True) 24 | self.depth_down_layer2 = ConvBnLeakyRelu(16, 32, 3, 1, 1, 1, 1,\ 25 | has_bn=True, leaky_alpha=0.3, \ 26 | has_leaky_relu=True, inplace=True, has_bias=True) 27 | self.depth_down_layer3 = ConvBnLeakyRelu(32, 64, 3, 1, 1, 1, 1,\ 28 | has_bn=True, leaky_alpha=0.3, \ 29 | has_leaky_relu=True, inplace=True, has_bias=True) 30 | 31 | self.depth_up_layer0 = RefineResidual(64, 64, relu_layer='LeakyReLU', \ 32 | has_bias=True, has_relu=True, leaky_alpha=0.3) 33 | self.depth_up_layer1 = RefineResidual(64, 32, relu_layer='LeakyReLU', \ 34 | has_bias=True, has_relu=True, leaky_alpha=0.3) 35 | self.depth_up_layer2 = RefineResidual(32, 16, relu_layer='LeakyReLU', \ 36 | has_bias=True, has_relu=True, leaky_alpha=0.3) 37 | self.depth_up_layer3 = RefineResidual(16, 8, relu_layer='LeakyReLU', \ 38 | has_bias=True, has_relu=True, leaky_alpha=0.3) 39 | 40 | self.depth_out_layer0 = RefineResidual(64, 1, relu_layer='LeakyReLU', \ 41 | has_bias=True, has_relu=True, leaky_alpha=0.3) 42 | self.depth_out_layer1 = RefineResidual(32, 1, relu_layer='LeakyReLU', \ 43 | has_bias=True, has_relu=True, leaky_alpha=0.3) 44 | self.depth_out_layer2 = RefineResidual(16, 1, relu_layer='LeakyReLU', \ 45 | has_bias=True, has_relu=True, leaky_alpha=0.3) 46 | self.depth_out_layer3 = RefineResidual(8, 1, relu_layer='LeakyReLU', \ 47 | has_bias=True, has_relu=True, leaky_alpha=0.3) 48 | 49 | self.refine_layer0 = ConvBnLeakyRelu(8, 4, 3, 1, 1, 1, 1,\ 50 | has_bn=True, leaky_alpha=0.3, \ 51 | has_leaky_relu=True, inplace=True, has_bias=True) 52 | self.refine_layer1 = ConvBnLeakyRelu(4, 4, 3, 1, 1, 1, 1,\ 53 | has_bn=True, leaky_alpha=0.3, \ 54 | has_leaky_relu=True, inplace=True, has_bias=True) 55 | 56 | self.output_layer = ConvBnRelu(4, 2, 3, 1, 1, 1, 1,\ 57 | has_bn=False, \ 58 | has_relu=False, inplace=True, has_bias=True) 59 | 60 | def forward(self, x): 61 | #### Depth #### 62 | x1 = self.depth_down_layer0(x) 63 | x1 = self.down_scale(x1) 64 | x1 = self.depth_down_layer1(x1) 65 | x1 = self.down_scale(x1) 66 | x2 = self.depth_down_layer2(x1) 67 | x2 = self.down_scale(x2) 68 | x = self.depth_down_layer3(x2) 69 | x = self.down_scale(x) 70 | 71 | x = self.depth_up_layer0(x) 72 | x = self.up_scale(x) 73 | x = self.depth_up_layer1(x) 74 | x = x + x2 75 | x = self.up_scale(x) 76 | x = self.depth_up_layer2(x) 77 | x = x + x1 78 | x = self.up_scale(x) 79 | x = self.depth_up_layer3(x) 80 | x = self.up_scale(x) 81 | x = self.refine_layer0(x) 82 | x = self.refine_layer1(x) 83 | x = self.output_layer(x) 84 | return x 85 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import time 4 | import numpy as np 5 | from easydict import EasyDict as edict 6 | import argparse 7 | 8 | C = edict() 9 | config = C 10 | cfg = C 11 | 12 | C.seed = 304 13 | 14 | """please config ROOT_dir and user when u first using""" 15 | C.abs_dir = osp.realpath(".") 16 | C.this_dir = C.abs_dir.split(osp.sep)[-1] 17 | C.label_dir = C.abs_dir.split(osp.sep)[-2] 18 | C.root_dir = C.abs_dir[:C.abs_dir.index('model')] 19 | C.log_dir = osp.abspath(osp.join(C.root_dir, 'log', C.label_dir, C.this_dir)) 20 | C.log_dir_link = osp.join(C.abs_dir, 'log') 21 | C.snapshot_dir = osp.abspath(osp.join(C.log_dir, "snapshot")) 22 | 23 | exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) 24 | C.log_file = C.log_dir + '/log_' + exp_time + '.log' 25 | C.link_log_file = C.log_file + '/log_last.log' 26 | C.val_log_file = C.log_dir + '/val_' + exp_time + '.log' 27 | C.link_val_log_file = C.log_dir + '/val_last.log' 28 | 29 | """Data Dir and Weight Dir""" 30 | C.data_source = osp.join(C.root_dir, 'dataset', 'nyu_depth_v2_labeled.mat') 31 | C.train_test_splits = osp.join(C.root_dir, 'dataset', 'nyuv2_splits.mat') 32 | C.is_test = False 33 | 34 | """Path Config""" 35 | 36 | def add_path(path): 37 | if path not in sys.path: 38 | sys.path.insert(0, path) 39 | 40 | 41 | add_path(osp.join(C.root_dir, 'lib')) 42 | 43 | """Image Config""" 44 | C.image_mean = np.array([0.485, 0.456, 0.406]) # 0.485, 0.456, 0.406 45 | C.image_std = np.array([0.229, 0.224, 0.225]) 46 | C.use_gauss_blur = True 47 | C.gaussian_kernel_range = [.3, 1] 48 | C.max_kernel = 51 49 | C.downsampling_scale = 8 50 | C.interpolation = 'LINEAR' 51 | C.use_updown_sampling = False 52 | C.target_size = 320 53 | C.image_height = 320 54 | C.image_width = 320 55 | C.num_train_imgs = 795 56 | C.num_eval_imgs = 654 57 | 58 | """ Settings for network, this would be different for each kind of model""" 59 | C.fix_bias = False 60 | C.fix_bn = False 61 | C.bn_eps = 1e-5 62 | C.bn_momentum = 0.1 63 | C.loss_weight = None 64 | C.pretrained_model = None 65 | 66 | """Train Config""" 67 | C.lr = 1e-3 68 | C.lr_power = 0.9 69 | C.momentum = 0.9 70 | C.weight_decay = 1e-6 71 | C.batch_size = 1 72 | C.nepochs = 20 73 | C.niters_per_epoch = 795 74 | C.num_workers = 4 75 | C.train_scale_array = [1., 1.5, 2, 2.5, 4] 76 | C.business_lr_ratio = 1.0 77 | C.aux_loss_ratio = 1 78 | 79 | """Display Config""" 80 | C.snapshot_iter = 5 81 | C.record_info_iter = 20 82 | C.display_iter = 50 83 | 84 | def open_tensorboard(): 85 | pass 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument( 90 | '-tb', '--tensorboard', default=False, action='store_true') 91 | args = parser.parse_args() 92 | 93 | if args.tensorboard: 94 | open_tensorboard() 95 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance/dataloader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils import data 6 | 7 | from config import config 8 | from datareader.img_utils import random_scale, random_mirror, normalize, \ 9 | generate_random_crop_pos, random_crop_pad_to_shape, \ 10 | random_uniform_gaussian_blur, normalize_depth, rgb2gray 11 | 12 | class TrainPre(object): 13 | def __init__(self, img_mean, img_std, target_size, use_gauss_blur=True): 14 | self.img_mean = img_mean 15 | self.img_std = img_std 16 | self.target_size = target_size 17 | self.use_gauss_blur = use_gauss_blur 18 | 19 | def __call__(self, img, gt): 20 | img, gt = random_mirror(img, gt) 21 | if config.train_scale_array is not None: 22 | img, gt, scale = random_scale(img, gt, config.train_scale_array) 23 | 24 | #img = normalize(img, self.img_mean, self.img_std) 25 | img = rgb2gray(img) 26 | img = img / 255. 27 | 28 | crop_size = (config.image_height, config.image_width) 29 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size) 30 | 31 | p_img, _ = random_crop_pad_to_shape(img, crop_pos, crop_size, 0) 32 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 0) 33 | 34 | p_mask = np.zeros(p_gt.shape) 35 | p_mask[p_gt > 0] = 1 36 | p_depth = p_gt.copy() 37 | if self.use_gauss_blur: 38 | p_depth = random_uniform_gaussian_blur(p_depth, config.gaussian_kernel_range, config.max_kernel) 39 | p_gt = normalize_depth(p_gt) 40 | p_depth = normalize_depth(p_depth) 41 | 42 | p_img = np.expand_dims(p_img, axis=0) 43 | p_depth = np.expand_dims(p_depth, axis=0) 44 | extra_dict = None 45 | 46 | return p_img, p_depth, p_gt, p_mask, extra_dict 47 | 48 | 49 | def get_train_loader(engine, dataset): 50 | data_setting = {'data_source': config.data_source, 51 | 'train_test_splits': config.train_test_splits} 52 | train_preprocess = TrainPre(config.image_mean, config.image_std, 53 | config.target_size, config.use_gauss_blur 54 | ) 55 | 56 | train_dataset = dataset(data_setting, "train", train_preprocess, 57 | config.niters_per_epoch * config.batch_size) 58 | 59 | train_sampler = None 60 | is_shuffle = True 61 | batch_size = config.batch_size 62 | 63 | train_loader = data.DataLoader(train_dataset, 64 | batch_size=batch_size, 65 | num_workers=config.num_workers, 66 | shuffle=is_shuffle) 67 | 68 | return train_loader, train_sampler 69 | 70 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance/df.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from unet import Unet 5 | from config import config 6 | 7 | class Displacement_Field(nn.Module): 8 | def __init__(self): 9 | super(Displacement_Field, self).__init__() 10 | self.displacement_net = Unet(n_channels=1, rgb_channels=1, n_classes=2) 11 | self.theta = torch.Tensor([[1, 0, 0], 12 | [0, 1, 0]]) 13 | self.theta = self.theta.view(-1, 2, 3) 14 | 15 | def forward(self, rgb, x): 16 | max_disp = .9 17 | displacement_map = self.displacement_net(rgb, x) 18 | output = [] 19 | 20 | displacement_map = displacement_map / 320 21 | displacement_map = displacement_map.clamp(min=-max_disp, max=max_disp) 22 | 23 | theta = self.theta.repeat(x.size()[0], 1, 1) 24 | grid = F.affine_grid(theta, x.size()).cuda() 25 | grid = (grid + displacement_map.transpose(1,2).transpose(2,3)).clamp(min=-1, max=1) 26 | x = F.grid_sample(x, grid) 27 | return x 28 | 29 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance/train.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import argparse 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.loss import MSELoss 10 | import torch.nn.functional as F 11 | 12 | from config import config 13 | from dataloader import get_train_loader 14 | from df import Displacement_Field 15 | from datasets.nyu import NYUDataset 16 | 17 | from utils.init_func import init_weight 18 | from misc.utils import get_params 19 | from engine.lr_policy import PolyLR 20 | from engine.engine import Engine 21 | 22 | class Mseloss(MSELoss): 23 | def __init__(self): 24 | super(Mseloss, self).__init__() 25 | 26 | def forward(self, input, target, mask=None): 27 | if mask is not None: 28 | input = input.squeeze(1) 29 | input = torch.mul(input, mask) 30 | target = torch.mul(target, mask) 31 | loss = F.mse_loss(input, target, reduction=self.reduction) 32 | 33 | return loss 34 | 35 | parser = argparse.ArgumentParser() 36 | 37 | with Engine(custom_parser=parser) as engine: 38 | args = parser.parse_args() 39 | 40 | seed = config.seed 41 | torch.manual_seed(seed) 42 | 43 | train_loader, train_sampler = get_train_loader(engine, NYUDataset) 44 | 45 | criterion = Mseloss() 46 | BatchNorm2d = nn.BatchNorm2d 47 | 48 | model = Displacement_Field() 49 | init_weight(model.displacement_net, nn.init.xavier_normal_, 50 | BatchNorm2d, config.bn_eps, config.bn_momentum) 51 | base_lr = config.lr 52 | 53 | total_iteration = config.nepochs * config.niters_per_epoch 54 | lr_policy = PolyLR(base_lr, config.lr_power, total_iteration) 55 | 56 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 57 | model.to(device) 58 | 59 | if engine.continue_state_object: 60 | engine.restore_checkpoint() 61 | model.zero_grad() 62 | model.train() 63 | 64 | optimizer = torch.optim.Adam(params=get_params(model), 65 | lr=base_lr) 66 | engine.register_state(dataloader=train_loader, model=model, 67 | optimizer=optimizer) 68 | 69 | for epoch in range(engine.state.epoch, config.nepochs): 70 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 71 | pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, 72 | bar_format=bar_format) 73 | dataloader = iter(train_loader) 74 | for idx in pbar: 75 | optimizer.zero_grad() 76 | engine.update_iteration(epoch, idx) 77 | minibatch = dataloader.next() 78 | 79 | imgs = minibatch['guidance'] 80 | deps = minibatch['data'] 81 | gts = minibatch['label'] 82 | masks = minibatch['mask'] 83 | 84 | imgs = imgs.cuda() 85 | imgs = torch.autograd.Variable(imgs) 86 | deps = deps.cuda() 87 | deps = torch.autograd.Variable(deps) 88 | gts = gts.cuda() 89 | gts = torch.autograd.Variable(gts) 90 | masks = masks.cuda() 91 | masks = torch.autograd.Variable(masks) 92 | 93 | pred = model(imgs, deps) 94 | loss = criterion(pred, gts, masks) 95 | current_idx = epoch * config.niters_per_epoch + idx 96 | lr = lr_policy.get_lr(current_idx) 97 | 98 | optimizer.param_groups[0]['lr'] = lr 99 | loss.backward() 100 | optimizer.step() 101 | print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \ 102 | + ' Iter{}/{}:'.format(idx + 1, config.niters_per_epoch) \ 103 | + ' lr=%.2e' % lr \ 104 | + ' loss=%.6f' % float(loss) 105 | 106 | pbar.set_description(print_str, refresh=False) 107 | 108 | if (epoch == (config.nepochs - 1)) or (epoch % config.snapshot_iter == 0): 109 | engine.save_and_link_checkpoint(config.snapshot_dir, 110 | config.log_dir, 111 | config.log_dir_link) 112 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from config import config 5 | from layers.basic_module import ConvBnRelu, ConvBnLeakyRelu, SeparableConvBnLeakyRelu, SeparableConvBnRelu, \ 6 | SELayer, ChannelAttention, BNRefine, RefineResidual, AttentionRefinement, GlobalAvgPool2d, \ 7 | FeatureFusion 8 | 9 | ########################################################### 10 | 11 | class Unet(nn.Module): 12 | def __init__(self, n_channels, rgb_channels, n_classes): 13 | super(Unet, self).__init__() 14 | 15 | self.down_scale = nn.MaxPool2d(2) 16 | self.up_scale = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 17 | 18 | self.depth_down_layer0 = ConvBnLeakyRelu(n_channels, 8, 3, 1, 1, 1, 1,\ 19 | has_bn=True, leaky_alpha=0.3, \ 20 | has_leaky_relu=True, inplace=True, has_bias=True) 21 | self.depth_down_layer1 = ConvBnLeakyRelu(8, 16, 3, 1, 1, 1, 1,\ 22 | has_bn=True, leaky_alpha=0.3, \ 23 | has_leaky_relu=True, inplace=True, has_bias=True) 24 | self.depth_down_layer2 = ConvBnLeakyRelu(16, 32, 3, 1, 1, 1, 1,\ 25 | has_bn=True, leaky_alpha=0.3, \ 26 | has_leaky_relu=True, inplace=True, has_bias=True) 27 | self.depth_down_layer3 = ConvBnLeakyRelu(32, 64, 3, 1, 1, 1, 1,\ 28 | has_bn=True, leaky_alpha=0.3, \ 29 | has_leaky_relu=True, inplace=True, has_bias=True) 30 | 31 | self.depth_up_layer0 = RefineResidual(64, 64, relu_layer='LeakyReLU', \ 32 | has_bias=True, has_relu=True, leaky_alpha=0.3) 33 | self.depth_up_layer1 = RefineResidual(64, 32, relu_layer='LeakyReLU', \ 34 | has_bias=True, has_relu=True, leaky_alpha=0.3) 35 | self.depth_up_layer2 = RefineResidual(32, 16, relu_layer='LeakyReLU', \ 36 | has_bias=True, has_relu=True, leaky_alpha=0.3) 37 | self.depth_up_layer3 = RefineResidual(16, 8, relu_layer='LeakyReLU', \ 38 | has_bias=True, has_relu=True, leaky_alpha=0.3) 39 | 40 | self.depth_out_layer0 = RefineResidual(64, 1, relu_layer='LeakyReLU', \ 41 | has_bias=True, has_relu=True, leaky_alpha=0.3) 42 | self.depth_out_layer1 = RefineResidual(32, 1, relu_layer='LeakyReLU', \ 43 | has_bias=True, has_relu=True, leaky_alpha=0.3) 44 | self.depth_out_layer2 = RefineResidual(16, 1, relu_layer='LeakyReLU', \ 45 | has_bias=True, has_relu=True, leaky_alpha=0.3) 46 | self.depth_out_layer3 = RefineResidual(8, 1, relu_layer='LeakyReLU', \ 47 | has_bias=True, has_relu=True, leaky_alpha=0.3) 48 | 49 | self.refine_layer0 = ConvBnLeakyRelu(8, 4, 3, 1, 1, 1, 1,\ 50 | has_bn=True, leaky_alpha=0.3, \ 51 | has_leaky_relu=True, inplace=True, has_bias=True) 52 | self.refine_layer1 = ConvBnLeakyRelu(4, 4, 3, 1, 1, 1, 1,\ 53 | has_bn=True, leaky_alpha=0.3, \ 54 | has_leaky_relu=True, inplace=True, has_bias=True) 55 | 56 | self.rgb_down_layer0 = ConvBnRelu(rgb_channels, 8, 3, 1, 1, 1, 1,\ 57 | has_bn=True, \ 58 | inplace=True, has_bias=True) 59 | self.rgb_down_layer1 = ConvBnRelu(8, 16, 3, 1, 1, 1, 1,\ 60 | has_bn=True, \ 61 | inplace=True, has_bias=True) 62 | self.rgb_down_layer2 = ConvBnRelu(16, 32, 3, 1, 1, 1, 1,\ 63 | has_bn=True, \ 64 | inplace=True, has_bias=True) 65 | self.rgb_down_layer3 = ConvBnRelu(32, 64, 3, 1, 1, 1, 1,\ 66 | has_bn=True, \ 67 | inplace=True, has_bias=True) 68 | 69 | self.rgb_refine_layer0 = RefineResidual(8, 8, relu_layer='ReLU', \ 70 | has_bias=True, has_relu=True) 71 | self.rgb_refine_layer1 = RefineResidual(16, 16, relu_layer='ReLU', \ 72 | has_bias=True, has_relu=True) 73 | self.rgb_refine_layer2 = RefineResidual(32, 32, relu_layer='ReLU', \ 74 | has_bias=True, has_relu=True) 75 | self.rgb_refine_layer3 = RefineResidual(64, 64, relu_layer='ReLU', \ 76 | has_bias=True, has_relu=True) 77 | 78 | self.output_layer = ConvBnRelu(4, 2, 3, 1, 1, 1, 1,\ 79 | has_bn=False, \ 80 | has_relu=False, inplace=True, has_bias=True) 81 | 82 | def forward(self, rgb, x): 83 | output = [] 84 | #### RGB #### 85 | r1 = self.rgb_down_layer0(rgb) 86 | r1 = self.down_scale(r1) 87 | r2 = self.rgb_down_layer1(r1) 88 | r2 = self.down_scale(r2) 89 | r3 = self.rgb_down_layer2(r2) 90 | r3 = self.down_scale(r3) 91 | r4 = self.rgb_down_layer3(r3) 92 | r4 = self.down_scale(r4) 93 | r1 = self.rgb_refine_layer0(r1) 94 | r2 = self.rgb_refine_layer1(r2) 95 | r3 = self.rgb_refine_layer2(r3) 96 | r4 = self.rgb_refine_layer3(r4) 97 | #### Depth #### 98 | x1 = self.depth_down_layer0(x) 99 | x1 = self.down_scale(x1) 100 | x1 = self.depth_down_layer1(x1) 101 | x1 = self.down_scale(x1) 102 | x2 = self.depth_down_layer2(x1) 103 | x2 = self.down_scale(x2) 104 | x = self.depth_down_layer3(x2) 105 | x = self.down_scale(x) 106 | x = self.depth_up_layer0(x) 107 | #out = self.depth_out_layer0(x) 108 | x = x + r4 109 | x = self.up_scale(x) 110 | x = self.depth_up_layer1(x) 111 | #out = self.depth_out_layer1(x) 112 | x = x + x2 + r3 113 | x = self.up_scale(x) 114 | x = self.depth_up_layer2(x) 115 | #out = self.depth_out_layer2(x) 116 | x = x + x1 + r2 117 | x = self.up_scale(x) 118 | x = self.depth_up_layer3(x) 119 | out = self.depth_out_layer3(x) 120 | x = x + r1 121 | x = self.up_scale(x) 122 | x = self.refine_layer0(x) 123 | x = self.refine_layer1(x) 124 | x = self.output_layer(x) 125 | return x 126 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance_pos_encoding_attention_loss/config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import time 4 | import numpy as np 5 | from easydict import EasyDict as edict 6 | import argparse 7 | 8 | C = edict() 9 | config = C 10 | cfg = C 11 | 12 | C.seed = 304 13 | 14 | """please config ROOT_dir and user when u first using""" 15 | C.abs_dir = osp.realpath(".") 16 | C.this_dir = C.abs_dir.split(osp.sep)[-1] 17 | C.label_dir = C.abs_dir.split(osp.sep)[-2] 18 | C.root_dir = C.abs_dir[:C.abs_dir.index('model')] 19 | C.log_dir = osp.abspath(osp.join(C.root_dir, 'log', C.label_dir, C.this_dir)) 20 | C.log_dir_link = osp.join(C.abs_dir, 'log') 21 | C.snapshot_dir = osp.abspath(osp.join(C.log_dir, "snapshot")) 22 | 23 | exp_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) 24 | C.log_file = C.log_dir + '/log_' + exp_time + '.log' 25 | C.link_log_file = C.log_file + '/log_last.log' 26 | C.val_log_file = C.log_dir + '/val_' + exp_time + '.log' 27 | C.link_val_log_file = C.log_dir + '/val_last.log' 28 | 29 | """Data Dir and Weight Dir""" 30 | C.data_source = osp.join(C.root_dir, 'dataset', 'nyu_depth_v2_labeled.mat') 31 | C.train_test_splits = osp.join(C.root_dir, 'dataset', 'nyuv2_splits.mat') 32 | C.is_test = False 33 | 34 | """Path Config""" 35 | 36 | def add_path(path): 37 | if path not in sys.path: 38 | sys.path.insert(0, path) 39 | 40 | 41 | add_path(osp.join(C.root_dir, 'lib')) 42 | 43 | """Image Config""" 44 | C.image_mean = np.array([0.485, 0.456, 0.406]) # 0.485, 0.456, 0.406 45 | C.image_std = np.array([0.229, 0.224, 0.225]) 46 | C.use_gauss_blur = True 47 | C.gaussian_kernel_range = [.3, 1] 48 | C.max_kernel = 51 49 | C.downsampling_scale = 8 50 | C.interpolation = 'LINEAR' 51 | C.use_updown_sampling = False 52 | C.target_size = 320 53 | C.image_height = 320 54 | C.image_width = 320 55 | C.num_train_imgs = 795 56 | C.num_eval_imgs = 654 57 | 58 | """ Settings for network, this would be different for each kind of model""" 59 | C.fix_bias = False 60 | C.fix_bn = False 61 | C.bn_eps = 1e-5 62 | C.bn_momentum = 0.1 63 | C.loss_weight = None 64 | C.pretrained_model = None 65 | 66 | """Train Config""" 67 | C.lr = 1e-3 68 | C.lr_power = 0.9 69 | C.momentum = 0.9 70 | C.weight_decay = 1e-6 71 | C.batch_size = 1 72 | C.nepochs = 20 73 | C.niters_per_epoch = 795 74 | C.num_workers = 4 75 | C.train_scale_array = [1., 1.5, 2, 2.5, 4] 76 | C.business_lr_ratio = 1.0 77 | C.aux_loss_ratio = 1 78 | 79 | """Display Config""" 80 | C.snapshot_iter = 5 81 | C.record_info_iter = 20 82 | C.display_iter = 50 83 | 84 | def open_tensorboard(): 85 | pass 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument( 90 | '-tb', '--tensorboard', default=False, action='store_true') 91 | args = parser.parse_args() 92 | 93 | if args.tensorboard: 94 | open_tensorboard() 95 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance_pos_encoding_attention_loss/dataloader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | from torch.utils import data 7 | 8 | from config import config 9 | from datareader.img_utils import random_scale, random_mirror, normalize, \ 10 | generate_random_crop_pos, random_crop_pad_to_shape, \ 11 | random_uniform_gaussian_blur, normalize_depth, rgb2gray, \ 12 | generate_mask_by_shifting 13 | 14 | class TrainPre(object): 15 | def __init__(self, img_mean, img_std, target_size, use_gauss_blur=True): 16 | self.img_mean = img_mean 17 | self.img_std = img_std 18 | self.target_size = target_size 19 | self.use_gauss_blur = use_gauss_blur 20 | 21 | def __call__(self, img, gt): 22 | img, gt = random_mirror(img, gt) 23 | if config.train_scale_array is not None: 24 | img, gt, scale = random_scale(img, gt, config.train_scale_array) 25 | 26 | #img = normalize(img, self.img_mean, self.img_std) 27 | img = rgb2gray(img) 28 | img = img / 255. 29 | 30 | crop_size = (config.image_height, config.image_width) 31 | crop_pos = generate_random_crop_pos(img.shape[:2], crop_size) 32 | 33 | p_img, _ = random_crop_pad_to_shape(img, crop_pos, crop_size, 0) 34 | p_gt, _ = random_crop_pad_to_shape(gt, crop_pos, crop_size, 0) 35 | 36 | p_mask = generate_mask_by_shifting(p_gt, scale=1, kernel=10, step_size=1, delta=5) 37 | p_depth = p_gt.copy() 38 | if self.use_gauss_blur: 39 | p_depth = random_uniform_gaussian_blur(p_depth, config.gaussian_kernel_range, config.max_kernel) 40 | p_gt = normalize_depth(p_gt) 41 | p_depth = normalize_depth(p_depth) 42 | 43 | p_img = np.expand_dims(p_img, axis=0) 44 | p_depth = np.expand_dims(p_depth, axis=0) 45 | extra_dict = None 46 | 47 | return p_img, p_depth, p_gt, p_mask, extra_dict 48 | 49 | 50 | def get_train_loader(engine, dataset): 51 | data_setting = {'data_source': config.data_source, 52 | 'train_test_splits': config.train_test_splits} 53 | train_preprocess = TrainPre(config.image_mean, config.image_std, 54 | config.target_size, config.use_gauss_blur 55 | ) 56 | 57 | train_dataset = dataset(data_setting, "train", train_preprocess, 58 | config.niters_per_epoch * config.batch_size) 59 | 60 | train_sampler = None 61 | is_shuffle = True 62 | batch_size = config.batch_size 63 | 64 | train_loader = data.DataLoader(train_dataset, 65 | batch_size=batch_size, 66 | num_workers=config.num_workers, 67 | shuffle=is_shuffle) 68 | 69 | return train_loader, train_sampler 70 | 71 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance_pos_encoding_attention_loss/df.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from unet import Unet 5 | from config import config 6 | 7 | class Displacement_Field(nn.Module): 8 | def __init__(self): 9 | super(Displacement_Field, self).__init__() 10 | self.displacement_net = Unet(n_channels=1, rgb_channels=1, n_classes=2) 11 | self.theta = torch.Tensor([[1, 0, 0], 12 | [0, 1, 0]]) 13 | self.theta = self.theta.view(-1, 2, 3) 14 | 15 | def forward(self, rgb, x): 16 | max_disp = .9 17 | displacement_map = self.displacement_net(rgb, x) 18 | output = [] 19 | 20 | displacement_map = displacement_map / 320 21 | displacement_map = displacement_map.clamp(min=-max_disp, max=max_disp) 22 | 23 | theta = self.theta.repeat(x.size()[0], 1, 1) 24 | grid = F.affine_grid(theta, x.size()).cuda() 25 | grid = (grid + displacement_map.transpose(1,2).transpose(2,3)).clamp(min=-1, max=1) 26 | x = F.grid_sample(x, grid) 27 | return x 28 | 29 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance_pos_encoding_attention_loss/train.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import argparse 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.modules.loss import MSELoss 10 | import torch.nn.functional as F 11 | 12 | from config import config 13 | from dataloader import get_train_loader 14 | from df import Displacement_Field 15 | from datasets.nyu import NYUDataset 16 | 17 | from utils.init_func import init_weight 18 | from misc.utils import get_params 19 | from engine.lr_policy import PolyLR 20 | from engine.engine import Engine 21 | 22 | class Mseloss(MSELoss): 23 | def __init__(self): 24 | super(Mseloss, self).__init__() 25 | 26 | def forward(self, input, target, mask=None): 27 | if mask is not None: 28 | input = input.squeeze(1) 29 | input = torch.mul(input, mask) 30 | target = torch.mul(target, mask) 31 | loss = F.mse_loss(input, target, reduction=self.reduction) 32 | 33 | return loss 34 | 35 | parser = argparse.ArgumentParser() 36 | 37 | with Engine(custom_parser=parser) as engine: 38 | args = parser.parse_args() 39 | 40 | seed = config.seed 41 | torch.manual_seed(seed) 42 | 43 | train_loader, train_sampler = get_train_loader(engine, NYUDataset) 44 | 45 | criterion = Mseloss() 46 | BatchNorm2d = nn.BatchNorm2d 47 | 48 | model = Displacement_Field() 49 | init_weight(model.displacement_net, nn.init.xavier_normal_, 50 | BatchNorm2d, config.bn_eps, config.bn_momentum) 51 | base_lr = config.lr 52 | 53 | total_iteration = config.nepochs * config.niters_per_epoch 54 | lr_policy = PolyLR(base_lr, config.lr_power, total_iteration) 55 | 56 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 57 | model.to(device) 58 | 59 | if engine.continue_state_object: 60 | engine.restore_checkpoint() 61 | model.zero_grad() 62 | model.train() 63 | 64 | optimizer = torch.optim.Adam(params=get_params(model), 65 | lr=base_lr) 66 | engine.register_state(dataloader=train_loader, model=model, 67 | optimizer=optimizer) 68 | 69 | for epoch in range(engine.state.epoch, config.nepochs): 70 | bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' 71 | pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, 72 | bar_format=bar_format) 73 | dataloader = iter(train_loader) 74 | for idx in pbar: 75 | optimizer.zero_grad() 76 | engine.update_iteration(epoch, idx) 77 | minibatch = dataloader.next() 78 | 79 | imgs = minibatch['guidance'] 80 | deps = minibatch['data'] 81 | gts = minibatch['label'] 82 | masks = minibatch['mask'] 83 | 84 | imgs = imgs.cuda() 85 | imgs = torch.autograd.Variable(imgs) 86 | deps = deps.cuda() 87 | deps = torch.autograd.Variable(deps) 88 | gts = gts.cuda() 89 | gts = torch.autograd.Variable(gts) 90 | masks = masks.cuda() 91 | masks = torch.autograd.Variable(masks) 92 | 93 | pred = model(imgs, deps) 94 | loss = criterion(pred, gts, masks) 95 | current_idx = epoch * config.niters_per_epoch + idx 96 | lr = lr_policy.get_lr(current_idx) 97 | 98 | optimizer.param_groups[0]['lr'] = lr 99 | loss.backward() 100 | optimizer.step() 101 | print_str = 'Epoch{}/{}'.format(epoch, config.nepochs) \ 102 | + ' Iter{}/{}:'.format(idx + 1, config.niters_per_epoch) \ 103 | + ' lr=%.2e' % lr \ 104 | + ' loss=%.6f' % float(loss) 105 | 106 | pbar.set_description(print_str, refresh=False) 107 | 108 | if (epoch == (config.nepochs - 1)) or (epoch % config.snapshot_iter == 0): 109 | engine.save_and_link_checkpoint(config.snapshot_dir, 110 | config.log_dir, 111 | config.log_dir_link) 112 | -------------------------------------------------------------------------------- /model/nyu/df_nyu_rgb_guidance_pos_encoding_attention_loss/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from config import config 7 | from layers.basic_module import ConvBnRelu, ConvBnLeakyRelu, SeparableConvBnLeakyRelu, SeparableConvBnRelu, \ 8 | SELayer, ChannelAttention, BNRefine, RefineResidual, AttentionRefinement, GlobalAvgPool2d, \ 9 | FeatureFusion 10 | 11 | def positionalencoding2d(d_model, height, width): 12 | if d_model % 4 != 0: 13 | raise ValueError("Cannot use sin/cos positional encoding with " 14 | "odd dimension (got dim={:d})".format(d_model)) 15 | pe = torch.zeros(d_model, height, width) 16 | # Each dimension use half of d_model 17 | d_model = int(d_model / 2) 18 | div_term = torch.exp(torch.arange(0., d_model, 2) * 19 | -(math.log(10000.0) / d_model)) 20 | pos_w = torch.arange(0., width).unsqueeze(1) 21 | pos_h = torch.arange(0., height).unsqueeze(1) 22 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 23 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 24 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 25 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 26 | 27 | return pe.cuda() 28 | 29 | def addpositionalembed(tensor): 30 | n,c,h,w = tensor.size() 31 | pos_embed = positionalencoding2d(c,h,w) 32 | pos_embed = pos_embed.unsqueeze(0).repeat(n,1,1,1) 33 | tensor += pos_embed 34 | return tensor 35 | 36 | def concatpositionalembed(tensor): 37 | n,c,h,w = tensor.size() 38 | pos_embed = positionalencoding2d(c,h,w) 39 | pos_embed = pos_embed.unsqueeze(0).repeat(n,1,1,1) 40 | tensor = torch.cat((tensor, pos_embed), 1) 41 | return tensor 42 | 43 | def multipositionalembed(tensor): 44 | n,c,h,w = tensor.size() 45 | pos_embed = positionalencoding2d(c,h,w) 46 | pos_embed = pos_embed.unsqueeze(0).repeat(n,1,1,1) 47 | tensor *= pos_embed 48 | return tensor 49 | 50 | ########################################################### 51 | 52 | class Unet(nn.Module): 53 | def __init__(self, n_channels, rgb_channels, n_classes): 54 | super(Unet, self).__init__() 55 | 56 | self.positionalembed = addpositionalembed 57 | 58 | self.down_scale = nn.MaxPool2d(2) 59 | self.up_scale = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 60 | 61 | self.depth_down_layer0 = ConvBnLeakyRelu(n_channels, 8, 3, 1, 1, 1, 1,\ 62 | has_bn=True, leaky_alpha=0.3, \ 63 | has_leaky_relu=True, inplace=True, has_bias=True) 64 | self.depth_down_layer1 = ConvBnLeakyRelu(8, 16, 3, 1, 1, 1, 1,\ 65 | has_bn=True, leaky_alpha=0.3, \ 66 | has_leaky_relu=True, inplace=True, has_bias=True) 67 | self.depth_down_layer2 = ConvBnLeakyRelu(16, 32, 3, 1, 1, 1, 1,\ 68 | has_bn=True, leaky_alpha=0.3, \ 69 | has_leaky_relu=True, inplace=True, has_bias=True) 70 | self.depth_down_layer3 = ConvBnLeakyRelu(32, 64, 3, 1, 1, 1, 1,\ 71 | has_bn=True, leaky_alpha=0.3, \ 72 | has_leaky_relu=True, inplace=True, has_bias=True) 73 | 74 | self.depth_down_embed0 = ConvBnLeakyRelu(16, 8, 3, 1, 1, 1, 1,\ 75 | has_bn=True, leaky_alpha=0.3, \ 76 | has_leaky_relu=True, inplace=True, has_bias=True) 77 | self.depth_down_embed1 = ConvBnLeakyRelu(32, 16, 3, 1, 1, 1, 1,\ 78 | has_bn=True, leaky_alpha=0.3, \ 79 | has_leaky_relu=True, inplace=True, has_bias=True) 80 | self.depth_down_embed2 = ConvBnLeakyRelu(64, 32, 3, 1, 1, 1, 1,\ 81 | has_bn=True, leaky_alpha=0.3, \ 82 | has_leaky_relu=True, inplace=True, has_bias=True) 83 | self.depth_down_embed3 = ConvBnLeakyRelu(128, 64, 3, 1, 1, 1, 1,\ 84 | has_bn=True, leaky_alpha=0.3, \ 85 | has_leaky_relu=True, inplace=True, has_bias=True) 86 | 87 | self.depth_up_layer0 = RefineResidual(64, 64, relu_layer='LeakyReLU', \ 88 | has_bias=True, has_relu=True, leaky_alpha=0.3) 89 | self.depth_up_layer1 = RefineResidual(64, 32, relu_layer='LeakyReLU', \ 90 | has_bias=True, has_relu=True, leaky_alpha=0.3) 91 | self.depth_up_layer2 = RefineResidual(32, 16, relu_layer='LeakyReLU', \ 92 | has_bias=True, has_relu=True, leaky_alpha=0.3) 93 | self.depth_up_layer3 = RefineResidual(16, 8, relu_layer='LeakyReLU', \ 94 | has_bias=True, has_relu=True, leaky_alpha=0.3) 95 | 96 | self.depth_up_embed0 = RefineResidual(128, 64, relu_layer='LeakyReLU', \ 97 | has_bias=True, has_relu=True, leaky_alpha=0.3) 98 | self.depth_up_embed1 = RefineResidual(128, 64, relu_layer='LeakyReLU', \ 99 | has_bias=True, has_relu=True, leaky_alpha=0.3) 100 | self.depth_up_embed2 = RefineResidual(64, 32, relu_layer='LeakyReLU', \ 101 | has_bias=True, has_relu=True, leaky_alpha=0.3) 102 | self.depth_up_embed3 = RefineResidual(32, 16, relu_layer='LeakyReLU', \ 103 | has_bias=True, has_relu=True, leaky_alpha=0.3) 104 | 105 | 106 | self.depth_out_layer0 = RefineResidual(64, 1, relu_layer='LeakyReLU', \ 107 | has_bias=True, has_relu=True, leaky_alpha=0.3) 108 | self.depth_out_layer1 = RefineResidual(32, 1, relu_layer='LeakyReLU', \ 109 | has_bias=True, has_relu=True, leaky_alpha=0.3) 110 | self.depth_out_layer2 = RefineResidual(16, 1, relu_layer='LeakyReLU', \ 111 | has_bias=True, has_relu=True, leaky_alpha=0.3) 112 | self.depth_out_layer3 = RefineResidual(8, 1, relu_layer='LeakyReLU', \ 113 | has_bias=True, has_relu=True, leaky_alpha=0.3) 114 | 115 | self.refine_layer0 = ConvBnLeakyRelu(8, 4, 3, 1, 1, 1, 1,\ 116 | has_bn=True, leaky_alpha=0.3, \ 117 | has_leaky_relu=True, inplace=True, has_bias=True) 118 | self.refine_layer1 = ConvBnLeakyRelu(4, 4, 3, 1, 1, 1, 1,\ 119 | has_bn=True, leaky_alpha=0.3, \ 120 | has_leaky_relu=True, inplace=True, has_bias=True) 121 | 122 | self.rgb_down_layer0 = ConvBnRelu(rgb_channels, 8, 3, 1, 1, 1, 1,\ 123 | has_bn=True, \ 124 | inplace=True, has_bias=True) 125 | self.rgb_down_layer1 = ConvBnRelu(8, 16, 3, 1, 1, 1, 1,\ 126 | has_bn=True, \ 127 | inplace=True, has_bias=True) 128 | self.rgb_down_layer2 = ConvBnRelu(16, 32, 3, 1, 1, 1, 1,\ 129 | has_bn=True, \ 130 | inplace=True, has_bias=True) 131 | self.rgb_down_layer3 = ConvBnRelu(32, 64, 3, 1, 1, 1, 1,\ 132 | has_bn=True, \ 133 | inplace=True, has_bias=True) 134 | 135 | self.rgb_down_embed0 = ConvBnRelu(16, 8, 3, 1, 1, 1, 1,\ 136 | has_bn=True, \ 137 | inplace=True, has_bias=True) 138 | self.rgb_down_embed1 = ConvBnRelu(32, 16, 3, 1, 1, 1, 1,\ 139 | has_bn=True, \ 140 | inplace=True, has_bias=True) 141 | self.rgb_down_embed2 = ConvBnRelu(64, 32, 3, 1, 1, 1, 1,\ 142 | has_bn=True, \ 143 | inplace=True, has_bias=True) 144 | self.rgb_down_embed3 = ConvBnRelu(128, 64, 3, 1, 1, 1, 1,\ 145 | has_bn=True, \ 146 | inplace=True, has_bias=True) 147 | 148 | self.rgb_refine_layer0 = RefineResidual(8, 8, relu_layer='ReLU', \ 149 | has_bias=True, has_relu=True) 150 | self.rgb_refine_layer1 = RefineResidual(16, 16, relu_layer='ReLU', \ 151 | has_bias=True, has_relu=True) 152 | self.rgb_refine_layer2 = RefineResidual(32, 32, relu_layer='ReLU', \ 153 | has_bias=True, has_relu=True) 154 | self.rgb_refine_layer3 = RefineResidual(64, 64, relu_layer='ReLU', \ 155 | has_bias=True, has_relu=True) 156 | 157 | self.output_layer = ConvBnRelu(4, 2, 3, 1, 1, 1, 1,\ 158 | has_bn=False, \ 159 | has_relu=False, inplace=True, has_bias=True) 160 | 161 | def forward(self, rgb, x): 162 | output = [] 163 | #### RGB #### 164 | r1 = self.rgb_down_layer0(rgb) 165 | r1 = self.down_scale(r1) 166 | r1 = self.positionalembed(r1) 167 | #r1 = self.rgb_down_embed0(r1) 168 | 169 | r2 = self.rgb_down_layer1(r1) 170 | r2 = self.down_scale(r2) 171 | r2 = self.positionalembed(r2) 172 | #r2 = self.rgb_down_embed1(r2) 173 | 174 | r3 = self.rgb_down_layer2(r2) 175 | r3 = self.down_scale(r3) 176 | r3 = self.positionalembed(r3) 177 | #r3 = self.rgb_down_embed2(r3) 178 | 179 | r4 = self.rgb_down_layer3(r3) 180 | r4 = self.down_scale(r4) 181 | r4 = self.positionalembed(r4) 182 | #r4 = self.rgb_down_embed3(r4) 183 | 184 | r1 = self.rgb_refine_layer0(r1) 185 | r2 = self.rgb_refine_layer1(r2) 186 | r3 = self.rgb_refine_layer2(r3) 187 | r4 = self.rgb_refine_layer3(r4) 188 | #### Depth #### 189 | x1 = self.depth_down_layer0(x) 190 | x1 = self.down_scale(x1) 191 | #x1 = self.positionalembed(x1) 192 | #x1 = self.depth_down_embed0(x1) 193 | 194 | x1 = self.depth_down_layer1(x1) 195 | x1 = self.down_scale(x1) 196 | #x1 = self.positionalembed(x1) 197 | #x1 = self.depth_down_embed1(x1) 198 | 199 | x2 = self.depth_down_layer2(x1) 200 | x2 = self.down_scale(x2) 201 | #x2 = self.positionalembed(x2) 202 | #x2 = self.depth_down_embed2(x2) 203 | 204 | x = self.depth_down_layer3(x2) 205 | x = self.down_scale(x) 206 | #x = self.positionalembed(x) 207 | #x = self.depth_down_embed3(x) 208 | 209 | #x = self.depth_up_embed0(x) 210 | x = self.depth_up_layer0(x) 211 | #out = self.depth_out_layer0(x) 212 | x = x + r4 213 | x = self.up_scale(x) 214 | 215 | #x = self.positionalembed(x) 216 | #x = self.depth_up_embed1(x) 217 | x = self.depth_up_layer1(x) 218 | #out = self.depth_out_layer1(x) 219 | x = x + x2 + r3 220 | x = self.up_scale(x) 221 | 222 | #x = self.positionalembed(x) 223 | #x = self.depth_up_embed2(x) 224 | x = self.depth_up_layer2(x) 225 | #out = self.depth_out_layer2(x) 226 | x = x + x1 + r2 227 | x = self.up_scale(x) 228 | 229 | #x = self.positionalembed(x) 230 | #x = self.depth_up_embed3(x) 231 | x = self.depth_up_layer3(x) 232 | x = x + r1 233 | x = self.positionalembed(x) 234 | 235 | x = self.up_scale(x) 236 | x = self.refine_layer0(x) 237 | x = self.refine_layer1(x) 238 | x = self.output_layer(x) 239 | return x 240 | 241 | if __name__ == '__main__': 242 | unet = Unet(1,1,1).cuda() 243 | img = torch.rand(1,1,320,320).cuda() 244 | depth = torch.rand(1,1,320,320).cuda() 245 | x = unet(img, depth) 246 | print(x.shape) 247 | --------------------------------------------------------------------------------