├── LICENSE ├── augment.py ├── dsac.py ├── images ├── dsac.png ├── dsac_eq.png ├── example_output.png ├── loss.png └── task.png ├── line_dataset.py ├── line_loss.py ├── line_nn.py ├── main.py └── readme.md /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Visual Learning Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import cv2 3 | import os 4 | 5 | infolder = './images_rf65_c4_h64_t0.05_schedule4_2' 6 | outfolder = infolder + '_out' 7 | 8 | files = os.listdir(infolder) 9 | files.sort() 10 | 11 | for i,f in enumerate(files): 12 | 13 | infile = infolder + '/' + f 14 | 15 | img = cv2.imread(infile) 16 | 17 | img = cv2.flip(img, 0) 18 | img.resize((250, 600, 3)) 19 | img = cv2.flip(img, 0) 20 | img.resize((300, 600, 3)) 21 | 22 | iteration = i * 100 23 | label = 'learning iteration %05d/50000' % iteration 24 | cv2.putText(img, label, (330, 30), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 25 | 26 | label = 'validation inputs' 27 | cv2.putText(img, label, (30, 270), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 28 | 29 | label = 'DSAC predictions' 30 | cv2.putText(img, label, (230, 270), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 31 | 32 | label = 'direct predictions' 33 | cv2.putText(img, label, (420, 270), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 34 | 35 | yoffset = 3 36 | 37 | cv2.line(img, (10, 5+yoffset), (20, 15+yoffset), (0, 255, 0)) 38 | label = 'ground truth' 39 | cv2.putText(img, label, (25, 15+yoffset), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 40 | 41 | cv2.line(img, (10, 25+yoffset), (20, 35+yoffset), (255, 0, 0)) 42 | label = 'predicted line' 43 | cv2.putText(img, label, (25, 35+yoffset), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 44 | 45 | cv2.rectangle(img, (170, 5+yoffset), (180, 15+yoffset), (0, 0, 254)) 46 | label = 'incorrect' 47 | cv2.putText(img, label, (185, 15+yoffset), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 48 | 49 | cv2.rectangle(img, (170, 25+yoffset), (180, 35+yoffset), (0, 255, 0)) 50 | label = 'correct' 51 | cv2.putText(img, label, (185, 35+yoffset), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255)) 52 | 53 | if not os.path.exists(outfolder): 54 | os.mkdir(outfolder) 55 | 56 | outfile = (outfolder + '/out_%06d.png') % i 57 | 58 | cv2.imwrite(outfile, img) 59 | -------------------------------------------------------------------------------- /dsac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import random 5 | 6 | class DSAC: 7 | ''' 8 | Differentiable RANSAC to robustly fit lines. 9 | ''' 10 | 11 | def __init__(self, hyps, inlier_thresh, inlier_beta, inlier_alpha, loss_function): 12 | ''' 13 | Constructor. 14 | 15 | hyps -- number of line hypotheses sampled for each image 16 | inlier_thresh -- threshold used in the soft inlier count, its measured in relative image size (1 = image width) 17 | inlier_beta -- scaling factor within the sigmoid of the soft inlier count 18 | inlier_alpha -- scaling factor for the soft inlier scores (controls the peakiness of the hypothesis distribution) 19 | loss_function -- function to compute the quality of estimated line parameters wrt ground truth 20 | ''' 21 | 22 | self.hyps = hyps 23 | self.inlier_thresh = inlier_thresh 24 | self.inlier_beta = inlier_beta 25 | self.inlier_alpha = inlier_alpha 26 | self.loss_function = loss_function 27 | 28 | def __sample_hyp(self, x, y): 29 | ''' 30 | Calculate a line hypothesis (slope, intercept) from two random points. 31 | 32 | x -- vector of x values 33 | y -- vector of y values 34 | ''' 35 | 36 | # select two random points 37 | num_correspondences = x.size(0) 38 | 39 | idx1 = random.randint(0, num_correspondences-1) 40 | idx2 = random.randint(0, num_correspondences-1) 41 | 42 | tries = 1000 43 | 44 | # prevent slope from getting too large 45 | while torch.abs(x[idx1] - x[idx2]) < 0.01 and tries > 0: 46 | idx2 = random.randint(0, num_correspondences-1) 47 | tries = tries - 1 48 | 49 | if tries == 0: return 0, 0, False # no valid hypothesis found, indicated by False 50 | 51 | slope = (y[idx1] - y[idx2]) / (x[idx1] - x[idx2]) 52 | intercept = y[idx1] - slope * x[idx1] 53 | 54 | return slope, intercept, True # True indicates success 55 | 56 | def __soft_inlier_count(self, slope, intercept, x, y): 57 | ''' 58 | Soft inlier count for a given line and a given set of points. 59 | 60 | slope -- slope of the line 61 | intercept -- intercept of the line 62 | x -- vector of x values 63 | y -- vector of y values 64 | ''' 65 | 66 | # point line distances 67 | dists = torch.abs(slope * x - y + intercept) 68 | dists = dists / torch.sqrt(slope * slope + 1) 69 | 70 | # soft inliers 71 | dists = 1 - torch.sigmoid(self.inlier_beta * (dists - self.inlier_thresh)) 72 | score = torch.sum(dists) 73 | 74 | return score, dists 75 | 76 | def __refine_hyp(self, x, y, weights): 77 | ''' 78 | Refinement by weighted Deming regression. 79 | 80 | Fits a line minimizing errors in x and y, implementation according to: 81 | 'Performance of Deming regression analysis in case of misspecified 82 | analytical error ratio in method comparison studies' 83 | Kristian Linnet, in Clinical Chemistry, 1998 84 | 85 | x -- vector of x values 86 | y -- vector of y values 87 | weights -- vector of weights (1 per point) 88 | ''' 89 | 90 | ws = weights.sum() 91 | xm = (x * weights).sum() / ws 92 | ym = (y * weights).sum() / ws 93 | 94 | u = (x - xm)**2 95 | u = (u * weights).sum() 96 | 97 | q = (y - ym)**2 98 | q = (q * weights).sum() 99 | 100 | p = torch.mul(x - xm, y - ym) 101 | p = (p * weights).sum() 102 | 103 | slope = (q - u + torch.sqrt((u - q)**2 + 4*p*p)) / (2*p) 104 | intercept = ym - slope * xm 105 | 106 | return slope, intercept 107 | 108 | 109 | def __call__(self, prediction, labels): 110 | ''' 111 | Perform robust, differentiable line fitting according to DSAC. 112 | 113 | Returns the expected loss of choosing a good line hypothesis which can be used for backprob. 114 | 115 | prediction -- predicted 2D points for a batch of images, array of shape (Bx2) where 116 | B is the number of images in the batch 117 | 2 is the number of point dimensions (y, x) 118 | labels -- ground truth labels for the batch, array of shape (Bx2) where 119 | B is the number of images in the batch 120 | 2 is the number of parameters (intercept, slope) 121 | ''' 122 | 123 | # working on CPU because of many, small matrices 124 | prediction = prediction.cpu() 125 | 126 | batch_size = prediction.size(0) 127 | 128 | avg_exp_loss = 0 # expected loss 129 | avg_top_loss = 0 # loss of best hypothesis 130 | 131 | self.est_parameters = torch.zeros(batch_size, 2) # estimated lines 132 | self.est_losses = torch.zeros(batch_size) # loss of estimated lines 133 | self.batch_inliers = torch.zeros(batch_size, prediction.size(2)) # (soft) inliers for estimated lines 134 | 135 | for b in range(0, batch_size): 136 | 137 | hyp_losses = torch.zeros([self.hyps, 1]) # loss of each hypothesis 138 | hyp_scores = torch.zeros([self.hyps, 1]) # score of each hypothesis 139 | 140 | max_score = 0 # score of best hypothesis 141 | 142 | y = prediction[b, 0] # all y-values of the prediction 143 | x = prediction[b, 1] # all x.values of the prediction 144 | 145 | for h in range(0, self.hyps): 146 | 147 | # === step 1: sample hypothesis =========================== 148 | slope, intercept, valid = self.__sample_hyp(x, y) 149 | if not valid: continue # skip invalid hyps 150 | 151 | # === step 2: score hypothesis using soft inlier count ==== 152 | score, inliers = self.__soft_inlier_count(slope, intercept, x, y) 153 | 154 | # === step 3: refine hypothesis =========================== 155 | slope, intercept = self.__refine_hyp(x, y, inliers) 156 | 157 | hyp = torch.zeros([2]) 158 | hyp[1] = slope 159 | hyp[0] = intercept 160 | 161 | # === step 4: calculate loss of hypothesis ================ 162 | loss = self.loss_function(hyp, labels[b]) 163 | 164 | # store results 165 | hyp_losses[h] = loss 166 | hyp_scores[h] = score 167 | 168 | # keep track of best hypothesis so far 169 | if score > max_score: 170 | max_score = score 171 | self.est_losses[b] = loss 172 | self.est_parameters[b] = hyp 173 | self.batch_inliers[b] = inliers 174 | 175 | # === step 5: calculate the expectation =========================== 176 | 177 | #softmax distribution from hypotheses scores 178 | hyp_scores = F.softmax(self.inlier_alpha * hyp_scores, 0) 179 | 180 | # expectation of loss 181 | exp_loss = torch.sum(hyp_losses * hyp_scores) 182 | avg_exp_loss = avg_exp_loss + exp_loss 183 | 184 | # loss of best hypothesis (for evaluation) 185 | avg_top_loss = avg_top_loss + self.est_losses[b] 186 | 187 | return avg_exp_loss / batch_size, avg_top_loss / batch_size 188 | -------------------------------------------------------------------------------- /images/dsac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/DSACLine/60f29a75bcb423e7a96cb3320522c7cfa42544fd/images/dsac.png -------------------------------------------------------------------------------- /images/dsac_eq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/DSACLine/60f29a75bcb423e7a96cb3320522c7cfa42544fd/images/dsac_eq.png -------------------------------------------------------------------------------- /images/example_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/DSACLine/60f29a75bcb423e7a96cb3320522c7cfa42544fd/images/example_output.png -------------------------------------------------------------------------------- /images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/DSACLine/60f29a75bcb423e7a96cb3320522c7cfa42544fd/images/loss.png -------------------------------------------------------------------------------- /images/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vislearn/DSACLine/60f29a75bcb423e7a96cb3320522c7cfa42544fd/images/task.png -------------------------------------------------------------------------------- /line_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import math 4 | 5 | from skimage.draw import line, line_aa, circle, set_color, circle_perimeter_aa 6 | from skimage.io import imsave 7 | from skimage.util import random_noise 8 | 9 | maxSlope = 10 # restrict the maximum slope of generated lines for stability 10 | minLength = 20 # restrict the minimum length of line segments 11 | 12 | class LineDataset: 13 | ''' 14 | Generator of line segment images. 15 | 16 | Images will have 1 random line segment each, filled with noise and distractor circles. 17 | Class also offers functionality for drawing line parameters, hypotheses and point predictions. 18 | ''' 19 | 20 | def __init__(self, imgW = 64, imgH = 64, margin = -5, bg_clr = 0.5): 21 | ''' 22 | Constructor. 23 | 24 | imgW -- image width (default 64) 25 | imgH -- image height (default 64) 26 | margin -- lines segments are sampled within this margin, negative value means that a line segment can start or end outside the image (default -5) 27 | bg_clr -- background intensity (default 0.5) 28 | ''' 29 | 30 | self.imgW = imgW 31 | self.imgH = imgH 32 | self.margin = margin 33 | self.bg_clr = bg_clr 34 | 35 | def draw_line(self, data, lX1, lY1, lX2, lY2, clr, alpha=1.0): 36 | ''' 37 | Draw a line with the given color and opacity. 38 | 39 | data -- image to draw to 40 | lX1 -- x value of line segment start point 41 | lY1 -- y value of line segment start point 42 | lX2 -- x value of line segment end point 43 | lY2 -- y value of line segment end point 44 | clr -- line color, triple of values 45 | alpha -- opacity (default 1.0) 46 | ''' 47 | 48 | rr, cc, val = line_aa(lY1, lX1, lY2, lX2) 49 | set_color(data, (rr, cc), clr, val*alpha) 50 | 51 | def draw_hyps(self, labels, scores, data=None): 52 | ''' 53 | Draw a set of line hypothesis for a batch of images. 54 | 55 | labels -- line parameters, array shape (NxMx2) where 56 | N is the number of images in the batch 57 | M is the number of hypotheses per image 58 | 2 is the number of line parameters (intercept, slope) 59 | scores -- hypotheses scores, array shape (NxM), see above, higher score will be drawn with higher opacity 60 | data -- batch of images to draw to, if empty a new batch wil be created according to the shape of labels 61 | 62 | ''' 63 | 64 | n = labels.shape[0] # number of images 65 | m = labels.shape[1] # number of hypotheses 66 | 67 | if data is None: # create new batch of images 68 | data = np.zeros((n, self.imgH, self.imgW, 3), dtype=np.float32) 69 | data.fill(self.bg_clr) 70 | 71 | clr = (0, 0, 1) 72 | 73 | for i in range (0, n): 74 | for j in range (0, m): 75 | lY1 = int(labels[i, j, 0] * self.imgH) 76 | lY2 = int(labels[i, j, 1] * self.imgW + labels[i, j, 0] * self.imgH) 77 | self.draw_line(data[i], 0, lY1, self.imgW, lY2, clr, scores[i, j]) 78 | 79 | return data 80 | 81 | def draw_models(self, labels, data=None, correct=None): 82 | ''' 83 | Draw lines for a batch of images. 84 | 85 | labels -- line parameters, array shape (Nx2) where 86 | N is the number of images in the batch 87 | 2 is the number of line parameters (intercept, slope) 88 | data -- batch of images to draw to, if empty a new batch wil be created according to the shape of labels 89 | and lines will be green, lines will be blue otherwise 90 | correct -- array of shape (N) indicating whether a line estimate is correct 91 | ''' 92 | 93 | n = labels.shape[0] 94 | if data is None: 95 | data = np.zeros((n, self.imgH, self.imgW, 3), dtype=np.float32) 96 | data.fill(self.bg_clr) 97 | clr = (0, 1, 0) 98 | else: 99 | clr = (0, 0, 1) 100 | 101 | for i in range (0, n): 102 | lY1 = int(labels[i, 0] * self.imgH) 103 | lY2 = int(labels[i, 1] * self.imgW + labels[i, 0] * self.imgH) 104 | self.draw_line(data[i], 0, lY1, self.imgW, lY2, clr) 105 | 106 | if correct is not None: 107 | 108 | # draw border green if estiamte is correct, red otherwise 109 | if correct[i]: borderclr = (0, 1, 0) 110 | else: borderclr = (1, 0, 0) 111 | 112 | set_color(data[i], line(0, 0, 0, self.imgW-1), borderclr) 113 | set_color(data[i], line(0, 0, self.imgH-1, 0), borderclr) 114 | set_color(data[i], line(self.imgH-1, 0, self.imgH-1, self.imgW-1), borderclr) 115 | set_color(data[i], line(0, self.imgW-1, self.imgH-1, self.imgW-1), borderclr) 116 | 117 | return data 118 | 119 | def draw_points(self, points, data, inliers=None): 120 | ''' 121 | Draw 2D points for a batch of images. 122 | 123 | points -- 2D points, array shape (Nx2xM) where 124 | N is the number of images in the batch 125 | 2 is the number of point dimensions (x, y) 126 | M is the number of points 127 | data -- batch of images to draw to 128 | inliers -- soft inlier score for each point, 129 | if given and score < 0.5 point will be drawn green, red otherwise 130 | ''' 131 | 132 | n = points.shape[0] # number of images 133 | m = points.shape[2] # number of points 134 | 135 | for i in range (0, n): 136 | for j in range(0, m): 137 | 138 | clr = (0.2, 0.2, 0.2) # draw predicted points as dark circles 139 | if inliers is not None and inliers[i, j] > 0.5: 140 | clr = (0.7, 0.7, 0.7) # draw inliers as light circles 141 | 142 | r = int(points[i, 0, j] * self.imgH) 143 | c = int(points[i, 1, j] * self.imgW) 144 | rr, cc = circle(r, c, 2) 145 | set_color(data[i], (rr, cc), clr) 146 | 147 | return data 148 | 149 | def sample_lines(self, n): 150 | ''' 151 | Create new input images of random line segments and distractors along with ground truth parameters. 152 | 153 | n -- number of images to create 154 | ''' 155 | 156 | data = np.zeros((n, self.imgH, self.imgW, 3), dtype=np.float32) 157 | data.fill(self.bg_clr) 158 | labels = np.ndarray((n, 2, 1, 1), dtype=np.float32) 159 | 160 | for i in range (0, n): # for each image 161 | 162 | # create a random number of distractor circles 163 | nC = random.randint(2, 5) 164 | for c in range(0, nC): 165 | 166 | cR = random.randint(int(0.1 * self.imgW), int(1 * self.imgW)) 167 | cX1 = random.randint(int(-0.5 * cR), int(self.imgW+0.5*cR+1)) 168 | cY1 = random.randint(int(-0.5 * cR), int(self.imgH+0.5*cR+1)) 169 | 170 | clr = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)) 171 | 172 | rr, cc, val = circle_perimeter_aa(cY1, cX1, cR) 173 | set_color(data[i], (rr, cc), clr, val) 174 | 175 | # create line segment 176 | while True: 177 | 178 | # sample segment end points 179 | lX1 = random.randint(self.margin, self.imgW-self.margin+1) 180 | lX2 = random.randint(self.margin, self.imgW-self.margin+1) 181 | lY1 = random.randint(self.margin, self.imgH-self.margin+1) 182 | lY2 = random.randint(self.margin, self.imgH-self.margin+1) 183 | 184 | # check min length 185 | length = math.sqrt((lX1 - lX2) * (lX1 - lX2) + (lY1 - lY2) * (lY1 - lY2)) 186 | if length < minLength: continue 187 | 188 | # random color 189 | clr = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)) 190 | 191 | # calculate line ground truth parameters 192 | delta = lX2 - lX1 193 | if delta == 0: delta = 1 194 | 195 | slope = (lY2 - lY1) / delta 196 | intercept = lY1 - slope * lX1 197 | 198 | # not too steep for stability 199 | if abs(slope) < maxSlope: break 200 | 201 | labels[i, 0] = intercept / self.imgH 202 | labels[i, 1] = slope 203 | 204 | self.draw_line(data[i], lX1, lY1, lX2, lY2, clr) 205 | 206 | # apply some noise on top 207 | data[i] = random_noise(data[i], mode='speckle') 208 | 209 | return data, labels 210 | -------------------------------------------------------------------------------- /line_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LineLoss: 4 | ''' 5 | Compares two lines by calculating the distance between their ends in the image. 6 | ''' 7 | 8 | def __init__(self, image_size): 9 | ''' 10 | Constructor. 11 | 12 | image_size -- size of the input images, used to normalize the loss 13 | ''' 14 | self.image_size = image_size 15 | 16 | def __get_max_points(self, slope, intercept): 17 | ''' 18 | Calculates the 2D points where a line intersects with the image borders. 19 | 20 | slope -- slope of the line 21 | intercept -- intercept of the line 22 | ''' 23 | pts = torch.zeros([2, 2]) 24 | 25 | x0 = 0 26 | x1 = 1 27 | y0 = intercept 28 | y1 = intercept + slope 29 | 30 | # determine which image borders the line cuts 31 | cuts_x0 = (y0 >= 0 and y0 <= 1) # left border 32 | cuts_x1 = (y1 >= 0 and y1 <= 1) # right border 33 | cuts_y0 = (y0 <= 0 and y1 >= 0) or (y1 <= 0 and y0 >= 0) # upper border 34 | cuts_y1 = (y0 <= 1 and y1 >= 1) or (y1 <= 1 and y0 >= 1) # lower border 35 | 36 | if cuts_x0 and cuts_x1: 37 | # line goes from left to right 38 | # use initialization above 39 | pass 40 | 41 | elif cuts_x0 and cuts_y0: 42 | # line goes from left to top 43 | y1 = 0 44 | x1 = -intercept / slope 45 | 46 | elif cuts_x0 and cuts_y1: 47 | # line goes from left to bottom 48 | y1 = 1 49 | x1 = (1 - intercept) / slope 50 | 51 | elif cuts_x1 and cuts_y0: 52 | # line goes from top to right 53 | y0 = 0 54 | x0 = -intercept / slope 55 | 56 | elif cuts_x1 and cuts_y1: 57 | # line goes from bottom to right 58 | y0 = 1 59 | x0 = (1 - intercept) / slope 60 | 61 | elif cuts_y0 and cuts_y1: 62 | # line goes from top to bottom 63 | y0 = 0 64 | x0 = -intercept / slope 65 | y1 = 1 66 | x1 = (1 - intercept) / slope 67 | 68 | else: 69 | # outside image 70 | x0 = -intercept / slope 71 | if abs(x0) < abs(y0): 72 | y0 = 0 73 | else: 74 | x0 = 0 75 | 76 | x1 = (1 - intercept) / slope 77 | if abs(x1) < abs(y1): 78 | y1 = 1 79 | else: 80 | x1 = 1 81 | 82 | pts[0, 0] = x0 83 | pts[0, 1] = y0 84 | pts[1, 0] = x1 85 | pts[1, 1] = y1 86 | 87 | return pts 88 | 89 | def __call__(self, est, gt): 90 | ''' 91 | Calculate the line loss. 92 | 93 | est -- estimated line, form: [intercept, slope] 94 | gt -- ground truth line, form: [intercept, slope] 95 | ''' 96 | 97 | pts_est = self.__get_max_points(est[1], est[0]) 98 | pts_gt = self.__get_max_points(gt[1], gt[0]) 99 | 100 | # not clear which ends of the lines should be compared (there are ambigious cases), compute both and take min 101 | loss1 = pts_est - pts_gt 102 | loss1 = loss1.norm(2, 1).sum() 103 | 104 | flip_mat = torch.zeros([2, 2]) 105 | flip_mat[0, 1] = 1 106 | flip_mat[1, 0] = 1 107 | 108 | loss2 = pts_est - flip_mat.mm(pts_gt) 109 | loss2 = loss2.norm(2, 1).sum() 110 | 111 | return min(loss1, loss2) * self.image_size 112 | -------------------------------------------------------------------------------- /line_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import random 6 | 7 | 8 | class LineNN(nn.Module): 9 | ''' 10 | Genereric CNN architecture that can be used for 2D point prediction and direct prediction. 11 | 12 | It supports a FCN style architecture with varying receptive fields, as well as a globel 13 | CNN which produces one output per image. 14 | 15 | ''' 16 | 17 | def __init__(self, net_capacity, receptive_field = 0, direct = False, image_size = 64, global_output_grid = 8): 18 | ''' 19 | Constructor. 20 | 21 | net_capacity -- multiplicative factor for the number of layer channels 22 | receptive field -- receptive field of the output neurons, the class will select 23 | filter strides accordingly (supported: 15, 29, 37, 51, 65, 0), 0 = global 24 | receptive field (default 0) 25 | direct -- model predicts line parameters directly, it predicts multiple 2D points 26 | otherwise (default False) 27 | image_size -- size of the input images (default 64) 28 | global_output_grid -- number of 2D output points for a global model 29 | (receptive_field=0), points are distributed on a 2D grid, i.e. number of 30 | points is squared, for a receptive_field > 0 (i.e. FCN setting) the 31 | number of output points results from the input image dimensions (default 8) 32 | ''' 33 | super(LineNN, self).__init__() 34 | 35 | c = net_capacity 36 | output_dim = 2 37 | 38 | if direct and receptive_field is not 0: 39 | print('Warning: Direct models must have global receptive field (0).') 40 | 41 | # set the conv strides to achieve the desired receptive field 42 | self.global_model = False 43 | if receptive_field == 15: 44 | strides = [1, 1, 1, 1, 1, 1, 8] 45 | elif receptive_field == 29: 46 | strides = [1, 1, 1, 2, 2, 1, 2] 47 | elif receptive_field == 37: 48 | strides = [1, 1, 1, 2, 2, 2, 1] 49 | elif receptive_field == 51: 50 | strides = [1, 1, 2, 2, 2, 1, 1] 51 | elif receptive_field == 65: 52 | strides = [1, 2, 2, 2, 1, 1, 1] 53 | else: 54 | if receptive_field is not 0: 55 | print('Warning: Unknown receptive field, using 0 (global).') 56 | 57 | receptive_field = 2 * image_size # set global receptive field 58 | strides = [1, 2, 2, 2, 2, 2, 2] 59 | if not direct: output_dim = global_output_grid * global_output_grid * 2 60 | self.global_model = True 61 | 62 | # build network 63 | self.conv1 = nn.Conv2d(3, 4*c, 3, strides[0], 1) 64 | self.bn1 = nn.BatchNorm2d(4*c) 65 | self.conv2 = nn.Conv2d(4*c, 8*c, 3, strides[1], 1) 66 | self.bn2 = nn.BatchNorm2d(8*c) 67 | self.conv3 = nn.Conv2d(8*c, 16*c, 3, strides[2], 1) 68 | self.bn3 = nn.BatchNorm2d(16*c) 69 | self.conv4 = nn.Conv2d(16*c, 32*c, 3, strides[3], 1) 70 | self.bn4 = nn.BatchNorm2d(32*c) 71 | self.conv5 = nn.Conv2d(32*c, 64*c, 3, strides[4], 1) 72 | self.bn5 = nn.BatchNorm2d(64*c) 73 | self.conv6 = nn.Conv2d(64*c, 64*c, 3, strides[5], 1) 74 | self.bn6 = nn.BatchNorm2d(64*c) 75 | self.conv7 = nn.Conv2d(64*c, 64*c, 3, strides[6], 1) 76 | self.bn7 = nn.BatchNorm2d(64*c) 77 | 78 | self.pool = nn.AdaptiveMaxPool2d(1) #used only for global models to support arbitrary image size 79 | 80 | self.fc1 = nn.Conv2d(64*c, 64*c, 1, 1, 0) 81 | self.bn8 = nn.BatchNorm2d(64*c) 82 | self.fc2 = nn.Conv2d(64*c, 64*c, 1, 1, 0) 83 | self.bn9 = nn.BatchNorm2d(64*c) 84 | self.fc3 = nn.Conv2d(64*c, output_dim, 1, 1, 0) 85 | 86 | self.patch_size = receptive_field / image_size 87 | self.global_output_grid = global_output_grid 88 | self.direct_model = direct 89 | 90 | def forward(self, input): 91 | ''' 92 | Forward pass. 93 | 94 | input -- 4D data tensor (BxCxHxW) 95 | ''' 96 | 97 | batch_size = input.size(0) 98 | 99 | x = F.relu(self.bn1(self.conv1(input))) 100 | x = F.relu(self.bn2(self.conv2(x))) 101 | x = F.relu(self.bn3(self.conv3(x))) 102 | x = F.relu(self.bn4(self.conv4(x))) 103 | x = F.relu(self.bn5(self.conv5(x))) 104 | x = F.relu(self.bn6(self.conv6(x))) 105 | x = F.relu(self.bn7(self.conv7(x))) 106 | 107 | if self.global_model: x = self.pool(x) 108 | 109 | x = F.relu(self.bn8(self.fc1(x))) 110 | x = F.relu(self.bn9(self.fc2(x))) 111 | x = self.fc3(x) 112 | 113 | # direct model predicts line paramters directly 114 | if self.direct_model: 115 | return x.squeeze() 116 | 117 | # otherwise points are predicted 118 | x = torch.sigmoid(x) # normalize to 0,1 119 | 120 | if self.global_model: 121 | x = x.view(-1, 2, self.global_output_grid, self.global_output_grid) 122 | 123 | # map local (patch-centric) point predictions to global image coordinates 124 | # i.e. distribute the points over the image 125 | patch_offset = 1 / x.size(2) 126 | 127 | x = x * self.patch_size - self.patch_size / 2 + patch_offset / 2 128 | 129 | for col in range(0, x.size(3)): 130 | x[:,1,:,col] = x[:,1,:,col] + col * patch_offset 131 | 132 | for row in range(0, x.size(2)): 133 | x[:,0,row,:] = x[:,0,row,:] + row * patch_offset 134 | 135 | return x.view(batch_size, 2, -1) 136 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torchvision.utils as vutils 4 | 5 | import os 6 | import time 7 | import numpy 8 | import warnings 9 | import argparse 10 | 11 | from skimage.io import imsave 12 | 13 | from line_dataset import LineDataset 14 | from line_nn import LineNN 15 | from line_loss import LineLoss 16 | 17 | from dsac import DSAC 18 | 19 | parser = argparse.ArgumentParser(description='This script creates a toy problem of fitting line parameters (slope+intercept) to synthetic images showing line segments, noise and distracting circles. Two networks are trained in parallel and compared: DirectNN predicts the line parameters directly (two output neurons). PointNN predicts a number of 2D points to which the line parameters are subsequently fitted using differentiable RANSAC (DSAC). The script will produce a sequence of images that illustrate the training process for both networks.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | 21 | parser.add_argument('--receptivefield', '-rf', type=int, default=65, choices=[65, 51, 37, 29, 15, 0], 22 | help='receptive field size of the PointNN, i.e. one point prediction is made for each image patch of this size, different receptive fields are achieved by different striding strategies, 0 means global, i.e. the full image, the DirectNN will always use 0 (global)') 23 | 24 | parser.add_argument('--capacity', '-c', type=int, default=4, 25 | help='controls the model capactiy of both networks (PointNN and DirectNN), it is a multiplicative factor for the number of channels in each network layer') 26 | 27 | parser.add_argument('--hypotheses', '-hyps', type=int, default=64, 28 | help='number of line hypotheses sampled for each image') 29 | 30 | parser.add_argument('--inlierthreshold', '-it', type=float, default=0.05, 31 | help='threshold used in the soft inlier count. Its measured in relative image size (1 = image width)') 32 | 33 | parser.add_argument('--inlieralpha', '-ia', type=float, default=0.5, 34 | help='scaling factor for the soft inlier scores (controls the peakiness of the hypothesis distribution)') 35 | 36 | parser.add_argument('--inlierbeta', '-ib', type=float, default=100.0, 37 | help='scaling factor within the sigmoid of the soft inlier count') 38 | 39 | parser.add_argument('--learningrate', '-lr', type=float, default=0.001, 40 | help='learning rate') 41 | 42 | parser.add_argument('--lrstep', '-lrs', type=int, default=2500, 43 | help='cut learning rate in half each x iterations') 44 | 45 | parser.add_argument('--lrstepoffset', '-lro', type=int, default=30000, 46 | help='keep initial learning rate for at least x iterations') 47 | 48 | parser.add_argument('--batchsize', '-bs', type=int, default=32, 49 | help='training batch size') 50 | 51 | parser.add_argument('--trainiterations', '-ti', type=int, default=50000, 52 | help='number of training iterations (= parameter updates)') 53 | 54 | parser.add_argument('--imagesize', '-is', type=int, default=64, 55 | help='size of input images generated, images are square') 56 | 57 | parser.add_argument('--storeinterval', '-si', type=int, default=1000, 58 | help='store network weights and a prediction vizualisation every x training iterations') 59 | 60 | parser.add_argument('--valsize', '-vs', type=int, default=9, 61 | help='number of validation images used to vizualize predictions') 62 | 63 | parser.add_argument('--valthresh', '-vt', type=float, default=5, 64 | help='threshold on the line loss for vizualizing correctness of predictions') 65 | 66 | parser.add_argument('--cpu', '-cpu', action='store_true', 67 | help='execute networks on CPU. Note that (RANSAC) line fitting anyway runs on CPU') 68 | 69 | parser.add_argument('--session', '-sid', default='', 70 | help='custom session name appended to output files. Useful to separate different runs of the program') 71 | 72 | opt = parser.parse_args() 73 | 74 | if len(opt.session) > 0: opt.session = '_' + opt.session 75 | sid = 'rf%d_c%d_h%d_t%.2f%s' % (opt.receptivefield, opt.capacity, opt.hypotheses, opt.inlierthreshold, opt.session) 76 | 77 | # setup the training process 78 | dataset = LineDataset(opt.imagesize, opt.imagesize) 79 | 80 | loss = LineLoss(opt.imagesize) 81 | dsac = DSAC(opt.hypotheses, opt.inlierthreshold, opt.inlierbeta, opt.inlieralpha, loss) 82 | 83 | # we train two CNNs in parallel 84 | # 1) a CNN that predicts points and is trained with DSAC -> PointNN (good idea) 85 | point_nn = LineNN(opt.capacity, opt.receptivefield) 86 | if not opt.cpu: point_nn = point_nn.cuda() 87 | point_nn.train() 88 | opt_point_nn = optim.Adam(point_nn.parameters(), lr=opt.learningrate) 89 | lrs_point_nn = optim.lr_scheduler.StepLR(opt_point_nn, opt.lrstep, gamma=0.5) 90 | 91 | # 2) a CNN that predicts the line parameters directly -> DirectNN (bad idea) 92 | direct_nn = LineNN(opt.capacity, 0, True) 93 | if not opt.cpu: direct_nn = direct_nn.cuda() 94 | direct_nn.train() 95 | opt_direct_nn = optim.Adam(direct_nn.parameters(), lr=opt.learningrate) 96 | lrs_direct_nn = optim.lr_scheduler.StepLR(opt_direct_nn, opt.lrstep, gamma=0.5) 97 | 98 | # keep track of training progress 99 | train_log = open('log_'+sid+'.txt', 'w', 1) 100 | 101 | # some helper functions 102 | def prepare_data(inputs, labels): 103 | # convert from numpy images to normalized torch arrays 104 | 105 | inputs = torch.from_numpy(inputs) 106 | labels = torch.from_numpy(labels) 107 | 108 | if not opt.cpu: inputs = inputs.cuda() 109 | inputs.transpose_(1,3).transpose_(2, 3) 110 | inputs = inputs - 0.5 # normalization 111 | 112 | return inputs, labels 113 | 114 | def batch_loss(prediction, labels): 115 | # caluclate the loss for each image in the batch 116 | 117 | losses = torch.zeros(labels.size(0)) 118 | 119 | for b in range(0, labels.size(0)): 120 | losses[b] = loss(prediction[b], labels[b]) 121 | 122 | return losses 123 | 124 | # generate validation data (for consistent vizualisation only) 125 | val_images, val_labels = dataset.sample_lines(opt.valsize) 126 | val_inputs, val_labels = prepare_data(val_images, val_labels) 127 | 128 | # start training 129 | for iteration in range(0, opt.trainiterations+1): 130 | 131 | start_time = time.time() 132 | 133 | # generate training data 134 | inputs, labels = dataset.sample_lines(opt.batchsize) 135 | inputs, labels = prepare_data(inputs, labels) 136 | 137 | # point nn forward pass 138 | point_prediction = point_nn(inputs) 139 | 140 | # robust line fitting with DSAC 141 | exp_loss, top_loss = dsac(point_prediction, labels) 142 | 143 | exp_loss.backward() # calculate gradients (pytorch autograd) 144 | opt_point_nn.step() # update parameters 145 | opt_point_nn.zero_grad() # reset gradient buffer 146 | if iteration >= opt.lrstepoffset: 147 | lrs_point_nn.step() # update learning rate schedule 148 | 149 | # also train direct nn 150 | direct_prediction = direct_nn(inputs) 151 | direct_loss = batch_loss(direct_prediction, labels).mean() 152 | 153 | direct_loss.backward() # calculate gradients (pytorch autograd) 154 | opt_direct_nn.step() # update parameters 155 | opt_direct_nn.zero_grad() # reset gradient buffer 156 | if iteration >= opt.lrstepoffset: 157 | lrs_direct_nn.step() # update learning rate schedule 158 | 159 | # wrap up 160 | end_time = time.time()-start_time 161 | print('Iteration: %6d, DSAC Expected Loss: %2.2f, DSAC Top Loss: %2.2f, Direct Loss: %2.2f, Time: %.2fs' 162 | % (iteration, exp_loss, top_loss, direct_loss, end_time), flush=True) 163 | 164 | train_log.write('%d %f %f %f\n' % (iteration, exp_loss, top_loss, direct_loss)) 165 | 166 | del exp_loss, top_loss, direct_loss 167 | 168 | # store prediction vizualization and nn weights (each couple of iterations) 169 | if iteration % opt.storeinterval == 0: 170 | 171 | point_nn.eval() 172 | direct_nn.eval() 173 | 174 | # DSAC validation prediction 175 | prediction = point_nn(val_inputs) 176 | val_exp, val_loss = dsac(prediction, val_labels) 177 | val_correct = dsac.est_losses < opt.valthresh 178 | 179 | # direct nn validation prediction 180 | direct_val_est = direct_nn(val_inputs) 181 | direct_val_loss = batch_loss(direct_val_est, val_labels) 182 | direct_val_correct = direct_val_loss < opt.valthresh 183 | 184 | direct_val_est = direct_val_est.detach().cpu().numpy() 185 | dsac_val_est = dsac.est_parameters.detach().cpu().numpy() 186 | points = prediction.detach().cpu().numpy() 187 | 188 | # draw DSAC estimates 189 | viz_dsac = dataset.draw_models(val_labels) 190 | viz_dsac = dataset.draw_points(points, viz_dsac, dsac.batch_inliers) 191 | viz_dsac = dataset.draw_models(dsac_val_est, viz_dsac, val_correct) 192 | 193 | # draw direct estimates 194 | viz_direct = dataset.draw_models(val_labels) 195 | viz_direct = dataset.draw_models(direct_val_est, viz_direct, direct_val_correct) 196 | 197 | def make_grid(batch): 198 | batch = torch.from_numpy(batch) 199 | batch.transpose_(1, 3).transpose_(2, 3) 200 | return vutils.make_grid(batch, nrow=3,normalize=False) 201 | 202 | viz_inputs = make_grid(val_images) 203 | viz_dsac = make_grid(viz_dsac) 204 | viz_direct = make_grid(viz_direct) 205 | 206 | viz = torch.cat((viz_inputs, viz_dsac, viz_direct), 2) 207 | viz.transpose_(0, 1).transpose_(1, 2) 208 | viz = viz.numpy() 209 | 210 | # store image (and ignore warning about loss of precision) 211 | with warnings.catch_warnings(): 212 | warnings.simplefilter("ignore") 213 | outfolder = 'images_' + sid 214 | if not os.path.isdir(outfolder): os.mkdir(outfolder) 215 | imsave('./%s/prediction_%s_%06d.png' % (outfolder, sid, iteration), viz) 216 | 217 | # store model weights 218 | torch.save(point_nn.state_dict(), './weights_pointnn_' + sid + '.net') 219 | torch.save(direct_nn.state_dict(), './weights_directnn_' + sid + '.net') 220 | 221 | print('Storing snapshot. Validation loss: %2.2f'% val_loss, flush=True) 222 | 223 | del val_exp, val_loss, direct_val_loss 224 | 225 | point_nn.train() 226 | direct_nn.train() 227 | 228 | print('Done without errors.') 229 | train_log.close() 230 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Differentiable RANSAC: Learning Robust Line Fitting 2 | 3 | - [Introduction](#introduction) 4 | - [Running the Code](#running-the-code) 5 | - [How Does It Work?](#how-does-it-work) 6 | - [Code Structure](#code-structure) 7 | - [Publications](#publications) 8 | 9 | ## Introduction 10 | 11 | This code illustrates the principles of differentiable RANSAC (DSAC) on a simple toy problem of fitting lines to noisy, synthetic images. 12 | 13 | ![Input and desired output.](./images/task.png) 14 | 15 | **Left:** Input image. **Right:** Ground truth line. 16 | 17 | We solve this task by training a CNN which predicts a set of 2D points within the image. 18 | We fit our desired line to these points using RANSAC. 19 | 20 | ![DSAC line fitting.](./images/dsac.png) 21 | 22 | **Left:** Input image. **Center:** Points predicted by a CNN. **Right:** Line (blue) fitted to the predictions. 23 | 24 | Ideally, the CNN would place all its point predictions on the image line segment. 25 | But because RANSAC is robust to outlier points, the CNN may choose to allow some erroneous predictions in favor of overall accuracy. 26 | 27 | We train the CNN end-to-end from scratch to minimize the deviation between the (robustly) fitted line and the ground truth line. 28 | 29 | ## Running the Code 30 | 31 | Just execute `python main.py` to start a training run with the standard settings. 32 | Running `python main.py -h` will list all parameter options for playing around. 33 | 34 | The code generates training data on the fly, and trains two CNNs in parallel. 35 | The first CNN predicts a set of 2D points to which the output line is fitted using DSAC. 36 | The second CNN is a baseline where the line parameters are predicted directly, i.e. without DSAC. 37 | 38 | In a specified interval during training, both CNNs are tested on a fixed validation set, and the visualization of the predictions is stored in an image such as the following: 39 | 40 | ![Training output.](./images/example_output.png) 41 | 42 | **Left:** Validation inputs. **Center:** DSAC estimates with dots marking the CNN prediction, blue the fitted line and green the ground truth line. Green borders mark accurate line predictions, red boxes mark inaccurate line predictions (there is a threshold parameter). **Right:** Predictions of the baseline CNN (direct prediction of line parameters). 43 | 44 | ### Dependencies 45 | 46 | This code requires the following packages, and was tested with the package version in brackets. 47 | 48 | `pytorch (0.5.0)`, `torchvision (0.2.1)`, `scikit-image (0.14.0)` 49 | 50 | ### Training Speed 51 | 52 | Depending on your system specification, one training iteration (with the standard batch size of 32) can take more than one second. 53 | This might seem excessive for a simple toy problem. 54 | Note that this code is designed for educative clarity rather than speed. 55 | The whole DSAC portion of training runs in native Python on a single CPU core, and backpropagation relies soley on standard PyTorch autograd. 56 | In any production setting, one would write a C++/CUDA extension encapsulating DSAC for a huge runtime boost. 57 | See for example our camera localization pipelines which utilize DSAC [here](https://github.com/cvlab-dresden/DSAC) and [here](https://github.com/vislearn/LessMore). 58 | 59 | ## How Does It Work? 60 | 61 | Vanilla RANSAC works by creating a set of model hypotheses (line hypotheses in our case), scoring them e.g. by inlier counting, and selecting the best one. 62 | 63 | DSAC is based on the idea of making hypothesis selection a probabilistic action. 64 | The probability of selecting a hypothesis increases with its score (e.g. inlier count). 65 | Training the CNN aims at minimizing the expected loss of the selected hypothesis. 66 | 67 | ![Training output.](./images/dsac_eq.png) 68 | 69 | More details and a formal description can be found in the papers referenced at the end of this document. 70 | 71 | In a nutshell, the training process works like this: 72 | 73 | 1. CNN predicts 2D points 74 | 2. sample line hypotheses by choosing random pairs of points 75 | 3. score hypotheses by soft inlier counting, and calculate selection probabilities 76 | 4. refine hypotheses by re-fitting them to their soft inliers 77 | 5. calculate expected loss of refined hypotheses w.r.t. selection probabilities 78 | 6. backprob, update CNN, repeat 79 | 80 | ### Loss Function 81 | 82 | For this toy problem, we are interested in observing visually nicely aligned lines rather then the nominal error in line parameters. 83 | We thus measure the maximum distance between the predicted line and ground truth within the image, and aim at minimizing this distance as our loss function. 84 | 85 | ![Loss function.](./images/loss.png) 86 | 87 | Red arrows mark the error between ground truth line (green) and estimated line (blue) that we try to minimize. 88 | 89 | ## Code Structure 90 | 91 | `main.py` Main script that handles the training loop. 92 | 93 | `dsac.py` Encapsulates robust, differentiable line fitting with DSAC (sampling hypotheses, scoring, refinement, expected loss). 94 | 95 | `line_dataset.py` Generates random, noisy input images with associated ground truth parameters. Also includes functions for visualizing predictions. 96 | 97 | `line_loss.py` Loss function used to compare predicted and ground truth lines. 98 | 99 | `line_nn.py` Definition of the CNN architecture which supports prediction of 2D points or direct regression of line parameters. 100 | 101 | ## Publications 102 | 103 | The following paper introduced DSAC for camera localization ([paper link](https://arxiv.org/abs/1611.05705)). 104 | 105 | ``` 106 | @inproceedings{brachmann2017dsac, 107 | title={{DSAC}-{Differentiable RANSAC} for Camera Localization}, 108 | author={Brachmann, Eric and Krull, Alexander and Nowozin, Sebastian and Shotton, Jamie and Michel, Frank and Gumhold, Stefan and Rother, Carsten}, 109 | booktitle={CVPR}, 110 | year={2017} 111 | } 112 | ``` 113 | 114 | This code uses a soft inlier count instead of a learned scoring function, as suggested in the following paper ([paper link](https://arxiv.org/abs/1711.10228)). 115 | 116 | ``` 117 | @inproceedings{brachmann2018lessmore, 118 | title={Learning Less is More-{6D} Camera Localization via {3D} Surface Regression}, 119 | author={Brachmann, Eric and Rother, Carsten}, 120 | booktitle={CVPR}, 121 | year={2018} 122 | } 123 | ``` 124 | 125 | Please cite one of these papers if you use DSAC or parts of this code in your own work. 126 | --------------------------------------------------------------------------------