├── .gitignore ├── LICENSE.txt ├── README.md ├── lib ├── coarsening.py ├── graph.py ├── models.py └── utils.py ├── makefile ├── nips2016 ├── 20news.ipynb ├── makefile └── mnist.ipynb ├── rcv1.ipynb ├── requirements.txt ├── trials ├── 1_learning_filters.ipynb ├── 2_classification.ipynb ├── 3_tensorflow.ipynb ├── 4_coarsening.ipynb └── makefile └── usage.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # IPython checkpoints 6 | .ipynb_checkpoints/ 7 | 8 | # Datasets 9 | data/ 10 | 11 | # Tensorflow summaries 12 | summaries/ 13 | 14 | # Model parameters 15 | checkpoints/ 16 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Michaël Defferrard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering 2 | 3 | The code in this repository implements an efficient generalization of the 4 | popular Convolutional Neural Networks (CNNs) to arbitrary graphs, presented in 5 | our paper: 6 | 7 | Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst, [Convolutional Neural 8 | Networks on Graphs with Fast Localized Spectral Filtering][arXiv], Neural 9 | Information Processing Systems (NIPS), 2016. 10 | 11 | Additional material: 12 | * [NIPS2016 spotlight video][video], 2016-11-22. 13 | * [Deep Learning on Graphs][slides_ntds], a lecture for EPFL's master course [A 14 | Network Tour of Data Science][ntds], 2016-12-21. 15 | * [Deep Learning on Graphs][slides_dlid], an invited talk at the [Deep Learning on 16 | Irregular Domains][dlid] workshop of BMVC, 2017-09-17. 17 | 18 | [video]: https://www.youtube.com/watch?v=cIA_m7vwOVQ 19 | [slides_ntds]: https://doi.org/10.6084/m9.figshare.4491686 20 | [ntds]: https://github.com/mdeff/ntds_2016 21 | [slides_dlid]: https://doi.org/10.6084/m9.figshare.5394805 22 | [dlid]: http://dlid.swansea.ac.uk 23 | 24 | There is also implementations of the filters used in: 25 | * Joan Bruna, Wojciech Zaremba, Arthur Szlam, Yann LeCun, [Spectral Networks 26 | and Locally Connected Networks on Graphs][bruna], International Conference on 27 | Learning Representations (ICLR), 2014. 28 | * Mikael Henaff, Joan Bruna and Yann LeCun, [Deep Convolutional Networks on 29 | Graph-Structured Data][henaff], arXiv, 2015. 30 | 31 | [arXiv]: https://arxiv.org/abs/1606.09375 32 | [bruna]: https://arxiv.org/abs/1312.6203 33 | [henaff]: https://arxiv.org/abs/1506.05163 34 | 35 | ## Installation 36 | 37 | 1. Clone this repository. 38 | ```sh 39 | git clone https://github.com/mdeff/cnn_graph 40 | cd cnn_graph 41 | ``` 42 | 43 | 2. Install the dependencies. The code should run with TensorFlow 1.0 and newer. 44 | ```sh 45 | pip install -r requirements.txt # or make install 46 | ``` 47 | 48 | 3. Play with the Jupyter notebooks. 49 | ```sh 50 | jupyter notebook 51 | ``` 52 | 53 | ## Reproducing our results 54 | 55 | Run all the notebooks to reproduce the experiments on 56 | [MNIST](nips2016/mnist.ipynb) and [20NEWS](nips2016/20news.ipynb) presented in 57 | the paper. 58 | ```sh 59 | cd nips2016 60 | make 61 | ``` 62 | 63 | ## Using the model 64 | 65 | To use our graph ConvNet on your data, you need: 66 | 67 | 1. a data matrix where each row is a sample and each column is a feature, 68 | 2. a target vector, 69 | 3. optionally, an adjacency matrix which encodes the structure as a graph. 70 | 71 | See the [usage notebook][usage] for a simple example with fabricated data. 72 | Please get in touch if you are unsure about applying the model to a different 73 | setting. 74 | 75 | [usage]: http://nbviewer.jupyter.org/github/mdeff/cnn_graph/blob/outputs/usage.ipynb 76 | 77 | ## License & co 78 | 79 | The code in this repository is released under the terms of the [MIT license](LICENSE.txt). 80 | Please cite our [paper][arXiv] if you use it. 81 | 82 | ``` 83 | @inproceedings{cnn_graph, 84 | title = {Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering}, 85 | author = {Defferrard, Micha\"el and Bresson, Xavier and Vandergheynst, Pierre}, 86 | booktitle = {Advances in Neural Information Processing Systems}, 87 | year = {2016}, 88 | url = {https://arxiv.org/abs/1606.09375}, 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /lib/coarsening.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse 3 | 4 | 5 | def coarsen(A, levels, self_connections=False): 6 | """ 7 | Coarsen a graph, represented by its adjacency matrix A, at multiple 8 | levels. 9 | """ 10 | graphs, parents = metis(A, levels) 11 | perms = compute_perm(parents) 12 | 13 | for i, A in enumerate(graphs): 14 | M, M = A.shape 15 | 16 | if not self_connections: 17 | A = A.tocoo() 18 | A.setdiag(0) 19 | 20 | if i < levels: 21 | A = perm_adjacency(A, perms[i]) 22 | 23 | A = A.tocsr() 24 | A.eliminate_zeros() 25 | graphs[i] = A 26 | 27 | Mnew, Mnew = A.shape 28 | print('Layer {0}: M_{0} = |V| = {1} nodes ({2} added),' 29 | '|E| = {3} edges'.format(i, Mnew, Mnew-M, A.nnz//2)) 30 | 31 | return graphs, perms[0] if levels > 0 else None 32 | 33 | 34 | def metis(W, levels, rid=None): 35 | """ 36 | Coarsen a graph multiple times using the METIS algorithm. 37 | 38 | INPUT 39 | W: symmetric sparse weight (adjacency) matrix 40 | levels: the number of coarsened graphs 41 | 42 | OUTPUT 43 | graph[0]: original graph of size N_1 44 | graph[2]: coarser graph of size N_2 < N_1 45 | graph[levels]: coarsest graph of Size N_levels < ... < N_2 < N_1 46 | parents[i] is a vector of size N_i with entries ranging from 1 to N_{i+1} 47 | which indicate the parents in the coarser graph[i+1] 48 | nd_sz{i} is a vector of size N_i that contains the size of the supernode in the graph{i} 49 | 50 | NOTE 51 | if "graph" is a list of length k, then "parents" will be a list of length k-1 52 | """ 53 | 54 | N, N = W.shape 55 | if rid is None: 56 | rid = np.random.permutation(range(N)) 57 | parents = [] 58 | degree = W.sum(axis=0) - W.diagonal() 59 | graphs = [] 60 | graphs.append(W) 61 | #supernode_size = np.ones(N) 62 | #nd_sz = [supernode_size] 63 | #count = 0 64 | 65 | #while N > maxsize: 66 | for _ in range(levels): 67 | 68 | #count += 1 69 | 70 | # CHOOSE THE WEIGHTS FOR THE PAIRING 71 | # weights = ones(N,1) # metis weights 72 | weights = degree # graclus weights 73 | # weights = supernode_size # other possibility 74 | weights = np.array(weights).squeeze() 75 | 76 | # PAIR THE VERTICES AND CONSTRUCT THE ROOT VECTOR 77 | idx_row, idx_col, val = scipy.sparse.find(W) 78 | perm = np.argsort(idx_row) 79 | rr = idx_row[perm] 80 | cc = idx_col[perm] 81 | vv = val[perm] 82 | cluster_id = metis_one_level(rr,cc,vv,rid,weights) # rr is ordered 83 | parents.append(cluster_id) 84 | 85 | # TO DO 86 | # COMPUTE THE SIZE OF THE SUPERNODES AND THEIR DEGREE 87 | #supernode_size = full( sparse(cluster_id, ones(N,1) , supernode_size ) ) 88 | #print(cluster_id) 89 | #print(supernode_size) 90 | #nd_sz{count+1}=supernode_size; 91 | 92 | # COMPUTE THE EDGES WEIGHTS FOR THE NEW GRAPH 93 | nrr = cluster_id[rr] 94 | ncc = cluster_id[cc] 95 | nvv = vv 96 | Nnew = cluster_id.max() + 1 97 | # CSR is more appropriate: row,val pairs appear multiple times 98 | W = scipy.sparse.csr_matrix((nvv,(nrr,ncc)), shape=(Nnew,Nnew)) 99 | W.eliminate_zeros() 100 | # Add new graph to the list of all coarsened graphs 101 | graphs.append(W) 102 | N, N = W.shape 103 | 104 | # COMPUTE THE DEGREE (OMIT OR NOT SELF LOOPS) 105 | degree = W.sum(axis=0) 106 | #degree = W.sum(axis=0) - W.diagonal() 107 | 108 | # CHOOSE THE ORDER IN WHICH VERTICES WILL BE VISTED AT THE NEXT PASS 109 | #[~, rid]=sort(ss); # arthur strategy 110 | #[~, rid]=sort(supernode_size); # thomas strategy 111 | #rid=randperm(N); # metis/graclus strategy 112 | ss = np.array(W.sum(axis=0)).squeeze() 113 | rid = np.argsort(ss) 114 | 115 | return graphs, parents 116 | 117 | 118 | # Coarsen a graph given by rr,cc,vv. rr is assumed to be ordered 119 | def metis_one_level(rr,cc,vv,rid,weights): 120 | 121 | nnz = rr.shape[0] 122 | N = rr[nnz-1] + 1 123 | 124 | marked = np.zeros(N, np.bool) 125 | rowstart = np.zeros(N, np.int32) 126 | rowlength = np.zeros(N, np.int32) 127 | cluster_id = np.zeros(N, np.int32) 128 | 129 | oldval = rr[0] 130 | count = 0 131 | clustercount = 0 132 | 133 | for ii in range(nnz): 134 | rowlength[count] = rowlength[count] + 1 135 | if rr[ii] > oldval: 136 | oldval = rr[ii] 137 | rowstart[count+1] = ii 138 | count = count + 1 139 | 140 | for ii in range(N): 141 | tid = rid[ii] 142 | if not marked[tid]: 143 | wmax = 0.0 144 | rs = rowstart[tid] 145 | marked[tid] = True 146 | bestneighbor = -1 147 | for jj in range(rowlength[tid]): 148 | nid = cc[rs+jj] 149 | if marked[nid]: 150 | tval = 0.0 151 | else: 152 | tval = vv[rs+jj] * (1.0/weights[tid] + 1.0/weights[nid]) 153 | if tval > wmax: 154 | wmax = tval 155 | bestneighbor = nid 156 | 157 | cluster_id[tid] = clustercount 158 | 159 | if bestneighbor > -1: 160 | cluster_id[bestneighbor] = clustercount 161 | marked[bestneighbor] = True 162 | 163 | clustercount += 1 164 | 165 | return cluster_id 166 | 167 | def compute_perm(parents): 168 | """ 169 | Return a list of indices to reorder the adjacency and data matrices so 170 | that the union of two neighbors from layer to layer forms a binary tree. 171 | """ 172 | 173 | # Order of last layer is random (chosen by the clustering algorithm). 174 | indices = [] 175 | if len(parents) > 0: 176 | M_last = max(parents[-1]) + 1 177 | indices.append(list(range(M_last))) 178 | 179 | for parent in parents[::-1]: 180 | #print('parent: {}'.format(parent)) 181 | 182 | # Fake nodes go after real ones. 183 | pool_singeltons = len(parent) 184 | 185 | indices_layer = [] 186 | for i in indices[-1]: 187 | indices_node = list(np.where(parent == i)[0]) 188 | assert 0 <= len(indices_node) <= 2 189 | #print('indices_node: {}'.format(indices_node)) 190 | 191 | # Add a node to go with a singelton. 192 | if len(indices_node) is 1: 193 | indices_node.append(pool_singeltons) 194 | pool_singeltons += 1 195 | #print('new singelton: {}'.format(indices_node)) 196 | # Add two nodes as children of a singelton in the parent. 197 | elif len(indices_node) is 0: 198 | indices_node.append(pool_singeltons+0) 199 | indices_node.append(pool_singeltons+1) 200 | pool_singeltons += 2 201 | #print('singelton childrens: {}'.format(indices_node)) 202 | 203 | indices_layer.extend(indices_node) 204 | indices.append(indices_layer) 205 | 206 | # Sanity checks. 207 | for i,indices_layer in enumerate(indices): 208 | M = M_last*2**i 209 | # Reduction by 2 at each layer (binary tree). 210 | assert len(indices[0] == M) 211 | # The new ordering does not omit an indice. 212 | assert sorted(indices_layer) == list(range(M)) 213 | 214 | return indices[::-1] 215 | 216 | assert (compute_perm([np.array([4,1,1,2,2,3,0,0,3]),np.array([2,1,0,1,0])]) 217 | == [[3,4,0,9,1,2,5,8,6,7,10,11],[2,4,1,3,0,5],[0,1,2]]) 218 | 219 | def perm_data(x, indices): 220 | """ 221 | Permute data matrix, i.e. exchange node ids, 222 | so that binary unions form the clustering tree. 223 | """ 224 | if indices is None: 225 | return x 226 | 227 | N, M = x.shape 228 | Mnew = len(indices) 229 | assert Mnew >= M 230 | xnew = np.empty((N, Mnew)) 231 | for i,j in enumerate(indices): 232 | # Existing vertex, i.e. real data. 233 | if j < M: 234 | xnew[:,i] = x[:,j] 235 | # Fake vertex because of singeltons. 236 | # They will stay 0 so that max pooling chooses the singelton. 237 | # Or -infty ? 238 | else: 239 | xnew[:,i] = np.zeros(N) 240 | return xnew 241 | 242 | def perm_adjacency(A, indices): 243 | """ 244 | Permute adjacency matrix, i.e. exchange node ids, 245 | so that binary unions form the clustering tree. 246 | """ 247 | if indices is None: 248 | return A 249 | 250 | M, M = A.shape 251 | Mnew = len(indices) 252 | assert Mnew >= M 253 | A = A.tocoo() 254 | 255 | # Add Mnew - M isolated vertices. 256 | if Mnew > M: 257 | rows = scipy.sparse.coo_matrix((Mnew-M, M), dtype=np.float32) 258 | cols = scipy.sparse.coo_matrix((Mnew, Mnew-M), dtype=np.float32) 259 | A = scipy.sparse.vstack([A, rows]) 260 | A = scipy.sparse.hstack([A, cols]) 261 | 262 | # Permute the rows and the columns. 263 | perm = np.argsort(indices) 264 | A.row = np.array(perm)[A.row] 265 | A.col = np.array(perm)[A.col] 266 | 267 | # assert np.abs(A - A.T).mean() < 1e-9 268 | assert type(A) is scipy.sparse.coo.coo_matrix 269 | return A 270 | -------------------------------------------------------------------------------- /lib/graph.py: -------------------------------------------------------------------------------- 1 | import sklearn.metrics 2 | import sklearn.neighbors 3 | import matplotlib.pyplot as plt 4 | import scipy.sparse 5 | import scipy.sparse.linalg 6 | import scipy.spatial.distance 7 | import numpy as np 8 | 9 | 10 | def grid(m, dtype=np.float32): 11 | """Return the embedding of a grid graph.""" 12 | M = m**2 13 | x = np.linspace(0, 1, m, dtype=dtype) 14 | y = np.linspace(0, 1, m, dtype=dtype) 15 | xx, yy = np.meshgrid(x, y) 16 | z = np.empty((M, 2), dtype) 17 | z[:, 0] = xx.reshape(M) 18 | z[:, 1] = yy.reshape(M) 19 | return z 20 | 21 | 22 | def distance_scipy_spatial(z, k=4, metric='euclidean'): 23 | """Compute exact pairwise distances.""" 24 | d = scipy.spatial.distance.pdist(z, metric) 25 | d = scipy.spatial.distance.squareform(d) 26 | # k-NN graph. 27 | idx = np.argsort(d)[:, 1:k+1] 28 | d.sort() 29 | d = d[:, 1:k+1] 30 | return d, idx 31 | 32 | 33 | def distance_sklearn_metrics(z, k=4, metric='euclidean'): 34 | """Compute exact pairwise distances.""" 35 | d = sklearn.metrics.pairwise.pairwise_distances( 36 | z, metric=metric, n_jobs=-2) 37 | # k-NN graph. 38 | idx = np.argsort(d)[:, 1:k+1] 39 | d.sort() 40 | d = d[:, 1:k+1] 41 | return d, idx 42 | 43 | 44 | def distance_lshforest(z, k=4, metric='cosine'): 45 | """Return an approximation of the k-nearest cosine distances.""" 46 | assert metric is 'cosine' 47 | lshf = sklearn.neighbors.LSHForest() 48 | lshf.fit(z) 49 | dist, idx = lshf.kneighbors(z, n_neighbors=k+1) 50 | assert dist.min() < 1e-10 51 | dist[dist < 0] = 0 52 | return dist, idx 53 | 54 | # TODO: other ANNs s.a. NMSLIB, EFANNA, FLANN, Annoy, sklearn neighbors, PANN 55 | 56 | 57 | def adjacency(dist, idx): 58 | """Return the adjacency matrix of a kNN graph.""" 59 | M, k = dist.shape 60 | assert M, k == idx.shape 61 | assert dist.min() >= 0 62 | 63 | # Weights. 64 | sigma2 = np.mean(dist[:, -1])**2 65 | dist = np.exp(- dist**2 / sigma2) 66 | 67 | # Weight matrix. 68 | I = np.arange(0, M).repeat(k) 69 | J = idx.reshape(M*k) 70 | V = dist.reshape(M*k) 71 | W = scipy.sparse.coo_matrix((V, (I, J)), shape=(M, M)) 72 | 73 | # No self-connections. 74 | W.setdiag(0) 75 | 76 | # Non-directed graph. 77 | bigger = W.T > W 78 | W = W - W.multiply(bigger) + W.T.multiply(bigger) 79 | 80 | assert W.nnz % 2 == 0 81 | assert np.abs(W - W.T).mean() < 1e-10 82 | assert type(W) is scipy.sparse.csr.csr_matrix 83 | return W 84 | 85 | 86 | def replace_random_edges(A, noise_level): 87 | """Replace randomly chosen edges by random edges.""" 88 | M, M = A.shape 89 | n = int(noise_level * A.nnz // 2) 90 | 91 | indices = np.random.permutation(A.nnz//2)[:n] 92 | rows = np.random.randint(0, M, n) 93 | cols = np.random.randint(0, M, n) 94 | vals = np.random.uniform(0, 1, n) 95 | assert len(indices) == len(rows) == len(cols) == len(vals) 96 | 97 | A_coo = scipy.sparse.triu(A, format='coo') 98 | assert A_coo.nnz == A.nnz // 2 99 | assert A_coo.nnz >= n 100 | A = A.tolil() 101 | 102 | for idx, row, col, val in zip(indices, rows, cols, vals): 103 | old_row = A_coo.row[idx] 104 | old_col = A_coo.col[idx] 105 | 106 | A[old_row, old_col] = 0 107 | A[old_col, old_row] = 0 108 | A[row, col] = 1 109 | A[col, row] = 1 110 | 111 | A.setdiag(0) 112 | A = A.tocsr() 113 | A.eliminate_zeros() 114 | return A 115 | 116 | 117 | def laplacian(W, normalized=True): 118 | """Return the Laplacian of the weigth matrix.""" 119 | 120 | # Degree matrix. 121 | d = W.sum(axis=0) 122 | 123 | # Laplacian matrix. 124 | if not normalized: 125 | D = scipy.sparse.diags(d.A.squeeze(), 0) 126 | L = D - W 127 | else: 128 | d += np.spacing(np.array(0, W.dtype)) 129 | d = 1 / np.sqrt(d) 130 | D = scipy.sparse.diags(d.A.squeeze(), 0) 131 | I = scipy.sparse.identity(d.size, dtype=W.dtype) 132 | L = I - D * W * D 133 | 134 | # assert np.abs(L - L.T).mean() < 1e-9 135 | assert type(L) is scipy.sparse.csr.csr_matrix 136 | return L 137 | 138 | 139 | def lmax(L, normalized=True): 140 | """Upper-bound on the spectrum.""" 141 | if normalized: 142 | return 2 143 | else: 144 | return scipy.sparse.linalg.eigsh( 145 | L, k=1, which='LM', return_eigenvectors=False)[0] 146 | 147 | 148 | def fourier(L, algo='eigh', k=1): 149 | """Return the Fourier basis, i.e. the EVD of the Laplacian.""" 150 | 151 | def sort(lamb, U): 152 | idx = lamb.argsort() 153 | return lamb[idx], U[:, idx] 154 | 155 | if algo is 'eig': 156 | lamb, U = np.linalg.eig(L.toarray()) 157 | lamb, U = sort(lamb, U) 158 | elif algo is 'eigh': 159 | lamb, U = np.linalg.eigh(L.toarray()) 160 | elif algo is 'eigs': 161 | lamb, U = scipy.sparse.linalg.eigs(L, k=k, which='SM') 162 | lamb, U = sort(lamb, U) 163 | elif algo is 'eigsh': 164 | lamb, U = scipy.sparse.linalg.eigsh(L, k=k, which='SM') 165 | 166 | return lamb, U 167 | 168 | 169 | def plot_spectrum(L, algo='eig'): 170 | """Plot the spectrum of a list of multi-scale Laplacians L.""" 171 | # Algo is eig to be sure to get all eigenvalues. 172 | plt.figure(figsize=(17, 5)) 173 | for i, lap in enumerate(L): 174 | lamb, U = fourier(lap, algo) 175 | step = 2**i 176 | x = range(step//2, L[0].shape[0], step) 177 | lb = 'L_{} spectrum in [{:1.2e}, {:1.2e}]'.format(i, lamb[0], lamb[-1]) 178 | plt.plot(x, lamb, '.', label=lb) 179 | plt.legend(loc='best') 180 | plt.xlim(0, L[0].shape[0]) 181 | plt.ylim(ymin=0) 182 | 183 | 184 | def lanczos(L, X, K): 185 | """ 186 | Given the graph Laplacian and a data matrix, return a data matrix which can 187 | be multiplied by the filter coefficients to filter X using the Lanczos 188 | polynomial approximation. 189 | """ 190 | M, N = X.shape 191 | assert L.dtype == X.dtype 192 | 193 | def basis(L, X, K): 194 | """ 195 | Lanczos algorithm which computes the orthogonal matrix V and the 196 | tri-diagonal matrix H. 197 | """ 198 | a = np.empty((K, N), L.dtype) 199 | b = np.zeros((K, N), L.dtype) 200 | V = np.empty((K, M, N), L.dtype) 201 | V[0, ...] = X / np.linalg.norm(X, axis=0) 202 | for k in range(K-1): 203 | W = L.dot(V[k, ...]) 204 | a[k, :] = np.sum(W * V[k, ...], axis=0) 205 | W = W - a[k, :] * V[k, ...] - ( 206 | b[k, :] * V[k-1, ...] if k > 0 else 0) 207 | b[k+1, :] = np.linalg.norm(W, axis=0) 208 | V[k+1, ...] = W / b[k+1, :] 209 | a[K-1, :] = np.sum(L.dot(V[K-1, ...]) * V[K-1, ...], axis=0) 210 | return V, a, b 211 | 212 | def diag_H(a, b, K): 213 | """Diagonalize the tri-diagonal H matrix.""" 214 | H = np.zeros((K*K, N), a.dtype) 215 | H[:K**2:K+1, :] = a 216 | H[1:(K-1)*K:K+1, :] = b[1:, :] 217 | H.shape = (K, K, N) 218 | Q = np.linalg.eigh(H.T, UPLO='L')[1] 219 | Q = np.swapaxes(Q, 1, 2).T 220 | return Q 221 | 222 | V, a, b = basis(L, X, K) 223 | Q = diag_H(a, b, K) 224 | Xt = np.empty((K, M, N), L.dtype) 225 | for n in range(N): 226 | Xt[..., n] = Q[..., n].T.dot(V[..., n]) 227 | Xt *= Q[0, :, np.newaxis, :] 228 | Xt *= np.linalg.norm(X, axis=0) 229 | return Xt # Q[0, ...] 230 | 231 | 232 | def rescale_L(L, lmax=2): 233 | """Rescale the Laplacian eigenvalues in [-1,1].""" 234 | M, M = L.shape 235 | I = scipy.sparse.identity(M, format='csr', dtype=L.dtype) 236 | L /= lmax / 2 237 | L -= I 238 | return L 239 | 240 | 241 | def chebyshev(L, X, K): 242 | """Return T_k X where T_k are the Chebyshev polynomials of order up to K. 243 | Complexity is O(KMN).""" 244 | M, N = X.shape 245 | assert L.dtype == X.dtype 246 | 247 | # L = rescale_L(L, lmax) 248 | # Xt = T @ X: MxM @ MxN. 249 | Xt = np.empty((K, M, N), L.dtype) 250 | # Xt_0 = T_0 X = I X = X. 251 | Xt[0, ...] = X 252 | # Xt_1 = T_1 X = L X. 253 | if K > 1: 254 | Xt[1, ...] = L.dot(X) 255 | # Xt_k = 2 L Xt_k-1 - Xt_k-2. 256 | for k in range(2, K): 257 | Xt[k, ...] = 2 * L.dot(Xt[k-1, ...]) - Xt[k-2, ...] 258 | return Xt 259 | -------------------------------------------------------------------------------- /lib/models.py: -------------------------------------------------------------------------------- 1 | from . import graph 2 | 3 | import tensorflow as tf 4 | import sklearn 5 | import scipy.sparse 6 | import numpy as np 7 | import os, time, collections, shutil 8 | 9 | 10 | #NFEATURES = 28**2 11 | #NCLASSES = 10 12 | 13 | 14 | # Common methods for all models 15 | 16 | 17 | class base_model(object): 18 | 19 | def __init__(self): 20 | self.regularizers = [] 21 | 22 | # High-level interface which runs the constructed computational graph. 23 | 24 | def predict(self, data, labels=None, sess=None): 25 | loss = 0 26 | size = data.shape[0] 27 | predictions = np.empty(size) 28 | sess = self._get_session(sess) 29 | for begin in range(0, size, self.batch_size): 30 | end = begin + self.batch_size 31 | end = min([end, size]) 32 | 33 | batch_data = np.zeros((self.batch_size, data.shape[1])) 34 | tmp_data = data[begin:end,:] 35 | if type(tmp_data) is not np.ndarray: 36 | tmp_data = tmp_data.toarray() # convert sparse matrices 37 | batch_data[:end-begin] = tmp_data 38 | feed_dict = {self.ph_data: batch_data, self.ph_dropout: 1} 39 | 40 | # Compute loss if labels are given. 41 | if labels is not None: 42 | batch_labels = np.zeros(self.batch_size) 43 | batch_labels[:end-begin] = labels[begin:end] 44 | feed_dict[self.ph_labels] = batch_labels 45 | batch_pred, batch_loss = sess.run([self.op_prediction, self.op_loss], feed_dict) 46 | loss += batch_loss 47 | else: 48 | batch_pred = sess.run(self.op_prediction, feed_dict) 49 | 50 | predictions[begin:end] = batch_pred[:end-begin] 51 | 52 | if labels is not None: 53 | return predictions, loss * self.batch_size / size 54 | else: 55 | return predictions 56 | 57 | def evaluate(self, data, labels, sess=None): 58 | """ 59 | Runs one evaluation against the full epoch of data. 60 | Return the precision and the number of correct predictions. 61 | Batch evaluation saves memory and enables this to run on smaller GPUs. 62 | 63 | sess: the session in which the model has been trained. 64 | op: the Tensor that returns the number of correct predictions. 65 | data: size N x M 66 | N: number of signals (samples) 67 | M: number of vertices (features) 68 | labels: size N 69 | N: number of signals (samples) 70 | """ 71 | t_process, t_wall = time.process_time(), time.time() 72 | predictions, loss = self.predict(data, labels, sess) 73 | #print(predictions) 74 | ncorrects = sum(predictions == labels) 75 | accuracy = 100 * sklearn.metrics.accuracy_score(labels, predictions) 76 | f1 = 100 * sklearn.metrics.f1_score(labels, predictions, average='weighted') 77 | string = 'accuracy: {:.2f} ({:d} / {:d}), f1 (weighted): {:.2f}, loss: {:.2e}'.format( 78 | accuracy, ncorrects, len(labels), f1, loss) 79 | if sess is None: 80 | string += '\ntime: {:.0f}s (wall {:.0f}s)'.format(time.process_time()-t_process, time.time()-t_wall) 81 | return string, accuracy, f1, loss 82 | 83 | def fit(self, train_data, train_labels, val_data, val_labels): 84 | t_process, t_wall = time.process_time(), time.time() 85 | sess = tf.Session(graph=self.graph) 86 | shutil.rmtree(self._get_path('summaries'), ignore_errors=True) 87 | writer = tf.summary.FileWriter(self._get_path('summaries'), self.graph) 88 | shutil.rmtree(self._get_path('checkpoints'), ignore_errors=True) 89 | os.makedirs(self._get_path('checkpoints')) 90 | path = os.path.join(self._get_path('checkpoints'), 'model') 91 | sess.run(self.op_init) 92 | 93 | # Training. 94 | accuracies = [] 95 | losses = [] 96 | indices = collections.deque() 97 | num_steps = int(self.num_epochs * train_data.shape[0] / self.batch_size) 98 | for step in range(1, num_steps+1): 99 | 100 | # Be sure to have used all the samples before using one a second time. 101 | if len(indices) < self.batch_size: 102 | indices.extend(np.random.permutation(train_data.shape[0])) 103 | idx = [indices.popleft() for i in range(self.batch_size)] 104 | 105 | batch_data, batch_labels = train_data[idx,:], train_labels[idx] 106 | if type(batch_data) is not np.ndarray: 107 | batch_data = batch_data.toarray() # convert sparse matrices 108 | feed_dict = {self.ph_data: batch_data, self.ph_labels: batch_labels, self.ph_dropout: self.dropout} 109 | learning_rate, loss_average = sess.run([self.op_train, self.op_loss_average], feed_dict) 110 | 111 | # Periodical evaluation of the model. 112 | if step % self.eval_frequency == 0 or step == num_steps: 113 | epoch = step * self.batch_size / train_data.shape[0] 114 | print('step {} / {} (epoch {:.2f} / {}):'.format(step, num_steps, epoch, self.num_epochs)) 115 | print(' learning_rate = {:.2e}, loss_average = {:.2e}'.format(learning_rate, loss_average)) 116 | string, accuracy, f1, loss = self.evaluate(val_data, val_labels, sess) 117 | accuracies.append(accuracy) 118 | losses.append(loss) 119 | print(' validation {}'.format(string)) 120 | print(' time: {:.0f}s (wall {:.0f}s)'.format(time.process_time()-t_process, time.time()-t_wall)) 121 | 122 | # Summaries for TensorBoard. 123 | summary = tf.Summary() 124 | summary.ParseFromString(sess.run(self.op_summary, feed_dict)) 125 | summary.value.add(tag='validation/accuracy', simple_value=accuracy) 126 | summary.value.add(tag='validation/f1', simple_value=f1) 127 | summary.value.add(tag='validation/loss', simple_value=loss) 128 | writer.add_summary(summary, step) 129 | 130 | # Save model parameters (for evaluation). 131 | self.op_saver.save(sess, path, global_step=step) 132 | 133 | print('validation accuracy: peak = {:.2f}, mean = {:.2f}'.format(max(accuracies), np.mean(accuracies[-10:]))) 134 | writer.close() 135 | sess.close() 136 | 137 | t_step = (time.time() - t_wall) / num_steps 138 | return accuracies, losses, t_step 139 | 140 | def get_var(self, name): 141 | sess = self._get_session() 142 | var = self.graph.get_tensor_by_name(name + ':0') 143 | val = sess.run(var) 144 | sess.close() 145 | return val 146 | 147 | # Methods to construct the computational graph. 148 | 149 | def build_graph(self, M_0): 150 | """Build the computational graph of the model.""" 151 | self.graph = tf.Graph() 152 | with self.graph.as_default(): 153 | 154 | # Inputs. 155 | with tf.name_scope('inputs'): 156 | self.ph_data = tf.placeholder(tf.float32, (self.batch_size, M_0), 'data') 157 | self.ph_labels = tf.placeholder(tf.int32, (self.batch_size), 'labels') 158 | self.ph_dropout = tf.placeholder(tf.float32, (), 'dropout') 159 | 160 | # Model. 161 | op_logits = self.inference(self.ph_data, self.ph_dropout) 162 | self.op_loss, self.op_loss_average = self.loss(op_logits, self.ph_labels, self.regularization) 163 | self.op_train = self.training(self.op_loss, self.learning_rate, 164 | self.decay_steps, self.decay_rate, self.momentum) 165 | self.op_prediction = self.prediction(op_logits) 166 | 167 | # Initialize variables, i.e. weights and biases. 168 | self.op_init = tf.global_variables_initializer() 169 | 170 | # Summaries for TensorBoard and Save for model parameters. 171 | self.op_summary = tf.summary.merge_all() 172 | self.op_saver = tf.train.Saver(max_to_keep=5) 173 | 174 | self.graph.finalize() 175 | 176 | def inference(self, data, dropout): 177 | """ 178 | It builds the model, i.e. the computational graph, as far as 179 | is required for running the network forward to make predictions, 180 | i.e. return logits given raw data. 181 | 182 | data: size N x M 183 | N: number of signals (samples) 184 | M: number of vertices (features) 185 | training: we may want to discriminate the two, e.g. for dropout. 186 | True: the model is built for training. 187 | False: the model is built for evaluation. 188 | """ 189 | # TODO: optimizations for sparse data 190 | logits = self._inference(data, dropout) 191 | return logits 192 | 193 | def probabilities(self, logits): 194 | """Return the probability of a sample to belong to each class.""" 195 | with tf.name_scope('probabilities'): 196 | probabilities = tf.nn.softmax(logits) 197 | return probabilities 198 | 199 | def prediction(self, logits): 200 | """Return the predicted classes.""" 201 | with tf.name_scope('prediction'): 202 | prediction = tf.argmax(logits, axis=1) 203 | return prediction 204 | 205 | def loss(self, logits, labels, regularization): 206 | """Adds to the inference model the layers required to generate loss.""" 207 | with tf.name_scope('loss'): 208 | with tf.name_scope('cross_entropy'): 209 | labels = tf.to_int64(labels) 210 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 211 | cross_entropy = tf.reduce_mean(cross_entropy) 212 | with tf.name_scope('regularization'): 213 | regularization *= tf.add_n(self.regularizers) 214 | loss = cross_entropy + regularization 215 | 216 | # Summaries for TensorBoard. 217 | tf.summary.scalar('loss/cross_entropy', cross_entropy) 218 | tf.summary.scalar('loss/regularization', regularization) 219 | tf.summary.scalar('loss/total', loss) 220 | with tf.name_scope('averages'): 221 | averages = tf.train.ExponentialMovingAverage(0.9) 222 | op_averages = averages.apply([cross_entropy, regularization, loss]) 223 | tf.summary.scalar('loss/avg/cross_entropy', averages.average(cross_entropy)) 224 | tf.summary.scalar('loss/avg/regularization', averages.average(regularization)) 225 | tf.summary.scalar('loss/avg/total', averages.average(loss)) 226 | with tf.control_dependencies([op_averages]): 227 | loss_average = tf.identity(averages.average(loss), name='control') 228 | return loss, loss_average 229 | 230 | def training(self, loss, learning_rate, decay_steps, decay_rate=0.95, momentum=0.9): 231 | """Adds to the loss model the Ops required to generate and apply gradients.""" 232 | with tf.name_scope('training'): 233 | # Learning rate. 234 | global_step = tf.Variable(0, name='global_step', trainable=False) 235 | if decay_rate != 1: 236 | learning_rate = tf.train.exponential_decay( 237 | learning_rate, global_step, decay_steps, decay_rate, staircase=True) 238 | tf.summary.scalar('learning_rate', learning_rate) 239 | # Optimizer. 240 | if momentum == 0: 241 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 242 | #optimizer = tf.train.AdamOptimizer(learning_rate=0.001) 243 | else: 244 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum) 245 | grads = optimizer.compute_gradients(loss) 246 | op_gradients = optimizer.apply_gradients(grads, global_step=global_step) 247 | # Histograms. 248 | for grad, var in grads: 249 | if grad is None: 250 | print('warning: {} has no gradient'.format(var.op.name)) 251 | else: 252 | tf.summary.histogram(var.op.name + '/gradients', grad) 253 | # The op return the learning rate. 254 | with tf.control_dependencies([op_gradients]): 255 | op_train = tf.identity(learning_rate, name='control') 256 | return op_train 257 | 258 | # Helper methods. 259 | 260 | def _get_path(self, folder): 261 | path = os.path.dirname(os.path.realpath(__file__)) 262 | return os.path.join(path, '..', folder, self.dir_name) 263 | 264 | def _get_session(self, sess=None): 265 | """Restore parameters if no session given.""" 266 | if sess is None: 267 | sess = tf.Session(graph=self.graph) 268 | filename = tf.train.latest_checkpoint(self._get_path('checkpoints')) 269 | self.op_saver.restore(sess, filename) 270 | return sess 271 | 272 | def _weight_variable(self, shape, regularization=True): 273 | initial = tf.truncated_normal_initializer(0, 0.1) 274 | var = tf.get_variable('weights', shape, tf.float32, initializer=initial) 275 | if regularization: 276 | self.regularizers.append(tf.nn.l2_loss(var)) 277 | tf.summary.histogram(var.op.name, var) 278 | return var 279 | 280 | def _bias_variable(self, shape, regularization=True): 281 | initial = tf.constant_initializer(0.1) 282 | var = tf.get_variable('bias', shape, tf.float32, initializer=initial) 283 | if regularization: 284 | self.regularizers.append(tf.nn.l2_loss(var)) 285 | tf.summary.histogram(var.op.name, var) 286 | return var 287 | 288 | def _conv2d(self, x, W): 289 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 290 | 291 | 292 | # Fully connected 293 | 294 | 295 | class fc1(base_model): 296 | def __init__(self): 297 | super().__init__() 298 | def _inference(self, x, dropout): 299 | W = self._weight_variable([NFEATURES, NCLASSES]) 300 | b = self._bias_variable([NCLASSES]) 301 | y = tf.matmul(x, W) + b 302 | return y 303 | 304 | class fc2(base_model): 305 | def __init__(self, nhiddens): 306 | super().__init__() 307 | self.nhiddens = nhiddens 308 | def _inference(self, x, dropout): 309 | with tf.name_scope('fc1'): 310 | W = self._weight_variable([NFEATURES, self.nhiddens]) 311 | b = self._bias_variable([self.nhiddens]) 312 | y = tf.nn.relu(tf.matmul(x, W) + b) 313 | with tf.name_scope('fc2'): 314 | W = self._weight_variable([self.nhiddens, NCLASSES]) 315 | b = self._bias_variable([NCLASSES]) 316 | y = tf.matmul(y, W) + b 317 | return y 318 | 319 | 320 | # Convolutional 321 | 322 | 323 | class cnn2(base_model): 324 | """Simple convolutional model.""" 325 | def __init__(self, K, F): 326 | super().__init__() 327 | self.K = K # Patch size 328 | self.F = F # Number of features 329 | def _inference(self, x, dropout): 330 | with tf.name_scope('conv1'): 331 | W = self._weight_variable([self.K, self.K, 1, self.F]) 332 | b = self._bias_variable([self.F]) 333 | # b = self._bias_variable([1, 28, 28, self.F]) 334 | x_2d = tf.reshape(x, [-1,28,28,1]) 335 | y_2d = self._conv2d(x_2d, W) + b 336 | y_2d = tf.nn.relu(y_2d) 337 | with tf.name_scope('fc1'): 338 | y = tf.reshape(y_2d, [-1, NFEATURES*self.F]) 339 | W = self._weight_variable([NFEATURES*self.F, NCLASSES]) 340 | b = self._bias_variable([NCLASSES]) 341 | y = tf.matmul(y, W) + b 342 | return y 343 | 344 | class fcnn2(base_model): 345 | """CNN using the FFT.""" 346 | def __init__(self, F): 347 | super().__init__() 348 | self.F = F # Number of features 349 | def _inference(self, x, dropout): 350 | with tf.name_scope('conv1'): 351 | # Transform to Fourier domain 352 | x_2d = tf.reshape(x, [-1, 28, 28]) 353 | x_2d = tf.complex(x_2d, 0) 354 | xf_2d = tf.fft2d(x_2d) 355 | xf = tf.reshape(xf_2d, [-1, NFEATURES]) 356 | xf = tf.expand_dims(xf, 1) # NSAMPLES x 1 x NFEATURES 357 | xf = tf.transpose(xf) # NFEATURES x 1 x NSAMPLES 358 | # Filter 359 | Wreal = self._weight_variable([int(NFEATURES/2), self.F, 1]) 360 | Wimg = self._weight_variable([int(NFEATURES/2), self.F, 1]) 361 | W = tf.complex(Wreal, Wimg) 362 | xf = xf[:int(NFEATURES/2), :, :] 363 | yf = tf.matmul(W, xf) # for each feature 364 | yf = tf.concat([yf, tf.conj(yf)], axis=0) 365 | yf = tf.transpose(yf) # NSAMPLES x NFILTERS x NFEATURES 366 | yf_2d = tf.reshape(yf, [-1, 28, 28]) 367 | # Transform back to spatial domain 368 | y_2d = tf.ifft2d(yf_2d) 369 | y_2d = tf.real(y_2d) 370 | y = tf.reshape(y_2d, [-1, self.F, NFEATURES]) 371 | # Bias and non-linearity 372 | b = self._bias_variable([1, self.F, 1]) 373 | # b = self._bias_variable([1, self.F, NFEATURES]) 374 | y += b # NSAMPLES x NFILTERS x NFEATURES 375 | y = tf.nn.relu(y) 376 | with tf.name_scope('fc1'): 377 | W = self._weight_variable([self.F*NFEATURES, NCLASSES]) 378 | b = self._bias_variable([NCLASSES]) 379 | y = tf.reshape(y, [-1, self.F*NFEATURES]) 380 | y = tf.matmul(y, W) + b 381 | return y 382 | 383 | 384 | # Graph convolutional 385 | 386 | 387 | class fgcnn2(base_model): 388 | """Graph CNN with full weights, i.e. patch has the same size as input.""" 389 | def __init__(self, L, F): 390 | super().__init__() 391 | #self.L = L # Graph Laplacian, NFEATURES x NFEATURES 392 | self.F = F # Number of filters 393 | _, self.U = graph.fourier(L) 394 | def _inference(self, x, dropout): 395 | # x: NSAMPLES x NFEATURES 396 | with tf.name_scope('gconv1'): 397 | # Transform to Fourier domain 398 | U = tf.constant(self.U, dtype=tf.float32) 399 | xf = tf.matmul(x, U) 400 | xf = tf.expand_dims(xf, 1) # NSAMPLES x 1 x NFEATURES 401 | xf = tf.transpose(xf) # NFEATURES x 1 x NSAMPLES 402 | # Filter 403 | W = self._weight_variable([NFEATURES, self.F, 1]) 404 | yf = tf.matmul(W, xf) # for each feature 405 | yf = tf.transpose(yf) # NSAMPLES x NFILTERS x NFEATURES 406 | yf = tf.reshape(yf, [-1, NFEATURES]) 407 | # Transform back to graph domain 408 | Ut = tf.transpose(U) 409 | y = tf.matmul(yf, Ut) 410 | y = tf.reshape(yf, [-1, self.F, NFEATURES]) 411 | # Bias and non-linearity 412 | b = self._bias_variable([1, self.F, 1]) 413 | # b = self._bias_variable([1, self.F, NFEATURES]) 414 | y += b # NSAMPLES x NFILTERS x NFEATURES 415 | y = tf.nn.relu(y) 416 | with tf.name_scope('fc1'): 417 | W = self._weight_variable([self.F*NFEATURES, NCLASSES]) 418 | b = self._bias_variable([NCLASSES]) 419 | y = tf.reshape(y, [-1, self.F*NFEATURES]) 420 | y = tf.matmul(y, W) + b 421 | return y 422 | 423 | 424 | class lgcnn2_1(base_model): 425 | """Graph CNN which uses the Lanczos approximation.""" 426 | def __init__(self, L, F, K): 427 | super().__init__() 428 | self.L = L # Graph Laplacian, M x M 429 | self.F = F # Number of filters 430 | self.K = K # Polynomial order, i.e. filter size (number of hopes) 431 | def _inference(self, x, dropout): 432 | with tf.name_scope('gconv1'): 433 | N, M, K = x.get_shape() # N: number of samples, M: number of features 434 | M = int(M) 435 | # Transform to Lanczos basis 436 | xl = tf.reshape(x, [-1, self.K]) # NM x K 437 | # Filter 438 | W = self._weight_variable([self.K, self.F]) 439 | y = tf.matmul(xl, W) # NM x F 440 | y = tf.reshape(y, [-1, M, self.F]) # N x M x F 441 | # Bias and non-linearity 442 | b = self._bias_variable([1, 1, self.F]) 443 | # b = self._bias_variable([1, M, self.F]) 444 | y += b # N x M x F 445 | y = tf.nn.relu(y) 446 | with tf.name_scope('fc1'): 447 | W = self._weight_variable([self.F*M, NCLASSES]) 448 | b = self._bias_variable([NCLASSES]) 449 | y = tf.reshape(y, [-1, self.F*M]) 450 | y = tf.matmul(y, W) + b 451 | return y 452 | 453 | class lgcnn2_2(base_model): 454 | """Graph CNN which uses the Lanczos approximation.""" 455 | def __init__(self, L, F, K): 456 | super().__init__() 457 | self.L = L # Graph Laplacian, M x M 458 | self.F = F # Number of filters 459 | self.K = K # Polynomial order, i.e. filter size (number of hopes) 460 | def _inference(self, x, dropout): 461 | with tf.name_scope('gconv1'): 462 | N, M = x.get_shape() # N: number of samples, M: number of features 463 | M = int(M) 464 | # Transform to Lanczos basis 465 | xl = tf.transpose(x) # M x N 466 | def lanczos(x): 467 | return graph.lanczos(self.L, x, self.K) 468 | xl = tf.py_func(lanczos, [xl], [tf.float32])[0] 469 | xl = tf.transpose(xl) # N x M x K 470 | xl = tf.reshape(xl, [-1, self.K]) # NM x K 471 | # Filter 472 | W = self._weight_variable([self.K, self.F]) 473 | y = tf.matmul(xl, W) # NM x F 474 | y = tf.reshape(y, [-1, M, self.F]) # N x M x F 475 | # Bias and non-linearity 476 | # b = self._bias_variable([1, 1, self.F]) 477 | b = self._bias_variable([1, M, self.F]) 478 | y += b # N x M x F 479 | y = tf.nn.relu(y) 480 | with tf.name_scope('fc1'): 481 | W = self._weight_variable([self.F*M, NCLASSES]) 482 | b = self._bias_variable([NCLASSES]) 483 | y = tf.reshape(y, [-1, self.F*M]) 484 | y = tf.matmul(y, W) + b 485 | return y 486 | 487 | 488 | class cgcnn2_2(base_model): 489 | """Graph CNN which uses the Chebyshev approximation.""" 490 | def __init__(self, L, F, K): 491 | super().__init__() 492 | self.L = graph.rescale_L(L, lmax=2) # Graph Laplacian, M x M 493 | self.F = F # Number of filters 494 | self.K = K # Polynomial order, i.e. filter size (number of hopes) 495 | def _inference(self, x, dropout): 496 | with tf.name_scope('gconv1'): 497 | N, M = x.get_shape() # N: number of samples, M: number of features 498 | M = int(M) 499 | # Transform to Chebyshev basis 500 | xc = tf.transpose(x) # M x N 501 | def chebyshev(x): 502 | return graph.chebyshev(self.L, x, self.K) 503 | xc = tf.py_func(chebyshev, [xc], [tf.float32])[0] 504 | xc = tf.transpose(xc) # N x M x K 505 | xc = tf.reshape(xc, [-1, self.K]) # NM x K 506 | # Filter 507 | W = self._weight_variable([self.K, self.F]) 508 | y = tf.matmul(xc, W) # NM x F 509 | y = tf.reshape(y, [-1, M, self.F]) # N x M x F 510 | # Bias and non-linearity 511 | # b = self._bias_variable([1, 1, self.F]) 512 | b = self._bias_variable([1, M, self.F]) 513 | y += b # N x M x F 514 | y = tf.nn.relu(y) 515 | with tf.name_scope('fc1'): 516 | W = self._weight_variable([self.F*M, NCLASSES]) 517 | b = self._bias_variable([NCLASSES]) 518 | y = tf.reshape(y, [-1, self.F*M]) 519 | y = tf.matmul(y, W) + b 520 | return y 521 | 522 | 523 | class cgcnn2_3(base_model): 524 | """Graph CNN which uses the Chebyshev approximation.""" 525 | def __init__(self, L, F, K): 526 | super().__init__() 527 | L = graph.rescale_L(L, lmax=2) # Graph Laplacian, M x M 528 | self.L = L.toarray() 529 | self.F = F # Number of filters 530 | self.K = K # Polynomial order, i.e. filter size (number of hopes) 531 | def _inference(self, x, dropout): 532 | with tf.name_scope('gconv1'): 533 | N, M = x.get_shape() # N: number of samples, M: number of features 534 | M = int(M) 535 | # Filter 536 | W = self._weight_variable([self.K, self.F]) 537 | def filter(xt, k): 538 | xt = tf.reshape(xt, [-1, 1]) # NM x 1 539 | w = tf.slice(W, [k,0], [1,-1]) # 1 x F 540 | y = tf.matmul(xt, w) # NM x F 541 | return tf.reshape(y, [-1, M, self.F]) # N x M x F 542 | xt0 = x 543 | y = filter(xt0, 0) 544 | if self.K > 1: 545 | xt1 = tf.matmul(x, self.L, b_is_sparse=True) # N x M 546 | y += filter(xt1, 1) 547 | for k in range(2, self.K): 548 | xt2 = 2 * tf.matmul(xt1, self.L, b_is_sparse=True) - xt0 # N x M 549 | y += filter(xt2, k) 550 | xt0, xt1 = xt1, xt2 551 | # Bias and non-linearity 552 | # b = self._bias_variable([1, 1, self.F]) 553 | b = self._bias_variable([1, M, self.F]) 554 | y += b # N x M x F 555 | y = tf.nn.relu(y) 556 | with tf.name_scope('fc1'): 557 | W = self._weight_variable([self.F*M, NCLASSES]) 558 | b = self._bias_variable([NCLASSES]) 559 | y = tf.reshape(y, [-1, self.F*M]) 560 | y = tf.matmul(y, W) + b 561 | return y 562 | 563 | 564 | class cgcnn2_4(base_model): 565 | """Graph CNN which uses the Chebyshev approximation.""" 566 | def __init__(self, L, F, K): 567 | super().__init__() 568 | L = graph.rescale_L(L, lmax=2) # Graph Laplacian, M x M 569 | L = L.tocoo() 570 | data = L.data 571 | indices = np.empty((L.nnz, 2)) 572 | indices[:,0] = L.row 573 | indices[:,1] = L.col 574 | L = tf.SparseTensor(indices, data, L.shape) 575 | self.L = tf.sparse_reorder(L) 576 | self.F = F # Number of filters 577 | self.K = K # Polynomial order, i.e. filter size (number of hopes) 578 | def _inference(self, x, dropout): 579 | with tf.name_scope('gconv1'): 580 | N, M = x.get_shape() # N: number of samples, M: number of features 581 | M = int(M) 582 | # Filter 583 | W = self._weight_variable([self.K, self.F]) 584 | def filter(xt, k): 585 | xt = tf.transpose(xt) # N x M 586 | xt = tf.reshape(xt, [-1, 1]) # NM x 1 587 | w = tf.slice(W, [k,0], [1,-1]) # 1 x F 588 | y = tf.matmul(xt, w) # NM x F 589 | return tf.reshape(y, [-1, M, self.F]) # N x M x F 590 | xt0 = tf.transpose(x) # M x N 591 | y = filter(xt0, 0) 592 | if self.K > 1: 593 | xt1 = tf.sparse_tensor_dense_matmul(self.L, xt0) 594 | y += filter(xt1, 1) 595 | for k in range(2, self.K): 596 | xt2 = 2 * tf.sparse_tensor_dense_matmul(self.L, xt1) - xt0 # M x N 597 | y += filter(xt2, k) 598 | xt0, xt1 = xt1, xt2 599 | # Bias and non-linearity 600 | # b = self._bias_variable([1, 1, self.F]) 601 | b = self._bias_variable([1, M, self.F]) 602 | y += b # N x M x F 603 | y = tf.nn.relu(y) 604 | with tf.name_scope('fc1'): 605 | W = self._weight_variable([self.F*M, NCLASSES]) 606 | b = self._bias_variable([NCLASSES]) 607 | y = tf.reshape(y, [-1, self.F*M]) 608 | y = tf.matmul(y, W) + b 609 | return y 610 | 611 | 612 | class cgcnn2_5(base_model): 613 | """Graph CNN which uses the Chebyshev approximation.""" 614 | def __init__(self, L, F, K): 615 | super().__init__() 616 | L = graph.rescale_L(L, lmax=2) # Graph Laplacian, M x M 617 | L = L.tocoo() 618 | data = L.data 619 | indices = np.empty((L.nnz, 2)) 620 | indices[:,0] = L.row 621 | indices[:,1] = L.col 622 | L = tf.SparseTensor(indices, data, L.shape) 623 | self.L = tf.sparse_reorder(L) 624 | self.F = F # Number of filters 625 | self.K = K # Polynomial order, i.e. filter size (number of hopes) 626 | def _inference(self, x, dropout): 627 | with tf.name_scope('gconv1'): 628 | N, M = x.get_shape() # N: number of samples, M: number of features 629 | M = int(M) 630 | # Transform to Chebyshev basis 631 | xt0 = tf.transpose(x) # M x N 632 | xt = tf.expand_dims(xt0, 0) # 1 x M x N 633 | def concat(xt, x): 634 | x = tf.expand_dims(x, 0) # 1 x M x N 635 | return tf.concat([xt, x], axis=0) # K x M x N 636 | if self.K > 1: 637 | xt1 = tf.sparse_tensor_dense_matmul(self.L, xt0) 638 | xt = concat(xt, xt1) 639 | for k in range(2, self.K): 640 | xt2 = 2 * tf.sparse_tensor_dense_matmul(self.L, xt1) - xt0 # M x N 641 | xt = concat(xt, xt2) 642 | xt0, xt1 = xt1, xt2 643 | xt = tf.transpose(xt) # N x M x K 644 | xt = tf.reshape(xt, [-1,self.K]) # NM x K 645 | # Filter 646 | W = self._weight_variable([self.K, self.F]) 647 | y = tf.matmul(xt, W) # NM x F 648 | y = tf.reshape(y, [-1, M, self.F]) # N x M x F 649 | # Bias and non-linearity 650 | # b = self._bias_variable([1, 1, self.F]) 651 | b = self._bias_variable([1, M, self.F]) 652 | y += b # N x M x F 653 | y = tf.nn.relu(y) 654 | with tf.name_scope('fc1'): 655 | W = self._weight_variable([self.F*M, NCLASSES]) 656 | b = self._bias_variable([NCLASSES]) 657 | y = tf.reshape(y, [-1, self.F*M]) 658 | y = tf.matmul(y, W) + b 659 | return y 660 | 661 | 662 | def bspline_basis(K, x, degree=3): 663 | """ 664 | Return the B-spline basis. 665 | 666 | K: number of control points. 667 | x: evaluation points 668 | or number of evenly distributed evaluation points. 669 | degree: degree of the spline. Cubic spline by default. 670 | """ 671 | if np.isscalar(x): 672 | x = np.linspace(0, 1, x) 673 | 674 | # Evenly distributed knot vectors. 675 | kv1 = x.min() * np.ones(degree) 676 | kv2 = np.linspace(x.min(), x.max(), K-degree+1) 677 | kv3 = x.max() * np.ones(degree) 678 | kv = np.concatenate((kv1, kv2, kv3)) 679 | 680 | # Cox - DeBoor recursive function to compute one spline over x. 681 | def cox_deboor(k, d): 682 | # Test for end conditions, the rectangular degree zero spline. 683 | if (d == 0): 684 | return ((x - kv[k] >= 0) & (x - kv[k + 1] < 0)).astype(int) 685 | 686 | denom1 = kv[k + d] - kv[k] 687 | term1 = 0 688 | if denom1 > 0: 689 | term1 = ((x - kv[k]) / denom1) * cox_deboor(k, d - 1) 690 | 691 | denom2 = kv[k + d + 1] - kv[k + 1] 692 | term2 = 0 693 | if denom2 > 0: 694 | term2 = ((-(x - kv[k + d + 1]) / denom2) * cox_deboor(k + 1, d - 1)) 695 | 696 | return term1 + term2 697 | 698 | # Compute basis for each point 699 | basis = np.column_stack([cox_deboor(k, degree) for k in range(K)]) 700 | basis[-1,-1] = 1 701 | return basis 702 | 703 | 704 | class cgcnn(base_model): 705 | """ 706 | Graph CNN which uses the Chebyshev approximation. 707 | 708 | The following are hyper-parameters of graph convolutional layers. 709 | They are lists, which length is equal to the number of gconv layers. 710 | F: Number of features. 711 | K: List of polynomial orders, i.e. filter sizes or number of hopes. 712 | p: Pooling size. 713 | Should be 1 (no pooling) or a power of 2 (reduction by 2 at each coarser level). 714 | Beware to have coarsened enough. 715 | 716 | L: List of Graph Laplacians. Size M x M. One per coarsening level. 717 | 718 | The following are hyper-parameters of fully connected layers. 719 | They are lists, which length is equal to the number of fc layers. 720 | M: Number of features per sample, i.e. number of hidden neurons. 721 | The last layer is the softmax, i.e. M[-1] is the number of classes. 722 | 723 | The following are choices of implementation for various blocks. 724 | filter: filtering operation, e.g. chebyshev5, lanczos2 etc. 725 | brelu: bias and relu, e.g. b1relu or b2relu. 726 | pool: pooling, e.g. mpool1. 727 | 728 | Training parameters: 729 | num_epochs: Number of training epochs. 730 | learning_rate: Initial learning rate. 731 | decay_rate: Base of exponential decay. No decay with 1. 732 | decay_steps: Number of steps after which the learning rate decays. 733 | momentum: Momentum. 0 indicates no momentum. 734 | 735 | Regularization parameters: 736 | regularization: L2 regularizations of weights and biases. 737 | dropout: Dropout (fc layers): probability to keep hidden neurons. No dropout with 1. 738 | batch_size: Batch size. Must divide evenly into the dataset sizes. 739 | eval_frequency: Number of steps between evaluations. 740 | 741 | Directories: 742 | dir_name: Name for directories (summaries and model parameters). 743 | """ 744 | def __init__(self, L, F, K, p, M, filter='chebyshev5', brelu='b1relu', pool='mpool1', 745 | num_epochs=20, learning_rate=0.1, decay_rate=0.95, decay_steps=None, momentum=0.9, 746 | regularization=0, dropout=0, batch_size=100, eval_frequency=200, 747 | dir_name=''): 748 | super().__init__() 749 | 750 | # Verify the consistency w.r.t. the number of layers. 751 | assert len(L) >= len(F) == len(K) == len(p) 752 | assert np.all(np.array(p) >= 1) 753 | p_log2 = np.where(np.array(p) > 1, np.log2(p), 0) 754 | assert np.all(np.mod(p_log2, 1) == 0) # Powers of 2. 755 | assert len(L) >= 1 + np.sum(p_log2) # Enough coarsening levels for pool sizes. 756 | 757 | # Keep the useful Laplacians only. May be zero. 758 | M_0 = L[0].shape[0] 759 | j = 0 760 | self.L = [] 761 | for pp in p: 762 | self.L.append(L[j]) 763 | j += int(np.log2(pp)) if pp > 1 else 0 764 | L = self.L 765 | 766 | # Print information about NN architecture. 767 | Ngconv = len(p) 768 | Nfc = len(M) 769 | print('NN architecture') 770 | print(' input: M_0 = {}'.format(M_0)) 771 | for i in range(Ngconv): 772 | print(' layer {0}: cgconv{0}'.format(i+1)) 773 | print(' representation: M_{0} * F_{1} / p_{1} = {2} * {3} / {4} = {5}'.format( 774 | i, i+1, L[i].shape[0], F[i], p[i], L[i].shape[0]*F[i]//p[i])) 775 | F_last = F[i-1] if i > 0 else 1 776 | print(' weights: F_{0} * F_{1} * K_{1} = {2} * {3} * {4} = {5}'.format( 777 | i, i+1, F_last, F[i], K[i], F_last*F[i]*K[i])) 778 | if brelu == 'b1relu': 779 | print(' biases: F_{} = {}'.format(i+1, F[i])) 780 | elif brelu == 'b2relu': 781 | print(' biases: M_{0} * F_{0} = {1} * {2} = {3}'.format( 782 | i+1, L[i].shape[0], F[i], L[i].shape[0]*F[i])) 783 | for i in range(Nfc): 784 | name = 'logits (softmax)' if i == Nfc-1 else 'fc{}'.format(i+1) 785 | print(' layer {}: {}'.format(Ngconv+i+1, name)) 786 | print(' representation: M_{} = {}'.format(Ngconv+i+1, M[i])) 787 | M_last = M[i-1] if i > 0 else M_0 if Ngconv == 0 else L[-1].shape[0] * F[-1] // p[-1] 788 | print(' weights: M_{} * M_{} = {} * {} = {}'.format( 789 | Ngconv+i, Ngconv+i+1, M_last, M[i], M_last*M[i])) 790 | print(' biases: M_{} = {}'.format(Ngconv+i+1, M[i])) 791 | 792 | # Store attributes and bind operations. 793 | self.L, self.F, self.K, self.p, self.M = L, F, K, p, M 794 | self.num_epochs, self.learning_rate = num_epochs, learning_rate 795 | self.decay_rate, self.decay_steps, self.momentum = decay_rate, decay_steps, momentum 796 | self.regularization, self.dropout = regularization, dropout 797 | self.batch_size, self.eval_frequency = batch_size, eval_frequency 798 | self.dir_name = dir_name 799 | self.filter = getattr(self, filter) 800 | self.brelu = getattr(self, brelu) 801 | self.pool = getattr(self, pool) 802 | 803 | # Build the computational graph. 804 | self.build_graph(M_0) 805 | 806 | def filter_in_fourier(self, x, L, Fout, K, U, W): 807 | # TODO: N x F x M would avoid the permutations 808 | N, M, Fin = x.get_shape() 809 | N, M, Fin = int(N), int(M), int(Fin) 810 | x = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 811 | # Transform to Fourier domain 812 | x = tf.reshape(x, [M, Fin*N]) # M x Fin*N 813 | x = tf.matmul(U, x) # M x Fin*N 814 | x = tf.reshape(x, [M, Fin, N]) # M x Fin x N 815 | # Filter 816 | x = tf.matmul(W, x) # for each feature 817 | x = tf.transpose(x) # N x Fout x M 818 | x = tf.reshape(x, [N*Fout, M]) # N*Fout x M 819 | # Transform back to graph domain 820 | x = tf.matmul(x, U) # N*Fout x M 821 | x = tf.reshape(x, [N, Fout, M]) # N x Fout x M 822 | return tf.transpose(x, perm=[0, 2, 1]) # N x M x Fout 823 | 824 | def fourier(self, x, L, Fout, K): 825 | assert K == L.shape[0] # artificial but useful to compute number of parameters 826 | N, M, Fin = x.get_shape() 827 | N, M, Fin = int(N), int(M), int(Fin) 828 | # Fourier basis 829 | _, U = graph.fourier(L) 830 | U = tf.constant(U.T, dtype=tf.float32) 831 | # Weights 832 | W = self._weight_variable([M, Fout, Fin], regularization=False) 833 | return self.filter_in_fourier(x, L, Fout, K, U, W) 834 | 835 | def spline(self, x, L, Fout, K): 836 | N, M, Fin = x.get_shape() 837 | N, M, Fin = int(N), int(M), int(Fin) 838 | # Fourier basis 839 | lamb, U = graph.fourier(L) 840 | U = tf.constant(U.T, dtype=tf.float32) # M x M 841 | # Spline basis 842 | B = bspline_basis(K, lamb, degree=3) # M x K 843 | #B = bspline_basis(K, len(lamb), degree=3) # M x K 844 | B = tf.constant(B, dtype=tf.float32) 845 | # Weights 846 | W = self._weight_variable([K, Fout*Fin], regularization=False) 847 | W = tf.matmul(B, W) # M x Fout*Fin 848 | W = tf.reshape(W, [M, Fout, Fin]) 849 | return self.filter_in_fourier(x, L, Fout, K, U, W) 850 | 851 | def chebyshev2(self, x, L, Fout, K): 852 | """ 853 | Filtering with Chebyshev interpolation 854 | Implementation: numpy. 855 | 856 | Data: x of size N x M x F 857 | N: number of signals 858 | M: number of vertices 859 | F: number of features per signal per vertex 860 | """ 861 | N, M, Fin = x.get_shape() 862 | N, M, Fin = int(N), int(M), int(Fin) 863 | # Rescale Laplacian. Copy to not modify the shared L. 864 | L = scipy.sparse.csr_matrix(L) 865 | L = graph.rescale_L(L, lmax=2) 866 | # Transform to Chebyshev basis 867 | x = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 868 | x = tf.reshape(x, [M, Fin*N]) # M x Fin*N 869 | def chebyshev(x): 870 | return graph.chebyshev(L, x, K) 871 | x = tf.py_func(chebyshev, [x], [tf.float32])[0] # K x M x Fin*N 872 | x = tf.reshape(x, [K, M, Fin, N]) # K x M x Fin x N 873 | x = tf.transpose(x, perm=[3,1,2,0]) # N x M x Fin x K 874 | x = tf.reshape(x, [N*M, Fin*K]) # N*M x Fin*K 875 | # Filter: Fin*Fout filters of order K, i.e. one filterbank per feature. 876 | W = self._weight_variable([Fin*K, Fout], regularization=False) 877 | x = tf.matmul(x, W) # N*M x Fout 878 | return tf.reshape(x, [N, M, Fout]) # N x M x Fout 879 | 880 | def chebyshev5(self, x, L, Fout, K): 881 | N, M, Fin = x.get_shape() 882 | N, M, Fin = int(N), int(M), int(Fin) 883 | # Rescale Laplacian and store as a TF sparse tensor. Copy to not modify the shared L. 884 | L = scipy.sparse.csr_matrix(L) 885 | L = graph.rescale_L(L, lmax=2) 886 | L = L.tocoo() 887 | indices = np.column_stack((L.row, L.col)) 888 | L = tf.SparseTensor(indices, L.data, L.shape) 889 | L = tf.sparse_reorder(L) 890 | # Transform to Chebyshev basis 891 | x0 = tf.transpose(x, perm=[1, 2, 0]) # M x Fin x N 892 | x0 = tf.reshape(x0, [M, Fin*N]) # M x Fin*N 893 | x = tf.expand_dims(x0, 0) # 1 x M x Fin*N 894 | def concat(x, x_): 895 | x_ = tf.expand_dims(x_, 0) # 1 x M x Fin*N 896 | return tf.concat([x, x_], axis=0) # K x M x Fin*N 897 | if K > 1: 898 | x1 = tf.sparse_tensor_dense_matmul(L, x0) 899 | x = concat(x, x1) 900 | for k in range(2, K): 901 | x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0 # M x Fin*N 902 | x = concat(x, x2) 903 | x0, x1 = x1, x2 904 | x = tf.reshape(x, [K, M, Fin, N]) # K x M x Fin x N 905 | x = tf.transpose(x, perm=[3,1,2,0]) # N x M x Fin x K 906 | x = tf.reshape(x, [N*M, Fin*K]) # N*M x Fin*K 907 | # Filter: Fin*Fout filters of order K, i.e. one filterbank per feature pair. 908 | W = self._weight_variable([Fin*K, Fout], regularization=False) 909 | x = tf.matmul(x, W) # N*M x Fout 910 | return tf.reshape(x, [N, M, Fout]) # N x M x Fout 911 | 912 | def b1relu(self, x): 913 | """Bias and ReLU. One bias per filter.""" 914 | N, M, F = x.get_shape() 915 | b = self._bias_variable([1, 1, int(F)], regularization=False) 916 | return tf.nn.relu(x + b) 917 | 918 | def b2relu(self, x): 919 | """Bias and ReLU. One bias per vertex per filter.""" 920 | N, M, F = x.get_shape() 921 | b = self._bias_variable([1, int(M), int(F)], regularization=False) 922 | return tf.nn.relu(x + b) 923 | 924 | def mpool1(self, x, p): 925 | """Max pooling of size p. Should be a power of 2.""" 926 | if p > 1: 927 | x = tf.expand_dims(x, 3) # N x M x F x 1 928 | x = tf.nn.max_pool(x, ksize=[1,p,1,1], strides=[1,p,1,1], padding='SAME') 929 | #tf.maximum 930 | return tf.squeeze(x, [3]) # N x M/p x F 931 | else: 932 | return x 933 | 934 | def apool1(self, x, p): 935 | """Average pooling of size p. Should be a power of 2.""" 936 | if p > 1: 937 | x = tf.expand_dims(x, 3) # N x M x F x 1 938 | x = tf.nn.avg_pool(x, ksize=[1,p,1,1], strides=[1,p,1,1], padding='SAME') 939 | return tf.squeeze(x, [3]) # N x M/p x F 940 | else: 941 | return x 942 | 943 | def fc(self, x, Mout, relu=True): 944 | """Fully connected layer with Mout features.""" 945 | N, Min = x.get_shape() 946 | W = self._weight_variable([int(Min), Mout], regularization=True) 947 | b = self._bias_variable([Mout], regularization=True) 948 | x = tf.matmul(x, W) + b 949 | return tf.nn.relu(x) if relu else x 950 | 951 | def _inference(self, x, dropout): 952 | # Graph convolutional layers. 953 | x = tf.expand_dims(x, 2) # N x M x F=1 954 | for i in range(len(self.p)): 955 | with tf.variable_scope('conv{}'.format(i+1)): 956 | with tf.name_scope('filter'): 957 | x = self.filter(x, self.L[i], self.F[i], self.K[i]) 958 | with tf.name_scope('bias_relu'): 959 | x = self.brelu(x) 960 | with tf.name_scope('pooling'): 961 | x = self.pool(x, self.p[i]) 962 | 963 | # Fully connected hidden layers. 964 | N, M, F = x.get_shape() 965 | x = tf.reshape(x, [int(N), int(M*F)]) # N x M 966 | for i,M in enumerate(self.M[:-1]): 967 | with tf.variable_scope('fc{}'.format(i+1)): 968 | x = self.fc(x, M) 969 | x = tf.nn.dropout(x, dropout) 970 | 971 | # Logits linear layer, i.e. softmax without normalization. 972 | with tf.variable_scope('logits'): 973 | x = self.fc(x, self.M[-1], relu=False) 974 | return x 975 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import gensim 2 | import sklearn, sklearn.datasets 3 | import sklearn.naive_bayes, sklearn.linear_model, sklearn.svm, sklearn.neighbors, sklearn.ensemble 4 | import matplotlib.pyplot as plt 5 | import scipy.sparse 6 | import numpy as np 7 | import time, re 8 | 9 | 10 | # Helpers to process text documents. 11 | 12 | 13 | class TextDataset(object): 14 | def clean_text(self, num='substitute'): 15 | # TODO: stemming, lemmatisation 16 | for i,doc in enumerate(self.documents): 17 | # Digits. 18 | if num is 'spell': 19 | doc = doc.replace('0', ' zero ') 20 | doc = doc.replace('1', ' one ') 21 | doc = doc.replace('2', ' two ') 22 | doc = doc.replace('3', ' three ') 23 | doc = doc.replace('4', ' four ') 24 | doc = doc.replace('5', ' five ') 25 | doc = doc.replace('6', ' six ') 26 | doc = doc.replace('7', ' seven ') 27 | doc = doc.replace('8', ' eight ') 28 | doc = doc.replace('9', ' nine ') 29 | elif num is 'substitute': 30 | # All numbers are equal. Useful for embedding (countable words) ? 31 | doc = re.sub('(\\d+)', ' NUM ', doc) 32 | elif num is 'remove': 33 | # Numbers are uninformative (they are all over the place). Useful for bag-of-words ? 34 | # But maybe some kind of documents contain more numbers, e.g. finance. 35 | # Some documents are indeed full of numbers. At least in 20NEWS. 36 | doc = re.sub('[0-9]', ' ', doc) 37 | # Remove everything except a-z characters and single space. 38 | doc = doc.replace('$', ' dollar ') 39 | doc = doc.lower() 40 | doc = re.sub('[^a-z]', ' ', doc) 41 | doc = ' '.join(doc.split()) # same as doc = re.sub('\s{2,}', ' ', doc) 42 | self.documents[i] = doc 43 | 44 | def vectorize(self, **params): 45 | # TODO: count or tf-idf. Or in normalize ? 46 | vectorizer = sklearn.feature_extraction.text.CountVectorizer(**params) 47 | self.data = vectorizer.fit_transform(self.documents) 48 | self.vocab = vectorizer.get_feature_names() 49 | assert len(self.vocab) == self.data.shape[1] 50 | 51 | def data_info(self, show_classes=False): 52 | N, M = self.data.shape 53 | sparsity = self.data.nnz / N / M * 100 54 | print('N = {} documents, M = {} words, sparsity={:.4f}%'.format(N, M, sparsity)) 55 | if show_classes: 56 | for i in range(len(self.class_names)): 57 | num = sum(self.labels == i) 58 | print(' {:5d} documents in class {:2d} ({})'.format(num, i, self.class_names[i])) 59 | 60 | def show_document(self, i): 61 | label = self.labels[i] 62 | name = self.class_names[label] 63 | try: 64 | text = self.documents[i] 65 | wc = len(text.split()) 66 | except AttributeError: 67 | text = None 68 | wc = 'N/A' 69 | print('document {}: label {} --> {}, {} words'.format(i, label, name, wc)) 70 | try: 71 | vector = self.data[i,:] 72 | for j in range(vector.shape[1]): 73 | if vector[0,j] != 0: 74 | print(' {:.2f} "{}" ({})'.format(vector[0,j], self.vocab[j], j)) 75 | except AttributeError: 76 | pass 77 | return text 78 | 79 | def keep_documents(self, idx): 80 | """Keep the documents given by the index, discard the others.""" 81 | self.documents = [self.documents[i] for i in idx] 82 | self.labels = self.labels[idx] 83 | self.data = self.data[idx,:] 84 | 85 | def keep_words(self, idx): 86 | """Keep the documents given by the index, discard the others.""" 87 | self.data = self.data[:,idx] 88 | self.vocab = [self.vocab[i] for i in idx] 89 | try: 90 | self.embeddings = self.embeddings[idx,:] 91 | except AttributeError: 92 | pass 93 | 94 | def remove_short_documents(self, nwords, vocab='selected'): 95 | """Remove a document if it contains less than nwords.""" 96 | if vocab is 'selected': 97 | # Word count with selected vocabulary. 98 | wc = self.data.sum(axis=1) 99 | wc = np.squeeze(np.asarray(wc)) 100 | elif vocab is 'full': 101 | # Word count with full vocabulary. 102 | wc = np.empty(len(self.documents), dtype=np.int) 103 | for i,doc in enumerate(self.documents): 104 | wc[i] = len(doc.split()) 105 | idx = np.argwhere(wc >= nwords).squeeze() 106 | self.keep_documents(idx) 107 | return wc 108 | 109 | def keep_top_words(self, M, Mprint=20): 110 | """Keep in the vocaluary the M words who appear most often.""" 111 | freq = self.data.sum(axis=0) 112 | freq = np.squeeze(np.asarray(freq)) 113 | idx = np.argsort(freq)[::-1] 114 | idx = idx[:M] 115 | self.keep_words(idx) 116 | print('most frequent words') 117 | for i in range(Mprint): 118 | print(' {:3d}: {:10s} {:6d} counts'.format(i, self.vocab[i], freq[idx][i])) 119 | return freq[idx] 120 | 121 | def normalize(self, norm='l1'): 122 | """Normalize data to unit length.""" 123 | # TODO: TF-IDF. 124 | data = self.data.astype(np.float64) 125 | self.data = sklearn.preprocessing.normalize(data, axis=1, norm=norm) 126 | 127 | def embed(self, filename=None, size=100): 128 | """Embed the vocabulary using pre-trained vectors.""" 129 | if filename: 130 | model = gensim.models.Word2Vec.load_word2vec_format(filename, binary=True) 131 | size = model.vector_size 132 | else: 133 | class Sentences(object): 134 | def __init__(self, documents): 135 | self.documents = documents 136 | def __iter__(self): 137 | for document in self.documents: 138 | yield document.split() 139 | model = gensim.models.Word2Vec(Sentences(self.documents), size) 140 | self.embeddings = np.empty((len(self.vocab), size)) 141 | keep = [] 142 | not_found = 0 143 | for i,word in enumerate(self.vocab): 144 | try: 145 | self.embeddings[i,:] = model[word] 146 | keep.append(i) 147 | except KeyError: 148 | not_found += 1 149 | print('{} words not found in corpus'.format(not_found, i)) 150 | self.keep_words(keep) 151 | 152 | class Text20News(TextDataset): 153 | def __init__(self, **params): 154 | dataset = sklearn.datasets.fetch_20newsgroups(**params) 155 | self.documents = dataset.data 156 | self.labels = dataset.target 157 | self.class_names = dataset.target_names 158 | assert max(self.labels) + 1 == len(self.class_names) 159 | N, C = len(self.documents), len(self.class_names) 160 | print('N = {} documents, C = {} classes'.format(N, C)) 161 | 162 | class TextRCV1(TextDataset): 163 | def __init__(self, **params): 164 | dataset = sklearn.datasets.fetch_rcv1(**params) 165 | self.data = dataset.data 166 | self.target = dataset.target 167 | self.class_names = dataset.target_names 168 | assert len(self.class_names) == 103 # 103 categories according to LYRL2004 169 | N, C = self.target.shape 170 | assert C == len(self.class_names) 171 | print('N = {} documents, C = {} classes'.format(N, C)) 172 | 173 | def remove_classes(self, keep): 174 | ## Construct a lookup table for labels. 175 | labels_row = [] 176 | labels_col = [] 177 | class_lookup = {} 178 | for i,name in enumerate(self.class_names): 179 | class_lookup[name] = i 180 | self.class_names = keep 181 | 182 | # Index of classes to keep. 183 | idx_keep = np.empty(len(keep)) 184 | for i,cat in enumerate(keep): 185 | idx_keep[i] = class_lookup[cat] 186 | self.target = self.target[:,idx_keep] 187 | assert self.target.shape[1] == len(keep) 188 | 189 | def show_doc_per_class(self, print_=False): 190 | """Number of documents per class.""" 191 | docs_per_class = np.array(self.target.astype(np.uint64).sum(axis=0)).squeeze() 192 | print('categories ({} assignments in total)'.format(docs_per_class.sum())) 193 | if print_: 194 | for i,cat in enumerate(self.class_names): 195 | print(' {:5s}: {:6d} documents'.format(cat, docs_per_class[i])) 196 | plt.figure(figsize=(17,5)) 197 | plt.plot(sorted(docs_per_class[::-1]),'.') 198 | 199 | def show_classes_per_doc(self): 200 | """Number of classes per document.""" 201 | classes_per_doc = np.array(self.target.sum(axis=1)).squeeze() 202 | plt.figure(figsize=(17,5)) 203 | plt.plot(sorted(classes_per_doc[::-1]),'.') 204 | 205 | def select_documents(self): 206 | classes_per_doc = np.array(self.target.sum(axis=1)).squeeze() 207 | self.target = self.target[classes_per_doc==1] 208 | self.data = self.data[classes_per_doc==1, :] 209 | 210 | # Convert labels from indicator form to single value. 211 | N, C = self.target.shape 212 | target = self.target.tocoo() 213 | self.labels = target.col 214 | assert self.labels.min() == 0 215 | assert self.labels.max() == C - 1 216 | 217 | # Bruna and Dropout used 2 * 201369 = 402738 documents. Probably the difference btw v1 and v2. 218 | #return classes_per_doc 219 | 220 | ### Helpers to quantify classifier's quality. 221 | 222 | 223 | def baseline(train_data, train_labels, test_data, test_labels, omit=[]): 224 | """Train various classifiers to get a baseline.""" 225 | clf, train_accuracy, test_accuracy, train_f1, test_f1, exec_time = [], [], [], [], [], [] 226 | clf.append(sklearn.neighbors.KNeighborsClassifier(n_neighbors=10)) 227 | clf.append(sklearn.linear_model.LogisticRegression()) 228 | clf.append(sklearn.naive_bayes.BernoulliNB(alpha=.01)) 229 | clf.append(sklearn.ensemble.RandomForestClassifier()) 230 | clf.append(sklearn.naive_bayes.MultinomialNB(alpha=.01)) 231 | clf.append(sklearn.linear_model.RidgeClassifier()) 232 | clf.append(sklearn.svm.LinearSVC()) 233 | for i,c in enumerate(clf): 234 | if i not in omit: 235 | t_start = time.process_time() 236 | c.fit(train_data, train_labels) 237 | train_pred = c.predict(train_data) 238 | test_pred = c.predict(test_data) 239 | train_accuracy.append('{:5.2f}'.format(100*sklearn.metrics.accuracy_score(train_labels, train_pred))) 240 | test_accuracy.append('{:5.2f}'.format(100*sklearn.metrics.accuracy_score(test_labels, test_pred))) 241 | train_f1.append('{:5.2f}'.format(100*sklearn.metrics.f1_score(train_labels, train_pred, average='weighted'))) 242 | test_f1.append('{:5.2f}'.format(100*sklearn.metrics.f1_score(test_labels, test_pred, average='weighted'))) 243 | exec_time.append('{:5.2f}'.format(time.process_time() - t_start)) 244 | print('Train accuracy: {}'.format(' '.join(train_accuracy))) 245 | print('Test accuracy: {}'.format(' '.join(test_accuracy))) 246 | print('Train F1 (weighted): {}'.format(' '.join(train_f1))) 247 | print('Test F1 (weighted): {}'.format(' '.join(test_f1))) 248 | print('Execution time: {}'.format(' '.join(exec_time))) 249 | 250 | def grid_search(params, grid_params, train_data, train_labels, val_data, 251 | val_labels, test_data, test_labels, model): 252 | """Explore the hyper-parameter space with an exhaustive grid search.""" 253 | params = params.copy() 254 | train_accuracy, test_accuracy, train_f1, test_f1 = [], [], [], [] 255 | grid = sklearn.grid_search.ParameterGrid(grid_params) 256 | print('grid search: {} combinations to evaluate'.format(len(grid))) 257 | for grid_params in grid: 258 | params.update(grid_params) 259 | name = '{}'.format(grid) 260 | print('\n\n {} \n\n'.format(grid_params)) 261 | m = model(params) 262 | m.fit(train_data, train_labels, val_data, val_labels) 263 | string, accuracy, f1, loss = m.evaluate(train_data, train_labels) 264 | train_accuracy.append('{:5.2f}'.format(accuracy)); train_f1.append('{:5.2f}'.format(f1)) 265 | print('train {}'.format(string)) 266 | string, accuracy, f1, loss = m.evaluate(test_data, test_labels) 267 | test_accuracy.append('{:5.2f}'.format(accuracy)); test_f1.append('{:5.2f}'.format(f1)) 268 | print('test {}'.format(string)) 269 | print('\n\n') 270 | print('Train accuracy: {}'.format(' '.join(train_accuracy))) 271 | print('Test accuracy: {}'.format(' '.join(test_accuracy))) 272 | print('Train F1 (weighted): {}'.format(' '.join(train_f1))) 273 | print('Test F1 (weighted): {}'.format(' '.join(test_f1))) 274 | for i,grid_params in enumerate(grid): 275 | print('{} --> {} {} {} {}'.format(grid_params, train_accuracy[i], test_accuracy[i], train_f1[i], test_f1[i])) 276 | 277 | 278 | class model_perf(object): 279 | 280 | def __init__(s): 281 | s.names, s.params = set(), {} 282 | s.fit_accuracies, s.fit_losses, s.fit_time = {}, {}, {} 283 | s.train_accuracy, s.train_f1, s.train_loss = {}, {}, {} 284 | s.test_accuracy, s.test_f1, s.test_loss = {}, {}, {} 285 | 286 | def test(s, model, name, params, train_data, train_labels, val_data, val_labels, test_data, test_labels): 287 | s.params[name] = params 288 | s.fit_accuracies[name], s.fit_losses[name], s.fit_time[name] = \ 289 | model.fit(train_data, train_labels, val_data, val_labels) 290 | string, s.train_accuracy[name], s.train_f1[name], s.train_loss[name] = \ 291 | model.evaluate(train_data, train_labels) 292 | print('train {}'.format(string)) 293 | string, s.test_accuracy[name], s.test_f1[name], s.test_loss[name] = \ 294 | model.evaluate(test_data, test_labels) 295 | print('test {}'.format(string)) 296 | s.names.add(name) 297 | 298 | def show(s, fontsize=None): 299 | if fontsize: 300 | plt.rc('pdf', fonttype=42) 301 | plt.rc('ps', fonttype=42) 302 | plt.rc('font', size=fontsize) # controls default text sizes 303 | plt.rc('axes', titlesize=fontsize) # fontsize of the axes title 304 | plt.rc('axes', labelsize=fontsize) # fontsize of the x any y labels 305 | plt.rc('xtick', labelsize=fontsize) # fontsize of the tick labels 306 | plt.rc('ytick', labelsize=fontsize) # fontsize of the tick labels 307 | plt.rc('legend', fontsize=fontsize) # legend fontsize 308 | plt.rc('figure', titlesize=fontsize) # size of the figure title 309 | print(' accuracy F1 loss time [ms] name') 310 | print('test train test train test train') 311 | for name in sorted(s.names): 312 | print('{:5.2f} {:5.2f} {:5.2f} {:5.2f} {:.2e} {:.2e} {:3.0f} {}'.format( 313 | s.test_accuracy[name], s.train_accuracy[name], 314 | s.test_f1[name], s.train_f1[name], 315 | s.test_loss[name], s.train_loss[name], s.fit_time[name]*1000, name)) 316 | 317 | fig, ax = plt.subplots(1, 2, figsize=(15, 5)) 318 | for name in sorted(s.names): 319 | steps = np.arange(len(s.fit_accuracies[name])) + 1 320 | steps *= s.params[name]['eval_frequency'] 321 | ax[0].plot(steps, s.fit_accuracies[name], '.-', label=name) 322 | ax[1].plot(steps, s.fit_losses[name], '.-', label=name) 323 | ax[0].set_xlim(min(steps), max(steps)) 324 | ax[1].set_xlim(min(steps), max(steps)) 325 | ax[0].set_xlabel('step') 326 | ax[1].set_xlabel('step') 327 | ax[0].set_ylabel('validation accuracy') 328 | ax[1].set_ylabel('training loss') 329 | ax[0].legend(loc='lower right') 330 | ax[1].legend(loc='upper right') 331 | #fig.savefig('training.pdf') 332 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | NB = $(sort $(wildcard *.ipynb)) 2 | DIRS = nips2016 trials 3 | 4 | CLEANDIRS = $(DIRS:%=clean-%) 5 | 6 | run: $(NB) $(DIRS) 7 | 8 | $(NB): 9 | jupyter nbconvert --inplace --execute --ExecutePreprocessor.timeout=-1 $@ 10 | 11 | $(DIRS): 12 | $(MAKE) -C $@ 13 | 14 | clean: $(CLEANDIRS) 15 | jupyter nbconvert --inplace --ClearOutputPreprocessor.enabled=True $(NB) 16 | #rm -rf **/*.pyc 17 | 18 | $(CLEANDIRS): 19 | $(MAKE) clean -C $(@:clean-%=%) 20 | 21 | install: 22 | pip install --upgrade pip 23 | pip install -r requirements.txt 24 | 25 | readme: 26 | grip README.md 27 | 28 | .PHONY: run $(NB) $(DIRS) clean $(CLEANDIRS) install readme 29 | -------------------------------------------------------------------------------- /nips2016/20news.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Parameters to reproduce the paper's results**:\n", 8 | "* change the optimizer from SGD to Adam in `lib/models.py`,\n", 9 | "* change the size of the vocabulary from 1000 to 10000 in `train.keep_top_words()` below." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "collapsed": false 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%load_ext autoreload\n", 21 | "%autoreload 2\n", 22 | "\n", 23 | "import sys, os\n", 24 | "sys.path.insert(0, '..')\n", 25 | "from lib import models, graph, coarsening, utils\n", 26 | "\n", 27 | "import tensorflow as tf\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import scipy.sparse\n", 30 | "import numpy as np\n", 31 | "import time\n", 32 | "\n", 33 | "%matplotlib inline" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "collapsed": false 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "flags = tf.app.flags\n", 45 | "FLAGS = flags.FLAGS\n", 46 | "\n", 47 | "# Graphs.\n", 48 | "flags.DEFINE_integer('number_edges', 16, 'Graph: minimum number of edges per vertex.')\n", 49 | "flags.DEFINE_string('metric', 'cosine', 'Graph: similarity measure (between features).')\n", 50 | "# TODO: change cgcnn for combinatorial Laplacians.\n", 51 | "flags.DEFINE_bool('normalized_laplacian', True, 'Graph Laplacian: normalized.')\n", 52 | "flags.DEFINE_integer('coarsening_levels', 0, 'Number of coarsened graphs.')\n", 53 | "\n", 54 | "flags.DEFINE_string('dir_data', os.path.join('..', 'data', '20news'), 'Directory to store data.')\n", 55 | "flags.DEFINE_integer('val_size', 400, 'Size of the validation set.')" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Data" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "collapsed": false 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "# Fetch dataset. Scikit-learn already performs some cleaning.\n", 74 | "remove = ('headers','footers','quotes') # (), ('headers') or ('headers','footers','quotes')\n", 75 | "train = utils.Text20News(data_home=FLAGS.dir_data, subset='train', remove=remove)\n", 76 | "\n", 77 | "# Pre-processing: transform everything to a-z and whitespace.\n", 78 | "print(train.show_document(1)[:400])\n", 79 | "train.clean_text(num='substitute')\n", 80 | "\n", 81 | "# Analyzing / tokenizing: transform documents to bags-of-words.\n", 82 | "#stop_words = set(sklearn.feature_extraction.text.ENGLISH_STOP_WORDS)\n", 83 | "# Or stop words from NLTK.\n", 84 | "# Add e.g. don, ve.\n", 85 | "train.vectorize(stop_words='english')\n", 86 | "print(train.show_document(1)[:400])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": { 93 | "collapsed": false 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "# Remove short documents.\n", 98 | "train.data_info(True)\n", 99 | "wc = train.remove_short_documents(nwords=20, vocab='full')\n", 100 | "train.data_info()\n", 101 | "print('shortest: {}, longest: {} words'.format(wc.min(), wc.max()))\n", 102 | "plt.figure(figsize=(17,5))\n", 103 | "plt.semilogy(wc, '.');\n", 104 | "\n", 105 | "# Remove encoded images.\n", 106 | "def remove_encoded_images(dataset, freq=1e3):\n", 107 | " widx = train.vocab.index('ax')\n", 108 | " wc = train.data[:,widx].toarray().squeeze()\n", 109 | " idx = np.argwhere(wc < freq).squeeze()\n", 110 | " dataset.keep_documents(idx)\n", 111 | " return wc\n", 112 | "wc = remove_encoded_images(train)\n", 113 | "train.data_info()\n", 114 | "plt.figure(figsize=(17,5))\n", 115 | "plt.semilogy(wc, '.');" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "collapsed": false 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "# Word embedding\n", 127 | "if True:\n", 128 | " train.embed()\n", 129 | "else:\n", 130 | " train.embed(os.path.join('..', 'data', 'word2vec', 'GoogleNews-vectors-negative300.bin'))\n", 131 | "train.data_info()\n", 132 | "# Further feature selection. (TODO)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": false 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "# Feature selection.\n", 144 | "# Other options include: mutual information or document count.\n", 145 | "freq = train.keep_top_words(1000, 20)\n", 146 | "train.data_info()\n", 147 | "train.show_document(1)\n", 148 | "plt.figure(figsize=(17,5))\n", 149 | "plt.semilogy(freq);\n", 150 | "\n", 151 | "# Remove documents whose signal would be the zero vector.\n", 152 | "wc = train.remove_short_documents(nwords=5, vocab='selected')\n", 153 | "train.data_info(True)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "collapsed": false 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "train.normalize(norm='l1')\n", 165 | "train.show_document(1);" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "collapsed": false 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "# Test dataset.\n", 177 | "test = utils.Text20News(data_home=FLAGS.dir_data, subset='test', remove=remove)\n", 178 | "test.clean_text(num='substitute')\n", 179 | "test.vectorize(vocabulary=train.vocab)\n", 180 | "test.data_info()\n", 181 | "wc = test.remove_short_documents(nwords=5, vocab='selected')\n", 182 | "print('shortest: {}, longest: {} words'.format(wc.min(), wc.max()))\n", 183 | "test.data_info(True)\n", 184 | "test.normalize(norm='l1')" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": { 191 | "collapsed": false 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "if True:\n", 196 | " train_data = train.data.astype(np.float32)\n", 197 | " test_data = test.data.astype(np.float32)\n", 198 | " train_labels = train.labels\n", 199 | " test_labels = test.labels\n", 200 | "else:\n", 201 | " perm = np.random.RandomState(seed=42).permutation(dataset.data.shape[0])\n", 202 | " Ntest = 6695\n", 203 | " perm_test = perm[:Ntest]\n", 204 | " perm_train = perm[Ntest:]\n", 205 | " train_data = train.data[perm_train,:].astype(np.float32)\n", 206 | " test_data = train.data[perm_test,:].astype(np.float32)\n", 207 | " train_labels = train.labels[perm_train]\n", 208 | " test_labels = train.labels[perm_test]\n", 209 | "\n", 210 | "if True:\n", 211 | " graph_data = train.embeddings.astype(np.float32)\n", 212 | "else:\n", 213 | " graph_data = train.data.T.astype(np.float32).toarray()\n", 214 | "\n", 215 | "#del train, test" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "# Feature graph" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": { 229 | "collapsed": false 230 | }, 231 | "outputs": [], 232 | "source": [ 233 | "t_start = time.process_time()\n", 234 | "dist, idx = graph.distance_sklearn_metrics(graph_data, k=FLAGS.number_edges, metric=FLAGS.metric)\n", 235 | "A = graph.adjacency(dist, idx)\n", 236 | "print(\"{} > {} edges\".format(A.nnz//2, FLAGS.number_edges*graph_data.shape[0]//2))\n", 237 | "A = graph.replace_random_edges(A, 0)\n", 238 | "graphs, perm = coarsening.coarsen(A, levels=FLAGS.coarsening_levels, self_connections=False)\n", 239 | "L = [graph.laplacian(A, normalized=True) for A in graphs]\n", 240 | "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 241 | "#graph.plot_spectrum(L)\n", 242 | "#del graph_data, A, dist, idx" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": { 249 | "collapsed": false 250 | }, 251 | "outputs": [], 252 | "source": [ 253 | "t_start = time.process_time()\n", 254 | "train_data = scipy.sparse.csr_matrix(coarsening.perm_data(train_data.toarray(), perm))\n", 255 | "test_data = scipy.sparse.csr_matrix(coarsening.perm_data(test_data.toarray(), perm))\n", 256 | "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 257 | "del perm" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "# Classification" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": { 271 | "collapsed": false 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "# Training set is shuffled already.\n", 276 | "#perm = np.random.permutation(train_data.shape[0])\n", 277 | "#train_data = train_data[perm,:]\n", 278 | "#train_labels = train_labels[perm]\n", 279 | "\n", 280 | "# Validation set.\n", 281 | "if False:\n", 282 | " val_data = train_data[:FLAGS.val_size,:]\n", 283 | " val_labels = train_labels[:FLAGS.val_size]\n", 284 | " train_data = train_data[FLAGS.val_size:,:]\n", 285 | " train_labels = train_labels[FLAGS.val_size:]\n", 286 | "else:\n", 287 | " val_data = test_data\n", 288 | " val_labels = test_labels" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": { 295 | "collapsed": false 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "if True:\n", 300 | " utils.baseline(train_data, train_labels, test_data, test_labels)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": { 307 | "collapsed": false 308 | }, 309 | "outputs": [], 310 | "source": [ 311 | "common = {}\n", 312 | "common['dir_name'] = '20news/'\n", 313 | "common['num_epochs'] = 80\n", 314 | "common['batch_size'] = 100\n", 315 | "common['decay_steps'] = len(train_labels) / common['batch_size']\n", 316 | "common['eval_frequency'] = 5 * common['num_epochs']\n", 317 | "common['filter'] = 'chebyshev5'\n", 318 | "common['brelu'] = 'b1relu'\n", 319 | "common['pool'] = 'mpool1'\n", 320 | "C = max(train_labels) + 1 # number of classes\n", 321 | "\n", 322 | "model_perf = utils.model_perf()" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "collapsed": false 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "if True:\n", 334 | " name = 'softmax'\n", 335 | " params = common.copy()\n", 336 | " params['dir_name'] += name\n", 337 | " params['regularization'] = 0\n", 338 | " params['dropout'] = 1\n", 339 | " params['learning_rate'] = 1e3\n", 340 | " params['decay_rate'] = 0.95\n", 341 | " params['momentum'] = 0.9\n", 342 | " params['F'] = []\n", 343 | " params['K'] = []\n", 344 | " params['p'] = []\n", 345 | " params['M'] = [C]\n", 346 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 347 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": { 354 | "collapsed": false 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "if True:\n", 359 | " name = 'fc_softmax'\n", 360 | " params = common.copy()\n", 361 | " params['dir_name'] += name\n", 362 | " params['regularization'] = 0\n", 363 | " params['dropout'] = 1\n", 364 | " params['learning_rate'] = 0.1\n", 365 | " params['decay_rate'] = 0.95\n", 366 | " params['momentum'] = 0.9\n", 367 | " params['F'] = []\n", 368 | " params['K'] = []\n", 369 | " params['p'] = []\n", 370 | " params['M'] = [2500, C]\n", 371 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 372 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": { 379 | "collapsed": false 380 | }, 381 | "outputs": [], 382 | "source": [ 383 | "if True:\n", 384 | " name = 'fc_fc_softmax'\n", 385 | " params = common.copy()\n", 386 | " params['dir_name'] += name\n", 387 | " params['regularization'] = 0\n", 388 | " params['dropout'] = 1\n", 389 | " params['learning_rate'] = 0.1\n", 390 | " params['decay_rate'] = 0.95\n", 391 | " params['momentum'] = 0.9\n", 392 | " params['F'] = []\n", 393 | " params['K'] = []\n", 394 | " params['p'] = []\n", 395 | " params['M'] = [2500, 500, C]\n", 396 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 397 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": { 404 | "collapsed": false 405 | }, 406 | "outputs": [], 407 | "source": [ 408 | "if True:\n", 409 | " name = 'fgconv_softmax'\n", 410 | " params = common.copy()\n", 411 | " params['dir_name'] += name\n", 412 | " params['filter'] = 'fourier'\n", 413 | " params['regularization'] = 0\n", 414 | " params['dropout'] = 1\n", 415 | " params['learning_rate'] = 0.001\n", 416 | " params['decay_rate'] = 1\n", 417 | " params['momentum'] = 0\n", 418 | " params['F'] = [32]\n", 419 | " params['K'] = [L[0].shape[0]]\n", 420 | " params['p'] = [1]\n", 421 | " params['M'] = [C]\n", 422 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 423 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": { 430 | "collapsed": false 431 | }, 432 | "outputs": [], 433 | "source": [ 434 | "if True:\n", 435 | " name = 'sgconv_softmax'\n", 436 | " params = common.copy()\n", 437 | " params['dir_name'] += name\n", 438 | " params['filter'] = 'spline'\n", 439 | " params['regularization'] = 1e-3\n", 440 | " params['dropout'] = 1\n", 441 | " params['learning_rate'] = 0.1\n", 442 | " params['decay_rate'] = 0.999\n", 443 | " params['momentum'] = 0\n", 444 | " params['F'] = [32]\n", 445 | " params['K'] = [5]\n", 446 | " params['p'] = [1]\n", 447 | " params['M'] = [C]\n", 448 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 449 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": { 456 | "collapsed": false 457 | }, 458 | "outputs": [], 459 | "source": [ 460 | "if True:\n", 461 | " name = 'cgconv_softmax'\n", 462 | " params = common.copy()\n", 463 | " params['dir_name'] += name\n", 464 | " params['regularization'] = 1e-3\n", 465 | " params['dropout'] = 1\n", 466 | " params['learning_rate'] = 0.1\n", 467 | " params['decay_rate'] = 0.999\n", 468 | " params['momentum'] = 0\n", 469 | " params['F'] = [32]\n", 470 | " params['K'] = [5]\n", 471 | " params['p'] = [1]\n", 472 | " params['M'] = [C]\n", 473 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 474 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": { 481 | "collapsed": false 482 | }, 483 | "outputs": [], 484 | "source": [ 485 | "if True:\n", 486 | " name = 'cgconv_fc_softmax'\n", 487 | " params = common.copy()\n", 488 | " params['dir_name'] += name\n", 489 | " params['regularization'] = 0\n", 490 | " params['dropout'] = 1\n", 491 | " params['learning_rate'] = 0.1\n", 492 | " params['decay_rate'] = 0.999\n", 493 | " params['momentum'] = 0\n", 494 | " params['F'] = [5]\n", 495 | " params['K'] = [15]\n", 496 | " params['p'] = [1]\n", 497 | " params['M'] = [100, C]\n", 498 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 499 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": { 506 | "collapsed": false 507 | }, 508 | "outputs": [], 509 | "source": [ 510 | "model_perf.show()" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": { 517 | "collapsed": true 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "if False:\n", 522 | " grid_params = {}\n", 523 | " data = (train_data, train_labels, val_data, val_labels, test_data, test_labels)\n", 524 | " utils.grid_search(params, grid_params, *data, model=lambda x: models.cgcnn(L,**x))" 525 | ] 526 | } 527 | ], 528 | "metadata": { 529 | "kernelspec": { 530 | "display_name": "Python 3", 531 | "language": "python", 532 | "name": "python3" 533 | }, 534 | "language_info": { 535 | "codemirror_mode": { 536 | "name": "ipython", 537 | "version": 3 538 | }, 539 | "file_extension": ".py", 540 | "mimetype": "text/x-python", 541 | "name": "python", 542 | "nbconvert_exporter": "python", 543 | "pygments_lexer": "ipython3", 544 | "version": "3.4.3" 545 | } 546 | }, 547 | "nbformat": 4, 548 | "nbformat_minor": 0 549 | } 550 | -------------------------------------------------------------------------------- /nips2016/makefile: -------------------------------------------------------------------------------- 1 | NB = $(sort $(wildcard *.ipynb)) 2 | 3 | run: $(NB) 4 | 5 | $(NB): 6 | jupyter nbconvert --inplace --execute --ExecutePreprocessor.timeout=-1 $@ 7 | 8 | clean: 9 | jupyter nbconvert --inplace --ClearOutputPreprocessor.enabled=True $(NB) 10 | 11 | .PHONY: run $(NB) clean 12 | -------------------------------------------------------------------------------- /nips2016/mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%load_ext autoreload\n", 12 | "%autoreload 2\n", 13 | "\n", 14 | "import sys, os\n", 15 | "sys.path.insert(0, '..')\n", 16 | "from lib import models, graph, coarsening, utils\n", 17 | "\n", 18 | "import tensorflow as tf\n", 19 | "import numpy as np\n", 20 | "import time\n", 21 | "\n", 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "collapsed": false 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "flags = tf.app.flags\n", 34 | "FLAGS = flags.FLAGS\n", 35 | "\n", 36 | "# Graphs.\n", 37 | "flags.DEFINE_integer('number_edges', 8, 'Graph: minimum number of edges per vertex.')\n", 38 | "flags.DEFINE_string('metric', 'euclidean', 'Graph: similarity measure (between features).')\n", 39 | "# TODO: change cgcnn for combinatorial Laplacians.\n", 40 | "flags.DEFINE_bool('normalized_laplacian', True, 'Graph Laplacian: normalized.')\n", 41 | "flags.DEFINE_integer('coarsening_levels', 4, 'Number of coarsened graphs.')\n", 42 | "\n", 43 | "# Directories.\n", 44 | "flags.DEFINE_string('dir_data', os.path.join('..', 'data', 'mnist'), 'Directory to store data.')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# Feature graph" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": { 58 | "collapsed": false 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "def grid_graph(m, corners=False):\n", 63 | " z = graph.grid(m)\n", 64 | " dist, idx = graph.distance_sklearn_metrics(z, k=FLAGS.number_edges, metric=FLAGS.metric)\n", 65 | " A = graph.adjacency(dist, idx)\n", 66 | "\n", 67 | " # Connections are only vertical or horizontal on the grid.\n", 68 | " # Corner vertices are connected to 2 neightbors only.\n", 69 | " if corners:\n", 70 | " import scipy.sparse\n", 71 | " A = A.toarray()\n", 72 | " A[A < A.max()/1.5] = 0\n", 73 | " A = scipy.sparse.csr_matrix(A)\n", 74 | " print('{} edges'.format(A.nnz))\n", 75 | "\n", 76 | " print(\"{} > {} edges\".format(A.nnz//2, FLAGS.number_edges*m**2//2))\n", 77 | " return A\n", 78 | "\n", 79 | "t_start = time.process_time()\n", 80 | "A = grid_graph(28, corners=False)\n", 81 | "A = graph.replace_random_edges(A, 0)\n", 82 | "graphs, perm = coarsening.coarsen(A, levels=FLAGS.coarsening_levels, self_connections=False)\n", 83 | "L = [graph.laplacian(A, normalized=True) for A in graphs]\n", 84 | "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 85 | "graph.plot_spectrum(L)\n", 86 | "del A" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "# Data" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": { 100 | "collapsed": false 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "from tensorflow.examples.tutorials.mnist import input_data\n", 105 | "mnist = input_data.read_data_sets(FLAGS.dir_data, one_hot=False)\n", 106 | "\n", 107 | "train_data = mnist.train.images.astype(np.float32)\n", 108 | "val_data = mnist.validation.images.astype(np.float32)\n", 109 | "test_data = mnist.test.images.astype(np.float32)\n", 110 | "train_labels = mnist.train.labels\n", 111 | "val_labels = mnist.validation.labels\n", 112 | "test_labels = mnist.test.labels\n", 113 | "\n", 114 | "t_start = time.process_time()\n", 115 | "train_data = coarsening.perm_data(train_data, perm)\n", 116 | "val_data = coarsening.perm_data(val_data, perm)\n", 117 | "test_data = coarsening.perm_data(test_data, perm)\n", 118 | "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 119 | "del perm" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "# Neural networks" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "collapsed": true 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "#model = fc1()\n", 138 | "#model = fc2(nhiddens=100)\n", 139 | "#model = cnn2(K=5, F=10) # K=28 is equivalent to filtering with fgcnn.\n", 140 | "#model = fcnn2(F=10)\n", 141 | "#model = fgcnn2(L[0], F=10)\n", 142 | "#model = lgcnn2_2(L[0], F=10, K=10)\n", 143 | "#model = cgcnn2_3(L[0], F=10, K=5)\n", 144 | "#model = cgcnn2_4(L[0], F=10, K=5)\n", 145 | "#model = cgcnn2_5(L[0], F=10, K=5)\n", 146 | "\n", 147 | "if False:\n", 148 | " K = 5 # 5 or 5^2\n", 149 | " t_start = time.process_time()\n", 150 | " mnist.test._images = graph.lanczos(L, mnist.test._images.T, K).T\n", 151 | " mnist.train._images = graph.lanczos(L, mnist.train._images.T, K).T\n", 152 | " model = lgcnn2_1(L, F=10, K=K)\n", 153 | " print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 154 | " ph_data = tf.placeholder(tf.float32, (FLAGS.batch_size, mnist.train.images.shape[1], K), 'data')" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "collapsed": false 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "common = {}\n", 166 | "common['dir_name'] = 'mnist/'\n", 167 | "common['num_epochs'] = 20\n", 168 | "common['batch_size'] = 100\n", 169 | "common['decay_steps'] = mnist.train.num_examples / common['batch_size']\n", 170 | "common['eval_frequency'] = 30 * common['num_epochs']\n", 171 | "common['brelu'] = 'b1relu'\n", 172 | "common['pool'] = 'mpool1'\n", 173 | "C = max(mnist.train.labels) + 1 # number of classes\n", 174 | "\n", 175 | "model_perf = utils.model_perf()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "collapsed": false 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "if True:\n", 187 | " name = 'softmax'\n", 188 | " params = common.copy()\n", 189 | " params['dir_name'] += name\n", 190 | " params['regularization'] = 5e-4\n", 191 | " params['dropout'] = 1\n", 192 | " params['learning_rate'] = 0.02\n", 193 | " params['decay_rate'] = 0.95\n", 194 | " params['momentum'] = 0.9\n", 195 | " params['F'] = []\n", 196 | " params['K'] = []\n", 197 | " params['p'] = []\n", 198 | " params['M'] = [C]\n", 199 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 200 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "collapsed": true 208 | }, 209 | "outputs": [], 210 | "source": [ 211 | "# Common hyper-parameters for networks with one convolutional layer.\n", 212 | "common['regularization'] = 0\n", 213 | "common['dropout'] = 1\n", 214 | "common['learning_rate'] = 0.02\n", 215 | "common['decay_rate'] = 0.95\n", 216 | "common['momentum'] = 0.9\n", 217 | "common['F'] = [10]\n", 218 | "common['K'] = [20]\n", 219 | "common['p'] = [1]\n", 220 | "common['M'] = [C]" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": { 227 | "collapsed": false 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "if True:\n", 232 | " name = 'fgconv_softmax'\n", 233 | " params = common.copy()\n", 234 | " params['dir_name'] += name\n", 235 | " params['filter'] = 'fourier'\n", 236 | " params['K'] = [L[0].shape[0]]\n", 237 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 238 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": { 245 | "collapsed": false 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "if True:\n", 250 | " name = 'sgconv_softmax'\n", 251 | " params = common.copy()\n", 252 | " params['dir_name'] += name\n", 253 | " params['filter'] = 'spline'\n", 254 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 255 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": { 262 | "collapsed": false 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "# With 'chebyshev2' and 'b2relu', it corresponds to cgcnn2_2(L[0], F=10, K=20).\n", 267 | "if True:\n", 268 | " name = 'cgconv_softmax'\n", 269 | " params = common.copy()\n", 270 | " params['dir_name'] += name\n", 271 | " params['filter'] = 'chebyshev5'\n", 272 | "# params['filter'] = 'chebyshev2'\n", 273 | "# params['brelu'] = 'b2relu'\n", 274 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 275 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": { 282 | "collapsed": true 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "# Common hyper-parameters for LeNet5-like networks.\n", 287 | "common['regularization'] = 5e-4\n", 288 | "common['dropout'] = 0.5\n", 289 | "common['learning_rate'] = 0.02 # 0.03 in the paper but sgconv_sgconv_fc_softmax has difficulty to converge\n", 290 | "common['decay_rate'] = 0.95\n", 291 | "common['momentum'] = 0.9\n", 292 | "common['F'] = [32, 64]\n", 293 | "common['K'] = [25, 25]\n", 294 | "common['p'] = [4, 4]\n", 295 | "common['M'] = [512, C]" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "collapsed": false 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "# Architecture of TF MNIST conv model (LeNet-5-like).\n", 307 | "# Changes: regularization, dropout, decaying learning rate, momentum optimizer, stopping condition, size of biases.\n", 308 | "# Differences: training data randomization, init conv1 biases at 0.\n", 309 | "if True:\n", 310 | " name = 'fgconv_fgconv_fc_softmax' # 'Non-Param'\n", 311 | " params = common.copy()\n", 312 | " params['dir_name'] += name\n", 313 | " params['filter'] = 'fourier'\n", 314 | " params['K'] = [L[0].shape[0], L[2].shape[0]]\n", 315 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 316 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": { 323 | "collapsed": false 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "if True:\n", 328 | " name = 'sgconv_sgconv_fc_softmax' # 'Spline'\n", 329 | " params = common.copy()\n", 330 | " params['dir_name'] += name\n", 331 | " params['filter'] = 'spline'\n", 332 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 333 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "collapsed": false 341 | }, 342 | "outputs": [], 343 | "source": [ 344 | "if True:\n", 345 | " name = 'cgconv_cgconv_fc_softmax' # 'Chebyshev'\n", 346 | " params = common.copy()\n", 347 | " params['dir_name'] += name\n", 348 | " params['filter'] = 'chebyshev5'\n", 349 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 350 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "metadata": { 357 | "collapsed": false 358 | }, 359 | "outputs": [], 360 | "source": [ 361 | "model_perf.show()" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "collapsed": true 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "if False:\n", 373 | " grid_params = {}\n", 374 | " data = (train_data, train_labels, val_data, val_labels, test_data, test_labels)\n", 375 | " utils.grid_search(params, grid_params, *data, model=lambda x: models.cgcnn(L,**x))" 376 | ] 377 | } 378 | ], 379 | "metadata": { 380 | "kernelspec": { 381 | "display_name": "Python 3", 382 | "language": "python", 383 | "name": "python3" 384 | }, 385 | "language_info": { 386 | "codemirror_mode": { 387 | "name": "ipython", 388 | "version": 3 389 | }, 390 | "file_extension": ".py", 391 | "mimetype": "text/x-python", 392 | "name": "python", 393 | "nbconvert_exporter": "python", 394 | "pygments_lexer": "ipython3", 395 | "version": "3.4.3" 396 | } 397 | }, 398 | "nbformat": 4, 399 | "nbformat_minor": 0 400 | } 401 | -------------------------------------------------------------------------------- /rcv1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "%load_ext autoreload\n", 12 | "%autoreload 2\n", 13 | "\n", 14 | "from lib import models, graph, coarsening, utils\n", 15 | "\n", 16 | "import tensorflow as tf\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import scipy.sparse\n", 19 | "import numpy as np\n", 20 | "import time, shutil\n", 21 | "\n", 22 | "%matplotlib inline" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "collapsed": false 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "flags = tf.app.flags\n", 34 | "FLAGS = flags.FLAGS\n", 35 | "\n", 36 | "# Graphs.\n", 37 | "flags.DEFINE_integer('number_edges', 16, 'Graph: minimum number of edges per vertex.')\n", 38 | "flags.DEFINE_string('metric', 'cosine', 'Graph: similarity measure (between features).')\n", 39 | "# TODO: change cgcnn for combinatorial Laplacians.\n", 40 | "flags.DEFINE_bool('normalized_laplacian', True, 'Graph Laplacian: normalized.')\n", 41 | "flags.DEFINE_integer('coarsening_levels', 0, 'Number of coarsened graphs.')\n", 42 | "\n", 43 | "flags.DEFINE_string('dir_data', os.path.join('data', 'rcv1'), 'Directory to store data.')\n", 44 | "flags.DEFINE_integer('val_size', 400, 'Size of the validation set.')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# Data\n", 52 | "\n", 53 | "**From Dropout (Bruna did the same).**\n", 54 | "We took the dataset and split it into 63 classes based on the the 63 categories at the second-level of the category tree. We removed 11 categories that did not have any data and one category that had only 4 training examples. We also removed one category that covered a huge chunk (25%) of the examples. This left us with 50 classes and 402,738 documents. We divided the documents into equal-sized training and test sets randomly. Each document was represented\n", 55 | "using the 2000 most frequent non-stopwords in the dataset." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "collapsed": false 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "# Fetch dataset from Scikit-learn.\n", 67 | "dataset = utils.TextRCV1(data_home=FLAGS.dir_data)\n", 68 | "\n", 69 | "# Pre-processing: transform everything to a-z and whitespace.\n", 70 | "#print(train.show_document(1)[:400])\n", 71 | "#train.clean_text(num='substitute')\n", 72 | "\n", 73 | "# Analyzing / tokenizing: transform documents to bags-of-words.\n", 74 | "#stop_words = set(sklearn.feature_extraction.text.ENGLISH_STOP_WORDS)\n", 75 | "# Or stop words from NLTK.\n", 76 | "# Add e.g. don, ve.\n", 77 | "#train.vectorize(stop_words='english')\n", 78 | "#print(train.show_document(1)[:400])" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "collapsed": false 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "# Selection of classes.\n", 90 | "keep = ['C11','C12','C13','C14','C15','C16','C17','C18','C21','C22','C23','C24',\n", 91 | " 'C31','C32','C33','C34','C41','C42','E11','E12','E13','E14','E21','E31',\n", 92 | " 'E41','E51','E61','E71','G15','GCRIM','GDEF','GDIP','GDIS','GENT','GENV',\n", 93 | " 'GFAS','GHEA','GJOB','GMIL','GOBIT','GODD','GPOL','GPRO','GREL','GSCI',\n", 94 | " 'GSPO','GTOUR','GVIO','GVOTE','GWEA','GWELF','M11','M12','M13','M14']\n", 95 | "assert len(keep) == 55 # There is 55 second-level categories according to LYRL2004.\n", 96 | "keep.remove('C15') # 151785 documents\n", 97 | "keep.remove('GMIL') # 5 documents only\n", 98 | "\n", 99 | "dataset.show_doc_per_class()\n", 100 | "dataset.show_classes_per_doc()\n", 101 | "dataset.remove_classes(keep)\n", 102 | "dataset.show_doc_per_class(True)\n", 103 | "dataset.show_classes_per_doc()" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": { 110 | "collapsed": false 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "# Remove documents with multiple classes.\n", 115 | "dataset.select_documents()\n", 116 | "dataset.data_info()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": { 123 | "collapsed": false 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "# Remove short documents.\n", 128 | "#train.data_info(True)\n", 129 | "#wc = train.remove_short_documents(nwords=20, vocab='full')\n", 130 | "#train.data_info()\n", 131 | "#print('shortest: {}, longest: {} words'.format(wc.min(), wc.max()))\n", 132 | "#plt.figure(figsize=(17,5))\n", 133 | "#plt.semilogy(wc, '.');" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "collapsed": false 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "# Feature selection.\n", 145 | "# Other options include: mutual information or document count.\n", 146 | "#freq = train.keep_top_words(1000, 20)\n", 147 | "#train.data_info()\n", 148 | "#train.show_document(1)\n", 149 | "#plt.figure(figsize=(17,5))\n", 150 | "#plt.semilogy(freq);\n", 151 | "\n", 152 | "# Remove documents whose signal would be the zero vector.\n", 153 | "#wc = train.remove_short_documents(nwords=5, vocab='selected')\n", 154 | "#train.data_info(True)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "collapsed": false 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "#dataset.normalize(norm='l1')\n", 166 | "dataset.show_document(1);" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "collapsed": false 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "# Word embedding\n", 178 | "#if True:\n", 179 | "# train.embed()\n", 180 | "#else:\n", 181 | "# train.embed('data_word2vec/GoogleNews-vectors-negative300.bin')\n", 182 | "#train.data_info()\n", 183 | "# Further feature selection. (TODO)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "collapsed": false 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "perm = np.random.RandomState(seed=42).permutation(dataset.data.shape[0])\n", 195 | "Ntest = dataset.data.shape[0] // 2\n", 196 | "perm_test = perm[:Ntest]\n", 197 | "perm_train = perm[Ntest:]\n", 198 | "train_data = dataset.data[perm_train,:].astype(np.float32)\n", 199 | "test_data = dataset.data[perm_test,:].astype(np.float32)\n", 200 | "train_labels = dataset.labels[perm_train]\n", 201 | "test_labels = dataset.labels[perm_test]\n", 202 | "\n", 203 | "if False:\n", 204 | " graph_data = train.embeddings.astype(np.float32)\n", 205 | "else:\n", 206 | " graph_data = dataset.data.T.astype(np.float32)\n", 207 | "\n", 208 | "#del dataset" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "# Feature graph" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": { 222 | "collapsed": false 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "t_start = time.process_time()\n", 227 | "dist, idx = graph.distance_lshforest(graph_data.astype(np.float64), k=FLAGS.number_edges, metric=FLAGS.metric)\n", 228 | "A = graph.adjacency(dist.astype(np.float32), idx)\n", 229 | "print(\"{} > {} edges\".format(A.nnz//2, FLAGS.number_edges*graph_data.shape[0]//2))\n", 230 | "A = graph.replace_random_edges(A, 0)\n", 231 | "graphs, perm = coarsening.coarsen(A, levels=FLAGS.coarsening_levels, self_connections=False)\n", 232 | "L = [graph.laplacian(A, normalized=True) for A in graphs]\n", 233 | "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 234 | "#graph.plot_spectrum(L)\n", 235 | "#del graph_data, A, dist, idx" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": { 242 | "collapsed": false 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "assert FLAGS.coarsening_levels is 0\n", 247 | "#t_start = time.process_time()\n", 248 | "#train_data = scipy.sparse.csr_matrix(coarsening.perm_data(train_data.toarray(), perm))\n", 249 | "#test_data = scipy.sparse.csr_matrix(coarsening.perm_data(test_data.toarray(), perm))\n", 250 | "#print('Execution time: {:.2f}s'.format(time.process_time() - t_start))\n", 251 | "#del perm" 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "metadata": {}, 257 | "source": [ 258 | "# Classification" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "metadata": { 265 | "collapsed": false 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "# Training set is shuffled already.\n", 270 | "#perm = np.random.permutation(train_data.shape[0])\n", 271 | "#train_data = train_data[perm,:]\n", 272 | "#train_labels = train_labels[perm]\n", 273 | "\n", 274 | "# Validation set.\n", 275 | "if False:\n", 276 | " val_data = train_data[:FLAGS.val_size,:]\n", 277 | " val_labels = train_labels[:FLAGS.val_size]\n", 278 | " train_data = train_data[FLAGS.val_size:,:]\n", 279 | " train_labels = train_labels[FLAGS.val_size:]\n", 280 | "else:\n", 281 | " val_data = test_data\n", 282 | " val_labels = test_labels" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": { 289 | "collapsed": false 290 | }, 291 | "outputs": [], 292 | "source": [ 293 | "if False:\n", 294 | " utils.baseline(train_data, train_labels, test_data, test_labels)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": { 301 | "collapsed": false 302 | }, 303 | "outputs": [], 304 | "source": [ 305 | "common = {}\n", 306 | "common['dir_name'] = 'rcv1/'\n", 307 | "common['num_epochs'] = 4\n", 308 | "common['batch_size'] = 100\n", 309 | "common['decay_steps'] = len(train_labels) / common['batch_size']\n", 310 | "common['eval_frequency'] = 200\n", 311 | "common['filter'] = 'chebyshev5'\n", 312 | "common['brelu'] = 'b1relu'\n", 313 | "common['pool'] = 'mpool1'\n", 314 | "C = max(train_labels) + 1 # number of classes\n", 315 | "\n", 316 | "model_perf = utils.model_perf()" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": { 323 | "collapsed": false 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "if True:\n", 328 | " name = 'softmax'\n", 329 | " params = common.copy()\n", 330 | " params['dir_name'] += name\n", 331 | " params['regularization'] = 0\n", 332 | " params['dropout'] = 1\n", 333 | " params['learning_rate'] = 1e3\n", 334 | " params['decay_rate'] = 0.95\n", 335 | " params['momentum'] = 0.9\n", 336 | " params['F'] = []\n", 337 | " params['K'] = []\n", 338 | " params['p'] = []\n", 339 | " params['M'] = [C]\n", 340 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 341 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": { 348 | "collapsed": false 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "if True:\n", 353 | " name = 'fc_softmax'\n", 354 | " params = common.copy()\n", 355 | " params['dir_name'] += name\n", 356 | " params['regularization'] = 0\n", 357 | " params['dropout'] = 1\n", 358 | " params['learning_rate'] = 0.1\n", 359 | " params['decay_rate'] = 0.95\n", 360 | " params['momentum'] = 0.9\n", 361 | " params['F'] = []\n", 362 | " params['K'] = []\n", 363 | " params['p'] = []\n", 364 | " params['M'] = [2500, C]\n", 365 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 366 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": { 373 | "collapsed": false 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "if True:\n", 378 | " name = 'fc_fc_softmax'\n", 379 | " params = common.copy()\n", 380 | " params['dir_name'] += name\n", 381 | " params['regularization'] = 0\n", 382 | " params['dropout'] = 1\n", 383 | " params['learning_rate'] = 0.1\n", 384 | " params['decay_rate'] = 0.95\n", 385 | " params['momentum'] = 0.9\n", 386 | " params['F'] = []\n", 387 | " params['K'] = []\n", 388 | " params['p'] = []\n", 389 | " params['M'] = [2500, 500, C]\n", 390 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 391 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": { 398 | "collapsed": false 399 | }, 400 | "outputs": [], 401 | "source": [ 402 | "if True:\n", 403 | " name = 'cgconv_softmax'\n", 404 | " params = common.copy()\n", 405 | " params['dir_name'] += name\n", 406 | " params['regularization'] = 1e-3\n", 407 | " params['dropout'] = 1\n", 408 | " params['learning_rate'] = 0.1\n", 409 | " params['decay_rate'] = 0.999\n", 410 | " params['momentum'] = 0\n", 411 | " params['F'] = [1]\n", 412 | " params['K'] = [5]\n", 413 | " params['p'] = [1]\n", 414 | " params['M'] = [C]\n", 415 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 416 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": { 423 | "collapsed": false 424 | }, 425 | "outputs": [], 426 | "source": [ 427 | "if True:\n", 428 | " name = 'cgconv_fc_softmax'\n", 429 | " params = common.copy()\n", 430 | " params['dir_name'] += name\n", 431 | " params['regularization'] = 0\n", 432 | " params['dropout'] = 1\n", 433 | " params['learning_rate'] = 0.1\n", 434 | " params['decay_rate'] = 0.999\n", 435 | " params['momentum'] = 0\n", 436 | " params['F'] = [5]\n", 437 | " params['K'] = [15]\n", 438 | " params['p'] = [1]\n", 439 | " params['M'] = [100, C]\n", 440 | " model_perf.test(models.cgcnn(L, **params), name, params,\n", 441 | " train_data, train_labels, val_data, val_labels, test_data, test_labels)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "collapsed": true 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "model_perf.show()" 453 | ] 454 | } 455 | ], 456 | "metadata": { 457 | "kernelspec": { 458 | "display_name": "Python 3", 459 | "language": "python", 460 | "name": "python3" 461 | }, 462 | "language_info": { 463 | "codemirror_mode": { 464 | "name": "ipython", 465 | "version": 3 466 | }, 467 | "file_extension": ".py", 468 | "mimetype": "text/x-python", 469 | "name": "python", 470 | "nbconvert_exporter": "python", 471 | "pygments_lexer": "ipython3", 472 | "version": "3.4.3" 473 | } 474 | }, 475 | "nbformat": 4, 476 | "nbformat_minor": 0 477 | } 478 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | scikit-learn 4 | matplotlib 5 | 6 | gensim 7 | tensorflow-gpu 8 | #tensorflow 9 | 10 | jupyter 11 | ipython 12 | -------------------------------------------------------------------------------- /trials/2_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Trial 2: classification with learned graph filters\n", 8 | "\n", 9 | "We want to classify data by first extracting meaningful features from learned filters." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "collapsed": false 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import time\n", 21 | "import numpy as np\n", 22 | "import scipy.sparse, scipy.sparse.linalg, scipy.spatial.distance\n", 23 | "from sklearn import datasets, linear_model\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "%matplotlib inline\n", 26 | "\n", 27 | "import os\n", 28 | "import sys\n", 29 | "sys.path.append('..')\n", 30 | "from lib import graph" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "# Parameters" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Dataset\n", 45 | "\n", 46 | "* Two digits version of MNIST with N samples of each class.\n", 47 | "* Distinguishing 4 from 9 is the hardest." 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "collapsed": false 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "def mnist(a, b, N):\n", 59 | " \"\"\"Prepare data for binary classification of MNIST.\"\"\"\n", 60 | " folder = os.path.join('..', 'data')\n", 61 | " mnist = datasets.fetch_mldata('MNIST original', data_home=folder)\n", 62 | "\n", 63 | " assert N < min(sum(mnist.target==a), sum(mnist.target==b))\n", 64 | " M = mnist.data.shape[1]\n", 65 | " \n", 66 | " X = np.empty((M, 2, N))\n", 67 | " X[:,0,:] = mnist.data[mnist.target==a,:][:N,:].T\n", 68 | " X[:,1,:] = mnist.data[mnist.target==b,:][:N,:].T\n", 69 | " \n", 70 | " y = np.empty((2, N))\n", 71 | " y[0,:] = -1\n", 72 | " y[1,:] = +1\n", 73 | "\n", 74 | " X.shape = M, 2*N\n", 75 | " y.shape = 2*N, 1\n", 76 | " return X, y\n", 77 | "\n", 78 | "X, y = mnist(4, 9, 1000)\n", 79 | "\n", 80 | "print('Dimensionality: N={} samples, M={} features'.format(X.shape[1], X.shape[0]))\n", 81 | "\n", 82 | "X -= 127.5\n", 83 | "print('X in [{}, {}]'.format(np.min(X), np.max(X)))\n", 84 | "\n", 85 | "def plot_digit(nn):\n", 86 | " M, N = X.shape\n", 87 | " m = int(np.sqrt(M))\n", 88 | " fig, axes = plt.subplots(1,len(nn), figsize=(15,5))\n", 89 | " for i, n in enumerate(nn):\n", 90 | " n = int(n)\n", 91 | " img = X[:,n]\n", 92 | " axes[i].imshow(img.reshape((m,m)))\n", 93 | " axes[i].set_title('Label: y = {:.0f}'.format(y[n,0]))\n", 94 | "\n", 95 | "plot_digit([0, 1, 1e2, 1e2+1, 1e3, 1e3+1])" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# Regularized least-square\n", 103 | "\n", 104 | "## Reference: sklearn ridge regression\n", 105 | "\n", 106 | "* With regularized data, the objective is the same with or without bias." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": { 113 | "collapsed": false 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "def test_sklearn(tauR):\n", 118 | " \n", 119 | " def L(w, b=0):\n", 120 | " return np.linalg.norm(X.T @ w + b - y)**2 + tauR * np.linalg.norm(w)**2\n", 121 | "\n", 122 | " def dL(w):\n", 123 | " return 2 * X @ (X.T @ w - y) + 2 * tauR * w\n", 124 | "\n", 125 | " clf = linear_model.Ridge(alpha=tauR, fit_intercept=False)\n", 126 | " clf.fit(X.T, y)\n", 127 | " w = clf.coef_.T\n", 128 | "\n", 129 | " print('L = {}'.format(L(w, clf.intercept_)))\n", 130 | " print('|dLw| = {}'.format(np.linalg.norm(dL(w))))\n", 131 | "\n", 132 | " # Normalized data: intercept should be small.\n", 133 | " print('bias: {}'.format(abs(np.mean(y - X.T @ w))))\n", 134 | "\n", 135 | "test_sklearn(1e-3)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "## Linear classifier" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": { 149 | "collapsed": false 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "def test_optim(clf, X, y, ax=None):\n", 154 | " \"\"\"Test optimization on full dataset.\"\"\"\n", 155 | " tstart = time.process_time()\n", 156 | " ret = clf.fit(X, y)\n", 157 | " print('Processing time: {}'.format(time.process_time()-tstart))\n", 158 | " print('L = {}'.format(clf.L(*ret, y)))\n", 159 | " if hasattr(clf, 'dLc'):\n", 160 | " print('|dLc| = {}'.format(np.linalg.norm(clf.dLc(*ret, y))))\n", 161 | " if hasattr(clf, 'dLw'):\n", 162 | " print('|dLw| = {}'.format(np.linalg.norm(clf.dLw(*ret, y))))\n", 163 | " if hasattr(clf, 'loss'):\n", 164 | " if not ax:\n", 165 | " fig = plt.figure()\n", 166 | " ax = fig.add_subplot(111)\n", 167 | " ax.semilogy(clf.loss)\n", 168 | " ax.set_title('Convergence')\n", 169 | " ax.set_xlabel('Iteration number')\n", 170 | " ax.set_ylabel('Loss')\n", 171 | " if hasattr(clf, 'Lsplit'):\n", 172 | " print('Lsplit = {}'.format(clf.Lsplit(*ret, y)))\n", 173 | " print('|dLz| = {}'.format(np.linalg.norm(clf.dLz(*ret, y))))\n", 174 | " ax.semilogy(clf.loss_split)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": { 181 | "collapsed": false 182 | }, 183 | "outputs": [], 184 | "source": [ 185 | "class rls:\n", 186 | " \n", 187 | " def __init__(s, tauR, algo='solve'):\n", 188 | " s.tauR = tauR\n", 189 | " if algo is 'solve':\n", 190 | " s.fit = s.solve\n", 191 | " elif algo is 'inv':\n", 192 | " s.fit = s.inv\n", 193 | "\n", 194 | " def L(s, X, y):\n", 195 | " return np.linalg.norm(X.T @ s.w - y)**2 + s.tauR * np.linalg.norm(s.w)**2\n", 196 | "\n", 197 | " def dLw(s, X, y):\n", 198 | " return 2 * X @ (X.T @ s.w - y) + 2 * s.tauR * s.w\n", 199 | " \n", 200 | " def inv(s, X, y):\n", 201 | " s.w = np.linalg.inv(X @ X.T + s.tauR * np.identity(X.shape[0])) @ X @ y\n", 202 | " return (X,)\n", 203 | " \n", 204 | " def solve(s, X, y):\n", 205 | " s.w = np.linalg.solve(X @ X.T + s.tauR * np.identity(X.shape[0]), X @ y)\n", 206 | " return (X,)\n", 207 | " \n", 208 | " def predict(s, X):\n", 209 | " return X.T @ s.w\n", 210 | "\n", 211 | "test_optim(rls(1e-3, 'solve'), X, y)\n", 212 | "test_optim(rls(1e-3, 'inv'), X, y)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "metadata": {}, 218 | "source": [ 219 | "# Feature graph" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": { 226 | "collapsed": false 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "t_start = time.process_time()\n", 231 | "z = graph.grid(int(np.sqrt(X.shape[0])))\n", 232 | "dist, idx = graph.distance_sklearn_metrics(z, k=4)\n", 233 | "A = graph.adjacency(dist, idx)\n", 234 | "L = graph.laplacian(A, True)\n", 235 | "lmax = graph.lmax(L)\n", 236 | "print('Execution time: {:.2f}s'.format(time.process_time() - t_start))" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "# Lanczos basis" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "collapsed": false 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "def lanczos(L, X, K):\n", 255 | " M, N = X.shape\n", 256 | " a = np.empty((K, N))\n", 257 | " b = np.zeros((K, N))\n", 258 | " V = np.empty((K, M, N))\n", 259 | " V[0,...] = X / np.linalg.norm(X, axis=0)\n", 260 | " for k in range(K-1):\n", 261 | " W = L.dot(V[k,...])\n", 262 | " a[k,:] = np.sum(W * V[k,...], axis=0)\n", 263 | " W = W - a[k,:] * V[k,...] - (b[k,:] * V[k-1,...] if k>0 else 0)\n", 264 | " b[k+1,:] = np.linalg.norm(W, axis=0)\n", 265 | " V[k+1,...] = W / b[k+1,:]\n", 266 | " a[K-1,:] = np.sum(L.dot(V[K-1,...]) * V[K-1,...], axis=0)\n", 267 | " return V, a, b\n", 268 | "\n", 269 | "def lanczos_H_diag(a, b):\n", 270 | " K, N = a.shape\n", 271 | " H = np.zeros((K*K, N))\n", 272 | " H[:K**2:K+1, :] = a\n", 273 | " H[1:(K-1)*K:K+1, :] = b[1:,:]\n", 274 | " H.shape = (K, K, N)\n", 275 | " Q = np.linalg.eigh(H.T, UPLO='L')[1]\n", 276 | " Q = np.swapaxes(Q,1,2).T\n", 277 | " return Q\n", 278 | "\n", 279 | "def lanczos_basis_eval(L, X, K):\n", 280 | " V, a, b = lanczos(L, X, K)\n", 281 | " Q = lanczos_H_diag(a, b)\n", 282 | " M, N = X.shape\n", 283 | " Xt = np.empty((K, M, N))\n", 284 | " for n in range(N):\n", 285 | " Xt[...,n] = Q[...,n].T @ V[...,n]\n", 286 | " Xt *= Q[0,:,np.newaxis,:]\n", 287 | " Xt *= np.linalg.norm(X, axis=0)\n", 288 | " return Xt, Q[0,...]" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "# Tests\n", 296 | "\n", 297 | "* Memory arrangement for fastest computations: largest dimensions on the outside, i.e. fastest varying indices.\n", 298 | "* The einsum seems to be efficient for three operands." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": { 305 | "collapsed": false 306 | }, 307 | "outputs": [], 308 | "source": [ 309 | "def test():\n", 310 | " \"\"\"Test the speed of filtering and weighting.\"\"\"\n", 311 | " \n", 312 | " def mult(impl=3):\n", 313 | " if impl is 0:\n", 314 | " Xb = Xt.view()\n", 315 | " Xb.shape = (K, M*N)\n", 316 | " XCb = Xb.T @ C # in MN x F\n", 317 | " XCb = XCb.T.reshape((F*M, N))\n", 318 | " return (XCb.T @ w).squeeze()\n", 319 | " elif impl is 1:\n", 320 | " tmp = np.tensordot(Xt, C, (0,0))\n", 321 | " return np.tensordot(tmp, W, ((0,2),(1,0)))\n", 322 | " elif impl is 2:\n", 323 | " tmp = np.tensordot(Xt, C, (0,0))\n", 324 | " return np.einsum('ijk,ki->j', tmp, W)\n", 325 | " elif impl is 3:\n", 326 | " return np.einsum('kmn,fm,kf->n', Xt, W, C)\n", 327 | " \n", 328 | " C = np.random.normal(0,1,(K,F))\n", 329 | " W = np.random.normal(0,1,(F,M))\n", 330 | " w = W.reshape((F*M, 1))\n", 331 | " a = mult(impl=0)\n", 332 | " for impl in range(4):\n", 333 | " tstart = time.process_time()\n", 334 | " for k in range(1000):\n", 335 | " b = mult(impl)\n", 336 | " print('Execution time (impl={}): {}'.format(impl, time.process_time() - tstart))\n", 337 | " np.testing.assert_allclose(a, b)\n", 338 | "#test()" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "# GFL classification without weights\n", 346 | "\n", 347 | "* The matrix is singular thus not invertible." 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": { 354 | "collapsed": false 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "class gflc_noweights:\n", 359 | "\n", 360 | " def __init__(s, F, K, niter, algo='direct'):\n", 361 | " \"\"\"Model hyper-parameters\"\"\"\n", 362 | " s.F = F\n", 363 | " s.K = K\n", 364 | " s.niter = niter\n", 365 | " if algo is 'direct':\n", 366 | " s.fit = s.direct\n", 367 | " elif algo is 'sgd':\n", 368 | " s.fit = s.sgd\n", 369 | " \n", 370 | " def L(s, Xt, y):\n", 371 | " #tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, np.ones((s.F,M))) - y.squeeze()\n", 372 | " #tmp = np.einsum('kmn,kf->mnf', Xt, s.C).sum((0,2)) - y.squeeze()\n", 373 | " #tmp = (C.T @ Xt.reshape((K,M*N))).reshape((F,M,N)).sum((0,2)) - y.squeeze()\n", 374 | " tmp = np.tensordot(s.C, Xt, (0,0)).sum((0,1)) - y.squeeze()\n", 375 | " return np.linalg.norm(tmp)**2\n", 376 | "\n", 377 | " def dLc(s, Xt, y):\n", 378 | " tmp = np.tensordot(s.C, Xt, (0,0)).sum(axis=(0,1)) - y.squeeze()\n", 379 | " return np.dot(Xt, tmp).sum(1)[:,np.newaxis].repeat(s.F,1)\n", 380 | " #return np.einsum('kmn,n->km', Xt, tmp).sum(1)[:,np.newaxis].repeat(s.F,1)\n", 381 | "\n", 382 | " def sgd(s, X, y):\n", 383 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 384 | " s.C = np.random.normal(0, 1, (s.K, s.F))\n", 385 | " s.loss = [s.L(Xt, y)]\n", 386 | " for t in range(s.niter):\n", 387 | " s.C -= 1e-13 * s.dLc(Xt, y)\n", 388 | " s.loss.append(s.L(Xt, y))\n", 389 | " return (Xt,)\n", 390 | " \n", 391 | " def direct(s, X, y):\n", 392 | " M, N = X.shape\n", 393 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 394 | " s.C = np.random.normal(0, 1, (s.K, s.F))\n", 395 | " W = np.ones((s.F, M))\n", 396 | " c = s.C.reshape((s.K*s.F, 1))\n", 397 | " s.loss = [s.L(Xt, y)]\n", 398 | " Xw = np.einsum('kmn,fm->kfn', Xt, W)\n", 399 | " #Xw = np.tensordot(Xt, W, (1,1))\n", 400 | " Xw.shape = (s.K*s.F, N)\n", 401 | " #np.linalg.inv(Xw @ Xw.T)\n", 402 | " c[:] = np.linalg.solve(Xw @ Xw.T, Xw @ y)\n", 403 | " s.loss.append(s.L(Xt, y))\n", 404 | " return (Xt,)\n", 405 | "\n", 406 | "#test_optim(gflc_noweights(1, 4, 100, 'sgd'), X, y)\n", 407 | "#test_optim(gflc_noweights(1, 4, 0, 'direct'), X, y)" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "# GFL classification with weights" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": { 421 | "collapsed": false, 422 | "scrolled": true 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "class gflc_weights():\n", 427 | "\n", 428 | " def __init__(s, F, K, tauR, niter, algo='direct'):\n", 429 | " \"\"\"Model hyper-parameters\"\"\"\n", 430 | " s.F = F\n", 431 | " s.K = K\n", 432 | " s.tauR = tauR\n", 433 | " s.niter = niter\n", 434 | " if algo is 'direct':\n", 435 | " s.fit = s.direct\n", 436 | " elif algo is 'sgd':\n", 437 | " s.fit = s.sgd\n", 438 | "\n", 439 | " def L(s, Xt, y):\n", 440 | " tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, s.W) - y.squeeze()\n", 441 | " return np.linalg.norm(tmp)**2 + s.tauR * np.linalg.norm(s.W)**2\n", 442 | "\n", 443 | " def dLw(s, Xt, y):\n", 444 | " tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, s.W) - y.squeeze()\n", 445 | " return 2 * np.einsum('kmn,kf,n->fm', Xt, s.C, tmp) + 2 * s.tauR * s.W\n", 446 | "\n", 447 | " def dLc(s, Xt, y):\n", 448 | " tmp = np.einsum('kmn,kf,fm->n', Xt, s.C, s.W) - y.squeeze()\n", 449 | " return 2 * np.einsum('kmn,n,fm->kf', Xt, tmp, s.W)\n", 450 | "\n", 451 | " def sgd(s, X, y):\n", 452 | " M, N = X.shape\n", 453 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 454 | " s.C = np.random.normal(0, 1, (s.K, s.F))\n", 455 | " s.W = np.random.normal(0, 1, (s.F, M))\n", 456 | "\n", 457 | " s.loss = [s.L(Xt, y)]\n", 458 | "\n", 459 | " for t in range(s.niter):\n", 460 | " s.C -= 1e-12 * s.dLc(Xt, y)\n", 461 | " s.W -= 1e-12 * s.dLw(Xt, y)\n", 462 | " s.loss.append(s.L(Xt, y))\n", 463 | " \n", 464 | " return (Xt,)\n", 465 | "\n", 466 | " def direct(s, X, y):\n", 467 | " M, N = X.shape\n", 468 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 469 | " s.C = np.random.normal(0, 1, (s.K, s.F))\n", 470 | " s.W = np.random.normal(0, 1, (s.F, M))\n", 471 | " #c = s.C.reshape((s.K*s.F, 1))\n", 472 | " #w = s.W.reshape((s.F*M, 1))\n", 473 | " c = s.C.view()\n", 474 | " c.shape = (s.K*s.F, 1)\n", 475 | " w = s.W.view()\n", 476 | " w.shape = (s.F*M, 1)\n", 477 | "\n", 478 | " s.loss = [s.L(Xt, y)]\n", 479 | "\n", 480 | " for t in range(s.niter):\n", 481 | " Xw = np.einsum('kmn,fm->kfn', Xt, s.W)\n", 482 | " #Xw = np.tensordot(Xt, s.W, (1,1))\n", 483 | " Xw.shape = (s.K*s.F, N)\n", 484 | " c[:] = np.linalg.solve(Xw @ Xw.T, Xw @ y)\n", 485 | "\n", 486 | " Z = np.einsum('kmn,kf->fmn', Xt, s.C)\n", 487 | " #Z = np.tensordot(Xt, s.C, (0,0))\n", 488 | " #Z = s.C.T @ Xt.reshape((K,M*N))\n", 489 | " Z.shape = (s.F*M, N)\n", 490 | " w[:] = np.linalg.solve(Z @ Z.T + s.tauR * np.identity(s.F*M), Z @ y)\n", 491 | "\n", 492 | " s.loss.append(s.L(Xt, y))\n", 493 | " \n", 494 | " return (Xt,)\n", 495 | "\n", 496 | " def predict(s, X):\n", 497 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 498 | " return np.einsum('kmn,kf,fm->n', Xt, s.C, s.W)\n", 499 | "\n", 500 | "#test_optim(gflc_weights(3, 4, 1e-3, 50, 'sgd'), X, y)\n", 501 | "clf_weights = gflc_weights(F=3, K=50, tauR=1e4, niter=5, algo='direct')\n", 502 | "test_optim(clf_weights, X, y)" 503 | ] 504 | }, 505 | { 506 | "cell_type": "markdown", 507 | "metadata": {}, 508 | "source": [ 509 | "# GFL classification with splitting\n", 510 | "\n", 511 | "Solvers\n", 512 | "* Closed-form solution.\n", 513 | "* Stochastic gradient descent." 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": { 520 | "collapsed": false 521 | }, 522 | "outputs": [], 523 | "source": [ 524 | "class gflc_split():\n", 525 | "\n", 526 | " def __init__(s, F, K, tauR, tauF, niter, algo='direct'):\n", 527 | " \"\"\"Model hyper-parameters\"\"\"\n", 528 | " s.F = F\n", 529 | " s.K = K\n", 530 | " s.tauR = tauR\n", 531 | " s.tauF = tauF\n", 532 | " s.niter = niter\n", 533 | " if algo is 'direct':\n", 534 | " s.fit = s.direct\n", 535 | " elif algo is 'sgd':\n", 536 | " s.fit = s.sgd\n", 537 | "\n", 538 | " def L(s, Xt, XCb, Z, y):\n", 539 | " return np.linalg.norm(XCb.T @ s.w - y)**2 + s.tauR * np.linalg.norm(s.w)**2\n", 540 | "\n", 541 | " def Lsplit(s, Xt, XCb, Z, y):\n", 542 | " return np.linalg.norm(Z.T @ s.w - y)**2 + s.tauF * np.linalg.norm(XCb - Z)**2 + s.tauR * np.linalg.norm(s.w)**2\n", 543 | "\n", 544 | " def dLw(s, Xt, XCb, Z, y):\n", 545 | " return 2 * Z @ (Z.T @ s.w - y) + 2 * s.tauR * s.w\n", 546 | "\n", 547 | " def dLc(s, Xt, XCb, Z, y):\n", 548 | " Xb = Xt.reshape((s.K, -1)).T\n", 549 | " Zb = Z.reshape((s.F, -1)).T\n", 550 | " return 2 * s.tauF * Xb.T @ (Xb @ s.C - Zb)\n", 551 | "\n", 552 | " def dLz(s, Xt, XCb, Z, y):\n", 553 | " return 2 * s.w @ (s.w.T @ Z - y.T) + 2 * s.tauF * (Z - XCb)\n", 554 | "\n", 555 | " def lanczos_filter(s, Xt):\n", 556 | " M, N = Xt.shape[1:]\n", 557 | " Xb = Xt.reshape((s.K, M*N)).T\n", 558 | " #XCb = np.tensordot(Xb, C, (2,1))\n", 559 | " XCb = Xb @ s.C # in MN x F\n", 560 | " XCb = XCb.T.reshape((s.F*M, N)) # Needs to copy data.\n", 561 | " return XCb\n", 562 | "\n", 563 | " def sgd(s, X, y):\n", 564 | " M, N = X.shape\n", 565 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 566 | " s.C = np.zeros((s.K, s.F))\n", 567 | " s.w = np.zeros((s.F*M, 1))\n", 568 | " Z = np.random.normal(0, 1, (s.F*M, N))\n", 569 | "\n", 570 | " XCb = np.empty((s.F*M, N))\n", 571 | "\n", 572 | " s.loss = [s.L(Xt, XCb, Z, y)]\n", 573 | " s.loss_split = [s.Lsplit(Xt, XCb, Z, y)]\n", 574 | "\n", 575 | " for t in range(s.niter):\n", 576 | " s.C -= 1e-7 * s.dLc(Xt, XCb, Z, y)\n", 577 | " XCb[:] = s.lanczos_filter(Xt)\n", 578 | " Z -= 1e-4 * s.dLz(Xt, XCb, Z, y)\n", 579 | " s.w -= 1e-4 * s.dLw(Xt, XCb, Z, y)\n", 580 | " s.loss.append(s.L(Xt, XCb, Z, y))\n", 581 | " s.loss_split.append(s.Lsplit(Xt, XCb, Z, y))\n", 582 | " \n", 583 | " return Xt, XCb, Z\n", 584 | "\n", 585 | " def direct(s, X, y):\n", 586 | " M, N = X.shape\n", 587 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 588 | " s.C = np.zeros((s.K, s.F))\n", 589 | " s.w = np.zeros((s.F*M, 1))\n", 590 | " Z = np.random.normal(0, 1, (s.F*M, N))\n", 591 | "\n", 592 | " XCb = np.empty((s.F*M, N))\n", 593 | " Xb = Xt.reshape((s.K, M*N)).T\n", 594 | " Zb = Z.reshape((s.F, M*N)).T\n", 595 | "\n", 596 | " s.loss = [s.L(Xt, XCb, Z, y)]\n", 597 | " s.loss_split = [s.Lsplit(Xt, XCb, Z, y)]\n", 598 | "\n", 599 | " for t in range(s.niter):\n", 600 | "\n", 601 | " s.C[:] = Xb.T @ Zb / np.sum((np.linalg.norm(X, axis=0) * q)**2, axis=1)[:,np.newaxis]\n", 602 | " XCb[:] = s.lanczos_filter(Xt)\n", 603 | "\n", 604 | " #Z[:] = np.linalg.inv(s.tauF * np.identity(s.F*M) + s.w @ s.w.T) @ (s.tauF * XCb + s.w @ y.T)\n", 605 | " Z[:] = np.linalg.solve(s.tauF * np.identity(s.F*M) + s.w @ s.w.T, s.tauF * XCb + s.w @ y.T)\n", 606 | "\n", 607 | " #s.w[:] = np.linalg.inv(Z @ Z.T + s.tauR * np.identity(s.F*M)) @ Z @ y\n", 608 | " s.w[:] = np.linalg.solve(Z @ Z.T + s.tauR * np.identity(s.F*M), Z @ y)\n", 609 | "\n", 610 | " s.loss.append(s.L(Xt, XCb, Z, y))\n", 611 | " s.loss_split.append(s.Lsplit(Xt, XCb, Z, y))\n", 612 | " \n", 613 | " return Xt, XCb, Z\n", 614 | "\n", 615 | " def predict(s, X):\n", 616 | " Xt, q = lanczos_basis_eval(L, X, s.K)\n", 617 | " XCb = s.lanczos_filter(Xt)\n", 618 | " return XCb.T @ s.w\n", 619 | "\n", 620 | "#test_optim(gflc_split(3, 4, 1e-3, 1e-3, 50, 'sgd'), X, y)\n", 621 | "clf_split = gflc_split(3, 4, 1e4, 1e-3, 8, 'direct')\n", 622 | "test_optim(clf_split, X, y)" 623 | ] 624 | }, 625 | { 626 | "cell_type": "markdown", 627 | "metadata": {}, 628 | "source": [ 629 | "# Filters visualization\n", 630 | "\n", 631 | "Observations:\n", 632 | "* Filters learned with the splitting scheme have much smaller amplitudes.\n", 633 | "* Maybe the energy sometimes goes in W ?\n", 634 | "* Why are the filters so different ?" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": null, 640 | "metadata": { 641 | "collapsed": false 642 | }, 643 | "outputs": [], 644 | "source": [ 645 | "lamb, U = graph.fourier(L)\n", 646 | "print('Spectrum in [{:1.2e}, {:1.2e}]'.format(lamb[0], lamb[-1]))" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "metadata": { 653 | "collapsed": false 654 | }, 655 | "outputs": [], 656 | "source": [ 657 | "def plot_filters(C, spectrum=False):\n", 658 | " K, F = C.shape\n", 659 | " M, M = L.shape\n", 660 | " m = int(np.sqrt(M))\n", 661 | " X = np.zeros((M,1))\n", 662 | " X[int(m/2*(m+1))] = 1 # Kronecker\n", 663 | " Xt, q = lanczos_basis_eval(L, X, K)\n", 664 | " Z = np.einsum('kmn,kf->mnf', Xt, C)\n", 665 | " Xh = U.T @ X\n", 666 | " Zh = np.tensordot(U.T, Z, (1,0))\n", 667 | " \n", 668 | " pmin = int(m/2) - K\n", 669 | " pmax = int(m/2) + K + 1\n", 670 | " fig, axes = plt.subplots(2,int(np.ceil(F/2)), figsize=(15,5))\n", 671 | " for f in range(F):\n", 672 | " img = Z[:,0,f].reshape((m,m))[pmin:pmax,pmin:pmax]\n", 673 | " im = axes.flat[f].imshow(img, vmin=Z.min(), vmax=Z.max(), interpolation='none')\n", 674 | " axes.flat[f].set_title('Filter {}'.format(f))\n", 675 | " fig.subplots_adjust(right=0.8)\n", 676 | " cax = fig.add_axes([0.82, 0.16, 0.02, 0.7])\n", 677 | " fig.colorbar(im, cax=cax)\n", 678 | " \n", 679 | " if spectrum:\n", 680 | " ax = plt.figure(figsize=(15,5)).add_subplot(111)\n", 681 | " for f in range(F):\n", 682 | " ax.plot(lamb, Zh[...,f] / Xh, '.-', label='Filter {}'.format(f))\n", 683 | " ax.legend(loc='best')\n", 684 | " ax.set_title('Spectrum of learned filters')\n", 685 | " ax.set_xlabel('Frequency')\n", 686 | " ax.set_ylabel('Amplitude')\n", 687 | " ax.set_xlim(0, lmax)\n", 688 | "\n", 689 | "plot_filters(clf_weights.C, True)\n", 690 | "plot_filters(clf_split.C, True)" 691 | ] 692 | }, 693 | { 694 | "cell_type": "markdown", 695 | "metadata": {}, 696 | "source": [ 697 | "# Extracted features" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": null, 703 | "metadata": { 704 | "collapsed": false 705 | }, 706 | "outputs": [], 707 | "source": [ 708 | "def plot_features(C, x):\n", 709 | " K, F = C.shape\n", 710 | " m = int(np.sqrt(x.shape[0]))\n", 711 | " xt, q = lanczos_basis_eval(L, x, K)\n", 712 | " Z = np.einsum('kmn,kf->mnf', xt, C)\n", 713 | " \n", 714 | " fig, axes = plt.subplots(2,int(np.ceil(F/2)), figsize=(15,5))\n", 715 | " for f in range(F):\n", 716 | " img = Z[:,0,f].reshape((m,m))\n", 717 | " #im = axes.flat[f].imshow(img, vmin=Z.min(), vmax=Z.max(), interpolation='none')\n", 718 | " im = axes.flat[f].imshow(img, interpolation='none')\n", 719 | " axes.flat[f].set_title('Filter {}'.format(f))\n", 720 | " fig.subplots_adjust(right=0.8)\n", 721 | " cax = fig.add_axes([0.82, 0.16, 0.02, 0.7])\n", 722 | " fig.colorbar(im, cax=cax)\n", 723 | "\n", 724 | "plot_features(clf_weights.C, X[:,[0]])\n", 725 | "plot_features(clf_weights.C, X[:,[1000]])" 726 | ] 727 | }, 728 | { 729 | "cell_type": "markdown", 730 | "metadata": {}, 731 | "source": [ 732 | "# Performance w.r.t. hyper-parameters\n", 733 | "\n", 734 | "* F plays a big role.\n", 735 | " * Both for performance and training time.\n", 736 | " * Larger values lead to over-fitting !\n", 737 | "* Order $K \\in [3,5]$ seems sufficient.\n", 738 | "* $\\tau_R$ does not have much influence." 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": null, 744 | "metadata": { 745 | "collapsed": false 746 | }, 747 | "outputs": [], 748 | "source": [ 749 | "def scorer(clf, X, y):\n", 750 | " yest = clf.predict(X).round().squeeze()\n", 751 | " y = y.squeeze()\n", 752 | " yy = np.ones(len(y))\n", 753 | " yy[yest < 0] = -1\n", 754 | " nerrs = np.count_nonzero(y - yy)\n", 755 | " return 1 - nerrs / len(y)" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": { 762 | "collapsed": false 763 | }, 764 | "outputs": [], 765 | "source": [ 766 | "def perf(clf, nfolds=3):\n", 767 | " \"\"\"Test training accuracy.\"\"\"\n", 768 | " N = X.shape[1]\n", 769 | " inds = np.arange(N)\n", 770 | " np.random.shuffle(inds)\n", 771 | " inds.resize((nfolds, int(N/nfolds)))\n", 772 | " folds = np.arange(nfolds)\n", 773 | " test = inds[0,:]\n", 774 | " train = inds[folds != 0, :].reshape(-1)\n", 775 | " \n", 776 | " fig, axes = plt.subplots(1,3, figsize=(15,5))\n", 777 | " test_optim(clf, X[:,train], y[train], axes[2])\n", 778 | " \n", 779 | " axes[0].plot(train, clf.predict(X[:,train]), '.')\n", 780 | " axes[0].plot(train, y[train].squeeze(), '.')\n", 781 | " axes[0].set_ylim([-3,3])\n", 782 | " axes[0].set_title('Training set accuracy: {:.2f}'.format(scorer(clf, X[:,train], y[train])))\n", 783 | " axes[1].plot(test, clf.predict(X[:,test]), '.')\n", 784 | " axes[1].plot(test, y[test].squeeze(), '.')\n", 785 | " axes[1].set_ylim([-3,3])\n", 786 | " axes[1].set_title('Testing set accuracy: {:.2f}'.format(scorer(clf, X[:,test], y[test])))\n", 787 | " \n", 788 | " if hasattr(clf, 'C'):\n", 789 | " plot_filters(clf.C)\n", 790 | "\n", 791 | "perf(rls(tauR=1e6))\n", 792 | "for F in [1,3,5]:\n", 793 | " perf(gflc_weights(F=F, K=50, tauR=1e4, niter=5, algo='direct'))\n", 794 | "\n", 795 | "#perf(rls(tauR=1e-3))\n", 796 | "#for K in [2,3,5,7]:\n", 797 | "# perf(gflc_weights(F=3, K=K, tauR=1e-3, niter=5, algo='direct'))\n", 798 | "\n", 799 | "#for tauR in [1e-3, 1e-1, 1e1]:\n", 800 | "# perf(rls(tauR=tauR))\n", 801 | "# perf(gflc_weights(F=3, K=3, tauR=tauR, niter=5, algo='direct'))" 802 | ] 803 | }, 804 | { 805 | "cell_type": "markdown", 806 | "metadata": {}, 807 | "source": [ 808 | "# Classification\n", 809 | "\n", 810 | "* Greater is $F$, greater should $K$ be." 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": null, 816 | "metadata": { 817 | "collapsed": true 818 | }, 819 | "outputs": [], 820 | "source": [ 821 | "def cross_validation(clf, nfolds, nvalidations):\n", 822 | " M, N = X.shape\n", 823 | " scores = np.empty((nvalidations, nfolds))\n", 824 | " for nval in range(nvalidations):\n", 825 | " inds = np.arange(N)\n", 826 | " np.random.shuffle(inds)\n", 827 | " inds.resize((nfolds, int(N/nfolds)))\n", 828 | " folds = np.arange(nfolds)\n", 829 | " for n in folds:\n", 830 | " test = inds[n,:]\n", 831 | " train = inds[folds != n, :].reshape(-1)\n", 832 | " clf.fit(X[:,train], y[train])\n", 833 | " scores[nval, n] = scorer(clf, X[:,test], y[test])\n", 834 | " return scores.mean()*100, scores.std()*100\n", 835 | " #print('Accuracy: {:.2f} +- {:.2f}'.format(scores.mean()*100, scores.std()*100))\n", 836 | " #print(scores)" 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "execution_count": null, 842 | "metadata": { 843 | "collapsed": false 844 | }, 845 | "outputs": [], 846 | "source": [ 847 | "def test_classification(clf, params, param, values, nfolds=10, nvalidations=1):\n", 848 | " means = []\n", 849 | " stds = []\n", 850 | " fig, ax = plt.subplots(1,1, figsize=(15,5))\n", 851 | " for i,val in enumerate(values):\n", 852 | " params[param] = val\n", 853 | " mean, std = cross_validation(clf(**params), nfolds, nvalidations)\n", 854 | " means.append(mean)\n", 855 | " stds.append(std)\n", 856 | " ax.annotate('{:.2f} +- {:.2f}'.format(mean,std), xy=(i,mean), xytext=(10,10), textcoords='offset points')\n", 857 | " ax.errorbar(np.arange(len(values)), means, stds, fmt='.', markersize=10)\n", 858 | " ax.set_xlim(-.8, len(values)-.2)\n", 859 | " ax.set_xticks(np.arange(len(values)))\n", 860 | " ax.set_xticklabels(values)\n", 861 | " ax.set_xlabel(param)\n", 862 | " ax.set_ylim(50, 100)\n", 863 | " ax.set_ylabel('Accuracy')\n", 864 | " ax.set_title('Parameters: {}'.format(params))" 865 | ] 866 | }, 867 | { 868 | "cell_type": "code", 869 | "execution_count": null, 870 | "metadata": { 871 | "collapsed": false 872 | }, 873 | "outputs": [], 874 | "source": [ 875 | "test_classification(rls, {}, 'tauR', [1e8,1e7,1e6,1e5,1e4,1e3,1e-5,1e-8], 10, 10)" 876 | ] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": null, 881 | "metadata": { 882 | "collapsed": false 883 | }, 884 | "outputs": [], 885 | "source": [ 886 | "params = {'F':1, 'K':2, 'tauR':1e3, 'niter':5, 'algo':'direct'}\n", 887 | "test_classification(gflc_weights, params, 'tauR', [1e8,1e6,1e5,1e4,1e3,1e2,1e-3,1e-8], 10, 10)" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": null, 893 | "metadata": { 894 | "collapsed": false 895 | }, 896 | "outputs": [], 897 | "source": [ 898 | "params = {'F':2, 'K':10, 'tauR':1e4, 'niter':5, 'algo':'direct'}\n", 899 | "test_classification(gflc_weights, params, 'F', [1,2,3,5])" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": { 906 | "collapsed": false 907 | }, 908 | "outputs": [], 909 | "source": [ 910 | "params = {'F':2, 'K':4, 'tauR':1e4, 'niter':5, 'algo':'direct'}\n", 911 | "test_classification(gflc_weights, params, 'K', [2,3,4,5,8,10,20,30,50,70])" 912 | ] 913 | }, 914 | { 915 | "cell_type": "markdown", 916 | "metadata": {}, 917 | "source": [ 918 | "# Sampled MNIST" 919 | ] 920 | }, 921 | { 922 | "cell_type": "code", 923 | "execution_count": null, 924 | "metadata": { 925 | "collapsed": true 926 | }, 927 | "outputs": [], 928 | "source": [ 929 | "Xfull = X" 930 | ] 931 | }, 932 | { 933 | "cell_type": "code", 934 | "execution_count": null, 935 | "metadata": { 936 | "collapsed": false 937 | }, 938 | "outputs": [], 939 | "source": [ 940 | "def sample(X, p, seed=None):\n", 941 | " M, N = X.shape\n", 942 | " z = graph.grid(int(np.sqrt(M)))\n", 943 | " \n", 944 | " # Select random pixels.\n", 945 | " np.random.seed(seed)\n", 946 | " mask = np.arange(M)\n", 947 | " np.random.shuffle(mask)\n", 948 | " mask = mask[:int(p*M)]\n", 949 | " \n", 950 | " return z[mask,:], X[mask,:]\n", 951 | "\n", 952 | "X = Xfull\n", 953 | "z, X = sample(X, .5)\n", 954 | "dist, idx = graph.distance_sklearn_metrics(z, k=4)\n", 955 | "A = graph.adjacency(dist, idx)\n", 956 | "L = graph.laplacian(A)\n", 957 | "lmax = graph.lmax(L)\n", 958 | "lamb, U = graph.fourier(L)\n", 959 | "print('Spectrum in [{:1.2e}, {:1.2e}]'.format(lamb[0], lamb[-1]))\n", 960 | "\n", 961 | "print(L.shape)\n", 962 | "\n", 963 | "def plot(n):\n", 964 | " M, N = X.shape\n", 965 | " m = int(np.sqrt(M))\n", 966 | " x = X[:,n]\n", 967 | " #print(x+127.5)\n", 968 | " plt.scatter(z[:,0], -z[:,1], s=20, c=x+127.5)\n", 969 | "plot(10)\n", 970 | "\n", 971 | "def plot_digit(nn):\n", 972 | " M, N = X.shape\n", 973 | " m = int(np.sqrt(M))\n", 974 | " fig, axes = plt.subplots(1,len(nn), figsize=(15,5))\n", 975 | " for i, n in enumerate(nn):\n", 976 | " n = int(n)\n", 977 | " img = X[:,n]\n", 978 | " axes[i].imshow(img.reshape((m,m)))\n", 979 | " axes[i].set_title('Label: y = {:.0f}'.format(y[n,0]))\n", 980 | "\n", 981 | "#plot_digit([0, 1, 1e2, 1e2+1, 1e3, 1e3+1])" 982 | ] 983 | }, 984 | { 985 | "cell_type": "code", 986 | "execution_count": null, 987 | "metadata": { 988 | "collapsed": false 989 | }, 990 | "outputs": [], 991 | "source": [ 992 | "#clf_weights = gflc_weights(F=3, K=4, tauR=1e-3, niter=5, algo='direct')\n", 993 | "#test_optim(clf_weights, X, y)\n", 994 | "#plot_filters(clf_weights.C, True)" 995 | ] 996 | }, 997 | { 998 | "cell_type": "code", 999 | "execution_count": null, 1000 | "metadata": { 1001 | "collapsed": false 1002 | }, 1003 | "outputs": [], 1004 | "source": [ 1005 | "#test_classification(rls, {}, 'tauR', [1e1,1e0])\n", 1006 | "#params = {'F':2, 'K':5, 'tauR':1e-3, 'niter':5, 'algo':'direct'}\n", 1007 | "#test_classification(gflc_weights, params, 'F', [1,2,3])" 1008 | ] 1009 | }, 1010 | { 1011 | "cell_type": "code", 1012 | "execution_count": null, 1013 | "metadata": { 1014 | "collapsed": false 1015 | }, 1016 | "outputs": [], 1017 | "source": [ 1018 | "test_classification(rls, {}, 'tauR', [1e8,1e7,1e6,1e5,1e4,1e3,1e-5,1e-8], 10, 10)" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": null, 1024 | "metadata": { 1025 | "collapsed": false 1026 | }, 1027 | "outputs": [], 1028 | "source": [ 1029 | "params = {'F':2, 'K':2, 'tauR':1e3, 'niter':5, 'algo':'direct'}\n", 1030 | "test_classification(gflc_weights, params, 'tauR', [1e8,1e5,1e4,1e3,1e2,1e1,1e-3,1e-8], 10, 1)" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "code", 1035 | "execution_count": null, 1036 | "metadata": { 1037 | "collapsed": false 1038 | }, 1039 | "outputs": [], 1040 | "source": [ 1041 | "params = {'F':2, 'K':10, 'tauR':1e5, 'niter':5, 'algo':'direct'}\n", 1042 | "test_classification(gflc_weights, params, 'F', [1,2,3,4,5,10])" 1043 | ] 1044 | }, 1045 | { 1046 | "cell_type": "code", 1047 | "execution_count": null, 1048 | "metadata": { 1049 | "collapsed": false 1050 | }, 1051 | "outputs": [], 1052 | "source": [ 1053 | "params = {'F':2, 'K':4, 'tauR':1e5, 'niter':5, 'algo':'direct'}\n", 1054 | "test_classification(gflc_weights, params, 'K', [2,3,4,5,6,7,8,10,20,30])" 1055 | ] 1056 | } 1057 | ], 1058 | "metadata": { 1059 | "kernelspec": { 1060 | "display_name": "Python 3", 1061 | "language": "python", 1062 | "name": "python3" 1063 | }, 1064 | "language_info": { 1065 | "codemirror_mode": { 1066 | "name": "ipython", 1067 | "version": 3 1068 | }, 1069 | "file_extension": ".py", 1070 | "mimetype": "text/x-python", 1071 | "name": "python", 1072 | "nbconvert_exporter": "python", 1073 | "pygments_lexer": "ipython3", 1074 | "version": "3.5.2" 1075 | } 1076 | }, 1077 | "nbformat": 4, 1078 | "nbformat_minor": 0 1079 | } 1080 | -------------------------------------------------------------------------------- /trials/3_tensorflow.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Trial 3: TensorFlow\n", 8 | "\n", 9 | "Small experiment to familiarize myself with TensorFlow." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "collapsed": false 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "import tensorflow as tf" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Data" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "collapsed": false 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "from tensorflow.examples.tutorials.mnist import input_data\n", 39 | "import os\n", 40 | "folder = os.path.join('..', 'data', 'mnist')\n", 41 | "mnist = input_data.read_data_sets(folder, one_hot=True)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "# Model" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "x = tf.placeholder(tf.float32, [None, 784])\n", 60 | "W = tf.Variable(tf.zeros([784, 10]))\n", 61 | "b = tf.Variable(tf.zeros([10]))\n", 62 | "y = tf.nn.softmax(tf.matmul(x, W) + b)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "# Training" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "collapsed": true 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "y_ = tf.placeholder(tf.float32, [None, 10])\n", 81 | "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))\n", 82 | "train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)\n", 83 | "\n", 84 | "init = tf.initialize_all_variables()\n", 85 | "sess = tf.Session()\n", 86 | "sess.run(init)\n", 87 | "\n", 88 | "for i in range(1000):\n", 89 | " batch_xs, batch_ys = mnist.train.next_batch(100)\n", 90 | " sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "# Evaluation" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": { 104 | "collapsed": false 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))\n", 109 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n", 110 | "print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))" 111 | ] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": "Python 3", 117 | "language": "python", 118 | "name": "python3" 119 | }, 120 | "language_info": { 121 | "codemirror_mode": { 122 | "name": "ipython", 123 | "version": 3 124 | }, 125 | "file_extension": ".py", 126 | "mimetype": "text/x-python", 127 | "name": "python", 128 | "nbconvert_exporter": "python", 129 | "pygments_lexer": "ipython3", 130 | "version": "3.5.2" 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 0 135 | } 136 | -------------------------------------------------------------------------------- /trials/4_coarsening.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Trial 4: graph coarsening\n", 10 | "\n", 11 | "* First Python implementation of the greedy Metis and Graclus coarsening algorithms.\n", 12 | "* Results comparison with a previously developed matlab implementation.\n", 13 | "* Results comparison with the newer version in the `coarsening` module." 14 | ] 15 | }, 16 | { 17 | "cell_type": "raw", 18 | "metadata": {}, 19 | "source": [ 20 | "METIS COARSENING IMPLEMENTATION AS PROPOSED IN:\n", 21 | "An incremental reseeding strategy for clustering\n", 22 | "X Bresson, H Hu, T Laurent, A Szlam, J von Brecht\n", 23 | "arXiv preprint arXiv:1406.3837\n", 24 | "3 May 2016" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "import os\n", 36 | "import scipy.io\n", 37 | "import scipy.sparse\n", 38 | "import numpy as np" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "collapsed": false 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "if False:\n", 50 | " # CHECK PYTHON RESULTS WITH MATLAB CODE\n", 51 | " folder = os.path.join('..', 'data', 'metis_matlab.mat')\n", 52 | " mat = scipy.io.loadmat(folder)\n", 53 | " W = mat['W']\n", 54 | " W = scipy.sparse.csr_matrix(W)\n", 55 | " rid = mat['rid']-1\n", 56 | " rid = rid.T\n", 57 | " rid = rid.squeeze()\n", 58 | " #print(type(W))\n", 59 | " #print(type(rid))\n", 60 | " print(W.shape)\n", 61 | " print(W.nnz)\n", 62 | " #print(rid.shape)\n", 63 | "\n", 64 | "else:\n", 65 | " N = 533\n", 66 | " #np.random.seed(0)\n", 67 | " rid = np.random.permutation(range(N))\n", 68 | " W = np.random.uniform(0.01, 0.99, size=(N,N))\n", 69 | " mask = np.random.uniform(size=(N,N))\n", 70 | " W[mask<0.99] = 0\n", 71 | " W = scipy.sparse.csr_matrix(W)\n", 72 | " print(W.nnz)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "collapsed": true 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# INPUT\n", 84 | "# W = symmetric sparse weight matrix\n", 85 | "# maxsize = the number of nodes for the coarsest graph\n", 86 | "# OUTPUT\n", 87 | "# graph{1}: original graph of size N_1\n", 88 | "# graph{2}: coarser graph of size N_2 < N_1\n", 89 | "# etc...\n", 90 | "# graph{k}: corsest graph of Size N_k <...< N_2 < N_1\n", 91 | "# parents{i} is a vector of size N_i with entries ranging from 1 to N_{i+1}\n", 92 | "# which indicate the parents in the coarser graph{i+1} \n", 93 | "# nd_sz{i} is a vector of size N_i that contains the size of the supernode in the graph{i}\n", 94 | "# NOTE\n", 95 | "# if \"graph\" is a cell of size k, then \"parents\" will be a cell of size k-1\n", 96 | "\n", 97 | "def metis_coarsening(W,maxsize,rid):\n", 98 | " \n", 99 | " N = W.shape[0]\n", 100 | " print('Size of original graph=',N)\n", 101 | " parents = []\n", 102 | " degree = W.sum(axis=0) - W.diagonal()\n", 103 | " graphs = []\n", 104 | " graphs.append(W)\n", 105 | " supernode_size = np.ones(N)\n", 106 | " nd_sz = [supernode_size]\n", 107 | " count = 0\n", 108 | " \n", 109 | " while N > maxsize:\n", 110 | " \n", 111 | " count = count + 1;\n", 112 | " print('level=',count)\n", 113 | " \n", 114 | " # CHOOSE THE WEIGHTS FOR THE PAIRING\n", 115 | " # weights = ones(N,1) # metis weights\n", 116 | " weights = degree # graclus weights\n", 117 | " # weights = supernode_size # other possibility\n", 118 | " weights = weights.T\n", 119 | " weights = np.array(weights)\n", 120 | " weights = weights.squeeze()\n", 121 | " \n", 122 | " # PAIR THE VERTICES AND CONSTRUCT THE ROOT VECTOR\n", 123 | " idx_row,idx_col,val = scipy.sparse.find(W) \n", 124 | " perm = np.argsort(idx_row)\n", 125 | " rr = idx_row[perm]\n", 126 | " cc = idx_col[perm]\n", 127 | " vv = val[perm]\n", 128 | " cluster_id = one_level_coarsening(rr,cc,vv,rid,weights) # rr is ordered \n", 129 | " parents.append(cluster_id)\n", 130 | " \n", 131 | " # TO DO\n", 132 | " # COMPUTE THE SIZE OF THE SUPERNODES AND THEIR DEGREE \n", 133 | " #supernode_size = full( sparse(cluster_id, ones(N,1) , supernode_size ) )\n", 134 | " #print(cluster_id)\n", 135 | " #print(supernode_size)\n", 136 | " #nd_sz{count+1}=supernode_size;\n", 137 | " \n", 138 | " # COMPUTE THE EDGES WEIGHTS FOR THE NEW GRAPH\n", 139 | " nrr = cluster_id[rr]\n", 140 | " ncc = cluster_id[cc]\n", 141 | " nvv = vv\n", 142 | " Nnew = int(cluster_id.max()) + 1\n", 143 | " print('Size of coarser graph=',Nnew)\n", 144 | " W = scipy.sparse.csr_matrix((nvv,(nrr,ncc)),shape=(Nnew,Nnew))\n", 145 | " # Add new graph to the list of all coarsened graphs\n", 146 | " graphs.append(W)\n", 147 | " N = W.shape[0]\n", 148 | " \n", 149 | " # COMPUTE THE DEGREE (OMIT OR NOT SELF LOOPS)\n", 150 | " degree = W.sum(axis=0)\n", 151 | " #degree = W.sum(axis=0) - W.diagonal()\n", 152 | " \n", 153 | " # CHOOSE THE ORDER IN WHICH VERTICES WILL BE VISTED AT THE NEXT PASS\n", 154 | " #[~, rid]=sort(ss); # arthur strategy\n", 155 | " #[~, rid]=sort(supernode_size); # thomas strategy\n", 156 | " #rid=randperm(N); # metis/graclus strategy \n", 157 | " ss = W.sum(axis=0).T\n", 158 | " rid = [i[0] for i in sorted(enumerate(ss), key=lambda x:x[1])] # [~, rid]=sort(ss);\n", 159 | " \n", 160 | " \n", 161 | " # Remove all diagonal entries in similarity matrices\n", 162 | " for i in range(len(graphs)): \n", 163 | " csr_setdiag_val(graphs[i])\n", 164 | " scipy.sparse.csr_matrix.eliminate_zeros(graphs[i])\n", 165 | " \n", 166 | " \n", 167 | " return graphs,parents" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": { 174 | "collapsed": true 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "#http://nbviewer.ipython.org/gist/Midnighter/9992103\n", 179 | "def csr_setdiag_val(csr, value=0):\n", 180 | " \"\"\"Set all diagonal nonzero elements\n", 181 | " (elements currently in the sparsity pattern)\n", 182 | " to the given value. Useful to set to 0 mostly.\n", 183 | " \"\"\"\n", 184 | " if csr.format != \"csr\":\n", 185 | " raise ValueError('Matrix given must be of CSR format.')\n", 186 | " csr.sort_indices()\n", 187 | " pointer = csr.indptr\n", 188 | " indices = csr.indices\n", 189 | " data = csr.data\n", 190 | " for i in range(min(csr.shape)):\n", 191 | " ind = indices[pointer[i]: pointer[i + 1]]\n", 192 | " j = ind.searchsorted(i)\n", 193 | " # matrix has only elements up until diagonal (in row i)\n", 194 | " if j == len(ind):\n", 195 | " continue\n", 196 | " j += pointer[i]\n", 197 | " # in case matrix has only elements after diagonal (in row i)\n", 198 | " if indices[j] == i:\n", 199 | " data[j] = value" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "collapsed": false 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "# Coarsen a graph given by rr,cc,vv. rr is assumed to be ordered\n", 211 | "def one_level_coarsening(rr,cc,vv,rid,weights):\n", 212 | " \n", 213 | " nnz = rr.shape[0]\n", 214 | " N = rr[nnz-1]+1\n", 215 | " #print(nnz,N)\n", 216 | " \n", 217 | " marked = np.zeros(N)\n", 218 | " rowstart = np.zeros(N)\n", 219 | " rowlength = np.zeros(N)\n", 220 | " cluster_id = np.zeros(N)\n", 221 | " \n", 222 | " oldval = rr[0]\n", 223 | " count = 0\n", 224 | " clustercount = 0\n", 225 | " \n", 226 | " for ii in range(nnz):\n", 227 | " rowlength[count] = rowlength[count] + 1\n", 228 | " if rr[ii] > oldval:\n", 229 | " oldval = rr[ii]\n", 230 | " rowstart[count+1] = ii\n", 231 | " count = count + 1\n", 232 | " \n", 233 | " for ii in range(N):\n", 234 | " tid = rid[ii]\n", 235 | " if marked[tid]==0.0:\n", 236 | " wmax = 0.0\n", 237 | " rs = rowstart[tid]\n", 238 | " marked[tid] = 1.0\n", 239 | " bestneighbor = -1\n", 240 | " for jj in range(int(rowlength[tid])):\n", 241 | " nid = cc[rs+jj]\n", 242 | " tval = (1.0-marked[nid]) * vv[rs+jj] * (1.0/weights[tid]+ 1.0/weights[nid])\n", 243 | " if tval > wmax:\n", 244 | " wmax = tval\n", 245 | " bestneighbor = nid\n", 246 | " \n", 247 | " cluster_id[tid] = clustercount;\n", 248 | " \n", 249 | " if bestneighbor > -1:\n", 250 | " cluster_id[bestneighbor] = clustercount\n", 251 | " marked[bestneighbor] = 1.0\n", 252 | " \n", 253 | " clustercount = clustercount + 1\n", 254 | " \n", 255 | " return cluster_id" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": { 262 | "collapsed": false 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "maxsize = 200\n", 267 | "N = W.shape[0]\n", 268 | "#rid = np.random.permutation(range(N))\n", 269 | "#print(N)\n", 270 | "#print(rid[0:10])\n", 271 | "\n", 272 | "graphs,parents = metis_coarsening(W.copy(),maxsize,rid)\n", 273 | "#print(graph)\n", 274 | "#print(parents)\n", 275 | "\n", 276 | "\n", 277 | "# CHECK RESULTS WITH MATLAB CODE\n", 278 | "graph0 = graphs[0]\n", 279 | "print(graph0.shape)\n", 280 | "print(graph0[0,:])\n", 281 | "\n", 282 | "graph1 = graphs[1]\n", 283 | "print(graph1.shape)\n", 284 | "print(graph1[0,:])\n", 285 | "\n", 286 | "graph2 = graphs[2]\n", 287 | "print(graph2.shape)\n", 288 | "print(graph2[0,:])\n", 289 | "\n", 290 | "parents0 = parents[0]\n", 291 | "print(parents0.shape)\n", 292 | "print(parents0[0:10])\n", 293 | "\n", 294 | "parents1 = parents[1]\n", 295 | "print(parents1.shape)\n", 296 | "print(parents1[0:10])" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": { 303 | "collapsed": false 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "import sys\n", 308 | "sys.path.append('..')\n", 309 | "from lib import coarsening\n", 310 | "\n", 311 | "graphs, parents = coarsening.metis(W, 2, rid)\n", 312 | "\n", 313 | "for i,A in enumerate(graphs):\n", 314 | " M, M = A.shape\n", 315 | " A = A.tocoo()\n", 316 | " A.setdiag(0)\n", 317 | " A = A.tocsr()\n", 318 | " A.eliminate_zeros()\n", 319 | " graphs[i] = A\n", 320 | " print('Layer {0}: M_{0} = {1} nodes, {2} edges'.format(i, M, A.nnz))\n", 321 | "\n", 322 | "# CHECK RESULTS WITH MATLAB CODE\n", 323 | "graph0 = graphs[0]\n", 324 | "print(graph0.shape)\n", 325 | "print(graph0[0,:])\n", 326 | "\n", 327 | "graph1 = graphs[1].tocsr()\n", 328 | "print(graph1.shape)\n", 329 | "print(graph1[0,:])\n", 330 | "\n", 331 | "graph2 = graphs[2].tocsr()\n", 332 | "print(graph2.shape)\n", 333 | "print(graph2[0,:])\n", 334 | "\n", 335 | "parents0 = parents[0]\n", 336 | "print(parents0.shape)\n", 337 | "print(parents0[0:10])\n", 338 | "\n", 339 | "parents1 = parents[1]\n", 340 | "print(parents1.shape)\n", 341 | "print(parents1[0:10])" 342 | ] 343 | }, 344 | { 345 | "cell_type": "raw", 346 | "metadata": {}, 347 | "source": [ 348 | "# Python results\n", 349 | "\n", 350 | "Size of original graph= 533\n", 351 | "level= 1\n", 352 | "Size of coarser graph= 279\n", 353 | "level= 2\n", 354 | "Size of coarser graph= 147\n", 355 | "(533, 533)\n", 356 | " (0, 18)\t0.810464124165\n", 357 | " (0, 59)\t0.349678536711\n", 358 | " (0, 60)\t0.591336229831\n", 359 | " (0, 83)\t0.388420442335\n", 360 | " (0, 105)\t0.255134781894\n", 361 | " (0, 210)\t0.656852096558\n", 362 | " (0, 226)\t0.900257809833\n", 363 | " (0, 299)\t0.065093756932\n", 364 | " (0, 340)\t0.810464124165\n", 365 | " (0, 407)\t0.431454676752\n", 366 | "(279, 279)\n", 367 | " (0, 44)\t1.63660876872\n", 368 | " (0, 58)\t2.42459126058\n", 369 | " (0, 71)\t0.186153138092\n", 370 | " (0, 115)\t1.99313658383\n", 371 | " (0, 167)\t1.24818832639\n", 372 | " (0, 168)\t2.95891026039\n", 373 | " (0, 179)\t0.388420442335\n", 374 | " (0, 240)\t0.431454676752\n", 375 | "(147, 147)\n", 376 | " (0, 21)\t5.1886032791\n", 377 | " (0, 85)\t1.08484314421\n", 378 | " (0, 87)\t0.353738954483\n", 379 | " (0, 127)\t0.186153138092\n", 380 | " (0, 135)\t1.88273900708\n", 381 | " (0, 141)\t0.255134781894\n", 382 | "(533,)\n", 383 | "[ 57. 148. 184. 237. 93. 93. 47. 28. 133. 71.]\n", 384 | "(279,)\n", 385 | "[ 127. 4. 88. 128. 50. 120. 54. 123. 146. 26.]" 386 | ] 387 | }, 388 | { 389 | "cell_type": "raw", 390 | "metadata": { 391 | "collapsed": true 392 | }, 393 | "source": [ 394 | "# Matlab results\n", 395 | "\n", 396 | "ans =\n", 397 | "\n", 398 | " (1,19) 0.8105\n", 399 | " (1,60) 0.3497\n", 400 | " (1,61) 0.5913\n", 401 | " (1,84) 0.3884\n", 402 | " (1,106) 0.2551\n", 403 | " (1,211) 0.6569\n", 404 | " (1,227) 0.9003\n", 405 | " (1,300) 0.0651\n", 406 | " (1,341) 0.8105\n", 407 | " (1,408) 0.4315\n", 408 | "\n", 409 | "\n", 410 | "ans =\n", 411 | "\n", 412 | " (1,45) 1.6366\n", 413 | " (1,59) 2.4246\n", 414 | " (1,72) 0.1862\n", 415 | " (1,116) 1.9931\n", 416 | " (1,168) 1.2482\n", 417 | " (1,169) 2.9589\n", 418 | " (1,180) 0.3884\n", 419 | " (1,241) 0.4315\n", 420 | "\n", 421 | "\n", 422 | "ans =\n", 423 | "\n", 424 | " (1,22) 5.1886\n", 425 | " (1,86) 1.0848\n", 426 | " (1,88) 0.3537\n", 427 | " (1,128) 0.1862\n", 428 | " (1,136) 1.8827\n", 429 | " (1,142) 0.2551\n", 430 | "\n", 431 | "\n", 432 | "ans =\n", 433 | "\n", 434 | " 58\n", 435 | " 149\n", 436 | " 185\n", 437 | " 238\n", 438 | " 94\n", 439 | " 94\n", 440 | " 48\n", 441 | " 29\n", 442 | " 134\n", 443 | " 72\n", 444 | "\n", 445 | "\n", 446 | "ans =\n", 447 | "\n", 448 | " 128\n", 449 | " 5\n", 450 | " 89\n", 451 | " 129\n", 452 | " 51\n", 453 | " 121\n", 454 | " 55\n", 455 | " 124\n", 456 | " 147\n", 457 | " 27" 458 | ] 459 | } 460 | ], 461 | "metadata": { 462 | "kernelspec": { 463 | "display_name": "Python 3", 464 | "language": "python", 465 | "name": "python3" 466 | }, 467 | "language_info": { 468 | "codemirror_mode": { 469 | "name": "ipython", 470 | "version": 3 471 | }, 472 | "file_extension": ".py", 473 | "mimetype": "text/x-python", 474 | "name": "python", 475 | "nbconvert_exporter": "python", 476 | "pygments_lexer": "ipython3", 477 | "version": "3.5.2" 478 | } 479 | }, 480 | "nbformat": 4, 481 | "nbformat_minor": 0 482 | } 483 | -------------------------------------------------------------------------------- /trials/makefile: -------------------------------------------------------------------------------- 1 | NB = $(sort $(wildcard *.ipynb)) 2 | 3 | run: $(NB) 4 | 5 | $(NB): 6 | jupyter nbconvert --inplace --execute --ExecutePreprocessor.timeout=-1 $@ 7 | 8 | clean: 9 | jupyter nbconvert --inplace --ClearOutputPreprocessor.enabled=True $(NB) 10 | 11 | .PHONY: run $(NB) clean 12 | -------------------------------------------------------------------------------- /usage.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Introduction\n", 8 | "\n", 9 | "$\\newcommand{\\G}{\\mathcal{G}}$\n", 10 | "$\\newcommand{\\V}{\\mathcal{V}}$\n", 11 | "$\\newcommand{\\E}{\\mathcal{E}}$\n", 12 | "$\\newcommand{\\R}{\\mathbb{R}}$\n", 13 | "\n", 14 | "This notebook shows how to apply our graph ConvNet ([paper] & [code]), or any other, to your structured or unstructured data. For this example, we assume that we have $n$ samples $x_i \\in \\R^{d_x}$ arranged in a data matrix $$X = [x_1, ..., x_n]^T \\in \\R^{n \\times d_x}.$$ Each sample $x_i$ is associated with a vector $y_i \\in \\R^{d_y}$ for a regression task or a label $y_i \\in \\{0,\\ldots,C\\}$ for a classification task.\n", 15 | "\n", 16 | "[paper]: https://arxiv.org/abs/1606.09375\n", 17 | "[code]: https://github.com/mdeff/cnn_graph\n", 18 | "\n", 19 | "From there, we'll structure our data with a graph $\\G = (\\V, \\E, A)$ where $\\V$ is the set of $d_x = |\\V|$ vertices, $\\E$ is the set of edges and $A \\in \\R^{d_x \\times d_x}$ is the adjacency matrix. That matrix represents the weight of each edge, i.e. $A_{i,j}$ is the weight of the edge connecting $v_i \\in \\V$ to $v_j \\in \\V$. The weights of that feature graph thus represent pairwise relationships between features $i$ and $j$. We call that regime **signal classification / regression**, as the samples $x_i$ to be classified or regressed are graph signals.\n", 20 | "\n", 21 | "Other modelling possibilities include:\n", 22 | "1. Using a data graph, i.e. an adjacency matrix $A \\in \\R^{n \\times n}$ which represents pairwise relationships between samples $x_i \\in \\R^{d_x}$. The problem is here to predict a graph signal $y \\in \\R^{n \\times d_y}$ given a graph characterized by $A$ and some graph signals $X \\in \\R^{n \\times d_x}$. We call that regime **node classification / regression**, as we classify or regress nodes instead of signals.\n", 23 | "2. Another problem of interest is whole graph classification, with or without signals on top. We'll call that third regime **graph classification / regression**. The problem here is to classify or regress a whole graph $A_i \\in \\R^{n \\times n}$ (with or without an associated data matrix $X_i \\in \\R^{n \\times d_x}$) into $y_i \\in \\R^{d_y}$. In case we have no signal, we can use a constant vector $X_i = 1_n$ of size $n$." 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "collapsed": false 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "from lib import models, graph, coarsening, utils\n", 35 | "import numpy as np\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "%matplotlib inline" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# 1 Data\n", 45 | "\n", 46 | "For the purpose of the demo, let's create a random data matrix $X \\in \\R^{n \\times d_x}$ and somehow infer a label $y_i = f(x_i)$." 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": { 53 | "collapsed": false 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "d = 100 # Dimensionality.\n", 58 | "n = 10000 # Number of samples.\n", 59 | "c = 5 # Number of feature communities.\n", 60 | "\n", 61 | "# Data matrix, structured in communities (feature-wise).\n", 62 | "X = np.random.normal(0, 1, (n, d)).astype(np.float32)\n", 63 | "X += np.linspace(0, 1, c).repeat(d // c)\n", 64 | "\n", 65 | "# Noisy non-linear target.\n", 66 | "w = np.random.normal(0, .02, d)\n", 67 | "t = X.dot(w) + np.random.normal(0, .001, n)\n", 68 | "t = np.tanh(t)\n", 69 | "plt.figure(figsize=(15, 5))\n", 70 | "plt.plot(t, '.')\n", 71 | "\n", 72 | "# Classification.\n", 73 | "y = np.ones(t.shape, dtype=np.uint8)\n", 74 | "y[t > t.mean() + 0.4 * t.std()] = 0\n", 75 | "y[t < t.mean() - 0.4 * t.std()] = 2\n", 76 | "print('Class imbalance: ', np.unique(y, return_counts=True)[1])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Then split this dataset into training, validation and testing sets." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "collapsed": false 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "n_train = n // 2\n", 95 | "n_val = n // 10\n", 96 | "\n", 97 | "X_train = X[:n_train]\n", 98 | "X_val = X[n_train:n_train+n_val]\n", 99 | "X_test = X[n_train+n_val:]\n", 100 | "\n", 101 | "y_train = y[:n_train]\n", 102 | "y_val = y[n_train:n_train+n_val]\n", 103 | "y_test = y[n_train+n_val:]" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "# 2 Graph\n", 111 | "\n", 112 | "The second thing we need is a **graph between features**, i.e. an adjacency matrix $A \\in \\mathbb{R}^{d_x \\times d_x}$.\n", 113 | "Structuring data with graphs is very flexible: it can accomodate both structured and unstructured data.\n", 114 | "1. **Structured data**.\n", 115 | " 1. The data is structured by an Euclidean domain, e.g. $x_i$ represents an image, a sound or a video. We can use a classical ConvNet with 1D, 2D or 3D convolutions or a graph ConvNet with a line or grid graph (however losing the orientation).\n", 116 | " 2. The data is structured by a graph, e.g. the data lies on a transportation, energy, brain or social network.\n", 117 | "2. **Unstructured data**. We could use a fully connected network, but the learning and computational complexities are gonna be large. An alternative is to construct a sparse similarity graph between features (or between samples) and use a graph ConvNet, effectively structuring the data and drastically reducing the number of parameters through weight sharing. As for classical ConvNets, the number of parameters are independent of the input size.\n", 118 | "\n", 119 | "There are many ways, supervised or unsupervised, to construct a graph given some data. And better the graph, better the performance ! For this example we'll define the adjacency matrix as a simple similarity measure between features. Below are the choices one has to make when constructing such a graph.\n", 120 | "1. The distance function. We'll use the Euclidean distance $d_{ij} = \\|x_i - x_j\\|_2$.\n", 121 | "2. The kernel. We'll use the Gaussian kernel $a_{ij} = \\exp(d_{ij}^2 / \\sigma^2)$.\n", 122 | "3. The type of graph. We'll use a $k$ nearest neigbors (kNN) graph." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": { 129 | "collapsed": false 130 | }, 131 | "outputs": [], 132 | "source": [ 133 | "dist, idx = graph.distance_scipy_spatial(X_train.T, k=10, metric='euclidean')\n", 134 | "A = graph.adjacency(dist, idx).astype(np.float32)\n", 135 | "\n", 136 | "assert A.shape == (d, d)\n", 137 | "print('d = |V| = {}, k|V| < |E| = {}'.format(d, A.nnz))\n", 138 | "plt.spy(A, markersize=2, color='black');" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "To be able to pool graph signals, we need first to coarsen the graph, i.e. to find which vertices to group together. At the end we'll have multiple graphs, like a pyramid, each at one level of resolution. The finest graph is where the input data lies, the coarsest graph is where the data at the output of the graph convolutional layers lie. That data, of reduced spatial dimensionality, can then be fed to a fully connected layer.\n", 146 | "\n", 147 | "The parameter here is the number of times to coarsen the graph. Each coarsening approximately reduces the size of the graph by a factor two. Thus if you want a pooling of size 4 in the first layer followed by a pooling of size 2 in the second, you'll need to coarsen $\\log_2(4+2) = 3$ times.\n", 148 | "\n", 149 | "After coarsening we rearrange the vertices (and add fake vertices) such that pooling a graph signal is analog to pooling a 1D signal. See the [paper] for details.\n", 150 | "\n", 151 | "[paper]: https://arxiv.org/abs/1606.09375" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": { 158 | "collapsed": false 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "graphs, perm = coarsening.coarsen(A, levels=3, self_connections=False)\n", 163 | "\n", 164 | "X_train = coarsening.perm_data(X_train, perm)\n", 165 | "X_val = coarsening.perm_data(X_val, perm)\n", 166 | "X_test = coarsening.perm_data(X_test, perm)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "We finally need to compute the graph Laplacian $L$ for each of our graphs (the original and the coarsened versions), defined by their adjacency matrices $A$. The sole parameter here is the type of Laplacian, e.g. the combinatorial Laplacian, the normalized Laplacian or the random walk Laplacian." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "collapsed": false 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "L = [graph.laplacian(A, normalized=True) for A in graphs]\n", 185 | "graph.plot_spectrum(L)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "# 3 Graph ConvNet\n", 193 | "\n", 194 | "Here we apply the graph convolutional neural network to signals lying on graphs. After designing the architecture and setting the hyper-parameters, the model takes as inputs the data matrix $X$, the target $y$ and a list of graph Laplacians $L$, one per coarsening level.\n", 195 | "\n", 196 | "The data, architecture and hyper-parameters are absolutely *not engineered to showcase performance*. Its sole purpose is to illustrate usage and functionality." 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": { 203 | "collapsed": false 204 | }, 205 | "outputs": [], 206 | "source": [ 207 | "params = dict()\n", 208 | "params['dir_name'] = 'demo'\n", 209 | "params['num_epochs'] = 40\n", 210 | "params['batch_size'] = 100\n", 211 | "params['eval_frequency'] = 200\n", 212 | "\n", 213 | "# Building blocks.\n", 214 | "params['filter'] = 'chebyshev5'\n", 215 | "params['brelu'] = 'b1relu'\n", 216 | "params['pool'] = 'apool1'\n", 217 | "\n", 218 | "# Number of classes.\n", 219 | "C = y.max() + 1\n", 220 | "assert C == np.unique(y).size\n", 221 | "\n", 222 | "# Architecture.\n", 223 | "params['F'] = [32, 64] # Number of graph convolutional filters.\n", 224 | "params['K'] = [20, 20] # Polynomial orders.\n", 225 | "params['p'] = [4, 2] # Pooling sizes.\n", 226 | "params['M'] = [512, C] # Output dimensionality of fully connected layers.\n", 227 | "\n", 228 | "# Optimization.\n", 229 | "params['regularization'] = 5e-4\n", 230 | "params['dropout'] = 1\n", 231 | "params['learning_rate'] = 1e-3\n", 232 | "params['decay_rate'] = 0.95\n", 233 | "params['momentum'] = 0.9\n", 234 | "params['decay_steps'] = n_train / params['batch_size']" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": { 241 | "collapsed": false 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "model = models.cgcnn(L, **params)\n", 246 | "accuracy, loss, t_step = model.fit(X_train, y_train, X_val, y_val)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "# 4 Evaluation\n", 254 | "\n", 255 | "We often want to monitor:\n", 256 | "1. The convergence, i.e. the training loss and the classification accuracy on the validation set.\n", 257 | "2. The performance, i.e. the classification accuracy on the testing set (to be compared with the training set accuracy to spot overfitting).\n", 258 | "\n", 259 | "The `model_perf` class in [utils.py](utils.py) can be used to compactly evaluate multiple models." 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "collapsed": false 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "fig, ax1 = plt.subplots(figsize=(15, 5))\n", 271 | "ax1.plot(accuracy, 'b.-')\n", 272 | "ax1.set_ylabel('validation accuracy', color='b')\n", 273 | "ax2 = ax1.twinx()\n", 274 | "ax2.plot(loss, 'g.-')\n", 275 | "ax2.set_ylabel('training loss', color='g')\n", 276 | "plt.show()" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "collapsed": false 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "print('Time per step: {:.2f} ms'.format(t_step*1000))" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": { 294 | "collapsed": false 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "res = model.evaluate(X_test, y_test)\n", 299 | "print(res[0])" 300 | ] 301 | } 302 | ], 303 | "metadata": { 304 | "kernelspec": { 305 | "display_name": "Python 3", 306 | "language": "python", 307 | "name": "python3" 308 | }, 309 | "language_info": { 310 | "codemirror_mode": { 311 | "name": "ipython", 312 | "version": 3 313 | }, 314 | "file_extension": ".py", 315 | "mimetype": "text/x-python", 316 | "name": "python", 317 | "nbconvert_exporter": "python", 318 | "pygments_lexer": "ipython3", 319 | "version": "3.4.3" 320 | } 321 | }, 322 | "nbformat": 4, 323 | "nbformat_minor": 0 324 | } 325 | --------------------------------------------------------------------------------