├── .gitignore
├── README.md
├── data.py
├── evaluate.py
├── image_helper.py
├── main.py
├── model.py
└── test_run.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | data
3 | models
4 | data2/
5 | plots/
6 | new_plots/
7 | logs/
8 | nyu_depth_v2_labeled.mat
9 | *.pyc
10 | *.pth
11 | *.csv
12 | *.DS_Store
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Depth Map Prediction from a Single Image using a Multi-Scale Deep Network
2 | 1. [depth-map-prediction](https://github.com/imran3180/depth-map-prediction)
3 | 2. [unet-depth-prediction](https://github.com/DikshaMeghwal/unet-depth-prediction)
4 | ----
5 | This repository is the first part of the project and Pytorch implementation of Depth Map Prediction from a Single Image using a Multi-Scale Deep Network by David Eigen, Christian Puhrsch and Rob Fergus. [Paper Link](https://cs.nyu.edu/~deigen/depth/depth_nips14.pdf)
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | Architecture
14 | ----------
15 |
16 |
17 |
18 |
19 | Data
20 | ----------
21 | We used [NYU Depth Dataset V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) as our dataset. We used Labeled dataset (~2.8 GB) of NYU Depth Dataset which provides 1449 densely labeled pairs of aligned RGB and depth images. We divided labeled dataset into three parts (Training - 1024, Validation - 224, Testing - 201) for our project. NYU Dataset also provides Raw dataset (~428 GB) on which we couldn't train due to machine capacity.
22 |
23 | Training & Validation
24 | -----------
25 |
26 | Evaluation
27 | -------------
28 |
29 |
30 | Contributors
31 | ---------------------------------
32 |
33 | - [Imran](https://github.com/imran3180/)
34 | - [Diksha Meghwal](https://github.com/DikshaMeghwal/)
35 |
36 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import zipfile
3 | import os
4 | import pdb
5 | import torchvision.transforms as transforms
6 | from torch.utils.data import Dataset, DataLoader
7 | import h5py
8 | import numpy as np
9 | from PIL import Image
10 | import torch.nn as nn
11 | import torch
12 | import copy
13 |
14 | class TransposeDepthInput(object):
15 | def __call__(self, depth):
16 | depth = depth.transpose((2, 0, 1))
17 | depth = torch.from_numpy(depth)
18 | depth = depth.view(1, depth.shape[0], depth.shape[1], depth.shape[2])
19 | depth = nn.functional.interpolate(depth, size = (55, 74), mode='bilinear', align_corners=False)
20 | depth = torch.log(depth)
21 | # depth = (depth - depth.min())/(depth.max() - depth.min())
22 | return depth[0]
23 |
24 | rgb_data_transforms = transforms.Compose([
25 | transforms.Resize((228, 304)), # Different for Input Image & Depth Image
26 | transforms.ToTensor(),
27 | ])
28 |
29 | depth_data_transforms = transforms.Compose([
30 | TransposeDepthInput(),
31 | ])
32 |
33 | input_for_plot_transforms = transforms.Compose([
34 | transforms.Resize((55, 74)), # Different for Input Image & Depth Image
35 | transforms.ToTensor(),
36 | ])
37 |
38 | class NYUDataset(Dataset):
39 | def calculate_mean(self, images):
40 | mean_image = np.mean(images, axis=0)
41 | return mean_image
42 |
43 | def __init__(self, filename, type, rgb_transform = None, depth_transform = None):
44 | f = h5py.File(filename, 'r')
45 | # images_data = copy.deepcopy(f['images'][0:1449])
46 | # depths_data = copy.deepcopy(f['depths'][0:1449])
47 | # merged_data = np.concatenate((images_data, depths_data.reshape((1449, 1, 640, 480))), axis=1)
48 |
49 | # np.random.shuffle(merged_data)
50 | # images_data = merged_data[:,0:3,:,:]
51 | # depths_data = merged_data[:,3:4,:,:]
52 |
53 | images_data = f['images'][0:1449]
54 | depths_data = f['depths'][0:1449]
55 |
56 | if type == "training":
57 | # self.images = images_data[0:1024]
58 | # self.depths = depths_data[0:1024]
59 | self.images = images_data[0:1024]
60 | self.depths = depths_data[0:1024]
61 | elif type == "validation":
62 | self.images = images_data[1024:1248]
63 | self.depths = depths_data[1024:1248]
64 | # self.images = images_data[1024:1072]
65 | # self.depths = depths_data[1024:1072]
66 | elif type == "test":
67 | self.images = images_data[1248:]
68 | self.depths = depths_data[1248:]
69 | # self.images = images_data[0:32]
70 | # self.depths = depths_data[0:32]
71 | self.rgb_transform = rgb_transform
72 | self.depth_transform = depth_transform
73 | self.mean_image = self.calculate_mean(images_data[0:1449])
74 |
75 | def __len__(self):
76 | return len(self.images)
77 |
78 | def __getitem__(self, idx):
79 | image = self.images[idx]
80 | # image = (image - self.mean_image)/np.std(image)
81 | image = image.transpose((2, 1, 0))
82 | # image = (image - image.min())/(image.max() - image.min())
83 | # image = image * 255
84 | # image = image.astype('uint8')
85 | image = Image.fromarray(image)
86 | if self.rgb_transform:
87 | image = self.rgb_transform(image)
88 |
89 | depth = self.depths[idx]
90 | depth = np.reshape(depth, (1, depth.shape[0], depth.shape[1]))
91 | depth = depth.transpose((2, 1, 0))
92 | if self.depth_transform:
93 | depth = self.depth_transform(depth)
94 | sample = {'image': image, 'depth': depth}
95 | return sample
96 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import argparse
3 | from tqdm import tqdm
4 | import os
5 | import PIL.Image as Image
6 |
7 | import torch
8 | from torch.autograd import Variable
9 | import torch.nn.functional as F
10 | import torchvision.datasets as datasets
11 |
12 | matplotlib.use('Agg')
13 | import matplotlib.pyplot as plt
14 | from mpl_toolkits.axes_grid1 import ImageGrid
15 |
16 | from model import coarseNet, fineNet
17 | import pdb
18 | import numpy as np
19 |
20 | parser = argparse.ArgumentParser(description='PyTorch depth prediction evaluation script')
21 | parser.add_argument('model_folder', type=str, metavar='F',
22 | help='In which folder have you saved the model')
23 | parser.add_argument('--data', type=str, default='data', metavar='D',
24 | help="folder where data is located. train_data.zip and test_data.zip need to be found in the folder")
25 | parser.add_argument('--model_no', type=int, default = 1, metavar='N',
26 | help='Which model no to evaluate (default: 1(first model))')
27 | parser.add_argument('--batch-size', type = int, default = 8, metavar = 'N',
28 | help='input batch size for training (default: 8)')
29 |
30 | args = parser.parse_args()
31 |
32 | output_height = 55
33 | output_width = 74
34 |
35 | coarse_state_dict = torch.load("models/" + args.model_folder + "/coarse_model_" + str(args.model_no) + ".pth")
36 | fine_state_dict = torch.load("models/" + args.model_folder + "/fine_model_" + str(args.model_no) + ".pth")
37 |
38 | coarse_model = coarseNet()
39 | fine_model = fineNet()
40 | coarse_model.cuda()
41 | fine_model.cuda()
42 |
43 |
44 | coarse_model.load_state_dict(coarse_state_dict)
45 | fine_model.load_state_dict(fine_state_dict)
46 | coarse_model.eval()
47 | fine_model.eval()
48 |
49 | dtype=torch.cuda.FloatTensor
50 |
51 | from data import NYUDataset, input_for_plot_transforms, rgb_data_transforms, depth_data_transforms
52 |
53 | test_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat',
54 | 'test',
55 | rgb_transform = rgb_data_transforms,
56 | depth_transform = depth_data_transforms),
57 | batch_size = args.batch_size,
58 | shuffle = False, num_workers = 0)
59 |
60 | input_for_plot_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat',
61 | 'test',
62 | rgb_transform = input_for_plot_transforms,
63 | depth_transform = depth_data_transforms),
64 | batch_size = args.batch_size,
65 | shuffle = False, num_workers = 0)
66 |
67 | def plot_grid(fig, plot_input, coarse_output, fine_output, actual_output, row_no):
68 | grid = ImageGrid(fig, 141, nrows_ncols=(row_no, 4), axes_pad=0.05, label_mode="1")
69 | for i in range(row_no):
70 | for j in range(4):
71 | if(j == 0):
72 | grid[i*4+j].imshow(np.transpose(plot_input[i], (1, 2, 0)), interpolation="nearest")
73 | if(j == 1):
74 | grid[i*4+j].imshow(np.transpose(coarse_output[i][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")
75 | if(j == 2):
76 | grid[i*4+j].imshow(np.transpose(fine_output[i][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")
77 | if(j == 3):
78 | grid[i*4+j].imshow(np.transpose(actual_output[i][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")
79 |
80 | batch_idx = 0
81 | for batch_idx,(data, plot_data) in enumerate(zip(test_loader, input_for_plot_loader)):
82 | rgb, depth = torch.tensor(data['image'].cuda(), requires_grad = False), torch.tensor(data['depth'].cuda(), requires_grad = False)
83 | plot_input, actual_output = torch.tensor(plot_data['image'].cuda(), requires_grad = False), torch.tensor(plot_data['depth'].cuda(), requires_grad = False)
84 | print('evaluating batch:' + str(batch_idx))
85 | coarse_output = coarse_model(rgb.type(dtype))
86 | fine_output = fine_model(rgb.type(dtype), coarse_output.type(dtype))
87 | depth_dim = list(depth.size())
88 | F = plt.figure(1, (30, 60))
89 | F.subplots_adjust(left=0.05, right=0.95)
90 | # pdb.set_trace()
91 | # plot_input = torch.exp(plot_input)-1
92 | # coarse_output = torch.exp(coarse_output)-1
93 | # fine_output = torch.exp(fine_output)-1
94 | # actual_output = torch.exp(actual_output)-1
95 |
96 | plot_grid(F, plot_input, coarse_output, fine_output, actual_output, depth_dim[0])
97 | plt.savefig("new_plots/" + args.model_folder + "_" + str(args.model_no) + "_" + str(batch_idx) + ".jpg")
98 | plt.show()
99 | #batch_idx = batch_idx + 1
100 | # if batch_idx == 1: break
--------------------------------------------------------------------------------
/image_helper.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import argparse
3 | from tqdm import tqdm
4 | import os
5 | import PIL.Image as Image
6 |
7 | import torch
8 | from torch.autograd import Variable
9 | import torch.nn.functional as F
10 | import torchvision.datasets as datasets
11 |
12 | matplotlib.use('Agg')
13 | import matplotlib.pyplot as plt
14 | from mpl_toolkits.axes_grid1 import ImageGrid
15 |
16 | from model import coarseNet, fineNet
17 | import pdb
18 | import numpy as np
19 |
20 | def plot_grid(fig, rgb, depth, row_no):
21 | grid = ImageGrid(fig, 141, nrows_ncols = (row_no, 2), axes_pad=0.05, label_mode="1")
22 | # pdb.set_trace()
23 | for i in range(row_no):
24 | for j in range(2):
25 | if(j == 0):
26 | grid[i*2+j].imshow(np.transpose(rgb[i].numpy(), (1, 2, 0)), interpolation="nearest")
27 | if(j == 1):
28 | grid[i*2+j].imshow(np.transpose(depth[i][0].numpy(), (0, 1)), interpolation="nearest")
29 |
30 |
31 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torchvision import datasets, transforms
8 | from torch.autograd import Variable
9 | import pdb
10 | from logger import Logger
11 | import os
12 |
13 | ############## Image related
14 | import matplotlib
15 | matplotlib.use('Agg')
16 | import matplotlib.pyplot as plt
17 | from mpl_toolkits.axes_grid1 import ImageGrid
18 | ####################
19 |
20 | # Training settings
21 | parser = argparse.ArgumentParser(description='PyTorch depth map prediction example')
22 | parser.add_argument('model_folder', type=str, metavar='F',
23 | help='In which folder do you want to save the model')
24 | parser.add_argument('--data', type=str, default='data', metavar='D',
25 | help="folder where data is located. train_data.zip and test_data.zip need to be found in the folder")
26 | parser.add_argument('--batch-size', type = int, default = 32, metavar = 'N',
27 | help='input batch size for training (default: 8)')
28 | parser.add_argument('--epochs', type=int, default = 10, metavar='N',
29 | help='number of epochs to train (default: 10)')
30 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
31 | help='learning rate (default: 0.001)')
32 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
33 | help='SGD momentum (default: 0.5)')
34 | parser.add_argument('--seed', type=int, default=1, metavar='S',
35 | help='random seed (default: 1)')
36 | parser.add_argument('--log-interval', type=int, default=20, metavar='N',
37 | help='how many batches to wait before logging training status')
38 | parser.add_argument('--suffix', type=str, default='', metavar='D',
39 | help='suffix for the filename of models and output files')
40 | args = parser.parse_args()
41 |
42 | torch.manual_seed(args.seed) # setting seed for random number generation
43 |
44 | output_height = 55
45 | output_width = 74
46 |
47 | from data import NYUDataset, rgb_data_transforms, depth_data_transforms
48 | from image_helper import plot_grid
49 |
50 | train_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat',
51 | 'training',
52 | rgb_transform = rgb_data_transforms,
53 | depth_transform = depth_data_transforms),
54 | batch_size = args.batch_size,
55 | shuffle = True, num_workers = 5)
56 |
57 | val_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat',
58 | 'validation',
59 | rgb_transform = rgb_data_transforms,
60 | depth_transform = depth_data_transforms),
61 | batch_size = args.batch_size,
62 | shuffle = False, num_workers = 5)
63 |
64 | test_loader = torch.utils.data.DataLoader(NYUDataset( 'nyu_depth_v2_labeled.mat',
65 | 'test',
66 | rgb_transform = rgb_data_transforms,
67 | depth_transform = depth_data_transforms),
68 | batch_size = args.batch_size,
69 | shuffle = False, num_workers = 5)
70 |
71 | from model import coarseNet, fineNet
72 | coarse_model = coarseNet()
73 | fine_model = fineNet()
74 | coarse_model.cuda()
75 | fine_model.cuda()
76 |
77 | # Paper values for SGD
78 | coarse_optimizer = optim.SGD([{'params': coarse_model.conv1.parameters(), 'lr': 0.001},{'params': coarse_model.conv2.parameters(), 'lr': 0.001},{'params': coarse_model.conv3.parameters(), 'lr': 0.001},{'params': coarse_model.conv4.parameters(), 'lr': 0.001},{'params': coarse_model.conv5.parameters(), 'lr': 0.001},{'params': coarse_model.fc1.parameters(), 'lr': 0.1},{'params': coarse_model.fc2.parameters(), 'lr': 0.1}], lr = 0.001, momentum = 0.9)
79 | fine_optimizer = optim.SGD([{'params': fine_model.conv1.parameters(), 'lr': 0.001},{'params': fine_model.conv2.parameters(), 'lr': 0.01},{'params': fine_model.conv3.parameters(), 'lr': 0.001}], lr = 0.001, momentum = 0.9)
80 |
81 | # Changed values
82 | # coarse_optimizer = optim.SGD([{'params': coarse_model.conv1.parameters(), 'lr': 0.01},{'params': coarse_model.conv2.parameters(), 'lr': 0.01},{'params': coarse_model.conv3.parameters(), 'lr': 0.01},{'params': coarse_model.conv4.parameters(), 'lr': 0.01},{'params': coarse_model.conv5.parameters(), 'lr': 0.01},{'params': coarse_model.fc1.parameters(), 'lr': 0.1},{'params': coarse_model.fc2.parameters(), 'lr': 0.1}], lr = 0.01, momentum = 0.9)
83 | # fine_optimizer = optim.SGD(fine_model.parameters(), lr=args.lr, momentum=args.momentum)
84 | # fine modified but default fine work more.
85 | #fine_optimizer = optim.SGD([{'params': coarse_model.conv1.parameters(), 'lr': 0.01},{'params': coarse_model.conv2.parameters(), 'lr': 0.1},{'params': coarse_model.conv3.parameters(), 'lr': 0.01}], lr = 0.01, momentum = 0.9)
86 |
87 | # default SGD optimiser - don't work
88 | # # coarse_optimizer = optim.SGD(coarse_model.parameters(), lr=args.lr, momentum=args.momentum)
89 | # fine_optimizer = optim.SGD(fine_model.parameters(), lr=args.lr, momentum=args.momentum)
90 |
91 | # coarse_optimizer = optim.Adadelta(coarse_model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
92 | # fine_optimizer = optim.Adadelta(fine_model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
93 |
94 | # coarse_optimizer = optim.Adagrad(coarse_model.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
95 | # fine_optimizer = optim.Adagrad(fine_model.parameters(), lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
96 |
97 | # coarse_optimizer = optim.Adam(coarse_model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
98 | # fine_optimizer = optim.Adam(fine_model.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
99 |
100 | # coarse_optimizer = optim.Adamax(coarse_model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
101 | # fine_optimizer = optim.Adamax(fine_model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
102 |
103 | # coarse_optimizer = optim.ASGD(coarse_model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)
104 | # fine_optimizer = optim.ASGD(fine_model.parameters(), lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)
105 |
106 |
107 | dtype=torch.cuda.FloatTensor
108 | logger = Logger('./logs/' + args.model_folder)
109 |
110 | def custom_loss_function(output, target):
111 | # di = output - target
112 | di = target - output
113 | n = (output_height * output_width)
114 | di2 = torch.pow(di, 2)
115 | fisrt_term = torch.sum(di2,(1,2,3))/n
116 | second_term = 0.5*torch.pow(torch.sum(di,(1,2,3)), 2)/ (n**2)
117 | loss = fisrt_term - second_term
118 | return loss.mean()
119 |
120 | def scale_invariant(output, target):
121 | # di = output - target
122 | di = target - output
123 | n = (output_height * output_width)
124 | di2 = torch.pow(di, 2)
125 | fisrt_term = torch.sum(di2,(1,2,3))/n
126 | second_term = torch.pow(torch.sum(di,(1,2,3)), 2)/ (n**2)
127 | loss = fisrt_term - second_term
128 | return loss.mean()
129 |
130 | # def custom_loss_function(output, target):
131 | # diff = target - output
132 | # alpha = torch.sum(diff, (1,2,3))/(output_height * output_width)
133 | # loss_val = 0
134 | # for i in range(alpha.shape[0]):
135 | # loss_val += torch.sum(torch.pow(((output[i] - target[i]) - alpha[i]), 2))/(2 * output_height * output_width)
136 | # loss_val = loss_val/output.shape[0]
137 | # return loss_val
138 |
139 | # All Error Function
140 | def threeshold_percentage(output, target, threeshold_val):
141 | d1 = torch.exp(output)/torch.exp(target)
142 | d2 = torch.exp(target)/torch.exp(output)
143 | # d1 = output/target
144 | # d2 = target/output
145 | max_d1_d2 = torch.max(d1,d2)
146 | zero = torch.zeros(output.shape[0], output.shape[1], output.shape[2], output.shape[3])
147 | one = torch.ones(output.shape[0], output.shape[1], output.shape[2], output.shape[3])
148 | bit_mat = torch.where(max_d1_d2.cpu() < threeshold_val, one, zero)
149 | count_mat = torch.sum(bit_mat, (1,2,3))
150 | threeshold_mat = count_mat/(output.shape[2] * output.shape[3])
151 | return threeshold_mat.mean()
152 |
153 | def rmse_linear(output, target):
154 | actual_output = torch.exp(output)
155 | actual_target = torch.exp(target)
156 | # actual_output = output
157 | # actual_target = target
158 | diff = actual_output - actual_target
159 | diff2 = torch.pow(diff, 2)
160 | mse = torch.sum(diff2, (1,2,3))/(output.shape[2] * output.shape[3])
161 | rmse = torch.sqrt(mse)
162 | return rmse.mean()
163 |
164 | def rmse_log(output, target):
165 | diff = output - target
166 | # diff = torch.log(output) - torch.log(target)
167 | diff2 = torch.pow(diff, 2)
168 | mse = torch.sum(diff2, (1,2,3))/(output.shape[2] * output.shape[3])
169 | rmse = torch.sqrt(mse)
170 | return mse.mean()
171 |
172 | def abs_relative_difference(output, target):
173 | actual_output = torch.exp(output)
174 | actual_target = torch.exp(target)
175 | # actual_output = output
176 | # actual_target = target
177 | abs_relative_diff = torch.abs(actual_output - actual_target)/actual_target
178 | abs_relative_diff = torch.sum(abs_relative_diff, (1,2,3))/(output.shape[2] * output.shape[3])
179 | return abs_relative_diff.mean()
180 |
181 | def squared_relative_difference(output, target):
182 | actual_output = torch.exp(output)
183 | actual_target = torch.exp(target)
184 | # actual_output = output
185 | # actual_target = target
186 | square_relative_diff = torch.pow(torch.abs(actual_output - actual_target), 2)/actual_target
187 | square_relative_diff = torch.sum(square_relative_diff, (1,2,3))/(output.shape[2] * output.shape[3])
188 | return square_relative_diff.mean()
189 |
190 | def train_coarse(epoch):
191 | coarse_model.train()
192 | train_coarse_loss = 0
193 | for batch_idx, data in enumerate(train_loader):
194 | # variable
195 | rgb, depth = torch.tensor(data['image'].cuda(), requires_grad = True), torch.tensor(data['depth'].cuda(), requires_grad = True)
196 | coarse_optimizer.zero_grad()
197 | output = coarse_model(rgb.type(dtype))
198 | loss = custom_loss_function(output, depth)
199 | loss.backward()
200 | coarse_optimizer.step()
201 | train_coarse_loss += loss.item()
202 | train_coarse_loss /= (batch_idx + 1)
203 | return train_coarse_loss
204 | # print('Epoch: {} Training set(Coarse) average loss: {:.4f}'.format(epoch, train_coarse_loss))
205 | # if batch_idx % args.log_interval == 0:
206 | # training_tag = "coarse training loss epoch:" + str(epoch)
207 | # logger.scalar_summary(training_tag, loss.item(), batch_idx)
208 |
209 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
210 | # epoch, batch_idx * len(rgb), len(train_loader.dataset),
211 | # 100. * batch_idx / len(train_loader), loss.item()))
212 |
213 | def train_fine(epoch):
214 | coarse_model.eval()
215 | fine_model.train()
216 | train_fine_loss = 0
217 | for batch_idx, data in enumerate(train_loader):
218 | # variable
219 | rgb, depth = torch.tensor(data['image'].cuda(), requires_grad = True), torch.tensor(data['depth'].cuda(), requires_grad = True)
220 | fine_optimizer.zero_grad()
221 | coarse_output = coarse_model(rgb.type(dtype)) # it should print last epoch error since coarse is fixed.
222 | output = fine_model(rgb.type(dtype), coarse_output.type(dtype))
223 | loss = custom_loss_function(output, depth)
224 | loss.backward()
225 | fine_optimizer.step()
226 | train_fine_loss += loss.item()
227 | train_fine_loss /= (batch_idx + 1)
228 | return train_fine_loss
229 | # print('Epoch: {} Training set(Fine) average loss: {:.4f}'.format(epoch, train_fine_loss))
230 | # if batch_idx % args.log_interval == 0:
231 | # training_tag = "fine training loss epoch:" + str(epoch)
232 | # logger.scalar_summary(training_tag, loss.item(), batch_idx)
233 |
234 | # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
235 | # epoch, batch_idx * len(rgb), len(train_loader.dataset),
236 | # 100. * batch_idx / len(train_loader), loss.item()))
237 |
238 | def coarse_validation(epoch, training_loss):
239 | coarse_model.eval()
240 | coarse_validation_loss = 0
241 | scale_invariant_loss = 0
242 | delta1_accuracy = 0
243 | delta2_accuracy = 0
244 | delta3_accuracy = 0
245 | rmse_linear_loss = 0
246 | rmse_log_loss = 0
247 | abs_relative_difference_loss = 0
248 | squared_relative_difference_loss = 0
249 |
250 | for batch_idx, data in enumerate(val_loader):
251 | # variable
252 | rgb, depth = torch.tensor(data['image'].cuda(), requires_grad = False), torch.tensor(data['depth'].cuda(), requires_grad = False)
253 | coarse_output = coarse_model(rgb.type(dtype))
254 | coarse_validation_loss += custom_loss_function(coarse_output, depth).item()
255 | # all error functions
256 | scale_invariant_loss += scale_invariant(coarse_output, depth)
257 | delta1_accuracy += threeshold_percentage(coarse_output, depth, 1.25)
258 | delta2_accuracy += threeshold_percentage(coarse_output, depth, 1.25*1.25)
259 | delta3_accuracy += threeshold_percentage(coarse_output, depth, 1.25*1.25*1.25)
260 | rmse_linear_loss += rmse_linear(coarse_output, depth)
261 | rmse_log_loss += rmse_log(coarse_output, depth)
262 | abs_relative_difference_loss += abs_relative_difference(coarse_output, depth)
263 | squared_relative_difference_loss += squared_relative_difference(coarse_output, depth)
264 |
265 | coarse_validation_loss /= (batch_idx + 1)
266 | delta1_accuracy /= (batch_idx + 1)
267 | delta2_accuracy /= (batch_idx + 1)
268 | delta3_accuracy /= (batch_idx + 1)
269 | rmse_linear_loss /= (batch_idx + 1)
270 | rmse_log_loss /= (batch_idx + 1)
271 | abs_relative_difference_loss /= (batch_idx + 1)
272 | squared_relative_difference_loss /= (batch_idx + 1)
273 | logger.scalar_summary("coarse validation loss", coarse_validation_loss, epoch)
274 | # print('\nValidation set: Average loss(Coarse): {:.4f} \n'.format(coarse_validation_loss))
275 | print('Epoch: {} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(epoch, training_loss,
276 | coarse_validation_loss, delta1_accuracy, delta2_accuracy, delta3_accuracy, rmse_linear_loss, rmse_log_loss,
277 | abs_relative_difference_loss, squared_relative_difference_loss))
278 |
279 | def fine_validation(epoch, training_loss):
280 | fine_model.eval()
281 | fine_validation_loss = 0
282 | scale_invariant_loss = 0
283 | delta1_accuracy = 0
284 | delta2_accuracy = 0
285 | delta3_accuracy = 0
286 | rmse_linear_loss = 0
287 | rmse_log_loss = 0
288 | abs_relative_difference_loss = 0
289 | squared_relative_difference_loss = 0
290 | for batch_idx,data in enumerate(val_loader):
291 | # variable
292 | rgb, depth = torch.tensor(data['image'].cuda(), requires_grad = False), torch.tensor(data['depth'].cuda(), requires_grad = False)
293 | coarse_output = coarse_model(rgb.type(dtype))
294 | fine_output = fine_model(rgb.type(dtype), coarse_output.type(dtype))
295 | fine_validation_loss += custom_loss_function(fine_output, depth).item()
296 | # all error functions
297 | scale_invariant_loss += scale_invariant(fine_output, depth)
298 | delta1_accuracy += threeshold_percentage(fine_output, depth, 1.25)
299 | delta2_accuracy += threeshold_percentage(fine_output, depth, 1.25*1.25)
300 | delta3_accuracy += threeshold_percentage(fine_output, depth, 1.25*1.25*1.25)
301 | rmse_linear_loss += rmse_linear(fine_output, depth)
302 | rmse_log_loss += rmse_log(fine_output, depth)
303 | abs_relative_difference_loss += abs_relative_difference(fine_output, depth)
304 | squared_relative_difference_loss += squared_relative_difference(fine_output, depth)
305 | fine_validation_loss /= (batch_idx + 1)
306 | scale_invariant_loss /= (batch_idx + 1)
307 | delta1_accuracy /= (batch_idx + 1)
308 | delta2_accuracy /= (batch_idx + 1)
309 | delta3_accuracy /= (batch_idx + 1)
310 | rmse_linear_loss /= (batch_idx + 1)
311 | rmse_log_loss /= (batch_idx + 1)
312 | abs_relative_difference_loss /= (batch_idx + 1)
313 | squared_relative_difference_loss /= (batch_idx + 1)
314 | logger.scalar_summary("fine validation loss", fine_validation_loss, epoch)
315 | # print('\nValidation set: Average loss(Fine): {:.4f} \n'.format(fine_validation_loss))
316 | print('Epoch: {} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}'.format(epoch, training_loss,
317 | fine_validation_loss, delta1_accuracy, delta2_accuracy, delta3_accuracy, rmse_linear_loss, rmse_log_loss,
318 | abs_relative_difference_loss, squared_relative_difference_loss))
319 |
320 | folder_name = "models/" + args.model_folder
321 | if not os.path.exists(folder_name): os.mkdir(folder_name)
322 |
323 | print("********* Training the Coarse Model **************")
324 | print("Epochs: Train_loss Val_loss Delta_1 Delta_2 Delta_3 rmse_lin rmse_log abs_rel. square_relative")
325 | print("Paper Val: (0.618) (0.891) (0.969) (0.871) (0.283) (0.228) (0.223)")
326 |
327 | for epoch in range(1, args.epochs + 1):
328 | # print("********* Training the Coarse Model **************")
329 | training_loss = train_coarse(epoch)
330 | coarse_validation(epoch, training_loss)
331 | model_file = folder_name + "/" + 'coarse_model_' + str(epoch) + '.pth'
332 | if(epoch%10 == 0):
333 | torch.save(coarse_model.state_dict(), model_file)
334 |
335 | coarse_model.eval() # stoping the coarse model to train.
336 |
337 | print("********* Training the Fine Model ****************")
338 | print("Epochs: Train_loss Val_loss Delta_1 Delta_2 Delta_3 rmse_lin rmse_log abs_rel. square_relative")
339 | print("Paper Val: (0.611) (0.887) (0.971) (0.907) (0.285) (0.215) (0.212)")
340 | for epoch in range(1, args.epochs + 1):
341 | # print("********* Training the Fine Model ****************")
342 | training_loss = train_fine(epoch)
343 | fine_validation(epoch, training_loss)
344 | model_file = folder_name + "/" + 'fine_model_' + str(epoch) + '.pth'
345 | if(epoch%5 == 0):
346 | torch.save(fine_model.state_dict(), model_file)
347 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import pdb
6 |
7 |
8 | class coarseNet(nn.Module):
9 | def __init__(self,init_weights=True):
10 | super(coarseNet, self).__init__()
11 | self.conv1 = nn.Conv2d(3, 96, kernel_size = 11, stride = 4, padding = 0)
12 | self.conv2 = nn.Conv2d(96, 256, kernel_size = 5, padding = 2)
13 | self.conv3 = nn.Conv2d(256, 384, kernel_size = 3, padding = 1)
14 | self.conv4 = nn.Conv2d(384, 384, kernel_size = 3, padding = 1)
15 | self.conv5 = nn.Conv2d(384, 256, kernel_size = 3, stride = 2)
16 | self.fc1 = nn.Linear(12288, 4096)
17 | self.fc2 = nn.Linear(4096, 4070)
18 | self.pool = nn.MaxPool2d(2)
19 | self.dropout = nn.Dropout2d()
20 | if init_weights:
21 | self._initialize_weights()
22 |
23 |
24 | def forward(self, x):
25 | # [n, c, H, W ]
26 | # [8, 3, 228, 304]
27 | x = self.conv1(x) # [8, 96, 55, 74]
28 | x = F.relu(x)
29 | x = self.pool(x) # [8, 96, 27, 37] --
30 | x = self.conv2(x) # [8, 256, 23, 33]
31 | x = F.relu(x)
32 | x = self.pool(x) # [8, 256, 11, 16] 18X13
33 | x = self.conv3(x) # [8, 384, 9, 14]
34 | x = F.relu(x)
35 | x = self.conv4(x) # [8, 384, 7, 12]
36 | x = F.relu(x)
37 | x = self.conv5(x) # [8, 256, 5, 10] 8X5
38 | x = F.relu(x)
39 | x = x.view(x.size(0), -1) # [8, 12800]
40 | x = F.relu(self.fc1(x)) # [8, 4096]
41 | x = self.dropout(x)
42 | x = self.fc2(x) # [8, 4070] => 55x74 = 4070
43 | x = x.view(-1, 1, 55, 74)
44 | return x
45 |
46 | # Pre-train Imagenet Model ??
47 | # Why random guassian model.
48 | def _initialize_weights(self):
49 | for m in self.modules():
50 | if isinstance(m, nn.Conv2d):
51 | m.weight.data.normal_(0, 0.01)
52 | if m.bias is not None:
53 | m.bias.data.zero_()
54 | elif isinstance(m, nn.Linear):
55 | m.weight.data.normal_(0, 0.01)
56 | m.bias.data.zero_()
57 |
58 |
59 | class fineNet(nn.Module):
60 | def __init__(self, init_weights=True):
61 | super(fineNet, self).__init__()
62 | self.conv1 = nn.Conv2d(3, 63, kernel_size = 9, stride = 2)
63 | self.conv2 = nn.Conv2d(64, 64, kernel_size = 5, padding = 2)
64 | self.conv3 = nn.Conv2d(64, 1, kernel_size = 5, padding = 2)
65 | self.pool = nn.MaxPool2d(2)
66 | if init_weights:
67 | self._initialize_weights()
68 |
69 |
70 | def forward(self, x, y):
71 | # [8, 3, 228, 304]
72 | x = F.relu(self.conv1(x)) # [8, 63, 110, 148]
73 | x = self.pool(x) # [8, 63, 55, 74]
74 | x = torch.cat((x,y),1) # x - [8, 63, 55, 74] y - [8, 1, 55, 74] => x = [8, 64, 55, 74]
75 | x = F.relu(self.conv2(x)) # [8, 64, 55, 74]
76 | x = self.conv3(x) # [8, 64, 55, 74]
77 | return x
78 |
79 |
80 | def _initialize_weights(self):
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | m.weight.data.normal_(0, 0.01)
84 | if m.bias is not None:
85 | m.bias.data.zero_()
86 | elif isinstance(m, nn.Linear):
87 | m.weight.data.normal_(0, 0.01)
88 | m.bias.data.zero_()
--------------------------------------------------------------------------------
/test_run.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import argparse
3 | from PIL import Image
4 |
5 | import torch
6 | # matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 | from mpl_toolkits.axes_grid1 import ImageGrid
9 |
10 | from model import coarseNet, fineNet
11 | import pdb
12 | import numpy as np
13 |
14 | import torchvision.transforms as transforms
15 |
16 | parser = argparse.ArgumentParser(description='PyTorch depth prediction test run script')
17 | parser.add_argument('--coarse_model_path', type=str, default='coarse_model.pth', metavar='F',
18 | help='path of coarse_model')
19 | parser.add_argument('--fine_model_path', type=str, default= 'fine_model.pth', metavar='F',
20 | help='path of fine_model')
21 | parser.add_argument('--path', type=str, default='sample_input2.jpg', metavar='D',
22 | help="path of the image. By default it will run on the sample.jpg which comes with the repository")
23 |
24 | args = parser.parse_args()
25 |
26 | coarse_state_dict = torch.load(args.coarse_model_path, map_location=lambda storage, loc: storage)
27 | fine_state_dict = torch.load(args.fine_model_path, map_location=lambda storage, loc: storage)
28 |
29 | coarse_model = coarseNet()
30 | fine_model = fineNet()
31 |
32 | coarse_model.load_state_dict(coarse_state_dict)
33 | fine_model.load_state_dict(fine_state_dict)
34 | coarse_model.eval()
35 | fine_model.eval()
36 |
37 | rgb_data_transforms = transforms.Compose([
38 | transforms.Resize((228, 304)),
39 | transforms.ToTensor(),
40 | ])
41 |
42 | input_for_plot_transforms = transforms.Compose([
43 | transforms.Resize((55, 74)), # for Input to be equal to output size
44 | transforms.ToTensor(),
45 | ])
46 |
47 | image = Image.open(args.path)
48 | image = np.transpose(image, (0, 1, 2))
49 |
50 | image = Image.fromarray(image)
51 | input_image = input_for_plot_transforms(image)
52 | image = rgb_data_transforms(image)
53 | image = image.view(1, 3, 228, 304)
54 |
55 | coarse_output = coarse_model(image)
56 | fine_output = fine_model(image, coarse_output)
57 |
58 | plt.figure(1, figsize=(9, 3))
59 | plt.subplot(131)
60 | plt.gca().set_title('input')
61 | plt.imshow(np.transpose(input_image, (1, 2, 0)), interpolation="nearest")
62 | plt.subplot(132)
63 | plt.gca().set_title('coarse_output')
64 | plt.imshow(np.transpose(coarse_output[0][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")
65 | plt.subplot(133)
66 | plt.gca().set_title('fine_output')
67 | plt.imshow(np.transpose(fine_output[0][0].detach().cpu().numpy(), (0, 1)), interpolation="nearest")
68 | plt.suptitle('Depth Map Prediction of Input Image')
69 | plt.show()
70 |
--------------------------------------------------------------------------------