├── LICENSE ├── README.md ├── Regularization.gif ├── Regularization.mp4 ├── formulae.png └── regularization.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Itay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Why does L1 regularization induce sparse models? 2 | 3 | Many illustrate this using the least squares problem with a norm constraint, which is an equivalent problem to the regularized least squares problem. 4 | The least squares level sets are drawn next to the different unit "circles". 5 | 6 | I prepared a cool animation which I believe makes it even clearer than static images and helps develop intuition. :) 7 | 8 | ## The animation 9 | ![The outcome](https://github.com/ievron/RegularizationAnimation/blob/main/Regularization.gif?raw=true) 10 | 11 | Also available as a video file in the file list above. 12 | 13 | ### The plotted optimization problems 14 | The left side shows the LS problem constrained by the L2 norm, while the right side uses the L1 norm. 15 | That is, the following problems are illustrated: 16 | 17 | ![The LS formulae](https://github.com/ievron/RegularizationAnimation/blob/main/formulae.png?raw=true) 18 | 19 | ## The code 20 | Can be found on the `regularization.py` script. 21 | 22 | ### Dependencies 23 | - `numpy` 24 | - `matplotlib` 25 | - [`celluloid`](https://pypi.org/project/celluloid/) (only if you want to create animations) 26 | 27 | ## Copyrights 28 | The code and animations are free to use but please keep the copyrights on the animations :) 29 | 30 | 31 | ## Other helpful resources 32 | - [Youtube: Sparsity and the L1 norm by Steve Brunton](https://www.youtube.com/watch?v=76B5cMEZA4Y&feature=youtu.be&ab_channel=SteveBrunton) 33 | - [Sam Petulla's interactive demo](https://observablehq.com/@petulla/l1-l2l_1-l_2l1-l2-norm-geometric-interpretation) 34 | - [Mathematical explanations on CrossValidated](https://stats.stackexchange.com/questions/45643/why-l1-norm-for-sparse-models/45644) 35 | -------------------------------------------------------------------------------- /Regularization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ievron/RegularizationAnimation/a493c15cae2f225ea4117650e3dffcfdde18a054/Regularization.gif -------------------------------------------------------------------------------- /Regularization.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ievron/RegularizationAnimation/a493c15cae2f225ea4117650e3dffcfdde18a054/Regularization.mp4 -------------------------------------------------------------------------------- /formulae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ievron/RegularizationAnimation/a493c15cae2f225ea4117650e3dffcfdde18a054/formulae.png -------------------------------------------------------------------------------- /regularization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | from matplotlib.collections import PatchCollection 5 | import matplotlib as mpl 6 | 7 | 8 | # Consts 9 | STD_NOISE = 2*10**-1 10 | DIM = 2 11 | NUM_SAMPLES = 15 12 | evals = [1, 2] # underlying singular values 13 | SEED = 19890 14 | 15 | ylim = [-1.5, 4] 16 | xlim = [-4, 4] 17 | CONTOUR_LINES = 5 18 | LABELS_FONTSIZE = 14 19 | TITLE_FONTSIZE = 18 20 | SUPTITLE_FONTSIZE = 20 21 | FIGSIZE = (12, 6) 22 | CONTOUR_LOSS_THRESHOLD = 20 23 | 24 | STEPS = 100 25 | ROTATION_STEPS = 200 26 | w_1_range = [-2.5, 3] 27 | w_t = [max(w_1_range), 1.8] 28 | 29 | FPS = 30 30 | SHOULD_ANIMATE = True 31 | 32 | 33 | def SVD(M): 34 | U, D, V = np.linalg.svd(M) # V ROWS are the eigenvectors 35 | V = V.T 36 | S = np.zeros(M.shape, dtype=V.dtype) 37 | S[:D.shape[0], :D.shape[0]] += np.diag(D) 38 | 39 | return U, S, V 40 | 41 | def generateData(NUM_SAMPLES, DIM, STD_NOISE, singularValues): 42 | # Not really random 43 | X = np.random.randn(NUM_SAMPLES, DIM) 44 | U, _, V = SVD(X) 45 | 46 | S = np.zeros(X.shape, dtype=V.dtype) 47 | S[:len(singularValues), :len(singularValues)] += np.diag(singularValues) 48 | X = U.dot(S.dot(V.T)) 49 | 50 | noise = np.random.randn(NUM_SAMPLES).astype(np.float64) * STD_NOISE 51 | 52 | return X, noise 53 | 54 | # Find the suitable lambda for the unconstrained problem 55 | def normToLambda(norm, S, U, y): 56 | z = S.T.dot(U.T.dot(y)) 57 | 58 | minLambda = 10**-5 59 | maxLambda = 10**5 60 | currLambda = 10 61 | eps = 10**-10 62 | 63 | i = 0 64 | # Since the relation is "monotonic", use binary search 65 | while True: 66 | i += 1 67 | 68 | if (i > 10000): 69 | exit() 70 | 71 | leftSide = (z[0] / (S[0,0] ** 2 + currLambda)) ** 2 + (z[1] / (S[1,1] ** 2 + currLambda)) ** 2 72 | 73 | if leftSide >= norm + eps: 74 | # Lambda is too small 75 | minLambda = currLambda 76 | currLambda = maxLambda - (maxLambda - currLambda) / 2 77 | continue 78 | elif leftSide <= norm - eps: 79 | # Lambda is too large 80 | maxLambda = currLambda 81 | currLambda = (currLambda - minLambda) / 2 + minLambda 82 | continue 83 | 84 | return currLambda 85 | 86 | def plotUnitCircle(alpha, fc, lw): 87 | # Plot unit circle 88 | unitCircle = plt.Circle((0, 0), 1, 89 | alpha=alpha, fc=fc, lw=lw, zorder=10) 90 | # Draw a non-transparent white edge to wipe the facecolor where they overlap 91 | c_wipe = plt.Circle((0, 0), 1, 92 | alpha=1.0, ec='white', fc='none', lw=lw, zorder=10) 93 | # Now draw only the edge 94 | c_edge = plt.Circle((0, 0), 1, 95 | fc='none', ec='k', lw=lw) 96 | unitCircle_patch = PatchCollection([unitCircle, c_wipe, c_edge], match_original=True, zorder=10) 97 | axes[0].add_artist(unitCircle_patch) 98 | 99 | def plotUnitDiamond(alpha, fc, lw): 100 | # Plot unit diamond 101 | diagonal = np.sqrt(2) 102 | 103 | rect = patches.Rectangle((-1/np.sqrt(2), -1/np.sqrt(2)), diagonal, diagonal, 104 | alpha=alpha, fc=fc, zorder=10) 105 | # Draw a non-transparent white edge to wipe the facecolor where they overlap 106 | r_wipe = patches.Rectangle((-1/np.sqrt(2), -1/np.sqrt(2)), diagonal, diagonal, 107 | alpha=1.0, ec='white', fc='none', lw=lw, zorder=10) 108 | # Now draw only the edge 109 | r_edge =patches.Rectangle((-1/np.sqrt(2), -1/np.sqrt(2)), diagonal, diagonal, 110 | fc='none', ec='k', lw=lw) 111 | t2 = mpl.transforms.Affine2D().rotate_deg(-45) + axes[1].transData 112 | rect_patch = PatchCollection([rect, r_wipe, r_edge], match_original=True, zorder=10) 113 | rect_patch.set_transform(t2) 114 | 115 | axes[1].add_artist(rect_patch) 116 | 117 | # Create linear range between two points (https://stackoverflow.com/a/46694364/1947677) 118 | def create_ranges_nd(start, stop, N, endpoint=True): 119 | if endpoint==1: 120 | divisor = N-1 121 | else: 122 | divisor = N 123 | steps = (1.0/divisor) * (stop - start) 124 | return start[...,None] + steps[...,None]*np.arange(N) 125 | 126 | def errorForSolution(X, y, sol): 127 | return np.linalg.norm(X.dot(sol) - y) ** 2 128 | 129 | def findL1Solution(norm, X, y, unregSol): 130 | if unregSol[0] >= 0 and unregSol[1] >= 0: 131 | # First quarter 132 | p1 = [0, norm] 133 | p2 = [norm, 0] 134 | elif unregSol[0] < 0 and unregSol[1] >= 0: 135 | # Second quarter 136 | p1 = [0, norm] 137 | p2 = [-norm, 0] 138 | elif unregSol[0] < 0 and unregSol[1] < 0: 139 | # Third quarter 140 | p1 = [0, -norm] 141 | p2 = [-norm, 0] 142 | else: 143 | # Fourth quarter 144 | p1 = [0, -norm] 145 | p2 = [norm, 0] 146 | 147 | # Search between the two points (more candidates => closer to actual solution) 148 | candidates = create_ranges_nd(np.array(p1), np.array(p2), N=10**3).T 149 | errors = [errorForSolution(X, y, s) for s in candidates] 150 | 151 | return candidates[np.argmin(errors), :] 152 | 153 | # Draw the axes + copyrights + norm circles 154 | def drawStaticShapes(axes): 155 | for ax in axes: 156 | # Draw axes 157 | ax.arrow(xlim[0] + 0.2, 0, xlim[1] - xlim[0] - 0.5, 0., fc='k', ec='k', 158 | lw=1.5, head_width=.2, head_length=.2, 159 | length_includes_head= True, clip_on = False, zorder=2) 160 | 161 | ax.arrow(0, ylim[0] + 0.2, 0., ylim[1] - ylim[0] - 0.5, fc='k', ec='k', 162 | lw=1.5, head_width=.2, head_length=.2, 163 | length_includes_head= True, clip_on = False, zorder=2) 164 | 165 | 166 | # Plot copyrights 167 | axes[-1].annotate('by @itayevron', 168 | fontsize=16, c='grey', zorder=20, 169 | xy=(1, 0), xytext=(0, 20), 170 | xycoords=('axes fraction', 'figure fraction'), 171 | textcoords='offset points', 172 | size=14, ha='right', va='bottom') 173 | 174 | fc = 'c' 175 | alpha = 0.7 176 | lw = 2 177 | plotUnitCircle(alpha, fc, lw) 178 | plotUnitDiamond(alpha, fc, lw) 179 | 180 | 181 | def plotSolution(ax, pnt,): 182 | return ax.scatter(pnt[0], pnt[1], c='orange', s=70, edgecolors='k', zorder=11) 183 | 184 | def plotContours(X, y): 185 | # Gather all discrete points 186 | delta = np.sqrt(((xlim[1] - xlim[0]) * (ylim[1] - ylim[0])) / 10 ** 6) # Have one million grids 187 | Xs = np.arange(xlim[0], xlim[1], delta) 188 | Ys = np.arange(ylim[0], ylim[1], delta) 189 | Xpt, Ypt = np.meshgrid(Xs, Ys) 190 | Xpt2 = np.ravel(Xpt) 191 | Ypt2 = np.ravel(Ypt) 192 | 193 | # Compute all weight combinations 194 | Weights = np.array([Xpt2, Ypt2]) 195 | residuals = X.dot(Weights).T - y.T 196 | losses = np.linalg.norm(residuals, axis=1) ** 2 197 | Z = losses.reshape(Xpt.shape) 198 | 199 | # Nullify outside a threshold value to create an ellipsis (cleaner view) 200 | Z[Z > CONTOUR_LOSS_THRESHOLD] = 0 201 | 202 | contours = [] 203 | for ax in axes: 204 | contours.append(ax.contour(Xpt, Ypt, Z, CONTOUR_LINES, linewidths=1, colors='k')) 205 | 206 | return contours 207 | 208 | # Plot one "frame"/setting of the system 209 | def plot(X, U, S, V, y): 210 | plottedPoints = [] 211 | 212 | # Find unconstrained (=unregularized solution) 213 | unregSol = V.dot(np.linalg.pinv(S).dot(U.T.dot(y))) 214 | 215 | # Plot contours 216 | contours = plotContours(X, y) 217 | 218 | # Plot l2-regularized solution 219 | regCoef = normToLambda(1, S, U, y) 220 | w_opt = np.linalg.inv(X.T.dot(X) + regCoef * np.eye(2)).dot(X.T.dot(y)) 221 | plottedPoints.append(plotSolution(axes[0], w_opt)) 222 | 223 | # Plot l1-regularized solution 224 | l1_opt = findL1Solution(1, X, y, unregSol) 225 | plottedPoints.append(plotSolution(axes[1], l1_opt)) 226 | 227 | # Plot unregularized solution 228 | for ax in axes: 229 | plottedPoints.append(ax.scatter(unregSol[0], unregSol[1], c='grey', s=40, zorder=11)) 230 | 231 | if SHOULD_ANIMATE: 232 | # Animation handling (requires plotting *everything* again) 233 | drawStaticShapes(axes) 234 | camera.snap() 235 | else: 236 | plt.draw() 237 | plt.show(block=False) 238 | 239 | plt.pause(0.1) 240 | 241 | # When not animating, the fixed shapes are drawn once 242 | # The rest of the shapes should be cleared at each iteration 243 | for cntr in contours: 244 | for c in cntr.collections: 245 | c.remove() 246 | 247 | for pnt in plottedPoints: 248 | pnt.remove() 249 | 250 | 251 | def rotateProblem(U, S, V, noise, angle): 252 | theta = angle * np.pi / 180 253 | 254 | # Create rotation matrix and rotate 255 | currV = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]).dot(V) 256 | X = U.dot(S.dot(currV.T)) 257 | y = X.dot(w_t) + noise 258 | 259 | return X, y, currV 260 | 261 | np.random.seed(SEED) 262 | 263 | # Plotting starts 264 | fig, axes = plt.subplots(1, 2, figsize=FIGSIZE) 265 | plt.tight_layout(pad=2) 266 | axes = np.ravel(axes) 267 | 268 | # Initialize plots and axes 269 | for ax in axes: 270 | ax.set_aspect('equal', adjustable='box') 271 | ax.set_aspect('equal', adjustable='box') 272 | ax.set_ylim(ylim) 273 | ax.set_xlim(xlim) 274 | ax.grid(zorder=0, alpha=0.5) 275 | ax.set_xlabel(r"$w_1$", fontsize=LABELS_FONTSIZE) 276 | ax.set_ylabel(r"$w_2$", fontsize=LABELS_FONTSIZE) 277 | 278 | 279 | # Set titles 280 | plt.suptitle(r"$\ell^1$ induces sparse solutions for least squares", fontsize=SUPTITLE_FONTSIZE) 281 | axes[0].set_title(r"$\ell^2$ regularization", fontsize=TITLE_FONTSIZE) 282 | axes[1].set_title(r"$\ell^1$ regularization", fontsize=TITLE_FONTSIZE) 283 | 284 | 285 | # Animation handling 286 | if SHOULD_ANIMATE: 287 | from celluloid import Camera 288 | camera = Camera(fig) 289 | else: 290 | drawStaticShapes(axes) 291 | 292 | 293 | # Generate data 294 | X, noise = generateData(NUM_SAMPLES, DIM, STD_NOISE, evals) 295 | U, S, V = SVD(X) 296 | 297 | 298 | w_t[0] = max(w_1_range) 299 | 300 | # Rotate 301 | print("Rotating...") 302 | for angle in np.linspace(0, 360, ROTATION_STEPS): 303 | X, y, currV = rotateProblem(U, S, V, noise, angle) 304 | 305 | plot(X, U, S, currV, y) 306 | 307 | 308 | # Reconstruct problem 309 | X = U.dot(S.dot(V.T)) 310 | 311 | # Move left 312 | print("Moving left...") 313 | for w_1 in np.linspace(max(w_1_range), min(w_1_range), STEPS): 314 | w_t[0] = w_1 315 | 316 | X = U.dot(S.dot(V.T)) 317 | y = X.dot(w_t) + noise 318 | 319 | plot(X, U, S, V, y) 320 | 321 | # Rotate 322 | print("Rotating...") 323 | w_t[0] = min(w_1_range) 324 | for angle in np.linspace(0, 360, ROTATION_STEPS): 325 | X, y, currV = rotateProblem(U, S, V, noise, angle) 326 | 327 | plot(X, U, S, currV, y) 328 | 329 | X = U.dot(S.dot(V.T)) 330 | 331 | # Move back right 332 | print("Moving right...") 333 | for w_1 in np.linspace(min(w_1_range), max(w_1_range), STEPS): 334 | w_t[0] = w_1 335 | 336 | X = U.dot(S.dot(V.T)) 337 | y = X.dot(w_t) + noise 338 | 339 | plot(X, U, S, V, y) 340 | 341 | # Animation handling 342 | if SHOULD_ANIMATE: 343 | print("Creating animation...") 344 | animation = camera.animate(interval=1000 // FPS, repeat=True, blit=True) 345 | animation.save('Regularization.gif') 346 | animation.save('Regularization.mp4') 347 | else: 348 | plt.show() 349 | --------------------------------------------------------------------------------