├── .gitignore ├── LICENSE ├── README.md ├── demo ├── EigenPro_cifar-10.ipynb ├── datasets.py └── main.py ├── eigenpro2 ├── __init__.py ├── kernels.py ├── models.py └── utils │ ├── __init__.py │ └── eigh.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | */.ipynb_checkpoints/ 4 | */__pycache__/ 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EigenPro2-pytorch 2 | 3 | EigenPro (short for Eigenspace Projections) is a fast iterative solver for Kernel Regression. 4 | **Paper:** [Kernel machines that adapt to GPUs for effective large batch training](https://arxiv.org/abs/1806.06144), SysML (2019). 5 | **Authors:** Siyuan Ma and Mikhail Belkin. (Bibtex below) 6 | 7 | It has a $O(n)$ memory and $O(n^2)$ time complexity with respect to number of samples. \ 8 | The algorithm is based on preconditioned SGD and has autotuned hyperparameters to maximize GPU utilization. 9 | 10 | # Installation 11 | ``` 12 | pip install git+https://github.com/EigenPro/EigenPro-pytorch.git 13 | ``` 14 | Requires a PyTorch installation 15 | 16 | ## Stable behavior 17 | Currently this code has been tested with n=1,000,000 samples.\ 18 | with Python 3.9 and `PyTorch >= 1.13` 19 | 20 | 21 | # Test installation with Laplacian kernel 22 | ```python 23 | import torch 24 | from eigenpro2.kernels import laplacian 25 | from eigenpro2.models import KernelModel 26 | 27 | n = 1000 # number of samples 28 | d = 100 # dimensions 29 | c = 3 # number of targets 30 | 31 | w_star=torch.randn(d, c) 32 | x_train, x_test = torch.randn(n, d), torch.randn(n, d) 33 | y_train, y_test = x_train @ w_star, x_test @ w_star 34 | 35 | if torch.cuda.is_available(): 36 | DEVICE = torch.device("cuda") 37 | DEV_MEM = torch.cuda.get_device_properties(DEVICE).total_memory//1024**3 - 1 # GPU memory in GB, keeping aside 1GB for safety 38 | else: 39 | DEVICE = torch.device("cpu") 40 | DEV_MEM = 8 # RAM available for computing 41 | 42 | kernel_fn = lambda x, y: laplacian(x, y, bandwidth=1.) 43 | model = KernelModel(kernel_fn, x_train, c, device=DEVICE) 44 | result = model.fit(x_train, y_train, x_test, y_test, epochs=30, print_every=5, mem_gb=DEV_MEM) 45 | print('Laplacian test complete!') 46 | ``` 47 | 48 | ### Bibtex 49 | ```latex 50 | @article{ma2019kernel, 51 | title={Kernel machines that adapt to GPUs for effective large batch training}, 52 | author={Ma, Siyuan and Belkin, Mikhail}, 53 | journal={Proceedings of Machine Learning and Systems}, 54 | volume={1}, 55 | pages={360--373}, 56 | year={2019} 57 | } 58 | ``` 59 | -------------------------------------------------------------------------------- /demo/datasets.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | from torchvision.datasets import MNIST, EMNIST, FashionMNIST, KMNIST, CIFAR10 3 | from torch.nn.functional import one_hot 4 | 5 | def unit_range_normalize(samples): 6 | samples -= samples.min(dim=0, keepdim=True).values 7 | return samples/samples.max(dim=1, keepdim=True).values 8 | 9 | def load_cifar10_data(**kwargs): 10 | train_data = CIFAR10(os.environ['DATA_DIR'], train=True) 11 | test_data = CIFAR10(os.environ['DATA_DIR'], train=False) 12 | n_class = len(train_data.classes) 13 | return ( 14 | n_class, 15 | (torch.from_numpy(train_data.data), torch.LongTensor(train_data.targets)), 16 | (torch.from_numpy(test_data.data), torch.LongTensor(test_data.targets)), 17 | ) 18 | 19 | 20 | def load_mnist_data(**kwargs): 21 | train_data = MNIST(os.environ['DATA_DIR'], train=True) 22 | test_data = MNIST(os.environ['DATA_DIR'], train=False) 23 | n_class = len(train_data.classes) 24 | return ( 25 | n_class, 26 | (train_data.data, train_data.targets), 27 | (test_data.data, test_data.targets), 28 | ) 29 | 30 | def load_emnist_data(**kwargs): 31 | train_data = EMNIST(os.environ['DATA_DIR'], train=True, **kwargs) 32 | test_data = EMNIST(os.environ['DATA_DIR'], train=False, **kwargs) 33 | n_class = len(train_data.classes) 34 | return ( 35 | n_class, 36 | (train_data.data, train_data.targets), 37 | (test_data.data, test_data.targets), 38 | ) 39 | 40 | def load_fmnist_data(**kwargs): 41 | train_data = FashionMNIST(os.environ['DATA_DIR'], train=True) 42 | test_data = FashionMNIST(os.environ['DATA_DIR'], train=False) 43 | n_class = len(train_data.classes) 44 | return ( 45 | n_class, 46 | (train_data.data, train_data.targets), 47 | (test_data.data, test_data.targets), 48 | ) 49 | 50 | def load_kmnist_data(**kwargs): 51 | train_data = KMNIST(os.environ['DATA_DIR'], train=True) 52 | test_data = KMNIST(os.environ['DATA_DIR'], train=False) 53 | n_class = len(train_data.classes) 54 | return ( 55 | n_class, 56 | (train_data.data, train_data.targets), 57 | (test_data.data, test_data.targets), 58 | ) 59 | 60 | 61 | def load(dataset='mnist', DEVICE=torch.device('cpu'), **kwargs): 62 | n_class, (x_train, y_train), (x_test, y_test) = eval(f'load_{dataset}_data')(**kwargs) 63 | 64 | x_train = x_train.reshape(x_train.shape[0], -1).to(DEVICE).float() 65 | x_test = x_test.reshape(x_test.shape[0], -1).to(DEVICE).float() 66 | 67 | x_train = unit_range_normalize(x_train) 68 | x_test = unit_range_normalize(x_test) 69 | y_train = one_hot(y_train, n_class).to(DEVICE).float() 70 | y_test = one_hot(y_test, n_class).to(DEVICE).float() 71 | print(f"Loaded {dataset.upper()} dataset to {DEVICE}") 72 | print(f"{n_class} classes") 73 | print(x_train.shape[0], 'train samples') 74 | print(x_test.shape[0], 'test samples') 75 | print('-'*20) 76 | 77 | return n_class, (x_train, y_train), (x_test, y_test) 78 | 79 | if __name__ == "__main__": 80 | n_class, (a,b), (c,d) = load() 81 | -------------------------------------------------------------------------------- /demo/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | from eigenpro2.models import KernelModel 4 | from eigenpro2.kernels import laplacian, ntk_relu_unit_sphere 5 | 6 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | n_class, (x_train, y_train), (x_test, y_test) = datasets.load('cifar10', DEVICE, split='digits') 9 | 10 | x_train=x_train/x_train.norm(dim=-1,keepdim=True) 11 | x_test=x_test/x_test.norm(dim=-1,keepdim=True) 12 | 13 | kernel_fn = lambda x, y: laplacian(x, y, bandwidth=1.) 14 | #kernel_fn = lambda x, z: ntk_relu_unit_sphere(x, z, depth=3) 15 | 16 | model = KernelModel(kernel_fn, x_train, n_class, device=DEVICE) 17 | 18 | results = model.fit(x_train, y_train, x_test, y_test, epochs=20, print_every=2, mem_gb=20) 19 | -------------------------------------------------------------------------------- /eigenpro2/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import KernelModel 2 | 3 | __version__ = 0.1 4 | -------------------------------------------------------------------------------- /eigenpro2/kernels.py: -------------------------------------------------------------------------------- 1 | '''Implementation of kernel functions.''' 2 | 3 | import torch 4 | 5 | eps = 1e-12 6 | 7 | def euclidean(samples, centers, squared=True): 8 | '''Calculate the pointwise distance. 9 | 10 | Args: 11 | samples: of shape (n_sample, n_feature). 12 | centers: of shape (n_center, n_feature). 13 | squared: boolean. 14 | 15 | Returns: 16 | pointwise distances (n_sample, n_center). 17 | ''' 18 | samples_norm = torch.sum(samples**2, dim=1, keepdim=True) 19 | if samples is centers: 20 | centers_norm = samples_norm 21 | else: 22 | centers_norm = torch.sum(centers**2, dim=1, keepdim=True) 23 | centers_norm = torch.reshape(centers_norm, (1, -1)) 24 | 25 | distances = samples.mm(torch.t(centers)) 26 | distances.mul_(-2) 27 | distances.add_(samples_norm) 28 | distances.add_(centers_norm) 29 | if not squared: 30 | distances.clamp_(min=0) 31 | distances.sqrt_() 32 | 33 | return distances 34 | 35 | 36 | def gaussian(samples, centers, bandwidth): 37 | '''Gaussian kernel. 38 | 39 | Args: 40 | samples: of shape (n_sample, n_feature). 41 | centers: of shape (n_center, n_feature). 42 | bandwidth: kernel bandwidth. 43 | 44 | Returns: 45 | kernel matrix of shape (n_sample, n_center). 46 | ''' 47 | assert bandwidth > 0 48 | kernel_mat = euclidean(samples, centers) 49 | kernel_mat.clamp_(min=0) 50 | gamma = 1. / (2 * bandwidth ** 2) 51 | kernel_mat.mul_(-gamma) 52 | kernel_mat.exp_() 53 | return kernel_mat 54 | 55 | 56 | def laplacian(samples, centers, bandwidth): 57 | '''Laplacian kernel. 58 | 59 | Args: 60 | samples: of shape (n_sample, n_feature). 61 | centers: of shape (n_center, n_feature). 62 | bandwidth: kernel bandwidth. 63 | 64 | Returns: 65 | kernel matrix of shape (n_sample, n_center). 66 | ''' 67 | assert bandwidth > 0 68 | kernel_mat = euclidean(samples, centers, squared=False) 69 | kernel_mat.clamp_(min=0) 70 | gamma = 1. / bandwidth 71 | kernel_mat.mul_(-gamma) 72 | kernel_mat.exp_() 73 | return kernel_mat 74 | 75 | 76 | def dispersal(samples, centers, bandwidth, gamma): 77 | '''Dispersal kernel. 78 | 79 | Args: 80 | samples: of shape (n_sample, n_feature). 81 | centers: of shape (n_center, n_feature). 82 | bandwidth: kernel bandwidth. 83 | gamma: dispersal factor. 84 | 85 | Returns: 86 | kernel matrix of shape (n_sample, n_center). 87 | ''' 88 | assert bandwidth > 0 89 | kernel_mat = euclidean(samples, centers) 90 | kernel_mat.pow_(gamma / 2.) 91 | kernel_mat.mul_(-1. / bandwidth) 92 | kernel_mat.exp_() 93 | return kernel_mat 94 | 95 | 96 | def ntk_relu(X, Z, depth=1, bias=0.): 97 | """ 98 | Returns the evaluation of nngp and ntk kernels 99 | for fully connected neural networks 100 | with ReLU nonlinearity. 101 | 102 | depth (int): number of layers of the network 103 | bias (float): (default=0.) 104 | """ 105 | from torch import acos, pi 106 | kappa_0 = lambda u: (1-acos(u)/pi) 107 | kappa_1 = lambda u: u*kappa_0(u) + (1-u.pow(2)).sqrt()/pi 108 | Z = Z if Z is not None else X 109 | norm_x = X.norm(dim=-1)[:, None].clip(min=eps) 110 | norm_z = Z.norm(dim=-1)[None, :].clip(min=eps) 111 | S = X @ Z.T 112 | N = S + bias**2 113 | for k in range(1, depth): 114 | in_ = (S/norm_x/norm_z).clip(min=-1+eps,max=1-eps) 115 | S = norm_x*norm_z*kappa_1(in_) 116 | N = N * kappa_0(in_) + S + bias**2 117 | return N 118 | 119 | def ntk_relu_unit_sphere(X, Z, depth=1, bias=0.): 120 | """ 121 | Returns the evaluation of nngp and ntk kernels 122 | for fully connected neural networks 123 | with ReLU nonlinearity. 124 | Assumes inputs are normalized to unit norm. 125 | 126 | depth (int): number of layers of the network 127 | bias (float): (default=0.) 128 | """ 129 | from torch import acos, pi 130 | kappa_0 = lambda u: (1-acos(u)/pi) 131 | kappa_1 = lambda u: u*kappa_0(u) + (1-u.pow(2)).sqrt()/pi 132 | Z = Z if Z is not None else X 133 | S = X @ Z.T 134 | N = S + bias**2 135 | for k in range(1, depth): 136 | in_ = (S).clip(min=-1+eps,max=1-eps) 137 | S = kappa_1(in_) 138 | N = N * kappa_0(in_) + S + bias**2 139 | return N 140 | 141 | if __name__ == "__main__": 142 | import torch 143 | from torch.nn.functional import normalize 144 | n, m, d = 1000, 800, 10 145 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 146 | X = torch.randn(n, d, device=DEVICE) 147 | X_ = normalize(X, dim=-1) 148 | Z = torch.randn(m, d, device=DEVICE) 149 | Z_ = normalize(Z, dim=-1) 150 | KXZ_ntk = ntk_relu(X, Z, 64, bias=1.) 151 | KXZ_ntk_ = ntk_relu_normalized(X_, Z_, 64, bias=1.) 152 | print( 153 | KXZ_ntk.diag().max().item(), 154 | KXZ_ntk_.diag().max().item() 155 | ) 156 | -------------------------------------------------------------------------------- /eigenpro2/models.py: -------------------------------------------------------------------------------- 1 | '''Construct kernel model with EigenPro optimizer.''' 2 | import collections 3 | import time 4 | import torch 5 | 6 | import torch.nn as nn 7 | 8 | from .utils.eigh import nystrom_kernel_eigh 9 | 10 | four_spaces = ' '*4 11 | 12 | def asm_eigenpro_fn(samples, map_fn, top_q, bs_gpu, alpha, min_q=5, seed=1): 13 | """Prepare gradient map for EigenPro and calculate 14 | scale factor for learning ratesuch that the update rule, 15 | p <- p - eta * g 16 | becomes, 17 | p <- p - scale * eta * (g - eigenpro_fn(g)) 18 | 19 | Arguments: 20 | samples: matrix of shape (n_sample, n_feature). 21 | map_fn: kernel k(samples, centers) where centers are specified. 22 | top_q: top-q eigensystem for constructing eigenpro iteration/kernel. 23 | bs_gpu: maxinum batch size corresponding to GPU memory. 24 | alpha: exponential factor (<= 1) for eigenvalue rescaling due to approximation. 25 | min_q: minimum value of q when q (if None) is calculated automatically. 26 | seed: seed for random number generation. 27 | 28 | Returns: 29 | eigenpro_fn: tensor function. 30 | scale: factor that rescales learning rate. 31 | top_eigval: largest eigenvalue. 32 | beta: largest k(x, x) for the EigenPro kernel. 33 | """ 34 | 35 | start = time.time() 36 | n_sample, _ = samples.shape 37 | 38 | if top_q is None: 39 | svd_q = min(n_sample//3 - 1, 1000) 40 | else: 41 | svd_q = min(top_q, n_sample//3 - 1) 42 | 43 | eigvals, eigvecs, beta = nystrom_kernel_eigh(samples, map_fn, svd_q) 44 | 45 | # Choose k such that the batch size is bounded by 46 | # the subsample size and the memory size. 47 | # Keep the original k if it is pre-specified. 48 | if top_q is None: 49 | max_bs = min(max(n_sample / 5, bs_gpu), n_sample) 50 | top_q = torch.sum((eigvals).pow(-alpha) < max_bs) - 1 51 | top_q = max(top_q, min_q) 52 | 53 | eigvals, tail_eigval = eigvals[:top_q - 1], eigvals[top_q - 1] 54 | eigvecs = eigvecs[:, :top_q - 1] 55 | 56 | device = samples.device 57 | eigvals_t = eigvals.to(device) 58 | eigvecs_t = eigvecs.to(device) 59 | tail_eigval_t = tail_eigval.to(device) 60 | 61 | scale = (eigvals[0]/tail_eigval).pow(alpha) 62 | diag_t = (1 - torch.pow(tail_eigval_t / eigvals_t, alpha)) / eigvals_t 63 | 64 | def eigenpro_fn(grad, kmat): 65 | '''Function to apply EigenPro preconditioner.''' 66 | return torch.mm(eigvecs_t * diag_t, 67 | torch.t(torch.mm(torch.mm(torch.t(grad), 68 | kmat), 69 | eigvecs_t))) 70 | 71 | print("SVD time: %.2fs, top_q: %d, top_eigval: %.2f, new top_eigval: %.2e" % 72 | (time.time() - start, top_q, eigvals[0], eigvals[0] / scale)) 73 | 74 | 75 | return eigenpro_fn, scale, eigvals[0], beta 76 | 77 | 78 | class KernelModel(nn.Module): 79 | '''Fast Kernel Regression using EigenPro iteration.''' 80 | def __init__(self, kernel_fn, centers, y_dim, device="cuda"): 81 | super(KernelModel, self).__init__() 82 | self.kernel_fn = kernel_fn 83 | self.n_centers, self.x_dim = centers.shape 84 | self.device = device 85 | self.pinned_list = [] 86 | 87 | self.centers = self.tensor(centers, release=True) 88 | self.weight = self.tensor(torch.zeros( 89 | self.n_centers, y_dim), release=True) 90 | 91 | def __del__(self): 92 | for pinned in self.pinned_list: 93 | _ = pinned.to("cpu") 94 | if torch.cuda.is_available(): 95 | torch.cuda.empty_cache() 96 | 97 | def tensor(self, data, dtype=None, release=False): 98 | tensor = torch.as_tensor(data, dtype=dtype, device=self.device) 99 | if release: 100 | self.pinned_list.append(tensor) 101 | return tensor 102 | 103 | def kernel_matrix(self, samples): 104 | return self.kernel_fn(samples, self.centers) 105 | 106 | def forward(self, samples, weight=None): 107 | if weight is None: 108 | weight = self.weight 109 | kmat = self.kernel_matrix(samples) 110 | pred = kmat.mm(weight) 111 | return pred 112 | 113 | def predict(self, samples): 114 | self.forward(samples) 115 | 116 | def primal_gradient(self, samples, labels, weight): 117 | pred = self.forward(samples, weight) 118 | grad = pred - labels.type(pred.type()) 119 | return grad 120 | 121 | @staticmethod 122 | def _compute_opt_params(bs, bs_gpu, beta, top_eigval): 123 | if bs is None: 124 | bs = min(int(beta / top_eigval + 1), bs_gpu) 125 | 126 | if bs < beta / top_eigval + 1: 127 | eta = bs / beta 128 | else: 129 | eta = 0.99 * 2 * bs / (beta + (bs - 1) * top_eigval) 130 | return bs, eta 131 | 132 | def eigenpro_iterate(self, samples, x_batch, y_batch, eigenpro_fn, 133 | eta, sample_ids, batch_ids): 134 | # update random coordiate block (for mini-batch) 135 | grad = self.primal_gradient(x_batch, y_batch.type(x_batch.type()), self.weight) 136 | self.weight.index_add_(0, batch_ids, -eta * grad) 137 | 138 | # update fixed coordinate block (for EigenPro) 139 | kmat = self.kernel_fn(x_batch, samples) 140 | correction = eigenpro_fn(grad, kmat) 141 | self.weight.index_add_(0, sample_ids, eta * correction) 142 | return 143 | 144 | def evaluate(self, x_eval, y_eval, bs, 145 | metrics=('mse', 'multiclass-acc')): 146 | if (x_eval is None) or (y_eval is None): 147 | return {a: None for a in metrics} 148 | p_list = [] 149 | n_sample, _ = x_eval.shape 150 | for batch_ids in torch.split(torch.arange(n_sample), bs): 151 | x_batch = self.tensor(x_eval[batch_ids]) 152 | p_batch = self.forward(x_batch) 153 | p_list.append(p_batch) 154 | p_eval = torch.vstack(p_list) 155 | 156 | eval_metrics = collections.OrderedDict() 157 | if 'mse' in metrics: 158 | eval_metrics['mse'] = (p_eval - self.tensor(y_eval.type(x_eval.type()))).pow(2).mean() 159 | if 'multiclass-acc' in metrics: 160 | y_class = self.tensor(y_eval.type(x_eval.type())).argmax(-1) 161 | p_class = p_eval.argmax(-1) 162 | eval_metrics['multiclass-acc'] = (1.*(y_class == p_class)).mean() 163 | 164 | return eval_metrics 165 | 166 | def score(self, samples, targets, metric='mse'): 167 | preds = self.predict(samples) 168 | if metric=='mse': 169 | return (preds - targets).pow(2).mean() 170 | elif metric=="accuracy": 171 | return 1.*(preds.argmax(-1)==targets.argmax(-1)).mean()*100 172 | 173 | def fit(self, x_train, y_train, x_val=None, y_val=None, epochs=1, mem_gb=1, 174 | print_every=1, 175 | n_subsamples=None, top_q=None, bs=None, eta=None, 176 | n_train_eval=5000, run_epoch_eval=True, scale=1, seed=1, **kwargs): 177 | 178 | n_samples, n_labels = y_train.shape 179 | if n_subsamples is None: 180 | if n_samples < 100000: 181 | n_subsamples = min(n_samples, 2000) 182 | else: 183 | n_subsamples = 12000 184 | 185 | mem_bytes = (mem_gb - 1) * 1024**3 # preserve 1GB 186 | bsizes = torch.arange(n_subsamples) 187 | mem_usages = ((self.x_dim + 3 * n_labels + bsizes + 1) 188 | * self.n_centers + n_subsamples * 1000) * 4 189 | bs_gpu = torch.sum(mem_usages < mem_bytes) # device-dependent batch size 190 | 191 | # Calculate batch size / learning rate for improved EigenPro iteration. 192 | sample_ids = torch.randperm(n_samples)[:n_subsamples] 193 | sample_ids = self.tensor(sample_ids, dtype=torch.int64) 194 | samples = self.centers[sample_ids] 195 | eigenpro_f, gap, top_eigval, beta = asm_eigenpro_fn( 196 | samples, self.kernel_fn, top_q, bs_gpu, alpha=.95, seed=seed) 197 | new_top_eigval = top_eigval / gap 198 | 199 | if eta is None: 200 | bs, eta = self._compute_opt_params( 201 | bs, bs_gpu, beta, new_top_eigval) 202 | else: 203 | bs, _ = self._compute_opt_params(bs, bs_gpu, beta, new_top_eigval) 204 | 205 | print("n_subsamples=%d, bs_gpu=%d, eta=%.2f, bs=%d, top_eigval=%.2e, beta=%.2f" % 206 | (n_subsamples, bs_gpu, eta, bs, top_eigval, beta)) 207 | print('-'*20) 208 | eta = self.tensor(scale * eta / bs, dtype=torch.float) 209 | 210 | # Subsample training data for fast estimation of training loss. 211 | ids = torch.randperm(n_samples, device=x_train.device)[:min(n_samples, n_train_eval)] 212 | x_train_eval, y_train_eval = x_train[ids], y_train[ids] 213 | 214 | results = dict() 215 | initial_epoch = 0 216 | train_sec = 0 # training time in seconds 217 | 218 | 219 | for epoch in range(epochs): 220 | start_time = time.time() 221 | epoch_ids = torch.randperm(n_samples, device=x_train.device) 222 | for batch_num, batch_ids in enumerate(torch.split(epoch_ids, bs)): 223 | x_batch = self.tensor(x_train[batch_ids]) 224 | y_batch = self.tensor(y_train[batch_ids]) 225 | self.eigenpro_iterate(samples, x_batch, y_batch, eigenpro_f, 226 | eta, sample_ids, self.tensor(batch_ids)) 227 | del x_batch, y_batch, batch_ids 228 | 229 | if run_epoch_eval and ((epoch%print_every)==0): 230 | tr_score = self.evaluate(x_train_eval, y_train_eval, bs) 231 | tv_score = self.evaluate(x_val, y_val, bs) 232 | message = f"epoch: {epoch:3d}{four_spaces}" 233 | message += f"time: {time.time() - start_time:04.1f}s{four_spaces}" 234 | message += f"train accuracy: {tr_score['multiclass-acc']*100:.2f}%{four_spaces}" 235 | message += f"val accuracy: {tv_score['multiclass-acc']*100:.2f}%{four_spaces}" if tv_score['multiclass-acc'] is not None else " " 236 | message += f"train mse: {tr_score['mse']:.2e}{four_spaces}" 237 | message += f"val mse: {tv_score['mse']:.2e}" if tv_score['mse'] is not None else " " 238 | print(message) 239 | results[epoch] = (tr_score, tv_score, train_sec) 240 | 241 | return results 242 | -------------------------------------------------------------------------------- /eigenpro2/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EigenPro/EigenPro-pytorch/bec49271ca51889c14ae20e2eb87a90d4291fdca/eigenpro2/utils/__init__.py -------------------------------------------------------------------------------- /eigenpro2/utils/eigh.py: -------------------------------------------------------------------------------- 1 | '''Utility functions for performing fast SVD.''' 2 | # import scipy.linalg as linalg 3 | import torch, math 4 | 5 | 6 | def nystrom_kernel_eigh(samples, kernel_fn, top_q): 7 | """Compute top eigensystem of kernel matrix using Nystrom method. 8 | 9 | Arguments: 10 | samples: data matrix of shape (n_sample, n_feature). 11 | kernel_fn: tensor function k(X, Y) that returns kernel matrix. 12 | top_q: top-q eigensystem. 13 | 14 | Returns: 15 | eigvals: top eigenvalues of shape (top_q). 16 | eigvecs: (rescaled) top eigenvectors of shape (n_sample, top_q). 17 | """ 18 | 19 | n_sample, _ = samples.shape 20 | samples_ = samples #.cpu() 21 | kmat = kernel_fn(samples_, samples_) 22 | scaled_kmat = kmat / n_sample 23 | vals, vecs = torch.lobpcg(scaled_kmat, min(top_q+1, n_sample//3)) 24 | # vals, vecs = linalg.eigh(scaled_kmat, 25 | # eigvals=(n_sample - top_q, n_sample - 1)) 26 | #eigvals = torch.from_numpy(vals).flip(0) 27 | #eigvecs = torch.from_numpy(vecs).fliplr()/sqrt(n_sample) 28 | beta = kmat.diag().max() 29 | 30 | # return eigvals.float(), eigvecs.float(), beta 31 | return vals, vecs/math.sqrt(n_sample), beta 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import eigenpro2 3 | 4 | with open("README.md", "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | setup( 8 | name='eigenpro2', 9 | version=eigenpro2.__version__, 10 | author='Siyuan Ma, Adityanarayanan Radhakrishnan, Parthe Pandit', 11 | author_email='parthe1292@gmail.com', 12 | description='Fast solver for Kernel Regression using GPUs with linear space and time complexity', 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url='https://github.com/EigenPro/EigenPro-pytorch/tree/pytorch', 16 | project_urls = { 17 | "Bug Tracker": "https://github.com/EigenPro/EigenPro-pytorch/issues" 18 | }, 19 | license='Apache-2.0 license', 20 | packages=find_packages(), 21 | install_requires=[], 22 | ) 23 | --------------------------------------------------------------------------------