├── Readme.md └── pytorch_randomized_svd.ipynb /Readme.md: -------------------------------------------------------------------------------- 1 | Attempt to speed up randomized SVD using GPU. Tested against current top algos. 2 | /nbs -- https://github.com/smortezavi/Randomized_SVD_GPU/blob/master/pytorch_randomized_svd.ipynb 3 | 4 | -------------------------------------------------------------------------------- /pytorch_randomized_svd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import timeit\n", 13 | "import pandas as pd\n", 14 | "from sklearn import decomposition\n", 15 | "import fbpca\n", 16 | "import torch\n", 17 | "from scipy import linalg" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": { 24 | "collapsed": true 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "def simple_randomized_svd(M, k=10):\n", 29 | " m, n = M.shape\n", 30 | " transpose = False\n", 31 | " if m < n:\n", 32 | " transpose = True\n", 33 | " M = M.T\n", 34 | " rand_matrix = np.random.normal(size=(M.shape[1], k)) # short side by k\n", 35 | " Q, _ = np.linalg.qr(M @ rand_matrix, mode='reduced') # long side by k\n", 36 | " smaller_matrix = Q.T @ M # k by short side\n", 37 | " U_hat, s, V = np.linalg.svd(smaller_matrix, full_matrices=False)\n", 38 | " U = Q @ U_hat\n", 39 | " \n", 40 | " if transpose:\n", 41 | " return V.T, s.T, U.T\n", 42 | " else:\n", 43 | " return U, s, V" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "def simple_randomized_torch_svd(M, k=10):\n", 53 | " B = torch.tensor(M).cuda(0)\n", 54 | " m, n = B.size()\n", 55 | " transpose = False\n", 56 | " if m < n:\n", 57 | " transpose = True\n", 58 | " B = B.transpose(0, 1).cuda(0)\n", 59 | " m, n = B.size()\n", 60 | " rand_matrix = torch.rand((n,k), dtype=torch.double).cuda(0) # short side by k\n", 61 | " Q, _ = torch.qr(B @ rand_matrix) # long side by k\n", 62 | " Q.cuda(0)\n", 63 | " smaller_matrix = (Q.transpose(0, 1) @ B).cuda(0) # k by short side\n", 64 | " U_hat, s, V = torch.svd(smaller_matrix,False)\n", 65 | " U_hat.cuda(0)\n", 66 | " U = (Q @ U_hat)\n", 67 | " \n", 68 | " if transpose:\n", 69 | " return V.transpose(0, 1), s, U.transpose(0, 1)\n", 70 | " else:\n", 71 | " return U, s, V" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "metadata": { 78 | "collapsed": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "# computes an orthonormal matrix whose range approximates the range of A\n", 83 | "# power_iteration_normalizer can be safe_sparse_dot (fast but unstable), LU (imbetween), or QR (slow but most accurate)\n", 84 | "def randomized_range_finder(A, size, n_iter=5):\n", 85 | " Q = np.random.normal(size=(A.shape[1], size))\n", 86 | " \n", 87 | " for i in range(n_iter):\n", 88 | " Q, _ = linalg.lu(A @ Q, permute_l=True)\n", 89 | " Q, _ = linalg.lu(A.T @ Q, permute_l=True)\n", 90 | " \n", 91 | " Q, _ = linalg.qr(A @ Q, mode='economic')\n", 92 | " return Q" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "def randomized_svd(M, n_components, n_oversamples=10, n_iter=4):\n", 102 | " n_random = n_components + n_oversamples\n", 103 | " \n", 104 | " Q = torch.tensor(randomized_range_finder(M, n_random, n_iter)).cuda(0)\n", 105 | " # project M to the (k + p) dimensional space using the basis vectors\n", 106 | " M = torch.tensor(M).cuda(0)\n", 107 | " B = Q.transpose(0, 1) @ M\n", 108 | " # compute the SVD on the thin matrix: (k + p) wide\n", 109 | " Uhat, s, V = linalg.svd(B, full_matrices=False)\n", 110 | " Uhat = torch.tensor(Uhat).cuda(0)\n", 111 | " del B\n", 112 | " U = Q @ Uhat\n", 113 | " \n", 114 | " return U[:, :n_components], s[:n_components], V[:n_components, :]" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 6, 120 | "metadata": { 121 | "collapsed": true 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "def randomized_svd_original(M, n_components, n_oversamples=10, n_iter=4):\n", 126 | " \n", 127 | " n_random = n_components + n_oversamples\n", 128 | " \n", 129 | " Q = randomized_range_finder(M, n_random, n_iter)\n", 130 | " # project M to the (k + p) dimensional space using the basis vectors\n", 131 | " B = Q.T @ M\n", 132 | " # compute the SVD on the thin matrix: (k + p) wide\n", 133 | " Uhat, s, V = linalg.svd(B, full_matrices=False)\n", 134 | " del B\n", 135 | " U = Q @ Uhat\n", 136 | " \n", 137 | " return U[:, :n_components], s[:n_components], V[:n_components, :]" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 51, 143 | "metadata": {}, 144 | "outputs": [ 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "(tensor([[ 0.2665, -0.4784, -0.4189, 0.0461, 0.1967],\n", 149 | " [ 0.0450, 0.2502, -0.4644, 0.6517, -0.3875],\n", 150 | " [-0.0823, 0.5004, 0.1257, -0.2966, -0.1977],\n", 151 | " [ 0.2300, 0.0620, -0.3229, -0.6062, -0.0580],\n", 152 | " [ 0.4515, 0.2407, 0.3616, 0.1039, 0.1663],\n", 153 | " [ 0.5530, 0.1675, -0.0788, -0.0937, -0.4466],\n", 154 | " [ 0.0042, 0.1490, -0.4542, -0.2347, 0.0064],\n", 155 | " [ 0.1880, -0.5289, 0.1129, -0.1100, -0.3019],\n", 156 | " [ 0.1415, 0.2574, -0.2925, 0.0434, 0.6464],\n", 157 | " [-0.5497, -0.0077, -0.2174, -0.1706, -0.1805]], dtype=torch.float64, device='cuda:0'),\n", 158 | " array([ 291.1458, 272.6614, 258.4442, 243.4219, 235.0934]),\n", 159 | " array([[ 0.0518, 0.0131, 0.1641, 0.1506, 0.146 , 0.0222, -0.0469,\n", 160 | " 0.1355, -0.098 , -0.0233, 0.0092, 0.0286, 0.0281, 0.0301,\n", 161 | " 0.0223, 0.0542, 0.0396, 0.0861, 0.1835, 0.17 , -0.104 ,\n", 162 | " -0.0503, -0.1279, 0.2209, -0.115 , -0.0889, -0.1002, 0.119 ,\n", 163 | " 0.097 , -0.0717, 0.0485, -0.0142, -0.0282, 0.158 , 0.0102,\n", 164 | " 0.1407, -0.2371, 0.0467, 0.0846, 0.059 , 0.0277, -0.0588,\n", 165 | " -0.1047, -0.1641, -0.0593, -0.022 , 0.1137, -0.1559, 0.0385,\n", 166 | " 0.0147, 0.1155, -0.0089, 0.008 , -0.1397, 0.1101, -0.0368,\n", 167 | " -0.0584, -0.0091, 0.0055, -0.0985, 0.1003, -0.0447, 0.1091,\n", 168 | " 0.0239, 0.0129, 0.0153, -0.1851, 0.1479, 0.064 , -0.0041,\n", 169 | " -0.0896, 0.08 , 0.0343, -0.038 , 0.0705, -0.0915, 0.1503,\n", 170 | " 0.0479, 0.006 , 0.1062, 0.0698, 0.0503, 0.1777, -0.0383,\n", 171 | " 0.054 , 0.1732, 0.0136, -0.0869, -0.0776, -0.0567, 0.1193,\n", 172 | " 0.0808, 0.2016, 0.1375, -0.0678, 0.1257, 0.1043, 0.0659,\n", 173 | " -0.0394, -0.2398],\n", 174 | " [ 0.1047, -0.0278, 0.0394, -0.0116, -0.0278, 0.1107, 0.0449,\n", 175 | " -0.035 , 0.0247, 0.0578, -0.1732, -0.0237, 0.1903, 0.059 ,\n", 176 | " 0.09 , 0.0089, -0.0202, 0.111 , 0.004 , -0.016 , -0.0801,\n", 177 | " 0.1059, 0.0315, 0.03 , 0.084 , 0.0228, 0.0251, -0.0063,\n", 178 | " 0.0831, -0.0133, 0.1926, -0.2026, 0.1024, -0.0399, -0.1468,\n", 179 | " -0.1746, 0.1429, -0.1128, -0.0426, -0.1053, 0.039 , 0.0626,\n", 180 | " -0.0411, 0.0109, -0.0389, 0.1114, 0.018 , -0.1367, -0.0503,\n", 181 | " -0.1345, 0.0085, -0.0016, 0.1778, -0.069 , -0.1027, -0.205 ,\n", 182 | " -0.0415, 0.0277, -0.1185, 0.0991, -0.013 , 0.0184, -0.1554,\n", 183 | " 0.1983, -0.02 , 0.1438, 0.0074, 0.0213, 0.2181, -0.0763,\n", 184 | " -0.1173, 0.2085, 0.088 , 0.0155, -0.1364, 0.0703, -0.0937,\n", 185 | " -0.0265, -0.0185, -0.0257, 0.1124, 0.0795, 0.0913, 0.001 ,\n", 186 | " 0.2215, 0.1455, -0.0052, 0.0898, -0.0091, -0.2242, -0.1665,\n", 187 | " 0.0639, 0.0636, 0.0718, 0.0732, -0.1127, -0.1063, -0.1051,\n", 188 | " -0.0029, -0.0029],\n", 189 | " [-0.1149, 0.0279, -0.1338, 0.116 , 0.0468, 0.1232, 0.054 ,\n", 190 | " 0.0565, -0.1259, -0.0839, 0.152 , -0.0863, 0.0003, -0.2371,\n", 191 | " 0.0439, 0.051 , 0.109 , 0.055 , 0.131 , 0.0244, -0.0124,\n", 192 | " 0.2318, 0.1091, -0.0065, -0.0539, 0.0326, 0.1813, 0.024 ,\n", 193 | " 0.0575, 0.1769, 0.0659, -0.0711, -0.0469, -0.0175, -0.0626,\n", 194 | " -0.0345, 0.0641, 0.1338, 0.0266, 0.0062, -0.0386, 0.1407,\n", 195 | " -0.0436, 0.0055, 0.0431, 0.1255, -0.041 , -0.0348, -0.1199,\n", 196 | " -0.0158, 0.1152, -0.2198, 0.1552, 0.0876, -0.1477, -0.1201,\n", 197 | " 0.068 , -0.0179, 0.0766, 0.0899, 0.0199, -0.1206, 0.1611,\n", 198 | " -0.1097, -0.0424, -0.0406, 0.1487, 0.1289, -0.0652, -0.1107,\n", 199 | " 0.0225, -0.1439, -0.1709, -0.1325, 0.026 , -0.1506, 0.0753,\n", 200 | " 0.1793, 0.0631, 0.1439, -0.0564, 0.1311, 0.09 , 0.0136,\n", 201 | " 0.1341, -0.0405, 0.0953, 0.0239, -0.0048, 0.0429, 0.0573,\n", 202 | " 0.0523, -0.0403, 0.0003, -0.0562, 0.0885, -0.1245, -0.0766,\n", 203 | " 0.1717, 0.0722],\n", 204 | " [-0.2255, -0.0402, -0.1114, 0.1937, -0.0568, 0.0984, -0.0648,\n", 205 | " -0.1507, 0.1158, 0.1139, -0.0229, 0.0422, 0.197 , 0.0318,\n", 206 | " 0.0226, -0.1479, 0.0431, -0.1541, -0.0681, -0.0846, 0.1133,\n", 207 | " 0.0326, 0.0527, -0.1195, -0.1615, 0.1094, 0.0358, 0.2004,\n", 208 | " -0.0675, 0.1282, 0.1066, 0.0048, -0.0814, 0.1301, 0.0967,\n", 209 | " 0.075 , 0.1008, 0.0662, -0.1028, -0.0577, 0.0941, 0.0233,\n", 210 | " -0.1145, 0.0581, -0.1916, -0.0421, 0.1936, 0.0785, -0.0758,\n", 211 | " -0.056 , 0.0159, 0.0372, -0.0504, -0.0189, -0.0927, -0.0653,\n", 212 | " -0.0319, 0.0388, 0.0135, -0.0001, -0.2637, 0.1775, 0.0873,\n", 213 | " 0.0548, 0.0287, -0.0601, -0.1801, 0.0968, -0.0181, -0.0947,\n", 214 | " 0.0702, 0.0035, 0.0148, 0.0786, 0.046 , -0.0466, -0.0671,\n", 215 | " -0.0059, -0.0293, 0.133 , 0.0439, -0.0543, 0.101 , -0.1619,\n", 216 | " 0.0623, -0.1021, -0.1044, -0.1816, -0.1118, 0.0177, -0.0705,\n", 217 | " -0.0982, -0.0008, 0.084 , -0.0429, -0.0115, 0.1613, -0.0457,\n", 218 | " -0.1003, -0.052 ],\n", 219 | " [ 0.1515, 0.0668, 0.1857, -0.1832, 0.1376, 0.0806, 0.02 ,\n", 220 | " -0.0895, 0.1381, 0.1017, 0.1031, -0.0905, 0.0288, 0.0122,\n", 221 | " 0.285 , 0.214 , 0.175 , 0.0477, -0.1441, 0.1352, -0.0321,\n", 222 | " 0.013 , 0.0007, 0.0931, -0.1038, -0.1026, -0.0756, -0.0975,\n", 223 | " 0.1073, 0.0184, -0.1268, 0.0242, -0.0954, 0.0914, -0.0318,\n", 224 | " -0.05 , 0.0268, 0.0005, -0.1374, -0.0452, -0.1993, 0.1278,\n", 225 | " -0.039 , -0.0092, -0.1076, 0.0096, 0.1636, -0.0246, -0.1138,\n", 226 | " 0.1856, -0.0244, -0.1302, -0.0931, -0.007 , 0.1454, -0.1308,\n", 227 | " 0.1207, 0.0613, -0.1073, 0.0442, -0.0394, 0.1058, -0.0777,\n", 228 | " 0.0225, -0.0158, -0.1921, -0.0933, -0.0287, -0.0471, -0.1154,\n", 229 | " -0.0455, -0.0632, -0.0143, 0.1797, 0.0116, -0.1832, 0.0383,\n", 230 | " -0.0319, 0.0525, -0.1981, -0.0385, 0.0208, 0.0057, -0.1077,\n", 231 | " 0.0467, -0.0724, 0.0701, 0.0228, 0.034 , 0.0531, -0.0178,\n", 232 | " -0.0728, -0.1358, -0.0163, -0.0502, -0.0159, -0.0144, -0.0449,\n", 233 | " -0.0267, 0.0774]]))" 234 | ] 235 | }, 236 | "execution_count": 51, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | } 240 | ], 241 | "source": [ 242 | "A = np.random.uniform(-40,40,[10,100]) \n", 243 | "randomized_svd(A, 5)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 28, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "data": { 253 | "text/plain": [ 254 | "(array([[-0.0927, -0.3094, 0.0138, -0.498 , 0.1362, -0.6804, 0.3705,\n", 255 | " 0.0016, -0.1609, -0.0533],\n", 256 | " [-0.3696, -0.5729, 0.1055, 0.0333, -0.1139, 0.0144, -0.315 ,\n", 257 | " 0.4539, 0.4073, -0.1962],\n", 258 | " [ 0.4726, -0.0089, 0.1479, 0.0591, -0.3718, -0.1061, -0.0042,\n", 259 | " -0.0511, -0.0456, -0.7727],\n", 260 | " [ 0.2914, -0.5239, -0.3223, -0.251 , -0.0146, 0.3509, -0.2546,\n", 261 | " -0.0093, -0.5257, 0.0952],\n", 262 | " [-0.1348, 0.3822, 0.1706, -0.4915, 0.2152, 0.3631, 0.0837,\n", 263 | " 0.5217, -0.1876, -0.2691],\n", 264 | " [ 0.4823, -0.0967, 0.0848, 0.4299, 0.5133, -0.1813, 0.0785,\n", 265 | " 0.5014, -0.0588, 0.093 ],\n", 266 | " [ 0.093 , -0.0529, -0.6686, -0.0841, 0.2823, 0.2049, 0.332 ,\n", 267 | " -0.0763, 0.4755, -0.2658],\n", 268 | " [-0.2028, 0.0905, -0.422 , 0.2922, -0.5218, -0.0799, 0.3588,\n", 269 | " 0.4363, -0.292 , 0.0649],\n", 270 | " [ 0.496 , 0.079 , 0.0429, -0.3942, -0.4031, -0.0208, 0.0142,\n", 271 | " 0.2405, 0.4233, 0.4363],\n", 272 | " [ 0.0036, -0.3589, 0.4469, 0.101 , -0.076 , 0.4346, 0.668 ,\n", 273 | " -0.1165, 0.0226, 0.0793]]),\n", 274 | " array([ 277.6828, 266.5159, 255.7381, 251.1185, 231.8737, 227.6304,\n", 275 | " 208.1148, 193.3796, 180.6411, 165.6627]),\n", 276 | " array([[-0.1194, -0.1059, 0.0448, 0.0323, -0.0184, 0.0164, 0.1107,\n", 277 | " 0.1889, 0.0443, 0.152 , -0.1577, -0.1047, -0.0287, 0.0447,\n", 278 | " -0.0222, 0.1027, 0.0896, 0.0256, -0.0631, 0.2503, -0.0941,\n", 279 | " -0.2155, -0.1066, 0.0411, -0.0044, -0.1438, 0.0143, -0.0057,\n", 280 | " -0.0201, 0.094 , 0.0311, 0.1404, -0.1361, 0.0639, 0.038 ,\n", 281 | " 0.1444, 0.1051, -0.0501, -0.0036, 0.1643, 0.0518, 0.0598,\n", 282 | " -0.0852, -0.1342, 0.0366, 0.1767, -0.099 , -0.0103, 0.0108,\n", 283 | " 0.1778, -0.159 , -0.0396, 0.0145, -0.1709, 0.0699, -0.165 ,\n", 284 | " 0.0288, 0.0027, 0.0259, -0.0587, -0.0326, 0.0643, -0.0139,\n", 285 | " -0.227 , -0.0479, 0.0867, -0.0538, -0.0293, -0.0625, -0.0212,\n", 286 | " 0.1295, -0.0414, -0.0687, -0.1003, 0.1867, 0.0899, 0.0124,\n", 287 | " 0.0393, 0.1353, 0.0867, -0.1051, -0.0636, -0.0234, 0.1038,\n", 288 | " -0.0898, -0.1123, -0.1175, -0.1263, -0.0303, 0.1688, -0.0692,\n", 289 | " -0.0361, 0.0794, 0.0727, -0.0921, -0.0932, 0.0346, -0.0148,\n", 290 | " -0.1141, -0.1871],\n", 291 | " [-0.0684, -0.0398, 0.0206, -0.1001, -0.1431, 0.0881, -0.1984,\n", 292 | " -0.0424, -0.0399, -0.0002, 0.1922, 0.0108, -0.0755, 0.0587,\n", 293 | " -0.0342, 0.1755, 0.0795, 0.0001, -0.0849, 0.0995, 0.121 ,\n", 294 | " -0.167 , -0.1273, 0.0813, 0.0177, -0.0494, -0.1588, -0.0827,\n", 295 | " 0.1408, -0.0441, 0.0248, 0.0259, 0.2409, 0.1924, 0.0976,\n", 296 | " 0.0346, -0.1064, -0.0406, -0.1364, -0.0326, 0.0262, 0.0643,\n", 297 | " 0.0348, 0.1109, -0.0144, 0.1152, -0.1018, 0.0876, 0.0459,\n", 298 | " 0.008 , 0.059 , 0.1505, -0.0883, 0.0658, -0.0711, 0.0704,\n", 299 | " -0.0777, -0.2006, -0.0437, 0.0167, -0.0064, 0.1083, 0.0233,\n", 300 | " -0.0497, -0.0578, -0.0621, 0.0751, -0.1627, -0.0343, -0.0834,\n", 301 | " -0.1485, -0.0777, -0.0034, -0.1765, 0.0568, 0.0163, 0.1305,\n", 302 | " 0.1266, -0.1724, -0.0842, 0.1017, 0.0856, 0.1258, 0.1359,\n", 303 | " -0.0683, -0.1306, 0.045 , 0.0655, 0.0264, -0.1812, -0.1439,\n", 304 | " -0.0052, -0.1109, 0.0527, -0.1811, -0.0578, -0.0426, -0.0676,\n", 305 | " 0.1123, -0.0313],\n", 306 | " [ 0.0098, 0.0089, 0.0564, -0.088 , 0.0003, -0.0827, 0.1666,\n", 307 | " -0.062 , 0.1427, -0.0718, 0.0482, 0.1847, -0.1329, -0.0236,\n", 308 | " 0.086 , -0.1303, -0.0225, -0.1029, 0.0635, 0.05 , -0.0386,\n", 309 | " 0.0086, -0.1426, 0.1299, -0.0772, -0.0782, -0.0802, 0.1941,\n", 310 | " 0.1048, -0.0166, 0.0384, 0.1791, -0.0119, -0.0255, -0.064 ,\n", 311 | " -0.1555, 0.0087, 0.132 , 0.1469, -0.0894, 0.0523, 0.0109,\n", 312 | " 0.0867, 0.0801, -0.0139, 0.1224, -0.1428, 0.0412, 0.2326,\n", 313 | " 0.0224, -0.0209, -0.085 , -0.0408, 0.0454, 0.0968, 0.0217,\n", 314 | " -0.2144, 0.1343, -0.0781, 0.0255, -0.0642, -0.0879, 0.0371,\n", 315 | " 0.0825, 0.0888, -0.1127, 0.028 , -0.0634, 0.139 , -0.1085,\n", 316 | " 0.1306, -0.0432, -0.1005, -0.0435, -0.0653, 0.0624, -0.0193,\n", 317 | " 0.0567, 0.1523, 0.031 , 0.0664, 0.0582, 0.0684, 0.1933,\n", 318 | " -0.0126, 0.0368, 0.0381, -0.0924, -0.1445, -0.0333, -0.1543,\n", 319 | " -0.1868, 0.0154, -0.2105, 0.1586, -0.1404, 0.1674, -0.0023,\n", 320 | " 0.0928, 0.0193],\n", 321 | " [-0.0312, 0.0933, -0.0229, 0.0331, 0.1376, 0.1745, -0.0012,\n", 322 | " 0.1118, 0.1125, 0.0059, -0.0473, -0.1874, 0.0003, -0.1555,\n", 323 | " 0.0708, 0.0351, -0.0353, -0.069 , 0.0834, 0.1013, 0.0272,\n", 324 | " 0.0905, -0.0961, -0.013 , -0.2107, 0.0895, 0.0946, 0.0179,\n", 325 | " 0.0648, -0.0176, -0.0776, -0.1891, 0.0275, 0.2144, 0.1968,\n", 326 | " -0.0629, -0.0736, 0.0178, 0.1535, -0.1932, -0.1041, 0.0819,\n", 327 | " 0.0172, -0.0118, -0.208 , 0.1033, -0.0175, 0.139 , 0.1149,\n", 328 | " 0.0169, -0.07 , -0.0948, 0.0045, 0.037 , 0.0593, 0.0397,\n", 329 | " 0.1287, 0.0053, -0.0724, -0.0459, 0.0204, 0.0833, 0.2035,\n", 330 | " -0.0783, 0.0821, -0.0689, 0.0843, 0.1898, 0.1107, 0.0165,\n", 331 | " -0.053 , 0.0846, -0.0065, -0.0038, -0.1598, 0.1618, 0.0794,\n", 332 | " -0.0485, -0.1004, 0.0229, -0.0965, -0.1124, -0.0565, 0.057 ,\n", 333 | " -0.0121, 0.0612, 0.0166, 0.0287, -0.1395, -0.0852, 0.1669,\n", 334 | " 0.129 , 0.1284, -0.0236, -0.1488, -0.1463, -0.0404, -0.0865,\n", 335 | " -0.0594, -0.1029],\n", 336 | " [ 0.0581, -0.0834, -0.1478, 0.0444, -0.0717, 0.0233, 0.016 ,\n", 337 | " -0.0592, -0.0853, 0.0948, 0.0738, -0.1201, -0.2125, 0.2126,\n", 338 | " 0.1718, 0.0415, -0.1715, 0.0089, 0.1158, -0.0047, 0.0028,\n", 339 | " -0.0533, 0.0282, 0.0168, 0.0254, -0.0085, -0.1421, -0.0544,\n", 340 | " 0.1502, -0.0623, 0.0225, -0.1236, -0.0113, 0.0442, 0.0045,\n", 341 | " 0.0559, -0.0042, -0.243 , -0.0074, 0.0362, 0.0433, -0.0972,\n", 342 | " 0.1463, -0.0052, -0.0015, 0.0838, 0.0502, -0.06 , 0.1244,\n", 343 | " 0.0457, -0.0553, -0.042 , -0.0861, 0.0734, -0.1825, 0.0787,\n", 344 | " -0.0062, 0.1039, 0.1444, 0.1039, -0.0022, -0.2063, 0.0464,\n", 345 | " 0.0867, -0.1355, 0.0045, -0.091 , 0.0281, -0.2494, 0.1741,\n", 346 | " -0.0993, 0.0978, 0.1304, 0.1351, 0.032 , -0.0123, -0.0124,\n", 347 | " -0.1266, 0.1234, 0.0331, -0.1054, -0.192 , 0.0886, 0.0018,\n", 348 | " -0.1345, 0.039 , -0.0252, 0.0639, -0.1552, 0.065 , -0.0225,\n", 349 | " -0.0379, 0.0223, -0.1473, -0.0158, -0.0979, 0.1099, -0.1995,\n", 350 | " 0.0381, -0.0259],\n", 351 | " [-0.1494, 0.1693, 0.1817, -0.0668, 0.1318, -0.0355, 0.1074,\n", 352 | " 0.0049, -0.0429, -0.2172, 0.1686, -0.0052, 0.0159, -0.0334,\n", 353 | " 0.157 , 0.0632, -0.1313, 0.1278, -0.1335, -0.0233, 0.2449,\n", 354 | " -0.1248, 0.0389, -0.1911, 0.1007, -0.1928, 0.1366, -0.213 ,\n", 355 | " -0.0734, -0.0542, 0.0653, 0.1547, 0.0024, -0.081 , -0.1056,\n", 356 | " 0.0168, -0.1497, -0.0766, 0.0437, 0.0195, 0.0475, -0.147 ,\n", 357 | " 0.1401, 0.0972, -0.1178, 0.0719, 0.0297, -0.0247, 0.1005,\n", 358 | " 0.0221, -0.1239, -0.0429, -0.0457, -0.072 , 0.0078, -0.0448,\n", 359 | " 0.0934, -0.1073, 0.0093, 0.0061, 0.0357, -0.0423, 0.0311,\n", 360 | " 0.015 , 0.1778, -0.0428, -0.0722, 0.0227, 0.0522, -0.0692,\n", 361 | " -0.1277, -0.0107, -0.0187, 0.022 , 0.1519, 0.0605, -0.1817,\n", 362 | " 0.0164, 0.0407, -0.1203, 0.0541, -0.0445, -0.0469, 0.0478,\n", 363 | " 0.0052, 0.0869, -0.1958, -0.0349, 0.0165, -0.0536, 0.1215,\n", 364 | " 0.0844, 0.0429, 0.0508, 0.1345, -0.0663, -0.0689, -0.0959,\n", 365 | " 0.0705, -0.096 ],\n", 366 | " [ 0.042 , 0.0353, -0.0529, 0.0904, -0.1003, 0.0023, -0.1248,\n", 367 | " 0.1678, 0.1003, 0.1283, 0.0016, 0.0692, 0.0479, 0.0497,\n", 368 | " 0.006 , 0.08 , -0.1133, 0.0636, 0.0552, -0.1907, 0.1857,\n", 369 | " 0.1229, 0.075 , 0.0927, -0.2153, -0.1263, 0.0556, 0.0235,\n", 370 | " -0.1678, 0.2431, 0.0308, -0.0247, -0.0646, -0.0142, -0.061 ,\n", 371 | " -0.1598, -0.1329, -0.0674, -0.0833, 0.1434, -0.1402, 0.1615,\n", 372 | " 0.1094, 0.013 , -0.0477, 0.1196, 0.1645, 0.1372, -0.0568,\n", 373 | " 0.0957, -0.0004, 0.1449, -0.016 , 0.1166, 0.0124, -0.0402,\n", 374 | " -0.0253, 0.0418, -0.1052, 0.087 , 0.1117, 0.0161, -0.0925,\n", 375 | " 0.0134, -0.0268, -0.1936, -0.1072, 0.1193, 0.0576, -0.0095,\n", 376 | " 0.0531, -0.1384, -0.0812, 0.1429, 0.1244, 0.0282, -0.0282,\n", 377 | " -0.0062, 0.0311, -0.0499, -0.1484, 0.0524, 0.0696, 0.1764,\n", 378 | " 0.1374, 0.0053, -0.1385, 0.0256, -0.0436, 0.0372, -0.1558,\n", 379 | " 0.06 , -0.1722, -0.0392, -0.0916, 0.0571, 0.0847, -0.0076,\n", 380 | " -0.045 , -0.0058],\n", 381 | " [-0.0644, 0.074 , 0.0298, 0.1139, 0.029 , 0.0728, 0.0545,\n", 382 | " 0.0066, 0.0438, -0.1246, -0.0512, 0.0649, -0.0015, -0.0626,\n", 383 | " 0.0177, 0.153 , -0.1139, -0.0785, -0.0914, -0.0269, -0.0965,\n", 384 | " -0.0256, 0.0953, -0.0307, 0.1179, -0.0643, 0.0556, -0.0318,\n", 385 | " 0.0098, -0.2266, -0.0066, -0.1737, -0.012 , 0.0155, 0.0827,\n", 386 | " -0.1823, 0.1417, -0.0809, -0.0141, 0.1102, 0.0483, 0.1434,\n", 387 | " -0.0643, -0.0592, -0.1776, -0.0353, -0.1399, -0.0311, 0.0601,\n", 388 | " -0.0556, 0.2011, 0.026 , 0.0247, 0.1667, 0.0975, 0.1257,\n", 389 | " 0.0554, 0.0787, 0.267 , -0.0717, -0.1579, -0.0433, -0.0424,\n", 390 | " -0.0702, -0.1176, -0.1289, -0.182 , 0.0326, 0.0815, -0.0226,\n", 391 | " -0.0667, 0.2647, -0.0972, -0.1035, 0.0266, 0.0695, 0.0393,\n", 392 | " -0.0123, 0.0945, -0.1469, 0.0791, 0.1074, 0.0504, 0.0933,\n", 393 | " 0.1163, -0.0456, -0.1153, -0.0878, 0.1167, 0.1422, -0.0062,\n", 394 | " -0.1823, -0.1169, -0.0334, -0.1232, 0.0955, -0.0013, -0.0222,\n", 395 | " -0.0516, -0.0122],\n", 396 | " [ 0.1755, 0.0701, -0.0588, -0.0834, 0.0671, 0.0504, 0.0299,\n", 397 | " -0.1256, 0.0072, 0.1121, 0.1107, 0.036 , 0.0904, 0.0448,\n", 398 | " 0.1639, -0.0031, -0.0833, -0.096 , 0.0674, 0.0226, 0.0837,\n", 399 | " -0.029 , -0.125 , -0.0895, 0.0796, 0.0353, -0.0196, 0.1451,\n", 400 | " 0.0373, -0.0487, 0.0204, -0.1386, -0.1177, -0.0687, 0.1374,\n", 401 | " 0.0844, -0.1328, 0.1388, -0.091 , -0.0843, 0.1455, 0.0904,\n", 402 | " 0.1358, 0.0097, 0.0926, 0.1205, -0.109 , -0.1065, 0.0966,\n", 403 | " 0.0455, 0.0929, 0.0133, -0.0075, -0.0903, 0.0095, -0.1677,\n", 404 | " 0.0606, 0.038 , 0.0353, 0.1976, 0.0134, 0.1137, -0.2059,\n", 405 | " -0.0271, 0.0912, 0.1127, -0.0967, 0.0737, 0.1882, 0.0675,\n", 406 | " -0.1005, -0.1373, 0.2494, -0.0256, 0.01 , 0.1084, 0.0665,\n", 407 | " -0.0967, -0.1128, 0.1204, -0.0376, 0.1391, -0.1169, -0.0605,\n", 408 | " 0.1143, 0.0154, -0.0114, -0.1473, 0.1942, 0.1005, -0.0274,\n", 409 | " -0.045 , -0.0829, 0.0436, -0.0515, 0.0796, 0.1322, -0.1056,\n", 410 | " 0.0053, -0.1475],\n", 411 | " [ 0.032 , -0.0456, 0.1179, -0.0105, 0.0812, -0.1035, 0.0459,\n", 412 | " 0.1157, -0.1335, 0.0648, 0.0538, 0.0327, -0.0078, -0.0107,\n", 413 | " -0.0912, -0.0911, -0.0455, -0.2188, -0.0527, 0.0974, -0.1089,\n", 414 | " 0.1092, 0.0744, 0.0784, 0.0168, -0.0292, 0.177 , -0.0883,\n", 415 | " 0.0922, -0.0812, 0.2664, -0.0381, -0.0137, -0.0896, -0.0391,\n", 416 | " 0.0098, 0.031 , 0.0609, -0.2073, -0.0634, -0.0691, 0.1889,\n", 417 | " -0.1392, -0.2065, 0.0611, 0.0558, 0.0481, 0.0993, 0.2513,\n", 418 | " 0.1045, -0.0177, -0.0063, 0.0053, 0.0057, -0.0123, 0.0603,\n", 419 | " 0.1883, -0.0631, -0.0205, 0.062 , -0.0031, -0.0486, 0.15 ,\n", 420 | " -0.004 , -0.1527, 0.0807, 0.01 , -0.1239, 0.0519, 0.0103,\n", 421 | " -0.0206, -0.127 , 0.1075, 0.0564, -0.0445, -0.1144, 0.1374,\n", 422 | " -0.1299, -0.0625, -0.0414, -0.0567, 0.0076, 0.0664, 0.069 ,\n", 423 | " 0.1062, 0.0418, -0.2425, 0.0768, -0.0972, -0.1723, -0.0346,\n", 424 | " -0.0602, -0.0282, 0.1123, 0.2317, 0.0189, -0.0737, -0.0803,\n", 425 | " 0.0466, 0.0221]]))" 426 | ] 427 | }, 428 | "execution_count": 28, 429 | "metadata": {}, 430 | "output_type": "execute_result" 431 | } 432 | ], 433 | "source": [ 434 | "A = np.random.uniform(-40,40,[10,100]) \n", 435 | "simple_randomized_svd(A)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 29, 441 | "metadata": {}, 442 | "outputs": [ 443 | { 444 | "data": { 445 | "text/plain": [ 446 | "(tensor([[ 0.0022, -0.6062, -0.0613, -0.0791, 0.4002, 0.0024, 0.4948,\n", 447 | " -0.3578, -0.1446, 0.2620],\n", 448 | " [-0.4845, -0.1573, -0.2625, 0.5393, 0.0407, 0.3863, -0.0498,\n", 449 | " -0.1818, 0.0822, -0.4332],\n", 450 | " [ 0.1150, 0.0404, 0.2540, -0.4545, 0.3348, 0.1813, -0.1469,\n", 451 | " -0.3452, 0.5088, -0.4116],\n", 452 | " [-0.4244, -0.0384, -0.5109, -0.3065, -0.0551, -0.5866, -0.2308,\n", 453 | " -0.2045, 0.1428, 0.0278],\n", 454 | " [ 0.0904, -0.3397, 0.0358, 0.0862, -0.4325, -0.2028, 0.4178,\n", 455 | " 0.2802, 0.5992, -0.1656],\n", 456 | " [ 0.0009, -0.3461, 0.3275, -0.0386, 0.0213, -0.3891, -0.1232,\n", 457 | " 0.1403, -0.4726, -0.6011],\n", 458 | " [-0.5206, 0.2742, 0.4045, 0.1494, 0.4669, -0.2599, 0.2320,\n", 459 | " 0.2895, 0.1736, 0.1205],\n", 460 | " [-0.1635, -0.5329, 0.2680, 0.0149, -0.0122, 0.1470, -0.6118,\n", 461 | " 0.2005, 0.1841, 0.3836],\n", 462 | " [-0.1608, 0.1006, 0.5005, 0.1803, -0.4526, -0.1572, 0.0226,\n", 463 | " -0.6543, -0.0230, 0.1494],\n", 464 | " [-0.4901, -0.0142, 0.0852, -0.5818, -0.3353, 0.4092, 0.2439,\n", 465 | " 0.1664, -0.2147, -0.0278]], dtype=torch.float64, device='cuda:0'),\n", 466 | " tensor([ 288.0379, 268.4482, 256.8885, 250.8847, 240.9807, 222.8867,\n", 467 | " 211.8652, 196.4339, 188.9815, 181.7164], dtype=torch.float64, device='cuda:0'),\n", 468 | " tensor([[ 0.0402, 0.0179, 0.1010, 0.1340, -0.1198, 0.1070, -0.0303,\n", 469 | " -0.0145, -0.1296, -0.0939, 0.0347, 0.0685, 0.1009, -0.1867,\n", 470 | " 0.0653, 0.1440, -0.0890, -0.1349, 0.0387, -0.0470, 0.1600,\n", 471 | " 0.0895, -0.0106, 0.0787, -0.1653, -0.0410, 0.0178, -0.0412,\n", 472 | " 0.1382, 0.0518, 0.0527, -0.1435, 0.2222, -0.0775, -0.0303,\n", 473 | " 0.1840, -0.1013, 0.2213, -0.0228, -0.0692, 0.1175, -0.0974,\n", 474 | " 0.0645, 0.1434, 0.0482, 0.0197, 0.1182, -0.1008, -0.0731,\n", 475 | " 0.0752, -0.1366, 0.0342, 0.0873, 0.1115, 0.0357, 0.0789,\n", 476 | " 0.0411, 0.0976, -0.1566, -0.0157, 0.0403, 0.1302, 0.0862,\n", 477 | " -0.0429, -0.1030, -0.0418, 0.0139, 0.0478, 0.0930, 0.0372,\n", 478 | " -0.0272, -0.0317, -0.0502, 0.1139, -0.0241, -0.1365, 0.1510,\n", 479 | " -0.0726, 0.1520, -0.0190, -0.0144, -0.1091, -0.1078, -0.1329,\n", 480 | " -0.1456, 0.0768, -0.1030, -0.0654, 0.0006, 0.0919, 0.0320,\n", 481 | " 0.1976, 0.2087, 0.1534, -0.0819, 0.0658, 0.0864, 0.0159,\n", 482 | " 0.0848, -0.0768],\n", 483 | " [ 0.0028, -0.2304, 0.0922, 0.1209, 0.1288, -0.0215, 0.0165,\n", 484 | " -0.1730, -0.1264, -0.0367, 0.1728, 0.0732, 0.0905, 0.0096,\n", 485 | " -0.0204, 0.0053, 0.1326, -0.1568, 0.0401, -0.0747, 0.0065,\n", 486 | " 0.0739, 0.0326, -0.0914, 0.0649, -0.1903, 0.1179, -0.0136,\n", 487 | " -0.2076, -0.0016, -0.0031, 0.0441, -0.0161, -0.0986, 0.0192,\n", 488 | " 0.0174, 0.1647, 0.0056, 0.0768, 0.0866, -0.0080, 0.0353,\n", 489 | " -0.1670, -0.0447, -0.0199, 0.0955, -0.0614, 0.1367, 0.0534,\n", 490 | " 0.1255, 0.0044, 0.0299, -0.0637, -0.1124, 0.0150, -0.0660,\n", 491 | " 0.1089, 0.2244, -0.0302, 0.0019, -0.0065, -0.1051, 0.0256,\n", 492 | " 0.1989, 0.0457, -0.0703, 0.0894, 0.1100, -0.0395, -0.1362,\n", 493 | " -0.0774, -0.0100, -0.0332, -0.0585, -0.0319, -0.0984, -0.0465,\n", 494 | " -0.0808, 0.1002, 0.0236, -0.1758, -0.1891, -0.0240, 0.1430,\n", 495 | " -0.0646, 0.2253, 0.0332, 0.0496, 0.1592, -0.0396, 0.0287,\n", 496 | " 0.0894, -0.0473, 0.1375, 0.1620, -0.0834, -0.1199, -0.1188,\n", 497 | " 0.1046, 0.1351],\n", 498 | " [-0.0884, 0.0999, 0.0426, -0.0744, -0.0532, 0.0521, -0.0875,\n", 499 | " 0.0857, 0.0061, 0.0447, 0.1417, 0.0376, -0.1017, 0.1446,\n", 500 | " -0.0487, -0.0861, 0.2344, -0.0448, -0.0421, 0.1018, -0.1235,\n", 501 | " -0.1093, -0.0251, 0.0719, 0.0190, 0.1050, 0.0458, -0.0777,\n", 502 | " 0.1354, 0.1192, -0.1047, -0.0277, -0.0419, 0.0380, 0.0036,\n", 503 | " -0.0073, -0.1259, -0.0438, 0.1315, 0.0686, 0.1742, 0.0673,\n", 504 | " -0.0877, 0.0119, 0.0763, 0.1042, -0.0186, -0.0021, 0.0122,\n", 505 | " 0.1370, -0.0361, 0.0547, -0.0003, -0.0124, 0.0549, 0.1181,\n", 506 | " -0.0179, -0.0130, -0.0742, -0.0555, 0.1165, 0.1674, -0.2366,\n", 507 | " 0.1009, 0.1254, 0.1372, -0.1783, 0.1805, -0.1533, -0.2399,\n", 508 | " -0.2054, 0.1292, -0.1053, 0.0793, 0.0486, -0.0115, 0.0378,\n", 509 | " -0.1045, 0.0460, 0.1473, 0.1193, 0.0784, 0.1839, 0.0148,\n", 510 | " -0.1172, 0.0546, -0.0588, -0.1027, -0.0312, 0.0176, -0.1419,\n", 511 | " -0.0781, 0.1442, 0.1013, -0.0276, 0.0649, -0.0431, 0.0485,\n", 512 | " 0.1122, 0.0601],\n", 513 | " [ 0.0264, 0.0399, -0.1116, -0.1556, 0.1822, 0.0549, -0.1786,\n", 514 | " 0.0368, -0.0447, -0.0710, 0.0534, -0.2455, 0.1285, -0.0170,\n", 515 | " -0.1809, 0.1463, 0.1507, 0.0374, 0.0847, -0.0951, -0.1087,\n", 516 | " 0.0752, -0.1242, 0.0321, -0.0560, 0.1063, -0.0261, -0.0935,\n", 517 | " -0.0339, -0.1423, -0.0467, -0.1533, 0.1397, -0.0679, 0.0439,\n", 518 | " 0.0878, -0.0127, 0.0274, 0.1109, 0.0243, 0.0185, -0.1542,\n", 519 | " -0.0204, 0.0007, 0.0391, 0.1079, -0.0636, 0.1032, 0.1512,\n", 520 | " 0.0684, -0.0886, -0.0413, -0.1078, 0.0172, -0.0120, -0.1109,\n", 521 | " -0.0179, -0.0591, 0.1489, 0.1321, -0.0895, -0.0721, 0.0246,\n", 522 | " 0.0583, -0.1290, 0.0604, 0.0362, 0.1651, 0.1972, 0.1336,\n", 523 | " 0.0104, 0.0291, 0.0235, 0.2579, 0.0867, -0.0081, -0.1880,\n", 524 | " -0.1350, -0.0562, 0.1306, -0.0211, -0.0152, 0.0277, -0.0011,\n", 525 | " 0.0464, 0.0237, -0.0550, 0.0843, -0.1540, -0.0198, -0.1439,\n", 526 | " 0.0274, -0.1024, 0.0887, 0.0165, -0.0300, 0.0527, 0.1530,\n", 527 | " -0.1372, -0.0620],\n", 528 | " [ 0.1347, 0.0335, -0.0503, 0.0225, 0.0985, -0.1284, -0.1658,\n", 529 | " 0.0607, 0.0117, -0.0873, -0.1379, 0.0673, -0.1342, -0.0387,\n", 530 | " 0.1093, -0.0318, -0.0723, 0.0280, 0.0152, -0.1466, 0.0698,\n", 531 | " -0.0916, -0.1624, 0.0353, 0.1334, -0.1421, -0.0433, 0.1296,\n", 532 | " -0.0337, -0.0706, 0.0881, 0.1332, 0.0894, -0.0915, 0.3250,\n", 533 | " 0.0995, -0.0664, -0.0735, 0.0625, 0.2010, -0.0397, 0.0893,\n", 534 | " -0.0882, -0.0379, 0.1399, 0.0904, -0.1091, 0.0112, 0.0447,\n", 535 | " -0.0305, 0.2226, -0.0058, 0.1066, 0.0163, 0.1464, -0.0573,\n", 536 | " -0.0208, -0.0964, -0.1697, -0.0079, 0.0895, 0.0288, 0.0678,\n", 537 | " -0.0428, 0.0088, 0.0403, -0.0180, -0.0395, -0.0164, 0.0931,\n", 538 | " 0.0240, -0.1108, -0.1498, -0.1309, 0.0183, -0.1166, 0.0619,\n", 539 | " -0.0851, 0.0168, -0.0518, 0.1131, 0.1017, 0.0206, -0.0342,\n", 540 | " 0.0488, 0.1308, -0.2026, -0.0302, -0.0831, -0.1704, -0.0906,\n", 541 | " 0.0435, 0.0368, 0.1438, -0.0062, -0.1276, 0.1259, 0.1237,\n", 542 | " -0.0343, 0.1287],\n", 543 | " [-0.1197, -0.0973, 0.1977, 0.0149, 0.0162, 0.1260, -0.0780,\n", 544 | " 0.1430, 0.1305, -0.1031, -0.0307, 0.0298, -0.0824, -0.0827,\n", 545 | " -0.2335, 0.0854, -0.0503, -0.0273, 0.2594, 0.1589, 0.0108,\n", 546 | " -0.0249, -0.0452, -0.2144, -0.0325, -0.1060, 0.0481, 0.1667,\n", 547 | " -0.0486, 0.0701, -0.0498, -0.2067, -0.0975, -0.0406, 0.1167,\n", 548 | " -0.1278, 0.1188, 0.0176, 0.1025, -0.0021, -0.1293, -0.0431,\n", 549 | " 0.0168, -0.0639, 0.0910, -0.0693, -0.1139, 0.0606, 0.0382,\n", 550 | " -0.0894, -0.0208, 0.1261, -0.1226, 0.2494, 0.0981, 0.1157,\n", 551 | " 0.1099, 0.0794, -0.0302, 0.0600, 0.1342, -0.0972, 0.0775,\n", 552 | " -0.0795, -0.0038, 0.0509, 0.0711, 0.0791, -0.0390, 0.0265,\n", 553 | " 0.0225, 0.0193, 0.0223, 0.0571, -0.0316, 0.0895, 0.0840,\n", 554 | " -0.0445, -0.1916, 0.0676, 0.1175, 0.1351, 0.0194, -0.1393,\n", 555 | " -0.0515, -0.0252, 0.0052, -0.1010, 0.2183, 0.0888, -0.0397,\n", 556 | " 0.0042, 0.0088, -0.0649, -0.0205, -0.1267, 0.0966, -0.0467,\n", 557 | " 0.0872, -0.0040],\n", 558 | " [ 0.0936, 0.1248, -0.0084, 0.2326, 0.0379, -0.0158, 0.0899,\n", 559 | " -0.0030, 0.1916, 0.0537, 0.0406, 0.0789, 0.0987, -0.1878,\n", 560 | " -0.0692, -0.0683, 0.0770, -0.0242, 0.1264, -0.1108, 0.0093,\n", 561 | " 0.0759, 0.0806, 0.0937, 0.1041, -0.0796, -0.0383, 0.0123,\n", 562 | " 0.0199, -0.0351, 0.0958, 0.0272, 0.1233, 0.2582, 0.0167,\n", 563 | " 0.0982, -0.0259, -0.0768, 0.0960, 0.0638, 0.1092, -0.0650,\n", 564 | " 0.1729, -0.0732, -0.0330, -0.0034, 0.0703, -0.0552, -0.0696,\n", 565 | " -0.0127, 0.0375, -0.1205, -0.1232, -0.0286, -0.2567, -0.0219,\n", 566 | " 0.0577, 0.0558, -0.0252, 0.0113, -0.0771, -0.0744, 0.0281,\n", 567 | " 0.0213, -0.0900, -0.0388, -0.1575, -0.0226, -0.0398, -0.0244,\n", 568 | " -0.2027, -0.0248, 0.0407, -0.0918, 0.0456, 0.1203, -0.0913,\n", 569 | " 0.1207, -0.0250, 0.1320, 0.0025, 0.1495, 0.1492, -0.0344,\n", 570 | " -0.0396, 0.1223, -0.2094, 0.0037, 0.1174, 0.1286, -0.1494,\n", 571 | " -0.0698, 0.0424, -0.0048, 0.1082, -0.1825, 0.1142, -0.1606,\n", 572 | " -0.0968, -0.1519],\n", 573 | " [-0.0016, -0.1800, 0.0698, 0.1645, 0.0473, -0.0481, -0.1001,\n", 574 | " 0.1240, -0.0739, 0.1824, -0.0996, -0.0741, -0.0305, -0.0974,\n", 575 | " -0.0156, -0.1295, -0.0522, 0.0011, 0.1147, -0.0729, -0.0135,\n", 576 | " -0.1363, 0.1198, -0.0800, 0.0343, 0.0096, 0.0645, 0.1102,\n", 577 | " 0.0368, 0.0777, 0.0726, -0.0835, 0.0123, 0.0795, -0.0760,\n", 578 | " 0.1144, -0.0869, -0.0926, -0.0165, -0.1925, -0.0161, -0.2408,\n", 579 | " -0.0979, -0.0226, 0.0420, -0.1627, 0.0011, -0.0376, 0.1536,\n", 580 | " 0.1317, -0.0440, 0.0126, 0.1233, -0.0229, -0.0363, 0.1184,\n", 581 | " -0.1023, -0.0566, 0.0192, 0.1838, -0.1512, 0.0650, -0.0489,\n", 582 | " 0.0779, 0.0833, 0.0991, 0.1060, 0.1444, -0.2267, -0.0181,\n", 583 | " 0.1801, -0.2471, -0.0391, 0.0135, 0.0094, -0.0075, -0.0308,\n", 584 | " -0.0231, 0.1661, -0.0127, -0.0828, -0.1450, 0.1563, 0.0933,\n", 585 | " 0.0916, -0.1271, -0.1525, 0.0836, 0.0689, -0.0273, -0.0632,\n", 586 | " -0.0709, 0.0688, -0.1060, 0.0094, 0.0156, 0.0996, 0.0740,\n", 587 | " -0.0237, 0.0647],\n", 588 | " [-0.0452, 0.0037, 0.0861, -0.0780, -0.0127, -0.0346, -0.1745,\n", 589 | " -0.0082, -0.0197, 0.1225, -0.0097, -0.1630, -0.0391, -0.1403,\n", 590 | " 0.0047, -0.0301, 0.0380, 0.0689, 0.0560, -0.1448, 0.0415,\n", 591 | " -0.0765, 0.0268, -0.0093, -0.0013, -0.0536, 0.1745, 0.0851,\n", 592 | " 0.1459, 0.0460, 0.0361, -0.0080, -0.0371, -0.0591, -0.0021,\n", 593 | " -0.0232, -0.1775, -0.0694, 0.2404, 0.2391, 0.1042, 0.0004,\n", 594 | " -0.0638, -0.0455, -0.1881, -0.0685, -0.0020, -0.0074, 0.0631,\n", 595 | " 0.0061, -0.0239, 0.0014, 0.1150, 0.0216, -0.0501, -0.1460,\n", 596 | " 0.0868, 0.0058, 0.0238, -0.0468, 0.1932, -0.1387, 0.1993,\n", 597 | " -0.1703, 0.0708, -0.1123, -0.0992, 0.0359, 0.1450, -0.0621,\n", 598 | " 0.0057, -0.1338, -0.0617, -0.0484, 0.2194, 0.0195, 0.0898,\n", 599 | " 0.1415, -0.0739, 0.1688, -0.0542, -0.0715, -0.0008, 0.2597,\n", 600 | " -0.0670, -0.0884, 0.1334, -0.0423, -0.0628, 0.1305, 0.1561,\n", 601 | " 0.0818, 0.0802, -0.0710, -0.0751, 0.1020, -0.1029, -0.0448,\n", 602 | " -0.0334, -0.0624],\n", 603 | " [ 0.1935, -0.1132, -0.0640, 0.0775, -0.1758, -0.0304, -0.0325,\n", 604 | " -0.0281, 0.0281, 0.0623, 0.0806, -0.0406, -0.0136, 0.0796,\n", 605 | " -0.0326, -0.0139, 0.0010, -0.0886, 0.0059, 0.0165, -0.0906,\n", 606 | " 0.0594, -0.1253, 0.0824, -0.0844, -0.1503, -0.0779, 0.0404,\n", 607 | " -0.1407, 0.1032, 0.0024, 0.0365, -0.0915, 0.1998, -0.0369,\n", 608 | " 0.0309, 0.0805, 0.1246, 0.0020, -0.0381, 0.1327, 0.1264,\n", 609 | " -0.0554, 0.0573, -0.2027, 0.1270, 0.1991, 0.1515, 0.1560,\n", 610 | " -0.0187, 0.0049, -0.0156, 0.0914, 0.0688, 0.1862, 0.0877,\n", 611 | " -0.0511, -0.0257, 0.1836, -0.0647, 0.1148, -0.0108, 0.0483,\n", 612 | " 0.0404, 0.0392, 0.0261, 0.1588, 0.0488, 0.1369, 0.1264,\n", 613 | " -0.1223, -0.1065, 0.0010, -0.0762, 0.1735, -0.0131, 0.0517,\n", 614 | " -0.1084, -0.2028, -0.1659, -0.0472, 0.0291, 0.1472, 0.0132,\n", 615 | " -0.1019, -0.1006, -0.2296, 0.0467, 0.0651, -0.1048, -0.0653,\n", 616 | " -0.0799, -0.0006, -0.0679, -0.0992, 0.1511, 0.0393, -0.0768,\n", 617 | " 0.0185, -0.1413]], dtype=torch.float64, device='cuda:0'))" 618 | ] 619 | }, 620 | "execution_count": 29, 621 | "metadata": {}, 622 | "output_type": "execute_result" 623 | } 624 | ], 625 | "source": [ 626 | "A = np.random.uniform(-40,40,[10,100])\n", 627 | "simple_randomized_torch_svd(A)" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 7, 633 | "metadata": {}, 634 | "outputs": [], 635 | "source": [ 636 | "np.set_printoptions(suppress=True, precision=4)" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": 8, 642 | "metadata": {}, 643 | "outputs": [], 644 | "source": [ 645 | "m_array = np.array([100, 1000, 20000])\n", 646 | "n_array = np.array([100, 1000, 20000])" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": 9, 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [ 655 | "index = pd.MultiIndex.from_product([m_array, n_array], names=['# rows', '# cols'])" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 10, 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [ 664 | "pd.options.display.float_format = '{:,.3f}'.format\n", 665 | "df1 = pd.DataFrame(index=m_array, columns=n_array)\n", 666 | "df2 = pd.DataFrame(index=m_array, columns=n_array)\n", 667 | "df3 = pd.DataFrame(index=m_array, columns=n_array)\n", 668 | "df4 = pd.DataFrame(index=m_array, columns=n_array)\n", 669 | "df5 = pd.DataFrame(index=m_array, columns=n_array)\n", 670 | "df6 = pd.DataFrame(index=m_array, columns=n_array)" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": 11, 676 | "metadata": {}, 677 | "outputs": [ 678 | { 679 | "name": "stderr", 680 | "output_type": "stream", 681 | "text": [ 682 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:11: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n", 683 | " # This is added back by InteractiveShellApp.init_path()\n", 684 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:12: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n", 685 | " if sys.path[0] == '':\n", 686 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n", 687 | " del sys.path[0]\n", 688 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:14: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n", 689 | " \n", 690 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:15: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n", 691 | " from ipykernel import kernelapp as app\n", 692 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:16: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n", 693 | " app.launch_new_instance()\n" 694 | ] 695 | } 696 | ], 697 | "source": [ 698 | "### %%prun\n", 699 | "for m in m_array:\n", 700 | " for n in n_array: \n", 701 | " A = np.random.uniform(-40,40,[m,n]) \n", 702 | " t1 = timeit.timeit('simple_randomized_svd(A, 10)', number=3, globals=globals())\n", 703 | " t2 = timeit.timeit('decomposition.randomized_svd(A, 10)', number=3, globals=globals())\n", 704 | " t3 = timeit.timeit('fbpca.pca(A, 10)', number=3, globals=globals())\n", 705 | " t4 = timeit.timeit('simple_randomized_torch_svd(A, 10)', number=3, globals=globals())\n", 706 | " t5 = timeit.timeit('randomized_svd(A, 5)', number=3, globals=globals())\n", 707 | " t6 = timeit.timeit('randomized_svd_original(A, 5)', number=3, globals=globals())\n", 708 | " df1.set_value(m, n, t1)\n", 709 | " df2.set_value(m, n, t2)\n", 710 | " df3.set_value(m, n, t3)\n", 711 | " df4.set_value(m, n, t4)\n", 712 | " df5.set_value(m, n, t5)\n", 713 | " df6.set_value(m, n, t6)" 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": 12, 719 | "metadata": {}, 720 | "outputs": [ 721 | { 722 | "data": { 723 | "text/html": [ 724 | "
\n", 725 | "\n", 738 | "\n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | "
100100020000
1000.0070.0210.018
10000.0250.0180.036
200000.0140.0420.364
\n", 768 | "
" 769 | ], 770 | "text/plain": [ 771 | " 100 1000 20000\n", 772 | "100 0.007 0.021 0.018\n", 773 | "1000 0.025 0.018 0.036\n", 774 | "20000 0.014 0.042 0.364" 775 | ] 776 | }, 777 | "execution_count": 12, 778 | "metadata": {}, 779 | "output_type": "execute_result" 780 | } 781 | ], 782 | "source": [ 783 | "df1/3" 784 | ] 785 | }, 786 | { 787 | "cell_type": "code", 788 | "execution_count": 13, 789 | "metadata": {}, 790 | "outputs": [ 791 | { 792 | "data": { 793 | "text/html": [ 794 | "
\n", 795 | "\n", 808 | "\n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | "
100100020000
1000.0590.0920.094
10000.1010.1040.511
200000.1240.4834.159
\n", 838 | "
" 839 | ], 840 | "text/plain": [ 841 | " 100 1000 20000\n", 842 | "100 0.059 0.092 0.094\n", 843 | "1000 0.101 0.104 0.511\n", 844 | "20000 0.124 0.483 4.159" 845 | ] 846 | }, 847 | "execution_count": 13, 848 | "metadata": {}, 849 | "output_type": "execute_result" 850 | } 851 | ], 852 | "source": [ 853 | "df2/3" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": 14, 859 | "metadata": {}, 860 | "outputs": [ 861 | { 862 | "data": { 863 | "text/html": [ 864 | "
\n", 865 | "\n", 878 | "\n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | "
100100020000
1000.0240.0770.084
10000.0560.0830.214
200000.0900.2251.424
\n", 908 | "
" 909 | ], 910 | "text/plain": [ 911 | " 100 1000 20000\n", 912 | "100 0.024 0.077 0.084\n", 913 | "1000 0.056 0.083 0.214\n", 914 | "20000 0.090 0.225 1.424" 915 | ] 916 | }, 917 | "execution_count": 14, 918 | "metadata": {}, 919 | "output_type": "execute_result" 920 | } 921 | ], 922 | "source": [ 923 | "df3/3" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": 15, 929 | "metadata": {}, 930 | "outputs": [ 931 | { 932 | "data": { 933 | "text/html": [ 934 | "
\n", 935 | "\n", 948 | "\n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | "
100100020000
1001.2020.0140.027
10000.0190.0270.096
200000.0290.1002.756
\n", 978 | "
" 979 | ], 980 | "text/plain": [ 981 | " 100 1000 20000\n", 982 | "100 1.202 0.014 0.027\n", 983 | "1000 0.019 0.027 0.096\n", 984 | "20000 0.029 0.100 2.756" 985 | ] 986 | }, 987 | "execution_count": 15, 988 | "metadata": {}, 989 | "output_type": "execute_result" 990 | } 991 | ], 992 | "source": [ 993 | "df4/4" 994 | ] 995 | }, 996 | { 997 | "cell_type": "code", 998 | "execution_count": 16, 999 | "metadata": {}, 1000 | "outputs": [ 1001 | { 1002 | "data": { 1003 | "text/html": [ 1004 | "
\n", 1005 | "\n", 1018 | "\n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | " \n", 1029 | " \n", 1030 | " \n", 1031 | " \n", 1032 | " \n", 1033 | " \n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | "
100100020000
1000.0300.0500.056
10000.0200.0280.252
200000.0500.2382.225
\n", 1048 | "
" 1049 | ], 1050 | "text/plain": [ 1051 | " 100 1000 20000\n", 1052 | "100 0.030 0.050 0.056\n", 1053 | "1000 0.020 0.028 0.252\n", 1054 | "20000 0.050 0.238 2.225" 1055 | ] 1056 | }, 1057 | "execution_count": 16, 1058 | "metadata": {}, 1059 | "output_type": "execute_result" 1060 | } 1061 | ], 1062 | "source": [ 1063 | "df5/5" 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": 17, 1069 | "metadata": {}, 1070 | "outputs": [ 1071 | { 1072 | "data": { 1073 | "text/html": [ 1074 | "
\n", 1075 | "\n", 1088 | "\n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | "
100100020000
1000.0260.0430.046
10000.0470.0280.181
200000.0460.1731.198
\n", 1118 | "
" 1119 | ], 1120 | "text/plain": [ 1121 | " 100 1000 20000\n", 1122 | "100 0.026 0.043 0.046\n", 1123 | "1000 0.047 0.028 0.181\n", 1124 | "20000 0.046 0.173 1.198" 1125 | ] 1126 | }, 1127 | "execution_count": 17, 1128 | "metadata": {}, 1129 | "output_type": "execute_result" 1130 | } 1131 | ], 1132 | "source": [ 1133 | "df6/6" 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "execution_count": null, 1139 | "metadata": { 1140 | "collapsed": true 1141 | }, 1142 | "outputs": [], 1143 | "source": [] 1144 | } 1145 | ], 1146 | "metadata": { 1147 | "kernelspec": { 1148 | "display_name": "Python 3", 1149 | "language": "python", 1150 | "name": "python3" 1151 | }, 1152 | "language_info": { 1153 | "codemirror_mode": { 1154 | "name": "ipython", 1155 | "version": 3 1156 | }, 1157 | "file_extension": ".py", 1158 | "mimetype": "text/x-python", 1159 | "name": "python", 1160 | "nbconvert_exporter": "python", 1161 | "pygments_lexer": "ipython3", 1162 | "version": "3.6.3" 1163 | } 1164 | }, 1165 | "nbformat": 4, 1166 | "nbformat_minor": 2 1167 | } 1168 | --------------------------------------------------------------------------------