├── LICENSE ├── .gitignore ├── README.md ├── RMT4ELM.py └── RMT4ELM.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Zhenyu LIAO 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RMT4ELM 2 | *A Random Matrix Approach to Extreme Learning Machine* 3 | 4 | This page contains a simple demo using [Python 3](https://www.python.org/) of the theoretical results in the following paper: 5 | 6 | [A Random Matrix Approach to Neural Networks](https://128.84.21.199/abs/1702.05419) 7 | 8 | where recent advances in matrix matrix theory are used to analyze the performance of randomly-connected single-layer neural nets (also referred in literatures as *extreme learning machines*). 9 | 10 | ## About the code 11 | Comparison between theory and practice is available for data from 12 | 13 | * [MNIST](http://yann.lecun.com/exdb/mnist/) database 14 | * Gaussian mixture model 15 | 16 | for a dozen of commonly-used activation functions. 17 | 18 | ## Dependencies 19 | To be able to test this code requires the following: 20 | 21 | * [Python](https://www.python.org/): tested with version 3.6 22 | * [Numpy](http://www.numpy.org/) and [Scipy](https://www.scipy.org/) 23 | * [Matplotlib](http://matplotlib.org/) for visulazation 24 | * [Scikit-learn](http://scikit-learn.org/stable/) for MNIST dataset 25 | 26 | We strongly recommend you to use [Jupyter nootbook](http://jupyter.org/) to have a direct illustration within your web browsers: [here](http://nbviewer.jupyter.org/github/Zhenyu-LIAO/RMT4ELM/blob/master/RMT4ELM.ipynb). 27 | 28 | ## Contact information 29 | * Zhenyu LIAO 30 | * Ph.D. student at CentraleSupelec, Paris, France 31 | * Website: [https://zhenyu-liao.github.io/](https://zhenyu-liao.github.io/) 32 | * E-mail: [zhenyu.liao@l2s.centralesupelec.fr](mailto:zhenyu.liao@l2s.centralesupelec.fr) 33 | * Prof. Romain COUILLET 34 | * Professor at CentraleSupelec, Paris, France 35 | * Website: [http://romaincouillet.hebfree.org/](http://romaincouillet.hebfree.org/) 36 | * E-mail: [romain.couillet@centralesupelec.fr](mailto:romain.couillet@centralesupelec.fr) 37 | 38 | 39 | -------------------------------------------------------------------------------- /RMT4ELM.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # In[7]: 5 | 6 | import math 7 | import scipy.special,scipy.linalg 8 | import numpy as np 9 | import time 10 | from matplotlib import pyplot as plt 11 | from sklearn.datasets import fetch_mldata 12 | 13 | 14 | # # Generate Data (MNIST or Gaussian mixture) 15 | 16 | # In[8]: 17 | 18 | def gen_data(testcase,Tr,Te,prop,means=None,covs=None): 19 | rng = np.random 20 | 21 | if testcase is 'MNIST': 22 | mnist=fetch_mldata('MNIST original') 23 | X,y = mnist.data,mnist.target 24 | X_train_full, X_test_full = X[:60000], X[60000:] 25 | y_train_full, y_test_full = y[:60000], y[60000:] 26 | 27 | selected_target = [7,9] 28 | K=len(selected_target) 29 | X_train = np.array([]).reshape(p,0) 30 | X_test = np.array([]).reshape(p,0) 31 | 32 | y_train = [] 33 | y_test = [] 34 | ind=0 35 | for i in selected_target: 36 | locate_target_train = np.where(y_train_full==i)[0][range(np.int(prop[ind]*Tr))] 37 | locate_target_test = np.where(y_test_full==i)[0][range(np.int(prop[ind]*Te))] 38 | X_train = np.concatenate( (X_train,X_train_full[locate_target_train].T),axis=1) 39 | y_train = np.concatenate( (y_train,2*(ind-K/2+.5)*np.ones(np.int(Tr*prop[ind]))) ) 40 | X_test = np.concatenate( (X_test,X_test_full[locate_target_test].T),axis=1) 41 | y_test = np.concatenate( (y_test,2*(ind-K/2+.5)*np.ones(np.int(Te*prop[ind]))) ) 42 | ind+=1 43 | 44 | X_train = X_train - np.mean(X_train,axis=1).reshape(p,1) 45 | X_train = X_train*np.sqrt(784)/np.sqrt(np.sum(X_train**2,(0,1))/Tr) 46 | 47 | X_test = X_test - np.mean(X_test,axis=1).reshape(p,1) 48 | X_test = X_test*np.sqrt(784)/np.sqrt(np.sum(X_test**2,(0,1))/Te) 49 | 50 | else: 51 | X_train = np.array([]).reshape(p,0) 52 | X_test = np.array([]).reshape(p,0) 53 | y_train = [] 54 | y_test = [] 55 | K = len(prop) 56 | for i in range(K): 57 | X_train = np.concatenate((X_train,rng.multivariate_normal(means[i],covs[i],size=np.int(Tr*prop[i])).T),axis=1) 58 | X_test = np.concatenate((X_test, rng.multivariate_normal(means[i],covs[i],size=np.int(Te*prop[i])).T),axis=1) 59 | y_train = np.concatenate( (y_train,2*(i-K/2+.5)*np.ones(np.int(Tr*prop[i]))) ) 60 | y_test = np.concatenate( (y_test,2*(i-K/2+.5)*np.ones(np.int(Te*prop[i]))) ) 61 | 62 | X_train = X_train/math.sqrt(p) 63 | X_test = X_test/math.sqrt(p) 64 | 65 | return X_train, X_test, y_train, y_test 66 | 67 | 68 | 69 | # # Generate $\sigma(\cdot)$ activation functions 70 | 71 | # In[9]: 72 | 73 | def gen_sig(fun,Z,polynom=None): 74 | 75 | if fun is 'poly2': 76 | sig = polynom[0]*Z**2+polynom[1]*Z+polynom[2] 77 | elif fun is 'ReLu': 78 | sig = np.maximum(Z,0) 79 | elif fun is 'sign': 80 | sig = np.sign(Z) 81 | elif fun is 'posit': 82 | sig = (Z>0).astype(int) 83 | elif fun is 'erf': 84 | sig = scipy.special.erf(Z) 85 | elif fun is 'cos': 86 | sig = np.cos(Z) 87 | elif fun is 'abs': 88 | sig = np.abs(Z) 89 | 90 | return sig 91 | 92 | 93 | # # Generate matrices $\Phi_{AB}$ 94 | 95 | # In[10]: 96 | 97 | def gen_Phi(fun,A,B,polynom=None,distrib=None,nu=None): 98 | normA = np.sqrt(np.sum(A**2,axis=0)) 99 | normB = np.sqrt(np.sum(B**2,axis=0)) 100 | 101 | AB = A.T @ B 102 | angle_AB = np.minimum( (1/normA).reshape((len(normA),1)) * AB * (1/normB).reshape( (1,len(normB)) ) ,1.) 103 | 104 | if fun is 'poly2': 105 | mom = {'gauss': [1,0,3],'bern': [1,0,1],'bern_skewed': [1,-2/math.sqrt(3),7/3],'student':[1,0,6/(nu-4)+3]} 106 | A2 = A**2 107 | B2 = B**2 108 | Phi = polynom[0]**2*(mom[distrib][0]**2*(2*AB**2+(normA**2).reshape((len(normA),1))*(normB**2).reshape((1,len(normB))) )+(mom[distrib][2]-3*mom[distrib][0]**2)*(A2.T@B2))+polynom[1]**2*mom[distrib][0]*AB+polynom[1]*polynom[0]*mom[distrib][1]*(A2.T@B+A.T@B2)+polynom[2]*polynom[0]*mom[distrib][0]*( (normA**2).reshape( (len(normA),1) )+(normB**2).reshape( (1,len(normB)) ) )+polynom[2]**2 109 | 110 | elif fun is 'ReLu': 111 | Phi = 1/(2*math.pi)* normA.reshape((len(normA),1)) * (angle_AB*np.arccos(-angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) ) 112 | 113 | elif fun is 'abs': 114 | Phi = 2/math.pi* normA.reshape((len(normA),1)) * (angle_AB*np.arcsin(angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) ) 115 | 116 | elif fun is 'posit': 117 | Phi = 1/2-1/(2*math.pi)*np.arccos(angle_AB) 118 | 119 | elif fun is 'sign': 120 | Phi = 1-2/math.pi*np.arccos(angle_AB) 121 | 122 | elif fun is 'cos': 123 | Phi = np.exp(-.5*( (normA**2).reshape((len(normA),1))+(normB**2).reshape((1,len(normB))) ))*np.cosh(AB) 124 | 125 | elif fun is 'erf': 126 | Phi = 2/math.pi*np.arcsin(2*AB/np.sqrt((1+2*(normA**2).reshape((len(normA),1)))*(1+2*(normB**2).reshape((1,len(normB)))))) 127 | 128 | return Phi 129 | 130 | 131 | # # Generate $E_{\rm train}$ and $E_{\rm test}$ 132 | 133 | # In[11]: 134 | 135 | def gen_E_th(): 136 | d=0 137 | dt=-1 138 | 139 | while np.abs(d-dt)>1e-6: 140 | dt=d 141 | d=np.mean(L/(L*n/Tr/(1+d)+gamma)) 142 | 143 | L_psi = L*n/Tr/(1+d) 144 | L_bQ = 1/(L_psi+gamma) 145 | 146 | # E_train 147 | E_train_th = gamma**2*np.mean(Uy_train**2*L_bQ**2*(1/n*np.sum(L_psi*L_bQ**2)/(1-1/n*np.sum(L_psi**2*L_bQ**2))*L_psi+1)) 148 | 149 | #E_test 150 | E_test_th = np.mean((y_test-n/Tr/(1+d)*UPhi_cross.T@(L_bQ*Uy_train))**2)+(1/n*np.sum(Uy_train**2*L_psi*L_bQ**2))/(1-1/n*np.sum(L_psi**2*L_bQ**2))*(np.mean((n/Tr/(1+d))*D_Phi_test)-Tr/Te*np.mean( (n/Tr/(1+d))**2*D_UPhi_cross2*(1+gamma*L_bQ)*L_bQ)) 151 | 152 | return E_train_th,E_test_th 153 | 154 | 155 | 156 | # # Main code 157 | 158 | # In[12]: 159 | 160 | ## Parameter setting 161 | n=512 162 | p=256 163 | Tr=1024 # Training length 164 | Te=Tr # Testing length 165 | 166 | prop=[.5,.5] # proportions of each class 167 | K=len(prop) # number of data classes 168 | 169 | gammas = [10**x for x in np.arange(-4,2.25,.25)] # Range of gamma for simulations 170 | 171 | testcase='MNIST' # testcase for simulation, among 'iid','means','var','orth','mixed',MNIST' 172 | sigma='ReLu' # activation function, among 'ReLu', 'sign', 'posit', 'erf', 'poly2', 'cos', 'abs' 173 | 174 | 175 | # Only used for sigma='poly2' 176 | polynom=[-.5,0,1] # sigma(t)=polynom[0].t²+polynom[1].t+polynom[2] 177 | distrib='student' # distribution of Wij, among 'gauss','bern','bern_skewed','student' 178 | 179 | # Only used for sigma='poly2' and distrib='student' 180 | nu=7 # degrees of freedom of Student-t distribution 181 | 182 | 183 | ## Generate X_train,X_test,y_train,y_test 184 | if testcase is 'MNIST': 185 | p=784 186 | X_train,X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop) 187 | else: 188 | means=[] 189 | covs=[] 190 | if testcase is 'iid': 191 | for i in range(K): 192 | means.append(np.zeros(p)) 193 | covs.append(np.eye(p)) 194 | elif testcase is 'means': 195 | for i in range(K): 196 | means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) ) 197 | covs.append(np.eye(p)) 198 | elif testcase is 'var': 199 | for i in range(K): 200 | means.append(np.zeros(p)) 201 | covs.append(np.eye(p)*(1+8*i/np.sqrt(p))) 202 | elif testcase is 'orth': 203 | for i in range(K): 204 | means.append(np.zeros(p)) 205 | covs.append( np.diag(np.concatenate( (np.ones(np.int(np.sum(prop[0:i]*p))),4*np.ones(np.int(prop[i]*p)),np.ones(np.int(np.sum(prop[i+1:]*p))) ) ) )) 206 | elif testcase is 'mixed': 207 | for i in range(K): 208 | means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) ) 209 | covs.append((1+4*i/np.sqrt(p))*scipy.linalg.toeplitz( [(.4*i)**x for x in range(p)] )) 210 | 211 | X_train,X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop,means,covs) 212 | 213 | ##Theory 214 | start_th_calculus = time.time() 215 | 216 | Phi=gen_Phi(sigma,X_train,X_train,polynom,distrib,nu) 217 | L,U = np.linalg.eigh(Phi) 218 | Phi_cross = gen_Phi(sigma,X_train,X_test,polynom,distrib,nu) 219 | UPhi_cross = U.T@Phi_cross 220 | D_UPhi_cross2 = np.sum(UPhi_cross**2,axis=1) 221 | 222 | Phi_test = gen_Phi(sigma,X_test,X_test,polynom,distrib,nu) 223 | D_Phi_test = np.diag(Phi_test) 224 | Uy_train = U.T@y_train 225 | 226 | E_train_th=np.zeros(len(gammas)) 227 | E_test_th =np.zeros(len(gammas)) 228 | 229 | ind=0 230 | for gamma in gammas: 231 | E_train_th[ind],E_test_th[ind] = gen_E_th() 232 | ind+=1 233 | 234 | end_th_calculus = time.time() 235 | 236 | m,s = divmod(end_th_calculus-start_th_calculus,60) 237 | print('Time for Theoretical Computation {:d}min {:d}s'.format( int(m),math.ceil(s) )) 238 | 239 | ## Simulations 240 | start_sim_calculus = time.time() 241 | 242 | loops = 10 # Number of generations of W to be averaged over 243 | 244 | E_train=np.zeros(len(gammas)) 245 | E_test =np.zeros(len(gammas)) 246 | 247 | 248 | rng = np.random 249 | 250 | for loop in range(loops): 251 | if sigma is 'poly2': 252 | if distrib is 'student': 253 | W = rng.standard_t(nu,n*p).reshape(n,p)/np.sqrt(nu/(nu-2)) 254 | elif distrib is 'bern': 255 | W = np.sign(rng.randn(n,p)) 256 | elif distrib is 'bern_skewed': 257 | Z = rng.rand(n,p) 258 | W = (Z<.75)/np.sqrt(3)+(Z>.75)*(-np.sqrt(3)) 259 | elif distrib is 'gauss': 260 | W = rng.randn(n,p) 261 | else: 262 | W = rng.randn(n,p) 263 | 264 | S_train = gen_sig(sigma,W @ X_train,polynom) 265 | SS = S_train.T @ S_train 266 | 267 | S_test = gen_sig(sigma, W @ X_test,polynom) 268 | 269 | ind = 0 270 | for gamma in gammas: 271 | 272 | inv_resolv = np.linalg.solve( SS/Tr+gamma*np.eye(Tr),y_train) 273 | beta = S_train @ inv_resolv/Tr 274 | z_train = S_train.T @ beta 275 | 276 | z_test = S_test.T @ beta 277 | 278 | 279 | E_train[ind] += gamma**2*np.linalg.norm(inv_resolv)**2/Tr/loops 280 | E_test[ind] += np.linalg.norm(y_test-z_test)**2/Te/loops 281 | 282 | ind+=1 283 | 284 | end_sim_calculus = time.time() 285 | 286 | m,s = divmod(end_sim_calculus-start_sim_calculus,60) 287 | print('Time for Simulations Computation {:d}min {:d}s'.format( int(m),math.ceil(s) )) 288 | 289 | #Plots 290 | p11,=plt.plot(gammas,E_train,'bo') 291 | p21,=plt.plot(gammas,E_test,'ro') 292 | 293 | p12,=plt.plot(gammas,E_train_th,'b-') 294 | p22,=plt.plot(gammas,E_test_th,'r-') 295 | plt.xscale('log') 296 | plt.yscale('log') 297 | plt.xlim( gammas[0],gammas[-1] ) 298 | plt.ylim(np.amin( (E_train,E_train_th) ),np.amax( (E_test,E_test_th) )) 299 | plt.legend([p11,p12,p21,p22], ["E_train", "E_train Th","E_test","E_test Th"],bbox_to_anchor=(1, 1), loc='upper left') 300 | plt.show() 301 | 302 | -------------------------------------------------------------------------------- /RMT4ELM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import math\n", 12 | "import scipy.special,scipy.linalg\n", 13 | "import numpy as np\n", 14 | "import time\n", 15 | "from matplotlib import pyplot as plt\n", 16 | "from sklearn.datasets import fetch_mldata" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "# Generate Data (MNIST or Gaussian mixture)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 8, 29 | "metadata": { 30 | "collapsed": true 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "def gen_data(testcase,Tr,Te,prop,means=None,covs=None):\n", 35 | " rng = np.random\n", 36 | " \n", 37 | " if testcase is 'MNIST':\n", 38 | " mnist=fetch_mldata('MNIST original')\n", 39 | " X,y = mnist.data,mnist.target\n", 40 | " X_train_full, X_test_full = X[:60000], X[60000:]\n", 41 | " y_train_full, y_test_full = y[:60000], y[60000:]\n", 42 | "\n", 43 | " selected_target = [7,9]\n", 44 | " K=len(selected_target)\n", 45 | " X_train = np.array([]).reshape(p,0)\n", 46 | " X_test = np.array([]).reshape(p,0) \n", 47 | " \n", 48 | " y_train = []\n", 49 | " y_test = []\n", 50 | " ind=0\n", 51 | " for i in selected_target:\n", 52 | " locate_target_train = np.where(y_train_full==i)[0][range(np.int(prop[ind]*Tr))]\n", 53 | " locate_target_test = np.where(y_test_full==i)[0][range(np.int(prop[ind]*Te))]\n", 54 | " X_train = np.concatenate( (X_train,X_train_full[locate_target_train].T),axis=1)\n", 55 | " y_train = np.concatenate( (y_train,2*(ind-K/2+.5)*np.ones(np.int(Tr*prop[ind]))) )\n", 56 | " X_test = np.concatenate( (X_test,X_test_full[locate_target_test].T),axis=1)\n", 57 | " y_test = np.concatenate( (y_test,2*(ind-K/2+.5)*np.ones(np.int(Te*prop[ind]))) )\n", 58 | " ind+=1 \n", 59 | " \n", 60 | " X_train = X_train - np.mean(X_train,axis=1).reshape(p,1)\n", 61 | " X_train = X_train*np.sqrt(784)/np.sqrt(np.sum(X_train**2,(0,1))/Tr)\n", 62 | " \n", 63 | " X_test = X_test - np.mean(X_test,axis=1).reshape(p,1)\n", 64 | " X_test = X_test*np.sqrt(784)/np.sqrt(np.sum(X_test**2,(0,1))/Te)\n", 65 | " \n", 66 | " else:\n", 67 | " X_train = np.array([]).reshape(p,0)\n", 68 | " X_test = np.array([]).reshape(p,0) \n", 69 | " y_train = []\n", 70 | " y_test = []\n", 71 | " K = len(prop)\n", 72 | " for i in range(K): \n", 73 | " X_train = np.concatenate((X_train,rng.multivariate_normal(means[i],covs[i],size=np.int(Tr*prop[i])).T),axis=1)\n", 74 | " X_test = np.concatenate((X_test, rng.multivariate_normal(means[i],covs[i],size=np.int(Te*prop[i])).T),axis=1)\n", 75 | " y_train = np.concatenate( (y_train,2*(i-K/2+.5)*np.ones(np.int(Tr*prop[i]))) )\n", 76 | " y_test = np.concatenate( (y_test,2*(i-K/2+.5)*np.ones(np.int(Te*prop[i]))) ) \n", 77 | " \n", 78 | " X_train = X_train/math.sqrt(p)\n", 79 | " X_test = X_test/math.sqrt(p)\n", 80 | " \n", 81 | " return X_train, X_test, y_train, y_test\n", 82 | " " 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "# Generate $\\sigma(\\cdot)$ activation functions" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 9, 95 | "metadata": { 96 | "collapsed": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "def gen_sig(fun,Z,polynom=None):\n", 101 | " \n", 102 | " if fun is 'poly2':\n", 103 | " sig = polynom[0]*Z**2+polynom[1]*Z+polynom[2]\n", 104 | " elif fun is 'ReLu':\n", 105 | " sig = np.maximum(Z,0)\n", 106 | " elif fun is 'sign':\n", 107 | " sig = np.sign(Z)\n", 108 | " elif fun is 'posit':\n", 109 | " sig = (Z>0).astype(int)\n", 110 | " elif fun is 'erf':\n", 111 | " sig = scipy.special.erf(Z)\n", 112 | " elif fun is 'cos':\n", 113 | " sig = np.cos(Z)\n", 114 | " elif fun is 'abs':\n", 115 | " sig = np.abs(Z)\n", 116 | " \n", 117 | " return sig" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "# Generate matrices $\\Phi_{AB}$" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 10, 130 | "metadata": { 131 | "collapsed": false 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "def gen_Phi(fun,A,B,polynom=None,distrib=None,nu=None):\n", 136 | " normA = np.sqrt(np.sum(A**2,axis=0))\n", 137 | " normB = np.sqrt(np.sum(B**2,axis=0))\n", 138 | " \n", 139 | " AB = A.T @ B\n", 140 | " angle_AB = np.minimum( (1/normA).reshape((len(normA),1)) * AB * (1/normB).reshape( (1,len(normB)) ) ,1.)\n", 141 | " \n", 142 | " if fun is 'poly2':\n", 143 | " mom = {'gauss': [1,0,3],'bern': [1,0,1],'bern_skewed': [1,-2/math.sqrt(3),7/3],'student':[1,0,6/(nu-4)+3]}\n", 144 | " A2 = A**2\n", 145 | " B2 = B**2\n", 146 | " Phi = polynom[0]**2*(mom[distrib][0]**2*(2*AB**2+(normA**2).reshape((len(normA),1))*(normB**2).reshape((1,len(normB))) )+(mom[distrib][2]-3*mom[distrib][0]**2)*(A2.T@B2))+polynom[1]**2*mom[distrib][0]*AB+polynom[1]*polynom[0]*mom[distrib][1]*(A2.T@B+A.T@B2)+polynom[2]*polynom[0]*mom[distrib][0]*( (normA**2).reshape( (len(normA),1) )+(normB**2).reshape( (1,len(normB)) ) )+polynom[2]**2\n", 147 | " \n", 148 | " elif fun is 'ReLu':\n", 149 | " Phi = 1/(2*math.pi)* normA.reshape((len(normA),1)) * (angle_AB*np.arccos(-angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )\n", 150 | " \n", 151 | " elif fun is 'abs':\n", 152 | " Phi = 2/math.pi* normA.reshape((len(normA),1)) * (angle_AB*np.arcsin(angle_AB)+np.sqrt(1-angle_AB**2)) * normB.reshape( (1,len(normB)) )\n", 153 | " \n", 154 | " elif fun is 'posit':\n", 155 | " Phi = 1/2-1/(2*math.pi)*np.arccos(angle_AB)\n", 156 | " \n", 157 | " elif fun is 'sign':\n", 158 | " Phi = 1-2/math.pi*np.arccos(angle_AB)\n", 159 | " \n", 160 | " elif fun is 'cos':\n", 161 | " Phi = np.exp(-.5*( (normA**2).reshape((len(normA),1))+(normB**2).reshape((1,len(normB))) ))*np.cosh(AB)\n", 162 | " \n", 163 | " elif fun is 'erf':\n", 164 | " Phi = 2/math.pi*np.arcsin(2*AB/np.sqrt((1+2*(normA**2).reshape((len(normA),1)))*(1+2*(normB**2).reshape((1,len(normB))))))\n", 165 | "\n", 166 | " return Phi" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "# Generate $E_{\\rm train}$ and $E_{\\rm test}$" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 11, 179 | "metadata": { 180 | "collapsed": false 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "def gen_E_th():\n", 185 | " d=0\n", 186 | " dt=-1\n", 187 | "\n", 188 | " while np.abs(d-dt)>1e-6:\n", 189 | " dt=d\n", 190 | " d=np.mean(L/(L*n/Tr/(1+d)+gamma))\n", 191 | " \n", 192 | " L_psi = L*n/Tr/(1+d)\n", 193 | " L_bQ = 1/(L_psi+gamma)\n", 194 | "\n", 195 | " # E_train\n", 196 | " E_train_th = gamma**2*np.mean(Uy_train**2*L_bQ**2*(1/n*np.sum(L_psi*L_bQ**2)/(1-1/n*np.sum(L_psi**2*L_bQ**2))*L_psi+1)) \n", 197 | " \n", 198 | " #E_test\n", 199 | " E_test_th = np.mean((y_test-n/Tr/(1+d)*UPhi_cross.T@(L_bQ*Uy_train))**2)+(1/n*np.sum(Uy_train**2*L_psi*L_bQ**2))/(1-1/n*np.sum(L_psi**2*L_bQ**2))*(np.mean((n/Tr/(1+d))*D_Phi_test)-Tr/Te*np.mean( (n/Tr/(1+d))**2*D_UPhi_cross2*(1+gamma*L_bQ)*L_bQ))\n", 200 | " \n", 201 | " return E_train_th,E_test_th\n", 202 | " " 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "# Main code" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 12, 215 | "metadata": { 216 | "collapsed": false 217 | }, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "Time for Theoretical Computation 0min 1s\n", 224 | "Time for Simulations Computation 0min 16s\n" 225 | ] 226 | }, 227 | { 228 | "data": { 229 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdcAAAEACAYAAADhvzxWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl8U1XaB/DfSUtbyr4vBdpXoEAFAalUZxhBEAQFFMuA\nUkVAQQTGFwRGpYgIgqIMCoKUvSgFYRCYsqgviizjArZYsOyLlEVK2VqWLjTNef84RNI2KUmb5N4k\nv+/nk0/bm5vkuZTk6Tn33OcRUkoQERGR8xi0DoCIiMjbMLkSERE5GZMrERGRkzG5EhERORmTKxER\nkZMxuRIRETkZkysREZGTMbkSERE5GZMrERGRkzG5EhEROZm/1gGUpGbNmjIsLEzrMIiI9CsrC/L4\nCQhI3EIA/sAtXJJSaB2Wr9N1cg0LC0NSUpLWYRAR6YuUwNdfI3/iZJQ7fhyn0Agz8AaqIBPfYoLW\n0RF0nlyJiMjC7aRaMGky/JL24A8RimliETLC/4IZR55CMxxDpNYxEgCecyUi0j8pga++ginqQeDx\nx3FubwZewiKM63UUo1NfwobDEbj4yts46xeqdaR0m9Bzy7nIyEjJaWEi8jkJCUBsLJCWBtSuDVmx\nEsTJEzjjH4Z3jLE49beBePeDADz4YPGHCiGSpZQcwGqM08JERHqSkADj4KHwz8+BBCAyMvBHhj/e\nxiKktBiIqTMC0L07IBxYspScnFzb399/MYCW4Iyls5gApBqNxpfatWuXUfROJlciIh258eqbqJif\nAwBIR128hanYhQ74Z8X5WJgSAEMpUqO/v//iunXrtqhVq9ZVg8Gg3+lKD2IymcTFixcj0tPTFwPo\nXfR+Xf4FI4ToJYRYmJWVpXUoRETus38/Klw5AxOAyXgbD2AP2iAFv+E+DL7xSakS620ta9WqdY2J\n1XkMBoOsVatWFtRsQPH73RyPXaSUG6WUw6pUqaJ1KEREriclsHAhZFQUrqA6umAb0hCKQ4jAKMxD\nAPJxGo3K8goGJlbnu/1vajWP6jK5EhH5jGvXgAEDgJdfxg9+DyMCB/AENmEphqASbgAAbiIYs2pM\n0zhQcgSTKxGRVvbuBdq1g2nNv/FO4HT0CfwKL02oiwPl7kcaQmGCwCmEYlS5hYiaHeO2sOLiUL1+\nfbQyGNCufn20iotD9bI+p5+fX7vmzZtHmG8TJkyoa2vfKVOm1L5+/brD+Wn06NH1N2zYUKlskToH\nFzQREbmblMC8eZBjx+JGUC08btqO68074JcNQFgYkBARg06xMTh9GmjUCJg2DYhxU26Ni0P1MWMQ\nmpurBl/nzyNgzBiEAsDw4bhS2ucNDAw0HT58+KA9+y5YsKDO0KFDr1SqVMlU9D6j0Qh/f+up6+OP\nP/6jtPE5G0euRETulJkJ9O0L/OMf+LVmV/zPtRSE9O+AH35QiRVQifTUKcBkUl/dlVgBYMoUhJgT\nq1luLgxTpiDEHa//7rvv1s7IyCjXsWPH8KioqHAACA4Objt06NAGzZo1i/juu+8qjhs3rl7Lli1b\nNG3a9N5nn3021GRSOTg6Ojps2bJl1QAgJCSk1ZgxY+pHRES0CA8Pj/j111+D3BG/GZMrEZG77NkD\ntG0LmZiIWfU/xAN/JGL8+zWxahVQoYLWwSnp6QhwZLu98vLyDJbTwosWLapmbb+JEydm1K5dO3/H\njh1Hd+/efRQAcnJyDFFRUTePHDly8LHHHrsxfvz4jNTU1EPHjh07kJOTY/jiiy+srn6tWbOm8eDB\ng4eGDBly8f33369TlvgdxWlhIiJXsay0VK0acO0acms2wFMVduHnmw9i0xagRw+tgyysbl3cOn++\neCKtWxe3yvK8jkwLF+Xn54dBgwZdNf/81VdfVZo1a1bd3NxcQ2Zmpn9EREQOgGLXbg4YMOAqALRv\n3z47MTHRajJ3FY5ciYhcISEBGDZMJVYAuHoVJ0xhaJCRjLR6D2LPHv0lVgCYNAnngoJQ6FxnUBBM\nkybhnFYxBQQEmMznWbOzs8XYsWND161bd+Lo0aMHn3vuuUu5ublWc1lQUJAEAH9/f2k0Gt3aho/J\nlYjIFWJjgexsAIAEsBLPoIk8hr8E7cXu3UB4uLbh2TJ8OK589BHS6tXDLSGAevVw66OPkFaWxUyO\nqlChQkFWVpbV/JSdnW0AgLp16xqzsrIMGzdudOuI1F6cFiYicgGZlgYBoAAGvIjFWI7BeAtT8HbO\nO/CrXKB1eCUaPhxXnJ1MzedczT937tw569NPP7U6Gn7hhRcude/ePbxOnTq3zOddzWrWrFkQExNz\nsUWLFvfWqlXL2Lp165vOjNNZdNkVRwjRC0CvJk2aDD127JjW4RAROebwYRS0aAkJoAu+RTIi8RkG\n4mmsx1m/UDQwnnLZS1vrirNv375TrVu3vuSyF/Vh+/btq9m6deuwott1OS3M8odE5LEOHwYeeQTX\nUAmP4WscQgR+wkN4GutxE8F4vYCVlnwBp4WJiJzl8GGgUyeYJNDR/0ccMjbBl3ga9+IATiEUEzAN\nP4a68aJVnevatWvjM2fOBFpumzZt2tno6OhrWsXkLEyuRETOcOgQ8MgjkACGhH2Pg5dbwD8QeDJv\n45+7BAcDCzlw/dPWrVtPaB2Dq+hyWpiIyKNYJNbX2m7H8j0tsHw5sGQJEBqqGpuHhgILF7q32hJp\nhyNXIqKyOHgQ6NwZUghM6/I9Pk5ojpkz7yRRJlPfxJErEVFp3U6sEAILn/kebyU0x2uvAWPHah0Y\naY3JlYioNA4eBB55BBAC61/9HsM/bo6YGODDD7UOjPSAyZWIyFEHDqjEajBg++Tt+PtbzdGtG7B0\nKWDgp6pVeuvn+uWXX1Y2xxIcHNw2LCysZfPmzSP69OkTNmfOnBoDBw5s5OjrW+I5VyIiRxw4oKaC\n/fyQ8tH3eHxwM7RpA6xdCwSUqW+Md9NbP9fo6Ohr0dHRBwGgffv2zWbOnHnm4YcfzgaAOXPm1LD3\neWxhciUiuhvL7jYGA1C5Mk6u/BldnmuG+vWBLVuASnaNl7Q3ZAgapqYi2JnP2bIlspcuxRlnPJdl\nP9dq1aoZd+/efTQ4OLhtTEzMxZ07d1aeM2fO6a1bt1b6+uuvq+bl5RkiIyNvJCQkpBkMBkRHR4f1\n7Nkza/DgwVdDQkJa9evX7/I333xTxWg0itWrV59s27Ztrr1xpKenl/vb3/7W9PTp04E9evTIjIuL\nO+vIcXACg4ioJEW725hMuJBdEY883wD+/sA33wC1a2sboifwtH6uBw8eDN6wYcPJQ4cOHUhMTKx2\n/Pjxco48XpcjV4vawlqHQkS+zqK7DQBcQ0V0vbUZV65I7EgCGjfWMLZScNYI01Ge1s+1Q4cO12rU\nqFEAAE2aNMk9ceJEYJMmTfLtfbwuR66sLUxEumEesQLIQzn0wiYcRnOsl31w//0axuVDtOjnGhAQ\n8GdXGz8/P5mfn+/Q43WZXImIdCE3FwhUpW8LIDAAq7ATHfEZBuLRUHbschX2cyUi8lYmEzB4MJCX\nh1wEYCxmYR2i8RFGoxc24r+PL0QHrWP0IOznqiORkZEyKSlJ6zCIyBfFxgLTp+P9qu9jfWZH7MGD\nGIcPMRLz/uxuc+qU1kEWx36u7mWrnytHrkRERS1dCkyfDgwdijcX/ROAOt02E+MxE+MBAOK0hvGR\n7jG5EhFZ2roVePll4LHHcGPGPPgvEzAai+/WqEz1ewhgP1ciIt+Qmgr07Qu0aAG5eg1eGVUOBQVq\nTVNe3p3dgoOBaezLWmbs50pE5O3++AN4/HGgYkVg82YsXVsZK1YA77zDvqzkOI5ciYhu3AB69QKu\nXAF27cJvmQ0xahTQpQswYQLg58dkSo5hciUi31ZQADz7LJCSAmzciBtN26LfA0DVqqryoZ+f1gGS\nJ+K0MBH5LimB0aOBTZuATz6B7PE4RowAjhxRibWOQ9VoqSTuaDkHAJ9//nnV5OTkoNJH6hxMrkTk\nu2bPBubOBcaOBUaMQHw88PnnwNtvq65yPisurjrq128Fg6Ed6tdvhbi46mV9SnNtYfNt+vTp6bb2\nXbBgQZ0bN26UKj9t2LCh6v79+8uXPlLnYHIlIt+0fj3w2mtAdDTwwQc4cAAYOVIl1YkTtQ5OQ3Fx\n1TFmTCjOnw+AlMD58wEYMybUGQnWHpYt56KiosIBYN26dZXbtGnTPCIiokWPHj3uMZdGHDFiREjj\nxo3vDQ8Pjxg2bFiDrVu3Vvj222+rTpw4sUHz5s0jDhw4EFjyq7kOKzQRke/Zswfo1Am47z7g++9x\n01QeDzyg1jOlpAB1bU5Y6l+ZKzTVr98K588Xb/ter94t/PHHb6WNy8/Pr13Tpk1zzD+PHTv2/NCh\nQ69a2zckJKRVUlLSoXr16hnPnz/v36tXr8bbtm07VrlyZVNsbGzdvLw8MW7cuIwHH3ywxcmTJ1MN\nBgMuXbrkV7NmzQLLnq6ljdURHlWhiS3niMglEhKAf/5TXXbj7w8MHAiUL4+Rg4DDh1X9CE9OrE6R\nnl48sZa03U6lbTm3ffv2CidOnAhq3759cwDIz88X7dq1u1GjRo2CwMBAU//+/cN69uyZ2b9//2It\n57Sky2lhtpwjIqdLSACGDlWJFQCMRmD8eMS//BOWLwfeektdeuPz6ta95dB2F5NSokOHDtfM52pP\nnDhxYM2aNWnlypVDSkrKob59+17dtGlT1U6dOjXVIj5bdJlciYicbsIEICen0KaD2aEYuag1OnUC\nJk3SJizdmTTpHIKCTIW2BQWZMGmS1Q42rmDZcq5Tp043k5KSKqampgYCwLVr1wz79+8PzMrKMly5\ncsWvf//+WXFxcWcOHz4cDAAVK1YsuHbtmua5TZfTwkRETmUyAacLV9q/iWD8Hf9GRXkdK1cG83pW\ns+HDrwAApkwJQXp6AOrWvYVJk879ub2UytJybsGCBaeeeeaZe27duiUA4O233z5XpUoVU8+ePZvk\n5eUJAJg6deoZAIiJibnyyiuvhMXFxdVZu3btiXvvvTfP2mu4Ghc0EZF3M1/LOmdOoc1DsATxGIRv\nag9E1wsrNArO+dhyzr1sLWjSfOhMRORS//qXSqw9eqiK+wA+w/NYhiGI9f8AXWf10DhA8kacFiYi\n77VqFTB+PNC/P7ByJf47ahXyFyzFK6b5iMJudHmxERAzQOsofRZbzhEReZpt24AXXgA6dgSWL0fC\nKgOGxscgx6Qq8O9GFJ74PAoL/8ai/FphyzkiIk+yfz/Qpw8QHg5s2AAEBiI2tthiYWRnA7Gx2oRI\n3o3JlYi8y+nT6vxqpUrAV1+p9jYA0tJs707kbJwWJiLvcfWqSqw3bwK7dgENGwIAzp4FDAZ1RU5R\njRq5OUbyCUyuROQdcnOBJ58Ejh8HvvkGaNUKgGrX+txzQLlygBBqN7PgYGDaNI3iJa/GaWEi8nwm\nE/D882q0+tlnqij/be+/D+zYAcyfDyxeDISGqiQbGgosXMjFTO6idT/X119/va75tS1jeffdd2tH\nR0eHLVu2rFppXs8WjlyJyLNJqVrHrV2rrmnt3//Pu376SfVm7d8fGDRIJVUmU204Urh/wYIFdYYO\nHXqlUqVKVibyS7Zhw4aqRqMxq127drmW22fMmJE+Y8aMdAAIDg5uaxlLdHR0mKOvczccuRKRZ5s1\nSzU9HzNGJdnbsrKAAQPUade4OJVYSf+06ue6Y8eOim3btm3eoEGDVs4YxXLkSkSeJyFBXUNjXgIc\nFQXMnPnn3VICr7wCnDmjZopvLxgmABgypCFSU4Od+pwtW2Zj6dIzJe1StLawrX6uEydOzJg/f36d\nHTt2HDX3c50+fXq9nTt3HjX3c506dWqdcePGZWzZsqVa0X6ujz76aGZp+rleuHChXFJS0uGUlJSg\nPn36NClrP1gmVyLyLAkJwLBh6iJVs/37VTWm23O+n32mfpw6FXjoIY3ipEL03s+1d+/emX5+fmjX\nrl3u5cuXy5XluQAmVyLyNLGxhRMroKpDxMYCMTE4ehQYOVIVZnrzTW1C1LW7jDD1xtzPdePGjb8X\nvS8lJeVQYmJi5bVr11abP39+7Z9//vloaV8nKCjozy42zmhoo8tzrkKIXkKIhVlZumosT0R6UEI1\niFu31HnWgADg88/BNnIeyhv6uWoegDVSyo1SymFVqlTROhQi0pPVq23f16gRYmOB5GRgyZI/60eQ\nTpjPuZpvI0aMCLG1r7mfa1RUVHj9+vWN5n6u4eHhEZGRkc1/++23oMzMTL/u3bs3DQ8Pj3jooYea\nWfZznTNnTt0WLVo4tKDJ2djPlYg8w8KFwPDhQLNmavRqWSg4OBj/949EPDajC4YPV9e0+ir2c3Uv\n9nMlIs/1wQfAyy+r0oZ79wKLFhWqBpHx4XIMjO+CiAh1qSuR1vSdXJOTgbAwtTrQHgkJan+DwbHH\nEZE+SalWJb3+OvDMM8D69UD58mpV8KlTgMkE+fspDN7cF5mZwBdf/NkPnTxA165dG1tOFTdv3jzi\nyy+/rKx1XM6g/9XCaWnAiy8CJ06oFlKBgUBQ0J2vQUFq9cLKlYWX56elqZ+Bu5dkMV8zd/q0quI9\nbRrLuBBpzWRSy37j4tSodd48qyuU5swBtmwBPvnkz3LC5CG8uZ+rvs+5CiHLfMY1KAjo3Vu1n7J2\nS04GPv0UyMu785jgYPuKjjIpE7lGfr6qV7hypRq1vvee1RJLv/4KPPgg8NhjwH/+wypMgM1zridb\ntWp11WAw6PcD3wOZTCbx22+/VWvduvU9Re/T/8jV0tq1Kgnm5t75av7+nXesPyY3F9i3D7h+Xd1u\n3FBTTSXJzlZv7Llzgdq11a1WrTvf164N/PKLukLdvKjCkZEyEdmWkwP06wds2qSS6htvFNslIUHN\nFp85owazjz/OxHoXqRcvXoyoVatWFhOsc5hMJnHx4sUqAFKt3e85yTU0FIiOtn1/fLz1699CQ4HD\nh+/8bDKp5GlOts2aWX8+o1GNbNPSVCLNyFC9q0qSnQ2MGqVGy02aAI0bAxUrFt6Ho10i265dUzNN\nO3eqJb/DhxfbpWiBpoICYOxY9XblW8k6o9H4Unp6+uL09PSW0PtaG89hApBqNBpfsnanZ0wL2zNN\na60kmj2PCwuznZRPnbrzs8kEZGaqJJuRocq/2KNOHZVkmzRRDZw3bgRu3XIsRiJfcOmSWg2ckqLq\nFz77rNXd7H3L+ipr08LkfvpPrqGh9o/uSjMqdHZSbtgQ2LBBLcA6flx9NX9/9qz156pWDUhMBO6/\nn0sdybdYFuAvd7uc6/r1wBNP2HyIrelfIdTfwL6OyVUnpJS6vbVr1066xYoVUoaGSimE+rpihX2P\nCQ6WUp3BVbfg4JIfa7mvtZufn5Rt2kg5bJiUixdLuX+/lEZj6WMk0jNr76HAwBL/b6enq7eJtbdP\naKj7QtczAElSB5/fvn7T98hV7xWaHB0p2xrtNmigzi/t3g3s2aNumZnqvgoV1Gj4+HF1HtiM08nk\n6Ryc383PBx59FPj5Z3Upe65FK2y+He7gyFUfmFzdyd4paClVMjUn2wULCp+nNatTBzh3jtXJyfPk\n5qpiENbYmN999VV1LWtCgnqLcF2gdUyu+sDk6m6lOS9sMNi+fKhGDbUIpGdPdbEfu0KT3qWmqtY1\nv/1m/X4rI9f4eGDwYGDMGGDWLJdH6NGYXPWBS7LdzaJsG06dsu/P7UaNrG+vWVMt/PjqK1UarlYt\noHNn9elz9HZbQ5aEJL2QUl07HhmpVtyPH198AV9wsPqD08Ivv6grcjp3ViWGiTyC1id9S7q5bUGT\n3t1t8ZTRKOUPP0j55ptStmx5Z5+6daX097f9OCJ3uXBByieeUP8Hn3hC/SzlXRfqXbggZYMG6q6L\nF90dtGcCFzTp4sZpYU/hyHTyqVPA5s3AuHGFV32YNWpku+E0kbN9/bWqeJaZCcycqeoF21FOybyA\nac8e4Icf1JVqdHecFtYHTgt7Ckemk8PC1AeYZb1kS6dPq/m18+ddECjRbbm5wOjRak1ArVpAUpKq\nYGZnncJx41ShpkWLmFjJ8zC5ejNb52oDA1Ux9AYN1Dlbc81mImc5cACIigJmzwb+8Q81/GzZ0u6H\nf/aZ6nYzejTw3HMujJPIRZhcvdm0adYXjCxZAhw5ohLsvn3A3/8O1K+vPgSTk+/e2ICoKPPCOSHU\nCva2bYH0dHV6Ys4c25fdWJGcrK5Y69QJ+PBDl0VM5FJMrt4sJkZdQxsaqj70QkPvXFMbHg5Mn67O\nvX79NdCtm5p/i4wEWrdW+zRsyFXGdHfm67fN5/GvXFHV9N9+W7WrcUBGhmrbXKcOsGYN4O85rUWI\nCuGCJrrj6lVg9Wq16OREkR7G5cur5Msr9amo0FB1Ht/adgcq6efnA127qtopXMBUelzQpA+6HLkK\nIXoJIRZmZWVpHYpvqVZNXVBoWWbRLCdHTRubyzISSQmsW2c9sQK2t9swfjywY4eaXGFiJU+ny+Qq\npdwopRxWpUoVrUPxTbY+FK9eBUJCVAK2VV2HfMNPPwEdOqgey+ZuNkXZWlBnwfJU7ezZqsjY8887\nN1QiLegyuZLGbH0o1q2rKkEtXw7cd5/qabtmjZrPI99w/DjQty/wl78AJ0+qYebixXZVWiqq6Kla\nQF16w9P75A2YXKk4W6uMZ85UK43PnlXLOM+cAfr3V0OPKVPUdbMst+idLl1SlfNbtFAL4CZPBo4d\nA4YOBQYOtL1wrgSxsYV7WADq7ENsrOsOg8hduKCJrLOnIlRBgfqgnTtXfTV35ykouLMPe4F5tpwc\nNV/73nvAjRsqmU6erGYxyohNz12DC5r0gcmVnOPYMaBdO+D69eL3ObhqlDRk/qMqLU1dryqlurSm\nVy9gxgw1cnWCa9dU0SZrnRT536VsmFz1gdPC5BxNm6qRjTVpaVwA5QkSEtTI1HwS9PJltTp8wgQg\nMdFpiTU3F3jqKbUoPTCw8H12nKol8ghMruQ8Ja0Ove8+tbo0IYGlFvXoxg11TjUnp/B2k8mp582N\nRrUm7vvvVYnDJUscPlVL5BGYXMl5bC2Emj9fLYa6cEEVim3QQJVePHlSmzjpjpMngbFj1e/kyhXr\n+zh4vaotJhPw0kvAf/4DfPKJSqKlaW9M5AmYXMl5bJVbHD5cfYAfOQJs3Qo8/DDwr38BjRsD3bur\nT9vPPuMqY3eREti2DXjySaBJE1X7t0cP24uU7Lhe1Z6XHDtWXcX1zjuqOQ6RV9O6oWxJNzZL92Jn\nz0o5ebKUISFSAqpZNpu6u9bNm1IuWCDlvfeqf+OaNaWMjVW/CynVv3dwsEt+D1Onqqf73/+V0mQq\n89NRCcBm6bq4aR5ASTcmVx+Qny9lrVqy0Ae6+daggdbRea4VK6QMDVV/tISESPnEE1JWq6b+Xdu0\nkXLZMilzckp+XGioUxLr3LnqZQcOlLKgoMxPR3fB5KqPGy/FIe0ZDLbb3D3xhFoB07s3ULmye+Py\nVOZVv0UXJ7Vvr6bj//pXuxuWl9XKlepsQe/ewJdfssuNO/BSHH3gOVfSnq1zepUrq0t4nn9e9SDr\n2xf497+Ll/UhJTsbWL8eePnl4okVUAvKOnRwW2LdvBl44QXVl3X1aiZW8i1MrqQ9W6uMP/0U+P13\n4McfVRHaH34A+vUDatdWw6GNG4H4eN9eCHXliloM1qcPULMm8PTTwM2b1vd10qpfe+zapf4Wat1a\nrVcLCnLbSxPpg9bz0iXdeM7Vh9hzrs9olHLbNimHDZOyRg1p9Tytty2EsvbvcvaslPPmSdmli5R+\nfuq4Q0KkHDlSyu++k7JRI+v/NqGhbgl5714pK1eWslkzKTMy3PKSZAE856qLG8+5kmfKz1fXZmZk\nFL+vUiV1CVCHDmofT2VuG2M5DW4w3Cm826yZGrH26QNERqr7bD3OxTWeLasmGgyqNfCvvwING7rk\n5agEPOeqD5wWJs9Urhxw8aL1+65fB559Vn2yh4WpwhVxcUBq6p3EpNfuPVKq6dvERGDkyOLnl00m\noGpV4OBB4PBhVVC/ffs7iRWwfb2xCxOrZes4k0nNTO/c6ZKXI/IIHLmS5woLK9wM1KxRI7Ww57//\nVbddu4D0dHVf1arqcQcOFO5Da+/Izp5uQfY+rl8/4NAhICWl8O3q1ZKfS2dtY2z9GliAXxscueoD\nkyt5LnunP6VUC6PMyTY+3nqD96AgtQqnXj11q1+/8PcbNtz99aRUrV5u3lT1em/cANauBaZPL1xT\nWQjVos9ovPPa990HtGkDtG2rvvbrp3rmFqWzrMXWcfrC5KoPTK7k2UozkizputqwMNX03VpzASGs\nP87fX5UONCdTc8K8m8qV1XR1mzaqq1DRa1U0OHfqqFOnVOjWDllnfwP4DCZXfeCVZ+TZzNXfHdGo\nke15zN9/Vwk0MxP44w+VaM1f33jD+vMZjUDXrkDFinduFSrc+f6ZZ6w/znxuuKRjA0o3De0GJ04A\nnTsDAQFqEG759whbx5Gv48iVfE9pR4SlPbnohScljxwBunRRvVm3blXrq3T6N4DP4chVH7hamHxP\naVfT2ip2cbchWmkfp1MHD6qqS7duqb6sbduydRxRUUyu5JtKkw1Km5TdfGmMK+3frxIrAGzfDrRq\npWU0RPrFaWEissveverUcvnyqh1seLjWEZE1nBbWB45cieiu9uxR51grVlTFIZhYiUrG5EpEJfrx\nR+DRR4Hq1VVivecerSMi0j8mVyKyaccOoFs3VUdjxw51upiI7o7JlYis+u47oEcPdWnN9u2e3QOB\nyN2YXInoT+Z+BkKoxUs1a6rEWq+e1pEReRYmVyICULy7jZTApUuqSAQROYbJlYgAqApLRTvc5eSo\n7UTkGF0mVyFELyHEwqysLK1DIfIZ1io0AqqkIRE5RpfJVUq5UUo5rEqVKlqHQuQT5s+3fV+jRu6L\ng8hb6DK5EpH7/OtfwIgRqkZw+fKF7/PgEshEmmJyJfJRUgJTpgDjxqm+7Lt3A4sWeUUJZCLNsZ8r\nkQ+SUrWn/eADYNAgYPFi1ZO1NO1xiag4JlciH2MyAa++Csybp6aDP/kEMHAOi8ipmFyJfEhBAfDS\nS0B8vJrzsRf+AAAOmElEQVQO/uADNQVMRM7Fv1eJfER+vpryjY8HJk9mYiVyJY5ciXxAbi7Qvz+Q\nmAh8+KEatRKR6zC5Enm57GzgqadUGUPzeVYici1OCxN5IcsC/NWrA99+q6aDmViJ3IMjVyIvYy7A\nb64TnJcHBAQA/ny3E7kNR65EXsZaAf5bt1iAn8idmFyJvAwL8BNpj8mVyIucO2d7+pcF+Inch8mV\nyEucPg107KiSa2Bg4ftYgJ/IvZhcibzA77+rxHrpErB9O7BkCQvwE2mJ6weJPNyxY0DnzmoR07Zt\nwP33A1FRTKZEWmJyJfJghw4BXbqo0obffw/cd5/WERERwORK5LFSU1ViFUJNBd97r9YREZEZz7kS\neaCUFKBTJ7V4accOJlYivWFyJfIwSUnqHGuFCsDOnUCzZlpHRERFMbkSeZCfflJTwVWrqhFr48Za\nR0RE1jC5EnmIXbuAbt2AOnVUYg0L0zoiIrKFyZVIxyy723TsCFSurBJrw4ZaR0ZEJWFyJdIpc3cb\nc61gKYGrV9W1rESkb0yuRDplrbtNTg672xB5AiZXIp1idxsiz8XkSqRDX35p+z52tyHSPyZXIp1Z\nvRro3x8IDwfKly98H7vbEHkGJlciHVmxAhgwAPjrX4HkZGDRIna3IfJErC1MpBPx8cCQIcAjjwCJ\niaoCU0wMkymRJ+LIlUgHFi0CBg8GunYFNm1SiZWIPBeTK5HGPv1UXc/6+OPAf/5T/DwrEXkeJlci\nDc2eDYwcCfTuDaxbBwQFaR0RETkDkyuRRmbOBEaPBp5+Gvj3v4HAQK0jIiJnYXIl0sB77wHjxwP9\n+gFffAEEBGgdERE5E5MrkRuYC/AbDKpd3IQJ6pKbhASgXDmtoyMiZ+OlOEQuZi7Ab64TnJUF+PkB\n3bsD/nwHEnkljlyJXMxaAf6CAuCtt7SJh4hcj8mVyMVYgJ/I9zC5ErmQ0Wi7IAQL8BN5LyZXIhfJ\ny1MF+G/eLL5oiQX4ibwbkyuRC2RnA08+qQpDfPQRsGwZC/AT+RKuVSRysmvXgJ49gf/+F1i8GHjx\nRbWdyZTIdzC5EjnR5cvqEpuUFGDVKjUtTES+h8mVyEnOn1ddbY4fV9PBvXppHRERaYXJlcgJ0tKA\nRx9VCXbzZqBLF60jIiItMbkSldHRoyqxXrsGbN0KPPSQ1hERkdaYXInKYP9+oFs3wGQCtm8H2rTR\nOiIi0gNeikPkAMsC/PXqAX/5i6oPvHMnEysR3cGRK5GdihbgT09X161OnQo0b65tbESkL24buQoh\n7hFCLBFCrHXXaxI5k7UC/FICs2drEw8R6ZddyVUIsVQIkSGESC2yvbsQ4ogQ4rgQ4o2SnkNKeVJK\n+WJZgiXSkq1C+yzAT0RF2TstHA9gLoDPzBuEEH4A5gHoCuAsgF+EEIkA/AC8V+TxQ6SUGWWOlkgj\nJhNQqZJaEVwUC/ATUVF2JVcp5U4hRFiRze0BHJdSngQAIcQXAJ6UUr4HoKczgyTSUm4uMHiwSqz+\n/qrTjRkL8BORNWU55xoC4IzFz2dvb7NKCFFDCBEHoK0Q4s0S9hsmhEgSQiRdvHixDOERld3ly6rq\n0hdfADNmsAA/EdnHbauFpZSXAQy3Y7+FABYCQGRkpHR1XES2nDwJ9OgBnDqlkqu5TvBzz2kaFhF5\ngLIk13MAGlr83OD2NiKPt3u3qg1cUAB89x3QoYPWERGRJynLtPAvAJoKIf5HCBEA4BkAic4Ji0g7\n69cDnTqpBUw//sjESkSOs/dSnFUAfgLQTAhxVgjxopTSCGAUgG8AHAKwRkp5wHWhErnexx8D0dGq\n2tLPPwPNmmkdERF5IntXCz9rY/sWAFucGhGRBgoKgNdeA+bMAZ5+GlixAihfXuuoiMhTsbYw+STL\nGsGNGgFRUSqxjhkDrFnDxEpEZcPawuRzitYIPnNG3QYOBGbN0jY2IvIOHLmSz7FWIxgAduxwfyxE\n5J10mVyFEL2EEAuzsrK0DoW8EGsEE5Gr6TK5Sik3SimHValSRetQyMtkZwMVKli/jzWCichZdJlc\niVzh8GG1cOnmTVUj2BJrBBORMzG5kk/4/HMgMhK4cAH4+msgPp41gonIdbhamLxadjYwapQquN+x\nI7ByJVC/vrqPyZSIXIUjV/JaBw8C7durUerEicC3395JrERErsSRK3ml5cuBESPU4qVvvlFt44iI\n3IUjV/IqN2+qxuaDBgEPPACkpDCxEpH7MbmSR7MsY1i/viq0v3w58NZbnAYmIu3oclpYCNELQK8m\nTZpoHQrpWNEyhufPq69vvAFMmaJdXEREuhy5sogE2cNWGcNVq9wfCxGRJV0mV6K7yc0F0tKs38cy\nhkSkNSZX8jibNwMtW9q+n2UMiUhrTK7kMU6cAHr1Anr2BMqVA15/XZUttMQyhkSkB0yupHvZ2Wr1\nb0QEsH078OGHwL59wPvvq7KFLGNIRHqjy9XCRAAgJbBuHfDaa+o8akwM8MEHhS+viYlhMiUi/eHI\nlXTB8nrVsDCVRLt1A/r2BapWVY3MV6zgdatE5Bk4ciXNFb1eNS3tzvnUTz4Bhg8v3iKOiEjP+JFF\nmrN1vWr16qqjDRGRp+G0MGnqyBHb16ueO+feWIiInEWXyVUI0UsIsTArK0vrUMgFTCZg0ybgsceA\n5s1t78frVYnIU+kyubL8oXfKzARmzQKaNlXXq6amAlOnAvPm8XpVIvIuukyu5LmKrvpNSFBJdPhw\nICQEGDtWrfhdvRo4dUo1MR8xgterEpF3EVJKrWOwKTIyUiYlJWkdBtmp6KpfQCVZkwkIClLJctQo\noE0b7WIk8nZCiGQpZaTWcfg6rhYmp3nzzeKrfk0mdZ3q8eNAjRraxEVE5G5MrlRqBQXA3r3A1q3q\nduaM9f2ysphYici38JwrWWXt3CkA/P67Oh/6978DtWsD7dur61SvXgUqV7b+XFz1S0S+hsmVijGf\nO01LU/V909KAQYOAOnWAe+4BXn4Z+OknoHdvtW96OpCSAnz6KVf9EhEBTK5ez9YI1FJBgVq5+3//\nB8ydC7zySvFzp0YjcP06MHs2cPCgmgJetgwYMEAlXUAtWOKqXyIirhb2GAkJavr19Gk1zTpt2t2T\nlrXVu4GBwPPPq3OgR4+q2/HjQF7e3WMQQi1QIiL94mphfdD1yDU52fZoyxp7RmlaP660jyk6TTt0\nKDBzJvDzz8CWLapjzJw5wOTJwKuvqsT70kvFR6B5ecDixaqYw+HDQJMmav9Fi1TnmfPnbZ8j5blT\nIiL76HrkKkSkBJIQGAiMHg107qy2Wwt52zY1ZWk5AgsMVInjkUcKP8b8vZSq+fbcuYUfFxCgpkYf\nfljtY+22a5dKUrduFX5cTAwQGammWk0m9dV827sX2LAByM+/8xh/f6BDB6BhQyA3V93y8gp/PXBA\nTcvaq2pVVfT+5Enb++Tn2+40Y23EGxzMKV4iT8CRqz54RHL1dn5+alQYGKiKLZi/mr9PTLT92M2b\n1RRv9erqVrWqej5AjYytFcUPDVXnWEtSmmloItIek6s+6DK5CiF6AegF1BgKhFnck5xs+1Ht2tm+\nr0yPqwLASgcBhx53+/vSxnhfK6BcQPHt+UZg/z7bj6tZHWgUCgiL6X9pAk6nAZeuFI6tEHu2WTm+\nQt/XBHDJdmx2sfFv79B+pTk+e47VW47P1vd6OD57tzv6fxMo+/G56ndnbbujx9dMSlnJjtjIlaSU\nur0BSCrj4xeWdT9b9xXdXtLP1r4v67G58vjs2ebNx2fPsXrL8ZXwvebH56r3njOOz9s/W3gr+03X\nC5qcYKMT9rN1X9HtJf1s6/uyctXx2bPNm4/P3mMtKz0cn6uOzZHns/c9Zmu7N/3ftLZdi+OjMtLl\ntLCZECJJeum5A28+NoDH5+l4fJ7Lm4/Nk+h95LpQ6wBcyJuPDeDxeToen+fy5mPzGLoeuRIREXki\nvY9ciYiIPA6TKxERkZMxuRIRETmZxyZXIUQFIUSSEKKn1rE4mxCihRAiTgixVgjxitbxOJsQ4ikh\nxCIhxGohRDet43E2IcQ9QoglQoi1WsfiDLffa8tv/868rk6Xt/2+ivL295teuT25CiGWCiEyhBCp\nRbZ3F0IcEUIcF0K8YcdTvQ5gjWuiLD1nHJ+U8pCUcjiAfgD+6sp4HeWk49sgpRwKYDiA/q6M11FO\nOr6TUsoXXRtp2Th4nE8DWHv7d9bb7cGWgiPH5wm/r6IcPD7dvt+8mRYj13gA3S03CCH8AMwD0ANA\nBIBnhRARQohWQohNRW61hRBdARwEkOHu4O0QjzIe3+3H9AawGcAW94Z/V/FwwvHdNvH24/QkHs47\nPj2Lh53HCaABgDO3dytwY4xlEQ/7j88TxcPx49Pj+81r2eiL4jpSyp1CiLAim9sDOC6lPAkAQogv\nADwppXwPQLFpXyFEJwAVoP4D5QghtkgpddFp1BnHd/t5EgEkCiE2A1jpuogd46TfnwDwPoCvpJR7\nXRuxY5z1+9M7R44TwFmoBJsCDzmV5ODxHXRvdGXnyPEJIQ5Bp+83b6aXN0oI7vxlDKg3c4itnaWU\nsVLK0VBJZ5FeEmsJHDo+IUQnIcQcIcQC6G/kao1DxwfgHwAeBdBXCDHclYE5iaO/vxpCiDgAbYUQ\nb7o6OCeydZzrAEQLIebDs0vsWT0+D/59FWXr9+dp7zev4PaRqzNJKeO1jsEVpJTbAWzXOAyXkVLO\nATBH6zhcRUp5Ger8lleQUt4EMFjrOFzF235fRXn7+02v9DJyPQegocXPDW5v8xY8Ps/m7cdn5u3H\nyeMjt9FLcv0FQFMhxP8IIQIAPAOghBbhHofH59m8/fjMvP04eXzkNlpcirMKwE8AmgkhzgohXpRS\nGgGMAvANgEMA1kgpD7g7Nmfg8fH4PIG3HyePz7OPzxuwcD8REZGT6WVamIiIyGswuRIRETkZkysR\nEZGTMbkSERE5GZMrERGRkzG5EhERORmTKxERkZMxuRIRETkZkysREZGT/T/pspA/hIoxbwAAAABJ\nRU5ErkJggg==\n", 230 | "text/plain": [ 231 | "" 232 | ] 233 | }, 234 | "metadata": {}, 235 | "output_type": "display_data" 236 | } 237 | ], 238 | "source": [ 239 | "## Parameter setting\n", 240 | "n=512\n", 241 | "p=256\n", 242 | "Tr=1024 # Training length\n", 243 | "Te=Tr # Testing length\n", 244 | "\n", 245 | "prop=[.5,.5] # proportions of each class\n", 246 | "K=len(prop) # number of data classes\n", 247 | "\n", 248 | "gammas = [10**x for x in np.arange(-4,2.25,.25)] # Range of gamma for simulations\n", 249 | "\n", 250 | "testcase='MNIST' # testcase for simulation, among 'iid','means','var','orth','mixed',MNIST'\n", 251 | "sigma='ReLu' # activation function, among 'ReLu', 'sign', 'posit', 'erf', 'poly2', 'cos', 'abs'\n", 252 | "\n", 253 | "\n", 254 | "# Only used for sigma='poly2'\n", 255 | "polynom=[-.5,0,1] # sigma(t)=polynom[0].t²+polynom[1].t+polynom[2]\n", 256 | "distrib='student' # distribution of Wij, among 'gauss','bern','bern_skewed','student'\n", 257 | "\n", 258 | "# Only used for sigma='poly2' and distrib='student'\n", 259 | "nu=7 # degrees of freedom of Student-t distribution\n", 260 | " \n", 261 | "\n", 262 | "## Generate X_train,X_test,y_train,y_test\n", 263 | "if testcase is 'MNIST':\n", 264 | " p=784\n", 265 | " X_train,X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop)\n", 266 | "else: \n", 267 | " means=[]\n", 268 | " covs=[]\n", 269 | " if testcase is 'iid':\n", 270 | " for i in range(K):\n", 271 | " means.append(np.zeros(p))\n", 272 | " covs.append(np.eye(p)) \n", 273 | " elif testcase is 'means':\n", 274 | " for i in range(K):\n", 275 | " means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n", 276 | " covs.append(np.eye(p))\n", 277 | " elif testcase is 'var':\n", 278 | " for i in range(K):\n", 279 | " means.append(np.zeros(p))\n", 280 | " covs.append(np.eye(p)*(1+8*i/np.sqrt(p)))\n", 281 | " elif testcase is 'orth':\n", 282 | " for i in range(K):\n", 283 | " means.append(np.zeros(p))\n", 284 | " covs.append( np.diag(np.concatenate( (np.ones(np.int(np.sum(prop[0:i]*p))),4*np.ones(np.int(prop[i]*p)),np.ones(np.int(np.sum(prop[i+1:]*p))) ) ) ))\n", 285 | " elif testcase is 'mixed':\n", 286 | " for i in range(K):\n", 287 | " means.append( np.concatenate( (np.zeros(i),4*np.ones(1),np.zeros(p-i-1)) ) )\n", 288 | " covs.append((1+4*i/np.sqrt(p))*scipy.linalg.toeplitz( [(.4*i)**x for x in range(p)] )) \n", 289 | "\n", 290 | " X_train,X_test,y_train,y_test = gen_data(testcase,Tr,Te,prop,means,covs)\n", 291 | "\n", 292 | "##Theory\n", 293 | "start_th_calculus = time.time()\n", 294 | "\n", 295 | "Phi=gen_Phi(sigma,X_train,X_train,polynom,distrib,nu)\n", 296 | "L,U = np.linalg.eigh(Phi)\n", 297 | "Phi_cross = gen_Phi(sigma,X_train,X_test,polynom,distrib,nu)\n", 298 | "UPhi_cross = U.T@Phi_cross\n", 299 | "D_UPhi_cross2 = np.sum(UPhi_cross**2,axis=1)\n", 300 | "\n", 301 | "Phi_test = gen_Phi(sigma,X_test,X_test,polynom,distrib,nu)\n", 302 | "D_Phi_test = np.diag(Phi_test)\n", 303 | "Uy_train = U.T@y_train\n", 304 | "\n", 305 | "E_train_th=np.zeros(len(gammas))\n", 306 | "E_test_th =np.zeros(len(gammas))\n", 307 | "\n", 308 | "ind=0\n", 309 | "for gamma in gammas:\n", 310 | " E_train_th[ind],E_test_th[ind] = gen_E_th()\n", 311 | " ind+=1\n", 312 | " \n", 313 | "end_th_calculus = time.time() \n", 314 | "\n", 315 | "m,s = divmod(end_th_calculus-start_th_calculus,60)\n", 316 | "print('Time for Theoretical Computation {:d}min {:d}s'.format( int(m),math.ceil(s) )) \n", 317 | " \n", 318 | "## Simulations\n", 319 | "start_sim_calculus = time.time()\n", 320 | "\n", 321 | "loops = 10 # Number of generations of W to be averaged over\n", 322 | "\n", 323 | "E_train=np.zeros(len(gammas))\n", 324 | "E_test =np.zeros(len(gammas))\n", 325 | "\n", 326 | "\n", 327 | "rng = np.random\n", 328 | "\n", 329 | "for loop in range(loops): \n", 330 | " if sigma is 'poly2':\n", 331 | " if distrib is 'student':\n", 332 | " W = rng.standard_t(nu,n*p).reshape(n,p)/np.sqrt(nu/(nu-2))\n", 333 | " elif distrib is 'bern':\n", 334 | " W = np.sign(rng.randn(n,p))\n", 335 | " elif distrib is 'bern_skewed':\n", 336 | " Z = rng.rand(n,p)\n", 337 | " W = (Z<.75)/np.sqrt(3)+(Z>.75)*(-np.sqrt(3))\n", 338 | " elif distrib is 'gauss':\n", 339 | " W = rng.randn(n,p)\n", 340 | " else:\n", 341 | " W = rng.randn(n,p)\n", 342 | "\n", 343 | " S_train = gen_sig(sigma,W @ X_train,polynom)\n", 344 | " SS = S_train.T @ S_train\n", 345 | "\n", 346 | " S_test = gen_sig(sigma, W @ X_test,polynom)\n", 347 | "\n", 348 | " ind = 0\n", 349 | " for gamma in gammas:\n", 350 | "\n", 351 | " inv_resolv = np.linalg.solve( SS/Tr+gamma*np.eye(Tr),y_train)\n", 352 | " beta = S_train @ inv_resolv/Tr\n", 353 | " z_train = S_train.T @ beta\n", 354 | "\n", 355 | " z_test = S_test.T @ beta\n", 356 | "\n", 357 | "\n", 358 | " E_train[ind] += gamma**2*np.linalg.norm(inv_resolv)**2/Tr/loops\n", 359 | " E_test[ind] += np.linalg.norm(y_test-z_test)**2/Te/loops\n", 360 | "\n", 361 | " ind+=1 \n", 362 | " \n", 363 | "end_sim_calculus = time.time() \n", 364 | "\n", 365 | "m,s = divmod(end_sim_calculus-start_sim_calculus,60)\n", 366 | "print('Time for Simulations Computation {:d}min {:d}s'.format( int(m),math.ceil(s) )) \n", 367 | " \n", 368 | "#Plots \n", 369 | "p11,=plt.plot(gammas,E_train,'bo')\n", 370 | "p21,=plt.plot(gammas,E_test,'ro')\n", 371 | "\n", 372 | "p12,=plt.plot(gammas,E_train_th,'b-')\n", 373 | "p22,=plt.plot(gammas,E_test_th,'r-')\n", 374 | "plt.xscale('log')\n", 375 | "plt.yscale('log')\n", 376 | "plt.xlim( gammas[0],gammas[-1] )\n", 377 | "plt.ylim(np.amin( (E_train,E_train_th) ),np.amax( (E_test,E_test_th) ))\n", 378 | "plt.legend([p11,p12,p21,p22], [\"E_train\", \"E_train Th\",\"E_test\",\"E_test Th\"],bbox_to_anchor=(1, 1), loc='upper left')\n", 379 | "plt.show()\n" 380 | ] 381 | } 382 | ], 383 | "metadata": { 384 | "kernelspec": { 385 | "display_name": "Python 3", 386 | "language": "python", 387 | "name": "python3" 388 | }, 389 | "language_info": { 390 | "codemirror_mode": { 391 | "name": "ipython", 392 | "version": 3 393 | }, 394 | "file_extension": ".py", 395 | "mimetype": "text/x-python", 396 | "name": "python", 397 | "nbconvert_exporter": "python", 398 | "pygments_lexer": "ipython3", 399 | "version": "3.6.0" 400 | } 401 | }, 402 | "nbformat": 4, 403 | "nbformat_minor": 2 404 | } 405 | --------------------------------------------------------------------------------