├── Readme.md
└── pytorch_randomized_svd.ipynb
/Readme.md:
--------------------------------------------------------------------------------
1 | Attempt to speed up randomized SVD using GPU. Tested against current top algos.
2 | /nbs -- https://github.com/smortezavi/Randomized_SVD_GPU/blob/master/pytorch_randomized_svd.ipynb
3 |
4 |
--------------------------------------------------------------------------------
/pytorch_randomized_svd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import timeit\n",
13 | "import pandas as pd\n",
14 | "from sklearn import decomposition\n",
15 | "import fbpca\n",
16 | "import torch\n",
17 | "from scipy import linalg"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {
24 | "collapsed": true
25 | },
26 | "outputs": [],
27 | "source": [
28 | "def simple_randomized_svd(M, k=10):\n",
29 | " m, n = M.shape\n",
30 | " transpose = False\n",
31 | " if m < n:\n",
32 | " transpose = True\n",
33 | " M = M.T\n",
34 | " rand_matrix = np.random.normal(size=(M.shape[1], k)) # short side by k\n",
35 | " Q, _ = np.linalg.qr(M @ rand_matrix, mode='reduced') # long side by k\n",
36 | " smaller_matrix = Q.T @ M # k by short side\n",
37 | " U_hat, s, V = np.linalg.svd(smaller_matrix, full_matrices=False)\n",
38 | " U = Q @ U_hat\n",
39 | " \n",
40 | " if transpose:\n",
41 | " return V.T, s.T, U.T\n",
42 | " else:\n",
43 | " return U, s, V"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 3,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "def simple_randomized_torch_svd(M, k=10):\n",
53 | " B = torch.tensor(M).cuda(0)\n",
54 | " m, n = B.size()\n",
55 | " transpose = False\n",
56 | " if m < n:\n",
57 | " transpose = True\n",
58 | " B = B.transpose(0, 1).cuda(0)\n",
59 | " m, n = B.size()\n",
60 | " rand_matrix = torch.rand((n,k), dtype=torch.double).cuda(0) # short side by k\n",
61 | " Q, _ = torch.qr(B @ rand_matrix) # long side by k\n",
62 | " Q.cuda(0)\n",
63 | " smaller_matrix = (Q.transpose(0, 1) @ B).cuda(0) # k by short side\n",
64 | " U_hat, s, V = torch.svd(smaller_matrix,False)\n",
65 | " U_hat.cuda(0)\n",
66 | " U = (Q @ U_hat)\n",
67 | " \n",
68 | " if transpose:\n",
69 | " return V.transpose(0, 1), s, U.transpose(0, 1)\n",
70 | " else:\n",
71 | " return U, s, V"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 4,
77 | "metadata": {
78 | "collapsed": true
79 | },
80 | "outputs": [],
81 | "source": [
82 | "# computes an orthonormal matrix whose range approximates the range of A\n",
83 | "# power_iteration_normalizer can be safe_sparse_dot (fast but unstable), LU (imbetween), or QR (slow but most accurate)\n",
84 | "def randomized_range_finder(A, size, n_iter=5):\n",
85 | " Q = np.random.normal(size=(A.shape[1], size))\n",
86 | " \n",
87 | " for i in range(n_iter):\n",
88 | " Q, _ = linalg.lu(A @ Q, permute_l=True)\n",
89 | " Q, _ = linalg.lu(A.T @ Q, permute_l=True)\n",
90 | " \n",
91 | " Q, _ = linalg.qr(A @ Q, mode='economic')\n",
92 | " return Q"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 5,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "def randomized_svd(M, n_components, n_oversamples=10, n_iter=4):\n",
102 | " n_random = n_components + n_oversamples\n",
103 | " \n",
104 | " Q = torch.tensor(randomized_range_finder(M, n_random, n_iter)).cuda(0)\n",
105 | " # project M to the (k + p) dimensional space using the basis vectors\n",
106 | " M = torch.tensor(M).cuda(0)\n",
107 | " B = Q.transpose(0, 1) @ M\n",
108 | " # compute the SVD on the thin matrix: (k + p) wide\n",
109 | " Uhat, s, V = linalg.svd(B, full_matrices=False)\n",
110 | " Uhat = torch.tensor(Uhat).cuda(0)\n",
111 | " del B\n",
112 | " U = Q @ Uhat\n",
113 | " \n",
114 | " return U[:, :n_components], s[:n_components], V[:n_components, :]"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 6,
120 | "metadata": {
121 | "collapsed": true
122 | },
123 | "outputs": [],
124 | "source": [
125 | "def randomized_svd_original(M, n_components, n_oversamples=10, n_iter=4):\n",
126 | " \n",
127 | " n_random = n_components + n_oversamples\n",
128 | " \n",
129 | " Q = randomized_range_finder(M, n_random, n_iter)\n",
130 | " # project M to the (k + p) dimensional space using the basis vectors\n",
131 | " B = Q.T @ M\n",
132 | " # compute the SVD on the thin matrix: (k + p) wide\n",
133 | " Uhat, s, V = linalg.svd(B, full_matrices=False)\n",
134 | " del B\n",
135 | " U = Q @ Uhat\n",
136 | " \n",
137 | " return U[:, :n_components], s[:n_components], V[:n_components, :]"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 51,
143 | "metadata": {},
144 | "outputs": [
145 | {
146 | "data": {
147 | "text/plain": [
148 | "(tensor([[ 0.2665, -0.4784, -0.4189, 0.0461, 0.1967],\n",
149 | " [ 0.0450, 0.2502, -0.4644, 0.6517, -0.3875],\n",
150 | " [-0.0823, 0.5004, 0.1257, -0.2966, -0.1977],\n",
151 | " [ 0.2300, 0.0620, -0.3229, -0.6062, -0.0580],\n",
152 | " [ 0.4515, 0.2407, 0.3616, 0.1039, 0.1663],\n",
153 | " [ 0.5530, 0.1675, -0.0788, -0.0937, -0.4466],\n",
154 | " [ 0.0042, 0.1490, -0.4542, -0.2347, 0.0064],\n",
155 | " [ 0.1880, -0.5289, 0.1129, -0.1100, -0.3019],\n",
156 | " [ 0.1415, 0.2574, -0.2925, 0.0434, 0.6464],\n",
157 | " [-0.5497, -0.0077, -0.2174, -0.1706, -0.1805]], dtype=torch.float64, device='cuda:0'),\n",
158 | " array([ 291.1458, 272.6614, 258.4442, 243.4219, 235.0934]),\n",
159 | " array([[ 0.0518, 0.0131, 0.1641, 0.1506, 0.146 , 0.0222, -0.0469,\n",
160 | " 0.1355, -0.098 , -0.0233, 0.0092, 0.0286, 0.0281, 0.0301,\n",
161 | " 0.0223, 0.0542, 0.0396, 0.0861, 0.1835, 0.17 , -0.104 ,\n",
162 | " -0.0503, -0.1279, 0.2209, -0.115 , -0.0889, -0.1002, 0.119 ,\n",
163 | " 0.097 , -0.0717, 0.0485, -0.0142, -0.0282, 0.158 , 0.0102,\n",
164 | " 0.1407, -0.2371, 0.0467, 0.0846, 0.059 , 0.0277, -0.0588,\n",
165 | " -0.1047, -0.1641, -0.0593, -0.022 , 0.1137, -0.1559, 0.0385,\n",
166 | " 0.0147, 0.1155, -0.0089, 0.008 , -0.1397, 0.1101, -0.0368,\n",
167 | " -0.0584, -0.0091, 0.0055, -0.0985, 0.1003, -0.0447, 0.1091,\n",
168 | " 0.0239, 0.0129, 0.0153, -0.1851, 0.1479, 0.064 , -0.0041,\n",
169 | " -0.0896, 0.08 , 0.0343, -0.038 , 0.0705, -0.0915, 0.1503,\n",
170 | " 0.0479, 0.006 , 0.1062, 0.0698, 0.0503, 0.1777, -0.0383,\n",
171 | " 0.054 , 0.1732, 0.0136, -0.0869, -0.0776, -0.0567, 0.1193,\n",
172 | " 0.0808, 0.2016, 0.1375, -0.0678, 0.1257, 0.1043, 0.0659,\n",
173 | " -0.0394, -0.2398],\n",
174 | " [ 0.1047, -0.0278, 0.0394, -0.0116, -0.0278, 0.1107, 0.0449,\n",
175 | " -0.035 , 0.0247, 0.0578, -0.1732, -0.0237, 0.1903, 0.059 ,\n",
176 | " 0.09 , 0.0089, -0.0202, 0.111 , 0.004 , -0.016 , -0.0801,\n",
177 | " 0.1059, 0.0315, 0.03 , 0.084 , 0.0228, 0.0251, -0.0063,\n",
178 | " 0.0831, -0.0133, 0.1926, -0.2026, 0.1024, -0.0399, -0.1468,\n",
179 | " -0.1746, 0.1429, -0.1128, -0.0426, -0.1053, 0.039 , 0.0626,\n",
180 | " -0.0411, 0.0109, -0.0389, 0.1114, 0.018 , -0.1367, -0.0503,\n",
181 | " -0.1345, 0.0085, -0.0016, 0.1778, -0.069 , -0.1027, -0.205 ,\n",
182 | " -0.0415, 0.0277, -0.1185, 0.0991, -0.013 , 0.0184, -0.1554,\n",
183 | " 0.1983, -0.02 , 0.1438, 0.0074, 0.0213, 0.2181, -0.0763,\n",
184 | " -0.1173, 0.2085, 0.088 , 0.0155, -0.1364, 0.0703, -0.0937,\n",
185 | " -0.0265, -0.0185, -0.0257, 0.1124, 0.0795, 0.0913, 0.001 ,\n",
186 | " 0.2215, 0.1455, -0.0052, 0.0898, -0.0091, -0.2242, -0.1665,\n",
187 | " 0.0639, 0.0636, 0.0718, 0.0732, -0.1127, -0.1063, -0.1051,\n",
188 | " -0.0029, -0.0029],\n",
189 | " [-0.1149, 0.0279, -0.1338, 0.116 , 0.0468, 0.1232, 0.054 ,\n",
190 | " 0.0565, -0.1259, -0.0839, 0.152 , -0.0863, 0.0003, -0.2371,\n",
191 | " 0.0439, 0.051 , 0.109 , 0.055 , 0.131 , 0.0244, -0.0124,\n",
192 | " 0.2318, 0.1091, -0.0065, -0.0539, 0.0326, 0.1813, 0.024 ,\n",
193 | " 0.0575, 0.1769, 0.0659, -0.0711, -0.0469, -0.0175, -0.0626,\n",
194 | " -0.0345, 0.0641, 0.1338, 0.0266, 0.0062, -0.0386, 0.1407,\n",
195 | " -0.0436, 0.0055, 0.0431, 0.1255, -0.041 , -0.0348, -0.1199,\n",
196 | " -0.0158, 0.1152, -0.2198, 0.1552, 0.0876, -0.1477, -0.1201,\n",
197 | " 0.068 , -0.0179, 0.0766, 0.0899, 0.0199, -0.1206, 0.1611,\n",
198 | " -0.1097, -0.0424, -0.0406, 0.1487, 0.1289, -0.0652, -0.1107,\n",
199 | " 0.0225, -0.1439, -0.1709, -0.1325, 0.026 , -0.1506, 0.0753,\n",
200 | " 0.1793, 0.0631, 0.1439, -0.0564, 0.1311, 0.09 , 0.0136,\n",
201 | " 0.1341, -0.0405, 0.0953, 0.0239, -0.0048, 0.0429, 0.0573,\n",
202 | " 0.0523, -0.0403, 0.0003, -0.0562, 0.0885, -0.1245, -0.0766,\n",
203 | " 0.1717, 0.0722],\n",
204 | " [-0.2255, -0.0402, -0.1114, 0.1937, -0.0568, 0.0984, -0.0648,\n",
205 | " -0.1507, 0.1158, 0.1139, -0.0229, 0.0422, 0.197 , 0.0318,\n",
206 | " 0.0226, -0.1479, 0.0431, -0.1541, -0.0681, -0.0846, 0.1133,\n",
207 | " 0.0326, 0.0527, -0.1195, -0.1615, 0.1094, 0.0358, 0.2004,\n",
208 | " -0.0675, 0.1282, 0.1066, 0.0048, -0.0814, 0.1301, 0.0967,\n",
209 | " 0.075 , 0.1008, 0.0662, -0.1028, -0.0577, 0.0941, 0.0233,\n",
210 | " -0.1145, 0.0581, -0.1916, -0.0421, 0.1936, 0.0785, -0.0758,\n",
211 | " -0.056 , 0.0159, 0.0372, -0.0504, -0.0189, -0.0927, -0.0653,\n",
212 | " -0.0319, 0.0388, 0.0135, -0.0001, -0.2637, 0.1775, 0.0873,\n",
213 | " 0.0548, 0.0287, -0.0601, -0.1801, 0.0968, -0.0181, -0.0947,\n",
214 | " 0.0702, 0.0035, 0.0148, 0.0786, 0.046 , -0.0466, -0.0671,\n",
215 | " -0.0059, -0.0293, 0.133 , 0.0439, -0.0543, 0.101 , -0.1619,\n",
216 | " 0.0623, -0.1021, -0.1044, -0.1816, -0.1118, 0.0177, -0.0705,\n",
217 | " -0.0982, -0.0008, 0.084 , -0.0429, -0.0115, 0.1613, -0.0457,\n",
218 | " -0.1003, -0.052 ],\n",
219 | " [ 0.1515, 0.0668, 0.1857, -0.1832, 0.1376, 0.0806, 0.02 ,\n",
220 | " -0.0895, 0.1381, 0.1017, 0.1031, -0.0905, 0.0288, 0.0122,\n",
221 | " 0.285 , 0.214 , 0.175 , 0.0477, -0.1441, 0.1352, -0.0321,\n",
222 | " 0.013 , 0.0007, 0.0931, -0.1038, -0.1026, -0.0756, -0.0975,\n",
223 | " 0.1073, 0.0184, -0.1268, 0.0242, -0.0954, 0.0914, -0.0318,\n",
224 | " -0.05 , 0.0268, 0.0005, -0.1374, -0.0452, -0.1993, 0.1278,\n",
225 | " -0.039 , -0.0092, -0.1076, 0.0096, 0.1636, -0.0246, -0.1138,\n",
226 | " 0.1856, -0.0244, -0.1302, -0.0931, -0.007 , 0.1454, -0.1308,\n",
227 | " 0.1207, 0.0613, -0.1073, 0.0442, -0.0394, 0.1058, -0.0777,\n",
228 | " 0.0225, -0.0158, -0.1921, -0.0933, -0.0287, -0.0471, -0.1154,\n",
229 | " -0.0455, -0.0632, -0.0143, 0.1797, 0.0116, -0.1832, 0.0383,\n",
230 | " -0.0319, 0.0525, -0.1981, -0.0385, 0.0208, 0.0057, -0.1077,\n",
231 | " 0.0467, -0.0724, 0.0701, 0.0228, 0.034 , 0.0531, -0.0178,\n",
232 | " -0.0728, -0.1358, -0.0163, -0.0502, -0.0159, -0.0144, -0.0449,\n",
233 | " -0.0267, 0.0774]]))"
234 | ]
235 | },
236 | "execution_count": 51,
237 | "metadata": {},
238 | "output_type": "execute_result"
239 | }
240 | ],
241 | "source": [
242 | "A = np.random.uniform(-40,40,[10,100]) \n",
243 | "randomized_svd(A, 5)"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 28,
249 | "metadata": {},
250 | "outputs": [
251 | {
252 | "data": {
253 | "text/plain": [
254 | "(array([[-0.0927, -0.3094, 0.0138, -0.498 , 0.1362, -0.6804, 0.3705,\n",
255 | " 0.0016, -0.1609, -0.0533],\n",
256 | " [-0.3696, -0.5729, 0.1055, 0.0333, -0.1139, 0.0144, -0.315 ,\n",
257 | " 0.4539, 0.4073, -0.1962],\n",
258 | " [ 0.4726, -0.0089, 0.1479, 0.0591, -0.3718, -0.1061, -0.0042,\n",
259 | " -0.0511, -0.0456, -0.7727],\n",
260 | " [ 0.2914, -0.5239, -0.3223, -0.251 , -0.0146, 0.3509, -0.2546,\n",
261 | " -0.0093, -0.5257, 0.0952],\n",
262 | " [-0.1348, 0.3822, 0.1706, -0.4915, 0.2152, 0.3631, 0.0837,\n",
263 | " 0.5217, -0.1876, -0.2691],\n",
264 | " [ 0.4823, -0.0967, 0.0848, 0.4299, 0.5133, -0.1813, 0.0785,\n",
265 | " 0.5014, -0.0588, 0.093 ],\n",
266 | " [ 0.093 , -0.0529, -0.6686, -0.0841, 0.2823, 0.2049, 0.332 ,\n",
267 | " -0.0763, 0.4755, -0.2658],\n",
268 | " [-0.2028, 0.0905, -0.422 , 0.2922, -0.5218, -0.0799, 0.3588,\n",
269 | " 0.4363, -0.292 , 0.0649],\n",
270 | " [ 0.496 , 0.079 , 0.0429, -0.3942, -0.4031, -0.0208, 0.0142,\n",
271 | " 0.2405, 0.4233, 0.4363],\n",
272 | " [ 0.0036, -0.3589, 0.4469, 0.101 , -0.076 , 0.4346, 0.668 ,\n",
273 | " -0.1165, 0.0226, 0.0793]]),\n",
274 | " array([ 277.6828, 266.5159, 255.7381, 251.1185, 231.8737, 227.6304,\n",
275 | " 208.1148, 193.3796, 180.6411, 165.6627]),\n",
276 | " array([[-0.1194, -0.1059, 0.0448, 0.0323, -0.0184, 0.0164, 0.1107,\n",
277 | " 0.1889, 0.0443, 0.152 , -0.1577, -0.1047, -0.0287, 0.0447,\n",
278 | " -0.0222, 0.1027, 0.0896, 0.0256, -0.0631, 0.2503, -0.0941,\n",
279 | " -0.2155, -0.1066, 0.0411, -0.0044, -0.1438, 0.0143, -0.0057,\n",
280 | " -0.0201, 0.094 , 0.0311, 0.1404, -0.1361, 0.0639, 0.038 ,\n",
281 | " 0.1444, 0.1051, -0.0501, -0.0036, 0.1643, 0.0518, 0.0598,\n",
282 | " -0.0852, -0.1342, 0.0366, 0.1767, -0.099 , -0.0103, 0.0108,\n",
283 | " 0.1778, -0.159 , -0.0396, 0.0145, -0.1709, 0.0699, -0.165 ,\n",
284 | " 0.0288, 0.0027, 0.0259, -0.0587, -0.0326, 0.0643, -0.0139,\n",
285 | " -0.227 , -0.0479, 0.0867, -0.0538, -0.0293, -0.0625, -0.0212,\n",
286 | " 0.1295, -0.0414, -0.0687, -0.1003, 0.1867, 0.0899, 0.0124,\n",
287 | " 0.0393, 0.1353, 0.0867, -0.1051, -0.0636, -0.0234, 0.1038,\n",
288 | " -0.0898, -0.1123, -0.1175, -0.1263, -0.0303, 0.1688, -0.0692,\n",
289 | " -0.0361, 0.0794, 0.0727, -0.0921, -0.0932, 0.0346, -0.0148,\n",
290 | " -0.1141, -0.1871],\n",
291 | " [-0.0684, -0.0398, 0.0206, -0.1001, -0.1431, 0.0881, -0.1984,\n",
292 | " -0.0424, -0.0399, -0.0002, 0.1922, 0.0108, -0.0755, 0.0587,\n",
293 | " -0.0342, 0.1755, 0.0795, 0.0001, -0.0849, 0.0995, 0.121 ,\n",
294 | " -0.167 , -0.1273, 0.0813, 0.0177, -0.0494, -0.1588, -0.0827,\n",
295 | " 0.1408, -0.0441, 0.0248, 0.0259, 0.2409, 0.1924, 0.0976,\n",
296 | " 0.0346, -0.1064, -0.0406, -0.1364, -0.0326, 0.0262, 0.0643,\n",
297 | " 0.0348, 0.1109, -0.0144, 0.1152, -0.1018, 0.0876, 0.0459,\n",
298 | " 0.008 , 0.059 , 0.1505, -0.0883, 0.0658, -0.0711, 0.0704,\n",
299 | " -0.0777, -0.2006, -0.0437, 0.0167, -0.0064, 0.1083, 0.0233,\n",
300 | " -0.0497, -0.0578, -0.0621, 0.0751, -0.1627, -0.0343, -0.0834,\n",
301 | " -0.1485, -0.0777, -0.0034, -0.1765, 0.0568, 0.0163, 0.1305,\n",
302 | " 0.1266, -0.1724, -0.0842, 0.1017, 0.0856, 0.1258, 0.1359,\n",
303 | " -0.0683, -0.1306, 0.045 , 0.0655, 0.0264, -0.1812, -0.1439,\n",
304 | " -0.0052, -0.1109, 0.0527, -0.1811, -0.0578, -0.0426, -0.0676,\n",
305 | " 0.1123, -0.0313],\n",
306 | " [ 0.0098, 0.0089, 0.0564, -0.088 , 0.0003, -0.0827, 0.1666,\n",
307 | " -0.062 , 0.1427, -0.0718, 0.0482, 0.1847, -0.1329, -0.0236,\n",
308 | " 0.086 , -0.1303, -0.0225, -0.1029, 0.0635, 0.05 , -0.0386,\n",
309 | " 0.0086, -0.1426, 0.1299, -0.0772, -0.0782, -0.0802, 0.1941,\n",
310 | " 0.1048, -0.0166, 0.0384, 0.1791, -0.0119, -0.0255, -0.064 ,\n",
311 | " -0.1555, 0.0087, 0.132 , 0.1469, -0.0894, 0.0523, 0.0109,\n",
312 | " 0.0867, 0.0801, -0.0139, 0.1224, -0.1428, 0.0412, 0.2326,\n",
313 | " 0.0224, -0.0209, -0.085 , -0.0408, 0.0454, 0.0968, 0.0217,\n",
314 | " -0.2144, 0.1343, -0.0781, 0.0255, -0.0642, -0.0879, 0.0371,\n",
315 | " 0.0825, 0.0888, -0.1127, 0.028 , -0.0634, 0.139 , -0.1085,\n",
316 | " 0.1306, -0.0432, -0.1005, -0.0435, -0.0653, 0.0624, -0.0193,\n",
317 | " 0.0567, 0.1523, 0.031 , 0.0664, 0.0582, 0.0684, 0.1933,\n",
318 | " -0.0126, 0.0368, 0.0381, -0.0924, -0.1445, -0.0333, -0.1543,\n",
319 | " -0.1868, 0.0154, -0.2105, 0.1586, -0.1404, 0.1674, -0.0023,\n",
320 | " 0.0928, 0.0193],\n",
321 | " [-0.0312, 0.0933, -0.0229, 0.0331, 0.1376, 0.1745, -0.0012,\n",
322 | " 0.1118, 0.1125, 0.0059, -0.0473, -0.1874, 0.0003, -0.1555,\n",
323 | " 0.0708, 0.0351, -0.0353, -0.069 , 0.0834, 0.1013, 0.0272,\n",
324 | " 0.0905, -0.0961, -0.013 , -0.2107, 0.0895, 0.0946, 0.0179,\n",
325 | " 0.0648, -0.0176, -0.0776, -0.1891, 0.0275, 0.2144, 0.1968,\n",
326 | " -0.0629, -0.0736, 0.0178, 0.1535, -0.1932, -0.1041, 0.0819,\n",
327 | " 0.0172, -0.0118, -0.208 , 0.1033, -0.0175, 0.139 , 0.1149,\n",
328 | " 0.0169, -0.07 , -0.0948, 0.0045, 0.037 , 0.0593, 0.0397,\n",
329 | " 0.1287, 0.0053, -0.0724, -0.0459, 0.0204, 0.0833, 0.2035,\n",
330 | " -0.0783, 0.0821, -0.0689, 0.0843, 0.1898, 0.1107, 0.0165,\n",
331 | " -0.053 , 0.0846, -0.0065, -0.0038, -0.1598, 0.1618, 0.0794,\n",
332 | " -0.0485, -0.1004, 0.0229, -0.0965, -0.1124, -0.0565, 0.057 ,\n",
333 | " -0.0121, 0.0612, 0.0166, 0.0287, -0.1395, -0.0852, 0.1669,\n",
334 | " 0.129 , 0.1284, -0.0236, -0.1488, -0.1463, -0.0404, -0.0865,\n",
335 | " -0.0594, -0.1029],\n",
336 | " [ 0.0581, -0.0834, -0.1478, 0.0444, -0.0717, 0.0233, 0.016 ,\n",
337 | " -0.0592, -0.0853, 0.0948, 0.0738, -0.1201, -0.2125, 0.2126,\n",
338 | " 0.1718, 0.0415, -0.1715, 0.0089, 0.1158, -0.0047, 0.0028,\n",
339 | " -0.0533, 0.0282, 0.0168, 0.0254, -0.0085, -0.1421, -0.0544,\n",
340 | " 0.1502, -0.0623, 0.0225, -0.1236, -0.0113, 0.0442, 0.0045,\n",
341 | " 0.0559, -0.0042, -0.243 , -0.0074, 0.0362, 0.0433, -0.0972,\n",
342 | " 0.1463, -0.0052, -0.0015, 0.0838, 0.0502, -0.06 , 0.1244,\n",
343 | " 0.0457, -0.0553, -0.042 , -0.0861, 0.0734, -0.1825, 0.0787,\n",
344 | " -0.0062, 0.1039, 0.1444, 0.1039, -0.0022, -0.2063, 0.0464,\n",
345 | " 0.0867, -0.1355, 0.0045, -0.091 , 0.0281, -0.2494, 0.1741,\n",
346 | " -0.0993, 0.0978, 0.1304, 0.1351, 0.032 , -0.0123, -0.0124,\n",
347 | " -0.1266, 0.1234, 0.0331, -0.1054, -0.192 , 0.0886, 0.0018,\n",
348 | " -0.1345, 0.039 , -0.0252, 0.0639, -0.1552, 0.065 , -0.0225,\n",
349 | " -0.0379, 0.0223, -0.1473, -0.0158, -0.0979, 0.1099, -0.1995,\n",
350 | " 0.0381, -0.0259],\n",
351 | " [-0.1494, 0.1693, 0.1817, -0.0668, 0.1318, -0.0355, 0.1074,\n",
352 | " 0.0049, -0.0429, -0.2172, 0.1686, -0.0052, 0.0159, -0.0334,\n",
353 | " 0.157 , 0.0632, -0.1313, 0.1278, -0.1335, -0.0233, 0.2449,\n",
354 | " -0.1248, 0.0389, -0.1911, 0.1007, -0.1928, 0.1366, -0.213 ,\n",
355 | " -0.0734, -0.0542, 0.0653, 0.1547, 0.0024, -0.081 , -0.1056,\n",
356 | " 0.0168, -0.1497, -0.0766, 0.0437, 0.0195, 0.0475, -0.147 ,\n",
357 | " 0.1401, 0.0972, -0.1178, 0.0719, 0.0297, -0.0247, 0.1005,\n",
358 | " 0.0221, -0.1239, -0.0429, -0.0457, -0.072 , 0.0078, -0.0448,\n",
359 | " 0.0934, -0.1073, 0.0093, 0.0061, 0.0357, -0.0423, 0.0311,\n",
360 | " 0.015 , 0.1778, -0.0428, -0.0722, 0.0227, 0.0522, -0.0692,\n",
361 | " -0.1277, -0.0107, -0.0187, 0.022 , 0.1519, 0.0605, -0.1817,\n",
362 | " 0.0164, 0.0407, -0.1203, 0.0541, -0.0445, -0.0469, 0.0478,\n",
363 | " 0.0052, 0.0869, -0.1958, -0.0349, 0.0165, -0.0536, 0.1215,\n",
364 | " 0.0844, 0.0429, 0.0508, 0.1345, -0.0663, -0.0689, -0.0959,\n",
365 | " 0.0705, -0.096 ],\n",
366 | " [ 0.042 , 0.0353, -0.0529, 0.0904, -0.1003, 0.0023, -0.1248,\n",
367 | " 0.1678, 0.1003, 0.1283, 0.0016, 0.0692, 0.0479, 0.0497,\n",
368 | " 0.006 , 0.08 , -0.1133, 0.0636, 0.0552, -0.1907, 0.1857,\n",
369 | " 0.1229, 0.075 , 0.0927, -0.2153, -0.1263, 0.0556, 0.0235,\n",
370 | " -0.1678, 0.2431, 0.0308, -0.0247, -0.0646, -0.0142, -0.061 ,\n",
371 | " -0.1598, -0.1329, -0.0674, -0.0833, 0.1434, -0.1402, 0.1615,\n",
372 | " 0.1094, 0.013 , -0.0477, 0.1196, 0.1645, 0.1372, -0.0568,\n",
373 | " 0.0957, -0.0004, 0.1449, -0.016 , 0.1166, 0.0124, -0.0402,\n",
374 | " -0.0253, 0.0418, -0.1052, 0.087 , 0.1117, 0.0161, -0.0925,\n",
375 | " 0.0134, -0.0268, -0.1936, -0.1072, 0.1193, 0.0576, -0.0095,\n",
376 | " 0.0531, -0.1384, -0.0812, 0.1429, 0.1244, 0.0282, -0.0282,\n",
377 | " -0.0062, 0.0311, -0.0499, -0.1484, 0.0524, 0.0696, 0.1764,\n",
378 | " 0.1374, 0.0053, -0.1385, 0.0256, -0.0436, 0.0372, -0.1558,\n",
379 | " 0.06 , -0.1722, -0.0392, -0.0916, 0.0571, 0.0847, -0.0076,\n",
380 | " -0.045 , -0.0058],\n",
381 | " [-0.0644, 0.074 , 0.0298, 0.1139, 0.029 , 0.0728, 0.0545,\n",
382 | " 0.0066, 0.0438, -0.1246, -0.0512, 0.0649, -0.0015, -0.0626,\n",
383 | " 0.0177, 0.153 , -0.1139, -0.0785, -0.0914, -0.0269, -0.0965,\n",
384 | " -0.0256, 0.0953, -0.0307, 0.1179, -0.0643, 0.0556, -0.0318,\n",
385 | " 0.0098, -0.2266, -0.0066, -0.1737, -0.012 , 0.0155, 0.0827,\n",
386 | " -0.1823, 0.1417, -0.0809, -0.0141, 0.1102, 0.0483, 0.1434,\n",
387 | " -0.0643, -0.0592, -0.1776, -0.0353, -0.1399, -0.0311, 0.0601,\n",
388 | " -0.0556, 0.2011, 0.026 , 0.0247, 0.1667, 0.0975, 0.1257,\n",
389 | " 0.0554, 0.0787, 0.267 , -0.0717, -0.1579, -0.0433, -0.0424,\n",
390 | " -0.0702, -0.1176, -0.1289, -0.182 , 0.0326, 0.0815, -0.0226,\n",
391 | " -0.0667, 0.2647, -0.0972, -0.1035, 0.0266, 0.0695, 0.0393,\n",
392 | " -0.0123, 0.0945, -0.1469, 0.0791, 0.1074, 0.0504, 0.0933,\n",
393 | " 0.1163, -0.0456, -0.1153, -0.0878, 0.1167, 0.1422, -0.0062,\n",
394 | " -0.1823, -0.1169, -0.0334, -0.1232, 0.0955, -0.0013, -0.0222,\n",
395 | " -0.0516, -0.0122],\n",
396 | " [ 0.1755, 0.0701, -0.0588, -0.0834, 0.0671, 0.0504, 0.0299,\n",
397 | " -0.1256, 0.0072, 0.1121, 0.1107, 0.036 , 0.0904, 0.0448,\n",
398 | " 0.1639, -0.0031, -0.0833, -0.096 , 0.0674, 0.0226, 0.0837,\n",
399 | " -0.029 , -0.125 , -0.0895, 0.0796, 0.0353, -0.0196, 0.1451,\n",
400 | " 0.0373, -0.0487, 0.0204, -0.1386, -0.1177, -0.0687, 0.1374,\n",
401 | " 0.0844, -0.1328, 0.1388, -0.091 , -0.0843, 0.1455, 0.0904,\n",
402 | " 0.1358, 0.0097, 0.0926, 0.1205, -0.109 , -0.1065, 0.0966,\n",
403 | " 0.0455, 0.0929, 0.0133, -0.0075, -0.0903, 0.0095, -0.1677,\n",
404 | " 0.0606, 0.038 , 0.0353, 0.1976, 0.0134, 0.1137, -0.2059,\n",
405 | " -0.0271, 0.0912, 0.1127, -0.0967, 0.0737, 0.1882, 0.0675,\n",
406 | " -0.1005, -0.1373, 0.2494, -0.0256, 0.01 , 0.1084, 0.0665,\n",
407 | " -0.0967, -0.1128, 0.1204, -0.0376, 0.1391, -0.1169, -0.0605,\n",
408 | " 0.1143, 0.0154, -0.0114, -0.1473, 0.1942, 0.1005, -0.0274,\n",
409 | " -0.045 , -0.0829, 0.0436, -0.0515, 0.0796, 0.1322, -0.1056,\n",
410 | " 0.0053, -0.1475],\n",
411 | " [ 0.032 , -0.0456, 0.1179, -0.0105, 0.0812, -0.1035, 0.0459,\n",
412 | " 0.1157, -0.1335, 0.0648, 0.0538, 0.0327, -0.0078, -0.0107,\n",
413 | " -0.0912, -0.0911, -0.0455, -0.2188, -0.0527, 0.0974, -0.1089,\n",
414 | " 0.1092, 0.0744, 0.0784, 0.0168, -0.0292, 0.177 , -0.0883,\n",
415 | " 0.0922, -0.0812, 0.2664, -0.0381, -0.0137, -0.0896, -0.0391,\n",
416 | " 0.0098, 0.031 , 0.0609, -0.2073, -0.0634, -0.0691, 0.1889,\n",
417 | " -0.1392, -0.2065, 0.0611, 0.0558, 0.0481, 0.0993, 0.2513,\n",
418 | " 0.1045, -0.0177, -0.0063, 0.0053, 0.0057, -0.0123, 0.0603,\n",
419 | " 0.1883, -0.0631, -0.0205, 0.062 , -0.0031, -0.0486, 0.15 ,\n",
420 | " -0.004 , -0.1527, 0.0807, 0.01 , -0.1239, 0.0519, 0.0103,\n",
421 | " -0.0206, -0.127 , 0.1075, 0.0564, -0.0445, -0.1144, 0.1374,\n",
422 | " -0.1299, -0.0625, -0.0414, -0.0567, 0.0076, 0.0664, 0.069 ,\n",
423 | " 0.1062, 0.0418, -0.2425, 0.0768, -0.0972, -0.1723, -0.0346,\n",
424 | " -0.0602, -0.0282, 0.1123, 0.2317, 0.0189, -0.0737, -0.0803,\n",
425 | " 0.0466, 0.0221]]))"
426 | ]
427 | },
428 | "execution_count": 28,
429 | "metadata": {},
430 | "output_type": "execute_result"
431 | }
432 | ],
433 | "source": [
434 | "A = np.random.uniform(-40,40,[10,100]) \n",
435 | "simple_randomized_svd(A)"
436 | ]
437 | },
438 | {
439 | "cell_type": "code",
440 | "execution_count": 29,
441 | "metadata": {},
442 | "outputs": [
443 | {
444 | "data": {
445 | "text/plain": [
446 | "(tensor([[ 0.0022, -0.6062, -0.0613, -0.0791, 0.4002, 0.0024, 0.4948,\n",
447 | " -0.3578, -0.1446, 0.2620],\n",
448 | " [-0.4845, -0.1573, -0.2625, 0.5393, 0.0407, 0.3863, -0.0498,\n",
449 | " -0.1818, 0.0822, -0.4332],\n",
450 | " [ 0.1150, 0.0404, 0.2540, -0.4545, 0.3348, 0.1813, -0.1469,\n",
451 | " -0.3452, 0.5088, -0.4116],\n",
452 | " [-0.4244, -0.0384, -0.5109, -0.3065, -0.0551, -0.5866, -0.2308,\n",
453 | " -0.2045, 0.1428, 0.0278],\n",
454 | " [ 0.0904, -0.3397, 0.0358, 0.0862, -0.4325, -0.2028, 0.4178,\n",
455 | " 0.2802, 0.5992, -0.1656],\n",
456 | " [ 0.0009, -0.3461, 0.3275, -0.0386, 0.0213, -0.3891, -0.1232,\n",
457 | " 0.1403, -0.4726, -0.6011],\n",
458 | " [-0.5206, 0.2742, 0.4045, 0.1494, 0.4669, -0.2599, 0.2320,\n",
459 | " 0.2895, 0.1736, 0.1205],\n",
460 | " [-0.1635, -0.5329, 0.2680, 0.0149, -0.0122, 0.1470, -0.6118,\n",
461 | " 0.2005, 0.1841, 0.3836],\n",
462 | " [-0.1608, 0.1006, 0.5005, 0.1803, -0.4526, -0.1572, 0.0226,\n",
463 | " -0.6543, -0.0230, 0.1494],\n",
464 | " [-0.4901, -0.0142, 0.0852, -0.5818, -0.3353, 0.4092, 0.2439,\n",
465 | " 0.1664, -0.2147, -0.0278]], dtype=torch.float64, device='cuda:0'),\n",
466 | " tensor([ 288.0379, 268.4482, 256.8885, 250.8847, 240.9807, 222.8867,\n",
467 | " 211.8652, 196.4339, 188.9815, 181.7164], dtype=torch.float64, device='cuda:0'),\n",
468 | " tensor([[ 0.0402, 0.0179, 0.1010, 0.1340, -0.1198, 0.1070, -0.0303,\n",
469 | " -0.0145, -0.1296, -0.0939, 0.0347, 0.0685, 0.1009, -0.1867,\n",
470 | " 0.0653, 0.1440, -0.0890, -0.1349, 0.0387, -0.0470, 0.1600,\n",
471 | " 0.0895, -0.0106, 0.0787, -0.1653, -0.0410, 0.0178, -0.0412,\n",
472 | " 0.1382, 0.0518, 0.0527, -0.1435, 0.2222, -0.0775, -0.0303,\n",
473 | " 0.1840, -0.1013, 0.2213, -0.0228, -0.0692, 0.1175, -0.0974,\n",
474 | " 0.0645, 0.1434, 0.0482, 0.0197, 0.1182, -0.1008, -0.0731,\n",
475 | " 0.0752, -0.1366, 0.0342, 0.0873, 0.1115, 0.0357, 0.0789,\n",
476 | " 0.0411, 0.0976, -0.1566, -0.0157, 0.0403, 0.1302, 0.0862,\n",
477 | " -0.0429, -0.1030, -0.0418, 0.0139, 0.0478, 0.0930, 0.0372,\n",
478 | " -0.0272, -0.0317, -0.0502, 0.1139, -0.0241, -0.1365, 0.1510,\n",
479 | " -0.0726, 0.1520, -0.0190, -0.0144, -0.1091, -0.1078, -0.1329,\n",
480 | " -0.1456, 0.0768, -0.1030, -0.0654, 0.0006, 0.0919, 0.0320,\n",
481 | " 0.1976, 0.2087, 0.1534, -0.0819, 0.0658, 0.0864, 0.0159,\n",
482 | " 0.0848, -0.0768],\n",
483 | " [ 0.0028, -0.2304, 0.0922, 0.1209, 0.1288, -0.0215, 0.0165,\n",
484 | " -0.1730, -0.1264, -0.0367, 0.1728, 0.0732, 0.0905, 0.0096,\n",
485 | " -0.0204, 0.0053, 0.1326, -0.1568, 0.0401, -0.0747, 0.0065,\n",
486 | " 0.0739, 0.0326, -0.0914, 0.0649, -0.1903, 0.1179, -0.0136,\n",
487 | " -0.2076, -0.0016, -0.0031, 0.0441, -0.0161, -0.0986, 0.0192,\n",
488 | " 0.0174, 0.1647, 0.0056, 0.0768, 0.0866, -0.0080, 0.0353,\n",
489 | " -0.1670, -0.0447, -0.0199, 0.0955, -0.0614, 0.1367, 0.0534,\n",
490 | " 0.1255, 0.0044, 0.0299, -0.0637, -0.1124, 0.0150, -0.0660,\n",
491 | " 0.1089, 0.2244, -0.0302, 0.0019, -0.0065, -0.1051, 0.0256,\n",
492 | " 0.1989, 0.0457, -0.0703, 0.0894, 0.1100, -0.0395, -0.1362,\n",
493 | " -0.0774, -0.0100, -0.0332, -0.0585, -0.0319, -0.0984, -0.0465,\n",
494 | " -0.0808, 0.1002, 0.0236, -0.1758, -0.1891, -0.0240, 0.1430,\n",
495 | " -0.0646, 0.2253, 0.0332, 0.0496, 0.1592, -0.0396, 0.0287,\n",
496 | " 0.0894, -0.0473, 0.1375, 0.1620, -0.0834, -0.1199, -0.1188,\n",
497 | " 0.1046, 0.1351],\n",
498 | " [-0.0884, 0.0999, 0.0426, -0.0744, -0.0532, 0.0521, -0.0875,\n",
499 | " 0.0857, 0.0061, 0.0447, 0.1417, 0.0376, -0.1017, 0.1446,\n",
500 | " -0.0487, -0.0861, 0.2344, -0.0448, -0.0421, 0.1018, -0.1235,\n",
501 | " -0.1093, -0.0251, 0.0719, 0.0190, 0.1050, 0.0458, -0.0777,\n",
502 | " 0.1354, 0.1192, -0.1047, -0.0277, -0.0419, 0.0380, 0.0036,\n",
503 | " -0.0073, -0.1259, -0.0438, 0.1315, 0.0686, 0.1742, 0.0673,\n",
504 | " -0.0877, 0.0119, 0.0763, 0.1042, -0.0186, -0.0021, 0.0122,\n",
505 | " 0.1370, -0.0361, 0.0547, -0.0003, -0.0124, 0.0549, 0.1181,\n",
506 | " -0.0179, -0.0130, -0.0742, -0.0555, 0.1165, 0.1674, -0.2366,\n",
507 | " 0.1009, 0.1254, 0.1372, -0.1783, 0.1805, -0.1533, -0.2399,\n",
508 | " -0.2054, 0.1292, -0.1053, 0.0793, 0.0486, -0.0115, 0.0378,\n",
509 | " -0.1045, 0.0460, 0.1473, 0.1193, 0.0784, 0.1839, 0.0148,\n",
510 | " -0.1172, 0.0546, -0.0588, -0.1027, -0.0312, 0.0176, -0.1419,\n",
511 | " -0.0781, 0.1442, 0.1013, -0.0276, 0.0649, -0.0431, 0.0485,\n",
512 | " 0.1122, 0.0601],\n",
513 | " [ 0.0264, 0.0399, -0.1116, -0.1556, 0.1822, 0.0549, -0.1786,\n",
514 | " 0.0368, -0.0447, -0.0710, 0.0534, -0.2455, 0.1285, -0.0170,\n",
515 | " -0.1809, 0.1463, 0.1507, 0.0374, 0.0847, -0.0951, -0.1087,\n",
516 | " 0.0752, -0.1242, 0.0321, -0.0560, 0.1063, -0.0261, -0.0935,\n",
517 | " -0.0339, -0.1423, -0.0467, -0.1533, 0.1397, -0.0679, 0.0439,\n",
518 | " 0.0878, -0.0127, 0.0274, 0.1109, 0.0243, 0.0185, -0.1542,\n",
519 | " -0.0204, 0.0007, 0.0391, 0.1079, -0.0636, 0.1032, 0.1512,\n",
520 | " 0.0684, -0.0886, -0.0413, -0.1078, 0.0172, -0.0120, -0.1109,\n",
521 | " -0.0179, -0.0591, 0.1489, 0.1321, -0.0895, -0.0721, 0.0246,\n",
522 | " 0.0583, -0.1290, 0.0604, 0.0362, 0.1651, 0.1972, 0.1336,\n",
523 | " 0.0104, 0.0291, 0.0235, 0.2579, 0.0867, -0.0081, -0.1880,\n",
524 | " -0.1350, -0.0562, 0.1306, -0.0211, -0.0152, 0.0277, -0.0011,\n",
525 | " 0.0464, 0.0237, -0.0550, 0.0843, -0.1540, -0.0198, -0.1439,\n",
526 | " 0.0274, -0.1024, 0.0887, 0.0165, -0.0300, 0.0527, 0.1530,\n",
527 | " -0.1372, -0.0620],\n",
528 | " [ 0.1347, 0.0335, -0.0503, 0.0225, 0.0985, -0.1284, -0.1658,\n",
529 | " 0.0607, 0.0117, -0.0873, -0.1379, 0.0673, -0.1342, -0.0387,\n",
530 | " 0.1093, -0.0318, -0.0723, 0.0280, 0.0152, -0.1466, 0.0698,\n",
531 | " -0.0916, -0.1624, 0.0353, 0.1334, -0.1421, -0.0433, 0.1296,\n",
532 | " -0.0337, -0.0706, 0.0881, 0.1332, 0.0894, -0.0915, 0.3250,\n",
533 | " 0.0995, -0.0664, -0.0735, 0.0625, 0.2010, -0.0397, 0.0893,\n",
534 | " -0.0882, -0.0379, 0.1399, 0.0904, -0.1091, 0.0112, 0.0447,\n",
535 | " -0.0305, 0.2226, -0.0058, 0.1066, 0.0163, 0.1464, -0.0573,\n",
536 | " -0.0208, -0.0964, -0.1697, -0.0079, 0.0895, 0.0288, 0.0678,\n",
537 | " -0.0428, 0.0088, 0.0403, -0.0180, -0.0395, -0.0164, 0.0931,\n",
538 | " 0.0240, -0.1108, -0.1498, -0.1309, 0.0183, -0.1166, 0.0619,\n",
539 | " -0.0851, 0.0168, -0.0518, 0.1131, 0.1017, 0.0206, -0.0342,\n",
540 | " 0.0488, 0.1308, -0.2026, -0.0302, -0.0831, -0.1704, -0.0906,\n",
541 | " 0.0435, 0.0368, 0.1438, -0.0062, -0.1276, 0.1259, 0.1237,\n",
542 | " -0.0343, 0.1287],\n",
543 | " [-0.1197, -0.0973, 0.1977, 0.0149, 0.0162, 0.1260, -0.0780,\n",
544 | " 0.1430, 0.1305, -0.1031, -0.0307, 0.0298, -0.0824, -0.0827,\n",
545 | " -0.2335, 0.0854, -0.0503, -0.0273, 0.2594, 0.1589, 0.0108,\n",
546 | " -0.0249, -0.0452, -0.2144, -0.0325, -0.1060, 0.0481, 0.1667,\n",
547 | " -0.0486, 0.0701, -0.0498, -0.2067, -0.0975, -0.0406, 0.1167,\n",
548 | " -0.1278, 0.1188, 0.0176, 0.1025, -0.0021, -0.1293, -0.0431,\n",
549 | " 0.0168, -0.0639, 0.0910, -0.0693, -0.1139, 0.0606, 0.0382,\n",
550 | " -0.0894, -0.0208, 0.1261, -0.1226, 0.2494, 0.0981, 0.1157,\n",
551 | " 0.1099, 0.0794, -0.0302, 0.0600, 0.1342, -0.0972, 0.0775,\n",
552 | " -0.0795, -0.0038, 0.0509, 0.0711, 0.0791, -0.0390, 0.0265,\n",
553 | " 0.0225, 0.0193, 0.0223, 0.0571, -0.0316, 0.0895, 0.0840,\n",
554 | " -0.0445, -0.1916, 0.0676, 0.1175, 0.1351, 0.0194, -0.1393,\n",
555 | " -0.0515, -0.0252, 0.0052, -0.1010, 0.2183, 0.0888, -0.0397,\n",
556 | " 0.0042, 0.0088, -0.0649, -0.0205, -0.1267, 0.0966, -0.0467,\n",
557 | " 0.0872, -0.0040],\n",
558 | " [ 0.0936, 0.1248, -0.0084, 0.2326, 0.0379, -0.0158, 0.0899,\n",
559 | " -0.0030, 0.1916, 0.0537, 0.0406, 0.0789, 0.0987, -0.1878,\n",
560 | " -0.0692, -0.0683, 0.0770, -0.0242, 0.1264, -0.1108, 0.0093,\n",
561 | " 0.0759, 0.0806, 0.0937, 0.1041, -0.0796, -0.0383, 0.0123,\n",
562 | " 0.0199, -0.0351, 0.0958, 0.0272, 0.1233, 0.2582, 0.0167,\n",
563 | " 0.0982, -0.0259, -0.0768, 0.0960, 0.0638, 0.1092, -0.0650,\n",
564 | " 0.1729, -0.0732, -0.0330, -0.0034, 0.0703, -0.0552, -0.0696,\n",
565 | " -0.0127, 0.0375, -0.1205, -0.1232, -0.0286, -0.2567, -0.0219,\n",
566 | " 0.0577, 0.0558, -0.0252, 0.0113, -0.0771, -0.0744, 0.0281,\n",
567 | " 0.0213, -0.0900, -0.0388, -0.1575, -0.0226, -0.0398, -0.0244,\n",
568 | " -0.2027, -0.0248, 0.0407, -0.0918, 0.0456, 0.1203, -0.0913,\n",
569 | " 0.1207, -0.0250, 0.1320, 0.0025, 0.1495, 0.1492, -0.0344,\n",
570 | " -0.0396, 0.1223, -0.2094, 0.0037, 0.1174, 0.1286, -0.1494,\n",
571 | " -0.0698, 0.0424, -0.0048, 0.1082, -0.1825, 0.1142, -0.1606,\n",
572 | " -0.0968, -0.1519],\n",
573 | " [-0.0016, -0.1800, 0.0698, 0.1645, 0.0473, -0.0481, -0.1001,\n",
574 | " 0.1240, -0.0739, 0.1824, -0.0996, -0.0741, -0.0305, -0.0974,\n",
575 | " -0.0156, -0.1295, -0.0522, 0.0011, 0.1147, -0.0729, -0.0135,\n",
576 | " -0.1363, 0.1198, -0.0800, 0.0343, 0.0096, 0.0645, 0.1102,\n",
577 | " 0.0368, 0.0777, 0.0726, -0.0835, 0.0123, 0.0795, -0.0760,\n",
578 | " 0.1144, -0.0869, -0.0926, -0.0165, -0.1925, -0.0161, -0.2408,\n",
579 | " -0.0979, -0.0226, 0.0420, -0.1627, 0.0011, -0.0376, 0.1536,\n",
580 | " 0.1317, -0.0440, 0.0126, 0.1233, -0.0229, -0.0363, 0.1184,\n",
581 | " -0.1023, -0.0566, 0.0192, 0.1838, -0.1512, 0.0650, -0.0489,\n",
582 | " 0.0779, 0.0833, 0.0991, 0.1060, 0.1444, -0.2267, -0.0181,\n",
583 | " 0.1801, -0.2471, -0.0391, 0.0135, 0.0094, -0.0075, -0.0308,\n",
584 | " -0.0231, 0.1661, -0.0127, -0.0828, -0.1450, 0.1563, 0.0933,\n",
585 | " 0.0916, -0.1271, -0.1525, 0.0836, 0.0689, -0.0273, -0.0632,\n",
586 | " -0.0709, 0.0688, -0.1060, 0.0094, 0.0156, 0.0996, 0.0740,\n",
587 | " -0.0237, 0.0647],\n",
588 | " [-0.0452, 0.0037, 0.0861, -0.0780, -0.0127, -0.0346, -0.1745,\n",
589 | " -0.0082, -0.0197, 0.1225, -0.0097, -0.1630, -0.0391, -0.1403,\n",
590 | " 0.0047, -0.0301, 0.0380, 0.0689, 0.0560, -0.1448, 0.0415,\n",
591 | " -0.0765, 0.0268, -0.0093, -0.0013, -0.0536, 0.1745, 0.0851,\n",
592 | " 0.1459, 0.0460, 0.0361, -0.0080, -0.0371, -0.0591, -0.0021,\n",
593 | " -0.0232, -0.1775, -0.0694, 0.2404, 0.2391, 0.1042, 0.0004,\n",
594 | " -0.0638, -0.0455, -0.1881, -0.0685, -0.0020, -0.0074, 0.0631,\n",
595 | " 0.0061, -0.0239, 0.0014, 0.1150, 0.0216, -0.0501, -0.1460,\n",
596 | " 0.0868, 0.0058, 0.0238, -0.0468, 0.1932, -0.1387, 0.1993,\n",
597 | " -0.1703, 0.0708, -0.1123, -0.0992, 0.0359, 0.1450, -0.0621,\n",
598 | " 0.0057, -0.1338, -0.0617, -0.0484, 0.2194, 0.0195, 0.0898,\n",
599 | " 0.1415, -0.0739, 0.1688, -0.0542, -0.0715, -0.0008, 0.2597,\n",
600 | " -0.0670, -0.0884, 0.1334, -0.0423, -0.0628, 0.1305, 0.1561,\n",
601 | " 0.0818, 0.0802, -0.0710, -0.0751, 0.1020, -0.1029, -0.0448,\n",
602 | " -0.0334, -0.0624],\n",
603 | " [ 0.1935, -0.1132, -0.0640, 0.0775, -0.1758, -0.0304, -0.0325,\n",
604 | " -0.0281, 0.0281, 0.0623, 0.0806, -0.0406, -0.0136, 0.0796,\n",
605 | " -0.0326, -0.0139, 0.0010, -0.0886, 0.0059, 0.0165, -0.0906,\n",
606 | " 0.0594, -0.1253, 0.0824, -0.0844, -0.1503, -0.0779, 0.0404,\n",
607 | " -0.1407, 0.1032, 0.0024, 0.0365, -0.0915, 0.1998, -0.0369,\n",
608 | " 0.0309, 0.0805, 0.1246, 0.0020, -0.0381, 0.1327, 0.1264,\n",
609 | " -0.0554, 0.0573, -0.2027, 0.1270, 0.1991, 0.1515, 0.1560,\n",
610 | " -0.0187, 0.0049, -0.0156, 0.0914, 0.0688, 0.1862, 0.0877,\n",
611 | " -0.0511, -0.0257, 0.1836, -0.0647, 0.1148, -0.0108, 0.0483,\n",
612 | " 0.0404, 0.0392, 0.0261, 0.1588, 0.0488, 0.1369, 0.1264,\n",
613 | " -0.1223, -0.1065, 0.0010, -0.0762, 0.1735, -0.0131, 0.0517,\n",
614 | " -0.1084, -0.2028, -0.1659, -0.0472, 0.0291, 0.1472, 0.0132,\n",
615 | " -0.1019, -0.1006, -0.2296, 0.0467, 0.0651, -0.1048, -0.0653,\n",
616 | " -0.0799, -0.0006, -0.0679, -0.0992, 0.1511, 0.0393, -0.0768,\n",
617 | " 0.0185, -0.1413]], dtype=torch.float64, device='cuda:0'))"
618 | ]
619 | },
620 | "execution_count": 29,
621 | "metadata": {},
622 | "output_type": "execute_result"
623 | }
624 | ],
625 | "source": [
626 | "A = np.random.uniform(-40,40,[10,100])\n",
627 | "simple_randomized_torch_svd(A)"
628 | ]
629 | },
630 | {
631 | "cell_type": "code",
632 | "execution_count": 7,
633 | "metadata": {},
634 | "outputs": [],
635 | "source": [
636 | "np.set_printoptions(suppress=True, precision=4)"
637 | ]
638 | },
639 | {
640 | "cell_type": "code",
641 | "execution_count": 8,
642 | "metadata": {},
643 | "outputs": [],
644 | "source": [
645 | "m_array = np.array([100, 1000, 20000])\n",
646 | "n_array = np.array([100, 1000, 20000])"
647 | ]
648 | },
649 | {
650 | "cell_type": "code",
651 | "execution_count": 9,
652 | "metadata": {},
653 | "outputs": [],
654 | "source": [
655 | "index = pd.MultiIndex.from_product([m_array, n_array], names=['# rows', '# cols'])"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": 10,
661 | "metadata": {},
662 | "outputs": [],
663 | "source": [
664 | "pd.options.display.float_format = '{:,.3f}'.format\n",
665 | "df1 = pd.DataFrame(index=m_array, columns=n_array)\n",
666 | "df2 = pd.DataFrame(index=m_array, columns=n_array)\n",
667 | "df3 = pd.DataFrame(index=m_array, columns=n_array)\n",
668 | "df4 = pd.DataFrame(index=m_array, columns=n_array)\n",
669 | "df5 = pd.DataFrame(index=m_array, columns=n_array)\n",
670 | "df6 = pd.DataFrame(index=m_array, columns=n_array)"
671 | ]
672 | },
673 | {
674 | "cell_type": "code",
675 | "execution_count": 11,
676 | "metadata": {},
677 | "outputs": [
678 | {
679 | "name": "stderr",
680 | "output_type": "stream",
681 | "text": [
682 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:11: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n",
683 | " # This is added back by InteractiveShellApp.init_path()\n",
684 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:12: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n",
685 | " if sys.path[0] == '':\n",
686 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:13: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n",
687 | " del sys.path[0]\n",
688 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:14: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n",
689 | " \n",
690 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:15: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n",
691 | " from ipykernel import kernelapp as app\n",
692 | "/data/siavashmortezavi/anaconda3/envs/fastai/lib/python3.6/site-packages/ipykernel_launcher.py:16: FutureWarning: set_value is deprecated and will be removed in a future release. Please use .at[] or .iat[] accessors instead\n",
693 | " app.launch_new_instance()\n"
694 | ]
695 | }
696 | ],
697 | "source": [
698 | "### %%prun\n",
699 | "for m in m_array:\n",
700 | " for n in n_array: \n",
701 | " A = np.random.uniform(-40,40,[m,n]) \n",
702 | " t1 = timeit.timeit('simple_randomized_svd(A, 10)', number=3, globals=globals())\n",
703 | " t2 = timeit.timeit('decomposition.randomized_svd(A, 10)', number=3, globals=globals())\n",
704 | " t3 = timeit.timeit('fbpca.pca(A, 10)', number=3, globals=globals())\n",
705 | " t4 = timeit.timeit('simple_randomized_torch_svd(A, 10)', number=3, globals=globals())\n",
706 | " t5 = timeit.timeit('randomized_svd(A, 5)', number=3, globals=globals())\n",
707 | " t6 = timeit.timeit('randomized_svd_original(A, 5)', number=3, globals=globals())\n",
708 | " df1.set_value(m, n, t1)\n",
709 | " df2.set_value(m, n, t2)\n",
710 | " df3.set_value(m, n, t3)\n",
711 | " df4.set_value(m, n, t4)\n",
712 | " df5.set_value(m, n, t5)\n",
713 | " df6.set_value(m, n, t6)"
714 | ]
715 | },
716 | {
717 | "cell_type": "code",
718 | "execution_count": 12,
719 | "metadata": {},
720 | "outputs": [
721 | {
722 | "data": {
723 | "text/html": [
724 | "
\n",
725 | "\n",
738 | "
\n",
739 | " \n",
740 | " \n",
741 | " | \n",
742 | " 100 | \n",
743 | " 1000 | \n",
744 | " 20000 | \n",
745 | "
\n",
746 | " \n",
747 | " \n",
748 | " \n",
749 | " | 100 | \n",
750 | " 0.007 | \n",
751 | " 0.021 | \n",
752 | " 0.018 | \n",
753 | "
\n",
754 | " \n",
755 | " | 1000 | \n",
756 | " 0.025 | \n",
757 | " 0.018 | \n",
758 | " 0.036 | \n",
759 | "
\n",
760 | " \n",
761 | " | 20000 | \n",
762 | " 0.014 | \n",
763 | " 0.042 | \n",
764 | " 0.364 | \n",
765 | "
\n",
766 | " \n",
767 | "
\n",
768 | "
"
769 | ],
770 | "text/plain": [
771 | " 100 1000 20000\n",
772 | "100 0.007 0.021 0.018\n",
773 | "1000 0.025 0.018 0.036\n",
774 | "20000 0.014 0.042 0.364"
775 | ]
776 | },
777 | "execution_count": 12,
778 | "metadata": {},
779 | "output_type": "execute_result"
780 | }
781 | ],
782 | "source": [
783 | "df1/3"
784 | ]
785 | },
786 | {
787 | "cell_type": "code",
788 | "execution_count": 13,
789 | "metadata": {},
790 | "outputs": [
791 | {
792 | "data": {
793 | "text/html": [
794 | "\n",
795 | "\n",
808 | "
\n",
809 | " \n",
810 | " \n",
811 | " | \n",
812 | " 100 | \n",
813 | " 1000 | \n",
814 | " 20000 | \n",
815 | "
\n",
816 | " \n",
817 | " \n",
818 | " \n",
819 | " | 100 | \n",
820 | " 0.059 | \n",
821 | " 0.092 | \n",
822 | " 0.094 | \n",
823 | "
\n",
824 | " \n",
825 | " | 1000 | \n",
826 | " 0.101 | \n",
827 | " 0.104 | \n",
828 | " 0.511 | \n",
829 | "
\n",
830 | " \n",
831 | " | 20000 | \n",
832 | " 0.124 | \n",
833 | " 0.483 | \n",
834 | " 4.159 | \n",
835 | "
\n",
836 | " \n",
837 | "
\n",
838 | "
"
839 | ],
840 | "text/plain": [
841 | " 100 1000 20000\n",
842 | "100 0.059 0.092 0.094\n",
843 | "1000 0.101 0.104 0.511\n",
844 | "20000 0.124 0.483 4.159"
845 | ]
846 | },
847 | "execution_count": 13,
848 | "metadata": {},
849 | "output_type": "execute_result"
850 | }
851 | ],
852 | "source": [
853 | "df2/3"
854 | ]
855 | },
856 | {
857 | "cell_type": "code",
858 | "execution_count": 14,
859 | "metadata": {},
860 | "outputs": [
861 | {
862 | "data": {
863 | "text/html": [
864 | "\n",
865 | "\n",
878 | "
\n",
879 | " \n",
880 | " \n",
881 | " | \n",
882 | " 100 | \n",
883 | " 1000 | \n",
884 | " 20000 | \n",
885 | "
\n",
886 | " \n",
887 | " \n",
888 | " \n",
889 | " | 100 | \n",
890 | " 0.024 | \n",
891 | " 0.077 | \n",
892 | " 0.084 | \n",
893 | "
\n",
894 | " \n",
895 | " | 1000 | \n",
896 | " 0.056 | \n",
897 | " 0.083 | \n",
898 | " 0.214 | \n",
899 | "
\n",
900 | " \n",
901 | " | 20000 | \n",
902 | " 0.090 | \n",
903 | " 0.225 | \n",
904 | " 1.424 | \n",
905 | "
\n",
906 | " \n",
907 | "
\n",
908 | "
"
909 | ],
910 | "text/plain": [
911 | " 100 1000 20000\n",
912 | "100 0.024 0.077 0.084\n",
913 | "1000 0.056 0.083 0.214\n",
914 | "20000 0.090 0.225 1.424"
915 | ]
916 | },
917 | "execution_count": 14,
918 | "metadata": {},
919 | "output_type": "execute_result"
920 | }
921 | ],
922 | "source": [
923 | "df3/3"
924 | ]
925 | },
926 | {
927 | "cell_type": "code",
928 | "execution_count": 15,
929 | "metadata": {},
930 | "outputs": [
931 | {
932 | "data": {
933 | "text/html": [
934 | "\n",
935 | "\n",
948 | "
\n",
949 | " \n",
950 | " \n",
951 | " | \n",
952 | " 100 | \n",
953 | " 1000 | \n",
954 | " 20000 | \n",
955 | "
\n",
956 | " \n",
957 | " \n",
958 | " \n",
959 | " | 100 | \n",
960 | " 1.202 | \n",
961 | " 0.014 | \n",
962 | " 0.027 | \n",
963 | "
\n",
964 | " \n",
965 | " | 1000 | \n",
966 | " 0.019 | \n",
967 | " 0.027 | \n",
968 | " 0.096 | \n",
969 | "
\n",
970 | " \n",
971 | " | 20000 | \n",
972 | " 0.029 | \n",
973 | " 0.100 | \n",
974 | " 2.756 | \n",
975 | "
\n",
976 | " \n",
977 | "
\n",
978 | "
"
979 | ],
980 | "text/plain": [
981 | " 100 1000 20000\n",
982 | "100 1.202 0.014 0.027\n",
983 | "1000 0.019 0.027 0.096\n",
984 | "20000 0.029 0.100 2.756"
985 | ]
986 | },
987 | "execution_count": 15,
988 | "metadata": {},
989 | "output_type": "execute_result"
990 | }
991 | ],
992 | "source": [
993 | "df4/4"
994 | ]
995 | },
996 | {
997 | "cell_type": "code",
998 | "execution_count": 16,
999 | "metadata": {},
1000 | "outputs": [
1001 | {
1002 | "data": {
1003 | "text/html": [
1004 | "\n",
1005 | "\n",
1018 | "
\n",
1019 | " \n",
1020 | " \n",
1021 | " | \n",
1022 | " 100 | \n",
1023 | " 1000 | \n",
1024 | " 20000 | \n",
1025 | "
\n",
1026 | " \n",
1027 | " \n",
1028 | " \n",
1029 | " | 100 | \n",
1030 | " 0.030 | \n",
1031 | " 0.050 | \n",
1032 | " 0.056 | \n",
1033 | "
\n",
1034 | " \n",
1035 | " | 1000 | \n",
1036 | " 0.020 | \n",
1037 | " 0.028 | \n",
1038 | " 0.252 | \n",
1039 | "
\n",
1040 | " \n",
1041 | " | 20000 | \n",
1042 | " 0.050 | \n",
1043 | " 0.238 | \n",
1044 | " 2.225 | \n",
1045 | "
\n",
1046 | " \n",
1047 | "
\n",
1048 | "
"
1049 | ],
1050 | "text/plain": [
1051 | " 100 1000 20000\n",
1052 | "100 0.030 0.050 0.056\n",
1053 | "1000 0.020 0.028 0.252\n",
1054 | "20000 0.050 0.238 2.225"
1055 | ]
1056 | },
1057 | "execution_count": 16,
1058 | "metadata": {},
1059 | "output_type": "execute_result"
1060 | }
1061 | ],
1062 | "source": [
1063 | "df5/5"
1064 | ]
1065 | },
1066 | {
1067 | "cell_type": "code",
1068 | "execution_count": 17,
1069 | "metadata": {},
1070 | "outputs": [
1071 | {
1072 | "data": {
1073 | "text/html": [
1074 | "\n",
1075 | "\n",
1088 | "
\n",
1089 | " \n",
1090 | " \n",
1091 | " | \n",
1092 | " 100 | \n",
1093 | " 1000 | \n",
1094 | " 20000 | \n",
1095 | "
\n",
1096 | " \n",
1097 | " \n",
1098 | " \n",
1099 | " | 100 | \n",
1100 | " 0.026 | \n",
1101 | " 0.043 | \n",
1102 | " 0.046 | \n",
1103 | "
\n",
1104 | " \n",
1105 | " | 1000 | \n",
1106 | " 0.047 | \n",
1107 | " 0.028 | \n",
1108 | " 0.181 | \n",
1109 | "
\n",
1110 | " \n",
1111 | " | 20000 | \n",
1112 | " 0.046 | \n",
1113 | " 0.173 | \n",
1114 | " 1.198 | \n",
1115 | "
\n",
1116 | " \n",
1117 | "
\n",
1118 | "
"
1119 | ],
1120 | "text/plain": [
1121 | " 100 1000 20000\n",
1122 | "100 0.026 0.043 0.046\n",
1123 | "1000 0.047 0.028 0.181\n",
1124 | "20000 0.046 0.173 1.198"
1125 | ]
1126 | },
1127 | "execution_count": 17,
1128 | "metadata": {},
1129 | "output_type": "execute_result"
1130 | }
1131 | ],
1132 | "source": [
1133 | "df6/6"
1134 | ]
1135 | },
1136 | {
1137 | "cell_type": "code",
1138 | "execution_count": null,
1139 | "metadata": {
1140 | "collapsed": true
1141 | },
1142 | "outputs": [],
1143 | "source": []
1144 | }
1145 | ],
1146 | "metadata": {
1147 | "kernelspec": {
1148 | "display_name": "Python 3",
1149 | "language": "python",
1150 | "name": "python3"
1151 | },
1152 | "language_info": {
1153 | "codemirror_mode": {
1154 | "name": "ipython",
1155 | "version": 3
1156 | },
1157 | "file_extension": ".py",
1158 | "mimetype": "text/x-python",
1159 | "name": "python",
1160 | "nbconvert_exporter": "python",
1161 | "pygments_lexer": "ipython3",
1162 | "version": "3.6.3"
1163 | }
1164 | },
1165 | "nbformat": 4,
1166 | "nbformat_minor": 2
1167 | }
1168 |
--------------------------------------------------------------------------------